From 86189e26fe276197f31d2d9ec23312d811615192 Mon Sep 17 00:00:00 2001 From: elburrito Date: Thu, 20 Apr 2023 13:47:40 -0700 Subject: [PATCH] Change SpendTokenData to use PublicMetadata instead of PublicMetadataInfo, and update its users. PiperOrigin-RevId: 525839307 git-subtree-dir: net/third_party/quiche/src git-subtree-split: 02c69dd28eef7ef2618782e8d54d53c14ae64382 --- .bazelrc | 11 + BUILD.bazel | 7 + CONTRIBUTING.md | 35 + LICENSE | 27 + README.md | 28 + WHITESPACE | 3 + WORKSPACE.bazel | 87 + build/BUILD.bazel | 7 + build/source_list.bzl | 1637 ++ build/source_list.gni | 1637 ++ build/source_list.json | 1637 ++ build/test.bzl | 30 + build/zlib.BUILD | 25 + depstool/deps/parse.go | 123 + depstool/deps/parse_test.go | 103 + depstool/depstool.go | 99 + depstool/go.mod | 7 + depstool/go.sum | 2 + quiche/BUILD.bazel | 555 + quiche/balsa/balsa_enums.cc | 117 + quiche/balsa/balsa_enums.h | 130 + quiche/balsa/balsa_frame.cc | 1338 ++ quiche/balsa/balsa_frame.h | 322 + quiche/balsa/balsa_frame_test.cc | 4011 ++++ quiche/balsa/balsa_headers.cc | 1157 + quiche/balsa/balsa_headers.h | 1468 ++ quiche/balsa/balsa_headers_test.cc | 3722 ++++ quiche/balsa/balsa_visitor_interface.h | 165 + quiche/balsa/framer_interface.h | 24 + quiche/balsa/header_api.h | 274 + quiche/balsa/header_properties.cc | 111 + quiche/balsa/header_properties.h | 50 + quiche/balsa/header_properties_test.cc | 80 + quiche/balsa/http_validation_policy.h | 36 + quiche/balsa/noop_balsa_visitor.h | 59 + quiche/balsa/simple_buffer.cc | 152 + quiche/balsa/simple_buffer.h | 118 + quiche/balsa/simple_buffer_test.cc | 411 + quiche/balsa/standard_header_map.cc | 143 + quiche/balsa/standard_header_map.h | 24 + quiche/binary_http/binary_http_message.cc | 454 + quiche/binary_http/binary_http_message.h | 291 + .../binary_http/binary_http_message_test.cc | 786 + .../anonymous_tokens_rsa_bssa_client.cc | 256 + .../client/anonymous_tokens_rsa_bssa_client.h | 104 + .../anonymous_tokens_rsa_bssa_client_test.cc | 470 + .../cpp/crypto/at_crypto_utils_test.cc | 397 + .../cpp/crypto/blind_signer.h | 37 + .../anonymous_tokens/cpp/crypto/blinder.h | 38 + .../anonymous_tokens/cpp/crypto/constants.h | 68 + .../cpp/crypto/crypto_utils.cc | 527 + .../cpp/crypto/crypto_utils.h | 189 + .../cpp/crypto/rsa_blind_signer.cc | 206 + .../cpp/crypto/rsa_blind_signer.h | 86 + .../cpp/crypto/rsa_blind_signer_test.cc | 262 + .../cpp/crypto/rsa_blinder.cc | 286 + .../anonymous_tokens/cpp/crypto/rsa_blinder.h | 97 + .../cpp/crypto/rsa_blinder_test.cc | 360 + .../cpp/crypto/rsa_ssa_pss_verifier.cc | 97 + .../cpp/crypto/rsa_ssa_pss_verifier.h | 81 + .../cpp/crypto/rsa_ssa_pss_verifier_test.cc | 287 + .../anonymous_tokens/cpp/crypto/verifier.h | 29 + .../cpp/shared/proto_utils.cc | 64 + .../anonymous_tokens/cpp/shared/proto_utils.h | 53 + .../cpp/shared/proto_utils_test.cc | 93 + .../cpp/shared/status_utils.h | 49 + .../anonymous_tokens/cpp/testing/utils.cc | 790 + .../anonymous_tokens/cpp/testing/utils.h | 156 + .../proto/anonymous_tokens.proto | 335 + quiche/blind_sign_auth/blind_sign_auth.cc | 299 + quiche/blind_sign_auth/blind_sign_auth.h | 64 + .../blind_sign_auth_interface.h | 32 + .../blind_sign_auth/blind_sign_auth_test.cc | 300 + .../blind_sign_http_interface.h | 42 + .../blind_sign_http_response.h | 33 + .../blind_sign_auth/cached_blind_sign_auth.cc | 115 + .../blind_sign_auth/cached_blind_sign_auth.h | 65 + .../cached_blind_sign_auth_test.cc | 337 + quiche/blind_sign_auth/proto/any.proto | 26 + .../blind_sign_auth/proto/attestation.proto | 114 + .../blind_sign_auth/proto/auth_and_sign.proto | 87 + .../proto/get_initial_data.proto | 61 + .../blind_sign_auth/proto/key_services.proto | 27 + .../proto/public_metadata.proto | 54 + .../proto/spend_token_data.proto | 38 + quiche/blind_sign_auth/proto/timestamp.proto | 32 + .../mock_blind_sign_auth_interface.h | 33 + .../mock_blind_sign_http_interface.h | 32 + .../strong_rsa_modulus2048_example.binarypb | Bin 0 -> 1178 bytes .../strong_rsa_modulus2048_example_2.binarypb | Bin 0 -> 1178 bytes .../strong_rsa_modulus3072_example.binarypb | Bin 0 -> 1754 bytes .../strong_rsa_modulus4096_example.binarypb | Bin 0 -> 2330 bytes quiche/common/btree_scheduler.h | 297 + quiche/common/btree_scheduler_test.cc | 281 + quiche/common/capsule.cc | 706 + quiche/common/capsule.h | 386 + quiche/common/capsule_test.cc | 521 + .../masque/connect_udp_datagram_payload.cc | 130 + .../masque/connect_udp_datagram_payload.h | 99 + .../connect_udp_datagram_payload_test.cc | 61 + .../common/platform/api/quiche_bug_tracker.h | 15 + .../common/platform/api/quiche_client_stats.h | 88 + .../platform/api/quiche_command_line_flags.h | 43 + .../common/platform/api/quiche_containers.h | 22 + .../api/quiche_default_proof_providers.h | 32 + .../common/platform/api/quiche_event_loop.h | 27 + .../common/platform/api/quiche_expect_bug.h | 14 + quiche/common/platform/api/quiche_export.h | 19 + .../common/platform/api/quiche_file_utils.cc | 51 + .../common/platform/api/quiche_file_utils.h | 40 + .../platform/api/quiche_file_utils_test.cc | 86 + .../common/platform/api/quiche_flag_utils.h | 19 + quiche/common/platform/api/quiche_flags.h | 28 + .../platform/api/quiche_header_policy.h | 20 + .../platform/api/quiche_hostname_utils.cc | 109 + .../platform/api/quiche_hostname_utils.h | 33 + .../api/quiche_hostname_utils_test.cc | 94 + quiche/common/platform/api/quiche_iovec.h | 23 + quiche/common/platform/api/quiche_logging.h | 56 + .../platform/api/quiche_lower_case_string.h | 16 + .../api/quiche_lower_case_string_test.cc | 29 + quiche/common/platform/api/quiche_mem_slice.h | 75 + .../platform/api/quiche_mem_slice_test.cc | 104 + quiche/common/platform/api/quiche_mutex.cc | 31 + quiche/common/platform/api/quiche_mutex.h | 101 + quiche/common/platform/api/quiche_prefetch.h | 39 + .../platform/api/quiche_reference_counted.h | 168 + .../api/quiche_reference_counted_test.cc | 173 + .../common/platform/api/quiche_server_stats.h | 82 + .../common/platform/api/quiche_stack_trace.h | 32 + .../platform/api/quiche_stack_trace_test.cc | 43 + .../platform/api/quiche_system_event_loop.h | 20 + quiche/common/platform/api/quiche_test.h | 42 + .../platform/api/quiche_test_loopback.cc | 21 + .../platform/api/quiche_test_loopback.h | 34 + .../common/platform/api/quiche_test_output.h | 40 + quiche/common/platform/api/quiche_testvalue.h | 25 + quiche/common/platform/api/quiche_thread.h | 28 + .../common/platform/api/quiche_time_utils.h | 27 + .../platform/api/quiche_time_utils_test.cc | 51 + .../api/quiche_udp_socket_platform_api.h | 46 + quiche/common/platform/api/quiche_url_utils.h | 38 + .../platform/api/quiche_url_utils_test.cc | 80 + quiche/common/platform/api/testdir/README.md | 1 + quiche/common/platform/api/testdir/a/b/c/d/e | 1 + .../platform/api/testdir/a/subdir/testfile | 1 + quiche/common/platform/api/testdir/a/z | 1 + quiche/common/platform/api/testdir/testfile | 1 + .../quiche_bug_tracker_impl.h | 16 + .../quiche_client_stats_impl.h | 44 + .../quiche_command_line_flags_impl.cc | 41 + .../quiche_command_line_flags_impl.h | 33 + .../quiche_containers_impl.h | 17 + .../quiche_default_proof_providers_impl.cc | 77 + .../quiche_default_proof_providers_impl.h | 21 + .../quiche_event_loop_impl.h | 27 + .../quiche_expect_bug_impl.h | 15 + .../quiche_platform_impl/quiche_export_impl.h | 18 + .../quiche_file_utils_impl.cc | 182 + .../quiche_file_utils_impl.h | 26 + .../quiche_flag_utils_impl.h | 29 + .../quiche_platform_impl/quiche_flags_impl.cc | 37 + .../quiche_platform_impl/quiche_flags_impl.h | 57 + .../quiche_header_policy_impl.h | 16 + .../quiche_platform_impl/quiche_iovec_impl.h | 24 + .../quiche_logging_impl.h | 160 + .../quiche_lower_case_string_impl.h | 25 + .../quiche_mem_slice_impl.h | 46 + .../quiche_platform_impl/quiche_mutex_impl.cc | 15 + .../quiche_platform_impl/quiche_mutex_impl.h | 68 + .../quiche_prefetch_impl.h | 28 + .../quiche_reference_counted_impl.h | 190 + .../quiche_server_stats_impl.h | 26 + .../quiche_stack_trace_impl.cc | 52 + .../quiche_stack_trace_impl.h | 17 + .../quiche_stream_buffer_allocator_impl.h | 16 + .../quiche_system_event_loop_impl.h | 24 + .../quiche_platform_impl/quiche_test_impl.cc | 25 + .../quiche_platform_impl/quiche_test_impl.h | 56 + .../quiche_test_loopback_impl.cc | 32 + .../quiche_test_loopback_impl.h | 30 + .../quiche_test_output_impl.h | 27 + .../quiche_testvalue_impl.h | 13 + .../quiche_platform_impl/quiche_thread_impl.h | 26 + .../quiche_time_utils_impl.cc | 48 + .../quiche_time_utils_impl.h | 21 + .../quiche_udp_socket_platform_impl.h | 37 + .../quiche_url_utils_impl.cc | 79 + .../quiche_url_utils_impl.h | 35 + quiche/common/print_elements.h | 37 + quiche/common/print_elements_test.cc | 61 + quiche/common/quiche_buffer_allocator.cc | 76 + quiche/common/quiche_buffer_allocator.h | 126 + quiche/common/quiche_buffer_allocator_test.cc | 141 + quiche/common/quiche_circular_deque.h | 754 + quiche/common/quiche_circular_deque_test.cc | 799 + quiche/common/quiche_crypto_logging.cc | 42 + quiche/common/quiche_crypto_logging.h | 26 + quiche/common/quiche_data_reader.cc | 321 + quiche/common/quiche_data_reader.h | 216 + quiche/common/quiche_data_reader_test.cc | 187 + quiche/common/quiche_data_writer.cc | 301 + quiche/common/quiche_data_writer.h | 152 + quiche/common/quiche_data_writer_test.cc | 835 + quiche/common/quiche_endian.h | 73 + quiche/common/quiche_endian_test.cc | 53 + quiche/common/quiche_ip_address.cc | 258 + quiche/common/quiche_ip_address.h | 130 + quiche/common/quiche_ip_address_family.cc | 47 + quiche/common/quiche_ip_address_family.h | 23 + quiche/common/quiche_ip_address_test.cc | 142 + quiche/common/quiche_linked_hash_map.h | 237 + quiche/common/quiche_linked_hash_map_test.cc | 393 + quiche/common/quiche_mem_slice_storage.cc | 34 + quiche/common/quiche_mem_slice_storage.h | 43 + .../common/quiche_mem_slice_storage_test.cc | 61 + quiche/common/quiche_protocol_flags_list.h | 15 + quiche/common/quiche_random.cc | 93 + quiche/common/quiche_random.h | 37 + quiche/common/quiche_random_test.cc | 47 + quiche/common/quiche_status_utils.h | 51 + quiche/common/quiche_stream.h | 106 + quiche/common/quiche_text_utils.cc | 76 + quiche/common/quiche_text_utils.h | 75 + quiche/common/quiche_text_utils_test.cc | 92 + quiche/common/simple_buffer_allocator.cc | 17 + quiche/common/simple_buffer_allocator.h | 30 + quiche/common/simple_buffer_allocator_test.cc | 63 + quiche/common/structured_headers.cc | 910 + quiche/common/structured_headers.h | 325 + quiche/common/structured_headers_fuzzer.cc | 22 + .../structured_headers_generated_test.cc | 3944 ++++ quiche/common/structured_headers_test.cc | 762 + quiche/common/test_tools/quiche_test_utils.cc | 102 + quiche/common/test_tools/quiche_test_utils.h | 85 + .../test_tools/quiche_test_utils_test.cc | 37 + quiche/common/wire_serialization.h | 398 + quiche/common/wire_serialization_test.cc | 256 + .../adapter/adapter_impl_comparison_test.cc | 171 + quiche/http2/adapter/callback_visitor.cc | 509 + quiche/http2/adapter/callback_visitor.h | 113 + quiche/http2/adapter/callback_visitor_test.cc | 536 + quiche/http2/adapter/data_source.h | 60 + quiche/http2/adapter/event_forwarder.cc | 198 + quiche/http2/adapter/event_forwarder.h | 81 + quiche/http2/adapter/event_forwarder_test.cc | 235 + quiche/http2/adapter/header_validator.cc | 298 + quiche/http2/adapter/header_validator.h | 54 + quiche/http2/adapter/header_validator_base.h | 70 + quiche/http2/adapter/header_validator_test.cc | 676 + quiche/http2/adapter/http2_adapter.h | 163 + quiche/http2/adapter/http2_protocol.cc | 85 + quiche/http2/adapter/http2_protocol.h | 150 + quiche/http2/adapter/http2_session.h | 33 + quiche/http2/adapter/http2_util.cc | 134 + quiche/http2/adapter/http2_util.h | 31 + .../http2/adapter/http2_visitor_interface.h | 264 + quiche/http2/adapter/mock_http2_visitor.h | 122 + .../http2/adapter/mock_nghttp2_callbacks.cc | 130 + quiche/http2/adapter/mock_nghttp2_callbacks.h | 70 + quiche/http2/adapter/nghttp2.h | 11 + quiche/http2/adapter/nghttp2_adapter.cc | 309 + quiche/http2/adapter/nghttp2_adapter.h | 115 + quiche/http2/adapter/nghttp2_adapter_test.cc | 7196 +++++++ quiche/http2/adapter/nghttp2_callbacks.cc | 388 + quiche/http2/adapter/nghttp2_callbacks.h | 90 + quiche/http2/adapter/nghttp2_data_provider.cc | 62 + quiche/http2/adapter/nghttp2_data_provider.h | 37 + .../adapter/nghttp2_data_provider_test.cc | 117 + quiche/http2/adapter/nghttp2_session.cc | 57 + quiche/http2/adapter/nghttp2_session.h | 41 + quiche/http2/adapter/nghttp2_session_test.cc | 374 + quiche/http2/adapter/nghttp2_test.cc | 262 + quiche/http2/adapter/nghttp2_test_utils.cc | 463 + quiche/http2/adapter/nghttp2_test_utils.h | 103 + quiche/http2/adapter/nghttp2_util.cc | 304 + quiche/http2/adapter/nghttp2_util.h | 76 + quiche/http2/adapter/nghttp2_util_test.cc | 109 + quiche/http2/adapter/noop_header_validator.cc | 22 + quiche/http2/adapter/noop_header_validator.h | 25 + .../adapter/noop_header_validator_test.cc | 523 + quiche/http2/adapter/oghttp2_adapter.cc | 168 + quiche/http2/adapter/oghttp2_adapter.h | 77 + quiche/http2/adapter/oghttp2_adapter_test.cc | 8352 ++++++++ quiche/http2/adapter/oghttp2_session.cc | 2024 ++ quiche/http2/adapter/oghttp2_session.h | 557 + quiche/http2/adapter/oghttp2_session_test.cc | 1076 + quiche/http2/adapter/oghttp2_util.cc | 17 + quiche/http2/adapter/oghttp2_util.h | 18 + quiche/http2/adapter/oghttp2_util_test.cc | 83 + .../http2/adapter/recording_http2_visitor.cc | 181 + .../http2/adapter/recording_http2_visitor.h | 80 + .../adapter/recording_http2_visitor_test.cc | 131 + quiche/http2/adapter/test_frame_sequence.cc | 188 + quiche/http2/adapter/test_frame_sequence.h | 72 + quiche/http2/adapter/test_utils.cc | 224 + quiche/http2/adapter/test_utils.h | 139 + quiche/http2/adapter/test_utils_test.cc | 123 + quiche/http2/adapter/window_manager.cc | 103 + quiche/http2/adapter/window_manager.h | 93 + quiche/http2/adapter/window_manager_test.cc | 342 + quiche/http2/core/http2_trace_logging.cc | 482 + quiche/http2/core/http2_trace_logging.h | 144 + quiche/http2/core/priority_write_scheduler.h | 381 + .../core/priority_write_scheduler_test.cc | 344 + quiche/http2/decoder/decode_buffer.cc | 93 + quiche/http2/decoder/decode_buffer.h | 171 + quiche/http2/decoder/decode_buffer_test.cc | 203 + .../http2/decoder/decode_http2_structures.cc | 121 + .../http2/decoder/decode_http2_structures.h | 33 + .../decoder/decode_http2_structures_test.cc | 460 + quiche/http2/decoder/decode_status.cc | 28 + quiche/http2/decoder/decode_status.h | 32 + quiche/http2/decoder/frame_decoder_state.cc | 80 + quiche/http2/decoder/frame_decoder_state.h | 252 + quiche/http2/decoder/http2_frame_decoder.cc | 456 + quiche/http2/decoder/http2_frame_decoder.h | 214 + .../decoder/http2_frame_decoder_listener.cc | 14 + .../decoder/http2_frame_decoder_listener.h | 385 + .../http2/decoder/http2_frame_decoder_test.cc | 919 + .../http2/decoder/http2_structure_decoder.cc | 95 + .../http2/decoder/http2_structure_decoder.h | 130 + .../decoder/http2_structure_decoder_test.cc | 535 + .../altsvc_payload_decoder.cc | 149 + .../payload_decoders/altsvc_payload_decoder.h | 64 + .../altsvc_payload_decoder_test.cc | 121 + .../continuation_payload_decoder.cc | 57 + .../continuation_payload_decoder.h | 31 + .../continuation_payload_decoder_test.cc | 84 + .../payload_decoders/data_payload_decoder.cc | 128 + .../payload_decoders/data_payload_decoder.h | 54 + .../data_payload_decoder_test.cc | 110 + .../goaway_payload_decoder.cc | 120 + .../payload_decoders/goaway_payload_decoder.h | 66 + .../goaway_payload_decoder_test.cc | 108 + .../headers_payload_decoder.cc | 176 + .../headers_payload_decoder.h | 67 + .../headers_payload_decoder_test.cc | 158 + .../payload_decoders/ping_payload_decoder.cc | 90 + .../payload_decoders/ping_payload_decoder.h | 43 + .../ping_payload_decoder_test.cc | 109 + .../priority_payload_decoder.cc | 62 + .../priority_payload_decoder.h | 44 + .../priority_payload_decoder_test.cc | 89 + .../priority_update_payload_decoder.cc | 123 + .../priority_update_payload_decoder.h | 63 + .../priority_update_payload_decoder_test.cc | 114 + .../push_promise_payload_decoder.cc | 171 + .../push_promise_payload_decoder.h | 66 + .../push_promise_payload_decoder_test.cc | 138 + .../rst_stream_payload_decoder.cc | 64 + .../rst_stream_payload_decoder.h | 42 + .../rst_stream_payload_decoder_test.cc | 92 + .../settings_payload_decoder.cc | 95 + .../settings_payload_decoder.h | 53 + .../settings_payload_decoder_test.cc | 159 + .../unknown_payload_decoder.cc | 55 + .../unknown_payload_decoder.h | 33 + .../unknown_payload_decoder_test.cc | 99 + .../window_update_payload_decoder.cc | 80 + .../window_update_payload_decoder.h | 42 + .../window_update_payload_decoder_test.cc | 95 + .../decoder/hpack_block_collector_test.cc | 121 + .../hpack/decoder/hpack_block_decoder.cc | 65 + .../http2/hpack/decoder/hpack_block_decoder.h | 69 + .../hpack/decoder/hpack_block_decoder_test.cc | 290 + quiche/http2/hpack/decoder/hpack_decoder.cc | 124 + quiche/http2/hpack/decoder/hpack_decoder.h | 132 + .../hpack/decoder/hpack_decoder_listener.cc | 29 + .../hpack/decoder/hpack_decoder_listener.h | 60 + .../hpack/decoder/hpack_decoder_state.cc | 223 + .../http2/hpack/decoder/hpack_decoder_state.h | 137 + .../hpack/decoder/hpack_decoder_state_test.cc | 541 + .../decoder/hpack_decoder_string_buffer.cc | 239 + .../decoder/hpack_decoder_string_buffer.h | 101 + .../hpack_decoder_string_buffer_test.cc | 250 + .../hpack/decoder/hpack_decoder_tables.cc | 148 + .../hpack/decoder/hpack_decoder_tables.h | 166 + .../decoder/hpack_decoder_tables_test.cc | 257 + .../http2/hpack/decoder/hpack_decoder_test.cc | 1187 ++ .../hpack/decoder/hpack_decoding_error.cc | 51 + .../hpack/decoder/hpack_decoding_error.h | 51 + .../decoder/hpack_entry_collector_test.cc | 155 + .../hpack/decoder/hpack_entry_decoder.cc | 294 + .../http2/hpack/decoder/hpack_entry_decoder.h | 90 + .../decoder/hpack_entry_decoder_listener.cc | 80 + .../decoder/hpack_entry_decoder_listener.h | 110 + .../hpack/decoder/hpack_entry_decoder_test.cc | 202 + .../hpack/decoder/hpack_entry_type_decoder.cc | 361 + .../hpack/decoder/hpack_entry_type_decoder.h | 57 + .../decoder/hpack_entry_type_decoder_test.cc | 86 + .../hpack/decoder/hpack_string_decoder.cc | 35 + .../hpack/decoder/hpack_string_decoder.h | 207 + .../decoder/hpack_string_decoder_listener.cc | 36 + .../decoder/hpack_string_decoder_listener.h | 62 + .../decoder/hpack_string_decoder_test.cc | 153 + .../hpack/decoder/hpack_whole_entry_buffer.cc | 152 + .../hpack/decoder/hpack_whole_entry_buffer.h | 101 + .../decoder/hpack_whole_entry_buffer_test.cc | 226 + .../decoder/hpack_whole_entry_listener.cc | 31 + .../decoder/hpack_whole_entry_listener.h | 80 + .../hpack/hpack_static_table_entries.inc | 65 + quiche/http2/hpack/http2_hpack_constants.cc | 31 + quiche/http2/hpack/http2_hpack_constants.h | 62 + .../http2/hpack/http2_hpack_constants_test.cc | 66 + .../hpack/huffman/hpack_huffman_decoder.cc | 483 + .../hpack/huffman/hpack_huffman_decoder.h | 134 + .../huffman/hpack_huffman_decoder_test.cc | 242 + .../hpack/huffman/hpack_huffman_encoder.cc | 127 + .../hpack/huffman/hpack_huffman_encoder.h | 38 + .../huffman/hpack_huffman_encoder_test.cc | 130 + .../huffman/hpack_huffman_transcoder_test.cc | 182 + .../hpack/huffman/huffman_spec_tables.cc | 572 + .../http2/hpack/huffman/huffman_spec_tables.h | 31 + .../hpack/varint/hpack_varint_decoder.cc | 143 + .../http2/hpack/varint/hpack_varint_decoder.h | 128 + .../hpack/varint/hpack_varint_decoder_test.cc | 309 + .../hpack/varint/hpack_varint_encoder.cc | 47 + .../http2/hpack/varint/hpack_varint_encoder.h | 29 + .../hpack/varint/hpack_varint_encoder_test.cc | 161 + .../varint/hpack_varint_round_trip_test.cc | 417 + quiche/http2/http2_constants.cc | 181 + quiche/http2/http2_constants.h | 270 + quiche/http2/http2_constants_test.cc | 271 + quiche/http2/http2_structures.cc | 153 + quiche/http2/http2_structures.h | 347 + quiche/http2/http2_structures_test.cc | 570 + .../frame_decoder_state_test_util.cc | 34 + .../frame_decoder_state_test_util.h | 37 + quiche/http2/test_tools/frame_parts.cc | 554 + quiche/http2/test_tools/frame_parts.h | 258 + .../http2/test_tools/frame_parts_collector.cc | 112 + .../http2/test_tools/frame_parts_collector.h | 113 + .../frame_parts_collector_listener.cc | 247 + .../frame_parts_collector_listener.h | 91 + .../http2/test_tools/hpack_block_builder.cc | 66 + quiche/http2/test_tools/hpack_block_builder.h | 97 + .../test_tools/hpack_block_builder_test.cc | 169 + .../http2/test_tools/hpack_block_collector.cc | 143 + .../http2/test_tools/hpack_block_collector.h | 122 + .../http2/test_tools/hpack_entry_collector.cc | 293 + .../http2/test_tools/hpack_entry_collector.h | 151 + quiche/http2/test_tools/hpack_example.cc | 59 + quiche/http2/test_tools/hpack_example.h | 32 + quiche/http2/test_tools/hpack_example_test.cc | 45 + .../test_tools/hpack_string_collector.cc | 117 + .../http2/test_tools/hpack_string_collector.h | 66 + .../test_tools/http2_constants_test_util.cc | 84 + .../test_tools/http2_constants_test_util.h | 34 + .../http2/test_tools/http2_frame_builder.cc | 179 + quiche/http2/test_tools/http2_frame_builder.h | 103 + .../test_tools/http2_frame_builder_test.cc | 228 + .../http2_frame_decoder_listener_test_util.cc | 511 + .../http2_frame_decoder_listener_test_util.h | 154 + quiche/http2/test_tools/http2_random.cc | 73 + quiche/http2/test_tools/http2_random.h | 89 + quiche/http2/test_tools/http2_random_test.cc | 93 + .../http2_structure_decoder_test_util.cc | 22 + .../http2_structure_decoder_test_util.h | 24 + .../test_tools/http2_structures_test_util.cc | 112 + .../test_tools/http2_structures_test_util.h | 61 + .../payload_decoder_base_test_util.cc | 97 + .../payload_decoder_base_test_util.h | 444 + .../test_tools/random_decoder_test_base.cc | 167 + .../test_tools/random_decoder_test_base.h | 255 + .../random_decoder_test_base_test.cc | 327 + quiche/http2/test_tools/random_util.cc | 39 + quiche/http2/test_tools/random_util.h | 30 + quiche/http2/test_tools/verify_macros.h | 32 + .../oblivious_http_integration_test.cc | 108 + .../buffers/oblivious_http_request.cc | 209 + .../buffers/oblivious_http_request.h | 120 + .../buffers/oblivious_http_request_test.cc | 287 + .../buffers/oblivious_http_response.cc | 353 + .../buffers/oblivious_http_response.h | 95 + .../buffers/oblivious_http_response_test.cc | 210 + .../oblivious_http_header_key_config.cc | 472 + .../common/oblivious_http_header_key_config.h | 220 + .../oblivious_http_header_key_config_test.cc | 356 + .../oblivious_http/oblivious_http_client.cc | 91 + quiche/oblivious_http/oblivious_http_client.h | 80 + .../oblivious_http_client_test.cc | 252 + .../oblivious_http/oblivious_http_gateway.cc | 68 + .../oblivious_http/oblivious_http_gateway.h | 83 + .../oblivious_http_gateway_test.cc | 227 + quiche/quic/bindings/quic_libevent.cc | 239 + quiche/quic/bindings/quic_libevent.h | 155 + quiche/quic/bindings/quic_libevent_test.cc | 68 + .../batch_writer/quic_batch_writer_base.cc | 176 + .../batch_writer/quic_batch_writer_base.h | 156 + .../batch_writer/quic_batch_writer_buffer.cc | 151 + .../batch_writer/quic_batch_writer_buffer.h | 94 + .../quic_batch_writer_buffer_test.cc | 281 + .../batch_writer/quic_batch_writer_test.cc | 76 + .../batch_writer/quic_batch_writer_test.h | 286 + .../batch_writer/quic_gso_batch_writer.cc | 159 + .../core/batch_writer/quic_gso_batch_writer.h | 113 + .../quic_gso_batch_writer_test.cc | 462 + .../quic_sendmmsg_batch_writer.cc | 81 + .../batch_writer/quic_sendmmsg_batch_writer.h | 34 + .../quic_sendmmsg_batch_writer_test.cc | 15 + quiche/quic/core/chlo_extractor.cc | 362 + quiche/quic/core/chlo_extractor.h | 44 + quiche/quic/core/chlo_extractor_test.cc | 177 + .../congestion_control/bandwidth_sampler.cc | 583 + .../congestion_control/bandwidth_sampler.h | 612 + .../bandwidth_sampler_test.cc | 888 + .../core/congestion_control/bbr2_drain.cc | 59 + .../quic/core/congestion_control/bbr2_drain.h | 59 + .../quic/core/congestion_control/bbr2_misc.cc | 460 + .../quic/core/congestion_control/bbr2_misc.h | 679 + .../core/congestion_control/bbr2_probe_bw.cc | 653 + .../core/congestion_control/bbr2_probe_bw.h | 138 + .../core/congestion_control/bbr2_probe_rtt.cc | 79 + .../core/congestion_control/bbr2_probe_rtt.h | 58 + .../core/congestion_control/bbr2_sender.cc | 577 + .../core/congestion_control/bbr2_sender.h | 218 + .../congestion_control/bbr2_simulator_test.cc | 2575 +++ .../core/congestion_control/bbr2_startup.cc | 154 + .../core/congestion_control/bbr2_startup.h | 68 + .../core/congestion_control/bbr_sender.cc | 896 + .../quic/core/congestion_control/bbr_sender.h | 391 + .../congestion_control/bbr_sender_test.cc | 1323 ++ .../core/congestion_control/cubic_bytes.cc | 189 + .../core/congestion_control/cubic_bytes.h | 102 + .../congestion_control/cubic_bytes_test.cc | 387 + .../general_loss_algorithm.cc | 190 + .../general_loss_algorithm.h | 137 + .../general_loss_algorithm_test.cc | 488 + .../congestion_control/hybrid_slow_start.cc | 104 + .../congestion_control/hybrid_slow_start.h | 82 + .../hybrid_slow_start_test.cc | 76 + .../loss_detection_interface.h | 71 + .../core/congestion_control/pacing_sender.cc | 167 + .../core/congestion_control/pacing_sender.h | 114 + .../congestion_control/pacing_sender_test.cc | 585 + .../core/congestion_control/prr_sender.cc | 62 + .../quic/core/congestion_control/prr_sender.h | 42 + .../congestion_control/prr_sender_test.cc | 123 + .../quic/core/congestion_control/rtt_stats.cc | 143 + .../quic/core/congestion_control/rtt_stats.h | 131 + .../core/congestion_control/rtt_stats_test.cc | 231 + .../send_algorithm_interface.cc | 58 + .../send_algorithm_interface.h | 179 + .../congestion_control/send_algorithm_test.cc | 347 + .../tcp_cubic_sender_bytes.cc | 387 + .../tcp_cubic_sender_bytes.h | 171 + .../tcp_cubic_sender_bytes_test.cc | 841 + .../congestion_control/uber_loss_algorithm.cc | 210 + .../congestion_control/uber_loss_algorithm.h | 139 + .../uber_loss_algorithm_test.cc | 360 + .../core/congestion_control/windowed_filter.h | 164 + .../windowed_filter_test.cc | 381 + quiche/quic/core/connecting_client_socket.h | 111 + quiche/quic/core/connection_id_generator.h | 34 + .../quic/core/crypto/aead_base_decrypter.cc | 190 + quiche/quic/core/crypto/aead_base_decrypter.h | 69 + .../quic/core/crypto/aead_base_encrypter.cc | 168 + quiche/quic/core/crypto/aead_base_encrypter.h | 73 + .../core/crypto/aes_128_gcm_12_decrypter.cc | 32 + .../core/crypto/aes_128_gcm_12_decrypter.h | 38 + .../crypto/aes_128_gcm_12_decrypter_test.cc | 288 + .../core/crypto/aes_128_gcm_12_encrypter.cc | 27 + .../core/crypto/aes_128_gcm_12_encrypter.h | 34 + .../crypto/aes_128_gcm_12_encrypter_test.cc | 244 + .../quic/core/crypto/aes_128_gcm_decrypter.cc | 34 + .../quic/core/crypto/aes_128_gcm_decrypter.h | 36 + .../core/crypto/aes_128_gcm_decrypter_test.cc | 291 + .../quic/core/crypto/aes_128_gcm_encrypter.cc | 27 + .../quic/core/crypto/aes_128_gcm_encrypter.h | 32 + .../core/crypto/aes_128_gcm_encrypter_test.cc | 273 + .../quic/core/crypto/aes_256_gcm_decrypter.cc | 34 + .../quic/core/crypto/aes_256_gcm_decrypter.h | 36 + .../core/crypto/aes_256_gcm_decrypter_test.cc | 297 + .../quic/core/crypto/aes_256_gcm_encrypter.cc | 27 + .../quic/core/crypto/aes_256_gcm_encrypter.h | 32 + .../core/crypto/aes_256_gcm_encrypter_test.cc | 259 + quiche/quic/core/crypto/aes_base_decrypter.cc | 52 + quiche/quic/core/crypto/aes_base_decrypter.h | 33 + quiche/quic/core/crypto/aes_base_encrypter.cc | 48 + quiche/quic/core/crypto/aes_base_encrypter.h | 32 + quiche/quic/core/crypto/boring_utils.h | 34 + quiche/quic/core/crypto/cert_compressor.cc | 598 + quiche/quic/core/crypto/cert_compressor.h | 45 + .../quic/core/crypto/cert_compressor_test.cc | 119 + quiche/quic/core/crypto/certificate_util.cc | 280 + quiche/quic/core/crypto/certificate_util.h | 46 + .../quic/core/crypto/certificate_util_test.cc | 49 + quiche/quic/core/crypto/certificate_view.cc | 664 + quiche/quic/core/crypto/certificate_view.h | 156 + .../crypto/certificate_view_der_fuzzer.cc | 19 + .../crypto/certificate_view_pem_fuzzer.cc | 18 + .../quic/core/crypto/certificate_view_test.cc | 230 + .../crypto/chacha20_poly1305_decrypter.cc | 41 + .../core/crypto/chacha20_poly1305_decrypter.h | 41 + .../chacha20_poly1305_decrypter_test.cc | 178 + .../crypto/chacha20_poly1305_encrypter.cc | 35 + .../core/crypto/chacha20_poly1305_encrypter.h | 38 + .../chacha20_poly1305_encrypter_test.cc | 159 + .../crypto/chacha20_poly1305_tls_decrypter.cc | 43 + .../crypto/chacha20_poly1305_tls_decrypter.h | 39 + .../chacha20_poly1305_tls_decrypter_test.cc | 188 + .../crypto/chacha20_poly1305_tls_encrypter.cc | 35 + .../crypto/chacha20_poly1305_tls_encrypter.h | 36 + .../chacha20_poly1305_tls_encrypter_test.cc | 173 + .../quic/core/crypto/chacha_base_decrypter.cc | 44 + .../quic/core/crypto/chacha_base_decrypter.h | 31 + .../quic/core/crypto/chacha_base_encrypter.cc | 41 + .../quic/core/crypto/chacha_base_encrypter.h | 30 + quiche/quic/core/crypto/channel_id.cc | 90 + quiche/quic/core/crypto/channel_id.h | 47 + quiche/quic/core/crypto/channel_id_test.cc | 285 + .../quic/core/crypto/client_proof_source.cc | 62 + quiche/quic/core/crypto/client_proof_source.h | 70 + .../core/crypto/client_proof_source_test.cc | 215 + quiche/quic/core/crypto/crypto_framer.cc | 351 + quiche/quic/core/crypto/crypto_framer.h | 136 + quiche/quic/core/crypto/crypto_framer_test.cc | 442 + quiche/quic/core/crypto/crypto_handshake.cc | 39 + quiche/quic/core/crypto/crypto_handshake.h | 190 + .../core/crypto/crypto_handshake_message.cc | 368 + .../core/crypto/crypto_handshake_message.h | 159 + .../crypto/crypto_handshake_message_test.cc | 105 + .../quic/core/crypto/crypto_message_parser.h | 35 + quiche/quic/core/crypto/crypto_protocol.h | 516 + .../quic/core/crypto/crypto_secret_boxer.cc | 146 + quiche/quic/core/crypto/crypto_secret_boxer.h | 67 + .../core/crypto/crypto_secret_boxer_test.cc | 82 + quiche/quic/core/crypto/crypto_server_test.cc | 1122 + quiche/quic/core/crypto/crypto_utils.cc | 812 + quiche/quic/core/crypto/crypto_utils.h | 259 + quiche/quic/core/crypto/crypto_utils_test.cc | 262 + .../core/crypto/curve25519_key_exchange.cc | 86 + .../core/crypto/curve25519_key_exchange.h | 52 + .../crypto/curve25519_key_exchange_test.cc | 104 + quiche/quic/core/crypto/key_exchange.cc | 42 + quiche/quic/core/crypto/key_exchange.h | 101 + quiche/quic/core/crypto/null_decrypter.cc | 121 + quiche/quic/core/crypto/null_decrypter.h | 62 + .../quic/core/crypto/null_decrypter_test.cc | 137 + quiche/quic/core/crypto/null_encrypter.cc | 88 + quiche/quic/core/crypto/null_encrypter.h | 53 + .../quic/core/crypto/null_encrypter_test.cc | 103 + quiche/quic/core/crypto/p256_key_exchange.cc | 121 + quiche/quic/core/crypto/p256_key_exchange.h | 68 + .../core/crypto/p256_key_exchange_test.cc | 109 + quiche/quic/core/crypto/proof_source.cc | 61 + quiche/quic/core/crypto/proof_source.h | 354 + quiche/quic/core/crypto/proof_source_x509.cc | 169 + quiche/quic/core/crypto/proof_source_x509.h | 84 + .../core/crypto/proof_source_x509_test.cc | 142 + quiche/quic/core/crypto/proof_verifier.h | 117 + .../core/crypto/quic_client_session_cache.cc | 173 + .../core/crypto/quic_client_session_cache.h | 82 + .../crypto/quic_client_session_cache_test.cc | 440 + .../crypto/quic_compressed_certs_cache.cc | 114 + .../core/crypto/quic_compressed_certs_cache.h | 103 + .../quic_compressed_certs_cache_test.cc | 91 + quiche/quic/core/crypto/quic_crypter.cc | 19 + quiche/quic/core/crypto/quic_crypter.h | 94 + .../core/crypto/quic_crypto_client_config.cc | 842 + .../core/crypto/quic_crypto_client_config.h | 467 + .../crypto/quic_crypto_client_config_test.cc | 550 + quiche/quic/core/crypto/quic_crypto_proof.cc | 12 + quiche/quic/core/crypto/quic_crypto_proof.h | 32 + .../core/crypto/quic_crypto_server_config.cc | 1896 ++ .../core/crypto/quic_crypto_server_config.h | 948 + .../crypto/quic_crypto_server_config_test.cc | 494 + quiche/quic/core/crypto/quic_decrypter.cc | 79 + quiche/quic/core/crypto/quic_decrypter.h | 93 + quiche/quic/core/crypto/quic_encrypter.cc | 60 + quiche/quic/core/crypto/quic_encrypter.h | 70 + quiche/quic/core/crypto/quic_hkdf.cc | 98 + quiche/quic/core/crypto/quic_hkdf.h | 72 + quiche/quic/core/crypto/quic_hkdf_test.cc | 91 + quiche/quic/core/crypto/quic_random.h | 16 + .../quic/core/crypto/tls_client_connection.cc | 48 + .../quic/core/crypto/tls_client_connection.h | 54 + quiche/quic/core/crypto/tls_connection.cc | 206 + quiche/quic/core/crypto/tls_connection.h | 153 + .../quic/core/crypto/tls_server_connection.cc | 172 + .../quic/core/crypto/tls_server_connection.h | 180 + .../quic/core/crypto/transport_parameters.cc | 1655 ++ .../quic/core/crypto/transport_parameters.h | 311 + .../core/crypto/transport_parameters_test.cc | 1192 ++ ...eb_transport_fingerprint_proof_verifier.cc | 231 + ...web_transport_fingerprint_proof_verifier.h | 126 + ...ansport_fingerprint_proof_verifier_test.cc | 183 + .../deterministic_connection_id_generator.cc | 73 + .../deterministic_connection_id_generator.h | 40 + ...erministic_connection_id_generator_test.cc | 126 + quiche/quic/core/frames/quic_ack_frame.cc | 188 + quiche/quic/core/frames/quic_ack_frame.h | 140 + .../core/frames/quic_ack_frequency_frame.cc | 29 + .../core/frames/quic_ack_frequency_frame.h | 50 + quiche/quic/core/frames/quic_blocked_frame.cc | 29 + quiche/quic/core/frames/quic_blocked_frame.h | 49 + .../frames/quic_connection_close_frame.cc | 73 + .../core/frames/quic_connection_close_frame.h | 68 + quiche/quic/core/frames/quic_crypto_frame.cc | 37 + quiche/quic/core/frames/quic_crypto_frame.h | 48 + quiche/quic/core/frames/quic_frame.cc | 531 + quiche/quic/core/frames/quic_frame.h | 174 + quiche/quic/core/frames/quic_frames_test.cc | 846 + quiche/quic/core/frames/quic_goaway_frame.cc | 29 + quiche/quic/core/frames/quic_goaway_frame.h | 35 + .../core/frames/quic_handshake_done_frame.cc | 24 + .../core/frames/quic_handshake_done_frame.h | 34 + quiche/quic/core/frames/quic_inlined_frame.h | 34 + .../core/frames/quic_max_streams_frame.cc | 28 + .../quic/core/frames/quic_max_streams_frame.h | 43 + quiche/quic/core/frames/quic_message_frame.cc | 42 + quiche/quic/core/frames/quic_message_frame.h | 51 + .../core/frames/quic_mtu_discovery_frame.h | 25 + .../frames/quic_new_connection_id_frame.cc | 30 + .../frames/quic_new_connection_id_frame.h | 39 + .../quic/core/frames/quic_new_token_frame.cc | 23 + .../quic/core/frames/quic_new_token_frame.h | 37 + quiche/quic/core/frames/quic_padding_frame.cc | 15 + quiche/quic/core/frames/quic_padding_frame.h | 36 + .../core/frames/quic_path_challenge_frame.cc | 32 + .../core/frames/quic_path_challenge_frame.h | 37 + .../core/frames/quic_path_response_frame.cc | 31 + .../core/frames/quic_path_response_frame.h | 37 + quiche/quic/core/frames/quic_ping_frame.cc | 19 + quiche/quic/core/frames/quic_ping_frame.h | 34 + .../frames/quic_retire_connection_id_frame.cc | 21 + .../frames/quic_retire_connection_id_frame.h | 32 + .../quic/core/frames/quic_rst_stream_frame.cc | 41 + .../quic/core/frames/quic_rst_stream_frame.h | 58 + .../core/frames/quic_stop_sending_frame.cc | 37 + .../core/frames/quic_stop_sending_frame.h | 52 + .../core/frames/quic_stop_waiting_frame.cc | 20 + .../core/frames/quic_stop_waiting_frame.h | 31 + quiche/quic/core/frames/quic_stream_frame.cc | 53 + quiche/quic/core/frames/quic_stream_frame.h | 50 + .../core/frames/quic_streams_blocked_frame.cc | 30 + .../core/frames/quic_streams_blocked_frame.h | 44 + .../core/frames/quic_window_update_frame.cc | 30 + .../core/frames/quic_window_update_frame.h | 45 + .../quic/core/handshaker_delegate_interface.h | 85 + quiche/quic/core/http/end_to_end_test.cc | 7427 +++++++ quiche/quic/core/http/http_constants.cc | 33 + quiche/quic/core/http/http_constants.h | 77 + quiche/quic/core/http/http_decoder.cc | 683 + quiche/quic/core/http/http_decoder.h | 278 + quiche/quic/core/http/http_decoder_test.cc | 1067 + quiche/quic/core/http/http_encoder.cc | 283 + quiche/quic/core/http/http_encoder.h | 65 + quiche/quic/core/http/http_encoder_test.cc | 139 + quiche/quic/core/http/http_frames.h | 163 + quiche/quic/core/http/http_frames_test.cc | 87 + .../core/http/quic_client_promised_info.cc | 146 + .../core/http/quic_client_promised_info.h | 115 + .../http/quic_client_promised_info_test.cc | 350 + .../http/quic_client_push_promise_index.cc | 45 + .../http/quic_client_push_promise_index.h | 99 + .../quic_client_push_promise_index_test.cc | 109 + quiche/quic/core/http/quic_header_list.cc | 75 + quiche/quic/core/http/quic_header_list.h | 88 + .../quic/core/http/quic_header_list_test.cc | 86 + quiche/quic/core/http/quic_headers_stream.cc | 163 + quiche/quic/core/http/quic_headers_stream.h | 96 + .../core/http/quic_headers_stream_test.cc | 936 + .../core/http/quic_receive_control_stream.cc | 234 + .../core/http/quic_receive_control_stream.h | 78 + .../http/quic_receive_control_stream_test.cc | 461 + .../core/http/quic_send_control_stream.cc | 121 + .../quic/core/http/quic_send_control_stream.h | 65 + .../http/quic_send_control_stream_test.cc | 301 + .../http/quic_server_initiated_spdy_stream.cc | 42 + .../http/quic_server_initiated_spdy_stream.h | 32 + .../core/http/quic_server_session_base.cc | 425 + .../quic/core/http/quic_server_session_base.h | 161 + .../http/quic_server_session_base_test.cc | 801 + .../core/http/quic_spdy_client_session.cc | 214 + .../quic/core/http/quic_spdy_client_session.h | 131 + .../http/quic_spdy_client_session_base.cc | 271 + .../core/http/quic_spdy_client_session_base.h | 146 + .../http/quic_spdy_client_session_test.cc | 1339 ++ .../quic/core/http/quic_spdy_client_stream.cc | 227 + .../quic/core/http/quic_spdy_client_stream.h | 108 + .../core/http/quic_spdy_client_stream_test.cc | 316 + .../core/http/quic_spdy_server_stream_base.cc | 134 + .../core/http/quic_spdy_server_stream_base.h | 31 + .../http/quic_spdy_server_stream_base_test.cc | 336 + quiche/quic/core/http/quic_spdy_session.cc | 1862 ++ quiche/quic/core/http/quic_spdy_session.h | 674 + .../quic/core/http/quic_spdy_session_test.cc | 3785 ++++ quiche/quic/core/http/quic_spdy_stream.cc | 1673 ++ quiche/quic/core/http/quic_spdy_stream.h | 500 + .../http/quic_spdy_stream_body_manager.cc | 146 + .../core/http/quic_spdy_stream_body_manager.h | 93 + .../quic_spdy_stream_body_manager_test.cc | 286 + .../quic/core/http/quic_spdy_stream_test.cc | 3275 +++ .../quic/core/http/spdy_server_push_utils.cc | 215 + .../quic/core/http/spdy_server_push_utils.h | 43 + .../core/http/spdy_server_push_utils_test.cc | 221 + quiche/quic/core/http/spdy_utils.cc | 176 + quiche/quic/core/http/spdy_utils.h | 68 + quiche/quic/core/http/spdy_utils_test.cc | 410 + quiche/quic/core/http/web_transport_http3.cc | 474 + quiche/quic/core/http/web_transport_http3.h | 182 + .../core/http/web_transport_http3_test.cc | 52 + .../core/http/web_transport_stream_adapter.cc | 156 + .../core/http/web_transport_stream_adapter.h | 68 + .../io/event_loop_connecting_client_socket.cc | 621 + .../io/event_loop_connecting_client_socket.h | 106 + ...vent_loop_connecting_client_socket_test.cc | 700 + .../quic/core/io/event_loop_socket_factory.cc | 47 + .../quic/core/io/event_loop_socket_factory.h | 45 + .../quic/core/io/quic_all_event_loops_test.cc | 440 + .../quic/core/io/quic_default_event_loop.cc | 43 + quiche/quic/core/io/quic_default_event_loop.h | 26 + quiche/quic/core/io/quic_event_loop.h | 101 + quiche/quic/core/io/quic_poll_event_loop.cc | 263 + quiche/quic/core/io/quic_poll_event_loop.h | 166 + .../quic/core/io/quic_poll_event_loop_test.cc | 342 + quiche/quic/core/io/socket.h | 131 + quiche/quic/core/io/socket_posix.cc | 521 + quiche/quic/core/io/socket_test.cc | 197 + .../core/legacy_quic_stream_id_manager.cc | 139 + .../quic/core/legacy_quic_stream_id_manager.h | 128 + .../legacy_quic_stream_id_manager_test.cc | 178 + .../quic/core/packet_number_indexed_queue.h | 252 + .../core/packet_number_indexed_queue_test.cc | 205 + .../proto/cached_network_parameters.proto | 43 + .../proto/cached_network_parameters_proto.h | 10 + .../core/proto/crypto_server_config.proto | 34 + .../core/proto/crypto_server_config_proto.h | 10 + .../core/proto/source_address_token.proto | 32 + .../core/proto/source_address_token_proto.h | 10 + .../core/qpack/fuzzer/qpack_decoder_fuzzer.cc | 193 + .../qpack_decoder_stream_receiver_fuzzer.cc | 62 + .../qpack_decoder_stream_sender_fuzzer.cc | 54 + .../qpack_encoder_stream_receiver_fuzzer.cc | 67 + .../qpack_encoder_stream_sender_fuzzer.cc | 72 + .../qpack/fuzzer/qpack_round_trip_fuzzer.cc | 661 + .../quic/core/qpack/qpack_blocking_manager.cc | 158 + .../quic/core/qpack/qpack_blocking_manager.h | 98 + .../core/qpack/qpack_blocking_manager_test.cc | 319 + .../qpack_decoded_headers_accumulator.cc | 100 + .../qpack/qpack_decoded_headers_accumulator.h | 104 + .../qpack_decoded_headers_accumulator_test.cc | 248 + quiche/quic/core/qpack/qpack_decoder.cc | 170 + quiche/quic/core/qpack/qpack_decoder.h | 137 + .../qpack/qpack_decoder_stream_receiver.cc | 62 + .../qpack/qpack_decoder_stream_receiver.h | 69 + .../qpack_decoder_stream_receiver_test.cc | 99 + .../core/qpack/qpack_decoder_stream_sender.cc | 44 + .../core/qpack/qpack_decoder_stream_sender.h | 51 + .../qpack/qpack_decoder_stream_sender_test.cc | 101 + quiche/quic/core/qpack/qpack_decoder_test.cc | 979 + quiche/quic/core/qpack/qpack_encoder.cc | 455 + quiche/quic/core/qpack/qpack_encoder.h | 170 + .../qpack/qpack_encoder_stream_receiver.cc | 80 + .../qpack/qpack_encoder_stream_receiver.h | 73 + .../qpack_encoder_stream_receiver_test.cc | 193 + .../core/qpack/qpack_encoder_stream_sender.cc | 66 + .../core/qpack/qpack_encoder_stream_sender.h | 68 + .../qpack/qpack_encoder_stream_sender_test.cc | 179 + quiche/quic/core/qpack/qpack_encoder_test.cc | 633 + quiche/quic/core/qpack/qpack_header_table.cc | 239 + quiche/quic/core/qpack/qpack_header_table.h | 364 + .../core/qpack/qpack_header_table_test.cc | 652 + .../core/qpack/qpack_index_conversions.cc | 59 + .../quic/core/qpack/qpack_index_conversions.h | 52 + .../qpack/qpack_index_conversions_test.cc | 99 + .../core/qpack/qpack_instruction_decoder.cc | 332 + .../core/qpack/qpack_instruction_decoder.h | 160 + .../qpack/qpack_instruction_decoder_test.cc | 222 + .../core/qpack/qpack_instruction_encoder.cc | 176 + .../core/qpack/qpack_instruction_encoder.h | 83 + .../qpack/qpack_instruction_encoder_test.cc | 204 + quiche/quic/core/qpack/qpack_instructions.cc | 326 + quiche/quic/core/qpack/qpack_instructions.h | 205 + .../core/qpack/qpack_progressive_decoder.cc | 406 + .../core/qpack/qpack_progressive_decoder.h | 183 + .../quic/core/qpack/qpack_receive_stream.cc | 33 + quiche/quic/core/qpack/qpack_receive_stream.h | 41 + .../core/qpack/qpack_receive_stream_test.cc | 95 + .../core/qpack/qpack_required_insert_count.cc | 71 + .../core/qpack/qpack_required_insert_count.h | 30 + .../qpack/qpack_required_insert_count_test.cc | 125 + .../quic/core/qpack/qpack_round_trip_test.cc | 137 + quiche/quic/core/qpack/qpack_send_stream.cc | 51 + quiche/quic/core/qpack/qpack_send_stream.h | 60 + .../quic/core/qpack/qpack_send_stream_test.cc | 133 + quiche/quic/core/qpack/qpack_static_table.cc | 139 + quiche/quic/core/qpack/qpack_static_table.h | 31 + .../core/qpack/qpack_static_table_test.cc | 54 + .../quic/core/qpack/qpack_stream_receiver.h | 24 + .../core/qpack/qpack_stream_sender_delegate.h | 27 + .../core/qpack/value_splitting_header_list.cc | 108 + .../core/qpack/value_splitting_header_list.h | 62 + .../qpack/value_splitting_header_list_test.cc | 158 + .../quic/core/quic_ack_listener_interface.cc | 11 + .../quic/core/quic_ack_listener_interface.h | 37 + quiche/quic/core/quic_alarm.cc | 105 + quiche/quic/core/quic_alarm.h | 125 + quiche/quic/core/quic_alarm_factory.h | 36 + quiche/quic/core/quic_alarm_test.cc | 259 + quiche/quic/core/quic_arena_scoped_ptr.h | 208 + .../quic/core/quic_arena_scoped_ptr_test.cc | 115 + quiche/quic/core/quic_bandwidth.cc | 41 + quiche/quic/core/quic_bandwidth.h | 168 + quiche/quic/core/quic_bandwidth_test.cc | 151 + .../quic/core/quic_blocked_writer_interface.h | 29 + .../quic/core/quic_buffered_packet_store.cc | 321 + quiche/quic/core/quic_buffered_packet_store.h | 194 + .../core/quic_buffered_packet_store_test.cc | 600 + quiche/quic/core/quic_chaos_protector.cc | 225 + quiche/quic/core/quic_chaos_protector.h | 96 + quiche/quic/core/quic_chaos_protector_test.cc | 229 + quiche/quic/core/quic_clock.h | 47 + quiche/quic/core/quic_coalesced_packet.cc | 194 + quiche/quic/core/quic_coalesced_packet.h | 96 + .../quic/core/quic_coalesced_packet_test.cc | 213 + quiche/quic/core/quic_config.cc | 1434 ++ quiche/quic/core/quic_config.h | 683 + quiche/quic/core/quic_config_test.cc | 776 + quiche/quic/core/quic_connection.cc | 7409 +++++++ quiche/quic/core/quic_connection.h | 2387 +++ quiche/quic/core/quic_connection_context.cc | 48 + quiche/quic/core/quic_connection_context.h | 153 + .../quic/core/quic_connection_context_test.cc | 173 + quiche/quic/core/quic_connection_id.cc | 180 + quiche/quic/core/quic_connection_id.h | 138 + .../quic/core/quic_connection_id_manager.cc | 487 + quiche/quic/core/quic_connection_id_manager.h | 197 + .../core/quic_connection_id_manager_test.cc | 1074 + quiche/quic/core/quic_connection_id_test.cc | 181 + quiche/quic/core/quic_connection_stats.cc | 77 + quiche/quic/core/quic_connection_stats.h | 253 + quiche/quic/core/quic_connection_test.cc | 17517 ++++++++++++++++ quiche/quic/core/quic_constants.cc | 25 + quiche/quic/core/quic_constants.h | 333 + .../quic/core/quic_control_frame_manager.cc | 364 + quiche/quic/core/quic_control_frame_manager.h | 192 + .../core/quic_control_frame_manager_test.cc | 363 + .../core/quic_crypto_client_handshaker.cc | 634 + .../quic/core/quic_crypto_client_handshaker.h | 213 + .../quic_crypto_client_handshaker_test.cc | 217 + quiche/quic/core/quic_crypto_client_stream.cc | 178 + quiche/quic/core/quic_crypto_client_stream.h | 320 + .../core/quic_crypto_client_stream_test.cc | 371 + quiche/quic/core/quic_crypto_handshaker.cc | 52 + quiche/quic/core/quic_crypto_handshaker.h | 52 + quiche/quic/core/quic_crypto_server_stream.cc | 548 + quiche/quic/core/quic_crypto_server_stream.h | 271 + .../core/quic_crypto_server_stream_base.cc | 50 + .../core/quic_crypto_server_stream_base.h | 122 + .../core/quic_crypto_server_stream_test.cc | 397 + quiche/quic/core/quic_crypto_stream.cc | 518 + quiche/quic/core/quic_crypto_stream.h | 281 + quiche/quic/core/quic_crypto_stream_test.cc | 815 + quiche/quic/core/quic_data_reader.cc | 87 + quiche/quic/core/quic_data_reader.h | 69 + quiche/quic/core/quic_data_writer.cc | 105 + quiche/quic/core/quic_data_writer.h | 61 + quiche/quic/core/quic_data_writer_test.cc | 874 + quiche/quic/core/quic_datagram_queue.cc | 102 + quiche/quic/core/quic_datagram_queue.h | 95 + quiche/quic/core/quic_datagram_queue_test.cc | 297 + quiche/quic/core/quic_default_clock.cc | 26 + quiche/quic/core/quic_default_clock.h | 32 + .../core/quic_default_connection_helper.h | 49 + .../quic/core/quic_default_packet_writer.cc | 65 + quiche/quic/core/quic_default_packet_writer.h | 56 + quiche/quic/core/quic_dispatcher.cc | 1382 ++ quiche/quic/core/quic_dispatcher.h | 470 + quiche/quic/core/quic_dispatcher_test.cc | 3003 +++ quiche/quic/core/quic_error_codes.cc | 992 + quiche/quic/core/quic_error_codes.h | 776 + quiche/quic/core/quic_error_codes_test.cc | 143 + quiche/quic/core/quic_flags_list.h | 106 + quiche/quic/core/quic_flow_controller.cc | 314 + quiche/quic/core/quic_flow_controller.h | 216 + quiche/quic/core/quic_flow_controller_test.cc | 416 + quiche/quic/core/quic_framer.cc | 7306 +++++++ quiche/quic/core/quic_framer.h | 1243 ++ quiche/quic/core/quic_framer_test.cc | 16544 +++++++++++++++ .../quic/core/quic_idle_network_detector.cc | 173 + quiche/quic/core/quic_idle_network_detector.h | 130 + .../core/quic_idle_network_detector_test.cc | 282 + quiche/quic/core/quic_interval.h | 381 + quiche/quic/core/quic_interval_deque.h | 391 + quiche/quic/core/quic_interval_deque_test.cc | 361 + quiche/quic/core/quic_interval_set.h | 885 + quiche/quic/core/quic_interval_set_test.cc | 1062 + quiche/quic/core/quic_interval_test.cc | 467 + quiche/quic/core/quic_linux_socket_utils.cc | 310 + quiche/quic/core/quic_linux_socket_utils.h | 285 + .../quic/core/quic_linux_socket_utils_test.cc | 324 + quiche/quic/core/quic_lru_cache.h | 98 + quiche/quic/core/quic_lru_cache_test.cc | 81 + quiche/quic/core/quic_mtu_discovery.cc | 137 + quiche/quic/core/quic_mtu_discovery.h | 116 + .../core/quic_network_blackhole_detector.cc | 135 + .../core/quic_network_blackhole_detector.h | 91 + .../quic_network_blackhole_detector_test.cc | 139 + quiche/quic/core/quic_one_block_arena.h | 77 + quiche/quic/core/quic_one_block_arena_test.cc | 59 + quiche/quic/core/quic_packet_creator.cc | 2289 ++ quiche/quic/core/quic_packet_creator.h | 693 + quiche/quic/core/quic_packet_creator_test.cc | 4148 ++++ quiche/quic/core/quic_packet_number.cc | 109 + quiche/quic/core/quic_packet_number.h | 164 + quiche/quic/core/quic_packet_number_test.cc | 67 + quiche/quic/core/quic_packet_reader.cc | 136 + quiche/quic/core/quic_packet_reader.h | 64 + quiche/quic/core/quic_packet_writer.h | 171 + .../quic/core/quic_packet_writer_wrapper.cc | 73 + quiche/quic/core/quic_packet_writer_wrapper.h | 62 + quiche/quic/core/quic_packets.cc | 601 + quiche/quic/core/quic_packets.h | 452 + quiche/quic/core/quic_packets_test.cc | 120 + quiche/quic/core/quic_path_validator.cc | 175 + quiche/quic/core/quic_path_validator.h | 194 + quiche/quic/core/quic_path_validator_test.cc | 276 + quiche/quic/core/quic_ping_manager.cc | 163 + quiche/quic/core/quic_ping_manager.h | 108 + quiche/quic/core/quic_ping_manager_test.cc | 429 + .../quic/core/quic_process_packet_interface.h | 24 + quiche/quic/core/quic_protocol_flags_list.h | 229 + .../quic/core/quic_received_packet_manager.cc | 362 + .../quic/core/quic_received_packet_manager.h | 221 + .../core/quic_received_packet_manager_test.cc | 704 + quiche/quic/core/quic_sent_packet_manager.cc | 1468 ++ quiche/quic/core/quic_sent_packet_manager.h | 680 + .../core/quic_sent_packet_manager_test.cc | 3216 +++ quiche/quic/core/quic_server_id.cc | 108 + quiche/quic/core/quic_server_id.h | 73 + quiche/quic/core/quic_server_id_test.cc | 225 + quiche/quic/core/quic_session.cc | 2728 +++ quiche/quic/core/quic_session.h | 1036 + quiche/quic/core/quic_session_test.cc | 3318 +++ quiche/quic/core/quic_socket_address_coder.cc | 92 + quiche/quic/core/quic_socket_address_coder.h | 42 + .../core/quic_socket_address_coder_test.cc | 130 + quiche/quic/core/quic_stream.cc | 1438 ++ quiche/quic/core/quic_stream.h | 610 + .../core/quic_stream_frame_data_producer.h | 38 + quiche/quic/core/quic_stream_id_manager.cc | 238 + quiche/quic/core/quic_stream_id_manager.h | 184 + .../quic/core/quic_stream_id_manager_test.cc | 472 + quiche/quic/core/quic_stream_priority.cc | 85 + quiche/quic/core/quic_stream_priority.h | 142 + quiche/quic/core/quic_stream_priority_test.cc | 160 + quiche/quic/core/quic_stream_send_buffer.cc | 293 + quiche/quic/core/quic_stream_send_buffer.h | 171 + .../quic/core/quic_stream_send_buffer_test.cc | 345 + quiche/quic/core/quic_stream_sequencer.cc | 315 + quiche/quic/core/quic_stream_sequencer.h | 220 + .../quic/core/quic_stream_sequencer_buffer.cc | 542 + .../quic/core/quic_stream_sequencer_buffer.h | 241 + .../core/quic_stream_sequencer_buffer_test.cc | 1139 + .../quic/core/quic_stream_sequencer_test.cc | 782 + quiche/quic/core/quic_stream_test.cc | 1752 ++ .../core/quic_sustained_bandwidth_recorder.cc | 59 + .../core/quic_sustained_bandwidth_recorder.h | 92 + .../quic_sustained_bandwidth_recorder_test.cc | 133 + quiche/quic/core/quic_syscall_wrapper.cc | 47 + quiche/quic/core/quic_syscall_wrapper.h | 47 + quiche/quic/core/quic_tag.cc | 109 + quiche/quic/core/quic_tag.h | 67 + quiche/quic/core/quic_tag_test.cc | 80 + quiche/quic/core/quic_time.cc | 81 + quiche/quic/core/quic_time.h | 295 + quiche/quic/core/quic_time_accumulator.h | 69 + .../quic/core/quic_time_accumulator_test.cc | 83 + quiche/quic/core/quic_time_test.cc | 186 + .../quic/core/quic_time_wait_list_manager.cc | 486 + .../quic/core/quic_time_wait_list_manager.h | 331 + .../core/quic_time_wait_list_manager_test.cc | 781 + quiche/quic/core/quic_trace_visitor.cc | 341 + quiche/quic/core/quic_trace_visitor.h | 75 + quiche/quic/core/quic_trace_visitor_test.cc | 184 + quiche/quic/core/quic_transmission_info.cc | 56 + quiche/quic/core/quic_transmission_info.h | 66 + quiche/quic/core/quic_types.cc | 465 + quiche/quic/core/quic_types.h | 928 + quiche/quic/core/quic_udp_socket.h | 270 + quiche/quic/core/quic_udp_socket_posix.cc | 711 + quiche/quic/core/quic_unacked_packet_map.cc | 652 + quiche/quic/core/quic_unacked_packet_map.h | 336 + .../quic/core/quic_unacked_packet_map_test.cc | 722 + quiche/quic/core/quic_utils.cc | 630 + quiche/quic/core/quic_utils.h | 290 + quiche/quic/core/quic_utils_test.cc | 320 + quiche/quic/core/quic_version_manager.cc | 94 + quiche/quic/core/quic_version_manager.h | 95 + quiche/quic/core/quic_version_manager_test.cc | 83 + quiche/quic/core/quic_versions.cc | 664 + quiche/quic/core/quic_versions.h | 649 + quiche/quic/core/quic_versions_test.cc | 523 + quiche/quic/core/quic_write_blocked_list.cc | 212 + quiche/quic/core/quic_write_blocked_list.h | 220 + .../quic/core/quic_write_blocked_list_test.cc | 678 + quiche/quic/core/session_notifier_interface.h | 48 + quiche/quic/core/socket_factory.h | 47 + quiche/quic/core/stream_delegate_interface.h | 55 + quiche/quic/core/tls_chlo_extractor.cc | 429 + quiche/quic/core/tls_chlo_extractor.h | 280 + quiche/quic/core/tls_chlo_extractor_test.cc | 291 + quiche/quic/core/tls_client_handshaker.cc | 665 + quiche/quic/core/tls_client_handshaker.h | 175 + .../quic/core/tls_client_handshaker_test.cc | 863 + quiche/quic/core/tls_handshaker.cc | 406 + quiche/quic/core/tls_handshaker.h | 230 + quiche/quic/core/tls_server_handshaker.cc | 1185 ++ quiche/quic/core/tls_server_handshaker.h | 386 + .../quic/core/tls_server_handshaker_test.cc | 1168 ++ .../quic/core/uber_quic_stream_id_manager.cc | 170 + .../quic/core/uber_quic_stream_id_manager.h | 106 + .../core/uber_quic_stream_id_manager_test.cc | 332 + .../quic/core/uber_received_packet_manager.cc | 246 + .../quic/core/uber_received_packet_manager.h | 112 + .../core/uber_received_packet_manager_test.cc | 568 + quiche/quic/core/web_transport_interface.h | 53 + .../load_balancer/load_balancer_config.cc | 202 + .../quic/load_balancer/load_balancer_config.h | 94 + .../load_balancer_config_test.cc | 190 + .../load_balancer/load_balancer_decoder.cc | 90 + .../load_balancer/load_balancer_decoder.h | 59 + .../load_balancer_decoder_test.cc | 242 + .../load_balancer/load_balancer_encoder.cc | 203 + .../load_balancer/load_balancer_encoder.h | 156 + .../load_balancer_encoder_test.cc | 451 + .../load_balancer/load_balancer_server_id.cc | 45 + .../load_balancer/load_balancer_server_id.h | 71 + .../load_balancer_server_id_map.h | 104 + .../load_balancer_server_id_map_test.cc | 94 + .../load_balancer_server_id_test.cc | 106 + quiche/quic/masque/README.md | 4 + quiche/quic/masque/masque_client.cc | 106 + quiche/quic/masque/masque_client.h | 67 + quiche/quic/masque/masque_client_bin.cc | 258 + quiche/quic/masque/masque_client_session.cc | 524 + quiche/quic/masque/masque_client_session.h | 238 + quiche/quic/masque/masque_client_tools.cc | 151 + quiche/quic/masque/masque_client_tools.h | 27 + quiche/quic/masque/masque_dispatcher.cc | 49 + quiche/quic/masque/masque_dispatcher.h | 52 + .../quic/masque/masque_encapsulated_client.cc | 262 + .../quic/masque/masque_encapsulated_client.h | 47 + .../masque_encapsulated_client_session.cc | 255 + .../masque_encapsulated_client_session.h | 78 + quiche/quic/masque/masque_server.cc | 31 + quiche/quic/masque/masque_server.h | 35 + quiche/quic/masque/masque_server_backend.cc | 153 + quiche/quic/masque/masque_server_backend.h | 80 + quiche/quic/masque/masque_server_bin.cc | 71 + quiche/quic/masque/masque_server_session.cc | 642 + quiche/quic/masque/masque_server_session.h | 155 + quiche/quic/masque/masque_utils.cc | 150 + quiche/quic/masque/masque_utils.h | 48 + quiche/quic/platform/README.md | 12 + quiche/quic/platform/api/README.md | 72 + quiche/quic/platform/api/quic_bug_tracker.h | 15 + quiche/quic/platform/api/quic_client_stats.h | 87 + .../api/quic_default_proof_providers.h | 33 + quiche/quic/platform/api/quic_expect_bug.h | 14 + quiche/quic/platform/api/quic_export.h | 21 + .../quic/platform/api/quic_exported_stats.h | 96 + quiche/quic/platform/api/quic_flag_utils.h | 19 + quiche/quic/platform/api/quic_flags.h | 21 + .../quic/platform/api/quic_hostname_utils.h | 16 + quiche/quic/platform/api/quic_ip_address.h | 16 + .../platform/api/quic_ip_address_family.h | 16 + quiche/quic/platform/api/quic_logging.h | 32 + quiche/quic/platform/api/quic_mutex.h | 32 + quiche/quic/platform/api/quic_server_stats.h | 25 + .../quic/platform/api/quic_socket_address.cc | 152 + .../quic/platform/api/quic_socket_address.h | 67 + .../platform/api/quic_socket_address_test.cc | 134 + quiche/quic/platform/api/quic_stack_trace.h | 18 + quiche/quic/platform/api/quic_test.h | 26 + quiche/quic/platform/api/quic_test_loopback.h | 38 + quiche/quic/platform/api/quic_test_output.h | 28 + quiche/quic/platform/api/quic_testvalue.h | 22 + quiche/quic/platform/api/quic_thread.h | 16 + .../api/quic_udp_socket_platform_api.h | 27 + quiche/quic/qbone/bonnet/icmp_reachable.cc | 209 + quiche/quic/qbone/bonnet/icmp_reachable.h | 146 + .../qbone/bonnet/icmp_reachable_interface.h | 27 + .../quic/qbone/bonnet/icmp_reachable_test.cc | 261 + .../quic/qbone/bonnet/mock_icmp_reachable.h | 20 + .../mock_packet_exchanger_stats_interface.h | 27 + quiche/quic/qbone/bonnet/mock_qbone_tunnel.h | 45 + quiche/quic/qbone/bonnet/mock_tun_device.h | 28 + .../qbone/bonnet/mock_tun_device_controller.h | 27 + quiche/quic/qbone/bonnet/qbone_tunnel_info.cc | 37 + quiche/quic/qbone/bonnet/qbone_tunnel_info.h | 29 + .../qbone/bonnet/qbone_tunnel_interface.h | 70 + quiche/quic/qbone/bonnet/qbone_tunnel_silo.cc | 31 + quiche/quic/qbone/bonnet/qbone_tunnel_silo.h | 48 + .../qbone/bonnet/qbone_tunnel_silo_test.cc | 78 + quiche/quic/qbone/bonnet/tun_device.cc | 217 + quiche/quic/qbone/bonnet/tun_device.h | 82 + .../qbone/bonnet/tun_device_controller.cc | 176 + .../quic/qbone/bonnet/tun_device_controller.h | 73 + .../bonnet/tun_device_controller_test.cc | 263 + .../quic/qbone/bonnet/tun_device_interface.h | 38 + .../bonnet/tun_device_packet_exchanger.cc | 230 + .../bonnet/tun_device_packet_exchanger.h | 86 + .../tun_device_packet_exchanger_test.cc | 119 + quiche/quic/qbone/bonnet/tun_device_test.cc | 211 + quiche/quic/qbone/mock_qbone_client.h | 22 + quiche/quic/qbone/mock_qbone_server_session.h | 35 + quiche/quic/qbone/platform/icmp_packet.cc | 88 + quiche/quic/qbone/platform/icmp_packet.h | 27 + .../quic/qbone/platform/icmp_packet_test.cc | 128 + .../quic/qbone/platform/internet_checksum.cc | 36 + .../quic/qbone/platform/internet_checksum.h | 32 + .../qbone/platform/internet_checksum_test.cc | 67 + quiche/quic/qbone/platform/ip_range.cc | 97 + quiche/quic/qbone/platform/ip_range.h | 61 + quiche/quic/qbone/platform/ip_range_test.cc | 65 + quiche/quic/qbone/platform/kernel_interface.h | 149 + quiche/quic/qbone/platform/mock_kernel.h | 41 + quiche/quic/qbone/platform/mock_netlink.h | 43 + quiche/quic/qbone/platform/netlink.cc | 856 + quiche/quic/qbone/platform/netlink.h | 138 + .../quic/qbone/platform/netlink_interface.h | 144 + quiche/quic/qbone/platform/netlink_test.cc | 788 + .../quic/qbone/platform/rtnetlink_message.cc | 162 + .../quic/qbone/platform/rtnetlink_message.h | 112 + .../qbone/platform/rtnetlink_message_test.cc | 229 + quiche/quic/qbone/platform/tcp_packet.cc | 127 + quiche/quic/qbone/platform/tcp_packet.h | 25 + quiche/quic/qbone/platform/tcp_packet_test.cc | 117 + quiche/quic/qbone/qbone_client.cc | 100 + quiche/quic/qbone/qbone_client.h | 74 + quiche/quic/qbone/qbone_client_interface.h | 25 + quiche/quic/qbone/qbone_client_session.cc | 124 + quiche/quic/qbone/qbone_client_session.h | 93 + quiche/quic/qbone/qbone_client_test.cc | 264 + quiche/quic/qbone/qbone_constants.cc | 45 + quiche/quic/qbone/qbone_constants.h | 35 + quiche/quic/qbone/qbone_control.proto | 13 + .../qbone/qbone_control_placeholder.proto | 20 + quiche/quic/qbone/qbone_control_stream.cc | 85 + quiche/quic/qbone/qbone_control_stream.h | 82 + quiche/quic/qbone/qbone_packet_exchanger.cc | 73 + quiche/quic/qbone/qbone_packet_exchanger.h | 78 + .../quic/qbone/qbone_packet_exchanger_test.cc | 269 + quiche/quic/qbone/qbone_packet_processor.cc | 291 + quiche/quic/qbone/qbone_packet_processor.h | 200 + .../quic/qbone/qbone_packet_processor_test.cc | 388 + .../qbone_packet_processor_test_tools.cc | 40 + .../qbone/qbone_packet_processor_test_tools.h | 46 + quiche/quic/qbone/qbone_packet_writer.h | 24 + quiche/quic/qbone/qbone_server_session.cc | 113 + quiche/quic/qbone/qbone_server_session.h | 101 + quiche/quic/qbone/qbone_session_base.cc | 213 + quiche/quic/qbone/qbone_session_base.h | 111 + quiche/quic/qbone/qbone_session_test.cc | 636 + quiche/quic/qbone/qbone_stream.cc | 62 + quiche/quic/qbone/qbone_stream.h | 56 + quiche/quic/qbone/qbone_stream_test.cc | 262 + quiche/quic/test_tools/bad_packet_writer.cc | 35 + quiche/quic/test_tools/bad_packet_writer.h | 35 + quiche/quic/test_tools/crypto_test_utils.cc | 979 + quiche/quic/test_tools/crypto_test_utils.h | 222 + .../quic/test_tools/crypto_test_utils_test.cc | 187 + .../quic/test_tools/failing_proof_source.cc | 40 + quiche/quic/test_tools/failing_proof_source.h | 45 + quiche/quic/test_tools/fake_proof_source.cc | 146 + quiche/quic/test_tools/fake_proof_source.h | 129 + .../test_tools/fake_proof_source_handle.cc | 235 + .../test_tools/fake_proof_source_handle.h | 198 + quiche/quic/test_tools/first_flight.cc | 191 + quiche/quic/test_tools/first_flight.h | 133 + quiche/quic/test_tools/fuzzing/README.md | 22 + .../test_tools/fuzzing/quic_framer_fuzzer.cc | 30 + .../quic_framer_process_data_packet_fuzzer.cc | 285 + .../test_tools/limited_mtu_test_writer.cc | 27 + .../quic/test_tools/limited_mtu_test_writer.h | 36 + quiche/quic/test_tools/mock_clock.cc | 25 + quiche/quic/test_tools/mock_clock.h | 36 + .../test_tools/mock_connection_id_generator.h | 31 + .../mock_quic_client_promised_info.cc | 17 + .../mock_quic_client_promised_info.h | 33 + .../quic/test_tools/mock_quic_dispatcher.cc | 28 + quiche/quic/test_tools/mock_quic_dispatcher.h | 43 + .../test_tools/mock_quic_session_visitor.cc | 19 + .../test_tools/mock_quic_session_visitor.h | 62 + .../mock_quic_spdy_client_stream.cc | 17 + .../test_tools/mock_quic_spdy_client_stream.h | 33 + .../mock_quic_time_wait_list_manager.cc | 30 + .../mock_quic_time_wait_list_manager.h | 63 + quiche/quic/test_tools/mock_random.cc | 48 + quiche/quic/test_tools/mock_random.h | 57 + .../test_tools/packet_dropping_test_writer.cc | 252 + .../test_tools/packet_dropping_test_writer.h | 185 + .../test_tools/packet_reordering_writer.cc | 51 + .../test_tools/packet_reordering_writer.h | 44 + .../qpack/qpack_decoder_test_utils.cc | 85 + .../qpack/qpack_decoder_test_utils.h | 101 + .../test_tools/qpack/qpack_encoder_peer.cc | 30 + .../test_tools/qpack/qpack_encoder_peer.h | 30 + .../test_tools/qpack/qpack_offline_decoder.cc | 336 + .../test_tools/qpack/qpack_offline_decoder.h | 88 + .../quic/test_tools/qpack/qpack_test_utils.cc | 28 + .../quic/test_tools/qpack/qpack_test_utils.h | 51 + .../quic_buffered_packet_store_peer.cc | 25 + .../quic_buffered_packet_store_peer.h | 32 + .../quic_client_promised_info_peer.cc | 17 + .../quic_client_promised_info_peer.h | 22 + .../quic_client_session_cache_peer.h | 33 + .../test_tools/quic_coalesced_packet_peer.cc | 23 + .../test_tools/quic_coalesced_packet_peer.h | 26 + quiche/quic/test_tools/quic_config_peer.cc | 156 + quiche/quic/test_tools/quic_config_peer.h | 90 + .../quic_connection_id_manager_peer.h | 29 + .../quic/test_tools/quic_connection_peer.cc | 628 + quiche/quic/test_tools/quic_connection_peer.h | 253 + .../quic_crypto_server_config_peer.cc | 152 + .../quic_crypto_server_config_peer.h | 88 + .../quic/test_tools/quic_dispatcher_peer.cc | 136 + quiche/quic/test_tools/quic_dispatcher_peer.h | 82 + .../test_tools/quic_flow_controller_peer.cc | 65 + .../test_tools/quic_flow_controller_peer.h | 45 + quiche/quic/test_tools/quic_framer_peer.cc | 119 + quiche/quic/test_tools/quic_framer_peer.h | 69 + .../test.example.com/index.html | 63 + .../test.example.com/map.html | 65 + .../test_tools/quic_interval_deque_peer.h | 35 + .../test_tools/quic_mock_syscall_wrapper.cc | 22 + .../test_tools/quic_mock_syscall_wrapper.h | 33 + .../test_tools/quic_packet_creator_peer.cc | 161 + .../test_tools/quic_packet_creator_peer.h | 64 + .../test_tools/quic_path_validator_peer.cc | 15 + .../test_tools/quic_path_validator_peer.h | 20 + .../quic_sent_packet_manager_peer.cc | 186 + .../quic_sent_packet_manager_peer.h | 97 + quiche/quic/test_tools/quic_server_peer.cc | 32 + quiche/quic/test_tools/quic_server_peer.h | 28 + .../quic_server_session_base_peer.h | 33 + quiche/quic/test_tools/quic_session_peer.cc | 246 + quiche/quic/test_tools/quic_session_peer.h | 96 + .../quic/test_tools/quic_spdy_session_peer.cc | 119 + .../quic/test_tools/quic_spdy_session_peer.h | 62 + .../quic/test_tools/quic_spdy_stream_peer.cc | 33 + .../quic/test_tools/quic_spdy_stream_peer.h | 33 + .../test_tools/quic_stream_id_manager_peer.cc | 40 + .../test_tools/quic_stream_id_manager_peer.h | 38 + quiche/quic/test_tools/quic_stream_peer.cc | 119 + quiche/quic/test_tools/quic_stream_peer.h | 55 + .../quic_stream_send_buffer_peer.cc | 54 + .../test_tools/quic_stream_send_buffer_peer.h | 33 + .../quic_stream_sequencer_buffer_peer.cc | 163 + .../quic_stream_sequencer_buffer_peer.h | 65 + .../test_tools/quic_stream_sequencer_peer.cc | 39 + .../test_tools/quic_stream_sequencer_peer.h | 33 + .../quic_sustained_bandwidth_recorder_peer.cc | 34 + .../quic_sustained_bandwidth_recorder_peer.h | 35 + quiche/quic/test_tools/quic_test_backend.cc | 120 + quiche/quic/test_tools/quic_test_backend.h | 44 + quiche/quic/test_tools/quic_test_client.cc | 932 + quiche/quic/test_tools/quic_test_client.h | 444 + quiche/quic/test_tools/quic_test_server.cc | 260 + quiche/quic/test_tools/quic_test_server.h | 126 + quiche/quic/test_tools/quic_test_utils.cc | 1515 ++ quiche/quic/test_tools/quic_test_utils.h | 2197 ++ .../quic/test_tools/quic_test_utils_test.cc | 79 + .../quic_time_wait_list_manager_peer.cc | 45 + .../quic_time_wait_list_manager_peer.h | 36 + .../quic_unacked_packet_map_peer.cc | 29 + .../test_tools/quic_unacked_packet_map_peer.h | 27 + quiche/quic/test_tools/rtt_stats_peer.cc | 21 + quiche/quic/test_tools/rtt_stats_peer.h | 26 + .../send_algorithm_test_result.proto | 15 + .../test_tools/send_algorithm_test_utils.cc | 61 + .../test_tools/send_algorithm_test_utils.h | 29 + quiche/quic/test_tools/server_thread.cc | 143 + quiche/quic/test_tools/server_thread.h | 96 + .../quic/test_tools/simple_data_producer.cc | 67 + quiche/quic/test_tools/simple_data_producer.h | 73 + quiche/quic/test_tools/simple_quic_framer.cc | 439 + quiche/quic/test_tools/simple_quic_framer.h | 70 + .../quic/test_tools/simple_session_cache.cc | 78 + quiche/quic/test_tools/simple_session_cache.h | 53 + .../test_tools/simple_session_notifier.cc | 768 + .../quic/test_tools/simple_session_notifier.h | 167 + .../simple_session_notifier_test.cc | 367 + quiche/quic/test_tools/simulator/README.md | 99 + quiche/quic/test_tools/simulator/actor.cc | 28 + quiche/quic/test_tools/simulator/actor.h | 66 + .../test_tools/simulator/alarm_factory.cc | 80 + .../quic/test_tools/simulator/alarm_factory.h | 39 + quiche/quic/test_tools/simulator/link.cc | 115 + quiche/quic/test_tools/simulator/link.h | 97 + .../test_tools/simulator/packet_filter.cc | 39 + .../quic/test_tools/simulator/packet_filter.h | 75 + quiche/quic/test_tools/simulator/port.cc | 21 + quiche/quic/test_tools/simulator/port.h | 66 + quiche/quic/test_tools/simulator/queue.cc | 127 + quiche/quic/test_tools/simulator/queue.h | 119 + .../test_tools/simulator/quic_endpoint.cc | 246 + .../quic/test_tools/simulator/quic_endpoint.h | 171 + .../simulator/quic_endpoint_base.cc | 206 + .../test_tools/simulator/quic_endpoint_base.h | 160 + .../simulator/quic_endpoint_test.cc | 207 + quiche/quic/test_tools/simulator/simulator.cc | 160 + quiche/quic/test_tools/simulator/simulator.h | 166 + .../test_tools/simulator/simulator_test.cc | 827 + quiche/quic/test_tools/simulator/switch.cc | 77 + quiche/quic/test_tools/simulator/switch.h | 84 + .../quic/test_tools/simulator/test_harness.cc | 35 + .../quic/test_tools/simulator/test_harness.h | 83 + .../test_tools/simulator/traffic_policer.cc | 58 + .../test_tools/simulator/traffic_policer.h | 52 + quiche/quic/test_tools/test_certificates.cc | 731 + quiche/quic/test_tools/test_certificates.h | 50 + quiche/quic/test_tools/test_ticket_crypter.cc | 84 + quiche/quic/test_tools/test_ticket_crypter.h | 54 + .../web_transport_resets_backend.cc | 113 + .../test_tools/web_transport_resets_backend.h | 24 + .../test_tools/web_transport_test_tools.h | 43 + quiche/quic/tools/connect_server_backend.cc | 164 + quiche/quic/tools/connect_server_backend.h | 70 + quiche/quic/tools/connect_tunnel.cc | 295 + quiche/quic/tools/connect_tunnel.h | 89 + quiche/quic/tools/connect_tunnel_test.cc | 353 + quiche/quic/tools/connect_udp_tunnel.cc | 424 + quiche/quic/tools/connect_udp_tunnel.h | 99 + quiche/quic/tools/connect_udp_tunnel_test.cc | 362 + .../quic/tools/crypto_message_printer_bin.cc | 61 + quiche/quic/tools/fake_proof_verifier.h | 44 + .../quic/tools/qpack_offline_decoder_bin.cc | 46 + quiche/quic/tools/quic_backend_response.cc | 28 + quiche/quic/tools/quic_backend_response.h | 98 + quiche/quic/tools/quic_client_base.cc | 545 + quiche/quic/tools/quic_client_base.h | 479 + quiche/quic/tools/quic_client_bin.cc | 67 + .../quic_client_default_network_helper.cc | 258 + .../quic_client_default_network_helper.h | 133 + quiche/quic/tools/quic_client_factory.h | 35 + .../tools/quic_client_interop_test_bin.cc | 462 + quiche/quic/tools/quic_default_client.cc | 103 + quiche/quic/tools/quic_default_client.h | 87 + quiche/quic/tools/quic_default_client_test.cc | 146 + .../quic/tools/quic_epoll_client_factory.cc | 40 + quiche/quic/tools/quic_epoll_client_factory.h | 33 + .../quic/tools/quic_memory_cache_backend.cc | 507 + quiche/quic/tools/quic_memory_cache_backend.h | 210 + .../tools/quic_memory_cache_backend_test.cc | 264 + quiche/quic/tools/quic_name_lookup.cc | 54 + quiche/quic/tools/quic_name_lookup.h | 35 + quiche/quic/tools/quic_packet_printer_bin.cc | 287 + .../tools/quic_reject_reason_decoder_bin.cc | 45 + quiche/quic/tools/quic_server.cc | 231 + quiche/quic/tools/quic_server.h | 174 + quiche/quic/tools/quic_server_bin.cc | 29 + quiche/quic/tools/quic_server_factory.cc | 21 + quiche/quic/tools/quic_server_factory.h | 23 + quiche/quic/tools/quic_server_test.cc | 228 + .../quic/tools/quic_simple_client_session.cc | 82 + .../quic/tools/quic_simple_client_session.h | 52 + .../quic/tools/quic_simple_client_stream.cc | 29 + quiche/quic/tools/quic_simple_client_stream.h | 26 + ...quic_simple_crypto_server_stream_helper.cc | 26 + .../quic_simple_crypto_server_stream_helper.h | 31 + quiche/quic/tools/quic_simple_dispatcher.cc | 66 + quiche/quic/tools/quic_simple_dispatcher.h | 53 + .../quic/tools/quic_simple_server_backend.h | 123 + .../quic/tools/quic_simple_server_session.cc | 106 + .../quic/tools/quic_simple_server_session.h | 87 + .../tools/quic_simple_server_session_test.cc | 465 + .../quic/tools/quic_simple_server_stream.cc | 503 + quiche/quic/tools/quic_simple_server_stream.h | 126 + .../tools/quic_simple_server_stream_test.cc | 912 + quiche/quic/tools/quic_spdy_client_base.cc | 282 + quiche/quic/tools/quic_spdy_client_base.h | 237 + quiche/quic/tools/quic_spdy_server_base.h | 30 + .../tools/quic_tcp_like_trace_converter.cc | 118 + .../tools/quic_tcp_like_trace_converter.h | 85 + .../quic_tcp_like_trace_converter_test.cc | 124 + quiche/quic/tools/quic_toy_client.cc | 555 + quiche/quic/tools/quic_toy_client.h | 35 + quiche/quic/tools/quic_toy_server.cc | 174 + quiche/quic/tools/quic_toy_server.h | 63 + quiche/quic/tools/quic_url.cc | 101 + quiche/quic/tools/quic_url.h | 61 + quiche/quic/tools/quic_url_test.cc | 156 + quiche/quic/tools/simple_ticket_crypter.cc | 112 + quiche/quic/tools/simple_ticket_crypter.h | 56 + .../quic/tools/simple_ticket_crypter_test.cc | 111 + .../quic/tools/web_transport_test_visitors.h | 270 + quiche/spdy/core/array_output_buffer.cc | 21 + quiche/spdy/core/array_output_buffer.h | 47 + quiche/spdy/core/array_output_buffer_test.cc | 49 + .../core/header_byte_listener_interface.h | 22 + quiche/spdy/core/hpack/hpack_constants.cc | 374 + quiche/spdy/core/hpack/hpack_constants.h | 88 + .../spdy/core/hpack/hpack_decoder_adapter.cc | 164 + .../spdy/core/hpack/hpack_decoder_adapter.h | 156 + .../core/hpack/hpack_decoder_adapter_test.cc | 1119 + quiche/spdy/core/hpack/hpack_encoder.cc | 375 + quiche/spdy/core/hpack/hpack_encoder.h | 147 + quiche/spdy/core/hpack/hpack_encoder_test.cc | 754 + quiche/spdy/core/hpack/hpack_entry.cc | 24 + quiche/spdy/core/hpack/hpack_entry.h | 81 + quiche/spdy/core/hpack/hpack_entry_test.cc | 53 + quiche/spdy/core/hpack/hpack_header_table.cc | 188 + quiche/spdy/core/hpack/hpack_header_table.h | 153 + .../core/hpack/hpack_header_table_test.cc | 392 + quiche/spdy/core/hpack/hpack_output_stream.cc | 100 + quiche/spdy/core/hpack/hpack_output_stream.h | 75 + .../core/hpack/hpack_output_stream_test.cc | 284 + .../spdy/core/hpack/hpack_round_trip_test.cc | 224 + quiche/spdy/core/hpack/hpack_static_table.cc | 50 + quiche/spdy/core/hpack/hpack_static_table.h | 56 + .../core/hpack/hpack_static_table_test.cc | 63 + .../spdy/core/http2_frame_decoder_adapter.cc | 1111 + .../spdy/core/http2_frame_decoder_adapter.h | 564 + quiche/spdy/core/http2_header_block.cc | 315 + quiche/spdy/core/http2_header_block.h | 291 + .../core/http2_header_block_hpack_listener.h | 49 + quiche/spdy/core/http2_header_block_test.cc | 295 + quiche/spdy/core/http2_header_storage.cc | 59 + quiche/spdy/core/http2_header_storage.h | 58 + quiche/spdy/core/http2_header_storage_test.cc | 35 + quiche/spdy/core/metadata_extension.cc | 176 + quiche/spdy/core/metadata_extension.h | 122 + quiche/spdy/core/metadata_extension_test.cc | 281 + quiche/spdy/core/no_op_headers_handler.h | 38 + quiche/spdy/core/recording_headers_handler.cc | 38 + quiche/spdy/core/recording_headers_handler.h | 51 + quiche/spdy/core/spdy_alt_svc_wire_format.cc | 420 + quiche/spdy/core/spdy_alt_svc_wire_format.h | 104 + .../core/spdy_alt_svc_wire_format_test.cc | 636 + quiche/spdy/core/spdy_bitmasks.h | 18 + quiche/spdy/core/spdy_frame_builder.cc | 182 + quiche/spdy/core/spdy_frame_builder.h | 140 + quiche/spdy/core/spdy_frame_builder_test.cc | 86 + quiche/spdy/core/spdy_framer.cc | 1365 ++ quiche/spdy/core/spdy_framer.h | 376 + quiche/spdy/core/spdy_framer_test.cc | 5089 +++++ .../core/spdy_headers_handler_interface.h | 39 + quiche/spdy/core/spdy_intrusive_list.h | 341 + quiche/spdy/core/spdy_intrusive_list_test.cc | 420 + quiche/spdy/core/spdy_no_op_visitor.cc | 27 + quiche/spdy/core/spdy_no_op_visitor.h | 91 + .../spdy/core/spdy_pinnable_buffer_piece.cc | 36 + quiche/spdy/core/spdy_pinnable_buffer_piece.h | 53 + .../core/spdy_pinnable_buffer_piece_test.cc | 80 + .../spdy/core/spdy_prefixed_buffer_reader.cc | 84 + .../spdy/core/spdy_prefixed_buffer_reader.h | 43 + .../core/spdy_prefixed_buffer_reader_test.cc | 131 + quiche/spdy/core/spdy_protocol.cc | 616 + quiche/spdy/core/spdy_protocol.h | 1126 + quiche/spdy/core/spdy_protocol_test.cc | 275 + quiche/spdy/core/spdy_simple_arena.cc | 106 + quiche/spdy/core/spdy_simple_arena.h | 77 + quiche/spdy/core/spdy_simple_arena_test.cc | 141 + quiche/spdy/core/zero_copy_output_buffer.h | 32 + .../test_tools/mock_spdy_framer_visitor.cc | 17 + .../test_tools/mock_spdy_framer_visitor.h | 125 + quiche/spdy/test_tools/spdy_test_utils.cc | 101 + quiche/spdy/test_tools/spdy_test_utils.h | 41 + .../test_tools/mock_web_transport.h | 80 + quiche/web_transport/web_transport.h | 221 + 1565 files changed, 415866 insertions(+) create mode 100644 .bazelrc create mode 100644 BUILD.bazel create mode 100644 CONTRIBUTING.md create mode 100644 LICENSE create mode 100644 README.md create mode 100644 WHITESPACE create mode 100644 WORKSPACE.bazel create mode 100644 build/BUILD.bazel create mode 100644 build/source_list.bzl create mode 100644 build/source_list.gni create mode 100644 build/source_list.json create mode 100644 build/test.bzl create mode 100644 build/zlib.BUILD create mode 100644 depstool/deps/parse.go create mode 100644 depstool/deps/parse_test.go create mode 100644 depstool/depstool.go create mode 100644 depstool/go.mod create mode 100644 depstool/go.sum create mode 100644 quiche/BUILD.bazel create mode 100644 quiche/balsa/balsa_enums.cc create mode 100644 quiche/balsa/balsa_enums.h create mode 100644 quiche/balsa/balsa_frame.cc create mode 100644 quiche/balsa/balsa_frame.h create mode 100644 quiche/balsa/balsa_frame_test.cc create mode 100644 quiche/balsa/balsa_headers.cc create mode 100644 quiche/balsa/balsa_headers.h create mode 100644 quiche/balsa/balsa_headers_test.cc create mode 100644 quiche/balsa/balsa_visitor_interface.h create mode 100644 quiche/balsa/framer_interface.h create mode 100644 quiche/balsa/header_api.h create mode 100644 quiche/balsa/header_properties.cc create mode 100644 quiche/balsa/header_properties.h create mode 100644 quiche/balsa/header_properties_test.cc create mode 100644 quiche/balsa/http_validation_policy.h create mode 100644 quiche/balsa/noop_balsa_visitor.h create mode 100644 quiche/balsa/simple_buffer.cc create mode 100644 quiche/balsa/simple_buffer.h create mode 100644 quiche/balsa/simple_buffer_test.cc create mode 100644 quiche/balsa/standard_header_map.cc create mode 100644 quiche/balsa/standard_header_map.h create mode 100644 quiche/binary_http/binary_http_message.cc create mode 100644 quiche/binary_http/binary_http_message.h create mode 100644 quiche/binary_http/binary_http_message_test.cc create mode 100644 quiche/blind_sign_auth/anonymous_tokens/cpp/client/anonymous_tokens_rsa_bssa_client.cc create mode 100644 quiche/blind_sign_auth/anonymous_tokens/cpp/client/anonymous_tokens_rsa_bssa_client.h create mode 100644 quiche/blind_sign_auth/anonymous_tokens/cpp/client/anonymous_tokens_rsa_bssa_client_test.cc create mode 100644 quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/at_crypto_utils_test.cc create mode 100644 quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/blind_signer.h create mode 100644 quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/blinder.h create mode 100644 quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/constants.h create mode 100644 quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.cc create mode 100644 quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.h create mode 100644 quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blind_signer.cc create mode 100644 quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blind_signer.h create mode 100644 quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blind_signer_test.cc create mode 100644 quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blinder.cc create mode 100644 quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blinder.h create mode 100644 quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blinder_test.cc create mode 100644 quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_ssa_pss_verifier.cc create mode 100644 quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_ssa_pss_verifier.h create mode 100644 quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_ssa_pss_verifier_test.cc create mode 100644 quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/verifier.h create mode 100644 quiche/blind_sign_auth/anonymous_tokens/cpp/shared/proto_utils.cc create mode 100644 quiche/blind_sign_auth/anonymous_tokens/cpp/shared/proto_utils.h create mode 100644 quiche/blind_sign_auth/anonymous_tokens/cpp/shared/proto_utils_test.cc create mode 100644 quiche/blind_sign_auth/anonymous_tokens/cpp/shared/status_utils.h create mode 100644 quiche/blind_sign_auth/anonymous_tokens/cpp/testing/utils.cc create mode 100644 quiche/blind_sign_auth/anonymous_tokens/cpp/testing/utils.h create mode 100644 quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.proto create mode 100644 quiche/blind_sign_auth/blind_sign_auth.cc create mode 100644 quiche/blind_sign_auth/blind_sign_auth.h create mode 100644 quiche/blind_sign_auth/blind_sign_auth_interface.h create mode 100644 quiche/blind_sign_auth/blind_sign_auth_test.cc create mode 100644 quiche/blind_sign_auth/blind_sign_http_interface.h create mode 100644 quiche/blind_sign_auth/blind_sign_http_response.h create mode 100644 quiche/blind_sign_auth/cached_blind_sign_auth.cc create mode 100644 quiche/blind_sign_auth/cached_blind_sign_auth.h create mode 100644 quiche/blind_sign_auth/cached_blind_sign_auth_test.cc create mode 100644 quiche/blind_sign_auth/proto/any.proto create mode 100644 quiche/blind_sign_auth/proto/attestation.proto create mode 100644 quiche/blind_sign_auth/proto/auth_and_sign.proto create mode 100644 quiche/blind_sign_auth/proto/get_initial_data.proto create mode 100644 quiche/blind_sign_auth/proto/key_services.proto create mode 100644 quiche/blind_sign_auth/proto/public_metadata.proto create mode 100644 quiche/blind_sign_auth/proto/spend_token_data.proto create mode 100644 quiche/blind_sign_auth/proto/timestamp.proto create mode 100644 quiche/blind_sign_auth/test_tools/mock_blind_sign_auth_interface.h create mode 100644 quiche/blind_sign_auth/test_tools/mock_blind_sign_http_interface.h create mode 100644 quiche/common/anonymous_tokens/testdata/strong_rsa_modulus2048_example.binarypb create mode 100644 quiche/common/anonymous_tokens/testdata/strong_rsa_modulus2048_example_2.binarypb create mode 100644 quiche/common/anonymous_tokens/testdata/strong_rsa_modulus3072_example.binarypb create mode 100644 quiche/common/anonymous_tokens/testdata/strong_rsa_modulus4096_example.binarypb create mode 100644 quiche/common/btree_scheduler.h create mode 100644 quiche/common/btree_scheduler_test.cc create mode 100644 quiche/common/capsule.cc create mode 100644 quiche/common/capsule.h create mode 100644 quiche/common/capsule_test.cc create mode 100644 quiche/common/masque/connect_udp_datagram_payload.cc create mode 100644 quiche/common/masque/connect_udp_datagram_payload.h create mode 100644 quiche/common/masque/connect_udp_datagram_payload_test.cc create mode 100644 quiche/common/platform/api/quiche_bug_tracker.h create mode 100644 quiche/common/platform/api/quiche_client_stats.h create mode 100644 quiche/common/platform/api/quiche_command_line_flags.h create mode 100644 quiche/common/platform/api/quiche_containers.h create mode 100644 quiche/common/platform/api/quiche_default_proof_providers.h create mode 100644 quiche/common/platform/api/quiche_event_loop.h create mode 100644 quiche/common/platform/api/quiche_expect_bug.h create mode 100644 quiche/common/platform/api/quiche_export.h create mode 100644 quiche/common/platform/api/quiche_file_utils.cc create mode 100644 quiche/common/platform/api/quiche_file_utils.h create mode 100644 quiche/common/platform/api/quiche_file_utils_test.cc create mode 100644 quiche/common/platform/api/quiche_flag_utils.h create mode 100644 quiche/common/platform/api/quiche_flags.h create mode 100644 quiche/common/platform/api/quiche_header_policy.h create mode 100644 quiche/common/platform/api/quiche_hostname_utils.cc create mode 100644 quiche/common/platform/api/quiche_hostname_utils.h create mode 100644 quiche/common/platform/api/quiche_hostname_utils_test.cc create mode 100644 quiche/common/platform/api/quiche_iovec.h create mode 100644 quiche/common/platform/api/quiche_logging.h create mode 100644 quiche/common/platform/api/quiche_lower_case_string.h create mode 100644 quiche/common/platform/api/quiche_lower_case_string_test.cc create mode 100644 quiche/common/platform/api/quiche_mem_slice.h create mode 100644 quiche/common/platform/api/quiche_mem_slice_test.cc create mode 100644 quiche/common/platform/api/quiche_mutex.cc create mode 100644 quiche/common/platform/api/quiche_mutex.h create mode 100644 quiche/common/platform/api/quiche_prefetch.h create mode 100644 quiche/common/platform/api/quiche_reference_counted.h create mode 100644 quiche/common/platform/api/quiche_reference_counted_test.cc create mode 100644 quiche/common/platform/api/quiche_server_stats.h create mode 100644 quiche/common/platform/api/quiche_stack_trace.h create mode 100644 quiche/common/platform/api/quiche_stack_trace_test.cc create mode 100644 quiche/common/platform/api/quiche_system_event_loop.h create mode 100644 quiche/common/platform/api/quiche_test.h create mode 100644 quiche/common/platform/api/quiche_test_loopback.cc create mode 100644 quiche/common/platform/api/quiche_test_loopback.h create mode 100644 quiche/common/platform/api/quiche_test_output.h create mode 100644 quiche/common/platform/api/quiche_testvalue.h create mode 100644 quiche/common/platform/api/quiche_thread.h create mode 100644 quiche/common/platform/api/quiche_time_utils.h create mode 100644 quiche/common/platform/api/quiche_time_utils_test.cc create mode 100644 quiche/common/platform/api/quiche_udp_socket_platform_api.h create mode 100644 quiche/common/platform/api/quiche_url_utils.h create mode 100644 quiche/common/platform/api/quiche_url_utils_test.cc create mode 100644 quiche/common/platform/api/testdir/README.md create mode 100644 quiche/common/platform/api/testdir/a/b/c/d/e create mode 100644 quiche/common/platform/api/testdir/a/subdir/testfile create mode 100644 quiche/common/platform/api/testdir/a/z create mode 100644 quiche/common/platform/api/testdir/testfile create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_bug_tracker_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_client_stats_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_command_line_flags_impl.cc create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_command_line_flags_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_containers_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_default_proof_providers_impl.cc create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_default_proof_providers_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_event_loop_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_expect_bug_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_export_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_file_utils_impl.cc create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_file_utils_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_flag_utils_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_flags_impl.cc create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_flags_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_header_policy_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_iovec_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_logging_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_lower_case_string_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_mem_slice_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_mutex_impl.cc create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_mutex_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_prefetch_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_reference_counted_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_server_stats_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_stack_trace_impl.cc create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_stack_trace_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_stream_buffer_allocator_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_system_event_loop_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_test_impl.cc create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_test_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_test_loopback_impl.cc create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_test_loopback_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_test_output_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_testvalue_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_thread_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_time_utils_impl.cc create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_time_utils_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_udp_socket_platform_impl.h create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_url_utils_impl.cc create mode 100644 quiche/common/platform/default/quiche_platform_impl/quiche_url_utils_impl.h create mode 100644 quiche/common/print_elements.h create mode 100644 quiche/common/print_elements_test.cc create mode 100644 quiche/common/quiche_buffer_allocator.cc create mode 100644 quiche/common/quiche_buffer_allocator.h create mode 100644 quiche/common/quiche_buffer_allocator_test.cc create mode 100644 quiche/common/quiche_circular_deque.h create mode 100644 quiche/common/quiche_circular_deque_test.cc create mode 100644 quiche/common/quiche_crypto_logging.cc create mode 100644 quiche/common/quiche_crypto_logging.h create mode 100644 quiche/common/quiche_data_reader.cc create mode 100644 quiche/common/quiche_data_reader.h create mode 100644 quiche/common/quiche_data_reader_test.cc create mode 100644 quiche/common/quiche_data_writer.cc create mode 100644 quiche/common/quiche_data_writer.h create mode 100644 quiche/common/quiche_data_writer_test.cc create mode 100644 quiche/common/quiche_endian.h create mode 100644 quiche/common/quiche_endian_test.cc create mode 100644 quiche/common/quiche_ip_address.cc create mode 100644 quiche/common/quiche_ip_address.h create mode 100644 quiche/common/quiche_ip_address_family.cc create mode 100644 quiche/common/quiche_ip_address_family.h create mode 100644 quiche/common/quiche_ip_address_test.cc create mode 100644 quiche/common/quiche_linked_hash_map.h create mode 100644 quiche/common/quiche_linked_hash_map_test.cc create mode 100644 quiche/common/quiche_mem_slice_storage.cc create mode 100644 quiche/common/quiche_mem_slice_storage.h create mode 100644 quiche/common/quiche_mem_slice_storage_test.cc create mode 100644 quiche/common/quiche_protocol_flags_list.h create mode 100644 quiche/common/quiche_random.cc create mode 100644 quiche/common/quiche_random.h create mode 100644 quiche/common/quiche_random_test.cc create mode 100644 quiche/common/quiche_status_utils.h create mode 100644 quiche/common/quiche_stream.h create mode 100644 quiche/common/quiche_text_utils.cc create mode 100644 quiche/common/quiche_text_utils.h create mode 100644 quiche/common/quiche_text_utils_test.cc create mode 100644 quiche/common/simple_buffer_allocator.cc create mode 100644 quiche/common/simple_buffer_allocator.h create mode 100644 quiche/common/simple_buffer_allocator_test.cc create mode 100644 quiche/common/structured_headers.cc create mode 100644 quiche/common/structured_headers.h create mode 100644 quiche/common/structured_headers_fuzzer.cc create mode 100644 quiche/common/structured_headers_generated_test.cc create mode 100644 quiche/common/structured_headers_test.cc create mode 100644 quiche/common/test_tools/quiche_test_utils.cc create mode 100644 quiche/common/test_tools/quiche_test_utils.h create mode 100644 quiche/common/test_tools/quiche_test_utils_test.cc create mode 100644 quiche/common/wire_serialization.h create mode 100644 quiche/common/wire_serialization_test.cc create mode 100644 quiche/http2/adapter/adapter_impl_comparison_test.cc create mode 100644 quiche/http2/adapter/callback_visitor.cc create mode 100644 quiche/http2/adapter/callback_visitor.h create mode 100644 quiche/http2/adapter/callback_visitor_test.cc create mode 100644 quiche/http2/adapter/data_source.h create mode 100644 quiche/http2/adapter/event_forwarder.cc create mode 100644 quiche/http2/adapter/event_forwarder.h create mode 100644 quiche/http2/adapter/event_forwarder_test.cc create mode 100644 quiche/http2/adapter/header_validator.cc create mode 100644 quiche/http2/adapter/header_validator.h create mode 100644 quiche/http2/adapter/header_validator_base.h create mode 100644 quiche/http2/adapter/header_validator_test.cc create mode 100644 quiche/http2/adapter/http2_adapter.h create mode 100644 quiche/http2/adapter/http2_protocol.cc create mode 100644 quiche/http2/adapter/http2_protocol.h create mode 100644 quiche/http2/adapter/http2_session.h create mode 100644 quiche/http2/adapter/http2_util.cc create mode 100644 quiche/http2/adapter/http2_util.h create mode 100644 quiche/http2/adapter/http2_visitor_interface.h create mode 100644 quiche/http2/adapter/mock_http2_visitor.h create mode 100644 quiche/http2/adapter/mock_nghttp2_callbacks.cc create mode 100644 quiche/http2/adapter/mock_nghttp2_callbacks.h create mode 100644 quiche/http2/adapter/nghttp2.h create mode 100644 quiche/http2/adapter/nghttp2_adapter.cc create mode 100644 quiche/http2/adapter/nghttp2_adapter.h create mode 100644 quiche/http2/adapter/nghttp2_adapter_test.cc create mode 100644 quiche/http2/adapter/nghttp2_callbacks.cc create mode 100644 quiche/http2/adapter/nghttp2_callbacks.h create mode 100644 quiche/http2/adapter/nghttp2_data_provider.cc create mode 100644 quiche/http2/adapter/nghttp2_data_provider.h create mode 100644 quiche/http2/adapter/nghttp2_data_provider_test.cc create mode 100644 quiche/http2/adapter/nghttp2_session.cc create mode 100644 quiche/http2/adapter/nghttp2_session.h create mode 100644 quiche/http2/adapter/nghttp2_session_test.cc create mode 100644 quiche/http2/adapter/nghttp2_test.cc create mode 100644 quiche/http2/adapter/nghttp2_test_utils.cc create mode 100644 quiche/http2/adapter/nghttp2_test_utils.h create mode 100644 quiche/http2/adapter/nghttp2_util.cc create mode 100644 quiche/http2/adapter/nghttp2_util.h create mode 100644 quiche/http2/adapter/nghttp2_util_test.cc create mode 100644 quiche/http2/adapter/noop_header_validator.cc create mode 100644 quiche/http2/adapter/noop_header_validator.h create mode 100644 quiche/http2/adapter/noop_header_validator_test.cc create mode 100644 quiche/http2/adapter/oghttp2_adapter.cc create mode 100644 quiche/http2/adapter/oghttp2_adapter.h create mode 100644 quiche/http2/adapter/oghttp2_adapter_test.cc create mode 100644 quiche/http2/adapter/oghttp2_session.cc create mode 100644 quiche/http2/adapter/oghttp2_session.h create mode 100644 quiche/http2/adapter/oghttp2_session_test.cc create mode 100644 quiche/http2/adapter/oghttp2_util.cc create mode 100644 quiche/http2/adapter/oghttp2_util.h create mode 100644 quiche/http2/adapter/oghttp2_util_test.cc create mode 100644 quiche/http2/adapter/recording_http2_visitor.cc create mode 100644 quiche/http2/adapter/recording_http2_visitor.h create mode 100644 quiche/http2/adapter/recording_http2_visitor_test.cc create mode 100644 quiche/http2/adapter/test_frame_sequence.cc create mode 100644 quiche/http2/adapter/test_frame_sequence.h create mode 100644 quiche/http2/adapter/test_utils.cc create mode 100644 quiche/http2/adapter/test_utils.h create mode 100644 quiche/http2/adapter/test_utils_test.cc create mode 100644 quiche/http2/adapter/window_manager.cc create mode 100644 quiche/http2/adapter/window_manager.h create mode 100644 quiche/http2/adapter/window_manager_test.cc create mode 100644 quiche/http2/core/http2_trace_logging.cc create mode 100644 quiche/http2/core/http2_trace_logging.h create mode 100644 quiche/http2/core/priority_write_scheduler.h create mode 100644 quiche/http2/core/priority_write_scheduler_test.cc create mode 100644 quiche/http2/decoder/decode_buffer.cc create mode 100644 quiche/http2/decoder/decode_buffer.h create mode 100644 quiche/http2/decoder/decode_buffer_test.cc create mode 100644 quiche/http2/decoder/decode_http2_structures.cc create mode 100644 quiche/http2/decoder/decode_http2_structures.h create mode 100644 quiche/http2/decoder/decode_http2_structures_test.cc create mode 100644 quiche/http2/decoder/decode_status.cc create mode 100644 quiche/http2/decoder/decode_status.h create mode 100644 quiche/http2/decoder/frame_decoder_state.cc create mode 100644 quiche/http2/decoder/frame_decoder_state.h create mode 100644 quiche/http2/decoder/http2_frame_decoder.cc create mode 100644 quiche/http2/decoder/http2_frame_decoder.h create mode 100644 quiche/http2/decoder/http2_frame_decoder_listener.cc create mode 100644 quiche/http2/decoder/http2_frame_decoder_listener.h create mode 100644 quiche/http2/decoder/http2_frame_decoder_test.cc create mode 100644 quiche/http2/decoder/http2_structure_decoder.cc create mode 100644 quiche/http2/decoder/http2_structure_decoder.h create mode 100644 quiche/http2/decoder/http2_structure_decoder_test.cc create mode 100644 quiche/http2/decoder/payload_decoders/altsvc_payload_decoder.cc create mode 100644 quiche/http2/decoder/payload_decoders/altsvc_payload_decoder.h create mode 100644 quiche/http2/decoder/payload_decoders/altsvc_payload_decoder_test.cc create mode 100644 quiche/http2/decoder/payload_decoders/continuation_payload_decoder.cc create mode 100644 quiche/http2/decoder/payload_decoders/continuation_payload_decoder.h create mode 100644 quiche/http2/decoder/payload_decoders/continuation_payload_decoder_test.cc create mode 100644 quiche/http2/decoder/payload_decoders/data_payload_decoder.cc create mode 100644 quiche/http2/decoder/payload_decoders/data_payload_decoder.h create mode 100644 quiche/http2/decoder/payload_decoders/data_payload_decoder_test.cc create mode 100644 quiche/http2/decoder/payload_decoders/goaway_payload_decoder.cc create mode 100644 quiche/http2/decoder/payload_decoders/goaway_payload_decoder.h create mode 100644 quiche/http2/decoder/payload_decoders/goaway_payload_decoder_test.cc create mode 100644 quiche/http2/decoder/payload_decoders/headers_payload_decoder.cc create mode 100644 quiche/http2/decoder/payload_decoders/headers_payload_decoder.h create mode 100644 quiche/http2/decoder/payload_decoders/headers_payload_decoder_test.cc create mode 100644 quiche/http2/decoder/payload_decoders/ping_payload_decoder.cc create mode 100644 quiche/http2/decoder/payload_decoders/ping_payload_decoder.h create mode 100644 quiche/http2/decoder/payload_decoders/ping_payload_decoder_test.cc create mode 100644 quiche/http2/decoder/payload_decoders/priority_payload_decoder.cc create mode 100644 quiche/http2/decoder/payload_decoders/priority_payload_decoder.h create mode 100644 quiche/http2/decoder/payload_decoders/priority_payload_decoder_test.cc create mode 100644 quiche/http2/decoder/payload_decoders/priority_update_payload_decoder.cc create mode 100644 quiche/http2/decoder/payload_decoders/priority_update_payload_decoder.h create mode 100644 quiche/http2/decoder/payload_decoders/priority_update_payload_decoder_test.cc create mode 100644 quiche/http2/decoder/payload_decoders/push_promise_payload_decoder.cc create mode 100644 quiche/http2/decoder/payload_decoders/push_promise_payload_decoder.h create mode 100644 quiche/http2/decoder/payload_decoders/push_promise_payload_decoder_test.cc create mode 100644 quiche/http2/decoder/payload_decoders/rst_stream_payload_decoder.cc create mode 100644 quiche/http2/decoder/payload_decoders/rst_stream_payload_decoder.h create mode 100644 quiche/http2/decoder/payload_decoders/rst_stream_payload_decoder_test.cc create mode 100644 quiche/http2/decoder/payload_decoders/settings_payload_decoder.cc create mode 100644 quiche/http2/decoder/payload_decoders/settings_payload_decoder.h create mode 100644 quiche/http2/decoder/payload_decoders/settings_payload_decoder_test.cc create mode 100644 quiche/http2/decoder/payload_decoders/unknown_payload_decoder.cc create mode 100644 quiche/http2/decoder/payload_decoders/unknown_payload_decoder.h create mode 100644 quiche/http2/decoder/payload_decoders/unknown_payload_decoder_test.cc create mode 100644 quiche/http2/decoder/payload_decoders/window_update_payload_decoder.cc create mode 100644 quiche/http2/decoder/payload_decoders/window_update_payload_decoder.h create mode 100644 quiche/http2/decoder/payload_decoders/window_update_payload_decoder_test.cc create mode 100644 quiche/http2/hpack/decoder/hpack_block_collector_test.cc create mode 100644 quiche/http2/hpack/decoder/hpack_block_decoder.cc create mode 100644 quiche/http2/hpack/decoder/hpack_block_decoder.h create mode 100644 quiche/http2/hpack/decoder/hpack_block_decoder_test.cc create mode 100644 quiche/http2/hpack/decoder/hpack_decoder.cc create mode 100644 quiche/http2/hpack/decoder/hpack_decoder.h create mode 100644 quiche/http2/hpack/decoder/hpack_decoder_listener.cc create mode 100644 quiche/http2/hpack/decoder/hpack_decoder_listener.h create mode 100644 quiche/http2/hpack/decoder/hpack_decoder_state.cc create mode 100644 quiche/http2/hpack/decoder/hpack_decoder_state.h create mode 100644 quiche/http2/hpack/decoder/hpack_decoder_state_test.cc create mode 100644 quiche/http2/hpack/decoder/hpack_decoder_string_buffer.cc create mode 100644 quiche/http2/hpack/decoder/hpack_decoder_string_buffer.h create mode 100644 quiche/http2/hpack/decoder/hpack_decoder_string_buffer_test.cc create mode 100644 quiche/http2/hpack/decoder/hpack_decoder_tables.cc create mode 100644 quiche/http2/hpack/decoder/hpack_decoder_tables.h create mode 100644 quiche/http2/hpack/decoder/hpack_decoder_tables_test.cc create mode 100644 quiche/http2/hpack/decoder/hpack_decoder_test.cc create mode 100644 quiche/http2/hpack/decoder/hpack_decoding_error.cc create mode 100644 quiche/http2/hpack/decoder/hpack_decoding_error.h create mode 100644 quiche/http2/hpack/decoder/hpack_entry_collector_test.cc create mode 100644 quiche/http2/hpack/decoder/hpack_entry_decoder.cc create mode 100644 quiche/http2/hpack/decoder/hpack_entry_decoder.h create mode 100644 quiche/http2/hpack/decoder/hpack_entry_decoder_listener.cc create mode 100644 quiche/http2/hpack/decoder/hpack_entry_decoder_listener.h create mode 100644 quiche/http2/hpack/decoder/hpack_entry_decoder_test.cc create mode 100644 quiche/http2/hpack/decoder/hpack_entry_type_decoder.cc create mode 100644 quiche/http2/hpack/decoder/hpack_entry_type_decoder.h create mode 100644 quiche/http2/hpack/decoder/hpack_entry_type_decoder_test.cc create mode 100644 quiche/http2/hpack/decoder/hpack_string_decoder.cc create mode 100644 quiche/http2/hpack/decoder/hpack_string_decoder.h create mode 100644 quiche/http2/hpack/decoder/hpack_string_decoder_listener.cc create mode 100644 quiche/http2/hpack/decoder/hpack_string_decoder_listener.h create mode 100644 quiche/http2/hpack/decoder/hpack_string_decoder_test.cc create mode 100644 quiche/http2/hpack/decoder/hpack_whole_entry_buffer.cc create mode 100644 quiche/http2/hpack/decoder/hpack_whole_entry_buffer.h create mode 100644 quiche/http2/hpack/decoder/hpack_whole_entry_buffer_test.cc create mode 100644 quiche/http2/hpack/decoder/hpack_whole_entry_listener.cc create mode 100644 quiche/http2/hpack/decoder/hpack_whole_entry_listener.h create mode 100644 quiche/http2/hpack/hpack_static_table_entries.inc create mode 100644 quiche/http2/hpack/http2_hpack_constants.cc create mode 100644 quiche/http2/hpack/http2_hpack_constants.h create mode 100644 quiche/http2/hpack/http2_hpack_constants_test.cc create mode 100644 quiche/http2/hpack/huffman/hpack_huffman_decoder.cc create mode 100644 quiche/http2/hpack/huffman/hpack_huffman_decoder.h create mode 100644 quiche/http2/hpack/huffman/hpack_huffman_decoder_test.cc create mode 100644 quiche/http2/hpack/huffman/hpack_huffman_encoder.cc create mode 100644 quiche/http2/hpack/huffman/hpack_huffman_encoder.h create mode 100644 quiche/http2/hpack/huffman/hpack_huffman_encoder_test.cc create mode 100644 quiche/http2/hpack/huffman/hpack_huffman_transcoder_test.cc create mode 100644 quiche/http2/hpack/huffman/huffman_spec_tables.cc create mode 100644 quiche/http2/hpack/huffman/huffman_spec_tables.h create mode 100644 quiche/http2/hpack/varint/hpack_varint_decoder.cc create mode 100644 quiche/http2/hpack/varint/hpack_varint_decoder.h create mode 100644 quiche/http2/hpack/varint/hpack_varint_decoder_test.cc create mode 100644 quiche/http2/hpack/varint/hpack_varint_encoder.cc create mode 100644 quiche/http2/hpack/varint/hpack_varint_encoder.h create mode 100644 quiche/http2/hpack/varint/hpack_varint_encoder_test.cc create mode 100644 quiche/http2/hpack/varint/hpack_varint_round_trip_test.cc create mode 100644 quiche/http2/http2_constants.cc create mode 100644 quiche/http2/http2_constants.h create mode 100644 quiche/http2/http2_constants_test.cc create mode 100644 quiche/http2/http2_structures.cc create mode 100644 quiche/http2/http2_structures.h create mode 100644 quiche/http2/http2_structures_test.cc create mode 100644 quiche/http2/test_tools/frame_decoder_state_test_util.cc create mode 100644 quiche/http2/test_tools/frame_decoder_state_test_util.h create mode 100644 quiche/http2/test_tools/frame_parts.cc create mode 100644 quiche/http2/test_tools/frame_parts.h create mode 100644 quiche/http2/test_tools/frame_parts_collector.cc create mode 100644 quiche/http2/test_tools/frame_parts_collector.h create mode 100644 quiche/http2/test_tools/frame_parts_collector_listener.cc create mode 100644 quiche/http2/test_tools/frame_parts_collector_listener.h create mode 100644 quiche/http2/test_tools/hpack_block_builder.cc create mode 100644 quiche/http2/test_tools/hpack_block_builder.h create mode 100644 quiche/http2/test_tools/hpack_block_builder_test.cc create mode 100644 quiche/http2/test_tools/hpack_block_collector.cc create mode 100644 quiche/http2/test_tools/hpack_block_collector.h create mode 100644 quiche/http2/test_tools/hpack_entry_collector.cc create mode 100644 quiche/http2/test_tools/hpack_entry_collector.h create mode 100644 quiche/http2/test_tools/hpack_example.cc create mode 100644 quiche/http2/test_tools/hpack_example.h create mode 100644 quiche/http2/test_tools/hpack_example_test.cc create mode 100644 quiche/http2/test_tools/hpack_string_collector.cc create mode 100644 quiche/http2/test_tools/hpack_string_collector.h create mode 100644 quiche/http2/test_tools/http2_constants_test_util.cc create mode 100644 quiche/http2/test_tools/http2_constants_test_util.h create mode 100644 quiche/http2/test_tools/http2_frame_builder.cc create mode 100644 quiche/http2/test_tools/http2_frame_builder.h create mode 100644 quiche/http2/test_tools/http2_frame_builder_test.cc create mode 100644 quiche/http2/test_tools/http2_frame_decoder_listener_test_util.cc create mode 100644 quiche/http2/test_tools/http2_frame_decoder_listener_test_util.h create mode 100644 quiche/http2/test_tools/http2_random.cc create mode 100644 quiche/http2/test_tools/http2_random.h create mode 100644 quiche/http2/test_tools/http2_random_test.cc create mode 100644 quiche/http2/test_tools/http2_structure_decoder_test_util.cc create mode 100644 quiche/http2/test_tools/http2_structure_decoder_test_util.h create mode 100644 quiche/http2/test_tools/http2_structures_test_util.cc create mode 100644 quiche/http2/test_tools/http2_structures_test_util.h create mode 100644 quiche/http2/test_tools/payload_decoder_base_test_util.cc create mode 100644 quiche/http2/test_tools/payload_decoder_base_test_util.h create mode 100644 quiche/http2/test_tools/random_decoder_test_base.cc create mode 100644 quiche/http2/test_tools/random_decoder_test_base.h create mode 100644 quiche/http2/test_tools/random_decoder_test_base_test.cc create mode 100644 quiche/http2/test_tools/random_util.cc create mode 100644 quiche/http2/test_tools/random_util.h create mode 100644 quiche/http2/test_tools/verify_macros.h create mode 100644 quiche/oblivious_http/buffers/oblivious_http_integration_test.cc create mode 100644 quiche/oblivious_http/buffers/oblivious_http_request.cc create mode 100644 quiche/oblivious_http/buffers/oblivious_http_request.h create mode 100644 quiche/oblivious_http/buffers/oblivious_http_request_test.cc create mode 100644 quiche/oblivious_http/buffers/oblivious_http_response.cc create mode 100644 quiche/oblivious_http/buffers/oblivious_http_response.h create mode 100644 quiche/oblivious_http/buffers/oblivious_http_response_test.cc create mode 100644 quiche/oblivious_http/common/oblivious_http_header_key_config.cc create mode 100644 quiche/oblivious_http/common/oblivious_http_header_key_config.h create mode 100644 quiche/oblivious_http/common/oblivious_http_header_key_config_test.cc create mode 100644 quiche/oblivious_http/oblivious_http_client.cc create mode 100644 quiche/oblivious_http/oblivious_http_client.h create mode 100644 quiche/oblivious_http/oblivious_http_client_test.cc create mode 100644 quiche/oblivious_http/oblivious_http_gateway.cc create mode 100644 quiche/oblivious_http/oblivious_http_gateway.h create mode 100644 quiche/oblivious_http/oblivious_http_gateway_test.cc create mode 100644 quiche/quic/bindings/quic_libevent.cc create mode 100644 quiche/quic/bindings/quic_libevent.h create mode 100644 quiche/quic/bindings/quic_libevent_test.cc create mode 100644 quiche/quic/core/batch_writer/quic_batch_writer_base.cc create mode 100644 quiche/quic/core/batch_writer/quic_batch_writer_base.h create mode 100644 quiche/quic/core/batch_writer/quic_batch_writer_buffer.cc create mode 100644 quiche/quic/core/batch_writer/quic_batch_writer_buffer.h create mode 100644 quiche/quic/core/batch_writer/quic_batch_writer_buffer_test.cc create mode 100644 quiche/quic/core/batch_writer/quic_batch_writer_test.cc create mode 100644 quiche/quic/core/batch_writer/quic_batch_writer_test.h create mode 100644 quiche/quic/core/batch_writer/quic_gso_batch_writer.cc create mode 100644 quiche/quic/core/batch_writer/quic_gso_batch_writer.h create mode 100644 quiche/quic/core/batch_writer/quic_gso_batch_writer_test.cc create mode 100644 quiche/quic/core/batch_writer/quic_sendmmsg_batch_writer.cc create mode 100644 quiche/quic/core/batch_writer/quic_sendmmsg_batch_writer.h create mode 100644 quiche/quic/core/batch_writer/quic_sendmmsg_batch_writer_test.cc create mode 100644 quiche/quic/core/chlo_extractor.cc create mode 100644 quiche/quic/core/chlo_extractor.h create mode 100644 quiche/quic/core/chlo_extractor_test.cc create mode 100644 quiche/quic/core/congestion_control/bandwidth_sampler.cc create mode 100644 quiche/quic/core/congestion_control/bandwidth_sampler.h create mode 100644 quiche/quic/core/congestion_control/bandwidth_sampler_test.cc create mode 100644 quiche/quic/core/congestion_control/bbr2_drain.cc create mode 100644 quiche/quic/core/congestion_control/bbr2_drain.h create mode 100644 quiche/quic/core/congestion_control/bbr2_misc.cc create mode 100644 quiche/quic/core/congestion_control/bbr2_misc.h create mode 100644 quiche/quic/core/congestion_control/bbr2_probe_bw.cc create mode 100644 quiche/quic/core/congestion_control/bbr2_probe_bw.h create mode 100644 quiche/quic/core/congestion_control/bbr2_probe_rtt.cc create mode 100644 quiche/quic/core/congestion_control/bbr2_probe_rtt.h create mode 100644 quiche/quic/core/congestion_control/bbr2_sender.cc create mode 100644 quiche/quic/core/congestion_control/bbr2_sender.h create mode 100644 quiche/quic/core/congestion_control/bbr2_simulator_test.cc create mode 100644 quiche/quic/core/congestion_control/bbr2_startup.cc create mode 100644 quiche/quic/core/congestion_control/bbr2_startup.h create mode 100644 quiche/quic/core/congestion_control/bbr_sender.cc create mode 100644 quiche/quic/core/congestion_control/bbr_sender.h create mode 100644 quiche/quic/core/congestion_control/bbr_sender_test.cc create mode 100644 quiche/quic/core/congestion_control/cubic_bytes.cc create mode 100644 quiche/quic/core/congestion_control/cubic_bytes.h create mode 100644 quiche/quic/core/congestion_control/cubic_bytes_test.cc create mode 100644 quiche/quic/core/congestion_control/general_loss_algorithm.cc create mode 100644 quiche/quic/core/congestion_control/general_loss_algorithm.h create mode 100644 quiche/quic/core/congestion_control/general_loss_algorithm_test.cc create mode 100644 quiche/quic/core/congestion_control/hybrid_slow_start.cc create mode 100644 quiche/quic/core/congestion_control/hybrid_slow_start.h create mode 100644 quiche/quic/core/congestion_control/hybrid_slow_start_test.cc create mode 100644 quiche/quic/core/congestion_control/loss_detection_interface.h create mode 100644 quiche/quic/core/congestion_control/pacing_sender.cc create mode 100644 quiche/quic/core/congestion_control/pacing_sender.h create mode 100644 quiche/quic/core/congestion_control/pacing_sender_test.cc create mode 100644 quiche/quic/core/congestion_control/prr_sender.cc create mode 100644 quiche/quic/core/congestion_control/prr_sender.h create mode 100644 quiche/quic/core/congestion_control/prr_sender_test.cc create mode 100644 quiche/quic/core/congestion_control/rtt_stats.cc create mode 100644 quiche/quic/core/congestion_control/rtt_stats.h create mode 100644 quiche/quic/core/congestion_control/rtt_stats_test.cc create mode 100644 quiche/quic/core/congestion_control/send_algorithm_interface.cc create mode 100644 quiche/quic/core/congestion_control/send_algorithm_interface.h create mode 100644 quiche/quic/core/congestion_control/send_algorithm_test.cc create mode 100644 quiche/quic/core/congestion_control/tcp_cubic_sender_bytes.cc create mode 100644 quiche/quic/core/congestion_control/tcp_cubic_sender_bytes.h create mode 100644 quiche/quic/core/congestion_control/tcp_cubic_sender_bytes_test.cc create mode 100644 quiche/quic/core/congestion_control/uber_loss_algorithm.cc create mode 100644 quiche/quic/core/congestion_control/uber_loss_algorithm.h create mode 100644 quiche/quic/core/congestion_control/uber_loss_algorithm_test.cc create mode 100644 quiche/quic/core/congestion_control/windowed_filter.h create mode 100644 quiche/quic/core/congestion_control/windowed_filter_test.cc create mode 100644 quiche/quic/core/connecting_client_socket.h create mode 100644 quiche/quic/core/connection_id_generator.h create mode 100644 quiche/quic/core/crypto/aead_base_decrypter.cc create mode 100644 quiche/quic/core/crypto/aead_base_decrypter.h create mode 100644 quiche/quic/core/crypto/aead_base_encrypter.cc create mode 100644 quiche/quic/core/crypto/aead_base_encrypter.h create mode 100644 quiche/quic/core/crypto/aes_128_gcm_12_decrypter.cc create mode 100644 quiche/quic/core/crypto/aes_128_gcm_12_decrypter.h create mode 100644 quiche/quic/core/crypto/aes_128_gcm_12_decrypter_test.cc create mode 100644 quiche/quic/core/crypto/aes_128_gcm_12_encrypter.cc create mode 100644 quiche/quic/core/crypto/aes_128_gcm_12_encrypter.h create mode 100644 quiche/quic/core/crypto/aes_128_gcm_12_encrypter_test.cc create mode 100644 quiche/quic/core/crypto/aes_128_gcm_decrypter.cc create mode 100644 quiche/quic/core/crypto/aes_128_gcm_decrypter.h create mode 100644 quiche/quic/core/crypto/aes_128_gcm_decrypter_test.cc create mode 100644 quiche/quic/core/crypto/aes_128_gcm_encrypter.cc create mode 100644 quiche/quic/core/crypto/aes_128_gcm_encrypter.h create mode 100644 quiche/quic/core/crypto/aes_128_gcm_encrypter_test.cc create mode 100644 quiche/quic/core/crypto/aes_256_gcm_decrypter.cc create mode 100644 quiche/quic/core/crypto/aes_256_gcm_decrypter.h create mode 100644 quiche/quic/core/crypto/aes_256_gcm_decrypter_test.cc create mode 100644 quiche/quic/core/crypto/aes_256_gcm_encrypter.cc create mode 100644 quiche/quic/core/crypto/aes_256_gcm_encrypter.h create mode 100644 quiche/quic/core/crypto/aes_256_gcm_encrypter_test.cc create mode 100644 quiche/quic/core/crypto/aes_base_decrypter.cc create mode 100644 quiche/quic/core/crypto/aes_base_decrypter.h create mode 100644 quiche/quic/core/crypto/aes_base_encrypter.cc create mode 100644 quiche/quic/core/crypto/aes_base_encrypter.h create mode 100644 quiche/quic/core/crypto/boring_utils.h create mode 100644 quiche/quic/core/crypto/cert_compressor.cc create mode 100644 quiche/quic/core/crypto/cert_compressor.h create mode 100644 quiche/quic/core/crypto/cert_compressor_test.cc create mode 100644 quiche/quic/core/crypto/certificate_util.cc create mode 100644 quiche/quic/core/crypto/certificate_util.h create mode 100644 quiche/quic/core/crypto/certificate_util_test.cc create mode 100644 quiche/quic/core/crypto/certificate_view.cc create mode 100644 quiche/quic/core/crypto/certificate_view.h create mode 100644 quiche/quic/core/crypto/certificate_view_der_fuzzer.cc create mode 100644 quiche/quic/core/crypto/certificate_view_pem_fuzzer.cc create mode 100644 quiche/quic/core/crypto/certificate_view_test.cc create mode 100644 quiche/quic/core/crypto/chacha20_poly1305_decrypter.cc create mode 100644 quiche/quic/core/crypto/chacha20_poly1305_decrypter.h create mode 100644 quiche/quic/core/crypto/chacha20_poly1305_decrypter_test.cc create mode 100644 quiche/quic/core/crypto/chacha20_poly1305_encrypter.cc create mode 100644 quiche/quic/core/crypto/chacha20_poly1305_encrypter.h create mode 100644 quiche/quic/core/crypto/chacha20_poly1305_encrypter_test.cc create mode 100644 quiche/quic/core/crypto/chacha20_poly1305_tls_decrypter.cc create mode 100644 quiche/quic/core/crypto/chacha20_poly1305_tls_decrypter.h create mode 100644 quiche/quic/core/crypto/chacha20_poly1305_tls_decrypter_test.cc create mode 100644 quiche/quic/core/crypto/chacha20_poly1305_tls_encrypter.cc create mode 100644 quiche/quic/core/crypto/chacha20_poly1305_tls_encrypter.h create mode 100644 quiche/quic/core/crypto/chacha20_poly1305_tls_encrypter_test.cc create mode 100644 quiche/quic/core/crypto/chacha_base_decrypter.cc create mode 100644 quiche/quic/core/crypto/chacha_base_decrypter.h create mode 100644 quiche/quic/core/crypto/chacha_base_encrypter.cc create mode 100644 quiche/quic/core/crypto/chacha_base_encrypter.h create mode 100644 quiche/quic/core/crypto/channel_id.cc create mode 100644 quiche/quic/core/crypto/channel_id.h create mode 100644 quiche/quic/core/crypto/channel_id_test.cc create mode 100644 quiche/quic/core/crypto/client_proof_source.cc create mode 100644 quiche/quic/core/crypto/client_proof_source.h create mode 100644 quiche/quic/core/crypto/client_proof_source_test.cc create mode 100644 quiche/quic/core/crypto/crypto_framer.cc create mode 100644 quiche/quic/core/crypto/crypto_framer.h create mode 100644 quiche/quic/core/crypto/crypto_framer_test.cc create mode 100644 quiche/quic/core/crypto/crypto_handshake.cc create mode 100644 quiche/quic/core/crypto/crypto_handshake.h create mode 100644 quiche/quic/core/crypto/crypto_handshake_message.cc create mode 100644 quiche/quic/core/crypto/crypto_handshake_message.h create mode 100644 quiche/quic/core/crypto/crypto_handshake_message_test.cc create mode 100644 quiche/quic/core/crypto/crypto_message_parser.h create mode 100644 quiche/quic/core/crypto/crypto_protocol.h create mode 100644 quiche/quic/core/crypto/crypto_secret_boxer.cc create mode 100644 quiche/quic/core/crypto/crypto_secret_boxer.h create mode 100644 quiche/quic/core/crypto/crypto_secret_boxer_test.cc create mode 100644 quiche/quic/core/crypto/crypto_server_test.cc create mode 100644 quiche/quic/core/crypto/crypto_utils.cc create mode 100644 quiche/quic/core/crypto/crypto_utils.h create mode 100644 quiche/quic/core/crypto/crypto_utils_test.cc create mode 100644 quiche/quic/core/crypto/curve25519_key_exchange.cc create mode 100644 quiche/quic/core/crypto/curve25519_key_exchange.h create mode 100644 quiche/quic/core/crypto/curve25519_key_exchange_test.cc create mode 100644 quiche/quic/core/crypto/key_exchange.cc create mode 100644 quiche/quic/core/crypto/key_exchange.h create mode 100644 quiche/quic/core/crypto/null_decrypter.cc create mode 100644 quiche/quic/core/crypto/null_decrypter.h create mode 100644 quiche/quic/core/crypto/null_decrypter_test.cc create mode 100644 quiche/quic/core/crypto/null_encrypter.cc create mode 100644 quiche/quic/core/crypto/null_encrypter.h create mode 100644 quiche/quic/core/crypto/null_encrypter_test.cc create mode 100644 quiche/quic/core/crypto/p256_key_exchange.cc create mode 100644 quiche/quic/core/crypto/p256_key_exchange.h create mode 100644 quiche/quic/core/crypto/p256_key_exchange_test.cc create mode 100644 quiche/quic/core/crypto/proof_source.cc create mode 100644 quiche/quic/core/crypto/proof_source.h create mode 100644 quiche/quic/core/crypto/proof_source_x509.cc create mode 100644 quiche/quic/core/crypto/proof_source_x509.h create mode 100644 quiche/quic/core/crypto/proof_source_x509_test.cc create mode 100644 quiche/quic/core/crypto/proof_verifier.h create mode 100644 quiche/quic/core/crypto/quic_client_session_cache.cc create mode 100644 quiche/quic/core/crypto/quic_client_session_cache.h create mode 100644 quiche/quic/core/crypto/quic_client_session_cache_test.cc create mode 100644 quiche/quic/core/crypto/quic_compressed_certs_cache.cc create mode 100644 quiche/quic/core/crypto/quic_compressed_certs_cache.h create mode 100644 quiche/quic/core/crypto/quic_compressed_certs_cache_test.cc create mode 100644 quiche/quic/core/crypto/quic_crypter.cc create mode 100644 quiche/quic/core/crypto/quic_crypter.h create mode 100644 quiche/quic/core/crypto/quic_crypto_client_config.cc create mode 100644 quiche/quic/core/crypto/quic_crypto_client_config.h create mode 100644 quiche/quic/core/crypto/quic_crypto_client_config_test.cc create mode 100644 quiche/quic/core/crypto/quic_crypto_proof.cc create mode 100644 quiche/quic/core/crypto/quic_crypto_proof.h create mode 100644 quiche/quic/core/crypto/quic_crypto_server_config.cc create mode 100644 quiche/quic/core/crypto/quic_crypto_server_config.h create mode 100644 quiche/quic/core/crypto/quic_crypto_server_config_test.cc create mode 100644 quiche/quic/core/crypto/quic_decrypter.cc create mode 100644 quiche/quic/core/crypto/quic_decrypter.h create mode 100644 quiche/quic/core/crypto/quic_encrypter.cc create mode 100644 quiche/quic/core/crypto/quic_encrypter.h create mode 100644 quiche/quic/core/crypto/quic_hkdf.cc create mode 100644 quiche/quic/core/crypto/quic_hkdf.h create mode 100644 quiche/quic/core/crypto/quic_hkdf_test.cc create mode 100644 quiche/quic/core/crypto/quic_random.h create mode 100644 quiche/quic/core/crypto/tls_client_connection.cc create mode 100644 quiche/quic/core/crypto/tls_client_connection.h create mode 100644 quiche/quic/core/crypto/tls_connection.cc create mode 100644 quiche/quic/core/crypto/tls_connection.h create mode 100644 quiche/quic/core/crypto/tls_server_connection.cc create mode 100644 quiche/quic/core/crypto/tls_server_connection.h create mode 100644 quiche/quic/core/crypto/transport_parameters.cc create mode 100644 quiche/quic/core/crypto/transport_parameters.h create mode 100644 quiche/quic/core/crypto/transport_parameters_test.cc create mode 100644 quiche/quic/core/crypto/web_transport_fingerprint_proof_verifier.cc create mode 100644 quiche/quic/core/crypto/web_transport_fingerprint_proof_verifier.h create mode 100644 quiche/quic/core/crypto/web_transport_fingerprint_proof_verifier_test.cc create mode 100644 quiche/quic/core/deterministic_connection_id_generator.cc create mode 100644 quiche/quic/core/deterministic_connection_id_generator.h create mode 100644 quiche/quic/core/deterministic_connection_id_generator_test.cc create mode 100644 quiche/quic/core/frames/quic_ack_frame.cc create mode 100644 quiche/quic/core/frames/quic_ack_frame.h create mode 100644 quiche/quic/core/frames/quic_ack_frequency_frame.cc create mode 100644 quiche/quic/core/frames/quic_ack_frequency_frame.h create mode 100644 quiche/quic/core/frames/quic_blocked_frame.cc create mode 100644 quiche/quic/core/frames/quic_blocked_frame.h create mode 100644 quiche/quic/core/frames/quic_connection_close_frame.cc create mode 100644 quiche/quic/core/frames/quic_connection_close_frame.h create mode 100644 quiche/quic/core/frames/quic_crypto_frame.cc create mode 100644 quiche/quic/core/frames/quic_crypto_frame.h create mode 100644 quiche/quic/core/frames/quic_frame.cc create mode 100644 quiche/quic/core/frames/quic_frame.h create mode 100644 quiche/quic/core/frames/quic_frames_test.cc create mode 100644 quiche/quic/core/frames/quic_goaway_frame.cc create mode 100644 quiche/quic/core/frames/quic_goaway_frame.h create mode 100644 quiche/quic/core/frames/quic_handshake_done_frame.cc create mode 100644 quiche/quic/core/frames/quic_handshake_done_frame.h create mode 100644 quiche/quic/core/frames/quic_inlined_frame.h create mode 100644 quiche/quic/core/frames/quic_max_streams_frame.cc create mode 100644 quiche/quic/core/frames/quic_max_streams_frame.h create mode 100644 quiche/quic/core/frames/quic_message_frame.cc create mode 100644 quiche/quic/core/frames/quic_message_frame.h create mode 100644 quiche/quic/core/frames/quic_mtu_discovery_frame.h create mode 100644 quiche/quic/core/frames/quic_new_connection_id_frame.cc create mode 100644 quiche/quic/core/frames/quic_new_connection_id_frame.h create mode 100644 quiche/quic/core/frames/quic_new_token_frame.cc create mode 100644 quiche/quic/core/frames/quic_new_token_frame.h create mode 100644 quiche/quic/core/frames/quic_padding_frame.cc create mode 100644 quiche/quic/core/frames/quic_padding_frame.h create mode 100644 quiche/quic/core/frames/quic_path_challenge_frame.cc create mode 100644 quiche/quic/core/frames/quic_path_challenge_frame.h create mode 100644 quiche/quic/core/frames/quic_path_response_frame.cc create mode 100644 quiche/quic/core/frames/quic_path_response_frame.h create mode 100644 quiche/quic/core/frames/quic_ping_frame.cc create mode 100644 quiche/quic/core/frames/quic_ping_frame.h create mode 100644 quiche/quic/core/frames/quic_retire_connection_id_frame.cc create mode 100644 quiche/quic/core/frames/quic_retire_connection_id_frame.h create mode 100644 quiche/quic/core/frames/quic_rst_stream_frame.cc create mode 100644 quiche/quic/core/frames/quic_rst_stream_frame.h create mode 100644 quiche/quic/core/frames/quic_stop_sending_frame.cc create mode 100644 quiche/quic/core/frames/quic_stop_sending_frame.h create mode 100644 quiche/quic/core/frames/quic_stop_waiting_frame.cc create mode 100644 quiche/quic/core/frames/quic_stop_waiting_frame.h create mode 100644 quiche/quic/core/frames/quic_stream_frame.cc create mode 100644 quiche/quic/core/frames/quic_stream_frame.h create mode 100644 quiche/quic/core/frames/quic_streams_blocked_frame.cc create mode 100644 quiche/quic/core/frames/quic_streams_blocked_frame.h create mode 100644 quiche/quic/core/frames/quic_window_update_frame.cc create mode 100644 quiche/quic/core/frames/quic_window_update_frame.h create mode 100644 quiche/quic/core/handshaker_delegate_interface.h create mode 100644 quiche/quic/core/http/end_to_end_test.cc create mode 100644 quiche/quic/core/http/http_constants.cc create mode 100644 quiche/quic/core/http/http_constants.h create mode 100644 quiche/quic/core/http/http_decoder.cc create mode 100644 quiche/quic/core/http/http_decoder.h create mode 100644 quiche/quic/core/http/http_decoder_test.cc create mode 100644 quiche/quic/core/http/http_encoder.cc create mode 100644 quiche/quic/core/http/http_encoder.h create mode 100644 quiche/quic/core/http/http_encoder_test.cc create mode 100644 quiche/quic/core/http/http_frames.h create mode 100644 quiche/quic/core/http/http_frames_test.cc create mode 100644 quiche/quic/core/http/quic_client_promised_info.cc create mode 100644 quiche/quic/core/http/quic_client_promised_info.h create mode 100644 quiche/quic/core/http/quic_client_promised_info_test.cc create mode 100644 quiche/quic/core/http/quic_client_push_promise_index.cc create mode 100644 quiche/quic/core/http/quic_client_push_promise_index.h create mode 100644 quiche/quic/core/http/quic_client_push_promise_index_test.cc create mode 100644 quiche/quic/core/http/quic_header_list.cc create mode 100644 quiche/quic/core/http/quic_header_list.h create mode 100644 quiche/quic/core/http/quic_header_list_test.cc create mode 100644 quiche/quic/core/http/quic_headers_stream.cc create mode 100644 quiche/quic/core/http/quic_headers_stream.h create mode 100644 quiche/quic/core/http/quic_headers_stream_test.cc create mode 100644 quiche/quic/core/http/quic_receive_control_stream.cc create mode 100644 quiche/quic/core/http/quic_receive_control_stream.h create mode 100644 quiche/quic/core/http/quic_receive_control_stream_test.cc create mode 100644 quiche/quic/core/http/quic_send_control_stream.cc create mode 100644 quiche/quic/core/http/quic_send_control_stream.h create mode 100644 quiche/quic/core/http/quic_send_control_stream_test.cc create mode 100644 quiche/quic/core/http/quic_server_initiated_spdy_stream.cc create mode 100644 quiche/quic/core/http/quic_server_initiated_spdy_stream.h create mode 100644 quiche/quic/core/http/quic_server_session_base.cc create mode 100644 quiche/quic/core/http/quic_server_session_base.h create mode 100644 quiche/quic/core/http/quic_server_session_base_test.cc create mode 100644 quiche/quic/core/http/quic_spdy_client_session.cc create mode 100644 quiche/quic/core/http/quic_spdy_client_session.h create mode 100644 quiche/quic/core/http/quic_spdy_client_session_base.cc create mode 100644 quiche/quic/core/http/quic_spdy_client_session_base.h create mode 100644 quiche/quic/core/http/quic_spdy_client_session_test.cc create mode 100644 quiche/quic/core/http/quic_spdy_client_stream.cc create mode 100644 quiche/quic/core/http/quic_spdy_client_stream.h create mode 100644 quiche/quic/core/http/quic_spdy_client_stream_test.cc create mode 100644 quiche/quic/core/http/quic_spdy_server_stream_base.cc create mode 100644 quiche/quic/core/http/quic_spdy_server_stream_base.h create mode 100644 quiche/quic/core/http/quic_spdy_server_stream_base_test.cc create mode 100644 quiche/quic/core/http/quic_spdy_session.cc create mode 100644 quiche/quic/core/http/quic_spdy_session.h create mode 100644 quiche/quic/core/http/quic_spdy_session_test.cc create mode 100644 quiche/quic/core/http/quic_spdy_stream.cc create mode 100644 quiche/quic/core/http/quic_spdy_stream.h create mode 100644 quiche/quic/core/http/quic_spdy_stream_body_manager.cc create mode 100644 quiche/quic/core/http/quic_spdy_stream_body_manager.h create mode 100644 quiche/quic/core/http/quic_spdy_stream_body_manager_test.cc create mode 100644 quiche/quic/core/http/quic_spdy_stream_test.cc create mode 100644 quiche/quic/core/http/spdy_server_push_utils.cc create mode 100644 quiche/quic/core/http/spdy_server_push_utils.h create mode 100644 quiche/quic/core/http/spdy_server_push_utils_test.cc create mode 100644 quiche/quic/core/http/spdy_utils.cc create mode 100644 quiche/quic/core/http/spdy_utils.h create mode 100644 quiche/quic/core/http/spdy_utils_test.cc create mode 100644 quiche/quic/core/http/web_transport_http3.cc create mode 100644 quiche/quic/core/http/web_transport_http3.h create mode 100644 quiche/quic/core/http/web_transport_http3_test.cc create mode 100644 quiche/quic/core/http/web_transport_stream_adapter.cc create mode 100644 quiche/quic/core/http/web_transport_stream_adapter.h create mode 100644 quiche/quic/core/io/event_loop_connecting_client_socket.cc create mode 100644 quiche/quic/core/io/event_loop_connecting_client_socket.h create mode 100644 quiche/quic/core/io/event_loop_connecting_client_socket_test.cc create mode 100644 quiche/quic/core/io/event_loop_socket_factory.cc create mode 100644 quiche/quic/core/io/event_loop_socket_factory.h create mode 100644 quiche/quic/core/io/quic_all_event_loops_test.cc create mode 100644 quiche/quic/core/io/quic_default_event_loop.cc create mode 100644 quiche/quic/core/io/quic_default_event_loop.h create mode 100644 quiche/quic/core/io/quic_event_loop.h create mode 100644 quiche/quic/core/io/quic_poll_event_loop.cc create mode 100644 quiche/quic/core/io/quic_poll_event_loop.h create mode 100644 quiche/quic/core/io/quic_poll_event_loop_test.cc create mode 100644 quiche/quic/core/io/socket.h create mode 100644 quiche/quic/core/io/socket_posix.cc create mode 100644 quiche/quic/core/io/socket_test.cc create mode 100644 quiche/quic/core/legacy_quic_stream_id_manager.cc create mode 100644 quiche/quic/core/legacy_quic_stream_id_manager.h create mode 100644 quiche/quic/core/legacy_quic_stream_id_manager_test.cc create mode 100644 quiche/quic/core/packet_number_indexed_queue.h create mode 100644 quiche/quic/core/packet_number_indexed_queue_test.cc create mode 100644 quiche/quic/core/proto/cached_network_parameters.proto create mode 100644 quiche/quic/core/proto/cached_network_parameters_proto.h create mode 100644 quiche/quic/core/proto/crypto_server_config.proto create mode 100644 quiche/quic/core/proto/crypto_server_config_proto.h create mode 100644 quiche/quic/core/proto/source_address_token.proto create mode 100644 quiche/quic/core/proto/source_address_token_proto.h create mode 100644 quiche/quic/core/qpack/fuzzer/qpack_decoder_fuzzer.cc create mode 100644 quiche/quic/core/qpack/fuzzer/qpack_decoder_stream_receiver_fuzzer.cc create mode 100644 quiche/quic/core/qpack/fuzzer/qpack_decoder_stream_sender_fuzzer.cc create mode 100644 quiche/quic/core/qpack/fuzzer/qpack_encoder_stream_receiver_fuzzer.cc create mode 100644 quiche/quic/core/qpack/fuzzer/qpack_encoder_stream_sender_fuzzer.cc create mode 100644 quiche/quic/core/qpack/fuzzer/qpack_round_trip_fuzzer.cc create mode 100644 quiche/quic/core/qpack/qpack_blocking_manager.cc create mode 100644 quiche/quic/core/qpack/qpack_blocking_manager.h create mode 100644 quiche/quic/core/qpack/qpack_blocking_manager_test.cc create mode 100644 quiche/quic/core/qpack/qpack_decoded_headers_accumulator.cc create mode 100644 quiche/quic/core/qpack/qpack_decoded_headers_accumulator.h create mode 100644 quiche/quic/core/qpack/qpack_decoded_headers_accumulator_test.cc create mode 100644 quiche/quic/core/qpack/qpack_decoder.cc create mode 100644 quiche/quic/core/qpack/qpack_decoder.h create mode 100644 quiche/quic/core/qpack/qpack_decoder_stream_receiver.cc create mode 100644 quiche/quic/core/qpack/qpack_decoder_stream_receiver.h create mode 100644 quiche/quic/core/qpack/qpack_decoder_stream_receiver_test.cc create mode 100644 quiche/quic/core/qpack/qpack_decoder_stream_sender.cc create mode 100644 quiche/quic/core/qpack/qpack_decoder_stream_sender.h create mode 100644 quiche/quic/core/qpack/qpack_decoder_stream_sender_test.cc create mode 100644 quiche/quic/core/qpack/qpack_decoder_test.cc create mode 100644 quiche/quic/core/qpack/qpack_encoder.cc create mode 100644 quiche/quic/core/qpack/qpack_encoder.h create mode 100644 quiche/quic/core/qpack/qpack_encoder_stream_receiver.cc create mode 100644 quiche/quic/core/qpack/qpack_encoder_stream_receiver.h create mode 100644 quiche/quic/core/qpack/qpack_encoder_stream_receiver_test.cc create mode 100644 quiche/quic/core/qpack/qpack_encoder_stream_sender.cc create mode 100644 quiche/quic/core/qpack/qpack_encoder_stream_sender.h create mode 100644 quiche/quic/core/qpack/qpack_encoder_stream_sender_test.cc create mode 100644 quiche/quic/core/qpack/qpack_encoder_test.cc create mode 100644 quiche/quic/core/qpack/qpack_header_table.cc create mode 100644 quiche/quic/core/qpack/qpack_header_table.h create mode 100644 quiche/quic/core/qpack/qpack_header_table_test.cc create mode 100644 quiche/quic/core/qpack/qpack_index_conversions.cc create mode 100644 quiche/quic/core/qpack/qpack_index_conversions.h create mode 100644 quiche/quic/core/qpack/qpack_index_conversions_test.cc create mode 100644 quiche/quic/core/qpack/qpack_instruction_decoder.cc create mode 100644 quiche/quic/core/qpack/qpack_instruction_decoder.h create mode 100644 quiche/quic/core/qpack/qpack_instruction_decoder_test.cc create mode 100644 quiche/quic/core/qpack/qpack_instruction_encoder.cc create mode 100644 quiche/quic/core/qpack/qpack_instruction_encoder.h create mode 100644 quiche/quic/core/qpack/qpack_instruction_encoder_test.cc create mode 100644 quiche/quic/core/qpack/qpack_instructions.cc create mode 100644 quiche/quic/core/qpack/qpack_instructions.h create mode 100644 quiche/quic/core/qpack/qpack_progressive_decoder.cc create mode 100644 quiche/quic/core/qpack/qpack_progressive_decoder.h create mode 100644 quiche/quic/core/qpack/qpack_receive_stream.cc create mode 100644 quiche/quic/core/qpack/qpack_receive_stream.h create mode 100644 quiche/quic/core/qpack/qpack_receive_stream_test.cc create mode 100644 quiche/quic/core/qpack/qpack_required_insert_count.cc create mode 100644 quiche/quic/core/qpack/qpack_required_insert_count.h create mode 100644 quiche/quic/core/qpack/qpack_required_insert_count_test.cc create mode 100644 quiche/quic/core/qpack/qpack_round_trip_test.cc create mode 100644 quiche/quic/core/qpack/qpack_send_stream.cc create mode 100644 quiche/quic/core/qpack/qpack_send_stream.h create mode 100644 quiche/quic/core/qpack/qpack_send_stream_test.cc create mode 100644 quiche/quic/core/qpack/qpack_static_table.cc create mode 100644 quiche/quic/core/qpack/qpack_static_table.h create mode 100644 quiche/quic/core/qpack/qpack_static_table_test.cc create mode 100644 quiche/quic/core/qpack/qpack_stream_receiver.h create mode 100644 quiche/quic/core/qpack/qpack_stream_sender_delegate.h create mode 100644 quiche/quic/core/qpack/value_splitting_header_list.cc create mode 100644 quiche/quic/core/qpack/value_splitting_header_list.h create mode 100644 quiche/quic/core/qpack/value_splitting_header_list_test.cc create mode 100644 quiche/quic/core/quic_ack_listener_interface.cc create mode 100644 quiche/quic/core/quic_ack_listener_interface.h create mode 100644 quiche/quic/core/quic_alarm.cc create mode 100644 quiche/quic/core/quic_alarm.h create mode 100644 quiche/quic/core/quic_alarm_factory.h create mode 100644 quiche/quic/core/quic_alarm_test.cc create mode 100644 quiche/quic/core/quic_arena_scoped_ptr.h create mode 100644 quiche/quic/core/quic_arena_scoped_ptr_test.cc create mode 100644 quiche/quic/core/quic_bandwidth.cc create mode 100644 quiche/quic/core/quic_bandwidth.h create mode 100644 quiche/quic/core/quic_bandwidth_test.cc create mode 100644 quiche/quic/core/quic_blocked_writer_interface.h create mode 100644 quiche/quic/core/quic_buffered_packet_store.cc create mode 100644 quiche/quic/core/quic_buffered_packet_store.h create mode 100644 quiche/quic/core/quic_buffered_packet_store_test.cc create mode 100644 quiche/quic/core/quic_chaos_protector.cc create mode 100644 quiche/quic/core/quic_chaos_protector.h create mode 100644 quiche/quic/core/quic_chaos_protector_test.cc create mode 100644 quiche/quic/core/quic_clock.h create mode 100644 quiche/quic/core/quic_coalesced_packet.cc create mode 100644 quiche/quic/core/quic_coalesced_packet.h create mode 100644 quiche/quic/core/quic_coalesced_packet_test.cc create mode 100644 quiche/quic/core/quic_config.cc create mode 100644 quiche/quic/core/quic_config.h create mode 100644 quiche/quic/core/quic_config_test.cc create mode 100644 quiche/quic/core/quic_connection.cc create mode 100644 quiche/quic/core/quic_connection.h create mode 100644 quiche/quic/core/quic_connection_context.cc create mode 100644 quiche/quic/core/quic_connection_context.h create mode 100644 quiche/quic/core/quic_connection_context_test.cc create mode 100644 quiche/quic/core/quic_connection_id.cc create mode 100644 quiche/quic/core/quic_connection_id.h create mode 100644 quiche/quic/core/quic_connection_id_manager.cc create mode 100644 quiche/quic/core/quic_connection_id_manager.h create mode 100644 quiche/quic/core/quic_connection_id_manager_test.cc create mode 100644 quiche/quic/core/quic_connection_id_test.cc create mode 100644 quiche/quic/core/quic_connection_stats.cc create mode 100644 quiche/quic/core/quic_connection_stats.h create mode 100644 quiche/quic/core/quic_connection_test.cc create mode 100644 quiche/quic/core/quic_constants.cc create mode 100644 quiche/quic/core/quic_constants.h create mode 100644 quiche/quic/core/quic_control_frame_manager.cc create mode 100644 quiche/quic/core/quic_control_frame_manager.h create mode 100644 quiche/quic/core/quic_control_frame_manager_test.cc create mode 100644 quiche/quic/core/quic_crypto_client_handshaker.cc create mode 100644 quiche/quic/core/quic_crypto_client_handshaker.h create mode 100644 quiche/quic/core/quic_crypto_client_handshaker_test.cc create mode 100644 quiche/quic/core/quic_crypto_client_stream.cc create mode 100644 quiche/quic/core/quic_crypto_client_stream.h create mode 100644 quiche/quic/core/quic_crypto_client_stream_test.cc create mode 100644 quiche/quic/core/quic_crypto_handshaker.cc create mode 100644 quiche/quic/core/quic_crypto_handshaker.h create mode 100644 quiche/quic/core/quic_crypto_server_stream.cc create mode 100644 quiche/quic/core/quic_crypto_server_stream.h create mode 100644 quiche/quic/core/quic_crypto_server_stream_base.cc create mode 100644 quiche/quic/core/quic_crypto_server_stream_base.h create mode 100644 quiche/quic/core/quic_crypto_server_stream_test.cc create mode 100644 quiche/quic/core/quic_crypto_stream.cc create mode 100644 quiche/quic/core/quic_crypto_stream.h create mode 100644 quiche/quic/core/quic_crypto_stream_test.cc create mode 100644 quiche/quic/core/quic_data_reader.cc create mode 100644 quiche/quic/core/quic_data_reader.h create mode 100644 quiche/quic/core/quic_data_writer.cc create mode 100644 quiche/quic/core/quic_data_writer.h create mode 100644 quiche/quic/core/quic_data_writer_test.cc create mode 100644 quiche/quic/core/quic_datagram_queue.cc create mode 100644 quiche/quic/core/quic_datagram_queue.h create mode 100644 quiche/quic/core/quic_datagram_queue_test.cc create mode 100644 quiche/quic/core/quic_default_clock.cc create mode 100644 quiche/quic/core/quic_default_clock.h create mode 100644 quiche/quic/core/quic_default_connection_helper.h create mode 100644 quiche/quic/core/quic_default_packet_writer.cc create mode 100644 quiche/quic/core/quic_default_packet_writer.h create mode 100644 quiche/quic/core/quic_dispatcher.cc create mode 100644 quiche/quic/core/quic_dispatcher.h create mode 100644 quiche/quic/core/quic_dispatcher_test.cc create mode 100644 quiche/quic/core/quic_error_codes.cc create mode 100644 quiche/quic/core/quic_error_codes.h create mode 100644 quiche/quic/core/quic_error_codes_test.cc create mode 100644 quiche/quic/core/quic_flags_list.h create mode 100644 quiche/quic/core/quic_flow_controller.cc create mode 100644 quiche/quic/core/quic_flow_controller.h create mode 100644 quiche/quic/core/quic_flow_controller_test.cc create mode 100644 quiche/quic/core/quic_framer.cc create mode 100644 quiche/quic/core/quic_framer.h create mode 100644 quiche/quic/core/quic_framer_test.cc create mode 100644 quiche/quic/core/quic_idle_network_detector.cc create mode 100644 quiche/quic/core/quic_idle_network_detector.h create mode 100644 quiche/quic/core/quic_idle_network_detector_test.cc create mode 100644 quiche/quic/core/quic_interval.h create mode 100644 quiche/quic/core/quic_interval_deque.h create mode 100644 quiche/quic/core/quic_interval_deque_test.cc create mode 100644 quiche/quic/core/quic_interval_set.h create mode 100644 quiche/quic/core/quic_interval_set_test.cc create mode 100644 quiche/quic/core/quic_interval_test.cc create mode 100644 quiche/quic/core/quic_linux_socket_utils.cc create mode 100644 quiche/quic/core/quic_linux_socket_utils.h create mode 100644 quiche/quic/core/quic_linux_socket_utils_test.cc create mode 100644 quiche/quic/core/quic_lru_cache.h create mode 100644 quiche/quic/core/quic_lru_cache_test.cc create mode 100644 quiche/quic/core/quic_mtu_discovery.cc create mode 100644 quiche/quic/core/quic_mtu_discovery.h create mode 100644 quiche/quic/core/quic_network_blackhole_detector.cc create mode 100644 quiche/quic/core/quic_network_blackhole_detector.h create mode 100644 quiche/quic/core/quic_network_blackhole_detector_test.cc create mode 100644 quiche/quic/core/quic_one_block_arena.h create mode 100644 quiche/quic/core/quic_one_block_arena_test.cc create mode 100644 quiche/quic/core/quic_packet_creator.cc create mode 100644 quiche/quic/core/quic_packet_creator.h create mode 100644 quiche/quic/core/quic_packet_creator_test.cc create mode 100644 quiche/quic/core/quic_packet_number.cc create mode 100644 quiche/quic/core/quic_packet_number.h create mode 100644 quiche/quic/core/quic_packet_number_test.cc create mode 100644 quiche/quic/core/quic_packet_reader.cc create mode 100644 quiche/quic/core/quic_packet_reader.h create mode 100644 quiche/quic/core/quic_packet_writer.h create mode 100644 quiche/quic/core/quic_packet_writer_wrapper.cc create mode 100644 quiche/quic/core/quic_packet_writer_wrapper.h create mode 100644 quiche/quic/core/quic_packets.cc create mode 100644 quiche/quic/core/quic_packets.h create mode 100644 quiche/quic/core/quic_packets_test.cc create mode 100644 quiche/quic/core/quic_path_validator.cc create mode 100644 quiche/quic/core/quic_path_validator.h create mode 100644 quiche/quic/core/quic_path_validator_test.cc create mode 100644 quiche/quic/core/quic_ping_manager.cc create mode 100644 quiche/quic/core/quic_ping_manager.h create mode 100644 quiche/quic/core/quic_ping_manager_test.cc create mode 100644 quiche/quic/core/quic_process_packet_interface.h create mode 100644 quiche/quic/core/quic_protocol_flags_list.h create mode 100644 quiche/quic/core/quic_received_packet_manager.cc create mode 100644 quiche/quic/core/quic_received_packet_manager.h create mode 100644 quiche/quic/core/quic_received_packet_manager_test.cc create mode 100644 quiche/quic/core/quic_sent_packet_manager.cc create mode 100644 quiche/quic/core/quic_sent_packet_manager.h create mode 100644 quiche/quic/core/quic_sent_packet_manager_test.cc create mode 100644 quiche/quic/core/quic_server_id.cc create mode 100644 quiche/quic/core/quic_server_id.h create mode 100644 quiche/quic/core/quic_server_id_test.cc create mode 100644 quiche/quic/core/quic_session.cc create mode 100644 quiche/quic/core/quic_session.h create mode 100644 quiche/quic/core/quic_session_test.cc create mode 100644 quiche/quic/core/quic_socket_address_coder.cc create mode 100644 quiche/quic/core/quic_socket_address_coder.h create mode 100644 quiche/quic/core/quic_socket_address_coder_test.cc create mode 100644 quiche/quic/core/quic_stream.cc create mode 100644 quiche/quic/core/quic_stream.h create mode 100644 quiche/quic/core/quic_stream_frame_data_producer.h create mode 100644 quiche/quic/core/quic_stream_id_manager.cc create mode 100644 quiche/quic/core/quic_stream_id_manager.h create mode 100644 quiche/quic/core/quic_stream_id_manager_test.cc create mode 100644 quiche/quic/core/quic_stream_priority.cc create mode 100644 quiche/quic/core/quic_stream_priority.h create mode 100644 quiche/quic/core/quic_stream_priority_test.cc create mode 100644 quiche/quic/core/quic_stream_send_buffer.cc create mode 100644 quiche/quic/core/quic_stream_send_buffer.h create mode 100644 quiche/quic/core/quic_stream_send_buffer_test.cc create mode 100644 quiche/quic/core/quic_stream_sequencer.cc create mode 100644 quiche/quic/core/quic_stream_sequencer.h create mode 100644 quiche/quic/core/quic_stream_sequencer_buffer.cc create mode 100644 quiche/quic/core/quic_stream_sequencer_buffer.h create mode 100644 quiche/quic/core/quic_stream_sequencer_buffer_test.cc create mode 100644 quiche/quic/core/quic_stream_sequencer_test.cc create mode 100644 quiche/quic/core/quic_stream_test.cc create mode 100644 quiche/quic/core/quic_sustained_bandwidth_recorder.cc create mode 100644 quiche/quic/core/quic_sustained_bandwidth_recorder.h create mode 100644 quiche/quic/core/quic_sustained_bandwidth_recorder_test.cc create mode 100644 quiche/quic/core/quic_syscall_wrapper.cc create mode 100644 quiche/quic/core/quic_syscall_wrapper.h create mode 100644 quiche/quic/core/quic_tag.cc create mode 100644 quiche/quic/core/quic_tag.h create mode 100644 quiche/quic/core/quic_tag_test.cc create mode 100644 quiche/quic/core/quic_time.cc create mode 100644 quiche/quic/core/quic_time.h create mode 100644 quiche/quic/core/quic_time_accumulator.h create mode 100644 quiche/quic/core/quic_time_accumulator_test.cc create mode 100644 quiche/quic/core/quic_time_test.cc create mode 100644 quiche/quic/core/quic_time_wait_list_manager.cc create mode 100644 quiche/quic/core/quic_time_wait_list_manager.h create mode 100644 quiche/quic/core/quic_time_wait_list_manager_test.cc create mode 100644 quiche/quic/core/quic_trace_visitor.cc create mode 100644 quiche/quic/core/quic_trace_visitor.h create mode 100644 quiche/quic/core/quic_trace_visitor_test.cc create mode 100644 quiche/quic/core/quic_transmission_info.cc create mode 100644 quiche/quic/core/quic_transmission_info.h create mode 100644 quiche/quic/core/quic_types.cc create mode 100644 quiche/quic/core/quic_types.h create mode 100644 quiche/quic/core/quic_udp_socket.h create mode 100644 quiche/quic/core/quic_udp_socket_posix.cc create mode 100644 quiche/quic/core/quic_unacked_packet_map.cc create mode 100644 quiche/quic/core/quic_unacked_packet_map.h create mode 100644 quiche/quic/core/quic_unacked_packet_map_test.cc create mode 100644 quiche/quic/core/quic_utils.cc create mode 100644 quiche/quic/core/quic_utils.h create mode 100644 quiche/quic/core/quic_utils_test.cc create mode 100644 quiche/quic/core/quic_version_manager.cc create mode 100644 quiche/quic/core/quic_version_manager.h create mode 100644 quiche/quic/core/quic_version_manager_test.cc create mode 100644 quiche/quic/core/quic_versions.cc create mode 100644 quiche/quic/core/quic_versions.h create mode 100644 quiche/quic/core/quic_versions_test.cc create mode 100644 quiche/quic/core/quic_write_blocked_list.cc create mode 100644 quiche/quic/core/quic_write_blocked_list.h create mode 100644 quiche/quic/core/quic_write_blocked_list_test.cc create mode 100644 quiche/quic/core/session_notifier_interface.h create mode 100644 quiche/quic/core/socket_factory.h create mode 100644 quiche/quic/core/stream_delegate_interface.h create mode 100644 quiche/quic/core/tls_chlo_extractor.cc create mode 100644 quiche/quic/core/tls_chlo_extractor.h create mode 100644 quiche/quic/core/tls_chlo_extractor_test.cc create mode 100644 quiche/quic/core/tls_client_handshaker.cc create mode 100644 quiche/quic/core/tls_client_handshaker.h create mode 100644 quiche/quic/core/tls_client_handshaker_test.cc create mode 100644 quiche/quic/core/tls_handshaker.cc create mode 100644 quiche/quic/core/tls_handshaker.h create mode 100644 quiche/quic/core/tls_server_handshaker.cc create mode 100644 quiche/quic/core/tls_server_handshaker.h create mode 100644 quiche/quic/core/tls_server_handshaker_test.cc create mode 100644 quiche/quic/core/uber_quic_stream_id_manager.cc create mode 100644 quiche/quic/core/uber_quic_stream_id_manager.h create mode 100644 quiche/quic/core/uber_quic_stream_id_manager_test.cc create mode 100644 quiche/quic/core/uber_received_packet_manager.cc create mode 100644 quiche/quic/core/uber_received_packet_manager.h create mode 100644 quiche/quic/core/uber_received_packet_manager_test.cc create mode 100644 quiche/quic/core/web_transport_interface.h create mode 100644 quiche/quic/load_balancer/load_balancer_config.cc create mode 100644 quiche/quic/load_balancer/load_balancer_config.h create mode 100644 quiche/quic/load_balancer/load_balancer_config_test.cc create mode 100644 quiche/quic/load_balancer/load_balancer_decoder.cc create mode 100644 quiche/quic/load_balancer/load_balancer_decoder.h create mode 100644 quiche/quic/load_balancer/load_balancer_decoder_test.cc create mode 100644 quiche/quic/load_balancer/load_balancer_encoder.cc create mode 100644 quiche/quic/load_balancer/load_balancer_encoder.h create mode 100644 quiche/quic/load_balancer/load_balancer_encoder_test.cc create mode 100644 quiche/quic/load_balancer/load_balancer_server_id.cc create mode 100644 quiche/quic/load_balancer/load_balancer_server_id.h create mode 100644 quiche/quic/load_balancer/load_balancer_server_id_map.h create mode 100644 quiche/quic/load_balancer/load_balancer_server_id_map_test.cc create mode 100644 quiche/quic/load_balancer/load_balancer_server_id_test.cc create mode 100644 quiche/quic/masque/README.md create mode 100644 quiche/quic/masque/masque_client.cc create mode 100644 quiche/quic/masque/masque_client.h create mode 100644 quiche/quic/masque/masque_client_bin.cc create mode 100644 quiche/quic/masque/masque_client_session.cc create mode 100644 quiche/quic/masque/masque_client_session.h create mode 100644 quiche/quic/masque/masque_client_tools.cc create mode 100644 quiche/quic/masque/masque_client_tools.h create mode 100644 quiche/quic/masque/masque_dispatcher.cc create mode 100644 quiche/quic/masque/masque_dispatcher.h create mode 100644 quiche/quic/masque/masque_encapsulated_client.cc create mode 100644 quiche/quic/masque/masque_encapsulated_client.h create mode 100644 quiche/quic/masque/masque_encapsulated_client_session.cc create mode 100644 quiche/quic/masque/masque_encapsulated_client_session.h create mode 100644 quiche/quic/masque/masque_server.cc create mode 100644 quiche/quic/masque/masque_server.h create mode 100644 quiche/quic/masque/masque_server_backend.cc create mode 100644 quiche/quic/masque/masque_server_backend.h create mode 100644 quiche/quic/masque/masque_server_bin.cc create mode 100644 quiche/quic/masque/masque_server_session.cc create mode 100644 quiche/quic/masque/masque_server_session.h create mode 100644 quiche/quic/masque/masque_utils.cc create mode 100644 quiche/quic/masque/masque_utils.h create mode 100644 quiche/quic/platform/README.md create mode 100644 quiche/quic/platform/api/README.md create mode 100644 quiche/quic/platform/api/quic_bug_tracker.h create mode 100644 quiche/quic/platform/api/quic_client_stats.h create mode 100644 quiche/quic/platform/api/quic_default_proof_providers.h create mode 100644 quiche/quic/platform/api/quic_expect_bug.h create mode 100644 quiche/quic/platform/api/quic_export.h create mode 100644 quiche/quic/platform/api/quic_exported_stats.h create mode 100644 quiche/quic/platform/api/quic_flag_utils.h create mode 100644 quiche/quic/platform/api/quic_flags.h create mode 100644 quiche/quic/platform/api/quic_hostname_utils.h create mode 100644 quiche/quic/platform/api/quic_ip_address.h create mode 100644 quiche/quic/platform/api/quic_ip_address_family.h create mode 100644 quiche/quic/platform/api/quic_logging.h create mode 100644 quiche/quic/platform/api/quic_mutex.h create mode 100644 quiche/quic/platform/api/quic_server_stats.h create mode 100644 quiche/quic/platform/api/quic_socket_address.cc create mode 100644 quiche/quic/platform/api/quic_socket_address.h create mode 100644 quiche/quic/platform/api/quic_socket_address_test.cc create mode 100644 quiche/quic/platform/api/quic_stack_trace.h create mode 100644 quiche/quic/platform/api/quic_test.h create mode 100644 quiche/quic/platform/api/quic_test_loopback.h create mode 100644 quiche/quic/platform/api/quic_test_output.h create mode 100644 quiche/quic/platform/api/quic_testvalue.h create mode 100644 quiche/quic/platform/api/quic_thread.h create mode 100644 quiche/quic/platform/api/quic_udp_socket_platform_api.h create mode 100644 quiche/quic/qbone/bonnet/icmp_reachable.cc create mode 100644 quiche/quic/qbone/bonnet/icmp_reachable.h create mode 100644 quiche/quic/qbone/bonnet/icmp_reachable_interface.h create mode 100644 quiche/quic/qbone/bonnet/icmp_reachable_test.cc create mode 100644 quiche/quic/qbone/bonnet/mock_icmp_reachable.h create mode 100644 quiche/quic/qbone/bonnet/mock_packet_exchanger_stats_interface.h create mode 100644 quiche/quic/qbone/bonnet/mock_qbone_tunnel.h create mode 100644 quiche/quic/qbone/bonnet/mock_tun_device.h create mode 100644 quiche/quic/qbone/bonnet/mock_tun_device_controller.h create mode 100644 quiche/quic/qbone/bonnet/qbone_tunnel_info.cc create mode 100644 quiche/quic/qbone/bonnet/qbone_tunnel_info.h create mode 100644 quiche/quic/qbone/bonnet/qbone_tunnel_interface.h create mode 100644 quiche/quic/qbone/bonnet/qbone_tunnel_silo.cc create mode 100644 quiche/quic/qbone/bonnet/qbone_tunnel_silo.h create mode 100644 quiche/quic/qbone/bonnet/qbone_tunnel_silo_test.cc create mode 100644 quiche/quic/qbone/bonnet/tun_device.cc create mode 100644 quiche/quic/qbone/bonnet/tun_device.h create mode 100644 quiche/quic/qbone/bonnet/tun_device_controller.cc create mode 100644 quiche/quic/qbone/bonnet/tun_device_controller.h create mode 100644 quiche/quic/qbone/bonnet/tun_device_controller_test.cc create mode 100644 quiche/quic/qbone/bonnet/tun_device_interface.h create mode 100644 quiche/quic/qbone/bonnet/tun_device_packet_exchanger.cc create mode 100644 quiche/quic/qbone/bonnet/tun_device_packet_exchanger.h create mode 100644 quiche/quic/qbone/bonnet/tun_device_packet_exchanger_test.cc create mode 100644 quiche/quic/qbone/bonnet/tun_device_test.cc create mode 100644 quiche/quic/qbone/mock_qbone_client.h create mode 100644 quiche/quic/qbone/mock_qbone_server_session.h create mode 100644 quiche/quic/qbone/platform/icmp_packet.cc create mode 100644 quiche/quic/qbone/platform/icmp_packet.h create mode 100644 quiche/quic/qbone/platform/icmp_packet_test.cc create mode 100644 quiche/quic/qbone/platform/internet_checksum.cc create mode 100644 quiche/quic/qbone/platform/internet_checksum.h create mode 100644 quiche/quic/qbone/platform/internet_checksum_test.cc create mode 100644 quiche/quic/qbone/platform/ip_range.cc create mode 100644 quiche/quic/qbone/platform/ip_range.h create mode 100644 quiche/quic/qbone/platform/ip_range_test.cc create mode 100644 quiche/quic/qbone/platform/kernel_interface.h create mode 100644 quiche/quic/qbone/platform/mock_kernel.h create mode 100644 quiche/quic/qbone/platform/mock_netlink.h create mode 100644 quiche/quic/qbone/platform/netlink.cc create mode 100644 quiche/quic/qbone/platform/netlink.h create mode 100644 quiche/quic/qbone/platform/netlink_interface.h create mode 100644 quiche/quic/qbone/platform/netlink_test.cc create mode 100644 quiche/quic/qbone/platform/rtnetlink_message.cc create mode 100644 quiche/quic/qbone/platform/rtnetlink_message.h create mode 100644 quiche/quic/qbone/platform/rtnetlink_message_test.cc create mode 100644 quiche/quic/qbone/platform/tcp_packet.cc create mode 100644 quiche/quic/qbone/platform/tcp_packet.h create mode 100644 quiche/quic/qbone/platform/tcp_packet_test.cc create mode 100644 quiche/quic/qbone/qbone_client.cc create mode 100644 quiche/quic/qbone/qbone_client.h create mode 100644 quiche/quic/qbone/qbone_client_interface.h create mode 100644 quiche/quic/qbone/qbone_client_session.cc create mode 100644 quiche/quic/qbone/qbone_client_session.h create mode 100644 quiche/quic/qbone/qbone_client_test.cc create mode 100644 quiche/quic/qbone/qbone_constants.cc create mode 100644 quiche/quic/qbone/qbone_constants.h create mode 100644 quiche/quic/qbone/qbone_control.proto create mode 100644 quiche/quic/qbone/qbone_control_placeholder.proto create mode 100644 quiche/quic/qbone/qbone_control_stream.cc create mode 100644 quiche/quic/qbone/qbone_control_stream.h create mode 100644 quiche/quic/qbone/qbone_packet_exchanger.cc create mode 100644 quiche/quic/qbone/qbone_packet_exchanger.h create mode 100644 quiche/quic/qbone/qbone_packet_exchanger_test.cc create mode 100644 quiche/quic/qbone/qbone_packet_processor.cc create mode 100644 quiche/quic/qbone/qbone_packet_processor.h create mode 100644 quiche/quic/qbone/qbone_packet_processor_test.cc create mode 100644 quiche/quic/qbone/qbone_packet_processor_test_tools.cc create mode 100644 quiche/quic/qbone/qbone_packet_processor_test_tools.h create mode 100644 quiche/quic/qbone/qbone_packet_writer.h create mode 100644 quiche/quic/qbone/qbone_server_session.cc create mode 100644 quiche/quic/qbone/qbone_server_session.h create mode 100644 quiche/quic/qbone/qbone_session_base.cc create mode 100644 quiche/quic/qbone/qbone_session_base.h create mode 100644 quiche/quic/qbone/qbone_session_test.cc create mode 100644 quiche/quic/qbone/qbone_stream.cc create mode 100644 quiche/quic/qbone/qbone_stream.h create mode 100644 quiche/quic/qbone/qbone_stream_test.cc create mode 100644 quiche/quic/test_tools/bad_packet_writer.cc create mode 100644 quiche/quic/test_tools/bad_packet_writer.h create mode 100644 quiche/quic/test_tools/crypto_test_utils.cc create mode 100644 quiche/quic/test_tools/crypto_test_utils.h create mode 100644 quiche/quic/test_tools/crypto_test_utils_test.cc create mode 100644 quiche/quic/test_tools/failing_proof_source.cc create mode 100644 quiche/quic/test_tools/failing_proof_source.h create mode 100644 quiche/quic/test_tools/fake_proof_source.cc create mode 100644 quiche/quic/test_tools/fake_proof_source.h create mode 100644 quiche/quic/test_tools/fake_proof_source_handle.cc create mode 100644 quiche/quic/test_tools/fake_proof_source_handle.h create mode 100644 quiche/quic/test_tools/first_flight.cc create mode 100644 quiche/quic/test_tools/first_flight.h create mode 100644 quiche/quic/test_tools/fuzzing/README.md create mode 100644 quiche/quic/test_tools/fuzzing/quic_framer_fuzzer.cc create mode 100644 quiche/quic/test_tools/fuzzing/quic_framer_process_data_packet_fuzzer.cc create mode 100644 quiche/quic/test_tools/limited_mtu_test_writer.cc create mode 100644 quiche/quic/test_tools/limited_mtu_test_writer.h create mode 100644 quiche/quic/test_tools/mock_clock.cc create mode 100644 quiche/quic/test_tools/mock_clock.h create mode 100644 quiche/quic/test_tools/mock_connection_id_generator.h create mode 100644 quiche/quic/test_tools/mock_quic_client_promised_info.cc create mode 100644 quiche/quic/test_tools/mock_quic_client_promised_info.h create mode 100644 quiche/quic/test_tools/mock_quic_dispatcher.cc create mode 100644 quiche/quic/test_tools/mock_quic_dispatcher.h create mode 100644 quiche/quic/test_tools/mock_quic_session_visitor.cc create mode 100644 quiche/quic/test_tools/mock_quic_session_visitor.h create mode 100644 quiche/quic/test_tools/mock_quic_spdy_client_stream.cc create mode 100644 quiche/quic/test_tools/mock_quic_spdy_client_stream.h create mode 100644 quiche/quic/test_tools/mock_quic_time_wait_list_manager.cc create mode 100644 quiche/quic/test_tools/mock_quic_time_wait_list_manager.h create mode 100644 quiche/quic/test_tools/mock_random.cc create mode 100644 quiche/quic/test_tools/mock_random.h create mode 100644 quiche/quic/test_tools/packet_dropping_test_writer.cc create mode 100644 quiche/quic/test_tools/packet_dropping_test_writer.h create mode 100644 quiche/quic/test_tools/packet_reordering_writer.cc create mode 100644 quiche/quic/test_tools/packet_reordering_writer.h create mode 100644 quiche/quic/test_tools/qpack/qpack_decoder_test_utils.cc create mode 100644 quiche/quic/test_tools/qpack/qpack_decoder_test_utils.h create mode 100644 quiche/quic/test_tools/qpack/qpack_encoder_peer.cc create mode 100644 quiche/quic/test_tools/qpack/qpack_encoder_peer.h create mode 100644 quiche/quic/test_tools/qpack/qpack_offline_decoder.cc create mode 100644 quiche/quic/test_tools/qpack/qpack_offline_decoder.h create mode 100644 quiche/quic/test_tools/qpack/qpack_test_utils.cc create mode 100644 quiche/quic/test_tools/qpack/qpack_test_utils.h create mode 100644 quiche/quic/test_tools/quic_buffered_packet_store_peer.cc create mode 100644 quiche/quic/test_tools/quic_buffered_packet_store_peer.h create mode 100644 quiche/quic/test_tools/quic_client_promised_info_peer.cc create mode 100644 quiche/quic/test_tools/quic_client_promised_info_peer.h create mode 100644 quiche/quic/test_tools/quic_client_session_cache_peer.h create mode 100644 quiche/quic/test_tools/quic_coalesced_packet_peer.cc create mode 100644 quiche/quic/test_tools/quic_coalesced_packet_peer.h create mode 100644 quiche/quic/test_tools/quic_config_peer.cc create mode 100644 quiche/quic/test_tools/quic_config_peer.h create mode 100644 quiche/quic/test_tools/quic_connection_id_manager_peer.h create mode 100644 quiche/quic/test_tools/quic_connection_peer.cc create mode 100644 quiche/quic/test_tools/quic_connection_peer.h create mode 100644 quiche/quic/test_tools/quic_crypto_server_config_peer.cc create mode 100644 quiche/quic/test_tools/quic_crypto_server_config_peer.h create mode 100644 quiche/quic/test_tools/quic_dispatcher_peer.cc create mode 100644 quiche/quic/test_tools/quic_dispatcher_peer.h create mode 100644 quiche/quic/test_tools/quic_flow_controller_peer.cc create mode 100644 quiche/quic/test_tools/quic_flow_controller_peer.h create mode 100644 quiche/quic/test_tools/quic_framer_peer.cc create mode 100644 quiche/quic/test_tools/quic_framer_peer.h create mode 100644 quiche/quic/test_tools/quic_http_response_cache_data/test.example.com/index.html create mode 100644 quiche/quic/test_tools/quic_http_response_cache_data/test.example.com/map.html create mode 100644 quiche/quic/test_tools/quic_interval_deque_peer.h create mode 100644 quiche/quic/test_tools/quic_mock_syscall_wrapper.cc create mode 100644 quiche/quic/test_tools/quic_mock_syscall_wrapper.h create mode 100644 quiche/quic/test_tools/quic_packet_creator_peer.cc create mode 100644 quiche/quic/test_tools/quic_packet_creator_peer.h create mode 100644 quiche/quic/test_tools/quic_path_validator_peer.cc create mode 100644 quiche/quic/test_tools/quic_path_validator_peer.h create mode 100644 quiche/quic/test_tools/quic_sent_packet_manager_peer.cc create mode 100644 quiche/quic/test_tools/quic_sent_packet_manager_peer.h create mode 100644 quiche/quic/test_tools/quic_server_peer.cc create mode 100644 quiche/quic/test_tools/quic_server_peer.h create mode 100644 quiche/quic/test_tools/quic_server_session_base_peer.h create mode 100644 quiche/quic/test_tools/quic_session_peer.cc create mode 100644 quiche/quic/test_tools/quic_session_peer.h create mode 100644 quiche/quic/test_tools/quic_spdy_session_peer.cc create mode 100644 quiche/quic/test_tools/quic_spdy_session_peer.h create mode 100644 quiche/quic/test_tools/quic_spdy_stream_peer.cc create mode 100644 quiche/quic/test_tools/quic_spdy_stream_peer.h create mode 100644 quiche/quic/test_tools/quic_stream_id_manager_peer.cc create mode 100644 quiche/quic/test_tools/quic_stream_id_manager_peer.h create mode 100644 quiche/quic/test_tools/quic_stream_peer.cc create mode 100644 quiche/quic/test_tools/quic_stream_peer.h create mode 100644 quiche/quic/test_tools/quic_stream_send_buffer_peer.cc create mode 100644 quiche/quic/test_tools/quic_stream_send_buffer_peer.h create mode 100644 quiche/quic/test_tools/quic_stream_sequencer_buffer_peer.cc create mode 100644 quiche/quic/test_tools/quic_stream_sequencer_buffer_peer.h create mode 100644 quiche/quic/test_tools/quic_stream_sequencer_peer.cc create mode 100644 quiche/quic/test_tools/quic_stream_sequencer_peer.h create mode 100644 quiche/quic/test_tools/quic_sustained_bandwidth_recorder_peer.cc create mode 100644 quiche/quic/test_tools/quic_sustained_bandwidth_recorder_peer.h create mode 100644 quiche/quic/test_tools/quic_test_backend.cc create mode 100644 quiche/quic/test_tools/quic_test_backend.h create mode 100644 quiche/quic/test_tools/quic_test_client.cc create mode 100644 quiche/quic/test_tools/quic_test_client.h create mode 100644 quiche/quic/test_tools/quic_test_server.cc create mode 100644 quiche/quic/test_tools/quic_test_server.h create mode 100644 quiche/quic/test_tools/quic_test_utils.cc create mode 100644 quiche/quic/test_tools/quic_test_utils.h create mode 100644 quiche/quic/test_tools/quic_test_utils_test.cc create mode 100644 quiche/quic/test_tools/quic_time_wait_list_manager_peer.cc create mode 100644 quiche/quic/test_tools/quic_time_wait_list_manager_peer.h create mode 100644 quiche/quic/test_tools/quic_unacked_packet_map_peer.cc create mode 100644 quiche/quic/test_tools/quic_unacked_packet_map_peer.h create mode 100644 quiche/quic/test_tools/rtt_stats_peer.cc create mode 100644 quiche/quic/test_tools/rtt_stats_peer.h create mode 100644 quiche/quic/test_tools/send_algorithm_test_result.proto create mode 100644 quiche/quic/test_tools/send_algorithm_test_utils.cc create mode 100644 quiche/quic/test_tools/send_algorithm_test_utils.h create mode 100644 quiche/quic/test_tools/server_thread.cc create mode 100644 quiche/quic/test_tools/server_thread.h create mode 100644 quiche/quic/test_tools/simple_data_producer.cc create mode 100644 quiche/quic/test_tools/simple_data_producer.h create mode 100644 quiche/quic/test_tools/simple_quic_framer.cc create mode 100644 quiche/quic/test_tools/simple_quic_framer.h create mode 100644 quiche/quic/test_tools/simple_session_cache.cc create mode 100644 quiche/quic/test_tools/simple_session_cache.h create mode 100644 quiche/quic/test_tools/simple_session_notifier.cc create mode 100644 quiche/quic/test_tools/simple_session_notifier.h create mode 100644 quiche/quic/test_tools/simple_session_notifier_test.cc create mode 100644 quiche/quic/test_tools/simulator/README.md create mode 100644 quiche/quic/test_tools/simulator/actor.cc create mode 100644 quiche/quic/test_tools/simulator/actor.h create mode 100644 quiche/quic/test_tools/simulator/alarm_factory.cc create mode 100644 quiche/quic/test_tools/simulator/alarm_factory.h create mode 100644 quiche/quic/test_tools/simulator/link.cc create mode 100644 quiche/quic/test_tools/simulator/link.h create mode 100644 quiche/quic/test_tools/simulator/packet_filter.cc create mode 100644 quiche/quic/test_tools/simulator/packet_filter.h create mode 100644 quiche/quic/test_tools/simulator/port.cc create mode 100644 quiche/quic/test_tools/simulator/port.h create mode 100644 quiche/quic/test_tools/simulator/queue.cc create mode 100644 quiche/quic/test_tools/simulator/queue.h create mode 100644 quiche/quic/test_tools/simulator/quic_endpoint.cc create mode 100644 quiche/quic/test_tools/simulator/quic_endpoint.h create mode 100644 quiche/quic/test_tools/simulator/quic_endpoint_base.cc create mode 100644 quiche/quic/test_tools/simulator/quic_endpoint_base.h create mode 100644 quiche/quic/test_tools/simulator/quic_endpoint_test.cc create mode 100644 quiche/quic/test_tools/simulator/simulator.cc create mode 100644 quiche/quic/test_tools/simulator/simulator.h create mode 100644 quiche/quic/test_tools/simulator/simulator_test.cc create mode 100644 quiche/quic/test_tools/simulator/switch.cc create mode 100644 quiche/quic/test_tools/simulator/switch.h create mode 100644 quiche/quic/test_tools/simulator/test_harness.cc create mode 100644 quiche/quic/test_tools/simulator/test_harness.h create mode 100644 quiche/quic/test_tools/simulator/traffic_policer.cc create mode 100644 quiche/quic/test_tools/simulator/traffic_policer.h create mode 100644 quiche/quic/test_tools/test_certificates.cc create mode 100644 quiche/quic/test_tools/test_certificates.h create mode 100644 quiche/quic/test_tools/test_ticket_crypter.cc create mode 100644 quiche/quic/test_tools/test_ticket_crypter.h create mode 100644 quiche/quic/test_tools/web_transport_resets_backend.cc create mode 100644 quiche/quic/test_tools/web_transport_resets_backend.h create mode 100644 quiche/quic/test_tools/web_transport_test_tools.h create mode 100644 quiche/quic/tools/connect_server_backend.cc create mode 100644 quiche/quic/tools/connect_server_backend.h create mode 100644 quiche/quic/tools/connect_tunnel.cc create mode 100644 quiche/quic/tools/connect_tunnel.h create mode 100644 quiche/quic/tools/connect_tunnel_test.cc create mode 100644 quiche/quic/tools/connect_udp_tunnel.cc create mode 100644 quiche/quic/tools/connect_udp_tunnel.h create mode 100644 quiche/quic/tools/connect_udp_tunnel_test.cc create mode 100644 quiche/quic/tools/crypto_message_printer_bin.cc create mode 100644 quiche/quic/tools/fake_proof_verifier.h create mode 100644 quiche/quic/tools/qpack_offline_decoder_bin.cc create mode 100644 quiche/quic/tools/quic_backend_response.cc create mode 100644 quiche/quic/tools/quic_backend_response.h create mode 100644 quiche/quic/tools/quic_client_base.cc create mode 100644 quiche/quic/tools/quic_client_base.h create mode 100644 quiche/quic/tools/quic_client_bin.cc create mode 100644 quiche/quic/tools/quic_client_default_network_helper.cc create mode 100644 quiche/quic/tools/quic_client_default_network_helper.h create mode 100644 quiche/quic/tools/quic_client_factory.h create mode 100644 quiche/quic/tools/quic_client_interop_test_bin.cc create mode 100644 quiche/quic/tools/quic_default_client.cc create mode 100644 quiche/quic/tools/quic_default_client.h create mode 100644 quiche/quic/tools/quic_default_client_test.cc create mode 100644 quiche/quic/tools/quic_epoll_client_factory.cc create mode 100644 quiche/quic/tools/quic_epoll_client_factory.h create mode 100644 quiche/quic/tools/quic_memory_cache_backend.cc create mode 100644 quiche/quic/tools/quic_memory_cache_backend.h create mode 100644 quiche/quic/tools/quic_memory_cache_backend_test.cc create mode 100644 quiche/quic/tools/quic_name_lookup.cc create mode 100644 quiche/quic/tools/quic_name_lookup.h create mode 100644 quiche/quic/tools/quic_packet_printer_bin.cc create mode 100644 quiche/quic/tools/quic_reject_reason_decoder_bin.cc create mode 100644 quiche/quic/tools/quic_server.cc create mode 100644 quiche/quic/tools/quic_server.h create mode 100644 quiche/quic/tools/quic_server_bin.cc create mode 100644 quiche/quic/tools/quic_server_factory.cc create mode 100644 quiche/quic/tools/quic_server_factory.h create mode 100644 quiche/quic/tools/quic_server_test.cc create mode 100644 quiche/quic/tools/quic_simple_client_session.cc create mode 100644 quiche/quic/tools/quic_simple_client_session.h create mode 100644 quiche/quic/tools/quic_simple_client_stream.cc create mode 100644 quiche/quic/tools/quic_simple_client_stream.h create mode 100644 quiche/quic/tools/quic_simple_crypto_server_stream_helper.cc create mode 100644 quiche/quic/tools/quic_simple_crypto_server_stream_helper.h create mode 100644 quiche/quic/tools/quic_simple_dispatcher.cc create mode 100644 quiche/quic/tools/quic_simple_dispatcher.h create mode 100644 quiche/quic/tools/quic_simple_server_backend.h create mode 100644 quiche/quic/tools/quic_simple_server_session.cc create mode 100644 quiche/quic/tools/quic_simple_server_session.h create mode 100644 quiche/quic/tools/quic_simple_server_session_test.cc create mode 100644 quiche/quic/tools/quic_simple_server_stream.cc create mode 100644 quiche/quic/tools/quic_simple_server_stream.h create mode 100644 quiche/quic/tools/quic_simple_server_stream_test.cc create mode 100644 quiche/quic/tools/quic_spdy_client_base.cc create mode 100644 quiche/quic/tools/quic_spdy_client_base.h create mode 100644 quiche/quic/tools/quic_spdy_server_base.h create mode 100644 quiche/quic/tools/quic_tcp_like_trace_converter.cc create mode 100644 quiche/quic/tools/quic_tcp_like_trace_converter.h create mode 100644 quiche/quic/tools/quic_tcp_like_trace_converter_test.cc create mode 100644 quiche/quic/tools/quic_toy_client.cc create mode 100644 quiche/quic/tools/quic_toy_client.h create mode 100644 quiche/quic/tools/quic_toy_server.cc create mode 100644 quiche/quic/tools/quic_toy_server.h create mode 100644 quiche/quic/tools/quic_url.cc create mode 100644 quiche/quic/tools/quic_url.h create mode 100644 quiche/quic/tools/quic_url_test.cc create mode 100644 quiche/quic/tools/simple_ticket_crypter.cc create mode 100644 quiche/quic/tools/simple_ticket_crypter.h create mode 100644 quiche/quic/tools/simple_ticket_crypter_test.cc create mode 100644 quiche/quic/tools/web_transport_test_visitors.h create mode 100644 quiche/spdy/core/array_output_buffer.cc create mode 100644 quiche/spdy/core/array_output_buffer.h create mode 100644 quiche/spdy/core/array_output_buffer_test.cc create mode 100644 quiche/spdy/core/header_byte_listener_interface.h create mode 100644 quiche/spdy/core/hpack/hpack_constants.cc create mode 100644 quiche/spdy/core/hpack/hpack_constants.h create mode 100644 quiche/spdy/core/hpack/hpack_decoder_adapter.cc create mode 100644 quiche/spdy/core/hpack/hpack_decoder_adapter.h create mode 100644 quiche/spdy/core/hpack/hpack_decoder_adapter_test.cc create mode 100644 quiche/spdy/core/hpack/hpack_encoder.cc create mode 100644 quiche/spdy/core/hpack/hpack_encoder.h create mode 100644 quiche/spdy/core/hpack/hpack_encoder_test.cc create mode 100644 quiche/spdy/core/hpack/hpack_entry.cc create mode 100644 quiche/spdy/core/hpack/hpack_entry.h create mode 100644 quiche/spdy/core/hpack/hpack_entry_test.cc create mode 100644 quiche/spdy/core/hpack/hpack_header_table.cc create mode 100644 quiche/spdy/core/hpack/hpack_header_table.h create mode 100644 quiche/spdy/core/hpack/hpack_header_table_test.cc create mode 100644 quiche/spdy/core/hpack/hpack_output_stream.cc create mode 100644 quiche/spdy/core/hpack/hpack_output_stream.h create mode 100644 quiche/spdy/core/hpack/hpack_output_stream_test.cc create mode 100644 quiche/spdy/core/hpack/hpack_round_trip_test.cc create mode 100644 quiche/spdy/core/hpack/hpack_static_table.cc create mode 100644 quiche/spdy/core/hpack/hpack_static_table.h create mode 100644 quiche/spdy/core/hpack/hpack_static_table_test.cc create mode 100644 quiche/spdy/core/http2_frame_decoder_adapter.cc create mode 100644 quiche/spdy/core/http2_frame_decoder_adapter.h create mode 100644 quiche/spdy/core/http2_header_block.cc create mode 100644 quiche/spdy/core/http2_header_block.h create mode 100644 quiche/spdy/core/http2_header_block_hpack_listener.h create mode 100644 quiche/spdy/core/http2_header_block_test.cc create mode 100644 quiche/spdy/core/http2_header_storage.cc create mode 100644 quiche/spdy/core/http2_header_storage.h create mode 100644 quiche/spdy/core/http2_header_storage_test.cc create mode 100644 quiche/spdy/core/metadata_extension.cc create mode 100644 quiche/spdy/core/metadata_extension.h create mode 100644 quiche/spdy/core/metadata_extension_test.cc create mode 100644 quiche/spdy/core/no_op_headers_handler.h create mode 100644 quiche/spdy/core/recording_headers_handler.cc create mode 100644 quiche/spdy/core/recording_headers_handler.h create mode 100644 quiche/spdy/core/spdy_alt_svc_wire_format.cc create mode 100644 quiche/spdy/core/spdy_alt_svc_wire_format.h create mode 100644 quiche/spdy/core/spdy_alt_svc_wire_format_test.cc create mode 100644 quiche/spdy/core/spdy_bitmasks.h create mode 100644 quiche/spdy/core/spdy_frame_builder.cc create mode 100644 quiche/spdy/core/spdy_frame_builder.h create mode 100644 quiche/spdy/core/spdy_frame_builder_test.cc create mode 100644 quiche/spdy/core/spdy_framer.cc create mode 100644 quiche/spdy/core/spdy_framer.h create mode 100644 quiche/spdy/core/spdy_framer_test.cc create mode 100644 quiche/spdy/core/spdy_headers_handler_interface.h create mode 100644 quiche/spdy/core/spdy_intrusive_list.h create mode 100644 quiche/spdy/core/spdy_intrusive_list_test.cc create mode 100644 quiche/spdy/core/spdy_no_op_visitor.cc create mode 100644 quiche/spdy/core/spdy_no_op_visitor.h create mode 100644 quiche/spdy/core/spdy_pinnable_buffer_piece.cc create mode 100644 quiche/spdy/core/spdy_pinnable_buffer_piece.h create mode 100644 quiche/spdy/core/spdy_pinnable_buffer_piece_test.cc create mode 100644 quiche/spdy/core/spdy_prefixed_buffer_reader.cc create mode 100644 quiche/spdy/core/spdy_prefixed_buffer_reader.h create mode 100644 quiche/spdy/core/spdy_prefixed_buffer_reader_test.cc create mode 100644 quiche/spdy/core/spdy_protocol.cc create mode 100644 quiche/spdy/core/spdy_protocol.h create mode 100644 quiche/spdy/core/spdy_protocol_test.cc create mode 100644 quiche/spdy/core/spdy_simple_arena.cc create mode 100644 quiche/spdy/core/spdy_simple_arena.h create mode 100644 quiche/spdy/core/spdy_simple_arena_test.cc create mode 100644 quiche/spdy/core/zero_copy_output_buffer.h create mode 100644 quiche/spdy/test_tools/mock_spdy_framer_visitor.cc create mode 100644 quiche/spdy/test_tools/mock_spdy_framer_visitor.h create mode 100644 quiche/spdy/test_tools/spdy_test_utils.cc create mode 100644 quiche/spdy/test_tools/spdy_test_utils.h create mode 100644 quiche/web_transport/test_tools/mock_web_transport.h create mode 100644 quiche/web_transport/web_transport.h diff --git a/.bazelrc b/.bazelrc new file mode 100644 index 000000000000..dda6d240e89e --- /dev/null +++ b/.bazelrc @@ -0,0 +1,11 @@ +build --cxxopt=-std=c++17 +build --cxxopt=-fno-rtti + +# Enable Abseil/Googletest integration +build --define absl=1 + +# Don't fail on converting "0xff" to char +build --copt=-Wno-narrowing + +# There is no system ICU on non-Linux platforms +build:macos --@com_google_googleurl//build_config:system_icu=0 diff --git a/BUILD.bazel b/BUILD.bazel new file mode 100644 index 000000000000..044dc1f96493 --- /dev/null +++ b/BUILD.bazel @@ -0,0 +1,7 @@ +# Copyright 2022 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +licenses(["notice"]) + +exports_files(["LICENSE"]) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000000..cb06def9e13c --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,35 @@ +# How to Contribute + +We'd love to accept your patches and contributions to this project. There are +just a few small guidelines you need to follow. + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement. You (or your employer) retain the copyright to your contribution; +this simply gives us permission to use and redistribute your contributions as +part of the project. Head over to to see +your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code reviews + +The QUICHE repository is currently not set up to accept pull requests directly. +If you would like to make a contribution, please follow these steps: + +1. Sign the Contributor License Agreement (see above). +2. Create a Gerrit pull request at , or + a GitHub pull request at . +3. Email a link to your pull request to . +4. An engineer will review your pull request and merge it internally. + +Note: if you are a Google engineer with access to google3, please submit a CL to +google3 directly. + +## Community Guidelines + +This project follows +[Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000000..a32e00ce6be3 --- /dev/null +++ b/LICENSE @@ -0,0 +1,27 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md new file mode 100644 index 000000000000..0da8b0661380 --- /dev/null +++ b/README.md @@ -0,0 +1,28 @@ +# QUICHE + +QUICHE stands for QUIC, Http, Etc. It is Google's production-ready +implementation of QUIC, HTTP/2, HTTP/3, and related protocols and tools. It +powers Google's servers, Chromium, Envoy, and other projects. It is actively +developed and maintained. + +There are two public QUICHE repositories. Either one may be used by embedders, +as they are automatically kept in sync: + +* https://quiche.googlesource.com/quiche +* https://github.com/google/quiche + +To embed QUICHE in your project, platform APIs need to be implemented and build +files need to be created. Note that it is on the QUICHE team's roadmap to +include default implementation for all platform APIs and to open-source build +files. In the meanwhile, take a look at open source embedders like Chromium and +Envoy to get started: + +* [Platform implementations in Chromium](https://source.chromium.org/chromium/chromium/src/+/main:net/third_party/quiche/overrides/quiche_platform_impl/) +* [Build file in Chromium](https://source.chromium.org/chromium/chromium/src/+/main:net/third_party/quiche/BUILD.gn) +* [Platform implementations in Envoy](https://github.com/envoyproxy/envoy/tree/master/source/common/quic/platform) +* [Build file in Envoy](https://github.com/envoyproxy/envoy/blob/main/bazel/external/quiche.BUILD) + +To contribute to QUICHE, follow instructions at +[CONTRIBUTING.md](CONTRIBUTING.md). + +QUICHE is only supported on little-endian platforms. diff --git a/WHITESPACE b/WHITESPACE new file mode 100644 index 000000000000..13c1880b42ee --- /dev/null +++ b/WHITESPACE @@ -0,0 +1,3 @@ +Edits in this file will cause a Copybara migration. + +1 2 3 4 5 6 \ No newline at end of file diff --git a/WORKSPACE.bazel b/WORKSPACE.bazel new file mode 100644 index 000000000000..62e2c6f756b2 --- /dev/null +++ b/WORKSPACE.bazel @@ -0,0 +1,87 @@ +# Copyright 2022 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +workspace(name = "com_google_quiche") + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +# -------- Bazel tooling dependencies -------- + +http_archive( + name = "bazel_skylib", + sha256 = "f7be3474d42aae265405a592bb7da8e171919d74c16f082a5457840f06054728", # Last updated 2022-05-18 + urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/1.2.1/bazel-skylib-1.2.1.tar.gz"], +) + +# -------- Dependencies used in core QUICHE build -------- + +http_archive( + name = "com_google_absl", + sha256 = "d33809a982df8705f5220d1acb7cc63650e692a12dc2a8ef3e68b8959a1cee02", # Last updated 2023-04-12 + strip_prefix = "abseil-cpp-32d314d0f5bb0ca3ff71ece49c71a728c128d43e", + urls = ["https://github.com/abseil/abseil-cpp/archive/32d314d0f5bb0ca3ff71ece49c71a728c128d43e.zip"], +) + +http_archive( + name = "com_google_protobuf", + sha256 = "8b28fdd45bab62d15db232ec404248901842e5340299a57765e48abe8a80d930", # Last updated 2022-05-18 + strip_prefix = "protobuf-3.20.1", + urls = ["https://github.com/protocolbuffers/protobuf/archive/refs/tags/v3.20.1.tar.gz"], +) + +http_archive( + name = "boringssl", + sha256 = "5d299325d1db8b2f2db3d927c7bc1f9fcbd05a3f9b5c8239fa527c09bf97f995", # Last updated 2022-10-19 + strip_prefix = "boringssl-0acfcff4be10514aacb98eb8ab27bb60136d131b", + urls = ["https://github.com/google/boringssl/archive/0acfcff4be10514aacb98eb8ab27bb60136d131b.tar.gz"], +) + +http_archive( + name = "com_google_quic_trace", + sha256 = "079331de8c3cbf145a3b57adb3ad4e73d733ecfa84d3486e1c5a9eaeef286549", # Last updated 2022-05-18 + strip_prefix = "quic-trace-c7b993eb750e60c307e82f75763600d9c06a6de1", + urls = ["https://github.com/google/quic-trace/archive/c7b993eb750e60c307e82f75763600d9c06a6de1.tar.gz"], +) + +http_archive( + name = "com_google_googleurl", + sha256 = "a1bc96169d34dcc1406ffb750deef3bc8718bd1f9069a2878838e1bd905de989", # Last updated 2022-04-04 + urls = ["https://storage.googleapis.com/quiche-envoy-integration/googleurl_9cdb1f4d1a365ebdbcbf179dadf7f8aa5ee802e7.tar.gz"], +) + +http_archive( + name = "zlib", + build_file = "//build:zlib.BUILD", + sha256 = "d8688496ea40fb61787500e863cc63c9afcbc524468cedeb478068924eb54932", # Last updated 2022-05-18 + strip_prefix = "zlib-1.2.12", + urls = ["https://github.com/madler/zlib/archive/refs/tags/v1.2.12.tar.gz"], +) + +# -------- Dependencies used by QUICHE tests and extra tooling -------- + +http_archive( + name = "com_google_googletest", + sha256 = "82808543c49488e712d9bd84c50edf40d692ffdaca552b4b019b8b533d3cf8ef", # Last updated 2023-04-12 + strip_prefix = "googletest-12a5852e451baabc79c63a86c634912c563d57bc", + urls = ["https://github.com/google/googletest/archive/12a5852e451baabc79c63a86c634912c563d57bc.zip"], +) + +# Note this must use a commit from the `abseil` branch of the RE2 project. +# https://github.com/google/re2/tree/abseil +http_archive( + name = "com_googlesource_code_re2", + sha256 = "906d0df8ff48f8d3a00a808827f009a840190f404559f649cb8e4d7143255ef9", # Last updated 2022-04-08 + strip_prefix = "re2-a276a8c738735a0fe45a6ee590fe2df69bcf4502", + urls = ["https://github.com/google/re2/archive/a276a8c738735a0fe45a6ee590fe2df69bcf4502.zip"], +) + +# -------- Load and call dependencies of underlying libraries -------- + +load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace") + +bazel_skylib_workspace() + +load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps") + +protobuf_deps() diff --git a/build/BUILD.bazel b/build/BUILD.bazel new file mode 100644 index 000000000000..0ba41e4d1400 --- /dev/null +++ b/build/BUILD.bazel @@ -0,0 +1,7 @@ +# Copyright 2022 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +licenses(["notice"]) + +exports_files(["source_list.json"]) diff --git a/build/source_list.bzl b/build/source_list.bzl new file mode 100644 index 000000000000..3ddc596f74e9 --- /dev/null +++ b/build/source_list.bzl @@ -0,0 +1,1637 @@ +"""Autogenerated source file list for QUICHE Bazel build.""" + +protobuf = [ + "quic/core/proto/cached_network_parameters.proto", + "quic/core/proto/crypto_server_config.proto", + "quic/core/proto/source_address_token.proto", +] +protobuf_test_support = [ + "quic/test_tools/send_algorithm_test_result.proto", +] +quiche_core_hdrs = [ + "balsa/balsa_enums.h", + "balsa/balsa_frame.h", + "balsa/balsa_headers.h", + "balsa/balsa_visitor_interface.h", + "balsa/framer_interface.h", + "balsa/header_api.h", + "balsa/header_properties.h", + "balsa/http_validation_policy.h", + "balsa/noop_balsa_visitor.h", + "balsa/simple_buffer.h", + "balsa/standard_header_map.h", + "common/btree_scheduler.h", + "common/capsule.h", + "common/masque/connect_udp_datagram_payload.h", + "common/platform/api/quiche_bug_tracker.h", + "common/platform/api/quiche_client_stats.h", + "common/platform/api/quiche_containers.h", + "common/platform/api/quiche_export.h", + "common/platform/api/quiche_flag_utils.h", + "common/platform/api/quiche_flags.h", + "common/platform/api/quiche_header_policy.h", + "common/platform/api/quiche_hostname_utils.h", + "common/platform/api/quiche_iovec.h", + "common/platform/api/quiche_logging.h", + "common/platform/api/quiche_lower_case_string.h", + "common/platform/api/quiche_mem_slice.h", + "common/platform/api/quiche_mutex.h", + "common/platform/api/quiche_prefetch.h", + "common/platform/api/quiche_reference_counted.h", + "common/platform/api/quiche_server_stats.h", + "common/platform/api/quiche_stack_trace.h", + "common/platform/api/quiche_testvalue.h", + "common/platform/api/quiche_thread.h", + "common/platform/api/quiche_time_utils.h", + "common/platform/api/quiche_url_utils.h", + "common/print_elements.h", + "common/quiche_buffer_allocator.h", + "common/quiche_circular_deque.h", + "common/quiche_crypto_logging.h", + "common/quiche_data_reader.h", + "common/quiche_data_writer.h", + "common/quiche_endian.h", + "common/quiche_ip_address.h", + "common/quiche_ip_address_family.h", + "common/quiche_linked_hash_map.h", + "common/quiche_mem_slice_storage.h", + "common/quiche_protocol_flags_list.h", + "common/quiche_random.h", + "common/quiche_status_utils.h", + "common/quiche_stream.h", + "common/quiche_text_utils.h", + "common/simple_buffer_allocator.h", + "common/structured_headers.h", + "common/wire_serialization.h", + "http2/adapter/data_source.h", + "http2/adapter/event_forwarder.h", + "http2/adapter/header_validator.h", + "http2/adapter/header_validator_base.h", + "http2/adapter/http2_adapter.h", + "http2/adapter/http2_protocol.h", + "http2/adapter/http2_session.h", + "http2/adapter/http2_util.h", + "http2/adapter/http2_visitor_interface.h", + "http2/adapter/noop_header_validator.h", + "http2/adapter/oghttp2_adapter.h", + "http2/adapter/oghttp2_session.h", + "http2/adapter/oghttp2_util.h", + "http2/adapter/window_manager.h", + "http2/core/http2_trace_logging.h", + "http2/core/priority_write_scheduler.h", + "http2/decoder/decode_buffer.h", + "http2/decoder/decode_http2_structures.h", + "http2/decoder/decode_status.h", + "http2/decoder/frame_decoder_state.h", + "http2/decoder/http2_frame_decoder.h", + "http2/decoder/http2_frame_decoder_listener.h", + "http2/decoder/http2_structure_decoder.h", + "http2/decoder/payload_decoders/altsvc_payload_decoder.h", + "http2/decoder/payload_decoders/continuation_payload_decoder.h", + "http2/decoder/payload_decoders/data_payload_decoder.h", + "http2/decoder/payload_decoders/goaway_payload_decoder.h", + "http2/decoder/payload_decoders/headers_payload_decoder.h", + "http2/decoder/payload_decoders/ping_payload_decoder.h", + "http2/decoder/payload_decoders/priority_payload_decoder.h", + "http2/decoder/payload_decoders/priority_update_payload_decoder.h", + "http2/decoder/payload_decoders/push_promise_payload_decoder.h", + "http2/decoder/payload_decoders/rst_stream_payload_decoder.h", + "http2/decoder/payload_decoders/settings_payload_decoder.h", + "http2/decoder/payload_decoders/unknown_payload_decoder.h", + "http2/decoder/payload_decoders/window_update_payload_decoder.h", + "http2/hpack/decoder/hpack_block_decoder.h", + "http2/hpack/decoder/hpack_decoder.h", + "http2/hpack/decoder/hpack_decoder_listener.h", + "http2/hpack/decoder/hpack_decoder_state.h", + "http2/hpack/decoder/hpack_decoder_string_buffer.h", + "http2/hpack/decoder/hpack_decoder_tables.h", + "http2/hpack/decoder/hpack_decoding_error.h", + "http2/hpack/decoder/hpack_entry_decoder.h", + "http2/hpack/decoder/hpack_entry_decoder_listener.h", + "http2/hpack/decoder/hpack_entry_type_decoder.h", + "http2/hpack/decoder/hpack_string_decoder.h", + "http2/hpack/decoder/hpack_string_decoder_listener.h", + "http2/hpack/decoder/hpack_whole_entry_buffer.h", + "http2/hpack/decoder/hpack_whole_entry_listener.h", + "http2/hpack/http2_hpack_constants.h", + "http2/hpack/huffman/hpack_huffman_decoder.h", + "http2/hpack/huffman/hpack_huffman_encoder.h", + "http2/hpack/huffman/huffman_spec_tables.h", + "http2/hpack/varint/hpack_varint_decoder.h", + "http2/hpack/varint/hpack_varint_encoder.h", + "http2/http2_constants.h", + "http2/http2_structures.h", + "quic/core/chlo_extractor.h", + "quic/core/congestion_control/bandwidth_sampler.h", + "quic/core/congestion_control/bbr2_drain.h", + "quic/core/congestion_control/bbr2_misc.h", + "quic/core/congestion_control/bbr2_probe_bw.h", + "quic/core/congestion_control/bbr2_probe_rtt.h", + "quic/core/congestion_control/bbr2_sender.h", + "quic/core/congestion_control/bbr2_startup.h", + "quic/core/congestion_control/bbr_sender.h", + "quic/core/congestion_control/cubic_bytes.h", + "quic/core/congestion_control/general_loss_algorithm.h", + "quic/core/congestion_control/hybrid_slow_start.h", + "quic/core/congestion_control/loss_detection_interface.h", + "quic/core/congestion_control/pacing_sender.h", + "quic/core/congestion_control/prr_sender.h", + "quic/core/congestion_control/rtt_stats.h", + "quic/core/congestion_control/send_algorithm_interface.h", + "quic/core/congestion_control/tcp_cubic_sender_bytes.h", + "quic/core/congestion_control/uber_loss_algorithm.h", + "quic/core/congestion_control/windowed_filter.h", + "quic/core/connecting_client_socket.h", + "quic/core/connection_id_generator.h", + "quic/core/crypto/aead_base_decrypter.h", + "quic/core/crypto/aead_base_encrypter.h", + "quic/core/crypto/aes_128_gcm_12_decrypter.h", + "quic/core/crypto/aes_128_gcm_12_encrypter.h", + "quic/core/crypto/aes_128_gcm_decrypter.h", + "quic/core/crypto/aes_128_gcm_encrypter.h", + "quic/core/crypto/aes_256_gcm_decrypter.h", + "quic/core/crypto/aes_256_gcm_encrypter.h", + "quic/core/crypto/aes_base_decrypter.h", + "quic/core/crypto/aes_base_encrypter.h", + "quic/core/crypto/boring_utils.h", + "quic/core/crypto/cert_compressor.h", + "quic/core/crypto/certificate_util.h", + "quic/core/crypto/certificate_view.h", + "quic/core/crypto/chacha20_poly1305_decrypter.h", + "quic/core/crypto/chacha20_poly1305_encrypter.h", + "quic/core/crypto/chacha20_poly1305_tls_decrypter.h", + "quic/core/crypto/chacha20_poly1305_tls_encrypter.h", + "quic/core/crypto/chacha_base_decrypter.h", + "quic/core/crypto/chacha_base_encrypter.h", + "quic/core/crypto/channel_id.h", + "quic/core/crypto/client_proof_source.h", + "quic/core/crypto/crypto_framer.h", + "quic/core/crypto/crypto_handshake.h", + "quic/core/crypto/crypto_handshake_message.h", + "quic/core/crypto/crypto_message_parser.h", + "quic/core/crypto/crypto_protocol.h", + "quic/core/crypto/crypto_secret_boxer.h", + "quic/core/crypto/crypto_utils.h", + "quic/core/crypto/curve25519_key_exchange.h", + "quic/core/crypto/key_exchange.h", + "quic/core/crypto/null_decrypter.h", + "quic/core/crypto/null_encrypter.h", + "quic/core/crypto/p256_key_exchange.h", + "quic/core/crypto/proof_source.h", + "quic/core/crypto/proof_source_x509.h", + "quic/core/crypto/proof_verifier.h", + "quic/core/crypto/quic_client_session_cache.h", + "quic/core/crypto/quic_compressed_certs_cache.h", + "quic/core/crypto/quic_crypter.h", + "quic/core/crypto/quic_crypto_client_config.h", + "quic/core/crypto/quic_crypto_proof.h", + "quic/core/crypto/quic_crypto_server_config.h", + "quic/core/crypto/quic_decrypter.h", + "quic/core/crypto/quic_encrypter.h", + "quic/core/crypto/quic_hkdf.h", + "quic/core/crypto/quic_random.h", + "quic/core/crypto/tls_client_connection.h", + "quic/core/crypto/tls_connection.h", + "quic/core/crypto/tls_server_connection.h", + "quic/core/crypto/transport_parameters.h", + "quic/core/crypto/web_transport_fingerprint_proof_verifier.h", + "quic/core/deterministic_connection_id_generator.h", + "quic/core/frames/quic_ack_frame.h", + "quic/core/frames/quic_ack_frequency_frame.h", + "quic/core/frames/quic_blocked_frame.h", + "quic/core/frames/quic_connection_close_frame.h", + "quic/core/frames/quic_crypto_frame.h", + "quic/core/frames/quic_frame.h", + "quic/core/frames/quic_goaway_frame.h", + "quic/core/frames/quic_handshake_done_frame.h", + "quic/core/frames/quic_inlined_frame.h", + "quic/core/frames/quic_max_streams_frame.h", + "quic/core/frames/quic_message_frame.h", + "quic/core/frames/quic_mtu_discovery_frame.h", + "quic/core/frames/quic_new_connection_id_frame.h", + "quic/core/frames/quic_new_token_frame.h", + "quic/core/frames/quic_padding_frame.h", + "quic/core/frames/quic_path_challenge_frame.h", + "quic/core/frames/quic_path_response_frame.h", + "quic/core/frames/quic_ping_frame.h", + "quic/core/frames/quic_retire_connection_id_frame.h", + "quic/core/frames/quic_rst_stream_frame.h", + "quic/core/frames/quic_stop_sending_frame.h", + "quic/core/frames/quic_stop_waiting_frame.h", + "quic/core/frames/quic_stream_frame.h", + "quic/core/frames/quic_streams_blocked_frame.h", + "quic/core/frames/quic_window_update_frame.h", + "quic/core/handshaker_delegate_interface.h", + "quic/core/http/http_constants.h", + "quic/core/http/http_decoder.h", + "quic/core/http/http_encoder.h", + "quic/core/http/http_frames.h", + "quic/core/http/quic_client_promised_info.h", + "quic/core/http/quic_client_push_promise_index.h", + "quic/core/http/quic_header_list.h", + "quic/core/http/quic_headers_stream.h", + "quic/core/http/quic_receive_control_stream.h", + "quic/core/http/quic_send_control_stream.h", + "quic/core/http/quic_server_initiated_spdy_stream.h", + "quic/core/http/quic_server_session_base.h", + "quic/core/http/quic_spdy_client_session.h", + "quic/core/http/quic_spdy_client_session_base.h", + "quic/core/http/quic_spdy_client_stream.h", + "quic/core/http/quic_spdy_server_stream_base.h", + "quic/core/http/quic_spdy_session.h", + "quic/core/http/quic_spdy_stream.h", + "quic/core/http/quic_spdy_stream_body_manager.h", + "quic/core/http/spdy_server_push_utils.h", + "quic/core/http/spdy_utils.h", + "quic/core/http/web_transport_http3.h", + "quic/core/http/web_transport_stream_adapter.h", + "quic/core/legacy_quic_stream_id_manager.h", + "quic/core/packet_number_indexed_queue.h", + "quic/core/proto/cached_network_parameters_proto.h", + "quic/core/proto/crypto_server_config_proto.h", + "quic/core/proto/source_address_token_proto.h", + "quic/core/qpack/qpack_blocking_manager.h", + "quic/core/qpack/qpack_decoded_headers_accumulator.h", + "quic/core/qpack/qpack_decoder.h", + "quic/core/qpack/qpack_decoder_stream_receiver.h", + "quic/core/qpack/qpack_decoder_stream_sender.h", + "quic/core/qpack/qpack_encoder.h", + "quic/core/qpack/qpack_encoder_stream_receiver.h", + "quic/core/qpack/qpack_encoder_stream_sender.h", + "quic/core/qpack/qpack_header_table.h", + "quic/core/qpack/qpack_index_conversions.h", + "quic/core/qpack/qpack_instruction_decoder.h", + "quic/core/qpack/qpack_instruction_encoder.h", + "quic/core/qpack/qpack_instructions.h", + "quic/core/qpack/qpack_progressive_decoder.h", + "quic/core/qpack/qpack_receive_stream.h", + "quic/core/qpack/qpack_required_insert_count.h", + "quic/core/qpack/qpack_send_stream.h", + "quic/core/qpack/qpack_static_table.h", + "quic/core/qpack/qpack_stream_receiver.h", + "quic/core/qpack/qpack_stream_sender_delegate.h", + "quic/core/qpack/value_splitting_header_list.h", + "quic/core/quic_ack_listener_interface.h", + "quic/core/quic_alarm.h", + "quic/core/quic_alarm_factory.h", + "quic/core/quic_arena_scoped_ptr.h", + "quic/core/quic_bandwidth.h", + "quic/core/quic_blocked_writer_interface.h", + "quic/core/quic_buffered_packet_store.h", + "quic/core/quic_chaos_protector.h", + "quic/core/quic_clock.h", + "quic/core/quic_coalesced_packet.h", + "quic/core/quic_config.h", + "quic/core/quic_connection.h", + "quic/core/quic_connection_context.h", + "quic/core/quic_connection_id.h", + "quic/core/quic_connection_id_manager.h", + "quic/core/quic_connection_stats.h", + "quic/core/quic_constants.h", + "quic/core/quic_control_frame_manager.h", + "quic/core/quic_crypto_client_handshaker.h", + "quic/core/quic_crypto_client_stream.h", + "quic/core/quic_crypto_handshaker.h", + "quic/core/quic_crypto_server_stream.h", + "quic/core/quic_crypto_server_stream_base.h", + "quic/core/quic_crypto_stream.h", + "quic/core/quic_data_reader.h", + "quic/core/quic_data_writer.h", + "quic/core/quic_datagram_queue.h", + "quic/core/quic_default_clock.h", + "quic/core/quic_default_connection_helper.h", + "quic/core/quic_dispatcher.h", + "quic/core/quic_error_codes.h", + "quic/core/quic_flags_list.h", + "quic/core/quic_flow_controller.h", + "quic/core/quic_framer.h", + "quic/core/quic_idle_network_detector.h", + "quic/core/quic_interval.h", + "quic/core/quic_interval_deque.h", + "quic/core/quic_interval_set.h", + "quic/core/quic_lru_cache.h", + "quic/core/quic_mtu_discovery.h", + "quic/core/quic_network_blackhole_detector.h", + "quic/core/quic_one_block_arena.h", + "quic/core/quic_packet_creator.h", + "quic/core/quic_packet_number.h", + "quic/core/quic_packet_writer.h", + "quic/core/quic_packet_writer_wrapper.h", + "quic/core/quic_packets.h", + "quic/core/quic_path_validator.h", + "quic/core/quic_ping_manager.h", + "quic/core/quic_process_packet_interface.h", + "quic/core/quic_protocol_flags_list.h", + "quic/core/quic_received_packet_manager.h", + "quic/core/quic_sent_packet_manager.h", + "quic/core/quic_server_id.h", + "quic/core/quic_session.h", + "quic/core/quic_socket_address_coder.h", + "quic/core/quic_stream.h", + "quic/core/quic_stream_frame_data_producer.h", + "quic/core/quic_stream_id_manager.h", + "quic/core/quic_stream_priority.h", + "quic/core/quic_stream_send_buffer.h", + "quic/core/quic_stream_sequencer.h", + "quic/core/quic_stream_sequencer_buffer.h", + "quic/core/quic_sustained_bandwidth_recorder.h", + "quic/core/quic_tag.h", + "quic/core/quic_time.h", + "quic/core/quic_time_accumulator.h", + "quic/core/quic_time_wait_list_manager.h", + "quic/core/quic_trace_visitor.h", + "quic/core/quic_transmission_info.h", + "quic/core/quic_types.h", + "quic/core/quic_unacked_packet_map.h", + "quic/core/quic_utils.h", + "quic/core/quic_version_manager.h", + "quic/core/quic_versions.h", + "quic/core/quic_write_blocked_list.h", + "quic/core/session_notifier_interface.h", + "quic/core/socket_factory.h", + "quic/core/stream_delegate_interface.h", + "quic/core/tls_chlo_extractor.h", + "quic/core/tls_client_handshaker.h", + "quic/core/tls_handshaker.h", + "quic/core/tls_server_handshaker.h", + "quic/core/uber_quic_stream_id_manager.h", + "quic/core/uber_received_packet_manager.h", + "quic/core/web_transport_interface.h", + "quic/platform/api/quic_bug_tracker.h", + "quic/platform/api/quic_client_stats.h", + "quic/platform/api/quic_export.h", + "quic/platform/api/quic_exported_stats.h", + "quic/platform/api/quic_flag_utils.h", + "quic/platform/api/quic_flags.h", + "quic/platform/api/quic_hostname_utils.h", + "quic/platform/api/quic_ip_address.h", + "quic/platform/api/quic_ip_address_family.h", + "quic/platform/api/quic_logging.h", + "quic/platform/api/quic_mutex.h", + "quic/platform/api/quic_server_stats.h", + "quic/platform/api/quic_socket_address.h", + "quic/platform/api/quic_stack_trace.h", + "quic/platform/api/quic_testvalue.h", + "quic/platform/api/quic_thread.h", + "spdy/core/array_output_buffer.h", + "spdy/core/header_byte_listener_interface.h", + "spdy/core/hpack/hpack_constants.h", + "spdy/core/hpack/hpack_decoder_adapter.h", + "spdy/core/hpack/hpack_encoder.h", + "spdy/core/hpack/hpack_entry.h", + "spdy/core/hpack/hpack_header_table.h", + "spdy/core/hpack/hpack_output_stream.h", + "spdy/core/hpack/hpack_static_table.h", + "spdy/core/http2_frame_decoder_adapter.h", + "spdy/core/http2_header_block.h", + "spdy/core/http2_header_block_hpack_listener.h", + "spdy/core/http2_header_storage.h", + "spdy/core/metadata_extension.h", + "spdy/core/no_op_headers_handler.h", + "spdy/core/recording_headers_handler.h", + "spdy/core/spdy_alt_svc_wire_format.h", + "spdy/core/spdy_bitmasks.h", + "spdy/core/spdy_frame_builder.h", + "spdy/core/spdy_framer.h", + "spdy/core/spdy_headers_handler_interface.h", + "spdy/core/spdy_intrusive_list.h", + "spdy/core/spdy_no_op_visitor.h", + "spdy/core/spdy_pinnable_buffer_piece.h", + "spdy/core/spdy_prefixed_buffer_reader.h", + "spdy/core/spdy_protocol.h", + "spdy/core/spdy_simple_arena.h", + "spdy/core/zero_copy_output_buffer.h", + "web_transport/web_transport.h", +] +quiche_core_srcs = [ + "balsa/balsa_enums.cc", + "balsa/balsa_frame.cc", + "balsa/balsa_headers.cc", + "balsa/header_properties.cc", + "balsa/simple_buffer.cc", + "balsa/standard_header_map.cc", + "common/capsule.cc", + "common/masque/connect_udp_datagram_payload.cc", + "common/platform/api/quiche_hostname_utils.cc", + "common/platform/api/quiche_mutex.cc", + "common/quiche_buffer_allocator.cc", + "common/quiche_crypto_logging.cc", + "common/quiche_data_reader.cc", + "common/quiche_data_writer.cc", + "common/quiche_ip_address.cc", + "common/quiche_ip_address_family.cc", + "common/quiche_mem_slice_storage.cc", + "common/quiche_random.cc", + "common/quiche_text_utils.cc", + "common/simple_buffer_allocator.cc", + "common/structured_headers.cc", + "http2/adapter/event_forwarder.cc", + "http2/adapter/header_validator.cc", + "http2/adapter/http2_protocol.cc", + "http2/adapter/http2_util.cc", + "http2/adapter/noop_header_validator.cc", + "http2/adapter/oghttp2_adapter.cc", + "http2/adapter/oghttp2_session.cc", + "http2/adapter/oghttp2_util.cc", + "http2/adapter/window_manager.cc", + "http2/core/http2_trace_logging.cc", + "http2/decoder/decode_buffer.cc", + "http2/decoder/decode_http2_structures.cc", + "http2/decoder/decode_status.cc", + "http2/decoder/frame_decoder_state.cc", + "http2/decoder/http2_frame_decoder.cc", + "http2/decoder/http2_frame_decoder_listener.cc", + "http2/decoder/http2_structure_decoder.cc", + "http2/decoder/payload_decoders/altsvc_payload_decoder.cc", + "http2/decoder/payload_decoders/continuation_payload_decoder.cc", + "http2/decoder/payload_decoders/data_payload_decoder.cc", + "http2/decoder/payload_decoders/goaway_payload_decoder.cc", + "http2/decoder/payload_decoders/headers_payload_decoder.cc", + "http2/decoder/payload_decoders/ping_payload_decoder.cc", + "http2/decoder/payload_decoders/priority_payload_decoder.cc", + "http2/decoder/payload_decoders/priority_update_payload_decoder.cc", + "http2/decoder/payload_decoders/push_promise_payload_decoder.cc", + "http2/decoder/payload_decoders/rst_stream_payload_decoder.cc", + "http2/decoder/payload_decoders/settings_payload_decoder.cc", + "http2/decoder/payload_decoders/unknown_payload_decoder.cc", + "http2/decoder/payload_decoders/window_update_payload_decoder.cc", + "http2/hpack/decoder/hpack_block_decoder.cc", + "http2/hpack/decoder/hpack_decoder.cc", + "http2/hpack/decoder/hpack_decoder_listener.cc", + "http2/hpack/decoder/hpack_decoder_state.cc", + "http2/hpack/decoder/hpack_decoder_string_buffer.cc", + "http2/hpack/decoder/hpack_decoder_tables.cc", + "http2/hpack/decoder/hpack_decoding_error.cc", + "http2/hpack/decoder/hpack_entry_decoder.cc", + "http2/hpack/decoder/hpack_entry_decoder_listener.cc", + "http2/hpack/decoder/hpack_entry_type_decoder.cc", + "http2/hpack/decoder/hpack_string_decoder.cc", + "http2/hpack/decoder/hpack_string_decoder_listener.cc", + "http2/hpack/decoder/hpack_whole_entry_buffer.cc", + "http2/hpack/decoder/hpack_whole_entry_listener.cc", + "http2/hpack/http2_hpack_constants.cc", + "http2/hpack/huffman/hpack_huffman_decoder.cc", + "http2/hpack/huffman/hpack_huffman_encoder.cc", + "http2/hpack/huffman/huffman_spec_tables.cc", + "http2/hpack/varint/hpack_varint_decoder.cc", + "http2/hpack/varint/hpack_varint_encoder.cc", + "http2/http2_constants.cc", + "http2/http2_structures.cc", + "quic/core/chlo_extractor.cc", + "quic/core/congestion_control/bandwidth_sampler.cc", + "quic/core/congestion_control/bbr2_drain.cc", + "quic/core/congestion_control/bbr2_misc.cc", + "quic/core/congestion_control/bbr2_probe_bw.cc", + "quic/core/congestion_control/bbr2_probe_rtt.cc", + "quic/core/congestion_control/bbr2_sender.cc", + "quic/core/congestion_control/bbr2_startup.cc", + "quic/core/congestion_control/bbr_sender.cc", + "quic/core/congestion_control/cubic_bytes.cc", + "quic/core/congestion_control/general_loss_algorithm.cc", + "quic/core/congestion_control/hybrid_slow_start.cc", + "quic/core/congestion_control/pacing_sender.cc", + "quic/core/congestion_control/prr_sender.cc", + "quic/core/congestion_control/rtt_stats.cc", + "quic/core/congestion_control/send_algorithm_interface.cc", + "quic/core/congestion_control/tcp_cubic_sender_bytes.cc", + "quic/core/congestion_control/uber_loss_algorithm.cc", + "quic/core/crypto/aead_base_decrypter.cc", + "quic/core/crypto/aead_base_encrypter.cc", + "quic/core/crypto/aes_128_gcm_12_decrypter.cc", + "quic/core/crypto/aes_128_gcm_12_encrypter.cc", + "quic/core/crypto/aes_128_gcm_decrypter.cc", + "quic/core/crypto/aes_128_gcm_encrypter.cc", + "quic/core/crypto/aes_256_gcm_decrypter.cc", + "quic/core/crypto/aes_256_gcm_encrypter.cc", + "quic/core/crypto/aes_base_decrypter.cc", + "quic/core/crypto/aes_base_encrypter.cc", + "quic/core/crypto/cert_compressor.cc", + "quic/core/crypto/certificate_util.cc", + "quic/core/crypto/certificate_view.cc", + "quic/core/crypto/chacha20_poly1305_decrypter.cc", + "quic/core/crypto/chacha20_poly1305_encrypter.cc", + "quic/core/crypto/chacha20_poly1305_tls_decrypter.cc", + "quic/core/crypto/chacha20_poly1305_tls_encrypter.cc", + "quic/core/crypto/chacha_base_decrypter.cc", + "quic/core/crypto/chacha_base_encrypter.cc", + "quic/core/crypto/channel_id.cc", + "quic/core/crypto/client_proof_source.cc", + "quic/core/crypto/crypto_framer.cc", + "quic/core/crypto/crypto_handshake.cc", + "quic/core/crypto/crypto_handshake_message.cc", + "quic/core/crypto/crypto_secret_boxer.cc", + "quic/core/crypto/crypto_utils.cc", + "quic/core/crypto/curve25519_key_exchange.cc", + "quic/core/crypto/key_exchange.cc", + "quic/core/crypto/null_decrypter.cc", + "quic/core/crypto/null_encrypter.cc", + "quic/core/crypto/p256_key_exchange.cc", + "quic/core/crypto/proof_source.cc", + "quic/core/crypto/proof_source_x509.cc", + "quic/core/crypto/quic_client_session_cache.cc", + "quic/core/crypto/quic_compressed_certs_cache.cc", + "quic/core/crypto/quic_crypter.cc", + "quic/core/crypto/quic_crypto_client_config.cc", + "quic/core/crypto/quic_crypto_proof.cc", + "quic/core/crypto/quic_crypto_server_config.cc", + "quic/core/crypto/quic_decrypter.cc", + "quic/core/crypto/quic_encrypter.cc", + "quic/core/crypto/quic_hkdf.cc", + "quic/core/crypto/tls_client_connection.cc", + "quic/core/crypto/tls_connection.cc", + "quic/core/crypto/tls_server_connection.cc", + "quic/core/crypto/transport_parameters.cc", + "quic/core/crypto/web_transport_fingerprint_proof_verifier.cc", + "quic/core/deterministic_connection_id_generator.cc", + "quic/core/frames/quic_ack_frame.cc", + "quic/core/frames/quic_ack_frequency_frame.cc", + "quic/core/frames/quic_blocked_frame.cc", + "quic/core/frames/quic_connection_close_frame.cc", + "quic/core/frames/quic_crypto_frame.cc", + "quic/core/frames/quic_frame.cc", + "quic/core/frames/quic_goaway_frame.cc", + "quic/core/frames/quic_handshake_done_frame.cc", + "quic/core/frames/quic_max_streams_frame.cc", + "quic/core/frames/quic_message_frame.cc", + "quic/core/frames/quic_new_connection_id_frame.cc", + "quic/core/frames/quic_new_token_frame.cc", + "quic/core/frames/quic_padding_frame.cc", + "quic/core/frames/quic_path_challenge_frame.cc", + "quic/core/frames/quic_path_response_frame.cc", + "quic/core/frames/quic_ping_frame.cc", + "quic/core/frames/quic_retire_connection_id_frame.cc", + "quic/core/frames/quic_rst_stream_frame.cc", + "quic/core/frames/quic_stop_sending_frame.cc", + "quic/core/frames/quic_stop_waiting_frame.cc", + "quic/core/frames/quic_stream_frame.cc", + "quic/core/frames/quic_streams_blocked_frame.cc", + "quic/core/frames/quic_window_update_frame.cc", + "quic/core/http/http_constants.cc", + "quic/core/http/http_decoder.cc", + "quic/core/http/http_encoder.cc", + "quic/core/http/quic_client_promised_info.cc", + "quic/core/http/quic_client_push_promise_index.cc", + "quic/core/http/quic_header_list.cc", + "quic/core/http/quic_headers_stream.cc", + "quic/core/http/quic_receive_control_stream.cc", + "quic/core/http/quic_send_control_stream.cc", + "quic/core/http/quic_server_initiated_spdy_stream.cc", + "quic/core/http/quic_server_session_base.cc", + "quic/core/http/quic_spdy_client_session.cc", + "quic/core/http/quic_spdy_client_session_base.cc", + "quic/core/http/quic_spdy_client_stream.cc", + "quic/core/http/quic_spdy_server_stream_base.cc", + "quic/core/http/quic_spdy_session.cc", + "quic/core/http/quic_spdy_stream.cc", + "quic/core/http/quic_spdy_stream_body_manager.cc", + "quic/core/http/spdy_server_push_utils.cc", + "quic/core/http/spdy_utils.cc", + "quic/core/http/web_transport_http3.cc", + "quic/core/http/web_transport_stream_adapter.cc", + "quic/core/legacy_quic_stream_id_manager.cc", + "quic/core/qpack/qpack_blocking_manager.cc", + "quic/core/qpack/qpack_decoded_headers_accumulator.cc", + "quic/core/qpack/qpack_decoder.cc", + "quic/core/qpack/qpack_decoder_stream_receiver.cc", + "quic/core/qpack/qpack_decoder_stream_sender.cc", + "quic/core/qpack/qpack_encoder.cc", + "quic/core/qpack/qpack_encoder_stream_receiver.cc", + "quic/core/qpack/qpack_encoder_stream_sender.cc", + "quic/core/qpack/qpack_header_table.cc", + "quic/core/qpack/qpack_index_conversions.cc", + "quic/core/qpack/qpack_instruction_decoder.cc", + "quic/core/qpack/qpack_instruction_encoder.cc", + "quic/core/qpack/qpack_instructions.cc", + "quic/core/qpack/qpack_progressive_decoder.cc", + "quic/core/qpack/qpack_receive_stream.cc", + "quic/core/qpack/qpack_required_insert_count.cc", + "quic/core/qpack/qpack_send_stream.cc", + "quic/core/qpack/qpack_static_table.cc", + "quic/core/qpack/value_splitting_header_list.cc", + "quic/core/quic_ack_listener_interface.cc", + "quic/core/quic_alarm.cc", + "quic/core/quic_bandwidth.cc", + "quic/core/quic_buffered_packet_store.cc", + "quic/core/quic_chaos_protector.cc", + "quic/core/quic_coalesced_packet.cc", + "quic/core/quic_config.cc", + "quic/core/quic_connection.cc", + "quic/core/quic_connection_context.cc", + "quic/core/quic_connection_id.cc", + "quic/core/quic_connection_id_manager.cc", + "quic/core/quic_connection_stats.cc", + "quic/core/quic_constants.cc", + "quic/core/quic_control_frame_manager.cc", + "quic/core/quic_crypto_client_handshaker.cc", + "quic/core/quic_crypto_client_stream.cc", + "quic/core/quic_crypto_handshaker.cc", + "quic/core/quic_crypto_server_stream.cc", + "quic/core/quic_crypto_server_stream_base.cc", + "quic/core/quic_crypto_stream.cc", + "quic/core/quic_data_reader.cc", + "quic/core/quic_data_writer.cc", + "quic/core/quic_datagram_queue.cc", + "quic/core/quic_default_clock.cc", + "quic/core/quic_dispatcher.cc", + "quic/core/quic_error_codes.cc", + "quic/core/quic_flow_controller.cc", + "quic/core/quic_framer.cc", + "quic/core/quic_idle_network_detector.cc", + "quic/core/quic_mtu_discovery.cc", + "quic/core/quic_network_blackhole_detector.cc", + "quic/core/quic_packet_creator.cc", + "quic/core/quic_packet_number.cc", + "quic/core/quic_packet_writer_wrapper.cc", + "quic/core/quic_packets.cc", + "quic/core/quic_path_validator.cc", + "quic/core/quic_ping_manager.cc", + "quic/core/quic_received_packet_manager.cc", + "quic/core/quic_sent_packet_manager.cc", + "quic/core/quic_server_id.cc", + "quic/core/quic_session.cc", + "quic/core/quic_socket_address_coder.cc", + "quic/core/quic_stream.cc", + "quic/core/quic_stream_id_manager.cc", + "quic/core/quic_stream_priority.cc", + "quic/core/quic_stream_send_buffer.cc", + "quic/core/quic_stream_sequencer.cc", + "quic/core/quic_stream_sequencer_buffer.cc", + "quic/core/quic_sustained_bandwidth_recorder.cc", + "quic/core/quic_tag.cc", + "quic/core/quic_time.cc", + "quic/core/quic_time_wait_list_manager.cc", + "quic/core/quic_trace_visitor.cc", + "quic/core/quic_transmission_info.cc", + "quic/core/quic_types.cc", + "quic/core/quic_unacked_packet_map.cc", + "quic/core/quic_utils.cc", + "quic/core/quic_version_manager.cc", + "quic/core/quic_versions.cc", + "quic/core/quic_write_blocked_list.cc", + "quic/core/tls_chlo_extractor.cc", + "quic/core/tls_client_handshaker.cc", + "quic/core/tls_handshaker.cc", + "quic/core/tls_server_handshaker.cc", + "quic/core/uber_quic_stream_id_manager.cc", + "quic/core/uber_received_packet_manager.cc", + "quic/platform/api/quic_socket_address.cc", + "spdy/core/array_output_buffer.cc", + "spdy/core/hpack/hpack_constants.cc", + "spdy/core/hpack/hpack_decoder_adapter.cc", + "spdy/core/hpack/hpack_encoder.cc", + "spdy/core/hpack/hpack_entry.cc", + "spdy/core/hpack/hpack_header_table.cc", + "spdy/core/hpack/hpack_output_stream.cc", + "spdy/core/hpack/hpack_static_table.cc", + "spdy/core/http2_frame_decoder_adapter.cc", + "spdy/core/http2_header_block.cc", + "spdy/core/http2_header_storage.cc", + "spdy/core/metadata_extension.cc", + "spdy/core/recording_headers_handler.cc", + "spdy/core/spdy_alt_svc_wire_format.cc", + "spdy/core/spdy_frame_builder.cc", + "spdy/core/spdy_framer.cc", + "spdy/core/spdy_no_op_visitor.cc", + "spdy/core/spdy_pinnable_buffer_piece.cc", + "spdy/core/spdy_prefixed_buffer_reader.cc", + "spdy/core/spdy_protocol.cc", + "spdy/core/spdy_simple_arena.cc", +] +quiche_tool_support_hdrs = [ + "common/platform/api/quiche_command_line_flags.h", + "common/platform/api/quiche_default_proof_providers.h", + "common/platform/api/quiche_file_utils.h", + "common/platform/api/quiche_system_event_loop.h", + "quic/platform/api/quic_default_proof_providers.h", + "quic/tools/connect_server_backend.h", + "quic/tools/connect_tunnel.h", + "quic/tools/connect_udp_tunnel.h", + "quic/tools/fake_proof_verifier.h", + "quic/tools/quic_backend_response.h", + "quic/tools/quic_client_base.h", + "quic/tools/quic_memory_cache_backend.h", + "quic/tools/quic_name_lookup.h", + "quic/tools/quic_simple_client_session.h", + "quic/tools/quic_simple_client_stream.h", + "quic/tools/quic_simple_crypto_server_stream_helper.h", + "quic/tools/quic_simple_dispatcher.h", + "quic/tools/quic_simple_server_backend.h", + "quic/tools/quic_simple_server_session.h", + "quic/tools/quic_simple_server_stream.h", + "quic/tools/quic_spdy_client_base.h", + "quic/tools/quic_spdy_server_base.h", + "quic/tools/quic_tcp_like_trace_converter.h", + "quic/tools/quic_url.h", + "quic/tools/simple_ticket_crypter.h", + "quic/tools/web_transport_test_visitors.h", +] +quiche_tool_support_srcs = [ + "common/platform/api/quiche_file_utils.cc", + "quic/tools/connect_server_backend.cc", + "quic/tools/connect_tunnel.cc", + "quic/tools/connect_udp_tunnel.cc", + "quic/tools/quic_backend_response.cc", + "quic/tools/quic_client_base.cc", + "quic/tools/quic_memory_cache_backend.cc", + "quic/tools/quic_name_lookup.cc", + "quic/tools/quic_simple_client_session.cc", + "quic/tools/quic_simple_client_stream.cc", + "quic/tools/quic_simple_crypto_server_stream_helper.cc", + "quic/tools/quic_simple_dispatcher.cc", + "quic/tools/quic_simple_server_session.cc", + "quic/tools/quic_simple_server_stream.cc", + "quic/tools/quic_spdy_client_base.cc", + "quic/tools/quic_tcp_like_trace_converter.cc", + "quic/tools/quic_url.cc", + "quic/tools/simple_ticket_crypter.cc", +] +quiche_test_support_hdrs = [ + "common/platform/api/quiche_expect_bug.h", + "common/platform/api/quiche_test.h", + "common/platform/api/quiche_test_loopback.h", + "common/platform/api/quiche_test_output.h", + "common/test_tools/quiche_test_utils.h", + "http2/adapter/mock_http2_visitor.h", + "http2/adapter/recording_http2_visitor.h", + "http2/adapter/test_frame_sequence.h", + "http2/adapter/test_utils.h", + "http2/test_tools/frame_decoder_state_test_util.h", + "http2/test_tools/frame_parts.h", + "http2/test_tools/frame_parts_collector.h", + "http2/test_tools/frame_parts_collector_listener.h", + "http2/test_tools/hpack_block_builder.h", + "http2/test_tools/hpack_block_collector.h", + "http2/test_tools/hpack_entry_collector.h", + "http2/test_tools/hpack_example.h", + "http2/test_tools/hpack_string_collector.h", + "http2/test_tools/http2_constants_test_util.h", + "http2/test_tools/http2_frame_builder.h", + "http2/test_tools/http2_frame_decoder_listener_test_util.h", + "http2/test_tools/http2_random.h", + "http2/test_tools/http2_structure_decoder_test_util.h", + "http2/test_tools/http2_structures_test_util.h", + "http2/test_tools/payload_decoder_base_test_util.h", + "http2/test_tools/random_decoder_test_base.h", + "http2/test_tools/random_util.h", + "http2/test_tools/verify_macros.h", + "quic/platform/api/quic_expect_bug.h", + "quic/platform/api/quic_test.h", + "quic/platform/api/quic_test_loopback.h", + "quic/platform/api/quic_test_output.h", + "quic/test_tools/bad_packet_writer.h", + "quic/test_tools/crypto_test_utils.h", + "quic/test_tools/failing_proof_source.h", + "quic/test_tools/fake_proof_source.h", + "quic/test_tools/fake_proof_source_handle.h", + "quic/test_tools/first_flight.h", + "quic/test_tools/limited_mtu_test_writer.h", + "quic/test_tools/mock_clock.h", + "quic/test_tools/mock_connection_id_generator.h", + "quic/test_tools/mock_quic_client_promised_info.h", + "quic/test_tools/mock_quic_dispatcher.h", + "quic/test_tools/mock_quic_session_visitor.h", + "quic/test_tools/mock_quic_spdy_client_stream.h", + "quic/test_tools/mock_quic_time_wait_list_manager.h", + "quic/test_tools/mock_random.h", + "quic/test_tools/packet_dropping_test_writer.h", + "quic/test_tools/packet_reordering_writer.h", + "quic/test_tools/qpack/qpack_decoder_test_utils.h", + "quic/test_tools/qpack/qpack_encoder_peer.h", + "quic/test_tools/qpack/qpack_offline_decoder.h", + "quic/test_tools/qpack/qpack_test_utils.h", + "quic/test_tools/quic_buffered_packet_store_peer.h", + "quic/test_tools/quic_client_promised_info_peer.h", + "quic/test_tools/quic_client_session_cache_peer.h", + "quic/test_tools/quic_coalesced_packet_peer.h", + "quic/test_tools/quic_config_peer.h", + "quic/test_tools/quic_connection_id_manager_peer.h", + "quic/test_tools/quic_connection_peer.h", + "quic/test_tools/quic_crypto_server_config_peer.h", + "quic/test_tools/quic_dispatcher_peer.h", + "quic/test_tools/quic_flow_controller_peer.h", + "quic/test_tools/quic_framer_peer.h", + "quic/test_tools/quic_interval_deque_peer.h", + "quic/test_tools/quic_packet_creator_peer.h", + "quic/test_tools/quic_path_validator_peer.h", + "quic/test_tools/quic_sent_packet_manager_peer.h", + "quic/test_tools/quic_server_session_base_peer.h", + "quic/test_tools/quic_session_peer.h", + "quic/test_tools/quic_spdy_session_peer.h", + "quic/test_tools/quic_spdy_stream_peer.h", + "quic/test_tools/quic_stream_id_manager_peer.h", + "quic/test_tools/quic_stream_peer.h", + "quic/test_tools/quic_stream_send_buffer_peer.h", + "quic/test_tools/quic_stream_sequencer_buffer_peer.h", + "quic/test_tools/quic_stream_sequencer_peer.h", + "quic/test_tools/quic_sustained_bandwidth_recorder_peer.h", + "quic/test_tools/quic_test_backend.h", + "quic/test_tools/quic_test_utils.h", + "quic/test_tools/quic_time_wait_list_manager_peer.h", + "quic/test_tools/quic_unacked_packet_map_peer.h", + "quic/test_tools/rtt_stats_peer.h", + "quic/test_tools/send_algorithm_test_utils.h", + "quic/test_tools/simple_data_producer.h", + "quic/test_tools/simple_quic_framer.h", + "quic/test_tools/simple_session_cache.h", + "quic/test_tools/simple_session_notifier.h", + "quic/test_tools/simulator/actor.h", + "quic/test_tools/simulator/alarm_factory.h", + "quic/test_tools/simulator/link.h", + "quic/test_tools/simulator/packet_filter.h", + "quic/test_tools/simulator/port.h", + "quic/test_tools/simulator/queue.h", + "quic/test_tools/simulator/quic_endpoint.h", + "quic/test_tools/simulator/quic_endpoint_base.h", + "quic/test_tools/simulator/simulator.h", + "quic/test_tools/simulator/switch.h", + "quic/test_tools/simulator/test_harness.h", + "quic/test_tools/simulator/traffic_policer.h", + "quic/test_tools/test_certificates.h", + "quic/test_tools/test_ticket_crypter.h", + "quic/test_tools/web_transport_resets_backend.h", + "quic/test_tools/web_transport_test_tools.h", + "spdy/test_tools/mock_spdy_framer_visitor.h", + "spdy/test_tools/spdy_test_utils.h", + "web_transport/test_tools/mock_web_transport.h", +] +quiche_test_support_srcs = [ + "common/platform/api/quiche_test_loopback.cc", + "common/test_tools/quiche_test_utils.cc", + "http2/adapter/recording_http2_visitor.cc", + "http2/adapter/test_frame_sequence.cc", + "http2/adapter/test_utils.cc", + "http2/test_tools/frame_decoder_state_test_util.cc", + "http2/test_tools/frame_parts.cc", + "http2/test_tools/frame_parts_collector.cc", + "http2/test_tools/frame_parts_collector_listener.cc", + "http2/test_tools/hpack_block_builder.cc", + "http2/test_tools/hpack_block_collector.cc", + "http2/test_tools/hpack_entry_collector.cc", + "http2/test_tools/hpack_example.cc", + "http2/test_tools/hpack_string_collector.cc", + "http2/test_tools/http2_constants_test_util.cc", + "http2/test_tools/http2_frame_builder.cc", + "http2/test_tools/http2_frame_decoder_listener_test_util.cc", + "http2/test_tools/http2_random.cc", + "http2/test_tools/http2_structure_decoder_test_util.cc", + "http2/test_tools/http2_structures_test_util.cc", + "http2/test_tools/payload_decoder_base_test_util.cc", + "http2/test_tools/random_decoder_test_base.cc", + "http2/test_tools/random_util.cc", + "quic/test_tools/bad_packet_writer.cc", + "quic/test_tools/crypto_test_utils.cc", + "quic/test_tools/failing_proof_source.cc", + "quic/test_tools/fake_proof_source.cc", + "quic/test_tools/fake_proof_source_handle.cc", + "quic/test_tools/first_flight.cc", + "quic/test_tools/limited_mtu_test_writer.cc", + "quic/test_tools/mock_clock.cc", + "quic/test_tools/mock_quic_client_promised_info.cc", + "quic/test_tools/mock_quic_dispatcher.cc", + "quic/test_tools/mock_quic_session_visitor.cc", + "quic/test_tools/mock_quic_spdy_client_stream.cc", + "quic/test_tools/mock_quic_time_wait_list_manager.cc", + "quic/test_tools/mock_random.cc", + "quic/test_tools/packet_dropping_test_writer.cc", + "quic/test_tools/packet_reordering_writer.cc", + "quic/test_tools/qpack/qpack_decoder_test_utils.cc", + "quic/test_tools/qpack/qpack_encoder_peer.cc", + "quic/test_tools/qpack/qpack_offline_decoder.cc", + "quic/test_tools/qpack/qpack_test_utils.cc", + "quic/test_tools/quic_buffered_packet_store_peer.cc", + "quic/test_tools/quic_client_promised_info_peer.cc", + "quic/test_tools/quic_coalesced_packet_peer.cc", + "quic/test_tools/quic_config_peer.cc", + "quic/test_tools/quic_connection_peer.cc", + "quic/test_tools/quic_crypto_server_config_peer.cc", + "quic/test_tools/quic_dispatcher_peer.cc", + "quic/test_tools/quic_flow_controller_peer.cc", + "quic/test_tools/quic_framer_peer.cc", + "quic/test_tools/quic_packet_creator_peer.cc", + "quic/test_tools/quic_path_validator_peer.cc", + "quic/test_tools/quic_sent_packet_manager_peer.cc", + "quic/test_tools/quic_session_peer.cc", + "quic/test_tools/quic_spdy_session_peer.cc", + "quic/test_tools/quic_spdy_stream_peer.cc", + "quic/test_tools/quic_stream_id_manager_peer.cc", + "quic/test_tools/quic_stream_peer.cc", + "quic/test_tools/quic_stream_send_buffer_peer.cc", + "quic/test_tools/quic_stream_sequencer_buffer_peer.cc", + "quic/test_tools/quic_stream_sequencer_peer.cc", + "quic/test_tools/quic_sustained_bandwidth_recorder_peer.cc", + "quic/test_tools/quic_test_backend.cc", + "quic/test_tools/quic_test_utils.cc", + "quic/test_tools/quic_time_wait_list_manager_peer.cc", + "quic/test_tools/quic_unacked_packet_map_peer.cc", + "quic/test_tools/rtt_stats_peer.cc", + "quic/test_tools/send_algorithm_test_utils.cc", + "quic/test_tools/simple_data_producer.cc", + "quic/test_tools/simple_quic_framer.cc", + "quic/test_tools/simple_session_cache.cc", + "quic/test_tools/simple_session_notifier.cc", + "quic/test_tools/simulator/actor.cc", + "quic/test_tools/simulator/alarm_factory.cc", + "quic/test_tools/simulator/link.cc", + "quic/test_tools/simulator/packet_filter.cc", + "quic/test_tools/simulator/port.cc", + "quic/test_tools/simulator/queue.cc", + "quic/test_tools/simulator/quic_endpoint.cc", + "quic/test_tools/simulator/quic_endpoint_base.cc", + "quic/test_tools/simulator/simulator.cc", + "quic/test_tools/simulator/switch.cc", + "quic/test_tools/simulator/test_harness.cc", + "quic/test_tools/simulator/traffic_policer.cc", + "quic/test_tools/test_certificates.cc", + "quic/test_tools/test_ticket_crypter.cc", + "quic/test_tools/web_transport_resets_backend.cc", + "spdy/test_tools/mock_spdy_framer_visitor.cc", + "spdy/test_tools/spdy_test_utils.cc", +] +io_tool_support_hdrs = [ + "common/platform/api/quiche_event_loop.h", + "common/platform/api/quiche_udp_socket_platform_api.h", + "quic/core/io/event_loop_connecting_client_socket.h", + "quic/core/io/event_loop_socket_factory.h", + "quic/core/io/quic_default_event_loop.h", + "quic/core/io/quic_event_loop.h", + "quic/core/io/quic_poll_event_loop.h", + "quic/core/io/socket.h", + "quic/core/quic_default_packet_writer.h", + "quic/core/quic_packet_reader.h", + "quic/core/quic_syscall_wrapper.h", + "quic/core/quic_udp_socket.h", + "quic/masque/masque_client.h", + "quic/masque/masque_client_session.h", + "quic/masque/masque_client_tools.h", + "quic/masque/masque_dispatcher.h", + "quic/masque/masque_encapsulated_client.h", + "quic/masque/masque_encapsulated_client_session.h", + "quic/masque/masque_server.h", + "quic/masque/masque_server_backend.h", + "quic/masque/masque_server_session.h", + "quic/masque/masque_utils.h", + "quic/platform/api/quic_udp_socket_platform_api.h", + "quic/tools/quic_client_default_network_helper.h", + "quic/tools/quic_client_factory.h", + "quic/tools/quic_default_client.h", + "quic/tools/quic_epoll_client_factory.h", + "quic/tools/quic_server.h", +] +io_tool_support_srcs = [ + "quic/core/io/event_loop_connecting_client_socket.cc", + "quic/core/io/event_loop_socket_factory.cc", + "quic/core/io/quic_default_event_loop.cc", + "quic/core/io/quic_poll_event_loop.cc", + "quic/core/io/socket_posix.cc", + "quic/core/quic_default_packet_writer.cc", + "quic/core/quic_packet_reader.cc", + "quic/core/quic_syscall_wrapper.cc", + "quic/core/quic_udp_socket_posix.cc", + "quic/masque/masque_client.cc", + "quic/masque/masque_client_session.cc", + "quic/masque/masque_client_tools.cc", + "quic/masque/masque_dispatcher.cc", + "quic/masque/masque_encapsulated_client.cc", + "quic/masque/masque_encapsulated_client_session.cc", + "quic/masque/masque_server.cc", + "quic/masque/masque_server_backend.cc", + "quic/masque/masque_server_session.cc", + "quic/masque/masque_utils.cc", + "quic/tools/quic_client_default_network_helper.cc", + "quic/tools/quic_default_client.cc", + "quic/tools/quic_epoll_client_factory.cc", + "quic/tools/quic_server.cc", +] +io_test_support_hdrs = [ + "quic/test_tools/quic_mock_syscall_wrapper.h", + "quic/test_tools/quic_server_peer.h", + "quic/test_tools/quic_test_client.h", + "quic/test_tools/quic_test_server.h", + "quic/test_tools/server_thread.h", +] +io_test_support_srcs = [ + "quic/test_tools/quic_mock_syscall_wrapper.cc", + "quic/test_tools/quic_server_peer.cc", + "quic/test_tools/quic_test_client.cc", + "quic/test_tools/quic_test_server.cc", + "quic/test_tools/server_thread.cc", +] +quiche_tests_hdrs = [ + +] +quiche_tests_srcs = [ + "balsa/balsa_frame_test.cc", + "balsa/balsa_headers_test.cc", + "balsa/header_properties_test.cc", + "balsa/simple_buffer_test.cc", + "binary_http/binary_http_message_test.cc", + "common/btree_scheduler_test.cc", + "common/capsule_test.cc", + "common/masque/connect_udp_datagram_payload_test.cc", + "common/platform/api/quiche_file_utils_test.cc", + "common/platform/api/quiche_hostname_utils_test.cc", + "common/platform/api/quiche_lower_case_string_test.cc", + "common/platform/api/quiche_mem_slice_test.cc", + "common/platform/api/quiche_reference_counted_test.cc", + "common/platform/api/quiche_stack_trace_test.cc", + "common/platform/api/quiche_time_utils_test.cc", + "common/platform/api/quiche_url_utils_test.cc", + "common/print_elements_test.cc", + "common/quiche_buffer_allocator_test.cc", + "common/quiche_circular_deque_test.cc", + "common/quiche_data_reader_test.cc", + "common/quiche_data_writer_test.cc", + "common/quiche_endian_test.cc", + "common/quiche_ip_address_test.cc", + "common/quiche_linked_hash_map_test.cc", + "common/quiche_mem_slice_storage_test.cc", + "common/quiche_random_test.cc", + "common/quiche_text_utils_test.cc", + "common/simple_buffer_allocator_test.cc", + "common/structured_headers_generated_test.cc", + "common/structured_headers_test.cc", + "common/test_tools/quiche_test_utils_test.cc", + "common/wire_serialization_test.cc", + "http2/adapter/event_forwarder_test.cc", + "http2/adapter/header_validator_test.cc", + "http2/adapter/noop_header_validator_test.cc", + "http2/adapter/oghttp2_adapter_test.cc", + "http2/adapter/oghttp2_session_test.cc", + "http2/adapter/oghttp2_util_test.cc", + "http2/adapter/recording_http2_visitor_test.cc", + "http2/adapter/test_utils_test.cc", + "http2/adapter/window_manager_test.cc", + "http2/core/priority_write_scheduler_test.cc", + "http2/decoder/decode_buffer_test.cc", + "http2/decoder/decode_http2_structures_test.cc", + "http2/decoder/http2_frame_decoder_test.cc", + "http2/decoder/http2_structure_decoder_test.cc", + "http2/decoder/payload_decoders/altsvc_payload_decoder_test.cc", + "http2/decoder/payload_decoders/continuation_payload_decoder_test.cc", + "http2/decoder/payload_decoders/data_payload_decoder_test.cc", + "http2/decoder/payload_decoders/goaway_payload_decoder_test.cc", + "http2/decoder/payload_decoders/headers_payload_decoder_test.cc", + "http2/decoder/payload_decoders/ping_payload_decoder_test.cc", + "http2/decoder/payload_decoders/priority_payload_decoder_test.cc", + "http2/decoder/payload_decoders/priority_update_payload_decoder_test.cc", + "http2/decoder/payload_decoders/push_promise_payload_decoder_test.cc", + "http2/decoder/payload_decoders/rst_stream_payload_decoder_test.cc", + "http2/decoder/payload_decoders/settings_payload_decoder_test.cc", + "http2/decoder/payload_decoders/unknown_payload_decoder_test.cc", + "http2/decoder/payload_decoders/window_update_payload_decoder_test.cc", + "http2/hpack/decoder/hpack_block_collector_test.cc", + "http2/hpack/decoder/hpack_block_decoder_test.cc", + "http2/hpack/decoder/hpack_decoder_state_test.cc", + "http2/hpack/decoder/hpack_decoder_string_buffer_test.cc", + "http2/hpack/decoder/hpack_decoder_tables_test.cc", + "http2/hpack/decoder/hpack_decoder_test.cc", + "http2/hpack/decoder/hpack_entry_collector_test.cc", + "http2/hpack/decoder/hpack_entry_decoder_test.cc", + "http2/hpack/decoder/hpack_entry_type_decoder_test.cc", + "http2/hpack/decoder/hpack_string_decoder_test.cc", + "http2/hpack/decoder/hpack_whole_entry_buffer_test.cc", + "http2/hpack/http2_hpack_constants_test.cc", + "http2/hpack/huffman/hpack_huffman_decoder_test.cc", + "http2/hpack/huffman/hpack_huffman_encoder_test.cc", + "http2/hpack/huffman/hpack_huffman_transcoder_test.cc", + "http2/hpack/varint/hpack_varint_decoder_test.cc", + "http2/hpack/varint/hpack_varint_encoder_test.cc", + "http2/hpack/varint/hpack_varint_round_trip_test.cc", + "http2/http2_constants_test.cc", + "http2/http2_structures_test.cc", + "http2/test_tools/hpack_block_builder_test.cc", + "http2/test_tools/hpack_example_test.cc", + "http2/test_tools/http2_frame_builder_test.cc", + "http2/test_tools/http2_random_test.cc", + "http2/test_tools/random_decoder_test_base_test.cc", + "oblivious_http/buffers/oblivious_http_integration_test.cc", + "oblivious_http/buffers/oblivious_http_request_test.cc", + "oblivious_http/buffers/oblivious_http_response_test.cc", + "oblivious_http/common/oblivious_http_header_key_config_test.cc", + "oblivious_http/oblivious_http_client_test.cc", + "oblivious_http/oblivious_http_gateway_test.cc", + "quic/core/congestion_control/bandwidth_sampler_test.cc", + "quic/core/congestion_control/bbr2_simulator_test.cc", + "quic/core/congestion_control/bbr_sender_test.cc", + "quic/core/congestion_control/cubic_bytes_test.cc", + "quic/core/congestion_control/general_loss_algorithm_test.cc", + "quic/core/congestion_control/hybrid_slow_start_test.cc", + "quic/core/congestion_control/pacing_sender_test.cc", + "quic/core/congestion_control/prr_sender_test.cc", + "quic/core/congestion_control/rtt_stats_test.cc", + "quic/core/congestion_control/send_algorithm_test.cc", + "quic/core/congestion_control/tcp_cubic_sender_bytes_test.cc", + "quic/core/congestion_control/uber_loss_algorithm_test.cc", + "quic/core/congestion_control/windowed_filter_test.cc", + "quic/core/crypto/aes_128_gcm_12_decrypter_test.cc", + "quic/core/crypto/aes_128_gcm_12_encrypter_test.cc", + "quic/core/crypto/aes_128_gcm_decrypter_test.cc", + "quic/core/crypto/aes_128_gcm_encrypter_test.cc", + "quic/core/crypto/aes_256_gcm_decrypter_test.cc", + "quic/core/crypto/aes_256_gcm_encrypter_test.cc", + "quic/core/crypto/cert_compressor_test.cc", + "quic/core/crypto/certificate_util_test.cc", + "quic/core/crypto/certificate_view_test.cc", + "quic/core/crypto/chacha20_poly1305_decrypter_test.cc", + "quic/core/crypto/chacha20_poly1305_encrypter_test.cc", + "quic/core/crypto/chacha20_poly1305_tls_decrypter_test.cc", + "quic/core/crypto/chacha20_poly1305_tls_encrypter_test.cc", + "quic/core/crypto/channel_id_test.cc", + "quic/core/crypto/client_proof_source_test.cc", + "quic/core/crypto/crypto_framer_test.cc", + "quic/core/crypto/crypto_handshake_message_test.cc", + "quic/core/crypto/crypto_secret_boxer_test.cc", + "quic/core/crypto/crypto_server_test.cc", + "quic/core/crypto/crypto_utils_test.cc", + "quic/core/crypto/curve25519_key_exchange_test.cc", + "quic/core/crypto/null_decrypter_test.cc", + "quic/core/crypto/null_encrypter_test.cc", + "quic/core/crypto/p256_key_exchange_test.cc", + "quic/core/crypto/proof_source_x509_test.cc", + "quic/core/crypto/quic_client_session_cache_test.cc", + "quic/core/crypto/quic_compressed_certs_cache_test.cc", + "quic/core/crypto/quic_crypto_client_config_test.cc", + "quic/core/crypto/quic_crypto_server_config_test.cc", + "quic/core/crypto/quic_hkdf_test.cc", + "quic/core/crypto/transport_parameters_test.cc", + "quic/core/crypto/web_transport_fingerprint_proof_verifier_test.cc", + "quic/core/deterministic_connection_id_generator_test.cc", + "quic/core/frames/quic_frames_test.cc", + "quic/core/http/http_decoder_test.cc", + "quic/core/http/http_encoder_test.cc", + "quic/core/http/http_frames_test.cc", + "quic/core/http/quic_client_promised_info_test.cc", + "quic/core/http/quic_client_push_promise_index_test.cc", + "quic/core/http/quic_header_list_test.cc", + "quic/core/http/quic_headers_stream_test.cc", + "quic/core/http/quic_receive_control_stream_test.cc", + "quic/core/http/quic_send_control_stream_test.cc", + "quic/core/http/quic_server_session_base_test.cc", + "quic/core/http/quic_spdy_session_test.cc", + "quic/core/http/quic_spdy_stream_body_manager_test.cc", + "quic/core/http/quic_spdy_stream_test.cc", + "quic/core/http/spdy_server_push_utils_test.cc", + "quic/core/http/spdy_utils_test.cc", + "quic/core/http/web_transport_http3_test.cc", + "quic/core/legacy_quic_stream_id_manager_test.cc", + "quic/core/packet_number_indexed_queue_test.cc", + "quic/core/qpack/qpack_blocking_manager_test.cc", + "quic/core/qpack/qpack_decoded_headers_accumulator_test.cc", + "quic/core/qpack/qpack_decoder_stream_receiver_test.cc", + "quic/core/qpack/qpack_decoder_stream_sender_test.cc", + "quic/core/qpack/qpack_decoder_test.cc", + "quic/core/qpack/qpack_encoder_stream_receiver_test.cc", + "quic/core/qpack/qpack_encoder_stream_sender_test.cc", + "quic/core/qpack/qpack_encoder_test.cc", + "quic/core/qpack/qpack_header_table_test.cc", + "quic/core/qpack/qpack_index_conversions_test.cc", + "quic/core/qpack/qpack_instruction_decoder_test.cc", + "quic/core/qpack/qpack_instruction_encoder_test.cc", + "quic/core/qpack/qpack_receive_stream_test.cc", + "quic/core/qpack/qpack_required_insert_count_test.cc", + "quic/core/qpack/qpack_round_trip_test.cc", + "quic/core/qpack/qpack_send_stream_test.cc", + "quic/core/qpack/qpack_static_table_test.cc", + "quic/core/qpack/value_splitting_header_list_test.cc", + "quic/core/quic_alarm_test.cc", + "quic/core/quic_arena_scoped_ptr_test.cc", + "quic/core/quic_bandwidth_test.cc", + "quic/core/quic_buffered_packet_store_test.cc", + "quic/core/quic_chaos_protector_test.cc", + "quic/core/quic_coalesced_packet_test.cc", + "quic/core/quic_config_test.cc", + "quic/core/quic_connection_context_test.cc", + "quic/core/quic_connection_id_manager_test.cc", + "quic/core/quic_connection_id_test.cc", + "quic/core/quic_connection_test.cc", + "quic/core/quic_control_frame_manager_test.cc", + "quic/core/quic_crypto_client_handshaker_test.cc", + "quic/core/quic_crypto_client_stream_test.cc", + "quic/core/quic_crypto_server_stream_test.cc", + "quic/core/quic_crypto_stream_test.cc", + "quic/core/quic_data_writer_test.cc", + "quic/core/quic_datagram_queue_test.cc", + "quic/core/quic_dispatcher_test.cc", + "quic/core/quic_error_codes_test.cc", + "quic/core/quic_flow_controller_test.cc", + "quic/core/quic_framer_test.cc", + "quic/core/quic_idle_network_detector_test.cc", + "quic/core/quic_interval_deque_test.cc", + "quic/core/quic_interval_set_test.cc", + "quic/core/quic_interval_test.cc", + "quic/core/quic_lru_cache_test.cc", + "quic/core/quic_network_blackhole_detector_test.cc", + "quic/core/quic_one_block_arena_test.cc", + "quic/core/quic_packet_creator_test.cc", + "quic/core/quic_packet_number_test.cc", + "quic/core/quic_packets_test.cc", + "quic/core/quic_path_validator_test.cc", + "quic/core/quic_ping_manager_test.cc", + "quic/core/quic_received_packet_manager_test.cc", + "quic/core/quic_sent_packet_manager_test.cc", + "quic/core/quic_server_id_test.cc", + "quic/core/quic_session_test.cc", + "quic/core/quic_socket_address_coder_test.cc", + "quic/core/quic_stream_id_manager_test.cc", + "quic/core/quic_stream_priority_test.cc", + "quic/core/quic_stream_send_buffer_test.cc", + "quic/core/quic_stream_sequencer_buffer_test.cc", + "quic/core/quic_stream_sequencer_test.cc", + "quic/core/quic_stream_test.cc", + "quic/core/quic_sustained_bandwidth_recorder_test.cc", + "quic/core/quic_tag_test.cc", + "quic/core/quic_time_accumulator_test.cc", + "quic/core/quic_time_test.cc", + "quic/core/quic_time_wait_list_manager_test.cc", + "quic/core/quic_trace_visitor_test.cc", + "quic/core/quic_unacked_packet_map_test.cc", + "quic/core/quic_utils_test.cc", + "quic/core/quic_version_manager_test.cc", + "quic/core/quic_versions_test.cc", + "quic/core/quic_write_blocked_list_test.cc", + "quic/core/tls_chlo_extractor_test.cc", + "quic/core/tls_client_handshaker_test.cc", + "quic/core/tls_server_handshaker_test.cc", + "quic/core/uber_quic_stream_id_manager_test.cc", + "quic/core/uber_received_packet_manager_test.cc", + "quic/platform/api/quic_socket_address_test.cc", + "quic/test_tools/crypto_test_utils_test.cc", + "quic/test_tools/quic_test_utils_test.cc", + "quic/test_tools/simple_session_notifier_test.cc", + "quic/test_tools/simulator/quic_endpoint_test.cc", + "quic/test_tools/simulator/simulator_test.cc", + "quic/tools/connect_tunnel_test.cc", + "quic/tools/connect_udp_tunnel_test.cc", + "quic/tools/quic_memory_cache_backend_test.cc", + "quic/tools/quic_tcp_like_trace_converter_test.cc", + "quic/tools/simple_ticket_crypter_test.cc", + "spdy/core/array_output_buffer_test.cc", + "spdy/core/hpack/hpack_decoder_adapter_test.cc", + "spdy/core/hpack/hpack_encoder_test.cc", + "spdy/core/hpack/hpack_entry_test.cc", + "spdy/core/hpack/hpack_header_table_test.cc", + "spdy/core/hpack/hpack_output_stream_test.cc", + "spdy/core/hpack/hpack_round_trip_test.cc", + "spdy/core/hpack/hpack_static_table_test.cc", + "spdy/core/http2_header_block_test.cc", + "spdy/core/http2_header_storage_test.cc", + "spdy/core/metadata_extension_test.cc", + "spdy/core/spdy_alt_svc_wire_format_test.cc", + "spdy/core/spdy_frame_builder_test.cc", + "spdy/core/spdy_framer_test.cc", + "spdy/core/spdy_intrusive_list_test.cc", + "spdy/core/spdy_pinnable_buffer_piece_test.cc", + "spdy/core/spdy_prefixed_buffer_reader_test.cc", + "spdy/core/spdy_protocol_test.cc", + "spdy/core/spdy_simple_arena_test.cc", +] +io_tests_hdrs = [ + +] +io_tests_srcs = [ + "quic/core/chlo_extractor_test.cc", + "quic/core/http/end_to_end_test.cc", + "quic/core/http/quic_spdy_client_session_test.cc", + "quic/core/http/quic_spdy_client_stream_test.cc", + "quic/core/http/quic_spdy_server_stream_base_test.cc", + "quic/core/io/event_loop_connecting_client_socket_test.cc", + "quic/core/io/quic_all_event_loops_test.cc", + "quic/core/io/quic_poll_event_loop_test.cc", + "quic/core/io/socket_test.cc", + "quic/tools/quic_default_client_test.cc", + "quic/tools/quic_server_test.cc", + "quic/tools/quic_simple_server_session_test.cc", + "quic/tools/quic_simple_server_stream_test.cc", + "quic/tools/quic_url_test.cc", +] +fuzzers_hdrs = [ + +] +fuzzers_srcs = [ + "common/structured_headers_fuzzer.cc", + "quic/core/crypto/certificate_view_der_fuzzer.cc", + "quic/core/crypto/certificate_view_pem_fuzzer.cc", + "quic/core/qpack/fuzzer/qpack_decoder_fuzzer.cc", + "quic/core/qpack/fuzzer/qpack_decoder_stream_receiver_fuzzer.cc", + "quic/core/qpack/fuzzer/qpack_decoder_stream_sender_fuzzer.cc", + "quic/core/qpack/fuzzer/qpack_encoder_stream_receiver_fuzzer.cc", + "quic/core/qpack/fuzzer/qpack_encoder_stream_sender_fuzzer.cc", + "quic/core/qpack/fuzzer/qpack_round_trip_fuzzer.cc", + "quic/test_tools/fuzzing/quic_framer_fuzzer.cc", + "quic/test_tools/fuzzing/quic_framer_process_data_packet_fuzzer.cc", +] +cli_tools_hdrs = [ + "quic/tools/quic_server_factory.h", + "quic/tools/quic_toy_client.h", + "quic/tools/quic_toy_server.h", +] +cli_tools_srcs = [ + "quic/masque/masque_client_bin.cc", + "quic/masque/masque_server_bin.cc", + "quic/tools/crypto_message_printer_bin.cc", + "quic/tools/qpack_offline_decoder_bin.cc", + "quic/tools/quic_client_bin.cc", + "quic/tools/quic_client_interop_test_bin.cc", + "quic/tools/quic_packet_printer_bin.cc", + "quic/tools/quic_reject_reason_decoder_bin.cc", + "quic/tools/quic_server_bin.cc", + "quic/tools/quic_server_factory.cc", + "quic/tools/quic_toy_client.cc", + "quic/tools/quic_toy_server.cc", +] +nghttp2_hdrs = [ + "http2/adapter/callback_visitor.h", + "http2/adapter/nghttp2.h", + "http2/adapter/nghttp2_adapter.h", + "http2/adapter/nghttp2_callbacks.h", + "http2/adapter/nghttp2_data_provider.h", + "http2/adapter/nghttp2_session.h", + "http2/adapter/nghttp2_util.h", +] +nghttp2_srcs = [ + "http2/adapter/callback_visitor.cc", + "http2/adapter/nghttp2_adapter.cc", + "http2/adapter/nghttp2_callbacks.cc", + "http2/adapter/nghttp2_data_provider.cc", + "http2/adapter/nghttp2_session.cc", + "http2/adapter/nghttp2_test.cc", + "http2/adapter/nghttp2_util.cc", +] +nghttp2_test_support_hdrs = [ + "http2/adapter/mock_nghttp2_callbacks.h", + "http2/adapter/nghttp2_test_utils.h", +] +nghttp2_test_support_srcs = [ + "http2/adapter/mock_nghttp2_callbacks.cc", + "http2/adapter/nghttp2_test_utils.cc", +] +nghttp2_tests_hdrs = [ + +] +nghttp2_tests_srcs = [ + "http2/adapter/adapter_impl_comparison_test.cc", + "http2/adapter/callback_visitor_test.cc", + "http2/adapter/nghttp2_adapter_test.cc", + "http2/adapter/nghttp2_data_provider_test.cc", + "http2/adapter/nghttp2_session_test.cc", + "http2/adapter/nghttp2_util_test.cc", +] +default_platform_impl_hdrs = [ + "common/platform/default/quiche_platform_impl/quiche_bug_tracker_impl.h", + "common/platform/default/quiche_platform_impl/quiche_client_stats_impl.h", + "common/platform/default/quiche_platform_impl/quiche_containers_impl.h", + "common/platform/default/quiche_platform_impl/quiche_event_loop_impl.h", + "common/platform/default/quiche_platform_impl/quiche_export_impl.h", + "common/platform/default/quiche_platform_impl/quiche_flag_utils_impl.h", + "common/platform/default/quiche_platform_impl/quiche_flags_impl.h", + "common/platform/default/quiche_platform_impl/quiche_header_policy_impl.h", + "common/platform/default/quiche_platform_impl/quiche_iovec_impl.h", + "common/platform/default/quiche_platform_impl/quiche_logging_impl.h", + "common/platform/default/quiche_platform_impl/quiche_lower_case_string_impl.h", + "common/platform/default/quiche_platform_impl/quiche_mem_slice_impl.h", + "common/platform/default/quiche_platform_impl/quiche_mutex_impl.h", + "common/platform/default/quiche_platform_impl/quiche_prefetch_impl.h", + "common/platform/default/quiche_platform_impl/quiche_reference_counted_impl.h", + "common/platform/default/quiche_platform_impl/quiche_server_stats_impl.h", + "common/platform/default/quiche_platform_impl/quiche_stack_trace_impl.h", + "common/platform/default/quiche_platform_impl/quiche_testvalue_impl.h", + "common/platform/default/quiche_platform_impl/quiche_time_utils_impl.h", + "common/platform/default/quiche_platform_impl/quiche_udp_socket_platform_impl.h", + "common/platform/default/quiche_platform_impl/quiche_url_utils_impl.h", +] +default_platform_impl_srcs = [ + "common/platform/default/quiche_platform_impl/quiche_flags_impl.cc", + "common/platform/default/quiche_platform_impl/quiche_mutex_impl.cc", + "common/platform/default/quiche_platform_impl/quiche_stack_trace_impl.cc", + "common/platform/default/quiche_platform_impl/quiche_time_utils_impl.cc", + "common/platform/default/quiche_platform_impl/quiche_url_utils_impl.cc", +] +default_platform_impl_tool_support_hdrs = [ + "common/platform/default/quiche_platform_impl/quiche_command_line_flags_impl.h", + "common/platform/default/quiche_platform_impl/quiche_default_proof_providers_impl.h", + "common/platform/default/quiche_platform_impl/quiche_file_utils_impl.h", + "common/platform/default/quiche_platform_impl/quiche_stream_buffer_allocator_impl.h", + "common/platform/default/quiche_platform_impl/quiche_system_event_loop_impl.h", +] +default_platform_impl_tool_support_srcs = [ + "common/platform/default/quiche_platform_impl/quiche_command_line_flags_impl.cc", + "common/platform/default/quiche_platform_impl/quiche_default_proof_providers_impl.cc", + "common/platform/default/quiche_platform_impl/quiche_file_utils_impl.cc", +] +default_platform_impl_test_support_hdrs = [ + "common/platform/default/quiche_platform_impl/quiche_expect_bug_impl.h", + "common/platform/default/quiche_platform_impl/quiche_test_impl.h", + "common/platform/default/quiche_platform_impl/quiche_test_loopback_impl.h", + "common/platform/default/quiche_platform_impl/quiche_test_output_impl.h", + "common/platform/default/quiche_platform_impl/quiche_thread_impl.h", +] +default_platform_impl_test_support_srcs = [ + "common/platform/default/quiche_platform_impl/quiche_test_impl.cc", + "common/platform/default/quiche_platform_impl/quiche_test_loopback_impl.cc", +] +load_balancer_hdrs = [ + "quic/load_balancer/load_balancer_config.h", + "quic/load_balancer/load_balancer_decoder.h", + "quic/load_balancer/load_balancer_encoder.h", + "quic/load_balancer/load_balancer_server_id.h", + "quic/load_balancer/load_balancer_server_id_map.h", +] +load_balancer_srcs = [ + "quic/load_balancer/load_balancer_config.cc", + "quic/load_balancer/load_balancer_config_test.cc", + "quic/load_balancer/load_balancer_decoder.cc", + "quic/load_balancer/load_balancer_decoder_test.cc", + "quic/load_balancer/load_balancer_encoder.cc", + "quic/load_balancer/load_balancer_encoder_test.cc", + "quic/load_balancer/load_balancer_server_id.cc", + "quic/load_balancer/load_balancer_server_id_map_test.cc", + "quic/load_balancer/load_balancer_server_id_test.cc", +] +binary_http_hdrs = [ + "binary_http/binary_http_message.h", +] +binary_http_srcs = [ + "binary_http/binary_http_message.cc", +] +oblivious_http_hdrs = [ + "oblivious_http/buffers/oblivious_http_request.h", + "oblivious_http/buffers/oblivious_http_response.h", + "oblivious_http/common/oblivious_http_header_key_config.h", + "oblivious_http/oblivious_http_client.h", + "oblivious_http/oblivious_http_gateway.h", +] +oblivious_http_srcs = [ + "oblivious_http/buffers/oblivious_http_request.cc", + "oblivious_http/buffers/oblivious_http_response.cc", + "oblivious_http/common/oblivious_http_header_key_config.cc", + "oblivious_http/oblivious_http_client.cc", + "oblivious_http/oblivious_http_gateway.cc", +] +qbone_hdrs = [ + "quic/qbone/bonnet/icmp_reachable.h", + "quic/qbone/bonnet/icmp_reachable_interface.h", + "quic/qbone/bonnet/mock_icmp_reachable.h", + "quic/qbone/bonnet/mock_packet_exchanger_stats_interface.h", + "quic/qbone/bonnet/mock_qbone_tunnel.h", + "quic/qbone/bonnet/mock_tun_device.h", + "quic/qbone/bonnet/mock_tun_device_controller.h", + "quic/qbone/bonnet/qbone_tunnel_info.h", + "quic/qbone/bonnet/qbone_tunnel_interface.h", + "quic/qbone/bonnet/qbone_tunnel_silo.h", + "quic/qbone/bonnet/tun_device.h", + "quic/qbone/bonnet/tun_device_controller.h", + "quic/qbone/bonnet/tun_device_interface.h", + "quic/qbone/bonnet/tun_device_packet_exchanger.h", + "quic/qbone/mock_qbone_client.h", + "quic/qbone/mock_qbone_server_session.h", + "quic/qbone/platform/icmp_packet.h", + "quic/qbone/platform/internet_checksum.h", + "quic/qbone/platform/ip_range.h", + "quic/qbone/platform/kernel_interface.h", + "quic/qbone/platform/mock_kernel.h", + "quic/qbone/platform/mock_netlink.h", + "quic/qbone/platform/netlink.h", + "quic/qbone/platform/netlink_interface.h", + "quic/qbone/platform/rtnetlink_message.h", + "quic/qbone/platform/tcp_packet.h", + "quic/qbone/qbone_client.h", + "quic/qbone/qbone_client_interface.h", + "quic/qbone/qbone_client_session.h", + "quic/qbone/qbone_constants.h", + "quic/qbone/qbone_control_stream.h", + "quic/qbone/qbone_packet_exchanger.h", + "quic/qbone/qbone_packet_processor.h", + "quic/qbone/qbone_packet_processor_test_tools.h", + "quic/qbone/qbone_packet_writer.h", + "quic/qbone/qbone_server_session.h", + "quic/qbone/qbone_session_base.h", + "quic/qbone/qbone_stream.h", +] +qbone_srcs = [ + "quic/qbone/bonnet/icmp_reachable.cc", + "quic/qbone/bonnet/icmp_reachable_test.cc", + "quic/qbone/bonnet/qbone_tunnel_info.cc", + "quic/qbone/bonnet/qbone_tunnel_silo.cc", + "quic/qbone/bonnet/qbone_tunnel_silo_test.cc", + "quic/qbone/bonnet/tun_device.cc", + "quic/qbone/bonnet/tun_device_controller.cc", + "quic/qbone/bonnet/tun_device_controller_test.cc", + "quic/qbone/bonnet/tun_device_packet_exchanger.cc", + "quic/qbone/bonnet/tun_device_packet_exchanger_test.cc", + "quic/qbone/bonnet/tun_device_test.cc", + "quic/qbone/platform/icmp_packet.cc", + "quic/qbone/platform/icmp_packet_test.cc", + "quic/qbone/platform/internet_checksum.cc", + "quic/qbone/platform/internet_checksum_test.cc", + "quic/qbone/platform/ip_range.cc", + "quic/qbone/platform/ip_range_test.cc", + "quic/qbone/platform/netlink.cc", + "quic/qbone/platform/netlink_test.cc", + "quic/qbone/platform/rtnetlink_message.cc", + "quic/qbone/platform/rtnetlink_message_test.cc", + "quic/qbone/platform/tcp_packet.cc", + "quic/qbone/platform/tcp_packet_test.cc", + "quic/qbone/qbone_client.cc", + "quic/qbone/qbone_client_session.cc", + "quic/qbone/qbone_client_test.cc", + "quic/qbone/qbone_constants.cc", + "quic/qbone/qbone_control_stream.cc", + "quic/qbone/qbone_packet_exchanger.cc", + "quic/qbone/qbone_packet_exchanger_test.cc", + "quic/qbone/qbone_packet_processor.cc", + "quic/qbone/qbone_packet_processor_test.cc", + "quic/qbone/qbone_packet_processor_test_tools.cc", + "quic/qbone/qbone_server_session.cc", + "quic/qbone/qbone_session_base.cc", + "quic/qbone/qbone_session_test.cc", + "quic/qbone/qbone_stream.cc", + "quic/qbone/qbone_stream_test.cc", +] +blind_sign_auth_hdrs = [ + "blind_sign_auth/anonymous_tokens/cpp/client/anonymous_tokens_rsa_bssa_client.h", + "blind_sign_auth/anonymous_tokens/cpp/crypto/blind_signer.h", + "blind_sign_auth/anonymous_tokens/cpp/crypto/blinder.h", + "blind_sign_auth/anonymous_tokens/cpp/crypto/constants.h", + "blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.h", + "blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blind_signer.h", + "blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blinder.h", + "blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_ssa_pss_verifier.h", + "blind_sign_auth/anonymous_tokens/cpp/crypto/verifier.h", + "blind_sign_auth/anonymous_tokens/cpp/shared/proto_utils.h", + "blind_sign_auth/anonymous_tokens/cpp/shared/status_utils.h", + "blind_sign_auth/anonymous_tokens/cpp/testing/utils.h", + "blind_sign_auth/blind_sign_auth.h", + "blind_sign_auth/blind_sign_auth_interface.h", + "blind_sign_auth/blind_sign_http_interface.h", + "blind_sign_auth/blind_sign_http_response.h", + "blind_sign_auth/cached_blind_sign_auth.h", + "blind_sign_auth/test_tools/mock_blind_sign_auth_interface.h", + "blind_sign_auth/test_tools/mock_blind_sign_http_interface.h", +] +blind_sign_auth_srcs = [ + "blind_sign_auth/anonymous_tokens/cpp/client/anonymous_tokens_rsa_bssa_client.cc", + "blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.cc", + "blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blind_signer.cc", + "blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blinder.cc", + "blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_ssa_pss_verifier.cc", + "blind_sign_auth/anonymous_tokens/cpp/shared/proto_utils.cc", + "blind_sign_auth/anonymous_tokens/cpp/testing/utils.cc", + "blind_sign_auth/blind_sign_auth.cc", + "blind_sign_auth/cached_blind_sign_auth.cc", +] +blind_sign_auth_tests_hdrs = [ + +] +blind_sign_auth_tests_srcs = [ + "blind_sign_auth/anonymous_tokens/cpp/client/anonymous_tokens_rsa_bssa_client_test.cc", + "blind_sign_auth/anonymous_tokens/cpp/crypto/at_crypto_utils_test.cc", + "blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blind_signer_test.cc", + "blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blinder_test.cc", + "blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_ssa_pss_verifier_test.cc", + "blind_sign_auth/anonymous_tokens/cpp/shared/proto_utils_test.cc", + "blind_sign_auth/blind_sign_auth_test.cc", + "blind_sign_auth/cached_blind_sign_auth_test.cc", +] +protobuf_blind_sign_auth = [ + "blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.proto", + "blind_sign_auth/proto/any.proto", + "blind_sign_auth/proto/attestation.proto", + "blind_sign_auth/proto/auth_and_sign.proto", + "blind_sign_auth/proto/get_initial_data.proto", + "blind_sign_auth/proto/key_services.proto", + "blind_sign_auth/proto/public_metadata.proto", + "blind_sign_auth/proto/spend_token_data.proto", + "blind_sign_auth/proto/timestamp.proto", +] +libevent_hdrs = [ + "quic/bindings/quic_libevent.h", +] +libevent_srcs = [ + "quic/bindings/quic_libevent.cc", + "quic/bindings/quic_libevent_test.cc", +] +linux_only_hdrs = [ + "quic/core/batch_writer/quic_batch_writer_base.h", + "quic/core/batch_writer/quic_batch_writer_buffer.h", + "quic/core/batch_writer/quic_batch_writer_test.h", + "quic/core/batch_writer/quic_gso_batch_writer.h", + "quic/core/batch_writer/quic_sendmmsg_batch_writer.h", + "quic/core/quic_linux_socket_utils.h", +] +linux_only_srcs = [ + "quic/core/batch_writer/quic_batch_writer_base.cc", + "quic/core/batch_writer/quic_batch_writer_buffer.cc", + "quic/core/batch_writer/quic_gso_batch_writer.cc", + "quic/core/batch_writer/quic_sendmmsg_batch_writer.cc", + "quic/core/quic_linux_socket_utils.cc", +] +linux_only_tests_hdrs = [ + +] +linux_only_tests_srcs = [ + "quic/core/batch_writer/quic_batch_writer_buffer_test.cc", + "quic/core/batch_writer/quic_batch_writer_test.cc", + "quic/core/batch_writer/quic_gso_batch_writer_test.cc", + "quic/core/batch_writer/quic_sendmmsg_batch_writer_test.cc", + "quic/core/quic_linux_socket_utils_test.cc", +] diff --git a/build/source_list.gni b/build/source_list.gni new file mode 100644 index 000000000000..3950cd341849 --- /dev/null +++ b/build/source_list.gni @@ -0,0 +1,1637 @@ +# Autogenerated source file list for QUICHE Chromium build. + +protobuf = [ + "src/quiche/quic/core/proto/cached_network_parameters.proto", + "src/quiche/quic/core/proto/crypto_server_config.proto", + "src/quiche/quic/core/proto/source_address_token.proto", +] +protobuf_test_support = [ + "src/quiche/quic/test_tools/send_algorithm_test_result.proto", +] +quiche_core_hdrs = [ + "src/quiche/balsa/balsa_enums.h", + "src/quiche/balsa/balsa_frame.h", + "src/quiche/balsa/balsa_headers.h", + "src/quiche/balsa/balsa_visitor_interface.h", + "src/quiche/balsa/framer_interface.h", + "src/quiche/balsa/header_api.h", + "src/quiche/balsa/header_properties.h", + "src/quiche/balsa/http_validation_policy.h", + "src/quiche/balsa/noop_balsa_visitor.h", + "src/quiche/balsa/simple_buffer.h", + "src/quiche/balsa/standard_header_map.h", + "src/quiche/common/btree_scheduler.h", + "src/quiche/common/capsule.h", + "src/quiche/common/masque/connect_udp_datagram_payload.h", + "src/quiche/common/platform/api/quiche_bug_tracker.h", + "src/quiche/common/platform/api/quiche_client_stats.h", + "src/quiche/common/platform/api/quiche_containers.h", + "src/quiche/common/platform/api/quiche_export.h", + "src/quiche/common/platform/api/quiche_flag_utils.h", + "src/quiche/common/platform/api/quiche_flags.h", + "src/quiche/common/platform/api/quiche_header_policy.h", + "src/quiche/common/platform/api/quiche_hostname_utils.h", + "src/quiche/common/platform/api/quiche_iovec.h", + "src/quiche/common/platform/api/quiche_logging.h", + "src/quiche/common/platform/api/quiche_lower_case_string.h", + "src/quiche/common/platform/api/quiche_mem_slice.h", + "src/quiche/common/platform/api/quiche_mutex.h", + "src/quiche/common/platform/api/quiche_prefetch.h", + "src/quiche/common/platform/api/quiche_reference_counted.h", + "src/quiche/common/platform/api/quiche_server_stats.h", + "src/quiche/common/platform/api/quiche_stack_trace.h", + "src/quiche/common/platform/api/quiche_testvalue.h", + "src/quiche/common/platform/api/quiche_thread.h", + "src/quiche/common/platform/api/quiche_time_utils.h", + "src/quiche/common/platform/api/quiche_url_utils.h", + "src/quiche/common/print_elements.h", + "src/quiche/common/quiche_buffer_allocator.h", + "src/quiche/common/quiche_circular_deque.h", + "src/quiche/common/quiche_crypto_logging.h", + "src/quiche/common/quiche_data_reader.h", + "src/quiche/common/quiche_data_writer.h", + "src/quiche/common/quiche_endian.h", + "src/quiche/common/quiche_ip_address.h", + "src/quiche/common/quiche_ip_address_family.h", + "src/quiche/common/quiche_linked_hash_map.h", + "src/quiche/common/quiche_mem_slice_storage.h", + "src/quiche/common/quiche_protocol_flags_list.h", + "src/quiche/common/quiche_random.h", + "src/quiche/common/quiche_status_utils.h", + "src/quiche/common/quiche_stream.h", + "src/quiche/common/quiche_text_utils.h", + "src/quiche/common/simple_buffer_allocator.h", + "src/quiche/common/structured_headers.h", + "src/quiche/common/wire_serialization.h", + "src/quiche/http2/adapter/data_source.h", + "src/quiche/http2/adapter/event_forwarder.h", + "src/quiche/http2/adapter/header_validator.h", + "src/quiche/http2/adapter/header_validator_base.h", + "src/quiche/http2/adapter/http2_adapter.h", + "src/quiche/http2/adapter/http2_protocol.h", + "src/quiche/http2/adapter/http2_session.h", + "src/quiche/http2/adapter/http2_util.h", + "src/quiche/http2/adapter/http2_visitor_interface.h", + "src/quiche/http2/adapter/noop_header_validator.h", + "src/quiche/http2/adapter/oghttp2_adapter.h", + "src/quiche/http2/adapter/oghttp2_session.h", + "src/quiche/http2/adapter/oghttp2_util.h", + "src/quiche/http2/adapter/window_manager.h", + "src/quiche/http2/core/http2_trace_logging.h", + "src/quiche/http2/core/priority_write_scheduler.h", + "src/quiche/http2/decoder/decode_buffer.h", + "src/quiche/http2/decoder/decode_http2_structures.h", + "src/quiche/http2/decoder/decode_status.h", + "src/quiche/http2/decoder/frame_decoder_state.h", + "src/quiche/http2/decoder/http2_frame_decoder.h", + "src/quiche/http2/decoder/http2_frame_decoder_listener.h", + "src/quiche/http2/decoder/http2_structure_decoder.h", + "src/quiche/http2/decoder/payload_decoders/altsvc_payload_decoder.h", + "src/quiche/http2/decoder/payload_decoders/continuation_payload_decoder.h", + "src/quiche/http2/decoder/payload_decoders/data_payload_decoder.h", + "src/quiche/http2/decoder/payload_decoders/goaway_payload_decoder.h", + "src/quiche/http2/decoder/payload_decoders/headers_payload_decoder.h", + "src/quiche/http2/decoder/payload_decoders/ping_payload_decoder.h", + "src/quiche/http2/decoder/payload_decoders/priority_payload_decoder.h", + "src/quiche/http2/decoder/payload_decoders/priority_update_payload_decoder.h", + "src/quiche/http2/decoder/payload_decoders/push_promise_payload_decoder.h", + "src/quiche/http2/decoder/payload_decoders/rst_stream_payload_decoder.h", + "src/quiche/http2/decoder/payload_decoders/settings_payload_decoder.h", + "src/quiche/http2/decoder/payload_decoders/unknown_payload_decoder.h", + "src/quiche/http2/decoder/payload_decoders/window_update_payload_decoder.h", + "src/quiche/http2/hpack/decoder/hpack_block_decoder.h", + "src/quiche/http2/hpack/decoder/hpack_decoder.h", + "src/quiche/http2/hpack/decoder/hpack_decoder_listener.h", + "src/quiche/http2/hpack/decoder/hpack_decoder_state.h", + "src/quiche/http2/hpack/decoder/hpack_decoder_string_buffer.h", + "src/quiche/http2/hpack/decoder/hpack_decoder_tables.h", + "src/quiche/http2/hpack/decoder/hpack_decoding_error.h", + "src/quiche/http2/hpack/decoder/hpack_entry_decoder.h", + "src/quiche/http2/hpack/decoder/hpack_entry_decoder_listener.h", + "src/quiche/http2/hpack/decoder/hpack_entry_type_decoder.h", + "src/quiche/http2/hpack/decoder/hpack_string_decoder.h", + "src/quiche/http2/hpack/decoder/hpack_string_decoder_listener.h", + "src/quiche/http2/hpack/decoder/hpack_whole_entry_buffer.h", + "src/quiche/http2/hpack/decoder/hpack_whole_entry_listener.h", + "src/quiche/http2/hpack/http2_hpack_constants.h", + "src/quiche/http2/hpack/huffman/hpack_huffman_decoder.h", + "src/quiche/http2/hpack/huffman/hpack_huffman_encoder.h", + "src/quiche/http2/hpack/huffman/huffman_spec_tables.h", + "src/quiche/http2/hpack/varint/hpack_varint_decoder.h", + "src/quiche/http2/hpack/varint/hpack_varint_encoder.h", + "src/quiche/http2/http2_constants.h", + "src/quiche/http2/http2_structures.h", + "src/quiche/quic/core/chlo_extractor.h", + "src/quiche/quic/core/congestion_control/bandwidth_sampler.h", + "src/quiche/quic/core/congestion_control/bbr2_drain.h", + "src/quiche/quic/core/congestion_control/bbr2_misc.h", + "src/quiche/quic/core/congestion_control/bbr2_probe_bw.h", + "src/quiche/quic/core/congestion_control/bbr2_probe_rtt.h", + "src/quiche/quic/core/congestion_control/bbr2_sender.h", + "src/quiche/quic/core/congestion_control/bbr2_startup.h", + "src/quiche/quic/core/congestion_control/bbr_sender.h", + "src/quiche/quic/core/congestion_control/cubic_bytes.h", + "src/quiche/quic/core/congestion_control/general_loss_algorithm.h", + "src/quiche/quic/core/congestion_control/hybrid_slow_start.h", + "src/quiche/quic/core/congestion_control/loss_detection_interface.h", + "src/quiche/quic/core/congestion_control/pacing_sender.h", + "src/quiche/quic/core/congestion_control/prr_sender.h", + "src/quiche/quic/core/congestion_control/rtt_stats.h", + "src/quiche/quic/core/congestion_control/send_algorithm_interface.h", + "src/quiche/quic/core/congestion_control/tcp_cubic_sender_bytes.h", + "src/quiche/quic/core/congestion_control/uber_loss_algorithm.h", + "src/quiche/quic/core/congestion_control/windowed_filter.h", + "src/quiche/quic/core/connecting_client_socket.h", + "src/quiche/quic/core/connection_id_generator.h", + "src/quiche/quic/core/crypto/aead_base_decrypter.h", + "src/quiche/quic/core/crypto/aead_base_encrypter.h", + "src/quiche/quic/core/crypto/aes_128_gcm_12_decrypter.h", + "src/quiche/quic/core/crypto/aes_128_gcm_12_encrypter.h", + "src/quiche/quic/core/crypto/aes_128_gcm_decrypter.h", + "src/quiche/quic/core/crypto/aes_128_gcm_encrypter.h", + "src/quiche/quic/core/crypto/aes_256_gcm_decrypter.h", + "src/quiche/quic/core/crypto/aes_256_gcm_encrypter.h", + "src/quiche/quic/core/crypto/aes_base_decrypter.h", + "src/quiche/quic/core/crypto/aes_base_encrypter.h", + "src/quiche/quic/core/crypto/boring_utils.h", + "src/quiche/quic/core/crypto/cert_compressor.h", + "src/quiche/quic/core/crypto/certificate_util.h", + "src/quiche/quic/core/crypto/certificate_view.h", + "src/quiche/quic/core/crypto/chacha20_poly1305_decrypter.h", + "src/quiche/quic/core/crypto/chacha20_poly1305_encrypter.h", + "src/quiche/quic/core/crypto/chacha20_poly1305_tls_decrypter.h", + "src/quiche/quic/core/crypto/chacha20_poly1305_tls_encrypter.h", + "src/quiche/quic/core/crypto/chacha_base_decrypter.h", + "src/quiche/quic/core/crypto/chacha_base_encrypter.h", + "src/quiche/quic/core/crypto/channel_id.h", + "src/quiche/quic/core/crypto/client_proof_source.h", + "src/quiche/quic/core/crypto/crypto_framer.h", + "src/quiche/quic/core/crypto/crypto_handshake.h", + "src/quiche/quic/core/crypto/crypto_handshake_message.h", + "src/quiche/quic/core/crypto/crypto_message_parser.h", + "src/quiche/quic/core/crypto/crypto_protocol.h", + "src/quiche/quic/core/crypto/crypto_secret_boxer.h", + "src/quiche/quic/core/crypto/crypto_utils.h", + "src/quiche/quic/core/crypto/curve25519_key_exchange.h", + "src/quiche/quic/core/crypto/key_exchange.h", + "src/quiche/quic/core/crypto/null_decrypter.h", + "src/quiche/quic/core/crypto/null_encrypter.h", + "src/quiche/quic/core/crypto/p256_key_exchange.h", + "src/quiche/quic/core/crypto/proof_source.h", + "src/quiche/quic/core/crypto/proof_source_x509.h", + "src/quiche/quic/core/crypto/proof_verifier.h", + "src/quiche/quic/core/crypto/quic_client_session_cache.h", + "src/quiche/quic/core/crypto/quic_compressed_certs_cache.h", + "src/quiche/quic/core/crypto/quic_crypter.h", + "src/quiche/quic/core/crypto/quic_crypto_client_config.h", + "src/quiche/quic/core/crypto/quic_crypto_proof.h", + "src/quiche/quic/core/crypto/quic_crypto_server_config.h", + "src/quiche/quic/core/crypto/quic_decrypter.h", + "src/quiche/quic/core/crypto/quic_encrypter.h", + "src/quiche/quic/core/crypto/quic_hkdf.h", + "src/quiche/quic/core/crypto/quic_random.h", + "src/quiche/quic/core/crypto/tls_client_connection.h", + "src/quiche/quic/core/crypto/tls_connection.h", + "src/quiche/quic/core/crypto/tls_server_connection.h", + "src/quiche/quic/core/crypto/transport_parameters.h", + "src/quiche/quic/core/crypto/web_transport_fingerprint_proof_verifier.h", + "src/quiche/quic/core/deterministic_connection_id_generator.h", + "src/quiche/quic/core/frames/quic_ack_frame.h", + "src/quiche/quic/core/frames/quic_ack_frequency_frame.h", + "src/quiche/quic/core/frames/quic_blocked_frame.h", + "src/quiche/quic/core/frames/quic_connection_close_frame.h", + "src/quiche/quic/core/frames/quic_crypto_frame.h", + "src/quiche/quic/core/frames/quic_frame.h", + "src/quiche/quic/core/frames/quic_goaway_frame.h", + "src/quiche/quic/core/frames/quic_handshake_done_frame.h", + "src/quiche/quic/core/frames/quic_inlined_frame.h", + "src/quiche/quic/core/frames/quic_max_streams_frame.h", + "src/quiche/quic/core/frames/quic_message_frame.h", + "src/quiche/quic/core/frames/quic_mtu_discovery_frame.h", + "src/quiche/quic/core/frames/quic_new_connection_id_frame.h", + "src/quiche/quic/core/frames/quic_new_token_frame.h", + "src/quiche/quic/core/frames/quic_padding_frame.h", + "src/quiche/quic/core/frames/quic_path_challenge_frame.h", + "src/quiche/quic/core/frames/quic_path_response_frame.h", + "src/quiche/quic/core/frames/quic_ping_frame.h", + "src/quiche/quic/core/frames/quic_retire_connection_id_frame.h", + "src/quiche/quic/core/frames/quic_rst_stream_frame.h", + "src/quiche/quic/core/frames/quic_stop_sending_frame.h", + "src/quiche/quic/core/frames/quic_stop_waiting_frame.h", + "src/quiche/quic/core/frames/quic_stream_frame.h", + "src/quiche/quic/core/frames/quic_streams_blocked_frame.h", + "src/quiche/quic/core/frames/quic_window_update_frame.h", + "src/quiche/quic/core/handshaker_delegate_interface.h", + "src/quiche/quic/core/http/http_constants.h", + "src/quiche/quic/core/http/http_decoder.h", + "src/quiche/quic/core/http/http_encoder.h", + "src/quiche/quic/core/http/http_frames.h", + "src/quiche/quic/core/http/quic_client_promised_info.h", + "src/quiche/quic/core/http/quic_client_push_promise_index.h", + "src/quiche/quic/core/http/quic_header_list.h", + "src/quiche/quic/core/http/quic_headers_stream.h", + "src/quiche/quic/core/http/quic_receive_control_stream.h", + "src/quiche/quic/core/http/quic_send_control_stream.h", + "src/quiche/quic/core/http/quic_server_initiated_spdy_stream.h", + "src/quiche/quic/core/http/quic_server_session_base.h", + "src/quiche/quic/core/http/quic_spdy_client_session.h", + "src/quiche/quic/core/http/quic_spdy_client_session_base.h", + "src/quiche/quic/core/http/quic_spdy_client_stream.h", + "src/quiche/quic/core/http/quic_spdy_server_stream_base.h", + "src/quiche/quic/core/http/quic_spdy_session.h", + "src/quiche/quic/core/http/quic_spdy_stream.h", + "src/quiche/quic/core/http/quic_spdy_stream_body_manager.h", + "src/quiche/quic/core/http/spdy_server_push_utils.h", + "src/quiche/quic/core/http/spdy_utils.h", + "src/quiche/quic/core/http/web_transport_http3.h", + "src/quiche/quic/core/http/web_transport_stream_adapter.h", + "src/quiche/quic/core/legacy_quic_stream_id_manager.h", + "src/quiche/quic/core/packet_number_indexed_queue.h", + "src/quiche/quic/core/proto/cached_network_parameters_proto.h", + "src/quiche/quic/core/proto/crypto_server_config_proto.h", + "src/quiche/quic/core/proto/source_address_token_proto.h", + "src/quiche/quic/core/qpack/qpack_blocking_manager.h", + "src/quiche/quic/core/qpack/qpack_decoded_headers_accumulator.h", + "src/quiche/quic/core/qpack/qpack_decoder.h", + "src/quiche/quic/core/qpack/qpack_decoder_stream_receiver.h", + "src/quiche/quic/core/qpack/qpack_decoder_stream_sender.h", + "src/quiche/quic/core/qpack/qpack_encoder.h", + "src/quiche/quic/core/qpack/qpack_encoder_stream_receiver.h", + "src/quiche/quic/core/qpack/qpack_encoder_stream_sender.h", + "src/quiche/quic/core/qpack/qpack_header_table.h", + "src/quiche/quic/core/qpack/qpack_index_conversions.h", + "src/quiche/quic/core/qpack/qpack_instruction_decoder.h", + "src/quiche/quic/core/qpack/qpack_instruction_encoder.h", + "src/quiche/quic/core/qpack/qpack_instructions.h", + "src/quiche/quic/core/qpack/qpack_progressive_decoder.h", + "src/quiche/quic/core/qpack/qpack_receive_stream.h", + "src/quiche/quic/core/qpack/qpack_required_insert_count.h", + "src/quiche/quic/core/qpack/qpack_send_stream.h", + "src/quiche/quic/core/qpack/qpack_static_table.h", + "src/quiche/quic/core/qpack/qpack_stream_receiver.h", + "src/quiche/quic/core/qpack/qpack_stream_sender_delegate.h", + "src/quiche/quic/core/qpack/value_splitting_header_list.h", + "src/quiche/quic/core/quic_ack_listener_interface.h", + "src/quiche/quic/core/quic_alarm.h", + "src/quiche/quic/core/quic_alarm_factory.h", + "src/quiche/quic/core/quic_arena_scoped_ptr.h", + "src/quiche/quic/core/quic_bandwidth.h", + "src/quiche/quic/core/quic_blocked_writer_interface.h", + "src/quiche/quic/core/quic_buffered_packet_store.h", + "src/quiche/quic/core/quic_chaos_protector.h", + "src/quiche/quic/core/quic_clock.h", + "src/quiche/quic/core/quic_coalesced_packet.h", + "src/quiche/quic/core/quic_config.h", + "src/quiche/quic/core/quic_connection.h", + "src/quiche/quic/core/quic_connection_context.h", + "src/quiche/quic/core/quic_connection_id.h", + "src/quiche/quic/core/quic_connection_id_manager.h", + "src/quiche/quic/core/quic_connection_stats.h", + "src/quiche/quic/core/quic_constants.h", + "src/quiche/quic/core/quic_control_frame_manager.h", + "src/quiche/quic/core/quic_crypto_client_handshaker.h", + "src/quiche/quic/core/quic_crypto_client_stream.h", + "src/quiche/quic/core/quic_crypto_handshaker.h", + "src/quiche/quic/core/quic_crypto_server_stream.h", + "src/quiche/quic/core/quic_crypto_server_stream_base.h", + "src/quiche/quic/core/quic_crypto_stream.h", + "src/quiche/quic/core/quic_data_reader.h", + "src/quiche/quic/core/quic_data_writer.h", + "src/quiche/quic/core/quic_datagram_queue.h", + "src/quiche/quic/core/quic_default_clock.h", + "src/quiche/quic/core/quic_default_connection_helper.h", + "src/quiche/quic/core/quic_dispatcher.h", + "src/quiche/quic/core/quic_error_codes.h", + "src/quiche/quic/core/quic_flags_list.h", + "src/quiche/quic/core/quic_flow_controller.h", + "src/quiche/quic/core/quic_framer.h", + "src/quiche/quic/core/quic_idle_network_detector.h", + "src/quiche/quic/core/quic_interval.h", + "src/quiche/quic/core/quic_interval_deque.h", + "src/quiche/quic/core/quic_interval_set.h", + "src/quiche/quic/core/quic_lru_cache.h", + "src/quiche/quic/core/quic_mtu_discovery.h", + "src/quiche/quic/core/quic_network_blackhole_detector.h", + "src/quiche/quic/core/quic_one_block_arena.h", + "src/quiche/quic/core/quic_packet_creator.h", + "src/quiche/quic/core/quic_packet_number.h", + "src/quiche/quic/core/quic_packet_writer.h", + "src/quiche/quic/core/quic_packet_writer_wrapper.h", + "src/quiche/quic/core/quic_packets.h", + "src/quiche/quic/core/quic_path_validator.h", + "src/quiche/quic/core/quic_ping_manager.h", + "src/quiche/quic/core/quic_process_packet_interface.h", + "src/quiche/quic/core/quic_protocol_flags_list.h", + "src/quiche/quic/core/quic_received_packet_manager.h", + "src/quiche/quic/core/quic_sent_packet_manager.h", + "src/quiche/quic/core/quic_server_id.h", + "src/quiche/quic/core/quic_session.h", + "src/quiche/quic/core/quic_socket_address_coder.h", + "src/quiche/quic/core/quic_stream.h", + "src/quiche/quic/core/quic_stream_frame_data_producer.h", + "src/quiche/quic/core/quic_stream_id_manager.h", + "src/quiche/quic/core/quic_stream_priority.h", + "src/quiche/quic/core/quic_stream_send_buffer.h", + "src/quiche/quic/core/quic_stream_sequencer.h", + "src/quiche/quic/core/quic_stream_sequencer_buffer.h", + "src/quiche/quic/core/quic_sustained_bandwidth_recorder.h", + "src/quiche/quic/core/quic_tag.h", + "src/quiche/quic/core/quic_time.h", + "src/quiche/quic/core/quic_time_accumulator.h", + "src/quiche/quic/core/quic_time_wait_list_manager.h", + "src/quiche/quic/core/quic_trace_visitor.h", + "src/quiche/quic/core/quic_transmission_info.h", + "src/quiche/quic/core/quic_types.h", + "src/quiche/quic/core/quic_unacked_packet_map.h", + "src/quiche/quic/core/quic_utils.h", + "src/quiche/quic/core/quic_version_manager.h", + "src/quiche/quic/core/quic_versions.h", + "src/quiche/quic/core/quic_write_blocked_list.h", + "src/quiche/quic/core/session_notifier_interface.h", + "src/quiche/quic/core/socket_factory.h", + "src/quiche/quic/core/stream_delegate_interface.h", + "src/quiche/quic/core/tls_chlo_extractor.h", + "src/quiche/quic/core/tls_client_handshaker.h", + "src/quiche/quic/core/tls_handshaker.h", + "src/quiche/quic/core/tls_server_handshaker.h", + "src/quiche/quic/core/uber_quic_stream_id_manager.h", + "src/quiche/quic/core/uber_received_packet_manager.h", + "src/quiche/quic/core/web_transport_interface.h", + "src/quiche/quic/platform/api/quic_bug_tracker.h", + "src/quiche/quic/platform/api/quic_client_stats.h", + "src/quiche/quic/platform/api/quic_export.h", + "src/quiche/quic/platform/api/quic_exported_stats.h", + "src/quiche/quic/platform/api/quic_flag_utils.h", + "src/quiche/quic/platform/api/quic_flags.h", + "src/quiche/quic/platform/api/quic_hostname_utils.h", + "src/quiche/quic/platform/api/quic_ip_address.h", + "src/quiche/quic/platform/api/quic_ip_address_family.h", + "src/quiche/quic/platform/api/quic_logging.h", + "src/quiche/quic/platform/api/quic_mutex.h", + "src/quiche/quic/platform/api/quic_server_stats.h", + "src/quiche/quic/platform/api/quic_socket_address.h", + "src/quiche/quic/platform/api/quic_stack_trace.h", + "src/quiche/quic/platform/api/quic_testvalue.h", + "src/quiche/quic/platform/api/quic_thread.h", + "src/quiche/spdy/core/array_output_buffer.h", + "src/quiche/spdy/core/header_byte_listener_interface.h", + "src/quiche/spdy/core/hpack/hpack_constants.h", + "src/quiche/spdy/core/hpack/hpack_decoder_adapter.h", + "src/quiche/spdy/core/hpack/hpack_encoder.h", + "src/quiche/spdy/core/hpack/hpack_entry.h", + "src/quiche/spdy/core/hpack/hpack_header_table.h", + "src/quiche/spdy/core/hpack/hpack_output_stream.h", + "src/quiche/spdy/core/hpack/hpack_static_table.h", + "src/quiche/spdy/core/http2_frame_decoder_adapter.h", + "src/quiche/spdy/core/http2_header_block.h", + "src/quiche/spdy/core/http2_header_block_hpack_listener.h", + "src/quiche/spdy/core/http2_header_storage.h", + "src/quiche/spdy/core/metadata_extension.h", + "src/quiche/spdy/core/no_op_headers_handler.h", + "src/quiche/spdy/core/recording_headers_handler.h", + "src/quiche/spdy/core/spdy_alt_svc_wire_format.h", + "src/quiche/spdy/core/spdy_bitmasks.h", + "src/quiche/spdy/core/spdy_frame_builder.h", + "src/quiche/spdy/core/spdy_framer.h", + "src/quiche/spdy/core/spdy_headers_handler_interface.h", + "src/quiche/spdy/core/spdy_intrusive_list.h", + "src/quiche/spdy/core/spdy_no_op_visitor.h", + "src/quiche/spdy/core/spdy_pinnable_buffer_piece.h", + "src/quiche/spdy/core/spdy_prefixed_buffer_reader.h", + "src/quiche/spdy/core/spdy_protocol.h", + "src/quiche/spdy/core/spdy_simple_arena.h", + "src/quiche/spdy/core/zero_copy_output_buffer.h", + "src/quiche/web_transport/web_transport.h", +] +quiche_core_srcs = [ + "src/quiche/balsa/balsa_enums.cc", + "src/quiche/balsa/balsa_frame.cc", + "src/quiche/balsa/balsa_headers.cc", + "src/quiche/balsa/header_properties.cc", + "src/quiche/balsa/simple_buffer.cc", + "src/quiche/balsa/standard_header_map.cc", + "src/quiche/common/capsule.cc", + "src/quiche/common/masque/connect_udp_datagram_payload.cc", + "src/quiche/common/platform/api/quiche_hostname_utils.cc", + "src/quiche/common/platform/api/quiche_mutex.cc", + "src/quiche/common/quiche_buffer_allocator.cc", + "src/quiche/common/quiche_crypto_logging.cc", + "src/quiche/common/quiche_data_reader.cc", + "src/quiche/common/quiche_data_writer.cc", + "src/quiche/common/quiche_ip_address.cc", + "src/quiche/common/quiche_ip_address_family.cc", + "src/quiche/common/quiche_mem_slice_storage.cc", + "src/quiche/common/quiche_random.cc", + "src/quiche/common/quiche_text_utils.cc", + "src/quiche/common/simple_buffer_allocator.cc", + "src/quiche/common/structured_headers.cc", + "src/quiche/http2/adapter/event_forwarder.cc", + "src/quiche/http2/adapter/header_validator.cc", + "src/quiche/http2/adapter/http2_protocol.cc", + "src/quiche/http2/adapter/http2_util.cc", + "src/quiche/http2/adapter/noop_header_validator.cc", + "src/quiche/http2/adapter/oghttp2_adapter.cc", + "src/quiche/http2/adapter/oghttp2_session.cc", + "src/quiche/http2/adapter/oghttp2_util.cc", + "src/quiche/http2/adapter/window_manager.cc", + "src/quiche/http2/core/http2_trace_logging.cc", + "src/quiche/http2/decoder/decode_buffer.cc", + "src/quiche/http2/decoder/decode_http2_structures.cc", + "src/quiche/http2/decoder/decode_status.cc", + "src/quiche/http2/decoder/frame_decoder_state.cc", + "src/quiche/http2/decoder/http2_frame_decoder.cc", + "src/quiche/http2/decoder/http2_frame_decoder_listener.cc", + "src/quiche/http2/decoder/http2_structure_decoder.cc", + "src/quiche/http2/decoder/payload_decoders/altsvc_payload_decoder.cc", + "src/quiche/http2/decoder/payload_decoders/continuation_payload_decoder.cc", + "src/quiche/http2/decoder/payload_decoders/data_payload_decoder.cc", + "src/quiche/http2/decoder/payload_decoders/goaway_payload_decoder.cc", + "src/quiche/http2/decoder/payload_decoders/headers_payload_decoder.cc", + "src/quiche/http2/decoder/payload_decoders/ping_payload_decoder.cc", + "src/quiche/http2/decoder/payload_decoders/priority_payload_decoder.cc", + "src/quiche/http2/decoder/payload_decoders/priority_update_payload_decoder.cc", + "src/quiche/http2/decoder/payload_decoders/push_promise_payload_decoder.cc", + "src/quiche/http2/decoder/payload_decoders/rst_stream_payload_decoder.cc", + "src/quiche/http2/decoder/payload_decoders/settings_payload_decoder.cc", + "src/quiche/http2/decoder/payload_decoders/unknown_payload_decoder.cc", + "src/quiche/http2/decoder/payload_decoders/window_update_payload_decoder.cc", + "src/quiche/http2/hpack/decoder/hpack_block_decoder.cc", + "src/quiche/http2/hpack/decoder/hpack_decoder.cc", + "src/quiche/http2/hpack/decoder/hpack_decoder_listener.cc", + "src/quiche/http2/hpack/decoder/hpack_decoder_state.cc", + "src/quiche/http2/hpack/decoder/hpack_decoder_string_buffer.cc", + "src/quiche/http2/hpack/decoder/hpack_decoder_tables.cc", + "src/quiche/http2/hpack/decoder/hpack_decoding_error.cc", + "src/quiche/http2/hpack/decoder/hpack_entry_decoder.cc", + "src/quiche/http2/hpack/decoder/hpack_entry_decoder_listener.cc", + "src/quiche/http2/hpack/decoder/hpack_entry_type_decoder.cc", + "src/quiche/http2/hpack/decoder/hpack_string_decoder.cc", + "src/quiche/http2/hpack/decoder/hpack_string_decoder_listener.cc", + "src/quiche/http2/hpack/decoder/hpack_whole_entry_buffer.cc", + "src/quiche/http2/hpack/decoder/hpack_whole_entry_listener.cc", + "src/quiche/http2/hpack/http2_hpack_constants.cc", + "src/quiche/http2/hpack/huffman/hpack_huffman_decoder.cc", + "src/quiche/http2/hpack/huffman/hpack_huffman_encoder.cc", + "src/quiche/http2/hpack/huffman/huffman_spec_tables.cc", + "src/quiche/http2/hpack/varint/hpack_varint_decoder.cc", + "src/quiche/http2/hpack/varint/hpack_varint_encoder.cc", + "src/quiche/http2/http2_constants.cc", + "src/quiche/http2/http2_structures.cc", + "src/quiche/quic/core/chlo_extractor.cc", + "src/quiche/quic/core/congestion_control/bandwidth_sampler.cc", + "src/quiche/quic/core/congestion_control/bbr2_drain.cc", + "src/quiche/quic/core/congestion_control/bbr2_misc.cc", + "src/quiche/quic/core/congestion_control/bbr2_probe_bw.cc", + "src/quiche/quic/core/congestion_control/bbr2_probe_rtt.cc", + "src/quiche/quic/core/congestion_control/bbr2_sender.cc", + "src/quiche/quic/core/congestion_control/bbr2_startup.cc", + "src/quiche/quic/core/congestion_control/bbr_sender.cc", + "src/quiche/quic/core/congestion_control/cubic_bytes.cc", + "src/quiche/quic/core/congestion_control/general_loss_algorithm.cc", + "src/quiche/quic/core/congestion_control/hybrid_slow_start.cc", + "src/quiche/quic/core/congestion_control/pacing_sender.cc", + "src/quiche/quic/core/congestion_control/prr_sender.cc", + "src/quiche/quic/core/congestion_control/rtt_stats.cc", + "src/quiche/quic/core/congestion_control/send_algorithm_interface.cc", + "src/quiche/quic/core/congestion_control/tcp_cubic_sender_bytes.cc", + "src/quiche/quic/core/congestion_control/uber_loss_algorithm.cc", + "src/quiche/quic/core/crypto/aead_base_decrypter.cc", + "src/quiche/quic/core/crypto/aead_base_encrypter.cc", + "src/quiche/quic/core/crypto/aes_128_gcm_12_decrypter.cc", + "src/quiche/quic/core/crypto/aes_128_gcm_12_encrypter.cc", + "src/quiche/quic/core/crypto/aes_128_gcm_decrypter.cc", + "src/quiche/quic/core/crypto/aes_128_gcm_encrypter.cc", + "src/quiche/quic/core/crypto/aes_256_gcm_decrypter.cc", + "src/quiche/quic/core/crypto/aes_256_gcm_encrypter.cc", + "src/quiche/quic/core/crypto/aes_base_decrypter.cc", + "src/quiche/quic/core/crypto/aes_base_encrypter.cc", + "src/quiche/quic/core/crypto/cert_compressor.cc", + "src/quiche/quic/core/crypto/certificate_util.cc", + "src/quiche/quic/core/crypto/certificate_view.cc", + "src/quiche/quic/core/crypto/chacha20_poly1305_decrypter.cc", + "src/quiche/quic/core/crypto/chacha20_poly1305_encrypter.cc", + "src/quiche/quic/core/crypto/chacha20_poly1305_tls_decrypter.cc", + "src/quiche/quic/core/crypto/chacha20_poly1305_tls_encrypter.cc", + "src/quiche/quic/core/crypto/chacha_base_decrypter.cc", + "src/quiche/quic/core/crypto/chacha_base_encrypter.cc", + "src/quiche/quic/core/crypto/channel_id.cc", + "src/quiche/quic/core/crypto/client_proof_source.cc", + "src/quiche/quic/core/crypto/crypto_framer.cc", + "src/quiche/quic/core/crypto/crypto_handshake.cc", + "src/quiche/quic/core/crypto/crypto_handshake_message.cc", + "src/quiche/quic/core/crypto/crypto_secret_boxer.cc", + "src/quiche/quic/core/crypto/crypto_utils.cc", + "src/quiche/quic/core/crypto/curve25519_key_exchange.cc", + "src/quiche/quic/core/crypto/key_exchange.cc", + "src/quiche/quic/core/crypto/null_decrypter.cc", + "src/quiche/quic/core/crypto/null_encrypter.cc", + "src/quiche/quic/core/crypto/p256_key_exchange.cc", + "src/quiche/quic/core/crypto/proof_source.cc", + "src/quiche/quic/core/crypto/proof_source_x509.cc", + "src/quiche/quic/core/crypto/quic_client_session_cache.cc", + "src/quiche/quic/core/crypto/quic_compressed_certs_cache.cc", + "src/quiche/quic/core/crypto/quic_crypter.cc", + "src/quiche/quic/core/crypto/quic_crypto_client_config.cc", + "src/quiche/quic/core/crypto/quic_crypto_proof.cc", + "src/quiche/quic/core/crypto/quic_crypto_server_config.cc", + "src/quiche/quic/core/crypto/quic_decrypter.cc", + "src/quiche/quic/core/crypto/quic_encrypter.cc", + "src/quiche/quic/core/crypto/quic_hkdf.cc", + "src/quiche/quic/core/crypto/tls_client_connection.cc", + "src/quiche/quic/core/crypto/tls_connection.cc", + "src/quiche/quic/core/crypto/tls_server_connection.cc", + "src/quiche/quic/core/crypto/transport_parameters.cc", + "src/quiche/quic/core/crypto/web_transport_fingerprint_proof_verifier.cc", + "src/quiche/quic/core/deterministic_connection_id_generator.cc", + "src/quiche/quic/core/frames/quic_ack_frame.cc", + "src/quiche/quic/core/frames/quic_ack_frequency_frame.cc", + "src/quiche/quic/core/frames/quic_blocked_frame.cc", + "src/quiche/quic/core/frames/quic_connection_close_frame.cc", + "src/quiche/quic/core/frames/quic_crypto_frame.cc", + "src/quiche/quic/core/frames/quic_frame.cc", + "src/quiche/quic/core/frames/quic_goaway_frame.cc", + "src/quiche/quic/core/frames/quic_handshake_done_frame.cc", + "src/quiche/quic/core/frames/quic_max_streams_frame.cc", + "src/quiche/quic/core/frames/quic_message_frame.cc", + "src/quiche/quic/core/frames/quic_new_connection_id_frame.cc", + "src/quiche/quic/core/frames/quic_new_token_frame.cc", + "src/quiche/quic/core/frames/quic_padding_frame.cc", + "src/quiche/quic/core/frames/quic_path_challenge_frame.cc", + "src/quiche/quic/core/frames/quic_path_response_frame.cc", + "src/quiche/quic/core/frames/quic_ping_frame.cc", + "src/quiche/quic/core/frames/quic_retire_connection_id_frame.cc", + "src/quiche/quic/core/frames/quic_rst_stream_frame.cc", + "src/quiche/quic/core/frames/quic_stop_sending_frame.cc", + "src/quiche/quic/core/frames/quic_stop_waiting_frame.cc", + "src/quiche/quic/core/frames/quic_stream_frame.cc", + "src/quiche/quic/core/frames/quic_streams_blocked_frame.cc", + "src/quiche/quic/core/frames/quic_window_update_frame.cc", + "src/quiche/quic/core/http/http_constants.cc", + "src/quiche/quic/core/http/http_decoder.cc", + "src/quiche/quic/core/http/http_encoder.cc", + "src/quiche/quic/core/http/quic_client_promised_info.cc", + "src/quiche/quic/core/http/quic_client_push_promise_index.cc", + "src/quiche/quic/core/http/quic_header_list.cc", + "src/quiche/quic/core/http/quic_headers_stream.cc", + "src/quiche/quic/core/http/quic_receive_control_stream.cc", + "src/quiche/quic/core/http/quic_send_control_stream.cc", + "src/quiche/quic/core/http/quic_server_initiated_spdy_stream.cc", + "src/quiche/quic/core/http/quic_server_session_base.cc", + "src/quiche/quic/core/http/quic_spdy_client_session.cc", + "src/quiche/quic/core/http/quic_spdy_client_session_base.cc", + "src/quiche/quic/core/http/quic_spdy_client_stream.cc", + "src/quiche/quic/core/http/quic_spdy_server_stream_base.cc", + "src/quiche/quic/core/http/quic_spdy_session.cc", + "src/quiche/quic/core/http/quic_spdy_stream.cc", + "src/quiche/quic/core/http/quic_spdy_stream_body_manager.cc", + "src/quiche/quic/core/http/spdy_server_push_utils.cc", + "src/quiche/quic/core/http/spdy_utils.cc", + "src/quiche/quic/core/http/web_transport_http3.cc", + "src/quiche/quic/core/http/web_transport_stream_adapter.cc", + "src/quiche/quic/core/legacy_quic_stream_id_manager.cc", + "src/quiche/quic/core/qpack/qpack_blocking_manager.cc", + "src/quiche/quic/core/qpack/qpack_decoded_headers_accumulator.cc", + "src/quiche/quic/core/qpack/qpack_decoder.cc", + "src/quiche/quic/core/qpack/qpack_decoder_stream_receiver.cc", + "src/quiche/quic/core/qpack/qpack_decoder_stream_sender.cc", + "src/quiche/quic/core/qpack/qpack_encoder.cc", + "src/quiche/quic/core/qpack/qpack_encoder_stream_receiver.cc", + "src/quiche/quic/core/qpack/qpack_encoder_stream_sender.cc", + "src/quiche/quic/core/qpack/qpack_header_table.cc", + "src/quiche/quic/core/qpack/qpack_index_conversions.cc", + "src/quiche/quic/core/qpack/qpack_instruction_decoder.cc", + "src/quiche/quic/core/qpack/qpack_instruction_encoder.cc", + "src/quiche/quic/core/qpack/qpack_instructions.cc", + "src/quiche/quic/core/qpack/qpack_progressive_decoder.cc", + "src/quiche/quic/core/qpack/qpack_receive_stream.cc", + "src/quiche/quic/core/qpack/qpack_required_insert_count.cc", + "src/quiche/quic/core/qpack/qpack_send_stream.cc", + "src/quiche/quic/core/qpack/qpack_static_table.cc", + "src/quiche/quic/core/qpack/value_splitting_header_list.cc", + "src/quiche/quic/core/quic_ack_listener_interface.cc", + "src/quiche/quic/core/quic_alarm.cc", + "src/quiche/quic/core/quic_bandwidth.cc", + "src/quiche/quic/core/quic_buffered_packet_store.cc", + "src/quiche/quic/core/quic_chaos_protector.cc", + "src/quiche/quic/core/quic_coalesced_packet.cc", + "src/quiche/quic/core/quic_config.cc", + "src/quiche/quic/core/quic_connection.cc", + "src/quiche/quic/core/quic_connection_context.cc", + "src/quiche/quic/core/quic_connection_id.cc", + "src/quiche/quic/core/quic_connection_id_manager.cc", + "src/quiche/quic/core/quic_connection_stats.cc", + "src/quiche/quic/core/quic_constants.cc", + "src/quiche/quic/core/quic_control_frame_manager.cc", + "src/quiche/quic/core/quic_crypto_client_handshaker.cc", + "src/quiche/quic/core/quic_crypto_client_stream.cc", + "src/quiche/quic/core/quic_crypto_handshaker.cc", + "src/quiche/quic/core/quic_crypto_server_stream.cc", + "src/quiche/quic/core/quic_crypto_server_stream_base.cc", + "src/quiche/quic/core/quic_crypto_stream.cc", + "src/quiche/quic/core/quic_data_reader.cc", + "src/quiche/quic/core/quic_data_writer.cc", + "src/quiche/quic/core/quic_datagram_queue.cc", + "src/quiche/quic/core/quic_default_clock.cc", + "src/quiche/quic/core/quic_dispatcher.cc", + "src/quiche/quic/core/quic_error_codes.cc", + "src/quiche/quic/core/quic_flow_controller.cc", + "src/quiche/quic/core/quic_framer.cc", + "src/quiche/quic/core/quic_idle_network_detector.cc", + "src/quiche/quic/core/quic_mtu_discovery.cc", + "src/quiche/quic/core/quic_network_blackhole_detector.cc", + "src/quiche/quic/core/quic_packet_creator.cc", + "src/quiche/quic/core/quic_packet_number.cc", + "src/quiche/quic/core/quic_packet_writer_wrapper.cc", + "src/quiche/quic/core/quic_packets.cc", + "src/quiche/quic/core/quic_path_validator.cc", + "src/quiche/quic/core/quic_ping_manager.cc", + "src/quiche/quic/core/quic_received_packet_manager.cc", + "src/quiche/quic/core/quic_sent_packet_manager.cc", + "src/quiche/quic/core/quic_server_id.cc", + "src/quiche/quic/core/quic_session.cc", + "src/quiche/quic/core/quic_socket_address_coder.cc", + "src/quiche/quic/core/quic_stream.cc", + "src/quiche/quic/core/quic_stream_id_manager.cc", + "src/quiche/quic/core/quic_stream_priority.cc", + "src/quiche/quic/core/quic_stream_send_buffer.cc", + "src/quiche/quic/core/quic_stream_sequencer.cc", + "src/quiche/quic/core/quic_stream_sequencer_buffer.cc", + "src/quiche/quic/core/quic_sustained_bandwidth_recorder.cc", + "src/quiche/quic/core/quic_tag.cc", + "src/quiche/quic/core/quic_time.cc", + "src/quiche/quic/core/quic_time_wait_list_manager.cc", + "src/quiche/quic/core/quic_trace_visitor.cc", + "src/quiche/quic/core/quic_transmission_info.cc", + "src/quiche/quic/core/quic_types.cc", + "src/quiche/quic/core/quic_unacked_packet_map.cc", + "src/quiche/quic/core/quic_utils.cc", + "src/quiche/quic/core/quic_version_manager.cc", + "src/quiche/quic/core/quic_versions.cc", + "src/quiche/quic/core/quic_write_blocked_list.cc", + "src/quiche/quic/core/tls_chlo_extractor.cc", + "src/quiche/quic/core/tls_client_handshaker.cc", + "src/quiche/quic/core/tls_handshaker.cc", + "src/quiche/quic/core/tls_server_handshaker.cc", + "src/quiche/quic/core/uber_quic_stream_id_manager.cc", + "src/quiche/quic/core/uber_received_packet_manager.cc", + "src/quiche/quic/platform/api/quic_socket_address.cc", + "src/quiche/spdy/core/array_output_buffer.cc", + "src/quiche/spdy/core/hpack/hpack_constants.cc", + "src/quiche/spdy/core/hpack/hpack_decoder_adapter.cc", + "src/quiche/spdy/core/hpack/hpack_encoder.cc", + "src/quiche/spdy/core/hpack/hpack_entry.cc", + "src/quiche/spdy/core/hpack/hpack_header_table.cc", + "src/quiche/spdy/core/hpack/hpack_output_stream.cc", + "src/quiche/spdy/core/hpack/hpack_static_table.cc", + "src/quiche/spdy/core/http2_frame_decoder_adapter.cc", + "src/quiche/spdy/core/http2_header_block.cc", + "src/quiche/spdy/core/http2_header_storage.cc", + "src/quiche/spdy/core/metadata_extension.cc", + "src/quiche/spdy/core/recording_headers_handler.cc", + "src/quiche/spdy/core/spdy_alt_svc_wire_format.cc", + "src/quiche/spdy/core/spdy_frame_builder.cc", + "src/quiche/spdy/core/spdy_framer.cc", + "src/quiche/spdy/core/spdy_no_op_visitor.cc", + "src/quiche/spdy/core/spdy_pinnable_buffer_piece.cc", + "src/quiche/spdy/core/spdy_prefixed_buffer_reader.cc", + "src/quiche/spdy/core/spdy_protocol.cc", + "src/quiche/spdy/core/spdy_simple_arena.cc", +] +quiche_tool_support_hdrs = [ + "src/quiche/common/platform/api/quiche_command_line_flags.h", + "src/quiche/common/platform/api/quiche_default_proof_providers.h", + "src/quiche/common/platform/api/quiche_file_utils.h", + "src/quiche/common/platform/api/quiche_system_event_loop.h", + "src/quiche/quic/platform/api/quic_default_proof_providers.h", + "src/quiche/quic/tools/connect_server_backend.h", + "src/quiche/quic/tools/connect_tunnel.h", + "src/quiche/quic/tools/connect_udp_tunnel.h", + "src/quiche/quic/tools/fake_proof_verifier.h", + "src/quiche/quic/tools/quic_backend_response.h", + "src/quiche/quic/tools/quic_client_base.h", + "src/quiche/quic/tools/quic_memory_cache_backend.h", + "src/quiche/quic/tools/quic_name_lookup.h", + "src/quiche/quic/tools/quic_simple_client_session.h", + "src/quiche/quic/tools/quic_simple_client_stream.h", + "src/quiche/quic/tools/quic_simple_crypto_server_stream_helper.h", + "src/quiche/quic/tools/quic_simple_dispatcher.h", + "src/quiche/quic/tools/quic_simple_server_backend.h", + "src/quiche/quic/tools/quic_simple_server_session.h", + "src/quiche/quic/tools/quic_simple_server_stream.h", + "src/quiche/quic/tools/quic_spdy_client_base.h", + "src/quiche/quic/tools/quic_spdy_server_base.h", + "src/quiche/quic/tools/quic_tcp_like_trace_converter.h", + "src/quiche/quic/tools/quic_url.h", + "src/quiche/quic/tools/simple_ticket_crypter.h", + "src/quiche/quic/tools/web_transport_test_visitors.h", +] +quiche_tool_support_srcs = [ + "src/quiche/common/platform/api/quiche_file_utils.cc", + "src/quiche/quic/tools/connect_server_backend.cc", + "src/quiche/quic/tools/connect_tunnel.cc", + "src/quiche/quic/tools/connect_udp_tunnel.cc", + "src/quiche/quic/tools/quic_backend_response.cc", + "src/quiche/quic/tools/quic_client_base.cc", + "src/quiche/quic/tools/quic_memory_cache_backend.cc", + "src/quiche/quic/tools/quic_name_lookup.cc", + "src/quiche/quic/tools/quic_simple_client_session.cc", + "src/quiche/quic/tools/quic_simple_client_stream.cc", + "src/quiche/quic/tools/quic_simple_crypto_server_stream_helper.cc", + "src/quiche/quic/tools/quic_simple_dispatcher.cc", + "src/quiche/quic/tools/quic_simple_server_session.cc", + "src/quiche/quic/tools/quic_simple_server_stream.cc", + "src/quiche/quic/tools/quic_spdy_client_base.cc", + "src/quiche/quic/tools/quic_tcp_like_trace_converter.cc", + "src/quiche/quic/tools/quic_url.cc", + "src/quiche/quic/tools/simple_ticket_crypter.cc", +] +quiche_test_support_hdrs = [ + "src/quiche/common/platform/api/quiche_expect_bug.h", + "src/quiche/common/platform/api/quiche_test.h", + "src/quiche/common/platform/api/quiche_test_loopback.h", + "src/quiche/common/platform/api/quiche_test_output.h", + "src/quiche/common/test_tools/quiche_test_utils.h", + "src/quiche/http2/adapter/mock_http2_visitor.h", + "src/quiche/http2/adapter/recording_http2_visitor.h", + "src/quiche/http2/adapter/test_frame_sequence.h", + "src/quiche/http2/adapter/test_utils.h", + "src/quiche/http2/test_tools/frame_decoder_state_test_util.h", + "src/quiche/http2/test_tools/frame_parts.h", + "src/quiche/http2/test_tools/frame_parts_collector.h", + "src/quiche/http2/test_tools/frame_parts_collector_listener.h", + "src/quiche/http2/test_tools/hpack_block_builder.h", + "src/quiche/http2/test_tools/hpack_block_collector.h", + "src/quiche/http2/test_tools/hpack_entry_collector.h", + "src/quiche/http2/test_tools/hpack_example.h", + "src/quiche/http2/test_tools/hpack_string_collector.h", + "src/quiche/http2/test_tools/http2_constants_test_util.h", + "src/quiche/http2/test_tools/http2_frame_builder.h", + "src/quiche/http2/test_tools/http2_frame_decoder_listener_test_util.h", + "src/quiche/http2/test_tools/http2_random.h", + "src/quiche/http2/test_tools/http2_structure_decoder_test_util.h", + "src/quiche/http2/test_tools/http2_structures_test_util.h", + "src/quiche/http2/test_tools/payload_decoder_base_test_util.h", + "src/quiche/http2/test_tools/random_decoder_test_base.h", + "src/quiche/http2/test_tools/random_util.h", + "src/quiche/http2/test_tools/verify_macros.h", + "src/quiche/quic/platform/api/quic_expect_bug.h", + "src/quiche/quic/platform/api/quic_test.h", + "src/quiche/quic/platform/api/quic_test_loopback.h", + "src/quiche/quic/platform/api/quic_test_output.h", + "src/quiche/quic/test_tools/bad_packet_writer.h", + "src/quiche/quic/test_tools/crypto_test_utils.h", + "src/quiche/quic/test_tools/failing_proof_source.h", + "src/quiche/quic/test_tools/fake_proof_source.h", + "src/quiche/quic/test_tools/fake_proof_source_handle.h", + "src/quiche/quic/test_tools/first_flight.h", + "src/quiche/quic/test_tools/limited_mtu_test_writer.h", + "src/quiche/quic/test_tools/mock_clock.h", + "src/quiche/quic/test_tools/mock_connection_id_generator.h", + "src/quiche/quic/test_tools/mock_quic_client_promised_info.h", + "src/quiche/quic/test_tools/mock_quic_dispatcher.h", + "src/quiche/quic/test_tools/mock_quic_session_visitor.h", + "src/quiche/quic/test_tools/mock_quic_spdy_client_stream.h", + "src/quiche/quic/test_tools/mock_quic_time_wait_list_manager.h", + "src/quiche/quic/test_tools/mock_random.h", + "src/quiche/quic/test_tools/packet_dropping_test_writer.h", + "src/quiche/quic/test_tools/packet_reordering_writer.h", + "src/quiche/quic/test_tools/qpack/qpack_decoder_test_utils.h", + "src/quiche/quic/test_tools/qpack/qpack_encoder_peer.h", + "src/quiche/quic/test_tools/qpack/qpack_offline_decoder.h", + "src/quiche/quic/test_tools/qpack/qpack_test_utils.h", + "src/quiche/quic/test_tools/quic_buffered_packet_store_peer.h", + "src/quiche/quic/test_tools/quic_client_promised_info_peer.h", + "src/quiche/quic/test_tools/quic_client_session_cache_peer.h", + "src/quiche/quic/test_tools/quic_coalesced_packet_peer.h", + "src/quiche/quic/test_tools/quic_config_peer.h", + "src/quiche/quic/test_tools/quic_connection_id_manager_peer.h", + "src/quiche/quic/test_tools/quic_connection_peer.h", + "src/quiche/quic/test_tools/quic_crypto_server_config_peer.h", + "src/quiche/quic/test_tools/quic_dispatcher_peer.h", + "src/quiche/quic/test_tools/quic_flow_controller_peer.h", + "src/quiche/quic/test_tools/quic_framer_peer.h", + "src/quiche/quic/test_tools/quic_interval_deque_peer.h", + "src/quiche/quic/test_tools/quic_packet_creator_peer.h", + "src/quiche/quic/test_tools/quic_path_validator_peer.h", + "src/quiche/quic/test_tools/quic_sent_packet_manager_peer.h", + "src/quiche/quic/test_tools/quic_server_session_base_peer.h", + "src/quiche/quic/test_tools/quic_session_peer.h", + "src/quiche/quic/test_tools/quic_spdy_session_peer.h", + "src/quiche/quic/test_tools/quic_spdy_stream_peer.h", + "src/quiche/quic/test_tools/quic_stream_id_manager_peer.h", + "src/quiche/quic/test_tools/quic_stream_peer.h", + "src/quiche/quic/test_tools/quic_stream_send_buffer_peer.h", + "src/quiche/quic/test_tools/quic_stream_sequencer_buffer_peer.h", + "src/quiche/quic/test_tools/quic_stream_sequencer_peer.h", + "src/quiche/quic/test_tools/quic_sustained_bandwidth_recorder_peer.h", + "src/quiche/quic/test_tools/quic_test_backend.h", + "src/quiche/quic/test_tools/quic_test_utils.h", + "src/quiche/quic/test_tools/quic_time_wait_list_manager_peer.h", + "src/quiche/quic/test_tools/quic_unacked_packet_map_peer.h", + "src/quiche/quic/test_tools/rtt_stats_peer.h", + "src/quiche/quic/test_tools/send_algorithm_test_utils.h", + "src/quiche/quic/test_tools/simple_data_producer.h", + "src/quiche/quic/test_tools/simple_quic_framer.h", + "src/quiche/quic/test_tools/simple_session_cache.h", + "src/quiche/quic/test_tools/simple_session_notifier.h", + "src/quiche/quic/test_tools/simulator/actor.h", + "src/quiche/quic/test_tools/simulator/alarm_factory.h", + "src/quiche/quic/test_tools/simulator/link.h", + "src/quiche/quic/test_tools/simulator/packet_filter.h", + "src/quiche/quic/test_tools/simulator/port.h", + "src/quiche/quic/test_tools/simulator/queue.h", + "src/quiche/quic/test_tools/simulator/quic_endpoint.h", + "src/quiche/quic/test_tools/simulator/quic_endpoint_base.h", + "src/quiche/quic/test_tools/simulator/simulator.h", + "src/quiche/quic/test_tools/simulator/switch.h", + "src/quiche/quic/test_tools/simulator/test_harness.h", + "src/quiche/quic/test_tools/simulator/traffic_policer.h", + "src/quiche/quic/test_tools/test_certificates.h", + "src/quiche/quic/test_tools/test_ticket_crypter.h", + "src/quiche/quic/test_tools/web_transport_resets_backend.h", + "src/quiche/quic/test_tools/web_transport_test_tools.h", + "src/quiche/spdy/test_tools/mock_spdy_framer_visitor.h", + "src/quiche/spdy/test_tools/spdy_test_utils.h", + "src/quiche/web_transport/test_tools/mock_web_transport.h", +] +quiche_test_support_srcs = [ + "src/quiche/common/platform/api/quiche_test_loopback.cc", + "src/quiche/common/test_tools/quiche_test_utils.cc", + "src/quiche/http2/adapter/recording_http2_visitor.cc", + "src/quiche/http2/adapter/test_frame_sequence.cc", + "src/quiche/http2/adapter/test_utils.cc", + "src/quiche/http2/test_tools/frame_decoder_state_test_util.cc", + "src/quiche/http2/test_tools/frame_parts.cc", + "src/quiche/http2/test_tools/frame_parts_collector.cc", + "src/quiche/http2/test_tools/frame_parts_collector_listener.cc", + "src/quiche/http2/test_tools/hpack_block_builder.cc", + "src/quiche/http2/test_tools/hpack_block_collector.cc", + "src/quiche/http2/test_tools/hpack_entry_collector.cc", + "src/quiche/http2/test_tools/hpack_example.cc", + "src/quiche/http2/test_tools/hpack_string_collector.cc", + "src/quiche/http2/test_tools/http2_constants_test_util.cc", + "src/quiche/http2/test_tools/http2_frame_builder.cc", + "src/quiche/http2/test_tools/http2_frame_decoder_listener_test_util.cc", + "src/quiche/http2/test_tools/http2_random.cc", + "src/quiche/http2/test_tools/http2_structure_decoder_test_util.cc", + "src/quiche/http2/test_tools/http2_structures_test_util.cc", + "src/quiche/http2/test_tools/payload_decoder_base_test_util.cc", + "src/quiche/http2/test_tools/random_decoder_test_base.cc", + "src/quiche/http2/test_tools/random_util.cc", + "src/quiche/quic/test_tools/bad_packet_writer.cc", + "src/quiche/quic/test_tools/crypto_test_utils.cc", + "src/quiche/quic/test_tools/failing_proof_source.cc", + "src/quiche/quic/test_tools/fake_proof_source.cc", + "src/quiche/quic/test_tools/fake_proof_source_handle.cc", + "src/quiche/quic/test_tools/first_flight.cc", + "src/quiche/quic/test_tools/limited_mtu_test_writer.cc", + "src/quiche/quic/test_tools/mock_clock.cc", + "src/quiche/quic/test_tools/mock_quic_client_promised_info.cc", + "src/quiche/quic/test_tools/mock_quic_dispatcher.cc", + "src/quiche/quic/test_tools/mock_quic_session_visitor.cc", + "src/quiche/quic/test_tools/mock_quic_spdy_client_stream.cc", + "src/quiche/quic/test_tools/mock_quic_time_wait_list_manager.cc", + "src/quiche/quic/test_tools/mock_random.cc", + "src/quiche/quic/test_tools/packet_dropping_test_writer.cc", + "src/quiche/quic/test_tools/packet_reordering_writer.cc", + "src/quiche/quic/test_tools/qpack/qpack_decoder_test_utils.cc", + "src/quiche/quic/test_tools/qpack/qpack_encoder_peer.cc", + "src/quiche/quic/test_tools/qpack/qpack_offline_decoder.cc", + "src/quiche/quic/test_tools/qpack/qpack_test_utils.cc", + "src/quiche/quic/test_tools/quic_buffered_packet_store_peer.cc", + "src/quiche/quic/test_tools/quic_client_promised_info_peer.cc", + "src/quiche/quic/test_tools/quic_coalesced_packet_peer.cc", + "src/quiche/quic/test_tools/quic_config_peer.cc", + "src/quiche/quic/test_tools/quic_connection_peer.cc", + "src/quiche/quic/test_tools/quic_crypto_server_config_peer.cc", + "src/quiche/quic/test_tools/quic_dispatcher_peer.cc", + "src/quiche/quic/test_tools/quic_flow_controller_peer.cc", + "src/quiche/quic/test_tools/quic_framer_peer.cc", + "src/quiche/quic/test_tools/quic_packet_creator_peer.cc", + "src/quiche/quic/test_tools/quic_path_validator_peer.cc", + "src/quiche/quic/test_tools/quic_sent_packet_manager_peer.cc", + "src/quiche/quic/test_tools/quic_session_peer.cc", + "src/quiche/quic/test_tools/quic_spdy_session_peer.cc", + "src/quiche/quic/test_tools/quic_spdy_stream_peer.cc", + "src/quiche/quic/test_tools/quic_stream_id_manager_peer.cc", + "src/quiche/quic/test_tools/quic_stream_peer.cc", + "src/quiche/quic/test_tools/quic_stream_send_buffer_peer.cc", + "src/quiche/quic/test_tools/quic_stream_sequencer_buffer_peer.cc", + "src/quiche/quic/test_tools/quic_stream_sequencer_peer.cc", + "src/quiche/quic/test_tools/quic_sustained_bandwidth_recorder_peer.cc", + "src/quiche/quic/test_tools/quic_test_backend.cc", + "src/quiche/quic/test_tools/quic_test_utils.cc", + "src/quiche/quic/test_tools/quic_time_wait_list_manager_peer.cc", + "src/quiche/quic/test_tools/quic_unacked_packet_map_peer.cc", + "src/quiche/quic/test_tools/rtt_stats_peer.cc", + "src/quiche/quic/test_tools/send_algorithm_test_utils.cc", + "src/quiche/quic/test_tools/simple_data_producer.cc", + "src/quiche/quic/test_tools/simple_quic_framer.cc", + "src/quiche/quic/test_tools/simple_session_cache.cc", + "src/quiche/quic/test_tools/simple_session_notifier.cc", + "src/quiche/quic/test_tools/simulator/actor.cc", + "src/quiche/quic/test_tools/simulator/alarm_factory.cc", + "src/quiche/quic/test_tools/simulator/link.cc", + "src/quiche/quic/test_tools/simulator/packet_filter.cc", + "src/quiche/quic/test_tools/simulator/port.cc", + "src/quiche/quic/test_tools/simulator/queue.cc", + "src/quiche/quic/test_tools/simulator/quic_endpoint.cc", + "src/quiche/quic/test_tools/simulator/quic_endpoint_base.cc", + "src/quiche/quic/test_tools/simulator/simulator.cc", + "src/quiche/quic/test_tools/simulator/switch.cc", + "src/quiche/quic/test_tools/simulator/test_harness.cc", + "src/quiche/quic/test_tools/simulator/traffic_policer.cc", + "src/quiche/quic/test_tools/test_certificates.cc", + "src/quiche/quic/test_tools/test_ticket_crypter.cc", + "src/quiche/quic/test_tools/web_transport_resets_backend.cc", + "src/quiche/spdy/test_tools/mock_spdy_framer_visitor.cc", + "src/quiche/spdy/test_tools/spdy_test_utils.cc", +] +io_tool_support_hdrs = [ + "src/quiche/common/platform/api/quiche_event_loop.h", + "src/quiche/common/platform/api/quiche_udp_socket_platform_api.h", + "src/quiche/quic/core/io/event_loop_connecting_client_socket.h", + "src/quiche/quic/core/io/event_loop_socket_factory.h", + "src/quiche/quic/core/io/quic_default_event_loop.h", + "src/quiche/quic/core/io/quic_event_loop.h", + "src/quiche/quic/core/io/quic_poll_event_loop.h", + "src/quiche/quic/core/io/socket.h", + "src/quiche/quic/core/quic_default_packet_writer.h", + "src/quiche/quic/core/quic_packet_reader.h", + "src/quiche/quic/core/quic_syscall_wrapper.h", + "src/quiche/quic/core/quic_udp_socket.h", + "src/quiche/quic/masque/masque_client.h", + "src/quiche/quic/masque/masque_client_session.h", + "src/quiche/quic/masque/masque_client_tools.h", + "src/quiche/quic/masque/masque_dispatcher.h", + "src/quiche/quic/masque/masque_encapsulated_client.h", + "src/quiche/quic/masque/masque_encapsulated_client_session.h", + "src/quiche/quic/masque/masque_server.h", + "src/quiche/quic/masque/masque_server_backend.h", + "src/quiche/quic/masque/masque_server_session.h", + "src/quiche/quic/masque/masque_utils.h", + "src/quiche/quic/platform/api/quic_udp_socket_platform_api.h", + "src/quiche/quic/tools/quic_client_default_network_helper.h", + "src/quiche/quic/tools/quic_client_factory.h", + "src/quiche/quic/tools/quic_default_client.h", + "src/quiche/quic/tools/quic_epoll_client_factory.h", + "src/quiche/quic/tools/quic_server.h", +] +io_tool_support_srcs = [ + "src/quiche/quic/core/io/event_loop_connecting_client_socket.cc", + "src/quiche/quic/core/io/event_loop_socket_factory.cc", + "src/quiche/quic/core/io/quic_default_event_loop.cc", + "src/quiche/quic/core/io/quic_poll_event_loop.cc", + "src/quiche/quic/core/io/socket_posix.cc", + "src/quiche/quic/core/quic_default_packet_writer.cc", + "src/quiche/quic/core/quic_packet_reader.cc", + "src/quiche/quic/core/quic_syscall_wrapper.cc", + "src/quiche/quic/core/quic_udp_socket_posix.cc", + "src/quiche/quic/masque/masque_client.cc", + "src/quiche/quic/masque/masque_client_session.cc", + "src/quiche/quic/masque/masque_client_tools.cc", + "src/quiche/quic/masque/masque_dispatcher.cc", + "src/quiche/quic/masque/masque_encapsulated_client.cc", + "src/quiche/quic/masque/masque_encapsulated_client_session.cc", + "src/quiche/quic/masque/masque_server.cc", + "src/quiche/quic/masque/masque_server_backend.cc", + "src/quiche/quic/masque/masque_server_session.cc", + "src/quiche/quic/masque/masque_utils.cc", + "src/quiche/quic/tools/quic_client_default_network_helper.cc", + "src/quiche/quic/tools/quic_default_client.cc", + "src/quiche/quic/tools/quic_epoll_client_factory.cc", + "src/quiche/quic/tools/quic_server.cc", +] +io_test_support_hdrs = [ + "src/quiche/quic/test_tools/quic_mock_syscall_wrapper.h", + "src/quiche/quic/test_tools/quic_server_peer.h", + "src/quiche/quic/test_tools/quic_test_client.h", + "src/quiche/quic/test_tools/quic_test_server.h", + "src/quiche/quic/test_tools/server_thread.h", +] +io_test_support_srcs = [ + "src/quiche/quic/test_tools/quic_mock_syscall_wrapper.cc", + "src/quiche/quic/test_tools/quic_server_peer.cc", + "src/quiche/quic/test_tools/quic_test_client.cc", + "src/quiche/quic/test_tools/quic_test_server.cc", + "src/quiche/quic/test_tools/server_thread.cc", +] +quiche_tests_hdrs = [ + +] +quiche_tests_srcs = [ + "src/quiche/balsa/balsa_frame_test.cc", + "src/quiche/balsa/balsa_headers_test.cc", + "src/quiche/balsa/header_properties_test.cc", + "src/quiche/balsa/simple_buffer_test.cc", + "src/quiche/binary_http/binary_http_message_test.cc", + "src/quiche/common/btree_scheduler_test.cc", + "src/quiche/common/capsule_test.cc", + "src/quiche/common/masque/connect_udp_datagram_payload_test.cc", + "src/quiche/common/platform/api/quiche_file_utils_test.cc", + "src/quiche/common/platform/api/quiche_hostname_utils_test.cc", + "src/quiche/common/platform/api/quiche_lower_case_string_test.cc", + "src/quiche/common/platform/api/quiche_mem_slice_test.cc", + "src/quiche/common/platform/api/quiche_reference_counted_test.cc", + "src/quiche/common/platform/api/quiche_stack_trace_test.cc", + "src/quiche/common/platform/api/quiche_time_utils_test.cc", + "src/quiche/common/platform/api/quiche_url_utils_test.cc", + "src/quiche/common/print_elements_test.cc", + "src/quiche/common/quiche_buffer_allocator_test.cc", + "src/quiche/common/quiche_circular_deque_test.cc", + "src/quiche/common/quiche_data_reader_test.cc", + "src/quiche/common/quiche_data_writer_test.cc", + "src/quiche/common/quiche_endian_test.cc", + "src/quiche/common/quiche_ip_address_test.cc", + "src/quiche/common/quiche_linked_hash_map_test.cc", + "src/quiche/common/quiche_mem_slice_storage_test.cc", + "src/quiche/common/quiche_random_test.cc", + "src/quiche/common/quiche_text_utils_test.cc", + "src/quiche/common/simple_buffer_allocator_test.cc", + "src/quiche/common/structured_headers_generated_test.cc", + "src/quiche/common/structured_headers_test.cc", + "src/quiche/common/test_tools/quiche_test_utils_test.cc", + "src/quiche/common/wire_serialization_test.cc", + "src/quiche/http2/adapter/event_forwarder_test.cc", + "src/quiche/http2/adapter/header_validator_test.cc", + "src/quiche/http2/adapter/noop_header_validator_test.cc", + "src/quiche/http2/adapter/oghttp2_adapter_test.cc", + "src/quiche/http2/adapter/oghttp2_session_test.cc", + "src/quiche/http2/adapter/oghttp2_util_test.cc", + "src/quiche/http2/adapter/recording_http2_visitor_test.cc", + "src/quiche/http2/adapter/test_utils_test.cc", + "src/quiche/http2/adapter/window_manager_test.cc", + "src/quiche/http2/core/priority_write_scheduler_test.cc", + "src/quiche/http2/decoder/decode_buffer_test.cc", + "src/quiche/http2/decoder/decode_http2_structures_test.cc", + "src/quiche/http2/decoder/http2_frame_decoder_test.cc", + "src/quiche/http2/decoder/http2_structure_decoder_test.cc", + "src/quiche/http2/decoder/payload_decoders/altsvc_payload_decoder_test.cc", + "src/quiche/http2/decoder/payload_decoders/continuation_payload_decoder_test.cc", + "src/quiche/http2/decoder/payload_decoders/data_payload_decoder_test.cc", + "src/quiche/http2/decoder/payload_decoders/goaway_payload_decoder_test.cc", + "src/quiche/http2/decoder/payload_decoders/headers_payload_decoder_test.cc", + "src/quiche/http2/decoder/payload_decoders/ping_payload_decoder_test.cc", + "src/quiche/http2/decoder/payload_decoders/priority_payload_decoder_test.cc", + "src/quiche/http2/decoder/payload_decoders/priority_update_payload_decoder_test.cc", + "src/quiche/http2/decoder/payload_decoders/push_promise_payload_decoder_test.cc", + "src/quiche/http2/decoder/payload_decoders/rst_stream_payload_decoder_test.cc", + "src/quiche/http2/decoder/payload_decoders/settings_payload_decoder_test.cc", + "src/quiche/http2/decoder/payload_decoders/unknown_payload_decoder_test.cc", + "src/quiche/http2/decoder/payload_decoders/window_update_payload_decoder_test.cc", + "src/quiche/http2/hpack/decoder/hpack_block_collector_test.cc", + "src/quiche/http2/hpack/decoder/hpack_block_decoder_test.cc", + "src/quiche/http2/hpack/decoder/hpack_decoder_state_test.cc", + "src/quiche/http2/hpack/decoder/hpack_decoder_string_buffer_test.cc", + "src/quiche/http2/hpack/decoder/hpack_decoder_tables_test.cc", + "src/quiche/http2/hpack/decoder/hpack_decoder_test.cc", + "src/quiche/http2/hpack/decoder/hpack_entry_collector_test.cc", + "src/quiche/http2/hpack/decoder/hpack_entry_decoder_test.cc", + "src/quiche/http2/hpack/decoder/hpack_entry_type_decoder_test.cc", + "src/quiche/http2/hpack/decoder/hpack_string_decoder_test.cc", + "src/quiche/http2/hpack/decoder/hpack_whole_entry_buffer_test.cc", + "src/quiche/http2/hpack/http2_hpack_constants_test.cc", + "src/quiche/http2/hpack/huffman/hpack_huffman_decoder_test.cc", + "src/quiche/http2/hpack/huffman/hpack_huffman_encoder_test.cc", + "src/quiche/http2/hpack/huffman/hpack_huffman_transcoder_test.cc", + "src/quiche/http2/hpack/varint/hpack_varint_decoder_test.cc", + "src/quiche/http2/hpack/varint/hpack_varint_encoder_test.cc", + "src/quiche/http2/hpack/varint/hpack_varint_round_trip_test.cc", + "src/quiche/http2/http2_constants_test.cc", + "src/quiche/http2/http2_structures_test.cc", + "src/quiche/http2/test_tools/hpack_block_builder_test.cc", + "src/quiche/http2/test_tools/hpack_example_test.cc", + "src/quiche/http2/test_tools/http2_frame_builder_test.cc", + "src/quiche/http2/test_tools/http2_random_test.cc", + "src/quiche/http2/test_tools/random_decoder_test_base_test.cc", + "src/quiche/oblivious_http/buffers/oblivious_http_integration_test.cc", + "src/quiche/oblivious_http/buffers/oblivious_http_request_test.cc", + "src/quiche/oblivious_http/buffers/oblivious_http_response_test.cc", + "src/quiche/oblivious_http/common/oblivious_http_header_key_config_test.cc", + "src/quiche/oblivious_http/oblivious_http_client_test.cc", + "src/quiche/oblivious_http/oblivious_http_gateway_test.cc", + "src/quiche/quic/core/congestion_control/bandwidth_sampler_test.cc", + "src/quiche/quic/core/congestion_control/bbr2_simulator_test.cc", + "src/quiche/quic/core/congestion_control/bbr_sender_test.cc", + "src/quiche/quic/core/congestion_control/cubic_bytes_test.cc", + "src/quiche/quic/core/congestion_control/general_loss_algorithm_test.cc", + "src/quiche/quic/core/congestion_control/hybrid_slow_start_test.cc", + "src/quiche/quic/core/congestion_control/pacing_sender_test.cc", + "src/quiche/quic/core/congestion_control/prr_sender_test.cc", + "src/quiche/quic/core/congestion_control/rtt_stats_test.cc", + "src/quiche/quic/core/congestion_control/send_algorithm_test.cc", + "src/quiche/quic/core/congestion_control/tcp_cubic_sender_bytes_test.cc", + "src/quiche/quic/core/congestion_control/uber_loss_algorithm_test.cc", + "src/quiche/quic/core/congestion_control/windowed_filter_test.cc", + "src/quiche/quic/core/crypto/aes_128_gcm_12_decrypter_test.cc", + "src/quiche/quic/core/crypto/aes_128_gcm_12_encrypter_test.cc", + "src/quiche/quic/core/crypto/aes_128_gcm_decrypter_test.cc", + "src/quiche/quic/core/crypto/aes_128_gcm_encrypter_test.cc", + "src/quiche/quic/core/crypto/aes_256_gcm_decrypter_test.cc", + "src/quiche/quic/core/crypto/aes_256_gcm_encrypter_test.cc", + "src/quiche/quic/core/crypto/cert_compressor_test.cc", + "src/quiche/quic/core/crypto/certificate_util_test.cc", + "src/quiche/quic/core/crypto/certificate_view_test.cc", + "src/quiche/quic/core/crypto/chacha20_poly1305_decrypter_test.cc", + "src/quiche/quic/core/crypto/chacha20_poly1305_encrypter_test.cc", + "src/quiche/quic/core/crypto/chacha20_poly1305_tls_decrypter_test.cc", + "src/quiche/quic/core/crypto/chacha20_poly1305_tls_encrypter_test.cc", + "src/quiche/quic/core/crypto/channel_id_test.cc", + "src/quiche/quic/core/crypto/client_proof_source_test.cc", + "src/quiche/quic/core/crypto/crypto_framer_test.cc", + "src/quiche/quic/core/crypto/crypto_handshake_message_test.cc", + "src/quiche/quic/core/crypto/crypto_secret_boxer_test.cc", + "src/quiche/quic/core/crypto/crypto_server_test.cc", + "src/quiche/quic/core/crypto/crypto_utils_test.cc", + "src/quiche/quic/core/crypto/curve25519_key_exchange_test.cc", + "src/quiche/quic/core/crypto/null_decrypter_test.cc", + "src/quiche/quic/core/crypto/null_encrypter_test.cc", + "src/quiche/quic/core/crypto/p256_key_exchange_test.cc", + "src/quiche/quic/core/crypto/proof_source_x509_test.cc", + "src/quiche/quic/core/crypto/quic_client_session_cache_test.cc", + "src/quiche/quic/core/crypto/quic_compressed_certs_cache_test.cc", + "src/quiche/quic/core/crypto/quic_crypto_client_config_test.cc", + "src/quiche/quic/core/crypto/quic_crypto_server_config_test.cc", + "src/quiche/quic/core/crypto/quic_hkdf_test.cc", + "src/quiche/quic/core/crypto/transport_parameters_test.cc", + "src/quiche/quic/core/crypto/web_transport_fingerprint_proof_verifier_test.cc", + "src/quiche/quic/core/deterministic_connection_id_generator_test.cc", + "src/quiche/quic/core/frames/quic_frames_test.cc", + "src/quiche/quic/core/http/http_decoder_test.cc", + "src/quiche/quic/core/http/http_encoder_test.cc", + "src/quiche/quic/core/http/http_frames_test.cc", + "src/quiche/quic/core/http/quic_client_promised_info_test.cc", + "src/quiche/quic/core/http/quic_client_push_promise_index_test.cc", + "src/quiche/quic/core/http/quic_header_list_test.cc", + "src/quiche/quic/core/http/quic_headers_stream_test.cc", + "src/quiche/quic/core/http/quic_receive_control_stream_test.cc", + "src/quiche/quic/core/http/quic_send_control_stream_test.cc", + "src/quiche/quic/core/http/quic_server_session_base_test.cc", + "src/quiche/quic/core/http/quic_spdy_session_test.cc", + "src/quiche/quic/core/http/quic_spdy_stream_body_manager_test.cc", + "src/quiche/quic/core/http/quic_spdy_stream_test.cc", + "src/quiche/quic/core/http/spdy_server_push_utils_test.cc", + "src/quiche/quic/core/http/spdy_utils_test.cc", + "src/quiche/quic/core/http/web_transport_http3_test.cc", + "src/quiche/quic/core/legacy_quic_stream_id_manager_test.cc", + "src/quiche/quic/core/packet_number_indexed_queue_test.cc", + "src/quiche/quic/core/qpack/qpack_blocking_manager_test.cc", + "src/quiche/quic/core/qpack/qpack_decoded_headers_accumulator_test.cc", + "src/quiche/quic/core/qpack/qpack_decoder_stream_receiver_test.cc", + "src/quiche/quic/core/qpack/qpack_decoder_stream_sender_test.cc", + "src/quiche/quic/core/qpack/qpack_decoder_test.cc", + "src/quiche/quic/core/qpack/qpack_encoder_stream_receiver_test.cc", + "src/quiche/quic/core/qpack/qpack_encoder_stream_sender_test.cc", + "src/quiche/quic/core/qpack/qpack_encoder_test.cc", + "src/quiche/quic/core/qpack/qpack_header_table_test.cc", + "src/quiche/quic/core/qpack/qpack_index_conversions_test.cc", + "src/quiche/quic/core/qpack/qpack_instruction_decoder_test.cc", + "src/quiche/quic/core/qpack/qpack_instruction_encoder_test.cc", + "src/quiche/quic/core/qpack/qpack_receive_stream_test.cc", + "src/quiche/quic/core/qpack/qpack_required_insert_count_test.cc", + "src/quiche/quic/core/qpack/qpack_round_trip_test.cc", + "src/quiche/quic/core/qpack/qpack_send_stream_test.cc", + "src/quiche/quic/core/qpack/qpack_static_table_test.cc", + "src/quiche/quic/core/qpack/value_splitting_header_list_test.cc", + "src/quiche/quic/core/quic_alarm_test.cc", + "src/quiche/quic/core/quic_arena_scoped_ptr_test.cc", + "src/quiche/quic/core/quic_bandwidth_test.cc", + "src/quiche/quic/core/quic_buffered_packet_store_test.cc", + "src/quiche/quic/core/quic_chaos_protector_test.cc", + "src/quiche/quic/core/quic_coalesced_packet_test.cc", + "src/quiche/quic/core/quic_config_test.cc", + "src/quiche/quic/core/quic_connection_context_test.cc", + "src/quiche/quic/core/quic_connection_id_manager_test.cc", + "src/quiche/quic/core/quic_connection_id_test.cc", + "src/quiche/quic/core/quic_connection_test.cc", + "src/quiche/quic/core/quic_control_frame_manager_test.cc", + "src/quiche/quic/core/quic_crypto_client_handshaker_test.cc", + "src/quiche/quic/core/quic_crypto_client_stream_test.cc", + "src/quiche/quic/core/quic_crypto_server_stream_test.cc", + "src/quiche/quic/core/quic_crypto_stream_test.cc", + "src/quiche/quic/core/quic_data_writer_test.cc", + "src/quiche/quic/core/quic_datagram_queue_test.cc", + "src/quiche/quic/core/quic_dispatcher_test.cc", + "src/quiche/quic/core/quic_error_codes_test.cc", + "src/quiche/quic/core/quic_flow_controller_test.cc", + "src/quiche/quic/core/quic_framer_test.cc", + "src/quiche/quic/core/quic_idle_network_detector_test.cc", + "src/quiche/quic/core/quic_interval_deque_test.cc", + "src/quiche/quic/core/quic_interval_set_test.cc", + "src/quiche/quic/core/quic_interval_test.cc", + "src/quiche/quic/core/quic_lru_cache_test.cc", + "src/quiche/quic/core/quic_network_blackhole_detector_test.cc", + "src/quiche/quic/core/quic_one_block_arena_test.cc", + "src/quiche/quic/core/quic_packet_creator_test.cc", + "src/quiche/quic/core/quic_packet_number_test.cc", + "src/quiche/quic/core/quic_packets_test.cc", + "src/quiche/quic/core/quic_path_validator_test.cc", + "src/quiche/quic/core/quic_ping_manager_test.cc", + "src/quiche/quic/core/quic_received_packet_manager_test.cc", + "src/quiche/quic/core/quic_sent_packet_manager_test.cc", + "src/quiche/quic/core/quic_server_id_test.cc", + "src/quiche/quic/core/quic_session_test.cc", + "src/quiche/quic/core/quic_socket_address_coder_test.cc", + "src/quiche/quic/core/quic_stream_id_manager_test.cc", + "src/quiche/quic/core/quic_stream_priority_test.cc", + "src/quiche/quic/core/quic_stream_send_buffer_test.cc", + "src/quiche/quic/core/quic_stream_sequencer_buffer_test.cc", + "src/quiche/quic/core/quic_stream_sequencer_test.cc", + "src/quiche/quic/core/quic_stream_test.cc", + "src/quiche/quic/core/quic_sustained_bandwidth_recorder_test.cc", + "src/quiche/quic/core/quic_tag_test.cc", + "src/quiche/quic/core/quic_time_accumulator_test.cc", + "src/quiche/quic/core/quic_time_test.cc", + "src/quiche/quic/core/quic_time_wait_list_manager_test.cc", + "src/quiche/quic/core/quic_trace_visitor_test.cc", + "src/quiche/quic/core/quic_unacked_packet_map_test.cc", + "src/quiche/quic/core/quic_utils_test.cc", + "src/quiche/quic/core/quic_version_manager_test.cc", + "src/quiche/quic/core/quic_versions_test.cc", + "src/quiche/quic/core/quic_write_blocked_list_test.cc", + "src/quiche/quic/core/tls_chlo_extractor_test.cc", + "src/quiche/quic/core/tls_client_handshaker_test.cc", + "src/quiche/quic/core/tls_server_handshaker_test.cc", + "src/quiche/quic/core/uber_quic_stream_id_manager_test.cc", + "src/quiche/quic/core/uber_received_packet_manager_test.cc", + "src/quiche/quic/platform/api/quic_socket_address_test.cc", + "src/quiche/quic/test_tools/crypto_test_utils_test.cc", + "src/quiche/quic/test_tools/quic_test_utils_test.cc", + "src/quiche/quic/test_tools/simple_session_notifier_test.cc", + "src/quiche/quic/test_tools/simulator/quic_endpoint_test.cc", + "src/quiche/quic/test_tools/simulator/simulator_test.cc", + "src/quiche/quic/tools/connect_tunnel_test.cc", + "src/quiche/quic/tools/connect_udp_tunnel_test.cc", + "src/quiche/quic/tools/quic_memory_cache_backend_test.cc", + "src/quiche/quic/tools/quic_tcp_like_trace_converter_test.cc", + "src/quiche/quic/tools/simple_ticket_crypter_test.cc", + "src/quiche/spdy/core/array_output_buffer_test.cc", + "src/quiche/spdy/core/hpack/hpack_decoder_adapter_test.cc", + "src/quiche/spdy/core/hpack/hpack_encoder_test.cc", + "src/quiche/spdy/core/hpack/hpack_entry_test.cc", + "src/quiche/spdy/core/hpack/hpack_header_table_test.cc", + "src/quiche/spdy/core/hpack/hpack_output_stream_test.cc", + "src/quiche/spdy/core/hpack/hpack_round_trip_test.cc", + "src/quiche/spdy/core/hpack/hpack_static_table_test.cc", + "src/quiche/spdy/core/http2_header_block_test.cc", + "src/quiche/spdy/core/http2_header_storage_test.cc", + "src/quiche/spdy/core/metadata_extension_test.cc", + "src/quiche/spdy/core/spdy_alt_svc_wire_format_test.cc", + "src/quiche/spdy/core/spdy_frame_builder_test.cc", + "src/quiche/spdy/core/spdy_framer_test.cc", + "src/quiche/spdy/core/spdy_intrusive_list_test.cc", + "src/quiche/spdy/core/spdy_pinnable_buffer_piece_test.cc", + "src/quiche/spdy/core/spdy_prefixed_buffer_reader_test.cc", + "src/quiche/spdy/core/spdy_protocol_test.cc", + "src/quiche/spdy/core/spdy_simple_arena_test.cc", +] +io_tests_hdrs = [ + +] +io_tests_srcs = [ + "src/quiche/quic/core/chlo_extractor_test.cc", + "src/quiche/quic/core/http/end_to_end_test.cc", + "src/quiche/quic/core/http/quic_spdy_client_session_test.cc", + "src/quiche/quic/core/http/quic_spdy_client_stream_test.cc", + "src/quiche/quic/core/http/quic_spdy_server_stream_base_test.cc", + "src/quiche/quic/core/io/event_loop_connecting_client_socket_test.cc", + "src/quiche/quic/core/io/quic_all_event_loops_test.cc", + "src/quiche/quic/core/io/quic_poll_event_loop_test.cc", + "src/quiche/quic/core/io/socket_test.cc", + "src/quiche/quic/tools/quic_default_client_test.cc", + "src/quiche/quic/tools/quic_server_test.cc", + "src/quiche/quic/tools/quic_simple_server_session_test.cc", + "src/quiche/quic/tools/quic_simple_server_stream_test.cc", + "src/quiche/quic/tools/quic_url_test.cc", +] +fuzzers_hdrs = [ + +] +fuzzers_srcs = [ + "src/quiche/common/structured_headers_fuzzer.cc", + "src/quiche/quic/core/crypto/certificate_view_der_fuzzer.cc", + "src/quiche/quic/core/crypto/certificate_view_pem_fuzzer.cc", + "src/quiche/quic/core/qpack/fuzzer/qpack_decoder_fuzzer.cc", + "src/quiche/quic/core/qpack/fuzzer/qpack_decoder_stream_receiver_fuzzer.cc", + "src/quiche/quic/core/qpack/fuzzer/qpack_decoder_stream_sender_fuzzer.cc", + "src/quiche/quic/core/qpack/fuzzer/qpack_encoder_stream_receiver_fuzzer.cc", + "src/quiche/quic/core/qpack/fuzzer/qpack_encoder_stream_sender_fuzzer.cc", + "src/quiche/quic/core/qpack/fuzzer/qpack_round_trip_fuzzer.cc", + "src/quiche/quic/test_tools/fuzzing/quic_framer_fuzzer.cc", + "src/quiche/quic/test_tools/fuzzing/quic_framer_process_data_packet_fuzzer.cc", +] +cli_tools_hdrs = [ + "src/quiche/quic/tools/quic_server_factory.h", + "src/quiche/quic/tools/quic_toy_client.h", + "src/quiche/quic/tools/quic_toy_server.h", +] +cli_tools_srcs = [ + "src/quiche/quic/masque/masque_client_bin.cc", + "src/quiche/quic/masque/masque_server_bin.cc", + "src/quiche/quic/tools/crypto_message_printer_bin.cc", + "src/quiche/quic/tools/qpack_offline_decoder_bin.cc", + "src/quiche/quic/tools/quic_client_bin.cc", + "src/quiche/quic/tools/quic_client_interop_test_bin.cc", + "src/quiche/quic/tools/quic_packet_printer_bin.cc", + "src/quiche/quic/tools/quic_reject_reason_decoder_bin.cc", + "src/quiche/quic/tools/quic_server_bin.cc", + "src/quiche/quic/tools/quic_server_factory.cc", + "src/quiche/quic/tools/quic_toy_client.cc", + "src/quiche/quic/tools/quic_toy_server.cc", +] +nghttp2_hdrs = [ + "src/quiche/http2/adapter/callback_visitor.h", + "src/quiche/http2/adapter/nghttp2.h", + "src/quiche/http2/adapter/nghttp2_adapter.h", + "src/quiche/http2/adapter/nghttp2_callbacks.h", + "src/quiche/http2/adapter/nghttp2_data_provider.h", + "src/quiche/http2/adapter/nghttp2_session.h", + "src/quiche/http2/adapter/nghttp2_util.h", +] +nghttp2_srcs = [ + "src/quiche/http2/adapter/callback_visitor.cc", + "src/quiche/http2/adapter/nghttp2_adapter.cc", + "src/quiche/http2/adapter/nghttp2_callbacks.cc", + "src/quiche/http2/adapter/nghttp2_data_provider.cc", + "src/quiche/http2/adapter/nghttp2_session.cc", + "src/quiche/http2/adapter/nghttp2_test.cc", + "src/quiche/http2/adapter/nghttp2_util.cc", +] +nghttp2_test_support_hdrs = [ + "src/quiche/http2/adapter/mock_nghttp2_callbacks.h", + "src/quiche/http2/adapter/nghttp2_test_utils.h", +] +nghttp2_test_support_srcs = [ + "src/quiche/http2/adapter/mock_nghttp2_callbacks.cc", + "src/quiche/http2/adapter/nghttp2_test_utils.cc", +] +nghttp2_tests_hdrs = [ + +] +nghttp2_tests_srcs = [ + "src/quiche/http2/adapter/adapter_impl_comparison_test.cc", + "src/quiche/http2/adapter/callback_visitor_test.cc", + "src/quiche/http2/adapter/nghttp2_adapter_test.cc", + "src/quiche/http2/adapter/nghttp2_data_provider_test.cc", + "src/quiche/http2/adapter/nghttp2_session_test.cc", + "src/quiche/http2/adapter/nghttp2_util_test.cc", +] +default_platform_impl_hdrs = [ + "src/quiche/common/platform/default/quiche_platform_impl/quiche_bug_tracker_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_client_stats_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_containers_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_event_loop_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_export_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_flag_utils_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_flags_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_header_policy_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_iovec_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_logging_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_lower_case_string_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_mem_slice_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_mutex_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_prefetch_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_reference_counted_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_server_stats_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_stack_trace_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_testvalue_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_time_utils_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_udp_socket_platform_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_url_utils_impl.h", +] +default_platform_impl_srcs = [ + "src/quiche/common/platform/default/quiche_platform_impl/quiche_flags_impl.cc", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_mutex_impl.cc", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_stack_trace_impl.cc", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_time_utils_impl.cc", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_url_utils_impl.cc", +] +default_platform_impl_tool_support_hdrs = [ + "src/quiche/common/platform/default/quiche_platform_impl/quiche_command_line_flags_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_default_proof_providers_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_file_utils_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_stream_buffer_allocator_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_system_event_loop_impl.h", +] +default_platform_impl_tool_support_srcs = [ + "src/quiche/common/platform/default/quiche_platform_impl/quiche_command_line_flags_impl.cc", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_default_proof_providers_impl.cc", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_file_utils_impl.cc", +] +default_platform_impl_test_support_hdrs = [ + "src/quiche/common/platform/default/quiche_platform_impl/quiche_expect_bug_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_test_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_test_loopback_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_test_output_impl.h", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_thread_impl.h", +] +default_platform_impl_test_support_srcs = [ + "src/quiche/common/platform/default/quiche_platform_impl/quiche_test_impl.cc", + "src/quiche/common/platform/default/quiche_platform_impl/quiche_test_loopback_impl.cc", +] +load_balancer_hdrs = [ + "src/quiche/quic/load_balancer/load_balancer_config.h", + "src/quiche/quic/load_balancer/load_balancer_decoder.h", + "src/quiche/quic/load_balancer/load_balancer_encoder.h", + "src/quiche/quic/load_balancer/load_balancer_server_id.h", + "src/quiche/quic/load_balancer/load_balancer_server_id_map.h", +] +load_balancer_srcs = [ + "src/quiche/quic/load_balancer/load_balancer_config.cc", + "src/quiche/quic/load_balancer/load_balancer_config_test.cc", + "src/quiche/quic/load_balancer/load_balancer_decoder.cc", + "src/quiche/quic/load_balancer/load_balancer_decoder_test.cc", + "src/quiche/quic/load_balancer/load_balancer_encoder.cc", + "src/quiche/quic/load_balancer/load_balancer_encoder_test.cc", + "src/quiche/quic/load_balancer/load_balancer_server_id.cc", + "src/quiche/quic/load_balancer/load_balancer_server_id_map_test.cc", + "src/quiche/quic/load_balancer/load_balancer_server_id_test.cc", +] +binary_http_hdrs = [ + "src/quiche/binary_http/binary_http_message.h", +] +binary_http_srcs = [ + "src/quiche/binary_http/binary_http_message.cc", +] +oblivious_http_hdrs = [ + "src/quiche/oblivious_http/buffers/oblivious_http_request.h", + "src/quiche/oblivious_http/buffers/oblivious_http_response.h", + "src/quiche/oblivious_http/common/oblivious_http_header_key_config.h", + "src/quiche/oblivious_http/oblivious_http_client.h", + "src/quiche/oblivious_http/oblivious_http_gateway.h", +] +oblivious_http_srcs = [ + "src/quiche/oblivious_http/buffers/oblivious_http_request.cc", + "src/quiche/oblivious_http/buffers/oblivious_http_response.cc", + "src/quiche/oblivious_http/common/oblivious_http_header_key_config.cc", + "src/quiche/oblivious_http/oblivious_http_client.cc", + "src/quiche/oblivious_http/oblivious_http_gateway.cc", +] +qbone_hdrs = [ + "src/quiche/quic/qbone/bonnet/icmp_reachable.h", + "src/quiche/quic/qbone/bonnet/icmp_reachable_interface.h", + "src/quiche/quic/qbone/bonnet/mock_icmp_reachable.h", + "src/quiche/quic/qbone/bonnet/mock_packet_exchanger_stats_interface.h", + "src/quiche/quic/qbone/bonnet/mock_qbone_tunnel.h", + "src/quiche/quic/qbone/bonnet/mock_tun_device.h", + "src/quiche/quic/qbone/bonnet/mock_tun_device_controller.h", + "src/quiche/quic/qbone/bonnet/qbone_tunnel_info.h", + "src/quiche/quic/qbone/bonnet/qbone_tunnel_interface.h", + "src/quiche/quic/qbone/bonnet/qbone_tunnel_silo.h", + "src/quiche/quic/qbone/bonnet/tun_device.h", + "src/quiche/quic/qbone/bonnet/tun_device_controller.h", + "src/quiche/quic/qbone/bonnet/tun_device_interface.h", + "src/quiche/quic/qbone/bonnet/tun_device_packet_exchanger.h", + "src/quiche/quic/qbone/mock_qbone_client.h", + "src/quiche/quic/qbone/mock_qbone_server_session.h", + "src/quiche/quic/qbone/platform/icmp_packet.h", + "src/quiche/quic/qbone/platform/internet_checksum.h", + "src/quiche/quic/qbone/platform/ip_range.h", + "src/quiche/quic/qbone/platform/kernel_interface.h", + "src/quiche/quic/qbone/platform/mock_kernel.h", + "src/quiche/quic/qbone/platform/mock_netlink.h", + "src/quiche/quic/qbone/platform/netlink.h", + "src/quiche/quic/qbone/platform/netlink_interface.h", + "src/quiche/quic/qbone/platform/rtnetlink_message.h", + "src/quiche/quic/qbone/platform/tcp_packet.h", + "src/quiche/quic/qbone/qbone_client.h", + "src/quiche/quic/qbone/qbone_client_interface.h", + "src/quiche/quic/qbone/qbone_client_session.h", + "src/quiche/quic/qbone/qbone_constants.h", + "src/quiche/quic/qbone/qbone_control_stream.h", + "src/quiche/quic/qbone/qbone_packet_exchanger.h", + "src/quiche/quic/qbone/qbone_packet_processor.h", + "src/quiche/quic/qbone/qbone_packet_processor_test_tools.h", + "src/quiche/quic/qbone/qbone_packet_writer.h", + "src/quiche/quic/qbone/qbone_server_session.h", + "src/quiche/quic/qbone/qbone_session_base.h", + "src/quiche/quic/qbone/qbone_stream.h", +] +qbone_srcs = [ + "src/quiche/quic/qbone/bonnet/icmp_reachable.cc", + "src/quiche/quic/qbone/bonnet/icmp_reachable_test.cc", + "src/quiche/quic/qbone/bonnet/qbone_tunnel_info.cc", + "src/quiche/quic/qbone/bonnet/qbone_tunnel_silo.cc", + "src/quiche/quic/qbone/bonnet/qbone_tunnel_silo_test.cc", + "src/quiche/quic/qbone/bonnet/tun_device.cc", + "src/quiche/quic/qbone/bonnet/tun_device_controller.cc", + "src/quiche/quic/qbone/bonnet/tun_device_controller_test.cc", + "src/quiche/quic/qbone/bonnet/tun_device_packet_exchanger.cc", + "src/quiche/quic/qbone/bonnet/tun_device_packet_exchanger_test.cc", + "src/quiche/quic/qbone/bonnet/tun_device_test.cc", + "src/quiche/quic/qbone/platform/icmp_packet.cc", + "src/quiche/quic/qbone/platform/icmp_packet_test.cc", + "src/quiche/quic/qbone/platform/internet_checksum.cc", + "src/quiche/quic/qbone/platform/internet_checksum_test.cc", + "src/quiche/quic/qbone/platform/ip_range.cc", + "src/quiche/quic/qbone/platform/ip_range_test.cc", + "src/quiche/quic/qbone/platform/netlink.cc", + "src/quiche/quic/qbone/platform/netlink_test.cc", + "src/quiche/quic/qbone/platform/rtnetlink_message.cc", + "src/quiche/quic/qbone/platform/rtnetlink_message_test.cc", + "src/quiche/quic/qbone/platform/tcp_packet.cc", + "src/quiche/quic/qbone/platform/tcp_packet_test.cc", + "src/quiche/quic/qbone/qbone_client.cc", + "src/quiche/quic/qbone/qbone_client_session.cc", + "src/quiche/quic/qbone/qbone_client_test.cc", + "src/quiche/quic/qbone/qbone_constants.cc", + "src/quiche/quic/qbone/qbone_control_stream.cc", + "src/quiche/quic/qbone/qbone_packet_exchanger.cc", + "src/quiche/quic/qbone/qbone_packet_exchanger_test.cc", + "src/quiche/quic/qbone/qbone_packet_processor.cc", + "src/quiche/quic/qbone/qbone_packet_processor_test.cc", + "src/quiche/quic/qbone/qbone_packet_processor_test_tools.cc", + "src/quiche/quic/qbone/qbone_server_session.cc", + "src/quiche/quic/qbone/qbone_session_base.cc", + "src/quiche/quic/qbone/qbone_session_test.cc", + "src/quiche/quic/qbone/qbone_stream.cc", + "src/quiche/quic/qbone/qbone_stream_test.cc", +] +blind_sign_auth_hdrs = [ + "src/quiche/blind_sign_auth/anonymous_tokens/cpp/client/anonymous_tokens_rsa_bssa_client.h", + "src/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/blind_signer.h", + "src/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/blinder.h", + "src/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/constants.h", + "src/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.h", + "src/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blind_signer.h", + "src/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blinder.h", + "src/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_ssa_pss_verifier.h", + "src/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/verifier.h", + "src/quiche/blind_sign_auth/anonymous_tokens/cpp/shared/proto_utils.h", + "src/quiche/blind_sign_auth/anonymous_tokens/cpp/shared/status_utils.h", + "src/quiche/blind_sign_auth/anonymous_tokens/cpp/testing/utils.h", + "src/quiche/blind_sign_auth/blind_sign_auth.h", + "src/quiche/blind_sign_auth/blind_sign_auth_interface.h", + "src/quiche/blind_sign_auth/blind_sign_http_interface.h", + "src/quiche/blind_sign_auth/blind_sign_http_response.h", + "src/quiche/blind_sign_auth/cached_blind_sign_auth.h", + "src/quiche/blind_sign_auth/test_tools/mock_blind_sign_auth_interface.h", + "src/quiche/blind_sign_auth/test_tools/mock_blind_sign_http_interface.h", +] +blind_sign_auth_srcs = [ + "src/quiche/blind_sign_auth/anonymous_tokens/cpp/client/anonymous_tokens_rsa_bssa_client.cc", + "src/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.cc", + "src/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blind_signer.cc", + "src/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blinder.cc", + "src/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_ssa_pss_verifier.cc", + "src/quiche/blind_sign_auth/anonymous_tokens/cpp/shared/proto_utils.cc", + "src/quiche/blind_sign_auth/anonymous_tokens/cpp/testing/utils.cc", + "src/quiche/blind_sign_auth/blind_sign_auth.cc", + "src/quiche/blind_sign_auth/cached_blind_sign_auth.cc", +] +blind_sign_auth_tests_hdrs = [ + +] +blind_sign_auth_tests_srcs = [ + "src/quiche/blind_sign_auth/anonymous_tokens/cpp/client/anonymous_tokens_rsa_bssa_client_test.cc", + "src/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/at_crypto_utils_test.cc", + "src/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blind_signer_test.cc", + "src/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blinder_test.cc", + "src/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_ssa_pss_verifier_test.cc", + "src/quiche/blind_sign_auth/anonymous_tokens/cpp/shared/proto_utils_test.cc", + "src/quiche/blind_sign_auth/blind_sign_auth_test.cc", + "src/quiche/blind_sign_auth/cached_blind_sign_auth_test.cc", +] +protobuf_blind_sign_auth = [ + "src/quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.proto", + "src/quiche/blind_sign_auth/proto/any.proto", + "src/quiche/blind_sign_auth/proto/attestation.proto", + "src/quiche/blind_sign_auth/proto/auth_and_sign.proto", + "src/quiche/blind_sign_auth/proto/get_initial_data.proto", + "src/quiche/blind_sign_auth/proto/key_services.proto", + "src/quiche/blind_sign_auth/proto/public_metadata.proto", + "src/quiche/blind_sign_auth/proto/spend_token_data.proto", + "src/quiche/blind_sign_auth/proto/timestamp.proto", +] +libevent_hdrs = [ + "src/quiche/quic/bindings/quic_libevent.h", +] +libevent_srcs = [ + "src/quiche/quic/bindings/quic_libevent.cc", + "src/quiche/quic/bindings/quic_libevent_test.cc", +] +linux_only_hdrs = [ + "src/quiche/quic/core/batch_writer/quic_batch_writer_base.h", + "src/quiche/quic/core/batch_writer/quic_batch_writer_buffer.h", + "src/quiche/quic/core/batch_writer/quic_batch_writer_test.h", + "src/quiche/quic/core/batch_writer/quic_gso_batch_writer.h", + "src/quiche/quic/core/batch_writer/quic_sendmmsg_batch_writer.h", + "src/quiche/quic/core/quic_linux_socket_utils.h", +] +linux_only_srcs = [ + "src/quiche/quic/core/batch_writer/quic_batch_writer_base.cc", + "src/quiche/quic/core/batch_writer/quic_batch_writer_buffer.cc", + "src/quiche/quic/core/batch_writer/quic_gso_batch_writer.cc", + "src/quiche/quic/core/batch_writer/quic_sendmmsg_batch_writer.cc", + "src/quiche/quic/core/quic_linux_socket_utils.cc", +] +linux_only_tests_hdrs = [ + +] +linux_only_tests_srcs = [ + "src/quiche/quic/core/batch_writer/quic_batch_writer_buffer_test.cc", + "src/quiche/quic/core/batch_writer/quic_batch_writer_test.cc", + "src/quiche/quic/core/batch_writer/quic_gso_batch_writer_test.cc", + "src/quiche/quic/core/batch_writer/quic_sendmmsg_batch_writer_test.cc", + "src/quiche/quic/core/quic_linux_socket_utils_test.cc", +] diff --git a/build/source_list.json b/build/source_list.json new file mode 100644 index 000000000000..3e8c573c5b7c --- /dev/null +++ b/build/source_list.json @@ -0,0 +1,1637 @@ +{ + "protobuf": [ + "quiche/quic/core/proto/cached_network_parameters.proto", + "quiche/quic/core/proto/crypto_server_config.proto", + "quiche/quic/core/proto/source_address_token.proto" + ], + "protobuf_test_support": [ + "quiche/quic/test_tools/send_algorithm_test_result.proto" + ], + "quiche_core_hdrs": [ + "quiche/balsa/balsa_enums.h", + "quiche/balsa/balsa_frame.h", + "quiche/balsa/balsa_headers.h", + "quiche/balsa/balsa_visitor_interface.h", + "quiche/balsa/framer_interface.h", + "quiche/balsa/header_api.h", + "quiche/balsa/header_properties.h", + "quiche/balsa/http_validation_policy.h", + "quiche/balsa/noop_balsa_visitor.h", + "quiche/balsa/simple_buffer.h", + "quiche/balsa/standard_header_map.h", + "quiche/common/btree_scheduler.h", + "quiche/common/capsule.h", + "quiche/common/masque/connect_udp_datagram_payload.h", + "quiche/common/platform/api/quiche_bug_tracker.h", + "quiche/common/platform/api/quiche_client_stats.h", + "quiche/common/platform/api/quiche_containers.h", + "quiche/common/platform/api/quiche_export.h", + "quiche/common/platform/api/quiche_flag_utils.h", + "quiche/common/platform/api/quiche_flags.h", + "quiche/common/platform/api/quiche_header_policy.h", + "quiche/common/platform/api/quiche_hostname_utils.h", + "quiche/common/platform/api/quiche_iovec.h", + "quiche/common/platform/api/quiche_logging.h", + "quiche/common/platform/api/quiche_lower_case_string.h", + "quiche/common/platform/api/quiche_mem_slice.h", + "quiche/common/platform/api/quiche_mutex.h", + "quiche/common/platform/api/quiche_prefetch.h", + "quiche/common/platform/api/quiche_reference_counted.h", + "quiche/common/platform/api/quiche_server_stats.h", + "quiche/common/platform/api/quiche_stack_trace.h", + "quiche/common/platform/api/quiche_testvalue.h", + "quiche/common/platform/api/quiche_thread.h", + "quiche/common/platform/api/quiche_time_utils.h", + "quiche/common/platform/api/quiche_url_utils.h", + "quiche/common/print_elements.h", + "quiche/common/quiche_buffer_allocator.h", + "quiche/common/quiche_circular_deque.h", + "quiche/common/quiche_crypto_logging.h", + "quiche/common/quiche_data_reader.h", + "quiche/common/quiche_data_writer.h", + "quiche/common/quiche_endian.h", + "quiche/common/quiche_ip_address.h", + "quiche/common/quiche_ip_address_family.h", + "quiche/common/quiche_linked_hash_map.h", + "quiche/common/quiche_mem_slice_storage.h", + "quiche/common/quiche_protocol_flags_list.h", + "quiche/common/quiche_random.h", + "quiche/common/quiche_status_utils.h", + "quiche/common/quiche_stream.h", + "quiche/common/quiche_text_utils.h", + "quiche/common/simple_buffer_allocator.h", + "quiche/common/structured_headers.h", + "quiche/common/wire_serialization.h", + "quiche/http2/adapter/data_source.h", + "quiche/http2/adapter/event_forwarder.h", + "quiche/http2/adapter/header_validator.h", + "quiche/http2/adapter/header_validator_base.h", + "quiche/http2/adapter/http2_adapter.h", + "quiche/http2/adapter/http2_protocol.h", + "quiche/http2/adapter/http2_session.h", + "quiche/http2/adapter/http2_util.h", + "quiche/http2/adapter/http2_visitor_interface.h", + "quiche/http2/adapter/noop_header_validator.h", + "quiche/http2/adapter/oghttp2_adapter.h", + "quiche/http2/adapter/oghttp2_session.h", + "quiche/http2/adapter/oghttp2_util.h", + "quiche/http2/adapter/window_manager.h", + "quiche/http2/core/http2_trace_logging.h", + "quiche/http2/core/priority_write_scheduler.h", + "quiche/http2/decoder/decode_buffer.h", + "quiche/http2/decoder/decode_http2_structures.h", + "quiche/http2/decoder/decode_status.h", + "quiche/http2/decoder/frame_decoder_state.h", + "quiche/http2/decoder/http2_frame_decoder.h", + "quiche/http2/decoder/http2_frame_decoder_listener.h", + "quiche/http2/decoder/http2_structure_decoder.h", + "quiche/http2/decoder/payload_decoders/altsvc_payload_decoder.h", + "quiche/http2/decoder/payload_decoders/continuation_payload_decoder.h", + "quiche/http2/decoder/payload_decoders/data_payload_decoder.h", + "quiche/http2/decoder/payload_decoders/goaway_payload_decoder.h", + "quiche/http2/decoder/payload_decoders/headers_payload_decoder.h", + "quiche/http2/decoder/payload_decoders/ping_payload_decoder.h", + "quiche/http2/decoder/payload_decoders/priority_payload_decoder.h", + "quiche/http2/decoder/payload_decoders/priority_update_payload_decoder.h", + "quiche/http2/decoder/payload_decoders/push_promise_payload_decoder.h", + "quiche/http2/decoder/payload_decoders/rst_stream_payload_decoder.h", + "quiche/http2/decoder/payload_decoders/settings_payload_decoder.h", + "quiche/http2/decoder/payload_decoders/unknown_payload_decoder.h", + "quiche/http2/decoder/payload_decoders/window_update_payload_decoder.h", + "quiche/http2/hpack/decoder/hpack_block_decoder.h", + "quiche/http2/hpack/decoder/hpack_decoder.h", + "quiche/http2/hpack/decoder/hpack_decoder_listener.h", + "quiche/http2/hpack/decoder/hpack_decoder_state.h", + "quiche/http2/hpack/decoder/hpack_decoder_string_buffer.h", + "quiche/http2/hpack/decoder/hpack_decoder_tables.h", + "quiche/http2/hpack/decoder/hpack_decoding_error.h", + "quiche/http2/hpack/decoder/hpack_entry_decoder.h", + "quiche/http2/hpack/decoder/hpack_entry_decoder_listener.h", + "quiche/http2/hpack/decoder/hpack_entry_type_decoder.h", + "quiche/http2/hpack/decoder/hpack_string_decoder.h", + "quiche/http2/hpack/decoder/hpack_string_decoder_listener.h", + "quiche/http2/hpack/decoder/hpack_whole_entry_buffer.h", + "quiche/http2/hpack/decoder/hpack_whole_entry_listener.h", + "quiche/http2/hpack/http2_hpack_constants.h", + "quiche/http2/hpack/huffman/hpack_huffman_decoder.h", + "quiche/http2/hpack/huffman/hpack_huffman_encoder.h", + "quiche/http2/hpack/huffman/huffman_spec_tables.h", + "quiche/http2/hpack/varint/hpack_varint_decoder.h", + "quiche/http2/hpack/varint/hpack_varint_encoder.h", + "quiche/http2/http2_constants.h", + "quiche/http2/http2_structures.h", + "quiche/quic/core/chlo_extractor.h", + "quiche/quic/core/congestion_control/bandwidth_sampler.h", + "quiche/quic/core/congestion_control/bbr2_drain.h", + "quiche/quic/core/congestion_control/bbr2_misc.h", + "quiche/quic/core/congestion_control/bbr2_probe_bw.h", + "quiche/quic/core/congestion_control/bbr2_probe_rtt.h", + "quiche/quic/core/congestion_control/bbr2_sender.h", + "quiche/quic/core/congestion_control/bbr2_startup.h", + "quiche/quic/core/congestion_control/bbr_sender.h", + "quiche/quic/core/congestion_control/cubic_bytes.h", + "quiche/quic/core/congestion_control/general_loss_algorithm.h", + "quiche/quic/core/congestion_control/hybrid_slow_start.h", + "quiche/quic/core/congestion_control/loss_detection_interface.h", + "quiche/quic/core/congestion_control/pacing_sender.h", + "quiche/quic/core/congestion_control/prr_sender.h", + "quiche/quic/core/congestion_control/rtt_stats.h", + "quiche/quic/core/congestion_control/send_algorithm_interface.h", + "quiche/quic/core/congestion_control/tcp_cubic_sender_bytes.h", + "quiche/quic/core/congestion_control/uber_loss_algorithm.h", + "quiche/quic/core/congestion_control/windowed_filter.h", + "quiche/quic/core/connecting_client_socket.h", + "quiche/quic/core/connection_id_generator.h", + "quiche/quic/core/crypto/aead_base_decrypter.h", + "quiche/quic/core/crypto/aead_base_encrypter.h", + "quiche/quic/core/crypto/aes_128_gcm_12_decrypter.h", + "quiche/quic/core/crypto/aes_128_gcm_12_encrypter.h", + "quiche/quic/core/crypto/aes_128_gcm_decrypter.h", + "quiche/quic/core/crypto/aes_128_gcm_encrypter.h", + "quiche/quic/core/crypto/aes_256_gcm_decrypter.h", + "quiche/quic/core/crypto/aes_256_gcm_encrypter.h", + "quiche/quic/core/crypto/aes_base_decrypter.h", + "quiche/quic/core/crypto/aes_base_encrypter.h", + "quiche/quic/core/crypto/boring_utils.h", + "quiche/quic/core/crypto/cert_compressor.h", + "quiche/quic/core/crypto/certificate_util.h", + "quiche/quic/core/crypto/certificate_view.h", + "quiche/quic/core/crypto/chacha20_poly1305_decrypter.h", + "quiche/quic/core/crypto/chacha20_poly1305_encrypter.h", + "quiche/quic/core/crypto/chacha20_poly1305_tls_decrypter.h", + "quiche/quic/core/crypto/chacha20_poly1305_tls_encrypter.h", + "quiche/quic/core/crypto/chacha_base_decrypter.h", + "quiche/quic/core/crypto/chacha_base_encrypter.h", + "quiche/quic/core/crypto/channel_id.h", + "quiche/quic/core/crypto/client_proof_source.h", + "quiche/quic/core/crypto/crypto_framer.h", + "quiche/quic/core/crypto/crypto_handshake.h", + "quiche/quic/core/crypto/crypto_handshake_message.h", + "quiche/quic/core/crypto/crypto_message_parser.h", + "quiche/quic/core/crypto/crypto_protocol.h", + "quiche/quic/core/crypto/crypto_secret_boxer.h", + "quiche/quic/core/crypto/crypto_utils.h", + "quiche/quic/core/crypto/curve25519_key_exchange.h", + "quiche/quic/core/crypto/key_exchange.h", + "quiche/quic/core/crypto/null_decrypter.h", + "quiche/quic/core/crypto/null_encrypter.h", + "quiche/quic/core/crypto/p256_key_exchange.h", + "quiche/quic/core/crypto/proof_source.h", + "quiche/quic/core/crypto/proof_source_x509.h", + "quiche/quic/core/crypto/proof_verifier.h", + "quiche/quic/core/crypto/quic_client_session_cache.h", + "quiche/quic/core/crypto/quic_compressed_certs_cache.h", + "quiche/quic/core/crypto/quic_crypter.h", + "quiche/quic/core/crypto/quic_crypto_client_config.h", + "quiche/quic/core/crypto/quic_crypto_proof.h", + "quiche/quic/core/crypto/quic_crypto_server_config.h", + "quiche/quic/core/crypto/quic_decrypter.h", + "quiche/quic/core/crypto/quic_encrypter.h", + "quiche/quic/core/crypto/quic_hkdf.h", + "quiche/quic/core/crypto/quic_random.h", + "quiche/quic/core/crypto/tls_client_connection.h", + "quiche/quic/core/crypto/tls_connection.h", + "quiche/quic/core/crypto/tls_server_connection.h", + "quiche/quic/core/crypto/transport_parameters.h", + "quiche/quic/core/crypto/web_transport_fingerprint_proof_verifier.h", + "quiche/quic/core/deterministic_connection_id_generator.h", + "quiche/quic/core/frames/quic_ack_frame.h", + "quiche/quic/core/frames/quic_ack_frequency_frame.h", + "quiche/quic/core/frames/quic_blocked_frame.h", + "quiche/quic/core/frames/quic_connection_close_frame.h", + "quiche/quic/core/frames/quic_crypto_frame.h", + "quiche/quic/core/frames/quic_frame.h", + "quiche/quic/core/frames/quic_goaway_frame.h", + "quiche/quic/core/frames/quic_handshake_done_frame.h", + "quiche/quic/core/frames/quic_inlined_frame.h", + "quiche/quic/core/frames/quic_max_streams_frame.h", + "quiche/quic/core/frames/quic_message_frame.h", + "quiche/quic/core/frames/quic_mtu_discovery_frame.h", + "quiche/quic/core/frames/quic_new_connection_id_frame.h", + "quiche/quic/core/frames/quic_new_token_frame.h", + "quiche/quic/core/frames/quic_padding_frame.h", + "quiche/quic/core/frames/quic_path_challenge_frame.h", + "quiche/quic/core/frames/quic_path_response_frame.h", + "quiche/quic/core/frames/quic_ping_frame.h", + "quiche/quic/core/frames/quic_retire_connection_id_frame.h", + "quiche/quic/core/frames/quic_rst_stream_frame.h", + "quiche/quic/core/frames/quic_stop_sending_frame.h", + "quiche/quic/core/frames/quic_stop_waiting_frame.h", + "quiche/quic/core/frames/quic_stream_frame.h", + "quiche/quic/core/frames/quic_streams_blocked_frame.h", + "quiche/quic/core/frames/quic_window_update_frame.h", + "quiche/quic/core/handshaker_delegate_interface.h", + "quiche/quic/core/http/http_constants.h", + "quiche/quic/core/http/http_decoder.h", + "quiche/quic/core/http/http_encoder.h", + "quiche/quic/core/http/http_frames.h", + "quiche/quic/core/http/quic_client_promised_info.h", + "quiche/quic/core/http/quic_client_push_promise_index.h", + "quiche/quic/core/http/quic_header_list.h", + "quiche/quic/core/http/quic_headers_stream.h", + "quiche/quic/core/http/quic_receive_control_stream.h", + "quiche/quic/core/http/quic_send_control_stream.h", + "quiche/quic/core/http/quic_server_initiated_spdy_stream.h", + "quiche/quic/core/http/quic_server_session_base.h", + "quiche/quic/core/http/quic_spdy_client_session.h", + "quiche/quic/core/http/quic_spdy_client_session_base.h", + "quiche/quic/core/http/quic_spdy_client_stream.h", + "quiche/quic/core/http/quic_spdy_server_stream_base.h", + "quiche/quic/core/http/quic_spdy_session.h", + "quiche/quic/core/http/quic_spdy_stream.h", + "quiche/quic/core/http/quic_spdy_stream_body_manager.h", + "quiche/quic/core/http/spdy_server_push_utils.h", + "quiche/quic/core/http/spdy_utils.h", + "quiche/quic/core/http/web_transport_http3.h", + "quiche/quic/core/http/web_transport_stream_adapter.h", + "quiche/quic/core/legacy_quic_stream_id_manager.h", + "quiche/quic/core/packet_number_indexed_queue.h", + "quiche/quic/core/proto/cached_network_parameters_proto.h", + "quiche/quic/core/proto/crypto_server_config_proto.h", + "quiche/quic/core/proto/source_address_token_proto.h", + "quiche/quic/core/qpack/qpack_blocking_manager.h", + "quiche/quic/core/qpack/qpack_decoded_headers_accumulator.h", + "quiche/quic/core/qpack/qpack_decoder.h", + "quiche/quic/core/qpack/qpack_decoder_stream_receiver.h", + "quiche/quic/core/qpack/qpack_decoder_stream_sender.h", + "quiche/quic/core/qpack/qpack_encoder.h", + "quiche/quic/core/qpack/qpack_encoder_stream_receiver.h", + "quiche/quic/core/qpack/qpack_encoder_stream_sender.h", + "quiche/quic/core/qpack/qpack_header_table.h", + "quiche/quic/core/qpack/qpack_index_conversions.h", + "quiche/quic/core/qpack/qpack_instruction_decoder.h", + "quiche/quic/core/qpack/qpack_instruction_encoder.h", + "quiche/quic/core/qpack/qpack_instructions.h", + "quiche/quic/core/qpack/qpack_progressive_decoder.h", + "quiche/quic/core/qpack/qpack_receive_stream.h", + "quiche/quic/core/qpack/qpack_required_insert_count.h", + "quiche/quic/core/qpack/qpack_send_stream.h", + "quiche/quic/core/qpack/qpack_static_table.h", + "quiche/quic/core/qpack/qpack_stream_receiver.h", + "quiche/quic/core/qpack/qpack_stream_sender_delegate.h", + "quiche/quic/core/qpack/value_splitting_header_list.h", + "quiche/quic/core/quic_ack_listener_interface.h", + "quiche/quic/core/quic_alarm.h", + "quiche/quic/core/quic_alarm_factory.h", + "quiche/quic/core/quic_arena_scoped_ptr.h", + "quiche/quic/core/quic_bandwidth.h", + "quiche/quic/core/quic_blocked_writer_interface.h", + "quiche/quic/core/quic_buffered_packet_store.h", + "quiche/quic/core/quic_chaos_protector.h", + "quiche/quic/core/quic_clock.h", + "quiche/quic/core/quic_coalesced_packet.h", + "quiche/quic/core/quic_config.h", + "quiche/quic/core/quic_connection.h", + "quiche/quic/core/quic_connection_context.h", + "quiche/quic/core/quic_connection_id.h", + "quiche/quic/core/quic_connection_id_manager.h", + "quiche/quic/core/quic_connection_stats.h", + "quiche/quic/core/quic_constants.h", + "quiche/quic/core/quic_control_frame_manager.h", + "quiche/quic/core/quic_crypto_client_handshaker.h", + "quiche/quic/core/quic_crypto_client_stream.h", + "quiche/quic/core/quic_crypto_handshaker.h", + "quiche/quic/core/quic_crypto_server_stream.h", + "quiche/quic/core/quic_crypto_server_stream_base.h", + "quiche/quic/core/quic_crypto_stream.h", + "quiche/quic/core/quic_data_reader.h", + "quiche/quic/core/quic_data_writer.h", + "quiche/quic/core/quic_datagram_queue.h", + "quiche/quic/core/quic_default_clock.h", + "quiche/quic/core/quic_default_connection_helper.h", + "quiche/quic/core/quic_dispatcher.h", + "quiche/quic/core/quic_error_codes.h", + "quiche/quic/core/quic_flags_list.h", + "quiche/quic/core/quic_flow_controller.h", + "quiche/quic/core/quic_framer.h", + "quiche/quic/core/quic_idle_network_detector.h", + "quiche/quic/core/quic_interval.h", + "quiche/quic/core/quic_interval_deque.h", + "quiche/quic/core/quic_interval_set.h", + "quiche/quic/core/quic_lru_cache.h", + "quiche/quic/core/quic_mtu_discovery.h", + "quiche/quic/core/quic_network_blackhole_detector.h", + "quiche/quic/core/quic_one_block_arena.h", + "quiche/quic/core/quic_packet_creator.h", + "quiche/quic/core/quic_packet_number.h", + "quiche/quic/core/quic_packet_writer.h", + "quiche/quic/core/quic_packet_writer_wrapper.h", + "quiche/quic/core/quic_packets.h", + "quiche/quic/core/quic_path_validator.h", + "quiche/quic/core/quic_ping_manager.h", + "quiche/quic/core/quic_process_packet_interface.h", + "quiche/quic/core/quic_protocol_flags_list.h", + "quiche/quic/core/quic_received_packet_manager.h", + "quiche/quic/core/quic_sent_packet_manager.h", + "quiche/quic/core/quic_server_id.h", + "quiche/quic/core/quic_session.h", + "quiche/quic/core/quic_socket_address_coder.h", + "quiche/quic/core/quic_stream.h", + "quiche/quic/core/quic_stream_frame_data_producer.h", + "quiche/quic/core/quic_stream_id_manager.h", + "quiche/quic/core/quic_stream_priority.h", + "quiche/quic/core/quic_stream_send_buffer.h", + "quiche/quic/core/quic_stream_sequencer.h", + "quiche/quic/core/quic_stream_sequencer_buffer.h", + "quiche/quic/core/quic_sustained_bandwidth_recorder.h", + "quiche/quic/core/quic_tag.h", + "quiche/quic/core/quic_time.h", + "quiche/quic/core/quic_time_accumulator.h", + "quiche/quic/core/quic_time_wait_list_manager.h", + "quiche/quic/core/quic_trace_visitor.h", + "quiche/quic/core/quic_transmission_info.h", + "quiche/quic/core/quic_types.h", + "quiche/quic/core/quic_unacked_packet_map.h", + "quiche/quic/core/quic_utils.h", + "quiche/quic/core/quic_version_manager.h", + "quiche/quic/core/quic_versions.h", + "quiche/quic/core/quic_write_blocked_list.h", + "quiche/quic/core/session_notifier_interface.h", + "quiche/quic/core/socket_factory.h", + "quiche/quic/core/stream_delegate_interface.h", + "quiche/quic/core/tls_chlo_extractor.h", + "quiche/quic/core/tls_client_handshaker.h", + "quiche/quic/core/tls_handshaker.h", + "quiche/quic/core/tls_server_handshaker.h", + "quiche/quic/core/uber_quic_stream_id_manager.h", + "quiche/quic/core/uber_received_packet_manager.h", + "quiche/quic/core/web_transport_interface.h", + "quiche/quic/platform/api/quic_bug_tracker.h", + "quiche/quic/platform/api/quic_client_stats.h", + "quiche/quic/platform/api/quic_export.h", + "quiche/quic/platform/api/quic_exported_stats.h", + "quiche/quic/platform/api/quic_flag_utils.h", + "quiche/quic/platform/api/quic_flags.h", + "quiche/quic/platform/api/quic_hostname_utils.h", + "quiche/quic/platform/api/quic_ip_address.h", + "quiche/quic/platform/api/quic_ip_address_family.h", + "quiche/quic/platform/api/quic_logging.h", + "quiche/quic/platform/api/quic_mutex.h", + "quiche/quic/platform/api/quic_server_stats.h", + "quiche/quic/platform/api/quic_socket_address.h", + "quiche/quic/platform/api/quic_stack_trace.h", + "quiche/quic/platform/api/quic_testvalue.h", + "quiche/quic/platform/api/quic_thread.h", + "quiche/spdy/core/array_output_buffer.h", + "quiche/spdy/core/header_byte_listener_interface.h", + "quiche/spdy/core/hpack/hpack_constants.h", + "quiche/spdy/core/hpack/hpack_decoder_adapter.h", + "quiche/spdy/core/hpack/hpack_encoder.h", + "quiche/spdy/core/hpack/hpack_entry.h", + "quiche/spdy/core/hpack/hpack_header_table.h", + "quiche/spdy/core/hpack/hpack_output_stream.h", + "quiche/spdy/core/hpack/hpack_static_table.h", + "quiche/spdy/core/http2_frame_decoder_adapter.h", + "quiche/spdy/core/http2_header_block.h", + "quiche/spdy/core/http2_header_block_hpack_listener.h", + "quiche/spdy/core/http2_header_storage.h", + "quiche/spdy/core/metadata_extension.h", + "quiche/spdy/core/no_op_headers_handler.h", + "quiche/spdy/core/recording_headers_handler.h", + "quiche/spdy/core/spdy_alt_svc_wire_format.h", + "quiche/spdy/core/spdy_bitmasks.h", + "quiche/spdy/core/spdy_frame_builder.h", + "quiche/spdy/core/spdy_framer.h", + "quiche/spdy/core/spdy_headers_handler_interface.h", + "quiche/spdy/core/spdy_intrusive_list.h", + "quiche/spdy/core/spdy_no_op_visitor.h", + "quiche/spdy/core/spdy_pinnable_buffer_piece.h", + "quiche/spdy/core/spdy_prefixed_buffer_reader.h", + "quiche/spdy/core/spdy_protocol.h", + "quiche/spdy/core/spdy_simple_arena.h", + "quiche/spdy/core/zero_copy_output_buffer.h", + "quiche/web_transport/web_transport.h" + ], + "quiche_core_srcs": [ + "quiche/balsa/balsa_enums.cc", + "quiche/balsa/balsa_frame.cc", + "quiche/balsa/balsa_headers.cc", + "quiche/balsa/header_properties.cc", + "quiche/balsa/simple_buffer.cc", + "quiche/balsa/standard_header_map.cc", + "quiche/common/capsule.cc", + "quiche/common/masque/connect_udp_datagram_payload.cc", + "quiche/common/platform/api/quiche_hostname_utils.cc", + "quiche/common/platform/api/quiche_mutex.cc", + "quiche/common/quiche_buffer_allocator.cc", + "quiche/common/quiche_crypto_logging.cc", + "quiche/common/quiche_data_reader.cc", + "quiche/common/quiche_data_writer.cc", + "quiche/common/quiche_ip_address.cc", + "quiche/common/quiche_ip_address_family.cc", + "quiche/common/quiche_mem_slice_storage.cc", + "quiche/common/quiche_random.cc", + "quiche/common/quiche_text_utils.cc", + "quiche/common/simple_buffer_allocator.cc", + "quiche/common/structured_headers.cc", + "quiche/http2/adapter/event_forwarder.cc", + "quiche/http2/adapter/header_validator.cc", + "quiche/http2/adapter/http2_protocol.cc", + "quiche/http2/adapter/http2_util.cc", + "quiche/http2/adapter/noop_header_validator.cc", + "quiche/http2/adapter/oghttp2_adapter.cc", + "quiche/http2/adapter/oghttp2_session.cc", + "quiche/http2/adapter/oghttp2_util.cc", + "quiche/http2/adapter/window_manager.cc", + "quiche/http2/core/http2_trace_logging.cc", + "quiche/http2/decoder/decode_buffer.cc", + "quiche/http2/decoder/decode_http2_structures.cc", + "quiche/http2/decoder/decode_status.cc", + "quiche/http2/decoder/frame_decoder_state.cc", + "quiche/http2/decoder/http2_frame_decoder.cc", + "quiche/http2/decoder/http2_frame_decoder_listener.cc", + "quiche/http2/decoder/http2_structure_decoder.cc", + "quiche/http2/decoder/payload_decoders/altsvc_payload_decoder.cc", + "quiche/http2/decoder/payload_decoders/continuation_payload_decoder.cc", + "quiche/http2/decoder/payload_decoders/data_payload_decoder.cc", + "quiche/http2/decoder/payload_decoders/goaway_payload_decoder.cc", + "quiche/http2/decoder/payload_decoders/headers_payload_decoder.cc", + "quiche/http2/decoder/payload_decoders/ping_payload_decoder.cc", + "quiche/http2/decoder/payload_decoders/priority_payload_decoder.cc", + "quiche/http2/decoder/payload_decoders/priority_update_payload_decoder.cc", + "quiche/http2/decoder/payload_decoders/push_promise_payload_decoder.cc", + "quiche/http2/decoder/payload_decoders/rst_stream_payload_decoder.cc", + "quiche/http2/decoder/payload_decoders/settings_payload_decoder.cc", + "quiche/http2/decoder/payload_decoders/unknown_payload_decoder.cc", + "quiche/http2/decoder/payload_decoders/window_update_payload_decoder.cc", + "quiche/http2/hpack/decoder/hpack_block_decoder.cc", + "quiche/http2/hpack/decoder/hpack_decoder.cc", + "quiche/http2/hpack/decoder/hpack_decoder_listener.cc", + "quiche/http2/hpack/decoder/hpack_decoder_state.cc", + "quiche/http2/hpack/decoder/hpack_decoder_string_buffer.cc", + "quiche/http2/hpack/decoder/hpack_decoder_tables.cc", + "quiche/http2/hpack/decoder/hpack_decoding_error.cc", + "quiche/http2/hpack/decoder/hpack_entry_decoder.cc", + "quiche/http2/hpack/decoder/hpack_entry_decoder_listener.cc", + "quiche/http2/hpack/decoder/hpack_entry_type_decoder.cc", + "quiche/http2/hpack/decoder/hpack_string_decoder.cc", + "quiche/http2/hpack/decoder/hpack_string_decoder_listener.cc", + "quiche/http2/hpack/decoder/hpack_whole_entry_buffer.cc", + "quiche/http2/hpack/decoder/hpack_whole_entry_listener.cc", + "quiche/http2/hpack/http2_hpack_constants.cc", + "quiche/http2/hpack/huffman/hpack_huffman_decoder.cc", + "quiche/http2/hpack/huffman/hpack_huffman_encoder.cc", + "quiche/http2/hpack/huffman/huffman_spec_tables.cc", + "quiche/http2/hpack/varint/hpack_varint_decoder.cc", + "quiche/http2/hpack/varint/hpack_varint_encoder.cc", + "quiche/http2/http2_constants.cc", + "quiche/http2/http2_structures.cc", + "quiche/quic/core/chlo_extractor.cc", + "quiche/quic/core/congestion_control/bandwidth_sampler.cc", + "quiche/quic/core/congestion_control/bbr2_drain.cc", + "quiche/quic/core/congestion_control/bbr2_misc.cc", + "quiche/quic/core/congestion_control/bbr2_probe_bw.cc", + "quiche/quic/core/congestion_control/bbr2_probe_rtt.cc", + "quiche/quic/core/congestion_control/bbr2_sender.cc", + "quiche/quic/core/congestion_control/bbr2_startup.cc", + "quiche/quic/core/congestion_control/bbr_sender.cc", + "quiche/quic/core/congestion_control/cubic_bytes.cc", + "quiche/quic/core/congestion_control/general_loss_algorithm.cc", + "quiche/quic/core/congestion_control/hybrid_slow_start.cc", + "quiche/quic/core/congestion_control/pacing_sender.cc", + "quiche/quic/core/congestion_control/prr_sender.cc", + "quiche/quic/core/congestion_control/rtt_stats.cc", + "quiche/quic/core/congestion_control/send_algorithm_interface.cc", + "quiche/quic/core/congestion_control/tcp_cubic_sender_bytes.cc", + "quiche/quic/core/congestion_control/uber_loss_algorithm.cc", + "quiche/quic/core/crypto/aead_base_decrypter.cc", + "quiche/quic/core/crypto/aead_base_encrypter.cc", + "quiche/quic/core/crypto/aes_128_gcm_12_decrypter.cc", + "quiche/quic/core/crypto/aes_128_gcm_12_encrypter.cc", + "quiche/quic/core/crypto/aes_128_gcm_decrypter.cc", + "quiche/quic/core/crypto/aes_128_gcm_encrypter.cc", + "quiche/quic/core/crypto/aes_256_gcm_decrypter.cc", + "quiche/quic/core/crypto/aes_256_gcm_encrypter.cc", + "quiche/quic/core/crypto/aes_base_decrypter.cc", + "quiche/quic/core/crypto/aes_base_encrypter.cc", + "quiche/quic/core/crypto/cert_compressor.cc", + "quiche/quic/core/crypto/certificate_util.cc", + "quiche/quic/core/crypto/certificate_view.cc", + "quiche/quic/core/crypto/chacha20_poly1305_decrypter.cc", + "quiche/quic/core/crypto/chacha20_poly1305_encrypter.cc", + "quiche/quic/core/crypto/chacha20_poly1305_tls_decrypter.cc", + "quiche/quic/core/crypto/chacha20_poly1305_tls_encrypter.cc", + "quiche/quic/core/crypto/chacha_base_decrypter.cc", + "quiche/quic/core/crypto/chacha_base_encrypter.cc", + "quiche/quic/core/crypto/channel_id.cc", + "quiche/quic/core/crypto/client_proof_source.cc", + "quiche/quic/core/crypto/crypto_framer.cc", + "quiche/quic/core/crypto/crypto_handshake.cc", + "quiche/quic/core/crypto/crypto_handshake_message.cc", + "quiche/quic/core/crypto/crypto_secret_boxer.cc", + "quiche/quic/core/crypto/crypto_utils.cc", + "quiche/quic/core/crypto/curve25519_key_exchange.cc", + "quiche/quic/core/crypto/key_exchange.cc", + "quiche/quic/core/crypto/null_decrypter.cc", + "quiche/quic/core/crypto/null_encrypter.cc", + "quiche/quic/core/crypto/p256_key_exchange.cc", + "quiche/quic/core/crypto/proof_source.cc", + "quiche/quic/core/crypto/proof_source_x509.cc", + "quiche/quic/core/crypto/quic_client_session_cache.cc", + "quiche/quic/core/crypto/quic_compressed_certs_cache.cc", + "quiche/quic/core/crypto/quic_crypter.cc", + "quiche/quic/core/crypto/quic_crypto_client_config.cc", + "quiche/quic/core/crypto/quic_crypto_proof.cc", + "quiche/quic/core/crypto/quic_crypto_server_config.cc", + "quiche/quic/core/crypto/quic_decrypter.cc", + "quiche/quic/core/crypto/quic_encrypter.cc", + "quiche/quic/core/crypto/quic_hkdf.cc", + "quiche/quic/core/crypto/tls_client_connection.cc", + "quiche/quic/core/crypto/tls_connection.cc", + "quiche/quic/core/crypto/tls_server_connection.cc", + "quiche/quic/core/crypto/transport_parameters.cc", + "quiche/quic/core/crypto/web_transport_fingerprint_proof_verifier.cc", + "quiche/quic/core/deterministic_connection_id_generator.cc", + "quiche/quic/core/frames/quic_ack_frame.cc", + "quiche/quic/core/frames/quic_ack_frequency_frame.cc", + "quiche/quic/core/frames/quic_blocked_frame.cc", + "quiche/quic/core/frames/quic_connection_close_frame.cc", + "quiche/quic/core/frames/quic_crypto_frame.cc", + "quiche/quic/core/frames/quic_frame.cc", + "quiche/quic/core/frames/quic_goaway_frame.cc", + "quiche/quic/core/frames/quic_handshake_done_frame.cc", + "quiche/quic/core/frames/quic_max_streams_frame.cc", + "quiche/quic/core/frames/quic_message_frame.cc", + "quiche/quic/core/frames/quic_new_connection_id_frame.cc", + "quiche/quic/core/frames/quic_new_token_frame.cc", + "quiche/quic/core/frames/quic_padding_frame.cc", + "quiche/quic/core/frames/quic_path_challenge_frame.cc", + "quiche/quic/core/frames/quic_path_response_frame.cc", + "quiche/quic/core/frames/quic_ping_frame.cc", + "quiche/quic/core/frames/quic_retire_connection_id_frame.cc", + "quiche/quic/core/frames/quic_rst_stream_frame.cc", + "quiche/quic/core/frames/quic_stop_sending_frame.cc", + "quiche/quic/core/frames/quic_stop_waiting_frame.cc", + "quiche/quic/core/frames/quic_stream_frame.cc", + "quiche/quic/core/frames/quic_streams_blocked_frame.cc", + "quiche/quic/core/frames/quic_window_update_frame.cc", + "quiche/quic/core/http/http_constants.cc", + "quiche/quic/core/http/http_decoder.cc", + "quiche/quic/core/http/http_encoder.cc", + "quiche/quic/core/http/quic_client_promised_info.cc", + "quiche/quic/core/http/quic_client_push_promise_index.cc", + "quiche/quic/core/http/quic_header_list.cc", + "quiche/quic/core/http/quic_headers_stream.cc", + "quiche/quic/core/http/quic_receive_control_stream.cc", + "quiche/quic/core/http/quic_send_control_stream.cc", + "quiche/quic/core/http/quic_server_initiated_spdy_stream.cc", + "quiche/quic/core/http/quic_server_session_base.cc", + "quiche/quic/core/http/quic_spdy_client_session.cc", + "quiche/quic/core/http/quic_spdy_client_session_base.cc", + "quiche/quic/core/http/quic_spdy_client_stream.cc", + "quiche/quic/core/http/quic_spdy_server_stream_base.cc", + "quiche/quic/core/http/quic_spdy_session.cc", + "quiche/quic/core/http/quic_spdy_stream.cc", + "quiche/quic/core/http/quic_spdy_stream_body_manager.cc", + "quiche/quic/core/http/spdy_server_push_utils.cc", + "quiche/quic/core/http/spdy_utils.cc", + "quiche/quic/core/http/web_transport_http3.cc", + "quiche/quic/core/http/web_transport_stream_adapter.cc", + "quiche/quic/core/legacy_quic_stream_id_manager.cc", + "quiche/quic/core/qpack/qpack_blocking_manager.cc", + "quiche/quic/core/qpack/qpack_decoded_headers_accumulator.cc", + "quiche/quic/core/qpack/qpack_decoder.cc", + "quiche/quic/core/qpack/qpack_decoder_stream_receiver.cc", + "quiche/quic/core/qpack/qpack_decoder_stream_sender.cc", + "quiche/quic/core/qpack/qpack_encoder.cc", + "quiche/quic/core/qpack/qpack_encoder_stream_receiver.cc", + "quiche/quic/core/qpack/qpack_encoder_stream_sender.cc", + "quiche/quic/core/qpack/qpack_header_table.cc", + "quiche/quic/core/qpack/qpack_index_conversions.cc", + "quiche/quic/core/qpack/qpack_instruction_decoder.cc", + "quiche/quic/core/qpack/qpack_instruction_encoder.cc", + "quiche/quic/core/qpack/qpack_instructions.cc", + "quiche/quic/core/qpack/qpack_progressive_decoder.cc", + "quiche/quic/core/qpack/qpack_receive_stream.cc", + "quiche/quic/core/qpack/qpack_required_insert_count.cc", + "quiche/quic/core/qpack/qpack_send_stream.cc", + "quiche/quic/core/qpack/qpack_static_table.cc", + "quiche/quic/core/qpack/value_splitting_header_list.cc", + "quiche/quic/core/quic_ack_listener_interface.cc", + "quiche/quic/core/quic_alarm.cc", + "quiche/quic/core/quic_bandwidth.cc", + "quiche/quic/core/quic_buffered_packet_store.cc", + "quiche/quic/core/quic_chaos_protector.cc", + "quiche/quic/core/quic_coalesced_packet.cc", + "quiche/quic/core/quic_config.cc", + "quiche/quic/core/quic_connection.cc", + "quiche/quic/core/quic_connection_context.cc", + "quiche/quic/core/quic_connection_id.cc", + "quiche/quic/core/quic_connection_id_manager.cc", + "quiche/quic/core/quic_connection_stats.cc", + "quiche/quic/core/quic_constants.cc", + "quiche/quic/core/quic_control_frame_manager.cc", + "quiche/quic/core/quic_crypto_client_handshaker.cc", + "quiche/quic/core/quic_crypto_client_stream.cc", + "quiche/quic/core/quic_crypto_handshaker.cc", + "quiche/quic/core/quic_crypto_server_stream.cc", + "quiche/quic/core/quic_crypto_server_stream_base.cc", + "quiche/quic/core/quic_crypto_stream.cc", + "quiche/quic/core/quic_data_reader.cc", + "quiche/quic/core/quic_data_writer.cc", + "quiche/quic/core/quic_datagram_queue.cc", + "quiche/quic/core/quic_default_clock.cc", + "quiche/quic/core/quic_dispatcher.cc", + "quiche/quic/core/quic_error_codes.cc", + "quiche/quic/core/quic_flow_controller.cc", + "quiche/quic/core/quic_framer.cc", + "quiche/quic/core/quic_idle_network_detector.cc", + "quiche/quic/core/quic_mtu_discovery.cc", + "quiche/quic/core/quic_network_blackhole_detector.cc", + "quiche/quic/core/quic_packet_creator.cc", + "quiche/quic/core/quic_packet_number.cc", + "quiche/quic/core/quic_packet_writer_wrapper.cc", + "quiche/quic/core/quic_packets.cc", + "quiche/quic/core/quic_path_validator.cc", + "quiche/quic/core/quic_ping_manager.cc", + "quiche/quic/core/quic_received_packet_manager.cc", + "quiche/quic/core/quic_sent_packet_manager.cc", + "quiche/quic/core/quic_server_id.cc", + "quiche/quic/core/quic_session.cc", + "quiche/quic/core/quic_socket_address_coder.cc", + "quiche/quic/core/quic_stream.cc", + "quiche/quic/core/quic_stream_id_manager.cc", + "quiche/quic/core/quic_stream_priority.cc", + "quiche/quic/core/quic_stream_send_buffer.cc", + "quiche/quic/core/quic_stream_sequencer.cc", + "quiche/quic/core/quic_stream_sequencer_buffer.cc", + "quiche/quic/core/quic_sustained_bandwidth_recorder.cc", + "quiche/quic/core/quic_tag.cc", + "quiche/quic/core/quic_time.cc", + "quiche/quic/core/quic_time_wait_list_manager.cc", + "quiche/quic/core/quic_trace_visitor.cc", + "quiche/quic/core/quic_transmission_info.cc", + "quiche/quic/core/quic_types.cc", + "quiche/quic/core/quic_unacked_packet_map.cc", + "quiche/quic/core/quic_utils.cc", + "quiche/quic/core/quic_version_manager.cc", + "quiche/quic/core/quic_versions.cc", + "quiche/quic/core/quic_write_blocked_list.cc", + "quiche/quic/core/tls_chlo_extractor.cc", + "quiche/quic/core/tls_client_handshaker.cc", + "quiche/quic/core/tls_handshaker.cc", + "quiche/quic/core/tls_server_handshaker.cc", + "quiche/quic/core/uber_quic_stream_id_manager.cc", + "quiche/quic/core/uber_received_packet_manager.cc", + "quiche/quic/platform/api/quic_socket_address.cc", + "quiche/spdy/core/array_output_buffer.cc", + "quiche/spdy/core/hpack/hpack_constants.cc", + "quiche/spdy/core/hpack/hpack_decoder_adapter.cc", + "quiche/spdy/core/hpack/hpack_encoder.cc", + "quiche/spdy/core/hpack/hpack_entry.cc", + "quiche/spdy/core/hpack/hpack_header_table.cc", + "quiche/spdy/core/hpack/hpack_output_stream.cc", + "quiche/spdy/core/hpack/hpack_static_table.cc", + "quiche/spdy/core/http2_frame_decoder_adapter.cc", + "quiche/spdy/core/http2_header_block.cc", + "quiche/spdy/core/http2_header_storage.cc", + "quiche/spdy/core/metadata_extension.cc", + "quiche/spdy/core/recording_headers_handler.cc", + "quiche/spdy/core/spdy_alt_svc_wire_format.cc", + "quiche/spdy/core/spdy_frame_builder.cc", + "quiche/spdy/core/spdy_framer.cc", + "quiche/spdy/core/spdy_no_op_visitor.cc", + "quiche/spdy/core/spdy_pinnable_buffer_piece.cc", + "quiche/spdy/core/spdy_prefixed_buffer_reader.cc", + "quiche/spdy/core/spdy_protocol.cc", + "quiche/spdy/core/spdy_simple_arena.cc" + ], + "quiche_tool_support_hdrs": [ + "quiche/common/platform/api/quiche_command_line_flags.h", + "quiche/common/platform/api/quiche_default_proof_providers.h", + "quiche/common/platform/api/quiche_file_utils.h", + "quiche/common/platform/api/quiche_system_event_loop.h", + "quiche/quic/platform/api/quic_default_proof_providers.h", + "quiche/quic/tools/connect_server_backend.h", + "quiche/quic/tools/connect_tunnel.h", + "quiche/quic/tools/connect_udp_tunnel.h", + "quiche/quic/tools/fake_proof_verifier.h", + "quiche/quic/tools/quic_backend_response.h", + "quiche/quic/tools/quic_client_base.h", + "quiche/quic/tools/quic_memory_cache_backend.h", + "quiche/quic/tools/quic_name_lookup.h", + "quiche/quic/tools/quic_simple_client_session.h", + "quiche/quic/tools/quic_simple_client_stream.h", + "quiche/quic/tools/quic_simple_crypto_server_stream_helper.h", + "quiche/quic/tools/quic_simple_dispatcher.h", + "quiche/quic/tools/quic_simple_server_backend.h", + "quiche/quic/tools/quic_simple_server_session.h", + "quiche/quic/tools/quic_simple_server_stream.h", + "quiche/quic/tools/quic_spdy_client_base.h", + "quiche/quic/tools/quic_spdy_server_base.h", + "quiche/quic/tools/quic_tcp_like_trace_converter.h", + "quiche/quic/tools/quic_url.h", + "quiche/quic/tools/simple_ticket_crypter.h", + "quiche/quic/tools/web_transport_test_visitors.h" + ], + "quiche_tool_support_srcs": [ + "quiche/common/platform/api/quiche_file_utils.cc", + "quiche/quic/tools/connect_server_backend.cc", + "quiche/quic/tools/connect_tunnel.cc", + "quiche/quic/tools/connect_udp_tunnel.cc", + "quiche/quic/tools/quic_backend_response.cc", + "quiche/quic/tools/quic_client_base.cc", + "quiche/quic/tools/quic_memory_cache_backend.cc", + "quiche/quic/tools/quic_name_lookup.cc", + "quiche/quic/tools/quic_simple_client_session.cc", + "quiche/quic/tools/quic_simple_client_stream.cc", + "quiche/quic/tools/quic_simple_crypto_server_stream_helper.cc", + "quiche/quic/tools/quic_simple_dispatcher.cc", + "quiche/quic/tools/quic_simple_server_session.cc", + "quiche/quic/tools/quic_simple_server_stream.cc", + "quiche/quic/tools/quic_spdy_client_base.cc", + "quiche/quic/tools/quic_tcp_like_trace_converter.cc", + "quiche/quic/tools/quic_url.cc", + "quiche/quic/tools/simple_ticket_crypter.cc" + ], + "quiche_test_support_hdrs": [ + "quiche/common/platform/api/quiche_expect_bug.h", + "quiche/common/platform/api/quiche_test.h", + "quiche/common/platform/api/quiche_test_loopback.h", + "quiche/common/platform/api/quiche_test_output.h", + "quiche/common/test_tools/quiche_test_utils.h", + "quiche/http2/adapter/mock_http2_visitor.h", + "quiche/http2/adapter/recording_http2_visitor.h", + "quiche/http2/adapter/test_frame_sequence.h", + "quiche/http2/adapter/test_utils.h", + "quiche/http2/test_tools/frame_decoder_state_test_util.h", + "quiche/http2/test_tools/frame_parts.h", + "quiche/http2/test_tools/frame_parts_collector.h", + "quiche/http2/test_tools/frame_parts_collector_listener.h", + "quiche/http2/test_tools/hpack_block_builder.h", + "quiche/http2/test_tools/hpack_block_collector.h", + "quiche/http2/test_tools/hpack_entry_collector.h", + "quiche/http2/test_tools/hpack_example.h", + "quiche/http2/test_tools/hpack_string_collector.h", + "quiche/http2/test_tools/http2_constants_test_util.h", + "quiche/http2/test_tools/http2_frame_builder.h", + "quiche/http2/test_tools/http2_frame_decoder_listener_test_util.h", + "quiche/http2/test_tools/http2_random.h", + "quiche/http2/test_tools/http2_structure_decoder_test_util.h", + "quiche/http2/test_tools/http2_structures_test_util.h", + "quiche/http2/test_tools/payload_decoder_base_test_util.h", + "quiche/http2/test_tools/random_decoder_test_base.h", + "quiche/http2/test_tools/random_util.h", + "quiche/http2/test_tools/verify_macros.h", + "quiche/quic/platform/api/quic_expect_bug.h", + "quiche/quic/platform/api/quic_test.h", + "quiche/quic/platform/api/quic_test_loopback.h", + "quiche/quic/platform/api/quic_test_output.h", + "quiche/quic/test_tools/bad_packet_writer.h", + "quiche/quic/test_tools/crypto_test_utils.h", + "quiche/quic/test_tools/failing_proof_source.h", + "quiche/quic/test_tools/fake_proof_source.h", + "quiche/quic/test_tools/fake_proof_source_handle.h", + "quiche/quic/test_tools/first_flight.h", + "quiche/quic/test_tools/limited_mtu_test_writer.h", + "quiche/quic/test_tools/mock_clock.h", + "quiche/quic/test_tools/mock_connection_id_generator.h", + "quiche/quic/test_tools/mock_quic_client_promised_info.h", + "quiche/quic/test_tools/mock_quic_dispatcher.h", + "quiche/quic/test_tools/mock_quic_session_visitor.h", + "quiche/quic/test_tools/mock_quic_spdy_client_stream.h", + "quiche/quic/test_tools/mock_quic_time_wait_list_manager.h", + "quiche/quic/test_tools/mock_random.h", + "quiche/quic/test_tools/packet_dropping_test_writer.h", + "quiche/quic/test_tools/packet_reordering_writer.h", + "quiche/quic/test_tools/qpack/qpack_decoder_test_utils.h", + "quiche/quic/test_tools/qpack/qpack_encoder_peer.h", + "quiche/quic/test_tools/qpack/qpack_offline_decoder.h", + "quiche/quic/test_tools/qpack/qpack_test_utils.h", + "quiche/quic/test_tools/quic_buffered_packet_store_peer.h", + "quiche/quic/test_tools/quic_client_promised_info_peer.h", + "quiche/quic/test_tools/quic_client_session_cache_peer.h", + "quiche/quic/test_tools/quic_coalesced_packet_peer.h", + "quiche/quic/test_tools/quic_config_peer.h", + "quiche/quic/test_tools/quic_connection_id_manager_peer.h", + "quiche/quic/test_tools/quic_connection_peer.h", + "quiche/quic/test_tools/quic_crypto_server_config_peer.h", + "quiche/quic/test_tools/quic_dispatcher_peer.h", + "quiche/quic/test_tools/quic_flow_controller_peer.h", + "quiche/quic/test_tools/quic_framer_peer.h", + "quiche/quic/test_tools/quic_interval_deque_peer.h", + "quiche/quic/test_tools/quic_packet_creator_peer.h", + "quiche/quic/test_tools/quic_path_validator_peer.h", + "quiche/quic/test_tools/quic_sent_packet_manager_peer.h", + "quiche/quic/test_tools/quic_server_session_base_peer.h", + "quiche/quic/test_tools/quic_session_peer.h", + "quiche/quic/test_tools/quic_spdy_session_peer.h", + "quiche/quic/test_tools/quic_spdy_stream_peer.h", + "quiche/quic/test_tools/quic_stream_id_manager_peer.h", + "quiche/quic/test_tools/quic_stream_peer.h", + "quiche/quic/test_tools/quic_stream_send_buffer_peer.h", + "quiche/quic/test_tools/quic_stream_sequencer_buffer_peer.h", + "quiche/quic/test_tools/quic_stream_sequencer_peer.h", + "quiche/quic/test_tools/quic_sustained_bandwidth_recorder_peer.h", + "quiche/quic/test_tools/quic_test_backend.h", + "quiche/quic/test_tools/quic_test_utils.h", + "quiche/quic/test_tools/quic_time_wait_list_manager_peer.h", + "quiche/quic/test_tools/quic_unacked_packet_map_peer.h", + "quiche/quic/test_tools/rtt_stats_peer.h", + "quiche/quic/test_tools/send_algorithm_test_utils.h", + "quiche/quic/test_tools/simple_data_producer.h", + "quiche/quic/test_tools/simple_quic_framer.h", + "quiche/quic/test_tools/simple_session_cache.h", + "quiche/quic/test_tools/simple_session_notifier.h", + "quiche/quic/test_tools/simulator/actor.h", + "quiche/quic/test_tools/simulator/alarm_factory.h", + "quiche/quic/test_tools/simulator/link.h", + "quiche/quic/test_tools/simulator/packet_filter.h", + "quiche/quic/test_tools/simulator/port.h", + "quiche/quic/test_tools/simulator/queue.h", + "quiche/quic/test_tools/simulator/quic_endpoint.h", + "quiche/quic/test_tools/simulator/quic_endpoint_base.h", + "quiche/quic/test_tools/simulator/simulator.h", + "quiche/quic/test_tools/simulator/switch.h", + "quiche/quic/test_tools/simulator/test_harness.h", + "quiche/quic/test_tools/simulator/traffic_policer.h", + "quiche/quic/test_tools/test_certificates.h", + "quiche/quic/test_tools/test_ticket_crypter.h", + "quiche/quic/test_tools/web_transport_resets_backend.h", + "quiche/quic/test_tools/web_transport_test_tools.h", + "quiche/spdy/test_tools/mock_spdy_framer_visitor.h", + "quiche/spdy/test_tools/spdy_test_utils.h", + "quiche/web_transport/test_tools/mock_web_transport.h" + ], + "quiche_test_support_srcs": [ + "quiche/common/platform/api/quiche_test_loopback.cc", + "quiche/common/test_tools/quiche_test_utils.cc", + "quiche/http2/adapter/recording_http2_visitor.cc", + "quiche/http2/adapter/test_frame_sequence.cc", + "quiche/http2/adapter/test_utils.cc", + "quiche/http2/test_tools/frame_decoder_state_test_util.cc", + "quiche/http2/test_tools/frame_parts.cc", + "quiche/http2/test_tools/frame_parts_collector.cc", + "quiche/http2/test_tools/frame_parts_collector_listener.cc", + "quiche/http2/test_tools/hpack_block_builder.cc", + "quiche/http2/test_tools/hpack_block_collector.cc", + "quiche/http2/test_tools/hpack_entry_collector.cc", + "quiche/http2/test_tools/hpack_example.cc", + "quiche/http2/test_tools/hpack_string_collector.cc", + "quiche/http2/test_tools/http2_constants_test_util.cc", + "quiche/http2/test_tools/http2_frame_builder.cc", + "quiche/http2/test_tools/http2_frame_decoder_listener_test_util.cc", + "quiche/http2/test_tools/http2_random.cc", + "quiche/http2/test_tools/http2_structure_decoder_test_util.cc", + "quiche/http2/test_tools/http2_structures_test_util.cc", + "quiche/http2/test_tools/payload_decoder_base_test_util.cc", + "quiche/http2/test_tools/random_decoder_test_base.cc", + "quiche/http2/test_tools/random_util.cc", + "quiche/quic/test_tools/bad_packet_writer.cc", + "quiche/quic/test_tools/crypto_test_utils.cc", + "quiche/quic/test_tools/failing_proof_source.cc", + "quiche/quic/test_tools/fake_proof_source.cc", + "quiche/quic/test_tools/fake_proof_source_handle.cc", + "quiche/quic/test_tools/first_flight.cc", + "quiche/quic/test_tools/limited_mtu_test_writer.cc", + "quiche/quic/test_tools/mock_clock.cc", + "quiche/quic/test_tools/mock_quic_client_promised_info.cc", + "quiche/quic/test_tools/mock_quic_dispatcher.cc", + "quiche/quic/test_tools/mock_quic_session_visitor.cc", + "quiche/quic/test_tools/mock_quic_spdy_client_stream.cc", + "quiche/quic/test_tools/mock_quic_time_wait_list_manager.cc", + "quiche/quic/test_tools/mock_random.cc", + "quiche/quic/test_tools/packet_dropping_test_writer.cc", + "quiche/quic/test_tools/packet_reordering_writer.cc", + "quiche/quic/test_tools/qpack/qpack_decoder_test_utils.cc", + "quiche/quic/test_tools/qpack/qpack_encoder_peer.cc", + "quiche/quic/test_tools/qpack/qpack_offline_decoder.cc", + "quiche/quic/test_tools/qpack/qpack_test_utils.cc", + "quiche/quic/test_tools/quic_buffered_packet_store_peer.cc", + "quiche/quic/test_tools/quic_client_promised_info_peer.cc", + "quiche/quic/test_tools/quic_coalesced_packet_peer.cc", + "quiche/quic/test_tools/quic_config_peer.cc", + "quiche/quic/test_tools/quic_connection_peer.cc", + "quiche/quic/test_tools/quic_crypto_server_config_peer.cc", + "quiche/quic/test_tools/quic_dispatcher_peer.cc", + "quiche/quic/test_tools/quic_flow_controller_peer.cc", + "quiche/quic/test_tools/quic_framer_peer.cc", + "quiche/quic/test_tools/quic_packet_creator_peer.cc", + "quiche/quic/test_tools/quic_path_validator_peer.cc", + "quiche/quic/test_tools/quic_sent_packet_manager_peer.cc", + "quiche/quic/test_tools/quic_session_peer.cc", + "quiche/quic/test_tools/quic_spdy_session_peer.cc", + "quiche/quic/test_tools/quic_spdy_stream_peer.cc", + "quiche/quic/test_tools/quic_stream_id_manager_peer.cc", + "quiche/quic/test_tools/quic_stream_peer.cc", + "quiche/quic/test_tools/quic_stream_send_buffer_peer.cc", + "quiche/quic/test_tools/quic_stream_sequencer_buffer_peer.cc", + "quiche/quic/test_tools/quic_stream_sequencer_peer.cc", + "quiche/quic/test_tools/quic_sustained_bandwidth_recorder_peer.cc", + "quiche/quic/test_tools/quic_test_backend.cc", + "quiche/quic/test_tools/quic_test_utils.cc", + "quiche/quic/test_tools/quic_time_wait_list_manager_peer.cc", + "quiche/quic/test_tools/quic_unacked_packet_map_peer.cc", + "quiche/quic/test_tools/rtt_stats_peer.cc", + "quiche/quic/test_tools/send_algorithm_test_utils.cc", + "quiche/quic/test_tools/simple_data_producer.cc", + "quiche/quic/test_tools/simple_quic_framer.cc", + "quiche/quic/test_tools/simple_session_cache.cc", + "quiche/quic/test_tools/simple_session_notifier.cc", + "quiche/quic/test_tools/simulator/actor.cc", + "quiche/quic/test_tools/simulator/alarm_factory.cc", + "quiche/quic/test_tools/simulator/link.cc", + "quiche/quic/test_tools/simulator/packet_filter.cc", + "quiche/quic/test_tools/simulator/port.cc", + "quiche/quic/test_tools/simulator/queue.cc", + "quiche/quic/test_tools/simulator/quic_endpoint.cc", + "quiche/quic/test_tools/simulator/quic_endpoint_base.cc", + "quiche/quic/test_tools/simulator/simulator.cc", + "quiche/quic/test_tools/simulator/switch.cc", + "quiche/quic/test_tools/simulator/test_harness.cc", + "quiche/quic/test_tools/simulator/traffic_policer.cc", + "quiche/quic/test_tools/test_certificates.cc", + "quiche/quic/test_tools/test_ticket_crypter.cc", + "quiche/quic/test_tools/web_transport_resets_backend.cc", + "quiche/spdy/test_tools/mock_spdy_framer_visitor.cc", + "quiche/spdy/test_tools/spdy_test_utils.cc" + ], + "io_tool_support_hdrs": [ + "quiche/common/platform/api/quiche_event_loop.h", + "quiche/common/platform/api/quiche_udp_socket_platform_api.h", + "quiche/quic/core/io/event_loop_connecting_client_socket.h", + "quiche/quic/core/io/event_loop_socket_factory.h", + "quiche/quic/core/io/quic_default_event_loop.h", + "quiche/quic/core/io/quic_event_loop.h", + "quiche/quic/core/io/quic_poll_event_loop.h", + "quiche/quic/core/io/socket.h", + "quiche/quic/core/quic_default_packet_writer.h", + "quiche/quic/core/quic_packet_reader.h", + "quiche/quic/core/quic_syscall_wrapper.h", + "quiche/quic/core/quic_udp_socket.h", + "quiche/quic/masque/masque_client.h", + "quiche/quic/masque/masque_client_session.h", + "quiche/quic/masque/masque_client_tools.h", + "quiche/quic/masque/masque_dispatcher.h", + "quiche/quic/masque/masque_encapsulated_client.h", + "quiche/quic/masque/masque_encapsulated_client_session.h", + "quiche/quic/masque/masque_server.h", + "quiche/quic/masque/masque_server_backend.h", + "quiche/quic/masque/masque_server_session.h", + "quiche/quic/masque/masque_utils.h", + "quiche/quic/platform/api/quic_udp_socket_platform_api.h", + "quiche/quic/tools/quic_client_default_network_helper.h", + "quiche/quic/tools/quic_client_factory.h", + "quiche/quic/tools/quic_default_client.h", + "quiche/quic/tools/quic_epoll_client_factory.h", + "quiche/quic/tools/quic_server.h" + ], + "io_tool_support_srcs": [ + "quiche/quic/core/io/event_loop_connecting_client_socket.cc", + "quiche/quic/core/io/event_loop_socket_factory.cc", + "quiche/quic/core/io/quic_default_event_loop.cc", + "quiche/quic/core/io/quic_poll_event_loop.cc", + "quiche/quic/core/io/socket_posix.cc", + "quiche/quic/core/quic_default_packet_writer.cc", + "quiche/quic/core/quic_packet_reader.cc", + "quiche/quic/core/quic_syscall_wrapper.cc", + "quiche/quic/core/quic_udp_socket_posix.cc", + "quiche/quic/masque/masque_client.cc", + "quiche/quic/masque/masque_client_session.cc", + "quiche/quic/masque/masque_client_tools.cc", + "quiche/quic/masque/masque_dispatcher.cc", + "quiche/quic/masque/masque_encapsulated_client.cc", + "quiche/quic/masque/masque_encapsulated_client_session.cc", + "quiche/quic/masque/masque_server.cc", + "quiche/quic/masque/masque_server_backend.cc", + "quiche/quic/masque/masque_server_session.cc", + "quiche/quic/masque/masque_utils.cc", + "quiche/quic/tools/quic_client_default_network_helper.cc", + "quiche/quic/tools/quic_default_client.cc", + "quiche/quic/tools/quic_epoll_client_factory.cc", + "quiche/quic/tools/quic_server.cc" + ], + "io_test_support_hdrs": [ + "quiche/quic/test_tools/quic_mock_syscall_wrapper.h", + "quiche/quic/test_tools/quic_server_peer.h", + "quiche/quic/test_tools/quic_test_client.h", + "quiche/quic/test_tools/quic_test_server.h", + "quiche/quic/test_tools/server_thread.h" + ], + "io_test_support_srcs": [ + "quiche/quic/test_tools/quic_mock_syscall_wrapper.cc", + "quiche/quic/test_tools/quic_server_peer.cc", + "quiche/quic/test_tools/quic_test_client.cc", + "quiche/quic/test_tools/quic_test_server.cc", + "quiche/quic/test_tools/server_thread.cc" + ], + "quiche_tests_hdrs": [ + + ], + "quiche_tests_srcs": [ + "quiche/balsa/balsa_frame_test.cc", + "quiche/balsa/balsa_headers_test.cc", + "quiche/balsa/header_properties_test.cc", + "quiche/balsa/simple_buffer_test.cc", + "quiche/binary_http/binary_http_message_test.cc", + "quiche/common/btree_scheduler_test.cc", + "quiche/common/capsule_test.cc", + "quiche/common/masque/connect_udp_datagram_payload_test.cc", + "quiche/common/platform/api/quiche_file_utils_test.cc", + "quiche/common/platform/api/quiche_hostname_utils_test.cc", + "quiche/common/platform/api/quiche_lower_case_string_test.cc", + "quiche/common/platform/api/quiche_mem_slice_test.cc", + "quiche/common/platform/api/quiche_reference_counted_test.cc", + "quiche/common/platform/api/quiche_stack_trace_test.cc", + "quiche/common/platform/api/quiche_time_utils_test.cc", + "quiche/common/platform/api/quiche_url_utils_test.cc", + "quiche/common/print_elements_test.cc", + "quiche/common/quiche_buffer_allocator_test.cc", + "quiche/common/quiche_circular_deque_test.cc", + "quiche/common/quiche_data_reader_test.cc", + "quiche/common/quiche_data_writer_test.cc", + "quiche/common/quiche_endian_test.cc", + "quiche/common/quiche_ip_address_test.cc", + "quiche/common/quiche_linked_hash_map_test.cc", + "quiche/common/quiche_mem_slice_storage_test.cc", + "quiche/common/quiche_random_test.cc", + "quiche/common/quiche_text_utils_test.cc", + "quiche/common/simple_buffer_allocator_test.cc", + "quiche/common/structured_headers_generated_test.cc", + "quiche/common/structured_headers_test.cc", + "quiche/common/test_tools/quiche_test_utils_test.cc", + "quiche/common/wire_serialization_test.cc", + "quiche/http2/adapter/event_forwarder_test.cc", + "quiche/http2/adapter/header_validator_test.cc", + "quiche/http2/adapter/noop_header_validator_test.cc", + "quiche/http2/adapter/oghttp2_adapter_test.cc", + "quiche/http2/adapter/oghttp2_session_test.cc", + "quiche/http2/adapter/oghttp2_util_test.cc", + "quiche/http2/adapter/recording_http2_visitor_test.cc", + "quiche/http2/adapter/test_utils_test.cc", + "quiche/http2/adapter/window_manager_test.cc", + "quiche/http2/core/priority_write_scheduler_test.cc", + "quiche/http2/decoder/decode_buffer_test.cc", + "quiche/http2/decoder/decode_http2_structures_test.cc", + "quiche/http2/decoder/http2_frame_decoder_test.cc", + "quiche/http2/decoder/http2_structure_decoder_test.cc", + "quiche/http2/decoder/payload_decoders/altsvc_payload_decoder_test.cc", + "quiche/http2/decoder/payload_decoders/continuation_payload_decoder_test.cc", + "quiche/http2/decoder/payload_decoders/data_payload_decoder_test.cc", + "quiche/http2/decoder/payload_decoders/goaway_payload_decoder_test.cc", + "quiche/http2/decoder/payload_decoders/headers_payload_decoder_test.cc", + "quiche/http2/decoder/payload_decoders/ping_payload_decoder_test.cc", + "quiche/http2/decoder/payload_decoders/priority_payload_decoder_test.cc", + "quiche/http2/decoder/payload_decoders/priority_update_payload_decoder_test.cc", + "quiche/http2/decoder/payload_decoders/push_promise_payload_decoder_test.cc", + "quiche/http2/decoder/payload_decoders/rst_stream_payload_decoder_test.cc", + "quiche/http2/decoder/payload_decoders/settings_payload_decoder_test.cc", + "quiche/http2/decoder/payload_decoders/unknown_payload_decoder_test.cc", + "quiche/http2/decoder/payload_decoders/window_update_payload_decoder_test.cc", + "quiche/http2/hpack/decoder/hpack_block_collector_test.cc", + "quiche/http2/hpack/decoder/hpack_block_decoder_test.cc", + "quiche/http2/hpack/decoder/hpack_decoder_state_test.cc", + "quiche/http2/hpack/decoder/hpack_decoder_string_buffer_test.cc", + "quiche/http2/hpack/decoder/hpack_decoder_tables_test.cc", + "quiche/http2/hpack/decoder/hpack_decoder_test.cc", + "quiche/http2/hpack/decoder/hpack_entry_collector_test.cc", + "quiche/http2/hpack/decoder/hpack_entry_decoder_test.cc", + "quiche/http2/hpack/decoder/hpack_entry_type_decoder_test.cc", + "quiche/http2/hpack/decoder/hpack_string_decoder_test.cc", + "quiche/http2/hpack/decoder/hpack_whole_entry_buffer_test.cc", + "quiche/http2/hpack/http2_hpack_constants_test.cc", + "quiche/http2/hpack/huffman/hpack_huffman_decoder_test.cc", + "quiche/http2/hpack/huffman/hpack_huffman_encoder_test.cc", + "quiche/http2/hpack/huffman/hpack_huffman_transcoder_test.cc", + "quiche/http2/hpack/varint/hpack_varint_decoder_test.cc", + "quiche/http2/hpack/varint/hpack_varint_encoder_test.cc", + "quiche/http2/hpack/varint/hpack_varint_round_trip_test.cc", + "quiche/http2/http2_constants_test.cc", + "quiche/http2/http2_structures_test.cc", + "quiche/http2/test_tools/hpack_block_builder_test.cc", + "quiche/http2/test_tools/hpack_example_test.cc", + "quiche/http2/test_tools/http2_frame_builder_test.cc", + "quiche/http2/test_tools/http2_random_test.cc", + "quiche/http2/test_tools/random_decoder_test_base_test.cc", + "quiche/oblivious_http/buffers/oblivious_http_integration_test.cc", + "quiche/oblivious_http/buffers/oblivious_http_request_test.cc", + "quiche/oblivious_http/buffers/oblivious_http_response_test.cc", + "quiche/oblivious_http/common/oblivious_http_header_key_config_test.cc", + "quiche/oblivious_http/oblivious_http_client_test.cc", + "quiche/oblivious_http/oblivious_http_gateway_test.cc", + "quiche/quic/core/congestion_control/bandwidth_sampler_test.cc", + "quiche/quic/core/congestion_control/bbr2_simulator_test.cc", + "quiche/quic/core/congestion_control/bbr_sender_test.cc", + "quiche/quic/core/congestion_control/cubic_bytes_test.cc", + "quiche/quic/core/congestion_control/general_loss_algorithm_test.cc", + "quiche/quic/core/congestion_control/hybrid_slow_start_test.cc", + "quiche/quic/core/congestion_control/pacing_sender_test.cc", + "quiche/quic/core/congestion_control/prr_sender_test.cc", + "quiche/quic/core/congestion_control/rtt_stats_test.cc", + "quiche/quic/core/congestion_control/send_algorithm_test.cc", + "quiche/quic/core/congestion_control/tcp_cubic_sender_bytes_test.cc", + "quiche/quic/core/congestion_control/uber_loss_algorithm_test.cc", + "quiche/quic/core/congestion_control/windowed_filter_test.cc", + "quiche/quic/core/crypto/aes_128_gcm_12_decrypter_test.cc", + "quiche/quic/core/crypto/aes_128_gcm_12_encrypter_test.cc", + "quiche/quic/core/crypto/aes_128_gcm_decrypter_test.cc", + "quiche/quic/core/crypto/aes_128_gcm_encrypter_test.cc", + "quiche/quic/core/crypto/aes_256_gcm_decrypter_test.cc", + "quiche/quic/core/crypto/aes_256_gcm_encrypter_test.cc", + "quiche/quic/core/crypto/cert_compressor_test.cc", + "quiche/quic/core/crypto/certificate_util_test.cc", + "quiche/quic/core/crypto/certificate_view_test.cc", + "quiche/quic/core/crypto/chacha20_poly1305_decrypter_test.cc", + "quiche/quic/core/crypto/chacha20_poly1305_encrypter_test.cc", + "quiche/quic/core/crypto/chacha20_poly1305_tls_decrypter_test.cc", + "quiche/quic/core/crypto/chacha20_poly1305_tls_encrypter_test.cc", + "quiche/quic/core/crypto/channel_id_test.cc", + "quiche/quic/core/crypto/client_proof_source_test.cc", + "quiche/quic/core/crypto/crypto_framer_test.cc", + "quiche/quic/core/crypto/crypto_handshake_message_test.cc", + "quiche/quic/core/crypto/crypto_secret_boxer_test.cc", + "quiche/quic/core/crypto/crypto_server_test.cc", + "quiche/quic/core/crypto/crypto_utils_test.cc", + "quiche/quic/core/crypto/curve25519_key_exchange_test.cc", + "quiche/quic/core/crypto/null_decrypter_test.cc", + "quiche/quic/core/crypto/null_encrypter_test.cc", + "quiche/quic/core/crypto/p256_key_exchange_test.cc", + "quiche/quic/core/crypto/proof_source_x509_test.cc", + "quiche/quic/core/crypto/quic_client_session_cache_test.cc", + "quiche/quic/core/crypto/quic_compressed_certs_cache_test.cc", + "quiche/quic/core/crypto/quic_crypto_client_config_test.cc", + "quiche/quic/core/crypto/quic_crypto_server_config_test.cc", + "quiche/quic/core/crypto/quic_hkdf_test.cc", + "quiche/quic/core/crypto/transport_parameters_test.cc", + "quiche/quic/core/crypto/web_transport_fingerprint_proof_verifier_test.cc", + "quiche/quic/core/deterministic_connection_id_generator_test.cc", + "quiche/quic/core/frames/quic_frames_test.cc", + "quiche/quic/core/http/http_decoder_test.cc", + "quiche/quic/core/http/http_encoder_test.cc", + "quiche/quic/core/http/http_frames_test.cc", + "quiche/quic/core/http/quic_client_promised_info_test.cc", + "quiche/quic/core/http/quic_client_push_promise_index_test.cc", + "quiche/quic/core/http/quic_header_list_test.cc", + "quiche/quic/core/http/quic_headers_stream_test.cc", + "quiche/quic/core/http/quic_receive_control_stream_test.cc", + "quiche/quic/core/http/quic_send_control_stream_test.cc", + "quiche/quic/core/http/quic_server_session_base_test.cc", + "quiche/quic/core/http/quic_spdy_session_test.cc", + "quiche/quic/core/http/quic_spdy_stream_body_manager_test.cc", + "quiche/quic/core/http/quic_spdy_stream_test.cc", + "quiche/quic/core/http/spdy_server_push_utils_test.cc", + "quiche/quic/core/http/spdy_utils_test.cc", + "quiche/quic/core/http/web_transport_http3_test.cc", + "quiche/quic/core/legacy_quic_stream_id_manager_test.cc", + "quiche/quic/core/packet_number_indexed_queue_test.cc", + "quiche/quic/core/qpack/qpack_blocking_manager_test.cc", + "quiche/quic/core/qpack/qpack_decoded_headers_accumulator_test.cc", + "quiche/quic/core/qpack/qpack_decoder_stream_receiver_test.cc", + "quiche/quic/core/qpack/qpack_decoder_stream_sender_test.cc", + "quiche/quic/core/qpack/qpack_decoder_test.cc", + "quiche/quic/core/qpack/qpack_encoder_stream_receiver_test.cc", + "quiche/quic/core/qpack/qpack_encoder_stream_sender_test.cc", + "quiche/quic/core/qpack/qpack_encoder_test.cc", + "quiche/quic/core/qpack/qpack_header_table_test.cc", + "quiche/quic/core/qpack/qpack_index_conversions_test.cc", + "quiche/quic/core/qpack/qpack_instruction_decoder_test.cc", + "quiche/quic/core/qpack/qpack_instruction_encoder_test.cc", + "quiche/quic/core/qpack/qpack_receive_stream_test.cc", + "quiche/quic/core/qpack/qpack_required_insert_count_test.cc", + "quiche/quic/core/qpack/qpack_round_trip_test.cc", + "quiche/quic/core/qpack/qpack_send_stream_test.cc", + "quiche/quic/core/qpack/qpack_static_table_test.cc", + "quiche/quic/core/qpack/value_splitting_header_list_test.cc", + "quiche/quic/core/quic_alarm_test.cc", + "quiche/quic/core/quic_arena_scoped_ptr_test.cc", + "quiche/quic/core/quic_bandwidth_test.cc", + "quiche/quic/core/quic_buffered_packet_store_test.cc", + "quiche/quic/core/quic_chaos_protector_test.cc", + "quiche/quic/core/quic_coalesced_packet_test.cc", + "quiche/quic/core/quic_config_test.cc", + "quiche/quic/core/quic_connection_context_test.cc", + "quiche/quic/core/quic_connection_id_manager_test.cc", + "quiche/quic/core/quic_connection_id_test.cc", + "quiche/quic/core/quic_connection_test.cc", + "quiche/quic/core/quic_control_frame_manager_test.cc", + "quiche/quic/core/quic_crypto_client_handshaker_test.cc", + "quiche/quic/core/quic_crypto_client_stream_test.cc", + "quiche/quic/core/quic_crypto_server_stream_test.cc", + "quiche/quic/core/quic_crypto_stream_test.cc", + "quiche/quic/core/quic_data_writer_test.cc", + "quiche/quic/core/quic_datagram_queue_test.cc", + "quiche/quic/core/quic_dispatcher_test.cc", + "quiche/quic/core/quic_error_codes_test.cc", + "quiche/quic/core/quic_flow_controller_test.cc", + "quiche/quic/core/quic_framer_test.cc", + "quiche/quic/core/quic_idle_network_detector_test.cc", + "quiche/quic/core/quic_interval_deque_test.cc", + "quiche/quic/core/quic_interval_set_test.cc", + "quiche/quic/core/quic_interval_test.cc", + "quiche/quic/core/quic_lru_cache_test.cc", + "quiche/quic/core/quic_network_blackhole_detector_test.cc", + "quiche/quic/core/quic_one_block_arena_test.cc", + "quiche/quic/core/quic_packet_creator_test.cc", + "quiche/quic/core/quic_packet_number_test.cc", + "quiche/quic/core/quic_packets_test.cc", + "quiche/quic/core/quic_path_validator_test.cc", + "quiche/quic/core/quic_ping_manager_test.cc", + "quiche/quic/core/quic_received_packet_manager_test.cc", + "quiche/quic/core/quic_sent_packet_manager_test.cc", + "quiche/quic/core/quic_server_id_test.cc", + "quiche/quic/core/quic_session_test.cc", + "quiche/quic/core/quic_socket_address_coder_test.cc", + "quiche/quic/core/quic_stream_id_manager_test.cc", + "quiche/quic/core/quic_stream_priority_test.cc", + "quiche/quic/core/quic_stream_send_buffer_test.cc", + "quiche/quic/core/quic_stream_sequencer_buffer_test.cc", + "quiche/quic/core/quic_stream_sequencer_test.cc", + "quiche/quic/core/quic_stream_test.cc", + "quiche/quic/core/quic_sustained_bandwidth_recorder_test.cc", + "quiche/quic/core/quic_tag_test.cc", + "quiche/quic/core/quic_time_accumulator_test.cc", + "quiche/quic/core/quic_time_test.cc", + "quiche/quic/core/quic_time_wait_list_manager_test.cc", + "quiche/quic/core/quic_trace_visitor_test.cc", + "quiche/quic/core/quic_unacked_packet_map_test.cc", + "quiche/quic/core/quic_utils_test.cc", + "quiche/quic/core/quic_version_manager_test.cc", + "quiche/quic/core/quic_versions_test.cc", + "quiche/quic/core/quic_write_blocked_list_test.cc", + "quiche/quic/core/tls_chlo_extractor_test.cc", + "quiche/quic/core/tls_client_handshaker_test.cc", + "quiche/quic/core/tls_server_handshaker_test.cc", + "quiche/quic/core/uber_quic_stream_id_manager_test.cc", + "quiche/quic/core/uber_received_packet_manager_test.cc", + "quiche/quic/platform/api/quic_socket_address_test.cc", + "quiche/quic/test_tools/crypto_test_utils_test.cc", + "quiche/quic/test_tools/quic_test_utils_test.cc", + "quiche/quic/test_tools/simple_session_notifier_test.cc", + "quiche/quic/test_tools/simulator/quic_endpoint_test.cc", + "quiche/quic/test_tools/simulator/simulator_test.cc", + "quiche/quic/tools/connect_tunnel_test.cc", + "quiche/quic/tools/connect_udp_tunnel_test.cc", + "quiche/quic/tools/quic_memory_cache_backend_test.cc", + "quiche/quic/tools/quic_tcp_like_trace_converter_test.cc", + "quiche/quic/tools/simple_ticket_crypter_test.cc", + "quiche/spdy/core/array_output_buffer_test.cc", + "quiche/spdy/core/hpack/hpack_decoder_adapter_test.cc", + "quiche/spdy/core/hpack/hpack_encoder_test.cc", + "quiche/spdy/core/hpack/hpack_entry_test.cc", + "quiche/spdy/core/hpack/hpack_header_table_test.cc", + "quiche/spdy/core/hpack/hpack_output_stream_test.cc", + "quiche/spdy/core/hpack/hpack_round_trip_test.cc", + "quiche/spdy/core/hpack/hpack_static_table_test.cc", + "quiche/spdy/core/http2_header_block_test.cc", + "quiche/spdy/core/http2_header_storage_test.cc", + "quiche/spdy/core/metadata_extension_test.cc", + "quiche/spdy/core/spdy_alt_svc_wire_format_test.cc", + "quiche/spdy/core/spdy_frame_builder_test.cc", + "quiche/spdy/core/spdy_framer_test.cc", + "quiche/spdy/core/spdy_intrusive_list_test.cc", + "quiche/spdy/core/spdy_pinnable_buffer_piece_test.cc", + "quiche/spdy/core/spdy_prefixed_buffer_reader_test.cc", + "quiche/spdy/core/spdy_protocol_test.cc", + "quiche/spdy/core/spdy_simple_arena_test.cc" + ], + "io_tests_hdrs": [ + + ], + "io_tests_srcs": [ + "quiche/quic/core/chlo_extractor_test.cc", + "quiche/quic/core/http/end_to_end_test.cc", + "quiche/quic/core/http/quic_spdy_client_session_test.cc", + "quiche/quic/core/http/quic_spdy_client_stream_test.cc", + "quiche/quic/core/http/quic_spdy_server_stream_base_test.cc", + "quiche/quic/core/io/event_loop_connecting_client_socket_test.cc", + "quiche/quic/core/io/quic_all_event_loops_test.cc", + "quiche/quic/core/io/quic_poll_event_loop_test.cc", + "quiche/quic/core/io/socket_test.cc", + "quiche/quic/tools/quic_default_client_test.cc", + "quiche/quic/tools/quic_server_test.cc", + "quiche/quic/tools/quic_simple_server_session_test.cc", + "quiche/quic/tools/quic_simple_server_stream_test.cc", + "quiche/quic/tools/quic_url_test.cc" + ], + "fuzzers_hdrs": [ + + ], + "fuzzers_srcs": [ + "quiche/common/structured_headers_fuzzer.cc", + "quiche/quic/core/crypto/certificate_view_der_fuzzer.cc", + "quiche/quic/core/crypto/certificate_view_pem_fuzzer.cc", + "quiche/quic/core/qpack/fuzzer/qpack_decoder_fuzzer.cc", + "quiche/quic/core/qpack/fuzzer/qpack_decoder_stream_receiver_fuzzer.cc", + "quiche/quic/core/qpack/fuzzer/qpack_decoder_stream_sender_fuzzer.cc", + "quiche/quic/core/qpack/fuzzer/qpack_encoder_stream_receiver_fuzzer.cc", + "quiche/quic/core/qpack/fuzzer/qpack_encoder_stream_sender_fuzzer.cc", + "quiche/quic/core/qpack/fuzzer/qpack_round_trip_fuzzer.cc", + "quiche/quic/test_tools/fuzzing/quic_framer_fuzzer.cc", + "quiche/quic/test_tools/fuzzing/quic_framer_process_data_packet_fuzzer.cc" + ], + "cli_tools_hdrs": [ + "quiche/quic/tools/quic_server_factory.h", + "quiche/quic/tools/quic_toy_client.h", + "quiche/quic/tools/quic_toy_server.h" + ], + "cli_tools_srcs": [ + "quiche/quic/masque/masque_client_bin.cc", + "quiche/quic/masque/masque_server_bin.cc", + "quiche/quic/tools/crypto_message_printer_bin.cc", + "quiche/quic/tools/qpack_offline_decoder_bin.cc", + "quiche/quic/tools/quic_client_bin.cc", + "quiche/quic/tools/quic_client_interop_test_bin.cc", + "quiche/quic/tools/quic_packet_printer_bin.cc", + "quiche/quic/tools/quic_reject_reason_decoder_bin.cc", + "quiche/quic/tools/quic_server_bin.cc", + "quiche/quic/tools/quic_server_factory.cc", + "quiche/quic/tools/quic_toy_client.cc", + "quiche/quic/tools/quic_toy_server.cc" + ], + "nghttp2_hdrs": [ + "quiche/http2/adapter/callback_visitor.h", + "quiche/http2/adapter/nghttp2.h", + "quiche/http2/adapter/nghttp2_adapter.h", + "quiche/http2/adapter/nghttp2_callbacks.h", + "quiche/http2/adapter/nghttp2_data_provider.h", + "quiche/http2/adapter/nghttp2_session.h", + "quiche/http2/adapter/nghttp2_util.h" + ], + "nghttp2_srcs": [ + "quiche/http2/adapter/callback_visitor.cc", + "quiche/http2/adapter/nghttp2_adapter.cc", + "quiche/http2/adapter/nghttp2_callbacks.cc", + "quiche/http2/adapter/nghttp2_data_provider.cc", + "quiche/http2/adapter/nghttp2_session.cc", + "quiche/http2/adapter/nghttp2_test.cc", + "quiche/http2/adapter/nghttp2_util.cc" + ], + "nghttp2_test_support_hdrs": [ + "quiche/http2/adapter/mock_nghttp2_callbacks.h", + "quiche/http2/adapter/nghttp2_test_utils.h" + ], + "nghttp2_test_support_srcs": [ + "quiche/http2/adapter/mock_nghttp2_callbacks.cc", + "quiche/http2/adapter/nghttp2_test_utils.cc" + ], + "nghttp2_tests_hdrs": [ + + ], + "nghttp2_tests_srcs": [ + "quiche/http2/adapter/adapter_impl_comparison_test.cc", + "quiche/http2/adapter/callback_visitor_test.cc", + "quiche/http2/adapter/nghttp2_adapter_test.cc", + "quiche/http2/adapter/nghttp2_data_provider_test.cc", + "quiche/http2/adapter/nghttp2_session_test.cc", + "quiche/http2/adapter/nghttp2_util_test.cc" + ], + "default_platform_impl_hdrs": [ + "quiche/common/platform/default/quiche_platform_impl/quiche_bug_tracker_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_client_stats_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_containers_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_event_loop_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_export_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_flag_utils_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_flags_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_header_policy_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_iovec_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_logging_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_lower_case_string_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_mem_slice_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_mutex_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_prefetch_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_reference_counted_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_server_stats_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_stack_trace_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_testvalue_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_time_utils_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_udp_socket_platform_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_url_utils_impl.h" + ], + "default_platform_impl_srcs": [ + "quiche/common/platform/default/quiche_platform_impl/quiche_flags_impl.cc", + "quiche/common/platform/default/quiche_platform_impl/quiche_mutex_impl.cc", + "quiche/common/platform/default/quiche_platform_impl/quiche_stack_trace_impl.cc", + "quiche/common/platform/default/quiche_platform_impl/quiche_time_utils_impl.cc", + "quiche/common/platform/default/quiche_platform_impl/quiche_url_utils_impl.cc" + ], + "default_platform_impl_tool_support_hdrs": [ + "quiche/common/platform/default/quiche_platform_impl/quiche_command_line_flags_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_default_proof_providers_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_file_utils_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_stream_buffer_allocator_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_system_event_loop_impl.h" + ], + "default_platform_impl_tool_support_srcs": [ + "quiche/common/platform/default/quiche_platform_impl/quiche_command_line_flags_impl.cc", + "quiche/common/platform/default/quiche_platform_impl/quiche_default_proof_providers_impl.cc", + "quiche/common/platform/default/quiche_platform_impl/quiche_file_utils_impl.cc" + ], + "default_platform_impl_test_support_hdrs": [ + "quiche/common/platform/default/quiche_platform_impl/quiche_expect_bug_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_test_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_test_loopback_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_test_output_impl.h", + "quiche/common/platform/default/quiche_platform_impl/quiche_thread_impl.h" + ], + "default_platform_impl_test_support_srcs": [ + "quiche/common/platform/default/quiche_platform_impl/quiche_test_impl.cc", + "quiche/common/platform/default/quiche_platform_impl/quiche_test_loopback_impl.cc" + ], + "load_balancer_hdrs": [ + "quiche/quic/load_balancer/load_balancer_config.h", + "quiche/quic/load_balancer/load_balancer_decoder.h", + "quiche/quic/load_balancer/load_balancer_encoder.h", + "quiche/quic/load_balancer/load_balancer_server_id.h", + "quiche/quic/load_balancer/load_balancer_server_id_map.h" + ], + "load_balancer_srcs": [ + "quiche/quic/load_balancer/load_balancer_config.cc", + "quiche/quic/load_balancer/load_balancer_config_test.cc", + "quiche/quic/load_balancer/load_balancer_decoder.cc", + "quiche/quic/load_balancer/load_balancer_decoder_test.cc", + "quiche/quic/load_balancer/load_balancer_encoder.cc", + "quiche/quic/load_balancer/load_balancer_encoder_test.cc", + "quiche/quic/load_balancer/load_balancer_server_id.cc", + "quiche/quic/load_balancer/load_balancer_server_id_map_test.cc", + "quiche/quic/load_balancer/load_balancer_server_id_test.cc" + ], + "binary_http_hdrs": [ + "quiche/binary_http/binary_http_message.h" + ], + "binary_http_srcs": [ + "quiche/binary_http/binary_http_message.cc" + ], + "oblivious_http_hdrs": [ + "quiche/oblivious_http/buffers/oblivious_http_request.h", + "quiche/oblivious_http/buffers/oblivious_http_response.h", + "quiche/oblivious_http/common/oblivious_http_header_key_config.h", + "quiche/oblivious_http/oblivious_http_client.h", + "quiche/oblivious_http/oblivious_http_gateway.h" + ], + "oblivious_http_srcs": [ + "quiche/oblivious_http/buffers/oblivious_http_request.cc", + "quiche/oblivious_http/buffers/oblivious_http_response.cc", + "quiche/oblivious_http/common/oblivious_http_header_key_config.cc", + "quiche/oblivious_http/oblivious_http_client.cc", + "quiche/oblivious_http/oblivious_http_gateway.cc" + ], + "qbone_hdrs": [ + "quiche/quic/qbone/bonnet/icmp_reachable.h", + "quiche/quic/qbone/bonnet/icmp_reachable_interface.h", + "quiche/quic/qbone/bonnet/mock_icmp_reachable.h", + "quiche/quic/qbone/bonnet/mock_packet_exchanger_stats_interface.h", + "quiche/quic/qbone/bonnet/mock_qbone_tunnel.h", + "quiche/quic/qbone/bonnet/mock_tun_device.h", + "quiche/quic/qbone/bonnet/mock_tun_device_controller.h", + "quiche/quic/qbone/bonnet/qbone_tunnel_info.h", + "quiche/quic/qbone/bonnet/qbone_tunnel_interface.h", + "quiche/quic/qbone/bonnet/qbone_tunnel_silo.h", + "quiche/quic/qbone/bonnet/tun_device.h", + "quiche/quic/qbone/bonnet/tun_device_controller.h", + "quiche/quic/qbone/bonnet/tun_device_interface.h", + "quiche/quic/qbone/bonnet/tun_device_packet_exchanger.h", + "quiche/quic/qbone/mock_qbone_client.h", + "quiche/quic/qbone/mock_qbone_server_session.h", + "quiche/quic/qbone/platform/icmp_packet.h", + "quiche/quic/qbone/platform/internet_checksum.h", + "quiche/quic/qbone/platform/ip_range.h", + "quiche/quic/qbone/platform/kernel_interface.h", + "quiche/quic/qbone/platform/mock_kernel.h", + "quiche/quic/qbone/platform/mock_netlink.h", + "quiche/quic/qbone/platform/netlink.h", + "quiche/quic/qbone/platform/netlink_interface.h", + "quiche/quic/qbone/platform/rtnetlink_message.h", + "quiche/quic/qbone/platform/tcp_packet.h", + "quiche/quic/qbone/qbone_client.h", + "quiche/quic/qbone/qbone_client_interface.h", + "quiche/quic/qbone/qbone_client_session.h", + "quiche/quic/qbone/qbone_constants.h", + "quiche/quic/qbone/qbone_control_stream.h", + "quiche/quic/qbone/qbone_packet_exchanger.h", + "quiche/quic/qbone/qbone_packet_processor.h", + "quiche/quic/qbone/qbone_packet_processor_test_tools.h", + "quiche/quic/qbone/qbone_packet_writer.h", + "quiche/quic/qbone/qbone_server_session.h", + "quiche/quic/qbone/qbone_session_base.h", + "quiche/quic/qbone/qbone_stream.h" + ], + "qbone_srcs": [ + "quiche/quic/qbone/bonnet/icmp_reachable.cc", + "quiche/quic/qbone/bonnet/icmp_reachable_test.cc", + "quiche/quic/qbone/bonnet/qbone_tunnel_info.cc", + "quiche/quic/qbone/bonnet/qbone_tunnel_silo.cc", + "quiche/quic/qbone/bonnet/qbone_tunnel_silo_test.cc", + "quiche/quic/qbone/bonnet/tun_device.cc", + "quiche/quic/qbone/bonnet/tun_device_controller.cc", + "quiche/quic/qbone/bonnet/tun_device_controller_test.cc", + "quiche/quic/qbone/bonnet/tun_device_packet_exchanger.cc", + "quiche/quic/qbone/bonnet/tun_device_packet_exchanger_test.cc", + "quiche/quic/qbone/bonnet/tun_device_test.cc", + "quiche/quic/qbone/platform/icmp_packet.cc", + "quiche/quic/qbone/platform/icmp_packet_test.cc", + "quiche/quic/qbone/platform/internet_checksum.cc", + "quiche/quic/qbone/platform/internet_checksum_test.cc", + "quiche/quic/qbone/platform/ip_range.cc", + "quiche/quic/qbone/platform/ip_range_test.cc", + "quiche/quic/qbone/platform/netlink.cc", + "quiche/quic/qbone/platform/netlink_test.cc", + "quiche/quic/qbone/platform/rtnetlink_message.cc", + "quiche/quic/qbone/platform/rtnetlink_message_test.cc", + "quiche/quic/qbone/platform/tcp_packet.cc", + "quiche/quic/qbone/platform/tcp_packet_test.cc", + "quiche/quic/qbone/qbone_client.cc", + "quiche/quic/qbone/qbone_client_session.cc", + "quiche/quic/qbone/qbone_client_test.cc", + "quiche/quic/qbone/qbone_constants.cc", + "quiche/quic/qbone/qbone_control_stream.cc", + "quiche/quic/qbone/qbone_packet_exchanger.cc", + "quiche/quic/qbone/qbone_packet_exchanger_test.cc", + "quiche/quic/qbone/qbone_packet_processor.cc", + "quiche/quic/qbone/qbone_packet_processor_test.cc", + "quiche/quic/qbone/qbone_packet_processor_test_tools.cc", + "quiche/quic/qbone/qbone_server_session.cc", + "quiche/quic/qbone/qbone_session_base.cc", + "quiche/quic/qbone/qbone_session_test.cc", + "quiche/quic/qbone/qbone_stream.cc", + "quiche/quic/qbone/qbone_stream_test.cc" + ], + "blind_sign_auth_hdrs": [ + "quiche/blind_sign_auth/anonymous_tokens/cpp/client/anonymous_tokens_rsa_bssa_client.h", + "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/blind_signer.h", + "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/blinder.h", + "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/constants.h", + "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.h", + "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blind_signer.h", + "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blinder.h", + "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_ssa_pss_verifier.h", + "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/verifier.h", + "quiche/blind_sign_auth/anonymous_tokens/cpp/shared/proto_utils.h", + "quiche/blind_sign_auth/anonymous_tokens/cpp/shared/status_utils.h", + "quiche/blind_sign_auth/anonymous_tokens/cpp/testing/utils.h", + "quiche/blind_sign_auth/blind_sign_auth.h", + "quiche/blind_sign_auth/blind_sign_auth_interface.h", + "quiche/blind_sign_auth/blind_sign_http_interface.h", + "quiche/blind_sign_auth/blind_sign_http_response.h", + "quiche/blind_sign_auth/cached_blind_sign_auth.h", + "quiche/blind_sign_auth/test_tools/mock_blind_sign_auth_interface.h", + "quiche/blind_sign_auth/test_tools/mock_blind_sign_http_interface.h" + ], + "blind_sign_auth_srcs": [ + "quiche/blind_sign_auth/anonymous_tokens/cpp/client/anonymous_tokens_rsa_bssa_client.cc", + "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.cc", + "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blind_signer.cc", + "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blinder.cc", + "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_ssa_pss_verifier.cc", + "quiche/blind_sign_auth/anonymous_tokens/cpp/shared/proto_utils.cc", + "quiche/blind_sign_auth/anonymous_tokens/cpp/testing/utils.cc", + "quiche/blind_sign_auth/blind_sign_auth.cc", + "quiche/blind_sign_auth/cached_blind_sign_auth.cc" + ], + "blind_sign_auth_tests_hdrs": [ + + ], + "blind_sign_auth_tests_srcs": [ + "quiche/blind_sign_auth/anonymous_tokens/cpp/client/anonymous_tokens_rsa_bssa_client_test.cc", + "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/at_crypto_utils_test.cc", + "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blind_signer_test.cc", + "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blinder_test.cc", + "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_ssa_pss_verifier_test.cc", + "quiche/blind_sign_auth/anonymous_tokens/cpp/shared/proto_utils_test.cc", + "quiche/blind_sign_auth/blind_sign_auth_test.cc", + "quiche/blind_sign_auth/cached_blind_sign_auth_test.cc" + ], + "protobuf_blind_sign_auth": [ + "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.proto", + "quiche/blind_sign_auth/proto/any.proto", + "quiche/blind_sign_auth/proto/attestation.proto", + "quiche/blind_sign_auth/proto/auth_and_sign.proto", + "quiche/blind_sign_auth/proto/get_initial_data.proto", + "quiche/blind_sign_auth/proto/key_services.proto", + "quiche/blind_sign_auth/proto/public_metadata.proto", + "quiche/blind_sign_auth/proto/spend_token_data.proto", + "quiche/blind_sign_auth/proto/timestamp.proto" + ], + "libevent_hdrs": [ + "quiche/quic/bindings/quic_libevent.h" + ], + "libevent_srcs": [ + "quiche/quic/bindings/quic_libevent.cc", + "quiche/quic/bindings/quic_libevent_test.cc" + ], + "linux_only_hdrs": [ + "quiche/quic/core/batch_writer/quic_batch_writer_base.h", + "quiche/quic/core/batch_writer/quic_batch_writer_buffer.h", + "quiche/quic/core/batch_writer/quic_batch_writer_test.h", + "quiche/quic/core/batch_writer/quic_gso_batch_writer.h", + "quiche/quic/core/batch_writer/quic_sendmmsg_batch_writer.h", + "quiche/quic/core/quic_linux_socket_utils.h" + ], + "linux_only_srcs": [ + "quiche/quic/core/batch_writer/quic_batch_writer_base.cc", + "quiche/quic/core/batch_writer/quic_batch_writer_buffer.cc", + "quiche/quic/core/batch_writer/quic_gso_batch_writer.cc", + "quiche/quic/core/batch_writer/quic_sendmmsg_batch_writer.cc", + "quiche/quic/core/quic_linux_socket_utils.cc" + ], + "linux_only_tests_hdrs": [ + + ], + "linux_only_tests_srcs": [ + "quiche/quic/core/batch_writer/quic_batch_writer_buffer_test.cc", + "quiche/quic/core/batch_writer/quic_batch_writer_test.cc", + "quiche/quic/core/batch_writer/quic_gso_batch_writer_test.cc", + "quiche/quic/core/batch_writer/quic_sendmmsg_batch_writer_test.cc", + "quiche/quic/core/quic_linux_socket_utils_test.cc" + ] +} \ No newline at end of file diff --git a/build/test.bzl b/build/test.bzl new file mode 100644 index 000000000000..33bef08aa247 --- /dev/null +++ b/build/test.bzl @@ -0,0 +1,30 @@ +"""Tools for building QUICHE tests.""" + +load("@bazel_skylib//lib:dicts.bzl", "dicts") +load("@bazel_skylib//lib:paths.bzl", "paths") + +def test_suite_from_source_list(name, srcs, **kwargs): + """ + Generates a test target for every individual test source file specified. + + Args: + name: the name of the resulting test_suite target. + srcs: the list of source files from which the test targets are generated. + **kwargs: other arguments that are passed to the cc_test rule directly.s + """ + + tests = [] + for sourcefile in srcs: + if not sourcefile.endswith("_test.cc"): + fail("All source files passed to test_suite_from_source_list() must end with _test.cc") + test_name, _ = paths.split_extension(paths.basename(sourcefile)) + extra_kwargs = {} + if test_name == "end_to_end_test": + extra_kwargs["shard_count"] = 16 + native.cc_test( + name = test_name, + srcs = [sourcefile], + **dicts.add(kwargs, extra_kwargs) + ) + tests.append(test_name) + native.test_suite(name = name, tests = tests) diff --git a/build/zlib.BUILD b/build/zlib.BUILD new file mode 100644 index 000000000000..61a55078084b --- /dev/null +++ b/build/zlib.BUILD @@ -0,0 +1,25 @@ +# Copyright 2022 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +licenses(["notice"]) + +cc_library( + name = "zlib", + srcs = glob( + include = [ + "*.c", + "*.h", + ], + exclude = [ + "zlib.h", + "zconf.h", + ], + ), + hdrs = [ + "zconf.h", + "zlib.h", + ], + copts = ["-Wno-implicit-function-declaration"], + visibility = ["//visibility:public"], +) diff --git a/depstool/deps/parse.go b/depstool/deps/parse.go new file mode 100644 index 000000000000..238837bfb194 --- /dev/null +++ b/depstool/deps/parse.go @@ -0,0 +1,123 @@ +// Package deps package provides methods to extract and manipulate external code dependencies from the QUICHE WORKSPACE.bazel file. +package deps + +import ( + "fmt" + "regexp" + + "github.com/bazelbuild/buildtools/build" +) + +var lastUpdatedRE = regexp.MustCompile(`Last updated (\d{4}-\d{2}-\d{2})`) + +// Entry is a parsed representation of a dependency entry in the WORKSPACE.bazel file. +type Entry struct { + Name string + SHA256 string + Prefix string + URL string + LastUpdated string +} + +// HTTPArchiveRule returns a CallExpr describing the provided http_archive +// rule, or nil if the expr in question is not an http_archive rule. +func HTTPArchiveRule(expr build.Expr) (*build.CallExpr, bool) { + callexpr, ok := expr.(*build.CallExpr) + if !ok { + return nil, false + } + name, ok := callexpr.X.(*build.Ident) + if !ok || name.Name != "http_archive" { + return nil, false + } + return callexpr, true +} + +func parseString(expr build.Expr) (string, error) { + str, ok := expr.(*build.StringExpr) + if !ok { + return "", fmt.Errorf("expected string as the function argument") + } + return str.Value, nil +} + +func parseSingleElementList(expr build.Expr) (string, error) { + list, ok := expr.(*build.ListExpr) + if !ok { + return "", fmt.Errorf("expected a list as the function argument") + } + if len(list.List) != 1 { + return "", fmt.Errorf("expected a single-element list as the function argument, got %d elements", len(list.List)) + } + return parseString(list.List[0]) +} + +// ParseHTTPArchiveRule parses the provided http_archive rule and returns all of the dependency metadata embedded. +func ParseHTTPArchiveRule(callexpr *build.CallExpr) (*Entry, error) { + result := Entry{} + for _, arg := range callexpr.List { + assign, ok := arg.(*build.AssignExpr) + if !ok { + return nil, fmt.Errorf("a non-named argument passed as a function parameter") + } + argname, _ := build.GetParamName(assign.LHS) + var err error = nil + switch argname { + case "name": + result.Name, err = parseString(assign.RHS) + case "sha256": + result.SHA256, err = parseString(assign.RHS) + + if len(assign.Comments.Suffix) != 1 { + return nil, fmt.Errorf("missing the \"Last updated\" comment on the sha256 field") + } + comment := assign.Comments.Suffix[0].Token + match := lastUpdatedRE.FindStringSubmatch(comment) + if match == nil { + return nil, fmt.Errorf("unable to parse the \"Last updated\" comment, comment value: %s", comment) + } + result.LastUpdated = match[1] + case "strip_prefix": + result.Prefix, err = parseString(assign.RHS) + case "urls": + result.URL, err = parseSingleElementList(assign.RHS) + default: + continue + } + if err != nil { + return nil, err + } + } + if result.Name == "" { + return nil, fmt.Errorf("missing the name field") + } + if result.SHA256 == "" { + return nil, fmt.Errorf("missing the sha256 field") + } + if result.URL == "" { + return nil, fmt.Errorf("missing the urls field") + } + return &result, nil +} + +// ParseHTTPArchiveRules parses the entire WORKSPACE.bazel file and returns all of the http_archive rules in it. +func ParseHTTPArchiveRules(source []byte) ([]*Entry, error) { + file, err := build.ParseWorkspace("WORKSPACE.bazel", source) + if err != nil { + return []*Entry{}, err + } + + result := make([]*Entry, 0) + for _, expr := range file.Stmt { + callexpr, ok := HTTPArchiveRule(expr) + if !ok { + continue + } + parsed, err := ParseHTTPArchiveRule(callexpr) + if err != nil { + return []*Entry{}, err + } + result = append(result, parsed) + } + return result, nil +} diff --git a/depstool/deps/parse_test.go b/depstool/deps/parse_test.go new file mode 100644 index 000000000000..2a0ac1dc8bbe --- /dev/null +++ b/depstool/deps/parse_test.go @@ -0,0 +1,103 @@ +package deps + +import ( + "reflect" + "testing" + + "github.com/bazelbuild/buildtools/build" +) + +func TestRuleParser(t *testing.T) { + exampleRule := ` +http_archive( + name = "com_google_absl", + sha256 = "44634eae586a7158dceedda7d8fd5cec6d1ebae08c83399f75dd9ce76324de40", # Last updated 2022-05-18 + strip_prefix = "abseil-cpp-3e04aade4e7a53aebbbed1a1268117f1f522bfb0", + urls = ["https://github.com/abseil/abseil-cpp/archive/3e04aade4e7a53aebbbed1a1268117f1f522bfb0.zip"], +)` + + file, err := build.ParseWorkspace("WORKSPACE.bazel", []byte(exampleRule)) + if err != nil { + t.Fatal(err) + } + rule, ok := HTTPArchiveRule(file.Stmt[0]) + if !ok { + t.Fatal("The first rule encountered is not http_archive") + } + + deps, err := ParseHTTPArchiveRule(rule) + if err != nil { + t.Fatal(err) + } + + expected := Entry{ + Name: "com_google_absl", + SHA256: "44634eae586a7158dceedda7d8fd5cec6d1ebae08c83399f75dd9ce76324de40", + Prefix: "abseil-cpp-3e04aade4e7a53aebbbed1a1268117f1f522bfb0", + URL: "https://github.com/abseil/abseil-cpp/archive/3e04aade4e7a53aebbbed1a1268117f1f522bfb0.zip", + LastUpdated: "2022-05-18", + } + if !reflect.DeepEqual(*deps, expected) { + t.Errorf("Parsing returned incorret result, expected:\n %v\n, got:\n %v", expected, *deps) + } +} + +func TestMultipleRules(t *testing.T) { + exampleRules := ` +http_archive( + name = "com_google_absl", + sha256 = "44634eae586a7158dceedda7d8fd5cec6d1ebae08c83399f75dd9ce76324de40", # Last updated 2022-05-18 + strip_prefix = "abseil-cpp-3e04aade4e7a53aebbbed1a1268117f1f522bfb0", + urls = ["https://github.com/abseil/abseil-cpp/archive/3e04aade4e7a53aebbbed1a1268117f1f522bfb0.zip"], +) + +irrelevant_call() + +http_archive( + name = "com_google_protobuf", + sha256 = "8b28fdd45bab62d15db232ec404248901842e5340299a57765e48abe8a80d930", # Last updated 2022-05-18 + strip_prefix = "protobuf-3.20.1", + urls = ["https://github.com/protocolbuffers/protobuf/archive/refs/tags/v3.20.1.tar.gz"], +) +` + + rules, err := ParseHTTPArchiveRules([]byte(exampleRules)) + if err != nil { + t.Fatal(err) + } + if len(rules) != 2 { + t.Fatalf("Expected 2 rules, got %d", len(rules)) + } + if rules[0].Name != "com_google_absl" || rules[1].Name != "com_google_protobuf" { + t.Errorf("Expected the two rules to be com_google_absl and com_google_protobuf, got %s and %s", rules[0].Name, rules[1].Name) + } +} + +func TestBazelParseError(t *testing.T) { + exampleRule := ` +http_archive( + name = "com_google_absl", + sha256 = "44634eae586a7158dceedda7d8fd5cec6d1ebae08c83399f75dd9ce76324de40", # Last updated 2022-05-18 + strip_prefix = "abseil-cpp-3e04aade4e7a53aebbbed1a1268117f1f522bfb0", + urls = ["https://github.com/abseil/abseil-cpp/archive/3e04aade4e7a53aebbbed1a1268117f1f522bfb0.zip"], +` + + _, err := ParseHTTPArchiveRules([]byte(exampleRule)) + if err == nil { + t.Errorf("Expected parser error") + } +} + +func TestMissingField(t *testing.T) { + exampleRule := ` +http_archive( + name = "com_google_absl", + strip_prefix = "abseil-cpp-3e04aade4e7a53aebbbed1a1268117f1f522bfb0", + urls = ["https://github.com/abseil/abseil-cpp/archive/3e04aade4e7a53aebbbed1a1268117f1f522bfb0.zip"], +)` + + _, err := ParseHTTPArchiveRules([]byte(exampleRule)) + if err == nil || err.Error() != "missing the sha256 field" { + t.Errorf("Expected the missing sha256 error, got %v", err) + } +} diff --git a/depstool/depstool.go b/depstool/depstool.go new file mode 100644 index 000000000000..8ada2d940118 --- /dev/null +++ b/depstool/depstool.go @@ -0,0 +1,99 @@ +// depstool is a command-line tool for manipulating QUICHE WORKSPACE.bazel file. +package main + +import ( + "flag" + "fmt" + "io/ioutil" + "log" + "os" + "time" + + "github.com/bazelbuild/buildtools/build" + "quiche.googlesource.com/quiche/depstool/deps" +) + +func list(path string, contents []byte) { + flags, err := deps.ParseHTTPArchiveRules(contents) + if err != nil { + log.Fatalf("Failed to parse %s: %v", path, err) + } + + fmt.Println("+------------------------------+--------------------------+") + fmt.Println("| Dependency | Last updated |") + fmt.Println("+------------------------------+--------------------------+") + for _, flag := range flags { + lastUpdated, err := time.Parse("2006-01-02", flag.LastUpdated) + if err != nil { + log.Fatalf("Failed to parse date %s: %v", flag.LastUpdated, err) + } + delta := time.Since(lastUpdated) + days := int(delta.Hours() / 24) + fmt.Printf("| %28s | %s, %3d days ago |\n", flag.Name, flag.LastUpdated, days) + } + fmt.Println("+------------------------------+--------------------------+") +} + +func validate(path string, contents []byte) { + file, err := build.ParseWorkspace(path, contents) + if err != nil { + log.Fatalf("Failed to parse the WORKSPACE.bazel file: %v", err) + } + + success := true + for _, stmt := range file.Stmt { + rule, ok := deps.HTTPArchiveRule(stmt) + if !ok { + // Skip unrelated rules + continue + } + if _, err := deps.ParseHTTPArchiveRule(rule); err != nil { + log.Printf("Failed to parse http_archive in %s on the line %d, issue: %v", path, rule.Pos.Line, err) + success = false + } + } + if !success { + os.Exit(1) + } + log.Printf("All http_archive rules have been validated successfully") + os.Exit(0) +} + +func usage() { + fmt.Fprintf(flag.CommandLine.Output(), ` +usage: depstool [WORKSPACE file] [subcommand] + +Available subcommands: + list Lists all of the rules in the file + validate Validates that the WORKSPACE file is parsable + +If no subcommand is specified, "list" is assumed. +`) + flag.PrintDefaults() +} + +func main() { + flag.Usage = usage + flag.Parse() + path := flag.Arg(0) + if path == "" { + usage() + os.Exit(1) + } + contents, err := ioutil.ReadFile(path) + if err != nil { + log.Fatalf("Failed to read WORKSPACE.bazel file: %v", err) + } + + subcommand := flag.Arg(1) + switch subcommand { + case "": + fallthrough // list is the default action + case "list": + list(path, contents) + case "validate": + validate(path, contents) + default: + log.Fatalf("Unknown command: %s", subcommand) + } +} diff --git a/depstool/go.mod b/depstool/go.mod new file mode 100644 index 000000000000..6277e2581b9e --- /dev/null +++ b/depstool/go.mod @@ -0,0 +1,7 @@ +module quiche.googlesource.com/quiche/depstool + +go 1.20 + +require ( + github.com/bazelbuild/buildtools v0.0.0-20221004120235-7186f635531b +) diff --git a/depstool/go.sum b/depstool/go.sum new file mode 100644 index 000000000000..781c5f2e5529 --- /dev/null +++ b/depstool/go.sum @@ -0,0 +1,2 @@ +github.com/bazelbuild/buildtools v0.0.0-20221004120235-7186f635531b h1:jhiMzJ+8unnLRtV8rpbWBFE9pFNzIqgUTyZU5aA++w8= +github.com/bazelbuild/buildtools v0.0.0-20221004120235-7186f635531b/go.mod h1:689QdV3hBP7Vo9dJMmzhoYIyo/9iMhEmHkJcnaPRCbo= diff --git a/quiche/BUILD.bazel b/quiche/BUILD.bazel new file mode 100644 index 000000000000..f63fb060e13f --- /dev/null +++ b/quiche/BUILD.bazel @@ -0,0 +1,555 @@ +# Copyright 2022 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +load( + "//build:source_list.bzl", + "binary_http_hdrs", + "binary_http_srcs", + "default_platform_impl_hdrs", + "default_platform_impl_srcs", + "default_platform_impl_test_support_hdrs", + "default_platform_impl_test_support_srcs", + "default_platform_impl_tool_support_hdrs", + "default_platform_impl_tool_support_srcs", + "io_test_support_hdrs", + "io_test_support_srcs", + "io_tests_srcs", + "io_tool_support_hdrs", + "io_tool_support_srcs", + "oblivious_http_hdrs", + "oblivious_http_srcs", + "quiche_core_hdrs", + "quiche_core_srcs", + "quiche_test_support_hdrs", + "quiche_test_support_srcs", + "quiche_tests_srcs", + "quiche_tool_support_hdrs", + "quiche_tool_support_srcs", +) +load("//build:test.bzl", "test_suite_from_source_list") + +licenses(["notice"]) + +package( + default_visibility = ["//visibility:private"], + features = [ + "parse_headers", + "layering_check", + ], +) + +cc_library( + name = "quiche_flags_list", + textual_hdrs = [ + "common/quiche_protocol_flags_list.h", + ], +) + +cc_library( + name = "quic_flags_list", + textual_hdrs = [ + "quic/core/quic_flags_list.h", + "quic/core/quic_protocol_flags_list.h", + ], +) + +cc_library( + name = "binary_http", + srcs = binary_http_srcs, + hdrs = binary_http_hdrs, + deps = [ + ":quiche_core", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "oblivious_http", + srcs = oblivious_http_srcs, + hdrs = oblivious_http_hdrs, + deps = [ + ":quiche_core", + "@boringssl//:crypto", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + +# QUICHE_EXPORT is used by all platform definitions, and thus needs to be handled separately. +cc_library( + name = "quiche_platform_default_quiche_export", + hdrs = [ + "common/platform/default/quiche_platform_impl/quiche_export_impl.h", + ], + strip_include_prefix = "common/platform/default", +) + +cc_library( + name = "quiche_platform_quiche_export", + hdrs = [ + "common/platform/api/quiche_export.h", + ], + deps = [":quiche_platform_default_quiche_export"], +) + +cc_library( + name = "quiche_platform_default", + srcs = default_platform_impl_srcs, + hdrs = default_platform_impl_hdrs, + strip_include_prefix = "common/platform/default", + deps = [ + ":quic_flags_list", + ":quiche_flags_list", + ":quiche_platform_quiche_export", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/debugging:stacktrace", + "@com_google_absl//absl/debugging:symbolize", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:optional", + "@com_google_googleurl//url", + ], +) + +cc_library( + name = "quiche_platform_default_tools", + srcs = default_platform_impl_tool_support_srcs, + hdrs = default_platform_impl_tool_support_hdrs, + strip_include_prefix = "common/platform/default", + deps = [ + ":quiche_core", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/flags:usage", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "quiche_platform_default_testonly", + testonly = 1, + srcs = default_platform_impl_test_support_srcs, + hdrs = default_platform_impl_test_support_hdrs, + strip_include_prefix = "common/platform/default", + deps = [ + ":quiche_core", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/flags:usage", + "@com_google_googletest//:gtest", + ], +) + +proto_library( + name = "quiche_protobufs", + srcs = [ + "quic/core/proto/cached_network_parameters.proto", + "quic/core/proto/crypto_server_config.proto", + "quic/core/proto/source_address_token.proto", + ], +) + +cc_proto_library( + name = "quiche_protobufs_cc_proto", + deps = [":quiche_protobufs"], +) + +proto_library( + name = "quiche_protobufs_testonly", + srcs = [ + "quic/test_tools/send_algorithm_test_result.proto", + ], +) + +cc_proto_library( + name = "quiche_protobufs_testonly_cc_proto", + deps = [":quiche_protobufs_testonly"], +) + +cc_library( + name = "quiche_core", + srcs = quiche_core_srcs, + hdrs = quiche_core_hdrs, + textual_hdrs = ["http2/hpack/hpack_static_table_entries.inc"], + deps = [ + ":quiche_platform_default", + ":quiche_protobufs_cc_proto", + "@boringssl//:crypto", + "@boringssl//:ssl", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/numeric:int128", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", + "@com_google_googleurl//url", + "@com_google_quic_trace//quic_trace:quic_trace_cc_proto", + "@zlib", + ], +) + +cc_library( + name = "quiche_tool_support", + srcs = quiche_tool_support_srcs, + hdrs = quiche_tool_support_hdrs, + deps = [ + ":quiche_core", + ":quiche_platform_default_tools", + "@boringssl//:crypto", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", + "@com_google_googleurl//url", + ], +) + +cc_library( + name = "quiche_test_support", + testonly = 1, + srcs = quiche_test_support_srcs, + hdrs = quiche_test_support_hdrs, + deps = [ + ":binary_http", + ":quiche_core", + ":quiche_platform_default_testonly", + ":quiche_protobufs_testonly_cc_proto", + ":quiche_tool_support", + "@boringssl//:crypto", + "@boringssl//:ssl", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", + "@com_google_googletest//:gtest", + "@com_google_googleurl//url", + ], +) + +cc_library( + name = "quic_toy_client", + srcs = [ + "quic/tools/quic_toy_client.cc", + ], + hdrs = [ + "quic/tools/quic_toy_client.h", + ], + deps = [ + ":io_tool_support", + ":quiche_core", + ":quiche_platform_default", + ":quiche_tool_support", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "quic_toy_server", + srcs = [ + "quic/tools/quic_toy_server.cc", + ], + hdrs = [ + "quic/tools/quic_toy_server.h", + ], + deps = [ + ":io_tool_support", + ":quiche_core", + ":quiche_platform_default_tools", + ":quiche_tool_support", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "quic_server_factory", + srcs = [ + "quic/tools/quic_server_factory.cc", + ], + hdrs = [ + "quic/tools/quic_server_factory.h", + ], + deps = [ + ":io_tool_support", + ":quic_toy_server", + ], +) + +test_suite_from_source_list( + name = "quiche_tests", + srcs = quiche_tests_srcs, + data = glob([ + "common/platform/api/testdir/**", + "quic/test_tools/quic_http_response_cache_data/**", + ]), + deps = [ + ":binary_http", + ":oblivious_http", + ":quiche_core", + ":quiche_platform_default_testonly", + ":quiche_protobufs_testonly_cc_proto", + ":quiche_test_support", + ":quiche_tool_support", + "@boringssl//:crypto", + "@boringssl//:ssl", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/functional:bind_front", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/numeric:int128", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@com_google_googleurl//url", + ], +) + +cc_library( + name = "io_tool_support", + srcs = io_tool_support_srcs, + hdrs = io_tool_support_hdrs, + deps = [ + ":quiche_core", + ":quiche_platform_default_tools", + ":quiche_tool_support", + "@boringssl//:crypto", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", + "@com_google_googleurl//url", + ], +) + +cc_library( + name = "io_test_support", + testonly = 1, + srcs = io_test_support_srcs, + hdrs = io_test_support_hdrs, + deps = [ + ":io_tool_support", + ":quiche_core", + ":quiche_platform_default_tools", + ":quiche_test_support", + ":quiche_tool_support", + "@boringssl//:crypto", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", + "@com_google_googletest//:gtest", + "@com_google_googleurl//url", + ], +) + +test_suite_from_source_list( + name = "io_tests", + srcs = io_tests_srcs, + data = glob([ + "common/platform/api/testdir/**", + "quic/test_tools/quic_http_response_cache_data/**", + ]), + deps = [ + ":binary_http", + ":io_test_support", + ":io_tool_support", + ":quiche_core", + ":quiche_platform_default_testonly", + ":quiche_protobufs_testonly_cc_proto", + ":quiche_test_support", + ":quiche_tool_support", + "@boringssl//:crypto", + "@boringssl//:ssl", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/functional:bind_front", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/numeric:int128", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + "@com_google_googleurl//url", + ], +) + +# TODO(vasilvv): make a rule that generates cc_binary rules for all _bin targets. +cc_binary( + name = "quic_packet_printer", + srcs = ["quic/tools/quic_packet_printer_bin.cc"], + deps = [ + ":quiche_core", + ":quiche_tool_support", + "@com_google_absl//absl/strings", + ], +) + +cc_binary( + name = "crypto_message_printer", + srcs = ["quic/tools/crypto_message_printer_bin.cc"], + deps = [ + ":quiche_core", + ":quiche_tool_support", + "@com_google_absl//absl/strings", + ], +) + +cc_binary( + name = "quic_client", + srcs = ["quic/tools/quic_client_bin.cc"], + deps = [ + ":io_tool_support", + ":quic_toy_client", + ":quiche_core", + ":quiche_tool_support", + "@com_google_absl//absl/strings", + ], +) + +cc_binary( + name = "quic_server", + srcs = ["quic/tools/quic_server_bin.cc"], + deps = [ + ":io_tool_support", + ":quic_server_factory", + ":quic_toy_server", + ":quiche_core", + ":quiche_platform_default", + ":quiche_platform_default_tools", + ":quiche_tool_support", + "@com_google_absl//absl/strings", + ], +) + +cc_binary( + name = "masque_client", + srcs = ["quic/masque/masque_client_bin.cc"], + deps = [ + ":io_tool_support", + ":quiche_core", + ":quiche_tool_support", + "@com_google_absl//absl/strings", + "@com_google_googleurl//url", + ], +) + +cc_binary( + name = "masque_server", + srcs = ["quic/masque/masque_server_bin.cc"], + deps = [ + ":io_tool_support", + ":quiche_core", + ":quiche_platform_default", + ":quiche_platform_default_tools", + ":quiche_tool_support", + "@com_google_absl//absl/strings", + ], +) + +# Indicate that QUICHE APIs are explicitly unstable by providing only +# appropriately named aliases as publicly visible targets. +alias( + name = "quiche_unstable_api", + actual = ":quiche_core", + visibility = ["//visibility:public"], +) + +alias( + name = "binary_http_unstable_api", + actual = ":binary_http", + visibility = ["//visibility:public"], +) + +alias( + name = "oblivious_http_unstable_api", + actual = ":oblivious_http", + visibility = ["//visibility:public"], +) + +alias( + name = "quiche_unstable_api_tool_support", + actual = ":quiche_tool_support", + visibility = ["//visibility:public"], +) + +alias( + name = "quiche_unstable_api_test_support", + actual = ":quiche_test_support", + visibility = ["//visibility:public"], +) diff --git a/quiche/balsa/balsa_enums.cc b/quiche/balsa/balsa_enums.cc new file mode 100644 index 000000000000..bc6b68f6be56 --- /dev/null +++ b/quiche/balsa/balsa_enums.cc @@ -0,0 +1,117 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/balsa/balsa_enums.h" + +namespace quiche { + +const char* BalsaFrameEnums::ParseStateToString( + BalsaFrameEnums::ParseState error_code) { + switch (error_code) { + case ERROR: + return "ERROR"; + case READING_HEADER_AND_FIRSTLINE: + return "READING_HEADER_AND_FIRSTLINE"; + case READING_CHUNK_LENGTH: + return "READING_CHUNK_LENGTH"; + case READING_CHUNK_EXTENSION: + return "READING_CHUNK_EXTENSION"; + case READING_CHUNK_DATA: + return "READING_CHUNK_DATA"; + case READING_CHUNK_TERM: + return "READING_CHUNK_TERM"; + case READING_LAST_CHUNK_TERM: + return "READING_LAST_CHUNK_TERM"; + case READING_TRAILER: + return "READING_TRAILER"; + case READING_UNTIL_CLOSE: + return "READING_UNTIL_CLOSE"; + case READING_CONTENT: + return "READING_CONTENT"; + case MESSAGE_FULLY_READ: + return "MESSAGE_FULLY_READ"; + case NUM_STATES: + return "UNKNOWN_STATE"; + } + return "UNKNOWN_STATE"; +} + +const char* BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums::ErrorCode error_code) { + switch (error_code) { + case BALSA_NO_ERROR: + return "BALSA_NO_ERROR"; + case NO_STATUS_LINE_IN_RESPONSE: + return "NO_STATUS_LINE_IN_RESPONSE"; + case NO_REQUEST_LINE_IN_REQUEST: + return "NO_REQUEST_LINE_IN_REQUEST"; + case FAILED_TO_FIND_WS_AFTER_RESPONSE_VERSION: + return "FAILED_TO_FIND_WS_AFTER_RESPONSE_VERSION"; + case FAILED_TO_FIND_WS_AFTER_REQUEST_METHOD: + return "FAILED_TO_FIND_WS_AFTER_REQUEST_METHOD"; + case FAILED_TO_FIND_WS_AFTER_RESPONSE_STATUSCODE: + return "FAILED_TO_FIND_WS_AFTER_RESPONSE_STATUSCODE"; + case FAILED_TO_FIND_WS_AFTER_REQUEST_REQUEST_URI: + return "FAILED_TO_FIND_WS_AFTER_REQUEST_REQUEST_URI"; + case FAILED_TO_FIND_NL_AFTER_RESPONSE_REASON_PHRASE: + return "FAILED_TO_FIND_NL_AFTER_RESPONSE_REASON_PHRASE"; + case FAILED_TO_FIND_NL_AFTER_REQUEST_HTTP_VERSION: + return "FAILED_TO_FIND_NL_AFTER_REQUEST_HTTP_VERSION"; + case FAILED_CONVERTING_STATUS_CODE_TO_INT: + return "FAILED_CONVERTING_STATUS_CODE_TO_INT"; + case HEADERS_TOO_LONG: + return "HEADERS_TOO_LONG"; + case UNPARSABLE_CONTENT_LENGTH: + return "UNPARSABLE_CONTENT_LENGTH"; + case HTTP2_CONTENT_LENGTH_ERROR: + return "HTTP2_CONTENT_LENGTH_ERROR"; + case MAYBE_BODY_BUT_NO_CONTENT_LENGTH: + return "MAYBE_BODY_BUT_NO_CONTENT_LENGTH"; + case REQUIRED_BODY_BUT_NO_CONTENT_LENGTH: + return "REQUIRED_BODY_BUT_NO_CONTENT_LENGTH"; + case HEADER_MISSING_COLON: + return "HEADER_MISSING_COLON"; + case INVALID_CHUNK_LENGTH: + return "INVALID_CHUNK_LENGTH"; + case CHUNK_LENGTH_OVERFLOW: + return "CHUNK_LENGTH_OVERFLOW"; + case CALLED_BYTES_SPLICED_WHEN_UNSAFE_TO_DO_SO: + return "CALLED_BYTES_SPLICED_WHEN_UNSAFE_TO_DO_SO"; + case CALLED_BYTES_SPLICED_AND_EXCEEDED_SAFE_SPLICE_AMOUNT: + return "CALLED_BYTES_SPLICED_AND_EXCEEDED_SAFE_SPLICE_AMOUNT"; + case MULTIPLE_CONTENT_LENGTH_KEYS: + return "MULTIPLE_CONTENT_LENGTH_KEYS"; + case MULTIPLE_TRANSFER_ENCODING_KEYS: + return "MULTIPLE_TRANSFER_ENCODING_KEYS"; + case UNKNOWN_TRANSFER_ENCODING: + return "UNKNOWN_TRANSFER_ENCODING"; + case BOTH_TRANSFER_ENCODING_AND_CONTENT_LENGTH: + return "BOTH_TRANSFER_ENCODING_AND_CONTENT_LENGTH"; + case INVALID_HEADER_FORMAT: + return "INVALID_HEADER_FORMAT"; + case HTTP2_INVALID_HEADER_FORMAT: + return "HTTP2_INVALID_HEADER_FORMAT"; + case INVALID_TRAILER_FORMAT: + return "INVALID_TRAILER_FORMAT"; + case TRAILER_TOO_LONG: + return "TRAILER_TOO_LONG"; + case TRAILER_MISSING_COLON: + return "TRAILER_MISSING_COLON"; + case INTERNAL_LOGIC_ERROR: + return "INTERNAL_LOGIC_ERROR"; + case INVALID_HEADER_CHARACTER: + return "INVALID_HEADER_CHARACTER"; + case INVALID_HEADER_NAME_CHARACTER: + return "INVALID_HEADER_NAME_CHARACTER"; + case INVALID_TRAILER_NAME_CHARACTER: + return "INVALID_TRAILER_NAME_CHARACTER"; + case UNSUPPORTED_100_CONTINUE: + return "UNSUPPORTED_100_CONTINUE"; + case NUM_ERROR_CODES: + return "UNKNOWN_ERROR"; + } + return "UNKNOWN_ERROR"; +} + +} // namespace quiche diff --git a/quiche/balsa/balsa_enums.h b/quiche/balsa/balsa_enums.h new file mode 100644 index 000000000000..60537a346227 --- /dev/null +++ b/quiche/balsa/balsa_enums.h @@ -0,0 +1,130 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_BALSA_BALSA_ENUMS_H_ +#define QUICHE_BALSA_BALSA_ENUMS_H_ + +#include "quiche/common/platform/api/quiche_export.h" + +namespace quiche { + +struct QUICHE_EXPORT BalsaFrameEnums { + enum ParseState : int { + ERROR, + READING_HEADER_AND_FIRSTLINE, + READING_CHUNK_LENGTH, + READING_CHUNK_EXTENSION, + READING_CHUNK_DATA, + READING_CHUNK_TERM, + READING_LAST_CHUNK_TERM, + READING_TRAILER, + READING_UNTIL_CLOSE, + READING_CONTENT, + MESSAGE_FULLY_READ, + NUM_STATES, + }; + + enum ErrorCode : int { + // A sentinel value for convenience, none of the callbacks should ever see + // this error code. + BALSA_NO_ERROR = 0, + + // Header parsing errors + // Note that adding one to many of the REQUEST errors yields the + // appropriate RESPONSE error. + // Particularly, when parsing the first line of a request or response, + // there are three sequences of non-whitespace regardless of whether or + // not it is a request or response. These are listed below, in order. + // + // firstline_a firstline_b firstline_c + // REQ: method request_uri version + // RESP: version statuscode reason + // + // As you can see, the first token is the 'method' field for a request, + // and 'version' field for a response. We call the first non whitespace + // token firstline_a, the second firstline_b, and the third token + // followed by [^\r\n]*) firstline_c. + // + // This organization is important, as it lets us determine the error code + // to use without a branch based on is_response. Instead, we simply add + // is_response to the response error code-- If is_response is true, then + // we'll get the response error code, thanks to the fact that the error + // code numbers are organized to ensure that response error codes always + // precede request error codes. + // | Triggered + // | while processing + // | this NONWS + // | sequence... + NO_STATUS_LINE_IN_RESPONSE, // | + NO_REQUEST_LINE_IN_REQUEST, // | + FAILED_TO_FIND_WS_AFTER_RESPONSE_VERSION, // | firstline_a + FAILED_TO_FIND_WS_AFTER_REQUEST_METHOD, // | firstline_a + FAILED_TO_FIND_WS_AFTER_RESPONSE_STATUSCODE, // | firstline_b + FAILED_TO_FIND_WS_AFTER_REQUEST_REQUEST_URI, // | firstline_b + FAILED_TO_FIND_NL_AFTER_RESPONSE_REASON_PHRASE, // | firstline_c + FAILED_TO_FIND_NL_AFTER_REQUEST_HTTP_VERSION, // | firstline_c + + FAILED_CONVERTING_STATUS_CODE_TO_INT, + + HEADERS_TOO_LONG, + UNPARSABLE_CONTENT_LENGTH, + // Warning: there may be a body but there was no content-length/chunked + // encoding + MAYBE_BODY_BUT_NO_CONTENT_LENGTH, + + // This is used if a body is required for a request. + REQUIRED_BODY_BUT_NO_CONTENT_LENGTH, + + HEADER_MISSING_COLON, + + // Chunking errors + INVALID_CHUNK_LENGTH, + CHUNK_LENGTH_OVERFLOW, + + // Other errors. + CALLED_BYTES_SPLICED_WHEN_UNSAFE_TO_DO_SO, + CALLED_BYTES_SPLICED_AND_EXCEEDED_SAFE_SPLICE_AMOUNT, + MULTIPLE_CONTENT_LENGTH_KEYS, + MULTIPLE_TRANSFER_ENCODING_KEYS, + UNKNOWN_TRANSFER_ENCODING, + BOTH_TRANSFER_ENCODING_AND_CONTENT_LENGTH, + INVALID_HEADER_FORMAT, + HTTP2_INVALID_HEADER_FORMAT, + HTTP2_CONTENT_LENGTH_ERROR, + + // Trailer errors. + INVALID_TRAILER_FORMAT, + TRAILER_TOO_LONG, + TRAILER_MISSING_COLON, + + // A detected internal inconsistency was found. + INTERNAL_LOGIC_ERROR, + + // A control character was found in a header key or value + INVALID_HEADER_CHARACTER, + INVALID_HEADER_NAME_CHARACTER, + INVALID_TRAILER_NAME_CHARACTER, + + // The client request included 'Expect: 100-continue' header on a protocol + // that doesn't support it. + UNSUPPORTED_100_CONTINUE, + + NUM_ERROR_CODES + }; + static const char* ParseStateToString(ParseState error_code); + static const char* ErrorCodeToString(ErrorCode error_code); +}; + +struct QUICHE_EXPORT BalsaHeadersEnums { + enum ContentLengthStatus : int { + INVALID_CONTENT_LENGTH, + CONTENT_LENGTH_OVERFLOW, + NO_CONTENT_LENGTH, + VALID_CONTENT_LENGTH, + }; +}; + +} // namespace quiche + +#endif // QUICHE_BALSA_BALSA_ENUMS_H_ diff --git a/quiche/balsa/balsa_frame.cc b/quiche/balsa/balsa_frame.cc new file mode 100644 index 000000000000..ac5e018528e9 --- /dev/null +++ b/quiche/balsa/balsa_frame.cc @@ -0,0 +1,1338 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/balsa/balsa_frame.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/string_view.h" +#include "quiche/balsa/balsa_enums.h" +#include "quiche/balsa/balsa_headers.h" +#include "quiche/balsa/balsa_visitor_interface.h" +#include "quiche/balsa/header_properties.h" +#include "quiche/common/platform/api/quiche_logging.h" + +// When comparing characters (other than == and !=), cast to unsigned char +// to make sure values above 127 rank as expected, even on platforms where char +// is signed and thus such values are represented as negative numbers before the +// cast. +#define CHAR_LT(a, b) \ + (static_cast(a) < static_cast(b)) +#define CHAR_LE(a, b) \ + (static_cast(a) <= static_cast(b)) +#define CHAR_GT(a, b) \ + (static_cast(a) > static_cast(b)) +#define CHAR_GE(a, b) \ + (static_cast(a) >= static_cast(b)) +#define QUICHE_DCHECK_CHAR_GE(a, b) \ + QUICHE_DCHECK_GE(static_cast(a), static_cast(b)) + +namespace quiche { + +namespace { + +const size_t kContinueStatusCode = 100; + +constexpr absl::string_view kChunked = "chunked"; +constexpr absl::string_view kContentLength = "content-length"; +constexpr absl::string_view kIdentity = "identity"; +constexpr absl::string_view kTransferEncoding = "transfer-encoding"; + +bool IsInterimResponse(size_t response_code) { + return response_code >= 100 && response_code < 200; +} + +} // namespace + +void BalsaFrame::Reset() { + last_char_was_slash_r_ = false; + saw_non_newline_char_ = false; + start_was_space_ = true; + chunk_length_character_extracted_ = false; + // is_request_ = true; // not reset between messages. + allow_reading_until_close_for_request_ = false; + // request_was_head_ = false; // not reset between messages. + // max_header_length_ = 16 * 1024; // not reset between messages. + // visitor_ = &do_nothing_visitor_; // not reset between messages. + chunk_length_remaining_ = 0; + content_length_remaining_ = 0; + last_slash_n_loc_ = nullptr; + last_recorded_slash_n_loc_ = nullptr; + last_slash_n_idx_ = 0; + term_chars_ = 0; + parse_state_ = BalsaFrameEnums::READING_HEADER_AND_FIRSTLINE; + last_error_ = BalsaFrameEnums::BALSA_NO_ERROR; + invalid_chars_.clear(); + lines_.clear(); + if (continue_headers_ != nullptr) { + continue_headers_->Clear(); + } + if (headers_ != nullptr) { + headers_->Clear(); + } + trailer_lines_.clear(); + start_of_trailer_line_ = 0; + trailer_length_ = 0; + if (trailer_ != nullptr) { + trailer_->Clear(); + } +} + +namespace { + +// Within the line bounded by [current, end), parses a single "island", +// comprising a (possibly empty) span of whitespace followed by a (possibly +// empty) span of non-whitespace. +// +// Returns a pointer to the first whitespace character beyond this island, or +// returns end if no additional whitespace characters are present after this +// island. (I.e., returnvalue == end || *returnvalue > ' ') +// +// Upon return, the whitespace span are the characters +// whose indices fall in [*first_whitespace, *first_nonwhite), while the +// non-whitespace span are the characters whose indices fall in +// [*first_nonwhite, returnvalue - begin). +inline const char* ParseOneIsland(const char* current, const char* begin, + const char* end, size_t* first_whitespace, + size_t* first_nonwhite) { + *first_whitespace = current - begin; + while (current < end && CHAR_LE(*current, ' ')) { + ++current; + } + *first_nonwhite = current - begin; + while (current < end && CHAR_GT(*current, ' ')) { + ++current; + } + return current; +} + +} // namespace + +// Summary: +// Parses the first line of either a request or response. +// Note that in the case of a detected warning, error_code will be set +// but the function will not return false. +// Exactly zero or one warning or error (but not both) may be detected +// by this function. +// Note that this function will not write the data of the first-line +// into the header's buffer (that should already have been done elsewhere). +// +// Pre-conditions: +// begin != end +// *begin should be a character which is > ' '. This implies that there +// is at least one non-whitespace characters between [begin, end). +// headers is a valid pointer to a BalsaHeaders class. +// error_code is a valid pointer to a BalsaFrameEnums::ErrorCode value. +// Entire first line must exist between [begin, end) +// Exactly zero or one newlines -may- exist between [begin, end) +// [begin, end) should exist in the header's buffer. +// +// Side-effects: +// headers will be modified +// error_code may be modified if either a warning or error is detected +// +// Returns: +// True if no error (as opposed to warning) is detected. +// False if an error (as opposed to warning) is detected. + +// +// If there is indeed non-whitespace in the line, then the following +// will take care of this for you: +// while (*begin <= ' ') ++begin; +// ProcessFirstLine(begin, end, is_request, &headers, &error_code); +// + +bool ParseHTTPFirstLine(const char* begin, const char* end, bool is_request, + BalsaHeaders* headers, + BalsaFrameEnums::ErrorCode* error_code) { + while (begin < end && (end[-1] == '\n' || end[-1] == '\r')) { + --end; + } + + const char* current = + ParseOneIsland(begin, begin, end, &headers->whitespace_1_idx_, + &headers->non_whitespace_1_idx_); + current = ParseOneIsland(current, begin, end, &headers->whitespace_2_idx_, + &headers->non_whitespace_2_idx_); + current = ParseOneIsland(current, begin, end, &headers->whitespace_3_idx_, + &headers->non_whitespace_3_idx_); + + // Clean up any trailing whitespace that comes after the third island + const char* last = end; + while (current <= last && CHAR_LE(*last, ' ')) { + --last; + } + headers->whitespace_4_idx_ = last - begin + 1; + + // Either the passed-in line is empty, or it starts with a non-whitespace + // character. + QUICHE_DCHECK(begin == end || static_cast(*begin) > ' '); + + QUICHE_DCHECK_EQ(0u, headers->whitespace_1_idx_); + QUICHE_DCHECK_EQ(0u, headers->non_whitespace_1_idx_); + + // If the line isn't empty, it has at least one non-whitespace character (see + // first QUICHE_DCHECK), which will have been identified as a non-empty + // [non_whitespace_1_idx_, whitespace_2_idx_). + QUICHE_DCHECK(begin == end || + headers->non_whitespace_1_idx_ < headers->whitespace_2_idx_); + + if (headers->non_whitespace_2_idx_ == headers->whitespace_3_idx_) { + // This error may be triggered if the second token is empty, OR there's no + // WS after the first token; we don't bother to distinguish exactly which. + // (I'm not sure why we distinguish different kinds of parse error at all, + // actually.) + // FAILED_TO_FIND_WS_AFTER_REQUEST_METHOD for request + // FAILED_TO_FIND_WS_AFTER_RESPONSE_VERSION for response + *error_code = static_cast( + BalsaFrameEnums::FAILED_TO_FIND_WS_AFTER_RESPONSE_VERSION + + static_cast(is_request)); + if (!is_request) { // FAILED_TO_FIND_WS_AFTER_RESPONSE_VERSION + return false; + } + } + if (headers->whitespace_3_idx_ == headers->non_whitespace_3_idx_) { + if (*error_code == BalsaFrameEnums::BALSA_NO_ERROR) { + // FAILED_TO_FIND_WS_AFTER_REQUEST_METHOD for request + // FAILED_TO_FIND_WS_AFTER_RESPONSE_VERSION for response + *error_code = static_cast( + BalsaFrameEnums::FAILED_TO_FIND_WS_AFTER_RESPONSE_STATUSCODE + + static_cast(is_request)); + } + } + + if (!is_request) { + headers->parsed_response_code_ = 0; + // If the response code is non-empty: + if (headers->non_whitespace_2_idx_ < headers->whitespace_3_idx_) { + if (!absl::SimpleAtoi( + absl::string_view(begin + headers->non_whitespace_2_idx_, + headers->non_whitespace_3_idx_ - + headers->non_whitespace_2_idx_), + &headers->parsed_response_code_)) { + *error_code = BalsaFrameEnums::FAILED_CONVERTING_STATUS_CODE_TO_INT; + return false; + } + } + } + + return true; +} + +// begin - beginning of the firstline +// end - end of the firstline +// +// A precondition for this function is that there is non-whitespace between +// [begin, end). If this precondition is not met, the function will not perform +// as expected (and bad things may happen, and it will eat your first, second, +// and third unborn children!). +// +// Another precondition for this function is that [begin, end) includes +// at most one newline, which must be at the end of the line. +void BalsaFrame::ProcessFirstLine(const char* begin, const char* end) { + BalsaFrameEnums::ErrorCode previous_error = last_error_; + if (!ParseHTTPFirstLine(begin, end, is_request_, headers_, &last_error_)) { + parse_state_ = BalsaFrameEnums::ERROR; + HandleError(last_error_); + return; + } + if (previous_error != last_error_) { + HandleWarning(last_error_); + } + + const absl::string_view line_input( + begin + headers_->non_whitespace_1_idx_, + headers_->whitespace_4_idx_ - headers_->non_whitespace_1_idx_); + const absl::string_view part1( + begin + headers_->non_whitespace_1_idx_, + headers_->whitespace_2_idx_ - headers_->non_whitespace_1_idx_); + const absl::string_view part2( + begin + headers_->non_whitespace_2_idx_, + headers_->whitespace_3_idx_ - headers_->non_whitespace_2_idx_); + const absl::string_view part3( + begin + headers_->non_whitespace_3_idx_, + headers_->whitespace_4_idx_ - headers_->non_whitespace_3_idx_); + + if (is_request_) { + visitor_->OnRequestFirstLineInput(line_input, part1, part2, part3); + if (part3.empty()) { + parse_state_ = BalsaFrameEnums::MESSAGE_FULLY_READ; + } + return; + } + + visitor_->OnResponseFirstLineInput(line_input, part1, part2, part3); +} + +// 'stream_begin' points to the first character of the headers buffer. +// 'line_begin' points to the first character of the line. +// 'current' points to a char which is ':'. +// 'line_end' points to the position of '\n' + 1. +// 'line_begin' points to the position of first character of line. +void BalsaFrame::CleanUpKeyValueWhitespace( + const char* stream_begin, const char* line_begin, const char* current, + const char* line_end, HeaderLineDescription* current_header_line) { + const char* colon_loc = current; + QUICHE_DCHECK_LT(colon_loc, line_end); + QUICHE_DCHECK_EQ(':', *colon_loc); + QUICHE_DCHECK_EQ(':', *current); + QUICHE_DCHECK_CHAR_GE(' ', *line_end) + << "\"" << std::string(line_begin, line_end) << "\""; + + --current; + while (current > line_begin && CHAR_LE(*current, ' ')) { + --current; + } + current += static_cast(current != colon_loc); + current_header_line->key_end_idx = current - stream_begin; + + current = colon_loc; + QUICHE_DCHECK_EQ(':', *current); + ++current; + while (current < line_end && CHAR_LE(*current, ' ')) { + ++current; + } + current_header_line->value_begin_idx = current - stream_begin; + + QUICHE_DCHECK_GE(current_header_line->key_end_idx, + current_header_line->first_char_idx); + QUICHE_DCHECK_GE(current_header_line->value_begin_idx, + current_header_line->key_end_idx); + QUICHE_DCHECK_GE(current_header_line->last_char_idx, + current_header_line->value_begin_idx); +} + +bool BalsaFrame::FindColonsAndParseIntoKeyValue(const Lines& lines, + bool is_trailer, + BalsaHeaders* headers) { + QUICHE_DCHECK(!lines.empty()); + const char* stream_begin = headers->OriginalHeaderStreamBegin(); + // The last line is always just a newline (and is uninteresting). + const Lines::size_type lines_size_m1 = lines.size() - 1; + // For a trailer, there is no first line, so lines[0] is the first header. + // For real headers, the first line takes lines[0], so real header starts + // at index 1. + int first_header_idx = (is_trailer ? 0 : 1); + const char* current = stream_begin + lines[first_header_idx].first; + // This code is a bit more subtle than it may appear at first glance. + // This code looks for a colon in the current line... but it also looks + // beyond the current line. If there is no colon in the current line, then + // for each subsequent line (until the colon which -has- been found is + // associated with a line), no searching for a colon will be performed. In + // this way, we minimize the amount of bytes we have scanned for a colon. + for (Lines::size_type i = first_header_idx; i < lines_size_m1;) { + const char* line_begin = stream_begin + lines[i].first; + + // Here we handle possible continuations. Note that we do not replace + // the '\n' in the line before a continuation (at least, as of now), + // which implies that any code which looks for a value must deal with + // "\r\n", etc -within- the line (and not just at the end of it). + for (++i; i < lines_size_m1; ++i) { + const char c = *(stream_begin + lines[i].first); + if (CHAR_GT(c, ' ')) { + // Not a continuation, so stop. Note that if the 'original' i = 1, + // and the next line is not a continuation, we'll end up with i = 2 + // when we break. This handles the incrementing of i for the outer + // loop. + break; + } + + // Space and tab are valid starts to continuation lines. + // https://tools.ietf.org/html/rfc7230#section-3.2.4 says that a proxy + // can choose to reject or normalize continuation lines. + if ((c != ' ' && c != '\t') || + http_validation_policy().disallow_header_continuation_lines) { + HandleError(is_trailer ? BalsaFrameEnums::INVALID_TRAILER_FORMAT + : BalsaFrameEnums::INVALID_HEADER_FORMAT); + return false; + } + + // If disallow_header_continuation_lines() is false, we neither reject nor + // normalize continuation lines, in violation of RFC7230. + } + const char* line_end = stream_begin + lines[i - 1].second; + QUICHE_DCHECK_LT(line_begin - stream_begin, line_end - stream_begin); + + // We cleanup the whitespace at the end of the line before doing anything + // else of interest as it allows us to do nothing when irregularly formatted + // headers are parsed (e.g. those with only keys, only values, or no colon). + // + // We're guaranteed to have *line_end > ' ' while line_end >= line_begin. + --line_end; + QUICHE_DCHECK_EQ('\n', *line_end) + << "\"" << std::string(line_begin, line_end) << "\""; + while (CHAR_LE(*line_end, ' ') && line_end > line_begin) { + --line_end; + } + ++line_end; + QUICHE_DCHECK_CHAR_GE(' ', *line_end); + QUICHE_DCHECK_LT(line_begin, line_end); + + // We use '0' for the block idx, because we're always writing to the first + // block from the framer (we do this because the framer requires that the + // entire header sequence be in a contiguous buffer). + headers->header_lines_.push_back(HeaderLineDescription( + line_begin - stream_begin, line_end - stream_begin, + line_end - stream_begin, line_end - stream_begin, 0)); + if (current >= line_end) { + if (http_validation_policy().require_header_colon) { + HandleError(is_trailer ? BalsaFrameEnums::TRAILER_MISSING_COLON + : BalsaFrameEnums::HEADER_MISSING_COLON); + return false; + } + HandleWarning(is_trailer ? BalsaFrameEnums::TRAILER_MISSING_COLON + : BalsaFrameEnums::HEADER_MISSING_COLON); + // Then the next colon will not be found within this header line-- time + // to try again with another header-line. + continue; + } + if (current < line_begin) { + // When this condition is true, the last detected colon was part of a + // previous line. We reset to the beginning of the line as we don't care + // about the presence of any colon before the beginning of the current + // line. + current = line_begin; + } + for (; current < line_end; ++current) { + if (*current == ':') { + break; + } + + if (header_properties::IsInvalidHeaderKeyChar(*current)) { + // Generally invalid characters were found earlier. + HandleError(is_trailer + ? BalsaFrameEnums::INVALID_TRAILER_NAME_CHARACTER + : BalsaFrameEnums::INVALID_HEADER_NAME_CHARACTER); + return false; + } + } + + if (current == line_end) { + // There was no colon in the line. The arguments we passed into the + // construction for the HeaderLineDescription object should be OK-- it + // assumes that the entire content is 'key' by default (which is true, as + // there was no colon, there can be no value). Note that this is a + // construct which is technically not allowed by the spec. + + // In strict mode, we do treat this invalid value-less key as an error. + if (http_validation_policy().require_header_colon) { + HandleError(is_trailer ? BalsaFrameEnums::TRAILER_MISSING_COLON + : BalsaFrameEnums::HEADER_MISSING_COLON); + return false; + } + HandleWarning(is_trailer ? BalsaFrameEnums::TRAILER_MISSING_COLON + : BalsaFrameEnums::HEADER_MISSING_COLON); + continue; + } + + QUICHE_DCHECK_EQ(*current, ':'); + QUICHE_DCHECK_LE(current - stream_begin, line_end - stream_begin); + QUICHE_DCHECK_LE(stream_begin - stream_begin, current - stream_begin); + + HeaderLineDescription& current_header_line = headers->header_lines_.back(); + current_header_line.key_end_idx = current - stream_begin; + current_header_line.value_begin_idx = current_header_line.key_end_idx; + if (current < line_end) { + ++current_header_line.key_end_idx; + + CleanUpKeyValueWhitespace(stream_begin, line_begin, current, line_end, + ¤t_header_line); + } + + const absl::string_view key( + stream_begin + current_header_line.first_char_idx, + current_header_line.key_end_idx - current_header_line.first_char_idx); + const absl::string_view value( + stream_begin + current_header_line.value_begin_idx, + current_header_line.last_char_idx - + current_header_line.value_begin_idx); + visitor_->OnHeader(key, value); + } + + return true; +} + +void BalsaFrame::HandleWarning(BalsaFrameEnums::ErrorCode error_code) { + last_error_ = error_code; + visitor_->HandleWarning(last_error_); +} + +void BalsaFrame::HandleError(BalsaFrameEnums::ErrorCode error_code) { + last_error_ = error_code; + parse_state_ = BalsaFrameEnums::ERROR; + visitor_->HandleError(last_error_); +} + +BalsaHeadersEnums::ContentLengthStatus BalsaFrame::ProcessContentLengthLine( + HeaderLines::size_type line_idx, size_t* length) { + const HeaderLineDescription& header_line = headers_->header_lines_[line_idx]; + const char* stream_begin = headers_->OriginalHeaderStreamBegin(); + const char* line_end = stream_begin + header_line.last_char_idx; + const char* value_begin = (stream_begin + header_line.value_begin_idx); + + if (value_begin >= line_end) { + // There is no non-whitespace value data. + QUICHE_DVLOG(1) << "invalid content-length -- no non-whitespace value data"; + return BalsaHeadersEnums::INVALID_CONTENT_LENGTH; + } + + *length = 0; + while (value_begin < line_end) { + if (*value_begin < '0' || *value_begin > '9') { + // bad! content-length found, and couldn't parse all of it! + QUICHE_DVLOG(1) + << "invalid content-length - non numeric character detected"; + return BalsaHeadersEnums::INVALID_CONTENT_LENGTH; + } + const size_t kMaxDiv10 = std::numeric_limits::max() / 10; + size_t length_x_10 = *length * 10; + const size_t c = *value_begin - '0'; + if (*length > kMaxDiv10 || + (std::numeric_limits::max() - length_x_10) < c) { + QUICHE_DVLOG(1) << "content-length overflow"; + return BalsaHeadersEnums::CONTENT_LENGTH_OVERFLOW; + } + *length = length_x_10 + c; + ++value_begin; + } + QUICHE_DVLOG(1) << "content_length parsed: " << *length; + return BalsaHeadersEnums::VALID_CONTENT_LENGTH; +} + +void BalsaFrame::ProcessTransferEncodingLine(HeaderLines::size_type line_idx) { + const HeaderLineDescription& header_line = headers_->header_lines_[line_idx]; + const char* stream_begin = headers_->OriginalHeaderStreamBegin(); + const absl::string_view transfer_encoding( + stream_begin + header_line.value_begin_idx, + header_line.last_char_idx - header_line.value_begin_idx); + + if (absl::EqualsIgnoreCase(transfer_encoding, kChunked)) { + headers_->transfer_encoding_is_chunked_ = true; + return; + } + + if (absl::EqualsIgnoreCase(transfer_encoding, kIdentity)) { + headers_->transfer_encoding_is_chunked_ = false; + return; + } + + HandleError(BalsaFrameEnums::UNKNOWN_TRANSFER_ENCODING); +} + +bool BalsaFrame::CheckHeaderLinesForInvalidChars(const Lines& lines, + const BalsaHeaders* headers) { + // Read from the beginning of the first line to the end of the last line. + // Note we need to add the first line's offset as in the case of a trailer + // it's non-zero. + const char* stream_begin = + headers->OriginalHeaderStreamBegin() + lines.front().first; + const char* stream_end = + headers->OriginalHeaderStreamBegin() + lines.back().second; + bool found_invalid = false; + + for (const char* c = stream_begin; c < stream_end; c++) { + if (header_properties::IsInvalidHeaderChar(*c)) { + found_invalid = true; + invalid_chars_[*c]++; + } + } + + return found_invalid; +} + +void BalsaFrame::ProcessHeaderLines(const Lines& lines, bool is_trailer, + BalsaHeaders* headers) { + QUICHE_DCHECK(!lines.empty()); + QUICHE_DVLOG(1) << "******@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@**********\n"; + + if (is_request() && track_invalid_chars()) { + if (CheckHeaderLinesForInvalidChars(lines, headers)) { + if (invalid_chars_error_enabled()) { + HandleError(BalsaFrameEnums::INVALID_HEADER_CHARACTER); + return; + } + + HandleWarning(BalsaFrameEnums::INVALID_HEADER_CHARACTER); + } + } + + // There is no need to attempt to process headers (resp. trailers) + // if no header (resp. trailer) lines exist. + // + // The last line of the message, which is an empty line, is never a header + // (resp. trailer) line. Furthermore, the first line of the message is not + // a header line. Therefore there are at least two (resp. one) lines in the + // message which are not header (resp. trailer) lines. + // + // Thus, we test to see if we have more than two (resp. one) lines total + // before attempting to parse any header (resp. trailer) lines. + if (lines.size() <= (is_trailer ? 1 : 2)) { + return; + } + + HeaderLines::size_type content_length_idx = 0; + HeaderLines::size_type transfer_encoding_idx = 0; + const char* stream_begin = headers->OriginalHeaderStreamBegin(); + // Parse the rest of the header or trailer data into key-value pairs. + if (!FindColonsAndParseIntoKeyValue(lines, is_trailer, headers)) { + return; + } + // At this point, we've parsed all of the headers/trailers. Time to look + // for those headers which we require for framing or for format errors. + const HeaderLines::size_type lines_size = headers->header_lines_.size(); + for (HeaderLines::size_type i = 0; i < lines_size; ++i) { + const HeaderLineDescription& line = headers->header_lines_[i]; + const absl::string_view key(stream_begin + line.first_char_idx, + line.key_end_idx - line.first_char_idx); + QUICHE_DVLOG(2) << "[" << i << "]: " << key << " key_len: " << key.length(); + + // If a header begins with either lowercase or uppercase 'c' or 't', then + // the header may be one of content-length, connection, content-encoding + // or transfer-encoding. These headers are special, as they change the way + // that the message is framed, and so the framer is required to search + // for them. However, first check for a formatting error, and skip + // special header treatment on trailer lines (when is_trailer is true). + if (key.empty() || key[0] == ' ') { + parse_state_ = BalsaFrameEnums::ERROR; + HandleError(is_trailer ? BalsaFrameEnums::INVALID_TRAILER_FORMAT + : BalsaFrameEnums::INVALID_HEADER_FORMAT); + return; + } + if (is_trailer) { + continue; + } + if (absl::EqualsIgnoreCase(key, kContentLength)) { + size_t length = 0; + BalsaHeadersEnums::ContentLengthStatus content_length_status = + ProcessContentLengthLine(i, &length); + if (content_length_idx == 0) { + content_length_idx = i + 1; + headers->content_length_status_ = content_length_status; + headers->content_length_ = length; + content_length_remaining_ = length; + continue; + } + if ((headers->content_length_status_ != content_length_status) || + ((headers->content_length_status_ == + BalsaHeadersEnums::VALID_CONTENT_LENGTH) && + (http_validation_policy().disallow_multiple_content_length || + length != headers->content_length_))) { + HandleError(BalsaFrameEnums::MULTIPLE_CONTENT_LENGTH_KEYS); + return; + } + continue; + } + if (absl::EqualsIgnoreCase(key, kTransferEncoding)) { + if (transfer_encoding_idx != 0) { + HandleError(BalsaFrameEnums::MULTIPLE_TRANSFER_ENCODING_KEYS); + return; + } + transfer_encoding_idx = i + 1; + } + } + + if (!is_trailer) { + if (http_validation_policy() + .disallow_transfer_encoding_with_content_length && + content_length_idx != 0 && transfer_encoding_idx != 0) { + HandleError(BalsaFrameEnums::BOTH_TRANSFER_ENCODING_AND_CONTENT_LENGTH); + return; + } + if (headers->transfer_encoding_is_chunked_) { + headers->content_length_ = 0; + headers->content_length_status_ = BalsaHeadersEnums::NO_CONTENT_LENGTH; + content_length_remaining_ = 0; + } + if (transfer_encoding_idx != 0) { + ProcessTransferEncodingLine(transfer_encoding_idx - 1); + } + } +} + +void BalsaFrame::AssignParseStateAfterHeadersHaveBeenParsed() { + // For responses, can't have a body if the request was a HEAD, or if it is + // one of these response-codes. rfc2616 section 4.3 + parse_state_ = BalsaFrameEnums::MESSAGE_FULLY_READ; + int response_code = headers_->parsed_response_code_; + if (!is_request_ && (request_was_head_ || + !BalsaHeaders::ResponseCanHaveBody(response_code))) { + // There is no body. + return; + } + + if (headers_->transfer_encoding_is_chunked_) { + // Note that + // if ( Transfer-Encoding: chunked && Content-length: ) + // then Transfer-Encoding: chunked trumps. + // This is as specified in the spec. + // rfc2616 section 4.4.3 + parse_state_ = BalsaFrameEnums::READING_CHUNK_LENGTH; + return; + } + + // Errors parsing content-length definitely can cause + // protocol errors/warnings + switch (headers_->content_length_status_) { + // If we have a content-length, and it is parsed + // properly, there are two options. + // 1) zero content, in which case the message is done, and + // 2) nonzero content, in which case we have to + // consume the body. + case BalsaHeadersEnums::VALID_CONTENT_LENGTH: + if (headers_->content_length_ == 0) { + parse_state_ = BalsaFrameEnums::MESSAGE_FULLY_READ; + } else { + parse_state_ = BalsaFrameEnums::READING_CONTENT; + } + break; + case BalsaHeadersEnums::CONTENT_LENGTH_OVERFLOW: + case BalsaHeadersEnums::INVALID_CONTENT_LENGTH: + // If there were characters left-over after parsing the + // content length, we should flag an error and stop. + HandleError(BalsaFrameEnums::UNPARSABLE_CONTENT_LENGTH); + break; + // We can have: no transfer-encoding, no content length, and no + // connection: close... + // Unfortunately, this case doesn't seem to be covered in the spec. + // We'll assume that the safest thing to do here is what the google + // binaries before 2008 already do, which is to assume that + // everything until the connection is closed is body. + case BalsaHeadersEnums::NO_CONTENT_LENGTH: + if (is_request_) { + const absl::string_view method = headers_->request_method(); + // POSTs and PUTs should have a detectable body length. If they + // do not we consider it an error. + if (method != "POST" && method != "PUT") { + parse_state_ = BalsaFrameEnums::MESSAGE_FULLY_READ; + break; + } else if (!allow_reading_until_close_for_request_) { + HandleError(BalsaFrameEnums::REQUIRED_BODY_BUT_NO_CONTENT_LENGTH); + break; + } + } + parse_state_ = BalsaFrameEnums::READING_UNTIL_CLOSE; + HandleWarning(BalsaFrameEnums::MAYBE_BODY_BUT_NO_CONTENT_LENGTH); + break; + // The COV_NF_... statements here provide hints to the apparatus + // which computes coverage reports/ratios that this code is never + // intended to be executed, and should technically be impossible. + // COV_NF_START + default: + QUICHE_LOG(FATAL) << "Saw a content_length_status: " + << headers_->content_length_status_ + << " which is unknown."; + // COV_NF_END + } +} + +size_t BalsaFrame::ProcessHeaders(const char* message_start, + size_t message_length) { + const char* const original_message_start = message_start; + const char* const message_end = message_start + message_length; + const char* message_current = message_start; + const char* checkpoint = message_start; + + if (message_length == 0) { + return message_current - original_message_start; + } + + while (message_current < message_end) { + size_t base_idx = headers_->GetReadableBytesFromHeaderStream(); + + // Yes, we could use strchr (assuming null termination), or + // memchr, but as it turns out that is slower than this tight loop + // for the input that we see. + if (!saw_non_newline_char_) { + do { + const char c = *message_current; + if (c != '\r' && c != '\n') { + if (CHAR_LE(c, ' ')) { + HandleError(BalsaFrameEnums::NO_REQUEST_LINE_IN_REQUEST); + return message_current - original_message_start; + } + break; + } + ++message_current; + if (message_current == message_end) { + return message_current - original_message_start; + } + } while (true); + saw_non_newline_char_ = true; + message_start = message_current; + checkpoint = message_current; + } + while (message_current < message_end) { + if (*message_current != '\n') { + ++message_current; + continue; + } + const size_t relative_idx = message_current - message_start; + const size_t message_current_idx = 1 + base_idx + relative_idx; + lines_.push_back(std::make_pair(last_slash_n_idx_, message_current_idx)); + if (lines_.size() == 1) { + headers_->WriteFromFramer(checkpoint, 1 + message_current - checkpoint); + checkpoint = message_current + 1; + const char* begin = headers_->OriginalHeaderStreamBegin(); + + QUICHE_DVLOG(1) << "First line " + << std::string(begin, lines_[0].second); + QUICHE_DVLOG(1) << "is_request_: " << is_request_; + ProcessFirstLine(begin, begin + lines_[0].second); + if (parse_state_ == BalsaFrameEnums::MESSAGE_FULLY_READ) { + break; + } + + if (parse_state_ == BalsaFrameEnums::ERROR) { + return message_current - original_message_start; + } + } + const size_t chars_since_last_slash_n = + (message_current_idx - last_slash_n_idx_); + last_slash_n_idx_ = message_current_idx; + if (chars_since_last_slash_n > 2) { + // false positive. + ++message_current; + continue; + } + if ((chars_since_last_slash_n == 1) || + (((message_current > message_start) && + (*(message_current - 1) == '\r')) || + (last_char_was_slash_r_))) { + break; + } + ++message_current; + } + + if (message_current == message_end) { + continue; + } + + ++message_current; + QUICHE_DCHECK(message_current >= message_start); + if (message_current > message_start) { + headers_->WriteFromFramer(checkpoint, message_current - checkpoint); + } + + // Check if we have exceeded maximum headers length + // Although we check for this limit before and after we call this function + // we check it here as well to make sure that in case the visitor changed + // the max_header_length_ (for example after processing the first line) + // we handle it gracefully. + if (headers_->GetReadableBytesFromHeaderStream() > max_header_length_) { + HandleError(BalsaFrameEnums::HEADERS_TOO_LONG); + return message_current - original_message_start; + } + + // Since we know that we won't be writing any more bytes of the header, + // we tell that to the headers object. The headers object may make + // more efficient allocation decisions when this is signaled. + headers_->DoneWritingFromFramer(); + visitor_->OnHeaderInput(headers_->GetReadablePtrFromHeaderStream()); + + // Ok, now that we've written everything into our header buffer, it is + // time to process the header lines (extract proper values for headers + // which are important for framing). + ProcessHeaderLines(lines_, false /*is_trailer*/, headers_); + if (parse_state_ == BalsaFrameEnums::ERROR) { + return message_current - original_message_start; + } + + if (use_interim_headers_callback_ && + IsInterimResponse(headers_->parsed_response_code())) { + // Deliver headers from this interim response but reset everything else to + // prepare for the next set of headers. + visitor_->OnInterimHeaders(std::move(*headers_)); + Reset(); + checkpoint = message_start = message_current; + continue; + } + if (continue_headers_ != nullptr && + headers_->parsed_response_code_ == kContinueStatusCode) { + // Save the headers from this 100 Continue response but reset everything + // else to prepare for the next set of headers. + BalsaHeaders saved_continue_headers = std::move(*headers_); + Reset(); + *continue_headers_ = std::move(saved_continue_headers); + visitor_->ContinueHeaderDone(); + checkpoint = message_start = message_current; + continue; + } + AssignParseStateAfterHeadersHaveBeenParsed(); + if (parse_state_ == BalsaFrameEnums::ERROR) { + return message_current - original_message_start; + } + visitor_->ProcessHeaders(*headers_); + visitor_->HeaderDone(); + if (parse_state_ == BalsaFrameEnums::MESSAGE_FULLY_READ) { + visitor_->MessageDone(); + } + return message_current - original_message_start; + } + // If we've gotten to here, it means that we've consumed all of the + // available input. We need to record whether or not the last character we + // saw was a '\r' so that a subsequent call to ProcessInput correctly finds + // a header framing that is split across the two calls. + last_char_was_slash_r_ = (*(message_end - 1) == '\r'); + QUICHE_DCHECK(message_current >= message_start); + if (message_current > message_start) { + headers_->WriteFromFramer(checkpoint, message_current - checkpoint); + } + return message_current - original_message_start; +} + +size_t BalsaFrame::BytesSafeToSplice() const { + switch (parse_state_) { + case BalsaFrameEnums::READING_CHUNK_DATA: + return chunk_length_remaining_; + case BalsaFrameEnums::READING_UNTIL_CLOSE: + return std::numeric_limits::max(); + case BalsaFrameEnums::READING_CONTENT: + return content_length_remaining_; + default: + return 0; + } +} + +void BalsaFrame::BytesSpliced(size_t bytes_spliced) { + switch (parse_state_) { + case BalsaFrameEnums::READING_CHUNK_DATA: + if (chunk_length_remaining_ < bytes_spliced) { + HandleError(BalsaFrameEnums:: + CALLED_BYTES_SPLICED_AND_EXCEEDED_SAFE_SPLICE_AMOUNT); + return; + } + chunk_length_remaining_ -= bytes_spliced; + if (chunk_length_remaining_ == 0) { + parse_state_ = BalsaFrameEnums::READING_CHUNK_TERM; + } + return; + + case BalsaFrameEnums::READING_UNTIL_CLOSE: + return; + + case BalsaFrameEnums::READING_CONTENT: + if (content_length_remaining_ < bytes_spliced) { + HandleError(BalsaFrameEnums:: + CALLED_BYTES_SPLICED_AND_EXCEEDED_SAFE_SPLICE_AMOUNT); + return; + } + content_length_remaining_ -= bytes_spliced; + if (content_length_remaining_ == 0) { + parse_state_ = BalsaFrameEnums::MESSAGE_FULLY_READ; + visitor_->MessageDone(); + } + return; + + default: + HandleError(BalsaFrameEnums::CALLED_BYTES_SPLICED_WHEN_UNSAFE_TO_DO_SO); + return; + } +} + +size_t BalsaFrame::ProcessInput(const char* input, size_t size) { + const char* current = input; + const char* on_entry = current; + const char* end = current + size; + + QUICHE_DCHECK(headers_ != nullptr); + if (headers_ == nullptr) { + return 0; + } + + if (parse_state_ == BalsaFrameEnums::READING_HEADER_AND_FIRSTLINE) { + const size_t header_length = headers_->GetReadableBytesFromHeaderStream(); + // Yes, we still have to check this here as the user can change the + // max_header_length amount! + // Also it is possible that we have reached the maximum allowed header size, + // and we have more to consume (remember we are still inside + // READING_HEADER_AND_FIRSTLINE) in which case we directly declare an error. + if (header_length > max_header_length_ || + (header_length == max_header_length_ && size > 0)) { + HandleError(BalsaFrameEnums::HEADERS_TOO_LONG); + return current - input; + } + const size_t bytes_to_process = + std::min(max_header_length_ - header_length, size); + current += ProcessHeaders(input, bytes_to_process); + // If we are still reading headers check if we have crossed the headers + // limit. Note that we check for >= as opposed to >. This is because if + // header_length_after equals max_header_length_ and we are still in the + // parse_state_ BalsaFrameEnums::READING_HEADER_AND_FIRSTLINE we know for + // sure that the headers limit will be crossed later on + if (parse_state_ == BalsaFrameEnums::READING_HEADER_AND_FIRSTLINE) { + // Note that headers_ is valid only if we are still reading headers. + const size_t header_length_after = + headers_->GetReadableBytesFromHeaderStream(); + if (header_length_after >= max_header_length_) { + HandleError(BalsaFrameEnums::HEADERS_TOO_LONG); + } + } + return current - input; + } + + if (parse_state_ == BalsaFrameEnums::MESSAGE_FULLY_READ || + parse_state_ == BalsaFrameEnums::ERROR) { + // Can do nothing more 'till we're reset. + return current - input; + } + + QUICHE_DCHECK_LE(current, end); + if (current == end) { + return current - input; + } + + while (true) { + switch (parse_state_) { + case BalsaFrameEnums::READING_CHUNK_LENGTH: + // In this state we read the chunk length. + // Note that once we hit a character which is not in: + // [0-9;A-Fa-f\n], we transition to a different state. + // + QUICHE_DCHECK_LE(current, end); + while (true) { + if (current == end) { + visitor_->OnRawBodyInput( + absl::string_view(on_entry, current - on_entry)); + return current - input; + } + + const char c = *current; + ++current; + + static const signed char kBad = -1; + static const signed char kDelimiter = -2; + + // valid cases: + // "09123\n" // -> 09123 + // "09123\r\n" // -> 09123 + // "09123 \n" // -> 09123 + // "09123 \r\n" // -> 09123 + // "09123 12312\n" // -> 09123 + // "09123 12312\r\n" // -> 09123 + // "09123; foo=bar\n" // -> 09123 + // "09123; foo=bar\r\n" // -> 09123 + // "FFFFFFFFFFFFFFFF\r\n" // -> FFFFFFFFFFFFFFFF + // "FFFFFFFFFFFFFFFF 22\r\n" // -> FFFFFFFFFFFFFFFF + // invalid cases: + // "[ \t]+[^\n]*\n" + // "FFFFFFFFFFFFFFFFF\r\n" (would overflow) + // "\r\n" + // "\n" + signed char addition = kBad; + // clang-format off + switch (c) { + case '0': addition = 0; break; + case '1': addition = 1; break; + case '2': addition = 2; break; + case '3': addition = 3; break; + case '4': addition = 4; break; + case '5': addition = 5; break; + case '6': addition = 6; break; + case '7': addition = 7; break; + case '8': addition = 8; break; + case '9': addition = 9; break; + case 'a': addition = 0xA; break; + case 'b': addition = 0xB; break; + case 'c': addition = 0xC; break; + case 'd': addition = 0xD; break; + case 'e': addition = 0xE; break; + case 'f': addition = 0xF; break; + case 'A': addition = 0xA; break; + case 'B': addition = 0xB; break; + case 'C': addition = 0xC; break; + case 'D': addition = 0xD; break; + case 'E': addition = 0xE; break; + case 'F': addition = 0xF; break; + case '\t': + case '\n': + case '\r': + case ' ': + case ';': + addition = kDelimiter; + break; + default: + // Leave addition == kBad + break; + } + // clang-format on + if (addition >= 0) { + chunk_length_character_extracted_ = true; + size_t length_x_16 = chunk_length_remaining_ * 16; + const size_t kMaxDiv16 = std::numeric_limits::max() / 16; + if ((chunk_length_remaining_ > kMaxDiv16) || + (std::numeric_limits::max() - length_x_16) < + static_cast(addition)) { + // overflow -- asked for a chunk-length greater than 2^64 - 1!! + visitor_->OnRawBodyInput( + absl::string_view(on_entry, current - on_entry)); + HandleError(BalsaFrameEnums::CHUNK_LENGTH_OVERFLOW); + return current - input; + } + chunk_length_remaining_ = length_x_16 + addition; + continue; + } + + if (!chunk_length_character_extracted_ || addition == kBad) { + // ^[0-9;A-Fa-f][ \t\n] -- was not matched, either because no + // characters were converted, or an unexpected character was + // seen. + visitor_->OnRawBodyInput( + absl::string_view(on_entry, current - on_entry)); + HandleError(BalsaFrameEnums::INVALID_CHUNK_LENGTH); + return current - input; + } + + break; + } + + --current; + parse_state_ = BalsaFrameEnums::READING_CHUNK_EXTENSION; + visitor_->OnChunkLength(chunk_length_remaining_); + continue; + + case BalsaFrameEnums::READING_CHUNK_EXTENSION: { + // TODO(phython): Convert this scanning to be 16 bytes at a time if + // there is data to be read. + const char* extensions_start = current; + size_t extensions_length = 0; + QUICHE_DCHECK_LE(current, end); + while (true) { + if (current == end) { + visitor_->OnChunkExtensionInput( + absl::string_view(extensions_start, extensions_length)); + visitor_->OnRawBodyInput( + absl::string_view(on_entry, current - on_entry)); + return current - input; + } + const char c = *current; + if (c == '\r' || c == '\n') { + extensions_length = (extensions_start == current) + ? 0 + : current - extensions_start - 1; + } + + ++current; + if (c == '\n') { + break; + } + } + + chunk_length_character_extracted_ = false; + visitor_->OnChunkExtensionInput( + absl::string_view(extensions_start, extensions_length)); + + if (chunk_length_remaining_ != 0) { + parse_state_ = BalsaFrameEnums::READING_CHUNK_DATA; + continue; + } + + HeaderFramingFound('\n'); + parse_state_ = BalsaFrameEnums::READING_LAST_CHUNK_TERM; + continue; + } + + case BalsaFrameEnums::READING_CHUNK_DATA: + while (current < end) { + if (chunk_length_remaining_ == 0) { + break; + } + // read in the chunk + size_t bytes_remaining = end - current; + size_t consumed_bytes = (chunk_length_remaining_ < bytes_remaining) + ? chunk_length_remaining_ + : bytes_remaining; + const char* tmp_current = current + consumed_bytes; + visitor_->OnRawBodyInput( + absl::string_view(on_entry, tmp_current - on_entry)); + visitor_->OnBodyChunkInput( + absl::string_view(current, consumed_bytes)); + on_entry = current = tmp_current; + chunk_length_remaining_ -= consumed_bytes; + } + + if (chunk_length_remaining_ == 0) { + parse_state_ = BalsaFrameEnums::READING_CHUNK_TERM; + continue; + } + + visitor_->OnRawBodyInput( + absl::string_view(on_entry, current - on_entry)); + return current - input; + + case BalsaFrameEnums::READING_CHUNK_TERM: + QUICHE_DCHECK_LE(current, end); + while (true) { + if (current == end) { + visitor_->OnRawBodyInput( + absl::string_view(on_entry, current - on_entry)); + return current - input; + } + + const char c = *current; + ++current; + + if (c == '\n') { + break; + } + } + parse_state_ = BalsaFrameEnums::READING_CHUNK_LENGTH; + continue; + + case BalsaFrameEnums::READING_LAST_CHUNK_TERM: + QUICHE_DCHECK_LE(current, end); + while (true) { + if (current == end) { + visitor_->OnRawBodyInput( + absl::string_view(on_entry, current - on_entry)); + return current - input; + } + + const char c = *current; + if (HeaderFramingFound(c) != 0) { + // If we've found a "\r\n\r\n", then the message + // is done. + ++current; + parse_state_ = BalsaFrameEnums::MESSAGE_FULLY_READ; + visitor_->OnRawBodyInput( + absl::string_view(on_entry, current - on_entry)); + visitor_->MessageDone(); + return current - input; + } + + // If not, however, since the spec only suggests that the + // client SHOULD indicate the presence of trailers, we get to + // *test* that they did or didn't. + // If all of the bytes we've seen since: + // OPTIONAL_WS 0 OPTIONAL_STUFF CRLF + // are either '\r', or '\n', then we can assume that we don't yet + // know if we need to parse headers, or if the next byte will make + // the HeaderFramingFound condition (above) true. + if (!HeaderFramingMayBeFound()) { + break; + } + + // If HeaderFramingMayBeFound(), then we have seen only characters + // '\r' or '\n'. + ++current; + + // Lets try again! There is no state change here. + } + + // If (!HeaderFramingMayBeFound()), then we know that we must be + // reading the first non CRLF character of a trailer. + parse_state_ = BalsaFrameEnums::READING_TRAILER; + visitor_->OnRawBodyInput( + absl::string_view(on_entry, current - on_entry)); + on_entry = current; + continue; + + // TODO(yongfa): No leading whitespace is allowed before field-name per + // RFC2616. Leading whitespace will cause header parsing error too. + case BalsaFrameEnums::READING_TRAILER: + while (current < end) { + const char c = *current; + ++current; + ++trailer_length_; + if (trailer_ != nullptr) { + // Reuse the header length limit for trailer, which is just a bunch + // of headers. + if (trailer_length_ > max_header_length_) { + --current; + HandleError(BalsaFrameEnums::TRAILER_TOO_LONG); + return current - input; + } + if (LineFramingFound(c)) { + trailer_lines_.push_back( + std::make_pair(start_of_trailer_line_, trailer_length_)); + start_of_trailer_line_ = trailer_length_; + } + } + if (HeaderFramingFound(c) != 0) { + parse_state_ = BalsaFrameEnums::MESSAGE_FULLY_READ; + if (trailer_ != nullptr) { + trailer_->WriteFromFramer(on_entry, current - on_entry); + trailer_->DoneWritingFromFramer(); + ProcessHeaderLines(trailer_lines_, true /*is_trailer*/, trailer_); + if (parse_state_ == BalsaFrameEnums::ERROR) { + return current - input; + } + visitor_->ProcessTrailers(*trailer_); + } + visitor_->OnTrailerInput( + absl::string_view(on_entry, current - on_entry)); + visitor_->MessageDone(); + return current - input; + } + } + if (trailer_ != nullptr) { + trailer_->WriteFromFramer(on_entry, current - on_entry); + } + visitor_->OnTrailerInput( + absl::string_view(on_entry, current - on_entry)); + return current - input; + + case BalsaFrameEnums::READING_UNTIL_CLOSE: { + const size_t bytes_remaining = end - current; + if (bytes_remaining > 0) { + visitor_->OnRawBodyInput(absl::string_view(current, bytes_remaining)); + visitor_->OnBodyChunkInput( + absl::string_view(current, bytes_remaining)); + current += bytes_remaining; + } + return current - input; + } + + case BalsaFrameEnums::READING_CONTENT: + while ((content_length_remaining_ != 0u) && current < end) { + // read in the content + const size_t bytes_remaining = end - current; + const size_t consumed_bytes = + (content_length_remaining_ < bytes_remaining) + ? content_length_remaining_ + : bytes_remaining; + visitor_->OnRawBodyInput(absl::string_view(current, consumed_bytes)); + visitor_->OnBodyChunkInput( + absl::string_view(current, consumed_bytes)); + current += consumed_bytes; + content_length_remaining_ -= consumed_bytes; + } + if (content_length_remaining_ == 0) { + parse_state_ = BalsaFrameEnums::MESSAGE_FULLY_READ; + visitor_->MessageDone(); + } + return current - input; + + default: + // The state-machine should never be in a state that isn't handled + // above. This is a glaring logic error, and we should do something + // drastic to ensure that this gets looked-at and fixed. + QUICHE_LOG(FATAL) << "Unknown state: " << parse_state_ // COV_NF_LINE + << " memory corruption?!"; // COV_NF_LINE + } + } +} + +const int32_t BalsaFrame::kValidTerm1; +const int32_t BalsaFrame::kValidTerm1Mask; +const int32_t BalsaFrame::kValidTerm2; +const int32_t BalsaFrame::kValidTerm2Mask; + +} // namespace quiche + +#undef CHAR_LT +#undef CHAR_LE +#undef CHAR_GT +#undef CHAR_GE +#undef QUICHE_DCHECK_CHAR_GE diff --git a/quiche/balsa/balsa_frame.h b/quiche/balsa/balsa_frame.h new file mode 100644 index 000000000000..f3f75e8037d2 --- /dev/null +++ b/quiche/balsa/balsa_frame.h @@ -0,0 +1,322 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_BALSA_BALSA_FRAME_H_ +#define QUICHE_BALSA_BALSA_FRAME_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "quiche/balsa/balsa_enums.h" +#include "quiche/balsa/balsa_headers.h" +#include "quiche/balsa/balsa_visitor_interface.h" +#include "quiche/balsa/framer_interface.h" +#include "quiche/balsa/http_validation_policy.h" +#include "quiche/balsa/noop_balsa_visitor.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_flag_utils.h" + +namespace quiche { + +namespace test { +class BalsaFrameTestPeer; +} // namespace test + +// BalsaFrame is a lightweight HTTP framer. +class QUICHE_EXPORT BalsaFrame : public FramerInterface { + public: + typedef std::vector > Lines; + + typedef BalsaHeaders::HeaderLineDescription HeaderLineDescription; + typedef BalsaHeaders::HeaderLines HeaderLines; + typedef BalsaHeaders::HeaderTokenList HeaderTokenList; + + enum class InvalidCharsLevel { kOff, kWarning, kError }; + + static constexpr int32_t kValidTerm1 = '\n' << 16 | '\r' << 8 | '\n'; + static constexpr int32_t kValidTerm1Mask = 0xFF << 16 | 0xFF << 8 | 0xFF; + static constexpr int32_t kValidTerm2 = '\n' << 8 | '\n'; + static constexpr int32_t kValidTerm2Mask = 0xFF << 8 | 0xFF; + BalsaFrame() + : last_char_was_slash_r_(false), + saw_non_newline_char_(false), + start_was_space_(true), + chunk_length_character_extracted_(false), + is_request_(true), + allow_reading_until_close_for_request_(false), + request_was_head_(false), + max_header_length_(16 * 1024), + visitor_(&do_nothing_visitor_), + chunk_length_remaining_(0), + content_length_remaining_(0), + last_slash_n_loc_(nullptr), + last_recorded_slash_n_loc_(nullptr), + last_slash_n_idx_(0), + term_chars_(0), + parse_state_(BalsaFrameEnums::READING_HEADER_AND_FIRSTLINE), + last_error_(BalsaFrameEnums::BALSA_NO_ERROR), + continue_headers_(nullptr), + headers_(nullptr), + start_of_trailer_line_(0), + trailer_length_(0), + trailer_(nullptr), + invalid_chars_level_(InvalidCharsLevel::kOff), + use_interim_headers_callback_(false) {} + + ~BalsaFrame() override {} + + // Reset reinitializes all the member variables of the framer and clears the + // attached header object (but doesn't change the pointer value headers_). + void Reset(); + + // The method set_balsa_headers clears the headers provided and attaches them + // to the framer. This is a required step before the framer will process any + // input message data. + // To detach the header object from the framer, use + // set_balsa_headers(nullptr). + void set_balsa_headers(BalsaHeaders* headers) { + if (headers_ != headers) { + headers_ = headers; + } + if (headers_ != nullptr) { + // Clear the headers if they are non-null, even if the new headers are + // the same as the old. + headers_->Clear(); + } + } + + // If set to non-null, allow 100 Continue headers before the main headers. + // This method is a no-op if set_use_interim_headers_callback(true) is called. + void set_continue_headers(BalsaHeaders* continue_headers) { + if (continue_headers_ != continue_headers) { + continue_headers_ = continue_headers; + } + if (continue_headers_ != nullptr) { + // Clear the headers if they are non-null, even if the new headers are + // the same as the old. + continue_headers_->Clear(); + } + } + + // The method set_balsa_trailer() clears `trailer` and attaches it to the + // framer. This is a required step before the framer will process any input + // message data. To detach the trailer object from the framer, use + // set_balsa_trailer(nullptr). + void set_balsa_trailer(BalsaHeaders* trailer) { + if (trailer != nullptr && is_request()) { + QUICHE_CODE_COUNT(balsa_trailer_in_request); + } + + if (trailer_ != trailer) { + trailer_ = trailer; + } + if (trailer_ != nullptr) { + // Clear the trailer if it is non-null, even if the new trailer is + // the same as the old. + trailer_->Clear(); + } + } + + void set_balsa_visitor(BalsaVisitorInterface* visitor) { + visitor_ = visitor; + if (visitor_ == nullptr) { + visitor_ = &do_nothing_visitor_; + } + } + + void set_invalid_chars_level(InvalidCharsLevel v) { + invalid_chars_level_ = v; + } + + bool track_invalid_chars() { + return invalid_chars_level_ != InvalidCharsLevel::kOff; + } + + bool invalid_chars_error_enabled() { + return invalid_chars_level_ == InvalidCharsLevel::kError; + } + + void set_http_validation_policy(const quiche::HttpValidationPolicy& policy) { + http_validation_policy_ = policy; + } + const quiche::HttpValidationPolicy& http_validation_policy() const { + return http_validation_policy_; + } + + void set_is_request(bool is_request) { is_request_ = is_request; } + + bool is_request() const { return is_request_; } + + void set_request_was_head(bool request_was_head) { + request_was_head_ = request_was_head; + } + + void set_max_header_length(size_t max_header_length) { + max_header_length_ = max_header_length; + } + + size_t max_header_length() const { return max_header_length_; } + + bool MessageFullyRead() const { + return parse_state_ == BalsaFrameEnums::MESSAGE_FULLY_READ; + } + + BalsaFrameEnums::ParseState ParseState() const { return parse_state_; } + + bool Error() const { return parse_state_ == BalsaFrameEnums::ERROR; } + + BalsaFrameEnums::ErrorCode ErrorCode() const { return last_error_; } + + const absl::flat_hash_map& get_invalid_chars() const { + return invalid_chars_; + } + + const BalsaHeaders* headers() const { return headers_; } + BalsaHeaders* mutable_headers() { return headers_; } + + const BalsaHeaders* trailer() const { return trailer_; } + BalsaHeaders* mutable_trailer() { return trailer_; } + + size_t BytesSafeToSplice() const; + void BytesSpliced(size_t bytes_spliced); + + size_t ProcessInput(const char* input, size_t size) override; + + void set_allow_reading_until_close_for_request(bool set) { + allow_reading_until_close_for_request_ = set; + } + + // For websockets and possibly other uses, we suspend the usual expectations + // about when a message has a body and how long it should be. + void AllowArbitraryBody() { + parse_state_ = BalsaFrameEnums::READING_UNTIL_CLOSE; + } + + // If enabled, calls BalsaVisitorInterface::OnInterimHeaders() when parsing + // interim headers. For 100 Continue, this callback will be invoked instead of + // ContinueHeaderDone(), even when set_continue_headers() is called. + void set_use_interim_headers_callback(bool set) { + use_interim_headers_callback_ = set; + } + + protected: + inline BalsaHeadersEnums::ContentLengthStatus ProcessContentLengthLine( + size_t line_idx, size_t* length); + + inline void ProcessTransferEncodingLine(size_t line_idx); + + void ProcessFirstLine(const char* begin, const char* end); + + void CleanUpKeyValueWhitespace(const char* stream_begin, + const char* line_begin, const char* current, + const char* line_end, + HeaderLineDescription* current_header_line); + + void ProcessHeaderLines(const Lines& lines, bool is_trailer, + BalsaHeaders* headers); + + // Returns true if there are invalid characters, false otherwise. + // Will also update counts per invalid character in invalid_chars_. + bool CheckHeaderLinesForInvalidChars(const Lines& lines, + const BalsaHeaders* headers); + + inline size_t ProcessHeaders(const char* message_start, + size_t message_length); + + void AssignParseStateAfterHeadersHaveBeenParsed(); + + inline bool LineFramingFound(char current_char) { + return current_char == '\n'; + } + + // Return header framing pattern. Non-zero return value indicates found, + // which has two possible outcomes: kValidTerm1, which means \n\r\n + // or kValidTerm2, which means \n\n. Zero return value means not found. + inline int32_t HeaderFramingFound(char current_char) { + // Note that the 'if (current_char == '\n' ...)' test exists to ensure that + // the HeaderFramingMayBeFound test works properly. In benchmarking done on + // 2/13/2008, the 'if' actually speeds up performance of the function + // anyway.. + if (current_char == '\n' || current_char == '\r') { + term_chars_ <<= 8; + // This is necessary IFF architecture has > 8 bit char. Alas, I'm + // paranoid. + term_chars_ |= current_char & 0xFF; + + if ((term_chars_ & kValidTerm1Mask) == kValidTerm1) { + term_chars_ = 0; + return kValidTerm1; + } + if ((term_chars_ & kValidTerm2Mask) == kValidTerm2) { + term_chars_ = 0; + return kValidTerm2; + } + } else { + term_chars_ = 0; + } + return 0; + } + + inline bool HeaderFramingMayBeFound() const { return term_chars_ != 0; } + + private: + friend class test::BalsaFrameTestPeer; + + // Calls HandleError() and returns false on error. + bool FindColonsAndParseIntoKeyValue(const Lines& lines, bool is_trailer, + BalsaHeaders* headers); + + void HandleError(BalsaFrameEnums::ErrorCode error_code); + void HandleWarning(BalsaFrameEnums::ErrorCode error_code); + + bool last_char_was_slash_r_; + bool saw_non_newline_char_; + bool start_was_space_; + bool chunk_length_character_extracted_; + bool is_request_; // This is not reset in Reset() + // Generally, requests are not allowed to frame with connection: close. For + // protocols which do their own protocol-specific chunking, such as streamed + // stubby, we allow connection close semantics for requests. + bool allow_reading_until_close_for_request_; + bool request_was_head_; // This is not reset in Reset() + size_t max_header_length_; // This is not reset in Reset() + BalsaVisitorInterface* visitor_; + size_t chunk_length_remaining_; + size_t content_length_remaining_; + const char* last_slash_n_loc_; + const char* last_recorded_slash_n_loc_; + size_t last_slash_n_idx_; + uint32_t term_chars_; + BalsaFrameEnums::ParseState parse_state_; + BalsaFrameEnums::ErrorCode last_error_; + absl::flat_hash_map invalid_chars_; + + Lines lines_; + + BalsaHeaders* continue_headers_; // This is not reset to nullptr in Reset(). + BalsaHeaders* headers_; // This is not reset to nullptr in Reset(). + NoOpBalsaVisitor do_nothing_visitor_; + + Lines trailer_lines_; + size_t start_of_trailer_line_; + size_t trailer_length_; + BalsaHeaders* trailer_; // Does not own and is not reset to nullptr + // in Reset(). + InvalidCharsLevel invalid_chars_level_; // This is not reset in Reset(). + + quiche::HttpValidationPolicy http_validation_policy_; + + // This is not reset in Reset(). + // TODO(b/68801833): Default-enable and then deprecate this field, along with + // set_continue_headers(). + bool use_interim_headers_callback_; +}; + +} // namespace quiche + +#endif // QUICHE_BALSA_BALSA_FRAME_H_ diff --git a/quiche/balsa/balsa_frame_test.cc b/quiche/balsa/balsa_frame_test.cc new file mode 100644 index 000000000000..13a704a2c000 --- /dev/null +++ b/quiche/balsa/balsa_frame_test.cc @@ -0,0 +1,4011 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/balsa/balsa_frame.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "quiche/balsa/balsa_enums.h" +#include "quiche/balsa/balsa_headers.h" +#include "quiche/balsa/balsa_visitor_interface.h" +#include "quiche/balsa/http_validation_policy.h" +#include "quiche/balsa/noop_balsa_visitor.h" +#include "quiche/balsa/simple_buffer.h" +#include "quiche/common/platform/api/quiche_command_line_flags.h" +#include "quiche/common/platform/api/quiche_expect_bug.h" +#include "quiche/common/platform/api/quiche_flags.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +using ::testing::_; +using ::testing::AnyNumber; +using ::testing::AtLeast; +using ::testing::InSequence; +using ::testing::IsEmpty; +using ::testing::Mock; +using ::testing::NiceMock; +using ::testing::Property; +using ::testing::Range; +using ::testing::StrEq; +using ::testing::StrictMock; + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::string, randseed, "", + "This is the seed for Pseudo-random number" + " generator used when generating random messages for unittests"); + +namespace quiche { + +namespace test { + +// This random engine from the standard library supports initialization with a +// seed, which is helpful for reproducing any unit test failures that are due to +// random sequence variation. +using RandomEngine = std::mt19937; + +class BalsaFrameTestPeer { + public: + static int32_t HeaderFramingFound(BalsaFrame* balsa_frame, char c) { + return balsa_frame->HeaderFramingFound(c); + } + + static void FindColonsAndParseIntoKeyValue(BalsaFrame* balsa_frame, + const BalsaFrame::Lines& lines, + bool is_trailer, + BalsaHeaders* headers) { + balsa_frame->FindColonsAndParseIntoKeyValue(lines, is_trailer, headers); + } +}; + +class BalsaHeadersTestPeer { + public: + static void WriteFromFramer(BalsaHeaders* headers, const char* ptr, + size_t size) { + headers->WriteFromFramer(ptr, size); + } +}; + +namespace { + +// This class encapsulates the policy of seed selection. If user supplies a +// valid use via the --randseed flag, GetSeed will only return the user +// supplied seed value. This is useful in reproducing bugs reported by the +// test. If an invalid seed value is supplied (likely due to bad numeric +// format), the test will abort (since this mode tend to be used for debugging, +// it is better to die early so the user knows a bad value is supplied). If no +// seed is supplied, the value supplied by ACMRandom::HostnamePidTimeSeed() is +// used. This class is supposed to be a singleton, but there is no ill-effect if +// multiple instances are created (although that tends not to be what the user +// wants). +class TestSeed { + public: + TestSeed() : test_seed_(0), user_supplied_seed_(false) {} + + void Initialize(const std::string& seed_flag) { + if (!seed_flag.empty()) { + ASSERT_TRUE(absl::SimpleAtoi(seed_flag, &test_seed_)); + user_supplied_seed_ = true; + } + } + + int GetSeed() const { + int seed = + (user_supplied_seed_ ? test_seed_ + : testing::UnitTest::GetInstance()->random_seed()); + QUICHE_LOG(INFO) << "**** The current seed is " << seed << " ****"; + return seed; + } + + private: + int test_seed_; + bool user_supplied_seed_; +}; + +static bool RandomBool(RandomEngine& rng) { return rng() % 2 != 0; } + +std::string EscapeString(absl::string_view message) { + return absl::StrReplaceAll( + message, {{"\n", "\\\\n\n"}, {"\\r", "\\\\r"}, {"\\t", "\\\\t"}}); +} + +char random_lws(RandomEngine& rng) { + if (RandomBool(rng)) { + return '\t'; + } + return ' '; +} + +const char* random_line_term(RandomEngine& rng) { + if (RandomBool(rng)) { + return "\r\n"; + } + return "\n"; +} + +void AppendRandomWhitespace(RandomEngine& rng, std::stringstream* s) { + // Appending a random amount of whitespace to the unparsed value. There is a + // max of 1000 pieces of whitespace that will be attached, however, it is + // extremely unlikely (1 in 2^1000) that we'll hit this limit, as we have a + // 50% probability of exiting the loop at any point in time. + for (int i = 0; i < 1000 && RandomBool(rng); ++i) { + *s << random_lws(rng); + } +} + +// Creates an HTTP message firstline from the given inputs. +// +// tokens - The list of nonwhitespace tokens (which should later be parsed out +// from the firstline). +// whitespace - the whitespace that occurs before, between, and +// after the tokens. Note that the last whitespace +// character should -not- include any '\n'. +// line_ending - one of "\n" or "\r\n" +// +// whitespace[0] occurs before the first token. +// whitespace[1] occurs between the first and second token +// whitespace[2] occurs between the second and third token +// whitespace[3] occurs between the third token and the line_ending. +// +// This code: +// const char tokens[3] = {"GET", "/", "HTTP/1.0"}; +// const char whitespace[4] = { "\n\n", " ", "\t", "\t"}; +// const char line_ending = "\r\n"; +// CreateFirstLine(tokens, whitespace, line_ending) -> +// Would yield the following string: +// string( +// "\n" +// "\n" +// "GET /\tHTTP/1.0\t\r\n" +// ); +// +std::string CreateFirstLine(const char* tokens[3], const char* whitespace[4], + const char* line_ending) { + QUICHE_CHECK(tokens != nullptr); + QUICHE_CHECK(whitespace != nullptr); + QUICHE_CHECK(line_ending != nullptr); + QUICHE_CHECK(std::string(line_ending) == "\n" || + std::string(line_ending) == "\r\n") + << "line_ending: " << EscapeString(line_ending); + SimpleBuffer firstline_buffer; + firstline_buffer.WriteString(whitespace[0]); + for (int i = 0; i < 3; ++i) { + firstline_buffer.WriteString(tokens[i]); + firstline_buffer.WriteString(whitespace[i + 1]); + } + firstline_buffer.WriteString(line_ending); + return std::string(firstline_buffer.GetReadableRegion()); +} + +// Creates a string (ostensibly an entire HTTP message) from the given input +// arguments. +// +// firstline - the first line of the request or response. +// The firstline should already have a line-ending on it. If you use the +// CreateFirstLine function, you'll get a valid firstline string for this +// function. This may include 'extraneous' whitespace before the first +// nonwhitespace character, including '\n's +// headers - a list of the -interpreted- key, value pairs. +// In other words, the value should be what you expect to get out of the +// headers after framing has occurred (and should include no whitespace +// before or after the first and list nonwhitespace characters, +// respectively). While this function will succeed if you don't follow +// these guidelines, the VerifyHeaderLines function will likely not agree +// with that input. +// headers_len - the number of key value pairs +// colon - the string that exists between the key and value pairs. +// It MUST include EXACTLY one colon, and may include any amount of either +// ' ' or '\t'. Note that for certain key strings, this value will be +// modified to exclude any leading whitespace. See the body of the function +// for more details. +// line_ending - one of "\r\n", or "\n\n" +// body - the appropriate body. +// The CreateMessage function does not do any checking that the headers +// agree with the present of any body, so the input must be correct given +// the set of headers. +std::string CreateMessage(const char* firstline, + const std::pair* headers, + size_t headers_len, const char* colon, + const char* line_ending, const char* body) { + SimpleBuffer request_buffer; + request_buffer.WriteString(firstline); + if (headers_len > 0) { + QUICHE_CHECK(headers != nullptr); + QUICHE_CHECK(colon != nullptr); + } + QUICHE_CHECK(line_ending != nullptr); + QUICHE_CHECK(std::string(line_ending) == "\n" || + std::string(line_ending) == "\r\n") + << "line_ending: " << EscapeString(line_ending); + QUICHE_CHECK(body != nullptr); + for (size_t i = 0; i < headers_len; ++i) { + bool only_whitespace_in_key = true; + { + // If the 'key' part includes no non-whitespace characters, then we need + // to be sure that the 'colon' part includes no whitespace before the + // ':'. If it did, then the line would be (correctly!) interpreted as a + // continuation, and the test would not work properly. + const char* tmp_key = headers[i].first.c_str(); + while (*tmp_key != '\0') { + if (*tmp_key > ' ') { + only_whitespace_in_key = false; + break; + } + ++tmp_key; + } + } + const char* tmp_colon = colon; + if (only_whitespace_in_key) { + while (*tmp_colon != ':') { + ++tmp_colon; + } + } + request_buffer.WriteString(headers[i].first); + request_buffer.WriteString(tmp_colon); + request_buffer.WriteString(headers[i].second); + request_buffer.WriteString(line_ending); + } + request_buffer.WriteString(line_ending); + request_buffer.WriteString(body); + return std::string(request_buffer.GetReadableRegion()); +} + +void VerifyRequestFirstLine(const char* tokens[3], + const BalsaHeaders& headers) { + EXPECT_EQ(tokens[0], headers.request_method()); + EXPECT_EQ(tokens[1], headers.request_uri()); + EXPECT_EQ(0u, headers.parsed_response_code()); + EXPECT_EQ(tokens[2], headers.request_version()); +} + +void VerifyResponseFirstLine(const char* tokens[3], + size_t expected_response_code, + const BalsaHeaders& headers) { + EXPECT_EQ(tokens[0], headers.response_version()); + EXPECT_EQ(tokens[1], headers.response_code()); + EXPECT_EQ(expected_response_code, headers.parsed_response_code()); + EXPECT_EQ(tokens[2], headers.response_reason_phrase()); +} + +// This function verifies that the expected_headers key and values +// are exactly equal to that returned by an iterator to a BalsaHeader +// object. +// +// expected_headers - key, value pairs, in the order in which they're +// expected to be returned from the iterator. +// headers_len - as expected, the number of expected key-value pairs. +// headers - the BalsaHeaders from which we'll examine the actual +// headers. +void VerifyHeaderLines( + const std::pair* expected_headers, + size_t headers_len, const BalsaHeaders& headers) { + BalsaHeaders::const_header_lines_iterator it = headers.lines().begin(); + for (size_t i = 0; it != headers.lines().end(); ++it, ++i) { + ASSERT_GT(headers_len, i); + std::string actual_key; + std::string actual_value; + if (!it->first.empty()) { + actual_key = std::string(it->first); + } + if (!it->second.empty()) { + actual_value = std::string(it->second); + } + EXPECT_THAT(actual_key, StrEq(expected_headers[i].first)); + EXPECT_THAT(actual_value, StrEq(expected_headers[i].second)); + } + EXPECT_TRUE(headers.lines().end() == it); +} + +void FirstLineParsedCorrectlyHelper(const char* tokens[3], + size_t expected_response_code, + bool is_request, const char* whitespace) { + BalsaHeaders headers; + BalsaFrame framer; + framer.set_is_request(is_request); + framer.set_balsa_headers(&headers); + const char* tmp_tokens[3] = {tokens[0], tokens[1], tokens[2]}; + const char* tmp_whitespace[4] = {"", whitespace, whitespace, ""}; + for (int j = 2; j >= 0; --j) { + framer.Reset(); + std::string firstline = CreateFirstLine(tmp_tokens, tmp_whitespace, "\n"); + std::string message = + CreateMessage(firstline.c_str(), nullptr, 0, nullptr, "\n", ""); + SCOPED_TRACE(absl::StrFormat("input: \n%s", EscapeString(message))); + EXPECT_GE(message.size(), + framer.ProcessInput(message.data(), message.size())); + // If this is a request then we don't expect a framer error (as we'll be + // getting back warnings that fields are missing). If, however, this is + // a response, and it is missing anything other than the reason phrase, + // the framer will signal an error instead. + if (is_request || j >= 1) { + EXPECT_FALSE(framer.Error()); + if (is_request) { + EXPECT_TRUE(framer.MessageFullyRead()); + } + if (j == 0) { + expected_response_code = 0; + } + if (is_request) { + VerifyRequestFirstLine(tmp_tokens, *framer.headers()); + } else { + VerifyResponseFirstLine(tmp_tokens, expected_response_code, + *framer.headers()); + } + } else { + EXPECT_TRUE(framer.Error()); + } + tmp_tokens[j] = ""; + tmp_whitespace[j] = ""; + } +} + +TEST(HTTPBalsaFrame, ParseStateToString) { + EXPECT_STREQ("ERROR", + BalsaFrameEnums::ParseStateToString(BalsaFrameEnums::ERROR)); + EXPECT_STREQ("READING_HEADER_AND_FIRSTLINE", + BalsaFrameEnums::ParseStateToString( + BalsaFrameEnums::READING_HEADER_AND_FIRSTLINE)); + EXPECT_STREQ("READING_CHUNK_LENGTH", + BalsaFrameEnums::ParseStateToString( + BalsaFrameEnums::READING_CHUNK_LENGTH)); + EXPECT_STREQ("READING_CHUNK_EXTENSION", + BalsaFrameEnums::ParseStateToString( + BalsaFrameEnums::READING_CHUNK_EXTENSION)); + EXPECT_STREQ("READING_CHUNK_DATA", BalsaFrameEnums::ParseStateToString( + BalsaFrameEnums::READING_CHUNK_DATA)); + EXPECT_STREQ("READING_CHUNK_TERM", BalsaFrameEnums::ParseStateToString( + BalsaFrameEnums::READING_CHUNK_TERM)); + EXPECT_STREQ("READING_LAST_CHUNK_TERM", + BalsaFrameEnums::ParseStateToString( + BalsaFrameEnums::READING_LAST_CHUNK_TERM)); + EXPECT_STREQ("READING_TRAILER", BalsaFrameEnums::ParseStateToString( + BalsaFrameEnums::READING_TRAILER)); + EXPECT_STREQ("READING_UNTIL_CLOSE", + BalsaFrameEnums::ParseStateToString( + BalsaFrameEnums::READING_UNTIL_CLOSE)); + EXPECT_STREQ("READING_CONTENT", BalsaFrameEnums::ParseStateToString( + BalsaFrameEnums::READING_CONTENT)); + EXPECT_STREQ("MESSAGE_FULLY_READ", BalsaFrameEnums::ParseStateToString( + BalsaFrameEnums::MESSAGE_FULLY_READ)); + + EXPECT_STREQ("UNKNOWN_STATE", BalsaFrameEnums::ParseStateToString( + BalsaFrameEnums::NUM_STATES)); + EXPECT_STREQ("UNKNOWN_STATE", + BalsaFrameEnums::ParseStateToString( + static_cast(-1))); + + for (int i = 0; i < BalsaFrameEnums::NUM_STATES; ++i) { + EXPECT_STRNE("UNKNOWN_STATE", + BalsaFrameEnums::ParseStateToString( + static_cast(i))); + } +} + +TEST(HTTPBalsaFrame, ErrorCodeToString) { + EXPECT_STREQ("NO_STATUS_LINE_IN_RESPONSE", + BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums::NO_STATUS_LINE_IN_RESPONSE)); + EXPECT_STREQ("NO_REQUEST_LINE_IN_REQUEST", + BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums::NO_REQUEST_LINE_IN_REQUEST)); + EXPECT_STREQ("FAILED_TO_FIND_WS_AFTER_RESPONSE_VERSION", + BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums::FAILED_TO_FIND_WS_AFTER_RESPONSE_VERSION)); + EXPECT_STREQ("FAILED_TO_FIND_WS_AFTER_REQUEST_METHOD", + BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums::FAILED_TO_FIND_WS_AFTER_REQUEST_METHOD)); + EXPECT_STREQ( + "FAILED_TO_FIND_WS_AFTER_RESPONSE_STATUSCODE", + BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums::FAILED_TO_FIND_WS_AFTER_RESPONSE_STATUSCODE)); + EXPECT_STREQ( + "FAILED_TO_FIND_WS_AFTER_REQUEST_REQUEST_URI", + BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums::FAILED_TO_FIND_WS_AFTER_REQUEST_REQUEST_URI)); + EXPECT_STREQ( + "FAILED_TO_FIND_NL_AFTER_RESPONSE_REASON_PHRASE", + BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums::FAILED_TO_FIND_NL_AFTER_RESPONSE_REASON_PHRASE)); + EXPECT_STREQ( + "FAILED_TO_FIND_NL_AFTER_REQUEST_HTTP_VERSION", + BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums::FAILED_TO_FIND_NL_AFTER_REQUEST_HTTP_VERSION)); + EXPECT_STREQ("FAILED_CONVERTING_STATUS_CODE_TO_INT", + BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums::FAILED_CONVERTING_STATUS_CODE_TO_INT)); + EXPECT_STREQ("HEADERS_TOO_LONG", BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums::HEADERS_TOO_LONG)); + EXPECT_STREQ("UNPARSABLE_CONTENT_LENGTH", + BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums::UNPARSABLE_CONTENT_LENGTH)); + EXPECT_STREQ("MAYBE_BODY_BUT_NO_CONTENT_LENGTH", + BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums::MAYBE_BODY_BUT_NO_CONTENT_LENGTH)); + EXPECT_STREQ("HEADER_MISSING_COLON", + BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums::HEADER_MISSING_COLON)); + EXPECT_STREQ("INVALID_CHUNK_LENGTH", + BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums::INVALID_CHUNK_LENGTH)); + EXPECT_STREQ("CHUNK_LENGTH_OVERFLOW", + BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums::CHUNK_LENGTH_OVERFLOW)); + EXPECT_STREQ("CALLED_BYTES_SPLICED_WHEN_UNSAFE_TO_DO_SO", + BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums::CALLED_BYTES_SPLICED_WHEN_UNSAFE_TO_DO_SO)); + EXPECT_STREQ("CALLED_BYTES_SPLICED_AND_EXCEEDED_SAFE_SPLICE_AMOUNT", + BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums:: + CALLED_BYTES_SPLICED_AND_EXCEEDED_SAFE_SPLICE_AMOUNT)); + EXPECT_STREQ("MULTIPLE_CONTENT_LENGTH_KEYS", + BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums::MULTIPLE_CONTENT_LENGTH_KEYS)); + EXPECT_STREQ("MULTIPLE_TRANSFER_ENCODING_KEYS", + BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums::MULTIPLE_TRANSFER_ENCODING_KEYS)); + EXPECT_STREQ("INVALID_HEADER_FORMAT", + BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums::INVALID_HEADER_FORMAT)); + EXPECT_STREQ("INVALID_TRAILER_FORMAT", + BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums::INVALID_TRAILER_FORMAT)); + EXPECT_STREQ("TRAILER_TOO_LONG", BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums::TRAILER_TOO_LONG)); + EXPECT_STREQ("TRAILER_MISSING_COLON", + BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums::TRAILER_MISSING_COLON)); + EXPECT_STREQ("INTERNAL_LOGIC_ERROR", + BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums::INTERNAL_LOGIC_ERROR)); + EXPECT_STREQ("INVALID_HEADER_CHARACTER", + BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums::INVALID_HEADER_CHARACTER)); + + EXPECT_STREQ("UNKNOWN_ERROR", BalsaFrameEnums::ErrorCodeToString( + BalsaFrameEnums::NUM_ERROR_CODES)); + EXPECT_STREQ("UNKNOWN_ERROR", + BalsaFrameEnums::ErrorCodeToString( + static_cast(-1))); + + for (int i = 0; i < BalsaFrameEnums::NUM_ERROR_CODES; ++i) { + EXPECT_STRNE("UNKNOWN_ERROR", + BalsaFrameEnums::ErrorCodeToString( + static_cast(i))); + } +} + +class FakeHeaders { + public: + struct KeyValuePair { + KeyValuePair(const std::string& key, const std::string& value) + : key(key), value(value) {} + KeyValuePair() {} + + std::string key; + std::string value; + }; + typedef std::vector KeyValuePairs; + KeyValuePairs key_value_pairs_; + + bool operator==(const FakeHeaders& other) const { + if (key_value_pairs_.size() != other.key_value_pairs_.size()) { + return false; + } + for (KeyValuePairs::size_type i = 0; i < key_value_pairs_.size(); ++i) { + if (key_value_pairs_[i].key != other.key_value_pairs_[i].key) { + return false; + } + if (key_value_pairs_[i].value != other.key_value_pairs_[i].value) { + return false; + } + } + return true; + } + + void AddKeyValue(const std::string& key, const std::string& value) { + key_value_pairs_.push_back(KeyValuePair(key, value)); + } +}; + +class BalsaVisitorMock : public BalsaVisitorInterface { + public: + ~BalsaVisitorMock() override = default; + + void ProcessHeaders(const BalsaHeaders& headers) override { + FakeHeaders fake_headers; + GenerateFakeHeaders(headers, &fake_headers); + ProcessHeaders(fake_headers); + } + void ProcessTrailers(const BalsaHeaders& trailer) override { + FakeHeaders fake_headers; + GenerateFakeHeaders(trailer, &fake_headers); + ProcessTrailers(fake_headers); + } + + MOCK_METHOD(void, OnRawBodyInput, (absl::string_view input), (override)); + MOCK_METHOD(void, OnBodyChunkInput, (absl::string_view input), (override)); + MOCK_METHOD(void, OnHeaderInput, (absl::string_view input), (override)); + MOCK_METHOD(void, OnHeader, (absl::string_view key, absl::string_view value), + (override)); + MOCK_METHOD(void, OnTrailerInput, (absl::string_view input), (override)); + MOCK_METHOD(void, ProcessHeaders, (const FakeHeaders& headers)); + MOCK_METHOD(void, ProcessTrailers, (const FakeHeaders& headers)); + MOCK_METHOD(void, OnRequestFirstLineInput, + (absl::string_view line_input, absl::string_view method_input, + absl::string_view request_uri, absl::string_view version_input), + (override)); + MOCK_METHOD(void, OnResponseFirstLineInput, + (absl::string_view line_input, absl::string_view version_input, + absl::string_view status_input, absl::string_view reason_input), + (override)); + MOCK_METHOD(void, OnChunkLength, (size_t length), (override)); + MOCK_METHOD(void, OnChunkExtensionInput, (absl::string_view input), + (override)); + MOCK_METHOD(void, OnInterimHeaders, (BalsaHeaders headers), (override)); + MOCK_METHOD(void, ContinueHeaderDone, (), (override)); + MOCK_METHOD(void, HeaderDone, (), (override)); + MOCK_METHOD(void, MessageDone, (), (override)); + MOCK_METHOD(void, HandleError, (BalsaFrameEnums::ErrorCode error_code), + (override)); + MOCK_METHOD(void, HandleWarning, (BalsaFrameEnums::ErrorCode error_code), + (override)); + + private: + static void GenerateFakeHeaders(const BalsaHeaders& headers, + FakeHeaders* fake_headers) { + for (const auto& line : headers.lines()) { + fake_headers->AddKeyValue(std::string(line.first), + std::string(line.second)); + } + } +}; + +class HTTPBalsaFrameTest : public QuicheTest { + protected: + void SetUp() override { + balsa_frame_.set_balsa_headers(&headers_); + balsa_frame_.set_balsa_trailer(&trailer_); + balsa_frame_.set_balsa_visitor(&visitor_mock_); + balsa_frame_.set_is_request(true); + + EXPECT_CALL(visitor_mock_, OnHeader).Times(AnyNumber()); + } + + void VerifyFirstLineParsing(const std::string& firstline, + BalsaFrameEnums::ErrorCode error_code) { + balsa_frame_.ProcessInput(firstline.data(), firstline.size()); + EXPECT_EQ(error_code, balsa_frame_.ErrorCode()); + } + + BalsaHeaders headers_; + BalsaHeaders trailer_; + BalsaFrame balsa_frame_; + NiceMock visitor_mock_; +}; + +// Test correct return value for HeaderFramingFound. +TEST_F(HTTPBalsaFrameTest, TestHeaderFramingFound) { + // Pattern \r\n\r\n should match kValidTerm1. + EXPECT_EQ(0, BalsaFrameTestPeer::HeaderFramingFound(&balsa_frame_, ' ')); + EXPECT_EQ(0, BalsaFrameTestPeer::HeaderFramingFound(&balsa_frame_, '\r')); + EXPECT_EQ(0, BalsaFrameTestPeer::HeaderFramingFound(&balsa_frame_, '\n')); + EXPECT_EQ(0, BalsaFrameTestPeer::HeaderFramingFound(&balsa_frame_, '\r')); + EXPECT_EQ(BalsaFrame::kValidTerm1, + BalsaFrameTestPeer::HeaderFramingFound(&balsa_frame_, '\n')); + + // Pattern \n\r\n should match kValidTerm1. + EXPECT_EQ(0, BalsaFrameTestPeer::HeaderFramingFound(&balsa_frame_, '\t')); + EXPECT_EQ(0, BalsaFrameTestPeer::HeaderFramingFound(&balsa_frame_, '\n')); + EXPECT_EQ(0, BalsaFrameTestPeer::HeaderFramingFound(&balsa_frame_, '\r')); + EXPECT_EQ(BalsaFrame::kValidTerm1, + BalsaFrameTestPeer::HeaderFramingFound(&balsa_frame_, '\n')); + + // Pattern \r\n\n should match kValidTerm2. + EXPECT_EQ(0, BalsaFrameTestPeer::HeaderFramingFound(&balsa_frame_, 'a')); + EXPECT_EQ(0, BalsaFrameTestPeer::HeaderFramingFound(&balsa_frame_, '\r')); + EXPECT_EQ(0, BalsaFrameTestPeer::HeaderFramingFound(&balsa_frame_, '\n')); + EXPECT_EQ(BalsaFrame::kValidTerm2, + BalsaFrameTestPeer::HeaderFramingFound(&balsa_frame_, '\n')); + + // Pattern \n\n should match kValidTerm2. + EXPECT_EQ(0, BalsaFrameTestPeer::HeaderFramingFound(&balsa_frame_, '1')); + EXPECT_EQ(0, BalsaFrameTestPeer::HeaderFramingFound(&balsa_frame_, '\n')); + EXPECT_EQ(BalsaFrame::kValidTerm2, + BalsaFrameTestPeer::HeaderFramingFound(&balsa_frame_, '\n')); + + // Other patterns should not match. + EXPECT_EQ(0, BalsaFrameTestPeer::HeaderFramingFound(&balsa_frame_, ':')); + EXPECT_EQ(0, BalsaFrameTestPeer::HeaderFramingFound(&balsa_frame_, '\r')); + EXPECT_EQ(0, BalsaFrameTestPeer::HeaderFramingFound(&balsa_frame_, '\r')); + EXPECT_EQ(0, BalsaFrameTestPeer::HeaderFramingFound(&balsa_frame_, '\n')); +} + +TEST_F(HTTPBalsaFrameTest, MissingColonInTrailer) { + const absl::string_view trailer = "kv\r\n\r\n"; + + BalsaFrame::Lines lines; + lines.push_back({0, 4}); + lines.push_back({4, trailer.length()}); + BalsaHeadersTestPeer::WriteFromFramer(&trailer_, trailer.data(), + trailer.length()); + BalsaFrameTestPeer::FindColonsAndParseIntoKeyValue( + &balsa_frame_, lines, true /*is_trailer*/, &trailer_); + // Note missing colon is not an error, just a warning. + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::TRAILER_MISSING_COLON, balsa_frame_.ErrorCode()); +} + +// Correctness of FindColonsAndParseIntoKeyValue is already verified for +// headers, so trailer related test is light. +TEST_F(HTTPBalsaFrameTest, FindColonsAndParseIntoKeyValueInTrailer) { + const absl::string_view trailer_line1 = "Fraction: 0.23\r\n"; + const absl::string_view trailer_line2 = "Some:junk \r\n"; + const absl::string_view trailer_line3 = "\r\n"; + const std::string trailer = + absl::StrCat(trailer_line1, trailer_line2, trailer_line3); + + BalsaFrame::Lines lines; + lines.push_back({0, trailer_line1.length()}); + lines.push_back({trailer_line1.length(), + trailer_line1.length() + trailer_line2.length()}); + lines.push_back( + {trailer_line1.length() + trailer_line2.length(), trailer.length()}); + BalsaHeadersTestPeer::WriteFromFramer(&trailer_, trailer.data(), + trailer.length()); + BalsaFrameTestPeer::FindColonsAndParseIntoKeyValue( + &balsa_frame_, lines, true /*is_trailer*/, &trailer_); + EXPECT_FALSE(balsa_frame_.Error()); + absl::string_view fraction = trailer_.GetHeader("Fraction"); + EXPECT_EQ("0.23", fraction); + absl::string_view some = trailer_.GetHeader("Some"); + EXPECT_EQ("junk", some); +} + +TEST_F(HTTPBalsaFrameTest, InvalidTrailer) { + const absl::string_view trailer_line1 = "Fraction : 0.23\r\n"; + const absl::string_view trailer_line2 = "Some\t :junk \r\n"; + const absl::string_view trailer_line3 = "\r\n"; + const std::string trailer = + absl::StrCat(trailer_line1, trailer_line2, trailer_line3); + + BalsaFrame::Lines lines; + lines.push_back({0, trailer_line1.length()}); + lines.push_back({trailer_line1.length(), + trailer_line1.length() + trailer_line2.length()}); + lines.push_back( + {trailer_line1.length() + trailer_line2.length(), trailer.length()}); + BalsaHeadersTestPeer::WriteFromFramer(&trailer_, trailer.data(), + trailer.length()); + BalsaFrameTestPeer::FindColonsAndParseIntoKeyValue( + &balsa_frame_, lines, true /*is_trailer*/, &trailer_); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::INVALID_TRAILER_NAME_CHARACTER, + balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, OneCharacterFirstLineParsedAsExpected) { + VerifyFirstLineParsing( + "a\r\n\r\n", BalsaFrameEnums::FAILED_TO_FIND_WS_AFTER_REQUEST_METHOD); +} + +TEST_F(HTTPBalsaFrameTest, + OneCharacterFirstLineWithWhitespaceParsedAsExpected) { + VerifyFirstLineParsing( + "a \r\n\r\n", BalsaFrameEnums::FAILED_TO_FIND_WS_AFTER_REQUEST_METHOD); +} + +TEST_F(HTTPBalsaFrameTest, WhitespaceOnlyFirstLineIsNotACompleteHeader) { + VerifyFirstLineParsing(" \n\n", BalsaFrameEnums::NO_REQUEST_LINE_IN_REQUEST); +} + +TEST(HTTPBalsaFrame, RequestFirstLineParsedCorrectly) { + const char* request_tokens[3] = {"GET", "/jjsdjrqk", "HTTP/1.0"}; + FirstLineParsedCorrectlyHelper(request_tokens, 0, true, " "); + FirstLineParsedCorrectlyHelper(request_tokens, 0, true, "\t"); + FirstLineParsedCorrectlyHelper(request_tokens, 0, true, "\t "); + FirstLineParsedCorrectlyHelper(request_tokens, 0, true, " \t"); + FirstLineParsedCorrectlyHelper(request_tokens, 0, true, " \t \t "); +} + +TEST_F(HTTPBalsaFrameTest, NonnumericResponseCode) { + balsa_frame_.set_is_request(false); + + VerifyFirstLineParsing("HTTP/1.1 0x3 Digits only\r\n\r\n", + BalsaFrameEnums::FAILED_CONVERTING_STATUS_CODE_TO_INT); + + EXPECT_EQ("HTTP/1.1 0x3 Digits only", headers_.first_line()); +} + +TEST_F(HTTPBalsaFrameTest, NegativeResponseCode) { + balsa_frame_.set_is_request(false); + + VerifyFirstLineParsing("HTTP/1.1 -11 No sign allowed\r\n\r\n", + BalsaFrameEnums::FAILED_CONVERTING_STATUS_CODE_TO_INT); + + EXPECT_EQ("HTTP/1.1 -11 No sign allowed", headers_.first_line()); +} + +TEST_F(HTTPBalsaFrameTest, WithoutTrailingWhitespace) { + balsa_frame_.set_is_request(false); + + VerifyFirstLineParsing( + "HTTP/1.1 101\r\n\r\n", + BalsaFrameEnums::FAILED_TO_FIND_WS_AFTER_RESPONSE_STATUSCODE); + + EXPECT_EQ("HTTP/1.1 101", headers_.first_line()); +} + +TEST_F(HTTPBalsaFrameTest, TrailingWhitespace) { + balsa_frame_.set_is_request(false); + + // b/69446061 + std::string firstline = "HTTP/1.1 101 \r\n\r\n"; + balsa_frame_.ProcessInput(firstline.data(), firstline.size()); + + EXPECT_EQ("HTTP/1.1 101 ", headers_.first_line()); +} + +TEST(HTTPBalsaFrame, ResponseFirstLineParsedCorrectly) { + const char* response_tokens[3] = {"HTTP/1.1", "200", "A reason\tphrase"}; + FirstLineParsedCorrectlyHelper(response_tokens, 200, false, " "); + FirstLineParsedCorrectlyHelper(response_tokens, 200, false, "\t"); + FirstLineParsedCorrectlyHelper(response_tokens, 200, false, "\t "); + FirstLineParsedCorrectlyHelper(response_tokens, 200, false, " \t"); + FirstLineParsedCorrectlyHelper(response_tokens, 200, false, " \t \t "); + + response_tokens[1] = "312"; + FirstLineParsedCorrectlyHelper(response_tokens, 312, false, " "); + FirstLineParsedCorrectlyHelper(response_tokens, 312, false, "\t"); + FirstLineParsedCorrectlyHelper(response_tokens, 312, false, "\t "); + FirstLineParsedCorrectlyHelper(response_tokens, 312, false, " \t"); + FirstLineParsedCorrectlyHelper(response_tokens, 312, false, " \t \t "); + + // Who knows what the future may hold w.r.t. response codes?! + response_tokens[1] = "4242"; + FirstLineParsedCorrectlyHelper(response_tokens, 4242, false, " "); + FirstLineParsedCorrectlyHelper(response_tokens, 4242, false, "\t"); + FirstLineParsedCorrectlyHelper(response_tokens, 4242, false, "\t "); + FirstLineParsedCorrectlyHelper(response_tokens, 4242, false, " \t"); + FirstLineParsedCorrectlyHelper(response_tokens, 4242, false, " \t \t "); +} + +void HeaderLineTestHelper(const char* firstline, bool is_request, + const std::pair* headers, + size_t headers_len, const char* colon, + const char* line_ending) { + BalsaHeaders balsa_headers; + BalsaFrame framer; + framer.set_is_request(is_request); + framer.set_balsa_headers(&balsa_headers); + std::string message = + CreateMessage(firstline, headers, headers_len, colon, line_ending, ""); + SCOPED_TRACE(EscapeString(message)); + size_t bytes_consumed = framer.ProcessInput(message.data(), message.size()); + EXPECT_EQ(message.size(), bytes_consumed); + VerifyHeaderLines(headers, headers_len, *framer.headers()); +} + +TEST(HTTPBalsaFrame, RequestLinesParsedProperly) { + SCOPED_TRACE("Testing that lines are properly parsed."); + const char firstline[] = "GET / HTTP/1.1\r\n"; + const std::pair headers[] = { + std::pair("foo", "bar"), + std::pair("duck", "water"), + std::pair("goose", "neck"), + std::pair("key_is_fine", + "value:includes:colons"), + std::pair("trucks", + "along\rvalue\rincluding\rslash\rrs"), + std::pair("monster", "truck"), + std::pair("another_key", ":colons in value"), + std::pair("another_key", "colons in value:"), + std::pair("another_key", + "value includes\r\n continuation"), + std::pair("key_without_continuations", + "multiple\n in\r\n the\n value"), + std::pair("key_without_value", + ""), // empty value + std::pair("", + "value without key"), // empty key + std::pair("", ""), // both key and value empty + std::pair("normal_key", "normal_value"), + }; + const size_t headers_len = ABSL_ARRAYSIZE(headers); + HeaderLineTestHelper(firstline, true, headers, headers_len, ": ", "\n"); + HeaderLineTestHelper(firstline, true, headers, headers_len, ": ", "\r\n"); + HeaderLineTestHelper(firstline, true, headers, headers_len, ":\t", "\n"); + HeaderLineTestHelper(firstline, true, headers, headers_len, ":\t", "\r\n"); + HeaderLineTestHelper(firstline, true, headers, headers_len, ":\t ", "\n"); + HeaderLineTestHelper(firstline, true, headers, headers_len, ":\t ", "\r\n"); + HeaderLineTestHelper(firstline, true, headers, headers_len, ":\t\t", "\n"); + HeaderLineTestHelper(firstline, true, headers, headers_len, ":\t\t", "\r\n"); + HeaderLineTestHelper(firstline, true, headers, headers_len, ":\t \t", "\n"); + HeaderLineTestHelper(firstline, true, headers, headers_len, ":\t \t", "\r\n"); +} + +TEST(HTTPBalsaFrame, ResponseLinesParsedProperly) { + SCOPED_TRACE("ResponseLineParsedProperly"); + const char firstline[] = "HTTP/1.0 200 A reason\tphrase\r\n"; + const std::pair headers[] = { + std::pair("foo", "bar"), + std::pair("duck", "water"), + std::pair("goose", "neck"), + std::pair("key_is_fine", + "value:includes:colons"), + std::pair("trucks", + "along\rvalue\rincluding\rslash\rrs"), + std::pair("monster", "truck"), + std::pair("another_key", ":colons in value"), + std::pair("another_key", "colons in value:"), + std::pair("another_key", + "value includes\r\n continuation"), + std::pair("key_includes_no_continuations", + "multiple\n in\r\n the\n value"), + std::pair("key_without_value", + ""), // empty value + std::pair("", + "value without key"), // empty key + std::pair("", ""), // both key and value empty + std::pair("normal_key", "normal_value"), + }; + const size_t headers_len = ABSL_ARRAYSIZE(headers); + HeaderLineTestHelper(firstline, false, headers, headers_len, ": ", "\n"); + HeaderLineTestHelper(firstline, false, headers, headers_len, ": ", "\r\n"); + HeaderLineTestHelper(firstline, false, headers, headers_len, ":\t", "\n"); + HeaderLineTestHelper(firstline, false, headers, headers_len, ":\t", "\r\n"); + HeaderLineTestHelper(firstline, false, headers, headers_len, ":\t ", "\n"); + HeaderLineTestHelper(firstline, false, headers, headers_len, ":\t ", "\r\n"); + HeaderLineTestHelper(firstline, false, headers, headers_len, ":\t\t", "\n"); + HeaderLineTestHelper(firstline, false, headers, headers_len, ":\t\t", "\r\n"); + HeaderLineTestHelper(firstline, false, headers, headers_len, ":\t \t", "\n"); + HeaderLineTestHelper(firstline, false, headers, headers_len, ":\t \t", + "\r\n"); +} + +void WhitespaceHeaderTestHelper( + const std::string& message, bool is_request, + BalsaFrameEnums::ErrorCode expected_error_code) { + BalsaHeaders balsa_headers; + BalsaFrame framer; + framer.set_is_request(is_request); + framer.set_balsa_headers(&balsa_headers); + SCOPED_TRACE(EscapeString(message)); + size_t bytes_consumed = framer.ProcessInput(message.data(), message.size()); + EXPECT_EQ(message.size(), bytes_consumed); + if (expected_error_code == BalsaFrameEnums::BALSA_NO_ERROR) { + EXPECT_EQ(false, framer.Error()); + } else { + EXPECT_EQ(true, framer.Error()); + } + EXPECT_EQ(expected_error_code, framer.ErrorCode()); +} + +TEST(HTTPBalsaFrame, WhitespaceInRequestsProcessedProperly) { + SCOPED_TRACE( + "Test that a request header with a line with spaces and no " + "data generates an error."); + WhitespaceHeaderTestHelper( + "GET / HTTP/1.1\r\n" + " \r\n" + "\r\n", + true, BalsaFrameEnums::INVALID_HEADER_NAME_CHARACTER); + WhitespaceHeaderTestHelper( + "GET / HTTP/1.1\r\n" + " \r\n" + "test: test\r\n" + "\r\n", + true, BalsaFrameEnums::INVALID_HEADER_NAME_CHARACTER); + + SCOPED_TRACE("Test proper handling for line continuation in requests."); + WhitespaceHeaderTestHelper( + "GET / HTTP/1.1\r\n" + "test: test\r\n" + " continued\r\n" + "\r\n", + true, BalsaFrameEnums::BALSA_NO_ERROR); + WhitespaceHeaderTestHelper( + "GET / HTTP/1.1\r\n" + "test: test\r\n" + " \r\n" + "\r\n", + true, BalsaFrameEnums::BALSA_NO_ERROR); +} + +TEST(HTTPBalsaFrame, WhitespaceInResponsesProcessedProperly) { + SCOPED_TRACE( + "Test that a response header with a line with spaces and no " + "data generates an error."); + WhitespaceHeaderTestHelper( + "HTTP/1.0 200 Reason\r\n" + " \r\nContent-Length: 0\r\n" + "\r\n", + false, BalsaFrameEnums::INVALID_HEADER_NAME_CHARACTER); + + SCOPED_TRACE("Test proper handling for line continuation in responses."); + WhitespaceHeaderTestHelper( + "HTTP/1.0 200 Reason\r\n" + "test: test\r\n" + " continued\r\n" + "Content-Length: 0\r\n" + "\r\n", + false, BalsaFrameEnums::BALSA_NO_ERROR); + WhitespaceHeaderTestHelper( + "HTTP/1.0 200 Reason\r\n" + "test: test\r\n" + " \r\n" + "Content-Length: 0\r\n" + "\r\n", + false, BalsaFrameEnums::BALSA_NO_ERROR); +} + +TEST_F(HTTPBalsaFrameTest, VisitorInvokedProperlyForTrivialRequest) { + std::string message = "GET /foobar HTTP/1.0\r\n\n"; + + FakeHeaders fake_headers; + + { + InSequence s; + + EXPECT_CALL(visitor_mock_, + OnRequestFirstLineInput("GET /foobar HTTP/1.0", "GET", + "/foobar", "HTTP/1.0")); + EXPECT_CALL(visitor_mock_, ProcessHeaders(fake_headers)); + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, MessageDone()); + } + EXPECT_CALL(visitor_mock_, OnHeaderInput(message)); + + ASSERT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); +} + +TEST_F(HTTPBalsaFrameTest, VisitorInvokedProperlyForRequestWithBlankLines) { + std::string message = "\n\n\r\n\nGET /foobar HTTP/1.0\r\n\n"; + + FakeHeaders fake_headers; + + { + InSequence s1; + // Yes, that is correct-- the framer 'eats' the blank-lines at the beginning + // and never notifies the visitor. + + EXPECT_CALL(visitor_mock_, + OnRequestFirstLineInput("GET /foobar HTTP/1.0", "GET", + "/foobar", "HTTP/1.0")); + EXPECT_CALL(visitor_mock_, ProcessHeaders(fake_headers)); + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, MessageDone()); + } + EXPECT_CALL(visitor_mock_, OnHeaderInput("GET /foobar HTTP/1.0\r\n\n")); + + ASSERT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); +} + +TEST_F(HTTPBalsaFrameTest, + VisitorInvokedProperlyForRequestWitSplithBlankLines) { + std::string blanks = + "\n" + "\n" + "\r\n" + "\n"; + std::string header_input = "GET /foobar HTTP/1.0\r\n\n"; + + FakeHeaders fake_headers; + + { + InSequence s1; + // Yes, that is correct-- the framer 'eats' the blank-lines at the beginning + // and never notifies the visitor. + + EXPECT_CALL(visitor_mock_, + OnRequestFirstLineInput("GET /foobar HTTP/1.0", "GET", + "/foobar", "HTTP/1.0")); + EXPECT_CALL(visitor_mock_, ProcessHeaders(fake_headers)); + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, MessageDone()); + } + EXPECT_CALL(visitor_mock_, OnHeaderInput("GET /foobar HTTP/1.0\r\n\n")); + + ASSERT_EQ(blanks.size(), + balsa_frame_.ProcessInput(blanks.data(), blanks.size())); + ASSERT_EQ(header_input.size(), balsa_frame_.ProcessInput( + header_input.data(), header_input.size())); +} + +TEST_F(HTTPBalsaFrameTest, + VisitorInvokedProperlyForRequestWithZeroContentLength) { + std::string message = + "PUT /search?q=fo HTTP/1.1\n" + "content-length: 0 \n" + "\n"; + + FakeHeaders fake_headers; + fake_headers.AddKeyValue("content-length", "0"); + + { + InSequence s1; + + EXPECT_CALL(visitor_mock_, + OnRequestFirstLineInput("PUT /search?q=fo HTTP/1.1", "PUT", + "/search?q=fo", "HTTP/1.1")); + EXPECT_CALL(visitor_mock_, ProcessHeaders(fake_headers)); + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, MessageDone()); + } + EXPECT_CALL(visitor_mock_, OnHeaderInput(message)); + + ASSERT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); +} + +TEST_F(HTTPBalsaFrameTest, + VisitorInvokedProperlyForRequestWithMissingContentLength) { + std::string message = + "PUT /search?q=fo HTTP/1.1\n" + "\n"; + + auto error_code = + BalsaFrameEnums::BalsaFrameEnums::REQUIRED_BODY_BUT_NO_CONTENT_LENGTH; + EXPECT_CALL(visitor_mock_, HandleError(error_code)); + + balsa_frame_.ProcessInput(message.data(), message.size()); + EXPECT_FALSE(balsa_frame_.MessageFullyRead()); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(error_code, balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, + VisitorInvokedProperlyForPermittedMissingContentLength) { + std::string message = + "PUT /search?q=fo HTTP/1.1\n" + "\n"; + + FakeHeaders fake_headers; + + { + InSequence s1; + + EXPECT_CALL(visitor_mock_, + OnRequestFirstLineInput("PUT /search?q=fo HTTP/1.1", "PUT", + "/search?q=fo", "HTTP/1.1")); + } + ASSERT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); +} + +TEST_F(HTTPBalsaFrameTest, NothingBadHappensWhenNothingInConnectionLine) { + // This is similar to the test above, but we use different whitespace + // throughout. + std::string message = + "PUT \t /search?q=fo \t HTTP/1.1 \t \r\n" + "Connection:\r\n" + "content-length: 0\r\n" + "\r\n"; + + FakeHeaders fake_headers; + fake_headers.AddKeyValue("Connection", ""); + fake_headers.AddKeyValue("content-length", "0"); + + { + InSequence s1; + + EXPECT_CALL(visitor_mock_, + OnRequestFirstLineInput("PUT \t /search?q=fo \t HTTP/1.1", + "PUT", "/search?q=fo", "HTTP/1.1")); + EXPECT_CALL(visitor_mock_, ProcessHeaders(fake_headers)); + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, MessageDone()); + } + EXPECT_CALL(visitor_mock_, OnHeaderInput(message)); + + ASSERT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); +} + +TEST_F(HTTPBalsaFrameTest, NothingBadHappensWhenOnlyCommentsInConnectionLine) { + // This is similar to the test above, but we use different whitespace + // throughout. + std::string message = + "PUT \t /search?q=fo \t HTTP/1.1 \t \r\n" + "Connection: ,,,,,,,,\r\n" + "content-length: 0\r\n" + "\r\n"; + + FakeHeaders fake_headers; + fake_headers.AddKeyValue("Connection", ",,,,,,,,"); + fake_headers.AddKeyValue("content-length", "0"); + + { + InSequence s1; + + EXPECT_CALL(visitor_mock_, + OnRequestFirstLineInput("PUT \t /search?q=fo \t HTTP/1.1", + "PUT", "/search?q=fo", "HTTP/1.1")); + EXPECT_CALL(visitor_mock_, ProcessHeaders(fake_headers)); + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, MessageDone()); + } + EXPECT_CALL(visitor_mock_, OnHeaderInput(message)); + + ASSERT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); +} + +TEST_F(HTTPBalsaFrameTest, + VisitorInvokedProperlyForRequestWithZeroContentLengthMk2) { + // This is similar to the test above, but we use different whitespace + // throughout. + std::string message = + "PUT \t /search?q=fo \t HTTP/1.1 \t \r\n" + "Connection: \t close \t\r\n" + "content-length: \t\t 0 \t\t \r\n" + "\r\n"; + + FakeHeaders fake_headers; + fake_headers.AddKeyValue("Connection", "close"); + fake_headers.AddKeyValue("content-length", "0"); + + { + InSequence s1; + + EXPECT_CALL(visitor_mock_, + OnRequestFirstLineInput("PUT \t /search?q=fo \t HTTP/1.1", + "PUT", "/search?q=fo", "HTTP/1.1")); + EXPECT_CALL(visitor_mock_, ProcessHeaders(fake_headers)); + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, MessageDone()); + } + EXPECT_CALL(visitor_mock_, OnHeaderInput(message)); + + ASSERT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); +} + +TEST_F(HTTPBalsaFrameTest, NothingBadHappensWhenNoVisitorIsAssigned) { + std::string headers = + "GET / HTTP/1.1\r\n" + "Connection: close\r\n" + "transfer-encoding: chunked\r\n" + "\r\n"; + + std::string chunks = + "3\r\n" + "123\r\n" + "0\r\n"; + std::string trailer = + "crass: monkeys\r\n" + "funky: monkeys\r\n" + "\r\n"; + + balsa_frame_.set_balsa_visitor(nullptr); + ASSERT_EQ(headers.size(), + balsa_frame_.ProcessInput(headers.data(), headers.size())); + ASSERT_EQ(chunks.size(), + balsa_frame_.ProcessInput(chunks.data(), chunks.size())); + EXPECT_EQ(trailer.size(), + balsa_frame_.ProcessInput(trailer.data(), trailer.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); + const absl::string_view crass = trailer_.GetHeader("crass"); + EXPECT_EQ("monkeys", crass); + const absl::string_view funky = trailer_.GetHeader("funky"); + EXPECT_EQ("monkeys", funky); +} + +TEST_F(HTTPBalsaFrameTest, RequestWithTrailers) { + std::string headers = + "GET / HTTP/1.1\r\n" + "Connection: close\r\n" + "transfer-encoding: chunked\r\n" + "\r\n"; + + std::string chunks = + "3\r\n" + "123\r\n" + "0\r\n"; + std::string trailer = + "crass: monkeys\r\n" + "funky: monkeys\r\n" + "\r\n"; + + InSequence s; + + // OnHeader() visitor method is called as soon as headers are parsed. + EXPECT_CALL(visitor_mock_, OnHeader("Connection", "close")); + EXPECT_CALL(visitor_mock_, OnHeader("transfer-encoding", "chunked")); + ASSERT_EQ(headers.size(), + balsa_frame_.ProcessInput(headers.data(), headers.size())); + testing::Mock::VerifyAndClearExpectations(&visitor_mock_); + + ASSERT_EQ(chunks.size(), + balsa_frame_.ProcessInput(chunks.data(), chunks.size())); + + EXPECT_CALL(visitor_mock_, OnHeader("crass", "monkeys")); + EXPECT_CALL(visitor_mock_, OnHeader("funky", "monkeys")); + + FakeHeaders fake_trailers; + fake_trailers.AddKeyValue("crass", "monkeys"); + fake_trailers.AddKeyValue("funky", "monkeys"); + EXPECT_CALL(visitor_mock_, ProcessTrailers(fake_trailers)); + + EXPECT_CALL(visitor_mock_, OnTrailerInput(_)).Times(AtLeast(1)); + + EXPECT_EQ(trailer.size(), + balsa_frame_.ProcessInput(trailer.data(), trailer.size())); + + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); + + const absl::string_view crass = trailer_.GetHeader("crass"); + EXPECT_EQ("monkeys", crass); + const absl::string_view funky = trailer_.GetHeader("funky"); + EXPECT_EQ("monkeys", funky); +} + +TEST_F(HTTPBalsaFrameTest, NothingBadHappensWhenNoVisitorIsAssignedInResponse) { + std::string headers = + "HTTP/1.1 502 Bad Gateway\r\n" + "Connection: close\r\n" + "transfer-encoding: chunked\r\n" + "\r\n"; + + std::string chunks = + "3\r\n" + "123\r\n" + "0\r\n"; + std::string trailer = + "crass: monkeys\r\n" + "funky: monkeys\r\n" + "\r\n"; + balsa_frame_.set_is_request(false); + balsa_frame_.set_balsa_visitor(nullptr); + + ASSERT_EQ(headers.size(), + balsa_frame_.ProcessInput(headers.data(), headers.size())); + ASSERT_EQ(chunks.size(), + balsa_frame_.ProcessInput(chunks.data(), chunks.size())); + EXPECT_EQ(trailer.size(), + balsa_frame_.ProcessInput(trailer.data(), trailer.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); + const absl::string_view crass = trailer_.GetHeader("crass"); + EXPECT_EQ("monkeys", crass); + const absl::string_view funky = trailer_.GetHeader("funky"); + EXPECT_EQ("monkeys", funky); +} + +TEST_F(HTTPBalsaFrameTest, TransferEncodingIdentityIsIgnored) { + std::string headers = + "GET / HTTP/1.1\r\n" + "Connection: close\r\n" + "transfer-encoding: identity\r\n" + "content-length: 10\r\n" + "\r\n"; + + std::string body = "1234567890"; + std::string message = (headers + body); + + ASSERT_EQ(headers.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_FALSE(balsa_frame_.MessageFullyRead()); + ASSERT_EQ(body.size(), balsa_frame_.ProcessInput(body.data(), body.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, + NothingBadHappensWhenAVisitorIsChangedToNULLInMidParsing) { + std::string headers = + "GET / HTTP/1.1\r\n" + "Connection: close\r\n" + "transfer-encoding: chunked\r\n" + "\r\n"; + + std::string chunks = + "3\r\n" + "123\r\n" + "0\r\n"; + std::string trailer = + "crass: monkeys\r\n" + "funky: monkeys\r\n" + "\n"; + + ASSERT_EQ(headers.size(), + balsa_frame_.ProcessInput(headers.data(), headers.size())); + balsa_frame_.set_balsa_visitor(nullptr); + ASSERT_EQ(chunks.size(), + balsa_frame_.ProcessInput(chunks.data(), chunks.size())); + ASSERT_EQ(trailer.size(), + balsa_frame_.ProcessInput(trailer.data(), trailer.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, + NothingBadHappensWhenAVisitorIsChangedToNULLInMidParsingInTrailer) { + std::string headers = + "HTTP/1.1 503 Server Not Available\r\n" + "Connection: close\r\n" + "transfer-encoding: chunked\r\n" + "\r\n"; + + std::string chunks = + "3\r\n" + "123\r\n" + "0\r\n"; + std::string trailer = + "crass: monkeys\r\n" + "funky: monkeys\r\n" + "\n"; + + balsa_frame_.set_is_request(false); + + ASSERT_EQ(headers.size(), + balsa_frame_.ProcessInput(headers.data(), headers.size())); + balsa_frame_.set_balsa_visitor(nullptr); + ASSERT_EQ(chunks.size(), + balsa_frame_.ProcessInput(chunks.data(), chunks.size())); + ASSERT_EQ(trailer.size(), + balsa_frame_.ProcessInput(trailer.data(), trailer.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); + const absl::string_view crass = trailer_.GetHeader("crass"); + EXPECT_EQ("monkeys", crass); + const absl::string_view funky = trailer_.GetHeader("funky"); + EXPECT_EQ("monkeys", funky); +} + +TEST_F(HTTPBalsaFrameTest, + NothingBadHappensWhenNoVisitorAssignedAndChunkingErrorOccurs) { + std::string headers = + "GET / HTTP/1.1\r\n" + "Connection: close\r\n" + "transfer-encoding: chunked\r\n" + "\r\n"; + + std::string chunks = + "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF\r\n" // should overflow + "0\r\n"; + std::string trailer = + "crass: monkeys\r\n" + "funky: monkeys\r\n" + "\n"; + + ASSERT_EQ(headers.size(), + balsa_frame_.ProcessInput(headers.data(), headers.size())); + balsa_frame_.set_balsa_visitor(nullptr); + EXPECT_GE(chunks.size(), + balsa_frame_.ProcessInput(chunks.data(), chunks.size())); + EXPECT_FALSE(balsa_frame_.MessageFullyRead()); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::CHUNK_LENGTH_OVERFLOW, balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, FramerRecognizesSemicolonAsChunkSizeDelimiter) { + std::string headers = + "GET / HTTP/1.1\r\n" + "Connection: close\r\n" + "transfer-encoding: chunked\r\n" + "\r\n"; + + std::string chunks = + "8; foo=bar\r\n" + "deadbeef\r\n" + "0\r\n" + "\r\n"; + + ASSERT_EQ(headers.size(), + balsa_frame_.ProcessInput(headers.data(), headers.size())); + + balsa_frame_.set_balsa_visitor(&visitor_mock_); + EXPECT_CALL(visitor_mock_, OnChunkLength(8)); + EXPECT_CALL(visitor_mock_, OnChunkLength(0)); + EXPECT_CALL(visitor_mock_, OnChunkExtensionInput("; foo=bar")); + EXPECT_CALL(visitor_mock_, OnChunkExtensionInput("")); + + EXPECT_EQ(chunks.size(), + balsa_frame_.ProcessInput(chunks.data(), chunks.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); +} + +TEST_F(HTTPBalsaFrameTest, NonAsciiCharacterInChunkLength) { + std::string headers = + "GET / HTTP/1.1\r\n" + "Connection: close\r\n" + "transfer-encoding: chunked\r\n" + "\r\n"; + + std::string chunks = + "555\xAB\r\n" // Character overflowing 7 bits, see b/20238315 + "0\r\n"; + std::string trailer = + "crass: monkeys\r\n" + "funky: monkeys\r\n" + "\n"; + + FakeHeaders fake_headers; + fake_headers.AddKeyValue("Connection", "close"); + fake_headers.AddKeyValue("transfer-encoding", "chunked"); + + auto error_code = BalsaFrameEnums::INVALID_CHUNK_LENGTH; + { + InSequence s1; + EXPECT_CALL(visitor_mock_, OnRequestFirstLineInput("GET / HTTP/1.1", "GET", + "/", "HTTP/1.1")); + EXPECT_CALL(visitor_mock_, ProcessHeaders(fake_headers)); + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, OnRawBodyInput("555\xAB")); + EXPECT_CALL(visitor_mock_, HandleError(error_code)); + } + + ASSERT_EQ(headers.size(), + balsa_frame_.ProcessInput(headers.data(), headers.size())); + EXPECT_EQ(strlen("555\xAB"), + balsa_frame_.ProcessInput(chunks.data(), chunks.size())); + EXPECT_FALSE(balsa_frame_.MessageFullyRead()); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::INVALID_CHUNK_LENGTH, balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, VisitorCalledAsExpectedWhenChunkingOverflowOccurs) { + std::string headers = + "GET / HTTP/1.1\r\n" + "Connection: close\r\n" + "transfer-encoding: chunked\r\n" + "\r\n"; + + std::string chunks = + "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF\r\n" // should overflow + "0\r\n"; + std::string trailer = + "crass: monkeys\r\n" + "funky: monkeys\r\n" + "\n"; + + const char* chunk_read_before_overflow = "FFFFFFFFFFFFFFFFF"; + + FakeHeaders fake_headers; + fake_headers.AddKeyValue("Connection", "close"); + fake_headers.AddKeyValue("transfer-encoding", "chunked"); + + auto error_code = BalsaFrameEnums::CHUNK_LENGTH_OVERFLOW; + { + InSequence s1; + EXPECT_CALL(visitor_mock_, OnRequestFirstLineInput("GET / HTTP/1.1", "GET", + "/", "HTTP/1.1")); + EXPECT_CALL(visitor_mock_, ProcessHeaders(fake_headers)); + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, OnRawBodyInput(chunk_read_before_overflow)); + EXPECT_CALL(visitor_mock_, HandleError(error_code)); + } + + ASSERT_EQ(headers.size(), + balsa_frame_.ProcessInput(headers.data(), headers.size())); + EXPECT_EQ(strlen(chunk_read_before_overflow), + balsa_frame_.ProcessInput(chunks.data(), chunks.size())); + EXPECT_FALSE(balsa_frame_.MessageFullyRead()); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::CHUNK_LENGTH_OVERFLOW, balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, + VisitorCalledAsExpectedWhenInvalidChunkLengthOccurs) { + std::string headers = + "GET / HTTP/1.1\r\n" + "Connection: close\r\n" + "transfer-encoding: chunked\r\n" + "\r\n"; + + std::string chunks = + "12z123 \r\n" // invalid chunk length + "0\r\n"; + std::string trailer = + "crass: monkeys\r\n" + "funky: monkeys\r\n" + "\n"; + + FakeHeaders fake_headers; + fake_headers.AddKeyValue("Connection", "close"); + fake_headers.AddKeyValue("transfer-encoding", "chunked"); + + auto error_code = BalsaFrameEnums::INVALID_CHUNK_LENGTH; + { + InSequence s1; + EXPECT_CALL(visitor_mock_, OnRequestFirstLineInput("GET / HTTP/1.1", "GET", + "/", "HTTP/1.1")); + EXPECT_CALL(visitor_mock_, ProcessHeaders(fake_headers)); + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, OnRawBodyInput("12z")); + EXPECT_CALL(visitor_mock_, HandleError(error_code)); + } + + ASSERT_EQ(headers.size(), + balsa_frame_.ProcessInput(headers.data(), headers.size())); + EXPECT_EQ(3u, balsa_frame_.ProcessInput(chunks.data(), chunks.size())); + EXPECT_FALSE(balsa_frame_.MessageFullyRead()); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::INVALID_CHUNK_LENGTH, balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, VisitorInvokedProperlyForRequestWithContentLength) { + std::string message_headers = + "PUT \t /search?q=fo \t HTTP/1.1 \t \r\n" + "content-length: \t\t 20 \t\t \r\n" + "\r\n"; + std::string message_body = "12345678901234567890"; + std::string message = + std::string(message_headers) + std::string(message_body); + + FakeHeaders fake_headers; + fake_headers.AddKeyValue("content-length", "20"); + + { + InSequence s1; + EXPECT_CALL(visitor_mock_, + OnRequestFirstLineInput("PUT \t /search?q=fo \t HTTP/1.1", + "PUT", "/search?q=fo", "HTTP/1.1")); + EXPECT_CALL(visitor_mock_, ProcessHeaders(fake_headers)); + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, OnRawBodyInput(message_body)); + EXPECT_CALL(visitor_mock_, OnBodyChunkInput(message_body)); + EXPECT_CALL(visitor_mock_, MessageDone()); + } + EXPECT_CALL(visitor_mock_, OnHeaderInput(message_headers)); + + ASSERT_EQ(message_headers.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + ASSERT_EQ(message_body.size(), + balsa_frame_.ProcessInput(message.data() + message_headers.size(), + message.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, + VisitorInvokedProperlyForRequestWithOneCharContentLength) { + std::string message_headers = + "PUT \t /search?q=fo \t HTTP/1.1 \t \r\n" + "content-length: \t\t 2 \t\t \r\n" + "\r\n"; + std::string message_body = "12"; + std::string message = + std::string(message_headers) + std::string(message_body); + + FakeHeaders fake_headers; + fake_headers.AddKeyValue("content-length", "2"); + + { + InSequence s1; + EXPECT_CALL(visitor_mock_, + OnRequestFirstLineInput("PUT \t /search?q=fo \t HTTP/1.1", + "PUT", "/search?q=fo", "HTTP/1.1")); + EXPECT_CALL(visitor_mock_, ProcessHeaders(fake_headers)); + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, OnRawBodyInput(message_body)); + EXPECT_CALL(visitor_mock_, OnBodyChunkInput(message_body)); + EXPECT_CALL(visitor_mock_, MessageDone()); + } + EXPECT_CALL(visitor_mock_, OnHeaderInput(message_headers)); + + ASSERT_EQ(message_headers.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + ASSERT_EQ(message_body.size(), + balsa_frame_.ProcessInput(message.data() + message_headers.size(), + message.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, + VisitorInvokedProperlyForRequestWithTransferEncoding) { + std::string message_headers = + "DELETE /search?q=fo \t HTTP/1.1 \t \r\n" + "trAnsfer-eNcoding: chunked\r\n" + "\r\n"; + std::string message_body = + "A chunkjed extension \r\n" + "01234567890 more crud including numbers 123123\r\n" + "3f\n" + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\n" + "0 last one\r\n" + "\r\n"; + std::string message_body_data = + "0123456789" + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; + + std::string message = + std::string(message_headers) + std::string(message_body); + + FakeHeaders fake_headers; + fake_headers.AddKeyValue("trAnsfer-eNcoding", "chunked"); + + { + InSequence s1; + EXPECT_CALL(visitor_mock_, + OnRequestFirstLineInput("DELETE /search?q=fo \t HTTP/1.1", + "DELETE", "/search?q=fo", "HTTP/1.1")); + EXPECT_CALL(visitor_mock_, ProcessHeaders(fake_headers)); + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, OnChunkLength(10)); + EXPECT_CALL(visitor_mock_, + OnChunkExtensionInput(" chunkjed extension ")); + EXPECT_CALL(visitor_mock_, OnChunkLength(63)); + EXPECT_CALL(visitor_mock_, OnChunkExtensionInput("")); + EXPECT_CALL(visitor_mock_, OnChunkLength(0)); + EXPECT_CALL(visitor_mock_, OnChunkExtensionInput(" last one")); + EXPECT_CALL(visitor_mock_, MessageDone()); + } + EXPECT_CALL(visitor_mock_, OnHeaderInput(message_headers)); + std::string body_input; + EXPECT_CALL(visitor_mock_, OnRawBodyInput(_)) + .WillRepeatedly([&body_input](absl::string_view input) { + absl::StrAppend(&body_input, input); + }); + std::string body_data; + EXPECT_CALL(visitor_mock_, OnBodyChunkInput(_)) + .WillRepeatedly([&body_data](absl::string_view input) { + absl::StrAppend(&body_data, input); + }); + EXPECT_CALL(visitor_mock_, OnTrailerInput(_)).Times(0); + + ASSERT_EQ(message_headers.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_EQ(message_body.size(), + balsa_frame_.ProcessInput(message.data() + message_headers.size(), + message.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); + + EXPECT_EQ(message_body, body_input); + EXPECT_EQ(message_body_data, body_data); +} + +TEST_F(HTTPBalsaFrameTest, + VisitorInvokedProperlyForRequestWithTransferEncodingAndTrailers) { + std::string message_headers = + "DELETE /search?q=fo \t HTTP/1.1 \t \r\n" + "trAnsfer-eNcoding: chunked\r\n" + "another_random_header: \r\n" + " \t \n" + " \t includes a continuation\n" + "\r\n"; + std::string message_body = + "A chunkjed extension \r\n" + "01234567890 more crud including numbers 123123\r\n" + "3f\n" + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\n" + "1 \r\n" + "x \r\n" + "0 last one\r\n"; + std::string trailer_data = + "a_trailer_key: and a trailer value\r\n" + "\r\n"; + std::string message_body_data = + "0123456789" + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; + + std::string message = (std::string(message_headers) + + std::string(message_body) + std::string(trailer_data)); + + FakeHeaders fake_headers; + fake_headers.AddKeyValue("trAnsfer-eNcoding", "chunked"); + fake_headers.AddKeyValue("another_random_header", "includes a continuation"); + + { + InSequence s1; + + EXPECT_CALL(visitor_mock_, + OnRequestFirstLineInput("DELETE /search?q=fo \t HTTP/1.1", + "DELETE", "/search?q=fo", "HTTP/1.1")); + EXPECT_CALL(visitor_mock_, ProcessHeaders(fake_headers)); + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, OnChunkLength(10)); + EXPECT_CALL(visitor_mock_, OnChunkLength(63)); + EXPECT_CALL(visitor_mock_, OnChunkLength(1)); + EXPECT_CALL(visitor_mock_, OnChunkLength(0)); + EXPECT_CALL(visitor_mock_, MessageDone()); + } + EXPECT_CALL(visitor_mock_, OnHeaderInput(message_headers)); + std::string body_input; + EXPECT_CALL(visitor_mock_, OnRawBodyInput(_)) + .WillRepeatedly([&body_input](absl::string_view input) { + absl::StrAppend(&body_input, input); + }); + std::string body_data; + EXPECT_CALL(visitor_mock_, OnBodyChunkInput(_)) + .WillRepeatedly([&body_data](absl::string_view input) { + absl::StrAppend(&body_data, input); + }); + EXPECT_CALL(visitor_mock_, OnTrailerInput(trailer_data)); + EXPECT_CALL(visitor_mock_, OnChunkExtensionInput(_)).Times(AnyNumber()); + + ASSERT_EQ(message_headers.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_EQ(message_body.size() + trailer_data.size(), + balsa_frame_.ProcessInput(message.data() + message_headers.size(), + message.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); + + EXPECT_EQ(message_body, body_input); + EXPECT_EQ(message_body_data, body_data); +} + +TEST_F(HTTPBalsaFrameTest, + VisitorInvokedProperlyWithRequestFirstLineWarningWithOnlyMethod) { + std::string message = "GET\n"; + + FakeHeaders fake_headers; + + auto error_code = BalsaFrameEnums::FAILED_TO_FIND_WS_AFTER_REQUEST_METHOD; + { + InSequence s; + EXPECT_CALL(visitor_mock_, HandleWarning(error_code)); + EXPECT_CALL(visitor_mock_, OnRequestFirstLineInput("GET", "GET", "", "")); + EXPECT_CALL(visitor_mock_, ProcessHeaders(fake_headers)); + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, MessageDone()); + } + EXPECT_CALL(visitor_mock_, OnHeaderInput(message)); + + EXPECT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::FAILED_TO_FIND_WS_AFTER_REQUEST_METHOD, + balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, + VisitorInvokedProperlyWithRequestFirstLineWarningWithOnlyMethodAndWS) { + std::string message = "GET \n"; + + FakeHeaders fake_headers; + + auto error_code = BalsaFrameEnums::FAILED_TO_FIND_WS_AFTER_REQUEST_METHOD; + { + InSequence s; + EXPECT_CALL(visitor_mock_, HandleWarning(error_code)); + // The flag setting here intentionally alters the framer's behavior with + // trailing whitespace. + EXPECT_CALL(visitor_mock_, OnRequestFirstLineInput("GET ", "GET", "", "")); + EXPECT_CALL(visitor_mock_, ProcessHeaders(fake_headers)); + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, MessageDone()); + } + EXPECT_CALL(visitor_mock_, OnHeaderInput(message)); + + EXPECT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::FAILED_TO_FIND_WS_AFTER_REQUEST_METHOD, + balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, + VisitorInvokedProperlyWithRequestFirstLineWarningWithMethodAndURI) { + std::string message = "GET /uri\n"; + + FakeHeaders fake_headers; + + auto error_code = + BalsaFrameEnums::FAILED_TO_FIND_WS_AFTER_REQUEST_REQUEST_URI; + { + InSequence s; + EXPECT_CALL(visitor_mock_, HandleWarning(error_code)); + EXPECT_CALL(visitor_mock_, + OnRequestFirstLineInput("GET /uri", "GET", "/uri", "")); + EXPECT_CALL(visitor_mock_, ProcessHeaders(fake_headers)); + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, MessageDone()); + } + EXPECT_CALL(visitor_mock_, OnHeaderInput(message)); + + EXPECT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::FAILED_TO_FIND_WS_AFTER_REQUEST_REQUEST_URI, + balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, VisitorInvokedProperlyWithResponseFirstLineError) { + std::string message = "HTTP/1.1\n\n"; + + FakeHeaders fake_headers; + + balsa_frame_.set_is_request(false); + auto error_code = BalsaFrameEnums::FAILED_TO_FIND_WS_AFTER_RESPONSE_VERSION; + { + InSequence s; + EXPECT_CALL(visitor_mock_, HandleError(error_code)); + // The function returns before any of the following is called. + EXPECT_CALL(visitor_mock_, OnRequestFirstLineInput).Times(0); + EXPECT_CALL(visitor_mock_, ProcessHeaders(_)).Times(0); + EXPECT_CALL(visitor_mock_, HeaderDone()).Times(0); + EXPECT_CALL(visitor_mock_, MessageDone()).Times(0); + } + EXPECT_CALL(visitor_mock_, OnHeaderInput(_)).Times(0); + + EXPECT_GE(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_FALSE(balsa_frame_.MessageFullyRead()); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::FAILED_TO_FIND_WS_AFTER_RESPONSE_VERSION, + balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, FlagsErrorWithContentLengthOverflow) { + std::string message = + "HTTP/1.0 200 OK\r\n" + "content-length: 9999999999999999999999999999999999999999\n" + "\n"; + + balsa_frame_.set_is_request(false); + auto error_code = BalsaFrameEnums::UNPARSABLE_CONTENT_LENGTH; + EXPECT_CALL(visitor_mock_, HandleError(error_code)); + + EXPECT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_FALSE(balsa_frame_.MessageFullyRead()); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::UNPARSABLE_CONTENT_LENGTH, + balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, FlagsErrorWithInvalidResponseCode) { + std::string message = + "HTTP/1.0 x OK\r\n" + "\n"; + + balsa_frame_.set_is_request(false); + auto error_code = BalsaFrameEnums::FAILED_CONVERTING_STATUS_CODE_TO_INT; + EXPECT_CALL(visitor_mock_, HandleError(error_code)); + + EXPECT_GE(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_FALSE(balsa_frame_.MessageFullyRead()); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::FAILED_CONVERTING_STATUS_CODE_TO_INT, + balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, FlagsErrorWithOverflowingResponseCode) { + std::string message = + "HTTP/1.0 999999999999999999999999999999999999999 OK\r\n" + "\n"; + + balsa_frame_.set_is_request(false); + auto error_code = BalsaFrameEnums::FAILED_CONVERTING_STATUS_CODE_TO_INT; + EXPECT_CALL(visitor_mock_, HandleError(error_code)); + + EXPECT_GE(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_FALSE(balsa_frame_.MessageFullyRead()); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::FAILED_CONVERTING_STATUS_CODE_TO_INT, + balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, FlagsErrorWithInvalidContentLength) { + std::string message = + "HTTP/1.0 200 OK\r\n" + "content-length: xxx\n" + "\n"; + + balsa_frame_.set_is_request(false); + auto error_code = BalsaFrameEnums::UNPARSABLE_CONTENT_LENGTH; + EXPECT_CALL(visitor_mock_, HandleError(error_code)); + + EXPECT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_FALSE(balsa_frame_.MessageFullyRead()); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::UNPARSABLE_CONTENT_LENGTH, + balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, FlagsErrorWithNegativeContentLengthValue) { + std::string message = + "HTTP/1.0 200 OK\r\n" + "content-length: -20\n" + "\n"; + + balsa_frame_.set_is_request(false); + auto error_code = BalsaFrameEnums::UNPARSABLE_CONTENT_LENGTH; + EXPECT_CALL(visitor_mock_, HandleError(error_code)); + + EXPECT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_FALSE(balsa_frame_.MessageFullyRead()); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::UNPARSABLE_CONTENT_LENGTH, + balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, FlagsErrorWithEmptyContentLengthValue) { + std::string message = + "HTTP/1.0 200 OK\r\n" + "content-length: \n" + "\n"; + + balsa_frame_.set_is_request(false); + auto error_code = BalsaFrameEnums::UNPARSABLE_CONTENT_LENGTH; + EXPECT_CALL(visitor_mock_, HandleError(error_code)); + + EXPECT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_FALSE(balsa_frame_.MessageFullyRead()); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::UNPARSABLE_CONTENT_LENGTH, + balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, VisitorInvokedProperlyForTrivialResponse) { + std::string message = + "HTTP/1.0 200 OK\r\n" + "content-length: 0\n" + "\n"; + + FakeHeaders fake_headers; + fake_headers.AddKeyValue("content-length", "0"); + + balsa_frame_.set_is_request(false); + { + InSequence s; + EXPECT_CALL(visitor_mock_, OnResponseFirstLineInput( + "HTTP/1.0 200 OK", "HTTP/1.0", "200", "OK")); + EXPECT_CALL(visitor_mock_, ProcessHeaders(fake_headers)); + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, MessageDone()); + } + EXPECT_CALL(visitor_mock_, OnHeaderInput(message)); + + EXPECT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, + VisitorInvokedProperlyForResponseWithSplitBlankLines) { + std::string blanks = + "\n" + "\r\n" + "\r\n"; + std::string header_input = + "HTTP/1.0 200 OK\r\n" + "content-length: 0\n" + "\n"; + FakeHeaders fake_headers; + fake_headers.AddKeyValue("content-length", "0"); + + balsa_frame_.set_is_request(false); + { + InSequence s; + EXPECT_CALL(visitor_mock_, OnResponseFirstLineInput( + "HTTP/1.0 200 OK", "HTTP/1.0", "200", "OK")); + EXPECT_CALL(visitor_mock_, ProcessHeaders(fake_headers)); + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, MessageDone()); + } + EXPECT_CALL(visitor_mock_, OnHeaderInput(header_input)); + + EXPECT_EQ(blanks.size(), + balsa_frame_.ProcessInput(blanks.data(), blanks.size())); + EXPECT_EQ(header_input.size(), balsa_frame_.ProcessInput( + header_input.data(), header_input.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, VisitorInvokedProperlyForResponseWithBlankLines) { + std::string blanks = + "\n" + "\r\n" + "\n" + "\n" + "\r\n" + "\r\n"; + std::string header_input = + "HTTP/1.0 200 OK\r\n" + "content-length: 0\n" + "\n"; + std::string message = blanks + header_input; + + FakeHeaders fake_headers; + fake_headers.AddKeyValue("content-length", "0"); + + balsa_frame_.set_is_request(false); + { + InSequence s; + EXPECT_CALL(visitor_mock_, OnResponseFirstLineInput( + "HTTP/1.0 200 OK", "HTTP/1.0", "200", "OK")); + EXPECT_CALL(visitor_mock_, ProcessHeaders(fake_headers)); + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, MessageDone()); + } + EXPECT_CALL(visitor_mock_, OnHeaderInput(header_input)); + + EXPECT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, VisitorInvokedProperlyForResponseWithContentLength) { + std::string message_headers = + "HTTP/1.1 \t 200 Ok all is well\r\n" + "content-length: \t\t 20 \t\t \r\n" + "\r\n"; + std::string message_body = "12345678901234567890"; + std::string message = + std::string(message_headers) + std::string(message_body); + + FakeHeaders fake_headers; + fake_headers.AddKeyValue("content-length", "20"); + + balsa_frame_.set_is_request(false); + { + InSequence s1; + EXPECT_CALL(visitor_mock_, + OnResponseFirstLineInput("HTTP/1.1 \t 200 Ok all is well", + "HTTP/1.1", "200", "Ok all is well")); + EXPECT_CALL(visitor_mock_, ProcessHeaders(fake_headers)); + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, OnRawBodyInput(message_body)); + EXPECT_CALL(visitor_mock_, OnBodyChunkInput(message_body)); + EXPECT_CALL(visitor_mock_, MessageDone()); + } + EXPECT_CALL(visitor_mock_, OnHeaderInput(message_headers)); + + ASSERT_EQ(message_headers.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_EQ(message_body.size(), + balsa_frame_.ProcessInput(message.data() + message_headers.size(), + message.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, + VisitorInvokedProperlyForResponseWithTransferEncoding) { + std::string message_headers = + "HTTP/1.1 \t 200 Ok all is well\r\n" + "trAnsfer-eNcoding: chunked\r\n" + "\r\n"; + std::string message_body = + "A chunkjed extension \r\n" + "01234567890 more crud including numbers 123123\r\n" + "3f\n" + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\n" + "0 last one\r\n" + "\r\n"; + std::string message_body_data = + "0123456789" + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; + + std::string message = + std::string(message_headers) + std::string(message_body); + + FakeHeaders fake_headers; + fake_headers.AddKeyValue("trAnsfer-eNcoding", "chunked"); + + balsa_frame_.set_is_request(false); + { + InSequence s1; + EXPECT_CALL(visitor_mock_, + OnResponseFirstLineInput("HTTP/1.1 \t 200 Ok all is well", + "HTTP/1.1", "200", "Ok all is well")); + EXPECT_CALL(visitor_mock_, ProcessHeaders(fake_headers)); + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, OnChunkLength(10)); + EXPECT_CALL(visitor_mock_, OnChunkLength(63)); + EXPECT_CALL(visitor_mock_, OnChunkLength(0)); + EXPECT_CALL(visitor_mock_, MessageDone()); + } + EXPECT_CALL(visitor_mock_, OnHeaderInput(message_headers)); + std::string body_input; + EXPECT_CALL(visitor_mock_, OnRawBodyInput(_)) + .WillRepeatedly([&body_input](absl::string_view input) { + absl::StrAppend(&body_input, input); + }); + std::string body_data; + EXPECT_CALL(visitor_mock_, OnBodyChunkInput(_)) + .WillRepeatedly([&body_data](absl::string_view input) { + absl::StrAppend(&body_data, input); + }); + EXPECT_CALL(visitor_mock_, OnTrailerInput(_)).Times(0); + + ASSERT_EQ(message_headers.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_EQ(message_body.size(), + balsa_frame_.ProcessInput(message.data() + message_headers.size(), + message.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); + + EXPECT_EQ(message_body, body_input); + EXPECT_EQ(message_body_data, body_data); +} + +TEST_F(HTTPBalsaFrameTest, + VisitorInvokedProperlyForResponseWithTransferEncodingAndTrailers) { + std::string message_headers = + "HTTP/1.1 \t 200 Ok all is well\r\n" + "trAnsfer-eNcoding: chunked\r\n" + "\r\n"; + std::string message_body = + "A chunkjed extension \r\n" + "01234567890 more crud including numbers 123123\r\n" + "3f\n" + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\n" + "0 last one\r\n"; + std::string trailer_data = + "a_trailer_key: and a trailer value\r\n" + "\r\n"; + std::string message_body_data = + "0123456789" + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; + + std::string message = (std::string(message_headers) + + std::string(message_body) + std::string(trailer_data)); + + FakeHeaders fake_headers; + fake_headers.AddKeyValue("trAnsfer-eNcoding", "chunked"); + + FakeHeaders fake_headers_in_trailer; + fake_headers_in_trailer.AddKeyValue("a_trailer_key", "and a trailer value"); + + balsa_frame_.set_is_request(false); + + { + InSequence s1; + EXPECT_CALL(visitor_mock_, + OnResponseFirstLineInput("HTTP/1.1 \t 200 Ok all is well", + "HTTP/1.1", "200", "Ok all is well")); + EXPECT_CALL(visitor_mock_, ProcessHeaders(fake_headers)); + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, OnChunkLength(10)); + EXPECT_CALL(visitor_mock_, OnChunkLength(63)); + EXPECT_CALL(visitor_mock_, OnChunkLength(0)); + EXPECT_CALL(visitor_mock_, ProcessTrailers(fake_headers_in_trailer)); + EXPECT_CALL(visitor_mock_, MessageDone()); + } + EXPECT_CALL(visitor_mock_, OnHeaderInput(message_headers)); + std::string body_input; + EXPECT_CALL(visitor_mock_, OnRawBodyInput(_)) + .WillRepeatedly([&body_input](absl::string_view input) { + absl::StrAppend(&body_input, input); + }); + std::string body_data; + EXPECT_CALL(visitor_mock_, OnBodyChunkInput(_)) + .WillRepeatedly([&body_data](absl::string_view input) { + absl::StrAppend(&body_data, input); + }); + EXPECT_CALL(visitor_mock_, OnTrailerInput(trailer_data)); + + ASSERT_EQ(message_headers.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_EQ(message_body.size() + trailer_data.size(), + balsa_frame_.ProcessInput(message.data() + message_headers.size(), + message.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); + + EXPECT_EQ(message_body, body_input); + EXPECT_EQ(message_body_data, body_data); + + const absl::string_view a_trailer_key = trailer_.GetHeader("a_trailer_key"); + EXPECT_EQ("and a trailer value", a_trailer_key); +} + +TEST_F( + HTTPBalsaFrameTest, + VisitorInvokedProperlyForResponseWithTransferEncodingAndTrailersBytePer) { + std::string message_headers = + "HTTP/1.1 \t 200 Ok all is well\r\n" + "trAnsfer-eNcoding: chunked\r\n" + "\r\n"; + std::string message_body = + "A chunkjed extension \r\n" + "01234567890 more crud including numbers 123123\r\n" + "3f\n" + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\n" + "0 last one\r\n"; + std::string trailer_data = + "a_trailer_key: and a trailer value\r\n" + "\r\n"; + std::string message_body_data = + "0123456789" + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; + + std::string message = (std::string(message_headers) + + std::string(message_body) + std::string(trailer_data)); + + FakeHeaders fake_headers; + fake_headers.AddKeyValue("trAnsfer-eNcoding", "chunked"); + FakeHeaders fake_headers_in_trailer; + fake_headers_in_trailer.AddKeyValue("a_trailer_key", "and a trailer value"); + + balsa_frame_.set_is_request(false); + + { + InSequence s1; + EXPECT_CALL(visitor_mock_, + OnResponseFirstLineInput("HTTP/1.1 \t 200 Ok all is well", + "HTTP/1.1", "200", "Ok all is well")); + EXPECT_CALL(visitor_mock_, ProcessHeaders(fake_headers)); + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, OnChunkLength(10)); + EXPECT_CALL(visitor_mock_, OnChunkLength(63)); + EXPECT_CALL(visitor_mock_, OnChunkLength(0)); + EXPECT_CALL(visitor_mock_, ProcessTrailers(fake_headers_in_trailer)); + EXPECT_CALL(visitor_mock_, MessageDone()); + } + EXPECT_CALL(visitor_mock_, OnHeaderInput(message_headers)); + std::string body_input; + EXPECT_CALL(visitor_mock_, OnRawBodyInput(_)) + .WillRepeatedly([&body_input](absl::string_view input) { + absl::StrAppend(&body_input, input); + }); + std::string body_data; + EXPECT_CALL(visitor_mock_, OnBodyChunkInput(_)) + .WillRepeatedly([&body_data](absl::string_view input) { + absl::StrAppend(&body_data, input); + }); + std::string trailer_input; + EXPECT_CALL(visitor_mock_, OnTrailerInput(_)) + .WillRepeatedly([&trailer_input](absl::string_view input) { + absl::StrAppend(&trailer_input, input); + }); + + for (size_t i = 0; i < message.size(); ++i) { + ASSERT_EQ(1u, balsa_frame_.ProcessInput(message.data() + i, 1)); + } + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); + + EXPECT_EQ(message_body, body_input); + EXPECT_EQ(message_body_data, body_data); + EXPECT_EQ(trailer_data, trailer_input); + + const absl::string_view a_trailer_key = trailer_.GetHeader("a_trailer_key"); + EXPECT_EQ("and a trailer value", a_trailer_key); +} + +TEST(HTTPBalsaFrame, + VisitorInvokedProperlyForResponseWithTransferEncodingAndTrailersRandom) { + TestSeed seed; + seed.Initialize(GetQuicheCommandLineFlag(FLAGS_randseed)); + RandomEngine rng; + rng.seed(seed.GetSeed()); + for (int i = 0; i < 1000; ++i) { + std::string message_headers = + "HTTP/1.1 \t 200 Ok all is well\r\n" + "trAnsfer-eNcoding: chunked\r\n" + "\r\n"; + std::string message_body = + "A chunkjed extension \r\n" + "01234567890 more crud including numbers 123123\r\n" + "3f\n" + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\n" + "0 last one\r\n"; + std::string trailer_data = + "a_trailer_key: and a trailer value\r\n" + "\r\n"; + std::string message_body_data = + "0123456789" + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; + + std::string message = + (std::string(message_headers) + std::string(message_body) + + std::string(trailer_data)); + + FakeHeaders fake_headers; + fake_headers.AddKeyValue("trAnsfer-eNcoding", "chunked"); + FakeHeaders fake_headers_in_trailer; + fake_headers_in_trailer.AddKeyValue("a_trailer_key", "and a trailer value"); + + StrictMock visitor_mock; + + BalsaHeaders headers; + BalsaHeaders trailer; + BalsaFrame balsa_frame; + balsa_frame.set_is_request(false); + balsa_frame.set_balsa_headers(&headers); + balsa_frame.set_balsa_trailer(&trailer); + balsa_frame.set_balsa_visitor(&visitor_mock); + + { + InSequence s1; + EXPECT_CALL(visitor_mock, OnResponseFirstLineInput( + "HTTP/1.1 \t 200 Ok all is well", + "HTTP/1.1", "200", "Ok all is well")); + EXPECT_CALL(visitor_mock, OnHeader); + EXPECT_CALL(visitor_mock, ProcessHeaders(fake_headers)); + EXPECT_CALL(visitor_mock, HeaderDone()); + EXPECT_CALL(visitor_mock, OnHeader); + EXPECT_CALL(visitor_mock, ProcessTrailers(fake_headers_in_trailer)); + EXPECT_CALL(visitor_mock, MessageDone()); + } + EXPECT_CALL(visitor_mock, OnHeaderInput(message_headers)); + std::string body_input; + EXPECT_CALL(visitor_mock, OnRawBodyInput(_)) + .WillRepeatedly([&body_input](absl::string_view input) { + absl::StrAppend(&body_input, input); + }); + std::string body_data; + EXPECT_CALL(visitor_mock, OnBodyChunkInput(_)) + .WillRepeatedly([&body_data](absl::string_view input) { + absl::StrAppend(&body_data, input); + }); + std::string trailer_input; + EXPECT_CALL(visitor_mock, OnTrailerInput(_)) + .WillRepeatedly([&trailer_input](absl::string_view input) { + absl::StrAppend(&trailer_input, input); + }); + EXPECT_CALL(visitor_mock, OnChunkLength(_)).Times(AtLeast(1)); + EXPECT_CALL(visitor_mock, OnChunkExtensionInput(_)).Times(AtLeast(1)); + + size_t count = 0; + size_t total_processed = 0; + for (size_t i = 0; i < message.size();) { + auto dist = std::uniform_int_distribution<>(0, message.size() - i + 1); + count = dist(rng); + size_t processed = balsa_frame.ProcessInput(message.data() + i, count); + ASSERT_GE(count, processed); + total_processed += processed; + i += processed; + } + EXPECT_EQ(message.size(), total_processed); + EXPECT_TRUE(balsa_frame.MessageFullyRead()); + EXPECT_FALSE(balsa_frame.Error()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame.ErrorCode()); + + EXPECT_EQ(message_body, body_input); + EXPECT_EQ(message_body_data, body_data); + EXPECT_EQ(trailer_data, trailer_input); + + const absl::string_view a_trailer_key = trailer.GetHeader("a_trailer_key"); + EXPECT_EQ("and a trailer value", a_trailer_key); + } +} + +TEST_F(HTTPBalsaFrameTest, + AppropriateActionTakenWhenHeadersTooLongWithTooMuchInput) { + const absl::string_view message = + "GET /asflkasfdhjsafdkljhasfdlkjhasdflkjhsafdlkjhh HTTP/1.1"; + const size_t kAmountLessThanHeaderLen = 10; + ASSERT_LE(kAmountLessThanHeaderLen, message.size()); + + auto error_code = BalsaFrameEnums::HEADERS_TOO_LONG; + EXPECT_CALL(visitor_mock_, HandleError(error_code)); + + balsa_frame_.set_max_header_length(message.size() - kAmountLessThanHeaderLen); + + ASSERT_EQ(balsa_frame_.max_header_length(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::HEADERS_TOO_LONG, balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, AppropriateActionTakenWhenHeadersTooLongWithBody) { + std::string message = + "PUT /foo HTTP/1.1\r\n" + "Content-Length: 4\r\n" + "header: xxxxxxxxx\r\n\r\n" + "B"; // body begin + + auto error_code = BalsaFrameEnums::HEADERS_TOO_LONG; + EXPECT_CALL(visitor_mock_, HandleError(error_code)); + + // -2 because we have 1 byte of body, and we want to refuse + // this. + balsa_frame_.set_max_header_length(message.size() - 2); + + ASSERT_EQ(balsa_frame_.max_header_length(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::HEADERS_TOO_LONG, balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, AppropriateActionTakenWhenHeadersTooLongWhenReset) { + std::string message = + "GET /asflkasfdhjsafdkljhasfdlkjhasdflkjhsafdlkjhh HTTP/1.1\r\n" + "\r\n"; + const size_t kAmountLessThanHeaderLen = 10; + ASSERT_LE(kAmountLessThanHeaderLen, message.size()); + + auto error_code = BalsaFrameEnums::HEADERS_TOO_LONG; + + ASSERT_EQ(message.size() - 2, + balsa_frame_.ProcessInput(message.data(), message.size() - 2)); + + // Now set max header length to something smaller. + balsa_frame_.set_max_header_length(message.size() - kAmountLessThanHeaderLen); + EXPECT_CALL(visitor_mock_, HandleError(error_code)); + + ASSERT_EQ(0u, + balsa_frame_.ProcessInput(message.data() + message.size() - 2, 2)); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::HEADERS_TOO_LONG, balsa_frame_.ErrorCode()); +} + +class BalsaFrameParsingTest : public QuicheTest { + protected: + void SetUp() override { + balsa_frame_.set_is_request(true); + balsa_frame_.set_balsa_headers(&headers_); + balsa_frame_.set_balsa_visitor(&visitor_mock_); + } + + void TestEmptyHeaderKeyHelper(const std::string& message) { + InSequence s; + EXPECT_CALL(visitor_mock_, OnRequestFirstLineInput("GET / HTTP/1.1", "GET", + "/", "HTTP/1.1")); + EXPECT_CALL(visitor_mock_, OnHeaderInput(_)); + EXPECT_CALL(visitor_mock_, OnHeader).Times(AnyNumber()); + EXPECT_CALL(visitor_mock_, + HandleError(BalsaFrameEnums::INVALID_HEADER_FORMAT)); + + ASSERT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_TRUE(balsa_frame_.Error()); + Mock::VerifyAndClearExpectations(&visitor_mock_); + } + + void TestInvalidTrailerFormat(const std::string& trailer, + bool invalid_name_char) { + balsa_frame_.set_is_request(false); + balsa_frame_.set_balsa_trailer(&trailer_); + + std::string headers = + "HTTP/1.0 200 ok\r\n" + "transfer-encoding: chunked\r\n" + "\r\n"; + + std::string chunks = + "3\r\n" + "123\r\n" + "0\r\n"; + + InSequence s; + + EXPECT_CALL(visitor_mock_, OnResponseFirstLineInput); + EXPECT_CALL(visitor_mock_, OnHeaderInput); + EXPECT_CALL(visitor_mock_, OnHeader).Times(AnyNumber()); + EXPECT_CALL(visitor_mock_, ProcessHeaders); + EXPECT_CALL(visitor_mock_, HeaderDone); + EXPECT_CALL(visitor_mock_, OnChunkLength(3)); + EXPECT_CALL(visitor_mock_, OnChunkExtensionInput); + EXPECT_CALL(visitor_mock_, OnRawBodyInput); + EXPECT_CALL(visitor_mock_, OnBodyChunkInput); + EXPECT_CALL(visitor_mock_, OnChunkLength(0)); + EXPECT_CALL(visitor_mock_, OnChunkExtensionInput); + EXPECT_CALL(visitor_mock_, OnRawBodyInput); + EXPECT_CALL(visitor_mock_, OnRawBodyInput); + EXPECT_CALL(visitor_mock_, OnHeader).Times(AnyNumber()); + const auto expected_error = + invalid_name_char ? BalsaFrameEnums::INVALID_TRAILER_NAME_CHARACTER + : BalsaFrameEnums::INVALID_TRAILER_FORMAT; + EXPECT_CALL(visitor_mock_, HandleError(expected_error)).Times(1); + + EXPECT_CALL(visitor_mock_, ProcessTrailers(_)).Times(0); + EXPECT_CALL(visitor_mock_, MessageDone()).Times(0); + + ASSERT_EQ(headers.size(), + balsa_frame_.ProcessInput(headers.data(), headers.size())); + ASSERT_EQ(chunks.size(), + balsa_frame_.ProcessInput(chunks.data(), chunks.size())); + EXPECT_EQ(trailer.size(), + balsa_frame_.ProcessInput(trailer.data(), trailer.size())); + EXPECT_FALSE(balsa_frame_.MessageFullyRead()); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(expected_error, balsa_frame_.ErrorCode()); + + Mock::VerifyAndClearExpectations(&visitor_mock_); + } + + BalsaHeaders headers_; + BalsaHeaders trailer_; + BalsaFrame balsa_frame_; + StrictMock visitor_mock_; +}; + +TEST_F(BalsaFrameParsingTest, AppropriateActionTakenWhenHeaderColonsAreFunny) { + // Believe it or not, the following message is not structured willy-nilly. + // It is structured so that both codepaths in both SSE2 and non SSE2 paths + // for finding colons are exersized. + std::string message = + "GET / HTTP/1.1\r\n" + "a\r\n" + "b\r\n" + "c\r\n" + "d\r\n" + "e\r\n" + "f\r\n" + "g\r\n" + "h\r\n" + "i:\r\n" + "j\r\n" + "k\r\n" + "l\r\n" + "m\r\n" + "n\r\n" + "o\r\n" + "p\r\n" + "q\r\n" + "r\r\n" + "s\r\n" + "t\r\n" + "u\r\n" + "v\r\n" + "w\r\n" + "x\r\n" + "y\r\n" + "z\r\n" + "A\r\n" + "B\r\n" + ": val\r\n" + "\r\n"; + + EXPECT_CALL(visitor_mock_, OnRequestFirstLineInput("GET / HTTP/1.1", "GET", + "/", "HTTP/1.1")); + EXPECT_CALL(visitor_mock_, OnHeaderInput(_)); + EXPECT_CALL(visitor_mock_, OnHeader("i", "")); + EXPECT_CALL(visitor_mock_, OnHeader("", "val")); + EXPECT_CALL(visitor_mock_, + HandleWarning(BalsaFrameEnums::HEADER_MISSING_COLON)) + .Times(27); + EXPECT_CALL(visitor_mock_, + HandleError(BalsaFrameEnums::INVALID_HEADER_FORMAT)); + + ASSERT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + + EXPECT_TRUE(balsa_frame_.Error()); +} + +TEST_F(BalsaFrameParsingTest, ErrorWhenHeaderKeyIsEmpty) { + std::string firstKeyIsEmpty = + "GET / HTTP/1.1\r\n" + ": \r\n" + "a:b\r\n" + "c:d\r\n" + "\r\n"; + TestEmptyHeaderKeyHelper(firstKeyIsEmpty); + + balsa_frame_.Reset(); + + std::string laterKeyIsEmpty = + "GET / HTTP/1.1\r\n" + "a:b\r\n" + ": \r\n" + "c:d\r\n" + "\r\n"; + TestEmptyHeaderKeyHelper(laterKeyIsEmpty); +} + +TEST_F(BalsaFrameParsingTest, InvalidTrailerFormat) { + std::string trailer = + ":monkeys\n" + "\r\n"; + TestInvalidTrailerFormat(trailer, false); + + balsa_frame_.Reset(); + + std::string trailer2 = + " \r\n" + "test: test\r\n" + "\r\n"; + TestInvalidTrailerFormat(trailer2, true); + + balsa_frame_.Reset(); + + std::string trailer3 = + "a: b\r\n" + ": test\r\n" + "\r\n"; + TestInvalidTrailerFormat(trailer3, false); +} + +TEST_F(HTTPBalsaFrameTest, + EnsureHeaderFramingFoundWithVariousCombinationsOfRN_RN) { + const std::string message = + "GET / HTTP/1.1\r\n" + "content-length: 0\r\n" + "a\r\n" + "b\r\n" + "c\r\n" + "d\r\n" + "e\r\n" + "f\r\n" + "g\r\n" + "h\r\n" + "i\r\n" + "\r\n"; + EXPECT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_FALSE(balsa_frame_.Error()) + << BalsaFrameEnums::ErrorCodeToString(balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, + EnsureHeaderFramingFoundWithVariousCombinationsOfRN_N) { + const std::string message = + "GET / HTTP/1.1\n" + "content-length: 0\n" + "a\n" + "b\n" + "c\n" + "d\n" + "e\n" + "f\n" + "g\n" + "h\n" + "i\n" + "\n"; + EXPECT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_FALSE(balsa_frame_.Error()) + << BalsaFrameEnums::ErrorCodeToString(balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, + EnsureHeaderFramingFoundWithVariousCombinationsOfRN_RN_N) { + const std::string message = + "GET / HTTP/1.1\n" + "content-length: 0\r\n" + "a\r\n" + "b\n" + "c\r\n" + "d\n" + "e\r\n" + "f\n" + "g\r\n" + "h\n" + "i\r\n" + "\n"; + EXPECT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_FALSE(balsa_frame_.Error()) + << BalsaFrameEnums::ErrorCodeToString(balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, + EnsureHeaderFramingFoundWithVariousCombinationsOfRN_N_RN) { + const std::string message = + "GET / HTTP/1.1\n" + "content-length: 0\r\n" + "a\n" + "b\r\n" + "c\n" + "d\r\n" + "e\n" + "f\r\n" + "g\n" + "h\r\n" + "i\n" + "\r\n"; + EXPECT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_FALSE(balsa_frame_.Error()) + << BalsaFrameEnums::ErrorCodeToString(balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, ReadUntilCloseStateEnteredAsExpectedAndNotExited) { + std::string message = + "HTTP/1.1 200 OK\r\n" + "\r\n"; + balsa_frame_.set_is_request(false); + EXPECT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_FALSE(balsa_frame_.Error()) + << BalsaFrameEnums::ErrorCodeToString(balsa_frame_.ErrorCode()); + EXPECT_EQ(BalsaFrameEnums::READING_UNTIL_CLOSE, balsa_frame_.ParseState()); + + std::string gobldygook = "-198324-9182-43981-23498-98342-jasldfn-1294hj"; + for (int i = 0; i < 1000; ++i) { + EXPECT_EQ(gobldygook.size(), + balsa_frame_.ProcessInput(gobldygook.data(), gobldygook.size())); + EXPECT_FALSE(balsa_frame_.Error()) + << BalsaFrameEnums::ErrorCodeToString(balsa_frame_.ErrorCode()); + EXPECT_EQ(BalsaFrameEnums::READING_UNTIL_CLOSE, balsa_frame_.ParseState()); + } +} + +TEST_F(HTTPBalsaFrameTest, + BytesSafeToSpliceAndBytesSplicedWorksWithContentLength) { + std::string header = + "HTTP/1.1 200 OK\r\n" + "content-length: 1000\r\n" + "\r\n"; + balsa_frame_.set_is_request(false); + size_t bytes_safe_to_splice = 1000; + EXPECT_EQ(0u, balsa_frame_.BytesSafeToSplice()); + EXPECT_EQ(header.size(), + balsa_frame_.ProcessInput(header.data(), header.size())); + EXPECT_EQ(bytes_safe_to_splice, balsa_frame_.BytesSafeToSplice()); + while (bytes_safe_to_splice > 0) { + balsa_frame_.BytesSpliced(1); + bytes_safe_to_splice -= 1; + ASSERT_FALSE(balsa_frame_.Error()) + << BalsaFrameEnums::ParseStateToString(balsa_frame_.ParseState()) << " " + << BalsaFrameEnums::ErrorCodeToString(balsa_frame_.ErrorCode()) + << " with bytes_safe_to_splice: " << bytes_safe_to_splice + << " and BytesSafeToSplice(): " << balsa_frame_.BytesSafeToSplice(); + } + EXPECT_EQ(0u, balsa_frame_.BytesSafeToSplice()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); +} + +TEST_F(HTTPBalsaFrameTest, BytesSplicedFlagsErrorsWhenNotInProperState) { + balsa_frame_.set_is_request(false); + balsa_frame_.BytesSpliced(1); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::CALLED_BYTES_SPLICED_WHEN_UNSAFE_TO_DO_SO, + balsa_frame_.ErrorCode()); + EXPECT_FALSE(balsa_frame_.MessageFullyRead()); +} + +TEST_F(HTTPBalsaFrameTest, + BytesSplicedFlagsErrorsWhenTooMuchSplicedForContentLen) { + std::string header = + "HTTP/1.1 200 OK\r\n" + "content-length: 1000\r\n" + "\r\n"; + balsa_frame_.set_is_request(false); + EXPECT_EQ(0u, balsa_frame_.BytesSafeToSplice()); + EXPECT_EQ(header.size(), + balsa_frame_.ProcessInput(header.data(), header.size())); + EXPECT_EQ(1000u, balsa_frame_.BytesSafeToSplice()); + balsa_frame_.BytesSpliced(1001); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ( + BalsaFrameEnums::CALLED_BYTES_SPLICED_AND_EXCEEDED_SAFE_SPLICE_AMOUNT, + balsa_frame_.ErrorCode()); + EXPECT_FALSE(balsa_frame_.MessageFullyRead()); +} + +TEST_F(HTTPBalsaFrameTest, BytesSplicedWorksAsExpectedForReadUntilClose) { + std::string header = + "HTTP/1.1 200 OK\r\n" + "\r\n"; + balsa_frame_.set_is_request(false); + EXPECT_EQ(0u, balsa_frame_.BytesSafeToSplice()); + EXPECT_EQ(header.size(), + balsa_frame_.ProcessInput(header.data(), header.size())); + EXPECT_EQ(BalsaFrameEnums::READING_UNTIL_CLOSE, balsa_frame_.ParseState()); + EXPECT_EQ(std::numeric_limits::max(), + balsa_frame_.BytesSafeToSplice()); + for (int i = 0; i < 1000; ++i) { + EXPECT_EQ(std::numeric_limits::max(), + balsa_frame_.BytesSafeToSplice()); + balsa_frame_.BytesSpliced(12312312); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_FALSE(balsa_frame_.MessageFullyRead()); + } + EXPECT_EQ(std::numeric_limits::max(), + balsa_frame_.BytesSafeToSplice()); +} + +TEST_F(HTTPBalsaFrameTest, + BytesSplicedFlagsErrorsWhenTooMuchSplicedForChunked) { + std::string header = + "HTTP/1.1 200 OK\r\n" + "transfer-encoding: chunked\r\n" + "\r\n"; + std::string body_fragment = "a\r\n"; + balsa_frame_.set_is_request(false); + EXPECT_EQ(0u, balsa_frame_.BytesSafeToSplice()); + EXPECT_EQ(header.size(), + balsa_frame_.ProcessInput(header.data(), header.size())); + EXPECT_EQ(0u, balsa_frame_.BytesSafeToSplice()); + EXPECT_EQ( + body_fragment.size(), + balsa_frame_.ProcessInput(body_fragment.data(), body_fragment.size())); + EXPECT_EQ(10u, balsa_frame_.BytesSafeToSplice()); + balsa_frame_.BytesSpliced(11); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ( + BalsaFrameEnums::CALLED_BYTES_SPLICED_AND_EXCEEDED_SAFE_SPLICE_AMOUNT, + balsa_frame_.ErrorCode()); + EXPECT_FALSE(balsa_frame_.MessageFullyRead()); +} + +TEST_F(HTTPBalsaFrameTest, BytesSafeToSpliceAndBytesSplicedWorksWithChunks) { + std::string header = + "HTTP/1.1 200 OK\r\n" + "transfer-encoding: chunked\r\n" + "\r\n"; + balsa_frame_.set_is_request(false); + EXPECT_EQ(0u, balsa_frame_.BytesSafeToSplice()); + EXPECT_EQ(header.size(), + balsa_frame_.ProcessInput(header.data(), header.size())); + + { + std::string body_fragment = "3e8\r\n"; + EXPECT_FALSE(balsa_frame_.MessageFullyRead()); + size_t bytes_safe_to_splice = 1000; + EXPECT_EQ(0u, balsa_frame_.BytesSafeToSplice()); + EXPECT_EQ( + body_fragment.size(), + balsa_frame_.ProcessInput(body_fragment.data(), body_fragment.size())); + EXPECT_EQ(bytes_safe_to_splice, balsa_frame_.BytesSafeToSplice()); + while (bytes_safe_to_splice > 0) { + balsa_frame_.BytesSpliced(1); + bytes_safe_to_splice -= 1; + ASSERT_FALSE(balsa_frame_.Error()); + } + EXPECT_EQ(0u, balsa_frame_.BytesSafeToSplice()); + EXPECT_FALSE(balsa_frame_.Error()); + } + { + std::string body_fragment = "\r\n7d0\r\n"; + EXPECT_FALSE(balsa_frame_.MessageFullyRead()); + size_t bytes_safe_to_splice = 2000; + EXPECT_EQ(0u, balsa_frame_.BytesSafeToSplice()); + EXPECT_EQ( + body_fragment.size(), + balsa_frame_.ProcessInput(body_fragment.data(), body_fragment.size())); + EXPECT_EQ(bytes_safe_to_splice, balsa_frame_.BytesSafeToSplice()); + while (bytes_safe_to_splice > 0) { + balsa_frame_.BytesSpliced(1); + bytes_safe_to_splice -= 1; + ASSERT_FALSE(balsa_frame_.Error()); + } + EXPECT_EQ(0u, balsa_frame_.BytesSafeToSplice()); + EXPECT_FALSE(balsa_frame_.Error()); + } + { + std::string body_fragment = "\r\n1\r\n"; + EXPECT_FALSE(balsa_frame_.MessageFullyRead()); + size_t bytes_safe_to_splice = 1; + EXPECT_EQ(0u, balsa_frame_.BytesSafeToSplice()); + EXPECT_EQ( + body_fragment.size(), + balsa_frame_.ProcessInput(body_fragment.data(), body_fragment.size())); + EXPECT_EQ(bytes_safe_to_splice, balsa_frame_.BytesSafeToSplice()); + while (bytes_safe_to_splice > 0) { + balsa_frame_.BytesSpliced(1); + bytes_safe_to_splice -= 1; + ASSERT_FALSE(balsa_frame_.Error()); + } + EXPECT_EQ(0u, balsa_frame_.BytesSafeToSplice()); + EXPECT_FALSE(balsa_frame_.Error()); + } + { + std::string body_fragment = "\r\n0\r\n\r\n"; + EXPECT_FALSE(balsa_frame_.MessageFullyRead()); + EXPECT_EQ(0u, balsa_frame_.BytesSafeToSplice()); + EXPECT_EQ( + body_fragment.size(), + balsa_frame_.ProcessInput(body_fragment.data(), body_fragment.size())); + EXPECT_EQ(0u, balsa_frame_.BytesSafeToSplice()); + EXPECT_FALSE(balsa_frame_.Error()); + } + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); +} + +TEST_F(HTTPBalsaFrameTest, TwoDifferentContentHeadersIsAnError) { + std::string header = + "HTTP/1.1 200 OK\r\n" + "content-length: 12\r\n" + "content-length: 14\r\n" + "\r\n"; + balsa_frame_.set_is_request(false); + balsa_frame_.ProcessInput(header.data(), header.size()); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::MULTIPLE_CONTENT_LENGTH_KEYS, + balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, TwoSameContentHeadersIsNotAnError) { + std::string header = + "POST / HTTP/1.1\r\n" + "content-length: 1\r\n" + "content-length: 1\r\n" + "\r\n" + "1"; + balsa_frame_.ProcessInput(header.data(), header.size()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); + EXPECT_FALSE(balsa_frame_.Error()); + balsa_frame_.ProcessInput(header.data(), header.size()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); +} + +TEST_F(HTTPBalsaFrameTest, TwoTransferEncodingHeadersIsAnError) { + std::string header = + "HTTP/1.1 200 OK\r\n" + "transfer-encoding: chunked\r\n" + "transfer-encoding: identity\r\n" + "\r\n"; + balsa_frame_.set_is_request(false); + balsa_frame_.ProcessInput(header.data(), header.size()); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::MULTIPLE_TRANSFER_ENCODING_KEYS, + balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, TwoTransferEncodingTokensIsAnError) { + std::string header = + "HTTP/1.1 200 OK\r\n" + "transfer-encoding: chunked, identity\r\n" + "\r\n"; + balsa_frame_.set_is_request(false); + balsa_frame_.ProcessInput(header.data(), header.size()); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::UNKNOWN_TRANSFER_ENCODING, + balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, UnknownTransferEncodingTokenIsAnError) { + std::string header = + "HTTP/1.1 200 OK\r\n" + "transfer-encoding: chunked-identity\r\n" + "\r\n"; + balsa_frame_.set_is_request(false); + balsa_frame_.ProcessInput(header.data(), header.size()); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::UNKNOWN_TRANSFER_ENCODING, + balsa_frame_.ErrorCode()); +} + +class DetachOnDoneFramer : public NoOpBalsaVisitor { + public: + DetachOnDoneFramer() { + framer_.set_balsa_headers(&headers_); + framer_.set_balsa_visitor(this); + } + + void MessageDone() override { framer_.set_balsa_headers(nullptr); } + + BalsaFrame* framer() { return &framer_; } + + protected: + BalsaFrame framer_; + BalsaHeaders headers_; +}; + +TEST(HTTPBalsaFrame, TestDetachOnDone) { + DetachOnDoneFramer framer; + const char* message = "GET HTTP/1.1\r\n\r\n"; + // Frame the whole message. The framer will call MessageDone which will set + // the headers to nullptr. + framer.framer()->ProcessInput(message, strlen(message)); + EXPECT_TRUE(framer.framer()->MessageFullyRead()); + EXPECT_FALSE(framer.framer()->Error()); +} + +// We simply extend DetachOnDoneFramer so that we do not have +// to provide trivial implementation for various functions. +class ModifyMaxHeaderLengthFramerInFirstLine : public DetachOnDoneFramer { + public: + void MessageDone() override {} + // This sets to max_header_length to a low number and + // this would cause us to reject the query. Even though + // our original headers length was acceptable. + void OnRequestFirstLineInput(absl::string_view /*line_input*/, + absl::string_view /*method_input*/, + absl::string_view /*request_uri*/, + absl::string_view /*version_input*/ + ) override { + framer_.set_max_header_length(1); + } +}; + +// In this case we have already processed the headers and called on +// the visitor HeadersDone and hence its too late to reduce the +// max_header_length here. +class ModifyMaxHeaderLengthFramerInHeaderDone : public DetachOnDoneFramer { + public: + void MessageDone() override {} + void HeaderDone() override { framer_.set_max_header_length(1); } +}; + +TEST(HTTPBalsaFrame, ChangeMaxHeadersLengthOnFirstLine) { + std::string message = + "PUT /foo HTTP/1.1\r\n" + "Content-Length: 2\r\n" + "header: xxxxxxxxx\r\n\r\n" + "B"; // body begin + + ModifyMaxHeaderLengthFramerInFirstLine balsa_frame; + balsa_frame.framer()->set_is_request(true); + balsa_frame.framer()->set_max_header_length(message.size() - 1); + + balsa_frame.framer()->ProcessInput(message.data(), message.size()); + EXPECT_EQ(BalsaFrameEnums::HEADERS_TOO_LONG, + balsa_frame.framer()->ErrorCode()); +} + +TEST(HTTPBalsaFrame, ChangeMaxHeadersLengthOnHeaderDone) { + std::string message = + "PUT /foo HTTP/1.1\r\n" + "Content-Length: 2\r\n" + "header: xxxxxxxxx\r\n\r\n" + "B"; // body begin + + ModifyMaxHeaderLengthFramerInHeaderDone balsa_frame; + balsa_frame.framer()->set_is_request(true); + balsa_frame.framer()->set_max_header_length(message.size() - 1); + + balsa_frame.framer()->ProcessInput(message.data(), message.size()); + EXPECT_EQ(0, balsa_frame.framer()->ErrorCode()); +} + +// This is a simple test to ensure the simple case that we accept +// a query which has headers size same as the max_header_length. +// (i.e., there is no off by one error). +TEST(HTTPBalsaFrame, HeadersSizeSameAsMaxLengthIsAccepted) { + std::string message = + "GET /foo HTTP/1.1\r\n" + "header: xxxxxxxxx\r\n\r\n"; + + ModifyMaxHeaderLengthFramerInHeaderDone balsa_frame; + balsa_frame.framer()->set_is_request(true); + balsa_frame.framer()->set_max_header_length(message.size()); + balsa_frame.framer()->ProcessInput(message.data(), message.size()); + EXPECT_EQ(0, balsa_frame.framer()->ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, KeyHasSpaces) { + const std::string message = + "GET / HTTP/1.1\r\n" + "key has spaces: lock\r\n" + "\r\n"; + EXPECT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::INVALID_HEADER_NAME_CHARACTER, + balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, SpaceBeforeColon) { + const std::string message = + "GET / HTTP/1.1\r\n" + "key : lock\r\n" + "\r\n"; + EXPECT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::INVALID_HEADER_NAME_CHARACTER, + balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, SpaceBeforeColonNotAfter) { + const std::string message = + "GET / HTTP/1.1\r\n" + "key :lock\r\n" + "\r\n"; + EXPECT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::INVALID_HEADER_NAME_CHARACTER, + balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, KeyHasTabs) { + const std::string message = + "GET / HTTP/1.1\r\n" + "key\thas\ttabs: lock\r\n" + "\r\n"; + EXPECT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::INVALID_HEADER_NAME_CHARACTER, + balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, TabBeforeColon) { + const std::string message = + "GET / HTTP/1.1\r\n" + "key\t: lock\r\n" + "\r\n"; + EXPECT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::INVALID_HEADER_NAME_CHARACTER, + balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, KeyHasContinuation) { + const std::string message = + "GET / HTTP/1.1\r\n" + "key\n includes continuation: but not value\r\n" + "\r\n"; + EXPECT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::INVALID_HEADER_NAME_CHARACTER, + balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, KeyHasMultipleContinuations) { + const std::string message = + "GET / HTTP/1.1\r\n" + "key\n includes\r\n multiple\n continuations: but not value\r\n" + "\r\n"; + EXPECT_EQ(message.size(), + balsa_frame_.ProcessInput(message.data(), message.size())); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::INVALID_HEADER_NAME_CHARACTER, + balsa_frame_.ErrorCode()); +} + +// Missing colon is a warning, not an error. +TEST_F(HTTPBalsaFrameTest, TrailerMissingColon) { + std::string headers = + "HTTP/1.0 302 Redirect\r\n" + "transfer-encoding: chunked\r\n" + "\r\n"; + + std::string chunks = + "3\r\n" + "123\r\n" + "0\r\n"; + std::string trailer = + "crass_monkeys\n" + "\r\n"; + + balsa_frame_.set_is_request(false); + EXPECT_CALL(visitor_mock_, + HandleWarning(BalsaFrameEnums::TRAILER_MISSING_COLON)); + ASSERT_EQ(headers.size(), + balsa_frame_.ProcessInput(headers.data(), headers.size())); + ASSERT_EQ(chunks.size(), + balsa_frame_.ProcessInput(chunks.data(), chunks.size())); + EXPECT_EQ(trailer.size(), + balsa_frame_.ProcessInput(trailer.data(), trailer.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::TRAILER_MISSING_COLON, balsa_frame_.ErrorCode()); + EXPECT_FALSE(trailer_.HasHeader("crass")); + EXPECT_TRUE(trailer_.HasHeader("crass_monkeys")); + const absl::string_view crass_monkeys = trailer_.GetHeader("crass_monkeys"); + EXPECT_TRUE(crass_monkeys.empty()); +} + +// This tests multiple headers in trailer. We currently do not and have no plan +// to support Trailer field in headers to limit valid field-name in trailer. +// Test that we aren't confused by the non-alphanumeric characters in the +// trailer, especially ':'. +TEST_F(HTTPBalsaFrameTest, MultipleHeadersInTrailer) { + std::string headers = + "HTTP/1.1 200 OK\r\n" + "transfer-encoding: chunked\r\n" + "\r\n"; + + std::string chunks = + "3\r\n" + "123\n" + "0\n"; + std::map trailer; + trailer["X-Trace"] = + "http://trace.example.com/trace?host=" + "foobar.example.com&start=2012-06-03_15:59:06&rpc_duration=0.243349"; + trailer["Date"] = "Sun, 03 Jun 2012 22:59:06 GMT"; + trailer["Content-Type"] = "text/html"; + trailer["X-Backends"] = "127.0.0.1_0,foo.example.com:39359"; + trailer["X-Request-Trace"] = + "foo.example.com:39359,127.0.0.1_1," + "foo.example.com:39359,127.0.0.1_0," + "foo.example.com:39359"; + trailer["X-Service-Trace"] = "default"; + trailer["X-Service"] = "default"; + + std::map::const_iterator iter; + std::string trailer_data; + TestSeed seed; + seed.Initialize(GetQuicheCommandLineFlag(FLAGS_randseed)); + RandomEngine rng; + rng.seed(seed.GetSeed()); + FakeHeaders fake_headers_in_trailer; + for (iter = trailer.begin(); iter != trailer.end(); ++iter) { + trailer_data += iter->first; + trailer_data += ":"; + std::stringstream leading_whitespace_for_value; + AppendRandomWhitespace(rng, &leading_whitespace_for_value); + trailer_data += leading_whitespace_for_value.str(); + trailer_data += iter->second; + std::stringstream trailing_whitespace_for_value; + AppendRandomWhitespace(rng, &trailing_whitespace_for_value); + trailer_data += trailing_whitespace_for_value.str(); + trailer_data += random_line_term(rng); + fake_headers_in_trailer.AddKeyValue(iter->first, iter->second); + } + trailer_data += random_line_term(rng); + + FakeHeaders fake_headers; + fake_headers.AddKeyValue("transfer-encoding", "chunked"); + + { + InSequence s1; + EXPECT_CALL(visitor_mock_, OnResponseFirstLineInput( + "HTTP/1.1 200 OK", "HTTP/1.1", "200", "OK")); + EXPECT_CALL(visitor_mock_, ProcessHeaders(fake_headers)); + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, OnChunkLength(3)); + EXPECT_CALL(visitor_mock_, OnChunkLength(0)); + EXPECT_CALL(visitor_mock_, ProcessTrailers(fake_headers_in_trailer)); + EXPECT_CALL(visitor_mock_, OnTrailerInput(trailer_data)); + EXPECT_CALL(visitor_mock_, MessageDone()); + } + EXPECT_CALL(visitor_mock_, OnHeaderInput(headers)); + std::string body_input; + EXPECT_CALL(visitor_mock_, OnRawBodyInput(_)) + .WillRepeatedly([&body_input](absl::string_view input) { + absl::StrAppend(&body_input, input); + }); + EXPECT_CALL(visitor_mock_, OnBodyChunkInput("123")); + + balsa_frame_.set_is_request(false); + + ASSERT_EQ(headers.size(), + balsa_frame_.ProcessInput(headers.data(), headers.size())); + ASSERT_EQ(chunks.size(), + balsa_frame_.ProcessInput(chunks.data(), chunks.size())); + EXPECT_EQ(trailer_data.size(), balsa_frame_.ProcessInput( + trailer_data.data(), trailer_data.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); + + EXPECT_EQ(chunks, body_input); + + for (iter = trailer.begin(); iter != trailer.end(); ++iter) { + const absl::string_view value = trailer_.GetHeader(iter->first); + EXPECT_EQ(iter->second, value); + } +} + +// Test if trailer is not set (the common case), everything will be fine. +TEST_F(HTTPBalsaFrameTest, NothingBadHappensWithNULLTrailer) { + std::string headers = + "HTTP/1.1 200 OK\r\n" + "transfer-encoding: chunked\r\n" + "\r\n"; + + std::string chunks = + "3\r\n" + "123\r\n" + "0\r\n"; + std::string trailer = + "crass: monkeys\r\n" + "funky: monkeys\r\n" + "\n"; + + balsa_frame_.set_is_request(false); + balsa_frame_.set_balsa_visitor(nullptr); + balsa_frame_.set_balsa_trailer(nullptr); + + ASSERT_EQ(headers.size(), + balsa_frame_.ProcessInput(headers.data(), headers.size())); + ASSERT_EQ(chunks.size(), + balsa_frame_.ProcessInput(chunks.data(), chunks.size())); + ASSERT_EQ(trailer.size(), + balsa_frame_.ProcessInput(trailer.data(), trailer.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); +} + +// Test Reset() correctly resets trailer related states. +TEST_F(HTTPBalsaFrameTest, FrameAndResetAndFrameAgain) { + std::string headers = + "HTTP/1.1 200 OK\r\n" + "transfer-encoding: chunked\r\n" + "\r\n"; + + std::string chunks = + "3\r\n" + "123\r\n" + "0\r\n"; + std::string trailer = + "k: v\n" + "\n"; + + balsa_frame_.set_is_request(false); + balsa_frame_.set_balsa_visitor(nullptr); + + ASSERT_EQ(headers.size(), + balsa_frame_.ProcessInput(headers.data(), headers.size())); + ASSERT_EQ(chunks.size(), + balsa_frame_.ProcessInput(chunks.data(), chunks.size())); + ASSERT_EQ(trailer.size(), + balsa_frame_.ProcessInput(trailer.data(), trailer.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); + + absl::string_view value = trailer_.GetHeader("k"); + EXPECT_EQ("v", value); + + balsa_frame_.Reset(); + + headers = + "HTTP/1.1 404 Error\r\n" + "transfer-encoding: chunked\r\n" + "\r\n"; + + chunks = + "4\r\n" + "1234\r\n" + "0\r\n"; + trailer = + "nk: nv\n" + "\n"; + + ASSERT_EQ(headers.size(), + balsa_frame_.ProcessInput(headers.data(), headers.size())); + ASSERT_EQ(chunks.size(), + balsa_frame_.ProcessInput(chunks.data(), chunks.size())); + ASSERT_EQ(trailer.size(), + balsa_frame_.ProcessInput(trailer.data(), trailer.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); + + value = trailer_.GetHeader("k"); + EXPECT_TRUE(value.empty()); + value = trailer_.GetHeader("nk"); + EXPECT_EQ("nv", value); +} + +TEST_F(HTTPBalsaFrameTest, TrackInvalidChars) { + EXPECT_FALSE(balsa_frame_.track_invalid_chars()); +} + +// valid chars are 9 (tab), 10 (LF), 13(CR), and 32-255 +TEST_F(HTTPBalsaFrameTest, InvalidCharsInHeaderValueWarning) { + balsa_frame_.set_invalid_chars_level(BalsaFrame::InvalidCharsLevel::kWarning); + // nulls are double escaped since otherwise this initialized wrong + const std::string kEscapedInvalid1 = + "GET /foo HTTP/1.1\r\n" + "Bogus-Head: val\\x00\r\n" + "More-Invalid: \\x00\x01\x02\x03\x04\x05\x06\x07\x08\x0B\x0C\x0E\x0F\r\n" + "And-More: \x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1A\x1B\x1C\x1D" + "\x1E\x1F\r\n\r\n"; + std::string message; + // now we convert to real embedded nulls + absl::CUnescape(kEscapedInvalid1, &message); + + EXPECT_CALL(visitor_mock_, + HandleWarning(BalsaFrameEnums::INVALID_HEADER_CHARACTER)); + + balsa_frame_.ProcessInput(message.data(), message.size()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); +} + +// Header names reject invalid chars even at the warning level. +TEST_F(HTTPBalsaFrameTest, InvalidCharsInHeaderKeyError) { + balsa_frame_.set_invalid_chars_level(BalsaFrame::InvalidCharsLevel::kWarning); + // nulls are double escaped since otherwise this initialized wrong + const std::string kEscapedInvalid1 = + "GET /foo HTTP/1.1\r\n" + "Bogus\\x00-Head: val\r\n\r\n"; + std::string message; + // now we convert to real embedded nulls + absl::CUnescape(kEscapedInvalid1, &message); + + EXPECT_CALL(visitor_mock_, + HandleError(BalsaFrameEnums::INVALID_HEADER_NAME_CHARACTER)); + + balsa_frame_.ProcessInput(message.data(), message.size()); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_FALSE(balsa_frame_.MessageFullyRead()); +} + +TEST_F(HTTPBalsaFrameTest, InvalidCharsInHeaderError) { + balsa_frame_.set_invalid_chars_level(BalsaFrame::InvalidCharsLevel::kError); + const std::string kEscapedInvalid = + "GET /foo HTTP/1.1\r\n" + "Smuggle-Me: \\x00GET /bar HTTP/1.1\r\n" + "Another-Header: value\r\n\r\n"; + std::string message; + absl::CUnescape(kEscapedInvalid, &message); + + EXPECT_CALL(visitor_mock_, + HandleError(BalsaFrameEnums::INVALID_HEADER_CHARACTER)); + + balsa_frame_.ProcessInput(message.data(), message.size()); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_FALSE(balsa_frame_.MessageFullyRead()); +} + +class HTTPBalsaFrameTestOneChar : public HTTPBalsaFrameTest, + public testing::WithParamInterface { + public: + char GetCharUnderTest() { return GetParam(); } +}; + +TEST_P(HTTPBalsaFrameTestOneChar, InvalidCharsWarningSet) { + balsa_frame_.set_invalid_chars_level(BalsaFrame::InvalidCharsLevel::kWarning); + const std::string kRequest = + "GET /foo HTTP/1.1\r\n" + "Bogus-Char-Goes-Here: "; + const std::string kEnding = "\r\n\r\n"; + std::string message = kRequest; + const char c = GetCharUnderTest(); + message.append(1, c); + message.append(kEnding); + if (c == 9 || c == 10 || c == 13) { + // valid char + EXPECT_CALL(visitor_mock_, + HandleWarning(BalsaFrameEnums::INVALID_HEADER_CHARACTER)) + .Times(0); + balsa_frame_.ProcessInput(message.data(), message.size()); + EXPECT_THAT(balsa_frame_.get_invalid_chars(), IsEmpty()); + } else { + // invalid char + absl::flat_hash_map expected_count = {{c, 1}}; + EXPECT_CALL(visitor_mock_, + HandleWarning(BalsaFrameEnums::INVALID_HEADER_CHARACTER)); + balsa_frame_.ProcessInput(message.data(), message.size()); + EXPECT_EQ(balsa_frame_.get_invalid_chars(), expected_count); + } + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); +} + +INSTANTIATE_TEST_SUITE_P(TestInvalidCharSet, HTTPBalsaFrameTestOneChar, + Range(0, 32)); + +TEST_F(HTTPBalsaFrameTest, InvalidCharEndOfLine) { + balsa_frame_.set_invalid_chars_level(BalsaFrame::InvalidCharsLevel::kWarning); + const std::string kInvalid1 = + "GET /foo HTTP/1.1\r\n" + "Header-Key: headervalue\\x00\r\n" + "Legit-Header: legitvalue\r\n\r\n"; + std::string message; + absl::CUnescape(kInvalid1, &message); + + EXPECT_CALL(visitor_mock_, + HandleWarning(BalsaFrameEnums::INVALID_HEADER_CHARACTER)); + balsa_frame_.ProcessInput(message.data(), message.size()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); +} + +TEST_F(HTTPBalsaFrameTest, InvalidCharInFirstLine) { + balsa_frame_.set_invalid_chars_level(BalsaFrame::InvalidCharsLevel::kWarning); + const std::string kInvalid1 = + "GET /foo \\x00HTTP/1.1\r\n" + "Legit-Header: legitvalue\r\n\r\n"; + std::string message; + absl::CUnescape(kInvalid1, &message); + + EXPECT_CALL(visitor_mock_, + HandleWarning(BalsaFrameEnums::INVALID_HEADER_CHARACTER)); + balsa_frame_.ProcessInput(message.data(), message.size()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); +} + +TEST_F(HTTPBalsaFrameTest, InvalidCharsAreCounted) { + balsa_frame_.set_invalid_chars_level(BalsaFrame::InvalidCharsLevel::kWarning); + const std::string kInvalid1 = + "GET /foo \\x00\\x00\\x00HTTP/1.1\r\n" + "Bogus-Header: \\x00\\x04\\x04value\r\n\r\n"; + std::string message; + absl::CUnescape(kInvalid1, &message); + + EXPECT_CALL(visitor_mock_, + HandleWarning(BalsaFrameEnums::INVALID_HEADER_CHARACTER)); + balsa_frame_.ProcessInput(message.data(), message.size()); + absl::flat_hash_map expected_count = {{'\0', 4}, {'\4', 2}}; + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_EQ(balsa_frame_.get_invalid_chars(), expected_count); + + absl::flat_hash_map empty_count; + balsa_frame_.Reset(); + EXPECT_EQ(balsa_frame_.get_invalid_chars(), empty_count); +} + +// Test gibberish in headers and trailer. GFE does not crash but garbage in +// garbage out. +TEST_F(HTTPBalsaFrameTest, GibberishInHeadersAndTrailer) { + // Use static_cast for values exceeding SCHAR_MAX to make sure this + // compiles on platforms where char is signed. + const char kGibberish1[] = {static_cast(138), static_cast(175), + static_cast(233), 0}; + const char kGibberish2[] = {'?', + '?', + static_cast(128), + static_cast(255), + static_cast(129), + static_cast(254), + 0}; + const char kGibberish3[] = "foo: bar : eeep : baz"; + + std::string gibberish_headers = + absl::StrCat(kGibberish1, ":", kGibberish2, "\r\n", kGibberish3, "\r\n"); + + std::string headers = absl::StrCat( + "HTTP/1.1 200 OK\r\n" + "transfer-encoding: chunked\r\n", + gibberish_headers, "\r\n"); + + std::string chunks = + "3\r\n" + "123\r\n" + "0\r\n"; + + std::string trailer = absl::StrCat("k: v\n", gibberish_headers, "\n"); + + balsa_frame_.set_is_request(false); + balsa_frame_.set_balsa_visitor(nullptr); + + ASSERT_EQ(headers.size(), + balsa_frame_.ProcessInput(headers.data(), headers.size())); + ASSERT_EQ(chunks.size(), + balsa_frame_.ProcessInput(chunks.data(), chunks.size())); + ASSERT_EQ(trailer.size(), + balsa_frame_.ProcessInput(trailer.data(), trailer.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); + + // Transfer-encoding can be multi-valued, so GetHeader does not work. + EXPECT_TRUE(headers_.transfer_encoding_is_chunked()); + absl::string_view field_value = headers_.GetHeader(kGibberish1); + EXPECT_EQ(kGibberish2, field_value); + field_value = headers_.GetHeader("foo"); + EXPECT_EQ("bar : eeep : baz", field_value); + + field_value = trailer_.GetHeader("k"); + EXPECT_EQ("v", field_value); + field_value = trailer_.GetHeader(kGibberish1); + EXPECT_EQ(kGibberish2, field_value); + field_value = trailer_.GetHeader("foo"); + EXPECT_EQ("bar : eeep : baz", field_value); +} + +// Note we reuse the header length limit because trailer is just multiple +// headers. +TEST_F(HTTPBalsaFrameTest, TrailerTooLong) { + std::string headers = + "HTTP/1.0 200 ok\r\n" + "transfer-encoding: chunked\r\n" + "\r\n"; + + std::string chunks = + "3\r\n" + "123\r\n" + "0\r\n"; + std::string trailer = + "very : long trailer\n" + "should:cause\r\n" + "trailer :too long error\n" + "\r\n"; + + balsa_frame_.set_is_request(false); + ASSERT_LT(headers.size(), trailer.size()); + balsa_frame_.set_max_header_length(headers.size()); + + EXPECT_CALL(visitor_mock_, HandleError(BalsaFrameEnums::TRAILER_TOO_LONG)); + EXPECT_CALL(visitor_mock_, ProcessTrailers(_)).Times(0); + EXPECT_CALL(visitor_mock_, MessageDone()).Times(0); + ASSERT_EQ(headers.size(), + balsa_frame_.ProcessInput(headers.data(), headers.size())); + ASSERT_EQ(chunks.size(), + balsa_frame_.ProcessInput(chunks.data(), chunks.size())); + EXPECT_EQ(balsa_frame_.max_header_length(), + balsa_frame_.ProcessInput(trailer.data(), trailer.size())); + EXPECT_FALSE(balsa_frame_.MessageFullyRead()); + EXPECT_TRUE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::TRAILER_TOO_LONG, balsa_frame_.ErrorCode()); +} + +// If the `trailer_` object in the framer is set to `nullptr`, +// ProcessTrailers() will not be called. +TEST_F(HTTPBalsaFrameTest, + NoProcessTrailersCallWhenFramerHasNullTrailerObject) { + std::string headers = + "HTTP/1.0 200 ok\r\n" + "transfer-encoding: chunked\r\n" + "\r\n"; + + std::string chunks = + "3\r\n" + "123\r\n" + "0\r\n"; + std::string trailer = + "trailer_key : trailer_value\n" + "\r\n"; + + balsa_frame_.set_is_request(false); + balsa_frame_.set_balsa_trailer(nullptr); + + EXPECT_CALL(visitor_mock_, ProcessTrailers(_)).Times(0); + ASSERT_EQ(headers.size(), + balsa_frame_.ProcessInput(headers.data(), headers.size())); + ASSERT_EQ(chunks.size(), + balsa_frame_.ProcessInput(chunks.data(), chunks.size())); + EXPECT_EQ(trailer.size(), + balsa_frame_.ProcessInput(trailer.data(), trailer.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, Parse100ContinueNoContinueHeadersNoCallback) { + std::string continue_headers = + "HTTP/1.1 100 Continue\r\n" + "\r\n"; + + // Do not set continue headers (or use interim callbacks). Then the parsed + // continue headers are treated as final headers. + balsa_frame_.set_is_request(false); + balsa_frame_.set_use_interim_headers_callback(false); + + InSequence s; + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, MessageDone()); + + ASSERT_EQ(balsa_frame_.ProcessInput(continue_headers.data(), + continue_headers.size()), + continue_headers.size()) + << balsa_frame_.ErrorCode(); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(headers_.parsed_response_code(), 100); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); +} + +TEST_F(HTTPBalsaFrameTest, Parse100Continue) { + std::string continue_headers = + "HTTP/1.1 100 Continue\r\n" + "\r\n"; + + // The parsed continue headers are delivered as interim headers. + balsa_frame_.set_is_request(false); + balsa_frame_.set_use_interim_headers_callback(true); + + InSequence s; + EXPECT_CALL(visitor_mock_, OnInterimHeaders(Property( + &BalsaHeaders::parsed_response_code, 100))); + EXPECT_CALL(visitor_mock_, HeaderDone()).Times(0); + EXPECT_CALL(visitor_mock_, MessageDone()).Times(0); + + ASSERT_EQ(balsa_frame_.ProcessInput(continue_headers.data(), + continue_headers.size()), + continue_headers.size()) + << balsa_frame_.ErrorCode(); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(headers_.parsed_response_code(), 0u); + EXPECT_FALSE(balsa_frame_.MessageFullyRead()); +} + +// Handle two sets of headers when set up properly and the first is 100 +// Continue. +TEST_F(HTTPBalsaFrameTest, Support100ContinueNoCallback) { + std::string initial_headers = + "HTTP/1.1 100 Continue\r\n" + "\r\n"; + std::string real_headers = + "HTTP/1.1 200 OK\r\n" + "content-length: 3\r\n" + "\r\n"; + std::string body = "foo"; + + balsa_frame_.set_is_request(false); + BalsaHeaders continue_headers; + balsa_frame_.set_continue_headers(&continue_headers); + balsa_frame_.set_use_interim_headers_callback(false); + + ASSERT_EQ(initial_headers.size(), + balsa_frame_.ProcessInput(initial_headers.data(), + initial_headers.size())); + ASSERT_EQ(real_headers.size(), + balsa_frame_.ProcessInput(real_headers.data(), real_headers.size())) + << balsa_frame_.ErrorCode(); + ASSERT_EQ(body.size(), balsa_frame_.ProcessInput(body.data(), body.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); +} + +// Handle two sets of headers when set up properly and the first is 100 +// Continue. +TEST_F(HTTPBalsaFrameTest, Support100Continue) { + std::string initial_headers = + "HTTP/1.1 100 Continue\r\n" + "\r\n"; + std::string real_headers = + "HTTP/1.1 200 OK\r\n" + "content-length: 3\r\n" + "\r\n"; + std::string body = "foo"; + + balsa_frame_.set_is_request(false); + balsa_frame_.set_use_interim_headers_callback(true); + + InSequence s; + EXPECT_CALL(visitor_mock_, OnInterimHeaders(Property( + &BalsaHeaders::parsed_response_code, 100))); + ASSERT_EQ( + balsa_frame_.ProcessInput(initial_headers.data(), initial_headers.size()), + initial_headers.size()); + ASSERT_FALSE(balsa_frame_.Error()); + + EXPECT_CALL(visitor_mock_, HeaderDone()); + ASSERT_EQ(balsa_frame_.ProcessInput(real_headers.data(), real_headers.size()), + real_headers.size()) + << balsa_frame_.ErrorCode(); + EXPECT_EQ(headers_.parsed_response_code(), 200); + + EXPECT_CALL(visitor_mock_, MessageDone()); + ASSERT_EQ(balsa_frame_.ProcessInput(body.data(), body.size()), body.size()); + + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(balsa_frame_.ErrorCode(), BalsaFrameEnums::BALSA_NO_ERROR); +} + +// If both the interim headers callback and continue headers are set, only the +// former should be used. +TEST_F(HTTPBalsaFrameTest, InterimHeadersCallbackTakesPrecedence) { + std::string initial_headers = + "HTTP/1.1 100 Continue\r\n" + "\r\n"; + std::string real_headers = + "HTTP/1.1 200 OK\r\n" + "content-length: 3\r\n" + "\r\n"; + std::string body = "foo"; + + balsa_frame_.set_is_request(false); + BalsaHeaders continue_headers; + balsa_frame_.set_continue_headers(&continue_headers); + balsa_frame_.set_use_interim_headers_callback(true); + + InSequence s; + EXPECT_CALL(visitor_mock_, OnInterimHeaders(Property( + &BalsaHeaders::parsed_response_code, 100))); + EXPECT_CALL(visitor_mock_, ContinueHeaderDone).Times(0); + ASSERT_EQ( + balsa_frame_.ProcessInput(initial_headers.data(), initial_headers.size()), + initial_headers.size()); + EXPECT_EQ(continue_headers.parsed_response_code(), 0u); + ASSERT_FALSE(balsa_frame_.Error()); + + EXPECT_CALL(visitor_mock_, HeaderDone()); + ASSERT_EQ(balsa_frame_.ProcessInput(real_headers.data(), real_headers.size()), + real_headers.size()) + << balsa_frame_.ErrorCode(); + EXPECT_EQ(headers_.parsed_response_code(), 200); + + EXPECT_CALL(visitor_mock_, MessageDone()); + ASSERT_EQ(balsa_frame_.ProcessInput(body.data(), body.size()), body.size()); + + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(balsa_frame_.ErrorCode(), BalsaFrameEnums::BALSA_NO_ERROR); +} + +// Handle two sets of headers when set up properly and the first is 100 +// Continue and it meets the conditions for b/62408297. +TEST_F(HTTPBalsaFrameTest, Support100Continue401UnauthorizedNoCallback) { + std::string initial_headers = + "HTTP/1.1 100 Continue\r\n" + "\r\n"; + std::string real_headers = + "HTTP/1.1 401 Unauthorized\r\n" + "content-length: 3\r\n" + "\r\n"; + std::string body = "foo"; + + balsa_frame_.set_is_request(false); + BalsaHeaders continue_headers; + balsa_frame_.set_continue_headers(&continue_headers); + balsa_frame_.set_use_interim_headers_callback(false); + + ASSERT_EQ(initial_headers.size(), + balsa_frame_.ProcessInput(initial_headers.data(), + initial_headers.size())); + ASSERT_EQ(real_headers.size(), + balsa_frame_.ProcessInput(real_headers.data(), real_headers.size())) + << balsa_frame_.ErrorCode(); + ASSERT_EQ(body.size(), balsa_frame_.ProcessInput(body.data(), body.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); +} + +// Handle two sets of headers when set up properly and the first is 100 +// Continue and it meets the conditions for b/62408297. +TEST_F(HTTPBalsaFrameTest, Support100Continue401Unauthorized) { + std::string initial_headers = + "HTTP/1.1 100 Continue\r\n" + "\r\n"; + std::string real_headers = + "HTTP/1.1 401 Unauthorized\r\n" + "content-length: 3\r\n" + "\r\n"; + std::string body = "foo"; + + balsa_frame_.set_is_request(false); + balsa_frame_.set_use_interim_headers_callback(true); + + InSequence s; + EXPECT_CALL(visitor_mock_, OnInterimHeaders(Property( + &BalsaHeaders::parsed_response_code, 100))); + ASSERT_EQ( + balsa_frame_.ProcessInput(initial_headers.data(), initial_headers.size()), + initial_headers.size()); + ASSERT_FALSE(balsa_frame_.Error()); + + EXPECT_CALL(visitor_mock_, HeaderDone()); + ASSERT_EQ(balsa_frame_.ProcessInput(real_headers.data(), real_headers.size()), + real_headers.size()) + << balsa_frame_.ErrorCode(); + EXPECT_EQ(headers_.parsed_response_code(), 401); + + EXPECT_CALL(visitor_mock_, MessageDone()); + ASSERT_EQ(balsa_frame_.ProcessInput(body.data(), body.size()), body.size()); + + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(balsa_frame_.ErrorCode(), BalsaFrameEnums::BALSA_NO_ERROR); +} + +TEST_F(HTTPBalsaFrameTest, Support100ContinueRunTogetherNoCallback) { + std::string both_headers = + "HTTP/1.1 100 Continue\r\n" + "\r\n" + "HTTP/1.1 200 OK\r\n" + "content-length: 3\r\n" + "\r\n"; + std::string body = "foo"; + + { + InSequence s; + EXPECT_CALL(visitor_mock_, ContinueHeaderDone()); + EXPECT_CALL(visitor_mock_, HeaderDone()); + EXPECT_CALL(visitor_mock_, MessageDone()); + } + + balsa_frame_.set_is_request(false); + BalsaHeaders continue_headers; + balsa_frame_.set_continue_headers(&continue_headers); + balsa_frame_.set_use_interim_headers_callback(false); + + ASSERT_EQ(both_headers.size(), + balsa_frame_.ProcessInput(both_headers.data(), both_headers.size())) + << balsa_frame_.ErrorCode(); + ASSERT_EQ(body.size(), balsa_frame_.ProcessInput(body.data(), body.size())); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::BALSA_NO_ERROR, balsa_frame_.ErrorCode()); +} + +TEST_F(HTTPBalsaFrameTest, Support100ContinueRunTogether) { + std::string both_headers = + "HTTP/1.1 100 Continue\r\n" + "\r\n" + "HTTP/1.1 200 OK\r\n" + "content-length: 3\r\n" + "\r\n"; + std::string body = "foo"; + + balsa_frame_.set_is_request(false); + balsa_frame_.set_use_interim_headers_callback(true); + + InSequence s; + EXPECT_CALL(visitor_mock_, OnInterimHeaders(Property( + &BalsaHeaders::parsed_response_code, 100))); + EXPECT_CALL(visitor_mock_, HeaderDone()); + + ASSERT_EQ(balsa_frame_.ProcessInput(both_headers.data(), both_headers.size()), + both_headers.size()) + << balsa_frame_.ErrorCode(); + ASSERT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(headers_.parsed_response_code(), 200); + + EXPECT_CALL(visitor_mock_, MessageDone()); + ASSERT_EQ(balsa_frame_.ProcessInput(body.data(), body.size()), body.size()); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(balsa_frame_.ErrorCode(), BalsaFrameEnums::BALSA_NO_ERROR); +} + +TEST_F(HTTPBalsaFrameTest, MultipleInterimHeaders) { + std::string all_headers = + "HTTP/1.1 100 Continue\r\n" + "\r\n" + "HTTP/1.1 103 Early Hints\r\n" + "\r\n" + "HTTP/1.1 200 OK\r\n" + "content-length: 3\r\n" + "\r\n"; + std::string body = "foo"; + + balsa_frame_.set_is_request(false); + balsa_frame_.set_use_interim_headers_callback(true); + + InSequence s; + EXPECT_CALL(visitor_mock_, OnInterimHeaders(Property( + &BalsaHeaders::parsed_response_code, 100))); + EXPECT_CALL(visitor_mock_, OnInterimHeaders(Property( + &BalsaHeaders::parsed_response_code, 103))); + EXPECT_CALL(visitor_mock_, HeaderDone()); + + ASSERT_EQ(balsa_frame_.ProcessInput(all_headers.data(), all_headers.size()), + all_headers.size()) + << balsa_frame_.ErrorCode(); + ASSERT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(headers_.parsed_response_code(), 200); + + EXPECT_CALL(visitor_mock_, MessageDone()); + ASSERT_EQ(balsa_frame_.ProcessInput(body.data(), body.size()), body.size()); + EXPECT_TRUE(balsa_frame_.MessageFullyRead()); + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(balsa_frame_.ErrorCode(), BalsaFrameEnums::BALSA_NO_ERROR); +} + +TEST_F(HTTPBalsaFrameTest, Http09) { + constexpr absl::string_view request = "GET /\r\n"; + + InSequence s; + StrictMock visitor_mock; + balsa_frame_.set_balsa_visitor(&visitor_mock); + + EXPECT_CALL( + visitor_mock, + HandleWarning( + BalsaFrameEnums::FAILED_TO_FIND_WS_AFTER_REQUEST_REQUEST_URI)); + EXPECT_CALL(visitor_mock, OnRequestFirstLineInput("GET /", "GET", "/", "")); + EXPECT_CALL(visitor_mock, OnHeaderInput(request)); + EXPECT_CALL(visitor_mock, ProcessHeaders(FakeHeaders{})); + EXPECT_CALL(visitor_mock, HeaderDone()); + EXPECT_CALL(visitor_mock, MessageDone()); + + EXPECT_EQ(request.size(), + balsa_frame_.ProcessInput(request.data(), request.size())); + + // HTTP/0.9 request is parsed with a warning. + EXPECT_FALSE(balsa_frame_.Error()); + EXPECT_EQ(BalsaFrameEnums::FAILED_TO_FIND_WS_AFTER_REQUEST_REQUEST_URI, + balsa_frame_.ErrorCode()); +} + +} // namespace + +} // namespace test + +} // namespace quiche diff --git a/quiche/balsa/balsa_headers.cc b/quiche/balsa/balsa_headers.cc new file mode 100644 index 000000000000..80642c3e5336 --- /dev/null +++ b/quiche/balsa/balsa_headers.cc @@ -0,0 +1,1157 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/balsa/balsa_headers.h" + +#include + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/ascii.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "quiche/balsa/balsa_enums.h" +#include "quiche/balsa/header_properties.h" +#include "quiche/common/platform/api/quiche_header_policy.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace { + +constexpr absl::string_view kContentLength("Content-Length"); +constexpr absl::string_view kCookie("Cookie"); +constexpr absl::string_view kHost("Host"); +constexpr absl::string_view kTransferEncoding("Transfer-Encoding"); + +// The following list defines list of headers that Envoy considers multivalue. +// Headers on this list are coalesced by EFG in order to provide forward +// compatibility with Envoy behavior. See b/143490671 for details. +// Date, Last-Modified and Location are excluded because they're found on Chrome +// HttpUtil::IsNonCoalescingHeader() list. +#define ALL_ENVOY_HEADERS(HEADER_FUNC) \ + HEADER_FUNC("Accept") \ + HEADER_FUNC("Accept-Encoding") \ + HEADER_FUNC("Access-Control-Request-Headers") \ + HEADER_FUNC("Access-Control-Request-Method") \ + HEADER_FUNC("Access-Control-Allow-Origin") \ + HEADER_FUNC("Access-Control-Allow-Headers") \ + HEADER_FUNC("Access-Control-Allow-Methods") \ + HEADER_FUNC("Access-Control-Allow-Credentials") \ + HEADER_FUNC("Access-Control-Expose-Headers") \ + HEADER_FUNC("Access-Control-Max-Age") \ + HEADER_FUNC("Authorization") \ + HEADER_FUNC("Cache-Control") \ + HEADER_FUNC("X-Client-Trace-Id") \ + HEADER_FUNC("Connection") \ + HEADER_FUNC("Content-Encoding") \ + HEADER_FUNC("Content-Length") \ + HEADER_FUNC("Content-Type") \ + /* HEADER_FUNC("Date") */ \ + HEADER_FUNC("Envoy-Attempt-Count") \ + HEADER_FUNC("Envoy-Degraded") \ + HEADER_FUNC("Envoy-Decorator-Operation") \ + HEADER_FUNC("Envoy-Downstream-Service-Cluster") \ + HEADER_FUNC("Envoy-Downstream-Service-Node") \ + HEADER_FUNC("Envoy-Expected-Request-Timeout-Ms") \ + HEADER_FUNC("Envoy-External-Address") \ + HEADER_FUNC("Envoy-Force-Trace") \ + HEADER_FUNC("Envoy-Hedge-On-Per-Try-Timeout") \ + HEADER_FUNC("Envoy-Immediate-Health-Check-Fail") \ + HEADER_FUNC("Envoy-Internal-Request") \ + HEADER_FUNC("Envoy-Ip-Tags") \ + HEADER_FUNC("Envoy-Max-Retries") \ + HEADER_FUNC("Envoy-Original-Path") \ + HEADER_FUNC("Envoy-Original-Url") \ + HEADER_FUNC("Envoy-Overloaded") \ + HEADER_FUNC("Envoy-Rate-Limited") \ + HEADER_FUNC("Envoy-Retry-On") \ + HEADER_FUNC("Envoy-Retry-Grpc-On") \ + HEADER_FUNC("Envoy-Retriable-StatusCodes") \ + HEADER_FUNC("Envoy-Retriable-HeaderNames") \ + HEADER_FUNC("Envoy-Upstream-AltStatName") \ + HEADER_FUNC("Envoy-Upstream-Canary") \ + HEADER_FUNC("Envoy-Upstream-HealthCheckedCluster") \ + HEADER_FUNC("Envoy-Upstream-RequestPerTryTimeoutMs") \ + HEADER_FUNC("Envoy-Upstream-RequestTimeoutAltResponse") \ + HEADER_FUNC("Envoy-Upstream-RequestTimeoutMs") \ + HEADER_FUNC("Envoy-Upstream-ServiceTime") \ + HEADER_FUNC("Etag") \ + HEADER_FUNC("Expect") \ + HEADER_FUNC("X-Forwarded-Client-Cert") \ + HEADER_FUNC("X-Forwarded-For") \ + HEADER_FUNC("X-Forwarded-Proto") \ + HEADER_FUNC("Grpc-Accept-Encoding") \ + HEADER_FUNC("Grpc-Message") \ + HEADER_FUNC("Grpc-Status") \ + HEADER_FUNC("Grpc-Timeout") \ + HEADER_FUNC("Host") \ + HEADER_FUNC("Keep-Alive") \ + /* HEADER_FUNC("Last-Modified") */ \ + /* HEADER_FUNC("Location") */ \ + HEADER_FUNC("Method") \ + HEADER_FUNC("No-Chunks") \ + HEADER_FUNC("Origin") \ + HEADER_FUNC("X-Ot-Span-Context") \ + HEADER_FUNC("Path") \ + HEADER_FUNC("Protocol") \ + HEADER_FUNC("Proxy-Connection") \ + HEADER_FUNC("Referer") \ + HEADER_FUNC("X-Request-Id") \ + HEADER_FUNC("Scheme") \ + HEADER_FUNC("Server") \ + HEADER_FUNC("Status") \ + HEADER_FUNC("TE") \ + HEADER_FUNC("Transfer-Encoding") \ + HEADER_FUNC("Upgrade") \ + HEADER_FUNC("User-Agent") \ + HEADER_FUNC("Vary") \ + HEADER_FUNC("Via") + +// HEADER_FUNC to insert "name" into the MultivaluedHeadersSet of Envoy headers. +#define MULTIVALUE_ENVOY_HEADER(name) {name}, + +absl::string_view::difference_type FindIgnoreCase(absl::string_view haystack, + absl::string_view needle) { + absl::string_view::difference_type pos = 0; + while (haystack.size() >= needle.size()) { + if (absl::StartsWithIgnoreCase(haystack, needle)) { + return pos; + } + ++pos; + haystack.remove_prefix(1); + } + + return absl::string_view::npos; +} + +absl::string_view::difference_type RemoveLeadingWhitespace( + absl::string_view* text) { + size_t count = 0; + const char* ptr = text->data(); + while (count < text->size() && absl::ascii_isspace(*ptr)) { + count++; + ptr++; + } + text->remove_prefix(count); + return count; +} + +absl::string_view::difference_type RemoveTrailingWhitespace( + absl::string_view* text) { + size_t count = 0; + const char* ptr = text->data() + text->size() - 1; + while (count < text->size() && absl::ascii_isspace(*ptr)) { + ++count; + --ptr; + } + text->remove_suffix(count); + return count; +} + +absl::string_view::difference_type RemoveWhitespaceContext( + absl::string_view* text) { + return RemoveLeadingWhitespace(text) + RemoveTrailingWhitespace(text); +} + +} // namespace + +namespace quiche { + +const size_t BalsaBuffer::kDefaultBlocksize; + +const BalsaHeaders::MultivaluedHeadersSet& +BalsaHeaders::multivalued_envoy_headers() { + static const MultivaluedHeadersSet* multivalued_envoy_headers = + new MultivaluedHeadersSet({ALL_ENVOY_HEADERS(MULTIVALUE_ENVOY_HEADER)}); + return *multivalued_envoy_headers; +} + +void BalsaHeaders::ParseTokenList(absl::string_view header_value, + HeaderTokenList* tokens) { + if (header_value.empty()) { + return; + } + const char* start = header_value.data(); + const char* end = header_value.data() + header_value.size(); + while (true) { + // Cast `*start` to unsigned char to make values above 127 rank as expected + // on platforms with signed char, where such values are represented as + // negative numbers before the cast. + + // search for first nonwhitespace, non separator char. + while (*start == ',' || static_cast(*start) <= ' ') { + ++start; + if (start == end) { + return; + } + } + // found. marked. + const char* nws = start; + + // search for next whitspace or separator char. + while (*start != ',' && static_cast(*start) > ' ') { + ++start; + if (start == end) { + if (nws != start) { + tokens->push_back(absl::string_view(nws, start - nws)); + } + return; + } + } + tokens->push_back(absl::string_view(nws, start - nws)); + } +} + +// This can be called after a std::move() operation, so things might be +// in an unspecified state after the move. +void BalsaHeaders::Clear() { + balsa_buffer_.Clear(); + transfer_encoding_is_chunked_ = false; + content_length_ = 0; + content_length_status_ = BalsaHeadersEnums::NO_CONTENT_LENGTH; + parsed_response_code_ = 0; + firstline_buffer_base_idx_ = 0; + whitespace_1_idx_ = 0; + non_whitespace_1_idx_ = 0; + whitespace_2_idx_ = 0; + non_whitespace_2_idx_ = 0; + whitespace_3_idx_ = 0; + non_whitespace_3_idx_ = 0; + whitespace_4_idx_ = 0; + header_lines_.clear(); + header_lines_.shrink_to_fit(); +} + +void BalsaHeaders::CopyFrom(const BalsaHeaders& other) { + // Protect against copying with self. + if (this == &other) { + return; + } + + balsa_buffer_.CopyFrom(other.balsa_buffer_); + transfer_encoding_is_chunked_ = other.transfer_encoding_is_chunked_; + content_length_ = other.content_length_; + content_length_status_ = other.content_length_status_; + parsed_response_code_ = other.parsed_response_code_; + firstline_buffer_base_idx_ = other.firstline_buffer_base_idx_; + whitespace_1_idx_ = other.whitespace_1_idx_; + non_whitespace_1_idx_ = other.non_whitespace_1_idx_; + whitespace_2_idx_ = other.whitespace_2_idx_; + non_whitespace_2_idx_ = other.non_whitespace_2_idx_; + whitespace_3_idx_ = other.whitespace_3_idx_; + non_whitespace_3_idx_ = other.non_whitespace_3_idx_; + whitespace_4_idx_ = other.whitespace_4_idx_; + header_lines_ = other.header_lines_; +} + +void BalsaHeaders::AddAndMakeDescription(absl::string_view key, + absl::string_view value, + HeaderLineDescription* d) { + QUICHE_CHECK(d != nullptr); + + if (enforce_header_policy_) { + QuicheHandleHeaderPolicy(key); + } + + // + 2 to size for ": " + size_t line_size = key.size() + 2 + value.size(); + BalsaBuffer::Blocks::size_type block_buffer_idx = 0; + char* storage = balsa_buffer_.Reserve(line_size, &block_buffer_idx); + size_t base_idx = storage - GetPtr(block_buffer_idx); + + char* cur_loc = storage; + memcpy(cur_loc, key.data(), key.size()); + cur_loc += key.size(); + *cur_loc = ':'; + ++cur_loc; + *cur_loc = ' '; + ++cur_loc; + memcpy(cur_loc, value.data(), value.size()); + *d = HeaderLineDescription( + base_idx, base_idx + key.size(), base_idx + key.size() + 2, + base_idx + key.size() + 2 + value.size(), block_buffer_idx); +} + +void BalsaHeaders::AppendAndMakeDescription(absl::string_view key, + absl::string_view value, + HeaderLineDescription* d) { + // Figure out how much space we need to reserve for the new header size. + size_t old_value_size = d->last_char_idx - d->value_begin_idx; + if (old_value_size == 0) { + AddAndMakeDescription(key, value, d); + return; + } + absl::string_view old_value(GetPtr(d->buffer_base_idx) + d->value_begin_idx, + old_value_size); + + BalsaBuffer::Blocks::size_type block_buffer_idx = 0; + // + 3 because we potentially need to add ": ", and "," to the line. + size_t new_size = key.size() + 3 + old_value_size + value.size(); + char* storage = balsa_buffer_.Reserve(new_size, &block_buffer_idx); + size_t base_idx = storage - GetPtr(block_buffer_idx); + + absl::string_view first_value = old_value; + absl::string_view second_value = value; + char* cur_loc = storage; + memcpy(cur_loc, key.data(), key.size()); + cur_loc += key.size(); + *cur_loc = ':'; + ++cur_loc; + *cur_loc = ' '; + ++cur_loc; + memcpy(cur_loc, first_value.data(), first_value.size()); + cur_loc += first_value.size(); + *cur_loc = ','; + ++cur_loc; + memcpy(cur_loc, second_value.data(), second_value.size()); + + *d = HeaderLineDescription(base_idx, base_idx + key.size(), + base_idx + key.size() + 2, base_idx + new_size, + block_buffer_idx); +} + +// Reset internal flags for chunked transfer encoding or content length if a +// header we're removing is one of those headers. +void BalsaHeaders::MaybeClearSpecialHeaderValues(absl::string_view key) { + if (absl::EqualsIgnoreCase(key, kContentLength)) { + if (transfer_encoding_is_chunked_) { + return; + } + + content_length_status_ = BalsaHeadersEnums::NO_CONTENT_LENGTH; + content_length_ = 0; + return; + } + + if (absl::EqualsIgnoreCase(key, kTransferEncoding)) { + transfer_encoding_is_chunked_ = false; + } +} + +// Removes all keys value pairs with key 'key' starting at 'start'. +void BalsaHeaders::RemoveAllOfHeaderStartingAt(absl::string_view key, + HeaderLines::iterator start) { + MaybeClearSpecialHeaderValues(key); + while (start != header_lines_.end()) { + start->skip = true; + ++start; + start = GetHeaderLinesIterator(key, start); + } +} + +void BalsaHeaders::ReplaceOrAppendHeader(absl::string_view key, + absl::string_view value) { + const HeaderLines::iterator end = header_lines_.end(); + const HeaderLines::iterator begin = header_lines_.begin(); + HeaderLines::iterator i = GetHeaderLinesIterator(key, begin); + if (i != end) { + // First, remove all of the header lines including this one. We want to + // remove before replacing, in case our replacement ends up being appended + // at the end (and thus would be removed by this call) + RemoveAllOfHeaderStartingAt(key, i); + // Now, take the first instance and replace it. This will remove the + // 'skipped' tag if the replacement is done in-place. + AddAndMakeDescription(key, value, &(*i)); + return; + } + AppendHeader(key, value); +} + +void BalsaHeaders::AppendHeader(absl::string_view key, + absl::string_view value) { + HeaderLineDescription hld; + AddAndMakeDescription(key, value, &hld); + header_lines_.push_back(hld); +} + +void BalsaHeaders::AppendToHeader(absl::string_view key, + absl::string_view value) { + HeaderLines::iterator i = GetHeaderLinesIterator(key, header_lines_.begin()); + if (i == header_lines_.end()) { + // The header did not exist already. Instead of appending to an existing + // header simply append the key/value pair to the headers. + AppendHeader(key, value); + return; + } + HeaderLineDescription hld = *i; + + AppendAndMakeDescription(key, value, &hld); + + // Invalidate the old header line and add the new one. + i->skip = true; + header_lines_.push_back(hld); +} + +void BalsaHeaders::AppendToHeaderWithCommaAndSpace(absl::string_view key, + absl::string_view value) { + HeaderLines::iterator i = GetHeaderLinesIteratorForLastMultivaluedHeader(key); + if (i == header_lines_.end()) { + // The header did not exist already. Instead of appending to an existing + // header simply append the key/value pair to the headers. No extra + // space will be added before the value. + AppendHeader(key, value); + return; + } + + std::string space_and_value = absl::StrCat(" ", value); + + HeaderLineDescription hld = *i; + AppendAndMakeDescription(key, space_and_value, &hld); + + // Invalidate the old header line and add the new one. + i->skip = true; + header_lines_.push_back(hld); +} + +absl::string_view BalsaHeaders::GetValueFromHeaderLineDescription( + const HeaderLineDescription& line) const { + QUICHE_DCHECK_GE(line.last_char_idx, line.value_begin_idx); + return absl::string_view(GetPtr(line.buffer_base_idx) + line.value_begin_idx, + line.last_char_idx - line.value_begin_idx); +} + +absl::string_view BalsaHeaders::GetHeader(absl::string_view key) const { + QUICHE_DCHECK(!header_properties::IsMultivaluedHeader(key)) + << "Header '" << key << "' may consist of multiple lines. Do not " + << "use BalsaHeaders::GetHeader() or you may be missing some of its " + << "values."; + const HeaderLines::const_iterator end = header_lines_.end(); + HeaderLines::const_iterator i = GetConstHeaderLinesIterator(key); + if (i == end) { + return absl::string_view(); + } + return GetValueFromHeaderLineDescription(*i); +} + +BalsaHeaders::const_header_lines_iterator BalsaHeaders::GetHeaderPosition( + absl::string_view key) const { + const HeaderLines::const_iterator end = header_lines_.end(); + HeaderLines::const_iterator i = GetConstHeaderLinesIterator(key); + if (i == end) { + // TODO(tgreer) Convert from HeaderLines::const_iterator to + // const_header_lines_iterator without calling lines().end(), which is + // nontrivial. Look for other needless calls to lines().end(), or make + // lines().end() trivial. + return lines().end(); + } + + return const_header_lines_iterator(this, (i - header_lines_.begin())); +} + +BalsaHeaders::const_header_lines_key_iterator BalsaHeaders::GetIteratorForKey( + absl::string_view key) const { + HeaderLines::const_iterator i = GetConstHeaderLinesIterator(key); + if (i == header_lines_.end()) { + return header_lines_key_end(); + } + + return const_header_lines_key_iterator(this, (i - header_lines_.begin()), + key); +} + +BalsaHeaders::HeaderLines::const_iterator +BalsaHeaders::GetConstHeaderLinesIterator(absl::string_view key) const { + const HeaderLines::const_iterator end = header_lines_.end(); + for (HeaderLines::const_iterator i = header_lines_.begin(); i != end; ++i) { + const HeaderLineDescription& line = *i; + if (line.skip) { + continue; + } + const absl::string_view current_key( + GetPtr(line.buffer_base_idx) + line.first_char_idx, + line.key_end_idx - line.first_char_idx); + if (absl::EqualsIgnoreCase(current_key, key)) { + QUICHE_DCHECK_GE(line.last_char_idx, line.value_begin_idx); + return i; + } + } + return end; +} + +BalsaHeaders::HeaderLines::iterator BalsaHeaders::GetHeaderLinesIterator( + absl::string_view key, BalsaHeaders::HeaderLines::iterator start) { + const HeaderLines::iterator end = header_lines_.end(); + for (HeaderLines::iterator i = start; i != end; ++i) { + const HeaderLineDescription& line = *i; + if (line.skip) { + continue; + } + const absl::string_view current_key( + GetPtr(line.buffer_base_idx) + line.first_char_idx, + line.key_end_idx - line.first_char_idx); + if (absl::EqualsIgnoreCase(current_key, key)) { + QUICHE_DCHECK_GE(line.last_char_idx, line.value_begin_idx); + return i; + } + } + return end; +} + +BalsaHeaders::HeaderLines::iterator +BalsaHeaders::GetHeaderLinesIteratorForLastMultivaluedHeader( + absl::string_view key) { + const HeaderLines::iterator end = header_lines_.end(); + HeaderLines::iterator last_found_match; + bool found_a_match = false; + for (HeaderLines::iterator i = header_lines_.begin(); i != end; ++i) { + const HeaderLineDescription& line = *i; + if (line.skip) { + continue; + } + const absl::string_view current_key( + GetPtr(line.buffer_base_idx) + line.first_char_idx, + line.key_end_idx - line.first_char_idx); + if (absl::EqualsIgnoreCase(current_key, key)) { + QUICHE_DCHECK_GE(line.last_char_idx, line.value_begin_idx); + last_found_match = i; + found_a_match = true; + } + } + return (found_a_match ? last_found_match : end); +} + +void BalsaHeaders::GetAllOfHeader(absl::string_view key, + std::vector* out) const { + for (const_header_lines_key_iterator it = GetIteratorForKey(key); + it != lines().end(); ++it) { + out->push_back(it->second); + } +} + +void BalsaHeaders::GetAllOfHeaderIncludeRemoved( + absl::string_view key, std::vector* out) const { + const HeaderLines::const_iterator begin = header_lines_.begin(); + const HeaderLines::const_iterator end = header_lines_.end(); + for (bool add_removed : {false, true}) { + for (HeaderLines::const_iterator i = begin; i != end; ++i) { + const HeaderLineDescription& line = *i; + if ((!add_removed && line.skip) || (add_removed && !line.skip)) { + continue; + } + const absl::string_view current_key( + GetPtr(line.buffer_base_idx) + line.first_char_idx, + line.key_end_idx - line.first_char_idx); + if (absl::EqualsIgnoreCase(current_key, key)) { + QUICHE_DCHECK_GE(line.last_char_idx, line.value_begin_idx); + out->push_back(GetValueFromHeaderLineDescription(line)); + } + } + } +} + +namespace { + +// Helper function for HeaderHasValue that checks that the specified region +// within line is preceded by whitespace and a comma or beginning of line, +// and followed by whitespace and a comma or end of line. +bool SurroundedOnlyBySpacesAndCommas(absl::string_view::difference_type idx, + absl::string_view::difference_type end_idx, + absl::string_view line) { + for (idx = idx - 1; idx >= 0; --idx) { + if (line[idx] == ',') { + break; + } + if (line[idx] != ' ') { + return false; + } + } + + for (; end_idx < static_cast(line.size()); ++end_idx) { + if (line[end_idx] == ',') { + break; + } + if (line[end_idx] != ' ') { + return false; + } + } + return true; +} + +} // namespace + +bool BalsaHeaders::HeaderHasValueHelper(absl::string_view key, + absl::string_view value, + bool case_sensitive) const { + for (const_header_lines_key_iterator it = GetIteratorForKey(key); + it != lines().end(); ++it) { + absl::string_view line = it->second; + absl::string_view::size_type idx = + case_sensitive ? line.find(value, 0) : FindIgnoreCase(line, value); + while (idx != absl::string_view::npos) { + absl::string_view::difference_type end_idx = idx + value.size(); + if (SurroundedOnlyBySpacesAndCommas(idx, end_idx, line)) { + return true; + } + idx = line.find(value, idx + 1); + } + } + return false; +} + +bool BalsaHeaders::HasNonEmptyHeader(absl::string_view key) const { + for (const_header_lines_key_iterator it = GetIteratorForKey(key); + it != header_lines_key_end(); ++it) { + if (!it->second.empty()) { + return true; + } + } + return false; +} + +std::string BalsaHeaders::GetAllOfHeaderAsString(absl::string_view key) const { + // Use custom formatter to ignore header key and join only header values. + // absl::AlphaNumFormatter is the default formatter for absl::StrJoin(). + auto formatter = [](std::string* out, + std::pair header) { + return absl::AlphaNumFormatter()(out, header.second); + }; + return absl::StrJoin(GetIteratorForKey(key), header_lines_key_end(), ",", + formatter); +} + +void BalsaHeaders::RemoveAllOfHeaderInList(const HeaderTokenList& keys) { + if (keys.empty()) { + return; + } + + // This extra copy sacrifices some performance to prevent the possible + // mistakes that the caller does not lower case the headers in keys. + // Better performance can be achieved by asking caller to lower case + // the keys and RemoveAllOfheaderInlist just does lookup. + absl::flat_hash_set lowercase_keys; + for (const auto& key : keys) { + MaybeClearSpecialHeaderValues(key); + lowercase_keys.insert(absl::AsciiStrToLower(key)); + } + + for (HeaderLineDescription& line : header_lines_) { + if (line.skip) { + continue; + } + // Remove the header if it matches any of the keys to remove. + const size_t key_len = line.key_end_idx - line.first_char_idx; + absl::string_view key(GetPtr(line.buffer_base_idx) + line.first_char_idx, + key_len); + + std::string lowercase_key = absl::AsciiStrToLower(key); + if (lowercase_keys.count(lowercase_key) != 0) { + line.skip = true; + } + } +} + +void BalsaHeaders::RemoveAllOfHeader(absl::string_view key) { + HeaderLines::iterator it = GetHeaderLinesIterator(key, header_lines_.begin()); + RemoveAllOfHeaderStartingAt(key, it); +} + +void BalsaHeaders::RemoveAllHeadersWithPrefix(absl::string_view prefix) { + for (HeaderLines::size_type i = 0; i < header_lines_.size(); ++i) { + if (header_lines_[i].skip) { + continue; + } + + HeaderLineDescription& line = header_lines_[i]; + const size_t key_len = line.key_end_idx - line.first_char_idx; + if (key_len < prefix.size()) { + continue; + } + + const absl::string_view current_key_prefix( + GetPtr(line.buffer_base_idx) + line.first_char_idx, prefix.size()); + if (absl::EqualsIgnoreCase(current_key_prefix, prefix)) { + const absl::string_view current_key( + GetPtr(line.buffer_base_idx) + line.first_char_idx, key_len); + MaybeClearSpecialHeaderValues(current_key); + line.skip = true; + } + } +} + +bool BalsaHeaders::HasHeadersWithPrefix(absl::string_view prefix) const { + for (HeaderLines::size_type i = 0; i < header_lines_.size(); ++i) { + if (header_lines_[i].skip) { + continue; + } + + const HeaderLineDescription& line = header_lines_[i]; + if (line.key_end_idx - line.first_char_idx < prefix.size()) { + continue; + } + + const absl::string_view current_key_prefix( + GetPtr(line.buffer_base_idx) + line.first_char_idx, prefix.size()); + if (absl::EqualsIgnoreCase(current_key_prefix, prefix)) { + return true; + } + } + return false; +} + +void BalsaHeaders::GetAllOfHeaderWithPrefix( + absl::string_view prefix, + std::vector>* out) const { + for (HeaderLines::size_type i = 0; i < header_lines_.size(); ++i) { + if (header_lines_[i].skip) { + continue; + } + const HeaderLineDescription& line = header_lines_[i]; + absl::string_view key(GetPtr(line.buffer_base_idx) + line.first_char_idx, + line.key_end_idx - line.first_char_idx); + if (absl::StartsWithIgnoreCase(key, prefix)) { + out->push_back(std::make_pair( + key, + absl::string_view(GetPtr(line.buffer_base_idx) + line.value_begin_idx, + line.last_char_idx - line.value_begin_idx))); + } + } +} + +void BalsaHeaders::GetAllHeadersWithLimit( + std::vector>* out, + int limit) const { + for (HeaderLines::size_type i = 0; i < header_lines_.size(); ++i) { + if (limit >= 0 && out->size() >= static_cast(limit)) { + return; + } + if (header_lines_[i].skip) { + continue; + } + const HeaderLineDescription& line = header_lines_[i]; + absl::string_view key(GetPtr(line.buffer_base_idx) + line.first_char_idx, + line.key_end_idx - line.first_char_idx); + out->push_back(std::make_pair( + key, + absl::string_view(GetPtr(line.buffer_base_idx) + line.value_begin_idx, + line.last_char_idx - line.value_begin_idx))); + } +} + +size_t BalsaHeaders::RemoveValue(absl::string_view key, + absl::string_view search_value) { + // Remove whitespace around search value. + absl::string_view needle = search_value; + RemoveWhitespaceContext(&needle); + QUICHE_BUG_IF(bug_22783_2, needle != search_value) + << "Search value should not be surrounded by spaces."; + + // We have nothing to do for empty needle strings. + if (needle.empty()) { + return 0; + } + + // The return value: number of removed values. + size_t removals = 0; + + // Iterate over all header lines matching key with skip=false. + for (HeaderLines::iterator it = + GetHeaderLinesIterator(key, header_lines_.begin()); + it != header_lines_.end(); it = GetHeaderLinesIterator(key, ++it)) { + HeaderLineDescription* line = &(*it); + + // If needle given to us is longer than this header, don't consider it. + if (line->ValuesLength() < needle.size()) { + continue; + } + + // If the values are equivalent, just remove the whole line. + char* buf = GetPtr(line->buffer_base_idx); // The head of our buffer. + char* value_begin = buf + line->value_begin_idx; + // StringPiece containing values that have yet to be processed. The head of + // this stringpiece will continually move forward, and its tail + // (head+length) will always remain the same. + absl::string_view values(value_begin, line->ValuesLength()); + RemoveWhitespaceContext(&values); + if (values.size() == needle.size()) { + if (values == needle) { + line->skip = true; + removals++; + } + continue; + } + + // Find all occurrences of the needle to be removed. + char* insertion = value_begin; + while (values.size() >= needle.size()) { + // Strip leading whitespace. + ssize_t cur_leading_whitespace = RemoveLeadingWhitespace(&values); + + // See if we've got a match (at least as a prefix). + bool found = absl::StartsWith(values, needle); + + // Find the entirety of this value (including trailing comma if existent). + const size_t next_comma = + values.find(',', /* pos = */ (found ? needle.size() : 0)); + const bool comma_found = next_comma != absl::string_view::npos; + const size_t cur_size = (comma_found ? next_comma + 1 : values.size()); + + // Make sure that our prefix match is a full match. + if (found && cur_size != needle.size()) { + absl::string_view cur(values.data(), cur_size); + if (comma_found) { + cur.remove_suffix(1); + } + RemoveTrailingWhitespace(&cur); + found = (cur.size() == needle.size()); + } + + // Move as necessary (avoid move just for the sake of leading whitespace). + if (found) { + removals++; + // Remove trailing comma if we happen to have found the last value. + if (!comma_found) { + // We modify insertion since it'll be used to update last_char_idx. + insertion--; + } + } else { + if (insertion + cur_leading_whitespace != values.data()) { + // Has the side-effect of also copying any trailing whitespace. + memmove(insertion, values.data(), cur_size); + insertion += cur_size; + } else { + insertion += cur_leading_whitespace + cur_size; + } + } + + // No longer consider the current value. (Increment.) + values.remove_prefix(cur_size); + } + // Move remaining data. + if (!values.empty()) { + if (insertion != values.data()) { + memmove(insertion, values.data(), values.size()); + } + insertion += values.size(); + } + // Set new line size. + if (insertion <= value_begin) { + // All values removed. + line->skip = true; + } else { + line->last_char_idx = insertion - buf; + } + } + + return removals; +} + +size_t BalsaHeaders::GetSizeForWriteBuffer() const { + // First add the space required for the first line + line separator. + size_t write_buf_size = whitespace_4_idx_ - non_whitespace_1_idx_ + 2; + // Then add the space needed for each header line to write out + line + // separator. + const HeaderLines::size_type end = header_lines_.size(); + for (HeaderLines::size_type i = 0; i < end; ++i) { + const HeaderLineDescription& line = header_lines_[i]; + if (!line.skip) { + // Add the key size and ": ". + write_buf_size += line.key_end_idx - line.first_char_idx + 2; + // Add the value size and the line separator. + write_buf_size += line.last_char_idx - line.value_begin_idx + 2; + } + } + // Finally tack on the terminal line separator. + return write_buf_size + 2; +} + +void BalsaHeaders::DumpToString(std::string* str) const { + DumpToPrefixedString(" ", str); +} + +std::string BalsaHeaders::DebugString() const { + std::string s; + DumpToString(&s); + return s; +} + +bool BalsaHeaders::ForEachHeader( + std::function + fn) const { + int s = header_lines_.size(); + for (int i = 0; i < s; ++i) { + const HeaderLineDescription& desc = header_lines_[i]; + if (!desc.skip && desc.KeyLength() > 0) { + const char* stream_begin = GetPtr(desc.buffer_base_idx); + if (!fn(absl::string_view(stream_begin + desc.first_char_idx, + desc.KeyLength()), + absl::string_view(stream_begin + desc.value_begin_idx, + desc.ValuesLength()))) { + return false; + } + } + } + return true; +} + +void BalsaHeaders::DumpToPrefixedString(const char* spaces, + std::string* str) const { + const absl::string_view firstline = first_line(); + const int buffer_length = GetReadableBytesFromHeaderStream(); + // First check whether the header object is empty. + if (firstline.empty() && buffer_length == 0) { + absl::StrAppend(str, "\n", spaces, "\n"); + return; + } + + // Then check whether the header is in a partially parsed state. If so, just + // dump the raw data. + if (!FramerIsDoneWriting()) { + absl::StrAppendFormat(str, "\n%s\n%s%.*s\n", + spaces, buffer_length, spaces, buffer_length, + OriginalHeaderStreamBegin()); + return; + } + + // If the header is complete, then just dump them with the logical key value + // pair. + str->reserve(str->size() + GetSizeForWriteBuffer()); + absl::StrAppend(str, "\n", spaces, firstline, "\n"); + for (const auto& line : lines()) { + absl::StrAppend(str, spaces, line.first, ": ", line.second, "\n"); + } +} + +void BalsaHeaders::SetContentLength(size_t length) { + // If the content-length is already the one we want, don't do anything. + if (content_length_status_ == BalsaHeadersEnums::VALID_CONTENT_LENGTH && + content_length_ == length) { + return; + } + // If header state indicates that there is either a content length or + // transfer encoding header, remove them before adding the new content + // length. There is always the possibility that client can manually add + // either header directly and cause content_length_status_ or + // transfer_encoding_is_chunked_ to be inconsistent with the actual header. + // In the interest of efficiency, however, we will assume that clients will + // use the header object correctly and thus we will not scan the all headers + // each time this function is called. + if (content_length_status_ != BalsaHeadersEnums::NO_CONTENT_LENGTH) { + RemoveAllOfHeader(kContentLength); + } else if (transfer_encoding_is_chunked_) { + RemoveAllOfHeader(kTransferEncoding); + } + content_length_status_ = BalsaHeadersEnums::VALID_CONTENT_LENGTH; + content_length_ = length; + + AppendHeader(kContentLength, absl::StrCat(length)); +} + +void BalsaHeaders::SetTransferEncodingToChunkedAndClearContentLength() { + if (transfer_encoding_is_chunked_) { + return; + } + if (content_length_status_ != BalsaHeadersEnums::NO_CONTENT_LENGTH) { + // Per https://httpwg.org/specs/rfc7230.html#header.content-length, we can't + // send both transfer-encoding and content-length. + ClearContentLength(); + } + ReplaceOrAppendHeader(kTransferEncoding, "chunked"); + transfer_encoding_is_chunked_ = true; +} + +void BalsaHeaders::SetNoTransferEncoding() { + if (transfer_encoding_is_chunked_) { + // clears transfer_encoding_is_chunked_ + RemoveAllOfHeader(kTransferEncoding); + } +} + +void BalsaHeaders::ClearContentLength() { RemoveAllOfHeader(kContentLength); } + +bool BalsaHeaders::IsEmpty() const { + return balsa_buffer_.GetTotalBytesUsed() == 0; +} + +absl::string_view BalsaHeaders::Authority() const { return GetHeader(kHost); } + +void BalsaHeaders::ReplaceOrAppendAuthority(absl::string_view value) { + ReplaceOrAppendHeader(kHost, value); +} + +void BalsaHeaders::RemoveAuthority() { RemoveAllOfHeader(kHost); } + +void BalsaHeaders::ApplyToCookie( + std::function f) const { + f(GetHeader(kCookie)); +} + +void BalsaHeaders::SetResponseFirstline(absl::string_view version, + size_t parsed_response_code, + absl::string_view reason_phrase) { + SetFirstlineFromStringPieces(version, absl::StrCat(parsed_response_code), + reason_phrase); + parsed_response_code_ = parsed_response_code; +} + +void BalsaHeaders::SetFirstlineFromStringPieces(absl::string_view firstline_a, + absl::string_view firstline_b, + absl::string_view firstline_c) { + size_t line_size = + (firstline_a.size() + firstline_b.size() + firstline_c.size() + 2); + char* storage = balsa_buffer_.Reserve(line_size, &firstline_buffer_base_idx_); + char* cur_loc = storage; + + memcpy(cur_loc, firstline_a.data(), firstline_a.size()); + cur_loc += firstline_a.size(); + + *cur_loc = ' '; + ++cur_loc; + + memcpy(cur_loc, firstline_b.data(), firstline_b.size()); + cur_loc += firstline_b.size(); + + *cur_loc = ' '; + ++cur_loc; + + memcpy(cur_loc, firstline_c.data(), firstline_c.size()); + + whitespace_1_idx_ = storage - BeginningOfFirstLine(); + non_whitespace_1_idx_ = whitespace_1_idx_; + whitespace_2_idx_ = non_whitespace_1_idx_ + firstline_a.size(); + non_whitespace_2_idx_ = whitespace_2_idx_ + 1; + whitespace_3_idx_ = non_whitespace_2_idx_ + firstline_b.size(); + non_whitespace_3_idx_ = whitespace_3_idx_ + 1; + whitespace_4_idx_ = non_whitespace_3_idx_ + firstline_c.size(); +} + +void BalsaHeaders::SetRequestMethod(absl::string_view method) { + // This is the first of the three parts of the firstline. + if (method.size() <= (whitespace_2_idx_ - non_whitespace_1_idx_)) { + non_whitespace_1_idx_ = whitespace_2_idx_ - method.size(); + if (!method.empty()) { + char* stream_begin = BeginningOfFirstLine(); + memcpy(stream_begin + non_whitespace_1_idx_, method.data(), + method.size()); + } + } else { + // The new method is too large to fit in the space available for the old + // one, so we have to reformat the firstline. + SetRequestFirstlineFromStringPieces(method, request_uri(), + request_version()); + } +} + +void BalsaHeaders::SetResponseVersion(absl::string_view version) { + // Note: There is no difference between request_method() and + // response_Version(). Thus, a function to set one is equivalent to a + // function to set the other. We maintain two functions for this as it is + // much more descriptive, and makes code more understandable. + SetRequestMethod(version); +} + +void BalsaHeaders::SetRequestUri(absl::string_view uri) { + SetRequestFirstlineFromStringPieces(request_method(), uri, request_version()); +} + +void BalsaHeaders::SetResponseCode(absl::string_view code) { + // Note: There is no difference between request_uri() and response_code(). + // Thus, a function to set one is equivalent to a function to set the other. + // We maintain two functions for this as it is much more descriptive, and + // makes code more understandable. + SetRequestUri(code); +} + +void BalsaHeaders::SetParsedResponseCodeAndUpdateFirstline( + size_t parsed_response_code) { + parsed_response_code_ = parsed_response_code; + SetResponseCode(absl::StrCat(parsed_response_code)); +} + +void BalsaHeaders::SetRequestVersion(absl::string_view version) { + // This is the last of the three parts of the firstline. + // Since whitespace_3_idx and non_whitespace_3_idx may point to the same + // place, we ensure below that any available space includes space for a + // literal space (' ') character between the second component and the third + // component. + bool fits_in_space_allowed = + version.size() + 1 <= whitespace_4_idx_ - whitespace_3_idx_; + + if (!fits_in_space_allowed) { + // If the new version is too large, then reformat the firstline. + SetRequestFirstlineFromStringPieces(request_method(), request_uri(), + version); + return; + } + + char* stream_begin = BeginningOfFirstLine(); + *(stream_begin + whitespace_3_idx_) = ' '; + non_whitespace_3_idx_ = whitespace_3_idx_ + 1; + whitespace_4_idx_ = non_whitespace_3_idx_ + version.size(); + memcpy(stream_begin + non_whitespace_3_idx_, version.data(), version.size()); +} + +void BalsaHeaders::SetResponseReasonPhrase(absl::string_view reason) { + // Note: There is no difference between request_version() and + // response_reason_phrase(). Thus, a function to set one is equivalent to a + // function to set the other. We maintain two functions for this as it is + // much more descriptive, and makes code more understandable. + SetRequestVersion(reason); +} + +void BalsaHeaders::RemoveLastTokenFromHeaderValue(absl::string_view key) { + BalsaHeaders::HeaderLines::iterator it = + GetHeaderLinesIterator(key, header_lines_.begin()); + if (it == header_lines_.end()) { + QUICHE_DLOG(WARNING) + << "Attempting to remove last token from a non-existent " + << "header \"" << key << "\""; + return; + } + + // Find the last line with that key. + BalsaHeaders::HeaderLines::iterator header_line; + do { + header_line = it; + it = GetHeaderLinesIterator(key, it + 1); + } while (it != header_lines_.end()); + + // Tokenize just that line. + BalsaHeaders::HeaderTokenList tokens; + // Find where this line is stored. + const char* stream_begin = GetPtr(header_line->buffer_base_idx); + absl::string_view value( + stream_begin + header_line->value_begin_idx, + header_line->last_char_idx - header_line->value_begin_idx); + // Tokenize. + ParseTokenList(value, &tokens); + + if (tokens.empty()) { + QUICHE_DLOG(WARNING) + << "Attempting to remove a token from an empty header value " + << "for header \"" << key << "\""; + header_line->skip = true; // remove the whole line + } else if (tokens.size() == 1) { + header_line->skip = true; // remove the whole line + } else { + // Shrink the line size and leave the extra data in the buffer. + absl::string_view new_last_token = tokens[tokens.size() - 2]; + const char* last_char_address = + new_last_token.data() + new_last_token.size() - 1; + const char* stream_begin = GetPtr(header_line->buffer_base_idx); + + header_line->last_char_idx = last_char_address - stream_begin + 1; + } +} + +bool BalsaHeaders::ResponseCanHaveBody(int response_code) { + // For responses, can't have a body if the request was a HEAD, or if it is + // one of these response-codes. rfc2616 section 4.3 + if (response_code >= 100 && response_code < 200) { + // 1xx responses can't have bodies. + return false; + } + + // No content and Not modified responses have no body. + return (response_code != 204) && (response_code != 304); +} + +} // namespace quiche diff --git a/quiche/balsa/balsa_headers.h b/quiche/balsa/balsa_headers.h new file mode 100644 index 000000000000..0005bb5d137e --- /dev/null +++ b/quiche/balsa/balsa_headers.h @@ -0,0 +1,1468 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// A lightweight implementation for storing HTTP headers. + +#ifndef QUICHE_BALSA_BALSA_HEADERS_H_ +#define QUICHE_BALSA_BALSA_HEADERS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/strings/ascii.h" +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/balsa/balsa_enums.h" +#include "quiche/balsa/header_api.h" +#include "quiche/balsa/standard_header_map.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace gfe2 { +class Http2HeaderValidator; +} // namespace gfe2 + +namespace quiche { + +namespace test { +class BalsaHeadersTestPeer; +} // namespace test + +// WARNING: +// Note that -no- char* returned by any function in this +// file is null-terminated. + +// This class exists to service the specific needs of BalsaHeaders. +// +// Functional goals: +// 1) provide a backing-store for all of the StringPieces that BalsaHeaders +// returns. Every StringPiece returned from BalsaHeaders should remain +// valid until the BalsaHeader's object is cleared, or the header-line is +// erased. +// 2) provide a backing-store for BalsaFrame, which requires contiguous memory +// for its fast-path parsing functions. Note that the cost of copying is +// less than the cost of requiring the parser to do slow-path parsing, as +// it would have to check for bounds every byte, instead of every 16 bytes. +// +// This class is optimized for the case where headers are stored in one of two +// buffers. It doesn't make a lot of effort to densely pack memory-- in fact, +// it -may- be somewhat memory inefficient. This possible inefficiency allows a +// certain simplicity of implementation and speed which makes it worthwhile. +// If, in the future, better memory density is required, it should be possible +// to reuse the abstraction presented by this object to achieve those goals. +// +// In the most common use-case, this memory inefficiency should be relatively +// small. +// +// Alternate implementations of BalsaBuffer may include: +// - vector of strings, one per header line (similar to HTTPHeaders) +// - densely packed strings: +// - keep a sorted array/map of free-space linked lists or numbers. +// - use the entry that most closely first your needs. +// - at this point, perhaps just use a vector of strings, and let +// the allocator do the right thing. +// +class QUICHE_EXPORT BalsaBuffer { + public: + static constexpr size_t kDefaultBlocksize = 4096; + + // The BufferBlock is a structure used internally by the + // BalsaBuffer class to store the base buffer pointers to + // each block, as well as the important metadata for buffer + // sizes and bytes free. It *may* be possible to replace this + // with a vector, but it's unclear whether moving a vector + // can invalidate pointers into it. LWG issue 2321 proposes to fix this. + struct QUICHE_EXPORT BufferBlock { + public: + std::unique_ptr buffer; + size_t buffer_size = 0; + size_t bytes_free = 0; + + size_t bytes_used() const { return buffer_size - bytes_free; } + char* start_of_unused_bytes() const { return buffer.get() + bytes_used(); } + + BufferBlock() {} + + BufferBlock(std::unique_ptr buf, size_t size, size_t free) + : buffer(std::move(buf)), buffer_size(size), bytes_free(free) {} + + BufferBlock(const BufferBlock&) = delete; + BufferBlock& operator=(const BufferBlock&) = delete; + BufferBlock(BufferBlock&&) = default; + BufferBlock& operator=(BufferBlock&&) = default; + + // Note: allocating a fresh buffer even if we could reuse an old one may let + // us shed memory, and invalidates old StringPieces (making them easier to + // catch with asan). + void CopyFrom(const BufferBlock& rhs) { + QUICHE_DCHECK(this != &rhs); + buffer_size = rhs.buffer_size; + bytes_free = rhs.bytes_free; + if (rhs.buffer == nullptr) { + buffer = nullptr; + } else { + buffer = std::make_unique(buffer_size); + memcpy(buffer.get(), rhs.buffer.get(), rhs.bytes_used()); + } + } + }; + + typedef std::vector Blocks; + + BalsaBuffer() + : blocksize_(kDefaultBlocksize), can_write_to_contiguous_buffer_(true) {} + + explicit BalsaBuffer(size_t blocksize) + : blocksize_(blocksize), can_write_to_contiguous_buffer_(true) {} + + BalsaBuffer(const BalsaBuffer&) = delete; + BalsaBuffer& operator=(const BalsaBuffer&) = delete; + BalsaBuffer(BalsaBuffer&&) = default; + BalsaBuffer& operator=(BalsaBuffer&&) = default; + + // Returns the total amount of memory reserved by the buffer blocks. + size_t GetTotalBufferBlockSize() const { + size_t buffer_size = 0; + for (Blocks::const_iterator iter = blocks_.begin(); iter != blocks_.end(); + ++iter) { + buffer_size += iter->buffer_size; + } + return buffer_size; + } + + // Returns the total amount of memory used by the buffer blocks. + size_t GetTotalBytesUsed() const { + size_t bytes_used = 0; + for (const auto& b : blocks_) { + bytes_used += b.bytes_used(); + } + return bytes_used; + } + + const char* GetPtr(Blocks::size_type block_idx) const { + QUICHE_DCHECK_LT(block_idx, blocks_.size()) + << block_idx << ", " << blocks_.size(); + return block_idx >= blocks_.size() ? nullptr + : blocks_[block_idx].buffer.get(); + } + + char* GetPtr(Blocks::size_type block_idx) { + QUICHE_DCHECK_LT(block_idx, blocks_.size()) + << block_idx << ", " << blocks_.size(); + return block_idx >= blocks_.size() ? nullptr + : blocks_[block_idx].buffer.get(); + } + + // This function is different from Reserve(), as it ensures that the data + // stored via subsequent calls to this function are all contiguous (and in + // the order in which these writes happened). This is essentially the same + // as a string append. + // + // You may call this function at any time between object + // construction/Clear(), and the calling of the + // NoMoreWriteToContiguousBuffer() function. + // + // You must not call this function after the NoMoreWriteToContiguousBuffer() + // function is called, unless a Clear() has been called since. + // If you do, the program will abort(). + // + // This condition is placed upon this code so that calls to Reserve() can + // append to the buffer in the first block safely, and without invaliding + // the StringPiece which it returns. + // + // This function's main intended user is the BalsaFrame class, which, + // for reasons of efficiency, requires that the buffer from which it parses + // the headers be contiguous. + // + void WriteToContiguousBuffer(absl::string_view sp) { + if (sp.empty()) { + return; + } + QUICHE_CHECK(can_write_to_contiguous_buffer_); + + if (blocks_.empty()) { + blocks_.push_back(AllocBlock()); + } + + QUICHE_DCHECK_GE(blocks_.size(), 1u); + if (blocks_[0].buffer == nullptr && sp.size() <= blocksize_) { + blocks_[0] = AllocBlock(); + memcpy(blocks_[0].start_of_unused_bytes(), sp.data(), sp.size()); + } else if (blocks_[0].bytes_free < sp.size()) { + // the first block isn't big enough, resize it. + const size_t old_storage_size_used = blocks_[0].bytes_used(); + // Increase to at least 2*old_storage_size_used; if sp.size() is larger, + // we'll increase by that amount. + const size_t new_storage_size = + old_storage_size_used + (old_storage_size_used < sp.size() + ? sp.size() + : old_storage_size_used); + std::unique_ptr new_storage{new char[new_storage_size]}; + char* old_storage = blocks_[0].buffer.get(); + if (old_storage_size_used != 0u) { + memcpy(new_storage.get(), old_storage, old_storage_size_used); + } + memcpy(new_storage.get() + old_storage_size_used, sp.data(), sp.size()); + blocks_[0].buffer = std::move(new_storage); + blocks_[0].bytes_free = new_storage_size - old_storage_size_used; + blocks_[0].buffer_size = new_storage_size; + } else { + memcpy(blocks_[0].start_of_unused_bytes(), sp.data(), sp.size()); + } + blocks_[0].bytes_free -= sp.size(); + } + + void NoMoreWriteToContiguousBuffer() { + can_write_to_contiguous_buffer_ = false; + } + + // Reserves "permanent" storage of the size indicated. Returns a pointer to + // the beginning of that storage, and assigns the index of the block used to + // block_buffer_idx. This function uses the first block IFF the + // NoMoreWriteToContiguousBuffer function has been called since the last + // Clear/Construction. + char* Reserve(size_t size, Blocks::size_type* block_buffer_idx) { + if (blocks_.empty()) { + blocks_.push_back(AllocBlock()); + } + + // There should always be a 'first_block', even if it + // contains nothing. + QUICHE_DCHECK_GE(blocks_.size(), 1u); + BufferBlock* block = nullptr; + Blocks::size_type block_idx = can_write_to_contiguous_buffer_ ? 1 : 0; + for (; block_idx < blocks_.size(); ++block_idx) { + if (blocks_[block_idx].bytes_free >= size) { + block = &blocks_[block_idx]; + break; + } + } + if (block == nullptr) { + if (blocksize_ < size) { + blocks_.push_back(AllocCustomBlock(size)); + } else { + blocks_.push_back(AllocBlock()); + } + block = &blocks_.back(); + } + + char* storage = block->start_of_unused_bytes(); + block->bytes_free -= size; + if (block_buffer_idx != nullptr) { + *block_buffer_idx = block_idx; + } + return storage; + } + + void Clear() { + blocks_.clear(); + blocks_.shrink_to_fit(); + can_write_to_contiguous_buffer_ = true; + } + + void CopyFrom(const BalsaBuffer& b) { + blocks_.resize(b.blocks_.size()); + for (Blocks::size_type i = 0; i < blocks_.size(); ++i) { + blocks_[i].CopyFrom(b.blocks_[i]); + } + blocksize_ = b.blocksize_; + can_write_to_contiguous_buffer_ = b.can_write_to_contiguous_buffer_; + } + + const char* StartOfFirstBlock() const { + QUICHE_BUG_IF(bug_if_1182_1, blocks_.empty()) + << "First block not allocated yet!"; + return blocks_.empty() ? nullptr : blocks_[0].buffer.get(); + } + + const char* EndOfFirstBlock() const { + QUICHE_BUG_IF(bug_if_1182_2, blocks_.empty()) + << "First block not allocated yet!"; + return blocks_.empty() ? nullptr : blocks_[0].start_of_unused_bytes(); + } + + size_t GetReadableBytesOfFirstBlock() const { + return blocks_.empty() ? 0 : blocks_[0].bytes_used(); + } + + bool can_write_to_contiguous_buffer() const { + return can_write_to_contiguous_buffer_; + } + size_t blocksize() const { return blocksize_; } + Blocks::size_type num_blocks() const { return blocks_.size(); } + size_t buffer_size(size_t idx) const { return blocks_[idx].buffer_size; } + size_t bytes_used(size_t idx) const { return blocks_[idx].bytes_used(); } + + private: + BufferBlock AllocBlock() { return AllocCustomBlock(blocksize_); } + + BufferBlock AllocCustomBlock(size_t blocksize) { + return BufferBlock{std::make_unique(blocksize), blocksize, + blocksize}; + } + + // A container of BufferBlocks + Blocks blocks_; + + // The default allocation size for a block. + // In general, blocksize_ bytes will be allocated for + // each buffer. + size_t blocksize_; + + // If set to true, then the first block cannot be used for Reserve() calls as + // the WriteToContiguous... function will modify the base pointer for this + // block, and the Reserve() calls need to be sure that the base pointer will + // not be changing in order to provide the user with StringPieces which + // continue to be valid. + bool can_write_to_contiguous_buffer_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +// All of the functions in the BalsaHeaders class use string pieces, by either +// using the StringPiece class, or giving an explicit size and char* (as these +// are the native representation for these string pieces). +// This is done for several reasons. +// 1) This minimizes copying/allocation/deallocation as compared to using +// string parameters +// 2) This reduces the number of strlen() calls done (as the length of any +// string passed in is relatively likely to be known at compile time, and for +// those strings passed back we obviate the need for a strlen() to determine +// the size of new storage allocations if a new allocation is required. +// 3) This class attempts to store all of its data in two linear buffers in +// order to enhance the speed of parsing and writing out to a buffer. As a +// result, many string pieces are -not- terminated by '\0', and are not +// c-strings. Since this is the case, we must delineate the length of the +// string explicitly via a length. +// +// WARNING: The side effect of using StringPiece is that if the underlying +// buffer changes (due to modifying the headers) the StringPieces which point +// to the data which was modified, may now contain "garbage", and should not +// be dereferenced. +// For example, If you fetch some component of the first-line, (request or +// response), and then you modify the first line, the StringPieces you +// originally received from the original first-line may no longer be valid). +// +// StringPieces pointing to pieces of header lines which have not been +// erased() or modified should be valid until the object is cleared or +// destroyed. +// +// Key comparisons are case-insensitive. + +class QUICHE_EXPORT BalsaHeaders : public HeaderApi { + public: + // Each header line is parsed into a HeaderLineDescription, which maintains + // pointers into the BalsaBuffer. + // + // Succinctly describes one header line as indices into a buffer. + struct QUICHE_EXPORT HeaderLineDescription { + HeaderLineDescription(size_t first_character_index, size_t key_end_index, + size_t value_begin_index, size_t last_character_index, + size_t buffer_base_index) + : first_char_idx(first_character_index), + key_end_idx(key_end_index), + value_begin_idx(value_begin_index), + last_char_idx(last_character_index), + buffer_base_idx(buffer_base_index), + skip(false) {} + + HeaderLineDescription() + : first_char_idx(0), + key_end_idx(0), + value_begin_idx(0), + last_char_idx(0), + buffer_base_idx(0), + skip(false) {} + + size_t KeyLength() const { + QUICHE_DCHECK_GE(key_end_idx, first_char_idx); + return key_end_idx - first_char_idx; + } + size_t ValuesLength() const { + QUICHE_DCHECK_GE(last_char_idx, value_begin_idx); + return last_char_idx - value_begin_idx; + } + + size_t first_char_idx; + size_t key_end_idx; + size_t value_begin_idx; + size_t last_char_idx; + BalsaBuffer::Blocks::size_type buffer_base_idx; + bool skip; + }; + + using HeaderTokenList = std::vector; + + // An iterator for walking through all the header lines. + class const_header_lines_iterator; + + // An iterator that only stops at lines with a particular key + // (case-insensitive). See also GetIteratorForKey. + // + // Check against header_lines_key_end() to determine when iteration is + // finished. lines().end() will also work. + class const_header_lines_key_iterator; + + // A simple class that can be used in a range-based for loop. + template + class QUICHE_EXPORT iterator_range { + public: + using iterator = IteratorType; + using const_iterator = IteratorType; + using value_type = typename std::iterator_traits::value_type; + + iterator_range(IteratorType begin_iterator, IteratorType end_iterator) + : begin_iterator_(std::move(begin_iterator)), + end_iterator_(std::move(end_iterator)) {} + + IteratorType begin() const { return begin_iterator_; } + IteratorType end() const { return end_iterator_; } + + private: + IteratorType begin_iterator_, end_iterator_; + }; + + // Set of names of headers that might have multiple values. + // CoalesceOption::kCoalesce can be used to match Envoy behavior in + // WriteToBuffer(). + using MultivaluedHeadersSet = + absl::flat_hash_set; + + // Map of key => vector, where vector contains ordered list of all + // values for |key| (ignoring the casing). + using MultivaluedHeadersValuesMap = + absl::flat_hash_map, + StringPieceCaseHash, StringPieceCaseEqual>; + + BalsaHeaders() + : balsa_buffer_(4096), + content_length_(0), + content_length_status_(BalsaHeadersEnums::NO_CONTENT_LENGTH), + parsed_response_code_(0), + firstline_buffer_base_idx_(0), + whitespace_1_idx_(0), + non_whitespace_1_idx_(0), + whitespace_2_idx_(0), + non_whitespace_2_idx_(0), + whitespace_3_idx_(0), + non_whitespace_3_idx_(0), + whitespace_4_idx_(0), + transfer_encoding_is_chunked_(false) {} + + explicit BalsaHeaders(size_t bufsize) + : balsa_buffer_(bufsize), + content_length_(0), + content_length_status_(BalsaHeadersEnums::NO_CONTENT_LENGTH), + parsed_response_code_(0), + firstline_buffer_base_idx_(0), + whitespace_1_idx_(0), + non_whitespace_1_idx_(0), + whitespace_2_idx_(0), + non_whitespace_2_idx_(0), + whitespace_3_idx_(0), + non_whitespace_3_idx_(0), + whitespace_4_idx_(0), + transfer_encoding_is_chunked_(false) {} + + // Copying BalsaHeaders is expensive, so require that it be visible. + BalsaHeaders(const BalsaHeaders&) = delete; + BalsaHeaders& operator=(const BalsaHeaders&) = delete; + BalsaHeaders(BalsaHeaders&&) = default; + BalsaHeaders& operator=(BalsaHeaders&&) = default; + + // Returns a range that represents all of the header lines. + iterator_range lines() const; + + // Returns an iterator range consisting of the header lines matching key. + // String backing 'key' must remain valid for lifetime of range. + iterator_range lines( + absl::string_view key) const; + + // Returns a forward-only iterator that only stops at lines matching key. + // String backing 'key' must remain valid for lifetime of iterator. + // + // Check returned iterator against header_lines_key_end() to determine when + // iteration is finished. + // + // Consider calling lines(key)--it may be more readable. + const_header_lines_key_iterator GetIteratorForKey( + absl::string_view key) const; + + const_header_lines_key_iterator header_lines_key_end() const; + + void erase(const const_header_lines_iterator& it); + + void Clear(); + + // Explicit copy functions to avoid risk of accidental copies. + BalsaHeaders Copy() const { + BalsaHeaders copy; + copy.CopyFrom(*this); + return copy; + } + void CopyFrom(const BalsaHeaders& other); + + // Replaces header entries with key 'key' if they exist, or appends + // a new header if none exist. See 'AppendHeader' below for additional + // comments about ContentLength and TransferEncoding headers. Note that this + // will allocate new storage every time that it is called. + void ReplaceOrAppendHeader(absl::string_view key, + absl::string_view value) override; + + // Append a new header entry to the header object. Clients who wish to append + // Content-Length header should use SetContentLength() method instead of + // adding the content length header using AppendHeader (manually adding the + // content length header will not update the content_length_ and + // content_length_status_ values). + // Similarly, clients who wish to add or remove the transfer encoding header + // in order to apply or remove chunked encoding should use + // SetTransferEncodingToChunkedAndClearContentLength() or + // SetNoTransferEncoding() instead. + void AppendHeader(absl::string_view key, absl::string_view value) override; + + // Appends ',value' to an existing header named 'key'. If no header with the + // correct key exists, it will call AppendHeader(key, value). Calling this + // function on a key which exists several times in the headers will produce + // unpredictable results. + void AppendToHeader(absl::string_view key, absl::string_view value) override; + + // Appends ', value' to an existing header named 'key'. If no header with the + // correct key exists, it will call AppendHeader(key, value). Calling this + // function on a key which exists several times in the headers will produce + // unpredictable results. + void AppendToHeaderWithCommaAndSpace(absl::string_view key, + absl::string_view value) override; + + // Returns the value corresponding to the given header key. Returns an empty + // string if the header key does not exist. For headers that may consist of + // multiple lines, use GetAllOfHeader() instead. + // Make the QuicheLowerCaseString overload visible, + // and only override the absl::string_view one. + using HeaderApi::GetHeader; + absl::string_view GetHeader(absl::string_view key) const override; + + // Iterates over all currently valid header lines, appending their + // values into the vector 'out', in top-to-bottom order. + // Header-lines which have been erased are not currently valid, and + // will not have their values appended. Empty values will be + // represented as empty string. If 'key' doesn't exist in the headers at + // all, out will not be changed. We do not clear the vector out + // before adding new entries. If there are header lines with matching + // key but empty value then they are also added to the vector out. + // (Basically empty values are not treated in any special manner). + // + // Example: + // Input header: + // "GET / HTTP/1.0\r\n" + // "key1: v1\r\n" + // "key1: \r\n" + // "key1:\r\n" + // "key1: v1\r\n" + // "key1:v2\r\n" + // + // vector out is initially: ["foo"] + // vector out after GetAllOfHeader("key1", &out) is: + // ["foo", "v1", "", "", "v1", "v2"] + // + // See gfe::header_properties::IsMultivaluedHeader() for which headers + // GFE treats as being multivalued. + + // Make the QuicheLowerCaseString overload visible, + // and only override the absl::string_view one. + using HeaderApi::GetAllOfHeader; + void GetAllOfHeader(absl::string_view key, + std::vector* out) const override; + + // Same as above, but iterates over all header lines including removed ones. + // Appends their values into the vector 'out' in top-to-bottom order, + // first all valid headers then all that were removed. + void GetAllOfHeaderIncludeRemoved(absl::string_view key, + std::vector* out) const; + + // Joins all values for `key` into a comma-separated string. + // Make the QuicheLowerCaseString overload visible, + // and only override the absl::string_view one. + using HeaderApi::GetAllOfHeaderAsString; + std::string GetAllOfHeaderAsString(absl::string_view key) const override; + + // Determine if a given header is present. Case-insensitive. + inline bool HasHeader(absl::string_view key) const override { + return GetConstHeaderLinesIterator(key) != header_lines_.end(); + } + + // Goes through all headers with key 'key' and checks to see if one of the + // values is 'value'. Returns true if there are headers with the desired key + // and value, false otherwise. Case-insensitive for the key; case-sensitive + // for the value. + bool HeaderHasValue(absl::string_view key, + absl::string_view value) const override { + return HeaderHasValueHelper(key, value, true); + } + // Same as above, but also case-insensitive for the value. + bool HeaderHasValueIgnoreCase(absl::string_view key, + absl::string_view value) const override { + return HeaderHasValueHelper(key, value, false); + } + + // Returns true iff any header 'key' exists with non-empty value. + bool HasNonEmptyHeader(absl::string_view key) const override; + + const_header_lines_iterator GetHeaderPosition(absl::string_view key) const; + + // Removes all headers in given set |keys| at once efficiently. Keys + // are case insensitive. + // + // Alternatives considered: + // + // 1. Use string_hash_set<>, the caller (such as ClearHopByHopHeaders) lower + // cases the keys and RemoveAllOfHeaderInList just does lookup. This according + // to microbenchmark gives the best performance because it does not require + // an extra copy of the hash table. However, it is not taken because of the + // possible risk that caller could forget to lowercase the keys. + // + // 2. Use flat_hash_set + // or string_hash_set. Both appear + // to have (much) worse performance with WithoutDupToken and LongHeader case + // in microbenchmark. + void RemoveAllOfHeaderInList(const HeaderTokenList& keys) override; + + void RemoveAllOfHeader(absl::string_view key) override; + + // Removes all headers starting with 'key' [case insensitive] + void RemoveAllHeadersWithPrefix(absl::string_view prefix) override; + + // Returns true if we have at least one header with given prefix + // [case insensitive]. Currently for test use only. + bool HasHeadersWithPrefix(absl::string_view prefix) const override; + + // Returns the key value pairs for all headers where the header key begins + // with the specified prefix. + void GetAllOfHeaderWithPrefix( + absl::string_view prefix, + std::vector>* out) + const override; + + void GetAllHeadersWithLimit( + std::vector>* out, + int limit) const override; + + // Removes all values equal to a given value from header lines with given key. + // All string operations done here are case-sensitive. + // If a header line has only values matching the given value, the entire + // line is removed. + // If the given value is found in a multi-value header line mixed with other + // values, the line is edited in-place to remove the values. + // Returns the number of occurrences of value that were removed. + // This method runs in linear time. + size_t RemoveValue(absl::string_view key, absl::string_view value); + + // Returns the upper bound on the required buffer space to fully write out + // the header object (this include the first line, all header lines, and the + // final line separator that marks the ending of the header). + size_t GetSizeForWriteBuffer() const override; + + // Indicates if to serialize headers with lower-case header keys. + enum class CaseOption { kNoModification, kLowercase, kPropercase }; + + // Indicates if to coalesce headers with multiple values to match Envoy/GFE3. + enum class CoalesceOption { kNoCoalesce, kCoalesce }; + + // The following WriteHeader* methods are template member functions that + // place one requirement on the Buffer class: it must implement a Write + // method that takes a pointer and a length. The buffer passed in is not + // required to be stretchable. For non-stretchable buffers, the user must + // call GetSizeForWriteBuffer() to find out the upper bound on the output + // buffer space required to make sure that the entire header is serialized. + // BalsaHeaders will not check that there is adequate space in the buffer + // object during the write. + + // Writes the entire header and the final line separator that marks the end + // of the HTTP header section to the buffer. After this method returns, no + // more header data should be written to the buffer. + template + void WriteHeaderAndEndingToBuffer(Buffer* buffer, CaseOption case_option, + CoalesceOption coalesce_option) const { + WriteToBuffer(buffer, case_option, coalesce_option); + WriteHeaderEndingToBuffer(buffer); + } + + template + void WriteHeaderAndEndingToBuffer(Buffer* buffer) const { + WriteHeaderAndEndingToBuffer(buffer, CaseOption::kNoModification, + CoalesceOption::kNoCoalesce); + } + + // Writes the final line separator to the buffer to terminate the HTTP header + // section. After this method returns, no more header data should be written + // to the buffer. + template + static void WriteHeaderEndingToBuffer(Buffer* buffer) { + buffer->WriteString("\r\n"); + } + + // Writes the entire header to the buffer without the line separator that + // terminates the HTTP header. This lets users append additional header lines + // using WriteHeaderLineToBuffer and then terminate the header with + // WriteHeaderEndingToBuffer as the header is serialized to the buffer, + // without having to first copy the header. + template + void WriteToBuffer(Buffer* buffer, CaseOption case_option, + CoalesceOption coalesce_option) const; + + template + void WriteToBuffer(Buffer* buffer) const { + WriteToBuffer(buffer, CaseOption::kNoModification, + CoalesceOption::kNoCoalesce); + } + + // Used by WriteToBuffer to coalesce multiple values of headers listed in + // |multivalued_headers| into a single comma-separated value. Public for test. + template + void WriteToBufferCoalescingMultivaluedHeaders( + Buffer* buffer, const MultivaluedHeadersSet& multivalued_headers, + CaseOption case_option) const; + + // Populates |multivalues| with values of |header_lines_| with keys present + // in |multivalued_headers| set. + void GetValuesOfMultivaluedHeaders( + const MultivaluedHeadersSet& multivalued_headers, + MultivaluedHeadersValuesMap* multivalues) const; + + static std::string ToPropercase(absl::string_view header) { + std::string copy = std::string(header); + bool should_uppercase = true; + for (char& c : copy) { + if (!absl::ascii_isalnum(c)) { + should_uppercase = true; + } else if (should_uppercase) { + c = absl::ascii_toupper(c); + should_uppercase = false; + } else { + c = absl::ascii_tolower(c); + } + } + return copy; + } + + template + void WriteHeaderKeyToBuffer(Buffer* buffer, absl::string_view key, + CaseOption case_option) const { + if (case_option == CaseOption::kLowercase) { + buffer->WriteString(absl::AsciiStrToLower(key)); + } else if (case_option == CaseOption::kPropercase) { + const auto& header_set = quiche::GetStandardHeaderSet(); + auto it = header_set.find(key); + if (it != header_set.end()) { + buffer->WriteString(*it); + } else { + buffer->WriteString(ToPropercase(key)); + } + } else { + buffer->WriteString(key); + } + } + + // Takes a header line in the form of a key/value pair and append it to the + // buffer. This function should be called after WriteToBuffer to + // append additional header lines to the header without copying the header. + // When the user is done with appending to the buffer, + // WriteHeaderEndingToBuffer must be used to terminate the HTTP + // header in the buffer. This method is a no-op if key is empty. + template + void WriteHeaderLineToBuffer(Buffer* buffer, absl::string_view key, + absl::string_view value, + CaseOption case_option) const { + // If the key is empty, we don't want to write the rest because it + // will not be a well-formed header line. + if (!key.empty()) { + WriteHeaderKeyToBuffer(buffer, key, case_option); + buffer->WriteString(": "); + buffer->WriteString(value); + buffer->WriteString("\r\n"); + } + } + + // Takes a header line in the form of a key and vector of values and appends + // it to the buffer. This function should be called after WriteToBuffer to + // append additional header lines to the header without copying the header. + // When the user is done with appending to the buffer, + // WriteHeaderEndingToBuffer must be used to terminate the HTTP + // header in the buffer. This method is a no-op if the |key| is empty. + template + void WriteHeaderLineValuesToBuffer( + Buffer* buffer, absl::string_view key, + const std::vector& values, + CaseOption case_option) const { + // If the key is empty, we don't want to write the rest because it + // will not be a well-formed header line. + if (!key.empty()) { + WriteHeaderKeyToBuffer(buffer, key, case_option); + buffer->WriteString(": "); + for (auto it = values.begin();;) { + buffer->WriteString(*it); + if (++it == values.end()) { + break; + } + buffer->WriteString(","); + } + buffer->WriteString("\r\n"); + } + } + + // Dump the textural representation of the header object to a string, which + // is suitable for writing out to logs. All CRLF will be printed out as \n. + // This function can be called on a header object in any state. Raw header + // data will be printed out if the header object is not completely parsed, + // e.g., when there was an error in the middle of parsing. + // The header content is appended to the string; the original content is not + // cleared. + // If used in test cases, WillNotWriteFromFramer() may be of interest. + void DumpToString(std::string* str) const; + std::string DebugString() const override; + + bool ForEachHeader(std::function + fn) const override; + + void DumpToPrefixedString(const char* spaces, std::string* str) const; + + absl::string_view first_line() const { + QUICHE_DCHECK_GE(whitespace_4_idx_, non_whitespace_1_idx_); + return whitespace_4_idx_ == non_whitespace_1_idx_ + ? "" + : absl::string_view( + BeginningOfFirstLine() + non_whitespace_1_idx_, + whitespace_4_idx_ - non_whitespace_1_idx_); + } + std::string first_line_of_request() const override { + return std::string(first_line()); + } + + // Returns the parsed value of the response code if it has been parsed. + // Guaranteed to return 0 when unparsed (though it is a much better idea to + // verify that the BalsaFrame had no errors while parsing). + // This may return response codes which are outside the normal bounds of + // HTTP response codes-- it is up to the user of this class to ensure that + // the response code is one which is interpretable. + size_t parsed_response_code() const override { return parsed_response_code_; } + + absl::string_view request_method() const override { + QUICHE_DCHECK_GE(whitespace_2_idx_, non_whitespace_1_idx_); + return whitespace_2_idx_ == non_whitespace_1_idx_ + ? "" + : absl::string_view( + BeginningOfFirstLine() + non_whitespace_1_idx_, + whitespace_2_idx_ - non_whitespace_1_idx_); + } + + absl::string_view response_version() const override { + // Note: There is no difference between request_method() and + // response_version(). They both could be called + // GetFirstTokenFromFirstline()... but that wouldn't be anywhere near as + // descriptive. + return request_method(); + } + + absl::string_view request_uri() const override { + QUICHE_DCHECK_GE(whitespace_3_idx_, non_whitespace_2_idx_); + return whitespace_3_idx_ == non_whitespace_2_idx_ + ? "" + : absl::string_view( + BeginningOfFirstLine() + non_whitespace_2_idx_, + whitespace_3_idx_ - non_whitespace_2_idx_); + } + + absl::string_view response_code() const override { + // Note: There is no difference between request_uri() and response_code(). + // They both could be called GetSecondtTokenFromFirstline(), but, as noted + // in an earlier comment, that wouldn't be as descriptive. + return request_uri(); + } + + absl::string_view request_version() const override { + QUICHE_DCHECK_GE(whitespace_4_idx_, non_whitespace_3_idx_); + return whitespace_4_idx_ == non_whitespace_3_idx_ + ? "" + : absl::string_view( + BeginningOfFirstLine() + non_whitespace_3_idx_, + whitespace_4_idx_ - non_whitespace_3_idx_); + } + + absl::string_view response_reason_phrase() const override { + // Note: There is no difference between request_version() and + // response_reason_phrase(). They both could be called + // GetThirdTokenFromFirstline(), but, as noted in an earlier comment, that + // wouldn't be as descriptive. + return request_version(); + } + + void SetRequestFirstlineFromStringPieces(absl::string_view method, + absl::string_view uri, + absl::string_view version) { + SetFirstlineFromStringPieces(method, uri, version); + } + + void SetResponseFirstline(absl::string_view version, + size_t parsed_response_code, + absl::string_view reason_phrase); + + // These functions are exactly the same, except that their names are + // different. This is done so that the code using this class is more + // expressive. + void SetRequestMethod(absl::string_view method) override; + void SetResponseVersion(absl::string_view version) override; + + void SetRequestUri(absl::string_view uri) override; + void SetResponseCode(absl::string_view code) override; + void set_parsed_response_code(size_t parsed_response_code) { + parsed_response_code_ = parsed_response_code; + } + void SetParsedResponseCodeAndUpdateFirstline( + size_t parsed_response_code) override; + + // These functions are exactly the same, except that their names are + // different. This is done so that the code using this class is more + // expressive. + void SetRequestVersion(absl::string_view version) override; + void SetResponseReasonPhrase(absl::string_view reason_phrase) override; + + // Simple accessors to some of the internal state + bool transfer_encoding_is_chunked() const { + return transfer_encoding_is_chunked_; + } + + static bool ResponseCodeImpliesNoBody(size_t code) { + // From HTTP spec section 6.1.1 all 1xx responses must not have a body, + // as well as 204 No Content and 304 Not Modified. + return ((code >= 100) && (code <= 199)) || (code == 204) || (code == 304); + } + + // Note: never check this for requests. Nothing bad will happen if you do, + // but spec does not allow requests framed by connection close. + // TODO(vitaliyl): refactor. + bool is_framed_by_connection_close() const { + // We declare that response is framed by connection close if it has no + // content-length, no transfer encoding, and is allowed to have a body by + // the HTTP spec. + // parsed_response_code_ is 0 for requests, so ResponseCodeImpliesNoBody + // will return false. + return (content_length_status_ == BalsaHeadersEnums::NO_CONTENT_LENGTH) && + !transfer_encoding_is_chunked_ && + !ResponseCodeImpliesNoBody(parsed_response_code_); + } + + size_t content_length() const override { return content_length_; } + BalsaHeadersEnums::ContentLengthStatus content_length_status() const { + return content_length_status_; + } + bool content_length_valid() const override { + return content_length_status_ == BalsaHeadersEnums::VALID_CONTENT_LENGTH; + } + + // SetContentLength, SetTransferEncodingToChunkedAndClearContentLength, and + // SetNoTransferEncoding modifies the header object to use + // content-length and transfer-encoding headers in a consistent + // manner. They set all internal flags and status so client can get + // a consistent view from various accessors. + void SetContentLength(size_t length) override; + // Sets transfer-encoding to chunked and updates internal state. + void SetTransferEncodingToChunkedAndClearContentLength() override; + // Removes transfer-encoding headers and updates internal state. + void SetNoTransferEncoding() override; + + // If you have a response that needs framing by connection close, use this + // method instead of RemoveAllOfHeader("Content-Length"). Has no effect if + // transfer_encoding_is_chunked(). + void ClearContentLength(); + + // This should be called if balsa headers are created entirely manually (not + // by any of the framer classes) to make sure that function calls like + // DumpToString will work correctly. + void WillNotWriteFromFramer() { + balsa_buffer_.NoMoreWriteToContiguousBuffer(); + } + + // True if DoneWritingFromFramer or WillNotWriteFromFramer is called. + bool FramerIsDoneWriting() const { + return !balsa_buffer_.can_write_to_contiguous_buffer(); + } + + bool IsEmpty() const override; + + // From HeaderApi and ConstHeaderApi. + absl::string_view Authority() const override; + void ReplaceOrAppendAuthority(absl::string_view value) override; + void RemoveAuthority() override; + void ApplyToCookie( + std::function f) const override; + + void set_enforce_header_policy(bool enforce) override { + enforce_header_policy_ = enforce; + } + + // Removes the last token from the header value. In the presence of multiple + // header lines with given key, will remove the last token of the last line. + // Can be useful if the last encoding has to be removed. + void RemoveLastTokenFromHeaderValue(absl::string_view key); + + // Gets the list of names of headers that are multivalued in Envoy. + static const MultivaluedHeadersSet& multivalued_envoy_headers(); + + // Returns true if HTTP responses with this response code have bodies. + static bool ResponseCanHaveBody(int response_code); + + // Given a pointer to the beginning and the end of the header value + // in some buffer, populates tokens list with beginning and end indices + // of all tokens present in the value string. + static void ParseTokenList(absl::string_view header_value, + HeaderTokenList* tokens); + + private: + typedef std::vector HeaderLines; + + class iterator_base; + + friend class BalsaFrame; + friend class gfe2::Http2HeaderValidator; + friend class SpdyPayloadFramer; + friend class HTTPMessage; + friend class test::BalsaHeadersTestPeer; + + friend bool ParseHTTPFirstLine(const char* begin, const char* end, + bool is_request, BalsaHeaders* headers, + BalsaFrameEnums::ErrorCode* error_code); + + // Reverse iterators have been removed for lack of use, refer to + // cl/30618773 in case they are needed. + + const char* BeginningOfFirstLine() const { + return GetPtr(firstline_buffer_base_idx_); + } + + char* BeginningOfFirstLine() { return GetPtr(firstline_buffer_base_idx_); } + + char* GetPtr(BalsaBuffer::Blocks::size_type block_idx) { + return balsa_buffer_.GetPtr(block_idx); + } + + const char* GetPtr(BalsaBuffer::Blocks::size_type block_idx) const { + return balsa_buffer_.GetPtr(block_idx); + } + + void WriteFromFramer(const char* ptr, size_t size) { + balsa_buffer_.WriteToContiguousBuffer(absl::string_view(ptr, size)); + } + + void DoneWritingFromFramer() { + balsa_buffer_.NoMoreWriteToContiguousBuffer(); + } + + const char* OriginalHeaderStreamBegin() const { + return balsa_buffer_.StartOfFirstBlock(); + } + + const char* OriginalHeaderStreamEnd() const { + return balsa_buffer_.EndOfFirstBlock(); + } + + size_t GetReadableBytesFromHeaderStream() const { + return balsa_buffer_.GetReadableBytesOfFirstBlock(); + } + + absl::string_view GetReadablePtrFromHeaderStream() { + return {OriginalHeaderStreamBegin(), GetReadableBytesFromHeaderStream()}; + } + + absl::string_view GetValueFromHeaderLineDescription( + const HeaderLineDescription& line) const; + + void AddAndMakeDescription(absl::string_view key, absl::string_view value, + HeaderLineDescription* d); + + void AppendAndMakeDescription(absl::string_view key, absl::string_view value, + HeaderLineDescription* d); + + // Removes all header lines with the given key starting at start. + void RemoveAllOfHeaderStartingAt(absl::string_view key, + HeaderLines::iterator start); + + HeaderLines::const_iterator GetConstHeaderLinesIterator( + absl::string_view key) const; + + HeaderLines::iterator GetHeaderLinesIterator(absl::string_view key, + HeaderLines::iterator start); + + HeaderLines::iterator GetHeaderLinesIteratorForLastMultivaluedHeader( + absl::string_view key); + + template + const IteratorType HeaderLinesBeginHelper() const; + + template + const IteratorType HeaderLinesEndHelper() const; + + // Helper function for HeaderHasValue and HeaderHasValueIgnoreCase that + // does most of the work. + bool HeaderHasValueHelper(absl::string_view key, absl::string_view value, + bool case_sensitive) const; + + // Called by header removal methods to reset internal values for transfer + // encoding or content length if we're removing the corresponding headers. + void MaybeClearSpecialHeaderValues(absl::string_view key); + + void SetFirstlineFromStringPieces(absl::string_view firstline_a, + absl::string_view firstline_b, + absl::string_view firstline_c); + BalsaBuffer balsa_buffer_; + + size_t content_length_; + BalsaHeadersEnums::ContentLengthStatus content_length_status_; + size_t parsed_response_code_; + // HTTP firstlines all have the following structure: + // LWS NONWS LWS NONWS LWS NONWS NOTCRLF CRLF + // [\t \r\n]+ [^\t ]+ [\t ]+ [^\t ]+ [\t ]+ [^\t ]+ [^\r\n]+ "\r\n" + // ws1 nws1 ws2 nws2 ws3 nws3 ws4 + // | [-------) [-------) [----------------) + // REQ: method request_uri version + // RESP: version statuscode reason + // + // The first NONWS->LWS component we'll call firstline_a. + // The second firstline_b, and the third firstline_c. + // + // firstline_a goes from nws1 to (but not including) ws2 + // firstline_b goes from nws2 to (but not including) ws3 + // firstline_c goes from nws3 to (but not including) ws4 + // + // In the code: + // ws1 == whitespace_1_idx_ + // nws1 == non_whitespace_1_idx_ + // ws2 == whitespace_2_idx_ + // nws2 == non_whitespace_2_idx_ + // ws3 == whitespace_3_idx_ + // nws3 == non_whitespace_3_idx_ + // ws4 == whitespace_4_idx_ + BalsaBuffer::Blocks::size_type firstline_buffer_base_idx_; + size_t whitespace_1_idx_; + size_t non_whitespace_1_idx_; + size_t whitespace_2_idx_; + size_t non_whitespace_2_idx_; + size_t whitespace_3_idx_; + size_t non_whitespace_3_idx_; + size_t whitespace_4_idx_; + + bool transfer_encoding_is_chunked_; + + // If true, QUICHE_BUG if a header that starts with an invalid prefix is + // explicitly set. + bool enforce_header_policy_ = true; + + HeaderLines header_lines_; +}; + +// Base class for iterating the headers in a BalsaHeaders object, returning a +// pair of string_view's for each header. +class QUICHE_EXPORT BalsaHeaders::iterator_base + : public std::iterator> { + public: + iterator_base() : headers_(nullptr), idx_(0) {} + + std::pair& operator*() const { + return Lookup(idx_); + } + + std::pair* operator->() const { + return &(this->operator*()); + } + + bool operator==(const BalsaHeaders::iterator_base& it) const { + return idx_ == it.idx_; + } + + bool operator<(const BalsaHeaders::iterator_base& it) const { + return idx_ < it.idx_; + } + + bool operator<=(const BalsaHeaders::iterator_base& it) const { + return idx_ <= it.idx_; + } + + bool operator!=(const BalsaHeaders::iterator_base& it) const { + return !(*this == it); + } + + bool operator>(const BalsaHeaders::iterator_base& it) const { + return it < *this; + } + + bool operator>=(const BalsaHeaders::iterator_base& it) const { + return it <= *this; + } + + // This mainly exists so that we can have interesting output for + // unittesting. The EXPECT_EQ, EXPECT_NE functions require that + // operator<< work for the classes it sees. It would be better if there + // was an additional traits-like system for the gUnit output... but oh + // well. + friend QUICHE_EXPORT std::ostream& operator<<(std::ostream& os, + const iterator_base& it) { + os << "[" << it.headers_ << ", " << it.idx_ << "]"; + return os; + } + + private: + friend class BalsaHeaders; + + iterator_base(const BalsaHeaders* headers, HeaderLines::size_type index) + : headers_(headers), idx_(index) {} + + void increment() { + value_.reset(); + const HeaderLines& header_lines = headers_->header_lines_; + const HeaderLines::size_type header_lines_size = header_lines.size(); + const HeaderLines::size_type original_idx = idx_; + do { + ++idx_; + } while (idx_ < header_lines_size && header_lines[idx_].skip == true); + // The condition below exists so that ++(end() - 1) == end(), even + // if there are only 'skip == true' elements between the end() iterator + // and the end of the vector of HeaderLineDescriptions. + if (idx_ == header_lines_size) { + idx_ = original_idx + 1; + } + } + + std::pair& Lookup( + HeaderLines::size_type index) const { + QUICHE_DCHECK_LT(index, headers_->header_lines_.size()); + if (!value_.has_value()) { + const HeaderLineDescription& line = headers_->header_lines_[index]; + const char* stream_begin = headers_->GetPtr(line.buffer_base_idx); + value_ = + std::make_pair(absl::string_view(stream_begin + line.first_char_idx, + line.KeyLength()), + absl::string_view(stream_begin + line.value_begin_idx, + line.ValuesLength())); + } + return value_.value(); + } + + const BalsaHeaders* headers_; + HeaderLines::size_type idx_; + mutable absl::optional> + value_; +}; + +// A const iterator for all the header lines. +class QUICHE_EXPORT BalsaHeaders::const_header_lines_iterator + : public BalsaHeaders::iterator_base { + public: + const_header_lines_iterator() : iterator_base() {} + + const_header_lines_iterator& operator++() { + iterator_base::increment(); + return *this; + } + + private: + friend class BalsaHeaders; + + const_header_lines_iterator(const BalsaHeaders* headers, + HeaderLines::size_type index) + : iterator_base(headers, index) {} +}; + +// A const iterator that stops only on header lines for a particular key. +class QUICHE_EXPORT BalsaHeaders::const_header_lines_key_iterator + : public BalsaHeaders::iterator_base { + public: + const_header_lines_key_iterator& operator++() { + do { + iterator_base::increment(); + } while (!AtEnd() && !absl::EqualsIgnoreCase(key_, (**this).first)); + return *this; + } + + // Only forward-iteration makes sense, so no operator-- defined. + + private: + friend class BalsaHeaders; + + const_header_lines_key_iterator(const BalsaHeaders* headers, + HeaderLines::size_type index, + absl::string_view key) + : iterator_base(headers, index), key_(key) {} + + // Should only be used for creating an end iterator. + const_header_lines_key_iterator(const BalsaHeaders* headers, + HeaderLines::size_type index) + : iterator_base(headers, index) {} + + bool AtEnd() const { return *this >= headers_->lines().end(); } + + absl::string_view key_; +}; + +inline BalsaHeaders::iterator_range +BalsaHeaders::lines() const { + return {HeaderLinesBeginHelper(), + HeaderLinesEndHelper()}; +} + +inline BalsaHeaders::iterator_range< + BalsaHeaders::const_header_lines_key_iterator> +BalsaHeaders::lines(absl::string_view key) const { + return {GetIteratorForKey(key), header_lines_key_end()}; +} + +inline BalsaHeaders::const_header_lines_key_iterator +BalsaHeaders::header_lines_key_end() const { + return HeaderLinesEndHelper(); +} + +inline void BalsaHeaders::erase(const const_header_lines_iterator& it) { + QUICHE_DCHECK_EQ(it.headers_, this); + QUICHE_DCHECK_LT(it.idx_, header_lines_.size()); + header_lines_[it.idx_].skip = true; +} + +template +void BalsaHeaders::WriteToBuffer(Buffer* buffer, CaseOption case_option, + CoalesceOption coalesce_option) const { + // write the first line. + const absl::string_view firstline = first_line(); + if (!firstline.empty()) { + buffer->WriteString(firstline); + } + buffer->WriteString("\r\n"); + if (coalesce_option != CoalesceOption::kCoalesce) { + const HeaderLines::size_type end = header_lines_.size(); + for (HeaderLines::size_type i = 0; i < end; ++i) { + const HeaderLineDescription& line = header_lines_[i]; + if (line.skip) { + continue; + } + const char* line_ptr = GetPtr(line.buffer_base_idx); + WriteHeaderLineToBuffer( + buffer, + absl::string_view(line_ptr + line.first_char_idx, line.KeyLength()), + absl::string_view(line_ptr + line.value_begin_idx, + line.ValuesLength()), + case_option); + } + } else { + WriteToBufferCoalescingMultivaluedHeaders( + buffer, multivalued_envoy_headers(), case_option); + } +} + +inline void BalsaHeaders::GetValuesOfMultivaluedHeaders( + const MultivaluedHeadersSet& multivalued_headers, + MultivaluedHeadersValuesMap* multivalues) const { + multivalues->reserve(header_lines_.capacity()); + + // Find lines that need to be coalesced and store them in |multivalues|. + for (const auto& line : header_lines_) { + if (line.skip) { + continue; + } + const char* line_ptr = GetPtr(line.buffer_base_idx); + absl::string_view header_key = + absl::string_view(line_ptr + line.first_char_idx, line.KeyLength()); + // If this is multivalued header, it may need to be coalesced. + if (multivalued_headers.contains(header_key)) { + absl::string_view header_value = absl::string_view( + line_ptr + line.value_begin_idx, line.ValuesLength()); + // Add |header_value| to the vector of values for this |header_key|, + // therefore preserving the order of values for the same key. + (*multivalues)[header_key].push_back(header_value); + } + } +} + +template +void BalsaHeaders::WriteToBufferCoalescingMultivaluedHeaders( + Buffer* buffer, const MultivaluedHeadersSet& multivalued_headers, + CaseOption case_option) const { + MultivaluedHeadersValuesMap multivalues; + GetValuesOfMultivaluedHeaders(multivalued_headers, &multivalues); + + // Write out header lines while coalescing those that need to be coalesced. + for (const auto& line : header_lines_) { + if (line.skip) { + continue; + } + const char* line_ptr = GetPtr(line.buffer_base_idx); + absl::string_view header_key = + absl::string_view(line_ptr + line.first_char_idx, line.KeyLength()); + auto header_multivalue = multivalues.find(header_key); + // If current line doesn't need to be coalesced (as it is either not + // multivalue, or has just a single value so it equals to current line), + // then just write it out. + if (header_multivalue == multivalues.end() || + header_multivalue->second.size() == 1) { + WriteHeaderLineToBuffer(buffer, header_key, + absl::string_view(line_ptr + line.value_begin_idx, + line.ValuesLength()), + case_option); + } else { + // If this line needs to be coalesced, then write all its values and clear + // them, so the subsequent same header keys will not be written. + if (!header_multivalue->second.empty()) { + WriteHeaderLineValuesToBuffer(buffer, header_key, + header_multivalue->second, case_option); + // Clear the multivalue list as it is already written out, so subsequent + // same header keys will not be written. + header_multivalue->second.clear(); + } + } + } +} + +template +const IteratorType BalsaHeaders::HeaderLinesBeginHelper() const { + if (header_lines_.empty()) { + return IteratorType(this, 0); + } + const HeaderLines::size_type header_lines_size = header_lines_.size(); + for (HeaderLines::size_type i = 0; i < header_lines_size; ++i) { + if (header_lines_[i].skip == false) { + return IteratorType(this, i); + } + } + return IteratorType(this, 0); +} + +template +const IteratorType BalsaHeaders::HeaderLinesEndHelper() const { + if (header_lines_.empty()) { + return IteratorType(this, 0); + } + const HeaderLines::size_type header_lines_size = header_lines_.size(); + HeaderLines::size_type i = header_lines_size; + do { + --i; + if (header_lines_[i].skip == false) { + return IteratorType(this, i + 1); + } + } while (i != 0); + return IteratorType(this, 0); +} + +} // namespace quiche + +#endif // QUICHE_BALSA_BALSA_HEADERS_H_ diff --git a/quiche/balsa/balsa_headers_test.cc b/quiche/balsa/balsa_headers_test.cc new file mode 100644 index 000000000000..d132a5182be0 --- /dev/null +++ b/quiche/balsa/balsa_headers_test.cc @@ -0,0 +1,3722 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Note that several of the BalsaHeaders functions are +// tested in the balsa_frame_test as the BalsaFrame and +// BalsaHeaders classes are fairly related. + +#include "quiche/balsa/balsa_headers.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "quiche/balsa/balsa_enums.h" +#include "quiche/balsa/balsa_frame.h" +#include "quiche/balsa/simple_buffer.h" +#include "quiche/common/platform/api/quiche_expect_bug.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +using absl::make_unique; +using testing::AnyOf; +using testing::Combine; +using testing::ElementsAre; +using testing::Eq; +using testing::StrEq; +using testing::ValuesIn; + +namespace quiche { + +namespace test { + +class BalsaHeadersTestPeer { + public: + static void WriteFromFramer(BalsaHeaders* headers, const char* ptr, + size_t size) { + headers->WriteFromFramer(ptr, size); + } +}; + +namespace { + +class BalsaBufferTest : public QuicheTest { + public: + void CreateBuffer(size_t blocksize) { + buffer_ = std::make_unique(blocksize); + } + void CreateBuffer() { buffer_ = std::make_unique(); } + static std::unique_ptr CreateUnmanagedBuffer(size_t blocksize) { + return std::make_unique(blocksize); + } + absl::string_view Write(absl::string_view sp, size_t* block_buffer_idx) { + if (sp.empty()) { + return sp; + } + char* storage = buffer_->Reserve(sp.size(), block_buffer_idx); + memcpy(storage, sp.data(), sp.size()); + return absl::string_view(storage, sp.size()); + } + + protected: + std::unique_ptr buffer_; +}; + +using BufferBlock = BalsaBuffer::BufferBlock; + +BufferBlock MakeBufferBlock(const std::string& s) { + // Make the buffer twice the size needed to verify that CopyFrom copies our + // buffer_size (as opposed to shrinking to fit or reusing an old buffer). + BufferBlock block{make_unique(s.size()), s.size() * 2, s.size()}; + std::memcpy(block.buffer.get(), s.data(), s.size()); + return block; +} + +BalsaHeaders CreateHTTPHeaders(bool request, absl::string_view s) { + BalsaHeaders headers; + BalsaFrame framer; + framer.set_is_request(request); + framer.set_balsa_headers(&headers); + QUICHE_CHECK_EQ(s.size(), framer.ProcessInput(s.data(), s.size())); + QUICHE_CHECK(framer.MessageFullyRead()); + return headers; +} + +class BufferBlockTest + : public QuicheTestWithParam> {}; + +TEST_P(BufferBlockTest, CopyFrom) { + const std::string s1 = std::get<0>(GetParam()); + const std::string s2 = std::get<1>(GetParam()); + BufferBlock block; + block.CopyFrom(MakeBufferBlock(s1)); + EXPECT_EQ(s1.size(), block.bytes_free); + ASSERT_EQ(2 * s1.size(), block.buffer_size); + EXPECT_EQ(0, memcmp(s1.data(), block.buffer.get(), s1.size())); + block.CopyFrom(MakeBufferBlock(s2)); + EXPECT_EQ(s2.size(), block.bytes_free); + ASSERT_EQ(2 * s2.size(), block.buffer_size); + EXPECT_EQ(0, memcmp(s2.data(), block.buffer.get(), s2.size())); +} + +const char* block_strings[] = {"short string", "longer than the other string"}; +INSTANTIATE_TEST_SUITE_P(VariousSizes, BufferBlockTest, + Combine(ValuesIn(block_strings), + ValuesIn(block_strings))); + +TEST_F(BalsaBufferTest, BlocksizeSet) { + CreateBuffer(); + EXPECT_EQ(BalsaBuffer::kDefaultBlocksize, buffer_->blocksize()); + CreateBuffer(1024); + EXPECT_EQ(1024u, buffer_->blocksize()); +} + +TEST_F(BalsaBufferTest, GetMemorySize) { + CreateBuffer(10); + EXPECT_EQ(0u, buffer_->GetTotalBytesUsed()); + EXPECT_EQ(0u, buffer_->GetTotalBufferBlockSize()); + BalsaBuffer::Blocks::size_type index; + buffer_->Reserve(1024, &index); + EXPECT_EQ(10u + 1024u, buffer_->GetTotalBufferBlockSize()); + EXPECT_EQ(1024u, buffer_->GetTotalBytesUsed()); +} + +TEST_F(BalsaBufferTest, ManyWritesToContiguousBuffer) { + CreateBuffer(0); + // The test is that the process completes. If it needs to do a resize on + // every write, it will timeout or run out of memory. + // ( 10 + 20 + 30 + ... + 1.2e6 bytes => ~1e11 bytes ) + std::string data = "0123456789"; + for (int i = 0; i < 120 * 1000; ++i) { + buffer_->WriteToContiguousBuffer(data); + } +} + +TEST_F(BalsaBufferTest, CopyFrom) { + CreateBuffer(10); + std::unique_ptr ptr = CreateUnmanagedBuffer(1024); + ASSERT_EQ(1024u, ptr->blocksize()); + EXPECT_EQ(0u, ptr->num_blocks()); + + std::string data1 = "foobarbaz01"; + buffer_->WriteToContiguousBuffer(data1); + buffer_->NoMoreWriteToContiguousBuffer(); + std::string data2 = "12345"; + Write(data2, nullptr); + std::string data3 = "6789"; + Write(data3, nullptr); + std::string data4 = "123456789012345"; + Write(data3, nullptr); + + ptr->CopyFrom(*buffer_); + + EXPECT_EQ(ptr->can_write_to_contiguous_buffer(), + buffer_->can_write_to_contiguous_buffer()); + ASSERT_EQ(ptr->num_blocks(), buffer_->num_blocks()); + for (size_t i = 0; i < buffer_->num_blocks(); ++i) { + ASSERT_EQ(ptr->bytes_used(i), buffer_->bytes_used(i)); + ASSERT_EQ(ptr->buffer_size(i), buffer_->buffer_size(i)); + EXPECT_EQ(0, + memcmp(ptr->GetPtr(i), buffer_->GetPtr(i), ptr->bytes_used(i))); + } +} + +TEST_F(BalsaBufferTest, ClearWorks) { + CreateBuffer(10); + + std::string data1 = "foobarbaz01"; + buffer_->WriteToContiguousBuffer(data1); + buffer_->NoMoreWriteToContiguousBuffer(); + std::string data2 = "12345"; + Write(data2, nullptr); + std::string data3 = "6789"; + Write(data3, nullptr); + std::string data4 = "123456789012345"; + Write(data3, nullptr); + + buffer_->Clear(); + + EXPECT_TRUE(buffer_->can_write_to_contiguous_buffer()); + EXPECT_EQ(10u, buffer_->blocksize()); + EXPECT_EQ(0u, buffer_->num_blocks()); +} + +TEST_F(BalsaBufferTest, ClearWorksWhenLargerThanBlocksize) { + CreateBuffer(10); + + std::string data1 = "foobarbaz01lkjasdlkjasdlkjasd"; + buffer_->WriteToContiguousBuffer(data1); + buffer_->NoMoreWriteToContiguousBuffer(); + std::string data2 = "12345"; + Write(data2, nullptr); + std::string data3 = "6789"; + Write(data3, nullptr); + std::string data4 = "123456789012345"; + Write(data3, nullptr); + + buffer_->Clear(); + + EXPECT_TRUE(buffer_->can_write_to_contiguous_buffer()); + EXPECT_EQ(10u, buffer_->blocksize()); + EXPECT_EQ(0u, buffer_->num_blocks()); +} + +TEST_F(BalsaBufferTest, ContiguousWriteSmallerThanBlocksize) { + CreateBuffer(1024); + + std::string data1 = "foo"; + buffer_->WriteToContiguousBuffer(data1); + std::string composite = data1; + const char* buf_ptr = buffer_->GetPtr(0); + ASSERT_LE(composite.size(), buffer_->buffer_size(0)); + EXPECT_EQ(0, memcmp(composite.data(), buf_ptr, composite.size())); + + std::string data2 = "barbaz"; + buffer_->WriteToContiguousBuffer(data2); + composite += data2; + buf_ptr = buffer_->GetPtr(0); + ASSERT_LE(composite.size(), buffer_->buffer_size(0)); + EXPECT_EQ(0, memcmp(composite.data(), buf_ptr, composite.size())); +} + +TEST_F(BalsaBufferTest, SingleContiguousWriteLargerThanBlocksize) { + CreateBuffer(10); + + std::string data1 = "abracadabrawords"; + buffer_->WriteToContiguousBuffer(data1); + std::string composite = data1; + const char* buf_ptr = buffer_->GetPtr(0); + ASSERT_LE(data1.size(), buffer_->buffer_size(0)); + EXPECT_EQ(0, memcmp(composite.data(), buf_ptr, composite.size())) + << composite << "\n" + << absl::string_view(buf_ptr, buffer_->bytes_used(0)); +} + +TEST_F(BalsaBufferTest, ContiguousWriteLargerThanBlocksize) { + CreateBuffer(10); + + std::string data1 = "123456789"; + buffer_->WriteToContiguousBuffer(data1); + std::string composite = data1; + ASSERT_LE(10u, buffer_->buffer_size(0)); + + std::string data2 = "0123456789"; + buffer_->WriteToContiguousBuffer(data2); + composite += data2; + + const char* buf_ptr = buffer_->GetPtr(0); + ASSERT_LE(composite.size(), buffer_->buffer_size(0)); + EXPECT_EQ(0, memcmp(composite.data(), buf_ptr, composite.size())) + << "composite: " << composite << "\n" + << " actual: " << absl::string_view(buf_ptr, buffer_->bytes_used(0)); +} + +TEST_F(BalsaBufferTest, TwoContiguousWritesLargerThanBlocksize) { + CreateBuffer(5); + + std::string data1 = "123456"; + buffer_->WriteToContiguousBuffer(data1); + std::string composite = data1; + ASSERT_LE(composite.size(), buffer_->buffer_size(0)); + + std::string data2 = "7890123"; + buffer_->WriteToContiguousBuffer(data2); + composite += data2; + + const char* buf_ptr = buffer_->GetPtr(0); + ASSERT_LE(composite.size(), buffer_->buffer_size(0)); + EXPECT_EQ(0, memcmp(composite.data(), buf_ptr, composite.size())) + << "composite: " << composite << "\n" + << " actual: " << absl::string_view(buf_ptr, buffer_->bytes_used(0)); +} + +TEST_F(BalsaBufferTest, WriteSmallerThanBlocksize) { + CreateBuffer(5); + std::string data1 = "1234"; + size_t block_idx = 0; + absl::string_view write_result = Write(data1, &block_idx); + ASSERT_EQ(1u, block_idx); + EXPECT_THAT(write_result, StrEq(data1)); + + CreateBuffer(5); + data1 = "1234"; + block_idx = 0; + write_result = Write(data1, &block_idx); + ASSERT_EQ(1u, block_idx); + EXPECT_THAT(write_result, StrEq(data1)); +} + +TEST_F(BalsaBufferTest, TwoWritesSmallerThanBlocksizeThenAnotherWrite) { + CreateBuffer(10); + std::string data1 = "12345"; + size_t block_idx = 0; + absl::string_view write_result = Write(data1, &block_idx); + ASSERT_EQ(1u, block_idx); + EXPECT_THAT(write_result, StrEq(data1)); + + std::string data2 = "data2"; + block_idx = 0; + write_result = Write(data2, &block_idx); + ASSERT_EQ(1u, block_idx); + EXPECT_THAT(write_result, StrEq(data2)); + + std::string data3 = "data3"; + block_idx = 0; + write_result = Write(data3, &block_idx); + ASSERT_EQ(2u, block_idx); + EXPECT_THAT(write_result, StrEq(data3)); + + CreateBuffer(10); + buffer_->NoMoreWriteToContiguousBuffer(); + data1 = "12345"; + block_idx = 0; + write_result = Write(data1, &block_idx); + ASSERT_EQ(0u, block_idx); + EXPECT_THAT(write_result, StrEq(data1)); + + data2 = "data2"; + block_idx = 0; + write_result = Write(data2, &block_idx); + ASSERT_EQ(0u, block_idx); + EXPECT_THAT(write_result, StrEq(data2)); + + data3 = "data3"; + block_idx = 0; + write_result = Write(data3, &block_idx); + ASSERT_EQ(1u, block_idx); + EXPECT_THAT(write_result, StrEq(data3)); +} + +TEST_F(BalsaBufferTest, WriteLargerThanBlocksize) { + CreateBuffer(5); + std::string data1 = "123456789"; + size_t block_idx = 0; + absl::string_view write_result = Write(data1, &block_idx); + ASSERT_EQ(1u, block_idx); + EXPECT_THAT(write_result, StrEq(data1)); + + CreateBuffer(5); + buffer_->NoMoreWriteToContiguousBuffer(); + data1 = "123456789"; + block_idx = 0; + write_result = Write(data1, &block_idx); + ASSERT_EQ(1u, block_idx); + EXPECT_THAT(write_result, StrEq(data1)); +} + +TEST_F(BalsaBufferTest, ContiguousThenTwoSmallerThanBlocksize) { + CreateBuffer(5); + std::string data1 = "1234567890"; + buffer_->WriteToContiguousBuffer(data1); + size_t block_idx = 0; + std::string data2 = "1234"; + absl::string_view write_result = Write(data2, &block_idx); + ASSERT_EQ(1u, block_idx); + std::string data3 = "1234"; + write_result = Write(data3, &block_idx); + ASSERT_EQ(2u, block_idx); +} + +TEST_F(BalsaBufferTest, AccessFirstBlockUninitialized) { + CreateBuffer(5); + EXPECT_EQ(0u, buffer_->GetReadableBytesOfFirstBlock()); + EXPECT_QUICHE_BUG(buffer_->StartOfFirstBlock(), + "First block not allocated yet!"); + EXPECT_QUICHE_BUG(buffer_->EndOfFirstBlock(), + "First block not allocated yet!"); +} + +TEST_F(BalsaBufferTest, AccessFirstBlockInitialized) { + CreateBuffer(5); + std::string data1 = "1234567890"; + buffer_->WriteToContiguousBuffer(data1); + const char* start = buffer_->StartOfFirstBlock(); + EXPECT_TRUE(start != nullptr); + const char* end = buffer_->EndOfFirstBlock(); + EXPECT_TRUE(end != nullptr); + EXPECT_EQ(data1.length(), static_cast(end - start)); + EXPECT_EQ(data1.length(), buffer_->GetReadableBytesOfFirstBlock()); +} + +TEST(BalsaHeaders, CanAssignBeginToIterator) { + { + BalsaHeaders header; + BalsaHeaders::const_header_lines_iterator chli = header.lines().begin(); + static_cast(chli); + } + { + const BalsaHeaders header; + BalsaHeaders::const_header_lines_iterator chli = header.lines().begin(); + static_cast(chli); + } +} + +TEST(BalsaHeaders, CanAssignEndToIterator) { + { + BalsaHeaders header; + BalsaHeaders::const_header_lines_iterator chli = header.lines().end(); + static_cast(chli); + } + { + const BalsaHeaders header; + BalsaHeaders::const_header_lines_iterator chli = header.lines().end(); + static_cast(chli); + } +} + +TEST(BalsaHeaders, ReplaceOrAppendHeaderTestAppending) { + BalsaHeaders header; + std::string key_1 = "key_1"; + std::string value_1 = "value_1"; + header.ReplaceOrAppendHeader(key_1, value_1); + + BalsaHeaders::const_header_lines_iterator chli = header.lines().begin(); + ASSERT_EQ(absl::string_view("key_1"), chli->first); + ASSERT_EQ(absl::string_view("value_1"), chli->second); + ++chli; + ASSERT_NE(header.lines().begin(), chli); + ASSERT_EQ(header.lines().end(), chli); +} + +TEST(BalsaHeaders, ReplaceOrAppendHeaderTestReplacing) { + BalsaHeaders header; + std::string key_1 = "key_1"; + std::string value_1 = "value_1"; + std::string key_2 = "key_2"; + header.ReplaceOrAppendHeader(key_1, value_1); + header.ReplaceOrAppendHeader(key_2, value_1); + std::string value_2 = "value_2_string"; + header.ReplaceOrAppendHeader(key_1, value_2); + + BalsaHeaders::const_header_lines_iterator chli = header.lines().begin(); + ASSERT_EQ(key_1, chli->first); + ASSERT_EQ(value_2, chli->second); + ++chli; + ASSERT_EQ(key_2, chli->first); + ASSERT_EQ(value_1, chli->second); + ++chli; + ASSERT_NE(header.lines().begin(), chli); + ASSERT_EQ(header.lines().end(), chli); +} + +TEST(BalsaHeaders, ReplaceOrAppendHeaderTestReplacingMultiple) { + BalsaHeaders header; + std::string key_1 = "key_1"; + std::string key_2 = "key_2"; + std::string value_1 = "val_1"; + std::string value_2 = "val_2"; + std::string value_3 = + "value_3_is_longer_than_value_1_and_value_2_and_their_keys"; + // Set up header keys 1, 1, 2. We will replace the value of key 1 with a long + // enough string that it should be moved to the end. This regression tests + // that replacement works if we move the header to the end. + header.AppendHeader(key_1, value_1); + header.AppendHeader(key_1, value_2); + header.AppendHeader(key_2, value_1); + header.ReplaceOrAppendHeader(key_1, value_3); + + BalsaHeaders::const_header_lines_iterator chli = header.lines().begin(); + ASSERT_EQ(key_1, chli->first); + ASSERT_EQ(value_3, chli->second); + ++chli; + ASSERT_EQ(key_2, chli->first); + ASSERT_EQ(value_1, chli->second); + ++chli; + ASSERT_NE(header.lines().begin(), chli); + ASSERT_EQ(header.lines().end(), chli); + + // Now test that replacement works with a shorter value, so that if we ever do + // in-place replacement it's tested. + header.ReplaceOrAppendHeader(key_1, value_1); + chli = header.lines().begin(); + ASSERT_EQ(key_1, chli->first); + ASSERT_EQ(value_1, chli->second); + ++chli; + ASSERT_EQ(key_2, chli->first); + ASSERT_EQ(value_1, chli->second); + ++chli; + ASSERT_NE(header.lines().begin(), chli); + ASSERT_EQ(header.lines().end(), chli); +} + +TEST(BalsaHeaders, AppendHeaderAndIteratorTest1) { + BalsaHeaders header; + ASSERT_EQ(header.lines().begin(), header.lines().end()); + { + std::string key_1 = "key_1"; + std::string value_1 = "value_1"; + header.AppendHeader(key_1, value_1); + key_1 = "garbage"; + value_1 = "garbage"; + } + + ASSERT_NE(header.lines().begin(), header.lines().end()); + BalsaHeaders::const_header_lines_iterator chli = header.lines().begin(); + ASSERT_EQ(header.lines().begin(), chli); + ASSERT_NE(header.lines().end(), chli); + ASSERT_EQ(absl::string_view("key_1"), chli->first); + ASSERT_EQ(absl::string_view("value_1"), chli->second); + + ++chli; + ASSERT_NE(header.lines().begin(), chli); + ASSERT_EQ(header.lines().end(), chli); +} + +TEST(BalsaHeaders, AppendHeaderAndIteratorTest2) { + BalsaHeaders header; + ASSERT_EQ(header.lines().begin(), header.lines().end()); + { + std::string key_1 = "key_1"; + std::string value_1 = "value_1"; + header.AppendHeader(key_1, value_1); + key_1 = "garbage"; + value_1 = "garbage"; + } + { + std::string key_2 = "key_2"; + std::string value_2 = "value_2"; + header.AppendHeader(key_2, value_2); + key_2 = "garbage"; + value_2 = "garbage"; + } + + ASSERT_NE(header.lines().begin(), header.lines().end()); + BalsaHeaders::const_header_lines_iterator chli = header.lines().begin(); + ASSERT_EQ(header.lines().begin(), chli); + ASSERT_NE(header.lines().end(), chli); + ASSERT_EQ(absl::string_view("key_1"), chli->first); + ASSERT_EQ(absl::string_view("value_1"), chli->second); + + ++chli; + ASSERT_NE(header.lines().begin(), chli); + ASSERT_NE(header.lines().end(), chli); + ASSERT_EQ(absl::string_view("key_2"), chli->first); + ASSERT_EQ(absl::string_view("value_2"), chli->second); + + ++chli; + ASSERT_NE(header.lines().begin(), chli); + ASSERT_EQ(header.lines().end(), chli); +} + +TEST(BalsaHeaders, AppendHeaderAndIteratorTest3) { + BalsaHeaders header; + ASSERT_EQ(header.lines().begin(), header.lines().end()); + { + std::string key_1 = "key_1"; + std::string value_1 = "value_1"; + header.AppendHeader(key_1, value_1); + key_1 = "garbage"; + value_1 = "garbage"; + } + { + std::string key_2 = "key_2"; + std::string value_2 = "value_2"; + header.AppendHeader(key_2, value_2); + key_2 = "garbage"; + value_2 = "garbage"; + } + { + std::string key_3 = "key_3"; + std::string value_3 = "value_3"; + header.AppendHeader(key_3, value_3); + key_3 = "garbage"; + value_3 = "garbage"; + } + + ASSERT_NE(header.lines().begin(), header.lines().end()); + BalsaHeaders::const_header_lines_iterator chli = header.lines().begin(); + ASSERT_EQ(header.lines().begin(), chli); + ASSERT_NE(header.lines().end(), chli); + ASSERT_EQ(absl::string_view("key_1"), chli->first); + ASSERT_EQ(absl::string_view("value_1"), chli->second); + + ++chli; + ASSERT_NE(header.lines().begin(), chli); + ASSERT_NE(header.lines().end(), chli); + ASSERT_EQ(absl::string_view("key_2"), chli->first); + ASSERT_EQ(absl::string_view("value_2"), chli->second); + + ++chli; + ASSERT_NE(header.lines().begin(), chli); + ASSERT_NE(header.lines().end(), chli); + ASSERT_EQ(absl::string_view("key_3"), chli->first); + ASSERT_EQ(absl::string_view("value_3"), chli->second); + + ++chli; + ASSERT_NE(header.lines().begin(), chli); + ASSERT_EQ(header.lines().end(), chli); +} + +TEST(BalsaHeaders, AppendHeaderAndTestEraseWithIterator) { + BalsaHeaders header; + ASSERT_EQ(header.lines().begin(), header.lines().end()); + { + std::string key_1 = "key_1"; + std::string value_1 = "value_1"; + header.AppendHeader(key_1, value_1); + key_1 = "garbage"; + value_1 = "garbage"; + } + { + std::string key_2 = "key_2"; + std::string value_2 = "value_2"; + header.AppendHeader(key_2, value_2); + key_2 = "garbage"; + value_2 = "garbage"; + } + { + std::string key_3 = "key_3"; + std::string value_3 = "value_3"; + header.AppendHeader(key_3, value_3); + key_3 = "garbage"; + value_3 = "garbage"; + } + BalsaHeaders::const_header_lines_iterator chli = header.lines().begin(); + ++chli; // should now point to key_2. + ASSERT_EQ(absl::string_view("key_2"), chli->first); + header.erase(chli); + chli = header.lines().begin(); + + ASSERT_NE(header.lines().begin(), header.lines().end()); + ASSERT_EQ(header.lines().begin(), chli); + ASSERT_NE(header.lines().end(), chli); + ASSERT_EQ(absl::string_view("key_1"), chli->first); + ASSERT_EQ(absl::string_view("value_1"), chli->second); + + ++chli; + ASSERT_NE(header.lines().begin(), chli); + ASSERT_NE(header.lines().end(), chli); + ASSERT_EQ(absl::string_view("key_3"), chli->first); + ASSERT_EQ(absl::string_view("value_3"), chli->second); + + ++chli; + ASSERT_NE(header.lines().begin(), chli); + ASSERT_EQ(header.lines().end(), chli); +} + +TEST(BalsaHeaders, TestSetFirstlineInAdditionalBuffer) { + BalsaHeaders headers; + headers.SetRequestFirstlineFromStringPieces("GET", "/", "HTTP/1.0"); + ASSERT_THAT(headers.first_line(), StrEq("GET / HTTP/1.0")); +} + +TEST(BalsaHeaders, TestSetFirstlineInOriginalBufferAndIsShorterThanOriginal) { + BalsaHeaders headers = CreateHTTPHeaders(true, + "GET /foobar HTTP/1.0\r\n" + "\r\n"); + ASSERT_THAT(headers.first_line(), StrEq("GET /foobar HTTP/1.0")); + // Note that this SetRequestFirstlineFromStringPieces should replace the + // original one in the -non- 'additional' buffer. + headers.SetRequestFirstlineFromStringPieces("GET", "/", "HTTP/1.0"); + ASSERT_THAT(headers.first_line(), StrEq("GET / HTTP/1.0")); +} + +TEST(BalsaHeaders, TestSetFirstlineInOriginalBufferAndIsLongerThanOriginal) { + // Similar to above, but this time the new firstline is larger than + // the original, yet it should still fit into the original -non- + // 'additional' buffer as the first header-line has been erased. + BalsaHeaders headers = CreateHTTPHeaders(true, + "GET / HTTP/1.0\r\n" + "some_key: some_value\r\n" + "another_key: another_value\r\n" + "\r\n"); + ASSERT_THAT(headers.first_line(), StrEq("GET / HTTP/1.0")); + headers.erase(headers.lines().begin()); + // Note that this SetRequestFirstlineFromStringPieces should replace the + // original one in the -non- 'additional' buffer. + headers.SetRequestFirstlineFromStringPieces("GET", "/foobar", "HTTP/1.0"); + ASSERT_THAT(headers.first_line(), StrEq("GET /foobar HTTP/1.0")); +} + +TEST(BalsaHeaders, TestSetFirstlineInAdditionalDataAndIsShorterThanOriginal) { + BalsaHeaders headers; + headers.SetRequestFirstlineFromStringPieces("GET", "/foobar", "HTTP/1.0"); + ASSERT_THAT(headers.first_line(), StrEq("GET /foobar HTTP/1.0")); + headers.SetRequestFirstlineFromStringPieces("GET", "/", "HTTP/1.0"); + ASSERT_THAT(headers.first_line(), StrEq("GET / HTTP/1.0")); +} + +TEST(BalsaHeaders, TestSetFirstlineInAdditionalDataAndIsLongerThanOriginal) { + BalsaHeaders headers; + headers.SetRequestFirstlineFromStringPieces("GET", "/", "HTTP/1.0"); + ASSERT_THAT(headers.first_line(), StrEq("GET / HTTP/1.0")); + headers.SetRequestFirstlineFromStringPieces("GET", "/foobar", "HTTP/1.0"); + ASSERT_THAT(headers.first_line(), StrEq("GET /foobar HTTP/1.0")); +} + +TEST(BalsaHeaders, TestDeletingSubstring) { + BalsaHeaders headers; + headers.SetRequestFirstlineFromStringPieces("GET", "/", "HTTP/1.0"); + headers.AppendHeader("key1", "value1"); + headers.AppendHeader("key2", "value2"); + headers.AppendHeader("key", "value"); + headers.AppendHeader("unrelated", "value"); + + // RemoveAllOfHeader should not delete key1 or key2 given a substring. + headers.RemoveAllOfHeader("key"); + EXPECT_TRUE(headers.HasHeader("key1")); + EXPECT_TRUE(headers.HasHeader("key2")); + EXPECT_TRUE(headers.HasHeader("unrelated")); + EXPECT_FALSE(headers.HasHeader("key")); + EXPECT_TRUE(headers.HasHeadersWithPrefix("key")); + EXPECT_TRUE(headers.HasHeadersWithPrefix("KeY")); + EXPECT_TRUE(headers.HasHeadersWithPrefix("UNREL")); + EXPECT_FALSE(headers.HasHeadersWithPrefix("key3")); + + EXPECT_FALSE(headers.GetHeader("key1").empty()); + EXPECT_FALSE(headers.GetHeader("KEY1").empty()); + EXPECT_FALSE(headers.GetHeader("key2").empty()); + EXPECT_FALSE(headers.GetHeader("unrelated").empty()); + EXPECT_TRUE(headers.GetHeader("key").empty()); + + // Add key back in. + headers.AppendHeader("key", ""); + EXPECT_TRUE(headers.HasHeader("key")); + EXPECT_TRUE(headers.HasHeadersWithPrefix("key")); + EXPECT_TRUE(headers.GetHeader("key").empty()); + + // RemoveAllHeadersWithPrefix should delete everything starting with key. + headers.RemoveAllHeadersWithPrefix("key"); + EXPECT_FALSE(headers.HasHeader("key1")); + EXPECT_FALSE(headers.HasHeader("key2")); + EXPECT_TRUE(headers.HasHeader("unrelated")); + EXPECT_FALSE(headers.HasHeader("key")); + EXPECT_FALSE(headers.HasHeadersWithPrefix("key")); + EXPECT_FALSE(headers.HasHeadersWithPrefix("key1")); + EXPECT_FALSE(headers.HasHeadersWithPrefix("key2")); + EXPECT_FALSE(headers.HasHeadersWithPrefix("kEy")); + EXPECT_TRUE(headers.HasHeadersWithPrefix("unrelated")); + + EXPECT_TRUE(headers.GetHeader("key1").empty()); + EXPECT_TRUE(headers.GetHeader("key2").empty()); + EXPECT_FALSE(headers.GetHeader("unrelated").empty()); + EXPECT_TRUE(headers.GetHeader("key").empty()); +} + +TEST(BalsaHeaders, TestRemovingValues) { + // Remove entire line from headers, twice. Ensures working line-skipping. + // Skip consideration of a line whose key is larger than our search key. + // Skip consideration of a line whose key is smaller than our search key. + // Skip consideration of a line that is already marked for skipping. + // Skip consideration of a line whose value is too small. + // Skip consideration of a line whose key is correct in length but doesn't + // match. + { + BalsaHeaders headers; + headers.SetRequestFirstlineFromStringPieces("GET", "/", "HTTP/1.0"); + headers.AppendHeader("hi", "hello"); + headers.AppendHeader("key1", "val1"); + headers.AppendHeader("key1", "value2"); + headers.AppendHeader("key1", "value3"); + headers.AppendHeader("key2", "value4"); + headers.AppendHeader("unrelated", "value"); + + EXPECT_EQ(0u, headers.RemoveValue("key1", "")); + EXPECT_EQ(1u, headers.RemoveValue("key1", "value2")); + + std::string key1_vals = headers.GetAllOfHeaderAsString("key1"); + EXPECT_THAT(key1_vals, StrEq("val1,value3")); + + EXPECT_TRUE(headers.HeaderHasValue("key1", "val1")); + EXPECT_TRUE(headers.HeaderHasValue("key1", "value3")); + EXPECT_EQ("value4", headers.GetHeader("key2")); + EXPECT_EQ("hello", headers.GetHeader("hi")); + EXPECT_EQ("value", headers.GetHeader("unrelated")); + EXPECT_FALSE(headers.HeaderHasValue("key1", "value2")); + + EXPECT_EQ(1u, headers.RemoveValue("key1", "value3")); + + key1_vals = headers.GetAllOfHeaderAsString("key1"); + EXPECT_THAT(key1_vals, StrEq("val1")); + + EXPECT_TRUE(headers.HeaderHasValue("key1", "val1")); + EXPECT_EQ("value4", headers.GetHeader("key2")); + EXPECT_EQ("hello", headers.GetHeader("hi")); + EXPECT_EQ("value", headers.GetHeader("unrelated")); + EXPECT_FALSE(headers.HeaderHasValue("key1", "value3")); + EXPECT_FALSE(headers.HeaderHasValue("key1", "value2")); + } + + // Remove/keep values with surrounding spaces. + // Remove values from in between others in multi-value line. + // Remove entire multi-value line. + // Keep value in between removed values in multi-value line. + // Keep trailing value that is too small to be matched after removing a match. + // Keep value containing matched value (partial but not complete match). + // Keep an empty header. + { + BalsaHeaders headers; + headers.SetRequestFirstlineFromStringPieces("GET", "/", "HTTP/1.0"); + headers.AppendHeader("key1", "value1"); + headers.AppendHeader("key1", "value2, value3,value2"); + headers.AppendHeader("key1", "value4 ,value2,value5,val6"); + headers.AppendHeader("key1", "value2, value2 , value2"); + headers.AppendHeader("key1", " value2 , value2 "); + headers.AppendHeader("key1", " value2 a"); + headers.AppendHeader("key1", ""); + headers.AppendHeader("key1", ", ,,"); + headers.AppendHeader("unrelated", "value"); + + EXPECT_EQ(8u, headers.RemoveValue("key1", "value2")); + + std::string key1_vals = headers.GetAllOfHeaderAsString("key1"); + EXPECT_THAT(key1_vals, + StrEq("value1,value3,value4 ,value5,val6,value2 a,,, ,,")); + + EXPECT_EQ("value", headers.GetHeader("unrelated")); + EXPECT_TRUE(headers.HeaderHasValue("key1", "value1")); + EXPECT_TRUE(headers.HeaderHasValue("key1", "value3")); + EXPECT_TRUE(headers.HeaderHasValue("key1", "value4")); + EXPECT_TRUE(headers.HeaderHasValue("key1", "value5")); + EXPECT_TRUE(headers.HeaderHasValue("key1", "val6")); + EXPECT_FALSE(headers.HeaderHasValue("key1", "value2")); + } + + { + const absl::string_view key("key"); + const absl::string_view value1("foo\0bar", 7); + const absl::string_view value2("value2"); + const std::string value = absl::StrCat(value1, ",", value2); + + { + BalsaHeaders headers; + headers.AppendHeader(key, value); + + EXPECT_TRUE(headers.HeaderHasValue(key, value1)); + EXPECT_TRUE(headers.HeaderHasValue(key, value2)); + EXPECT_EQ(value, headers.GetAllOfHeaderAsString(key)); + + EXPECT_EQ(1u, headers.RemoveValue(key, value2)); + + EXPECT_TRUE(headers.HeaderHasValue(key, value1)); + EXPECT_FALSE(headers.HeaderHasValue(key, value2)); + EXPECT_EQ(value1, headers.GetAllOfHeaderAsString(key)); + } + + { + BalsaHeaders headers; + headers.AppendHeader(key, value1); + headers.AppendHeader(key, value2); + + EXPECT_TRUE(headers.HeaderHasValue(key, value1)); + EXPECT_TRUE(headers.HeaderHasValue(key, value2)); + EXPECT_EQ(value, headers.GetAllOfHeaderAsString(key)); + + EXPECT_EQ(1u, headers.RemoveValue(key, value2)); + + EXPECT_TRUE(headers.HeaderHasValue(key, value1)); + EXPECT_FALSE(headers.HeaderHasValue(key, value2)); + EXPECT_EQ(value1, headers.GetAllOfHeaderAsString(key)); + } + } +} + +TEST(BalsaHeaders, ZeroAppendToHeaderWithCommaAndSpace) { + // Create an initial header with zero 'X-Forwarded-For' headers. + BalsaHeaders headers = CreateHTTPHeaders(true, + "GET / HTTP/1.0\r\n" + "\r\n"); + + // Use AppendToHeaderWithCommaAndSpace to add 4 new 'X-Forwarded-For' headers. + // Appending these headers should preserve the order in which they are added. + // i.e. 1.1.1.1, 2.2.2.2, 3.3.3.3, 4.4.4.4 + headers.AppendToHeaderWithCommaAndSpace("X-Forwarded-For", "1.1.1.1"); + headers.AppendToHeaderWithCommaAndSpace("X-Forwarded-For", "2.2.2.2"); + headers.AppendToHeaderWithCommaAndSpace("X-Forwarded-For", "3.3.3.3"); + headers.AppendToHeaderWithCommaAndSpace("X-Forwarded-For", "4.4.4.4"); + + // Fetch the 'X-Forwarded-For' headers and compare them to the expected order. + EXPECT_THAT(headers.GetAllOfHeader("X-Forwarded-For"), + ElementsAre("1.1.1.1, 2.2.2.2, 3.3.3.3, 4.4.4.4")); +} + +TEST(BalsaHeaders, SingleAppendToHeaderWithCommaAndSpace) { + // Create an initial header with one 'X-Forwarded-For' header. + BalsaHeaders headers = CreateHTTPHeaders(true, + "GET / HTTP/1.0\r\n" + "X-Forwarded-For: 1.1.1.1\r\n" + "\r\n"); + + // Use AppendToHeaderWithCommaAndSpace to add 4 new 'X-Forwarded-For' headers. + // Appending these headers should preserve the order in which they are added. + // i.e. 1.1.1.1, 2.2.2.2, 3.3.3.3, 4.4.4.4, 5.5.5.5 + headers.AppendToHeaderWithCommaAndSpace("X-Forwarded-For", "2.2.2.2"); + headers.AppendToHeaderWithCommaAndSpace("X-Forwarded-For", "3.3.3.3"); + headers.AppendToHeaderWithCommaAndSpace("X-Forwarded-For", "4.4.4.4"); + headers.AppendToHeaderWithCommaAndSpace("X-Forwarded-For", "5.5.5.5"); + + // Fetch the 'X-Forwarded-For' headers and compare them to the expected order. + EXPECT_THAT(headers.GetAllOfHeader("X-Forwarded-For"), + ElementsAre("1.1.1.1, 2.2.2.2, 3.3.3.3, 4.4.4.4, 5.5.5.5")); +} + +TEST(BalsaHeaders, MultipleAppendToHeaderWithCommaAndSpace) { + // Create an initial header with two 'X-Forwarded-For' headers. + BalsaHeaders headers = CreateHTTPHeaders(true, + "GET / HTTP/1.0\r\n" + "X-Forwarded-For: 1.1.1.1\r\n" + "X-Forwarded-For: 2.2.2.2\r\n" + "\r\n"); + + // Use AppendToHeaderWithCommaAndSpace to add 4 new 'X-Forwarded-For' headers. + // Appending these headers should preserve the order in which they are added. + // i.e. 1.1.1.1, 2.2.2.2, 3.3.3.3, 4.4.4.4, 5.5.5.5, 6.6.6.6 + headers.AppendToHeaderWithCommaAndSpace("X-Forwarded-For", "3.3.3.3"); + headers.AppendToHeaderWithCommaAndSpace("X-Forwarded-For", "4.4.4.4"); + headers.AppendToHeaderWithCommaAndSpace("X-Forwarded-For", "5.5.5.5"); + headers.AppendToHeaderWithCommaAndSpace("X-Forwarded-For", "6.6.6.6"); + + // Fetch the 'X-Forwarded-For' headers and compare them to the expected order. + EXPECT_THAT( + headers.GetAllOfHeader("X-Forwarded-For"), + ElementsAre("1.1.1.1", "2.2.2.2, 3.3.3.3, 4.4.4.4, 5.5.5.5, 6.6.6.6")); +} + +TEST(BalsaHeaders, HeaderHasValues) { + BalsaHeaders headers; + headers.SetRequestFirstlineFromStringPieces("GET", "/", "HTTP/1.0"); + // Make sure we find values at the beginning, middle, and end, and we handle + // multiple .find() calls correctly. + headers.AppendHeader("key", "val1,val2val2,val2,val3"); + // Make sure we don't mess up comma/boundary checks for beginning, middle and + // end. + headers.AppendHeader("key", "val4val5val6"); + headers.AppendHeader("key", "val11 val12"); + headers.AppendHeader("key", "v val13"); + // Make sure we catch the line header + headers.AppendHeader("key", "val7"); + // Make sure there's no out-of-bounds indexing on an empty line. + headers.AppendHeader("key", ""); + // Make sure it works when there's spaces before or after a comma. + headers.AppendHeader("key", "val8 , val9 , val10"); + // Make sure it works when val is surrounded by spaces. + headers.AppendHeader("key", " val14 "); + // Make sure other keys aren't used. + headers.AppendHeader("key2", "val15"); + // Mixed case. + headers.AppendHeader("key", "Val16"); + headers.AppendHeader("key", "foo, Val17, bar"); + + // All case-sensitive. + EXPECT_TRUE(headers.HeaderHasValue("key", "val1")); + EXPECT_TRUE(headers.HeaderHasValue("key", "val2")); + EXPECT_TRUE(headers.HeaderHasValue("key", "val3")); + EXPECT_TRUE(headers.HeaderHasValue("key", "val7")); + EXPECT_TRUE(headers.HeaderHasValue("key", "val8")); + EXPECT_TRUE(headers.HeaderHasValue("key", "val9")); + EXPECT_TRUE(headers.HeaderHasValue("key", "val10")); + EXPECT_TRUE(headers.HeaderHasValue("key", "val14")); + EXPECT_FALSE(headers.HeaderHasValue("key", "val4")); + EXPECT_FALSE(headers.HeaderHasValue("key", "val5")); + EXPECT_FALSE(headers.HeaderHasValue("key", "val6")); + EXPECT_FALSE(headers.HeaderHasValue("key", "val11")); + EXPECT_FALSE(headers.HeaderHasValue("key", "val12")); + EXPECT_FALSE(headers.HeaderHasValue("key", "val13")); + EXPECT_FALSE(headers.HeaderHasValue("key", "val15")); + EXPECT_FALSE(headers.HeaderHasValue("key", "val16")); + EXPECT_FALSE(headers.HeaderHasValue("key", "val17")); + + // All case-insensitive, only change is for val16 and val17. + EXPECT_TRUE(headers.HeaderHasValueIgnoreCase("key", "val1")); + EXPECT_TRUE(headers.HeaderHasValueIgnoreCase("key", "val2")); + EXPECT_TRUE(headers.HeaderHasValueIgnoreCase("key", "val3")); + EXPECT_TRUE(headers.HeaderHasValueIgnoreCase("key", "val7")); + EXPECT_TRUE(headers.HeaderHasValueIgnoreCase("key", "val8")); + EXPECT_TRUE(headers.HeaderHasValueIgnoreCase("key", "val9")); + EXPECT_TRUE(headers.HeaderHasValueIgnoreCase("key", "val10")); + EXPECT_TRUE(headers.HeaderHasValueIgnoreCase("key", "val14")); + EXPECT_FALSE(headers.HeaderHasValueIgnoreCase("key", "val4")); + EXPECT_FALSE(headers.HeaderHasValueIgnoreCase("key", "val5")); + EXPECT_FALSE(headers.HeaderHasValueIgnoreCase("key", "val6")); + EXPECT_FALSE(headers.HeaderHasValueIgnoreCase("key", "val11")); + EXPECT_FALSE(headers.HeaderHasValueIgnoreCase("key", "val12")); + EXPECT_FALSE(headers.HeaderHasValueIgnoreCase("key", "val13")); + EXPECT_FALSE(headers.HeaderHasValueIgnoreCase("key", "val15")); + EXPECT_TRUE(headers.HeaderHasValueIgnoreCase("key", "val16")); + EXPECT_TRUE(headers.HeaderHasValueIgnoreCase("key", "val17")); +} + +// Because we're dealing with one giant buffer, make sure we don't go beyond +// the bounds of the key when doing compares! +TEST(BalsaHeaders, TestNotDeletingBeyondString) { + BalsaHeaders headers; + headers.SetRequestFirstlineFromStringPieces("GET", "/", "HTTP/1.0"); + headers.AppendHeader("key1", "value1"); + + headers.RemoveAllHeadersWithPrefix("key1: value1"); + EXPECT_NE(headers.lines().begin(), headers.lines().end()); +} + +TEST(BalsaHeaders, TestIteratingOverErasedHeaders) { + BalsaHeaders headers; + headers.SetRequestFirstlineFromStringPieces("GET", "/", "HTTP/1.0"); + headers.AppendHeader("key1", "value1"); + headers.AppendHeader("key2", "value2"); + headers.AppendHeader("key3", "value3"); + headers.AppendHeader("key4", "value4"); + headers.AppendHeader("key5", "value5"); + headers.AppendHeader("key6", "value6"); + + headers.RemoveAllOfHeader("key6"); + headers.RemoveAllOfHeader("key5"); + headers.RemoveAllOfHeader("key4"); + + BalsaHeaders::const_header_lines_iterator chli = headers.lines().begin(); + EXPECT_NE(headers.lines().end(), chli); + EXPECT_EQ(headers.lines().begin(), chli); + EXPECT_THAT(chli->first, StrEq("key1")); + EXPECT_THAT(chli->second, StrEq("value1")); + + ++chli; + EXPECT_NE(headers.lines().end(), chli); + EXPECT_NE(headers.lines().begin(), chli); + EXPECT_THAT(chli->first, StrEq("key2")); + EXPECT_THAT(chli->second, StrEq("value2")); + + ++chli; + EXPECT_NE(headers.lines().end(), chli); + EXPECT_NE(headers.lines().begin(), chli); + EXPECT_THAT(chli->first, StrEq("key3")); + EXPECT_THAT(chli->second, StrEq("value3")); + + ++chli; + EXPECT_EQ(headers.lines().end(), chli); + EXPECT_NE(headers.lines().begin(), chli); + + headers.RemoveAllOfHeader("key1"); + headers.RemoveAllOfHeader("key2"); + chli = headers.lines().begin(); + EXPECT_THAT(chli->first, StrEq("key3")); + EXPECT_THAT(chli->second, StrEq("value3")); + EXPECT_NE(headers.lines().end(), chli); + EXPECT_EQ(headers.lines().begin(), chli); + + ++chli; + EXPECT_EQ(headers.lines().end(), chli); + EXPECT_NE(headers.lines().begin(), chli); +} + +TEST(BalsaHeaders, CanCompareIterators) { + BalsaHeaders header; + ASSERT_EQ(header.lines().begin(), header.lines().end()); + { + std::string key_1 = "key_1"; + std::string value_1 = "value_1"; + header.AppendHeader(key_1, value_1); + key_1 = "garbage"; + value_1 = "garbage"; + } + { + std::string key_2 = "key_2"; + std::string value_2 = "value_2"; + header.AppendHeader(key_2, value_2); + key_2 = "garbage"; + value_2 = "garbage"; + } + BalsaHeaders::const_header_lines_iterator chli = header.lines().begin(); + BalsaHeaders::const_header_lines_iterator chlj = header.lines().begin(); + EXPECT_EQ(chli, chlj); + ++chlj; + EXPECT_NE(chli, chlj); + EXPECT_LT(chli, chlj); + EXPECT_LE(chli, chlj); + EXPECT_LE(chli, chli); + EXPECT_GT(chlj, chli); + EXPECT_GE(chlj, chli); + EXPECT_GE(chlj, chlj); +} + +TEST(BalsaHeaders, AppendHeaderAndTestThatYouCanEraseEverything) { + BalsaHeaders header; + ASSERT_EQ(header.lines().begin(), header.lines().end()); + { + std::string key_1 = "key_1"; + std::string value_1 = "value_1"; + header.AppendHeader(key_1, value_1); + key_1 = "garbage"; + value_1 = "garbage"; + } + { + std::string key_2 = "key_2"; + std::string value_2 = "value_2"; + header.AppendHeader(key_2, value_2); + key_2 = "garbage"; + value_2 = "garbage"; + } + { + std::string key_3 = "key_3"; + std::string value_3 = "value_3"; + header.AppendHeader(key_3, value_3); + key_3 = "garbage"; + value_3 = "garbage"; + } + EXPECT_NE(header.lines().begin(), header.lines().end()); + BalsaHeaders::const_header_lines_iterator chli = header.lines().begin(); + while (chli != header.lines().end()) { + header.erase(chli); + chli = header.lines().begin(); + } + ASSERT_EQ(header.lines().begin(), header.lines().end()); +} + +TEST(BalsaHeaders, GetHeaderPositionWorksAsExpectedWithNoHeaderLines) { + BalsaHeaders header; + BalsaHeaders::const_header_lines_iterator i = header.GetHeaderPosition("foo"); + EXPECT_EQ(i, header.lines().end()); +} + +TEST(BalsaHeaders, GetHeaderPositionWorksAsExpectedWithBalsaFrameProcessInput) { + BalsaHeaders headers = CreateHTTPHeaders( + true, + "GET / HTTP/1.0\r\n" + "key1: value_1\r\n" + "key1: value_foo\r\n" // this one cannot be fetched via GetHeader + "key2: value_2\r\n" + "key3: value_3\r\n" + "a: value_a\r\n" + "b: value_b\r\n" + "\r\n"); + + BalsaHeaders::const_header_lines_iterator header_position_b = + headers.GetHeaderPosition("b"); + ASSERT_NE(header_position_b, headers.lines().end()); + absl::string_view header_key_b_value = header_position_b->second; + ASSERT_FALSE(header_key_b_value.empty()); + EXPECT_EQ(std::string("value_b"), header_key_b_value); + + BalsaHeaders::const_header_lines_iterator header_position_1 = + headers.GetHeaderPosition("key1"); + ASSERT_NE(header_position_1, headers.lines().end()); + absl::string_view header_key_1_value = header_position_1->second; + ASSERT_FALSE(header_key_1_value.empty()); + EXPECT_EQ(std::string("value_1"), header_key_1_value); + + BalsaHeaders::const_header_lines_iterator header_position_3 = + headers.GetHeaderPosition("key3"); + ASSERT_NE(header_position_3, headers.lines().end()); + absl::string_view header_key_3_value = header_position_3->second; + ASSERT_FALSE(header_key_3_value.empty()); + EXPECT_EQ(std::string("value_3"), header_key_3_value); + + BalsaHeaders::const_header_lines_iterator header_position_2 = + headers.GetHeaderPosition("key2"); + ASSERT_NE(header_position_2, headers.lines().end()); + absl::string_view header_key_2_value = header_position_2->second; + ASSERT_FALSE(header_key_2_value.empty()); + EXPECT_EQ(std::string("value_2"), header_key_2_value); + + BalsaHeaders::const_header_lines_iterator header_position_a = + headers.GetHeaderPosition("a"); + ASSERT_NE(header_position_a, headers.lines().end()); + absl::string_view header_key_a_value = header_position_a->second; + ASSERT_FALSE(header_key_a_value.empty()); + EXPECT_EQ(std::string("value_a"), header_key_a_value); +} + +TEST(BalsaHeaders, GetHeaderWorksAsExpectedWithNoHeaderLines) { + BalsaHeaders header; + absl::string_view value = header.GetHeader("foo"); + EXPECT_TRUE(value.empty()); + value = header.GetHeader(""); + EXPECT_TRUE(value.empty()); +} + +TEST(BalsaHeaders, HasHeaderWorksAsExpectedWithNoHeaderLines) { + BalsaHeaders header; + EXPECT_FALSE(header.HasHeader("foo")); + EXPECT_FALSE(header.HasHeader("")); + EXPECT_FALSE(header.HasHeadersWithPrefix("foo")); + EXPECT_FALSE(header.HasHeadersWithPrefix("")); +} + +TEST(BalsaHeaders, HasHeaderWorksAsExpectedWithBalsaFrameProcessInput) { + BalsaHeaders headers = CreateHTTPHeaders(true, + "GET / HTTP/1.0\r\n" + "key1: value_1\r\n" + "key1: value_foo\r\n" + "key2:\r\n" + "\r\n"); + + EXPECT_FALSE(headers.HasHeader("foo")); + EXPECT_TRUE(headers.HasHeader("key1")); + EXPECT_TRUE(headers.HasHeader("key2")); + EXPECT_FALSE(headers.HasHeadersWithPrefix("foo")); + EXPECT_TRUE(headers.HasHeadersWithPrefix("key")); + EXPECT_TRUE(headers.HasHeadersWithPrefix("KEY")); +} + +TEST(BalsaHeaders, GetHeaderWorksAsExpectedWithBalsaFrameProcessInput) { + BalsaHeaders headers = CreateHTTPHeaders( + true, + "GET / HTTP/1.0\r\n" + "key1: value_1\r\n" + "key1: value_foo\r\n" // this one cannot be fetched via GetHeader + "key2: value_2\r\n" + "key3: value_3\r\n" + "key4:\r\n" + "a: value_a\r\n" + "b: value_b\r\n" + "\r\n"); + + absl::string_view header_key_b_value = headers.GetHeader("b"); + ASSERT_FALSE(header_key_b_value.empty()); + EXPECT_EQ(std::string("value_b"), header_key_b_value); + + absl::string_view header_key_1_value = headers.GetHeader("key1"); + ASSERT_FALSE(header_key_1_value.empty()); + EXPECT_EQ(std::string("value_1"), header_key_1_value); + + absl::string_view header_key_3_value = headers.GetHeader("key3"); + ASSERT_FALSE(header_key_3_value.empty()); + EXPECT_EQ(std::string("value_3"), header_key_3_value); + + absl::string_view header_key_2_value = headers.GetHeader("key2"); + ASSERT_FALSE(header_key_2_value.empty()); + EXPECT_EQ(std::string("value_2"), header_key_2_value); + + absl::string_view header_key_a_value = headers.GetHeader("a"); + ASSERT_FALSE(header_key_a_value.empty()); + EXPECT_EQ(std::string("value_a"), header_key_a_value); + + EXPECT_TRUE(headers.GetHeader("key4").empty()); +} + +TEST(BalsaHeaders, GetHeaderWorksAsExpectedWithAppendHeader) { + BalsaHeaders header; + + header.AppendHeader("key1", "value_1"); + // note that this (following) one cannot be found using GetHeader. + header.AppendHeader("key1", "value_2"); + header.AppendHeader("key2", "value_2"); + header.AppendHeader("key3", "value_3"); + header.AppendHeader("a", "value_a"); + header.AppendHeader("b", "value_b"); + + absl::string_view header_key_b_value = header.GetHeader("b"); + absl::string_view header_key_1_value = header.GetHeader("key1"); + absl::string_view header_key_3_value = header.GetHeader("key3"); + absl::string_view header_key_2_value = header.GetHeader("key2"); + absl::string_view header_key_a_value = header.GetHeader("a"); + + ASSERT_FALSE(header_key_1_value.empty()); + ASSERT_FALSE(header_key_2_value.empty()); + ASSERT_FALSE(header_key_3_value.empty()); + ASSERT_FALSE(header_key_a_value.empty()); + ASSERT_FALSE(header_key_b_value.empty()); + + EXPECT_TRUE(header.HasHeader("key1")); + EXPECT_TRUE(header.HasHeader("key2")); + EXPECT_TRUE(header.HasHeader("key3")); + EXPECT_TRUE(header.HasHeader("a")); + EXPECT_TRUE(header.HasHeader("b")); + + EXPECT_TRUE(header.HasHeadersWithPrefix("key1")); + EXPECT_TRUE(header.HasHeadersWithPrefix("key2")); + EXPECT_TRUE(header.HasHeadersWithPrefix("key3")); + EXPECT_TRUE(header.HasHeadersWithPrefix("a")); + EXPECT_TRUE(header.HasHeadersWithPrefix("b")); + + EXPECT_EQ(std::string("value_1"), header_key_1_value); + EXPECT_EQ(std::string("value_2"), header_key_2_value); + EXPECT_EQ(std::string("value_3"), header_key_3_value); + EXPECT_EQ(std::string("value_a"), header_key_a_value); + EXPECT_EQ(std::string("value_b"), header_key_b_value); +} + +TEST(BalsaHeaders, HasHeaderWorksAsExpectedWithAppendHeader) { + BalsaHeaders header; + + ASSERT_FALSE(header.HasHeader("key1")); + EXPECT_FALSE(header.HasHeadersWithPrefix("K")); + EXPECT_FALSE(header.HasHeadersWithPrefix("ke")); + EXPECT_FALSE(header.HasHeadersWithPrefix("key")); + EXPECT_FALSE(header.HasHeadersWithPrefix("key1")); + EXPECT_FALSE(header.HasHeadersWithPrefix("key2")); + header.AppendHeader("key1", "value_1"); + EXPECT_TRUE(header.HasHeader("key1")); + EXPECT_TRUE(header.HasHeadersWithPrefix("K")); + EXPECT_TRUE(header.HasHeadersWithPrefix("ke")); + EXPECT_TRUE(header.HasHeadersWithPrefix("key")); + EXPECT_TRUE(header.HasHeadersWithPrefix("key1")); + EXPECT_FALSE(header.HasHeadersWithPrefix("key2")); + + header.AppendHeader("key1", "value_2"); + EXPECT_TRUE(header.HasHeader("key1")); + EXPECT_FALSE(header.HasHeader("key2")); + EXPECT_TRUE(header.HasHeadersWithPrefix("k")); + EXPECT_TRUE(header.HasHeadersWithPrefix("ke")); + EXPECT_TRUE(header.HasHeadersWithPrefix("key")); + EXPECT_TRUE(header.HasHeadersWithPrefix("key1")); + EXPECT_FALSE(header.HasHeadersWithPrefix("key2")); +} + +TEST(BalsaHeaders, GetHeaderWorksAsExpectedWithHeadersErased) { + BalsaHeaders header; + header.AppendHeader("key1", "value_1"); + header.AppendHeader("key1", "value_2"); + header.AppendHeader("key2", "value_2"); + header.AppendHeader("key3", "value_3"); + header.AppendHeader("a", "value_a"); + header.AppendHeader("b", "value_b"); + + header.erase(header.GetHeaderPosition("key2")); + + absl::string_view header_key_b_value = header.GetHeader("b"); + absl::string_view header_key_1_value = header.GetHeader("key1"); + absl::string_view header_key_3_value = header.GetHeader("key3"); + absl::string_view header_key_2_value = header.GetHeader("key2"); + absl::string_view header_key_a_value = header.GetHeader("a"); + + ASSERT_FALSE(header_key_1_value.empty()); + ASSERT_TRUE(header_key_2_value.empty()); + ASSERT_FALSE(header_key_3_value.empty()); + ASSERT_FALSE(header_key_a_value.empty()); + ASSERT_FALSE(header_key_b_value.empty()); + + EXPECT_EQ(std::string("value_1"), header_key_1_value); + EXPECT_EQ(std::string("value_3"), header_key_3_value); + EXPECT_EQ(std::string("value_a"), header_key_a_value); + EXPECT_EQ(std::string("value_b"), header_key_b_value); + + // Erasing one makes the next one visible: + header.erase(header.GetHeaderPosition("key1")); + header_key_1_value = header.GetHeader("key1"); + ASSERT_FALSE(header_key_1_value.empty()); + EXPECT_EQ(std::string("value_2"), header_key_1_value); + + // Erase both: + header.erase(header.GetHeaderPosition("key1")); + ASSERT_TRUE(header.GetHeader("key1").empty()); +} + +TEST(BalsaHeaders, HasHeaderWorksAsExpectedWithHeadersErased) { + BalsaHeaders header; + header.AppendHeader("key1", "value_1"); + header.AppendHeader("key2", "value_2a"); + header.AppendHeader("key2", "value_2b"); + + ASSERT_TRUE(header.HasHeader("key1")); + ASSERT_TRUE(header.HasHeadersWithPrefix("key1")); + ASSERT_TRUE(header.HasHeadersWithPrefix("key2")); + ASSERT_TRUE(header.HasHeadersWithPrefix("kEY")); + header.erase(header.GetHeaderPosition("key1")); + EXPECT_FALSE(header.HasHeader("key1")); + EXPECT_FALSE(header.HasHeadersWithPrefix("key1")); + EXPECT_TRUE(header.HasHeadersWithPrefix("key2")); + EXPECT_TRUE(header.HasHeadersWithPrefix("kEY")); + + ASSERT_TRUE(header.HasHeader("key2")); + header.erase(header.GetHeaderPosition("key2")); + ASSERT_TRUE(header.HasHeader("key2")); + EXPECT_FALSE(header.HasHeadersWithPrefix("key1")); + EXPECT_TRUE(header.HasHeadersWithPrefix("key2")); + EXPECT_TRUE(header.HasHeadersWithPrefix("kEY")); + header.erase(header.GetHeaderPosition("key2")); + EXPECT_FALSE(header.HasHeader("key2")); + EXPECT_FALSE(header.HasHeadersWithPrefix("key1")); + EXPECT_FALSE(header.HasHeadersWithPrefix("key2")); + EXPECT_FALSE(header.HasHeadersWithPrefix("kEY")); +} + +TEST(BalsaHeaders, HasNonEmptyHeaderWorksAsExpectedWithNoHeaderLines) { + BalsaHeaders header; + EXPECT_FALSE(header.HasNonEmptyHeader("foo")); + EXPECT_FALSE(header.HasNonEmptyHeader("")); +} + +TEST(BalsaHeaders, HasNonEmptyHeaderWorksAsExpectedWithAppendHeader) { + BalsaHeaders header; + + EXPECT_FALSE(header.HasNonEmptyHeader("key1")); + header.AppendHeader("key1", ""); + EXPECT_FALSE(header.HasNonEmptyHeader("key1")); + + header.AppendHeader("key1", "value_2"); + EXPECT_TRUE(header.HasNonEmptyHeader("key1")); + EXPECT_FALSE(header.HasNonEmptyHeader("key2")); +} + +TEST(BalsaHeaders, HasNonEmptyHeaderWorksAsExpectedWithHeadersErased) { + BalsaHeaders header; + header.AppendHeader("key1", "value_1"); + header.AppendHeader("key2", "value_2a"); + header.AppendHeader("key2", ""); + + EXPECT_TRUE(header.HasNonEmptyHeader("key1")); + header.erase(header.GetHeaderPosition("key1")); + EXPECT_FALSE(header.HasNonEmptyHeader("key1")); + + EXPECT_TRUE(header.HasNonEmptyHeader("key2")); + header.erase(header.GetHeaderPosition("key2")); + EXPECT_FALSE(header.HasNonEmptyHeader("key2")); + header.erase(header.GetHeaderPosition("key2")); + EXPECT_FALSE(header.HasNonEmptyHeader("key2")); +} + +TEST(BalsaHeaders, HasNonEmptyHeaderWorksAsExpectedWithBalsaFrameProcessInput) { + BalsaHeaders headers = CreateHTTPHeaders(true, + "GET / HTTP/1.0\r\n" + "key1: value_1\r\n" + "key2:\r\n" + "key3:\r\n" + "key3: value_3\r\n" + "key4:\r\n" + "key4:\r\n" + "key5: value_5\r\n" + "key5:\r\n" + "\r\n"); + + EXPECT_FALSE(headers.HasNonEmptyHeader("foo")); + EXPECT_TRUE(headers.HasNonEmptyHeader("key1")); + EXPECT_FALSE(headers.HasNonEmptyHeader("key2")); + EXPECT_TRUE(headers.HasNonEmptyHeader("key3")); + EXPECT_FALSE(headers.HasNonEmptyHeader("key4")); + EXPECT_TRUE(headers.HasNonEmptyHeader("key5")); + + headers.erase(headers.GetHeaderPosition("key5")); + EXPECT_FALSE(headers.HasNonEmptyHeader("key5")); +} + +TEST(BalsaHeaders, GetAllOfHeader) { + BalsaHeaders header; + header.AppendHeader("key", "value_1"); + header.AppendHeader("Key", "value_2,value_3"); + header.AppendHeader("key", ""); + header.AppendHeader("KEY", "value_4"); + + std::vector result; + header.GetAllOfHeader("key", &result); + ASSERT_EQ(4u, result.size()); + EXPECT_EQ("value_1", result[0]); + EXPECT_EQ("value_2,value_3", result[1]); + EXPECT_EQ("", result[2]); + EXPECT_EQ("value_4", result[3]); + + EXPECT_EQ(header.GetAllOfHeader("key"), result); +} + +TEST(BalsaHeaders, GetAllOfHeaderDoesWhatItSays) { + BalsaHeaders header; + // Multiple values for a given header. + // Some values appear multiple times + header.AppendHeader("key", "value_1"); + header.AppendHeader("key", "value_2"); + header.AppendHeader("key", ""); + header.AppendHeader("key", "value_1"); + + ASSERT_NE(header.lines().begin(), header.lines().end()); + std::vector out; + + header.GetAllOfHeader("key", &out); + ASSERT_EQ(4u, out.size()); + EXPECT_EQ("value_1", out[0]); + EXPECT_EQ("value_2", out[1]); + EXPECT_EQ("", out[2]); + EXPECT_EQ("value_1", out[3]); + + EXPECT_EQ(header.GetAllOfHeader("key"), out); +} + +TEST(BalsaHeaders, GetAllOfHeaderWithPrefix) { + BalsaHeaders header; + header.AppendHeader("foo-Foo", "value_1"); + header.AppendHeader("Foo-bar", "value_2,value_3"); + header.AppendHeader("foo-Foo", ""); + header.AppendHeader("bar", "value_not"); + header.AppendHeader("fOO-fOO", "value_4"); + + std::vector> result; + header.GetAllOfHeaderWithPrefix("abc", &result); + ASSERT_EQ(0u, result.size()); + + header.GetAllOfHeaderWithPrefix("foo", &result); + ASSERT_EQ(4u, result.size()); + EXPECT_EQ("foo-Foo", result[0].first); + EXPECT_EQ("value_1", result[0].second); + EXPECT_EQ("Foo-bar", result[1].first); + EXPECT_EQ("value_2,value_3", result[1].second); + EXPECT_EQ("", result[2].second); + EXPECT_EQ("value_4", result[3].second); + + std::vector> result2; + header.GetAllOfHeaderWithPrefix("FoO", &result2); + ASSERT_EQ(4u, result2.size()); +} + +TEST(BalsaHeaders, GetAllHeadersWithLimit) { + BalsaHeaders header; + header.AppendHeader("foo-Foo", "value_1"); + header.AppendHeader("Foo-bar", "value_2,value_3"); + header.AppendHeader("foo-Foo", ""); + header.AppendHeader("bar", "value_4"); + header.AppendHeader("fOO-fOO", "value_5"); + + std::vector> result; + header.GetAllHeadersWithLimit(&result, 4); + ASSERT_EQ(4u, result.size()); + EXPECT_EQ("foo-Foo", result[0].first); + EXPECT_EQ("value_1", result[0].second); + EXPECT_EQ("Foo-bar", result[1].first); + EXPECT_EQ("value_2,value_3", result[1].second); + EXPECT_EQ("", result[2].second); + EXPECT_EQ("value_4", result[3].second); + + std::vector> result2; + header.GetAllHeadersWithLimit(&result2, -1); + ASSERT_EQ(5u, result2.size()); +} + +TEST(BalsaHeaders, RangeFor) { + BalsaHeaders header; + // Multiple values for a given header. + // Some values appear multiple times + header.AppendHeader("key1", "value_1a"); + header.AppendHeader("key1", "value_1b"); + header.AppendHeader("key2", ""); + header.AppendHeader("key3", "value_3"); + + std::vector> out; + for (const auto& line : header.lines()) { + out.push_back(line); + } + const std::vector> expected = + {{"key1", "value_1a"}, + {"key1", "value_1b"}, + {"key2", ""}, + {"key3", "value_3"}}; + EXPECT_EQ(expected, out); +} + +TEST(BalsaHeaders, GetAllOfHeaderWithNonExistentKey) { + BalsaHeaders header; + header.AppendHeader("key", "value_1"); + header.AppendHeader("key", "value_2"); + std::vector out; + + header.GetAllOfHeader("key_non_existent", &out); + ASSERT_EQ(0u, out.size()); + + EXPECT_EQ(header.GetAllOfHeader("key_non_existent"), out); +} + +TEST(BalsaHeaders, GetAllOfHeaderEmptyValVariation1) { + BalsaHeaders header; + header.AppendHeader("key", ""); + header.AppendHeader("key", ""); + header.AppendHeader("key", "v1"); + std::vector out; + header.GetAllOfHeader("key", &out); + ASSERT_EQ(3u, out.size()); + EXPECT_EQ("", out[0]); + EXPECT_EQ("", out[1]); + EXPECT_EQ("v1", out[2]); + + EXPECT_EQ(header.GetAllOfHeader("key"), out); +} + +TEST(BalsaHeaders, GetAllOfHeaderEmptyValVariation2) { + BalsaHeaders header; + header.AppendHeader("key", ""); + header.AppendHeader("key", "v1"); + header.AppendHeader("key", ""); + std::vector out; + header.GetAllOfHeader("key", &out); + ASSERT_EQ(3u, out.size()); + EXPECT_EQ("", out[0]); + EXPECT_EQ("v1", out[1]); + EXPECT_EQ("", out[2]); + + EXPECT_EQ(header.GetAllOfHeader("key"), out); +} + +TEST(BalsaHeaders, GetAllOfHeaderEmptyValVariation3) { + BalsaHeaders header; + header.AppendHeader("key", ""); + header.AppendHeader("key", "v1"); + std::vector out; + header.GetAllOfHeader("key", &out); + ASSERT_EQ(2u, out.size()); + EXPECT_EQ("", out[0]); + EXPECT_EQ("v1", out[1]); + + EXPECT_EQ(header.GetAllOfHeader("key"), out); +} + +TEST(BalsaHeaders, GetAllOfHeaderEmptyValVariation4) { + BalsaHeaders header; + header.AppendHeader("key", "v1"); + header.AppendHeader("key", ""); + std::vector out; + header.GetAllOfHeader("key", &out); + ASSERT_EQ(2u, out.size()); + EXPECT_EQ("v1", out[0]); + EXPECT_EQ("", out[1]); + + EXPECT_EQ(header.GetAllOfHeader("key"), out); +} + +TEST(BalsaHeaders, GetAllOfHeaderWithAppendHeaders) { + BalsaHeaders header; + header.AppendHeader("key", "value_1"); + header.AppendHeader("key", "value_2"); + std::vector out; + + header.GetAllOfHeader("key_new", &out); + ASSERT_EQ(0u, out.size()); + EXPECT_EQ(header.GetAllOfHeader("key_new"), out); + + // Add key_new to the header + header.AppendHeader("key_new", "value_3"); + header.GetAllOfHeader("key_new", &out); + ASSERT_EQ(1u, out.size()); + EXPECT_EQ("value_3", out[0]); + EXPECT_EQ(header.GetAllOfHeader("key_new"), out); + + // Get the keys that are not modified + header.GetAllOfHeader("key", &out); + ASSERT_EQ(3u, out.size()); + EXPECT_EQ("value_1", out[1]); + EXPECT_EQ("value_2", out[2]); + EXPECT_THAT(header.GetAllOfHeader("key"), ElementsAre("value_1", "value_2")); +} + +TEST(BalsaHeaders, GetAllOfHeaderWithRemoveHeaders) { + BalsaHeaders header; + header.AppendHeader("key", "value_1"); + header.AppendHeader("key", "value_2"); + header.AppendHeader("a", "va"); + + header.RemoveAllOfHeader("key"); + std::vector out; + header.GetAllOfHeader("key", &out); + ASSERT_EQ(0u, out.size()); + EXPECT_EQ(header.GetAllOfHeader("key"), out); + + header.GetAllOfHeader("a", &out); + ASSERT_EQ(1u, out.size()); + EXPECT_EQ(header.GetAllOfHeader("a"), out); + + out.clear(); + header.RemoveAllOfHeader("a"); + header.GetAllOfHeader("a", &out); + ASSERT_EQ(0u, out.size()); + EXPECT_EQ(header.GetAllOfHeader("a"), out); +} + +TEST(BalsaHeaders, GetAllOfHeaderWithRemoveNonExistentHeaders) { + BalsaHeaders headers; + headers.SetRequestFirstlineFromStringPieces("GET", "/", "HTTP/1.0"); + headers.AppendHeader("Accept-Encoding", "deflate,compress"); + EXPECT_EQ(0u, headers.RemoveValue("Accept-Encoding", "gzip(gfe)")); + std::string accept_encoding_vals = + headers.GetAllOfHeaderAsString("Accept-Encoding"); + EXPECT_EQ("deflate,compress", accept_encoding_vals); +} + +TEST(BalsaHeaders, GetAllOfHeaderWithEraseHeaders) { + BalsaHeaders header; + header.AppendHeader("key", "value_1"); + header.AppendHeader("key", "value_2"); + header.AppendHeader("a", "va"); + + std::vector out; + + header.erase(header.GetHeaderPosition("key")); + header.GetAllOfHeader("key", &out); + ASSERT_EQ(1u, out.size()); + EXPECT_EQ("value_2", out[0]); + EXPECT_EQ(header.GetAllOfHeader("key"), out); + + out.clear(); + header.erase(header.GetHeaderPosition("key")); + header.GetAllOfHeader("key", &out); + ASSERT_EQ(0u, out.size()); + EXPECT_EQ(header.GetAllOfHeader("key"), out); + + out.clear(); + header.GetAllOfHeader("a", &out); + ASSERT_EQ(1u, out.size()); + EXPECT_EQ(header.GetAllOfHeader("a"), out); + + out.clear(); + header.erase(header.GetHeaderPosition("a")); + header.GetAllOfHeader("a", &out); + ASSERT_EQ(0u, out.size()); + EXPECT_EQ(header.GetAllOfHeader("key"), out); +} + +TEST(BalsaHeaders, GetAllOfHeaderWithNoHeaderLines) { + BalsaHeaders header; + std::vector out; + header.GetAllOfHeader("key", &out); + EXPECT_EQ(0u, out.size()); + EXPECT_EQ(header.GetAllOfHeader("key"), out); +} + +TEST(BalsaHeaders, GetAllOfHeaderDoesWhatItSaysForVariousKeys) { + BalsaHeaders header; + header.AppendHeader("key1", "value_11"); + header.AppendHeader("key2", "value_21"); + header.AppendHeader("key1", "value_12"); + header.AppendHeader("key2", "value_22"); + + std::vector out; + + header.GetAllOfHeader("key1", &out); + EXPECT_EQ("value_11", out[0]); + EXPECT_EQ("value_12", out[1]); + EXPECT_EQ(header.GetAllOfHeader("key1"), out); + + header.GetAllOfHeader("key2", &out); + EXPECT_EQ("value_21", out[2]); + EXPECT_EQ("value_22", out[3]); + EXPECT_THAT(header.GetAllOfHeader("key2"), + ElementsAre("value_21", "value_22")); +} + +TEST(BalsaHeaders, GetAllOfHeaderWithBalsaFrameProcessInput) { + BalsaHeaders header = CreateHTTPHeaders(true, + "GET / HTTP/1.0\r\n" + "key1: value_1\r\n" + "key1: value_foo\r\n" + "key2: value_2\r\n" + "a: value_a\r\n" + "key2: \r\n" + "b: value_b\r\n" + "\r\n"); + + std::vector out; + int index = 0; + header.GetAllOfHeader("key1", &out); + EXPECT_EQ("value_1", out[index++]); + EXPECT_EQ("value_foo", out[index++]); + EXPECT_EQ(header.GetAllOfHeader("key1"), out); + + header.GetAllOfHeader("key2", &out); + EXPECT_EQ("value_2", out[index++]); + EXPECT_EQ("", out[index++]); + EXPECT_THAT(header.GetAllOfHeader("key2"), ElementsAre("value_2", "")); + + header.GetAllOfHeader("a", &out); + EXPECT_EQ("value_a", out[index++]); + EXPECT_THAT(header.GetAllOfHeader("a"), ElementsAre("value_a")); + + header.GetAllOfHeader("b", &out); + EXPECT_EQ("value_b", out[index++]); + EXPECT_THAT(header.GetAllOfHeader("b"), ElementsAre("value_b")); +} + +TEST(BalsaHeaders, GetAllOfHeaderIncludeRemovedDoesWhatItSays) { + BalsaHeaders header; + header.AppendHeader("key1", "value_11"); + header.AppendHeader("key2", "value_21"); + header.AppendHeader("key1", "value_12"); + header.AppendHeader("key2", "value_22"); + header.AppendHeader("key1", ""); + + std::vector out; + header.GetAllOfHeaderIncludeRemoved("key1", &out); + ASSERT_EQ(3u, out.size()); + EXPECT_EQ("value_11", out[0]); + EXPECT_EQ("value_12", out[1]); + EXPECT_EQ("", out[2]); + header.GetAllOfHeaderIncludeRemoved("key2", &out); + ASSERT_EQ(5u, out.size()); + EXPECT_EQ("value_21", out[3]); + EXPECT_EQ("value_22", out[4]); + + header.erase(header.GetHeaderPosition("key1")); + out.clear(); + header.GetAllOfHeaderIncludeRemoved("key1", &out); + ASSERT_EQ(3u, out.size()); + EXPECT_EQ("value_12", out[0]); + EXPECT_EQ("", out[1]); + EXPECT_EQ("value_11", out[2]); + header.GetAllOfHeaderIncludeRemoved("key2", &out); + ASSERT_EQ(5u, out.size()); + EXPECT_EQ("value_21", out[3]); + EXPECT_EQ("value_22", out[4]); + + header.RemoveAllOfHeader("key1"); + out.clear(); + header.GetAllOfHeaderIncludeRemoved("key1", &out); + ASSERT_EQ(3u, out.size()); + EXPECT_EQ("value_11", out[0]); + EXPECT_EQ("value_12", out[1]); + EXPECT_EQ("", out[2]); + header.GetAllOfHeaderIncludeRemoved("key2", &out); + ASSERT_EQ(5u, out.size()); + EXPECT_EQ("value_21", out[3]); + EXPECT_EQ("value_22", out[4]); + + header.Clear(); + out.clear(); + header.GetAllOfHeaderIncludeRemoved("key1", &out); + ASSERT_EQ(0u, out.size()); + header.GetAllOfHeaderIncludeRemoved("key2", &out); + ASSERT_EQ(0u, out.size()); +} + +TEST(BalsaHeaders, GetAllOfHeaderIncludeRemovedWithNonExistentKey) { + BalsaHeaders header; + header.AppendHeader("key", "value_1"); + header.AppendHeader("key", "value_2"); + std::vector out; + header.GetAllOfHeaderIncludeRemoved("key_non_existent", &out); + ASSERT_EQ(0u, out.size()); +} + +TEST(BalsaHeaders, GetIteratorForKeyDoesWhatItSays) { + BalsaHeaders header; + // Multiple values for a given header. + // Some values appear multiple times + header.AppendHeader("key", "value_1"); + header.AppendHeader("Key", "value_2"); + header.AppendHeader("key", ""); + header.AppendHeader("KEY", "value_1"); + + BalsaHeaders::const_header_lines_key_iterator key_it = + header.GetIteratorForKey("key"); + EXPECT_NE(header.lines().end(), key_it); + EXPECT_NE(header.header_lines_key_end(), key_it); + EXPECT_EQ("key", key_it->first); + EXPECT_EQ("value_1", key_it->second); + ++key_it; + EXPECT_NE(header.lines().end(), key_it); + EXPECT_NE(header.header_lines_key_end(), key_it); + EXPECT_EQ("Key", key_it->first); + EXPECT_EQ("value_2", key_it->second); + ++key_it; + EXPECT_NE(header.lines().end(), key_it); + EXPECT_NE(header.header_lines_key_end(), key_it); + EXPECT_EQ("key", key_it->first); + EXPECT_EQ("", key_it->second); + ++key_it; + EXPECT_NE(header.lines().end(), key_it); + EXPECT_NE(header.header_lines_key_end(), key_it); + EXPECT_EQ("KEY", key_it->first); + EXPECT_EQ("value_1", key_it->second); + ++key_it; + EXPECT_EQ(header.lines().end(), key_it); + EXPECT_EQ(header.header_lines_key_end(), key_it); +} + +TEST(BalsaHeaders, GetIteratorForKeyWithNonExistentKey) { + BalsaHeaders header; + header.AppendHeader("key", "value_1"); + header.AppendHeader("key", "value_2"); + + BalsaHeaders::const_header_lines_key_iterator key_it = + header.GetIteratorForKey("key_non_existent"); + EXPECT_EQ(header.lines().end(), key_it); + EXPECT_EQ(header.header_lines_key_end(), key_it); + const auto lines = header.lines("key_non_existent"); + EXPECT_EQ(lines.begin(), header.lines().end()); + EXPECT_EQ(lines.end(), header.header_lines_key_end()); +} + +TEST(BalsaHeaders, GetIteratorForKeyWithAppendHeaders) { + BalsaHeaders header; + header.AppendHeader("key", "value_1"); + header.AppendHeader("key", "value_2"); + + BalsaHeaders::const_header_lines_key_iterator key_it = + header.GetIteratorForKey("key_new"); + EXPECT_EQ(header.lines().end(), key_it); + EXPECT_EQ(header.header_lines_key_end(), key_it); + + // Add key_new to the header + header.AppendHeader("key_new", "value_3"); + key_it = header.GetIteratorForKey("key_new"); + const auto lines1 = header.lines("key_new"); + EXPECT_EQ(lines1.begin(), key_it); + EXPECT_EQ(lines1.end(), header.header_lines_key_end()); + + EXPECT_NE(header.lines().end(), key_it); + EXPECT_NE(header.header_lines_key_end(), key_it); + EXPECT_EQ("key_new", key_it->first); + EXPECT_EQ("value_3", key_it->second); + ++key_it; + EXPECT_EQ(header.lines().end(), key_it); + EXPECT_EQ(header.header_lines_key_end(), key_it); + + // Get the keys that are not modified + key_it = header.GetIteratorForKey("key"); + const auto lines2 = header.lines("key"); + EXPECT_EQ(lines2.begin(), key_it); + EXPECT_EQ(lines2.end(), header.header_lines_key_end()); + EXPECT_NE(header.lines().end(), key_it); + EXPECT_NE(header.header_lines_key_end(), key_it); + EXPECT_EQ("key", key_it->first); + EXPECT_EQ("value_1", key_it->second); + ++key_it; + EXPECT_NE(header.lines().end(), key_it); + EXPECT_NE(header.header_lines_key_end(), key_it); + EXPECT_EQ("key", key_it->first); + EXPECT_EQ("value_2", key_it->second); + ++key_it; + EXPECT_EQ(header.lines().end(), key_it); + EXPECT_EQ(header.header_lines_key_end(), key_it); +} + +TEST(BalsaHeaders, GetIteratorForKeyWithRemoveHeaders) { + BalsaHeaders header; + header.AppendHeader("key", "value_1"); + header.AppendHeader("key", "value_2"); + header.AppendHeader("a", "va"); + + header.RemoveAllOfHeader("a"); + BalsaHeaders::const_header_lines_key_iterator key_it = + header.GetIteratorForKey("key"); + EXPECT_NE(header.lines().end(), key_it); + const auto lines1 = header.lines("key"); + EXPECT_EQ(lines1.begin(), key_it); + EXPECT_EQ(lines1.end(), header.header_lines_key_end()); + EXPECT_EQ("value_1", key_it->second); + ++key_it; + EXPECT_NE(header.lines().end(), key_it); + EXPECT_NE(header.header_lines_key_end(), key_it); + EXPECT_EQ("key", key_it->first); + EXPECT_EQ("value_2", key_it->second); + ++key_it; + EXPECT_EQ(header.lines().end(), key_it); + EXPECT_EQ(header.header_lines_key_end(), key_it); + + // Check that a typical loop works properly. + for (BalsaHeaders::const_header_lines_key_iterator it = + header.GetIteratorForKey("key"); + it != header.lines().end(); ++it) { + EXPECT_EQ("key", it->first); + } +} + +TEST(BalsaHeaders, GetIteratorForKeyWithEraseHeaders) { + BalsaHeaders header; + header.AppendHeader("key", "value_1"); + header.AppendHeader("key", "value_2"); + header.AppendHeader("a", "va"); + header.erase(header.GetHeaderPosition("key")); + + BalsaHeaders::const_header_lines_key_iterator key_it = + header.GetIteratorForKey("key"); + EXPECT_NE(header.lines().end(), key_it); + const auto lines1 = header.lines("key"); + EXPECT_EQ(lines1.begin(), key_it); + EXPECT_EQ(lines1.end(), header.header_lines_key_end()); + EXPECT_NE(header.header_lines_key_end(), key_it); + EXPECT_EQ("key", key_it->first); + EXPECT_EQ("value_2", key_it->second); + ++key_it; + EXPECT_EQ(header.lines().end(), key_it); + EXPECT_EQ(header.header_lines_key_end(), key_it); + + header.erase(header.GetHeaderPosition("key")); + key_it = header.GetIteratorForKey("key"); + const auto lines2 = header.lines("key"); + EXPECT_EQ(lines2.begin(), key_it); + EXPECT_EQ(lines2.end(), header.header_lines_key_end()); + EXPECT_EQ(header.lines().end(), key_it); + EXPECT_EQ(header.header_lines_key_end(), key_it); + + key_it = header.GetIteratorForKey("a"); + const auto lines3 = header.lines("a"); + EXPECT_EQ(lines3.begin(), key_it); + EXPECT_EQ(lines3.end(), header.header_lines_key_end()); + EXPECT_NE(header.lines().end(), key_it); + EXPECT_NE(header.header_lines_key_end(), key_it); + EXPECT_EQ("a", key_it->first); + EXPECT_EQ("va", key_it->second); + ++key_it; + EXPECT_EQ(header.lines().end(), key_it); + EXPECT_EQ(header.header_lines_key_end(), key_it); + + header.erase(header.GetHeaderPosition("a")); + key_it = header.GetIteratorForKey("a"); + const auto lines4 = header.lines("a"); + EXPECT_EQ(lines4.begin(), key_it); + EXPECT_EQ(lines4.end(), header.header_lines_key_end()); + EXPECT_EQ(header.lines().end(), key_it); + EXPECT_EQ(header.header_lines_key_end(), key_it); +} + +TEST(BalsaHeaders, GetIteratorForKeyWithNoHeaderLines) { + BalsaHeaders header; + BalsaHeaders::const_header_lines_key_iterator key_it = + header.GetIteratorForKey("key"); + const auto lines = header.lines("key"); + EXPECT_EQ(lines.begin(), key_it); + EXPECT_EQ(lines.end(), header.header_lines_key_end()); + EXPECT_EQ(header.lines().end(), key_it); + EXPECT_EQ(header.header_lines_key_end(), key_it); +} + +TEST(BalsaHeaders, GetIteratorForKeyWithBalsaFrameProcessInput) { + BalsaHeaders header = CreateHTTPHeaders(true, + "GET / HTTP/1.0\r\n" + "key1: value_1\r\n" + "Key1: value_foo\r\n" + "key2: value_2\r\n" + "a: value_a\r\n" + "key2: \r\n" + "b: value_b\r\n" + "\r\n"); + + BalsaHeaders::const_header_lines_key_iterator key_it = + header.GetIteratorForKey("Key1"); + const auto lines1 = header.lines("Key1"); + EXPECT_EQ(lines1.begin(), key_it); + EXPECT_EQ(lines1.end(), header.header_lines_key_end()); + EXPECT_NE(header.lines().end(), key_it); + EXPECT_NE(header.header_lines_key_end(), key_it); + EXPECT_EQ("key1", key_it->first); + EXPECT_EQ("value_1", key_it->second); + ++key_it; + EXPECT_NE(header.lines().end(), key_it); + EXPECT_NE(header.header_lines_key_end(), key_it); + EXPECT_EQ("Key1", key_it->first); + EXPECT_EQ("value_foo", key_it->second); + ++key_it; + EXPECT_EQ(header.lines().end(), key_it); + EXPECT_EQ(header.header_lines_key_end(), key_it); + + key_it = header.GetIteratorForKey("key2"); + EXPECT_NE(header.lines().end(), key_it); + const auto lines2 = header.lines("key2"); + EXPECT_EQ(lines2.begin(), key_it); + EXPECT_EQ(lines2.end(), header.header_lines_key_end()); + EXPECT_NE(header.header_lines_key_end(), key_it); + EXPECT_EQ("key2", key_it->first); + EXPECT_EQ("value_2", key_it->second); + ++key_it; + EXPECT_NE(header.lines().end(), key_it); + EXPECT_NE(header.header_lines_key_end(), key_it); + EXPECT_EQ("key2", key_it->first); + EXPECT_EQ("", key_it->second); + ++key_it; + EXPECT_EQ(header.lines().end(), key_it); + EXPECT_EQ(header.header_lines_key_end(), key_it); + + key_it = header.GetIteratorForKey("a"); + EXPECT_NE(header.lines().end(), key_it); + const auto lines3 = header.lines("a"); + EXPECT_EQ(lines3.begin(), key_it); + EXPECT_EQ(lines3.end(), header.header_lines_key_end()); + EXPECT_NE(header.header_lines_key_end(), key_it); + EXPECT_EQ("a", key_it->first); + EXPECT_EQ("value_a", key_it->second); + ++key_it; + EXPECT_EQ(header.lines().end(), key_it); + EXPECT_EQ(header.header_lines_key_end(), key_it); + + key_it = header.GetIteratorForKey("b"); + EXPECT_NE(header.lines().end(), key_it); + const auto lines4 = header.lines("b"); + EXPECT_EQ(lines4.begin(), key_it); + EXPECT_EQ(lines4.end(), header.header_lines_key_end()); + EXPECT_NE(header.header_lines_key_end(), key_it); + EXPECT_EQ("b", key_it->first); + EXPECT_EQ("value_b", key_it->second); + ++key_it; + EXPECT_EQ(header.lines().end(), key_it); + EXPECT_EQ(header.header_lines_key_end(), key_it); +} + +TEST(BalsaHeaders, GetAllOfHeaderAsStringDoesWhatItSays) { + BalsaHeaders header; + // Multiple values for a given header. + // Some values appear multiple times + header.AppendHeader("key", "value_1"); + header.AppendHeader("Key", "value_2"); + header.AppendHeader("key", ""); + header.AppendHeader("KEY", "value_1"); + + std::string result = header.GetAllOfHeaderAsString("key"); + EXPECT_EQ("value_1,value_2,,value_1", result); +} + +TEST(BalsaHeaders, RemoveAllOfHeaderDoesWhatItSays) { + BalsaHeaders header; + header.AppendHeader("key", "value_1"); + header.AppendHeader("key", "value_2"); + ASSERT_NE(header.lines().begin(), header.lines().end()); + header.RemoveAllOfHeader("key"); + ASSERT_EQ(header.lines().begin(), header.lines().end()); +} + +TEST(BalsaHeaders, + RemoveAllOfHeaderDoesWhatItSaysEvenWhenThingsHaveBeenErased) { + BalsaHeaders header; + header.AppendHeader("key1", "value_1"); + header.AppendHeader("key1", "value_2"); + header.AppendHeader("key2", "value_3"); + header.AppendHeader("key1", "value_4"); + header.AppendHeader("key2", "value_5"); + header.AppendHeader("key1", "value_6"); + ASSERT_NE(header.lines().begin(), header.lines().end()); + + BalsaHeaders::const_header_lines_iterator chli = header.lines().begin(); + ++chli; + ++chli; + ++chli; + header.erase(chli); + + chli = header.lines().begin(); + ++chli; + header.erase(chli); + + header.RemoveAllOfHeader("key1"); + for (const auto& line : header.lines()) { + EXPECT_NE(std::string("key1"), line.first); + } +} + +TEST(BalsaHeaders, RemoveAllOfHeaderDoesNothingWhenNoKeyOfThatNameExists) { + BalsaHeaders header; + header.AppendHeader("key", "value_1"); + header.AppendHeader("key", "value_2"); + ASSERT_NE(header.lines().begin(), header.lines().end()); + header.RemoveAllOfHeader("foo"); + int num_found = 0; + for (const auto& line : header.lines()) { + ++num_found; + EXPECT_EQ(absl::string_view("key"), line.first); + } + EXPECT_EQ(2, num_found); + EXPECT_NE(header.lines().begin(), header.lines().end()); +} + +TEST(BalsaHeaders, WriteHeaderEndingToBuffer) { + BalsaHeaders header; + SimpleBuffer simple_buffer; + header.WriteHeaderEndingToBuffer(&simple_buffer); + EXPECT_THAT(simple_buffer.GetReadableRegion(), StrEq("\r\n")); +} + +TEST(BalsaHeaders, WriteToBufferDoesntCrashWithUninitializedHeader) { + BalsaHeaders header; + SimpleBuffer simple_buffer; + header.WriteHeaderAndEndingToBuffer(&simple_buffer); +} + +TEST(BalsaHeaders, WriteToBufferWorksWithBalsaHeadersParsedByFramer) { + std::string input = + "GET / HTTP/1.0\r\n" + "key_with_value: value\r\n" + "key_with_continuation_value: \r\n" + " with continuation\r\n" + "key_with_two_continuation_value: \r\n" + " continuation 1\r\n" + " continuation 2\r\n" + "a: foo \r\n" + "b-s:\n" + " bar\t\n" + "foo: \r\n" + "bazzzzzzzleriffic!: snaps\n" + "\n"; + std::string expected = + "GET / HTTP/1.0\r\n" + "key_with_value: value\r\n" + "key_with_continuation_value: with continuation\r\n" + "key_with_two_continuation_value: continuation 1\r\n" + " continuation 2\r\n" + "a: foo\r\n" + "b-s: bar\r\n" + "foo: \r\n" + "bazzzzzzzleriffic!: snaps\r\n" + "\r\n"; + + BalsaHeaders headers = CreateHTTPHeaders(true, input); + SimpleBuffer simple_buffer; + size_t expected_write_buffer_size = headers.GetSizeForWriteBuffer(); + headers.WriteHeaderAndEndingToBuffer(&simple_buffer); + EXPECT_THAT(simple_buffer.GetReadableRegion(), StrEq(expected)); + EXPECT_EQ(expected_write_buffer_size, + static_cast(simple_buffer.ReadableBytes())); +} + +TEST(BalsaHeaders, + WriteToBufferWorksWithBalsaHeadersParsedByFramerTabContinuations) { + std::string input = + "GET / HTTP/1.0\r\n" + "key_with_value: value\r\n" + "key_with_continuation_value: \r\n" + "\twith continuation\r\n" + "key_with_two_continuation_value: \r\n" + "\tcontinuation 1\r\n" + "\tcontinuation 2\r\n" + "a: foo \r\n" + "b-s:\n" + "\tbar\t\n" + "foo: \r\n" + "bazzzzzzzleriffic!: snaps\n" + "\n"; + std::string expected = + "GET / HTTP/1.0\r\n" + "key_with_value: value\r\n" + "key_with_continuation_value: with continuation\r\n" + "key_with_two_continuation_value: continuation 1\r\n" + "\tcontinuation 2\r\n" + "a: foo\r\n" + "b-s: bar\r\n" + "foo: \r\n" + "bazzzzzzzleriffic!: snaps\r\n" + "\r\n"; + + BalsaHeaders headers = CreateHTTPHeaders(true, input); + SimpleBuffer simple_buffer; + size_t expected_write_buffer_size = headers.GetSizeForWriteBuffer(); + headers.WriteHeaderAndEndingToBuffer(&simple_buffer); + EXPECT_THAT(simple_buffer.GetReadableRegion(), StrEq(expected)); + EXPECT_EQ(expected_write_buffer_size, + static_cast(simple_buffer.ReadableBytes())); +} + +TEST(BalsaHeaders, WriteToBufferWorksWhenFirstlineSetThroughHeaders) { + BalsaHeaders headers; + headers.SetRequestFirstlineFromStringPieces("GET", "/", "HTTP/1.0"); + std::string expected = + "GET / HTTP/1.0\r\n" + "\r\n"; + SimpleBuffer simple_buffer; + size_t expected_write_buffer_size = headers.GetSizeForWriteBuffer(); + headers.WriteHeaderAndEndingToBuffer(&simple_buffer); + EXPECT_THAT(simple_buffer.GetReadableRegion(), StrEq(expected)); + EXPECT_EQ(expected_write_buffer_size, + static_cast(simple_buffer.ReadableBytes())); +} + +TEST(BalsaHeaders, WriteToBufferWorksWhenSetThroughHeaders) { + BalsaHeaders headers; + headers.SetRequestFirstlineFromStringPieces("GET", "/", "HTTP/1.0"); + headers.AppendHeader("key1", "value1"); + headers.AppendHeader("key 2", "value\n 2"); + headers.AppendHeader("key\n 3", "value3"); + std::string expected = + "GET / HTTP/1.0\r\n" + "key1: value1\r\n" + "key 2: value\n" + " 2\r\n" + "key\n" + " 3: value3\r\n" + "\r\n"; + SimpleBuffer simple_buffer; + size_t expected_write_buffer_size = headers.GetSizeForWriteBuffer(); + headers.WriteHeaderAndEndingToBuffer(&simple_buffer); + EXPECT_THAT(simple_buffer.GetReadableRegion(), StrEq(expected)); + EXPECT_EQ(expected_write_buffer_size, + static_cast(simple_buffer.ReadableBytes())); +} + +TEST(BalsaHeaders, WriteToBufferWorkWhensOnlyLinesSetThroughHeaders) { + BalsaHeaders headers; + headers.AppendHeader("key1", "value1"); + headers.AppendHeader("key 2", "value\n 2"); + headers.AppendHeader("key\n 3", "value3"); + std::string expected = + "\r\n" + "key1: value1\r\n" + "key 2: value\n" + " 2\r\n" + "key\n" + " 3: value3\r\n" + "\r\n"; + SimpleBuffer simple_buffer; + size_t expected_write_buffer_size = headers.GetSizeForWriteBuffer(); + headers.WriteHeaderAndEndingToBuffer(&simple_buffer); + EXPECT_THAT(simple_buffer.GetReadableRegion(), StrEq(expected)); + EXPECT_EQ(expected_write_buffer_size, + static_cast(simple_buffer.ReadableBytes())); +} + +TEST(BalsaHeaders, WriteToBufferWorksWhenSetThroughHeadersWithElementsErased) { + BalsaHeaders headers; + headers.SetRequestFirstlineFromStringPieces("GET", "/", "HTTP/1.0"); + headers.AppendHeader("key1", "value1"); + headers.AppendHeader("key 2", "value\n 2"); + headers.AppendHeader("key\n 3", "value3"); + headers.RemoveAllOfHeader("key1"); + headers.RemoveAllOfHeader("key\n 3"); + std::string expected = + "GET / HTTP/1.0\r\n" + "key 2: value\n" + " 2\r\n" + "\r\n"; + SimpleBuffer simple_buffer; + size_t expected_write_buffer_size = headers.GetSizeForWriteBuffer(); + headers.WriteHeaderAndEndingToBuffer(&simple_buffer); + EXPECT_THAT(simple_buffer.GetReadableRegion(), StrEq(expected)); + EXPECT_EQ(expected_write_buffer_size, + static_cast(simple_buffer.ReadableBytes())); +} + +TEST(BalsaHeaders, WriteToBufferWithManuallyAppendedHeaderLine) { + BalsaHeaders headers; + headers.SetRequestFirstlineFromStringPieces("GET", "/", "HTTP/1.0"); + headers.AppendHeader("key1", "value1"); + headers.AppendHeader("key 2", "value\n 2"); + std::string expected = + "GET / HTTP/1.0\r\n" + "key1: value1\r\n" + "key 2: value\n" + " 2\r\n" + "key 3: value 3\r\n" + "\r\n"; + + SimpleBuffer simple_buffer; + size_t expected_write_buffer_size = headers.GetSizeForWriteBuffer(); + headers.WriteToBuffer(&simple_buffer); + headers.WriteHeaderLineToBuffer(&simple_buffer, "key 3", "value 3", + BalsaHeaders::CaseOption::kNoModification); + headers.WriteHeaderEndingToBuffer(&simple_buffer); + EXPECT_THAT(simple_buffer.GetReadableRegion(), StrEq(expected)); + EXPECT_EQ(expected_write_buffer_size + 16, + static_cast(simple_buffer.ReadableBytes())); +} + +TEST(BalsaHeaders, DumpToStringEmptyHeaders) { + BalsaHeaders headers; + std::string headers_str; + headers.DumpToString(&headers_str); + EXPECT_EQ("\n \n", headers_str); +} + +TEST(BalsaHeaders, DumpToStringParsedHeaders) { + std::string input = + "GET / HTTP/1.0\r\n" + "Header1: value\r\n" + "Header2: value\r\n" + "\r\n"; + std::string output = + "\n" + " GET / HTTP/1.0\n" + " Header1: value\n" + " Header2: value\n"; + + BalsaHeaders headers = CreateHTTPHeaders(true, input); + std::string headers_str; + headers.DumpToString(&headers_str); + EXPECT_EQ(output, headers_str); + EXPECT_TRUE(headers.FramerIsDoneWriting()); +} + +TEST(BalsaHeaders, DumpToStringPartialHeaders) { + BalsaHeaders headers; + BalsaFrame balsa_frame; + balsa_frame.set_is_request(true); + balsa_frame.set_balsa_headers(&headers); + std::string input = + "GET / HTTP/1.0\r\n" + "Header1: value\r\n" + "Header2: value\r\n"; + std::string output = absl::StrFormat("\n \n ", + static_cast(input.size())); + output += input; + output += '\n'; + + ASSERT_EQ(input.size(), balsa_frame.ProcessInput(input.data(), input.size())); + ASSERT_FALSE(balsa_frame.MessageFullyRead()); + std::string headers_str; + headers.DumpToString(&headers_str); + EXPECT_EQ(output, headers_str); + EXPECT_FALSE(headers.FramerIsDoneWriting()); +} + +TEST(BalsaHeaders, DumpToStringParsingNonHeadersData) { + BalsaHeaders headers; + BalsaFrame balsa_frame; + balsa_frame.set_is_request(true); + balsa_frame.set_balsa_headers(&headers); + std::string input = + "This is not a header. " + "Just some random data to simulate mismatch."; + std::string output = absl::StrFormat("\n \n ", + static_cast(input.size())); + output += input; + output += '\n'; + + ASSERT_EQ(input.size(), balsa_frame.ProcessInput(input.data(), input.size())); + ASSERT_FALSE(balsa_frame.MessageFullyRead()); + std::string headers_str; + headers.DumpToString(&headers_str); + EXPECT_EQ(output, headers_str); +} + +TEST(BalsaHeaders, Clear) { + BalsaHeaders headers; + headers.SetRequestFirstlineFromStringPieces("GET", "/", "HTTP/1.0"); + headers.AppendHeader("key1", "value1"); + headers.AppendHeader("key 2", "value\n 2"); + headers.AppendHeader("key\n 3", "value3"); + headers.RemoveAllOfHeader("key1"); + headers.RemoveAllOfHeader("key\n 3"); + headers.Clear(); + EXPECT_TRUE(headers.first_line().empty()); + EXPECT_EQ(headers.lines().begin(), headers.lines().end()); + EXPECT_TRUE(headers.IsEmpty()); +} + +TEST(BalsaHeaders, + TestSetFromStringPiecesWithInitialFirstlineInHeaderStreamAndNewToo) { + BalsaHeaders headers = CreateHTTPHeaders(false, + "HTTP/1.1 200 reason phrase\r\n" + "content-length: 0\r\n" + "\r\n"); + EXPECT_THAT(headers.response_version(), StrEq("HTTP/1.1")); + EXPECT_THAT(headers.response_code(), StrEq("200")); + EXPECT_THAT(headers.response_reason_phrase(), StrEq("reason phrase")); + + headers.SetResponseFirstline("HTTP/1.0", 404, "a reason"); + EXPECT_THAT(headers.response_version(), StrEq("HTTP/1.0")); + EXPECT_THAT(headers.response_code(), StrEq("404")); + EXPECT_THAT(headers.parsed_response_code(), Eq(404)); + EXPECT_THAT(headers.response_reason_phrase(), StrEq("a reason")); + EXPECT_THAT(headers.first_line(), StrEq("HTTP/1.0 404 a reason")); +} + +TEST(BalsaHeaders, + TestSetFromStringPiecesWithInitialFirstlineInHeaderStreamButNotNew) { + BalsaHeaders headers = CreateHTTPHeaders(false, + "HTTP/1.1 200 reason phrase\r\n" + "content-length: 0\r\n" + "\r\n"); + EXPECT_THAT(headers.response_version(), StrEq("HTTP/1.1")); + EXPECT_THAT(headers.response_code(), StrEq("200")); + EXPECT_THAT(headers.response_reason_phrase(), StrEq("reason phrase")); + + headers.SetResponseFirstline("HTTP/1.000", 404000, + "supercalifragilisticexpealidocious"); + EXPECT_THAT(headers.response_version(), StrEq("HTTP/1.000")); + EXPECT_THAT(headers.response_code(), StrEq("404000")); + EXPECT_THAT(headers.parsed_response_code(), Eq(404000)); + EXPECT_THAT(headers.response_reason_phrase(), + StrEq("supercalifragilisticexpealidocious")); + EXPECT_THAT(headers.first_line(), + StrEq("HTTP/1.000 404000 supercalifragilisticexpealidocious")); +} + +TEST(BalsaHeaders, + TestSetFromStringPiecesWithFirstFirstlineInHeaderStreamButNotNew2) { + SCOPED_TRACE( + "This test tests the codepath where the new firstline is" + " too large to fit within the space used by the original" + " firstline, but large enuogh to space in the free space" + " available in both firstline plus the space made available" + " with deleted header lines (specifically, the first one"); + BalsaHeaders headers = CreateHTTPHeaders( + false, + "HTTP/1.1 200 reason phrase\r\n" + "a: 0987123409871234078130948710938471093827401983740198327401982374\r\n" + "content-length: 0\r\n" + "\r\n"); + EXPECT_THAT(headers.response_version(), StrEq("HTTP/1.1")); + EXPECT_THAT(headers.response_code(), StrEq("200")); + EXPECT_THAT(headers.response_reason_phrase(), StrEq("reason phrase")); + + headers.erase(headers.lines().begin()); + headers.SetResponseFirstline("HTTP/1.000", 404000, + "supercalifragilisticexpealidocious"); + EXPECT_THAT(headers.response_version(), StrEq("HTTP/1.000")); + EXPECT_THAT(headers.response_code(), StrEq("404000")); + EXPECT_THAT(headers.parsed_response_code(), Eq(404000)); + EXPECT_THAT(headers.response_reason_phrase(), + StrEq("supercalifragilisticexpealidocious")); + EXPECT_THAT(headers.first_line(), + StrEq("HTTP/1.000 404000 supercalifragilisticexpealidocious")); +} + +TEST(BalsaHeaders, TestSetFirstlineFromStringPiecesWithNoInitialFirstline) { + BalsaHeaders headers; + headers.SetResponseFirstline("HTTP/1.1", 200, "don't need a reason"); + EXPECT_THAT(headers.response_version(), StrEq("HTTP/1.1")); + EXPECT_THAT(headers.response_code(), StrEq("200")); + EXPECT_THAT(headers.parsed_response_code(), Eq(200)); + EXPECT_THAT(headers.response_reason_phrase(), StrEq("don't need a reason")); + EXPECT_THAT(headers.first_line(), StrEq("HTTP/1.1 200 don't need a reason")); +} + +TEST(BalsaHeaders, TestSettingFirstlineElementsWithOtherElementsMissing) { + { + BalsaHeaders headers; + headers.SetRequestMethod("GET"); + headers.SetRequestUri("/"); + EXPECT_THAT(headers.first_line(), StrEq("GET / ")); + } + { + BalsaHeaders headers; + headers.SetRequestMethod("GET"); + headers.SetRequestVersion("HTTP/1.1"); + EXPECT_THAT(headers.first_line(), StrEq("GET HTTP/1.1")); + } + { + BalsaHeaders headers; + headers.SetRequestUri("/"); + headers.SetRequestVersion("HTTP/1.1"); + EXPECT_THAT(headers.first_line(), StrEq(" / HTTP/1.1")); + } +} + +TEST(BalsaHeaders, TestSettingMissingFirstlineElementsAfterBalsaHeadersParsed) { + { + BalsaHeaders headers = CreateHTTPHeaders(true, "GET /foo\r\n"); + ASSERT_THAT(headers.first_line(), StrEq("GET /foo")); + + headers.SetRequestVersion("HTTP/1.1"); + EXPECT_THAT(headers.first_line(), StrEq("GET /foo HTTP/1.1")); + } + { + BalsaHeaders headers = CreateHTTPHeaders(true, "GET\r\n"); + ASSERT_THAT(headers.first_line(), StrEq("GET")); + + headers.SetRequestUri("/foo"); + EXPECT_THAT(headers.first_line(), StrEq("GET /foo ")); + } +} + +// Here we exersize the codepaths involved in setting a new firstine when the +// previously set firstline is stored in the 'additional_data_stream_' +// variable, and the new firstline is larger than the previously set firstline. +TEST(BalsaHeaders, + SetFirstlineFromStringPiecesFirstInAdditionalDataAndNewLarger) { + BalsaHeaders headers; + // This one will end up being put into the additional data stream + headers.SetResponseFirstline("HTTP/1.1", 200, "don't need a reason"); + EXPECT_THAT(headers.response_version(), StrEq("HTTP/1.1")); + EXPECT_THAT(headers.response_code(), StrEq("200")); + EXPECT_THAT(headers.parsed_response_code(), Eq(200)); + EXPECT_THAT(headers.response_reason_phrase(), StrEq("don't need a reason")); + EXPECT_THAT(headers.first_line(), StrEq("HTTP/1.1 200 don't need a reason")); + + // Now, we set it again, this time we're extending what exists + // here. + headers.SetResponseFirstline("HTTP/1.10", 2000, "REALLY don't need a reason"); + EXPECT_THAT(headers.response_version(), StrEq("HTTP/1.10")); + EXPECT_THAT(headers.response_code(), StrEq("2000")); + EXPECT_THAT(headers.parsed_response_code(), Eq(2000)); + EXPECT_THAT(headers.response_reason_phrase(), + StrEq("REALLY don't need a reason")); + EXPECT_THAT(headers.first_line(), + StrEq("HTTP/1.10 2000 REALLY don't need a reason")); +} + +// Here we exersize the codepaths involved in setting a new firstine when the +// previously set firstline is stored in the 'additional_data_stream_' +// variable, and the new firstline is smaller than the previously set firstline. +TEST(BalsaHeaders, + TestSetFirstlineFromStringPiecesWithPreviousInAdditionalDataNewSmaller) { + BalsaHeaders headers; + // This one will end up being put into the additional data stream + // + headers.SetResponseFirstline("HTTP/1.10", 2000, "REALLY don't need a reason"); + EXPECT_THAT(headers.response_version(), StrEq("HTTP/1.10")); + EXPECT_THAT(headers.response_code(), StrEq("2000")); + EXPECT_THAT(headers.parsed_response_code(), Eq(2000)); + EXPECT_THAT(headers.response_reason_phrase(), + StrEq("REALLY don't need a reason")); + EXPECT_THAT(headers.first_line(), + StrEq("HTTP/1.10 2000 REALLY don't need a reason")); + + // Now, we set it again, this time we're extending what exists + // here. + headers.SetResponseFirstline("HTTP/1.0", 200, "a reason"); + EXPECT_THAT(headers.response_version(), StrEq("HTTP/1.0")); + EXPECT_THAT(headers.response_code(), StrEq("200")); + EXPECT_THAT(headers.parsed_response_code(), Eq(200)); + EXPECT_THAT(headers.response_reason_phrase(), StrEq("a reason")); + EXPECT_THAT(headers.first_line(), StrEq("HTTP/1.0 200 a reason")); +} + +TEST(BalsaHeaders, CopyFrom) { + BalsaHeaders headers1, headers2; + absl::string_view method("GET"); + absl::string_view uri("/foo"); + absl::string_view version("HTTP/1.0"); + headers1.SetRequestFirstlineFromStringPieces(method, uri, version); + headers1.AppendHeader("key1", "value1"); + headers1.AppendHeader("key 2", "value\n 2"); + headers1.AppendHeader("key\n 3", "value3"); + + // "GET /foo HTTP/1.0" // 17 + // "key1: value1\r\n" // 14 + // "key 2: value\n 2\r\n" // 17 + // "key\n 3: value3\r\n" // 16 + + headers2.CopyFrom(headers1); + + EXPECT_THAT(headers1.first_line(), StrEq("GET /foo HTTP/1.0")); + BalsaHeaders::const_header_lines_iterator chli = headers1.lines().begin(); + EXPECT_THAT(chli->first, StrEq("key1")); + EXPECT_THAT(chli->second, StrEq("value1")); + ++chli; + EXPECT_THAT(chli->first, StrEq("key 2")); + EXPECT_THAT(chli->second, StrEq("value\n 2")); + ++chli; + EXPECT_THAT(chli->first, StrEq("key\n 3")); + EXPECT_THAT(chli->second, StrEq("value3")); + ++chli; + EXPECT_EQ(headers1.lines().end(), chli); + + EXPECT_THAT(headers1.request_method(), + StrEq((std::string(headers2.request_method())))); + EXPECT_THAT(headers1.request_uri(), + StrEq((std::string(headers2.request_uri())))); + EXPECT_THAT(headers1.request_version(), + StrEq((std::string(headers2.request_version())))); + + EXPECT_THAT(headers2.first_line(), StrEq("GET /foo HTTP/1.0")); + chli = headers2.lines().begin(); + EXPECT_THAT(chli->first, StrEq("key1")); + EXPECT_THAT(chli->second, StrEq("value1")); + ++chli; + EXPECT_THAT(chli->first, StrEq("key 2")); + EXPECT_THAT(chli->second, StrEq("value\n 2")); + ++chli; + EXPECT_THAT(chli->first, StrEq("key\n 3")); + EXPECT_THAT(chli->second, StrEq("value3")); + ++chli; + EXPECT_EQ(headers2.lines().end(), chli); + + version = absl::string_view("HTTP/1.1"); + int code = 200; + absl::string_view reason_phrase("reason phrase asdf"); + + headers1.RemoveAllOfHeader("key1"); + headers1.AppendHeader("key4", "value4"); + + headers1.SetResponseFirstline(version, code, reason_phrase); + + headers2.CopyFrom(headers1); + + // "GET /foo HTTP/1.0" // 17 + // "XXXXXXXXXXXXXX" // 14 + // "key 2: value\n 2\r\n" // 17 + // "key\n 3: value3\r\n" // 16 + // "key4: value4\r\n" // 14 + // + // -> + // + // "HTTP/1.1 200 reason phrase asdf" // 31 = (17 + 14) + // "key 2: value\n 2\r\n" // 17 + // "key\n 3: value3\r\n" // 16 + // "key4: value4\r\n" // 14 + + EXPECT_THAT(headers1.request_method(), + StrEq((std::string(headers2.request_method())))); + EXPECT_THAT(headers1.request_uri(), + StrEq((std::string(headers2.request_uri())))); + EXPECT_THAT(headers1.request_version(), + StrEq((std::string(headers2.request_version())))); + + EXPECT_THAT(headers2.first_line(), StrEq("HTTP/1.1 200 reason phrase asdf")); + chli = headers2.lines().begin(); + EXPECT_THAT(chli->first, StrEq("key 2")); + EXPECT_THAT(chli->second, StrEq("value\n 2")); + ++chli; + EXPECT_THAT(chli->first, StrEq("key\n 3")); + EXPECT_THAT(chli->second, StrEq("value3")); + ++chli; + EXPECT_THAT(chli->first, StrEq("key4")); + EXPECT_THAT(chli->second, StrEq("value4")); + ++chli; + EXPECT_EQ(headers2.lines().end(), chli); +} + +// Test BalsaHeaders move constructor and move assignment operator. +TEST(BalsaHeaders, Move) { + BalsaHeaders headers1, headers3; + absl::string_view method("GET"); + absl::string_view uri("/foo"); + absl::string_view version("HTTP/1.0"); + headers1.SetRequestFirstlineFromStringPieces(method, uri, version); + headers1.AppendHeader("key1", "value1"); + headers1.AppendHeader("key 2", "value\n 2"); + headers1.AppendHeader("key\n 3", "value3"); + + // "GET /foo HTTP/1.0" // 17 + // "key1: value1\r\n" // 14 + // "key 2: value\n 2\r\n" // 17 + // "key\n 3: value3\r\n" // 16 + + BalsaHeaders headers2 = std::move(headers1); + + EXPECT_EQ("GET /foo HTTP/1.0", headers2.first_line()); + BalsaHeaders::const_header_lines_iterator chli = headers2.lines().begin(); + EXPECT_EQ("key1", chli->first); + EXPECT_EQ("value1", chli->second); + ++chli; + EXPECT_EQ("key 2", chli->first); + EXPECT_EQ("value\n 2", chli->second); + ++chli; + EXPECT_EQ("key\n 3", chli->first); + EXPECT_EQ("value3", chli->second); + ++chli; + EXPECT_EQ(headers2.lines().end(), chli); + + EXPECT_EQ("GET", headers2.request_method()); + EXPECT_EQ("/foo", headers2.request_uri()); + EXPECT_EQ("HTTP/1.0", headers2.request_version()); + + headers3 = std::move(headers2); + version = absl::string_view("HTTP/1.1"); + int code = 200; + absl::string_view reason_phrase("reason phrase asdf"); + + headers3.RemoveAllOfHeader("key1"); + headers3.AppendHeader("key4", "value4"); + + headers3.SetResponseFirstline(version, code, reason_phrase); + + BalsaHeaders headers4 = std::move(headers3); + + // "GET /foo HTTP/1.0" // 17 + // "XXXXXXXXXXXXXX" // 14 + // "key 2: value\n 2\r\n" // 17 + // "key\n 3: value3\r\n" // 16 + // "key4: value4\r\n" // 14 + // + // -> + // + // "HTTP/1.1 200 reason phrase asdf" // 31 = (17 + 14) + // "key 2: value\n 2\r\n" // 17 + // "key\n 3: value3\r\n" // 16 + // "key4: value4\r\n" // 14 + + EXPECT_EQ("200", headers4.response_code()); + EXPECT_EQ("reason phrase asdf", headers4.response_reason_phrase()); + EXPECT_EQ("HTTP/1.1", headers4.response_version()); + + EXPECT_EQ("HTTP/1.1 200 reason phrase asdf", headers4.first_line()); + chli = headers4.lines().begin(); + EXPECT_EQ("key 2", chli->first); + EXPECT_EQ("value\n 2", chli->second); + ++chli; + EXPECT_EQ("key\n 3", chli->first); + EXPECT_EQ("value3", chli->second); + ++chli; + EXPECT_EQ("key4", chli->first); + EXPECT_EQ("value4", chli->second); + ++chli; + EXPECT_EQ(headers4.lines().end(), chli); +} + +TEST(BalsaHeaders, IteratorWorksWithOStreamAsExpected) { + { + std::stringstream actual; + BalsaHeaders::const_header_lines_iterator chli; + actual << chli; + // Note that the output depends on the flavor of standard library in use. + EXPECT_THAT(actual.str(), AnyOf(StrEq("[0, 0]"), // libstdc++ + StrEq("[(nil), 0]"), // libc++ + StrEq("[0x0, 0]"))); // libc++ on Mac + } + { + BalsaHeaders headers; + std::stringstream actual; + BalsaHeaders::const_header_lines_iterator chli = headers.lines().begin(); + actual << chli; + std::stringstream expected; + expected << "[" << &headers << ", 0]"; + EXPECT_THAT(expected.str(), StrEq(actual.str())); + } +} + +TEST(BalsaHeaders, TestSetResponseReasonPhraseWithNoInitialFirstline) { + BalsaHeaders balsa_headers; + balsa_headers.SetResponseReasonPhrase("don't need a reason"); + EXPECT_THAT(balsa_headers.first_line(), StrEq(" don't need a reason")); + EXPECT_TRUE(balsa_headers.response_version().empty()); + EXPECT_TRUE(balsa_headers.response_code().empty()); + EXPECT_THAT(balsa_headers.response_reason_phrase(), + StrEq("don't need a reason")); +} + +// Testing each of 9 combinations separately was taking up way too much of this +// file (not to mention the inordinate amount of stupid code duplication), thus +// this test tests all 9 combinations of smaller, equal, and larger in one +// place. +TEST(BalsaHeaders, TestSetResponseReasonPhrase) { + const char* response_reason_phrases[] = { + "qwerty asdfgh", + "qwerty", + "qwerty asdfghjkl", + }; + size_t arraysize_squared = (ABSL_ARRAYSIZE(response_reason_phrases) * + ABSL_ARRAYSIZE(response_reason_phrases)); + // We go through the 9 different permutations of (response_reason_phrases + // choose 2) in the loop below. For each permutation, we mutate the firstline + // twice-- once from the original, and once from the previous. + for (size_t iteration = 0; iteration < arraysize_squared; ++iteration) { + SCOPED_TRACE("Original firstline: \"HTTP/1.0 200 reason phrase\""); + BalsaHeaders headers = CreateHTTPHeaders(true, + "HTTP/1.0 200 reason phrase\r\n" + "content-length: 0\r\n" + "\r\n"); + ASSERT_THAT(headers.first_line(), StrEq("HTTP/1.0 200 reason phrase")); + + { + int first = iteration / ABSL_ARRAYSIZE(response_reason_phrases); + const char* response_reason_phrase_first = response_reason_phrases[first]; + std::string expected_new_firstline = + absl::StrFormat("HTTP/1.0 200 %s", response_reason_phrase_first); + SCOPED_TRACE(absl::StrFormat("Then set response_reason_phrase(\"%s\")", + response_reason_phrase_first)); + + headers.SetResponseReasonPhrase(response_reason_phrase_first); + EXPECT_THAT(headers.first_line(), + StrEq(absl::StrFormat("HTTP/1.0 200 %s", + response_reason_phrase_first))); + EXPECT_THAT(headers.response_version(), StrEq("HTTP/1.0")); + EXPECT_THAT(headers.response_code(), StrEq("200")); + EXPECT_THAT(headers.response_reason_phrase(), + StrEq(response_reason_phrase_first)); + } + + // Note that each iteration of the outer loop causes the headers to be left + // in a different state. Nothing wrong with that, but we should use each of + // these states, and try each of our scenarios again. This inner loop does + // that. + { + int second = iteration % ABSL_ARRAYSIZE(response_reason_phrases); + const char* response_reason_phrase_second = + response_reason_phrases[second]; + std::string expected_new_firstline = + absl::StrFormat("HTTP/1.0 200 %s", response_reason_phrase_second); + SCOPED_TRACE(absl::StrFormat("Then set response_reason_phrase(\"%s\")", + response_reason_phrase_second)); + + headers.SetResponseReasonPhrase(response_reason_phrase_second); + EXPECT_THAT(headers.first_line(), + StrEq(absl::StrFormat("HTTP/1.0 200 %s", + response_reason_phrase_second))); + EXPECT_THAT(headers.response_version(), StrEq("HTTP/1.0")); + EXPECT_THAT(headers.response_code(), StrEq("200")); + EXPECT_THAT(headers.response_reason_phrase(), + StrEq(response_reason_phrase_second)); + } + } +} + +TEST(BalsaHeaders, TestSetResponseVersionWithNoInitialFirstline) { + BalsaHeaders balsa_headers; + balsa_headers.SetResponseVersion("HTTP/1.1"); + EXPECT_THAT(balsa_headers.first_line(), StrEq("HTTP/1.1 ")); + EXPECT_THAT(balsa_headers.response_version(), StrEq("HTTP/1.1")); + EXPECT_TRUE(balsa_headers.response_code().empty()); + EXPECT_TRUE(balsa_headers.response_reason_phrase().empty()); +} + +// Testing each of 9 combinations separately was taking up way too much of this +// file (not to mention the inordinate amount of stupid code duplication), thus +// this test tests all 9 combinations of smaller, equal, and larger in one +// place. +TEST(BalsaHeaders, TestSetResponseVersion) { + const char* response_versions[] = { + "ABCD/123", + "ABCD", + "ABCD/123456", + }; + size_t arraysize_squared = + (ABSL_ARRAYSIZE(response_versions) * ABSL_ARRAYSIZE(response_versions)); + // We go through the 9 different permutations of (response_versions choose 2) + // in the loop below. For each permutation, we mutate the firstline twice-- + // once from the original, and once from the previous. + for (size_t iteration = 0; iteration < arraysize_squared; ++iteration) { + SCOPED_TRACE("Original firstline: \"HTTP/1.0 200 reason phrase\""); + BalsaHeaders headers = CreateHTTPHeaders(false, + "HTTP/1.0 200 reason phrase\r\n" + "content-length: 0\r\n" + "\r\n"); + ASSERT_THAT(headers.first_line(), StrEq("HTTP/1.0 200 reason phrase")); + + // This structure guarantees that we'll visit all of the possible + // variations of setting. + + { + int first = iteration / ABSL_ARRAYSIZE(response_versions); + const char* response_version_first = response_versions[first]; + std::string expected_new_firstline = + absl::StrFormat("%s 200 reason phrase", response_version_first); + SCOPED_TRACE(absl::StrFormat("Then set response_version(\"%s\")", + response_version_first)); + + headers.SetResponseVersion(response_version_first); + EXPECT_THAT(headers.first_line(), StrEq(expected_new_firstline)); + EXPECT_THAT(headers.response_version(), StrEq(response_version_first)); + EXPECT_THAT(headers.response_code(), StrEq("200")); + EXPECT_THAT(headers.response_reason_phrase(), StrEq("reason phrase")); + } + { + int second = iteration % ABSL_ARRAYSIZE(response_versions); + const char* response_version_second = response_versions[second]; + std::string expected_new_firstline = + absl::StrFormat("%s 200 reason phrase", response_version_second); + SCOPED_TRACE(absl::StrFormat("Then set response_version(\"%s\")", + response_version_second)); + + headers.SetResponseVersion(response_version_second); + EXPECT_THAT(headers.first_line(), StrEq(expected_new_firstline)); + EXPECT_THAT(headers.response_version(), StrEq(response_version_second)); + EXPECT_THAT(headers.response_code(), StrEq("200")); + EXPECT_THAT(headers.response_reason_phrase(), StrEq("reason phrase")); + } + } +} + +TEST(BalsaHeaders, TestSetResponseReasonAndVersionWithNoInitialFirstline) { + BalsaHeaders headers; + headers.SetResponseVersion("HTTP/1.1"); + headers.SetResponseReasonPhrase("don't need a reason"); + EXPECT_THAT(headers.first_line(), StrEq("HTTP/1.1 don't need a reason")); + EXPECT_THAT(headers.response_version(), StrEq("HTTP/1.1")); + EXPECT_TRUE(headers.response_code().empty()); + EXPECT_THAT(headers.response_reason_phrase(), StrEq("don't need a reason")); +} + +TEST(BalsaHeaders, TestSetResponseCodeWithNoInitialFirstline) { + BalsaHeaders balsa_headers; + balsa_headers.SetParsedResponseCodeAndUpdateFirstline(2002); + EXPECT_THAT(balsa_headers.first_line(), StrEq(" 2002 ")); + EXPECT_TRUE(balsa_headers.response_version().empty()); + EXPECT_THAT(balsa_headers.response_code(), StrEq("2002")); + EXPECT_TRUE(balsa_headers.response_reason_phrase().empty()); + EXPECT_THAT(balsa_headers.parsed_response_code(), Eq(2002)); +} + +TEST(BalsaHeaders, TestSetParsedResponseCode) { + BalsaHeaders balsa_headers; + balsa_headers.set_parsed_response_code(std::numeric_limits::max()); + EXPECT_THAT(balsa_headers.parsed_response_code(), + Eq(std::numeric_limits::max())); +} + +TEST(BalsaHeaders, TestSetResponseCode) { + const char* response_codes[] = { + "200" + "23", + "200200", + }; + size_t arraysize_squared = + (ABSL_ARRAYSIZE(response_codes) * ABSL_ARRAYSIZE(response_codes)); + // We go through the 9 different permutations of (response_codes choose 2) + // in the loop below. For each permutation, we mutate the firstline twice-- + // once from the original, and once from the previous. + for (size_t iteration = 0; iteration < arraysize_squared; ++iteration) { + SCOPED_TRACE("Original firstline: \"HTTP/1.0 200 reason phrase\""); + BalsaHeaders headers = CreateHTTPHeaders(false, + "HTTP/1.0 200 reason phrase\r\n" + "content-length: 0\r\n" + "\r\n"); + ASSERT_THAT(headers.first_line(), StrEq("HTTP/1.0 200 reason phrase")); + + // This structure guarantees that we'll visit all of the possible + // variations of setting. + + { + int first = iteration / ABSL_ARRAYSIZE(response_codes); + const char* response_code_first = response_codes[first]; + std::string expected_new_firstline = + absl::StrFormat("HTTP/1.0 %s reason phrase", response_code_first); + SCOPED_TRACE(absl::StrFormat("Then set response_code(\"%s\")", + response_code_first)); + + headers.SetResponseCode(response_code_first); + + EXPECT_THAT(headers.first_line(), StrEq(expected_new_firstline)); + EXPECT_THAT(headers.response_version(), StrEq("HTTP/1.0")); + EXPECT_THAT(headers.response_code(), StrEq(response_code_first)); + EXPECT_THAT(headers.response_reason_phrase(), StrEq("reason phrase")); + } + { + int second = iteration % ABSL_ARRAYSIZE(response_codes); + const char* response_code_second = response_codes[second]; + std::string expected_new_secondline = + absl::StrFormat("HTTP/1.0 %s reason phrase", response_code_second); + SCOPED_TRACE(absl::StrFormat("Then set response_code(\"%s\")", + response_code_second)); + + headers.SetResponseCode(response_code_second); + + EXPECT_THAT(headers.first_line(), StrEq(expected_new_secondline)); + EXPECT_THAT(headers.response_version(), StrEq("HTTP/1.0")); + EXPECT_THAT(headers.response_code(), StrEq(response_code_second)); + EXPECT_THAT(headers.response_reason_phrase(), StrEq("reason phrase")); + } + } +} + +TEST(BalsaHeaders, TestAppendToHeader) { + // Test the basic case of appending to a header. + BalsaHeaders headers; + headers.AppendHeader("foo", "foo_value"); + headers.AppendHeader("bar", "bar_value"); + headers.AppendToHeader("foo", "foo_value2"); + + EXPECT_THAT(headers.GetHeader("foo"), StrEq("foo_value,foo_value2")); + EXPECT_THAT(headers.GetHeader("bar"), StrEq("bar_value")); +} + +TEST(BalsaHeaders, TestInitialAppend) { + // Test that AppendToHeader works properly when the header did not already + // exist. + BalsaHeaders headers; + headers.AppendToHeader("foo", "foo_value"); + EXPECT_THAT(headers.GetHeader("foo"), StrEq("foo_value")); + headers.AppendToHeader("foo", "foo_value2"); + EXPECT_THAT(headers.GetHeader("foo"), StrEq("foo_value,foo_value2")); +} + +TEST(BalsaHeaders, TestAppendAndRemove) { + // Test that AppendToHeader works properly with removing. + BalsaHeaders headers; + headers.AppendToHeader("foo", "foo_value"); + EXPECT_THAT(headers.GetHeader("foo"), StrEq("foo_value")); + headers.AppendToHeader("foo", "foo_value2"); + EXPECT_THAT(headers.GetHeader("foo"), StrEq("foo_value,foo_value2")); + headers.RemoveAllOfHeader("foo"); + headers.AppendToHeader("foo", "foo_value3"); + EXPECT_THAT(headers.GetHeader("foo"), StrEq("foo_value3")); + headers.AppendToHeader("foo", "foo_value4"); + EXPECT_THAT(headers.GetHeader("foo"), StrEq("foo_value3,foo_value4")); +} + +TEST(BalsaHeaders, TestAppendToHeaderWithCommaAndSpace) { + // Test the basic case of appending to a header with comma and space. + BalsaHeaders headers; + headers.AppendHeader("foo", "foo_value"); + headers.AppendHeader("bar", "bar_value"); + headers.AppendToHeaderWithCommaAndSpace("foo", "foo_value2"); + + EXPECT_THAT(headers.GetHeader("foo"), StrEq("foo_value, foo_value2")); + EXPECT_THAT(headers.GetHeader("bar"), StrEq("bar_value")); +} + +TEST(BalsaHeaders, TestInitialAppendWithCommaAndSpace) { + // Test that AppendToHeadeWithCommaAndSpace works properly when the + // header did not already exist. + BalsaHeaders headers; + headers.AppendToHeaderWithCommaAndSpace("foo", "foo_value"); + EXPECT_THAT(headers.GetHeader("foo"), StrEq("foo_value")); + headers.AppendToHeaderWithCommaAndSpace("foo", "foo_value2"); + EXPECT_THAT(headers.GetHeader("foo"), StrEq("foo_value, foo_value2")); +} + +TEST(BalsaHeaders, TestAppendWithCommaAndSpaceAndRemove) { + // Test that AppendToHeadeWithCommaAndSpace works properly with removing. + BalsaHeaders headers; + headers.AppendToHeaderWithCommaAndSpace("foo", "foo_value"); + EXPECT_THAT(headers.GetHeader("foo"), StrEq("foo_value")); + headers.AppendToHeaderWithCommaAndSpace("foo", "foo_value2"); + EXPECT_THAT(headers.GetHeader("foo"), StrEq("foo_value, foo_value2")); + headers.RemoveAllOfHeader("foo"); + headers.AppendToHeaderWithCommaAndSpace("foo", "foo_value3"); + EXPECT_THAT(headers.GetHeader("foo"), StrEq("foo_value3")); + headers.AppendToHeaderWithCommaAndSpace("foo", "foo_value4"); + EXPECT_THAT(headers.GetHeader("foo"), StrEq("foo_value3, foo_value4")); +} + +TEST(BalsaHeaders, SetContentLength) { + // Test that SetContentLength correctly sets the content-length header and + // sets the content length status. + BalsaHeaders headers; + headers.SetContentLength(10); + EXPECT_THAT(headers.GetHeader("Content-length"), StrEq("10")); + EXPECT_EQ(BalsaHeadersEnums::VALID_CONTENT_LENGTH, + headers.content_length_status()); + EXPECT_TRUE(headers.content_length_valid()); + + // Test overwriting the content-length. + headers.SetContentLength(0); + EXPECT_THAT(headers.GetHeader("Content-length"), StrEq("0")); + EXPECT_EQ(BalsaHeadersEnums::VALID_CONTENT_LENGTH, + headers.content_length_status()); + EXPECT_TRUE(headers.content_length_valid()); + + // Make sure there is only one header line after the overwrite. + BalsaHeaders::const_header_lines_iterator iter = + headers.GetHeaderPosition("Content-length"); + EXPECT_EQ(headers.lines().begin(), iter); + EXPECT_EQ(headers.lines().end(), ++iter); + + // Test setting the same content-length again, this should be no-op. + headers.SetContentLength(0); + EXPECT_THAT(headers.GetHeader("Content-length"), StrEq("0")); + EXPECT_EQ(BalsaHeadersEnums::VALID_CONTENT_LENGTH, + headers.content_length_status()); + EXPECT_TRUE(headers.content_length_valid()); + + // Make sure the number of header lines didn't change. + iter = headers.GetHeaderPosition("Content-length"); + EXPECT_EQ(headers.lines().begin(), iter); + EXPECT_EQ(headers.lines().end(), ++iter); +} + +TEST(BalsaHeaders, ToggleChunkedEncoding) { + // Test that SetTransferEncodingToChunkedAndClearContentLength correctly adds + // chunk-encoding header and sets the transfer_encoding_is_chunked_ + // flag. + BalsaHeaders headers; + headers.SetTransferEncodingToChunkedAndClearContentLength(); + EXPECT_EQ("chunked", headers.GetAllOfHeaderAsString("Transfer-Encoding")); + EXPECT_TRUE(headers.HasHeadersWithPrefix("Transfer-Encoding")); + EXPECT_TRUE(headers.HasHeadersWithPrefix("transfer-encoding")); + EXPECT_TRUE(headers.HasHeadersWithPrefix("transfer")); + EXPECT_TRUE(headers.transfer_encoding_is_chunked()); + + // Set it to the same value, nothing should change. + headers.SetTransferEncodingToChunkedAndClearContentLength(); + EXPECT_EQ("chunked", headers.GetAllOfHeaderAsString("Transfer-Encoding")); + EXPECT_TRUE(headers.HasHeadersWithPrefix("Transfer-Encoding")); + EXPECT_TRUE(headers.HasHeadersWithPrefix("transfer-encoding")); + EXPECT_TRUE(headers.HasHeadersWithPrefix("transfer")); + EXPECT_TRUE(headers.transfer_encoding_is_chunked()); + BalsaHeaders::const_header_lines_iterator iter = + headers.GetHeaderPosition("Transfer-Encoding"); + EXPECT_EQ(headers.lines().begin(), iter); + EXPECT_EQ(headers.lines().end(), ++iter); + + // Removes the chunked encoding, and there should be no transfer-encoding + // headers left. + headers.SetNoTransferEncoding(); + EXPECT_FALSE(headers.HasHeader("Transfer-Encoding")); + EXPECT_FALSE(headers.HasHeadersWithPrefix("Transfer-Encoding")); + EXPECT_FALSE(headers.HasHeadersWithPrefix("transfer-encoding")); + EXPECT_FALSE(headers.HasHeadersWithPrefix("transfer")); + EXPECT_FALSE(headers.transfer_encoding_is_chunked()); + EXPECT_EQ(headers.lines().end(), headers.lines().begin()); + + // Clear chunked again, this should be a no-op and the header should not + // change. + headers.SetNoTransferEncoding(); + EXPECT_FALSE(headers.HasHeader("Transfer-Encoding")); + EXPECT_FALSE(headers.HasHeadersWithPrefix("Transfer-Encoding")); + EXPECT_FALSE(headers.HasHeadersWithPrefix("transfer-encoding")); + EXPECT_FALSE(headers.HasHeadersWithPrefix("transfer")); + EXPECT_FALSE(headers.transfer_encoding_is_chunked()); + EXPECT_EQ(headers.lines().end(), headers.lines().begin()); +} + +TEST(BalsaHeaders, SetNoTransferEncodingByRemoveHeader) { + // Tests that calling Remove() methods to clear the Transfer-Encoding + // header correctly resets transfer_encoding_is_chunked_ internal state. + BalsaHeaders headers; + headers.SetTransferEncodingToChunkedAndClearContentLength(); + headers.RemoveAllOfHeader("Transfer-Encoding"); + EXPECT_FALSE(headers.transfer_encoding_is_chunked()); + + headers.SetTransferEncodingToChunkedAndClearContentLength(); + std::vector headers_to_remove; + headers_to_remove.emplace_back("Transfer-Encoding"); + headers.RemoveAllOfHeaderInList(headers_to_remove); + EXPECT_FALSE(headers.transfer_encoding_is_chunked()); + + headers.SetTransferEncodingToChunkedAndClearContentLength(); + headers.RemoveAllHeadersWithPrefix("Transfer"); + EXPECT_FALSE(headers.transfer_encoding_is_chunked()); +} + +TEST(BalsaHeaders, ClearContentLength) { + // Test that ClearContentLength() removes the content-length header and + // resets content_length_status(). + BalsaHeaders headers; + headers.SetContentLength(10); + headers.ClearContentLength(); + EXPECT_FALSE(headers.HasHeader("Content-length")); + EXPECT_EQ(BalsaHeadersEnums::NO_CONTENT_LENGTH, + headers.content_length_status()); + EXPECT_FALSE(headers.content_length_valid()); + + // Clear it again; nothing should change. + headers.ClearContentLength(); + EXPECT_FALSE(headers.HasHeader("Content-length")); + EXPECT_EQ(BalsaHeadersEnums::NO_CONTENT_LENGTH, + headers.content_length_status()); + EXPECT_FALSE(headers.content_length_valid()); + + // Set chunked encoding and test that ClearContentLength() has no effect. + headers.SetTransferEncodingToChunkedAndClearContentLength(); + headers.ClearContentLength(); + EXPECT_EQ("chunked", headers.GetAllOfHeaderAsString("Transfer-Encoding")); + EXPECT_TRUE(headers.transfer_encoding_is_chunked()); + BalsaHeaders::const_header_lines_iterator iter = + headers.GetHeaderPosition("Transfer-Encoding"); + EXPECT_EQ(headers.lines().begin(), iter); + EXPECT_EQ(headers.lines().end(), ++iter); + + // Remove chunked encoding, and verify that the state is the same as after + // ClearContentLength(). + headers.SetNoTransferEncoding(); + EXPECT_EQ(BalsaHeadersEnums::NO_CONTENT_LENGTH, + headers.content_length_status()); + EXPECT_FALSE(headers.content_length_valid()); +} + +TEST(BalsaHeaders, ClearContentLengthByRemoveHeader) { + // Test that calling Remove() methods to clear the content-length header + // correctly resets internal content length fields. + BalsaHeaders headers; + headers.SetContentLength(10); + headers.RemoveAllOfHeader("Content-Length"); + EXPECT_EQ(BalsaHeadersEnums::NO_CONTENT_LENGTH, + headers.content_length_status()); + EXPECT_EQ(0u, headers.content_length()); + EXPECT_FALSE(headers.content_length_valid()); + + headers.SetContentLength(11); + std::vector headers_to_remove; + headers_to_remove.emplace_back("Content-Length"); + headers.RemoveAllOfHeaderInList(headers_to_remove); + EXPECT_EQ(BalsaHeadersEnums::NO_CONTENT_LENGTH, + headers.content_length_status()); + EXPECT_EQ(0u, headers.content_length()); + EXPECT_FALSE(headers.content_length_valid()); + + headers.SetContentLength(12); + headers.RemoveAllHeadersWithPrefix("Content"); + EXPECT_EQ(BalsaHeadersEnums::NO_CONTENT_LENGTH, + headers.content_length_status()); + EXPECT_EQ(0u, headers.content_length()); + EXPECT_FALSE(headers.content_length_valid()); +} + +// Chunk-encoding an identity-coded BalsaHeaders removes the identity-coding. +TEST(BalsaHeaders, IdentityCodingToChunked) { + std::string message = + "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: identity\r\n\r\n"; + BalsaHeaders headers; + BalsaFrame balsa_frame; + balsa_frame.set_is_request(false); + balsa_frame.set_balsa_headers(&headers); + EXPECT_EQ(message.size(), + balsa_frame.ProcessInput(message.data(), message.size())); + + EXPECT_TRUE(headers.is_framed_by_connection_close()); + EXPECT_FALSE(headers.transfer_encoding_is_chunked()); + EXPECT_THAT(headers.GetAllOfHeader("Transfer-Encoding"), + ElementsAre("identity")); + + headers.SetTransferEncodingToChunkedAndClearContentLength(); + + EXPECT_FALSE(headers.is_framed_by_connection_close()); + EXPECT_TRUE(headers.transfer_encoding_is_chunked()); + EXPECT_THAT(headers.GetAllOfHeader("Transfer-Encoding"), + ElementsAre("chunked")); +} + +TEST(BalsaHeaders, SwitchContentLengthToChunk) { + // Test that a header originally with content length header is correctly + // switched to using chunk encoding. + BalsaHeaders headers; + headers.SetContentLength(10); + EXPECT_THAT(headers.GetHeader("Content-length"), StrEq("10")); + EXPECT_EQ(BalsaHeadersEnums::VALID_CONTENT_LENGTH, + headers.content_length_status()); + EXPECT_TRUE(headers.content_length_valid()); + + headers.SetTransferEncodingToChunkedAndClearContentLength(); + EXPECT_EQ("chunked", headers.GetAllOfHeaderAsString("Transfer-Encoding")); + EXPECT_TRUE(headers.transfer_encoding_is_chunked()); + EXPECT_FALSE(headers.HasHeader("Content-length")); + EXPECT_EQ(BalsaHeadersEnums::NO_CONTENT_LENGTH, + headers.content_length_status()); + EXPECT_FALSE(headers.content_length_valid()); +} + +TEST(BalsaHeaders, SwitchChunkedToContentLength) { + // Test that a header originally with chunk encoding is correctly + // switched to using content length. + BalsaHeaders headers; + headers.SetTransferEncodingToChunkedAndClearContentLength(); + EXPECT_EQ("chunked", headers.GetAllOfHeaderAsString("Transfer-Encoding")); + EXPECT_TRUE(headers.transfer_encoding_is_chunked()); + EXPECT_FALSE(headers.HasHeader("Content-length")); + EXPECT_EQ(BalsaHeadersEnums::NO_CONTENT_LENGTH, + headers.content_length_status()); + EXPECT_FALSE(headers.content_length_valid()); + + headers.SetContentLength(10); + EXPECT_THAT(headers.GetHeader("Content-length"), StrEq("10")); + EXPECT_EQ(BalsaHeadersEnums::VALID_CONTENT_LENGTH, + headers.content_length_status()); + EXPECT_TRUE(headers.content_length_valid()); + EXPECT_FALSE(headers.HasHeader("Transfer-Encoding")); + EXPECT_FALSE(headers.transfer_encoding_is_chunked()); +} + +TEST(BalsaHeaders, OneHundredResponseMessagesNoFramedByClose) { + BalsaHeaders headers; + headers.SetResponseFirstline("HTTP/1.1", 100, "Continue"); + EXPECT_FALSE(headers.is_framed_by_connection_close()); +} + +TEST(BalsaHeaders, TwoOhFourResponseMessagesNoFramedByClose) { + BalsaHeaders headers; + headers.SetResponseFirstline("HTTP/1.1", 204, "Continue"); + EXPECT_FALSE(headers.is_framed_by_connection_close()); +} + +TEST(BalsaHeaders, ThreeOhFourResponseMessagesNoFramedByClose) { + BalsaHeaders headers; + headers.SetResponseFirstline("HTTP/1.1", 304, "Continue"); + EXPECT_FALSE(headers.is_framed_by_connection_close()); +} + +TEST(BalsaHeaders, InvalidCharInHeaderValue) { + std::string message = + "GET http://www.256.com/foo HTTP/1.1\r\n" + "Host: \x01\x01www.265.com\r\n" + "\r\n"; + BalsaHeaders headers = CreateHTTPHeaders(true, message); + EXPECT_EQ("www.265.com", headers.GetHeader("Host")); + SimpleBuffer buffer; + headers.WriteHeaderAndEndingToBuffer(&buffer); + message.replace(message.find_first_of(0x1), 2, ""); + EXPECT_EQ(message, buffer.GetReadableRegion()); +} + +TEST(BalsaHeaders, CarriageReturnAtStartOfLine) { + std::string message = + "GET /foo HTTP/1.1\r\n" + "Host: www.265.com\r\n" + "Foo: bar\r\n" + "\rX-User-Ip: 1.2.3.4\r\n" + "\r\n"; + BalsaHeaders headers; + BalsaFrame balsa_frame; + balsa_frame.set_is_request(true); + balsa_frame.set_balsa_headers(&headers); + EXPECT_EQ(message.size(), + balsa_frame.ProcessInput(message.data(), message.size())); + EXPECT_EQ(BalsaFrameEnums::INVALID_HEADER_FORMAT, balsa_frame.ErrorCode()); + EXPECT_TRUE(balsa_frame.Error()); +} + +TEST(BalsaHeaders, CheckEmpty) { + BalsaHeaders headers; + EXPECT_TRUE(headers.IsEmpty()); +} + +TEST(BalsaHeaders, CheckNonEmpty) { + BalsaHeaders headers; + BalsaHeadersTestPeer::WriteFromFramer(&headers, "a b c", 5); + EXPECT_FALSE(headers.IsEmpty()); +} + +TEST(BalsaHeaders, ForEachHeader) { + BalsaHeaders headers; + headers.AppendHeader(":host", "SomeHost"); + headers.AppendHeader("key", "val1,val2val2,val2,val3"); + headers.AppendHeader("key", "val4val5val6"); + headers.AppendHeader("key", "val11 val12"); + headers.AppendHeader("key", "v val13"); + headers.AppendHeader("key", "val7"); + headers.AppendHeader("key", ""); + headers.AppendHeader("key", "val8 , val9 ,, val10"); + headers.AppendHeader("key", " val14 "); + headers.AppendHeader("key2", "val15"); + headers.AppendHeader("key", "Val16"); + headers.AppendHeader("key", "foo, Val17, bar"); + headers.AppendHeader("date", "2 Jan 1970"); + headers.AppendHeader("AcceptEncoding", "MyFavoriteEncoding"); + + { + std::string result; + EXPECT_TRUE(headers.ForEachHeader( + [&result](const absl::string_view key, absl::string_view value) { + result.append("<") + .append(key.data(), key.size()) + .append("> = <") + .append(value.data(), value.size()) + .append(">\n"); + return true; + })); + + EXPECT_EQ(result, + "<:host> = \n" + " = \n" + " = \n" + " = \n" + " = \n" + " = \n" + " = <>\n" + " = \n" + " = < val14 >\n" + " = \n" + " = \n" + " = \n" + " = <2 Jan 1970>\n" + " = \n"); + } + + { + std::string result; + EXPECT_FALSE(headers.ForEachHeader( + [&result](const absl::string_view key, absl::string_view value) { + result.append("<") + .append(key.data(), key.size()) + .append("> = <") + .append(value.data(), value.size()) + .append(">\n"); + return !value.empty(); + })); + + EXPECT_EQ(result, + "<:host> = \n" + " = \n" + " = \n" + " = \n" + " = \n" + " = \n" + " = <>\n"); + } +} + +TEST(BalsaHeaders, WriteToBufferWithLowerCasedHeaderKey) { + BalsaHeaders headers; + headers.SetRequestFirstlineFromStringPieces("GET", "/", "HTTP/1.0"); + headers.AppendHeader("Key1", "value1"); + headers.AppendHeader("Key2", "value2"); + std::string expected_lower_case = + "GET / HTTP/1.0\r\n" + "key1: value1\r\n" + "key2: value2\r\n"; + std::string expected_lower_case_with_end = + "GET / HTTP/1.0\r\n" + "key1: value1\r\n" + "key2: value2\r\n\r\n"; + std::string expected_upper_case = + "GET / HTTP/1.0\r\n" + "Key1: value1\r\n" + "Key2: value2\r\n"; + std::string expected_upper_case_with_end = + "GET / HTTP/1.0\r\n" + "Key1: value1\r\n" + "Key2: value2\r\n\r\n"; + + SimpleBuffer simple_buffer; + headers.WriteToBuffer(&simple_buffer, BalsaHeaders::CaseOption::kLowercase, + BalsaHeaders::CoalesceOption::kNoCoalesce); + EXPECT_THAT(simple_buffer.GetReadableRegion(), StrEq(expected_lower_case)); + + simple_buffer.Clear(); + headers.WriteToBuffer(&simple_buffer); + EXPECT_THAT(simple_buffer.GetReadableRegion(), StrEq(expected_upper_case)); + + simple_buffer.Clear(); + headers.WriteHeaderAndEndingToBuffer(&simple_buffer); + EXPECT_THAT(simple_buffer.GetReadableRegion(), + StrEq(expected_upper_case_with_end)); + + simple_buffer.Clear(); + headers.WriteHeaderAndEndingToBuffer( + &simple_buffer, BalsaHeaders::CaseOption::kLowercase, + BalsaHeaders::CoalesceOption::kNoCoalesce); + EXPECT_THAT(simple_buffer.GetReadableRegion(), + StrEq(expected_lower_case_with_end)); +} + +TEST(BalsaHeaders, WriteToBufferWithProperCasedHeaderKey) { + BalsaHeaders headers; + headers.SetRequestFirstlineFromStringPieces("GET", "/", "HTTP/1.0"); + headers.AppendHeader("Te", "value1"); + headers.AppendHeader("my-Test-header", "value2"); + std::string expected_proper_case = + "GET / HTTP/1.0\r\n" + "TE: value1\r\n" + "My-Test-Header: value2\r\n"; + std::string expected_proper_case_with_end = + "GET / HTTP/1.0\r\n" + "TE: value1\r\n" + "My-Test-Header: value2\r\n\r\n"; + std::string expected_unmodified = + "GET / HTTP/1.0\r\n" + "Te: value1\r\n" + "my-Test-header: value2\r\n"; + std::string expected_unmodified_with_end = + "GET / HTTP/1.0\r\n" + "Te: value1\r\n" + "my-Test-header: value2\r\n\r\n"; + + SimpleBuffer simple_buffer; + headers.WriteToBuffer(&simple_buffer, BalsaHeaders::CaseOption::kPropercase, + BalsaHeaders::CoalesceOption::kNoCoalesce); + EXPECT_EQ(simple_buffer.GetReadableRegion(), expected_proper_case); + + simple_buffer.Clear(); + headers.WriteToBuffer(&simple_buffer, + BalsaHeaders::CaseOption::kNoModification, + BalsaHeaders::CoalesceOption::kNoCoalesce); + EXPECT_EQ(simple_buffer.GetReadableRegion(), expected_unmodified); + + simple_buffer.Clear(); + headers.WriteHeaderAndEndingToBuffer( + &simple_buffer, BalsaHeaders::CaseOption::kNoModification, + BalsaHeaders::CoalesceOption::kNoCoalesce); + EXPECT_EQ(simple_buffer.GetReadableRegion(), expected_unmodified_with_end); + + simple_buffer.Clear(); + headers.WriteHeaderAndEndingToBuffer( + &simple_buffer, BalsaHeaders::CaseOption::kPropercase, + BalsaHeaders::CoalesceOption::kNoCoalesce); + EXPECT_EQ(simple_buffer.GetReadableRegion(), expected_proper_case_with_end); +} + +TEST(BalsaHeadersTest, ToPropercaseTest) { + EXPECT_EQ(BalsaHeaders::ToPropercase(""), ""); + EXPECT_EQ(BalsaHeaders::ToPropercase("Foo"), "Foo"); + EXPECT_EQ(BalsaHeaders::ToPropercase("foO"), "Foo"); + EXPECT_EQ(BalsaHeaders::ToPropercase("my-test-header"), "My-Test-Header"); + EXPECT_EQ(BalsaHeaders::ToPropercase("my--test-header"), "My--Test-Header"); +} + +TEST(BalsaHeaders, WriteToBufferCoalescingMultivaluedHeaders) { + BalsaHeaders::MultivaluedHeadersSet multivalued_headers; + multivalued_headers.insert("KeY1"); + multivalued_headers.insert("another_KEY"); + + BalsaHeaders headers; + headers.SetRequestFirstlineFromStringPieces("GET", "/", "HTTP/1.0"); + headers.AppendHeader("Key1", "value1"); + headers.AppendHeader("Key2", "value2"); + headers.AppendHeader("Key1", "value11"); + headers.AppendHeader("Key2", "value21"); + headers.AppendHeader("Key1", "multiples, values, already"); + std::string expected_non_coalesced = + "GET / HTTP/1.0\r\n" + "Key1: value1\r\n" + "Key2: value2\r\n" + "Key1: value11\r\n" + "Key2: value21\r\n" + "Key1: multiples, values, already\r\n"; + std::string expected_coalesced = + "Key1: value1,value11,multiples, values, already\r\n" + "Key2: value2\r\n" + "Key2: value21\r\n"; + + SimpleBuffer simple_buffer; + headers.WriteToBuffer(&simple_buffer); + EXPECT_EQ(simple_buffer.GetReadableRegion(), expected_non_coalesced); + + simple_buffer.Clear(); + headers.WriteToBufferCoalescingMultivaluedHeaders( + &simple_buffer, multivalued_headers, + BalsaHeaders::CaseOption::kNoModification); + EXPECT_EQ(simple_buffer.GetReadableRegion(), expected_coalesced); +} + +TEST(BalsaHeaders, WriteToBufferCoalescingMultivaluedHeadersMultiLine) { + BalsaHeaders::MultivaluedHeadersSet multivalued_headers; + multivalued_headers.insert("Key 2"); + multivalued_headers.insert("key\n 3"); + + BalsaHeaders headers; + headers.AppendHeader("key1", "value1"); + headers.AppendHeader("key 2", "value\n 2"); + headers.AppendHeader("key\n 3", "value3"); + headers.AppendHeader("key 2", "value 21"); + headers.AppendHeader("key 3", "value 33"); + std::string expected_non_coalesced = + "\r\n" + "key1: value1\r\n" + "key 2: value\n" + " 2\r\n" + "key\n" + " 3: value3\r\n" + "key 2: value 21\r\n" + "key 3: value 33\r\n"; + + SimpleBuffer simple_buffer; + headers.WriteToBuffer(&simple_buffer); + EXPECT_EQ(simple_buffer.GetReadableRegion(), expected_non_coalesced); + + std::string expected_coalesced = + "key1: value1\r\n" + "key 2: value\n" + " 2,value 21\r\n" + "key\n" + " 3: value3\r\n" + "key 3: value 33\r\n"; + + simple_buffer.Clear(); + headers.WriteToBufferCoalescingMultivaluedHeaders( + &simple_buffer, multivalued_headers, + BalsaHeaders::CaseOption::kNoModification); + EXPECT_EQ(simple_buffer.GetReadableRegion(), expected_coalesced); +} + +TEST(BalsaHeaders, WriteToBufferCoalescingEnvoyHeaders) { + BalsaHeaders headers; + headers.SetRequestFirstlineFromStringPieces("GET", "/", "HTTP/1.0"); + headers.AppendHeader("User-Agent", "UserAgent1"); + headers.AppendHeader("Key2", "value2"); + headers.AppendHeader("USER-AGENT", "UA2"); + headers.AppendHeader("Set-Cookie", "Cookie1=aaa"); + headers.AppendHeader("user-agent", "agent3"); + headers.AppendHeader("Set-Cookie", "Cookie2=bbb"); + std::string expected_non_coalesced = + "GET / HTTP/1.0\r\n" + "User-Agent: UserAgent1\r\n" + "Key2: value2\r\n" + "USER-AGENT: UA2\r\n" + "Set-Cookie: Cookie1=aaa\r\n" + "user-agent: agent3\r\n" + "Set-Cookie: Cookie2=bbb\r\n" + "\r\n"; + std::string expected_coalesced = + "GET / HTTP/1.0\r\n" + "User-Agent: UserAgent1,UA2,agent3\r\n" + "Key2: value2\r\n" + "Set-Cookie: Cookie1=aaa\r\n" + "Set-Cookie: Cookie2=bbb\r\n" + "\r\n"; + + SimpleBuffer simple_buffer; + headers.WriteHeaderAndEndingToBuffer(&simple_buffer); + EXPECT_EQ(simple_buffer.GetReadableRegion(), expected_non_coalesced); + + simple_buffer.Clear(); + headers.WriteHeaderAndEndingToBuffer( + &simple_buffer, BalsaHeaders::CaseOption::kNoModification, + BalsaHeaders::CoalesceOption::kCoalesce); + EXPECT_EQ(simple_buffer.GetReadableRegion(), expected_coalesced); +} + +TEST(BalsaHeadersTest, RemoveLastTokenFromOneLineHeader) { + BalsaHeaders headers = + CreateHTTPHeaders(true, + "GET /foo HTTP/1.1\r\n" + "Content-Length: 0\r\n" + "Content-Encoding: gzip, 3des, tar, prc\r\n\r\n"); + + BalsaHeaders::const_header_lines_key_iterator it = + headers.GetIteratorForKey("Content-Encoding"); + ASSERT_EQ("gzip, 3des, tar, prc", it->second); + EXPECT_EQ(headers.header_lines_key_end(), ++it); + + headers.RemoveLastTokenFromHeaderValue("Content-Encoding"); + it = headers.GetIteratorForKey("Content-Encoding"); + ASSERT_EQ("gzip, 3des, tar", it->second); + EXPECT_EQ(headers.header_lines_key_end(), ++it); + + headers.RemoveLastTokenFromHeaderValue("Content-Encoding"); + it = headers.GetIteratorForKey("Content-Encoding"); + ASSERT_EQ("gzip, 3des", it->second); + EXPECT_EQ(headers.header_lines_key_end(), ++it); + + headers.RemoveLastTokenFromHeaderValue("Content-Encoding"); + it = headers.GetIteratorForKey("Content-Encoding"); + ASSERT_EQ("gzip", it->second); + EXPECT_EQ(headers.header_lines_key_end(), ++it); + + headers.RemoveLastTokenFromHeaderValue("Content-Encoding"); + + EXPECT_FALSE(headers.HasHeader("Content-Encoding")); +} + +TEST(BalsaHeadersTest, RemoveLastTokenFromMultiLineHeader) { + BalsaHeaders headers = + CreateHTTPHeaders(true, + "GET /foo HTTP/1.1\r\n" + "Content-Length: 0\r\n" + "Content-Encoding: gzip, 3des\r\n" + "Content-Encoding: tar, prc\r\n\r\n"); + + BalsaHeaders::const_header_lines_key_iterator it = + headers.GetIteratorForKey("Content-Encoding"); + ASSERT_EQ("gzip, 3des", it->second); + ASSERT_EQ("tar, prc", (++it)->second); + ASSERT_EQ(headers.header_lines_key_end(), ++it); + + // First, we should start removing tokens from the second line. + headers.RemoveLastTokenFromHeaderValue("Content-Encoding"); + it = headers.GetIteratorForKey("Content-Encoding"); + ASSERT_EQ("gzip, 3des", it->second); + ASSERT_EQ("tar", (++it)->second); + ASSERT_EQ(headers.header_lines_key_end(), ++it); + + // Second line should be entirely removed after all its tokens are gone. + headers.RemoveLastTokenFromHeaderValue("Content-Encoding"); + it = headers.GetIteratorForKey("Content-Encoding"); + ASSERT_EQ("gzip, 3des", it->second); + ASSERT_EQ(headers.header_lines_key_end(), ++it); + + // Now we should be removing the tokens from the first line. + headers.RemoveLastTokenFromHeaderValue("Content-Encoding"); + it = headers.GetIteratorForKey("Content-Encoding"); + ASSERT_EQ("gzip", it->second); + ASSERT_EQ(headers.header_lines_key_end(), ++it); + + headers.RemoveLastTokenFromHeaderValue("Content-Encoding"); + EXPECT_FALSE(headers.HasHeader("Content-Encoding")); +} + +TEST(BalsaHeadersTest, ResponseCanHaveBody) { + // 1xx, 204 no content and 304 not modified responses can't have bodies. + EXPECT_FALSE(BalsaHeaders::ResponseCanHaveBody(100)); + EXPECT_FALSE(BalsaHeaders::ResponseCanHaveBody(101)); + EXPECT_FALSE(BalsaHeaders::ResponseCanHaveBody(102)); + EXPECT_FALSE(BalsaHeaders::ResponseCanHaveBody(204)); + EXPECT_FALSE(BalsaHeaders::ResponseCanHaveBody(304)); + + // Other responses can have body. + EXPECT_TRUE(BalsaHeaders::ResponseCanHaveBody(200)); + EXPECT_TRUE(BalsaHeaders::ResponseCanHaveBody(302)); + EXPECT_TRUE(BalsaHeaders::ResponseCanHaveBody(404)); + EXPECT_TRUE(BalsaHeaders::ResponseCanHaveBody(502)); +} + +} // namespace + +} // namespace test + +} // namespace quiche diff --git a/quiche/balsa/balsa_visitor_interface.h b/quiche/balsa/balsa_visitor_interface.h new file mode 100644 index 000000000000..e2c4327d12e6 --- /dev/null +++ b/quiche/balsa/balsa_visitor_interface.h @@ -0,0 +1,165 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_BALSA_BALSA_VISITOR_INTERFACE_H_ +#define QUICHE_BALSA_BALSA_VISITOR_INTERFACE_H_ + +#include + +#include "absl/strings/string_view.h" +#include "quiche/balsa/balsa_enums.h" +#include "quiche/balsa/balsa_headers.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace quiche { + +// By default the BalsaFrame instantiates a class derived from this interface +// that does absolutely nothing. If you'd prefer to have interesting +// functionality execute when any of the below functions are called by the +// BalsaFrame, then you should subclass it, and set an instantiation of your +// subclass as the current visitor for the BalsaFrame class using +// BalsaFrame::set_visitor(). +class QUICHE_EXPORT BalsaVisitorInterface { + public: + virtual ~BalsaVisitorInterface() {} + + // Summary: + // This is how the BalsaFrame passes you the raw input that it knows to be a + // part of the body. To be clear, every byte of the Balsa that isn't part of + // the header (or its framing), or trailers will be passed through this + // function. This includes data as well as chunking framing. + // Arguments: + // input - the raw input that is part of the body. + virtual void OnRawBodyInput(absl::string_view input) = 0; + + // Summary: + // This is like OnRawBodyInput, but it will only include those parts of the + // body that would be stored by a program such as wget, i.e. the bytes + // indicating chunking will have been removed. Trailers will not be passed + // in through this function-- they'll be passed in through OnTrailerInput. + // Arguments: + // input - the part of the body. + virtual void OnBodyChunkInput(absl::string_view input) = 0; + + // Summary: + // BalsaFrame passes the raw header data through this function. This is not + // cleaned up in any way. + // Arguments: + // input - raw header data. + virtual void OnHeaderInput(absl::string_view input) = 0; + + // Summary: + // BalsaFrame passes each header through this function as soon as it is + // parsed. + // Argument: + // key - the header name. + // value - the associated header value. + virtual void OnHeader(absl::string_view key, absl::string_view value) = 0; + + // Summary: + // BalsaFrame passes the raw trailer data through this function. This is not + // cleaned up in any way. Note that trailers only occur in a message if + // there was a chunked encoding, and not always then. + // Arguments: + // input - raw trailer data. + virtual void OnTrailerInput(absl::string_view input) = 0; + + // Summary: + // Since the BalsaFrame already has to parse the headers in order to + // determine proper framing, it might as well pass the parsed and cleaned-up + // results to whatever might need it. This function exists for that + // purpose-- parsed headers are passed into this function. + // Arguments: + // headers - contains the parsed headers in the order in which + // they occurred in the header. + virtual void ProcessHeaders(const BalsaHeaders& headers) = 0; + + // Summary: + // Since the BalsaFrame already has to parse the trailer, it might as well + // pass the parsed and cleaned-up results to whatever might need it. This + // function exists for that purpose-- parsed trailer is passed into this + // function. This will not be called if the trailer_ object is not set in + // the framer, even if trailer exists in request/response. + // Arguments: + // trailer - contains the parsed headers in the order in which + // they occurred in the trailer. + virtual void ProcessTrailers(const BalsaHeaders& trailer) = 0; + + // Summary: + // Called when the first line of the message is parsed, in this case, for a + // request. + // Arguments: + // line_input - the first line string, + // method_input - the method substring, + // request_uri_input - request uri substring, + // version_input - the version substring. + virtual void OnRequestFirstLineInput(absl::string_view line_input, + absl::string_view method_input, + absl::string_view request_uri, + absl::string_view version_input) = 0; + + // Summary: + // Called when the first line of the message is parsed, in this case, for a + // response. + // Arguments: + // line_input - the first line string, + // version_input - the version substring, + // status_input - the status substring, + // reason_input - the reason substring. + virtual void OnResponseFirstLineInput(absl::string_view line_input, + absl::string_view version_input, + absl::string_view status_input, + absl::string_view reason_input) = 0; + + // Summary: + // Called when a chunk length is parsed. + // Arguments: + // chunk length - the length of the next incoming chunk. + virtual void OnChunkLength(size_t chunk_length) = 0; + + // Summary: + // BalsaFrame passes the raw chunk extension data through this function. + // The data is not cleaned up at all. + // Arguments: + // input - contains the bytes available for read. + virtual void OnChunkExtensionInput(absl::string_view input) = 0; + + // Summary: + // Called when an interim response (response code 1xx) is framed and + // processed. This callback is mutually exclusive with ContinueHeaderDone(). + // Arguments: + // headers - contains the parsed headers in the order in which they occurred + // in the interim response. + virtual void OnInterimHeaders(BalsaHeaders headers) = 0; + + // Summary: + // Called when the 100 Continue headers are framed and processed. This + // callback is mutually exclusive with OnInterimHeaders(). + // TODO(b/68801833): Remove this and update the OnInterimHeaders() comment. + virtual void ContinueHeaderDone() = 0; + + // Summary: + // Called when the header is framed and processed. + virtual void HeaderDone() = 0; + + // Summary: + // Called when the message is framed and processed. + virtual void MessageDone() = 0; + + // Summary: + // Called when an error is detected + // Arguments: + // error_code - the error which is to be reported. + virtual void HandleError(BalsaFrameEnums::ErrorCode error_code) = 0; + + // Summary: + // Called when something meriting a warning is detected + // Arguments: + // error_code - the warning which is to be reported. + virtual void HandleWarning(BalsaFrameEnums::ErrorCode error_code) = 0; +}; + +} // namespace quiche + +#endif // QUICHE_BALSA_BALSA_VISITOR_INTERFACE_H_ diff --git a/quiche/balsa/framer_interface.h b/quiche/balsa/framer_interface.h new file mode 100644 index 000000000000..fdb0f7d9b026 --- /dev/null +++ b/quiche/balsa/framer_interface.h @@ -0,0 +1,24 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_BALSA_FRAMER_INTERFACE_H_ +#define QUICHE_BALSA_FRAMER_INTERFACE_H_ + +#include + +#include "quiche/common/platform/api/quiche_export.h" + +namespace quiche { + +// A minimal interface supported by BalsaFrame and other framer types. For use +// in HttpReader. +class QUICHE_EXPORT FramerInterface { + public: + virtual ~FramerInterface() {} + virtual size_t ProcessInput(const char* input, size_t length) = 0; +}; + +} // namespace quiche + +#endif // QUICHE_BALSA_FRAMER_INTERFACE_H_ diff --git a/quiche/balsa/header_api.h b/quiche/balsa/header_api.h new file mode 100644 index 000000000000..889c984f5fdd --- /dev/null +++ b/quiche/balsa/header_api.h @@ -0,0 +1,274 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_BALSA_HEADER_API_H_ +#define QUICHE_BALSA_HEADER_API_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_lower_case_string.h" + +namespace quiche { + +// An API so we can reuse functions for BalsaHeaders and Envoy's HeaderMap. +// Contains only const member functions, so it can wrap const HeaderMaps; +// non-const functions are in HeaderApi. +// +// Depending on the implementation, the headers may act like HTTP/1 headers +// (BalsaHeaders) or HTTP/2 headers (HeaderMap). For HTTP-version-specific +// headers or pseudoheaders like "host" or ":authority", use this API's +// implementation-independent member functions, like Authority(). Looking those +// headers up by name is deprecated and may QUICHE_DCHECK-fail. +// For the differences between HTTP/1 and HTTP/2 headers, see RFC 7540: +// https://tools.ietf.org/html/rfc7540#section-8.1.2 +// +// Operations on header keys are case-insensitive while operations on header +// values are case-sensitive. +// +// Some methods have overloads which accept Envoy-style LowerCaseStrings. Often +// these keys are accessible from Envoy::Http::Headers::get().SomeHeader, +// already lowercaseified. It's faster to avoid converting them to and from +// lowercase. Additionally, some implementations of ConstHeaderApi might take +// advantage of a constant-time lookup for inlined headers. +class QUICHE_EXPORT ConstHeaderApi { + public: + virtual ~ConstHeaderApi() {} + + // Determine whether the headers are empty. + virtual bool IsEmpty() const = 0; + + // Returns the header entry for the first instance with key |key| + // If header isn't present, returns absl::string_view(). + virtual absl::string_view GetHeader(absl::string_view key) const = 0; + + virtual absl::string_view GetHeader(const QuicheLowerCaseString& key) const { + // Default impl for BalsaHeaders, etc. + return GetHeader(key.get()); + } + + // Collects all of the header entries with key |key| and returns them in |out| + // Headers are returned in the order they are inserted. + virtual void GetAllOfHeader(absl::string_view key, + std::vector* out) const = 0; + virtual std::vector GetAllOfHeader( + absl::string_view key) const { + std::vector out; + GetAllOfHeader(key, &out); + return out; + } + virtual void GetAllOfHeader(const QuicheLowerCaseString& key, + std::vector* out) const { + return GetAllOfHeader(key.get(), out); + } + + // Determine if a given header is present. + virtual bool HasHeader(absl::string_view key) const = 0; + + // Determines if a given header is present with non-empty value. + virtual bool HasNonEmptyHeader(absl::string_view key) const = 0; + + // Goes through all headers with key |key| and checks to see if one of the + // values is |value|. Returns true if there are headers with the desired key + // and value, false otherwise. + virtual bool HeaderHasValue(absl::string_view key, + absl::string_view value) const = 0; + + // Same as above, but value is treated as case insensitive. + virtual bool HeaderHasValueIgnoreCase(absl::string_view key, + absl::string_view value) const = 0; + + // Joins all values for header entries with `key` into a comma-separated + // string. Headers are returned in the order they are inserted. + virtual std::string GetAllOfHeaderAsString(absl::string_view key) const = 0; + virtual std::string GetAllOfHeaderAsString( + const QuicheLowerCaseString& key) const { + return GetAllOfHeaderAsString(key.get()); + } + + // Returns true if we have at least one header with given prefix + // [case insensitive]. Currently for test use only. + virtual bool HasHeadersWithPrefix(absl::string_view key) const = 0; + + // Returns the key value pairs for all headers where the header key begins + // with the specified prefix. + // Headers are returned in the order they are inserted. + virtual void GetAllOfHeaderWithPrefix( + absl::string_view prefix, + std::vector>* out) + const = 0; + + // Returns the key value pairs for all headers in this object. If 'limit' is + // >= 0, return at most 'limit' headers. + virtual void GetAllHeadersWithLimit( + std::vector>* out, + int limit) const = 0; + + // Returns a textual representation of the header object. The format of the + // string may depend on the underlying implementation. + virtual std::string DebugString() const = 0; + + // Applies the argument function to each header line. If the argument + // function returns false, iteration stops and ForEachHeader returns false; + // otherwise, ForEachHeader returns true. + virtual bool ForEachHeader(std::function + fn) const = 0; + + // Returns the upper bound byte size of the headers. This can be used to size + // a Buffer when serializing headers. + virtual size_t GetSizeForWriteBuffer() const = 0; + + // Returns the response code for response headers. If no status code exists, + // the return value is implementation-specific. + virtual absl::string_view response_code() const = 0; + + // Returns the response code for response headers or 0 if no status code + // exists. + virtual size_t parsed_response_code() const = 0; + + // Returns the response reason phrase; the stored one for HTTP/1 headers, or a + // phrase determined from the response code for HTTP/2 headers.. + virtual absl::string_view response_reason_phrase() const = 0; + + // Return the HTTP first line of this request, generally of the format: + // GET /path/ HTTP/1.1 + // TODO(b/110421449): deprecate this method. + virtual std::string first_line_of_request() const = 0; + + // Return the method for this request, such as GET or POST. + virtual absl::string_view request_method() const = 0; + + // Return the request URI from the first line of this request, such as + // "/path/". + virtual absl::string_view request_uri() const = 0; + + // Return the version portion of the first line of this request, such as + // "HTTP/1.1". + // TODO(b/110421449): deprecate this method. + virtual absl::string_view request_version() const = 0; + + virtual absl::string_view response_version() const = 0; + + // Returns the authority portion of a request, or an empty string if missing. + // This is the value of the host header for HTTP/1 headers and the value of + // the :authority pseudo-header for HTTP/2 headers. + virtual absl::string_view Authority() const = 0; + + // Call the provided function on the cookie, avoiding + // copies if possible. The cookie is the value of the Cookie header; for + // HTTP/2 headers, if there are multiple Cookie headers, they will be joined + // by "; ", per go/rfc/7540#section-8.1.2.5. If there is no Cookie header, + // cookie.data() will be nullptr. The lifetime of the cookie isn't guaranteed + // to extend beyond this call. + virtual void ApplyToCookie( + std::function f) const = 0; + + virtual size_t content_length() const = 0; + virtual bool content_length_valid() const = 0; + + // TODO(b/118501626): Add functions for working with other headers and + // pseudo-headers whose presence or value depends on HTTP version, including: + // :method, :scheme, :path, connection, and cookie. +}; + +// An API so we can reuse functions for BalsaHeaders and Envoy's HeaderMap. +// Inherits const functions from ConstHeaderApi and adds non-const functions, +// for use with non-const HeaderMaps. +// +// For HTTP-version-specific headers and pseudo-headers, the same caveats apply +// as with ConstHeaderApi. +// +// Operations on header keys are case-insensitive while operations on header +// values are case-sensitive. +class QUICHE_EXPORT HeaderApi : public virtual ConstHeaderApi { + public: + // Replaces header entries with key |key| if they exist, or appends + // a new header if none exist. + virtual void ReplaceOrAppendHeader(absl::string_view key, + absl::string_view value) = 0; + + // Removes all headers in given set of |keys| at once + virtual void RemoveAllOfHeaderInList( + const std::vector& keys) = 0; + + // Removes all headers with key |key|. + virtual void RemoveAllOfHeader(absl::string_view key) = 0; + + // Append a new header entry to the header object with key |key| and value + // |value|. + virtual void AppendHeader(absl::string_view key, absl::string_view value) = 0; + + // Removes all headers starting with 'key' [case insensitive] + virtual void RemoveAllHeadersWithPrefix(absl::string_view key) = 0; + + // Appends ',value' to an existing header named 'key'. If no header with the + // correct key exists, it will call AppendHeader(key, value). Calling this + // function on a key which exists several times in the headers will produce + // unpredictable results. + virtual void AppendToHeader(absl::string_view key, + absl::string_view value) = 0; + + // Appends ', value' to an existing header named 'key'. If no header with the + // correct key exists, it will call AppendHeader(key, value). Calling this + // function on a key which exists several times in the headers will produce + // unpredictable results. + virtual void AppendToHeaderWithCommaAndSpace(absl::string_view key, + absl::string_view value) = 0; + + // Set the header or pseudo-header corresponding to the authority portion of a + // request: host for HTTP/1 headers, or :authority for HTTP/2 headers. + virtual void ReplaceOrAppendAuthority(absl::string_view value) = 0; + virtual void RemoveAuthority() = 0; + + // These set portions of the first line for HTTP/1 headers, or the + // corresponding pseudo-headers for HTTP/2 headers. + virtual void SetRequestMethod(absl::string_view method) = 0; + virtual void SetResponseCode(absl::string_view code) = 0; + // As SetResponseCode, but slightly faster for BalsaHeaders if the caller + // represents the response code as an integer and not a string. + virtual void SetParsedResponseCodeAndUpdateFirstline( + size_t parsed_response_code) = 0; + + // Sets the request URI. + // + // For HTTP/1 headers, sets the request URI portion of the first line (the + // second token). Doesn't parse the URI; leaves the Host header unchanged. + // + // For HTTP/2 headers, sets the :path pseudo-header, and also :scheme and + // :authority if they're present in the URI; otherwise, leaves :scheme and + // :authority unchanged. + // + // The caller is responsible for verifying that the URI is in a valid format. + virtual void SetRequestUri(absl::string_view uri) = 0; + + // These are only meaningful for HTTP/1 headers; for HTTP/2 headers, they do + // nothing. + virtual void SetRequestVersion(absl::string_view version) = 0; + virtual void SetResponseVersion(absl::string_view version) = 0; + virtual void SetResponseReasonPhrase(absl::string_view reason_phrase) = 0; + + // SetContentLength, SetTransferEncodingToChunkedAndClearContentLength, and + // SetNoTransferEncoding modifies the header object to use + // content-length and transfer-encoding headers in a consistent + // manner. They set all internal flags and status, if applicable, so client + // can get a consistent view from various accessors. + virtual void SetContentLength(size_t length) = 0; + // Sets transfer-encoding to chunked and updates internal state. + virtual void SetTransferEncodingToChunkedAndClearContentLength() = 0; + // Removes transfer-encoding headers and updates internal state. + virtual void SetNoTransferEncoding() = 0; + + // If true, QUICHE_BUG if a header that starts with an invalid prefix is + // explicitly set. Not implemented for Envoy headers; can only be set false. + virtual void set_enforce_header_policy(bool enforce) = 0; +}; + +} // namespace quiche + +#endif // QUICHE_BALSA_HEADER_API_H_ diff --git a/quiche/balsa/header_properties.cc b/quiche/balsa/header_properties.cc new file mode 100644 index 000000000000..240979c12e1f --- /dev/null +++ b/quiche/balsa/header_properties.cc @@ -0,0 +1,111 @@ +#include "quiche/balsa/header_properties.h" + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "quiche/common/quiche_text_utils.h" + +namespace quiche::header_properties { + +namespace { + +using MultivaluedHeadersSet = + absl::flat_hash_set; + +MultivaluedHeadersSet* buildMultivaluedHeaders() { + return new MultivaluedHeadersSet({ + "accept", + "accept-charset", + "accept-encoding", + "accept-language", + "accept-ranges", + // The follow four headers are all CORS standard headers + "access-control-allow-headers", + "access-control-allow-methods", + "access-control-expose-headers", + "access-control-request-headers", + "allow", + "cache-control", + // IETF draft makes this have cache-control syntax + "cdn-cache-control", + "connection", + "content-encoding", + "content-language", + "expect", + "if-match", + "if-none-match", + // See RFC 5988 section 5 + "link", + "pragma", + "proxy-authenticate", + "te", + // Used in the opening handshake of the WebSocket protocol. + "sec-websocket-extensions", + // Not mentioned in RFC 2616, but it can have multiple values. + "set-cookie", + "trailer", + "transfer-encoding", + "upgrade", + "vary", + "via", + "warning", + "www-authenticate", + // De facto standard not in the RFCs + "x-forwarded-for", + // Internal Google usage gives this cache-control syntax + "x-go" /**/ "ogle-cache-control", + }); +} + +std::array buildInvalidHeaderKeyCharLookupTable() { + std::array invalidCharTable; + invalidCharTable.fill(false); + for (uint8_t c : kInvalidHeaderKeyCharList) { + invalidCharTable[c] = true; + } + return invalidCharTable; +} + +std::array buildInvalidCharLookupTable() { + std::array invalidCharTable; + invalidCharTable.fill(false); + for (uint8_t c : kInvalidHeaderCharList) { + invalidCharTable[c] = true; + } + return invalidCharTable; +} + +} // anonymous namespace + +bool IsMultivaluedHeader(absl::string_view header) { + static const MultivaluedHeadersSet* const multivalued_headers = + buildMultivaluedHeaders(); + return multivalued_headers->contains(header); +} + +bool IsInvalidHeaderKeyChar(uint8_t c) { + static const std::array invalidHeaderKeyCharTable = + buildInvalidHeaderKeyCharLookupTable(); + + return invalidHeaderKeyCharTable[c]; +} + +bool IsInvalidHeaderChar(uint8_t c) { + static const std::array invalidCharTable = + buildInvalidCharLookupTable(); + + return invalidCharTable[c]; +} + +bool HasInvalidHeaderChars(absl::string_view value) { + for (const char c : value) { + if (IsInvalidHeaderChar(c)) { + return true; + } + } + return false; +} + +} // namespace quiche::header_properties diff --git a/quiche/balsa/header_properties.h b/quiche/balsa/header_properties.h new file mode 100644 index 000000000000..cbd5c596b91e --- /dev/null +++ b/quiche/balsa/header_properties.h @@ -0,0 +1,50 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_BALSA_HEADER_PROPERTIES_H_ +#define QUICHE_BALSA_HEADER_PROPERTIES_H_ + +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace quiche::header_properties { + +// Returns true if RFC 2616 Section 14 (or other relevant standards or +// practices) indicates that header can have multiple values. Note that nothing +// stops clients from sending multiple values of other headers, so this may not +// be perfectly reliable in practice. +QUICHE_EXPORT bool IsMultivaluedHeader(absl::string_view header); + +// An array of characters that are invalid in HTTP header field names. +// These are control characters, including \t, \n, \r, as well as space and +// (),/;<=>?@[\]{} and \x7f (see +// https://tools.ietf.org/html/rfc7230#section-3.2.6). +inline constexpr char kInvalidHeaderKeyCharList[] = { + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, + 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, + 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, + 0x1E, 0x1F, ' ', '(', ')', ',', '/', ';', '<', '=', + '>', '?', '@', '[', '\\', ']', '{', '}', 0x7f}; + +// An array of characters that are invalid in HTTP header field values, +// according to RFC 7230 Section 3.2. Valid low characters not in this array +// are \t (0x09), \n (0x0A), and \r (0x0D). +// Note that HTTP header field names are even more restrictive. +inline constexpr char kInvalidHeaderCharList[] = { + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x0B, + 0x0C, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, + 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x7F}; + +// Returns true if the given `c` is invalid in a header field name. +QUICHE_EXPORT bool IsInvalidHeaderKeyChar(uint8_t c); +// Returns true if the given `c` is invalid in a header field or the `value` has +// invalid characters. +QUICHE_EXPORT bool IsInvalidHeaderChar(uint8_t c); +QUICHE_EXPORT bool HasInvalidHeaderChars(absl::string_view value); + +} // namespace quiche::header_properties + +#endif // QUICHE_BALSA_HEADER_PROPERTIES_H_ diff --git a/quiche/balsa/header_properties_test.cc b/quiche/balsa/header_properties_test.cc new file mode 100644 index 000000000000..ef5f8fc3bc4d --- /dev/null +++ b/quiche/balsa/header_properties_test.cc @@ -0,0 +1,80 @@ +#include "quiche/balsa/header_properties.h" + +#include "quiche/common/platform/api/quiche_test.h" + +namespace quiche::header_properties::test { +namespace { + +TEST(HeaderPropertiesTest, IsMultivaluedHeaderIsCaseInsensitive) { + EXPECT_TRUE(IsMultivaluedHeader("content-encoding")); + EXPECT_TRUE(IsMultivaluedHeader("Content-Encoding")); + EXPECT_TRUE(IsMultivaluedHeader("set-cookie")); + EXPECT_TRUE(IsMultivaluedHeader("sEt-cOOkie")); + EXPECT_TRUE(IsMultivaluedHeader("X-Goo" /**/ "gle-Cache-Control")); + EXPECT_TRUE(IsMultivaluedHeader("access-control-expose-HEADERS")); + + EXPECT_FALSE(IsMultivaluedHeader("set-cook")); + EXPECT_FALSE(IsMultivaluedHeader("content-length")); + EXPECT_FALSE(IsMultivaluedHeader("Content-Length")); +} + +TEST(HeaderPropertiesTest, IsInvalidHeaderKeyChar) { + EXPECT_TRUE(IsInvalidHeaderKeyChar(0x00)); + EXPECT_TRUE(IsInvalidHeaderKeyChar(0x06)); + EXPECT_TRUE(IsInvalidHeaderKeyChar(0x09)); + EXPECT_TRUE(IsInvalidHeaderKeyChar(0x1F)); + EXPECT_TRUE(IsInvalidHeaderKeyChar(0x7F)); + EXPECT_TRUE(IsInvalidHeaderKeyChar(' ')); + EXPECT_TRUE(IsInvalidHeaderKeyChar('\t')); + EXPECT_TRUE(IsInvalidHeaderKeyChar('\r')); + EXPECT_TRUE(IsInvalidHeaderKeyChar('\n')); + + EXPECT_FALSE(IsInvalidHeaderKeyChar('a')); + EXPECT_FALSE(IsInvalidHeaderKeyChar('B')); + EXPECT_FALSE(IsInvalidHeaderKeyChar('7')); + EXPECT_FALSE(IsInvalidHeaderKeyChar(0x42)); + EXPECT_FALSE(IsInvalidHeaderChar(0x7D)); +} + +TEST(HeaderPropertiesTest, IsInvalidHeaderChar) { + EXPECT_TRUE(IsInvalidHeaderChar(0x00)); + EXPECT_TRUE(IsInvalidHeaderChar(0x06)); + EXPECT_TRUE(IsInvalidHeaderChar(0x1F)); + EXPECT_TRUE(IsInvalidHeaderChar(0x7F)); + + EXPECT_FALSE(IsInvalidHeaderChar(0x09)); + EXPECT_FALSE(IsInvalidHeaderChar(' ')); + EXPECT_FALSE(IsInvalidHeaderChar('\t')); + EXPECT_FALSE(IsInvalidHeaderChar('\r')); + EXPECT_FALSE(IsInvalidHeaderChar('\n')); + EXPECT_FALSE(IsInvalidHeaderChar('a')); + EXPECT_FALSE(IsInvalidHeaderChar('B')); + EXPECT_FALSE(IsInvalidHeaderChar('7')); + EXPECT_FALSE(IsInvalidHeaderChar(0x42)); + EXPECT_FALSE(IsInvalidHeaderChar(0x7D)); +} + +TEST(HeaderPropertiesTest, KeyMoreRestrictiveThanValue) { + for (int c = 0; c < 255; ++c) { + if (IsInvalidHeaderChar(c)) { + EXPECT_TRUE(IsInvalidHeaderKeyChar(c)) << c; + } + } +} + +TEST(HeaderPropertiesTest, HasInvalidHeaderChars) { + const char with_null[] = "Here's l\x00king at you, kid"; + EXPECT_TRUE(HasInvalidHeaderChars(std::string(with_null, sizeof(with_null)))); + EXPECT_TRUE(HasInvalidHeaderChars("Why's \x06 afraid of \x07? \x07\x08\x09")); + EXPECT_TRUE(HasInvalidHeaderChars("\x1Flower power")); + EXPECT_TRUE(HasInvalidHeaderChars("\x7Flowers more powers")); + + EXPECT_FALSE(HasInvalidHeaderChars("Plenty of space")); + EXPECT_FALSE(HasInvalidHeaderChars("Keeping \tabs")); + EXPECT_FALSE(HasInvalidHeaderChars("Al\right")); + EXPECT_FALSE(HasInvalidHeaderChars("\new day")); + EXPECT_FALSE(HasInvalidHeaderChars("\x42 is a nice character")); +} + +} // namespace +} // namespace quiche::header_properties::test diff --git a/quiche/balsa/http_validation_policy.h b/quiche/balsa/http_validation_policy.h new file mode 100644 index 000000000000..725f48ee7ebf --- /dev/null +++ b/quiche/balsa/http_validation_policy.h @@ -0,0 +1,36 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_BALSA_HTTP_VALIDATION_POLICY_H_ +#define QUICHE_BALSA_HTTP_VALIDATION_POLICY_H_ + +#include + +#include "quiche/common/platform/api/quiche_export.h" + +namespace quiche { + +// An HttpValidationPolicy captures policy choices affecting parsing of HTTP +// requests. It offers individual Boolean members to be consulted during the +// parsing of an HTTP request. +struct QUICHE_EXPORT HttpValidationPolicy { + // https://tools.ietf.org/html/rfc7230#section-3.2.4 deprecates "folding" + // of long header lines onto continuation lines. + bool disallow_header_continuation_lines = false; + + // A valid header line requires a header name and a colon. + bool require_header_colon = false; + + // https://tools.ietf.org/html/rfc7230#section-3.3.2 disallows multiple + // Content-Length header fields with the same value. + bool disallow_multiple_content_length = false; + + // https://tools.ietf.org/html/rfc7230#section-3.3.2 disallows + // Transfer-Encoding and Content-Length header fields together. + bool disallow_transfer_encoding_with_content_length = false; +}; + +} // namespace quiche + +#endif // QUICHE_BALSA_HTTP_VALIDATION_POLICY_H_ diff --git a/quiche/balsa/noop_balsa_visitor.h b/quiche/balsa/noop_balsa_visitor.h new file mode 100644 index 000000000000..ce82d58bb2f0 --- /dev/null +++ b/quiche/balsa/noop_balsa_visitor.h @@ -0,0 +1,59 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_BALSA_NOOP_BALSA_VISITOR_H_ +#define QUICHE_BALSA_NOOP_BALSA_VISITOR_H_ + +#include + +#include "absl/strings/string_view.h" +#include "quiche/balsa/balsa_visitor_interface.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace quiche { + +class BalsaHeaders; + +// Provides empty BalsaVisitorInterface overrides for convenience. +// Intended to be used as a base class for BalsaVisitorInterface subclasses that +// only need to override a small number of methods. +class QUICHE_EXPORT NoOpBalsaVisitor : public BalsaVisitorInterface { + public: + NoOpBalsaVisitor() = default; + + NoOpBalsaVisitor(const NoOpBalsaVisitor&) = delete; + NoOpBalsaVisitor& operator=(const NoOpBalsaVisitor&) = delete; + + ~NoOpBalsaVisitor() override {} + + void OnRawBodyInput(absl::string_view /*input*/) override {} + void OnBodyChunkInput(absl::string_view /*input*/) override {} + void OnHeaderInput(absl::string_view /*input*/) override {} + void OnHeader(absl::string_view /*key*/, + absl::string_view /*value*/) override {} + void OnTrailerInput(absl::string_view /*input*/) override {} + void ProcessHeaders(const BalsaHeaders& /*headers*/) override {} + void ProcessTrailers(const BalsaHeaders& /*trailer*/) override {} + + void OnRequestFirstLineInput(absl::string_view /*line_input*/, + absl::string_view /*method_input*/, + absl::string_view /*request_uri_input*/, + absl::string_view /*version_input*/) override {} + void OnResponseFirstLineInput(absl::string_view /*line_input*/, + absl::string_view /*version_input*/, + absl::string_view /*status_input*/, + absl::string_view /*reason_input*/) override {} + void OnChunkLength(size_t /*chunk_length*/) override {} + void OnChunkExtensionInput(absl::string_view /*input*/) override {} + void OnInterimHeaders(BalsaHeaders /*headers*/) override {} + void ContinueHeaderDone() override {} + void HeaderDone() override {} + void MessageDone() override {} + void HandleError(BalsaFrameEnums::ErrorCode /*error_code*/) override {} + void HandleWarning(BalsaFrameEnums::ErrorCode /*error_code*/) override {} +}; + +} // namespace quiche + +#endif // QUICHE_BALSA_NOOP_BALSA_VISITOR_H_ diff --git a/quiche/balsa/simple_buffer.cc b/quiche/balsa/simple_buffer.cc new file mode 100644 index 000000000000..756c7da9e8ea --- /dev/null +++ b/quiche/balsa/simple_buffer.cc @@ -0,0 +1,152 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/balsa/simple_buffer.h" + +#include +#include +#include + +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace quiche { + +constexpr int kMinimumSimpleBufferSize = 10; + +SimpleBuffer::SimpleBuffer(int size) { Reserve(size); } + +//////////////////////////////////////////////////////////////////////////////// + +int SimpleBuffer::Write(const char* bytes, int size) { + if (size <= 0) { + QUICHE_BUG_IF(simple_buffer_write_negative_size, size < 0) + << "size must not be negative: " << size; + return 0; + } + + Reserve(size); + memcpy(storage_ + write_idx_, bytes, size); + AdvanceWritablePtr(size); + return size; +} + +//////////////////////////////////////////////////////////////////////////////// + +int SimpleBuffer::Read(char* bytes, int size) { + if (size < 0) { + QUICHE_BUG(simple_buffer_read_negative_size) + << "size must not be negative: " << size; + return 0; + } + + char* read_ptr = nullptr; + int read_size = 0; + GetReadablePtr(&read_ptr, &read_size); + read_size = std::min(read_size, size); + if (read_size == 0) { + return 0; + } + + memcpy(bytes, read_ptr, read_size); + AdvanceReadablePtr(read_size); + return read_size; +} + +//////////////////////////////////////////////////////////////////////////////// + +// Attempts to reserve a contiguous block of buffer space either by reclaiming +// consumed data or by allocating a larger buffer. +void SimpleBuffer::Reserve(int size) { + if (size < 0) { + QUICHE_BUG(simple_buffer_reserve_negative_size) + << "size must not be negative: " << size; + return; + } + + if (size == 0 || storage_size_ - write_idx_ >= size) { + return; + } + + char* read_ptr = nullptr; + int read_size = 0; + GetReadablePtr(&read_ptr, &read_size); + + if (read_ptr == nullptr) { + QUICHE_DCHECK_EQ(0, read_size); + + size = std::max(size, kMinimumSimpleBufferSize); + storage_ = new char[size]; + storage_size_ = size; + return; + } + + if (read_size + size <= storage_size_) { + // Can reclaim space from consumed bytes by shifting. + memmove(storage_, read_ptr, read_size); + read_idx_ = 0; + write_idx_ = read_size; + return; + } + + // The new buffer needs to be at least `read_size + size` bytes. + // At least double the buffer to amortize allocation costs. + storage_size_ = std::max(2 * storage_size_, size + read_size); + + char* new_storage = new char[storage_size_]; + memcpy(new_storage, read_ptr, read_size); + delete[] storage_; + + read_idx_ = 0; + write_idx_ = read_size; + storage_ = new_storage; +} + +void SimpleBuffer::AdvanceReadablePtr(int amount_to_advance) { + if (amount_to_advance < 0) { + QUICHE_BUG(simple_buffer_advance_read_negative_arg) + << "amount_to_advance must not be negative: " << amount_to_advance; + return; + } + + read_idx_ += amount_to_advance; + if (read_idx_ > write_idx_) { + QUICHE_BUG(simple_buffer_read_ptr_too_far) + << "error: readable pointer advanced beyond writable one"; + read_idx_ = write_idx_; + } + + if (read_idx_ == write_idx_) { + // Buffer is empty, rewind `read_idx_` and `write_idx_` so that next write + // happens at the beginning of buffer instead of cutting free space in two. + Clear(); + } +} + +void SimpleBuffer::AdvanceWritablePtr(int amount_to_advance) { + if (amount_to_advance < 0) { + QUICHE_BUG(simple_buffer_advance_write_negative_arg) + << "amount_to_advance must not be negative: " << amount_to_advance; + return; + } + + write_idx_ += amount_to_advance; + if (write_idx_ > storage_size_) { + QUICHE_BUG(simple_buffer_write_ptr_too_far) + << "error: writable pointer advanced beyond end of storage"; + write_idx_ = storage_size_; + } +} + +QuicheMemSlice SimpleBuffer::ReleaseAsSlice() { + if (write_idx_ == 0) { + return QuicheMemSlice(); + } + QuicheMemSlice slice(std::unique_ptr(storage_), write_idx_); + Clear(); + storage_ = nullptr; + storage_size_ = 0; + return slice; +} +} // namespace quiche diff --git a/quiche/balsa/simple_buffer.h b/quiche/balsa/simple_buffer.h new file mode 100644 index 000000000000..96fd25f13e3d --- /dev/null +++ b/quiche/balsa/simple_buffer.h @@ -0,0 +1,118 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_BALSA_SIMPLE_BUFFER_H_ +#define QUICHE_BALSA_SIMPLE_BUFFER_H_ + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" + +namespace quiche { + +namespace test { +class SimpleBufferTest; +} // namespace test + +// SimpleBuffer stores data in a contiguous region. It can grow on demand, +// which involves moving its data. It keeps track of a read and a write +// position. Reading consumes data. +class QUICHE_EXPORT SimpleBuffer { + public: + SimpleBuffer() = default; + // Create SimpleBuffer with at least `size` reserved capacity. + explicit SimpleBuffer(int size); + + SimpleBuffer(const SimpleBuffer&) = delete; + SimpleBuffer& operator=(const SimpleBuffer&) = delete; + + virtual ~SimpleBuffer() { delete[] storage_; } + + // Returns the number of bytes that can be read from the buffer. + int ReadableBytes() const { return write_idx_ - read_idx_; } + + bool Empty() const { return read_idx_ == write_idx_; } + + // Copies `size` bytes to the buffer. Returns size. + int Write(const char* bytes, int size); + int WriteString(absl::string_view piece) { + return Write(piece.data(), piece.size()); + } + + // Stores the pointer into the buffer that can be written to in `*ptr`, and + // the number of characters that are allowed to be written in `*size`. The + // pointer and size can be used in functions like recv() or read(). If + // `*size` is zero upon returning from this function, then it is unsafe to + // dereference `*ptr`. Writing to this region after calling any other + // non-const method results in undefined behavior. + void GetWritablePtr(char** ptr, int* size) const { + *ptr = storage_ + write_idx_; + *size = storage_size_ - write_idx_; + } + + // Stores the pointer that can be read from in `*ptr`, and the number of bytes + // that are allowed to be read in `*size`. The pointer and size can be used + // in functions like send() or write(). If `*size` is zero upon returning + // from this function, then it is unsafe to dereference `*ptr`. Reading from + // this region after calling any other non-const method results in undefined + // behavior. + void GetReadablePtr(char** ptr, int* size) const { + *ptr = storage_ + read_idx_; + *size = write_idx_ - read_idx_; + } + + // Returns the readable region as a string_view. Reading from this region + // after calling any other non-const method results in undefined behavior. + absl::string_view GetReadableRegion() const { + return absl::string_view(storage_ + read_idx_, write_idx_ - read_idx_); + } + + // Reads bytes out of the buffer, and writes them into `bytes`. Returns the + // number of bytes read. Consumes bytes from the buffer. + int Read(char* bytes, int size); + + // Marks all data consumed, making the entire reserved buffer available for + // write. Does not resize or free up any memory. + void Clear() { read_idx_ = write_idx_ = 0; } + + // Makes sure at least `size` bytes can be written into the buffer. This can + // be an expensive operation: costing a new and a delete, and copying of all + // existing data. Even if the existing buffer does not need to be resized, + // unread data may need to be moved to consolidate fragmented free space. + void Reserve(int size); + + // Marks the oldest `amount_to_advance` bytes as consumed. + // `amount_to_advance` must not be negative and it must not exceed + // ReadableBytes(). + void AdvanceReadablePtr(int amount_to_advance); + + // Marks the first `amount_to_advance` bytes of the writable area written. + // `amount_to_advance` must not be negative and it must not exceed the size of + // the writable area, returned as the `size` outparam of GetWritablePtr(). + void AdvanceWritablePtr(int amount_to_advance); + + // Releases the current contents of the SimpleBuffer and returns them as a + // MemSlice. Logically, has the same effect as calling Clear(). + QuicheMemSlice ReleaseAsSlice(); + + private: + friend class test::SimpleBufferTest; + + // The buffer owned by this class starts at `*storage_` and is `storage_size_` + // bytes long. + // If `storage_` is nullptr, then `storage_size_` must be zero. + // `0 <= read_idx_ <= write_idx_ <= storage_size_` must always hold. + // If `read_idx_ == write_idx_`, then they must be equal to zero. + // The first `read_idx_` bytes of the buffer are consumed, + // the next `write_idx_ - read_idx_` bytes are the readable region, and the + // remaining `storage_size_ - write_idx_` bytes are the writable region. + char* storage_ = nullptr; + int write_idx_ = 0; + int read_idx_ = 0; + int storage_size_ = 0; +}; + +} // namespace quiche + +#endif // QUICHE_BALSA_SIMPLE_BUFFER_H_ diff --git a/quiche/balsa/simple_buffer_test.cc b/quiche/balsa/simple_buffer_test.cc new file mode 100644 index 000000000000..0f5bbbbf2979 --- /dev/null +++ b/quiche/balsa/simple_buffer_test.cc @@ -0,0 +1,411 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/balsa/simple_buffer.h" + +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_expect_bug.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace quiche { + +namespace test { + +namespace { + +constexpr int kMinimumSimpleBufferSize = 10; + +// Buffer full of 40 char strings. +const char ibuf[] = { + "123456789!@#$%^&*()abcdefghijklmnopqrstu" + "123456789!@#$%^&*()abcdefghijklmnopqrstu" + "123456789!@#$%^&*()abcdefghijklmnopqrstu" + "123456789!@#$%^&*()abcdefghijklmnopqrstu" + "123456789!@#$%^&*()abcdefghijklmnopqrstu"}; + +} // namespace + +class SimpleBufferTest : public QuicheTest { + public: + static char* storage(SimpleBuffer& buffer) { return buffer.storage_; } + static int write_idx(SimpleBuffer& buffer) { return buffer.write_idx_; } + static int read_idx(SimpleBuffer& buffer) { return buffer.read_idx_; } + static int storage_size(SimpleBuffer& buffer) { return buffer.storage_size_; } +}; + +namespace { + +TEST_F(SimpleBufferTest, CreationWithSize) { + SimpleBuffer buffer1(5); + EXPECT_EQ(kMinimumSimpleBufferSize, storage_size(buffer1)); + + SimpleBuffer buffer2(25); + EXPECT_EQ(25, storage_size(buffer2)); +} + +// Make sure that a zero-sized initial buffer does not throw things off. +TEST_F(SimpleBufferTest, CreationWithZeroSize) { + SimpleBuffer buffer(0); + EXPECT_EQ(0, storage_size(buffer)); + EXPECT_EQ(4, buffer.Write(ibuf, 4)); + EXPECT_EQ(4, write_idx(buffer)); + EXPECT_EQ(kMinimumSimpleBufferSize, storage_size(buffer)); + EXPECT_EQ(4, buffer.ReadableBytes()); +} + +TEST_F(SimpleBufferTest, ReadZeroBytes) { + SimpleBuffer buffer; + + EXPECT_EQ(0, buffer.Read(nullptr, 0)); +} + +TEST_F(SimpleBufferTest, WriteZeroFromNullptr) { + SimpleBuffer buffer; + + EXPECT_EQ(0, buffer.Write(nullptr, 0)); +} + +TEST(SimpleBufferExpectBug, ReserveNegativeSize) { + SimpleBuffer buffer; + + EXPECT_QUICHE_BUG(buffer.Reserve(-1), "size must not be negative"); +} + +TEST(SimpleBufferExpectBug, ReadNegativeSize) { + SimpleBuffer buffer; + + EXPECT_QUICHE_BUG(buffer.Read(nullptr, -1), "size must not be negative"); +} + +TEST(SimpleBufferExpectBug, WriteNegativeSize) { + SimpleBuffer buffer; + + EXPECT_QUICHE_BUG(buffer.Write(nullptr, -1), "size must not be negative"); +} + +TEST_F(SimpleBufferTest, Basics) { + SimpleBuffer buffer; + + EXPECT_TRUE(buffer.Empty()); + EXPECT_EQ("", buffer.GetReadableRegion()); + EXPECT_EQ(0, storage_size(buffer)); + EXPECT_EQ(0, read_idx(buffer)); + EXPECT_EQ(0, write_idx(buffer)); + + char* readable_ptr = nullptr; + int readable_size = 0; + buffer.GetReadablePtr(&readable_ptr, &readable_size); + char* writeable_ptr = nullptr; + int writable_size = 0; + buffer.GetWritablePtr(&writeable_ptr, &writable_size); + + EXPECT_EQ(storage(buffer), readable_ptr); + EXPECT_EQ(0, readable_size); + EXPECT_EQ(storage(buffer), writeable_ptr); + EXPECT_EQ(0, writable_size); + EXPECT_EQ(0, buffer.ReadableBytes()); + + const SimpleBuffer buffer2; + EXPECT_EQ(0, buffer2.ReadableBytes()); +} + +TEST_F(SimpleBufferTest, BasicWR) { + SimpleBuffer buffer; + + EXPECT_EQ(4, buffer.Write(ibuf, 4)); + EXPECT_EQ(0, read_idx(buffer)); + EXPECT_EQ(4, write_idx(buffer)); + EXPECT_EQ(kMinimumSimpleBufferSize, storage_size(buffer)); + EXPECT_EQ(4, buffer.ReadableBytes()); + EXPECT_EQ("1234", buffer.GetReadableRegion()); + int bytes_written = 4; + EXPECT_TRUE(!buffer.Empty()); + + char* readable_ptr = nullptr; + int readable_size = 0; + buffer.GetReadablePtr(&readable_ptr, &readable_size); + char* writeable_ptr = nullptr; + int writable_size = 0; + buffer.GetWritablePtr(&writeable_ptr, &writable_size); + + EXPECT_EQ(storage(buffer), readable_ptr); + EXPECT_EQ(4, readable_size); + EXPECT_EQ(storage(buffer) + 4, writeable_ptr); + EXPECT_EQ(6, writable_size); + + char obuf[ABSL_ARRAYSIZE(ibuf)]; + int bytes_read = 0; + EXPECT_EQ(4, buffer.Read(obuf + bytes_read, 40)); + EXPECT_EQ(0, read_idx(buffer)); + EXPECT_EQ(0, write_idx(buffer)); + EXPECT_EQ(kMinimumSimpleBufferSize, storage_size(buffer)); + EXPECT_EQ(0, buffer.ReadableBytes()); + EXPECT_EQ("", buffer.GetReadableRegion()); + bytes_read += 4; + EXPECT_TRUE(buffer.Empty()); + buffer.GetReadablePtr(&readable_ptr, &readable_size); + buffer.GetWritablePtr(&writeable_ptr, &writable_size); + EXPECT_EQ(storage(buffer), readable_ptr); + EXPECT_EQ(0, readable_size); + EXPECT_EQ(storage(buffer), writeable_ptr); + EXPECT_EQ(kMinimumSimpleBufferSize, writable_size); + + EXPECT_EQ(bytes_written, bytes_read); + for (int i = 0; i < bytes_read; ++i) { + EXPECT_EQ(obuf[i], ibuf[i]); + } + + // More R/W tests. + EXPECT_EQ(10, buffer.Write(ibuf + bytes_written, 10)); + EXPECT_EQ(0, read_idx(buffer)); + EXPECT_EQ(10, write_idx(buffer)); + EXPECT_EQ(10, storage_size(buffer)); + EXPECT_EQ(10, buffer.ReadableBytes()); + bytes_written += 10; + + EXPECT_TRUE(!buffer.Empty()); + + EXPECT_EQ(6, buffer.Read(obuf + bytes_read, 6)); + EXPECT_EQ(6, read_idx(buffer)); + EXPECT_EQ(10, write_idx(buffer)); + EXPECT_EQ(10, storage_size(buffer)); + EXPECT_EQ(4, buffer.ReadableBytes()); + bytes_read += 6; + + EXPECT_TRUE(!buffer.Empty()); + + EXPECT_EQ(4, buffer.Read(obuf + bytes_read, 7)); + EXPECT_EQ(0, read_idx(buffer)); + EXPECT_EQ(0, write_idx(buffer)); + EXPECT_EQ(10, storage_size(buffer)); + EXPECT_EQ(0, buffer.ReadableBytes()); + bytes_read += 4; + + EXPECT_TRUE(buffer.Empty()); + + EXPECT_EQ(bytes_written, bytes_read); + for (int i = 0; i < bytes_read; ++i) { + EXPECT_EQ(obuf[i], ibuf[i]); + } +} + +TEST_F(SimpleBufferTest, Reserve) { + SimpleBuffer buffer; + EXPECT_EQ(0, storage_size(buffer)); + + buffer.WriteString("foo"); + EXPECT_EQ(kMinimumSimpleBufferSize, storage_size(buffer)); + + // Reserve by expanding the buffer. + buffer.Reserve(kMinimumSimpleBufferSize + 1); + EXPECT_EQ(2 * kMinimumSimpleBufferSize, storage_size(buffer)); + + buffer.Clear(); + buffer.AdvanceWritablePtr(kMinimumSimpleBufferSize); + buffer.AdvanceReadablePtr(kMinimumSimpleBufferSize - 2); + EXPECT_EQ(kMinimumSimpleBufferSize, write_idx(buffer)); + EXPECT_EQ(2 * kMinimumSimpleBufferSize, storage_size(buffer)); + + // Reserve by moving data around. `storage_size` does not change. + buffer.Reserve(kMinimumSimpleBufferSize + 1); + EXPECT_EQ(2, write_idx(buffer)); + EXPECT_EQ(2 * kMinimumSimpleBufferSize, storage_size(buffer)); +} + +TEST_F(SimpleBufferTest, Extend) { + SimpleBuffer buffer; + + // Test a write which should not extend the buffer. + EXPECT_EQ(7, buffer.Write(ibuf, 7)); + EXPECT_EQ(0, read_idx(buffer)); + EXPECT_EQ(7, write_idx(buffer)); + EXPECT_EQ(kMinimumSimpleBufferSize, storage_size(buffer)); + EXPECT_EQ(7, buffer.ReadableBytes()); + EXPECT_EQ(0, read_idx(buffer)); + EXPECT_EQ(7, write_idx(buffer)); + EXPECT_EQ(kMinimumSimpleBufferSize, storage_size(buffer)); + EXPECT_EQ(7, buffer.ReadableBytes()); + int bytes_written = 7; + + // Test a write which should extend the buffer. + EXPECT_EQ(4, buffer.Write(ibuf + bytes_written, 4)); + EXPECT_EQ(0, read_idx(buffer)); + EXPECT_EQ(11, write_idx(buffer)); + EXPECT_EQ(20, storage_size(buffer)); + EXPECT_EQ(11, buffer.ReadableBytes()); + bytes_written += 4; + + char obuf[ABSL_ARRAYSIZE(ibuf)]; + EXPECT_EQ(11, buffer.Read(obuf, 11)); + EXPECT_EQ(0, read_idx(buffer)); + EXPECT_EQ(0, write_idx(buffer)); + EXPECT_EQ(20, storage_size(buffer)); + EXPECT_EQ(0, read_idx(buffer)); + EXPECT_EQ(0, write_idx(buffer)); + EXPECT_EQ(0, buffer.ReadableBytes()); + + const int bytes_read = 11; + EXPECT_EQ(bytes_written, bytes_read); + for (int i = 0; i < bytes_read; ++i) { + EXPECT_EQ(obuf[i], ibuf[i]); + } +} + +TEST_F(SimpleBufferTest, Clear) { + SimpleBuffer buffer; + + buffer.Clear(); + + EXPECT_EQ(0, read_idx(buffer)); + EXPECT_EQ(0, write_idx(buffer)); + EXPECT_EQ(0, storage_size(buffer)); + EXPECT_EQ(0, buffer.ReadableBytes()); + + buffer.WriteString("foo"); + buffer.Clear(); + + EXPECT_EQ(0, read_idx(buffer)); + EXPECT_EQ(0, write_idx(buffer)); + EXPECT_EQ(kMinimumSimpleBufferSize, storage_size(buffer)); + EXPECT_EQ(0, buffer.ReadableBytes()); +} + +TEST_F(SimpleBufferTest, LongWrite) { + SimpleBuffer buffer; + + std::string s1 = "HTTP/1.1 500 Service Unavailable"; + buffer.Write(s1.data(), s1.size()); + buffer.Write("\r\n", 2); + std::string key = "Connection"; + std::string value = "close"; + buffer.Write(key.data(), key.size()); + buffer.Write(": ", 2); + buffer.Write(value.data(), value.size()); + buffer.Write("\r\n", 2); + buffer.Write("\r\n", 2); + std::string message = + "\n" + "\n" + "\n" + "\n" + "\n" + "" + "\n" + "\n" + "
\n" + "" + "G" + "o" + "o" + "g" + "l" + "e" + "  \n" + " 
" + " Error
 
\n" + "
\n" + "

Internal Server Error

\n" + " This server was unable to complete the request\n" + "

\n" + "" + "" + "
\"\"
" + "\n"; + buffer.Write(message.data(), message.size()); + const std::string correct_result = + "HTTP/1.1 500 Service Unavailable\r\n" + "Connection: close\r\n" + "\r\n" + "\n" + "\n" + "\n" + "\n" + "\n" + "" + "\n" + "\n" + "
\n" + "" + "G" + "o" + "o" + "g" + "l" + "e" + "  \n" + " 
" + " Error
 
\n" + "
\n" + "

Internal Server Error

\n" + " This server was unable to complete the request\n" + "

\n" + "" + "" + "
\"\"
" + "\n"; + EXPECT_EQ(correct_result, buffer.GetReadableRegion()); +} + +TEST_F(SimpleBufferTest, ReleaseAsSlice) { + SimpleBuffer buffer; + + buffer.WriteString("abc"); + QuicheMemSlice slice = buffer.ReleaseAsSlice(); + EXPECT_EQ("abc", slice.AsStringView()); + + char* readable_ptr = nullptr; + int readable_size = 0; + buffer.GetReadablePtr(&readable_ptr, &readable_size); + EXPECT_EQ(nullptr, readable_ptr); + EXPECT_EQ(0, readable_size); + + buffer.WriteString("def"); + slice = buffer.ReleaseAsSlice(); + buffer.GetReadablePtr(&readable_ptr, &readable_size); + EXPECT_EQ(nullptr, readable_ptr); + EXPECT_EQ(0, readable_size); + EXPECT_EQ("def", slice.AsStringView()); +} + +TEST_F(SimpleBufferTest, EmptyBufferReleaseAsSlice) { + SimpleBuffer buffer; + char* readable_ptr = nullptr; + int readable_size = 0; + + QuicheMemSlice slice = buffer.ReleaseAsSlice(); + buffer.GetReadablePtr(&readable_ptr, &readable_size); + EXPECT_EQ(nullptr, readable_ptr); + EXPECT_EQ(0, readable_size); + EXPECT_TRUE(slice.empty()); +} + +} // namespace + +} // namespace test + +} // namespace quiche diff --git a/quiche/balsa/standard_header_map.cc b/quiche/balsa/standard_header_map.cc new file mode 100644 index 000000000000..f1c5f4bab708 --- /dev/null +++ b/quiche/balsa/standard_header_map.cc @@ -0,0 +1,143 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/balsa/standard_header_map.h" + +namespace quiche { + +const StandardHttpHeaderNameSet& GetStandardHeaderSet() { + static const StandardHttpHeaderNameSet* const header_map = + new StandardHttpHeaderNameSet({ + {"Accept"}, + {"Accept-Charset"}, + {"Accept-CH"}, + {"Accept-CH-Lifetime"}, + {"Accept-Encoding"}, + {"Accept-Language"}, + {"Accept-Ranges"}, + {"Access-Control-Allow-Credentials"}, + {"Access-Control-Allow-Headers"}, + {"Access-Control-Allow-Methods"}, + {"Access-Control-Allow-Origin"}, + {"Access-Control-Expose-Headers"}, + {"Access-Control-Max-Age"}, + {"Access-Control-Request-Headers"}, + {"Access-Control-Request-Method"}, + {"Age"}, + {"Allow"}, + {"Authorization"}, + {"Cache-Control"}, + {"Connection"}, + {"Content-Disposition"}, + {"Content-Encoding"}, + {"Content-Language"}, + {"Content-Length"}, + {"Content-Location"}, + {"Content-Range"}, + {"Content-Security-Policy"}, + {"Content-Security-Policy-Report-Only"}, + {"X-Content-Security-Policy"}, + {"X-Content-Security-Policy-Report-Only"}, + {"X-WebKit-CSP"}, + {"X-WebKit-CSP-Report-Only"}, + {"Content-Type"}, + {"Content-MD5"}, + {"X-Content-Type-Options"}, + {"Cookie"}, + {"Cookie2"}, + {"Cross-Origin-Resource-Policy"}, + {"Cross-Origin-Opener-Policy"}, + {"Date"}, + {"DAV"}, + {"Depth"}, + {"Destination"}, + {"DNT"}, + {"DPR"}, + {"Early-Data"}, + {"ETag"}, + {"Expect"}, + {"Expires"}, + {"Follow-Only-When-Prerender-Shown"}, + {"Forwarded"}, + {"From"}, + {"Host"}, + {"HTTP2-Settings"}, + {"If"}, + {"If-Match"}, + {"If-Modified-Since"}, + {"If-None-Match"}, + {"If-Range"}, + {"If-Unmodified-Since"}, + {"Keep-Alive"}, + {"Label"}, + {"Last-Modified"}, + {"Link"}, + {"Location"}, + {"Lock-Token"}, + {"Max-Forwards"}, + {"MS-Author-Via"}, + {"Origin"}, + {"Overwrite"}, + {"P3P"}, + {"Ping-From"}, + {"Ping-To"}, + {"Pragma"}, + {"Proxy-Connection"}, + {"Proxy-Authenticate"}, + {"Public-Key-Pins"}, + {"Public-Key-Pins-Report-Only"}, + {"Range"}, + {"Referer"}, + {"Referrer-Policy"}, + {"Refresh"}, + {"Report-To"}, + {"Retry-After"}, + {"Sec-Fetch-Dest"}, + {"Sec-Fetch-Mode"}, + {"Sec-Fetch-Site"}, + {"Sec-Fetch-User"}, + {"Sec-Metadata"}, + {"Sec-Token-Binding"}, + {"Sec-Provided-Token-Binding-ID"}, + {"Sec-Referred-Token-Binding-ID"}, + {"Sec-WebSocket-Accept"}, + {"Sec-WebSocket-Extensions"}, + {"Sec-WebSocket-Key"}, + {"Sec-WebSocket-Protocol"}, + {"Sec-WebSocket-Version"}, + {"Server"}, + {"Server-Timing"}, + {"Service-Worker"}, + {"Service-Worker-Allowed"}, + {"Service-Worker-Navigation-Preload"}, + {"Set-Cookie"}, + {"Set-Cookie2"}, + {"Status-URI"}, + {"Strict-Transport-Security"}, + {"SourceMap"}, + {"Timeout"}, + {"Timing-Allow-Origin"}, + {"Tk"}, + {"Trailer"}, + {"Trailers"}, + {"Transfer-Encoding"}, + {"TE"}, + {"Upgrade"}, + {"Upgrade-Insecure-Requests"}, + {"User-Agent"}, + {"X-OperaMini-Phone-UA"}, + {"X-UCBrowser-UA"}, + {"X-UCBrowser-Device-UA"}, + {"X-Device-User-Agent"}, + {"Vary"}, + {"Via"}, + {"CDN-Loop"}, + {"Warning"}, + {"WWW-Authenticate"}, + }); + + return *header_map; +} + +} // namespace quiche diff --git a/quiche/balsa/standard_header_map.h b/quiche/balsa/standard_header_map.h new file mode 100644 index 000000000000..d4a67df9920f --- /dev/null +++ b/quiche/balsa/standard_header_map.h @@ -0,0 +1,24 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_BALSA_STANDARD_HEADER_MAP_H_ +#define QUICHE_BALSA_STANDARD_HEADER_MAP_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "quiche/common/quiche_text_utils.h" + +namespace quiche { + +// This specifies an absl::flat_hash_set with case-insensitive lookup and +// hashing +using StandardHttpHeaderNameSet = + absl::flat_hash_set; + +const StandardHttpHeaderNameSet& GetStandardHeaderSet(); + +} // namespace quiche + +#endif // QUICHE_BALSA_STANDARD_HEADER_MAP_H_ diff --git a/quiche/binary_http/binary_http_message.cc b/quiche/binary_http/binary_http_message.cc new file mode 100644 index 000000000000..6d333ac2c566 --- /dev/null +++ b/quiche/binary_http/binary_http_message.cc @@ -0,0 +1,454 @@ +#include "quiche/binary_http/binary_http_message.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "quiche/common/quiche_data_reader.h" +#include "quiche/common/quiche_data_writer.h" + +namespace quiche { +namespace { + +constexpr uint8_t kKnownLengthRequestFraming = 0; +constexpr uint8_t kKnownLengthResponseFraming = 1; + +bool ReadStringValue(quiche::QuicheDataReader& reader, std::string& data) { + absl::string_view data_view; + if (!reader.ReadStringPieceVarInt62(&data_view)) { + return false; + } + data = std::string(data_view); + return true; +} + +bool IsValidPadding(absl::string_view data) { + return std::all_of(data.begin(), data.end(), + [](char c) { return c == '\0'; }); +} + +absl::StatusOr DecodeControlData( + quiche::QuicheDataReader& reader) { + BinaryHttpRequest::ControlData control_data; + if (!ReadStringValue(reader, control_data.method)) { + return absl::InvalidArgumentError("Failed to read method."); + } + if (!ReadStringValue(reader, control_data.scheme)) { + return absl::InvalidArgumentError("Failed to read scheme."); + } + if (!ReadStringValue(reader, control_data.authority)) { + return absl::InvalidArgumentError("Failed to read authority."); + } + if (!ReadStringValue(reader, control_data.path)) { + return absl::InvalidArgumentError("Failed to read path."); + } + return control_data; +} + +absl::Status DecodeFields( + quiche::QuicheDataReader& reader, + const std::function& + callback) { + absl::string_view fields; + if (!reader.ReadStringPieceVarInt62(&fields)) { + return absl::InvalidArgumentError("Failed to read fields."); + } + quiche::QuicheDataReader fields_reader(fields); + while (!fields_reader.IsDoneReading()) { + absl::string_view name; + if (!fields_reader.ReadStringPieceVarInt62(&name)) { + return absl::InvalidArgumentError("Failed to read field name."); + } + absl::string_view value; + if (!fields_reader.ReadStringPieceVarInt62(&value)) { + return absl::InvalidArgumentError("Failed to read field value."); + } + callback(name, value); + } + return absl::OkStatus(); +} + +absl::Status DecodeFieldsAndBody(quiche::QuicheDataReader& reader, + BinaryHttpMessage& message) { + if (const absl::Status status = DecodeFields( + reader, + [&message](absl::string_view name, absl::string_view value) { + message.AddHeaderField({std::string(name), std::string(value)}); + }); + !status.ok()) { + return status; + } + // TODO(bschneider): Handle case where remaining message is truncated. + // Skip it on encode as well. + // https://www.ietf.org/archive/id/draft-ietf-httpbis-binary-message-06.html#name-padding-and-truncation + absl::string_view body; + if (!reader.ReadStringPieceVarInt62(&body)) { + return absl::InvalidArgumentError("Failed to read body."); + } + message.set_body(std::string(body)); + // TODO(bschneider): Check for / read-in any trailer-fields + return absl::OkStatus(); +} + +absl::StatusOr DecodeKnownLengthRequest( + quiche::QuicheDataReader& reader) { + const auto control_data = DecodeControlData(reader); + if (!control_data.ok()) { + return control_data.status(); + } + BinaryHttpRequest request(std::move(*control_data)); + if (const absl::Status status = DecodeFieldsAndBody(reader, request); + !status.ok()) { + return status; + } + if (!IsValidPadding(reader.PeekRemainingPayload())) { + return absl::InvalidArgumentError("Non-zero padding."); + } + request.set_num_padding_bytes(reader.BytesRemaining()); + return request; +} + +absl::StatusOr DecodeKnownLengthResponse( + quiche::QuicheDataReader& reader) { + std::vector>> + informational_responses; + uint64_t status_code; + bool reading_response_control_data = true; + while (reading_response_control_data) { + if (!reader.ReadVarInt62(&status_code)) { + return absl::InvalidArgumentError("Failed to read status code."); + } + if (status_code >= 100 && status_code <= 199) { + std::vector fields; + if (const absl::Status status = DecodeFields( + reader, + [&fields](absl::string_view name, absl::string_view value) { + fields.push_back({std::string(name), std::string(value)}); + }); + !status.ok()) { + return status; + } + informational_responses.emplace_back(status_code, std::move(fields)); + } else { + reading_response_control_data = false; + } + } + BinaryHttpResponse response(status_code); + for (const auto& informational_response : informational_responses) { + if (const absl::Status status = response.AddInformationalResponse( + informational_response.first, + std::move(informational_response.second)); + !status.ok()) { + return status; + } + } + if (const absl::Status status = DecodeFieldsAndBody(reader, response); + !status.ok()) { + return status; + } + if (!IsValidPadding(reader.PeekRemainingPayload())) { + return absl::InvalidArgumentError("Non-zero padding."); + } + response.set_num_padding_bytes(reader.BytesRemaining()); + return response; +} + +uint64_t StringPieceVarInt62Len(absl::string_view s) { + return quiche::QuicheDataWriter::GetVarInt62Len(s.length()) + s.length(); +} +} // namespace + +void BinaryHttpMessage::Fields::AddField(BinaryHttpMessage::Field field) { + fields_.push_back(std::move(field)); +} + +// Encode fields in the order they were initially inserted. +// Updates do not change order. +absl::Status BinaryHttpMessage::Fields::Encode( + quiche::QuicheDataWriter& writer) const { + if (!writer.WriteVarInt62(EncodedFieldsSize())) { + return absl::InvalidArgumentError("Failed to write encoded field size."); + } + for (const BinaryHttpMessage::Field& field : fields_) { + if (!writer.WriteStringPieceVarInt62(field.name)) { + return absl::InvalidArgumentError("Failed to write field name."); + } + if (!writer.WriteStringPieceVarInt62(field.value)) { + return absl::InvalidArgumentError("Failed to write field value."); + } + } + return absl::OkStatus(); +} + +size_t BinaryHttpMessage::Fields::EncodedSize() const { + const size_t size = EncodedFieldsSize(); + return size + quiche::QuicheDataWriter::GetVarInt62Len(size); +} + +size_t BinaryHttpMessage::Fields::EncodedFieldsSize() const { + size_t size = 0; + for (const BinaryHttpMessage::Field& field : fields_) { + size += StringPieceVarInt62Len(field.name) + + StringPieceVarInt62Len(field.value); + } + return size; +} + +BinaryHttpMessage* BinaryHttpMessage::AddHeaderField( + BinaryHttpMessage::Field field) { + const std::string lower_name = absl::AsciiStrToLower(field.name); + if (lower_name == "host") { + has_host_ = true; + } + header_fields_.AddField({std::move(lower_name), std::move(field.value)}); + return this; +} + +// Appends the encoded fields and body to data. +absl::Status BinaryHttpMessage::EncodeKnownLengthFieldsAndBody( + quiche::QuicheDataWriter& writer) const { + if (const absl::Status status = header_fields_.Encode(writer); !status.ok()) { + return status; + } + if (!writer.WriteStringPieceVarInt62(body_)) { + return absl::InvalidArgumentError("Failed to encode body."); + } + // TODO(bschneider): Consider support for trailer fields on known-length + // requests. Trailers are atypical for a known-length request. + return absl::OkStatus(); +} + +size_t BinaryHttpMessage::EncodedKnownLengthFieldsAndBodySize() const { + return header_fields_.EncodedSize() + StringPieceVarInt62Len(body_); +} + +absl::Status BinaryHttpResponse::AddInformationalResponse( + uint16_t status_code, std::vector header_fields) { + if (status_code < 100) { + return absl::InvalidArgumentError("status code < 100"); + } + if (status_code > 199) { + return absl::InvalidArgumentError("status code > 199"); + } + InformationalResponse data(status_code); + for (Field& header : header_fields) { + data.AddField(header.name, std::move(header.value)); + } + informational_response_control_data_.push_back(std::move(data)); + return absl::OkStatus(); +} + +absl::StatusOr BinaryHttpResponse::Serialize() const { + // Only supporting known length requests so far. + return EncodeAsKnownLength(); +} + +absl::StatusOr BinaryHttpResponse::EncodeAsKnownLength() const { + std::string data; + data.resize(EncodedSize()); + quiche::QuicheDataWriter writer(data.size(), data.data()); + if (!writer.WriteUInt8(kKnownLengthResponseFraming)) { + return absl::InvalidArgumentError("Failed to write framing indicator"); + } + // Informational response + for (const auto& informational : informational_response_control_data_) { + if (const absl::Status status = informational.Encode(writer); + !status.ok()) { + return status; + } + } + if (!writer.WriteVarInt62(status_code_)) { + return absl::InvalidArgumentError("Failed to write status code"); + } + if (const absl::Status status = EncodeKnownLengthFieldsAndBody(writer); + !status.ok()) { + return status; + } + QUICHE_DCHECK_EQ(writer.remaining(), num_padding_bytes()); + writer.WritePadding(); + return data; +} + +size_t BinaryHttpResponse::EncodedSize() const { + size_t size = sizeof(kKnownLengthResponseFraming); + for (const auto& informational : informational_response_control_data_) { + size += informational.EncodedSize(); + } + return size + quiche::QuicheDataWriter::GetVarInt62Len(status_code_) + + EncodedKnownLengthFieldsAndBodySize() + num_padding_bytes(); +} + +void BinaryHttpResponse::InformationalResponse::AddField(absl::string_view name, + std::string value) { + fields_.AddField({absl::AsciiStrToLower(name), std::move(value)}); +} + +// Appends the encoded fields and body to data. +absl::Status BinaryHttpResponse::InformationalResponse::Encode( + quiche::QuicheDataWriter& writer) const { + writer.WriteVarInt62(status_code_); + return fields_.Encode(writer); +} + +size_t BinaryHttpResponse::InformationalResponse::EncodedSize() const { + return quiche::QuicheDataWriter::GetVarInt62Len(status_code_) + + fields_.EncodedSize(); +} + +absl::StatusOr BinaryHttpRequest::Serialize() const { + // Only supporting known length requests so far. + return EncodeAsKnownLength(); +} + +// https://www.ietf.org/archive/id/draft-ietf-httpbis-binary-message-06.html#name-request-control-data +absl::Status BinaryHttpRequest::EncodeControlData( + quiche::QuicheDataWriter& writer) const { + if (!writer.WriteStringPieceVarInt62(control_data_.method)) { + return absl::InvalidArgumentError("Failed to encode method."); + } + if (!writer.WriteStringPieceVarInt62(control_data_.scheme)) { + return absl::InvalidArgumentError("Failed to encode scheme."); + } + // the Host header field is not replicated in the :authority field, as is + // required for ensuring that the request is reproduced accurately; see + // Section 8.1.2.3 of [H2]. + if (!has_host()) { + if (!writer.WriteStringPieceVarInt62(control_data_.authority)) { + return absl::InvalidArgumentError("Failed to encode authority."); + } + } else { + if (!writer.WriteStringPieceVarInt62("")) { + return absl::InvalidArgumentError("Failed to encode authority."); + } + } + if (!writer.WriteStringPieceVarInt62(control_data_.path)) { + return absl::InvalidArgumentError("Failed to encode path."); + } + return absl::OkStatus(); +} + +size_t BinaryHttpRequest::EncodedControlDataSize() const { + size_t size = StringPieceVarInt62Len(control_data_.method) + + StringPieceVarInt62Len(control_data_.scheme) + + StringPieceVarInt62Len(control_data_.path); + if (!has_host()) { + size += StringPieceVarInt62Len(control_data_.authority); + } else { + size += StringPieceVarInt62Len(""); + } + return size; +} + +size_t BinaryHttpRequest::EncodedSize() const { + return sizeof(kKnownLengthRequestFraming) + EncodedControlDataSize() + + EncodedKnownLengthFieldsAndBodySize() + num_padding_bytes(); +} + +// https://www.ietf.org/archive/id/draft-ietf-httpbis-binary-message-06.html#name-known-length-messages +absl::StatusOr BinaryHttpRequest::EncodeAsKnownLength() const { + std::string data; + data.resize(EncodedSize()); + quiche::QuicheDataWriter writer(data.size(), data.data()); + if (!writer.WriteUInt8(kKnownLengthRequestFraming)) { + return absl::InvalidArgumentError("Failed to encode framing indicator."); + } + if (const absl::Status status = EncodeControlData(writer); !status.ok()) { + return status; + } + if (const absl::Status status = EncodeKnownLengthFieldsAndBody(writer); + !status.ok()) { + return status; + } + QUICHE_DCHECK_EQ(writer.remaining(), num_padding_bytes()); + writer.WritePadding(); + return data; +} + +absl::StatusOr BinaryHttpRequest::Create( + absl::string_view data) { + quiche::QuicheDataReader reader(data); + uint8_t framing; + if (!reader.ReadUInt8(&framing)) { + return absl::InvalidArgumentError("Missing framing indicator."); + } + if (framing == kKnownLengthRequestFraming) { + return DecodeKnownLengthRequest(reader); + } + return absl::UnimplementedError( + absl::StrCat("Unsupported framing type ", framing)); +} + +absl::StatusOr BinaryHttpResponse::Create( + absl::string_view data) { + quiche::QuicheDataReader reader(data); + uint8_t framing; + if (!reader.ReadUInt8(&framing)) { + return absl::InvalidArgumentError("Missing framing indicator."); + } + if (framing == kKnownLengthResponseFraming) { + return DecodeKnownLengthResponse(reader); + } + return absl::UnimplementedError( + absl::StrCat("Unsupported framing type ", framing)); +} + +std::string BinaryHttpMessage::DebugString() const { + std::vector headers; + for (const auto& field : GetHeaderFields()) { + headers.emplace_back(field.DebugString()); + } + return absl::StrCat("BinaryHttpMessage{Headers{", absl::StrJoin(headers, ";"), + "}Body{", body(), "}}"); +} + +std::string BinaryHttpMessage::Field::DebugString() const { + return absl::StrCat("Field{", name, "=", value, "}"); +} + +std::string BinaryHttpResponse::InformationalResponse::DebugString() const { + std::vector fs; + for (const auto& field : fields()) { + fs.emplace_back(field.DebugString()); + } + return absl::StrCat("InformationalResponse{", absl::StrJoin(fs, ";"), "}"); +} + +std::string BinaryHttpResponse::DebugString() const { + std::vector irs; + for (const auto& ir : informational_responses()) { + irs.emplace_back(ir.DebugString()); + } + return absl::StrCat("BinaryHttpResponse(", status_code_, "){", + BinaryHttpMessage::DebugString(), absl::StrJoin(irs, ";"), + "}"); +} + +std::string BinaryHttpRequest::DebugString() const { + return absl::StrCat("BinaryHttpRequest{", BinaryHttpMessage::DebugString(), + "}"); +} + +void PrintTo(const BinaryHttpRequest& msg, std::ostream* os) { + *os << msg.DebugString(); +} + +void PrintTo(const BinaryHttpResponse& msg, std::ostream* os) { + *os << msg.DebugString(); +} + +void PrintTo(const BinaryHttpMessage::Field& msg, std::ostream* os) { + *os << msg.DebugString(); +} + +} // namespace quiche diff --git a/quiche/binary_http/binary_http_message.h b/quiche/binary_http/binary_http_message.h new file mode 100644 index 000000000000..8f1f8f883212 --- /dev/null +++ b/quiche/binary_http/binary_http_message.h @@ -0,0 +1,291 @@ +#ifndef QUICHE_BINARY_HTTP_BINARY_HTTP_MESSAGE_H_ +#define QUICHE_BINARY_HTTP_BINARY_HTTP_MESSAGE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/quiche_data_writer.h" + +namespace quiche { + +// Supports encoding and decoding Binary Http messages. +// Currently limited to known-length messages. +// https://www.ietf.org/archive/id/draft-ietf-httpbis-binary-message-06.html +class QUICHE_EXPORT BinaryHttpMessage { + public: + // Name value pair of either a header or trailer field. + struct QUICHE_EXPORT Field { + std::string name; + std::string value; + bool operator==(const BinaryHttpMessage::Field& rhs) const { + return name == rhs.name && value == rhs.value; + } + + bool operator!=(const BinaryHttpMessage::Field& rhs) const { + return !(*this == rhs); + } + + std::string DebugString() const; + }; + virtual ~BinaryHttpMessage() = default; + + // TODO(bschneider): Switch to use existing Http2HeaderBlock + BinaryHttpMessage* AddHeaderField(Field header_field); + + const std::vector& GetHeaderFields() const { + return header_fields_.fields(); + } + + BinaryHttpMessage* set_body(std::string body) { + body_ = std::move(body); + return this; + } + + void swap_body(std::string& body) { body_.swap(body); } + void set_num_padding_bytes(size_t num_padding_bytes) { + num_padding_bytes_ = num_padding_bytes; + } + size_t num_padding_bytes() const { return num_padding_bytes_; } + + absl::string_view body() const { return body_; } + + // Returns the number of bytes `Serialize` will return, including padding. + virtual size_t EncodedSize() const = 0; + + // Returns the Binary Http formatted message. + virtual absl::StatusOr Serialize() const = 0; + // TODO(bschneider): Add AddTrailerField for chunked messages + // TODO(bschneider): Add SetBodyCallback() for chunked messages + + virtual std::string DebugString() const; + + protected: + class Fields { + public: + // Appends `field` to list of fields. Can contain duplicates. + void AddField(BinaryHttpMessage::Field field); + + const std::vector& fields() const { + return fields_; + } + + bool operator==(const BinaryHttpMessage::Fields& rhs) const { + return fields_ == rhs.fields_; + } + + // Encode fields in insertion order. + // https://www.ietf.org/archive/id/draft-ietf-httpbis-binary-message-06.html#name-header-and-trailer-field-li + absl::Status Encode(quiche::QuicheDataWriter& writer) const; + + // The number of returned by EncodedFieldsSize + // plus the number of bytes used in the varint holding that value. + size_t EncodedSize() const; + + private: + // Number of bytes of just the set of fields. + size_t EncodedFieldsSize() const; + + // Fields in insertion order. + std::vector fields_; + }; + + // Checks equality excluding padding. + bool IsPayloadEqual(const BinaryHttpMessage& rhs) const { + // `has_host_` is derived from `header_fields_` so it doesn't need to be + // tested directly. + return body_ == rhs.body_ && header_fields_ == rhs.header_fields_; + } + + absl::Status EncodeKnownLengthFieldsAndBody( + quiche::QuicheDataWriter& writer) const; + size_t EncodedKnownLengthFieldsAndBodySize() const; + bool has_host() const { return has_host_; } + + private: + std::string body_; + Fields header_fields_; + bool has_host_ = false; + size_t num_padding_bytes_ = 0; +}; + +void QUICHE_EXPORT PrintTo(const BinaryHttpMessage::Field& msg, + std::ostream* os); + +class QUICHE_EXPORT BinaryHttpRequest : public BinaryHttpMessage { + public: + // HTTP request must have method, scheme, and path fields. + // The `authority` field is required unless a `host` header field is added. + // If a `host` header field is added, `authority` is serialized as the empty + // string. + // Some examples are: + // scheme: HTTP + // authority: www.example.com + // path: /index.html + struct QUICHE_EXPORT ControlData { + std::string method; + std::string scheme; + std::string authority; + std::string path; + bool operator==(const BinaryHttpRequest::ControlData& rhs) const { + return method == rhs.method && scheme == rhs.scheme && + authority == rhs.authority && path == rhs.path; + } + bool operator!=(const BinaryHttpRequest::ControlData& rhs) const { + return !(*this == rhs); + } + }; + explicit BinaryHttpRequest(ControlData control_data) + : control_data_(std::move(control_data)) {} + + // Deserialize + static absl::StatusOr Create(absl::string_view data); + + size_t EncodedSize() const override; + absl::StatusOr Serialize() const override; + const ControlData& control_data() const { return control_data_; } + + virtual std::string DebugString() const override; + + // Returns true if the contents of the requests are equal, excluding padding. + bool IsPayloadEqual(const BinaryHttpRequest& rhs) const { + return control_data_ == rhs.control_data_ && + BinaryHttpMessage::IsPayloadEqual(rhs); + } + + bool operator==(const BinaryHttpRequest& rhs) const { + return IsPayloadEqual(rhs) && + num_padding_bytes() == rhs.num_padding_bytes(); + } + + bool operator!=(const BinaryHttpRequest& rhs) const { + return !(*this == rhs); + } + + private: + absl::Status EncodeControlData(quiche::QuicheDataWriter& writer) const; + + size_t EncodedControlDataSize() const; + + // Returns Binary Http known length request formatted request. + absl::StatusOr EncodeAsKnownLength() const; + + const ControlData control_data_; +}; + +void QUICHE_EXPORT PrintTo(const BinaryHttpRequest& msg, std::ostream* os); + +class QUICHE_EXPORT BinaryHttpResponse : public BinaryHttpMessage { + public: + // https://www.ietf.org/archive/id/draft-ietf-httpbis-binary-message-06.html#name-response-control-data + // A response can contain 0 to N informational responses. Each informational + // response contains a status code followed by a header field. Valid status + // codes are [100,199]. + class QUICHE_EXPORT InformationalResponse { + public: + explicit InformationalResponse(uint16_t status_code) + : status_code_(status_code) {} + InformationalResponse(uint16_t status_code, + const std::vector& fields) + : status_code_(status_code) { + for (const BinaryHttpMessage::Field& field : fields) { + AddField(field.name, field.value); + } + } + + bool operator==( + const BinaryHttpResponse::InformationalResponse& rhs) const { + return status_code_ == rhs.status_code_ && fields_ == rhs.fields_; + } + + bool operator!=( + const BinaryHttpResponse::InformationalResponse& rhs) const { + return !(*this == rhs); + } + + // Adds a field with the provided name, converted to lower case. + // Fields are in the order they are added. + void AddField(absl::string_view name, std::string value); + + const std::vector& fields() const { + return fields_.fields(); + } + + uint16_t status_code() const { return status_code_; } + + std::string DebugString() const; + + private: + // Give BinaryHttpResponse access to Encoding functionality. + friend class BinaryHttpResponse; + + size_t EncodedSize() const; + + // Appends the encoded fields and body to `writer`. + absl::Status Encode(quiche::QuicheDataWriter& writer) const; + + const uint16_t status_code_; + BinaryHttpMessage::Fields fields_; + }; + + explicit BinaryHttpResponse(uint16_t status_code) + : status_code_(status_code) {} + + // Deserialize + static absl::StatusOr Create(absl::string_view data); + + size_t EncodedSize() const override; + absl::StatusOr Serialize() const override; + + // Informational status codes must be between 100 and 199 inclusive. + absl::Status AddInformationalResponse(uint16_t status_code, + std::vector header_fields); + + uint16_t status_code() const { return status_code_; } + + // References in the returned `ResponseControlData` are invalidated on + // `BinaryHttpResponse` object mutations. + const std::vector& informational_responses() const { + return informational_response_control_data_; + } + + virtual std::string DebugString() const override; + + // Returns true if the contents of the requests are equal, excluding padding. + bool IsPayloadEqual(const BinaryHttpResponse& rhs) const { + return informational_response_control_data_ == + rhs.informational_response_control_data_ && + status_code_ == rhs.status_code_ && + BinaryHttpMessage::IsPayloadEqual(rhs); + } + + bool operator==(const BinaryHttpResponse& rhs) const { + return IsPayloadEqual(rhs) && + num_padding_bytes() == rhs.num_padding_bytes(); + } + + bool operator!=(const BinaryHttpResponse& rhs) const { + return !(*this == rhs); + } + + private: + // Returns Binary Http known length request formatted response. + absl::StatusOr EncodeAsKnownLength() const; + + std::vector informational_response_control_data_; + const uint16_t status_code_; +}; + +void QUICHE_EXPORT PrintTo(const BinaryHttpResponse& msg, std::ostream* os); +} // namespace quiche + +#endif // QUICHE_BINARY_HTTP_BINARY_HTTP_MESSAGE_H_ diff --git a/quiche/binary_http/binary_http_message_test.cc b/quiche/binary_http/binary_http_message_test.cc new file mode 100644 index 000000000000..df8c0367b463 --- /dev/null +++ b/quiche/binary_http/binary_http_message_test.cc @@ -0,0 +1,786 @@ +#include "quiche/binary_http/binary_http_message.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "quiche/common/platform/api/quiche_test.h" + +using ::testing::ContainerEq; +using ::testing::FieldsAre; +using ::testing::StrEq; + +namespace quiche { +namespace { + +std::string WordToBytes(uint32_t word) { + return std::string({static_cast(word >> 24), + static_cast(word >> 16), + static_cast(word >> 8), static_cast(word)}); +} + +template +void TestPrintTo(const T& resp) { + std::ostringstream os; + PrintTo(resp, &os); + EXPECT_EQ(os.str(), resp.DebugString()); +} +} // namespace +// Test examples from +// https://www.ietf.org/archive/id/draft-ietf-httpbis-binary-message-06.html + +TEST(BinaryHttpRequest, EncodeGetNoBody) { + /* + GET /hello.txt HTTP/1.1 + User-Agent: curl/7.16.3 libcurl/7.16.3 OpenSSL/0.9.7l zlib/1.2.3 + Host: www.example.com + Accept-Language: en, mi + */ + BinaryHttpRequest request({"GET", "https", "www.example.com", "/hello.txt"}); + request + .AddHeaderField({"User-Agent", + "curl/7.16.3 libcurl/7.16.3 OpenSSL/0.9.7l zlib/1.2.3"}) + ->AddHeaderField({"Host", "www.example.com"}) + ->AddHeaderField({"Accept-Language", "en, mi"}); + /* + 00000000: 00034745 54056874 74707300 0a2f6865 ..GET.https../he + 00000010: 6c6c6f2e 74787440 6c0a7573 65722d61 llo.txt@l.user-a + 00000020: 67656e74 34637572 6c2f372e 31362e33 gent4curl/7.16.3 + 00000030: 206c6962 6375726c 2f372e31 362e3320 libcurl/7.16.3 + 00000040: 4f70656e 53534c2f 302e392e 376c207a OpenSSL/0.9.7l z + 00000050: 6c69622f 312e322e 3304686f 73740f77 lib/1.2.3.host.w + 00000060: 77772e65 78616d70 6c652e63 6f6d0f61 ww.example.com.a + 00000070: 63636570 742d6c61 6e677561 67650665 ccept-language.e + 00000080: 6e2c206d 6900 n, mi.. + */ + const uint32_t expected_words[] = { + 0x00034745, 0x54056874, 0x74707300, 0x0a2f6865, 0x6c6c6f2e, 0x74787440, + 0x6c0a7573, 0x65722d61, 0x67656e74, 0x34637572, 0x6c2f372e, 0x31362e33, + 0x206c6962, 0x6375726c, 0x2f372e31, 0x362e3320, 0x4f70656e, 0x53534c2f, + 0x302e392e, 0x376c207a, 0x6c69622f, 0x312e322e, 0x3304686f, 0x73740f77, + 0x77772e65, 0x78616d70, 0x6c652e63, 0x6f6d0f61, 0x63636570, 0x742d6c61, + 0x6e677561, 0x67650665, 0x6e2c206d, 0x69000000}; + std::string expected; + for (const auto& word : expected_words) { + expected += WordToBytes(word); + } + // Remove padding. + expected.resize(expected.size() - 2); + + const auto result = request.Serialize(); + ASSERT_TRUE(result.ok()); + ASSERT_EQ(*result, expected); + EXPECT_THAT( + request.DebugString(), + StrEq("BinaryHttpRequest{BinaryHttpMessage{Headers{Field{user-agent=curl/" + "7.16.3 " + "libcurl/7.16.3 OpenSSL/0.9.7l " + "zlib/1.2.3};Field{host=www.example.com};Field{accept-language=en, " + "mi}}Body{}}}")); + TestPrintTo(request); +} + +TEST(BinaryHttpRequest, DecodeGetNoBody) { + const uint32_t words[] = { + 0x00034745, 0x54056874, 0x74707300, 0x0a2f6865, 0x6c6c6f2e, 0x74787440, + 0x6c0a7573, 0x65722d61, 0x67656e74, 0x34637572, 0x6c2f372e, 0x31362e33, + 0x206c6962, 0x6375726c, 0x2f372e31, 0x362e3320, 0x4f70656e, 0x53534c2f, + 0x302e392e, 0x376c207a, 0x6c69622f, 0x312e322e, 0x3304686f, 0x73740f77, + 0x77772e65, 0x78616d70, 0x6c652e63, 0x6f6d0f61, 0x63636570, 0x742d6c61, + 0x6e677561, 0x67650665, 0x6e2c206d, 0x69000000}; + std::string data; + for (const auto& word : words) { + data += WordToBytes(word); + } + const auto request_so = BinaryHttpRequest::Create(data); + ASSERT_TRUE(request_so.ok()); + const BinaryHttpRequest request = *request_so; + ASSERT_THAT(request.control_data(), + FieldsAre("GET", "https", "", "/hello.txt")); + std::vector expected_fields = { + {"user-agent", "curl/7.16.3 libcurl/7.16.3 OpenSSL/0.9.7l zlib/1.2.3"}, + {"host", "www.example.com"}, + {"accept-language", "en, mi"}}; + for (const auto& field : expected_fields) { + TestPrintTo(field); + } + ASSERT_THAT(request.GetHeaderFields(), ContainerEq(expected_fields)); + ASSERT_EQ(request.body(), ""); + EXPECT_THAT( + request.DebugString(), + StrEq("BinaryHttpRequest{BinaryHttpMessage{Headers{Field{user-agent=curl/" + "7.16.3 " + "libcurl/7.16.3 OpenSSL/0.9.7l " + "zlib/1.2.3};Field{host=www.example.com};Field{accept-language=en, " + "mi}}Body{}}}")); + TestPrintTo(request); +} + +TEST(BinaryHttpRequest, EncodeGetWithAuthority) { + /* + GET https://www.example.com/hello.txt HTTP/1.1 + User-Agent: curl/7.16.3 libcurl/7.16.3 OpenSSL/0.9.7l zlib/1.2.3 + Accept-Language: en, mi + */ + BinaryHttpRequest request({"GET", "https", "www.example.com", "/hello.txt"}); + request + .AddHeaderField({"User-Agent", + "curl/7.16.3 libcurl/7.16.3 OpenSSL/0.9.7l zlib/1.2.3"}) + ->AddHeaderField({"Accept-Language", "en, mi"}); + /* + 00000000: 00034745 54056874 7470730f 7777772e ..GET.https.www. + 00000010: 6578616d 706c652e 636f6d0a 2f68656c example.com./hel + 00000020: 6c6f2e74 78744057 0a757365 722d6167 lo.txt@W.user-ag + 00000030: 656e7434 6375726c 2f372e31 362e3320 ent4curl/7.16.3 + 00000040: 6c696263 75726c2f 372e3136 2e33204f libcurl/7.16.3 O + 00000050: 70656e53 534c2f30 2e392e37 6c207a6c penSSL/0.9.7l zl + 00000060: 69622f31 2e322e33 0f616363 6570742d ib/1.2.3.accept- + 00000070: 6c616e67 75616765 06656e2c 206d6900 language.en, mi. + */ + + const uint32_t expected_words[] = { + 0x00034745, 0x54056874, 0x7470730f, 0x7777772e, 0x6578616d, 0x706c652e, + 0x636f6d0a, 0x2f68656c, 0x6c6f2e74, 0x78744057, 0x0a757365, 0x722d6167, + 0x656e7434, 0x6375726c, 0x2f372e31, 0x362e3320, 0x6c696263, 0x75726c2f, + 0x372e3136, 0x2e33204f, 0x70656e53, 0x534c2f30, 0x2e392e37, 0x6c207a6c, + 0x69622f31, 0x2e322e33, 0x0f616363, 0x6570742d, 0x6c616e67, 0x75616765, + 0x06656e2c, 0x206d6900}; + std::string expected; + for (const auto& word : expected_words) { + expected += WordToBytes(word); + } + const auto result = request.Serialize(); + ASSERT_TRUE(result.ok()); + ASSERT_EQ(*result, expected); + EXPECT_THAT( + request.DebugString(), + StrEq("BinaryHttpRequest{BinaryHttpMessage{Headers{Field{user-agent=curl/" + "7.16.3 libcurl/7.16.3 OpenSSL/0.9.7l " + "zlib/1.2.3};Field{accept-language=en, mi}}Body{}}}")); +} + +TEST(BinaryHttpRequest, DecodeGetWithAuthority) { + const uint32_t words[] = { + 0x00034745, 0x54056874, 0x7470730f, 0x7777772e, 0x6578616d, 0x706c652e, + 0x636f6d0a, 0x2f68656c, 0x6c6f2e74, 0x78744057, 0x0a757365, 0x722d6167, + 0x656e7434, 0x6375726c, 0x2f372e31, 0x362e3320, 0x6c696263, 0x75726c2f, + 0x372e3136, 0x2e33204f, 0x70656e53, 0x534c2f30, 0x2e392e37, 0x6c207a6c, + 0x69622f31, 0x2e322e33, 0x0f616363, 0x6570742d, 0x6c616e67, 0x75616765, + 0x06656e2c, 0x206d6900, 0x00}; + std::string data; + for (const auto& word : words) { + data += WordToBytes(word); + } + const auto request_so = BinaryHttpRequest::Create(data); + ASSERT_TRUE(request_so.ok()); + const BinaryHttpRequest request = *request_so; + ASSERT_THAT(request.control_data(), + FieldsAre("GET", "https", "www.example.com", "/hello.txt")); + std::vector expected_fields = { + {"user-agent", "curl/7.16.3 libcurl/7.16.3 OpenSSL/0.9.7l zlib/1.2.3"}, + {"accept-language", "en, mi"}}; + ASSERT_THAT(request.GetHeaderFields(), ContainerEq(expected_fields)); + ASSERT_EQ(request.body(), ""); + EXPECT_THAT( + request.DebugString(), + StrEq("BinaryHttpRequest{BinaryHttpMessage{Headers{Field{user-agent=curl/" + "7.16.3 libcurl/7.16.3 OpenSSL/0.9.7l " + "zlib/1.2.3};Field{accept-language=en, mi}}Body{}}}")); +} + +TEST(BinaryHttpRequest, EncodePostBody) { + /* + POST /hello.txt HTTP/1.1 + User-Agent: not/telling + Host: www.example.com + Accept-Language: en + + Some body that I used to post. + */ + BinaryHttpRequest request({"POST", "https", "www.example.com", "/hello.txt"}); + request.AddHeaderField({"User-Agent", "not/telling"}) + ->AddHeaderField({"Host", "www.example.com"}) + ->AddHeaderField({"Accept-Language", "en"}) + ->set_body({"Some body that I used to post.\r\n"}); + /* + 00000000: 0004504f 53540568 74747073 000a2f68 ..POST.https../h + 00000010: 656c6c6f 2e747874 3f0a7573 65722d61 ello.txt?.user-a + 00000020: 67656e74 0b6e6f74 2f74656c 6c696e67 gent.not/telling + 00000030: 04686f73 740f7777 772e6578 616d706c .host.www.exampl + 00000040: 652e636f 6d0f6163 63657074 2d6c616e e.com.accept-lan + 00000050: 67756167 6502656e 20536f6d 6520626f guage.en Some bo + 00000060: 64792074 68617420 49207573 65642074 dy that I used t + 00000070: 6f20706f 73742e0d 0a o post.... + */ + const uint32_t expected_words[] = { + 0x0004504f, 0x53540568, 0x74747073, 0x000a2f68, 0x656c6c6f, 0x2e747874, + 0x3f0a7573, 0x65722d61, 0x67656e74, 0x0b6e6f74, 0x2f74656c, 0x6c696e67, + 0x04686f73, 0x740f7777, 0x772e6578, 0x616d706c, 0x652e636f, 0x6d0f6163, + 0x63657074, 0x2d6c616e, 0x67756167, 0x6502656e, 0x20536f6d, 0x6520626f, + 0x64792074, 0x68617420, 0x49207573, 0x65642074, 0x6f20706f, 0x73742e0d, + 0x0a000000}; + std::string expected; + for (const auto& word : expected_words) { + expected += WordToBytes(word); + } + // Remove padding. + expected.resize(expected.size() - 3); + const auto result = request.Serialize(); + ASSERT_TRUE(result.ok()); + ASSERT_EQ(*result, expected); + EXPECT_THAT( + request.DebugString(), + StrEq("BinaryHttpRequest{BinaryHttpMessage{Headers{Field{user-agent=not/" + "telling};Field{host=www.example.com};Field{accept-language=en}}" + "Body{Some " + "body that I used to post.\r\n}}}")); +} + +TEST(BinaryHttpRequest, DecodePostBody) { + const uint32_t words[] = { + 0x0004504f, 0x53540568, 0x74747073, 0x000a2f68, 0x656c6c6f, 0x2e747874, + 0x3f0a7573, 0x65722d61, 0x67656e74, 0x0b6e6f74, 0x2f74656c, 0x6c696e67, + 0x04686f73, 0x740f7777, 0x772e6578, 0x616d706c, 0x652e636f, 0x6d0f6163, + 0x63657074, 0x2d6c616e, 0x67756167, 0x6502656e, 0x20536f6d, 0x6520626f, + 0x64792074, 0x68617420, 0x49207573, 0x65642074, 0x6f20706f, 0x73742e0d, + 0x0a000000}; + std::string data; + for (const auto& word : words) { + data += WordToBytes(word); + } + const auto request_so = BinaryHttpRequest::Create(data); + ASSERT_TRUE(request_so.ok()); + BinaryHttpRequest request = *request_so; + ASSERT_THAT(request.control_data(), + FieldsAre("POST", "https", "", "/hello.txt")); + std::vector expected_fields = { + {"user-agent", "not/telling"}, + {"host", "www.example.com"}, + {"accept-language", "en"}}; + ASSERT_THAT(request.GetHeaderFields(), ContainerEq(expected_fields)); + ASSERT_EQ(request.body(), "Some body that I used to post.\r\n"); + EXPECT_THAT( + request.DebugString(), + StrEq("BinaryHttpRequest{BinaryHttpMessage{Headers{Field{user-agent=not/" + "telling};Field{host=www.example.com};Field{accept-language=en}}" + "Body{Some " + "body that I used to post.\r\n}}}")); +} + +TEST(BinaryHttpRequest, Equality) { + BinaryHttpRequest request({"POST", "https", "www.example.com", "/hello.txt"}); + request.AddHeaderField({"User-Agent", "not/telling"}) + ->set_body({"hello, world!\r\n"}); + + BinaryHttpRequest same({"POST", "https", "www.example.com", "/hello.txt"}); + same.AddHeaderField({"User-Agent", "not/telling"}) + ->set_body({"hello, world!\r\n"}); + EXPECT_EQ(request, same); +} + +TEST(BinaryHttpRequest, Inequality) { + BinaryHttpRequest request({"POST", "https", "www.example.com", "/hello.txt"}); + request.AddHeaderField({"User-Agent", "not/telling"}) + ->set_body({"hello, world!\r\n"}); + + BinaryHttpRequest different_control( + {"PUT", "https", "www.example.com", "/hello.txt"}); + different_control.AddHeaderField({"User-Agent", "not/telling"}) + ->set_body({"hello, world!\r\n"}); + EXPECT_NE(request, different_control); + + BinaryHttpRequest different_header( + {"PUT", "https", "www.example.com", "/hello.txt"}); + different_header.AddHeaderField({"User-Agent", "told/you"}) + ->set_body({"hello, world!\r\n"}); + EXPECT_NE(request, different_header); + + BinaryHttpRequest no_header( + {"PUT", "https", "www.example.com", "/hello.txt"}); + no_header.set_body({"hello, world!\r\n"}); + EXPECT_NE(request, no_header); + + BinaryHttpRequest different_body( + {"POST", "https", "www.example.com", "/hello.txt"}); + different_body.AddHeaderField({"User-Agent", "not/telling"}) + ->set_body({"goodbye, world!\r\n"}); + EXPECT_NE(request, different_body); + + BinaryHttpRequest no_body({"POST", "https", "www.example.com", "/hello.txt"}); + no_body.AddHeaderField({"User-Agent", "not/telling"}); + EXPECT_NE(request, no_body); +} + +TEST(BinaryHttpResponse, EncodeNoBody) { + /* + HTTP/1.1 404 Not Found + Server: Apache + */ + BinaryHttpResponse response(404); + response.AddHeaderField({"Server", "Apache"}); + /* + 0141940e 06736572 76657206 41706163 .A...server.Apac + 686500 he.. + */ + const uint32_t expected_words[] = {0x0141940e, 0x06736572, 0x76657206, + 0x41706163, 0x68650000}; + std::string expected; + for (const auto& word : expected_words) { + expected += WordToBytes(word); + } + // Remove padding. + expected.resize(expected.size() - 1); + const auto result = response.Serialize(); + ASSERT_TRUE(result.ok()); + ASSERT_EQ(*result, expected); + EXPECT_THAT( + response.DebugString(), + StrEq("BinaryHttpResponse(404){BinaryHttpMessage{Headers{Field{server=" + "Apache}}Body{}}}")); +} + +TEST(BinaryHttpResponse, DecodeNoBody) { + /* + HTTP/1.1 404 Not Found + Server: Apache + */ + const uint32_t words[] = {0x0141940e, 0x06736572, 0x76657206, 0x41706163, + 0x68650000}; + std::string data; + for (const auto& word : words) { + data += WordToBytes(word); + } + const auto response_so = BinaryHttpResponse::Create(data); + ASSERT_TRUE(response_so.ok()); + const BinaryHttpResponse response = *response_so; + ASSERT_EQ(response.status_code(), 404); + std::vector expected_fields = { + {"server", "Apache"}}; + ASSERT_THAT(response.GetHeaderFields(), ContainerEq(expected_fields)); + ASSERT_EQ(response.body(), ""); + ASSERT_TRUE(response.informational_responses().empty()); + EXPECT_THAT( + response.DebugString(), + StrEq("BinaryHttpResponse(404){BinaryHttpMessage{Headers{Field{server=" + "Apache}}Body{}}}")); +} + +TEST(BinaryHttpResponse, EncodeBody) { + /* + HTTP/1.1 200 OK + Server: Apache + + Hello, world! + */ + BinaryHttpResponse response(200); + response.AddHeaderField({"Server", "Apache"}); + response.set_body("Hello, world!\r\n"); + /* + 0140c80e 06736572 76657206 41706163 .@...server.Apac + 68650f48 656c6c6f 2c20776f 726c6421 he.Hello, world! + 0d0a .... + */ + const uint32_t expected_words[] = {0x0140c80e, 0x06736572, 0x76657206, + 0x41706163, 0x68650f48, 0x656c6c6f, + 0x2c20776f, 0x726c6421, 0x0d0a0000}; + std::string expected; + for (const auto& word : expected_words) { + expected += WordToBytes(word); + } + // Remove padding. + expected.resize(expected.size() - 2); + + const auto result = response.Serialize(); + ASSERT_TRUE(result.ok()); + ASSERT_EQ(*result, expected); + EXPECT_THAT( + response.DebugString(), + StrEq("BinaryHttpResponse(200){BinaryHttpMessage{Headers{Field{server=" + "Apache}}Body{Hello, world!\r\n}}}")); +} + +TEST(BinaryHttpResponse, DecodeBody) { + /* + HTTP/1.1 200 OK + + Hello, world! + */ + const uint32_t words[] = {0x0140c80e, 0x06736572, 0x76657206, + 0x41706163, 0x68650f48, 0x656c6c6f, + 0x2c20776f, 0x726c6421, 0x0d0a0000}; + std::string data; + for (const auto& word : words) { + data += WordToBytes(word); + } + const auto response_so = BinaryHttpResponse::Create(data); + ASSERT_TRUE(response_so.ok()); + const BinaryHttpResponse response = *response_so; + ASSERT_EQ(response.status_code(), 200); + std::vector expected_fields = { + {"server", "Apache"}}; + ASSERT_THAT(response.GetHeaderFields(), ContainerEq(expected_fields)); + ASSERT_EQ(response.body(), "Hello, world!\r\n"); + ASSERT_TRUE(response.informational_responses().empty()); + EXPECT_THAT( + response.DebugString(), + StrEq("BinaryHttpResponse(200){BinaryHttpMessage{Headers{Field{server=" + "Apache}}Body{Hello, world!\r\n}}}")); +} + +TEST(BHttpResponse, AddBadInformationalResponseCode) { + BinaryHttpResponse response(200); + ASSERT_FALSE(response.AddInformationalResponse(50, {}).ok()); + ASSERT_FALSE(response.AddInformationalResponse(300, {}).ok()); +} + +TEST(BinaryHttpResponse, EncodeMultiInformationalWithBody) { + /* + HTTP/1.1 102 Processing + Running: "sleep 15" + + HTTP/1.1 103 Early Hints + Link: ; rel=preload; as=style + Link: ; rel=preload; as=script + + HTTP/1.1 200 OK + Date: Mon, 27 Jul 2009 12:28:53 GMT + Server: Apache + Last-Modified: Wed, 22 Jul 2009 19:15:56 GMT + ETag: "34aa387-d-1568eb00" + Accept-Ranges: bytes + Content-Length: 51 + Vary: Accept-Encoding + Content-Type: text/plain + + Hello World! My content includes a trailing CRLF. + */ + BinaryHttpResponse response(200); + response.AddHeaderField({"Date", "Mon, 27 Jul 2009 12:28:53 GMT"}) + ->AddHeaderField({"Server", "Apache"}) + ->AddHeaderField({"Last-Modified", "Wed, 22 Jul 2009 19:15:56 GMT"}) + ->AddHeaderField({"ETag", "\"34aa387-d-1568eb00\""}) + ->AddHeaderField({"Accept-Ranges", "bytes"}) + ->AddHeaderField({"Content-Length", "51"}) + ->AddHeaderField({"Vary", "Accept-Encoding"}) + ->AddHeaderField({"Content-Type", "text/plain"}); + response.set_body("Hello World! My content includes a trailing CRLF.\r\n"); + ASSERT_TRUE( + response.AddInformationalResponse(102, {{"Running", "\"sleep 15\""}}) + .ok()); + ASSERT_TRUE(response + .AddInformationalResponse( + 103, {{"Link", "; rel=preload; as=style"}, + {"Link", "; rel=preload; as=script"}}) + .ok()); + + /* + 01406613 0772756e 6e696e67 0a22736c .@f..running."sl + 65657020 31352240 67405304 6c696e6b eep 15"@g@S.link + 233c2f73 74796c65 2e637373 3e3b2072 #; r + 656c3d70 72656c6f 61643b20 61733d73 el=preload; as=s + 74796c65 046c696e 6b243c2f 73637269 tyle.link$; rel=prel + 6f61643b 2061733d 73637269 707440c8 oad; as=script@. + 40ca0464 6174651d 4d6f6e2c 20323720 @..date.Mon, 27 + 4a756c20 32303039 2031323a 32383a35 Jul 2009 12:28:5 + 3320474d 54067365 72766572 06417061 3 GMT.server.Apa + 6368650d 6c617374 2d6d6f64 69666965 che.last-modifie + 641d5765 642c2032 32204a75 6c203230 d.Wed, 22 Jul 20 + 30392031 393a3135 3a353620 474d5404 09 19:15:56 GMT. + 65746167 14223334 61613338 372d642d etag."34aa387-d- + 31353638 65623030 220d6163 63657074 1568eb00".accept + 2d72616e 67657305 62797465 730e636f -ranges.bytes.co + 6e74656e 742d6c65 6e677468 02353104 ntent-length.51. + 76617279 0f416363 6570742d 456e636f vary.Accept-Enco + 64696e67 0c636f6e 74656e74 2d747970 ding.content-typ + 650a7465 78742f70 6c61696e 3348656c e.text/plain3Hel + 6c6f2057 6f726c64 21204d79 20636f6e lo World! My con + 74656e74 20696e63 6c756465 73206120 tent includes a + 74726169 6c696e67 2043524c 462e0d0a trailing CRLF... + */ + const uint32_t expected_words[] = { + 0x01406613, 0x0772756e, 0x6e696e67, 0x0a22736c, 0x65657020, 0x31352240, + 0x67405304, 0x6c696e6b, 0x233c2f73, 0x74796c65, 0x2e637373, 0x3e3b2072, + 0x656c3d70, 0x72656c6f, 0x61643b20, 0x61733d73, 0x74796c65, 0x046c696e, + 0x6b243c2f, 0x73637269, 0x70742e6a, 0x733e3b20, 0x72656c3d, 0x7072656c, + 0x6f61643b, 0x2061733d, 0x73637269, 0x707440c8, 0x40ca0464, 0x6174651d, + 0x4d6f6e2c, 0x20323720, 0x4a756c20, 0x32303039, 0x2031323a, 0x32383a35, + 0x3320474d, 0x54067365, 0x72766572, 0x06417061, 0x6368650d, 0x6c617374, + 0x2d6d6f64, 0x69666965, 0x641d5765, 0x642c2032, 0x32204a75, 0x6c203230, + 0x30392031, 0x393a3135, 0x3a353620, 0x474d5404, 0x65746167, 0x14223334, + 0x61613338, 0x372d642d, 0x31353638, 0x65623030, 0x220d6163, 0x63657074, + 0x2d72616e, 0x67657305, 0x62797465, 0x730e636f, 0x6e74656e, 0x742d6c65, + 0x6e677468, 0x02353104, 0x76617279, 0x0f416363, 0x6570742d, 0x456e636f, + 0x64696e67, 0x0c636f6e, 0x74656e74, 0x2d747970, 0x650a7465, 0x78742f70, + 0x6c61696e, 0x3348656c, 0x6c6f2057, 0x6f726c64, 0x21204d79, 0x20636f6e, + 0x74656e74, 0x20696e63, 0x6c756465, 0x73206120, 0x74726169, 0x6c696e67, + 0x2043524c, 0x462e0d0a}; + std::string expected; + for (const auto& word : expected_words) { + expected += WordToBytes(word); + } + const auto result = response.Serialize(); + ASSERT_TRUE(result.ok()); + ASSERT_EQ(*result, expected); + EXPECT_THAT( + response.DebugString(), + StrEq( + "BinaryHttpResponse(200){BinaryHttpMessage{Headers{Field{date=Mon, " + "27 Jul 2009 12:28:53 " + "GMT};Field{server=Apache};Field{last-modified=Wed, 22 Jul 2009 " + "19:15:56 " + "GMT};Field{etag=\"34aa387-d-1568eb00\"};Field{accept-ranges=bytes};" + "Field{" + "content-length=51};Field{vary=Accept-Encoding};Field{content-type=" + "text/plain}}Body{Hello World! My content includes a trailing " + "CRLF.\r\n}}InformationalResponse{Field{running=\"sleep " + "15\"}};InformationalResponse{Field{link=; rel=preload; " + "as=style};Field{link=; rel=preload; as=script}}}")); + TestPrintTo(response); +} + +TEST(BinaryHttpResponse, DecodeMultiInformationalWithBody) { + /* + HTTP/1.1 102 Processing + Running: "sleep 15" + + HTTP/1.1 103 Early Hints + Link: ; rel=preload; as=style + Link: ; rel=preload; as=script + + HTTP/1.1 200 OK + Date: Mon, 27 Jul 2009 12:28:53 GMT + Server: Apache + Last-Modified: Wed, 22 Jul 2009 19:15:56 GMT + ETag: "34aa387-d-1568eb00" + Accept-Ranges: bytes + Content-Length: 51 + Vary: Accept-Encoding + Content-Type: text/plain + + Hello World! My content includes a trailing CRLF. + */ + const uint32_t words[] = { + 0x01406613, 0x0772756e, 0x6e696e67, 0x0a22736c, 0x65657020, 0x31352240, + 0x67405304, 0x6c696e6b, 0x233c2f73, 0x74796c65, 0x2e637373, 0x3e3b2072, + 0x656c3d70, 0x72656c6f, 0x61643b20, 0x61733d73, 0x74796c65, 0x046c696e, + 0x6b243c2f, 0x73637269, 0x70742e6a, 0x733e3b20, 0x72656c3d, 0x7072656c, + 0x6f61643b, 0x2061733d, 0x73637269, 0x707440c8, 0x40ca0464, 0x6174651d, + 0x4d6f6e2c, 0x20323720, 0x4a756c20, 0x32303039, 0x2031323a, 0x32383a35, + 0x3320474d, 0x54067365, 0x72766572, 0x06417061, 0x6368650d, 0x6c617374, + 0x2d6d6f64, 0x69666965, 0x641d5765, 0x642c2032, 0x32204a75, 0x6c203230, + 0x30392031, 0x393a3135, 0x3a353620, 0x474d5404, 0x65746167, 0x14223334, + 0x61613338, 0x372d642d, 0x31353638, 0x65623030, 0x220d6163, 0x63657074, + 0x2d72616e, 0x67657305, 0x62797465, 0x730e636f, 0x6e74656e, 0x742d6c65, + 0x6e677468, 0x02353104, 0x76617279, 0x0f416363, 0x6570742d, 0x456e636f, + 0x64696e67, 0x0c636f6e, 0x74656e74, 0x2d747970, 0x650a7465, 0x78742f70, + 0x6c61696e, 0x3348656c, 0x6c6f2057, 0x6f726c64, 0x21204d79, 0x20636f6e, + 0x74656e74, 0x20696e63, 0x6c756465, 0x73206120, 0x74726169, 0x6c696e67, + 0x2043524c, 0x462e0d0a, 0x00000000}; + std::string data; + for (const auto& word : words) { + data += WordToBytes(word); + } + const auto response_so = BinaryHttpResponse::Create(data); + ASSERT_TRUE(response_so.ok()); + const BinaryHttpResponse response = *response_so; + std::vector expected_fields = { + {"date", "Mon, 27 Jul 2009 12:28:53 GMT"}, + {"server", "Apache"}, + {"last-modified", "Wed, 22 Jul 2009 19:15:56 GMT"}, + {"etag", "\"34aa387-d-1568eb00\""}, + {"accept-ranges", "bytes"}, + {"content-length", "51"}, + {"vary", "Accept-Encoding"}, + {"content-type", "text/plain"}}; + + ASSERT_THAT(response.GetHeaderFields(), ContainerEq(expected_fields)); + ASSERT_EQ(response.body(), + "Hello World! My content includes a trailing CRLF.\r\n"); + std::vector header102 = { + {"running", "\"sleep 15\""}}; + std::vector header103 = { + {"link", "; rel=preload; as=style"}, + {"link", "; rel=preload; as=script"}}; + std::vector expected_control = { + {102, header102}, {103, header103}}; + ASSERT_THAT(response.informational_responses(), + ContainerEq(expected_control)); + EXPECT_THAT( + response.DebugString(), + StrEq( + "BinaryHttpResponse(200){BinaryHttpMessage{Headers{Field{date=Mon, " + "27 Jul 2009 12:28:53 " + "GMT};Field{server=Apache};Field{last-modified=Wed, 22 Jul 2009 " + "19:15:56 " + "GMT};Field{etag=\"34aa387-d-1568eb00\"};Field{accept-ranges=bytes};" + "Field{" + "content-length=51};Field{vary=Accept-Encoding};Field{content-type=" + "text/plain}}Body{Hello World! My content includes a trailing " + "CRLF.\r\n}}InformationalResponse{Field{running=\"sleep " + "15\"}};InformationalResponse{Field{link=; rel=preload; " + "as=style};Field{link=; rel=preload; as=script}}}")); + TestPrintTo(response); +} + +TEST(BinaryHttpMessage, SwapBody) { + BinaryHttpRequest request({}); + request.set_body("hello, world!"); + std::string other = "goodbye, world!"; + request.swap_body(other); + EXPECT_EQ(request.body(), "goodbye, world!"); + EXPECT_EQ(other, "hello, world!"); +} + +TEST(BinaryHttpResponse, Equality) { + BinaryHttpResponse response(200); + response.AddHeaderField({"Server", "Apache"})->set_body("Hello, world!\r\n"); + ASSERT_TRUE( + response.AddInformationalResponse(102, {{"Running", "\"sleep 15\""}}) + .ok()); + + BinaryHttpResponse same(200); + same.AddHeaderField({"Server", "Apache"})->set_body("Hello, world!\r\n"); + ASSERT_TRUE( + same.AddInformationalResponse(102, {{"Running", "\"sleep 15\""}}).ok()); + ASSERT_EQ(response, same); +} + +TEST(BinaryHttpResponse, Inequality) { + BinaryHttpResponse response(200); + response.AddHeaderField({"Server", "Apache"})->set_body("Hello, world!\r\n"); + ASSERT_TRUE( + response.AddInformationalResponse(102, {{"Running", "\"sleep 15\""}}) + .ok()); + + BinaryHttpResponse different_status(201); + different_status.AddHeaderField({"Server", "Apache"}) + ->set_body("Hello, world!\r\n"); + EXPECT_TRUE(different_status + .AddInformationalResponse(102, {{"Running", "\"sleep 15\""}}) + .ok()); + EXPECT_NE(response, different_status); + + BinaryHttpResponse different_header(200); + different_header.AddHeaderField({"Server", "python3"}) + ->set_body("Hello, world!\r\n"); + EXPECT_TRUE(different_header + .AddInformationalResponse(102, {{"Running", "\"sleep 15\""}}) + .ok()); + EXPECT_NE(response, different_header); + + BinaryHttpResponse no_header(200); + no_header.set_body("Hello, world!\r\n"); + EXPECT_TRUE( + no_header.AddInformationalResponse(102, {{"Running", "\"sleep 15\""}}) + .ok()); + EXPECT_NE(response, no_header); + + BinaryHttpResponse different_body(200); + different_body.AddHeaderField({"Server", "Apache"}) + ->set_body("Goodbye, world!\r\n"); + EXPECT_TRUE(different_body + .AddInformationalResponse(102, {{"Running", "\"sleep 15\""}}) + .ok()); + EXPECT_NE(response, different_body); + + BinaryHttpResponse no_body(200); + no_body.AddHeaderField({"Server", "Apache"}); + EXPECT_TRUE( + no_body.AddInformationalResponse(102, {{"Running", "\"sleep 15\""}}) + .ok()); + EXPECT_NE(response, no_body); + + BinaryHttpResponse different_informational(200); + different_informational.AddHeaderField({"Server", "Apache"}) + ->set_body("Hello, world!\r\n"); + EXPECT_TRUE(different_informational + .AddInformationalResponse(198, {{"Running", "\"sleep 15\""}}) + .ok()); + EXPECT_NE(response, different_informational); + + BinaryHttpResponse no_informational(200); + no_informational.AddHeaderField({"Server", "Apache"}) + ->set_body("Hello, world!\r\n"); + EXPECT_NE(response, no_informational); +} + +MATCHER_P(HasEqPayload, value, "Payloads of messages are equivalent.") { + return arg.IsPayloadEqual(value); +} + +template +void TestPadding(T& message) { + const auto data_so = message.Serialize(); + ASSERT_TRUE(data_so.ok()); + auto data = *data_so; + ASSERT_EQ(data.size(), message.EncodedSize()); + + message.set_num_padding_bytes(10); + const auto padded_data_so = message.Serialize(); + ASSERT_TRUE(padded_data_so.ok()); + const auto padded_data = *padded_data_so; + ASSERT_EQ(padded_data.size(), message.EncodedSize()); + + // Check padding size output. + ASSERT_EQ(data.size() + 10, padded_data.size()); + // Check for valid null byte padding output + data.resize(data.size() + 10); + ASSERT_EQ(data, padded_data); + + // Deserialize padded and not padded, and verify they are the same. + const auto deserialized_padded_message_so = T::Create(data); + ASSERT_TRUE(deserialized_padded_message_so.ok()); + const auto deserialized_padded_message = *deserialized_padded_message_so; + ASSERT_EQ(deserialized_padded_message, message); + ASSERT_EQ(deserialized_padded_message.num_padding_bytes(), size_t(10)); + + // Invalid padding + data[data.size() - 1] = 'a'; + const auto bad_so = T::Create(data); + ASSERT_FALSE(bad_so.ok()); + + // Check that padding does not impact equality. + data.resize(data.size() - 10); + const auto deserialized_message_so = T::Create(data); + ASSERT_TRUE(deserialized_message_so.ok()); + const auto deserialized_message = *deserialized_message_so; + ASSERT_EQ(deserialized_message.num_padding_bytes(), size_t(0)); + // Confirm that the message payloads are equal, but not fully equivalent due + // to padding. + ASSERT_THAT(deserialized_message, HasEqPayload(deserialized_padded_message)); + ASSERT_NE(deserialized_message, deserialized_padded_message); +} + +TEST(BinaryHttpRequest, Padding) { + /* + GET /hello.txt HTTP/1.1 + User-Agent: curl/7.16.3 libcurl/7.16.3 OpenSSL/0.9.7l zlib/1.2.3 + Host: www.example.com + Accept-Language: en, mi + */ + BinaryHttpRequest request({"GET", "https", "", "/hello.txt"}); + request + .AddHeaderField({"User-Agent", + "curl/7.16.3 libcurl/7.16.3 OpenSSL/0.9.7l zlib/1.2.3"}) + ->AddHeaderField({"Host", "www.example.com"}) + ->AddHeaderField({"Accept-Language", "en, mi"}); + TestPadding(request); +} + +TEST(BinaryHttpResponse, Padding) { + /* + HTTP/1.1 200 OK + Server: Apache + + Hello, world! + */ + BinaryHttpResponse response(200); + response.AddHeaderField({"Server", "Apache"}); + response.set_body("Hello, world!\r\n"); + TestPadding(response); +} + +} // namespace quiche diff --git a/quiche/blind_sign_auth/anonymous_tokens/cpp/client/anonymous_tokens_rsa_bssa_client.cc b/quiche/blind_sign_auth/anonymous_tokens/cpp/client/anonymous_tokens_rsa_bssa_client.cc new file mode 100644 index 000000000000..77e2e9ed569f --- /dev/null +++ b/quiche/blind_sign_auth/anonymous_tokens/cpp/client/anonymous_tokens_rsa_bssa_client.cc @@ -0,0 +1,256 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/client/anonymous_tokens_rsa_bssa_client.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/time/time.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/shared/proto_utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/shared/status_utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.pb.h" + +namespace private_membership { +namespace anonymous_tokens { + +namespace { + +absl::Status ValidityChecksForClientCreation( + const RSABlindSignaturePublicKey& public_key) { + // Basic validity checks. + if (!ParseUseCase(public_key.use_case()).ok()) { + return absl::InvalidArgumentError("Invalid use case for public key."); + } else if (public_key.key_version() <= 0) { + return absl::InvalidArgumentError( + "Key version cannot be zero or negative."); + } else if (public_key.key_size() < 256) { + return absl::InvalidArgumentError( + "Key modulus size cannot be less than 256 bytes."); + } else if (public_key.mask_gen_function() == AT_TEST_MGF || + public_key.mask_gen_function() == AT_MGF_UNDEFINED) { + return absl::InvalidArgumentError("Unknown or unacceptable mgf1 hash."); + } else if (public_key.sig_hash_type() == AT_TEST_HASH_TYPE || + public_key.sig_hash_type() == AT_HASH_TYPE_UNDEFINED) { + return absl::InvalidArgumentError( + "Unknown or unacceptable signature hash."); + } else if (public_key.salt_length() <= 0) { + return absl::InvalidArgumentError( + "Non-positive salt length is not allowed."); + } else if (public_key.mask_gen_function() == AT_TEST_MGF || + public_key.mask_gen_function() == AT_MGF_UNDEFINED) { + return absl::InvalidArgumentError("Message mask type must be defined."); + } else if (public_key.message_mask_size() <= 0) { + return absl::InvalidArgumentError("Message mask size must be positive."); + } + + RSAPublicKey rsa_public_key; + if (!rsa_public_key.ParseFromString(public_key.serialized_public_key())) { + return absl::InvalidArgumentError("Public key is malformed."); + } + if (rsa_public_key.n().size() != static_cast(public_key.key_size())) { + return absl::InvalidArgumentError( + "Public key size does not match key size."); + } + return absl::OkStatus(); +} + +absl::Status CheckPublicKeyValidity( + const RSABlindSignaturePublicKey& public_key) { + absl::Time time_now = absl::Now(); + ANON_TOKENS_ASSIGN_OR_RETURN( + absl::Time start_time, + TimeFromProto(public_key.key_validity_start_time())); + if (start_time > time_now) { + return absl::FailedPreconditionError("Key is not valid yet."); + } + if (public_key.has_expiration_time()) { + ANON_TOKENS_ASSIGN_OR_RETURN(absl::Time expiration_time, + TimeFromProto(public_key.expiration_time())); + if (expiration_time <= time_now) { + return absl::FailedPreconditionError("Key is already expired."); + } + } + return absl::OkStatus(); +} + +} // namespace + +AnonymousTokensRsaBssaClient::AnonymousTokensRsaBssaClient( + const RSABlindSignaturePublicKey& public_key) + : public_key_(public_key) {} + +absl::StatusOr> +AnonymousTokensRsaBssaClient::Create( + const RSABlindSignaturePublicKey& public_key) { + ANON_TOKENS_RETURN_IF_ERROR(ValidityChecksForClientCreation(public_key)); + return absl::WrapUnique(new AnonymousTokensRsaBssaClient(public_key)); +} + +// TODO(b/261866075): Offer an API to simply return bytes of blinded requests. +absl::StatusOr +AnonymousTokensRsaBssaClient::CreateRequest( + const std::vector& inputs) { + if (inputs.empty()) { + return absl::InvalidArgumentError("Cannot create an empty request."); + } else if (!blinding_info_map_.empty()) { + return absl::FailedPreconditionError( + "Blind signature request already created."); + } + + ANON_TOKENS_RETURN_IF_ERROR(CheckPublicKeyValidity(public_key_)); + + AnonymousTokensSignRequest request; + for (const PlaintextMessageWithPublicMetadata& input : inputs) { + // Generate nonce and masked message. For more details, see + // https://datatracker.ietf.org/doc/draft-irtf-cfrg-rsa-blind-signatures/ + ANON_TOKENS_ASSIGN_OR_RETURN(std::string mask, GenerateMask(public_key_)); + std::string masked_message = + MaskMessageConcat(mask, input.plaintext_message()); + + std::optional public_metadata = std::nullopt; + if (public_key_.public_metadata_support()) { + // Empty public metadata is a valid value. + public_metadata = input.public_metadata(); + } + // Generate RSA blinder. + ANON_TOKENS_ASSIGN_OR_RETURN(auto rsa_bssa_blinder, + RsaBlinder::New(public_key_, public_metadata)); + ANON_TOKENS_ASSIGN_OR_RETURN(const std::string blinded_message, + rsa_bssa_blinder->Blind(masked_message)); + + // Store randomness needed to unblind. + BlindingInfo blinding_info = { + input, + mask, + std::move(rsa_bssa_blinder), + }; + + // Create the blinded token. + AnonymousTokensSignRequest_BlindedToken* blinded_token = + request.add_blinded_tokens(); + blinded_token->set_use_case(public_key_.use_case()); + blinded_token->set_key_version(public_key_.key_version()); + blinded_token->set_serialized_token(blinded_message); + blinded_token->set_public_metadata(input.public_metadata()); + blinding_info_map_[blinded_message] = std::move(blinding_info); + } + + return request; +} + +absl::StatusOr> +AnonymousTokensRsaBssaClient::ProcessResponse( + const AnonymousTokensSignResponse& response) { + if (blinding_info_map_.empty()) { + return absl::FailedPreconditionError( + "A valid Blind signature request was not created before calling " + "RetrieveAnonymousTokensFromSignResponse."); + } else if (response.anonymous_tokens().empty()) { + return absl::InvalidArgumentError("Cannot process an empty response."); + } else if (static_cast(response.anonymous_tokens().size()) != + blinding_info_map_.size()) { + return absl::InvalidArgumentError( + "Response is missing some requested tokens."); + } + + // Vector to accumulate output tokens. + std::vector tokens; + + // Temporary set structure to check for duplicate responses. + absl::flat_hash_set blinded_messages; + + // Loop over all the anonymous tokens in the response. + for (const AnonymousTokensSignResponse_AnonymousToken& anonymous_token : + response.anonymous_tokens()) { + // Basic validity checks on the response. + if (anonymous_token.use_case() != public_key_.use_case()) { + return absl::InvalidArgumentError("Use case does not match public key."); + } else if (anonymous_token.key_version() != public_key_.key_version()) { + return absl::InvalidArgumentError( + "Key version does not match public key."); + } else if (anonymous_token.serialized_blinded_message().empty()) { + return absl::InvalidArgumentError( + "Blinded message that was sent in request cannot be empty in " + "response."); + } else if (anonymous_token.serialized_token().empty()) { + return absl::InvalidArgumentError( + "Blinded anonymous token (serialized_token) in response cannot be " + "empty."); + } + + // Check for duplicate in responses. + if (!blinded_messages.insert(anonymous_token.serialized_blinded_message()) + .second) { + return absl::InvalidArgumentError( + "Blinded message was repeated in the response."); + } + + // Retrieve blinding info associated with blind response. + auto it = + blinding_info_map_.find(anonymous_token.serialized_blinded_message()); + if (it == blinding_info_map_.end()) { + return absl::InvalidArgumentError( + "Response has some tokens for some blinded messages that were not " + "requested."); + } + const BlindingInfo& blinding_info = it->second; + + if (blinding_info.input.public_metadata() != + anonymous_token.public_metadata()) { + return absl::InvalidArgumentError( + "Response public metadata does not match input."); + } + + // Unblind the blinded anonymous token to obtain the final anonymous token + // (signature). + ANON_TOKENS_ASSIGN_OR_RETURN( + const std::string final_anonymous_token, + blinding_info.rsa_blinder->Unblind(anonymous_token.serialized_token())); + + // Verify the signature for correctness. + ANON_TOKENS_RETURN_IF_ERROR(blinding_info.rsa_blinder->Verify( + final_anonymous_token, + MaskMessageConcat(blinding_info.mask, + blinding_info.input.plaintext_message()))); + + // Construct the final signature proto. + RSABlindSignatureTokenWithInput final_token_proto; + *final_token_proto.mutable_token()->mutable_token() = final_anonymous_token; + *final_token_proto.mutable_token()->mutable_message_mask() = + blinding_info.mask; + *final_token_proto.mutable_input() = blinding_info.input; + + tokens.push_back(final_token_proto); + } + + return tokens; +} + +absl::Status AnonymousTokensRsaBssaClient::Verify( + const RSABlindSignaturePublicKey& /*public_key*/, + const RSABlindSignatureToken& /*token*/, + const PlaintextMessageWithPublicMetadata& /*input*/) { + return absl::UnimplementedError("Verify not implemented yet."); +} + +} // namespace anonymous_tokens +} // namespace private_membership diff --git a/quiche/blind_sign_auth/anonymous_tokens/cpp/client/anonymous_tokens_rsa_bssa_client.h b/quiche/blind_sign_auth/anonymous_tokens/cpp/client/anonymous_tokens_rsa_bssa_client.h new file mode 100644 index 000000000000..e76018263ab9 --- /dev/null +++ b/quiche/blind_sign_auth/anonymous_tokens/cpp/client/anonymous_tokens_rsa_bssa_client.h @@ -0,0 +1,104 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CLIENT_ANONYMOUS_TOKENS_RSA_BSSA_CLIENT_H_ +#define THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CLIENT_ANONYMOUS_TOKENS_RSA_BSSA_CLIENT_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blinder.h" +#include "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.pb.h" +#include "quiche/common/platform/api/quiche_export.h" +// copybara:strip_begin(internal comment) +// The QUICHE_EXPORT annotation is necessary for some classes and functions +// to link correctly on Windows. Please do not remove them! +// copybara:strip_end + +namespace private_membership { +namespace anonymous_tokens { + +// This class generates AnonymousTokens RSA blind signatures, +// (https://datatracker.ietf.org/doc/draft-irtf-cfrg-rsa-blind-signatures/) +// blind message signing request and processes the response. +// +// Each execution of the Anonymous Tokens RSA blind signatures protocol requires +// a new instance of the AnonymousTokensRsaBssaClient. +// +// This class is not thread-safe. +class QUICHE_EXPORT AnonymousTokensRsaBssaClient { + public: + // AnonymousTokensRsaBssaClient is neither copyable nor copy assignable. + AnonymousTokensRsaBssaClient(const AnonymousTokensRsaBssaClient&) = delete; + AnonymousTokensRsaBssaClient& operator=(const AnonymousTokensRsaBssaClient&) = + delete; + + // Create client with the specified public key which can be used to send a + // sign request and process a response. + // + // This method is to be used to create a client as its constructor is private. + // It takes as input RSABlindSignaturePublicKey which contains the public key + // and relevant parameters. + static absl::StatusOr> Create( + const RSABlindSignaturePublicKey& public_key); + + // Class method that creates the signature requests by taking a vector where + // each element in the vector is the plaintext message along with its + // respective public metadata (if the metadata exists). + // + // The library will also fail if the key has expired. + // + // It only puts the blinded version of the messages in the request. + absl::StatusOr CreateRequest( + const std::vector& inputs); + + // Class method that processes the signature response from the server. + // + // It outputs a vector of a protos where each element contains an input + // plaintext message and associated public metadata (if it exists) along with + // its final (unblinded) anonymous token resulting from the RSA blind + // signatures protocol. + absl::StatusOr> ProcessResponse( + const AnonymousTokensSignResponse& response); + + // Method to verify whether an anonymous token is valid or not. + // + // Returns OK on a valid token and non-OK otherwise. + absl::Status Verify(const RSABlindSignaturePublicKey& public_key, + const RSABlindSignatureToken& token, + const PlaintextMessageWithPublicMetadata& input); + + private: + struct BlindingInfo { + PlaintextMessageWithPublicMetadata input; + std::string mask; + std::unique_ptr rsa_blinder; + }; + + explicit AnonymousTokensRsaBssaClient( + const RSABlindSignaturePublicKey& public_key); + + const RSABlindSignaturePublicKey public_key_; + absl::flat_hash_map blinding_info_map_; +}; + +} // namespace anonymous_tokens +} // namespace private_membership + +#endif // THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CLIENT_ANONYMOUS_TOKENS_RSA_BSSA_CLIENT_H_ diff --git a/quiche/blind_sign_auth/anonymous_tokens/cpp/client/anonymous_tokens_rsa_bssa_client_test.cc b/quiche/blind_sign_auth/anonymous_tokens/cpp/client/anonymous_tokens_rsa_bssa_client_test.cc new file mode 100644 index 000000000000..37d3a0161713 --- /dev/null +++ b/quiche/blind_sign_auth/anonymous_tokens/cpp/client/anonymous_tokens_rsa_bssa_client_test.cc @@ -0,0 +1,470 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/client/anonymous_tokens_rsa_bssa_client.h" + +#include +#include +#include +#include + +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/test_tools/quiche_test_utils.h" +#include "absl/time/time.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/constants.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blind_signer.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/shared/proto_utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/shared/status_utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/testing/utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.pb.h" +#include "openssl/base.h" +#include "openssl/rsa.h" + +namespace private_membership { +namespace anonymous_tokens { +namespace { + +using ::testing::SizeIs; +using quiche::test::StatusIs; + +// Returns a fixed public private key pair by calling GetStrongRsaKeys4096(). +absl::StatusOr> +CreateClientTestKey(absl::string_view use_case = "TEST_USE_CASE", + int key_version = 1, + MessageMaskType mask_type = AT_MESSAGE_MASK_CONCAT, + int message_mask_size = 32, + bool enable_public_metadata = false) { + ANON_TOKENS_ASSIGN_OR_RETURN(auto key_pair, GetStrongRsaKeys4096()); + RSABlindSignaturePublicKey public_key; + public_key.set_use_case(std::string(use_case)); + public_key.set_key_version(key_version); + public_key.set_serialized_public_key(key_pair.first.SerializeAsString()); + absl::Time start_time = absl::Now() - absl::Minutes(100); + ANON_TOKENS_ASSIGN_OR_RETURN(*public_key.mutable_key_validity_start_time(), + TimeToProto(start_time)); + public_key.set_sig_hash_type(AT_HASH_TYPE_SHA384); + public_key.set_mask_gen_function(AT_MGF_SHA384); + public_key.set_salt_length(kSaltLengthInBytes48); + public_key.set_key_size(kRsaModulusSizeInBytes512); + public_key.set_message_mask_type(mask_type); + public_key.set_message_mask_size(message_mask_size); + public_key.set_public_metadata_support(enable_public_metadata); + + return std::make_pair(std::move(public_key), std::move(key_pair.second)); +} + +// Creates the input consisting on plaintext messages and public metadata that +// can be passed to the AnonymousTokensRsaBssaClient. +absl::StatusOr> CreateInput( + absl::Span messages, + absl::Span public_metadata = {}) { + // Check input parameter sizes. + if (!public_metadata.empty() && messages.size() != public_metadata.size()) { + return absl::InvalidArgumentError( + "Input vectors should be of the same size."); + } + + std::vector anonymmous_tokens_input_proto; + anonymmous_tokens_input_proto.reserve(messages.size()); + for (int i = 0; i < messages.size(); ++i) { + PlaintextMessageWithPublicMetadata input_message_and_metadata; + input_message_and_metadata.set_plaintext_message(messages[i]); + if (!public_metadata.empty()) { + input_message_and_metadata.set_public_metadata(public_metadata[i]); + } + anonymmous_tokens_input_proto.push_back(input_message_and_metadata); + } + return anonymmous_tokens_input_proto; +} + +// Creates the server response for anonymous tokens request by using +// RsaBlindSigner. +absl::StatusOr CreateResponse( + const AnonymousTokensSignRequest& request, const RSAPrivateKey& private_key, + bool enable_public_metadata = false) { + AnonymousTokensSignResponse response; + for (const auto& request_token : request.blinded_tokens()) { + auto* response_token = response.add_anonymous_tokens(); + response_token->set_use_case(request_token.use_case()); + response_token->set_key_version(request_token.key_version()); + response_token->set_public_metadata(request_token.public_metadata()); + response_token->set_serialized_blinded_message( + request_token.serialized_token()); + std::optional public_metadata = std::nullopt; + if (enable_public_metadata) { + public_metadata = request_token.public_metadata(); + } + ANON_TOKENS_ASSIGN_OR_RETURN( + std::unique_ptr blind_signer, + RsaBlindSigner::New(private_key, public_metadata)); + ANON_TOKENS_ASSIGN_OR_RETURN( + *response_token->mutable_serialized_token(), + blind_signer->Sign(request_token.serialized_token())); + } + return response; +} + +TEST(CreateAnonymousTokensRsaBssaClientTest, Success) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(auto rsa_key, CreateClientTestKey()); + QUICHE_EXPECT_OK(AnonymousTokensRsaBssaClient::Create(rsa_key.first)); +} + +TEST(CreateAnonymousTokensRsaBssaClientTest, InvalidUseCase) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(auto rsa_key, + CreateClientTestKey("INVALID_USE_CASE")); + EXPECT_THAT(AnonymousTokensRsaBssaClient::Create(rsa_key.first), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(CreateAnonymousTokensRsaBssaClientTest, NotAUseCase) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(auto rsa_key, + CreateClientTestKey("NOT_A_USE_CASE")); + EXPECT_THAT(AnonymousTokensRsaBssaClient::Create(rsa_key.first), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(CreateAnonymousTokensRsaBssaClientTest, InvalidKeyVersion) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(auto rsa_key, + CreateClientTestKey("TEST_USE_CASE", 0)); + EXPECT_THAT(AnonymousTokensRsaBssaClient::Create(rsa_key.first), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(CreateAnonymousTokensRsaBssaClientTest, InvalidMessageMaskType) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + auto rsa_key, + CreateClientTestKey("TEST_USE_CASE", 0, AT_MESSAGE_MASK_TYPE_UNDEFINED)); + EXPECT_THAT(AnonymousTokensRsaBssaClient::Create(rsa_key.first), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(CreateAnonymousTokensRsaBssaClientTest, InvalidMessageMaskSize) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + auto rsa_key, + CreateClientTestKey("TEST_USE_CASE", 0, AT_MESSAGE_MASK_CONCAT, 0)); + EXPECT_THAT(AnonymousTokensRsaBssaClient::Create(rsa_key.first), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +class AnonymousTokensRsaBssaClientTest : public testing::Test { + protected: + void SetUp() override { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::tie(public_key_, private_key_), + CreateClientTestKey()); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + client_, AnonymousTokensRsaBssaClient::Create(public_key_)); + } + + RSAPrivateKey private_key_; + RSABlindSignaturePublicKey public_key_; + std::unique_ptr client_; +}; + +TEST_F(AnonymousTokensRsaBssaClientTest, SuccessOneMessage) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::vector input_messages, + CreateInput({"message"})); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(AnonymousTokensSignRequest request, + client_->CreateRequest(input_messages)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(AnonymousTokensSignResponse response, + CreateResponse(request, private_key_)); + EXPECT_THAT(response.anonymous_tokens(), SizeIs(1)); + QUICHE_EXPECT_OK(client_->ProcessResponse(response)); +} + +TEST_F(AnonymousTokensRsaBssaClientTest, SuccessMultipleMessages) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::vector input_messages, + CreateInput({"message1", "msg2", "anotherMessage", "one_more_message"})); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(AnonymousTokensSignRequest request, + client_->CreateRequest(input_messages)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(AnonymousTokensSignResponse response, + CreateResponse(request, private_key_)); + EXPECT_THAT(response.anonymous_tokens(), SizeIs(4)); + QUICHE_EXPECT_OK(client_->ProcessResponse(response)); +} + +TEST_F(AnonymousTokensRsaBssaClientTest, EnsureRandomTokens) { + std::string message = "test_same_message"; + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::vector input_messages, + CreateInput({message, message})); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(AnonymousTokensSignRequest request, + client_->CreateRequest(input_messages)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(AnonymousTokensSignResponse response, + CreateResponse(request, private_key_)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::vector tokens, + client_->ProcessResponse(response)); + ASSERT_EQ(tokens.size(), 2); + for (const RSABlindSignatureTokenWithInput& token : tokens) { + EXPECT_EQ(token.input().plaintext_message(), message); + } + EXPECT_NE(tokens[0].token().message_mask(), tokens[1].token().message_mask()); + EXPECT_NE(tokens[0].token().token(), tokens[1].token().token()); +} + +TEST_F(AnonymousTokensRsaBssaClientTest, EmptyInput) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::vector input_messages, + CreateInput({})); + EXPECT_THAT(client_->CreateRequest(input_messages), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(AnonymousTokensRsaBssaClientTest, NotYetValidKey) { + RSABlindSignaturePublicKey not_valid_key = public_key_; + absl::Time start_time = absl::Now() + absl::Minutes(100); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + *not_valid_key.mutable_key_validity_start_time(), + TimeToProto(start_time)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::unique_ptr client, + AnonymousTokensRsaBssaClient::Create(not_valid_key)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::vector input_messages, + CreateInput({"message"})); + EXPECT_THAT(client->CreateRequest(input_messages), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST_F(AnonymousTokensRsaBssaClientTest, ExpiredKey) { + RSABlindSignaturePublicKey expired_key = public_key_; + absl::Time end_time = absl::Now() - absl::Seconds(1); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(*expired_key.mutable_expiration_time(), + TimeToProto(end_time)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::unique_ptr client, + AnonymousTokensRsaBssaClient::Create(expired_key)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::vector input_messages, + CreateInput({"message"})); + EXPECT_THAT(client->CreateRequest(input_messages), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST_F(AnonymousTokensRsaBssaClientTest, CreateRequestTwice) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::vector input_messages, + CreateInput({"message"})); + QUICHE_EXPECT_OK(client_->CreateRequest(input_messages)); + EXPECT_THAT(client_->CreateRequest(input_messages), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST_F(AnonymousTokensRsaBssaClientTest, ProcessResponseWithoutCreateRequest) { + AnonymousTokensSignResponse response; + EXPECT_THAT(client_->ProcessResponse(response), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST_F(AnonymousTokensRsaBssaClientTest, ProcessEmptyResponse) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::vector input_messages, + CreateInput({"message"})); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(AnonymousTokensSignRequest request, + client_->CreateRequest(input_messages)); + AnonymousTokensSignResponse response; + EXPECT_THAT(client_->ProcessResponse(response), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(AnonymousTokensRsaBssaClientTest, ProcessResponseWithBadUseCase) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::vector input_messages, + CreateInput({"message"})); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(AnonymousTokensSignRequest request, + client_->CreateRequest(input_messages)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(AnonymousTokensSignResponse response, + CreateResponse(request, private_key_)); + response.mutable_anonymous_tokens(0)->set_use_case("TEST_USE_CASE_2"); + EXPECT_THAT(client_->ProcessResponse(response), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(AnonymousTokensRsaBssaClientTest, ProcessResponseWithBadKeyVersion) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::vector input_messages, + CreateInput({"message"})); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(AnonymousTokensSignRequest request, + client_->CreateRequest(input_messages)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(AnonymousTokensSignResponse response, + CreateResponse(request, private_key_)); + response.mutable_anonymous_tokens(0)->set_key_version(2); + EXPECT_THAT(client_->ProcessResponse(response), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(AnonymousTokensRsaBssaClientTest, ProcessResponseFromDifferentClient) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::vector input_messages, + CreateInput({"message"})); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::unique_ptr client2, + AnonymousTokensRsaBssaClient::Create(public_key_)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(AnonymousTokensSignRequest request1, + client_->CreateRequest(input_messages)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(AnonymousTokensSignRequest request2, + client2->CreateRequest(input_messages)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(AnonymousTokensSignResponse response1, + CreateResponse(request1, private_key_)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(AnonymousTokensSignResponse response2, + CreateResponse(request2, private_key_)); + EXPECT_THAT(client_->ProcessResponse(response2), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(client2->ProcessResponse(response1), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +class AnonymousTokensRsaBssaClientWithPublicMetadataTest + : public testing::Test { + protected: + void SetUp() override { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::tie(public_key_, private_key_), + CreateClientTestKey("TEST_USE_CASE", /*key_version=*/1, + AT_MESSAGE_MASK_CONCAT, + kRsaMessageMaskSizeInBytes32, + /*enable_public_metadata=*/true)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + public_metadata_client_, + AnonymousTokensRsaBssaClient::Create(public_key_)); + } + + RSAPrivateKey private_key_; + RSABlindSignaturePublicKey public_key_; + std::unique_ptr public_metadata_client_; +}; + +TEST_F(AnonymousTokensRsaBssaClientWithPublicMetadataTest, + SuccessOneMessageWithPublicMetadata) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::vector input_messages, + CreateInput({"message"}, {"md1"})); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + AnonymousTokensSignRequest request, + public_metadata_client_->CreateRequest(input_messages)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + AnonymousTokensSignResponse response, + CreateResponse(request, private_key_, /*enable_public_metadata=*/true)); + EXPECT_THAT(response.anonymous_tokens(), SizeIs(1)); + QUICHE_EXPECT_OK(public_metadata_client_->ProcessResponse(response)); +} + +TEST_F(AnonymousTokensRsaBssaClientWithPublicMetadataTest, + FailureWithEmptyPublicMetadata) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::vector input_messages, + CreateInput({"message"}, {"md1"})); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + AnonymousTokensSignRequest request, + public_metadata_client_->CreateRequest(input_messages)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + AnonymousTokensSignResponse response, + CreateResponse(request, private_key_, /*enable_public_metadata=*/false)); + EXPECT_THAT(public_metadata_client_->ProcessResponse(response), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(AnonymousTokensRsaBssaClientWithPublicMetadataTest, + FailureWithWrongPublicMetadata) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::vector input_messages, + CreateInput({"message"}, {"md1"})); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + AnonymousTokensSignRequest request, + public_metadata_client_->CreateRequest(input_messages)); + request.mutable_blinded_tokens(0)->set_public_metadata( + "wrong_public_metadata"); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + AnonymousTokensSignResponse response, + CreateResponse(request, private_key_, /*enable_public_metadata=*/true)); + EXPECT_THAT(public_metadata_client_->ProcessResponse(response), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(AnonymousTokensRsaBssaClientWithPublicMetadataTest, + FailureWithPublicMetadataSupportOff) { + // Create a client with public metadata support disabled. + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(auto key_pair, CreateClientTestKey()); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::unique_ptr non_public_metadata_client, + AnonymousTokensRsaBssaClient::Create(key_pair.first)); + + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::vector input_messages, + CreateInput({"message"}, {"md1"})); + // Use client_ that does not support public metadata. + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + AnonymousTokensSignRequest request, + non_public_metadata_client->CreateRequest(input_messages)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + AnonymousTokensSignResponse response, + CreateResponse(request, private_key_, /*enable_public_metadata=*/true)); + EXPECT_THAT(non_public_metadata_client->ProcessResponse(response), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(AnonymousTokensRsaBssaClientWithPublicMetadataTest, + SuccessMultipleMessagesWithDistinctPublicMetadata) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::vector input_messages, + CreateInput({"message1", "msg2", "anotherMessage", "one_more_message"}, + {"md1", "md2", "md3", "md4"})); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + AnonymousTokensSignRequest request, + public_metadata_client_->CreateRequest(input_messages)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + AnonymousTokensSignResponse response, + CreateResponse(request, private_key_, /*enable_public_metadata=*/true)); + EXPECT_THAT(response.anonymous_tokens(), SizeIs(4)); + QUICHE_EXPECT_OK(public_metadata_client_->ProcessResponse(response)); +} + +TEST_F(AnonymousTokensRsaBssaClientWithPublicMetadataTest, + SuccessMultipleMessagesWithRepeatedPublicMetadata) { + // Create input with repeated public metadata + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::vector input_messages, + CreateInput({"message1", "msg2", "anotherMessage", "one_more_message"}, + {"md1", "md2", "md2", "md1"})); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + AnonymousTokensSignRequest request, + public_metadata_client_->CreateRequest(input_messages)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + AnonymousTokensSignResponse response, + CreateResponse(request, private_key_, /*enable_public_metadata=*/true)); + EXPECT_THAT(response.anonymous_tokens(), SizeIs(4)); + QUICHE_EXPECT_OK(public_metadata_client_->ProcessResponse(response)); +} + +TEST_F(AnonymousTokensRsaBssaClientWithPublicMetadataTest, + SuccessMultipleMessagesWithEmptyStringPublicMetadata) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::vector input_messages, + CreateInput({"message1", "msg2", "anotherMessage", "one_more_message"}, + {"md1", "", "", "md4"})); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + AnonymousTokensSignRequest request, + public_metadata_client_->CreateRequest(input_messages)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + AnonymousTokensSignResponse response, + CreateResponse(request, private_key_, /*enable_public_metadata=*/true)); + EXPECT_THAT(response.anonymous_tokens(), SizeIs(4)); + QUICHE_EXPECT_OK(public_metadata_client_->ProcessResponse(response)); +} + +} // namespace +} // namespace anonymous_tokens +} // namespace private_membership diff --git a/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/at_crypto_utils_test.cc b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/at_crypto_utils_test.cc new file mode 100644 index 000000000000..ed1a76b058e6 --- /dev/null +++ b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/at_crypto_utils_test.cc @@ -0,0 +1,397 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.h" + +#include +#include +#include +#include + +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/test_tools/quiche_test_utils.h" +#include "absl/strings/escaping.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/testing/utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.pb.h" +#include "openssl/base.h" +#include "openssl/rsa.h" + +namespace private_membership { +namespace anonymous_tokens { +namespace { + +struct IetfNewPublicExponentWithPublicMetadataTestVector { + RSAPublicKey public_key; + std::string public_metadata; + std::string new_e; +}; + +TEST(AnonymousTokensCryptoUtilsTest, BignumToStringAndBack) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(BnCtxPtr ctx, GetAndStartBigNumCtx()); + + // Create a new BIGNUM using the context and set it + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(bssl::UniquePtr bn_1, NewBigNum()); + ASSERT_EQ(BN_set_u64(bn_1.get(), 0x124435435), 1); + EXPECT_NE(bn_1, nullptr); + EXPECT_EQ(BN_is_zero(bn_1.get()), 0); + EXPECT_EQ(BN_is_one(bn_1.get()), 0); + + // Convert bn_1 to string from BIGNUM + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + const std::string converted_str, + BignumToString(*bn_1, BN_num_bytes(bn_1.get()))); + // Convert the string version of bn_1 back to BIGNUM + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(bssl::UniquePtr bn_2, + StringToBignum(converted_str)); + // Check whether the conversion back worked + EXPECT_EQ(BN_cmp(bn_1.get(), bn_2.get()), 0); +} + +TEST(AnonymousTokensCryptoUtilsTest, PowerOfTwoAndRsaSqrtTwo) { + // Compute 2^(10-1/2). + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(bssl::UniquePtr sqrt2, + GetRsaSqrtTwo(10)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(bssl::UniquePtr small_pow2, + ComputePowerOfTwo(9)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(bssl::UniquePtr large_pow2, + ComputePowerOfTwo(10)); + EXPECT_GT(BN_cmp(sqrt2.get(), small_pow2.get()), 0); + EXPECT_LT(BN_cmp(sqrt2.get(), large_pow2.get()), 0); +} + +TEST(AnonymousTokensCryptoUtilsTest, ComputeHashAcceptsNullStringView) { + absl::StatusOr null_hash = + ComputeHash(absl::string_view(nullptr, 0), *EVP_sha512()); + absl::StatusOr empty_hash = ComputeHash("", *EVP_sha512()); + std::string str; + absl::StatusOr empty_str_hash = ComputeHash(str, *EVP_sha512()); + + QUICHE_EXPECT_OK(null_hash); + QUICHE_EXPECT_OK(empty_hash); + QUICHE_EXPECT_OK(empty_str_hash); + + EXPECT_EQ(*null_hash, *empty_hash); + EXPECT_EQ(*null_hash, *empty_str_hash); +} + +TEST(AnonymousTokensCryptoUtilsTest, ComputeCarmichaelLcm) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(BnCtxPtr ctx, GetAndStartBigNumCtx()); + + // Suppose that N = 1019 * 1187. + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(bssl::UniquePtr phi_p, NewBigNum()); + ASSERT_TRUE(BN_set_word(phi_p.get(), 1019 - 1)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(bssl::UniquePtr phi_q, NewBigNum()); + ASSERT_TRUE(BN_set_word(phi_q.get(), 1187 - 1)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(bssl::UniquePtr expected_lcm, + NewBigNum()); + ASSERT_TRUE(BN_set_word(expected_lcm.get(), (1019 - 1) * (1187 - 1) / 2)); + + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(bssl::UniquePtr lcm, + ComputeCarmichaelLcm(*phi_p, *phi_q, *ctx)); + EXPECT_EQ(BN_cmp(lcm.get(), expected_lcm.get()), 0); +} + +struct ComputeHashTestParam { + const EVP_MD* hasher; + absl::string_view input_hex; + absl::string_view expected_digest_hex; +}; + +using ComputeHashTest = testing::TestWithParam; + +// Returns the test parameters for ComputeHashTestParam from NIST's +// samples. +std::vector GetComputeHashTestParams() { + std::vector params; + params.push_back({ + EVP_sha256(), + "af397a8b8dd73ab702ce8e53aa9f", + "d189498a3463b18e846b8ab1b41583b0b7efc789dad8a7fb885bbf8fb5b45c5c", + }); + params.push_back({ + EVP_sha256(), + "59eb45bbbeb054b0b97334d53580ce03f699", + "32c38c54189f2357e96bd77eb00c2b9c341ebebacc2945f97804f59a93238288", + }); + params.push_back({ + EVP_sha512(), + "16b17074d3e3d97557f9ed77d920b4b1bff4e845b345a922", + "6884134582a760046433abcbd53db8ff1a89995862f305b887020f6da6c7b903a314721e" + "972bf438483f452a8b09596298a576c903c91df4a414c7bd20fd1d07", + }); + params.push_back({ + EVP_sha512(), + "7651ab491b8fa86f969d42977d09df5f8bee3e5899180b52c968b0db057a6f02a886ad61" + "7a84915a", + "f35e50e2e02b8781345f8ceb2198f068ba103476f715cfb487a452882c9f0de0c720b2a0" + "88a39d06a8a6b64ce4d6470dfeadc4f65ae06672c057e29f14c4daf9", + }); + return params; +} + +TEST_P(ComputeHashTest, ComputesHash) { + const ComputeHashTestParam& params = GetParam(); + ASSERT_NE(params.hasher, nullptr); + std::string data = absl::HexStringToBytes(params.input_hex); + std::string expected_digest = + absl::HexStringToBytes(params.expected_digest_hex); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(auto computed_hash, + ComputeHash(data, *params.hasher)); + EXPECT_EQ(computed_hash, expected_digest); +} + +INSTANTIATE_TEST_SUITE_P(ComputeHashTests, ComputeHashTest, + testing::ValuesIn(GetComputeHashTestParams())); + +TEST(PublicMetadataCryptoUtilsInternalTest, PublicMetadataHashWithHKDF) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(BnCtxPtr ctx, GetAndStartBigNumCtx()); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(bssl::UniquePtr max_value, + NewBigNum()); + ASSERT_TRUE(BN_set_word(max_value.get(), 4294967295)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(auto key_pair, GetStrongRsaKeys2048()); + std::string input1 = "ro1"; + std::string input2 = "ro2"; + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + bssl::UniquePtr output1, + internal::PublicMetadataHashWithHKDF(input1, key_pair.first.n(), + 1 + input1.size())); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + bssl::UniquePtr another_output1, + internal::PublicMetadataHashWithHKDF(input1, key_pair.first.n(), + 1 + input1.size())); + EXPECT_EQ(BN_cmp(output1.get(), another_output1.get()), 0); + + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + bssl::UniquePtr output2, + internal::PublicMetadataHashWithHKDF(input2, key_pair.first.n(), + 1 + input2.size())); + EXPECT_NE(BN_cmp(output1.get(), output2.get()), 0); + + EXPECT_LE(BN_cmp(output1.get(), max_value.get()), 0); + EXPECT_LE(BN_cmp(output2.get(), max_value.get()), 0); +} + +TEST(PublicMetadataCryptoUtilsTest, PublicExponentHashDifferentModulus) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(auto key_pair_1, GetStrongRsaKeys2048()); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(auto key_pair_2, + GetAnotherStrongRsaKeys2048()); + std::string metadata = "md"; + // Check that same metadata and different modulus result in different + // hashes. + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(bssl::UniquePtr rsa_modulus_1, + StringToBignum(key_pair_1.first.n())); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + bssl::UniquePtr exp1, + PublicMetadataExponent(*rsa_modulus_1.get(), metadata)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(auto rsa_modulus_2, + StringToBignum(key_pair_2.first.n())); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + bssl::UniquePtr exp2, + PublicMetadataExponent(*rsa_modulus_2.get(), metadata)); + EXPECT_NE(BN_cmp(exp1.get(), exp2.get()), 0); +} + +std::vector +GetIetfNewPublicExponentWithPublicMetadataTestVectors() { + std::vector test_vectors; + + RSAPublicKey public_key; + public_key.set_n(absl::HexStringToBytes( + "d6930820f71fe517bf3259d14d40209b02a5c0d3d61991c731dd7da39f8d69821552e231" + "8d6c9ad897e603887a476ea3162c1205da9ac96f02edf31df049bd55f142134c17d4382a" + "0e78e275345f165fbe8e49cdca6cf5c726c599dd39e09e75e0f330a33121e73976e4facb" + "a9cfa001c28b7c96f8134f9981db6750b43a41710f51da4240fe03106c12acb1e7bb53d7" + "5ec7256da3fddd0718b89c365410fce61bc7c99b115fb4c3c318081fa7e1b65a37774e8e" + "50c96e8ce2b2cc6b3b367982366a2bf9924c4bafdb3ff5e722258ab705c76d43e5f1f121" + "b984814e98ea2b2b8725cd9bc905c0bc3d75c2a8db70a7153213c39ae371b2b5dc1dafcb" + "19d6fae9")); + public_key.set_e(absl::HexStringToBytes("010001")); + + // Test vector 1 + test_vectors.push_back( + {.public_key = public_key, + .public_metadata = absl::HexStringToBytes("6d65746164617461"), + .new_e = absl::HexStringToBytes( + "30584b72f5cb557085106232f051d039e23358feee9204cf30ea567620e90d79e4a" + "7a81388b1f390e18ea5240a1d8cc296ce1325128b445c48aa5a3b34fa07c324bf17" + "bc7f1b3efebaff81d7e032948f1477493bc183d2f8d94c947c984c6f0757527615b" + "f2a2f0ef0db5ad80ce99905beed0440b47fa5cb9a2334fea40ad88e6ef1")}); + + // Test vector 2 + test_vectors.push_back( + {.public_key = public_key, + .public_metadata = "", + .new_e = absl::HexStringToBytes( + "2ed5a8d2592a11bbeef728bb39018ef5c3cf343507dd77dd156d5eec7f06f04732e" + "4be944c5d2443d244c59e52c9fa5e8de40f55ffd0e70fbe9093d3f7be2aafd77c14" + "b263b71c1c6b3ca2b9629842a902128fee4878392a950906fae35d6194e0d2548e5" + "8bbc20f841188ca2fceb20b2b1b45448da5c7d1c73fb6e83fa58867397b")}); + + return test_vectors; +} + +TEST(PublicMetadataCryptoUtilsTest, + IetfNewPublicExponentWithPublicMetadataTests) { + const auto test_vectors = + GetIetfNewPublicExponentWithPublicMetadataTestVectors(); + for (const IetfNewPublicExponentWithPublicMetadataTestVector& test_vector : + test_vectors) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + bssl::UniquePtr rsa_modulus, + StringToBignum(test_vector.public_key.n())); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + bssl::UniquePtr rsa_e, + StringToBignum(test_vector.public_key.e())); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(bssl::UniquePtr expected_new_e, + StringToBignum(test_vector.new_e)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + bssl::UniquePtr modified_e, + ComputeFinalExponentUnderPublicMetadata( + *rsa_modulus.get(), *rsa_e.get(), test_vector.public_metadata)); + + EXPECT_EQ(BN_cmp(modified_e.get(), expected_new_e.get()), 0); + } +} + +using CreateTestKeyPairFunction = + absl::StatusOr>(); + +class CryptoUtilsTest + : public testing::TestWithParam { + protected: + void SetUp() override { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(auto keys_pair, (*GetParam())()); + public_key_ = std::move(keys_pair.first); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(rsa_modulus_, + StringToBignum(keys_pair.second.n())); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(rsa_e_, + StringToBignum(keys_pair.second.e())); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(rsa_p_, + StringToBignum(keys_pair.second.p())); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(rsa_q_, + StringToBignum(keys_pair.second.q())); + } + + bssl::UniquePtr rsa_modulus_; + bssl::UniquePtr rsa_e_; + bssl::UniquePtr rsa_p_; + bssl::UniquePtr rsa_q_; + RSAPublicKey public_key_; +}; + +TEST_P(CryptoUtilsTest, PublicExponentCoprime) { + std::string metadata = "md"; + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + bssl::UniquePtr exp, + PublicMetadataExponent(*rsa_modulus_.get(), metadata)); + int rsa_mod_size_bits = BN_num_bits(rsa_modulus_.get()); + // Check that exponent is odd. + EXPECT_EQ(BN_is_odd(exp.get()), 1); + // Check that exponent is small enough. + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(bssl::UniquePtr sqrt2, + GetRsaSqrtTwo(rsa_mod_size_bits / 2)); + EXPECT_LT(BN_cmp(exp.get(), sqrt2.get()), 0); + EXPECT_LT(BN_cmp(exp.get(), rsa_p_.get()), 0); + EXPECT_LT(BN_cmp(exp.get(), rsa_q_.get()), 0); +} + +TEST_P(CryptoUtilsTest, PublicExponentHash) { + std::string metadata1 = "md1"; + std::string metadata2 = "md2"; + // Check that hash is deterministic. + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + bssl::UniquePtr exp1, + PublicMetadataExponent(*rsa_modulus_.get(), metadata1)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + bssl::UniquePtr another_exp1, + PublicMetadataExponent(*rsa_modulus_.get(), metadata1)); + EXPECT_EQ(BN_cmp(exp1.get(), another_exp1.get()), 0); + // Check that hashes are distinct for different metadata. + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + bssl::UniquePtr exp2, + PublicMetadataExponent(*rsa_modulus_.get(), metadata2)); + EXPECT_NE(BN_cmp(exp1.get(), exp2.get()), 0); +} + +TEST_P(CryptoUtilsTest, FinalExponentCoprime) { + std::string metadata = "md"; + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + bssl::UniquePtr final_exponent, + ComputeFinalExponentUnderPublicMetadata(*rsa_modulus_.get(), + *rsa_e_.get(), metadata)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(BnCtxPtr ctx, GetAndStartBigNumCtx()); + + // Check that exponent is odd. + EXPECT_EQ(BN_is_odd(final_exponent.get()), 1); + // Check that exponent is co-prime to factors of the rsa modulus. + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(bssl::UniquePtr gcd_p_fe, + NewBigNum()); + ASSERT_EQ( + BN_gcd(gcd_p_fe.get(), rsa_p_.get(), final_exponent.get(), ctx.get()), 1); + EXPECT_EQ(BN_cmp(gcd_p_fe.get(), BN_value_one()), 0); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(bssl::UniquePtr gcd_q_fe, + NewBigNum()); + ASSERT_EQ( + BN_gcd(gcd_q_fe.get(), rsa_q_.get(), final_exponent.get(), ctx.get()), 1); + EXPECT_EQ(BN_cmp(gcd_q_fe.get(), BN_value_one()), 0); +} + +TEST_P(CryptoUtilsTest, DeterministicModificationOfPublicExponentWithMetadata) { + std::string metadata = "md"; + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + bssl::UniquePtr public_exp_1, + ComputeFinalExponentUnderPublicMetadata(*rsa_modulus_.get(), + *rsa_e_.get(), metadata)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + bssl::UniquePtr public_exp_2, + ComputeFinalExponentUnderPublicMetadata(*rsa_modulus_.get(), + *rsa_e_.get(), metadata)); + + EXPECT_EQ(BN_cmp(public_exp_1.get(), public_exp_2.get()), 0); +} + +TEST_P(CryptoUtilsTest, DifferentPublicExponentWithDifferentPublicMetadata) { + std::string metadata_1 = "md1"; + std::string metadata_2 = "md2"; + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + bssl::UniquePtr public_exp_1, + ComputeFinalExponentUnderPublicMetadata(*rsa_modulus_.get(), + *rsa_e_.get(), metadata_1)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + bssl::UniquePtr public_exp_2, + ComputeFinalExponentUnderPublicMetadata(*rsa_modulus_.get(), + *rsa_e_.get(), metadata_2)); + // Check that exponent is different in all keys + EXPECT_NE(BN_cmp(public_exp_1.get(), public_exp_2.get()), 0); + EXPECT_NE(BN_cmp(public_exp_1.get(), rsa_e_.get()), 0); + EXPECT_NE(BN_cmp(public_exp_2.get(), rsa_e_.get()), 0); +} + +TEST_P(CryptoUtilsTest, ModifiedPublicExponentWithEmptyPublicMetadata) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(bssl::UniquePtr new_public_exp, + ComputeFinalExponentUnderPublicMetadata( + *rsa_modulus_.get(), *rsa_e_.get(), "")); + + EXPECT_NE(BN_cmp(new_public_exp.get(), rsa_e_.get()), 0); +} + +INSTANTIATE_TEST_SUITE_P(CryptoUtilsTest, CryptoUtilsTest, + testing::Values(&GetStrongRsaKeys2048, + &GetAnotherStrongRsaKeys2048, + &GetStrongRsaKeys3072, + &GetStrongRsaKeys4096)); + +} // namespace +} // namespace anonymous_tokens +} // namespace private_membership diff --git a/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/blind_signer.h b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/blind_signer.h new file mode 100644 index 000000000000..3c3e5dba6824 --- /dev/null +++ b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/blind_signer.h @@ -0,0 +1,37 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CRYPTO_BLIND_SIGNER_H_ +#define THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CRYPTO_BLIND_SIGNER_H_ + +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" + +namespace private_membership { +namespace anonymous_tokens { + +class BlindSigner { + public: + virtual absl::StatusOr Sign( + absl::string_view blinded_data) const = 0; + + virtual ~BlindSigner() = default; +}; + +} // namespace anonymous_tokens +} // namespace private_membership + +#endif // THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CRYPTO_BLIND_SIGNER_H_ diff --git a/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/blinder.h b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/blinder.h new file mode 100644 index 000000000000..fd29ad74a360 --- /dev/null +++ b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/blinder.h @@ -0,0 +1,38 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CRYPTO_BLINDER_H_ +#define THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CRYPTO_BLINDER_H_ + +#include + +#include "absl/status/statusor.h" + +namespace private_membership { +namespace anonymous_tokens { + +class Blinder { + public: + enum class BlinderState { kCreated = 0, kBlinded, kUnblinded }; + virtual absl::StatusOr Blind(absl::string_view message) = 0; + + virtual absl::StatusOr Unblind( + absl::string_view blind_signature) = 0; + + virtual ~Blinder() = default; +}; + +} // namespace anonymous_tokens +} // namespace private_membership +#endif // THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CRYPTO_BLINDER_H_ diff --git a/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/constants.h b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/constants.h new file mode 100644 index 000000000000..4f73afd7841b --- /dev/null +++ b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/constants.h @@ -0,0 +1,68 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CRYPTO_CONSTANTS_H_ +#define THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CRYPTO_CONSTANTS_H_ + +#include + +#include "absl/strings/string_view.h" + +namespace private_membership { +namespace anonymous_tokens { + +// Returned integer on successful execution of BoringSSL methods +constexpr int kBsslSuccess = 1; + +// RSA modulus size, 4096 bits +// +// Our recommended size. +constexpr int kRsaModulusSizeInBits4096 = 4096; + +// RSA modulus size, 512 bytes +constexpr int kRsaModulusSizeInBytes512 = 512; + +// RSA modulus size, 2048 bits +// +// Recommended size for RSA Blind Signatures without Public Metadata. +// +// https://www.ietf.org/archive/id/draft-ietf-privacypass-protocol-08.html#name-token-type-blind-rsa-2048-b. +constexpr int kRsaModulusSizeInBits2048 = 2048; + +// RSA modulus size, 256 bytes +constexpr int kRsaModulusSizeInBytes256 = 256; + +// Salt length, 48 bytes +// +// Recommended size. The convention is to use hLen, the length of the output of +// the hash function in bytes. A salt length of zero will result in a +// deterministic signature value. +// +// https://datatracker.ietf.org/doc/draft-irtf-cfrg-rsa-blind-signatures/ +constexpr int kSaltLengthInBytes48 = 48; + +// Length of message mask, 32 bytes. +// +// https://datatracker.ietf.org/doc/draft-irtf-cfrg-rsa-blind-signatures/ +constexpr int kRsaMessageMaskSizeInBytes32 = 32; + +// Info used in HKDF for Public Metadata Hash. +constexpr absl::string_view kHkdfPublicMetadataInfo = "PBRSA"; + +constexpr int kHkdfPublicMetadataInfoSizeInBytes = 5; + +} // namespace anonymous_tokens +} // namespace private_membership + +#endif // THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CRYPTO_CONSTANTS_H_ diff --git a/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.cc b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.cc new file mode 100644 index 000000000000..16dad4f9c6d7 --- /dev/null +++ b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.cc @@ -0,0 +1,527 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/constants.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/shared/status_utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.pb.h" +#include "openssl/err.h" +#include "openssl/hkdf.h" +#include "openssl/rand.h" +#include "openssl/rsa.h" + +namespace private_membership { +namespace anonymous_tokens { + +namespace internal { + +// Approximation of sqrt(2) taken from +// //depot/google3/third_party/openssl/boringssl/src/crypto/fipsmodule/rsa/rsa_impl.c;l=997 +const std::vector kBoringSSLRSASqrtTwo = { + 0x4d7c60a5, 0xe633e3e1, 0x5fcf8f7b, 0xca3ea33b, 0xc246785e, 0x92957023, + 0xf9acce41, 0x797f2805, 0xfdfe170f, 0xd3b1f780, 0xd24f4a76, 0x3facb882, + 0x18838a2e, 0xaff5f3b2, 0xc1fcbdde, 0xa2f7dc33, 0xdea06241, 0xf7aa81c2, + 0xf6a1be3f, 0xca221307, 0x332a5e9f, 0x7bda1ebf, 0x0104dc01, 0xfe32352f, + 0xb8cf341b, 0x6f8236c7, 0x4264dabc, 0xd528b651, 0xf4d3a02c, 0xebc93e0c, + 0x81394ab6, 0xd8fd0efd, 0xeaa4a089, 0x9040ca4a, 0xf52f120f, 0x836e582e, + 0xcb2a6343, 0x31f3c84d, 0xc6d5a8a3, 0x8bb7e9dc, 0x460abc72, 0x2f7c4e33, + 0xcab1bc91, 0x1688458a, 0x53059c60, 0x11bc337b, 0xd2202e87, 0x42af1f4e, + 0x78048736, 0x3dfa2768, 0x0f74a85e, 0x439c7b4a, 0xa8b1fe6f, 0xdc83db39, + 0x4afc8304, 0x3ab8a2c3, 0xed17ac85, 0x83339915, 0x1d6f60ba, 0x893ba84c, + 0x597d89b3, 0x754abe9f, 0xb504f333, 0xf9de6484, +}; + +absl::StatusOr> PublicMetadataHashWithHKDF( + absl::string_view public_metadata, absl::string_view rsa_modulus_str, + size_t out_len_bytes) { + const EVP_MD* evp_md_sha_384 = EVP_sha384(); + // Prepend "key" to input. + std::string modified_input = absl::StrCat("key", public_metadata); + std::vector input_buffer(modified_input.begin(), + modified_input.end()); + // Append 0x00 to input. + input_buffer.push_back(0x00); + std::string out_e; + // We set the out_e size beyond out_len_bytes so that out_e bytes are + // indifferentiable from truly random bytes even after truncations. + // + // Expanding to 16 more bytes is sufficient. + // https://cfrg.github.io/draft-irtf-cfrg-hash-to-curve/draft-irtf-cfrg-hash-to-curve.html#name-hashing-to-a-finite-field + const size_t hkdf_output_size = out_len_bytes + 16; + out_e.resize(hkdf_output_size); + // The modulus is used as salt to ensure different outputs for same metadata + // and different modulus. + if (HKDF(reinterpret_cast(out_e.data()), hkdf_output_size, + evp_md_sha_384, input_buffer.data(), input_buffer.size(), + reinterpret_cast(rsa_modulus_str.data()), + rsa_modulus_str.size(), + reinterpret_cast(kHkdfPublicMetadataInfo.data()), + kHkdfPublicMetadataInfoSizeInBytes) != kBsslSuccess) { + return absl::InternalError("HKDF failed in public_metadata_crypto_utils"); + } + // Truncate out_e to out_len_bytes + out_e.resize(out_len_bytes); + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr out, + StringToBignum(out_e)); + return out; +} + +} // namespace internal + +absl::StatusOr GetAndStartBigNumCtx() { + // Create context to be used in intermediate computation. + BnCtxPtr bn_ctx = BnCtxPtr(BN_CTX_new()); + if (!bn_ctx.get()) { + return absl::InternalError("Error generating bignum context."); + } + BN_CTX_start(bn_ctx.get()); + + return bn_ctx; +} + +absl::StatusOr> NewBigNum() { + bssl::UniquePtr bn(BN_new()); + if (!bn.get()) { + return absl::InternalError("Error generating bignum."); + } + return bn; +} + +absl::StatusOr BignumToString(const BIGNUM& big_num, + const size_t output_len) { + std::vector serialization(output_len); + if (BN_bn2bin_padded(serialization.data(), serialization.size(), &big_num) != + kBsslSuccess) { + return absl::InternalError( + absl::StrCat("Function BN_bn2bin_padded failed: ", GetSslErrors())); + } + return std::string(std::make_move_iterator(serialization.begin()), + std::make_move_iterator(serialization.end())); +} + +absl::StatusOr> StringToBignum( + const absl::string_view input_str) { + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr output, NewBigNum()); + if (!BN_bin2bn(reinterpret_cast(input_str.data()), + input_str.size(), output.get())) { + return absl::InternalError( + absl::StrCat("Function BN_bin2bn failed: ", GetSslErrors())); + } + if (!output.get()) { + return absl::InternalError("Function BN_bin2bn failed."); + } + return output; +} + +std::string GetSslErrors() { + std::string ret; + ERR_print_errors_cb( + [](const char* str, size_t len, void* ctx) -> int { + static_cast(ctx)->append(str, len); + return 1; + }, + &ret); + return ret; +} + +absl::StatusOr GenerateMask( + const RSABlindSignaturePublicKey& public_key) { + std::string mask; + if (public_key.message_mask_type() == AT_MESSAGE_MASK_CONCAT && + public_key.message_mask_size() >= kRsaMessageMaskSizeInBytes32) { + mask = std::string(public_key.message_mask_size(), '\0'); + RAND_bytes(reinterpret_cast(mask.data()), mask.size()); + } else { + return absl::InvalidArgumentError( + "Undefined or unsupported message mask type."); + } + return mask; +} + +std::string MaskMessageConcat(absl::string_view mask, + absl::string_view message) { + return absl::StrCat(mask, message); +} + +std::string EncodeMessagePublicMetadata(absl::string_view message, + absl::string_view public_metadata) { + // Prepend encoding of "msg" followed by 4 bytes representing public metadata + // length. + std::string tag = "msg"; + std::vector buffer(tag.begin(), tag.end()); + buffer.push_back((public_metadata.size() >> 24) & 0xFF); + buffer.push_back((public_metadata.size() >> 16) & 0xFF); + buffer.push_back((public_metadata.size() >> 8) & 0xFF); + buffer.push_back((public_metadata.size() >> 0) & 0xFF); + + // Finally append public metadata and then the message to the output. + std::string encoding(buffer.begin(), buffer.end()); + return absl::StrCat(encoding, public_metadata, message); +} + +absl::StatusOr ProtoHashTypeToEVPDigest( + const HashType hash_type) { + switch (hash_type) { + case AT_HASH_TYPE_SHA256: + return EVP_sha256(); + case AT_HASH_TYPE_SHA384: + return EVP_sha384(); + case AT_HASH_TYPE_UNDEFINED: + default: + return absl::InvalidArgumentError("Unknown hash type."); + } +} + +absl::StatusOr ProtoMaskGenFunctionToEVPDigest( + const MaskGenFunction mgf) { + switch (mgf) { + case AT_MGF_SHA256: + return EVP_sha256(); + case AT_MGF_SHA384: + return EVP_sha384(); + case AT_MGF_UNDEFINED: + default: + return absl::InvalidArgumentError( + "Unknown hash type for mask generation hash function."); + } +} + +absl::StatusOr> GetRsaSqrtTwo(int x) { + // Compute hard-coded sqrt(2). + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr sqrt2, NewBigNum()); + // TODO(b/277606961): simplify RsaSqrtTwo initialization logic + for (int i = internal::kBoringSSLRSASqrtTwo.size() - 2; i >= 0; i = i - 2) { + // Add the uint32_t values as words directly and shift. + // 'i' is the "hi" value of a uint64_t, and 'i+1' is the "lo" value. + if (BN_add_word(sqrt2.get(), internal::kBoringSSLRSASqrtTwo[i]) != 1) { + return absl::InternalError(absl::StrCat( + "Cannot add word to compute RSA sqrt(2): ", GetSslErrors())); + } + if (BN_lshift(sqrt2.get(), sqrt2.get(), 32) != 1) { + return absl::InternalError(absl::StrCat( + "Cannot shift to compute RSA sqrt(2): ", GetSslErrors())); + } + if (BN_add_word(sqrt2.get(), internal::kBoringSSLRSASqrtTwo[i+1]) != 1) { + return absl::InternalError(absl::StrCat( + "Cannot add word to compute RSA sqrt(2): ", GetSslErrors())); + } + if (i > 0) { + if (BN_lshift(sqrt2.get(), sqrt2.get(), 32) != 1) { + return absl::InternalError(absl::StrCat( + "Cannot shift to compute RSA sqrt(2): ", GetSslErrors())); + } + } + } + + // Check that hard-coded result is correct length. + int sqrt2_bits = 32 * internal::kBoringSSLRSASqrtTwo.size(); + if (BN_num_bits(sqrt2.get()) != sqrt2_bits) { + return absl::InternalError("RSA sqrt(2) is not correct length."); + } + + // Either shift left or right depending on value x. + if (sqrt2_bits > x) { + if (BN_rshift(sqrt2.get(), sqrt2.get(), sqrt2_bits - x) != 1) { + return absl::InternalError( + absl::StrCat("Cannot rshift to compute 2^(x-1/2): ", GetSslErrors())); + } + } else { + // Round up and be pessimistic about minimium factors. + if (BN_add_word(sqrt2.get(), 1) != 1 || + BN_lshift(sqrt2.get(), sqrt2.get(), x - sqrt2_bits) != 1) { + return absl::InternalError(absl::StrCat( + "Cannot add/lshift to compute 2^(x-1/2): ", GetSslErrors())); + } + } + + // Check that 2^(x - 1/2) is correct length. + if (BN_num_bits(sqrt2.get()) != x) { + return absl::InternalError( + "2^(x-1/2) is not correct length after shifting."); + } + + return std::move(sqrt2); +} + +absl::StatusOr> ComputePowerOfTwo(int x) { + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr ret, NewBigNum()); + if (BN_set_bit(ret.get(), x) != 1) { + return absl::InternalError( + absl::StrCat("Unable to set bit to compute 2^x: ", GetSslErrors())); + } + if (!BN_is_pow2(ret.get()) || !BN_is_bit_set(ret.get(), x)) { + return absl::InternalError(absl::StrCat("Unable to compute 2^", x, ".")); + } + return ret; +} + +absl::StatusOr ComputeHash(absl::string_view input, + const EVP_MD& hasher) { + std::string digest; + digest.resize(EVP_MAX_MD_SIZE); + + uint32_t digest_length = 0; + if (EVP_Digest(input.data(), input.length(), + reinterpret_cast(&digest[0]), &digest_length, + &hasher, /*impl=*/nullptr) != 1) { + return absl::InternalError(absl::StrCat( + "Openssl internal error computing hash: ", GetSslErrors())); + } + digest.resize(digest_length); + return digest; +} + +absl::StatusOr> AnonymousTokensRSAPrivateKeyToRSA( + const RSAPrivateKey& private_key) { + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr n, + StringToBignum(private_key.n())); + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr e, + StringToBignum(private_key.e())); + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr d, + StringToBignum(private_key.d())); + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr p, + StringToBignum(private_key.p())); + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr q, + StringToBignum(private_key.q())); + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr dp, + StringToBignum(private_key.dp())); + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr dq, + StringToBignum(private_key.dq())); + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr crt, + StringToBignum(private_key.crt())); + + bssl::UniquePtr rsa_private_key(RSA_new()); + // Populate private key. + if (!rsa_private_key.get()) { + return absl::InternalError( + absl::StrCat("RSA_new failed: ", GetSslErrors())); + } else if (RSA_set0_key(rsa_private_key.get(), n.get(), e.get(), d.get()) != + kBsslSuccess) { + return absl::InternalError( + absl::StrCat("RSA_set0_key failed: ", GetSslErrors())); + } else if (RSA_set0_factors(rsa_private_key.get(), p.get(), q.get()) != + kBsslSuccess) { + return absl::InternalError( + absl::StrCat("RSA_set0_factors failed: ", GetSslErrors())); + } else if (RSA_set0_crt_params(rsa_private_key.get(), dp.get(), dq.get(), + crt.get()) != kBsslSuccess) { + return absl::InternalError( + absl::StrCat("RSA_set0_crt_params failed: ", GetSslErrors())); + } else { + n.release(); + e.release(); + d.release(); + p.release(); + q.release(); + dp.release(); + dq.release(); + crt.release(); + } + return std::move(rsa_private_key); +} + +absl::StatusOr> AnonymousTokensRSAPublicKeyToRSA( + const RSAPublicKey& public_key) { + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr rsa_modulus, + StringToBignum(public_key.n())); + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr rsa_e, + StringToBignum(public_key.e())); + // Convert to OpenSSL RSA. + bssl::UniquePtr rsa_public_key(RSA_new()); + if (!rsa_public_key.get()) { + return absl::InternalError( + absl::StrCat("RSA_new failed: ", GetSslErrors())); + } else if (RSA_set0_key(rsa_public_key.get(), rsa_modulus.get(), rsa_e.get(), + nullptr) != kBsslSuccess) { + return absl::InternalError( + absl::StrCat("RSA_set0_key failed: ", GetSslErrors())); + } + // RSA_set0_key takes ownership of the pointers under rsa_modulus, new_e on + // success. + rsa_modulus.release(); + rsa_e.release(); + return rsa_public_key; +} + +absl::StatusOr> ComputeCarmichaelLcm( + const BIGNUM& phi_p, const BIGNUM& phi_q, BN_CTX& bn_ctx) { + // To compute lcm(phi(p), phi(q)), we first compute phi(n) = + // (p-1)(q-1). As n is assumed to be a safe RSA modulus (signing_key is + // assumed to be part of a strong rsa key pair), phi(n) = (p-1)(q-1) = + // (2 phi(p))(2 phi(q)) = 4 * phi(p) * phi(q) where phi(p) and phi(q) are also + // primes. So we get the lcm by outputting phi(n) >> 1 = 2 * phi(p) * phi(q). + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr phi_n, NewBigNum()); + if (BN_mul(phi_n.get(), &phi_p, &phi_q, &bn_ctx) != 1) { + return absl::InternalError( + absl::StrCat("Unable to compute phi(n): ", GetSslErrors())); + } + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr lcm, NewBigNum()); + if (BN_rshift1(lcm.get(), phi_n.get()) != 1) { + return absl::InternalError(absl::StrCat( + "Could not compute LCM(phi(p), phi(q)): ", GetSslErrors())); + } + return lcm; +} + +absl::StatusOr> PublicMetadataExponent( + const BIGNUM& n, absl::string_view public_metadata) { + // Check modulus length. + if (BN_num_bits(&n) % 2 == 1) { + return absl::InvalidArgumentError( + "Strong RSA modulus should be even length."); + } + int modulus_bytes = BN_num_bytes(&n); + // The integer modulus_bytes is expected to be a power of 2. + int prime_bytes = modulus_bytes / 2; + + ANON_TOKENS_ASSIGN_OR_RETURN(std::string rsa_modulus_str, + BignumToString(n, modulus_bytes)); + + // Get HKDF output of length prime_bytes. + ANON_TOKENS_ASSIGN_OR_RETURN( + bssl::UniquePtr exponent, + internal::PublicMetadataHashWithHKDF(public_metadata, rsa_modulus_str, + prime_bytes)); + + // We need to generate random odd exponents < 2^(primes_bits - 2) where + // prime_bits = prime_bytes * 8. This will guarantee that the resulting + // exponent is coprime to phi(N) = 4p'q' as 2^(prime_bits - 2) < p', q' < + // 2^(prime_bits - 1). + // + // To do this, we can truncate the HKDF output (exponent) which is prime_bits + // long, to prime_bits - 2, by clearing its top two bits. We then set the + // least significant bit to 1. This way the final exponent will be less than + // 2^(primes_bits - 2) and will always be odd. + if (BN_clear_bit(exponent.get(), (prime_bytes * 8) - 1) != kBsslSuccess || + BN_clear_bit(exponent.get(), (prime_bytes * 8) - 2) != kBsslSuccess || + BN_set_bit(exponent.get(), 0) != kBsslSuccess) { + return absl::InvalidArgumentError(absl::StrCat( + "Could not clear the two most significant bits and set the least " + "significant bit to zero: ", + GetSslErrors())); + } + // Check that exponent is small enough to ensure it is coprime to phi(n). + if (BN_num_bits(exponent.get()) >= (8 * prime_bytes - 1)) { + return absl::InternalError("Generated exponent is too large."); + } + + return exponent; +} + +absl::StatusOr> ComputeFinalExponentUnderPublicMetadata( + const BIGNUM& n, const BIGNUM& e, absl::string_view public_metadata) { + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr md_exp, + PublicMetadataExponent(n, public_metadata)); + ANON_TOKENS_ASSIGN_OR_RETURN(BnCtxPtr bn_ctx, GetAndStartBigNumCtx()); + // new_e=e*md_exp + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr new_e, NewBigNum()); + if (BN_mul(new_e.get(), md_exp.get(), &e, bn_ctx.get()) != kBsslSuccess) { + return absl::InternalError( + absl::StrCat("Unable to multiply e with md_exp: ", GetSslErrors())); + } + return new_e; +} + +absl::Status RsaBlindSignatureVerify( + const int salt_length, const EVP_MD* sig_hash, const EVP_MD* mgf1_hash, + RSA* rsa_public_key, const BIGNUM& rsa_modulus, + const BIGNUM& augmented_rsa_e, absl::string_view signature, + absl::string_view message, + std::optional public_metadata) { + std::string augmented_message(message); + if (public_metadata.has_value()) { + augmented_message = EncodeMessagePublicMetadata(message, *public_metadata); + } + ANON_TOKENS_ASSIGN_OR_RETURN(std::string message_digest, + ComputeHash(augmented_message, *sig_hash)); + const int hash_size = EVP_MD_size(sig_hash); + // Make sure the size of the digest is correct. + if (message_digest.size() != hash_size) { + return absl::InvalidArgumentError( + absl::StrCat("Size of the digest doesn't match the one " + "of the hashing algorithm; expected ", + hash_size, " got ", message_digest.size())); + } + const int rsa_modulus_size = BN_num_bytes(&rsa_modulus); + if (signature.size() != rsa_modulus_size) { + return absl::InvalidArgumentError( + "Signature size not equal to modulus size."); + } + + std::string recovered_message_digest(rsa_modulus_size, 0); + if (!public_metadata.has_value()) { + int recovered_message_digest_size = RSA_public_decrypt( + /*flen=*/signature.size(), + /*from=*/reinterpret_cast(signature.data()), + /*to=*/ + reinterpret_cast(recovered_message_digest.data()), + /*rsa=*/rsa_public_key, + /*padding=*/RSA_NO_PADDING); + if (recovered_message_digest_size != rsa_modulus_size) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid signature size (likely an incorrect key is " + "used); expected ", + rsa_modulus_size, " got ", recovered_message_digest_size, + ": ", GetSslErrors())); + } + } else { + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr signature_bn, + StringToBignum(signature)); + if (BN_ucmp(signature_bn.get(), &rsa_modulus) >= 0) { + return absl::InternalError("Data too large for modulus."); + } + ANON_TOKENS_ASSIGN_OR_RETURN(BnCtxPtr bn_ctx, GetAndStartBigNumCtx()); + bssl::UniquePtr bn_mont_ctx( + BN_MONT_CTX_new_for_modulus(&rsa_modulus, bn_ctx.get())); + if (!bn_mont_ctx) { + return absl::InternalError("BN_MONT_CTX_new_for_modulus failed."); + } + ANON_TOKENS_ASSIGN_OR_RETURN( + bssl::UniquePtr recovered_message_digest_bn, NewBigNum()); + if (BN_mod_exp_mont(recovered_message_digest_bn.get(), signature_bn.get(), + &augmented_rsa_e, &rsa_modulus, bn_ctx.get(), + bn_mont_ctx.get()) != kBsslSuccess) { + return absl::InternalError("Exponentiation failed."); + } + ANON_TOKENS_ASSIGN_OR_RETURN( + recovered_message_digest, + BignumToString(*recovered_message_digest_bn, rsa_modulus_size)); + } + if (RSA_verify_PKCS1_PSS_mgf1( + rsa_public_key, reinterpret_cast(&message_digest[0]), + sig_hash, mgf1_hash, + reinterpret_cast(recovered_message_digest.data()), + salt_length) != kBsslSuccess) { + return absl::InvalidArgumentError( + absl::StrCat("PSS padding verification failed: ", GetSslErrors())); + } + return absl::OkStatus(); +} + +} // namespace anonymous_tokens +} // namespace private_membership diff --git a/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.h b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.h new file mode 100644 index 000000000000..109476fe35dc --- /dev/null +++ b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.h @@ -0,0 +1,189 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CRYPTO_CRYPTO_UTILS_H_ +#define THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CRYPTO_CRYPTO_UTILS_H_ + +#include + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.pb.h" +#include "openssl/base.h" +#include "openssl/bn.h" +#include "openssl/evp.h" +#include "openssl/rsa.h" +#include "quiche/common/platform/api/quiche_export.h" +// copybara:strip_begin(internal comment) +// The QUICHE_EXPORT annotation is necessary for some classes and functions +// to link correctly on Windows. Please do not remove them! +// copybara:strip_end + +namespace private_membership { +namespace anonymous_tokens { + +// Internal functions only exposed for testing. +namespace internal { + +// Outputs a public metadata `hash` using HKDF with the public metadata as +// input and the rsa modulus as salt. The expected output hash size is passed as +// out_len_bytes. +// +// Implementation follows the steps listed in +// https://datatracker.ietf.org/doc/draft-amjad-cfrg-partially-blind-rsa/ +// +// This method internally calls HKDF with output size of more than +// out_len_bytes and later truncates the output to out_len_bytes. This is done +// so that the output is indifferentiable from truly random bytes. +// https://cfrg.github.io/draft-irtf-cfrg-hash-to-curve/draft-irtf-cfrg-hash-to-curve.html#name-hashing-to-a-finite-field +absl::StatusOr> QUICHE_EXPORT +PublicMetadataHashWithHKDF(absl::string_view public_metadata, + absl::string_view rsa_modulus_str, + size_t out_len_bytes); + +} // namespace internal + +// Deletes a BN_CTX. +class BnCtxDeleter { + public: + void operator()(BN_CTX* ctx) { BN_CTX_free(ctx); } +}; +typedef std::unique_ptr BnCtxPtr; + +// Deletes a BN_MONT_CTX. +class BnMontCtxDeleter { + public: + void operator()(BN_MONT_CTX* mont_ctx) { BN_MONT_CTX_free(mont_ctx); } +}; +typedef std::unique_ptr BnMontCtxPtr; + +// Deletes an EVP_MD_CTX. +class EvpMdCtxDeleter { + public: + void operator()(EVP_MD_CTX* ctx) { EVP_MD_CTX_destroy(ctx); } +}; +typedef std::unique_ptr EvpMdCtxPtr; + +// Creates and starts a BIGNUM context. +absl::StatusOr QUICHE_EXPORT GetAndStartBigNumCtx(); + +// Creates a new BIGNUM. +absl::StatusOr> QUICHE_EXPORT NewBigNum(); + +// Converts a BIGNUM to string. +absl::StatusOr QUICHE_EXPORT BignumToString( + const BIGNUM& big_num, size_t output_len); + +// Converts a string to BIGNUM. +absl::StatusOr> QUICHE_EXPORT StringToBignum( + absl::string_view input_str); + +// Retrieve error messages from OpenSSL. +std::string QUICHE_EXPORT GetSslErrors(); + +// Generate a message mask. For more details, see +// https://datatracker.ietf.org/doc/draft-irtf-cfrg-rsa-blind-signatures/ +absl::StatusOr QUICHE_EXPORT GenerateMask( + const RSABlindSignaturePublicKey& public_key); + +// Mask message using protocol at +// https://datatracker.ietf.org/doc/draft-irtf-cfrg-rsa-blind-signatures/ +std::string QUICHE_EXPORT MaskMessageConcat(absl::string_view mask, + absl::string_view message); + +// Encode Message and Public Metadata using steps in +// https://datatracker.ietf.org/doc/draft-amjad-cfrg-partially-blind-rsa/ +// +// The length of public metadata must fit in 4 bytes. +std::string QUICHE_EXPORT EncodeMessagePublicMetadata( + absl::string_view message, absl::string_view public_metadata); + +// Compute 2^(x - 1/2). +absl::StatusOr> QUICHE_EXPORT GetRsaSqrtTwo( + int x); + +// Compute compute 2^x. +absl::StatusOr> QUICHE_EXPORT ComputePowerOfTwo( + int x); + +// Converts the AnonymousTokens proto hash type to the equivalent EVP digest. +absl::StatusOr QUICHE_EXPORT +ProtoHashTypeToEVPDigest(HashType hash_type); + +// Converts the AnonymousTokens proto hash type for mask generation function to +// the equivalent EVP digest. +absl::StatusOr QUICHE_EXPORT +ProtoMaskGenFunctionToEVPDigest(MaskGenFunction mgf); + +// ComputeHash sub-routine used during blindness and verification of RSA blind +// signatures protocol with or without public metadata. +absl::StatusOr QUICHE_EXPORT ComputeHash( + absl::string_view input, const EVP_MD& hasher); + +// Computes the Carmichael LCM given phi(p) and phi(q) where N = p*q is a safe +// RSA modulus. +absl::StatusOr> QUICHE_EXPORT +ComputeCarmichaelLcm(const BIGNUM& phi_p, const BIGNUM& phi_q, BN_CTX& bn_ctx); + +// Converts AnonymousTokens::RSAPrivateKey to bssl::UniquePtr without +// public metadata augmentation. +absl::StatusOr> QUICHE_EXPORT +AnonymousTokensRSAPrivateKeyToRSA(const RSAPrivateKey& private_key); + +// Converts AnonymousTokens::RSAPublicKey to bssl::UniquePtr without +// public metadata augmentation. +absl::StatusOr> QUICHE_EXPORT +AnonymousTokensRSAPublicKeyToRSA(const RSAPublicKey& public_key); + +// Compute exponent based only on the public metadata. Assumes that n is a safe +// modulus i.e. it produces a strong RSA key pair. If not, the exponent may be +// invalid. +absl::StatusOr> QUICHE_EXPORT +PublicMetadataExponent(const BIGNUM& n, absl::string_view public_metadata); + +// Computes final exponent by multiplying the public exponent e with the +// exponent derived from public metadata. Assumes that n is a safe modulus i.e. +// it produces a strong RSA key pair. If not, the exponent may be invalid. +// +// Empty public metadata is considered to be a valid value for public_metadata +// and will output an exponent different than `e` as well. +absl::StatusOr> QUICHE_EXPORT +ComputeFinalExponentUnderPublicMetadata(const BIGNUM& n, const BIGNUM& e, + absl::string_view public_metadata); + +// Helper method that implements RSA PSS Blind Signatures verification protocol +// for both the standard scheme as well as the public metadata version. +// +// The standard public exponent e in rsa_public_key should always have a +// standard value even if the public_metada is not std::nullopt. +// +// If the public_metadata is set to std::nullopt, augmented_rsa_e should be +// equal to a standard public exponent same as the value of e in rsa_public_key. +// Otherwise, it will be equal to a new public exponent value derived using the +// public metadata. +absl::Status QUICHE_EXPORT RsaBlindSignatureVerify( + int salt_length, const EVP_MD* sig_hash, const EVP_MD* mgf1_hash, + RSA* rsa_public_key, const BIGNUM& rsa_modulus, + const BIGNUM& augmented_rsa_e, absl::string_view signature, + absl::string_view message, + std::optional public_metadata = std::nullopt); + +} // namespace anonymous_tokens +} // namespace private_membership + +#endif // THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CRYPTO_CRYPTO_UTILS_H_ diff --git a/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blind_signer.cc b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blind_signer.cc new file mode 100644 index 000000000000..f138eb41d471 --- /dev/null +++ b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blind_signer.cc @@ -0,0 +1,206 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blind_signer.h" + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/constants.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/shared/status_utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.pb.h" +#include "openssl/rsa.h" + +namespace private_membership { +namespace anonymous_tokens { + +RsaBlindSigner::RsaBlindSigner(std::optional public_metadata, + bssl::UniquePtr rsa_modulus, + bssl::UniquePtr rsa_p, + bssl::UniquePtr rsa_q, + bssl::UniquePtr augmented_rsa_e, + bssl::UniquePtr augmented_rsa_d, + bssl::UniquePtr rsa_standard_key) + : public_metadata_(public_metadata), + rsa_modulus_(std::move(rsa_modulus)), + rsa_p_(std::move(rsa_p)), + rsa_q_(std::move(rsa_q)), + augmented_rsa_e_(std::move(augmented_rsa_e)), + augmented_rsa_d_(std::move(augmented_rsa_d)), + rsa_standard_key_(std::move(rsa_standard_key)) {} + +absl::StatusOr> RsaBlindSigner::New( + const RSAPrivateKey& signing_key, + std::optional public_metadata) { + if (!public_metadata.has_value()) { + // The RSA modulus and exponent are checked as part of the conversion to + // bssl::UniquePtr. + ANON_TOKENS_ASSIGN_OR_RETURN( + bssl::UniquePtr rsa_standard_key, + AnonymousTokensRSAPrivateKeyToRSA(signing_key)); + return absl::WrapUnique( + new RsaBlindSigner(public_metadata, nullptr, nullptr, nullptr, nullptr, + nullptr, std::move(rsa_standard_key))); + } + + // Convert RSA modulus n (=p*q) to BIGNUM + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr rsa_modulus, + StringToBignum(signing_key.n())); + // Convert p & q to BIGNUM + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr rsa_p, + StringToBignum(signing_key.p())); + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr rsa_q, + StringToBignum(signing_key.q())); + // Convert public exponent e to BIGNUM + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr old_e, + StringToBignum(signing_key.e())); + // Convert public exponent e to BIGNUM + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr old_d, + StringToBignum(signing_key.d())); + + // Compute new exponents based on public metadata. + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr augmented_rsa_e, + ComputeFinalExponentUnderPublicMetadata( + *rsa_modulus, *old_e, *public_metadata)); + + // Compute phi(p) = p-1 and phi(q) = q-1 + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr phi_p, NewBigNum()); + if (BN_sub(phi_p.get(), rsa_p.get(), BN_value_one()) != 1) { + return absl::InternalError( + absl::StrCat("Unable to compute phi(p): ", GetSslErrors())); + } + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr phi_q, NewBigNum()); + if (BN_sub(phi_q.get(), rsa_q.get(), BN_value_one()) != 1) { + return absl::InternalError( + absl::StrCat("Unable to compute phi(q): ", GetSslErrors())); + } + + bssl::UniquePtr bn_ctx(BN_CTX_new()); + if (!bn_ctx) { + return absl::InternalError("BN_CTX_new failed."); + } + // Compute lcm(phi(p), phi(q)). + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr lcm, + ComputeCarmichaelLcm(*phi_p, *phi_q, *bn_ctx)); + + // Compute the new private exponent new_d + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr augmented_rsa_d, + NewBigNum()); + if (!BN_mod_inverse(augmented_rsa_d.get(), augmented_rsa_e.get(), lcm.get(), + bn_ctx.get())) { + return absl::InternalError( + absl::StrCat("Could not compute private exponent d: ", GetSslErrors())); + } + + return absl::WrapUnique(new RsaBlindSigner( + *public_metadata, std::move(rsa_modulus), std::move(rsa_p), + std::move(rsa_q), std::move(augmented_rsa_e), + std::move(augmented_rsa_d))); +} + +// Helper Signature method that assumes RSA_NO_PADDING. +// TODO(b/271438729): Adding blinding of private operations in RSA Sign +// TODO(b/271438266): Implement RsaSign using the Chinese Remainder Theorem +absl::StatusOr RsaBlindSigner::SignInternal( + absl::string_view input) const { + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr input_bn, + StringToBignum(input)); + if (BN_ucmp(input_bn.get(), rsa_modulus_.get()) >= 0) { + return absl::InvalidArgumentError( + "RsaSign input size too large for modulus size"); + } + + ANON_TOKENS_ASSIGN_OR_RETURN(BnCtxPtr ctx, GetAndStartBigNumCtx()); + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr result, NewBigNum()); + // TODO(b/271438266): Replace with constant-time implementation. + if (!BN_mod_exp(result.get(), input_bn.get(), augmented_rsa_d_.get(), + rsa_modulus_.get(), ctx.get())) { + return absl::InternalError("BN_mod_exp_mont_consttime failed in RsaSign"); + } + + // Verify the result to protect against fault attacks as described in + // boringssl. Also serves as a check for correctness. + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr vrfy, NewBigNum()); + if (vrfy == nullptr || + !BN_mod_exp(vrfy.get(), result.get(), augmented_rsa_e_.get(), + rsa_modulus_.get(), ctx.get()) || + BN_cmp(vrfy.get(), input_bn.get()) != 0) { + return absl::InternalError("Signature verification failed in RsaSign"); + } + + return BignumToString(*result, BN_num_bytes(rsa_modulus_.get())); +} + +absl::StatusOr RsaBlindSigner::Sign( + const absl::string_view blinded_data) const { + if (blinded_data.empty() || blinded_data.data() == nullptr) { + return absl::InvalidArgumentError("blinded_data string is empty."); + } + + int mod_size; + if (!public_metadata_.has_value()) { + mod_size = RSA_size(rsa_standard_key_.get()); + } else { + mod_size = BN_num_bytes(rsa_modulus_.get()); + } + if (blinded_data.size() != mod_size) { + return absl::InternalError(absl::StrCat( + "Expected blind data size = ", mod_size, + " actual blind data size = ", blinded_data.size(), " bytes.")); + } + + std::string signature(mod_size, 0); + if (!public_metadata_.has_value()) { + // Compute a raw RSA signature. + size_t out_len; + if (RSA_sign_raw(/*rsa=*/rsa_standard_key_.get(), /*out_len=*/&out_len, + /*out=*/reinterpret_cast(&signature[0]), + /*max_out=*/mod_size, + /*in=*/reinterpret_cast(&blinded_data[0]), + /*in_len=*/mod_size, + /*padding=*/RSA_NO_PADDING) != kBsslSuccess) { + return absl::InternalError( + "RSA_sign_raw failed when called from RsaBlindSigner::Sign"); + } + if (out_len != mod_size && out_len == signature.size()) { + return absl::InternalError(absl::StrCat( + "Expected value of out_len = ", mod_size, + " bytes, actual value of out_len and signature.size() = ", out_len, + " and ", signature.size(), " bytes.")); + } + } else { + // As public metadata is not empty, we cannot use RSA_sign_raw as it might + // err on exponent size. + ANON_TOKENS_ASSIGN_OR_RETURN(signature, SignInternal(blinded_data)); + if (signature.size() != mod_size) { + return absl::InternalError(absl::StrCat( + "Expected value of signature.size() = ", mod_size, + " bytes, actual value of signature.size() = ", signature.size(), + " bytes.")); + } + } + return signature; +} + +} // namespace anonymous_tokens +} // namespace private_membership diff --git a/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blind_signer.h b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blind_signer.h new file mode 100644 index 000000000000..e7f503e033ce --- /dev/null +++ b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blind_signer.h @@ -0,0 +1,86 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CRYPTO_RSA_BLIND_SIGNER_H_ +#define THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CRYPTO_RSA_BLIND_SIGNER_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/blind_signer.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.pb.h" +#include "quiche/common/platform/api/quiche_export.h" +// copybara:strip_begin(internal comment) +// The QUICHE_EXPORT annotation is necessary for some classes and functions +// to link correctly on Windows. Please do not remove them! +// copybara:strip_end + +namespace private_membership { +namespace anonymous_tokens { + +// The RSA SSA (Signature Schemes with Appendix) using PSS (Probabilistic +// Signature Scheme) encoding is defined at +// https://tools.ietf.org/html/rfc8017#section-8.1). This implementation uses +// Boring SSL for the underlying cryptographic operations. +class QUICHE_EXPORT RsaBlindSigner : public BlindSigner { + public: + ~RsaBlindSigner() override = default; + RsaBlindSigner(const RsaBlindSigner&) = delete; + RsaBlindSigner& operator=(const RsaBlindSigner&) = delete; + + // Passing of public_metadata is optional. If it is set to any value including + // an empty string, RsaBlindSigner will assume that partially blind RSA + // signature protocol is being executed. + static absl::StatusOr> New( + const RSAPrivateKey& signing_key, + std::optional public_metadata = std::nullopt); + + // Computes the signature for 'blinded_data'. + absl::StatusOr Sign( + absl::string_view blinded_data) const override; + + private: + // Use New to construct. + RsaBlindSigner(std::optional public_metadata, + bssl::UniquePtr rsa_modulus, + bssl::UniquePtr rsa_p, bssl::UniquePtr rsa_q, + bssl::UniquePtr augmented_rsa_e, + bssl::UniquePtr augmented_rsa_d, + bssl::UniquePtr rsa_standard_key = nullptr); + + absl::StatusOr SignInternal(absl::string_view input) const; + + const std::optional public_metadata_; + + // We only keep these for the case when we use RSA blind signatures with + // public metadata. Specifically augmented_rsa_e_ and augmented_rsa_d_ is + // derived using the public metadata. + const bssl::UniquePtr rsa_modulus_; + const bssl::UniquePtr rsa_p_; + const bssl::UniquePtr rsa_q_; + const bssl::UniquePtr augmented_rsa_e_; + const bssl::UniquePtr augmented_rsa_d_; + + // We only keep this for the case when we use standard RSA blind signatures + // without public metadata. + const bssl::UniquePtr rsa_standard_key_; +}; + +} // namespace anonymous_tokens +} // namespace private_membership + +#endif // THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CRYPTO_RSA_BLIND_SIGNER_H_ diff --git a/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blind_signer_test.cc b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blind_signer_test.cc new file mode 100644 index 000000000000..7106eb55f18e --- /dev/null +++ b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blind_signer_test.cc @@ -0,0 +1,262 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blind_signer.h" + +#include +#include +#include +#include + +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/test_tools/quiche_test_utils.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/constants.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_ssa_pss_verifier.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/testing/utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.pb.h" +#include "openssl/digest.h" +#include "openssl/rsa.h" + +namespace private_membership { +namespace anonymous_tokens { +namespace { + +using CreateTestKeyPairFunction = + absl::StatusOr>(); + +class RsaBlindSignerTest + : public ::testing::TestWithParam { + protected: + void SetUp() override { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(auto keys_pair, (*GetParam())()); + public_key_ = std::move(keys_pair.first); + private_key_ = std::move(keys_pair.second); + generator_.seed(0); + // NOTE: using recommended RsaSsaPssParams + sig_hash_ = EVP_sha384(); + mgf1_hash_ = EVP_sha384(); + salt_length_ = kSaltLengthInBytes48; + } + + RSAPrivateKey private_key_; + RSAPublicKey public_key_; + std::mt19937_64 generator_; + const EVP_MD *sig_hash_; // Owned by BoringSSL. + const EVP_MD *mgf1_hash_; // Owned by BoringSSL. + int salt_length_; + std::uniform_int_distribution distr_u8_ = + std::uniform_int_distribution{0, 255}; +}; + +// This test only tests whether the implemented signer 'signs' properly. The +// outline of method calls in this test should not be assumed a secure signature +// scheme (and used in other places) as the security has not been +// proven/analyzed. +// +// Test for the standard signer does not take public metadata as a parameter +// which means public metadata is set to std::nullopt. +TEST_P(RsaBlindSignerTest, StandardSignerWorks) { + absl::string_view message = "Hello World!"; + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::string encoded_message, + EncodeMessageForTests(message, public_key_, sig_hash_, mgf1_hash_, + salt_length_)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::unique_ptr signer, + RsaBlindSigner::New(private_key_)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::string potentially_insecure_signature, + signer->Sign(encoded_message)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + const auto verifier, + RsaSsaPssVerifier::New(salt_length_, sig_hash_, mgf1_hash_, public_key_)); + QUICHE_EXPECT_OK(verifier->Verify(potentially_insecure_signature, message)); +} + +TEST_P(RsaBlindSignerTest, SignerFails) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::unique_ptr signer, + RsaBlindSigner::New(private_key_)); + absl::string_view message = "Hello World!"; + EXPECT_THAT(signer->Sign(message), + quiche::test::StatusIs( + absl::StatusCode::kInternal, + ::testing::HasSubstr("Expected blind data size"))); + + int sig_size = public_key_.n().size(); + std::string message2 = RandomString(sig_size, &distr_u8_, &generator_); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::string insecure_sig, + signer->Sign(message2)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + const auto verifier, + RsaSsaPssVerifier::New(salt_length_, sig_hash_, mgf1_hash_, public_key_)); + EXPECT_THAT( + verifier->Verify(insecure_sig, message2), + quiche::test::StatusIs(absl::StatusCode::kInvalidArgument, + ::testing::HasSubstr("verification failed"))); +} + +INSTANTIATE_TEST_SUITE_P(RsaBlindSignerTest, RsaBlindSignerTest, + ::testing::Values(&GetStrongRsaKeys2048, + &GetAnotherStrongRsaKeys2048, + &GetStrongRsaKeys3072, + &GetStrongRsaKeys4096)); + +class RsaBlindSignerTestWithPublicMetadata + : public ::testing::TestWithParam { + protected: + void SetUp() override { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(auto keys_pair, (*GetParam())()); + public_key_ = std::move(keys_pair.first); + private_key_ = std::move(keys_pair.second); + // NOTE: using recommended RsaSsaPssParams + sig_hash_ = EVP_sha384(); + mgf1_hash_ = EVP_sha384(); + salt_length_ = kSaltLengthInBytes48; + } + + RSAPrivateKey private_key_; + RSAPublicKey public_key_; + const EVP_MD *sig_hash_; // Owned by BoringSSL. + const EVP_MD *mgf1_hash_; // Owned by BoringSSL. + int salt_length_; +}; + +// This test only tests whether the implemented signer 'signs' properly under +// some public metadata. The outline of method calls in this test should not +// be assumed a secure signature scheme (and used in other places) as the +// security has not been proven/analyzed. +TEST_P(RsaBlindSignerTestWithPublicMetadata, SignerWorksWithPublicMetadata) { + absl::string_view message = "Hello World!"; + absl::string_view public_metadata = "pubmd!"; + std::string augmented_message = + EncodeMessagePublicMetadata(message, public_metadata); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::string encoded_message, + EncodeMessageForTests(augmented_message, public_key_, sig_hash_, + mgf1_hash_, salt_length_)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::unique_ptr signer, + RsaBlindSigner::New(private_key_, public_metadata)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::string potentially_insecure_signature, + signer->Sign(encoded_message)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + auto verifier, RsaSsaPssVerifier::New(salt_length_, sig_hash_, mgf1_hash_, + public_key_, public_metadata)); + QUICHE_EXPECT_OK(verifier->Verify(potentially_insecure_signature, message)); +} + +TEST_P(RsaBlindSignerTestWithPublicMetadata, + SignerWorksWithEmptyPublicMetadata) { + absl::string_view message = "Hello World!"; + absl::string_view empty_public_metadata = ""; + std::string augmented_message = + EncodeMessagePublicMetadata(message, empty_public_metadata); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::string encoded_message, + EncodeMessageForTests(augmented_message, public_key_, sig_hash_, + mgf1_hash_, salt_length_)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::unique_ptr signer, + RsaBlindSigner::New(private_key_, empty_public_metadata)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::string potentially_insecure_signature, + signer->Sign(encoded_message)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + auto verifier, + RsaSsaPssVerifier::New(salt_length_, sig_hash_, mgf1_hash_, public_key_, + empty_public_metadata)); + QUICHE_EXPECT_OK(verifier->Verify(potentially_insecure_signature, message)); +} + +TEST_P(RsaBlindSignerTestWithPublicMetadata, + SignatureFailstoVerifyWithWrongPublicMetadata) { + absl::string_view message = "Hello World!"; + absl::string_view public_metadata = "pubmd!"; + absl::string_view public_metadata_2 = "pubmd2"; + std::string augmented_message = + EncodeMessagePublicMetadata(message, public_metadata); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::string encoded_message, + EncodeMessageForTests(augmented_message, public_key_, sig_hash_, + mgf1_hash_, salt_length_)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::unique_ptr signer, + RsaBlindSigner::New(private_key_, public_metadata)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::string potentially_insecure_signature, + signer->Sign(encoded_message)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + auto verifier, RsaSsaPssVerifier::New(salt_length_, sig_hash_, mgf1_hash_, + public_key_, public_metadata_2)); + EXPECT_THAT( + verifier->Verify(potentially_insecure_signature, message), + quiche::test::StatusIs(absl::StatusCode::kInvalidArgument, + ::testing::HasSubstr("verification failed"))); +} + +TEST_P(RsaBlindSignerTestWithPublicMetadata, + SignatureFailsToVerifyWithNoPublicMetadata) { + absl::string_view message = "Hello World!"; + absl::string_view public_metadata = "pubmd!"; + absl::string_view public_metadata_2 = ""; + std::string augmented_message = + EncodeMessagePublicMetadata(message, public_metadata); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::string encoded_message, + EncodeMessageForTests(augmented_message, public_key_, sig_hash_, + mgf1_hash_, salt_length_)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::unique_ptr signer, + RsaBlindSigner::New(private_key_, public_metadata)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::string potentially_insecure_signature, + signer->Sign(encoded_message)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + auto verifier, RsaSsaPssVerifier::New(salt_length_, sig_hash_, mgf1_hash_, + public_key_, public_metadata_2)); + EXPECT_THAT( + verifier->Verify(potentially_insecure_signature, message), + quiche::test::StatusIs(absl::StatusCode::kInvalidArgument, + ::testing::HasSubstr("verification failed"))); +} + +INSTANTIATE_TEST_SUITE_P( + RsaBlindSignerTestWithPublicMetadata, RsaBlindSignerTestWithPublicMetadata, + ::testing::Values(&GetStrongRsaKeys2048, &GetAnotherStrongRsaKeys2048, + &GetStrongRsaKeys3072, &GetStrongRsaKeys4096)); + +// TODO(b/275956922): Consolidate all tests that use IETF test vectors into one +// E2E test. +// +// This test uses IETF test vectors for RSA blind signatures with public +// metadata. The vectors includes tests for public metadata set to an empty +// string as well as a non-empty value. +TEST(IetfRsaBlindSignerTest, + IetfRsaBlindSignaturesWithPublicMetadataTestVectorsSuccess) { + auto test_vectors = GetIetfRsaBlindSignatureWithPublicMetadataTestVectors(); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + const auto test_key, + GetIetfRsaBlindSignatureWithPublicMetadataTestKeys()); + for (const auto &test_vector : test_vectors) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::unique_ptr signer, + RsaBlindSigner::New(test_key.second, test_vector.public_metadata)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::string blind_signature, + signer->Sign(test_vector.blinded_message)); + EXPECT_EQ(blind_signature, test_vector.blinded_signature); + } +} + +} // namespace +} // namespace anonymous_tokens +} // namespace private_membership diff --git a/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blinder.cc b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blinder.cc new file mode 100644 index 000000000000..aacd44569a7f --- /dev/null +++ b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blinder.cc @@ -0,0 +1,286 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blinder.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/constants.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/shared/status_utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.pb.h" +#include "openssl/digest.h" +#include "openssl/rsa.h" + +namespace private_membership { +namespace anonymous_tokens { + +absl::StatusOr> RsaBlinder::New( + const RSABlindSignaturePublicKey& public_key, + std::optional public_metadata) { + RSAPublicKey rsa_public_key_proto; + if (!rsa_public_key_proto.ParseFromString( + public_key.serialized_public_key())) { + return absl::InvalidArgumentError("Public key is malformed."); + } + + // Convert to OpenSSL RSA which will be used in the code paths for the + // standard RSA blind signature scheme. + // + // Moreover, it will also be passed as an argument to PSS related padding and + // padding verification methods irrespective of whether RsaBlinder is being + // used as a part of the standard RSA blind signature scheme or the scheme + // with public metadata support. + ANON_TOKENS_ASSIGN_OR_RETURN( + bssl::UniquePtr rsa_public_key, + AnonymousTokensRSAPublicKeyToRSA(rsa_public_key_proto)); + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr rsa_modulus, + StringToBignum(rsa_public_key_proto.n())); + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr rsa_e, + StringToBignum(rsa_public_key_proto.e())); + + bssl::UniquePtr augmented_rsa_e = nullptr; + // If public metadata is supported, RsaBlinder will compute a new public + // exponent using the public metadata. + // + // Empty string is a valid public metadata value. + if (public_metadata.has_value()) { + ANON_TOKENS_ASSIGN_OR_RETURN( + augmented_rsa_e, + ComputeFinalExponentUnderPublicMetadata( + *rsa_modulus.get(), *rsa_e.get(), *public_metadata)); + } else { + augmented_rsa_e = std::move(rsa_e); + } + + // Owned by BoringSSL. + ANON_TOKENS_ASSIGN_OR_RETURN( + const EVP_MD* sig_hash, + ProtoHashTypeToEVPDigest(public_key.sig_hash_type())); + + // Owned by BoringSSL. + ANON_TOKENS_ASSIGN_OR_RETURN( + const EVP_MD* mgf1_hash, + ProtoMaskGenFunctionToEVPDigest(public_key.mask_gen_function())); + + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr r, NewBigNum()); + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr r_inv_mont, NewBigNum()); + + // Limit r between [2, n) so that an r of 1 never happens. An r of 1 doesn't + // blind. + if (BN_rand_range_ex(r.get(), 2, rsa_modulus.get()) != kBsslSuccess) { + return absl::InternalError( + "BN_rand_range_ex failed when called from RsaBlinder::New."); + } + + bssl::UniquePtr bn_ctx(BN_CTX_new()); + if (!bn_ctx) { + return absl::InternalError("BN_CTX_new failed."); + } + + bssl::UniquePtr bn_mont_ctx( + BN_MONT_CTX_new_for_modulus(rsa_modulus.get(), bn_ctx.get())); + if (!bn_mont_ctx) { + return absl::InternalError("BN_MONT_CTX_new_for_modulus failed."); + } + + // We wish to compute r^-1 in the Montgomery domain, or r^-1 R mod n. This is + // can be done with BN_mod_inverse_blinded followed by BN_to_montgomery, but + // it is equivalent and slightly more efficient to first compute r R^-1 mod n + // with BN_from_montgomery, and then inverting that to give r^-1 R mod n. + int is_r_not_invertible = 0; + if (BN_from_montgomery(r_inv_mont.get(), r.get(), bn_mont_ctx.get(), + bn_ctx.get()) != kBsslSuccess || + BN_mod_inverse_blinded(r_inv_mont.get(), &is_r_not_invertible, + r_inv_mont.get(), bn_mont_ctx.get(), + bn_ctx.get()) != kBsslSuccess) { + return absl::InternalError( + absl::StrCat("BN_mod_inverse failed when called from RsaBlinder::New, " + "is_r_not_invertible = ", + is_r_not_invertible)); + } + + return absl::WrapUnique(new RsaBlinder( + public_key.salt_length(), public_metadata, sig_hash, mgf1_hash, + std::move(rsa_public_key), std::move(rsa_modulus), + std::move(augmented_rsa_e), std::move(r), std::move(r_inv_mont), + std::move(bn_mont_ctx))); +} + +RsaBlinder::RsaBlinder( + int salt_length, std::optional public_metadata, + const EVP_MD* sig_hash, const EVP_MD* mgf1_hash, + bssl::UniquePtr rsa_public_key, bssl::UniquePtr rsa_modulus, + bssl::UniquePtr augmented_rsa_e, bssl::UniquePtr r, + bssl::UniquePtr r_inv_mont, bssl::UniquePtr mont_n) + : salt_length_(salt_length), + public_metadata_(public_metadata), + sig_hash_(sig_hash), + mgf1_hash_(mgf1_hash), + rsa_public_key_(std::move(rsa_public_key)), + rsa_modulus_(std::move(rsa_modulus)), + augmented_rsa_e_(std::move(augmented_rsa_e)), + r_(std::move(r)), + r_inv_mont_(std::move(r_inv_mont)), + mont_n_(std::move(mont_n)), + blinder_state_(RsaBlinder::BlinderState::kCreated) {} + +absl::StatusOr RsaBlinder::Blind(const absl::string_view message) { + // Check that the blinder state was kCreated + if (blinder_state_ != RsaBlinder::BlinderState::kCreated) { + return absl::FailedPreconditionError( + "RsaBlinder is in wrong state to blind message."); + } + std::string augmented_message(message); + if (public_metadata_.has_value()) { + augmented_message = EncodeMessagePublicMetadata(message, *public_metadata_); + } + ANON_TOKENS_ASSIGN_OR_RETURN(std::string digest_str, + ComputeHash(augmented_message, *sig_hash_)); + std::vector digest(digest_str.begin(), digest_str.end()); + + // Construct the PSS padded message, using the same workflow as BoringSSL's + // RSA_sign_pss_mgf1 for processing the message (but not signing the message): + // google3/third_party/openssl/boringssl/src/crypto/fipsmodule/rsa/rsa.c?l=557 + if (digest.size() != EVP_MD_size(sig_hash_)) { + return absl::InternalError("Invalid input message length."); + } + + // Allocate for padded length + const int padded_len = BN_num_bytes(rsa_modulus_.get()); + std::vector padded(padded_len); + + // The |md| and |mgf1_md| arguments identify the hash used to calculate + // |digest| and the MGF1 hash, respectively. If |mgf1_md| is NULL, |md| is + // used. |salt_len| specifies the expected salt length in bytes. If |salt_len| + // is -1, then the salt length is the same as the hash length. If -2, then the + // salt length is maximal given the size of |rsa|. If unsure, use -1. + if (RSA_padding_add_PKCS1_PSS_mgf1( + /*rsa=*/rsa_public_key_.get(), /*EM=*/padded.data(), + /*mHash=*/digest.data(), /*Hash=*/sig_hash_, /*mgf1Hash=*/mgf1_hash_, + /*sLen=*/salt_length_) != kBsslSuccess) { + return absl::InternalError( + "RSA_padding_add_PKCS1_PSS_mgf1 failed when called from " + "RsaBlinder::Blind"); + } + + bssl::UniquePtr bn_ctx(BN_CTX_new()); + if (!bn_ctx) { + return absl::InternalError("BN_CTX_new failed."); + } + + std::string encoded_message(padded.begin(), padded.end()); + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr encoded_message_bn, + StringToBignum(encoded_message)); + + // Take `r^e mod n`. This is an equivalent operation to RSA_encrypt, without + // extra encode/decode trips. + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr rE, NewBigNum()); + if (BN_mod_exp_mont(rE.get(), r_.get(), augmented_rsa_e_.get(), + rsa_modulus_.get(), bn_ctx.get(), + mont_n_.get()) != kBsslSuccess) { + return absl::InternalError( + "BN_mod_exp_mont failed when called from RsaBlinder::Blind."); + } + + // Do `encoded_message*r^e mod n`. + // + // To avoid leaking side channels, we use Montgomery reduction. This would be + // FromMontgomery(ModMulMontgomery(ToMontgomery(m), ToMontgomery(r^e))). + // However, this is equivalent to ModMulMontgomery(m, ToMontgomery(r^e)). + // Each BN_mod_mul_montgomery removes a factor of R, so by having only one + // input in the Montgomery domain, we save a To/FromMontgomery pair. + // + // Internally, BN_mod_exp_mont actually computes r^e in the Montgomery domain + // and converts it out, but there is no public API for this, so we perform an + // extra conversion. + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr multiplication_res, + NewBigNum()); + if (BN_to_montgomery(multiplication_res.get(), rE.get(), mont_n_.get(), + bn_ctx.get()) != kBsslSuccess || + BN_mod_mul_montgomery(multiplication_res.get(), encoded_message_bn.get(), + multiplication_res.get(), mont_n_.get(), + bn_ctx.get()) != kBsslSuccess) { + return absl::InternalError( + "BN_mod_mul failed when called from RsaBlinder::Blind."); + } + + absl::StatusOr blinded_msg = + BignumToString(*multiplication_res, BN_num_bytes(rsa_modulus_.get())); + + // Update RsaBlinder state to kBlinded + blinder_state_ = RsaBlinder::BlinderState::kBlinded; + + return blinded_msg; +} + +// Unblinds `blind_signature`. +absl::StatusOr RsaBlinder::Unblind( + const absl::string_view blind_signature) { + if (blinder_state_ != RsaBlinder::BlinderState::kBlinded) { + return absl::FailedPreconditionError( + "RsaBlinder is in wrong state to unblind signature."); + } + const int mod_size = BN_num_bytes(rsa_modulus_.get()); + // Parse the signed_blinded_data as BIGNUM. + if (blind_signature.size() != mod_size) { + return absl::InternalError(absl::StrCat( + "Expected blind signature size = ", mod_size, + " actual blind signature size = ", blind_signature.size(), " bytes.")); + } + + bssl::UniquePtr bn_ctx(BN_CTX_new()); + if (!bn_ctx) { + return absl::InternalError("BN_CTX_new failed."); + } + + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr signed_big_num, + StringToBignum(blind_signature)); + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr unblinded_sig_big, + NewBigNum()); + // Do `signed_message*r^-1 mod n`. + // + // To avoid leaking side channels, we use Montgomery reduction. This would be + // FromMontgomery(ModMulMontgomery(ToMontgomery(m), ToMontgomery(r^-1))). + // However, this is equivalent to ModMulMontgomery(m, ToMontgomery(r^-1)). + // Each BN_mod_mul_montgomery removes a factor of R, so by having only one + // input in the Montgomery domain, we save a To/FromMontgomery pair. + if (BN_mod_mul_montgomery(unblinded_sig_big.get(), signed_big_num.get(), + r_inv_mont_.get(), mont_n_.get(), + bn_ctx.get()) != kBsslSuccess) { + return absl::InternalError( + "BN_mod_mul failed when called from RsaBlinder::Unblind."); + } + absl::StatusOr unblinded_signed_message = + BignumToString(*unblinded_sig_big, + /*output_len=*/BN_num_bytes(rsa_modulus_.get())); + blinder_state_ = RsaBlinder::BlinderState::kUnblinded; + return unblinded_signed_message; +} + +absl::Status RsaBlinder::Verify(absl::string_view signature, + absl::string_view message) { + return RsaBlindSignatureVerify(salt_length_, sig_hash_, mgf1_hash_, + rsa_public_key_.get(), *rsa_modulus_.get(), + *augmented_rsa_e_.get(), signature, message, + public_metadata_); +} + +} // namespace anonymous_tokens +} // namespace private_membership diff --git a/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blinder.h b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blinder.h new file mode 100644 index 000000000000..0fb8304d8ff5 --- /dev/null +++ b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blinder.h @@ -0,0 +1,97 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CRYPTO_RSA_BLINDER_H_ +#define THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CRYPTO_RSA_BLINDER_H_ + +#include + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/blinder.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.pb.h" +#include "quiche/common/platform/api/quiche_export.h" +// copybara:strip_begin(internal comment) +// The QUICHE_EXPORT annotation is necessary for some classes and functions +// to link correctly on Windows. Please do not remove them! +// copybara:strip_end + +namespace private_membership { +namespace anonymous_tokens { + +// RsaBlinder `blinds` input messages, and then unblinds them after they are +// signed. +class QUICHE_EXPORT RsaBlinder : public Blinder { + public: + // Passing of public_metadata is optional. If it is set to any value including + // an empty string, RsaBlinder will assume that partially blind RSA signature + // protocol is being executed. + static absl::StatusOr> New( + const RSABlindSignaturePublicKey& public_key, + std::optional public_metadata = std::nullopt); + + // Blind `message` using n and e derived from an RSA public key and the public + // metadata if applicable. + // + // Before blinding, the `message` will first be hashed and then encoded with + // the EMSA-PSS operation. + absl::StatusOr Blind(absl::string_view message) override; + + // Unblinds `blind_signature`. + absl::StatusOr Unblind( + absl::string_view blind_signature) override; + + // Verifies an `unblinded` signature against the input message. + absl::Status Verify(absl::string_view signature, absl::string_view message); + + private: + // Use `New` to construct + RsaBlinder(int salt_length, std::optional public_metadata, + const EVP_MD* sig_hash, const EVP_MD* mgf1_hash, + bssl::UniquePtr rsa_public_key, + bssl::UniquePtr rsa_modulus, + bssl::UniquePtr augmented_rsa_e, bssl::UniquePtr r, + bssl::UniquePtr r_inv_mont, + bssl::UniquePtr mont_n); + + const int salt_length_; + std::optional public_metadata_; + const EVP_MD* sig_hash_; // Owned by BoringSSL. + const EVP_MD* mgf1_hash_; // Owned by BoringSSL. + + const bssl::UniquePtr rsa_public_key_; + // Storing RSA modulus separately for helping with BN computations. + const bssl::UniquePtr rsa_modulus_; + // If public metadata is not supported, augmented_rsa_e_ will be equal to + // public exponent e in rsa_public_key_. + const bssl::UniquePtr augmented_rsa_e_; + + const bssl::UniquePtr r_; + // r^-1 mod n in the Montgomery domain + const bssl::UniquePtr r_inv_mont_; + const bssl::UniquePtr mont_n_; + + BlinderState blinder_state_; +}; + +} // namespace anonymous_tokens +} // namespace private_membership + +#endif // THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CRYPTO_RSA_BLINDER_H_ diff --git a/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blinder_test.cc b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blinder_test.cc new file mode 100644 index 000000000000..cd1b11d38a8c --- /dev/null +++ b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blinder_test.cc @@ -0,0 +1,360 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_blinder.h" + +#include +#include +#include + +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/test_tools/quiche_test_utils.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/constants.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/testing/utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.pb.h" +#include "openssl/base.h" +#include "openssl/rsa.h" + +namespace private_membership { +namespace anonymous_tokens { +namespace { + +// TODO(b/275965524): Figure out a way to test RsaBlinder class with IETF test +// vectors in rsa_blinder_test.cc. + +using CreateTestKeyFunction = absl::StatusOr< + std::pair, RSABlindSignaturePublicKey>>(); + +absl::StatusOr, RSABlindSignaturePublicKey>> +CreateStandardTestKey() { + return CreateTestKey(); +} + +absl::StatusOr, RSABlindSignaturePublicKey>> +CreateShorterTestKey() { + return CreateTestKey(/*key_size=*/256); +} + +absl::StatusOr, RSABlindSignaturePublicKey>> +CreateLongerTestKey() { + return CreateTestKey(/*key_size=*/544); +} + +absl::StatusOr, RSABlindSignaturePublicKey>> +CreateSHA256TestKey() { + return CreateTestKey(/*key_size=*/512, AT_HASH_TYPE_SHA256, AT_MGF_SHA256); +} + +absl::StatusOr, RSABlindSignaturePublicKey>> +CreateLongerSaltTestKey() { + return CreateTestKey(/*key_size=*/512, AT_HASH_TYPE_SHA384, AT_MGF_SHA384, + /*salt_length=*/64); +} + +class RsaBlinderTest : public testing::TestWithParam { + protected: + void SetUp() override { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(auto test_key, (*GetParam())()); + rsa_key_ = std::move(test_key.first); + public_key_ = std::move(test_key.second); + } + + RSABlindSignaturePublicKey public_key_; + bssl::UniquePtr rsa_key_; +}; + +TEST_P(RsaBlinderTest, BlindSignUnblindEnd2EndTest) { + const absl::string_view message = "Hello World!"; + + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::unique_ptr blinder, + RsaBlinder::New(public_key_)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::string blinded_message, + blinder->Blind(message)); + EXPECT_NE(blinded_message, message); + + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::string blinded_signature, + TestSign(blinded_message, rsa_key_.get())); + EXPECT_NE(blinded_signature, blinded_message); + EXPECT_NE(blinded_signature, message); + + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::string signature, + blinder->Unblind(blinded_signature)); + EXPECT_NE(signature, blinded_signature); + EXPECT_NE(signature, blinded_message); + EXPECT_NE(signature, message); + + QUICHE_EXPECT_OK(blinder->Verify(signature, message)); +} + +TEST_P(RsaBlinderTest, DoubleBlindingFailure) { + const absl::string_view message = "Hello World2!"; + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::unique_ptr blinder, + RsaBlinder::New(public_key_)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(const std::string blinded_message, + blinder->Blind(message)); + // Blind the blinded_message + absl::StatusOr result = blinder->Blind(blinded_message); + EXPECT_EQ(result.status().code(), absl::StatusCode::kFailedPrecondition); + EXPECT_THAT(result.status().message(), testing::HasSubstr("wrong state")); + // Blind a new message + const absl::string_view new_message = "Hello World3!"; + result = blinder->Blind(new_message); + EXPECT_EQ(result.status().code(), absl::StatusCode::kFailedPrecondition); + EXPECT_THAT(result.status().message(), testing::HasSubstr("wrong state")); +} + +TEST_P(RsaBlinderTest, DoubleUnblindingFailure) { + const absl::string_view message = "Hello World2!"; + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::unique_ptr blinder, + RsaBlinder::New(public_key_)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(const std::string blinded_message, + blinder->Blind(message)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(const std::string blinded_signature, + TestSign(blinded_message, rsa_key_.get())); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::string signature, + blinder->Unblind(blinded_signature)); + // Unblind the unblinded signature + absl::StatusOr result = blinder->Unblind(signature); + EXPECT_EQ(result.status().code(), absl::StatusCode::kFailedPrecondition); + EXPECT_THAT(result.status().message(), testing::HasSubstr("wrong state")); + // Unblind the blinded_signature again + result = blinder->Unblind(signature); + EXPECT_EQ(result.status().code(), absl::StatusCode::kFailedPrecondition); + EXPECT_THAT(result.status().message(), testing::HasSubstr("wrong state")); +} + +TEST_P(RsaBlinderTest, InvalidSignature) { + const absl::string_view message = "Hello World2!"; + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::unique_ptr blinder, + RsaBlinder::New(public_key_)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(const std::string blinded_message, + blinder->Blind(message)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(const std::string blinded_signature, + TestSign(blinded_message, rsa_key_.get())); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::string signature, + blinder->Unblind(blinded_signature)); + QUICHE_EXPECT_OK(blinder->Verify(signature, message)); + + // Invalidate the signature by replacing the last 10 characters by 10 '0's + for (int i = 0; i < 10; i++) { + signature.pop_back(); + } + for (int i = 0; i < 10; i++) { + signature.push_back('0'); + } + + absl::Status result = blinder->Verify(signature, message); + EXPECT_EQ(result.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(result.message(), testing::HasSubstr("verification failed")); +} + +TEST_P(RsaBlinderTest, InvalidVerificationKey) { + const absl::string_view message = "Hello World4!"; + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::unique_ptr blinder, + RsaBlinder::New(public_key_)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(const std::string blinded_message, + blinder->Blind(message)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(const std::string blinded_signature, + TestSign(blinded_message, rsa_key_.get())); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::string signature, + blinder->Unblind(blinded_signature)); + + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(auto bad_key, CreateTestKey()); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::unique_ptr bad_blinder, + RsaBlinder::New(bad_key.second)); + EXPECT_THAT(bad_blinder->Verify(signature, message).code(), + absl::StatusCode::kInvalidArgument); +} + +INSTANTIATE_TEST_SUITE_P(RsaBlinderTest, RsaBlinderTest, + testing::Values(&CreateStandardTestKey, + &CreateShorterTestKey, + &CreateLongerTestKey, + &CreateSHA256TestKey, + &CreateLongerSaltTestKey)); + +using CreateTestKeyPairFunction = + absl::StatusOr>(); + +class RsaBlinderWithPublicMetadataTest + : public testing::TestWithParam { + protected: + void SetUp() override { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(auto test_key, (*GetParam())()); + RSABlindSignaturePublicKey public_key; + public_key.set_sig_hash_type(HashType::AT_HASH_TYPE_SHA384); + public_key.set_mask_gen_function(AT_MGF_SHA384); + public_key.set_salt_length(kSaltLengthInBytes48); + public_key.set_serialized_public_key( + std::move(test_key.first).SerializeAsString()); + public_key_ = std::move(public_key); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + rsa_key_, AnonymousTokensRSAPrivateKeyToRSA(test_key.second)); + } + + RSABlindSignaturePublicKey public_key_; + bssl::UniquePtr rsa_key_; +}; + +TEST_P(RsaBlinderWithPublicMetadataTest, + BlindSignUnblindWithPublicMetadataEnd2EndTest) { + const absl::string_view message = "Hello World!"; + const absl::string_view public_metadata = "pubmd!"; + + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::unique_ptr blinder, + RsaBlinder::New(public_key_, public_metadata)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::string blinded_message, + blinder->Blind(message)); + EXPECT_NE(blinded_message, message); + + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::string blinded_signature, + TestSignWithPublicMetadata(blinded_message, public_metadata, *rsa_key_)); + EXPECT_NE(blinded_signature, blinded_message); + EXPECT_NE(blinded_signature, message); + + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::string signature, + blinder->Unblind(blinded_signature)); + EXPECT_NE(signature, blinded_signature); + EXPECT_NE(signature, blinded_message); + EXPECT_NE(signature, message); + + QUICHE_EXPECT_OK(blinder->Verify(signature, message)); +} + +TEST_P(RsaBlinderWithPublicMetadataTest, + BlindSignUnblindWithEmptyPublicMetadataEnd2EndTest) { + const absl::string_view message = "Hello World!"; + const absl::string_view empty_public_metadata = ""; + + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::unique_ptr blinder, + RsaBlinder::New(public_key_, empty_public_metadata)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::string blinded_message, + blinder->Blind(message)); + EXPECT_NE(blinded_message, message); + + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::string blinded_signature, + TestSignWithPublicMetadata(blinded_message, empty_public_metadata, + *rsa_key_)); + EXPECT_NE(blinded_signature, blinded_message); + EXPECT_NE(blinded_signature, message); + + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::string signature, + blinder->Unblind(blinded_signature)); + EXPECT_NE(signature, blinded_signature); + EXPECT_NE(signature, blinded_message); + EXPECT_NE(signature, message); + + QUICHE_EXPECT_OK(blinder->Verify(signature, message)); +} + +TEST_P(RsaBlinderWithPublicMetadataTest, WrongPublicMetadata) { + const absl::string_view message = "Hello World!"; + const absl::string_view public_metadata = "pubmd!"; + const absl::string_view public_metadata_2 = "pubmd2"; + + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::unique_ptr blinder, + RsaBlinder::New(public_key_, public_metadata)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::string blinded_message, + blinder->Blind(message)); + EXPECT_NE(blinded_message, message); + + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::string blinded_signature, + TestSignWithPublicMetadata(blinded_message, public_metadata_2, + *rsa_key_)); + EXPECT_NE(blinded_signature, blinded_message); + EXPECT_NE(blinded_signature, message); + + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::string signature, + blinder->Unblind(blinded_signature)); + EXPECT_NE(signature, blinded_signature); + EXPECT_NE(signature, blinded_message); + EXPECT_NE(signature, message); + EXPECT_THAT( + blinder->Verify(signature, message), + quiche::test::StatusIs(absl::StatusCode::kInvalidArgument, + ::testing::HasSubstr("verification failed"))); +} + +TEST_P(RsaBlinderWithPublicMetadataTest, NoPublicMetadataForSigning) { + const absl::string_view message = "Hello World!"; + const absl::string_view public_metadata = "pubmd!"; + + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::unique_ptr blinder, + RsaBlinder::New(public_key_, public_metadata)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::string blinded_message, + blinder->Blind(message)); + EXPECT_NE(blinded_message, message); + + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::string blinded_signature, + TestSign(blinded_message, rsa_key_.get())); + EXPECT_NE(blinded_signature, blinded_message); + EXPECT_NE(blinded_signature, message); + + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::string signature, + blinder->Unblind(blinded_signature)); + EXPECT_NE(signature, blinded_signature); + EXPECT_NE(signature, blinded_message); + EXPECT_NE(signature, message); + EXPECT_THAT( + blinder->Verify(signature, message), + quiche::test::StatusIs(absl::StatusCode::kInvalidArgument, + ::testing::HasSubstr("verification failed"))); +} + +TEST_P(RsaBlinderWithPublicMetadataTest, NoPublicMetadataInBlinding) { + const absl::string_view message = "Hello World!"; + const absl::string_view public_metadata = "pubmd!"; + + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::unique_ptr blinder, + RsaBlinder::New(public_key_)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::string blinded_message, + blinder->Blind(message)); + EXPECT_NE(blinded_message, message); + + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::string blinded_signature, + TestSignWithPublicMetadata(blinded_message, public_metadata, *rsa_key_)); + EXPECT_NE(blinded_signature, blinded_message); + EXPECT_NE(blinded_signature, message); + + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(std::string signature, + blinder->Unblind(blinded_signature)); + EXPECT_NE(signature, blinded_signature); + EXPECT_NE(signature, blinded_message); + EXPECT_NE(signature, message); + EXPECT_THAT( + blinder->Verify(signature, message), + quiche::test::StatusIs(absl::StatusCode::kInvalidArgument, + ::testing::HasSubstr("verification failed"))); +} + +INSTANTIATE_TEST_SUITE_P( + RsaBlinderWithPublicMetadataTest, RsaBlinderWithPublicMetadataTest, + testing::Values(&GetStrongRsaKeys2048, &GetAnotherStrongRsaKeys2048, + &GetStrongRsaKeys3072, &GetStrongRsaKeys4096)); + +} // namespace +} // namespace anonymous_tokens +} // namespace private_membership diff --git a/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_ssa_pss_verifier.cc b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_ssa_pss_verifier.cc new file mode 100644 index 000000000000..be883147a9bc --- /dev/null +++ b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_ssa_pss_verifier.cc @@ -0,0 +1,97 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_ssa_pss_verifier.h" + +#include + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/constants.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/shared/status_utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.pb.h" +#include "openssl/bn.h" +#include "openssl/rsa.h" + +namespace private_membership { +namespace anonymous_tokens { + +absl::StatusOr> RsaSsaPssVerifier::New( + const int salt_length, const EVP_MD* sig_hash, const EVP_MD* mgf1_hash, + const RSAPublicKey& public_key, + std::optional public_metadata) { + // Convert to OpenSSL RSA which will be used in the code paths for the + // standard RSA blind signature scheme. + // + // Moreover, it will also be passed as an argument to PSS related padding + // verification methods irrespective of whether RsaBlinder is being used as a + // part of the standard RSA blind signature scheme or the scheme with public + // metadata support. + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr rsa_public_key, + AnonymousTokensRSAPublicKeyToRSA(public_key)); + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr rsa_modulus, + StringToBignum(public_key.n())); + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr rsa_e, + StringToBignum(public_key.e())); + + bssl::UniquePtr augmented_rsa_e = nullptr; + // If public metadata is supported, RsaSsaPssVerifier will compute a new + // public exponent using the public metadata. + // + // Empty string is a valid public metadata value. + if (public_metadata.has_value()) { + ANON_TOKENS_ASSIGN_OR_RETURN( + augmented_rsa_e, + ComputeFinalExponentUnderPublicMetadata( + *rsa_modulus.get(), *rsa_e.get(), *public_metadata)); + } else { + augmented_rsa_e = std::move(rsa_e); + } + return absl::WrapUnique( + new RsaSsaPssVerifier(salt_length, public_metadata, sig_hash, mgf1_hash, + std::move(rsa_public_key), std::move(rsa_modulus), + std::move(augmented_rsa_e))); +} + +RsaSsaPssVerifier::RsaSsaPssVerifier( + int salt_length, std::optional public_metadata, + const EVP_MD* sig_hash, const EVP_MD* mgf1_hash, + bssl::UniquePtr rsa_public_key, bssl::UniquePtr rsa_modulus, + bssl::UniquePtr augmented_rsa_e) + : salt_length_(salt_length), + public_metadata_(public_metadata), + sig_hash_(sig_hash), + mgf1_hash_(mgf1_hash), + rsa_public_key_(std::move(rsa_public_key)), + rsa_modulus_(std::move(rsa_modulus)), + augmented_rsa_e_(std::move(augmented_rsa_e)) {} + +absl::Status RsaSsaPssVerifier::Verify(absl::string_view unblind_token, + absl::string_view message) { + return RsaBlindSignatureVerify(salt_length_, sig_hash_, mgf1_hash_, + rsa_public_key_.get(), *rsa_modulus_.get(), + *augmented_rsa_e_.get(), unblind_token, + message, public_metadata_); +} + +} // namespace anonymous_tokens +} // namespace private_membership diff --git a/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_ssa_pss_verifier.h b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_ssa_pss_verifier.h new file mode 100644 index 000000000000..199717fad4ec --- /dev/null +++ b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_ssa_pss_verifier.h @@ -0,0 +1,81 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CRYPTO_RSA_SSA_PSS_VERIFIER_H_ +#define THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CRYPTO_RSA_SSA_PSS_VERIFIER_H_ + +#include + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/verifier.h" +#include "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.pb.h" +#include "quiche/common/platform/api/quiche_export.h" +// copybara:strip_begin(internal comment) +// The QUICHE_EXPORT annotation is necessary for some classes and functions +// to link correctly on Windows. Please do not remove them! +// copybara:strip_end + +namespace private_membership { +namespace anonymous_tokens { + +// RsaSsaPssVerifier is able to verify an unblinded token (signature) against an +// inputted message using a public key and other input parameters. +class QUICHE_EXPORT RsaSsaPssVerifier : public Verifier { + public: + // Passing of public_metadata is optional. If it is set to any value including + // an empty string, RsaSsaPssVerifier will assume that partially blind RSA + // signature protocol is being executed. + static absl::StatusOr> New( + int salt_length, const EVP_MD* sig_hash, const EVP_MD* mgf1_hash, + const RSAPublicKey& public_key, + std::optional public_metadata = std::nullopt); + + // Verifies the signature. + // + // Returns OkStatus() on successful verification. Otherwise returns an error. + absl::Status Verify(absl::string_view unblind_token, + absl::string_view message) override; + + private: + // Use `New` to construct + RsaSsaPssVerifier(int salt_length, + std::optional public_metadata, + const EVP_MD* sig_hash, const EVP_MD* mgf1_hash, + bssl::UniquePtr rsa_public_key, + bssl::UniquePtr rsa_modulus, + bssl::UniquePtr augmented_rsa_e); + + const int salt_length_; + std::optional public_metadata_; + const EVP_MD* sig_hash_; // Owned by BoringSSL. + const EVP_MD* mgf1_hash_; // Owned by BoringSSL. + + const bssl::UniquePtr rsa_public_key_; + // Storing RSA modulus separately for helping with BN computations. + const bssl::UniquePtr rsa_modulus_; + // If public metadata is not supported, augmented_rsa_e_ will be equal to + // public exponent e in rsa_public_key_. + const bssl::UniquePtr augmented_rsa_e_; +}; + +} // namespace anonymous_tokens +} // namespace private_membership + +#endif // THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CRYPTO_RSA_SSA_PSS_VERIFIER_H_ diff --git a/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_ssa_pss_verifier_test.cc b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_ssa_pss_verifier_test.cc new file mode 100644 index 000000000000..1d19310dc389 --- /dev/null +++ b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_ssa_pss_verifier_test.cc @@ -0,0 +1,287 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/rsa_ssa_pss_verifier.h" + +#include +#include +#include + +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/test_tools/quiche_test_utils.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/constants.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/testing/utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.pb.h" +#include "openssl/rsa.h" + +namespace private_membership { +namespace anonymous_tokens { +namespace { + +// TODO(b/259581423): Add tests incorporating blinder and signer. +// TODO(b/275956922): Consolidate all tests that use IETF test vectors into one +// E2E test. +TEST(RsaSsaPssVerifier, SuccessfulVerification) { + const IetfStandardRsaBlindSignatureTestVector test_vec = + GetIetfStandardRsaBlindSignatureTestVector(); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(const auto test_keys, + GetIetfStandardRsaBlindSignatureTestKeys()); + const EVP_MD *sig_hash = EVP_sha384(); // Owned by BoringSSL + const EVP_MD *mgf1_hash = EVP_sha384(); // Owned by BoringSSL + const int salt_length = kSaltLengthInBytes48; + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + const auto verifier, RsaSsaPssVerifier::New(salt_length, sig_hash, + mgf1_hash, test_keys.first)); + QUICHE_EXPECT_OK(verifier->Verify(test_vec.signature, test_vec.message)); +} + +TEST(RsaSsaPssVerifier, InvalidSignature) { + const IetfStandardRsaBlindSignatureTestVector test_vec = + GetIetfStandardRsaBlindSignatureTestVector(); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(const auto test_keys, + GetIetfStandardRsaBlindSignatureTestKeys()); + const EVP_MD *sig_hash = EVP_sha384(); // Owned by BoringSSL + const EVP_MD *mgf1_hash = EVP_sha384(); // Owned by BoringSSL + const int salt_length = kSaltLengthInBytes48; + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + const auto verifier, RsaSsaPssVerifier::New(salt_length, sig_hash, + mgf1_hash, test_keys.first)); + // corrupt signature + std::string wrong_sig = test_vec.signature; + wrong_sig.replace(10, 1, "x"); + + EXPECT_THAT( + verifier->Verify(wrong_sig, test_vec.message), + quiche::test::StatusIs(absl::StatusCode::kInvalidArgument, + testing::HasSubstr("verification failed"))); +} + +TEST(RsaSsaPssVerifier, InvalidVerificationKey) { + const IetfStandardRsaBlindSignatureTestVector test_vec = + GetIetfStandardRsaBlindSignatureTestVector(); + const EVP_MD *sig_hash = EVP_sha384(); // Owned by BoringSSL + const EVP_MD *mgf1_hash = EVP_sha384(); // Owned by BoringSSL + const int salt_length = kSaltLengthInBytes48; + // wrong key + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(auto new_keys_pair, GetStandardRsaKeyPair()); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + const auto verifier, + RsaSsaPssVerifier::New(salt_length, sig_hash, mgf1_hash, + new_keys_pair.first)); + + EXPECT_THAT( + verifier->Verify(test_vec.signature, test_vec.message), + quiche::test::StatusIs(absl::StatusCode::kInvalidArgument, + testing::HasSubstr("verification failed"))); +} + +TEST(RsaSsaPssVerifierTestWithPublicMetadata, + EmptyMessageStandardVerificationSuccess) { + absl::string_view message = ""; + const EVP_MD *sig_hash = EVP_sha384(); // Owned by BoringSSL + const EVP_MD *mgf1_hash = EVP_sha384(); // Owned by BoringSSL + const int salt_length = kSaltLengthInBytes48; + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(const auto test_key, + GetStandardRsaKeyPair()); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + auto private_key, AnonymousTokensRSAPrivateKeyToRSA(test_key.second)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::string encoded_message, + EncodeMessageForTests(message, test_key.first, sig_hash, mgf1_hash, + salt_length)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::string potentially_insecure_signature, + TestSign(encoded_message, private_key.get())); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + auto verifier, + RsaSsaPssVerifier::New(salt_length, sig_hash, mgf1_hash, test_key.first)); + QUICHE_EXPECT_OK(verifier->Verify(potentially_insecure_signature, message)); +} + +// TODO(b/275956922): Consolidate all tests that use IETF test vectors into one +// E2E test. +TEST(RsaSsaPssVerifierTestWithPublicMetadata, + IetfRsaBlindSignaturesWithPublicMetadataTestVectorsSuccess) { + auto test_vectors = GetIetfRsaBlindSignatureWithPublicMetadataTestVectors(); + const EVP_MD *sig_hash = EVP_sha384(); // Owned by BoringSSL + const EVP_MD *mgf1_hash = EVP_sha384(); // Owned by BoringSSL + const int salt_length = kSaltLengthInBytes48; + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + const auto test_key, + GetIetfRsaBlindSignatureWithPublicMetadataTestKeys()); + for (const auto &test_vector : test_vectors) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + auto verifier, + RsaSsaPssVerifier::New(salt_length, sig_hash, mgf1_hash, test_key.first, + test_vector.public_metadata)); + QUICHE_EXPECT_OK(verifier->Verify( + test_vector.signature, + MaskMessageConcat(test_vector.message_mask, test_vector.message))); + } +} + +using CreateTestKeyPairFunction = + absl::StatusOr>(); + +class RsaSsaPssVerifierTestWithPublicMetadata + : public ::testing::TestWithParam { + protected: + void SetUp() override { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(auto keys_pair, (*GetParam())()); + public_key_ = std::move(keys_pair.first); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + private_key_, AnonymousTokensRSAPrivateKeyToRSA(keys_pair.second)); + // NOTE: using recommended RsaSsaPssParams + sig_hash_ = EVP_sha384(); + mgf1_hash_ = EVP_sha384(); + salt_length_ = kSaltLengthInBytes48; + } + + RSAPublicKey public_key_; + bssl::UniquePtr private_key_; + const EVP_MD *sig_hash_; // Owned by BoringSSL. + const EVP_MD *mgf1_hash_; // Owned by BoringSSL. + int salt_length_; +}; + +// This test only tests whether the implemented verfier 'verifies' properly +// under some public metadata. The outline of method calls in this test should +// not be assumed a secure signature scheme (and used in other places) as the +// security has not been proven/analyzed. +TEST_P(RsaSsaPssVerifierTestWithPublicMetadata, + VerifierWorksWithPublicMetadata) { + absl::string_view message = "Hello World!"; + absl::string_view public_metadata = "pubmd!"; + std::string augmented_message = + EncodeMessagePublicMetadata(message, public_metadata); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::string encoded_message, + EncodeMessageForTests(augmented_message, public_key_, sig_hash_, + mgf1_hash_, salt_length_)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::string potentially_insecure_signature, + TestSignWithPublicMetadata(encoded_message, public_metadata, + *private_key_)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + auto verifier, RsaSsaPssVerifier::New(salt_length_, sig_hash_, mgf1_hash_, + public_key_, public_metadata)); + QUICHE_EXPECT_OK(verifier->Verify(potentially_insecure_signature, message)); +} + +TEST_P(RsaSsaPssVerifierTestWithPublicMetadata, + VerifierFailsToVerifyWithWrongPublicMetadata) { + absl::string_view message = "Hello World!"; + absl::string_view public_metadata = "pubmd!"; + absl::string_view public_metadata_2 = "pubmd2"; + std::string augmented_message = + EncodeMessagePublicMetadata(message, public_metadata); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::string encoded_message, + EncodeMessageForTests(augmented_message, public_key_, sig_hash_, + mgf1_hash_, salt_length_)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::string potentially_insecure_signature, + TestSignWithPublicMetadata(encoded_message, public_metadata, + *private_key_)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + auto verifier, RsaSsaPssVerifier::New(salt_length_, sig_hash_, mgf1_hash_, + public_key_, public_metadata_2)); + EXPECT_THAT( + verifier->Verify(potentially_insecure_signature, message), + quiche::test::StatusIs(absl::StatusCode::kInvalidArgument, + ::testing::HasSubstr("verification failed"))); +} + +TEST_P(RsaSsaPssVerifierTestWithPublicMetadata, + VerifierFailsToVerifyWithEmptyPublicMetadata) { + absl::string_view message = "Hello World!"; + absl::string_view public_metadata = "pubmd!"; + absl::string_view empty_public_metadata = ""; + std::string augmented_message = + EncodeMessagePublicMetadata(message, public_metadata); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::string encoded_message, + EncodeMessageForTests(augmented_message, public_key_, sig_hash_, + mgf1_hash_, salt_length_)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::string potentially_insecure_signature, + TestSignWithPublicMetadata(encoded_message, public_metadata, + *private_key_)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + auto verifier, + RsaSsaPssVerifier::New(salt_length_, sig_hash_, mgf1_hash_, public_key_, + empty_public_metadata)); + EXPECT_THAT( + verifier->Verify(potentially_insecure_signature, message), + quiche::test::StatusIs(absl::StatusCode::kInvalidArgument, + ::testing::HasSubstr("verification failed"))); +} + +TEST_P(RsaSsaPssVerifierTestWithPublicMetadata, + VerifierFailsToVerifyWithoutPublicMetadataSupport) { + absl::string_view message = "Hello World!"; + absl::string_view public_metadata = "pubmd!"; + std::string augmented_message = + EncodeMessagePublicMetadata(message, public_metadata); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::string encoded_message, + EncodeMessageForTests(augmented_message, public_key_, sig_hash_, + mgf1_hash_, salt_length_)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::string potentially_insecure_signature, + TestSignWithPublicMetadata(encoded_message, public_metadata, + *private_key_)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + auto verifier, + RsaSsaPssVerifier::New(salt_length_, sig_hash_, mgf1_hash_, public_key_)); + EXPECT_THAT( + verifier->Verify(potentially_insecure_signature, message), + quiche::test::StatusIs(absl::StatusCode::kInvalidArgument, + ::testing::HasSubstr("verification failed"))); +} + +TEST_P(RsaSsaPssVerifierTestWithPublicMetadata, + EmptyMessageEmptyPublicMetadataVerificationSuccess) { + absl::string_view message = ""; + absl::string_view public_metadata = ""; + std::string augmented_message = + EncodeMessagePublicMetadata(message, public_metadata); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::string encoded_message, + EncodeMessageForTests(augmented_message, public_key_, sig_hash_, + mgf1_hash_, salt_length_)); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + std::string potentially_insecure_signature, + TestSignWithPublicMetadata(encoded_message, public_metadata, + *private_key_.get())); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + auto verifier, RsaSsaPssVerifier::New(salt_length_, sig_hash_, mgf1_hash_, + public_key_, public_metadata)); + QUICHE_EXPECT_OK(verifier->Verify(potentially_insecure_signature, message)); +} + +INSTANTIATE_TEST_SUITE_P(RsaSsaPssVerifierTestWithPublicMetadata, + RsaSsaPssVerifierTestWithPublicMetadata, + ::testing::Values(&GetStrongRsaKeys2048, + &GetAnotherStrongRsaKeys2048, + &GetStrongRsaKeys3072, + &GetStrongRsaKeys4096)); + +} // namespace +} // namespace anonymous_tokens +} // namespace private_membership diff --git a/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/verifier.h b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/verifier.h new file mode 100644 index 000000000000..14e3d2f2caba --- /dev/null +++ b/quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/verifier.h @@ -0,0 +1,29 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CRYPTO_VERIFIER_H_ +#define THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CRYPTO_VERIFIER_H_ + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" + +class Verifier { + public: + virtual absl::Status Verify(absl::string_view signature, + absl::string_view message) = 0; + + virtual ~Verifier() = default; +}; + +#endif // THIRD_PARTY_ANONYMOUS_TOKENS_CPP_CRYPTO_VERIFIER_H_ diff --git a/quiche/blind_sign_auth/anonymous_tokens/cpp/shared/proto_utils.cc b/quiche/blind_sign_auth/anonymous_tokens/cpp/shared/proto_utils.cc new file mode 100644 index 000000000000..2b7a7598fd9b --- /dev/null +++ b/quiche/blind_sign_auth/anonymous_tokens/cpp/shared/proto_utils.cc @@ -0,0 +1,64 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/shared/proto_utils.h" + +namespace private_membership { +namespace anonymous_tokens { + +absl::StatusOr ParseUseCase( + absl::string_view use_case) { + AnonymousTokensUseCase parsed_use_case; + if (!AnonymousTokensUseCase_Parse(std::string(use_case), &parsed_use_case) || + parsed_use_case == ANONYMOUS_TOKENS_USE_CASE_UNDEFINED) { + return absl::InvalidArgumentError( + "Invalid / undefined use case cannot be parsed."); + } + return parsed_use_case; +} + +absl::StatusOr TimeFromProto( + const quiche::protobuf::Timestamp& proto) { + const auto sec = proto.seconds(); + const auto ns = proto.nanos(); + // sec must be [0001-01-01T00:00:00Z, 9999-12-31T23:59:59.999999999Z] + if (sec < -62135596800 || sec > 253402300799) { + return absl::InvalidArgumentError(absl::StrCat("seconds=", sec)); + } + if (ns < 0 || ns > 999999999) { + return absl::InvalidArgumentError(absl::StrCat("nanos=", ns)); + } + return absl::FromUnixSeconds(proto.seconds()) + + absl::Nanoseconds(proto.nanos()); +} + +absl::StatusOr TimeToProto(absl::Time time) { + quiche::protobuf::Timestamp proto; + const int64_t seconds = absl::ToUnixSeconds(time); + proto.set_seconds(seconds); + proto.set_nanos((time - absl::FromUnixSeconds(seconds)) / + absl::Nanoseconds(1)); + // seconds must be [0001-01-01T00:00:00Z, 9999-12-31T23:59:59.999999999Z] + if (seconds < -62135596800 || seconds > 253402300799) { + return absl::InvalidArgumentError(absl::StrCat("seconds=", seconds)); + } + const int64_t ns = proto.nanos(); + if (ns < 0 || ns > 999999999) { + return absl::InvalidArgumentError(absl::StrCat("nanos=", ns)); + } + return proto; +} + +} // namespace anonymous_tokens +} // namespace private_membership diff --git a/quiche/blind_sign_auth/anonymous_tokens/cpp/shared/proto_utils.h b/quiche/blind_sign_auth/anonymous_tokens/cpp/shared/proto_utils.h new file mode 100644 index 000000000000..bc9b8df96cc1 --- /dev/null +++ b/quiche/blind_sign_auth/anonymous_tokens/cpp/shared/proto_utils.h @@ -0,0 +1,53 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_ANONYMOUS_TOKENS_CPP_SHARED_PROTO_UTILS_H_ +#define THIRD_PARTY_ANONYMOUS_TOKENS_CPP_SHARED_PROTO_UTILS_H_ + +#include "quiche/blind_sign_auth/proto/timestamp.pb.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.pb.h" +#include "quiche/common/platform/api/quiche_export.h" +// copybara:strip_begin(internal comment) +// The QUICHE_EXPORT annotation is necessary for some classes and functions +// to link correctly on Windows. Please do not remove them! +// copybara:strip_end + +namespace private_membership { +namespace anonymous_tokens { + +// Returns AnonymousTokensUseCase parsed from a string_view. +absl::StatusOr QUICHE_EXPORT ParseUseCase( + absl::string_view use_case); + +// Takes in quiche::protobuf::Timestamp and converts it to absl::Time. +// +// Timestamp is defined here: +// https://developers.google.com/protocol-buffers/docs/reference/quiche.protobuf#timestamp +absl::StatusOr QUICHE_EXPORT TimeFromProto( + const quiche::protobuf::Timestamp& proto); + +// Takes in absl::Time and converts it to quiche::protobuf::Timestamp. +// +// Timestamp is defined here: +// https://developers.google.com/protocol-buffers/docs/reference/quiche.protobuf#timestamp +absl::StatusOr QUICHE_EXPORT TimeToProto( + absl::Time time); + +} // namespace anonymous_tokens +} // namespace private_membership + +#endif // THIRD_PARTY_ANONYMOUS_TOKENS_CPP_SHARED_PROTO_UTILS_H_ diff --git a/quiche/blind_sign_auth/anonymous_tokens/cpp/shared/proto_utils_test.cc b/quiche/blind_sign_auth/anonymous_tokens/cpp/shared/proto_utils_test.cc new file mode 100644 index 000000000000..5c7d84533d7c --- /dev/null +++ b/quiche/blind_sign_auth/anonymous_tokens/cpp/shared/proto_utils_test.cc @@ -0,0 +1,93 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/shared/proto_utils.h" + +#include "quiche/blind_sign_auth/proto/timestamp.pb.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/test_tools/quiche_test_utils.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/time/time.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/testing/utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.pb.h" + +namespace private_membership { +namespace anonymous_tokens { +namespace { + +TEST(ProtoUtilsTest, EmptyUseCase) { + EXPECT_THAT(ParseUseCase("").status().code(), + absl::StatusCode::kInvalidArgument); +} + +TEST(ProtoUtilsTest, InvalidUseCase) { + EXPECT_THAT(ParseUseCase("NOT_A_USE_CASE").status().code(), + absl::StatusCode::kInvalidArgument); +} + +TEST(ProtoUtilsTest, UndefinedUseCase) { + EXPECT_THAT( + ParseUseCase("ANONYMOUS_TOKENS_USE_CASE_UNDEFINED").status().code(), + absl::StatusCode::kInvalidArgument); +} + +TEST(ProtoUtilsTest, ValidUseCase) { + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(AnonymousTokensUseCase use_case, + ParseUseCase("TEST_USE_CASE")); + EXPECT_EQ(use_case, AnonymousTokensUseCase::TEST_USE_CASE); +} + +TEST(ProtoUtilsTest, TimeFromProtoGood) { + quiche::protobuf::Timestamp timestamp; + timestamp.set_seconds(1234567890); + timestamp.set_nanos(12345); + ANON_TOKENS_ASSERT_OK_AND_ASSIGN(absl::Time time, TimeFromProto(timestamp)); + ASSERT_EQ(time, absl::FromUnixNanos(1234567890000012345)); +} + +TEST(ProtoUtilsTest, TimeFromProtoBad) { + quiche::protobuf::Timestamp proto; + proto.set_nanos(-1); + EXPECT_THAT(TimeFromProto(proto).status().code(), + absl::StatusCode::kInvalidArgument); + + proto.set_nanos(0); + proto.set_seconds(253402300800); + EXPECT_THAT(TimeFromProto(proto).status().code(), + absl::StatusCode::kInvalidArgument); +} + +TEST(ProtoUtilsTest, TimeToProtoGood) { + quiche::protobuf::Timestamp proto; + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + proto, TimeToProto(absl::FromUnixSeconds(1596762373))); + EXPECT_EQ(proto.seconds(), 1596762373); + EXPECT_EQ(proto.nanos(), 0); + + ANON_TOKENS_ASSERT_OK_AND_ASSIGN( + proto, TimeToProto(absl::FromUnixMillis(1596762373123L))); + EXPECT_EQ(proto.seconds(), 1596762373); + EXPECT_EQ(proto.nanos(), 123000000); +} + +TEST(ProtoUtilsTest, TimeToProtoBad) { + absl::StatusOr proto; + proto = TimeToProto(absl::FromUnixSeconds(253402300800)); + EXPECT_THAT(proto.status().code(), absl::StatusCode::kInvalidArgument); +} + +} // namespace +} // namespace anonymous_tokens +} // namespace private_membership diff --git a/quiche/blind_sign_auth/anonymous_tokens/cpp/shared/status_utils.h b/quiche/blind_sign_auth/anonymous_tokens/cpp/shared/status_utils.h new file mode 100644 index 000000000000..bcdddf67e193 --- /dev/null +++ b/quiche/blind_sign_auth/anonymous_tokens/cpp/shared/status_utils.h @@ -0,0 +1,49 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_ANONYMOUS_TOKENS_CPP_SHARED_STATUS_UTILS_H_ +#define THIRD_PARTY_ANONYMOUS_TOKENS_CPP_SHARED_STATUS_UTILS_H_ + +#include "absl/base/optimization.h" +#include "absl/status/status.h" + +namespace private_membership { +namespace anonymous_tokens { + +#define _ANON_TOKENS_STATUS_MACROS_CONCAT_NAME(x, y) \ + _ANON_TOKENS_STATUS_MACROS_CONCAT_IMPL(x, y) +#define _ANON_TOKENS_STATUS_MACROS_CONCAT_IMPL(x, y) x##y + +#define ANON_TOKENS_ASSIGN_OR_RETURN(lhs, rexpr) \ + _ANON_TOKENS_ASSIGN_OR_RETURN_IMPL( \ + _ANON_TOKENS_STATUS_MACROS_CONCAT_NAME(_status_or_val, __LINE__), lhs, \ + rexpr) + +#define _ANON_TOKENS_ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr) \ + auto statusor = (rexpr); \ + if (ABSL_PREDICT_FALSE(!statusor.ok())) { \ + return statusor.status(); \ + } \ + lhs = *std::move(statusor) + +#define ANON_TOKENS_RETURN_IF_ERROR(expr) \ + do { \ + auto _status = (expr); \ + if (ABSL_PREDICT_FALSE(!_status.ok())) return _status; \ + } while (0) + +} // namespace anonymous_tokens +} // namespace private_membership + +#endif // THIRD_PARTY_ANONYMOUS_TOKENS_CPP_SHARED_STATUS_UTILS_H_ diff --git a/quiche/blind_sign_auth/anonymous_tokens/cpp/testing/utils.cc b/quiche/blind_sign_auth/anonymous_tokens/cpp/testing/utils.cc new file mode 100644 index 000000000000..63fd304081d9 --- /dev/null +++ b/quiche/blind_sign_auth/anonymous_tokens/cpp/testing/utils.cc @@ -0,0 +1,790 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/testing/utils.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/constants.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/shared/status_utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.pb.h" +#include "quiche/common/platform/api/quiche_file_utils.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "openssl/rsa.h" + +namespace private_membership { +namespace anonymous_tokens { + +namespace { + +absl::StatusOr ReadFileToString(absl::string_view path) { + std::ifstream file(std::string(path), std::ios::binary); + if (!file.is_open()) { + return absl::InternalError("Reading file failed."); + } + std::ostringstream ss(std::ios::binary); + ss << file.rdbuf(); + return ss.str(); +} + +absl::StatusOr> ParseRsaKeysFromFile( + absl::string_view path) { + ANON_TOKENS_ASSIGN_OR_RETURN(std::string binary_proto, + ReadFileToString(path)); + RSAPrivateKey private_key; + if (!private_key.ParseFromString(binary_proto)) { + return absl::InternalError("Parsing binary proto failed."); + } + RSAPublicKey public_key; + public_key.set_n(private_key.n()); + public_key.set_e(private_key.e()); + return std::make_pair(std::move(public_key), std::move(private_key)); +} + +absl::StatusOr> GenerateRSAKey(int modulus_bit_size, + const BIGNUM& e) { + bssl::UniquePtr rsa(RSA_new()); + if (!rsa.get()) { + return absl::InternalError( + absl::StrCat("RSA_new failed: ", GetSslErrors())); + } + if (RSA_generate_key_ex(rsa.get(), modulus_bit_size, &e, + /*cb=*/nullptr) != kBsslSuccess) { + return absl::InternalError( + absl::StrCat("Error generating private key: ", GetSslErrors())); + } + return rsa; +} + +absl::StatusOr> PopulateTestVectorKeys( + const std::string& n, const std::string& e, const std::string& d, + const std::string& p, const std::string& q) { + RSAPublicKey public_key; + RSAPrivateKey private_key; + + public_key.set_n(n); + public_key.set_e(e); + + private_key.set_n(n); + private_key.set_e(e); + private_key.set_d(d); + private_key.set_p(p); + private_key.set_q(q); + + // Computing CRT parameters + ANON_TOKENS_ASSIGN_OR_RETURN(BnCtxPtr bn_ctx, GetAndStartBigNumCtx()); + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr dp_bn, NewBigNum()); + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr dq_bn, NewBigNum()); + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr crt_bn, NewBigNum()); + + // p - 1 + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr pm1, StringToBignum(p)); + BN_sub_word(pm1.get(), 1); + // q - 1 + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr qm1, StringToBignum(q)); + BN_sub_word(qm1.get(), 1); + // d mod p-1 + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr d_bn, StringToBignum(d)); + BN_mod(dp_bn.get(), d_bn.get(), pm1.get(), bn_ctx.get()); + // d mod q-1 + BN_mod(dq_bn.get(), d_bn.get(), qm1.get(), bn_ctx.get()); + // crt q^(-1) mod p + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr q_bn, StringToBignum(q)); + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr p_bn, StringToBignum(p)); + BN_mod_inverse(crt_bn.get(), q_bn.get(), p_bn.get(), bn_ctx.get()); + + // Populating crt params in private key + ANON_TOKENS_ASSIGN_OR_RETURN( + std::string dp_str, BignumToString(*dp_bn, BN_num_bytes(dp_bn.get()))); + ANON_TOKENS_ASSIGN_OR_RETURN( + std::string dq_str, BignumToString(*dq_bn, BN_num_bytes(dq_bn.get()))); + ANON_TOKENS_ASSIGN_OR_RETURN( + std::string crt_str, BignumToString(*crt_bn, BN_num_bytes(crt_bn.get()))); + private_key.set_dp(dp_str); + private_key.set_dq(dq_str); + private_key.set_crt(crt_str); + + return std::make_pair(std::move(public_key), std::move(private_key)); +} + +} // namespace + +absl::StatusOr, RSABlindSignaturePublicKey>> +CreateTestKey(int key_size, HashType sig_hash, MaskGenFunction mfg1_hash, + int salt_length, MessageMaskType message_mask_type, + int message_mask_size) { + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr rsa_f4, NewBigNum()); + BN_set_u64(rsa_f4.get(), RSA_F4); + + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr rsa_key, + GenerateRSAKey(key_size * 8, *rsa_f4)); + + RSAPublicKey rsa_public_key; + ANON_TOKENS_ASSIGN_OR_RETURN( + *rsa_public_key.mutable_n(), + BignumToString(*RSA_get0_n(rsa_key.get()), key_size)); + ANON_TOKENS_ASSIGN_OR_RETURN( + *rsa_public_key.mutable_e(), + BignumToString(*RSA_get0_e(rsa_key.get()), key_size)); + + RSABlindSignaturePublicKey public_key; + public_key.set_serialized_public_key(rsa_public_key.SerializeAsString()); + public_key.set_sig_hash_type(sig_hash); + public_key.set_mask_gen_function(mfg1_hash); + public_key.set_salt_length(salt_length); + public_key.set_key_size(key_size); + public_key.set_message_mask_type(message_mask_type); + public_key.set_message_mask_size(message_mask_size); + + return std::make_pair(std::move(rsa_key), std::move(public_key)); +} + +absl::StatusOr TestSign(const absl::string_view blinded_data, + RSA* rsa_key) { + if (blinded_data.empty()) { + return absl::InvalidArgumentError("blinded_data string is empty."); + } + const size_t mod_size = RSA_size(rsa_key); + if (blinded_data.size() != mod_size) { + return absl::InternalError(absl::StrCat( + "Expected blind data size = ", mod_size, + " actual blind data size = ", blinded_data.size(), " bytes.")); + } + // Compute a raw RSA signature. + std::string signature(mod_size, 0); + size_t out_len; + if (RSA_sign_raw(/*rsa=*/rsa_key, /*out_len=*/&out_len, + /*out=*/reinterpret_cast(&signature[0]), + /*max_out=*/mod_size, + /*in=*/reinterpret_cast(&blinded_data[0]), + /*in_len=*/mod_size, + /*padding=*/RSA_NO_PADDING) != kBsslSuccess) { + return absl::InternalError( + "RSA_sign_raw failed when called from RsaBlindSigner::Sign"); + } + if (out_len != mod_size && out_len == signature.size()) { + return absl::InternalError(absl::StrCat( + "Expected value of out_len = ", mod_size, + " bytes, actual value of out_len and signature.size() = ", out_len, + " and ", signature.size(), " bytes.")); + } + return signature; +} + +absl::StatusOr TestSignWithPublicMetadata( + const absl::string_view blinded_data, absl::string_view public_metadata, + const RSA& rsa_key) { + if (blinded_data.empty()) { + return absl::InvalidArgumentError("blinded_data string is empty."); + } else if (blinded_data.size() != RSA_size(&rsa_key)) { + return absl::InternalError(absl::StrCat( + "Expected blind data size = ", RSA_size(&rsa_key), + " actual blind data size = ", blinded_data.size(), " bytes.")); + } + ANON_TOKENS_ASSIGN_OR_RETURN( + bssl::UniquePtr new_e, + ComputeFinalExponentUnderPublicMetadata( + *RSA_get0_n(&rsa_key), *RSA_get0_e(&rsa_key), public_metadata)); + // Compute phi(p) = p-1 + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr phi_p, NewBigNum()); + if (BN_sub(phi_p.get(), RSA_get0_p(&rsa_key), BN_value_one()) != 1) { + return absl::InternalError( + absl::StrCat("Unable to compute phi(p): ", GetSslErrors())); + } + // Compute phi(q) = q-1 + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr phi_q, NewBigNum()); + if (BN_sub(phi_q.get(), RSA_get0_q(&rsa_key), BN_value_one()) != 1) { + return absl::InternalError( + absl::StrCat("Unable to compute phi(q): ", GetSslErrors())); + } + // Compute phi(n) = phi(p)*phi(q) + ANON_TOKENS_ASSIGN_OR_RETURN(auto ctx, GetAndStartBigNumCtx()); + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr phi_n, NewBigNum()); + if (BN_mul(phi_n.get(), phi_p.get(), phi_q.get(), ctx.get()) != 1) { + return absl::InternalError( + absl::StrCat("Unable to compute phi(n): ", GetSslErrors())); + } + // Compute lcm(phi(p), phi(q)). + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr lcm, NewBigNum()); + if (BN_rshift1(lcm.get(), phi_n.get()) != 1) { + return absl::InternalError(absl::StrCat( + "Could not compute LCM(phi(p), phi(q)): ", GetSslErrors())); + } + // Compute the new private exponent new_d + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr new_d, NewBigNum()); + if (!BN_mod_inverse(new_d.get(), new_e.get(), lcm.get(), ctx.get())) { + return absl::InternalError( + absl::StrCat("Could not compute private exponent d: ", GetSslErrors())); + } + + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr input_bn, + StringToBignum(blinded_data)); + if (BN_ucmp(input_bn.get(), RSA_get0_n(&rsa_key)) >= 0) { + return absl::InvalidArgumentError( + "RsaSign input size too large for modulus size"); + } + + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr result, NewBigNum()); + if (!BN_mod_exp(result.get(), input_bn.get(), new_d.get(), + RSA_get0_n(&rsa_key), ctx.get())) { + return absl::InternalError( + "BN_mod_exp failed in TestSignWithPublicMetadata"); + } + + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr vrfy, NewBigNum()); + if (vrfy == nullptr || + !BN_mod_exp(vrfy.get(), result.get(), new_e.get(), RSA_get0_n(&rsa_key), + ctx.get()) || + BN_cmp(vrfy.get(), input_bn.get()) != 0) { + return absl::InternalError("Signature verification failed in RsaSign"); + } + + return BignumToString(*result, BN_num_bytes(RSA_get0_n(&rsa_key))); +} + +absl::StatusOr EncodeMessageForTests(absl::string_view message, + RSAPublicKey public_key, + const EVP_MD* sig_hasher, + const EVP_MD* mgf1_hasher, + int32_t salt_length) { + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr rsa_modulus, + StringToBignum(public_key.n())); + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr e, + StringToBignum(public_key.e())); + // Convert to OpenSSL RSA. + bssl::UniquePtr rsa_public_key(RSA_new()); + if (!rsa_public_key.get()) { + return absl::InternalError( + absl::StrCat("RSA_new failed: ", GetSslErrors())); + } else if (RSA_set0_key(rsa_public_key.get(), rsa_modulus.release(), + e.release(), nullptr) != kBsslSuccess) { + return absl::InternalError( + absl::StrCat("RSA_set0_key failed: ", GetSslErrors())); + } + + const int padded_len = RSA_size(rsa_public_key.get()); + std::vector padded(padded_len); + ANON_TOKENS_ASSIGN_OR_RETURN(std::string digest, + ComputeHash(message, *sig_hasher)); + if (RSA_padding_add_PKCS1_PSS_mgf1( + /*rsa=*/rsa_public_key.get(), /*EM=*/padded.data(), + /*mHash=*/reinterpret_cast(&digest[0]), /*Hash=*/sig_hasher, + /*mgf1Hash=*/mgf1_hasher, + /*sLen=*/salt_length) != kBsslSuccess) { + return absl::InternalError( + "RSA_padding_add_PKCS1_PSS_mgf1 failed when called from " + "testing_utils"); + } + std::string encoded_message(padded.begin(), padded.end()); + return encoded_message; +} + +absl::StatusOr> GetStandardRsaKeyPair( + int modulus_size_in_bytes) { + ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr rsa_f4, NewBigNum()); + BN_set_u64(rsa_f4.get(), RSA_F4); + ANON_TOKENS_ASSIGN_OR_RETURN( + bssl::UniquePtr rsa_key, + GenerateRSAKey(modulus_size_in_bytes * 8, *rsa_f4)); + + RSAPublicKey rsa_public_key; + ANON_TOKENS_ASSIGN_OR_RETURN( + *rsa_public_key.mutable_n(), + BignumToString(*RSA_get0_n(rsa_key.get()), modulus_size_in_bytes)); + ANON_TOKENS_ASSIGN_OR_RETURN( + *rsa_public_key.mutable_e(), + BignumToString(*RSA_get0_e(rsa_key.get()), modulus_size_in_bytes)); + + RSAPrivateKey rsa_private_key; + ANON_TOKENS_ASSIGN_OR_RETURN( + *rsa_private_key.mutable_n(), + BignumToString(*RSA_get0_n(rsa_key.get()), modulus_size_in_bytes)); + ANON_TOKENS_ASSIGN_OR_RETURN( + *rsa_private_key.mutable_e(), + BignumToString(*RSA_get0_e(rsa_key.get()), modulus_size_in_bytes)); + ANON_TOKENS_ASSIGN_OR_RETURN( + *rsa_private_key.mutable_d(), + BignumToString(*RSA_get0_d(rsa_key.get()), modulus_size_in_bytes)); + ANON_TOKENS_ASSIGN_OR_RETURN( + *rsa_private_key.mutable_p(), + BignumToString(*RSA_get0_p(rsa_key.get()), modulus_size_in_bytes)); + ANON_TOKENS_ASSIGN_OR_RETURN( + *rsa_private_key.mutable_q(), + BignumToString(*RSA_get0_q(rsa_key.get()), modulus_size_in_bytes)); + ANON_TOKENS_ASSIGN_OR_RETURN( + *rsa_private_key.mutable_dp(), + BignumToString(*RSA_get0_dmp1(rsa_key.get()), modulus_size_in_bytes)); + ANON_TOKENS_ASSIGN_OR_RETURN( + *rsa_private_key.mutable_dq(), + BignumToString(*RSA_get0_dmq1(rsa_key.get()), modulus_size_in_bytes)); + ANON_TOKENS_ASSIGN_OR_RETURN( + *rsa_private_key.mutable_crt(), + BignumToString(*RSA_get0_iqmp(rsa_key.get()), modulus_size_in_bytes)); + + return std::make_pair(std::move(rsa_public_key), std::move(rsa_private_key)); +} + +absl::StatusOr> GetStrongRsaKeys2048() { + std::string path = absl::StrCat(quiche::test::QuicheGetCommonSourcePath(), + "/anonymous_tokens/testdata/strong_rsa_modulus2048_example.binarypb"); + ANON_TOKENS_ASSIGN_OR_RETURN(auto key_pair, ParseRsaKeysFromFile(path)); + return std::make_pair(std::move(key_pair.first), std::move(key_pair.second)); +} + +absl::StatusOr> +GetAnotherStrongRsaKeys2048() { + std::string path = absl::StrCat(quiche::test::QuicheGetCommonSourcePath(), + "/anonymous_tokens/testdata/strong_rsa_modulus2048_example_2.binarypb"); + ANON_TOKENS_ASSIGN_OR_RETURN(auto key_pair, ParseRsaKeysFromFile(path)); + return std::make_pair(std::move(key_pair.first), std::move(key_pair.second)); +} + +absl::StatusOr> GetStrongRsaKeys3072() { + std::string path = absl::StrCat(quiche::test::QuicheGetCommonSourcePath(), + "/anonymous_tokens/testdata/strong_rsa_modulus3072_example.binarypb"); + ANON_TOKENS_ASSIGN_OR_RETURN(auto key_pair, ParseRsaKeysFromFile(path)); + return std::make_pair(std::move(key_pair.first), std::move(key_pair.second)); +} + +absl::StatusOr> GetStrongRsaKeys4096() { + std::string path = absl::StrCat(quiche::test::QuicheGetCommonSourcePath(), + "/anonymous_tokens/testdata/strong_rsa_modulus4096_example.binarypb"); + ANON_TOKENS_ASSIGN_OR_RETURN(auto key_pair, ParseRsaKeysFromFile(path)); + return std::make_pair(std::move(key_pair.first), std::move(key_pair.second)); +} + +IetfStandardRsaBlindSignatureTestVector +GetIetfStandardRsaBlindSignatureTestVector() { + IetfStandardRsaBlindSignatureTestVector test_vector = { + // n + absl::HexStringToBytes( + "aec4d69addc70b990ea66a5e70603b6fee27aafebd08f2d94cbe1250c556e047a9" + "28d635c3f45ee9b66d1bc628a03bac9b7c3f416fe20dabea8f3d7b4bbf7f963be3" + "35d2328d67e6c13ee4a8f955e05a3283720d3e1f139c38e43e0338ad058a9495c5" + "3377fc35be64d208f89b4aa721bf7f7d3fef837be2a80e0f8adf0bcd1eec5bb040" + "443a2b2792fdca522a7472aed74f31a1ebe1eebc1f408660a0543dfe2a850f106a" + "617ec6685573702eaaa21a5640a5dcaf9b74e397fa3af18a2f1b7c03ba91a63361" + "58de420d63188ee143866ee415735d155b7c2d854d795b7bc236cffd71542df342" + "34221a0413e142d8c61355cc44d45bda94204974557ac2704cd8b593f035a5724b" + "1adf442e78c542cd4414fce6f1298182fb6d8e53cef1adfd2e90e1e4deec52999b" + "dc6c29144e8d52a125232c8c6d75c706ea3cc06841c7bda33568c63a6c03817f72" + "2b50fcf898237d788a4400869e44d90a3020923dc646388abcc914315215fcd1ba" + "e11b1c751fd52443aac8f601087d8d42737c18a3fa11ecd4131ecae017ae0a14ac" + "fc4ef85b83c19fed33cfd1cd629da2c4c09e222b398e18d822f77bb378dea3cb36" + "0b605e5aa58b20edc29d000a66bd177c682a17e7eb12a63ef7c2e4183e0d898f3d" + "6bf567ba8ae84f84f1d23bf8b8e261c3729e2fa6d07b832e07cddd1d14f55325c6" + "f924267957121902dc19b3b32948bdead5"), + // e + absl::HexStringToBytes("010001"), + // d + absl::HexStringToBytes( + "0d43242aefe1fb2c13fbc66e20b678c4336d20b1808c558b6e62ad16a287077180b1" + "77e1f01b12f9c6cd6c52630257ccef26a45135a990928773f3bd2fc01a313f1dac97" + "a51cec71cb1fd7efc7adffdeb05f1fb04812c924ed7f4a8269925dad88bd7dcfbc4e" + "f01020ebfc60cb3e04c54f981fdbd273e69a8a58b8ceb7c2d83fbcbd6f784d052201" + "b88a9848186f2a45c0d2826870733e6fd9aa46983e0a6e82e35ca20a439c5ee7b502" + "a9062e1066493bdadf8b49eb30d9558ed85abc7afb29b3c9bc644199654a4676681a" + "f4babcea4e6f71fe4565c9c1b85d9985b84ec1abf1a820a9bbebee0df1398aae2c85" + "ab580a9f13e7743afd3108eb32100b870648fa6bc17e8abac4d3c99246b1f0ea9f7f" + "93a5dd5458c56d9f3f81ff2216b3c3680a13591673c43194d8e6fc93fc1e37ce2986" + "bd628ac48088bc723d8fbe293861ca7a9f4a73e9fa63b1b6d0074f5dea2a624c5249" + "ff3ad811b6255b299d6bc5451ba7477f19c5a0db690c3e6476398b1483d10314afd3" + "8bbaf6e2fbdbcd62c3ca9797a420ca6034ec0a83360a3ee2adf4b9d4ba29731d131b" + "099a38d6a23cc463db754603211260e99d19affc902c915d7854554aabf608e3ac52" + "c19b8aa26ae042249b17b2d29669b5c859103ee53ef9bdc73ba3c6b537d5c34b6d8f" + "034671d7f3a8a6966cc4543df223565343154140fd7391c7e7be03e241f4ecfeb877" + "a051"), + // p + absl::HexStringToBytes( + "e1f4d7a34802e27c7392a3cea32a262a34dc3691bd87f3f310dc75673488930559c1" + "20fd0410194fb8a0da55bd0b81227e843fdca6692ae80e5a5d414116d4803fca7d8c" + "30eaaae57e44a1816ebb5c5b0606c536246c7f11985d731684150b63c9a3ad9e41b0" + "4c0b5b27cb188a692c84696b742a80d3cd00ab891f2457443dadfeba6d6daf108602" + "be26d7071803c67105a5426838e6889d77e8474b29244cefaf418e381b312048b457" + "d73419213063c60ee7b0d81820165864fef93523c9635c22210956e53a8d96322493" + "ffc58d845368e2416e078e5bcb5d2fd68ae6acfa54f9627c42e84a9d3f2774017e32" + "ebca06308a12ecc290c7cd1156dcccfb2311"), + // q + absl::HexStringToBytes( + "c601a9caea66dc3835827b539db9df6f6f5ae77244692780cd334a006ab353c80642" + "6b60718c05245650821d39445d3ab591ed10a7339f15d83fe13f6a3dfb20b9452c6a" + "9b42eaa62a68c970df3cadb2139f804ad8223d56108dfde30ba7d367e9b0a7a80c4f" + "dba2fd9dde6661fc73fc2947569d2029f2870fc02d8325acf28c9afa19ecf962daa7" + "916e21afad09eb62fe9f1cf91b77dc879b7974b490d3ebd2e95426057f35d0a3c9f4" + "5f79ac727ab81a519a8b9285932d9b2e5ccd347e59f3f32ad9ca359115e7da008ab7" + "406707bd0e8e185a5ed8758b5ba266e8828f8d863ae133846304a2936ad7bc7c9803" + "879d2fc4a28e69291d73dbd799f8bc238385"), + // message + absl::HexStringToBytes("8f3dc6fb8c4a02f4d6352edf0907822c1210a" + "9b32f9bdda4c45a698c80023aa6b5" + "9f8cfec5fdbb36331372ebefedae7d"), + // salt + absl::HexStringToBytes("051722b35f458781397c3a671a7d3bd3096503940e4c4f1aa" + "a269d60300ce449555cd7340100df9d46944c5356825abf"), + // inv + absl::HexStringToBytes( + "80682c48982407b489d53d1261b19ec8627d02b8cda5336750b8cee332ae260de57b" + "02d72609c1e0e9f28e2040fc65b6f02d56dbd6aa9af8fde656f70495dfb723ba0117" + "3d4707a12fddac628ca29f3e32340bd8f7ddb557cf819f6b01e445ad96f874ba2355" + "84ee71f6581f62d4f43bf03f910f6510deb85e8ef06c7f09d9794a008be7ff2529f0" + "ebb69decef646387dc767b74939265fec0223aa6d84d2a8a1cc912d5ca25b4e144ab" + "8f6ba054b54910176d5737a2cff011da431bd5f2a0d2d66b9e70b39f4b050e45c0d9" + "c16f02deda9ddf2d00f3e4b01037d7029cd49c2d46a8e1fc2c0c17520af1f4b5e25b" + "a396afc4cd60c494a4c426448b35b49635b337cfb08e7c22a39b256dd032c00addda" + "fb51a627f99a0e1704170ac1f1912e49d9db10ec04c19c58f420212973e0cb329524" + "223a6aa56c7937c5dffdb5d966b6cd4cbc26f3201dd25c80960a1a111b32947bb789" + "73d269fac7f5186530930ed19f68507540eed9e1bab8b00f00d8ca09b3f099aae461" + "80e04e3584bd7ca054df18a1504b89d1d1675d0966c4ae1407be325cdf623cf13ff1" + "3e4a28b594d59e3eadbadf6136eee7a59d6a444c9eb4e2198e8a974f27a39eb63af2" + "c9af3870488b8adaad444674f512133ad80b9220e09158521614f1faadfe8505ef57" + "b7df6813048603f0dd04f4280177a11380fbfc861dbcbd7418d62155248dad5fdec0" + "991f"), + // encoded_message + absl::HexStringToBytes( + "6e0c464d9c2f9fbc147b43570fc4f238e0d0b38870b3addcf7a4217df912ccef17a7" + "f629aa850f63a063925f312d61d6437be954b45025e8282f9c0b1131bc8ff19a8a92" + "8d859b37113db1064f92a27f64761c181c1e1f9b251ae5a2f8a4047573b67a270584" + "e089beadcb13e7c82337797119712e9b849ff56e04385d144d3ca9d8d92bf78adb20" + "b5bbeb3685f17038ec6afade3ef354429c51c687b45a7018ee3a6966b3af15c9ba8f" + "40e6461ba0a17ef5a799672ad882bab02b518f9da7c1a962945c2e9b0f02f29b31b9" + "cdf3e633f9d9d2a22e96e1de28e25241ca7dd04147112f578973403e0f4fd8086596" + "5475d22294f065e17a1c4a201de93bd14223e6b1b999fd548f2f759f52db71964528" + "b6f15b9c2d7811f2a0a35d534b8216301c47f4f04f412cae142b48c4cdff78bc54df" + "690fd43142d750c671dd8e2e938e6a440b2f825b6dbb3e19f1d7a3c0150428a47948" + "037c322365b7fe6fe57ac88d8f80889e9ff38177bad8c8d8d98db42908b389cb5969" + "2a58ce275aa15acb032ca951b3e0a3404b7f33f655b7c7d83a2f8d1b6bbff49d5fce" + "df2e030e80881aa436db27a5c0dea13f32e7d460dbf01240c2320c2bb5b3225b1714" + "5c72d61d47c8f84d1e19417ebd8ce3638a82d395cc6f7050b6209d9283dc7b93fecc" + "04f3f9e7f566829ac41568ef799480c733c09759aa9734e2013d7640dc6151018ea9" + "02bc"), + // blinded_message + absl::HexStringToBytes( + "10c166c6a711e81c46f45b18e5873cc4f494f003180dd7f115585d871a2893025965" + "4fe28a54dab319cc5011204c8373b50a57b0fdc7a678bd74c523259dfe4fd5ea9f52" + "f170e19dfa332930ad1609fc8a00902d725cfe50685c95e5b2968c9a2828a21207fc" + "f393d15f849769e2af34ac4259d91dfd98c3a707c509e1af55647efaa31290ddf48e" + "0133b798562af5eabd327270ac2fb6c594734ce339a14ea4fe1b9a2f81c0bc230ca5" + "23bda17ff42a377266bc2778a274c0ae5ec5a8cbbe364fcf0d2403f7ee178d77ff28" + "b67a20c7ceec009182dbcaa9bc99b51ebbf13b7d542be337172c6474f2cd3561219f" + "e0dfa3fb207cff89632091ab841cf38d8aa88af6891539f263adb8eac6402c41b6eb" + "d72984e43666e537f5f5fe27b2b5aa114957e9a580730308a5f5a9c63a1eb599f093" + "ab401d0c6003a451931b6d124180305705845060ebba6b0036154fcef3e5e9f9e4b8" + "7e8f084542fd1dd67e7782a5585150181c01eb6d90cb95883837384a5b91dbb606f2" + "66059ecc51b5acbaa280e45cfd2eec8cc1cdb1b7211c8e14805ba683f9b78824b2eb" + "005bc8a7d7179a36c152cb87c8219e5569bba911bb32a1b923ca83de0e03fb10fba7" + "5d85c55907dda5a2606bf918b056c3808ba496a4d95532212040a5f44f37e1097f26" + "dc27b98a51837daa78f23e532156296b64352669c94a8a855acf30533d8e0594ace7" + "c442"), + // blinded_signature + absl::HexStringToBytes( + "364f6a40dbfbc3bbb257943337eeff791a0f290898a6791283bba581d9eac90a6376" + "a837241f5f73a78a5c6746e1306ba3adab6067c32ff69115734ce014d354e2f259d4" + "cbfb890244fd451a497fe6ecf9aa90d19a2d441162f7eaa7ce3fc4e89fd4e76b7ae5" + "85be2a2c0fd6fb246b8ac8d58bcb585634e30c9168a434786fe5e0b74bfe8187b47a" + "c091aa571ffea0a864cb906d0e28c77a00e8cd8f6aba4317a8cc7bf32ce566bd1ef8" + "0c64de041728abe087bee6cadd0b7062bde5ceef308a23bd1ccc154fd0c3a26110df" + "6193464fc0d24ee189aea8979d722170ba945fdcce9b1b4b63349980f3a92dc2e541" + "8c54d38a862916926b3f9ca270a8cf40dfb9772bfbdd9a3e0e0892369c18249211ba" + "857f35963d0e05d8da98f1aa0c6bba58f47487b8f663e395091275f82941830b050b" + "260e4767ce2fa903e75ff8970c98bfb3a08d6db91ab1746c86420ee2e909bf681cac" + "173697135983c3594b2def673736220452fde4ddec867d40ff42dd3da36c84e3e525" + "08b891a00f50b4f62d112edb3b6b6cc3dbd546ba10f36b03f06c0d82aeec3b25e127" + "af545fac28e1613a0517a6095ad18a98ab79f68801e05c175e15bae21f821e80c80a" + "b4fdec6fb34ca315e194502b8f3dcf7892b511aee45060e3994cd15e003861bc7220" + "a2babd7b40eda03382548a34a7110f9b1779bf3ef6011361611e6bc5c0dc851e1509" + "de1a"), + // signature + absl::HexStringToBytes( + "6fef8bf9bc182cd8cf7ce45c7dcf0e6f3e518ae48f06f3c670c649ac737a8b8119" + "a34d51641785be151a697ed7825fdfece82865123445eab03eb4bb91cecf4d6951" + "738495f8481151b62de869658573df4e50a95c17c31b52e154ae26a04067d5ecdc" + "1592c287550bb982a5bb9c30fd53a768cee6baabb3d483e9f1e2da954c7f4cf492" + "fe3944d2fe456c1ecaf0840369e33fb4010e6b44bb1d721840513524d8e9a3519f" + "40d1b81ae34fb7a31ee6b7ed641cb16c2ac999004c2191de0201457523f5a4700d" + "d649267d9286f5c1d193f1454c9f868a57816bf5ff76c838a2eeb616a3fc9976f6" + "5d4371deecfbab29362caebdff69c635fe5a2113da4d4d8c24f0b16a0584fa05e8" + "0e607c5d9a2f765f1f069f8d4da21f27c2a3b5c984b4ab24899bef46c6d9323df4" + "862fe51ce300fca40fb539c3bb7fe2dcc9409e425f2d3b95e70e9c49c5feb6ecc9" + "d43442c33d50003ee936845892fb8be475647da9a080f5bc7f8a716590b3745c22" + "09fe05b17992830ce15f32c7b22cde755c8a2fe50bd814a0434130b807dc1b7218" + "d4e85342d70695a5d7f29306f25623ad1e8aa08ef71b54b8ee447b5f64e73d09bd" + "d6c3b7ca224058d7c67cc7551e9241688ada12d859cb7646fbd3ed8b34312f3b49" + "d69802f0eaa11bc4211c2f7a29cd5c01ed01a39001c5856fab36228f5ee2f2e111" + "0811872fe7c865c42ed59029c706195d52"), + }; + return test_vector; +} + +std::vector +GetIetfRsaBlindSignatureWithPublicMetadataTestVectors() { + // n + std::string n = absl::HexStringToBytes( + "d6930820f71fe517bf3259d14d40209b02a5c0d3d61991c731dd7da39f8d69821552" + "e2318d6c9ad897e603887a476ea3162c1205da9ac96f02edf31df049bd55f142134c" + "17d4382a0e78e275345f165fbe8e49cdca6cf5c726c599dd39e09e75e0f330a33121" + "e73976e4facba9cfa001c28b7c96f8134f9981db6750b43a41710f51da4240fe0310" + "6c12acb1e7bb53d75ec7256da3fddd0718b89c365410fce61bc7c99b115fb4c3c318" + "081fa7e1b65a37774e8e50c96e8ce2b2cc6b3b367982366a2bf9924c4bafdb3ff5e7" + "22258ab705c76d43e5f1f121b984814e98ea2b2b8725cd9bc905c0bc3d75c2a8db70" + "a7153213c39ae371b2b5dc1dafcb19d6fae9"); + std::string e = absl::HexStringToBytes("010001"); + std::string d = absl::HexStringToBytes( + "4e21356983722aa1adedb084a483401c1127b781aac89eab103e1cfc52215494981d" + "18dd8028566d9d499469c25476358de23821c78a6ae43005e26b394e3051b5ca206a" + "a9968d68cae23b5affd9cbb4cb16d64ac7754b3cdba241b72ad6ddfc000facdb0f0d" + "d03abd4efcfee1730748fcc47b7621182ef8af2eeb7c985349f62ce96ab373d2689b" + "aeaea0e28ea7d45f2d605451920ca4ea1f0c08b0f1f6711eaa4b7cca66d58a6b916f" + "9985480f90aca97210685ac7b12d2ec3e30a1c7b97b65a18d38a93189258aa346bf2" + "bc572cd7e7359605c20221b8909d599ed9d38164c9c4abf396f897b9993c1e805e57" + "4d704649985b600fa0ced8e5427071d7049d"); + std::string p = absl::HexStringToBytes( + "dcd90af1be463632c0d5ea555256a20605af3db667475e190e3af12a34a3324c46a3" + "094062c59fb4b249e0ee6afba8bee14e0276d126c99f4784b23009bf6168ff628ac1" + "486e5ae8e23ce4d362889de4df63109cbd90ef93db5ae64372bfe1c55f832766f21e" + "94ea3322eb2182f10a891546536ba907ad74b8d72469bea396f3"); + std::string q = absl::HexStringToBytes( + "f8ba5c89bd068f57234a3cf54a1c89d5b4cd0194f2633ca7c60b91a795a56fa8c868" + "6c0e37b1c4498b851e3420d08bea29f71d195cfbd3671c6ddc49cf4c1db5b478231e" + "a9d91377ffa98fe95685fca20ba4623212b2f2def4da5b281ed0100b651f6db32112" + "e4017d831c0da668768afa7141d45bbc279f1e0f8735d74395b3"); + + std::vector test_vectors; + // test_vector 1. + test_vectors.push_back({ + n, + e, + d, + p, + q, + // message + absl::HexStringToBytes("68656c6c6f20776f726c64"), + // public_metadata + absl::HexStringToBytes("6d65746164617461"), + // message_mask + absl::HexStringToBytes( + "64b5c5d2b2ca672690df59bab774a389606d85d56f92a18a57c42eb4cb164d43"), + // blinded_message + absl::HexStringToBytes( + "1b9e1057dd2d05a17ad2feba5f87a4083cc825fe06fc70f0b782062ea0043fa65ec8" + "096ce5d403cfa2aa3b11195b2a655d694386058f6266450715a936b5764f42977c0a" + "0933ff3054d456624734fd2c019def792f00d30b3ac2f27859ea56d835f80564a3ba" + "59f3c876dc926b2a785378ca83f177f7b378513b36a074e7db59448fd4007b54c647" + "91a33b61721ab3b5476165193af30f25164d480684d045a8d0782a53dd73774563e8" + "d29e48b175534f696763abaab49fa03a055ec9246c5e398a5563cc88d02eb57d725d" + "3fc9231ae5139aa7fcb9941060b0bf0192b8c81944fa0c54568b0ab4ea9c4c4c9829" + "d6dbcbf8b48006b322ee51d784ac93e4bf13"), + // blinded_signature + absl::HexStringToBytes( + "7ef75d9887f29f2232602acab43263afaea70313a0c90374388df5a7a7440d2584c4" + "b4e5b886accc065bf4824b4b22370ddde7fea99d4cd67f8ed2e4a6a2b7b5869e8d4d" + "0c52318320c5bf7b9f02bb132af7365c471e799edd111ca9441934c7db76c164b051" + "5afc5607b8ceb584f5b1d2177d5180e57218265c07aec9ebde982f3961e7ddaa432e" + "47297884da8f4512fe3dc9ab820121262e6a73850920299999c293b017cd800c6ec9" + "94f76b6ace35ff4232f9502e6a52262e19c03de7cc27d95ccbf4c381d698fcfe1f20" + "0209814e04ae2d6279883015bbf36cabf3e2350be1e175020ee9f4bb861ba409b467" + "e23d08027a699ac36b2e5ab988390f3c0ee9"), + // signature + absl::HexStringToBytes( + "abd6813bb4bbe3bc8dc9f8978655b22305e5481b35c5bdc4869b60e2d5cc74b84356" + "416abaaca0ca8602cd061248587f0d492fee3534b19a3fe089de18e4df9f3a6ad289" + "afb5323d7934487b8fafd25943766072bab873fa9cd69ce7328a57344c2c529fe969" + "83ca701483ca353a98a1a9610391b7d32b13e14e8ef87d04c0f56a724800655636cf" + "ff280d35d6b468f68f09f56e1b3acdb46bc6634b7a1eab5c25766cec3b5d97c37bbc" + "a302286c17ff557bcf1a4a0e342ea9b2713ab7f935c8174377bace2e5926b3983407" + "9761d9121f5df1fad47a51b03eab3d84d050c99cf1f68718101735267cca3213c0a4" + "6c0537887ffe92ca05371e26d587313cc3f4"), + }); + + // test_vector 2. + test_vectors.push_back({ + n, + e, + d, + p, + q, + // message + absl::HexStringToBytes("68656c6c6f20776f726c64"), + // public_metadata + "", + // message_mask + absl::HexStringToBytes( + "ebb56541b9a1758028033cfb085a4ffe048f072c6c82a71ce21d40842b5c0a89"), + // blinded_message + absl::HexStringToBytes( + "d1fc97f30efbf116fadd9895130cdd55f939211f7db19ce9a85287227a02b33fb698" + "b52399f81be0e1f598482000202ec89968085753eae1810f14676b514e08238c8aa7" + "9d8b999af54e9f4282c6220d4d760716e48e5413f3228cc59ce10b8252916640de7b" + "9b5c7dc9c2bff9f53b4fb5eb4a5f8bab49af3fd1b955d34312073d15030e7fdb44bd" + "b23460d1c5662597f9947092def7fff955a5f3e63419ae9858c6405f9609b63c4331" + "e0cf90d24c196bee554f2b78e0d8f6da3d4308c8d4ae9fbe18a8bb7fa4fc3b9cacd4" + "263e5bd6e12ed891cfdfba8b50d0f37d7a9abe065238367907c685ed2c224924caf5" + "d8fe41f5db898b09a0501d318d9f65d88cb8"), + // blinded_signature + absl::HexStringToBytes( + "400c1bcdfa56624f15d04f6954908b5605dbeff4cd56f384d7531669970290d70652" + "9d44cde4c972a1399635525a2859ef1d914b4130068ed407cfda3bd9d1259790a30f" + "6d8c07d190aa98bf21ae9581e5d61801565d96e9eec134335958b3d0b905739e2fd9" + "f39074da08f869089fe34de2d218062afa16170c1505c67b65af4dcc2f1aeccd4827" + "5c3dacf96116557b7f8c7044d84e296a0501c511ba1e6201703e1dd834bf47a96e1a" + "c4ec9b935233ed751239bd4b514b031522cd51615c1555e520312ed1fa43f55d4abe" + "b222ee48b4746c79006966590004714039bac7fd18cdd54761924d91a4648e871458" + "937061ef6549dd12d76e37ed417634d88914"), + // signature + absl::HexStringToBytes( + "4062960edb71cc071e7d101db4f595aae4a98e0bfe6843aca3e5f48c9dfb46d505e8" + "c19806ffa07f040313d44d0996ef9f69a86fa5946cb818a32627fe2df2a0e8035028" + "8ae4fedfbee4193554cc1433d9d27639db8b4635265504d87dca7054c85e0c882d32" + "887534405e6cc4e7eb4b174383e5ce4eebbfffb217f353102f6d1a0461ef89238de3" + "1b0a0c134dfac0d2a8c533c807ccdd557c6510637596a490d5258b77410421be4076" + "ecdf2d7e9044327e36e349751f3239681bba10fe633f1b246f5a9f694706316898c9" + "00af2294f47267f2e9ad1e61c7f56bf643280258875d29f3745dfdb74b9bbcd5fe3d" + "ea62d9be85e2c6f5aed68bc79f8b4a27b3de"), + }); + + // test_vector 3. + test_vectors.push_back({ + n, + e, + d, + p, + q, + // message + "", + // public_metadata + absl::HexStringToBytes("6d65746164617461"), + // message_mask + absl::HexStringToBytes( + "f2a4ed7c5aa338430c7026d7d92017f994ca1c8b123b236dae8666b1899059d0"), + // blinded_message + absl::HexStringToBytes( + "7756a1f89fa33cfc083567e02fd865d07d6e5cd4943f030a2f94b5c23f3fe79c83c4" + "9c594247d02885e2cd161638cff60803184c9e802a659d76a1c53340972e62e728cc" + "70cf684ef03ce2d05cefc729e6eee2ae46afa17b6b27a64f91e4c46cc12adc58d9cb" + "61a4306dac732c9789199cfe8bd28359d1911678e9709bc159dae34ac7aa59fd0c95" + "962c9f4904bf04aaba8a7e774735bd03be4a02fb0864a53354a2e2f3502506318a5b" + "03961366005c7b120f0e6b87b44bc15658c3e8985d69f6adea38c24fe5e7b4bafa1a" + "d6dc7d729281c26dffc88bd34fcc5a5f9df9b9781f99ea47472ba8bd679aaada5952" + "5b978ebc8a3ea2161de84b7398e4878b751b"), + // blinded_signature + absl::HexStringToBytes( + "2a13f73e4e255a9d5bc6f76cf48dfbf189581c2b170600fd3ab1a3def14884621323" + "9b9d0a981537541cb4f481a602aeebca9ef28c9fcdc63d15d4296f85d864f799edf0" + "8e9045180571ce1f1d3beff293b18aae9d8845068cc0d9a05b822295042dc56a1a2b" + "604c51aa65fd89e6d163fe1eac63cf603774797b7936a8b7494d43fa37039d3777b8" + "e57cf0d95227ab29d0bd9c01b3eae9dde5fca7141919bd83a17f9b1a3b401507f3e3" + "a8e8a2c8eb6c5c1921a781000fee65b6dd851d53c89cba2c3375f0900001c0485594" + "9b7fa499f2a78089a6f0c9b4d36fdfcac2d846076736c5eaedaf0ae70860633e51b0" + "de21d96c8b43c600afa2e4cc64cd66d77a8f"), + // signature + absl::HexStringToBytes( + "67985949f4e7c91edd5647223170d2a9b6611a191ca48ceadb6c568828b4c415b627" + "0b037cd8a68b5bca1992eb769aaef04549422889c8b156b9378c50e8a31c07dc1fe0" + "a80d25b870fadbcc1435197f0a31723740f3084ecb4e762c623546f6bd7d072aa565" + "bc2105b954244a2b03946c7d4093ba1216ec6bb65b8ca8d2f3f3c43468e80b257c54" + "a2c2ea15f640a08183a00488c7772b10df87232ee7879bee93d17e194d6b703aeceb" + "348c1b02ec7ce202086b6494f96a0f2d800f12e855f9c33dcd3abf6bd8044efd69d4" + "594a974d6297365479fe6c11f6ecc5ea333031c57deb6e14509777963a25cdf8db62" + "d6c8c68aa038555e4e3ae4411b28e43c8f57"), + }); + + // test_vector 4. + test_vectors.push_back({ + n, + e, + d, + p, + q, + // message + "", + // public_metadata + "", + // message_mask + absl::HexStringToBytes( + "ba3ea4b1e475eebe11d4bfe3a48521d3ba8cd62f3baed9ec29fbbf7ff0478bc0"), + // blinded_message + absl::HexStringToBytes( + "99d725c5613ff87d16464b0375b0976bf4d47319d6946e85f0d0c2ca79eb02a4c0c2" + "82642e090a910b80fee288f0b3b6777e517b757fc6c96ea44ac570216c8fcd868e15" + "da4b389b0c70898c5a2ed25c1d13451e4d407fe1301c231b4dc76826b1d4cc5e64b0" + "e28fb9c71f928ba48c87e308d851dd07fb5a7e0aa5d0dce61d1348afb4233355374e" + "5898f63adbd5ba215332d3329786fb7c30ef04c181b267562828d8cf1295f2ef4a05" + "ef1e03ed8fee65efb7725d8c8ae476f61a35987e40efc481bcb4b89cb363addfb2ad" + "acf690aff5425107d29b2a75b4665d49f255c5caa856cdc0c5667de93dbf3f500db8" + "fcce246a70a159526729d82c34df69c926a8"), + // blinded_signature + absl::HexStringToBytes( + "a9678acee80b528a836e4784f0690fdddce147e5d4ac506e9ec51c11b16ee2fd5a32" + "e382a3c3d276a681bb638b63040388d53894afab79249e159835cd6bd65849e5d139" + "7666f03d1351aaec3eae8d3e7cba3135e7ec4e7b478ef84d79d81039693adc6b130b" + "0771e3d6f0879723a20b7f72b476fe6fef6f21e00b9e3763a364ed918180f939c351" + "0bb5f46b35c06a00e51f049ade9e47a8e1c3d5689bd5a43df20b73d70dcacfeed9fa" + "23cabfbe750779997da6bc7269d08b2620acaa3daa0d9e9d4b87ef841ebcc06a4c0a" + "f13f1d13f0808f512c50898586b4fc76d2b32858a7ddf715a095b7989d8df50654e3" + "e05120a83cec275709cf79571d8f46af2b8e"), + // signature + absl::HexStringToBytes( + "ba57951810dbea7652209eb73e3b8edafc56ca7061475a048751cbfb995aeb4ccda2" + "e9eb309698b7c61012accc4c0414adeeb4b89cd29ba2b49b1cc661d5e7f30caee7a1" + "2ab36d6b52b5e4d487dbff98eb2d27d552ecd09ca022352c9480ae27e10c3a49a1fd" + "4912699cc01fba9dbbfd18d1adcec76ca4bc44100ea67b9f1e00748d80255a03371a" + "7b8f2c160cf632499cea48f99a6c2322978bd29107d0dffdd2e4934bb7dc81c90dd6" + "3ae744fd8e57bff5e83f98014ca502b6ace876b455d1e3673525ba01687dce998406" + "e89100f55316147ad510e854a064d99835554de8949d3662708d5f1e43bca473c14a" + "8b1729846c6092f18fc0e08520e9309a32de"), + }); + return test_vectors; +} + +absl::StatusOr> +GetIetfStandardRsaBlindSignatureTestKeys() { + IetfStandardRsaBlindSignatureTestVector test_vector = + GetIetfStandardRsaBlindSignatureTestVector(); + return PopulateTestVectorKeys(test_vector.n, test_vector.e, test_vector.d, + test_vector.p, test_vector.q); +} + +absl::StatusOr> +GetIetfRsaBlindSignatureWithPublicMetadataTestKeys() { + auto test_vectors = GetIetfRsaBlindSignatureWithPublicMetadataTestVectors(); + return PopulateTestVectorKeys(test_vectors[0].n, test_vectors[0].e, + test_vectors[0].d, test_vectors[0].p, + test_vectors[0].q); +} + +std::string RandomString(int n, std::uniform_int_distribution* distr_u8, + std::mt19937_64* generator) { + std::string rand(n, 0); + for (int i = 0; i < n; ++i) { + rand[i] = static_cast((*distr_u8)(*generator)); + } + return rand; +} + +} // namespace anonymous_tokens +} // namespace private_membership diff --git a/quiche/blind_sign_auth/anonymous_tokens/cpp/testing/utils.h b/quiche/blind_sign_auth/anonymous_tokens/cpp/testing/utils.h new file mode 100644 index 000000000000..5c6aa1c40a50 --- /dev/null +++ b/quiche/blind_sign_auth/anonymous_tokens/cpp/testing/utils.h @@ -0,0 +1,156 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_ANONYMOUS_TOKENS_CPP_TESTING_UTILS_H_ +#define THIRD_PARTY_ANONYMOUS_TOKENS_CPP_TESTING_UTILS_H_ + +#include + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/constants.h" +#include "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.pb.h" +#include "openssl/base.h" + +namespace private_membership { +namespace anonymous_tokens { + +struct IetfStandardRsaBlindSignatureTestVector { + std::string n; + std::string e; + std::string d; + std::string p; + std::string q; + std::string message; + std::string salt; + std::string inv; + std::string encoded_message; + std::string blinded_message; + std::string blinded_signature; + std::string signature; +}; + +struct IetfRsaBlindSignatureWithPublicMetadataTestVector { + std::string n; + std::string e; + std::string d; + std::string p; + std::string q; + std::string message; + std::string public_metadata; + std::string message_mask; + std::string blinded_message; + std::string blinded_signature; + std::string signature; +}; + +// Creates a pair containing a standard RSA Private key and an Anonymous Tokens +// RSABlindSignaturePublicKey using RSA_F4 (65537) as the public exponent and +// other input parameters. +absl::StatusOr, RSABlindSignaturePublicKey>> +CreateTestKey(int key_size = 512, HashType sig_hash = AT_HASH_TYPE_SHA384, + MaskGenFunction mfg1_hash = AT_MGF_SHA384, int salt_length = 48, + MessageMaskType message_mask_type = AT_MESSAGE_MASK_CONCAT, + int message_mask_size = kRsaMessageMaskSizeInBytes32); + +// Prepares message for signing by computing its hash and then applying the PSS +// padding to the result by executing RSA_padding_add_PKCS1_PSS_mgf1 from the +// openssl library, using the input parameters. +// +// This is a test function and it skips the message blinding part. +absl::StatusOr EncodeMessageForTests(absl::string_view message, + RSAPublicKey public_key, + const EVP_MD* sig_hasher, + const EVP_MD* mgf1_hasher, + int32_t salt_length); + +// TestSign can be removed once rsa_blind_signer is moved to +// anonympous_tokens/public/cpp/crypto +absl::StatusOr TestSign(absl::string_view blinded_data, + RSA* rsa_key); + +// TestSignWithPublicMetadata can be removed once rsa_blind_signer is moved to +// anonympous_tokens/public/cpp/crypto +absl::StatusOr TestSignWithPublicMetadata( + absl::string_view blinded_data, absl::string_view public_metadata, + const RSA& rsa_key); + +// This method returns a newly generated RSA key pair, setting the public +// exponent to be the standard RSA_F4 (65537) and the default modulus size to +// 512 bytes. +absl::StatusOr> GetStandardRsaKeyPair( + int modulus_size_in_bytes = kRsaModulusSizeInBytes512); + +// Method returns fixed 2048-bit strong RSA modulus for testing. +absl::StatusOr> GetStrongRsaKeys2048(); + +// Method returns another fixed 2048-bit strong RSA modulus for testing. +absl::StatusOr> +GetAnotherStrongRsaKeys2048(); + +// Method returns fixed 3072-bit strong RSA modulus for testing. +absl::StatusOr> GetStrongRsaKeys3072(); + +// Method returns fixed 4096-bit strong RSA modulus for testing. +absl::StatusOr> GetStrongRsaKeys4096(); + +// Returns the IETF test example from +// https://datatracker.ietf.org/doc/draft-irtf-cfrg-rsa-blind-signatures/ +IetfStandardRsaBlindSignatureTestVector +GetIetfStandardRsaBlindSignatureTestVector(); + +// This method returns a RSA key pair as described in the IETF test example +// above. +absl::StatusOr> +GetIetfStandardRsaBlindSignatureTestKeys(); + +// Returns the IETF test with Public Metadata examples from +// https://datatracker.ietf.org/doc/draft-amjad-cfrg-partially-blind-rsa/ +// +// Note that all test vectors use the same RSA key pair. +std::vector +GetIetfRsaBlindSignatureWithPublicMetadataTestVectors(); + +// This method returns a RSA key pair as described in the IETF test with Public +// Metadata example. It can be used for all test vectors returned by +// GetIetfRsaBlindSignatureWithPublicMetadataTestVectors. +absl::StatusOr> +GetIetfRsaBlindSignatureWithPublicMetadataTestKeys(); + +// Outputs a random string of n characters. +std::string RandomString(int n, std::uniform_int_distribution* distr_u8, + std::mt19937_64* generator); + +#define ANON_TOKENS_ASSERT_OK_AND_ASSIGN(lhs, rexpr) \ + ANON_TOKENS_ASSERT_OK_AND_ASSIGN_IMPL_( \ + ANON_TOKENS_STATUS_TESTING_IMPL_CONCAT_(_status_or_value, __LINE__), \ + lhs, rexpr) + +#define ANON_TOKENS_ASSERT_OK_AND_ASSIGN_IMPL_(statusor, lhs, rexpr) \ + auto statusor = (rexpr); \ + ASSERT_THAT(statusor.ok(), ::testing::Eq(true)); \ + lhs = std::move(statusor).value() + +#define ANON_TOKENS_STATUS_TESTING_IMPL_CONCAT_INNER_(x, y) x##y +#define ANON_TOKENS_STATUS_TESTING_IMPL_CONCAT_(x, y) \ + ANON_TOKENS_STATUS_TESTING_IMPL_CONCAT_INNER_(x, y) + +} // namespace anonymous_tokens +} // namespace private_membership + +#endif // THIRD_PARTY_ANONYMOUS_TOKENS_CPP_TESTING_UTILS_H_ diff --git a/quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.proto b/quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.proto new file mode 100644 index 000000000000..209ba0a33457 --- /dev/null +++ b/quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.proto @@ -0,0 +1,335 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package private_membership.anonymous_tokens; + +import "quiche/blind_sign_auth/proto/timestamp.proto"; + +// Different use cases for the Anonymous Tokens service. +// Next ID: 9 +enum AnonymousTokensUseCase { + // Test use cases here. + ANONYMOUS_TOKENS_USE_CASE_UNDEFINED = 0; + TEST_USE_CASE = 1; + TEST_USE_CASE_2 = 2; + TEST_USE_CASE_3 = 4; + TEST_USE_CASE_4 = 5; + TEST_USE_CASE_5 = 6; + + PROVABLY_PRIVATE_NETWORK = 3; + CHROME_IP_BLINDING = 7; + NOCTOGRAM_PPISSUER = 8; +} + +// An enum describing different types of available hash functions. +enum HashType { + AT_HASH_TYPE_UNDEFINED = 0; + AT_TEST_HASH_TYPE = 1; + AT_HASH_TYPE_SHA256 = 2; + AT_HASH_TYPE_SHA384 = 3; + // Add more hash types if necessary. +} + +// An enum describing different types of hash functions that can be used by the +// mask generation function. +enum MaskGenFunction { + AT_MGF_UNDEFINED = 0; + AT_TEST_MGF = 1; + AT_MGF_SHA256 = 2; + AT_MGF_SHA384 = 3; + // Add more hash types if necessary. +} + +// An enum describing different types of message masking. +enum MessageMaskType { + AT_MESSAGE_MASK_TYPE_UNDEFINED = 0; + AT_MESSAGE_MASK_XOR = 1; + AT_MESSAGE_MASK_CONCAT = 2; +} + +// Proto representation for RSA private key. +message RSAPrivateKey { + // Modulus. + bytes n = 1; + // Public exponent. + bytes e = 2; + // Private exponent. + bytes d = 3; + // The prime factor p of n. + bytes p = 4; + // The prime factor q of n. + bytes q = 5; + // d mod (p - 1). + bytes dp = 6; + // d mod (q - 1). + bytes dq = 7; + // Chinese Remainder Theorem coefficient q^(-1) mod p. + bytes crt = 8; +} + +// Proto representation for RSA public key. +message RSAPublicKey { + // Modulus. + bytes n = 1; + // Public exponent. + bytes e = 2; +} + +// Next ID: 13 +message RSABlindSignaturePublicKey { + // Use case associated with this public key. + bytes use_case = 9; + + // Version number of public key. + int64 key_version = 1; + + // Serialization of the public key. + bytes serialized_public_key = 2; + + // Timestamp of expiration. + // + // Note that we will not return keys whose expiration times are in the past. + quiche.protobuf.Timestamp expiration_time = 3; + + // Key becomes valid at key_validity_start_time. + quiche.protobuf.Timestamp key_validity_start_time = 8; + + // Hash function used in computing hash of the signing message + // (see https://tools.ietf.org/html/rfc8017#section-9.1.1) + HashType sig_hash_type = 4; + + // Hash function used in MGF1 (a mask generation function based on a + // hash function) (see https://tools.ietf.org/html/rfc8017#appendix-B.2.1). + MaskGenFunction mask_gen_function = 5; + + // Length in bytes of the salt (see + // https://tools.ietf.org/html/rfc8017#section-9.1.1) + int64 salt_length = 6; + + // Key size: bytes of RSA key. + int64 key_size = 7; + + // Type of masking of message (see https://eprint.iacr.org/2022/895.pdf). + MessageMaskType message_mask_type = 10; + + // Length of message mask in bytes. + int64 message_mask_size = 11; + + // Conveys whether public metadata support is enabled and RSA blind signatures + // with public metadata protocol should be used. If false, standard RSA blind + // signatures are used and all public metadata inputs are ignored. + bool public_metadata_support = 12; +} + +message AnonymousTokensPublicKeysGetRequest { + // Use case associated with this request. + // + // Returns an error if the token type does not support public key verification + // for the requested use_case. + bytes use_case = 1; + + // Key version associated with this request. + // + // Returns an error if the token type does not support public key verification + // for the requested use_case and key_version combination. + // + // If unset, all valid possibilities for the key are returned. + int64 key_version = 2; + + // Public key that becomes valid at or before this requested time and not + // after. More explicitly, we need the requested key to be valid at the + // requested key_validity_start_time. + // + // If unset it will be set to current time. + quiche.protobuf.Timestamp key_validity_start_time = 3 + ; + + // Public key that is definitely not valid after this particular time. If + // unset / null, only keys that are indefinitely valid are returned. + // + // Note: It is possible that the key becomes invalid before this time. But the + // key should not be valid after this time. + quiche.protobuf.Timestamp key_validity_end_time = 4 + ; +} + +message AnonymousTokensPublicKeysGetResponse { + // List of currently valid RSA public keys. + repeated RSABlindSignaturePublicKey rsa_public_keys = 1; +} + +message AnonymousTokensSignRequest { + // Next ID: 5 + message BlindedToken { + // Use case associated with this request. + bytes use_case = 1; + + // Version of key used to sign and generate the token. + int64 key_version = 2; + + // Public metadata to be tied to the `blinded message` (serialized_token). + // + // The length of public metadata must fit in 4 bytes. + bytes public_metadata = 4; + + // Serialization of the token. + bytes serialized_token = 3; + } + + // Token(s) that have been blinded by the user, not yet signed + repeated BlindedToken blinded_tokens = 1; +} + +message AnonymousTokensSignResponse { + // Next ID: 6 + message AnonymousToken { + // Use case associated with this anonymous token. + bytes use_case = 1; + + // Version of key used to sign and generate the token. + int64 key_version = 2; + + // Public metadata tied to the input (serialized_blinded_message) and the + // `blinded` signature (serialized_token). + // + // The length of public metadata must fit in 4 bytes. + bytes public_metadata = 4; + + // The serialized_token in BlindedToken in the AnonymousTokensSignRequest. + bytes serialized_blinded_message = 5; + + // Serialization of the signed token. This will have to be `unblinded` by + // the user before it can be used / redeemed. + bytes serialized_token = 3; + } + + // Returned anonymous token(s) + repeated AnonymousToken anonymous_tokens = 1; +} + +message AnonymousTokensRedemptionRequest { + // Next ID: 7 + message AnonymousTokenToRedeem { + // Use case associated with this anonymous token that needs to be redeemed. + bytes use_case = 1; + + // Version of key associated with this anonymous token that needs to be + // redeemed. + int64 key_version = 2; + + // Public metadata to be used for redeeming the signature + // (serialized_unblinded_token). + // + // The length of public metadata must fit in 4 bytes. + bytes public_metadata = 4; + + // Serialization of the unblinded anonymous token that needs to be redeemed. + bytes serialized_unblinded_token = 3; + + // Plaintext input message to verify the signature for. + bytes plaintext_message = 5; + + // Nonce used to mask plaintext message before cryptographic verification. + bytes message_mask = 6; + } + + // One or more anonymous tokens to redeem. + repeated AnonymousTokenToRedeem anonymous_tokens_to_redeem = 1; +} + +message AnonymousTokensRedemptionResponse { + // Next ID: 9 + message AnonymousTokenRedemptionResult { + // Use case associated with this redeemed anonymous token. + bytes use_case = 3; + + // Version of key associated with this redeemed anonymous token. + int64 key_version = 4; + + // Public metadata used for verifying the signature + // (serialized_unblinded_token). + // + // The length of public metadata must fit in 4 bytes. + bytes public_metadata = 5; + + // Serialization of this redeemed unblinded anonymous token. + bytes serialized_unblinded_token = 6; + + // Unblinded input message that the signature was verified against. + bytes plaintext_message = 7; + + // Nonce used to mask plaintext message before cryptographic verification. + bytes message_mask = 8; + + // Returns true if and only if the anonymous token was redeemed + // successfully i.e. token was cryptographically verified, all relevant + // state in the server was updated successfully and the token was not + // redeemed already. + // + bool verified = 1; + + // Returns true if and only if the anonymous token has already been + // redeemed. + bool double_spent = 2; + } + + // Redemption response for requested anonymous tokens. + repeated AnonymousTokenRedemptionResult anonymous_token_redemption_results = + 1; +} + +// Plaintext message with public metadata. +message PlaintextMessageWithPublicMetadata { + // Message to be signed. + bytes plaintext_message = 1; + + // Public metadata to be tied to the signature. + bytes public_metadata = 2; +} + +// Proto representing a token created during the blind signing protocol. +message RSABlindSignatureToken { + // Resulting token from the blind signing protocol. + bytes token = 1; + + // Nonce used to mask messages. + bytes message_mask = 2; +} + +// Proto representing a token along with the input. +message RSABlindSignatureTokenWithInput { + // Input consisting of plaintext message and public metadata. + PlaintextMessageWithPublicMetadata input = 1; + + // Resulting token after blind signing protocol. + RSABlindSignatureToken token = 2; +} + +// Proto representing redemption result along with the token and the token +// input. +message RSABlindSignatureRedemptionResult { + // Proto representing a token along with the input. + RSABlindSignatureTokenWithInput token_with_input = 1; + + // This is set to true if and only if the anonymous token was redeemed + // successfully i.e. token was cryptographically verified, all relevant + // state in the redemption server was updated successfully and the token was + // not redeemed already. + bool redeemed = 2; + + // True if and only if the token was redeemed before. + bool double_spent = 3; +} diff --git a/quiche/blind_sign_auth/blind_sign_auth.cc b/quiche/blind_sign_auth/blind_sign_auth.cc new file mode 100644 index 000000000000..d13dc332bcef --- /dev/null +++ b/quiche/blind_sign_auth/blind_sign_auth.cc @@ -0,0 +1,299 @@ +// Copyright (c) 2023 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/blind_sign_auth/blind_sign_auth.h" + +#include +#include +#include +#include + +#include "quiche/blind_sign_auth/proto/auth_and_sign.pb.h" +#include "quiche/blind_sign_auth/proto/get_initial_data.pb.h" +#include "quiche/blind_sign_auth/proto/key_services.pb.h" +#include "quiche/blind_sign_auth/proto/public_metadata.pb.h" +#include "quiche/blind_sign_auth/proto/spend_token_data.pb.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/shared/proto_utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.pb.h" +#include "quiche/blind_sign_auth/blind_sign_http_response.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_random.h" + +namespace quiche { +namespace { + +template +std::string OmitDefault(T value) { + return value == 0 ? "" : absl::StrCat(value); +} + +} // namespace + +void BlindSignAuth::GetTokens( + absl::string_view oauth_token, int num_tokens, + std::function>)> + callback) { + // Create GetInitialData RPC. + privacy::ppn::GetInitialDataRequest request; + request.set_use_attestation(false); + request.set_service_type("chromeipblinding"); + request.set_location_granularity( + privacy::ppn::GetInitialDataRequest_LocationGranularity_CITY_GEOS); + + // Call GetInitialData on the HttpFetcher. + std::string path_and_query = "/v1/getInitialData"; + std::string body = request.SerializeAsString(); + http_fetcher_->DoRequest( + path_and_query, oauth_token.data(), body, + [this, callback, oauth_token, + num_tokens](absl::StatusOr response) { + GetInitialDataCallback(response, oauth_token, num_tokens, callback); + }); +} + +void BlindSignAuth::GetInitialDataCallback( + absl::StatusOr response, + absl::string_view oauth_token, int num_tokens, + std::function>)> callback) { + if (!response.ok()) { + QUICHE_LOG(WARNING) << "GetInitialDataRequest failed: " + << response.status(); + callback(response.status()); + return; + } + int status_code = response.value().status_code(); + if (response.value().status_code() != 200) { + QUICHE_LOG(WARNING) << "GetInitialDataRequest failed with code: " + << status_code; + callback(response.status()); + return; + } + // Parse GetInitialDataResponse. + privacy::ppn::GetInitialDataResponse initial_data_response; + if (!initial_data_response.ParseFromString(response.value().body())) { + QUICHE_LOG(WARNING) << "Failed to parse GetInitialDataResponse"; + callback(absl::InternalError("Failed to parse GetInitialDataResponse")); + return; + } + + // Create RSA BSSA client. + auto bssa_client = + private_membership::anonymous_tokens::AnonymousTokensRsaBssaClient:: + Create(initial_data_response.at_public_metadata_public_key()); + if (!bssa_client.ok()) { + QUICHE_LOG(WARNING) << "Failed to create AT BSSA client: " + << bssa_client.status(); + callback(bssa_client.status()); + return; + } + + // Create plaintext tokens. + // Client blinds plaintext tokens (random 32-byte strings) in CreateRequest. + std::vector< + private_membership::anonymous_tokens::PlaintextMessageWithPublicMetadata> + plaintext_tokens; + QuicheRandom* random = QuicheRandom::GetInstance(); + for (int i = 0; i < num_tokens; i++) { + // Create random 32-byte string prefixed with "blind:". + private_membership::anonymous_tokens::PlaintextMessageWithPublicMetadata + plaintext_message; + std::string rand_bytes(32, '\0'); + random->RandBytes(rand_bytes.data(), rand_bytes.size()); + plaintext_message.set_plaintext_message(absl::StrCat("blind:", rand_bytes)); + uint64_t fingerprint = 0; + absl::Status fingerprint_status = FingerprintPublicMetadata( + initial_data_response.public_metadata_info().public_metadata(), + &fingerprint); + if (!fingerprint_status.ok()) { + QUICHE_LOG(WARNING) << "Failed to fingerprint public metadata: " + << fingerprint_status; + callback(fingerprint_status); + return; + } + plaintext_message.set_public_metadata(absl::StrCat(fingerprint)); + plaintext_tokens.push_back(plaintext_message); + } + + absl::StatusOr< + private_membership::anonymous_tokens::AnonymousTokensSignRequest> + at_sign_request = bssa_client.value()->CreateRequest(plaintext_tokens); + if (!at_sign_request.ok()) { + QUICHE_LOG(WARNING) << "Failed to create AT Sign Request: " + << at_sign_request.status(); + callback(at_sign_request.status()); + return; + } + + // Create AuthAndSign RPC. + privacy::ppn::AuthAndSignRequest sign_request; + sign_request.set_oauth_token(std::string(oauth_token)); + sign_request.set_service_type("chromeipblinding"); + sign_request.set_key_type(privacy::ppn::AT_PUBLIC_METADATA_KEY_TYPE); + sign_request.set_key_version( + initial_data_response.at_public_metadata_public_key().key_version()); + *sign_request.mutable_public_metadata_info() = + initial_data_response.public_metadata_info(); + for (int i = 0; i < at_sign_request->blinded_tokens_size(); i++) { + sign_request.add_blinded_token(absl::Base64Escape( + at_sign_request->blinded_tokens().at(i).serialized_token())); + } + + privacy::ppn::PublicMetadataInfo public_metadata_info = + initial_data_response.public_metadata_info(); + http_fetcher_->DoRequest( + "/v1/authWithHeaderCreds", oauth_token.data(), + sign_request.SerializeAsString(), + [this, at_sign_request, public_metadata_info, + bssa_client_ = bssa_client.value().get(), + callback](absl::StatusOr response) { + AuthAndSignCallback(response, public_metadata_info, *at_sign_request, + bssa_client_, callback); + }); +} + +void BlindSignAuth::AuthAndSignCallback( + absl::StatusOr response, + privacy::ppn::PublicMetadataInfo public_metadata_info, + private_membership::anonymous_tokens::AnonymousTokensSignRequest + at_sign_request, + private_membership::anonymous_tokens::AnonymousTokensRsaBssaClient* + bssa_client, + std::function>)> callback) { + // Validate response. + if (!response.ok()) { + QUICHE_LOG(WARNING) << "AuthAndSign failed: " << response.status(); + callback(response.status()); + return; + } + int status_code = response.value().status_code(); + if (response.value().status_code() != 200) { + QUICHE_LOG(WARNING) << "AuthAndSign failed with code: " << status_code; + callback(response.status()); + return; + } + + // Decode AuthAndSignResponse. + privacy::ppn::AuthAndSignResponse sign_response; + if (!sign_response.ParseFromString(response.value().body())) { + QUICHE_LOG(WARNING) << "Failed to parse AuthAndSignResponse"; + callback(absl::InternalError("Failed to parse AuthAndSignResponse")); + return; + } + + // Create vector of unblinded anonymous tokens. + private_membership::anonymous_tokens::AnonymousTokensSignResponse + at_sign_response; + + if (sign_response.blinded_token_signature_size() != + at_sign_request.blinded_tokens_size()) { + QUICHE_LOG(WARNING) + << "Response signature size does not equal request tokens size"; + callback(absl::InternalError( + "Response signature size does not equal request tokens size")); + return; + } + // This depends on the signing server returning the signatures in the order + // that the tokens were sent. Phosphor does guarantee this. + for (int i = 0; i < sign_response.blinded_token_signature_size(); i++) { + std::string blinded_token; + if (!absl::Base64Unescape(sign_response.blinded_token_signature(i), + &blinded_token)) { + QUICHE_LOG(WARNING) << "Failed to unescape blinded token signature"; + callback( + absl::InternalError("Failed to unescape blinded token signature")); + return; + } + private_membership::anonymous_tokens::AnonymousTokensSignResponse:: + AnonymousToken anon_token_proto; + *anon_token_proto.mutable_use_case() = + at_sign_request.blinded_tokens(i).use_case(); + anon_token_proto.set_key_version( + at_sign_request.blinded_tokens(i).key_version()); + *anon_token_proto.mutable_public_metadata() = + at_sign_request.blinded_tokens(i).public_metadata(); + *anon_token_proto.mutable_serialized_blinded_message() = + at_sign_request.blinded_tokens(i).serialized_token(); + *anon_token_proto.mutable_serialized_token() = blinded_token; + at_sign_response.add_anonymous_tokens()->Swap(&anon_token_proto); + } + + auto signed_tokens = bssa_client->ProcessResponse(at_sign_response); + if (!signed_tokens.ok()) { + QUICHE_LOG(WARNING) << "AuthAndSign ProcessResponse failed: " + << signed_tokens.status(); + callback(signed_tokens.status()); + return; + } + if (signed_tokens->size() != + static_cast(at_sign_response.anonymous_tokens_size())) { + QUICHE_LOG(WARNING) + << "ProcessResponse did not output the right number of signed tokens"; + callback(absl::InternalError( + "ProcessResponse did not output the right number of signed tokens")); + return; + } + + // Output SpendTokenData with data for the redeemer to make a SpendToken RPC. + std::vector tokens_vec; + for (size_t i = 0; i < signed_tokens->size(); i++) { + privacy::ppn::SpendTokenData spend_token_data; + *spend_token_data.mutable_public_metadata() = + public_metadata_info.public_metadata(); + *spend_token_data.mutable_unblinded_token() = + signed_tokens->at(i).input().plaintext_message(); + *spend_token_data.mutable_unblinded_token_signature() = + signed_tokens->at(i).token().token(); + spend_token_data.set_signing_key_version( + at_sign_response.anonymous_tokens(i).key_version()); + auto use_case = private_membership::anonymous_tokens::ParseUseCase( + at_sign_response.anonymous_tokens(i).use_case()); + if (!use_case.ok()) { + QUICHE_LOG(WARNING) << "Failed to parse use case: " << use_case.status(); + callback(use_case.status()); + return; + } + spend_token_data.set_use_case(*use_case); + spend_token_data.set_message_mask( + signed_tokens->at(i).token().message_mask()); + tokens_vec.push_back(spend_token_data.SerializeAsString()); + } + + callback(absl::Span(tokens_vec)); +} + +absl::Status BlindSignAuth::FingerprintPublicMetadata( + const privacy::ppn::PublicMetadata& metadata, uint64_t* fingerprint) { + const EVP_MD* hasher = EVP_sha256(); + std::string digest; + digest.resize(EVP_MAX_MD_SIZE); + + uint32_t digest_length = 0; + // Concatenate fields in tag number order, omitting fields whose values match + // the default. This enables new fields to be added without changing the + // resulting encoding. The signer needs to ensure that | is not allowed in any + // metadata value so intentional collisions cannot be created. + const std::vector parts = { + metadata.exit_location().country(), + metadata.exit_location().city_geo_id(), + metadata.service_type(), + OmitDefault(metadata.expiration().seconds()), + OmitDefault(metadata.expiration().nanos()), + }; + const std::string input = absl::StrJoin(parts, "|"); + if (EVP_Digest(input.data(), input.length(), + reinterpret_cast(&digest[0]), &digest_length, hasher, + nullptr) != 1) { + return absl::InternalError("EVP_Digest failed"); + } + // Return the first uint64_t of the SHA-256 hash. + memcpy(fingerprint, digest.data(), sizeof(*fingerprint)); + return absl::OkStatus(); +} + +} // namespace quiche diff --git a/quiche/blind_sign_auth/blind_sign_auth.h b/quiche/blind_sign_auth/blind_sign_auth.h new file mode 100644 index 000000000000..5a9384bb81e2 --- /dev/null +++ b/quiche/blind_sign_auth/blind_sign_auth.h @@ -0,0 +1,64 @@ +// Copyright (c) 2023 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_BLIND_SIGN_AUTH_BLIND_SIGN_AUTH_H_ +#define QUICHE_BLIND_SIGN_AUTH_BLIND_SIGN_AUTH_H_ + +#include +#include +#include +#include + +#include "quiche/blind_sign_auth/proto/public_metadata.pb.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/client/anonymous_tokens_rsa_bssa_client.h" +#include "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.pb.h" +#include "quiche/blind_sign_auth/blind_sign_auth_interface.h" +#include "quiche/blind_sign_auth/blind_sign_http_interface.h" +#include "quiche/blind_sign_auth/blind_sign_http_response.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace quiche { + +// BlindSignAuth provides signed, unblinded tokens to callers. +class QUICHE_EXPORT BlindSignAuth : public BlindSignAuthInterface { + public: + explicit BlindSignAuth(BlindSignHttpInterface* http_fetcher) + : http_fetcher_(http_fetcher) {} + + // Returns signed unblinded tokens in a callback. Tokens are single-use. + // GetTokens starts asynchronous HTTP POST requests to a signer hostname + // specified by the caller, with path and query params given in the request. + // The GetTokens callback will run on the same thread as the + // BlindSignHttpInterface callbacks. + // Callers can make multiple concurrent requests to GetTokens. + void GetTokens( + absl::string_view oauth_token, int num_tokens, + std::function>)> + callback) override; + + private: + void GetInitialDataCallback( + absl::StatusOr response, + absl::string_view oauth_token, int num_tokens, + std::function>)> callback); + void AuthAndSignCallback( + absl::StatusOr response, + privacy::ppn::PublicMetadataInfo public_metadata_info, + private_membership::anonymous_tokens::AnonymousTokensSignRequest + at_sign_request, + private_membership::anonymous_tokens::AnonymousTokensRsaBssaClient* + bssa_client, + std::function>)> callback); + absl::Status FingerprintPublicMetadata( + const privacy::ppn::PublicMetadata& metadata, uint64_t* fingerprint); + + BlindSignHttpInterface* http_fetcher_ = nullptr; +}; + +} // namespace quiche + +#endif // QUICHE_BLIND_SIGN_AUTH_BLIND_SIGN_AUTH_H_ diff --git a/quiche/blind_sign_auth/blind_sign_auth_interface.h b/quiche/blind_sign_auth/blind_sign_auth_interface.h new file mode 100644 index 000000000000..f7e390546f96 --- /dev/null +++ b/quiche/blind_sign_auth/blind_sign_auth_interface.h @@ -0,0 +1,32 @@ +// Copyright (c) 2023 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_BLIND_SIGN_AUTH_BLIND_SIGN_AUTH_INTERFACE_H_ +#define QUICHE_BLIND_SIGN_AUTH_BLIND_SIGN_AUTH_INTERFACE_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace quiche { + +// BlindSignAuth provides signed, unblinded tokens to callers. +class QUICHE_EXPORT BlindSignAuthInterface { + public: + virtual ~BlindSignAuthInterface() = default; + + // Returns signed unblinded tokens in a callback. Tokens are single-use. + virtual void GetTokens( + absl::string_view oauth_token, int num_tokens, + std::function>)> + callback) = 0; +}; + +} // namespace quiche + +#endif // QUICHE_BLIND_SIGN_AUTH_BLIND_SIGN_AUTH_INTERFACE_H_ diff --git a/quiche/blind_sign_auth/blind_sign_auth_test.cc b/quiche/blind_sign_auth/blind_sign_auth_test.cc new file mode 100644 index 000000000000..bb9d8daffe1b --- /dev/null +++ b/quiche/blind_sign_auth/blind_sign_auth_test.cc @@ -0,0 +1,300 @@ +// Copyright (c) 2023 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/blind_sign_auth/blind_sign_auth.h" + +#include +#include +#include +#include + +#include "quiche/blind_sign_auth/proto/timestamp.pb.h" +#include "quiche/blind_sign_auth/proto/auth_and_sign.pb.h" +#include "quiche/blind_sign_auth/proto/get_initial_data.pb.h" +#include "quiche/blind_sign_auth/proto/key_services.pb.h" +#include "quiche/blind_sign_auth/proto/public_metadata.pb.h" +#include "quiche/blind_sign_auth/proto/spend_token_data.pb.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/blind_sign_auth/anonymous_tokens/cpp/testing/utils.h" +#include "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.pb.h" +#include "openssl/base.h" +#include "quiche/blind_sign_auth/blind_sign_http_response.h" +#include "quiche/blind_sign_auth/test_tools/mock_blind_sign_http_interface.h" +#include "quiche/common/platform/api/quiche_mutex.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +namespace quiche { +namespace test { +namespace { + +using ::testing::_; +using ::testing::Eq; +using ::testing::InSequence; +using ::testing::Invoke; +using ::testing::InvokeArgument; +using ::testing::StartsWith; +using ::testing::Unused; + +class BlindSignAuthTest : public QuicheTest { + protected: + void SetUp() override { + // Create public key. + auto keypair = private_membership::anonymous_tokens::CreateTestKey(); + if (!keypair.ok()) { + return; + } + keypair_ = *std::move(keypair); + keypair_.second.set_key_version(1); + keypair_.second.set_use_case("TEST_USE_CASE"); + + // Create fake GetInitialDataRequest. + expected_get_initial_data_request_.set_use_attestation(false); + expected_get_initial_data_request_.set_service_type("chromeipblinding"); + expected_get_initial_data_request_.set_location_granularity( + privacy::ppn::GetInitialDataRequest_LocationGranularity_CITY_GEOS); + + // Create fake public key response. + privacy::ppn::GetInitialDataResponse fake_get_initial_data_response; + private_membership::anonymous_tokens::RSABlindSignaturePublicKey public_key; + ASSERT_TRUE( + public_key.ParseFromString(keypair_.second.SerializeAsString())); + *fake_get_initial_data_response.mutable_at_public_metadata_public_key() = + public_key; + + // Create public metadata info. + std::string public_metadata_str = R"pb( + public_metadata { + exit_location { country: "US" } + service_type: "chromeipblinding" + expiration { seconds: 3600 } + } + validation_version: 1 + )pb"; + privacy::ppn::PublicMetadata::Location location; + location.set_country("US"); + quiche::protobuf::Timestamp expiration; + expiration.set_seconds(3600); + privacy::ppn::PublicMetadata public_metadata; + *public_metadata.mutable_exit_location() = location; + public_metadata.set_service_type("chromeipblinding"); + *public_metadata.mutable_expiration() = expiration; + public_metadata_info_.set_validation_version(1); + *public_metadata_info_.mutable_public_metadata() = public_metadata; + *fake_get_initial_data_response.mutable_public_metadata_info() = + public_metadata_info_; + fake_get_initial_data_response_ = fake_get_initial_data_response; + + blind_sign_auth_ = std::make_unique(&mock_http_interface_); + } + + void TearDown() override { + blind_sign_auth_.reset(nullptr); + keypair_.first.reset(nullptr); + keypair_.second.Clear(); + } + + public: + void CreateSignResponse(const std::string& body) { + privacy::ppn::AuthAndSignRequest request; + ASSERT_TRUE(request.ParseFromString(body)); + + // Validate AuthAndSignRequest. + EXPECT_EQ(request.oauth_token(), oauth_token_); + EXPECT_EQ(request.service_type(), "chromeipblinding"); + // Phosphor does not need the public key hash if the KeyType is + // privacy::ppn::AT_PUBLIC_METADATA_KEY_TYPE. + EXPECT_EQ(request.key_type(), privacy::ppn::AT_PUBLIC_METADATA_KEY_TYPE); + EXPECT_EQ(request.public_key_hash(), ""); + EXPECT_EQ(request.public_metadata_info().SerializeAsString(), + public_metadata_info_.SerializeAsString()); + EXPECT_EQ(request.key_version(), keypair_.second.key_version()); + + // Construct AuthAndSignResponse. + privacy::ppn::AuthAndSignResponse response; + for (const auto& request_token : request.blinded_token()) { + std::string decoded_blinded_token; + ASSERT_TRUE(absl::Base64Unescape(request_token, &decoded_blinded_token)); + absl::StatusOr serialized_token = + private_membership::anonymous_tokens::TestSign(decoded_blinded_token, + keypair_.first.get()); + QUICHE_EXPECT_OK(serialized_token); + response.add_blinded_token_signature( + absl::Base64Escape(*serialized_token)); + } + sign_response_ = response; + } + + void ValidateGetTokensOutput(const absl::Span& tokens) { + for (const auto& token : tokens) { + privacy::ppn::SpendTokenData spend_token_data; + ASSERT_TRUE(spend_token_data.ParseFromString(token)); + // Validate token structure. + EXPECT_EQ(spend_token_data.public_metadata().SerializeAsString(), + public_metadata_info_.public_metadata().SerializeAsString()); + EXPECT_THAT(spend_token_data.unblinded_token(), StartsWith("blind:")); + EXPECT_GE(spend_token_data.unblinded_token_signature().size(), + spend_token_data.unblinded_token().size()); + EXPECT_EQ(spend_token_data.signing_key_version(), + keypair_.second.key_version()); + EXPECT_NE(spend_token_data.use_case(), + private_membership::anonymous_tokens::AnonymousTokensUseCase:: + ANONYMOUS_TOKENS_USE_CASE_UNDEFINED); + EXPECT_NE(spend_token_data.message_mask(), ""); + } + } + + MockBlindSignHttpInterface mock_http_interface_; + std::unique_ptr blind_sign_auth_; + std::pair, + private_membership::anonymous_tokens::RSABlindSignaturePublicKey> + keypair_; + privacy::ppn::PublicMetadataInfo public_metadata_info_; + privacy::ppn::AuthAndSignResponse sign_response_; + privacy::ppn::GetInitialDataResponse fake_get_initial_data_response_; + std::string oauth_token_ = "oauth_token"; + privacy::ppn::GetInitialDataRequest expected_get_initial_data_request_; +}; + +TEST_F(BlindSignAuthTest, TestGetTokensSuccessful) { + BlindSignHttpResponse fake_public_key_response( + 200, fake_get_initial_data_response_.SerializeAsString()); + + { + InSequence seq; + + EXPECT_CALL( + mock_http_interface_, + DoRequest(Eq("/v1/getInitialData"), Eq(oauth_token_), + Eq(expected_get_initial_data_request_.SerializeAsString()), + _)) + .Times(1) + .WillOnce(InvokeArgument<3>(fake_public_key_response)); + + EXPECT_CALL(mock_http_interface_, DoRequest(Eq("/v1/authWithHeaderCreds"), + Eq(oauth_token_), _, _)) + .Times(1) + .WillOnce(Invoke( + [this](Unused, Unused, const std::string& body, + std::function)> + callback) { + CreateSignResponse(body); + BlindSignHttpResponse http_response( + 200, sign_response_.SerializeAsString()); + callback(http_response); + })); + } + + int num_tokens = 1; + QuicheNotification done; + std::function>)> callback = + [this, &done, + num_tokens](absl::StatusOr> tokens) { + QUICHE_EXPECT_OK(tokens); + EXPECT_EQ(tokens->size(), num_tokens); + ValidateGetTokensOutput(*tokens); + done.Notify(); + }; + blind_sign_auth_->GetTokens(oauth_token_, num_tokens, callback); + done.WaitForNotification(); +} + +TEST_F(BlindSignAuthTest, TestGetTokensFailedNetworkError) { + EXPECT_CALL(mock_http_interface_, + DoRequest(Eq("/v1/getInitialData"), Eq(oauth_token_), _, _)) + .Times(1) + .WillOnce( + InvokeArgument<3>(absl::InternalError("Failed to create socket"))); + + EXPECT_CALL(mock_http_interface_, + DoRequest(Eq("/v1/authWithHeaderCreds"), _, _, _)) + .Times(0); + + int num_tokens = 1; + QuicheNotification done; + std::function>)> callback = + [&done](absl::StatusOr> tokens) { + EXPECT_THAT(tokens.status().code(), absl::StatusCode::kInternal); + done.Notify(); + }; + blind_sign_auth_->GetTokens(oauth_token_, num_tokens, callback); + done.WaitForNotification(); +} + +TEST_F(BlindSignAuthTest, TestGetTokensFailedBadGetInitialDataResponse) { + *fake_get_initial_data_response_.mutable_at_public_metadata_public_key() + ->mutable_use_case() = "SPAM"; + + BlindSignHttpResponse fake_public_key_response( + 200, fake_get_initial_data_response_.SerializeAsString()); + + EXPECT_CALL( + mock_http_interface_, + DoRequest(Eq("/v1/getInitialData"), Eq(oauth_token_), + Eq(expected_get_initial_data_request_.SerializeAsString()), _)) + .Times(1) + .WillOnce(InvokeArgument<3>(fake_public_key_response)); + + EXPECT_CALL(mock_http_interface_, + DoRequest(Eq("/v1/authWithHeaderCreds"), _, _, _)) + .Times(0); + + int num_tokens = 1; + QuicheNotification done; + std::function>)> callback = + [&done](absl::StatusOr> tokens) { + EXPECT_THAT(tokens.status().code(), absl::StatusCode::kInvalidArgument); + done.Notify(); + }; + blind_sign_auth_->GetTokens(oauth_token_, num_tokens, callback); + done.WaitForNotification(); +} + +TEST_F(BlindSignAuthTest, TestGetTokensFailedBadAuthAndSignResponse) { + BlindSignHttpResponse fake_public_key_response( + 200, fake_get_initial_data_response_.SerializeAsString()); + { + InSequence seq; + + EXPECT_CALL( + mock_http_interface_, + DoRequest(Eq("/v1/getInitialData"), Eq(oauth_token_), + Eq(expected_get_initial_data_request_.SerializeAsString()), + _)) + .Times(1) + .WillOnce(InvokeArgument<3>(fake_public_key_response)); + + EXPECT_CALL(mock_http_interface_, DoRequest(Eq("/v1/authWithHeaderCreds"), + Eq(oauth_token_), _, _)) + .Times(1) + .WillOnce(Invoke( + [this](Unused, Unused, const std::string& body, + std::function)> + callback) { + CreateSignResponse(body); + // Add an invalid signature that can't be Base64 decoded. + sign_response_.add_blinded_token_signature("invalid_signature%"); + BlindSignHttpResponse http_response( + 200, sign_response_.SerializeAsString()); + callback(http_response); + })); + } + + int num_tokens = 1; + QuicheNotification done; + std::function>)> callback = + [&done](absl::StatusOr> tokens) { + EXPECT_THAT(tokens.status().code(), absl::StatusCode::kInternal); + done.Notify(); + }; + blind_sign_auth_->GetTokens(oauth_token_, num_tokens, callback); + done.WaitForNotification(); +} + +} // namespace +} // namespace test +} // namespace quiche diff --git a/quiche/blind_sign_auth/blind_sign_http_interface.h b/quiche/blind_sign_auth/blind_sign_http_interface.h new file mode 100644 index 000000000000..d8111b4b36ab --- /dev/null +++ b/quiche/blind_sign_auth/blind_sign_http_interface.h @@ -0,0 +1,42 @@ +// Copyright (c) 2023 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_BLIND_SIGN_AUTH_BLIND_SIGN_HTTP_INTERFACE_H_ +#define QUICHE_BLIND_SIGN_AUTH_BLIND_SIGN_HTTP_INTERFACE_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "quiche/blind_sign_auth/blind_sign_http_response.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace quiche { + +// Interface for async HTTP POST requests in BlindSignAuth. +// Implementers must send a request to a signer hostname, using the request's +// arguments, and call the provided callback when a request is complete. +class QUICHE_EXPORT BlindSignHttpInterface { + public: + virtual ~BlindSignHttpInterface() = default; + // Non-HTTP errors (like failing to create a socket) must return an + // absl::Status. + // HTTP errors must set status_code and body in BlindSignHttpResponse. + // DoRequest must be a HTTP POST request. + // Requests do not need cookies and must follow redirects. + // The implementer must set Content-Type and Accept headers to + // "application/x-protobuf". + // DoRequest is async. When the request completes, the implementer must call + // the provided callback. + virtual void DoRequest( + const std::string& path_and_query, + const std::string& authorization_header, const std::string& body, + std::function)> callback) = 0; +}; + +} // namespace quiche + +#endif // QUICHE_BLIND_SIGN_AUTH_BLIND_SIGN_HTTP_INTERFACE_H_ diff --git a/quiche/blind_sign_auth/blind_sign_http_response.h b/quiche/blind_sign_auth/blind_sign_http_response.h new file mode 100644 index 000000000000..89d90728f3b3 --- /dev/null +++ b/quiche/blind_sign_auth/blind_sign_http_response.h @@ -0,0 +1,33 @@ +// Copyright (c) 2023 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_BLIND_SIGN_AUTH_BLIND_SIGN_HTTP_RESPONSE_H_ +#define QUICHE_BLIND_SIGN_AUTH_BLIND_SIGN_HTTP_RESPONSE_H_ + +#include +#include +#include +#include + +#include "quiche/common/platform/api/quiche_export.h" + +namespace quiche { + +// Contains a response to a HTTP POST request issued by BlindSignAuth. +class QUICHE_EXPORT BlindSignHttpResponse { + public: + BlindSignHttpResponse(int status_code, std::string body) + : status_code_(status_code), body_(std::move(body)) {} + + int status_code() const { return status_code_; } + const std::string& body() const { return body_; } + + private: + int status_code_; + std::string body_; +}; + +} // namespace quiche + +#endif // QUICHE_BLIND_SIGN_AUTH_BLIND_SIGN_HTTP_RESPONSE_H_ diff --git a/quiche/blind_sign_auth/cached_blind_sign_auth.cc b/quiche/blind_sign_auth/cached_blind_sign_auth.cc new file mode 100644 index 000000000000..34e5e73e0beb --- /dev/null +++ b/quiche/blind_sign_auth/cached_blind_sign_auth.cc @@ -0,0 +1,115 @@ +// Copyright (c) 2023 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/blind_sign_auth/cached_blind_sign_auth.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_mutex.h" + +namespace quiche { + +void CachedBlindSignAuth::GetTokens( + absl::string_view oauth_token, int num_tokens, + std::function>)> + callback) { + if (num_tokens > max_tokens_per_request_) { + callback(absl::InvalidArgumentError( + absl::StrFormat("Number of tokens requested exceeds maximum: %d", + kBlindSignAuthRequestMaxTokens))); + return; + } + if (num_tokens < 0) { + callback(absl::InvalidArgumentError(absl::StrFormat( + "Negative number of tokens requested: %d", num_tokens))); + return; + } + + std::vector output_tokens; + { + QuicheWriterMutexLock lock(&mutex_); + + // Try to fill the request from cache. + if (static_cast(num_tokens) <= cached_tokens_.size()) { + output_tokens = CreateOutputTokens(num_tokens); + } + } + if (!output_tokens.empty() || num_tokens == 0) { + callback(output_tokens); + return; + } + + // Make a GetTokensRequest if the cache can't handle the request size. + std::function>)> + caching_callback = + [this, num_tokens, + callback](absl::StatusOr> tokens) { + HandleGetTokensResponse(tokens, num_tokens, callback); + }; + blind_sign_auth_->GetTokens(oauth_token, kBlindSignAuthRequestMaxTokens, + caching_callback); +} + +void CachedBlindSignAuth::HandleGetTokensResponse( + absl::StatusOr> tokens, int num_tokens, + std::function>)> + callback) { + if (!tokens.ok()) { + QUICHE_LOG(WARNING) << "BlindSignAuth::GetTokens failed: " + << tokens.status(); + callback(tokens); + return; + } + if (tokens->size() < static_cast(num_tokens) || + tokens->size() > kBlindSignAuthRequestMaxTokens) { + QUICHE_LOG(WARNING) << "Expected " << num_tokens << " tokens, got " + << tokens->size(); + } + + std::vector output_tokens; + size_t cache_size; + { + QuicheWriterMutexLock lock(&mutex_); + + // Add returned tokens to cache. + for (const std::string& token : *tokens) { + cached_tokens_.push_back(token); + } + + // Return tokens or a ResourceExhaustedError. + cache_size = cached_tokens_.size(); + if (cache_size >= static_cast(num_tokens)) { + output_tokens = CreateOutputTokens(num_tokens); + } + } + + if (!output_tokens.empty()) { + callback(output_tokens); + return; + } + callback(absl::ResourceExhaustedError(absl::StrFormat( + "Requested %d tokens, cache only has %d after GetTokensRequest", + num_tokens, cache_size))); +} + +std::vector CachedBlindSignAuth::CreateOutputTokens( + int num_tokens) { + std::vector output_tokens; + if (cached_tokens_.size() < static_cast(num_tokens)) { + QUICHE_LOG(FATAL) << "Check failed, not enough tokens in cache: " + << cached_tokens_.size() << " < " << num_tokens; + } + for (int i = 0; i < num_tokens; i++) { + output_tokens.push_back(std::move(cached_tokens_.front())); + cached_tokens_.pop_front(); + } + return output_tokens; +} + +} // namespace quiche diff --git a/quiche/blind_sign_auth/cached_blind_sign_auth.h b/quiche/blind_sign_auth/cached_blind_sign_auth.h new file mode 100644 index 000000000000..ee405a11606d --- /dev/null +++ b/quiche/blind_sign_auth/cached_blind_sign_auth.h @@ -0,0 +1,65 @@ +// Copyright (c) 2023 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_BLIND_SIGN_AUTH_CACHED_BLIND_SIGN_AUTH_H_ +#define QUICHE_BLIND_SIGN_AUTH_CACHED_BLIND_SIGN_AUTH_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "quiche/blind_sign_auth/blind_sign_auth_interface.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_mutex.h" +#include "quiche/common/quiche_circular_deque.h" + +namespace quiche { + +inline constexpr int kBlindSignAuthRequestMaxTokens = 1024; + +// CachedBlindSignAuth caches signed tokens generated by BlindSignAuth. +// This class does not guarantee that tokens returned are fresh. +// Tokens may be stale if the backend has rotated its signing key since tokens +// were generated. +// This class is thread-safe. +class QUICHE_EXPORT CachedBlindSignAuth : public BlindSignAuthInterface { + public: + CachedBlindSignAuth( + BlindSignAuthInterface* blind_sign_auth, + int max_tokens_per_request = kBlindSignAuthRequestMaxTokens) + : blind_sign_auth_(blind_sign_auth), + max_tokens_per_request_(max_tokens_per_request) {} + + // Returns signed unblinded tokens in a callback. Tokens are single-use. + // + // The GetTokens callback may be called synchronously on the calling thread, + // or asynchronously on BlindSignAuth's BlindSignHttpInterface thread. + // The GetTokens callback must not acquire any locks that the calling thread + // owns, otherwise the callback will deadlock. + void GetTokens( + absl::string_view oauth_token, int num_tokens, + std::function>)> + callback) override; + + private: + void HandleGetTokensResponse( + absl::StatusOr> tokens, int num_tokens, + std::function>)> + callback); + std::vector CreateOutputTokens(int num_tokens) + QUICHE_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + BlindSignAuthInterface* blind_sign_auth_; + int max_tokens_per_request_; + QuicheMutex mutex_; + QuicheCircularDeque cached_tokens_ QUICHE_GUARDED_BY(mutex_); +}; + +} // namespace quiche + +#endif // QUICHE_BLIND_SIGN_AUTH_CACHED_BLIND_SIGN_AUTH_H_ diff --git a/quiche/blind_sign_auth/cached_blind_sign_auth_test.cc b/quiche/blind_sign_auth/cached_blind_sign_auth_test.cc new file mode 100644 index 000000000000..dfad523fe3fb --- /dev/null +++ b/quiche/blind_sign_auth/cached_blind_sign_auth_test.cc @@ -0,0 +1,337 @@ +// Copyright (c) 2023 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/blind_sign_auth/cached_blind_sign_auth.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "quiche/blind_sign_auth/test_tools/mock_blind_sign_auth_interface.h" +#include "quiche/common/platform/api/quiche_mutex.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +namespace quiche { +namespace test { +namespace { + +using ::testing::_; +using ::testing::Invoke; +using ::testing::InvokeArgument; +using ::testing::Unused; + +class CachedBlindSignAuthTest : public QuicheTest { + protected: + void SetUp() override { + cached_blind_sign_auth_ = + std::make_unique(&mock_blind_sign_auth_interface_); + } + + void TearDown() override { + fake_tokens_.clear(); + cached_blind_sign_auth_.reset(); + } + + public: + std::vector MakeFakeTokens(int num_tokens) { + std::vector fake_tokens; + for (int i = 0; i < kBlindSignAuthRequestMaxTokens; i++) { + fake_tokens.push_back(absl::StrCat("token:", i)); + } + return fake_tokens; + } + MockBlindSignAuthInterface mock_blind_sign_auth_interface_; + std::unique_ptr cached_blind_sign_auth_; + std::string oauth_token_ = "oauth_token"; + std::vector fake_tokens_; +}; + +TEST_F(CachedBlindSignAuthTest, TestGetTokensOneCallSuccessful) { + EXPECT_CALL(mock_blind_sign_auth_interface_, + GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _)) + .Times(1) + .WillOnce(Invoke( + [this](Unused, int num_tokens, + std::function>)> + callback) { + fake_tokens_ = MakeFakeTokens(num_tokens); + callback(absl::MakeSpan(fake_tokens_)); + })); + + int num_tokens = 5; + QuicheNotification done; + std::function>)> callback = + [num_tokens, + &done](absl::StatusOr> tokens) { + QUICHE_EXPECT_OK(tokens); + EXPECT_EQ(num_tokens, tokens->size()); + for (int i = 0; i < num_tokens; i++) { + EXPECT_EQ(tokens->at(i), absl::StrCat("token:", i)); + } + done.Notify(); + }; + + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, callback); + done.WaitForNotification(); +} + +TEST_F(CachedBlindSignAuthTest, TestGetTokensMultipleRemoteCallsSuccessful) { + EXPECT_CALL(mock_blind_sign_auth_interface_, + GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _)) + .Times(2) + .WillRepeatedly(Invoke( + [this](Unused, int num_tokens, + std::function>)> + callback) { + fake_tokens_ = MakeFakeTokens(num_tokens); + callback(absl::MakeSpan(fake_tokens_)); + })); + + int num_tokens = kBlindSignAuthRequestMaxTokens - 1; + QuicheNotification first; + std::function>)> + first_callback = + [num_tokens, + &first](absl::StatusOr> tokens) { + QUICHE_EXPECT_OK(tokens); + EXPECT_EQ(num_tokens, tokens->size()); + for (int i = 0; i < num_tokens; i++) { + EXPECT_EQ(tokens->at(i), absl::StrCat("token:", i)); + } + first.Notify(); + }; + + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, first_callback); + first.WaitForNotification(); + + QuicheNotification second; + std::function>)> + second_callback = + [num_tokens, + &second](absl::StatusOr> tokens) { + QUICHE_EXPECT_OK(tokens); + EXPECT_EQ(num_tokens, tokens->size()); + EXPECT_EQ( + tokens->at(0), + absl::StrCat("token:", kBlindSignAuthRequestMaxTokens - 1)); + for (int i = 1; i < num_tokens; i++) { + EXPECT_EQ(tokens->at(i), absl::StrCat("token:", i - 1)); + } + second.Notify(); + }; + + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, second_callback); + second.WaitForNotification(); +} + +TEST_F(CachedBlindSignAuthTest, TestGetTokensSecondRequestFilledFromCache) { + EXPECT_CALL(mock_blind_sign_auth_interface_, + GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _)) + .Times(1) + .WillOnce(Invoke( + [this](Unused, int num_tokens, + std::function>)> + callback) { + fake_tokens_ = MakeFakeTokens(num_tokens); + callback(absl::MakeSpan(fake_tokens_)); + })); + + int num_tokens = kBlindSignAuthRequestMaxTokens / 2; + QuicheNotification first; + std::function>)> + first_callback = + [num_tokens, + &first](absl::StatusOr> tokens) { + QUICHE_EXPECT_OK(tokens); + EXPECT_EQ(num_tokens, tokens->size()); + for (int i = 0; i < num_tokens; i++) { + EXPECT_EQ(tokens->at(i), absl::StrCat("token:", i)); + } + first.Notify(); + }; + + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, first_callback); + first.WaitForNotification(); + + QuicheNotification second; + std::function>)> + second_callback = + [num_tokens, + &second](absl::StatusOr> tokens) { + QUICHE_EXPECT_OK(tokens); + EXPECT_EQ(num_tokens, tokens->size()); + for (int i = 0; i < num_tokens; i++) { + EXPECT_EQ(tokens->at(i), absl::StrCat("token:", i + num_tokens)); + } + second.Notify(); + }; + + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, second_callback); + second.WaitForNotification(); +} + +TEST_F(CachedBlindSignAuthTest, TestGetTokensThirdRequestRefillsCache) { + EXPECT_CALL(mock_blind_sign_auth_interface_, + GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _)) + .Times(2) + .WillRepeatedly(Invoke( + [this](Unused, int num_tokens, + std::function>)> + callback) { + fake_tokens_ = MakeFakeTokens(num_tokens); + callback(absl::MakeSpan(fake_tokens_)); + })); + + int num_tokens = kBlindSignAuthRequestMaxTokens / 2; + QuicheNotification first; + std::function>)> + first_callback = + [num_tokens, + &first](absl::StatusOr> tokens) { + QUICHE_EXPECT_OK(tokens); + EXPECT_EQ(num_tokens, tokens->size()); + for (int i = 0; i < num_tokens; i++) { + EXPECT_EQ(tokens->at(i), absl::StrCat("token:", i)); + } + first.Notify(); + }; + + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, first_callback); + first.WaitForNotification(); + + QuicheNotification second; + std::function>)> + second_callback = + [num_tokens, + &second](absl::StatusOr> tokens) { + QUICHE_EXPECT_OK(tokens); + EXPECT_EQ(num_tokens, tokens->size()); + for (int i = 0; i < num_tokens; i++) { + EXPECT_EQ(tokens->at(i), absl::StrCat("token:", i + num_tokens)); + } + second.Notify(); + }; + + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, second_callback); + second.WaitForNotification(); + + QuicheNotification third; + int third_request_tokens = 10; + std::function>)> + third_callback = + [third_request_tokens, + &third](absl::StatusOr> tokens) { + QUICHE_EXPECT_OK(tokens); + EXPECT_EQ(third_request_tokens, tokens->size()); + for (int i = 0; i < third_request_tokens; i++) { + EXPECT_EQ(tokens->at(i), absl::StrCat("token:", i)); + } + third.Notify(); + }; + + cached_blind_sign_auth_->GetTokens(oauth_token_, third_request_tokens, + third_callback); + third.WaitForNotification(); +} + +TEST_F(CachedBlindSignAuthTest, TestGetTokensRequestTooLarge) { + EXPECT_CALL(mock_blind_sign_auth_interface_, + GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _)) + .Times(0); + + int num_tokens = kBlindSignAuthRequestMaxTokens + 1; + std::function>)> callback = + [](absl::StatusOr> tokens) { + EXPECT_THAT(tokens.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + tokens.status().message(), + absl::StrFormat("Number of tokens requested exceeds maximum: %d", + kBlindSignAuthRequestMaxTokens)); + }; + + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, callback); +} + +TEST_F(CachedBlindSignAuthTest, TestGetTokensRequestNegative) { + EXPECT_CALL(mock_blind_sign_auth_interface_, + GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _)) + .Times(0); + + int num_tokens = -1; + std::function>)> callback = + [num_tokens](absl::StatusOr> tokens) { + EXPECT_THAT(tokens.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(tokens.status().message(), + absl::StrFormat("Negative number of tokens requested: %d", + num_tokens)); + }; + + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, callback); +} + +TEST_F(CachedBlindSignAuthTest, TestHandleGetTokensResponseErrorHandling) { + EXPECT_CALL(mock_blind_sign_auth_interface_, + GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _)) + .Times(2) + .WillOnce(InvokeArgument<2>(absl::InternalError("AuthAndSign failed"))) + .WillOnce(Invoke( + [this](Unused, int num_tokens, + std::function>)> + callback) { + fake_tokens_ = MakeFakeTokens(num_tokens); + fake_tokens_.pop_back(); + callback(absl::MakeSpan(fake_tokens_)); + })); + + int num_tokens = kBlindSignAuthRequestMaxTokens; + QuicheNotification first; + std::function>)> + first_callback = + [&first](absl::StatusOr> tokens) { + EXPECT_THAT(tokens.status().code(), absl::StatusCode::kInternal); + EXPECT_THAT(tokens.status().message(), "AuthAndSign failed"); + first.Notify(); + }; + + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, first_callback); + first.WaitForNotification(); + + QuicheNotification second; + std::function>)> + second_callback = + [&second](absl::StatusOr> tokens) { + EXPECT_THAT(tokens.status().code(), + absl::StatusCode::kResourceExhausted); + second.Notify(); + }; + + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, second_callback); + second.WaitForNotification(); +} + +TEST_F(CachedBlindSignAuthTest, TestGetTokensZeroTokensRequested) { + EXPECT_CALL(mock_blind_sign_auth_interface_, + GetTokens(oauth_token_, kBlindSignAuthRequestMaxTokens, _)) + .Times(0); + + int num_tokens = 0; + std::function>)> callback = + [](absl::StatusOr> tokens) { + QUICHE_EXPECT_OK(tokens); + EXPECT_EQ(tokens->size(), 0); + }; + + cached_blind_sign_auth_->GetTokens(oauth_token_, num_tokens, callback); +} + +} // namespace +} // namespace test +} // namespace quiche diff --git a/quiche/blind_sign_auth/proto/any.proto b/quiche/blind_sign_auth/proto/any.proto new file mode 100644 index 000000000000..c4fa8353ec4b --- /dev/null +++ b/quiche/blind_sign_auth/proto/any.proto @@ -0,0 +1,26 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS-IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package quiche.protobuf; + +// Cloned from +// https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/any.proto. +message Any { + string type_url = 1; + + // Must be a valid serialized protocol buffer of the above specified type. + bytes value = 2; +} diff --git a/quiche/blind_sign_auth/proto/attestation.proto b/quiche/blind_sign_auth/proto/attestation.proto new file mode 100644 index 000000000000..8e658e9595cd --- /dev/null +++ b/quiche/blind_sign_auth/proto/attestation.proto @@ -0,0 +1,114 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS-IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package privacy.ppn; + +import "quiche/blind_sign_auth/proto/any.proto"; + +option java_multiple_files = true; +option java_outer_classname = "AttestationProto"; +option java_package = "com.google.android.libraries.privacy.ppn.proto"; + +message NonceRequest {} + +message NonceResponse { + // A nonce with the following format: + // ECDSA( + // SHA256( + // .)). + bytes nonce = 1 ; + + // Nonce signature. + bytes sig = 2; + + // Algorithm used to sign the nonce. Should be "es256". + bytes alg = 3; +} + +message ValidateDeviceRequest { + // Attestation data that is returned by the client. + oneof attestation_data { + AndroidAttestationData android_attestation_data = 1 [deprecated = true]; + IosAttestationData ios_attestation_data = 2 [deprecated = true]; + } + AttestationData attestation = 3; + + string package_name = 4; + + // If attestation is AndroidAttestationData device models should be listed in: + // https://storage.googleapis.com/play_public/supported_devices.html + repeated string allowed_models = 5; +} + +message ValidateDeviceResponse { + // True iff all checks passed + // (integrity token, nonce, hardware properties are legitimate). + // Hardware properties check will be performed by the calling service + // as attestation only checks to see if the device's hardware properties + // are genuine. + bool device_verified = 1; + + // Detailed information on what specifically passed and what did not. + VerdictBreakdown breakdown = 2; + + // If verified, contains the device model. + string verified_device_type = 3; +} + +message VerdictBreakdown { + enum Verdict { + VERDICT_UNKNOWN = 0; + VERDICT_PASS = 1; + VERDICT_FAIL = 2; + } + + // Integrity verdict as determined by either Play Server or AppAttest. + Verdict integrity_verdict = 1; + + // Whether nonce check passed. + Verdict nonce_verdict = 2; + + // Whether or not the device properties sent by the client are + // legitimate. + Verdict device_properties_verdict = 3; +} + +message PrepareAttestationData { + bytes attestation_nonce = 2 [ + + json_name = "attestation_nonce" + ]; +} + +message AndroidAttestationData { + // Play IntegrityToken returned by Play Integrity API is detailed in + // https://developer.android.com/google/play/integrity/verdict. + string attestation_token = 1 ; + + // X509 Certificate chain generated by Android Keystore used for + // Hardware-Backed Key Attestation. + repeated bytes hardware_backed_certs = 2; +} + +message IosAttestationData { + // AppAttest attestation token. + // Encoded in CBOR format. + bytes attestation_token = 1 ; +} + +message AttestationData { + quiche.protobuf.Any attestation_data = 1; +} diff --git a/quiche/blind_sign_auth/proto/auth_and_sign.proto b/quiche/blind_sign_auth/proto/auth_and_sign.proto new file mode 100644 index 000000000000..38f82d19918c --- /dev/null +++ b/quiche/blind_sign_auth/proto/auth_and_sign.proto @@ -0,0 +1,87 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS-IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package privacy.ppn; + +import "quiche/blind_sign_auth/proto/attestation.proto"; +import "quiche/blind_sign_auth/proto/key_services.proto"; +import "quiche/blind_sign_auth/proto/public_metadata.proto"; + +// Client is requesting to auth using the provided auth token. +// Next ID: 9 +message AuthAndSignRequest { + reserved 3; + + // A 'bearer' oauth token to be validated. + // https://datatracker.ietf.org/doc/html/rfc6750#section-6.1.1 + string oauth_token = 1 ; + + // A string uniquely identifying the strategy this client should be + // authenticated with. + string service_type = 2 ; + + // A set of blinded tokens to be signed by zinc. b64 encoded. + repeated string blinded_token = 4 + ; + + // A sha256 of the public key PEM used in generated `blinded_token`. This + // Ensures the signer signs with the matching key. Only required if key_type + // is ZINC_KEY_TYPE. + string public_key_hash = 5 ; + + oneof attestation_data { + AndroidAttestationData android_attestation_data = 6 [deprecated = true]; + IosAttestationData ios_attestation_data = 7 [deprecated = true]; + } + privacy.ppn.AttestationData attestation = 8; + + privacy.ppn.KeyType key_type = 10 ; + + privacy.ppn.PublicMetadataInfo public_metadata_info = 11 + ; + + // Indicates which key to use for signing. Only set if key type is + // PUBLIC_METADATA. + int64 key_version = 12 ; +} + +message AuthAndSignResponse { + reserved 1, 2, 3; + + // A set of signatures corresponding by index to `blinded_token` in the + // request. b64 encoded. + repeated string blinded_token_signature = 4 [ + + json_name = "blinded_token_signature" + ]; + + // The marconi server hostname bridge-proxy used to set up tunnel. + string copper_controller_hostname = 5 [ + + json_name = "copper_controller_hostname" + ]; + + // The base64 encoding of override_region token and signature for white listed + // users in the format of "${Region}.${timestamp}.${signature}". + string region_token_and_signature = 6 [ + + json_name = "region_token_and_signature" + ]; + + // The APN type bridge-proxy use to deside which APN to use for connecting. + string apn_type = 7 + [ json_name = "apn_type"]; +} diff --git a/quiche/blind_sign_auth/proto/get_initial_data.proto b/quiche/blind_sign_auth/proto/get_initial_data.proto new file mode 100644 index 000000000000..bd6cc34b83d9 --- /dev/null +++ b/quiche/blind_sign_auth/proto/get_initial_data.proto @@ -0,0 +1,61 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS-IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package privacy.ppn; + +import "quiche/blind_sign_auth/proto/attestation.proto"; +import "quiche/blind_sign_auth/proto/public_metadata.proto"; +import "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.proto"; + +option java_multiple_files = true; + +// Request data needed to prepare for AuthAndSign. +message GetInitialDataRequest { + // Whether the client wants to use attestation as part of authentication. + bool use_attestation = 1 ; + + // A string uniquely identifying the strategy this client should be + // authenticated with. + string service_type = 2 ; + + enum LocationGranularity { + UNKNOWN = 0; + COUNTRY = 1; + // Geographic area with population greater than 1 million. + CITY_GEOS = 2; + } + // The user selected granularity of exit IP location. + LocationGranularity location_granularity = 3 + ; + + // Indicates what validation rules the client uses for public metadata. + int64 validation_version = 4 ; +} + +// Contains data needed to perform blind signing and prepare for calling +// AuthAndSign. +message GetInitialDataResponse { + private_membership.anonymous_tokens.RSABlindSignaturePublicKey + at_public_metadata_public_key = 1; + + // Metadata to associate with the token. Version will match the validation + // version in the request. + privacy.ppn.PublicMetadataInfo public_metadata_info = 2; + + // Data needed to set up attestation, included if use_attestation is true or + // if the service_type input requires it. + privacy.ppn.PrepareAttestationData attestation = 3; +} diff --git a/quiche/blind_sign_auth/proto/key_services.proto b/quiche/blind_sign_auth/proto/key_services.proto new file mode 100644 index 000000000000..343fea85dae1 --- /dev/null +++ b/quiche/blind_sign_auth/proto/key_services.proto @@ -0,0 +1,27 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS-IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package privacy.ppn; + +option java_multiple_files = true; + +// Indicates client's desired or capable key support. +enum KeyType { + UNKNOWN_KEY_TYPE = 0; + ZINC_KEY_TYPE = 1; + AT_PUBLIC_METADATA_KEY_TYPE = 2; + AT_PUBLIC_METADATA_VERIFIED_KEY_TYPE = 3; +} diff --git a/quiche/blind_sign_auth/proto/public_metadata.proto b/quiche/blind_sign_auth/proto/public_metadata.proto new file mode 100644 index 000000000000..5154200d081e --- /dev/null +++ b/quiche/blind_sign_auth/proto/public_metadata.proto @@ -0,0 +1,54 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS-IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package privacy.ppn; + +import "quiche/blind_sign_auth/proto/timestamp.proto"; + +option java_multiple_files = true; + +// Contains fields which will be cryptographically linked to a blinded token and +// visible to client, signer, and verifier. Clients should validate/set fields +// contained within such that the values are reasonable for the security and +// privacy constraints of the application. +message PublicMetadata { + // Contains desired exit IP address's declared location. + message Location { + // TODO(b/268354975): fix copybara regex to strip this line automatically + + // All caps ISO 3166-1 alpha-2. + string country = 1; + + // City region geo id if requested by the client. + string city_geo_id = 2; + } + Location exit_location = 1; + + // Indicates which service this token is associated with. + string service_type = 2; + + // When the token and metadata expire. + quiche.protobuf.Timestamp expiration = 3; +} + +// Contains PublicMetadata and associated information. Only the public_metadata +// is cryptographically associated with the token. +message PublicMetadataInfo { + PublicMetadata public_metadata = 1; + + // Earliest validation version that this public metadata conforms to. + int32 validation_version = 2; +} diff --git a/quiche/blind_sign_auth/proto/spend_token_data.proto b/quiche/blind_sign_auth/proto/spend_token_data.proto new file mode 100644 index 000000000000..68c2e6ceb831 --- /dev/null +++ b/quiche/blind_sign_auth/proto/spend_token_data.proto @@ -0,0 +1,38 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS-IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package privacy.ppn; + +import "quiche/blind_sign_auth/proto/public_metadata.proto"; +import "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.proto"; + +message SpendTokenData { + // Public metadata associated with the token being spent. + // See go/ppn-token-spend and go/ppn-phosphor-at-service for details. + PublicMetadata public_metadata = 1; + // The unblinded token to be spent which was blind-signed by Phosphor. + bytes unblinded_token = 2; + // The signature for the token to be spent, obtained from Phosphor and + // unblinded. + bytes unblinded_token_signature = 3; + // The version number of the signing key that was used during blind-signing. + int64 signing_key_version = 4; + // A use case identifying the caller. Should be a fixed, hardcoded value to + // prevent cross-spending tokens. + private_membership.anonymous_tokens.AnonymousTokensUseCase use_case = 5; + // Nonce used to mask plaintext message before cryptographic verification. + bytes message_mask = 6; +} diff --git a/quiche/blind_sign_auth/proto/timestamp.proto b/quiche/blind_sign_auth/proto/timestamp.proto new file mode 100644 index 000000000000..1d99392b742d --- /dev/null +++ b/quiche/blind_sign_auth/proto/timestamp.proto @@ -0,0 +1,32 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS-IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package quiche.protobuf; + +// Copied from +// https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/timestamp.proto. +message Timestamp { + // Represents seconds of UTC time since Unix epoch + // 1970-01-01T00:00:00Z. Must be from 0001-01-01T00:00:00Z to + // 9999-12-31T23:59:59Z inclusive. + int64 seconds = 1; + + // Non-negative fractions of a second at nanosecond resolution. Negative + // second values with fractions must still have non-negative nanos values + // that count forward in time. Must be from 0 to 999,999,999 + // inclusive. + int32 nanos = 2; +} diff --git a/quiche/blind_sign_auth/test_tools/mock_blind_sign_auth_interface.h b/quiche/blind_sign_auth/test_tools/mock_blind_sign_auth_interface.h new file mode 100644 index 000000000000..dcb487610a02 --- /dev/null +++ b/quiche/blind_sign_auth/test_tools/mock_blind_sign_auth_interface.h @@ -0,0 +1,33 @@ +// Copyright (c) 2023 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_BLIND_SIGN_AUTH_TEST_TOOLS_MOCK_BLIND_SIGN_AUTH_INTERFACE_H_ +#define QUICHE_BLIND_SIGN_AUTH_TEST_TOOLS_MOCK_BLIND_SIGN_AUTH_INTERFACE_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "quiche/blind_sign_auth/blind_sign_auth_interface.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace quiche::test { + +class QUICHE_NO_EXPORT MockBlindSignAuthInterface + : public BlindSignAuthInterface { + public: + MOCK_METHOD( + void, GetTokens, + (absl::string_view oauth_token, int num_tokens, + std::function>)> + callback), + (override)); +}; + +} // namespace quiche::test + +#endif // QUICHE_BLIND_SIGN_AUTH_TEST_TOOLS_MOCK_BLIND_SIGN_AUTH_INTERFACE_H_ diff --git a/quiche/blind_sign_auth/test_tools/mock_blind_sign_http_interface.h b/quiche/blind_sign_auth/test_tools/mock_blind_sign_http_interface.h new file mode 100644 index 000000000000..15e970bcf417 --- /dev/null +++ b/quiche/blind_sign_auth/test_tools/mock_blind_sign_http_interface.h @@ -0,0 +1,32 @@ +// Copyright (c) 2023 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_BLIND_SIGN_AUTH_TEST_TOOLS_MOCK_BLIND_SIGN_HTTP_INTERFACE_H_ +#define QUICHE_BLIND_SIGN_AUTH_TEST_TOOLS_MOCK_BLIND_SIGN_HTTP_INTERFACE_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "quiche/blind_sign_auth/blind_sign_http_interface.h" +#include "quiche/blind_sign_auth/blind_sign_http_response.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace quiche::test { + +class QUICHE_NO_EXPORT MockBlindSignHttpInterface + : public BlindSignHttpInterface { + public: + MOCK_METHOD( + void, DoRequest, + (const std::string& path_and_query, + const std::string& authorization_header, const std::string& body, + std::function)> callback), + (override)); +}; + +} // namespace quiche::test + +#endif // QUICHE_BLIND_SIGN_AUTH_TEST_TOOLS_MOCK_BLIND_SIGN_HTTP_INTERFACE_H_ diff --git a/quiche/common/anonymous_tokens/testdata/strong_rsa_modulus2048_example.binarypb b/quiche/common/anonymous_tokens/testdata/strong_rsa_modulus2048_example.binarypb new file mode 100644 index 0000000000000000000000000000000000000000..c54070b3306f9b27f7288b1f23ee01f213a50cdc GIT binary patch literal 1178 zcmV;L1ZDdQfCA4NC)U(?ug)O>ISCG_96E(ti}B6qv3i7=BxSK7PqapJ8diVO-E>p%y1oXZve(H9Q`OzT zaBKUPzP*<{kDSDUK>sBAYG^Ivn$sm{iBTt_>2e_InG;!JDKb@rd0@sW^R^Kp()Tss z6QVC_5cY;FB?b@e1+fP)mbn)&o0x*^XNXOq_t{cO`cR>&^AeBK2J&yr46}=e*ek)- zcu>uFj~NmJ0RRCSfC5tdqZ}oc0ZDS!Kbek$TPb7+v;r5!cVMN%i{U0}oFgf9>YAx$ z=KIQFA;CTe^#-!$+2Sn@xc>FBG%4JoD`=gy%C1M*r!N4;CLpZ56n`M`YZ1!(`L4dW zb$y{H0A~GFWGoqdT4mpBPK5W6&=ihUNsE?YjVw9DTI1at0AUfNsN0z)V|&zMIS5fT ztq9V}WpH=cXT6WO_c)Mx44O8#4%t9ldQ5>T3lsOHS}V|s2Z8ebs9&5#J0x^O7bA+G z5Zb?;XPqxxj4aQ*f(pvTGa5k5NrKSsEs+pUINJw`>37}$x5_atoi%~0DvD2T{htAo;KaDL;y(pk1XJMQ=XUQGFR7%DX2n zu$qk@WT7PR(b*B&SNTVA_EQj*b%;8JtO1_lbq6r~j@wd-9RhAu8algllTmV|Tz%d6`C&BFU!ze1OZ~Lg?vu|D_9)MWm zDFVU~lVVsm#xrT|%OKoaQJtVYWugW{BJoNP5*6o+P8qJK;E^;v4E!OB6*9;eUd11N zxM6N~?uL1lcxz@tlp~(aR8)U8w?wt{mQxhTc%J=#mmo{_AGYf{fB}=H#kf8XvBN5> zm-JgqakGP`i>{6ciJd-s_96XAG+p$?q@BYq^d9-b#ewdjJg)ljQbR<;tIXi-g{gRt z1GBOEpHAtE4zt!6G;aeK(rYq|#F)zq1nGOqs+}ycy~zJSf44*cZ0~lt<_~Ie$c}QJ zF1JD!-+UFjChvpZypuwJ0X?w=Qx(9Wm0awYP3u(?zAXlH?zJ$l8mdxDB}%?VN5gJF zuYJFRJH)>ZpqX5=c?K-e?D45b)Db#QaFb8b$Tosbbl;^}px6^pi?U+moZy`526y;) sL0iFC2C8QOUy2>I7J5!d(q%Vb$BTCKPL;TiNGr<-QKC7!mdV#Y{b8v+k^lez literal 0 HcmV?d00001 diff --git a/quiche/common/anonymous_tokens/testdata/strong_rsa_modulus2048_example_2.binarypb b/quiche/common/anonymous_tokens/testdata/strong_rsa_modulus2048_example_2.binarypb new file mode 100644 index 0000000000000000000000000000000000000000..50faa9693365c7d7abdec293569f38aed7d5f61e GIT binary patch literal 1178 zcmV;L1ZDdQfC8?2nH9^N4BqsCVllC*keFF#Xek9CbrBR0 zzN2ZCQDC(if?#jKD~SL$aenj4L&j&;NmBzEq;`AB++T*0Q_)Z)?0#)iKiYihvsX-* znDXSOL3z~p$sAbE?lOE(it8vSiOsUMZPd$F|8hbA>jqHA3Nxi+un&!ML1P{qgg$_o z=^L{BW^ru}gf!0EkVC#iIeLd>o(Q7|Wd5|%goXrozBdbmDAmIk4~}3N4kl7Emj)PwWKe z(gizcbrli=0RRCSfC3iGIdxi;I(iHodo2hrTeF>7e+;*FcYsYWPV6;Jmh}MveBYNg z<#goDM@wwEK-a|b0L1?F8*UsV0P^);WMgTP|C~s z$f!YJY=(!*sRbf{0oW91fnc|LAv+)|P@i3hN~8!y=_RoCu0j9pFLpN+fU@po&c{zY zgkKN0NtsK(vUlBj7Lj5;1B+`FM=VYhlWUs!A;_?qi2%=S!IlmR(-S36IkmfNC^aE; z?ncnN7X`W6TPyfAXb6_WhULm;nx?lm>k*)PPKd#Md=}M^)kk+#YxaRZ%Pk9UQlIs5 zOyLX-_FJ9@fBmjIGJpY8QL=72HJ?D*ceO8K;P%{EqBzySHH(#HF-_>9!eKI{q)m5c+9 z#6`L2Rw9r3YCUj!Fsaj{idDXGy_M`Yz1Y43euOADqClI}0XNk;fB_xJ37<7;^xe9m z;!@7!iI%;U1XL@dWnD5km0==%X(_#U2sGEoJ~ZH%4y^z>_4EUe`!$%t!3hoLteCm> zW!kn@htAzoTH_{-;8wY%yDM`odkbr?C+RD?zjx&6`Hwy+03N+eZQ5oM|E`pD>KEZW z;T_fs`!lQ3siTDgrEfxj0TY9oaZOBLo`CL2R!+N{KpoLpg|VMe(rJg=?!#Q+DeFBX zIslQ5F%VA^B`WWfiD^wg=xu<}yD&rM_z|V|I6}K*sH(F&StBQuu59(~u6oIrbz3)X stw9>eJ3P^?gaEeg0c6#bek87A#u#&WTzTUuWmwfdL+Yj_fh#D7ECEY7>i_@% literal 0 HcmV?d00001 diff --git a/quiche/common/anonymous_tokens/testdata/strong_rsa_modulus3072_example.binarypb b/quiche/common/anonymous_tokens/testdata/strong_rsa_modulus3072_example.binarypb new file mode 100644 index 0000000000000000000000000000000000000000..3e4bae91980163663409d41eb7c9ba3d7233193f GIT binary patch literal 1754 zcmV<01||6lfCIc{lhM6cPF{nYx=m{fvq#d=&-}#45uurRMr4sR$gQ2WdwOy{x76%B`FPFvSL}OOuj;$^Z(JY5#=9j$Q7eq9&&ee| ze!415R~P$E2D&7a>ttc1Nqe#;w+?XV?2^fYIeLsp#B0Z)dq9Sr16UzT2X5#tC`h@a zumuy7bX??0q0tGy$o9d~zdXiksM*14S-+2s2qxzwI(+M=qvPWkGz5vu7?MWhB9<}J zD%+cHYOo)}7!IPt-T-!iuMRH!e&%rY&sSwRC!%dMAFX-Pd$-@w)<{{8AdqB}814>i zxPDr1lbo{<_+G_P>+RaY*t+j&r zrn;NILrQOLCXc{d*L4FSnfN2ic59kLP1{?Xr*{+X>qkL{d*wTJdc1xj+<=E`6;Gij zFGM{qDf?~I;RfJw#X~90kR@%yW3|S!uP2_T3qJy$A?T&((R5@v)%FPAP;bM%Q~w2p zq8IL$<6hTddC`u$(TukQ513-G+7lS~!cW|(xiJi}M6zP)A@Mfm*~-y`Xtt8u`kt7@ z;a$sux%JAO5X~+EoqBFC@!6&(Q0yKpyn;|^9;(k0Tf&k3hR-4L`^0UJO-N}-%K6j< zzBGlov=!}eNe8#D#e9S=^5@>WXf`{_vflTYjZu6|S8KBC#F@u9s(O$zn*=$7x@%RC zK(~&UnTVGvzr(agVUDTsb@RmU8_+~A23|` z<4$>V=9Uq?T_iceL{iR8vw|O>3%)Wkhs&(K;1uRbRs$-K?o=_&t=Q{{{w_ zV}of>K$wvr9uUPF?63EH=0bXg=KI)=$g zY-80ko?sq~s`>7{4qY8;tz7F5T% z)`0cufVz0-O(~zL6!JTZ9sDD|VQ=+K$)?}69)&~+9VvvQZ2-*nPnfmXD0Shjh*oy=O_3AA{S&HMpO@!C#zV_XA!O0kWs0>)xYm`Bz$WK0p&1L^Z}Tr@7v$-|Heey6YfSk!4BhkYSB_5U;_C z@;|@2jh6$+XgJAOnP2#&fVUao`LJi> zaqh*2X;*s!Q)3YRy{U|mM2&Trcn}t^VB=qruOXH!{qhutTsfzCsryo7s0_0MeW12S wDSP%6?)4>XQJ-PX;it<}#W{XZ7y8%>d%@xvxhi$ues1Mifv}|_8ZCDVY26QVq5uE@ literal 0 HcmV?d00001 diff --git a/quiche/common/anonymous_tokens/testdata/strong_rsa_modulus4096_example.binarypb b/quiche/common/anonymous_tokens/testdata/strong_rsa_modulus4096_example.binarypb new file mode 100644 index 0000000000000000000000000000000000000000..f322a9bfe2e240251f7455cca8f4686ee5599ce0 GIT binary patch literal 2330 zcmV+#3FY<*fCR`?Clgjd^$WlHMU6Wd&U5-Hg}t;%0mq6k$~#uljbfxyjvgh z;_%xi_^|uFFX*Q5k(Pmbo?dEbBfIN!sA1gqS0BF-_+d9vgvyww8^j$S2mfQY8!9O9 zPWR8qVW+IEaLOS`1Gv7jegfZ^q7pn55c&*t=F!UUA*&DkK%wY?Xf)dDoVrzU2AHiO z5SDM8sXV_XMaK!9hr(%>6hN?$njj2;7c-3@CgslhnZuvM(4$5OWL&k#&IA-87cv9! zo!0wQNxeVOQ_FhfSWl7aP1e+lVxc5bJW*f`{6snWHHzY9wr$orT1Xe@A2+ypW;&pG zV5!6)lTQP=vR%U}z<@~57WxS!ETHI`w^mHMIyz1z#9;cii6=IbBZ%`e_wLsn zP|oH+S(nwS#-9y`{&5f_UGnn)4)aBm(QPsC!R#jVN|v-TIo!PB9OL2T@k?!q3X#7v zdhHSe0RRCSfCMW&^8;H7xV-fbE@e7>D-qz$hPQaji;YQmm!`04neZsid%o#MpR3al zpI>2Rv_US-_Xcwy_^6%Pc7PliE(2yPa^w)Z>Hq#he}|vCQKY&2@Xp_n+xu{**h>hu z{r~~Vrt~U<$OLo&d!Vk3Vts~w_<%DN5(Gcn^*wvcqcuv1U@joD^{k#;v-rA)L$z4) zVUC9JiS{nJd5i;f#{og{ch(Hhx&u}(D+7V6&g1u(c^mS1 z{_Ke@62ISV_xr+$ew&`Q#DDZ@*`7`3FK6ouquHRf&)ILnsI&F_VT^mX%B9w`@~j(w=wm_R6n@`H@J>m{J#MX-p9t7W35 zrRN)-9#?&oT4`MM48P9dB^tJtup-A3g%WpV$Z>FnRjZfC8QzaPMRPaMJNmgdUi*$X z_G0^G4c4T4eg>uwxFiSdLEUnmN(r`Biy~Lh4z?b;@5uTV?8A1`sO3*Z=d+?dp;?WW z#{Dxw5@)zN+VYNXE~wJ3U28&7v4BrnFU{Z0KijdaGFg04Q4SW=NeORaMeLs#11Olen2BkT>BGz_GNoVt^~nFNn%xnDVfx9Y-#Dt zSh_8$mi%aI|6m$p<@O}amfdVj*KS%9^9gG znht1V-4dRCQLI63(#3JqPvgpXic$-80F;NyG3FYajS!F{&2eh-g)@5nq@02_;00x> zFVE*+fuX<~iGEXMO}5KF0U0NJt<=WA^c-1tnhVsui4bBXchMF9y|clsO0GWDd_s#YTA6`ZUU(#qF?)t!y?fu2EG$`!DS)3k zfC7&`32;DBJ+od)wPYKGRyP!(rvPQ(R?t)!++>Fm1O`eSbYOa?{PG=fvZlyM68~vt z=Z%rvkvnLH%E#13vt?2fei!p({SLnT3;l#Ib>bB8)?P{d7^i@gc~r6F2UCAt+@Ajv zLBn{+$``|=JU90)=Weu_#pFO$+FbN*z9LekyJVD#px$ALiupL_7FMLSbmv56#Q(zk z4w1~8D;k=ou9_{vT~*r-i{#MH9>R@0CC~6n?X$ARNjPOg)z%+3!!TRJB5=h-$9cmSt#z AE&u=k literal 0 HcmV?d00001 diff --git a/quiche/common/btree_scheduler.h b/quiche/common/btree_scheduler.h new file mode 100644 index 000000000000..fd07d36a35da --- /dev/null +++ b/quiche/common/btree_scheduler.h @@ -0,0 +1,297 @@ +// Copyright 2023 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_BTREE_SCHEDULER_H_ +#define QUICHE_COMMON_BTREE_SCHEDULER_H_ + +#include + +#include "absl/container/btree_map.h" +#include "absl/container/node_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace quiche { + +// BTreeScheduler is a data structure that allows streams (and potentially other +// entities) to be scheduled according to the arbitrary priorities. The API for +// using the scheduler can be used as follows: +// - A stream has to be registered with a priority before being scheduled. +// - A stream can be unregistered, or can be re-prioritized. +// - A stream can be scheduled; that adds it into the queue. +// - PopFront() will return the stream with highest priority. +// - ShouldYield() will return if there is a stream with higher priority than +// the specified one. +// +// The prioritization works as following: +// - If two streams have different priorities, the higher priority stream goes +// first. +// - If two streams have the same priority, the one that got scheduled earlier +// goes first. Internally, this is implemented by assigning a monotonically +// decreasing sequence number to every newly scheduled stream. +// +// The Id type has to define operator==, be hashable via absl::Hash, and +// printable via operator<<; the Priority type has to define operator<. +template +class QUICHE_EXPORT BTreeScheduler { + public: + // Returns true if there are any streams scheduled. + bool HasScheduled() const { return !schedule_.empty(); } + // Returns the number of currently scheduled streams. + size_t NumScheduled() const { return schedule_.size(); } + + // Counts the number of scheduled entries in the range [min, max]. If either + // min or max is omitted, negative or positive infinity is assumed. + size_t NumScheduledInPriorityRange(absl::optional min, + absl::optional max) const; + + // Returns true if there is a stream that would go before `id` in the + // schedule. + absl::StatusOr ShouldYield(Id id) const; + + // Returns the priority for `id`, or nullopt if stream is not registered. + absl::optional GetPriorityFor(Id id) const { + auto it = streams_.find(id); + if (it == streams_.end()) { + return absl::nullopt; + } + return it->second.priority; + } + + // Pops the highest priority stream. Will fail if the schedule is empty. + absl::StatusOr PopFront(); + + // Registers the specified stream with the supplied priority. The stream must + // not be already registered. + absl::Status Register(Id stream_id, const Priority& priority); + // Unregisters a previously registered stream. + absl::Status Unregister(Id stream_id); + // Alters the priority of an already registered stream. + absl::Status UpdatePriority(Id stream_id, const Priority& new_priority); + + // Adds the `stream` into the schedule if it's not already there. + absl::Status Schedule(Id stream_id); + // Returns true if `stream` is in the schedule. + bool IsScheduled(Id stream_id) const; + + private: + // A record for a registered stream. + struct StreamEntry { + // The current priority of the stream. + Priority priority; + // If present, the sequence number with which the stream is currently + // scheduled. If absent, indicates that the stream is not scheduled. + absl::optional current_sequence_number; + + bool scheduled() const { return current_sequence_number.has_value(); } + }; + // The full entry for the stream (includes the ID that's used as a hashmap + // key). + using FullStreamEntry = std::pair; + + // A key that is used to order entities within the schedule. + struct ScheduleKey { + // The main order key: the priority of the stream. + Priority priority; + // The secondary order key: the sequence number. + int sequence_number; + + // Orders schedule keys in order of decreasing priority. + bool operator<(const ScheduleKey& other) const { + return std::make_tuple(priority, sequence_number) > + std::make_tuple(other.priority, other.sequence_number); + } + + // In order to find all entities with priority `p`, one can iterate between + // `lower_bound(MinForPriority(p))` and `upper_bound(MaxForPriority(p))`. + static ScheduleKey MinForPriority(Priority priority) { + return ScheduleKey{priority, std::numeric_limits::max()}; + } + static ScheduleKey MaxForPriority(Priority priority) { + return ScheduleKey{priority, std::numeric_limits::min()}; + } + }; + using FullScheduleEntry = std::pair; + using ScheduleIterator = + typename absl::btree_map::const_iterator; + + // Convenience method to get the stream ID for a schedule entry. + static Id StreamId(const FullScheduleEntry& entry) { + return entry.second->first; + } + + // Removes a stream from the schedule, and returns the old entry if it were + // present. + absl::StatusOr DescheduleStream(const StreamEntry& entry); + + // The map of currently registered streams. + absl::node_hash_map streams_; + // The stream schedule, ordered starting from the highest priority stream. + absl::btree_map schedule_; + + // The counter that is used to ensure that streams with the same priority are + // handled in the FIFO order. Decreases with every write. + int current_write_sequence_number_ = 0; +}; + +template +size_t BTreeScheduler::NumScheduledInPriorityRange( + absl::optional min, absl::optional max) const { + if (min.has_value() && max.has_value()) { + QUICHE_DCHECK(*min <= *max); + } + // This is reversed, since the schedule is ordered in the descending priority + // order. + ScheduleIterator begin = + max.has_value() ? schedule_.lower_bound(ScheduleKey::MinForPriority(*max)) + : schedule_.begin(); + ScheduleIterator end = + min.has_value() ? schedule_.upper_bound(ScheduleKey::MaxForPriority(*min)) + : schedule_.end(); + return end - begin; +} + +template +absl::Status BTreeScheduler::Register(Id stream_id, + const Priority& priority) { + auto [it, success] = streams_.insert({stream_id, StreamEntry{priority}}); + if (!success) { + return absl::AlreadyExistsError("ID already registered"); + } + return absl::OkStatus(); +} + +template +auto BTreeScheduler::DescheduleStream(const StreamEntry& entry) + -> absl::StatusOr { + QUICHE_DCHECK(entry.scheduled()); + auto it = schedule_.find( + ScheduleKey{entry.priority, *entry.current_sequence_number}); + if (it == schedule_.end()) { + return absl::InternalError( + "Calling DescheduleStream() on an entry that is not in the schedule at " + "the expected key."); + } + FullScheduleEntry result = *it; + schedule_.erase(it); + return result; +} + +template +absl::Status BTreeScheduler::Unregister(Id stream_id) { + auto it = streams_.find(stream_id); + if (it == streams_.end()) { + return absl::NotFoundError("Stream not registered"); + } + const StreamEntry& stream = it->second; + + if (stream.scheduled()) { + if (!DescheduleStream(stream).ok()) { + QUICHE_BUG(BTreeSchedule_Unregister_NotInSchedule) + << "UnregisterStream() called on a stream ID " << stream_id + << ", which is marked ready, but is not in the schedule"; + } + } + + streams_.erase(it); + return absl::OkStatus(); +} + +template +absl::Status BTreeScheduler::UpdatePriority( + Id stream_id, const Priority& new_priority) { + auto it = streams_.find(stream_id); + if (it == streams_.end()) { + return absl::NotFoundError("ID not registered"); + } + + StreamEntry& stream = it->second; + absl::optional sequence_number; + if (stream.scheduled()) { + absl::StatusOr old_entry = DescheduleStream(stream); + if (old_entry.ok()) { + sequence_number = old_entry->first.sequence_number; + QUICHE_DCHECK_EQ(old_entry->second, &*it); + } else { + QUICHE_BUG(BTreeScheduler_Update_Not_In_Schedule) + << "UpdatePriority() called on a stream ID " << stream_id + << ", which is marked ready, but is not in the schedule"; + } + } + + stream.priority = new_priority; + if (sequence_number.has_value()) { + schedule_.insert({ScheduleKey{stream.priority, *sequence_number}, &*it}); + } + return absl::OkStatus(); +} + +template +absl::StatusOr BTreeScheduler::ShouldYield( + Id stream_id) const { + const auto stream_it = streams_.find(stream_id); + if (stream_it == streams_.end()) { + return absl::NotFoundError("ID not registered"); + } + const StreamEntry& stream = stream_it->second; + + if (schedule_.empty()) { + return false; + } + const FullScheduleEntry& next = *schedule_.begin(); + if (StreamId(next) == stream_id) { + return false; + } + return next.first.priority >= stream.priority; +} + +template +absl::StatusOr BTreeScheduler::PopFront() { + if (schedule_.empty()) { + return absl::NotFoundError("No streams scheduled"); + } + auto schedule_it = schedule_.begin(); + QUICHE_DCHECK(schedule_it->second->second.scheduled()); + schedule_it->second->second.current_sequence_number = absl::nullopt; + + Id result = StreamId(*schedule_it); + schedule_.erase(schedule_it); + return result; +} + +template +absl::Status BTreeScheduler::Schedule(Id stream_id) { + const auto stream_it = streams_.find(stream_id); + if (stream_it == streams_.end()) { + return absl::NotFoundError("ID not registered"); + } + if (stream_it->second.scheduled()) { + return absl::OkStatus(); + } + auto [schedule_it, success] = + schedule_.insert({ScheduleKey{stream_it->second.priority, + --current_write_sequence_number_}, + &*stream_it}); + QUICHE_BUG_IF(WebTransportWriteBlockedList_AddStream_conflict, !success) + << "Conflicting key in scheduler for stream " << stream_id; + stream_it->second.current_sequence_number = + schedule_it->first.sequence_number; + return absl::OkStatus(); +} + +template +bool BTreeScheduler::IsScheduled(Id stream_id) const { + const auto stream_it = streams_.find(stream_id); + if (stream_it == streams_.end()) { + return false; + } + return stream_it->second.scheduled(); +} + +} // namespace quiche + +#endif // QUICHE_COMMON_BTREE_SCHEDULER_H_ diff --git a/quiche/common/btree_scheduler_test.cc b/quiche/common/btree_scheduler_test.cc new file mode 100644 index 000000000000..d3a806c550f9 --- /dev/null +++ b/quiche/common/btree_scheduler_test.cc @@ -0,0 +1,281 @@ +// Copyright 2023 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/btree_scheduler.h" + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +namespace quiche::test { +namespace { + +using ::testing::ElementsAre; +using ::testing::Optional; + +template +void ScheduleIds(BTreeScheduler& scheduler, + absl::Span ids) { + for (Id id : ids) { + QUICHE_EXPECT_OK(scheduler.Schedule(id)); + } +} + +template +std::vector PopAll(BTreeScheduler& scheduler) { + std::vector result; + result.reserve(scheduler.NumScheduled()); + for (;;) { + absl::StatusOr id = scheduler.PopFront(); + if (id.ok()) { + result.push_back(*id); + } else { + EXPECT_THAT(id, StatusIs(absl::StatusCode::kNotFound)); + break; + } + } + return result; +} + +TEST(BTreeSchedulerTest, SimplePop) { + BTreeScheduler scheduler; + QUICHE_EXPECT_OK(scheduler.Register(1, 100)); + QUICHE_EXPECT_OK(scheduler.Register(2, 101)); + QUICHE_EXPECT_OK(scheduler.Register(3, 102)); + + EXPECT_THAT(scheduler.GetPriorityFor(1), Optional(100)); + EXPECT_THAT(scheduler.GetPriorityFor(3), Optional(102)); + EXPECT_EQ(scheduler.GetPriorityFor(5), absl::nullopt); + + EXPECT_EQ(scheduler.NumScheduled(), 0u); + EXPECT_FALSE(scheduler.HasScheduled()); + QUICHE_EXPECT_OK(scheduler.Schedule(1)); + QUICHE_EXPECT_OK(scheduler.Schedule(2)); + QUICHE_EXPECT_OK(scheduler.Schedule(3)); + EXPECT_EQ(scheduler.NumScheduled(), 3u); + EXPECT_TRUE(scheduler.HasScheduled()); + + EXPECT_THAT(scheduler.PopFront(), IsOkAndHolds(3)); + EXPECT_THAT(scheduler.PopFront(), IsOkAndHolds(2)); + EXPECT_THAT(scheduler.PopFront(), IsOkAndHolds(1)); + + QUICHE_EXPECT_OK(scheduler.Schedule(2)); + QUICHE_EXPECT_OK(scheduler.Schedule(1)); + QUICHE_EXPECT_OK(scheduler.Schedule(3)); + + EXPECT_THAT(scheduler.PopFront(), IsOkAndHolds(3)); + EXPECT_THAT(scheduler.PopFront(), IsOkAndHolds(2)); + EXPECT_THAT(scheduler.PopFront(), IsOkAndHolds(1)); + + QUICHE_EXPECT_OK(scheduler.Schedule(3)); + QUICHE_EXPECT_OK(scheduler.Schedule(1)); + + EXPECT_THAT(scheduler.PopFront(), IsOkAndHolds(3)); + EXPECT_THAT(scheduler.PopFront(), IsOkAndHolds(1)); +} + +TEST(BTreeSchedulerTest, FIFO) { + BTreeScheduler scheduler; + QUICHE_EXPECT_OK(scheduler.Register(1, 100)); + QUICHE_EXPECT_OK(scheduler.Register(2, 100)); + QUICHE_EXPECT_OK(scheduler.Register(3, 100)); + + ScheduleIds(scheduler, {2, 1, 3}); + EXPECT_THAT(PopAll(scheduler), ElementsAre(2, 1, 3)); + + QUICHE_EXPECT_OK(scheduler.Register(4, 101)); + QUICHE_EXPECT_OK(scheduler.Register(5, 99)); + + ScheduleIds(scheduler, {5, 1, 2, 3, 4}); + EXPECT_THAT(PopAll(scheduler), ElementsAre(4, 1, 2, 3, 5)); + ScheduleIds(scheduler, {1, 5, 2, 4, 3}); + EXPECT_THAT(PopAll(scheduler), ElementsAre(4, 1, 2, 3, 5)); + ScheduleIds(scheduler, {3, 5, 2, 4, 1}); + EXPECT_THAT(PopAll(scheduler), ElementsAre(4, 3, 2, 1, 5)); + ScheduleIds(scheduler, {3, 2, 1, 2, 3}); + EXPECT_THAT(PopAll(scheduler), ElementsAre(3, 2, 1)); +} + +TEST(BTreeSchedulerTest, NumEntriesInRange) { + BTreeScheduler scheduler; + QUICHE_EXPECT_OK(scheduler.Register(1, 0)); + QUICHE_EXPECT_OK(scheduler.Register(2, 0)); + QUICHE_EXPECT_OK(scheduler.Register(3, 0)); + QUICHE_EXPECT_OK(scheduler.Register(4, -2)); + QUICHE_EXPECT_OK(scheduler.Register(5, -5)); + QUICHE_EXPECT_OK(scheduler.Register(6, 10)); + QUICHE_EXPECT_OK(scheduler.Register(7, 16)); + QUICHE_EXPECT_OK(scheduler.Register(8, 32)); + QUICHE_EXPECT_OK(scheduler.Register(9, 64)); + + EXPECT_EQ(scheduler.NumScheduled(), 0u); + EXPECT_EQ(scheduler.NumScheduledInPriorityRange(absl::nullopt, absl::nullopt), + 0u); + EXPECT_EQ(scheduler.NumScheduledInPriorityRange(-1, 1), 0u); + + for (int stream = 1; stream <= 9; ++stream) { + QUICHE_ASSERT_OK(scheduler.Schedule(stream)); + } + + EXPECT_EQ(scheduler.NumScheduled(), 9u); + EXPECT_EQ(scheduler.NumScheduledInPriorityRange(absl::nullopt, absl::nullopt), + 9u); + EXPECT_EQ(scheduler.NumScheduledInPriorityRange(0, 0), 3u); + EXPECT_EQ(scheduler.NumScheduledInPriorityRange(absl::nullopt, -1), 2u); + EXPECT_EQ(scheduler.NumScheduledInPriorityRange(1, absl::nullopt), 4u); +} + +TEST(BTreeSchedulerTest, Registration) { + BTreeScheduler scheduler; + QUICHE_EXPECT_OK(scheduler.Register(1, 0)); + QUICHE_EXPECT_OK(scheduler.Register(2, 0)); + + QUICHE_EXPECT_OK(scheduler.Schedule(1)); + QUICHE_EXPECT_OK(scheduler.Schedule(2)); + EXPECT_EQ(scheduler.NumScheduled(), 2u); + EXPECT_TRUE(scheduler.IsScheduled(2)); + + EXPECT_THAT(scheduler.Register(2, 0), + StatusIs(absl::StatusCode::kAlreadyExists)); + QUICHE_EXPECT_OK(scheduler.Unregister(2)); + EXPECT_EQ(scheduler.NumScheduled(), 1u); + EXPECT_FALSE(scheduler.IsScheduled(2)); + + EXPECT_THAT(scheduler.UpdatePriority(2, 1234), + StatusIs(absl::StatusCode::kNotFound)); + EXPECT_THAT(scheduler.Unregister(2), StatusIs(absl::StatusCode::kNotFound)); + EXPECT_THAT(scheduler.Schedule(2), StatusIs(absl::StatusCode::kNotFound)); + QUICHE_EXPECT_OK(scheduler.Register(2, 0)); + EXPECT_EQ(scheduler.NumScheduled(), 1u); + EXPECT_TRUE(scheduler.IsScheduled(1)); + EXPECT_FALSE(scheduler.IsScheduled(2)); +} + +TEST(BTreeSchedulerTest, UpdatePriorityUp) { + BTreeScheduler scheduler; + QUICHE_EXPECT_OK(scheduler.Register(1, 0)); + QUICHE_EXPECT_OK(scheduler.Register(2, 0)); + QUICHE_EXPECT_OK(scheduler.Register(3, 0)); + + ScheduleIds(scheduler, {1, 2, 3}); + QUICHE_EXPECT_OK(scheduler.UpdatePriority(2, 1000)); + EXPECT_THAT(PopAll(scheduler), ElementsAre(2, 1, 3)); +} + +TEST(BTreeSchedulerTest, UpdatePriorityDown) { + BTreeScheduler scheduler; + QUICHE_EXPECT_OK(scheduler.Register(1, 0)); + QUICHE_EXPECT_OK(scheduler.Register(2, 0)); + QUICHE_EXPECT_OK(scheduler.Register(3, 0)); + + ScheduleIds(scheduler, {1, 2, 3}); + QUICHE_EXPECT_OK(scheduler.UpdatePriority(2, -1000)); + EXPECT_THAT(PopAll(scheduler), ElementsAre(1, 3, 2)); +} + +TEST(BTreeSchedulerTest, UpdatePriorityEqual) { + BTreeScheduler scheduler; + QUICHE_EXPECT_OK(scheduler.Register(1, 0)); + QUICHE_EXPECT_OK(scheduler.Register(2, 0)); + QUICHE_EXPECT_OK(scheduler.Register(3, 0)); + + ScheduleIds(scheduler, {1, 2, 3}); + QUICHE_EXPECT_OK(scheduler.UpdatePriority(2, 0)); + EXPECT_THAT(PopAll(scheduler), ElementsAre(1, 2, 3)); +} + +TEST(BTreeSchedulerTest, UpdatePriorityIntoSameBucket) { + BTreeScheduler scheduler; + QUICHE_EXPECT_OK(scheduler.Register(1, 0)); + QUICHE_EXPECT_OK(scheduler.Register(2, -100)); + QUICHE_EXPECT_OK(scheduler.Register(3, 0)); + + ScheduleIds(scheduler, {1, 2, 3}); + QUICHE_EXPECT_OK(scheduler.UpdatePriority(2, 0)); + EXPECT_THAT(PopAll(scheduler), ElementsAre(1, 2, 3)); +} + +TEST(BTreeSchedulerTest, ShouldYield) { + BTreeScheduler scheduler; + QUICHE_EXPECT_OK(scheduler.Register(10, 100)); + QUICHE_EXPECT_OK(scheduler.Register(20, 101)); + QUICHE_EXPECT_OK(scheduler.Register(21, 101)); + QUICHE_EXPECT_OK(scheduler.Register(30, 102)); + + EXPECT_THAT(scheduler.ShouldYield(10), IsOkAndHolds(false)); + EXPECT_THAT(scheduler.ShouldYield(20), IsOkAndHolds(false)); + EXPECT_THAT(scheduler.ShouldYield(21), IsOkAndHolds(false)); + EXPECT_THAT(scheduler.ShouldYield(30), IsOkAndHolds(false)); + EXPECT_THAT(scheduler.ShouldYield(40), StatusIs(absl::StatusCode::kNotFound)); + + QUICHE_EXPECT_OK(scheduler.Schedule(20)); + + EXPECT_THAT(scheduler.ShouldYield(10), IsOkAndHolds(true)); + EXPECT_THAT(scheduler.ShouldYield(20), IsOkAndHolds(false)); + EXPECT_THAT(scheduler.ShouldYield(21), IsOkAndHolds(true)); + EXPECT_THAT(scheduler.ShouldYield(30), IsOkAndHolds(false)); +} + +struct CustomPriority { + int a; + int b; + + bool operator<(const CustomPriority& other) const { + return std::make_tuple(a, b) < std::make_tuple(other.a, other.b); + } +}; + +TEST(BTreeSchedulerTest, CustomPriority) { + BTreeScheduler scheduler; + QUICHE_EXPECT_OK(scheduler.Register(10, CustomPriority{0, 1})); + QUICHE_EXPECT_OK(scheduler.Register(11, CustomPriority{0, 0})); + QUICHE_EXPECT_OK(scheduler.Register(12, CustomPriority{0, 0})); + QUICHE_EXPECT_OK(scheduler.Register(13, CustomPriority{10, 0})); + QUICHE_EXPECT_OK(scheduler.Register(14, CustomPriority{-10, 0})); + + ScheduleIds(scheduler, {10, 11, 12, 13, 14}); + EXPECT_THAT(PopAll(scheduler), ElementsAre(13, 10, 11, 12, 14)); +} + +struct CustomId { + int a; + std::string b; + + bool operator==(const CustomId& other) const { + return a == other.a && b == other.b; + } + + template + friend H AbslHashValue(H h, const CustomId& c) { + return H::combine(std::move(h), c.a, c.b); + } +}; + +std::ostream& operator<<(std::ostream& os, const CustomId& id) { + os << id.a << ":" << id.b; + return os; +} + +TEST(BTreeSchedulerTest, CustomIds) { + BTreeScheduler scheduler; + QUICHE_EXPECT_OK(scheduler.Register(CustomId{1, "foo"}, 10)); + QUICHE_EXPECT_OK(scheduler.Register(CustomId{1, "bar"}, 12)); + QUICHE_EXPECT_OK(scheduler.Register(CustomId{2, "foo"}, 11)); + EXPECT_THAT(scheduler.Register(CustomId{1, "foo"}, 10), + StatusIs(absl::StatusCode::kAlreadyExists)); + + ScheduleIds(scheduler, + {CustomId{1, "foo"}, CustomId{1, "bar"}, CustomId{2, "foo"}}); + EXPECT_THAT(scheduler.ShouldYield(CustomId{1, "foo"}), IsOkAndHolds(true)); + EXPECT_THAT(scheduler.ShouldYield(CustomId{1, "bar"}), IsOkAndHolds(false)); + EXPECT_THAT( + PopAll(scheduler), + ElementsAre(CustomId{1, "bar"}, CustomId{2, "foo"}, CustomId{1, "foo"})); +} + +} // namespace +} // namespace quiche::test diff --git a/quiche/common/capsule.cc b/quiche/common/capsule.cc new file mode 100644 index 000000000000..4c5ad5ab0f35 --- /dev/null +++ b/quiche/common/capsule.cc @@ -0,0 +1,706 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/capsule.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "absl/types/variant.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_buffer_allocator.h" +#include "quiche/common/quiche_data_reader.h" +#include "quiche/common/quiche_data_writer.h" +#include "quiche/common/quiche_ip_address.h" +#include "quiche/common/quiche_status_utils.h" +#include "quiche/common/wire_serialization.h" +#include "quiche/web_transport/web_transport.h" + +namespace quiche { + +std::string CapsuleTypeToString(CapsuleType capsule_type) { + switch (capsule_type) { + case CapsuleType::DATAGRAM: + return "DATAGRAM"; + case CapsuleType::LEGACY_DATAGRAM: + return "LEGACY_DATAGRAM"; + case CapsuleType::LEGACY_DATAGRAM_WITHOUT_CONTEXT: + return "LEGACY_DATAGRAM_WITHOUT_CONTEXT"; + case CapsuleType::CLOSE_WEBTRANSPORT_SESSION: + return "CLOSE_WEBTRANSPORT_SESSION"; + case CapsuleType::ADDRESS_REQUEST: + return "ADDRESS_REQUEST"; + case CapsuleType::ADDRESS_ASSIGN: + return "ADDRESS_ASSIGN"; + case CapsuleType::ROUTE_ADVERTISEMENT: + return "ROUTE_ADVERTISEMENT"; + case CapsuleType::WT_STREAM: + return "WT_STREAM"; + case CapsuleType::WT_STREAM_WITH_FIN: + return "WT_STREAM_WITH_FIN"; + case CapsuleType::WT_RESET_STREAM: + return "WT_RESET_STREAM"; + case CapsuleType::WT_STOP_SENDING: + return "WT_STOP_SENDING"; + case CapsuleType::WT_MAX_STREAM_DATA: + return "WT_MAX_STREAM_DATA"; + case CapsuleType::WT_MAX_STREAMS_BIDI: + return "WT_MAX_STREAMS_BIDI"; + case CapsuleType::WT_MAX_STREAMS_UNIDI: + return "WT_MAX_STREAMS_UNIDI"; + } + return absl::StrCat("Unknown(", static_cast(capsule_type), ")"); +} + +std::ostream& operator<<(std::ostream& os, const CapsuleType& capsule_type) { + os << CapsuleTypeToString(capsule_type); + return os; +} + +// static +Capsule Capsule::Datagram(absl::string_view http_datagram_payload) { + return Capsule(DatagramCapsule{http_datagram_payload}); +} + +// static +Capsule Capsule::LegacyDatagram(absl::string_view http_datagram_payload) { + return Capsule(LegacyDatagramCapsule{http_datagram_payload}); +} + +// static +Capsule Capsule::LegacyDatagramWithoutContext( + absl::string_view http_datagram_payload) { + return Capsule(LegacyDatagramWithoutContextCapsule{http_datagram_payload}); +} + +// static +Capsule Capsule::CloseWebTransportSession( + webtransport::SessionErrorCode error_code, + absl::string_view error_message) { + return Capsule(CloseWebTransportSessionCapsule({error_code, error_message})); +} + +// static +Capsule Capsule::AddressRequest() { return Capsule(AddressRequestCapsule()); } + +// static +Capsule Capsule::AddressAssign() { return Capsule(AddressAssignCapsule()); } + +// static +Capsule Capsule::RouteAdvertisement() { + return Capsule(RouteAdvertisementCapsule()); +} + +// static +Capsule Capsule::Unknown(uint64_t capsule_type, + absl::string_view unknown_capsule_data) { + return Capsule(UnknownCapsule{capsule_type, unknown_capsule_data}); +} + +bool Capsule::operator==(const Capsule& other) const { + return capsule_ == other.capsule_; +} + +std::string DatagramCapsule::ToString() const { + return absl::StrCat("DATAGRAM[", + absl::BytesToHexString(http_datagram_payload), "]"); +} + +std::string LegacyDatagramCapsule::ToString() const { + return absl::StrCat("LEGACY_DATAGRAM[", + absl::BytesToHexString(http_datagram_payload), "]"); +} + +std::string LegacyDatagramWithoutContextCapsule::ToString() const { + return absl::StrCat("LEGACY_DATAGRAM_WITHOUT_CONTEXT[", + absl::BytesToHexString(http_datagram_payload), "]"); +} + +std::string CloseWebTransportSessionCapsule::ToString() const { + return absl::StrCat("CLOSE_WEBTRANSPORT_SESSION(error_code=", error_code, + ",error_message=\"", error_message, "\")"); +} + +std::string AddressRequestCapsule::ToString() const { + std::string rv = "ADDRESS_REQUEST["; + for (auto requested_address : requested_addresses) { + absl::StrAppend(&rv, "(", requested_address.request_id, "-", + requested_address.ip_prefix.ToString(), ")"); + } + absl::StrAppend(&rv, "]"); + return rv; +} + +std::string AddressAssignCapsule::ToString() const { + std::string rv = "ADDRESS_ASSIGN["; + for (auto assigned_address : assigned_addresses) { + absl::StrAppend(&rv, "(", assigned_address.request_id, "-", + assigned_address.ip_prefix.ToString(), ")"); + } + absl::StrAppend(&rv, "]"); + return rv; +} + +std::string RouteAdvertisementCapsule::ToString() const { + std::string rv = "ROUTE_ADVERTISEMENT["; + for (auto ip_address_range : ip_address_ranges) { + absl::StrAppend(&rv, "(", ip_address_range.start_ip_address.ToString(), "-", + ip_address_range.end_ip_address.ToString(), "-", + static_cast(ip_address_range.ip_protocol), ")"); + } + absl::StrAppend(&rv, "]"); + return rv; +} + +std::string UnknownCapsule::ToString() const { + return absl::StrCat("Unknown(", type, ") [", absl::BytesToHexString(payload), + "]"); +} + +std::string WebTransportStreamDataCapsule::ToString() const { + return absl::StrCat(CapsuleTypeToString(capsule_type()), + " [stream_id=", stream_id, + ", data=", absl::BytesToHexString(data), "]"); +} + +std::string WebTransportResetStreamCapsule::ToString() const { + return absl::StrCat("WT_RESET_STREAM(stream_id=", stream_id, + ", error_code=", error_code, ")"); +} + +std::string WebTransportStopSendingCapsule::ToString() const { + return absl::StrCat("WT_STOP_SENDING(stream_id=", stream_id, + ", error_code=", error_code, ")"); +} + +std::string WebTransportMaxStreamDataCapsule::ToString() const { + return absl::StrCat("WT_MAX_STREAM_DATA (stream_id=", stream_id, + ", max_stream_data=", max_stream_data, ")"); +} + +std::string WebTransportMaxStreamsCapsule::ToString() const { + return absl::StrCat(CapsuleTypeToString(capsule_type()), + " (max_streams=", max_stream_count, ")"); +} + +std::string Capsule::ToString() const { + return absl::visit([](const auto& capsule) { return capsule.ToString(); }, + capsule_); +} + +std::ostream& operator<<(std::ostream& os, const Capsule& capsule) { + os << capsule.ToString(); + return os; +} + +CapsuleParser::CapsuleParser(Visitor* visitor) : visitor_(visitor) { + QUICHE_DCHECK_NE(visitor_, nullptr); +} + +// Serialization logic for quiche::PrefixWithId. +class WirePrefixWithId { + public: + using DataType = PrefixWithId; + + WirePrefixWithId(const PrefixWithId& prefix) : prefix_(prefix) {} + + size_t GetLengthOnWire() { + return ComputeLengthOnWire( + WireVarInt62(prefix_.request_id), + WireUint8(prefix_.ip_prefix.address().IsIPv4() ? 4 : 6), + WireBytes(prefix_.ip_prefix.address().ToPackedString()), + WireUint8(prefix_.ip_prefix.prefix_length())); + } + + absl::Status SerializeIntoWriter(QuicheDataWriter& writer) { + return AppendToStatus( + quiche::SerializeIntoWriter( + writer, WireVarInt62(prefix_.request_id), + WireUint8(prefix_.ip_prefix.address().IsIPv4() ? 4 : 6), + WireBytes(prefix_.ip_prefix.address().ToPackedString()), + WireUint8(prefix_.ip_prefix.prefix_length())), + " while serializing a PrefixWithId"); + } + + private: + const PrefixWithId& prefix_; +}; + +// Serialization logic for quiche::IpAddressRange. +class WireIpAddressRange { + public: + using DataType = IpAddressRange; + + explicit WireIpAddressRange(const IpAddressRange& range) : range_(range) {} + + size_t GetLengthOnWire() { + return ComputeLengthOnWire( + WireUint8(range_.start_ip_address.IsIPv4() ? 4 : 6), + WireBytes(range_.start_ip_address.ToPackedString()), + WireBytes(range_.end_ip_address.ToPackedString()), + WireUint8(range_.ip_protocol)); + } + + absl::Status SerializeIntoWriter(QuicheDataWriter& writer) { + return AppendToStatus( + ::quiche::SerializeIntoWriter( + writer, WireUint8(range_.start_ip_address.IsIPv4() ? 4 : 6), + WireBytes(range_.start_ip_address.ToPackedString()), + WireBytes(range_.end_ip_address.ToPackedString()), + WireUint8(range_.ip_protocol)), + " while serializing an IpAddressRange"); + } + + private: + const IpAddressRange& range_; +}; + +template +absl::StatusOr SerializeCapsuleFields( + CapsuleType type, QuicheBufferAllocator* allocator, T... fields) { + size_t capsule_payload_size = ComputeLengthOnWire(fields...); + return SerializeIntoBuffer(allocator, WireVarInt62(type), + WireVarInt62(capsule_payload_size), fields...); +} + +absl::StatusOr SerializeCapsuleWithStatus( + const Capsule& capsule, quiche::QuicheBufferAllocator* allocator) { + switch (capsule.capsule_type()) { + case CapsuleType::DATAGRAM: + return SerializeCapsuleFields( + capsule.capsule_type(), allocator, + WireBytes(capsule.datagram_capsule().http_datagram_payload)); + case CapsuleType::LEGACY_DATAGRAM: + return SerializeCapsuleFields( + capsule.capsule_type(), allocator, + WireBytes(capsule.legacy_datagram_capsule().http_datagram_payload)); + case CapsuleType::LEGACY_DATAGRAM_WITHOUT_CONTEXT: + return SerializeCapsuleFields( + capsule.capsule_type(), allocator, + WireBytes(capsule.legacy_datagram_without_context_capsule() + .http_datagram_payload)); + case CapsuleType::CLOSE_WEBTRANSPORT_SESSION: + return SerializeCapsuleFields( + capsule.capsule_type(), allocator, + WireUint32(capsule.close_web_transport_session_capsule().error_code), + WireBytes( + capsule.close_web_transport_session_capsule().error_message)); + case CapsuleType::ADDRESS_REQUEST: + return SerializeCapsuleFields( + capsule.capsule_type(), allocator, + WireSpan(absl::MakeConstSpan( + capsule.address_request_capsule().requested_addresses))); + case CapsuleType::ADDRESS_ASSIGN: + return SerializeCapsuleFields( + capsule.capsule_type(), allocator, + WireSpan(absl::MakeConstSpan( + capsule.address_assign_capsule().assigned_addresses))); + case CapsuleType::ROUTE_ADVERTISEMENT: + return SerializeCapsuleFields( + capsule.capsule_type(), allocator, + WireSpan(absl::MakeConstSpan( + capsule.route_advertisement_capsule().ip_address_ranges))); + case CapsuleType::WT_STREAM: + case CapsuleType::WT_STREAM_WITH_FIN: + return SerializeCapsuleFields( + capsule.capsule_type(), allocator, + WireVarInt62(capsule.web_transport_stream_data().stream_id), + WireBytes(capsule.web_transport_stream_data().data)); + case CapsuleType::WT_RESET_STREAM: + return SerializeCapsuleFields( + capsule.capsule_type(), allocator, + WireVarInt62(capsule.web_transport_reset_stream().stream_id), + WireVarInt62(capsule.web_transport_reset_stream().error_code)); + case CapsuleType::WT_STOP_SENDING: + return SerializeCapsuleFields( + capsule.capsule_type(), allocator, + WireVarInt62(capsule.web_transport_stop_sending().stream_id), + WireVarInt62(capsule.web_transport_stop_sending().error_code)); + case CapsuleType::WT_MAX_STREAM_DATA: + return SerializeCapsuleFields( + capsule.capsule_type(), allocator, + WireVarInt62(capsule.web_transport_max_stream_data().stream_id), + WireVarInt62( + capsule.web_transport_max_stream_data().max_stream_data)); + case CapsuleType::WT_MAX_STREAMS_BIDI: + case CapsuleType::WT_MAX_STREAMS_UNIDI: + return SerializeCapsuleFields( + capsule.capsule_type(), allocator, + WireVarInt62(capsule.web_transport_max_streams().max_stream_count)); + default: + return SerializeCapsuleFields( + capsule.capsule_type(), allocator, + WireBytes(capsule.unknown_capsule().payload)); + } +} + +QuicheBuffer SerializeCapsule(const Capsule& capsule, + quiche::QuicheBufferAllocator* allocator) { + absl::StatusOr serialized = + SerializeCapsuleWithStatus(capsule, allocator); + if (!serialized.ok()) { + QUICHE_BUG(capsule_serialization_failed) + << "Failed to serialize the following capsule:\n" + << capsule << "Serialization error: " << serialized.status(); + return QuicheBuffer(); + } + return *std::move(serialized); +} + +bool CapsuleParser::IngestCapsuleFragment(absl::string_view capsule_fragment) { + if (parsing_error_occurred_) { + return false; + } + absl::StrAppend(&buffered_data_, capsule_fragment); + while (true) { + const absl::StatusOr buffered_data_read = AttemptParseCapsule(); + if (!buffered_data_read.ok()) { + ReportParseFailure(buffered_data_read.status().message()); + buffered_data_.clear(); + return false; + } + if (*buffered_data_read == 0) { + break; + } + buffered_data_.erase(0, *buffered_data_read); + } + static constexpr size_t kMaxCapsuleBufferSize = 1024 * 1024; + if (buffered_data_.size() > kMaxCapsuleBufferSize) { + buffered_data_.clear(); + ReportParseFailure("Refusing to buffer too much capsule data"); + return false; + } + return true; +} + +namespace { +absl::Status ReadWebTransportStreamId(QuicheDataReader& reader, + webtransport::StreamId& id) { + uint64_t raw_id; + if (!reader.ReadVarInt62(&raw_id)) { + return absl::InvalidArgumentError("Failed to read WebTransport Stream ID"); + } + if (raw_id > std::numeric_limits::max()) { + return absl::InvalidArgumentError("Stream ID does not fit into a uint32_t"); + } + id = static_cast(raw_id); + return absl::OkStatus(); +} + +absl::StatusOr ParseCapsulePayload(QuicheDataReader& reader, + CapsuleType type) { + switch (type) { + case CapsuleType::DATAGRAM: + return Capsule::Datagram(reader.ReadRemainingPayload()); + case CapsuleType::LEGACY_DATAGRAM: + return Capsule::LegacyDatagram(reader.ReadRemainingPayload()); + case CapsuleType::LEGACY_DATAGRAM_WITHOUT_CONTEXT: + return Capsule::LegacyDatagramWithoutContext( + reader.ReadRemainingPayload()); + case CapsuleType::CLOSE_WEBTRANSPORT_SESSION: { + CloseWebTransportSessionCapsule capsule; + if (!reader.ReadUInt32(&capsule.error_code)) { + return absl::InvalidArgumentError( + "Unable to parse capsule CLOSE_WEBTRANSPORT_SESSION error code"); + } + capsule.error_message = reader.ReadRemainingPayload(); + return Capsule(std::move(capsule)); + } + case CapsuleType::ADDRESS_REQUEST: { + AddressRequestCapsule capsule; + while (!reader.IsDoneReading()) { + PrefixWithId requested_address; + if (!reader.ReadVarInt62(&requested_address.request_id)) { + return absl::InvalidArgumentError( + "Unable to parse capsule ADDRESS_REQUEST request ID"); + } + uint8_t address_family; + if (!reader.ReadUInt8(&address_family)) { + return absl::InvalidArgumentError( + "Unable to parse capsule ADDRESS_REQUEST family"); + } + if (address_family != 4 && address_family != 6) { + return absl::InvalidArgumentError("Bad ADDRESS_REQUEST family"); + } + absl::string_view ip_address_bytes; + if (!reader.ReadStringPiece(&ip_address_bytes, + address_family == 4 + ? QuicheIpAddress::kIPv4AddressSize + : QuicheIpAddress::kIPv6AddressSize)) { + return absl::InvalidArgumentError( + "Unable to read capsule ADDRESS_REQUEST address"); + } + quiche::QuicheIpAddress ip_address; + if (!ip_address.FromPackedString(ip_address_bytes.data(), + ip_address_bytes.size())) { + return absl::InvalidArgumentError( + "Unable to parse capsule ADDRESS_REQUEST address"); + } + uint8_t ip_prefix_length; + if (!reader.ReadUInt8(&ip_prefix_length)) { + return absl::InvalidArgumentError( + "Unable to parse capsule ADDRESS_REQUEST IP prefix length"); + } + if (ip_prefix_length > QuicheIpPrefix(ip_address).prefix_length()) { + return absl::InvalidArgumentError("Invalid IP prefix length"); + } + requested_address.ip_prefix = + QuicheIpPrefix(ip_address, ip_prefix_length); + capsule.requested_addresses.push_back(requested_address); + } + return Capsule(std::move(capsule)); + } + case CapsuleType::ADDRESS_ASSIGN: { + AddressAssignCapsule capsule; + while (!reader.IsDoneReading()) { + PrefixWithId assigned_address; + if (!reader.ReadVarInt62(&assigned_address.request_id)) { + return absl::InvalidArgumentError( + "Unable to parse capsule ADDRESS_ASSIGN request ID"); + } + uint8_t address_family; + if (!reader.ReadUInt8(&address_family)) { + return absl::InvalidArgumentError( + "Unable to parse capsule ADDRESS_ASSIGN family"); + } + if (address_family != 4 && address_family != 6) { + return absl::InvalidArgumentError("Bad ADDRESS_ASSIGN family"); + } + absl::string_view ip_address_bytes; + if (!reader.ReadStringPiece(&ip_address_bytes, + address_family == 4 + ? QuicheIpAddress::kIPv4AddressSize + : QuicheIpAddress::kIPv6AddressSize)) { + return absl::InvalidArgumentError( + "Unable to read capsule ADDRESS_ASSIGN address"); + } + quiche::QuicheIpAddress ip_address; + if (!ip_address.FromPackedString(ip_address_bytes.data(), + ip_address_bytes.size())) { + return absl::InvalidArgumentError( + "Unable to parse capsule ADDRESS_ASSIGN address"); + } + uint8_t ip_prefix_length; + if (!reader.ReadUInt8(&ip_prefix_length)) { + return absl::InvalidArgumentError( + "Unable to parse capsule ADDRESS_ASSIGN IP prefix length"); + } + if (ip_prefix_length > QuicheIpPrefix(ip_address).prefix_length()) { + return absl::InvalidArgumentError("Invalid IP prefix length"); + } + assigned_address.ip_prefix = + QuicheIpPrefix(ip_address, ip_prefix_length); + capsule.assigned_addresses.push_back(assigned_address); + } + return Capsule(std::move(capsule)); + } + case CapsuleType::ROUTE_ADVERTISEMENT: { + RouteAdvertisementCapsule capsule; + while (!reader.IsDoneReading()) { + uint8_t address_family; + if (!reader.ReadUInt8(&address_family)) { + return absl::InvalidArgumentError( + "Unable to parse capsule ROUTE_ADVERTISEMENT family"); + } + if (address_family != 4 && address_family != 6) { + return absl::InvalidArgumentError("Bad ROUTE_ADVERTISEMENT family"); + } + IpAddressRange ip_address_range; + absl::string_view start_ip_address_bytes; + if (!reader.ReadStringPiece(&start_ip_address_bytes, + address_family == 4 + ? QuicheIpAddress::kIPv4AddressSize + : QuicheIpAddress::kIPv6AddressSize)) { + return absl::InvalidArgumentError( + "Unable to read capsule ROUTE_ADVERTISEMENT start address"); + } + if (!ip_address_range.start_ip_address.FromPackedString( + start_ip_address_bytes.data(), start_ip_address_bytes.size())) { + return absl::InvalidArgumentError( + "Unable to parse capsule ROUTE_ADVERTISEMENT start address"); + } + absl::string_view end_ip_address_bytes; + if (!reader.ReadStringPiece(&end_ip_address_bytes, + address_family == 4 + ? QuicheIpAddress::kIPv4AddressSize + : QuicheIpAddress::kIPv6AddressSize)) { + return absl::InvalidArgumentError( + "Unable to read capsule ROUTE_ADVERTISEMENT end address"); + } + if (!ip_address_range.end_ip_address.FromPackedString( + end_ip_address_bytes.data(), end_ip_address_bytes.size())) { + return absl::InvalidArgumentError( + "Unable to parse capsule ROUTE_ADVERTISEMENT end address"); + } + if (!reader.ReadUInt8(&ip_address_range.ip_protocol)) { + return absl::InvalidArgumentError( + "Unable to parse capsule ROUTE_ADVERTISEMENT IP protocol"); + } + capsule.ip_address_ranges.push_back(ip_address_range); + } + return Capsule(std::move(capsule)); + } + case CapsuleType::WT_STREAM: + case CapsuleType::WT_STREAM_WITH_FIN: { + WebTransportStreamDataCapsule capsule; + capsule.fin = (type == CapsuleType::WT_STREAM_WITH_FIN); + QUICHE_RETURN_IF_ERROR( + ReadWebTransportStreamId(reader, capsule.stream_id)); + capsule.data = reader.ReadRemainingPayload(); + return Capsule(std::move(capsule)); + } + case CapsuleType::WT_RESET_STREAM: { + WebTransportResetStreamCapsule capsule; + QUICHE_RETURN_IF_ERROR( + ReadWebTransportStreamId(reader, capsule.stream_id)); + if (!reader.ReadVarInt62(&capsule.error_code)) { + return absl::InvalidArgumentError( + "Failed to parse the RESET_STREAM error code"); + } + return Capsule(std::move(capsule)); + } + case CapsuleType::WT_STOP_SENDING: { + WebTransportStopSendingCapsule capsule; + QUICHE_RETURN_IF_ERROR( + ReadWebTransportStreamId(reader, capsule.stream_id)); + if (!reader.ReadVarInt62(&capsule.error_code)) { + return absl::InvalidArgumentError( + "Failed to parse the STOP_SENDING error code"); + } + return Capsule(std::move(capsule)); + } + case CapsuleType::WT_MAX_STREAM_DATA: { + WebTransportMaxStreamDataCapsule capsule; + QUICHE_RETURN_IF_ERROR( + ReadWebTransportStreamId(reader, capsule.stream_id)); + if (!reader.ReadVarInt62(&capsule.max_stream_data)) { + return absl::InvalidArgumentError( + "Failed to parse the max stream data field"); + } + return Capsule(std::move(capsule)); + } + case CapsuleType::WT_MAX_STREAMS_UNIDI: + case CapsuleType::WT_MAX_STREAMS_BIDI: { + WebTransportMaxStreamsCapsule capsule; + capsule.stream_type = type == CapsuleType::WT_MAX_STREAMS_UNIDI + ? webtransport::StreamType::kUnidirectional + : webtransport::StreamType::kBidirectional; + if (!reader.ReadVarInt62(&capsule.max_stream_count)) { + return absl::InvalidArgumentError( + "Failed to parse the max streams field"); + } + return Capsule(std::move(capsule)); + } + default: + return Capsule(UnknownCapsule{static_cast(type), + reader.ReadRemainingPayload()}); + } +} +} // namespace + +absl::StatusOr CapsuleParser::AttemptParseCapsule() { + QUICHE_DCHECK(!parsing_error_occurred_); + if (buffered_data_.empty()) { + return 0; + } + QuicheDataReader capsule_fragment_reader(buffered_data_); + uint64_t capsule_type64; + if (!capsule_fragment_reader.ReadVarInt62(&capsule_type64)) { + QUICHE_DVLOG(2) << "Partial read: not enough data to read capsule type"; + return 0; + } + absl::string_view capsule_data; + if (!capsule_fragment_reader.ReadStringPieceVarInt62(&capsule_data)) { + QUICHE_DVLOG(2) + << "Partial read: not enough data to read capsule length or " + "full capsule data"; + return 0; + } + QuicheDataReader capsule_data_reader(capsule_data); + absl::StatusOr capsule = ParseCapsulePayload( + capsule_data_reader, static_cast(capsule_type64)); + QUICHE_RETURN_IF_ERROR(capsule.status()); + if (!visitor_->OnCapsule(*capsule)) { + return absl::AbortedError("Visitor failed to process capsule"); + } + return capsule_fragment_reader.PreviouslyReadPayload().length(); +} + +void CapsuleParser::ReportParseFailure(absl::string_view error_message) { + if (parsing_error_occurred_) { + QUICHE_BUG(multiple parse errors) << "Experienced multiple parse failures"; + return; + } + parsing_error_occurred_ = true; + visitor_->OnCapsuleParseFailure(error_message); +} + +void CapsuleParser::ErrorIfThereIsRemainingBufferedData() { + if (parsing_error_occurred_) { + return; + } + if (!buffered_data_.empty()) { + ReportParseFailure("Incomplete capsule left at the end of the stream"); + } +} + +bool PrefixWithId::operator==(const PrefixWithId& other) const { + return request_id == other.request_id && ip_prefix == other.ip_prefix; +} + +bool IpAddressRange::operator==(const IpAddressRange& other) const { + return start_ip_address == other.start_ip_address && + end_ip_address == other.end_ip_address && + ip_protocol == other.ip_protocol; +} + +bool AddressAssignCapsule::operator==(const AddressAssignCapsule& other) const { + return assigned_addresses == other.assigned_addresses; +} + +bool AddressRequestCapsule::operator==( + const AddressRequestCapsule& other) const { + return requested_addresses == other.requested_addresses; +} + +bool RouteAdvertisementCapsule::operator==( + const RouteAdvertisementCapsule& other) const { + return ip_address_ranges == other.ip_address_ranges; +} + +bool WebTransportStreamDataCapsule::operator==( + const WebTransportStreamDataCapsule& other) const { + return stream_id == other.stream_id && data == other.data && fin == other.fin; +} + +bool WebTransportResetStreamCapsule::operator==( + const WebTransportResetStreamCapsule& other) const { + return stream_id == other.stream_id && error_code == other.error_code; +} + +bool WebTransportStopSendingCapsule::operator==( + const WebTransportStopSendingCapsule& other) const { + return stream_id == other.stream_id && error_code == other.error_code; +} + +bool WebTransportMaxStreamDataCapsule::operator==( + const WebTransportMaxStreamDataCapsule& other) const { + return stream_id == other.stream_id && + max_stream_data == other.max_stream_data; +} + +bool WebTransportMaxStreamsCapsule::operator==( + const WebTransportMaxStreamsCapsule& other) const { + return stream_type == other.stream_type && + max_stream_count == other.max_stream_count; +} + +} // namespace quiche diff --git a/quiche/common/capsule.h b/quiche/common/capsule.h new file mode 100644 index 000000000000..08bde5b4bb2b --- /dev/null +++ b/quiche/common/capsule.h @@ -0,0 +1,386 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_CAPSULE_H_ +#define QUICHE_COMMON_CAPSULE_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_buffer_allocator.h" +#include "quiche/common/quiche_ip_address.h" +#include "quiche/web_transport/web_transport.h" + +namespace quiche { + +enum class CapsuleType : uint64_t { + // Casing in this enum matches the IETF specifications. + DATAGRAM = 0x00, // RFC 9297. + LEGACY_DATAGRAM = 0xff37a0, // draft-ietf-masque-h3-datagram-04. + LEGACY_DATAGRAM_WITHOUT_CONTEXT = + 0xff37a5, // draft-ietf-masque-h3-datagram-05 to -08. + + // + CLOSE_WEBTRANSPORT_SESSION = 0x2843, + + // draft-ietf-masque-connect-ip-03. + ADDRESS_ASSIGN = 0x1ECA6A00, + ADDRESS_REQUEST = 0x1ECA6A01, + ROUTE_ADVERTISEMENT = 0x1ECA6A02, + + // + WT_RESET_STREAM = 0x190b4d39, + WT_STOP_SENDING = 0x190b4d3a, + WT_STREAM = 0x190b4d3b, + WT_STREAM_WITH_FIN = 0x190b4d3c, + // Should be removed as a result of + // . + // WT_MAX_DATA = 0x190b4d3d, + WT_MAX_STREAM_DATA = 0x190b4d3e, + WT_MAX_STREAMS_BIDI = 0x190b4d3f, + WT_MAX_STREAMS_UNIDI = 0x190b4d40, + + // TODO(b/264263113): implement those. + // PADDING = 0x190b4d38, + // WT_DATA_BLOCKED = 0x190b4d41, + // WT_STREAM_DATA_BLOCKED = 0x190b4d42, + // WT_STREAMS_BLOCKED_BIDI = 0x190b4d43, + // WT_STREAMS_BLOCKED_UNIDI = 0x190b4d44, +}; + +QUICHE_EXPORT std::string CapsuleTypeToString(CapsuleType capsule_type); +QUICHE_EXPORT std::ostream& operator<<(std::ostream& os, + const CapsuleType& capsule_type); + +// General. +struct QUICHE_EXPORT DatagramCapsule { + absl::string_view http_datagram_payload; + + std::string ToString() const; + CapsuleType capsule_type() const { return CapsuleType::DATAGRAM; } + bool operator==(const DatagramCapsule& other) const { + return http_datagram_payload == other.http_datagram_payload; + } +}; + +struct QUICHE_EXPORT LegacyDatagramCapsule { + absl::string_view http_datagram_payload; + + std::string ToString() const; + CapsuleType capsule_type() const { return CapsuleType::LEGACY_DATAGRAM; } + bool operator==(const LegacyDatagramCapsule& other) const { + return http_datagram_payload == other.http_datagram_payload; + } +}; + +struct QUICHE_EXPORT LegacyDatagramWithoutContextCapsule { + absl::string_view http_datagram_payload; + + std::string ToString() const; + CapsuleType capsule_type() const { + return CapsuleType::LEGACY_DATAGRAM_WITHOUT_CONTEXT; + } + bool operator==(const LegacyDatagramWithoutContextCapsule& other) const { + return http_datagram_payload == other.http_datagram_payload; + } +}; + +// WebTransport over HTTP/3. +struct QUICHE_EXPORT CloseWebTransportSessionCapsule { + webtransport::SessionErrorCode error_code; + absl::string_view error_message; + + std::string ToString() const; + CapsuleType capsule_type() const { + return CapsuleType::CLOSE_WEBTRANSPORT_SESSION; + } + bool operator==(const CloseWebTransportSessionCapsule& other) const { + return error_code == other.error_code && + error_message == other.error_message; + } +}; + +// MASQUE CONNECT-IP. +struct QUICHE_EXPORT PrefixWithId { + uint64_t request_id; + quiche::QuicheIpPrefix ip_prefix; + bool operator==(const PrefixWithId& other) const; +}; +struct QUICHE_EXPORT IpAddressRange { + quiche::QuicheIpAddress start_ip_address; + quiche::QuicheIpAddress end_ip_address; + uint8_t ip_protocol; + bool operator==(const IpAddressRange& other) const; +}; + +struct QUICHE_EXPORT AddressAssignCapsule { + std::vector assigned_addresses; + bool operator==(const AddressAssignCapsule& other) const; + std::string ToString() const; + CapsuleType capsule_type() const { return CapsuleType::ADDRESS_ASSIGN; } +}; +struct QUICHE_EXPORT AddressRequestCapsule { + std::vector requested_addresses; + bool operator==(const AddressRequestCapsule& other) const; + std::string ToString() const; + CapsuleType capsule_type() const { return CapsuleType::ADDRESS_REQUEST; } +}; +struct QUICHE_EXPORT RouteAdvertisementCapsule { + std::vector ip_address_ranges; + bool operator==(const RouteAdvertisementCapsule& other) const; + std::string ToString() const; + CapsuleType capsule_type() const { return CapsuleType::ROUTE_ADVERTISEMENT; } +}; +struct QUICHE_EXPORT UnknownCapsule { + uint64_t type; + absl::string_view payload; + + std::string ToString() const; + CapsuleType capsule_type() const { return static_cast(type); } + bool operator==(const UnknownCapsule& other) const { + return type == other.type && payload == other.payload; + } +}; + +// WebTransport over HTTP/2. +struct QUICHE_EXPORT WebTransportStreamDataCapsule { + webtransport::StreamId stream_id; + absl::string_view data; + bool fin; + + bool operator==(const WebTransportStreamDataCapsule& other) const; + std::string ToString() const; + CapsuleType capsule_type() const { + return fin ? CapsuleType::WT_STREAM_WITH_FIN : CapsuleType::WT_STREAM; + } +}; +struct QUICHE_EXPORT WebTransportResetStreamCapsule { + webtransport::StreamId stream_id; + uint64_t error_code; + + bool operator==(const WebTransportResetStreamCapsule& other) const; + std::string ToString() const; + CapsuleType capsule_type() const { return CapsuleType::WT_RESET_STREAM; } +}; +struct QUICHE_EXPORT WebTransportStopSendingCapsule { + webtransport::StreamId stream_id; + uint64_t error_code; + + bool operator==(const WebTransportStopSendingCapsule& other) const; + std::string ToString() const; + CapsuleType capsule_type() const { return CapsuleType::WT_STOP_SENDING; } +}; +struct QUICHE_EXPORT WebTransportMaxStreamDataCapsule { + webtransport::StreamId stream_id; + uint64_t max_stream_data; + + bool operator==(const WebTransportMaxStreamDataCapsule& other) const; + std::string ToString() const; + CapsuleType capsule_type() const { return CapsuleType::WT_MAX_STREAM_DATA; } +}; +struct QUICHE_EXPORT WebTransportMaxStreamsCapsule { + webtransport::StreamType stream_type; + uint64_t max_stream_count; + + bool operator==(const WebTransportMaxStreamsCapsule& other) const; + std::string ToString() const; + CapsuleType capsule_type() const { + return stream_type == webtransport::StreamType::kBidirectional + ? CapsuleType::WT_MAX_STREAMS_BIDI + : CapsuleType::WT_MAX_STREAMS_UNIDI; + } +}; + +// Capsule from RFC 9297. +// IMPORTANT NOTE: Capsule does not own any of the absl::string_view memory it +// points to. Strings saved into a capsule must outlive the capsule object. Any +// code that sees a capsule in a callback needs to either process it immediately +// or perform its own deep copy. +class QUICHE_EXPORT Capsule { + public: + static Capsule Datagram( + absl::string_view http_datagram_payload = absl::string_view()); + static Capsule LegacyDatagram( + absl::string_view http_datagram_payload = absl::string_view()); + static Capsule LegacyDatagramWithoutContext( + absl::string_view http_datagram_payload = absl::string_view()); + static Capsule CloseWebTransportSession( + webtransport::SessionErrorCode error_code = 0, + absl::string_view error_message = ""); + static Capsule AddressRequest(); + static Capsule AddressAssign(); + static Capsule RouteAdvertisement(); + static Capsule Unknown( + uint64_t capsule_type, + absl::string_view unknown_capsule_data = absl::string_view()); + + template + explicit Capsule(CapsuleStruct capsule) : capsule_(std::move(capsule)) {} + bool operator==(const Capsule& other) const; + + // Human-readable information string for debugging purposes. + std::string ToString() const; + friend QUICHE_EXPORT std::ostream& operator<<(std::ostream& os, + const Capsule& capsule); + + CapsuleType capsule_type() const { + return absl::visit( + [](const auto& capsule) { return capsule.capsule_type(); }, capsule_); + } + DatagramCapsule& datagram_capsule() { + return absl::get(capsule_); + } + const DatagramCapsule& datagram_capsule() const { + return absl::get(capsule_); + } + LegacyDatagramCapsule& legacy_datagram_capsule() { + return absl::get(capsule_); + } + const LegacyDatagramCapsule& legacy_datagram_capsule() const { + return absl::get(capsule_); + } + LegacyDatagramWithoutContextCapsule& + legacy_datagram_without_context_capsule() { + return absl::get(capsule_); + } + const LegacyDatagramWithoutContextCapsule& + legacy_datagram_without_context_capsule() const { + return absl::get(capsule_); + } + CloseWebTransportSessionCapsule& close_web_transport_session_capsule() { + return absl::get(capsule_); + } + const CloseWebTransportSessionCapsule& close_web_transport_session_capsule() + const { + return absl::get(capsule_); + } + AddressRequestCapsule& address_request_capsule() { + return absl::get(capsule_); + } + const AddressRequestCapsule& address_request_capsule() const { + return absl::get(capsule_); + } + AddressAssignCapsule& address_assign_capsule() { + return absl::get(capsule_); + } + const AddressAssignCapsule& address_assign_capsule() const { + return absl::get(capsule_); + } + RouteAdvertisementCapsule& route_advertisement_capsule() { + return absl::get(capsule_); + } + const RouteAdvertisementCapsule& route_advertisement_capsule() const { + return absl::get(capsule_); + } + WebTransportStreamDataCapsule& web_transport_stream_data() { + return absl::get(capsule_); + } + const WebTransportStreamDataCapsule& web_transport_stream_data() const { + return absl::get(capsule_); + } + WebTransportResetStreamCapsule& web_transport_reset_stream() { + return absl::get(capsule_); + } + const WebTransportResetStreamCapsule& web_transport_reset_stream() const { + return absl::get(capsule_); + } + WebTransportStopSendingCapsule& web_transport_stop_sending() { + return absl::get(capsule_); + } + const WebTransportStopSendingCapsule& web_transport_stop_sending() const { + return absl::get(capsule_); + } + WebTransportMaxStreamDataCapsule& web_transport_max_stream_data() { + return absl::get(capsule_); + } + const WebTransportMaxStreamDataCapsule& web_transport_max_stream_data() + const { + return absl::get(capsule_); + } + WebTransportMaxStreamsCapsule& web_transport_max_streams() { + return absl::get(capsule_); + } + const WebTransportMaxStreamsCapsule& web_transport_max_streams() const { + return absl::get(capsule_); + } + UnknownCapsule& unknown_capsule() { + return absl::get(capsule_); + } + const UnknownCapsule& unknown_capsule() const { + return absl::get(capsule_); + } + + private: + absl::variant + capsule_; +}; + +namespace test { +class CapsuleParserPeer; +} // namespace test + +class QUICHE_EXPORT CapsuleParser { + public: + class QUICHE_EXPORT Visitor { + public: + virtual ~Visitor() {} + + // Called when a capsule has been successfully parsed. The return value + // indicates whether the contents of the capsule are valid: if false is + // returned, the parse operation will be considered failed and + // OnCapsuleParseFailure will be called. Note that since Capsule does not + // own the memory backing its string_views, that memory is only valid until + // this callback returns. Visitors that wish to access the capsule later + // MUST make a deep copy before this returns. + virtual bool OnCapsule(const Capsule& capsule) = 0; + + virtual void OnCapsuleParseFailure(absl::string_view error_message) = 0; + }; + + // |visitor| must be non-null, and must outlive CapsuleParser. + explicit CapsuleParser(Visitor* visitor); + + // Ingests a capsule fragment (any fragment of bytes from the capsule data + // stream) and parses and complete capsules it encounters. Returns false if a + // parsing error occurred. + bool IngestCapsuleFragment(absl::string_view capsule_fragment); + + void ErrorIfThereIsRemainingBufferedData(); + + friend class test::CapsuleParserPeer; + + private: + // Attempts to parse a single capsule from |buffered_data_|. If a full capsule + // is not available, returns 0. If a parsing error occurs, returns an error. + // Otherwise, returns the number of bytes in the parsed capsule. + absl::StatusOr AttemptParseCapsule(); + void ReportParseFailure(absl::string_view error_message); + + // Whether a parsing error has occurred. + bool parsing_error_occurred_ = false; + // Visitor which will receive callbacks, unowned. + Visitor* visitor_; + + std::string buffered_data_; +}; + +// Serializes |capsule| into a newly allocated buffer. +QUICHE_EXPORT quiche::QuicheBuffer SerializeCapsule( + const Capsule& capsule, quiche::QuicheBufferAllocator* allocator); + +} // namespace quiche + +#endif // QUICHE_COMMON_CAPSULE_H_ diff --git a/quiche/common/capsule_test.cc b/quiche/common/capsule_test.cc new file mode 100644 index 000000000000..5ed4d1a350fb --- /dev/null +++ b/quiche/common/capsule_test.cc @@ -0,0 +1,521 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/capsule.h" + +#include +#include +#include +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/quiche_ip_address.h" +#include "quiche/common/simple_buffer_allocator.h" +#include "quiche/common/test_tools/quiche_test_utils.h" +#include "quiche/web_transport/web_transport.h" + +using ::testing::_; +using ::testing::InSequence; +using ::testing::Return; +using ::webtransport::StreamType; + +namespace quiche { +namespace test { + +class CapsuleParserPeer { + public: + static std::string* buffered_data(CapsuleParser* capsule_parser) { + return &capsule_parser->buffered_data_; + } +}; + +namespace { + +class MockCapsuleParserVisitor : public CapsuleParser::Visitor { + public: + MockCapsuleParserVisitor() { + ON_CALL(*this, OnCapsule(_)).WillByDefault(Return(true)); + } + ~MockCapsuleParserVisitor() override = default; + MOCK_METHOD(bool, OnCapsule, (const Capsule& capsule), (override)); + MOCK_METHOD(void, OnCapsuleParseFailure, (absl::string_view error_message), + (override)); +}; + +class CapsuleTest : public QuicheTest { + public: + CapsuleTest() : capsule_parser_(&visitor_) {} + + void ValidateParserIsEmpty() { + EXPECT_CALL(visitor_, OnCapsule(_)).Times(0); + EXPECT_CALL(visitor_, OnCapsuleParseFailure(_)).Times(0); + capsule_parser_.ErrorIfThereIsRemainingBufferedData(); + EXPECT_TRUE(CapsuleParserPeer::buffered_data(&capsule_parser_)->empty()); + } + + void TestSerialization(const Capsule& capsule, + const std::string& expected_bytes) { + quiche::QuicheBuffer serialized_capsule = + SerializeCapsule(capsule, SimpleBufferAllocator::Get()); + quiche::test::CompareCharArraysWithHexError( + "Serialized capsule", serialized_capsule.data(), + serialized_capsule.size(), expected_bytes.data(), + expected_bytes.size()); + } + + ::testing::StrictMock visitor_; + CapsuleParser capsule_parser_; +}; + +TEST_F(CapsuleTest, DatagramCapsule) { + std::string capsule_fragment = absl::HexStringToBytes( + "00" // DATAGRAM capsule type + "08" // capsule length + "a1a2a3a4a5a6a7a8" // HTTP Datagram payload + ); + std::string datagram_payload = absl::HexStringToBytes("a1a2a3a4a5a6a7a8"); + Capsule expected_capsule = Capsule::Datagram(datagram_payload); + { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + ValidateParserIsEmpty(); + TestSerialization(expected_capsule, capsule_fragment); +} + +TEST_F(CapsuleTest, LegacyDatagramCapsule) { + std::string capsule_fragment = absl::HexStringToBytes( + "80ff37a0" // LEGACY_DATAGRAM capsule type + "08" // capsule length + "a1a2a3a4a5a6a7a8" // HTTP Datagram payload + ); + std::string datagram_payload = absl::HexStringToBytes("a1a2a3a4a5a6a7a8"); + Capsule expected_capsule = Capsule::LegacyDatagram(datagram_payload); + { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + ValidateParserIsEmpty(); + TestSerialization(expected_capsule, capsule_fragment); +} + +TEST_F(CapsuleTest, LegacyDatagramWithoutContextCapsule) { + std::string capsule_fragment = absl::HexStringToBytes( + "80ff37a5" // LEGACY_DATAGRAM_WITHOUT_CONTEXT capsule type + "08" // capsule length + "a1a2a3a4a5a6a7a8" // HTTP Datagram payload + ); + std::string datagram_payload = absl::HexStringToBytes("a1a2a3a4a5a6a7a8"); + Capsule expected_capsule = + Capsule::LegacyDatagramWithoutContext(datagram_payload); + { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + ValidateParserIsEmpty(); + TestSerialization(expected_capsule, capsule_fragment); +} + +TEST_F(CapsuleTest, CloseWebTransportStreamCapsule) { + std::string capsule_fragment = absl::HexStringToBytes( + "6843" // CLOSE_WEBTRANSPORT_STREAM capsule type + "09" // capsule length + "00001234" // 0x1234 error code + "68656c6c6f" // "hello" error message + ); + Capsule expected_capsule = Capsule::CloseWebTransportSession( + /*error_code=*/0x1234, /*error_message=*/"hello"); + { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + ValidateParserIsEmpty(); + TestSerialization(expected_capsule, capsule_fragment); +} + +TEST_F(CapsuleTest, AddressAssignCapsule) { + std::string capsule_fragment = absl::HexStringToBytes( + "9ECA6A00" // ADDRESS_ASSIGN capsule type + "1A" // capsule length = 26 + // first assigned address + "00" // request ID = 0 + "04" // IP version = 4 + "C000022A" // 192.0.2.42 + "1F" // prefix length = 31 + // second assigned address + "01" // request ID = 1 + "06" // IP version = 6 + "20010db8123456780000000000000000" // 2001:db8:1234:5678:: + "40" // prefix length = 64 + ); + Capsule expected_capsule = Capsule::AddressAssign(); + quiche::QuicheIpAddress ip_address1; + ip_address1.FromString("192.0.2.42"); + PrefixWithId assigned_address1; + assigned_address1.request_id = 0; + assigned_address1.ip_prefix = + quiche::QuicheIpPrefix(ip_address1, /*prefix_length=*/31); + expected_capsule.address_assign_capsule().assigned_addresses.push_back( + assigned_address1); + quiche::QuicheIpAddress ip_address2; + ip_address2.FromString("2001:db8:1234:5678::"); + PrefixWithId assigned_address2; + assigned_address2.request_id = 1; + assigned_address2.ip_prefix = + quiche::QuicheIpPrefix(ip_address2, /*prefix_length=*/64); + expected_capsule.address_assign_capsule().assigned_addresses.push_back( + assigned_address2); + { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + ValidateParserIsEmpty(); + TestSerialization(expected_capsule, capsule_fragment); +} + +TEST_F(CapsuleTest, AddressRequestCapsule) { + std::string capsule_fragment = absl::HexStringToBytes( + "9ECA6A01" // ADDRESS_REQUEST capsule type + "1A" // capsule length = 26 + // first requested address + "00" // request ID = 0 + "04" // IP version = 4 + "C000022A" // 192.0.2.42 + "1F" // prefix length = 31 + // second requested address + "01" // request ID = 1 + "06" // IP version = 6 + "20010db8123456780000000000000000" // 2001:db8:1234:5678:: + "40" // prefix length = 64 + ); + Capsule expected_capsule = Capsule::AddressRequest(); + quiche::QuicheIpAddress ip_address1; + ip_address1.FromString("192.0.2.42"); + PrefixWithId requested_address1; + requested_address1.request_id = 0; + requested_address1.ip_prefix = + quiche::QuicheIpPrefix(ip_address1, /*prefix_length=*/31); + expected_capsule.address_request_capsule().requested_addresses.push_back( + requested_address1); + quiche::QuicheIpAddress ip_address2; + ip_address2.FromString("2001:db8:1234:5678::"); + PrefixWithId requested_address2; + requested_address2.request_id = 1; + requested_address2.ip_prefix = + quiche::QuicheIpPrefix(ip_address2, /*prefix_length=*/64); + expected_capsule.address_request_capsule().requested_addresses.push_back( + requested_address2); + { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + ValidateParserIsEmpty(); + TestSerialization(expected_capsule, capsule_fragment); +} + +TEST_F(CapsuleTest, RouteAdvertisementCapsule) { + std::string capsule_fragment = absl::HexStringToBytes( + "9ECA6A02" // ROUTE_ADVERTISEMENT capsule type + "2C" // capsule length = 44 + // first IP address range + "04" // IP version = 4 + "C0000218" // 192.0.2.24 + "C000022A" // 192.0.2.42 + "00" // ip protocol = 0 + // second IP address range + "06" // IP version = 6 + "00000000000000000000000000000000" // :: + "ffffffffffffffffffffffffffffffff" // all ones IPv6 address + "01" // ip protocol = 1 (ICMP) + ); + Capsule expected_capsule = Capsule::RouteAdvertisement(); + IpAddressRange ip_address_range1; + ip_address_range1.start_ip_address.FromString("192.0.2.24"); + ip_address_range1.end_ip_address.FromString("192.0.2.42"); + ip_address_range1.ip_protocol = 0; + expected_capsule.route_advertisement_capsule().ip_address_ranges.push_back( + ip_address_range1); + IpAddressRange ip_address_range2; + ip_address_range2.start_ip_address.FromString("::"); + ip_address_range2.end_ip_address.FromString( + "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"); + ip_address_range2.ip_protocol = 1; + expected_capsule.route_advertisement_capsule().ip_address_ranges.push_back( + ip_address_range2); + { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + ValidateParserIsEmpty(); + TestSerialization(expected_capsule, capsule_fragment); +} + +TEST_F(CapsuleTest, WebTransportStreamData) { + std::string capsule_fragment = absl::HexStringToBytes( + "990b4d3b" // WT_STREAM without FIN + "04" // capsule length + "17" // stream ID + "abcdef" // stream payload + ); + Capsule expected_capsule = Capsule(WebTransportStreamDataCapsule()); + expected_capsule.web_transport_stream_data().stream_id = 0x17; + expected_capsule.web_transport_stream_data().data = "\xab\xcd\xef"; + expected_capsule.web_transport_stream_data().fin = false; + { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + ValidateParserIsEmpty(); + TestSerialization(expected_capsule, capsule_fragment); +} +TEST_F(CapsuleTest, WebTransportStreamDataWithFin) { + std::string capsule_fragment = absl::HexStringToBytes( + "990b4d3c" // data with FIN + "04" // capsule length + "17" // stream ID + "abcdef" // stream payload + ); + Capsule expected_capsule = Capsule(WebTransportStreamDataCapsule()); + expected_capsule.web_transport_stream_data().stream_id = 0x17; + expected_capsule.web_transport_stream_data().data = "\xab\xcd\xef"; + expected_capsule.web_transport_stream_data().fin = true; + { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + ValidateParserIsEmpty(); + TestSerialization(expected_capsule, capsule_fragment); +} + +TEST_F(CapsuleTest, WebTransportResetStream) { + std::string capsule_fragment = absl::HexStringToBytes( + "990b4d39" // WT_RESET_STREAM + "02" // capsule length + "17" // stream ID + "07" // error code + ); + Capsule expected_capsule = Capsule(WebTransportResetStreamCapsule()); + expected_capsule.web_transport_reset_stream().stream_id = 0x17; + expected_capsule.web_transport_reset_stream().error_code = 0x07; + { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + ValidateParserIsEmpty(); + TestSerialization(expected_capsule, capsule_fragment); +} + +TEST_F(CapsuleTest, WebTransportStopSending) { + std::string capsule_fragment = absl::HexStringToBytes( + "990b4d3a" // WT_STOP_SENDING + "02" // capsule length + "17" // stream ID + "07" // error code + ); + Capsule expected_capsule = Capsule(WebTransportStopSendingCapsule()); + expected_capsule.web_transport_stop_sending().stream_id = 0x17; + expected_capsule.web_transport_stop_sending().error_code = 0x07; + { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + ValidateParserIsEmpty(); + TestSerialization(expected_capsule, capsule_fragment); +} + +TEST_F(CapsuleTest, WebTransportMaxStreamData) { + std::string capsule_fragment = absl::HexStringToBytes( + "990b4d3e" // WT_MAX_STREAM_DATA + "02" // capsule length + "17" // stream ID + "10" // max stream data + ); + Capsule expected_capsule = Capsule(WebTransportMaxStreamDataCapsule()); + expected_capsule.web_transport_max_stream_data().stream_id = 0x17; + expected_capsule.web_transport_max_stream_data().max_stream_data = 0x10; + { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + ValidateParserIsEmpty(); + TestSerialization(expected_capsule, capsule_fragment); +} + +TEST_F(CapsuleTest, WebTransportMaxStreamsBi) { + std::string capsule_fragment = absl::HexStringToBytes( + "990b4d3f" // WT_MAX_STREAMS (bidi) + "01" // capsule length + "17" // max streams + ); + Capsule expected_capsule = Capsule(WebTransportMaxStreamsCapsule()); + expected_capsule.web_transport_max_streams().stream_type = + StreamType::kBidirectional; + expected_capsule.web_transport_max_streams().max_stream_count = 0x17; + { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + ValidateParserIsEmpty(); + TestSerialization(expected_capsule, capsule_fragment); +} + +TEST_F(CapsuleTest, WebTransportMaxStreamsUni) { + std::string capsule_fragment = absl::HexStringToBytes( + "990b4d40" // WT_MAX_STREAMS (unidi) + "01" // capsule length + "17" // max streams + ); + Capsule expected_capsule = Capsule(WebTransportMaxStreamsCapsule()); + expected_capsule.web_transport_max_streams().stream_type = + StreamType::kUnidirectional; + expected_capsule.web_transport_max_streams().max_stream_count = 0x17; + { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + ValidateParserIsEmpty(); + TestSerialization(expected_capsule, capsule_fragment); +} + +TEST_F(CapsuleTest, UnknownCapsule) { + std::string capsule_fragment = absl::HexStringToBytes( + "17" // unknown capsule type of 0x17 + "08" // capsule length + "a1a2a3a4a5a6a7a8" // unknown capsule data + ); + std::string unknown_capsule_data = absl::HexStringToBytes("a1a2a3a4a5a6a7a8"); + Capsule expected_capsule = Capsule::Unknown(0x17, unknown_capsule_data); + { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + ValidateParserIsEmpty(); + TestSerialization(expected_capsule, capsule_fragment); +} + +TEST_F(CapsuleTest, TwoCapsules) { + std::string capsule_fragment = absl::HexStringToBytes( + "00" // DATAGRAM capsule type + "08" // capsule length + "a1a2a3a4a5a6a7a8" // HTTP Datagram payload + "00" // DATAGRAM capsule type + "08" // capsule length + "b1b2b3b4b5b6b7b8" // HTTP Datagram payload + ); + std::string datagram_payload1 = absl::HexStringToBytes("a1a2a3a4a5a6a7a8"); + std::string datagram_payload2 = absl::HexStringToBytes("b1b2b3b4b5b6b7b8"); + Capsule expected_capsule1 = Capsule::Datagram(datagram_payload1); + Capsule expected_capsule2 = Capsule::Datagram(datagram_payload2); + { + InSequence s; + EXPECT_CALL(visitor_, OnCapsule(expected_capsule1)); + EXPECT_CALL(visitor_, OnCapsule(expected_capsule2)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + ValidateParserIsEmpty(); +} + +TEST_F(CapsuleTest, TwoCapsulesPartialReads) { + std::string capsule_fragment1 = absl::HexStringToBytes( + "00" // first capsule DATAGRAM capsule type + "08" // first capsule length + "a1a2a3a4" // first half of HTTP Datagram payload of first capsule + ); + std::string capsule_fragment2 = absl::HexStringToBytes( + "a5a6a7a8" // second half of HTTP Datagram payload 1 + "00" // second capsule DATAGRAM capsule type + ); + std::string capsule_fragment3 = absl::HexStringToBytes( + "08" // second capsule length + "b1b2b3b4b5b6b7b8" // HTTP Datagram payload of second capsule + ); + capsule_parser_.ErrorIfThereIsRemainingBufferedData(); + std::string datagram_payload1 = absl::HexStringToBytes("a1a2a3a4a5a6a7a8"); + std::string datagram_payload2 = absl::HexStringToBytes("b1b2b3b4b5b6b7b8"); + Capsule expected_capsule1 = Capsule::Datagram(datagram_payload1); + Capsule expected_capsule2 = Capsule::Datagram(datagram_payload2); + { + InSequence s; + EXPECT_CALL(visitor_, OnCapsule(expected_capsule1)); + EXPECT_CALL(visitor_, OnCapsule(expected_capsule2)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment1)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment2)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment3)); + } + ValidateParserIsEmpty(); +} + +TEST_F(CapsuleTest, TwoCapsulesOneByteAtATime) { + std::string capsule_fragment = absl::HexStringToBytes( + "00" // DATAGRAM capsule type + "08" // capsule length + "a1a2a3a4a5a6a7a8" // HTTP Datagram payload + "00" // DATAGRAM capsule type + "08" // capsule length + "b1b2b3b4b5b6b7b8" // HTTP Datagram payload + ); + std::string datagram_payload1 = absl::HexStringToBytes("a1a2a3a4a5a6a7a8"); + std::string datagram_payload2 = absl::HexStringToBytes("b1b2b3b4b5b6b7b8"); + Capsule expected_capsule1 = Capsule::Datagram(datagram_payload1); + Capsule expected_capsule2 = Capsule::Datagram(datagram_payload2); + for (size_t i = 0; i < capsule_fragment.size(); i++) { + if (i < capsule_fragment.size() / 2 - 1) { + EXPECT_CALL(visitor_, OnCapsule(_)).Times(0); + ASSERT_TRUE( + capsule_parser_.IngestCapsuleFragment(capsule_fragment.substr(i, 1))); + } else if (i == capsule_fragment.size() / 2 - 1) { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule1)); + ASSERT_TRUE( + capsule_parser_.IngestCapsuleFragment(capsule_fragment.substr(i, 1))); + EXPECT_TRUE(CapsuleParserPeer::buffered_data(&capsule_parser_)->empty()); + } else if (i < capsule_fragment.size() - 1) { + EXPECT_CALL(visitor_, OnCapsule(_)).Times(0); + ASSERT_TRUE( + capsule_parser_.IngestCapsuleFragment(capsule_fragment.substr(i, 1))); + } else { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule2)); + ASSERT_TRUE( + capsule_parser_.IngestCapsuleFragment(capsule_fragment.substr(i, 1))); + EXPECT_TRUE(CapsuleParserPeer::buffered_data(&capsule_parser_)->empty()); + } + } + capsule_parser_.ErrorIfThereIsRemainingBufferedData(); + EXPECT_TRUE(CapsuleParserPeer::buffered_data(&capsule_parser_)->empty()); +} + +TEST_F(CapsuleTest, PartialCapsuleThenError) { + std::string capsule_fragment = absl::HexStringToBytes( + "00" // DATAGRAM capsule type + "08" // capsule length + "a1a2a3a4" // first half of HTTP Datagram payload + ); + EXPECT_CALL(visitor_, OnCapsule(_)).Times(0); + { + EXPECT_CALL(visitor_, OnCapsuleParseFailure(_)).Times(0); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + { + EXPECT_CALL(visitor_, + OnCapsuleParseFailure( + "Incomplete capsule left at the end of the stream")); + capsule_parser_.ErrorIfThereIsRemainingBufferedData(); + } +} + +TEST_F(CapsuleTest, RejectOverlyLongCapsule) { + std::string capsule_fragment = absl::HexStringToBytes( + "17" // unknown capsule type of 0x17 + "80123456" // capsule length + ) + + std::string(1111111, '?'); + EXPECT_CALL(visitor_, OnCapsuleParseFailure( + "Refusing to buffer too much capsule data")); + EXPECT_FALSE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); +} + +} // namespace +} // namespace test +} // namespace quiche diff --git a/quiche/common/masque/connect_udp_datagram_payload.cc b/quiche/common/masque/connect_udp_datagram_payload.cc new file mode 100644 index 000000000000..ae01817eeaf0 --- /dev/null +++ b/quiche/common/masque/connect_udp_datagram_payload.cc @@ -0,0 +1,130 @@ +// Copyright 2022 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/masque/connect_udp_datagram_payload.h" + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_data_reader.h" +#include "quiche/common/quiche_data_writer.h" + +namespace quiche { + +// static +std::unique_ptr ConnectUdpDatagramPayload::Parse( + absl::string_view datagram_payload) { + QuicheDataReader data_reader(datagram_payload); + + uint64_t context_id; + if (!data_reader.ReadVarInt62(&context_id)) { + QUICHE_DVLOG(1) << "Could not parse malformed UDP proxy payload"; + return nullptr; + } + + if (ContextId{context_id} == ConnectUdpDatagramUdpPacketPayload::kContextId) { + return std::make_unique( + data_reader.ReadRemainingPayload()); + } else { + return std::make_unique( + ContextId{context_id}, data_reader.ReadRemainingPayload()); + } +} + +std::string ConnectUdpDatagramPayload::Serialize() const { + std::string buffer(SerializedLength(), '\0'); + QuicheDataWriter writer(buffer.size(), buffer.data()); + + bool result = SerializeTo(writer); + QUICHE_DCHECK(result); + QUICHE_DCHECK_EQ(writer.remaining(), 0u); + + return buffer; +} + +ConnectUdpDatagramUdpPacketPayload::ConnectUdpDatagramUdpPacketPayload( + absl::string_view udp_packet) + : udp_packet_(udp_packet) {} + +ConnectUdpDatagramPayload::ContextId +ConnectUdpDatagramUdpPacketPayload::GetContextId() const { + return kContextId; +} + +ConnectUdpDatagramPayload::Type ConnectUdpDatagramUdpPacketPayload::GetType() + const { + return Type::kUdpPacket; +} + +absl::string_view ConnectUdpDatagramUdpPacketPayload::GetUdpProxyingPayload() + const { + return udp_packet_; +} + +size_t ConnectUdpDatagramUdpPacketPayload::SerializedLength() const { + return udp_packet_.size() + + QuicheDataWriter::GetVarInt62Len(uint64_t{kContextId}); +} + +bool ConnectUdpDatagramUdpPacketPayload::SerializeTo( + QuicheDataWriter& writer) const { + if (!writer.WriteVarInt62(uint64_t{kContextId})) { + return false; + } + + if (!writer.WriteStringPiece(udp_packet_)) { + return false; + } + + return true; +} + +ConnectUdpDatagramUnknownPayload::ConnectUdpDatagramUnknownPayload( + ContextId context_id, absl::string_view udp_proxying_payload) + : context_id_(context_id), udp_proxying_payload_(udp_proxying_payload) { + if (context_id == ConnectUdpDatagramUdpPacketPayload::kContextId) { + QUICHE_BUG(udp_proxy_unknown_payload_udp_context) + << "ConnectUdpDatagramUnknownPayload created with UDP packet context " + "type (0). Should instead create a " + "ConnectUdpDatagramUdpPacketPayload."; + } +} + +ConnectUdpDatagramPayload::ContextId +ConnectUdpDatagramUnknownPayload::GetContextId() const { + return context_id_; +} + +ConnectUdpDatagramPayload::Type ConnectUdpDatagramUnknownPayload::GetType() + const { + return Type::kUnknown; +} +absl::string_view ConnectUdpDatagramUnknownPayload::GetUdpProxyingPayload() + const { + return udp_proxying_payload_; +} + +size_t ConnectUdpDatagramUnknownPayload::SerializedLength() const { + return udp_proxying_payload_.size() + + QuicheDataWriter::GetVarInt62Len(uint64_t{context_id_}); +} + +bool ConnectUdpDatagramUnknownPayload::SerializeTo( + QuicheDataWriter& writer) const { + if (!writer.WriteVarInt62(uint64_t{context_id_})) { + return false; + } + + if (!writer.WriteStringPiece(udp_proxying_payload_)) { + return false; + } + + return true; +} + +} // namespace quiche diff --git a/quiche/common/masque/connect_udp_datagram_payload.h b/quiche/common/masque/connect_udp_datagram_payload.h new file mode 100644 index 000000000000..dba4e3db60ac --- /dev/null +++ b/quiche/common/masque/connect_udp_datagram_payload.h @@ -0,0 +1,99 @@ +// Copyright 2022 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_MASQUE_CONNECT_UDP_DATAGRAM_PAYLOAD_H_ +#define QUICHE_COMMON_MASQUE_CONNECT_UDP_DATAGRAM_PAYLOAD_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/quiche_data_writer.h" + +namespace quiche { + +// UDP-proxying HTTP Datagram payload for use with CONNECT-UDP. See RFC 9298, +// Section 5. +class ConnectUdpDatagramPayload { + public: + using ContextId = uint64_t; + enum class Type { kUdpPacket, kUnknown }; + + // Parse from `datagram_payload` (a wire-format UDP-proxying HTTP datagram + // payload). Returns nullptr on error. The created ConnectUdpDatagramPayload + // object may use absl::string_views pointing into `datagram_payload`, so the + // data pointed to by `datagram_payload` must outlive the created + // ConnectUdpDatagramPayload object. + static std::unique_ptr Parse( + absl::string_view datagram_payload); + + ConnectUdpDatagramPayload() = default; + + ConnectUdpDatagramPayload(const ConnectUdpDatagramPayload&) = delete; + ConnectUdpDatagramPayload& operator=(const ConnectUdpDatagramPayload&) = + delete; + + virtual ~ConnectUdpDatagramPayload() = default; + + virtual ContextId GetContextId() const = 0; + virtual Type GetType() const = 0; + // Get the inner payload (the UDP Proxying Payload). + virtual absl::string_view GetUdpProxyingPayload() const = 0; + + // Length of this UDP-proxying HTTP datagram payload in wire format. + virtual size_t SerializedLength() const = 0; + // Write a wire-format buffer for the payload. Returns false on write failure + // (typically due to `writer` buffer being full). + virtual bool SerializeTo(QuicheDataWriter& writer) const = 0; + + // Write a wire-format buffer. + std::string Serialize() const; +}; + +// UDP-proxying HTTP Datagram payload that encodes a UDP packet. +class ConnectUdpDatagramUdpPacketPayload final + : public ConnectUdpDatagramPayload { + public: + static constexpr ContextId kContextId = 0; + + // The string pointed to by `udp_packet` must outlive the created + // ConnectUdpDatagramUdpPacketPayload. + explicit ConnectUdpDatagramUdpPacketPayload(absl::string_view udp_packet); + + ContextId GetContextId() const override; + Type GetType() const override; + absl::string_view GetUdpProxyingPayload() const override; + size_t SerializedLength() const override; + bool SerializeTo(QuicheDataWriter& writer) const override; + + absl::string_view udp_packet() const { return udp_packet_; } + + private: + absl::string_view udp_packet_; +}; + +class ConnectUdpDatagramUnknownPayload final + : public ConnectUdpDatagramPayload { + public: + // `udp_proxying_payload` represents the inner payload contained by the UDP- + // proxying HTTP datagram payload. The string pointed to by `inner_payload` + // must outlive the created ConnectUdpDatagramUnknownPayload. + ConnectUdpDatagramUnknownPayload(ContextId context_id, + absl::string_view udp_proxying_payload); + + ContextId GetContextId() const override; + Type GetType() const override; + absl::string_view GetUdpProxyingPayload() const override; + size_t SerializedLength() const override; + bool SerializeTo(QuicheDataWriter& writer) const override; + + private: + ContextId context_id_; + absl::string_view udp_proxying_payload_; // The inner payload. +}; + +} // namespace quiche + +#endif // QUICHE_COMMON_MASQUE_CONNECT_UDP_DATAGRAM_PAYLOAD_H_ diff --git a/quiche/common/masque/connect_udp_datagram_payload_test.cc b/quiche/common/masque/connect_udp_datagram_payload_test.cc new file mode 100644 index 000000000000..52503e3739e1 --- /dev/null +++ b/quiche/common/masque/connect_udp_datagram_payload_test.cc @@ -0,0 +1,61 @@ +// Copyright 2022 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/masque/connect_udp_datagram_payload.h" + +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace quiche::test { +namespace { + +TEST(ConnectUdpDatagramPayloadTest, ParseUdpPacket) { + static constexpr char kDatagramPayload[] = "\x00packet"; + + std::unique_ptr parsed = + ConnectUdpDatagramPayload::Parse( + absl::string_view(kDatagramPayload, sizeof(kDatagramPayload) - 1)); + ASSERT_TRUE(parsed); + + EXPECT_EQ(parsed->GetContextId(), + ConnectUdpDatagramUdpPacketPayload::kContextId); + EXPECT_EQ(parsed->GetType(), ConnectUdpDatagramPayload::Type::kUdpPacket); + EXPECT_EQ(parsed->GetUdpProxyingPayload(), "packet"); +} + +TEST(ConnectUdpDatagramPayloadTest, SerializeUdpPacket) { + static constexpr absl::string_view kUdpPacket = "packet"; + + ConnectUdpDatagramUdpPacketPayload payload(kUdpPacket); + EXPECT_EQ(payload.GetUdpProxyingPayload(), kUdpPacket); + + EXPECT_EQ(payload.Serialize(), std::string("\x00packet", 7)); +} + +TEST(ConnectUdpDatagramPayloadTest, ParseUnknownPacket) { + static constexpr char kDatagramPayload[] = "\x05packet"; + + std::unique_ptr parsed = + ConnectUdpDatagramPayload::Parse( + absl::string_view(kDatagramPayload, sizeof(kDatagramPayload) - 1)); + ASSERT_TRUE(parsed); + + EXPECT_EQ(parsed->GetContextId(), 5); + EXPECT_EQ(parsed->GetType(), ConnectUdpDatagramPayload::Type::kUnknown); + EXPECT_EQ(parsed->GetUdpProxyingPayload(), "packet"); +} + +TEST(ConnectUdpDatagramPayloadTest, SerializeUnknownPacket) { + static constexpr absl::string_view kInnerUdpProxyingPayload = "packet"; + + ConnectUdpDatagramUnknownPayload payload(4u, kInnerUdpProxyingPayload); + EXPECT_EQ(payload.GetUdpProxyingPayload(), kInnerUdpProxyingPayload); + + EXPECT_EQ(payload.Serialize(), std::string("\x04packet", 7)); +} + +} // namespace +} // namespace quiche::test diff --git a/quiche/common/platform/api/quiche_bug_tracker.h b/quiche/common/platform/api/quiche_bug_tracker.h new file mode 100644 index 000000000000..27ac0d209480 --- /dev/null +++ b/quiche/common/platform/api/quiche_bug_tracker.h @@ -0,0 +1,15 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_BUG_TRACKER_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_BUG_TRACKER_H_ + +#include "quiche_platform_impl/quiche_bug_tracker_impl.h" + +#define QUICHE_BUG QUICHE_BUG_IMPL +#define QUICHE_BUG_IF QUICHE_BUG_IF_IMPL +#define QUICHE_PEER_BUG QUICHE_PEER_BUG_IMPL +#define QUICHE_PEER_BUG_IF QUICHE_PEER_BUG_IF_IMPL + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_BUG_TRACKER_H_ diff --git a/quiche/common/platform/api/quiche_client_stats.h b/quiche/common/platform/api/quiche_client_stats.h new file mode 100644 index 000000000000..5b1b08c2a673 --- /dev/null +++ b/quiche/common/platform/api/quiche_client_stats.h @@ -0,0 +1,88 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_CLIENT_STATS_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_CLIENT_STATS_H_ + +#include + +#include "quiche_platform_impl/quiche_client_stats_impl.h" + +namespace quiche { + +//------------------------------------------------------------------------------ +// Enumeration histograms. +// +// Sample usage: +// // In Chrome, these values are persisted to logs. Entries should not be +// // renumbered and numeric values should never be reused. +// enum class MyEnum { +// FIRST_VALUE = 0, +// SECOND_VALUE = 1, +// ... +// FINAL_VALUE = N, +// COUNT +// }; +// QUICHE_CLIENT_HISTOGRAM_ENUM("My.Enumeration", MyEnum::SOME_VALUE, +// MyEnum::COUNT, "Number of time $foo equals to some enum value"); +// +// Note: The value in |sample| must be strictly less than |enum_size|. + +#define QUICHE_CLIENT_HISTOGRAM_ENUM(name, sample, enum_size, docstring) \ + QUICHE_CLIENT_HISTOGRAM_ENUM_IMPL(name, sample, enum_size, docstring) + +//------------------------------------------------------------------------------ +// Histogram for boolean values. + +// Sample usage: +// QUICHE_CLIENT_HISTOGRAM_BOOL("My.Boolean", bool, +// "Number of times $foo is true or false"); +#define QUICHE_CLIENT_HISTOGRAM_BOOL(name, sample, docstring) \ + QUICHE_CLIENT_HISTOGRAM_BOOL_IMPL(name, sample, docstring) + +//------------------------------------------------------------------------------ +// Timing histograms. These are used for collecting timing data (generally +// latencies). + +// These macros create exponentially sized histograms (lengths of the bucket +// ranges exponentially increase as the sample range increases). The units for +// sample and max are unspecified, but they must be the same for one histogram. + +// Sample usage: +// QUICHE_CLIENT_HISTOGRAM_TIMES("Very.Long.Timing.Histogram", time_delta, +// QuicTime::Delta::FromSeconds(1), QuicTime::Delta::FromSecond(3600 * +// 24), 100, "Time spent in doing operation."); +#define QUICHE_CLIENT_HISTOGRAM_TIMES(name, sample, min, max, bucket_count, \ + docstring) \ + QUICHE_CLIENT_HISTOGRAM_TIMES_IMPL(name, sample, min, max, bucket_count, \ + docstring) + +//------------------------------------------------------------------------------ +// Count histograms. These are used for collecting numeric data. + +// These macros default to exponential histograms - i.e. the lengths of the +// bucket ranges exponentially increase as the sample range increases. + +// All of these macros must be called with |name| as a runtime constant. + +// Any data outside the range here will be put in underflow and overflow +// buckets. Min values should be >=1 as emitted 0s will still go into the +// underflow bucket. + +// Sample usage: +// UMA_CLIENT_HISTOGRAM_CUSTOM_COUNTS("My.Histogram", 1, 100000000, 100, +// "Counters of hitting certain code."); + +#define QUICHE_CLIENT_HISTOGRAM_COUNTS(name, sample, min, max, bucket_count, \ + docstring) \ + QUICHE_CLIENT_HISTOGRAM_COUNTS_IMPL(name, sample, min, max, bucket_count, \ + docstring) + +inline void QuicheClientSparseHistogram(const std::string& name, int sample) { + QuicheClientSparseHistogramImpl(name, sample); +} + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_CLIENT_STATS_H_ diff --git a/quiche/common/platform/api/quiche_command_line_flags.h b/quiche/common/platform/api/quiche_command_line_flags.h new file mode 100644 index 000000000000..57caeb33b978 --- /dev/null +++ b/quiche/common/platform/api/quiche_command_line_flags.h @@ -0,0 +1,43 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_COMMAND_LINE_FLAGS_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_COMMAND_LINE_FLAGS_H_ + +#include +#include + +#include "quiche_platform_impl/quiche_command_line_flags_impl.h" + +// Define a command-line flag that can be automatically set via +// QuicheParseCommandLineFlags(). The macro has to be called in the .cc file of +// a unit test or the CLI tool reading the flag. +#define DEFINE_QUICHE_COMMAND_LINE_FLAG(type, name, default_value, help) \ + DEFINE_QUICHE_COMMAND_LINE_FLAG_IMPL(type, name, default_value, help) + +namespace quiche { + +// The impl header must provide GetQuicheCommandLineFlag(), which takes +// PlatformSpecificFlag variable defined by the macro above, and returns the +// flag value of type T. + +// Parses command-line flags, setting flag variables defined using +// DEFINE_QUICHE_COMMAND_LINE_FLAG if they appear in the command line, and +// returning a list of any non-flag arguments specified in the command line. If +// the command line specifies '-h' or '--help', prints a usage message with flag +// descriptions to stdout and exits with status 0. If a flag has an unparsable +// value, writes an error message to stderr and exits with status 1. +inline std::vector QuicheParseCommandLineFlags( + const char* usage, int argc, const char* const* argv) { + return QuicheParseCommandLineFlagsImpl(usage, argc, argv); +} + +// Prints a usage message with flag descriptions to stdout. +inline void QuichePrintCommandLineFlagHelp(const char* usage) { + QuichePrintCommandLineFlagHelpImpl(usage); +} + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_COMMAND_LINE_FLAGS_H_ diff --git a/quiche/common/platform/api/quiche_containers.h b/quiche/common/platform/api/quiche_containers.h new file mode 100644 index 000000000000..b3b929ab04d8 --- /dev/null +++ b/quiche/common/platform/api/quiche_containers.h @@ -0,0 +1,22 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_CONTAINERS_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_CONTAINERS_H_ + +#include "quiche_platform_impl/quiche_containers_impl.h" + +namespace quiche { + +// An ordered container optimized for small sets. +// An implementation with O(n) mutations might be chosen +// in case it has better memory usage and/or faster access. +// +// DOES NOT GUARANTEE POINTER OR ITERATOR STABILITY! +template > +using QuicheSmallOrderedSet = QuicheSmallOrderedSetImpl; + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_CONTAINERS_H_ diff --git a/quiche/common/platform/api/quiche_default_proof_providers.h b/quiche/common/platform/api/quiche_default_proof_providers.h new file mode 100644 index 000000000000..9d5522e7f937 --- /dev/null +++ b/quiche/common/platform/api/quiche_default_proof_providers.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_DEFAULT_PROOF_PROVIDERS_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_DEFAULT_PROOF_PROVIDERS_H_ + +#include + +#include "quiche_platform_impl/quiche_default_proof_providers_impl.h" +#include "quiche/quic/core/crypto/proof_source.h" +#include "quiche/quic/core/crypto/proof_verifier.h" + +namespace quiche { + +// Provides a default proof verifier that can verify a cert chain for |host|. +// The verifier has to do a good faith attempt at verifying the certificate +// against a reasonable root store, and not just always return success. +inline std::unique_ptr CreateDefaultProofVerifier( + const std::string& host) { + return CreateDefaultProofVerifierImpl(host); +} + +// Provides a default proof source for CLI-based tools. The actual certificates +// used in the proof source should be confifgurable via command-line flags. +inline std::unique_ptr CreateDefaultProofSource() { + return CreateDefaultProofSourceImpl(); +} + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_DEFAULT_PROOF_PROVIDERS_H_ diff --git a/quiche/common/platform/api/quiche_event_loop.h b/quiche/common/platform/api/quiche_event_loop.h new file mode 100644 index 000000000000..fd5bde0f8839 --- /dev/null +++ b/quiche/common/platform/api/quiche_event_loop.h @@ -0,0 +1,27 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUIC_EVENT_LOOP_H_ +#define QUICHE_COMMON_PLATFORM_API_QUIC_EVENT_LOOP_H_ + +#include "quiche_platform_impl/quiche_event_loop_impl.h" + +namespace quic { +class QuicEventLoopFactory; +} + +namespace quiche { + +inline quic::QuicEventLoopFactory* GetOverrideForDefaultEventLoop() { + return GetOverrideForDefaultEventLoopImpl(); +} + +inline std::vector +GetExtraEventLoopImplementations() { + return GetExtraEventLoopImplementationsImpl(); +} + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_API_QUIC_EVENT_LOOP_H_ diff --git a/quiche/common/platform/api/quiche_expect_bug.h b/quiche/common/platform/api/quiche_expect_bug.h new file mode 100644 index 000000000000..02ba8d1c2a5d --- /dev/null +++ b/quiche/common/platform/api/quiche_expect_bug.h @@ -0,0 +1,14 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_EXPECT_BUG_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_EXPECT_BUG_H_ + +#include "quiche_platform_impl/quiche_expect_bug_impl.h" + +#define EXPECT_QUICHE_BUG EXPECT_QUICHE_BUG_IMPL +#define EXPECT_QUICHE_PEER_BUG(statement, regex) \ + EXPECT_QUICHE_PEER_BUG_IMPL(statement, regex) + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_EXPECT_BUG_H_ diff --git a/quiche/common/platform/api/quiche_export.h b/quiche/common/platform/api/quiche_export.h new file mode 100644 index 000000000000..3f11ccc44409 --- /dev/null +++ b/quiche/common/platform/api/quiche_export.h @@ -0,0 +1,19 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_EXPORT_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_EXPORT_H_ + +#include "quiche_platform_impl/quiche_export_impl.h" + +// QUICHE_EXPORT is meant for QUICHE functionality that is built in +// Chromium as part of //net/third_party/quiche component, and not fully +// contained in headers. It is required for Windows DLL builds to work. +#define QUICHE_EXPORT QUICHE_EXPORT_IMPL + +// QUICHE_NO_EXPORT is meant for QUICHE functionality that is either fully +// defined in a header, or is built in Chromium as part of tests or tools. +#define QUICHE_NO_EXPORT QUICHE_NO_EXPORT_IMPL + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_EXPORT_H_ diff --git a/quiche/common/platform/api/quiche_file_utils.cc b/quiche/common/platform/api/quiche_file_utils.cc new file mode 100644 index 000000000000..6453ea24f47c --- /dev/null +++ b/quiche/common/platform/api/quiche_file_utils.cc @@ -0,0 +1,51 @@ +#include "quiche/common/platform/api/quiche_file_utils.h" + +#include "quiche_platform_impl/quiche_file_utils_impl.h" + +namespace quiche { + +std::string JoinPath(absl::string_view a, absl::string_view b) { + return JoinPathImpl(a, b); +} + +absl::optional ReadFileContents(absl::string_view file) { + return ReadFileContentsImpl(file); +} + +bool EnumerateDirectory(absl::string_view path, + std::vector& directories, + std::vector& files) { + return EnumerateDirectoryImpl(path, directories, files); +} + +bool EnumerateDirectoryRecursivelyInner(absl::string_view path, + int recursion_limit, + std::vector& files) { + if (recursion_limit < 0) { + return false; + } + + std::vector local_files; + std::vector directories; + if (!EnumerateDirectory(path, directories, local_files)) { + return false; + } + for (const std::string& directory : directories) { + if (!EnumerateDirectoryRecursivelyInner(JoinPath(path, directory), + recursion_limit - 1, files)) { + return false; + } + } + for (const std::string& file : local_files) { + files.push_back(JoinPath(path, file)); + } + return true; +} + +bool EnumerateDirectoryRecursively(absl::string_view path, + std::vector& files) { + constexpr int kRecursionLimit = 20; + return EnumerateDirectoryRecursivelyInner(path, kRecursionLimit, files); +} + +} // namespace quiche diff --git a/quiche/common/platform/api/quiche_file_utils.h b/quiche/common/platform/api/quiche_file_utils.h new file mode 100644 index 000000000000..47723d19e925 --- /dev/null +++ b/quiche/common/platform/api/quiche_file_utils.h @@ -0,0 +1,40 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This header contains basic filesystem functions for use in unit tests and CLI +// tools. Note that those are not 100% suitable for production use, as in, they +// might be prone to race conditions and not always handle non-ASCII filenames +// correctly. +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_FILE_UTILS_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_FILE_UTILS_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" + +namespace quiche { + +// Join two paths in a platform-specific way. Returns |a| if |b| is empty, and +// vice versa. +std::string JoinPath(absl::string_view a, absl::string_view b); + +// Reads the entire file into the memory. +absl::optional ReadFileContents(absl::string_view file); + +// Lists all files and directories in the directory specified by |path|. Returns +// true on success, false on failure. +bool EnumerateDirectory(absl::string_view path, + std::vector& directories, + std::vector& files); + +// Recursively enumerates all of the files in the directory and all of the +// internal subdirectories. Has a fairly small recursion limit. +bool EnumerateDirectoryRecursively(absl::string_view path, + std::vector& files); + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_FILE_UTILS_H_ diff --git a/quiche/common/platform/api/quiche_file_utils_test.cc b/quiche/common/platform/api/quiche_file_utils_test.cc new file mode 100644 index 000000000000..68387a421cc4 --- /dev/null +++ b/quiche/common/platform/api/quiche_file_utils_test.cc @@ -0,0 +1,86 @@ +#include "quiche/common/platform/api/quiche_file_utils.h" + +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/strip.h" +#include "absl/types/optional.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace quiche { +namespace test { +namespace { + +using testing::UnorderedElementsAre; +using testing::UnorderedElementsAreArray; + +TEST(QuicheFileUtilsTest, ReadFileContents) { + std::string path = absl::StrCat(QuicheGetCommonSourcePath(), + "/platform/api/testdir/testfile"); + absl::optional contents = ReadFileContents(path); + ASSERT_TRUE(contents.has_value()); + EXPECT_EQ(*contents, "This is a test file."); +} + +TEST(QuicheFileUtilsTest, ReadFileContentsFileNotFound) { + std::string path = + absl::StrCat(QuicheGetCommonSourcePath(), + "/platform/api/testdir/file-that-does-not-exist"); + absl::optional contents = ReadFileContents(path); + EXPECT_FALSE(contents.has_value()); +} + +TEST(QuicheFileUtilsTest, EnumerateDirectory) { + std::string path = + absl::StrCat(QuicheGetCommonSourcePath(), "/platform/api/testdir"); + std::vector dirs; + std::vector files; + bool success = EnumerateDirectory(path, dirs, files); + EXPECT_TRUE(success); + EXPECT_THAT(files, UnorderedElementsAre("testfile", "README.md")); + EXPECT_THAT(dirs, UnorderedElementsAre("a")); +} + +TEST(QuicheFileUtilsTest, EnumerateDirectoryNoSuchDirectory) { + std::string path = absl::StrCat(QuicheGetCommonSourcePath(), + "/platform/api/testdir/no-such-directory"); + std::vector dirs; + std::vector files; + bool success = EnumerateDirectory(path, dirs, files); + EXPECT_FALSE(success); +} + +TEST(QuicheFileUtilsTest, EnumerateDirectoryNotADirectory) { + std::string path = absl::StrCat(QuicheGetCommonSourcePath(), + "/platform/api/testdir/testfile"); + std::vector dirs; + std::vector files; + bool success = EnumerateDirectory(path, dirs, files); + EXPECT_FALSE(success); +} + +TEST(QuicheFileUtilsTest, EnumerateDirectoryRecursively) { + std::vector expected_paths = {"a/b/c/d/e", "a/subdir/testfile", + "a/z", "testfile", "README.md"}; + + std::string root_path = + absl::StrCat(QuicheGetCommonSourcePath(), "/platform/api/testdir"); + for (std::string& path : expected_paths) { + // For Windows, use Windows path separators. + if (JoinPath("a", "b") == "a\\b") { + absl::c_replace(path, '/', '\\'); + } + + path = JoinPath(root_path, path); + } + + std::vector files; + bool success = EnumerateDirectoryRecursively(root_path, files); + EXPECT_TRUE(success); + EXPECT_THAT(files, UnorderedElementsAreArray(expected_paths)); +} + +} // namespace +} // namespace test +} // namespace quiche diff --git a/quiche/common/platform/api/quiche_flag_utils.h b/quiche/common/platform/api/quiche_flag_utils.h new file mode 100644 index 000000000000..fcd66231e7d3 --- /dev/null +++ b/quiche/common/platform/api/quiche_flag_utils.h @@ -0,0 +1,19 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_FLAG_UTILS_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_FLAG_UTILS_H_ + +#include "quiche_platform_impl/quiche_flag_utils_impl.h" + +#define QUICHE_RELOADABLE_FLAG_COUNT QUICHE_RELOADABLE_FLAG_COUNT_IMPL +#define QUICHE_RELOADABLE_FLAG_COUNT_N QUICHE_RELOADABLE_FLAG_COUNT_N_IMPL + +#define QUICHE_RESTART_FLAG_COUNT QUICHE_RESTART_FLAG_COUNT_IMPL +#define QUICHE_RESTART_FLAG_COUNT_N QUICHE_RESTART_FLAG_COUNT_N_IMPL + +#define QUICHE_CODE_COUNT QUICHE_CODE_COUNT_IMPL +#define QUICHE_CODE_COUNT_N QUICHE_CODE_COUNT_N_IMPL + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_FLAG_UTILS_H_ diff --git a/quiche/common/platform/api/quiche_flags.h b/quiche/common/platform/api/quiche_flags.h new file mode 100644 index 000000000000..5fb23bf709dd --- /dev/null +++ b/quiche/common/platform/api/quiche_flags.h @@ -0,0 +1,28 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_FLAGS_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_FLAGS_H_ + +#include "quiche_platform_impl/quiche_flags_impl.h" + +// Flags accessed via GetQuicheReloadableFlag/GetQuicheRestartFlag are temporary +// boolean flags that are used to enable or disable code behavior changes. The +// current list is available in the quiche/quic/core/quic_flags_list.h file. +#define GetQuicheReloadableFlag(module, flag) \ + GetQuicheReloadableFlagImpl(module, flag) +#define SetQuicheReloadableFlag(module, flag, value) \ + SetQuicheReloadableFlagImpl(module, flag, value) +#define GetQuicheRestartFlag(module, flag) \ + GetQuicheRestartFlagImpl(module, flag) +#define SetQuicheRestartFlag(module, flag, value) \ + SetQuicheRestartFlagImpl(module, flag, value) + +// Flags accessed via GetQuicheFlag are permanent flags used to control QUICHE +// library behavior. The current list is available in the +// quiche/quic/core/quic_protocol_flags_list.h file. +#define GetQuicheFlag(flag) GetQuicheFlagImpl(flag) +#define SetQuicheFlag(flag, value) SetQuicheFlagImpl(flag, value) + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_FLAGS_H_ diff --git a/quiche/common/platform/api/quiche_header_policy.h b/quiche/common/platform/api/quiche_header_policy.h new file mode 100644 index 000000000000..4562778086cf --- /dev/null +++ b/quiche/common/platform/api/quiche_header_policy.h @@ -0,0 +1,20 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_HEADER_POLICY_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_HEADER_POLICY_H_ + +#include "quiche_platform_impl/quiche_header_policy_impl.h" +#include "absl/strings/string_view.h" + +namespace quiche { + +// Invoke some platform-specific action based on header key. +inline void QuicheHandleHeaderPolicy(absl::string_view key) { + QuicheHandleHeaderPolicyImpl(key); +} + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_HEADER_POLICY_H_ diff --git a/quiche/common/platform/api/quiche_hostname_utils.cc b/quiche/common/platform/api/quiche_hostname_utils.cc new file mode 100644 index 000000000000..19ac83e61cc0 --- /dev/null +++ b/quiche/common/platform/api/quiche_hostname_utils.cc @@ -0,0 +1,109 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/platform/api/quiche_hostname_utils.h" + +#include + +#include "absl/strings/string_view.h" +#include "url/url_canon.h" +#include "url/url_canon_stdstring.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace quiche { + +// TODO(vasilvv): the functions below are forked from Chromium's +// net/base/url_util.h; those should be moved to googleurl. +namespace { + +std::string CanonicalizeHost(absl::string_view host, + url::CanonHostInfo* host_info) { + // Try to canonicalize the host. + const url::Component raw_host_component(0, static_cast(host.length())); + std::string canon_host; + url::StdStringCanonOutput canon_host_output(&canon_host); + url::CanonicalizeHostVerbose(host.data(), raw_host_component, + &canon_host_output, host_info); + + if (host_info->out_host.is_nonempty() && + host_info->family != url::CanonHostInfo::BROKEN) { + // Success! Assert that there's no extra garbage. + canon_host_output.Complete(); + QUICHE_DCHECK_EQ(host_info->out_host.len, + static_cast(canon_host.length())); + } else { + // Empty host, or canonicalization failed. We'll return empty. + canon_host.clear(); + } + + return canon_host; +} + +bool IsHostCharAlphanumeric(char c) { + // We can just check lowercase because uppercase characters have already been + // normalized. + return ((c >= 'a') && (c <= 'z')) || ((c >= '0') && (c <= '9')); +} + +bool IsCanonicalizedHostCompliant(const std::string& host) { + if (host.empty()) { + return false; + } + + bool in_component = false; + bool most_recent_component_started_alphanumeric = false; + + for (char c : host) { + if (!in_component) { + most_recent_component_started_alphanumeric = IsHostCharAlphanumeric(c); + if (!most_recent_component_started_alphanumeric && (c != '-') && + (c != '_')) { + return false; + } + in_component = true; + } else if (c == '.') { + in_component = false; + } else if (!IsHostCharAlphanumeric(c) && (c != '-') && (c != '_')) { + return false; + } + } + + return most_recent_component_started_alphanumeric; +} + +} // namespace + +// static +bool QuicheHostnameUtils::IsValidSNI(absl::string_view sni) { + // TODO(rtenneti): Support RFC2396 hostname. + // NOTE: Microsoft does NOT enforce this spec, so if we throw away hostnames + // based on the above spec, we may be losing some hostnames that windows + // would consider valid. By far the most common hostname character NOT + // accepted by the above spec is '_'. + url::CanonHostInfo host_info; + std::string canonicalized_host = CanonicalizeHost(sni, &host_info); + return !host_info.IsIPAddress() && + IsCanonicalizedHostCompliant(canonicalized_host); +} + +// static +std::string QuicheHostnameUtils::NormalizeHostname(absl::string_view hostname) { + url::CanonHostInfo host_info; + std::string host = CanonicalizeHost(hostname, &host_info); + + // Walk backwards over the string, stopping at the first trailing dot. + size_t host_end = host.length(); + while (host_end != 0 && host[host_end - 1] == '.') { + host_end--; + } + + // Erase the trailing dots. + if (host_end != host.length()) { + host.erase(host_end, host.length() - host_end); + } + + return host; +} + +} // namespace quiche diff --git a/quiche/common/platform/api/quiche_hostname_utils.h b/quiche/common/platform/api/quiche_hostname_utils.h new file mode 100644 index 000000000000..10aa3991f1b4 --- /dev/null +++ b/quiche/common/platform/api/quiche_hostname_utils.h @@ -0,0 +1,33 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_HOSTNAME_UTILS_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_HOSTNAME_UTILS_H_ + +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace quiche { + +class QUICHE_EXPORT QuicheHostnameUtils { + public: + QuicheHostnameUtils() = delete; + + // Returns true if the sni is valid, false otherwise. + // (1) disallow IP addresses; + // (2) check that the hostname contains valid characters only; and + // (3) contains at least one dot. + static bool IsValidSNI(absl::string_view sni); + + // Canonicalizes the specified hostname. This involves a wide variety of + // transformations, including lowercasing, removing trailing dots and IDNA + // conversion. + static std::string NormalizeHostname(absl::string_view hostname); +}; + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_HOSTNAME_UTILS_H_ diff --git a/quiche/common/platform/api/quiche_hostname_utils_test.cc b/quiche/common/platform/api/quiche_hostname_utils_test.cc new file mode 100644 index 000000000000..59a38b320890 --- /dev/null +++ b/quiche/common/platform/api/quiche_hostname_utils_test.cc @@ -0,0 +1,94 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/platform/api/quiche_hostname_utils.h" + +#include + +#include "absl/base/macros.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +namespace quiche { +namespace test { +namespace { + +class QuicheHostnameUtilsTest : public QuicheTest {}; + +TEST_F(QuicheHostnameUtilsTest, IsValidSNI) { + // IP as SNI. + EXPECT_FALSE(QuicheHostnameUtils::IsValidSNI("192.168.0.1")); + // SNI without any dot. + EXPECT_TRUE(QuicheHostnameUtils::IsValidSNI("somedomain")); + // Invalid by RFC2396 but unfortunately domains of this form exist. + EXPECT_TRUE(QuicheHostnameUtils::IsValidSNI("some_domain.com")); + // An empty string must be invalid otherwise the QUIC client will try sending + // it. + EXPECT_FALSE(QuicheHostnameUtils::IsValidSNI("")); + + // Valid SNI + EXPECT_TRUE(QuicheHostnameUtils::IsValidSNI("test.google.com")); +} + +TEST_F(QuicheHostnameUtilsTest, NormalizeHostname) { + // clang-format off + struct { + const char *input, *expected; + } tests[] = { + { + "www.google.com", + "www.google.com", + }, + { + "WWW.GOOGLE.COM", + "www.google.com", + }, + { + "www.google.com.", + "www.google.com", + }, + { + "www.google.COM.", + "www.google.com", + }, + { + "www.google.com..", + "www.google.com", + }, + { + "www.google.com........", + "www.google.com", + }, + { + "", + "", + }, + { + ".", + "", + }, + { + "........", + "", + }, + }; + // clang-format on + + for (size_t i = 0; i < ABSL_ARRAYSIZE(tests); ++i) { + EXPECT_EQ(std::string(tests[i].expected), + QuicheHostnameUtils::NormalizeHostname(tests[i].input)); + } + + if (GoogleUrlSupportsIdnaForTest()) { + EXPECT_EQ("xn--54q.google.com", QuicheHostnameUtils::NormalizeHostname( + "\xe5\x85\x89.google.com")); + } else { + EXPECT_EQ( + "", QuicheHostnameUtils::NormalizeHostname("\xe5\x85\x89.google.com")); + } +} + +} // namespace +} // namespace test +} // namespace quiche diff --git a/quiche/common/platform/api/quiche_iovec.h b/quiche/common/platform/api/quiche_iovec.h new file mode 100644 index 000000000000..351e627c1513 --- /dev/null +++ b/quiche/common/platform/api/quiche_iovec.h @@ -0,0 +1,23 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_IOVEC_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_IOVEC_H_ + +#include +#include + +#include "quiche_platform_impl/quiche_iovec_impl.h" + +// The impl header has to export struct iovec, or a POSIX-compatible polyfill. +// Below, we mostly assert that what we have is appropriate. +static_assert(std::is_standard_layout::value, + "iovec has to be a standard-layout struct"); + +static_assert(offsetof(struct iovec, iov_base) < sizeof(struct iovec), + "iovec has to have iov_base"); +static_assert(offsetof(struct iovec, iov_len) < sizeof(struct iovec), + "iovec has to have iov_len"); + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_IOVEC_H_ diff --git a/quiche/common/platform/api/quiche_logging.h b/quiche/common/platform/api/quiche_logging.h new file mode 100644 index 000000000000..217a2016ef92 --- /dev/null +++ b/quiche/common/platform/api/quiche_logging.h @@ -0,0 +1,56 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_LOGGING_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_LOGGING_H_ + +#include "quiche_platform_impl/quiche_logging_impl.h" + +// Please note following QUICHE_LOG are platform dependent: +// INFO severity can be degraded (to VLOG(1) or DVLOG(1)). +// Some platforms may not support QUICHE_LOG_FIRST_N or QUICHE_LOG_EVERY_N_SEC, +// and they would simply be translated to LOG. + +#define QUICHE_DVLOG(verbose_level) QUICHE_DVLOG_IMPL(verbose_level) +#define QUICHE_DVLOG_IF(verbose_level, condition) \ + QUICHE_DVLOG_IF_IMPL(verbose_level, condition) +#define QUICHE_DLOG(severity) QUICHE_DLOG_IMPL(severity) +#define QUICHE_DLOG_IF(severity, condition) \ + QUICHE_DLOG_IF_IMPL(severity, condition) +#define QUICHE_VLOG(verbose_level) QUICHE_VLOG_IMPL(verbose_level) +#define QUICHE_LOG(severity) QUICHE_LOG_IMPL(severity) +#define QUICHE_LOG_FIRST_N(severity, n) QUICHE_LOG_FIRST_N_IMPL(severity, n) +#define QUICHE_LOG_EVERY_N_SEC(severity, seconds) \ + QUICHE_LOG_EVERY_N_SEC_IMPL(severity, seconds) +#define QUICHE_LOG_IF(severity, condition) \ + QUICHE_LOG_IF_IMPL(severity, condition) + +// This is a noop in release build. +#define QUICHE_NOTREACHED() QUICHE_NOTREACHED_IMPL() + +#define QUICHE_PLOG(severity) QUICHE_PLOG_IMPL(severity) + +#define QUICHE_DLOG_INFO_IS_ON() QUICHE_DLOG_INFO_IS_ON_IMPL() +#define QUICHE_LOG_INFO_IS_ON() QUICHE_LOG_INFO_IS_ON_IMPL() +#define QUICHE_LOG_WARNING_IS_ON() QUICHE_LOG_WARNING_IS_ON_IMPL() +#define QUICHE_LOG_ERROR_IS_ON() QUICHE_LOG_ERROR_IS_ON_IMPL() + +#define QUICHE_CHECK(condition) QUICHE_CHECK_IMPL(condition) +#define QUICHE_CHECK_OK(condition) QUICHE_CHECK_OK_IMPL(condition) +#define QUICHE_CHECK_EQ(val1, val2) QUICHE_CHECK_EQ_IMPL(val1, val2) +#define QUICHE_CHECK_NE(val1, val2) QUICHE_CHECK_NE_IMPL(val1, val2) +#define QUICHE_CHECK_LE(val1, val2) QUICHE_CHECK_LE_IMPL(val1, val2) +#define QUICHE_CHECK_LT(val1, val2) QUICHE_CHECK_LT_IMPL(val1, val2) +#define QUICHE_CHECK_GE(val1, val2) QUICHE_CHECK_GE_IMPL(val1, val2) +#define QUICHE_CHECK_GT(val1, val2) QUICHE_CHECK_GT_IMPL(val1, val2) + +#define QUICHE_DCHECK(condition) QUICHE_DCHECK_IMPL(condition) +#define QUICHE_DCHECK_EQ(val1, val2) QUICHE_DCHECK_EQ_IMPL(val1, val2) +#define QUICHE_DCHECK_NE(val1, val2) QUICHE_DCHECK_NE_IMPL(val1, val2) +#define QUICHE_DCHECK_LE(val1, val2) QUICHE_DCHECK_LE_IMPL(val1, val2) +#define QUICHE_DCHECK_LT(val1, val2) QUICHE_DCHECK_LT_IMPL(val1, val2) +#define QUICHE_DCHECK_GE(val1, val2) QUICHE_DCHECK_GE_IMPL(val1, val2) +#define QUICHE_DCHECK_GT(val1, val2) QUICHE_DCHECK_GT_IMPL(val1, val2) + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_LOGGING_H_ diff --git a/quiche/common/platform/api/quiche_lower_case_string.h b/quiche/common/platform/api/quiche_lower_case_string.h new file mode 100644 index 000000000000..dcaef3cd7b3d --- /dev/null +++ b/quiche/common/platform/api/quiche_lower_case_string.h @@ -0,0 +1,16 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_LOWER_CASE_STRING_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_LOWER_CASE_STRING_H_ + +#include "quiche_platform_impl/quiche_lower_case_string_impl.h" + +namespace quiche { + +using QuicheLowerCaseString = QuicheLowerCaseStringImpl; + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_LOWER_CASE_STRING_H_ diff --git a/quiche/common/platform/api/quiche_lower_case_string_test.cc b/quiche/common/platform/api/quiche_lower_case_string_test.cc new file mode 100644 index 000000000000..1685ab3c9b0b --- /dev/null +++ b/quiche/common/platform/api/quiche_lower_case_string_test.cc @@ -0,0 +1,29 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/platform/api/quiche_lower_case_string.h" + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace quiche::test { +namespace { + +TEST(QuicheLowerCaseString, Basic) { + QuicheLowerCaseString empty(""); + EXPECT_EQ("", empty.get()); + + QuicheLowerCaseString from_lower_case("foo"); + EXPECT_EQ("foo", from_lower_case.get()); + + QuicheLowerCaseString from_mixed_case("BaR"); + EXPECT_EQ("bar", from_mixed_case.get()); + + const absl::string_view kData = "FooBar"; + QuicheLowerCaseString from_string_view(kData); + EXPECT_EQ("foobar", from_string_view.get()); +} + +} // namespace +} // namespace quiche::test diff --git a/quiche/common/platform/api/quiche_mem_slice.h b/quiche/common/platform/api/quiche_mem_slice.h new file mode 100644 index 000000000000..18319f80baae --- /dev/null +++ b/quiche/common/platform/api/quiche_mem_slice.h @@ -0,0 +1,75 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_MEM_SLICE_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_MEM_SLICE_H_ + +#include + +#include "quiche_platform_impl/quiche_mem_slice_impl.h" +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/quiche_buffer_allocator.h" + +namespace quiche { + +// QuicheMemSlice is a wrapper around a platform-specific I/O buffer type. It +// may be reference counted, though QUICHE itself does not rely on that. +class QUICHE_EXPORT QuicheMemSlice { + public: + // Constructs a empty QuicheMemSlice with no underlying data. + QuicheMemSlice() = default; + + // Constructs a QuicheMemSlice that takes ownership of |buffer|. The length + // of the |buffer| must not be zero. To construct an empty QuicheMemSlice, + // use the zero-argument constructor instead. + explicit QuicheMemSlice(QuicheBuffer buffer) : impl_(std::move(buffer)) {} + + // Constructs a QuicheMemSlice that takes ownership of |buffer| allocated on + // heap. |length| must not be zero. + QuicheMemSlice(std::unique_ptr buffer, size_t length) + : impl_(std::move(buffer), length) {} + + // Ensures the use of the in-place constructor (below) is intentional. + struct InPlace {}; + + // Constructs a QuicheMemSlice by constructing |impl_| in-place. + template + explicit QuicheMemSlice(InPlace, Args&&... args) + : impl_{std::forward(args)...} {} + + QuicheMemSlice(const QuicheMemSlice& other) = delete; + QuicheMemSlice& operator=(const QuicheMemSlice& other) = delete; + + // Move constructors. |other| will not hold a reference to the data buffer + // after this call completes. + QuicheMemSlice(QuicheMemSlice&& other) = default; + QuicheMemSlice& operator=(QuicheMemSlice&& other) = default; + + ~QuicheMemSlice() = default; + + // Release the underlying reference. Further access the memory will result in + // undefined behavior. + void Reset() { impl_.Reset(); } + + // Returns a const char pointer to underlying data buffer. + const char* data() const { return impl_.data(); } + // Returns the length of underlying data buffer. + size_t length() const { return impl_.length(); } + // Returns the representation of the underlying data as a string view. + absl::string_view AsStringView() const { + return absl::string_view(data(), length()); + } + + bool empty() const { return impl_.empty(); } + + QuicheMemSliceImpl* impl() { return &impl_; } + + private: + QuicheMemSliceImpl impl_; +}; + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_MEM_SLICE_H_ diff --git a/quiche/common/platform/api/quiche_mem_slice_test.cc b/quiche/common/platform/api/quiche_mem_slice_test.cc new file mode 100644 index 000000000000..4eadcc9cfc68 --- /dev/null +++ b/quiche/common/platform/api/quiche_mem_slice_test.cc @@ -0,0 +1,104 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/platform/api/quiche_mem_slice.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/quiche_buffer_allocator.h" +#include "quiche/common/simple_buffer_allocator.h" + +namespace quiche { +namespace test { +namespace { + +class QuicheMemSliceTest : public QuicheTest { + public: + QuicheMemSliceTest() { + size_t length = 1024; + slice_ = QuicheMemSlice(QuicheBuffer(&allocator_, length)); + orig_data_ = slice_.data(); + orig_length_ = slice_.length(); + } + + SimpleBufferAllocator allocator_; + QuicheMemSlice slice_; + const char* orig_data_; + size_t orig_length_; +}; + +TEST_F(QuicheMemSliceTest, MoveConstruct) { + QuicheMemSlice moved(std::move(slice_)); + EXPECT_EQ(moved.data(), orig_data_); + EXPECT_EQ(moved.length(), orig_length_); + EXPECT_EQ(nullptr, slice_.data()); + EXPECT_EQ(0u, slice_.length()); + EXPECT_TRUE(slice_.empty()); +} + +TEST_F(QuicheMemSliceTest, MoveAssign) { + QuicheMemSlice moved; + moved = std::move(slice_); + EXPECT_EQ(moved.data(), orig_data_); + EXPECT_EQ(moved.length(), orig_length_); + EXPECT_EQ(nullptr, slice_.data()); + EXPECT_EQ(0u, slice_.length()); + EXPECT_TRUE(slice_.empty()); +} + +TEST_F(QuicheMemSliceTest, MoveAssignNonEmpty) { + const absl::string_view data("foo"); + auto buffer = std::make_unique(data.length()); + std::memcpy(buffer.get(), data.data(), data.length()); + + QuicheMemSlice moved(std::move(buffer), data.length()); + EXPECT_EQ(data, moved.AsStringView()); + + moved = std::move(slice_); + EXPECT_EQ(moved.data(), orig_data_); + EXPECT_EQ(moved.length(), orig_length_); + EXPECT_EQ(nullptr, slice_.data()); + EXPECT_EQ(0u, slice_.length()); + EXPECT_TRUE(slice_.empty()); +} + +TEST_F(QuicheMemSliceTest, Reset) { + EXPECT_EQ(slice_.data(), orig_data_); + EXPECT_EQ(slice_.length(), orig_length_); + EXPECT_FALSE(slice_.empty()); + + slice_.Reset(); + + EXPECT_EQ(slice_.length(), 0u); + EXPECT_TRUE(slice_.empty()); +} + +TEST_F(QuicheMemSliceTest, SliceAllocatedOnHeap) { + auto buffer = std::make_unique(128); + char* orig_data = buffer.get(); + size_t used_length = 105; + QuicheMemSlice slice = QuicheMemSlice(std::move(buffer), used_length); + QuicheMemSlice moved = std::move(slice); + EXPECT_EQ(moved.data(), orig_data); + EXPECT_EQ(moved.length(), used_length); +} + +TEST_F(QuicheMemSliceTest, SliceFromBuffer) { + const absl::string_view kTestString = + "RFC 9000 Release Celebration Memorial Test String"; + auto buffer = QuicheBuffer::Copy(&allocator_, kTestString); + QuicheMemSlice slice(std::move(buffer)); + + EXPECT_EQ(buffer.data(), nullptr); // NOLINT(bugprone-use-after-move) + EXPECT_EQ(buffer.size(), 0u); + EXPECT_EQ(slice.AsStringView(), kTestString); + EXPECT_EQ(slice.length(), kTestString.length()); +} + +} // namespace +} // namespace test +} // namespace quiche diff --git a/quiche/common/platform/api/quiche_mutex.cc b/quiche/common/platform/api/quiche_mutex.cc new file mode 100644 index 000000000000..e6d4b0c24eaa --- /dev/null +++ b/quiche/common/platform/api/quiche_mutex.cc @@ -0,0 +1,31 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/platform/api/quiche_mutex.h" + +namespace quiche { + +void QuicheMutex::WriterLock() { impl_.WriterLock(); } + +void QuicheMutex::WriterUnlock() { impl_.WriterUnlock(); } + +void QuicheMutex::ReaderLock() { impl_.ReaderLock(); } + +void QuicheMutex::ReaderUnlock() { impl_.ReaderUnlock(); } + +void QuicheMutex::AssertReaderHeld() const { impl_.AssertReaderHeld(); } + +QuicheReaderMutexLock::QuicheReaderMutexLock(QuicheMutex* lock) : lock_(lock) { + lock->ReaderLock(); +} + +QuicheReaderMutexLock::~QuicheReaderMutexLock() { lock_->ReaderUnlock(); } + +QuicheWriterMutexLock::QuicheWriterMutexLock(QuicheMutex* lock) : lock_(lock) { + lock->WriterLock(); +} + +QuicheWriterMutexLock::~QuicheWriterMutexLock() { lock_->WriterUnlock(); } + +} // namespace quiche diff --git a/quiche/common/platform/api/quiche_mutex.h b/quiche/common/platform/api/quiche_mutex.h new file mode 100644 index 000000000000..11e6287d8c61 --- /dev/null +++ b/quiche/common/platform/api/quiche_mutex.h @@ -0,0 +1,101 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_MUTEX_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_MUTEX_H_ + +#include "quiche_platform_impl/quiche_mutex_impl.h" + +#define QUICHE_EXCLUSIVE_LOCKS_REQUIRED QUICHE_EXCLUSIVE_LOCKS_REQUIRED_IMPL +#define QUICHE_GUARDED_BY QUICHE_GUARDED_BY_IMPL +#define QUICHE_LOCKABLE QUICHE_LOCKABLE_IMPL +#define QUICHE_LOCKS_EXCLUDED QUICHE_LOCKS_EXCLUDED_IMPL +#define QUICHE_SHARED_LOCKS_REQUIRED QUICHE_SHARED_LOCKS_REQUIRED_IMPL +#define QUICHE_EXCLUSIVE_LOCK_FUNCTION QUICHE_EXCLUSIVE_LOCK_FUNCTION_IMPL +#define QUICHE_UNLOCK_FUNCTION QUICHE_UNLOCK_FUNCTION_IMPL +#define QUICHE_SHARED_LOCK_FUNCTION QUICHE_SHARED_LOCK_FUNCTION_IMPL +#define QUICHE_SCOPED_LOCKABLE QUICHE_SCOPED_LOCKABLE_IMPL +#define QUICHE_ASSERT_SHARED_LOCK QUICHE_ASSERT_SHARED_LOCK_IMPL + +namespace quiche { + +// A class representing a non-reentrant mutex in QUIC. +class QUICHE_LOCKABLE QUICHE_EXPORT QuicheMutex { + public: + QuicheMutex() = default; + QuicheMutex(const QuicheMutex&) = delete; + QuicheMutex& operator=(const QuicheMutex&) = delete; + + // Block until this Mutex is free, then acquire it exclusively. + void WriterLock() QUICHE_EXCLUSIVE_LOCK_FUNCTION(); + + // Release this Mutex. Caller must hold it exclusively. + void WriterUnlock() QUICHE_UNLOCK_FUNCTION(); + + // Block until this Mutex is free or shared, then acquire a share of it. + void ReaderLock() QUICHE_SHARED_LOCK_FUNCTION(); + + // Release this Mutex. Caller could hold it in shared mode. + void ReaderUnlock() QUICHE_UNLOCK_FUNCTION(); + + // Returns immediately if current thread holds the Mutex in at least shared + // mode. Otherwise, may report an error (typically by crashing with a + // diagnostic), or may return immediately. + void AssertReaderHeld() const QUICHE_ASSERT_SHARED_LOCK(); + + private: + QuicheLockImpl impl_; +}; + +// A helper class that acquires the given QuicheMutex shared lock while the +// QuicheReaderMutexLock is in scope. +class QUICHE_SCOPED_LOCKABLE QUICHE_EXPORT QuicheReaderMutexLock { + public: + explicit QuicheReaderMutexLock(QuicheMutex* lock) + QUICHE_SHARED_LOCK_FUNCTION(lock); + QuicheReaderMutexLock(const QuicheReaderMutexLock&) = delete; + QuicheReaderMutexLock& operator=(const QuicheReaderMutexLock&) = delete; + + ~QuicheReaderMutexLock() QUICHE_UNLOCK_FUNCTION(); + + private: + QuicheMutex* const lock_; +}; + +// A helper class that acquires the given QuicheMutex exclusive lock while the +// QuicheWriterMutexLock is in scope. +class QUICHE_SCOPED_LOCKABLE QUICHE_EXPORT QuicheWriterMutexLock { + public: + explicit QuicheWriterMutexLock(QuicheMutex* lock) + QUICHE_EXCLUSIVE_LOCK_FUNCTION(lock); + QuicheWriterMutexLock(const QuicheWriterMutexLock&) = delete; + QuicheWriterMutexLock& operator=(const QuicheWriterMutexLock&) = delete; + + ~QuicheWriterMutexLock() QUICHE_UNLOCK_FUNCTION(); + + private: + QuicheMutex* const lock_; +}; + +// A Notification allows threads to receive notification of a single occurrence +// of a single event. +class QUICHE_EXPORT QuicheNotification { + public: + QuicheNotification() = default; + QuicheNotification(const QuicheNotification&) = delete; + QuicheNotification& operator=(const QuicheNotification&) = delete; + + bool HasBeenNotified() { return impl_.HasBeenNotified(); } + + void Notify() { impl_.Notify(); } + + void WaitForNotification() { impl_.WaitForNotification(); } + + private: + QuicheNotificationImpl impl_; +}; + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_MUTEX_H_ diff --git a/quiche/common/platform/api/quiche_prefetch.h b/quiche/common/platform/api/quiche_prefetch.h new file mode 100644 index 000000000000..706a7090bea3 --- /dev/null +++ b/quiche/common/platform/api/quiche_prefetch.h @@ -0,0 +1,39 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_PREFETCH_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_PREFETCH_H_ + +#include "quiche_platform_impl/quiche_prefetch_impl.h" + +namespace quiche { + +// Move data into the cache before it is read, or "prefetch" it. +// +// The value of `addr` is the address of the memory to prefetch. If +// the target and compiler support it, data prefetch instructions are +// generated. If the prefetch is done some time before the memory is +// read, it may be in the cache by the time the read occurs. +// +// The function names specify the temporal locality heuristic applied, +// using the names of Intel prefetch instructions: +// +// T0 - high degree of temporal locality; data should be left in as +// many levels of the cache possible +// T1 - moderate degree of temporal locality +// T2 - low degree of temporal locality +// Nta - no temporal locality, data need not be left in the cache +// after the read +// +// Incorrect or gratuitous use of these functions can degrade +// performance, so use them only when representative benchmarks show +// an improvement. + +inline void QuichePrefetchT0(const void* addr) { + return QuichePrefetchT0Impl(addr); +} + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_PREFETCH_H_ diff --git a/quiche/common/platform/api/quiche_reference_counted.h b/quiche/common/platform/api/quiche_reference_counted.h new file mode 100644 index 000000000000..226a155ac9cd --- /dev/null +++ b/quiche/common/platform/api/quiche_reference_counted.h @@ -0,0 +1,168 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_REFERENCE_COUNTED_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_REFERENCE_COUNTED_H_ + +#include "quiche_platform_impl/quiche_reference_counted_impl.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace quiche { + +// Base class for explicitly reference-counted objects in QUIC. +class QUICHE_EXPORT QuicheReferenceCounted : public QuicheReferenceCountedImpl { + public: + QuicheReferenceCounted() {} + + protected: + ~QuicheReferenceCounted() override {} +}; + +// A class representing a reference counted pointer in QUIC. +// +// Construct or initialize QuicheReferenceCountedPointer from raw pointer. Here +// raw pointer MUST be a newly created object. Reference count of a newly +// created object is undefined, but that will be 1 after being added to +// QuicheReferenceCountedPointer. +// QuicheReferenceCountedPointer is used as a local variable. +// QuicheReferenceCountedPointer r_ptr(new T()); +// or, equivalently: +// QuicheReferenceCountedPointer r_ptr; +// T* p = new T(); +// r_ptr = T; +// +// QuicheReferenceCountedPointer is used as a member variable: +// MyClass::MyClass() : r_ptr(new T()) {} +// +// This is WRONG, since *p is not guaranteed to be newly created: +// MyClass::MyClass(T* p) : r_ptr(p) {} +// +// Given an existing QuicheReferenceCountedPointer, create a duplicate that has +// its own reference on the object: +// QuicheReferenceCountedPointer r_ptr_b(r_ptr_a); +// or, equivalently: +// QuicheReferenceCountedPointer r_ptr_b = r_ptr_a; +// +// Given an existing QuicheReferenceCountedPointer, create a +// QuicheReferenceCountedPointer that adopts the reference: +// QuicheReferenceCountedPointer r_ptr_b(std::move(r_ptr_a)); +// or, equivalently: +// QuicheReferenceCountedPointer r_ptr_b = std::move(r_ptr_a); + +template +class QUICHE_NO_EXPORT QuicheReferenceCountedPointer { + public: + QuicheReferenceCountedPointer() = default; + + // Constructor from raw pointer |p|. This guarantees that the reference count + // of *p is 1. This should be only called when a new object is created. + // Calling this on an already existent object does not increase its reference + // count. + explicit QuicheReferenceCountedPointer(T* p) : impl_(p) {} + + // Allows implicit conversion from nullptr. + QuicheReferenceCountedPointer(std::nullptr_t) : impl_(nullptr) {} // NOLINT + + // Copy and copy conversion constructors. It does not take the reference away + // from |other| and they each end up with their own reference. + template + QuicheReferenceCountedPointer( // NOLINT + const QuicheReferenceCountedPointer& other) + : impl_(other.impl()) {} + QuicheReferenceCountedPointer(const QuicheReferenceCountedPointer& other) + : impl_(other.impl()) {} + + // Move constructors. After move, it adopts the reference from |other|. + template + QuicheReferenceCountedPointer( + QuicheReferenceCountedPointer&& other) // NOLINT + : impl_(std::move(other.impl())) {} + QuicheReferenceCountedPointer(QuicheReferenceCountedPointer&& other) + : impl_(std::move(other.impl())) {} + + ~QuicheReferenceCountedPointer() = default; + + // Copy assignments. + QuicheReferenceCountedPointer& operator=( + const QuicheReferenceCountedPointer& other) { + impl_ = other.impl(); + return *this; + } + template + QuicheReferenceCountedPointer& operator=( + const QuicheReferenceCountedPointer& other) { + impl_ = other.impl(); + return *this; + } + + // Move assignments. + QuicheReferenceCountedPointer& operator=( + QuicheReferenceCountedPointer&& other) { + impl_ = std::move(other.impl()); + return *this; + } + template + QuicheReferenceCountedPointer& operator=( + QuicheReferenceCountedPointer&& other) { + impl_ = std::move(other.impl()); + return *this; + } + + // Accessors for the referenced object. + // operator*() and operator->() will assert() if there is no current object. + T& operator*() const { return *impl_; } + T* operator->() const { return impl_.get(); } + + explicit operator bool() const { return static_cast(impl_); } + + // Assignment operator on raw pointer. Drops a reference to current pointee, + // if any, and replaces it with |p|. This guarantees that the reference count + // of *p is 1. This should only be used when a new object is created. Calling + // this on an already existent object is undefined behavior. + QuicheReferenceCountedPointer& operator=(T* p) { + impl_ = p; + return *this; + } + + // Returns the raw pointer with no change in reference count. + T* get() const { return impl_.get(); } + + QuicheReferenceCountedPointerImpl& impl() { return impl_; } + const QuicheReferenceCountedPointerImpl& impl() const { return impl_; } + + // Comparisons against same type. + friend bool operator==(const QuicheReferenceCountedPointer& a, + const QuicheReferenceCountedPointer& b) { + return a.get() == b.get(); + } + friend bool operator!=(const QuicheReferenceCountedPointer& a, + const QuicheReferenceCountedPointer& b) { + return a.get() != b.get(); + } + + // Comparisons against nullptr. + friend bool operator==(const QuicheReferenceCountedPointer& a, + std::nullptr_t) { + return a.get() == nullptr; + } + friend bool operator==(std::nullptr_t, + const QuicheReferenceCountedPointer& b) { + return nullptr == b.get(); + } + friend bool operator!=(const QuicheReferenceCountedPointer& a, + std::nullptr_t) { + return a.get() != nullptr; + } + friend bool operator!=(std::nullptr_t, + const QuicheReferenceCountedPointer& b) { + return nullptr != b.get(); + } + + private: + QuicheReferenceCountedPointerImpl impl_; +}; + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_REFERENCE_COUNTED_H_ diff --git a/quiche/common/platform/api/quiche_reference_counted_test.cc b/quiche/common/platform/api/quiche_reference_counted_test.cc new file mode 100644 index 000000000000..7c05d7ac1971 --- /dev/null +++ b/quiche/common/platform/api/quiche_reference_counted_test.cc @@ -0,0 +1,173 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/platform/api/quiche_reference_counted.h" + +#include "quiche/common/platform/api/quiche_test.h" + +namespace quiche { +namespace test { +namespace { + +class Base : public QuicheReferenceCounted { + public: + explicit Base(bool* destroyed) : destroyed_(destroyed) { + *destroyed_ = false; + } + + protected: + ~Base() override { *destroyed_ = true; } + + private: + bool* destroyed_; +}; + +class Derived : public Base { + public: + explicit Derived(bool* destroyed) : Base(destroyed) {} + + private: + ~Derived() override {} +}; + +class QuicheReferenceCountedTest : public QuicheTest {}; + +TEST_F(QuicheReferenceCountedTest, DefaultConstructor) { + QuicheReferenceCountedPointer a; + EXPECT_EQ(nullptr, a); + EXPECT_EQ(nullptr, a.get()); + EXPECT_FALSE(a); +} + +TEST_F(QuicheReferenceCountedTest, ConstructFromRawPointer) { + bool destroyed = false; + { + QuicheReferenceCountedPointer a(new Base(&destroyed)); + EXPECT_FALSE(destroyed); + } + EXPECT_TRUE(destroyed); +} + +TEST_F(QuicheReferenceCountedTest, RawPointerAssignment) { + bool destroyed = false; + { + QuicheReferenceCountedPointer a; + Base* rct = new Base(&destroyed); + a = rct; + EXPECT_FALSE(destroyed); + } + EXPECT_TRUE(destroyed); +} + +TEST_F(QuicheReferenceCountedTest, PointerCopy) { + bool destroyed = false; + { + QuicheReferenceCountedPointer a(new Base(&destroyed)); + { + QuicheReferenceCountedPointer b(a); + EXPECT_EQ(a, b); + EXPECT_FALSE(destroyed); + } + EXPECT_FALSE(destroyed); + } + EXPECT_TRUE(destroyed); +} + +TEST_F(QuicheReferenceCountedTest, PointerCopyAssignment) { + bool destroyed = false; + { + QuicheReferenceCountedPointer a(new Base(&destroyed)); + { + QuicheReferenceCountedPointer b = a; + EXPECT_EQ(a, b); + EXPECT_FALSE(destroyed); + } + EXPECT_FALSE(destroyed); + } + EXPECT_TRUE(destroyed); +} + +TEST_F(QuicheReferenceCountedTest, PointerCopyFromOtherType) { + bool destroyed = false; + { + QuicheReferenceCountedPointer a(new Derived(&destroyed)); + { + QuicheReferenceCountedPointer b(a); + EXPECT_EQ(a.get(), b.get()); + EXPECT_FALSE(destroyed); + } + EXPECT_FALSE(destroyed); + } + EXPECT_TRUE(destroyed); +} + +TEST_F(QuicheReferenceCountedTest, PointerCopyAssignmentFromOtherType) { + bool destroyed = false; + { + QuicheReferenceCountedPointer a(new Derived(&destroyed)); + { + QuicheReferenceCountedPointer b = a; + EXPECT_EQ(a.get(), b.get()); + EXPECT_FALSE(destroyed); + } + EXPECT_FALSE(destroyed); + } + EXPECT_TRUE(destroyed); +} + +TEST_F(QuicheReferenceCountedTest, PointerMove) { + bool destroyed = false; + QuicheReferenceCountedPointer a(new Derived(&destroyed)); + EXPECT_FALSE(destroyed); + QuicheReferenceCountedPointer b(std::move(a)); + EXPECT_FALSE(destroyed); + EXPECT_NE(nullptr, b); + EXPECT_EQ(nullptr, a); // NOLINT + + b = nullptr; + EXPECT_TRUE(destroyed); +} + +TEST_F(QuicheReferenceCountedTest, PointerMoveAssignment) { + bool destroyed = false; + QuicheReferenceCountedPointer a(new Derived(&destroyed)); + EXPECT_FALSE(destroyed); + QuicheReferenceCountedPointer b = std::move(a); + EXPECT_FALSE(destroyed); + EXPECT_NE(nullptr, b); + EXPECT_EQ(nullptr, a); // NOLINT + + b = nullptr; + EXPECT_TRUE(destroyed); +} + +TEST_F(QuicheReferenceCountedTest, PointerMoveFromOtherType) { + bool destroyed = false; + QuicheReferenceCountedPointer a(new Derived(&destroyed)); + EXPECT_FALSE(destroyed); + QuicheReferenceCountedPointer b(std::move(a)); + EXPECT_FALSE(destroyed); + EXPECT_NE(nullptr, b); + EXPECT_EQ(nullptr, a); // NOLINT + + b = nullptr; + EXPECT_TRUE(destroyed); +} + +TEST_F(QuicheReferenceCountedTest, PointerMoveAssignmentFromOtherType) { + bool destroyed = false; + QuicheReferenceCountedPointer a(new Derived(&destroyed)); + EXPECT_FALSE(destroyed); + QuicheReferenceCountedPointer b = std::move(a); + EXPECT_FALSE(destroyed); + EXPECT_NE(nullptr, b); + EXPECT_EQ(nullptr, a); // NOLINT + + b = nullptr; + EXPECT_TRUE(destroyed); +} + +} // namespace +} // namespace test +} // namespace quiche diff --git a/quiche/common/platform/api/quiche_server_stats.h b/quiche/common/platform/api/quiche_server_stats.h new file mode 100644 index 000000000000..e8ad499cd93a --- /dev/null +++ b/quiche/common/platform/api/quiche_server_stats.h @@ -0,0 +1,82 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_SERVER_STATS_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_SERVER_STATS_H_ + +#include "quiche_platform_impl/quiche_server_stats_impl.h" + +namespace quiche { + +//------------------------------------------------------------------------------ +// Enumeration histograms. +// +// Sample usage: +// // In Chrome, these values are persisted to logs. Entries should not be +// // renumbered and numeric values should never be reused. +// enum class MyEnum { +// FIRST_VALUE = 0, +// SECOND_VALUE = 1, +// ... +// FINAL_VALUE = N, +// COUNT +// }; +// QUICHE_SERVER_HISTOGRAM_ENUM("My.Enumeration", MyEnum::SOME_VALUE, +// MyEnum::COUNT, "Number of time $foo equals to some enum value"); +// +// Note: The value in |sample| must be strictly less than |enum_size|. + +#define QUICHE_SERVER_HISTOGRAM_ENUM(name, sample, enum_size, docstring) \ + QUICHE_SERVER_HISTOGRAM_ENUM_IMPL(name, sample, enum_size, docstring) + +//------------------------------------------------------------------------------ +// Histogram for boolean values. + +// Sample usage: +// QUICHE_SERVER_HISTOGRAM_BOOL("My.Boolean", bool, +// "Number of times $foo is true or false"); +#define QUICHE_SERVER_HISTOGRAM_BOOL(name, sample, docstring) \ + QUICHE_SERVER_HISTOGRAM_BOOL_IMPL(name, sample, docstring) + +//------------------------------------------------------------------------------ +// Timing histograms. These are used for collecting timing data (generally +// latencies). + +// These macros create exponentially sized histograms (lengths of the bucket +// ranges exponentially increase as the sample range increases). The units for +// sample and max are unspecified, but they must be the same for one histogram. + +// Sample usage: +// QUICHE_SERVER_HISTOGRAM_TIMES("Very.Long.Timing.Histogram", time_delta, +// QuicTime::Delta::FromSeconds(1), QuicTime::Delta::FromSecond(3600 * +// 24), 100, "Time spent in doing operation."); +#define QUICHE_SERVER_HISTOGRAM_TIMES(name, sample, min, max, bucket_count, \ + docstring) \ + QUICHE_SERVER_HISTOGRAM_TIMES_IMPL(name, sample, min, max, bucket_count, \ + docstring) + +//------------------------------------------------------------------------------ +// Count histograms. These are used for collecting numeric data. + +// These macros default to exponential histograms - i.e. the lengths of the +// bucket ranges exponentially increase as the sample range increases. + +// All of these macros must be called with |name| as a runtime constant. + +// Any data outside the range here will be put in underflow and overflow +// buckets. Min values should be >=1 as emitted 0s will still go into the +// underflow bucket. + +// Sample usage: +// QUICHE_SERVER_SERVER_HISTOGRAM_CUSTOM_COUNTS("My.Histogram", 1, 100000000, +// 100, "Counters of hitting certian code."); + +#define QUICHE_SERVER_HISTOGRAM_COUNTS(name, sample, min, max, bucket_count, \ + docstring) \ + QUICHE_SERVER_HISTOGRAM_COUNTS_IMPL(name, sample, min, max, bucket_count, \ + docstring) + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_SERVER_STATS_H_ diff --git a/quiche/common/platform/api/quiche_stack_trace.h b/quiche/common/platform/api/quiche_stack_trace.h new file mode 100644 index 000000000000..4c07577cc2ef --- /dev/null +++ b/quiche/common/platform/api/quiche_stack_trace.h @@ -0,0 +1,32 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_STACK_TRACE_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_STACK_TRACE_H_ + +#include + +#include "quiche_platform_impl/quiche_stack_trace_impl.h" + +namespace quiche { + +// Returns a human-readable stack trace. Mostly used in error logging and +// related features. +inline std::string QuicheStackTrace() { return QuicheStackTraceImpl(); } + +// Indicates whether the unit test for QuicheStackTrace() should be run. The +// unit test calls QuicheStackTrace() from a specific function and checks +// whether that specific function is in the stack trace. This function should +// return false if: +// (1) QuicheStackTrace() is unimplemented, +// (2) QuicheStackTrace() does not work on the current platform, or +// (3) QuicheStackTrace() works, but the symbols are not guaranteed to be +// available. +inline bool QuicheShouldRunStackTraceTest() { + return QuicheShouldRunStackTraceTestImpl(); +} + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_STACK_TRACE_H_ diff --git a/quiche/common/platform/api/quiche_stack_trace_test.cc b/quiche/common/platform/api/quiche_stack_trace_test.cc new file mode 100644 index 000000000000..74a4ca90beee --- /dev/null +++ b/quiche/common/platform/api/quiche_stack_trace_test.cc @@ -0,0 +1,43 @@ +#include "quiche/common/platform/api/quiche_stack_trace.h" + +#include + +#include "absl/base/attributes.h" +#include "absl/base/optimization.h" +#include "absl/strings/str_cat.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace quiche { +namespace test { +namespace { + +bool ShouldRunTest() { +#if defined(ABSL_HAVE_ATTRIBUTE_NOINLINE) + return QuicheShouldRunStackTraceTest(); +#else + // If QuicheDesignatedStackTraceTestFunction gets inlined, the test will + // inevitably fail, since the function won't be on the stack trace. Disable + // the test in that scenario. + return false; +#endif +} + +ABSL_ATTRIBUTE_NOINLINE std::string QuicheDesignatedStackTraceTestFunction() { + std::string result = QuicheStackTrace(); + ABSL_BLOCK_TAIL_CALL_OPTIMIZATION(); + return result; +} + +TEST(QuicheStackTraceTest, GetStackTrace) { + if (!ShouldRunTest()) { + return; + } + + std::string stacktrace = QuicheDesignatedStackTraceTestFunction(); + EXPECT_THAT(stacktrace, + testing::HasSubstr("QuicheDesignatedStackTraceTestFunction")); +} + +} // namespace +} // namespace test +} // namespace quiche diff --git a/quiche/common/platform/api/quiche_system_event_loop.h b/quiche/common/platform/api/quiche_system_event_loop.h new file mode 100644 index 000000000000..41ed45a693de --- /dev/null +++ b/quiche/common/platform/api/quiche_system_event_loop.h @@ -0,0 +1,20 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_SYSTEM_EVENT_LOOP_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_SYSTEM_EVENT_LOOP_H_ + +#include "quiche_platform_impl/quiche_system_event_loop_impl.h" + +namespace quiche { + +inline void QuicheRunSystemEventLoopIteration() { + QuicheRunSystemEventLoopIterationImpl(); +} + +using QuicheSystemEventLoop = QuicheSystemEventLoopImpl; + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_SYSTEM_EVENT_LOOP_H_ diff --git a/quiche/common/platform/api/quiche_test.h b/quiche/common/platform/api/quiche_test.h new file mode 100644 index 000000000000..2fbb81da54fa --- /dev/null +++ b/quiche/common/platform/api/quiche_test.h @@ -0,0 +1,42 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_TEST_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_TEST_H_ + +#include "quiche_platform_impl/quiche_test_impl.h" + +namespace quiche::test { + +using QuicheTest = QuicheTestImpl; + +template +using QuicheTestWithParam = QuicheTestWithParamImpl; + +using QuicheFlagSaver = QuicheFlagSaverImpl; + +// Class which needs to be instantiated in tests which use threads. +using ScopedEnvironmentForThreads = ScopedEnvironmentForThreadsImpl; + +inline std::string QuicheGetTestMemoryCachePath() { + return QuicheGetTestMemoryCachePathImpl(); +} + +// Returns the path to quiche/common directory where the test data could be +// located. +inline std::string QuicheGetCommonSourcePath() { + return QuicheGetCommonSourcePathImpl(); +} + +} // namespace quiche::test + +#define EXPECT_QUICHE_DEBUG_DEATH(condition, message) \ + EXPECT_QUICHE_DEBUG_DEATH_IMPL(condition, message) + +#define QUICHE_TEST_DISABLED_IN_CHROME(name) \ + QUICHE_TEST_DISABLED_IN_CHROME_IMPL(name) + +#define QUICHE_SLOW_TEST(test) QUICHE_SLOW_TEST_IMPL(test) + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_TEST_H_ diff --git a/quiche/common/platform/api/quiche_test_loopback.cc b/quiche/common/platform/api/quiche_test_loopback.cc new file mode 100644 index 000000000000..07d3ccf956e3 --- /dev/null +++ b/quiche/common/platform/api/quiche_test_loopback.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/platform/api/quiche_test_loopback.h" + +namespace quiche { + +quic::IpAddressFamily AddressFamilyUnderTest() { + return AddressFamilyUnderTestImpl(); +} + +quic::QuicIpAddress TestLoopback4() { return TestLoopback4Impl(); } + +quic::QuicIpAddress TestLoopback6() { return TestLoopback6Impl(); } + +quic::QuicIpAddress TestLoopback() { return TestLoopbackImpl(); } + +quic::QuicIpAddress TestLoopback(int index) { return TestLoopbackImpl(index); } + +} // namespace quiche diff --git a/quiche/common/platform/api/quiche_test_loopback.h b/quiche/common/platform/api/quiche_test_loopback.h new file mode 100644 index 000000000000..b493ca4528ec --- /dev/null +++ b/quiche/common/platform/api/quiche_test_loopback.h @@ -0,0 +1,34 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_TEST_LOOPBACK_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_TEST_LOOPBACK_H_ + +#include "quiche_platform_impl/quiche_test_loopback_impl.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_ip_address_family.h" + +namespace quiche { + +// Returns the address family (IPv4 or IPv6) used to run test under. +quic::IpAddressFamily AddressFamilyUnderTest(); + +// Returns an IPv4 loopback address. +quic::QuicIpAddress TestLoopback4(); + +// Returns the only IPv6 loopback address. +quic::QuicIpAddress TestLoopback6(); + +// Returns an appropriate IPv4/Ipv6 loopback address based upon whether the +// test's environment. +quic::QuicIpAddress TestLoopback(); + +// If address family under test is IPv4, returns an indexed IPv4 loopback +// address. If address family under test is IPv6, the address returned is +// platform-dependent. +quic::QuicIpAddress TestLoopback(int index); + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_TEST_LOOPBACK_H_ diff --git a/quiche/common/platform/api/quiche_test_output.h b/quiche/common/platform/api/quiche_test_output.h new file mode 100644 index 000000000000..b28a8a7d754d --- /dev/null +++ b/quiche/common/platform/api/quiche_test_output.h @@ -0,0 +1,40 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_TEST_OUTPUT_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_TEST_OUTPUT_H_ + +#include "quiche_platform_impl/quiche_test_output_impl.h" +#include "absl/strings/string_view.h" + +namespace quiche { + +// Save |data| into ${QUICHE_TEST_OUTPUT_DIR}/filename. If a file with the same +// path already exists, overwrite it. +inline void QuicheSaveTestOutput(absl::string_view filename, + absl::string_view data) { + QuicheSaveTestOutputImpl(filename, data); +} + +// Load the content of ${QUICHE_TEST_OUTPUT_DIR}/filename into |*data|. +// Return whether it is successfully loaded. +inline bool QuicheLoadTestOutput(absl::string_view filename, + std::string* data) { + return QuicheLoadTestOutputImpl(filename, data); +} + +// Records a QUIC trace file(.qtr) into a directory specified by the +// QUICHE_TEST_OUTPUT_DIR environment variable. Assumes that it's called from a +// unit test. +// +// The |identifier| is a human-readable identifier that will be combined with +// the name of the unit test and a timestamp. |data| is the serialized +// quic_trace.Trace protobuf that is being recorded into the file. +inline void QuicheRecordTrace(absl::string_view identifier, + absl::string_view data) { + QuicheRecordTraceImpl(identifier, data); +} + +} // namespace quiche +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_TEST_OUTPUT_H_ diff --git a/quiche/common/platform/api/quiche_testvalue.h b/quiche/common/platform/api/quiche_testvalue.h new file mode 100644 index 000000000000..ec50cd9a720d --- /dev/null +++ b/quiche/common/platform/api/quiche_testvalue.h @@ -0,0 +1,25 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_TESTVALUE_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_TESTVALUE_H_ + +#include "quiche_platform_impl/quiche_testvalue_impl.h" +#include "absl/strings/string_view.h" + +namespace quiche { + +// Interface allowing injection of test-specific code in production codepaths. +// |label| is an arbitrary value identifying the location, and |var| is a +// pointer to the value to be modified. +// +// Note that this method does nothing in Chromium. +template +void AdjustTestValue(absl::string_view label, T* var) { + AdjustTestValueImpl(label, var); +} + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_TESTVALUE_H_ diff --git a/quiche/common/platform/api/quiche_thread.h b/quiche/common/platform/api/quiche_thread.h new file mode 100644 index 000000000000..0a15e23b49d5 --- /dev/null +++ b/quiche/common/platform/api/quiche_thread.h @@ -0,0 +1,28 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_THREAD_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_THREAD_H_ + +#include + +#include "quiche_platform_impl/quiche_thread_impl.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace quiche { + +// A class representing a thread of execution in QUIC. +class QUICHE_EXPORT QuicheThread : public QuicheThreadImpl { + public: + QuicheThread(const std::string& string) : QuicheThreadImpl(string) {} + QuicheThread(const QuicheThread&) = delete; + QuicheThread& operator=(const QuicheThread&) = delete; + + // Impl defines a virtual void Run() method which subclasses + // must implement. +}; + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_THREAD_H_ diff --git a/quiche/common/platform/api/quiche_time_utils.h b/quiche/common/platform/api/quiche_time_utils.h new file mode 100644 index 000000000000..60226d5d7086 --- /dev/null +++ b/quiche/common/platform/api/quiche_time_utils.h @@ -0,0 +1,27 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_TIME_UTILS_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_TIME_UTILS_H_ + +#include + +#include "quiche_platform_impl/quiche_time_utils_impl.h" + +namespace quiche { + +// Converts a civil time specified in UTC into a number of seconds since the +// Unix epoch. This function is strict about validity of accepted dates. For +// instance, it will reject February 29 on non-leap years, or 25 hours in a day. +// As a notable exception, 60 seconds is accepted to deal with potential leap +// seconds. If the date predates Unix epoch, nullopt will be returned. +inline absl::optional QuicheUtcDateTimeToUnixSeconds( + int year, int month, int day, int hour, int minute, int second) { + return QuicheUtcDateTimeToUnixSecondsImpl(year, month, day, hour, minute, + second); +} + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_TIME_UTILS_H_ diff --git a/quiche/common/platform/api/quiche_time_utils_test.cc b/quiche/common/platform/api/quiche_time_utils_test.cc new file mode 100644 index 000000000000..5d0900486a3f --- /dev/null +++ b/quiche/common/platform/api/quiche_time_utils_test.cc @@ -0,0 +1,51 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/platform/api/quiche_time_utils.h" + +#include "absl/types/optional.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace quiche { +namespace { + +TEST(QuicheTimeUtilsTest, Basic) { + EXPECT_EQ(1, QuicheUtcDateTimeToUnixSeconds(1970, 1, 1, 0, 0, 1)); + EXPECT_EQ(365 * 86400, QuicheUtcDateTimeToUnixSeconds(1971, 1, 1, 0, 0, 0)); + // Some arbitrary timestamps closer to the present, compared to the output of + // "Date(...).getTime()" from the JavaScript console. + EXPECT_EQ(1152966896, + QuicheUtcDateTimeToUnixSeconds(2006, 7, 15, 12, 34, 56)); + EXPECT_EQ(1591130001, QuicheUtcDateTimeToUnixSeconds(2020, 6, 2, 20, 33, 21)); + + EXPECT_EQ(absl::nullopt, + QuicheUtcDateTimeToUnixSeconds(1970, 2, 29, 0, 0, 1)); + EXPECT_NE(absl::nullopt, + QuicheUtcDateTimeToUnixSeconds(1972, 2, 29, 0, 0, 1)); +} + +TEST(QuicheTimeUtilsTest, Bounds) { + EXPECT_EQ(absl::nullopt, + QuicheUtcDateTimeToUnixSeconds(1970, 1, 32, 0, 0, 1)); + EXPECT_EQ(absl::nullopt, + QuicheUtcDateTimeToUnixSeconds(1970, 4, 31, 0, 0, 1)); + EXPECT_EQ(absl::nullopt, QuicheUtcDateTimeToUnixSeconds(1970, 1, 0, 0, 0, 1)); + EXPECT_EQ(absl::nullopt, + QuicheUtcDateTimeToUnixSeconds(1970, 13, 1, 0, 0, 1)); + EXPECT_EQ(absl::nullopt, QuicheUtcDateTimeToUnixSeconds(1970, 0, 1, 0, 0, 1)); + EXPECT_EQ(absl::nullopt, + QuicheUtcDateTimeToUnixSeconds(1970, 1, 1, 24, 0, 0)); + EXPECT_EQ(absl::nullopt, + QuicheUtcDateTimeToUnixSeconds(1970, 1, 1, 0, 60, 0)); +} + +TEST(QuicheTimeUtilsTest, LeapSecond) { + EXPECT_EQ(QuicheUtcDateTimeToUnixSeconds(2015, 6, 30, 23, 59, 60), + QuicheUtcDateTimeToUnixSeconds(2015, 7, 1, 0, 0, 0)); + EXPECT_EQ(QuicheUtcDateTimeToUnixSeconds(2015, 6, 30, 25, 59, 60), + absl::nullopt); +} + +} // namespace +} // namespace quiche diff --git a/quiche/common/platform/api/quiche_udp_socket_platform_api.h b/quiche/common/platform/api/quiche_udp_socket_platform_api.h new file mode 100644 index 000000000000..a426c14bdf94 --- /dev/null +++ b/quiche/common/platform/api/quiche_udp_socket_platform_api.h @@ -0,0 +1,46 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_UDP_SOCKET_PLATFORM_API_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_UDP_SOCKET_PLATFORM_API_H_ + +#include "quiche_platform_impl/quiche_udp_socket_platform_impl.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/common/quiche_ip_address_family.h" + +namespace quiche { + +const size_t kCmsgSpaceForGooglePacketHeader = + kCmsgSpaceForGooglePacketHeaderImpl; + +inline bool GetGooglePacketHeadersFromControlMessage( + struct ::cmsghdr* cmsg, char** packet_headers, size_t* packet_headers_len) { + return GetGooglePacketHeadersFromControlMessageImpl(cmsg, packet_headers, + packet_headers_len); +} + +inline void SetGoogleSocketOptions(int fd) { SetGoogleSocketOptionsImpl(fd); } + +// Retrieves the IP TOS byte for |fd| and |address_family|, based on the correct +// sockopt for the platform, replaces the two ECN bits of that byte with the +// value in |ecn_codepoint|. +// The result is stored in |value| in the proper format to set the TOS byte +// using a cmsg. |value| must point to memory of size |value_len|. Stores the +// correct cmsg type to use in |type|. +// Returns 0 on success. Returns EINVAL if |address_family| is neither IP_V4 nor +// IP_V6, or if |value_len| is not large enough to store the appropriately +// formatted argument. If getting the socket option fails, returns the +// associated error code. +inline int GetEcnCmsgArgsPreserveDscp( + const int fd, const quiche::IpAddressFamily address_family, + quic::QuicEcnCodepoint ecn_codepoint, int& type, void* value, + socklen_t& value_len) { + return GetEcnCmsgArgsPreserveDscpImpl( + fd, ToPlatformAddressFamily(address_family), + static_cast(ecn_codepoint), type, value, value_len); +} + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_UDP_SOCKET_PLATFORM_API_H_ diff --git a/quiche/common/platform/api/quiche_url_utils.h b/quiche/common/platform/api/quiche_url_utils.h new file mode 100644 index 000000000000..e6c9fc90c07a --- /dev/null +++ b/quiche/common/platform/api/quiche_url_utils.h @@ -0,0 +1,38 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_URL_UTILS_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_URL_UTILS_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche_platform_impl/quiche_url_utils_impl.h" + +namespace quiche { + +// Produces concrete URLs in |target| from templated ones in |uri_template|. +// Parameters are URL-encoded. Collects the names of any expanded variables in +// |vars_found|. Returns true if the template was parseable, false if it was +// malformed. +inline bool ExpandURITemplate( + const std::string& uri_template, + const absl::flat_hash_map& parameters, + std::string* target, + absl::flat_hash_set* vars_found = nullptr) { + return ExpandURITemplateImpl(uri_template, parameters, target, vars_found); +} + +// Decodes a URL-encoded string and converts it to ASCII. If the decoded input +// contains non-ASCII characters, decoding fails and absl::nullopt is returned. +inline absl::optional AsciiUrlDecode(absl::string_view input) { + return AsciiUrlDecodeImpl(input); +} + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_URL_UTILS_H_ diff --git a/quiche/common/platform/api/quiche_url_utils_test.cc b/quiche/common/platform/api/quiche_url_utils_test.cc new file mode 100644 index 000000000000..33a30ece1965 --- /dev/null +++ b/quiche/common/platform/api/quiche_url_utils_test.cc @@ -0,0 +1,80 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/platform/api/quiche_url_utils.h" + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/types/optional.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace quiche { +namespace { + +void ValidateExpansion( + const std::string& uri_template, + const absl::flat_hash_map& parameters, + const std::string& expected_expansion, + const absl::flat_hash_set& expected_vars_found) { + absl::flat_hash_set vars_found; + std::string target; + ASSERT_TRUE( + ExpandURITemplate(uri_template, parameters, &target, &vars_found)); + EXPECT_EQ(expected_expansion, target); + EXPECT_EQ(vars_found, expected_vars_found); +} + +TEST(QuicheUrlUtilsTest, Basic) { + ValidateExpansion("/{foo}/{bar}/", {{"foo", "123"}, {"bar", "456"}}, + "/123/456/", {"foo", "bar"}); +} + +TEST(QuicheUrlUtilsTest, ExtraParameter) { + ValidateExpansion("/{foo}/{bar}/{baz}/", {{"foo", "123"}, {"bar", "456"}}, + "/123/456//", {"foo", "bar"}); +} + +TEST(QuicheUrlUtilsTest, MissingParameter) { + ValidateExpansion("/{foo}/{baz}/", {{"foo", "123"}, {"bar", "456"}}, "/123//", + {"foo"}); +} + +TEST(QuicheUrlUtilsTest, RepeatedParameter) { + ValidateExpansion("/{foo}/{bar}/{foo}/", {{"foo", "123"}, {"bar", "456"}}, + "/123/456/123/", {"foo", "bar"}); +} + +TEST(QuicheUrlUtilsTest, URLEncoding) { + ValidateExpansion("/{foo}/{bar}/", {{"foo", "123"}, {"bar", ":"}}, + "/123/%3A/", {"foo", "bar"}); +} + +void ValidateUrlDecode(const std::string& input, + const absl::optional& expected_output) { + absl::optional decode_result = AsciiUrlDecode(input); + if (!expected_output.has_value()) { + EXPECT_FALSE(decode_result.has_value()); + return; + } + ASSERT_TRUE(decode_result.has_value()); + EXPECT_EQ(decode_result.value(), expected_output); +} + +TEST(QuicheUrlUtilsTest, DecodeNoChange) { + ValidateUrlDecode("foobar", "foobar"); +} + +TEST(QuicheUrlUtilsTest, DecodeReplace) { + ValidateUrlDecode("%7Bfoobar%7D", "{foobar}"); +} + +TEST(QuicheUrlUtilsTest, DecodeFail) { + ValidateUrlDecode("%FF", absl::nullopt); +} + +} // namespace +} // namespace quiche diff --git a/quiche/common/platform/api/testdir/README.md b/quiche/common/platform/api/testdir/README.md new file mode 100644 index 000000000000..8be29a969d54 --- /dev/null +++ b/quiche/common/platform/api/testdir/README.md @@ -0,0 +1 @@ +This directory is used in the QUICHE filesystem API tests. diff --git a/quiche/common/platform/api/testdir/a/b/c/d/e b/quiche/common/platform/api/testdir/a/b/c/d/e new file mode 100644 index 000000000000..fa8570e15711 --- /dev/null +++ b/quiche/common/platform/api/testdir/a/b/c/d/e @@ -0,0 +1 @@ +Test file for deeply nested folders. \ No newline at end of file diff --git a/quiche/common/platform/api/testdir/a/subdir/testfile b/quiche/common/platform/api/testdir/a/subdir/testfile new file mode 100644 index 000000000000..52a65318e22a --- /dev/null +++ b/quiche/common/platform/api/testdir/a/subdir/testfile @@ -0,0 +1 @@ +Test for a file with the same name as the other file. \ No newline at end of file diff --git a/quiche/common/platform/api/testdir/a/z b/quiche/common/platform/api/testdir/a/z new file mode 100644 index 000000000000..4f3124c37362 --- /dev/null +++ b/quiche/common/platform/api/testdir/a/z @@ -0,0 +1 @@ +Test for a file in a subdirectory. \ No newline at end of file diff --git a/quiche/common/platform/api/testdir/testfile b/quiche/common/platform/api/testdir/testfile new file mode 100644 index 000000000000..af27ff4986a7 --- /dev/null +++ b/quiche/common/platform/api/testdir/testfile @@ -0,0 +1 @@ +This is a test file. \ No newline at end of file diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_bug_tracker_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_bug_tracker_impl.h new file mode 100644 index 000000000000..2725f9468d69 --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_bug_tracker_impl.h @@ -0,0 +1,16 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_BUG_TRACKER_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_BUG_TRACKER_IMPL_H_ + +#include "quiche/common/platform/api/quiche_logging.h" + +#define QUICHE_BUG_IMPL(b) QUICHE_LOG(DFATAL) << #b ": " +#define QUICHE_BUG_IF_IMPL(b, condition) \ + QUICHE_LOG_IF(DFATAL, condition) << #b ": " +#define QUICHE_PEER_BUG_IMPL(b) QUICHE_LOG(DFATAL) +#define QUICHE_PEER_BUG_IF_IMPL(b, condition) QUICHE_LOG_IF(DFATAL, condition) + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_BUG_TRACKER_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_client_stats_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_client_stats_impl.h new file mode 100644 index 000000000000..fde75f78e0e5 --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_client_stats_impl.h @@ -0,0 +1,44 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_CLIENT_STATS_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_CLIENT_STATS_IMPL_H_ + +#include + +namespace quiche { + +// Use namespace qualifier in case the macro is used outside the quiche +// namespace. + +#define QUICHE_CLIENT_HISTOGRAM_ENUM_IMPL(name, sample, enum_size, docstring) \ + do { \ + quiche::QuicheClientSparseHistogramImpl(name, static_cast(sample)); \ + } while (0) + +#define QUICHE_CLIENT_HISTOGRAM_BOOL_IMPL(name, sample, docstring) \ + do { \ + (void)sample; /* Workaround for -Wunused-variable. */ \ + } while (0) + +#define QUICHE_CLIENT_HISTOGRAM_TIMES_IMPL(name, sample, min, max, \ + num_buckets, docstring) \ + do { \ + (void)sample; /* Workaround for -Wunused-variable. */ \ + } while (0) + +#define QUICHE_CLIENT_HISTOGRAM_COUNTS_IMPL(name, sample, min, max, \ + num_buckets, docstring) \ + do { \ + quiche::QuicheClientSparseHistogramImpl(name, sample); \ + } while (0) + +inline void QuicheClientSparseHistogramImpl(const std::string& /*name*/, + int /*sample*/) { + // No-op. +} + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_CLIENT_STATS_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_command_line_flags_impl.cc b/quiche/common/platform/default/quiche_platform_impl/quiche_command_line_flags_impl.cc new file mode 100644 index 000000000000..ae1537696dfe --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_command_line_flags_impl.cc @@ -0,0 +1,41 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche_platform_impl/quiche_command_line_flags_impl.h" + +#include + +#include "absl/flags/parse.h" +#include "absl/flags/usage.h" + +namespace quiche { + +static void SetUsage(absl::string_view usage) { + static bool usage_set = false; + if (!usage_set) { + absl::SetProgramUsageMessage(usage); + usage_set = true; + } +} + +std::vector QuicheParseCommandLineFlagsImpl( + const char* usage, int argc, const char* const* argv, bool /*parse_only*/) { + SetUsage(usage); + std::vector parsed = + absl::ParseCommandLine(argc, const_cast(argv)); + std::vector result; + result.reserve(parsed.size()); + // Remove the first argument, which is the name of the binary. + for (size_t i = 1; i < parsed.size(); i++) { + result.push_back(std::string(parsed[i])); + } + return result; +} + +void QuichePrintCommandLineFlagHelpImpl(const char* usage) { + SetUsage(usage); + std::cerr << absl::ProgramUsageMessage() << std::endl; +} + +} // namespace quiche diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_command_line_flags_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_command_line_flags_impl.h new file mode 100644 index 000000000000..79d91d2aa2f5 --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_command_line_flags_impl.h @@ -0,0 +1,33 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_COMMAND_LINE_FLAGS_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_COMMAND_LINE_FLAGS_IMPL_H_ + +#include "absl/flags/flag.h" + +#define DEFINE_QUICHE_COMMAND_LINE_FLAG_IMPL(type, name, default_value, help) \ + ABSL_FLAG(type, name, default_value, help) + +namespace quiche { + +template +T GetQuicheCommandLineFlag(const absl::Flag& flag) { + return absl::GetFlag(flag); +} + +std::vector QuicheParseCommandLineFlagsImpl( + const char* usage, int argc, const char* const* argv, + bool parse_only = false); + +void QuichePrintCommandLineFlagHelpImpl(const char* usage); + +} // namespace quiche + +template +T GetQuicheFlagImplImpl(const absl::Flag& flag) { + return absl::GetFlag(flag); +} + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_COMMAND_LINE_FLAGS_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_containers_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_containers_impl.h new file mode 100644 index 000000000000..ded2e40a71ac --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_containers_impl.h @@ -0,0 +1,17 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_CONTAINERS_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_CONTAINERS_IMPL_H_ + +#include "absl/container/btree_set.h" + +namespace quiche { + +template +using QuicheSmallOrderedSetImpl = absl::btree_set; + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_CONTAINERS_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_default_proof_providers_impl.cc b/quiche/common/platform/default/quiche_platform_impl/quiche_default_proof_providers_impl.cc new file mode 100644 index 000000000000..97b14825dec7 --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_default_proof_providers_impl.cc @@ -0,0 +1,77 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche_platform_impl/quiche_default_proof_providers_impl.h" + +#include +#include +#include +#include + +#include "quiche/quic/core/crypto/certificate_view.h" +#include "quiche/quic/core/crypto/proof_source.h" +#include "quiche/quic/core/crypto/proof_source_x509.h" +#include "quiche/quic/core/crypto/proof_verifier.h" +#include "quiche_platform_impl/quiche_command_line_flags_impl.h" + +DEFINE_QUICHE_COMMAND_LINE_FLAG_IMPL(std::string, certificate_file, "", + "Path to the certificate chain."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG_IMPL(std::string, key_file, "", + "Path to the pkcs8 private key."); + +namespace quiche { + +// TODO(vasilvv): implement this in order for the CLI tools to work. +std::unique_ptr CreateDefaultProofVerifierImpl( + const std::string& /*host*/) { + return nullptr; +} + +std::unique_ptr CreateDefaultProofSourceImpl() { + std::string certificate_file = + quiche::GetQuicheCommandLineFlag(FLAGS_certificate_file); + if (certificate_file.empty()) { + // TODO(b/275440369): switch to QUICHE_LOG(FATAL) when available. + std::cerr << "QUIC ProofSource needs a certificate file, but " + "--certificate_file was empty." + << std::endl; + exit(1); + } + + std::string key_file = quiche::GetQuicheCommandLineFlag(FLAGS_key_file); + if (key_file.empty()) { + // TODO(b/275440369): switch to QUICHE_LOG(FATAL) when available. + std::cerr + << "QUIC ProofSource needs a private key, but --key_file was empty." + << std::endl; + exit(1); + } + + std::ifstream cert_stream(certificate_file, std::ios::binary); + std::vector certs = + quic::CertificateView::LoadPemFromStream(&cert_stream); + if (certs.empty()) { + // TODO(b/275440369): switch to QUICHE_LOG(FATAL) when available. + std::cerr << "Failed to load certificate chain from --certificate_file=" + << certificate_file << std::endl; + exit(1); + } + + std::ifstream key_stream(key_file, std::ios::binary); + std::unique_ptr private_key = + quic::CertificatePrivateKey::LoadPemFromStream(&key_stream); + if (private_key == nullptr) { + // TODO(b/275440369): switch to QUICHE_LOG(FATAL) when available. + std::cerr << "Failed to load private key from --key_file=" << key_file + << std::endl; + exit(1); + } + + QuicheReferenceCountedPointer chain( + new quic::ProofSource::Chain({certs})); + return quic::ProofSourceX509::Create(chain, std::move(*private_key)); +} + +} // namespace quiche diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_default_proof_providers_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_default_proof_providers_impl.h new file mode 100644 index 000000000000..208acc6cfff0 --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_default_proof_providers_impl.h @@ -0,0 +1,21 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_DEFAULT_PROOF_PROVIDERS_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_DEFAULT_PROOF_PROVIDERS_IMPL_H_ + +#include + +#include "quiche/quic/core/crypto/proof_source.h" +#include "quiche/quic/core/crypto/proof_verifier.h" + +namespace quiche { + +std::unique_ptr CreateDefaultProofVerifierImpl( + const std::string& host); +std::unique_ptr CreateDefaultProofSourceImpl(); + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_DEFAULT_PROOF_PROVIDERS_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_event_loop_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_event_loop_impl.h new file mode 100644 index 000000000000..44613ce46484 --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_event_loop_impl.h @@ -0,0 +1,27 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_EVENT_LOOP_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_EVENT_LOOP_IMPL_H_ + +#include + +namespace quic { +class QuicEventLoopFactory; +} + +namespace quiche { + +inline quic::QuicEventLoopFactory* GetOverrideForDefaultEventLoopImpl() { + return nullptr; +} + +inline std::vector +GetExtraEventLoopImplementationsImpl() { + return {}; +} + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_EVENT_LOOP_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_expect_bug_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_expect_bug_impl.h new file mode 100644 index 000000000000..89b55562ecb6 --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_expect_bug_impl.h @@ -0,0 +1,15 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_EXPECT_BUG_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_EXPECT_BUG_IMPL_H_ + +#include "quiche/common/platform/api/quiche_test.h" + +#define EXPECT_QUICHE_BUG_IMPL(statement, regex) \ + EXPECT_QUICHE_DEBUG_DEATH(statement, regex) +#define EXPECT_QUICHE_PEER_BUG_IMPL(statement, regex) \ + EXPECT_QUICHE_DEBUG_DEATH(statement, regex) + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_EXPECT_BUG_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_export_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_export_impl.h new file mode 100644 index 000000000000..64396f7044c3 --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_export_impl.h @@ -0,0 +1,18 @@ +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_EXPORT_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_EXPORT_IMPL_H_ + +#include "absl/base/attributes.h" + +// These macros are documented in: quiche/quic/platform/api/quic_export.h + +#if defined(_WIN32) +#define QUICHE_EXPORT_IMPL __declspec(dllexport) +#elif ABSL_HAVE_ATTRIBUTE(visibility) +#define QUICHE_EXPORT_IMPL __attribute__((visibility("default"))) +#else +#define QUICHE_EXPORT_IMPL +#endif + +#define QUICHE_NO_EXPORT_IMPL QUICHE_EXPORT_IMPL + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_EXPORT_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_file_utils_impl.cc b/quiche/common/platform/default/quiche_platform_impl/quiche_file_utils_impl.cc new file mode 100644 index 000000000000..65965b222d5c --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_file_utils_impl.cc @@ -0,0 +1,182 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche_platform_impl/quiche_file_utils_impl.h" + +#if defined(_WIN32) +#include +#else +#include +#include +#include +#include +#endif // defined(_WIN32) + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" +#include "absl/types/optional.h" + +namespace quiche { + +#if defined(_WIN32) +std::string JoinPathImpl(absl::string_view a, absl::string_view b) { + if (a.empty()) { + return std::string(b); + } + if (b.empty()) { + return std::string(a); + } + // Win32 actually provides two different APIs for combining paths; one of them + // has issues that could potentially lead to buffer overflow, and another is + // not supported in Windows 7, which is why we're doing it manually. + a = absl::StripSuffix(a, "/"); + a = absl::StripSuffix(a, "\\"); + return absl::StrCat(a, "\\", b); +} +#else +std::string JoinPathImpl(absl::string_view a, absl::string_view b) { + if (a.empty()) { + return std::string(b); + } + if (b.empty()) { + return std::string(a); + } + return absl::StrCat(absl::StripSuffix(a, "/"), "/", b); +} +#endif // defined(_WIN32) + +absl::optional ReadFileContentsImpl(absl::string_view file) { + std::ifstream input_file(std::string{file}, std::ios::binary); + if (!input_file || !input_file.is_open()) { + return absl::nullopt; + } + + input_file.seekg(0, std::ios_base::end); + auto file_size = input_file.tellg(); + if (!input_file) { + return absl::nullopt; + } + input_file.seekg(0, std::ios_base::beg); + + std::string output; + output.resize(file_size); + input_file.read(&output[0], file_size); + if (!input_file) { + return absl::nullopt; + } + + return output; +} + +#if defined(_WIN32) + +class ScopedDir { + public: + ScopedDir(HANDLE dir) : dir_(dir) {} + ~ScopedDir() { + if (dir_ != INVALID_HANDLE_VALUE) { + // The API documentation explicitly says that CloseHandle() should not be + // used on directory search handles. + FindClose(dir_); + dir_ = INVALID_HANDLE_VALUE; + } + } + + HANDLE get() { return dir_; } + + private: + HANDLE dir_; +}; + +bool EnumerateDirectoryImpl(absl::string_view path, + std::vector& directories, + std::vector& files) { + std::string path_owned(path); + + // Explicitly check that the directory we are trying to search is in fact a + // directory. + DWORD attributes = GetFileAttributesA(path_owned.c_str()); + if (attributes == INVALID_FILE_ATTRIBUTES) { + return false; + } + if ((attributes & FILE_ATTRIBUTE_DIRECTORY) == 0) { + return false; + } + + std::string search_path = JoinPathImpl(path, "*"); + WIN32_FIND_DATAA file_data; + ScopedDir dir(FindFirstFileA(search_path.c_str(), &file_data)); + if (dir.get() == INVALID_HANDLE_VALUE) { + return GetLastError() == ERROR_FILE_NOT_FOUND; + } + do { + std::string filename(file_data.cFileName); + if (filename == "." || filename == "..") { + continue; + } + if ((file_data.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) != 0) { + directories.push_back(std::move(filename)); + } else { + files.push_back(std::move(filename)); + } + } while (FindNextFileA(dir.get(), &file_data)); + return GetLastError() == ERROR_NO_MORE_FILES; +} + +#else // defined(_WIN32) + +class ScopedDir { + public: + ScopedDir(DIR* dir) : dir_(dir) {} + ~ScopedDir() { + if (dir_ != nullptr) { + closedir(dir_); + dir_ = nullptr; + } + } + + DIR* get() { return dir_; } + + private: + DIR* dir_; +}; + +bool EnumerateDirectoryImpl(absl::string_view path, + std::vector& directories, + std::vector& files) { + std::string path_owned(path); + ScopedDir dir(opendir(path_owned.c_str())); + if (dir.get() == nullptr) { + return false; + } + + dirent* entry; + while ((entry = readdir(dir.get()))) { + const std::string filename(entry->d_name); + if (filename == "." || filename == "..") { + continue; + } + + const std::string entry_path = JoinPathImpl(path, filename); + struct stat stat_entry; + if (stat(entry_path.c_str(), &stat_entry) != 0) { + return false; + } + if (S_ISREG(stat_entry.st_mode)) { + files.push_back(std::move(filename)); + } else if (S_ISDIR(stat_entry.st_mode)) { + directories.push_back(std::move(filename)); + } + } + return true; +} + +#endif // defined(_WIN32) + +} // namespace quiche diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_file_utils_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_file_utils_impl.h new file mode 100644 index 000000000000..ad5ff1a9084f --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_file_utils_impl.h @@ -0,0 +1,26 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_FILE_UTILS_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_FILE_UTILS_IMPL_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" + +namespace quiche { + +std::string JoinPathImpl(absl::string_view a, absl::string_view b); + +absl::optional ReadFileContentsImpl(absl::string_view file); + +bool EnumerateDirectoryImpl(absl::string_view path, + std::vector& directories, + std::vector& files); + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_FILE_UTILS_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_flag_utils_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_flag_utils_impl.h new file mode 100644 index 000000000000..c38f75c67dac --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_flag_utils_impl.h @@ -0,0 +1,29 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_FLAG_UTILS_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_FLAG_UTILS_IMPL_H_ + +#define QUICHE_RELOADABLE_FLAG_COUNT_IMPL(flag) \ + do { \ + } while (0) +#define QUICHE_RELOADABLE_FLAG_COUNT_N_IMPL(flag, instance, total) \ + do { \ + } while (0) + +#define QUICHE_RESTART_FLAG_COUNT_IMPL(flag) \ + do { \ + } while (0) +#define QUICHE_RESTART_FLAG_COUNT_N_IMPL(flag, instance, total) \ + do { \ + } while (0) + +#define QUICHE_CODE_COUNT_IMPL(name) \ + do { \ + } while (0) +#define QUICHE_CODE_COUNT_N_IMPL(name, instance, total) \ + do { \ + } while (0) + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_FLAG_UTILS_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_flags_impl.cc b/quiche/common/platform/default/quiche_platform_impl/quiche_flags_impl.cc new file mode 100644 index 000000000000..eb609835fa91 --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_flags_impl.cc @@ -0,0 +1,37 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche_platform_impl/quiche_flags_impl.h" + +#define QUIC_FLAG(flag, value) bool FLAGS_##flag = value; +#include "quiche/quic/core/quic_flags_list.h" +#undef QUIC_FLAG + +#define DEFINE_QUIC_PROTOCOL_FLAG_SINGLE_VALUE(type, flag, value, doc) \ + type FLAGS_##flag = value; + +#define DEFINE_QUIC_PROTOCOL_FLAG_TWO_VALUES(type, flag, internal_value, \ + external_value, doc) \ + type FLAGS_##flag = external_value; + +// Preprocessor macros can only have one definition. +// Select the right macro based on the number of arguments. +#define GET_6TH_ARG(arg1, arg2, arg3, arg4, arg5, arg6, ...) arg6 +#define QUIC_PROTOCOL_FLAG_MACRO_CHOOSER(...) \ + GET_6TH_ARG(__VA_ARGS__, DEFINE_QUIC_PROTOCOL_FLAG_TWO_VALUES, \ + DEFINE_QUIC_PROTOCOL_FLAG_SINGLE_VALUE) +#define QUIC_PROTOCOL_FLAG(...) \ + QUIC_PROTOCOL_FLAG_MACRO_CHOOSER(__VA_ARGS__)(__VA_ARGS__) + +#include "quiche/quic/core/quic_protocol_flags_list.h" + +#undef QUIC_PROTOCOL_FLAG +#undef QUIC_PROTOCOL_FLAG_MACRO_CHOOSER +#undef GET_6TH_ARG +#undef DEFINE_QUIC_PROTOCOL_FLAG_TWO_VALUES +#undef DEFINE_QUIC_PROTOCOL_FLAG_SINGLE_VALUE + +#define QUICHE_PROTOCOL_FLAG(type, flag, value, doc) type FLAGS_##flag = value; +#include "quiche/common/quiche_protocol_flags_list.h" +#undef QUICHE_PROTOCOL_FLAG diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_flags_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_flags_impl.h new file mode 100644 index 000000000000..4565cd8d5dab --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_flags_impl.h @@ -0,0 +1,57 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_FLAGS_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_FLAGS_IMPL_H_ + +#include +#include +#include +#include +#include + +#include "quiche/common/platform/api/quiche_export.h" + +#define QUIC_FLAG(flag, value) QUICHE_EXPORT extern bool FLAGS_##flag; +#include "quiche/quic/core/quic_flags_list.h" +#undef QUIC_FLAG + +// Protocol flags. TODO(bnc): Move to quiche_protocol_flags_list.h. +#define QUIC_PROTOCOL_FLAG(type, flag, ...) \ + QUICHE_EXPORT extern type FLAGS_##flag; +#include "quiche/quic/core/quic_protocol_flags_list.h" +#undef QUIC_PROTOCOL_FLAG + +// Protocol flags. +#define QUICHE_PROTOCOL_FLAG(type, flag, ...) \ + QUICHE_EXPORT extern type FLAGS_##flag; +#include "quiche/common/quiche_protocol_flags_list.h" +#undef QUICHE_PROTOCOL_FLAG + +#define GetQuicheFlagImpl(flag) GetQuicheFlagImplImpl(FLAGS_##flag) +inline bool GetQuicheFlagImplImpl(bool flag) { return flag; } +inline int32_t GetQuicheFlagImplImpl(int32_t flag) { return flag; } +inline int64_t GetQuicheFlagImplImpl(int64_t flag) { return flag; } +inline uint64_t GetQuicheFlagImplImpl(uint64_t flag) { return flag; } +inline double GetQuicheFlagImplImpl(double flag) { return flag; } +inline std::string GetQuicheFlagImplImpl(const std::string& flag) { + return flag; +} +#define SetQuicheFlagImpl(flag, value) ((FLAGS_##flag) = (value)) + +// ------------------------------------------------------------------------ +// QUICHE feature flags implementation. +// ------------------------------------------------------------------------ +#define QUICHE_RELOADABLE_FLAG(flag) quic_reloadable_flag_##flag +#define QUICHE_RESTART_FLAG(flag) quic_restart_flag_##flag +#define GetQuicheReloadableFlagImpl(module, flag) \ + GetQuicheFlag(QUICHE_RELOADABLE_FLAG(flag)) +#define SetQuicheReloadableFlagImpl(module, flag, value) \ + SetQuicheFlag(QUICHE_RELOADABLE_FLAG(flag), value) +#define GetQuicheRestartFlagImpl(module, flag) \ + GetQuicheFlag(QUICHE_RESTART_FLAG(flag)) +#define SetQuicheRestartFlagImpl(module, flag, value) \ + SetQuicheFlag(QUICHE_RESTART_FLAG(flag), value) + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_FLAGS_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_header_policy_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_header_policy_impl.h new file mode 100644 index 000000000000..f9f32b581587 --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_header_policy_impl.h @@ -0,0 +1,16 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_HEADER_POLICY_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_HEADER_POLICY_IMPL_H_ + +#include "absl/strings/string_view.h" + +namespace quiche { + +inline void QuicheHandleHeaderPolicyImpl(absl::string_view /*key*/) {} + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_HEADER_POLICY_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_iovec_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_iovec_impl.h new file mode 100644 index 000000000000..b4ac17fb89b3 --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_iovec_impl.h @@ -0,0 +1,24 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_IOVEC_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_IOVEC_IMPL_H_ + +#include "quiche/common/platform/api/quiche_export.h" + +#if defined(_WIN32) + +// See +struct QUICHE_EXPORT iovec { + void* iov_base; + size_t iov_len; +}; + +#else + +#include // IWYU pragma: export + +#endif // defined(_WIN32) + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_IOVEC_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_logging_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_logging_impl.h new file mode 100644 index 000000000000..7ab2966e133f --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_logging_impl.h @@ -0,0 +1,160 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This file does not actually implement logging, it merely provides enough of +// logging code for QUICHE to compile and pass the unit tests. QUICHE embedders +// are encouraged to override this file with their own logic. If at some point +// logging becomes a part of Abseil, this file will likely start using that +// instead. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_LOGGING_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_LOGGING_IMPL_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche_platform_impl/quiche_stack_trace_impl.h" + +namespace quiche { + +class QUICHE_EXPORT LogStreamVoidHelper { + public: + // This operator has lower precedence than << but higher than ?:, which is + // useful for implementing QUICHE_DISREGARD_LOG_STREAM below. + constexpr void operator&(std::ostream&) {} +}; + +// NoopLogSink provides a log sink that does not put the data that it logs +// anywhere. +class QUICHE_EXPORT NoopLogSink { + public: + NoopLogSink() {} + + template + constexpr NoopLogSink(const T&) {} + + template + constexpr NoopLogSink(const T1&, const T2&) {} + + constexpr std::ostream& stream() { return stream_; } + + protected: + std::string str() { return stream_.str(); } + + private: + std::stringstream stream_; +}; + +// We need to actually implement LOG(FATAL), otherwise some functions will fail +// to compile due to the "failed to return value from non-void function" error. +class QUICHE_EXPORT FatalLogSink : public NoopLogSink { + public: + ABSL_ATTRIBUTE_NORETURN ~FatalLogSink() { + std::cerr << str() << std::endl; + std::cerr << quiche::QuicheStackTraceImpl() << std::endl; + abort(); + } +}; + +class QUICHE_EXPORT CheckLogSink : public NoopLogSink { + public: + CheckLogSink(bool condition) : condition_(condition) {} + ~CheckLogSink() { + if (!condition_) { + std::cerr << "Check failed: " << str() << std::endl; + std::cerr << quiche::QuicheStackTraceImpl() << std::endl; + abort(); + } + } + + private: + const bool condition_; +}; + +} // namespace quiche + +// This is necessary because we sometimes call QUICHE_DCHECK inside constexpr +// functions, and then write non-constexpr expressions into the resulting log. +#define QUICHE_CONDITIONAL_LOG_STREAM(stream, condition) \ + !(condition) ? (void)0 : ::quiche::LogStreamVoidHelper() & (stream) +#define QUICHE_DISREGARD_LOG_STREAM(stream) \ + QUICHE_CONDITIONAL_LOG_STREAM(stream, /*condition=*/false) +#define QUICHE_NOOP_STREAM() \ + QUICHE_DISREGARD_LOG_STREAM(::quiche::NoopLogSink().stream()) +#define QUICHE_NOOP_STREAM_WITH_CONDITION(condition) \ + QUICHE_DISREGARD_LOG_STREAM(::quiche::NoopLogSink(condition).stream()) + +#define QUICHE_DVLOG_IMPL(verbose_level) QUICHE_NOOP_STREAM() +#define QUICHE_DVLOG_IF_IMPL(verbose_level, condition) \ + QUICHE_NOOP_STREAM_WITH_CONDITION(condition) +#define QUICHE_DLOG_IMPL(severity) QUICHE_NOOP_STREAM() +#define QUICHE_VLOG_IMPL(verbose_level) QUICHE_NOOP_STREAM() +#define QUICHE_LOG_FIRST_N_IMPL(severity, n) QUICHE_NOOP_STREAM() +#define QUICHE_LOG_EVERY_N_SEC_IMPL(severity, seconds) QUICHE_NOOP_STREAM() + +#define QUICHE_LOG_IMPL(severity) QUICHE_LOG_IMPL_##severity() +#define QUICHE_LOG_IMPL_FATAL() ::quiche::FatalLogSink().stream() +#define QUICHE_LOG_IMPL_ERROR() ::quiche::NoopLogSink().stream() +#define QUICHE_LOG_IMPL_WARNING() ::quiche::NoopLogSink().stream() +#define QUICHE_LOG_IMPL_INFO() ::quiche::NoopLogSink().stream() + +#define QUICHE_LOG_IF_IMPL(severity, condition) \ + QUICHE_CONDITIONAL_LOG_STREAM(QUICHE_LOG_IMPL_##severity(), condition) + +#ifdef NDEBUG +#define QUICHE_LOG_IMPL_DFATAL() ::quiche::NoopLogSink().stream() +#define QUICHE_DLOG_IF_IMPL(severity, condition) \ + QUICHE_NOOP_STREAM_WITH_CONDITION(condition) +#else +#define QUICHE_LOG_IMPL_DFATAL() ::quiche::FatalLogSink().stream() +#define QUICHE_DLOG_IF_IMPL(severity, condition) \ + QUICHE_CONDITIONAL_LOG_STREAM(QUICHE_LOG_IMPL_##severity(), condition) +#endif + +#define QUICHE_PLOG_IMPL(severity) QUICHE_NOOP_STREAM() + +#define QUICHE_DLOG_INFO_IS_ON_IMPL() false +#define QUICHE_LOG_INFO_IS_ON_IMPL() false +#define QUICHE_LOG_WARNING_IS_ON_IMPL() false +#define QUICHE_LOG_ERROR_IS_ON_IMPL() false + +#define QUICHE_CHECK_IMPL(condition) \ + ::quiche::CheckLogSink(static_cast(condition)).stream() +#define QUICHE_CHECK_EQ_IMPL(val1, val2) \ + ::quiche::CheckLogSink((val1) == (val2)).stream() +#define QUICHE_CHECK_NE_IMPL(val1, val2) \ + ::quiche::CheckLogSink((val1) != (val2)).stream() +#define QUICHE_CHECK_LE_IMPL(val1, val2) \ + ::quiche::CheckLogSink((val1) <= (val2)).stream() +#define QUICHE_CHECK_LT_IMPL(val1, val2) \ + ::quiche::CheckLogSink((val1) < (val2)).stream() +#define QUICHE_CHECK_GE_IMPL(val1, val2) \ + ::quiche::CheckLogSink((val1) >= (val2)).stream() +#define QUICHE_CHECK_GT_IMPL(val1, val2) \ + ::quiche::CheckLogSink((val1) > (val2)).stream() +#define QUICHE_CHECK_OK_IMPL(status) \ + QUICHE_CHECK_EQ_IMPL(absl::OkStatus(), (status)) + +#ifdef NDEBUG +#define QUICHE_DCHECK_IMPL(condition) \ + QUICHE_NOOP_STREAM_WITH_CONDITION((condition)) +#else +#define QUICHE_DCHECK_IMPL(condition) \ + QUICHE_LOG_IF_IMPL(DFATAL, !static_cast(condition)) \ + << "Check failed: " << #condition +#endif +#define QUICHE_DCHECK_EQ_IMPL(val1, val2) QUICHE_DCHECK_IMPL((val1) == (val2)) +#define QUICHE_DCHECK_NE_IMPL(val1, val2) QUICHE_DCHECK_IMPL((val1) != (val2)) +#define QUICHE_DCHECK_LE_IMPL(val1, val2) QUICHE_DCHECK_IMPL((val1) <= (val2)) +#define QUICHE_DCHECK_LT_IMPL(val1, val2) QUICHE_DCHECK_IMPL((val1) < (val2)) +#define QUICHE_DCHECK_GE_IMPL(val1, val2) QUICHE_DCHECK_IMPL((val1) >= (val2)) +#define QUICHE_DCHECK_GT_IMPL(val1, val2) QUICHE_DCHECK_IMPL((val1) > (val2)) + +#define QUICHE_NOTREACHED_IMPL() QUICHE_DCHECK_IMPL(false) + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_LOGGING_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_lower_case_string_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_lower_case_string_impl.h new file mode 100644 index 000000000000..576101232a80 --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_lower_case_string_impl.h @@ -0,0 +1,25 @@ +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_LOWER_CASE_STRING_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_LOWER_CASE_STRING_IMPL_H_ + +#include + +#include "absl/strings/ascii.h" +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace quiche { + +class QUICHE_EXPORT QuicheLowerCaseStringImpl { + public: + QuicheLowerCaseStringImpl(absl::string_view str) + : str_(absl::AsciiStrToLower(str)) {} + + const std::string& get() const { return str_; } + + private: + std::string str_; +}; + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_LOWER_CASE_STRING_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_mem_slice_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_mem_slice_impl.h new file mode 100644 index 000000000000..b422b5c7f03e --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_mem_slice_impl.h @@ -0,0 +1,46 @@ +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_MEM_SLICE_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_MEM_SLICE_IMPL_H_ + +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/quiche_buffer_allocator.h" +#include "quiche/common/simple_buffer_allocator.h" + +namespace quiche { + +class QUICHE_EXPORT QuicheMemSliceImpl { + public: + QuicheMemSliceImpl() = default; + + explicit QuicheMemSliceImpl(QuicheBuffer buffer) + : buffer_(std::move(buffer)) {} + + QuicheMemSliceImpl(std::unique_ptr buffer, size_t length) + : buffer_( + QuicheBuffer(QuicheUniqueBufferPtr( + buffer.release(), + QuicheBufferDeleter(SimpleBufferAllocator::Get())), + length)) {} + + QuicheMemSliceImpl(const QuicheMemSliceImpl& other) = delete; + QuicheMemSliceImpl& operator=(const QuicheMemSliceImpl& other) = delete; + + // Move constructors. |other| will not hold a reference to the data buffer + // after this call completes. + QuicheMemSliceImpl(QuicheMemSliceImpl&& other) = default; + QuicheMemSliceImpl& operator=(QuicheMemSliceImpl&& other) = default; + + ~QuicheMemSliceImpl() = default; + + void Reset() { buffer_ = QuicheBuffer(); } + + const char* data() const { return buffer_.data(); } + size_t length() const { return buffer_.size(); } + bool empty() const { return buffer_.empty(); } + + private: + QuicheBuffer buffer_; +}; + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_MEM_SLICE_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_mutex_impl.cc b/quiche/common/platform/default/quiche_platform_impl/quiche_mutex_impl.cc new file mode 100644 index 000000000000..78ca92661536 --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_mutex_impl.cc @@ -0,0 +1,15 @@ +#include "quiche_platform_impl/quiche_mutex_impl.h" + +namespace quiche { + +void QuicheLockImpl::WriterLock() { mu_.WriterLock(); } + +void QuicheLockImpl::WriterUnlock() { mu_.WriterUnlock(); } + +void QuicheLockImpl::ReaderLock() { mu_.ReaderLock(); } + +void QuicheLockImpl::ReaderUnlock() { mu_.ReaderUnlock(); } + +void QuicheLockImpl::AssertReaderHeld() const { mu_.AssertReaderHeld(); } + +} // namespace quiche diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_mutex_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_mutex_impl.h new file mode 100644 index 000000000000..c6f6655671d8 --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_mutex_impl.h @@ -0,0 +1,68 @@ +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_MUTEX_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_MUTEX_IMPL_H_ + +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "quiche/common/platform/api/quiche_export.h" + +#define QUICHE_EXCLUSIVE_LOCKS_REQUIRED_IMPL ABSL_EXCLUSIVE_LOCKS_REQUIRED +#define QUICHE_GUARDED_BY_IMPL ABSL_GUARDED_BY +#define QUICHE_LOCKABLE_IMPL ABSL_LOCKABLE +#define QUICHE_LOCKS_EXCLUDED_IMPL ABSL_LOCKS_EXCLUDED +#define QUICHE_SHARED_LOCKS_REQUIRED_IMPL ABSL_SHARED_LOCKS_REQUIRED +#define QUICHE_EXCLUSIVE_LOCK_FUNCTION_IMPL ABSL_EXCLUSIVE_LOCK_FUNCTION +#define QUICHE_UNLOCK_FUNCTION_IMPL ABSL_UNLOCK_FUNCTION +#define QUICHE_SHARED_LOCK_FUNCTION_IMPL ABSL_SHARED_LOCK_FUNCTION +#define QUICHE_SCOPED_LOCKABLE_IMPL ABSL_SCOPED_LOCKABLE +#define QUICHE_ASSERT_SHARED_LOCK_IMPL ABSL_ASSERT_SHARED_LOCK + +namespace quiche { + +// A class wrapping a non-reentrant mutex. +class ABSL_LOCKABLE QUICHE_EXPORT QuicheLockImpl { + public: + QuicheLockImpl() = default; + QuicheLockImpl(const QuicheLockImpl&) = delete; + QuicheLockImpl& operator=(const QuicheLockImpl&) = delete; + + // Block until mu_ is free, then acquire it exclusively. + void WriterLock() ABSL_EXCLUSIVE_LOCK_FUNCTION(); + + // Release mu_. Caller must hold it exclusively. + void WriterUnlock() ABSL_UNLOCK_FUNCTION(); + + // Block until mu_ is free or shared, then acquire a share of it. + void ReaderLock() ABSL_SHARED_LOCK_FUNCTION(); + + // Release mu_. Caller could hold it in shared mode. + void ReaderUnlock() ABSL_UNLOCK_FUNCTION(); + + // Returns immediately if current thread holds mu_ in at least shared + // mode. Otherwise, reports an error by crashing with a diagnostic. + void AssertReaderHeld() const ABSL_ASSERT_SHARED_LOCK(); + + private: + absl::Mutex mu_; +}; + +// A Notification allows threads to receive notification of a single occurrence +// of a single event. +class QUICHE_EXPORT QuicheNotificationImpl { + public: + QuicheNotificationImpl() = default; + QuicheNotificationImpl(const QuicheNotificationImpl&) = delete; + QuicheNotificationImpl& operator=(const QuicheNotificationImpl&) = delete; + + bool HasBeenNotified() { return notification_.HasBeenNotified(); } + + void Notify() { notification_.Notify(); } + + void WaitForNotification() { notification_.WaitForNotification(); } + + private: + absl::Notification notification_; +}; + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_MUTEX_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_prefetch_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_prefetch_impl.h new file mode 100644 index 000000000000..89454817d877 --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_prefetch_impl.h @@ -0,0 +1,28 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_PREFETCH_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_PREFETCH_IMPL_H_ + +#if defined(_MSC_VER) +#include +#endif + +namespace quiche { + +inline void QuichePrefetchT0Impl(const void* addr) { +#if !defined(DISABLE_BUILTIN_PREFETCH) +#if defined(__GNUC__) || (defined(_M_ARM64) && defined(__clang__)) + __builtin_prefetch(addr, 0, 3); +#elif defined(_MSC_VER) + _mm_prefetch(reinterpret_cast(addr), _MM_HINT_T0); +#else + (void*)addr; +#endif +#endif // !defined(DISABLE_BUILTIN_PREFETCH) +} + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_PREFETCH_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_reference_counted_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_reference_counted_impl.h new file mode 100644 index 000000000000..b568da6ca2f6 --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_reference_counted_impl.h @@ -0,0 +1,190 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_REFERENCE_COUNTED_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_REFERENCE_COUNTED_IMPL_H_ + +#include +#include + +#include "absl/base/attributes.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace quiche { + +class QUICHE_EXPORT QuicheReferenceCountedImpl { + public: + virtual ~QuicheReferenceCountedImpl() { QUICHE_DCHECK_EQ(ref_count_, 0); } + + void AddReference() { ref_count_.fetch_add(1, std::memory_order_relaxed); } + + // Returns true if the objects needs to be deleted. + ABSL_MUST_USE_RESULT bool RemoveReference() { + int new_count = ref_count_.fetch_sub(1, std::memory_order_acq_rel) - 1; + QUICHE_DCHECK_GE(new_count, 0); + return new_count == 0; + } + + bool HasUniqueReference() const { + return ref_count_.load(std::memory_order_acquire) == 1; + } + + private: + std::atomic ref_count_ = 1; +}; + +template +class QUICHE_NO_EXPORT QuicheReferenceCountedPointerImpl { + public: + QuicheReferenceCountedPointerImpl() = default; + ~QuicheReferenceCountedPointerImpl() { RemoveReference(); } + + // Constructor from raw pointer |p|. This guarantees that the reference count + // of *p is 1. This should be only called when a new object is created. + explicit QuicheReferenceCountedPointerImpl(T* p) : object_(p) { + if (p != nullptr) { + QUICHE_DCHECK(p->HasUniqueReference()); + } + } + + explicit QuicheReferenceCountedPointerImpl(std::nullptr_t) + : object_(nullptr) {} + + // Copy and copy conversion constructors. + QuicheReferenceCountedPointerImpl( + const QuicheReferenceCountedPointerImpl& other) { + AssignObject(other.get()); + AddReference(); + } + template + QuicheReferenceCountedPointerImpl( // NOLINT + const QuicheReferenceCountedPointerImpl& other) { + AssignObject(other.get()); + AddReference(); + } + + // Move constructors. + QuicheReferenceCountedPointerImpl(QuicheReferenceCountedPointerImpl&& other) { + object_ = other.object_; + other.object_ = nullptr; + } + template + QuicheReferenceCountedPointerImpl( + QuicheReferenceCountedPointerImpl&& other) { // NOLINT + // We can't access other.object_ since other has different T and object_ is + // private. + object_ = other.get(); + AddReference(); + other = nullptr; + } + + // Copy assignments. + QuicheReferenceCountedPointerImpl& operator=( + const QuicheReferenceCountedPointerImpl& other) { + AssignObject(other.object_); + AddReference(); + return *this; + } + template + QuicheReferenceCountedPointerImpl& operator=( + const QuicheReferenceCountedPointerImpl& other) { + AssignObject(other.object_); + AddReference(); + return *this; + } + + // Move assignments. + QuicheReferenceCountedPointerImpl& operator=( + QuicheReferenceCountedPointerImpl&& other) { + AssignObject(other.object_); + other.object_ = nullptr; + return *this; + } + template + QuicheReferenceCountedPointerImpl& operator=( + QuicheReferenceCountedPointerImpl&& other) { + AssignObject(other.get()); + AddReference(); + other = nullptr; + return *this; + } + + T& operator*() const { return *object_; } + T* operator->() const { return object_; } + + explicit operator bool() const { return object_ != nullptr; } + + // Assignment operator on raw pointer. Behaves similar to the raw pointer + // constructor. + QuicheReferenceCountedPointerImpl& operator=(T* p) { + AssignObject(p); + if (p != nullptr) { + QUICHE_DCHECK(p->HasUniqueReference()); + } + return *this; + } + + // Returns the raw pointer with no change in reference count. + T* get() const { return object_; } + + // Comparisons against same type. + friend bool operator==(const QuicheReferenceCountedPointerImpl& a, + const QuicheReferenceCountedPointerImpl& b) { + return a.get() == b.get(); + } + friend bool operator!=(const QuicheReferenceCountedPointerImpl& a, + const QuicheReferenceCountedPointerImpl& b) { + return a.get() != b.get(); + } + + // Comparisons against nullptr. + friend bool operator==(const QuicheReferenceCountedPointerImpl& a, + std::nullptr_t) { + return a.get() == nullptr; + } + friend bool operator==(std::nullptr_t, + const QuicheReferenceCountedPointerImpl& b) { + return nullptr == b.get(); + } + friend bool operator!=(const QuicheReferenceCountedPointerImpl& a, + std::nullptr_t) { + return a.get() != nullptr; + } + friend bool operator!=(std::nullptr_t, + const QuicheReferenceCountedPointerImpl& b) { + return nullptr != b.get(); + } + + private: + void AddReference() { + if (object_ == nullptr) { + return; + } + QuicheReferenceCountedImpl* implicitly_cast_object = object_; + implicitly_cast_object->AddReference(); + } + + void RemoveReference() { + if (object_ == nullptr) { + return; + } + QuicheReferenceCountedImpl* implicitly_cast_object = object_; + if (implicitly_cast_object->RemoveReference()) { + delete implicitly_cast_object; + } + object_ = nullptr; + } + + void AssignObject(T* new_object) { + RemoveReference(); + object_ = new_object; + } + + T* object_ = nullptr; +}; + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_REFERENCE_COUNTED_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_server_stats_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_server_stats_impl.h new file mode 100644 index 000000000000..289c5990cc2c --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_server_stats_impl.h @@ -0,0 +1,26 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_SERVER_STATS_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_SERVER_STATS_IMPL_H_ + +#define QUICHE_SERVER_HISTOGRAM_ENUM_IMPL(name, sample, enum_size, docstring) \ + do { \ + } while (0) + +#define QUICHE_SERVER_HISTOGRAM_BOOL_IMPL(name, sample, docstring) \ + do { \ + } while (0) + +#define QUICHE_SERVER_HISTOGRAM_TIMES_IMPL(name, sample, min, max, \ + bucket_count, docstring) \ + do { \ + } while (0) + +#define QUICHE_SERVER_HISTOGRAM_COUNTS_IMPL(name, sample, min, max, \ + bucket_count, docstring) \ + do { \ + } while (0) + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_SERVER_STATS_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_stack_trace_impl.cc b/quiche/common/platform/default/quiche_platform_impl/quiche_stack_trace_impl.cc new file mode 100644 index 000000000000..9b0c969fe0d2 --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_stack_trace_impl.cc @@ -0,0 +1,52 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche_platform_impl/quiche_stack_trace_impl.h" + +#include + +#include "absl/base/macros.h" +#include "absl/debugging/stacktrace.h" +#include "absl/debugging/symbolize.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" + +namespace quiche { + +namespace { +constexpr int kMaxStackSize = 4096; +constexpr int kMaxSymbolSize = 1024; +constexpr absl::string_view kUnknownSymbol = "(unknown)"; +} // namespace + +std::string QuicheStackTraceImpl() { + std::vector stacktrace(kMaxStackSize, nullptr); + int num_frames = absl::GetStackTrace(stacktrace.data(), stacktrace.size(), + /*skip_count=*/0); + if (num_frames <= 0) { + return ""; + } + stacktrace.resize(num_frames); + + std::string formatted_trace = "Stack trace:\n"; + for (void* function : stacktrace) { + char symbol_name[kMaxSymbolSize]; + bool success = absl::Symbolize(function, symbol_name, sizeof(symbol_name)); + absl::StrAppendFormat( + &formatted_trace, " %p %s\n", function, + success ? absl::string_view(symbol_name) : kUnknownSymbol); + } + return formatted_trace; +} + +bool QuicheShouldRunStackTraceTestImpl() { + void* unused[4]; // An arbitrary small number of stack frames to trace. + int stack_traces_found = + absl::GetStackTrace(unused, ABSL_ARRAYSIZE(unused), /*skip_count=*/0); + // absl::GetStackTrace() always returns 0 if the current platform is + // unsupported. + return stack_traces_found > 0; +} + +} // namespace quiche diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_stack_trace_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_stack_trace_impl.h new file mode 100644 index 000000000000..5228e41197f8 --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_stack_trace_impl.h @@ -0,0 +1,17 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_STACK_TRACE_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_STACK_TRACE_IMPL_H_ + +#include + +namespace quiche { + +std::string QuicheStackTraceImpl(); +bool QuicheShouldRunStackTraceTestImpl(); + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_STACK_TRACE_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_stream_buffer_allocator_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_stream_buffer_allocator_impl.h new file mode 100644 index 000000000000..ec56ba55b7fc --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_stream_buffer_allocator_impl.h @@ -0,0 +1,16 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_STREAM_BUFFER_ALLOCATOR_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_STREAM_BUFFER_ALLOCATOR_IMPL_H_ + +#include "quiche/common/simple_buffer_allocator.h" + +namespace quiche { + +using QuicheStreamBufferAllocatorImpl = quiche::SimpleBufferAllocator; + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_STREAM_BUFFER_ALLOCATOR_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_system_event_loop_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_system_event_loop_impl.h new file mode 100644 index 000000000000..f810fa825cd1 --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_system_event_loop_impl.h @@ -0,0 +1,24 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_SYSTEM_EVENT_LOOP_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_SYSTEM_EVENT_LOOP_IMPL_H_ + +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace quiche { + +inline void QuicheRunSystemEventLoopIterationImpl() {} + +class QUICHE_EXPORT QuicheSystemEventLoopImpl { + public: + QuicheSystemEventLoopImpl(std::string context_name) { + QUICHE_LOG(INFO) << "Starting event loop for " << context_name; + } +}; + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_SYSTEM_EVENT_LOOP_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_test_impl.cc b/quiche/common/platform/default/quiche_platform_impl/quiche_test_impl.cc new file mode 100644 index 000000000000..933b996af5e2 --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_test_impl.cc @@ -0,0 +1,25 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche_platform_impl/quiche_test_impl.h" + +#include "quiche/common/platform/api/quiche_flags.h" + +QuicheFlagSaverImpl::QuicheFlagSaverImpl() { +#define QUIC_FLAG(flag, value) saved_##flag##_ = FLAGS_##flag; +#include "quiche/quic/core/quic_flags_list.h" +#undef QUIC_FLAG +#define QUIC_PROTOCOL_FLAG(type, flag, ...) saved_##flag##_ = FLAGS_##flag; +#include "quiche/quic/core/quic_protocol_flags_list.h" +#undef QUIC_PROTOCOL_FLAG +} + +QuicheFlagSaverImpl::~QuicheFlagSaverImpl() { +#define QUIC_FLAG(flag, value) FLAGS_##flag = saved_##flag##_; +#include "quiche/quic/core/quic_flags_list.h" // NOLINT +#undef QUIC_FLAG +#define QUIC_PROTOCOL_FLAG(type, flag, ...) FLAGS_##flag = saved_##flag##_; +#include "quiche/quic/core/quic_protocol_flags_list.h" // NOLINT +#undef QUIC_PROTOCOL_FLAG +} diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_test_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_test_impl.h new file mode 100644 index 000000000000..258b86286c36 --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_test_impl.h @@ -0,0 +1,56 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_TEST_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_TEST_IMPL_H_ + +#include "gmock/gmock.h" +#include "gtest/gtest-spi.h" +#include "gtest/gtest.h" + +#define EXPECT_QUICHE_DEBUG_DEATH_IMPL(condition, message) \ + EXPECT_DEBUG_DEATH(condition, message) + +#define QUICHE_TEST_DISABLED_IN_CHROME_IMPL(name) name +#define QUICHE_SLOW_TEST_IMPL(test) test + +class QuicheFlagSaverImpl { + public: + QuicheFlagSaverImpl(); + ~QuicheFlagSaverImpl(); + + private: +#define QUIC_FLAG(flag, value) bool saved_##flag##_; +#include "quiche/quic/core/quic_flags_list.h" +#undef QUIC_FLAG + +#define QUIC_PROTOCOL_FLAG(type, flag, ...) type saved_##flag##_; +#include "quiche/quic/core/quic_protocol_flags_list.h" +#undef QUIC_PROTOCOL_FLAG +}; + +class ScopedEnvironmentForThreadsImpl {}; + +namespace quiche::test { + +class QuicheTestImpl : public ::testing::Test { + private: + QuicheFlagSaverImpl saver_; +}; + +template +class QuicheTestWithParamImpl : public ::testing::TestWithParam { + private: + QuicheFlagSaverImpl saver_; +}; + +inline std::string QuicheGetCommonSourcePathImpl() { return "quiche/common"; } + +} // namespace quiche::test + +inline std::string QuicheGetTestMemoryCachePathImpl() { + return "quiche/quic/test_tools/quic_http_response_cache_data"; +} + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_TEST_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_test_loopback_impl.cc b/quiche/common/platform/default/quiche_platform_impl/quiche_test_loopback_impl.cc new file mode 100644 index 000000000000..661952e895ed --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_test_loopback_impl.cc @@ -0,0 +1,32 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche_platform_impl/quiche_test_loopback_impl.h" + +namespace quiche { + +quic::IpAddressFamily AddressFamilyUnderTestImpl() { + return quic::IpAddressFamily::IP_V4; +} + +quic::QuicIpAddress TestLoopback4Impl() { + return quic::QuicIpAddress::Loopback4(); +} + +quic::QuicIpAddress TestLoopback6Impl() { + return quic::QuicIpAddress::Loopback6(); +} + +quic::QuicIpAddress TestLoopbackImpl() { + return quic::QuicIpAddress::Loopback4(); +} + +quic::QuicIpAddress TestLoopbackImpl(int index) { + const char kLocalhostIPv4[] = {127, 0, 0, static_cast(index)}; + quic::QuicIpAddress address; + address.FromPackedString(kLocalhostIPv4, 4); + return address; +} + +} // namespace quiche diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_test_loopback_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_test_loopback_impl.h new file mode 100644 index 000000000000..6c377c866832 --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_test_loopback_impl.h @@ -0,0 +1,30 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_TEST_LOOPBACK_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_TEST_LOOPBACK_IMPL_H_ + +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_ip_address_family.h" + +namespace quiche { + +// Returns the address family IPv4 used to run test under. +quic::IpAddressFamily AddressFamilyUnderTestImpl(); + +// Returns an IPv4 loopback address. +quic::QuicIpAddress TestLoopback4Impl(); + +// Returns the only IPv6 loopback address. +quic::QuicIpAddress TestLoopback6Impl(); + +// Returns an IPv4 loopback address. +quic::QuicIpAddress TestLoopbackImpl(); + +// Returns an indexed IPv4 loopback address. +quic::QuicIpAddress TestLoopbackImpl(int index); + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_TEST_LOOPBACK_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_test_output_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_test_output_impl.h new file mode 100644 index 000000000000..5fae61a4ee4f --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_test_output_impl.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_TEST_OUTPUT_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_TEST_OUTPUT_IMPL_H_ + +#include + +#include "absl/strings/string_view.h" + +namespace quiche { + +inline void QuicheSaveTestOutputImpl(absl::string_view /*filename*/, + absl::string_view /*data*/) {} + +inline bool QuicheLoadTestOutputImpl(absl::string_view /*filename*/, + std::string* /*data*/) { + return false; +} + +inline void QuicheRecordTraceImpl(absl::string_view /*identifier*/, + absl::string_view /*data*/) {} + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_TEST_OUTPUT_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_testvalue_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_testvalue_impl.h new file mode 100644 index 000000000000..e88283fd892f --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_testvalue_impl.h @@ -0,0 +1,13 @@ +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_TESTVALUE_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_TESTVALUE_IMPL_H_ + +#include "absl/strings/string_view.h" + +namespace quiche { + +template +void AdjustTestValueImpl(absl::string_view /*label*/, T* /*var*/) {} + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_TESTVALUE_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_thread_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_thread_impl.h new file mode 100644 index 000000000000..7b5748c1114d --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_thread_impl.h @@ -0,0 +1,26 @@ +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_THREAD_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_THREAD_IMPL_H_ + +#include +#include // NOLINT: only used outside of google3 + +#include "absl/types/optional.h" +#include "quiche/common/platform/api/quiche_export.h" + +class QUICHE_NO_EXPORT QuicheThreadImpl { + public: + QuicheThreadImpl(const std::string&) {} + virtual ~QuicheThreadImpl() {} + + virtual void Run() = 0; + + void Start() { + thread_.emplace([this]() { Run(); }); + } + void Join() { thread_->join(); } + + private: + absl::optional thread_; +}; + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_THREAD_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_time_utils_impl.cc b/quiche/common/platform/default/quiche_platform_impl/quiche_time_utils_impl.cc new file mode 100644 index 000000000000..ea4926986a80 --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_time_utils_impl.cc @@ -0,0 +1,48 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche_platform_impl/quiche_time_utils_impl.h" + +#include "absl/time/civil_time.h" +#include "absl/time/time.h" + +namespace quiche { + +namespace { +absl::optional QuicheUtcDateTimeToUnixSecondsInner(int year, int month, + int day, int hour, + int minute, + int second) { + const absl::CivilSecond civil_time(year, month, day, hour, minute, second); + if (second != 60 && + (civil_time.year() != year || civil_time.month() != month || + civil_time.day() != day || civil_time.hour() != hour || + civil_time.minute() != minute || civil_time.second() != second)) { + return absl::nullopt; + } + + const absl::Time time = absl::FromCivil(civil_time, absl::UTCTimeZone()); + return absl::ToUnixSeconds(time); +} +} // namespace + +absl::optional QuicheUtcDateTimeToUnixSecondsImpl(int year, int month, + int day, int hour, + int minute, + int second) { + // Handle leap seconds without letting any other irregularities happen. + if (second == 60) { + auto previous_second = QuicheUtcDateTimeToUnixSecondsInner( + year, month, day, hour, minute, second - 1); + if (!previous_second.has_value()) { + return absl::nullopt; + } + return *previous_second + 1; + } + + return QuicheUtcDateTimeToUnixSecondsInner(year, month, day, hour, minute, + second); +} + +} // namespace quiche diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_time_utils_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_time_utils_impl.h new file mode 100644 index 000000000000..24ef40fe77b8 --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_time_utils_impl.h @@ -0,0 +1,21 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_TIME_UTILS_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_TIME_UTILS_IMPL_H_ + +#include + +#include "absl/types/optional.h" + +namespace quiche { + +absl::optional QuicheUtcDateTimeToUnixSecondsImpl(int year, int month, + int day, int hour, + int minute, + int second); + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_TIME_UTILS_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_udp_socket_platform_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_udp_socket_platform_impl.h new file mode 100644 index 000000000000..e79ac225006e --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_udp_socket_platform_impl.h @@ -0,0 +1,37 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_UDP_SOCKET_PLATFORM_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_UDP_SOCKET_PLATFORM_IMPL_H_ + +#include +#include + +#include +#include + +namespace quiche { + +constexpr size_t kCmsgSpaceForGooglePacketHeaderImpl = 0; + +inline bool GetGooglePacketHeadersFromControlMessageImpl( + struct ::cmsghdr* /*cmsg*/, char** /*packet_headers*/, + size_t* /*packet_headers_len*/) { + return false; +} + +inline void SetGoogleSocketOptionsImpl(int /*fd*/) {} + +inline int GetEcnCmsgArgsPreserveDscpImpl(const int /*fd*/, + const int /*address_family*/, + uint8_t /*ecn_codepoint*/, + int& /*type*/, void* /*value*/, + socklen_t& /*value_len*/) { + // TODO(b/273081493): implement this. + return 0; +} + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_UDP_SOCKET_PLATFORM_IMPL_H_ diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_url_utils_impl.cc b/quiche/common/platform/default/quiche_platform_impl/quiche_url_utils_impl.cc new file mode 100644 index 000000000000..68e4a5e207cb --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_url_utils_impl.cc @@ -0,0 +1,79 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche_platform_impl/quiche_url_utils_impl.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "url/url_canon.h" +#include "url/url_util.h" + +namespace quiche { + +bool ExpandURITemplateImpl( + const std::string& uri_template, + const absl::flat_hash_map& parameters, + std::string* target, absl::flat_hash_set* vars_found) { + absl::flat_hash_set found; + std::string result = uri_template; + for (const auto& pair : parameters) { + const std::string& name = pair.first; + const std::string& value = pair.second; + std::string name_input = absl::StrCat("{", name, "}"); + url::RawCanonOutputT canon_value; + url::EncodeURIComponent(value.c_str(), value.length(), &canon_value); + std::string encoded_value(canon_value.data(), canon_value.length()); + int num_replaced = + absl::StrReplaceAll({{name_input, encoded_value}}, &result); + if (num_replaced > 0) { + found.insert(name); + } + } + // Remove any remaining variables that were not present in |parameters|. + while (true) { + size_t start = result.find('{'); + if (start == std::string::npos) { + break; + } + size_t end = result.find('}'); + if (end == std::string::npos || end <= start) { + return false; + } + result.erase(start, (end - start) + 1); + } + if (vars_found != nullptr) { + *vars_found = found; + } + *target = result; + return true; +} + +absl::optional AsciiUrlDecodeImpl(absl::string_view input) { + std::string input_encoded = std::string(input); + url::RawCanonOutputW<1024> canon_output; + url::DecodeURLEscapeSequences(input_encoded.c_str(), input_encoded.length(), + url::DecodeURLMode::kUTF8, + &canon_output); + std::string output; + output.reserve(canon_output.length()); + for (int i = 0; i < canon_output.length(); i++) { + const uint16_t c = reinterpret_cast(canon_output.data())[i]; + if (c > std::numeric_limits::max()) { + return absl::nullopt; + } + output += static_cast(c); + } + return output; +} + +} // namespace quiche diff --git a/quiche/common/platform/default/quiche_platform_impl/quiche_url_utils_impl.h b/quiche/common/platform/default/quiche_platform_impl/quiche_url_utils_impl.h new file mode 100644 index 000000000000..45d87e0b1d62 --- /dev/null +++ b/quiche/common/platform/default/quiche_platform_impl/quiche_url_utils_impl.h @@ -0,0 +1,35 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_URL_UTILS_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_URL_UTILS_IMPL_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace quiche { + +// Produces concrete URLs in |target| from templated ones in |uri_template|. +// Parameters are URL-encoded. Collects the names of any expanded variables in +// |vars_found|. Supports level 1 templates as specified in RFC 6570. Returns +// true if the template was parseable, false if it was malformed. +QUICHE_EXPORT bool ExpandURITemplateImpl( + const std::string& uri_template, + const absl::flat_hash_map& parameters, + std::string* target, + absl::flat_hash_set* vars_found = nullptr); + +// Decodes a URL-encoded string and converts it to ASCII. If the decoded input +// contains non-ASCII characters, decoding fails and absl::nullopt is returned. +QUICHE_EXPORT absl::optional AsciiUrlDecodeImpl( + absl::string_view input); + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_URL_UTILS_IMPL_H_ diff --git a/quiche/common/print_elements.h b/quiche/common/print_elements.h new file mode 100644 index 000000000000..ae69c400292f --- /dev/null +++ b/quiche/common/print_elements.h @@ -0,0 +1,37 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PRINT_ELEMENTS_H_ +#define QUICHE_COMMON_PRINT_ELEMENTS_H_ + +#include +#include +#include + +#include "quiche/common/platform/api/quiche_export.h" + +namespace quiche { + +// Print elements of any iterable container that has cbegin() and cend() methods +// and the elements have operator<<(ostream) override. +template +QUICHE_EXPORT inline std::string PrintElements(const T& container) { + std::stringstream debug_string; + debug_string << "{"; + auto it = container.cbegin(); + if (it != container.cend()) { + debug_string << *it; + ++it; + while (it != container.cend()) { + debug_string << ", " << *it; + ++it; + } + } + debug_string << "}"; + return debug_string.str(); +} + +} // namespace quiche + +#endif // QUICHE_COMMON_PRINT_ELEMENTS_H_ diff --git a/quiche/common/print_elements_test.cc b/quiche/common/print_elements_test.cc new file mode 100644 index 000000000000..17f1b1ce29ff --- /dev/null +++ b/quiche/common/print_elements_test.cc @@ -0,0 +1,61 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/print_elements.h" + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/common/platform/api/quiche_test.h" + +using quic::QuicIetfTransportErrorCodes; + +namespace quiche { +namespace test { +namespace { + +TEST(PrintElementsTest, Empty) { + std::vector empty{}; + EXPECT_EQ("{}", PrintElements(empty)); +} + +TEST(PrintElementsTest, StdContainers) { + std::vector one{"foo"}; + EXPECT_EQ("{foo}", PrintElements(one)); + + std::list two{"foo", "bar"}; + EXPECT_EQ("{foo, bar}", PrintElements(two)); + + std::deque three{"foo", "bar", "baz"}; + EXPECT_EQ("{foo, bar, baz}", PrintElements(three)); +} + +// QuicIetfTransportErrorCodes has a custom operator<<() override. +TEST(PrintElementsTest, CustomPrinter) { + std::vector empty{}; + EXPECT_EQ("{}", PrintElements(empty)); + + std::list one{ + QuicIetfTransportErrorCodes::NO_IETF_QUIC_ERROR}; + EXPECT_EQ("{NO_IETF_QUIC_ERROR}", PrintElements(one)); + + std::vector two{ + QuicIetfTransportErrorCodes::FLOW_CONTROL_ERROR, + QuicIetfTransportErrorCodes::STREAM_LIMIT_ERROR}; + EXPECT_EQ("{FLOW_CONTROL_ERROR, STREAM_LIMIT_ERROR}", PrintElements(two)); + + std::list three{ + QuicIetfTransportErrorCodes::CONNECTION_ID_LIMIT_ERROR, + QuicIetfTransportErrorCodes::PROTOCOL_VIOLATION, + QuicIetfTransportErrorCodes::INVALID_TOKEN}; + EXPECT_EQ("{CONNECTION_ID_LIMIT_ERROR, PROTOCOL_VIOLATION, INVALID_TOKEN}", + PrintElements(three)); +} + +} // anonymous namespace +} // namespace test +} // namespace quiche diff --git a/quiche/common/quiche_buffer_allocator.cc b/quiche/common/quiche_buffer_allocator.cc new file mode 100644 index 000000000000..9d53d9dada98 --- /dev/null +++ b/quiche/common/quiche_buffer_allocator.cc @@ -0,0 +1,76 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/quiche_buffer_allocator.h" + +#include + +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_prefetch.h" + +namespace quiche { + +QuicheBuffer QuicheBuffer::CopyFromIovec(QuicheBufferAllocator* allocator, + const struct iovec* iov, int iov_count, + size_t iov_offset, + size_t buffer_length) { + if (buffer_length == 0) { + return {}; + } + + int iovnum = 0; + while (iovnum < iov_count && iov_offset >= iov[iovnum].iov_len) { + iov_offset -= iov[iovnum].iov_len; + ++iovnum; + } + QUICHE_DCHECK_LE(iovnum, iov_count); + if (iovnum >= iov_count) { + QUICHE_BUG(quiche_bug_10839_1) + << "iov_offset larger than iovec total size."; + return {}; + } + QUICHE_DCHECK_LE(iov_offset, iov[iovnum].iov_len); + + // Unroll the first iteration that handles iov_offset. + const size_t iov_available = iov[iovnum].iov_len - iov_offset; + size_t copy_len = std::min(buffer_length, iov_available); + + // Try to prefetch the next iov if there is at least one more after the + // current. Otherwise, it looks like an irregular access that the hardware + // prefetcher won't speculatively prefetch. Only prefetch one iov because + // generally, the iov_offset is not 0, input iov consists of 2K buffers and + // the output buffer is ~1.4K. + if (copy_len == iov_available && iovnum + 1 < iov_count) { + char* next_base = static_cast(iov[iovnum + 1].iov_base); + // Prefetch 2 cachelines worth of data to get the prefetcher started; leave + // it to the hardware prefetcher after that. + quiche::QuichePrefetchT0(next_base); + if (iov[iovnum + 1].iov_len >= 64) { + quiche::QuichePrefetchT0(next_base + ABSL_CACHELINE_SIZE); + } + } + + QuicheBuffer buffer(allocator, buffer_length); + + const char* src = static_cast(iov[iovnum].iov_base) + iov_offset; + char* dst = buffer.data(); + while (true) { + memcpy(dst, src, copy_len); + buffer_length -= copy_len; + dst += copy_len; + if (buffer_length == 0 || ++iovnum >= iov_count) { + break; + } + src = static_cast(iov[iovnum].iov_base); + copy_len = std::min(buffer_length, iov[iovnum].iov_len); + } + + QUICHE_BUG_IF(quiche_bug_10839_2, buffer_length > 0) + << "iov_offset + buffer_length larger than iovec total size."; + + return buffer; +} + +} // namespace quiche diff --git a/quiche/common/quiche_buffer_allocator.h b/quiche/common/quiche_buffer_allocator.h new file mode 100644 index 000000000000..0f3684299556 --- /dev/null +++ b/quiche/common/quiche_buffer_allocator.h @@ -0,0 +1,126 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_QUICHE_BUFFER_ALLOCATOR_H_ +#define QUICHE_COMMON_QUICHE_BUFFER_ALLOCATOR_H_ + +#include + +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_iovec.h" + +namespace quiche { + +// Abstract base class for classes which allocate and delete buffers. +class QUICHE_EXPORT QuicheBufferAllocator { + public: + virtual ~QuicheBufferAllocator() = default; + + // Returns or allocates a new buffer of |size|. Never returns null. + virtual char* New(size_t size) = 0; + + // Returns or allocates a new buffer of |size| if |flag_enable| is true. + // Otherwise, returns a buffer that is compatible with this class directly + // with operator new. Never returns null. + virtual char* New(size_t size, bool flag_enable) = 0; + + // Releases a buffer. + virtual void Delete(char* buffer) = 0; + + // Marks the allocator as being idle. Serves as a hint to notify the allocator + // that it should release any resources it's still holding on to. + virtual void MarkAllocatorIdle() {} +}; + +// A deleter that can be used to manage ownership of buffers allocated via +// QuicheBufferAllocator through std::unique_ptr. +class QUICHE_EXPORT QuicheBufferDeleter { + public: + explicit QuicheBufferDeleter(QuicheBufferAllocator* allocator) + : allocator_(allocator) {} + + QuicheBufferAllocator* allocator() { return allocator_; } + void operator()(char* buffer) { allocator_->Delete(buffer); } + + private: + QuicheBufferAllocator* allocator_; +}; + +using QuicheUniqueBufferPtr = std::unique_ptr; + +inline QuicheUniqueBufferPtr MakeUniqueBuffer(QuicheBufferAllocator* allocator, + size_t size) { + return QuicheUniqueBufferPtr(allocator->New(size), + QuicheBufferDeleter(allocator)); +} + +// QuicheUniqueBufferPtr with a length attached to it. Similar to +// QuicheMemSlice, except unlike QuicheMemSlice, QuicheBuffer is mutable and is +// not platform-specific. Also unlike QuicheMemSlice, QuicheBuffer can be +// empty. +class QUICHE_EXPORT QuicheBuffer { + public: + QuicheBuffer() : buffer_(nullptr, QuicheBufferDeleter(nullptr)), size_(0) {} + QuicheBuffer(QuicheBufferAllocator* allocator, size_t size) + : buffer_(MakeUniqueBuffer(allocator, size)), size_(size) {} + + QuicheBuffer(QuicheUniqueBufferPtr buffer, size_t size) + : buffer_(std::move(buffer)), size_(size) {} + + // Make sure the move constructor zeroes out the size field. + QuicheBuffer(QuicheBuffer&& other) + : buffer_(std::move(other.buffer_)), size_(other.size_) { + other.buffer_ = nullptr; + other.size_ = 0; + } + QuicheBuffer& operator=(QuicheBuffer&& other) { + buffer_ = std::move(other.buffer_); + size_ = other.size_; + + other.buffer_ = nullptr; + other.size_ = 0; + return *this; + } + + // Factory method to create a QuicheBuffer that holds a copy of `data`. + static QuicheBuffer Copy(QuicheBufferAllocator* allocator, + absl::string_view data) { + QuicheBuffer buffer(allocator, data.size()); + memcpy(buffer.data(), data.data(), data.size()); + return buffer; + } + + // Factory method to create a QuicheBuffer of length `buffer_length` that + // holds a copy of `buffer_length` bytes from `iov` starting at offset + // `iov_offset`. `iov` must be at least `iov_offset + buffer_length` total + // length. + static QuicheBuffer CopyFromIovec(QuicheBufferAllocator* allocator, + const struct iovec* iov, int iov_count, + size_t iov_offset, size_t buffer_length); + + const char* data() const { return buffer_.get(); } + char* data() { return buffer_.get(); } + size_t size() const { return size_; } + bool empty() const { return size_ == 0; } + absl::string_view AsStringView() const { + return absl::string_view(data(), size()); + } + + // Releases the ownership of the underlying buffer. + QuicheUniqueBufferPtr Release() { + size_ = 0; + return std::move(buffer_); + } + + private: + QuicheUniqueBufferPtr buffer_; + size_t size_; +}; + +} // namespace quiche + +#endif // QUICHE_COMMON_QUICHE_BUFFER_ALLOCATOR_H_ diff --git a/quiche/common/quiche_buffer_allocator_test.cc b/quiche/common/quiche_buffer_allocator_test.cc new file mode 100644 index 000000000000..6110a9c865e1 --- /dev/null +++ b/quiche/common/quiche_buffer_allocator_test.cc @@ -0,0 +1,141 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/quiche_buffer_allocator.h" + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_expect_bug.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/simple_buffer_allocator.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +namespace quiche { +namespace test { +namespace { + +TEST(QuicheBuffer, CopyFromEmpty) { + SimpleBufferAllocator allocator; + QuicheBuffer buffer = QuicheBuffer::Copy(&allocator, ""); + EXPECT_TRUE(buffer.empty()); +} + +TEST(QuicheBuffer, Copy) { + SimpleBufferAllocator allocator; + QuicheBuffer buffer = QuicheBuffer::Copy(&allocator, "foobar"); + EXPECT_EQ("foobar", buffer.AsStringView()); +} + +TEST(QuicheBuffer, CopyFromIovecZeroBytes) { + const int buffer_length = 0; + + SimpleBufferAllocator allocator; + QuicheBuffer buffer = QuicheBuffer::CopyFromIovec( + &allocator, nullptr, + /* iov_count = */ 0, /* iov_offset = */ 0, buffer_length); + EXPECT_TRUE(buffer.empty()); + + constexpr absl::string_view kData("foobar"); + iovec iov = MakeIOVector(kData); + + buffer = QuicheBuffer::CopyFromIovec(&allocator, &iov, + /* iov_count = */ 1, + /* iov_offset = */ 0, buffer_length); + EXPECT_TRUE(buffer.empty()); + + buffer = QuicheBuffer::CopyFromIovec(&allocator, &iov, + /* iov_count = */ 1, + /* iov_offset = */ 3, buffer_length); + EXPECT_TRUE(buffer.empty()); +} + +TEST(QuicheBuffer, CopyFromIovecSimple) { + constexpr absl::string_view kData("foobar"); + iovec iov = MakeIOVector(kData); + + SimpleBufferAllocator allocator; + QuicheBuffer buffer = + QuicheBuffer::CopyFromIovec(&allocator, &iov, + /* iov_count = */ 1, /* iov_offset = */ 0, + /* buffer_length = */ 6); + EXPECT_EQ("foobar", buffer.AsStringView()); + + buffer = + QuicheBuffer::CopyFromIovec(&allocator, &iov, + /* iov_count = */ 1, /* iov_offset = */ 0, + /* buffer_length = */ 3); + EXPECT_EQ("foo", buffer.AsStringView()); + + buffer = + QuicheBuffer::CopyFromIovec(&allocator, &iov, + /* iov_count = */ 1, /* iov_offset = */ 3, + /* buffer_length = */ 3); + EXPECT_EQ("bar", buffer.AsStringView()); + + buffer = + QuicheBuffer::CopyFromIovec(&allocator, &iov, + /* iov_count = */ 1, /* iov_offset = */ 1, + /* buffer_length = */ 4); + EXPECT_EQ("ooba", buffer.AsStringView()); +} + +TEST(QuicheBuffer, CopyFromIovecMultiple) { + constexpr absl::string_view kData1("foo"); + constexpr absl::string_view kData2("bar"); + iovec iov[] = {MakeIOVector(kData1), MakeIOVector(kData2)}; + + SimpleBufferAllocator allocator; + QuicheBuffer buffer = + QuicheBuffer::CopyFromIovec(&allocator, &iov[0], + /* iov_count = */ 2, /* iov_offset = */ 0, + /* buffer_length = */ 6); + EXPECT_EQ("foobar", buffer.AsStringView()); + + buffer = + QuicheBuffer::CopyFromIovec(&allocator, &iov[0], + /* iov_count = */ 2, /* iov_offset = */ 0, + /* buffer_length = */ 3); + EXPECT_EQ("foo", buffer.AsStringView()); + + buffer = + QuicheBuffer::CopyFromIovec(&allocator, &iov[0], + /* iov_count = */ 2, /* iov_offset = */ 3, + /* buffer_length = */ 3); + EXPECT_EQ("bar", buffer.AsStringView()); + + buffer = + QuicheBuffer::CopyFromIovec(&allocator, &iov[0], + /* iov_count = */ 2, /* iov_offset = */ 1, + /* buffer_length = */ 4); + EXPECT_EQ("ooba", buffer.AsStringView()); +} + +TEST(QuicheBuffer, CopyFromIovecOffsetTooLarge) { + constexpr absl::string_view kData1("foo"); + constexpr absl::string_view kData2("bar"); + iovec iov[] = {MakeIOVector(kData1), MakeIOVector(kData2)}; + + SimpleBufferAllocator allocator; + EXPECT_QUICHE_BUG( + QuicheBuffer::CopyFromIovec(&allocator, &iov[0], + /* iov_count = */ 2, /* iov_offset = */ 10, + /* buffer_length = */ 6), + "iov_offset larger than iovec total size"); +} + +TEST(QuicheBuffer, CopyFromIovecTooManyBytesRequested) { + constexpr absl::string_view kData1("foo"); + constexpr absl::string_view kData2("bar"); + iovec iov[] = {MakeIOVector(kData1), MakeIOVector(kData2)}; + + SimpleBufferAllocator allocator; + EXPECT_QUICHE_BUG( + QuicheBuffer::CopyFromIovec(&allocator, &iov[0], + /* iov_count = */ 2, /* iov_offset = */ 2, + /* buffer_length = */ 10), + R"(iov_offset \+ buffer_length larger than iovec total size)"); +} + +} // anonymous namespace +} // namespace test +} // namespace quiche diff --git a/quiche/common/quiche_circular_deque.h b/quiche/common/quiche_circular_deque.h new file mode 100644 index 000000000000..5c8fc1f1f688 --- /dev/null +++ b/quiche/common/quiche_circular_deque.h @@ -0,0 +1,754 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_QUICHE_CIRCULAR_DEQUE_H_ +#define QUICHE_COMMON_QUICHE_CIRCULAR_DEQUE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace quiche { + +// QuicheCircularDeque is a STL-style container that is similar to std::deque in +// API and std::vector in capacity management. The goal is to optimize a common +// QUIC use case where we keep adding new elements to the end and removing old +// elements from the beginning, under such scenarios, if the container's size() +// remain relatively stable, QuicheCircularDeque requires little to no memory +// allocations or deallocations. +// +// The implementation, as the name suggests, uses a flat circular buffer to hold +// all elements. At any point in time, either +// a) All elements are placed in a contiguous portion of this buffer, like a +// c-array, or +// b) Elements are phycially divided into two parts: the first part occupies the +// end of the buffer and the second part occupies the beginning of the +// buffer. +// +// Currently, elements can only be pushed or poped from either ends, it can't be +// inserted or erased in the middle. +// +// TODO(wub): Make memory grow/shrink strategies customizable. +template > +class QUICHE_NO_EXPORT QuicheCircularDeque { + using AllocatorTraits = std::allocator_traits; + + // Pointee is either T or const T. + template + class QUICHE_NO_EXPORT basic_iterator { + using size_type = typename AllocatorTraits::size_type; + + public: + using iterator_category = std::random_access_iterator_tag; + using value_type = typename AllocatorTraits::value_type; + using difference_type = typename AllocatorTraits::difference_type; + using pointer = Pointee*; + using reference = Pointee&; + + basic_iterator() = default; + + // A copy constructor if Pointee is T. + // A conversion from iterator to const_iterator if Pointee is const T. + basic_iterator( + const basic_iterator& it) // NOLINT(runtime/explicit) + : deque_(it.deque_), index_(it.index_) {} + + // A copy assignment if Pointee is T. + // A assignment from iterator to const_iterator if Pointee is const T. + basic_iterator& operator=(const basic_iterator& it) { + if (this != &it) { + deque_ = it.deque_; + index_ = it.index_; + } + return *this; + } + + reference operator*() const { return *deque_->index_to_address(index_); } + pointer operator->() const { return deque_->index_to_address(index_); } + reference operator[](difference_type i) { return *(*this + i); } + + basic_iterator& operator++() { + Increment(); + return *this; + } + + basic_iterator operator++(int) { + basic_iterator result = *this; + Increment(); + return result; + } + + basic_iterator operator--() { + Decrement(); + return *this; + } + + basic_iterator operator--(int) { + basic_iterator result = *this; + Decrement(); + return result; + } + + friend basic_iterator operator+(const basic_iterator& it, + difference_type delta) { + basic_iterator result = it; + result.IncrementBy(delta); + return result; + } + + basic_iterator& operator+=(difference_type delta) { + IncrementBy(delta); + return *this; + } + + friend basic_iterator operator-(const basic_iterator& it, + difference_type delta) { + basic_iterator result = it; + result.IncrementBy(-delta); + return result; + } + + basic_iterator& operator-=(difference_type delta) { + IncrementBy(-delta); + return *this; + } + + friend difference_type operator-(const basic_iterator& lhs, + const basic_iterator& rhs) { + return lhs.ExternalPosition() - rhs.ExternalPosition(); + } + + friend bool operator==(const basic_iterator& lhs, + const basic_iterator& rhs) { + return lhs.index_ == rhs.index_; + } + + friend bool operator!=(const basic_iterator& lhs, + const basic_iterator& rhs) { + return !(lhs == rhs); + } + + friend bool operator<(const basic_iterator& lhs, + const basic_iterator& rhs) { + return lhs.ExternalPosition() < rhs.ExternalPosition(); + } + + friend bool operator<=(const basic_iterator& lhs, + const basic_iterator& rhs) { + return !(lhs > rhs); + } + + friend bool operator>(const basic_iterator& lhs, + const basic_iterator& rhs) { + return lhs.ExternalPosition() > rhs.ExternalPosition(); + } + + friend bool operator>=(const basic_iterator& lhs, + const basic_iterator& rhs) { + return !(lhs < rhs); + } + + private: + basic_iterator(const QuicheCircularDeque* deque, size_type index) + : deque_(deque), index_(index) {} + + void Increment() { + QUICHE_DCHECK_LE(ExternalPosition() + 1, deque_->size()); + index_ = deque_->index_next(index_); + } + + void Decrement() { + QUICHE_DCHECK_GE(ExternalPosition(), 1u); + index_ = deque_->index_prev(index_); + } + + void IncrementBy(difference_type delta) { + if (delta >= 0) { + // After increment we are before or at end(). + QUICHE_DCHECK_LE(static_cast(ExternalPosition() + delta), + deque_->size()); + } else { + // After decrement we are after or at begin(). + QUICHE_DCHECK_GE(ExternalPosition(), static_cast(-delta)); + } + index_ = deque_->index_increment_by(index_, delta); + } + + size_type ExternalPosition() const { + if (index_ >= deque_->begin_) { + return index_ - deque_->begin_; + } + return index_ + deque_->data_capacity() - deque_->begin_; + } + + friend class QuicheCircularDeque; + const QuicheCircularDeque* deque_ = nullptr; + size_type index_ = 0; + }; + + public: + using allocator_type = typename AllocatorTraits::allocator_type; + using value_type = typename AllocatorTraits::value_type; + using size_type = typename AllocatorTraits::size_type; + using difference_type = typename AllocatorTraits::difference_type; + using reference = value_type&; + using const_reference = const value_type&; + using pointer = typename AllocatorTraits::pointer; + using const_pointer = typename AllocatorTraits::const_pointer; + using iterator = basic_iterator; + using const_iterator = basic_iterator; + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = std::reverse_iterator; + + QuicheCircularDeque() : QuicheCircularDeque(allocator_type()) {} + explicit QuicheCircularDeque(const allocator_type& alloc) + : allocator_and_data_(alloc) {} + + QuicheCircularDeque(size_type count, const T& value, + const Allocator& alloc = allocator_type()) + : allocator_and_data_(alloc) { + resize(count, value); + } + + explicit QuicheCircularDeque(size_type count, + const Allocator& alloc = allocator_type()) + : allocator_and_data_(alloc) { + resize(count); + } + + template < + class InputIt, + typename = std::enable_if_t::iterator_category>::value>> + QuicheCircularDeque(InputIt first, InputIt last, + const Allocator& alloc = allocator_type()) + : allocator_and_data_(alloc) { + AssignRange(first, last); + } + + QuicheCircularDeque(const QuicheCircularDeque& other) + : QuicheCircularDeque( + other, AllocatorTraits::select_on_container_copy_construction( + other.allocator_and_data_.allocator())) {} + + QuicheCircularDeque(const QuicheCircularDeque& other, + const allocator_type& alloc) + : allocator_and_data_(alloc) { + assign(other.begin(), other.end()); + } + + QuicheCircularDeque(QuicheCircularDeque&& other) + : begin_(other.begin_), + end_(other.end_), + allocator_and_data_(std::move(other.allocator_and_data_)) { + other.begin_ = other.end_ = 0; + other.allocator_and_data_.data = nullptr; + other.allocator_and_data_.data_capacity = 0; + } + + QuicheCircularDeque(QuicheCircularDeque&& other, const allocator_type& alloc) + : allocator_and_data_(alloc) { + MoveRetainAllocator(std::move(other)); + } + + QuicheCircularDeque(std::initializer_list init, + const allocator_type& alloc = allocator_type()) + : QuicheCircularDeque(init.begin(), init.end(), alloc) {} + + QuicheCircularDeque& operator=(const QuicheCircularDeque& other) { + if (this == &other) { + return *this; + } + if (AllocatorTraits::propagate_on_container_copy_assignment::value && + (allocator_and_data_.allocator() != + other.allocator_and_data_.allocator())) { + // Destroy all current elements and blocks with the current allocator, + // before switching this to use the allocator propagated from "other". + DestroyAndDeallocateAll(); + begin_ = end_ = 0; + allocator_and_data_ = + AllocatorAndData(other.allocator_and_data_.allocator()); + } + assign(other.begin(), other.end()); + return *this; + } + + QuicheCircularDeque& operator=(QuicheCircularDeque&& other) { + if (this == &other) { + return *this; + } + if (AllocatorTraits::propagate_on_container_move_assignment::value) { + // Take over the storage of "other", along with its allocator. + this->~QuicheCircularDeque(); + new (this) QuicheCircularDeque(std::move(other)); + } else { + MoveRetainAllocator(std::move(other)); + } + return *this; + } + + ~QuicheCircularDeque() { DestroyAndDeallocateAll(); } + + void assign(size_type count, const T& value) { + ClearRetainCapacity(); + reserve(count); + for (size_t i = 0; i < count; ++i) { + emplace_back(value); + } + } + + template < + class InputIt, + typename = std::enable_if_t::iterator_category>::value>> + void assign(InputIt first, InputIt last) { + AssignRange(first, last); + } + + void assign(std::initializer_list ilist) { + assign(ilist.begin(), ilist.end()); + } + + reference at(size_type pos) { + QUICHE_DCHECK(pos < size()) << "pos:" << pos << ", size():" << size(); + size_type index = begin_ + pos; + if (index < data_capacity()) { + return *index_to_address(index); + } + return *index_to_address(index - data_capacity()); + } + + const_reference at(size_type pos) const { + return const_cast(this)->at(pos); + } + + reference operator[](size_type pos) { return at(pos); } + + const_reference operator[](size_type pos) const { return at(pos); } + + reference front() { + QUICHE_DCHECK(!empty()); + return *index_to_address(begin_); + } + + const_reference front() const { + return const_cast(this)->front(); + } + + reference back() { + QUICHE_DCHECK(!empty()); + return *(index_to_address(end_ == 0 ? data_capacity() - 1 : end_ - 1)); + } + + const_reference back() const { + return const_cast(this)->back(); + } + + iterator begin() { return iterator(this, begin_); } + const_iterator begin() const { return const_iterator(this, begin_); } + const_iterator cbegin() const { return const_iterator(this, begin_); } + + iterator end() { return iterator(this, end_); } + const_iterator end() const { return const_iterator(this, end_); } + const_iterator cend() const { return const_iterator(this, end_); } + + reverse_iterator rbegin() { return reverse_iterator(end()); } + const_reverse_iterator rbegin() const { + return const_reverse_iterator(end()); + } + const_reverse_iterator crbegin() const { return rbegin(); } + + reverse_iterator rend() { return reverse_iterator(begin()); } + const_reverse_iterator rend() const { + return const_reverse_iterator(begin()); + } + const_reverse_iterator crend() const { return rend(); } + + size_type capacity() const { + return data_capacity() == 0 ? 0 : data_capacity() - 1; + } + + void reserve(size_type new_cap) { + if (new_cap > capacity()) { + Relocate(new_cap); + } + } + + // Remove all elements. Leave capacity unchanged. + void clear() { ClearRetainCapacity(); } + + bool empty() const { return begin_ == end_; } + + size_type size() const { + if (begin_ <= end_) { + return end_ - begin_; + } + return data_capacity() + end_ - begin_; + } + + void resize(size_type count) { ResizeInternal(count); } + + void resize(size_type count, const value_type& value) { + ResizeInternal(count, value); + } + + void push_front(const T& value) { emplace_front(value); } + void push_front(T&& value) { emplace_front(std::move(value)); } + + template + reference emplace_front(Args&&... args) { + MaybeExpandCapacity(1); + begin_ = index_prev(begin_); + new (index_to_address(begin_)) T(std::forward(args)...); + return front(); + } + + void push_back(const T& value) { emplace_back(value); } + void push_back(T&& value) { emplace_back(std::move(value)); } + + template + reference emplace_back(Args&&... args) { + MaybeExpandCapacity(1); + new (index_to_address(end_)) T(std::forward(args)...); + end_ = index_next(end_); + return back(); + } + + void pop_front() { + QUICHE_DCHECK(!empty()); + DestroyByIndex(begin_); + begin_ = index_next(begin_); + MaybeShrinkCapacity(); + } + + size_type pop_front_n(size_type count) { + size_type num_elements_to_pop = std::min(count, size()); + size_type new_begin = index_increment_by(begin_, num_elements_to_pop); + DestroyRange(begin_, new_begin); + begin_ = new_begin; + MaybeShrinkCapacity(); + return num_elements_to_pop; + } + + void pop_back() { + QUICHE_DCHECK(!empty()); + end_ = index_prev(end_); + DestroyByIndex(end_); + MaybeShrinkCapacity(); + } + + size_type pop_back_n(size_type count) { + size_type num_elements_to_pop = std::min(count, size()); + size_type new_end = index_increment_by(end_, -num_elements_to_pop); + DestroyRange(new_end, end_); + end_ = new_end; + MaybeShrinkCapacity(); + return num_elements_to_pop; + } + + void swap(QuicheCircularDeque& other) { + using std::swap; + swap(begin_, other.begin_); + swap(end_, other.end_); + + if (AllocatorTraits::propagate_on_container_swap::value) { + swap(allocator_and_data_, other.allocator_and_data_); + } else { + // When propagate_on_container_swap is false, it is undefined behavior, by + // c++ standard, to swap between two AllocatorAwareContainer(s) with + // unequal allocators. + QUICHE_DCHECK(get_allocator() == other.get_allocator()) + << "Undefined swap behavior"; + swap(allocator_and_data_.data, other.allocator_and_data_.data); + swap(allocator_and_data_.data_capacity, + other.allocator_and_data_.data_capacity); + } + } + + friend void swap(QuicheCircularDeque& lhs, QuicheCircularDeque& rhs) { + lhs.swap(rhs); + } + + allocator_type get_allocator() const { + return allocator_and_data_.allocator(); + } + + friend bool operator==(const QuicheCircularDeque& lhs, + const QuicheCircularDeque& rhs) { + return std::equal(lhs.begin(), lhs.end(), rhs.begin(), rhs.end()); + } + + friend bool operator!=(const QuicheCircularDeque& lhs, + const QuicheCircularDeque& rhs) { + return !(lhs == rhs); + } + + friend QUICHE_NO_EXPORT std::ostream& operator<<( + std::ostream& os, const QuicheCircularDeque& dq) { + os << "{"; + for (size_type pos = 0; pos != dq.size(); ++pos) { + if (pos != 0) { + os << ","; + } + os << " " << dq[pos]; + } + os << " }"; + return os; + } + + private: + void MoveRetainAllocator(QuicheCircularDeque&& other) { + if (get_allocator() == other.get_allocator()) { + // Take over the storage of "other", with which we share an allocator. + DestroyAndDeallocateAll(); + + begin_ = other.begin_; + end_ = other.end_; + allocator_and_data_.data = other.allocator_and_data_.data; + allocator_and_data_.data_capacity = + other.allocator_and_data_.data_capacity; + + other.begin_ = other.end_ = 0; + other.allocator_and_data_.data = nullptr; + other.allocator_and_data_.data_capacity = 0; + } else { + // We cannot take over of the storage from "other", since it has a + // different allocator; we're stuck move-assigning elements individually. + ClearRetainCapacity(); + for (auto& elem : other) { + push_back(std::move(elem)); + } + other.clear(); + } + } + + template < + typename InputIt, + typename = std::enable_if_t::iterator_category>::value>> + void AssignRange(InputIt first, InputIt last) { + ClearRetainCapacity(); + if (std::is_base_of< + std::random_access_iterator_tag, + typename std::iterator_traits::iterator_category>::value) { + reserve(std::distance(first, last)); + } + for (; first != last; ++first) { + emplace_back(*first); + } + } + + // WARNING: begin_, end_ and allocator_and_data_ are not modified. + void DestroyAndDeallocateAll() { + DestroyRange(begin_, end_); + + if (data_capacity() > 0) { + QUICHE_DCHECK_NE(nullptr, allocator_and_data_.data); + AllocatorTraits::deallocate(allocator_and_data_.allocator(), + allocator_and_data_.data, data_capacity()); + } + } + + void ClearRetainCapacity() { + DestroyRange(begin_, end_); + begin_ = end_ = 0; + } + + void MaybeShrinkCapacity() { + // TODO(wub): Implement a storage policy that actually shrinks. + } + + void MaybeExpandCapacity(size_t num_additional_elements) { + size_t new_size = size() + num_additional_elements; + if (capacity() >= new_size) { + return; + } + + // The minimum amount of additional capacity to grow. + size_t min_additional_capacity = + std::max(MinCapacityIncrement, capacity() / 4); + size_t new_capacity = + std::max(new_size, capacity() + min_additional_capacity); + + Relocate(new_capacity); + } + + void Relocate(size_t new_capacity) { + const size_t num_elements = size(); + QUICHE_DCHECK_GT(new_capacity, num_elements) + << "new_capacity:" << new_capacity << ", num_elements:" << num_elements; + + size_t new_data_capacity = new_capacity + 1; + pointer new_data = AllocatorTraits::allocate( + allocator_and_data_.allocator(), new_data_capacity); + + if (begin_ < end_) { + // Not wrapped. + RelocateUnwrappedRange(begin_, end_, new_data); + } else if (begin_ > end_) { + // Wrapped. + const size_t num_elements_before_wrap = data_capacity() - begin_; + RelocateUnwrappedRange(begin_, data_capacity(), new_data); + RelocateUnwrappedRange(0, end_, new_data + num_elements_before_wrap); + } + + if (data_capacity()) { + AllocatorTraits::deallocate(allocator_and_data_.allocator(), + allocator_and_data_.data, data_capacity()); + } + + allocator_and_data_.data = new_data; + allocator_and_data_.data_capacity = new_data_capacity; + begin_ = 0; + end_ = num_elements; + } + + template + typename std::enable_if::value, void>::type + RelocateUnwrappedRange(size_type begin, size_type end, pointer dest) const { + QUICHE_DCHECK_LE(begin, end) << "begin:" << begin << ", end:" << end; + pointer src = index_to_address(begin); + QUICHE_DCHECK_NE(src, nullptr); + memcpy(dest, src, sizeof(T) * (end - begin)); + DestroyRange(begin, end); + } + + template + typename std::enable_if::value && + std::is_move_constructible::value, + void>::type + RelocateUnwrappedRange(size_type begin, size_type end, pointer dest) const { + QUICHE_DCHECK_LE(begin, end) << "begin:" << begin << ", end:" << end; + pointer src = index_to_address(begin); + pointer src_end = index_to_address(end); + while (src != src_end) { + new (dest) T(std::move(*src)); + DestroyByAddress(src); + ++dest; + ++src; + } + } + + template + typename std::enable_if::value && + !std::is_move_constructible::value, + void>::type + RelocateUnwrappedRange(size_type begin, size_type end, pointer dest) const { + QUICHE_DCHECK_LE(begin, end) << "begin:" << begin << ", end:" << end; + pointer src = index_to_address(begin); + pointer src_end = index_to_address(end); + while (src != src_end) { + new (dest) T(*src); + DestroyByAddress(src); + ++dest; + ++src; + } + } + + template + void ResizeInternal(size_type count, U&&... u) { + if (count > size()) { + // Expanding. + MaybeExpandCapacity(count - size()); + while (size() < count) { + emplace_back(std::forward(u)...); + } + } else { + // Most likely shrinking. No-op if count == size(). + size_type new_end = (begin_ + count) % data_capacity(); + DestroyRange(new_end, end_); + end_ = new_end; + + MaybeShrinkCapacity(); + } + } + + void DestroyRange(size_type begin, size_type end) const { + if (std::is_trivially_destructible::value) { + return; + } + if (end >= begin) { + DestroyUnwrappedRange(begin, end); + } else { + DestroyUnwrappedRange(begin, data_capacity()); + DestroyUnwrappedRange(0, end); + } + } + + // Should only be called from DestroyRange. + void DestroyUnwrappedRange(size_type begin, size_type end) const { + QUICHE_DCHECK_LE(begin, end) << "begin:" << begin << ", end:" << end; + for (; begin != end; ++begin) { + DestroyByIndex(begin); + } + } + + void DestroyByIndex(size_type index) const { + DestroyByAddress(index_to_address(index)); + } + + void DestroyByAddress(pointer address) const { + if (std::is_trivially_destructible::value) { + return; + } + address->~T(); + } + + size_type data_capacity() const { return allocator_and_data_.data_capacity; } + + pointer index_to_address(size_type index) const { + return allocator_and_data_.data + index; + } + + size_type index_prev(size_type index) const { + return index == 0 ? data_capacity() - 1 : index - 1; + } + + size_type index_next(size_type index) const { + return index == data_capacity() - 1 ? 0 : index + 1; + } + + size_type index_increment_by(size_type index, difference_type delta) const { + if (delta == 0) { + return index; + } + + QUICHE_DCHECK_LT(static_cast(std::abs(delta)), data_capacity()); + return (index + data_capacity() + delta) % data_capacity(); + } + + // Empty base-class optimization: bundle storage for our allocator together + // with the fields we had to store anyway, via inheriting from the allocator, + // so this allocator instance doesn't consume any storage when its type has no + // data members. + struct QUICHE_NO_EXPORT AllocatorAndData : private allocator_type { + explicit AllocatorAndData(const allocator_type& alloc) + : allocator_type(alloc) {} + + const allocator_type& allocator() const { return *this; } + allocator_type& allocator() { return *this; } + + pointer data = nullptr; + size_type data_capacity = 0; + }; + + size_type begin_ = 0; + size_type end_ = 0; + AllocatorAndData allocator_and_data_; +}; + +} // namespace quiche + +#endif // QUICHE_COMMON_QUICHE_CIRCULAR_DEQUE_H_ diff --git a/quiche/common/quiche_circular_deque_test.cc b/quiche/common/quiche_circular_deque_test.cc new file mode 100644 index 000000000000..a239cd5d13e9 --- /dev/null +++ b/quiche/common/quiche_circular_deque_test.cc @@ -0,0 +1,799 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/quiche_circular_deque.h" + +#include +#include +#include +#include +#include + +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +using testing::ElementsAre; + +namespace quiche { +namespace test { +namespace { + +template class BaseAllocator = std::allocator> +class CountingAllocator : public BaseAllocator { + using BaseType = BaseAllocator; + + public: + using propagate_on_container_copy_assignment = std::true_type; + using propagate_on_container_move_assignment = std::true_type; + using propagate_on_container_swap = std::true_type; + + T* allocate(std::size_t n) { + ++shared_counts_->allocate_count; + return BaseType::allocate(n); + } + + void deallocate(T* ptr, std::size_t n) { + ++shared_counts_->deallocate_count; + return BaseType::deallocate(ptr, n); + } + + size_t allocate_count() const { return shared_counts_->allocate_count; } + + size_t deallocate_count() const { return shared_counts_->deallocate_count; } + + friend bool operator==(const CountingAllocator& lhs, + const CountingAllocator& rhs) { + return lhs.shared_counts_ == rhs.shared_counts_; + } + + friend bool operator!=(const CountingAllocator& lhs, + const CountingAllocator& rhs) { + return !(lhs == rhs); + } + + private: + struct Counts { + size_t allocate_count = 0; + size_t deallocate_count = 0; + }; + + std::shared_ptr shared_counts_ = std::make_shared(); +}; + +template class BaseAllocator = std::allocator> +struct ConfigurableAllocator : public BaseAllocator { + using propagate_on_container_copy_assignment = propagate_on_copy_assignment; + using propagate_on_container_move_assignment = propagate_on_move_assignment; + using propagate_on_container_swap = propagate_on_swap; + + friend bool operator==(const ConfigurableAllocator& /*lhs*/, + const ConfigurableAllocator& /*rhs*/) { + return equality_result; + } + + friend bool operator!=(const ConfigurableAllocator& lhs, + const ConfigurableAllocator& rhs) { + return !(lhs == rhs); + } +}; + +// [1, 2, 3, 4] ==> [4, 1, 2, 3] +template +void ShiftRight(Deque* dq, bool emplace) { + auto back = *(&dq->back()); + dq->pop_back(); + if (emplace) { + dq->emplace_front(back); + } else { + dq->push_front(back); + } +} + +// [1, 2, 3, 4] ==> [2, 3, 4, 1] +template +void ShiftLeft(Deque* dq, bool emplace) { + auto front = *(&dq->front()); + dq->pop_front(); + if (emplace) { + dq->emplace_back(front); + } else { + dq->push_back(front); + } +} + +class QuicheCircularDequeTest : public QuicheTest {}; + +TEST_F(QuicheCircularDequeTest, Empty) { + QuicheCircularDeque dq; + EXPECT_TRUE(dq.empty()); + EXPECT_EQ(0u, dq.size()); + dq.clear(); + dq.push_back(10); + EXPECT_FALSE(dq.empty()); + EXPECT_EQ(1u, dq.size()); + EXPECT_EQ(10, dq.front()); + EXPECT_EQ(10, dq.back()); + dq.pop_front(); + EXPECT_TRUE(dq.empty()); + EXPECT_EQ(0u, dq.size()); + + EXPECT_QUICHE_DEBUG_DEATH(dq.front(), ""); + EXPECT_QUICHE_DEBUG_DEATH(dq.back(), ""); + EXPECT_QUICHE_DEBUG_DEATH(dq.at(0), ""); + EXPECT_QUICHE_DEBUG_DEATH(dq[0], ""); +} + +TEST_F(QuicheCircularDequeTest, Constructor) { + QuicheCircularDeque dq; + EXPECT_TRUE(dq.empty()); + + std::allocator alloc; + QuicheCircularDeque dq1(alloc); + EXPECT_TRUE(dq1.empty()); + + QuicheCircularDeque dq2(8, 100, alloc); + EXPECT_THAT(dq2, ElementsAre(100, 100, 100, 100, 100, 100, 100, 100)); + + QuicheCircularDeque dq3(5, alloc); + EXPECT_THAT(dq3, ElementsAre(0, 0, 0, 0, 0)); + + QuicheCircularDeque dq4_rand_iter(dq3.begin(), dq3.end(), alloc); + EXPECT_THAT(dq4_rand_iter, ElementsAre(0, 0, 0, 0, 0)); + EXPECT_EQ(dq4_rand_iter, dq3); + + std::list dq4_src = {4, 4, 4, 4}; + QuicheCircularDeque dq4_bidi_iter(dq4_src.begin(), dq4_src.end()); + EXPECT_THAT(dq4_bidi_iter, ElementsAre(4, 4, 4, 4)); + + QuicheCircularDeque dq5(dq4_bidi_iter); + EXPECT_THAT(dq5, ElementsAre(4, 4, 4, 4)); + EXPECT_EQ(dq5, dq4_bidi_iter); + + QuicheCircularDeque dq6(dq5, alloc); + EXPECT_THAT(dq6, ElementsAre(4, 4, 4, 4)); + EXPECT_EQ(dq6, dq5); + + QuicheCircularDeque dq7(std::move(*&dq6)); + EXPECT_THAT(dq7, ElementsAre(4, 4, 4, 4)); + EXPECT_TRUE(dq6.empty()); + + QuicheCircularDeque dq8_equal_allocator(std::move(*&dq7), alloc); + EXPECT_THAT(dq8_equal_allocator, ElementsAre(4, 4, 4, 4)); + EXPECT_TRUE(dq7.empty()); + + QuicheCircularDeque> dq8_temp = {5, 6, 7, 8, + 9}; + QuicheCircularDeque> dq8_unequal_allocator( + std::move(*&dq8_temp), CountingAllocator()); + EXPECT_THAT(dq8_unequal_allocator, ElementsAre(5, 6, 7, 8, 9)); + EXPECT_TRUE(dq8_temp.empty()); + + QuicheCircularDeque dq9({3, 4, 5, 6, 7}, alloc); + EXPECT_THAT(dq9, ElementsAre(3, 4, 5, 6, 7)); +} + +TEST_F(QuicheCircularDequeTest, Assign) { + // assign() + QuicheCircularDeque> dq; + dq.assign(7, 1); + EXPECT_THAT(dq, ElementsAre(1, 1, 1, 1, 1, 1, 1)); + EXPECT_EQ(1u, dq.get_allocator().allocate_count()); + + QuicheCircularDeque> dq2; + dq2.assign(dq.begin(), dq.end()); + EXPECT_THAT(dq2, ElementsAre(1, 1, 1, 1, 1, 1, 1)); + EXPECT_EQ(1u, dq2.get_allocator().allocate_count()); + EXPECT_TRUE(std::equal(dq.begin(), dq.end(), dq2.begin(), dq2.end())); + + dq2.assign({2, 2, 2, 2, 2, 2}); + EXPECT_THAT(dq2, ElementsAre(2, 2, 2, 2, 2, 2)); + + // Assign from a non random access iterator. + std::list dq3_src = {3, 3, 3, 3, 3}; + QuicheCircularDeque> dq3; + dq3.assign(dq3_src.begin(), dq3_src.end()); + EXPECT_THAT(dq3, ElementsAre(3, 3, 3, 3, 3)); + EXPECT_LT(1u, dq3.get_allocator().allocate_count()); + + // Copy assignment + dq3 = *&dq3; + EXPECT_THAT(dq3, ElementsAre(3, 3, 3, 3, 3)); + + QuicheCircularDeque< + int, 3, + ConfigurableAllocator> + dq4, dq5; + dq4.assign(dq3.begin(), dq3.end()); + dq5 = dq4; + EXPECT_THAT(dq5, ElementsAre(3, 3, 3, 3, 3)); + + QuicheCircularDeque< + int, 3, + ConfigurableAllocator> + dq6, dq7; + dq6.assign(dq3.begin(), dq3.end()); + dq7 = dq6; + EXPECT_THAT(dq7, ElementsAre(3, 3, 3, 3, 3)); + + // Move assignment + dq3 = std::move(*&dq3); + EXPECT_THAT(dq3, ElementsAre(3, 3, 3, 3, 3)); + + ASSERT_TRUE(decltype(dq3.get_allocator()):: + propagate_on_container_move_assignment::value); + decltype(dq3) dq8; + dq8 = std::move(*&dq3); + EXPECT_THAT(dq8, ElementsAre(3, 3, 3, 3, 3)); + EXPECT_TRUE(dq3.empty()); + + QuicheCircularDeque< + int, 3, + ConfigurableAllocator> + dq9, dq10; + dq9.assign(dq8.begin(), dq8.end()); + dq10.assign(dq2.begin(), dq2.end()); + dq9 = std::move(*&dq10); + EXPECT_THAT(dq9, ElementsAre(2, 2, 2, 2, 2, 2)); + EXPECT_TRUE(dq10.empty()); + + QuicheCircularDeque< + int, 3, + ConfigurableAllocator> + dq11, dq12; + dq11.assign(dq8.begin(), dq8.end()); + dq12.assign(dq2.begin(), dq2.end()); + dq11 = std::move(*&dq12); + EXPECT_THAT(dq11, ElementsAre(2, 2, 2, 2, 2, 2)); + EXPECT_TRUE(dq12.empty()); +} + +TEST_F(QuicheCircularDequeTest, Access) { + // at() + // operator[] + // front() + // back() + + QuicheCircularDeque> dq; + dq.push_back(10); + EXPECT_EQ(dq.front(), 10); + EXPECT_EQ(dq.back(), 10); + EXPECT_EQ(dq.at(0), 10); + EXPECT_EQ(dq[0], 10); + dq.front() = 12; + EXPECT_EQ(dq.front(), 12); + EXPECT_EQ(dq.back(), 12); + EXPECT_EQ(dq.at(0), 12); + EXPECT_EQ(dq[0], 12); + + const auto& dqref = dq; + EXPECT_EQ(dqref.front(), 12); + EXPECT_EQ(dqref.back(), 12); + EXPECT_EQ(dqref.at(0), 12); + EXPECT_EQ(dqref[0], 12); + + dq.pop_front(); + EXPECT_TRUE(dqref.empty()); + + // Push to capacity. + dq.push_back(15); + dq.push_front(5); + dq.push_back(25); + EXPECT_EQ(dq.size(), dq.capacity()); + EXPECT_THAT(dq, ElementsAre(5, 15, 25)); + EXPECT_LT(&dq.front(), &dq.back()); + EXPECT_EQ(dq.front(), 5); + EXPECT_EQ(dq.back(), 25); + EXPECT_EQ(dq.at(0), 5); + EXPECT_EQ(dq.at(1), 15); + EXPECT_EQ(dq.at(2), 25); + EXPECT_EQ(dq[0], 5); + EXPECT_EQ(dq[1], 15); + EXPECT_EQ(dq[2], 25); + + // Shift right such that begin=1 and end=0. Data is still not wrapped. + dq.pop_front(); + dq.push_back(35); + EXPECT_THAT(dq, ElementsAre(15, 25, 35)); + EXPECT_LT(&dq.front(), &dq.back()); + EXPECT_EQ(dq.front(), 15); + EXPECT_EQ(dq.back(), 35); + EXPECT_EQ(dq.at(0), 15); + EXPECT_EQ(dq.at(1), 25); + EXPECT_EQ(dq.at(2), 35); + EXPECT_EQ(dq[0], 15); + EXPECT_EQ(dq[1], 25); + EXPECT_EQ(dq[2], 35); + + // Shift right such that data is wrapped. + dq.pop_front(); + dq.push_back(45); + EXPECT_THAT(dq, ElementsAre(25, 35, 45)); + EXPECT_GT(&dq.front(), &dq.back()); + EXPECT_EQ(dq.front(), 25); + EXPECT_EQ(dq.back(), 45); + EXPECT_EQ(dq.at(0), 25); + EXPECT_EQ(dq.at(1), 35); + EXPECT_EQ(dq.at(2), 45); + EXPECT_EQ(dq[0], 25); + EXPECT_EQ(dq[1], 35); + EXPECT_EQ(dq[2], 45); + + // Shift right again, data is still wrapped. + dq.pop_front(); + dq.push_back(55); + EXPECT_THAT(dq, ElementsAre(35, 45, 55)); + EXPECT_GT(&dq.front(), &dq.back()); + EXPECT_EQ(dq.front(), 35); + EXPECT_EQ(dq.back(), 55); + EXPECT_EQ(dq.at(0), 35); + EXPECT_EQ(dq.at(1), 45); + EXPECT_EQ(dq.at(2), 55); + EXPECT_EQ(dq[0], 35); + EXPECT_EQ(dq[1], 45); + EXPECT_EQ(dq[2], 55); + + // Shift right one last time. begin returns to 0. Data is no longer wrapped. + dq.pop_front(); + dq.push_back(65); + EXPECT_THAT(dq, ElementsAre(45, 55, 65)); + EXPECT_LT(&dq.front(), &dq.back()); + EXPECT_EQ(dq.front(), 45); + EXPECT_EQ(dq.back(), 65); + EXPECT_EQ(dq.at(0), 45); + EXPECT_EQ(dq.at(1), 55); + EXPECT_EQ(dq.at(2), 65); + EXPECT_EQ(dq[0], 45); + EXPECT_EQ(dq[1], 55); + EXPECT_EQ(dq[2], 65); + + EXPECT_EQ(1u, dq.get_allocator().allocate_count()); +} + +TEST_F(QuicheCircularDequeTest, Iterate) { + QuicheCircularDeque dq; + EXPECT_EQ(dq.begin(), dq.end()); + EXPECT_EQ(dq.cbegin(), dq.cend()); + EXPECT_EQ(dq.rbegin(), dq.rend()); + EXPECT_EQ(dq.crbegin(), dq.crend()); + + dq.emplace_back(2); + QuicheCircularDeque::const_iterator citer = dq.begin(); + EXPECT_NE(citer, dq.end()); + EXPECT_EQ(*citer, 2); + ++citer; + EXPECT_EQ(citer, dq.end()); + + EXPECT_EQ(*dq.begin(), 2); + EXPECT_EQ(*dq.cbegin(), 2); + EXPECT_EQ(*dq.rbegin(), 2); + EXPECT_EQ(*dq.crbegin(), 2); + + dq.emplace_front(1); + QuicheCircularDeque::const_reverse_iterator criter = dq.rbegin(); + EXPECT_NE(criter, dq.rend()); + EXPECT_EQ(*criter, 2); + ++criter; + EXPECT_NE(criter, dq.rend()); + EXPECT_EQ(*criter, 1); + ++criter; + EXPECT_EQ(criter, dq.rend()); + + EXPECT_EQ(*dq.begin(), 1); + EXPECT_EQ(*dq.cbegin(), 1); + EXPECT_EQ(*dq.rbegin(), 2); + EXPECT_EQ(*dq.crbegin(), 2); + + dq.push_back(3); + + // Forward iterate. + int expected_value = 1; + for (QuicheCircularDeque::iterator it = dq.begin(); it != dq.end(); + ++it) { + EXPECT_EQ(expected_value++, *it); + } + + expected_value = 1; + for (QuicheCircularDeque::const_iterator it = dq.cbegin(); + it != dq.cend(); ++it) { + EXPECT_EQ(expected_value++, *it); + } + + // Reverse iterate. + expected_value = 3; + for (QuicheCircularDeque::reverse_iterator it = dq.rbegin(); + it != dq.rend(); ++it) { + EXPECT_EQ(expected_value--, *it); + } + + expected_value = 3; + for (QuicheCircularDeque::const_reverse_iterator it = dq.crbegin(); + it != dq.crend(); ++it) { + EXPECT_EQ(expected_value--, *it); + } +} + +TEST_F(QuicheCircularDequeTest, Iterator) { + // Default constructed iterators of the same type compare equal. + EXPECT_EQ(QuicheCircularDeque::iterator(), + QuicheCircularDeque::iterator()); + EXPECT_EQ(QuicheCircularDeque::const_iterator(), + QuicheCircularDeque::const_iterator()); + EXPECT_EQ(QuicheCircularDeque::reverse_iterator(), + QuicheCircularDeque::reverse_iterator()); + EXPECT_EQ(QuicheCircularDeque::const_reverse_iterator(), + QuicheCircularDeque::const_reverse_iterator()); + + QuicheCircularDeque, 3> dqdq = { + {1, 2}, {10, 20, 30}, {100, 200, 300, 400}}; + + // iter points to {1, 2} + decltype(dqdq)::iterator iter = dqdq.begin(); + EXPECT_EQ(iter->size(), 2u); + EXPECT_THAT(*iter, ElementsAre(1, 2)); + + // citer points to {10, 20, 30} + decltype(dqdq)::const_iterator citer = dqdq.cbegin() + 1; + EXPECT_NE(*iter, *citer); + EXPECT_EQ(citer->size(), 3u); + int x = 10; + for (auto it = citer->begin(); it != citer->end(); ++it) { + EXPECT_EQ(*it, x); + x += 10; + } + + EXPECT_LT(iter, citer); + EXPECT_LE(iter, iter); + EXPECT_GT(citer, iter); + EXPECT_GE(citer, citer); + + // iter points to {100, 200, 300, 400} + iter += 2; + EXPECT_NE(*iter, *citer); + EXPECT_EQ(iter->size(), 4u); + for (int i = 1; i <= 4; ++i) { + EXPECT_EQ(iter->begin()[i - 1], i * 100); + } + + EXPECT_LT(citer, iter); + EXPECT_LE(iter, iter); + EXPECT_GT(iter, citer); + EXPECT_GE(citer, citer); + + // iter points to {10, 20, 30}. (same as citer) + iter -= 1; + EXPECT_EQ(*iter, *citer); + EXPECT_EQ(iter->size(), 3u); + x = 10; + for (auto it = iter->begin(); it != iter->end();) { + EXPECT_EQ(*(it++), x); + x += 10; + } + x = 30; + for (auto it = iter->begin() + 2; it != iter->begin();) { + EXPECT_EQ(*(it--), x); + x -= 10; + } +} + +TEST_F(QuicheCircularDequeTest, Resize) { + QuicheCircularDeque> dq; + dq.resize(8); + EXPECT_THAT(dq, ElementsAre(0, 0, 0, 0, 0, 0, 0, 0)); + EXPECT_EQ(1u, dq.get_allocator().allocate_count()); + + dq.resize(10, 5); + EXPECT_THAT(dq, ElementsAre(0, 0, 0, 0, 0, 0, 0, 0, 5, 5)); + + QuicheCircularDeque> dq2 = dq; + + for (size_t new_size = dq.size(); new_size != 0; --new_size) { + dq.resize(new_size); + EXPECT_TRUE( + std::equal(dq.begin(), dq.end(), dq2.begin(), dq2.begin() + new_size)); + } + + dq.resize(0); + EXPECT_TRUE(dq.empty()); + + // Resize when data is wrapped. + ASSERT_EQ(dq2.size(), dq2.capacity()); + while (dq2.size() < dq2.capacity()) { + dq2.push_back(5); + } + + // Shift left once such that data is wrapped. + ASSERT_LT(&dq2.front(), &dq2.back()); + dq2.pop_back(); + dq2.push_front(-5); + ASSERT_GT(&dq2.front(), &dq2.back()); + + EXPECT_EQ(-5, dq2.front()); + EXPECT_EQ(5, dq2.back()); + dq2.resize(dq2.size() + 1, 10); + + // Data should be unwrapped after the resize. + ASSERT_LT(&dq2.front(), &dq2.back()); + EXPECT_EQ(-5, dq2.front()); + EXPECT_EQ(10, dq2.back()); + EXPECT_EQ(5, *(dq2.rbegin() + 1)); +} + +namespace { +class Foo { + public: + Foo() : Foo(0xF00) {} + + explicit Foo(int i) : i_(new int(i)) {} + + ~Foo() { + if (i_ != nullptr) { + delete i_; + // Do not set i_ to nullptr such that if the container calls destructor + // multiple times, asan can detect it. + } + } + + Foo(const Foo& other) : i_(new int(*other.i_)) {} + + Foo(Foo&& other) = delete; + + void Set(int i) { *i_ = i; } + + int i() const { return *i_; } + + friend bool operator==(const Foo& lhs, const Foo& rhs) { + return lhs.i() == rhs.i(); + } + + friend std::ostream& operator<<(std::ostream& os, const Foo& foo) { + return os << "Foo(" << foo.i() << ")"; + } + + private: + // By pointing i_ to a dynamically allocated integer, a memory leak will be + // reported if the container forget to properly destruct this object. + int* i_ = nullptr; +}; +} // namespace + +TEST_F(QuicheCircularDequeTest, RelocateNonTriviallyCopyable) { + // When relocating non-trivially-copyable objects: + // - Move constructor is preferred, if available. + // - Copy constructor is used otherwise. + + { + // Move construct in Relocate. + using MoveConstructible = std::unique_ptr; + ASSERT_FALSE(std::is_trivially_copyable::value); + ASSERT_TRUE(std::is_move_constructible::value); + QuicheCircularDeque> + dq1; + dq1.resize(3); + EXPECT_EQ(dq1.size(), dq1.capacity()); + EXPECT_EQ(1u, dq1.get_allocator().allocate_count()); + + dq1.emplace_back(new Foo(0xF1)); // Cause existing elements to relocate. + EXPECT_EQ(4u, dq1.size()); + EXPECT_EQ(2u, dq1.get_allocator().allocate_count()); + EXPECT_EQ(dq1[0], nullptr); + EXPECT_EQ(dq1[1], nullptr); + EXPECT_EQ(dq1[2], nullptr); + EXPECT_EQ(dq1[3]->i(), 0xF1); + } + + { + // Copy construct in Relocate. + using NonMoveConstructible = Foo; + ASSERT_FALSE(std::is_trivially_copyable::value); + ASSERT_FALSE(std::is_move_constructible::value); + QuicheCircularDeque> + dq2; + dq2.resize(3); + EXPECT_EQ(dq2.size(), dq2.capacity()); + EXPECT_EQ(1u, dq2.get_allocator().allocate_count()); + + dq2.emplace_back(0xF1); // Cause existing elements to relocate. + EXPECT_EQ(4u, dq2.size()); + EXPECT_EQ(2u, dq2.get_allocator().allocate_count()); + EXPECT_EQ(dq2[0].i(), 0xF00); + EXPECT_EQ(dq2[1].i(), 0xF00); + EXPECT_EQ(dq2[2].i(), 0xF00); + EXPECT_EQ(dq2[3].i(), 0xF1); + } +} + +TEST_F(QuicheCircularDequeTest, PushPop) { + // (push|pop|emplace)_(back|front) + + { + QuicheCircularDeque> dq(4); + for (size_t i = 0; i < dq.size(); ++i) { + dq[i].Set(i + 1); + } + QUICHE_LOG(INFO) << "dq initialized to " << dq; + EXPECT_THAT(dq, ElementsAre(Foo(1), Foo(2), Foo(3), Foo(4))); + + ShiftLeft(&dq, false); + QUICHE_LOG(INFO) << "shift left once : " << dq; + EXPECT_THAT(dq, ElementsAre(Foo(2), Foo(3), Foo(4), Foo(1))); + + ShiftLeft(&dq, true); + QUICHE_LOG(INFO) << "shift left twice: " << dq; + EXPECT_THAT(dq, ElementsAre(Foo(3), Foo(4), Foo(1), Foo(2))); + ASSERT_GT(&dq.front(), &dq.back()); + // dq destructs with wrapped data. + } + + { + QuicheCircularDeque> dq1(4); + for (size_t i = 0; i < dq1.size(); ++i) { + dq1[i].Set(i + 1); + } + QUICHE_LOG(INFO) << "dq1 initialized to " << dq1; + EXPECT_THAT(dq1, ElementsAre(Foo(1), Foo(2), Foo(3), Foo(4))); + + ShiftRight(&dq1, false); + QUICHE_LOG(INFO) << "shift right once : " << dq1; + EXPECT_THAT(dq1, ElementsAre(Foo(4), Foo(1), Foo(2), Foo(3))); + + ShiftRight(&dq1, true); + QUICHE_LOG(INFO) << "shift right twice: " << dq1; + EXPECT_THAT(dq1, ElementsAre(Foo(3), Foo(4), Foo(1), Foo(2))); + ASSERT_GT(&dq1.front(), &dq1.back()); + // dq1 destructs with wrapped data. + } + + { // Pop n elements from front. + QuicheCircularDeque> dq2(5); + for (size_t i = 0; i < dq2.size(); ++i) { + dq2[i].Set(i + 1); + } + EXPECT_THAT(dq2, ElementsAre(Foo(1), Foo(2), Foo(3), Foo(4), Foo(5))); + + EXPECT_EQ(2u, dq2.pop_front_n(2)); + EXPECT_THAT(dq2, ElementsAre(Foo(3), Foo(4), Foo(5))); + + EXPECT_EQ(3u, dq2.pop_front_n(100)); + EXPECT_TRUE(dq2.empty()); + } + + { // Pop n elements from back. + QuicheCircularDeque> dq3(6); + for (size_t i = 0; i < dq3.size(); ++i) { + dq3[i].Set(i + 1); + } + EXPECT_THAT(dq3, + ElementsAre(Foo(1), Foo(2), Foo(3), Foo(4), Foo(5), Foo(6))); + + ShiftRight(&dq3, true); + ShiftRight(&dq3, true); + ShiftRight(&dq3, true); + EXPECT_THAT(dq3, + ElementsAre(Foo(4), Foo(5), Foo(6), Foo(1), Foo(2), Foo(3))); + + EXPECT_EQ(2u, dq3.pop_back_n(2)); + EXPECT_THAT(dq3, ElementsAre(Foo(4), Foo(5), Foo(6), Foo(1))); + + EXPECT_EQ(2u, dq3.pop_back_n(2)); + EXPECT_THAT(dq3, ElementsAre(Foo(4), Foo(5))); + } +} + +TEST_F(QuicheCircularDequeTest, Allocation) { + CountingAllocator alloc; + + { + QuicheCircularDeque> dq(alloc); + EXPECT_EQ(alloc, dq.get_allocator()); + EXPECT_EQ(0u, dq.size()); + EXPECT_EQ(0u, dq.capacity()); + EXPECT_EQ(0u, alloc.allocate_count()); + EXPECT_EQ(0u, alloc.deallocate_count()); + + for (int i = 1; i <= 18; ++i) { + SCOPED_TRACE(testing::Message() + << "i=" << i << ", capacity_b4_push=" << dq.capacity()); + dq.push_back(i); + EXPECT_EQ(i, static_cast(dq.size())); + + const size_t capacity = 3 + (i - 1) / 3 * 3; + EXPECT_EQ(capacity, dq.capacity()); + EXPECT_EQ(capacity / 3, alloc.allocate_count()); + EXPECT_EQ(capacity / 3 - 1, alloc.deallocate_count()); + } + + dq.push_back(19); + EXPECT_EQ(22u, dq.capacity()); // 18 + 18 / 4 + EXPECT_EQ(7u, alloc.allocate_count()); + EXPECT_EQ(6u, alloc.deallocate_count()); + } + + EXPECT_EQ(7u, alloc.deallocate_count()); +} + +} // namespace +} // namespace test +} // namespace quiche + +// Use a non-quiche namespace to make sure swap can be used via ADL. +namespace { + +template +using SwappableAllocator = quiche::test::ConfigurableAllocator< + T, + /*propagate_on_copy_assignment=*/std::true_type, + /*propagate_on_move_assignment=*/std::true_type, + /*propagate_on_swap=*/std::true_type, + /*equality_result=*/true>; + +template +using UnswappableEqualAllocator = quiche::test::ConfigurableAllocator< + T, + /*propagate_on_copy_assignment=*/std::true_type, + /*propagate_on_move_assignment=*/std::true_type, + /*propagate_on_swap=*/std::false_type, + /*equality_result=*/true>; + +template +using UnswappableUnequalAllocator = quiche::test::ConfigurableAllocator< + T, + /*propagate_on_copy_assignment=*/std::true_type, + /*propagate_on_move_assignment=*/std::true_type, + /*propagate_on_swap=*/std::false_type, + /*equality_result=*/false>; + +using quiche::test::QuicheCircularDequeTest; + +TEST_F(QuicheCircularDequeTest, Swap) { + using std::swap; + + quiche::QuicheCircularDeque> dq1, dq2; + dq1.push_back(10); + dq1.push_back(11); + dq2.push_back(20); + swap(dq1, dq2); + EXPECT_THAT(dq1, ElementsAre(20)); + EXPECT_THAT(dq2, ElementsAre(10, 11)); + + quiche::QuicheCircularDeque> dq3, + dq4; + dq3 = {1, 2, 3, 4, 5}; + dq4 = {6, 7, 8, 9, 0}; + swap(dq3, dq4); + EXPECT_THAT(dq3, ElementsAre(6, 7, 8, 9, 0)); + EXPECT_THAT(dq4, ElementsAre(1, 2, 3, 4, 5)); + + quiche::QuicheCircularDeque> dq5, + dq6; + dq6.push_front(4); + + // Using UnswappableUnequalAllocator is ok as long as swap is not called. + dq5.assign(dq6.begin(), dq6.end()); + EXPECT_THAT(dq5, ElementsAre(4)); + + // Undefined behavior to swap between two containers with unequal allocators. + EXPECT_QUICHE_DEBUG_DEATH(swap(dq5, dq6), "Undefined swap behavior"); +} +} // namespace diff --git a/quiche/common/quiche_crypto_logging.cc b/quiche/common/quiche_crypto_logging.cc new file mode 100644 index 000000000000..17d5407a762d --- /dev/null +++ b/quiche/common/quiche_crypto_logging.cc @@ -0,0 +1,42 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/quiche_crypto_logging.h" + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "openssl/err.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace quiche { +void DLogOpenSslErrors() { +#ifdef NDEBUG + // Clear OpenSSL error stack. + ClearOpenSslErrors(); +#else + while (uint32_t error = ERR_get_error()) { + char buf[120]; + ERR_error_string_n(error, buf, ABSL_ARRAYSIZE(buf)); + QUICHE_DLOG(ERROR) << "OpenSSL error: " << buf; + } +#endif +} + +void ClearOpenSslErrors() { + while (ERR_get_error()) { + } +} + +absl::Status SslErrorAsStatus(absl::string_view msg, absl::StatusCode code) { + std::string message; + absl::StrAppend(&message, msg, "OpenSSL error: "); + while (uint32_t error = ERR_get_error()) { + char buf[120]; + ERR_error_string_n(error, buf, ABSL_ARRAYSIZE(buf)); + absl::StrAppend(&message, buf); + } + return absl::Status(code, message); +} + +} // namespace quiche diff --git a/quiche/common/quiche_crypto_logging.h b/quiche/common/quiche_crypto_logging.h new file mode 100644 index 000000000000..8806672e8018 --- /dev/null +++ b/quiche/common/quiche_crypto_logging.h @@ -0,0 +1,26 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_QUICHE_CRYPTO_LOGGING_H_ +#define QUICHE_COMMON_QUICHE_CRYPTO_LOGGING_H_ + +#include "absl/status/status.h" + +namespace quiche { + +// In debug builds only, log OpenSSL error stack. Then clear OpenSSL error +// stack. +void DLogOpenSslErrors(); + +// Clears OpenSSL error stack. +void ClearOpenSslErrors(); + +// Include OpenSSL error stack in Status msg so that callers could choose to +// only log it in debug builds if required. +absl::Status SslErrorAsStatus( + absl::string_view msg, absl::StatusCode code = absl::StatusCode::kInternal); + +} // namespace quiche + +#endif // QUICHE_COMMON_QUICHE_CRYPTO_LOGGING_H_ diff --git a/quiche/common/quiche_data_reader.cc b/quiche/common/quiche_data_reader.cc new file mode 100644 index 000000000000..84eca4d2dcc6 --- /dev/null +++ b/quiche/common/quiche_data_reader.cc @@ -0,0 +1,321 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/quiche_data_reader.h" + +#include + +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_endian.h" + +namespace quiche { + +QuicheDataReader::QuicheDataReader(absl::string_view data) + : QuicheDataReader(data.data(), data.length(), quiche::NETWORK_BYTE_ORDER) { +} + +QuicheDataReader::QuicheDataReader(const char* data, const size_t len) + : QuicheDataReader(data, len, quiche::NETWORK_BYTE_ORDER) {} + +QuicheDataReader::QuicheDataReader(const char* data, const size_t len, + quiche::Endianness endianness) + : data_(data), len_(len), pos_(0), endianness_(endianness) {} + +bool QuicheDataReader::ReadUInt8(uint8_t* result) { + return ReadBytes(result, sizeof(*result)); +} + +bool QuicheDataReader::ReadUInt16(uint16_t* result) { + if (!ReadBytes(result, sizeof(*result))) { + return false; + } + if (endianness_ == quiche::NETWORK_BYTE_ORDER) { + *result = quiche::QuicheEndian::NetToHost16(*result); + } + return true; +} + +bool QuicheDataReader::ReadUInt24(uint32_t* result) { + if (endianness_ != quiche::NETWORK_BYTE_ORDER) { + // TODO(b/214573190): Implement and test HOST_BYTE_ORDER case. + QUICHE_BUG(QuicheDataReader_ReadUInt24_NotImplemented); + return false; + } + + *result = 0; + if (!ReadBytes(reinterpret_cast(result) + 1, 3u)) { + return false; + } + *result = quiche::QuicheEndian::NetToHost32(*result); + return true; +} + +bool QuicheDataReader::ReadUInt32(uint32_t* result) { + if (!ReadBytes(result, sizeof(*result))) { + return false; + } + if (endianness_ == quiche::NETWORK_BYTE_ORDER) { + *result = quiche::QuicheEndian::NetToHost32(*result); + } + return true; +} + +bool QuicheDataReader::ReadUInt64(uint64_t* result) { + if (!ReadBytes(result, sizeof(*result))) { + return false; + } + if (endianness_ == quiche::NETWORK_BYTE_ORDER) { + *result = quiche::QuicheEndian::NetToHost64(*result); + } + return true; +} + +bool QuicheDataReader::ReadBytesToUInt64(size_t num_bytes, uint64_t* result) { + *result = 0u; + if (num_bytes > sizeof(*result)) { + return false; + } + if (endianness_ == quiche::HOST_BYTE_ORDER) { + return ReadBytes(result, num_bytes); + } + + if (!ReadBytes(reinterpret_cast(result) + sizeof(*result) - num_bytes, + num_bytes)) { + return false; + } + *result = quiche::QuicheEndian::NetToHost64(*result); + return true; +} + +bool QuicheDataReader::ReadStringPiece16(absl::string_view* result) { + // Read resultant length. + uint16_t result_len; + if (!ReadUInt16(&result_len)) { + // OnFailure() already called. + return false; + } + + return ReadStringPiece(result, result_len); +} + +bool QuicheDataReader::ReadStringPiece8(absl::string_view* result) { + // Read resultant length. + uint8_t result_len; + if (!ReadUInt8(&result_len)) { + // OnFailure() already called. + return false; + } + + return ReadStringPiece(result, result_len); +} + +bool QuicheDataReader::ReadStringPiece(absl::string_view* result, size_t size) { + // Make sure that we have enough data to read. + if (!CanRead(size)) { + OnFailure(); + return false; + } + + // Set result. + *result = absl::string_view(data_ + pos_, size); + + // Iterate. + pos_ += size; + + return true; +} + +bool QuicheDataReader::ReadTag(uint32_t* tag) { + return ReadBytes(tag, sizeof(*tag)); +} + +bool QuicheDataReader::ReadDecimal64(size_t num_digits, uint64_t* result) { + absl::string_view digits; + if (!ReadStringPiece(&digits, num_digits)) { + return false; + } + + return absl::SimpleAtoi(digits, result); +} + +QuicheVariableLengthIntegerLength QuicheDataReader::PeekVarInt62Length() { + QUICHE_DCHECK_EQ(endianness(), NETWORK_BYTE_ORDER); + const unsigned char* next = + reinterpret_cast(data() + pos()); + if (BytesRemaining() == 0) { + return VARIABLE_LENGTH_INTEGER_LENGTH_0; + } + return static_cast( + 1 << ((*next & 0b11000000) >> 6)); +} + +// Read an RFC 9000 62-bit Variable Length Integer. +// +// Performance notes +// +// Measurements and experiments showed that unrolling the four cases +// like this and dereferencing next_ as we do (*(next_+n) --- and then +// doing a single pos_+=x at the end) gains about 10% over making a +// loop and dereferencing next_ such as *(next_++) +// +// Using a register for pos_ was not helpful. +// +// Branches are ordered to increase the likelihood of the first being +// taken. +// +// Low-level optimization is useful here because this function will be +// called frequently, leading to outsize benefits. +bool QuicheDataReader::ReadVarInt62(uint64_t* result) { + QUICHE_DCHECK_EQ(endianness(), quiche::NETWORK_BYTE_ORDER); + + size_t remaining = BytesRemaining(); + const unsigned char* next = + reinterpret_cast(data() + pos()); + if (remaining != 0) { + switch (*next & 0xc0) { + case 0xc0: + // Leading 0b11...... is 8 byte encoding + if (remaining >= 8) { + *result = (static_cast((*(next)) & 0x3f) << 56) + + (static_cast(*(next + 1)) << 48) + + (static_cast(*(next + 2)) << 40) + + (static_cast(*(next + 3)) << 32) + + (static_cast(*(next + 4)) << 24) + + (static_cast(*(next + 5)) << 16) + + (static_cast(*(next + 6)) << 8) + + (static_cast(*(next + 7)) << 0); + AdvancePos(8); + return true; + } + return false; + + case 0x80: + // Leading 0b10...... is 4 byte encoding + if (remaining >= 4) { + *result = (((*(next)) & 0x3f) << 24) + (((*(next + 1)) << 16)) + + (((*(next + 2)) << 8)) + (((*(next + 3)) << 0)); + AdvancePos(4); + return true; + } + return false; + + case 0x40: + // Leading 0b01...... is 2 byte encoding + if (remaining >= 2) { + *result = (((*(next)) & 0x3f) << 8) + (*(next + 1)); + AdvancePos(2); + return true; + } + return false; + + case 0x00: + // Leading 0b00...... is 1 byte encoding + *result = (*next) & 0x3f; + AdvancePos(1); + return true; + } + } + return false; +} + +bool QuicheDataReader::ReadStringPieceVarInt62(absl::string_view* result) { + uint64_t result_length; + if (!ReadVarInt62(&result_length)) { + return false; + } + return ReadStringPiece(result, result_length); +} + +absl::string_view QuicheDataReader::ReadRemainingPayload() { + absl::string_view payload = PeekRemainingPayload(); + pos_ = len_; + return payload; +} + +absl::string_view QuicheDataReader::PeekRemainingPayload() const { + return absl::string_view(data_ + pos_, len_ - pos_); +} + +absl::string_view QuicheDataReader::FullPayload() const { + return absl::string_view(data_, len_); +} + +absl::string_view QuicheDataReader::PreviouslyReadPayload() const { + return absl::string_view(data_, pos_); +} + +bool QuicheDataReader::ReadBytes(void* result, size_t size) { + // Make sure that we have enough data to read. + if (!CanRead(size)) { + OnFailure(); + return false; + } + + // Read into result. + memcpy(result, data_ + pos_, size); + + // Iterate. + pos_ += size; + + return true; +} + +bool QuicheDataReader::Seek(size_t size) { + if (!CanRead(size)) { + OnFailure(); + return false; + } + pos_ += size; + return true; +} + +bool QuicheDataReader::IsDoneReading() const { return len_ == pos_; } + +size_t QuicheDataReader::BytesRemaining() const { + if (pos_ > len_) { + QUICHE_BUG(quiche_reader_pos_out_of_bound) + << "QUIC reader pos out of bound: " << pos_ << ", len: " << len_; + return 0; + } + return len_ - pos_; +} + +bool QuicheDataReader::TruncateRemaining(size_t truncation_length) { + if (truncation_length > BytesRemaining()) { + return false; + } + len_ = pos_ + truncation_length; + return true; +} + +bool QuicheDataReader::CanRead(size_t bytes) const { + return bytes <= (len_ - pos_); +} + +void QuicheDataReader::OnFailure() { + // Set our iterator to the end of the buffer so that further reads fail + // immediately. + pos_ = len_; +} + +uint8_t QuicheDataReader::PeekByte() const { + if (pos_ >= len_) { + QUICHE_LOG(FATAL) + << "Reading is done, cannot peek next byte. Tried to read pos = " + << pos_ << " buffer length = " << len_; + return 0; + } + return data_[pos_]; +} + +std::string QuicheDataReader::DebugString() const { + return absl::StrCat(" { length: ", len_, ", position: ", pos_, " }"); +} + +#undef ENDPOINT // undef for jumbo builds +} // namespace quiche diff --git a/quiche/common/quiche_data_reader.h b/quiche/common/quiche_data_reader.h new file mode 100644 index 000000000000..9f7dd56cd80c --- /dev/null +++ b/quiche/common/quiche_data_reader.h @@ -0,0 +1,216 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_QUICHE_DATA_READER_H_ +#define QUICHE_COMMON_QUICHE_DATA_READER_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_endian.h" + +namespace quiche { + +// To use, simply construct a QuicheDataReader using the underlying buffer that +// you'd like to read fields from, then call one of the Read*() methods to +// actually do some reading. +// +// This class keeps an internal iterator to keep track of what's already been +// read and each successive Read*() call automatically increments said iterator +// on success. On failure, internal state of the QuicheDataReader should not be +// trusted and it is up to the caller to throw away the failed instance and +// handle the error as appropriate. None of the Read*() methods should ever be +// called after failure, as they will also fail immediately. +class QUICHE_EXPORT QuicheDataReader { + public: + // Constructs a reader using NETWORK_BYTE_ORDER endianness. + // Caller must provide an underlying buffer to work on. + explicit QuicheDataReader(absl::string_view data); + // Constructs a reader using NETWORK_BYTE_ORDER endianness. + // Caller must provide an underlying buffer to work on. + QuicheDataReader(const char* data, const size_t len); + // Constructs a reader using the specified endianness. + // Caller must provide an underlying buffer to work on. + QuicheDataReader(const char* data, const size_t len, + quiche::Endianness endianness); + QuicheDataReader(const QuicheDataReader&) = delete; + QuicheDataReader& operator=(const QuicheDataReader&) = delete; + + // Empty destructor. + ~QuicheDataReader() {} + + // Reads an 8/16/24/32/64-bit unsigned integer into the given output + // parameter. Forwards the internal iterator on success. Returns true on + // success, false otherwise. + bool ReadUInt8(uint8_t* result); + bool ReadUInt16(uint16_t* result); + bool ReadUInt24(uint32_t* result); + bool ReadUInt32(uint32_t* result); + bool ReadUInt64(uint64_t* result); + + // Set |result| to 0, then read |num_bytes| bytes in the correct byte order + // into least significant bytes of |result|. + bool ReadBytesToUInt64(size_t num_bytes, uint64_t* result); + + // Reads a string prefixed with 16-bit length into the given output parameter. + // + // NOTE: Does not copy but rather references strings in the underlying buffer. + // This should be kept in mind when handling memory management! + // + // Forwards the internal iterator on success. + // Returns true on success, false otherwise. + bool ReadStringPiece16(absl::string_view* result); + + // Reads a string prefixed with 8-bit length into the given output parameter. + // + // NOTE: Does not copy but rather references strings in the underlying buffer. + // This should be kept in mind when handling memory management! + // + // Forwards the internal iterator on success. + // Returns true on success, false otherwise. + bool ReadStringPiece8(absl::string_view* result); + + // Reads a given number of bytes into the given buffer. The buffer + // must be of adequate size. + // Forwards the internal iterator on success. + // Returns true on success, false otherwise. + bool ReadStringPiece(absl::string_view* result, size_t size); + + // Reads tag represented as 32-bit unsigned integer into given output + // parameter. Tags are in big endian on the wire (e.g., CHLO is + // 'C','H','L','O') and are read in byte order, so tags in memory are in big + // endian. + bool ReadTag(uint32_t* tag); + + // Reads a sequence of a fixed number of decimal digits, parses them as an + // unsigned integer and returns them as a uint64_t. Forwards internal + // iterator on success, may forward it even in case of failure. + bool ReadDecimal64(size_t num_digits, uint64_t* result); + + // Returns the length in bytes of a variable length integer based on the next + // two bits available. Returns 1, 2, 4, or 8 on success, and 0 on failure. + QuicheVariableLengthIntegerLength PeekVarInt62Length(); + + // Read an RFC 9000 62-bit Variable Length Integer and place the result in + // |*result|. Returns false if there is not enough space in the buffer to read + // the number, true otherwise. If false is returned, |*result| is not altered. + bool ReadVarInt62(uint64_t* result); + + // Reads a string prefixed with a RFC 9000 62-bit variable Length integer + // length into the given output parameter. + // + // NOTE: Does not copy but rather references strings in the underlying buffer. + // This should be kept in mind when handling memory management! + // + // Returns false if there is not enough space in the buffer to read + // the number and subsequent string, true otherwise. + bool ReadStringPieceVarInt62(absl::string_view* result); + + // Returns the remaining payload as a absl::string_view. + // + // NOTE: Does not copy but rather references strings in the underlying buffer. + // This should be kept in mind when handling memory management! + // + // Forwards the internal iterator. + absl::string_view ReadRemainingPayload(); + + // Returns the remaining payload as a absl::string_view. + // + // NOTE: Does not copy but rather references strings in the underlying buffer. + // This should be kept in mind when handling memory management! + // + // DOES NOT forward the internal iterator. + absl::string_view PeekRemainingPayload() const; + + // Returns the entire payload as a absl::string_view. + // + // NOTE: Does not copy but rather references strings in the underlying buffer. + // This should be kept in mind when handling memory management! + // + // DOES NOT forward the internal iterator. + absl::string_view FullPayload() const; + + // Returns the part of the payload that has been already read as a + // absl::string_view. + // + // NOTE: Does not copy but rather references strings in the underlying buffer. + // This should be kept in mind when handling memory management! + // + // DOES NOT forward the internal iterator. + absl::string_view PreviouslyReadPayload() const; + + // Reads a given number of bytes into the given buffer. The buffer + // must be of adequate size. + // Forwards the internal iterator on success. + // Returns true on success, false otherwise. + bool ReadBytes(void* result, size_t size); + + // Skips over |size| bytes from the buffer and forwards the internal iterator. + // Returns true if there are at least |size| bytes remaining to read, false + // otherwise. + bool Seek(size_t size); + + // Returns true if the entirety of the underlying buffer has been read via + // Read*() calls. + bool IsDoneReading() const; + + // Returns the number of bytes remaining to be read. + size_t BytesRemaining() const; + + // Truncates the reader down by reducing its internal length. + // If called immediately after calling this, BytesRemaining will + // return |truncation_length|. If truncation_length is less than the + // current value of BytesRemaining, this does nothing and returns false. + bool TruncateRemaining(size_t truncation_length); + + // Returns the next byte that to be read. Must not be called when there are no + // bytes to be read. + // + // DOES NOT forward the internal iterator. + uint8_t PeekByte() const; + + std::string DebugString() const; + + protected: + // Returns true if the underlying buffer has enough room to read the given + // amount of bytes. + bool CanRead(size_t bytes) const; + + // To be called when a read fails for any reason. + void OnFailure(); + + const char* data() const { return data_; } + + size_t pos() const { return pos_; } + + void AdvancePos(size_t amount) { + QUICHE_DCHECK_LE(pos_, std::numeric_limits::max() - amount); + QUICHE_DCHECK_LE(pos_, len_ - amount); + pos_ += amount; + } + + quiche::Endianness endianness() const { return endianness_; } + + private: + // TODO(fkastenholz, b/73004262) change buffer_, et al, to be uint8_t, not + // char. The data buffer that we're reading from. + const char* data_; + + // The length of the data buffer that we're reading from. + size_t len_; + + // The location of the next read from our data buffer. + size_t pos_; + + // The endianness to read integers and floating numbers. + quiche::Endianness endianness_; +}; + +} // namespace quiche + +#endif // QUICHE_COMMON_QUICHE_DATA_READER_H_ diff --git a/quiche/common/quiche_data_reader_test.cc b/quiche/common/quiche_data_reader_test.cc new file mode 100644 index 000000000000..d65dd8831ed6 --- /dev/null +++ b/quiche/common/quiche_data_reader_test.cc @@ -0,0 +1,187 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/quiche_data_reader.h" + +#include + +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/quiche_endian.h" + +namespace quiche { + +// TODO(b/214573190): Test Endianness::HOST_BYTE_ORDER. +// TODO(b/214573190): Test ReadUInt8, ReadUInt24, ReadUInt64, ReadBytesToUInt64, +// ReadStringPiece8, ReadStringPiece, ReadTag, etc. + +TEST(QuicheDataReaderTest, ReadUInt16) { + // Data in network byte order. + const uint16_t kData[] = { + QuicheEndian::HostToNet16(1), + QuicheEndian::HostToNet16(1 << 15), + }; + + QuicheDataReader reader(reinterpret_cast(kData), sizeof(kData)); + EXPECT_FALSE(reader.IsDoneReading()); + + uint16_t uint16_val; + EXPECT_TRUE(reader.ReadUInt16(&uint16_val)); + EXPECT_FALSE(reader.IsDoneReading()); + EXPECT_EQ(1, uint16_val); + + EXPECT_TRUE(reader.ReadUInt16(&uint16_val)); + EXPECT_TRUE(reader.IsDoneReading()); + EXPECT_EQ(1 << 15, uint16_val); +} + +TEST(QuicheDataReaderTest, ReadUInt32) { + // Data in network byte order. + const uint32_t kData[] = { + QuicheEndian::HostToNet32(1), + QuicheEndian::HostToNet32(0x80000000), + }; + + QuicheDataReader reader(reinterpret_cast(kData), + ABSL_ARRAYSIZE(kData) * sizeof(uint32_t)); + EXPECT_FALSE(reader.IsDoneReading()); + + uint32_t uint32_val; + EXPECT_TRUE(reader.ReadUInt32(&uint32_val)); + EXPECT_FALSE(reader.IsDoneReading()); + EXPECT_EQ(1u, uint32_val); + + EXPECT_TRUE(reader.ReadUInt32(&uint32_val)); + EXPECT_TRUE(reader.IsDoneReading()); + EXPECT_EQ(1u << 31, uint32_val); +} + +TEST(QuicheDataReaderTest, ReadStringPiece16) { + // Data in network byte order. + const char kData[] = { + 0x00, 0x02, // uint16_t(2) + 0x48, 0x69, // "Hi" + 0x00, 0x10, // uint16_t(16) + 0x54, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67, 0x2c, + 0x20, 0x31, 0x2c, 0x20, 0x32, 0x2c, 0x20, 0x33, // "Testing, 1, 2, 3" + }; + + QuicheDataReader reader(kData, ABSL_ARRAYSIZE(kData)); + EXPECT_FALSE(reader.IsDoneReading()); + + absl::string_view stringpiece_val; + EXPECT_TRUE(reader.ReadStringPiece16(&stringpiece_val)); + EXPECT_FALSE(reader.IsDoneReading()); + EXPECT_EQ(0, stringpiece_val.compare("Hi")); + + EXPECT_TRUE(reader.ReadStringPiece16(&stringpiece_val)); + EXPECT_TRUE(reader.IsDoneReading()); + EXPECT_EQ(0, stringpiece_val.compare("Testing, 1, 2, 3")); +} + +TEST(QuicheDataReaderTest, ReadUInt16WithBufferTooSmall) { + // Data in network byte order. + const char kData[] = { + 0x00, // part of a uint16_t + }; + + QuicheDataReader reader(kData, ABSL_ARRAYSIZE(kData)); + EXPECT_FALSE(reader.IsDoneReading()); + + uint16_t uint16_val; + EXPECT_FALSE(reader.ReadUInt16(&uint16_val)); +} + +TEST(QuicheDataReaderTest, ReadUInt32WithBufferTooSmall) { + // Data in network byte order. + const char kData[] = { + 0x00, 0x00, 0x00, // part of a uint32_t + }; + + QuicheDataReader reader(kData, ABSL_ARRAYSIZE(kData)); + EXPECT_FALSE(reader.IsDoneReading()); + + uint32_t uint32_val; + EXPECT_FALSE(reader.ReadUInt32(&uint32_val)); + + // Also make sure that trying to read a uint16_t, which technically could + // work, fails immediately due to previously encountered failed read. + uint16_t uint16_val; + EXPECT_FALSE(reader.ReadUInt16(&uint16_val)); +} + +// Tests ReadStringPiece16() with a buffer too small to fit the entire string. +TEST(QuicheDataReaderTest, ReadStringPiece16WithBufferTooSmall) { + // Data in network byte order. + const char kData[] = { + 0x00, 0x03, // uint16_t(3) + 0x48, 0x69, // "Hi" + }; + + QuicheDataReader reader(kData, ABSL_ARRAYSIZE(kData)); + EXPECT_FALSE(reader.IsDoneReading()); + + absl::string_view stringpiece_val; + EXPECT_FALSE(reader.ReadStringPiece16(&stringpiece_val)); + + // Also make sure that trying to read a uint16_t, which technically could + // work, fails immediately due to previously encountered failed read. + uint16_t uint16_val; + EXPECT_FALSE(reader.ReadUInt16(&uint16_val)); +} + +// Tests ReadStringPiece16() with a buffer too small even to fit the length. +TEST(QuicheDataReaderTest, ReadStringPiece16WithBufferWayTooSmall) { + // Data in network byte order. + const char kData[] = { + 0x00, // part of a uint16_t + }; + + QuicheDataReader reader(kData, ABSL_ARRAYSIZE(kData)); + EXPECT_FALSE(reader.IsDoneReading()); + + absl::string_view stringpiece_val; + EXPECT_FALSE(reader.ReadStringPiece16(&stringpiece_val)); + + // Also make sure that trying to read a uint16_t, which technically could + // work, fails immediately due to previously encountered failed read. + uint16_t uint16_val; + EXPECT_FALSE(reader.ReadUInt16(&uint16_val)); +} + +TEST(QuicheDataReaderTest, ReadBytes) { + // Data in network byte order. + const char kData[] = { + 0x66, 0x6f, 0x6f, // "foo" + 0x48, 0x69, // "Hi" + }; + + QuicheDataReader reader(kData, ABSL_ARRAYSIZE(kData)); + EXPECT_FALSE(reader.IsDoneReading()); + + char dest1[3] = {}; + EXPECT_TRUE(reader.ReadBytes(&dest1, ABSL_ARRAYSIZE(dest1))); + EXPECT_FALSE(reader.IsDoneReading()); + EXPECT_EQ("foo", absl::string_view(dest1, ABSL_ARRAYSIZE(dest1))); + + char dest2[2] = {}; + EXPECT_TRUE(reader.ReadBytes(&dest2, ABSL_ARRAYSIZE(dest2))); + EXPECT_TRUE(reader.IsDoneReading()); + EXPECT_EQ("Hi", absl::string_view(dest2, ABSL_ARRAYSIZE(dest2))); +} + +TEST(QuicheDataReaderTest, ReadBytesWithBufferTooSmall) { + // Data in network byte order. + const char kData[] = { + 0x01, + }; + + QuicheDataReader reader(kData, ABSL_ARRAYSIZE(kData)); + EXPECT_FALSE(reader.IsDoneReading()); + + char dest[ABSL_ARRAYSIZE(kData) + 2] = {}; + EXPECT_FALSE(reader.ReadBytes(&dest, ABSL_ARRAYSIZE(kData) + 1)); + EXPECT_STREQ("", dest); +} + +} // namespace quiche diff --git a/quiche/common/quiche_data_writer.cc b/quiche/common/quiche_data_writer.cc new file mode 100644 index 000000000000..5e7943925d0b --- /dev/null +++ b/quiche/common/quiche_data_writer.cc @@ -0,0 +1,301 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/quiche_data_writer.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/quiche_endian.h" + +namespace quiche { + +QuicheDataWriter::QuicheDataWriter(size_t size, char* buffer) + : QuicheDataWriter(size, buffer, quiche::NETWORK_BYTE_ORDER) {} + +QuicheDataWriter::QuicheDataWriter(size_t size, char* buffer, + quiche::Endianness endianness) + : buffer_(buffer), capacity_(size), length_(0), endianness_(endianness) {} + +QuicheDataWriter::~QuicheDataWriter() {} + +char* QuicheDataWriter::data() { return buffer_; } + +bool QuicheDataWriter::WriteUInt8(uint8_t value) { + return WriteBytes(&value, sizeof(value)); +} + +bool QuicheDataWriter::WriteUInt16(uint16_t value) { + if (endianness_ == quiche::NETWORK_BYTE_ORDER) { + value = quiche::QuicheEndian::HostToNet16(value); + } + return WriteBytes(&value, sizeof(value)); +} + +bool QuicheDataWriter::WriteUInt32(uint32_t value) { + if (endianness_ == quiche::NETWORK_BYTE_ORDER) { + value = quiche::QuicheEndian::HostToNet32(value); + } + return WriteBytes(&value, sizeof(value)); +} + +bool QuicheDataWriter::WriteUInt64(uint64_t value) { + if (endianness_ == quiche::NETWORK_BYTE_ORDER) { + value = quiche::QuicheEndian::HostToNet64(value); + } + return WriteBytes(&value, sizeof(value)); +} + +bool QuicheDataWriter::WriteBytesToUInt64(size_t num_bytes, uint64_t value) { + if (num_bytes > sizeof(value)) { + return false; + } + if (endianness_ == quiche::HOST_BYTE_ORDER) { + return WriteBytes(&value, num_bytes); + } + + value = quiche::QuicheEndian::HostToNet64(value); + return WriteBytes(reinterpret_cast(&value) + sizeof(value) - num_bytes, + num_bytes); +} + +bool QuicheDataWriter::WriteStringPiece16(absl::string_view val) { + if (val.size() > std::numeric_limits::max()) { + return false; + } + if (!WriteUInt16(static_cast(val.size()))) { + return false; + } + return WriteBytes(val.data(), val.size()); +} + +bool QuicheDataWriter::WriteStringPiece(absl::string_view val) { + return WriteBytes(val.data(), val.size()); +} + +char* QuicheDataWriter::BeginWrite(size_t length) { + if (length_ > capacity_) { + return nullptr; + } + + if (capacity_ - length_ < length) { + return nullptr; + } + +#ifdef ARCH_CPU_64_BITS + QUICHE_DCHECK_LE(length, std::numeric_limits::max()); +#endif + + return buffer_ + length_; +} + +bool QuicheDataWriter::WriteBytes(const void* data, size_t data_len) { + char* dest = BeginWrite(data_len); + if (!dest) { + return false; + } + + memcpy(dest, data, data_len); + + length_ += data_len; + return true; +} + +bool QuicheDataWriter::WriteRepeatedByte(uint8_t byte, size_t count) { + char* dest = BeginWrite(count); + if (!dest) { + return false; + } + + memset(dest, byte, count); + + length_ += count; + return true; +} + +void QuicheDataWriter::WritePadding() { + QUICHE_DCHECK_LE(length_, capacity_); + if (length_ > capacity_) { + return; + } + memset(buffer_ + length_, 0x00, capacity_ - length_); + length_ = capacity_; +} + +bool QuicheDataWriter::WritePaddingBytes(size_t count) { + return WriteRepeatedByte(0x00, count); +} + +bool QuicheDataWriter::WriteTag(uint32_t tag) { + return WriteBytes(&tag, sizeof(tag)); +} + +// Converts a uint64_t into a 62-bit RFC 9000 Variable Length Integer. +// +// Performance notes +// +// Measurements and experiments showed that unrolling the four cases +// like this and dereferencing next_ as we do (*(next_+n)) gains about +// 10% over making a loop and dereferencing it as *(next_++) +// +// Using a register for next didn't help. +// +// Branches are ordered to increase the likelihood of the first being +// taken. +// +// Low-level optimization is useful here because this function will be +// called frequently, leading to outsize benefits. +bool QuicheDataWriter::WriteVarInt62(uint64_t value) { + QUICHE_DCHECK_EQ(endianness(), quiche::NETWORK_BYTE_ORDER); + + size_t remaining_bytes = remaining(); + char* next = buffer() + length(); + + if ((value & kVarInt62ErrorMask) == 0) { + // We know the high 2 bits are 0 so |value| is legal. + // We can do the encoding. + if ((value & kVarInt62Mask8Bytes) != 0) { + // Someplace in the high-4 bytes is a 1-bit. Do an 8-byte + // encoding. + if (remaining_bytes >= 8) { + *(next + 0) = ((value >> 56) & 0x3f) + 0xc0; + *(next + 1) = (value >> 48) & 0xff; + *(next + 2) = (value >> 40) & 0xff; + *(next + 3) = (value >> 32) & 0xff; + *(next + 4) = (value >> 24) & 0xff; + *(next + 5) = (value >> 16) & 0xff; + *(next + 6) = (value >> 8) & 0xff; + *(next + 7) = value & 0xff; + IncreaseLength(8); + return true; + } + return false; + } + // The high-order-4 bytes are all 0, check for a 1, 2, or 4-byte + // encoding + if ((value & kVarInt62Mask4Bytes) != 0) { + // The encoding will not fit into 2 bytes, Do a 4-byte + // encoding. + if (remaining_bytes >= 4) { + *(next + 0) = ((value >> 24) & 0x3f) + 0x80; + *(next + 1) = (value >> 16) & 0xff; + *(next + 2) = (value >> 8) & 0xff; + *(next + 3) = value & 0xff; + IncreaseLength(4); + return true; + } + return false; + } + // The high-order bits are all 0. Check to see if the number + // can be encoded as one or two bytes. One byte encoding has + // only 6 significant bits (bits 0xffffffff ffffffc0 are all 0). + // Two byte encoding has more than 6, but 14 or less significant + // bits (bits 0xffffffff ffffc000 are 0 and 0x00000000 00003fc0 + // are not 0) + if ((value & kVarInt62Mask2Bytes) != 0) { + // Do 2-byte encoding + if (remaining_bytes >= 2) { + *(next + 0) = ((value >> 8) & 0x3f) + 0x40; + *(next + 1) = (value)&0xff; + IncreaseLength(2); + return true; + } + return false; + } + if (remaining_bytes >= 1) { + // Do 1-byte encoding + *next = (value & 0x3f); + IncreaseLength(1); + return true; + } + return false; + } + // Can not encode, high 2 bits not 0 + return false; +} + +bool QuicheDataWriter::WriteStringPieceVarInt62( + const absl::string_view& string_piece) { + if (!WriteVarInt62(string_piece.size())) { + return false; + } + if (!string_piece.empty()) { + if (!WriteBytes(string_piece.data(), string_piece.size())) { + return false; + } + } + return true; +} + +// static +QuicheVariableLengthIntegerLength QuicheDataWriter::GetVarInt62Len( + uint64_t value) { + if ((value & kVarInt62ErrorMask) != 0) { + QUICHE_BUG(invalid_varint) << "Attempted to encode a value, " << value + << ", that is too big for VarInt62"; + return VARIABLE_LENGTH_INTEGER_LENGTH_0; + } + if ((value & kVarInt62Mask8Bytes) != 0) { + return VARIABLE_LENGTH_INTEGER_LENGTH_8; + } + if ((value & kVarInt62Mask4Bytes) != 0) { + return VARIABLE_LENGTH_INTEGER_LENGTH_4; + } + if ((value & kVarInt62Mask2Bytes) != 0) { + return VARIABLE_LENGTH_INTEGER_LENGTH_2; + } + return VARIABLE_LENGTH_INTEGER_LENGTH_1; +} + +bool QuicheDataWriter::WriteVarInt62WithForcedLength( + uint64_t value, QuicheVariableLengthIntegerLength write_length) { + QUICHE_DCHECK_EQ(endianness(), NETWORK_BYTE_ORDER); + + size_t remaining_bytes = remaining(); + if (remaining_bytes < write_length) { + return false; + } + + const QuicheVariableLengthIntegerLength min_length = GetVarInt62Len(value); + if (write_length < min_length) { + QUICHE_BUG(invalid_varint_forced) << "Cannot write value " << value + << " with write_length " << write_length; + return false; + } + if (write_length == min_length) { + return WriteVarInt62(value); + } + + if (write_length == VARIABLE_LENGTH_INTEGER_LENGTH_2) { + return WriteUInt8(0b01000000) && WriteUInt8(value); + } + if (write_length == VARIABLE_LENGTH_INTEGER_LENGTH_4) { + return WriteUInt8(0b10000000) && WriteUInt8(0) && WriteUInt16(value); + } + if (write_length == VARIABLE_LENGTH_INTEGER_LENGTH_8) { + return WriteUInt8(0b11000000) && WriteUInt8(0) && WriteUInt16(0) && + WriteUInt32(value); + } + + QUICHE_BUG(invalid_write_length) + << "Invalid write_length " << static_cast(write_length); + return false; +} + +bool QuicheDataWriter::Seek(size_t length) { + if (!BeginWrite(length)) { + return false; + } + length_ += length; + return true; +} + +std::string QuicheDataWriter::DebugString() const { + return absl::StrCat(" { capacity: ", capacity_, ", length: ", length_, " }"); +} + +} // namespace quiche diff --git a/quiche/common/quiche_data_writer.h b/quiche/common/quiche_data_writer.h new file mode 100644 index 000000000000..bb691cdc1f65 --- /dev/null +++ b/quiche/common/quiche_data_writer.h @@ -0,0 +1,152 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_QUICHE_DATA_WRITER_H_ +#define QUICHE_COMMON_QUICHE_DATA_WRITER_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_endian.h" + +namespace quiche { + +// Maximum value that can be properly encoded using RFC 9000 62-bit Variable +// Length Integer encoding. +enum : uint64_t { + kVarInt62MaxValue = UINT64_C(0x3fffffffffffffff), +}; + +// RFC 9000 62-bit Variable Length Integer encoding masks +// If a uint64_t anded with a mask is not 0 then the value is encoded +// using that length (or is too big, in the case of kVarInt62ErrorMask). +// Values must be checked in order (error, 8-, 4-, and then 2- bytes) +// and if none are non-0, the value is encoded in 1 byte. +enum : uint64_t { + kVarInt62ErrorMask = UINT64_C(0xc000000000000000), + kVarInt62Mask8Bytes = UINT64_C(0x3fffffffc0000000), + kVarInt62Mask4Bytes = UINT64_C(0x000000003fffc000), + kVarInt62Mask2Bytes = UINT64_C(0x0000000000003fc0), +}; + +// This class provides facilities for packing binary data. +// +// The QuicheDataWriter supports appending primitive values (int, string, etc) +// to a frame instance. The internal memory buffer is exposed as the "data" +// of the QuicheDataWriter. +class QUICHE_EXPORT QuicheDataWriter { + public: + // Creates a QuicheDataWriter where |buffer| is not owned + // using NETWORK_BYTE_ORDER endianness. + QuicheDataWriter(size_t size, char* buffer); + // Creates a QuicheDataWriter where |buffer| is not owned + // using the specified endianness. + QuicheDataWriter(size_t size, char* buffer, quiche::Endianness endianness); + QuicheDataWriter(const QuicheDataWriter&) = delete; + QuicheDataWriter& operator=(const QuicheDataWriter&) = delete; + + ~QuicheDataWriter(); + + // Returns the size of the QuicheDataWriter's data. + size_t length() const { return length_; } + + // Retrieves the buffer from the QuicheDataWriter without changing ownership. + char* data(); + + // Methods for adding to the payload. These values are appended to the end + // of the QuicheDataWriter payload. + + // Writes 8/16/32/64-bit unsigned integers. + bool WriteUInt8(uint8_t value); + bool WriteUInt16(uint16_t value); + bool WriteUInt32(uint32_t value); + bool WriteUInt64(uint64_t value); + + // Writes least significant |num_bytes| of a 64-bit unsigned integer in the + // correct byte order. + bool WriteBytesToUInt64(size_t num_bytes, uint64_t value); + + bool WriteStringPiece(absl::string_view val); + bool WriteStringPiece16(absl::string_view val); + bool WriteBytes(const void* data, size_t data_len); + bool WriteRepeatedByte(uint8_t byte, size_t count); + // Fills the remaining buffer with null characters. + void WritePadding(); + // Write padding of |count| bytes. + bool WritePaddingBytes(size_t count); + + // Write tag as a 32-bit unsigned integer to the payload. As tags are already + // converted to big endian (e.g., CHLO is 'C','H','L','O') in memory by TAG or + // MakeQuicTag and tags are written in byte order, so tags on the wire are + // in big endian. + bool WriteTag(uint32_t tag); + + // Write a 62-bit unsigned integer using RFC 9000 Variable Length Integer + // encoding. Returns false if the value is out of range or if there is no room + // in the buffer. + bool WriteVarInt62(uint64_t value); + + // Same as WriteVarInt62(uint64_t), but forces an encoding size to write to. + // This is not as optimized as WriteVarInt62(uint64_t). Returns false if the + // value does not fit in the specified write_length or if there is no room in + // the buffer. + bool WriteVarInt62WithForcedLength( + uint64_t value, QuicheVariableLengthIntegerLength write_length); + + // Writes a string piece as a consecutive length/content pair. The + // length uses RFC 9000 Variable Length Integer encoding. + bool WriteStringPieceVarInt62(const absl::string_view& string_piece); + + // Utility function to return the number of bytes needed to encode + // the given value using IETF VarInt62 encoding. Returns the number + // of bytes required to encode the given integer or 0 if the value + // is too large to encode. + static QuicheVariableLengthIntegerLength GetVarInt62Len(uint64_t value); + + // Advance the writer's position for writing by |length| bytes without writing + // anything. This method only makes sense to be used on a buffer that has + // already been written to (and is having certain parts rewritten). + bool Seek(size_t length); + + size_t capacity() const { return capacity_; } + + size_t remaining() const { return capacity_ - length_; } + + std::string DebugString() const; + + protected: + // Returns the location that the data should be written at, or nullptr if + // there is not enough room. Call EndWrite with the returned offset and the + // given length to pad out for the next write. + char* BeginWrite(size_t length); + + quiche::Endianness endianness() const { return endianness_; } + + char* buffer() const { return buffer_; } + + void IncreaseLength(size_t delta) { + QUICHE_DCHECK_LE(length_, std::numeric_limits::max() - delta); + QUICHE_DCHECK_LE(length_, capacity_ - delta); + length_ += delta; + } + + private: + // TODO(fkastenholz, b/73004262) change buffer_, et al, to be uint8_t, not + // char. + char* buffer_; + size_t capacity_; // Allocation size of payload (or -1 if buffer is const). + size_t length_; // Current length of the buffer. + + // The endianness to write integers and floating numbers. + quiche::Endianness endianness_; +}; + +} // namespace quiche + +#endif // QUICHE_COMMON_QUICHE_DATA_WRITER_H_ diff --git a/quiche/common/quiche_data_writer_test.cc b/quiche/common/quiche_data_writer_test.cc new file mode 100644 index 000000000000..eb143502b141 --- /dev/null +++ b/quiche/common/quiche_data_writer_test.cc @@ -0,0 +1,835 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/quiche_data_writer.h" + +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/quiche_data_reader.h" +#include "quiche/common/quiche_endian.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +namespace quiche { +namespace test { +namespace { + +char* AsChars(unsigned char* data) { return reinterpret_cast(data); } + +struct TestParams { + explicit TestParams(quiche::Endianness endianness) : endianness(endianness) {} + + quiche::Endianness endianness; +}; + +// Used by ::testing::PrintToStringParamName(). +std::string PrintToString(const TestParams& p) { + return absl::StrCat( + (p.endianness == quiche::NETWORK_BYTE_ORDER ? "Network" : "Host"), + "ByteOrder"); +} + +std::vector GetTestParams() { + std::vector params; + for (quiche::Endianness endianness : + {quiche::NETWORK_BYTE_ORDER, quiche::HOST_BYTE_ORDER}) { + params.push_back(TestParams(endianness)); + } + return params; +} + +class QuicheDataWriterTest : public QuicheTestWithParam {}; + +INSTANTIATE_TEST_SUITE_P(QuicheDataWriterTests, QuicheDataWriterTest, + ::testing::ValuesIn(GetTestParams()), + ::testing::PrintToStringParamName()); + +TEST_P(QuicheDataWriterTest, Write16BitUnsignedIntegers) { + char little_endian16[] = {0x22, 0x11}; + char big_endian16[] = {0x11, 0x22}; + char buffer16[2]; + { + uint16_t in_memory16 = 0x1122; + QuicheDataWriter writer(2, buffer16, GetParam().endianness); + writer.WriteUInt16(in_memory16); + test::CompareCharArraysWithHexError( + "uint16_t", buffer16, 2, + GetParam().endianness == quiche::NETWORK_BYTE_ORDER ? big_endian16 + : little_endian16, + 2); + + uint16_t read_number16; + QuicheDataReader reader(buffer16, 2, GetParam().endianness); + reader.ReadUInt16(&read_number16); + EXPECT_EQ(in_memory16, read_number16); + } + + { + uint64_t in_memory16 = 0x0000000000001122; + QuicheDataWriter writer(2, buffer16, GetParam().endianness); + writer.WriteBytesToUInt64(2, in_memory16); + test::CompareCharArraysWithHexError( + "uint16_t", buffer16, 2, + GetParam().endianness == quiche::NETWORK_BYTE_ORDER ? big_endian16 + : little_endian16, + 2); + + uint64_t read_number16; + QuicheDataReader reader(buffer16, 2, GetParam().endianness); + reader.ReadBytesToUInt64(2, &read_number16); + EXPECT_EQ(in_memory16, read_number16); + } +} + +TEST_P(QuicheDataWriterTest, Write24BitUnsignedIntegers) { + char little_endian24[] = {0x33, 0x22, 0x11}; + char big_endian24[] = {0x11, 0x22, 0x33}; + char buffer24[3]; + uint64_t in_memory24 = 0x0000000000112233; + QuicheDataWriter writer(3, buffer24, GetParam().endianness); + writer.WriteBytesToUInt64(3, in_memory24); + test::CompareCharArraysWithHexError( + "uint24", buffer24, 3, + GetParam().endianness == quiche::NETWORK_BYTE_ORDER ? big_endian24 + : little_endian24, + 3); + + uint64_t read_number24; + QuicheDataReader reader(buffer24, 3, GetParam().endianness); + reader.ReadBytesToUInt64(3, &read_number24); + EXPECT_EQ(in_memory24, read_number24); +} + +TEST_P(QuicheDataWriterTest, Write32BitUnsignedIntegers) { + char little_endian32[] = {0x44, 0x33, 0x22, 0x11}; + char big_endian32[] = {0x11, 0x22, 0x33, 0x44}; + char buffer32[4]; + { + uint32_t in_memory32 = 0x11223344; + QuicheDataWriter writer(4, buffer32, GetParam().endianness); + writer.WriteUInt32(in_memory32); + test::CompareCharArraysWithHexError( + "uint32_t", buffer32, 4, + GetParam().endianness == quiche::NETWORK_BYTE_ORDER ? big_endian32 + : little_endian32, + 4); + + uint32_t read_number32; + QuicheDataReader reader(buffer32, 4, GetParam().endianness); + reader.ReadUInt32(&read_number32); + EXPECT_EQ(in_memory32, read_number32); + } + + { + uint64_t in_memory32 = 0x11223344; + QuicheDataWriter writer(4, buffer32, GetParam().endianness); + writer.WriteBytesToUInt64(4, in_memory32); + test::CompareCharArraysWithHexError( + "uint32_t", buffer32, 4, + GetParam().endianness == quiche::NETWORK_BYTE_ORDER ? big_endian32 + : little_endian32, + 4); + + uint64_t read_number32; + QuicheDataReader reader(buffer32, 4, GetParam().endianness); + reader.ReadBytesToUInt64(4, &read_number32); + EXPECT_EQ(in_memory32, read_number32); + } +} + +TEST_P(QuicheDataWriterTest, Write40BitUnsignedIntegers) { + uint64_t in_memory40 = 0x0000001122334455; + char little_endian40[] = {0x55, 0x44, 0x33, 0x22, 0x11}; + char big_endian40[] = {0x11, 0x22, 0x33, 0x44, 0x55}; + char buffer40[5]; + QuicheDataWriter writer(5, buffer40, GetParam().endianness); + writer.WriteBytesToUInt64(5, in_memory40); + test::CompareCharArraysWithHexError( + "uint40", buffer40, 5, + GetParam().endianness == quiche::NETWORK_BYTE_ORDER ? big_endian40 + : little_endian40, + 5); + + uint64_t read_number40; + QuicheDataReader reader(buffer40, 5, GetParam().endianness); + reader.ReadBytesToUInt64(5, &read_number40); + EXPECT_EQ(in_memory40, read_number40); +} + +TEST_P(QuicheDataWriterTest, Write48BitUnsignedIntegers) { + uint64_t in_memory48 = 0x0000112233445566; + char little_endian48[] = {0x66, 0x55, 0x44, 0x33, 0x22, 0x11}; + char big_endian48[] = {0x11, 0x22, 0x33, 0x44, 0x55, 0x66}; + char buffer48[6]; + QuicheDataWriter writer(6, buffer48, GetParam().endianness); + writer.WriteBytesToUInt64(6, in_memory48); + test::CompareCharArraysWithHexError( + "uint48", buffer48, 6, + GetParam().endianness == quiche::NETWORK_BYTE_ORDER ? big_endian48 + : little_endian48, + 6); + + uint64_t read_number48; + QuicheDataReader reader(buffer48, 6, GetParam().endianness); + reader.ReadBytesToUInt64(6., &read_number48); + EXPECT_EQ(in_memory48, read_number48); +} + +TEST_P(QuicheDataWriterTest, Write56BitUnsignedIntegers) { + uint64_t in_memory56 = 0x0011223344556677; + char little_endian56[] = {0x77, 0x66, 0x55, 0x44, 0x33, 0x22, 0x11}; + char big_endian56[] = {0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77}; + char buffer56[7]; + QuicheDataWriter writer(7, buffer56, GetParam().endianness); + writer.WriteBytesToUInt64(7, in_memory56); + test::CompareCharArraysWithHexError( + "uint56", buffer56, 7, + GetParam().endianness == quiche::NETWORK_BYTE_ORDER ? big_endian56 + : little_endian56, + 7); + + uint64_t read_number56; + QuicheDataReader reader(buffer56, 7, GetParam().endianness); + reader.ReadBytesToUInt64(7, &read_number56); + EXPECT_EQ(in_memory56, read_number56); +} + +TEST_P(QuicheDataWriterTest, Write64BitUnsignedIntegers) { + uint64_t in_memory64 = 0x1122334455667788; + unsigned char little_endian64[] = {0x88, 0x77, 0x66, 0x55, + 0x44, 0x33, 0x22, 0x11}; + unsigned char big_endian64[] = {0x11, 0x22, 0x33, 0x44, + 0x55, 0x66, 0x77, 0x88}; + char buffer64[8]; + QuicheDataWriter writer(8, buffer64, GetParam().endianness); + writer.WriteBytesToUInt64(8, in_memory64); + test::CompareCharArraysWithHexError( + "uint64_t", buffer64, 8, + GetParam().endianness == quiche::NETWORK_BYTE_ORDER + ? AsChars(big_endian64) + : AsChars(little_endian64), + 8); + + uint64_t read_number64; + QuicheDataReader reader(buffer64, 8, GetParam().endianness); + reader.ReadBytesToUInt64(8, &read_number64); + EXPECT_EQ(in_memory64, read_number64); + + QuicheDataWriter writer2(8, buffer64, GetParam().endianness); + writer2.WriteUInt64(in_memory64); + test::CompareCharArraysWithHexError( + "uint64_t", buffer64, 8, + GetParam().endianness == quiche::NETWORK_BYTE_ORDER + ? AsChars(big_endian64) + : AsChars(little_endian64), + 8); + read_number64 = 0u; + QuicheDataReader reader2(buffer64, 8, GetParam().endianness); + reader2.ReadUInt64(&read_number64); + EXPECT_EQ(in_memory64, read_number64); +} + +TEST_P(QuicheDataWriterTest, WriteIntegers) { + char buf[43]; + uint8_t i8 = 0x01; + uint16_t i16 = 0x0123; + uint32_t i32 = 0x01234567; + uint64_t i64 = 0x0123456789ABCDEF; + QuicheDataWriter writer(46, buf, GetParam().endianness); + for (size_t i = 0; i < 10; ++i) { + switch (i) { + case 0u: + EXPECT_TRUE(writer.WriteBytesToUInt64(i, i64)); + break; + case 1u: + EXPECT_TRUE(writer.WriteUInt8(i8)); + EXPECT_TRUE(writer.WriteBytesToUInt64(i, i64)); + break; + case 2u: + EXPECT_TRUE(writer.WriteUInt16(i16)); + EXPECT_TRUE(writer.WriteBytesToUInt64(i, i64)); + break; + case 3u: + EXPECT_TRUE(writer.WriteBytesToUInt64(i, i64)); + break; + case 4u: + EXPECT_TRUE(writer.WriteUInt32(i32)); + EXPECT_TRUE(writer.WriteBytesToUInt64(i, i64)); + break; + case 5u: + case 6u: + case 7u: + case 8u: + EXPECT_TRUE(writer.WriteBytesToUInt64(i, i64)); + break; + default: + EXPECT_FALSE(writer.WriteBytesToUInt64(i, i64)); + } + } + + QuicheDataReader reader(buf, 46, GetParam().endianness); + for (size_t i = 0; i < 10; ++i) { + uint8_t read8; + uint16_t read16; + uint32_t read32; + uint64_t read64; + switch (i) { + case 0u: + EXPECT_TRUE(reader.ReadBytesToUInt64(i, &read64)); + EXPECT_EQ(0u, read64); + break; + case 1u: + EXPECT_TRUE(reader.ReadUInt8(&read8)); + EXPECT_TRUE(reader.ReadBytesToUInt64(i, &read64)); + EXPECT_EQ(i8, read8); + EXPECT_EQ(0xEFu, read64); + break; + case 2u: + EXPECT_TRUE(reader.ReadUInt16(&read16)); + EXPECT_TRUE(reader.ReadBytesToUInt64(i, &read64)); + EXPECT_EQ(i16, read16); + EXPECT_EQ(0xCDEFu, read64); + break; + case 3u: + EXPECT_TRUE(reader.ReadBytesToUInt64(i, &read64)); + EXPECT_EQ(0xABCDEFu, read64); + break; + case 4u: + EXPECT_TRUE(reader.ReadUInt32(&read32)); + EXPECT_TRUE(reader.ReadBytesToUInt64(i, &read64)); + EXPECT_EQ(i32, read32); + EXPECT_EQ(0x89ABCDEFu, read64); + break; + case 5u: + EXPECT_TRUE(reader.ReadBytesToUInt64(i, &read64)); + EXPECT_EQ(0x6789ABCDEFu, read64); + break; + case 6u: + EXPECT_TRUE(reader.ReadBytesToUInt64(i, &read64)); + EXPECT_EQ(0x456789ABCDEFu, read64); + break; + case 7u: + EXPECT_TRUE(reader.ReadBytesToUInt64(i, &read64)); + EXPECT_EQ(0x23456789ABCDEFu, read64); + break; + case 8u: + EXPECT_TRUE(reader.ReadBytesToUInt64(i, &read64)); + EXPECT_EQ(0x0123456789ABCDEFu, read64); + break; + default: + EXPECT_FALSE(reader.ReadBytesToUInt64(i, &read64)); + } + } +} + +TEST_P(QuicheDataWriterTest, WriteBytes) { + char bytes[] = {0, 1, 2, 3, 4, 5, 6, 7, 8}; + char buf[ABSL_ARRAYSIZE(bytes)]; + QuicheDataWriter writer(ABSL_ARRAYSIZE(buf), buf, GetParam().endianness); + EXPECT_TRUE(writer.WriteBytes(bytes, ABSL_ARRAYSIZE(bytes))); + for (unsigned int i = 0; i < ABSL_ARRAYSIZE(bytes); ++i) { + EXPECT_EQ(bytes[i], buf[i]); + } +} + +const int kVarIntBufferLength = 1024; + +// Encodes and then decodes a specified value, checks that the +// value that was encoded is the same as the decoded value, the length +// is correct, and that after decoding, all data in the buffer has +// been consumed.. +// Returns true if everything works, false if not. +bool EncodeDecodeValue(uint64_t value_in, char* buffer, size_t size_of_buffer) { + // Init the buffer to all 0, just for cleanliness. Makes for better + // output if, in debugging, we need to dump out the buffer. + memset(buffer, 0, size_of_buffer); + // make a writer. Note that for IETF encoding + // we do not care about endianness... It's always big-endian, + // but the c'tor expects to be told what endianness is in force... + QuicheDataWriter writer(size_of_buffer, buffer, + quiche::Endianness::NETWORK_BYTE_ORDER); + + // Try to write the value. + if (writer.WriteVarInt62(value_in) != true) { + return false; + } + // Look at the value we encoded. Determine how much should have been + // used based on the value, and then check the state of the writer + // to see that it matches. + size_t expected_length = 0; + if (value_in <= 0x3f) { + expected_length = 1; + } else if (value_in <= 0x3fff) { + expected_length = 2; + } else if (value_in <= 0x3fffffff) { + expected_length = 4; + } else { + expected_length = 8; + } + if (writer.length() != expected_length) { + return false; + } + + // set up a reader, just the length we've used, no more, no less. + QuicheDataReader reader(buffer, expected_length, + quiche::Endianness::NETWORK_BYTE_ORDER); + uint64_t value_out; + + if (reader.ReadVarInt62(&value_out) == false) { + return false; + } + if (value_in != value_out) { + return false; + } + // We only write one value so there had better be nothing left to read + return reader.IsDoneReading(); +} + +// Test that 8-byte-encoded Variable Length Integers are properly laid +// out in the buffer. +TEST_P(QuicheDataWriterTest, VarInt8Layout) { + char buffer[1024]; + + // Check that the layout of bytes in the buffer is correct. Bytes + // are always encoded big endian... + memset(buffer, 0, sizeof(buffer)); + QuicheDataWriter writer(sizeof(buffer), static_cast(buffer), + quiche::Endianness::NETWORK_BYTE_ORDER); + EXPECT_TRUE(writer.WriteVarInt62(UINT64_C(0x3142f3e4d5c6b7a8))); + EXPECT_EQ(static_cast(*(writer.data() + 0)), + (0x31 + 0xc0)); // 0xc0 for encoding + EXPECT_EQ(static_cast(*(writer.data() + 1)), 0x42); + EXPECT_EQ(static_cast(*(writer.data() + 2)), 0xf3); + EXPECT_EQ(static_cast(*(writer.data() + 3)), 0xe4); + EXPECT_EQ(static_cast(*(writer.data() + 4)), 0xd5); + EXPECT_EQ(static_cast(*(writer.data() + 5)), 0xc6); + EXPECT_EQ(static_cast(*(writer.data() + 6)), 0xb7); + EXPECT_EQ(static_cast(*(writer.data() + 7)), 0xa8); +} + +// Test that 4-byte-encoded Variable Length Integers are properly laid +// out in the buffer. +TEST_P(QuicheDataWriterTest, VarInt4Layout) { + char buffer[1024]; + + // Check that the layout of bytes in the buffer is correct. Bytes + // are always encoded big endian... + memset(buffer, 0, sizeof(buffer)); + QuicheDataWriter writer(sizeof(buffer), static_cast(buffer), + quiche::Endianness::NETWORK_BYTE_ORDER); + EXPECT_TRUE(writer.WriteVarInt62(0x3243f4e5)); + EXPECT_EQ(static_cast(*(writer.data() + 0)), + (0x32 + 0x80)); // 0x80 for encoding + EXPECT_EQ(static_cast(*(writer.data() + 1)), 0x43); + EXPECT_EQ(static_cast(*(writer.data() + 2)), 0xf4); + EXPECT_EQ(static_cast(*(writer.data() + 3)), 0xe5); +} + +// Test that 2-byte-encoded Variable Length Integers are properly laid +// out in the buffer. +TEST_P(QuicheDataWriterTest, VarInt2Layout) { + char buffer[1024]; + + // Check that the layout of bytes in the buffer is correct. Bytes + // are always encoded big endian... + memset(buffer, 0, sizeof(buffer)); + QuicheDataWriter writer(sizeof(buffer), static_cast(buffer), + quiche::Endianness::NETWORK_BYTE_ORDER); + EXPECT_TRUE(writer.WriteVarInt62(0x3647)); + EXPECT_EQ(static_cast(*(writer.data() + 0)), + (0x36 + 0x40)); // 0x40 for encoding + EXPECT_EQ(static_cast(*(writer.data() + 1)), 0x47); +} + +// Test that 1-byte-encoded Variable Length Integers are properly laid +// out in the buffer. +TEST_P(QuicheDataWriterTest, VarInt1Layout) { + char buffer[1024]; + + // Check that the layout of bytes in the buffer + // is correct. Bytes are always encoded big endian... + memset(buffer, 0, sizeof(buffer)); + QuicheDataWriter writer(sizeof(buffer), static_cast(buffer), + quiche::Endianness::NETWORK_BYTE_ORDER); + EXPECT_TRUE(writer.WriteVarInt62(0x3f)); + EXPECT_EQ(static_cast(*(writer.data() + 0)), 0x3f); +} + +// Test certain, targeted, values that are expected to succeed: +// 0, 1, +// 0x3e, 0x3f, 0x40, 0x41 (around the 1-2 byte transitions) +// 0x3ffe, 0x3fff, 0x4000, 0x4001 (the 2-4 byte transition) +// 0x3ffffffe, 0x3fffffff, 0x40000000, 0x40000001 (the 4-8 byte +// transition) +// 0x3ffffffffffffffe, 0x3fffffffffffffff, (the highest valid values) +// 0xfe, 0xff, 0x100, 0x101, +// 0xfffe, 0xffff, 0x10000, 0x10001, +// 0xfffffe, 0xffffff, 0x1000000, 0x1000001, +// 0xfffffffe, 0xffffffff, 0x100000000, 0x100000001, +// 0xfffffffffe, 0xffffffffff, 0x10000000000, 0x10000000001, +// 0xfffffffffffe, 0xffffffffffff, 0x1000000000000, 0x1000000000001, +// 0xfffffffffffffe, 0xffffffffffffff, 0x100000000000000, 0x100000000000001, +TEST_P(QuicheDataWriterTest, VarIntGoodTargetedValues) { + char buffer[kVarIntBufferLength]; + uint64_t passing_values[] = { + 0, + 1, + 0x3e, + 0x3f, + 0x40, + 0x41, + 0x3ffe, + 0x3fff, + 0x4000, + 0x4001, + 0x3ffffffe, + 0x3fffffff, + 0x40000000, + 0x40000001, + 0x3ffffffffffffffe, + 0x3fffffffffffffff, + 0xfe, + 0xff, + 0x100, + 0x101, + 0xfffe, + 0xffff, + 0x10000, + 0x10001, + 0xfffffe, + 0xffffff, + 0x1000000, + 0x1000001, + 0xfffffffe, + 0xffffffff, + 0x100000000, + 0x100000001, + 0xfffffffffe, + 0xffffffffff, + 0x10000000000, + 0x10000000001, + 0xfffffffffffe, + 0xffffffffffff, + 0x1000000000000, + 0x1000000000001, + 0xfffffffffffffe, + 0xffffffffffffff, + 0x100000000000000, + 0x100000000000001, + }; + for (uint64_t test_val : passing_values) { + EXPECT_TRUE( + EncodeDecodeValue(test_val, static_cast(buffer), sizeof(buffer))) + << " encode/decode of " << test_val << " failed"; + } +} +// +// Test certain, targeted, values where failure is expected (the +// values are invalid w.r.t. IETF VarInt encoding): +// 0x4000000000000000, 0x4000000000000001, ( Just above max allowed value) +// 0xfffffffffffffffe, 0xffffffffffffffff, (should fail) +TEST_P(QuicheDataWriterTest, VarIntBadTargetedValues) { + char buffer[kVarIntBufferLength]; + uint64_t failing_values[] = { + 0x4000000000000000, + 0x4000000000000001, + 0xfffffffffffffffe, + 0xffffffffffffffff, + }; + for (uint64_t test_val : failing_values) { + EXPECT_FALSE( + EncodeDecodeValue(test_val, static_cast(buffer), sizeof(buffer))) + << " encode/decode of " << test_val << " succeeded, but was an " + << "invalid value"; + } +} +// Test writing varints with a forced length. +TEST_P(QuicheDataWriterTest, WriteVarInt62WithForcedLength) { + char buffer[90]; + memset(buffer, 0, sizeof(buffer)); + QuicheDataWriter writer(sizeof(buffer), static_cast(buffer)); + + writer.WriteVarInt62WithForcedLength(1, VARIABLE_LENGTH_INTEGER_LENGTH_1); + writer.WriteVarInt62WithForcedLength(1, VARIABLE_LENGTH_INTEGER_LENGTH_2); + writer.WriteVarInt62WithForcedLength(1, VARIABLE_LENGTH_INTEGER_LENGTH_4); + writer.WriteVarInt62WithForcedLength(1, VARIABLE_LENGTH_INTEGER_LENGTH_8); + + writer.WriteVarInt62WithForcedLength(63, VARIABLE_LENGTH_INTEGER_LENGTH_1); + writer.WriteVarInt62WithForcedLength(63, VARIABLE_LENGTH_INTEGER_LENGTH_2); + writer.WriteVarInt62WithForcedLength(63, VARIABLE_LENGTH_INTEGER_LENGTH_4); + writer.WriteVarInt62WithForcedLength(63, VARIABLE_LENGTH_INTEGER_LENGTH_8); + + writer.WriteVarInt62WithForcedLength(64, VARIABLE_LENGTH_INTEGER_LENGTH_2); + writer.WriteVarInt62WithForcedLength(64, VARIABLE_LENGTH_INTEGER_LENGTH_4); + writer.WriteVarInt62WithForcedLength(64, VARIABLE_LENGTH_INTEGER_LENGTH_8); + + writer.WriteVarInt62WithForcedLength(16383, VARIABLE_LENGTH_INTEGER_LENGTH_2); + writer.WriteVarInt62WithForcedLength(16383, VARIABLE_LENGTH_INTEGER_LENGTH_4); + writer.WriteVarInt62WithForcedLength(16383, VARIABLE_LENGTH_INTEGER_LENGTH_8); + + writer.WriteVarInt62WithForcedLength(16384, VARIABLE_LENGTH_INTEGER_LENGTH_4); + writer.WriteVarInt62WithForcedLength(16384, VARIABLE_LENGTH_INTEGER_LENGTH_8); + + writer.WriteVarInt62WithForcedLength(1073741823, + VARIABLE_LENGTH_INTEGER_LENGTH_4); + writer.WriteVarInt62WithForcedLength(1073741823, + VARIABLE_LENGTH_INTEGER_LENGTH_8); + + writer.WriteVarInt62WithForcedLength(1073741824, + VARIABLE_LENGTH_INTEGER_LENGTH_8); + + QuicheDataReader reader(buffer, sizeof(buffer)); + + uint64_t test_val = 0; + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(reader.ReadVarInt62(&test_val)); + EXPECT_EQ(test_val, 1u); + } + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(reader.ReadVarInt62(&test_val)); + EXPECT_EQ(test_val, 63u); + } + + for (int i = 0; i < 3; ++i) { + EXPECT_TRUE(reader.ReadVarInt62(&test_val)); + EXPECT_EQ(test_val, 64u); + } + for (int i = 0; i < 3; ++i) { + EXPECT_TRUE(reader.ReadVarInt62(&test_val)); + EXPECT_EQ(test_val, 16383u); + } + + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(reader.ReadVarInt62(&test_val)); + EXPECT_EQ(test_val, 16384u); + } + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(reader.ReadVarInt62(&test_val)); + EXPECT_EQ(test_val, 1073741823u); + } + + EXPECT_TRUE(reader.ReadVarInt62(&test_val)); + EXPECT_EQ(test_val, 1073741824u); + + // We are at the end of the buffer so this should fail. + EXPECT_FALSE(reader.ReadVarInt62(&test_val)); +} + +// Following tests all try to fill the buffer with multiple values, +// go one value more than the buffer can accommodate, then read +// the successfully encoded values, and try to read the unsuccessfully +// encoded value. The following is the number of values to encode. +const int kMultiVarCount = 1000; + +// Test writing & reading multiple 8-byte-encoded varints +TEST_P(QuicheDataWriterTest, MultiVarInt8) { + uint64_t test_val; + char buffer[8 * kMultiVarCount]; + memset(buffer, 0, sizeof(buffer)); + QuicheDataWriter writer(sizeof(buffer), static_cast(buffer), + quiche::Endianness::NETWORK_BYTE_ORDER); + // Put N values into the buffer. Adding i to the value ensures that + // each value is different so we can detect if we overwrite values, + // or read the same value over and over. + for (int i = 0; i < kMultiVarCount; i++) { + EXPECT_TRUE(writer.WriteVarInt62(UINT64_C(0x3142f3e4d5c6b7a8) + i)); + } + EXPECT_EQ(writer.length(), 8u * kMultiVarCount); + + // N+1st should fail, the buffer is full. + EXPECT_FALSE(writer.WriteVarInt62(UINT64_C(0x3142f3e4d5c6b7a8))); + + // Now we should be able to read out the N values that were + // successfully encoded. + QuicheDataReader reader(buffer, sizeof(buffer), + quiche::Endianness::NETWORK_BYTE_ORDER); + for (int i = 0; i < kMultiVarCount; i++) { + EXPECT_TRUE(reader.ReadVarInt62(&test_val)); + EXPECT_EQ(test_val, (UINT64_C(0x3142f3e4d5c6b7a8) + i)); + } + // And the N+1st should fail. + EXPECT_FALSE(reader.ReadVarInt62(&test_val)); +} + +// Test writing & reading multiple 4-byte-encoded varints +TEST_P(QuicheDataWriterTest, MultiVarInt4) { + uint64_t test_val; + char buffer[4 * kMultiVarCount]; + memset(buffer, 0, sizeof(buffer)); + QuicheDataWriter writer(sizeof(buffer), static_cast(buffer), + quiche::Endianness::NETWORK_BYTE_ORDER); + // Put N values into the buffer. Adding i to the value ensures that + // each value is different so we can detect if we overwrite values, + // or read the same value over and over. + for (int i = 0; i < kMultiVarCount; i++) { + EXPECT_TRUE(writer.WriteVarInt62(UINT64_C(0x3142f3e4) + i)); + } + EXPECT_EQ(writer.length(), 4u * kMultiVarCount); + + // N+1st should fail, the buffer is full. + EXPECT_FALSE(writer.WriteVarInt62(UINT64_C(0x3142f3e4))); + + // Now we should be able to read out the N values that were + // successfully encoded. + QuicheDataReader reader(buffer, sizeof(buffer), + quiche::Endianness::NETWORK_BYTE_ORDER); + for (int i = 0; i < kMultiVarCount; i++) { + EXPECT_TRUE(reader.ReadVarInt62(&test_val)); + EXPECT_EQ(test_val, (UINT64_C(0x3142f3e4) + i)); + } + // And the N+1st should fail. + EXPECT_FALSE(reader.ReadVarInt62(&test_val)); +} + +// Test writing & reading multiple 2-byte-encoded varints +TEST_P(QuicheDataWriterTest, MultiVarInt2) { + uint64_t test_val; + char buffer[2 * kMultiVarCount]; + memset(buffer, 0, sizeof(buffer)); + QuicheDataWriter writer(sizeof(buffer), static_cast(buffer), + quiche::Endianness::NETWORK_BYTE_ORDER); + // Put N values into the buffer. Adding i to the value ensures that + // each value is different so we can detect if we overwrite values, + // or read the same value over and over. + for (int i = 0; i < kMultiVarCount; i++) { + EXPECT_TRUE(writer.WriteVarInt62(UINT64_C(0x3142) + i)); + } + EXPECT_EQ(writer.length(), 2u * kMultiVarCount); + + // N+1st should fail, the buffer is full. + EXPECT_FALSE(writer.WriteVarInt62(UINT64_C(0x3142))); + + // Now we should be able to read out the N values that were + // successfully encoded. + QuicheDataReader reader(buffer, sizeof(buffer), + quiche::Endianness::NETWORK_BYTE_ORDER); + for (int i = 0; i < kMultiVarCount; i++) { + EXPECT_TRUE(reader.ReadVarInt62(&test_val)); + EXPECT_EQ(test_val, (UINT64_C(0x3142) + i)); + } + // And the N+1st should fail. + EXPECT_FALSE(reader.ReadVarInt62(&test_val)); +} + +// Test writing & reading multiple 1-byte-encoded varints +TEST_P(QuicheDataWriterTest, MultiVarInt1) { + uint64_t test_val; + char buffer[1 * kMultiVarCount]; + memset(buffer, 0, sizeof(buffer)); + QuicheDataWriter writer(sizeof(buffer), static_cast(buffer), + quiche::Endianness::NETWORK_BYTE_ORDER); + // Put N values into the buffer. Adding i to the value ensures that + // each value is different so we can detect if we overwrite values, + // or read the same value over and over. &0xf ensures we do not + // overflow the max value for single-byte encoding. + for (int i = 0; i < kMultiVarCount; i++) { + EXPECT_TRUE(writer.WriteVarInt62(UINT64_C(0x30) + (i & 0xf))); + } + EXPECT_EQ(writer.length(), 1u * kMultiVarCount); + + // N+1st should fail, the buffer is full. + EXPECT_FALSE(writer.WriteVarInt62(UINT64_C(0x31))); + + // Now we should be able to read out the N values that were + // successfully encoded. + QuicheDataReader reader(buffer, sizeof(buffer), + quiche::Endianness::NETWORK_BYTE_ORDER); + for (int i = 0; i < kMultiVarCount; i++) { + EXPECT_TRUE(reader.ReadVarInt62(&test_val)); + EXPECT_EQ(test_val, (UINT64_C(0x30) + (i & 0xf))); + } + // And the N+1st should fail. + EXPECT_FALSE(reader.ReadVarInt62(&test_val)); +} + +TEST_P(QuicheDataWriterTest, Seek) { + char buffer[3] = {}; + QuicheDataWriter writer(ABSL_ARRAYSIZE(buffer), buffer, + GetParam().endianness); + EXPECT_TRUE(writer.WriteUInt8(42)); + EXPECT_TRUE(writer.Seek(1)); + EXPECT_TRUE(writer.WriteUInt8(3)); + + char expected[] = {42, 0, 3}; + for (size_t i = 0; i < ABSL_ARRAYSIZE(expected); ++i) { + EXPECT_EQ(buffer[i], expected[i]); + } +} + +TEST_P(QuicheDataWriterTest, SeekTooFarFails) { + char buffer[20]; + + // Check that one can seek to the end of the writer, but not past. + { + QuicheDataWriter writer(ABSL_ARRAYSIZE(buffer), buffer, + GetParam().endianness); + EXPECT_TRUE(writer.Seek(20)); + EXPECT_FALSE(writer.Seek(1)); + } + + // Seeking several bytes past the end fails. + { + QuicheDataWriter writer(ABSL_ARRAYSIZE(buffer), buffer, + GetParam().endianness); + EXPECT_FALSE(writer.Seek(100)); + } + + // Seeking so far that arithmetic overflow could occur also fails. + { + QuicheDataWriter writer(ABSL_ARRAYSIZE(buffer), buffer, + GetParam().endianness); + EXPECT_TRUE(writer.Seek(10)); + EXPECT_FALSE(writer.Seek(std::numeric_limits::max())); + } +} + +TEST_P(QuicheDataWriterTest, PayloadReads) { + char buffer[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + char expected_first_read[4] = {1, 2, 3, 4}; + char expected_remaining[12] = {5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + QuicheDataReader reader(buffer, sizeof(buffer)); + absl::string_view previously_read_payload1 = reader.PreviouslyReadPayload(); + EXPECT_TRUE(previously_read_payload1.empty()); + char first_read_buffer[4] = {}; + EXPECT_TRUE(reader.ReadBytes(first_read_buffer, sizeof(first_read_buffer))); + test::CompareCharArraysWithHexError( + "first read", first_read_buffer, sizeof(first_read_buffer), + expected_first_read, sizeof(expected_first_read)); + absl::string_view peeked_remaining_payload = reader.PeekRemainingPayload(); + test::CompareCharArraysWithHexError( + "peeked_remaining_payload", peeked_remaining_payload.data(), + peeked_remaining_payload.length(), expected_remaining, + sizeof(expected_remaining)); + absl::string_view full_payload = reader.FullPayload(); + test::CompareCharArraysWithHexError("full_payload", full_payload.data(), + full_payload.length(), buffer, + sizeof(buffer)); + absl::string_view previously_read_payload2 = reader.PreviouslyReadPayload(); + test::CompareCharArraysWithHexError( + "previously_read_payload2", previously_read_payload2.data(), + previously_read_payload2.length(), first_read_buffer, + sizeof(first_read_buffer)); + absl::string_view read_remaining_payload = reader.ReadRemainingPayload(); + test::CompareCharArraysWithHexError( + "read_remaining_payload", read_remaining_payload.data(), + read_remaining_payload.length(), expected_remaining, + sizeof(expected_remaining)); + EXPECT_TRUE(reader.IsDoneReading()); + absl::string_view full_payload2 = reader.FullPayload(); + test::CompareCharArraysWithHexError("full_payload2", full_payload2.data(), + full_payload2.length(), buffer, + sizeof(buffer)); + absl::string_view previously_read_payload3 = reader.PreviouslyReadPayload(); + test::CompareCharArraysWithHexError( + "previously_read_payload3", previously_read_payload3.data(), + previously_read_payload3.length(), buffer, sizeof(buffer)); +} + +} // namespace +} // namespace test +} // namespace quiche diff --git a/quiche/common/quiche_endian.h b/quiche/common/quiche_endian.h new file mode 100644 index 000000000000..2aaa47831771 --- /dev/null +++ b/quiche/common/quiche_endian.h @@ -0,0 +1,73 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_QUICHE_ENDIAN_H_ +#define QUICHE_COMMON_QUICHE_ENDIAN_H_ + +#include +#include +#include + +#include "quiche/common/platform/api/quiche_export.h" + +namespace quiche { + +enum Endianness { + NETWORK_BYTE_ORDER, // big endian + HOST_BYTE_ORDER // little endian +}; + +// Provide utility functions that convert from/to network order (big endian) +// to/from host order (little endian). +class QUICHE_EXPORT QuicheEndian { + public: + // Convert |x| from host order (little endian) to network order (big endian). +#if defined(__clang__) || \ + (defined(__GNUC__) && \ + ((__GNUC__ == 4 && __GNUC_MINOR__ >= 8) || __GNUC__ >= 5)) + static uint16_t HostToNet16(uint16_t x) { return __builtin_bswap16(x); } + static uint32_t HostToNet32(uint32_t x) { return __builtin_bswap32(x); } + static uint64_t HostToNet64(uint64_t x) { return __builtin_bswap64(x); } +#else + static uint16_t HostToNet16(uint16_t x) { return PortableByteSwap(x); } + static uint32_t HostToNet32(uint32_t x) { return PortableByteSwap(x); } + static uint64_t HostToNet64(uint64_t x) { return PortableByteSwap(x); } +#endif + + // Convert |x| from network order (big endian) to host order (little endian). + static uint16_t NetToHost16(uint16_t x) { return HostToNet16(x); } + static uint32_t NetToHost32(uint32_t x) { return HostToNet32(x); } + static uint64_t NetToHost64(uint64_t x) { return HostToNet64(x); } + + // Left public for tests. + template + static T PortableByteSwap(T input) { + static_assert(std::is_unsigned::value, "T has to be uintNN_t"); + union { + T number; + char bytes[sizeof(T)]; + } value; + value.number = input; + std::reverse(&value.bytes[0], &value.bytes[sizeof(T)]); + return value.number; + } +}; + +enum QuicheVariableLengthIntegerLength : uint8_t { + // Length zero means the variable length integer is not present. + VARIABLE_LENGTH_INTEGER_LENGTH_0 = 0, + VARIABLE_LENGTH_INTEGER_LENGTH_1 = 1, + VARIABLE_LENGTH_INTEGER_LENGTH_2 = 2, + VARIABLE_LENGTH_INTEGER_LENGTH_4 = 4, + VARIABLE_LENGTH_INTEGER_LENGTH_8 = 8, + + // By default we write the IETF long header length using the 2-byte encoding + // of variable length integers, even when the length is below 64, which allows + // us to fill in the length before knowing what the length actually is. + kQuicheDefaultLongHeaderLengthLength = VARIABLE_LENGTH_INTEGER_LENGTH_2, +}; + +} // namespace quiche + +#endif // QUICHE_COMMON_QUICHE_ENDIAN_H_ diff --git a/quiche/common/quiche_endian_test.cc b/quiche/common/quiche_endian_test.cc new file mode 100644 index 000000000000..66527a9a9fb6 --- /dev/null +++ b/quiche/common/quiche_endian_test.cc @@ -0,0 +1,53 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/quiche_endian.h" + +#include "quiche/common/platform/api/quiche_test.h" + +namespace quiche { +namespace test { +namespace { + +const uint16_t k16BitTestData = 0xaabb; +const uint16_t k16BitSwappedTestData = 0xbbaa; +const uint32_t k32BitTestData = 0xaabbccdd; +const uint32_t k32BitSwappedTestData = 0xddccbbaa; +const uint64_t k64BitTestData = 0xaabbccdd44332211; +const uint64_t k64BitSwappedTestData = 0x11223344ddccbbaa; + +class QuicheEndianTest : public QuicheTest {}; + +// Test portable version. Since we normally compile with either GCC or Clang, +// it will very rarely used otherwise. +TEST_F(QuicheEndianTest, Portable) { + EXPECT_EQ(k16BitSwappedTestData, + QuicheEndian::PortableByteSwap(k16BitTestData)); + EXPECT_EQ(k32BitSwappedTestData, + QuicheEndian::PortableByteSwap(k32BitTestData)); + EXPECT_EQ(k64BitSwappedTestData, + QuicheEndian::PortableByteSwap(k64BitTestData)); +} + +TEST_F(QuicheEndianTest, HostToNet) { + EXPECT_EQ(k16BitSwappedTestData, + quiche::QuicheEndian::HostToNet16(k16BitTestData)); + EXPECT_EQ(k32BitSwappedTestData, + quiche::QuicheEndian::HostToNet32(k32BitTestData)); + EXPECT_EQ(k64BitSwappedTestData, + quiche::QuicheEndian::HostToNet64(k64BitTestData)); +} + +TEST_F(QuicheEndianTest, NetToHost) { + EXPECT_EQ(k16BitTestData, + quiche::QuicheEndian::NetToHost16(k16BitSwappedTestData)); + EXPECT_EQ(k32BitTestData, + quiche::QuicheEndian::NetToHost32(k32BitSwappedTestData)); + EXPECT_EQ(k64BitTestData, + quiche::QuicheEndian::NetToHost64(k64BitSwappedTestData)); +} + +} // namespace +} // namespace test +} // namespace quiche diff --git a/quiche/common/quiche_ip_address.cc b/quiche/common/quiche_ip_address.cc new file mode 100644 index 000000000000..342597bfab31 --- /dev/null +++ b/quiche/common/quiche_ip_address.cc @@ -0,0 +1,258 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/quiche_ip_address.h" + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_ip_address_family.h" + +namespace quiche { + +QuicheIpAddress QuicheIpAddress::Loopback4() { + QuicheIpAddress result; + result.family_ = IpAddressFamily::IP_V4; + result.address_.bytes[0] = 127; + result.address_.bytes[1] = 0; + result.address_.bytes[2] = 0; + result.address_.bytes[3] = 1; + return result; +} + +QuicheIpAddress QuicheIpAddress::Loopback6() { + QuicheIpAddress result; + result.family_ = IpAddressFamily::IP_V6; + uint8_t* bytes = result.address_.bytes; + memset(bytes, 0, 15); + bytes[15] = 1; + return result; +} + +QuicheIpAddress QuicheIpAddress::Any4() { + in_addr address; + memset(&address, 0, sizeof(address)); + return QuicheIpAddress(address); +} + +QuicheIpAddress QuicheIpAddress::Any6() { + in6_addr address; + memset(&address, 0, sizeof(address)); + return QuicheIpAddress(address); +} + +QuicheIpAddress::QuicheIpAddress() : family_(IpAddressFamily::IP_UNSPEC) {} + +QuicheIpAddress::QuicheIpAddress(const in_addr& ipv4_address) + : family_(IpAddressFamily::IP_V4) { + address_.v4 = ipv4_address; +} +QuicheIpAddress::QuicheIpAddress(const in6_addr& ipv6_address) + : family_(IpAddressFamily::IP_V6) { + address_.v6 = ipv6_address; +} + +bool operator==(QuicheIpAddress lhs, QuicheIpAddress rhs) { + if (lhs.family_ != rhs.family_) { + return false; + } + switch (lhs.family_) { + case IpAddressFamily::IP_V4: + return std::equal(lhs.address_.bytes, + lhs.address_.bytes + QuicheIpAddress::kIPv4AddressSize, + rhs.address_.bytes); + case IpAddressFamily::IP_V6: + return std::equal(lhs.address_.bytes, + lhs.address_.bytes + QuicheIpAddress::kIPv6AddressSize, + rhs.address_.bytes); + case IpAddressFamily::IP_UNSPEC: + return true; + } + QUICHE_BUG(quiche_bug_10126_2) + << "Invalid IpAddressFamily " << static_cast(lhs.family_); + return false; +} + +bool operator!=(QuicheIpAddress lhs, QuicheIpAddress rhs) { + return !(lhs == rhs); +} + +bool QuicheIpAddress::IsInitialized() const { + return family_ != IpAddressFamily::IP_UNSPEC; +} + +IpAddressFamily QuicheIpAddress::address_family() const { return family_; } + +int QuicheIpAddress::AddressFamilyToInt() const { + return ToPlatformAddressFamily(family_); +} + +std::string QuicheIpAddress::ToPackedString() const { + switch (family_) { + case IpAddressFamily::IP_V4: + return std::string(address_.chars, sizeof(address_.v4)); + case IpAddressFamily::IP_V6: + return std::string(address_.chars, sizeof(address_.v6)); + case IpAddressFamily::IP_UNSPEC: + return ""; + } + QUICHE_BUG(quiche_bug_10126_3) + << "Invalid IpAddressFamily " << static_cast(family_); + return ""; +} + +std::string QuicheIpAddress::ToString() const { + if (!IsInitialized()) { + return ""; + } + + char buffer[INET6_ADDRSTRLEN] = {0}; + const char* result = + inet_ntop(AddressFamilyToInt(), address_.bytes, buffer, sizeof(buffer)); + QUICHE_BUG_IF(quiche_bug_10126_4, result == nullptr) + << "Failed to convert an IP address to string"; + return buffer; +} + +static const uint8_t kMappedAddressPrefix[] = { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, +}; + +QuicheIpAddress QuicheIpAddress::Normalized() const { + if (!IsIPv6()) { + return *this; + } + if (!std::equal(std::begin(kMappedAddressPrefix), + std::end(kMappedAddressPrefix), address_.bytes)) { + return *this; + } + + in_addr result; + memcpy(&result, &address_.bytes[12], sizeof(result)); + return QuicheIpAddress(result); +} + +QuicheIpAddress QuicheIpAddress::DualStacked() const { + if (!IsIPv4()) { + return *this; + } + + QuicheIpAddress result; + result.family_ = IpAddressFamily::IP_V6; + memcpy(result.address_.bytes, kMappedAddressPrefix, + sizeof(kMappedAddressPrefix)); + memcpy(result.address_.bytes + 12, address_.bytes, kIPv4AddressSize); + return result; +} + +bool QuicheIpAddress::FromPackedString(const char* data, size_t length) { + switch (length) { + case kIPv4AddressSize: + family_ = IpAddressFamily::IP_V4; + break; + case kIPv6AddressSize: + family_ = IpAddressFamily::IP_V6; + break; + default: + return false; + } + memcpy(address_.chars, data, length); + return true; +} + +bool QuicheIpAddress::FromString(std::string str) { + for (IpAddressFamily family : + {IpAddressFamily::IP_V6, IpAddressFamily::IP_V4}) { + int result = + inet_pton(ToPlatformAddressFamily(family), str.c_str(), address_.bytes); + if (result > 0) { + family_ = family; + return true; + } + } + return false; +} + +bool QuicheIpAddress::IsIPv4() const { + return family_ == IpAddressFamily::IP_V4; +} + +bool QuicheIpAddress::IsIPv6() const { + return family_ == IpAddressFamily::IP_V6; +} + +bool QuicheIpAddress::InSameSubnet(const QuicheIpAddress& other, + int subnet_length) { + if (!IsInitialized()) { + QUICHE_BUG(quiche_bug_10126_5) + << "Attempting to do subnet matching on undefined address"; + return false; + } + if ((IsIPv4() && subnet_length > 32) || (IsIPv6() && subnet_length > 128)) { + QUICHE_BUG(quiche_bug_10126_6) << "Subnet mask is out of bounds"; + return false; + } + + int bytes_to_check = subnet_length / 8; + int bits_to_check = subnet_length % 8; + const uint8_t* const lhs = address_.bytes; + const uint8_t* const rhs = other.address_.bytes; + if (!std::equal(lhs, lhs + bytes_to_check, rhs)) { + return false; + } + if (bits_to_check == 0) { + return true; + } + QUICHE_DCHECK_LT(static_cast(bytes_to_check), sizeof(address_.bytes)); + int mask = (~0u) << (8u - bits_to_check); + return (lhs[bytes_to_check] & mask) == (rhs[bytes_to_check] & mask); +} + +in_addr QuicheIpAddress::GetIPv4() const { + QUICHE_DCHECK(IsIPv4()); + return address_.v4; +} + +in6_addr QuicheIpAddress::GetIPv6() const { + QUICHE_DCHECK(IsIPv6()); + return address_.v6; +} + +QuicheIpPrefix::QuicheIpPrefix() : prefix_length_(0) {} +QuicheIpPrefix::QuicheIpPrefix(const QuicheIpAddress& address) + : address_(address) { + if (address_.IsIPv6()) { + prefix_length_ = QuicheIpAddress::kIPv6AddressSize * 8; + } else if (address_.IsIPv4()) { + prefix_length_ = QuicheIpAddress::kIPv4AddressSize * 8; + } else { + prefix_length_ = 0; + } +} +QuicheIpPrefix::QuicheIpPrefix(const QuicheIpAddress& address, + uint8_t prefix_length) + : address_(address), prefix_length_(prefix_length) { + QUICHE_DCHECK(prefix_length <= QuicheIpPrefix(address).prefix_length()) + << "prefix_length cannot be longer than the size of the IP address"; +} + +std::string QuicheIpPrefix::ToString() const { + return absl::StrCat(address_.ToString(), "/", prefix_length_); +} + +bool operator==(const QuicheIpPrefix& lhs, const QuicheIpPrefix& rhs) { + return lhs.address_ == rhs.address_ && + lhs.prefix_length_ == rhs.prefix_length_; +} + +bool operator!=(const QuicheIpPrefix& lhs, const QuicheIpPrefix& rhs) { + return !(lhs == rhs); +} + +} // namespace quiche diff --git a/quiche/common/quiche_ip_address.h b/quiche/common/quiche_ip_address.h new file mode 100644 index 000000000000..a6eeffd86bb1 --- /dev/null +++ b/quiche/common/quiche_ip_address.h @@ -0,0 +1,130 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_QUICHE_IP_ADDRESS_H_ +#define QUICHE_COMMON_QUICHE_IP_ADDRESS_H_ + +#include +#if defined(_WIN32) +#include +#include +#else +#include +#include +#include +#include +#endif + +#include +#include + +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/quiche_ip_address_family.h" + +namespace quiche { + +// Represents an IP address. +class QUICHE_EXPORT QuicheIpAddress { + public: + // Sizes of IP addresses of different types, in bytes. + enum : size_t { + kIPv4AddressSize = 32 / 8, + kIPv6AddressSize = 128 / 8, + kMaxAddressSize = kIPv6AddressSize, + }; + + // TODO(fayang): Remove Loopback*() and use TestLoopback*() in tests. + static QuicheIpAddress Loopback4(); + static QuicheIpAddress Loopback6(); + static QuicheIpAddress Any4(); + static QuicheIpAddress Any6(); + + QuicheIpAddress(); + QuicheIpAddress(const QuicheIpAddress& other) = default; + explicit QuicheIpAddress(const in_addr& ipv4_address); + explicit QuicheIpAddress(const in6_addr& ipv6_address); + QuicheIpAddress& operator=(const QuicheIpAddress& other) = default; + QuicheIpAddress& operator=(QuicheIpAddress&& other) = default; + QUICHE_EXPORT friend bool operator==(QuicheIpAddress lhs, + QuicheIpAddress rhs); + QUICHE_EXPORT friend bool operator!=(QuicheIpAddress lhs, + QuicheIpAddress rhs); + + bool IsInitialized() const; + IpAddressFamily address_family() const; + int AddressFamilyToInt() const; + // Returns the address as a sequence of bytes in network-byte-order. IPv4 will + // be 4 bytes. IPv6 will be 16 bytes. + std::string ToPackedString() const; + // Returns string representation of the address. + std::string ToString() const; + // Normalizes the address representation with respect to IPv4 addresses, i.e, + // mapped IPv4 addresses ("::ffff:X.Y.Z.Q") are converted to pure IPv4 + // addresses. All other IPv4, IPv6, and empty values are left unchanged. + QuicheIpAddress Normalized() const; + // Returns an address suitable for use in IPv6-aware contexts. This is the + // opposite of NormalizeIPAddress() above. IPv4 addresses are converted into + // their IPv4-mapped address equivalents (e.g. 192.0.2.1 becomes + // ::ffff:192.0.2.1). IPv6 addresses are a noop (they are returned + // unchanged). + QuicheIpAddress DualStacked() const; + bool FromPackedString(const char* data, size_t length); + bool FromString(std::string str); + bool IsIPv4() const; + bool IsIPv6() const; + bool InSameSubnet(const QuicheIpAddress& other, int subnet_length); + + in_addr GetIPv4() const; + in6_addr GetIPv6() const; + + private: + union { + in_addr v4; + in6_addr v6; + uint8_t bytes[kMaxAddressSize]; + char chars[kMaxAddressSize]; + } address_; + IpAddressFamily family_; +}; + +inline std::ostream& operator<<(std::ostream& os, + const QuicheIpAddress address) { + os << address.ToString(); + return os; +} + +// Represents an IP prefix, which is an IP address and a prefix length in bits. +class QUICHE_EXPORT QuicheIpPrefix { + public: + QuicheIpPrefix(); + explicit QuicheIpPrefix(const QuicheIpAddress& address); + explicit QuicheIpPrefix(const QuicheIpAddress& address, + uint8_t prefix_length); + + QuicheIpAddress address() const { return address_; } + uint8_t prefix_length() const { return prefix_length_; } + // Human-readable string representation of the prefix suitable for logging. + std::string ToString() const; + + QuicheIpPrefix(const QuicheIpPrefix& other) = default; + QuicheIpPrefix& operator=(const QuicheIpPrefix& other) = default; + QuicheIpPrefix& operator=(QuicheIpPrefix&& other) = default; + QUICHE_EXPORT friend bool operator==(const QuicheIpPrefix& lhs, + const QuicheIpPrefix& rhs); + QUICHE_EXPORT friend bool operator!=(const QuicheIpPrefix& lhs, + const QuicheIpPrefix& rhs); + + private: + QuicheIpAddress address_; + uint8_t prefix_length_; +}; + +inline std::ostream& operator<<(std::ostream& os, const QuicheIpPrefix prefix) { + os << prefix.ToString(); + return os; +} + +} // namespace quiche + +#endif // QUICHE_COMMON_QUICHE_IP_ADDRESS_H_ diff --git a/quiche/common/quiche_ip_address_family.cc b/quiche/common/quiche_ip_address_family.cc new file mode 100644 index 000000000000..885ddb6393a3 --- /dev/null +++ b/quiche/common/quiche_ip_address_family.cc @@ -0,0 +1,47 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/quiche_ip_address_family.h" + +#include "quiche/common/platform/api/quiche_bug_tracker.h" + +#if defined(_WIN32) +#include +#else +#include +#endif // defined(_WIN32) + +namespace quiche { + +int ToPlatformAddressFamily(IpAddressFamily family) { + switch (family) { + case IpAddressFamily::IP_V4: + return AF_INET; + case IpAddressFamily::IP_V6: + return AF_INET6; + case IpAddressFamily::IP_UNSPEC: + return AF_UNSPEC; + default: + QUICHE_BUG(quic_bug_10126_1) + << "Invalid IpAddressFamily " << static_cast(family); + return AF_UNSPEC; + } +} + +IpAddressFamily FromPlatformAddressFamily(int family) { + switch (family) { + case AF_INET: + return IpAddressFamily::IP_V4; + case AF_INET6: + return IpAddressFamily::IP_V6; + case AF_UNSPEC: + return IpAddressFamily::IP_UNSPEC; + default: + QUICHE_BUG(quic_FromPlatformAddressFamily_unrecognized_family) + << "Invalid platform address family int " << family; + return IpAddressFamily::IP_UNSPEC; + } +} + +} // namespace quiche diff --git a/quiche/common/quiche_ip_address_family.h b/quiche/common/quiche_ip_address_family.h new file mode 100644 index 000000000000..1fbec53b8ac1 --- /dev/null +++ b/quiche/common/quiche_ip_address_family.h @@ -0,0 +1,23 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_QUICHE_IP_ADDRESS_FAMILY_H_ +#define QUICHE_COMMON_QUICHE_IP_ADDRESS_FAMILY_H_ + +namespace quiche { + +// IP address family type used in QUIC. This hides platform dependant IP address +// family types. +enum class IpAddressFamily { + IP_V4, + IP_V6, + IP_UNSPEC, +}; + +int ToPlatformAddressFamily(IpAddressFamily family); +IpAddressFamily FromPlatformAddressFamily(int family); + +} // namespace quiche + +#endif // QUICHE_COMMON_QUICHE_IP_ADDRESS_FAMILY_H_ diff --git a/quiche/common/quiche_ip_address_test.cc b/quiche/common/quiche_ip_address_test.cc new file mode 100644 index 000000000000..609b6b250fdc --- /dev/null +++ b/quiche/common/quiche_ip_address_test.cc @@ -0,0 +1,142 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/quiche_ip_address.h" + +#include + +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/quiche_ip_address_family.h" + +namespace quiche { +namespace test { +namespace { + +TEST(QuicheIpAddressTest, IPv4) { + QuicheIpAddress ip_address; + EXPECT_FALSE(ip_address.IsInitialized()); + + EXPECT_TRUE(ip_address.FromString("127.0.52.223")); + EXPECT_TRUE(ip_address.IsInitialized()); + + EXPECT_EQ(IpAddressFamily::IP_V4, ip_address.address_family()); + EXPECT_TRUE(ip_address.IsIPv4()); + EXPECT_FALSE(ip_address.IsIPv6()); + + EXPECT_EQ("127.0.52.223", ip_address.ToString()); + const in_addr v4_address = ip_address.GetIPv4(); + const uint8_t* const v4_address_ptr = + reinterpret_cast(&v4_address); + EXPECT_EQ(127u, *(v4_address_ptr + 0)); + EXPECT_EQ(0u, *(v4_address_ptr + 1)); + EXPECT_EQ(52u, *(v4_address_ptr + 2)); + EXPECT_EQ(223u, *(v4_address_ptr + 3)); +} + +TEST(QuicheIpAddressTest, IPv6) { + QuicheIpAddress ip_address; + EXPECT_FALSE(ip_address.IsInitialized()); + + EXPECT_TRUE(ip_address.FromString("fe80::1ff:fe23:4567")); + EXPECT_TRUE(ip_address.IsInitialized()); + + EXPECT_EQ(IpAddressFamily::IP_V6, ip_address.address_family()); + EXPECT_FALSE(ip_address.IsIPv4()); + EXPECT_TRUE(ip_address.IsIPv6()); + + EXPECT_EQ("fe80::1ff:fe23:4567", ip_address.ToString()); + const in6_addr v6_address = ip_address.GetIPv6(); + const uint16_t* const v6_address_ptr = + reinterpret_cast(&v6_address); + EXPECT_EQ(0x80feu, *(v6_address_ptr + 0)); + EXPECT_EQ(0x0000u, *(v6_address_ptr + 1)); + EXPECT_EQ(0x0000u, *(v6_address_ptr + 2)); + EXPECT_EQ(0x0000u, *(v6_address_ptr + 3)); + EXPECT_EQ(0x0000u, *(v6_address_ptr + 4)); + EXPECT_EQ(0xff01u, *(v6_address_ptr + 5)); + EXPECT_EQ(0x23feu, *(v6_address_ptr + 6)); + EXPECT_EQ(0x6745u, *(v6_address_ptr + 7)); + + EXPECT_EQ(ip_address, ip_address.Normalized()); + EXPECT_EQ(ip_address, ip_address.DualStacked()); +} + +TEST(QuicheIpAddressTest, FromPackedString) { + QuicheIpAddress loopback4, loopback6; + const char loopback4_packed[] = "\x7f\0\0\x01"; + const char loopback6_packed[] = "\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x01"; + EXPECT_TRUE(loopback4.FromPackedString(loopback4_packed, 4)); + EXPECT_TRUE(loopback6.FromPackedString(loopback6_packed, 16)); + EXPECT_EQ(loopback4, QuicheIpAddress::Loopback4()); + EXPECT_EQ(loopback6, QuicheIpAddress::Loopback6()); +} + +TEST(QuicheIpAddressTest, MappedAddress) { + QuicheIpAddress ipv4_address; + QuicheIpAddress mapped_address; + + EXPECT_TRUE(ipv4_address.FromString("127.0.0.1")); + EXPECT_TRUE(mapped_address.FromString("::ffff:7f00:1")); + + EXPECT_EQ(mapped_address, ipv4_address.DualStacked()); + EXPECT_EQ(ipv4_address, mapped_address.Normalized()); +} + +TEST(QuicheIpAddressTest, Subnets) { + struct { + const char* address1; + const char* address2; + int subnet_size; + bool same_subnet; + } test_cases[] = { + {"127.0.0.1", "127.0.0.2", 24, true}, + {"8.8.8.8", "127.0.0.1", 24, false}, + {"8.8.8.8", "127.0.0.1", 16, false}, + {"8.8.8.8", "127.0.0.1", 8, false}, + {"8.8.8.8", "127.0.0.1", 2, false}, + {"8.8.8.8", "127.0.0.1", 1, true}, + + {"127.0.0.1", "127.0.0.128", 24, true}, + {"127.0.0.1", "127.0.0.128", 25, false}, + {"127.0.0.1", "127.0.0.127", 25, true}, + + {"127.0.0.1", "127.0.0.0", 30, true}, + {"127.0.0.1", "127.0.0.1", 30, true}, + {"127.0.0.1", "127.0.0.2", 30, true}, + {"127.0.0.1", "127.0.0.3", 30, true}, + {"127.0.0.1", "127.0.0.4", 30, false}, + + {"127.0.0.1", "127.0.0.2", 31, false}, + {"127.0.0.1", "127.0.0.0", 31, true}, + + {"::1", "fe80::1", 8, false}, + {"::1", "fe80::1", 1, false}, + {"::1", "fe80::1", 0, true}, + {"fe80::1", "fe80::2", 126, true}, + {"fe80::1", "fe80::2", 127, false}, + }; + + for (const auto& test_case : test_cases) { + QuicheIpAddress address1, address2; + ASSERT_TRUE(address1.FromString(test_case.address1)); + ASSERT_TRUE(address2.FromString(test_case.address2)); + EXPECT_EQ(test_case.same_subnet, + address1.InSameSubnet(address2, test_case.subnet_size)) + << "Addresses: " << test_case.address1 << ", " << test_case.address2 + << "; subnet: /" << test_case.subnet_size; + } +} + +TEST(QuicheIpAddress, LoopbackAddresses) { + QuicheIpAddress loopback4; + QuicheIpAddress loopback6; + ASSERT_TRUE(loopback4.FromString("127.0.0.1")); + ASSERT_TRUE(loopback6.FromString("::1")); + EXPECT_EQ(loopback4, QuicheIpAddress::Loopback4()); + EXPECT_EQ(loopback6, QuicheIpAddress::Loopback6()); +} + +} // namespace +} // namespace test +} // namespace quiche diff --git a/quiche/common/quiche_linked_hash_map.h b/quiche/common/quiche_linked_hash_map.h new file mode 100644 index 000000000000..25c70fb0314b --- /dev/null +++ b/quiche/common/quiche_linked_hash_map.h @@ -0,0 +1,237 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This is a simplistic insertion-ordered map. It behaves similarly to an STL +// map, but only implements a small subset of the map's methods. Internally, we +// just keep a map and a list going in parallel. +// +// This class provides no thread safety guarantees, beyond what you would +// normally see with std::list. +// +// Iterators point into the list and should be stable in the face of +// mutations, except for an iterator pointing to an element that was just +// deleted. + +#ifndef QUICHE_COMMON_QUICHE_LINKED_HASH_MAP_H_ +#define QUICHE_COMMON_QUICHE_LINKED_HASH_MAP_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/hash/hash.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace quiche { + +// This holds a list of pair items. This list is what gets +// traversed, and it's iterators from this list that we return from +// begin/end/find. +// +// We also keep a set for find. Since std::list is a +// doubly-linked list, the iterators should remain stable. + +// QUICHE_NO_EXPORT comments suppress erroneous presubmit failures. +template , // QUICHE_NO_EXPORT + class Eq = std::equal_to> // QUICHE_NO_EXPORT +class QuicheLinkedHashMap { // QUICHE_NO_EXPORT + private: + typedef std::list> ListType; + typedef absl::flat_hash_map + MapType; + + public: + typedef typename ListType::iterator iterator; + typedef typename ListType::reverse_iterator reverse_iterator; + typedef typename ListType::const_iterator const_iterator; + typedef typename ListType::const_reverse_iterator const_reverse_iterator; + typedef typename MapType::key_type key_type; + typedef typename ListType::value_type value_type; + typedef typename ListType::size_type size_type; + + QuicheLinkedHashMap() = default; + explicit QuicheLinkedHashMap(size_type bucket_count) : map_(bucket_count) {} + + QuicheLinkedHashMap(const QuicheLinkedHashMap& other) = delete; + QuicheLinkedHashMap& operator=(const QuicheLinkedHashMap& other) = delete; + QuicheLinkedHashMap(QuicheLinkedHashMap&& other) = default; + QuicheLinkedHashMap& operator=(QuicheLinkedHashMap&& other) = default; + + // Returns an iterator to the first (insertion-ordered) element. Like a map, + // this can be dereferenced to a pair. + iterator begin() { return list_.begin(); } + const_iterator begin() const { return list_.begin(); } + + // Returns an iterator beyond the last element. + iterator end() { return list_.end(); } + const_iterator end() const { return list_.end(); } + + // Returns an iterator to the last (insertion-ordered) element. Like a map, + // this can be dereferenced to a pair. + reverse_iterator rbegin() { return list_.rbegin(); } + const_reverse_iterator rbegin() const { return list_.rbegin(); } + + // Returns an iterator beyond the first element. + reverse_iterator rend() { return list_.rend(); } + const_reverse_iterator rend() const { return list_.rend(); } + + // Front and back accessors common to many stl containers. + + // Returns the earliest-inserted element + const value_type& front() const { return list_.front(); } + + // Returns the earliest-inserted element. + value_type& front() { return list_.front(); } + + // Returns the most-recently-inserted element. + const value_type& back() const { return list_.back(); } + + // Returns the most-recently-inserted element. + value_type& back() { return list_.back(); } + + // Clears the map of all values. + void clear() { + map_.clear(); + list_.clear(); + } + + // Returns true iff the map is empty. + bool empty() const { return list_.empty(); } + + // Removes the first element from the list. + void pop_front() { erase(begin()); } + + // Erases values with the provided key. Returns the number of elements + // erased. In this implementation, this will be 0 or 1. + size_type erase(const Key& key) { + typename MapType::iterator found = map_.find(key); + if (found == map_.end()) { + return 0; + } + + list_.erase(found->second); + map_.erase(found); + + return 1; + } + + // Erases the item that 'position' points to. Returns an iterator that points + // to the item that comes immediately after the deleted item in the list, or + // end(). + // If the provided iterator is invalid or there is inconsistency between the + // map and list, a QUICHE_CHECK() error will occur. + iterator erase(iterator position) { + typename MapType::iterator found = map_.find(position->first); + QUICHE_CHECK(found->second == position) + << "Inconsistent iterator for map and list, or the iterator is " + "invalid."; + + map_.erase(found); + return list_.erase(position); + } + + // Erases all the items in the range [first, last). Returns an iterator that + // points to the item that comes immediately after the last deleted item in + // the list, or end(). + iterator erase(iterator first, iterator last) { + while (first != last && first != end()) { + first = erase(first); + } + return first; + } + + // Finds the element with the given key. Returns an iterator to the + // value found, or to end() if the value was not found. Like a map, this + // iterator points to a pair. + iterator find(const Key& key) { + typename MapType::iterator found = map_.find(key); + if (found == map_.end()) { + return end(); + } + return found->second; + } + + const_iterator find(const Key& key) const { + typename MapType::const_iterator found = map_.find(key); + if (found == map_.end()) { + return end(); + } + return found->second; + } + + bool contains(const Key& key) const { return find(key) != end(); } + + // Returns the value mapped to key, or an inserted iterator to that position + // in the map. + Value& operator[](const key_type& key) { + return (*((this->insert(std::make_pair(key, Value()))).first)).second; + } + + // Inserts an element into the map + std::pair insert(const std::pair& pair) { + return InsertInternal(pair); + } + + // Inserts an element into the map + std::pair insert(std::pair&& pair) { + return InsertInternal(std::move(pair)); + } + + // Derive size_ from map_, as list::size might be O(N). + size_type size() const { return map_.size(); } + + template + std::pair emplace(Args&&... args) { + ListType node_donor; + auto node_pos = + node_donor.emplace(node_donor.end(), std::forward(args)...); + const auto& k = node_pos->first; + auto ins = map_.insert({k, node_pos}); + if (!ins.second) { + return {ins.first->second, false}; + } + list_.splice(list_.end(), node_donor, node_pos); + return {ins.first->second, true}; + } + + void swap(QuicheLinkedHashMap& other) { + map_.swap(other.map_); + list_.swap(other.list_); + } + + private: + template + std::pair InsertInternal(U&& pair) { + auto insert_result = map_.try_emplace(pair.first); + auto map_iter = insert_result.first; + + // If the map already contains this key, return a pair with an iterator to + // it, and false indicating that we didn't insert anything. + if (!insert_result.second) { + return {map_iter->second, false}; + } + + // Otherwise, insert into the list, and set value in map. + auto list_iter = list_.insert(list_.end(), std::forward(pair)); + map_iter->second = list_iter; + + return {list_iter, true}; + } + + // The map component, used for speedy lookups + MapType map_; + + // The list component, used for maintaining insertion order + ListType list_; +}; + +} // namespace quiche + +#endif // QUICHE_COMMON_QUICHE_LINKED_HASH_MAP_H_ diff --git a/quiche/common/quiche_linked_hash_map_test.cc b/quiche/common/quiche_linked_hash_map_test.cc new file mode 100644 index 000000000000..0aa9c54bfc29 --- /dev/null +++ b/quiche/common/quiche_linked_hash_map_test.cc @@ -0,0 +1,393 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Tests QuicheLinkedHashMap. + +#include "quiche/common/quiche_linked_hash_map.h" + +#include +#include + +#include "quiche/common/platform/api/quiche_test.h" + +using testing::Pair; +using testing::Pointee; +using testing::UnorderedElementsAre; + +namespace quiche { +namespace test { + +// Tests that move constructor works. +TEST(LinkedHashMapTest, Move) { + // Use unique_ptr as an example of a non-copyable type. + QuicheLinkedHashMap> m; + m[2] = std::make_unique(12); + m[3] = std::make_unique(13); + QuicheLinkedHashMap> n = std::move(m); + EXPECT_THAT(n, + UnorderedElementsAre(Pair(2, Pointee(12)), Pair(3, Pointee(13)))); +} + +TEST(LinkedHashMapTest, CanEmplaceMoveOnly) { + QuicheLinkedHashMap> m; + struct Data { + int k, v; + }; + const Data data[] = {{1, 123}, {3, 345}, {2, 234}, {4, 456}}; + for (const auto& kv : data) { + m.emplace(std::piecewise_construct, std::make_tuple(kv.k), + std::make_tuple(new int{kv.v})); + } + EXPECT_TRUE(m.contains(2)); + auto found = m.find(2); + ASSERT_TRUE(found != m.end()); + EXPECT_EQ(234, *found->second); +} + +struct NoCopy { + explicit NoCopy(int x) : x(x) {} + NoCopy(const NoCopy&) = delete; + NoCopy& operator=(const NoCopy&) = delete; + NoCopy(NoCopy&&) = delete; + NoCopy& operator=(NoCopy&&) = delete; + int x; +}; + +TEST(LinkedHashMapTest, CanEmplaceNoMoveNoCopy) { + QuicheLinkedHashMap m; + struct Data { + int k, v; + }; + const Data data[] = {{1, 123}, {3, 345}, {2, 234}, {4, 456}}; + for (const auto& kv : data) { + m.emplace(std::piecewise_construct, std::make_tuple(kv.k), + std::make_tuple(kv.v)); + } + EXPECT_TRUE(m.contains(2)); + auto found = m.find(2); + ASSERT_TRUE(found != m.end()); + EXPECT_EQ(234, found->second.x); +} + +TEST(LinkedHashMapTest, ConstKeys) { + QuicheLinkedHashMap m; + m.insert(std::make_pair(1, 2)); + // Test that keys are const in iteration. + std::pair& p = *m.begin(); + EXPECT_EQ(1, p.first); +} + +// Tests that iteration from begin() to end() works +TEST(LinkedHashMapTest, Iteration) { + QuicheLinkedHashMap m; + EXPECT_TRUE(m.begin() == m.end()); + + m.insert(std::make_pair(2, 12)); + m.insert(std::make_pair(1, 11)); + m.insert(std::make_pair(3, 13)); + + QuicheLinkedHashMap::iterator i = m.begin(); + ASSERT_TRUE(m.begin() == i); + ASSERT_TRUE(m.end() != i); + EXPECT_EQ(2, i->first); + EXPECT_EQ(12, i->second); + + ++i; + ASSERT_TRUE(m.end() != i); + EXPECT_EQ(1, i->first); + EXPECT_EQ(11, i->second); + + ++i; + ASSERT_TRUE(m.end() != i); + EXPECT_EQ(3, i->first); + EXPECT_EQ(13, i->second); + + ++i; // Should be the end of the line. + ASSERT_TRUE(m.end() == i); +} + +// Tests that reverse iteration from rbegin() to rend() works +TEST(LinkedHashMapTest, ReverseIteration) { + QuicheLinkedHashMap m; + EXPECT_TRUE(m.rbegin() == m.rend()); + + m.insert(std::make_pair(2, 12)); + m.insert(std::make_pair(1, 11)); + m.insert(std::make_pair(3, 13)); + + QuicheLinkedHashMap::reverse_iterator i = m.rbegin(); + ASSERT_TRUE(m.rbegin() == i); + ASSERT_TRUE(m.rend() != i); + EXPECT_EQ(3, i->first); + EXPECT_EQ(13, i->second); + + ++i; + ASSERT_TRUE(m.rend() != i); + EXPECT_EQ(1, i->first); + EXPECT_EQ(11, i->second); + + ++i; + ASSERT_TRUE(m.rend() != i); + EXPECT_EQ(2, i->first); + EXPECT_EQ(12, i->second); + + ++i; // Should be the end of the line. + ASSERT_TRUE(m.rend() == i); +} + +// Tests that clear() works +TEST(LinkedHashMapTest, Clear) { + QuicheLinkedHashMap m; + m.insert(std::make_pair(2, 12)); + m.insert(std::make_pair(1, 11)); + m.insert(std::make_pair(3, 13)); + + ASSERT_EQ(3u, m.size()); + + m.clear(); + + EXPECT_EQ(0u, m.size()); + + m.clear(); // Make sure we can call it on an empty map. + + EXPECT_EQ(0u, m.size()); +} + +// Tests that size() works. +TEST(LinkedHashMapTest, Size) { + QuicheLinkedHashMap m; + EXPECT_EQ(0u, m.size()); + m.insert(std::make_pair(2, 12)); + EXPECT_EQ(1u, m.size()); + m.insert(std::make_pair(1, 11)); + EXPECT_EQ(2u, m.size()); + m.insert(std::make_pair(3, 13)); + EXPECT_EQ(3u, m.size()); + m.clear(); + EXPECT_EQ(0u, m.size()); +} + +// Tests empty() +TEST(LinkedHashMapTest, Empty) { + QuicheLinkedHashMap m; + ASSERT_TRUE(m.empty()); + m.insert(std::make_pair(2, 12)); + ASSERT_FALSE(m.empty()); + m.clear(); + ASSERT_TRUE(m.empty()); +} + +TEST(LinkedHashMapTest, Erase) { + QuicheLinkedHashMap m; + ASSERT_EQ(0u, m.size()); + EXPECT_EQ(0u, m.erase(2)); // Nothing to erase yet + + m.insert(std::make_pair(2, 12)); + ASSERT_EQ(1u, m.size()); + EXPECT_EQ(1u, m.erase(2)); + EXPECT_EQ(0u, m.size()); + + EXPECT_EQ(0u, m.erase(2)); // Make sure nothing bad happens if we repeat. + EXPECT_EQ(0u, m.size()); +} + +TEST(LinkedHashMapTest, Erase2) { + QuicheLinkedHashMap m; + ASSERT_EQ(0u, m.size()); + EXPECT_EQ(0u, m.erase(2)); // Nothing to erase yet + + m.insert(std::make_pair(2, 12)); + m.insert(std::make_pair(1, 11)); + m.insert(std::make_pair(3, 13)); + m.insert(std::make_pair(4, 14)); + ASSERT_EQ(4u, m.size()); + + // Erase middle two + EXPECT_EQ(1u, m.erase(1)); + EXPECT_EQ(1u, m.erase(3)); + + EXPECT_EQ(2u, m.size()); + + // Make sure we can still iterate over everything that's left. + QuicheLinkedHashMap::iterator it = m.begin(); + ASSERT_TRUE(it != m.end()); + EXPECT_EQ(12, it->second); + ++it; + ASSERT_TRUE(it != m.end()); + EXPECT_EQ(14, it->second); + ++it; + ASSERT_TRUE(it == m.end()); + + EXPECT_EQ(0u, m.erase(1)); // Make sure nothing bad happens if we repeat. + ASSERT_EQ(2u, m.size()); + + EXPECT_EQ(1u, m.erase(2)); + EXPECT_EQ(1u, m.erase(4)); + ASSERT_EQ(0u, m.size()); + + EXPECT_EQ(0u, m.erase(1)); // Make sure nothing bad happens if we repeat. + ASSERT_EQ(0u, m.size()); +} + +// Test that erase(iter,iter) and erase(iter) compile and work. +TEST(LinkedHashMapTest, Erase3) { + QuicheLinkedHashMap m; + + m.insert(std::make_pair(1, 11)); + m.insert(std::make_pair(2, 12)); + m.insert(std::make_pair(3, 13)); + m.insert(std::make_pair(4, 14)); + + // Erase middle two + QuicheLinkedHashMap::iterator it2 = m.find(2); + QuicheLinkedHashMap::iterator it4 = m.find(4); + EXPECT_EQ(m.erase(it2, it4), m.find(4)); + EXPECT_EQ(2u, m.size()); + + // Make sure we can still iterate over everything that's left. + QuicheLinkedHashMap::iterator it = m.begin(); + ASSERT_TRUE(it != m.end()); + EXPECT_EQ(11, it->second); + ++it; + ASSERT_TRUE(it != m.end()); + EXPECT_EQ(14, it->second); + ++it; + ASSERT_TRUE(it == m.end()); + + // Erase first one using an iterator. + EXPECT_EQ(m.erase(m.begin()), m.find(4)); + + // Only the last element should be left. + it = m.begin(); + ASSERT_TRUE(it != m.end()); + EXPECT_EQ(14, it->second); + ++it; + ASSERT_TRUE(it == m.end()); +} + +TEST(LinkedHashMapTest, Insertion) { + QuicheLinkedHashMap m; + ASSERT_EQ(0u, m.size()); + std::pair::iterator, bool> result; + + result = m.insert(std::make_pair(2, 12)); + ASSERT_EQ(1u, m.size()); + EXPECT_TRUE(result.second); + EXPECT_EQ(2, result.first->first); + EXPECT_EQ(12, result.first->second); + + result = m.insert(std::make_pair(1, 11)); + ASSERT_EQ(2u, m.size()); + EXPECT_TRUE(result.second); + EXPECT_EQ(1, result.first->first); + EXPECT_EQ(11, result.first->second); + + result = m.insert(std::make_pair(3, 13)); + QuicheLinkedHashMap::iterator result_iterator = result.first; + ASSERT_EQ(3u, m.size()); + EXPECT_TRUE(result.second); + EXPECT_EQ(3, result.first->first); + EXPECT_EQ(13, result.first->second); + + result = m.insert(std::make_pair(3, 13)); + EXPECT_EQ(3u, m.size()); + EXPECT_FALSE(result.second) << "No insertion should have occurred."; + EXPECT_TRUE(result_iterator == result.first) + << "Duplicate insertion should have given us the original iterator."; +} + +static std::pair Pair(int i, int j) { return {i, j}; } + +// Test front accessors. +TEST(LinkedHashMapTest, Front) { + QuicheLinkedHashMap m; + + m.insert(std::make_pair(2, 12)); + m.insert(std::make_pair(1, 11)); + m.insert(std::make_pair(3, 13)); + + EXPECT_EQ(3u, m.size()); + EXPECT_EQ(Pair(2, 12), m.front()); + m.pop_front(); + EXPECT_EQ(2u, m.size()); + EXPECT_EQ(Pair(1, 11), m.front()); + m.pop_front(); + EXPECT_EQ(1u, m.size()); + EXPECT_EQ(Pair(3, 13), m.front()); + m.pop_front(); + EXPECT_TRUE(m.empty()); +} + +TEST(LinkedHashMapTest, Find) { + QuicheLinkedHashMap m; + + EXPECT_TRUE(m.end() == m.find(1)) + << "We shouldn't find anything in an empty map."; + + m.insert(std::make_pair(2, 12)); + EXPECT_TRUE(m.end() == m.find(1)) + << "We shouldn't find an element that doesn't exist in the map."; + + std::pair::iterator, bool> result = + m.insert(std::make_pair(1, 11)); + ASSERT_TRUE(result.second); + ASSERT_TRUE(m.end() != result.first); + EXPECT_TRUE(result.first == m.find(1)) + << "We should have found an element we know exists in the map."; + EXPECT_EQ(11, result.first->second); + + // Check that a follow-up insertion doesn't affect our original + m.insert(std::make_pair(3, 13)); + QuicheLinkedHashMap::iterator it = m.find(1); + ASSERT_TRUE(m.end() != it); + EXPECT_EQ(11, it->second); + + m.clear(); + EXPECT_TRUE(m.end() == m.find(1)) + << "We shouldn't find anything in a map that we've cleared."; +} + +TEST(LinkedHashMapTest, Contains) { + QuicheLinkedHashMap m; + + EXPECT_FALSE(m.contains(1)) << "An empty map shouldn't contain anything."; + + m.insert(std::make_pair(2, 12)); + EXPECT_FALSE(m.contains(1)) + << "The map shouldn't contain an element that doesn't exist."; + + m.insert(std::make_pair(1, 11)); + EXPECT_TRUE(m.contains(1)) + << "The map should contain an element that we know exists."; + + m.clear(); + EXPECT_FALSE(m.contains(1)) + << "A map that we've cleared shouldn't contain anything."; +} + +TEST(LinkedHashMapTest, Swap) { + QuicheLinkedHashMap m1; + QuicheLinkedHashMap m2; + m1.insert(std::make_pair(1, 1)); + m1.insert(std::make_pair(2, 2)); + m2.insert(std::make_pair(3, 3)); + ASSERT_EQ(2u, m1.size()); + ASSERT_EQ(1u, m2.size()); + m1.swap(m2); + ASSERT_EQ(1u, m1.size()); + ASSERT_EQ(2u, m2.size()); +} + +TEST(LinkedHashMapTest, CustomHashAndEquality) { + struct CustomIntHash { + size_t operator()(int x) const { return x; } + }; + QuicheLinkedHashMap m; + m.insert(std::make_pair(1, 1)); + EXPECT_TRUE(m.contains(1)); + EXPECT_EQ(1, m[1]); +} + +} // namespace test +} // namespace quiche diff --git a/quiche/common/quiche_mem_slice_storage.cc b/quiche/common/quiche_mem_slice_storage.cc new file mode 100644 index 000000000000..4b304af9bbfc --- /dev/null +++ b/quiche/common/quiche_mem_slice_storage.cc @@ -0,0 +1,34 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/quiche_mem_slice_storage.h" + +#include "quiche/quic/core/quic_utils.h" + +namespace quiche { + +QuicheMemSliceStorage::QuicheMemSliceStorage( + const struct iovec* iov, int iov_count, QuicheBufferAllocator* allocator, + const quic::QuicByteCount max_slice_len) { + if (iov == nullptr) { + return; + } + quic::QuicByteCount write_len = 0; + for (int i = 0; i < iov_count; ++i) { + write_len += iov[i].iov_len; + } + QUICHE_DCHECK_LT(0u, write_len); + + size_t io_offset = 0; + while (write_len > 0) { + size_t slice_len = std::min(write_len, max_slice_len); + QuicheBuffer buffer = QuicheBuffer::CopyFromIovec(allocator, iov, iov_count, + io_offset, slice_len); + storage_.push_back(QuicheMemSlice(std::move(buffer))); + write_len -= slice_len; + io_offset += slice_len; + } +} + +} // namespace quiche diff --git a/quiche/common/quiche_mem_slice_storage.h b/quiche/common/quiche_mem_slice_storage.h new file mode 100644 index 000000000000..1439d636f8d5 --- /dev/null +++ b/quiche/common/quiche_mem_slice_storage.h @@ -0,0 +1,43 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_QUICHE_MEM_SLICE_STORAGE_H_ +#define QUICHE_COMMON_QUICHE_MEM_SLICE_STORAGE_H_ + +#include + +#include "absl/types/span.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_iovec.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/common/quiche_buffer_allocator.h" + +namespace quiche { + +// QuicheMemSliceStorage is a container class that store QuicheMemSlices for +// further use cases such as turning into QuicheMemSliceSpan. +class QUICHE_EXPORT QuicheMemSliceStorage { + public: + QuicheMemSliceStorage(const struct iovec* iov, int iov_count, + QuicheBufferAllocator* allocator, + const quic::QuicByteCount max_slice_len); + + QuicheMemSliceStorage(const QuicheMemSliceStorage& other) = delete; + QuicheMemSliceStorage& operator=(const QuicheMemSliceStorage& other) = delete; + QuicheMemSliceStorage(QuicheMemSliceStorage&& other) = default; + QuicheMemSliceStorage& operator=(QuicheMemSliceStorage&& other) = default; + + ~QuicheMemSliceStorage() = default; + + // Return a QuicheMemSliceSpan form of the storage. + absl::Span ToSpan() { return absl::MakeSpan(storage_); } + + private: + std::vector storage_; +}; + +} // namespace quiche + +#endif // QUICHE_COMMON_QUICHE_MEM_SLICE_STORAGE_H_ diff --git a/quiche/common/quiche_mem_slice_storage_test.cc b/quiche/common/quiche_mem_slice_storage_test.cc new file mode 100644 index 000000000000..8b7ed1a4d032 --- /dev/null +++ b/quiche/common/quiche_mem_slice_storage_test.cc @@ -0,0 +1,61 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/quiche_mem_slice_storage.h" + +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/simple_buffer_allocator.h" + +namespace quiche { +namespace test { +namespace { + +class QuicheMemSliceStorageImplTest : public QuicheTest { + public: + QuicheMemSliceStorageImplTest() = default; +}; + +TEST_F(QuicheMemSliceStorageImplTest, EmptyIov) { + QuicheMemSliceStorage storage(nullptr, 0, nullptr, 1024); + EXPECT_TRUE(storage.ToSpan().empty()); +} + +TEST_F(QuicheMemSliceStorageImplTest, SingleIov) { + SimpleBufferAllocator allocator; + std::string body(3, 'c'); + struct iovec iov = {const_cast(body.data()), body.length()}; + QuicheMemSliceStorage storage(&iov, 1, &allocator, 1024); + auto span = storage.ToSpan(); + EXPECT_EQ("ccc", span[0].AsStringView()); + EXPECT_NE(static_cast(span[0].data()), body.data()); +} + +TEST_F(QuicheMemSliceStorageImplTest, MultipleIovInSingleSlice) { + SimpleBufferAllocator allocator; + std::string body1(3, 'a'); + std::string body2(4, 'b'); + struct iovec iov[] = {{const_cast(body1.data()), body1.length()}, + {const_cast(body2.data()), body2.length()}}; + + QuicheMemSliceStorage storage(iov, 2, &allocator, 1024); + auto span = storage.ToSpan(); + EXPECT_EQ("aaabbbb", span[0].AsStringView()); +} + +TEST_F(QuicheMemSliceStorageImplTest, MultipleIovInMultipleSlice) { + SimpleBufferAllocator allocator; + std::string body1(4, 'a'); + std::string body2(4, 'b'); + struct iovec iov[] = {{const_cast(body1.data()), body1.length()}, + {const_cast(body2.data()), body2.length()}}; + + QuicheMemSliceStorage storage(iov, 2, &allocator, 4); + auto span = storage.ToSpan(); + EXPECT_EQ("aaaa", span[0].AsStringView()); + EXPECT_EQ("bbbb", span[1].AsStringView()); +} + +} // namespace +} // namespace test +} // namespace quiche diff --git a/quiche/common/quiche_protocol_flags_list.h b/quiche/common/quiche_protocol_flags_list.h new file mode 100644 index 000000000000..90d21a361c65 --- /dev/null +++ b/quiche/common/quiche_protocol_flags_list.h @@ -0,0 +1,15 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// NOLINT(build/header_guard) +// This file intentionally does not have header guards, it's intended to be +// included multiple times, each time with a different definition of +// QUICHE_PROTOCOL_FLAG. + +#if defined(QUICHE_PROTOCOL_FLAG) + +QUICHE_PROTOCOL_FLAG(bool, quiche_oghttp2_debug_trace, false, + "If true, emits trace logs for HTTP/2 events.") + +#endif diff --git a/quiche/common/quiche_random.cc b/quiche/common/quiche_random.cc new file mode 100644 index 000000000000..12eeac203786 --- /dev/null +++ b/quiche/common/quiche_random.cc @@ -0,0 +1,93 @@ +#include "quiche/common/quiche_random.h" + +#include +#include + +#include "openssl/rand.h" +#include "quiche/common/platform/api/quiche_logging.h" +namespace quiche { + +namespace { + +// Insecure randomness in DefaultRandom uses an implementation of +// xoshiro256++ 1.0 based on code in the public domain from +// . + +inline uint64_t Xoshiro256InitializeRngStateMember() { + uint64_t result; + RAND_bytes(reinterpret_cast(&result), sizeof(result)); + return result; +} + +inline uint64_t Xoshiro256PlusPlusRotLeft(uint64_t x, int k) { + return (x << k) | (x >> (64 - k)); +} + +uint64_t Xoshiro256PlusPlus() { + static thread_local uint64_t rng_state[4] = { + Xoshiro256InitializeRngStateMember(), + Xoshiro256InitializeRngStateMember(), + Xoshiro256InitializeRngStateMember(), + Xoshiro256InitializeRngStateMember()}; + const uint64_t result = + Xoshiro256PlusPlusRotLeft(rng_state[0] + rng_state[3], 23) + rng_state[0]; + const uint64_t t = rng_state[1] << 17; + rng_state[2] ^= rng_state[0]; + rng_state[3] ^= rng_state[1]; + rng_state[1] ^= rng_state[2]; + rng_state[0] ^= rng_state[3]; + rng_state[2] ^= t; + rng_state[3] = Xoshiro256PlusPlusRotLeft(rng_state[3], 45); + return result; +} + +class DefaultQuicheRandom : public QuicheRandom { + public: + DefaultQuicheRandom() {} + DefaultQuicheRandom(const DefaultQuicheRandom&) = delete; + DefaultQuicheRandom& operator=(const DefaultQuicheRandom&) = delete; + ~DefaultQuicheRandom() override {} + + // QuicRandom implementation + void RandBytes(void* data, size_t len) override; + uint64_t RandUint64() override; + void InsecureRandBytes(void* data, size_t len) override; + uint64_t InsecureRandUint64() override; +}; + +void DefaultQuicheRandom::RandBytes(void* data, size_t len) { + RAND_bytes(reinterpret_cast(data), len); +} + +uint64_t DefaultQuicheRandom::RandUint64() { + uint64_t value; + RandBytes(&value, sizeof(value)); + return value; +} + +void DefaultQuicheRandom::InsecureRandBytes(void* data, size_t len) { + while (len >= sizeof(uint64_t)) { + uint64_t random_bytes64 = Xoshiro256PlusPlus(); + memcpy(data, &random_bytes64, sizeof(uint64_t)); + data = reinterpret_cast(data) + sizeof(uint64_t); + len -= sizeof(uint64_t); + } + if (len > 0) { + QUICHE_DCHECK_LT(len, sizeof(uint64_t)); + uint64_t random_bytes64 = Xoshiro256PlusPlus(); + memcpy(data, &random_bytes64, len); + } +} + +uint64_t DefaultQuicheRandom::InsecureRandUint64() { + return Xoshiro256PlusPlus(); +} + +} // namespace + +// static +QuicheRandom* QuicheRandom::GetInstance() { + static DefaultQuicheRandom* random = new DefaultQuicheRandom(); + return random; +} +} // namespace quiche diff --git a/quiche/common/quiche_random.h b/quiche/common/quiche_random.h new file mode 100644 index 000000000000..724bcca857df --- /dev/null +++ b/quiche/common/quiche_random.h @@ -0,0 +1,37 @@ +#ifndef QUICHE_COMMON_QUICHE_RANDOM_H_ +#define QUICHE_COMMON_QUICHE_RANDOM_H_ + +#include +#include + +#include "quiche/common/platform/api/quiche_export.h" + +namespace quiche { + +// The interface for a random number generator. +class QUICHE_EXPORT QuicheRandom { + public: + virtual ~QuicheRandom() {} + + // Returns the default random number generator, which is cryptographically + // secure and thread-safe. + static QuicheRandom* GetInstance(); + + // Generates |len| random bytes in the |data| buffer. + virtual void RandBytes(void* data, size_t len) = 0; + + // Returns a random number in the range [0, kuint64max]. + virtual uint64_t RandUint64() = 0; + + // Generates |len| random bytes in the |data| buffer. This MUST NOT be used + // for any application that requires cryptographically-secure randomness. + virtual void InsecureRandBytes(void* data, size_t len) = 0; + + // Returns a random number in the range [0, kuint64max]. This MUST NOT be used + // for any application that requires cryptographically-secure randomness. + virtual uint64_t InsecureRandUint64() = 0; +}; + +} // namespace quiche + +#endif // QUICHE_COMMON_QUICHE_RANDOM_H_ diff --git a/quiche/common/quiche_random_test.cc b/quiche/common/quiche_random_test.cc new file mode 100644 index 000000000000..2f1aacc4f4d5 --- /dev/null +++ b/quiche/common/quiche_random_test.cc @@ -0,0 +1,47 @@ +#include "quiche/common/quiche_random.h" + +#include "quiche/common/platform/api/quiche_test.h" + +namespace quiche { +namespace { + +TEST(QuicheRandom, RandBytes) { + unsigned char buf1[16]; + unsigned char buf2[16]; + memset(buf1, 0xaf, sizeof(buf1)); + memset(buf2, 0xaf, sizeof(buf2)); + ASSERT_EQ(0, memcmp(buf1, buf2, sizeof(buf1))); + + auto rng = QuicheRandom::GetInstance(); + rng->RandBytes(buf1, sizeof(buf1)); + EXPECT_NE(0, memcmp(buf1, buf2, sizeof(buf1))); +} + +TEST(QuicheRandom, RandUint64) { + auto rng = QuicheRandom::GetInstance(); + uint64_t value1 = rng->RandUint64(); + uint64_t value2 = rng->RandUint64(); + EXPECT_NE(value1, value2); +} + +TEST(QuicheRandom, InsecureRandBytes) { + unsigned char buf1[16]; + unsigned char buf2[16]; + memset(buf1, 0xaf, sizeof(buf1)); + memset(buf2, 0xaf, sizeof(buf2)); + ASSERT_EQ(0, memcmp(buf1, buf2, sizeof(buf1))); + + auto rng = QuicheRandom::GetInstance(); + rng->InsecureRandBytes(buf1, sizeof(buf1)); + EXPECT_NE(0, memcmp(buf1, buf2, sizeof(buf1))); +} + +TEST(QuicheRandom, InsecureRandUint64) { + auto rng = QuicheRandom::GetInstance(); + uint64_t value1 = rng->InsecureRandUint64(); + uint64_t value2 = rng->InsecureRandUint64(); + EXPECT_NE(value1, value2); +} + +} // namespace +} // namespace quiche diff --git a/quiche/common/quiche_status_utils.h b/quiche/common/quiche_status_utils.h new file mode 100644 index 000000000000..7b14e5dd0cb3 --- /dev/null +++ b/quiche/common/quiche_status_utils.h @@ -0,0 +1,51 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_QUICHE_STATUS_UTILS_H_ +#define QUICHE_COMMON_QUICHE_STATUS_UTILS_H_ + +#include + +#include "absl/base/optimization.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" + +namespace quiche { + +// A simplified version of the standard google3 "return if error" macro. Unlike +// the standard version, this does not come with a StatusBuilder support; the +// AppendToStatus() function below is meant to partially fill that gap. +#define QUICHE_RETURN_IF_ERROR(expr) \ + do { \ + absl::Status quiche_status_macro_value = (expr); \ + if (ABSL_PREDICT_FALSE(!quiche_status_macro_value.ok())) { \ + return quiche_status_macro_value; \ + } \ + } while (0) + +// Copies absl::Status payloads from `original` to `target`; required to copy a +// status correctly. +inline void CopyStatusPayloads(const absl::Status& original, + absl::Status& target) { + original.ForEachPayload([&](absl::string_view key, const absl::Cord& value) { + target.SetPayload(key, value); + }); +} + +// Appends additional into to a status message if the status message is +// an error. +template +absl::Status AppendToStatus(absl::Status input, T&&... args) { + if (ABSL_PREDICT_TRUE(input.ok())) { + return input; + } + absl::Status result = absl::Status( + input.code(), absl::StrCat(input.message(), std::forward(args)...)); + CopyStatusPayloads(input, result); + return result; +} + +} // namespace quiche + +#endif // QUICHE_COMMON_QUICHE_STATUS_UTILS_H_ diff --git a/quiche/common/quiche_stream.h b/quiche/common/quiche_stream.h new file mode 100644 index 000000000000..072b3cad4bd8 --- /dev/null +++ b/quiche/common/quiche_stream.h @@ -0,0 +1,106 @@ +// Copyright 2023 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// General-purpose abstractions for a write stream. + +#ifndef QUICHE_COMMON_QUICHE_STREAM_H_ +#define QUICHE_COMMON_QUICHE_STREAM_H_ + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace quiche { + +// A shared base class for read and write stream to support abrupt termination. +class QUICHE_EXPORT TerminableStream { + public: + virtual ~TerminableStream() = default; + + // Abruptly terminate the stream due to an error. If `error` is not OK, it may + // carry the error information that could be potentially communicated to the + // peer in case the stream is remote. If the stream is a duplex stream, both + // ends of the stream are terminated. + virtual void AbruptlyTerminate(absl::Status error) = 0; +}; + +// A general-purpose visitor API that gets notifications for WriteStream-related +// events. +class QUICHE_EXPORT WriteStreamVisitor { + public: + virtual ~WriteStreamVisitor() {} + + // Called whenever the stream is not write-blocked and can accept new data. + virtual void OnCanWrite() = 0; +}; + +// Options for writing data into a WriteStream. +class QUICHE_EXPORT StreamWriteOptions { + public: + StreamWriteOptions() = default; + + // If send_fin() is sent to true, the write operation also sends a FIN on the + // stream. + bool send_fin() const { return send_fin_; } + void set_send_fin(bool send_fin) { send_fin_ = send_fin; } + + private: + bool send_fin_ = false; +}; + +inline constexpr StreamWriteOptions kDefaultStreamWriteOptions = + StreamWriteOptions(); + +// WriteStream is an object that can accept a stream of bytes. +// +// The writes into a WriteStream are all-or-nothing. A WriteStream object has +// to either accept all data written into it by returning absl::OkStatus, or ask +// the caller to try again once via OnCanWrite() by returning +// absl::UnavailableError. +class QUICHE_EXPORT WriteStream : public TerminableStream { + public: + virtual ~WriteStream() {} + + // Writes |data| into the stream. + virtual absl::Status Writev(absl::Span data, + const StreamWriteOptions& options) = 0; + + // Indicates whether it is possible to write into stream right now. + virtual bool CanWrite() const = 0; + + // Legacy convenience method for writing a single string_view. New users + // should use quiche::WriteIntoStream instead, since this method does not + // return useful failure information. + [[nodiscard]] bool SendFin() { + StreamWriteOptions options; + options.set_send_fin(true); + return Writev(absl::Span(), options).ok(); + } + + // Legacy convenience method for writing a single string_view. New users + // should use quiche::WriteIntoStream instead, since this method does not + // return useful failure information. + [[nodiscard]] bool Write(absl::string_view data) { + return Writev(absl::MakeSpan(&data, 1), kDefaultStreamWriteOptions).ok(); + } +}; + +// Convenience methods to write a single chunk of data into the stream. +inline absl::Status WriteIntoStream( + WriteStream& stream, absl::string_view data, + const StreamWriteOptions& options = kDefaultStreamWriteOptions) { + return stream.Writev(absl::MakeSpan(&data, 1), options); +} + +// Convenience methods to send a FIN on the stream. +inline absl::Status SendFinOnStream(WriteStream& stream) { + StreamWriteOptions options; + options.set_send_fin(true); + return stream.Writev(absl::Span(), options); +} + +} // namespace quiche + +#endif // QUICHE_COMMON_QUICHE_STREAM_H_ diff --git a/quiche/common/quiche_text_utils.cc b/quiche/common/quiche_text_utils.cc new file mode 100644 index 000000000000..5b4ee8e1426e --- /dev/null +++ b/quiche/common/quiche_text_utils.cc @@ -0,0 +1,76 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/quiche_text_utils.h" + +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" + +namespace quiche { + +// static +void QuicheTextUtils::Base64Encode(const uint8_t* data, size_t data_len, + std::string* output) { + absl::Base64Escape(std::string(reinterpret_cast(data), data_len), + output); + // Remove padding. + size_t len = output->size(); + if (len >= 2) { + if ((*output)[len - 1] == '=') { + len--; + if ((*output)[len - 1] == '=') { + len--; + } + output->resize(len); + } + } +} + +// static +absl::optional QuicheTextUtils::Base64Decode( + absl::string_view input) { + std::string output; + if (!absl::Base64Unescape(input, &output)) { + return absl::nullopt; + } + return output; +} + +// static +std::string QuicheTextUtils::HexDump(absl::string_view binary_data) { + const int kBytesPerLine = 16; // Maximum bytes dumped per line. + int offset = 0; + const char* p = binary_data.data(); + int bytes_remaining = binary_data.size(); + std::string output; + while (bytes_remaining > 0) { + const int line_bytes = std::min(bytes_remaining, kBytesPerLine); + absl::StrAppendFormat(&output, "0x%04x: ", offset); + for (int i = 0; i < kBytesPerLine; ++i) { + if (i < line_bytes) { + absl::StrAppendFormat(&output, "%02x", + static_cast(p[i])); + } else { + absl::StrAppend(&output, " "); + } + if (i % 2) { + absl::StrAppend(&output, " "); + } + } + absl::StrAppend(&output, " "); + for (int i = 0; i < line_bytes; ++i) { + // Replace non-printable characters and 0x20 (space) with '.' + output += absl::ascii_isgraph(p[i]) ? p[i] : '.'; + } + + bytes_remaining -= line_bytes; + offset += line_bytes; + p += line_bytes; + absl::StrAppend(&output, "\n"); + } + return output; +} + +} // namespace quiche diff --git a/quiche/common/quiche_text_utils.h b/quiche/common/quiche_text_utils.h new file mode 100644 index 000000000000..b433718f7af9 --- /dev/null +++ b/quiche/common/quiche_text_utils.h @@ -0,0 +1,75 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_QUICHE_TEXT_UTILS_H_ +#define QUICHE_COMMON_QUICHE_TEXT_UTILS_H_ + +#include + +#include "absl/hash/hash.h" +#include "absl/strings/ascii.h" +#include "absl/strings/escaping.h" +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace quiche { + +struct QUICHE_EXPORT StringPieceCaseHash { + size_t operator()(absl::string_view data) const { + std::string lower = absl::AsciiStrToLower(data); + absl::Hash hasher; + return hasher(lower); + } +}; + +struct QUICHE_EXPORT StringPieceCaseEqual { + bool operator()(absl::string_view piece1, absl::string_view piece2) const { + return absl::EqualsIgnoreCase(piece1, piece2); + } +}; + +// Various utilities for manipulating text. +class QUICHE_EXPORT QuicheTextUtils { + public: + // Returns a new string in which |data| has been converted to lower case. + static std::string ToLower(absl::string_view data) { + return absl::AsciiStrToLower(data); + } + + // Removes leading and trailing whitespace from |data|. + static void RemoveLeadingAndTrailingWhitespace(absl::string_view* data) { + *data = absl::StripAsciiWhitespace(*data); + } + + // Base64 encodes with no padding |data_len| bytes of |data| into |output|. + static void Base64Encode(const uint8_t* data, size_t data_len, + std::string* output); + + // Decodes a base64-encoded |input|. Returns nullopt when the input is + // invalid. + static absl::optional Base64Decode(absl::string_view input); + + // Returns a string containing hex and ASCII representations of |binary|, + // side-by-side in the style of hexdump. Non-printable characters will be + // printed as '.' in the ASCII output. + // For example, given the input "Hello, QUIC!\01\02\03\04", returns: + // "0x0000: 4865 6c6c 6f2c 2051 5549 4321 0102 0304 Hello,.QUIC!...." + static std::string HexDump(absl::string_view binary_data); + + // Returns true if |data| contains any uppercase characters. + static bool ContainsUpperCase(absl::string_view data) { + return std::any_of(data.begin(), data.end(), absl::ascii_isupper); + } + + // Returns true if |data| contains only decimal digits. + static bool IsAllDigits(absl::string_view data) { + return std::all_of(data.begin(), data.end(), absl::ascii_isdigit); + } +}; + +} // namespace quiche + +#endif // QUICHE_COMMON_QUICHE_TEXT_UTILS_H_ diff --git a/quiche/common/quiche_text_utils_test.cc b/quiche/common/quiche_text_utils_test.cc new file mode 100644 index 000000000000..aeaa9404e59a --- /dev/null +++ b/quiche/common/quiche_text_utils_test.cc @@ -0,0 +1,92 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/quiche_text_utils.h" + +#include + +#include "absl/strings/escaping.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace quiche { +namespace test { + +TEST(QuicheTextUtilsTest, ToLower) { + EXPECT_EQ("lower", quiche::QuicheTextUtils::ToLower("LOWER")); + EXPECT_EQ("lower", quiche::QuicheTextUtils::ToLower("lower")); + EXPECT_EQ("lower", quiche::QuicheTextUtils::ToLower("lOwEr")); + EXPECT_EQ("123", quiche::QuicheTextUtils::ToLower("123")); + EXPECT_EQ("", quiche::QuicheTextUtils::ToLower("")); +} + +TEST(QuicheTextUtilsTest, RemoveLeadingAndTrailingWhitespace) { + std::string input; + + for (auto* input : {"text", " text", " text", "text ", "text ", " text ", + " text ", "\r\n\ttext", "text\n\r\t"}) { + absl::string_view piece(input); + quiche::QuicheTextUtils::RemoveLeadingAndTrailingWhitespace(&piece); + EXPECT_EQ("text", piece); + } +} + +TEST(QuicheTextUtilsTest, HexDump) { + // Verify output for empty input. + EXPECT_EQ("", quiche::QuicheTextUtils::HexDump(absl::HexStringToBytes(""))); + // Verify output of the HexDump method is as expected. + char packet[] = { + 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, 0x51, 0x55, 0x49, 0x43, 0x21, + 0x20, 0x54, 0x68, 0x69, 0x73, 0x20, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, + 0x20, 0x73, 0x68, 0x6f, 0x75, 0x6c, 0x64, 0x20, 0x62, 0x65, 0x20, 0x6c, + 0x6f, 0x6e, 0x67, 0x20, 0x65, 0x6e, 0x6f, 0x75, 0x67, 0x68, 0x20, 0x74, + 0x6f, 0x20, 0x73, 0x70, 0x61, 0x6e, 0x20, 0x6d, 0x75, 0x6c, 0x74, 0x69, + 0x70, 0x6c, 0x65, 0x20, 0x6c, 0x69, 0x6e, 0x65, 0x73, 0x20, 0x6f, 0x66, + 0x20, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x2e, 0x01, 0x02, 0x03, 0x00, + }; + EXPECT_EQ( + quiche::QuicheTextUtils::HexDump(packet), + "0x0000: 4865 6c6c 6f2c 2051 5549 4321 2054 6869 Hello,.QUIC!.Thi\n" + "0x0010: 7320 7374 7269 6e67 2073 686f 756c 6420 s.string.should.\n" + "0x0020: 6265 206c 6f6e 6720 656e 6f75 6768 2074 be.long.enough.t\n" + "0x0030: 6f20 7370 616e 206d 756c 7469 706c 6520 o.span.multiple.\n" + "0x0040: 6c69 6e65 7320 6f66 206f 7574 7075 742e lines.of.output.\n" + "0x0050: 0102 03 ...\n"); + // Verify that 0x21 and 0x7e are printable, 0x20 and 0x7f are not. + EXPECT_EQ( + "0x0000: 2021 7e7f .!~.\n", + quiche::QuicheTextUtils::HexDump(absl::HexStringToBytes("20217e7f"))); + // Verify that values above numeric_limits::max() are formatted + // properly on platforms where char is unsigned. + EXPECT_EQ("0x0000: 90aa ff ...\n", + quiche::QuicheTextUtils::HexDump(absl::HexStringToBytes("90aaff"))); +} + +TEST(QuicheTextUtilsTest, Base64Encode) { + std::string output; + std::string input = "Hello"; + quiche::QuicheTextUtils::Base64Encode( + reinterpret_cast(input.data()), input.length(), &output); + EXPECT_EQ("SGVsbG8", output); + + input = + "Hello, QUIC! This string should be long enough to span" + "multiple lines of output\n"; + quiche::QuicheTextUtils::Base64Encode( + reinterpret_cast(input.data()), input.length(), &output); + EXPECT_EQ( + "SGVsbG8sIFFVSUMhIFRoaXMgc3RyaW5nIHNob3VsZCBiZSBsb25n" + "IGVub3VnaCB0byBzcGFubXVsdGlwbGUgbGluZXMgb2Ygb3V0cHV0Cg", + output); +} + +TEST(QuicheTextUtilsTest, ContainsUpperCase) { + EXPECT_FALSE(quiche::QuicheTextUtils::ContainsUpperCase("abc")); + EXPECT_FALSE(quiche::QuicheTextUtils::ContainsUpperCase("")); + EXPECT_FALSE(quiche::QuicheTextUtils::ContainsUpperCase("123")); + EXPECT_TRUE(quiche::QuicheTextUtils::ContainsUpperCase("ABC")); + EXPECT_TRUE(quiche::QuicheTextUtils::ContainsUpperCase("aBc")); +} + +} // namespace test +} // namespace quiche diff --git a/quiche/common/simple_buffer_allocator.cc b/quiche/common/simple_buffer_allocator.cc new file mode 100644 index 000000000000..ff3498451acf --- /dev/null +++ b/quiche/common/simple_buffer_allocator.cc @@ -0,0 +1,17 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/simple_buffer_allocator.h" + +namespace quiche { + +char* SimpleBufferAllocator::New(size_t size) { return new char[size]; } + +char* SimpleBufferAllocator::New(size_t size, bool /* flag_enable */) { + return New(size); +} + +void SimpleBufferAllocator::Delete(char* buffer) { delete[] buffer; } + +} // namespace quiche diff --git a/quiche/common/simple_buffer_allocator.h b/quiche/common/simple_buffer_allocator.h new file mode 100644 index 000000000000..babfa5527c72 --- /dev/null +++ b/quiche/common/simple_buffer_allocator.h @@ -0,0 +1,30 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_SIMPLE_BUFFER_ALLOCATOR_H_ +#define QUICHE_COMMON_SIMPLE_BUFFER_ALLOCATOR_H_ + +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/quiche_buffer_allocator.h" + +namespace quiche { + +// Provides buffer allocation using operators new[] and delete[] on char arrays. +// Note that some of the QUICHE code relies on this being the case for deleting +// new[]-allocated arrays from elsewhere. +class QUICHE_EXPORT SimpleBufferAllocator : public QuicheBufferAllocator { + public: + static SimpleBufferAllocator* Get() { + static SimpleBufferAllocator* singleton = new SimpleBufferAllocator(); + return singleton; + } + + char* New(size_t size) override; + char* New(size_t size, bool flag_enable) override; + void Delete(char* buffer) override; +}; + +} // namespace quiche + +#endif // QUICHE_COMMON_SIMPLE_BUFFER_ALLOCATOR_H_ diff --git a/quiche/common/simple_buffer_allocator_test.cc b/quiche/common/simple_buffer_allocator_test.cc new file mode 100644 index 000000000000..c23d6b7ba9d4 --- /dev/null +++ b/quiche/common/simple_buffer_allocator_test.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/simple_buffer_allocator.h" + +#include "quiche/common/platform/api/quiche_test.h" + +namespace quiche { +namespace { + +TEST(SimpleBufferAllocatorTest, NewDelete) { + SimpleBufferAllocator alloc; + char* buf = alloc.New(4); + EXPECT_NE(nullptr, buf); + alloc.Delete(buf); +} + +TEST(SimpleBufferAllocatorTest, DeleteNull) { + SimpleBufferAllocator alloc; + alloc.Delete(nullptr); +} + +TEST(SimpleBufferAllocatorTest, MoveBuffersConstructor) { + SimpleBufferAllocator alloc; + QuicheBuffer buffer1(&alloc, 16); + + EXPECT_NE(buffer1.data(), nullptr); + EXPECT_EQ(buffer1.size(), 16u); + + QuicheBuffer buffer2(std::move(buffer1)); + EXPECT_EQ(buffer1.data(), nullptr); // NOLINT(bugprone-use-after-move) + EXPECT_EQ(buffer1.size(), 0u); + EXPECT_NE(buffer2.data(), nullptr); + EXPECT_EQ(buffer2.size(), 16u); +} + +TEST(SimpleBufferAllocatorTest, MoveBuffersAssignment) { + SimpleBufferAllocator alloc; + QuicheBuffer buffer1(&alloc, 16); + QuicheBuffer buffer2; + + EXPECT_NE(buffer1.data(), nullptr); + EXPECT_EQ(buffer1.size(), 16u); + EXPECT_EQ(buffer2.data(), nullptr); + EXPECT_EQ(buffer2.size(), 0u); + + buffer2 = std::move(buffer1); + EXPECT_EQ(buffer1.data(), nullptr); // NOLINT(bugprone-use-after-move) + EXPECT_EQ(buffer1.size(), 0u); + EXPECT_NE(buffer2.data(), nullptr); + EXPECT_EQ(buffer2.size(), 16u); +} + +TEST(SimpleBufferAllocatorTest, CopyBuffer) { + SimpleBufferAllocator alloc; + const absl::string_view original = "Test string"; + QuicheBuffer copy = QuicheBuffer::Copy(&alloc, original); + EXPECT_EQ(copy.AsStringView(), original); +} + +} // namespace +} // namespace quiche diff --git a/quiche/common/structured_headers.cc b/quiche/common/structured_headers.cc new file mode 100644 index 000000000000..58aec62d7f73 --- /dev/null +++ b/quiche/common/structured_headers.cc @@ -0,0 +1,910 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/structured_headers.h" + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/ascii.h" +#include "absl/strings/escaping.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace quiche { +namespace structured_headers { + +namespace { + +#define DIGIT "0123456789" +#define LCALPHA "abcdefghijklmnopqrstuvwxyz" +#define UCALPHA "ABCDEFGHIJKLMNOPQRSTUVWXYZ" +#define TCHAR DIGIT LCALPHA UCALPHA "!#$%&'*+-.^_`|~" +// https://tools.ietf.org/html/draft-ietf-httpbis-header-structure-09#section-3.9 +constexpr char kTokenChars09[] = DIGIT UCALPHA LCALPHA "_-.:%*/"; +// https://www.rfc-editor.org/rfc/rfc8941.html#section-3.3.4 +constexpr char kTokenChars[] = TCHAR ":/"; +// https://tools.ietf.org/html/draft-ietf-httpbis-header-structure-09#section-3.1 +constexpr char kKeyChars09[] = DIGIT LCALPHA "_-"; +// https://www.rfc-editor.org/rfc/rfc8941.html#section-3.1.2 +constexpr char kKeyChars[] = DIGIT LCALPHA "_-.*"; +constexpr char kSP[] = " "; +constexpr char kOWS[] = " \t"; +#undef DIGIT +#undef LCALPHA +#undef UCALPHA + +// https://www.rfc-editor.org/rfc/rfc8941.html#section-3.3.1 +constexpr int64_t kMaxInteger = 999'999'999'999'999L; +constexpr int64_t kMinInteger = -999'999'999'999'999L; + +// Smallest value which is too large for an sh-decimal. This is the smallest +// double which will round up to 1e12 when serialized, which exceeds the range +// for sh-decimal. Any float less than this should round down. This behaviour is +// verified by unit tests. +constexpr double kTooLargeDecimal = 1e12 - 0.0005; + +// Removes characters in remove from the beginning of s. +void StripLeft(absl::string_view& s, absl::string_view remove) { + size_t i = s.find_first_not_of(remove); + if (i == absl::string_view::npos) { + i = s.size(); + } + s.remove_prefix(i); +} + +// Parser for (a subset of) Structured Headers for HTTP defined in [SH09] and +// [RFC8941]. [SH09] compatibility is retained for use by Web Packaging, and can +// be removed once that spec is updated, and users have migrated to new headers. +// [SH09] https://tools.ietf.org/html/draft-ietf-httpbis-header-structure-09 +// [RFC8941] https://www.rfc-editor.org/rfc/rfc8941.html +class StructuredHeaderParser { + public: + enum DraftVersion { + kDraft09, + kFinal, + }; + explicit StructuredHeaderParser(absl::string_view str, DraftVersion version) + : input_(str), version_(version) { + // [SH09] 4.2 Step 1. + // Discard any leading OWS from input_string. + // [RFC8941] 4.2 Step 2. + // Discard any leading SP characters from input_string. + SkipWhitespaces(); + } + StructuredHeaderParser(const StructuredHeaderParser&) = delete; + StructuredHeaderParser& operator=(const StructuredHeaderParser&) = delete; + + // Callers should call this after ReadSomething(), to check if parser has + // consumed all the input successfully. + bool FinishParsing() { + // [SH09] 4.2 Step 7. + // Discard any leading OWS from input_string. + // [RFC8941] 4.2 Step 6. + // Discard any leading SP characters from input_string. + SkipWhitespaces(); + // [SH09] 4.2 Step 8. [RFC8941] 4.2 Step 7. + // If input_string is not empty, fail parsing. + return input_.empty(); + } + + // Parses a List of Lists ([SH09] 4.2.4). + absl::optional ReadListOfLists() { + QUICHE_CHECK_EQ(version_, kDraft09); + ListOfLists result; + while (true) { + std::vector inner_list; + while (true) { + absl::optional item(ReadBareItem()); + if (!item) return absl::nullopt; + inner_list.push_back(std::move(*item)); + SkipWhitespaces(); + if (!ConsumeChar(';')) break; + SkipWhitespaces(); + } + result.push_back(std::move(inner_list)); + SkipWhitespaces(); + if (!ConsumeChar(',')) break; + SkipWhitespaces(); + } + return result; + } + + // Parses a List ([RFC8941] 4.2.1). + absl::optional ReadList() { + QUICHE_CHECK_EQ(version_, kFinal); + List members; + while (!input_.empty()) { + absl::optional member(ReadItemOrInnerList()); + if (!member) return absl::nullopt; + members.push_back(std::move(*member)); + SkipOWS(); + if (input_.empty()) break; + if (!ConsumeChar(',')) return absl::nullopt; + SkipOWS(); + if (input_.empty()) return absl::nullopt; + } + return members; + } + + // Parses an Item ([RFC8941] 4.2.3). + absl::optional ReadItem() { + absl::optional item = ReadBareItem(); + if (!item) return absl::nullopt; + absl::optional parameters = ReadParameters(); + if (!parameters) return absl::nullopt; + return ParameterizedItem(std::move(*item), std::move(*parameters)); + } + + // Parses a bare Item ([RFC8941] 4.2.3.1, though this is also the algorithm + // for parsing an Item from [SH09] 4.2.7). + absl::optional ReadBareItem() { + if (input_.empty()) { + QUICHE_DVLOG(1) << "ReadBareItem: unexpected EOF"; + return absl::nullopt; + } + switch (input_.front()) { + case '"': + return ReadString(); + case '*': + if (version_ == kDraft09) return ReadByteSequence(); + return ReadToken(); + case ':': + if (version_ == kFinal) return ReadByteSequence(); + return absl::nullopt; + case '?': + return ReadBoolean(); + default: + if (input_.front() == '-' || absl::ascii_isdigit(input_.front())) + return ReadNumber(); + if (absl::ascii_isalpha(input_.front())) return ReadToken(); + return absl::nullopt; + } + } + + // Parses a Dictionary ([RFC8941] 4.2.2). + absl::optional ReadDictionary() { + QUICHE_CHECK_EQ(version_, kFinal); + Dictionary members; + while (!input_.empty()) { + absl::optional key(ReadKey()); + if (!key) return absl::nullopt; + absl::optional member; + if (ConsumeChar('=')) { + member = ReadItemOrInnerList(); + if (!member) return absl::nullopt; + } else { + absl::optional parameters; + parameters = ReadParameters(); + if (!parameters) return absl::nullopt; + member = ParameterizedMember{Item(true), std::move(*parameters)}; + } + members[*key] = std::move(*member); + SkipOWS(); + if (input_.empty()) break; + if (!ConsumeChar(',')) return absl::nullopt; + SkipOWS(); + if (input_.empty()) return absl::nullopt; + } + return members; + } + + // Parses a Parameterised List ([SH09] 4.2.5). + absl::optional ReadParameterisedList() { + QUICHE_CHECK_EQ(version_, kDraft09); + ParameterisedList items; + while (true) { + absl::optional item = + ReadParameterisedIdentifier(); + if (!item) return absl::nullopt; + items.push_back(std::move(*item)); + SkipWhitespaces(); + if (!ConsumeChar(',')) return items; + SkipWhitespaces(); + } + } + + private: + // Parses a Parameterised Identifier ([SH09] 4.2.6). + absl::optional ReadParameterisedIdentifier() { + QUICHE_CHECK_EQ(version_, kDraft09); + absl::optional primary_identifier = ReadToken(); + if (!primary_identifier) return absl::nullopt; + + ParameterisedIdentifier::Parameters parameters; + + SkipWhitespaces(); + while (ConsumeChar(';')) { + SkipWhitespaces(); + + absl::optional name = ReadKey(); + if (!name) return absl::nullopt; + + Item value; + if (ConsumeChar('=')) { + auto item = ReadBareItem(); + if (!item) return absl::nullopt; + value = std::move(*item); + } + if (!parameters.emplace(*name, value).second) { + QUICHE_DVLOG(1) << "ReadParameterisedIdentifier: duplicated parameter: " + << *name; + return absl::nullopt; + } + SkipWhitespaces(); + } + return ParameterisedIdentifier(std::move(*primary_identifier), + std::move(parameters)); + } + + // Parses an Item or Inner List ([RFC8941] 4.2.1.1). + absl::optional ReadItemOrInnerList() { + QUICHE_CHECK_EQ(version_, kFinal); + std::vector member; + bool member_is_inner_list = (!input_.empty() && input_.front() == '('); + if (member_is_inner_list) { + return ReadInnerList(); + } else { + auto item = ReadItem(); + if (!item) return absl::nullopt; + return ParameterizedMember(std::move(item->item), + std::move(item->params)); + } + } + + // Parses Parameters ([RFC8941] 4.2.3.2) + absl::optional ReadParameters() { + Parameters parameters; + absl::flat_hash_set keys; + + while (ConsumeChar(';')) { + SkipWhitespaces(); + + absl::optional name = ReadKey(); + if (!name) return absl::nullopt; + bool is_duplicate_key = !keys.insert(*name).second; + + Item value{true}; + if (ConsumeChar('=')) { + auto item = ReadBareItem(); + if (!item) return absl::nullopt; + value = std::move(*item); + } + if (is_duplicate_key) { + for (auto& param : parameters) { + if (param.first == name) { + param.second = std::move(value); + break; + } + } + } else { + parameters.emplace_back(std::move(*name), std::move(value)); + } + } + return parameters; + } + + // Parses an Inner List ([RFC8941] 4.2.1.2). + absl::optional ReadInnerList() { + QUICHE_CHECK_EQ(version_, kFinal); + if (!ConsumeChar('(')) return absl::nullopt; + std::vector inner_list; + while (true) { + SkipWhitespaces(); + if (ConsumeChar(')')) { + absl::optional parameters; + parameters = ReadParameters(); + if (!parameters) return absl::nullopt; + return ParameterizedMember(std::move(inner_list), true, + std::move(*parameters)); + } + auto item = ReadItem(); + if (!item) return absl::nullopt; + inner_list.push_back(std::move(*item)); + if (input_.empty() || (input_.front() != ' ' && input_.front() != ')')) + return absl::nullopt; + } + QUICHE_NOTREACHED(); + return absl::nullopt; + } + + // Parses a Key ([SH09] 4.2.2, [RFC8941] 4.2.3.3). + absl::optional ReadKey() { + if (version_ == kDraft09) { + if (input_.empty() || !absl::ascii_islower(input_.front())) { + LogParseError("ReadKey", "lcalpha"); + return absl::nullopt; + } + } else { + if (input_.empty() || + (!absl::ascii_islower(input_.front()) && input_.front() != '*')) { + LogParseError("ReadKey", "lcalpha | *"); + return absl::nullopt; + } + } + const char* allowed_chars = + (version_ == kDraft09 ? kKeyChars09 : kKeyChars); + size_t len = input_.find_first_not_of(allowed_chars); + if (len == absl::string_view::npos) len = input_.size(); + std::string key(input_.substr(0, len)); + input_.remove_prefix(len); + return key; + } + + // Parses a Token ([SH09] 4.2.10, [RFC8941] 4.2.6). + absl::optional ReadToken() { + if (input_.empty() || + !(absl::ascii_isalpha(input_.front()) || input_.front() == '*')) { + LogParseError("ReadToken", "ALPHA"); + return absl::nullopt; + } + size_t len = input_.find_first_not_of(version_ == kDraft09 ? kTokenChars09 + : kTokenChars); + if (len == absl::string_view::npos) len = input_.size(); + std::string token(input_.substr(0, len)); + input_.remove_prefix(len); + return Item(std::move(token), Item::kTokenType); + } + + // Parses a Number ([SH09] 4.2.8, [RFC8941] 4.2.4). + absl::optional ReadNumber() { + bool is_negative = ConsumeChar('-'); + bool is_decimal = false; + size_t decimal_position = 0; + size_t i = 0; + for (; i < input_.size(); ++i) { + if (i > 0 && input_[i] == '.' && !is_decimal) { + is_decimal = true; + decimal_position = i; + continue; + } + if (!absl::ascii_isdigit(input_[i])) break; + } + if (i == 0) { + LogParseError("ReadNumber", "DIGIT"); + return absl::nullopt; + } + if (!is_decimal) { + // [RFC8941] restricts the range of integers further. + if (version_ == kFinal && i > 15) { + LogParseError("ReadNumber", "integer too long"); + return absl::nullopt; + } + } else { + if (version_ != kFinal && i > 16) { + LogParseError("ReadNumber", "float too long"); + return absl::nullopt; + } + if (version_ == kFinal && decimal_position > 12) { + LogParseError("ReadNumber", "decimal too long"); + return absl::nullopt; + } + if (i - decimal_position > (version_ == kFinal ? 4 : 7)) { + LogParseError("ReadNumber", "too many digits after decimal"); + return absl::nullopt; + } + if (i == decimal_position) { + LogParseError("ReadNumber", "no digits after decimal"); + return absl::nullopt; + } + } + std::string output_number_string(input_.substr(0, i)); + input_.remove_prefix(i); + + if (is_decimal) { + // Convert to a 64-bit double, and return if the conversion is + // successful. + double f; + if (!absl::SimpleAtod(output_number_string, &f)) return absl::nullopt; + return Item(is_negative ? -f : f); + } else { + // Convert to a 64-bit signed integer, and return if the conversion is + // successful. + int64_t n; + if (!absl::SimpleAtoi(output_number_string, &n)) return absl::nullopt; + QUICHE_CHECK(version_ != kFinal || + (n <= kMaxInteger && n >= kMinInteger)); + return Item(is_negative ? -n : n); + } + } + + // Parses a String ([SH09] 4.2.9, [RFC8941] 4.2.5). + absl::optional ReadString() { + std::string s; + if (!ConsumeChar('"')) { + LogParseError("ReadString", "'\"'"); + return absl::nullopt; + } + while (!ConsumeChar('"')) { + size_t i = 0; + for (; i < input_.size(); ++i) { + if (!absl::ascii_isprint(input_[i])) { + QUICHE_DVLOG(1) << "ReadString: non printable-ASCII character"; + return absl::nullopt; + } + if (input_[i] == '"' || input_[i] == '\\') break; + } + if (i == input_.size()) { + QUICHE_DVLOG(1) << "ReadString: missing closing '\"'"; + return absl::nullopt; + } + s.append(std::string(input_.substr(0, i))); + input_.remove_prefix(i); + if (ConsumeChar('\\')) { + if (input_.empty()) { + QUICHE_DVLOG(1) << "ReadString: backslash at string end"; + return absl::nullopt; + } + if (input_[0] != '"' && input_[0] != '\\') { + QUICHE_DVLOG(1) << "ReadString: invalid escape"; + return absl::nullopt; + } + s.push_back(input_.front()); + input_.remove_prefix(1); + } + } + return s; + } + + // Parses a Byte Sequence ([SH09] 4.2.11, [RFC8941] 4.2.7). + absl::optional ReadByteSequence() { + char delimiter = (version_ == kDraft09 ? '*' : ':'); + if (!ConsumeChar(delimiter)) { + LogParseError("ReadByteSequence", "delimiter"); + return absl::nullopt; + } + size_t len = input_.find(delimiter); + if (len == absl::string_view::npos) { + QUICHE_DVLOG(1) << "ReadByteSequence: missing closing delimiter"; + return absl::nullopt; + } + std::string base64(input_.substr(0, len)); + // Append the necessary padding characters. + base64.resize((base64.size() + 3) / 4 * 4, '='); + + std::string binary; + if (!absl::Base64Unescape(base64, &binary)) { + QUICHE_DVLOG(1) << "ReadByteSequence: failed to decode base64: " + << base64; + return absl::nullopt; + } + input_.remove_prefix(len); + ConsumeChar(delimiter); + return Item(std::move(binary), Item::kByteSequenceType); + } + + // Parses a Boolean ([RFC8941] 4.2.8). + // Note that this only parses ?0 and ?1 forms from SH version 10+, not the + // previous ?F and ?T, which were not needed by any consumers of SH version 9. + absl::optional ReadBoolean() { + if (!ConsumeChar('?')) { + LogParseError("ReadBoolean", "'?'"); + return absl::nullopt; + } + if (ConsumeChar('1')) { + return Item(true); + } + if (ConsumeChar('0')) { + return Item(false); + } + return absl::nullopt; + } + + // There are several points in the specs where the handling of whitespace + // differs between Draft 9 and the final RFC. In those cases, Draft 9 allows + // any OWS character, while the RFC allows only a U+0020 SPACE. + void SkipWhitespaces() { + if (version_ == kDraft09) { + StripLeft(input_, kOWS); + } else { + StripLeft(input_, kSP); + } + } + + void SkipOWS() { StripLeft(input_, kOWS); } + + bool ConsumeChar(char expected) { + if (!input_.empty() && input_.front() == expected) { + input_.remove_prefix(1); + return true; + } + return false; + } + + void LogParseError(const char* func, const char* expected) { + QUICHE_DVLOG(1) << func << ": " << expected << " expected, got " + << (input_.empty() + ? "EOS" + : "'" + std::string(input_.substr(0, 1)) + "'"); + } + + absl::string_view input_; + DraftVersion version_; +}; + +// Serializer for (a subset of) Structured Field Values for HTTP defined in +// [RFC8941]. Note that this serializer does not attempt to support [SH09]. +class StructuredHeaderSerializer { + public: + StructuredHeaderSerializer() = default; + ~StructuredHeaderSerializer() = default; + StructuredHeaderSerializer(const StructuredHeaderSerializer&) = delete; + StructuredHeaderSerializer& operator=(const StructuredHeaderSerializer&) = + delete; + + std::string Output() { return output_.str(); } + + // Serializes a List ([RFC8941] 4.1.1). + bool WriteList(const List& value) { + bool first = true; + for (const auto& member : value) { + if (!first) output_ << ", "; + if (!WriteParameterizedMember(member)) return false; + first = false; + } + return true; + } + + // Serializes an Item ([RFC8941] 4.1.3). + bool WriteItem(const ParameterizedItem& value) { + if (!WriteBareItem(value.item)) return false; + return WriteParameters(value.params); + } + + // Serializes an Item ([RFC8941] 4.1.3). + bool WriteBareItem(const Item& value) { + if (value.is_string()) { + // Serializes a String ([RFC8941] 4.1.6). + output_ << "\""; + for (const char& c : value.GetString()) { + if (!absl::ascii_isprint(c)) return false; + if (c == '\\' || c == '\"') output_ << "\\"; + output_ << c; + } + output_ << "\""; + return true; + } + if (value.is_token()) { + // Serializes a Token ([RFC8941] 4.1.7). + if (value.GetString().empty() || + !(absl::ascii_isalpha(value.GetString().front()) || + value.GetString().front() == '*')) + return false; + if (value.GetString().find_first_not_of(kTokenChars) != std::string::npos) + return false; + output_ << value.GetString(); + return true; + } + if (value.is_byte_sequence()) { + // Serializes a Byte Sequence ([RFC8941] 4.1.8). + output_ << ":"; + output_ << absl::Base64Escape(value.GetString()); + output_ << ":"; + return true; + } + if (value.is_integer()) { + // Serializes an Integer ([RFC8941] 4.1.4). + if (value.GetInteger() > kMaxInteger || value.GetInteger() < kMinInteger) + return false; + output_ << value.GetInteger(); + return true; + } + if (value.is_decimal()) { + // Serializes a Decimal ([RFC8941] 4.1.5). + double decimal_value = value.GetDecimal(); + if (!std::isfinite(decimal_value) || + fabs(decimal_value) >= kTooLargeDecimal) + return false; + + // Handle sign separately to simplify the rest of the formatting. + if (decimal_value < 0) output_ << "-"; + // Unconditionally take absolute value to ensure that -0 is serialized as + // "0.0", with no negative sign, as required by spec. (4.1.5, step 2). + decimal_value = fabs(decimal_value); + double remainder = fmod(decimal_value, 0.002); + if (remainder == 0.0005) { + // Value ended in exactly 0.0005, 0.0025, 0.0045, etc. Round down. + decimal_value -= 0.0005; + } else if (remainder == 0.0015) { + // Value ended in exactly 0.0015, 0.0035, 0,0055, etc. Round up. + decimal_value += 0.0005; + } else { + // Standard rounding will work in all other cases. + decimal_value = round(decimal_value * 1000.0) / 1000.0; + } + + // Use standard library functions to write the decimal, and then truncate + // if necessary to conform to spec. + + // Maximum is 12 integer digits, one decimal point, three fractional + // digits, and a null terminator. + char buffer[17]; + absl::SNPrintF(buffer, std::size(buffer), "%#.3f", decimal_value); + + // Strip any trailing 0s after the decimal point, but leave at least one + // digit after it in all cases. (So 1.230 becomes 1.23, but 1.000 becomes + // 1.0.) + absl::string_view formatted_number(buffer); + auto truncate_index = formatted_number.find_last_not_of('0'); + if (formatted_number[truncate_index] == '.') truncate_index++; + output_ << formatted_number.substr(0, truncate_index + 1); + return true; + } + if (value.is_boolean()) { + // Serializes a Boolean ([RFC8941] 4.1.9). + output_ << (value.GetBoolean() ? "?1" : "?0"); + return true; + } + return false; + } + + // Serializes a Dictionary ([RFC8941] 4.1.2). + bool WriteDictionary(const Dictionary& value) { + bool first = true; + for (const auto& [dict_key, dict_value] : value) { + if (!first) output_ << ", "; + if (!WriteKey(dict_key)) return false; + first = false; + if (!dict_value.member_is_inner_list && !dict_value.member.empty() && + dict_value.member.front().item.is_boolean() && + dict_value.member.front().item.GetBoolean()) { + if (!WriteParameters(dict_value.params)) return false; + } else { + output_ << "="; + if (!WriteParameterizedMember(dict_value)) return false; + } + } + return true; + } + + private: + bool WriteParameterizedMember(const ParameterizedMember& value) { + // Serializes a parameterized member ([RFC8941] 4.1.1). + if (value.member_is_inner_list) { + if (!WriteInnerList(value.member)) return false; + } else { + QUICHE_CHECK_EQ(value.member.size(), 1UL); + if (!WriteItem(value.member[0])) return false; + } + return WriteParameters(value.params); + } + + bool WriteInnerList(const std::vector& value) { + // Serializes an inner list ([RFC8941] 4.1.1.1). + output_ << "("; + bool first = true; + for (const ParameterizedItem& member : value) { + if (!first) output_ << " "; + if (!WriteItem(member)) return false; + first = false; + } + output_ << ")"; + return true; + } + + bool WriteParameters(const Parameters& value) { + // Serializes a parameter list ([RFC8941] 4.1.1.2). + for (const auto& param_name_and_value : value) { + const std::string& param_name = param_name_and_value.first; + const Item& param_value = param_name_and_value.second; + output_ << ";"; + if (!WriteKey(param_name)) return false; + if (!param_value.is_null()) { + if (param_value.is_boolean() && param_value.GetBoolean()) continue; + output_ << "="; + if (!WriteBareItem(param_value)) return false; + } + } + return true; + } + + bool WriteKey(const std::string& value) { + // Serializes a Key ([RFC8941] 4.1.1.3). + if (value.empty()) return false; + if (value.find_first_not_of(kKeyChars) != std::string::npos) return false; + if (!absl::ascii_islower(value[0]) && value[0] != '*') return false; + output_ << value; + return true; + } + + std::ostringstream output_; +}; + +} // namespace + +Item::Item() {} +Item::Item(std::string value, Item::ItemType type) { + switch (type) { + case kStringType: + value_.emplace(std::move(value)); + break; + case kTokenType: + value_.emplace(std::move(value)); + break; + case kByteSequenceType: + value_.emplace(std::move(value)); + break; + default: + QUICHE_CHECK(false); + break; + } +} +Item::Item(const char* value, Item::ItemType type) + : Item(std::string(value), type) {} +Item::Item(int64_t value) : value_(value) {} +Item::Item(double value) : value_(value) {} +Item::Item(bool value) : value_(value) {} + +bool operator==(const Item& lhs, const Item& rhs) { + return lhs.value_ == rhs.value_; +} + +ParameterizedItem::ParameterizedItem() = default; +ParameterizedItem::ParameterizedItem(const ParameterizedItem&) = default; +ParameterizedItem& ParameterizedItem::operator=(const ParameterizedItem&) = + default; +ParameterizedItem::ParameterizedItem(Item id, Parameters ps) + : item(std::move(id)), params(std::move(ps)) {} +ParameterizedItem::~ParameterizedItem() = default; + +ParameterizedMember::ParameterizedMember() = default; +ParameterizedMember::ParameterizedMember(const ParameterizedMember&) = default; +ParameterizedMember& ParameterizedMember::operator=( + const ParameterizedMember&) = default; +ParameterizedMember::ParameterizedMember(std::vector id, + bool member_is_inner_list, + Parameters ps) + : member(std::move(id)), + member_is_inner_list(member_is_inner_list), + params(std::move(ps)) {} +ParameterizedMember::ParameterizedMember(std::vector id, + Parameters ps) + : member(std::move(id)), + member_is_inner_list(true), + params(std::move(ps)) {} +ParameterizedMember::ParameterizedMember(Item id, Parameters ps) + : member({{std::move(id), {}}}), + member_is_inner_list(false), + params(std::move(ps)) {} +ParameterizedMember::~ParameterizedMember() = default; + +ParameterisedIdentifier::ParameterisedIdentifier() = default; +ParameterisedIdentifier::ParameterisedIdentifier( + const ParameterisedIdentifier&) = default; +ParameterisedIdentifier& ParameterisedIdentifier::operator=( + const ParameterisedIdentifier&) = default; +ParameterisedIdentifier::ParameterisedIdentifier(Item id, Parameters ps) + : identifier(std::move(id)), params(std::move(ps)) {} +ParameterisedIdentifier::~ParameterisedIdentifier() = default; + +Dictionary::Dictionary() = default; +Dictionary::Dictionary(const Dictionary&) = default; +Dictionary::Dictionary(std::vector members) + : members_(std::move(members)) {} +Dictionary::~Dictionary() = default; +std::vector::iterator Dictionary::begin() { + return members_.begin(); +} +std::vector::const_iterator Dictionary::begin() const { + return members_.begin(); +} +std::vector::iterator Dictionary::end() { + return members_.end(); +} +std::vector::const_iterator Dictionary::end() const { + return members_.end(); +} +ParameterizedMember& Dictionary::operator[](std::size_t idx) { + return members_[idx].second; +} +const ParameterizedMember& Dictionary::operator[](std::size_t idx) const { + return members_[idx].second; +} +ParameterizedMember& Dictionary::at(std::size_t idx) { return (*this)[idx]; } +const ParameterizedMember& Dictionary::at(std::size_t idx) const { + return (*this)[idx]; +} +ParameterizedMember& Dictionary::operator[](absl::string_view key) { + auto it = absl::c_find_if( + members_, [key](const auto& member) { return member.first == key; }); + if (it != members_.end()) return it->second; + members_.push_back({std::string(key), ParameterizedMember()}); + return members_.back().second; +} +ParameterizedMember& Dictionary::at(absl::string_view key) { + auto it = absl::c_find_if( + members_, [key](const auto& member) { return member.first == key; }); + QUICHE_CHECK(it != members_.end()) << "Provided key not found in dictionary"; + return it->second; +} +const ParameterizedMember& Dictionary::at(absl::string_view key) const { + auto it = absl::c_find_if( + members_, [key](const auto& member) { return member.first == key; }); + QUICHE_CHECK(it != members_.end()) << "Provided key not found in dictionary"; + return it->second; +} +bool Dictionary::empty() const { return members_.empty(); } +std::size_t Dictionary::size() const { return members_.size(); } +bool Dictionary::contains(absl::string_view key) const { + for (auto& member : members_) { + if (member.first == key) return true; + } + return false; +} + +absl::optional ParseItem(absl::string_view str) { + StructuredHeaderParser parser(str, StructuredHeaderParser::kFinal); + absl::optional item = parser.ReadItem(); + if (item && parser.FinishParsing()) return item; + return absl::nullopt; +} + +absl::optional ParseBareItem(absl::string_view str) { + StructuredHeaderParser parser(str, StructuredHeaderParser::kFinal); + absl::optional item = parser.ReadBareItem(); + if (item && parser.FinishParsing()) return item; + return absl::nullopt; +} + +absl::optional ParseParameterisedList( + absl::string_view str) { + StructuredHeaderParser parser(str, StructuredHeaderParser::kDraft09); + absl::optional param_list = parser.ReadParameterisedList(); + if (param_list && parser.FinishParsing()) return param_list; + return absl::nullopt; +} + +absl::optional ParseListOfLists(absl::string_view str) { + StructuredHeaderParser parser(str, StructuredHeaderParser::kDraft09); + absl::optional list_of_lists = parser.ReadListOfLists(); + if (list_of_lists && parser.FinishParsing()) return list_of_lists; + return absl::nullopt; +} + +absl::optional ParseList(absl::string_view str) { + StructuredHeaderParser parser(str, StructuredHeaderParser::kFinal); + absl::optional list = parser.ReadList(); + if (list && parser.FinishParsing()) return list; + return absl::nullopt; +} + +absl::optional ParseDictionary(absl::string_view str) { + StructuredHeaderParser parser(str, StructuredHeaderParser::kFinal); + absl::optional dictionary = parser.ReadDictionary(); + if (dictionary && parser.FinishParsing()) return dictionary; + return absl::nullopt; +} + +absl::optional SerializeItem(const Item& value) { + StructuredHeaderSerializer s; + if (s.WriteItem(ParameterizedItem(value, {}))) return s.Output(); + return absl::nullopt; +} + +absl::optional SerializeItem(const ParameterizedItem& value) { + StructuredHeaderSerializer s; + if (s.WriteItem(value)) return s.Output(); + return absl::nullopt; +} + +absl::optional SerializeList(const List& value) { + StructuredHeaderSerializer s; + if (s.WriteList(value)) return s.Output(); + return absl::nullopt; +} + +absl::optional SerializeDictionary(const Dictionary& value) { + StructuredHeaderSerializer s; + if (s.WriteDictionary(value)) return s.Output(); + return absl::nullopt; +} + +} // namespace structured_headers +} // namespace quiche diff --git a/quiche/common/structured_headers.h b/quiche/common/structured_headers.h new file mode 100644 index 000000000000..00bf795ffdbd --- /dev/null +++ b/quiche/common/structured_headers.h @@ -0,0 +1,325 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_STRUCTURED_HEADERS_H_ +#define QUICHE_COMMON_STRUCTURED_HEADERS_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace quiche { +namespace structured_headers { + +// This file implements parsing of HTTP structured headers, as defined in +// RFC8941 (https://www.rfc-editor.org/rfc/rfc8941.html). For compatibility with +// the shipped implementation of Web Packaging, this file also supports a +// previous revision of the standard, referred to here as "Draft 9". +// (https://datatracker.ietf.org/doc/draft-ietf-httpbis-header-structure/09/) +// +// The major difference between the two revisions is in the various list +// formats: Draft 9 describes "parameterised lists" and "lists-of-lists", while +// the final RFC uses a single "list" syntax, whose members may be inner lists. +// There should be no ambiguity, however, as the code which calls this parser +// should be expecting only a single type for a given header. +// +// References within the code are tagged with either [SH09] or [RFC8941], +// depending on which revision they refer to. +// +// Currently supported data types are: +// Item: +// integer: 123 +// string: "abc" +// token: abc +// byte sequence: *YWJj* +// Parameterised list: abc_123;a=1;b=2; cdef_456, ghi;q="9";r="w" +// List-of-lists: "foo";"bar", "baz", "bat"; "one" +// List: "foo", "bar", "It was the best of times." +// ("foo" "bar"), ("baz"), ("bat" "one"), () +// abc;a=1;b=2; cde_456, (ghi jkl);q="9";r=w +// Dictionary: a=(1 2), b=3, c=4;aa=bb, d=(5 6);valid=?0 +// +// Functions are provided to parse each of these, which are intended to be +// called with the complete value of an HTTP header (that is, any +// sub-structure will be handled internally by the parser; the exported +// functions are not intended to be called on partial header strings.) Input +// values should be ASCII byte strings (non-ASCII characters should not be +// present in Structured Header values, and will cause the entire header to fail +// to parse.) + +class QUICHE_EXPORT Item { + public: + enum ItemType { + kNullType, + kIntegerType, + kDecimalType, + kStringType, + kTokenType, + kByteSequenceType, + kBooleanType + }; + Item(); + explicit Item(int64_t value); + explicit Item(double value); + explicit Item(bool value); + + // Constructors for string-like items: Strings, Tokens and Byte Sequences. + Item(const char* value, Item::ItemType type = kStringType); + Item(std::string value, Item::ItemType type = kStringType); + + QUICHE_EXPORT friend bool operator==(const Item& lhs, const Item& rhs); + inline friend bool operator!=(const Item& lhs, const Item& rhs) { + return !(lhs == rhs); + } + + bool is_null() const { return Type() == kNullType; } + bool is_integer() const { return Type() == kIntegerType; } + bool is_decimal() const { return Type() == kDecimalType; } + bool is_string() const { return Type() == kStringType; } + bool is_token() const { return Type() == kTokenType; } + bool is_byte_sequence() const { return Type() == kByteSequenceType; } + bool is_boolean() const { return Type() == kBooleanType; } + + int64_t GetInteger() const { + const auto* value = absl::get_if(&value_); + QUICHE_CHECK(value); + return *value; + } + double GetDecimal() const { + const auto* value = absl::get_if(&value_); + QUICHE_CHECK(value); + return *value; + } + bool GetBoolean() const { + const auto* value = absl::get_if(&value_); + QUICHE_CHECK(value); + return *value; + } + // TODO(iclelland): Split up accessors for String, Token and Byte Sequence. + const std::string& GetString() const { + struct Visitor { + const std::string* operator()(const absl::monostate&) { return nullptr; } + const std::string* operator()(const int64_t&) { return nullptr; } + const std::string* operator()(const double&) { return nullptr; } + const std::string* operator()(const std::string& value) { return &value; } + const std::string* operator()(const bool&) { return nullptr; } + }; + const std::string* value = absl::visit(Visitor(), value_); + QUICHE_CHECK(value); + return *value; + } + + // Transfers ownership of the underlying String, Token, or Byte Sequence. + std::string TakeString() && { + struct Visitor { + std::string* operator()(absl::monostate&) { return nullptr; } + std::string* operator()(int64_t&) { return nullptr; } + std::string* operator()(double&) { return nullptr; } + std::string* operator()(std::string& value) { return &value; } + std::string* operator()(bool&) { return nullptr; } + }; + std::string* value = absl::visit(Visitor(), value_); + QUICHE_CHECK(value); + return std::move(*value); + } + + ItemType Type() const { return static_cast(value_.index()); } + + private: + absl::variant + value_; +}; + +// Holds a ParameterizedIdentifier (draft 9 only). The contained Item must be a +// Token, and there may be any number of parameters. Parameter ordering is not +// significant. +struct QUICHE_EXPORT ParameterisedIdentifier { + using Parameters = std::map; + + Item identifier; + Parameters params; + + ParameterisedIdentifier(); + ParameterisedIdentifier(const ParameterisedIdentifier&); + ParameterisedIdentifier& operator=(const ParameterisedIdentifier&); + ParameterisedIdentifier(Item, Parameters); + ~ParameterisedIdentifier(); +}; + +inline bool operator==(const ParameterisedIdentifier& lhs, + const ParameterisedIdentifier& rhs) { + return std::tie(lhs.identifier, lhs.params) == + std::tie(rhs.identifier, rhs.params); +} + +using Parameters = std::vector>; + +struct QUICHE_EXPORT ParameterizedItem { + Item item; + Parameters params; + + ParameterizedItem(); + ParameterizedItem(const ParameterizedItem&); + ParameterizedItem& operator=(const ParameterizedItem&); + ParameterizedItem(Item, Parameters); + ~ParameterizedItem(); +}; + +inline bool operator==(const ParameterizedItem& lhs, + const ParameterizedItem& rhs) { + return std::tie(lhs.item, lhs.params) == std::tie(rhs.item, rhs.params); +} + +inline bool operator!=(const ParameterizedItem& lhs, + const ParameterizedItem& rhs) { + return !(lhs == rhs); +} + +// Holds a ParameterizedMember, which may be either an single Item, or an Inner +// List of ParameterizedItems, along with any number of parameters. Parameter +// ordering is significant. +struct QUICHE_EXPORT ParameterizedMember { + std::vector member; + // If false, then |member| should only hold one Item. + bool member_is_inner_list = false; + + Parameters params; + + ParameterizedMember(); + ParameterizedMember(const ParameterizedMember&); + ParameterizedMember& operator=(const ParameterizedMember&); + ParameterizedMember(std::vector, bool member_is_inner_list, + Parameters); + // Shorthand constructor for a member which is an inner list. + ParameterizedMember(std::vector, Parameters); + // Shorthand constructor for a member which is a single Item. + ParameterizedMember(Item, Parameters); + ~ParameterizedMember(); +}; + +inline bool operator==(const ParameterizedMember& lhs, + const ParameterizedMember& rhs) { + return std::tie(lhs.member, lhs.member_is_inner_list, lhs.params) == + std::tie(rhs.member, rhs.member_is_inner_list, rhs.params); +} + +using DictionaryMember = std::pair; + +// Structured Headers RFC8941 Dictionary. +class QUICHE_EXPORT Dictionary { + public: + using iterator = std::vector::iterator; + using const_iterator = std::vector::const_iterator; + + Dictionary(); + Dictionary(const Dictionary&); + explicit Dictionary(std::vector members); + ~Dictionary(); + Dictionary& operator=(const Dictionary&) = default; + iterator begin(); + const_iterator begin() const; + iterator end(); + const_iterator end() const; + + // operator[](size_t) and at(size_t) will both abort the program in case of + // out of bounds access. + ParameterizedMember& operator[](std::size_t idx); + const ParameterizedMember& operator[](std::size_t idx) const; + ParameterizedMember& at(std::size_t idx); + const ParameterizedMember& at(std::size_t idx) const; + + // Consistent with std::map, if |key| does not exist in the Dictionary, then + // operator[](absl::string_view) will create an entry for it, but at() will + // abort the entire program. + ParameterizedMember& operator[](absl::string_view key); + ParameterizedMember& at(absl::string_view key); + const ParameterizedMember& at(absl::string_view key) const; + + bool empty() const; + std::size_t size() const; + bool contains(absl::string_view key) const; + friend bool operator==(const Dictionary& lhs, const Dictionary& rhs); + friend bool operator!=(const Dictionary& lhs, const Dictionary& rhs); + + private: + // Uses a vector to hold pairs of key and dictionary member. This makes + // look up by index and serialization much easier. + std::vector members_; +}; + +inline bool operator==(const Dictionary& lhs, const Dictionary& rhs) { + return lhs.members_ == rhs.members_; +} + +inline bool operator!=(const Dictionary& lhs, const Dictionary& rhs) { + return !(lhs == rhs); +} + +// Structured Headers Draft 09 Parameterised List. +using ParameterisedList = std::vector; +// Structured Headers Draft 09 List of Lists. +using ListOfLists = std::vector>; +// Structured Headers RFC8941 List. +using List = std::vector; + +// Returns the result of parsing the header value as an Item, if it can be +// parsed as one, or nullopt if it cannot. Note that this uses the Draft 15 +// parsing rules, and so applies tighter range limits to integers. +QUICHE_EXPORT absl::optional ParseItem( + absl::string_view str); + +// Returns the result of parsing the header value as an Item with no parameters, +// or nullopt if it cannot. Note that this uses the Draft 15 parsing rules, and +// so applies tighter range limits to integers. +QUICHE_EXPORT absl::optional ParseBareItem(absl::string_view str); + +// Returns the result of parsing the header value as a Parameterised List, if it +// can be parsed as one, or nullopt if it cannot. Note that parameter keys will +// be returned as strings, which are guaranteed to be ASCII-encoded. List items, +// as well as parameter values, will be returned as Items. This method uses the +// Draft 09 parsing rules for Items, so integers have the 64-bit int range. +// Structured-Headers Draft 09 only. +QUICHE_EXPORT absl::optional ParseParameterisedList( + absl::string_view str); + +// Returns the result of parsing the header value as a List of Lists, if it can +// be parsed as one, or nullopt if it cannot. Inner list items will be returned +// as Items. This method uses the Draft 09 parsing rules for Items, so integers +// have the 64-bit int range. +// Structured-Headers Draft 09 only. +QUICHE_EXPORT absl::optional ParseListOfLists( + absl::string_view str); + +// Returns the result of parsing the header value as a general List, if it can +// be parsed as one, or nullopt if it cannot. +// Structured-Headers Draft 15 only. +QUICHE_EXPORT absl::optional ParseList(absl::string_view str); + +// Returns the result of parsing the header value as a general Dictionary, if it +// can be parsed as one, or nullopt if it cannot. Structured-Headers Draft 15 +// only. +QUICHE_EXPORT absl::optional ParseDictionary(absl::string_view str); + +// Serialization is implemented for Structured-Headers Draft 15 only. +QUICHE_EXPORT absl::optional SerializeItem(const Item& value); +QUICHE_EXPORT absl::optional SerializeItem( + const ParameterizedItem& value); +QUICHE_EXPORT absl::optional SerializeList(const List& value); +QUICHE_EXPORT absl::optional SerializeDictionary( + const Dictionary& value); + +} // namespace structured_headers +} // namespace quiche + +#endif // QUICHE_COMMON_STRUCTURED_HEADERS_H_ diff --git a/quiche/common/structured_headers_fuzzer.cc b/quiche/common/structured_headers_fuzzer.cc new file mode 100644 index 000000000000..aefeea5e6a7e --- /dev/null +++ b/quiche/common/structured_headers_fuzzer.cc @@ -0,0 +1,22 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "absl/strings/string_view.h" +#include "quiche/common/structured_headers.h" + +namespace quiche { +namespace structured_headers { + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + absl::string_view input(reinterpret_cast(data), size); + ParseItem(input); + ParseListOfLists(input); + ParseList(input); + ParseDictionary(input); + ParseParameterisedList(input); + return 0; +} + +} // namespace structured_headers +} // namespace quiche diff --git a/quiche/common/structured_headers_generated_test.cc b/quiche/common/structured_headers_generated_test.cc new file mode 100644 index 000000000000..7e94e49fce07 --- /dev/null +++ b/quiche/common/structured_headers_generated_test.cc @@ -0,0 +1,3944 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/structured_headers.h" + +// This file contains tests cases for the Structured Header parser and +// serializer, taken from the public test case repository at +// https://github.com/httpwg/structured-field-tests. All of the tests are named, +// so a given test case can be found in the JSON files in that repository by +// searching for the test name. This file is generated, with the test cases +// being automatically translated from the JSON source to C++ unit tests. Please +// do not modify, as the contents will be overwritten when this is re-generated. + +// Generated on 2022-03-15 from structured-field-tests.git @ +// faed1f92942abd4fb5d61b1f9f0dc359f499f1d7. + +namespace quiche { +namespace structured_headers { +namespace { + +// Helpers to make test cases clearer + +Item Integer(int64_t value) { return Item(value); } + +std::pair BooleanParam(std::string key, bool value) { + return std::make_pair(key, Item(value)); +} + +std::pair DoubleParam(std::string key, double value) { + return std::make_pair(key, Item(value)); +} + +std::pair Param(std::string key, int64_t value) { + return std::make_pair(key, Item(value)); +} + +std::pair Param(std::string key, std::string value) { + return std::make_pair(key, Item(value)); +} + +std::pair TokenParam(std::string key, std::string value) { + return std::make_pair(key, Item(value, Item::kTokenType)); +} + +const struct ParameterizedItemTestCase { + const char* name; + const char* raw; + size_t raw_len; + const absl::optional + expected; // nullopt if parse error is expected. + const char* canonical; // nullptr if parse error is expected, or if canonical + // format is identical to raw. +} parameterized_item_test_cases[] = { + // binary.json + {"basic binary", + ":aGVsbG8=:", + 10, + {{Item("hello", Item::kByteSequenceType), {}}}, + nullptr}, + {"empty binary", + "::", + 2, + {{Item("", Item::kByteSequenceType), {}}}, + nullptr}, + {"bad paddding", + ":aGVsbG8:", + 9, + {{Item("hello", Item::kByteSequenceType), {}}}, + ":aGVsbG8=:"}, + {"bad end delimiter", ":aGVsbG8=", 9, absl::nullopt, nullptr}, + {"extra whitespace", ":aGVsb G8=:", 11, absl::nullopt, nullptr}, + {"extra chars", ":aGVsbG!8=:", 11, absl::nullopt, nullptr}, + {"suffix chars", ":aGVsbG8=!:", 11, absl::nullopt, nullptr}, + {"non-zero pad bits", + ":iZ==:", + 6, + {{Item("\211", Item::kByteSequenceType), {}}}, + ":iQ==:"}, + {"non-ASCII binary", + ":/+Ah:", + 6, + {{Item("\377\340!", Item::kByteSequenceType), {}}}, + nullptr}, + {"base64url binary", ":_-Ah:", 6, absl::nullopt, nullptr}, + // boolean.json + {"basic true boolean", "?1", 2, {{Item(true), {}}}, nullptr}, + {"basic false boolean", "?0", 2, {{Item(false), {}}}, nullptr}, + {"unknown boolean", "?Q", 2, absl::nullopt, nullptr}, + {"whitespace boolean", "? 1", 3, absl::nullopt, nullptr}, + {"negative zero boolean", "?-0", 3, absl::nullopt, nullptr}, + {"T boolean", "?T", 2, absl::nullopt, nullptr}, + {"F boolean", "?F", 2, absl::nullopt, nullptr}, + {"t boolean", "?t", 2, absl::nullopt, nullptr}, + {"f boolean", "?f", 2, absl::nullopt, nullptr}, + {"spelled-out True boolean", "?True", 5, absl::nullopt, nullptr}, + {"spelled-out False boolean", "?False", 6, absl::nullopt, nullptr}, + // examples.json + {"Foo-Example", + "2; foourl=\"https://foo.example.com/\"", + 36, + {{Integer(2), {Param("foourl", "https://foo.example.com/")}}}, + "2;foourl=\"https://foo.example.com/\""}, + {"Example-IntHeader", + "1; a; b=?0", + 10, + {{Integer(1), {BooleanParam("a", true), BooleanParam("b", false)}}}, + "1;a;b=?0"}, + {"Example-IntItemHeader", "5", 1, {{Integer(5), {}}}, nullptr}, + {"Example-IntItemHeader (params)", + "5; foo=bar", + 10, + {{Integer(5), {TokenParam("foo", "bar")}}}, + "5;foo=bar"}, + {"Example-IntegerHeader", "42", 2, {{Integer(42), {}}}, nullptr}, + {"Example-FloatHeader", "4.5", 3, {{Item(4.500000), {}}}, nullptr}, + {"Example-StringHeader", + "\"hello world\"", + 13, + {{Item("hello world"), {}}}, + nullptr}, + {"Example-BinaryHdr", + ":cHJldGVuZCB0aGlzIGlzIGJpbmFyeSBjb250ZW50Lg==:", + 46, + {{Item("pretend this is binary content.", Item::kByteSequenceType), {}}}, + nullptr}, + {"Example-BoolHdr", "?1", 2, {{Item(true), {}}}, nullptr}, + // item.json + {"empty item", "", 0, absl::nullopt, nullptr}, + {"leading space", " \t 1", 4, absl::nullopt, nullptr}, + {"trailing space", "1 \t ", 4, absl::nullopt, nullptr}, + {"leading and trailing space", " 1 ", 5, {{Integer(1), {}}}, "1"}, + {"leading and trailing whitespace", " 1 ", 8, {{Integer(1), {}}}, "1"}, + // number-generated.json + {"1 digits of zero", "0", 1, {{Integer(0), {}}}, "0"}, + {"1 digit small integer", "1", 1, {{Integer(1), {}}}, nullptr}, + {"1 digit large integer", "9", 1, {{Integer(9), {}}}, nullptr}, + {"2 digits of zero", "00", 2, {{Integer(0), {}}}, "0"}, + {"2 digit small integer", "11", 2, {{Integer(11), {}}}, nullptr}, + {"2 digit large integer", "99", 2, {{Integer(99), {}}}, nullptr}, + {"3 digits of zero", "000", 3, {{Integer(0), {}}}, "0"}, + {"3 digit small integer", "111", 3, {{Integer(111), {}}}, nullptr}, + {"3 digit large integer", "999", 3, {{Integer(999), {}}}, nullptr}, + {"4 digits of zero", "0000", 4, {{Integer(0), {}}}, "0"}, + {"4 digit small integer", "1111", 4, {{Integer(1111), {}}}, nullptr}, + {"4 digit large integer", "9999", 4, {{Integer(9999), {}}}, nullptr}, + {"5 digits of zero", "00000", 5, {{Integer(0), {}}}, "0"}, + {"5 digit small integer", "11111", 5, {{Integer(11111), {}}}, nullptr}, + {"5 digit large integer", "99999", 5, {{Integer(99999), {}}}, nullptr}, + {"6 digits of zero", "000000", 6, {{Integer(0), {}}}, "0"}, + {"6 digit small integer", "111111", 6, {{Integer(111111), {}}}, nullptr}, + {"6 digit large integer", "999999", 6, {{Integer(999999), {}}}, nullptr}, + {"7 digits of zero", "0000000", 7, {{Integer(0), {}}}, "0"}, + {"7 digit small integer", "1111111", 7, {{Integer(1111111), {}}}, nullptr}, + {"7 digit large integer", "9999999", 7, {{Integer(9999999), {}}}, nullptr}, + {"8 digits of zero", "00000000", 8, {{Integer(0), {}}}, "0"}, + {"8 digit small integer", + "11111111", + 8, + {{Integer(11111111), {}}}, + nullptr}, + {"8 digit large integer", + "99999999", + 8, + {{Integer(99999999), {}}}, + nullptr}, + {"9 digits of zero", "000000000", 9, {{Integer(0), {}}}, "0"}, + {"9 digit small integer", + "111111111", + 9, + {{Integer(111111111), {}}}, + nullptr}, + {"9 digit large integer", + "999999999", + 9, + {{Integer(999999999), {}}}, + nullptr}, + {"10 digits of zero", "0000000000", 10, {{Integer(0), {}}}, "0"}, + {"10 digit small integer", + "1111111111", + 10, + {{Integer(1111111111), {}}}, + nullptr}, + {"10 digit large integer", + "9999999999", + 10, + {{Integer(9999999999), {}}}, + nullptr}, + {"11 digits of zero", "00000000000", 11, {{Integer(0), {}}}, "0"}, + {"11 digit small integer", + "11111111111", + 11, + {{Integer(11111111111), {}}}, + nullptr}, + {"11 digit large integer", + "99999999999", + 11, + {{Integer(99999999999), {}}}, + nullptr}, + {"12 digits of zero", "000000000000", 12, {{Integer(0), {}}}, "0"}, + {"12 digit small integer", + "111111111111", + 12, + {{Integer(111111111111), {}}}, + nullptr}, + {"12 digit large integer", + "999999999999", + 12, + {{Integer(999999999999), {}}}, + nullptr}, + {"13 digits of zero", "0000000000000", 13, {{Integer(0), {}}}, "0"}, + {"13 digit small integer", + "1111111111111", + 13, + {{Integer(1111111111111), {}}}, + nullptr}, + {"13 digit large integer", + "9999999999999", + 13, + {{Integer(9999999999999), {}}}, + nullptr}, + {"14 digits of zero", "00000000000000", 14, {{Integer(0), {}}}, "0"}, + {"14 digit small integer", + "11111111111111", + 14, + {{Integer(11111111111111), {}}}, + nullptr}, + {"14 digit large integer", + "99999999999999", + 14, + {{Integer(99999999999999), {}}}, + nullptr}, + {"15 digits of zero", "000000000000000", 15, {{Integer(0), {}}}, "0"}, + {"15 digit small integer", + "111111111111111", + 15, + {{Integer(111111111111111), {}}}, + nullptr}, + {"15 digit large integer", + "999999999999999", + 15, + {{Integer(999999999999999), {}}}, + nullptr}, + {"2 digit 0, 1 fractional small decimal", + "0.1", + 3, + {{Item(0.100000), {}}}, + "0.1"}, + {"2 digit, 1 fractional 0 decimal", + "1.0", + 3, + {{Item(1.000000), {}}}, + "1.0"}, + {"2 digit, 1 fractional small decimal", + "1.1", + 3, + {{Item(1.100000), {}}}, + nullptr}, + {"2 digit, 1 fractional large decimal", + "9.9", + 3, + {{Item(9.900000), {}}}, + nullptr}, + {"3 digit 0, 2 fractional small decimal", + "0.11", + 4, + {{Item(0.110000), {}}}, + "0.11"}, + {"3 digit, 2 fractional 0 decimal", + "1.00", + 4, + {{Item(1.000000), {}}}, + "1.0"}, + {"3 digit, 2 fractional small decimal", + "1.11", + 4, + {{Item(1.110000), {}}}, + nullptr}, + {"3 digit, 2 fractional large decimal", + "9.99", + 4, + {{Item(9.990000), {}}}, + nullptr}, + {"4 digit 0, 3 fractional small decimal", + "0.111", + 5, + {{Item(0.111000), {}}}, + "0.111"}, + {"4 digit, 3 fractional 0 decimal", + "1.000", + 5, + {{Item(1.000000), {}}}, + "1.0"}, + {"4 digit, 3 fractional small decimal", + "1.111", + 5, + {{Item(1.111000), {}}}, + nullptr}, + {"4 digit, 3 fractional large decimal", + "9.999", + 5, + {{Item(9.999000), {}}}, + nullptr}, + {"3 digit 0, 1 fractional small decimal", + "00.1", + 4, + {{Item(0.100000), {}}}, + "0.1"}, + {"3 digit, 1 fractional 0 decimal", + "11.0", + 4, + {{Item(11.000000), {}}}, + "11.0"}, + {"3 digit, 1 fractional small decimal", + "11.1", + 4, + {{Item(11.100000), {}}}, + nullptr}, + {"3 digit, 1 fractional large decimal", + "99.9", + 4, + {{Item(99.900000), {}}}, + nullptr}, + {"4 digit 0, 2 fractional small decimal", + "00.11", + 5, + {{Item(0.110000), {}}}, + "0.11"}, + {"4 digit, 2 fractional 0 decimal", + "11.00", + 5, + {{Item(11.000000), {}}}, + "11.0"}, + {"4 digit, 2 fractional small decimal", + "11.11", + 5, + {{Item(11.110000), {}}}, + nullptr}, + {"4 digit, 2 fractional large decimal", + "99.99", + 5, + {{Item(99.990000), {}}}, + nullptr}, + {"5 digit 0, 3 fractional small decimal", + "00.111", + 6, + {{Item(0.111000), {}}}, + "0.111"}, + {"5 digit, 3 fractional 0 decimal", + "11.000", + 6, + {{Item(11.000000), {}}}, + "11.0"}, + {"5 digit, 3 fractional small decimal", + "11.111", + 6, + {{Item(11.111000), {}}}, + nullptr}, + {"5 digit, 3 fractional large decimal", + "99.999", + 6, + {{Item(99.999000), {}}}, + nullptr}, + {"4 digit 0, 1 fractional small decimal", + "000.1", + 5, + {{Item(0.100000), {}}}, + "0.1"}, + {"4 digit, 1 fractional 0 decimal", + "111.0", + 5, + {{Item(111.000000), {}}}, + "111.0"}, + {"4 digit, 1 fractional small decimal", + "111.1", + 5, + {{Item(111.100000), {}}}, + nullptr}, + {"4 digit, 1 fractional large decimal", + "999.9", + 5, + {{Item(999.900000), {}}}, + nullptr}, + {"5 digit 0, 2 fractional small decimal", + "000.11", + 6, + {{Item(0.110000), {}}}, + "0.11"}, + {"5 digit, 2 fractional 0 decimal", + "111.00", + 6, + {{Item(111.000000), {}}}, + "111.0"}, + {"5 digit, 2 fractional small decimal", + "111.11", + 6, + {{Item(111.110000), {}}}, + nullptr}, + {"5 digit, 2 fractional large decimal", + "999.99", + 6, + {{Item(999.990000), {}}}, + nullptr}, + {"6 digit 0, 3 fractional small decimal", + "000.111", + 7, + {{Item(0.111000), {}}}, + "0.111"}, + {"6 digit, 3 fractional 0 decimal", + "111.000", + 7, + {{Item(111.000000), {}}}, + "111.0"}, + {"6 digit, 3 fractional small decimal", + "111.111", + 7, + {{Item(111.111000), {}}}, + nullptr}, + {"6 digit, 3 fractional large decimal", + "999.999", + 7, + {{Item(999.999000), {}}}, + nullptr}, + {"5 digit 0, 1 fractional small decimal", + "0000.1", + 6, + {{Item(0.100000), {}}}, + "0.1"}, + {"5 digit, 1 fractional 0 decimal", + "1111.0", + 6, + {{Item(1111.000000), {}}}, + "1111.0"}, + {"5 digit, 1 fractional small decimal", + "1111.1", + 6, + {{Item(1111.100000), {}}}, + nullptr}, + {"5 digit, 1 fractional large decimal", + "9999.9", + 6, + {{Item(9999.900000), {}}}, + nullptr}, + {"6 digit 0, 2 fractional small decimal", + "0000.11", + 7, + {{Item(0.110000), {}}}, + "0.11"}, + {"6 digit, 2 fractional 0 decimal", + "1111.00", + 7, + {{Item(1111.000000), {}}}, + "1111.0"}, + {"6 digit, 2 fractional small decimal", + "1111.11", + 7, + {{Item(1111.110000), {}}}, + nullptr}, + {"6 digit, 2 fractional large decimal", + "9999.99", + 7, + {{Item(9999.990000), {}}}, + nullptr}, + {"7 digit 0, 3 fractional small decimal", + "0000.111", + 8, + {{Item(0.111000), {}}}, + "0.111"}, + {"7 digit, 3 fractional 0 decimal", + "1111.000", + 8, + {{Item(1111.000000), {}}}, + "1111.0"}, + {"7 digit, 3 fractional small decimal", + "1111.111", + 8, + {{Item(1111.111000), {}}}, + nullptr}, + {"7 digit, 3 fractional large decimal", + "9999.999", + 8, + {{Item(9999.999000), {}}}, + nullptr}, + {"6 digit 0, 1 fractional small decimal", + "00000.1", + 7, + {{Item(0.100000), {}}}, + "0.1"}, + {"6 digit, 1 fractional 0 decimal", + "11111.0", + 7, + {{Item(11111.000000), {}}}, + "11111.0"}, + {"6 digit, 1 fractional small decimal", + "11111.1", + 7, + {{Item(11111.100000), {}}}, + nullptr}, + {"6 digit, 1 fractional large decimal", + "99999.9", + 7, + {{Item(99999.900000), {}}}, + nullptr}, + {"7 digit 0, 2 fractional small decimal", + "00000.11", + 8, + {{Item(0.110000), {}}}, + "0.11"}, + {"7 digit, 2 fractional 0 decimal", + "11111.00", + 8, + {{Item(11111.000000), {}}}, + "11111.0"}, + {"7 digit, 2 fractional small decimal", + "11111.11", + 8, + {{Item(11111.110000), {}}}, + nullptr}, + {"7 digit, 2 fractional large decimal", + "99999.99", + 8, + {{Item(99999.990000), {}}}, + nullptr}, + {"8 digit 0, 3 fractional small decimal", + "00000.111", + 9, + {{Item(0.111000), {}}}, + "0.111"}, + {"8 digit, 3 fractional 0 decimal", + "11111.000", + 9, + {{Item(11111.000000), {}}}, + "11111.0"}, + {"8 digit, 3 fractional small decimal", + "11111.111", + 9, + {{Item(11111.111000), {}}}, + nullptr}, + {"8 digit, 3 fractional large decimal", + "99999.999", + 9, + {{Item(99999.999000), {}}}, + nullptr}, + {"7 digit 0, 1 fractional small decimal", + "000000.1", + 8, + {{Item(0.100000), {}}}, + "0.1"}, + {"7 digit, 1 fractional 0 decimal", + "111111.0", + 8, + {{Item(111111.000000), {}}}, + "111111.0"}, + {"7 digit, 1 fractional small decimal", + "111111.1", + 8, + {{Item(111111.100000), {}}}, + nullptr}, + {"7 digit, 1 fractional large decimal", + "999999.9", + 8, + {{Item(999999.900000), {}}}, + nullptr}, + {"8 digit 0, 2 fractional small decimal", + "000000.11", + 9, + {{Item(0.110000), {}}}, + "0.11"}, + {"8 digit, 2 fractional 0 decimal", + "111111.00", + 9, + {{Item(111111.000000), {}}}, + "111111.0"}, + {"8 digit, 2 fractional small decimal", + "111111.11", + 9, + {{Item(111111.110000), {}}}, + nullptr}, + {"8 digit, 2 fractional large decimal", + "999999.99", + 9, + {{Item(999999.990000), {}}}, + nullptr}, + {"9 digit 0, 3 fractional small decimal", + "000000.111", + 10, + {{Item(0.111000), {}}}, + "0.111"}, + {"9 digit, 3 fractional 0 decimal", + "111111.000", + 10, + {{Item(111111.000000), {}}}, + "111111.0"}, + {"9 digit, 3 fractional small decimal", + "111111.111", + 10, + {{Item(111111.111000), {}}}, + nullptr}, + {"9 digit, 3 fractional large decimal", + "999999.999", + 10, + {{Item(999999.999000), {}}}, + nullptr}, + {"8 digit 0, 1 fractional small decimal", + "0000000.1", + 9, + {{Item(0.100000), {}}}, + "0.1"}, + {"8 digit, 1 fractional 0 decimal", + "1111111.0", + 9, + {{Item(1111111.000000), {}}}, + "1111111.0"}, + {"8 digit, 1 fractional small decimal", + "1111111.1", + 9, + {{Item(1111111.100000), {}}}, + nullptr}, + {"8 digit, 1 fractional large decimal", + "9999999.9", + 9, + {{Item(9999999.900000), {}}}, + nullptr}, + {"9 digit 0, 2 fractional small decimal", + "0000000.11", + 10, + {{Item(0.110000), {}}}, + "0.11"}, + {"9 digit, 2 fractional 0 decimal", + "1111111.00", + 10, + {{Item(1111111.000000), {}}}, + "1111111.0"}, + {"9 digit, 2 fractional small decimal", + "1111111.11", + 10, + {{Item(1111111.110000), {}}}, + nullptr}, + {"9 digit, 2 fractional large decimal", + "9999999.99", + 10, + {{Item(9999999.990000), {}}}, + nullptr}, + {"10 digit 0, 3 fractional small decimal", + "0000000.111", + 11, + {{Item(0.111000), {}}}, + "0.111"}, + {"10 digit, 3 fractional 0 decimal", + "1111111.000", + 11, + {{Item(1111111.000000), {}}}, + "1111111.0"}, + {"10 digit, 3 fractional small decimal", + "1111111.111", + 11, + {{Item(1111111.111000), {}}}, + nullptr}, + {"10 digit, 3 fractional large decimal", + "9999999.999", + 11, + {{Item(9999999.999000), {}}}, + nullptr}, + {"9 digit 0, 1 fractional small decimal", + "00000000.1", + 10, + {{Item(0.100000), {}}}, + "0.1"}, + {"9 digit, 1 fractional 0 decimal", + "11111111.0", + 10, + {{Item(11111111.000000), {}}}, + "11111111.0"}, + {"9 digit, 1 fractional small decimal", + "11111111.1", + 10, + {{Item(11111111.100000), {}}}, + nullptr}, + {"9 digit, 1 fractional large decimal", + "99999999.9", + 10, + {{Item(99999999.900000), {}}}, + nullptr}, + {"10 digit 0, 2 fractional small decimal", + "00000000.11", + 11, + {{Item(0.110000), {}}}, + "0.11"}, + {"10 digit, 2 fractional 0 decimal", + "11111111.00", + 11, + {{Item(11111111.000000), {}}}, + "11111111.0"}, + {"10 digit, 2 fractional small decimal", + "11111111.11", + 11, + {{Item(11111111.110000), {}}}, + nullptr}, + {"10 digit, 2 fractional large decimal", + "99999999.99", + 11, + {{Item(99999999.990000), {}}}, + nullptr}, + {"11 digit 0, 3 fractional small decimal", + "00000000.111", + 12, + {{Item(0.111000), {}}}, + "0.111"}, + {"11 digit, 3 fractional 0 decimal", + "11111111.000", + 12, + {{Item(11111111.000000), {}}}, + "11111111.0"}, + {"11 digit, 3 fractional small decimal", + "11111111.111", + 12, + {{Item(11111111.111000), {}}}, + nullptr}, + {"11 digit, 3 fractional large decimal", + "99999999.999", + 12, + {{Item(99999999.999000), {}}}, + nullptr}, + {"10 digit 0, 1 fractional small decimal", + "000000000.1", + 11, + {{Item(0.100000), {}}}, + "0.1"}, + {"10 digit, 1 fractional 0 decimal", + "111111111.0", + 11, + {{Item(111111111.000000), {}}}, + "111111111.0"}, + {"10 digit, 1 fractional small decimal", + "111111111.1", + 11, + {{Item(111111111.100000), {}}}, + nullptr}, + {"10 digit, 1 fractional large decimal", + "999999999.9", + 11, + {{Item(999999999.900000), {}}}, + nullptr}, + {"11 digit 0, 2 fractional small decimal", + "000000000.11", + 12, + {{Item(0.110000), {}}}, + "0.11"}, + {"11 digit, 2 fractional 0 decimal", + "111111111.00", + 12, + {{Item(111111111.000000), {}}}, + "111111111.0"}, + {"11 digit, 2 fractional small decimal", + "111111111.11", + 12, + {{Item(111111111.110000), {}}}, + nullptr}, + {"11 digit, 2 fractional large decimal", + "999999999.99", + 12, + {{Item(999999999.990000), {}}}, + nullptr}, + {"12 digit 0, 3 fractional small decimal", + "000000000.111", + 13, + {{Item(0.111000), {}}}, + "0.111"}, + {"12 digit, 3 fractional 0 decimal", + "111111111.000", + 13, + {{Item(111111111.000000), {}}}, + "111111111.0"}, + {"12 digit, 3 fractional small decimal", + "111111111.111", + 13, + {{Item(111111111.111000), {}}}, + nullptr}, + {"12 digit, 3 fractional large decimal", + "999999999.999", + 13, + {{Item(999999999.999000), {}}}, + nullptr}, + {"11 digit 0, 1 fractional small decimal", + "0000000000.1", + 12, + {{Item(0.100000), {}}}, + "0.1"}, + {"11 digit, 1 fractional 0 decimal", + "1111111111.0", + 12, + {{Item(1111111111.000000), {}}}, + "1111111111.0"}, + {"11 digit, 1 fractional small decimal", + "1111111111.1", + 12, + {{Item(1111111111.100000), {}}}, + nullptr}, + {"11 digit, 1 fractional large decimal", + "9999999999.9", + 12, + {{Item(9999999999.900000), {}}}, + nullptr}, + {"12 digit 0, 2 fractional small decimal", + "0000000000.11", + 13, + {{Item(0.110000), {}}}, + "0.11"}, + {"12 digit, 2 fractional 0 decimal", + "1111111111.00", + 13, + {{Item(1111111111.000000), {}}}, + "1111111111.0"}, + {"12 digit, 2 fractional small decimal", + "1111111111.11", + 13, + {{Item(1111111111.110000), {}}}, + nullptr}, + {"12 digit, 2 fractional large decimal", + "9999999999.99", + 13, + {{Item(9999999999.990000), {}}}, + nullptr}, + {"13 digit 0, 3 fractional small decimal", + "0000000000.111", + 14, + {{Item(0.111000), {}}}, + "0.111"}, + {"13 digit, 3 fractional 0 decimal", + "1111111111.000", + 14, + {{Item(1111111111.000000), {}}}, + "1111111111.0"}, + {"13 digit, 3 fractional small decimal", + "1111111111.111", + 14, + {{Item(1111111111.111000), {}}}, + nullptr}, + {"13 digit, 3 fractional large decimal", + "9999999999.999", + 14, + {{Item(9999999999.999001), {}}}, + nullptr}, + {"12 digit 0, 1 fractional small decimal", + "00000000000.1", + 13, + {{Item(0.100000), {}}}, + "0.1"}, + {"12 digit, 1 fractional 0 decimal", + "11111111111.0", + 13, + {{Item(11111111111.000000), {}}}, + "11111111111.0"}, + {"12 digit, 1 fractional small decimal", + "11111111111.1", + 13, + {{Item(11111111111.100000), {}}}, + nullptr}, + {"12 digit, 1 fractional large decimal", + "99999999999.9", + 13, + {{Item(99999999999.899994), {}}}, + nullptr}, + {"13 digit 0, 2 fractional small decimal", + "00000000000.11", + 14, + {{Item(0.110000), {}}}, + "0.11"}, + {"13 digit, 2 fractional 0 decimal", + "11111111111.00", + 14, + {{Item(11111111111.000000), {}}}, + "11111111111.0"}, + {"13 digit, 2 fractional small decimal", + "11111111111.11", + 14, + {{Item(11111111111.110001), {}}}, + nullptr}, + {"13 digit, 2 fractional large decimal", + "99999999999.99", + 14, + {{Item(99999999999.990005), {}}}, + nullptr}, + {"14 digit 0, 3 fractional small decimal", + "00000000000.111", + 15, + {{Item(0.111000), {}}}, + "0.111"}, + {"14 digit, 3 fractional 0 decimal", + "11111111111.000", + 15, + {{Item(11111111111.000000), {}}}, + "11111111111.0"}, + {"14 digit, 3 fractional small decimal", + "11111111111.111", + 15, + {{Item(11111111111.111000), {}}}, + nullptr}, + {"14 digit, 3 fractional large decimal", + "99999999999.999", + 15, + {{Item(99999999999.998993), {}}}, + nullptr}, + {"13 digit 0, 1 fractional small decimal", + "000000000000.1", + 14, + {{Item(0.100000), {}}}, + "0.1"}, + {"13 digit, 1 fractional 0 decimal", + "111111111111.0", + 14, + {{Item(111111111111.000000), {}}}, + "111111111111.0"}, + {"13 digit, 1 fractional small decimal", + "111111111111.1", + 14, + {{Item(111111111111.100006), {}}}, + nullptr}, + {"13 digit, 1 fractional large decimal", + "999999999999.9", + 14, + {{Item(999999999999.900024), {}}}, + nullptr}, + {"14 digit 0, 2 fractional small decimal", + "000000000000.11", + 15, + {{Item(0.110000), {}}}, + "0.11"}, + {"14 digit, 2 fractional 0 decimal", + "111111111111.00", + 15, + {{Item(111111111111.000000), {}}}, + "111111111111.0"}, + {"14 digit, 2 fractional small decimal", + "111111111111.11", + 15, + {{Item(111111111111.110001), {}}}, + nullptr}, + {"14 digit, 2 fractional large decimal", + "999999999999.99", + 15, + {{Item(999999999999.989990), {}}}, + nullptr}, + {"15 digit 0, 3 fractional small decimal", + "000000000000.111", + 16, + {{Item(0.111000), {}}}, + "0.111"}, + {"15 digit, 3 fractional 0 decimal", + "111111111111.000", + 16, + {{Item(111111111111.000000), {}}}, + "111111111111.0"}, + {"15 digit, 3 fractional small decimal", + "111111111111.111", + 16, + {{Item(111111111111.110992), {}}}, + nullptr}, + {"15 digit, 3 fractional large decimal", + "999999999999.999", + 16, + {{Item(999999999999.999023), {}}}, + nullptr}, + {"too many digit 0 decimal", "000000000000000.0", 17, absl::nullopt, + nullptr}, + {"too many fractional digits 0 decimal", "000000000000.0000", 17, + absl::nullopt, nullptr}, + {"too many digit 9 decimal", "999999999999999.9", 17, absl::nullopt, + nullptr}, + {"too many fractional digits 9 decimal", "999999999999.9999", 17, + absl::nullopt, nullptr}, + // number.json + {"basic integer", "42", 2, {{Integer(42), {}}}, nullptr}, + {"zero integer", "0", 1, {{Integer(0), {}}}, nullptr}, + {"negative zero", "-0", 2, {{Integer(0), {}}}, "0"}, + {"double negative zero", "--0", 3, absl::nullopt, nullptr}, + {"negative integer", "-42", 3, {{Integer(-42), {}}}, nullptr}, + {"leading 0 integer", "042", 3, {{Integer(42), {}}}, "42"}, + {"leading 0 negative integer", "-042", 4, {{Integer(-42), {}}}, "-42"}, + {"leading 0 zero", "00", 2, {{Integer(0), {}}}, "0"}, + {"comma", "2,3", 3, absl::nullopt, nullptr}, + {"negative non-DIGIT first character", "-a23", 4, absl::nullopt, nullptr}, + {"sign out of place", "4-2", 3, absl::nullopt, nullptr}, + {"whitespace after sign", "- 42", 4, absl::nullopt, nullptr}, + {"long integer", + "123456789012345", + 15, + {{Integer(123456789012345), {}}}, + nullptr}, + {"long negative integer", + "-123456789012345", + 16, + {{Integer(-123456789012345), {}}}, + nullptr}, + {"too long integer", "1234567890123456", 16, absl::nullopt, nullptr}, + {"negative too long integer", "-1234567890123456", 17, absl::nullopt, + nullptr}, + {"simple decimal", "1.23", 4, {{Item(1.230000), {}}}, nullptr}, + {"negative decimal", "-1.23", 5, {{Item(-1.230000), {}}}, nullptr}, + {"decimal, whitespace after decimal", "1. 23", 5, absl::nullopt, nullptr}, + {"decimal, whitespace before decimal", "1 .23", 5, absl::nullopt, nullptr}, + {"negative decimal, whitespace after sign", "- 1.23", 6, absl::nullopt, + nullptr}, + {"tricky precision decimal", + "123456789012.1", + 14, + {{Item(123456789012.100006), {}}}, + nullptr}, + {"double decimal decimal", "1.5.4", 5, absl::nullopt, nullptr}, + {"adjacent double decimal decimal", "1..4", 4, absl::nullopt, nullptr}, + {"decimal with three fractional digits", + "1.123", + 5, + {{Item(1.123000), {}}}, + nullptr}, + {"negative decimal with three fractional digits", + "-1.123", + 6, + {{Item(-1.123000), {}}}, + nullptr}, + {"decimal with four fractional digits", "1.1234", 6, absl::nullopt, + nullptr}, + {"negative decimal with four fractional digits", "-1.1234", 7, + absl::nullopt, nullptr}, + {"decimal with thirteen integer digits", "1234567890123.0", 15, + absl::nullopt, nullptr}, + {"negative decimal with thirteen integer digits", "-1234567890123.0", 16, + absl::nullopt, nullptr}, + // string-generated.json + {"0x00 in string", "\" \000 \"", 5, absl::nullopt, nullptr}, + {"0x01 in string", "\" \001 \"", 5, absl::nullopt, nullptr}, + {"0x02 in string", "\" \002 \"", 5, absl::nullopt, nullptr}, + {"0x03 in string", "\" \003 \"", 5, absl::nullopt, nullptr}, + {"0x04 in string", "\" \004 \"", 5, absl::nullopt, nullptr}, + {"0x05 in string", "\" \005 \"", 5, absl::nullopt, nullptr}, + {"0x06 in string", "\" \006 \"", 5, absl::nullopt, nullptr}, + {"0x07 in string", "\" \a \"", 5, absl::nullopt, nullptr}, + {"0x08 in string", "\" \b \"", 5, absl::nullopt, nullptr}, + {"0x09 in string", "\" \t \"", 5, absl::nullopt, nullptr}, + {"0x0a in string", "\" \n \"", 5, absl::nullopt, nullptr}, + {"0x0b in string", "\" \v \"", 5, absl::nullopt, nullptr}, + {"0x0c in string", "\" \f \"", 5, absl::nullopt, nullptr}, + {"0x0d in string", "\" \r \"", 5, absl::nullopt, nullptr}, + {"0x0e in string", "\" \016 \"", 5, absl::nullopt, nullptr}, + {"0x0f in string", "\" \017 \"", 5, absl::nullopt, nullptr}, + {"0x10 in string", "\" \020 \"", 5, absl::nullopt, nullptr}, + {"0x11 in string", "\" \021 \"", 5, absl::nullopt, nullptr}, + {"0x12 in string", "\" \022 \"", 5, absl::nullopt, nullptr}, + {"0x13 in string", "\" \023 \"", 5, absl::nullopt, nullptr}, + {"0x14 in string", "\" \024 \"", 5, absl::nullopt, nullptr}, + {"0x15 in string", "\" \025 \"", 5, absl::nullopt, nullptr}, + {"0x16 in string", "\" \026 \"", 5, absl::nullopt, nullptr}, + {"0x17 in string", "\" \027 \"", 5, absl::nullopt, nullptr}, + {"0x18 in string", "\" \030 \"", 5, absl::nullopt, nullptr}, + {"0x19 in string", "\" \031 \"", 5, absl::nullopt, nullptr}, + {"0x1a in string", "\" \032 \"", 5, absl::nullopt, nullptr}, + {"0x1b in string", "\" \033 \"", 5, absl::nullopt, nullptr}, + {"0x1c in string", "\" \034 \"", 5, absl::nullopt, nullptr}, + {"0x1d in string", "\" \035 \"", 5, absl::nullopt, nullptr}, + {"0x1e in string", "\" \036 \"", 5, absl::nullopt, nullptr}, + {"0x1f in string", "\" \037 \"", 5, absl::nullopt, nullptr}, + {"0x20 in string", "\" \"", 5, {{Item(" "), {}}}, nullptr}, + {"0x21 in string", "\" ! \"", 5, {{Item(" ! "), {}}}, nullptr}, + {"0x22 in string", "\" \" \"", 5, absl::nullopt, nullptr}, + {"0x23 in string", "\" # \"", 5, {{Item(" # "), {}}}, nullptr}, + {"0x24 in string", "\" $ \"", 5, {{Item(" $ "), {}}}, nullptr}, + {"0x25 in string", "\" % \"", 5, {{Item(" % "), {}}}, nullptr}, + {"0x26 in string", "\" & \"", 5, {{Item(" & "), {}}}, nullptr}, + {"0x27 in string", "\" ' \"", 5, {{Item(" ' "), {}}}, nullptr}, + {"0x28 in string", "\" ( \"", 5, {{Item(" ( "), {}}}, nullptr}, + {"0x29 in string", "\" ) \"", 5, {{Item(" ) "), {}}}, nullptr}, + {"0x2a in string", "\" * \"", 5, {{Item(" * "), {}}}, nullptr}, + {"0x2b in string", "\" + \"", 5, {{Item(" + "), {}}}, nullptr}, + {"0x2c in string", "\" , \"", 5, {{Item(" , "), {}}}, nullptr}, + {"0x2d in string", "\" - \"", 5, {{Item(" - "), {}}}, nullptr}, + {"0x2e in string", "\" . \"", 5, {{Item(" . "), {}}}, nullptr}, + {"0x2f in string", "\" / \"", 5, {{Item(" / "), {}}}, nullptr}, + {"0x30 in string", "\" 0 \"", 5, {{Item(" 0 "), {}}}, nullptr}, + {"0x31 in string", "\" 1 \"", 5, {{Item(" 1 "), {}}}, nullptr}, + {"0x32 in string", "\" 2 \"", 5, {{Item(" 2 "), {}}}, nullptr}, + {"0x33 in string", "\" 3 \"", 5, {{Item(" 3 "), {}}}, nullptr}, + {"0x34 in string", "\" 4 \"", 5, {{Item(" 4 "), {}}}, nullptr}, + {"0x35 in string", "\" 5 \"", 5, {{Item(" 5 "), {}}}, nullptr}, + {"0x36 in string", "\" 6 \"", 5, {{Item(" 6 "), {}}}, nullptr}, + {"0x37 in string", "\" 7 \"", 5, {{Item(" 7 "), {}}}, nullptr}, + {"0x38 in string", "\" 8 \"", 5, {{Item(" 8 "), {}}}, nullptr}, + {"0x39 in string", "\" 9 \"", 5, {{Item(" 9 "), {}}}, nullptr}, + {"0x3a in string", "\" : \"", 5, {{Item(" : "), {}}}, nullptr}, + {"0x3b in string", "\" ; \"", 5, {{Item(" ; "), {}}}, nullptr}, + {"0x3c in string", "\" < \"", 5, {{Item(" < "), {}}}, nullptr}, + {"0x3d in string", "\" = \"", 5, {{Item(" = "), {}}}, nullptr}, + {"0x3e in string", "\" > \"", 5, {{Item(" > "), {}}}, nullptr}, + {"0x3f in string", "\" ? \"", 5, {{Item(" ? "), {}}}, nullptr}, + {"0x40 in string", "\" @ \"", 5, {{Item(" @ "), {}}}, nullptr}, + {"0x41 in string", "\" A \"", 5, {{Item(" A "), {}}}, nullptr}, + {"0x42 in string", "\" B \"", 5, {{Item(" B "), {}}}, nullptr}, + {"0x43 in string", "\" C \"", 5, {{Item(" C "), {}}}, nullptr}, + {"0x44 in string", "\" D \"", 5, {{Item(" D "), {}}}, nullptr}, + {"0x45 in string", "\" E \"", 5, {{Item(" E "), {}}}, nullptr}, + {"0x46 in string", "\" F \"", 5, {{Item(" F "), {}}}, nullptr}, + {"0x47 in string", "\" G \"", 5, {{Item(" G "), {}}}, nullptr}, + {"0x48 in string", "\" H \"", 5, {{Item(" H "), {}}}, nullptr}, + {"0x49 in string", "\" I \"", 5, {{Item(" I "), {}}}, nullptr}, + {"0x4a in string", "\" J \"", 5, {{Item(" J "), {}}}, nullptr}, + {"0x4b in string", "\" K \"", 5, {{Item(" K "), {}}}, nullptr}, + {"0x4c in string", "\" L \"", 5, {{Item(" L "), {}}}, nullptr}, + {"0x4d in string", "\" M \"", 5, {{Item(" M "), {}}}, nullptr}, + {"0x4e in string", "\" N \"", 5, {{Item(" N "), {}}}, nullptr}, + {"0x4f in string", "\" O \"", 5, {{Item(" O "), {}}}, nullptr}, + {"0x50 in string", "\" P \"", 5, {{Item(" P "), {}}}, nullptr}, + {"0x51 in string", "\" Q \"", 5, {{Item(" Q "), {}}}, nullptr}, + {"0x52 in string", "\" R \"", 5, {{Item(" R "), {}}}, nullptr}, + {"0x53 in string", "\" S \"", 5, {{Item(" S "), {}}}, nullptr}, + {"0x54 in string", "\" T \"", 5, {{Item(" T "), {}}}, nullptr}, + {"0x55 in string", "\" U \"", 5, {{Item(" U "), {}}}, nullptr}, + {"0x56 in string", "\" V \"", 5, {{Item(" V "), {}}}, nullptr}, + {"0x57 in string", "\" W \"", 5, {{Item(" W "), {}}}, nullptr}, + {"0x58 in string", "\" X \"", 5, {{Item(" X "), {}}}, nullptr}, + {"0x59 in string", "\" Y \"", 5, {{Item(" Y "), {}}}, nullptr}, + {"0x5a in string", "\" Z \"", 5, {{Item(" Z "), {}}}, nullptr}, + {"0x5b in string", "\" [ \"", 5, {{Item(" [ "), {}}}, nullptr}, + {"0x5c in string", "\" \\ \"", 5, absl::nullopt, nullptr}, + {"0x5d in string", "\" ] \"", 5, {{Item(" ] "), {}}}, nullptr}, + {"0x5e in string", "\" ^ \"", 5, {{Item(" ^ "), {}}}, nullptr}, + {"0x5f in string", "\" _ \"", 5, {{Item(" _ "), {}}}, nullptr}, + {"0x60 in string", "\" ` \"", 5, {{Item(" ` "), {}}}, nullptr}, + {"0x61 in string", "\" a \"", 5, {{Item(" a "), {}}}, nullptr}, + {"0x62 in string", "\" b \"", 5, {{Item(" b "), {}}}, nullptr}, + {"0x63 in string", "\" c \"", 5, {{Item(" c "), {}}}, nullptr}, + {"0x64 in string", "\" d \"", 5, {{Item(" d "), {}}}, nullptr}, + {"0x65 in string", "\" e \"", 5, {{Item(" e "), {}}}, nullptr}, + {"0x66 in string", "\" f \"", 5, {{Item(" f "), {}}}, nullptr}, + {"0x67 in string", "\" g \"", 5, {{Item(" g "), {}}}, nullptr}, + {"0x68 in string", "\" h \"", 5, {{Item(" h "), {}}}, nullptr}, + {"0x69 in string", "\" i \"", 5, {{Item(" i "), {}}}, nullptr}, + {"0x6a in string", "\" j \"", 5, {{Item(" j "), {}}}, nullptr}, + {"0x6b in string", "\" k \"", 5, {{Item(" k "), {}}}, nullptr}, + {"0x6c in string", "\" l \"", 5, {{Item(" l "), {}}}, nullptr}, + {"0x6d in string", "\" m \"", 5, {{Item(" m "), {}}}, nullptr}, + {"0x6e in string", "\" n \"", 5, {{Item(" n "), {}}}, nullptr}, + {"0x6f in string", "\" o \"", 5, {{Item(" o "), {}}}, nullptr}, + {"0x70 in string", "\" p \"", 5, {{Item(" p "), {}}}, nullptr}, + {"0x71 in string", "\" q \"", 5, {{Item(" q "), {}}}, nullptr}, + {"0x72 in string", "\" r \"", 5, {{Item(" r "), {}}}, nullptr}, + {"0x73 in string", "\" s \"", 5, {{Item(" s "), {}}}, nullptr}, + {"0x74 in string", "\" t \"", 5, {{Item(" t "), {}}}, nullptr}, + {"0x75 in string", "\" u \"", 5, {{Item(" u "), {}}}, nullptr}, + {"0x76 in string", "\" v \"", 5, {{Item(" v "), {}}}, nullptr}, + {"0x77 in string", "\" w \"", 5, {{Item(" w "), {}}}, nullptr}, + {"0x78 in string", "\" x \"", 5, {{Item(" x "), {}}}, nullptr}, + {"0x79 in string", "\" y \"", 5, {{Item(" y "), {}}}, nullptr}, + {"0x7a in string", "\" z \"", 5, {{Item(" z "), {}}}, nullptr}, + {"0x7b in string", "\" { \"", 5, {{Item(" { "), {}}}, nullptr}, + {"0x7c in string", "\" | \"", 5, {{Item(" | "), {}}}, nullptr}, + {"0x7d in string", "\" } \"", 5, {{Item(" } "), {}}}, nullptr}, + {"0x7e in string", "\" ~ \"", 5, {{Item(" ~ "), {}}}, nullptr}, + {"0x7f in string", "\" \177 \"", 5, absl::nullopt, nullptr}, + {"Escaped 0x00 in string", "\"\\\000\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x01 in string", "\"\\\001\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x02 in string", "\"\\\002\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x03 in string", "\"\\\003\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x04 in string", "\"\\\004\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x05 in string", "\"\\\005\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x06 in string", "\"\\\006\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x07 in string", "\"\\\a\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x08 in string", "\"\\\b\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x09 in string", "\"\\\t\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x0a in string", "\"\\\n\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x0b in string", "\"\\\v\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x0c in string", "\"\\\f\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x0d in string", "\"\\\r\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x0e in string", "\"\\\016\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x0f in string", "\"\\\017\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x10 in string", "\"\\\020\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x11 in string", "\"\\\021\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x12 in string", "\"\\\022\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x13 in string", "\"\\\023\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x14 in string", "\"\\\024\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x15 in string", "\"\\\025\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x16 in string", "\"\\\026\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x17 in string", "\"\\\027\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x18 in string", "\"\\\030\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x19 in string", "\"\\\031\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x1a in string", "\"\\\032\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x1b in string", "\"\\\033\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x1c in string", "\"\\\034\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x1d in string", "\"\\\035\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x1e in string", "\"\\\036\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x1f in string", "\"\\\037\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x20 in string", "\"\\ \"", 4, absl::nullopt, nullptr}, + {"Escaped 0x21 in string", "\"\\!\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x22 in string", "\"\\\"\"", 4, {{Item("\""), {}}}, nullptr}, + {"Escaped 0x23 in string", "\"\\#\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x24 in string", "\"\\$\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x25 in string", "\"\\%\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x26 in string", "\"\\&\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x27 in string", "\"\\'\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x28 in string", "\"\\(\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x29 in string", "\"\\)\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x2a in string", "\"\\*\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x2b in string", "\"\\+\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x2c in string", "\"\\,\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x2d in string", "\"\\-\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x2e in string", "\"\\.\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x2f in string", "\"\\/\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x30 in string", "\"\\0\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x31 in string", "\"\\1\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x32 in string", "\"\\2\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x33 in string", "\"\\3\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x34 in string", "\"\\4\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x35 in string", "\"\\5\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x36 in string", "\"\\6\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x37 in string", "\"\\7\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x38 in string", "\"\\8\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x39 in string", "\"\\9\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x3a in string", "\"\\:\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x3b in string", "\"\\;\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x3c in string", "\"\\<\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x3d in string", "\"\\=\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x3e in string", "\"\\>\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x3f in string", "\"\\?\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x40 in string", "\"\\@\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x41 in string", "\"\\A\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x42 in string", "\"\\B\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x43 in string", "\"\\C\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x44 in string", "\"\\D\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x45 in string", "\"\\E\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x46 in string", "\"\\F\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x47 in string", "\"\\G\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x48 in string", "\"\\H\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x49 in string", "\"\\I\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x4a in string", "\"\\J\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x4b in string", "\"\\K\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x4c in string", "\"\\L\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x4d in string", "\"\\M\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x4e in string", "\"\\N\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x4f in string", "\"\\O\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x50 in string", "\"\\P\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x51 in string", "\"\\Q\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x52 in string", "\"\\R\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x53 in string", "\"\\S\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x54 in string", "\"\\T\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x55 in string", "\"\\U\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x56 in string", "\"\\V\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x57 in string", "\"\\W\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x58 in string", "\"\\X\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x59 in string", "\"\\Y\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x5a in string", "\"\\Z\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x5b in string", "\"\\[\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x5c in string", "\"\\\\\"", 4, {{Item("\\"), {}}}, nullptr}, + {"Escaped 0x5d in string", "\"\\]\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x5e in string", "\"\\^\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x5f in string", "\"\\_\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x60 in string", "\"\\`\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x61 in string", "\"\\a\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x62 in string", "\"\\b\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x63 in string", "\"\\c\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x64 in string", "\"\\d\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x65 in string", "\"\\e\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x66 in string", "\"\\f\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x67 in string", "\"\\g\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x68 in string", "\"\\h\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x69 in string", "\"\\i\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x6a in string", "\"\\j\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x6b in string", "\"\\k\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x6c in string", "\"\\l\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x6d in string", "\"\\m\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x6e in string", "\"\\n\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x6f in string", "\"\\o\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x70 in string", "\"\\p\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x71 in string", "\"\\q\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x72 in string", "\"\\r\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x73 in string", "\"\\s\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x74 in string", "\"\\t\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x75 in string", "\"\\u\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x76 in string", "\"\\v\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x77 in string", "\"\\w\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x78 in string", "\"\\x\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x79 in string", "\"\\y\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x7a in string", "\"\\z\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x7b in string", "\"\\{\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x7c in string", "\"\\|\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x7d in string", "\"\\}\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x7e in string", "\"\\~\"", 4, absl::nullopt, nullptr}, + {"Escaped 0x7f in string", "\"\\\177\"", 4, absl::nullopt, nullptr}, + // string.json + {"basic string", "\"foo bar\"", 9, {{Item("foo bar"), {}}}, nullptr}, + {"empty string", "\"\"", 2, {{Item(""), {}}}, nullptr}, + {"long string", + "\"foo foo foo foo foo foo foo foo foo foo foo foo foo foo foo foo foo " + "foo foo foo foo foo foo foo foo foo foo foo foo foo foo foo foo foo foo " + "foo foo foo foo foo foo foo foo foo foo foo foo foo foo foo foo foo foo " + "foo foo foo foo foo foo foo foo foo foo foo foo \"", + 262, + {{Item("foo foo foo foo foo foo foo foo foo foo foo foo foo foo foo foo " + "foo foo foo foo foo foo foo foo foo foo foo foo foo foo foo foo " + "foo foo foo foo foo foo foo foo foo foo foo foo foo foo foo foo " + "foo foo foo foo foo foo foo foo foo foo foo foo foo foo foo foo " + "foo "), + {}}}, + nullptr}, + {"whitespace string", "\" \"", 5, {{Item(" "), {}}}, nullptr}, + {"non-ascii string", "\"f\374\374\"", 5, absl::nullopt, nullptr}, + {"tab in string", "\"\\t\"", 4, absl::nullopt, nullptr}, + {"newline in string", "\" \\n \"", 6, absl::nullopt, nullptr}, + {"single quoted string", "'foo'", 5, absl::nullopt, nullptr}, + {"unbalanced string", "\"foo", 4, absl::nullopt, nullptr}, + {"string quoting", + "\"foo \\\"bar\\\" \\\\ baz\"", + 20, + {{Item("foo \"bar\" \\ baz"), {}}}, + nullptr}, + {"bad string quoting", "\"foo \\,\"", 8, absl::nullopt, nullptr}, + {"ending string quote", "\"foo \\\"", 7, absl::nullopt, nullptr}, + {"abruptly ending string quote", "\"foo \\", 6, absl::nullopt, nullptr}, + // token-generated.json + {"0x00 in token", "a\000a", 3, absl::nullopt, nullptr}, + {"0x01 in token", "a\001a", 3, absl::nullopt, nullptr}, + {"0x02 in token", "a\002a", 3, absl::nullopt, nullptr}, + {"0x03 in token", "a\003a", 3, absl::nullopt, nullptr}, + {"0x04 in token", "a\004a", 3, absl::nullopt, nullptr}, + {"0x05 in token", "a\005a", 3, absl::nullopt, nullptr}, + {"0x06 in token", "a\006a", 3, absl::nullopt, nullptr}, + {"0x07 in token", "a\aa", 3, absl::nullopt, nullptr}, + {"0x08 in token", "a\ba", 3, absl::nullopt, nullptr}, + {"0x09 in token", "a\ta", 3, absl::nullopt, nullptr}, + {"0x0a in token", "a\na", 3, absl::nullopt, nullptr}, + {"0x0b in token", "a\va", 3, absl::nullopt, nullptr}, + {"0x0c in token", "a\fa", 3, absl::nullopt, nullptr}, + {"0x0d in token", "a\ra", 3, absl::nullopt, nullptr}, + {"0x0e in token", "a\016a", 3, absl::nullopt, nullptr}, + {"0x0f in token", "a\017a", 3, absl::nullopt, nullptr}, + {"0x10 in token", "a\020a", 3, absl::nullopt, nullptr}, + {"0x11 in token", "a\021a", 3, absl::nullopt, nullptr}, + {"0x12 in token", "a\022a", 3, absl::nullopt, nullptr}, + {"0x13 in token", "a\023a", 3, absl::nullopt, nullptr}, + {"0x14 in token", "a\024a", 3, absl::nullopt, nullptr}, + {"0x15 in token", "a\025a", 3, absl::nullopt, nullptr}, + {"0x16 in token", "a\026a", 3, absl::nullopt, nullptr}, + {"0x17 in token", "a\027a", 3, absl::nullopt, nullptr}, + {"0x18 in token", "a\030a", 3, absl::nullopt, nullptr}, + {"0x19 in token", "a\031a", 3, absl::nullopt, nullptr}, + {"0x1a in token", "a\032a", 3, absl::nullopt, nullptr}, + {"0x1b in token", "a\033a", 3, absl::nullopt, nullptr}, + {"0x1c in token", "a\034a", 3, absl::nullopt, nullptr}, + {"0x1d in token", "a\035a", 3, absl::nullopt, nullptr}, + {"0x1e in token", "a\036a", 3, absl::nullopt, nullptr}, + {"0x1f in token", "a\037a", 3, absl::nullopt, nullptr}, + {"0x20 in token", "a a", 3, absl::nullopt, nullptr}, + {"0x21 in token", "a!a", 3, {{Item("a!a", Item::kTokenType), {}}}, nullptr}, + {"0x22 in token", "a\"a", 3, absl::nullopt, nullptr}, + {"0x23 in token", "a#a", 3, {{Item("a#a", Item::kTokenType), {}}}, nullptr}, + {"0x24 in token", "a$a", 3, {{Item("a$a", Item::kTokenType), {}}}, nullptr}, + {"0x25 in token", "a%a", 3, {{Item("a%a", Item::kTokenType), {}}}, nullptr}, + {"0x26 in token", "a&a", 3, {{Item("a&a", Item::kTokenType), {}}}, nullptr}, + {"0x27 in token", "a'a", 3, {{Item("a'a", Item::kTokenType), {}}}, nullptr}, + {"0x28 in token", "a(a", 3, absl::nullopt, nullptr}, + {"0x29 in token", "a)a", 3, absl::nullopt, nullptr}, + {"0x2a in token", "a*a", 3, {{Item("a*a", Item::kTokenType), {}}}, nullptr}, + {"0x2b in token", "a+a", 3, {{Item("a+a", Item::kTokenType), {}}}, nullptr}, + {"0x2c in token", "a,a", 3, absl::nullopt, nullptr}, + {"0x2d in token", "a-a", 3, {{Item("a-a", Item::kTokenType), {}}}, nullptr}, + {"0x2e in token", "a.a", 3, {{Item("a.a", Item::kTokenType), {}}}, nullptr}, + {"0x2f in token", "a/a", 3, {{Item("a/a", Item::kTokenType), {}}}, nullptr}, + {"0x30 in token", "a0a", 3, {{Item("a0a", Item::kTokenType), {}}}, nullptr}, + {"0x31 in token", "a1a", 3, {{Item("a1a", Item::kTokenType), {}}}, nullptr}, + {"0x32 in token", "a2a", 3, {{Item("a2a", Item::kTokenType), {}}}, nullptr}, + {"0x33 in token", "a3a", 3, {{Item("a3a", Item::kTokenType), {}}}, nullptr}, + {"0x34 in token", "a4a", 3, {{Item("a4a", Item::kTokenType), {}}}, nullptr}, + {"0x35 in token", "a5a", 3, {{Item("a5a", Item::kTokenType), {}}}, nullptr}, + {"0x36 in token", "a6a", 3, {{Item("a6a", Item::kTokenType), {}}}, nullptr}, + {"0x37 in token", "a7a", 3, {{Item("a7a", Item::kTokenType), {}}}, nullptr}, + {"0x38 in token", "a8a", 3, {{Item("a8a", Item::kTokenType), {}}}, nullptr}, + {"0x39 in token", "a9a", 3, {{Item("a9a", Item::kTokenType), {}}}, nullptr}, + {"0x3a in token", "a:a", 3, {{Item("a:a", Item::kTokenType), {}}}, nullptr}, + {"0x3b in token", + "a;a", + 3, + {{Item("a", Item::kTokenType), {BooleanParam("a", true)}}}, + nullptr}, + {"0x3c in token", "aa", 3, absl::nullopt, nullptr}, + {"0x3f in token", "a?a", 3, absl::nullopt, nullptr}, + {"0x40 in token", "a@a", 3, absl::nullopt, nullptr}, + {"0x41 in token", "aAa", 3, {{Item("aAa", Item::kTokenType), {}}}, nullptr}, + {"0x42 in token", "aBa", 3, {{Item("aBa", Item::kTokenType), {}}}, nullptr}, + {"0x43 in token", "aCa", 3, {{Item("aCa", Item::kTokenType), {}}}, nullptr}, + {"0x44 in token", "aDa", 3, {{Item("aDa", Item::kTokenType), {}}}, nullptr}, + {"0x45 in token", "aEa", 3, {{Item("aEa", Item::kTokenType), {}}}, nullptr}, + {"0x46 in token", "aFa", 3, {{Item("aFa", Item::kTokenType), {}}}, nullptr}, + {"0x47 in token", "aGa", 3, {{Item("aGa", Item::kTokenType), {}}}, nullptr}, + {"0x48 in token", "aHa", 3, {{Item("aHa", Item::kTokenType), {}}}, nullptr}, + {"0x49 in token", "aIa", 3, {{Item("aIa", Item::kTokenType), {}}}, nullptr}, + {"0x4a in token", "aJa", 3, {{Item("aJa", Item::kTokenType), {}}}, nullptr}, + {"0x4b in token", "aKa", 3, {{Item("aKa", Item::kTokenType), {}}}, nullptr}, + {"0x4c in token", "aLa", 3, {{Item("aLa", Item::kTokenType), {}}}, nullptr}, + {"0x4d in token", "aMa", 3, {{Item("aMa", Item::kTokenType), {}}}, nullptr}, + {"0x4e in token", "aNa", 3, {{Item("aNa", Item::kTokenType), {}}}, nullptr}, + {"0x4f in token", "aOa", 3, {{Item("aOa", Item::kTokenType), {}}}, nullptr}, + {"0x50 in token", "aPa", 3, {{Item("aPa", Item::kTokenType), {}}}, nullptr}, + {"0x51 in token", "aQa", 3, {{Item("aQa", Item::kTokenType), {}}}, nullptr}, + {"0x52 in token", "aRa", 3, {{Item("aRa", Item::kTokenType), {}}}, nullptr}, + {"0x53 in token", "aSa", 3, {{Item("aSa", Item::kTokenType), {}}}, nullptr}, + {"0x54 in token", "aTa", 3, {{Item("aTa", Item::kTokenType), {}}}, nullptr}, + {"0x55 in token", "aUa", 3, {{Item("aUa", Item::kTokenType), {}}}, nullptr}, + {"0x56 in token", "aVa", 3, {{Item("aVa", Item::kTokenType), {}}}, nullptr}, + {"0x57 in token", "aWa", 3, {{Item("aWa", Item::kTokenType), {}}}, nullptr}, + {"0x58 in token", "aXa", 3, {{Item("aXa", Item::kTokenType), {}}}, nullptr}, + {"0x59 in token", "aYa", 3, {{Item("aYa", Item::kTokenType), {}}}, nullptr}, + {"0x5a in token", "aZa", 3, {{Item("aZa", Item::kTokenType), {}}}, nullptr}, + {"0x5b in token", "a[a", 3, absl::nullopt, nullptr}, + {"0x5c in token", "a\\a", 3, absl::nullopt, nullptr}, + {"0x5d in token", "a]a", 3, absl::nullopt, nullptr}, + {"0x5e in token", "a^a", 3, {{Item("a^a", Item::kTokenType), {}}}, nullptr}, + {"0x5f in token", "a_a", 3, {{Item("a_a", Item::kTokenType), {}}}, nullptr}, + {"0x60 in token", "a`a", 3, {{Item("a`a", Item::kTokenType), {}}}, nullptr}, + {"0x61 in token", "aaa", 3, {{Item("aaa", Item::kTokenType), {}}}, nullptr}, + {"0x62 in token", "aba", 3, {{Item("aba", Item::kTokenType), {}}}, nullptr}, + {"0x63 in token", "aca", 3, {{Item("aca", Item::kTokenType), {}}}, nullptr}, + {"0x64 in token", "ada", 3, {{Item("ada", Item::kTokenType), {}}}, nullptr}, + {"0x65 in token", "aea", 3, {{Item("aea", Item::kTokenType), {}}}, nullptr}, + {"0x66 in token", "afa", 3, {{Item("afa", Item::kTokenType), {}}}, nullptr}, + {"0x67 in token", "aga", 3, {{Item("aga", Item::kTokenType), {}}}, nullptr}, + {"0x68 in token", "aha", 3, {{Item("aha", Item::kTokenType), {}}}, nullptr}, + {"0x69 in token", "aia", 3, {{Item("aia", Item::kTokenType), {}}}, nullptr}, + {"0x6a in token", "aja", 3, {{Item("aja", Item::kTokenType), {}}}, nullptr}, + {"0x6b in token", "aka", 3, {{Item("aka", Item::kTokenType), {}}}, nullptr}, + {"0x6c in token", "ala", 3, {{Item("ala", Item::kTokenType), {}}}, nullptr}, + {"0x6d in token", "ama", 3, {{Item("ama", Item::kTokenType), {}}}, nullptr}, + {"0x6e in token", "ana", 3, {{Item("ana", Item::kTokenType), {}}}, nullptr}, + {"0x6f in token", "aoa", 3, {{Item("aoa", Item::kTokenType), {}}}, nullptr}, + {"0x70 in token", "apa", 3, {{Item("apa", Item::kTokenType), {}}}, nullptr}, + {"0x71 in token", "aqa", 3, {{Item("aqa", Item::kTokenType), {}}}, nullptr}, + {"0x72 in token", "ara", 3, {{Item("ara", Item::kTokenType), {}}}, nullptr}, + {"0x73 in token", "asa", 3, {{Item("asa", Item::kTokenType), {}}}, nullptr}, + {"0x74 in token", "ata", 3, {{Item("ata", Item::kTokenType), {}}}, nullptr}, + {"0x75 in token", "aua", 3, {{Item("aua", Item::kTokenType), {}}}, nullptr}, + {"0x76 in token", "ava", 3, {{Item("ava", Item::kTokenType), {}}}, nullptr}, + {"0x77 in token", "awa", 3, {{Item("awa", Item::kTokenType), {}}}, nullptr}, + {"0x78 in token", "axa", 3, {{Item("axa", Item::kTokenType), {}}}, nullptr}, + {"0x79 in token", "aya", 3, {{Item("aya", Item::kTokenType), {}}}, nullptr}, + {"0x7a in token", "aza", 3, {{Item("aza", Item::kTokenType), {}}}, nullptr}, + {"0x7b in token", "a{a", 3, absl::nullopt, nullptr}, + {"0x7c in token", "a|a", 3, {{Item("a|a", Item::kTokenType), {}}}, nullptr}, + {"0x7d in token", "a}a", 3, absl::nullopt, nullptr}, + {"0x7e in token", "a~a", 3, {{Item("a~a", Item::kTokenType), {}}}, nullptr}, + {"0x7f in token", "a\177a", 3, absl::nullopt, nullptr}, + {"0x00 starting an token", "\000a", 2, absl::nullopt, nullptr}, + {"0x01 starting an token", "\001a", 2, absl::nullopt, nullptr}, + {"0x02 starting an token", "\002a", 2, absl::nullopt, nullptr}, + {"0x03 starting an token", "\003a", 2, absl::nullopt, nullptr}, + {"0x04 starting an token", "\004a", 2, absl::nullopt, nullptr}, + {"0x05 starting an token", "\005a", 2, absl::nullopt, nullptr}, + {"0x06 starting an token", "\006a", 2, absl::nullopt, nullptr}, + {"0x07 starting an token", "\aa", 2, absl::nullopt, nullptr}, + {"0x08 starting an token", "\ba", 2, absl::nullopt, nullptr}, + {"0x09 starting an token", "\ta", 2, absl::nullopt, nullptr}, + {"0x0a starting an token", "\na", 2, absl::nullopt, nullptr}, + {"0x0b starting an token", "\va", 2, absl::nullopt, nullptr}, + {"0x0c starting an token", "\fa", 2, absl::nullopt, nullptr}, + {"0x0d starting an token", "\ra", 2, absl::nullopt, nullptr}, + {"0x0e starting an token", "\016a", 2, absl::nullopt, nullptr}, + {"0x0f starting an token", "\017a", 2, absl::nullopt, nullptr}, + {"0x10 starting an token", "\020a", 2, absl::nullopt, nullptr}, + {"0x11 starting an token", "\021a", 2, absl::nullopt, nullptr}, + {"0x12 starting an token", "\022a", 2, absl::nullopt, nullptr}, + {"0x13 starting an token", "\023a", 2, absl::nullopt, nullptr}, + {"0x14 starting an token", "\024a", 2, absl::nullopt, nullptr}, + {"0x15 starting an token", "\025a", 2, absl::nullopt, nullptr}, + {"0x16 starting an token", "\026a", 2, absl::nullopt, nullptr}, + {"0x17 starting an token", "\027a", 2, absl::nullopt, nullptr}, + {"0x18 starting an token", "\030a", 2, absl::nullopt, nullptr}, + {"0x19 starting an token", "\031a", 2, absl::nullopt, nullptr}, + {"0x1a starting an token", "\032a", 2, absl::nullopt, nullptr}, + {"0x1b starting an token", "\033a", 2, absl::nullopt, nullptr}, + {"0x1c starting an token", "\034a", 2, absl::nullopt, nullptr}, + {"0x1d starting an token", "\035a", 2, absl::nullopt, nullptr}, + {"0x1e starting an token", "\036a", 2, absl::nullopt, nullptr}, + {"0x1f starting an token", "\037a", 2, absl::nullopt, nullptr}, + {"0x20 starting an token", + " a", + 2, + {{Item("a", Item::kTokenType), {}}}, + "a"}, + {"0x21 starting an token", "!a", 2, absl::nullopt, nullptr}, + {"0x22 starting an token", "\"a", 2, absl::nullopt, nullptr}, + {"0x23 starting an token", "#a", 2, absl::nullopt, nullptr}, + {"0x24 starting an token", "$a", 2, absl::nullopt, nullptr}, + {"0x25 starting an token", "%a", 2, absl::nullopt, nullptr}, + {"0x26 starting an token", "&a", 2, absl::nullopt, nullptr}, + {"0x27 starting an token", "'a", 2, absl::nullopt, nullptr}, + {"0x28 starting an token", "(a", 2, absl::nullopt, nullptr}, + {"0x29 starting an token", ")a", 2, absl::nullopt, nullptr}, + {"0x2a starting an token", + "*a", + 2, + {{Item("*a", Item::kTokenType), {}}}, + nullptr}, + {"0x2b starting an token", "+a", 2, absl::nullopt, nullptr}, + {"0x2c starting an token", ",a", 2, absl::nullopt, nullptr}, + {"0x2d starting an token", "-a", 2, absl::nullopt, nullptr}, + {"0x2e starting an token", ".a", 2, absl::nullopt, nullptr}, + {"0x2f starting an token", "/a", 2, absl::nullopt, nullptr}, + {"0x30 starting an token", "0a", 2, absl::nullopt, nullptr}, + {"0x31 starting an token", "1a", 2, absl::nullopt, nullptr}, + {"0x32 starting an token", "2a", 2, absl::nullopt, nullptr}, + {"0x33 starting an token", "3a", 2, absl::nullopt, nullptr}, + {"0x34 starting an token", "4a", 2, absl::nullopt, nullptr}, + {"0x35 starting an token", "5a", 2, absl::nullopt, nullptr}, + {"0x36 starting an token", "6a", 2, absl::nullopt, nullptr}, + {"0x37 starting an token", "7a", 2, absl::nullopt, nullptr}, + {"0x38 starting an token", "8a", 2, absl::nullopt, nullptr}, + {"0x39 starting an token", "9a", 2, absl::nullopt, nullptr}, + {"0x3a starting an token", ":a", 2, absl::nullopt, nullptr}, + {"0x3b starting an token", ";a", 2, absl::nullopt, nullptr}, + {"0x3c starting an token", "a", 2, absl::nullopt, nullptr}, + {"0x3f starting an token", "?a", 2, absl::nullopt, nullptr}, + {"0x40 starting an token", "@a", 2, absl::nullopt, nullptr}, + {"0x41 starting an token", + "Aa", + 2, + {{Item("Aa", Item::kTokenType), {}}}, + nullptr}, + {"0x42 starting an token", + "Ba", + 2, + {{Item("Ba", Item::kTokenType), {}}}, + nullptr}, + {"0x43 starting an token", + "Ca", + 2, + {{Item("Ca", Item::kTokenType), {}}}, + nullptr}, + {"0x44 starting an token", + "Da", + 2, + {{Item("Da", Item::kTokenType), {}}}, + nullptr}, + {"0x45 starting an token", + "Ea", + 2, + {{Item("Ea", Item::kTokenType), {}}}, + nullptr}, + {"0x46 starting an token", + "Fa", + 2, + {{Item("Fa", Item::kTokenType), {}}}, + nullptr}, + {"0x47 starting an token", + "Ga", + 2, + {{Item("Ga", Item::kTokenType), {}}}, + nullptr}, + {"0x48 starting an token", + "Ha", + 2, + {{Item("Ha", Item::kTokenType), {}}}, + nullptr}, + {"0x49 starting an token", + "Ia", + 2, + {{Item("Ia", Item::kTokenType), {}}}, + nullptr}, + {"0x4a starting an token", + "Ja", + 2, + {{Item("Ja", Item::kTokenType), {}}}, + nullptr}, + {"0x4b starting an token", + "Ka", + 2, + {{Item("Ka", Item::kTokenType), {}}}, + nullptr}, + {"0x4c starting an token", + "La", + 2, + {{Item("La", Item::kTokenType), {}}}, + nullptr}, + {"0x4d starting an token", + "Ma", + 2, + {{Item("Ma", Item::kTokenType), {}}}, + nullptr}, + {"0x4e starting an token", + "Na", + 2, + {{Item("Na", Item::kTokenType), {}}}, + nullptr}, + {"0x4f starting an token", + "Oa", + 2, + {{Item("Oa", Item::kTokenType), {}}}, + nullptr}, + {"0x50 starting an token", + "Pa", + 2, + {{Item("Pa", Item::kTokenType), {}}}, + nullptr}, + {"0x51 starting an token", + "Qa", + 2, + {{Item("Qa", Item::kTokenType), {}}}, + nullptr}, + {"0x52 starting an token", + "Ra", + 2, + {{Item("Ra", Item::kTokenType), {}}}, + nullptr}, + {"0x53 starting an token", + "Sa", + 2, + {{Item("Sa", Item::kTokenType), {}}}, + nullptr}, + {"0x54 starting an token", + "Ta", + 2, + {{Item("Ta", Item::kTokenType), {}}}, + nullptr}, + {"0x55 starting an token", + "Ua", + 2, + {{Item("Ua", Item::kTokenType), {}}}, + nullptr}, + {"0x56 starting an token", + "Va", + 2, + {{Item("Va", Item::kTokenType), {}}}, + nullptr}, + {"0x57 starting an token", + "Wa", + 2, + {{Item("Wa", Item::kTokenType), {}}}, + nullptr}, + {"0x58 starting an token", + "Xa", + 2, + {{Item("Xa", Item::kTokenType), {}}}, + nullptr}, + {"0x59 starting an token", + "Ya", + 2, + {{Item("Ya", Item::kTokenType), {}}}, + nullptr}, + {"0x5a starting an token", + "Za", + 2, + {{Item("Za", Item::kTokenType), {}}}, + nullptr}, + {"0x5b starting an token", "[a", 2, absl::nullopt, nullptr}, + {"0x5c starting an token", "\\a", 2, absl::nullopt, nullptr}, + {"0x5d starting an token", "]a", 2, absl::nullopt, nullptr}, + {"0x5e starting an token", "^a", 2, absl::nullopt, nullptr}, + {"0x5f starting an token", "_a", 2, absl::nullopt, nullptr}, + {"0x60 starting an token", "`a", 2, absl::nullopt, nullptr}, + {"0x61 starting an token", + "aa", + 2, + {{Item("aa", Item::kTokenType), {}}}, + nullptr}, + {"0x62 starting an token", + "ba", + 2, + {{Item("ba", Item::kTokenType), {}}}, + nullptr}, + {"0x63 starting an token", + "ca", + 2, + {{Item("ca", Item::kTokenType), {}}}, + nullptr}, + {"0x64 starting an token", + "da", + 2, + {{Item("da", Item::kTokenType), {}}}, + nullptr}, + {"0x65 starting an token", + "ea", + 2, + {{Item("ea", Item::kTokenType), {}}}, + nullptr}, + {"0x66 starting an token", + "fa", + 2, + {{Item("fa", Item::kTokenType), {}}}, + nullptr}, + {"0x67 starting an token", + "ga", + 2, + {{Item("ga", Item::kTokenType), {}}}, + nullptr}, + {"0x68 starting an token", + "ha", + 2, + {{Item("ha", Item::kTokenType), {}}}, + nullptr}, + {"0x69 starting an token", + "ia", + 2, + {{Item("ia", Item::kTokenType), {}}}, + nullptr}, + {"0x6a starting an token", + "ja", + 2, + {{Item("ja", Item::kTokenType), {}}}, + nullptr}, + {"0x6b starting an token", + "ka", + 2, + {{Item("ka", Item::kTokenType), {}}}, + nullptr}, + {"0x6c starting an token", + "la", + 2, + {{Item("la", Item::kTokenType), {}}}, + nullptr}, + {"0x6d starting an token", + "ma", + 2, + {{Item("ma", Item::kTokenType), {}}}, + nullptr}, + {"0x6e starting an token", + "na", + 2, + {{Item("na", Item::kTokenType), {}}}, + nullptr}, + {"0x6f starting an token", + "oa", + 2, + {{Item("oa", Item::kTokenType), {}}}, + nullptr}, + {"0x70 starting an token", + "pa", + 2, + {{Item("pa", Item::kTokenType), {}}}, + nullptr}, + {"0x71 starting an token", + "qa", + 2, + {{Item("qa", Item::kTokenType), {}}}, + nullptr}, + {"0x72 starting an token", + "ra", + 2, + {{Item("ra", Item::kTokenType), {}}}, + nullptr}, + {"0x73 starting an token", + "sa", + 2, + {{Item("sa", Item::kTokenType), {}}}, + nullptr}, + {"0x74 starting an token", + "ta", + 2, + {{Item("ta", Item::kTokenType), {}}}, + nullptr}, + {"0x75 starting an token", + "ua", + 2, + {{Item("ua", Item::kTokenType), {}}}, + nullptr}, + {"0x76 starting an token", + "va", + 2, + {{Item("va", Item::kTokenType), {}}}, + nullptr}, + {"0x77 starting an token", + "wa", + 2, + {{Item("wa", Item::kTokenType), {}}}, + nullptr}, + {"0x78 starting an token", + "xa", + 2, + {{Item("xa", Item::kTokenType), {}}}, + nullptr}, + {"0x79 starting an token", + "ya", + 2, + {{Item("ya", Item::kTokenType), {}}}, + nullptr}, + {"0x7a starting an token", + "za", + 2, + {{Item("za", Item::kTokenType), {}}}, + nullptr}, + {"0x7b starting an token", "{a", 2, absl::nullopt, nullptr}, + {"0x7c starting an token", "|a", 2, absl::nullopt, nullptr}, + {"0x7d starting an token", "}a", 2, absl::nullopt, nullptr}, + {"0x7e starting an token", "~a", 2, absl::nullopt, nullptr}, + {"0x7f starting an token", "\177a", 2, absl::nullopt, nullptr}, + // token.json + {"basic token - item", + "a_b-c.d3:f%00/*", + 15, + {{Item("a_b-c.d3:f%00/*", Item::kTokenType), {}}}, + nullptr}, + {"token with capitals - item", + "fooBar", + 6, + {{Item("fooBar", Item::kTokenType), {}}}, + nullptr}, + {"token starting with capitals - item", + "FooBar", + 6, + {{Item("FooBar", Item::kTokenType), {}}}, + nullptr}, +}; + +const struct ListTestCase { + const char* name; + const char* raw; + size_t raw_len; + const absl::optional expected; // nullopt if parse error is expected. + const char* canonical; // nullptr if parse error is expected, or if canonical + // format is identical to raw. +} list_test_cases[] = { + // examples.json + {"Example-StrListHeader", + "\"foo\", \"bar\", \"It was the best of times.\"", + 41, + {{{Item("foo"), {}}, + {Item("bar"), {}}, + {Item("It was the best of times."), {}}}}, + nullptr}, + {"Example-Hdr (list on one line)", + "foo, bar", + 8, + {{{Item("foo", Item::kTokenType), {}}, + {Item("bar", Item::kTokenType), {}}}}, + nullptr}, + {"Example-Hdr (list on two lines)", + "foo, bar", + 8, + {{{Item("foo", Item::kTokenType), {}}, + {Item("bar", Item::kTokenType), {}}}}, + "foo, bar"}, + {"Example-StrListListHeader", + "(\"foo\" \"bar\"), (\"baz\"), (\"bat\" \"one\"), ()", + 41, + {{{{{Item("foo"), {}}, {Item("bar"), {}}}, {}}, + {{{Item("baz"), {}}}, {}}, + {{{Item("bat"), {}}, {Item("one"), {}}}, {}}, + {std::vector(), {}}}}, + nullptr}, + {"Example-ListListParam", + "(\"foo\"; a=1;b=2);lvl=5, (\"bar\" \"baz\");lvl=1", + 43, + {{{{{Item("foo"), {Param("a", 1), Param("b", 2)}}}, {Param("lvl", 5)}}, + {{{Item("bar"), {}}, {Item("baz"), {}}}, {Param("lvl", 1)}}}}, + "(\"foo\";a=1;b=2);lvl=5, (\"bar\" \"baz\");lvl=1"}, + {"Example-ParamListHeader", + "abc;a=1;b=2; cde_456, (ghi;jk=4 l);q=\"9\";r=w", + 44, + {{{Item("abc", Item::kTokenType), + {Param("a", 1), Param("b", 2), BooleanParam("cde_456", true)}}, + {{{Item("ghi", Item::kTokenType), {Param("jk", 4)}}, + {Item("l", Item::kTokenType), {}}}, + {Param("q", "9"), TokenParam("r", "w")}}}}, + "abc;a=1;b=2;cde_456, (ghi;jk=4 l);q=\"9\";r=w"}, + // key-generated.json + {"0x00 in parameterised list key", "foo; a\000a=1", 10, absl::nullopt, + nullptr}, + {"0x01 in parameterised list key", "foo; a\001a=1", 10, absl::nullopt, + nullptr}, + {"0x02 in parameterised list key", "foo; a\002a=1", 10, absl::nullopt, + nullptr}, + {"0x03 in parameterised list key", "foo; a\003a=1", 10, absl::nullopt, + nullptr}, + {"0x04 in parameterised list key", "foo; a\004a=1", 10, absl::nullopt, + nullptr}, + {"0x05 in parameterised list key", "foo; a\005a=1", 10, absl::nullopt, + nullptr}, + {"0x06 in parameterised list key", "foo; a\006a=1", 10, absl::nullopt, + nullptr}, + {"0x07 in parameterised list key", "foo; a\aa=1", 10, absl::nullopt, + nullptr}, + {"0x08 in parameterised list key", "foo; a\ba=1", 10, absl::nullopt, + nullptr}, + {"0x09 in parameterised list key", "foo; a\ta=1", 10, absl::nullopt, + nullptr}, + {"0x0a in parameterised list key", "foo; a\na=1", 10, absl::nullopt, + nullptr}, + {"0x0b in parameterised list key", "foo; a\va=1", 10, absl::nullopt, + nullptr}, + {"0x0c in parameterised list key", "foo; a\fa=1", 10, absl::nullopt, + nullptr}, + {"0x0d in parameterised list key", "foo; a\ra=1", 10, absl::nullopt, + nullptr}, + {"0x0e in parameterised list key", "foo; a\016a=1", 10, absl::nullopt, + nullptr}, + {"0x0f in parameterised list key", "foo; a\017a=1", 10, absl::nullopt, + nullptr}, + {"0x10 in parameterised list key", "foo; a\020a=1", 10, absl::nullopt, + nullptr}, + {"0x11 in parameterised list key", "foo; a\021a=1", 10, absl::nullopt, + nullptr}, + {"0x12 in parameterised list key", "foo; a\022a=1", 10, absl::nullopt, + nullptr}, + {"0x13 in parameterised list key", "foo; a\023a=1", 10, absl::nullopt, + nullptr}, + {"0x14 in parameterised list key", "foo; a\024a=1", 10, absl::nullopt, + nullptr}, + {"0x15 in parameterised list key", "foo; a\025a=1", 10, absl::nullopt, + nullptr}, + {"0x16 in parameterised list key", "foo; a\026a=1", 10, absl::nullopt, + nullptr}, + {"0x17 in parameterised list key", "foo; a\027a=1", 10, absl::nullopt, + nullptr}, + {"0x18 in parameterised list key", "foo; a\030a=1", 10, absl::nullopt, + nullptr}, + {"0x19 in parameterised list key", "foo; a\031a=1", 10, absl::nullopt, + nullptr}, + {"0x1a in parameterised list key", "foo; a\032a=1", 10, absl::nullopt, + nullptr}, + {"0x1b in parameterised list key", "foo; a\033a=1", 10, absl::nullopt, + nullptr}, + {"0x1c in parameterised list key", "foo; a\034a=1", 10, absl::nullopt, + nullptr}, + {"0x1d in parameterised list key", "foo; a\035a=1", 10, absl::nullopt, + nullptr}, + {"0x1e in parameterised list key", "foo; a\036a=1", 10, absl::nullopt, + nullptr}, + {"0x1f in parameterised list key", "foo; a\037a=1", 10, absl::nullopt, + nullptr}, + {"0x20 in parameterised list key", "foo; a a=1", 10, absl::nullopt, + nullptr}, + {"0x21 in parameterised list key", "foo; a!a=1", 10, absl::nullopt, + nullptr}, + {"0x22 in parameterised list key", "foo; a\"a=1", 10, absl::nullopt, + nullptr}, + {"0x23 in parameterised list key", "foo; a#a=1", 10, absl::nullopt, + nullptr}, + {"0x24 in parameterised list key", "foo; a$a=1", 10, absl::nullopt, + nullptr}, + {"0x25 in parameterised list key", "foo; a%a=1", 10, absl::nullopt, + nullptr}, + {"0x26 in parameterised list key", "foo; a&a=1", 10, absl::nullopt, + nullptr}, + {"0x27 in parameterised list key", "foo; a'a=1", 10, absl::nullopt, + nullptr}, + {"0x28 in parameterised list key", "foo; a(a=1", 10, absl::nullopt, + nullptr}, + {"0x29 in parameterised list key", "foo; a)a=1", 10, absl::nullopt, + nullptr}, + {"0x2a in parameterised list key", + "foo; a*a=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("a*a", 1)}}}}, + "foo;a*a=1"}, + {"0x2b in parameterised list key", "foo; a+a=1", 10, absl::nullopt, + nullptr}, + {"0x2c in parameterised list key", "foo; a,a=1", 10, absl::nullopt, + nullptr}, + {"0x2d in parameterised list key", + "foo; a-a=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("a-a", 1)}}}}, + "foo;a-a=1"}, + {"0x2e in parameterised list key", + "foo; a.a=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("a.a", 1)}}}}, + "foo;a.a=1"}, + {"0x2f in parameterised list key", "foo; a/a=1", 10, absl::nullopt, + nullptr}, + {"0x30 in parameterised list key", + "foo; a0a=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("a0a", 1)}}}}, + "foo;a0a=1"}, + {"0x31 in parameterised list key", + "foo; a1a=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("a1a", 1)}}}}, + "foo;a1a=1"}, + {"0x32 in parameterised list key", + "foo; a2a=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("a2a", 1)}}}}, + "foo;a2a=1"}, + {"0x33 in parameterised list key", + "foo; a3a=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("a3a", 1)}}}}, + "foo;a3a=1"}, + {"0x34 in parameterised list key", + "foo; a4a=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("a4a", 1)}}}}, + "foo;a4a=1"}, + {"0x35 in parameterised list key", + "foo; a5a=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("a5a", 1)}}}}, + "foo;a5a=1"}, + {"0x36 in parameterised list key", + "foo; a6a=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("a6a", 1)}}}}, + "foo;a6a=1"}, + {"0x37 in parameterised list key", + "foo; a7a=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("a7a", 1)}}}}, + "foo;a7a=1"}, + {"0x38 in parameterised list key", + "foo; a8a=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("a8a", 1)}}}}, + "foo;a8a=1"}, + {"0x39 in parameterised list key", + "foo; a9a=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("a9a", 1)}}}}, + "foo;a9a=1"}, + {"0x3a in parameterised list key", "foo; a:a=1", 10, absl::nullopt, + nullptr}, + {"0x3b in parameterised list key", + "foo; a;a=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("a", 1)}}}}, + "foo;a=1"}, + {"0x3c in parameterised list key", "foo; aa=1", 10, absl::nullopt, + nullptr}, + {"0x3f in parameterised list key", "foo; a?a=1", 10, absl::nullopt, + nullptr}, + {"0x40 in parameterised list key", "foo; a@a=1", 10, absl::nullopt, + nullptr}, + {"0x41 in parameterised list key", "foo; aAa=1", 10, absl::nullopt, + nullptr}, + {"0x42 in parameterised list key", "foo; aBa=1", 10, absl::nullopt, + nullptr}, + {"0x43 in parameterised list key", "foo; aCa=1", 10, absl::nullopt, + nullptr}, + {"0x44 in parameterised list key", "foo; aDa=1", 10, absl::nullopt, + nullptr}, + {"0x45 in parameterised list key", "foo; aEa=1", 10, absl::nullopt, + nullptr}, + {"0x46 in parameterised list key", "foo; aFa=1", 10, absl::nullopt, + nullptr}, + {"0x47 in parameterised list key", "foo; aGa=1", 10, absl::nullopt, + nullptr}, + {"0x48 in parameterised list key", "foo; aHa=1", 10, absl::nullopt, + nullptr}, + {"0x49 in parameterised list key", "foo; aIa=1", 10, absl::nullopt, + nullptr}, + {"0x4a in parameterised list key", "foo; aJa=1", 10, absl::nullopt, + nullptr}, + {"0x4b in parameterised list key", "foo; aKa=1", 10, absl::nullopt, + nullptr}, + {"0x4c in parameterised list key", "foo; aLa=1", 10, absl::nullopt, + nullptr}, + {"0x4d in parameterised list key", "foo; aMa=1", 10, absl::nullopt, + nullptr}, + {"0x4e in parameterised list key", "foo; aNa=1", 10, absl::nullopt, + nullptr}, + {"0x4f in parameterised list key", "foo; aOa=1", 10, absl::nullopt, + nullptr}, + {"0x50 in parameterised list key", "foo; aPa=1", 10, absl::nullopt, + nullptr}, + {"0x51 in parameterised list key", "foo; aQa=1", 10, absl::nullopt, + nullptr}, + {"0x52 in parameterised list key", "foo; aRa=1", 10, absl::nullopt, + nullptr}, + {"0x53 in parameterised list key", "foo; aSa=1", 10, absl::nullopt, + nullptr}, + {"0x54 in parameterised list key", "foo; aTa=1", 10, absl::nullopt, + nullptr}, + {"0x55 in parameterised list key", "foo; aUa=1", 10, absl::nullopt, + nullptr}, + {"0x56 in parameterised list key", "foo; aVa=1", 10, absl::nullopt, + nullptr}, + {"0x57 in parameterised list key", "foo; aWa=1", 10, absl::nullopt, + nullptr}, + {"0x58 in parameterised list key", "foo; aXa=1", 10, absl::nullopt, + nullptr}, + {"0x59 in parameterised list key", "foo; aYa=1", 10, absl::nullopt, + nullptr}, + {"0x5a in parameterised list key", "foo; aZa=1", 10, absl::nullopt, + nullptr}, + {"0x5b in parameterised list key", "foo; a[a=1", 10, absl::nullopt, + nullptr}, + {"0x5c in parameterised list key", "foo; a\\a=1", 10, absl::nullopt, + nullptr}, + {"0x5d in parameterised list key", "foo; a]a=1", 10, absl::nullopt, + nullptr}, + {"0x5e in parameterised list key", "foo; a^a=1", 10, absl::nullopt, + nullptr}, + {"0x5f in parameterised list key", + "foo; a_a=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("a_a", 1)}}}}, + "foo;a_a=1"}, + {"0x60 in parameterised list key", "foo; a`a=1", 10, absl::nullopt, + nullptr}, + {"0x61 in parameterised list key", + "foo; aaa=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("aaa", 1)}}}}, + "foo;aaa=1"}, + {"0x62 in parameterised list key", + "foo; aba=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("aba", 1)}}}}, + "foo;aba=1"}, + {"0x63 in parameterised list key", + "foo; aca=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("aca", 1)}}}}, + "foo;aca=1"}, + {"0x64 in parameterised list key", + "foo; ada=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("ada", 1)}}}}, + "foo;ada=1"}, + {"0x65 in parameterised list key", + "foo; aea=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("aea", 1)}}}}, + "foo;aea=1"}, + {"0x66 in parameterised list key", + "foo; afa=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("afa", 1)}}}}, + "foo;afa=1"}, + {"0x67 in parameterised list key", + "foo; aga=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("aga", 1)}}}}, + "foo;aga=1"}, + {"0x68 in parameterised list key", + "foo; aha=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("aha", 1)}}}}, + "foo;aha=1"}, + {"0x69 in parameterised list key", + "foo; aia=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("aia", 1)}}}}, + "foo;aia=1"}, + {"0x6a in parameterised list key", + "foo; aja=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("aja", 1)}}}}, + "foo;aja=1"}, + {"0x6b in parameterised list key", + "foo; aka=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("aka", 1)}}}}, + "foo;aka=1"}, + {"0x6c in parameterised list key", + "foo; ala=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("ala", 1)}}}}, + "foo;ala=1"}, + {"0x6d in parameterised list key", + "foo; ama=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("ama", 1)}}}}, + "foo;ama=1"}, + {"0x6e in parameterised list key", + "foo; ana=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("ana", 1)}}}}, + "foo;ana=1"}, + {"0x6f in parameterised list key", + "foo; aoa=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("aoa", 1)}}}}, + "foo;aoa=1"}, + {"0x70 in parameterised list key", + "foo; apa=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("apa", 1)}}}}, + "foo;apa=1"}, + {"0x71 in parameterised list key", + "foo; aqa=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("aqa", 1)}}}}, + "foo;aqa=1"}, + {"0x72 in parameterised list key", + "foo; ara=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("ara", 1)}}}}, + "foo;ara=1"}, + {"0x73 in parameterised list key", + "foo; asa=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("asa", 1)}}}}, + "foo;asa=1"}, + {"0x74 in parameterised list key", + "foo; ata=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("ata", 1)}}}}, + "foo;ata=1"}, + {"0x75 in parameterised list key", + "foo; aua=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("aua", 1)}}}}, + "foo;aua=1"}, + {"0x76 in parameterised list key", + "foo; ava=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("ava", 1)}}}}, + "foo;ava=1"}, + {"0x77 in parameterised list key", + "foo; awa=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("awa", 1)}}}}, + "foo;awa=1"}, + {"0x78 in parameterised list key", + "foo; axa=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("axa", 1)}}}}, + "foo;axa=1"}, + {"0x79 in parameterised list key", + "foo; aya=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("aya", 1)}}}}, + "foo;aya=1"}, + {"0x7a in parameterised list key", + "foo; aza=1", + 10, + {{{Item("foo", Item::kTokenType), {Param("aza", 1)}}}}, + "foo;aza=1"}, + {"0x7b in parameterised list key", "foo; a{a=1", 10, absl::nullopt, + nullptr}, + {"0x7c in parameterised list key", "foo; a|a=1", 10, absl::nullopt, + nullptr}, + {"0x7d in parameterised list key", "foo; a}a=1", 10, absl::nullopt, + nullptr}, + {"0x7e in parameterised list key", "foo; a~a=1", 10, absl::nullopt, + nullptr}, + {"0x7f in parameterised list key", "foo; a\177a=1", 10, absl::nullopt, + nullptr}, + {"0x00 starting a parameterised list key", "foo; \000a=1", 9, absl::nullopt, + nullptr}, + {"0x01 starting a parameterised list key", "foo; \001a=1", 9, absl::nullopt, + nullptr}, + {"0x02 starting a parameterised list key", "foo; \002a=1", 9, absl::nullopt, + nullptr}, + {"0x03 starting a parameterised list key", "foo; \003a=1", 9, absl::nullopt, + nullptr}, + {"0x04 starting a parameterised list key", "foo; \004a=1", 9, absl::nullopt, + nullptr}, + {"0x05 starting a parameterised list key", "foo; \005a=1", 9, absl::nullopt, + nullptr}, + {"0x06 starting a parameterised list key", "foo; \006a=1", 9, absl::nullopt, + nullptr}, + {"0x07 starting a parameterised list key", "foo; \aa=1", 9, absl::nullopt, + nullptr}, + {"0x08 starting a parameterised list key", "foo; \ba=1", 9, absl::nullopt, + nullptr}, + {"0x09 starting a parameterised list key", "foo; \ta=1", 9, absl::nullopt, + nullptr}, + {"0x0a starting a parameterised list key", "foo; \na=1", 9, absl::nullopt, + nullptr}, + {"0x0b starting a parameterised list key", "foo; \va=1", 9, absl::nullopt, + nullptr}, + {"0x0c starting a parameterised list key", "foo; \fa=1", 9, absl::nullopt, + nullptr}, + {"0x0d starting a parameterised list key", "foo; \ra=1", 9, absl::nullopt, + nullptr}, + {"0x0e starting a parameterised list key", "foo; \016a=1", 9, absl::nullopt, + nullptr}, + {"0x0f starting a parameterised list key", "foo; \017a=1", 9, absl::nullopt, + nullptr}, + {"0x10 starting a parameterised list key", "foo; \020a=1", 9, absl::nullopt, + nullptr}, + {"0x11 starting a parameterised list key", "foo; \021a=1", 9, absl::nullopt, + nullptr}, + {"0x12 starting a parameterised list key", "foo; \022a=1", 9, absl::nullopt, + nullptr}, + {"0x13 starting a parameterised list key", "foo; \023a=1", 9, absl::nullopt, + nullptr}, + {"0x14 starting a parameterised list key", "foo; \024a=1", 9, absl::nullopt, + nullptr}, + {"0x15 starting a parameterised list key", "foo; \025a=1", 9, absl::nullopt, + nullptr}, + {"0x16 starting a parameterised list key", "foo; \026a=1", 9, absl::nullopt, + nullptr}, + {"0x17 starting a parameterised list key", "foo; \027a=1", 9, absl::nullopt, + nullptr}, + {"0x18 starting a parameterised list key", "foo; \030a=1", 9, absl::nullopt, + nullptr}, + {"0x19 starting a parameterised list key", "foo; \031a=1", 9, absl::nullopt, + nullptr}, + {"0x1a starting a parameterised list key", "foo; \032a=1", 9, absl::nullopt, + nullptr}, + {"0x1b starting a parameterised list key", "foo; \033a=1", 9, absl::nullopt, + nullptr}, + {"0x1c starting a parameterised list key", "foo; \034a=1", 9, absl::nullopt, + nullptr}, + {"0x1d starting a parameterised list key", "foo; \035a=1", 9, absl::nullopt, + nullptr}, + {"0x1e starting a parameterised list key", "foo; \036a=1", 9, absl::nullopt, + nullptr}, + {"0x1f starting a parameterised list key", "foo; \037a=1", 9, absl::nullopt, + nullptr}, + {"0x20 starting a parameterised list key", + "foo; a=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("a", 1)}}}}, + "foo;a=1"}, + {"0x21 starting a parameterised list key", "foo; !a=1", 9, absl::nullopt, + nullptr}, + {"0x22 starting a parameterised list key", "foo; \"a=1", 9, absl::nullopt, + nullptr}, + {"0x23 starting a parameterised list key", "foo; #a=1", 9, absl::nullopt, + nullptr}, + {"0x24 starting a parameterised list key", "foo; $a=1", 9, absl::nullopt, + nullptr}, + {"0x25 starting a parameterised list key", "foo; %a=1", 9, absl::nullopt, + nullptr}, + {"0x26 starting a parameterised list key", "foo; &a=1", 9, absl::nullopt, + nullptr}, + {"0x27 starting a parameterised list key", "foo; 'a=1", 9, absl::nullopt, + nullptr}, + {"0x28 starting a parameterised list key", "foo; (a=1", 9, absl::nullopt, + nullptr}, + {"0x29 starting a parameterised list key", "foo; )a=1", 9, absl::nullopt, + nullptr}, + {"0x2a starting a parameterised list key", + "foo; *a=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("*a", 1)}}}}, + "foo;*a=1"}, + {"0x2b starting a parameterised list key", "foo; +a=1", 9, absl::nullopt, + nullptr}, + {"0x2c starting a parameterised list key", "foo; ,a=1", 9, absl::nullopt, + nullptr}, + {"0x2d starting a parameterised list key", "foo; -a=1", 9, absl::nullopt, + nullptr}, + {"0x2e starting a parameterised list key", "foo; .a=1", 9, absl::nullopt, + nullptr}, + {"0x2f starting a parameterised list key", "foo; /a=1", 9, absl::nullopt, + nullptr}, + {"0x30 starting a parameterised list key", "foo; 0a=1", 9, absl::nullopt, + nullptr}, + {"0x31 starting a parameterised list key", "foo; 1a=1", 9, absl::nullopt, + nullptr}, + {"0x32 starting a parameterised list key", "foo; 2a=1", 9, absl::nullopt, + nullptr}, + {"0x33 starting a parameterised list key", "foo; 3a=1", 9, absl::nullopt, + nullptr}, + {"0x34 starting a parameterised list key", "foo; 4a=1", 9, absl::nullopt, + nullptr}, + {"0x35 starting a parameterised list key", "foo; 5a=1", 9, absl::nullopt, + nullptr}, + {"0x36 starting a parameterised list key", "foo; 6a=1", 9, absl::nullopt, + nullptr}, + {"0x37 starting a parameterised list key", "foo; 7a=1", 9, absl::nullopt, + nullptr}, + {"0x38 starting a parameterised list key", "foo; 8a=1", 9, absl::nullopt, + nullptr}, + {"0x39 starting a parameterised list key", "foo; 9a=1", 9, absl::nullopt, + nullptr}, + {"0x3a starting a parameterised list key", "foo; :a=1", 9, absl::nullopt, + nullptr}, + {"0x3b starting a parameterised list key", "foo; ;a=1", 9, absl::nullopt, + nullptr}, + {"0x3c starting a parameterised list key", "foo; a=1", 9, absl::nullopt, + nullptr}, + {"0x3f starting a parameterised list key", "foo; ?a=1", 9, absl::nullopt, + nullptr}, + {"0x40 starting a parameterised list key", "foo; @a=1", 9, absl::nullopt, + nullptr}, + {"0x41 starting a parameterised list key", "foo; Aa=1", 9, absl::nullopt, + nullptr}, + {"0x42 starting a parameterised list key", "foo; Ba=1", 9, absl::nullopt, + nullptr}, + {"0x43 starting a parameterised list key", "foo; Ca=1", 9, absl::nullopt, + nullptr}, + {"0x44 starting a parameterised list key", "foo; Da=1", 9, absl::nullopt, + nullptr}, + {"0x45 starting a parameterised list key", "foo; Ea=1", 9, absl::nullopt, + nullptr}, + {"0x46 starting a parameterised list key", "foo; Fa=1", 9, absl::nullopt, + nullptr}, + {"0x47 starting a parameterised list key", "foo; Ga=1", 9, absl::nullopt, + nullptr}, + {"0x48 starting a parameterised list key", "foo; Ha=1", 9, absl::nullopt, + nullptr}, + {"0x49 starting a parameterised list key", "foo; Ia=1", 9, absl::nullopt, + nullptr}, + {"0x4a starting a parameterised list key", "foo; Ja=1", 9, absl::nullopt, + nullptr}, + {"0x4b starting a parameterised list key", "foo; Ka=1", 9, absl::nullopt, + nullptr}, + {"0x4c starting a parameterised list key", "foo; La=1", 9, absl::nullopt, + nullptr}, + {"0x4d starting a parameterised list key", "foo; Ma=1", 9, absl::nullopt, + nullptr}, + {"0x4e starting a parameterised list key", "foo; Na=1", 9, absl::nullopt, + nullptr}, + {"0x4f starting a parameterised list key", "foo; Oa=1", 9, absl::nullopt, + nullptr}, + {"0x50 starting a parameterised list key", "foo; Pa=1", 9, absl::nullopt, + nullptr}, + {"0x51 starting a parameterised list key", "foo; Qa=1", 9, absl::nullopt, + nullptr}, + {"0x52 starting a parameterised list key", "foo; Ra=1", 9, absl::nullopt, + nullptr}, + {"0x53 starting a parameterised list key", "foo; Sa=1", 9, absl::nullopt, + nullptr}, + {"0x54 starting a parameterised list key", "foo; Ta=1", 9, absl::nullopt, + nullptr}, + {"0x55 starting a parameterised list key", "foo; Ua=1", 9, absl::nullopt, + nullptr}, + {"0x56 starting a parameterised list key", "foo; Va=1", 9, absl::nullopt, + nullptr}, + {"0x57 starting a parameterised list key", "foo; Wa=1", 9, absl::nullopt, + nullptr}, + {"0x58 starting a parameterised list key", "foo; Xa=1", 9, absl::nullopt, + nullptr}, + {"0x59 starting a parameterised list key", "foo; Ya=1", 9, absl::nullopt, + nullptr}, + {"0x5a starting a parameterised list key", "foo; Za=1", 9, absl::nullopt, + nullptr}, + {"0x5b starting a parameterised list key", "foo; [a=1", 9, absl::nullopt, + nullptr}, + {"0x5c starting a parameterised list key", "foo; \\a=1", 9, absl::nullopt, + nullptr}, + {"0x5d starting a parameterised list key", "foo; ]a=1", 9, absl::nullopt, + nullptr}, + {"0x5e starting a parameterised list key", "foo; ^a=1", 9, absl::nullopt, + nullptr}, + {"0x5f starting a parameterised list key", "foo; _a=1", 9, absl::nullopt, + nullptr}, + {"0x60 starting a parameterised list key", "foo; `a=1", 9, absl::nullopt, + nullptr}, + {"0x61 starting a parameterised list key", + "foo; aa=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("aa", 1)}}}}, + "foo;aa=1"}, + {"0x62 starting a parameterised list key", + "foo; ba=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("ba", 1)}}}}, + "foo;ba=1"}, + {"0x63 starting a parameterised list key", + "foo; ca=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("ca", 1)}}}}, + "foo;ca=1"}, + {"0x64 starting a parameterised list key", + "foo; da=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("da", 1)}}}}, + "foo;da=1"}, + {"0x65 starting a parameterised list key", + "foo; ea=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("ea", 1)}}}}, + "foo;ea=1"}, + {"0x66 starting a parameterised list key", + "foo; fa=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("fa", 1)}}}}, + "foo;fa=1"}, + {"0x67 starting a parameterised list key", + "foo; ga=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("ga", 1)}}}}, + "foo;ga=1"}, + {"0x68 starting a parameterised list key", + "foo; ha=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("ha", 1)}}}}, + "foo;ha=1"}, + {"0x69 starting a parameterised list key", + "foo; ia=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("ia", 1)}}}}, + "foo;ia=1"}, + {"0x6a starting a parameterised list key", + "foo; ja=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("ja", 1)}}}}, + "foo;ja=1"}, + {"0x6b starting a parameterised list key", + "foo; ka=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("ka", 1)}}}}, + "foo;ka=1"}, + {"0x6c starting a parameterised list key", + "foo; la=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("la", 1)}}}}, + "foo;la=1"}, + {"0x6d starting a parameterised list key", + "foo; ma=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("ma", 1)}}}}, + "foo;ma=1"}, + {"0x6e starting a parameterised list key", + "foo; na=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("na", 1)}}}}, + "foo;na=1"}, + {"0x6f starting a parameterised list key", + "foo; oa=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("oa", 1)}}}}, + "foo;oa=1"}, + {"0x70 starting a parameterised list key", + "foo; pa=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("pa", 1)}}}}, + "foo;pa=1"}, + {"0x71 starting a parameterised list key", + "foo; qa=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("qa", 1)}}}}, + "foo;qa=1"}, + {"0x72 starting a parameterised list key", + "foo; ra=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("ra", 1)}}}}, + "foo;ra=1"}, + {"0x73 starting a parameterised list key", + "foo; sa=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("sa", 1)}}}}, + "foo;sa=1"}, + {"0x74 starting a parameterised list key", + "foo; ta=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("ta", 1)}}}}, + "foo;ta=1"}, + {"0x75 starting a parameterised list key", + "foo; ua=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("ua", 1)}}}}, + "foo;ua=1"}, + {"0x76 starting a parameterised list key", + "foo; va=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("va", 1)}}}}, + "foo;va=1"}, + {"0x77 starting a parameterised list key", + "foo; wa=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("wa", 1)}}}}, + "foo;wa=1"}, + {"0x78 starting a parameterised list key", + "foo; xa=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("xa", 1)}}}}, + "foo;xa=1"}, + {"0x79 starting a parameterised list key", + "foo; ya=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("ya", 1)}}}}, + "foo;ya=1"}, + {"0x7a starting a parameterised list key", + "foo; za=1", + 9, + {{{Item("foo", Item::kTokenType), {Param("za", 1)}}}}, + "foo;za=1"}, + {"0x7b starting a parameterised list key", "foo; {a=1", 9, absl::nullopt, + nullptr}, + {"0x7c starting a parameterised list key", "foo; |a=1", 9, absl::nullopt, + nullptr}, + {"0x7d starting a parameterised list key", "foo; }a=1", 9, absl::nullopt, + nullptr}, + {"0x7e starting a parameterised list key", "foo; ~a=1", 9, absl::nullopt, + nullptr}, + {"0x7f starting a parameterised list key", "foo; \177a=1", 9, absl::nullopt, + nullptr}, + // list.json + {"basic list", + "1, 42", + 5, + {{{Integer(1), {}}, {Integer(42), {}}}}, + nullptr}, + {"empty list", "", 0, {List()}, nullptr}, + {"leading SP list", + " 42, 43", + 8, + {{{Integer(42), {}}, {Integer(43), {}}}}, + "42, 43"}, + {"single item list", "42", 2, {{{Integer(42), {}}}}, nullptr}, + {"no whitespace list", + "1,42", + 4, + {{{Integer(1), {}}, {Integer(42), {}}}}, + "1, 42"}, + {"extra whitespace list", + "1 , 42", + 6, + {{{Integer(1), {}}, {Integer(42), {}}}}, + "1, 42"}, + {"tab separated list", + "1\t,\t42", + 6, + {{{Integer(1), {}}, {Integer(42), {}}}}, + "1, 42"}, + {"two line list", + "1, 42", + 5, + {{{Integer(1), {}}, {Integer(42), {}}}}, + "1, 42"}, + {"trailing comma list", "1, 42,", 6, absl::nullopt, nullptr}, + {"empty item list", "1,,42", 5, absl::nullopt, nullptr}, + {"empty item list (multiple field lines)", "1, , 42", 7, absl::nullopt, + nullptr}, + // listlist.json + {"basic list of lists", + "(1 2), (42 43)", + 14, + {{{{{Integer(1), {}}, {Integer(2), {}}}, {}}, + {{{Integer(42), {}}, {Integer(43), {}}}, {}}}}, + nullptr}, + {"single item list of lists", + "(42)", + 4, + {{{{{Integer(42), {}}}, {}}}}, + nullptr}, + {"empty item list of lists", + "()", + 2, + {{{std::vector(), {}}}}, + nullptr}, + {"empty middle item list of lists", + "(1),(),(42)", + 11, + {{{{{Integer(1), {}}}, {}}, + {std::vector(), {}}, + {{{Integer(42), {}}}, {}}}}, + "(1), (), (42)"}, + {"extra whitespace list of lists", + "( 1 42 )", + 11, + {{{{{Integer(1), {}}, {Integer(42), {}}}, {}}}}, + "(1 42)"}, + {"wrong whitespace list of lists", "(1\t 42)", 7, absl::nullopt, nullptr}, + {"no trailing parenthesis list of lists", "(1 42", 5, absl::nullopt, + nullptr}, + {"no trailing parenthesis middle list of lists", "(1 2, (42 43)", 13, + absl::nullopt, nullptr}, + {"no spaces in inner-list", "(abc\"def\"?0123*dXZ3*xyz)", 24, absl::nullopt, + nullptr}, + {"no closing parenthesis", "(", 1, absl::nullopt, nullptr}, + // param-list.json + {"basic parameterised list", + "abc_123;a=1;b=2; cdef_456, ghi;q=9;r=\"+w\"", + 41, + {{{Item("abc_123", Item::kTokenType), + {Param("a", 1), Param("b", 2), BooleanParam("cdef_456", true)}}, + {Item("ghi", Item::kTokenType), {Param("q", 9), Param("r", "+w")}}}}, + "abc_123;a=1;b=2;cdef_456, ghi;q=9;r=\"+w\""}, + {"single item parameterised list", + "text/html;q=1.0", + 15, + {{{Item("text/html", Item::kTokenType), {DoubleParam("q", 1.000000)}}}}, + nullptr}, + {"missing parameter value parameterised list", + "text/html;a;q=1.0", + 17, + {{{Item("text/html", Item::kTokenType), + {BooleanParam("a", true), DoubleParam("q", 1.000000)}}}}, + nullptr}, + {"missing terminal parameter value parameterised list", + "text/html;q=1.0;a", + 17, + {{{Item("text/html", Item::kTokenType), + {DoubleParam("q", 1.000000), BooleanParam("a", true)}}}}, + nullptr}, + {"no whitespace parameterised list", + "text/html,text/plain;q=0.5", + 26, + {{{Item("text/html", Item::kTokenType), {}}, + {Item("text/plain", Item::kTokenType), {DoubleParam("q", 0.500000)}}}}, + "text/html, text/plain;q=0.5"}, + {"whitespace before = parameterised list", "text/html, text/plain;q =0.5", + 28, absl::nullopt, nullptr}, + {"whitespace after = parameterised list", "text/html, text/plain;q= 0.5", + 28, absl::nullopt, nullptr}, + {"whitespace before ; parameterised list", "text/html, text/plain ;q=0.5", + 28, absl::nullopt, nullptr}, + {"whitespace after ; parameterised list", + "text/html, text/plain; q=0.5", + 28, + {{{Item("text/html", Item::kTokenType), {}}, + {Item("text/plain", Item::kTokenType), {DoubleParam("q", 0.500000)}}}}, + "text/html, text/plain;q=0.5"}, + {"extra whitespace parameterised list", + "text/html , text/plain; q=0.5; charset=utf-8", + 48, + {{{Item("text/html", Item::kTokenType), {}}, + {Item("text/plain", Item::kTokenType), + {DoubleParam("q", 0.500000), TokenParam("charset", "utf-8")}}}}, + "text/html, text/plain;q=0.5;charset=utf-8"}, + {"two lines parameterised list", + "text/html, text/plain;q=0.5", + 27, + {{{Item("text/html", Item::kTokenType), {}}, + {Item("text/plain", Item::kTokenType), {DoubleParam("q", 0.500000)}}}}, + "text/html, text/plain;q=0.5"}, + {"trailing comma parameterised list", "text/html,text/plain;q=0.5,", 27, + absl::nullopt, nullptr}, + {"empty item parameterised list", "text/html,,text/plain;q=0.5,", 28, + absl::nullopt, nullptr}, + // param-listlist.json + {"parameterised inner list", + "(abc_123);a=1;b=2, cdef_456", + 27, + {{{{{Item("abc_123", Item::kTokenType), {}}}, + {Param("a", 1), Param("b", 2)}}, + {Item("cdef_456", Item::kTokenType), {}}}}, + nullptr}, + {"parameterised inner list item", + "(abc_123;a=1;b=2;cdef_456)", + 26, + {{{{{Item("abc_123", Item::kTokenType), + {Param("a", 1), Param("b", 2), BooleanParam("cdef_456", true)}}}, + {}}}}, + nullptr}, + {"parameterised inner list with parameterised item", + "(abc_123;a=1;b=2);cdef_456", + 26, + {{{{{Item("abc_123", Item::kTokenType), {Param("a", 1), Param("b", 2)}}}, + {BooleanParam("cdef_456", true)}}}}, + nullptr}, + // token.json + {"basic token - list", + "a_b-c3/*", + 8, + {{{Item("a_b-c3/*", Item::kTokenType), {}}}}, + nullptr}, + {"token with capitals - list", + "fooBar", + 6, + {{{Item("fooBar", Item::kTokenType), {}}}}, + nullptr}, + {"token starting with capitals - list", + "FooBar", + 6, + {{{Item("FooBar", Item::kTokenType), {}}}}, + nullptr}, +}; + +const struct DictionaryTestCase { + const char* name; + const char* raw; + size_t raw_len; + const absl::optional + expected; // nullopt if parse error is expected. + const char* canonical; // nullptr if parse error is expected, or if canonical + // format is identical to raw. +} dictionary_test_cases[] = { + // dictionary.json + {"basic dictionary", + "en=\"Applepie\", da=:w4ZibGV0w6ZydGUK:", + 36, + {Dictionary{ + {{"en", {Item("Applepie"), {}}}, + {"da", + {Item("\303\206blet\303\246rte\n", Item::kByteSequenceType), {}}}}}}, + nullptr}, + {"empty dictionary", "", 0, {Dictionary{{}}}, nullptr}, + {"single item dictionary", + "a=1", + 3, + {Dictionary{{{"a", {Integer(1), {}}}}}}, + nullptr}, + {"list item dictionary", + "a=(1 2)", + 7, + {Dictionary{{{"a", {{{Integer(1), {}}, {Integer(2), {}}}, {}}}}}}, + nullptr}, + {"single list item dictionary", + "a=(1)", + 5, + {Dictionary{{{"a", {{{Integer(1), {}}}, {}}}}}}, + nullptr}, + {"empty list item dictionary", + "a=()", + 4, + {Dictionary{{{"a", {std::vector(), {}}}}}}, + nullptr}, + {"no whitespace dictionary", + "a=1,b=2", + 7, + {Dictionary{{{"a", {Integer(1), {}}}, {"b", {Integer(2), {}}}}}}, + "a=1, b=2"}, + {"extra whitespace dictionary", + "a=1 , b=2", + 10, + {Dictionary{{{"a", {Integer(1), {}}}, {"b", {Integer(2), {}}}}}}, + "a=1, b=2"}, + {"tab separated dictionary", + "a=1\t,\tb=2", + 9, + {Dictionary{{{"a", {Integer(1), {}}}, {"b", {Integer(2), {}}}}}}, + "a=1, b=2"}, + {"leading whitespace dictionary", + " a=1 , b=2", + 15, + {Dictionary{{{"a", {Integer(1), {}}}, {"b", {Integer(2), {}}}}}}, + "a=1, b=2"}, + {"whitespace before = dictionary", "a =1, b=2", 9, absl::nullopt, nullptr}, + {"whitespace after = dictionary", "a=1, b= 2", 9, absl::nullopt, nullptr}, + {"two lines dictionary", + "a=1, b=2", + 8, + {Dictionary{{{"a", {Integer(1), {}}}, {"b", {Integer(2), {}}}}}}, + "a=1, b=2"}, + {"missing value dictionary", + "a=1, b, c=3", + 11, + {Dictionary{{{"a", {Integer(1), {}}}, + {"b", {Item(true), {}}}, + {"c", {Integer(3), {}}}}}}, + nullptr}, + {"all missing value dictionary", + "a, b, c", + 7, + {Dictionary{{{"a", {Item(true), {}}}, + {"b", {Item(true), {}}}, + {"c", {Item(true), {}}}}}}, + nullptr}, + {"start missing value dictionary", + "a, b=2", + 6, + {Dictionary{{{"a", {Item(true), {}}}, {"b", {Integer(2), {}}}}}}, + nullptr}, + {"end missing value dictionary", + "a=1, b", + 6, + {Dictionary{{{"a", {Integer(1), {}}}, {"b", {Item(true), {}}}}}}, + nullptr}, + {"missing value with params dictionary", + "a=1, b;foo=9, c=3", + 17, + {Dictionary{{{"a", {Integer(1), {}}}, + {"b", {Item(true), {Param("foo", 9)}}}, + {"c", {Integer(3), {}}}}}}, + nullptr}, + {"explicit true value with params dictionary", + "a=1, b=?1;foo=9, c=3", + 20, + {Dictionary{{{"a", {Integer(1), {}}}, + {"b", {Item(true), {Param("foo", 9)}}}, + {"c", {Integer(3), {}}}}}}, + "a=1, b;foo=9, c=3"}, + {"trailing comma dictionary", "a=1, b=2,", 9, absl::nullopt, nullptr}, + {"empty item dictionary", "a=1,,b=2,", 9, absl::nullopt, nullptr}, + {"duplicate key dictionary", + "a=1,b=2,a=3", + 11, + {Dictionary{{{"a", {Integer(3), {}}}, {"b", {Integer(2), {}}}}}}, + "a=3, b=2"}, + {"numeric key dictionary", "a=1,1b=2,a=1", 12, absl::nullopt, nullptr}, + {"uppercase key dictionary", "a=1,B=2,a=1", 11, absl::nullopt, nullptr}, + {"bad key dictionary", "a=1,b!=2,a=1", 12, absl::nullopt, nullptr}, + // examples.json + {"Example-DictHeader", + "en=\"Applepie\", da=:w4ZibGV0w6ZydGU=:", + 36, + {Dictionary{ + {{"en", {Item("Applepie"), {}}}, + {"da", + {Item("\303\206blet\303\246rte", Item::kByteSequenceType), {}}}}}}, + nullptr}, + {"Example-DictHeader (boolean values)", + "a=?0, b, c; foo=bar", + 19, + {Dictionary{{{"a", {Item(false), {}}}, + {"b", {Item(true), {}}}, + {"c", {Item(true), {TokenParam("foo", "bar")}}}}}}, + "a=?0, b, c;foo=bar"}, + {"Example-DictListHeader", + "rating=1.5, feelings=(joy sadness)", + 34, + {Dictionary{{{"rating", {Item(1.500000), {}}}, + {"feelings", + {{{Item("joy", Item::kTokenType), {}}, + {Item("sadness", Item::kTokenType), {}}}, + {}}}}}}, + nullptr}, + {"Example-MixDict", + "a=(1 2), b=3, c=4;aa=bb, d=(5 6);valid", + 38, + {Dictionary{{{"a", {{{Integer(1), {}}, {Integer(2), {}}}, {}}}, + {"b", {Integer(3), {}}}, + {"c", {Integer(4), {TokenParam("aa", "bb")}}}, + {"d", + {{{Integer(5), {}}, {Integer(6), {}}}, + {BooleanParam("valid", true)}}}}}}, + "a=(1 2), b=3, c=4;aa=bb, d=(5 6);valid"}, + {"Example-Hdr (dictionary on one line)", + "foo=1, bar=2", + 12, + {Dictionary{{{"foo", {Integer(1), {}}}, {"bar", {Integer(2), {}}}}}}, + nullptr}, + {"Example-Hdr (dictionary on two lines)", + "foo=1, bar=2", + 12, + {Dictionary{{{"foo", {Integer(1), {}}}, {"bar", {Integer(2), {}}}}}}, + "foo=1, bar=2"}, + // key-generated.json + {"0x00 as a single-character dictionary key", "\000=1", 3, absl::nullopt, + nullptr}, + {"0x01 as a single-character dictionary key", "\001=1", 3, absl::nullopt, + nullptr}, + {"0x02 as a single-character dictionary key", "\002=1", 3, absl::nullopt, + nullptr}, + {"0x03 as a single-character dictionary key", "\003=1", 3, absl::nullopt, + nullptr}, + {"0x04 as a single-character dictionary key", "\004=1", 3, absl::nullopt, + nullptr}, + {"0x05 as a single-character dictionary key", "\005=1", 3, absl::nullopt, + nullptr}, + {"0x06 as a single-character dictionary key", "\006=1", 3, absl::nullopt, + nullptr}, + {"0x07 as a single-character dictionary key", "\a=1", 3, absl::nullopt, + nullptr}, + {"0x08 as a single-character dictionary key", "\b=1", 3, absl::nullopt, + nullptr}, + {"0x09 as a single-character dictionary key", "\t=1", 3, absl::nullopt, + nullptr}, + {"0x0a as a single-character dictionary key", "\n=1", 3, absl::nullopt, + nullptr}, + {"0x0b as a single-character dictionary key", "\v=1", 3, absl::nullopt, + nullptr}, + {"0x0c as a single-character dictionary key", "\f=1", 3, absl::nullopt, + nullptr}, + {"0x0d as a single-character dictionary key", "\r=1", 3, absl::nullopt, + nullptr}, + {"0x0e as a single-character dictionary key", "\016=1", 3, absl::nullopt, + nullptr}, + {"0x0f as a single-character dictionary key", "\017=1", 3, absl::nullopt, + nullptr}, + {"0x10 as a single-character dictionary key", "\020=1", 3, absl::nullopt, + nullptr}, + {"0x11 as a single-character dictionary key", "\021=1", 3, absl::nullopt, + nullptr}, + {"0x12 as a single-character dictionary key", "\022=1", 3, absl::nullopt, + nullptr}, + {"0x13 as a single-character dictionary key", "\023=1", 3, absl::nullopt, + nullptr}, + {"0x14 as a single-character dictionary key", "\024=1", 3, absl::nullopt, + nullptr}, + {"0x15 as a single-character dictionary key", "\025=1", 3, absl::nullopt, + nullptr}, + {"0x16 as a single-character dictionary key", "\026=1", 3, absl::nullopt, + nullptr}, + {"0x17 as a single-character dictionary key", "\027=1", 3, absl::nullopt, + nullptr}, + {"0x18 as a single-character dictionary key", "\030=1", 3, absl::nullopt, + nullptr}, + {"0x19 as a single-character dictionary key", "\031=1", 3, absl::nullopt, + nullptr}, + {"0x1a as a single-character dictionary key", "\032=1", 3, absl::nullopt, + nullptr}, + {"0x1b as a single-character dictionary key", "\033=1", 3, absl::nullopt, + nullptr}, + {"0x1c as a single-character dictionary key", "\034=1", 3, absl::nullopt, + nullptr}, + {"0x1d as a single-character dictionary key", "\035=1", 3, absl::nullopt, + nullptr}, + {"0x1e as a single-character dictionary key", "\036=1", 3, absl::nullopt, + nullptr}, + {"0x1f as a single-character dictionary key", "\037=1", 3, absl::nullopt, + nullptr}, + {"0x20 as a single-character dictionary key", "=1", 2, absl::nullopt, + nullptr}, + {"0x21 as a single-character dictionary key", "!=1", 3, absl::nullopt, + nullptr}, + {"0x22 as a single-character dictionary key", "\"=1", 3, absl::nullopt, + nullptr}, + {"0x23 as a single-character dictionary key", "#=1", 3, absl::nullopt, + nullptr}, + {"0x24 as a single-character dictionary key", "$=1", 3, absl::nullopt, + nullptr}, + {"0x25 as a single-character dictionary key", "%=1", 3, absl::nullopt, + nullptr}, + {"0x26 as a single-character dictionary key", "&=1", 3, absl::nullopt, + nullptr}, + {"0x27 as a single-character dictionary key", "'=1", 3, absl::nullopt, + nullptr}, + {"0x28 as a single-character dictionary key", "(=1", 3, absl::nullopt, + nullptr}, + {"0x29 as a single-character dictionary key", ")=1", 3, absl::nullopt, + nullptr}, + {"0x2a as a single-character dictionary key", + "*=1", + 3, + {Dictionary{{{"*", {Integer(1), {}}}}}}, + nullptr}, + {"0x2b as a single-character dictionary key", "+=1", 3, absl::nullopt, + nullptr}, + {"0x2c as a single-character dictionary key", ",=1", 3, absl::nullopt, + nullptr}, + {"0x2d as a single-character dictionary key", "-=1", 3, absl::nullopt, + nullptr}, + {"0x2e as a single-character dictionary key", ".=1", 3, absl::nullopt, + nullptr}, + {"0x2f as a single-character dictionary key", "/=1", 3, absl::nullopt, + nullptr}, + {"0x30 as a single-character dictionary key", "0=1", 3, absl::nullopt, + nullptr}, + {"0x31 as a single-character dictionary key", "1=1", 3, absl::nullopt, + nullptr}, + {"0x32 as a single-character dictionary key", "2=1", 3, absl::nullopt, + nullptr}, + {"0x33 as a single-character dictionary key", "3=1", 3, absl::nullopt, + nullptr}, + {"0x34 as a single-character dictionary key", "4=1", 3, absl::nullopt, + nullptr}, + {"0x35 as a single-character dictionary key", "5=1", 3, absl::nullopt, + nullptr}, + {"0x36 as a single-character dictionary key", "6=1", 3, absl::nullopt, + nullptr}, + {"0x37 as a single-character dictionary key", "7=1", 3, absl::nullopt, + nullptr}, + {"0x38 as a single-character dictionary key", "8=1", 3, absl::nullopt, + nullptr}, + {"0x39 as a single-character dictionary key", "9=1", 3, absl::nullopt, + nullptr}, + {"0x3a as a single-character dictionary key", ":=1", 3, absl::nullopt, + nullptr}, + {"0x3b as a single-character dictionary key", ";=1", 3, absl::nullopt, + nullptr}, + {"0x3c as a single-character dictionary key", "<=1", 3, absl::nullopt, + nullptr}, + {"0x3d as a single-character dictionary key", "==1", 3, absl::nullopt, + nullptr}, + {"0x3e as a single-character dictionary key", ">=1", 3, absl::nullopt, + nullptr}, + {"0x3f as a single-character dictionary key", "?=1", 3, absl::nullopt, + nullptr}, + {"0x40 as a single-character dictionary key", "@=1", 3, absl::nullopt, + nullptr}, + {"0x41 as a single-character dictionary key", "A=1", 3, absl::nullopt, + nullptr}, + {"0x42 as a single-character dictionary key", "B=1", 3, absl::nullopt, + nullptr}, + {"0x43 as a single-character dictionary key", "C=1", 3, absl::nullopt, + nullptr}, + {"0x44 as a single-character dictionary key", "D=1", 3, absl::nullopt, + nullptr}, + {"0x45 as a single-character dictionary key", "E=1", 3, absl::nullopt, + nullptr}, + {"0x46 as a single-character dictionary key", "F=1", 3, absl::nullopt, + nullptr}, + {"0x47 as a single-character dictionary key", "G=1", 3, absl::nullopt, + nullptr}, + {"0x48 as a single-character dictionary key", "H=1", 3, absl::nullopt, + nullptr}, + {"0x49 as a single-character dictionary key", "I=1", 3, absl::nullopt, + nullptr}, + {"0x4a as a single-character dictionary key", "J=1", 3, absl::nullopt, + nullptr}, + {"0x4b as a single-character dictionary key", "K=1", 3, absl::nullopt, + nullptr}, + {"0x4c as a single-character dictionary key", "L=1", 3, absl::nullopt, + nullptr}, + {"0x4d as a single-character dictionary key", "M=1", 3, absl::nullopt, + nullptr}, + {"0x4e as a single-character dictionary key", "N=1", 3, absl::nullopt, + nullptr}, + {"0x4f as a single-character dictionary key", "O=1", 3, absl::nullopt, + nullptr}, + {"0x50 as a single-character dictionary key", "P=1", 3, absl::nullopt, + nullptr}, + {"0x51 as a single-character dictionary key", "Q=1", 3, absl::nullopt, + nullptr}, + {"0x52 as a single-character dictionary key", "R=1", 3, absl::nullopt, + nullptr}, + {"0x53 as a single-character dictionary key", "S=1", 3, absl::nullopt, + nullptr}, + {"0x54 as a single-character dictionary key", "T=1", 3, absl::nullopt, + nullptr}, + {"0x55 as a single-character dictionary key", "U=1", 3, absl::nullopt, + nullptr}, + {"0x56 as a single-character dictionary key", "V=1", 3, absl::nullopt, + nullptr}, + {"0x57 as a single-character dictionary key", "W=1", 3, absl::nullopt, + nullptr}, + {"0x58 as a single-character dictionary key", "X=1", 3, absl::nullopt, + nullptr}, + {"0x59 as a single-character dictionary key", "Y=1", 3, absl::nullopt, + nullptr}, + {"0x5a as a single-character dictionary key", "Z=1", 3, absl::nullopt, + nullptr}, + {"0x5b as a single-character dictionary key", "[=1", 3, absl::nullopt, + nullptr}, + {"0x5c as a single-character dictionary key", "\\=1", 3, absl::nullopt, + nullptr}, + {"0x5d as a single-character dictionary key", "]=1", 3, absl::nullopt, + nullptr}, + {"0x5e as a single-character dictionary key", "^=1", 3, absl::nullopt, + nullptr}, + {"0x5f as a single-character dictionary key", "_=1", 3, absl::nullopt, + nullptr}, + {"0x60 as a single-character dictionary key", "`=1", 3, absl::nullopt, + nullptr}, + {"0x61 as a single-character dictionary key", + "a=1", + 3, + {Dictionary{{{"a", {Integer(1), {}}}}}}, + nullptr}, + {"0x62 as a single-character dictionary key", + "b=1", + 3, + {Dictionary{{{"b", {Integer(1), {}}}}}}, + nullptr}, + {"0x63 as a single-character dictionary key", + "c=1", + 3, + {Dictionary{{{"c", {Integer(1), {}}}}}}, + nullptr}, + {"0x64 as a single-character dictionary key", + "d=1", + 3, + {Dictionary{{{"d", {Integer(1), {}}}}}}, + nullptr}, + {"0x65 as a single-character dictionary key", + "e=1", + 3, + {Dictionary{{{"e", {Integer(1), {}}}}}}, + nullptr}, + {"0x66 as a single-character dictionary key", + "f=1", + 3, + {Dictionary{{{"f", {Integer(1), {}}}}}}, + nullptr}, + {"0x67 as a single-character dictionary key", + "g=1", + 3, + {Dictionary{{{"g", {Integer(1), {}}}}}}, + nullptr}, + {"0x68 as a single-character dictionary key", + "h=1", + 3, + {Dictionary{{{"h", {Integer(1), {}}}}}}, + nullptr}, + {"0x69 as a single-character dictionary key", + "i=1", + 3, + {Dictionary{{{"i", {Integer(1), {}}}}}}, + nullptr}, + {"0x6a as a single-character dictionary key", + "j=1", + 3, + {Dictionary{{{"j", {Integer(1), {}}}}}}, + nullptr}, + {"0x6b as a single-character dictionary key", + "k=1", + 3, + {Dictionary{{{"k", {Integer(1), {}}}}}}, + nullptr}, + {"0x6c as a single-character dictionary key", + "l=1", + 3, + {Dictionary{{{"l", {Integer(1), {}}}}}}, + nullptr}, + {"0x6d as a single-character dictionary key", + "m=1", + 3, + {Dictionary{{{"m", {Integer(1), {}}}}}}, + nullptr}, + {"0x6e as a single-character dictionary key", + "n=1", + 3, + {Dictionary{{{"n", {Integer(1), {}}}}}}, + nullptr}, + {"0x6f as a single-character dictionary key", + "o=1", + 3, + {Dictionary{{{"o", {Integer(1), {}}}}}}, + nullptr}, + {"0x70 as a single-character dictionary key", + "p=1", + 3, + {Dictionary{{{"p", {Integer(1), {}}}}}}, + nullptr}, + {"0x71 as a single-character dictionary key", + "q=1", + 3, + {Dictionary{{{"q", {Integer(1), {}}}}}}, + nullptr}, + {"0x72 as a single-character dictionary key", + "r=1", + 3, + {Dictionary{{{"r", {Integer(1), {}}}}}}, + nullptr}, + {"0x73 as a single-character dictionary key", + "s=1", + 3, + {Dictionary{{{"s", {Integer(1), {}}}}}}, + nullptr}, + {"0x74 as a single-character dictionary key", + "t=1", + 3, + {Dictionary{{{"t", {Integer(1), {}}}}}}, + nullptr}, + {"0x75 as a single-character dictionary key", + "u=1", + 3, + {Dictionary{{{"u", {Integer(1), {}}}}}}, + nullptr}, + {"0x76 as a single-character dictionary key", + "v=1", + 3, + {Dictionary{{{"v", {Integer(1), {}}}}}}, + nullptr}, + {"0x77 as a single-character dictionary key", + "w=1", + 3, + {Dictionary{{{"w", {Integer(1), {}}}}}}, + nullptr}, + {"0x78 as a single-character dictionary key", + "x=1", + 3, + {Dictionary{{{"x", {Integer(1), {}}}}}}, + nullptr}, + {"0x79 as a single-character dictionary key", + "y=1", + 3, + {Dictionary{{{"y", {Integer(1), {}}}}}}, + nullptr}, + {"0x7a as a single-character dictionary key", + "z=1", + 3, + {Dictionary{{{"z", {Integer(1), {}}}}}}, + nullptr}, + {"0x7b as a single-character dictionary key", "{=1", 3, absl::nullopt, + nullptr}, + {"0x7c as a single-character dictionary key", "|=1", 3, absl::nullopt, + nullptr}, + {"0x7d as a single-character dictionary key", "}=1", 3, absl::nullopt, + nullptr}, + {"0x7e as a single-character dictionary key", "~=1", 3, absl::nullopt, + nullptr}, + {"0x7f as a single-character dictionary key", "\177=1", 3, absl::nullopt, + nullptr}, + {"0x00 in dictionary key", "a\000a=1", 5, absl::nullopt, nullptr}, + {"0x01 in dictionary key", "a\001a=1", 5, absl::nullopt, nullptr}, + {"0x02 in dictionary key", "a\002a=1", 5, absl::nullopt, nullptr}, + {"0x03 in dictionary key", "a\003a=1", 5, absl::nullopt, nullptr}, + {"0x04 in dictionary key", "a\004a=1", 5, absl::nullopt, nullptr}, + {"0x05 in dictionary key", "a\005a=1", 5, absl::nullopt, nullptr}, + {"0x06 in dictionary key", "a\006a=1", 5, absl::nullopt, nullptr}, + {"0x07 in dictionary key", "a\aa=1", 5, absl::nullopt, nullptr}, + {"0x08 in dictionary key", "a\ba=1", 5, absl::nullopt, nullptr}, + {"0x09 in dictionary key", "a\ta=1", 5, absl::nullopt, nullptr}, + {"0x0a in dictionary key", "a\na=1", 5, absl::nullopt, nullptr}, + {"0x0b in dictionary key", "a\va=1", 5, absl::nullopt, nullptr}, + {"0x0c in dictionary key", "a\fa=1", 5, absl::nullopt, nullptr}, + {"0x0d in dictionary key", "a\ra=1", 5, absl::nullopt, nullptr}, + {"0x0e in dictionary key", "a\016a=1", 5, absl::nullopt, nullptr}, + {"0x0f in dictionary key", "a\017a=1", 5, absl::nullopt, nullptr}, + {"0x10 in dictionary key", "a\020a=1", 5, absl::nullopt, nullptr}, + {"0x11 in dictionary key", "a\021a=1", 5, absl::nullopt, nullptr}, + {"0x12 in dictionary key", "a\022a=1", 5, absl::nullopt, nullptr}, + {"0x13 in dictionary key", "a\023a=1", 5, absl::nullopt, nullptr}, + {"0x14 in dictionary key", "a\024a=1", 5, absl::nullopt, nullptr}, + {"0x15 in dictionary key", "a\025a=1", 5, absl::nullopt, nullptr}, + {"0x16 in dictionary key", "a\026a=1", 5, absl::nullopt, nullptr}, + {"0x17 in dictionary key", "a\027a=1", 5, absl::nullopt, nullptr}, + {"0x18 in dictionary key", "a\030a=1", 5, absl::nullopt, nullptr}, + {"0x19 in dictionary key", "a\031a=1", 5, absl::nullopt, nullptr}, + {"0x1a in dictionary key", "a\032a=1", 5, absl::nullopt, nullptr}, + {"0x1b in dictionary key", "a\033a=1", 5, absl::nullopt, nullptr}, + {"0x1c in dictionary key", "a\034a=1", 5, absl::nullopt, nullptr}, + {"0x1d in dictionary key", "a\035a=1", 5, absl::nullopt, nullptr}, + {"0x1e in dictionary key", "a\036a=1", 5, absl::nullopt, nullptr}, + {"0x1f in dictionary key", "a\037a=1", 5, absl::nullopt, nullptr}, + {"0x20 in dictionary key", "a a=1", 5, absl::nullopt, nullptr}, + {"0x21 in dictionary key", "a!a=1", 5, absl::nullopt, nullptr}, + {"0x22 in dictionary key", "a\"a=1", 5, absl::nullopt, nullptr}, + {"0x23 in dictionary key", "a#a=1", 5, absl::nullopt, nullptr}, + {"0x24 in dictionary key", "a$a=1", 5, absl::nullopt, nullptr}, + {"0x25 in dictionary key", "a%a=1", 5, absl::nullopt, nullptr}, + {"0x26 in dictionary key", "a&a=1", 5, absl::nullopt, nullptr}, + {"0x27 in dictionary key", "a'a=1", 5, absl::nullopt, nullptr}, + {"0x28 in dictionary key", "a(a=1", 5, absl::nullopt, nullptr}, + {"0x29 in dictionary key", "a)a=1", 5, absl::nullopt, nullptr}, + {"0x2a in dictionary key", + "a*a=1", + 5, + {Dictionary{{{"a*a", {Integer(1), {}}}}}}, + nullptr}, + {"0x2b in dictionary key", "a+a=1", 5, absl::nullopt, nullptr}, + {"0x2c in dictionary key", + "a,a=1", + 5, + {Dictionary{{{"a", {Integer(1), {}}}}}}, + "a=1"}, + {"0x2d in dictionary key", + "a-a=1", + 5, + {Dictionary{{{"a-a", {Integer(1), {}}}}}}, + nullptr}, + {"0x2e in dictionary key", + "a.a=1", + 5, + {Dictionary{{{"a.a", {Integer(1), {}}}}}}, + nullptr}, + {"0x2f in dictionary key", "a/a=1", 5, absl::nullopt, nullptr}, + {"0x30 in dictionary key", + "a0a=1", + 5, + {Dictionary{{{"a0a", {Integer(1), {}}}}}}, + nullptr}, + {"0x31 in dictionary key", + "a1a=1", + 5, + {Dictionary{{{"a1a", {Integer(1), {}}}}}}, + nullptr}, + {"0x32 in dictionary key", + "a2a=1", + 5, + {Dictionary{{{"a2a", {Integer(1), {}}}}}}, + nullptr}, + {"0x33 in dictionary key", + "a3a=1", + 5, + {Dictionary{{{"a3a", {Integer(1), {}}}}}}, + nullptr}, + {"0x34 in dictionary key", + "a4a=1", + 5, + {Dictionary{{{"a4a", {Integer(1), {}}}}}}, + nullptr}, + {"0x35 in dictionary key", + "a5a=1", + 5, + {Dictionary{{{"a5a", {Integer(1), {}}}}}}, + nullptr}, + {"0x36 in dictionary key", + "a6a=1", + 5, + {Dictionary{{{"a6a", {Integer(1), {}}}}}}, + nullptr}, + {"0x37 in dictionary key", + "a7a=1", + 5, + {Dictionary{{{"a7a", {Integer(1), {}}}}}}, + nullptr}, + {"0x38 in dictionary key", + "a8a=1", + 5, + {Dictionary{{{"a8a", {Integer(1), {}}}}}}, + nullptr}, + {"0x39 in dictionary key", + "a9a=1", + 5, + {Dictionary{{{"a9a", {Integer(1), {}}}}}}, + nullptr}, + {"0x3a in dictionary key", "a:a=1", 5, absl::nullopt, nullptr}, + {"0x3b in dictionary key", + "a;a=1", + 5, + {Dictionary{{{"a", {Item(true), {Param("a", 1)}}}}}}, + nullptr}, + {"0x3c in dictionary key", "aa=1", 5, absl::nullopt, nullptr}, + {"0x3f in dictionary key", "a?a=1", 5, absl::nullopt, nullptr}, + {"0x40 in dictionary key", "a@a=1", 5, absl::nullopt, nullptr}, + {"0x41 in dictionary key", "aAa=1", 5, absl::nullopt, nullptr}, + {"0x42 in dictionary key", "aBa=1", 5, absl::nullopt, nullptr}, + {"0x43 in dictionary key", "aCa=1", 5, absl::nullopt, nullptr}, + {"0x44 in dictionary key", "aDa=1", 5, absl::nullopt, nullptr}, + {"0x45 in dictionary key", "aEa=1", 5, absl::nullopt, nullptr}, + {"0x46 in dictionary key", "aFa=1", 5, absl::nullopt, nullptr}, + {"0x47 in dictionary key", "aGa=1", 5, absl::nullopt, nullptr}, + {"0x48 in dictionary key", "aHa=1", 5, absl::nullopt, nullptr}, + {"0x49 in dictionary key", "aIa=1", 5, absl::nullopt, nullptr}, + {"0x4a in dictionary key", "aJa=1", 5, absl::nullopt, nullptr}, + {"0x4b in dictionary key", "aKa=1", 5, absl::nullopt, nullptr}, + {"0x4c in dictionary key", "aLa=1", 5, absl::nullopt, nullptr}, + {"0x4d in dictionary key", "aMa=1", 5, absl::nullopt, nullptr}, + {"0x4e in dictionary key", "aNa=1", 5, absl::nullopt, nullptr}, + {"0x4f in dictionary key", "aOa=1", 5, absl::nullopt, nullptr}, + {"0x50 in dictionary key", "aPa=1", 5, absl::nullopt, nullptr}, + {"0x51 in dictionary key", "aQa=1", 5, absl::nullopt, nullptr}, + {"0x52 in dictionary key", "aRa=1", 5, absl::nullopt, nullptr}, + {"0x53 in dictionary key", "aSa=1", 5, absl::nullopt, nullptr}, + {"0x54 in dictionary key", "aTa=1", 5, absl::nullopt, nullptr}, + {"0x55 in dictionary key", "aUa=1", 5, absl::nullopt, nullptr}, + {"0x56 in dictionary key", "aVa=1", 5, absl::nullopt, nullptr}, + {"0x57 in dictionary key", "aWa=1", 5, absl::nullopt, nullptr}, + {"0x58 in dictionary key", "aXa=1", 5, absl::nullopt, nullptr}, + {"0x59 in dictionary key", "aYa=1", 5, absl::nullopt, nullptr}, + {"0x5a in dictionary key", "aZa=1", 5, absl::nullopt, nullptr}, + {"0x5b in dictionary key", "a[a=1", 5, absl::nullopt, nullptr}, + {"0x5c in dictionary key", "a\\a=1", 5, absl::nullopt, nullptr}, + {"0x5d in dictionary key", "a]a=1", 5, absl::nullopt, nullptr}, + {"0x5e in dictionary key", "a^a=1", 5, absl::nullopt, nullptr}, + {"0x5f in dictionary key", + "a_a=1", + 5, + {Dictionary{{{"a_a", {Integer(1), {}}}}}}, + nullptr}, + {"0x60 in dictionary key", "a`a=1", 5, absl::nullopt, nullptr}, + {"0x61 in dictionary key", + "aaa=1", + 5, + {Dictionary{{{"aaa", {Integer(1), {}}}}}}, + nullptr}, + {"0x62 in dictionary key", + "aba=1", + 5, + {Dictionary{{{"aba", {Integer(1), {}}}}}}, + nullptr}, + {"0x63 in dictionary key", + "aca=1", + 5, + {Dictionary{{{"aca", {Integer(1), {}}}}}}, + nullptr}, + {"0x64 in dictionary key", + "ada=1", + 5, + {Dictionary{{{"ada", {Integer(1), {}}}}}}, + nullptr}, + {"0x65 in dictionary key", + "aea=1", + 5, + {Dictionary{{{"aea", {Integer(1), {}}}}}}, + nullptr}, + {"0x66 in dictionary key", + "afa=1", + 5, + {Dictionary{{{"afa", {Integer(1), {}}}}}}, + nullptr}, + {"0x67 in dictionary key", + "aga=1", + 5, + {Dictionary{{{"aga", {Integer(1), {}}}}}}, + nullptr}, + {"0x68 in dictionary key", + "aha=1", + 5, + {Dictionary{{{"aha", {Integer(1), {}}}}}}, + nullptr}, + {"0x69 in dictionary key", + "aia=1", + 5, + {Dictionary{{{"aia", {Integer(1), {}}}}}}, + nullptr}, + {"0x6a in dictionary key", + "aja=1", + 5, + {Dictionary{{{"aja", {Integer(1), {}}}}}}, + nullptr}, + {"0x6b in dictionary key", + "aka=1", + 5, + {Dictionary{{{"aka", {Integer(1), {}}}}}}, + nullptr}, + {"0x6c in dictionary key", + "ala=1", + 5, + {Dictionary{{{"ala", {Integer(1), {}}}}}}, + nullptr}, + {"0x6d in dictionary key", + "ama=1", + 5, + {Dictionary{{{"ama", {Integer(1), {}}}}}}, + nullptr}, + {"0x6e in dictionary key", + "ana=1", + 5, + {Dictionary{{{"ana", {Integer(1), {}}}}}}, + nullptr}, + {"0x6f in dictionary key", + "aoa=1", + 5, + {Dictionary{{{"aoa", {Integer(1), {}}}}}}, + nullptr}, + {"0x70 in dictionary key", + "apa=1", + 5, + {Dictionary{{{"apa", {Integer(1), {}}}}}}, + nullptr}, + {"0x71 in dictionary key", + "aqa=1", + 5, + {Dictionary{{{"aqa", {Integer(1), {}}}}}}, + nullptr}, + {"0x72 in dictionary key", + "ara=1", + 5, + {Dictionary{{{"ara", {Integer(1), {}}}}}}, + nullptr}, + {"0x73 in dictionary key", + "asa=1", + 5, + {Dictionary{{{"asa", {Integer(1), {}}}}}}, + nullptr}, + {"0x74 in dictionary key", + "ata=1", + 5, + {Dictionary{{{"ata", {Integer(1), {}}}}}}, + nullptr}, + {"0x75 in dictionary key", + "aua=1", + 5, + {Dictionary{{{"aua", {Integer(1), {}}}}}}, + nullptr}, + {"0x76 in dictionary key", + "ava=1", + 5, + {Dictionary{{{"ava", {Integer(1), {}}}}}}, + nullptr}, + {"0x77 in dictionary key", + "awa=1", + 5, + {Dictionary{{{"awa", {Integer(1), {}}}}}}, + nullptr}, + {"0x78 in dictionary key", + "axa=1", + 5, + {Dictionary{{{"axa", {Integer(1), {}}}}}}, + nullptr}, + {"0x79 in dictionary key", + "aya=1", + 5, + {Dictionary{{{"aya", {Integer(1), {}}}}}}, + nullptr}, + {"0x7a in dictionary key", + "aza=1", + 5, + {Dictionary{{{"aza", {Integer(1), {}}}}}}, + nullptr}, + {"0x7b in dictionary key", "a{a=1", 5, absl::nullopt, nullptr}, + {"0x7c in dictionary key", "a|a=1", 5, absl::nullopt, nullptr}, + {"0x7d in dictionary key", "a}a=1", 5, absl::nullopt, nullptr}, + {"0x7e in dictionary key", "a~a=1", 5, absl::nullopt, nullptr}, + {"0x7f in dictionary key", "a\177a=1", 5, absl::nullopt, nullptr}, + {"0x00 starting an dictionary key", "\000a=1", 4, absl::nullopt, nullptr}, + {"0x01 starting an dictionary key", "\001a=1", 4, absl::nullopt, nullptr}, + {"0x02 starting an dictionary key", "\002a=1", 4, absl::nullopt, nullptr}, + {"0x03 starting an dictionary key", "\003a=1", 4, absl::nullopt, nullptr}, + {"0x04 starting an dictionary key", "\004a=1", 4, absl::nullopt, nullptr}, + {"0x05 starting an dictionary key", "\005a=1", 4, absl::nullopt, nullptr}, + {"0x06 starting an dictionary key", "\006a=1", 4, absl::nullopt, nullptr}, + {"0x07 starting an dictionary key", "\aa=1", 4, absl::nullopt, nullptr}, + {"0x08 starting an dictionary key", "\ba=1", 4, absl::nullopt, nullptr}, + {"0x09 starting an dictionary key", "\ta=1", 4, absl::nullopt, nullptr}, + {"0x0a starting an dictionary key", "\na=1", 4, absl::nullopt, nullptr}, + {"0x0b starting an dictionary key", "\va=1", 4, absl::nullopt, nullptr}, + {"0x0c starting an dictionary key", "\fa=1", 4, absl::nullopt, nullptr}, + {"0x0d starting an dictionary key", "\ra=1", 4, absl::nullopt, nullptr}, + {"0x0e starting an dictionary key", "\016a=1", 4, absl::nullopt, nullptr}, + {"0x0f starting an dictionary key", "\017a=1", 4, absl::nullopt, nullptr}, + {"0x10 starting an dictionary key", "\020a=1", 4, absl::nullopt, nullptr}, + {"0x11 starting an dictionary key", "\021a=1", 4, absl::nullopt, nullptr}, + {"0x12 starting an dictionary key", "\022a=1", 4, absl::nullopt, nullptr}, + {"0x13 starting an dictionary key", "\023a=1", 4, absl::nullopt, nullptr}, + {"0x14 starting an dictionary key", "\024a=1", 4, absl::nullopt, nullptr}, + {"0x15 starting an dictionary key", "\025a=1", 4, absl::nullopt, nullptr}, + {"0x16 starting an dictionary key", "\026a=1", 4, absl::nullopt, nullptr}, + {"0x17 starting an dictionary key", "\027a=1", 4, absl::nullopt, nullptr}, + {"0x18 starting an dictionary key", "\030a=1", 4, absl::nullopt, nullptr}, + {"0x19 starting an dictionary key", "\031a=1", 4, absl::nullopt, nullptr}, + {"0x1a starting an dictionary key", "\032a=1", 4, absl::nullopt, nullptr}, + {"0x1b starting an dictionary key", "\033a=1", 4, absl::nullopt, nullptr}, + {"0x1c starting an dictionary key", "\034a=1", 4, absl::nullopt, nullptr}, + {"0x1d starting an dictionary key", "\035a=1", 4, absl::nullopt, nullptr}, + {"0x1e starting an dictionary key", "\036a=1", 4, absl::nullopt, nullptr}, + {"0x1f starting an dictionary key", "\037a=1", 4, absl::nullopt, nullptr}, + {"0x20 starting an dictionary key", + " a=1", + 4, + {Dictionary{{{"a", {Integer(1), {}}}}}}, + "a=1"}, + {"0x21 starting an dictionary key", "!a=1", 4, absl::nullopt, nullptr}, + {"0x22 starting an dictionary key", "\"a=1", 4, absl::nullopt, nullptr}, + {"0x23 starting an dictionary key", "#a=1", 4, absl::nullopt, nullptr}, + {"0x24 starting an dictionary key", "$a=1", 4, absl::nullopt, nullptr}, + {"0x25 starting an dictionary key", "%a=1", 4, absl::nullopt, nullptr}, + {"0x26 starting an dictionary key", "&a=1", 4, absl::nullopt, nullptr}, + {"0x27 starting an dictionary key", "'a=1", 4, absl::nullopt, nullptr}, + {"0x28 starting an dictionary key", "(a=1", 4, absl::nullopt, nullptr}, + {"0x29 starting an dictionary key", ")a=1", 4, absl::nullopt, nullptr}, + {"0x2a starting an dictionary key", + "*a=1", + 4, + {Dictionary{{{"*a", {Integer(1), {}}}}}}, + nullptr}, + {"0x2b starting an dictionary key", "+a=1", 4, absl::nullopt, nullptr}, + {"0x2c starting an dictionary key", ",a=1", 4, absl::nullopt, nullptr}, + {"0x2d starting an dictionary key", "-a=1", 4, absl::nullopt, nullptr}, + {"0x2e starting an dictionary key", ".a=1", 4, absl::nullopt, nullptr}, + {"0x2f starting an dictionary key", "/a=1", 4, absl::nullopt, nullptr}, + {"0x30 starting an dictionary key", "0a=1", 4, absl::nullopt, nullptr}, + {"0x31 starting an dictionary key", "1a=1", 4, absl::nullopt, nullptr}, + {"0x32 starting an dictionary key", "2a=1", 4, absl::nullopt, nullptr}, + {"0x33 starting an dictionary key", "3a=1", 4, absl::nullopt, nullptr}, + {"0x34 starting an dictionary key", "4a=1", 4, absl::nullopt, nullptr}, + {"0x35 starting an dictionary key", "5a=1", 4, absl::nullopt, nullptr}, + {"0x36 starting an dictionary key", "6a=1", 4, absl::nullopt, nullptr}, + {"0x37 starting an dictionary key", "7a=1", 4, absl::nullopt, nullptr}, + {"0x38 starting an dictionary key", "8a=1", 4, absl::nullopt, nullptr}, + {"0x39 starting an dictionary key", "9a=1", 4, absl::nullopt, nullptr}, + {"0x3a starting an dictionary key", ":a=1", 4, absl::nullopt, nullptr}, + {"0x3b starting an dictionary key", ";a=1", 4, absl::nullopt, nullptr}, + {"0x3c starting an dictionary key", "a=1", 4, absl::nullopt, nullptr}, + {"0x3f starting an dictionary key", "?a=1", 4, absl::nullopt, nullptr}, + {"0x40 starting an dictionary key", "@a=1", 4, absl::nullopt, nullptr}, + {"0x41 starting an dictionary key", "Aa=1", 4, absl::nullopt, nullptr}, + {"0x42 starting an dictionary key", "Ba=1", 4, absl::nullopt, nullptr}, + {"0x43 starting an dictionary key", "Ca=1", 4, absl::nullopt, nullptr}, + {"0x44 starting an dictionary key", "Da=1", 4, absl::nullopt, nullptr}, + {"0x45 starting an dictionary key", "Ea=1", 4, absl::nullopt, nullptr}, + {"0x46 starting an dictionary key", "Fa=1", 4, absl::nullopt, nullptr}, + {"0x47 starting an dictionary key", "Ga=1", 4, absl::nullopt, nullptr}, + {"0x48 starting an dictionary key", "Ha=1", 4, absl::nullopt, nullptr}, + {"0x49 starting an dictionary key", "Ia=1", 4, absl::nullopt, nullptr}, + {"0x4a starting an dictionary key", "Ja=1", 4, absl::nullopt, nullptr}, + {"0x4b starting an dictionary key", "Ka=1", 4, absl::nullopt, nullptr}, + {"0x4c starting an dictionary key", "La=1", 4, absl::nullopt, nullptr}, + {"0x4d starting an dictionary key", "Ma=1", 4, absl::nullopt, nullptr}, + {"0x4e starting an dictionary key", "Na=1", 4, absl::nullopt, nullptr}, + {"0x4f starting an dictionary key", "Oa=1", 4, absl::nullopt, nullptr}, + {"0x50 starting an dictionary key", "Pa=1", 4, absl::nullopt, nullptr}, + {"0x51 starting an dictionary key", "Qa=1", 4, absl::nullopt, nullptr}, + {"0x52 starting an dictionary key", "Ra=1", 4, absl::nullopt, nullptr}, + {"0x53 starting an dictionary key", "Sa=1", 4, absl::nullopt, nullptr}, + {"0x54 starting an dictionary key", "Ta=1", 4, absl::nullopt, nullptr}, + {"0x55 starting an dictionary key", "Ua=1", 4, absl::nullopt, nullptr}, + {"0x56 starting an dictionary key", "Va=1", 4, absl::nullopt, nullptr}, + {"0x57 starting an dictionary key", "Wa=1", 4, absl::nullopt, nullptr}, + {"0x58 starting an dictionary key", "Xa=1", 4, absl::nullopt, nullptr}, + {"0x59 starting an dictionary key", "Ya=1", 4, absl::nullopt, nullptr}, + {"0x5a starting an dictionary key", "Za=1", 4, absl::nullopt, nullptr}, + {"0x5b starting an dictionary key", "[a=1", 4, absl::nullopt, nullptr}, + {"0x5c starting an dictionary key", "\\a=1", 4, absl::nullopt, nullptr}, + {"0x5d starting an dictionary key", "]a=1", 4, absl::nullopt, nullptr}, + {"0x5e starting an dictionary key", "^a=1", 4, absl::nullopt, nullptr}, + {"0x5f starting an dictionary key", "_a=1", 4, absl::nullopt, nullptr}, + {"0x60 starting an dictionary key", "`a=1", 4, absl::nullopt, nullptr}, + {"0x61 starting an dictionary key", + "aa=1", + 4, + {Dictionary{{{"aa", {Integer(1), {}}}}}}, + nullptr}, + {"0x62 starting an dictionary key", + "ba=1", + 4, + {Dictionary{{{"ba", {Integer(1), {}}}}}}, + nullptr}, + {"0x63 starting an dictionary key", + "ca=1", + 4, + {Dictionary{{{"ca", {Integer(1), {}}}}}}, + nullptr}, + {"0x64 starting an dictionary key", + "da=1", + 4, + {Dictionary{{{"da", {Integer(1), {}}}}}}, + nullptr}, + {"0x65 starting an dictionary key", + "ea=1", + 4, + {Dictionary{{{"ea", {Integer(1), {}}}}}}, + nullptr}, + {"0x66 starting an dictionary key", + "fa=1", + 4, + {Dictionary{{{"fa", {Integer(1), {}}}}}}, + nullptr}, + {"0x67 starting an dictionary key", + "ga=1", + 4, + {Dictionary{{{"ga", {Integer(1), {}}}}}}, + nullptr}, + {"0x68 starting an dictionary key", + "ha=1", + 4, + {Dictionary{{{"ha", {Integer(1), {}}}}}}, + nullptr}, + {"0x69 starting an dictionary key", + "ia=1", + 4, + {Dictionary{{{"ia", {Integer(1), {}}}}}}, + nullptr}, + {"0x6a starting an dictionary key", + "ja=1", + 4, + {Dictionary{{{"ja", {Integer(1), {}}}}}}, + nullptr}, + {"0x6b starting an dictionary key", + "ka=1", + 4, + {Dictionary{{{"ka", {Integer(1), {}}}}}}, + nullptr}, + {"0x6c starting an dictionary key", + "la=1", + 4, + {Dictionary{{{"la", {Integer(1), {}}}}}}, + nullptr}, + {"0x6d starting an dictionary key", + "ma=1", + 4, + {Dictionary{{{"ma", {Integer(1), {}}}}}}, + nullptr}, + {"0x6e starting an dictionary key", + "na=1", + 4, + {Dictionary{{{"na", {Integer(1), {}}}}}}, + nullptr}, + {"0x6f starting an dictionary key", + "oa=1", + 4, + {Dictionary{{{"oa", {Integer(1), {}}}}}}, + nullptr}, + {"0x70 starting an dictionary key", + "pa=1", + 4, + {Dictionary{{{"pa", {Integer(1), {}}}}}}, + nullptr}, + {"0x71 starting an dictionary key", + "qa=1", + 4, + {Dictionary{{{"qa", {Integer(1), {}}}}}}, + nullptr}, + {"0x72 starting an dictionary key", + "ra=1", + 4, + {Dictionary{{{"ra", {Integer(1), {}}}}}}, + nullptr}, + {"0x73 starting an dictionary key", + "sa=1", + 4, + {Dictionary{{{"sa", {Integer(1), {}}}}}}, + nullptr}, + {"0x74 starting an dictionary key", + "ta=1", + 4, + {Dictionary{{{"ta", {Integer(1), {}}}}}}, + nullptr}, + {"0x75 starting an dictionary key", + "ua=1", + 4, + {Dictionary{{{"ua", {Integer(1), {}}}}}}, + nullptr}, + {"0x76 starting an dictionary key", + "va=1", + 4, + {Dictionary{{{"va", {Integer(1), {}}}}}}, + nullptr}, + {"0x77 starting an dictionary key", + "wa=1", + 4, + {Dictionary{{{"wa", {Integer(1), {}}}}}}, + nullptr}, + {"0x78 starting an dictionary key", + "xa=1", + 4, + {Dictionary{{{"xa", {Integer(1), {}}}}}}, + nullptr}, + {"0x79 starting an dictionary key", + "ya=1", + 4, + {Dictionary{{{"ya", {Integer(1), {}}}}}}, + nullptr}, + {"0x7a starting an dictionary key", + "za=1", + 4, + {Dictionary{{{"za", {Integer(1), {}}}}}}, + nullptr}, + {"0x7b starting an dictionary key", "{a=1", 4, absl::nullopt, nullptr}, + {"0x7c starting an dictionary key", "|a=1", 4, absl::nullopt, nullptr}, + {"0x7d starting an dictionary key", "}a=1", 4, absl::nullopt, nullptr}, + {"0x7e starting an dictionary key", "~a=1", 4, absl::nullopt, nullptr}, + {"0x7f starting an dictionary key", "\177a=1", 4, absl::nullopt, nullptr}, + // param-dict.json + {"basic parameterised dict", + "abc=123;a=1;b=2, def=456, ghi=789;q=9;r=\"+w\"", + 44, + {Dictionary{{{"abc", {Integer(123), {Param("a", 1), Param("b", 2)}}}, + {"def", {Integer(456), {}}}, + {"ghi", {Integer(789), {Param("q", 9), Param("r", "+w")}}}}}}, + nullptr}, + {"single item parameterised dict", + "a=b; q=1.0", + 10, + {Dictionary{ + {{"a", {Item("b", Item::kTokenType), {DoubleParam("q", 1.000000)}}}}}}, + "a=b;q=1.0"}, + {"list item parameterised dictionary", + "a=(1 2); q=1.0", + 14, + {Dictionary{{{"a", + {{{Integer(1), {}}, {Integer(2), {}}}, + {DoubleParam("q", 1.000000)}}}}}}, + "a=(1 2);q=1.0"}, + {"missing parameter value parameterised dict", + "a=3;c;d=5", + 9, + {Dictionary{ + {{"a", {Integer(3), {BooleanParam("c", true), Param("d", 5)}}}}}}, + nullptr}, + {"terminal missing parameter value parameterised dict", + "a=3;c=5;d", + 9, + {Dictionary{ + {{"a", {Integer(3), {Param("c", 5), BooleanParam("d", true)}}}}}}, + nullptr}, + {"no whitespace parameterised dict", + "a=b;c=1,d=e;f=2", + 15, + {Dictionary{{{"a", {Item("b", Item::kTokenType), {Param("c", 1)}}}, + {"d", {Item("e", Item::kTokenType), {Param("f", 2)}}}}}}, + "a=b;c=1, d=e;f=2"}, + {"whitespace before = parameterised dict", "a=b;q =0.5", 10, absl::nullopt, + nullptr}, + {"whitespace after = parameterised dict", "a=b;q= 0.5", 10, absl::nullopt, + nullptr}, + {"whitespace before ; parameterised dict", "a=b ;q=0.5", 10, absl::nullopt, + nullptr}, + {"whitespace after ; parameterised dict", + "a=b; q=0.5", + 10, + {Dictionary{ + {{"a", {Item("b", Item::kTokenType), {DoubleParam("q", 0.500000)}}}}}}, + "a=b;q=0.5"}, + {"extra whitespace parameterised dict", + "a=b; c=1 , d=e; f=2; g=3", + 27, + {Dictionary{ + {{"a", {Item("b", Item::kTokenType), {Param("c", 1)}}}, + {"d", + {Item("e", Item::kTokenType), {Param("f", 2), Param("g", 3)}}}}}}, + "a=b;c=1, d=e;f=2;g=3"}, + {"two lines parameterised list", + "a=b;c=1, d=e;f=2", + 16, + {Dictionary{{{"a", {Item("b", Item::kTokenType), {Param("c", 1)}}}, + {"d", {Item("e", Item::kTokenType), {Param("f", 2)}}}}}}, + "a=b;c=1, d=e;f=2"}, + {"trailing comma parameterised list", "a=b; q=1.0,", 11, absl::nullopt, + nullptr}, + {"empty item parameterised list", "a=b; q=1.0,,c=d", 15, absl::nullopt, + nullptr}, +}; + +} // namespace + +TEST(StructuredHeaderGeneratedTest, ParseItem) { + for (const auto& c : parameterized_item_test_cases) { + if (c.raw) { + SCOPED_TRACE(c.name); + std::string raw{c.raw, c.raw_len}; + absl::optional result = ParseItem(raw); + EXPECT_EQ(result, c.expected); + } + } +} + +TEST(StructuredHeaderGeneratedTest, ParseList) { + for (const auto& c : list_test_cases) { + if (c.raw) { + SCOPED_TRACE(c.name); + std::string raw{c.raw, c.raw_len}; + absl::optional result = ParseList(raw); + EXPECT_EQ(result, c.expected); + } + } +} + +TEST(StructuredHeaderGeneratedTest, ParseDictionary) { + for (const auto& c : dictionary_test_cases) { + if (c.raw) { + SCOPED_TRACE(c.name); + std::string raw{c.raw, c.raw_len}; + absl::optional result = ParseDictionary(raw); + EXPECT_EQ(result, c.expected); + } + } +} + +TEST(StructuredHeaderGeneratedTest, SerializeItem) { + for (const auto& c : parameterized_item_test_cases) { + SCOPED_TRACE(c.name); + if (c.expected) { + absl::optional result = SerializeItem(*c.expected); + if (c.raw || c.canonical) { + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), + std::string(c.canonical ? c.canonical : c.raw)); + } else { + EXPECT_FALSE(result.has_value()); + } + } + } +} + +TEST(StructuredHeaderGeneratedTest, SerializeList) { + for (const auto& c : list_test_cases) { + SCOPED_TRACE(c.name); + if (c.expected) { + absl::optional result = SerializeList(*c.expected); + if (c.raw || c.canonical) { + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), + std::string(c.canonical ? c.canonical : c.raw)); + } else { + EXPECT_FALSE(result.has_value()); + } + } + } +} + +TEST(StructuredHeaderGeneratedTest, SerializeDictionary) { + for (const auto& c : dictionary_test_cases) { + SCOPED_TRACE(c.name); + if (c.expected) { + absl::optional result = SerializeDictionary(*c.expected); + if (c.raw || c.canonical) { + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), + std::string(c.canonical ? c.canonical : c.raw)); + } else { + EXPECT_FALSE(result.has_value()); + } + } + } +} + +} // namespace structured_headers +} // namespace quiche diff --git a/quiche/common/structured_headers_test.cc b/quiche/common/structured_headers_test.cc new file mode 100644 index 000000000000..73a7ab31b592 --- /dev/null +++ b/quiche/common/structured_headers_test.cc @@ -0,0 +1,762 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/structured_headers.h" + +#include + +#include +#include + +#include "quiche/common/platform/api/quiche_test.h" + +namespace quiche { +namespace structured_headers { +namespace { + +// Helpers to make test cases clearer + +Item Token(std::string value) { return Item(value, Item::kTokenType); } + +Item Integer(int64_t value) { return Item(value); } + +// Parameter with null value, only used in Structured Headers Draft 09 +std::pair NullParam(std::string key) { + return std::make_pair(key, Item()); +} + +std::pair BooleanParam(std::string key, bool value) { + return std::make_pair(key, Item(value)); +} + +std::pair DoubleParam(std::string key, double value) { + return std::make_pair(key, Item(value)); +} + +std::pair Param(std::string key, int64_t value) { + return std::make_pair(key, Item(value)); +} + +std::pair Param(std::string key, std::string value) { + return std::make_pair(key, Item(value)); +} + +std::pair ByteSequenceParam(std::string key, + std::string value) { + return std::make_pair(key, Item(value, Item::kByteSequenceType)); +} + +std::pair TokenParam(std::string key, std::string value) { + return std::make_pair(key, Token(value)); +} + +// Test cases taken from https://github.com/httpwg/structured-header-tests can +// be found in structured_headers_generated_unittest.cc + +const struct ItemTestCase { + const char* name; + const char* raw; + const absl::optional expected; // nullopt if parse error is expected. + const char* canonical; // nullptr if parse error is expected, or if canonical + // format is identical to raw. +} item_test_cases[] = { + // Token + {"bad token - item", "abc$@%!", absl::nullopt, nullptr}, + {"leading whitespace", " foo", Token("foo"), "foo"}, + {"trailing whitespace", "foo ", Token("foo"), "foo"}, + {"leading asterisk", "*foo", Token("*foo"), nullptr}, + // Number + {"long integer", "999999999999999", Integer(999999999999999L), nullptr}, + {"long negative integer", "-999999999999999", Integer(-999999999999999L), + nullptr}, + {"too long integer", "1000000000000000", absl::nullopt, nullptr}, + {"negative too long integer", "-1000000000000000", absl::nullopt, nullptr}, + {"integral decimal", "1.0", Item(1.0), nullptr}, + // String + {"basic string", "\"foo\"", Item("foo"), nullptr}, + {"non-ascii string", "\"f\xC3\xBC\xC3\xBC\"", absl::nullopt, nullptr}, + // Additional tests + {"valid quoting containing \\n", "\"\\\\n\"", Item("\\n"), nullptr}, + {"valid quoting containing \\t", "\"\\\\t\"", Item("\\t"), nullptr}, + {"valid quoting containing \\x", "\"\\\\x61\"", Item("\\x61"), nullptr}, + {"c-style hex escape in string", "\"\\x61\"", absl::nullopt, nullptr}, + {"valid quoting containing \\u", "\"\\\\u0061\"", Item("\\u0061"), nullptr}, + {"c-style unicode escape in string", "\"\\u0061\"", absl::nullopt, nullptr}, +}; + +const ItemTestCase sh09_item_test_cases[] = { + // Integer + {"large integer", "9223372036854775807", Integer(9223372036854775807L), + nullptr}, + {"large negative integer", "-9223372036854775807", + Integer(-9223372036854775807L), nullptr}, + {"too large integer", "9223372036854775808", absl::nullopt, nullptr}, + {"too large negative integer", "-9223372036854775808", absl::nullopt, + nullptr}, + // Byte Sequence + {"basic binary", "*aGVsbG8=*", Item("hello", Item::kByteSequenceType), + nullptr}, + {"empty binary", "**", Item("", Item::kByteSequenceType), nullptr}, + {"bad paddding", "*aGVsbG8*", Item("hello", Item::kByteSequenceType), + "*aGVsbG8=*"}, + {"bad end delimiter", "*aGVsbG8=", absl::nullopt, nullptr}, + {"extra whitespace", "*aGVsb G8=*", absl::nullopt, nullptr}, + {"extra chars", "*aGVsbG!8=*", absl::nullopt, nullptr}, + {"suffix chars", "*aGVsbG8=!*", absl::nullopt, nullptr}, + {"non-zero pad bits", "*iZ==*", Item("\x89", Item::kByteSequenceType), + "*iQ==*"}, + {"non-ASCII binary", "*/+Ah*", Item("\xFF\xE0!", Item::kByteSequenceType), + nullptr}, + {"base64url binary", "*_-Ah*", absl::nullopt, nullptr}, + {"token with leading asterisk", "*foo", absl::nullopt, nullptr}, +}; + +// For Structured Headers Draft 15 +const struct ParameterizedItemTestCase { + const char* name; + const char* raw; + const absl::optional + expected; // nullopt if parse error is expected. + const char* canonical; // nullptr if parse error is expected, or if canonical + // format is identical to raw. +} parameterized_item_test_cases[] = { + {"single parameter item", + "text/html;q=1.0", + {{Token("text/html"), {DoubleParam("q", 1)}}}, + nullptr}, + {"missing parameter value item", + "text/html;a;q=1.0", + {{Token("text/html"), {BooleanParam("a", true), DoubleParam("q", 1)}}}, + nullptr}, + {"missing terminal parameter value item", + "text/html;q=1.0;a", + {{Token("text/html"), {DoubleParam("q", 1), BooleanParam("a", true)}}}, + nullptr}, + {"duplicate parameter keys with different value", + "text/html;a=1;b=2;a=3.0", + {{Token("text/html"), {DoubleParam("a", 3), Param("b", 2L)}}}, + "text/html;a=3.0;b=2"}, + {"multiple duplicate parameter keys at different position", + "text/html;c=1;a=2;b;b=3.0;a", + {{Token("text/html"), + {Param("c", 1L), BooleanParam("a", true), DoubleParam("b", 3)}}}, + "text/html;c=1;a;b=3.0"}, + {"duplicate parameter keys with missing value", + "text/html;a;a=1", + {{Token("text/html"), {Param("a", 1L)}}}, + "text/html;a=1"}, + {"whitespace before = parameterised item", "text/html, text/plain;q =0.5", + absl::nullopt, nullptr}, + {"whitespace after = parameterised item", "text/html, text/plain;q= 0.5", + absl::nullopt, nullptr}, + {"whitespace before ; parameterised item", "text/html, text/plain ;q=0.5", + absl::nullopt, nullptr}, + {"whitespace after ; parameterised item", + "text/plain; q=0.5", + {{Token("text/plain"), {DoubleParam("q", 0.5)}}}, + "text/plain;q=0.5"}, + {"extra whitespace parameterised item", + "text/plain; q=0.5; charset=utf-8", + {{Token("text/plain"), + {DoubleParam("q", 0.5), TokenParam("charset", "utf-8")}}}, + "text/plain;q=0.5;charset=utf-8"}, +}; + +// For Structured Headers Draft 15 +const struct ListTestCase { + const char* name; + const char* raw; + const absl::optional expected; // nullopt if parse error is expected. + const char* canonical; // nullptr if parse error is expected, or if canonical + // format is identical to raw. +} list_test_cases[] = { + // Lists of lists + {"extra whitespace list of lists", + "(1 42)", + {{{{{Integer(1L), {}}, {Integer(42L), {}}}, {}}}}, + "(1 42)"}, + // Parameterized Lists + {"basic parameterised list", + "abc_123;a=1;b=2; cdef_456, ghi;q=\"9\";r=\"+w\"", + {{{Token("abc_123"), + {Param("a", 1), Param("b", 2), BooleanParam("cdef_456", true)}}, + {Token("ghi"), {Param("q", "9"), Param("r", "+w")}}}}, + "abc_123;a=1;b=2;cdef_456, ghi;q=\"9\";r=\"+w\""}, + // Parameterized inner lists + {"parameterised basic list of lists", + "(1;a=1.0 2), (42 43)", + {{{{{Integer(1L), {DoubleParam("a", 1.0)}}, {Integer(2L), {}}}, {}}, + {{{Integer(42L), {}}, {Integer(43L), {}}}, {}}}}, + nullptr}, + {"parameters on inner members", + "(1;a=1.0 2;b=c), (42;d=?0 43;e=:Zmdo:)", + {{{{{Integer(1L), {DoubleParam("a", 1.0)}}, + {Integer(2L), {TokenParam("b", "c")}}}, + {}}, + {{{Integer(42L), {BooleanParam("d", false)}}, + {Integer(43L), {ByteSequenceParam("e", "fgh")}}}, + {}}}}, + nullptr}, + {"parameters on inner lists", + "(1 2);a=1.0, (42 43);b=?0", + {{{{{Integer(1L), {}}, {Integer(2L), {}}}, {DoubleParam("a", 1.0)}}, + {{{Integer(42L), {}}, {Integer(43L), {}}}, {BooleanParam("b", false)}}}}, + nullptr}, + {"default true values for parameters on inner list members", + "(1;a 2), (42 43;b)", + {{{{{Integer(1L), {BooleanParam("a", true)}}, {Integer(2L), {}}}, {}}, + {{{Integer(42L), {}}, {Integer(43L), {BooleanParam("b", true)}}}, {}}}}, + nullptr}, + {"default true values for parameters on inner lists", + "(1 2);a, (42 43);b", + {{{{{Integer(1L), {}}, {Integer(2L), {}}}, {BooleanParam("a", true)}}, + {{{Integer(42L), {}}, {Integer(43L), {}}}, {BooleanParam("b", true)}}}}, + nullptr}, + {"extra whitespace before semicolon in parameters on inner list member", + "(a;b ;c b)", absl::nullopt, nullptr}, + {"extra whitespace between parameters on inner list member", + "(a;b; c b)", + {{{{{Token("a"), {BooleanParam("b", true), BooleanParam("c", true)}}, + {Token("b"), {}}}, + {}}}}, + "(a;b;c b)"}, + {"extra whitespace before semicolon in parameters on inner list", + "(a b);c ;d, (e)", absl::nullopt, nullptr}, + {"extra whitespace between parameters on inner list", + "(a b);c; d, (e)", + {{{{{Token("a"), {}}, {Token("b"), {}}}, + {BooleanParam("c", true), BooleanParam("d", true)}}, + {{{Token("e"), {}}}, {}}}}, + "(a b);c;d, (e)"}, +}; + +// For Structured Headers Draft 15 +const struct DictionaryTestCase { + const char* name; + const char* raw; + const absl::optional + expected; // nullopt if parse error is expected. + const char* canonical; // nullptr if parse error is expected, or if canonical + // format is identical to raw. +} dictionary_test_cases[] = { + {"basic dictionary", + "en=\"Applepie\", da=:aGVsbG8=:", + {Dictionary{{{"en", {Item("Applepie"), {}}}, + {"da", {Item("hello", Item::kByteSequenceType), {}}}}}}, + nullptr}, + {"tab separated dictionary", + "a=1\t,\tb=2", + {Dictionary{{{"a", {Integer(1L), {}}}, {"b", {Integer(2L), {}}}}}}, + "a=1, b=2"}, + {"missing value with params dictionary", + "a=1, b;foo=9, c=3", + {Dictionary{{{"a", {Integer(1L), {}}}, + {"b", {Item(true), {Param("foo", 9)}}}, + {"c", {Integer(3L), {}}}}}}, + nullptr}, + // Parameterised dictionary tests + {"parameterised inner list member dict", + "a=(\"1\";b=1;c=?0 \"2\");d=\"e\"", + {Dictionary{{{"a", + {{{Item("1"), {Param("b", 1), BooleanParam("c", false)}}, + {Item("2"), {}}}, + {Param("d", "e")}}}}}}, + nullptr}, + {"explicit true value with parameter", + "a=?1;b=1", + {Dictionary{{{"a", {Item(true), {Param("b", 1)}}}}}}, + "a;b=1"}, + {"implicit true value with parameter", + "a;b=1", + {Dictionary{{{"a", {Item(true), {Param("b", 1)}}}}}}, + nullptr}, + {"implicit true value with implicitly-valued parameter", + "a;b", + {Dictionary{{{"a", {Item(true), {BooleanParam("b", true)}}}}}}, + nullptr}, +}; +} // namespace + +TEST(StructuredHeaderTest, ParseBareItem) { + for (const auto& c : item_test_cases) { + SCOPED_TRACE(c.name); + absl::optional result = ParseBareItem(c.raw); + EXPECT_EQ(result, c.expected); + } +} + +// For Structured Headers Draft 15, these tests include parameters on Items. +TEST(StructuredHeaderTest, ParseItem) { + for (const auto& c : parameterized_item_test_cases) { + SCOPED_TRACE(c.name); + absl::optional result = ParseItem(c.raw); + EXPECT_EQ(result, c.expected); + } +} + +// Structured Headers Draft 9 parsing rules are different than Draft 15, and +// some strings which are considered invalid in SH15 should parse in SH09. +// The SH09 Item parser is not directly exposed, but can be used indirectly by +// calling the parser for SH09-specific lists. +TEST(StructuredHeaderTest, ParseSH09Item) { + for (const auto& c : sh09_item_test_cases) { + SCOPED_TRACE(c.name); + absl::optional result = ParseListOfLists(c.raw); + if (c.expected.has_value()) { + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result->size(), 1UL); + EXPECT_EQ((*result)[0].size(), 1UL); + EXPECT_EQ((*result)[0][0], c.expected); + } else { + EXPECT_FALSE(result.has_value()); + } + } +} + +// In Structured Headers Draft 9, floats can have more than three fractional +// digits, and can be larger than 1e12. This behaviour is exposed in the parser +// for SH09-specific lists, so test it through that interface. +TEST(StructuredHeaderTest, SH09HighPrecisionFloats) { + // These values are exactly representable in binary floating point, so no + // accuracy issues are expected in this test. + absl::optional result = + ParseListOfLists("1.03125;-1.03125;12345678901234.5;-12345678901234.5"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, + (ListOfLists{{Item(1.03125), Item(-1.03125), Item(12345678901234.5), + Item(-12345678901234.5)}})); + + result = ParseListOfLists("123456789012345.0"); + EXPECT_FALSE(result.has_value()); + + result = ParseListOfLists("-123456789012345.0"); + EXPECT_FALSE(result.has_value()); +} + +// For Structured Headers Draft 9 +TEST(StructuredHeaderTest, ParseListOfLists) { + static const struct TestCase { + const char* name; + const char* raw; + ListOfLists expected; // empty if parse error is expected + } cases[] = { + {"basic list of lists", + "1;2, 42;43", + {{Integer(1L), Integer(2L)}, {Integer(42L), Integer(43L)}}}, + {"empty list of lists", "", {}}, + {"single item list of lists", "42", {{Integer(42L)}}}, + {"no whitespace list of lists", "1,42", {{Integer(1L)}, {Integer(42L)}}}, + {"no inner whitespace list of lists", + "1;2, 42;43", + {{Integer(1L), Integer(2L)}, {Integer(42L), Integer(43L)}}}, + {"extra whitespace list of lists", + "1 , 42", + {{Integer(1L)}, {Integer(42L)}}}, + {"extra inner whitespace list of lists", + "1 ; 2,42 ; 43", + {{Integer(1L), Integer(2L)}, {Integer(42L), Integer(43L)}}}, + {"trailing comma list of lists", "1;2, 42,", {}}, + {"trailing semicolon list of lists", "1;2, 42;43;", {}}, + {"leading comma list of lists", ",1;2, 42", {}}, + {"leading semicolon list of lists", ";1;2, 42;43", {}}, + {"empty item list of lists", "1,,42", {}}, + {"empty inner item list of lists", "1;;2,42", {}}, + }; + for (const auto& c : cases) { + SCOPED_TRACE(c.name); + absl::optional result = ParseListOfLists(c.raw); + if (!c.expected.empty()) { + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, c.expected); + } else { + EXPECT_FALSE(result.has_value()); + } + } +} + +// For Structured Headers Draft 9 +TEST(StructuredHeaderTest, ParseParameterisedList) { + static const struct TestCase { + const char* name; + const char* raw; + ParameterisedList expected; // empty if parse error is expected + } cases[] = { + {"basic param-list", + "abc_123;a=1;b=2; cdef_456, ghi;q=\"9\";r=\"w\"", + { + {Token("abc_123"), + {Param("a", 1), Param("b", 2), NullParam("cdef_456")}}, + {Token("ghi"), {Param("q", "9"), Param("r", "w")}}, + }}, + {"empty param-list", "", {}}, + {"single item param-list", + "text/html;q=1", + {{Token("text/html"), {Param("q", 1)}}}}, + {"empty param-list", "", {}}, + {"no whitespace param-list", + "text/html,text/plain;q=1", + {{Token("text/html"), {}}, {Token("text/plain"), {Param("q", 1)}}}}, + {"whitespace before = param-list", "text/html, text/plain;q =1", {}}, + {"whitespace after = param-list", "text/html, text/plain;q= 1", {}}, + {"extra whitespace param-list", + "text/html , text/plain ; q=1", + {{Token("text/html"), {}}, {Token("text/plain"), {Param("q", 1)}}}}, + {"duplicate key", "abc;a=1;b=2;a=1", {}}, + {"numeric key", "abc;a=1;1b=2;c=1", {}}, + {"uppercase key", "abc;a=1;B=2;c=1", {}}, + {"bad key", "abc;a=1;b!=2;c=1", {}}, + {"another bad key", "abc;a=1;b==2;c=1", {}}, + {"empty key name", "abc;a=1;=2;c=1", {}}, + {"empty parameter", "abc;a=1;;c=1", {}}, + {"empty list item", "abc;a=1,,def;b=1", {}}, + {"extra semicolon", "abc;a=1;b=1;", {}}, + {"extra comma", "abc;a=1,def;b=1,", {}}, + {"leading semicolon", ";abc;a=1", {}}, + {"leading comma", ",abc;a=1", {}}, + }; + for (const auto& c : cases) { + SCOPED_TRACE(c.name); + absl::optional result = ParseParameterisedList(c.raw); + if (c.expected.empty()) { + EXPECT_FALSE(result.has_value()); + continue; + } + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result->size(), c.expected.size()); + if (result->size() == c.expected.size()) { + for (size_t i = 0; i < c.expected.size(); ++i) { + EXPECT_EQ((*result)[i], c.expected[i]); + } + } + } +} + +// For Structured Headers Draft 15 +TEST(StructuredHeaderTest, ParseList) { + for (const auto& c : list_test_cases) { + SCOPED_TRACE(c.name); + absl::optional result = ParseList(c.raw); + EXPECT_EQ(result, c.expected); + } +} + +// For Structured Headers Draft 15 +TEST(StructuredHeaderTest, ParseDictionary) { + for (const auto& c : dictionary_test_cases) { + SCOPED_TRACE(c.name); + absl::optional result = ParseDictionary(c.raw); + EXPECT_EQ(result, c.expected); + } +} + +// Serializer tests are all exclusively for Structured Headers Draft 15 + +TEST(StructuredHeaderTest, SerializeItem) { + for (const auto& c : item_test_cases) { + SCOPED_TRACE(c.name); + if (c.expected) { + absl::optional result = SerializeItem(*c.expected); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), std::string(c.canonical ? c.canonical : c.raw)); + } + } +} + +TEST(StructuredHeaderTest, SerializeParameterizedItem) { + for (const auto& c : parameterized_item_test_cases) { + SCOPED_TRACE(c.name); + if (c.expected) { + absl::optional result = SerializeItem(*c.expected); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), std::string(c.canonical ? c.canonical : c.raw)); + } + } +} + +TEST(StructuredHeaderTest, UnserializableItems) { + // Test that items with unknown type are not serialized. + EXPECT_FALSE(SerializeItem(Item()).has_value()); +} + +TEST(StructuredHeaderTest, UnserializableTokens) { + static const struct UnserializableString { + const char* name; + const char* value; + } bad_tokens[] = { + {"empty token", ""}, + {"contains high ascii", "a\xff"}, + {"contains nonprintable character", "a\x7f"}, + {"contains C0", "a\x01"}, + {"UTF-8 encoded", "a\xc3\xa9"}, + {"contains TAB", "a\t"}, + {"contains LF", "a\n"}, + {"contains CR", "a\r"}, + {"contains SP", "a "}, + {"begins with digit", "9token"}, + {"begins with hyphen", "-token"}, + {"begins with LF", "\ntoken"}, + {"begins with SP", " token"}, + {"begins with colon", ":token"}, + {"begins with percent", "%token"}, + {"begins with period", ".token"}, + {"begins with slash", "/token"}, + }; + for (const auto& bad_token : bad_tokens) { + SCOPED_TRACE(bad_token.name); + absl::optional serialization = + SerializeItem(Token(bad_token.value)); + EXPECT_FALSE(serialization.has_value()) << *serialization; + } +} + +TEST(StructuredHeaderTest, UnserializableKeys) { + static const struct UnserializableString { + const char* name; + const char* value; + } bad_keys[] = { + {"empty key", ""}, + {"contains high ascii", "a\xff"}, + {"contains nonprintable character", "a\x7f"}, + {"contains C0", "a\x01"}, + {"UTF-8 encoded", "a\xc3\xa9"}, + {"contains TAB", "a\t"}, + {"contains LF", "a\n"}, + {"contains CR", "a\r"}, + {"contains SP", "a "}, + {"begins with uppercase", "Atoken"}, + {"begins with digit", "9token"}, + {"begins with hyphen", "-token"}, + {"begins with LF", "\ntoken"}, + {"begins with SP", " token"}, + {"begins with colon", ":token"}, + {"begins with percent", "%token"}, + {"begins with period", ".token"}, + {"begins with slash", "/token"}, + }; + for (const auto& bad_key : bad_keys) { + SCOPED_TRACE(bad_key.name); + absl::optional serialization = + SerializeItem(ParameterizedItem("a", {{bad_key.value, "a"}})); + EXPECT_FALSE(serialization.has_value()) << *serialization; + } +} + +TEST(StructuredHeaderTest, UnserializableStrings) { + static const struct UnserializableString { + const char* name; + const char* value; + } bad_strings[] = { + {"contains high ascii", "a\xff"}, + {"contains nonprintable character", "a\x7f"}, + {"UTF-8 encoded", "a\xc3\xa9"}, + {"contains TAB", "a\t"}, + {"contains LF", "a\n"}, + {"contains CR", "a\r"}, + {"contains C0", "a\x01"}, + }; + for (const auto& bad_string : bad_strings) { + SCOPED_TRACE(bad_string.name); + absl::optional serialization = + SerializeItem(Item(bad_string.value)); + EXPECT_FALSE(serialization.has_value()) << *serialization; + } +} + +TEST(StructuredHeaderTest, UnserializableIntegers) { + EXPECT_FALSE(SerializeItem(Integer(1e15L)).has_value()); + EXPECT_FALSE(SerializeItem(Integer(-1e15L)).has_value()); +} + +TEST(StructuredHeaderTest, UnserializableDecimals) { + for (double value : + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::infinity(), + -std::numeric_limits::infinity(), 1e12, 1e12 - 0.0001, + 1e12 - 0.0005, -1e12, -1e12 + 0.0001, -1e12 + 0.0005}) { + auto x = SerializeItem(Item(value)); + EXPECT_FALSE(SerializeItem(Item(value)).has_value()); + } +} + +// These values cannot be directly parsed from headers, but are valid doubles +// which can be serialized as sh-floats (though rounding is expected.) +TEST(StructuredHeaderTest, SerializeUnparseableDecimals) { + struct UnparseableDecimal { + const char* name; + double value; + const char* canonical; + } float_test_cases[] = { + {"negative 0", -0.0, "0.0"}, + {"0.0001", 0.0001, "0.0"}, + {"0.0000001", 0.0000001, "0.0"}, + {"1.0001", 1.0001, "1.0"}, + {"1.0009", 1.0009, "1.001"}, + {"round positive odd decimal", 0.0015, "0.002"}, + {"round positive even decimal", 0.0025, "0.002"}, + {"round negative odd decimal", -0.0015, "-0.002"}, + {"round negative even decimal", -0.0025, "-0.002"}, + {"round decimal up to integer part", 9.9995, "10.0"}, + {"subnormal numbers", std::numeric_limits::denorm_min(), "0.0"}, + {"round up to 10 digits", 1e9 - 0.0000001, "1000000000.0"}, + {"round up to 11 digits", 1e10 - 0.000001, "10000000000.0"}, + {"round up to 12 digits", 1e11 - 0.00001, "100000000000.0"}, + {"largest serializable float", nextafter(1e12 - 0.0005, 0), + "999999999999.999"}, + {"largest serializable negative float", -nextafter(1e12 - 0.0005, 0), + "-999999999999.999"}, + // This will fail if we simply truncate the fractional portion. + {"float rounds up to next int", 3.9999999, "4.0"}, + // This will fail if we first round to >3 digits, and then round again to + // 3 digits. + {"don't double round", 3.99949, "3.999"}, + // This will fail if we first round to 3 digits, and then round again to + // max_avail_digits. + {"don't double round", 123456789.99949, "123456789.999"}, + }; + for (const auto& test_case : float_test_cases) { + SCOPED_TRACE(test_case.name); + absl::optional serialization = + SerializeItem(Item(test_case.value)); + EXPECT_TRUE(serialization.has_value()); + EXPECT_EQ(*serialization, test_case.canonical); + } +} + +TEST(StructuredHeaderTest, SerializeList) { + for (const auto& c : list_test_cases) { + SCOPED_TRACE(c.name); + if (c.expected) { + absl::optional result = SerializeList(*c.expected); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), std::string(c.canonical ? c.canonical : c.raw)); + } + } +} + +TEST(StructuredHeaderTest, UnserializableLists) { + static const struct UnserializableList { + const char* name; + const List value; + } bad_lists[] = { + {"Null item as member", {{Item(), {}}}}, + {"Unserializable item as member", {{Token("\n"), {}}}}, + {"Key is empty", {{Token("abc"), {Param("", 1)}}}}, + {"Key containswhitespace", {{Token("abc"), {Param("a\n", 1)}}}}, + {"Key contains UTF8", {{Token("abc"), {Param("a\xc3\xa9", 1)}}}}, + {"Key contains unprintable characters", + {{Token("abc"), {Param("a\x7f", 1)}}}}, + {"Key contains disallowed characters", + {{Token("abc"), {Param("a:", 1)}}}}, + {"Param value is unserializable", {{Token("abc"), {{"a", Token("\n")}}}}}, + {"Inner list contains unserializable item", + {{std::vector{{Token("\n"), {}}}, {}}}}, + }; + for (const auto& bad_list : bad_lists) { + SCOPED_TRACE(bad_list.name); + absl::optional serialization = SerializeList(bad_list.value); + EXPECT_FALSE(serialization.has_value()) << *serialization; + } +} + +TEST(StructuredHeaderTest, SerializeDictionary) { + for (const auto& c : dictionary_test_cases) { + SCOPED_TRACE(c.name); + if (c.expected) { + absl::optional result = SerializeDictionary(*c.expected); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), std::string(c.canonical ? c.canonical : c.raw)); + } + } +} + +TEST(StructuredHeaderTest, DictionaryConstructors) { + const std::string key0 = "key0"; + const std::string key1 = "key1"; + const ParameterizedMember member0{Item("Applepie"), {}}; + const ParameterizedMember member1{Item("hello", Item::kByteSequenceType), {}}; + + Dictionary dict; + EXPECT_TRUE(dict.empty()); + EXPECT_EQ(0U, dict.size()); + dict[key0] = member0; + EXPECT_FALSE(dict.empty()); + EXPECT_EQ(1U, dict.size()); + + const Dictionary dict_copy = dict; + EXPECT_FALSE(dict_copy.empty()); + EXPECT_EQ(1U, dict_copy.size()); + EXPECT_EQ(dict, dict_copy); + + const Dictionary dict_init{{{key0, member0}, {key1, member1}}}; + EXPECT_FALSE(dict_init.empty()); + EXPECT_EQ(2U, dict_init.size()); + EXPECT_EQ(member0, dict_init.at(key0)); + EXPECT_EQ(member1, dict_init.at(key1)); +} + +TEST(StructuredHeaderTest, DictionaryAccessors) { + const std::string key0 = "key0"; + const std::string key1 = "key1"; + + const ParameterizedMember nonempty_member0{Item("Applepie"), {}}; + const ParameterizedMember nonempty_member1{ + Item("hello", Item::kByteSequenceType), {}}; + const ParameterizedMember empty_member; + + Dictionary dict{{{key0, nonempty_member0}}}; + EXPECT_TRUE(dict.contains(key0)); + EXPECT_EQ(nonempty_member0, dict[key0]); + EXPECT_EQ(&dict[key0], &dict.at(key0)); + EXPECT_EQ(&dict[key0], &dict[0]); + EXPECT_EQ(&dict[key0], &dict.at(0)); + + // Even if the key does not yet exist in |dict|, operator[]() should + // automatically create an empty entry. + ASSERT_FALSE(dict.contains(key1)); + ParameterizedMember& member1 = dict[key1]; + EXPECT_TRUE(dict.contains(key1)); + EXPECT_EQ(empty_member, member1); + EXPECT_EQ(&member1, &dict[key1]); + EXPECT_EQ(&member1, &dict.at(key1)); + EXPECT_EQ(&member1, &dict[1]); + EXPECT_EQ(&member1, &dict.at(1)); + + member1 = nonempty_member1; + EXPECT_EQ(nonempty_member1, dict[key1]); + EXPECT_EQ(&dict[key1], &dict.at(key1)); + EXPECT_EQ(&dict[key1], &dict[1]); + EXPECT_EQ(&dict[key1], &dict.at(1)); + + // at(StringPiece) and indexed accessors have const overloads. + const Dictionary& dict_ref = dict; + EXPECT_EQ(&member1, &dict_ref.at(key1)); + EXPECT_EQ(&member1, &dict_ref[1]); + EXPECT_EQ(&member1, &dict_ref.at(1)); +} + +TEST(StructuredHeaderTest, UnserializableDictionary) { + static const struct UnserializableDictionary { + const char* name; + const Dictionary value; + } bad_dictionaries[] = { + {"Unserializable dict key", Dictionary{{{"ABC", {Token("abc"), {}}}}}}, + {"Dictionary item is unserializable", + Dictionary{{{"abc", {Token("abc="), {}}}}}}, + {"Param value is unserializable", + Dictionary{{{"abc", {Token("abc"), {{"a", Token("\n")}}}}}}}, + {"Dictionary inner-list contains unserializable item", + Dictionary{ + {{"abc", + {std::vector{{Token("abc="), {}}}, {}}}}}}, + }; + for (const auto& bad_dictionary : bad_dictionaries) { + SCOPED_TRACE(bad_dictionary.name); + absl::optional serialization = + SerializeDictionary(bad_dictionary.value); + EXPECT_FALSE(serialization.has_value()) << *serialization; + } +} + +} // namespace structured_headers +} // namespace quiche diff --git a/quiche/common/test_tools/quiche_test_utils.cc b/quiche/common/test_tools/quiche_test_utils.cc new file mode 100644 index 000000000000..64707866d0a6 --- /dev/null +++ b/quiche/common/test_tools/quiche_test_utils.cc @@ -0,0 +1,102 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/test_tools/quiche_test_utils.h" + +#include + +#include "url/gurl.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace { + +std::string HexDumpWithMarks(const char* data, int length, const bool* marks, + int mark_length) { + static const char kHexChars[] = "0123456789abcdef"; + static const int kColumns = 4; + + const int kSizeLimit = 1024; + if (length > kSizeLimit || mark_length > kSizeLimit) { + QUICHE_LOG(ERROR) << "Only dumping first " << kSizeLimit << " bytes."; + length = std::min(length, kSizeLimit); + mark_length = std::min(mark_length, kSizeLimit); + } + + std::string hex; + for (const char* row = data; length > 0; + row += kColumns, length -= kColumns) { + for (const char* p = row; p < row + 4; ++p) { + if (p < row + length) { + const bool mark = + (marks && (p - data) < mark_length && marks[p - data]); + hex += mark ? '*' : ' '; + hex += kHexChars[(*p & 0xf0) >> 4]; + hex += kHexChars[*p & 0x0f]; + hex += mark ? '*' : ' '; + } else { + hex += " "; + } + } + hex = hex + " "; + + for (const char* p = row; p < row + 4 && p < row + length; ++p) { + hex += (*p >= 0x20 && *p < 0x7f) ? (*p) : '.'; + } + + hex = hex + '\n'; + } + return hex; +} + +} // namespace + +namespace quiche { +namespace test { + +void CompareCharArraysWithHexError(const std::string& description, + const char* actual, const int actual_len, + const char* expected, + const int expected_len) { + EXPECT_EQ(actual_len, expected_len); + const int min_len = std::min(actual_len, expected_len); + const int max_len = std::max(actual_len, expected_len); + std::unique_ptr marks(new bool[max_len]); + bool identical = (actual_len == expected_len); + for (int i = 0; i < min_len; ++i) { + if (actual[i] != expected[i]) { + marks[i] = true; + identical = false; + } else { + marks[i] = false; + } + } + for (int i = min_len; i < max_len; ++i) { + marks[i] = true; + } + if (identical) return; + ADD_FAILURE() << "Description:\n" + << description << "\n\nExpected:\n" + << HexDumpWithMarks(expected, expected_len, marks.get(), + max_len) + << "\nActual:\n" + << HexDumpWithMarks(actual, actual_len, marks.get(), max_len); +} + +iovec MakeIOVector(absl::string_view str) { + return iovec{const_cast(str.data()), static_cast(str.size())}; +} + +bool GoogleUrlSupportsIdnaForTest() { + const std::string kTestInput = "https://\xe5\x85\x89.example.org/"; + const std::string kExpectedOutput = "https://xn--54q.example.org/"; + + GURL url(kTestInput); + bool valid = url.is_valid() && url.spec() == kExpectedOutput; + QUICHE_CHECK(valid || !url.is_valid()) << url.spec(); + return valid; +} + +} // namespace test +} // namespace quiche diff --git a/quiche/common/test_tools/quiche_test_utils.h b/quiche/common/test_tools/quiche_test_utils.h new file mode 100644 index 000000000000..811611aae3b5 --- /dev/null +++ b/quiche/common/test_tools/quiche_test_utils.h @@ -0,0 +1,85 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_TEST_TOOLS_QUICHE_TEST_UTILS_H_ +#define QUICHE_COMMON_TEST_TOOLS_QUICHE_TEST_UTILS_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_iovec.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace quiche { +namespace test { + +void CompareCharArraysWithHexError(const std::string& description, + const char* actual, const int actual_len, + const char* expected, + const int expected_len); + +// Create iovec that points to that data that `str` points to. +iovec MakeIOVector(absl::string_view str); + +// Due to binary size considerations, googleurl library can be built with or +// without IDNA support, meaning that we have to adjust our tests accordingly. +// This function checks if IDNAs are supported. +bool GoogleUrlSupportsIdnaForTest(); + +// Takes either a Status or StatusOr, and returns just the Status. +inline const absl::Status& ExtractStatus(const absl::Status& status) { + return status; +} +template +const absl::Status& ExtractStatus(const absl::StatusOr& status_or) { + return status_or.status(); +} + +// Abseil does not provide absl::Status-related macros, so we have to provide +// those instead. +MATCHER(IsOk, "Checks if an instance of absl::Status is ok.") { + if (arg.ok()) { + return true; + } + *result_listener << "Expected status OK, got " << ExtractStatus(arg); + return false; +} + +MATCHER_P(IsOkAndHolds, matcher, + "Matcher against the inner value of absl::StatusOr") { + if (!arg.ok()) { + *result_listener << "Expected status OK, got " << arg.status(); + return false; + } + return ::testing::ExplainMatchResult(matcher, arg.value(), result_listener); +} + +MATCHER_P(StatusIs, code, "Matcher against only a specific status code") { + if (ExtractStatus(arg).code() != code) { + *result_listener << "Expected status " << absl::StatusCodeToString(code) + << ", got " << ExtractStatus(arg); + return false; + } + return true; +} + +MATCHER_P2(StatusIs, code, matcher, "Matcher against a specific status code") { + if (ExtractStatus(arg).code() != code) { + *result_listener << "Expected status " << absl::StatusCodeToString(code) + << ", got " << ExtractStatus(arg); + return false; + } + return ::testing::ExplainMatchResult(matcher, ExtractStatus(arg).message(), + result_listener); +} + +#define QUICHE_EXPECT_OK(arg) EXPECT_THAT((arg), ::quiche::test::IsOk()) +#define QUICHE_ASSERT_OK(arg) ASSERT_THAT((arg), ::quiche::test::IsOk()) + +} // namespace test +} // namespace quiche + +#endif // QUICHE_COMMON_TEST_TOOLS_QUICHE_TEST_UTILS_H_ diff --git a/quiche/common/test_tools/quiche_test_utils_test.cc b/quiche/common/test_tools/quiche_test_utils_test.cc new file mode 100644 index 000000000000..17427a534715 --- /dev/null +++ b/quiche/common/test_tools/quiche_test_utils_test.cc @@ -0,0 +1,37 @@ +#include "quiche/common/test_tools/quiche_test_utils.h" + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace quiche::test { +namespace { + +using ::testing::HasSubstr; +using ::testing::Not; + +TEST(QuicheTestUtilsTest, StatusMatchers) { + const absl::Status ok = absl::OkStatus(); + QUICHE_EXPECT_OK(ok); + QUICHE_ASSERT_OK(ok); + EXPECT_THAT(ok, IsOk()); + + const absl::StatusOr ok_with_value = 2023; + QUICHE_EXPECT_OK(ok_with_value); + QUICHE_ASSERT_OK(ok_with_value); + EXPECT_THAT(ok_with_value, IsOk()); + EXPECT_THAT(ok_with_value, IsOkAndHolds(2023)); + + const absl::Status err = absl::InternalError("test error"); + EXPECT_THAT(err, Not(IsOk())); + EXPECT_THAT(err, StatusIs(absl::StatusCode::kInternal, HasSubstr("test"))); + + const absl::StatusOr err_with_value = absl::InternalError("test error"); + EXPECT_THAT(err_with_value, Not(IsOk())); + EXPECT_THAT(err_with_value, Not(IsOkAndHolds(2023))); + EXPECT_THAT(err_with_value, + StatusIs(absl::StatusCode::kInternal, HasSubstr("test"))); +} + +} // namespace +} // namespace quiche::test diff --git a/quiche/common/wire_serialization.h b/quiche/common/wire_serialization.h new file mode 100644 index 000000000000..89a54f2711e1 --- /dev/null +++ b/quiche/common/wire_serialization.h @@ -0,0 +1,398 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// wire_serialization.h -- absl::StrCat()-like interface for QUICHE wire format. +// +// When serializing a data structure, there are two common approaches: +// (1) Allocate into a dynamically sized buffer and incur the costs of memory +// allocations. +// (2) Precompute the length of the structure, allocate a buffer of the +// exact required size and then write into the said buffer. +// QUICHE generally takes the second approach, but as a result, a lot of +// serialization code is written twice. This API avoids this issue by letting +// the caller declaratively describe the wire format; the description provided +// is used both for the size computation and for the serialization. +// +// Consider the following struct in RFC 9000 language: +// Test Struct { +// Magic Value (32), +// Some Number (i), +// [Optional Number (i)], +// Magical String Length (i), +// Magical String (..), +// } +// +// Using the functions in this header, it can be serialized as follows: +// absl::StatusOr test_struct = SerializeIntoBuffer( +// WireUint32(magic_value), +// WireVarInt62(some_number), +// WireOptional(optional_number), +// WireStringWithVarInt62Length(magical_string) +// ); +// +// This header provides three main functions with fairly self-explanatory names: +// - size_t ComputeLengthOnWire(d1, d2, ... dN) +// - absl::Status SerializeIntoWriter(writer, d1, d2, ... dN) +// - absl::StatusOr SerializeIntoBuffer(allocator, d1, ... dN) +// +// It is possible to define a custom serializer for individual structs. Those +// would normally look like this: +// +// struct AwesomeStruct { ... } +// class WireAwesomeStruct { +// public: +// using DataType = AwesomeStruct; +// WireAwesomeStruct(const AwesomeStruct& awesome) : awesome_(awesome) {} +// size_t GetLengthOnWire() { ... } +// absl::Status SerializeIntoWriter(QuicheDataWriter& writer) { ... } +// }; +// +// See the unit test for the full version of the example above. + +#ifndef QUICHE_COMMON_WIRE_SERIALIZATION_H_ +#define QUICHE_COMMON_WIRE_SERIALIZATION_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_buffer_allocator.h" +#include "quiche/common/quiche_data_writer.h" +#include "quiche/common/quiche_status_utils.h" + +namespace quiche { + +// T::SerializeIntoWriter() is allowed to return both a bool and an +// absl::Status. There are two reasons for that: +// 1. Most QuicheDataWriter methods return a bool. +// 2. While cheap, absl::Status has a non-trivial destructor and thus is not +// as free as a bool is. +// To accomodate this, SerializeIntoWriterStatus provides a way to deduce +// what is the status type returned by the SerializeIntoWriter method. +template +class QUICHE_NO_EXPORT SerializeIntoWriterStatus { + public: + static_assert(std::is_trivially_copyable_v && sizeof(T) <= 32, + "The types passed into SerializeInto() APIs are passed by " + "value; if your type has non-trivial copy costs, it should be " + "wrapped into a type that carries a pointer"); + + using Type = decltype(std::declval().SerializeIntoWriter( + std::declval())); + static constexpr bool kIsBool = std::is_same_v; + static constexpr bool kIsStatus = std::is_same_v; + static_assert( + kIsBool || kIsStatus, + "SerializeIntoWriter() has to return either a bool or an absl::Status"); + + static ABSL_ATTRIBUTE_ALWAYS_INLINE Type OkValue() { + if constexpr (kIsStatus) { + return absl::OkStatus(); + } else { + return true; + } + } +}; + +inline ABSL_ATTRIBUTE_ALWAYS_INLINE bool IsWriterStatusOk(bool status) { + return status; +} +inline ABSL_ATTRIBUTE_ALWAYS_INLINE bool IsWriterStatusOk( + const absl::Status& status) { + return status.ok(); +} + +// ------------------- WireType() wrapper definitions ------------------- + +// Base class for WireUint8/16/32/64. +template +class QUICHE_EXPORT WireFixedSizeIntBase { + public: + using DataType = T; + static_assert(std::is_integral_v, + "WireFixedSizeIntBase is only usable with integral types"); + + explicit WireFixedSizeIntBase(T value) { value_ = value; } + size_t GetLengthOnWire() const { return sizeof(T); } + T value() const { return value_; } + + private: + T value_; +}; + +// Fixed-size integer fields. Correspond to (8), (16), (32) and (64) fields in +// RFC 9000 language. +class QUICHE_EXPORT WireUint8 : public WireFixedSizeIntBase { + public: + using WireFixedSizeIntBase::WireFixedSizeIntBase; + bool SerializeIntoWriter(QuicheDataWriter& writer) const { + return writer.WriteUInt8(value()); + } +}; +class QUICHE_EXPORT WireUint16 : public WireFixedSizeIntBase { + public: + using WireFixedSizeIntBase::WireFixedSizeIntBase; + bool SerializeIntoWriter(QuicheDataWriter& writer) const { + return writer.WriteUInt16(value()); + } +}; +class QUICHE_EXPORT WireUint32 : public WireFixedSizeIntBase { + public: + using WireFixedSizeIntBase::WireFixedSizeIntBase; + bool SerializeIntoWriter(QuicheDataWriter& writer) const { + return writer.WriteUInt32(value()); + } +}; +class QUICHE_EXPORT WireUint64 : public WireFixedSizeIntBase { + public: + using WireFixedSizeIntBase::WireFixedSizeIntBase; + bool SerializeIntoWriter(QuicheDataWriter& writer) const { + return writer.WriteUInt64(value()); + } +}; + +// Represents a 62-bit variable-length non-negative integer. Those are +// described in the Section 16 of RFC 9000, and are denoted as (i) in type +// descriptions. +class QUICHE_EXPORT WireVarInt62 { + public: + using DataType = uint64_t; + + explicit WireVarInt62(uint64_t value) { value_ = value; } + // Convenience wrapper. This is safe, since it is clear from the context that + // the enum is being treated as an integer. + template + explicit WireVarInt62(T value) { + static_assert(std::is_enum_v || std::is_convertible_v); + value_ = static_cast(value); + } + + size_t GetLengthOnWire() const { + return QuicheDataWriter::GetVarInt62Len(value_); + } + bool SerializeIntoWriter(QuicheDataWriter& writer) const { + return writer.WriteVarInt62(value_); + } + + private: + uint64_t value_; +}; + +// Represents unframed raw string. +class QUICHE_EXPORT WireBytes { + public: + using DataType = absl::string_view; + + explicit WireBytes(absl::string_view value) { value_ = value; } + size_t GetLengthOnWire() { return value_.size(); } + bool SerializeIntoWriter(QuicheDataWriter& writer) { + return writer.WriteStringPiece(value_); + } + + private: + absl::string_view value_; +}; + +// Represents a string where another wire type is used as a length prefix. +template +class QUICHE_EXPORT WireStringWithLengthPrefix { + public: + using DataType = absl::string_view; + + explicit WireStringWithLengthPrefix(absl::string_view value) { + value_ = value; + } + size_t GetLengthOnWire() { + return LengthWireType(value_.size()).GetLengthOnWire() + value_.size(); + } + absl::Status SerializeIntoWriter(QuicheDataWriter& writer) { + if (!LengthWireType(value_.size()).SerializeIntoWriter(writer)) { + return absl::InternalError("Failed to serialize the length prefix"); + } + if (!writer.WriteStringPiece(value_)) { + return absl::InternalError("Failed to serialize the string proper"); + } + return absl::OkStatus(); + } + + private: + absl::string_view value_; +}; + +// Represents varint62-prefixed strings. +using WireStringWithVarInt62Length = WireStringWithLengthPrefix; + +// Allows absl::optional to be used with this API. For instance, if the spec +// defines +// [Context ID (i)] +// and the value is stored as absl::optional context_id, this can be +// recorded as +// WireOptional(context_id) +// When optional is absent, nothing is written onto the wire. +template +class QUICHE_EXPORT WireOptional { + public: + using DataType = absl::optional; + using Status = SerializeIntoWriterStatus; + + explicit WireOptional(DataType value) { value_ = value; } + size_t GetLengthOnWire() const { + return value_.has_value() ? WireType(*value_).GetLengthOnWire() : 0; + } + typename Status::Type SerializeIntoWriter(QuicheDataWriter& writer) const { + if (value_.has_value()) { + return WireType(*value_).SerializeIntoWriter(writer); + } + return Status::OkValue(); + } + + private: + DataType value_; +}; + +// Allows multiple entries of the same type to be serialized in a single call. +template +class QUICHE_EXPORT WireSpan { + public: + using DataType = absl::Span; + + explicit WireSpan(DataType value) { value_ = value; } + size_t GetLengthOnWire() const { + size_t total = 0; + for (const SpanElementType& value : value_) { + total += WireType(value).GetLengthOnWire(); + } + return total; + } + absl::Status SerializeIntoWriter(QuicheDataWriter& writer) const { + for (size_t i = 0; i < value_.size(); i++) { + // `status` here can be either a bool or an absl::Status. + auto status = WireType(value_[i]).SerializeIntoWriter(writer); + if (IsWriterStatusOk(status)) { + continue; + } + if constexpr (SerializeIntoWriterStatus::kIsStatus) { + return AppendToStatus(std::move(status), + " while serializing the value #", i); + } else { + return absl::InternalError( + absl::StrCat("Failed to serialize vector value #", i)); + } + } + return absl::OkStatus(); + } + + private: + DataType value_; +}; + +// ------------------- Top-level serialization API ------------------- + +namespace wire_serialization_internal { +template +auto SerializeIntoWriterWrapper(QuicheDataWriter& writer, int argno, T data) { +#if defined(NDEBUG) + (void)argno; + (void)data; + return data.SerializeIntoWriter(writer); +#else + // When running in the debug build, we check that the length reported by + // GetLengthOnWire() matches what is actually being written onto the wire. + // While any mismatch will most likely lead to an error further down the line, + // this simplifies the debugging process. + const size_t initial_offset = writer.length(); + const size_t expected_size = data.GetLengthOnWire(); + auto result = data.SerializeIntoWriter(writer); + const size_t final_offset = writer.length(); + if (IsWriterStatusOk(result)) { + QUICHE_DCHECK_EQ(initial_offset + expected_size, final_offset) + << "while serializing field #" << argno; + } + return result; +#endif +} + +template +std::enable_if_t::kIsBool, absl::Status> +SerializeIntoWriterCore(QuicheDataWriter& writer, int argno, T data) { + const bool success = SerializeIntoWriterWrapper(writer, argno, data); + if (!success) { + return absl::InternalError( + absl::StrCat("Failed to serialize field #", argno)); + } + return absl::OkStatus(); +} + +template +std::enable_if_t::kIsStatus, absl::Status> +SerializeIntoWriterCore(QuicheDataWriter& writer, int argno, T data) { + return AppendToStatus(SerializeIntoWriterWrapper(writer, argno, data), + " while serializing field #", argno); +} + +template +absl::Status SerializeIntoWriterCore(QuicheDataWriter& writer, int argno, + T1 data1, Ts... rest) { + QUICHE_RETURN_IF_ERROR(SerializeIntoWriterCore(writer, argno, data1)); + return SerializeIntoWriterCore(writer, argno + 1, rest...); +} +} // namespace wire_serialization_internal + +// SerializeIntoWriter(writer, d1, d2, ... dN) serializes all of supplied data +// into the writer |writer|. True is returned on success, and false is returned +// if serialization fails (typically because the writer ran out of buffer). This +// is conceptually similar to absl::StrAppend(). +template +absl::Status SerializeIntoWriter(QuicheDataWriter& writer, Ts... data) { + return wire_serialization_internal::SerializeIntoWriterCore( + writer, /*argno=*/0, data...); +} + +// ComputeLengthOnWire(writer, d1, d2, ... dN) calculates the number of bytes +// necessary to serialize the supplied data. +template +size_t ComputeLengthOnWire(T data) { + return data.GetLengthOnWire(); +} +template +size_t ComputeLengthOnWire(T1 data1, Ts... rest) { + return data1.GetLengthOnWire() + ComputeLengthOnWire(rest...); +} + +// SerializeIntoBuffer(allocator, d1, d2, ... dN) computes the length required +// to store the supplied data, allocates the buffer of appropriate size using +// |allocator|, and serializes the result into it. In a rare event that the +// serialization fails (e.g. due to invalid varint62 value), an empty buffer is +// returned. +template +absl::StatusOr SerializeIntoBuffer( + QuicheBufferAllocator* allocator, Ts... data) { + size_t buffer_size = ComputeLengthOnWire(data...); + if (buffer_size == 0) { + return QuicheBuffer(); + } + + QuicheBuffer buffer(allocator, buffer_size); + QuicheDataWriter writer(buffer.size(), buffer.data()); + QUICHE_RETURN_IF_ERROR(SerializeIntoWriter(writer, data...)); + if (writer.remaining() != 0) { + return absl::InternalError(absl::StrCat( + "Excess ", writer.remaining(), " bytes allocated while serializing")); + } + return buffer; +} + +} // namespace quiche + +#endif // QUICHE_COMMON_WIRE_SERIALIZATION_H_ diff --git a/quiche/common/wire_serialization_test.cc b/quiche/common/wire_serialization_test.cc new file mode 100644 index 000000000000..b1dea91024ba --- /dev/null +++ b/quiche/common/wire_serialization_test.cc @@ -0,0 +1,256 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/common/wire_serialization.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/common/platform/api/quiche_expect_bug.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/quiche_buffer_allocator.h" +#include "quiche/common/quiche_endian.h" +#include "quiche/common/quiche_status_utils.h" +#include "quiche/common/simple_buffer_allocator.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +namespace quiche::test { +namespace { + +using ::testing::ElementsAre; + +constexpr uint64_t kInvalidVarInt = std::numeric_limits::max(); + +template +absl::StatusOr SerializeIntoSimpleBuffer(Ts... data) { + return SerializeIntoBuffer(quiche::SimpleBufferAllocator::Get(), data...); +} + +template +void ExpectEncoding(const std::string& description, absl::string_view expected, + Ts... data) { + absl::StatusOr actual = + SerializeIntoSimpleBuffer(data...); + QUICHE_ASSERT_OK(actual); + quiche::test::CompareCharArraysWithHexError(description, actual->data(), + actual->size(), expected.data(), + expected.size()); +} + +template +void ExpectEncodingHex(const std::string& description, + absl::string_view expected_hex, Ts... data) { + ExpectEncoding(description, absl::HexStringToBytes(expected_hex), data...); +} + +TEST(SerializationTest, SerializeStrings) { + absl::StatusOr one_string = + SerializeIntoSimpleBuffer(WireBytes("test")); + QUICHE_ASSERT_OK(one_string); + EXPECT_EQ(one_string->AsStringView(), "test"); + + absl::StatusOr two_strings = + SerializeIntoSimpleBuffer(WireBytes("Hello"), WireBytes("World")); + QUICHE_ASSERT_OK(two_strings); + EXPECT_EQ(two_strings->AsStringView(), "HelloWorld"); +} + +TEST(SerializationTest, SerializeIntegers) { + ExpectEncodingHex("one uint8_t value", "42", WireUint8(0x42)); + ExpectEncodingHex("two uint8_t values", "ab01", WireUint8(0xab), + WireUint8(0x01)); + ExpectEncodingHex("one uint16_t value", "1234", WireUint16(0x1234)); + ExpectEncodingHex("one uint32_t value", "12345678", WireUint32(0x12345678)); + ExpectEncodingHex("one uint64_t value", "123456789abcdef0", + WireUint64(UINT64_C(0x123456789abcdef0))); + ExpectEncodingHex("mix of values", "aabbcc000000dd", WireUint8(0xaa), + WireUint16(0xbbcc), WireUint32(0xdd)); +} + +TEST(SerializationTest, SerializeLittleEndian) { + char buffer[4]; + QuicheDataWriter writer(sizeof(buffer), buffer, + quiche::Endianness::HOST_BYTE_ORDER); + QUICHE_ASSERT_OK( + SerializeIntoWriter(writer, WireUint16(0x1234), WireUint16(0xabcd))); + absl::string_view actual(writer.data(), writer.length()); + EXPECT_EQ(actual, absl::HexStringToBytes("3412cdab")); +} + +TEST(SerializationTest, SerializeVarInt62) { + // Test cases from RFC 9000, Appendix A.1 + ExpectEncodingHex("1-byte varint", "25", WireVarInt62(37)); + ExpectEncodingHex("2-byte varint", "7bbd", WireVarInt62(15293)); + ExpectEncodingHex("4-byte varint", "9d7f3e7d", WireVarInt62(494878333)); + ExpectEncodingHex("8-byte varint", "c2197c5eff14e88c", + WireVarInt62(UINT64_C(151288809941952652))); +} + +TEST(SerializationTest, SerializeStringWithVarInt62Length) { + ExpectEncodingHex("short string", "0474657374", + WireStringWithVarInt62Length("test")); + const std::string long_string(15293, 'a'); + ExpectEncoding("long string", absl::StrCat("\x7b\xbd", long_string), + WireStringWithVarInt62Length(long_string)); + ExpectEncodingHex("empty string", "00", WireStringWithVarInt62Length("")); +} + +TEST(SerializationTest, SerializeOptionalValues) { + absl::optional has_no_value; + absl::optional has_value = 0x42; + ExpectEncodingHex("optional without value", "00", WireUint8(0), + WireOptional(has_no_value)); + ExpectEncodingHex("optional with value", "0142", WireUint8(1), + WireOptional(has_value)); + ExpectEncodingHex("empty data", "", WireOptional(has_no_value)); + + absl::optional has_no_string; + absl::optional has_string = "\x42"; + ExpectEncodingHex("optional no string", "", + WireOptional(has_no_string)); + ExpectEncodingHex("optional string", "0142", + WireOptional(has_string)); +} + +enum class TestEnum { + kValue1 = 0x17, + kValue2 = 0x19, +}; + +TEST(SerializationTest, SerializeEnumValue) { + ExpectEncodingHex("enum value", "17", WireVarInt62(TestEnum::kValue1)); +} + +TEST(SerializationTest, SerializeLotsOfValues) { + ExpectEncodingHex("ten values", "00010203040506070809", WireUint8(0), + WireUint8(1), WireUint8(2), WireUint8(3), WireUint8(4), + WireUint8(5), WireUint8(6), WireUint8(7), WireUint8(8), + WireUint8(9)); +} + +TEST(SerializationTest, FailDueToLackOfSpace) { + char buffer[4]; + QuicheDataWriter writer(sizeof(buffer), buffer); + QUICHE_EXPECT_OK(SerializeIntoWriter(writer, WireUint32(0))); + ASSERT_EQ(writer.remaining(), 0u); + EXPECT_THAT( + SerializeIntoWriter(writer, WireUint32(0)), + StatusIs(absl::StatusCode::kInternal, "Failed to serialize field #0")); + EXPECT_THAT( + SerializeIntoWriter(writer, WireStringWithVarInt62Length("test")), + StatusIs( + absl::StatusCode::kInternal, + "Failed to serialize the length prefix while serializing field #0")); +} + +TEST(SerializationTest, FailDueToInvalidValue) { + EXPECT_QUICHE_BUG( + ExpectEncoding("invalid varint", "", WireVarInt62(kInvalidVarInt)), + "too big for VarInt62"); +} + +TEST(SerializationTest, InvalidValueCausesPartialWrite) { + char buffer[3] = {'\0'}; + QuicheDataWriter writer(sizeof(buffer), buffer); + QUICHE_EXPECT_OK(SerializeIntoWriter(writer, WireBytes("a"))); + EXPECT_THAT( + SerializeIntoWriter(writer, WireBytes("b"), + WireBytes("A considerably long string, writing which " + "will most likely cause ASAN to crash"), + WireBytes("c")), + StatusIs(absl::StatusCode::kInternal, "Failed to serialize field #1")); + EXPECT_THAT(buffer, ElementsAre('a', 'b', '\0')); + + QUICHE_EXPECT_OK(SerializeIntoWriter(writer, WireBytes("z"))); + EXPECT_EQ(buffer[2], 'z'); +} + +TEST(SerializationTest, SerializeVector) { + std::vector strs = {"foo", "test", "bar"}; + absl::StatusOr serialized = + SerializeIntoSimpleBuffer(WireSpan(absl::MakeSpan(strs))); + QUICHE_ASSERT_OK(serialized); + EXPECT_EQ(serialized->AsStringView(), "footestbar"); +} + +struct AwesomeStruct { + uint64_t awesome_number; + std::string awesome_text; +}; + +class WireAwesomeStruct { + public: + using DataType = AwesomeStruct; + + WireAwesomeStruct(const AwesomeStruct& awesome) : awesome_(awesome) {} + + size_t GetLengthOnWire() { + return quiche::ComputeLengthOnWire(WireUint16(awesome_.awesome_number), + WireBytes(awesome_.awesome_text)); + } + absl::Status SerializeIntoWriter(QuicheDataWriter& writer) { + return AppendToStatus(::quiche::SerializeIntoWriter( + writer, WireUint16(awesome_.awesome_number), + WireBytes(awesome_.awesome_text)), + " while serializing AwesomeStruct"); + } + + private: + const AwesomeStruct& awesome_; +}; + +TEST(SerializationTest, CustomStruct) { + AwesomeStruct awesome; + awesome.awesome_number = 0xabcd; + awesome.awesome_text = "test"; + ExpectEncodingHex("struct", "abcd74657374", WireAwesomeStruct(awesome)); +} + +TEST(SerializationTest, CustomStructSpan) { + std::array awesome; + awesome[0].awesome_number = 0xabcd; + awesome[0].awesome_text = "test"; + awesome[1].awesome_number = 0x1234; + awesome[1].awesome_text = std::string(3, '\0'); + ExpectEncodingHex("struct", "abcd746573741234000000", + WireSpan(absl::MakeSpan(awesome))); +} + +class WireFormatterThatWritesTooLittle { + public: + using DataType = absl::string_view; + + explicit WireFormatterThatWritesTooLittle(absl::string_view s) : s_(s) {} + + size_t GetLengthOnWire() const { return s_.size(); } + bool SerializeIntoWriter(QuicheDataWriter& writer) { + return writer.WriteStringPiece(s_.substr(0, s_.size() - 1)); + } + + private: + absl::string_view s_; +}; + +TEST(SerializationTest, CustomStructWritesTooLittle) { + constexpr absl::string_view kStr = "\xaa\xbb\xcc\xdd"; +#if defined(NDEBUG) + absl::Status status = + SerializeIntoSimpleBuffer(WireFormatterThatWritesTooLittle(kStr)) + .status(); + EXPECT_THAT(status, StatusIs(absl::StatusCode::kInternal, + ::testing::HasSubstr("Excess 1 bytes"))); +#else + EXPECT_DEATH(QUICHE_LOG(INFO) << SerializeIntoSimpleBuffer( + WireFormatterThatWritesTooLittle(kStr)) + .status(), + "while serializing field #0"); +#endif +} + +} // namespace +} // namespace quiche::test diff --git a/quiche/http2/adapter/adapter_impl_comparison_test.cc b/quiche/http2/adapter/adapter_impl_comparison_test.cc new file mode 100644 index 000000000000..aaa2f99f254b --- /dev/null +++ b/quiche/http2/adapter/adapter_impl_comparison_test.cc @@ -0,0 +1,171 @@ +#include "quiche/http2/adapter/http2_protocol.h" +#include "quiche/http2/adapter/nghttp2_adapter.h" +#include "quiche/http2/adapter/oghttp2_adapter.h" +#include "quiche/http2/adapter/recording_http2_visitor.h" +#include "quiche/http2/adapter/test_frame_sequence.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/spdy/core/spdy_protocol.h" + +namespace http2 { +namespace adapter { +namespace test { +namespace { + +TEST(AdapterImplComparisonTest, ClientHandlesFrames) { + RecordingHttp2Visitor nghttp2_visitor; + std::unique_ptr nghttp2_adapter = + NgHttp2Adapter::CreateClientAdapter(nghttp2_visitor); + + RecordingHttp2Visitor oghttp2_visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + std::unique_ptr oghttp2_adapter = + OgHttp2Adapter::Create(oghttp2_visitor, options); + + const std::string initial_frames = TestFrameSequence() + .ServerPreface() + .Ping(42) + .WindowUpdate(0, 1000) + .Serialize(); + + nghttp2_adapter->ProcessBytes(initial_frames); + oghttp2_adapter->ProcessBytes(initial_frames); + + EXPECT_EQ(nghttp2_visitor.GetEventSequence(), + oghttp2_visitor.GetEventSequence()); + + // TODO(b/181586191): Consider consistent behavior for delivering events on + // non-existent streams between nghttp2_adapter and oghttp2_adapter. +} + +TEST(AdapterImplComparisonTest, SubmitWindowUpdateBumpsWindow) { + RecordingHttp2Visitor nghttp2_visitor; + std::unique_ptr nghttp2_adapter = + NgHttp2Adapter::CreateClientAdapter(nghttp2_visitor); + + RecordingHttp2Visitor oghttp2_visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + std::unique_ptr oghttp2_adapter = + OgHttp2Adapter::Create(oghttp2_visitor, options); + + int result; + + const std::vector
request_headers = + ToHeaders({{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}}); + const int kInitialFlowControlWindow = 65535; + const int kConnectionWindowIncrease = 192 * 1024; + + const int32_t nghttp2_stream_id = + nghttp2_adapter->SubmitRequest(request_headers, nullptr, nullptr); + + // Both the connection and stream flow control windows are increased. + nghttp2_adapter->SubmitWindowUpdate(0, kConnectionWindowIncrease); + nghttp2_adapter->SubmitWindowUpdate(nghttp2_stream_id, + kConnectionWindowIncrease); + result = nghttp2_adapter->Send(); + EXPECT_EQ(0, result); + int nghttp2_window = nghttp2_adapter->GetReceiveWindowSize(); + EXPECT_EQ(kInitialFlowControlWindow + kConnectionWindowIncrease, + nghttp2_window); + + const int32_t oghttp2_stream_id = + oghttp2_adapter->SubmitRequest(request_headers, nullptr, nullptr); + // Both the connection and stream flow control windows are increased. + oghttp2_adapter->SubmitWindowUpdate(0, kConnectionWindowIncrease); + oghttp2_adapter->SubmitWindowUpdate(oghttp2_stream_id, + kConnectionWindowIncrease); + result = oghttp2_adapter->Send(); + EXPECT_EQ(0, result); + int oghttp2_window = oghttp2_adapter->GetReceiveWindowSize(); + EXPECT_EQ(kInitialFlowControlWindow + kConnectionWindowIncrease, + oghttp2_window); + + // nghttp2 and oghttp2 agree on the advertised window. + EXPECT_EQ(nghttp2_window, oghttp2_window); + + ASSERT_EQ(nghttp2_stream_id, oghttp2_stream_id); + + const int kMaxFrameSize = 16 * 1024; + const std::string body_chunk(kMaxFrameSize, 'a'); + auto sequence = TestFrameSequence(); + sequence.ServerPreface().Headers(nghttp2_stream_id, {{":status", "200"}}, + /*fin=*/false); + // This loop generates enough DATA frames to consume the window increase. + const int kNumFrames = kConnectionWindowIncrease / kMaxFrameSize; + for (int i = 0; i < kNumFrames; ++i) { + sequence.Data(nghttp2_stream_id, body_chunk); + } + const std::string frames = sequence.Serialize(); + + nghttp2_adapter->ProcessBytes(frames); + // Marking the data consumed causes a window update, which is reflected in the + // advertised window size. + nghttp2_adapter->MarkDataConsumedForStream(nghttp2_stream_id, + kNumFrames * kMaxFrameSize); + result = nghttp2_adapter->Send(); + EXPECT_EQ(0, result); + nghttp2_window = nghttp2_adapter->GetReceiveWindowSize(); + + oghttp2_adapter->ProcessBytes(frames); + // Marking the data consumed causes a window update, which is reflected in the + // advertised window size. + oghttp2_adapter->MarkDataConsumedForStream(oghttp2_stream_id, + kNumFrames * kMaxFrameSize); + result = oghttp2_adapter->Send(); + EXPECT_EQ(0, result); + oghttp2_window = oghttp2_adapter->GetReceiveWindowSize(); + + const int kMinExpectation = + (kInitialFlowControlWindow + kConnectionWindowIncrease) / 2; + EXPECT_GT(nghttp2_window, kMinExpectation); + EXPECT_GT(oghttp2_window, kMinExpectation); +} + +TEST(AdapterImplComparisonTest, ServerHandlesFrames) { + RecordingHttp2Visitor nghttp2_visitor; + std::unique_ptr nghttp2_adapter = + NgHttp2Adapter::CreateServerAdapter(nghttp2_visitor); + + RecordingHttp2Visitor oghttp2_visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + std::unique_ptr oghttp2_adapter = + OgHttp2Adapter::Create(oghttp2_visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Ping(42) + .WindowUpdate(0, 1000) + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .Headers(3, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .RstStream(3, Http2ErrorCode::CANCEL) + .Ping(47) + .Serialize(); + + nghttp2_adapter->ProcessBytes(frames); + oghttp2_adapter->ProcessBytes(frames); + + EXPECT_EQ(nghttp2_visitor.GetEventSequence(), + oghttp2_visitor.GetEventSequence()); +} + +} // namespace +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/callback_visitor.cc b/quiche/http2/adapter/callback_visitor.cc new file mode 100644 index 000000000000..64f0ed6d9cf5 --- /dev/null +++ b/quiche/http2/adapter/callback_visitor.cc @@ -0,0 +1,509 @@ +#include "quiche/http2/adapter/callback_visitor.h" + +#include "absl/strings/escaping.h" +#include "quiche/http2/adapter/http2_util.h" +#include "quiche/http2/adapter/nghttp2_util.h" +#include "quiche/common/quiche_endian.h" + +// This visitor implementation needs visibility into the +// nghttp2_session_callbacks type. There's no public header, so we'll redefine +// the struct here. +struct nghttp2_session_callbacks { + nghttp2_send_callback send_callback; + nghttp2_recv_callback recv_callback; + nghttp2_on_frame_recv_callback on_frame_recv_callback; + nghttp2_on_invalid_frame_recv_callback on_invalid_frame_recv_callback; + nghttp2_on_data_chunk_recv_callback on_data_chunk_recv_callback; + nghttp2_before_frame_send_callback before_frame_send_callback; + nghttp2_on_frame_send_callback on_frame_send_callback; + nghttp2_on_frame_not_send_callback on_frame_not_send_callback; + nghttp2_on_stream_close_callback on_stream_close_callback; + nghttp2_on_begin_headers_callback on_begin_headers_callback; + nghttp2_on_header_callback on_header_callback; + nghttp2_on_header_callback2 on_header_callback2; + nghttp2_on_invalid_header_callback on_invalid_header_callback; + nghttp2_on_invalid_header_callback2 on_invalid_header_callback2; + nghttp2_select_padding_callback select_padding_callback; + nghttp2_data_source_read_length_callback read_length_callback; + nghttp2_on_begin_frame_callback on_begin_frame_callback; + nghttp2_send_data_callback send_data_callback; + nghttp2_pack_extension_callback pack_extension_callback; + nghttp2_unpack_extension_callback unpack_extension_callback; + nghttp2_on_extension_chunk_recv_callback on_extension_chunk_recv_callback; + nghttp2_error_callback error_callback; + nghttp2_error_callback2 error_callback2; +}; + +namespace http2 { +namespace adapter { + +CallbackVisitor::CallbackVisitor(Perspective perspective, + const nghttp2_session_callbacks& callbacks, + void* user_data) + : perspective_(perspective), + callbacks_(MakeCallbacksPtr(nullptr)), + user_data_(user_data) { + nghttp2_session_callbacks* c; + nghttp2_session_callbacks_new(&c); + *c = callbacks; + callbacks_ = MakeCallbacksPtr(c); + memset(¤t_frame_, 0, sizeof(current_frame_)); +} + +int64_t CallbackVisitor::OnReadyToSend(absl::string_view serialized) { + if (!callbacks_->send_callback) { + return kSendError; + } + int64_t result = callbacks_->send_callback( + nullptr, ToUint8Ptr(serialized.data()), serialized.size(), 0, user_data_); + QUICHE_VLOG(1) << "CallbackVisitor::OnReadyToSend called with " + << serialized.size() << " bytes, returning " << result; + QUICHE_VLOG(2) << (perspective_ == Perspective::kClient ? "Client" : "Server") + << " sending: [" << absl::CEscape(serialized) << "]"; + if (result > 0) { + return result; + } else if (result == NGHTTP2_ERR_WOULDBLOCK) { + return kSendBlocked; + } else { + return kSendError; + } +} + +void CallbackVisitor::OnConnectionError(ConnectionError /*error*/) { + QUICHE_VLOG(1) << "OnConnectionError not implemented"; +} + +bool CallbackVisitor::OnFrameHeader(Http2StreamId stream_id, size_t length, + uint8_t type, uint8_t flags) { + QUICHE_VLOG(1) << "CallbackVisitor::OnFrameHeader(stream_id=" << stream_id + << ", type=" << int(type) << ", length=" << length + << ", flags=" << int(flags) << ")"; + if (static_cast(type) == FrameType::CONTINUATION) { + if (static_cast(current_frame_.hd.type) != FrameType::HEADERS || + current_frame_.hd.stream_id == 0 || + current_frame_.hd.stream_id != stream_id) { + // CONTINUATION frames must follow HEADERS on the same stream. If no + // frames have been received, the type is initialized to zero, and the + // comparison will fail. + return false; + } + current_frame_.hd.length += length; + current_frame_.hd.flags |= flags; + QUICHE_DLOG_IF(ERROR, length == 0) << "Empty CONTINUATION!"; + // Still need to deliver the CONTINUATION to the begin frame callback. + nghttp2_frame_hd hd; + memset(&hd, 0, sizeof(hd)); + hd.stream_id = stream_id; + hd.length = length; + hd.type = type; + hd.flags = flags; + if (callbacks_->on_begin_frame_callback) { + const int result = + callbacks_->on_begin_frame_callback(nullptr, &hd, user_data_); + return result == 0; + } + return true; + } + // The general strategy is to clear |current_frame_| at the start of a new + // frame, accumulate frame information from the various callback events, then + // invoke the on_frame_recv_callback() with the accumulated frame data. + memset(¤t_frame_, 0, sizeof(current_frame_)); + current_frame_.hd.stream_id = stream_id; + current_frame_.hd.length = length; + current_frame_.hd.type = type; + current_frame_.hd.flags = flags; + if (callbacks_->on_begin_frame_callback) { + const int result = callbacks_->on_begin_frame_callback( + nullptr, ¤t_frame_.hd, user_data_); + return result == 0; + } + return true; +} + +void CallbackVisitor::OnSettingsStart() {} + +void CallbackVisitor::OnSetting(Http2Setting setting) { + settings_.push_back({setting.id, setting.value}); +} + +void CallbackVisitor::OnSettingsEnd() { + current_frame_.settings.niv = settings_.size(); + current_frame_.settings.iv = settings_.data(); + QUICHE_VLOG(1) << "OnSettingsEnd, received settings of size " + << current_frame_.settings.niv; + if (callbacks_->on_frame_recv_callback) { + const int result = callbacks_->on_frame_recv_callback( + nullptr, ¤t_frame_, user_data_); + QUICHE_DCHECK_EQ(0, result); + } + settings_.clear(); +} + +void CallbackVisitor::OnSettingsAck() { + // ACK is part of the flags, which were set in OnFrameHeader(). + QUICHE_VLOG(1) << "OnSettingsAck()"; + if (callbacks_->on_frame_recv_callback) { + const int result = callbacks_->on_frame_recv_callback( + nullptr, ¤t_frame_, user_data_); + QUICHE_DCHECK_EQ(0, result); + } +} + +bool CallbackVisitor::OnBeginHeadersForStream(Http2StreamId stream_id) { + auto it = GetStreamInfo(stream_id); + if (it->second.received_headers) { + // At least one headers frame has already been received. + QUICHE_VLOG(1) + << "Headers already received for stream " << stream_id + << ", these are trailers or headers following a 100 response"; + current_frame_.headers.cat = NGHTTP2_HCAT_HEADERS; + } else { + switch (perspective_) { + case Perspective::kClient: + QUICHE_VLOG(1) << "First headers at the client for stream " << stream_id + << "; these are response headers"; + current_frame_.headers.cat = NGHTTP2_HCAT_RESPONSE; + break; + case Perspective::kServer: + QUICHE_VLOG(1) << "First headers at the server for stream " << stream_id + << "; these are request headers"; + current_frame_.headers.cat = NGHTTP2_HCAT_REQUEST; + break; + } + } + it->second.received_headers = true; + if (callbacks_->on_begin_headers_callback) { + const int result = callbacks_->on_begin_headers_callback( + nullptr, ¤t_frame_, user_data_); + return result == 0; + } + return true; +} + +Http2VisitorInterface::OnHeaderResult CallbackVisitor::OnHeaderForStream( + Http2StreamId stream_id, absl::string_view name, absl::string_view value) { + QUICHE_VLOG(2) << "OnHeaderForStream(stream_id=" << stream_id << ", name=[" + << absl::CEscape(name) << "], value=[" << absl::CEscape(value) + << "])"; + if (callbacks_->on_header_callback) { + const int result = callbacks_->on_header_callback( + nullptr, ¤t_frame_, ToUint8Ptr(name.data()), name.size(), + ToUint8Ptr(value.data()), value.size(), NGHTTP2_NV_FLAG_NONE, + user_data_); + if (result == 0) { + return HEADER_OK; + } else if (result == NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE) { + return HEADER_RST_STREAM; + } else { + // Assume NGHTTP2_ERR_CALLBACK_FAILURE. + return HEADER_CONNECTION_ERROR; + } + } + return HEADER_OK; +} + +bool CallbackVisitor::OnEndHeadersForStream(Http2StreamId stream_id) { + QUICHE_VLOG(1) << "OnEndHeadersForStream(stream_id=" << stream_id << ")"; + if (callbacks_->on_frame_recv_callback) { + const int result = callbacks_->on_frame_recv_callback( + nullptr, ¤t_frame_, user_data_); + return result == 0; + } + return true; +} + +bool CallbackVisitor::OnDataPaddingLength(Http2StreamId /*stream_id*/, + size_t padding_length) { + QUICHE_DCHECK_GE(remaining_data_, padding_length); + current_frame_.data.padlen = padding_length; + remaining_data_ -= padding_length; + if (remaining_data_ == 0 && + (current_frame_.hd.flags & NGHTTP2_FLAG_END_STREAM) == 0 && + callbacks_->on_frame_recv_callback != nullptr) { + const int result = callbacks_->on_frame_recv_callback( + nullptr, ¤t_frame_, user_data_); + return result == 0; + } + return true; +} + +bool CallbackVisitor::OnBeginDataForStream(Http2StreamId /*stream_id*/, + size_t payload_length) { + remaining_data_ = payload_length; + if (remaining_data_ == 0 && + (current_frame_.hd.flags & NGHTTP2_FLAG_END_STREAM) == 0 && + callbacks_->on_frame_recv_callback != nullptr) { + const int result = callbacks_->on_frame_recv_callback( + nullptr, ¤t_frame_, user_data_); + return result == 0; + } + return true; +} + +bool CallbackVisitor::OnDataForStream(Http2StreamId stream_id, + absl::string_view data) { + QUICHE_VLOG(1) << "OnDataForStream(stream_id=" << stream_id + << ", data.size()=" << data.size() << ")"; + int result = 0; + if (callbacks_->on_data_chunk_recv_callback) { + result = callbacks_->on_data_chunk_recv_callback( + nullptr, current_frame_.hd.flags, stream_id, ToUint8Ptr(data.data()), + data.size(), user_data_); + } + remaining_data_ -= data.size(); + if (result == 0 && remaining_data_ == 0 && + (current_frame_.hd.flags & NGHTTP2_FLAG_END_STREAM) == 0 && + callbacks_->on_frame_recv_callback) { + // If the DATA frame contains the END_STREAM flag, `on_frame_recv` is + // invoked later. + result = callbacks_->on_frame_recv_callback(nullptr, ¤t_frame_, + user_data_); + } + return result == 0; +} + +bool CallbackVisitor::OnEndStream(Http2StreamId stream_id) { + QUICHE_VLOG(1) << "OnEndStream(stream_id=" << stream_id << ")"; + int result = 0; + if (static_cast(current_frame_.hd.type) == FrameType::DATA && + (current_frame_.hd.flags & NGHTTP2_FLAG_END_STREAM) != 0 && + callbacks_->on_frame_recv_callback) { + // `on_frame_recv` is invoked here to ensure that the Http2Adapter + // implementation has successfully validated and processed the entire DATA + // frame. + result = callbacks_->on_frame_recv_callback(nullptr, ¤t_frame_, + user_data_); + } + return result == 0; +} + +void CallbackVisitor::OnRstStream(Http2StreamId stream_id, + Http2ErrorCode error_code) { + QUICHE_VLOG(1) << "OnRstStream(stream_id=" << stream_id + << ", error_code=" << static_cast(error_code) << ")"; + current_frame_.rst_stream.error_code = static_cast(error_code); + if (callbacks_->on_frame_recv_callback) { + const int result = callbacks_->on_frame_recv_callback( + nullptr, ¤t_frame_, user_data_); + QUICHE_DCHECK_EQ(0, result); + } +} + +bool CallbackVisitor::OnCloseStream(Http2StreamId stream_id, + Http2ErrorCode error_code) { + QUICHE_VLOG(1) << "OnCloseStream(stream_id=" << stream_id + << ", error_code=" << static_cast(error_code) << ")"; + int result = 0; + if (callbacks_->on_stream_close_callback) { + result = callbacks_->on_stream_close_callback( + nullptr, stream_id, static_cast(error_code), user_data_); + } + stream_map_.erase(stream_id); + if (stream_close_listener_) { + stream_close_listener_(stream_id); + } + return result == 0; +} + +void CallbackVisitor::OnPriorityForStream(Http2StreamId /*stream_id*/, + Http2StreamId parent_stream_id, + int weight, bool exclusive) { + current_frame_.priority.pri_spec.stream_id = parent_stream_id; + current_frame_.priority.pri_spec.weight = weight; + current_frame_.priority.pri_spec.exclusive = exclusive; + if (callbacks_->on_frame_recv_callback) { + const int result = callbacks_->on_frame_recv_callback( + nullptr, ¤t_frame_, user_data_); + QUICHE_DCHECK_EQ(0, result); + } +} + +void CallbackVisitor::OnPing(Http2PingId ping_id, bool is_ack) { + QUICHE_VLOG(1) << "OnPing(ping_id=" << static_cast(ping_id) + << ", is_ack=" << is_ack << ")"; + uint64_t network_order_opaque_data = + quiche::QuicheEndian::HostToNet64(ping_id); + std::memcpy(current_frame_.ping.opaque_data, &network_order_opaque_data, + sizeof(network_order_opaque_data)); + if (callbacks_->on_frame_recv_callback) { + const int result = callbacks_->on_frame_recv_callback( + nullptr, ¤t_frame_, user_data_); + QUICHE_DCHECK_EQ(0, result); + } +} + +void CallbackVisitor::OnPushPromiseForStream( + Http2StreamId /*stream_id*/, Http2StreamId /*promised_stream_id*/) { + QUICHE_LOG(DFATAL) << "Not implemented"; +} + +bool CallbackVisitor::OnGoAway(Http2StreamId last_accepted_stream_id, + Http2ErrorCode error_code, + absl::string_view opaque_data) { + QUICHE_VLOG(1) << "OnGoAway(last_accepted_stream_id=" + << last_accepted_stream_id + << ", error_code=" << static_cast(error_code) + << ", opaque_data=[" << absl::CEscape(opaque_data) << "])"; + current_frame_.goaway.last_stream_id = last_accepted_stream_id; + current_frame_.goaway.error_code = static_cast(error_code); + current_frame_.goaway.opaque_data = ToUint8Ptr(opaque_data.data()); + current_frame_.goaway.opaque_data_len = opaque_data.size(); + if (callbacks_->on_frame_recv_callback) { + const int result = callbacks_->on_frame_recv_callback( + nullptr, ¤t_frame_, user_data_); + return result == 0; + } + return true; +} + +void CallbackVisitor::OnWindowUpdate(Http2StreamId stream_id, + int window_increment) { + QUICHE_VLOG(1) << "OnWindowUpdate(stream_id=" << stream_id + << ", delta=" << window_increment << ")"; + current_frame_.window_update.window_size_increment = window_increment; + if (callbacks_->on_frame_recv_callback) { + const int result = callbacks_->on_frame_recv_callback( + nullptr, ¤t_frame_, user_data_); + QUICHE_DCHECK_EQ(0, result); + } +} + +void CallbackVisitor::PopulateFrame(nghttp2_frame& frame, uint8_t frame_type, + Http2StreamId stream_id, size_t length, + uint8_t flags, uint32_t error_code, + bool sent_headers) { + frame.hd.type = frame_type; + frame.hd.stream_id = stream_id; + frame.hd.length = length; + frame.hd.flags = flags; + const FrameType frame_type_enum = static_cast(frame_type); + if (frame_type_enum == FrameType::HEADERS) { + if (sent_headers) { + frame.headers.cat = NGHTTP2_HCAT_HEADERS; + } else { + switch (perspective_) { + case Perspective::kClient: + QUICHE_VLOG(1) << "First headers sent by the client for stream " + << stream_id << "; these are request headers"; + frame.headers.cat = NGHTTP2_HCAT_REQUEST; + break; + case Perspective::kServer: + QUICHE_VLOG(1) << "First headers sent by the server for stream " + << stream_id << "; these are response headers"; + frame.headers.cat = NGHTTP2_HCAT_RESPONSE; + break; + } + } + } else if (frame_type_enum == FrameType::RST_STREAM) { + frame.rst_stream.error_code = error_code; + } else if (frame_type_enum == FrameType::GOAWAY) { + frame.goaway.error_code = error_code; + } +} + +int CallbackVisitor::OnBeforeFrameSent(uint8_t frame_type, + Http2StreamId stream_id, size_t length, + uint8_t flags) { + QUICHE_VLOG(1) << "OnBeforeFrameSent(stream_id=" << stream_id + << ", type=" << int(frame_type) << ", length=" << length + << ", flags=" << int(flags) << ")"; + if (callbacks_->before_frame_send_callback) { + nghttp2_frame frame; + auto it = GetStreamInfo(stream_id); + // The implementation of the before_frame_send_callback doesn't look at the + // error code, so for now it's populated with 0. + PopulateFrame(frame, frame_type, stream_id, length, flags, /*error_code=*/0, + it->second.before_sent_headers); + it->second.before_sent_headers = true; + return callbacks_->before_frame_send_callback(nullptr, &frame, user_data_); + } + return 0; +} + +int CallbackVisitor::OnFrameSent(uint8_t frame_type, Http2StreamId stream_id, + size_t length, uint8_t flags, + uint32_t error_code) { + QUICHE_VLOG(1) << "OnFrameSent(stream_id=" << stream_id + << ", type=" << int(frame_type) << ", length=" << length + << ", flags=" << int(flags) << ", error_code=" << error_code + << ")"; + if (callbacks_->on_frame_send_callback) { + nghttp2_frame frame; + auto it = GetStreamInfo(stream_id); + PopulateFrame(frame, frame_type, stream_id, length, flags, error_code, + it->second.sent_headers); + it->second.sent_headers = true; + return callbacks_->on_frame_send_callback(nullptr, &frame, user_data_); + } + return 0; +} + +bool CallbackVisitor::OnInvalidFrame(Http2StreamId stream_id, + InvalidFrameError error) { + QUICHE_VLOG(1) << "OnInvalidFrame(" << stream_id << ", " + << InvalidFrameErrorToString(error) << ")"; + QUICHE_DCHECK_EQ(stream_id, current_frame_.hd.stream_id); + if (callbacks_->on_invalid_frame_recv_callback) { + return 0 == + callbacks_->on_invalid_frame_recv_callback( + nullptr, ¤t_frame_, ToNgHttp2ErrorCode(error), user_data_); + } + return true; +} + +void CallbackVisitor::OnBeginMetadataForStream(Http2StreamId stream_id, + size_t payload_length) { + QUICHE_VLOG(1) << "OnBeginMetadataForStream(stream_id=" << stream_id + << ", payload_length=" << payload_length << ")"; +} + +bool CallbackVisitor::OnMetadataForStream(Http2StreamId stream_id, + absl::string_view metadata) { + QUICHE_VLOG(1) << "OnMetadataForStream(stream_id=" << stream_id + << ", len=" << metadata.size() << ")"; + if (callbacks_->on_extension_chunk_recv_callback) { + int result = callbacks_->on_extension_chunk_recv_callback( + nullptr, ¤t_frame_.hd, ToUint8Ptr(metadata.data()), + metadata.size(), user_data_); + return result == 0; + } + return true; +} + +bool CallbackVisitor::OnMetadataEndForStream(Http2StreamId stream_id) { + QUICHE_LOG_IF(DFATAL, current_frame_.hd.flags != kMetadataEndFlag); + QUICHE_VLOG(1) << "OnMetadataEndForStream(stream_id=" << stream_id << ")"; + if (callbacks_->unpack_extension_callback) { + void* payload; + int result = callbacks_->unpack_extension_callback( + nullptr, &payload, ¤t_frame_.hd, user_data_); + if (result == 0 && callbacks_->on_frame_recv_callback) { + current_frame_.ext.payload = payload; + result = callbacks_->on_frame_recv_callback(nullptr, ¤t_frame_, + user_data_); + } + return (result == 0); + } + return true; +} + +void CallbackVisitor::OnErrorDebug(absl::string_view message) { + QUICHE_VLOG(1) << "OnErrorDebug(message=[" << absl::CEscape(message) << "])"; + if (callbacks_->error_callback2) { + callbacks_->error_callback2(nullptr, -1, message.data(), message.size(), + user_data_); + } +} + +CallbackVisitor::StreamInfoMap::iterator CallbackVisitor::GetStreamInfo( + Http2StreamId stream_id) { + auto it = stream_map_.find(stream_id); + if (it == stream_map_.end()) { + auto p = stream_map_.insert({stream_id, {}}); + it = p.first; + } + return it; +} + +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/callback_visitor.h b/quiche/http2/adapter/callback_visitor.h new file mode 100644 index 000000000000..5132104c66bf --- /dev/null +++ b/quiche/http2/adapter/callback_visitor.h @@ -0,0 +1,113 @@ +#ifndef QUICHE_HTTP2_ADAPTER_CALLBACK_VISITOR_H_ +#define QUICHE_HTTP2_ADAPTER_CALLBACK_VISITOR_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "quiche/http2/adapter/http2_visitor_interface.h" +#include "quiche/http2/adapter/nghttp2.h" +#include "quiche/http2/adapter/nghttp2_util.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace adapter { + +// This visitor implementation accepts a set of nghttp2 callbacks and a "user +// data" pointer, and invokes the callbacks according to HTTP/2 events received. +class QUICHE_EXPORT CallbackVisitor : public Http2VisitorInterface { + public: + // Called when the visitor receives a close event for `stream_id`. + using StreamCloseListener = std::function; + + explicit CallbackVisitor(Perspective perspective, + const nghttp2_session_callbacks& callbacks, + void* user_data); + + int64_t OnReadyToSend(absl::string_view serialized) override; + void OnConnectionError(ConnectionError error) override; + bool OnFrameHeader(Http2StreamId stream_id, size_t length, uint8_t type, + uint8_t flags) override; + void OnSettingsStart() override; + void OnSetting(Http2Setting setting) override; + void OnSettingsEnd() override; + void OnSettingsAck() override; + bool OnBeginHeadersForStream(Http2StreamId stream_id) override; + OnHeaderResult OnHeaderForStream(Http2StreamId stream_id, + absl::string_view name, + absl::string_view value) override; + bool OnEndHeadersForStream(Http2StreamId stream_id) override; + bool OnDataPaddingLength(Http2StreamId stream_id, + size_t padding_length) override; + bool OnBeginDataForStream(Http2StreamId stream_id, + size_t payload_length) override; + bool OnDataForStream(Http2StreamId stream_id, + absl::string_view data) override; + bool OnEndStream(Http2StreamId stream_id) override; + void OnRstStream(Http2StreamId stream_id, Http2ErrorCode error_code) override; + bool OnCloseStream(Http2StreamId stream_id, + Http2ErrorCode error_code) override; + void OnPriorityForStream(Http2StreamId stream_id, + Http2StreamId parent_stream_id, int weight, + bool exclusive) override; + void OnPing(Http2PingId ping_id, bool is_ack) override; + void OnPushPromiseForStream(Http2StreamId stream_id, + Http2StreamId promised_stream_id) override; + bool OnGoAway(Http2StreamId last_accepted_stream_id, + Http2ErrorCode error_code, + absl::string_view opaque_data) override; + void OnWindowUpdate(Http2StreamId stream_id, int window_increment) override; + int OnBeforeFrameSent(uint8_t frame_type, Http2StreamId stream_id, + size_t length, uint8_t flags) override; + int OnFrameSent(uint8_t frame_type, Http2StreamId stream_id, size_t length, + uint8_t flags, uint32_t error_code) override; + bool OnInvalidFrame(Http2StreamId stream_id, + InvalidFrameError error) override; + void OnBeginMetadataForStream(Http2StreamId stream_id, + size_t payload_length) override; + bool OnMetadataForStream(Http2StreamId stream_id, + absl::string_view metadata) override; + bool OnMetadataEndForStream(Http2StreamId stream_id) override; + void OnErrorDebug(absl::string_view message) override; + + size_t stream_map_size() const { return stream_map_.size(); } + + void set_stream_close_listener(StreamCloseListener stream_close_listener) { + stream_close_listener_ = std::move(stream_close_listener); + } + + private: + struct QUICHE_EXPORT StreamInfo { + bool before_sent_headers = false; + bool sent_headers = false; + bool received_headers = false; + }; + + using StreamInfoMap = absl::flat_hash_map; + + void PopulateFrame(nghttp2_frame& frame, uint8_t frame_type, + Http2StreamId stream_id, size_t length, uint8_t flags, + uint32_t error_code, bool sent_headers); + + // Creates the StreamInfoMap entry if it doesn't exist. + StreamInfoMap::iterator GetStreamInfo(Http2StreamId stream_id); + + StreamInfoMap stream_map_; + + StreamCloseListener stream_close_listener_; + + Perspective perspective_; + nghttp2_session_callbacks_unique_ptr callbacks_; + void* user_data_; + + nghttp2_frame current_frame_; + std::vector settings_; + size_t remaining_data_ = 0; +}; + +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_CALLBACK_VISITOR_H_ diff --git a/quiche/http2/adapter/callback_visitor_test.cc b/quiche/http2/adapter/callback_visitor_test.cc new file mode 100644 index 000000000000..4d0355e62aaa --- /dev/null +++ b/quiche/http2/adapter/callback_visitor_test.cc @@ -0,0 +1,536 @@ +#include "quiche/http2/adapter/callback_visitor.h" + +#include "absl/container/flat_hash_map.h" +#include "quiche/http2/adapter/http2_protocol.h" +#include "quiche/http2/adapter/mock_nghttp2_callbacks.h" +#include "quiche/http2/adapter/nghttp2_adapter.h" +#include "quiche/http2/adapter/nghttp2_test_utils.h" +#include "quiche/http2/adapter/test_frame_sequence.h" +#include "quiche/http2/adapter/test_utils.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace adapter { +namespace test { +namespace { + +using testing::_; +using testing::IsEmpty; +using testing::Pair; +using testing::UnorderedElementsAre; + +enum FrameType { + DATA, + HEADERS, + PRIORITY, + RST_STREAM, + SETTINGS, + PUSH_PROMISE, + PING, + GOAWAY, + WINDOW_UPDATE, + CONTINUATION, +}; + +// Tests connection-level events. +TEST(ClientCallbackVisitorUnitTest, ConnectionFrames) { + testing::StrictMock callbacks; + CallbackVisitor visitor(Perspective::kClient, + *MockNghttp2Callbacks::GetCallbacks(), &callbacks); + + testing::InSequence seq; + + // SETTINGS + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(0, SETTINGS, _))); + visitor.OnFrameHeader(0, 0, SETTINGS, 0); + + visitor.OnSettingsStart(); + EXPECT_CALL(callbacks, OnFrameRecv(IsSettings(testing::IsEmpty()))); + visitor.OnSettingsEnd(); + + // PING + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(0, PING, _))); + visitor.OnFrameHeader(0, 8, PING, 0); + + EXPECT_CALL(callbacks, OnFrameRecv(IsPing(42))); + visitor.OnPing(42, false); + + // WINDOW_UPDATE + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(0, WINDOW_UPDATE, _))); + visitor.OnFrameHeader(0, 4, WINDOW_UPDATE, 0); + + EXPECT_CALL(callbacks, OnFrameRecv(IsWindowUpdate(1000))); + visitor.OnWindowUpdate(0, 1000); + + // PING ack + EXPECT_CALL(callbacks, + OnBeginFrame(HasFrameHeader(0, PING, NGHTTP2_FLAG_ACK))); + visitor.OnFrameHeader(0, 8, PING, 1); + + EXPECT_CALL(callbacks, OnFrameRecv(IsPingAck(247))); + visitor.OnPing(247, true); + + // GOAWAY + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(0, GOAWAY, 0))); + visitor.OnFrameHeader(0, 19, GOAWAY, 0); + + EXPECT_CALL(callbacks, OnFrameRecv(IsGoAway(5, NGHTTP2_ENHANCE_YOUR_CALM, + "calm down!!"))); + visitor.OnGoAway(5, Http2ErrorCode::ENHANCE_YOUR_CALM, "calm down!!"); + + EXPECT_EQ(visitor.stream_map_size(), 0); +} + +TEST(ClientCallbackVisitorUnitTest, StreamFrames) { + testing::StrictMock callbacks; + CallbackVisitor visitor(Perspective::kClient, + *MockNghttp2Callbacks::GetCallbacks(), &callbacks); + absl::flat_hash_map stream_close_counts; + visitor.set_stream_close_listener( + [&stream_close_counts](Http2StreamId stream_id) { + ++stream_close_counts[stream_id]; + }); + + testing::InSequence seq; + + EXPECT_EQ(visitor.stream_map_size(), 0); + + // HEADERS on stream 1 + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(1, HEADERS, _))); + visitor.OnFrameHeader(1, 23, HEADERS, 4); + + EXPECT_CALL(callbacks, + OnBeginHeaders(IsHeaders(1, _, NGHTTP2_HCAT_RESPONSE))); + visitor.OnBeginHeadersForStream(1); + + EXPECT_EQ(visitor.stream_map_size(), 1); + + EXPECT_CALL(callbacks, OnHeader(_, ":status", "200", _)); + visitor.OnHeaderForStream(1, ":status", "200"); + + EXPECT_CALL(callbacks, OnHeader(_, "server", "my-fake-server", _)); + visitor.OnHeaderForStream(1, "server", "my-fake-server"); + + EXPECT_CALL(callbacks, + OnHeader(_, "date", "Tue, 6 Apr 2021 12:54:01 GMT", _)); + visitor.OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT"); + + EXPECT_CALL(callbacks, OnHeader(_, "trailer", "x-server-status", _)); + visitor.OnHeaderForStream(1, "trailer", "x-server-status"); + + EXPECT_CALL(callbacks, OnFrameRecv(IsHeaders(1, _, NGHTTP2_HCAT_RESPONSE))); + visitor.OnEndHeadersForStream(1); + + // DATA for stream 1 + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(1, DATA, 0))); + visitor.OnFrameHeader(1, 26, DATA, 0); + + visitor.OnBeginDataForStream(1, 26); + EXPECT_CALL(callbacks, OnDataChunkRecv(0, 1, "This is the response body.")); + EXPECT_CALL(callbacks, OnFrameRecv(IsData(1, _, 0))); + visitor.OnDataForStream(1, "This is the response body."); + + // Trailers for stream 1, with a different nghttp2 "category". + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(1, HEADERS, _))); + visitor.OnFrameHeader(1, 23, HEADERS, 4); + + EXPECT_CALL(callbacks, OnBeginHeaders(IsHeaders(1, _, NGHTTP2_HCAT_HEADERS))); + visitor.OnBeginHeadersForStream(1); + + EXPECT_CALL(callbacks, OnHeader(_, "x-server-status", "OK", _)); + visitor.OnHeaderForStream(1, "x-server-status", "OK"); + + EXPECT_CALL(callbacks, OnFrameRecv(IsHeaders(1, _, NGHTTP2_HCAT_HEADERS))); + visitor.OnEndHeadersForStream(1); + + EXPECT_THAT(stream_close_counts, IsEmpty()); + + // RST_STREAM on stream 3 + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(3, RST_STREAM, 0))); + visitor.OnFrameHeader(3, 4, RST_STREAM, 0); + + // No change in stream map size. + EXPECT_EQ(visitor.stream_map_size(), 1); + EXPECT_THAT(stream_close_counts, IsEmpty()); + + EXPECT_CALL(callbacks, OnFrameRecv(IsRstStream(3, NGHTTP2_INTERNAL_ERROR))); + visitor.OnRstStream(3, Http2ErrorCode::INTERNAL_ERROR); + + EXPECT_CALL(callbacks, OnStreamClose(3, NGHTTP2_INTERNAL_ERROR)); + visitor.OnCloseStream(3, Http2ErrorCode::INTERNAL_ERROR); + + EXPECT_THAT(stream_close_counts, UnorderedElementsAre(Pair(3, 1))); + + // More stream close events + EXPECT_CALL(callbacks, + OnBeginFrame(HasFrameHeader(1, DATA, NGHTTP2_FLAG_END_STREAM))); + visitor.OnFrameHeader(1, 0, DATA, 1); + + EXPECT_CALL(callbacks, OnFrameRecv(IsData(1, _, NGHTTP2_FLAG_END_STREAM))); + visitor.OnBeginDataForStream(1, 0); + EXPECT_TRUE(visitor.OnEndStream(1)); + + EXPECT_CALL(callbacks, OnStreamClose(1, NGHTTP2_NO_ERROR)); + visitor.OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR); + + // Stream map is empty again after both streams were closed. + EXPECT_EQ(visitor.stream_map_size(), 0); + EXPECT_THAT(stream_close_counts, + UnorderedElementsAre(Pair(3, 1), Pair(1, 1))); + + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(5, RST_STREAM, _))); + visitor.OnFrameHeader(5, 4, RST_STREAM, 0); + + EXPECT_CALL(callbacks, OnFrameRecv(IsRstStream(5, NGHTTP2_REFUSED_STREAM))); + visitor.OnRstStream(5, Http2ErrorCode::REFUSED_STREAM); + + EXPECT_CALL(callbacks, OnStreamClose(5, NGHTTP2_REFUSED_STREAM)); + visitor.OnCloseStream(5, Http2ErrorCode::REFUSED_STREAM); + + EXPECT_EQ(visitor.stream_map_size(), 0); + EXPECT_THAT(stream_close_counts, + UnorderedElementsAre(Pair(3, 1), Pair(1, 1), Pair(5, 1))); +} + +TEST(ClientCallbackVisitorUnitTest, HeadersWithContinuation) { + testing::StrictMock callbacks; + CallbackVisitor visitor(Perspective::kClient, + *MockNghttp2Callbacks::GetCallbacks(), &callbacks); + + testing::InSequence seq; + + // HEADERS on stream 1 + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(1, HEADERS, 0x0))); + ASSERT_TRUE(visitor.OnFrameHeader(1, 23, HEADERS, 0x0)); + + EXPECT_CALL(callbacks, + OnBeginHeaders(IsHeaders(1, _, NGHTTP2_HCAT_RESPONSE))); + visitor.OnBeginHeadersForStream(1); + + EXPECT_CALL(callbacks, OnHeader(_, ":status", "200", _)); + visitor.OnHeaderForStream(1, ":status", "200"); + + EXPECT_CALL(callbacks, OnHeader(_, "server", "my-fake-server", _)); + visitor.OnHeaderForStream(1, "server", "my-fake-server"); + + EXPECT_CALL(callbacks, + OnBeginFrame(HasFrameHeader(1, CONTINUATION, END_HEADERS_FLAG))); + ASSERT_TRUE(visitor.OnFrameHeader(1, 23, CONTINUATION, END_HEADERS_FLAG)); + + EXPECT_CALL(callbacks, + OnHeader(_, "date", "Tue, 6 Apr 2021 12:54:01 GMT", _)); + visitor.OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT"); + + EXPECT_CALL(callbacks, OnHeader(_, "trailer", "x-server-status", _)); + visitor.OnHeaderForStream(1, "trailer", "x-server-status"); + + EXPECT_CALL(callbacks, OnFrameRecv(IsHeaders(1, _, NGHTTP2_HCAT_RESPONSE))); + visitor.OnEndHeadersForStream(1); +} + +TEST(ClientCallbackVisitorUnitTest, ContinuationNoHeaders) { + testing::StrictMock callbacks; + CallbackVisitor visitor(Perspective::kClient, + *MockNghttp2Callbacks::GetCallbacks(), &callbacks); + // Because no stream precedes the CONTINUATION frame, the stream ID does not + // match, and the method returns false. + EXPECT_FALSE(visitor.OnFrameHeader(1, 23, CONTINUATION, END_HEADERS_FLAG)); +} + +TEST(ClientCallbackVisitorUnitTest, ContinuationWrongPrecedingType) { + testing::StrictMock callbacks; + CallbackVisitor visitor(Perspective::kClient, + *MockNghttp2Callbacks::GetCallbacks(), &callbacks); + + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(1, WINDOW_UPDATE, _))); + visitor.OnFrameHeader(1, 4, WINDOW_UPDATE, 0); + + // Because the CONTINUATION frame does not follow HEADERS, the method returns + // false. + EXPECT_FALSE(visitor.OnFrameHeader(1, 23, CONTINUATION, END_HEADERS_FLAG)); +} + +TEST(ClientCallbackVisitorUnitTest, ContinuationWrongStream) { + testing::StrictMock callbacks; + CallbackVisitor visitor(Perspective::kClient, + *MockNghttp2Callbacks::GetCallbacks(), &callbacks); + // HEADERS on stream 1 + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(1, HEADERS, 0x0))); + ASSERT_TRUE(visitor.OnFrameHeader(1, 23, HEADERS, 0x0)); + + EXPECT_CALL(callbacks, + OnBeginHeaders(IsHeaders(1, _, NGHTTP2_HCAT_RESPONSE))); + visitor.OnBeginHeadersForStream(1); + + EXPECT_CALL(callbacks, OnHeader(_, ":status", "200", _)); + visitor.OnHeaderForStream(1, ":status", "200"); + + EXPECT_CALL(callbacks, OnHeader(_, "server", "my-fake-server", _)); + visitor.OnHeaderForStream(1, "server", "my-fake-server"); + + // The CONTINUATION stream ID does not match the one from the HEADERS. + EXPECT_FALSE(visitor.OnFrameHeader(3, 23, CONTINUATION, END_HEADERS_FLAG)); +} + +TEST(ClientCallbackVisitorUnitTest, ResetAndGoaway) { + testing::StrictMock callbacks; + CallbackVisitor visitor(Perspective::kClient, + *MockNghttp2Callbacks::GetCallbacks(), &callbacks); + + testing::InSequence seq; + + // RST_STREAM on stream 1 + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(1, RST_STREAM, 0x0))); + EXPECT_TRUE(visitor.OnFrameHeader(1, 13, RST_STREAM, 0x0)); + + EXPECT_CALL(callbacks, OnFrameRecv(IsRstStream(1, NGHTTP2_INTERNAL_ERROR))); + visitor.OnRstStream(1, Http2ErrorCode::INTERNAL_ERROR); + + EXPECT_CALL(callbacks, OnStreamClose(1, NGHTTP2_INTERNAL_ERROR)); + EXPECT_TRUE(visitor.OnCloseStream(1, Http2ErrorCode::INTERNAL_ERROR)); + + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(0, GOAWAY, 0x0))); + EXPECT_TRUE(visitor.OnFrameHeader(0, 13, GOAWAY, 0x0)); + + EXPECT_CALL(callbacks, + OnFrameRecv(IsGoAway(3, NGHTTP2_ENHANCE_YOUR_CALM, "calma te"))); + EXPECT_TRUE( + visitor.OnGoAway(3, Http2ErrorCode::ENHANCE_YOUR_CALM, "calma te")); + + EXPECT_CALL(callbacks, OnStreamClose(5, NGHTTP2_STREAM_CLOSED)) + .WillOnce(testing::Return(NGHTTP2_ERR_CALLBACK_FAILURE)); + EXPECT_FALSE(visitor.OnCloseStream(5, Http2ErrorCode::STREAM_CLOSED)); +} + +TEST(ServerCallbackVisitorUnitTest, ConnectionFrames) { + testing::StrictMock callbacks; + CallbackVisitor visitor(Perspective::kServer, + *MockNghttp2Callbacks::GetCallbacks(), &callbacks); + + testing::InSequence seq; + + // SETTINGS + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(0, SETTINGS, _))); + visitor.OnFrameHeader(0, 0, SETTINGS, 0); + + visitor.OnSettingsStart(); + EXPECT_CALL(callbacks, OnFrameRecv(IsSettings(testing::IsEmpty()))); + visitor.OnSettingsEnd(); + + // PING + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(0, PING, _))); + visitor.OnFrameHeader(0, 8, PING, 0); + + EXPECT_CALL(callbacks, OnFrameRecv(IsPing(42))); + visitor.OnPing(42, false); + + // WINDOW_UPDATE + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(0, WINDOW_UPDATE, _))); + visitor.OnFrameHeader(0, 4, WINDOW_UPDATE, 0); + + EXPECT_CALL(callbacks, OnFrameRecv(IsWindowUpdate(1000))); + visitor.OnWindowUpdate(0, 1000); + + // PING ack + EXPECT_CALL(callbacks, + OnBeginFrame(HasFrameHeader(0, PING, NGHTTP2_FLAG_ACK))); + visitor.OnFrameHeader(0, 8, PING, 1); + + EXPECT_CALL(callbacks, OnFrameRecv(IsPingAck(247))); + visitor.OnPing(247, true); + + EXPECT_EQ(visitor.stream_map_size(), 0); +} + +TEST(ServerCallbackVisitorUnitTest, StreamFrames) { + testing::StrictMock callbacks; + CallbackVisitor visitor(Perspective::kServer, + *MockNghttp2Callbacks::GetCallbacks(), &callbacks); + + testing::InSequence seq; + + // HEADERS on stream 1 + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader( + 1, HEADERS, NGHTTP2_FLAG_END_HEADERS))); + visitor.OnFrameHeader(1, 23, HEADERS, 4); + + EXPECT_CALL(callbacks, OnBeginHeaders(IsHeaders(1, NGHTTP2_FLAG_END_HEADERS, + NGHTTP2_HCAT_REQUEST))); + visitor.OnBeginHeadersForStream(1); + + EXPECT_EQ(visitor.stream_map_size(), 1); + + EXPECT_CALL(callbacks, OnHeader(_, ":method", "POST", _)); + visitor.OnHeaderForStream(1, ":method", "POST"); + + EXPECT_CALL(callbacks, OnHeader(_, ":path", "/example/path", _)); + visitor.OnHeaderForStream(1, ":path", "/example/path"); + + EXPECT_CALL(callbacks, OnHeader(_, ":scheme", "https", _)); + visitor.OnHeaderForStream(1, ":scheme", "https"); + + EXPECT_CALL(callbacks, OnHeader(_, ":authority", "example.com", _)); + visitor.OnHeaderForStream(1, ":authority", "example.com"); + + EXPECT_CALL(callbacks, OnHeader(_, "accept", "text/html", _)); + visitor.OnHeaderForStream(1, "accept", "text/html"); + + EXPECT_CALL(callbacks, OnFrameRecv(IsHeaders(1, NGHTTP2_FLAG_END_HEADERS, + NGHTTP2_HCAT_REQUEST))); + visitor.OnEndHeadersForStream(1); + + // DATA on stream 1 + EXPECT_CALL(callbacks, + OnBeginFrame(HasFrameHeader(1, DATA, NGHTTP2_FLAG_END_STREAM))); + visitor.OnFrameHeader(1, 25, DATA, NGHTTP2_FLAG_END_STREAM); + + visitor.OnBeginDataForStream(1, 25); + EXPECT_CALL(callbacks, OnDataChunkRecv(NGHTTP2_FLAG_END_STREAM, 1, + "This is the request body.")); + EXPECT_CALL(callbacks, OnFrameRecv(IsData(1, _, NGHTTP2_FLAG_END_STREAM))); + visitor.OnDataForStream(1, "This is the request body."); + EXPECT_TRUE(visitor.OnEndStream(1)); + + EXPECT_CALL(callbacks, OnStreamClose(1, NGHTTP2_NO_ERROR)); + visitor.OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR); + + EXPECT_EQ(visitor.stream_map_size(), 0); + + // RST_STREAM on stream 3 + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(3, RST_STREAM, 0))); + visitor.OnFrameHeader(3, 4, RST_STREAM, 0); + + EXPECT_CALL(callbacks, OnFrameRecv(IsRstStream(3, NGHTTP2_INTERNAL_ERROR))); + visitor.OnRstStream(3, Http2ErrorCode::INTERNAL_ERROR); + + EXPECT_CALL(callbacks, OnStreamClose(3, NGHTTP2_INTERNAL_ERROR)); + visitor.OnCloseStream(3, Http2ErrorCode::INTERNAL_ERROR); + + EXPECT_EQ(visitor.stream_map_size(), 0); +} + +TEST(ServerCallbackVisitorUnitTest, DataWithPadding) { + testing::StrictMock callbacks; + CallbackVisitor visitor(Perspective::kServer, + *MockNghttp2Callbacks::GetCallbacks(), &callbacks); + + const size_t kPaddingLength = 39; + const uint8_t kFlags = NGHTTP2_FLAG_PADDED | NGHTTP2_FLAG_END_STREAM; + + testing::InSequence seq; + + // DATA on stream 1 + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(1, DATA, kFlags))); + EXPECT_TRUE(visitor.OnFrameHeader(1, 25 + kPaddingLength, DATA, kFlags)); + + EXPECT_TRUE(visitor.OnBeginDataForStream(1, 25 + kPaddingLength)); + + // Padding before data. + EXPECT_TRUE(visitor.OnDataPaddingLength(1, kPaddingLength)); + + EXPECT_CALL(callbacks, + OnDataChunkRecv(kFlags, 1, "This is the request body.")); + EXPECT_CALL(callbacks, OnFrameRecv(IsData(1, _, kFlags, kPaddingLength))); + EXPECT_TRUE(visitor.OnDataForStream(1, "This is the request body.")); + EXPECT_TRUE(visitor.OnEndStream(1)); + + EXPECT_CALL(callbacks, OnStreamClose(1, NGHTTP2_NO_ERROR)); + visitor.OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR); + + // DATA on stream 3 + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(3, DATA, kFlags))); + EXPECT_TRUE(visitor.OnFrameHeader(3, 25 + kPaddingLength, DATA, kFlags)); + + EXPECT_TRUE(visitor.OnBeginDataForStream(3, 25 + kPaddingLength)); + + // Data before padding. + EXPECT_CALL(callbacks, + OnDataChunkRecv(kFlags, 3, "This is the request body.")); + EXPECT_TRUE(visitor.OnDataForStream(3, "This is the request body.")); + + EXPECT_CALL(callbacks, OnFrameRecv(IsData(3, _, kFlags, kPaddingLength))); + EXPECT_TRUE(visitor.OnDataPaddingLength(3, kPaddingLength)); + EXPECT_TRUE(visitor.OnEndStream(3)); + + EXPECT_CALL(callbacks, OnStreamClose(3, NGHTTP2_NO_ERROR)); + visitor.OnCloseStream(3, Http2ErrorCode::HTTP2_NO_ERROR); + + // DATA on stream 5 + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(5, DATA, kFlags))); + EXPECT_TRUE(visitor.OnFrameHeader(5, 25 + kPaddingLength, DATA, kFlags)); + + EXPECT_TRUE(visitor.OnBeginDataForStream(5, 25 + kPaddingLength)); + + // Error during padding. + EXPECT_CALL(callbacks, + OnDataChunkRecv(kFlags, 5, "This is the request body.")); + EXPECT_TRUE(visitor.OnDataForStream(5, "This is the request body.")); + + EXPECT_CALL(callbacks, OnFrameRecv(IsData(5, _, kFlags, kPaddingLength))) + .WillOnce(testing::Return(NGHTTP2_ERR_CALLBACK_FAILURE)); + EXPECT_TRUE(visitor.OnDataPaddingLength(5, kPaddingLength)); + EXPECT_FALSE(visitor.OnEndStream(3)); + + EXPECT_CALL(callbacks, OnStreamClose(5, NGHTTP2_NO_ERROR)); + visitor.OnCloseStream(5, Http2ErrorCode::HTTP2_NO_ERROR); +} + +// In the case of a Content-Length mismatch where the header value is larger +// than the actual data for the stream, nghttp2 will call +// `on_begin_frame_callback` and `on_data_chunk_recv_callback`, but not the +// `on_frame_recv_callback`. +TEST(ServerCallbackVisitorUnitTest, MismatchedContentLengthCallbacks) { + testing::StrictMock callbacks; + CallbackVisitor visitor(Perspective::kServer, + *MockNghttp2Callbacks::GetCallbacks(), &callbacks); + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}, + {"content-length", "50"}}, + /*fin=*/false) + .Data(1, "Less than 50 bytes.", true) + .Serialize(); + + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(0, SETTINGS, _))); + + EXPECT_CALL(callbacks, OnFrameRecv(IsSettings(testing::IsEmpty()))); + + // HEADERS on stream 1 + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader( + 1, HEADERS, NGHTTP2_FLAG_END_HEADERS))); + + EXPECT_CALL(callbacks, OnBeginHeaders(IsHeaders(1, NGHTTP2_FLAG_END_HEADERS, + NGHTTP2_HCAT_REQUEST))); + + EXPECT_CALL(callbacks, OnHeader(_, ":method", "POST", _)); + EXPECT_CALL(callbacks, OnHeader(_, ":path", "/", _)); + EXPECT_CALL(callbacks, OnHeader(_, ":scheme", "https", _)); + EXPECT_CALL(callbacks, OnHeader(_, ":authority", "example.com", _)); + EXPECT_CALL(callbacks, OnHeader(_, "content-length", "50", _)); + EXPECT_CALL(callbacks, OnFrameRecv(IsHeaders(1, NGHTTP2_FLAG_END_HEADERS, + NGHTTP2_HCAT_REQUEST))); + + // DATA on stream 1 + EXPECT_CALL(callbacks, + OnBeginFrame(HasFrameHeader(1, DATA, NGHTTP2_FLAG_END_STREAM))); + + EXPECT_CALL(callbacks, OnDataChunkRecv(NGHTTP2_FLAG_END_STREAM, 1, + "Less than 50 bytes.")); + + // Like nghttp2, CallbackVisitor does not pass on a call to OnFrameRecv in the + // case of Content-Length mismatch. + + int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); +} + +} // namespace +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/data_source.h b/quiche/http2/adapter/data_source.h new file mode 100644 index 000000000000..ffd78a0eef86 --- /dev/null +++ b/quiche/http2/adapter/data_source.h @@ -0,0 +1,60 @@ +#ifndef QUICHE_HTTP2_ADAPTER_DATA_SOURCE_H_ +#define QUICHE_HTTP2_ADAPTER_DATA_SOURCE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace adapter { + +// Represents a source of DATA frames for transmission to the peer. +class QUICHE_EXPORT DataFrameSource { + public: + virtual ~DataFrameSource() {} + + enum : int64_t { kBlocked = 0, kError = -1 }; + + // Returns the number of bytes to send in the next DATA frame, and whether + // this frame indicates the end of the data. Returns {kBlocked, false} if + // blocked, {kError, false} on error. + virtual std::pair SelectPayloadLength(size_t max_length) = 0; + + // This method is called with a frame header and a payload length to send. The + // source should send or buffer the entire frame and return true, or return + // false without sending or buffering anything. + virtual bool Send(absl::string_view frame_header, size_t payload_length) = 0; + + // If true, the end of this data source indicates the end of the stream. + // Otherwise, this data will be followed by trailers. + virtual bool send_fin() const = 0; +}; + +// Represents a source of metadata frames for transmission to the peer. +class QUICHE_EXPORT MetadataSource { + public: + virtual ~MetadataSource() {} + + // Returns the number of frames of at most |max_frame_size| required to + // serialize the metadata for this source. Only required by the nghttp2 + // implementation. + virtual size_t NumFrames(size_t max_frame_size) const = 0; + + // This method is called with a destination buffer and length. It should + // return the number of payload bytes copied to |dest|, or a negative integer + // to indicate an error, as well as a boolean indicating whether the metadata + // has been completely copied. + virtual std::pair Pack(uint8_t* dest, size_t dest_len) = 0; + + // This method is called when transmission of the metadata for this source + // fails in a non-recoverable way. + virtual void OnFailure() = 0; +}; + +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_DATA_SOURCE_H_ diff --git a/quiche/http2/adapter/event_forwarder.cc b/quiche/http2/adapter/event_forwarder.cc new file mode 100644 index 000000000000..62a57c6f49b0 --- /dev/null +++ b/quiche/http2/adapter/event_forwarder.cc @@ -0,0 +1,198 @@ +#include "quiche/http2/adapter/event_forwarder.h" + +namespace http2 { +namespace adapter { + +EventForwarder::EventForwarder(ForwardPredicate can_forward, + spdy::SpdyFramerVisitorInterface& receiver) + : can_forward_(std::move(can_forward)), receiver_(receiver) {} + +void EventForwarder::OnError(Http2DecoderAdapter::SpdyFramerError error, + std::string detailed_error) { + if (can_forward_()) { + receiver_.OnError(error, std::move(detailed_error)); + } +} + +void EventForwarder::OnCommonHeader(spdy::SpdyStreamId stream_id, size_t length, + uint8_t type, uint8_t flags) { + if (can_forward_()) { + receiver_.OnCommonHeader(stream_id, length, type, flags); + } +} + +void EventForwarder::OnDataFrameHeader(spdy::SpdyStreamId stream_id, + size_t length, bool fin) { + if (can_forward_()) { + receiver_.OnDataFrameHeader(stream_id, length, fin); + } +} + +void EventForwarder::OnStreamFrameData(spdy::SpdyStreamId stream_id, + const char* data, size_t len) { + if (can_forward_()) { + receiver_.OnStreamFrameData(stream_id, data, len); + } +} + +void EventForwarder::OnStreamEnd(spdy::SpdyStreamId stream_id) { + if (can_forward_()) { + receiver_.OnStreamEnd(stream_id); + } +} + +void EventForwarder::OnStreamPadLength(spdy::SpdyStreamId stream_id, + size_t value) { + if (can_forward_()) { + receiver_.OnStreamPadLength(stream_id, value); + } +} + +void EventForwarder::OnStreamPadding(spdy::SpdyStreamId stream_id, size_t len) { + if (can_forward_()) { + receiver_.OnStreamPadding(stream_id, len); + } +} + +spdy::SpdyHeadersHandlerInterface* EventForwarder::OnHeaderFrameStart( + spdy::SpdyStreamId stream_id) { + return receiver_.OnHeaderFrameStart(stream_id); +} + +void EventForwarder::OnHeaderFrameEnd(spdy::SpdyStreamId stream_id) { + if (can_forward_()) { + receiver_.OnHeaderFrameEnd(stream_id); + } +} + +void EventForwarder::OnRstStream(spdy::SpdyStreamId stream_id, + spdy::SpdyErrorCode error_code) { + if (can_forward_()) { + receiver_.OnRstStream(stream_id, error_code); + } +} + +void EventForwarder::OnSettings() { + if (can_forward_()) { + receiver_.OnSettings(); + } +} + +void EventForwarder::OnSetting(spdy::SpdySettingsId id, uint32_t value) { + if (can_forward_()) { + receiver_.OnSetting(id, value); + } +} + +void EventForwarder::OnSettingsEnd() { + if (can_forward_()) { + receiver_.OnSettingsEnd(); + } +} + +void EventForwarder::OnSettingsAck() { + if (can_forward_()) { + receiver_.OnSettingsAck(); + } +} + +void EventForwarder::OnPing(spdy::SpdyPingId unique_id, bool is_ack) { + if (can_forward_()) { + receiver_.OnPing(unique_id, is_ack); + } +} + +void EventForwarder::OnGoAway(spdy::SpdyStreamId last_accepted_stream_id, + spdy::SpdyErrorCode error_code) { + if (can_forward_()) { + receiver_.OnGoAway(last_accepted_stream_id, error_code); + } +} + +bool EventForwarder::OnGoAwayFrameData(const char* goaway_data, size_t len) { + if (can_forward_()) { + return receiver_.OnGoAwayFrameData(goaway_data, len); + } + return false; +} + +void EventForwarder::OnHeaders(spdy::SpdyStreamId stream_id, + size_t payload_length, bool has_priority, + int weight, spdy::SpdyStreamId parent_stream_id, + bool exclusive, bool fin, bool end) { + if (can_forward_()) { + receiver_.OnHeaders(stream_id, payload_length, has_priority, weight, + parent_stream_id, exclusive, fin, end); + } +} + +void EventForwarder::OnWindowUpdate(spdy::SpdyStreamId stream_id, + int delta_window_size) { + if (can_forward_()) { + receiver_.OnWindowUpdate(stream_id, delta_window_size); + } +} + +void EventForwarder::OnPushPromise(spdy::SpdyStreamId stream_id, + spdy::SpdyStreamId promised_stream_id, + bool end) { + if (can_forward_()) { + receiver_.OnPushPromise(stream_id, promised_stream_id, end); + } +} + +void EventForwarder::OnContinuation(spdy::SpdyStreamId stream_id, + size_t payload_length, bool end) { + if (can_forward_()) { + receiver_.OnContinuation(stream_id, payload_length, end); + } +} + +void EventForwarder::OnAltSvc( + spdy::SpdyStreamId stream_id, absl::string_view origin, + const spdy::SpdyAltSvcWireFormat::AlternativeServiceVector& altsvc_vector) { + if (can_forward_()) { + receiver_.OnAltSvc(stream_id, origin, altsvc_vector); + } +} + +void EventForwarder::OnPriority(spdy::SpdyStreamId stream_id, + spdy::SpdyStreamId parent_stream_id, int weight, + bool exclusive) { + if (can_forward_()) { + receiver_.OnPriority(stream_id, parent_stream_id, weight, exclusive); + } +} + +void EventForwarder::OnPriorityUpdate(spdy::SpdyStreamId prioritized_stream_id, + absl::string_view priority_field_value) { + if (can_forward_()) { + receiver_.OnPriorityUpdate(prioritized_stream_id, priority_field_value); + } +} + +bool EventForwarder::OnUnknownFrame(spdy::SpdyStreamId stream_id, + uint8_t frame_type) { + if (can_forward_()) { + return receiver_.OnUnknownFrame(stream_id, frame_type); + } + return false; +} + +void EventForwarder::OnUnknownFrameStart(spdy::SpdyStreamId stream_id, + size_t length, uint8_t type, + uint8_t flags) { + if (can_forward_()) { + receiver_.OnUnknownFrameStart(stream_id, length, type, flags); + } +} + +void EventForwarder::OnUnknownFramePayload(spdy::SpdyStreamId stream_id, + absl::string_view payload) { + if (can_forward_()) { + receiver_.OnUnknownFramePayload(stream_id, payload); + } +} + +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/event_forwarder.h b/quiche/http2/adapter/event_forwarder.h new file mode 100644 index 000000000000..74140f97e6c4 --- /dev/null +++ b/quiche/http2/adapter/event_forwarder.h @@ -0,0 +1,81 @@ +#ifndef QUICHE_HTTP2_ADAPTER_EVENT_FORWARDER_H_ +#define QUICHE_HTTP2_ADAPTER_EVENT_FORWARDER_H_ + +#include + +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/spdy/core/http2_frame_decoder_adapter.h" + +namespace http2 { +namespace adapter { + +// Forwards events to a provided SpdyFramerVisitorInterface receiver if the +// provided predicate succeeds. Currently, OnHeaderFrameStart() is always +// forwarded regardless of the predicate. +// TODO(diannahu): Add a NoOpHeadersHandler if needed. +class QUICHE_EXPORT EventForwarder : public spdy::SpdyFramerVisitorInterface { + public: + // Whether the forwarder can forward events to the receiver. + using ForwardPredicate = std::function; + + EventForwarder(ForwardPredicate can_forward, + spdy::SpdyFramerVisitorInterface& receiver); + + void OnError(Http2DecoderAdapter::SpdyFramerError error, + std::string detailed_error) override; + void OnCommonHeader(spdy::SpdyStreamId stream_id, size_t length, uint8_t type, + uint8_t flags) override; + void OnDataFrameHeader(spdy::SpdyStreamId stream_id, size_t length, + bool fin) override; + void OnStreamFrameData(spdy::SpdyStreamId stream_id, const char* data, + size_t len) override; + void OnStreamEnd(spdy::SpdyStreamId stream_id) override; + void OnStreamPadLength(spdy::SpdyStreamId stream_id, size_t value) override; + void OnStreamPadding(spdy::SpdyStreamId stream_id, size_t len) override; + spdy::SpdyHeadersHandlerInterface* OnHeaderFrameStart( + spdy::SpdyStreamId stream_id) override; + void OnHeaderFrameEnd(spdy::SpdyStreamId stream_id) override; + void OnRstStream(spdy::SpdyStreamId stream_id, + spdy::SpdyErrorCode error_code) override; + void OnSettings() override; + void OnSetting(spdy::SpdySettingsId id, uint32_t value) override; + void OnSettingsEnd() override; + void OnSettingsAck() override; + void OnPing(spdy::SpdyPingId unique_id, bool is_ack) override; + void OnGoAway(spdy::SpdyStreamId last_accepted_stream_id, + spdy::SpdyErrorCode error_code) override; + bool OnGoAwayFrameData(const char* goaway_data, size_t len) override; + void OnHeaders(spdy::SpdyStreamId stream_id, size_t payload_length, + bool has_priority, int weight, + spdy::SpdyStreamId parent_stream_id, bool exclusive, bool fin, + bool end) override; + void OnWindowUpdate(spdy::SpdyStreamId stream_id, + int delta_window_size) override; + void OnPushPromise(spdy::SpdyStreamId stream_id, + spdy::SpdyStreamId promised_stream_id, bool end) override; + void OnContinuation(spdy::SpdyStreamId stream_id, size_t payload_length, + bool end) override; + void OnAltSvc(spdy::SpdyStreamId stream_id, absl::string_view origin, + const spdy::SpdyAltSvcWireFormat::AlternativeServiceVector& + altsvc_vector) override; + void OnPriority(spdy::SpdyStreamId stream_id, + spdy::SpdyStreamId parent_stream_id, int weight, + bool exclusive) override; + void OnPriorityUpdate(spdy::SpdyStreamId prioritized_stream_id, + absl::string_view priority_field_value) override; + bool OnUnknownFrame(spdy::SpdyStreamId stream_id, + uint8_t frame_type) override; + void OnUnknownFrameStart(spdy::SpdyStreamId stream_id, size_t length, + uint8_t type, uint8_t flags) override; + void OnUnknownFramePayload(spdy::SpdyStreamId stream_id, + absl::string_view payload) override; + + private: + ForwardPredicate can_forward_; + spdy::SpdyFramerVisitorInterface& receiver_; +}; + +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_EVENT_FORWARDER_H_ diff --git a/quiche/http2/adapter/event_forwarder_test.cc b/quiche/http2/adapter/event_forwarder_test.cc new file mode 100644 index 000000000000..5e89504b5124 --- /dev/null +++ b/quiche/http2/adapter/event_forwarder_test.cc @@ -0,0 +1,235 @@ +#include "quiche/http2/adapter/event_forwarder.h" + +#include + +#include "absl/strings/string_view.h" +#include "quiche/http2/adapter/http2_protocol.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/spdy/core/spdy_protocol.h" +#include "quiche/spdy/test_tools/mock_spdy_framer_visitor.h" + +namespace http2 { +namespace adapter { +namespace test { +namespace { + +constexpr absl::string_view some_data = "Here is some data for events"; +constexpr spdy::SpdyStreamId stream_id = 1; +constexpr spdy::SpdyErrorCode error_code = + spdy::SpdyErrorCode::ERROR_CODE_ENHANCE_YOUR_CALM; +constexpr size_t length = 42; + +TEST(EventForwarderTest, ForwardsEventsWithTruePredicate) { + spdy::test::MockSpdyFramerVisitor receiver; + receiver.DelegateHeaderHandling(); + EventForwarder event_forwarder([]() { return true; }, receiver); + + EXPECT_CALL( + receiver, + OnError(Http2DecoderAdapter::SpdyFramerError::SPDY_STOP_PROCESSING, + std::string(some_data))); + event_forwarder.OnError( + Http2DecoderAdapter::SpdyFramerError::SPDY_STOP_PROCESSING, + std::string(some_data)); + + EXPECT_CALL(receiver, + OnCommonHeader(stream_id, length, /*type=*/0x0, END_STREAM_FLAG)); + event_forwarder.OnCommonHeader(stream_id, length, /*type=*/0x0, + END_STREAM_FLAG); + + EXPECT_CALL(receiver, OnDataFrameHeader(stream_id, length, /*fin=*/true)); + event_forwarder.OnDataFrameHeader(stream_id, length, /*fin=*/true); + + EXPECT_CALL(receiver, + OnStreamFrameData(stream_id, some_data.data(), some_data.size())); + event_forwarder.OnStreamFrameData(stream_id, some_data.data(), + some_data.size()); + + EXPECT_CALL(receiver, OnStreamEnd(stream_id)); + event_forwarder.OnStreamEnd(stream_id); + + EXPECT_CALL(receiver, OnStreamPadLength(stream_id, length)); + event_forwarder.OnStreamPadLength(stream_id, length); + + EXPECT_CALL(receiver, OnStreamPadding(stream_id, length)); + event_forwarder.OnStreamPadding(stream_id, length); + + EXPECT_CALL(receiver, OnHeaderFrameStart(stream_id)); + spdy::SpdyHeadersHandlerInterface* handler = + event_forwarder.OnHeaderFrameStart(stream_id); + EXPECT_EQ(handler, receiver.ReturnTestHeadersHandler(stream_id)); + + EXPECT_CALL(receiver, OnHeaderFrameEnd(stream_id)); + event_forwarder.OnHeaderFrameEnd(stream_id); + + EXPECT_CALL(receiver, OnRstStream(stream_id, error_code)); + event_forwarder.OnRstStream(stream_id, error_code); + + EXPECT_CALL(receiver, OnSettings()); + event_forwarder.OnSettings(); + + EXPECT_CALL( + receiver, + OnSetting(spdy::SpdyKnownSettingsId::SETTINGS_MAX_CONCURRENT_STREAMS, + 100)); + event_forwarder.OnSetting( + spdy::SpdyKnownSettingsId::SETTINGS_MAX_CONCURRENT_STREAMS, 100); + + EXPECT_CALL(receiver, OnSettingsEnd()); + event_forwarder.OnSettingsEnd(); + + EXPECT_CALL(receiver, OnSettingsAck()); + event_forwarder.OnSettingsAck(); + + EXPECT_CALL(receiver, OnPing(/*unique_id=*/42, /*is_ack=*/false)); + event_forwarder.OnPing(/*unique_id=*/42, /*is_ack=*/false); + + EXPECT_CALL(receiver, OnGoAway(stream_id, error_code)); + event_forwarder.OnGoAway(stream_id, error_code); + + EXPECT_CALL(receiver, OnGoAwayFrameData(some_data.data(), some_data.size())); + event_forwarder.OnGoAwayFrameData(some_data.data(), some_data.size()); + + EXPECT_CALL(receiver, + OnHeaders(stream_id, /*payload_length=*/1234, + /*has_priority=*/false, /*weight=*/42, stream_id + 2, + /*exclusive=*/false, /*fin=*/true, /*end=*/true)); + event_forwarder.OnHeaders(stream_id, /*payload_length=*/1234, + /*has_priority=*/false, /*weight=*/42, + stream_id + 2, /*exclusive=*/false, /*fin=*/true, + /*end=*/true); + + EXPECT_CALL(receiver, OnWindowUpdate(stream_id, /*delta_window_size=*/42)); + event_forwarder.OnWindowUpdate(stream_id, /*delta_window_size=*/42); + + EXPECT_CALL(receiver, OnPushPromise(stream_id, stream_id + 1, /*end=*/true)); + event_forwarder.OnPushPromise(stream_id, stream_id + 1, /*end=*/true); + + EXPECT_CALL(receiver, + OnContinuation(stream_id, /*payload_length=*/42, /*end=*/true)); + event_forwarder.OnContinuation(stream_id, /*payload_length=*/42, + /*end=*/true); + + const spdy::SpdyAltSvcWireFormat::AlternativeServiceVector altsvc_vector; + EXPECT_CALL(receiver, OnAltSvc(stream_id, some_data, altsvc_vector)); + event_forwarder.OnAltSvc(stream_id, some_data, altsvc_vector); + + EXPECT_CALL(receiver, OnPriority(stream_id, stream_id + 2, /*weight=*/42, + /*exclusive=*/false)); + event_forwarder.OnPriority(stream_id, stream_id + 2, /*weight=*/42, + /*exclusive=*/false); + + EXPECT_CALL(receiver, OnPriorityUpdate(stream_id, some_data)); + event_forwarder.OnPriorityUpdate(stream_id, some_data); + + EXPECT_CALL(receiver, OnUnknownFrame(stream_id, /*frame_type=*/0x4D)); + event_forwarder.OnUnknownFrame(stream_id, /*frame_type=*/0x4D); + + EXPECT_CALL(receiver, OnUnknownFrameStart(stream_id, /*length=*/42, + /*type=*/0x4D, /*flags=*/0x0)); + event_forwarder.OnUnknownFrameStart(stream_id, /*length=*/42, /*type=*/0x4D, + /*flags=*/0x0); +} + +TEST(EventForwarderTest, DoesNotForwardEventsWithFalsePredicate) { + spdy::test::MockSpdyFramerVisitor receiver; + receiver.DelegateHeaderHandling(); + EventForwarder event_forwarder([]() { return false; }, receiver); + + EXPECT_CALL(receiver, OnError).Times(0); + event_forwarder.OnError( + Http2DecoderAdapter::SpdyFramerError::SPDY_STOP_PROCESSING, + std::string(some_data)); + + EXPECT_CALL(receiver, OnCommonHeader).Times(0); + event_forwarder.OnCommonHeader(stream_id, length, /*type=*/0x0, + END_STREAM_FLAG); + + EXPECT_CALL(receiver, OnDataFrameHeader).Times(0); + event_forwarder.OnDataFrameHeader(stream_id, length, /*fin=*/true); + + EXPECT_CALL(receiver, OnStreamFrameData).Times(0); + event_forwarder.OnStreamFrameData(stream_id, some_data.data(), + some_data.size()); + + EXPECT_CALL(receiver, OnStreamEnd).Times(0); + event_forwarder.OnStreamEnd(stream_id); + + EXPECT_CALL(receiver, OnStreamPadLength).Times(0); + event_forwarder.OnStreamPadLength(stream_id, length); + + EXPECT_CALL(receiver, OnStreamPadding).Times(0); + event_forwarder.OnStreamPadding(stream_id, length); + + EXPECT_CALL(receiver, OnHeaderFrameStart(stream_id)); + spdy::SpdyHeadersHandlerInterface* handler = + event_forwarder.OnHeaderFrameStart(stream_id); + EXPECT_EQ(handler, receiver.ReturnTestHeadersHandler(stream_id)); + + EXPECT_CALL(receiver, OnHeaderFrameEnd).Times(0); + event_forwarder.OnHeaderFrameEnd(stream_id); + + EXPECT_CALL(receiver, OnRstStream).Times(0); + event_forwarder.OnRstStream(stream_id, error_code); + + EXPECT_CALL(receiver, OnSettings).Times(0); + event_forwarder.OnSettings(); + + EXPECT_CALL(receiver, OnSetting).Times(0); + event_forwarder.OnSetting( + spdy::SpdyKnownSettingsId::SETTINGS_MAX_CONCURRENT_STREAMS, 100); + + EXPECT_CALL(receiver, OnSettingsEnd).Times(0); + event_forwarder.OnSettingsEnd(); + + EXPECT_CALL(receiver, OnSettingsAck).Times(0); + event_forwarder.OnSettingsAck(); + + EXPECT_CALL(receiver, OnPing).Times(0); + event_forwarder.OnPing(/*unique_id=*/42, /*is_ack=*/false); + + EXPECT_CALL(receiver, OnGoAway).Times(0); + event_forwarder.OnGoAway(stream_id, error_code); + + EXPECT_CALL(receiver, OnGoAwayFrameData).Times(0); + event_forwarder.OnGoAwayFrameData(some_data.data(), some_data.size()); + + EXPECT_CALL(receiver, OnHeaders).Times(0); + event_forwarder.OnHeaders(stream_id, /*payload_length=*/1234, + /*has_priority=*/false, /*weight=*/42, + stream_id + 2, /*exclusive=*/false, /*fin=*/true, + /*end=*/true); + + EXPECT_CALL(receiver, OnWindowUpdate).Times(0); + event_forwarder.OnWindowUpdate(stream_id, /*delta_window_size=*/42); + + EXPECT_CALL(receiver, OnPushPromise).Times(0); + event_forwarder.OnPushPromise(stream_id, stream_id + 1, /*end=*/true); + + EXPECT_CALL(receiver, OnContinuation).Times(0); + event_forwarder.OnContinuation(stream_id, /*payload_length=*/42, + /*end=*/true); + + EXPECT_CALL(receiver, OnAltSvc).Times(0); + const spdy::SpdyAltSvcWireFormat::AlternativeServiceVector altsvc_vector; + event_forwarder.OnAltSvc(stream_id, some_data, altsvc_vector); + + EXPECT_CALL(receiver, OnPriority).Times(0); + event_forwarder.OnPriority(stream_id, stream_id + 2, /*weight=*/42, + /*exclusive=*/false); + + EXPECT_CALL(receiver, OnPriorityUpdate).Times(0); + event_forwarder.OnPriorityUpdate(stream_id, some_data); + + EXPECT_CALL(receiver, OnUnknownFrame).Times(0); + event_forwarder.OnUnknownFrame(stream_id, /*frame_type=*/0x4D); + + EXPECT_CALL(receiver, OnUnknownFrameStart).Times(0); + event_forwarder.OnUnknownFrameStart(stream_id, /*length=*/42, /*type=*/0x4D, + /*flags=*/0x0); +} + +} // namespace +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/header_validator.cc b/quiche/http2/adapter/header_validator.cc new file mode 100644 index 000000000000..e474558fe061 --- /dev/null +++ b/quiche/http2/adapter/header_validator.cc @@ -0,0 +1,298 @@ +#include "quiche/http2/adapter/header_validator.h" + +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/numbers.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { +namespace adapter { + +namespace { + +const absl::string_view kHttp2HeaderNameAllowedChars = + "!#$%&\'*+-.0123456789" + "^_`abcdefghijklmnopqrstuvwxyz|~"; + +const absl::string_view kHttp2HeaderValueAllowedChars = + "\t " + "!\"#$%&'()*+,-./" + "0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`" + "abcdefghijklmnopqrstuvwxyz{|}~"; + +const absl::string_view kHttp2StatusValueAllowedChars = "0123456789"; + +const absl::string_view kValidAuthorityChars = + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~%!$&'()[" + "]*+,;=:"; + +using CharMap = std::array; + +CharMap BuildValidCharMap(absl::string_view valid_chars) { + CharMap map; + map.fill(false); + for (char c : valid_chars) { + // Cast to uint8_t, guaranteed to have 8 bits. A char may have more, leading + // to possible indices above 256. + map[static_cast(c)] = true; + } + return map; +} +CharMap AllowObsText(CharMap map) { + // Characters above 0x80 are allowed in header field values as `obs-text` in + // RFC 7230. + for (uint8_t c = 0xff; c >= 0x80; --c) { + map[c] = true; + } + return map; +} + +bool AllCharsInMap(absl::string_view str, const CharMap& map) { + for (char c : str) { + if (!map[static_cast(c)]) { + return false; + } + } + return true; +} + +bool IsValidHeaderName(absl::string_view name) { + static const CharMap valid_chars = + BuildValidCharMap(kHttp2HeaderNameAllowedChars); + return AllCharsInMap(name, valid_chars); +} + +bool IsValidStatus(absl::string_view status) { + static const CharMap valid_chars = + BuildValidCharMap(kHttp2StatusValueAllowedChars); + return AllCharsInMap(status, valid_chars); +} + +bool ValidateRequestHeaders(const std::vector& pseudo_headers, + absl::optional& authority, + absl::string_view method, absl::string_view path, + bool allow_extended_connect) { + QUICHE_VLOG(2) << "Request pseudo-headers: [" + << absl::StrJoin(pseudo_headers, ", ") + << "], allow_extended_connect: " << allow_extended_connect + << ", authority: " + << (authority ? authority.value() : "") + << ", method: " << method << ", path: " << path; + if (method == "CONNECT") { + if (allow_extended_connect) { + // See RFC 8441. + static const std::vector* kExtendedConnectHeaders = + new std::vector( + {":authority", ":method", ":path", ":protocol", ":scheme"}); + if (pseudo_headers == *kExtendedConnectHeaders) { + return true; + } + } + // See RFC 7540 Section 8.3. + static const std::vector* kConnectHeaders = + new std::vector({":authority", ":method"}); + return authority.has_value() && !authority.value().empty() && + pseudo_headers == *kConnectHeaders; + } + + if (path.empty()) { + return false; + } + if (path == "*") { + if (method != "OPTIONS") { + return false; + } + } else if (path[0] != '/') { + return false; + } + + static const std::vector* kRequiredHeaders = + new std::vector( + {":authority", ":method", ":path", ":scheme"}); + return pseudo_headers == *kRequiredHeaders; +} + +bool ValidateRequestTrailers(const std::vector& pseudo_headers) { + return pseudo_headers.empty(); +} + +bool ValidateResponseHeaders(const std::vector& pseudo_headers) { + static const std::vector* kRequiredHeaders = + new std::vector({":status"}); + return pseudo_headers == *kRequiredHeaders; +} + +bool ValidateResponseTrailers(const std::vector& pseudo_headers) { + return pseudo_headers.empty(); +} + +} // namespace + +void HeaderValidator::StartHeaderBlock() { + HeaderValidatorBase::StartHeaderBlock(); + pseudo_headers_.clear(); + method_.clear(); + path_.clear(); + authority_ = absl::nullopt; +} + +HeaderValidator::HeaderStatus HeaderValidator::ValidateSingleHeader( + absl::string_view key, absl::string_view value) { + if (key.empty()) { + return HEADER_FIELD_INVALID; + } + if (max_field_size_.has_value() && + key.size() + value.size() > max_field_size_.value()) { + QUICHE_VLOG(2) << "Header field size is " << key.size() + value.size() + << ", exceeds max size of " << max_field_size_.value(); + return HEADER_FIELD_TOO_LONG; + } + const absl::string_view validated_key = key[0] == ':' ? key.substr(1) : key; + if (!IsValidHeaderName(validated_key)) { + QUICHE_VLOG(2) << "invalid chars in header name: [" + << absl::CEscape(validated_key) << "]"; + return HEADER_FIELD_INVALID; + } + if (!IsValidHeaderValue(value, obs_text_option_)) { + QUICHE_VLOG(2) << "invalid chars in header value: [" << absl::CEscape(value) + << "]"; + return HEADER_FIELD_INVALID; + } + if (key[0] == ':') { + if (key == ":status") { + if (value.size() != 3 || !IsValidStatus(value)) { + QUICHE_VLOG(2) << "malformed status value: [" << absl::CEscape(value) + << "]"; + return HEADER_FIELD_INVALID; + } + if (value == "101") { + // Switching protocols is not allowed on a HTTP/2 stream. + return HEADER_FIELD_INVALID; + } + status_ = std::string(value); + } else if (key == ":method") { + method_ = std::string(value); + } else if (key == ":authority" && !ValidateAndSetAuthority(value)) { + return HEADER_FIELD_INVALID; + } else if (key == ":path") { + if (value.empty()) { + // For now, reject an empty path regardless of scheme. + return HEADER_FIELD_INVALID; + } + path_ = std::string(value); + } + pseudo_headers_.push_back(std::string(key)); + } else if (key == "host") { + if (!status_.empty()) { + // Response headers can contain "Host". + } else { + if (!authority_.has_value()) { + pseudo_headers_.push_back(std::string(":authority")); + } + if (!ValidateAndSetAuthority(value)) { + return HEADER_FIELD_INVALID; + } + } + } else if (key == "content-length") { + const ContentLengthStatus status = HandleContentLength(value); + switch (status) { + case CONTENT_LENGTH_ERROR: + return HEADER_FIELD_INVALID; + case CONTENT_LENGTH_SKIP: + return HEADER_SKIP; + case CONTENT_LENGTH_OK: + return HEADER_OK; + default: + return HEADER_FIELD_INVALID; + } + } else if (key == "te" && value != "trailers") { + return HEADER_FIELD_INVALID; + } else if (key == "upgrade" || GetInvalidHttp2HeaderSet().contains(key)) { + // TODO(b/78024822): Remove the "upgrade" here once it's added to + // GetInvalidHttp2HeaderSet(). + return HEADER_FIELD_INVALID; + } + return HEADER_OK; +} + +// Returns true if all required pseudoheaders and no extra pseudoheaders are +// present for the given header type. +bool HeaderValidator::FinishHeaderBlock(HeaderType type) { + std::sort(pseudo_headers_.begin(), pseudo_headers_.end()); + switch (type) { + case HeaderType::REQUEST: + return ValidateRequestHeaders(pseudo_headers_, authority_, method_, path_, + allow_extended_connect_); + case HeaderType::REQUEST_TRAILER: + return ValidateRequestTrailers(pseudo_headers_); + case HeaderType::RESPONSE_100: + case HeaderType::RESPONSE: + return ValidateResponseHeaders(pseudo_headers_); + case HeaderType::RESPONSE_TRAILER: + return ValidateResponseTrailers(pseudo_headers_); + } + return false; +} + +bool HeaderValidator::IsValidHeaderValue(absl::string_view value, + ObsTextOption option) { + static const CharMap valid_chars = + BuildValidCharMap(kHttp2HeaderValueAllowedChars); + static const CharMap valid_chars_with_obs_text = + AllowObsText(BuildValidCharMap(kHttp2HeaderValueAllowedChars)); + return AllCharsInMap(value, option == ObsTextOption::kAllow + ? valid_chars_with_obs_text + : valid_chars); +} + +bool HeaderValidator::IsValidAuthority(absl::string_view authority) { + static const CharMap valid_chars = BuildValidCharMap(kValidAuthorityChars); + return AllCharsInMap(authority, valid_chars); +} + +HeaderValidator::ContentLengthStatus HeaderValidator::HandleContentLength( + absl::string_view value) { + if (value.empty()) { + return CONTENT_LENGTH_ERROR; + } + + if (status_ == "204" && value != "0") { + // There should be no body in a "204 No Content" response. + return CONTENT_LENGTH_ERROR; + } + if (!status_.empty() && status_[0] == '1' && value != "0") { + // There should also be no body in a 1xx response. + return CONTENT_LENGTH_ERROR; + } + + size_t content_length = 0; + const bool valid = absl::SimpleAtoi(value, &content_length); + if (!valid) { + return CONTENT_LENGTH_ERROR; + } + + if (content_length_.has_value()) { + return content_length == content_length_.value() ? CONTENT_LENGTH_SKIP + : CONTENT_LENGTH_ERROR; + } + content_length_ = content_length; + return CONTENT_LENGTH_OK; +} + +// Returns whether `authority` contains only characters from the `host` ABNF +// from RFC 3986 section 3.2.2. +bool HeaderValidator::ValidateAndSetAuthority(absl::string_view authority) { + if (!IsValidAuthority(authority)) { + return false; + } + if (authority_.has_value() && authority != authority_.value()) { + return false; + } + authority_ = std::string(authority); + return true; +} + +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/header_validator.h b/quiche/http2/adapter/header_validator.h new file mode 100644 index 000000000000..a6070db77062 --- /dev/null +++ b/quiche/http2/adapter/header_validator.h @@ -0,0 +1,54 @@ +#ifndef QUICHE_HTTP2_ADAPTER_HEADER_VALIDATOR_H_ +#define QUICHE_HTTP2_ADAPTER_HEADER_VALIDATOR_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/http2/adapter/header_validator_base.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace adapter { + +class QUICHE_EXPORT HeaderValidator : public HeaderValidatorBase { + public: + HeaderValidator() = default; + + void StartHeaderBlock() override; + + HeaderStatus ValidateSingleHeader(absl::string_view key, + absl::string_view value) override; + + // Returns true if all required pseudoheaders and no extra pseudoheaders are + // present for the given header type. + bool FinishHeaderBlock(HeaderType type) override; + + // Returns whether `value` is valid according to RFC 9110 Section 5.5 and RFC + // 9112 Section 8.2.1. + static bool IsValidHeaderValue(absl::string_view value, + ObsTextOption ops_text_option); + + // Returns whether `authority` is valid according to RFC 3986 Section 3.2. + static bool IsValidAuthority(absl::string_view authority); + + private: + enum ContentLengthStatus { + CONTENT_LENGTH_OK, + CONTENT_LENGTH_SKIP, // Used to handle duplicate content length values. + CONTENT_LENGTH_ERROR, + }; + ContentLengthStatus HandleContentLength(absl::string_view value); + bool ValidateAndSetAuthority(absl::string_view authority); + + std::vector pseudo_headers_; + absl::optional authority_ = absl::nullopt; + std::string method_; + std::string path_; +}; + +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_HEADER_VALIDATOR_H_ diff --git a/quiche/http2/adapter/header_validator_base.h b/quiche/http2/adapter/header_validator_base.h new file mode 100644 index 000000000000..2b25afae3236 --- /dev/null +++ b/quiche/http2/adapter/header_validator_base.h @@ -0,0 +1,70 @@ +#ifndef QUICHE_HTTP2_ADAPTER_HEADER_VALIDATOR_BASE_H_ +#define QUICHE_HTTP2_ADAPTER_HEADER_VALIDATOR_BASE_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace adapter { + +enum class HeaderType : uint8_t { + REQUEST, + REQUEST_TRAILER, + RESPONSE_100, + RESPONSE, + RESPONSE_TRAILER, +}; + +enum class ObsTextOption : uint8_t { + kAllow, + kDisallow, +}; + +class QUICHE_EXPORT HeaderValidatorBase { + public: + HeaderValidatorBase() = default; + virtual ~HeaderValidatorBase() = default; + + virtual void StartHeaderBlock() { + status_.clear(); + content_length_ = absl::nullopt; + } + + enum HeaderStatus { + HEADER_OK, + HEADER_SKIP, + HEADER_FIELD_INVALID, + HEADER_FIELD_TOO_LONG, + }; + virtual HeaderStatus ValidateSingleHeader(absl::string_view key, + absl::string_view value) = 0; + + // Should return true if validation was successful. + virtual bool FinishHeaderBlock(HeaderType type) = 0; + + // For responses, returns the value of the ":status" header, if present. + absl::string_view status_header() const { return status_; } + + absl::optional content_length() const { return content_length_; } + + void SetMaxFieldSize(uint32_t field_size) { max_field_size_ = field_size; } + void SetObsTextOption(ObsTextOption option) { obs_text_option_ = option; } + // Allows the "extended CONNECT" syntax described in RFC 8441. + void SetAllowExtendedConnect() { allow_extended_connect_ = true; } + + protected: + std::string status_; + absl::optional max_field_size_; + absl::optional content_length_; + ObsTextOption obs_text_option_ = ObsTextOption::kDisallow; + bool allow_extended_connect_ = false; +}; + +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_HEADER_VALIDATOR_BASE_H_ diff --git a/quiche/http2/adapter/header_validator_test.cc b/quiche/http2/adapter/header_validator_test.cc new file mode 100644 index 000000000000..3dfbd96f6c1b --- /dev/null +++ b/quiche/http2/adapter/header_validator_test.cc @@ -0,0 +1,676 @@ +#include "quiche/http2/adapter/header_validator.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace adapter { +namespace test { + +using ::testing::Optional; + +using Header = std::pair; +constexpr Header kSampleRequestPseudoheaders[] = {{":authority", "www.foo.com"}, + {":method", "GET"}, + {":path", "/foo"}, + {":scheme", "https"}}; + +TEST(HeaderValidatorTest, HeaderNameEmpty) { + HeaderValidator v; + HeaderValidator::HeaderStatus status = v.ValidateSingleHeader("", "value"); + EXPECT_EQ(HeaderValidator::HEADER_FIELD_INVALID, status); +} + +TEST(HeaderValidatorTest, HeaderValueEmpty) { + HeaderValidator v; + HeaderValidator::HeaderStatus status = v.ValidateSingleHeader("name", ""); + EXPECT_EQ(HeaderValidator::HEADER_OK, status); +} + +TEST(HeaderValidatorTest, ExceedsMaxSize) { + HeaderValidator v; + v.SetMaxFieldSize(64u); + HeaderValidator::HeaderStatus status = + v.ValidateSingleHeader("name", "value"); + EXPECT_EQ(HeaderValidator::HEADER_OK, status); + status = v.ValidateSingleHeader( + "name2", + "Antidisestablishmentariansism is supercalifragilisticexpialodocious."); + EXPECT_EQ(HeaderValidator::HEADER_FIELD_TOO_LONG, status); +} + +TEST(HeaderValidatorTest, NameHasInvalidChar) { + HeaderValidator v; + for (const bool is_pseudo_header : {true, false}) { + // These characters should be allowed. (Not exhaustive.) + for (const char* c : {"!", "3", "a", "_", "|", "~"}) { + const std::string name = is_pseudo_header ? absl::StrCat(":met", c, "hod") + : absl::StrCat("na", c, "me"); + HeaderValidator::HeaderStatus status = + v.ValidateSingleHeader(name, "value"); + EXPECT_EQ(HeaderValidator::HEADER_OK, status); + } + // These should not. (Not exhaustive.) + for (const char* c : {"\\", "<", ";", "[", "=", " ", "\r", "\n", ",", "\"", + "\x1F", "\x91"}) { + const std::string name = is_pseudo_header ? absl::StrCat(":met", c, "hod") + : absl::StrCat("na", c, "me"); + HeaderValidator::HeaderStatus status = + v.ValidateSingleHeader(name, "value"); + EXPECT_EQ(HeaderValidator::HEADER_FIELD_INVALID, status); + } + // Test nul separately. + { + const absl::string_view name = is_pseudo_header + ? absl::string_view(":met\0hod", 8) + : absl::string_view("na\0me", 5); + HeaderValidator::HeaderStatus status = + v.ValidateSingleHeader(name, "value"); + EXPECT_EQ(HeaderValidator::HEADER_FIELD_INVALID, status); + } + // Uppercase characters in header names should not be allowed. + const std::string uc_name = is_pseudo_header ? ":Method" : "Name"; + HeaderValidator::HeaderStatus status = + v.ValidateSingleHeader(uc_name, "value"); + EXPECT_EQ(HeaderValidator::HEADER_FIELD_INVALID, status); + } +} + +TEST(HeaderValidatorTest, ValueHasInvalidChar) { + HeaderValidator v; + // These characters should be allowed. (Not exhaustive.) + for (const char* c : + {"!", "3", "a", "_", "|", "~", "\\", "<", ";", "[", "=", "A", "\t"}) { + const std::string value = absl::StrCat("val", c, "ue"); + EXPECT_TRUE( + HeaderValidator::IsValidHeaderValue(value, ObsTextOption::kDisallow)); + HeaderValidator::HeaderStatus status = + v.ValidateSingleHeader("name", value); + EXPECT_EQ(HeaderValidator::HEADER_OK, status); + } + // These should not. + for (const char* c : {"\r", "\n"}) { + const std::string value = absl::StrCat("val", c, "ue"); + EXPECT_FALSE( + HeaderValidator::IsValidHeaderValue(value, ObsTextOption::kDisallow)); + HeaderValidator::HeaderStatus status = + v.ValidateSingleHeader("name", value); + EXPECT_EQ(HeaderValidator::HEADER_FIELD_INVALID, status); + } + // Test nul separately. + { + const std::string value("val\0ue", 6); + EXPECT_FALSE( + HeaderValidator::IsValidHeaderValue(value, ObsTextOption::kDisallow)); + HeaderValidator::HeaderStatus status = + v.ValidateSingleHeader("name", value); + EXPECT_EQ(HeaderValidator::HEADER_FIELD_INVALID, status); + } + { + const std::string obs_text_value = "val\xa9ue"; + // Test that obs-text is disallowed by default. + EXPECT_EQ(HeaderValidator::HEADER_FIELD_INVALID, + v.ValidateSingleHeader("name", obs_text_value)); + // Test that obs-text is disallowed when configured. + v.SetObsTextOption(ObsTextOption::kDisallow); + EXPECT_FALSE(HeaderValidator::IsValidHeaderValue(obs_text_value, + ObsTextOption::kDisallow)); + EXPECT_EQ(HeaderValidator::HEADER_FIELD_INVALID, + v.ValidateSingleHeader("name", obs_text_value)); + // Test that obs-text is allowed when configured. + v.SetObsTextOption(ObsTextOption::kAllow); + EXPECT_TRUE(HeaderValidator::IsValidHeaderValue(obs_text_value, + ObsTextOption::kAllow)); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader("name", obs_text_value)); + } +} + +TEST(HeaderValidatorTest, StatusHasInvalidChar) { + HeaderValidator v; + + for (HeaderType type : {HeaderType::RESPONSE, HeaderType::RESPONSE_100}) { + // When `:status` has a non-digit value, validation will fail. + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_FIELD_INVALID, + v.ValidateSingleHeader(":status", "bar")); + EXPECT_FALSE(v.FinishHeaderBlock(type)); + + // When `:status` is too short, validation will fail. + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_FIELD_INVALID, + v.ValidateSingleHeader(":status", "10")); + EXPECT_FALSE(v.FinishHeaderBlock(type)); + + // When `:status` is too long, validation will fail. + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_FIELD_INVALID, + v.ValidateSingleHeader(":status", "9000")); + EXPECT_FALSE(v.FinishHeaderBlock(type)); + + // When `:status` is just right, validation will succeed. + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "400")); + EXPECT_TRUE(v.FinishHeaderBlock(type)); + } +} + +TEST(HeaderValidatorTest, AuthorityHasInvalidChar) { + for (absl::string_view key : {":authority", "host"}) { + // These characters should be allowed. (Not exhaustive.) + for (const absl::string_view c : {"1", "-", "!", ":", "+", "=", ","}) { + const std::string value = absl::StrCat("ho", c, "st.example.com"); + EXPECT_TRUE(HeaderValidator::IsValidAuthority(value)); + + HeaderValidator v; + v.StartHeaderBlock(); + HeaderValidator::HeaderStatus status = v.ValidateSingleHeader(key, value); + EXPECT_EQ(HeaderValidator::HEADER_OK, status); + } + // These should not. + for (const absl::string_view c : {"\r", "\n", "|", "\\", "`"}) { + const std::string value = absl::StrCat("ho", c, "st.example.com"); + EXPECT_FALSE(HeaderValidator::IsValidAuthority(value)); + + HeaderValidator v; + v.StartHeaderBlock(); + HeaderValidator::HeaderStatus status = v.ValidateSingleHeader(key, value); + EXPECT_EQ(HeaderValidator::HEADER_FIELD_INVALID, status); + } + + { + // IPv4 example + const std::string value = "123.45.67.89"; + EXPECT_TRUE(HeaderValidator::IsValidAuthority(value)); + + HeaderValidator v; + v.StartHeaderBlock(); + HeaderValidator::HeaderStatus status = v.ValidateSingleHeader(key, value); + EXPECT_EQ(HeaderValidator::HEADER_OK, status); + } + + { + // IPv6 examples + const std::string value1 = "2001:0db8:85a3:0000:0000:8a2e:0370:7334"; + EXPECT_TRUE(HeaderValidator::IsValidAuthority(value1)); + + HeaderValidator v; + v.StartHeaderBlock(); + HeaderValidator::HeaderStatus status = + v.ValidateSingleHeader(key, value1); + EXPECT_EQ(HeaderValidator::HEADER_OK, status); + + const std::string value2 = "[::1]:80"; + EXPECT_TRUE(HeaderValidator::IsValidAuthority(value2)); + HeaderValidator v2; + v2.StartHeaderBlock(); + status = v2.ValidateSingleHeader(key, value2); + EXPECT_EQ(HeaderValidator::HEADER_OK, status); + } + + { + // Empty field + EXPECT_TRUE(HeaderValidator::IsValidAuthority("")); + + HeaderValidator v; + v.StartHeaderBlock(); + HeaderValidator::HeaderStatus status = v.ValidateSingleHeader(key, ""); + EXPECT_EQ(HeaderValidator::HEADER_OK, status); + } + } +} + +TEST(HeaderValidatorTest, RequestHostAndAuthority) { + HeaderValidator v; + v.StartHeaderBlock(); + for (Header to_add : kSampleRequestPseudoheaders) { + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + } + // If both "host" and ":authority" have the same value, validation succeeds. + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader("host", "www.foo.com")); + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::REQUEST)); + + v.StartHeaderBlock(); + for (Header to_add : kSampleRequestPseudoheaders) { + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + } + // If "host" and ":authority" have different values, validation fails. + EXPECT_EQ(HeaderValidator::HEADER_FIELD_INVALID, + v.ValidateSingleHeader("host", "www.bar.com")); +} + +TEST(HeaderValidatorTest, RequestPseudoHeaders) { + HeaderValidator v; + for (Header to_skip : kSampleRequestPseudoheaders) { + v.StartHeaderBlock(); + for (Header to_add : kSampleRequestPseudoheaders) { + if (to_add != to_skip) { + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + } + } + // When any pseudo-header is missing, final validation will fail. + EXPECT_FALSE(v.FinishHeaderBlock(HeaderType::REQUEST)); + } + + // When all pseudo-headers are present, final validation will succeed. + v.StartHeaderBlock(); + for (Header to_add : kSampleRequestPseudoheaders) { + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + } + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::REQUEST)); + + // When an extra pseudo-header is present, final validation will fail. + v.StartHeaderBlock(); + for (Header to_add : kSampleRequestPseudoheaders) { + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + } + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":extra", "blah")); + EXPECT_FALSE(v.FinishHeaderBlock(HeaderType::REQUEST)); + + // When a required pseudo-header is repeated, final validation will fail. + for (Header to_repeat : kSampleRequestPseudoheaders) { + v.StartHeaderBlock(); + for (Header to_add : kSampleRequestPseudoheaders) { + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + if (to_add == to_repeat) { + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + } + } + EXPECT_FALSE(v.FinishHeaderBlock(HeaderType::REQUEST)); + } +} + +TEST(HeaderValidatorTest, ConnectHeaders) { + // Too few headers. + HeaderValidator v; + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":authority", "athena.dialup.mit.edu:23")); + EXPECT_FALSE(v.FinishHeaderBlock(HeaderType::REQUEST)); + + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":method", "CONNECT")); + EXPECT_FALSE(v.FinishHeaderBlock(HeaderType::REQUEST)); + + // Too many headers. + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":authority", "athena.dialup.mit.edu:23")); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":method", "CONNECT")); + EXPECT_EQ(HeaderValidator::HEADER_OK, v.ValidateSingleHeader(":path", "/")); + EXPECT_FALSE(v.FinishHeaderBlock(HeaderType::REQUEST)); + + // Empty :authority + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":authority", "")); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":method", "CONNECT")); + EXPECT_FALSE(v.FinishHeaderBlock(HeaderType::REQUEST)); + + // Just right. + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":authority", "athena.dialup.mit.edu:23")); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":method", "CONNECT")); + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::REQUEST)); + + v.SetAllowExtendedConnect(); + // "Classic" CONNECT headers should still be accepted. + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":authority", "athena.dialup.mit.edu:23")); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":method", "CONNECT")); + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::REQUEST)); +} + +TEST(HeaderValidatorTest, WebsocketPseudoHeaders) { + HeaderValidator v; + v.StartHeaderBlock(); + for (Header to_add : kSampleRequestPseudoheaders) { + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + } + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":protocol", "websocket")); + // At this point, `:protocol` is treated as an extra pseudo-header. + EXPECT_FALSE(v.FinishHeaderBlock(HeaderType::REQUEST)); + + // Future header blocks may send the `:protocol` pseudo-header for CONNECT + // requests. + v.SetAllowExtendedConnect(); + + v.StartHeaderBlock(); + for (Header to_add : kSampleRequestPseudoheaders) { + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + } + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":protocol", "websocket")); + // The method is not "CONNECT", so `:protocol` is still treated as an extra + // pseudo-header. + EXPECT_FALSE(v.FinishHeaderBlock(HeaderType::REQUEST)); + + v.StartHeaderBlock(); + for (Header to_add : kSampleRequestPseudoheaders) { + if (to_add.first == ":method") { + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, "CONNECT")); + } else { + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + } + } + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":protocol", "websocket")); + // After allowing the method, `:protocol` is acepted for CONNECT requests. + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::REQUEST)); +} + +TEST(HeaderValidatorTest, AsteriskPathPseudoHeader) { + HeaderValidator v; + + // An asterisk :path should not be allowed for non-OPTIONS requests. + v.StartHeaderBlock(); + for (Header to_add : kSampleRequestPseudoheaders) { + if (to_add.first == ":path") { + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, "*")); + } else { + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + } + } + EXPECT_FALSE(v.FinishHeaderBlock(HeaderType::REQUEST)); + + // An asterisk :path should be allowed for OPTIONS requests. + v.StartHeaderBlock(); + for (Header to_add : kSampleRequestPseudoheaders) { + if (to_add.first == ":path") { + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, "*")); + } else if (to_add.first == ":method") { + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, "OPTIONS")); + } else { + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + } + } + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::REQUEST)); +} + +TEST(HeaderValidatorTest, InvalidPathPseudoHeader) { + HeaderValidator v; + + // An empty path should fail on single header validation and finish. + v.StartHeaderBlock(); + for (Header to_add : kSampleRequestPseudoheaders) { + if (to_add.first == ":path") { + EXPECT_EQ(HeaderValidator::HEADER_FIELD_INVALID, + v.ValidateSingleHeader(to_add.first, "")); + } else { + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + } + } + EXPECT_FALSE(v.FinishHeaderBlock(HeaderType::REQUEST)); + + // A path that does not start with a slash should fail on finish. + v.StartHeaderBlock(); + for (Header to_add : kSampleRequestPseudoheaders) { + if (to_add.first == ":path") { + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, "shawarma")); + } else { + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + } + } + EXPECT_FALSE(v.FinishHeaderBlock(HeaderType::REQUEST)); +} + +TEST(HeaderValidatorTest, ResponsePseudoHeaders) { + HeaderValidator v; + + for (HeaderType type : {HeaderType::RESPONSE, HeaderType::RESPONSE_100}) { + // When `:status` is missing, validation will fail. + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, v.ValidateSingleHeader("foo", "bar")); + EXPECT_FALSE(v.FinishHeaderBlock(type)); + + // When all pseudo-headers are present, final validation will succeed. + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "199")); + EXPECT_TRUE(v.FinishHeaderBlock(type)); + EXPECT_EQ("199", v.status_header()); + + // When `:status` is repeated, validation will fail. + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "199")); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "299")); + EXPECT_FALSE(v.FinishHeaderBlock(type)); + + // When an extra pseudo-header is present, final validation will fail. + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "199")); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":extra", "blorp")); + EXPECT_FALSE(v.FinishHeaderBlock(type)); + } +} + +TEST(HeaderValidatorTest, ResponseWithHost) { + HeaderValidator v; + + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "200")); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader("host", "myserver.com")); + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::RESPONSE)); +} + +TEST(HeaderValidatorTest, Response204) { + HeaderValidator v; + + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "204")); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader("x-content", "is not present")); + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::RESPONSE)); +} + +TEST(HeaderValidatorTest, ResponseWithMultipleIdenticalContentLength) { + HeaderValidator v; + + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "200")); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader("content-length", "13")); + EXPECT_EQ(HeaderValidator::HEADER_SKIP, + v.ValidateSingleHeader("content-length", "13")); +} + +TEST(HeaderValidatorTest, ResponseWithMultipleDifferingContentLength) { + HeaderValidator v; + + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "200")); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader("content-length", "13")); + EXPECT_EQ(HeaderValidator::HEADER_FIELD_INVALID, + v.ValidateSingleHeader("content-length", "17")); +} + +TEST(HeaderValidatorTest, Response204WithContentLengthZero) { + HeaderValidator v; + + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "204")); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader("x-content", "is not present")); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader("content-length", "0")); + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::RESPONSE)); +} + +TEST(HeaderValidatorTest, Response204WithContentLength) { + HeaderValidator v; + + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "204")); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader("x-content", "is not present")); + EXPECT_EQ(HeaderValidator::HEADER_FIELD_INVALID, + v.ValidateSingleHeader("content-length", "1")); +} + +TEST(HeaderValidatorTest, Response100) { + HeaderValidator v; + + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "100")); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader("x-content", "is not present")); + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::RESPONSE)); +} + +TEST(HeaderValidatorTest, Response100WithContentLengthZero) { + HeaderValidator v; + + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "100")); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader("x-content", "is not present")); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader("content-length", "0")); + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::RESPONSE)); +} + +TEST(HeaderValidatorTest, Response100WithContentLength) { + HeaderValidator v; + + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "100")); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader("x-content", "is not present")); + EXPECT_EQ(HeaderValidator::HEADER_FIELD_INVALID, + v.ValidateSingleHeader("content-length", "1")); +} + +TEST(HeaderValidatorTest, ResponseTrailerPseudoHeaders) { + HeaderValidator v; + + // When no pseudo-headers are present, validation will succeed. + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, v.ValidateSingleHeader("foo", "bar")); + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::RESPONSE_TRAILER)); + + // When any pseudo-header is present, final validation will fail. + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "200")); + EXPECT_EQ(HeaderValidator::HEADER_OK, v.ValidateSingleHeader("foo", "bar")); + EXPECT_FALSE(v.FinishHeaderBlock(HeaderType::RESPONSE_TRAILER)); +} + +TEST(HeaderValidatorTest, ValidContentLength) { + HeaderValidator v; + + v.StartHeaderBlock(); + EXPECT_EQ(v.content_length(), absl::nullopt); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader("content-length", "41")); + EXPECT_THAT(v.content_length(), Optional(41)); + + v.StartHeaderBlock(); + EXPECT_EQ(v.content_length(), absl::nullopt); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader("content-length", "42")); + EXPECT_THAT(v.content_length(), Optional(42)); +} + +TEST(HeaderValidatorTest, InvalidContentLength) { + HeaderValidator v; + + v.StartHeaderBlock(); + EXPECT_EQ(v.content_length(), absl::nullopt); + EXPECT_EQ(HeaderValidator::HEADER_FIELD_INVALID, + v.ValidateSingleHeader("content-length", "")); + EXPECT_EQ(v.content_length(), absl::nullopt); + EXPECT_EQ(HeaderValidator::HEADER_FIELD_INVALID, + v.ValidateSingleHeader("content-length", "nan")); + EXPECT_EQ(v.content_length(), absl::nullopt); + EXPECT_EQ(HeaderValidator::HEADER_FIELD_INVALID, + v.ValidateSingleHeader("content-length", "-42")); + EXPECT_EQ(v.content_length(), absl::nullopt); + // End on a positive note. + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader("content-length", "42")); + EXPECT_THAT(v.content_length(), Optional(42)); +} + +TEST(HeaderValidatorTest, TeHeader) { + HeaderValidator v; + + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader("te", "trailers")); + + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_FIELD_INVALID, + v.ValidateSingleHeader("te", "trailers, deflate")); +} + +TEST(HeaderValidatorTest, ConnectionSpecificHeaders) { + const std::vector
connection_headers = { + {"connection", "keep-alive"}, {"proxy-connection", "keep-alive"}, + {"keep-alive", "timeout=42"}, {"transfer-encoding", "chunked"}, + {"upgrade", "h2c"}, + }; + for (const auto& [connection_key, connection_value] : connection_headers) { + HeaderValidator v; + v.StartHeaderBlock(); + for (const auto& [sample_key, sample_value] : kSampleRequestPseudoheaders) { + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(sample_key, sample_value)); + } + EXPECT_EQ(HeaderValidator::HEADER_FIELD_INVALID, + v.ValidateSingleHeader(connection_key, connection_value)); + } +} + +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/http2_adapter.h b/quiche/http2/adapter/http2_adapter.h new file mode 100644 index 000000000000..fb9df138d9d5 --- /dev/null +++ b/quiche/http2/adapter/http2_adapter.h @@ -0,0 +1,163 @@ +#ifndef QUICHE_HTTP2_ADAPTER_HTTP2_ADAPTER_H_ +#define QUICHE_HTTP2_ADAPTER_HTTP2_ADAPTER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "quiche/http2/adapter/data_source.h" +#include "quiche/http2/adapter/http2_protocol.h" +#include "quiche/http2/adapter/http2_session.h" +#include "quiche/http2/adapter/http2_visitor_interface.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace adapter { + +// Http2Adapter is an HTTP/2-processing class that exposes an interface similar +// to the nghttp2 library for processing the HTTP/2 wire format. As nghttp2 +// parses HTTP/2 frames and invokes callbacks on Http2Adapter, Http2Adapter then +// invokes corresponding callbacks on its passed-in Http2VisitorInterface. +// Http2Adapter is a base class shared between client-side and server-side +// implementations. +class QUICHE_EXPORT Http2Adapter { + public: + Http2Adapter(const Http2Adapter&) = delete; + Http2Adapter& operator=(const Http2Adapter&) = delete; + + virtual ~Http2Adapter() {} + + virtual bool IsServerSession() const = 0; + + virtual bool want_read() const = 0; + virtual bool want_write() const = 0; + + // Processes the incoming |bytes| as HTTP/2 and invokes callbacks on the + // |visitor_| as appropriate. + virtual int64_t ProcessBytes(absl::string_view bytes) = 0; + + // Submits the |settings| to be written to the peer, e.g., as part of the + // HTTP/2 connection preface. + virtual void SubmitSettings(absl::Span settings) = 0; + + // Submits a PRIORITY frame for the given stream. + virtual void SubmitPriorityForStream(Http2StreamId stream_id, + Http2StreamId parent_stream_id, + int weight, bool exclusive) = 0; + + // Submits a PING on the connection. + virtual void SubmitPing(Http2PingId ping_id) = 0; + + // Starts a graceful shutdown. A no-op for clients. + virtual void SubmitShutdownNotice() = 0; + + // Submits a GOAWAY on the connection. Note that |last_accepted_stream_id| + // refers to stream IDs initiated by the peer. For a server sending this + // frame, this last stream ID must be odd (or 0). + virtual void SubmitGoAway(Http2StreamId last_accepted_stream_id, + Http2ErrorCode error_code, + absl::string_view opaque_data) = 0; + + // Submits a WINDOW_UPDATE for the given stream (a |stream_id| of 0 indicates + // a connection-level WINDOW_UPDATE). + virtual void SubmitWindowUpdate(Http2StreamId stream_id, + int window_increment) = 0; + + // Submits a RST_STREAM for the given |stream_id| and |error_code|. + virtual void SubmitRst(Http2StreamId stream_id, + Http2ErrorCode error_code) = 0; + + // Submits a sequence of METADATA frames for the given stream. A |stream_id| + // of 0 indicates connection-level METADATA. + virtual void SubmitMetadata(Http2StreamId stream_id, size_t max_frame_size, + std::unique_ptr source) = 0; + + // Invokes the visitor's OnReadyToSend() method for serialized frame data. + // Returns 0 on success. + virtual int Send() = 0; + + // Returns the connection-level flow control window advertised by the peer. + virtual int GetSendWindowSize() const = 0; + + // Returns the stream-level flow control window advertised by the peer. + virtual int GetStreamSendWindowSize(Http2StreamId stream_id) const = 0; + + // Returns the current upper bound on the flow control receive window for this + // stream. This value does not account for data received from the peer. + virtual int GetStreamReceiveWindowLimit(Http2StreamId stream_id) const = 0; + + // Returns the amount of data a peer could send on a given stream. This is + // the outstanding stream receive window. + virtual int GetStreamReceiveWindowSize(Http2StreamId stream_id) const = 0; + + // Returns the total amount of data a peer could send on the connection. This + // is the outstanding connection receive window. + virtual int GetReceiveWindowSize() const = 0; + + // Returns the size of the HPACK encoder's dynamic table, including the + // per-entry overhead from the specification. + virtual int GetHpackEncoderDynamicTableSize() const = 0; + + // Returns the size of the HPACK decoder's dynamic table, including the + // per-entry overhead from the specification. + virtual int GetHpackDecoderDynamicTableSize() const = 0; + + // Gets the highest stream ID value seen in a frame received by this endpoint. + // This method is only guaranteed to work for server endpoints. + virtual Http2StreamId GetHighestReceivedStreamId() const = 0; + + // Marks the given amount of data as consumed for the given stream, which + // enables the implementation layer to send WINDOW_UPDATEs as appropriate. + virtual void MarkDataConsumedForStream(Http2StreamId stream_id, + size_t num_bytes) = 0; + + // Returns the assigned stream ID if the operation succeeds. Otherwise, + // returns a negative integer indicating an error code. |data_source| may be + // nullptr if the request does not have a body. + virtual int32_t SubmitRequest(absl::Span headers, + std::unique_ptr data_source, + void* user_data) = 0; + + // Returns 0 on success. |data_source| may be nullptr if the response does not + // have a body. + virtual int SubmitResponse(Http2StreamId stream_id, + absl::Span headers, + std::unique_ptr data_source) = 0; + + // Queues trailers to be sent after any outstanding data on the stream with ID + // |stream_id|. Returns 0 on success. + virtual int SubmitTrailer(Http2StreamId stream_id, + absl::Span trailers) = 0; + + // Sets a user data pointer for the given stream. Can be called after + // SubmitRequest/SubmitResponse, or after receiving any frame for a given + // stream. + virtual void SetStreamUserData(Http2StreamId stream_id, void* user_data) = 0; + + // Returns nullptr if the stream does not exist, or if stream user data has + // not been set. + virtual void* GetStreamUserData(Http2StreamId stream_id) = 0; + + // Resumes a stream that was previously blocked (for example, due to + // DataFrameSource::SelectPayloadLength() returning kBlocked). Returns true if + // the stream was successfully resumed. + virtual bool ResumeStream(Http2StreamId stream_id) = 0; + + protected: + // Subclasses should expose a public factory method for constructing and + // initializing (via Initialize()) adapter instances. + explicit Http2Adapter(Http2VisitorInterface& visitor) : visitor_(visitor) {} + + // Accessors. Do not transfer ownership. + Http2VisitorInterface& visitor() { return visitor_; } + + private: + // Http2Adapter will invoke callbacks upon the |visitor_| while processing. + Http2VisitorInterface& visitor_; +}; + +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_HTTP2_ADAPTER_H_ diff --git a/quiche/http2/adapter/http2_protocol.cc b/quiche/http2/adapter/http2_protocol.cc new file mode 100644 index 000000000000..6469d324fdb7 --- /dev/null +++ b/quiche/http2/adapter/http2_protocol.cc @@ -0,0 +1,85 @@ +#include "quiche/http2/adapter/http2_protocol.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" + +namespace http2 { +namespace adapter { + +ABSL_CONST_INIT const char kHttp2MethodPseudoHeader[] = ":method"; +ABSL_CONST_INIT const char kHttp2SchemePseudoHeader[] = ":scheme"; +ABSL_CONST_INIT const char kHttp2AuthorityPseudoHeader[] = ":authority"; +ABSL_CONST_INIT const char kHttp2PathPseudoHeader[] = ":path"; +ABSL_CONST_INIT const char kHttp2StatusPseudoHeader[] = ":status"; + +ABSL_CONST_INIT const uint8_t kMetadataFrameType = 0x4d; +ABSL_CONST_INIT const uint8_t kMetadataEndFlag = 0x04; +ABSL_CONST_INIT const uint16_t kMetadataExtensionId = 0x4d44; + +std::pair GetStringView(const HeaderRep& rep) { + if (absl::holds_alternative(rep)) { + return std::make_pair(absl::get(rep), true); + } else { + absl::string_view view = absl::get(rep); + return std::make_pair(view, false); + } +} + +bool operator==(const Http2Setting& a, const Http2Setting& b) { + return a.id == b.id && a.value == b.value; +} + +absl::string_view Http2SettingsIdToString(uint16_t id) { + switch (id) { + case Http2KnownSettingsId::HEADER_TABLE_SIZE: + return "SETTINGS_HEADER_TABLE_SIZE"; + case Http2KnownSettingsId::ENABLE_PUSH: + return "SETTINGS_ENABLE_PUSH"; + case Http2KnownSettingsId::MAX_CONCURRENT_STREAMS: + return "SETTINGS_MAX_CONCURRENT_STREAMS"; + case Http2KnownSettingsId::INITIAL_WINDOW_SIZE: + return "SETTINGS_INITIAL_WINDOW_SIZE"; + case Http2KnownSettingsId::MAX_FRAME_SIZE: + return "SETTINGS_MAX_FRAME_SIZE"; + case Http2KnownSettingsId::MAX_HEADER_LIST_SIZE: + return "SETTINGS_MAX_HEADER_LIST_SIZE"; + } + return "SETTINGS_UNKNOWN"; +} + +absl::string_view Http2ErrorCodeToString(Http2ErrorCode error_code) { + switch (error_code) { + case Http2ErrorCode::HTTP2_NO_ERROR: + return "HTTP2_NO_ERROR"; + case Http2ErrorCode::PROTOCOL_ERROR: + return "PROTOCOL_ERROR"; + case Http2ErrorCode::INTERNAL_ERROR: + return "INTERNAL_ERROR"; + case Http2ErrorCode::FLOW_CONTROL_ERROR: + return "FLOW_CONTROL_ERROR"; + case Http2ErrorCode::SETTINGS_TIMEOUT: + return "SETTINGS_TIMEOUT"; + case Http2ErrorCode::STREAM_CLOSED: + return "STREAM_CLOSED"; + case Http2ErrorCode::FRAME_SIZE_ERROR: + return "FRAME_SIZE_ERROR"; + case Http2ErrorCode::REFUSED_STREAM: + return "REFUSED_STREAM"; + case Http2ErrorCode::CANCEL: + return "CANCEL"; + case Http2ErrorCode::COMPRESSION_ERROR: + return "COMPRESSION_ERROR"; + case Http2ErrorCode::CONNECT_ERROR: + return "CONNECT_ERROR"; + case Http2ErrorCode::ENHANCE_YOUR_CALM: + return "ENHANCE_YOUR_CALM"; + case Http2ErrorCode::INADEQUATE_SECURITY: + return "INADEQUATE_SECURITY"; + case Http2ErrorCode::HTTP_1_1_REQUIRED: + return "HTTP_1_1_REQUIRED"; + } + return "UNKNOWN_ERROR"; +} + +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/http2_protocol.h b/quiche/http2/adapter/http2_protocol.h new file mode 100644 index 000000000000..79225c3f5864 --- /dev/null +++ b/quiche/http2/adapter/http2_protocol.h @@ -0,0 +1,150 @@ +#ifndef QUICHE_HTTP2_ADAPTER_HTTP2_PROTOCOL_H_ +#define QUICHE_HTTP2_ADAPTER_HTTP2_PROTOCOL_H_ + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/strings/string_view.h" +#include "absl/types/variant.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace adapter { + +// Represents an HTTP/2 stream ID, consistent with nghttp2. +using Http2StreamId = int32_t; + +// Represents an HTTP/2 SETTINGS parameter as specified in RFC 7540 Section 6.5. +using Http2SettingsId = uint16_t; + +// Represents the payload of an HTTP/2 PING frame. +using Http2PingId = uint64_t; + +// Represents a single header name or value. +using HeaderRep = absl::variant; + +// Boolean return value is true if |rep| holds a string_view, which is assumed +// to have an indefinite lifetime. +std::pair GetStringView(const HeaderRep& rep); + +// Represents an HTTP/2 header field. A header field is a key-value pair with +// lowercase keys (as specified in RFC 7540 Section 8.1.2). +using Header = std::pair; + +// Represents an HTTP/2 SETTINGS key-value parameter. +struct QUICHE_EXPORT Http2Setting { + Http2SettingsId id; + uint32_t value; +}; + +QUICHE_EXPORT bool operator==(const Http2Setting& a, const Http2Setting& b); + +// The maximum possible stream ID. +const Http2StreamId kMaxStreamId = 0x7FFFFFFF; + +// The stream ID that represents the connection (e.g., for connection-level flow +// control updates). +const Http2StreamId kConnectionStreamId = 0; + +// The default value for the size of the largest frame payload, according to RFC +// 7540 Section 6.5.2 (SETTINGS_MAX_FRAME_SIZE). +const uint32_t kDefaultFramePayloadSizeLimit = 16u * 1024u; + +// The maximum value for the size of the largest frame payload, according to RFC +// 7540 Section 6.5.2 (SETTINGS_MAX_FRAME_SIZE). +const uint32_t kMaximumFramePayloadSizeLimit = 16777215u; + +// The default value for the initial stream and connection flow control window +// size, according to RFC 7540 Section 6.9.2. +const int kInitialFlowControlWindowSize = 64 * 1024 - 1; + +// The pseudo-header fields as specified in RFC 7540 Section 8.1.2.3 (request) +// and Section 8.1.2.4 (response). +ABSL_CONST_INIT QUICHE_EXPORT extern const char kHttp2MethodPseudoHeader[]; +ABSL_CONST_INIT QUICHE_EXPORT extern const char kHttp2SchemePseudoHeader[]; +ABSL_CONST_INIT QUICHE_EXPORT extern const char kHttp2AuthorityPseudoHeader[]; +ABSL_CONST_INIT QUICHE_EXPORT extern const char kHttp2PathPseudoHeader[]; +ABSL_CONST_INIT QUICHE_EXPORT extern const char kHttp2StatusPseudoHeader[]; + +ABSL_CONST_INIT QUICHE_EXPORT extern const uint8_t kMetadataFrameType; +ABSL_CONST_INIT QUICHE_EXPORT extern const uint8_t kMetadataEndFlag; +ABSL_CONST_INIT QUICHE_EXPORT extern const uint16_t kMetadataExtensionId; + +enum class FrameType : uint8_t { + DATA = 0x0, + HEADERS, + PRIORITY, + RST_STREAM, + SETTINGS, + PUSH_PROMISE, + PING, + GOAWAY, + WINDOW_UPDATE, + CONTINUATION, +}; + +enum FrameFlags : uint8_t { + END_STREAM_FLAG = 0x1, + ACK_FLAG = END_STREAM_FLAG, + END_HEADERS_FLAG = 0x4, + PADDED_FLAG = 0x8, + PRIORITY_FLAG = 0x20, +}; + +// HTTP/2 error codes as specified in RFC 7540 Section 7. +enum class Http2ErrorCode { + HTTP2_NO_ERROR = 0x0, + PROTOCOL_ERROR = 0x1, + INTERNAL_ERROR = 0x2, + FLOW_CONTROL_ERROR = 0x3, + SETTINGS_TIMEOUT = 0x4, + STREAM_CLOSED = 0x5, + FRAME_SIZE_ERROR = 0x6, + REFUSED_STREAM = 0x7, + CANCEL = 0x8, + COMPRESSION_ERROR = 0x9, + CONNECT_ERROR = 0xA, + ENHANCE_YOUR_CALM = 0xB, + INADEQUATE_SECURITY = 0xC, + HTTP_1_1_REQUIRED = 0xD, + MAX_ERROR_CODE = HTTP_1_1_REQUIRED, +}; + +// The SETTINGS parameters defined in RFC 7540 Section 6.5.2. Endpoints may send +// SETTINGS parameters outside of these definitions as per RFC 7540 Section 5.5. +// This is explicitly an enum instead of an enum class for ease of implicit +// conversion to the underlying Http2SettingsId type and use with non-standard +// extension SETTINGS parameters. +enum Http2KnownSettingsId : Http2SettingsId { + HEADER_TABLE_SIZE = 0x1, + MIN_SETTING = HEADER_TABLE_SIZE, + ENABLE_PUSH = 0x2, + MAX_CONCURRENT_STREAMS = 0x3, + INITIAL_WINDOW_SIZE = 0x4, + MAX_FRAME_SIZE = 0x5, + MAX_HEADER_LIST_SIZE = 0x6, + ENABLE_CONNECT_PROTOCOL = 0x8, // See RFC 8441 + MAX_SETTING = ENABLE_CONNECT_PROTOCOL +}; + +// Returns a human-readable string representation of the given SETTINGS |id| for +// logging/debugging. Returns "SETTINGS_UNKNOWN" for IDs outside of the RFC 7540 +// Section 6.5.2 definitions. +absl::string_view Http2SettingsIdToString(uint16_t id); + +// Returns a human-readable string representation of the given |error_code| for +// logging/debugging. Returns "UNKNOWN_ERROR" for errors outside of RFC 7540 +// Section 7 definitions. +absl::string_view Http2ErrorCodeToString(Http2ErrorCode error_code); + +enum class Perspective { + kClient, + kServer, +}; + +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_HTTP2_PROTOCOL_H_ diff --git a/quiche/http2/adapter/http2_session.h b/quiche/http2/adapter/http2_session.h new file mode 100644 index 000000000000..7b01000aca4b --- /dev/null +++ b/quiche/http2/adapter/http2_session.h @@ -0,0 +1,33 @@ +#ifndef QUICHE_HTTP2_ADAPTER_HTTP2_SESSION_H_ +#define QUICHE_HTTP2_ADAPTER_HTTP2_SESSION_H_ + +#include + +#include "absl/strings/string_view.h" +#include "quiche/http2/adapter/http2_protocol.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace adapter { + +struct QUICHE_EXPORT Http2SessionCallbacks {}; + +// A class to represent the state of a single HTTP/2 connection. +class QUICHE_EXPORT Http2Session { + public: + Http2Session() = default; + virtual ~Http2Session() {} + + virtual int64_t ProcessBytes(absl::string_view bytes) = 0; + + virtual int Consume(Http2StreamId stream_id, size_t num_bytes) = 0; + + virtual bool want_read() const = 0; + virtual bool want_write() const = 0; + virtual int GetRemoteWindowSize() const = 0; +}; + +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_HTTP2_SESSION_H_ diff --git a/quiche/http2/adapter/http2_util.cc b/quiche/http2/adapter/http2_util.cc new file mode 100644 index 000000000000..707fa00029f9 --- /dev/null +++ b/quiche/http2/adapter/http2_util.cc @@ -0,0 +1,134 @@ +#include "quiche/http2/adapter/http2_util.h" + +#include "quiche/spdy/core/spdy_protocol.h" + +namespace http2 { +namespace adapter { +namespace { + +using ConnectionError = Http2VisitorInterface::ConnectionError; +using InvalidFrameError = Http2VisitorInterface::InvalidFrameError; + +} // anonymous namespace + +spdy::SpdyErrorCode TranslateErrorCode(Http2ErrorCode code) { + switch (code) { + case Http2ErrorCode::HTTP2_NO_ERROR: + return spdy::ERROR_CODE_NO_ERROR; + case Http2ErrorCode::PROTOCOL_ERROR: + return spdy::ERROR_CODE_PROTOCOL_ERROR; + case Http2ErrorCode::INTERNAL_ERROR: + return spdy::ERROR_CODE_INTERNAL_ERROR; + case Http2ErrorCode::FLOW_CONTROL_ERROR: + return spdy::ERROR_CODE_FLOW_CONTROL_ERROR; + case Http2ErrorCode::SETTINGS_TIMEOUT: + return spdy::ERROR_CODE_SETTINGS_TIMEOUT; + case Http2ErrorCode::STREAM_CLOSED: + return spdy::ERROR_CODE_STREAM_CLOSED; + case Http2ErrorCode::FRAME_SIZE_ERROR: + return spdy::ERROR_CODE_FRAME_SIZE_ERROR; + case Http2ErrorCode::REFUSED_STREAM: + return spdy::ERROR_CODE_REFUSED_STREAM; + case Http2ErrorCode::CANCEL: + return spdy::ERROR_CODE_CANCEL; + case Http2ErrorCode::COMPRESSION_ERROR: + return spdy::ERROR_CODE_COMPRESSION_ERROR; + case Http2ErrorCode::CONNECT_ERROR: + return spdy::ERROR_CODE_CONNECT_ERROR; + case Http2ErrorCode::ENHANCE_YOUR_CALM: + return spdy::ERROR_CODE_ENHANCE_YOUR_CALM; + case Http2ErrorCode::INADEQUATE_SECURITY: + return spdy::ERROR_CODE_INADEQUATE_SECURITY; + case Http2ErrorCode::HTTP_1_1_REQUIRED: + return spdy::ERROR_CODE_HTTP_1_1_REQUIRED; + } + return spdy::ERROR_CODE_INTERNAL_ERROR; +} + +Http2ErrorCode TranslateErrorCode(spdy::SpdyErrorCode code) { + switch (code) { + case spdy::ERROR_CODE_NO_ERROR: + return Http2ErrorCode::HTTP2_NO_ERROR; + case spdy::ERROR_CODE_PROTOCOL_ERROR: + return Http2ErrorCode::PROTOCOL_ERROR; + case spdy::ERROR_CODE_INTERNAL_ERROR: + return Http2ErrorCode::INTERNAL_ERROR; + case spdy::ERROR_CODE_FLOW_CONTROL_ERROR: + return Http2ErrorCode::FLOW_CONTROL_ERROR; + case spdy::ERROR_CODE_SETTINGS_TIMEOUT: + return Http2ErrorCode::SETTINGS_TIMEOUT; + case spdy::ERROR_CODE_STREAM_CLOSED: + return Http2ErrorCode::STREAM_CLOSED; + case spdy::ERROR_CODE_FRAME_SIZE_ERROR: + return Http2ErrorCode::FRAME_SIZE_ERROR; + case spdy::ERROR_CODE_REFUSED_STREAM: + return Http2ErrorCode::REFUSED_STREAM; + case spdy::ERROR_CODE_CANCEL: + return Http2ErrorCode::CANCEL; + case spdy::ERROR_CODE_COMPRESSION_ERROR: + return Http2ErrorCode::COMPRESSION_ERROR; + case spdy::ERROR_CODE_CONNECT_ERROR: + return Http2ErrorCode::CONNECT_ERROR; + case spdy::ERROR_CODE_ENHANCE_YOUR_CALM: + return Http2ErrorCode::ENHANCE_YOUR_CALM; + case spdy::ERROR_CODE_INADEQUATE_SECURITY: + return Http2ErrorCode::INADEQUATE_SECURITY; + case spdy::ERROR_CODE_HTTP_1_1_REQUIRED: + return Http2ErrorCode::HTTP_1_1_REQUIRED; + } + return Http2ErrorCode::INTERNAL_ERROR; +} + +absl::string_view ConnectionErrorToString(ConnectionError error) { + switch (error) { + case ConnectionError::kInvalidConnectionPreface: + return "InvalidConnectionPreface"; + case ConnectionError::kSendError: + return "SendError"; + case ConnectionError::kParseError: + return "ParseError"; + case ConnectionError::kHeaderError: + return "HeaderError"; + case ConnectionError::kInvalidNewStreamId: + return "InvalidNewStreamId"; + case ConnectionError::kWrongFrameSequence: + return "kWrongFrameSequence"; + case ConnectionError::kInvalidPushPromise: + return "InvalidPushPromise"; + case ConnectionError::kExceededMaxConcurrentStreams: + return "ExceededMaxConcurrentStreams"; + case ConnectionError::kFlowControlError: + return "FlowControlError"; + case ConnectionError::kInvalidGoAwayLastStreamId: + return "InvalidGoAwayLastStreamId"; + case ConnectionError::kInvalidSetting: + return "InvalidSetting"; + } + return "UnknownConnectionError"; +} + +absl::string_view InvalidFrameErrorToString( + Http2VisitorInterface::InvalidFrameError error) { + switch (error) { + case InvalidFrameError::kProtocol: + return "Protocol"; + case InvalidFrameError::kRefusedStream: + return "RefusedStream"; + case InvalidFrameError::kHttpHeader: + return "HttpHeader"; + case InvalidFrameError::kHttpMessaging: + return "HttpMessaging"; + case InvalidFrameError::kFlowControl: + return "FlowControl"; + case InvalidFrameError::kStreamClosed: + return "StreamClosed"; + } + return "UnknownInvalidFrameError"; +} + +bool DeltaAtLeastHalfLimit(int64_t limit, int64_t /*size*/, int64_t delta) { + return delta > 0 && delta >= limit / 2; +} + +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/http2_util.h b/quiche/http2/adapter/http2_util.h new file mode 100644 index 000000000000..12015a3ac935 --- /dev/null +++ b/quiche/http2/adapter/http2_util.h @@ -0,0 +1,31 @@ +#ifndef QUICHE_HTTP2_ADAPTER_HTTP2_UTIL_H_ +#define QUICHE_HTTP2_ADAPTER_HTTP2_UTIL_H_ + +#include + +#include "quiche/http2/adapter/http2_protocol.h" +#include "quiche/http2/adapter/http2_visitor_interface.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/spdy/core/spdy_protocol.h" + +namespace http2 { +namespace adapter { + +QUICHE_EXPORT spdy::SpdyErrorCode TranslateErrorCode(Http2ErrorCode code); +QUICHE_EXPORT Http2ErrorCode TranslateErrorCode(spdy::SpdyErrorCode code); + +QUICHE_EXPORT absl::string_view ConnectionErrorToString( + Http2VisitorInterface::ConnectionError error); + +QUICHE_EXPORT absl::string_view InvalidFrameErrorToString( + Http2VisitorInterface::InvalidFrameError error); + +// A WINDOW_UPDATE sending strategy that returns true if the `delta` to be sent +// is positive and at least half of the window `limit`. +QUICHE_EXPORT bool DeltaAtLeastHalfLimit(int64_t limit, int64_t /*size*/, + int64_t delta); + +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_HTTP2_UTIL_H_ diff --git a/quiche/http2/adapter/http2_visitor_interface.h b/quiche/http2/adapter/http2_visitor_interface.h new file mode 100644 index 000000000000..fa8491d879f1 --- /dev/null +++ b/quiche/http2/adapter/http2_visitor_interface.h @@ -0,0 +1,264 @@ +#ifndef QUICHE_HTTP2_ADAPTER_HTTP2_VISITOR_INTERFACE_H_ +#define QUICHE_HTTP2_ADAPTER_HTTP2_VISITOR_INTERFACE_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/http2/adapter/http2_protocol.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace adapter { + +// Http2VisitorInterface contains callbacks for receiving HTTP/2-level events. A +// processor like NghttpAdapter parses HTTP/2 frames and invokes the callbacks +// on an instance of this interface. Prefer a void return type for these +// callbacks, instead setting output parameters as needed. +// +// Example sequences of calls/events: +// GET: +// - OnBeginHeadersForStream() +// - OnHeaderForStream() +// - OnEndHeadersForStream() +// - OnEndStream() +// +// POST: +// - OnBeginHeadersForStream() +// - OnHeaderForStream() +// - OnEndHeadersForStream() +// - OnBeginDataForStream() +// - OnDataForStream() +// - OnEndStream() +// +// Request canceled mid-stream, e.g, with error code CANCEL: +// - OnBeginHeadersForStream() +// - OnHeaderForStream() +// - OnEndHeadersForStream() +// - OnRstStream() +// - OnCloseStream() +// +// Request closed mid-stream, e.g., with error code NO_ERROR: +// - OnBeginHeadersForStream() +// - OnHeaderForStream() +// - OnEndHeadersForStream() +// - OnRstStream() +// - OnCloseStream() +// +// More details are at RFC 7540 (go/http2spec). +class QUICHE_EXPORT Http2VisitorInterface { + public: + Http2VisitorInterface(const Http2VisitorInterface&) = delete; + Http2VisitorInterface& operator=(const Http2VisitorInterface&) = delete; + virtual ~Http2VisitorInterface() = default; + + enum : int64_t { + kSendBlocked = 0, + kSendError = -1, + }; + // Called when there are serialized frames to send. Should return how many + // bytes were actually sent. May return kSendBlocked or kSendError. + virtual int64_t OnReadyToSend(absl::string_view serialized) = 0; + + // Called when a connection-level error has occurred. + enum class ConnectionError { + // The peer sent an invalid connection preface. + kInvalidConnectionPreface, + // The visitor encountered an error sending bytes to the peer. + kSendError, + // There was an error reading and framing bytes from the peer. + kParseError, + // The visitor considered a received header to be a connection error. + kHeaderError, + // The peer attempted to open a stream with an invalid stream ID. + kInvalidNewStreamId, + // The peer sent a frame that is invalid on an idle stream (before HEADERS). + kWrongFrameSequence, + // The peer sent an invalid PUSH_PROMISE frame. + kInvalidPushPromise, + // The peer exceeded the max concurrent streams limit. + kExceededMaxConcurrentStreams, + // The peer caused a flow control error. + kFlowControlError, + // The peer sent a GOAWAY with an invalid last-stream-ID field. + kInvalidGoAwayLastStreamId, + // The peer sent an invalid SETTINGS value. + kInvalidSetting, + }; + virtual void OnConnectionError(ConnectionError error) = 0; + + // Called when the header for a frame is received. Returns false if a fatal + // error has occurred. + virtual bool OnFrameHeader(Http2StreamId /*stream_id*/, size_t /*length*/, + uint8_t /*type*/, uint8_t /*flags*/) { + return true; + } + + // Called when a non-ack SETTINGS frame is received. + virtual void OnSettingsStart() = 0; + + // Called for each SETTINGS id-value pair. + virtual void OnSetting(Http2Setting setting) = 0; + + // Called at the end of a non-ack SETTINGS frame. + virtual void OnSettingsEnd() = 0; + + // Called when a SETTINGS ack frame is received. + virtual void OnSettingsAck() = 0; + + // Called when the connection receives the header block for a HEADERS frame on + // a stream but has not yet parsed individual headers. Returns false if a + // fatal error has occurred. + virtual bool OnBeginHeadersForStream(Http2StreamId stream_id) = 0; + + // Called when the connection receives the header |key| and |value| for a + // stream. The HTTP/2 pseudo-headers defined in RFC 7540 Sections 8.1.2.3 and + // 8.1.2.4 are also conveyed in this callback. This method is called after + // OnBeginHeadersForStream(). May return HEADER_RST_STREAM to indicate the + // header block should be rejected. This will cause the library to queue a + // RST_STREAM frame, which will have a default error code of INTERNAL_ERROR. + // The visitor implementation may choose to queue a RST_STREAM with a + // different error code instead, which should be done before returning + // HEADER_RST_STREAM. Returning HEADER_CONNECTION_ERROR will lead to a + // non-recoverable error on the connection. + enum OnHeaderResult { + // The header was accepted. + HEADER_OK, + // The application considers the header a connection error. + HEADER_CONNECTION_ERROR, + // The application rejects the header and requests the stream be reset. + HEADER_RST_STREAM, + // The header field is invalid and will be reset with error code + // PROTOCOL_ERROR. + HEADER_FIELD_INVALID, + // The headers are a violation of HTTP messaging semantics and will be reset + // with error code PROTOCOL_ERROR. + HEADER_HTTP_MESSAGING, + // The headers caused a compression context error. + HEADER_COMPRESSION_ERROR, + }; + virtual OnHeaderResult OnHeaderForStream(Http2StreamId stream_id, + absl::string_view key, + absl::string_view value) = 0; + + // Called when the connection has received the complete header block for a + // logical HEADERS frame on a stream (which may contain CONTINUATION frames, + // transparent to the user). Returns false if a fatal error has occurred. + virtual bool OnEndHeadersForStream(Http2StreamId stream_id) = 0; + + // Called when the connection receives the beginning of a DATA frame. The data + // payload will be provided via subsequent calls to OnDataForStream(). Returns + // false if a fatal error has occurred. + virtual bool OnBeginDataForStream(Http2StreamId stream_id, + size_t payload_length) = 0; + + // Called when the optional padding length field is parsed as part of a DATA + // frame payload. `padding_length` represents the total amount of padding for + // this frame, including the length byte itself. Returns false if a fatal + // error has occurred. + virtual bool OnDataPaddingLength(Http2StreamId stream_id, + size_t padding_length) = 0; + + // Called when the connection receives some |data| (as part of a DATA frame + // payload) for a stream. Returns false if a fatal error has occurred. + virtual bool OnDataForStream(Http2StreamId stream_id, + absl::string_view data) = 0; + + // Called when the peer sends the END_STREAM flag on a stream, indicating that + // the peer will not send additional headers or data for that stream. + virtual bool OnEndStream(Http2StreamId stream_id) = 0; + + // Called when the connection receives a RST_STREAM for a stream. This call + // will be followed by either OnCloseStream(). + virtual void OnRstStream(Http2StreamId stream_id, + Http2ErrorCode error_code) = 0; + + // Called when a stream is closed. Returns false if a fatal error has + // occurred. + virtual bool OnCloseStream(Http2StreamId stream_id, + Http2ErrorCode error_code) = 0; + + // Called when the connection receives a PRIORITY frame. + virtual void OnPriorityForStream(Http2StreamId stream_id, + Http2StreamId parent_stream_id, int weight, + bool exclusive) = 0; + + // Called when the connection receives a PING frame. + virtual void OnPing(Http2PingId ping_id, bool is_ack) = 0; + + // Called when the connection receives a PUSH_PROMISE frame. The server push + // request headers follow in calls to OnHeaderForStream() with |stream_id|. + virtual void OnPushPromiseForStream(Http2StreamId stream_id, + Http2StreamId promised_stream_id) = 0; + + // Called when the connection receives a GOAWAY frame. Returns false if a + // fatal error has occurred. + virtual bool OnGoAway(Http2StreamId last_accepted_stream_id, + Http2ErrorCode error_code, + absl::string_view opaque_data) = 0; + + // Called when the connection receives a WINDOW_UPDATE frame. For + // connection-level window updates, the |stream_id| will be 0. + virtual void OnWindowUpdate(Http2StreamId stream_id, + int window_increment) = 0; + + // Called immediately before a frame of the given type is sent. Should return + // 0 on success. + virtual int OnBeforeFrameSent(uint8_t frame_type, Http2StreamId stream_id, + size_t length, uint8_t flags) = 0; + + // Called immediately after a frame of the given type is sent. Should return 0 + // on success. |error_code| is only populated for RST_STREAM and GOAWAY frame + // types. + virtual int OnFrameSent(uint8_t frame_type, Http2StreamId stream_id, + size_t length, uint8_t flags, + uint32_t error_code) = 0; + + // Called when the connection receives an invalid frame. A return value of + // false will result in the connection entering an error state, with no + // further frame processing possible. + enum class InvalidFrameError { + // The frame contains a general protocol error. + kProtocol, + // The frame would have caused a new (invalid) stream to be opened. + kRefusedStream, + // The frame contains an invalid header field. + kHttpHeader, + // The frame contains a violation in HTTP messaging rules. + kHttpMessaging, + // The frame causes a flow control error. + kFlowControl, + // The frame is on an already closed stream or has an invalid stream ID. + kStreamClosed, + }; + virtual bool OnInvalidFrame(Http2StreamId stream_id, + InvalidFrameError error) = 0; + + // Called when the connection receives the beginning of a METADATA frame + // (which may itself be the middle of a logical metadata block). The metadata + // payload will be provided via subsequent calls to OnMetadataForStream(). + // TODO(birenroy): Consider removing this unnecessary method. + virtual void OnBeginMetadataForStream(Http2StreamId stream_id, + size_t payload_length) = 0; + + // Called when the connection receives |metadata| as part of a METADATA frame + // payload for a stream. Returns false if a fatal error has occurred. + virtual bool OnMetadataForStream(Http2StreamId stream_id, + absl::string_view metadata) = 0; + + // Called when the connection has finished receiving a logical metadata block + // for a stream. Note that there may be multiple metadata blocks for a stream. + // Returns false if there was an error unpacking the metadata payload. + virtual bool OnMetadataEndForStream(Http2StreamId stream_id) = 0; + + // Invoked with an error message from the application. + virtual void OnErrorDebug(absl::string_view message) = 0; + + protected: + Http2VisitorInterface() = default; +}; + +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_HTTP2_VISITOR_INTERFACE_H_ diff --git a/quiche/http2/adapter/mock_http2_visitor.h b/quiche/http2/adapter/mock_http2_visitor.h new file mode 100644 index 000000000000..86345d3651aa --- /dev/null +++ b/quiche/http2/adapter/mock_http2_visitor.h @@ -0,0 +1,122 @@ +#ifndef QUICHE_HTTP2_ADAPTER_MOCK_HTTP2_VISITOR_H_ +#define QUICHE_HTTP2_ADAPTER_MOCK_HTTP2_VISITOR_H_ + +#include + +#include "quiche/http2/adapter/http2_visitor_interface.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace adapter { +namespace test { + +// A mock visitor class, for use in tests. +class QUICHE_NO_EXPORT MockHttp2Visitor : public Http2VisitorInterface { + public: + MockHttp2Visitor() { + ON_CALL(*this, OnFrameHeader).WillByDefault(testing::Return(true)); + ON_CALL(*this, OnBeginHeadersForStream) + .WillByDefault(testing::Return(true)); + ON_CALL(*this, OnHeaderForStream).WillByDefault(testing::Return(HEADER_OK)); + ON_CALL(*this, OnEndHeadersForStream).WillByDefault(testing::Return(true)); + ON_CALL(*this, OnDataPaddingLength).WillByDefault(testing::Return(true)); + ON_CALL(*this, OnBeginDataForStream).WillByDefault(testing::Return(true)); + ON_CALL(*this, OnDataForStream).WillByDefault(testing::Return(true)); + ON_CALL(*this, OnEndStream).WillByDefault(testing::Return(true)); + ON_CALL(*this, OnCloseStream).WillByDefault(testing::Return(true)); + ON_CALL(*this, OnGoAway).WillByDefault(testing::Return(true)); + ON_CALL(*this, OnInvalidFrame).WillByDefault(testing::Return(true)); + ON_CALL(*this, OnMetadataForStream).WillByDefault(testing::Return(true)); + ON_CALL(*this, OnMetadataEndForStream).WillByDefault(testing::Return(true)); + } + + MOCK_METHOD(int64_t, OnReadyToSend, (absl::string_view serialized), + (override)); + MOCK_METHOD(void, OnConnectionError, (ConnectionError error), (override)); + MOCK_METHOD(bool, OnFrameHeader, + (Http2StreamId stream_id, size_t length, uint8_t type, + uint8_t flags), + (override)); + MOCK_METHOD(void, OnSettingsStart, (), (override)); + MOCK_METHOD(void, OnSetting, (Http2Setting setting), (override)); + MOCK_METHOD(void, OnSettingsEnd, (), (override)); + MOCK_METHOD(void, OnSettingsAck, (), (override)); + MOCK_METHOD(bool, OnBeginHeadersForStream, (Http2StreamId stream_id), + (override)); + + MOCK_METHOD(OnHeaderResult, OnHeaderForStream, + (Http2StreamId stream_id, absl::string_view key, + absl::string_view value), + (override)); + + MOCK_METHOD(bool, OnEndHeadersForStream, (Http2StreamId stream_id), + (override)); + + MOCK_METHOD(bool, OnDataPaddingLength, + (Http2StreamId strema_id, size_t padding_length), (override)); + + MOCK_METHOD(bool, OnBeginDataForStream, + (Http2StreamId stream_id, size_t payload_length), (override)); + + MOCK_METHOD(bool, OnDataForStream, + (Http2StreamId stream_id, absl::string_view data), (override)); + + MOCK_METHOD(bool, OnEndStream, (Http2StreamId stream_id), (override)); + + MOCK_METHOD(void, OnRstStream, + (Http2StreamId stream_id, Http2ErrorCode error_code), (override)); + + MOCK_METHOD(bool, OnCloseStream, + (Http2StreamId stream_id, Http2ErrorCode error_code), (override)); + + MOCK_METHOD(void, OnPriorityForStream, + (Http2StreamId stream_id, Http2StreamId parent_stream_id, + int weight, bool exclusive), + (override)); + + MOCK_METHOD(void, OnPing, (Http2PingId ping_id, bool is_ack), (override)); + + MOCK_METHOD(void, OnPushPromiseForStream, + (Http2StreamId stream_id, Http2StreamId promised_stream_id), + (override)); + + MOCK_METHOD(bool, OnGoAway, + (Http2StreamId last_accepted_stream_id, Http2ErrorCode error_code, + absl::string_view opaque_data), + (override)); + + MOCK_METHOD(void, OnWindowUpdate, + (Http2StreamId stream_id, int window_increment), (override)); + + MOCK_METHOD(int, OnBeforeFrameSent, + (uint8_t frame_type, Http2StreamId stream_id, size_t length, + uint8_t flags), + (override)); + + MOCK_METHOD(int, OnFrameSent, + (uint8_t frame_type, Http2StreamId stream_id, size_t length, + uint8_t flags, uint32_t error_code), + (override)); + + MOCK_METHOD(bool, OnInvalidFrame, + (Http2StreamId stream_id, InvalidFrameError error), (override)); + + MOCK_METHOD(void, OnBeginMetadataForStream, + (Http2StreamId stream_id, size_t payload_length), (override)); + + MOCK_METHOD(bool, OnMetadataForStream, + (Http2StreamId stream_id, absl::string_view metadata), + (override)); + + MOCK_METHOD(bool, OnMetadataEndForStream, (Http2StreamId stream_id), + (override)); + + MOCK_METHOD(void, OnErrorDebug, (absl::string_view message), (override)); +}; + +} // namespace test +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_MOCK_HTTP2_VISITOR_H_ diff --git a/quiche/http2/adapter/mock_nghttp2_callbacks.cc b/quiche/http2/adapter/mock_nghttp2_callbacks.cc new file mode 100644 index 000000000000..20ac327aab00 --- /dev/null +++ b/quiche/http2/adapter/mock_nghttp2_callbacks.cc @@ -0,0 +1,130 @@ +#include "quiche/http2/adapter/mock_nghttp2_callbacks.h" + +#include "quiche/http2/adapter/nghttp2_util.h" + +namespace http2 { +namespace adapter { +namespace test { + +/* static */ +nghttp2_session_callbacks_unique_ptr MockNghttp2Callbacks::GetCallbacks() { + nghttp2_session_callbacks* callbacks; + nghttp2_session_callbacks_new(&callbacks); + + // All of the callback implementations below just delegate to the mock methods + // of |user_data|, which is assumed to be a MockNghttp2Callbacks*. + nghttp2_session_callbacks_set_send_callback( + callbacks, + [](nghttp2_session*, const uint8_t* data, size_t length, int flags, + void* user_data) -> ssize_t { + return static_cast(user_data)->Send(data, length, + flags); + }); + + nghttp2_session_callbacks_set_send_data_callback( + callbacks, + [](nghttp2_session*, nghttp2_frame* frame, const uint8_t* framehd, + size_t length, nghttp2_data_source* source, void* user_data) -> int { + return static_cast(user_data)->SendData( + frame, framehd, length, source); + }); + + nghttp2_session_callbacks_set_on_begin_headers_callback( + callbacks, + [](nghttp2_session*, const nghttp2_frame* frame, void* user_data) -> int { + return static_cast(user_data)->OnBeginHeaders( + frame); + }); + + nghttp2_session_callbacks_set_on_header_callback( + callbacks, + [](nghttp2_session*, const nghttp2_frame* frame, const uint8_t* raw_name, + size_t name_length, const uint8_t* raw_value, size_t value_length, + uint8_t flags, void* user_data) -> int { + absl::string_view name = ToStringView(raw_name, name_length); + absl::string_view value = ToStringView(raw_value, value_length); + return static_cast(user_data)->OnHeader( + frame, name, value, flags); + }); + + nghttp2_session_callbacks_set_on_data_chunk_recv_callback( + callbacks, + [](nghttp2_session*, uint8_t flags, int32_t stream_id, + const uint8_t* data, size_t len, void* user_data) -> int { + absl::string_view chunk = ToStringView(data, len); + return static_cast(user_data)->OnDataChunkRecv( + flags, stream_id, chunk); + }); + + nghttp2_session_callbacks_set_on_begin_frame_callback( + callbacks, + [](nghttp2_session*, const nghttp2_frame_hd* hd, void* user_data) -> int { + return static_cast(user_data)->OnBeginFrame(hd); + }); + + nghttp2_session_callbacks_set_on_frame_recv_callback( + callbacks, + [](nghttp2_session*, const nghttp2_frame* frame, void* user_data) -> int { + return static_cast(user_data)->OnFrameRecv( + frame); + }); + + nghttp2_session_callbacks_set_on_stream_close_callback( + callbacks, + [](nghttp2_session*, int32_t stream_id, uint32_t error_code, + void* user_data) -> int { + return static_cast(user_data)->OnStreamClose( + stream_id, error_code); + }); + + nghttp2_session_callbacks_set_on_frame_send_callback( + callbacks, + [](nghttp2_session*, const nghttp2_frame* frame, void* user_data) -> int { + return static_cast(user_data)->OnFrameSend( + frame); + }); + + nghttp2_session_callbacks_set_before_frame_send_callback( + callbacks, + [](nghttp2_session*, const nghttp2_frame* frame, void* user_data) -> int { + return static_cast(user_data)->BeforeFrameSend( + frame); + }); + + nghttp2_session_callbacks_set_on_frame_not_send_callback( + callbacks, + [](nghttp2_session*, const nghttp2_frame* frame, int lib_error_code, + void* user_data) -> int { + return static_cast(user_data)->OnFrameNotSend( + frame, lib_error_code); + }); + + nghttp2_session_callbacks_set_on_invalid_frame_recv_callback( + callbacks, + [](nghttp2_session*, const nghttp2_frame* frame, int error_code, + void* user_data) -> int { + return static_cast(user_data) + ->OnInvalidFrameRecv(frame, error_code); + }); + + nghttp2_session_callbacks_set_error_callback2( + callbacks, + [](nghttp2_session* /*session*/, int lib_error_code, const char* msg, + size_t len, void* user_data) -> int { + return static_cast(user_data)->OnErrorCallback2( + lib_error_code, msg, len); + }); + + nghttp2_session_callbacks_set_pack_extension_callback( + callbacks, + [](nghttp2_session*, uint8_t* buf, size_t len, const nghttp2_frame* frame, + void* user_data) -> ssize_t { + return static_cast(user_data)->OnPackExtension( + buf, len, frame); + }); + return MakeCallbacksPtr(callbacks); +} + +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/mock_nghttp2_callbacks.h b/quiche/http2/adapter/mock_nghttp2_callbacks.h new file mode 100644 index 000000000000..1bda75636194 --- /dev/null +++ b/quiche/http2/adapter/mock_nghttp2_callbacks.h @@ -0,0 +1,70 @@ +#ifndef QUICHE_HTTP2_ADAPTER_MOCK_NGHTTP2_CALLBACKS_H_ +#define QUICHE_HTTP2_ADAPTER_MOCK_NGHTTP2_CALLBACKS_H_ + +#include + +#include "absl/strings/string_view.h" +#include "quiche/http2/adapter/nghttp2.h" +#include "quiche/http2/adapter/nghttp2_util.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace adapter { +namespace test { + +// This class provides a set of mock nghttp2 callbacks for use in unit test +// expectations. +class QUICHE_NO_EXPORT MockNghttp2Callbacks { + public: + MockNghttp2Callbacks() = default; + + // The caller takes ownership of the |nghttp2_session_callbacks|. + static nghttp2_session_callbacks_unique_ptr GetCallbacks(); + + MOCK_METHOD(ssize_t, Send, (const uint8_t* data, size_t length, int flags), + ()); + + MOCK_METHOD(int, SendData, + (nghttp2_frame * frame, const uint8_t* framehd, size_t length, + nghttp2_data_source* source), + ()); + + MOCK_METHOD(int, OnBeginHeaders, (const nghttp2_frame* frame), ()); + + MOCK_METHOD(int, OnHeader, + (const nghttp2_frame* frame, absl::string_view name, + absl::string_view value, uint8_t flags), + ()); + + MOCK_METHOD(int, OnDataChunkRecv, + (uint8_t flags, int32_t stream_id, absl::string_view data), ()); + + MOCK_METHOD(int, OnBeginFrame, (const nghttp2_frame_hd* hd), ()); + + MOCK_METHOD(int, OnFrameRecv, (const nghttp2_frame* frame), ()); + + MOCK_METHOD(int, OnStreamClose, (int32_t stream_id, uint32_t error_code), ()); + + MOCK_METHOD(int, BeforeFrameSend, (const nghttp2_frame* frame), ()); + + MOCK_METHOD(int, OnFrameSend, (const nghttp2_frame* frame), ()); + + MOCK_METHOD(int, OnFrameNotSend, + (const nghttp2_frame* frame, int lib_error_code), ()); + + MOCK_METHOD(int, OnInvalidFrameRecv, + (const nghttp2_frame* frame, int error_code), ()); + + MOCK_METHOD(int, OnErrorCallback2, + (int lib_error_code, const char* msg, size_t len), ()); + + MOCK_METHOD(ssize_t, OnPackExtension, + (uint8_t * buf, size_t len, const nghttp2_frame* frame), ()); +}; + +} // namespace test +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_MOCK_NGHTTP2_CALLBACKS_H_ diff --git a/quiche/http2/adapter/nghttp2.h b/quiche/http2/adapter/nghttp2.h new file mode 100644 index 000000000000..eed3c8621296 --- /dev/null +++ b/quiche/http2/adapter/nghttp2.h @@ -0,0 +1,11 @@ +#ifndef QUICHE_HTTP2_ADAPTER_NGHTTP2_H_ +#define QUICHE_HTTP2_ADAPTER_NGHTTP2_H_ + +#include + +// Required to build on Windows. +using ssize_t = ptrdiff_t; + +#include "nghttp2/nghttp2.h" + +#endif // QUICHE_HTTP2_ADAPTER_NGHTTP2_H_ diff --git a/quiche/http2/adapter/nghttp2_adapter.cc b/quiche/http2/adapter/nghttp2_adapter.cc new file mode 100644 index 000000000000..f724ea395ccf --- /dev/null +++ b/quiche/http2/adapter/nghttp2_adapter.cc @@ -0,0 +1,309 @@ +#include "quiche/http2/adapter/nghttp2_adapter.h" + +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/http2/adapter/http2_visitor_interface.h" +#include "quiche/http2/adapter/nghttp2.h" +#include "quiche/http2/adapter/nghttp2_callbacks.h" +#include "quiche/http2/adapter/nghttp2_data_provider.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_endian.h" + +namespace http2 { +namespace adapter { + +namespace { + +using ConnectionError = Http2VisitorInterface::ConnectionError; + +// A metadata source that deletes itself upon completion. +class SelfDeletingMetadataSource : public MetadataSource { + public: + explicit SelfDeletingMetadataSource(std::unique_ptr source) + : source_(std::move(source)) {} + + size_t NumFrames(size_t max_frame_size) const override { + return source_->NumFrames(max_frame_size); + } + + std::pair Pack(uint8_t* dest, size_t dest_len) override { + const auto result = source_->Pack(dest, dest_len); + if (result.first < 0 || result.second) { + delete this; + } + return result; + } + + void OnFailure() override { + source_->OnFailure(); + delete this; + } + + private: + std::unique_ptr source_; +}; + +} // anonymous namespace + +/* static */ +std::unique_ptr NgHttp2Adapter::CreateClientAdapter( + Http2VisitorInterface& visitor, const nghttp2_option* options) { + auto adapter = new NgHttp2Adapter(visitor, Perspective::kClient, options); + adapter->Initialize(); + return absl::WrapUnique(adapter); +} + +/* static */ +std::unique_ptr NgHttp2Adapter::CreateServerAdapter( + Http2VisitorInterface& visitor, const nghttp2_option* options) { + auto adapter = new NgHttp2Adapter(visitor, Perspective::kServer, options); + adapter->Initialize(); + return absl::WrapUnique(adapter); +} + +bool NgHttp2Adapter::IsServerSession() const { + int result = nghttp2_session_check_server_session(session_->raw_ptr()); + QUICHE_DCHECK_EQ(perspective_ == Perspective::kServer, result > 0); + return result > 0; +} + +int64_t NgHttp2Adapter::ProcessBytes(absl::string_view bytes) { + const int64_t processed_bytes = session_->ProcessBytes(bytes); + if (processed_bytes < 0) { + visitor_.OnConnectionError(ConnectionError::kParseError); + } + return processed_bytes; +} + +void NgHttp2Adapter::SubmitSettings(absl::Span settings) { + // Submit SETTINGS, converting each Http2Setting to an nghttp2_settings_entry. + std::vector nghttp2_settings; + absl::c_transform(settings, std::back_inserter(nghttp2_settings), + [](const Http2Setting& setting) { + return nghttp2_settings_entry{setting.id, setting.value}; + }); + nghttp2_submit_settings(session_->raw_ptr(), NGHTTP2_FLAG_NONE, + nghttp2_settings.data(), nghttp2_settings.size()); +} + +void NgHttp2Adapter::SubmitPriorityForStream(Http2StreamId stream_id, + Http2StreamId parent_stream_id, + int weight, bool exclusive) { + nghttp2_priority_spec priority_spec; + nghttp2_priority_spec_init(&priority_spec, parent_stream_id, weight, + static_cast(exclusive)); + nghttp2_submit_priority(session_->raw_ptr(), NGHTTP2_FLAG_NONE, stream_id, + &priority_spec); +} + +void NgHttp2Adapter::SubmitPing(Http2PingId ping_id) { + uint8_t opaque_data[8] = {}; + Http2PingId ping_id_to_serialize = quiche::QuicheEndian::HostToNet64(ping_id); + std::memcpy(opaque_data, &ping_id_to_serialize, sizeof(Http2PingId)); + nghttp2_submit_ping(session_->raw_ptr(), NGHTTP2_FLAG_NONE, opaque_data); +} + +void NgHttp2Adapter::SubmitShutdownNotice() { + nghttp2_submit_shutdown_notice(session_->raw_ptr()); +} + +void NgHttp2Adapter::SubmitGoAway(Http2StreamId last_accepted_stream_id, + Http2ErrorCode error_code, + absl::string_view opaque_data) { + nghttp2_submit_goaway(session_->raw_ptr(), NGHTTP2_FLAG_NONE, + last_accepted_stream_id, + static_cast(error_code), + ToUint8Ptr(opaque_data.data()), opaque_data.size()); +} + +void NgHttp2Adapter::SubmitWindowUpdate(Http2StreamId stream_id, + int window_increment) { + nghttp2_submit_window_update(session_->raw_ptr(), NGHTTP2_FLAG_NONE, + stream_id, window_increment); +} + +void NgHttp2Adapter::SubmitMetadata(Http2StreamId stream_id, + size_t max_frame_size, + std::unique_ptr source) { + auto* wrapped_source = new SelfDeletingMetadataSource(std::move(source)); + const size_t num_frames = wrapped_source->NumFrames(max_frame_size); + size_t num_successes = 0; + for (size_t i = 1; i <= num_frames; ++i) { + const int result = nghttp2_submit_extension( + session_->raw_ptr(), kMetadataFrameType, + i == num_frames ? kMetadataEndFlag : 0, stream_id, wrapped_source); + if (result != 0) { + QUICHE_LOG(DFATAL) << "Failed to submit extension frame " << i << " of " + << num_frames; + break; + } + ++num_successes; + } + if (num_successes == 0) { + delete wrapped_source; + } +} + +int NgHttp2Adapter::Send() { + const int result = nghttp2_session_send(session_->raw_ptr()); + if (result != 0) { + QUICHE_VLOG(1) << "nghttp2_session_send returned " << result; + visitor_.OnConnectionError(ConnectionError::kSendError); + } + return result; +} + +int NgHttp2Adapter::GetSendWindowSize() const { + return session_->GetRemoteWindowSize(); +} + +int NgHttp2Adapter::GetStreamSendWindowSize(Http2StreamId stream_id) const { + return nghttp2_session_get_stream_remote_window_size(session_->raw_ptr(), + stream_id); +} + +int NgHttp2Adapter::GetStreamReceiveWindowLimit(Http2StreamId stream_id) const { + return nghttp2_session_get_stream_effective_local_window_size( + session_->raw_ptr(), stream_id); +} + +int NgHttp2Adapter::GetStreamReceiveWindowSize(Http2StreamId stream_id) const { + return nghttp2_session_get_stream_local_window_size(session_->raw_ptr(), + stream_id); +} + +int NgHttp2Adapter::GetReceiveWindowSize() const { + return nghttp2_session_get_local_window_size(session_->raw_ptr()); +} + +int NgHttp2Adapter::GetHpackEncoderDynamicTableSize() const { + return nghttp2_session_get_hd_deflate_dynamic_table_size(session_->raw_ptr()); +} + +int NgHttp2Adapter::GetHpackDecoderDynamicTableSize() const { + return nghttp2_session_get_hd_inflate_dynamic_table_size(session_->raw_ptr()); +} + +Http2StreamId NgHttp2Adapter::GetHighestReceivedStreamId() const { + return nghttp2_session_get_last_proc_stream_id(session_->raw_ptr()); +} + +void NgHttp2Adapter::MarkDataConsumedForStream(Http2StreamId stream_id, + size_t num_bytes) { + int rc = session_->Consume(stream_id, num_bytes); + if (rc != 0) { + QUICHE_LOG(ERROR) << "Error " << rc << " marking " << num_bytes + << " bytes consumed for stream " << stream_id; + } +} + +void NgHttp2Adapter::SubmitRst(Http2StreamId stream_id, + Http2ErrorCode error_code) { + int status = + nghttp2_submit_rst_stream(session_->raw_ptr(), NGHTTP2_FLAG_NONE, + stream_id, static_cast(error_code)); + if (status < 0) { + QUICHE_LOG(WARNING) << "Reset stream failed: " << stream_id + << " with status code " << status; + } +} + +int32_t NgHttp2Adapter::SubmitRequest( + absl::Span headers, + std::unique_ptr data_source, void* stream_user_data) { + auto nvs = GetNghttp2Nvs(headers); + std::unique_ptr provider = + MakeDataProvider(data_source.get()); + + int32_t stream_id = + nghttp2_submit_request(session_->raw_ptr(), nullptr, nvs.data(), + nvs.size(), provider.get(), stream_user_data); + sources_.emplace(stream_id, std::move(data_source)); + QUICHE_VLOG(1) << "Submitted request with " << nvs.size() + << " request headers and user data " << stream_user_data + << "; resulted in stream " << stream_id; + return stream_id; +} + +int NgHttp2Adapter::SubmitResponse( + Http2StreamId stream_id, absl::Span headers, + std::unique_ptr data_source) { + auto nvs = GetNghttp2Nvs(headers); + std::unique_ptr provider = + MakeDataProvider(data_source.get()); + + sources_.emplace(stream_id, std::move(data_source)); + + int result = nghttp2_submit_response(session_->raw_ptr(), stream_id, + nvs.data(), nvs.size(), provider.get()); + QUICHE_VLOG(1) << "Submitted response with " << nvs.size() + << " response headers; result = " << result; + return result; +} + +int NgHttp2Adapter::SubmitTrailer(Http2StreamId stream_id, + absl::Span trailers) { + auto nvs = GetNghttp2Nvs(trailers); + int result = nghttp2_submit_trailer(session_->raw_ptr(), stream_id, + nvs.data(), nvs.size()); + QUICHE_VLOG(1) << "Submitted trailers with " << nvs.size() + << " response trailers; result = " << result; + return result; +} + +void NgHttp2Adapter::SetStreamUserData(Http2StreamId stream_id, + void* stream_user_data) { + nghttp2_session_set_stream_user_data(session_->raw_ptr(), stream_id, + stream_user_data); +} + +void* NgHttp2Adapter::GetStreamUserData(Http2StreamId stream_id) { + return nghttp2_session_get_stream_user_data(session_->raw_ptr(), stream_id); +} + +bool NgHttp2Adapter::ResumeStream(Http2StreamId stream_id) { + return 0 == nghttp2_session_resume_data(session_->raw_ptr(), stream_id); +} + +void NgHttp2Adapter::RemoveStream(Http2StreamId stream_id) { + sources_.erase(stream_id); +} + +NgHttp2Adapter::NgHttp2Adapter(Http2VisitorInterface& visitor, + Perspective perspective, + const nghttp2_option* options) + : Http2Adapter(visitor), + visitor_(visitor), + options_(options), + perspective_(perspective) {} + +NgHttp2Adapter::~NgHttp2Adapter() {} + +void NgHttp2Adapter::Initialize() { + nghttp2_option* owned_options = nullptr; + if (options_ == nullptr) { + nghttp2_option_new(&owned_options); + // Set some common options for compatibility. + nghttp2_option_set_no_closed_streams(owned_options, 1); + nghttp2_option_set_no_auto_window_update(owned_options, 1); + nghttp2_option_set_max_send_header_block_length(owned_options, 0x2000000); + nghttp2_option_set_max_outbound_ack(owned_options, 10000); + nghttp2_option_set_user_recv_extension_type(owned_options, + kMetadataFrameType); + options_ = owned_options; + } + + session_ = + std::make_unique(perspective_, callbacks::Create(), + options_, static_cast(&visitor_)); + if (owned_options != nullptr) { + nghttp2_option_del(owned_options); + } + options_ = nullptr; +} + +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/nghttp2_adapter.h b/quiche/http2/adapter/nghttp2_adapter.h new file mode 100644 index 000000000000..7ae9645aaebc --- /dev/null +++ b/quiche/http2/adapter/nghttp2_adapter.h @@ -0,0 +1,115 @@ +#ifndef QUICHE_HTTP2_ADAPTER_NGHTTP2_ADAPTER_H_ +#define QUICHE_HTTP2_ADAPTER_NGHTTP2_ADAPTER_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "quiche/http2/adapter/http2_adapter.h" +#include "quiche/http2/adapter/http2_protocol.h" +#include "quiche/http2/adapter/nghttp2_session.h" +#include "quiche/http2/adapter/nghttp2_util.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace adapter { + +class QUICHE_EXPORT NgHttp2Adapter : public Http2Adapter { + public: + ~NgHttp2Adapter() override; + + // Creates an adapter that functions as a client. Does not take ownership of + // |options|. + static std::unique_ptr CreateClientAdapter( + Http2VisitorInterface& visitor, const nghttp2_option* options = nullptr); + + // Creates an adapter that functions as a server. Does not take ownership of + // |options|. + static std::unique_ptr CreateServerAdapter( + Http2VisitorInterface& visitor, const nghttp2_option* options = nullptr); + + bool IsServerSession() const override; + bool want_read() const override { return session_->want_read(); } + bool want_write() const override { return session_->want_write(); } + + int64_t ProcessBytes(absl::string_view bytes) override; + void SubmitSettings(absl::Span settings) override; + void SubmitPriorityForStream(Http2StreamId stream_id, + Http2StreamId parent_stream_id, int weight, + bool exclusive) override; + + // Submits a PING on the connection. Note that nghttp2 automatically submits + // PING acks upon receiving non-ack PINGs from the peer, so callers only use + // this method to originate PINGs. See nghttp2_option_set_no_auto_ping_ack(). + void SubmitPing(Http2PingId ping_id) override; + + void SubmitShutdownNotice() override; + void SubmitGoAway(Http2StreamId last_accepted_stream_id, + Http2ErrorCode error_code, + absl::string_view opaque_data) override; + + void SubmitWindowUpdate(Http2StreamId stream_id, + int window_increment) override; + + void SubmitRst(Http2StreamId stream_id, Http2ErrorCode error_code) override; + + void SubmitMetadata(Http2StreamId stream_id, size_t max_frame_size, + std::unique_ptr source) override; + + int Send() override; + + int GetSendWindowSize() const override; + int GetStreamSendWindowSize(Http2StreamId stream_id) const override; + + int GetStreamReceiveWindowLimit(Http2StreamId stream_id) const override; + int GetStreamReceiveWindowSize(Http2StreamId stream_id) const override; + int GetReceiveWindowSize() const override; + + int GetHpackEncoderDynamicTableSize() const override; + int GetHpackDecoderDynamicTableSize() const override; + + Http2StreamId GetHighestReceivedStreamId() const override; + + void MarkDataConsumedForStream(Http2StreamId stream_id, + size_t num_bytes) override; + + int32_t SubmitRequest(absl::Span headers, + std::unique_ptr data_source, + void* user_data) override; + + int SubmitResponse(Http2StreamId stream_id, absl::Span headers, + std::unique_ptr data_source) override; + + int SubmitTrailer(Http2StreamId stream_id, + absl::Span trailers) override; + + void SetStreamUserData(Http2StreamId stream_id, void* user_data) override; + void* GetStreamUserData(Http2StreamId stream_id) override; + + bool ResumeStream(Http2StreamId stream_id) override; + + // Removes references to the `stream_id` from this adapter. + void RemoveStream(Http2StreamId stream_id); + + // Accessor for testing. + size_t sources_size() const { return sources_.size(); } + + private: + NgHttp2Adapter(Http2VisitorInterface& visitor, Perspective perspective, + const nghttp2_option* options); + + // Performs any necessary initialization of the underlying HTTP/2 session, + // such as preparing initial SETTINGS. + void Initialize(); + + std::unique_ptr session_; + Http2VisitorInterface& visitor_; + const nghttp2_option* options_; + Perspective perspective_; + + absl::flat_hash_map> sources_; +}; + +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_NGHTTP2_ADAPTER_H_ diff --git a/quiche/http2/adapter/nghttp2_adapter_test.cc b/quiche/http2/adapter/nghttp2_adapter_test.cc new file mode 100644 index 000000000000..464135ad877e --- /dev/null +++ b/quiche/http2/adapter/nghttp2_adapter_test.cc @@ -0,0 +1,7196 @@ +#include "quiche/http2/adapter/nghttp2_adapter.h" + +#include + +#include "quiche/http2/adapter/http2_protocol.h" +#include "quiche/http2/adapter/http2_visitor_interface.h" +#include "quiche/http2/adapter/mock_http2_visitor.h" +#include "quiche/http2/adapter/nghttp2.h" +#include "quiche/http2/adapter/nghttp2_test_utils.h" +#include "quiche/http2/adapter/oghttp2_util.h" +#include "quiche/http2/adapter/test_frame_sequence.h" +#include "quiche/http2/adapter/test_utils.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace adapter { +namespace test { +namespace { + +using ConnectionError = Http2VisitorInterface::ConnectionError; + +using spdy::SpdyFrameType; +using testing::_; + +enum FrameType { + DATA, + HEADERS, + PRIORITY, + RST_STREAM, + SETTINGS, + PUSH_PROMISE, + PING, + GOAWAY, + WINDOW_UPDATE, + CONTINUATION, +}; + +// This send callback assumes |source|'s pointer is a TestDataSource, and +// |user_data| is a Http2VisitorInterface. +int TestSendCallback(nghttp2_session*, nghttp2_frame* /*frame*/, + const uint8_t* framehd, size_t length, + nghttp2_data_source* source, void* user_data) { + auto* visitor = static_cast(user_data); + // Send the frame header via the visitor. + ssize_t result = visitor->OnReadyToSend(ToStringView(framehd, 9)); + if (result == 0) { + return NGHTTP2_ERR_WOULDBLOCK; + } + auto* test_source = static_cast(source->ptr); + absl::string_view payload = test_source->ReadNext(length); + // Send the frame payload via the visitor. + visitor->OnReadyToSend(payload); + return 0; +} + +TEST(NgHttp2AdapterTest, ClientConstruction) { + testing::StrictMock visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + ASSERT_NE(nullptr, adapter); + EXPECT_TRUE(adapter->want_read()); + EXPECT_FALSE(adapter->want_write()); + EXPECT_FALSE(adapter->IsServerSession()); +} + +TEST(NgHttp2AdapterTest, ClientHandlesFrames) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + testing::StrEq(spdy::kHttp2ConnectionHeaderPrefix)); + visitor.Clear(); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + const std::string initial_frames = TestFrameSequence() + .ServerPreface() + .Ping(42) + .WindowUpdate(0, 1000) + .Serialize(); + testing::InSequence s; + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(0, 8, PING, 0)); + EXPECT_CALL(visitor, OnPing(42, false)); + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(0, 1000)); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), initial_result); + + EXPECT_EQ(adapter->GetSendWindowSize(), kInitialFlowControlWindowSize + 1000); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(PING, 0, 8, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(PING, 0, 8, 0x1, 0)); + + result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::PING})); + visitor.Clear(); + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const std::vector
headers2 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}); + + const std::vector
headers3 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/three"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const char* kSentinel3 = "arbitrary pointer 3"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + const int32_t stream_id2 = adapter->SubmitRequest(headers2, nullptr, nullptr); + ASSERT_GT(stream_id2, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id2; + + const int32_t stream_id3 = + adapter->SubmitRequest(headers3, nullptr, const_cast(kSentinel3)); + ASSERT_GT(stream_id3, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id3; + + const char* kSentinel2 = "arbitrary pointer 2"; + adapter->SetStreamUserData(stream_id2, const_cast(kSentinel2)); + adapter->SetStreamUserData(stream_id3, nullptr); + + EXPECT_EQ(adapter->sources_size(), 3); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id2, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id2, _, 0x5, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id3, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id3, _, 0x5, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::HEADERS, SpdyFrameType::HEADERS, + SpdyFrameType::HEADERS})); + visitor.Clear(); + + // All streams are active and have not yet received any data, so the receive + // window should be at the initial value. + EXPECT_EQ(kInitialFlowControlWindowSize, + adapter->GetStreamReceiveWindowSize(stream_id1)); + EXPECT_EQ(kInitialFlowControlWindowSize, + adapter->GetStreamReceiveWindowSize(stream_id2)); + EXPECT_EQ(kInitialFlowControlWindowSize, + adapter->GetStreamReceiveWindowSize(stream_id3)); + + // Upper bound on the flow control receive window should be the initial value. + EXPECT_EQ(kInitialFlowControlWindowSize, + adapter->GetStreamReceiveWindowLimit(stream_id1)); + + // Connection has not yet received any data. + EXPECT_EQ(kInitialFlowControlWindowSize, adapter->GetReceiveWindowSize()); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + EXPECT_EQ(kSentinel1, adapter->GetStreamUserData(stream_id1)); + EXPECT_EQ(kSentinel2, adapter->GetStreamUserData(stream_id2)); + EXPECT_EQ(nullptr, adapter->GetStreamUserData(stream_id3)); + + EXPECT_EQ(0, adapter->GetHpackDecoderDynamicTableSize()); + + const std::string stream_frames = + TestFrameSequence() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(1, "This is the response body.") + .RstStream(3, Http2ErrorCode::INTERNAL_ERROR) + .GoAway(5, Http2ErrorCode::ENHANCE_YOUR_CALM, "calm down!!") + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 26, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 26)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the response body.")); + EXPECT_CALL(visitor, OnFrameHeader(3, 4, RST_STREAM, 0)); + EXPECT_CALL(visitor, OnRstStream(3, Http2ErrorCode::INTERNAL_ERROR)); + EXPECT_CALL(visitor, OnCloseStream(3, Http2ErrorCode::INTERNAL_ERROR)) + .WillOnce( + [&adapter](Http2StreamId stream_id, Http2ErrorCode /*error_code*/) { + adapter->RemoveStream(stream_id); + return true; + }); + EXPECT_CALL(visitor, OnFrameHeader(0, 19, GOAWAY, 0)); + EXPECT_CALL(visitor, + OnGoAway(5, Http2ErrorCode::ENHANCE_YOUR_CALM, "calm down!!")); + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), stream_result); + + // First stream has received some data. + EXPECT_GT(kInitialFlowControlWindowSize, + adapter->GetStreamReceiveWindowSize(stream_id1)); + // Second stream was closed. + EXPECT_EQ(-1, adapter->GetStreamReceiveWindowSize(stream_id2)); + // Third stream has not received any data. + EXPECT_EQ(kInitialFlowControlWindowSize, + adapter->GetStreamReceiveWindowSize(stream_id3)); + + // One stream was closed. + EXPECT_EQ(adapter->sources_size(), 2); + + // Connection window should be the same as the first stream. + EXPECT_EQ(adapter->GetReceiveWindowSize(), + adapter->GetStreamReceiveWindowSize(stream_id1)); + + // Upper bound on the flow control receive window should still be the initial + // value. + EXPECT_EQ(kInitialFlowControlWindowSize, + adapter->GetStreamReceiveWindowLimit(stream_id1)); + + EXPECT_GT(adapter->GetHpackDecoderDynamicTableSize(), 0); + + // Should be 3, but this method only works for server adapters. + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + // Even though the client recieved a GOAWAY, streams 1 and 5 are still active. + EXPECT_TRUE(adapter->want_read()); + + EXPECT_CALL(visitor, OnFrameHeader(1, 0, DATA, 1)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 0)); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)) + .WillOnce( + [&adapter](Http2StreamId stream_id, Http2ErrorCode /*error_code*/) { + adapter->RemoveStream(stream_id); + return true; + }); + EXPECT_CALL(visitor, OnFrameHeader(5, 4, RST_STREAM, 0)); + EXPECT_CALL(visitor, OnRstStream(5, Http2ErrorCode::REFUSED_STREAM)); + EXPECT_CALL(visitor, OnCloseStream(5, Http2ErrorCode::REFUSED_STREAM)) + .WillOnce( + [&adapter](Http2StreamId stream_id, Http2ErrorCode /*error_code*/) { + adapter->RemoveStream(stream_id); + return true; + }); + adapter->ProcessBytes(TestFrameSequence() + .Data(1, "", true) + .RstStream(5, Http2ErrorCode::REFUSED_STREAM) + .Serialize()); + + // Should be 5, but this method only works for server adapters. + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + // After receiving END_STREAM for 1 and RST_STREAM for 5, the session no + // longer expects reads. + EXPECT_FALSE(adapter->want_read()); + EXPECT_EQ(adapter->sources_size(), 0); + + // Client will not have anything else to write. + EXPECT_FALSE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), testing::IsEmpty()); +} + +TEST(NgHttp2AdapterTest, QueuingWindowUpdateAffectsWindow) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + EXPECT_EQ(adapter->GetReceiveWindowSize(), kInitialFlowControlWindowSize); + adapter->SubmitWindowUpdate(0, 10000); + EXPECT_EQ(adapter->GetReceiveWindowSize(), + kInitialFlowControlWindowSize + 10000); + + const std::vector
headers = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + const int32_t stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + + EXPECT_CALL(visitor, OnBeforeFrameSent(WINDOW_UPDATE, 0, 4, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(WINDOW_UPDATE, 0, 4, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + + EXPECT_EQ(adapter->GetStreamReceiveWindowSize(stream_id), + kInitialFlowControlWindowSize); + adapter->SubmitWindowUpdate(1, 20000); + EXPECT_EQ(adapter->GetStreamReceiveWindowSize(stream_id), + kInitialFlowControlWindowSize + 20000); +} + +TEST(NgHttp2AdapterTest, AckOfSettingInitialWindowSizeAffectsWindow) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector
headers = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + const int32_t stream_id1 = adapter->SubmitRequest(headers, nullptr, nullptr); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + + const std::string initial_frames = + TestFrameSequence().ServerPreface().Serialize(); + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0x0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + int64_t parse_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), static_cast(parse_result)); + + EXPECT_EQ(adapter->GetStreamReceiveWindowSize(stream_id1), + kInitialFlowControlWindowSize); + adapter->SubmitSettings({{INITIAL_WINDOW_SIZE, 80000u}}); + // No update for the first stream, yet. + EXPECT_EQ(adapter->GetStreamReceiveWindowSize(stream_id1), + kInitialFlowControlWindowSize); + + // Ack of server's initial settings. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + // Outbound SETTINGS containing INITIAL_WINDOW_SIZE. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + + // Still no update, as a SETTINGS ack has not yet been received. + EXPECT_EQ(adapter->GetStreamReceiveWindowSize(stream_id1), + kInitialFlowControlWindowSize); + + const std::string settings_ack = + TestFrameSequence().SettingsAck().Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0x1)); + EXPECT_CALL(visitor, OnSettingsAck); + + parse_result = adapter->ProcessBytes(settings_ack); + EXPECT_EQ(settings_ack.size(), static_cast(parse_result)); + + // Stream window has been updated. + EXPECT_EQ(adapter->GetStreamReceiveWindowSize(stream_id1), 80000); + + const std::vector
headers2 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}); + const int32_t stream_id2 = adapter->SubmitRequest(headers, nullptr, nullptr); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id2, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id2, _, 0x5, 0)); + result = adapter->Send(); + EXPECT_EQ(0, result); + + EXPECT_EQ(adapter->GetStreamReceiveWindowSize(stream_id2), 80000); +} + +TEST(NgHttp2AdapterTest, ClientRejects100HeadersWithFin) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, {{":status", "100"}}, /*fin=*/false) + .Headers(1, {{":status", "100"}}, /*fin=*/true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "100")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "100")); + EXPECT_CALL(visitor, + OnInvalidFrame( + 1, Http2VisitorInterface::InvalidFrameError::kHttpMessaging)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(RST_STREAM, 1, _, 0x0, 1)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::PROTOCOL_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, ClientRejects100HeadersWithContent) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, {{":status", "100"}}, + /*fin=*/false) + .Data(1, "We needed the final headers before data, whoops") + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "100")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, _)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::PROTOCOL_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, ClientRejects100HeadersWithContentLength) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, {{":status", "100"}, {"content-length", "42"}}, + /*fin=*/false) + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "100")); + EXPECT_CALL( + visitor, + OnErrorDebug("Invalid HTTP header field was received: frame type: 1, " + "stream: 1, name: [content-length], value: [42]")); + EXPECT_CALL( + visitor, + OnInvalidFrame(1, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::PROTOCOL_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, ClientHandles204WithContent) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + const std::vector
headers2 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}); + + const int32_t stream_id2 = adapter->SubmitRequest(headers2, nullptr, nullptr); + ASSERT_GT(stream_id2, stream_id1); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id2, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id2, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, {{":status", "204"}, {"content-length", "2"}}, + /*fin=*/false) + .Data(1, "hi") + .Headers(3, {{":status", "204"}}, /*fin=*/false) + .Data(3, "hi") + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "204")); + EXPECT_CALL( + visitor, + OnErrorDebug("Invalid HTTP header field was received: frame type: 1, " + "stream: 1, name: [content-length], value: [2]")); + EXPECT_CALL( + visitor, + OnInvalidFrame(1, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":status", "204")); + EXPECT_CALL(visitor, OnEndHeadersForStream(3)); + EXPECT_CALL(visitor, OnFrameHeader(3, _, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(3, 2)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::PROTOCOL_ERROR)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 3, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 3, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(3, Http2ErrorCode::PROTOCOL_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::RST_STREAM, + SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, ClientHandles304WithContent) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, {{":status", "304"}, {"content-length", "2"}}, + /*fin=*/false) + .Data(1, "hi") + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "304")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "content-length", "2")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 2)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::PROTOCOL_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, ClientHandles304WithContentLength) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector
headers = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + ASSERT_GT(stream_id, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, {{":status", "304"}, {"content-length", "2"}}, + /*fin=*/true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "304")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "content-length", "2")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, ClientHandlesTrailers) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(1, "This is the response body.") + .Headers(1, {{"final-status", "A-OK"}}, + /*fin=*/true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 26, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 26)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the response body.")); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, "final-status", "A-OK")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), stream_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, ClientSendsTrailers) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const std::string kBody = "This is an example request body."; + auto body1 = std::make_unique(visitor, false); + body1->AppendPayload(kBody); + // nghttp2 does not require that the data source indicate the end of data + // before trailers are enqueued. + + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, std::move(body1), nullptr); + ASSERT_GT(stream_id1, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id1, _, 0x0, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, + EqualsFrames({SpdyFrameType::HEADERS, SpdyFrameType::DATA})); + visitor.Clear(); + + const std::vector
trailers1 = + ToHeaders({{"extra-info", "Trailers are weird but good?"}}); + adapter->SubmitTrailer(stream_id1, trailers1); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + data = visitor.data(); + EXPECT_THAT(data, EqualsFrames({SpdyFrameType::HEADERS})); +} + +TEST(NgHttp2AdapterTest, ClientHandlesMetadata) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Metadata(0, "Example connection metadata") + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Metadata(1, "Example stream metadata") + .Data(1, "This is the response body.", true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(0, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataEndForStream(0)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataEndForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 26, DATA, 1)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 26)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the response body.")); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), stream_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, ClientHandlesMetadataWithEmptyPayload) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Metadata(1, "") + .Data(1, "This is the response body.", true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(3); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataEndForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 1)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, _)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the response body.")); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); +} + +TEST(NgHttp2AdapterTest, ClientHandlesMetadataWithError) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Metadata(0, "Example connection metadata") + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Metadata(1, "Example stream metadata") + .Data(1, "This is the response body.", true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(0, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataEndForStream(0)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataForStream(1, _)) + .WillOnce(testing::Return(false)); + // Remaining frames are not processed due to the error. + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + // The false return from OnMetadataForStream() results in a connection error. + EXPECT_EQ(stream_result, NGHTTP2_ERR_CALLBACK_FAILURE); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + EXPECT_TRUE(adapter->want_write()); + EXPECT_TRUE(adapter->want_read()); // Even after an error. Why? + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, ClientHandlesHpackHeaderTableSetting) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector
headers1 = ToHeaders({ + {":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"x-i-do-not-like", "green eggs and ham"}, + {"x-i-will-not-eat-them", "here or there, in a box, with a fox"}, + {"x-like-them-in-a-house", "no"}, + {"x-like-them-with-a-mouse", "no"}, + }); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + EXPECT_GT(adapter->GetHpackEncoderDynamicTableSize(), 100); + + const std::string stream_frames = + TestFrameSequence().Settings({{HEADER_TABLE_SIZE, 100u}}).Serialize(); + // Server preface (SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting(Http2Setting{HEADER_TABLE_SIZE, 100u})); + + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), stream_result); + + EXPECT_LE(adapter->GetHpackEncoderDynamicTableSize(), 100); +} + +TEST(NgHttp2AdapterTest, ClientHandlesInvalidTrailers) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(1, "This is the response body.") + .Headers(1, {{":bad-status", "9000"}}, + /*fin=*/true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 26, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 26)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the response body.")); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL( + visitor, + OnErrorDebug("Invalid HTTP header field was received: frame type: 1, " + "stream: 1, name: [:bad-status], value: [9000]")); + EXPECT_CALL( + visitor, + OnInvalidFrame(1, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + + // Bad status trailer will cause a PROTOCOL_ERROR. The header is never + // delivered in an OnHeaderForStream callback. + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), stream_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, stream_id1, 4, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(RST_STREAM, stream_id1, 4, 0x0, 1)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::PROTOCOL_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, ClientRstStreamWhileHandlingHeaders) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(1, "This is the response body.") + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")) + .WillOnce(testing::DoAll( + testing::InvokeWithoutArgs([&adapter]() { + adapter->SubmitRst(1, Http2ErrorCode::REFUSED_STREAM); + }), + testing::Return(Http2VisitorInterface::HEADER_RST_STREAM))); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), stream_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, stream_id1, 4, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, stream_id1, 4, 0x0, + static_cast(Http2ErrorCode::REFUSED_STREAM))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::REFUSED_STREAM)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, ClientConnectionErrorWhileHandlingHeaders) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(1, "This is the response body.") + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")) + .WillOnce( + testing::Return(Http2VisitorInterface::HEADER_CONNECTION_ERROR)); + // Translation to nghttp2 treats this error as a general parsing error. + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(-902 /* NGHTTP2_ERR_CALLBACK_FAILURE */, stream_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, ClientConnectionErrorWhileHandlingHeadersOnly) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")) + .WillOnce( + testing::Return(Http2VisitorInterface::HEADER_CONNECTION_ERROR)); + // Translation to nghttp2 treats this error as a general parsing error. + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(-902 /* NGHTTP2_ERR_CALLBACK_FAILURE */, stream_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, ClientRejectsHeaders) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(1, "This is the response body.") + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)) + .WillOnce(testing::Return(false)); + // Rejecting headers leads to a connection error. + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(NGHTTP2_ERR_CALLBACK_FAILURE, stream_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, ClientStartsShutdown) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + EXPECT_FALSE(adapter->want_write()); + + // No-op for a client implementation. + adapter->SubmitShutdownNotice(); + EXPECT_FALSE(adapter->want_write()); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + + EXPECT_EQ(visitor.data(), spdy::kHttp2ConnectionHeaderPrefix); +} + +TEST(NgHttp2AdapterTest, ClientReceivesGoAway) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + const std::vector
headers2 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}); + + const int32_t stream_id2 = adapter->SubmitRequest(headers2, nullptr, nullptr); + ASSERT_GT(stream_id2, stream_id1); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id2, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id2, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, + EqualsFrames({SpdyFrameType::HEADERS, SpdyFrameType::HEADERS})); + visitor.Clear(); + + // Submit a pending WINDOW_UPDATE for a stream that will be closed due to + // GOAWAY. The WINDOW_UPDATE should not be sent. + adapter->SubmitWindowUpdate(3, 42); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .RstStream(1, Http2ErrorCode::ENHANCE_YOUR_CALM) + .GoAway(1, Http2ErrorCode::INTERNAL_ERROR, "indigestion") + .WindowUpdate(0, 42) + .WindowUpdate(1, 42) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + EXPECT_CALL(visitor, OnFrameHeader(1, 4, RST_STREAM, 0)); + EXPECT_CALL(visitor, OnRstStream(1, Http2ErrorCode::ENHANCE_YOUR_CALM)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::ENHANCE_YOUR_CALM)); + EXPECT_CALL(visitor, OnFrameHeader(0, _, GOAWAY, 0)); + EXPECT_CALL(visitor, + OnGoAway(1, Http2ErrorCode::INTERNAL_ERROR, "indigestion")); + EXPECT_CALL(visitor, OnCloseStream(3, Http2ErrorCode::REFUSED_STREAM)); + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(0, 42)); + EXPECT_CALL(visitor, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), stream_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + // SETTINGS ack (but only after the enqueue of the seemingly unrelated + // WINDOW_UPDATE). The WINDOW_UPDATE is not written. + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, ClientReceivesMultipleGoAways) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string initial_frames = + TestFrameSequence() + .ServerPreface() + .GoAway(kMaxStreamId, Http2ErrorCode::INTERNAL_ERROR, "indigestion") + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + EXPECT_CALL(visitor, OnFrameHeader(0, _, GOAWAY, 0)); + EXPECT_CALL(visitor, OnGoAway(kMaxStreamId, Http2ErrorCode::INTERNAL_ERROR, + "indigestion")); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), static_cast(initial_result)); + + // Submit a WINDOW_UPDATE for the open stream. Because the stream is below the + // GOAWAY's last_stream_id, it should be sent. + adapter->SubmitWindowUpdate(1, 42); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(WINDOW_UPDATE, 1, 4, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(WINDOW_UPDATE, 1, 4, 0x0, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS, + SpdyFrameType::WINDOW_UPDATE})); + visitor.Clear(); + + const std::string final_frames = + TestFrameSequence() + .GoAway(0, Http2ErrorCode::INTERNAL_ERROR, "indigestion") + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, _, GOAWAY, 0)); + EXPECT_CALL(visitor, + OnGoAway(0, Http2ErrorCode::INTERNAL_ERROR, "indigestion")); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::REFUSED_STREAM)); + + const int64_t final_result = adapter->ProcessBytes(final_frames); + EXPECT_EQ(final_frames.size(), static_cast(final_result)); + + EXPECT_FALSE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), testing::IsEmpty()); +} + +TEST(NgHttp2AdapterTest, ClientReceivesMultipleGoAwaysWithIncreasingStreamId) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string frames = + TestFrameSequence() + .ServerPreface() + .GoAway(0, Http2ErrorCode::HTTP2_NO_ERROR, "") + .GoAway(0, Http2ErrorCode::ENHANCE_YOUR_CALM, "") + .GoAway(1, Http2ErrorCode::INTERNAL_ERROR, "") + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + EXPECT_CALL(visitor, OnFrameHeader(0, _, GOAWAY, 0)); + EXPECT_CALL(visitor, OnGoAway(0, Http2ErrorCode::HTTP2_NO_ERROR, "")); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::REFUSED_STREAM)); + EXPECT_CALL(visitor, OnFrameHeader(0, _, GOAWAY, 0)); + EXPECT_CALL(visitor, OnGoAway(0, Http2ErrorCode::ENHANCE_YOUR_CALM, "")); + EXPECT_CALL(visitor, OnFrameHeader(0, _, GOAWAY, 0)); + EXPECT_CALL( + visitor, + OnInvalidFrame(0, Http2VisitorInterface::InvalidFrameError::kProtocol)); + + const int64_t frames_result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(frames_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); +} + +TEST(NgHttp2AdapterTest, ClientReceivesGoAwayWithPendingStreams) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + // Client preface does not appear to include the mandatory SETTINGS frame. + EXPECT_THAT(visitor.data(), + testing::StrEq(spdy::kHttp2ConnectionHeaderPrefix)); + visitor.Clear(); + + testing::InSequence s; + + const std::string initial_frames = + TestFrameSequence() + .ServerPreface({{MAX_CONCURRENT_STREAMS, 1}}) + .Serialize(); + + // Server preface (SETTINGS with MAX_CONCURRENT_STREAMS) + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), initial_result); + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + const std::vector
headers2 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}); + + const int32_t stream_id2 = adapter->SubmitRequest(headers2, nullptr, nullptr); + ASSERT_GT(stream_id2, stream_id1); + + // The second request should be pending because of + // SETTINGS_MAX_CONCURRENT_STREAMS. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); + visitor.Clear(); + + // Let the client receive a GOAWAY and raise MAX_CONCURRENT_STREAMS. Even + // though the GOAWAY last_stream_id is higher than the pending request's + // stream ID, pending request should not be sent. + const std::string stream_frames = + TestFrameSequence() + .GoAway(kMaxStreamId, Http2ErrorCode::INTERNAL_ERROR, "indigestion") + .Settings({{MAX_CONCURRENT_STREAMS, 42u}}) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, _, GOAWAY, 0)); + EXPECT_CALL(visitor, OnGoAway(kMaxStreamId, Http2ErrorCode::INTERNAL_ERROR, + "indigestion")); + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting(Http2Setting{MAX_CONCURRENT_STREAMS, 42u})); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), stream_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + // Nghttp2 closes the pending stream on the next write attempt. + EXPECT_CALL(visitor, OnCloseStream(3, Http2ErrorCode::REFUSED_STREAM)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); + visitor.Clear(); + + // Requests submitted after receiving the GOAWAY should not be sent. + const std::vector
headers3 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/three"}}); + + const int32_t stream_id3 = adapter->SubmitRequest(headers3, nullptr, nullptr); + ASSERT_GT(stream_id3, stream_id2); + + // Nghttp2 closes the pending stream on the next write attempt. + EXPECT_CALL(visitor, OnCloseStream(5, Http2ErrorCode::REFUSED_STREAM)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), testing::IsEmpty()); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(NgHttp2AdapterTest, ClientFailsOnGoAway) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .GoAway(1, Http2ErrorCode::INTERNAL_ERROR, "indigestion") + .Data(1, "This is the response body.") + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(0, _, GOAWAY, 0)); + EXPECT_CALL(visitor, + OnGoAway(1, Http2ErrorCode::INTERNAL_ERROR, "indigestion")) + .WillOnce(testing::Return(false)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(NGHTTP2_ERR_CALLBACK_FAILURE, stream_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, ClientRejects101Response) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"upgrade", "new-protocol"}}); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "101"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL( + visitor, + OnErrorDebug("Invalid HTTP header field was received: frame type: 1, " + "stream: 1, name: [:status], value: [101]")); + EXPECT_CALL( + visitor, + OnInvalidFrame(1, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(static_cast(stream_frames.size()), stream_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, 4, 0x0)); + EXPECT_CALL( + visitor, + OnFrameSent(RST_STREAM, 1, 4, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::PROTOCOL_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, ClientSubmitRequest) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + int result = adapter->Send(); + EXPECT_EQ(0, result); + // Client preface does not appear to include the mandatory SETTINGS frame. + EXPECT_THAT(visitor.data(), + testing::StrEq(spdy::kHttp2ConnectionHeaderPrefix)); + visitor.Clear(); + + const std::string initial_frames = + TestFrameSequence().ServerPreface().Serialize(); + testing::InSequence s; + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), initial_result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); + visitor.Clear(); + + EXPECT_EQ(0, adapter->GetHpackEncoderDynamicTableSize()); + EXPECT_FALSE(adapter->want_write()); + const char* kSentinel = ""; + const absl::string_view kBody = "This is an example request body."; + auto body1 = std::make_unique(visitor, true); + body1->AppendPayload(kBody); + body1->EndData(); + int stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(body1), const_cast(kSentinel)); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, _, 0x1, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + + EXPECT_EQ(kInitialFlowControlWindowSize, + adapter->GetStreamReceiveWindowSize(stream_id)); + EXPECT_EQ(kInitialFlowControlWindowSize, adapter->GetReceiveWindowSize()); + EXPECT_EQ(kInitialFlowControlWindowSize, + adapter->GetStreamReceiveWindowLimit(stream_id)); + + EXPECT_GT(adapter->GetHpackEncoderDynamicTableSize(), 0); + + // Some data was sent, so the remaining send window size should be less than + // the default. + EXPECT_LT(adapter->GetStreamSendWindowSize(stream_id), + kInitialFlowControlWindowSize); + EXPECT_GT(adapter->GetStreamSendWindowSize(stream_id), 0); + // Send window for a nonexistent stream is not available. + EXPECT_EQ(-1, adapter->GetStreamSendWindowSize(stream_id + 2)); + + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::HEADERS, SpdyFrameType::DATA})); + EXPECT_THAT(visitor.data(), testing::HasSubstr(kBody)); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); + + stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + nullptr, nullptr); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(adapter->want_write()); + const char* kSentinel2 = "arbitrary pointer 2"; + EXPECT_EQ(nullptr, adapter->GetStreamUserData(stream_id)); + adapter->SetStreamUserData(stream_id, const_cast(kSentinel2)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x5, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::HEADERS})); + + EXPECT_EQ(kSentinel2, adapter->GetStreamUserData(stream_id)); + + // No data was sent (just HEADERS), so the remaining send window size should + // still be the default. + EXPECT_EQ(adapter->GetStreamSendWindowSize(stream_id), + kInitialFlowControlWindowSize); +} + +// This is really a test of the MakeZeroCopyDataFrameSource adapter, but I +// wasn't sure where else to put it. +TEST(NgHttp2AdapterTest, ClientSubmitRequestWithDataProvider) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + int result = adapter->Send(); + EXPECT_EQ(0, result); + // Client preface does not appear to include the mandatory SETTINGS frame. + EXPECT_THAT(visitor.data(), + testing::StrEq(spdy::kHttp2ConnectionHeaderPrefix)); + visitor.Clear(); + + const std::string initial_frames = + TestFrameSequence().ServerPreface().Serialize(); + testing::InSequence s; + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), initial_result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); + visitor.Clear(); + + EXPECT_FALSE(adapter->want_write()); + const absl::string_view kBody = "This is an example request body."; + // This test will use TestDataSource as the source of the body payload data. + TestDataSource body1{kBody}; + // The TestDataSource is wrapped in the nghttp2_data_provider data type. + nghttp2_data_provider provider = body1.MakeDataProvider(); + nghttp2_send_data_callback send_callback = &TestSendCallback; + + // This call transforms it back into a DataFrameSource, which is compatible + // with the Http2Adapter API. + std::unique_ptr frame_source = + MakeZeroCopyDataFrameSource(provider, &visitor, std::move(send_callback)); + int stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(frame_source), nullptr); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, _, 0x1, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::HEADERS, SpdyFrameType::DATA})); + EXPECT_THAT(visitor.data(), testing::HasSubstr(kBody)); + EXPECT_FALSE(adapter->want_write()); +} + +// This test verifies how nghttp2 behaves when a data source becomes +// read-blocked. +TEST(NgHttp2AdapterTest, ClientSubmitRequestWithDataProviderAndReadBlock) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + const absl::string_view kBody = "This is an example request body."; + // This test will use TestDataSource as the source of the body payload data. + TestDataSource body1{kBody}; + body1.set_is_data_available(false); + // The TestDataSource is wrapped in the nghttp2_data_provider data type. + nghttp2_data_provider provider = body1.MakeDataProvider(); + nghttp2_send_data_callback send_callback = &TestSendCallback; + + // This call transforms it back into a DataFrameSource, which is compatible + // with the Http2Adapter API. + std::unique_ptr frame_source = + MakeZeroCopyDataFrameSource(provider, &visitor, std::move(send_callback)); + int stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(frame_source), nullptr); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x4, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + // Client preface does not appear to include the mandatory SETTINGS frame. + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, EqualsFrames({SpdyFrameType::HEADERS})); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); + + // Resume the deferred stream. + body1.set_is_data_available(true); + EXPECT_TRUE(adapter->ResumeStream(stream_id)); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, _, 0x1, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::DATA})); + EXPECT_FALSE(adapter->want_write()); + + // Stream data is done, so this stream cannot be resumed. + EXPECT_FALSE(adapter->ResumeStream(stream_id)); + EXPECT_FALSE(adapter->want_write()); +} + +// This test verifies how nghttp2 behaves when a data source is read block, then +// ends with an empty DATA frame. +TEST(NgHttp2AdapterTest, ClientSubmitRequestEmptyDataWithFin) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + const absl::string_view kEmptyBody = ""; + // This test will use TestDataSource as the source of the body payload data. + TestDataSource body1{kEmptyBody}; + body1.set_is_data_available(false); + // The TestDataSource is wrapped in the nghttp2_data_provider data type. + nghttp2_data_provider provider = body1.MakeDataProvider(); + nghttp2_send_data_callback send_callback = &TestSendCallback; + + // This call transforms it back into a DataFrameSource, which is compatible + // with the Http2Adapter API. + std::unique_ptr frame_source = + MakeZeroCopyDataFrameSource(provider, &visitor, std::move(send_callback)); + int stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(frame_source), nullptr); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x4, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + // Client preface does not appear to include the mandatory SETTINGS frame. + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, EqualsFrames({SpdyFrameType::HEADERS})); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); + + // Resume the deferred stream. + body1.set_is_data_available(true); + EXPECT_TRUE(adapter->ResumeStream(stream_id)); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, 0, 0x1, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::DATA})); + EXPECT_FALSE(adapter->want_write()); + + // Stream data is done, so this stream cannot be resumed. + EXPECT_FALSE(adapter->ResumeStream(stream_id)); + EXPECT_FALSE(adapter->want_write()); +} + +// This test verifies how nghttp2 behaves when a connection becomes +// write-blocked. +TEST(NgHttp2AdapterTest, ClientSubmitRequestWithDataProviderAndWriteBlock) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + const absl::string_view kBody = "This is an example request body."; + // This test will use TestDataSource as the source of the body payload data. + TestDataSource body1{kBody}; + // The TestDataSource is wrapped in the nghttp2_data_provider data type. + nghttp2_data_provider provider = body1.MakeDataProvider(); + nghttp2_send_data_callback send_callback = &TestSendCallback; + + // This call transforms it back into a DataFrameSource, which is compatible + // with the Http2Adapter API. + std::unique_ptr frame_source = + MakeZeroCopyDataFrameSource(provider, &visitor, std::move(send_callback)); + int stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(frame_source), nullptr); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(adapter->want_write()); + + visitor.set_is_write_blocked(true); + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), testing::IsEmpty()); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, _, 0x1, 0)); + + visitor.set_is_write_blocked(false); + result = adapter->Send(); + EXPECT_EQ(0, result); + + // Client preface does not appear to include the mandatory SETTINGS frame. + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, + EqualsFrames({SpdyFrameType::HEADERS, SpdyFrameType::DATA})); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(NgHttp2AdapterTest, ClientReceivesDataOnClosedStream) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + int result = adapter->Send(); + EXPECT_EQ(0, result); + // Client preface does not appear to include the mandatory SETTINGS frame. + EXPECT_THAT(visitor.data(), + testing::StrEq(spdy::kHttp2ConnectionHeaderPrefix)); + visitor.Clear(); + + const std::string initial_frames = + TestFrameSequence().ServerPreface().Serialize(); + testing::InSequence s; + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), initial_result); + + // Client SETTINGS ack + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); + visitor.Clear(); + + // Let the client open a stream with a request. + int stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + nullptr, nullptr); + EXPECT_GT(stream_id, 0); + + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x5, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::HEADERS})); + visitor.Clear(); + + // Let the client RST_STREAM the stream it opened. + adapter->SubmitRst(stream_id, Http2ErrorCode::CANCEL); + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, stream_id, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(RST_STREAM, stream_id, _, 0x0, + static_cast(Http2ErrorCode::CANCEL))); + EXPECT_CALL(visitor, OnCloseStream(stream_id, Http2ErrorCode::CANCEL)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::RST_STREAM})); + visitor.Clear(); + + // Let the server send a response on the stream. (It might not have received + // the RST_STREAM yet.) + const std::string response_frames = + TestFrameSequence() + .Headers(stream_id, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(stream_id, "This is the response body.", /*fin=*/true) + .Serialize(); + + // The visitor gets notified about the HEADERS frame but not the DATA frame on + // the closed stream. No further processing for either frame occurs. + EXPECT_CALL(visitor, OnFrameHeader(stream_id, _, HEADERS, 0x4)); + EXPECT_CALL(visitor, OnFrameHeader(stream_id, _, DATA, _)).Times(0); + + const int64_t response_result = adapter->ProcessBytes(response_frames); + EXPECT_EQ(response_frames.size(), response_result); + + EXPECT_FALSE(adapter->want_write()); +} + +TEST(NgHttp2AdapterTest, ClientQueuesRequests) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + adapter->SubmitSettings({}); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + adapter->Send(); + + const std::string initial_frames = + TestFrameSequence() + .ServerPreface({{MAX_CONCURRENT_STREAMS, 2}}) + .SettingsAck() + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0x0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting(Http2Setting{ + Http2KnownSettingsId::MAX_CONCURRENT_STREAMS, 2u})); + EXPECT_CALL(visitor, OnSettingsEnd()); + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0x1)); + EXPECT_CALL(visitor, OnSettingsAck()); + + adapter->ProcessBytes(initial_frames); + + const std::vector
headers = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/example/request"}}); + std::vector stream_ids; + // Start two, which hits the limit. + int32_t stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + stream_ids.push_back(stream_id); + stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + stream_ids.push_back(stream_id); + // Start two more, which must be queued. + stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + stream_ids.push_back(stream_id); + stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + stream_ids.push_back(stream_id); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_ids[0], _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_ids[0], _, 0x5, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_ids[1], _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_ids[1], _, 0x5, 0)); + + adapter->Send(); + + const std::string update_streams = + TestFrameSequence().Settings({{MAX_CONCURRENT_STREAMS, 5}}).Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0x0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting(Http2Setting{ + Http2KnownSettingsId::MAX_CONCURRENT_STREAMS, 5u})); + EXPECT_CALL(visitor, OnSettingsEnd()); + + adapter->ProcessBytes(update_streams); + + stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + stream_ids.push_back(stream_id); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_ids[2], _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_ids[2], _, 0x5, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_ids[3], _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_ids[3], _, 0x5, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_ids[4], _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_ids[4], _, 0x5, 0)); + // Header frames should all have been sent in order, regardless of any + // queuing. + + adapter->Send(); +} + +TEST(NgHttp2AdapterTest, ClientAcceptsHeadResponseWithContentLength) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + const std::vector
headers = ToHeaders({{":method", "HEAD"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/"}}); + const int32_t stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + + testing::InSequence s; + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x5, 0)); + + adapter->Send(); + + const std::string initial_frames = + TestFrameSequence() + .ServerPreface() + .Headers(stream_id, {{":status", "200"}, {"content-length", "101"}}, + /*fin=*/true) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, _, SETTINGS, 0x0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + EXPECT_CALL(visitor, OnFrameHeader(stream_id, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(stream_id)); + EXPECT_CALL(visitor, OnHeaderForStream).Times(2); + EXPECT_CALL(visitor, OnEndHeadersForStream(stream_id)); + EXPECT_CALL(visitor, OnEndStream(stream_id)); + EXPECT_CALL(visitor, + OnCloseStream(stream_id, Http2ErrorCode::HTTP2_NO_ERROR)); + + adapter->ProcessBytes(initial_frames); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + adapter->Send(); +} + +TEST(NgHttp2AdapterTest, SubmitMetadata) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + auto source = std::make_unique(ToHeaderBlock(ToHeaders( + {{"query-cost", "is too darn high"}, {"secret-sauce", "hollandaise"}}))); + adapter->SubmitMetadata(1, 16384u, std::move(source)); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(kMetadataFrameType, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(kMetadataFrameType, 1, _, 0x4, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, + EqualsFrames({static_cast(kMetadataFrameType)})); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(NgHttp2AdapterTest, SubmitMetadataMultipleFrames) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + const auto kLargeValue = std::string(63 * 1024, 'a'); + auto source = std::make_unique( + ToHeaderBlock(ToHeaders({{"large-value", kLargeValue}}))); + adapter->SubmitMetadata(1, 16384u, std::move(source)); + EXPECT_TRUE(adapter->want_write()); + + testing::InSequence seq; + EXPECT_CALL(visitor, OnBeforeFrameSent(kMetadataFrameType, 1, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(kMetadataFrameType, 1, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(kMetadataFrameType, 1, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(kMetadataFrameType, 1, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(kMetadataFrameType, 1, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(kMetadataFrameType, 1, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(kMetadataFrameType, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(kMetadataFrameType, 1, _, 0x4, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, + EqualsFrames({static_cast(kMetadataFrameType), + static_cast(kMetadataFrameType), + static_cast(kMetadataFrameType), + static_cast(kMetadataFrameType)})); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(NgHttp2AdapterTest, SubmitConnectionMetadata) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + auto source = std::make_unique(ToHeaderBlock(ToHeaders( + {{"query-cost", "is too darn high"}, {"secret-sauce", "hollandaise"}}))); + adapter->SubmitMetadata(0, 16384u, std::move(source)); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(kMetadataFrameType, 0, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(kMetadataFrameType, 0, _, 0x4, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, + EqualsFrames({static_cast(kMetadataFrameType)})); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(NgHttp2AdapterTest, ClientObeysMaxConcurrentStreams) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + int result = adapter->Send(); + EXPECT_EQ(0, result); + // Client preface does not appear to include the mandatory SETTINGS frame. + EXPECT_THAT(visitor.data(), + testing::StrEq(spdy::kHttp2ConnectionHeaderPrefix)); + visitor.Clear(); + + const std::string initial_frames = + TestFrameSequence() + .ServerPreface({{MAX_CONCURRENT_STREAMS, 1}}) + .Serialize(); + testing::InSequence s; + + // Server preface (SETTINGS with MAX_CONCURRENT_STREAMS) + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), initial_result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); + visitor.Clear(); + + EXPECT_FALSE(adapter->want_write()); + const absl::string_view kBody = "This is an example request body."; + auto body1 = std::make_unique(visitor, true); + body1->AppendPayload(kBody); + body1->EndData(); + const int stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(body1), nullptr); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, _, 0x1, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::HEADERS, SpdyFrameType::DATA})); + EXPECT_THAT(visitor.data(), testing::HasSubstr(kBody)); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); + + const int next_stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}), + nullptr, nullptr); + + // A new pending stream is created, but because of MAX_CONCURRENT_STREAMS, the + // session should not want to write it at the moment. + EXPECT_GT(next_stream_id, stream_id); + EXPECT_FALSE(adapter->want_write()); + + const std::string stream_frames = + TestFrameSequence() + .Headers(stream_id, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(stream_id, "This is the response body.", /*fin=*/true) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(stream_id, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(stream_id)); + EXPECT_CALL(visitor, OnHeaderForStream(stream_id, ":status", "200")); + EXPECT_CALL(visitor, + OnHeaderForStream(stream_id, "server", "my-fake-server")); + EXPECT_CALL(visitor, OnHeaderForStream(stream_id, "date", + "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(stream_id)); + EXPECT_CALL(visitor, OnFrameHeader(stream_id, 26, DATA, 0x1)); + EXPECT_CALL(visitor, OnBeginDataForStream(stream_id, 26)); + EXPECT_CALL(visitor, + OnDataForStream(stream_id, "This is the response body.")); + EXPECT_CALL(visitor, OnEndStream(stream_id)); + EXPECT_CALL(visitor, + OnCloseStream(stream_id, Http2ErrorCode::HTTP2_NO_ERROR)); + + // The first stream should close, which should make the session want to write + // the next stream. + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), stream_result); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, next_stream_id, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, next_stream_id, _, 0x5, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::HEADERS})); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(NgHttp2AdapterTest, ClientReceivesInitialWindowSetting) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + const std::string initial_frames = + TestFrameSequence() + .Settings({{INITIAL_WINDOW_SIZE, 80000u}}) + .WindowUpdate(0, 65536) + .Serialize(); + // Server preface (SETTINGS with INITIAL_STREAM_WINDOW) + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting(Http2Setting{INITIAL_WINDOW_SIZE, 80000u})); + EXPECT_CALL(visitor, OnSettingsEnd()); + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(0, 65536)); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), static_cast(initial_result)); + + // Session will want to write a SETTINGS ack. + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + int64_t result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, EqualsFrames({SpdyFrameType::SETTINGS})); + visitor.Clear(); + + const std::string kLongBody = std::string(81000, 'c'); + auto body1 = std::make_unique(visitor, true); + body1->AppendPayload(kLongBody); + body1->EndData(); + const int stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(body1), nullptr); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x4, 0)); + // The client can send more than 4 frames (65536 bytes) of data. + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, 16384, 0x0, 0)).Times(4); + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, 14464, 0x0, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::HEADERS, SpdyFrameType::DATA, + SpdyFrameType::DATA, SpdyFrameType::DATA, + SpdyFrameType::DATA, SpdyFrameType::DATA})); +} + +TEST(NgHttp2AdapterTest, ClientReceivesInitialWindowSettingAfterStreamStart) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + const std::string initial_frames = + TestFrameSequence().ServerPreface().WindowUpdate(0, 65536).Serialize(); + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(0, 65536)); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), static_cast(initial_result)); + + // Session will want to write a SETTINGS ack. + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + int64_t result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string kLongBody = std::string(81000, 'c'); + auto body1 = std::make_unique(visitor, true); + body1->AppendPayload(kLongBody); + body1->EndData(); + const int stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(body1), nullptr); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x4, 0)); + // The client can only send 65535 bytes of data, as the stream window has not + // yet been increased. + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, 16384, 0x0, 0)).Times(3); + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, 16383, 0x0, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::HEADERS, SpdyFrameType::DATA, + SpdyFrameType::DATA, SpdyFrameType::DATA, + SpdyFrameType::DATA})); + visitor.Clear(); + + // Can't write any more due to flow control. + EXPECT_FALSE(adapter->want_write()); + + const std::string settings_frame = + TestFrameSequence().Settings({{INITIAL_WINDOW_SIZE, 80000u}}).Serialize(); + // SETTINGS with INITIAL_STREAM_WINDOW + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting(Http2Setting{INITIAL_WINDOW_SIZE, 80000u})); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t settings_result = adapter->ProcessBytes(settings_frame); + EXPECT_EQ(settings_frame.size(), static_cast(settings_result)); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + // The client can write more after receiving the INITIAL_WINDOW_SIZE setting. + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, 14465, 0x0, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::DATA})); +} + +TEST(NgHttp2AdapterTest, InvalidInitialWindowSetting) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + const uint32_t kTooLargeInitialWindow = 1u << 31; + const std::string initial_frames = + TestFrameSequence() + .Settings({{INITIAL_WINDOW_SIZE, kTooLargeInitialWindow}}) + .Serialize(); + // Server preface (SETTINGS with INITIAL_STREAM_WINDOW) + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, + OnInvalidFrame( + 0, Http2VisitorInterface::InvalidFrameError::kFlowControl)); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), static_cast(initial_result)); + + // Session will want to write a GOAWAY. + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL( + visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::FLOW_CONTROL_ERROR))); + + int64_t result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, EqualsFrames({SpdyFrameType::GOAWAY})); + visitor.Clear(); +} + +TEST(NgHttp2AdapterTest, InitialWindowSettingCausesOverflow) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector
headers = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + const int32_t stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + ASSERT_GT(stream_id, 0); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x5, 0)); + int64_t write_result = adapter->Send(); + EXPECT_EQ(0, write_result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({SpdyFrameType::HEADERS})); + visitor.Clear(); + + const uint32_t kLargeInitialWindow = (1u << 31) - 1; + const std::string frames = + TestFrameSequence() + .ServerPreface() + .Headers(stream_id, {{":status", "200"}}, /*fin=*/false) + .WindowUpdate(stream_id, 65536u) + .Settings({{INITIAL_WINDOW_SIZE, kLargeInitialWindow}}) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(stream_id, _, HEADERS, 0x4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(stream_id)); + EXPECT_CALL(visitor, OnHeaderForStream(stream_id, ":status", "200")); + EXPECT_CALL(visitor, OnEndHeadersForStream(stream_id)); + + EXPECT_CALL(visitor, OnFrameHeader(stream_id, 4, WINDOW_UPDATE, 0x0)); + EXPECT_CALL(visitor, OnWindowUpdate(stream_id, 65536)); + + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting(Http2Setting{INITIAL_WINDOW_SIZE, + kLargeInitialWindow})); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + // The stream window update plus the SETTINGS frame with INITIAL_WINDOW_SIZE + // pushes the stream's flow control window outside of the acceptable range. + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, stream_id, 4, 0x0)); + EXPECT_CALL( + visitor, + OnFrameSent(RST_STREAM, stream_id, 4, 0x0, + static_cast(Http2ErrorCode::FLOW_CONTROL_ERROR))); + EXPECT_CALL(visitor, + OnCloseStream(stream_id, Http2ErrorCode::FLOW_CONTROL_ERROR)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, ClientForbidsPushPromise) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + adapter->SubmitSettings({{ENABLE_PUSH, 0}}); + + testing::InSequence s; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + + int write_result = adapter->Send(); + EXPECT_EQ(0, write_result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({SpdyFrameType::SETTINGS})); + + visitor.Clear(); + + const std::vector
headers = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + const int32_t stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + ASSERT_GT(stream_id, 0); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x5, 0)); + write_result = adapter->Send(); + EXPECT_EQ(0, write_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::vector
push_headers = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/push"}}); + const std::string frames = TestFrameSequence() + .ServerPreface() + .SettingsAck() + .PushPromise(stream_id, 2, push_headers) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + // SETTINGS ack (to acknowledge PUSH_ENABLED=0) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0x1)); + EXPECT_CALL(visitor, OnSettingsAck); + + // The PUSH_PROMISE is now treated as an invalid frame. + EXPECT_CALL(visitor, OnFrameHeader(stream_id, _, PUSH_PROMISE, _)); + EXPECT_CALL(visitor, OnInvalidFrame(stream_id, _)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), read_result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL( + visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + write_result = adapter->Send(); + EXPECT_EQ(0, write_result); +} + +TEST(NgHttp2AdapterTest, ClientForbidsPushStream) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + adapter->SubmitSettings({{ENABLE_PUSH, 0}}); + + testing::InSequence s; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + + int write_result = adapter->Send(); + EXPECT_EQ(0, write_result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({SpdyFrameType::SETTINGS})); + + visitor.Clear(); + + const std::vector
headers = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + const int32_t stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + ASSERT_GT(stream_id, 0); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x5, 0)); + write_result = adapter->Send(); + EXPECT_EQ(0, write_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string frames = + TestFrameSequence() + .ServerPreface() + .SettingsAck() + .Headers(2, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + // SETTINGS ack (to acknowledge PUSH_ENABLED=0) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0x1)); + EXPECT_CALL(visitor, OnSettingsAck); + + // The push HEADERS are invalid. + EXPECT_CALL(visitor, OnFrameHeader(2, _, HEADERS, _)); + EXPECT_CALL(visitor, OnInvalidFrame(2, _)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), read_result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL( + visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + write_result = adapter->Send(); + EXPECT_EQ(0, write_result); +} + +TEST(NgHttp2AdapterTest, FailureSendingConnectionPreface) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + visitor.set_has_write_error(); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kSendError)); + + int result = adapter->Send(); + EXPECT_EQ(result, NGHTTP2_ERR_CALLBACK_FAILURE); +} + +TEST(NgHttp2AdapterTest, MaxFrameSizeSettingNotAppliedBeforeAck) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + const uint32_t large_frame_size = kDefaultFramePayloadSizeLimit + 42; + adapter->SubmitSettings({{MAX_FRAME_SIZE, large_frame_size}}); + const int32_t stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + /*data_source=*/nullptr, /*user_data=*/nullptr); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(adapter->want_write()); + + testing::InSequence s; + + // Client preface (SETTINGS with MAX_FRAME_SIZE) and request HEADERS + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x5, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string server_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, {{":status", "200"}}, /*fin=*/false) + .Data(1, std::string(large_frame_size, 'a')) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + // Response HEADERS. Because the SETTINGS with MAX_FRAME_SIZE was not + // acknowledged, the large DATA is treated as a connection error. Note that + // nghttp2 does not deliver any DATA or connection error events. + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + const int64_t process_result = adapter->ProcessBytes(server_frames); + EXPECT_EQ(server_frames.size(), static_cast(process_result)); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::FRAME_SIZE_ERROR))); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); +} + +TEST(NgHttp2AdapterTest, MaxFrameSizeSettingAppliedAfterAck) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + const uint32_t large_frame_size = kDefaultFramePayloadSizeLimit + 42; + adapter->SubmitSettings({{MAX_FRAME_SIZE, large_frame_size}}); + const int32_t stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + /*data_source=*/nullptr, /*user_data=*/nullptr); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(adapter->want_write()); + + testing::InSequence s; + + // Client preface (SETTINGS with MAX_FRAME_SIZE) and request HEADERS + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x5, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string server_frames = + TestFrameSequence() + .ServerPreface() + .SettingsAck() + .Headers(1, {{":status", "200"}}, /*fin=*/false) + .Data(1, std::string(large_frame_size, 'a')) + .Serialize(); + + // Server preface (empty SETTINGS) and ack of SETTINGS. + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0x1)); + EXPECT_CALL(visitor, OnSettingsAck()); + + // Response HEADERS and DATA. Because the SETTINGS with MAX_FRAME_SIZE was + // acknowledged, the large DATA is accepted without any error. + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, large_frame_size, DATA, 0x0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, large_frame_size)); + EXPECT_CALL(visitor, OnDataForStream(1, _)); + + const int64_t process_result = adapter->ProcessBytes(server_frames); + EXPECT_EQ(server_frames.size(), static_cast(process_result)); + + // Client ack of SETTINGS. + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, WindowUpdateRaisesFlowControlWindowLimit) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string data_chunk(kDefaultFramePayloadSizeLimit, 'a'); + const std::string request = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}}, + /*fin=*/false) + .Serialize(); + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + adapter->ProcessBytes(request); + + // Updates the advertised window for the connection and stream 1. + adapter->SubmitWindowUpdate(0, 2 * kDefaultFramePayloadSizeLimit); + adapter->SubmitWindowUpdate(1, 2 * kDefaultFramePayloadSizeLimit); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(WINDOW_UPDATE, 0, 4, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(WINDOW_UPDATE, 0, 4, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(WINDOW_UPDATE, 1, 4, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(WINDOW_UPDATE, 1, 4, 0x0, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + + // Verifies the advertised window. + EXPECT_EQ(kInitialFlowControlWindowSize + 2 * kDefaultFramePayloadSizeLimit, + adapter->GetReceiveWindowSize()); + EXPECT_EQ(kInitialFlowControlWindowSize + 2 * kDefaultFramePayloadSizeLimit, + adapter->GetStreamReceiveWindowSize(1)); + + const std::string request_body = TestFrameSequence() + .Data(1, data_chunk) + .Data(1, data_chunk) + .Data(1, data_chunk) + .Data(1, data_chunk) + .Data(1, data_chunk) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 0)).Times(5); + EXPECT_CALL(visitor, OnBeginDataForStream(1, _)).Times(5); + EXPECT_CALL(visitor, OnDataForStream(1, _)).Times(5); + + // DATA frames on stream 1 consume most of the window. + adapter->ProcessBytes(request_body); + EXPECT_EQ(kInitialFlowControlWindowSize - 3 * kDefaultFramePayloadSizeLimit, + adapter->GetReceiveWindowSize()); + EXPECT_EQ(kInitialFlowControlWindowSize - 3 * kDefaultFramePayloadSizeLimit, + adapter->GetStreamReceiveWindowSize(1)); + + // Marking the data consumed should result in an advertised window larger than + // the initial window. + adapter->MarkDataConsumedForStream(1, 4 * kDefaultFramePayloadSizeLimit); + EXPECT_GT(adapter->GetReceiveWindowSize(), kInitialFlowControlWindowSize); + EXPECT_GT(adapter->GetStreamReceiveWindowSize(1), + kInitialFlowControlWindowSize); +} + +TEST(NgHttp2AdapterTest, ConnectionErrorOnControlFrameSent) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = + TestFrameSequence().ClientPreface().Ping(42).Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // PING + EXPECT_CALL(visitor, OnFrameHeader(0, _, PING, 0)); + EXPECT_CALL(visitor, OnPing(42, false)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + + EXPECT_TRUE(adapter->want_write()); + + // SETTINGS ack + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)) + .WillOnce(testing::Return(-902)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kSendError)); + + int send_result = adapter->Send(); + EXPECT_LT(send_result, 0); + + // Apparently nghttp2 retries sending the frames that had failed before. + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(PING, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(PING, 0, _, 0x1, 0)); + send_result = adapter->Send(); + EXPECT_EQ(send_result, 0); + + EXPECT_FALSE(adapter->want_write()); +} + +TEST(NgHttp2AdapterTest, ConnectionErrorOnDataFrameSent) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0x5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + + auto body = std::make_unique(visitor, true); + body->AppendPayload("Here is some data, which will lead to a fatal error"); + TestDataFrameSource* body_ptr = body.get(); + int submit_result = adapter->SubmitResponse( + 1, ToHeaders({{":status", "200"}}), std::move(body)); + ASSERT_EQ(0, submit_result); + + EXPECT_TRUE(adapter->want_write()); + + // SETTINGS ack + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + // Stream 1, with doomed DATA + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)) + .WillOnce(testing::Return(-902)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kSendError)); + + int send_result = adapter->Send(); + EXPECT_LT(send_result, 0); + + // The test data source got a signal that the first chunk of data was sent + // successfully, so discarded that data internally. However, due to the send + // error, the next Send() from nghttp2 will try to send that exact same data + // again. Without this line appending the exact same data back to the data + // source, the test crashes. It is not clear how the data source would know to + // not discard the data, unless told by the session? This is not intuitive. + body_ptr->AppendPayload( + "Here is some data, which will lead to a fatal error"); + + // Apparently nghttp2 retries sending the frames that had failed before. + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)); + + send_result = adapter->Send(); + EXPECT_EQ(send_result, 0); + + EXPECT_FALSE(adapter->want_write()); +} + +TEST(NgHttp2AdapterTest, ServerConstruction) { + testing::StrictMock visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + ASSERT_NE(nullptr, adapter); + EXPECT_TRUE(adapter->want_read()); + EXPECT_FALSE(adapter->want_write()); + EXPECT_TRUE(adapter->IsServerSession()); +} + +TEST(NgHttp2AdapterTest, ServerHandlesFrames) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + EXPECT_EQ(0, adapter->GetHpackDecoderDynamicTableSize()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Ping(42) + .WindowUpdate(0, 1000) + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .Headers(3, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .RstStream(3, Http2ErrorCode::CANCEL) + .Ping(47) + .Serialize(); + testing::InSequence s; + + const char* kSentinel1 = "arbitrary pointer 1"; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(0, 8, PING, 0)); + EXPECT_CALL(visitor, OnPing(42, false)); + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(0, 1000)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)) + .WillOnce(testing::InvokeWithoutArgs([&adapter, kSentinel1]() { + adapter->SetStreamUserData(1, const_cast(kSentinel1)); + return true; + })); + EXPECT_CALL(visitor, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(1, 2000)); + EXPECT_CALL(visitor, OnFrameHeader(1, 25, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 25)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the request body.")); + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":scheme", "http")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":path", "/this/is/request/two")); + EXPECT_CALL(visitor, OnEndHeadersForStream(3)); + EXPECT_CALL(visitor, OnEndStream(3)); + EXPECT_CALL(visitor, OnFrameHeader(3, 4, RST_STREAM, 0)); + EXPECT_CALL(visitor, OnRstStream(3, Http2ErrorCode::CANCEL)); + EXPECT_CALL(visitor, OnCloseStream(3, Http2ErrorCode::CANCEL)); + EXPECT_CALL(visitor, OnFrameHeader(0, 8, PING, 0)); + EXPECT_CALL(visitor, OnPing(47, false)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + EXPECT_EQ(kSentinel1, adapter->GetStreamUserData(1)); + + EXPECT_GT(kInitialFlowControlWindowSize, + adapter->GetStreamReceiveWindowSize(1)); + EXPECT_EQ(adapter->GetStreamReceiveWindowSize(1), + adapter->GetReceiveWindowSize()); + // Upper bound should still be the original value. + EXPECT_EQ(kInitialFlowControlWindowSize, + adapter->GetStreamReceiveWindowLimit(1)); + + EXPECT_GT(adapter->GetHpackDecoderDynamicTableSize(), 0); + + // Because stream 3 has already been closed, it's not possible to set user + // data. + const char* kSentinel3 = "another arbitrary pointer"; + adapter->SetStreamUserData(3, const_cast(kSentinel3)); + EXPECT_EQ(nullptr, adapter->GetStreamUserData(3)); + + EXPECT_EQ(3, adapter->GetHighestReceivedStreamId()); + + EXPECT_EQ(adapter->GetSendWindowSize(), kInitialFlowControlWindowSize + 1000); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(PING, 0, 8, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(PING, 0, 8, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(PING, 0, 8, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(PING, 0, 8, 0x1, 0)); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS ack, two PING acks. + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::PING, + SpdyFrameType::PING})); +} + +TEST(OgHttp2AdapterTest, HeaderValuesWithObsTextAllowed) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}, + {"name", "val\xa1ue"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "name", "val\xa1ue")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); +} + +TEST(NgHttp2AdapterTest, ServerHandlesDataWithPadding) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Data(1, "This is the request body.", + /*fin=*/true, /*padding_length=*/39) + .Headers(3, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 25 + 39, DATA, 0x9)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 25 + 39)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the request body.")); + // Note: nghttp2 passes padding information after the actual data. + EXPECT_CALL(visitor, OnDataPaddingLength(1, 39)); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(3)); + EXPECT_CALL(visitor, OnEndStream(3)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, ServerHandlesHostHeader) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":path", "/this/is/request/one"}, + {"host", "example.com"}}, + /*fin=*/true) + .Headers(3, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"host", "example.com"}}, + /*fin=*/true) + .Headers(5, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "foo.com"}, + {":path", "/this/is/request/one"}, + {"host", "bar.com"}}, + /*fin=*/true) + .Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, _, _)).Times(5); + EXPECT_CALL(visitor, OnEndHeadersForStream(3)); + EXPECT_CALL(visitor, OnEndStream(3)); + + EXPECT_CALL(visitor, OnFrameHeader(5, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(5)); + EXPECT_CALL(visitor, OnHeaderForStream(5, _, _)).Times(5); + EXPECT_CALL(visitor, OnEndHeadersForStream(5)); + EXPECT_CALL(visitor, OnEndStream(5)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + visitor.Clear(); +} + +// Tests the case where the response body is in the progress of being sent while +// trailers are queued. +TEST(NgHttp2AdapterTest, ServerSubmitsTrailersWhileDataDeferred) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .WindowUpdate(0, 2000) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(1, 2000)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, _)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the request body.")); + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(0, 2000)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + visitor.Clear(); + + const absl::string_view kBody = "This is an example response body."; + + // The body source must indicate that the end of the body is not the end of + // the stream. + auto body1 = std::make_unique(visitor, false); + body1->AppendPayload(kBody); + auto* body1_ptr = body1.get(); + int submit_result = adapter->SubmitResponse( + 1, ToHeaders({{":status", "200"}, {"x-comment", "Sure, sounds good."}}), + std::move(body1)); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)); + + send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); + + int trailer_result = + adapter->SubmitTrailer(1, ToHeaders({{"final-status", "a-ok"}})); + ASSERT_EQ(trailer_result, 0); + + // Even though the data source has not finished sending data, nghttp2 will + // write the trailers anyway. + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x5, 0)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::HEADERS})); + visitor.Clear(); + + // Resuming the stream results in the library wanting to write again. + body1_ptr->AppendPayload(kBody); + body1_ptr->EndData(); + adapter->ResumeStream(1); + EXPECT_TRUE(adapter->want_write()); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + + // But no data is written for the stream. + EXPECT_THAT(visitor.data(), testing::IsEmpty()); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(NgHttp2AdapterTest, ClientDisobeysConnectionFlowControl) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"accept", "some bogus value!"}}, + /*fin=*/false) + // 70000 bytes of data + .Data(1, std::string(16384, 'a')) + .Data(1, std::string(16384, 'a')) + .Data(1, std::string(16384, 'a')) + .Data(1, std::string(16384, 'a')) + .Data(1, std::string(4464, 'a')) + .Serialize(); + + testing::InSequence s; + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream).Times(5); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 16384, DATA, 0x0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 16384)); + EXPECT_CALL(visitor, OnDataForStream(1, _)); + EXPECT_CALL(visitor, OnFrameHeader(1, 16384, DATA, 0x0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 16384)); + EXPECT_CALL(visitor, OnDataForStream(1, _)); + EXPECT_CALL(visitor, OnFrameHeader(1, 16384, DATA, 0x0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 16384)); + EXPECT_CALL(visitor, OnDataForStream(1, _)); + EXPECT_CALL(visitor, OnFrameHeader(1, 16384, DATA, 0x0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 16384)); + // No further frame data or headers are delivered. + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + EXPECT_TRUE(adapter->want_write()); + + // No SETTINGS ack is written. + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL( + visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::FLOW_CONTROL_ERROR))); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); +} + +TEST(NgHttp2AdapterTest, ClientDisobeysConnectionFlowControlWithOneDataFrame) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + // Allow the client to send a DATA frame that exceeds the connection flow + // control window. + const uint32_t window_overflow_bytes = kInitialFlowControlWindowSize + 1; + adapter->SubmitSettings({{MAX_FRAME_SIZE, window_overflow_bytes}}); + + const std::string initial_frames = + TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + int64_t process_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), static_cast(process_result)); + + EXPECT_TRUE(adapter->want_write()); + + // Outbound SETTINGS containing MAX_FRAME_SIZE. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + + // Ack of client's initial settings. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS})); + visitor.Clear(); + + // Now let the client ack the MAX_FRAME_SIZE SETTINGS and send a DATA frame to + // overflow the connection-level window. The result should be a GOAWAY. + const std::string overflow_frames = + TestFrameSequence() + .SettingsAck() + .Data(1, std::string(window_overflow_bytes, 'a')) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0x1)); + EXPECT_CALL(visitor, OnSettingsAck()); + EXPECT_CALL(visitor, OnFrameHeader(1, window_overflow_bytes, DATA, 0x0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, window_overflow_bytes)); + // No further frame data is delivered. + + process_result = adapter->ProcessBytes(overflow_frames); + EXPECT_EQ(overflow_frames.size(), static_cast(process_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL( + visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::FLOW_CONTROL_ERROR))); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); +} + +TEST(NgHttp2AdapterTest, ClientDisobeysConnectionFlowControlAcrossReads) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + // Allow the client to send a DATA frame that exceeds the connection flow + // control window. + const uint32_t window_overflow_bytes = kInitialFlowControlWindowSize + 1; + adapter->SubmitSettings({{MAX_FRAME_SIZE, window_overflow_bytes}}); + + const std::string initial_frames = + TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + int64_t process_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), static_cast(process_result)); + + EXPECT_TRUE(adapter->want_write()); + + // Outbound SETTINGS containing MAX_FRAME_SIZE. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + + // Ack of client's initial settings. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS})); + visitor.Clear(); + + // Now let the client ack the MAX_FRAME_SIZE SETTINGS and send a DATA frame to + // overflow the connection-level window. The result should be a GOAWAY, but + // because the processing is split across several calls, nghttp2 instead + // delivers the data payloads (which the visitor then consumes). This is a bug + // in nghttp2, which should recognize the flow control error. + const std::string overflow_frames = + TestFrameSequence() + .SettingsAck() + .Data(1, std::string(window_overflow_bytes, 'a')) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0x1)); + EXPECT_CALL(visitor, OnSettingsAck()); + EXPECT_CALL(visitor, OnFrameHeader(1, window_overflow_bytes, DATA, 0x0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, window_overflow_bytes)); + // BUG: The visitor should not have received the data. + EXPECT_CALL(visitor, OnDataForStream(1, _)) + .WillRepeatedly( + [&adapter](Http2StreamId stream_id, absl::string_view data) { + adapter->MarkDataConsumedForStream(stream_id, data.size()); + return true; + }); + + const size_t chunk_length = 16384; + ASSERT_GE(overflow_frames.size(), chunk_length); + absl::string_view remaining = overflow_frames; + while (!remaining.empty()) { + absl::string_view chunk = remaining.substr(0, chunk_length); + process_result = adapter->ProcessBytes(chunk); + EXPECT_EQ(chunk.length(), static_cast(process_result)); + + remaining.remove_prefix(chunk.length()); + } + + EXPECT_CALL(visitor, OnBeforeFrameSent(WINDOW_UPDATE, 0, 4, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(WINDOW_UPDATE, 0, 4, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(WINDOW_UPDATE, 1, 4, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(WINDOW_UPDATE, 1, 4, 0x0, 0)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::WINDOW_UPDATE, + SpdyFrameType::WINDOW_UPDATE})); +} + +TEST(NgHttp2AdapterTest, ClientDisobeysStreamFlowControl) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"accept", "some bogus value!"}}, + /*fin=*/false) + .Serialize(); + const std::string more_frames = TestFrameSequence() + // 70000 bytes of data + .Data(1, std::string(16384, 'a')) + .Data(1, std::string(16384, 'a')) + .Data(1, std::string(16384, 'a')) + .Data(1, std::string(16384, 'a')) + .Data(1, std::string(4464, 'a')) + .Serialize(); + + testing::InSequence s; + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream).Times(5); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + adapter->SubmitWindowUpdate(0, 20000); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(WINDOW_UPDATE, 0, 4, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(WINDOW_UPDATE, 0, 4, 0x0, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS, + SpdyFrameType::WINDOW_UPDATE})); + visitor.Clear(); + + EXPECT_CALL(visitor, OnFrameHeader(1, 16384, DATA, 0x0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 16384)); + EXPECT_CALL(visitor, OnDataForStream(1, _)); + EXPECT_CALL(visitor, OnFrameHeader(1, 16384, DATA, 0x0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 16384)); + EXPECT_CALL(visitor, OnDataForStream(1, _)); + EXPECT_CALL(visitor, OnFrameHeader(1, 16384, DATA, 0x0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 16384)); + EXPECT_CALL(visitor, OnDataForStream(1, _)); + EXPECT_CALL(visitor, OnFrameHeader(1, 16384, DATA, 0x0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 16384)); + EXPECT_CALL(visitor, OnDataForStream(1, _)); + // No further frame data or headers for stream 1 are delivered. + + result = adapter->ProcessBytes(more_frames); + EXPECT_EQ(more_frames.size(), static_cast(result)); + + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, 4, 0x0)); + EXPECT_CALL( + visitor, + OnFrameSent(RST_STREAM, 1, 4, 0x0, + static_cast(Http2ErrorCode::FLOW_CONTROL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::FLOW_CONTROL_ERROR)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, ServerErrorWhileHandlingHeaders) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"accept", "some bogus value!"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .WindowUpdate(0, 2000) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "accept", "some bogus value!")) + .WillOnce(testing::Return(Http2VisitorInterface::HEADER_RST_STREAM)); + EXPECT_CALL(visitor, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(1, 2000)); + // DATA frame is not delivered to the visitor. + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(0, 2000)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, 4, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, 4, 0x0, + static_cast(Http2ErrorCode::INTERNAL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::INTERNAL_ERROR)); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS ack + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, ServerErrorWhileHandlingHeadersDropsFrames) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"accept", "some bogus value!"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .Metadata(1, "This is the request metadata.") + .RstStream(1, Http2ErrorCode::CANCEL) + .WindowUpdate(0, 2000) + .Headers(3, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/false) + .Metadata(3, "This is the request metadata.", + /*multiple_frames=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnHeaderForStream(1, "accept", "some bogus value!")) + .WillOnce(testing::Return(Http2VisitorInterface::HEADER_RST_STREAM)); + // For the RST_STREAM-marked stream, the control frames and METADATA frame but + // not the DATA frame are delivered to the visitor. + EXPECT_CALL(visitor, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(1, 2000)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataEndForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 4, RST_STREAM, 0)); + EXPECT_CALL(visitor, OnRstStream(1, Http2ErrorCode::CANCEL)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::CANCEL)); + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(0, 2000)); + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(3)); + EXPECT_CALL(visitor, OnFrameHeader(3, _, kMetadataFrameType, 0)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(3, _)); + EXPECT_CALL(visitor, OnMetadataForStream(3, "This is the re")) + .WillOnce(testing::DoAll(testing::InvokeWithoutArgs([&adapter]() { + adapter->SubmitRst( + 3, Http2ErrorCode::REFUSED_STREAM); + }), + testing::Return(true))); + // The rest of the metadata is still delivered to the visitor. + EXPECT_CALL(visitor, OnFrameHeader(3, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(3, _)); + EXPECT_CALL(visitor, OnMetadataForStream(3, "quest metadata.")); + EXPECT_CALL(visitor, OnMetadataEndForStream(3)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, 4, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, 4, 0x0, + static_cast(Http2ErrorCode::INTERNAL_ERROR))); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 3, 4, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 3, 4, 0x0, + static_cast(Http2ErrorCode::REFUSED_STREAM))); + EXPECT_CALL(visitor, OnCloseStream(3, Http2ErrorCode::REFUSED_STREAM)); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::RST_STREAM, + SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, ServerConnectionErrorWhileHandlingHeaders) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"Accept", "uppercase, oh boy!"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .WindowUpdate(0, 2000) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnErrorDebug); + EXPECT_CALL( + visitor, + OnInvalidFrame(1, Http2VisitorInterface::InvalidFrameError::kHttpHeader)) + .WillOnce(testing::Return(false)); + // Translation to nghttp2 treats this error as a general parsing error. + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(result, NGHTTP2_ERR_CALLBACK_FAILURE); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, 4, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, 4, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::PROTOCOL_ERROR)); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS ack and RST_STREAM + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, ServerErrorAfterHandlingHeaders) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .WindowUpdate(0, 2000) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)) + .WillOnce(testing::Return(false)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(-902, result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS ack + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); +} + +// Exercises the case when a visitor chooses to reject a frame based solely on +// the frame header, which is a fatal error for the connection. +TEST(NgHttp2AdapterTest, ServerRejectsFrameHeader) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Ping(64) + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .WindowUpdate(0, 2000) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(0, 8, PING, 0)) + .WillOnce(testing::Return(false)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(-902, result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS ack + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, ServerRejectsBeginningOfData) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Data(1, "This is the request body.") + .Headers(3, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .RstStream(3, Http2ErrorCode::CANCEL) + .Ping(47) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 25, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 25)) + .WillOnce(testing::Return(false)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(NGHTTP2_ERR_CALLBACK_FAILURE, result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS ack. + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, ServerRejectsStreamData) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Data(1, "This is the request body.") + .Headers(3, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .RstStream(3, Http2ErrorCode::CANCEL) + .Ping(47) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 25, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 25)); + EXPECT_CALL(visitor, OnDataForStream(1, _)).WillOnce(testing::Return(false)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(NGHTTP2_ERR_CALLBACK_FAILURE, result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS ack. + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, ServerReceivesTooLargeHeader) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + // nghttp2 will accept a maximum of 64kB of huffman encoded data per header + // field. + const std::string too_large_value = std::string(80 * 1024, 'q'); + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"x-toobig", too_large_value}}, + /*fin=*/false) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + // Further header processing is skipped, as the header field is too large. + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + EXPECT_TRUE(adapter->want_write()); + + // Since nghttp2 opted not to process the header, it generates a GOAWAY with + // error code COMPRESSION_ERROR. + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, 8, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, 8, 0x0, + static_cast(Http2ErrorCode::COMPRESSION_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // GOAWAY. + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); +} + +TEST(NgHttp2AdapterTest, ServerReceivesInvalidAuthority) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "ex|ample.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL( + visitor, + OnErrorDebug("Invalid HTTP header field was received: frame type: 1, " + "stream: 1, name: [:authority], value: [ex|ample.com]")); + EXPECT_CALL( + visitor, + OnInvalidFrame(1, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0x0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, 4, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, 4, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::PROTOCOL_ERROR)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttpAdapterTest, ServerReceivesGoAway) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .GoAway(0, Http2ErrorCode::HTTP2_NO_ERROR, "") + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0x5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(0, _, GOAWAY, 0x0)); + EXPECT_CALL(visitor, OnGoAway(0, Http2ErrorCode::HTTP2_NO_ERROR, "")); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(frames.size()), result); + + // The server should still be able to send a response after receiving a GOAWAY + // with a lower last-stream-ID field, as the stream was client-initiated. + const int submit_result = + adapter->SubmitResponse(1, ToHeaders({{":status", "200"}}), + /*data_source=*/nullptr); + ASSERT_EQ(0, submit_result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0x0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x5, 0)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); +} + +TEST(NgHttp2AdapterTest, ServerSubmitResponse) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + const char* kSentinel1 = "arbitrary pointer 1"; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)) + .WillOnce(testing::InvokeWithoutArgs([&adapter, kSentinel1]() { + adapter->SetStreamUserData(1, const_cast(kSentinel1)); + return true; + })); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + EXPECT_EQ(1, adapter->GetHighestReceivedStreamId()); + + // Server will want to send a SETTINGS ack. + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); + visitor.Clear(); + + EXPECT_EQ(0, adapter->GetHpackEncoderDynamicTableSize()); + + EXPECT_FALSE(adapter->want_write()); + const absl::string_view kBody = "This is an example response body."; + // A data fin is not sent so that the stream remains open, and the flow + // control state can be verified. + auto body1 = std::make_unique(visitor, false); + body1->AppendPayload(kBody); + int submit_result = adapter->SubmitResponse( + 1, + ToHeaders({{":status", "404"}, + {"x-comment", "I have no idea what you're talking about."}}), + std::move(body1)); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + // Stream user data should have been set successfully after receiving headers. + EXPECT_EQ(kSentinel1, adapter->GetStreamUserData(1)); + adapter->SetStreamUserData(1, nullptr); + EXPECT_EQ(nullptr, adapter->GetStreamUserData(1)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::HEADERS, SpdyFrameType::DATA})); + EXPECT_THAT(visitor.data(), testing::HasSubstr(kBody)); + EXPECT_FALSE(adapter->want_write()); + + // Some data was sent, so the remaining send window size should be less than + // the default. + EXPECT_LT(adapter->GetStreamSendWindowSize(1), kInitialFlowControlWindowSize); + EXPECT_GT(adapter->GetStreamSendWindowSize(1), 0); + // Send window for a nonexistent stream is not available. + EXPECT_EQ(adapter->GetStreamSendWindowSize(3), -1); + + EXPECT_GT(adapter->GetHpackEncoderDynamicTableSize(), 0); +} + +TEST(NgHttp2AdapterTest, ServerSubmitResponseWithResetFromClient) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + EXPECT_EQ(1, adapter->GetHighestReceivedStreamId()); + + // Server will want to send a SETTINGS ack. + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); + visitor.Clear(); + + EXPECT_FALSE(adapter->want_write()); + const absl::string_view kBody = "This is an example response body."; + auto body1 = std::make_unique(visitor, true); + body1->AppendPayload(kBody); + int submit_result = adapter->SubmitResponse( + 1, + ToHeaders({{":status", "404"}, + {"x-comment", "I have no idea what you're talking about."}}), + std::move(body1)); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + // Client resets the stream before the server can send the response. + const std::string reset = + TestFrameSequence().RstStream(1, Http2ErrorCode::CANCEL).Serialize(); + EXPECT_CALL(visitor, OnFrameHeader(1, 4, RST_STREAM, 0)); + EXPECT_CALL(visitor, OnRstStream(1, Http2ErrorCode::CANCEL)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::CANCEL)); + const int64_t reset_result = adapter->ProcessBytes(reset); + EXPECT_EQ(reset.size(), static_cast(reset_result)); + + // Outbound HEADERS and DATA are dropped. + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, _)).Times(0); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, _, _)).Times(0); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, _, _)).Times(0); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + + EXPECT_THAT(visitor.data(), testing::IsEmpty()); +} + +// Should also test: client attempts shutdown, server attempts shutdown after an +// explicit GOAWAY. +TEST(NgHttp2AdapterTest, ServerSendsShutdown) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + adapter->SubmitShutdownNotice(); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(GOAWAY, 0, _, 0x0, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::GOAWAY})); +} + +TEST(NgHttp2AdapterTest, ServerSendsTrailers) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + // Server will want to send a SETTINGS ack. + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); + visitor.Clear(); + + EXPECT_FALSE(adapter->want_write()); + const absl::string_view kBody = "This is an example response body."; + + // The body source must indicate that the end of the body is not the end of + // the stream. + auto body1 = std::make_unique(visitor, false); + body1->AppendPayload(kBody); + body1->EndData(); + int submit_result = adapter->SubmitResponse( + 1, ToHeaders({{":status", "200"}, {"x-comment", "Sure, sounds good."}}), + std::move(body1)); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::HEADERS, SpdyFrameType::DATA})); + EXPECT_THAT(visitor.data(), testing::HasSubstr(kBody)); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); + + // The body source has been exhausted by the call to Send() above. + int trailer_result = adapter->SubmitTrailer( + 1, ToHeaders({{"final-status", "a-ok"}, + {"x-comment", "trailers sure are cool"}})); + ASSERT_EQ(trailer_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x5, 0)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::HEADERS})); +} + +TEST(NgHttp2AdapterTest, ClientSendsContinuation) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true, + /*add_continuation=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 1)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnFrameHeader(1, _, CONTINUATION, 4)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); +} + +TEST(NgHttp2AdapterTest, ClientSendsMetadataWithContinuation) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = + TestFrameSequence() + .ClientPreface() + .Metadata(0, "Example connection metadata in multiple frames", true) + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false, + /*add_continuation=*/true) + .Metadata(1, + "Some stream metadata that's also sent in multiple frames", + true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Metadata on stream 0 + EXPECT_CALL(visitor, OnFrameHeader(0, _, kMetadataFrameType, 0)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnFrameHeader(0, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataEndForStream(0)); + + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnFrameHeader(1, _, CONTINUATION, 4)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + // Metadata on stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, kMetadataFrameType, 0)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataEndForStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + EXPECT_EQ("Example connection metadata in multiple frames", + absl::StrJoin(visitor.GetMetadata(0), "")); + EXPECT_EQ("Some stream metadata that's also sent in multiple frames", + absl::StrJoin(visitor.GetMetadata(1), "")); +} + +TEST(NgHttp2AdapterTest, RepeatedHeaderNames) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"accept", "text/plain"}, + {"accept", "text/html"}}, + /*fin=*/true) + .Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "accept", "text/plain")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "accept", "text/html")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + const std::vector
headers1 = ToHeaders( + {{":status", "200"}, {"content-length", "10"}, {"content-length", "10"}}); + auto body1 = std::make_unique(visitor, true); + body1->AppendPayload("perfection"); + body1->EndData(); + + int submit_result = adapter->SubmitResponse(1, headers1, std::move(body1)); + ASSERT_EQ(0, submit_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, 10, 0x1, 0)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS, + SpdyFrameType::DATA})); +} + +TEST(NgHttp2AdapterTest, ServerRespondsToRequestWithTrailers) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = + TestFrameSequence() + .ClientPreface() + .Headers(1, {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}) + .Data(1, "Example data, woohoo.") + .Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, _)); + EXPECT_CALL(visitor, OnDataForStream(1, _)); + + int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + const std::vector
headers1 = ToHeaders({{":status", "200"}}); + auto body1 = std::make_unique(visitor, true); + TestDataFrameSource* body1_ptr = body1.get(); + + int submit_result = adapter->SubmitResponse(1, headers1, std::move(body1)); + ASSERT_EQ(0, submit_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string more_frames = + TestFrameSequence() + .Headers(1, {{"extra-info", "Trailers are weird but good?"}}, + /*fin=*/true) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, "extra-info", + "Trailers are weird but good?")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + result = adapter->ProcessBytes(more_frames); + EXPECT_EQ(more_frames.size(), static_cast(result)); + + body1_ptr->EndData(); + EXPECT_EQ(true, adapter->ResumeStream(1)); + + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::DATA})); +} + +TEST(NgHttp2AdapterTest, ServerSubmitsResponseWithDataSourceError) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + auto body1 = std::make_unique(visitor, false); + body1->SimulateError(); + int submit_result = adapter->SubmitResponse( + 1, ToHeaders({{":status", "200"}, {"x-comment", "Sure, sounds good."}}), + std::move(body1)); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(RST_STREAM, 1, _, 0x0, 2)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::INTERNAL_ERROR)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS, + SpdyFrameType::RST_STREAM})); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); + + int trailer_result = + adapter->SubmitTrailer(1, ToHeaders({{":final-status", "a-ok"}})); + // The library does not object to the user queuing trailers, even through the + // stream has already been closed. + EXPECT_EQ(trailer_result, 0); +} + +TEST(NgHttp2AdapterTest, CompleteRequestWithServerResponse) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = + TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Data(1, "This is the response body.", /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 1)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, _)); + EXPECT_CALL(visitor, OnDataForStream(1, _)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + int submit_result = + adapter->SubmitResponse(1, ToHeaders({{":status", "200"}}), nullptr); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x5, 0)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(NgHttp2AdapterTest, IncompleteRequestWithServerResponse) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + int submit_result = + adapter->SubmitResponse(1, ToHeaders({{":status", "200"}}), nullptr); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x5, 0)); + // BUG: Should send RST_STREAM NO_ERROR as well, but nghttp2 does not. + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(NgHttp2AdapterTest, ServerHandlesMultipleContentLength) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/1"}, + {"content-length", "7"}, + {"content-length", "7"}}, + /*fin=*/false) + .Headers(3, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/3"}, + {"content-length", "11"}, + {"content-length", "13"}}, + /*fin=*/false) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/1")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "content-length", "7")); + // nghttp2 does not like duplicate Content-Length headers. + EXPECT_CALL( + visitor, + OnErrorDebug("Invalid HTTP header field was received: frame type: 1, " + "stream: 1, name: [content-length], value: [7]")); + EXPECT_CALL( + visitor, + OnInvalidFrame(1, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + // Stream 3 + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":path", "/3")); + EXPECT_CALL(visitor, OnHeaderForStream(3, "content-length", "11")); + EXPECT_CALL( + visitor, + OnErrorDebug("Invalid HTTP header field was received: frame type: 1, " + "stream: 3, name: [content-length], value: [13]")); + EXPECT_CALL( + visitor, + OnInvalidFrame(3, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); +} + +TEST(NgHttp2AdapterTest, ServerSendsInvalidTrailers) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + const absl::string_view kBody = "This is an example response body."; + + // The body source must indicate that the end of the body is not the end of + // the stream. + auto body1 = std::make_unique(visitor, false); + body1->AppendPayload(kBody); + body1->EndData(); + int submit_result = adapter->SubmitResponse( + 1, ToHeaders({{":status", "200"}, {"x-comment", "Sure, sounds good."}}), + std::move(body1)); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS, + SpdyFrameType::DATA})); + EXPECT_THAT(visitor.data(), testing::HasSubstr(kBody)); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); + + // The body source has been exhausted by the call to Send() above. + int trailer_result = + adapter->SubmitTrailer(1, ToHeaders({{":final-status", "a-ok"}})); + ASSERT_EQ(trailer_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x5, 0)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::HEADERS})); +} + +TEST(NgHttp2AdapterTest, ServerDropsNewStreamBelowWatermark) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(3, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Data(3, "This is the request body.") + .Headers(1, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(3)); + EXPECT_CALL(visitor, OnFrameHeader(3, 25, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(3, 25)); + EXPECT_CALL(visitor, OnDataForStream(3, "This is the request body.")); + + // It looks like nghttp2 delivers the under-watermark frame header but + // otherwise silently drops the rest of the frame without error. + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnInvalidFrame).Times(0); + EXPECT_CALL(visitor, OnConnectionError).Times(0); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + EXPECT_EQ(3, adapter->GetHighestReceivedStreamId()); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS ack + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterInteractionTest, + ClientServerInteractionRepeatedHeaderNames) { + DataSavingVisitor client_visitor; + auto client_adapter = NgHttp2Adapter::CreateClientAdapter(client_visitor); + + client_adapter->SubmitSettings({}); + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"accept", "text/plain"}, + {"accept", "text/html"}}); + + const int32_t stream_id1 = + client_adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + EXPECT_CALL(client_visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(client_visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(client_visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(client_visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + int send_result = client_adapter->Send(); + EXPECT_EQ(0, send_result); + + DataSavingVisitor server_visitor; + auto server_adapter = NgHttp2Adapter::CreateServerAdapter(server_visitor); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(server_visitor, OnFrameHeader(0, _, SETTINGS, 0)); + EXPECT_CALL(server_visitor, OnSettingsStart()); + EXPECT_CALL(server_visitor, OnSetting(_)).Times(testing::AnyNumber()); + EXPECT_CALL(server_visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(server_visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(server_visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(server_visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(server_visitor, OnHeaderForStream(1, ":scheme", "http")); + EXPECT_CALL(server_visitor, + OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(server_visitor, + OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(server_visitor, OnHeaderForStream(1, "accept", "text/plain")); + EXPECT_CALL(server_visitor, OnHeaderForStream(1, "accept", "text/html")); + EXPECT_CALL(server_visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(server_visitor, OnEndStream(1)); + + int64_t result = server_adapter->ProcessBytes(client_visitor.data()); + EXPECT_EQ(client_visitor.data().size(), static_cast(result)); +} + +TEST(NgHttp2AdapterTest, ServerForbidsWindowUpdateOnIdleStream) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + const std::string frames = + TestFrameSequence().ClientPreface().WindowUpdate(1, 42).Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnInvalidFrame(1, _)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // The GOAWAY apparently causes the SETTINGS ack to be dropped. + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); +} + +TEST(NgHttp2AdapterTest, ServerForbidsDataOnIdleStream) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Data(1, "Sorry, out of order") + .Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + EXPECT_TRUE(adapter->want_write()); + + // In this case, nghttp2 goes straight to GOAWAY and does not invoke the + // invalid frame callback. + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // The GOAWAY apparently causes the SETTINGS ack to be dropped. + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); +} + +TEST(NgHttp2AdapterTest, ServerForbidsRstStreamOnIdleStream) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + const std::string frames = + TestFrameSequence() + .ClientPreface() + .RstStream(1, Http2ErrorCode::ENHANCE_YOUR_CALM) + .Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, RST_STREAM, 0)); + EXPECT_CALL(visitor, OnInvalidFrame(1, _)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // The GOAWAY apparently causes the SETTINGS ack to be dropped. + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); +} + +TEST(NgHttp2AdapterTest, ServerForbidsNewStreamAboveStreamLimit) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + adapter->SubmitSettings({{MAX_CONCURRENT_STREAMS, 1}}); + + const std::string initial_frames = + TestFrameSequence().ClientPreface().Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), initial_result); + + EXPECT_TRUE(adapter->want_write()); + + // Server initial SETTINGS (with MAX_CONCURRENT_STREAMS) and SETTINGS ack. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS})); + visitor.Clear(); + + // Let the client send a SETTINGS ack and then attempt to open more than the + // advertised number of streams. The overflow stream should be rejected. + const std::string stream_frames = + TestFrameSequence() + .SettingsAck() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Headers(3, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0x1)); + EXPECT_CALL(visitor, OnSettingsAck()); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0x5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 0x5)); + EXPECT_CALL( + visitor, + OnInvalidFrame(3, Http2VisitorInterface::InvalidFrameError::kProtocol)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), stream_result); + + // The server should send a GOAWAY for this error, even though + // OnInvalidFrame() returns true. + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); +} + +TEST(NgHttp2AdapterTest, ServerRstStreamsNewStreamAboveStreamLimitBeforeAck) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + adapter->SubmitSettings({{MAX_CONCURRENT_STREAMS, 1}}); + + const std::string initial_frames = + TestFrameSequence().ClientPreface().Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), initial_result); + + EXPECT_TRUE(adapter->want_write()); + + // Server initial SETTINGS (with MAX_CONCURRENT_STREAMS) and SETTINGS ack. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS})); + visitor.Clear(); + + // Let the client avoid sending a SETTINGS ack and attempt to open more than + // the advertised number of streams. The server should still reject the + // overflow stream, albeit with RST_STREAM REFUSED_STREAM instead of GOAWAY. + const std::string stream_frames = + TestFrameSequence() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Headers(3, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0x5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 0x5)); + EXPECT_CALL(visitor, + OnInvalidFrame( + 3, Http2VisitorInterface::InvalidFrameError::kRefusedStream)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_result, stream_frames.size()); + + // The server sends a RST_STREAM for the offending stream. + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 3, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 3, _, 0x0, + static_cast(Http2ErrorCode::REFUSED_STREAM))); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, AutomaticSettingsAndPingAcks) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + const std::string frames = + TestFrameSequence().ClientPreface().Ping(42).Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // PING + EXPECT_CALL(visitor, OnFrameHeader(0, _, PING, 0)); + EXPECT_CALL(visitor, OnPing(42, false)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + + EXPECT_TRUE(adapter->want_write()); + + // Server preface does not appear to include the mandatory SETTINGS frame. + // SETTINGS ack + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + // PING ack + EXPECT_CALL(visitor, OnBeforeFrameSent(PING, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(PING, 0, _, 0x1, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::PING})); +} + +TEST(NgHttp2AdapterTest, AutomaticPingAcksDisabled) { + DataSavingVisitor visitor; + nghttp2_option* options; + nghttp2_option_new(&options); + nghttp2_option_set_no_auto_ping_ack(options, 1); + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor, options); + nghttp2_option_del(options); + + const std::string frames = + TestFrameSequence().ClientPreface().Ping(42).Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // PING + EXPECT_CALL(visitor, OnFrameHeader(0, _, PING, 0)); + EXPECT_CALL(visitor, OnPing(42, false)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + + EXPECT_TRUE(adapter->want_write()); + + // Server preface does not appear to include the mandatory SETTINGS frame. + // SETTINGS ack + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + // No PING ack expected because automatic PING acks are disabled. + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, InvalidMaxFrameSizeSetting) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = + TestFrameSequence().ClientPreface({{MAX_FRAME_SIZE, 3u}}).Serialize(); + testing::InSequence s; + + // Client preface + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL( + visitor, + OnInvalidFrame(0, Http2VisitorInterface::InvalidFrameError::kProtocol)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterTest, InvalidPushSetting) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = + TestFrameSequence().ClientPreface({{ENABLE_PUSH, 3u}}).Serialize(); + testing::InSequence s; + + // Client preface + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnInvalidFrame(0, _)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); +} + +TEST(NgHttp2AdapterTest, InvalidConnectProtocolSetting) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface({{ENABLE_CONNECT_PROTOCOL, 3u}}) + .Serialize(); + testing::InSequence s; + + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL( + visitor, + OnInvalidFrame(0, Http2VisitorInterface::InvalidFrameError::kProtocol)); + + int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); + + auto adapter2 = NgHttp2Adapter::CreateServerAdapter(visitor); + const std::string frames2 = TestFrameSequence() + .ClientPreface({{ENABLE_CONNECT_PROTOCOL, 1}}) + .Settings({{ENABLE_CONNECT_PROTOCOL, 0}}) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting(Http2Setting{ENABLE_CONNECT_PROTOCOL, 1u})); + EXPECT_CALL(visitor, OnSettingsEnd()); + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + // Surprisingly, nghttp2 allows this behavior, which is prohibited in RFC + // 8441. + EXPECT_CALL(visitor, OnSetting(Http2Setting{ENABLE_CONNECT_PROTOCOL, 0u})); + EXPECT_CALL(visitor, OnSettingsEnd()); + + read_result = adapter2->ProcessBytes(frames2); + EXPECT_EQ(static_cast(read_result), frames2.size()); + + EXPECT_TRUE(adapter2->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + adapter2->Send(); +} + +TEST(NgHttp2AdapterTest, ServerForbidsProtocolPseudoheaderBeforeAck) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string initial_frames = + TestFrameSequence().ClientPreface().Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(static_cast(initial_result), initial_frames.size()); + + // The client attempts to send a CONNECT request with the `:protocol` + // pseudoheader before receiving the server's SETTINGS frame. + const std::string stream1_frames = + TestFrameSequence() + .Headers(1, + {{":method", "CONNECT"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {":protocol", "websocket"}}, + /*fin=*/true) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0x5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL( + visitor, + OnErrorDebug("Invalid HTTP header field was received: frame type: 1, " + "stream: 1, name: [:protocol], value: [websocket]")); + EXPECT_CALL( + visitor, + OnInvalidFrame(1, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + + int64_t stream_result = adapter->ProcessBytes(stream1_frames); + EXPECT_EQ(static_cast(stream_result), stream1_frames.size()); + + // Server sends a SETTINGS ack and initial SETTINGS (with + // ENABLE_CONNECT_PROTOCOL). + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + + // The server sends a RST_STREAM for the offending stream. + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::PROTOCOL_ERROR)); + + adapter->SubmitSettings({{ENABLE_CONNECT_PROTOCOL, 1}}); + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); + visitor.Clear(); + + // The client attempts to send a CONNECT request with the `:protocol` + // pseudoheader before acking the server's SETTINGS frame. + const std::string stream3_frames = + TestFrameSequence() + .Headers(3, + {{":method", "CONNECT"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}, + {":protocol", "websocket"}}, + /*fin=*/true) + .Serialize(); + + // Surprisingly, nghttp2 is okay with this. + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 0x5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, _, _)).Times(5); + EXPECT_CALL(visitor, OnEndHeadersForStream(3)); + EXPECT_CALL(visitor, OnEndStream(3)); + + stream_result = adapter->ProcessBytes(stream3_frames); + EXPECT_EQ(static_cast(stream_result), stream3_frames.size()); + + EXPECT_FALSE(adapter->want_write()); +} + +TEST(NgHttp2AdapterTest, ServerAllowsProtocolPseudoheaderAfterAck) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + adapter->SubmitSettings({{ENABLE_CONNECT_PROTOCOL, 1}}); + + const std::string initial_frames = + TestFrameSequence().ClientPreface().Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(static_cast(initial_result), initial_frames.size()); + + // Server initial SETTINGS (with ENABLE_CONNECT_PROTOCOL) and SETTINGS ack. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + visitor.Clear(); + + // The client attempts to send a CONNECT request with the `:protocol` + // pseudoheader after acking the server's SETTINGS frame. + const std::string stream_frames = + TestFrameSequence() + .SettingsAck() + .Headers(1, + {{":method", "CONNECT"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {":protocol", "websocket"}}, + /*fin=*/true) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, _, SETTINGS, 0x1)); + EXPECT_CALL(visitor, OnSettingsAck()); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0x5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(5); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(static_cast(stream_result), stream_frames.size()); + + EXPECT_FALSE(adapter->want_write()); +} + +TEST(NgHttp2AdapterTest, SkipsSendingFramesForRejectedStream) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string initial_frames = + TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0x5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(static_cast(initial_result), initial_frames.size()); + + auto body = std::make_unique(visitor, true); + body->AppendPayload("Here is some data, which will be completely ignored!"); + + int submit_result = adapter->SubmitResponse( + 1, ToHeaders({{":status", "200"}}), std::move(body)); + ASSERT_EQ(0, submit_result); + + auto source = std::make_unique(ToHeaderBlock(ToHeaders( + {{"query-cost", "is too darn high"}, {"secret-sauce", "hollandaise"}}))); + adapter->SubmitMetadata(1, 16384u, std::move(source)); + + adapter->SubmitWindowUpdate(1, 1024); + adapter->SubmitRst(1, Http2ErrorCode::INTERNAL_ERROR); + + // Server initial SETTINGS and SETTINGS ack. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + // nghttp2 apparently allows extension frames to be sent on reset streams. + EXPECT_CALL(visitor, OnBeforeFrameSent(kMetadataFrameType, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(kMetadataFrameType, 1, _, 0x4, 0)); + + // The server sends a RST_STREAM for the offending stream. + // The response HEADERS, DATA and WINDOW_UPDATE are all ignored. + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, _, 0x0, + static_cast(Http2ErrorCode::INTERNAL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::INTERNAL_ERROR)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, + static_cast(kMetadataFrameType), + SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, ServerStartsShutdown) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + EXPECT_FALSE(adapter->want_write()); + + adapter->SubmitShutdownNotice(); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(GOAWAY, 0, _, 0x0, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); +} + +TEST(NgHttp2AdapterTest, ServerStartsShutdownAfterGoaway) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + EXPECT_FALSE(adapter->want_write()); + + adapter->SubmitGoAway(1, Http2ErrorCode::HTTP2_NO_ERROR, + "and don't come back!"); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(GOAWAY, 0, _, 0x0, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); + + // No-op, since a GOAWAY has previously been enqueued. + adapter->SubmitShutdownNotice(); + EXPECT_FALSE(adapter->want_write()); +} + +// Verifies that a connection-level processing error results in repeatedly +// returning a positive value for ProcessBytes() to mark all data as consumed. +TEST(NgHttp2AdapterTest, ConnectionErrorWithBlackholeSinkingData) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = + TestFrameSequence().ClientPreface().WindowUpdate(1, 42).Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnInvalidFrame(1, _)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(result), frames.size()); + + // Ask the connection to process more bytes. Because the option is enabled, + // the data should be marked as consumed. + const std::string next_frame = TestFrameSequence().Ping(42).Serialize(); + const int64_t next_result = adapter->ProcessBytes(next_frame); + EXPECT_EQ(static_cast(next_result), next_frame.size()); +} + +TEST(NgHttp2AdapterTest, ServerDoesNotSendFramesAfterImmediateGoAway) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + // Submit a custom initial SETTINGS frame with one setting. + adapter->SubmitSettings({{HEADER_TABLE_SIZE, 100u}}); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0x5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + + // Submit a response for the stream. + auto body = std::make_unique(visitor, true); + body->AppendPayload("This data is doomed to never be written."); + int submit_result = adapter->SubmitResponse( + 1, ToHeaders({{":status", "200"}}), std::move(body)); + ASSERT_EQ(0, submit_result); + + // Submit a WINDOW_UPDATE frame. + adapter->SubmitWindowUpdate(kConnectionStreamId, 42); + + // Submit another SETTINGS frame. + adapter->SubmitSettings({}); + + // Submit some metadata. + auto source = std::make_unique(ToHeaderBlock(ToHeaders( + {{"query-cost", "is too darn high"}, {"secret-sauce", "hollandaise"}}))); + adapter->SubmitMetadata(1, 16384u, std::move(source)); + + EXPECT_TRUE(adapter->want_write()); + + // Trigger a connection error. Only the response headers will be written. + const std::string connection_error_frames = + TestFrameSequence().WindowUpdate(3, 42).Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(3, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnInvalidFrame(3, _)); + + const int64_t result = adapter->ProcessBytes(connection_error_frames); + EXPECT_EQ(static_cast(result), connection_error_frames.size()); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // The GOAWAY apparently causes the other frames to be dropped except for the + // non-ack SETTINGS frames; nghttp2 sends non-ack SETTINGS frames because they + // could be the initial SETTINGS frame. However, nghttp2 still allows sending + // multiple non-ack SETTINGS, which feels non-ideal. + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::GOAWAY})); + visitor.Clear(); + + // Try to submit more frames for writing. They should not be written. + adapter->SubmitPing(42); + EXPECT_FALSE(adapter->want_write()); + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), testing::IsEmpty()); +} + +TEST(NgHttp2AdapterTest, ServerHandlesContentLength) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + testing::InSequence s; + + const std::string stream_frames = + TestFrameSequence() + .ClientPreface() + .Headers(1, {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"content-length", "2"}}) + .Data(1, "hi", /*fin=*/true) + .Headers(3, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}, + {"content-length", "nan"}}, + /*fin=*/true) + .Serialize(); + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + // Stream 1: content-length is correct + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(5); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 1)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 2)); + EXPECT_CALL(visitor, OnDataForStream(1, "hi")); + EXPECT_CALL(visitor, OnEndStream(1)); + + // Stream 3: content-length is not a number + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, _, _)).Times(4); + EXPECT_CALL( + visitor, + OnErrorDebug("Invalid HTTP header field was received: frame type: 1, " + "stream: 3, name: [content-length], value: [nan]")); + EXPECT_CALL( + visitor, + OnInvalidFrame(3, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 3, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 3, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(3, Http2ErrorCode::PROTOCOL_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, ServerHandlesContentLengthMismatch) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + testing::InSequence s; + + const std::string stream_frames = + TestFrameSequence() + .ClientPreface() + .Headers(1, {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}, + {"content-length", "2"}}) + .Data(1, "h", /*fin=*/true) + .Headers(3, {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/three"}, + {"content-length", "2"}}) + .Data(3, "howdy", /*fin=*/true) + .Headers(5, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/four"}, + {"content-length", "2"}}, + /*fin=*/true) + .Headers(7, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/four"}, + {"content-length", "2"}}, + /*fin=*/false) + .Data(7, "h", /*fin=*/false) + .Headers(7, {{"extra-info", "Trailers with content-length mismatch"}}, + /*fin=*/true) + .Serialize(); + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + // Stream 1: content-length is larger than actual data + // All data is delivered to the visitor, but OnInvalidFrame() is not. + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(5); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 1)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 1)); + EXPECT_CALL(visitor, OnDataForStream(1, "h")); + + // Stream 3: content-length is smaller than actual data + // The beginning of data is delivered to the visitor, but not the actual data, + // and neither is OnInvalidFrame(). + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, _, _)).Times(5); + EXPECT_CALL(visitor, OnEndHeadersForStream(3)); + EXPECT_CALL(visitor, OnFrameHeader(3, _, DATA, 1)); + EXPECT_CALL(visitor, OnBeginDataForStream(3, 5)); + + // Stream 5: content-length is invalid and HEADERS ends the stream + // When the stream ends with HEADERS, nghttp2 invokes OnInvalidFrame(). + EXPECT_CALL(visitor, OnFrameHeader(5, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(5)); + EXPECT_CALL(visitor, OnHeaderForStream(5, _, _)).Times(5); + EXPECT_CALL(visitor, + OnInvalidFrame( + 5, Http2VisitorInterface::InvalidFrameError::kHttpMessaging)); + + // Stream 7: content-length is invalid and trailers end the stream + // When the stream ends with trailers, nghttp2 invokes OnInvalidFrame(). + EXPECT_CALL(visitor, OnFrameHeader(7, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(7)); + EXPECT_CALL(visitor, OnHeaderForStream(7, _, _)).Times(5); + EXPECT_CALL(visitor, OnEndHeadersForStream(7)); + EXPECT_CALL(visitor, OnFrameHeader(7, _, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(7, 1)); + EXPECT_CALL(visitor, OnDataForStream(7, "h")); + EXPECT_CALL(visitor, OnFrameHeader(7, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(7)); + EXPECT_CALL(visitor, OnHeaderForStream(7, _, _)); + EXPECT_CALL(visitor, + OnInvalidFrame( + 7, Http2VisitorInterface::InvalidFrameError::kHttpMessaging)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::PROTOCOL_ERROR)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 3, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 3, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(3, Http2ErrorCode::PROTOCOL_ERROR)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 5, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 5, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(5, Http2ErrorCode::PROTOCOL_ERROR)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 7, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 7, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(7, Http2ErrorCode::PROTOCOL_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT( + visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::RST_STREAM, + SpdyFrameType::RST_STREAM, SpdyFrameType::RST_STREAM, + SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, ServerHandlesAsteriskPathForOptions) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + testing::InSequence s; + + const std::string stream_frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":scheme", "https"}, + {":authority", "example.com"}, + {":path", "*"}, + {":method", "OPTIONS"}}, + /*fin=*/true) + .Serialize(); + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + EXPECT_TRUE(adapter->want_write()); + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, ServerHandlesInvalidPath) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + testing::InSequence s; + + const std::string stream_frames = + TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":scheme", "https"}, + {":authority", "example.com"}, + {":path", "*"}, + {":method", "GET"}}, + /*fin=*/true) + .Headers(3, + {{":scheme", "https"}, + {":authority", "example.com"}, + {":path", "other/non/slash/starter"}, + {":method", "GET"}}, + /*fin=*/true) + .Headers(5, + {{":scheme", "https"}, + {":authority", "example.com"}, + {":path", ""}, + {":method", "GET"}}, + /*fin=*/true) + .Serialize(); + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, + OnInvalidFrame( + 1, Http2VisitorInterface::InvalidFrameError::kHttpMessaging)); + + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, _, _)).Times(4); + EXPECT_CALL(visitor, + OnInvalidFrame( + 3, Http2VisitorInterface::InvalidFrameError::kHttpMessaging)); + + EXPECT_CALL(visitor, OnFrameHeader(5, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(5)); + EXPECT_CALL(visitor, OnHeaderForStream(5, _, _)).Times(2); + EXPECT_CALL( + visitor, + OnErrorDebug("Invalid HTTP header field was received: frame type: 1, " + "stream: 5, name: [:path], value: []")); + EXPECT_CALL( + visitor, + OnInvalidFrame(5, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::PROTOCOL_ERROR)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 3, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 3, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(3, Http2ErrorCode::PROTOCOL_ERROR)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 5, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 5, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(5, Http2ErrorCode::PROTOCOL_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT( + visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::RST_STREAM, + SpdyFrameType::RST_STREAM, SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, ServerHandlesTeHeader) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + testing::InSequence s; + + const std::string stream_frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}, + {":method", "GET"}, + {"te", "trailers"}}, + /*fin=*/true) + .Headers(3, + {{":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}, + {":method", "GET"}, + {"te", "trailers, deflate"}}, + /*fin=*/true) + .Serialize(); + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + // Stream 1: TE: trailers should be allowed. + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(5); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + // Stream 3: TE: should be rejected. + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, _, _)).Times(4); + EXPECT_CALL( + visitor, + OnErrorDebug("Invalid HTTP header field was received: frame type: 1, " + "stream: 3, name: [te], value: [trailers, deflate]")); + EXPECT_CALL( + visitor, + OnInvalidFrame(3, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 3, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 3, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(3, Http2ErrorCode::PROTOCOL_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, ServerHandlesConnectionSpecificHeaders) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + testing::InSequence s; + + const std::string stream_frames = + TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}, + {":method", "GET"}, + {"connection", "keep-alive"}}, + /*fin=*/true) + .Headers(3, + {{":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}, + {":method", "GET"}, + {"proxy-connection", "keep-alive"}}, + /*fin=*/true) + .Headers(5, + {{":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}, + {":method", "GET"}, + {"keep-alive", "timeout=42"}}, + /*fin=*/true) + .Headers(7, + {{":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}, + {":method", "GET"}, + {"transfer-encoding", "chunked"}}, + /*fin=*/true) + .Headers(9, + {{":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}, + {":method", "GET"}, + {"upgrade", "h2c"}}, + /*fin=*/true) + .Serialize(); + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + // All streams contain a connection-specific header and should be rejected. + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL( + visitor, + OnErrorDebug("Invalid HTTP header field was received: frame type: 1, " + "stream: 1, name: [connection], value: [keep-alive]")); + EXPECT_CALL( + visitor, + OnInvalidFrame(1, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, _, _)).Times(4); + EXPECT_CALL( + visitor, + OnErrorDebug("Invalid HTTP header field was received: frame type: 1, " + "stream: 3, name: [proxy-connection], value: [keep-alive]")); + EXPECT_CALL( + visitor, + OnInvalidFrame(3, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + EXPECT_CALL(visitor, OnFrameHeader(5, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(5)); + EXPECT_CALL(visitor, OnHeaderForStream(5, _, _)).Times(4); + EXPECT_CALL( + visitor, + OnErrorDebug("Invalid HTTP header field was received: frame type: 1, " + "stream: 5, name: [keep-alive], value: [timeout=42]")); + EXPECT_CALL( + visitor, + OnInvalidFrame(5, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + EXPECT_CALL(visitor, OnFrameHeader(7, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(7)); + EXPECT_CALL(visitor, OnHeaderForStream(7, _, _)).Times(4); + EXPECT_CALL( + visitor, + OnErrorDebug("Invalid HTTP header field was received: frame type: 1, " + "stream: 7, name: [transfer-encoding], value: [chunked]")); + EXPECT_CALL( + visitor, + OnInvalidFrame(7, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + EXPECT_CALL(visitor, OnFrameHeader(9, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(9)); + EXPECT_CALL(visitor, OnHeaderForStream(9, _, _)).Times(4); + EXPECT_CALL( + visitor, + OnErrorDebug("Invalid HTTP header field was received: frame type: 1, " + "stream: 9, name: [upgrade], value: [h2c]")); + EXPECT_CALL( + visitor, + OnInvalidFrame(9, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::PROTOCOL_ERROR)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 3, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 3, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(3, Http2ErrorCode::PROTOCOL_ERROR)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 5, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 5, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(5, Http2ErrorCode::PROTOCOL_ERROR)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 7, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 7, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(7, Http2ErrorCode::PROTOCOL_ERROR)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 9, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 9, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(9, Http2ErrorCode::PROTOCOL_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT( + visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::RST_STREAM, + SpdyFrameType::RST_STREAM, SpdyFrameType::RST_STREAM, + SpdyFrameType::RST_STREAM, SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, NegativeFlowControlStreamResumption) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = + TestFrameSequence() + .ClientPreface({{INITIAL_WINDOW_SIZE, 128u * 1024u}}) + .WindowUpdate(0, 1 << 20) + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, + OnSetting(Http2Setting{Http2KnownSettingsId::INITIAL_WINDOW_SIZE, + 128u * 1024u})); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(0, 1 << 20)); + + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0x5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + + // Submit a response for the stream. + auto body = std::make_unique(visitor, true); + TestDataFrameSource& body_ref = *body; + body_ref.AppendPayload(std::string(70000, 'a')); + int submit_result = adapter->SubmitResponse( + 1, ToHeaders({{":status", "200"}}), std::move(body)); + ASSERT_EQ(0, submit_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)).Times(5); + + adapter->Send(); + EXPECT_FALSE(adapter->want_write()); + + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, + OnSetting(Http2Setting{Http2KnownSettingsId::INITIAL_WINDOW_SIZE, + 64u * 1024u})); + EXPECT_CALL(visitor, OnSettingsEnd()); + + // Processing these SETTINGS will cause stream 1's send window to become + // negative. + adapter->ProcessBytes(TestFrameSequence() + .Settings({{INITIAL_WINDOW_SIZE, 64u * 1024u}}) + .Serialize()); + EXPECT_TRUE(adapter->want_write()); + // nghttp2 does not expose the fact that the send window size is negative. + EXPECT_EQ(adapter->GetStreamSendWindowSize(1), 0); + + body_ref.AppendPayload("Stream should be resumed."); + adapter->ResumeStream(1); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + adapter->Send(); + EXPECT_FALSE(adapter->want_write()); + + // Upon receiving the WINDOW_UPDATE, stream 1 should be ready to write. + EXPECT_CALL(visitor, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(1, 10000)); + adapter->ProcessBytes(TestFrameSequence().WindowUpdate(1, 10000).Serialize()); + EXPECT_TRUE(adapter->want_write()); + EXPECT_GT(adapter->GetStreamSendWindowSize(1), 0); + + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)); + adapter->Send(); +} + +} // namespace +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/nghttp2_callbacks.cc b/quiche/http2/adapter/nghttp2_callbacks.cc new file mode 100644 index 000000000000..200a70999055 --- /dev/null +++ b/quiche/http2/adapter/nghttp2_callbacks.cc @@ -0,0 +1,388 @@ +#include "quiche/http2/adapter/nghttp2_callbacks.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/http2/adapter/data_source.h" +#include "quiche/http2/adapter/http2_protocol.h" +#include "quiche/http2/adapter/http2_visitor_interface.h" +#include "quiche/http2/adapter/nghttp2_data_provider.h" +#include "quiche/http2/adapter/nghttp2_util.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_endian.h" + +namespace http2 { +namespace adapter { +namespace callbacks { + +ssize_t OnReadyToSend(nghttp2_session* /* session */, const uint8_t* data, + size_t length, int flags, void* user_data) { + QUICHE_CHECK_NE(user_data, nullptr); + auto* visitor = static_cast(user_data); + const int64_t result = visitor->OnReadyToSend(ToStringView(data, length)); + QUICHE_VLOG(1) << "callbacks::OnReadyToSend(length=" << length + << ", flags=" << flags << ") returning " << result; + if (result > 0) { + return result; + } else if (result == Http2VisitorInterface::kSendBlocked) { + return -504; // NGHTTP2_ERR_WOULDBLOCK + } else { + return -902; // NGHTTP2_ERR_CALLBACK_FAILURE + } +} + +int OnBeginFrame(nghttp2_session* /* session */, const nghttp2_frame_hd* header, + void* user_data) { + QUICHE_VLOG(1) << "callbacks::OnBeginFrame(stream_id=" << header->stream_id + << ", type=" << int(header->type) + << ", length=" << header->length + << ", flags=" << int(header->flags) << ")"; + QUICHE_CHECK_NE(user_data, nullptr); + auto* visitor = static_cast(user_data); + bool result = visitor->OnFrameHeader(header->stream_id, header->length, + header->type, header->flags); + if (!result) { + return NGHTTP2_ERR_CALLBACK_FAILURE; + } + if (header->type == NGHTTP2_DATA) { + result = visitor->OnBeginDataForStream(header->stream_id, header->length); + } else if (header->type == kMetadataFrameType) { + visitor->OnBeginMetadataForStream(header->stream_id, header->length); + } + return result ? 0 : NGHTTP2_ERR_CALLBACK_FAILURE; +} + +int OnFrameReceived(nghttp2_session* /* session */, const nghttp2_frame* frame, + void* user_data) { + QUICHE_VLOG(1) << "callbacks::OnFrameReceived(stream_id=" + << frame->hd.stream_id << ", type=" << int(frame->hd.type) + << ", length=" << frame->hd.length + << ", flags=" << int(frame->hd.flags) << ")"; + QUICHE_CHECK_NE(user_data, nullptr); + auto* visitor = static_cast(user_data); + const Http2StreamId stream_id = frame->hd.stream_id; + switch (frame->hd.type) { + // The beginning of the DATA frame is handled in OnBeginFrame(), and the + // beginning of the header block is handled in client/server-specific + // callbacks. This callback handles the point at which the entire logical + // frame has been received and processed. + case NGHTTP2_DATA: + if ((frame->hd.flags & NGHTTP2_FLAG_PADDED) != 0) { + visitor->OnDataPaddingLength(stream_id, frame->data.padlen); + } + if (frame->hd.flags & NGHTTP2_FLAG_END_STREAM) { + const bool result = visitor->OnEndStream(stream_id); + if (!result) { + return NGHTTP2_ERR_CALLBACK_FAILURE; + } + } + break; + case NGHTTP2_HEADERS: { + if (frame->hd.flags & NGHTTP2_FLAG_END_HEADERS) { + const bool result = visitor->OnEndHeadersForStream(stream_id); + if (!result) { + return NGHTTP2_ERR_CALLBACK_FAILURE; + } + } + if (frame->hd.flags & NGHTTP2_FLAG_END_STREAM) { + const bool result = visitor->OnEndStream(stream_id); + if (!result) { + return NGHTTP2_ERR_CALLBACK_FAILURE; + } + } + break; + } + case NGHTTP2_PRIORITY: { + nghttp2_priority_spec priority_spec = frame->priority.pri_spec; + visitor->OnPriorityForStream(stream_id, priority_spec.stream_id, + priority_spec.weight, + priority_spec.exclusive != 0); + break; + } + case NGHTTP2_RST_STREAM: { + visitor->OnRstStream(stream_id, + ToHttp2ErrorCode(frame->rst_stream.error_code)); + break; + } + case NGHTTP2_SETTINGS: + if (frame->hd.flags & NGHTTP2_FLAG_ACK) { + visitor->OnSettingsAck(); + } else { + visitor->OnSettingsStart(); + for (size_t i = 0; i < frame->settings.niv; ++i) { + nghttp2_settings_entry entry = frame->settings.iv[i]; + // The nghttp2_settings_entry uses int32_t for the ID; we must cast. + visitor->OnSetting(Http2Setting{ + static_cast(entry.settings_id), entry.value}); + } + visitor->OnSettingsEnd(); + } + break; + case NGHTTP2_PUSH_PROMISE: + // This case is handled by headers-related callbacks: + // 1. visitor->OnPushPromiseForStream() is invoked in the client-side + // OnHeadersStart() adapter callback, as nghttp2 only allows clients + // to receive PUSH_PROMISE frames. + // 2. visitor->OnHeaderForStream() is invoked for each server push + // request header in the PUSH_PROMISE header block. + // 3. This switch statement is reached once all server push request + // headers have been parsed. + break; + case NGHTTP2_PING: { + Http2PingId ping_id; + std::memcpy(&ping_id, frame->ping.opaque_data, sizeof(Http2PingId)); + visitor->OnPing(quiche::QuicheEndian::NetToHost64(ping_id), + (frame->hd.flags & NGHTTP2_FLAG_ACK) != 0); + break; + } + case NGHTTP2_GOAWAY: { + absl::string_view opaque_data( + reinterpret_cast(frame->goaway.opaque_data), + frame->goaway.opaque_data_len); + const bool result = visitor->OnGoAway( + frame->goaway.last_stream_id, + ToHttp2ErrorCode(frame->goaway.error_code), opaque_data); + if (!result) { + return NGHTTP2_ERR_CALLBACK_FAILURE; + } + break; + } + case NGHTTP2_WINDOW_UPDATE: { + visitor->OnWindowUpdate(stream_id, + frame->window_update.window_size_increment); + break; + } + case NGHTTP2_CONTINUATION: + // This frame type should not be passed to any callbacks, according to + // https://nghttp2.org/documentation/enums.html#c.NGHTTP2_CONTINUATION. + QUICHE_LOG(ERROR) << "Unexpected receipt of NGHTTP2_CONTINUATION type!"; + break; + case NGHTTP2_ALTSVC: + break; + case NGHTTP2_ORIGIN: + break; + } + + return 0; +} + +int OnBeginHeaders(nghttp2_session* /* session */, const nghttp2_frame* frame, + void* user_data) { + QUICHE_VLOG(1) << "callbacks::OnBeginHeaders(stream_id=" + << frame->hd.stream_id << ")"; + QUICHE_CHECK_NE(user_data, nullptr); + auto* visitor = static_cast(user_data); + const bool result = visitor->OnBeginHeadersForStream(frame->hd.stream_id); + return result ? 0 : NGHTTP2_ERR_CALLBACK_FAILURE; +} + +int OnHeader(nghttp2_session* /* session */, const nghttp2_frame* frame, + nghttp2_rcbuf* name, nghttp2_rcbuf* value, uint8_t /*flags*/, + void* user_data) { + QUICHE_VLOG(2) << "callbacks::OnHeader(stream_id=" << frame->hd.stream_id + << ", name=[" << absl::CEscape(ToStringView(name)) + << "], value=[" << absl::CEscape(ToStringView(value)) << "])"; + QUICHE_CHECK_NE(user_data, nullptr); + auto* visitor = static_cast(user_data); + const Http2VisitorInterface::OnHeaderResult result = + visitor->OnHeaderForStream(frame->hd.stream_id, ToStringView(name), + ToStringView(value)); + switch (result) { + case Http2VisitorInterface::HEADER_OK: + return 0; + case Http2VisitorInterface::HEADER_CONNECTION_ERROR: + case Http2VisitorInterface::HEADER_COMPRESSION_ERROR: + return NGHTTP2_ERR_CALLBACK_FAILURE; + case Http2VisitorInterface::HEADER_RST_STREAM: + case Http2VisitorInterface::HEADER_FIELD_INVALID: + case Http2VisitorInterface::HEADER_HTTP_MESSAGING: + return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE; + } + // Unexpected value. + return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE; +} + +int OnBeforeFrameSent(nghttp2_session* /* session */, + const nghttp2_frame* frame, void* user_data) { + QUICHE_VLOG(1) << "callbacks::OnBeforeFrameSent(stream_id=" + << frame->hd.stream_id << ", type=" << int(frame->hd.type) + << ", length=" << frame->hd.length + << ", flags=" << int(frame->hd.flags) << ")"; + QUICHE_CHECK_NE(user_data, nullptr); + LogBeforeSend(*frame); + auto* visitor = static_cast(user_data); + return visitor->OnBeforeFrameSent(frame->hd.type, frame->hd.stream_id, + frame->hd.length, frame->hd.flags); +} + +int OnFrameSent(nghttp2_session* /* session */, const nghttp2_frame* frame, + void* user_data) { + QUICHE_VLOG(1) << "callbacks::OnFrameSent(stream_id=" << frame->hd.stream_id + << ", type=" << int(frame->hd.type) + << ", length=" << frame->hd.length + << ", flags=" << int(frame->hd.flags) << ")"; + QUICHE_CHECK_NE(user_data, nullptr); + auto* visitor = static_cast(user_data); + uint32_t error_code = 0; + if (frame->hd.type == NGHTTP2_RST_STREAM) { + error_code = frame->rst_stream.error_code; + } else if (frame->hd.type == NGHTTP2_GOAWAY) { + error_code = frame->goaway.error_code; + } + return visitor->OnFrameSent(frame->hd.type, frame->hd.stream_id, + frame->hd.length, frame->hd.flags, error_code); +} + +int OnFrameNotSent(nghttp2_session* /* session */, const nghttp2_frame* frame, + int /* lib_error_code */, void* /* user_data */) { + QUICHE_VLOG(1) << "callbacks::OnFrameNotSent(stream_id=" + << frame->hd.stream_id << ", type=" << int(frame->hd.type) + << ", length=" << frame->hd.length + << ", flags=" << int(frame->hd.flags) << ")"; + if (frame->hd.type == kMetadataFrameType) { + auto* source = static_cast(frame->ext.payload); + if (source == nullptr) { + QUICHE_BUG(not_sent_payload_is_nullptr) + << "Extension frame payload for stream " << frame->hd.stream_id + << " is null!"; + } else { + source->OnFailure(); + } + } + return 0; +} + +int OnInvalidFrameReceived(nghttp2_session* /* session */, + const nghttp2_frame* frame, int lib_error_code, + void* user_data) { + QUICHE_VLOG(1) << "callbacks::OnInvalidFrameReceived(stream_id=" + << frame->hd.stream_id << ", InvalidFrameError=" + << int(ToInvalidFrameError(lib_error_code)) << ")"; + QUICHE_CHECK_NE(user_data, nullptr); + auto* visitor = static_cast(user_data); + const bool result = visitor->OnInvalidFrame( + frame->hd.stream_id, ToInvalidFrameError(lib_error_code)); + return result ? 0 : NGHTTP2_ERR_CALLBACK_FAILURE; +} + +int OnDataChunk(nghttp2_session* /* session */, uint8_t /*flags*/, + Http2StreamId stream_id, const uint8_t* data, size_t len, + void* user_data) { + QUICHE_VLOG(1) << "callbacks::OnDataChunk(stream_id=" << stream_id + << ", length=" << len << ")"; + QUICHE_CHECK_NE(user_data, nullptr); + auto* visitor = static_cast(user_data); + const bool result = visitor->OnDataForStream( + stream_id, absl::string_view(reinterpret_cast(data), len)); + return result ? 0 : NGHTTP2_ERR_CALLBACK_FAILURE; +} + +int OnStreamClosed(nghttp2_session* /* session */, Http2StreamId stream_id, + uint32_t error_code, void* user_data) { + QUICHE_VLOG(1) << "callbacks::OnStreamClosed(stream_id=" << stream_id + << ", error_code=" << error_code << ")"; + QUICHE_CHECK_NE(user_data, nullptr); + auto* visitor = static_cast(user_data); + const bool result = + visitor->OnCloseStream(stream_id, ToHttp2ErrorCode(error_code)); + return result ? 0 : NGHTTP2_ERR_CALLBACK_FAILURE; +} + +int OnExtensionChunkReceived(nghttp2_session* /*session*/, + const nghttp2_frame_hd* hd, const uint8_t* data, + size_t len, void* user_data) { + QUICHE_CHECK_NE(user_data, nullptr); + auto* visitor = static_cast(user_data); + if (hd->type != kMetadataFrameType) { + QUICHE_LOG(ERROR) << "Unexpected frame type: " + << static_cast(hd->type); + return NGHTTP2_ERR_CANCEL; + } + const bool result = + visitor->OnMetadataForStream(hd->stream_id, ToStringView(data, len)); + return result ? 0 : NGHTTP2_ERR_CALLBACK_FAILURE; +} + +int OnUnpackExtensionCallback(nghttp2_session* /*session*/, void** /*payload*/, + const nghttp2_frame_hd* hd, void* user_data) { + QUICHE_CHECK_NE(user_data, nullptr); + auto* visitor = static_cast(user_data); + if (hd->flags == kMetadataEndFlag) { + const bool result = visitor->OnMetadataEndForStream(hd->stream_id); + if (!result) { + return NGHTTP2_ERR_CALLBACK_FAILURE; + } + } + return 0; +} + +ssize_t OnPackExtensionCallback(nghttp2_session* /*session*/, uint8_t* buf, + size_t len, const nghttp2_frame* frame, + void* user_data) { + QUICHE_CHECK_NE(user_data, nullptr); + auto* source = static_cast(frame->ext.payload); + if (source == nullptr) { + QUICHE_BUG(payload_is_nullptr) << "Extension frame payload for stream " + << frame->hd.stream_id << " is null!"; + return NGHTTP2_ERR_CALLBACK_FAILURE; + } + const std::pair result = source->Pack(buf, len); + if (result.first < 0) { + return NGHTTP2_ERR_CALLBACK_FAILURE; + } + const bool end_metadata_flag = (frame->hd.flags & kMetadataEndFlag); + QUICHE_LOG_IF(DFATAL, result.second != end_metadata_flag) + << "Metadata ends: " << result.second + << " has kMetadataEndFlag: " << end_metadata_flag; + return result.first; +} + +int OnError(nghttp2_session* /*session*/, int /*lib_error_code*/, + const char* msg, size_t len, void* user_data) { + QUICHE_VLOG(1) << "callbacks::OnError(" << absl::string_view(msg, len) << ")"; + QUICHE_CHECK_NE(user_data, nullptr); + auto* visitor = static_cast(user_data); + visitor->OnErrorDebug(absl::string_view(msg, len)); + return 0; +} + +nghttp2_session_callbacks_unique_ptr Create() { + nghttp2_session_callbacks* callbacks; + nghttp2_session_callbacks_new(&callbacks); + + nghttp2_session_callbacks_set_send_callback(callbacks, &OnReadyToSend); + nghttp2_session_callbacks_set_on_begin_frame_callback(callbacks, + &OnBeginFrame); + nghttp2_session_callbacks_set_on_frame_recv_callback(callbacks, + &OnFrameReceived); + nghttp2_session_callbacks_set_on_begin_headers_callback(callbacks, + &OnBeginHeaders); + nghttp2_session_callbacks_set_on_header_callback2(callbacks, &OnHeader); + nghttp2_session_callbacks_set_on_data_chunk_recv_callback(callbacks, + &OnDataChunk); + nghttp2_session_callbacks_set_on_stream_close_callback(callbacks, + &OnStreamClosed); + nghttp2_session_callbacks_set_before_frame_send_callback(callbacks, + &OnBeforeFrameSent); + nghttp2_session_callbacks_set_on_frame_send_callback(callbacks, &OnFrameSent); + nghttp2_session_callbacks_set_on_frame_not_send_callback(callbacks, + &OnFrameNotSent); + nghttp2_session_callbacks_set_on_invalid_frame_recv_callback( + callbacks, &OnInvalidFrameReceived); + nghttp2_session_callbacks_set_error_callback2(callbacks, &OnError); + nghttp2_session_callbacks_set_send_data_callback( + callbacks, &DataFrameSourceSendCallback); + nghttp2_session_callbacks_set_pack_extension_callback( + callbacks, &OnPackExtensionCallback); + nghttp2_session_callbacks_set_unpack_extension_callback( + callbacks, &OnUnpackExtensionCallback); + nghttp2_session_callbacks_set_on_extension_chunk_recv_callback( + callbacks, &OnExtensionChunkReceived); + return MakeCallbacksPtr(callbacks); +} + +} // namespace callbacks +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/nghttp2_callbacks.h b/quiche/http2/adapter/nghttp2_callbacks.h new file mode 100644 index 000000000000..dfdff066457e --- /dev/null +++ b/quiche/http2/adapter/nghttp2_callbacks.h @@ -0,0 +1,90 @@ +#ifndef QUICHE_HTTP2_ADAPTER_NGHTTP2_CALLBACKS_H_ +#define QUICHE_HTTP2_ADAPTER_NGHTTP2_CALLBACKS_H_ + +#include + +#include "quiche/http2/adapter/http2_protocol.h" +#include "quiche/http2/adapter/nghttp2.h" +#include "quiche/http2/adapter/nghttp2_util.h" + +namespace http2 { +namespace adapter { +namespace callbacks { + +// The following functions are nghttp2 callbacks that Nghttp2Adapter sets at the +// beginning of its lifetime. It is expected that |user_data| holds an +// Http2VisitorInterface. + +// Callback once the library is ready to send serialized frames. +ssize_t OnReadyToSend(nghttp2_session* session, const uint8_t* data, + size_t length, int flags, void* user_data); + +// Callback once a frame header has been received. +int OnBeginFrame(nghttp2_session* session, const nghttp2_frame_hd* header, + void* user_data); + +// Callback once a complete frame has been received. +int OnFrameReceived(nghttp2_session* session, const nghttp2_frame* frame, + void* user_data); + +// Callback at the start of a frame carrying headers. +int OnBeginHeaders(nghttp2_session* session, const nghttp2_frame* frame, + void* user_data); + +// Callback once a name-value header has been received. +int OnHeader(nghttp2_session* session, const nghttp2_frame* frame, + nghttp2_rcbuf* name, nghttp2_rcbuf* value, uint8_t flags, + void* user_data); + +// Invoked immediately before sending a frame. +int OnBeforeFrameSent(nghttp2_session* session, const nghttp2_frame* frame, + void* user_data); + +// Invoked immediately after a frame is sent. +int OnFrameSent(nghttp2_session* session, const nghttp2_frame* frame, + void* user_data); + +// Invoked when a non-DATA frame is not sent because of an error. +int OnFrameNotSent(nghttp2_session* session, const nghttp2_frame* frame, + int lib_error_code, void* user_data); + +// Invoked when an invalid frame is received. +int OnInvalidFrameReceived(nghttp2_session* session, const nghttp2_frame* frame, + int lib_error_code, void* user_data); + +// Invoked when a chunk of data (from a DATA frame payload) has been received. +int OnDataChunk(nghttp2_session* session, uint8_t flags, + Http2StreamId stream_id, const uint8_t* data, size_t len, + void* user_data); + +// Callback once a stream has been closed. +int OnStreamClosed(nghttp2_session* session, Http2StreamId stream_id, + uint32_t error_code, void* user_data); + +// Invoked when nghttp2 has a chunk of extension frame data to pass to the +// application. +int OnExtensionChunkReceived(nghttp2_session* session, + const nghttp2_frame_hd* hd, const uint8_t* data, + size_t len, void* user_data); + +// Invoked when nghttp2 wants the application to unpack an extension payload. +int OnUnpackExtensionCallback(nghttp2_session* session, void** payload, + const nghttp2_frame_hd* hd, void* user_data); + +// Invoked when nghttp2 is ready to pack an extension payload. Returns the +// number of bytes serialized to |buf|. +ssize_t OnPackExtensionCallback(nghttp2_session* session, uint8_t* buf, + size_t len, const nghttp2_frame* frame, + void* user_data); + +// Invoked when the library has an error message to deliver. +int OnError(nghttp2_session* session, int lib_error_code, const char* msg, + size_t len, void* user_data); + +nghttp2_session_callbacks_unique_ptr Create(); + +} // namespace callbacks +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_NGHTTP2_CALLBACKS_H_ diff --git a/quiche/http2/adapter/nghttp2_data_provider.cc b/quiche/http2/adapter/nghttp2_data_provider.cc new file mode 100644 index 000000000000..c0d76d197b36 --- /dev/null +++ b/quiche/http2/adapter/nghttp2_data_provider.cc @@ -0,0 +1,62 @@ +#include "quiche/http2/adapter/nghttp2_data_provider.h" + +#include + +#include "quiche/http2/adapter/http2_visitor_interface.h" +#include "quiche/http2/adapter/nghttp2_util.h" + +namespace http2 { +namespace adapter { +namespace callbacks { + +namespace { +const size_t kFrameHeaderSize = 9; +} + +ssize_t DataFrameSourceReadCallback(nghttp2_session* /* session */, + int32_t /* stream_id */, uint8_t* /* buf */, + size_t length, uint32_t* data_flags, + nghttp2_data_source* source, + void* /* user_data */) { + *data_flags |= NGHTTP2_DATA_FLAG_NO_COPY; + auto* frame_source = static_cast(source->ptr); + auto [result_length, done] = frame_source->SelectPayloadLength(length); + if (result_length == 0 && !done) { + return NGHTTP2_ERR_DEFERRED; + } else if (result_length == DataFrameSource::kError) { + return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE; + } + if (done) { + *data_flags |= NGHTTP2_DATA_FLAG_EOF; + } + if (!frame_source->send_fin()) { + *data_flags |= NGHTTP2_DATA_FLAG_NO_END_STREAM; + } + return result_length; +} + +int DataFrameSourceSendCallback(nghttp2_session* /* session */, + nghttp2_frame* /* frame */, + const uint8_t* framehd, size_t length, + nghttp2_data_source* source, + void* /* user_data */) { + auto* frame_source = static_cast(source->ptr); + frame_source->Send(ToStringView(framehd, kFrameHeaderSize), length); + return 0; +} + +} // namespace callbacks + +std::unique_ptr MakeDataProvider( + DataFrameSource* source) { + if (source == nullptr) { + return nullptr; + } + auto provider = std::make_unique(); + provider->source.ptr = source; + provider->read_callback = &callbacks::DataFrameSourceReadCallback; + return provider; +} + +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/nghttp2_data_provider.h b/quiche/http2/adapter/nghttp2_data_provider.h new file mode 100644 index 000000000000..a3f095773440 --- /dev/null +++ b/quiche/http2/adapter/nghttp2_data_provider.h @@ -0,0 +1,37 @@ +#ifndef QUICHE_HTTP2_ADAPTER_NGHTTP2_DATA_PROVIDER_H_ +#define QUICHE_HTTP2_ADAPTER_NGHTTP2_DATA_PROVIDER_H_ + +#include +#include + +#include "quiche/http2/adapter/data_source.h" +#include "quiche/http2/adapter/nghttp2.h" + +namespace http2 { +namespace adapter { +namespace callbacks { + +// Assumes |source| is a DataFrameSource. +ssize_t DataFrameSourceReadCallback(nghttp2_session* /*session */, + int32_t /* stream_id */, uint8_t* /* buf */, + size_t length, uint32_t* data_flags, + nghttp2_data_source* source, + void* /* user_data */); + +int DataFrameSourceSendCallback(nghttp2_session* /* session */, + nghttp2_frame* /* frame */, + const uint8_t* framehd, size_t length, + nghttp2_data_source* source, + void* /* user_data */); + +} // namespace callbacks + +// Transforms a DataFrameSource into a nghttp2_data_provider. Does not take +// ownership of |source|. Returns nullptr if |source| is nullptr. +std::unique_ptr MakeDataProvider( + DataFrameSource* source); + +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_NGHTTP2_DATA_PROVIDER_H_ diff --git a/quiche/http2/adapter/nghttp2_data_provider_test.cc b/quiche/http2/adapter/nghttp2_data_provider_test.cc new file mode 100644 index 000000000000..3d90855c9288 --- /dev/null +++ b/quiche/http2/adapter/nghttp2_data_provider_test.cc @@ -0,0 +1,117 @@ +#include "quiche/http2/adapter/nghttp2_data_provider.h" + +#include "quiche/http2/adapter/test_utils.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace adapter { +namespace test { + +const size_t kFrameHeaderSize = 9; + +// Verifies that a nghttp2_data_provider derived from a DataFrameSource works +// correctly with nghttp2-style callbacks when the amount of data read is less +// than what the source provides. +TEST(DataProviderTest, ReadLessThanSourceProvides) { + DataSavingVisitor visitor; + TestDataFrameSource source(visitor, true); + source.AppendPayload("Example payload"); + source.EndData(); + auto provider = MakeDataProvider(&source); + uint32_t data_flags = 0; + const int32_t kStreamId = 1; + const size_t kReadLength = 10; + // Read callback selects a payload length given an upper bound. + ssize_t result = + provider->read_callback(nullptr, kStreamId, nullptr, kReadLength, + &data_flags, &provider->source, nullptr); + ASSERT_EQ(kReadLength, result); + EXPECT_EQ(NGHTTP2_DATA_FLAG_NO_COPY, data_flags); + + const uint8_t framehd[kFrameHeaderSize] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + // Sends the frame header and some payload bytes. + int send_result = callbacks::DataFrameSourceSendCallback( + nullptr, nullptr, framehd, result, &provider->source, nullptr); + EXPECT_EQ(0, send_result); + // Data accepted by the visitor includes a frame header and kReadLength bytes + // of payload. + EXPECT_EQ(visitor.data().size(), kFrameHeaderSize + kReadLength); +} + +// Verifies that a nghttp2_data_provider derived from a DataFrameSource works +// correctly with nghttp2-style callbacks when the amount of data read is more +// than what the source provides. +TEST(DataProviderTest, ReadMoreThanSourceProvides) { + DataSavingVisitor visitor; + const absl::string_view kPayload = "Example payload"; + TestDataFrameSource source(visitor, true); + source.AppendPayload(kPayload); + source.EndData(); + auto provider = MakeDataProvider(&source); + uint32_t data_flags = 0; + const int32_t kStreamId = 1; + const size_t kReadLength = 30; + // Read callback selects a payload length given an upper bound. + ssize_t result = + provider->read_callback(nullptr, kStreamId, nullptr, kReadLength, + &data_flags, &provider->source, nullptr); + ASSERT_EQ(kPayload.size(), result); + EXPECT_EQ(NGHTTP2_DATA_FLAG_NO_COPY | NGHTTP2_DATA_FLAG_EOF, data_flags); + + const uint8_t framehd[kFrameHeaderSize] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + // Sends the frame header and some payload bytes. + int send_result = callbacks::DataFrameSourceSendCallback( + nullptr, nullptr, framehd, result, &provider->source, nullptr); + EXPECT_EQ(0, send_result); + // Data accepted by the visitor includes a frame header and the entire + // payload. + EXPECT_EQ(visitor.data().size(), kFrameHeaderSize + kPayload.size()); +} + +// Verifies that a nghttp2_data_provider derived from a DataFrameSource works +// correctly with nghttp2-style callbacks when the source is blocked. +TEST(DataProviderTest, ReadFromBlockedSource) { + DataSavingVisitor visitor; + // Source has no payload, but also no fin, so it's blocked. + TestDataFrameSource source(visitor, false); + auto provider = MakeDataProvider(&source); + uint32_t data_flags = 0; + const int32_t kStreamId = 1; + const size_t kReadLength = 10; + ssize_t result = + provider->read_callback(nullptr, kStreamId, nullptr, kReadLength, + &data_flags, &provider->source, nullptr); + // Read operation is deferred, since the source is blocked. + EXPECT_EQ(NGHTTP2_ERR_DEFERRED, result); +} + +// Verifies that a nghttp2_data_provider derived from a DataFrameSource works +// correctly with nghttp2-style callbacks when the source provides only fin and +// no data. +TEST(DataProviderTest, ReadFromZeroLengthSource) { + DataSavingVisitor visitor; + // Empty payload and fin=true indicates the source is done. + TestDataFrameSource source(visitor, true); + source.EndData(); + auto provider = MakeDataProvider(&source); + uint32_t data_flags = 0; + const int32_t kStreamId = 1; + const size_t kReadLength = 10; + ssize_t result = + provider->read_callback(nullptr, kStreamId, nullptr, kReadLength, + &data_flags, &provider->source, nullptr); + ASSERT_EQ(0, result); + EXPECT_EQ(NGHTTP2_DATA_FLAG_NO_COPY | NGHTTP2_DATA_FLAG_EOF, data_flags); + + const uint8_t framehd[kFrameHeaderSize] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + int send_result = callbacks::DataFrameSourceSendCallback( + nullptr, nullptr, framehd, result, &provider->source, nullptr); + EXPECT_EQ(0, send_result); + // Data accepted by the visitor includes a frame header with fin and zero + // bytes of payload. + EXPECT_EQ(visitor.data().size(), kFrameHeaderSize); +} + +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/nghttp2_session.cc b/quiche/http2/adapter/nghttp2_session.cc new file mode 100644 index 000000000000..f7f49fb6008c --- /dev/null +++ b/quiche/http2/adapter/nghttp2_session.cc @@ -0,0 +1,57 @@ +#include "quiche/http2/adapter/nghttp2_session.h" + +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { +namespace adapter { + +NgHttp2Session::NgHttp2Session(Perspective perspective, + nghttp2_session_callbacks_unique_ptr callbacks, + const nghttp2_option* options, void* userdata) + : session_(MakeSessionPtr(nullptr)), perspective_(perspective) { + nghttp2_session* session; + switch (perspective_) { + case Perspective::kClient: + nghttp2_session_client_new2(&session, callbacks.get(), userdata, options); + break; + case Perspective::kServer: + nghttp2_session_server_new2(&session, callbacks.get(), userdata, options); + break; + } + session_ = MakeSessionPtr(session); +} + +NgHttp2Session::~NgHttp2Session() { + // Can't invoke want_read() or want_write(), as they are virtual methods. + const bool pending_reads = nghttp2_session_want_read(session_.get()) != 0; + const bool pending_writes = nghttp2_session_want_write(session_.get()) != 0; + if (pending_reads || pending_writes) { + QUICHE_VLOG(1) << "Shutting down connection with pending reads: " + << pending_reads << " or pending writes: " << pending_writes; + } +} + +int64_t NgHttp2Session::ProcessBytes(absl::string_view bytes) { + return nghttp2_session_mem_recv( + session_.get(), reinterpret_cast(bytes.data()), + bytes.size()); +} + +int NgHttp2Session::Consume(Http2StreamId stream_id, size_t num_bytes) { + return nghttp2_session_consume(session_.get(), stream_id, num_bytes); +} + +bool NgHttp2Session::want_read() const { + return nghttp2_session_want_read(session_.get()) != 0; +} + +bool NgHttp2Session::want_write() const { + return nghttp2_session_want_write(session_.get()) != 0; +} + +int NgHttp2Session::GetRemoteWindowSize() const { + return nghttp2_session_get_remote_window_size(session_.get()); +} + +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/nghttp2_session.h b/quiche/http2/adapter/nghttp2_session.h new file mode 100644 index 000000000000..b07119fa2eb3 --- /dev/null +++ b/quiche/http2/adapter/nghttp2_session.h @@ -0,0 +1,41 @@ +#ifndef QUICHE_HTTP2_ADAPTER_NGHTTP2_SESSION_H_ +#define QUICHE_HTTP2_ADAPTER_NGHTTP2_SESSION_H_ + +#include + +#include "quiche/http2/adapter/http2_session.h" +#include "quiche/http2/adapter/nghttp2.h" +#include "quiche/http2/adapter/nghttp2_util.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace adapter { + +// A C++ wrapper around common nghttp2_session operations. +class QUICHE_EXPORT NgHttp2Session : public Http2Session { + public: + // Does not take ownership of |options|. + NgHttp2Session(Perspective perspective, + nghttp2_session_callbacks_unique_ptr callbacks, + const nghttp2_option* options, void* userdata); + ~NgHttp2Session() override; + + int64_t ProcessBytes(absl::string_view bytes) override; + + int Consume(Http2StreamId stream_id, size_t num_bytes) override; + + bool want_read() const override; + bool want_write() const override; + int GetRemoteWindowSize() const override; + + nghttp2_session* raw_ptr() const { return session_.get(); } + + private: + nghttp2_session_unique_ptr session_; + Perspective perspective_; +}; + +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_NGHTTP2_SESSION_H_ diff --git a/quiche/http2/adapter/nghttp2_session_test.cc b/quiche/http2/adapter/nghttp2_session_test.cc new file mode 100644 index 000000000000..f11c490569bd --- /dev/null +++ b/quiche/http2/adapter/nghttp2_session_test.cc @@ -0,0 +1,374 @@ +#include "quiche/http2/adapter/nghttp2_session.h" + +#include "quiche/http2/adapter/mock_http2_visitor.h" +#include "quiche/http2/adapter/nghttp2_callbacks.h" +#include "quiche/http2/adapter/nghttp2_util.h" +#include "quiche/http2/adapter/test_frame_sequence.h" +#include "quiche/http2/adapter/test_utils.h" +#include "quiche/common/platform/api/quiche_expect_bug.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace adapter { +namespace test { +namespace { + +using testing::_; + +enum FrameType { + DATA, + HEADERS, + PRIORITY, + RST_STREAM, + SETTINGS, + PUSH_PROMISE, + PING, + GOAWAY, + WINDOW_UPDATE, +}; + +class NgHttp2SessionTest : public quiche::test::QuicheTest { + public: + void SetUp() override { + nghttp2_option_new(&options_); + nghttp2_option_set_no_auto_window_update(options_, 1); + } + + void TearDown() override { nghttp2_option_del(options_); } + + nghttp2_session_callbacks_unique_ptr CreateCallbacks() { + nghttp2_session_callbacks_unique_ptr callbacks = callbacks::Create(); + return callbacks; + } + + DataSavingVisitor visitor_; + nghttp2_option* options_ = nullptr; +}; + +TEST_F(NgHttp2SessionTest, ClientConstruction) { + NgHttp2Session session(Perspective::kClient, CreateCallbacks(), options_, + &visitor_); + EXPECT_TRUE(session.want_read()); + EXPECT_FALSE(session.want_write()); + EXPECT_EQ(session.GetRemoteWindowSize(), kInitialFlowControlWindowSize); + EXPECT_NE(session.raw_ptr(), nullptr); +} + +TEST_F(NgHttp2SessionTest, ClientHandlesFrames) { + NgHttp2Session session(Perspective::kClient, CreateCallbacks(), options_, + &visitor_); + + ASSERT_EQ(0, nghttp2_session_send(session.raw_ptr())); + ASSERT_GT(visitor_.data().size(), 0); + + const std::string initial_frames = TestFrameSequence() + .ServerPreface() + .Ping(42) + .WindowUpdate(0, 1000) + .Serialize(); + testing::InSequence s; + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor_, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor_, OnSettingsStart()); + EXPECT_CALL(visitor_, OnSettingsEnd()); + + EXPECT_CALL(visitor_, OnFrameHeader(0, 8, PING, 0)); + EXPECT_CALL(visitor_, OnPing(42, false)); + EXPECT_CALL(visitor_, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor_, OnWindowUpdate(0, 1000)); + + const int64_t initial_result = session.ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), initial_result); + + EXPECT_EQ(session.GetRemoteWindowSize(), + kInitialFlowControlWindowSize + 1000); + + EXPECT_CALL(visitor_, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor_, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor_, OnBeforeFrameSent(PING, 0, 8, 0x1)); + EXPECT_CALL(visitor_, OnFrameSent(PING, 0, 8, 0x1, 0)); + + ASSERT_EQ(0, nghttp2_session_send(session.raw_ptr())); + // Some bytes should have been serialized. + absl::string_view serialized = visitor_.data(); + ASSERT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::PING})); + visitor_.Clear(); + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + const auto nvs1 = GetNghttp2Nvs(headers1); + + const std::vector
headers2 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}); + const auto nvs2 = GetNghttp2Nvs(headers2); + + const std::vector
headers3 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/three"}}); + const auto nvs3 = GetNghttp2Nvs(headers3); + + const int32_t stream_id1 = nghttp2_submit_request( + session.raw_ptr(), nullptr, nvs1.data(), nvs1.size(), nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + const int32_t stream_id2 = nghttp2_submit_request( + session.raw_ptr(), nullptr, nvs2.data(), nvs2.size(), nullptr, nullptr); + ASSERT_GT(stream_id2, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id2; + + const int32_t stream_id3 = nghttp2_submit_request( + session.raw_ptr(), nullptr, nvs3.data(), nvs3.size(), nullptr, nullptr); + ASSERT_GT(stream_id3, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id3; + + EXPECT_CALL(visitor_, OnBeforeFrameSent(HEADERS, 1, _, 0x5)); + EXPECT_CALL(visitor_, OnFrameSent(HEADERS, 1, _, 0x5, 0)); + EXPECT_CALL(visitor_, OnBeforeFrameSent(HEADERS, 3, _, 0x5)); + EXPECT_CALL(visitor_, OnFrameSent(HEADERS, 3, _, 0x5, 0)); + EXPECT_CALL(visitor_, OnBeforeFrameSent(HEADERS, 5, _, 0x5)); + EXPECT_CALL(visitor_, OnFrameSent(HEADERS, 5, _, 0x5, 0)); + + ASSERT_EQ(0, nghttp2_session_send(session.raw_ptr())); + serialized = visitor_.data(); + EXPECT_THAT(serialized, EqualsFrames({spdy::SpdyFrameType::HEADERS, + spdy::SpdyFrameType::HEADERS, + spdy::SpdyFrameType::HEADERS})); + visitor_.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(1, "This is the response body.") + .RstStream(3, Http2ErrorCode::INTERNAL_ERROR) + .GoAway(5, Http2ErrorCode::ENHANCE_YOUR_CALM, "calm down!!") + .Serialize(); + + EXPECT_CALL(visitor_, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor_, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor_, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor_, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor_, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor_, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor_, OnFrameHeader(1, 26, DATA, 0)); + EXPECT_CALL(visitor_, OnBeginDataForStream(1, 26)); + EXPECT_CALL(visitor_, OnDataForStream(1, "This is the response body.")); + EXPECT_CALL(visitor_, OnFrameHeader(3, 4, RST_STREAM, 0)); + EXPECT_CALL(visitor_, OnRstStream(3, Http2ErrorCode::INTERNAL_ERROR)); + EXPECT_CALL(visitor_, OnCloseStream(3, Http2ErrorCode::INTERNAL_ERROR)); + EXPECT_CALL(visitor_, OnFrameHeader(0, 19, GOAWAY, 0)); + EXPECT_CALL(visitor_, + OnGoAway(5, Http2ErrorCode::ENHANCE_YOUR_CALM, "calm down!!")); + const int64_t stream_result = session.ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), stream_result); + + // Even though the client recieved a GOAWAY, streams 1 and 5 are still active. + EXPECT_TRUE(session.want_read()); + + EXPECT_CALL(visitor_, OnFrameHeader(1, 0, DATA, 1)); + EXPECT_CALL(visitor_, OnBeginDataForStream(1, 0)); + EXPECT_CALL(visitor_, OnEndStream(1)); + EXPECT_CALL(visitor_, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + EXPECT_CALL(visitor_, OnFrameHeader(5, 4, RST_STREAM, 0)); + EXPECT_CALL(visitor_, OnRstStream(5, Http2ErrorCode::REFUSED_STREAM)); + EXPECT_CALL(visitor_, OnCloseStream(5, Http2ErrorCode::REFUSED_STREAM)); + session.ProcessBytes(TestFrameSequence() + .Data(1, "", true) + .RstStream(5, Http2ErrorCode::REFUSED_STREAM) + .Serialize()); + // After receiving END_STREAM for 1 and RST_STREAM for 5, the session no + // longer expects reads. + EXPECT_FALSE(session.want_read()); + + // Client will not have anything else to write. + EXPECT_FALSE(session.want_write()); + ASSERT_EQ(0, nghttp2_session_send(session.raw_ptr())); + serialized = visitor_.data(); + EXPECT_EQ(serialized.size(), 0); +} + +TEST_F(NgHttp2SessionTest, ServerConstruction) { + NgHttp2Session session(Perspective::kServer, CreateCallbacks(), options_, + &visitor_); + EXPECT_TRUE(session.want_read()); + EXPECT_FALSE(session.want_write()); + EXPECT_EQ(session.GetRemoteWindowSize(), kInitialFlowControlWindowSize); + EXPECT_NE(session.raw_ptr(), nullptr); +} + +TEST_F(NgHttp2SessionTest, ServerHandlesFrames) { + NgHttp2Session session(Perspective::kServer, CreateCallbacks(), options_, + &visitor_); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Ping(42) + .WindowUpdate(0, 1000) + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .Headers(3, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .RstStream(3, Http2ErrorCode::CANCEL) + .Ping(47) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor_, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor_, OnSettingsStart()); + EXPECT_CALL(visitor_, OnSettingsEnd()); + + EXPECT_CALL(visitor_, OnFrameHeader(0, 8, PING, 0)); + EXPECT_CALL(visitor_, OnPing(42, false)); + EXPECT_CALL(visitor_, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor_, OnWindowUpdate(0, 1000)); + EXPECT_CALL(visitor_, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor_, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor_, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor_, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor_, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor_, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor_, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor_, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor_, OnWindowUpdate(1, 2000)); + EXPECT_CALL(visitor_, OnFrameHeader(1, 25, DATA, 0)); + EXPECT_CALL(visitor_, OnBeginDataForStream(1, 25)); + EXPECT_CALL(visitor_, OnDataForStream(1, "This is the request body.")); + EXPECT_CALL(visitor_, OnFrameHeader(3, _, HEADERS, 5)); + EXPECT_CALL(visitor_, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor_, OnHeaderForStream(3, ":method", "GET")); + EXPECT_CALL(visitor_, OnHeaderForStream(3, ":scheme", "http")); + EXPECT_CALL(visitor_, OnHeaderForStream(3, ":authority", "example.com")); + EXPECT_CALL(visitor_, OnHeaderForStream(3, ":path", "/this/is/request/two")); + EXPECT_CALL(visitor_, OnEndHeadersForStream(3)); + EXPECT_CALL(visitor_, OnEndStream(3)); + EXPECT_CALL(visitor_, OnFrameHeader(3, 4, RST_STREAM, 0)); + EXPECT_CALL(visitor_, OnRstStream(3, Http2ErrorCode::CANCEL)); + EXPECT_CALL(visitor_, OnCloseStream(3, Http2ErrorCode::CANCEL)); + EXPECT_CALL(visitor_, OnFrameHeader(0, 8, PING, 0)); + EXPECT_CALL(visitor_, OnPing(47, false)); + + const int64_t result = session.ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + EXPECT_EQ(session.GetRemoteWindowSize(), + kInitialFlowControlWindowSize + 1000); + + EXPECT_CALL(visitor_, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor_, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor_, OnBeforeFrameSent(PING, 0, 8, 0x1)); + EXPECT_CALL(visitor_, OnFrameSent(PING, 0, 8, 0x1, 0)); + EXPECT_CALL(visitor_, OnBeforeFrameSent(PING, 0, 8, 0x1)); + EXPECT_CALL(visitor_, OnFrameSent(PING, 0, 8, 0x1, 0)); + + EXPECT_TRUE(session.want_write()); + ASSERT_EQ(0, nghttp2_session_send(session.raw_ptr())); + // Some bytes should have been serialized. + absl::string_view serialized = visitor_.data(); + // SETTINGS ack, two PING acks. + EXPECT_THAT(serialized, EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::PING, + spdy::SpdyFrameType::PING})); +} + +// Verifies that a null payload is caught by the OnPackExtensionCallback +// implementation. +TEST_F(NgHttp2SessionTest, NullPayload) { + NgHttp2Session session(Perspective::kClient, CreateCallbacks(), options_, + &visitor_); + + void* payload = nullptr; + const int result = nghttp2_submit_extension( + session.raw_ptr(), kMetadataFrameType, 0, 1, payload); + ASSERT_EQ(0, result); + EXPECT_TRUE(session.want_write()); + int send_result = -1; + EXPECT_QUICHE_BUG( + { + send_result = nghttp2_session_send(session.raw_ptr()); + EXPECT_EQ(NGHTTP2_ERR_CALLBACK_FAILURE, send_result); + }, + "Extension frame payload for stream 1 is null!"); +} + +TEST_F(NgHttp2SessionTest, ServerSeesErrorOnEndStream) { + NgHttp2Session session(Perspective::kServer, CreateCallbacks(), options_, + &visitor_); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}}, + /*fin=*/false) + .Data(1, "Request body", true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor_, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor_, OnSettingsStart()); + EXPECT_CALL(visitor_, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor_, OnFrameHeader(1, _, HEADERS, 0x4)); + EXPECT_CALL(visitor_, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor_, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor_, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor_, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor_, OnHeaderForStream(1, ":path", "/")); + EXPECT_CALL(visitor_, OnEndHeadersForStream(1)); + + EXPECT_CALL(visitor_, OnFrameHeader(1, _, DATA, 0x1)); + EXPECT_CALL(visitor_, OnBeginDataForStream(1, _)); + EXPECT_CALL(visitor_, OnDataForStream(1, "Request body")); + EXPECT_CALL(visitor_, OnEndStream(1)).WillOnce(testing::Return(false)); + + const int64_t result = session.ProcessBytes(frames); + EXPECT_EQ(NGHTTP2_ERR_CALLBACK_FAILURE, result); + + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor_, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor_, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + ASSERT_EQ(0, nghttp2_session_send(session.raw_ptr())); + EXPECT_THAT(visitor_.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); + visitor_.Clear(); + + EXPECT_FALSE(session.want_write()); +} + +} // namespace +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/nghttp2_test.cc b/quiche/http2/adapter/nghttp2_test.cc new file mode 100644 index 000000000000..1363ba690dd3 --- /dev/null +++ b/quiche/http2/adapter/nghttp2_test.cc @@ -0,0 +1,262 @@ +#include "quiche/http2/adapter/nghttp2.h" + +#include "absl/strings/str_cat.h" +#include "quiche/http2/adapter/mock_nghttp2_callbacks.h" +#include "quiche/http2/adapter/nghttp2_test_utils.h" +#include "quiche/http2/adapter/nghttp2_util.h" +#include "quiche/http2/adapter/test_frame_sequence.h" +#include "quiche/http2/adapter/test_utils.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace adapter { +namespace test { +namespace { + +using testing::_; + +enum FrameType { + DATA, + HEADERS, + PRIORITY, + RST_STREAM, + SETTINGS, + PUSH_PROMISE, + PING, + GOAWAY, + WINDOW_UPDATE, +}; + +nghttp2_option* GetOptions() { + nghttp2_option* options; + nghttp2_option_new(&options); + // Set some common options for compatibility. + nghttp2_option_set_no_closed_streams(options, 1); + nghttp2_option_set_no_auto_window_update(options, 1); + nghttp2_option_set_max_send_header_block_length(options, 0x2000000); + nghttp2_option_set_max_outbound_ack(options, 10000); + return options; +} + +class Nghttp2Test : public quiche::test::QuicheTest { + public: + Nghttp2Test() : session_(MakeSessionPtr(nullptr)) {} + + void SetUp() override { InitializeSession(); } + + virtual Perspective GetPerspective() = 0; + + void InitializeSession() { + auto nghttp2_callbacks = MockNghttp2Callbacks::GetCallbacks(); + nghttp2_option* options = GetOptions(); + nghttp2_session* ptr; + if (GetPerspective() == Perspective::kClient) { + nghttp2_session_client_new2(&ptr, nghttp2_callbacks.get(), + &mock_callbacks_, options); + } else { + nghttp2_session_server_new2(&ptr, nghttp2_callbacks.get(), + &mock_callbacks_, options); + } + nghttp2_option_del(options); + + // Sets up the Send() callback to append to |serialized_|. + EXPECT_CALL(mock_callbacks_, Send(_, _, _)) + .WillRepeatedly( + [this](const uint8_t* data, size_t length, int /*flags*/) { + absl::StrAppend(&serialized_, ToStringView(data, length)); + return length; + }); + // Sets up the SendData() callback to fetch and append data from a + // TestDataSource. + EXPECT_CALL(mock_callbacks_, SendData(_, _, _, _)) + .WillRepeatedly([this](nghttp2_frame* /*frame*/, const uint8_t* framehd, + size_t length, nghttp2_data_source* source) { + QUICHE_LOG(INFO) << "Appending frame header and " << length + << " bytes of data"; + auto* s = static_cast(source->ptr); + absl::StrAppend(&serialized_, ToStringView(framehd, 9), + s->ReadNext(length)); + return 0; + }); + session_ = MakeSessionPtr(ptr); + } + + testing::StrictMock mock_callbacks_; + nghttp2_session_unique_ptr session_; + std::string serialized_; +}; + +class Nghttp2ClientTest : public Nghttp2Test { + public: + Perspective GetPerspective() override { return Perspective::kClient; } +}; + +// Verifies nghttp2 behavior when acting as a client. +TEST_F(Nghttp2ClientTest, ClientReceivesUnexpectedHeaders) { + const std::string initial_frames = TestFrameSequence() + .ServerPreface() + .Ping(42) + .WindowUpdate(0, 1000) + .Serialize(); + + testing::InSequence seq; + EXPECT_CALL(mock_callbacks_, OnBeginFrame(HasFrameHeader(0, SETTINGS, 0))); + EXPECT_CALL(mock_callbacks_, OnFrameRecv(IsSettings(testing::IsEmpty()))); + EXPECT_CALL(mock_callbacks_, OnBeginFrame(HasFrameHeader(0, PING, 0))); + EXPECT_CALL(mock_callbacks_, OnFrameRecv(IsPing(42))); + EXPECT_CALL(mock_callbacks_, + OnBeginFrame(HasFrameHeader(0, WINDOW_UPDATE, 0))); + EXPECT_CALL(mock_callbacks_, OnFrameRecv(IsWindowUpdate(1000))); + + ssize_t result = nghttp2_session_mem_recv( + session_.get(), ToUint8Ptr(initial_frames.data()), initial_frames.size()); + ASSERT_EQ(result, initial_frames.size()); + + const std::string unexpected_stream_frames = + TestFrameSequence() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(1, "This is the response body.") + .RstStream(3, Http2ErrorCode::INTERNAL_ERROR) + .GoAway(5, Http2ErrorCode::ENHANCE_YOUR_CALM, "calm down!!") + .Serialize(); + + EXPECT_CALL(mock_callbacks_, OnBeginFrame(HasFrameHeader(1, HEADERS, _))); + EXPECT_CALL(mock_callbacks_, OnInvalidFrameRecv(IsHeaders(1, _, _), _)); + // No events from the DATA, RST_STREAM or GOAWAY. + + nghttp2_session_mem_recv(session_.get(), + ToUint8Ptr(unexpected_stream_frames.data()), + unexpected_stream_frames.size()); +} + +// Tests the request-sending behavior of nghttp2 when acting as a client. +TEST_F(Nghttp2ClientTest, ClientSendsRequest) { + int result = nghttp2_session_send(session_.get()); + ASSERT_EQ(result, 0); + + EXPECT_THAT(serialized_, testing::StrEq(spdy::kHttp2ConnectionHeaderPrefix)); + serialized_.clear(); + + const std::string initial_frames = + TestFrameSequence().ServerPreface().Serialize(); + testing::InSequence s; + + // Server preface (empty SETTINGS) + EXPECT_CALL(mock_callbacks_, OnBeginFrame(HasFrameHeader(0, SETTINGS, 0))); + EXPECT_CALL(mock_callbacks_, OnFrameRecv(IsSettings(testing::IsEmpty()))); + + ssize_t recv_result = nghttp2_session_mem_recv( + session_.get(), ToUint8Ptr(initial_frames.data()), initial_frames.size()); + EXPECT_EQ(initial_frames.size(), recv_result); + + // Client wants to send a SETTINGS ack. + EXPECT_CALL(mock_callbacks_, BeforeFrameSend(IsSettings(testing::IsEmpty()))); + EXPECT_CALL(mock_callbacks_, OnFrameSend(IsSettings(testing::IsEmpty()))); + EXPECT_TRUE(nghttp2_session_want_write(session_.get())); + result = nghttp2_session_send(session_.get()); + EXPECT_THAT(serialized_, EqualsFrames({spdy::SpdyFrameType::SETTINGS})); + serialized_.clear(); + + EXPECT_FALSE(nghttp2_session_want_write(session_.get())); + + // The following sets up the client request. + std::vector> headers = { + {":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}; + std::vector nvs; + for (const auto& h : headers) { + nvs.push_back({.name = ToUint8Ptr(h.first.data()), + .value = ToUint8Ptr(h.second.data()), + .namelen = h.first.size(), + .valuelen = h.second.size()}); + } + const absl::string_view kBody = "This is an example request body."; + TestDataSource source{kBody}; + nghttp2_data_provider provider = source.MakeDataProvider(); + // After submitting the request, the client will want to write. + int stream_id = + nghttp2_submit_request(session_.get(), nullptr /* pri_spec */, nvs.data(), + nvs.size(), &provider, nullptr /* stream_data */); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(nghttp2_session_want_write(session_.get())); + + // We expect that the client will want to write HEADERS, then DATA. + EXPECT_CALL(mock_callbacks_, BeforeFrameSend(IsHeaders(stream_id, _, _))); + EXPECT_CALL(mock_callbacks_, OnFrameSend(IsHeaders(stream_id, _, _))); + EXPECT_CALL(mock_callbacks_, OnFrameSend(IsData(stream_id, kBody.size(), _))); + nghttp2_session_send(session_.get()); + EXPECT_THAT(serialized_, EqualsFrames({spdy::SpdyFrameType::HEADERS, + spdy::SpdyFrameType::DATA})); + EXPECT_THAT(serialized_, testing::HasSubstr(kBody)); + + // Once the request is flushed, the client no longer wants to write. + EXPECT_FALSE(nghttp2_session_want_write(session_.get())); +} + +class Nghttp2ServerTest : public Nghttp2Test { + public: + Perspective GetPerspective() override { return Perspective::kServer; } +}; + +// Verifies the behavior when a stream ends early. +TEST_F(Nghttp2ServerTest, MismatchedContentLength) { + const std::string initial_frames = + TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}, + {"content-length", "50"}}, + /*fin=*/false) + .Data(1, "Less than 50 bytes.", true) + .Serialize(); + + testing::InSequence seq; + EXPECT_CALL(mock_callbacks_, OnBeginFrame(HasFrameHeader(0, SETTINGS, _))); + + EXPECT_CALL(mock_callbacks_, OnFrameRecv(IsSettings(testing::IsEmpty()))); + + // HEADERS on stream 1 + EXPECT_CALL(mock_callbacks_, OnBeginFrame(HasFrameHeader( + 1, HEADERS, NGHTTP2_FLAG_END_HEADERS))); + + EXPECT_CALL(mock_callbacks_, + OnBeginHeaders(IsHeaders(1, NGHTTP2_FLAG_END_HEADERS, + NGHTTP2_HCAT_REQUEST))); + + EXPECT_CALL(mock_callbacks_, OnHeader(_, ":method", "POST", _)); + EXPECT_CALL(mock_callbacks_, OnHeader(_, ":scheme", "https", _)); + EXPECT_CALL(mock_callbacks_, OnHeader(_, ":authority", "example.com", _)); + EXPECT_CALL(mock_callbacks_, OnHeader(_, ":path", "/", _)); + EXPECT_CALL(mock_callbacks_, OnHeader(_, "content-length", "50", _)); + EXPECT_CALL(mock_callbacks_, + OnFrameRecv(IsHeaders(1, NGHTTP2_FLAG_END_HEADERS, + NGHTTP2_HCAT_REQUEST))); + + // DATA on stream 1 + EXPECT_CALL(mock_callbacks_, + OnBeginFrame(HasFrameHeader(1, DATA, NGHTTP2_FLAG_END_STREAM))); + + EXPECT_CALL(mock_callbacks_, OnDataChunkRecv(NGHTTP2_FLAG_END_STREAM, 1, + "Less than 50 bytes.")); + + // No OnFrameRecv() callback for the DATA frame, since there is a + // Content-Length mismatch error. + + ssize_t result = nghttp2_session_mem_recv( + session_.get(), ToUint8Ptr(initial_frames.data()), initial_frames.size()); + ASSERT_EQ(result, initial_frames.size()); +} + +} // namespace +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/nghttp2_test_utils.cc b/quiche/http2/adapter/nghttp2_test_utils.cc new file mode 100644 index 000000000000..0de00d85d30c --- /dev/null +++ b/quiche/http2/adapter/nghttp2_test_utils.cc @@ -0,0 +1,463 @@ +#include "quiche/http2/adapter/nghttp2_test_utils.h" + +#include "quiche/http2/adapter/nghttp2_util.h" +#include "quiche/common/quiche_endian.h" + +namespace http2 { +namespace adapter { +namespace test { + +namespace { + +// Custom gMock matcher, used to implement HasFrameHeader(). +class FrameHeaderMatcher { + public: + FrameHeaderMatcher(int32_t streamid, uint8_t type, + const testing::Matcher flags) + : stream_id_(streamid), type_(type), flags_(flags) {} + + bool Match(const nghttp2_frame_hd& frame, + testing::MatchResultListener* listener) const { + bool matched = true; + if (stream_id_ != frame.stream_id) { + *listener << "; expected stream " << stream_id_ << ", saw " + << frame.stream_id; + matched = false; + } + if (type_ != frame.type) { + *listener << "; expected frame type " << type_ << ", saw " + << static_cast(frame.type); + matched = false; + } + if (!flags_.MatchAndExplain(frame.flags, listener)) { + matched = false; + } + return matched; + } + + void DescribeTo(std::ostream* os) const { + *os << "contains a frame header with stream " << stream_id_ << ", type " + << type_ << ", "; + flags_.DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const { + *os << "does not contain a frame header with stream " << stream_id_ + << ", type " << type_ << ", "; + flags_.DescribeNegationTo(os); + } + + private: + const int32_t stream_id_; + const int type_; + const testing::Matcher flags_; +}; + +class PointerToFrameHeaderMatcher + : public FrameHeaderMatcher, + public testing::MatcherInterface { + public: + PointerToFrameHeaderMatcher(int32_t streamid, uint8_t type, + const testing::Matcher flags) + : FrameHeaderMatcher(streamid, type, flags) {} + + bool MatchAndExplain(const nghttp2_frame_hd* frame, + testing::MatchResultListener* listener) const override { + return FrameHeaderMatcher::Match(*frame, listener); + } + + void DescribeTo(std::ostream* os) const override { + FrameHeaderMatcher::DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const override { + FrameHeaderMatcher::DescribeNegationTo(os); + } +}; + +class ReferenceToFrameHeaderMatcher + : public FrameHeaderMatcher, + public testing::MatcherInterface { + public: + ReferenceToFrameHeaderMatcher(int32_t streamid, uint8_t type, + const testing::Matcher flags) + : FrameHeaderMatcher(streamid, type, flags) {} + + bool MatchAndExplain(const nghttp2_frame_hd& frame, + testing::MatchResultListener* listener) const override { + return FrameHeaderMatcher::Match(frame, listener); + } + + void DescribeTo(std::ostream* os) const override { + FrameHeaderMatcher::DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const override { + FrameHeaderMatcher::DescribeNegationTo(os); + } +}; + +class DataMatcher : public testing::MatcherInterface { + public: + DataMatcher(const testing::Matcher stream_id, + const testing::Matcher length, + const testing::Matcher flags, + const testing::Matcher padding) + : stream_id_(stream_id), + length_(length), + flags_(flags), + padding_(padding) {} + + bool MatchAndExplain(const nghttp2_frame* frame, + testing::MatchResultListener* listener) const override { + if (frame->hd.type != NGHTTP2_DATA) { + *listener << "; expected DATA frame, saw frame of type " + << static_cast(frame->hd.type); + return false; + } + bool matched = true; + if (!stream_id_.MatchAndExplain(frame->hd.stream_id, listener)) { + matched = false; + } + if (!length_.MatchAndExplain(frame->hd.length, listener)) { + matched = false; + } + if (!flags_.MatchAndExplain(frame->hd.flags, listener)) { + matched = false; + } + if (!padding_.MatchAndExplain(frame->data.padlen, listener)) { + matched = false; + } + return matched; + } + + void DescribeTo(std::ostream* os) const override { + *os << "contains a DATA frame, "; + stream_id_.DescribeTo(os); + length_.DescribeTo(os); + flags_.DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const override { + *os << "does not contain a DATA frame, "; + stream_id_.DescribeNegationTo(os); + length_.DescribeNegationTo(os); + flags_.DescribeNegationTo(os); + } + + private: + const testing::Matcher stream_id_; + const testing::Matcher length_; + const testing::Matcher flags_; + const testing::Matcher padding_; +}; + +class HeadersMatcher : public testing::MatcherInterface { + public: + HeadersMatcher(const testing::Matcher stream_id, + const testing::Matcher flags, + const testing::Matcher category) + : stream_id_(stream_id), flags_(flags), category_(category) {} + + bool MatchAndExplain(const nghttp2_frame* frame, + testing::MatchResultListener* listener) const override { + if (frame->hd.type != NGHTTP2_HEADERS) { + *listener << "; expected HEADERS frame, saw frame of type " + << static_cast(frame->hd.type); + return false; + } + bool matched = true; + if (!stream_id_.MatchAndExplain(frame->hd.stream_id, listener)) { + matched = false; + } + if (!flags_.MatchAndExplain(frame->hd.flags, listener)) { + matched = false; + } + if (!category_.MatchAndExplain(frame->headers.cat, listener)) { + matched = false; + } + return matched; + } + + void DescribeTo(std::ostream* os) const override { + *os << "contains a HEADERS frame, "; + stream_id_.DescribeTo(os); + flags_.DescribeTo(os); + category_.DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const override { + *os << "does not contain a HEADERS frame, "; + stream_id_.DescribeNegationTo(os); + flags_.DescribeNegationTo(os); + category_.DescribeNegationTo(os); + } + + private: + const testing::Matcher stream_id_; + const testing::Matcher flags_; + const testing::Matcher category_; +}; + +class RstStreamMatcher + : public testing::MatcherInterface { + public: + RstStreamMatcher(const testing::Matcher stream_id, + const testing::Matcher error_code) + : stream_id_(stream_id), error_code_(error_code) {} + + bool MatchAndExplain(const nghttp2_frame* frame, + testing::MatchResultListener* listener) const override { + if (frame->hd.type != NGHTTP2_RST_STREAM) { + *listener << "; expected RST_STREAM frame, saw frame of type " + << static_cast(frame->hd.type); + return false; + } + bool matched = true; + if (!stream_id_.MatchAndExplain(frame->hd.stream_id, listener)) { + matched = false; + } + if (!error_code_.MatchAndExplain(frame->rst_stream.error_code, listener)) { + matched = false; + } + return matched; + } + + void DescribeTo(std::ostream* os) const override { + *os << "contains a RST_STREAM frame, "; + stream_id_.DescribeTo(os); + error_code_.DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const override { + *os << "does not contain a RST_STREAM frame, "; + stream_id_.DescribeNegationTo(os); + error_code_.DescribeNegationTo(os); + } + + private: + const testing::Matcher stream_id_; + const testing::Matcher error_code_; +}; + +class SettingsMatcher : public testing::MatcherInterface { + public: + SettingsMatcher(const testing::Matcher> values) + : values_(values) {} + + bool MatchAndExplain(const nghttp2_frame* frame, + testing::MatchResultListener* listener) const override { + if (frame->hd.type != NGHTTP2_SETTINGS) { + *listener << "; expected SETTINGS frame, saw frame of type " + << static_cast(frame->hd.type); + return false; + } + std::vector settings; + settings.reserve(frame->settings.niv); + for (size_t i = 0; i < frame->settings.niv; ++i) { + const auto& p = frame->settings.iv[i]; + settings.push_back({static_cast(p.settings_id), p.value}); + } + return values_.MatchAndExplain(settings, listener); + } + + void DescribeTo(std::ostream* os) const override { + *os << "contains a SETTINGS frame, "; + values_.DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const override { + *os << "does not contain a SETTINGS frame, "; + values_.DescribeNegationTo(os); + } + + private: + const testing::Matcher> values_; +}; + +class PingMatcher : public testing::MatcherInterface { + public: + PingMatcher(const testing::Matcher id, bool is_ack) + : id_(id), is_ack_(is_ack) {} + + bool MatchAndExplain(const nghttp2_frame* frame, + testing::MatchResultListener* listener) const override { + if (frame->hd.type != NGHTTP2_PING) { + *listener << "; expected PING frame, saw frame of type " + << static_cast(frame->hd.type); + return false; + } + bool matched = true; + bool frame_ack = frame->hd.flags & NGHTTP2_FLAG_ACK; + if (is_ack_ != frame_ack) { + *listener << "; expected is_ack=" << is_ack_ << ", saw " << frame_ack; + matched = false; + } + uint64_t data; + std::memcpy(&data, frame->ping.opaque_data, sizeof(data)); + data = quiche::QuicheEndian::HostToNet64(data); + if (!id_.MatchAndExplain(data, listener)) { + matched = false; + } + return matched; + } + + void DescribeTo(std::ostream* os) const override { + *os << "contains a PING frame, "; + id_.DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const override { + *os << "does not contain a PING frame, "; + id_.DescribeNegationTo(os); + } + + private: + const testing::Matcher id_; + const bool is_ack_; +}; + +class GoAwayMatcher : public testing::MatcherInterface { + public: + GoAwayMatcher(const testing::Matcher last_stream_id, + const testing::Matcher error_code, + const testing::Matcher opaque_data) + : last_stream_id_(last_stream_id), + error_code_(error_code), + opaque_data_(opaque_data) {} + + bool MatchAndExplain(const nghttp2_frame* frame, + testing::MatchResultListener* listener) const override { + if (frame->hd.type != NGHTTP2_GOAWAY) { + *listener << "; expected GOAWAY frame, saw frame of type " + << static_cast(frame->hd.type); + return false; + } + bool matched = true; + if (!last_stream_id_.MatchAndExplain(frame->goaway.last_stream_id, + listener)) { + matched = false; + } + if (!error_code_.MatchAndExplain(frame->goaway.error_code, listener)) { + matched = false; + } + auto opaque_data = + ToStringView(frame->goaway.opaque_data, frame->goaway.opaque_data_len); + if (!opaque_data_.MatchAndExplain(opaque_data, listener)) { + matched = false; + } + return matched; + } + + void DescribeTo(std::ostream* os) const override { + *os << "contains a GOAWAY frame, "; + last_stream_id_.DescribeTo(os); + error_code_.DescribeTo(os); + opaque_data_.DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const override { + *os << "does not contain a GOAWAY frame, "; + last_stream_id_.DescribeNegationTo(os); + error_code_.DescribeNegationTo(os); + opaque_data_.DescribeNegationTo(os); + } + + private: + const testing::Matcher last_stream_id_; + const testing::Matcher error_code_; + const testing::Matcher opaque_data_; +}; + +class WindowUpdateMatcher + : public testing::MatcherInterface { + public: + WindowUpdateMatcher(const testing::Matcher delta) : delta_(delta) {} + + bool MatchAndExplain(const nghttp2_frame* frame, + testing::MatchResultListener* listener) const override { + if (frame->hd.type != NGHTTP2_WINDOW_UPDATE) { + *listener << "; expected WINDOW_UPDATE frame, saw frame of type " + << static_cast(frame->hd.type); + return false; + } + return delta_.MatchAndExplain(frame->window_update.window_size_increment, + listener); + } + + void DescribeTo(std::ostream* os) const override { + *os << "contains a WINDOW_UPDATE frame, "; + delta_.DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const override { + *os << "does not contain a WINDOW_UPDATE frame, "; + delta_.DescribeNegationTo(os); + } + + private: + const testing::Matcher delta_; +}; + +} // namespace + +testing::Matcher HasFrameHeader( + uint32_t streamid, uint8_t type, const testing::Matcher flags) { + return MakeMatcher(new PointerToFrameHeaderMatcher(streamid, type, flags)); +} + +testing::Matcher HasFrameHeaderRef( + uint32_t streamid, uint8_t type, const testing::Matcher flags) { + return MakeMatcher(new ReferenceToFrameHeaderMatcher(streamid, type, flags)); +} + +testing::Matcher IsData( + const testing::Matcher stream_id, + const testing::Matcher length, const testing::Matcher flags, + const testing::Matcher padding) { + return MakeMatcher(new DataMatcher(stream_id, length, flags, padding)); +} + +testing::Matcher IsHeaders( + const testing::Matcher stream_id, + const testing::Matcher flags, const testing::Matcher category) { + return MakeMatcher(new HeadersMatcher(stream_id, flags, category)); +} + +testing::Matcher IsRstStream( + const testing::Matcher stream_id, + const testing::Matcher error_code) { + return MakeMatcher(new RstStreamMatcher(stream_id, error_code)); +} + +testing::Matcher IsSettings( + const testing::Matcher> values) { + return MakeMatcher(new SettingsMatcher(values)); +} + +testing::Matcher IsPing( + const testing::Matcher id) { + return MakeMatcher(new PingMatcher(id, false)); +} + +testing::Matcher IsPingAck( + const testing::Matcher id) { + return MakeMatcher(new PingMatcher(id, true)); +} + +testing::Matcher IsGoAway( + const testing::Matcher last_stream_id, + const testing::Matcher error_code, + const testing::Matcher opaque_data) { + return MakeMatcher( + new GoAwayMatcher(last_stream_id, error_code, opaque_data)); +} + +testing::Matcher IsWindowUpdate( + const testing::Matcher delta) { + return MakeMatcher(new WindowUpdateMatcher(delta)); +} + +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/nghttp2_test_utils.h b/quiche/http2/adapter/nghttp2_test_utils.h new file mode 100644 index 000000000000..9b359f4cef65 --- /dev/null +++ b/quiche/http2/adapter/nghttp2_test_utils.h @@ -0,0 +1,103 @@ +#ifndef QUICHE_HTTP2_ADAPTER_NGHTTP2_TEST_UTILS_H_ +#define QUICHE_HTTP2_ADAPTER_NGHTTP2_TEST_UTILS_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/http2/adapter/http2_protocol.h" +#include "quiche/http2/adapter/nghttp2.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace adapter { +namespace test { + +// A simple class that can easily be adapted to act as a nghttp2_data_source. +class QUICHE_NO_EXPORT TestDataSource { + public: + explicit TestDataSource(absl::string_view data) : data_(std::string(data)) {} + + absl::string_view ReadNext(size_t size) { + const size_t to_send = std::min(size, remaining_.size()); + auto ret = remaining_.substr(0, to_send); + remaining_.remove_prefix(to_send); + return ret; + } + + size_t SelectPayloadLength(size_t max_length) { + return std::min(max_length, remaining_.size()); + } + + nghttp2_data_provider MakeDataProvider() { + nghttp2_data_source s; + s.ptr = this; + return nghttp2_data_provider{ + s, + [](nghttp2_session*, int32_t, uint8_t*, size_t length, + uint32_t* data_flags, nghttp2_data_source* source, + void*) -> ssize_t { + *data_flags |= NGHTTP2_DATA_FLAG_NO_COPY; + auto* s = static_cast(source->ptr); + if (!s->is_data_available()) { + return NGHTTP2_ERR_DEFERRED; + } + const ssize_t ret = s->SelectPayloadLength(length); + if (ret < static_cast(length)) { + *data_flags |= NGHTTP2_DATA_FLAG_EOF; + } + return ret; + }}; + } + + bool is_data_available() const { return is_data_available_; } + void set_is_data_available(bool value) { is_data_available_ = value; } + + private: + const std::string data_; + absl::string_view remaining_ = data_; + bool is_data_available_ = true; +}; + +// Matchers for nghttp2 data types. +testing::Matcher HasFrameHeader( + uint32_t streamid, uint8_t type, const testing::Matcher flags); +testing::Matcher HasFrameHeaderRef( + uint32_t streamid, uint8_t type, const testing::Matcher flags); + +testing::Matcher IsData( + const testing::Matcher stream_id, + const testing::Matcher length, const testing::Matcher flags, + const testing::Matcher padding = testing::_); + +testing::Matcher IsHeaders( + const testing::Matcher stream_id, + const testing::Matcher flags, const testing::Matcher category); + +testing::Matcher IsRstStream( + const testing::Matcher stream_id, + const testing::Matcher error_code); + +testing::Matcher IsSettings( + const testing::Matcher> values); + +testing::Matcher IsPing( + const testing::Matcher id); + +testing::Matcher IsPingAck( + const testing::Matcher id); + +testing::Matcher IsGoAway( + const testing::Matcher last_stream_id, + const testing::Matcher error_code, + const testing::Matcher opaque_data); + +testing::Matcher IsWindowUpdate( + const testing::Matcher delta); + +} // namespace test +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_NGHTTP2_TEST_UTILS_H_ diff --git a/quiche/http2/adapter/nghttp2_util.cc b/quiche/http2/adapter/nghttp2_util.cc new file mode 100644 index 000000000000..0d5dc5e41d1e --- /dev/null +++ b/quiche/http2/adapter/nghttp2_util.cc @@ -0,0 +1,304 @@ +#include "quiche/http2/adapter/nghttp2_util.h" + +#include +#include + +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "quiche/http2/adapter/http2_protocol.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_endian.h" + +namespace http2 { +namespace adapter { + +namespace { + +using InvalidFrameError = Http2VisitorInterface::InvalidFrameError; + +void DeleteCallbacks(nghttp2_session_callbacks* callbacks) { + if (callbacks) { + nghttp2_session_callbacks_del(callbacks); + } +} + +void DeleteSession(nghttp2_session* session) { + if (session) { + nghttp2_session_del(session); + } +} + +} // namespace + +nghttp2_session_callbacks_unique_ptr MakeCallbacksPtr( + nghttp2_session_callbacks* callbacks) { + return nghttp2_session_callbacks_unique_ptr(callbacks, &DeleteCallbacks); +} + +nghttp2_session_unique_ptr MakeSessionPtr(nghttp2_session* session) { + return nghttp2_session_unique_ptr(session, &DeleteSession); +} + +uint8_t* ToUint8Ptr(char* str) { return reinterpret_cast(str); } +uint8_t* ToUint8Ptr(const char* str) { + return const_cast(reinterpret_cast(str)); +} + +absl::string_view ToStringView(nghttp2_rcbuf* rc_buffer) { + nghttp2_vec buffer = nghttp2_rcbuf_get_buf(rc_buffer); + return absl::string_view(reinterpret_cast(buffer.base), + buffer.len); +} + +absl::string_view ToStringView(uint8_t* pointer, size_t length) { + return absl::string_view(reinterpret_cast(pointer), length); +} + +absl::string_view ToStringView(const uint8_t* pointer, size_t length) { + return absl::string_view(reinterpret_cast(pointer), length); +} + +std::vector GetNghttp2Nvs(absl::Span headers) { + const int num_headers = headers.size(); + std::vector nghttp2_nvs; + nghttp2_nvs.reserve(num_headers); + for (int i = 0; i < num_headers; ++i) { + nghttp2_nv header; + uint8_t flags = NGHTTP2_NV_FLAG_NONE; + + const auto [name, no_copy_name] = GetStringView(headers[i].first); + header.name = ToUint8Ptr(name.data()); + header.namelen = name.size(); + if (no_copy_name) { + flags |= NGHTTP2_NV_FLAG_NO_COPY_NAME; + } + const auto [value, no_copy_value] = GetStringView(headers[i].second); + header.value = ToUint8Ptr(value.data()); + header.valuelen = value.size(); + if (no_copy_value) { + flags |= NGHTTP2_NV_FLAG_NO_COPY_VALUE; + } + header.flags = flags; + nghttp2_nvs.push_back(std::move(header)); + } + + return nghttp2_nvs; +} + +std::vector GetResponseNghttp2Nvs( + const spdy::Http2HeaderBlock& headers, absl::string_view response_code) { + // Allocate enough for all headers and also the :status pseudoheader. + const int num_headers = headers.size(); + std::vector nghttp2_nvs; + nghttp2_nvs.reserve(num_headers + 1); + + // Add the :status pseudoheader first. + nghttp2_nv status; + status.name = ToUint8Ptr(kHttp2StatusPseudoHeader); + status.namelen = strlen(kHttp2StatusPseudoHeader); + status.value = ToUint8Ptr(response_code.data()); + status.valuelen = response_code.size(); + status.flags = NGHTTP2_FLAG_NONE; + nghttp2_nvs.push_back(std::move(status)); + + // Add the remaining headers. + for (const auto& header_pair : headers) { + nghttp2_nv header; + header.name = ToUint8Ptr(header_pair.first.data()); + header.namelen = header_pair.first.size(); + header.value = ToUint8Ptr(header_pair.second.data()); + header.valuelen = header_pair.second.size(); + header.flags = NGHTTP2_FLAG_NONE; + nghttp2_nvs.push_back(std::move(header)); + } + + return nghttp2_nvs; +} + +Http2ErrorCode ToHttp2ErrorCode(uint32_t wire_error_code) { + if (wire_error_code > static_cast(Http2ErrorCode::MAX_ERROR_CODE)) { + return Http2ErrorCode::INTERNAL_ERROR; + } + return static_cast(wire_error_code); +} + +int ToNgHttp2ErrorCode(InvalidFrameError error) { + switch (error) { + case InvalidFrameError::kProtocol: + return NGHTTP2_ERR_PROTO; + case InvalidFrameError::kRefusedStream: + return NGHTTP2_ERR_REFUSED_STREAM; + case InvalidFrameError::kHttpHeader: + return NGHTTP2_ERR_HTTP_HEADER; + case InvalidFrameError::kHttpMessaging: + return NGHTTP2_ERR_HTTP_MESSAGING; + case InvalidFrameError::kFlowControl: + return NGHTTP2_ERR_FLOW_CONTROL; + case InvalidFrameError::kStreamClosed: + return NGHTTP2_ERR_STREAM_CLOSED; + } + return NGHTTP2_ERR_PROTO; +} + +InvalidFrameError ToInvalidFrameError(int error) { + switch (error) { + case NGHTTP2_ERR_PROTO: + return InvalidFrameError::kProtocol; + case NGHTTP2_ERR_REFUSED_STREAM: + return InvalidFrameError::kRefusedStream; + case NGHTTP2_ERR_HTTP_HEADER: + return InvalidFrameError::kHttpHeader; + case NGHTTP2_ERR_HTTP_MESSAGING: + return InvalidFrameError::kHttpMessaging; + case NGHTTP2_ERR_FLOW_CONTROL: + return InvalidFrameError::kFlowControl; + case NGHTTP2_ERR_STREAM_CLOSED: + return InvalidFrameError::kStreamClosed; + } + return InvalidFrameError::kProtocol; +} + +class Nghttp2DataFrameSource : public DataFrameSource { + public: + Nghttp2DataFrameSource(nghttp2_data_provider provider, + nghttp2_send_data_callback send_data, void* user_data) + : provider_(std::move(provider)), + send_data_(std::move(send_data)), + user_data_(user_data) {} + + std::pair SelectPayloadLength(size_t max_length) override { + const int32_t stream_id = 0; + uint32_t data_flags = 0; + int64_t result = provider_.read_callback( + nullptr /* session */, stream_id, nullptr /* buf */, max_length, + &data_flags, &provider_.source, nullptr /* user_data */); + if (result == NGHTTP2_ERR_DEFERRED) { + return {kBlocked, false}; + } else if (result < 0) { + return {kError, false}; + } else if ((data_flags & NGHTTP2_DATA_FLAG_NO_COPY) == 0) { + QUICHE_LOG(ERROR) << "Source did not use the zero-copy API!"; + return {kError, false}; + } else { + const bool eof = data_flags & NGHTTP2_DATA_FLAG_EOF; + if (eof && (data_flags & NGHTTP2_DATA_FLAG_NO_END_STREAM) == 0) { + send_fin_ = true; + } + return {result, eof}; + } + } + + bool Send(absl::string_view frame_header, size_t payload_length) override { + nghttp2_frame frame; + frame.hd.type = 0; + frame.hd.length = payload_length; + frame.hd.flags = 0; + frame.hd.stream_id = 0; + frame.data.padlen = 0; + const int result = send_data_( + nullptr /* session */, &frame, ToUint8Ptr(frame_header.data()), + payload_length, &provider_.source, user_data_); + QUICHE_LOG_IF(ERROR, result < 0 && result != NGHTTP2_ERR_WOULDBLOCK) + << "Unexpected error code from send: " << result; + return result == 0; + } + + bool send_fin() const override { return send_fin_; } + + private: + nghttp2_data_provider provider_; + nghttp2_send_data_callback send_data_; + void* user_data_; + bool send_fin_ = false; +}; + +std::unique_ptr MakeZeroCopyDataFrameSource( + nghttp2_data_provider provider, void* user_data, + nghttp2_send_data_callback send_data) { + return std::make_unique( + std::move(provider), std::move(send_data), user_data); +} + +absl::string_view ErrorString(uint32_t error_code) { + return Http2ErrorCodeToString(static_cast(error_code)); +} + +size_t PaddingLength(uint8_t flags, size_t padlen) { + return (flags & PADDED_FLAG ? 1 : 0) + padlen; +} + +struct NvFormatter { + void operator()(std::string* out, const nghttp2_nv& nv) { + absl::StrAppend(out, ToStringView(nv.name, nv.namelen), ": ", + ToStringView(nv.value, nv.valuelen)); + } +}; + +std::string NvsAsString(nghttp2_nv* nva, size_t nvlen) { + return absl::StrJoin(absl::MakeConstSpan(nva, nvlen), ", ", NvFormatter()); +} + +#define HTTP2_FRAME_SEND_LOG QUICHE_VLOG(1) + +void LogBeforeSend(const nghttp2_frame& frame) { + switch (static_cast(frame.hd.type)) { + case FrameType::DATA: + HTTP2_FRAME_SEND_LOG << "Sending DATA on stream " << frame.hd.stream_id + << " with length " + << frame.hd.length - PaddingLength(frame.hd.flags, + frame.data.padlen) + << " and padding " + << PaddingLength(frame.hd.flags, frame.data.padlen); + break; + case FrameType::HEADERS: + HTTP2_FRAME_SEND_LOG << "Sending HEADERS on stream " << frame.hd.stream_id + << " with headers [" + << NvsAsString(frame.headers.nva, + frame.headers.nvlen) + << "]"; + break; + case FrameType::PRIORITY: + HTTP2_FRAME_SEND_LOG << "Sending PRIORITY"; + break; + case FrameType::RST_STREAM: + HTTP2_FRAME_SEND_LOG << "Sending RST_STREAM on stream " + << frame.hd.stream_id << " with error code " + << ErrorString(frame.rst_stream.error_code); + break; + case FrameType::SETTINGS: + HTTP2_FRAME_SEND_LOG << "Sending SETTINGS with " << frame.settings.niv + << " entries, is_ack: " + << (frame.hd.flags & ACK_FLAG); + break; + case FrameType::PUSH_PROMISE: + HTTP2_FRAME_SEND_LOG << "Sending PUSH_PROMISE"; + break; + case FrameType::PING: { + Http2PingId ping_id; + std::memcpy(&ping_id, frame.ping.opaque_data, sizeof(Http2PingId)); + HTTP2_FRAME_SEND_LOG << "Sending PING with unique_id " + << quiche::QuicheEndian::NetToHost64(ping_id) + << ", is_ack: " << (frame.hd.flags & ACK_FLAG); + break; + } + case FrameType::GOAWAY: + HTTP2_FRAME_SEND_LOG << "Sending GOAWAY with last_stream: " + << frame.goaway.last_stream_id << " and error " + << ErrorString(frame.goaway.error_code); + break; + case FrameType::WINDOW_UPDATE: + HTTP2_FRAME_SEND_LOG << "Sending WINDOW_UPDATE on stream " + << frame.hd.stream_id << " with update delta " + << frame.window_update.window_size_increment; + break; + case FrameType::CONTINUATION: + HTTP2_FRAME_SEND_LOG << "Sending CONTINUATION, which is unexpected"; + break; + } +} + +#undef HTTP2_FRAME_SEND_LOG + +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/nghttp2_util.h b/quiche/http2/adapter/nghttp2_util.h new file mode 100644 index 000000000000..423ad1bb7c0c --- /dev/null +++ b/quiche/http2/adapter/nghttp2_util.h @@ -0,0 +1,76 @@ +// Various utility/conversion functions for compatibility with the nghttp2 API. + +#ifndef QUICHE_HTTP2_ADAPTER_NGHTTP2_UTIL_H_ +#define QUICHE_HTTP2_ADAPTER_NGHTTP2_UTIL_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "quiche/http2/adapter/data_source.h" +#include "quiche/http2/adapter/http2_protocol.h" +#include "quiche/http2/adapter/http2_visitor_interface.h" +#include "quiche/http2/adapter/nghttp2.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace http2 { +namespace adapter { + +// Return codes to represent various errors. +inline constexpr int kStreamCallbackFailureStatus = + NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE; +inline constexpr int kCancelStatus = NGHTTP2_ERR_CANCEL; + +using CallbacksDeleter = void (*)(nghttp2_session_callbacks*); +using SessionDeleter = void (*)(nghttp2_session*); + +using nghttp2_session_callbacks_unique_ptr = + std::unique_ptr; +using nghttp2_session_unique_ptr = + std::unique_ptr; + +nghttp2_session_callbacks_unique_ptr MakeCallbacksPtr( + nghttp2_session_callbacks* callbacks); +nghttp2_session_unique_ptr MakeSessionPtr(nghttp2_session* session); + +uint8_t* ToUint8Ptr(char* str); +uint8_t* ToUint8Ptr(const char* str); + +absl::string_view ToStringView(nghttp2_rcbuf* rc_buffer); +absl::string_view ToStringView(uint8_t* pointer, size_t length); +absl::string_view ToStringView(const uint8_t* pointer, size_t length); + +// Returns the nghttp2 header structure from the given |headers|, which +// must have the correct pseudoheaders preceding other headers. +std::vector GetNghttp2Nvs(absl::Span headers); + +// Returns the nghttp2 header structure from the given response |headers|, with +// the :status pseudoheader first based on the given |response_code|. The +// |response_code| is passed in separately from |headers| for lifetime reasons. +std::vector GetResponseNghttp2Nvs( + const spdy::Http2HeaderBlock& headers, absl::string_view response_code); + +// Returns the HTTP/2 error code corresponding to the raw wire value, as defined +// in RFC 7540 Section 7. Unrecognized error codes are treated as INTERNAL_ERROR +// based on the RFC 7540 Section 7 suggestion. +Http2ErrorCode ToHttp2ErrorCode(uint32_t wire_error_code); + +// Converts between the integer error code used by nghttp2 and the corresponding +// InvalidFrameError value. +int ToNgHttp2ErrorCode(Http2VisitorInterface::InvalidFrameError error); +Http2VisitorInterface::InvalidFrameError ToInvalidFrameError(int error); + +// Transforms a nghttp2_data_provider into a DataFrameSource. Assumes that +// |provider| uses the zero-copy nghttp2_data_source_read_callback API. Unsafe +// otherwise. +std::unique_ptr MakeZeroCopyDataFrameSource( + nghttp2_data_provider provider, void* user_data, + nghttp2_send_data_callback send_data); + +void LogBeforeSend(const nghttp2_frame& frame); + +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_NGHTTP2_UTIL_H_ diff --git a/quiche/http2/adapter/nghttp2_util_test.cc b/quiche/http2/adapter/nghttp2_util_test.cc new file mode 100644 index 000000000000..9ac6ef687942 --- /dev/null +++ b/quiche/http2/adapter/nghttp2_util_test.cc @@ -0,0 +1,109 @@ +#include "quiche/http2/adapter/nghttp2_util.h" + +#include "quiche/http2/adapter/nghttp2_test_utils.h" +#include "quiche/http2/adapter/test_utils.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace adapter { +namespace test { +namespace { + +// This send callback assumes |source|'s pointer is a TestDataSource, and +// |user_data| is a std::string. +int FakeSendCallback(nghttp2_session*, nghttp2_frame* /*frame*/, + const uint8_t* framehd, size_t length, + nghttp2_data_source* source, void* user_data) { + auto* dest = static_cast(user_data); + // Appends the frame header to the string. + absl::StrAppend(dest, ToStringView(framehd, 9)); + auto* test_source = static_cast(source->ptr); + absl::string_view payload = test_source->ReadNext(length); + // Appends the frame payload to the string. + absl::StrAppend(dest, payload); + return 0; +} + +TEST(MakeZeroCopyDataFrameSource, EmptyPayload) { + std::string result; + + const absl::string_view kEmptyBody = ""; + TestDataSource body1{kEmptyBody}; + // The TestDataSource is wrapped in the nghttp2_data_provider data type. + nghttp2_data_provider provider = body1.MakeDataProvider(); + + // This call transforms it back into a DataFrameSource, which is compatible + // with the Http2Adapter API. + std::unique_ptr frame_source = + MakeZeroCopyDataFrameSource(provider, &result, FakeSendCallback); + auto [length, eof] = frame_source->SelectPayloadLength(100); + EXPECT_EQ(length, 0); + EXPECT_TRUE(eof); + frame_source->Send("ninebytes", 0); + EXPECT_EQ(result, "ninebytes"); +} + +TEST(MakeZeroCopyDataFrameSource, ShortPayload) { + std::string result; + + const absl::string_view kShortBody = + "Example Page!" + "
Wow!!" + "
" + ""; + TestDataSource body1{kShortBody}; + // The TestDataSource is wrapped in the nghttp2_data_provider data type. + nghttp2_data_provider provider = body1.MakeDataProvider(); + + // This call transforms it back into a DataFrameSource, which is compatible + // with the Http2Adapter API. + std::unique_ptr frame_source = + MakeZeroCopyDataFrameSource(provider, &result, FakeSendCallback); + auto [length, eof] = frame_source->SelectPayloadLength(200); + EXPECT_EQ(length, kShortBody.size()); + EXPECT_TRUE(eof); + frame_source->Send("ninebytes", length); + EXPECT_EQ(result, absl::StrCat("ninebytes", kShortBody)); +} + +TEST(MakeZeroCopyDataFrameSource, MultiFramePayload) { + std::string result; + + const absl::string_view kShortBody = + "Example Page!" + "
Wow!!" + "
" + ""; + TestDataSource body1{kShortBody}; + // The TestDataSource is wrapped in the nghttp2_data_provider data type. + nghttp2_data_provider provider = body1.MakeDataProvider(); + + // This call transforms it back into a DataFrameSource, which is compatible + // with the Http2Adapter API. + std::unique_ptr frame_source = + MakeZeroCopyDataFrameSource(provider, &result, FakeSendCallback); + auto ret = frame_source->SelectPayloadLength(50); + EXPECT_EQ(ret.first, 50); + EXPECT_FALSE(ret.second); + frame_source->Send("ninebyte1", ret.first); + + ret = frame_source->SelectPayloadLength(50); + EXPECT_EQ(ret.first, 50); + EXPECT_FALSE(ret.second); + frame_source->Send("ninebyte2", ret.first); + + ret = frame_source->SelectPayloadLength(50); + EXPECT_EQ(ret.first, 44); + EXPECT_TRUE(ret.second); + frame_source->Send("ninebyte3", ret.first); + + EXPECT_EQ(result, + "ninebyte1Example Page!
Wow!!<" + "ninebyte3/th>
"); +} + +} // namespace +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/noop_header_validator.cc b/quiche/http2/adapter/noop_header_validator.cc new file mode 100644 index 000000000000..f39342d5bcae --- /dev/null +++ b/quiche/http2/adapter/noop_header_validator.cc @@ -0,0 +1,22 @@ +#include "quiche/http2/adapter/noop_header_validator.h" + +#include "absl/strings/escaping.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { +namespace adapter { + +HeaderValidatorBase::HeaderStatus NoopHeaderValidator::ValidateSingleHeader( + absl::string_view key, absl::string_view value) { + if (key == ":status") { + status_ = std::string(value); + } + return HEADER_OK; +} + +bool NoopHeaderValidator::FinishHeaderBlock(HeaderType /* type */) { + return true; +} + +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/noop_header_validator.h b/quiche/http2/adapter/noop_header_validator.h new file mode 100644 index 000000000000..f6b95e940994 --- /dev/null +++ b/quiche/http2/adapter/noop_header_validator.h @@ -0,0 +1,25 @@ +#ifndef QUICHE_HTTP2_ADAPTER_NOOP_HEADER_VALIDATOR_H_ +#define QUICHE_HTTP2_ADAPTER_NOOP_HEADER_VALIDATOR_H_ + +#include "absl/strings/string_view.h" +#include "quiche/http2/adapter/header_validator_base.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace adapter { + +// A validator that does not actually perform any validation. +class QUICHE_EXPORT NoopHeaderValidator : public HeaderValidatorBase { + public: + NoopHeaderValidator() = default; + + HeaderStatus ValidateSingleHeader(absl::string_view key, + absl::string_view value) override; + + bool FinishHeaderBlock(HeaderType type) override; +}; + +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_NOOP_HEADER_VALIDATOR_H_ diff --git a/quiche/http2/adapter/noop_header_validator_test.cc b/quiche/http2/adapter/noop_header_validator_test.cc new file mode 100644 index 000000000000..6340c606c763 --- /dev/null +++ b/quiche/http2/adapter/noop_header_validator_test.cc @@ -0,0 +1,523 @@ +#include "quiche/http2/adapter/noop_header_validator.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace adapter { +namespace test { + +using ::testing::Optional; + +using Header = std::pair; +constexpr Header kSampleRequestPseudoheaders[] = {{":authority", "www.foo.com"}, + {":method", "GET"}, + {":path", "/foo"}, + {":scheme", "https"}}; + +TEST(NoopHeaderValidatorTest, HeaderNameEmpty) { + NoopHeaderValidator v; + NoopHeaderValidator::HeaderStatus status = + v.ValidateSingleHeader("", "value"); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, status); +} + +TEST(NoopHeaderValidatorTest, HeaderValueEmpty) { + NoopHeaderValidator v; + NoopHeaderValidator::HeaderStatus status = v.ValidateSingleHeader("name", ""); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, status); +} + +TEST(NoopHeaderValidatorTest, ExceedsMaxSize) { + NoopHeaderValidator v; + v.SetMaxFieldSize(64u); + NoopHeaderValidator::HeaderStatus status = + v.ValidateSingleHeader("name", "value"); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, status); + status = v.ValidateSingleHeader( + "name2", + "Antidisestablishmentariansism is supercalifragilisticexpialodocious."); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, status); +} + +TEST(NoopHeaderValidatorTest, AnyNameCharIsValid) { + NoopHeaderValidator v; + char pseudo_name[] = ":met hod"; + char name[] = "na me"; + for (int i = std::numeric_limits::min(); + i < std::numeric_limits::max(); ++i) { + char c = static_cast(i); + // Test a pseudo-header name with this char. + pseudo_name[3] = c; + auto sv = absl::string_view(pseudo_name, 8); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(sv, "value")); + // Test a regular header name with this char. + name[2] = c; + sv = absl::string_view(name, 5); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(sv, "value")); + } +} + +TEST(NoopHeaderValidatorTest, AnyValueCharIsValid) { + NoopHeaderValidator v; + char value[] = "val ue"; + for (int i = std::numeric_limits::min(); + i < std::numeric_limits::max(); ++i) { + char c = static_cast(i); + value[3] = c; + auto sv = absl::string_view(value, 6); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("name", sv)); + } +} + +TEST(NoopHeaderValidatorTest, AnyStatusIsValid) { + NoopHeaderValidator v; + + for (HeaderType type : {HeaderType::RESPONSE, HeaderType::RESPONSE_100}) { + v.StartHeaderBlock(); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "bar")); + EXPECT_TRUE(v.FinishHeaderBlock(type)); + + v.StartHeaderBlock(); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "10")); + EXPECT_TRUE(v.FinishHeaderBlock(type)); + + v.StartHeaderBlock(); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "9000")); + EXPECT_TRUE(v.FinishHeaderBlock(type)); + + v.StartHeaderBlock(); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "400")); + EXPECT_TRUE(v.FinishHeaderBlock(type)); + } +} + +TEST(NoopHeaderValidatorTest, AnyAuthorityCharIsValid) { + char value[] = "ho st.example.com"; + for (int i = std::numeric_limits::min(); + i < std::numeric_limits::max(); ++i) { + char c = static_cast(i); + value[2] = c; + auto sv = absl::string_view(value, 17); + for (absl::string_view key : {":authority", "host"}) { + NoopHeaderValidator v; + v.StartHeaderBlock(); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(key, sv)); + } + } +} + +TEST(NoopHeaderValidatorTest, RequestHostAndAuthority) { + NoopHeaderValidator v; + v.StartHeaderBlock(); + for (Header to_add : kSampleRequestPseudoheaders) { + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + } + // If both "host" and ":authority" have the same value, validation succeeds. + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("host", "www.foo.com")); + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::REQUEST)); + + v.StartHeaderBlock(); + for (Header to_add : kSampleRequestPseudoheaders) { + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + } + // If "host" and ":authority" have different values, validation still + // succeeds. + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("host", "www.bar.com")); +} + +TEST(NoopHeaderValidatorTest, RequestPseudoHeaders) { + NoopHeaderValidator v; + for (Header to_skip : kSampleRequestPseudoheaders) { + v.StartHeaderBlock(); + for (Header to_add : kSampleRequestPseudoheaders) { + if (to_add != to_skip) { + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + } + } + // Even if a pseudo-header is missing, final validation will succeed. + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::REQUEST)); + } + + // When all pseudo-headers are present, final validation will succeed. + v.StartHeaderBlock(); + for (Header to_add : kSampleRequestPseudoheaders) { + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + } + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::REQUEST)); + + // When an extra pseudo-header is present, final validation will still + // succeed. + v.StartHeaderBlock(); + for (Header to_add : kSampleRequestPseudoheaders) { + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + } + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":extra", "blah")); + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::REQUEST)); + + // When a required pseudo-header is repeated, final validation will succeed. + for (Header to_repeat : kSampleRequestPseudoheaders) { + v.StartHeaderBlock(); + for (Header to_add : kSampleRequestPseudoheaders) { + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + if (to_add == to_repeat) { + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + } + } + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::REQUEST)); + } +} + +TEST(NoopHeaderValidatorTest, WebsocketPseudoHeaders) { + NoopHeaderValidator v; + v.StartHeaderBlock(); + for (Header to_add : kSampleRequestPseudoheaders) { + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + } + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":protocol", "websocket")); + // Validation always succeeds. + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::REQUEST)); + + // This is a no-op for NoopHeaderValidator. + v.SetAllowExtendedConnect(); + + v.StartHeaderBlock(); + for (Header to_add : kSampleRequestPseudoheaders) { + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + } + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":protocol", "websocket")); + // The validator does not check for a CONNECT request. + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::REQUEST)); + + v.StartHeaderBlock(); + for (Header to_add : kSampleRequestPseudoheaders) { + if (to_add.first == ":method") { + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, "CONNECT")); + } else { + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + } + } + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":protocol", "websocket")); + // After allowing the method, `:protocol` is acepted for CONNECT requests. + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::REQUEST)); +} + +TEST(NoopHeaderValidatorTest, AsteriskPathPseudoHeader) { + NoopHeaderValidator v; + + // The validator does not perform any path validation. + v.StartHeaderBlock(); + for (Header to_add : kSampleRequestPseudoheaders) { + if (to_add.first == ":path") { + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, "*")); + } else { + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + } + } + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::REQUEST)); + + v.StartHeaderBlock(); + for (Header to_add : kSampleRequestPseudoheaders) { + if (to_add.first == ":path") { + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, "*")); + } else if (to_add.first == ":method") { + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, "OPTIONS")); + } else { + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + } + } + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::REQUEST)); +} + +TEST(NoopHeaderValidatorTest, InvalidPathPseudoHeader) { + NoopHeaderValidator v; + + // An empty path is allowed. + v.StartHeaderBlock(); + for (Header to_add : kSampleRequestPseudoheaders) { + if (to_add.first == ":path") { + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, "")); + } else { + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + } + } + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::REQUEST)); + + // A path that does not start with a slash is allowed. + v.StartHeaderBlock(); + for (Header to_add : kSampleRequestPseudoheaders) { + if (to_add.first == ":path") { + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, "shawarma")); + } else { + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add.first, to_add.second)); + } + } + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::REQUEST)); +} + +TEST(NoopHeaderValidatorTest, ResponsePseudoHeaders) { + NoopHeaderValidator v; + + for (HeaderType type : {HeaderType::RESPONSE, HeaderType::RESPONSE_100}) { + // When `:status` is missing, validation succeeds. + v.StartHeaderBlock(); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("foo", "bar")); + EXPECT_TRUE(v.FinishHeaderBlock(type)); + + // When all pseudo-headers are present, final validation succeeds. + v.StartHeaderBlock(); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "199")); + EXPECT_TRUE(v.FinishHeaderBlock(type)); + EXPECT_EQ("199", v.status_header()); + + // When `:status` is repeated, validation succeeds. + v.StartHeaderBlock(); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "199")); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "299")); + EXPECT_TRUE(v.FinishHeaderBlock(type)); + + // When an extra pseudo-header is present, final validation succeeds. + v.StartHeaderBlock(); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "199")); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":extra", "blorp")); + EXPECT_TRUE(v.FinishHeaderBlock(type)); + } +} + +TEST(NoopHeaderValidatorTest, ResponseWithHost) { + NoopHeaderValidator v; + + v.StartHeaderBlock(); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "200")); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("host", "myserver.com")); + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::RESPONSE)); +} + +TEST(NoopHeaderValidatorTest, Response204) { + NoopHeaderValidator v; + + v.StartHeaderBlock(); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "204")); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("x-content", "is not present")); + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::RESPONSE)); +} + +TEST(NoopHeaderValidatorTest, ResponseWithMultipleIdenticalContentLength) { + NoopHeaderValidator v; + + v.StartHeaderBlock(); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "200")); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("content-length", "13")); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("content-length", "13")); +} + +TEST(NoopHeaderValidatorTest, ResponseWithMultipleDifferingContentLength) { + NoopHeaderValidator v; + + v.StartHeaderBlock(); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "200")); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("content-length", "13")); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("content-length", "17")); +} + +TEST(NoopHeaderValidatorTest, Response204WithContentLengthZero) { + NoopHeaderValidator v; + + v.StartHeaderBlock(); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "204")); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("x-content", "is not present")); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("content-length", "0")); + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::RESPONSE)); +} + +TEST(NoopHeaderValidatorTest, Response204WithContentLength) { + NoopHeaderValidator v; + + v.StartHeaderBlock(); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "204")); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("x-content", "is not present")); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("content-length", "1")); +} + +TEST(NoopHeaderValidatorTest, Response100) { + NoopHeaderValidator v; + + v.StartHeaderBlock(); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "100")); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("x-content", "is not present")); + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::RESPONSE)); +} + +TEST(NoopHeaderValidatorTest, Response100WithContentLengthZero) { + NoopHeaderValidator v; + + v.StartHeaderBlock(); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "100")); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("x-content", "is not present")); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("content-length", "0")); + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::RESPONSE)); +} + +TEST(NoopHeaderValidatorTest, Response100WithContentLength) { + NoopHeaderValidator v; + + v.StartHeaderBlock(); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "100")); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("x-content", "is not present")); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("content-length", "1")); +} + +TEST(NoopHeaderValidatorTest, ResponseTrailerPseudoHeaders) { + NoopHeaderValidator v; + + // When no pseudo-headers are present, validation will succeed. + v.StartHeaderBlock(); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("foo", "bar")); + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::RESPONSE_TRAILER)); + + // When a pseudo-header is present, validation will succeed. + v.StartHeaderBlock(); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "200")); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("foo", "bar")); + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::RESPONSE_TRAILER)); +} + +TEST(NoopHeaderValidatorTest, ValidContentLength) { + NoopHeaderValidator v; + + v.StartHeaderBlock(); + EXPECT_EQ(v.content_length(), absl::nullopt); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("content-length", "41")); + EXPECT_EQ(v.content_length(), absl::nullopt); + + v.StartHeaderBlock(); + EXPECT_EQ(v.content_length(), absl::nullopt); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("content-length", "42")); + EXPECT_EQ(v.content_length(), absl::nullopt); +} + +TEST(NoopHeaderValidatorTest, InvalidContentLength) { + NoopHeaderValidator v; + + v.StartHeaderBlock(); + EXPECT_EQ(v.content_length(), absl::nullopt); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("content-length", "")); + EXPECT_EQ(v.content_length(), absl::nullopt); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("content-length", "nan")); + EXPECT_EQ(v.content_length(), absl::nullopt); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("content-length", "-42")); + EXPECT_EQ(v.content_length(), absl::nullopt); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("content-length", "42")); + EXPECT_EQ(v.content_length(), absl::nullopt); +} + +TEST(NoopHeaderValidatorTest, TeHeader) { + NoopHeaderValidator v; + + v.StartHeaderBlock(); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("te", "trailers")); + + v.StartHeaderBlock(); + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader("te", "trailers, deflate")); +} + +TEST(NoopHeaderValidatorTest, ConnectionSpecificHeaders) { + const std::vector
connection_headers = { + {"connection", "keep-alive"}, {"proxy-connection", "keep-alive"}, + {"keep-alive", "timeout=42"}, {"transfer-encoding", "chunked"}, + {"upgrade", "h2c"}, + }; + for (const auto& [connection_key, connection_value] : connection_headers) { + NoopHeaderValidator v; + v.StartHeaderBlock(); + for (const auto& [sample_key, sample_value] : kSampleRequestPseudoheaders) { + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(sample_key, sample_value)); + } + EXPECT_EQ(NoopHeaderValidator::HEADER_OK, + v.ValidateSingleHeader(connection_key, connection_value)); + } +} + +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/oghttp2_adapter.cc b/quiche/http2/adapter/oghttp2_adapter.cc new file mode 100644 index 000000000000..fbfc76aae6ae --- /dev/null +++ b/quiche/http2/adapter/oghttp2_adapter.cc @@ -0,0 +1,168 @@ +#include "quiche/http2/adapter/oghttp2_adapter.h" + +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "quiche/http2/adapter/http2_util.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/spdy/core/spdy_protocol.h" + +namespace http2 { +namespace adapter { + +namespace { + +using spdy::SpdyGoAwayIR; +using spdy::SpdyPingIR; +using spdy::SpdyPriorityIR; +using spdy::SpdyWindowUpdateIR; + +} // namespace + +/* static */ +std::unique_ptr OgHttp2Adapter::Create( + Http2VisitorInterface& visitor, Options options) { + // Using `new` to access a non-public constructor. + return absl::WrapUnique(new OgHttp2Adapter(visitor, std::move(options))); +} + +OgHttp2Adapter::~OgHttp2Adapter() {} + +bool OgHttp2Adapter::IsServerSession() const { + return session_->IsServerSession(); +} + +int64_t OgHttp2Adapter::ProcessBytes(absl::string_view bytes) { + return session_->ProcessBytes(bytes); +} + +void OgHttp2Adapter::SubmitSettings(absl::Span settings) { + session_->SubmitSettings(settings); +} + +void OgHttp2Adapter::SubmitPriorityForStream(Http2StreamId stream_id, + Http2StreamId parent_stream_id, + int weight, bool exclusive) { + session_->EnqueueFrame(std::make_unique( + stream_id, parent_stream_id, weight, exclusive)); +} + +void OgHttp2Adapter::SubmitPing(Http2PingId ping_id) { + session_->EnqueueFrame(std::make_unique(ping_id)); +} + +void OgHttp2Adapter::SubmitShutdownNotice() { + session_->StartGracefulShutdown(); +} + +void OgHttp2Adapter::SubmitGoAway(Http2StreamId last_accepted_stream_id, + Http2ErrorCode error_code, + absl::string_view opaque_data) { + session_->EnqueueFrame(std::make_unique( + last_accepted_stream_id, TranslateErrorCode(error_code), + std::string(opaque_data))); +} +void OgHttp2Adapter::SubmitWindowUpdate(Http2StreamId stream_id, + int window_increment) { + session_->EnqueueFrame( + std::make_unique(stream_id, window_increment)); +} + +void OgHttp2Adapter::SubmitMetadata(Http2StreamId stream_id, + size_t /* max_frame_size */, + std::unique_ptr source) { + // Not necessary to pass max_frame_size along, since OgHttp2Session tracks the + // peer's advertised max frame size. + session_->SubmitMetadata(stream_id, std::move(source)); +} + +int OgHttp2Adapter::Send() { return session_->Send(); } + +int OgHttp2Adapter::GetSendWindowSize() const { + return session_->GetRemoteWindowSize(); +} + +int OgHttp2Adapter::GetStreamSendWindowSize(Http2StreamId stream_id) const { + return session_->GetStreamSendWindowSize(stream_id); +} + +int OgHttp2Adapter::GetStreamReceiveWindowLimit(Http2StreamId stream_id) const { + return session_->GetStreamReceiveWindowLimit(stream_id); +} + +int OgHttp2Adapter::GetStreamReceiveWindowSize(Http2StreamId stream_id) const { + return session_->GetStreamReceiveWindowSize(stream_id); +} + +int OgHttp2Adapter::GetReceiveWindowSize() const { + return session_->GetReceiveWindowSize(); +} + +int OgHttp2Adapter::GetHpackEncoderDynamicTableSize() const { + return session_->GetHpackEncoderDynamicTableSize(); +} + +int OgHttp2Adapter::GetHpackEncoderDynamicTableCapacity() const { + return session_->GetHpackEncoderDynamicTableCapacity(); +} + +int OgHttp2Adapter::GetHpackDecoderDynamicTableSize() const { + return session_->GetHpackDecoderDynamicTableSize(); +} + +int OgHttp2Adapter::GetHpackDecoderSizeLimit() const { + return session_->GetHpackDecoderSizeLimit(); +} + +Http2StreamId OgHttp2Adapter::GetHighestReceivedStreamId() const { + return session_->GetHighestReceivedStreamId(); +} + +void OgHttp2Adapter::MarkDataConsumedForStream(Http2StreamId stream_id, + size_t num_bytes) { + session_->Consume(stream_id, num_bytes); +} + +void OgHttp2Adapter::SubmitRst(Http2StreamId stream_id, + Http2ErrorCode error_code) { + session_->EnqueueFrame(std::make_unique( + stream_id, TranslateErrorCode(error_code))); +} + +int32_t OgHttp2Adapter::SubmitRequest( + absl::Span headers, + std::unique_ptr data_source, void* user_data) { + return session_->SubmitRequest(headers, std::move(data_source), user_data); +} + +int OgHttp2Adapter::SubmitResponse( + Http2StreamId stream_id, absl::Span headers, + std::unique_ptr data_source) { + return session_->SubmitResponse(stream_id, headers, std::move(data_source)); +} + +int OgHttp2Adapter::SubmitTrailer(Http2StreamId stream_id, + absl::Span trailers) { + return session_->SubmitTrailer(stream_id, trailers); +} + +void OgHttp2Adapter::SetStreamUserData(Http2StreamId stream_id, + void* user_data) { + session_->SetStreamUserData(stream_id, user_data); +} + +void* OgHttp2Adapter::GetStreamUserData(Http2StreamId stream_id) { + return session_->GetStreamUserData(stream_id); +} + +bool OgHttp2Adapter::ResumeStream(Http2StreamId stream_id) { + return session_->ResumeStream(stream_id); +} + +OgHttp2Adapter::OgHttp2Adapter(Http2VisitorInterface& visitor, Options options) + : Http2Adapter(visitor), + session_(std::make_unique(visitor, std::move(options))) {} + +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/oghttp2_adapter.h b/quiche/http2/adapter/oghttp2_adapter.h new file mode 100644 index 000000000000..76e3b93c7d2b --- /dev/null +++ b/quiche/http2/adapter/oghttp2_adapter.h @@ -0,0 +1,77 @@ +#ifndef QUICHE_HTTP2_ADAPTER_OGHTTP2_ADAPTER_H_ +#define QUICHE_HTTP2_ADAPTER_OGHTTP2_ADAPTER_H_ + +#include +#include + +#include "quiche/http2/adapter/http2_adapter.h" +#include "quiche/http2/adapter/http2_session.h" +#include "quiche/http2/adapter/oghttp2_session.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace adapter { + +class QUICHE_EXPORT OgHttp2Adapter : public Http2Adapter { + public: + using Options = OgHttp2Session::Options; + static std::unique_ptr Create(Http2VisitorInterface& visitor, + Options options); + + ~OgHttp2Adapter() override; + + // From Http2Adapter. + bool IsServerSession() const override; + bool want_read() const override { return session_->want_read(); } + bool want_write() const override { return session_->want_write(); } + int64_t ProcessBytes(absl::string_view bytes) override; + void SubmitSettings(absl::Span settings) override; + void SubmitPriorityForStream(Http2StreamId stream_id, + Http2StreamId parent_stream_id, int weight, + bool exclusive) override; + void SubmitPing(Http2PingId ping_id) override; + void SubmitShutdownNotice() override; + void SubmitGoAway(Http2StreamId last_accepted_stream_id, + Http2ErrorCode error_code, + absl::string_view opaque_data) override; + void SubmitWindowUpdate(Http2StreamId stream_id, + int window_increment) override; + void SubmitMetadata(Http2StreamId stream_id, size_t max_frame_size, + std::unique_ptr source) override; + int Send() override; + int GetSendWindowSize() const override; + int GetStreamSendWindowSize(Http2StreamId stream_id) const override; + int GetStreamReceiveWindowLimit(Http2StreamId stream_id) const override; + int GetStreamReceiveWindowSize(Http2StreamId stream_id) const override; + int GetReceiveWindowSize() const override; + int GetHpackEncoderDynamicTableSize() const override; + int GetHpackEncoderDynamicTableCapacity() const; + int GetHpackDecoderDynamicTableSize() const override; + int GetHpackDecoderSizeLimit() const; + Http2StreamId GetHighestReceivedStreamId() const override; + void MarkDataConsumedForStream(Http2StreamId stream_id, + size_t num_bytes) override; + void SubmitRst(Http2StreamId stream_id, Http2ErrorCode error_code) override; + int32_t SubmitRequest(absl::Span headers, + std::unique_ptr data_source, + void* user_data) override; + int SubmitResponse(Http2StreamId stream_id, absl::Span headers, + std::unique_ptr data_source) override; + + int SubmitTrailer(Http2StreamId stream_id, + absl::Span trailers) override; + + void SetStreamUserData(Http2StreamId stream_id, void* user_data) override; + void* GetStreamUserData(Http2StreamId stream_id) override; + bool ResumeStream(Http2StreamId stream_id) override; + + private: + OgHttp2Adapter(Http2VisitorInterface& visitor, Options options); + + std::unique_ptr session_; +}; + +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_OGHTTP2_ADAPTER_H_ diff --git a/quiche/http2/adapter/oghttp2_adapter_test.cc b/quiche/http2/adapter/oghttp2_adapter_test.cc new file mode 100644 index 000000000000..39ac613d2416 --- /dev/null +++ b/quiche/http2/adapter/oghttp2_adapter_test.cc @@ -0,0 +1,8352 @@ +#include "quiche/http2/adapter/oghttp2_adapter.h" + +#include +#include + +#include "absl/strings/str_join.h" +#include "quiche/http2/adapter/http2_protocol.h" +#include "quiche/http2/adapter/http2_visitor_interface.h" +#include "quiche/http2/adapter/mock_http2_visitor.h" +#include "quiche/http2/adapter/oghttp2_util.h" +#include "quiche/http2/adapter/test_frame_sequence.h" +#include "quiche/http2/adapter/test_utils.h" +#include "quiche/common/platform/api/quiche_expect_bug.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace http2 { +namespace adapter { +namespace test { +namespace { + +using ConnectionError = Http2VisitorInterface::ConnectionError; + +using spdy::SpdyFrameType; +using testing::_; + +enum FrameType { + DATA, + HEADERS, + PRIORITY, + RST_STREAM, + SETTINGS, + PUSH_PROMISE, + PING, + GOAWAY, + WINDOW_UPDATE, + CONTINUATION, +}; + +TEST(OgHttp2AdapterTest, IsServerSession) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + EXPECT_TRUE(adapter->IsServerSession()); +} + +TEST(OgHttp2AdapterTest, ProcessBytes) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence seq; + EXPECT_CALL(visitor, OnFrameHeader(0, 0, 4, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + EXPECT_CALL(visitor, OnFrameHeader(0, 8, 6, 0)); + EXPECT_CALL(visitor, OnPing(17, false)); + adapter->ProcessBytes( + TestFrameSequence().ClientPreface().Ping(17).Serialize()); +} + +TEST(OgHttp2AdapterTest, HeaderValuesWithObsTextAllowedByDefault) { + DataSavingVisitor visitor; + OgHttp2Session::Options options; + options.perspective = Perspective::kServer; + ASSERT_TRUE(options.allow_obs_text); + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}, + {"name", "val\xa1ue"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "name", "val\xa1ue")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); +} + +TEST(OgHttp2AdapterTest, HeaderValuesWithObsTextDisallowed) { + DataSavingVisitor visitor; + OgHttp2Session::Options options; + options.allow_obs_text = false; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}, + {"name", "val\xa1ue"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/")); + EXPECT_CALL( + visitor, + OnInvalidFrame(1, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); +} + +TEST(OgHttp2AdapterTest, InitialSettingsNoExtendedConnect) { + DataSavingVisitor client_visitor; + OgHttp2Adapter::Options client_options; + client_options.perspective = Perspective::kClient; + client_options.max_header_list_bytes = 42; + client_options.allow_extended_connect = false; + auto client_adapter = OgHttp2Adapter::Create(client_visitor, client_options); + + DataSavingVisitor server_visitor; + OgHttp2Adapter::Options server_options; + server_options.perspective = Perspective::kServer; + server_options.allow_extended_connect = false; + auto server_adapter = OgHttp2Adapter::Create(server_visitor, server_options); + + testing::InSequence s; + + // Client sends the connection preface, including the initial SETTINGS. + EXPECT_CALL(client_visitor, OnBeforeFrameSent(SETTINGS, 0, 12, 0x0)); + EXPECT_CALL(client_visitor, OnFrameSent(SETTINGS, 0, 12, 0x0, 0)); + { + int result = client_adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = client_visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({SpdyFrameType::SETTINGS})); + } + + // Server sends the connection preface, including the initial SETTINGS. + EXPECT_CALL(server_visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x0)); + EXPECT_CALL(server_visitor, OnFrameSent(SETTINGS, 0, 0, 0x0, 0)); + { + int result = server_adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = server_visitor.data(); + EXPECT_THAT(data, EqualsFrames({SpdyFrameType::SETTINGS})); + } + + // Client processes the server's initial bytes, including initial SETTINGS. + EXPECT_CALL(client_visitor, OnFrameHeader(0, 0, SETTINGS, 0x0)); + EXPECT_CALL(client_visitor, OnSettingsStart()); + EXPECT_CALL(client_visitor, OnSettingsEnd()); + { + const int64_t result = client_adapter->ProcessBytes(server_visitor.data()); + EXPECT_EQ(server_visitor.data().size(), static_cast(result)); + } + + // Server processes the client's initial bytes, including initial SETTINGS. + EXPECT_CALL(server_visitor, OnFrameHeader(0, 12, SETTINGS, 0x0)); + EXPECT_CALL(server_visitor, OnSettingsStart()); + EXPECT_CALL(server_visitor, + OnSetting(Http2Setting{Http2KnownSettingsId::ENABLE_PUSH, 0u})); + EXPECT_CALL( + server_visitor, + OnSetting(Http2Setting{Http2KnownSettingsId::MAX_HEADER_LIST_SIZE, 42u})); + EXPECT_CALL(server_visitor, OnSettingsEnd()); + { + const int64_t result = server_adapter->ProcessBytes(client_visitor.data()); + EXPECT_EQ(client_visitor.data().size(), static_cast(result)); + } +} + +TEST(OgHttp2AdapterTest, InitialSettings) { + DataSavingVisitor client_visitor; + OgHttp2Adapter::Options client_options; + client_options.perspective = Perspective::kClient; + client_options.max_header_list_bytes = 42; + ASSERT_TRUE(client_options.allow_extended_connect); + auto client_adapter = OgHttp2Adapter::Create(client_visitor, client_options); + + DataSavingVisitor server_visitor; + OgHttp2Adapter::Options server_options; + server_options.perspective = Perspective::kServer; + ASSERT_TRUE(server_options.allow_extended_connect); + auto server_adapter = OgHttp2Adapter::Create(server_visitor, server_options); + + testing::InSequence s; + + // Client sends the connection preface, including the initial SETTINGS. + EXPECT_CALL(client_visitor, OnBeforeFrameSent(SETTINGS, 0, 12, 0x0)); + EXPECT_CALL(client_visitor, OnFrameSent(SETTINGS, 0, 12, 0x0, 0)); + { + int result = client_adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = client_visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({SpdyFrameType::SETTINGS})); + } + + // Server sends the connection preface, including the initial SETTINGS. + EXPECT_CALL(server_visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(server_visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + { + int result = server_adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = server_visitor.data(); + EXPECT_THAT(data, EqualsFrames({SpdyFrameType::SETTINGS})); + } + + // Client processes the server's initial bytes, including initial SETTINGS. + EXPECT_CALL(client_visitor, OnFrameHeader(0, 6, SETTINGS, 0x0)); + EXPECT_CALL(client_visitor, OnSettingsStart()); + EXPECT_CALL(client_visitor, + OnSetting(Http2Setting{ + Http2KnownSettingsId::ENABLE_CONNECT_PROTOCOL, 1u})); + EXPECT_CALL(client_visitor, OnSettingsEnd()); + { + const int64_t result = client_adapter->ProcessBytes(server_visitor.data()); + EXPECT_EQ(server_visitor.data().size(), static_cast(result)); + } + + // Server processes the client's initial bytes, including initial SETTINGS. + EXPECT_CALL(server_visitor, OnFrameHeader(0, 12, SETTINGS, 0x0)); + EXPECT_CALL(server_visitor, OnSettingsStart()); + EXPECT_CALL(server_visitor, + OnSetting(Http2Setting{Http2KnownSettingsId::ENABLE_PUSH, 0u})); + EXPECT_CALL( + server_visitor, + OnSetting(Http2Setting{Http2KnownSettingsId::MAX_HEADER_LIST_SIZE, 42u})); + EXPECT_CALL(server_visitor, OnSettingsEnd()); + { + const int64_t result = server_adapter->ProcessBytes(client_visitor.data()); + EXPECT_EQ(client_visitor.data().size(), static_cast(result)); + } +} + +TEST(OgHttp2AdapterTest, AutomaticSettingsAndPingAcks) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = + TestFrameSequence().ClientPreface().Ping(42).Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // PING + EXPECT_CALL(visitor, OnFrameHeader(0, _, PING, 0)); + EXPECT_CALL(visitor, OnPing(42, false)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + + EXPECT_TRUE(adapter->want_write()); + + // Server preface (SETTINGS) + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + // SETTINGS ack + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + // PING ack + EXPECT_CALL(visitor, OnBeforeFrameSent(PING, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(PING, 0, _, ACK_FLAG, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::PING})); +} + +TEST(OgHttp2AdapterTest, AutomaticPingAcksDisabled) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + options.auto_ping_ack = false; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = + TestFrameSequence().ClientPreface().Ping(42).Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // PING + EXPECT_CALL(visitor, OnFrameHeader(0, _, PING, 0)); + EXPECT_CALL(visitor, OnPing(42, false)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + + EXPECT_TRUE(adapter->want_write()); + + // Server preface (SETTINGS) + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + // SETTINGS ack + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + // No PING ack expected because automatic PING acks are disabled. + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS})); +} + +TEST(OgHttp2AdapterTest, InvalidMaxFrameSizeSetting) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = + TestFrameSequence().ClientPreface({{MAX_FRAME_SIZE, 3u}}).Serialize(); + testing::InSequence s; + + // Client preface + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL( + visitor, + OnInvalidFrame(0, Http2VisitorInterface::InvalidFrameError::kProtocol)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kInvalidSetting)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterTest, InvalidPushSetting) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = + TestFrameSequence().ClientPreface({{ENABLE_PUSH, 3u}}).Serialize(); + testing::InSequence s; + + // Client preface + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL( + visitor, + OnInvalidFrame(0, Http2VisitorInterface::InvalidFrameError::kProtocol)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kInvalidSetting)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterTest, InvalidConnectProtocolSetting) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface({{ENABLE_CONNECT_PROTOCOL, 3u}}) + .Serialize(); + testing::InSequence s; + + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL( + visitor, + OnInvalidFrame(0, Http2VisitorInterface::InvalidFrameError::kProtocol)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kInvalidSetting)); + + int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::GOAWAY})); + + auto adapter2 = OgHttp2Adapter::Create(visitor, options); + const std::string frames2 = TestFrameSequence() + .ClientPreface({{ENABLE_CONNECT_PROTOCOL, 1}}) + .Settings({{ENABLE_CONNECT_PROTOCOL, 0}}) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting(Http2Setting{ENABLE_CONNECT_PROTOCOL, 1u})); + EXPECT_CALL(visitor, OnSettingsEnd()); + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL( + visitor, + OnInvalidFrame(0, Http2VisitorInterface::InvalidFrameError::kProtocol)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kInvalidSetting)); + + read_result = adapter2->ProcessBytes(frames2); + EXPECT_EQ(static_cast(read_result), frames2.size()); + + EXPECT_TRUE(adapter2->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + adapter2->Send(); +} + +TEST(OgHttp2AdapterTest, ClientHandles100Headers) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, {{":status", "100"}}, + /*fin=*/false) + .Ping(101) + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "100")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + EXPECT_CALL(visitor, OnFrameHeader(0, 8, PING, 0)); + EXPECT_CALL(visitor, OnPing(101, false)); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(PING, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(PING, 0, _, ACK_FLAG, 0)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::PING})); +} + +TEST(OgHttp2AdapterTest, QueuingWindowUpdateAffectsWindow) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + EXPECT_EQ(adapter->GetReceiveWindowSize(), kInitialFlowControlWindowSize); + adapter->SubmitWindowUpdate(0, 10000); + EXPECT_EQ(adapter->GetReceiveWindowSize(), + kInitialFlowControlWindowSize + 10000); + + const std::vector
headers = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + const int32_t stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(WINDOW_UPDATE, 0, 4, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(WINDOW_UPDATE, 0, 4, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + + EXPECT_EQ(adapter->GetStreamReceiveWindowSize(stream_id), + kInitialFlowControlWindowSize); + adapter->SubmitWindowUpdate(1, 20000); + EXPECT_EQ(adapter->GetStreamReceiveWindowSize(stream_id), + kInitialFlowControlWindowSize + 20000); +} + +TEST(OgHttp2AdapterTest, AckOfSettingInitialWindowSizeAffectsWindow) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + const int32_t stream_id1 = adapter->SubmitRequest(headers, nullptr, nullptr); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + + const std::string initial_frames = + TestFrameSequence() + .ServerPreface() + .SettingsAck() // Ack of the client's initial settings. + .Serialize(); + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0x0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, ACK_FLAG)); + EXPECT_CALL(visitor, OnSettingsAck); + + int64_t parse_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), static_cast(parse_result)); + + EXPECT_EQ(adapter->GetStreamReceiveWindowSize(stream_id1), + kInitialFlowControlWindowSize); + adapter->SubmitSettings({{INITIAL_WINDOW_SIZE, 80000u}}); + // No update for the first stream, yet. + EXPECT_EQ(adapter->GetStreamReceiveWindowSize(stream_id1), + kInitialFlowControlWindowSize); + + // Ack of server's initial settings. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + + // Outbound SETTINGS containing INITIAL_WINDOW_SIZE. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + + // Still no update, as a SETTINGS ack has not yet been received. + EXPECT_EQ(adapter->GetStreamReceiveWindowSize(stream_id1), + kInitialFlowControlWindowSize); + + const std::string settings_ack = + TestFrameSequence().SettingsAck().Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, ACK_FLAG)); + EXPECT_CALL(visitor, OnSettingsAck); + + parse_result = adapter->ProcessBytes(settings_ack); + EXPECT_EQ(settings_ack.size(), static_cast(parse_result)); + + // Stream window has been updated. + EXPECT_EQ(adapter->GetStreamReceiveWindowSize(stream_id1), 80000); + + const std::vector
headers2 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}); + const int32_t stream_id2 = adapter->SubmitRequest(headers, nullptr, nullptr); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id2, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id2, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + result = adapter->Send(); + EXPECT_EQ(0, result); + + EXPECT_EQ(adapter->GetStreamReceiveWindowSize(stream_id2), 80000); +} + +TEST(OgHttp2AdapterTest, ClientRejects100HeadersWithFin) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, {{":status", "100"}}, /*fin=*/false) + .Headers(1, {{":status", "100"}}, /*fin=*/true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "100")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "100")); + EXPECT_CALL(visitor, + OnInvalidFrame( + 1, Http2VisitorInterface::InvalidFrameError::kHttpMessaging)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(RST_STREAM, 1, _, 0x0, 1)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(OgHttp2AdapterTest, ClientRejects100HeadersWithContent) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, {{":status", "100"}}, + /*fin=*/false) + .Data(1, "We needed the final headers before data, whoops") + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "100")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, _)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(OgHttp2AdapterTest, ClientRejects100HeadersWithContentLength) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, {{":status", "100"}, {"content-length", "42"}}, + /*fin=*/false) + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "100")); + EXPECT_CALL( + visitor, + OnInvalidFrame(1, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(OgHttp2AdapterTest, ClientHandles204WithContent) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + const std::vector
headers2 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}); + + const int32_t stream_id2 = adapter->SubmitRequest(headers2, nullptr, nullptr); + ASSERT_GT(stream_id2, stream_id1); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id2, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id2, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, {{":status", "204"}, {"content-length", "2"}}, + /*fin=*/false) + .Data(1, "hi") + .Headers(3, {{":status", "204"}}, /*fin=*/false) + .Data(3, "hi") + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "204")); + EXPECT_CALL( + visitor, + OnInvalidFrame(1, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":status", "204")); + EXPECT_CALL(visitor, OnEndHeadersForStream(3)); + EXPECT_CALL(visitor, OnFrameHeader(3, _, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(3, 2)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 3, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 3, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(3, Http2ErrorCode::HTTP2_NO_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::RST_STREAM, + SpdyFrameType::RST_STREAM})); +} + +TEST(OgHttp2AdapterTest, ClientHandles304WithContent) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, {{":status", "304"}, {"content-length", "2"}}, + /*fin=*/false) + .Data(1, "hi") + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "304")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "content-length", "2")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 2)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(OgHttp2AdapterTest, ClientHandles304WithContentLength) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + ASSERT_GT(stream_id, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, {{":status", "304"}, {"content-length", "2"}}, + /*fin=*/true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "304")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "content-length", "2")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(OgHttp2AdapterTest, ClientHandlesTrailers) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(1, "This is the response body.") + .Headers(1, {{"final-status", "A-OK"}}, + /*fin=*/true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 26, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 26)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the response body.")); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, "final-status", "A-OK")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(OgHttp2AdapterTest, ClientSendsTrailers) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const std::string kBody = "This is an example request body."; + auto body1 = std::make_unique(visitor, false); + body1->AppendPayload(kBody); + body1->EndData(); + + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, std::move(body1), nullptr); + ASSERT_GT(stream_id1, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id1, _, 0x0, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS, + SpdyFrameType::DATA})); + visitor.Clear(); + + const std::vector
trailers1 = + ToHeaders({{"extra-info", "Trailers are weird but good?"}}); + adapter->SubmitTrailer(stream_id1, trailers1); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + data = visitor.data(); + EXPECT_THAT(data, EqualsFrames({SpdyFrameType::HEADERS})); +} + +TEST(OgHttp2AdapterTest, ClientHandlesMetadata) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Metadata(0, "Example connection metadata") + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Metadata(1, "Example stream metadata") + .Data(1, "This is the response body.", true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(0, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataEndForStream(0)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataEndForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 26, DATA, 1)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 26)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the response body.")); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(OgHttp2AdapterTest, ClientHandlesMetadataWithEmptyPayload) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Metadata(1, "") + .Data(1, "This is the response body.", true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(3); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataEndForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 1)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, _)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the response body.")); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); +} + +TEST(OgHttp2AdapterTest, ClientHandlesMetadataWithPayloadError) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + ASSERT_GT(stream_id, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Metadata(0, "Example connection metadata") + .Headers(stream_id, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Metadata(stream_id, "Example stream metadata") + .Data(stream_id, "This is the response body.", true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(0, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataEndForStream(0)); + EXPECT_CALL(visitor, OnFrameHeader(stream_id, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(stream_id)); + EXPECT_CALL(visitor, OnHeaderForStream(stream_id, _, _)).Times(3); + EXPECT_CALL(visitor, OnEndHeadersForStream(stream_id)); + EXPECT_CALL(visitor, OnFrameHeader(stream_id, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(stream_id, _)); + EXPECT_CALL(visitor, OnMetadataForStream(stream_id, _)) + .WillOnce(testing::Return(false)); + // Remaining frames are not processed due to the error. + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + // Negative integer returned to indicate an error. + EXPECT_LT(stream_result, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::INTERNAL_ERROR))); + + EXPECT_FALSE(adapter->want_read()); + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterTest, ClientHandlesMetadataWithCompletionError) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + ASSERT_GT(stream_id, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Metadata(0, "Example connection metadata") + .Headers(stream_id, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Metadata(stream_id, "Example stream metadata") + .Data(stream_id, "This is the response body.", true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(0, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataEndForStream(0)); + EXPECT_CALL(visitor, OnFrameHeader(stream_id, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(stream_id)); + EXPECT_CALL(visitor, OnHeaderForStream(stream_id, _, _)).Times(3); + EXPECT_CALL(visitor, OnEndHeadersForStream(stream_id)); + EXPECT_CALL(visitor, OnFrameHeader(stream_id, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(stream_id, _)); + EXPECT_CALL(visitor, OnMetadataForStream(stream_id, _)); + EXPECT_CALL(visitor, OnMetadataEndForStream(stream_id)) + .WillOnce(testing::Return(false)); + // Remaining frames are not processed due to the error. + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + // Negative integer returned to indicate an error. + EXPECT_LT(stream_result, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::INTERNAL_ERROR))); + + EXPECT_FALSE(adapter->want_read()); + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterTest, ClientRstStreamWhileHandlingHeaders) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(1, "This is the response body.") + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")) + .WillOnce(testing::DoAll( + testing::InvokeWithoutArgs([&adapter]() { + adapter->SubmitRst(1, Http2ErrorCode::REFUSED_STREAM); + }), + testing::Return(Http2VisitorInterface::HEADER_RST_STREAM))); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, stream_id1, 4, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, stream_id1, 4, 0x0, + static_cast(Http2ErrorCode::REFUSED_STREAM))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(OgHttp2AdapterTest, ClientConnectionErrorWhileHandlingHeaders) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(1, "This is the response body.") + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")) + .WillOnce( + testing::Return(Http2VisitorInterface::HEADER_CONNECTION_ERROR)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kHeaderError)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_LT(stream_result, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::INTERNAL_ERROR))); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterTest, ClientConnectionErrorWhileHandlingHeadersOnly) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")) + .WillOnce( + testing::Return(Http2VisitorInterface::HEADER_CONNECTION_ERROR)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kHeaderError)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_LT(stream_result, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::INTERNAL_ERROR))); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterTest, ClientRejectsHeaders) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(1, "This is the response body.") + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)) + .WillOnce(testing::Return(false)); + // Rejecting headers leads to a connection error. + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kHeaderError)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_LT(stream_result, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::INTERNAL_ERROR))); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterTest, ClientHandlesSmallerHpackHeaderTableSetting) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers1 = ToHeaders({ + {":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"x-i-do-not-like", "green eggs and ham"}, + {"x-i-will-not-eat-them", "here or there, in a box, with a fox"}, + {"x-like-them-in-a-house", "no"}, + {"x-like-them-with-a-mouse", "no"}, + }); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + EXPECT_GT(adapter->GetHpackEncoderDynamicTableSize(), 100); + + const std::string stream_frames = + TestFrameSequence().Settings({{HEADER_TABLE_SIZE, 100u}}).Serialize(); + // Server preface (SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting(Http2Setting{HEADER_TABLE_SIZE, 100u})); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_EQ(adapter->GetHpackEncoderDynamicTableCapacity(), 100); + EXPECT_LE(adapter->GetHpackEncoderDynamicTableSize(), 100); +} + +TEST(OgHttp2AdapterTest, ClientHandlesLargerHpackHeaderTableSetting) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + EXPECT_EQ(adapter->GetHpackEncoderDynamicTableCapacity(), 4096); + + const std::string stream_frames = + TestFrameSequence().Settings({{HEADER_TABLE_SIZE, 40960u}}).Serialize(); + // Server preface (SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting(Http2Setting{HEADER_TABLE_SIZE, 40960u})); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + // The increased capacity will not be applied until a SETTINGS ack is + // serialized. + EXPECT_EQ(adapter->GetHpackEncoderDynamicTableCapacity(), 4096); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + EXPECT_EQ(adapter->GetHpackEncoderDynamicTableCapacity(), 40960); +} + +TEST(OgHttp2AdapterTest, ClientSendsHpackHeaderTableSetting) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers1 = ToHeaders({ + {":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + }); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .SettingsAck() + .Headers( + 1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}, + {"x-i-do-not-like", "green eggs and ham"}, + {"x-i-will-not-eat-them", "here or there, in a box, with a fox"}, + {"x-like-them-in-a-house", "no"}, + {"x-like-them-with-a-mouse", "no"}}, + /*fin=*/true) + .Serialize(); + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Server acks client's initial SETTINGS. + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 1)); + EXPECT_CALL(visitor, OnSettingsAck()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(7); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_GT(adapter->GetHpackDecoderSizeLimit(), 100); + + // Submit settings, check decoder table size. + adapter->SubmitSettings({{HEADER_TABLE_SIZE, 100u}}); + EXPECT_GT(adapter->GetHpackDecoderSizeLimit(), 100); + + // Server preface SETTINGS ack + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + // SETTINGS with the new header table size value + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + + // Because the client has not yet seen an ack from the server for the SETTINGS + // with header table size, it has not applied the new value. + EXPECT_GT(adapter->GetHpackDecoderSizeLimit(), 100); + + result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::vector
headers2 = ToHeaders({ + {":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}, + }); + + const int32_t stream_id2 = adapter->SubmitRequest(headers2, nullptr, nullptr); + ASSERT_GT(stream_id2, stream_id1); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id2, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id2, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string response_frames = + TestFrameSequence() + .Headers(stream_id2, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/true) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(stream_id2, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(stream_id2)); + EXPECT_CALL(visitor, OnHeaderForStream(stream_id2, _, _)).Times(3); + EXPECT_CALL(visitor, OnEndHeadersForStream(stream_id2)); + EXPECT_CALL(visitor, OnEndStream(stream_id2)); + EXPECT_CALL(visitor, + OnCloseStream(stream_id2, Http2ErrorCode::HTTP2_NO_ERROR)); + + const int64_t response_result = adapter->ProcessBytes(response_frames); + EXPECT_EQ(response_frames.size(), static_cast(response_result)); + + // Still no ack for the outbound settings. + EXPECT_GT(adapter->GetHpackDecoderSizeLimit(), 100); + + const std::string settings_ack = + TestFrameSequence().SettingsAck().Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 1)); + EXPECT_CALL(visitor, OnSettingsAck()); + + const int64_t ack_result = adapter->ProcessBytes(settings_ack); + EXPECT_EQ(settings_ack.size(), static_cast(ack_result)); + // Ack has finally arrived. + EXPECT_EQ(adapter->GetHpackDecoderSizeLimit(), 100); +} + +// TODO(birenroy): Validate headers and re-enable this test. The library should +// invoke OnErrorDebug() with an error message for the invalid header. The +// library should also invoke OnInvalidFrame() for the invalid HEADERS frame. +TEST(OgHttp2AdapterTest, DISABLED_ClientHandlesInvalidTrailers) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(1, "This is the response body.") + .Headers(1, {{":bad-status", "9000"}}, + /*fin=*/true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 26, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 26)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the response body.")); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + + // Bad status trailer will cause a PROTOCOL_ERROR. The header is never + // delivered in an OnHeaderForStream callback. + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, stream_id1, 4, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(RST_STREAM, stream_id1, 4, 0x0, 1)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::PROTOCOL_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(OgHttp2AdapterTest, ClientStartsShutdown) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + EXPECT_FALSE(adapter->want_write()); + + // No-op (except for logging) for a client implementation. + adapter->SubmitShutdownNotice(); + EXPECT_FALSE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(OgHttp2AdapterTest, ClientReceivesGoAway) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + const std::vector
headers2 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}); + + const int32_t stream_id2 = adapter->SubmitRequest(headers2, nullptr, nullptr); + ASSERT_GT(stream_id2, stream_id1); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id2, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id2, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS, + SpdyFrameType::HEADERS})); + visitor.Clear(); + + // Submit a pending WINDOW_UPDATE for a stream that will be closed due to + // GOAWAY. The WINDOW_UPDATE should not be sent. + adapter->SubmitWindowUpdate(3, 42); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .RstStream(1, Http2ErrorCode::ENHANCE_YOUR_CALM) + .GoAway(1, Http2ErrorCode::INTERNAL_ERROR, "indigestion") + .WindowUpdate(0, 42) + .WindowUpdate(1, 42) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + EXPECT_CALL(visitor, OnFrameHeader(1, 4, RST_STREAM, 0)); + EXPECT_CALL(visitor, OnRstStream(1, Http2ErrorCode::ENHANCE_YOUR_CALM)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::ENHANCE_YOUR_CALM)); + EXPECT_CALL(visitor, OnFrameHeader(0, _, GOAWAY, 0)); + // Currently, oghttp2 does not pass the opaque data to the visitor. + EXPECT_CALL(visitor, OnGoAway(1, Http2ErrorCode::INTERNAL_ERROR, "")); + EXPECT_CALL(visitor, OnCloseStream(3, Http2ErrorCode::REFUSED_STREAM)); + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(0, 42)); + EXPECT_CALL(visitor, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(OgHttp2AdapterTest, ClientReceivesMultipleGoAways) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string initial_frames = + TestFrameSequence() + .ServerPreface() + .GoAway(kMaxStreamId, Http2ErrorCode::INTERNAL_ERROR, "indigestion") + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + EXPECT_CALL(visitor, OnFrameHeader(0, _, GOAWAY, 0)); + // Currently, oghttp2 does not pass the opaque data to the visitor. + EXPECT_CALL(visitor, + OnGoAway(kMaxStreamId, Http2ErrorCode::INTERNAL_ERROR, "")); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), static_cast(initial_result)); + + // Submit a WINDOW_UPDATE for the open stream. Because the stream is below the + // GOAWAY's last_stream_id, it should be sent. + adapter->SubmitWindowUpdate(1, 42); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(WINDOW_UPDATE, 1, 4, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(WINDOW_UPDATE, 1, 4, 0x0, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS, + SpdyFrameType::WINDOW_UPDATE})); + visitor.Clear(); + + const std::string final_frames = + TestFrameSequence() + .GoAway(0, Http2ErrorCode::INTERNAL_ERROR, "indigestion") + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, _, GOAWAY, 0)); + // Currently, oghttp2 does not pass the opaque data to the visitor. + EXPECT_CALL(visitor, OnGoAway(0, Http2ErrorCode::INTERNAL_ERROR, "")); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::REFUSED_STREAM)); + + const int64_t final_result = adapter->ProcessBytes(final_frames); + EXPECT_EQ(final_frames.size(), static_cast(final_result)); + + EXPECT_FALSE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), testing::IsEmpty()); +} + +TEST(OgHttp2AdapterTest, ClientReceivesMultipleGoAwaysWithIncreasingStreamId) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string frames = + TestFrameSequence() + .ServerPreface() + .GoAway(0, Http2ErrorCode::HTTP2_NO_ERROR, "") + .GoAway(0, Http2ErrorCode::ENHANCE_YOUR_CALM, "") + .GoAway(1, Http2ErrorCode::INTERNAL_ERROR, "") + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + EXPECT_CALL(visitor, OnFrameHeader(0, _, GOAWAY, 0)); + EXPECT_CALL(visitor, OnGoAway(0, Http2ErrorCode::HTTP2_NO_ERROR, "")); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::REFUSED_STREAM)); + EXPECT_CALL(visitor, OnFrameHeader(0, _, GOAWAY, 0)); + EXPECT_CALL(visitor, OnGoAway(0, Http2ErrorCode::ENHANCE_YOUR_CALM, "")); + EXPECT_CALL(visitor, OnFrameHeader(0, _, GOAWAY, 0)); + EXPECT_CALL( + visitor, + OnInvalidFrame(0, Http2VisitorInterface::InvalidFrameError::kProtocol)); + // The oghttp2 stack also signals the error via OnConnectionError(). + EXPECT_CALL(visitor, + OnConnectionError(ConnectionError::kInvalidGoAwayLastStreamId)); + + const int64_t frames_result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(frames_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterTest, ClientReceivesGoAwayWithPendingStreams) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({SpdyFrameType::SETTINGS})); + visitor.Clear(); + + const std::string initial_frames = + TestFrameSequence() + .ServerPreface({{MAX_CONCURRENT_STREAMS, 1}}) + .Serialize(); + + // Server preface (SETTINGS with MAX_CONCURRENT_STREAMS) + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), static_cast(initial_result)); + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + const std::vector
headers2 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}); + + const int32_t stream_id2 = adapter->SubmitRequest(headers2, nullptr, nullptr); + ASSERT_GT(stream_id2, stream_id1); + + // The second request should be pending because of + // SETTINGS_MAX_CONCURRENT_STREAMS. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); + visitor.Clear(); + + // Let the client receive a GOAWAY and raise MAX_CONCURRENT_STREAMS. Even + // though the GOAWAY last_stream_id is higher than the pending request's + // stream ID, pending request should not be sent. + const std::string stream_frames = + TestFrameSequence() + .GoAway(kMaxStreamId, Http2ErrorCode::INTERNAL_ERROR, "indigestion") + .Settings({{MAX_CONCURRENT_STREAMS, 42u}}) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, _, GOAWAY, 0)); + EXPECT_CALL(visitor, + OnGoAway(kMaxStreamId, Http2ErrorCode::INTERNAL_ERROR, "")); + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting(Http2Setting{MAX_CONCURRENT_STREAMS, 42u})); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + + // We close the pending stream on the next write attempt. + EXPECT_CALL(visitor, OnCloseStream(3, Http2ErrorCode::REFUSED_STREAM)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); + visitor.Clear(); + + // Requests submitted after receiving the GOAWAY should not be sent. + const std::vector
headers3 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/three"}}); + + const int32_t stream_id3 = adapter->SubmitRequest(headers3, nullptr, nullptr); + ASSERT_GT(stream_id3, stream_id2); + + // We close the pending stream on the next write attempt. + EXPECT_CALL(visitor, OnCloseStream(5, Http2ErrorCode::REFUSED_STREAM)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), testing::IsEmpty()); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(OgHttp2AdapterTest, ClientFailsOnGoAway) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .GoAway(1, Http2ErrorCode::INTERNAL_ERROR, "indigestion") + .Data(1, "This is the response body.") + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(0, _, GOAWAY, 0)); + // TODO(birenroy): Pass the GOAWAY opaque data through the oghttp2 stack. + EXPECT_CALL(visitor, OnGoAway(1, Http2ErrorCode::INTERNAL_ERROR, "")) + .WillOnce(testing::Return(false)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_LT(stream_result, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::INTERNAL_ERROR))); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterTest, ClientRejects101Response) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"upgrade", "new-protocol"}}); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "101"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL( + visitor, + OnInvalidFrame(1, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(static_cast(stream_frames.size()), stream_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, 4, 0x0)); + EXPECT_CALL( + visitor, + OnFrameSent(RST_STREAM, 1, 4, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(OgHttp2AdapterTest, ClientObeysMaxConcurrentStreams) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + EXPECT_FALSE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + + // Even though the user has not queued any frames for the session, it should + // still send the connection preface. + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + // Initial SETTINGS. + EXPECT_THAT(serialized, EqualsFrames({SpdyFrameType::SETTINGS})); + visitor.Clear(); + + const std::string initial_frames = + TestFrameSequence() + .ServerPreface({{MAX_CONCURRENT_STREAMS, 1}}) + .Serialize(); + testing::InSequence s; + + // Server preface (SETTINGS with MAX_CONCURRENT_STREAMS) + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), static_cast(initial_result)); + + // Session will want to write a SETTINGS ack. + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); + visitor.Clear(); + + const std::string kBody = "This is an example request body."; + auto body1 = std::make_unique(visitor, true); + body1->AppendPayload(kBody); + body1->EndData(); + const int stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(body1), nullptr); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, + OnBeforeFrameSent(HEADERS, stream_id, _, END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, END_HEADERS_FLAG, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, _, END_STREAM_FLAG, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::HEADERS, SpdyFrameType::DATA})); + EXPECT_THAT(visitor.data(), testing::HasSubstr(kBody)); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); + + const int next_stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}), + nullptr, nullptr); + + // A new pending stream is created, but because of MAX_CONCURRENT_STREAMS, the + // session should not want to write it at the moment. + EXPECT_GT(next_stream_id, stream_id); + EXPECT_FALSE(adapter->want_write()); + + const std::string stream_frames = + TestFrameSequence() + .Headers(stream_id, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(stream_id, "This is the response body.", /*fin=*/true) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(stream_id, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(stream_id)); + EXPECT_CALL(visitor, OnHeaderForStream(stream_id, ":status", "200")); + EXPECT_CALL(visitor, + OnHeaderForStream(stream_id, "server", "my-fake-server")); + EXPECT_CALL(visitor, OnHeaderForStream(stream_id, "date", + "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(stream_id)); + EXPECT_CALL(visitor, OnFrameHeader(stream_id, 26, DATA, END_STREAM_FLAG)); + EXPECT_CALL(visitor, OnBeginDataForStream(stream_id, 26)); + EXPECT_CALL(visitor, + OnDataForStream(stream_id, "This is the response body.")); + EXPECT_CALL(visitor, OnEndStream(stream_id)); + EXPECT_CALL(visitor, + OnCloseStream(stream_id, Http2ErrorCode::HTTP2_NO_ERROR)); + + // The first stream should close, which should make the session want to write + // the next stream. + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, next_stream_id, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, next_stream_id, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::HEADERS})); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(OgHttp2AdapterTest, ClientReceivesInitialWindowSetting) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string initial_frames = + TestFrameSequence() + .Settings({{INITIAL_WINDOW_SIZE, 80000u}}) + .WindowUpdate(0, 65536) + .Serialize(); + // Server preface (SETTINGS with INITIAL_STREAM_WINDOW) + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting(Http2Setting{INITIAL_WINDOW_SIZE, 80000u})); + EXPECT_CALL(visitor, OnSettingsEnd()); + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(0, 65536)); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), static_cast(initial_result)); + + // Session will want to write a SETTINGS ack. + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + + int64_t result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS})); + visitor.Clear(); + + const std::string kLongBody = std::string(81000, 'c'); + auto body1 = std::make_unique(visitor, true); + body1->AppendPayload(kLongBody); + body1->EndData(); + const int stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(body1), nullptr); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x4, 0)); + // The client can send more than 4 frames (65536 bytes) of data. + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, 16384, 0x0, 0)).Times(4); + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, 14464, 0x0, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::HEADERS, SpdyFrameType::DATA, + SpdyFrameType::DATA, SpdyFrameType::DATA, + SpdyFrameType::DATA, SpdyFrameType::DATA})); +} + +TEST(OgHttp2AdapterTest, ClientReceivesInitialWindowSettingAfterStreamStart) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string initial_frames = + TestFrameSequence().ServerPreface().WindowUpdate(0, 65536).Serialize(); + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(0, 65536)); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), static_cast(initial_result)); + + // Session will want to write a SETTINGS ack. + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + + int64_t result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string kLongBody = std::string(81000, 'c'); + auto body1 = std::make_unique(visitor, true); + body1->AppendPayload(kLongBody); + body1->EndData(); + const int stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(body1), nullptr); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x4, 0)); + // The client can only send 65535 bytes of data, as the stream window has not + // yet been increased. + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, 16384, 0x0, 0)).Times(3); + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, 16383, 0x0, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::HEADERS, SpdyFrameType::DATA, + SpdyFrameType::DATA, SpdyFrameType::DATA, + SpdyFrameType::DATA})); + visitor.Clear(); + + // Can't write any more due to flow control. + EXPECT_FALSE(adapter->want_write()); + + const std::string settings_frame = + TestFrameSequence().Settings({{INITIAL_WINDOW_SIZE, 80000u}}).Serialize(); + // SETTINGS with INITIAL_STREAM_WINDOW + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting(Http2Setting{INITIAL_WINDOW_SIZE, 80000u})); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t settings_result = adapter->ProcessBytes(settings_frame); + EXPECT_EQ(settings_frame.size(), static_cast(settings_result)); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + // The client can write more after receiving the INITIAL_WINDOW_SIZE setting. + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, 14465, 0x0, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::DATA})); +} + +TEST(OgHttp2AdapterTest, InvalidInitialWindowSetting) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const uint32_t kTooLargeInitialWindow = 1u << 31; + const std::string initial_frames = + TestFrameSequence() + .Settings({{INITIAL_WINDOW_SIZE, kTooLargeInitialWindow}}) + .Serialize(); + // Server preface (SETTINGS with INITIAL_STREAM_WINDOW) + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, + OnInvalidFrame( + 0, Http2VisitorInterface::InvalidFrameError::kFlowControl)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kFlowControlError)); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), static_cast(initial_result)); + + // Session will want to write a GOAWAY. + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL( + visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::FLOW_CONTROL_ERROR))); + + int64_t result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::GOAWAY})); + visitor.Clear(); +} + +TEST(OggHttp2AdapterClientTest, InitialWindowSettingCausesOverflow) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + const int32_t stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + ASSERT_GT(stream_id, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + int64_t write_result = adapter->Send(); + EXPECT_EQ(0, write_result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); + visitor.Clear(); + + const uint32_t kLargeInitialWindow = (1u << 31) - 1; + const std::string frames = + TestFrameSequence() + .ServerPreface() + .Headers(stream_id, {{":status", "200"}}, /*fin=*/false) + .WindowUpdate(stream_id, 65536u) + .Settings({{INITIAL_WINDOW_SIZE, kLargeInitialWindow}}) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(stream_id, _, HEADERS, 0x4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(stream_id)); + EXPECT_CALL(visitor, OnHeaderForStream(stream_id, ":status", "200")); + EXPECT_CALL(visitor, OnEndHeadersForStream(stream_id)); + + EXPECT_CALL(visitor, OnFrameHeader(stream_id, 4, WINDOW_UPDATE, 0x0)); + EXPECT_CALL(visitor, OnWindowUpdate(stream_id, 65536)); + + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting(Http2Setting{INITIAL_WINDOW_SIZE, + kLargeInitialWindow})); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + + // The stream window update plus the SETTINGS frame with INITIAL_WINDOW_SIZE + // pushes the stream's flow control window outside of the acceptable range. + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, stream_id, 4, 0x0)); + EXPECT_CALL( + visitor, + OnFrameSent(RST_STREAM, stream_id, 4, 0x0, + static_cast(Http2ErrorCode::FLOW_CONTROL_ERROR))); + EXPECT_CALL(visitor, + OnCloseStream(stream_id, Http2ErrorCode::HTTP2_NO_ERROR)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(OgHttp2AdapterTest, FailureSendingConnectionPreface) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + visitor.set_has_write_error(); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kSendError)); + + int result = adapter->Send(); + EXPECT_LT(result, 0); +} + +TEST(OgHttp2AdapterTest, MaxFrameSizeSettingNotAppliedBeforeAck) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const uint32_t large_frame_size = kDefaultFramePayloadSizeLimit + 42; + adapter->SubmitSettings({{MAX_FRAME_SIZE, large_frame_size}}); + const int32_t stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + /*data_source=*/nullptr, /*user_data=*/nullptr); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(adapter->want_write()); + + testing::InSequence s; + + // Client preface (SETTINGS with MAX_FRAME_SIZE) and request HEADERS + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string server_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, {{":status", "200"}}, /*fin=*/false) + .Data(1, std::string(large_frame_size, 'a')) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + // Response HEADERS. Because the SETTINGS with MAX_FRAME_SIZE was not + // acknowledged, the large DATA is treated as a connection error. Note that + // oghttp2 delivers the DATA frame header and connection error events. + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, large_frame_size, DATA, 0x0)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t process_result = adapter->ProcessBytes(server_frames); + EXPECT_EQ(server_frames.size(), static_cast(process_result)); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::FRAME_SIZE_ERROR))); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterTest, MaxFrameSizeSettingAppliedAfterAck) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const uint32_t large_frame_size = kDefaultFramePayloadSizeLimit + 42; + adapter->SubmitSettings({{MAX_FRAME_SIZE, large_frame_size}}); + const int32_t stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + /*data_source=*/nullptr, /*user_data=*/nullptr); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(adapter->want_write()); + + testing::InSequence s; + + // Client preface (SETTINGS with MAX_FRAME_SIZE) and request HEADERS + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string server_frames = + TestFrameSequence() + .ServerPreface() + .SettingsAck() + .Headers(1, {{":status", "200"}}, /*fin=*/false) + .Data(1, std::string(large_frame_size, 'a')) + .Serialize(); + + // Server preface (empty SETTINGS) and ack of SETTINGS. + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, ACK_FLAG)); + EXPECT_CALL(visitor, OnSettingsAck()); + + // Response HEADERS and DATA. Because the SETTINGS with MAX_FRAME_SIZE was + // acknowledged, the large DATA is accepted without any error. + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, large_frame_size, DATA, 0x0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, large_frame_size)); + EXPECT_CALL(visitor, OnDataForStream(1, _)); + + const int64_t process_result = adapter->ProcessBytes(server_frames); + EXPECT_EQ(server_frames.size(), static_cast(process_result)); + + // Client ack of SETTINGS. + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(OgHttp2AdapterTest, ClientForbidsPushPromise) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + + int write_result = adapter->Send(); + EXPECT_EQ(0, write_result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({SpdyFrameType::SETTINGS})); + + visitor.Clear(); + + const std::vector
headers = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + const int32_t stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + ASSERT_GT(stream_id, 0); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + write_result = adapter->Send(); + EXPECT_EQ(0, write_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::vector
push_headers = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/push"}}); + const std::string frames = TestFrameSequence() + .ServerPreface() + .SettingsAck() + .PushPromise(stream_id, 2, push_headers) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + // SETTINGS ack (to acknowledge PUSH_ENABLED=0, though this is not explicitly + // required for OgHttp2: should it be?) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, ACK_FLAG)); + EXPECT_CALL(visitor, OnSettingsAck); + + // The PUSH_PROMISE is treated as an invalid frame. + EXPECT_CALL(visitor, OnFrameHeader(stream_id, _, PUSH_PROMISE, _)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kInvalidPushPromise)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterTest, ClientForbidsPushStream) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + + int write_result = adapter->Send(); + EXPECT_EQ(0, write_result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({SpdyFrameType::SETTINGS})); + + visitor.Clear(); + + const std::vector
headers = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + const int32_t stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + ASSERT_GT(stream_id, 0); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + write_result = adapter->Send(); + EXPECT_EQ(0, write_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string frames = + TestFrameSequence() + .ServerPreface() + .SettingsAck() + .Headers(2, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + // SETTINGS ack (to acknowledge PUSH_ENABLED=0, though this is not explicitly + // required for OgHttp2: should it be?) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, ACK_FLAG)); + EXPECT_CALL(visitor, OnSettingsAck); + + // The push HEADERS are invalid. + EXPECT_CALL(visitor, OnFrameHeader(2, _, HEADERS, _)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kInvalidNewStreamId)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterTest, ClientReceivesDataOnClosedStream) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({SpdyFrameType::SETTINGS})); + visitor.Clear(); + + const std::string initial_frames = + TestFrameSequence().ServerPreface().Serialize(); + testing::InSequence s; + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), static_cast(initial_result)); + + // Client SETTINGS ack + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); + visitor.Clear(); + + // Let the client open a stream with a request. + int stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + nullptr, nullptr); + EXPECT_GT(stream_id, 0); + + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::HEADERS})); + visitor.Clear(); + + // Let the client RST_STREAM the stream it opened. + adapter->SubmitRst(stream_id, Http2ErrorCode::CANCEL); + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, stream_id, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(RST_STREAM, stream_id, _, 0x0, + static_cast(Http2ErrorCode::CANCEL))); + EXPECT_CALL(visitor, + OnCloseStream(stream_id, Http2ErrorCode::HTTP2_NO_ERROR)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::RST_STREAM})); + visitor.Clear(); + + // Let the server send a response on the stream. (It might not have received + // the RST_STREAM yet.) + const std::string response_frames = + TestFrameSequence() + .Headers(stream_id, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(stream_id, "This is the response body.", /*fin=*/true) + .Serialize(); + + // The visitor gets notified about the HEADERS frame and DATA frame for the + // closed stream with no further processing on either frame. + EXPECT_CALL(visitor, OnFrameHeader(stream_id, _, HEADERS, 0x4)); + EXPECT_CALL(visitor, OnFrameHeader(stream_id, _, DATA, END_STREAM_FLAG)); + + const int64_t response_result = adapter->ProcessBytes(response_frames); + EXPECT_EQ(response_frames.size(), static_cast(response_result)); + + EXPECT_FALSE(adapter->want_write()); +} + +TEST(OgHttp2AdapterTest, ClientEncountersFlowControlBlock) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const std::string kBody = std::string(100 * 1024, 'a'); + auto body1 = std::make_unique(visitor, false); + body1->AppendPayload(kBody); + body1->EndData(); + + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, std::move(body1), nullptr); + ASSERT_GT(stream_id1, 0); + + const std::vector
headers2 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}); + + auto body2 = std::make_unique(visitor, false); + body2->AppendPayload(kBody); + body2->EndData(); + + const int32_t stream_id2 = + adapter->SubmitRequest(headers2, std::move(body2), nullptr); + ASSERT_GT(stream_id2, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id2, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id2, _, 0x4, 0)); + // 4 DATA frames should saturate the default 64kB stream/connection flow + // control window. + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id1, _, 0x0, 0)).Times(4); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_EQ(0, adapter->GetSendWindowSize()); + + const std::string stream_frames = TestFrameSequence() + .ServerPreface() + .WindowUpdate(0, 80000) + .WindowUpdate(stream_id1, 20000) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(0, 80000)); + EXPECT_CALL(visitor, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(1, 20000)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id2, _, 0x0, 0)) + .Times(testing::AtLeast(1)); + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id1, _, 0x0, 0)) + .Times(testing::AtLeast(1)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); +} + +TEST(OgHttp2AdapterTest, ClientSendsTrailersAfterFlowControlBlock) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + auto body1 = std::make_unique(visitor, false); + body1->AppendPayload("Really small body."); + body1->EndData(); + + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, std::move(body1), nullptr); + ASSERT_GT(stream_id1, 0); + + const std::vector
headers2 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}); + + const std::string kBody = std::string(100 * 1024, 'a'); + auto body2 = std::make_unique(visitor, false); + body2->AppendPayload(kBody); + body2->EndData(); + + const int32_t stream_id2 = + adapter->SubmitRequest(headers2, std::move(body2), nullptr); + ASSERT_GT(stream_id2, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id2, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id2, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id1, _, 0x0, 0)).Times(1); + // 4 DATA frames should saturate the default 64kB stream/connection flow + // control window. + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id2, _, 0x0, 0)).Times(4); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_FALSE(adapter->want_write()); + EXPECT_EQ(0, adapter->GetSendWindowSize()); + + const std::vector
trailers1 = + ToHeaders({{"extra-info", "Trailers are weird but good?"}}); + adapter->SubmitTrailer(stream_id1, trailers1); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); +} + +TEST(OgHttp2AdapterTest, ClientSendsMetadataAfterFlowControlBlock) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const std::string kBody = std::string(100 * 1024, 'a'); + auto body1 = std::make_unique(visitor, false); + body1->AppendPayload(kBody); + body1->EndData(); + + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, std::move(body1), nullptr); + ASSERT_GT(stream_id1, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x4, 0)); + // 4 DATA frames should saturate the default 64kB stream/connection flow + // control window. + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id1, _, 0x0, 0)).Times(4); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_FALSE(adapter->want_write()); + EXPECT_EQ(0, adapter->GetSendWindowSize()); + + auto source = std::make_unique(ToHeaderBlock(ToHeaders( + {{"query-cost", "is too darn high"}, {"secret-sauce", "hollandaise"}}))); + adapter->SubmitMetadata(1, 16384u, std::move(source)); + EXPECT_CALL(visitor, OnBeforeFrameSent(kMetadataFrameType, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(kMetadataFrameType, 1, _, 0x4, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); +} + +TEST(OgHttp2AdapterTest, ClientQueuesRequests) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + + adapter->Send(); + + const std::string initial_frames = + TestFrameSequence() + .ServerPreface({{MAX_CONCURRENT_STREAMS, 2}}) + .SettingsAck() + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0x0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting(Http2Setting{ + Http2KnownSettingsId::MAX_CONCURRENT_STREAMS, 2u})); + EXPECT_CALL(visitor, OnSettingsEnd()); + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, ACK_FLAG)); + EXPECT_CALL(visitor, OnSettingsAck()); + + adapter->ProcessBytes(initial_frames); + + const std::vector
headers = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/example/request"}}); + std::vector stream_ids; + // Start two, which hits the limit. + int32_t stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + stream_ids.push_back(stream_id); + stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + stream_ids.push_back(stream_id); + // Start two more, which must be queued. + stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + stream_ids.push_back(stream_id); + stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + stream_ids.push_back(stream_id); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_ids[0], _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_ids[0], _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_ids[1], _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_ids[1], _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + adapter->Send(); + + const std::string update_streams = + TestFrameSequence().Settings({{MAX_CONCURRENT_STREAMS, 5}}).Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0x0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting(Http2Setting{ + Http2KnownSettingsId::MAX_CONCURRENT_STREAMS, 5u})); + EXPECT_CALL(visitor, OnSettingsEnd()); + + adapter->ProcessBytes(update_streams); + stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + stream_ids.push_back(stream_id); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_ids[2], _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_ids[2], _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_ids[3], _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_ids[3], _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_ids[4], _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_ids[4], _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + // Header frames should all have been sent in order, regardless of any + // queuing. + + adapter->Send(); +} + +TEST(OgHttp2AdapterTest, ClientAcceptsHeadResponseWithContentLength) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kClient; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::vector
headers = ToHeaders({{":method", "HEAD"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/"}}); + const int32_t stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + + testing::InSequence s; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + adapter->Send(); + + const std::string initial_frames = + TestFrameSequence() + .ServerPreface() + .SettingsAck() + .Headers(stream_id, {{":status", "200"}, {"content-length", "101"}}, + /*fin=*/true) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, _, SETTINGS, 0x0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, ACK_FLAG)); + EXPECT_CALL(visitor, OnSettingsAck()); + EXPECT_CALL(visitor, OnFrameHeader(stream_id, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(stream_id)); + EXPECT_CALL(visitor, OnHeaderForStream).Times(2); + EXPECT_CALL(visitor, OnEndHeadersForStream(stream_id)); + EXPECT_CALL(visitor, OnEndStream(stream_id)); + EXPECT_CALL(visitor, + OnCloseStream(stream_id, Http2ErrorCode::HTTP2_NO_ERROR)); + + adapter->ProcessBytes(initial_frames); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + + adapter->Send(); +} + +TEST(OgHttp2AdapterTest, SubmitMetadata) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + auto source = std::make_unique(ToHeaderBlock(ToHeaders( + {{"query-cost", "is too darn high"}, {"secret-sauce", "hollandaise"}}))); + adapter->SubmitMetadata(1, 16384u, std::move(source)); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(kMetadataFrameType, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(kMetadataFrameType, 1, _, 0x4, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, + static_cast(kMetadataFrameType)})); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(OgHttp2AdapterTest, SubmitMetadataMultipleFrames) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const auto kLargeValue = std::string(63 * 1024, 'a'); + auto source = std::make_unique( + ToHeaderBlock(ToHeaders({{"large-value", kLargeValue}}))); + adapter->SubmitMetadata(1, 16384u, std::move(source)); + EXPECT_TRUE(adapter->want_write()); + + testing::InSequence seq; + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(kMetadataFrameType, 1, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(kMetadataFrameType, 1, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(kMetadataFrameType, 1, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(kMetadataFrameType, 1, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(kMetadataFrameType, 1, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(kMetadataFrameType, 1, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(kMetadataFrameType, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(kMetadataFrameType, 1, _, 0x4, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + EqualsFrames({SpdyFrameType::SETTINGS, + static_cast(kMetadataFrameType), + static_cast(kMetadataFrameType), + static_cast(kMetadataFrameType), + static_cast(kMetadataFrameType)})); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(OgHttp2AdapterTest, SubmitConnectionMetadata) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + auto source = std::make_unique(ToHeaderBlock(ToHeaders( + {{"query-cost", "is too darn high"}, {"secret-sauce", "hollandaise"}}))); + adapter->SubmitMetadata(0, 16384u, std::move(source)); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(kMetadataFrameType, 0, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(kMetadataFrameType, 0, _, 0x4, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, + static_cast(kMetadataFrameType)})); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(OgHttp2AdapterTest, GetSendWindowSize) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const int peer_window = adapter->GetSendWindowSize(); + EXPECT_EQ(peer_window, kInitialFlowControlWindowSize); +} + +TEST(OgHttp2AdapterTest, WindowUpdateZeroDelta) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string data_chunk(kDefaultFramePayloadSizeLimit, 'a'); + const std::string request = + TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}}, + /*fin=*/false) + .WindowUpdate(1, 0) + .Data(1, "Subsequent frames on stream 1 are not delivered.") + .Serialize(); + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + EXPECT_CALL(visitor, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); + + adapter->ProcessBytes(request); + + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, _)); + + adapter->Send(); + + const std::string window_update = + TestFrameSequence().WindowUpdate(0, 0).Serialize(); + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kFlowControlError)); + adapter->ProcessBytes(window_update); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + adapter->Send(); +} + +TEST(OgHttp2AdapterTest, WindowUpdateCausesWindowOverflow) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string data_chunk(kDefaultFramePayloadSizeLimit, 'a'); + const std::string request = + TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}}, + /*fin=*/false) + .WindowUpdate(1, std::numeric_limits::max()) + .Data(1, "Subsequent frames on stream 1 are not delivered.") + .Serialize(); + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + EXPECT_CALL(visitor, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); + + adapter->ProcessBytes(request); + + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0)); + EXPECT_CALL( + visitor, + OnFrameSent(RST_STREAM, 1, _, 0x0, + static_cast(Http2ErrorCode::FLOW_CONTROL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, _)); + + adapter->Send(); + + const std::string window_update = + TestFrameSequence() + .WindowUpdate(0, std::numeric_limits::max()) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kFlowControlError)); + adapter->ProcessBytes(window_update); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL( + visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::FLOW_CONTROL_ERROR))); + adapter->Send(); +} + +TEST(OgHttp2AdapterTest, WindowUpdateRaisesFlowControlWindowLimit) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string data_chunk(kDefaultFramePayloadSizeLimit, 'a'); + const std::string request = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}}, + /*fin=*/false) + .Serialize(); + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + adapter->ProcessBytes(request); + + // Updates the advertised window for the connection and stream 1. + adapter->SubmitWindowUpdate(0, 2 * kDefaultFramePayloadSizeLimit); + adapter->SubmitWindowUpdate(1, 2 * kDefaultFramePayloadSizeLimit); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(WINDOW_UPDATE, 0, 4, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(WINDOW_UPDATE, 0, 4, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(WINDOW_UPDATE, 1, 4, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(WINDOW_UPDATE, 1, 4, 0x0, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + + // Verifies the advertised window. + EXPECT_EQ(kInitialFlowControlWindowSize + 2 * kDefaultFramePayloadSizeLimit, + adapter->GetReceiveWindowSize()); + EXPECT_EQ(kInitialFlowControlWindowSize + 2 * kDefaultFramePayloadSizeLimit, + adapter->GetStreamReceiveWindowSize(1)); + + const std::string request_body = TestFrameSequence() + .Data(1, data_chunk) + .Data(1, data_chunk) + .Data(1, data_chunk) + .Data(1, data_chunk) + .Data(1, data_chunk) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 0)).Times(5); + EXPECT_CALL(visitor, OnBeginDataForStream(1, _)).Times(5); + EXPECT_CALL(visitor, OnDataForStream(1, _)).Times(5); + + // DATA frames on stream 1 consume most of the window. + adapter->ProcessBytes(request_body); + EXPECT_EQ(kInitialFlowControlWindowSize - 3 * kDefaultFramePayloadSizeLimit, + adapter->GetReceiveWindowSize()); + EXPECT_EQ(kInitialFlowControlWindowSize - 3 * kDefaultFramePayloadSizeLimit, + adapter->GetStreamReceiveWindowSize(1)); + + // Marking the data consumed should result in an advertised window larger than + // the initial window. + adapter->MarkDataConsumedForStream(1, 4 * kDefaultFramePayloadSizeLimit); + EXPECT_GT(adapter->GetReceiveWindowSize(), kInitialFlowControlWindowSize); + EXPECT_GT(adapter->GetStreamReceiveWindowSize(1), + kInitialFlowControlWindowSize); +} + +TEST(OgHttp2AdapterTest, MarkDataConsumedForNonexistentStream) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + // Send some data on stream 1 so the connection window manager doesn't + // underflow later. + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Data(1, "Some data on stream 1") + .Serialize(); + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, _)); + EXPECT_CALL(visitor, OnDataForStream(1, _)); + + adapter->ProcessBytes(frames); + + // This should not cause a crash or QUICHE_BUG. + adapter->MarkDataConsumedForStream(3, 11); +} + +TEST(OgHttp2AdapterTest, TestSerialize) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + EXPECT_TRUE(adapter->want_read()); + EXPECT_FALSE(adapter->want_write()); + + adapter->SubmitSettings( + {{HEADER_TABLE_SIZE, 128}, {MAX_FRAME_SIZE, 128 << 10}}); + EXPECT_TRUE(adapter->want_write()); + + const Http2StreamId accepted_stream = 3; + const Http2StreamId rejected_stream = 7; + adapter->SubmitPriorityForStream(accepted_stream, 1, 255, true); + adapter->SubmitRst(rejected_stream, Http2ErrorCode::CANCEL); + adapter->SubmitPing(42); + adapter->SubmitGoAway(13, Http2ErrorCode::HTTP2_NO_ERROR, ""); + adapter->SubmitWindowUpdate(accepted_stream, 127); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(PRIORITY, accepted_stream, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(PRIORITY, accepted_stream, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, rejected_stream, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(RST_STREAM, rejected_stream, _, 0x0, 0x8)); + EXPECT_CALL(visitor, OnBeforeFrameSent(PING, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(PING, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(GOAWAY, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, + OnBeforeFrameSent(WINDOW_UPDATE, accepted_stream, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(WINDOW_UPDATE, accepted_stream, _, 0x0, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT( + visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::PRIORITY, + SpdyFrameType::RST_STREAM, SpdyFrameType::PING, + SpdyFrameType::GOAWAY, SpdyFrameType::WINDOW_UPDATE})); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(OgHttp2AdapterTest, TestPartialSerialize) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + EXPECT_FALSE(adapter->want_write()); + + adapter->SubmitSettings( + {{HEADER_TABLE_SIZE, 128}, {MAX_FRAME_SIZE, 128 << 10}}); + adapter->SubmitGoAway(13, Http2ErrorCode::HTTP2_NO_ERROR, + "And don't come back!"); + adapter->SubmitPing(42); + EXPECT_TRUE(adapter->want_write()); + + visitor.set_send_limit(20); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(GOAWAY, 0, _, 0x0, 0)); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(PING, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(PING, 0, _, 0x0, 0)); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_FALSE(adapter->want_write()); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::GOAWAY, + SpdyFrameType::PING})); +} + +TEST(OgHttp2AdapterTest, TestStreamInitialWindowSizeUpdates) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + adapter->SubmitSettings({{INITIAL_WINDOW_SIZE, 80000}}); + EXPECT_TRUE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0x4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + + // New stream window size has not yet been applied. + EXPECT_EQ(adapter->GetStreamReceiveWindowSize(1), 65535); + + // Server initial SETTINGS + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + // SETTINGS ack + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + int result = adapter->Send(); + EXPECT_EQ(0, result); + + // New stream window size has still not been applied. + EXPECT_EQ(adapter->GetStreamReceiveWindowSize(1), 65535); + + const std::string ack = TestFrameSequence().SettingsAck().Serialize(); + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, ACK_FLAG)); + EXPECT_CALL(visitor, OnSettingsAck()); + adapter->ProcessBytes(ack); + + // New stream window size has finally been applied upon SETTINGS ack. + EXPECT_EQ(adapter->GetStreamReceiveWindowSize(1), 80000); + + // Update the stream window size again. + adapter->SubmitSettings({{INITIAL_WINDOW_SIZE, 90000}}); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + result = adapter->Send(); + EXPECT_EQ(0, result); + + // New stream window size has not yet been applied. + EXPECT_EQ(adapter->GetStreamReceiveWindowSize(1), 80000); + + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, ACK_FLAG)); + EXPECT_CALL(visitor, OnSettingsAck()); + adapter->ProcessBytes(ack); + + // New stream window size is applied after the ack. + EXPECT_EQ(adapter->GetStreamReceiveWindowSize(1), 90000); +} + +TEST(OgHttp2AdapterTest, ConnectionErrorOnControlFrameSent) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = + TestFrameSequence().ClientPreface().Ping(42).Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // PING + EXPECT_CALL(visitor, OnFrameHeader(0, _, PING, 0)); + EXPECT_CALL(visitor, OnPing(42, false)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + + EXPECT_TRUE(adapter->want_write()); + + // Server preface (SETTINGS) + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + // SETTINGS ack + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)) + .WillOnce(testing::Return(-902)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kSendError)); + + int send_result = adapter->Send(); + EXPECT_LT(send_result, 0); + + EXPECT_FALSE(adapter->want_write()); + + send_result = adapter->Send(); + EXPECT_LT(send_result, 0); +} + +TEST(OgHttp2AdapterTest, ConnectionErrorOnDataFrameSent) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, + OnFrameHeader(1, _, HEADERS, END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + + auto body = std::make_unique(visitor, true); + body->AppendPayload("Here is some data, which will lead to a fatal error"); + TestDataFrameSource* body_ptr = body.get(); + int submit_result = adapter->SubmitResponse( + 1, ToHeaders({{":status", "200"}}), std::move(body)); + ASSERT_EQ(0, submit_result); + + EXPECT_TRUE(adapter->want_write()); + + // Server preface (SETTINGS) + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + // SETTINGS ack + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + // Stream 1, with doomed DATA + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)) + .WillOnce(testing::Return(-902)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kSendError)); + + int send_result = adapter->Send(); + EXPECT_LT(send_result, 0); + + body_ptr->AppendPayload("After the fatal error, data will be sent no more"); + + EXPECT_FALSE(adapter->want_write()); + + send_result = adapter->Send(); + EXPECT_LT(send_result, 0); +} + +TEST(OgHttp2AdapterTest, ClientSendsContinuation) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true, + /*add_continuation=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 1)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnFrameHeader(1, _, CONTINUATION, 4)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); +} + +TEST(OgHttp2AdapterTest, ClientSendsMetadataWithContinuation) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = + TestFrameSequence() + .ClientPreface() + .Metadata(0, "Example connection metadata in multiple frames", true) + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false, + /*add_continuation=*/true) + .Metadata(1, + "Some stream metadata that's also sent in multiple frames", + true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Metadata on stream 0 + EXPECT_CALL(visitor, OnFrameHeader(0, _, kMetadataFrameType, 0)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnFrameHeader(0, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataEndForStream(0)); + + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnFrameHeader(1, _, CONTINUATION, 4)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + // Metadata on stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, kMetadataFrameType, 0)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataEndForStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + EXPECT_EQ("Example connection metadata in multiple frames", + absl::StrJoin(visitor.GetMetadata(0), "")); + EXPECT_EQ("Some stream metadata that's also sent in multiple frames", + absl::StrJoin(visitor.GetMetadata(1), "")); +} + +TEST(OgHttp2AdapterTest, RepeatedHeaderNames) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"accept", "text/plain"}, + {"accept", "text/html"}}, + /*fin=*/true) + .Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "accept", "text/plain")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "accept", "text/html")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + const std::vector
headers1 = ToHeaders( + {{":status", "200"}, {"content-length", "10"}, {"content-length", "10"}}); + auto body1 = std::make_unique(visitor, true); + body1->AppendPayload("perfection"); + body1->EndData(); + + int submit_result = adapter->SubmitResponse(1, headers1, std::move(body1)); + ASSERT_EQ(0, submit_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, 10, END_STREAM, 0)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::HEADERS, SpdyFrameType::DATA})); +} + +TEST(OgHttp2AdapterTest, ServerRespondsToRequestWithTrailers) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = + TestFrameSequence() + .ClientPreface() + .Headers(1, {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}) + .Data(1, "Example data, woohoo.") + .Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, _)); + EXPECT_CALL(visitor, OnDataForStream(1, _)); + + int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + const std::vector
headers1 = ToHeaders({{":status", "200"}}); + auto body1 = std::make_unique(visitor, true); + TestDataFrameSource* body1_ptr = body1.get(); + + int submit_result = adapter->SubmitResponse(1, headers1, std::move(body1)); + ASSERT_EQ(0, submit_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string more_frames = + TestFrameSequence() + .Headers(1, {{"extra-info", "Trailers are weird but good?"}}, + /*fin=*/true) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, "extra-info", + "Trailers are weird but good?")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + result = adapter->ProcessBytes(more_frames); + EXPECT_EQ(more_frames.size(), static_cast(result)); + + body1_ptr->EndData(); + EXPECT_EQ(true, adapter->ResumeStream(1)); + + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, 0, END_STREAM, 0)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::DATA})); +} + +TEST(OgHttp2AdapterTest, ServerReceivesMoreHeaderBytesThanConfigured) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + options.max_header_list_bytes = 42; + auto adapter = OgHttp2Adapter::Create(visitor, options); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = + TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"from-douglas-de-fermat", + "I have discovered a truly marvelous answer to the life, " + "the universe, and everything that the header setting is " + "too narrow to contain."}}, + /*fin=*/true) + .Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, + OnFrameHeader(1, _, HEADERS, END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(result), frames.size()); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::COMPRESSION_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterTest, ServerSubmitsResponseWithDataSourceError) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + auto body1 = std::make_unique(visitor, false); + body1->SimulateError(); + int submit_result = adapter->SubmitResponse( + 1, ToHeaders({{":status", "200"}, {"x-comment", "Sure, sounds good."}}), + std::move(body1)); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + // TODO(birenroy): Send RST_STREAM INTERNAL_ERROR to the client as well. + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::INTERNAL_ERROR)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::HEADERS})); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); + + // Since the stream has been closed, it is not possible to submit trailers for + // the stream. + int trailer_result = + adapter->SubmitTrailer(1, ToHeaders({{":final-status", "a-ok"}})); + ASSERT_LT(trailer_result, 0); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(OgHttp2AdapterTest, CompleteRequestWithServerResponse) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = + TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Data(1, "This is the response body.", /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 1)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, _)); + EXPECT_CALL(visitor, OnDataForStream(1, _)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + int submit_result = + adapter->SubmitResponse(1, ToHeaders({{":status", "200"}}), nullptr); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::HEADERS})); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(OgHttp2AdapterTest, IncompleteRequestWithServerResponse) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + int submit_result = + adapter->SubmitResponse(1, ToHeaders({{":status", "200"}}), nullptr); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + // RST_STREAM NO_ERROR option is disabled. + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::HEADERS})); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(OgHttp2AdapterTest, IncompleteRequestWithServerResponseRstStreamEnabled) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + options.rst_stream_no_error_when_incomplete = true; + auto adapter = OgHttp2Adapter::Create(visitor, options); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + int submit_result = + adapter->SubmitResponse(1, ToHeaders({{":status", "200"}}), nullptr); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, 4, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(RST_STREAM, 1, 4, 0x0, 0)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT( + visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::HEADERS, SpdyFrameType::RST_STREAM})); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(OgHttp2AdapterTest, ServerHandlesMultipleContentLength) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/1"}, + {"content-length", "7"}, + {"content-length", "7"}}, + /*fin=*/false) + .Headers(3, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/3"}, + {"content-length", "11"}, + {"content-length", "13"}}, + /*fin=*/false) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/1")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "content-length", "7")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + // Stream 3 + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":path", "/3")); + EXPECT_CALL(visitor, OnHeaderForStream(3, "content-length", "11")); + EXPECT_CALL( + visitor, + OnInvalidFrame(3, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); +} + +TEST(OgHttp2AdapterTest, ServerSendsInvalidTrailers) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + const absl::string_view kBody = "This is an example response body."; + + // The body source must indicate that the end of the body is not the end of + // the stream. + auto body1 = std::make_unique(visitor, false); + body1->AppendPayload(kBody); + body1->EndData(); + int submit_result = adapter->SubmitResponse( + 1, ToHeaders({{":status", "200"}, {"x-comment", "Sure, sounds good."}}), + std::move(body1)); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::HEADERS, SpdyFrameType::DATA})); + EXPECT_THAT(visitor.data(), testing::HasSubstr(kBody)); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); + + // The body source has been exhausted by the call to Send() above. + int trailer_result = + adapter->SubmitTrailer(1, ToHeaders({{":final-status", "a-ok"}})); + ASSERT_EQ(trailer_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::HEADERS})); +} + +TEST(OgHttp2AdapterTest, ServerQueuesMetadataThenTrailers) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + const absl::string_view kBody = "This is an example response body."; + + // The body source must indicate that the end of the body is not the end of + // the stream. + auto body1 = std::make_unique(visitor, false); + body1->AppendPayload(kBody); + body1->EndData(); + int submit_result = adapter->SubmitResponse( + 1, ToHeaders({{":status", "200"}, {"x-comment", "Sure, sounds good."}}), + std::move(body1)); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::HEADERS, SpdyFrameType::DATA})); + EXPECT_THAT(visitor.data(), testing::HasSubstr(kBody)); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); + + spdy::Http2HeaderBlock block; + block["key"] = "wild value!"; + adapter->SubmitMetadata( + 1, 16384u, std::make_unique(std::move(block))); + + int trailer_result = + adapter->SubmitTrailer(1, ToHeaders({{":final-status", "a-ok"}})); + ASSERT_EQ(trailer_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(kMetadataFrameType, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(kMetadataFrameType, 1, _, 0x4, 0)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({static_cast(kMetadataFrameType), + SpdyFrameType::HEADERS})); +} + +TEST(OgHttp2AdapterTest, ServerHandlesDataWithPadding) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Data(1, "This is the request body.", + /*fin=*/true, /*padding_length=*/39) + .Headers(3, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 25 + 39, DATA, 0x9)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 25 + 39)); + // Note: oghttp2 passes padding information before the actual data. + EXPECT_CALL(visitor, OnDataPaddingLength(1, 39)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the request body.")); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(3)); + EXPECT_CALL(visitor, OnEndStream(3)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(frames.size()), result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS})); +} + +TEST(OgHttp2AdapterTest, ServerHandlesHostHeader) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":path", "/this/is/request/one"}, + {"host", "example.com"}}, + /*fin=*/true) + .Headers(3, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"host", "example.com"}}, + /*fin=*/true) + .Headers(5, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "foo.com"}, + {":path", "/this/is/request/one"}, + {"host", "bar.com"}}, + /*fin=*/true) + .Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, _, _)).Times(5); + EXPECT_CALL(visitor, OnEndHeadersForStream(3)); + EXPECT_CALL(visitor, OnEndStream(3)); + + EXPECT_CALL(visitor, OnFrameHeader(5, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(5)); + EXPECT_CALL(visitor, OnHeaderForStream(5, _, _)).Times(4); + EXPECT_CALL( + visitor, + OnInvalidFrame(5, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 5, 4, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 5, 4, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(5, Http2ErrorCode::HTTP2_NO_ERROR)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + visitor.Clear(); +} + +// Tests the case where the response body is in the progress of being sent while +// trailers are queued. +TEST(OgHttp2AdapterTest, ServerSubmitsTrailersWhileDataDeferred) { + DataSavingVisitor visitor; + for (const bool queue_trailers : {true, false}) { + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + options.trailers_require_end_data = queue_trailers; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .WindowUpdate(0, 2000) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(1, 2000)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, _)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the request body.")); + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(0, 2000)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + visitor.Clear(); + + const absl::string_view kBody = "This is an example response body."; + + // The body source must indicate that the end of the body is not the end of + // the stream. + auto body1 = std::make_unique(visitor, false); + body1->AppendPayload(kBody); + auto* body1_ptr = body1.get(); + int submit_result = adapter->SubmitResponse( + 1, ToHeaders({{":status", "200"}, {"x-comment", "Sure, sounds good."}}), + std::move(body1)); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); + + int trailer_result = + adapter->SubmitTrailer(1, ToHeaders({{"final-status", "a-ok"}})); + ASSERT_EQ(trailer_result, 0); + if (queue_trailers) { + // Even though there are new trailers to write, the data source has not + // finished writing data and is blocked. + EXPECT_FALSE(adapter->want_write()); + + body1_ptr->EndData(); + adapter->ResumeStream(1); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL( + visitor, + OnBeforeFrameSent(HEADERS, 1, _, END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + } else { + // Even though the data source has not finished sending data, the library + // will write the trailers anyway. + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL( + visitor, + OnBeforeFrameSent(HEADERS, 1, _, END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::HEADERS})); + } + } +} + +TEST(OgHttp2AdapterTest, ClientDisobeysConnectionFlowControl) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"accept", "some bogus value!"}}, + /*fin=*/false) + // 70000 bytes of data + .Data(1, std::string(16384, 'a')) + .Data(1, std::string(16384, 'a')) + .Data(1, std::string(16384, 'a')) + .Data(1, std::string(16384, 'a')) + .Data(1, std::string(4464, 'a')) + .Serialize(); + + testing::InSequence s; + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream).Times(5); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 16384, DATA, 0x0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 16384)); + EXPECT_CALL(visitor, OnDataForStream(1, _)); + EXPECT_CALL(visitor, OnFrameHeader(1, 16384, DATA, 0x0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 16384)); + EXPECT_CALL(visitor, OnDataForStream(1, _)); + EXPECT_CALL(visitor, OnFrameHeader(1, 16384, DATA, 0x0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 16384)); + EXPECT_CALL(visitor, OnDataForStream(1, _)); + EXPECT_CALL(visitor, OnFrameHeader(1, 16384, DATA, 0x0)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kFlowControlError)); + // No further frame data or headers are delivered. + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL( + visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::FLOW_CONTROL_ERROR))); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterTest, ClientDisobeysConnectionFlowControlWithOneDataFrame) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + // Allow the client to send a DATA frame that exceeds the connection flow + // control window. + const uint32_t window_overflow_bytes = kInitialFlowControlWindowSize + 1; + adapter->SubmitSettings({{MAX_FRAME_SIZE, window_overflow_bytes}}); + + const std::string initial_frames = + TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + int64_t process_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), static_cast(process_result)); + + EXPECT_TRUE(adapter->want_write()); + + // Outbound SETTINGS containing MAX_FRAME_SIZE. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + + // Ack of client's initial settings. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS})); + visitor.Clear(); + + // Now let the client ack the MAX_FRAME_SIZE SETTINGS and send a DATA frame to + // overflow the connection-level window. The result should be a GOAWAY. + const std::string overflow_frames = + TestFrameSequence() + .SettingsAck() + .Data(1, std::string(window_overflow_bytes, 'a')) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, ACK_FLAG)); + EXPECT_CALL(visitor, OnSettingsAck()); + EXPECT_CALL(visitor, OnFrameHeader(1, window_overflow_bytes, DATA, 0x0)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kFlowControlError)); + // No further frame data is delivered. + + process_result = adapter->ProcessBytes(overflow_frames); + EXPECT_EQ(overflow_frames.size(), static_cast(process_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL( + visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::FLOW_CONTROL_ERROR))); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterTest, ClientDisobeysConnectionFlowControlAcrossReads) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + // Allow the client to send a DATA frame that exceeds the connection flow + // control window. + const uint32_t window_overflow_bytes = kInitialFlowControlWindowSize + 1; + adapter->SubmitSettings({{MAX_FRAME_SIZE, window_overflow_bytes}}); + + const std::string initial_frames = + TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + int64_t process_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), static_cast(process_result)); + + EXPECT_TRUE(adapter->want_write()); + + // Outbound SETTINGS containing MAX_FRAME_SIZE. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + + // Ack of client's initial settings. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS})); + visitor.Clear(); + + // Now let the client ack the MAX_FRAME_SIZE SETTINGS and send a DATA frame to + // overflow the connection-level window. The result should be a GOAWAY. + const std::string overflow_frames = + TestFrameSequence() + .SettingsAck() + .Data(1, std::string(window_overflow_bytes, 'a')) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, ACK_FLAG)); + EXPECT_CALL(visitor, OnSettingsAck()); + EXPECT_CALL(visitor, OnFrameHeader(1, window_overflow_bytes, DATA, 0x0)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kFlowControlError)); + + const size_t chunk_length = 16384; + ASSERT_GE(overflow_frames.size(), chunk_length); + process_result = + adapter->ProcessBytes(overflow_frames.substr(0, chunk_length)); + EXPECT_EQ(chunk_length, static_cast(process_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL( + visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::FLOW_CONTROL_ERROR))); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterTest, ClientDisobeysStreamFlowControl) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"accept", "some bogus value!"}}, + /*fin=*/false) + .Serialize(); + const std::string more_frames = TestFrameSequence() + // 70000 bytes of data + .Data(1, std::string(16384, 'a')) + .Data(1, std::string(16384, 'a')) + .Data(1, std::string(16384, 'a')) + .Data(1, std::string(16384, 'a')) + .Data(1, std::string(4464, 'a')) + .Serialize(); + + testing::InSequence s; + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream).Times(5); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + adapter->SubmitWindowUpdate(0, 20000); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(WINDOW_UPDATE, 0, 4, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(WINDOW_UPDATE, 0, 4, 0x0, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::WINDOW_UPDATE})); + visitor.Clear(); + + EXPECT_CALL(visitor, OnFrameHeader(1, 16384, DATA, 0x0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 16384)); + EXPECT_CALL(visitor, OnDataForStream(1, _)); + EXPECT_CALL(visitor, OnFrameHeader(1, 16384, DATA, 0x0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 16384)); + EXPECT_CALL(visitor, OnDataForStream(1, _)); + EXPECT_CALL(visitor, OnFrameHeader(1, 16384, DATA, 0x0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 16384)); + EXPECT_CALL(visitor, OnDataForStream(1, _)); + EXPECT_CALL(visitor, OnFrameHeader(1, 16384, DATA, 0x0)); + // No further frame data or headers are delivered. + + result = adapter->ProcessBytes(more_frames); + EXPECT_EQ(more_frames.size(), static_cast(result)); + + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, 4, 0x0)); + EXPECT_CALL( + visitor, + OnFrameSent(RST_STREAM, 1, 4, 0x0, + static_cast(Http2ErrorCode::FLOW_CONTROL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::RST_STREAM})); +} + +TEST(OgHttp2AdapterTest, ServerErrorWhileHandlingHeaders) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"accept", "some bogus value!"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .WindowUpdate(0, 2000) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "accept", "some bogus value!")) + .WillOnce(testing::Return(Http2VisitorInterface::HEADER_RST_STREAM)); + // Stream WINDOW_UPDATE and DATA frames are not delivered to the visitor. + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(0, 2000)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, 4, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, 4, 0x0, + static_cast(Http2ErrorCode::INTERNAL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS ack + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(OgHttp2AdapterTest, ServerErrorWhileHandlingHeadersDropsFrames) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"accept", "some bogus value!"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .Metadata(1, "This is the request metadata.") + .RstStream(1, Http2ErrorCode::CANCEL) + .WindowUpdate(0, 2000) + .Headers(3, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/false) + .Metadata(3, "This is the request metadata.", + /*multiple_frames=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnHeaderForStream(1, "accept", "some bogus value!")) + .WillOnce(testing::Return(Http2VisitorInterface::HEADER_RST_STREAM)); + // Frames for the RST_STREAM-marked stream are not delivered to the visitor. + // Note: nghttp2 still delivers control frames and metadata for the stream. + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(0, 2000)); + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(3)); + EXPECT_CALL(visitor, OnFrameHeader(3, _, kMetadataFrameType, 0)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(3, _)); + EXPECT_CALL(visitor, OnMetadataForStream(3, "This is the re")) + .WillOnce(testing::DoAll(testing::InvokeWithoutArgs([&adapter]() { + adapter->SubmitRst( + 3, Http2ErrorCode::REFUSED_STREAM); + }), + testing::Return(true))); + // The rest of the metadata is not delivered to the visitor. + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, 4, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, 4, 0x0, + static_cast(Http2ErrorCode::INTERNAL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 3, 4, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 3, 4, 0x0, + static_cast(Http2ErrorCode::REFUSED_STREAM))); + EXPECT_CALL(visitor, OnCloseStream(3, Http2ErrorCode::HTTP2_NO_ERROR)); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS ack + EXPECT_THAT( + visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM, SpdyFrameType::RST_STREAM})); +} + +TEST(OgHttp2AdapterTest, ServerConnectionErrorWhileHandlingHeaders) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"Accept", "uppercase, oh boy!"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .WindowUpdate(0, 2000) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL( + visitor, + OnInvalidFrame(1, Http2VisitorInterface::InvalidFrameError::kHttpHeader)) + .WillOnce(testing::Return(false)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kHeaderError)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_LT(result, 0); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, 4, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, 4, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::RST_STREAM, + SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterTest, ServerErrorAfterHandlingHeaders) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .WindowUpdate(0, 2000) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)) + .WillOnce(testing::Return(false)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_LT(result, 0); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::INTERNAL_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::GOAWAY})); +} + +// Exercises the case when a visitor chooses to reject a frame based solely on +// the frame header, which is a fatal error for the connection. +TEST(OgHttp2AdapterTest, ServerRejectsFrameHeader) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Ping(64) + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .WindowUpdate(0, 2000) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(0, 8, PING, 0)) + .WillOnce(testing::Return(false)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_LT(result, 0); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::INTERNAL_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterTest, ServerRejectsBeginningOfData) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Data(1, "This is the request body.") + .Headers(3, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .RstStream(3, Http2ErrorCode::CANCEL) + .Ping(47) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 25, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 25)) + .WillOnce(testing::Return(false)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_LT(result, 0); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::INTERNAL_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterTest, ServerReceivesTooLargeHeader) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + options.max_header_list_bytes = 64 * 1024; + options.max_header_field_size = 64 * 1024; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + // Due to configuration, the library will accept a maximum of 64kB of huffman + // encoded data per header field. + const std::string too_large_value = std::string(80 * 1024, 'q'); + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"x-toobig", too_large_value}}, + /*fin=*/true) + .Headers(3, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, END_STREAM_FLAG)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnFrameHeader(1, _, CONTINUATION, 0)).Times(3); + EXPECT_CALL(visitor, OnFrameHeader(1, _, CONTINUATION, END_HEADERS_FLAG)); + // Further header processing is skipped, as the header field is too large. + + EXPECT_CALL(visitor, + OnFrameHeader(3, _, HEADERS, END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(3)); + EXPECT_CALL(visitor, OnEndStream(3)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(frames.size()), result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, 4, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, 4, 0x0, + static_cast(Http2ErrorCode::INTERNAL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(OgHttp2AdapterTest, ServerReceivesInvalidAuthority) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "ex|ample.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL( + visitor, + OnInvalidFrame(1, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(frames.size()), result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0x0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0x0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, 4, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, 4, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(OgHttpAdapterTest, ServerReceivesGoAway) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .GoAway(0, Http2ErrorCode::HTTP2_NO_ERROR, "") + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, + OnFrameHeader(1, _, HEADERS, END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(0, _, GOAWAY, 0x0)); + EXPECT_CALL(visitor, OnGoAway(0, Http2ErrorCode::HTTP2_NO_ERROR, "")); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(frames.size()), result); + + // The server should still be able to send a response after receiving a GOAWAY + // with a lower last-stream-ID field, as the stream was client-initiated. + const int submit_result = + adapter->SubmitResponse(1, ToHeaders({{":status", "200"}}), + /*data_source=*/nullptr); + ASSERT_EQ(0, submit_result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0x0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0x0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::HEADERS})); +} + +TEST(OgHttp2AdapterTest, ServerSubmitResponse) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + const char* kSentinel1 = "arbitrary pointer 1"; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)) + .WillOnce(testing::InvokeWithoutArgs([&adapter, kSentinel1]() { + adapter->SetStreamUserData(1, const_cast(kSentinel1)); + return true; + })); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + EXPECT_EQ(1, adapter->GetHighestReceivedStreamId()); + + // Server will want to send a SETTINGS and a SETTINGS ack. + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS})); + visitor.Clear(); + + EXPECT_EQ(0, adapter->GetHpackEncoderDynamicTableSize()); + + EXPECT_FALSE(adapter->want_write()); + const absl::string_view kBody = "This is an example response body."; + // A data fin is not sent so that the stream remains open, and the flow + // control state can be verified. + auto body1 = std::make_unique(visitor, false); + body1->AppendPayload(kBody); + int submit_result = adapter->SubmitResponse( + 1, + ToHeaders({{":status", "404"}, + {"x-comment", "I have no idea what you're talking about."}}), + std::move(body1)); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + // Stream user data should have been set successfully after receiving headers. + EXPECT_EQ(kSentinel1, adapter->GetStreamUserData(1)); + adapter->SetStreamUserData(1, nullptr); + EXPECT_EQ(nullptr, adapter->GetStreamUserData(1)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::HEADERS, SpdyFrameType::DATA})); + EXPECT_THAT(visitor.data(), testing::HasSubstr(kBody)); + EXPECT_FALSE(adapter->want_write()); + + // Some data was sent, so the remaining send window size should be less than + // the default. + EXPECT_LT(adapter->GetStreamSendWindowSize(1), kInitialFlowControlWindowSize); + EXPECT_GT(adapter->GetStreamSendWindowSize(1), 0); + // Send window for a nonexistent stream is not available. + EXPECT_EQ(adapter->GetStreamSendWindowSize(3), -1); + + EXPECT_GT(adapter->GetHpackEncoderDynamicTableSize(), 0); +} + +TEST(OgHttp2AdapterTest, ServerSubmitResponseWithResetFromClient) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + EXPECT_EQ(1, adapter->GetHighestReceivedStreamId()); + + // Server will want to send a SETTINGS and a SETTINGS ack. + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS})); + visitor.Clear(); + + EXPECT_FALSE(adapter->want_write()); + const absl::string_view kBody = "This is an example response body."; + auto body1 = std::make_unique(visitor, true); + body1->AppendPayload(kBody); + int submit_result = adapter->SubmitResponse( + 1, + ToHeaders({{":status", "404"}, + {"x-comment", "I have no idea what you're talking about."}}), + std::move(body1)); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + // Client resets the stream before the server can send the response. + const std::string reset = + TestFrameSequence().RstStream(1, Http2ErrorCode::CANCEL).Serialize(); + EXPECT_CALL(visitor, OnFrameHeader(1, 4, RST_STREAM, 0)); + EXPECT_CALL(visitor, OnRstStream(1, Http2ErrorCode::CANCEL)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::CANCEL)); + const int64_t reset_result = adapter->ProcessBytes(reset); + EXPECT_EQ(reset.size(), static_cast(reset_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, _)).Times(0); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, _, _)).Times(0); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, _, _)).Times(0); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + + EXPECT_THAT(visitor.data(), testing::IsEmpty()); +} + +TEST(OgHttp2AdapterTest, ServerRejectsStreamData) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Data(1, "This is the request body.") + .Headers(3, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .RstStream(3, Http2ErrorCode::CANCEL) + .Ping(47) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 25, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 25)); + EXPECT_CALL(visitor, OnDataForStream(1, _)).WillOnce(testing::Return(false)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_LT(result, 0); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::INTERNAL_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::GOAWAY})); +} + +// Exercises a naive mutually recursive test client and server. This test fails +// without recursion guards in OgHttp2Session. +TEST(OgHttp2AdapterInteractionTest, ClientServerInteractionTest) { + testing::NiceMock client_visitor; + OgHttp2Adapter::Options client_options; + client_options.perspective = Perspective::kClient; + auto client_adapter = OgHttp2Adapter::Create(client_visitor, client_options); + testing::NiceMock server_visitor; + OgHttp2Adapter::Options server_options; + server_options.perspective = Perspective::kServer; + auto server_adapter = OgHttp2Adapter::Create(server_visitor, server_options); + + // Feeds bytes sent from the client into the server's ProcessBytes. + EXPECT_CALL(client_visitor, OnReadyToSend(_)) + .WillRepeatedly( + testing::Invoke(server_adapter.get(), &OgHttp2Adapter::ProcessBytes)); + // Feeds bytes sent from the server into the client's ProcessBytes. + EXPECT_CALL(server_visitor, OnReadyToSend(_)) + .WillRepeatedly( + testing::Invoke(client_adapter.get(), &OgHttp2Adapter::ProcessBytes)); + // Sets up the server to respond automatically to a request from a client. + EXPECT_CALL(server_visitor, OnEndHeadersForStream(_)) + .WillRepeatedly([&server_adapter](Http2StreamId stream_id) { + server_adapter->SubmitResponse( + stream_id, ToHeaders({{":status", "200"}}), nullptr); + server_adapter->Send(); + return true; + }); + // Sets up the client to create a new stream automatically when receiving a + // response. + EXPECT_CALL(client_visitor, OnEndHeadersForStream(_)) + .WillRepeatedly([&client_adapter, + &client_visitor](Http2StreamId stream_id) { + if (stream_id < 10) { + const Http2StreamId new_stream_id = stream_id + 2; + auto body = + std::make_unique(client_visitor, true); + body->AppendPayload("This is an example request body."); + body->EndData(); + const int created_stream_id = client_adapter->SubmitRequest( + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", + absl::StrCat("/this/is/request/", new_stream_id)}}), + std::move(body), nullptr); + EXPECT_EQ(new_stream_id, created_stream_id); + client_adapter->Send(); + } + return true; + }); + + // Submit a request to ensure the first stream is created. + int stream_id = client_adapter->SubmitRequest( + ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + nullptr, nullptr); + EXPECT_EQ(stream_id, 1); + + client_adapter->Send(); +} + +TEST(OgHttp2AdapterInteractionTest, + ClientServerInteractionRepeatedHeaderNames) { + DataSavingVisitor client_visitor; + OgHttp2Adapter::Options client_options; + client_options.perspective = Perspective::kClient; + auto client_adapter = OgHttp2Adapter::Create(client_visitor, client_options); + + const std::vector
headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"accept", "text/plain"}, + {"accept", "text/html"}}); + + const int32_t stream_id1 = + client_adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + EXPECT_CALL(client_visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(client_visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(client_visitor, + OnBeforeFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(client_visitor, + OnFrameSent(HEADERS, stream_id1, _, + END_STREAM_FLAG | END_HEADERS_FLAG, 0)); + int send_result = client_adapter->Send(); + EXPECT_EQ(0, send_result); + + DataSavingVisitor server_visitor; + OgHttp2Adapter::Options server_options; + server_options.perspective = Perspective::kServer; + auto server_adapter = OgHttp2Adapter::Create(server_visitor, server_options); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(server_visitor, OnFrameHeader(0, _, SETTINGS, 0)); + EXPECT_CALL(server_visitor, OnSettingsStart()); + EXPECT_CALL(server_visitor, OnSetting).Times(testing::AnyNumber()); + EXPECT_CALL(server_visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(server_visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(server_visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(server_visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(server_visitor, OnHeaderForStream(1, ":scheme", "http")); + EXPECT_CALL(server_visitor, + OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(server_visitor, + OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(server_visitor, OnHeaderForStream(1, "accept", "text/plain")); + EXPECT_CALL(server_visitor, OnHeaderForStream(1, "accept", "text/html")); + EXPECT_CALL(server_visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(server_visitor, OnEndStream(1)); + + int64_t result = server_adapter->ProcessBytes(client_visitor.data()); + EXPECT_EQ(client_visitor.data().size(), static_cast(result)); +} + +TEST(OgHttp2AdapterTest, ServerForbidsNewStreamBelowWatermark) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(3, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Data(3, "This is the request body.") + .Headers(1, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(3)); + EXPECT_CALL(visitor, OnFrameHeader(3, 25, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(3, 25)); + EXPECT_CALL(visitor, OnDataForStream(3, "This is the request body.")); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kInvalidNewStreamId)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(result), frames.size()); + + EXPECT_EQ(3, adapter->GetHighestReceivedStreamId()); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterTest, ServerForbidsWindowUpdateOnIdleStream) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + const std::string frames = + TestFrameSequence().ClientPreface().WindowUpdate(1, 42).Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kWrongFrameSequence)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(result), frames.size()); + + EXPECT_EQ(1, adapter->GetHighestReceivedStreamId()); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterTest, ServerForbidsDataOnIdleStream) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Data(1, "Sorry, out of order") + .Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 0)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kWrongFrameSequence)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(result), frames.size()); + + EXPECT_EQ(1, adapter->GetHighestReceivedStreamId()); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterTest, ServerForbidsRstStreamOnIdleStream) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + const std::string frames = + TestFrameSequence() + .ClientPreface() + .RstStream(1, Http2ErrorCode::ENHANCE_YOUR_CALM) + .Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, RST_STREAM, 0)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kWrongFrameSequence)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(result), frames.size()); + + EXPECT_EQ(1, adapter->GetHighestReceivedStreamId()); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterTest, ServerForbidsNewStreamAboveStreamLimit) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + adapter->SubmitSettings({{MAX_CONCURRENT_STREAMS, 1}}); + + const std::string initial_frames = + TestFrameSequence().ClientPreface().Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(static_cast(initial_result), initial_frames.size()); + + EXPECT_TRUE(adapter->want_write()); + + // Server initial SETTINGS (with MAX_CONCURRENT_STREAMS) and SETTINGS ack. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS})); + visitor.Clear(); + + // Let the client send a SETTINGS ack and then attempt to open more than the + // advertised number of streams. The overflow stream should be rejected. + const std::string stream_frames = + TestFrameSequence() + .SettingsAck() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Headers(3, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, ACK_FLAG)); + EXPECT_CALL(visitor, OnSettingsAck()); + EXPECT_CALL(visitor, + OnFrameHeader(1, _, HEADERS, END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, + OnFrameHeader(3, _, HEADERS, END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL( + visitor, + OnInvalidFrame(3, Http2VisitorInterface::InvalidFrameError::kProtocol)); + // The oghttp2 stack also signals the error via OnConnectionError(). + EXPECT_CALL(visitor, OnConnectionError( + ConnectionError::kExceededMaxConcurrentStreams)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(static_cast(stream_result), stream_frames.size()); + + // The server should send a GOAWAY for this error, even though + // OnInvalidFrame() returns true. + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterTest, ServerRstStreamsNewStreamAboveStreamLimitBeforeAck) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + adapter->SubmitSettings({{MAX_CONCURRENT_STREAMS, 1}}); + + const std::string initial_frames = + TestFrameSequence().ClientPreface().Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(static_cast(initial_result), initial_frames.size()); + + EXPECT_TRUE(adapter->want_write()); + + // Server initial SETTINGS (with MAX_CONCURRENT_STREAMS) and SETTINGS ack. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS})); + visitor.Clear(); + + // Let the client avoid sending a SETTINGS ack and attempt to open more than + // the advertised number of streams. The server should still reject the + // overflow stream, albeit with RST_STREAM REFUSED_STREAM instead of GOAWAY. + const std::string stream_frames = + TestFrameSequence() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Headers(3, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .Serialize(); + + EXPECT_CALL(visitor, + OnFrameHeader(1, _, HEADERS, END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, + OnFrameHeader(3, _, HEADERS, END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, + OnInvalidFrame( + 3, Http2VisitorInterface::InvalidFrameError::kRefusedStream)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(static_cast(stream_result), stream_frames.size()); + + // The server sends a RST_STREAM for the offending stream. + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 3, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 3, _, 0x0, + static_cast(Http2ErrorCode::REFUSED_STREAM))); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::RST_STREAM})); +} + +TEST(OgHttp2AdapterTest, ServerForbidsProtocolPseudoheaderBeforeAck) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + options.allow_extended_connect = false; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string initial_frames = + TestFrameSequence().ClientPreface().Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(static_cast(initial_result), initial_frames.size()); + + // The client attempts to send a CONNECT request with the `:protocol` + // pseudoheader before receiving the server's SETTINGS frame. + const std::string stream1_frames = + TestFrameSequence() + .Headers(1, + {{":method", "CONNECT"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {":protocol", "websocket"}}, + /*fin=*/true) + .Serialize(); + + EXPECT_CALL(visitor, + OnFrameHeader(1, _, HEADERS, END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(5); + EXPECT_CALL(visitor, + OnInvalidFrame( + 1, Http2VisitorInterface::InvalidFrameError::kHttpMessaging)); + + int64_t stream_result = adapter->ProcessBytes(stream1_frames); + EXPECT_EQ(static_cast(stream_result), stream1_frames.size()); + + // Server initial SETTINGS and SETTINGS ack. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + + // The server sends a RST_STREAM for the offending stream. + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + // Server settings with ENABLE_CONNECT_PROTOCOL. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + + adapter->SubmitSettings({{ENABLE_CONNECT_PROTOCOL, 1}}); + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT( + visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM, SpdyFrameType::SETTINGS})); + visitor.Clear(); + + // The client attempts to send a CONNECT request with the `:protocol` + // pseudoheader before acking the server's SETTINGS frame. + const std::string stream3_frames = + TestFrameSequence() + .Headers(3, + {{":method", "CONNECT"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}, + {":protocol", "websocket"}}, + /*fin=*/true) + .Serialize(); + + // After sending SETTINGS with `ENABLE_CONNECT_PROTOCOL`, oghttp2 matches + // nghttp2 in allowing this, even though the `allow_extended_connect` option + // is false. + EXPECT_CALL(visitor, + OnFrameHeader(3, _, HEADERS, END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, _, _)).Times(5); + EXPECT_CALL(visitor, OnEndHeadersForStream(3)); + EXPECT_CALL(visitor, OnEndStream(3)); + + stream_result = adapter->ProcessBytes(stream3_frames); + EXPECT_EQ(static_cast(stream_result), stream3_frames.size()); + + EXPECT_FALSE(adapter->want_write()); +} + +TEST(OgHttp2AdapterTest, ServerAllowsProtocolPseudoheaderAfterAck) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + adapter->SubmitSettings({{ENABLE_CONNECT_PROTOCOL, 1}}); + + const std::string initial_frames = + TestFrameSequence().ClientPreface().Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(static_cast(initial_result), initial_frames.size()); + + // Server initial SETTINGS (with ENABLE_CONNECT_PROTOCOL) and SETTINGS ack. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + visitor.Clear(); + + // The client attempts to send a CONNECT request with the `:protocol` + // pseudoheader after acking the server's SETTINGS frame. + const std::string stream_frames = + TestFrameSequence() + .SettingsAck() + .Headers(1, + {{":method", "CONNECT"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {":protocol", "websocket"}}, + /*fin=*/true) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, _, SETTINGS, ACK_FLAG)); + EXPECT_CALL(visitor, OnSettingsAck()); + EXPECT_CALL(visitor, + OnFrameHeader(1, _, HEADERS, END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(5); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(static_cast(stream_result), stream_frames.size()); + + EXPECT_FALSE(adapter->want_write()); +} + +TEST(OgHttp2AdapterTest, SkipsSendingFramesForRejectedStream) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string initial_frames = + TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + EXPECT_CALL(visitor, + OnFrameHeader(1, _, HEADERS, END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(static_cast(initial_result), initial_frames.size()); + + auto body = std::make_unique(visitor, true); + body->AppendPayload("Here is some data, which will be completely ignored!"); + + int submit_result = adapter->SubmitResponse( + 1, ToHeaders({{":status", "200"}}), std::move(body)); + ASSERT_EQ(0, submit_result); + + auto source = std::make_unique(ToHeaderBlock(ToHeaders( + {{"query-cost", "is too darn high"}, {"secret-sauce", "hollandaise"}}))); + adapter->SubmitMetadata(1, 16384u, std::move(source)); + + adapter->SubmitWindowUpdate(1, 1024); + adapter->SubmitRst(1, Http2ErrorCode::INTERNAL_ERROR); + + // Server initial SETTINGS and SETTINGS ack. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + + // The server sends a RST_STREAM for the offending stream. + // The response HEADERS, DATA and WINDOW_UPDATE are all ignored. + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, _, 0x0, + static_cast(Http2ErrorCode::INTERNAL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(OgHttpAdapterServerTest, ServerStartsShutdown) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + EXPECT_FALSE(adapter->want_write()); + + adapter->SubmitShutdownNotice(); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(GOAWAY, 0, _, 0x0, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterTest, ServerStartsShutdownAfterGoaway) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + EXPECT_FALSE(adapter->want_write()); + + adapter->SubmitGoAway(1, Http2ErrorCode::HTTP2_NO_ERROR, + "and don't come back!"); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(GOAWAY, 0, _, 0x0, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::GOAWAY})); + + // No-op, since a GOAWAY has previously been enqueued. + adapter->SubmitShutdownNotice(); + EXPECT_FALSE(adapter->want_write()); +} + +// Verifies that a connection-level processing error results in repeatedly +// returning a positive value for ProcessBytes() to mark all data as consumed +// when the blackhole option is enabled. +TEST(OgHttp2AdapterTest, ConnectionErrorWithBlackholingData) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + options.blackhole_data_on_connection_error = true; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = + TestFrameSequence().ClientPreface().WindowUpdate(1, 42).Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kWrongFrameSequence)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(result), frames.size()); + + // Ask the connection to process more bytes. Because the option is enabled, + // the data should be marked as consumed. + const std::string next_frame = TestFrameSequence().Ping(42).Serialize(); + const int64_t next_result = adapter->ProcessBytes(next_frame); + EXPECT_EQ(static_cast(next_result), next_frame.size()); +} + +// Verifies that a connection-level processing error results in returning a +// negative value for ProcessBytes() when the blackhole option is disabled. +TEST(OgHttp2AdapterTest, ConnectionErrorWithoutBlackholingData) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + options.blackhole_data_on_connection_error = false; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = + TestFrameSequence().ClientPreface().WindowUpdate(1, 42).Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kWrongFrameSequence)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_LT(result, 0); + + // Ask the connection to process more bytes. Because the option is disabled, + // ProcessBytes() should continue to return an error. + const std::string next_frame = TestFrameSequence().Ping(42).Serialize(); + const int64_t next_result = adapter->ProcessBytes(next_frame); + EXPECT_LT(next_result, 0); +} + +TEST(OgHttp2AdapterTest, ServerDoesNotSendFramesAfterImmediateGoAway) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + // Submit a custom initial SETTINGS frame with one setting. + adapter->SubmitSettings({{HEADER_TABLE_SIZE, 100u}}); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, + OnFrameHeader(1, _, HEADERS, END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + + // Submit a response for the stream. + auto body = std::make_unique(visitor, true); + body->AppendPayload("This data is doomed to never be written."); + int submit_result = adapter->SubmitResponse( + 1, ToHeaders({{":status", "200"}}), std::move(body)); + ASSERT_EQ(0, submit_result); + + // Submit a WINDOW_UPDATE frame. + adapter->SubmitWindowUpdate(kConnectionStreamId, 42); + + // Submit another SETTINGS frame. + adapter->SubmitSettings({}); + + // Submit some metadata. + auto source = std::make_unique(ToHeaderBlock(ToHeaders( + {{"query-cost", "is too darn high"}, {"secret-sauce", "hollandaise"}}))); + adapter->SubmitMetadata(1, 16384u, std::move(source)); + + EXPECT_TRUE(adapter->want_write()); + + // Trigger a connection error. Only the response headers will be written. + const std::string connection_error_frames = + TestFrameSequence().WindowUpdate(3, 42).Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(3, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kWrongFrameSequence)); + + const int64_t result = adapter->ProcessBytes(connection_error_frames); + EXPECT_EQ(static_cast(result), connection_error_frames.size()); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::GOAWAY})); + visitor.Clear(); + + // Try to submit more frames for writing. They should not be written. + adapter->SubmitPing(42); + // TODO(diannahu): Enable the below expectation. + // EXPECT_FALSE(adapter->want_write()); + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), testing::IsEmpty()); +} + +TEST(OgHttp2AdapterTest, ServerHandlesContentLength) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::string stream_frames = + TestFrameSequence() + .ClientPreface() + .Headers(1, {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"content-length", "2"}}) + .Data(1, "hi", /*fin=*/true) + .Headers(3, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/three"}, + {"content-length", "nan"}}, + /*fin=*/true) + .Serialize(); + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + // Stream 1: content-length is correct + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(5); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 1)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 2)); + EXPECT_CALL(visitor, OnDataForStream(1, "hi")); + EXPECT_CALL(visitor, OnEndStream(1)); + + // Stream 3: content-length is not a number + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, _, _)).Times(4); + EXPECT_CALL( + visitor, + OnInvalidFrame(3, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 3, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 3, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(3, Http2ErrorCode::HTTP2_NO_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(OgHttp2AdapterTest, ServerHandlesContentLengthMismatch) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::string stream_frames = + TestFrameSequence() + .ClientPreface() + .Headers(1, {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}, + {"content-length", "2"}}) + .Data(1, "h", /*fin=*/true) + .Headers(3, {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/three"}, + {"content-length", "2"}}) + .Data(3, "howdy", /*fin=*/true) + .Headers(5, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/four"}, + {"content-length", "2"}}, + /*fin=*/true) + .Headers(7, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/four"}, + {"content-length", "2"}}, + /*fin=*/false) + .Data(7, "h", /*fin=*/false) + .Headers(7, {{"extra-info", "Trailers with content-length mismatch"}}, + /*fin=*/true) + .Serialize(); + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + // Stream 1: content-length is larger than actual data + // All data is delivered to the visitor. Note that neither oghttp2 nor + // nghttp2 delivers OnInvalidFrame(). + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(5); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 1)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 1)); + EXPECT_CALL(visitor, OnDataForStream(1, "h")); + + // Stream 3: content-length is smaller than actual data + // The beginning of data is delivered to the visitor, but not the actual data. + // Again, neither oghttp2 nor nghttp2 delivers OnInvalidFrame(). + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, _, _)).Times(5); + EXPECT_CALL(visitor, OnEndHeadersForStream(3)); + EXPECT_CALL(visitor, OnFrameHeader(3, _, DATA, 1)); + EXPECT_CALL(visitor, OnBeginDataForStream(3, 5)); + + // Stream 5: content-length is invalid and HEADERS ends the stream + // Only oghttp2 invokes OnEndHeadersForStream(). Only nghttp2 invokes + // OnInvalidFrame(). + EXPECT_CALL(visitor, OnFrameHeader(5, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(5)); + EXPECT_CALL(visitor, OnHeaderForStream(5, _, _)).Times(5); + EXPECT_CALL(visitor, OnEndHeadersForStream(5)); + + // Stream 7: content-length is invalid and trailers end the stream + // Only oghttp2 invokes OnEndHeadersForStream(). Only nghttp2 invokes + // OnInvalidFrame(). + EXPECT_CALL(visitor, OnFrameHeader(7, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(7)); + EXPECT_CALL(visitor, OnHeaderForStream(7, _, _)).Times(5); + EXPECT_CALL(visitor, OnEndHeadersForStream(7)); + EXPECT_CALL(visitor, OnFrameHeader(7, _, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(7, 1)); + EXPECT_CALL(visitor, OnDataForStream(7, "h")); + EXPECT_CALL(visitor, OnFrameHeader(7, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(7)); + EXPECT_CALL(visitor, OnHeaderForStream(7, _, _)); + EXPECT_CALL(visitor, OnEndHeadersForStream(7)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 3, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 3, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(3, Http2ErrorCode::HTTP2_NO_ERROR)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 5, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 5, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(5, Http2ErrorCode::HTTP2_NO_ERROR)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 7, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 7, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(7, Http2ErrorCode::HTTP2_NO_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT( + visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM, SpdyFrameType::RST_STREAM, + SpdyFrameType::RST_STREAM, SpdyFrameType::RST_STREAM})); +} + +TEST(OgHttp2AdapterTest, ServerHandlesAsteriskPathForOptions) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::string stream_frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":scheme", "https"}, + {":authority", "example.com"}, + {":path", "*"}, + {":method", "OPTIONS"}}, + /*fin=*/true) + .Serialize(); + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + + EXPECT_TRUE(adapter->want_write()); + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS})); +} + +TEST(OgHttp2AdapterTest, ServerHandlesInvalidPath) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::string stream_frames = + TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":scheme", "https"}, + {":authority", "example.com"}, + {":path", "*"}, + {":method", "GET"}}, + /*fin=*/true) + .Headers(3, + {{":scheme", "https"}, + {":authority", "example.com"}, + {":path", "other/non/slash/starter"}, + {":method", "GET"}}, + /*fin=*/true) + .Headers(5, + {{":scheme", "https"}, + {":authority", "example.com"}, + {":path", ""}, + {":method", "GET"}}, + /*fin=*/true) + .Serialize(); + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, + OnInvalidFrame( + 1, Http2VisitorInterface::InvalidFrameError::kHttpMessaging)); + + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, _, _)).Times(4); + EXPECT_CALL(visitor, + OnInvalidFrame( + 3, Http2VisitorInterface::InvalidFrameError::kHttpMessaging)); + + EXPECT_CALL(visitor, OnFrameHeader(5, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(5)); + EXPECT_CALL(visitor, OnHeaderForStream(5, _, _)).Times(2); + EXPECT_CALL( + visitor, + OnInvalidFrame(5, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 3, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 3, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(3, Http2ErrorCode::HTTP2_NO_ERROR)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 5, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 5, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(5, Http2ErrorCode::HTTP2_NO_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT( + visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM, SpdyFrameType::RST_STREAM, + SpdyFrameType::RST_STREAM})); +} + +TEST(OgHttp2AdapterTest, ServerHandlesTeHeader) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::string stream_frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}, + {":method", "GET"}, + {"te", "trailers"}}, + /*fin=*/true) + .Headers(3, + {{":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}, + {":method", "GET"}, + {"te", "trailers, deflate"}}, + /*fin=*/true) + .Serialize(); + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + // Stream 1: TE: trailers should be allowed. + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(5); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + // Stream 3: TE: should be rejected. + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, _, _)).Times(4); + EXPECT_CALL( + visitor, + OnInvalidFrame(3, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 3, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 3, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(3, Http2ErrorCode::HTTP2_NO_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM})); +} + +TEST(OgHttp2AdapterTest, ServerHandlesConnectionSpecificHeaders) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::string stream_frames = + TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}, + {":method", "GET"}, + {"connection", "keep-alive"}}, + /*fin=*/true) + .Headers(3, + {{":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}, + {":method", "GET"}, + {"proxy-connection", "keep-alive"}}, + /*fin=*/true) + .Headers(5, + {{":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}, + {":method", "GET"}, + {"keep-alive", "timeout=42"}}, + /*fin=*/true) + .Headers(7, + {{":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}, + {":method", "GET"}, + {"transfer-encoding", "chunked"}}, + /*fin=*/true) + .Headers(9, + {{":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}, + {":method", "GET"}, + {"upgrade", "h2c"}}, + /*fin=*/true) + .Serialize(); + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + // All streams contain a connection-specific header and should be rejected. + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL( + visitor, + OnInvalidFrame(1, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, _, _)).Times(4); + EXPECT_CALL( + visitor, + OnInvalidFrame(3, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + EXPECT_CALL(visitor, OnFrameHeader(5, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(5)); + EXPECT_CALL(visitor, OnHeaderForStream(5, _, _)).Times(4); + EXPECT_CALL( + visitor, + OnInvalidFrame(5, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + EXPECT_CALL(visitor, OnFrameHeader(7, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(7)); + EXPECT_CALL(visitor, OnHeaderForStream(7, _, _)).Times(4); + EXPECT_CALL( + visitor, + OnInvalidFrame(7, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + EXPECT_CALL(visitor, OnFrameHeader(9, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(9)); + EXPECT_CALL(visitor, OnHeaderForStream(9, _, _)).Times(4); + EXPECT_CALL( + visitor, + OnInvalidFrame(9, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 3, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 3, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(3, Http2ErrorCode::HTTP2_NO_ERROR)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 5, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 5, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(5, Http2ErrorCode::HTTP2_NO_ERROR)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 7, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 7, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(7, Http2ErrorCode::HTTP2_NO_ERROR)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 9, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 9, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(9, Http2ErrorCode::HTTP2_NO_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT( + visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::RST_STREAM, SpdyFrameType::RST_STREAM, + SpdyFrameType::RST_STREAM, SpdyFrameType::RST_STREAM, + SpdyFrameType::RST_STREAM})); +} + +TEST(OgHttp2AdapterTest, ServerUsesCustomWindowUpdateStrategy) { + // Test the use of a custom WINDOW_UPDATE strategy. + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.should_window_update_fn = [](int64_t /*limit*/, int64_t /*size*/, + int64_t /*delta*/) { return true; }; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Data(1, "This is the request body.", + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, END_STREAM_FLAG)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, _)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the request body.")); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(frames.size()), result); + + // Mark a small number of bytes for the stream as consumed. Because of the + // custom WINDOW_UPDATE strategy, the session should send WINDOW_UPDATEs. + adapter->MarkDataConsumedForStream(1, 5); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(WINDOW_UPDATE, 1, 4, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(WINDOW_UPDATE, 1, 4, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(WINDOW_UPDATE, 0, 4, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(WINDOW_UPDATE, 0, 4, 0x0, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS, + SpdyFrameType::WINDOW_UPDATE, + SpdyFrameType::WINDOW_UPDATE})); +} + +// Verifies that NoopHeaderValidator allows several header combinations that +// would otherwise be invalid. +TEST(OgHttp2AdapterTest, NoopHeaderValidatorTest) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + options.validate_http_headers = false; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/1"}, + {"content-length", "7"}, + {"content-length", "7"}}, + /*fin=*/false) + .Headers(3, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/3"}, + {"content-length", "11"}, + {"content-length", "13"}}, + /*fin=*/false) + .Headers(5, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "foo.com"}, + {":path", "/"}, + {"host", "bar.com"}}, + /*fin=*/true) + .Headers(7, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}, + {"Accept", "uppercase, oh boy!"}}, + /*fin=*/false) + .Headers(9, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "ex|ample.com"}, + {":path", "/"}}, + /*fin=*/false) + .Headers(11, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}, + {"content-length", "nan"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/1")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "content-length", "7")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "content-length", "7")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + // Stream 3 + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":path", "/3")); + EXPECT_CALL(visitor, OnHeaderForStream(3, "content-length", "11")); + EXPECT_CALL(visitor, OnHeaderForStream(3, "content-length", "13")); + EXPECT_CALL(visitor, OnEndHeadersForStream(3)); + // Stream 5 + EXPECT_CALL(visitor, OnFrameHeader(5, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(5)); + EXPECT_CALL(visitor, OnHeaderForStream(5, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(5, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(5, ":authority", "foo.com")); + EXPECT_CALL(visitor, OnHeaderForStream(5, ":path", "/")); + EXPECT_CALL(visitor, OnHeaderForStream(5, "host", "bar.com")); + EXPECT_CALL(visitor, OnEndHeadersForStream(5)); + EXPECT_CALL(visitor, OnEndStream(5)); + // Stream 7 + EXPECT_CALL(visitor, OnFrameHeader(7, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(7)); + EXPECT_CALL(visitor, OnHeaderForStream(7, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(7, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(7, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(7, ":path", "/")); + EXPECT_CALL(visitor, OnHeaderForStream(7, "Accept", "uppercase, oh boy!")); + EXPECT_CALL(visitor, OnEndHeadersForStream(7)); + // Stream 9 + EXPECT_CALL(visitor, OnFrameHeader(9, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(9)); + EXPECT_CALL(visitor, OnHeaderForStream(9, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(9, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(9, ":authority", "ex|ample.com")); + EXPECT_CALL(visitor, OnHeaderForStream(9, ":path", "/")); + EXPECT_CALL(visitor, OnEndHeadersForStream(9)); + // Stream 11 + EXPECT_CALL(visitor, OnFrameHeader(11, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(11)); + EXPECT_CALL(visitor, OnHeaderForStream(11, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(11, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(11, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(11, ":path", "/")); + EXPECT_CALL(visitor, OnHeaderForStream(11, "content-length", "nan")); + EXPECT_CALL(visitor, OnEndHeadersForStream(11)); + EXPECT_CALL(visitor, OnEndStream(11)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); +} + +TEST(OgHttp2AdapterTest, NegativeFlowControlStreamResumption) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options; + options.perspective = Perspective::kServer; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = + TestFrameSequence() + .ClientPreface({{INITIAL_WINDOW_SIZE, 128u * 1024u}}) + .WindowUpdate(0, 1 << 20) + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, + OnSetting(Http2Setting{Http2KnownSettingsId::INITIAL_WINDOW_SIZE, + 128u * 1024u})); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(0, 1 << 20)); + + // Stream 1 + EXPECT_CALL(visitor, + OnFrameHeader(1, _, HEADERS, END_STREAM_FLAG | END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + + // Submit a response for the stream. + auto body = std::make_unique(visitor, true); + TestDataFrameSource& body_ref = *body; + body_ref.AppendPayload(std::string(70000, 'a')); + int submit_result = adapter->SubmitResponse( + 1, ToHeaders({{":status", "200"}}), std::move(body)); + ASSERT_EQ(0, submit_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, END_HEADERS_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, END_HEADERS_FLAG, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)).Times(5); + + adapter->Send(); + EXPECT_FALSE(adapter->want_write()); + + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, + OnSetting(Http2Setting{Http2KnownSettingsId::INITIAL_WINDOW_SIZE, + 64u * 1024u})); + EXPECT_CALL(visitor, OnSettingsEnd()); + + // Processing these SETTINGS will cause stream 1's send window to become + // negative. + adapter->ProcessBytes(TestFrameSequence() + .Settings({{INITIAL_WINDOW_SIZE, 64u * 1024u}}) + .Serialize()); + EXPECT_TRUE(adapter->want_write()); + EXPECT_LT(adapter->GetStreamSendWindowSize(1), 0); + + body_ref.AppendPayload("Stream should be resumed."); + adapter->ResumeStream(1); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, ACK_FLAG)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, ACK_FLAG, 0)); + adapter->Send(); + EXPECT_FALSE(adapter->want_write()); + + // Upon receiving the WINDOW_UPDATE, stream 1 should be ready to write. + EXPECT_CALL(visitor, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(1, 10000)); + adapter->ProcessBytes(TestFrameSequence().WindowUpdate(1, 10000).Serialize()); + EXPECT_TRUE(adapter->want_write()); + EXPECT_GT(adapter->GetStreamSendWindowSize(1), 0); + + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)); + adapter->Send(); +} + +} // namespace +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/oghttp2_session.cc b/quiche/http2/adapter/oghttp2_session.cc new file mode 100644 index 000000000000..fe6127a3cc0b --- /dev/null +++ b/quiche/http2/adapter/oghttp2_session.cc @@ -0,0 +1,2024 @@ +#include "quiche/http2/adapter/oghttp2_session.h" + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/escaping.h" +#include "quiche/http2/adapter/header_validator.h" +#include "quiche/http2/adapter/http2_protocol.h" +#include "quiche/http2/adapter/http2_util.h" +#include "quiche/http2/adapter/http2_visitor_interface.h" +#include "quiche/http2/adapter/noop_header_validator.h" +#include "quiche/http2/adapter/oghttp2_util.h" +#include "quiche/spdy/core/spdy_protocol.h" + +namespace http2 { +namespace adapter { + +namespace { + +using ConnectionError = Http2VisitorInterface::ConnectionError; +using SpdyFramerError = Http2DecoderAdapter::SpdyFramerError; + +using ::spdy::SpdySettingsIR; + +const uint32_t kMaxAllowedMetadataFrameSize = 65536u; +const uint32_t kDefaultHpackTableCapacity = 4096u; +const uint32_t kMaximumHpackTableCapacity = 65536u; + +// Corresponds to NGHTTP2_ERR_CALLBACK_FAILURE. +const int kSendError = -902; + +constexpr absl::string_view kHeadValue = "HEAD"; + +// TODO(birenroy): Consider incorporating spdy::FlagsSerializionVisitor here. +class FrameAttributeCollector : public spdy::SpdyFrameVisitor { + public: + FrameAttributeCollector() = default; + void VisitData(const spdy::SpdyDataIR& data) override { + frame_type_ = static_cast(data.frame_type()); + stream_id_ = data.stream_id(); + flags_ = + (data.fin() ? END_STREAM_FLAG : 0) | (data.padded() ? PADDED_FLAG : 0); + } + void VisitHeaders(const spdy::SpdyHeadersIR& headers) override { + frame_type_ = static_cast(headers.frame_type()); + stream_id_ = headers.stream_id(); + flags_ = END_HEADERS_FLAG | (headers.fin() ? END_STREAM_FLAG : 0) | + (headers.padded() ? PADDED_FLAG : 0) | + (headers.has_priority() ? PRIORITY_FLAG : 0); + } + void VisitPriority(const spdy::SpdyPriorityIR& priority) override { + frame_type_ = static_cast(priority.frame_type()); + frame_type_ = 2; + stream_id_ = priority.stream_id(); + } + void VisitRstStream(const spdy::SpdyRstStreamIR& rst_stream) override { + frame_type_ = static_cast(rst_stream.frame_type()); + frame_type_ = 3; + stream_id_ = rst_stream.stream_id(); + error_code_ = rst_stream.error_code(); + } + void VisitSettings(const spdy::SpdySettingsIR& settings) override { + frame_type_ = static_cast(settings.frame_type()); + frame_type_ = 4; + flags_ = (settings.is_ack() ? ACK_FLAG : 0); + } + void VisitPushPromise(const spdy::SpdyPushPromiseIR& push_promise) override { + frame_type_ = static_cast(push_promise.frame_type()); + frame_type_ = 5; + stream_id_ = push_promise.stream_id(); + flags_ = (push_promise.padded() ? PADDED_FLAG : 0); + } + void VisitPing(const spdy::SpdyPingIR& ping) override { + frame_type_ = static_cast(ping.frame_type()); + frame_type_ = 6; + flags_ = (ping.is_ack() ? ACK_FLAG : 0); + } + void VisitGoAway(const spdy::SpdyGoAwayIR& goaway) override { + frame_type_ = static_cast(goaway.frame_type()); + frame_type_ = 7; + error_code_ = goaway.error_code(); + } + void VisitWindowUpdate( + const spdy::SpdyWindowUpdateIR& window_update) override { + frame_type_ = static_cast(window_update.frame_type()); + frame_type_ = 8; + stream_id_ = window_update.stream_id(); + } + void VisitContinuation( + const spdy::SpdyContinuationIR& continuation) override { + frame_type_ = static_cast(continuation.frame_type()); + stream_id_ = continuation.stream_id(); + flags_ = continuation.end_headers() ? END_HEADERS_FLAG : 0; + } + void VisitUnknown(const spdy::SpdyUnknownIR& unknown) override { + frame_type_ = static_cast(unknown.frame_type()); + stream_id_ = unknown.stream_id(); + flags_ = unknown.flags(); + } + void VisitAltSvc(const spdy::SpdyAltSvcIR& /*altsvc*/) override {} + void VisitPriorityUpdate( + const spdy::SpdyPriorityUpdateIR& /*priority_update*/) override {} + void VisitAcceptCh(const spdy::SpdyAcceptChIR& /*accept_ch*/) override {} + + uint32_t stream_id() { return stream_id_; } + uint32_t error_code() { return error_code_; } + uint8_t frame_type() { return frame_type_; } + uint8_t flags() { return flags_; } + + private: + uint32_t stream_id_ = 0; + uint32_t error_code_ = 0; + uint8_t frame_type_ = 0; + uint8_t flags_ = 0; +}; + +absl::string_view TracePerspectiveAsString(Perspective p) { + switch (p) { + case Perspective::kClient: + return "OGHTTP2_CLIENT"; + case Perspective::kServer: + return "OGHTTP2_SERVER"; + } + return "OGHTTP2_SERVER"; +} + +class RunOnExit { + public: + RunOnExit() = default; + explicit RunOnExit(std::function f) : f_(std::move(f)) {} + + RunOnExit(const RunOnExit& other) = delete; + RunOnExit& operator=(const RunOnExit& other) = delete; + RunOnExit(RunOnExit&& other) = delete; + RunOnExit& operator=(RunOnExit&& other) = delete; + + ~RunOnExit() { + if (f_) { + f_(); + } + f_ = {}; + } + + void emplace(std::function f) { f_ = std::move(f); } + + private: + std::function f_; +}; + +Http2ErrorCode GetHttp2ErrorCode(SpdyFramerError error) { + switch (error) { + case SpdyFramerError::SPDY_NO_ERROR: + return Http2ErrorCode::HTTP2_NO_ERROR; + case SpdyFramerError::SPDY_INVALID_STREAM_ID: + case SpdyFramerError::SPDY_INVALID_CONTROL_FRAME: + case SpdyFramerError::SPDY_INVALID_PADDING: + case SpdyFramerError::SPDY_INVALID_DATA_FRAME_FLAGS: + case SpdyFramerError::SPDY_UNEXPECTED_FRAME: + return Http2ErrorCode::PROTOCOL_ERROR; + case SpdyFramerError::SPDY_CONTROL_PAYLOAD_TOO_LARGE: + case SpdyFramerError::SPDY_INVALID_CONTROL_FRAME_SIZE: + case SpdyFramerError::SPDY_OVERSIZED_PAYLOAD: + return Http2ErrorCode::FRAME_SIZE_ERROR; + case SpdyFramerError::SPDY_DECOMPRESS_FAILURE: + case SpdyFramerError::SPDY_HPACK_INDEX_VARINT_ERROR: + case SpdyFramerError::SPDY_HPACK_NAME_LENGTH_VARINT_ERROR: + case SpdyFramerError::SPDY_HPACK_VALUE_LENGTH_VARINT_ERROR: + case SpdyFramerError::SPDY_HPACK_NAME_TOO_LONG: + case SpdyFramerError::SPDY_HPACK_VALUE_TOO_LONG: + case SpdyFramerError::SPDY_HPACK_NAME_HUFFMAN_ERROR: + case SpdyFramerError::SPDY_HPACK_VALUE_HUFFMAN_ERROR: + case SpdyFramerError::SPDY_HPACK_MISSING_DYNAMIC_TABLE_SIZE_UPDATE: + case SpdyFramerError::SPDY_HPACK_INVALID_INDEX: + case SpdyFramerError::SPDY_HPACK_INVALID_NAME_INDEX: + case SpdyFramerError::SPDY_HPACK_DYNAMIC_TABLE_SIZE_UPDATE_NOT_ALLOWED: + case SpdyFramerError:: + SPDY_HPACK_INITIAL_DYNAMIC_TABLE_SIZE_UPDATE_IS_ABOVE_LOW_WATER_MARK: + case SpdyFramerError:: + SPDY_HPACK_DYNAMIC_TABLE_SIZE_UPDATE_IS_ABOVE_ACKNOWLEDGED_SETTING: + case SpdyFramerError::SPDY_HPACK_TRUNCATED_BLOCK: + case SpdyFramerError::SPDY_HPACK_FRAGMENT_TOO_LONG: + case SpdyFramerError::SPDY_HPACK_COMPRESSED_HEADER_SIZE_EXCEEDS_LIMIT: + return Http2ErrorCode::COMPRESSION_ERROR; + case SpdyFramerError::SPDY_INTERNAL_FRAMER_ERROR: + case SpdyFramerError::SPDY_STOP_PROCESSING: + case SpdyFramerError::LAST_ERROR: + return Http2ErrorCode::INTERNAL_ERROR; + } + return Http2ErrorCode::INTERNAL_ERROR; +} + +bool IsResponse(HeaderType type) { + return type == HeaderType::RESPONSE_100 || type == HeaderType::RESPONSE; +} + +bool StatusIs1xx(absl::string_view status) { + return status.size() == 3 && status[0] == '1'; +} + +// Returns the upper bound on HPACK encoder table capacity. If not specified in +// the Options, a reasonable default upper bound is used. +uint32_t HpackCapacityBound(const OgHttp2Session::Options& o) { + return o.max_hpack_encoding_table_capacity.value_or( + kMaximumHpackTableCapacity); +} + +bool IsNonAckSettings(const spdy::SpdyFrameIR& frame) { + return frame.frame_type() == spdy::SpdyFrameType::SETTINGS && + !reinterpret_cast(frame).is_ack(); +} + +} // namespace + +OgHttp2Session::PassthroughHeadersHandler::PassthroughHeadersHandler( + OgHttp2Session& session, Http2VisitorInterface& visitor) + : session_(session), visitor_(visitor) { + if (session_.options_.validate_http_headers) { + QUICHE_VLOG(2) << "instantiating regular header validator"; + validator_ = std::make_unique(); + } else { + QUICHE_VLOG(2) << "instantiating noop header validator"; + validator_ = std::make_unique(); + } +} + +void OgHttp2Session::PassthroughHeadersHandler::OnHeaderBlockStart() { + result_ = Http2VisitorInterface::HEADER_OK; + const bool status = visitor_.OnBeginHeadersForStream(stream_id_); + if (!status) { + QUICHE_VLOG(1) + << "Visitor rejected header block, returning HEADER_CONNECTION_ERROR"; + result_ = Http2VisitorInterface::HEADER_CONNECTION_ERROR; + } + validator_->StartHeaderBlock(); +} + +Http2VisitorInterface::OnHeaderResult InterpretHeaderStatus( + HeaderValidator::HeaderStatus status) { + switch (status) { + case HeaderValidator::HEADER_OK: + case HeaderValidator::HEADER_SKIP: + return Http2VisitorInterface::HEADER_OK; + case HeaderValidator::HEADER_FIELD_INVALID: + return Http2VisitorInterface::HEADER_FIELD_INVALID; + case HeaderValidator::HEADER_FIELD_TOO_LONG: + return Http2VisitorInterface::HEADER_RST_STREAM; + } + return Http2VisitorInterface::HEADER_CONNECTION_ERROR; +} + +void OgHttp2Session::PassthroughHeadersHandler::OnHeader( + absl::string_view key, absl::string_view value) { + if (result_ != Http2VisitorInterface::HEADER_OK) { + QUICHE_VLOG(2) << "Early return; status not HEADER_OK"; + return; + } + const HeaderValidator::HeaderStatus validation_result = + validator_->ValidateSingleHeader(key, value); + if (validation_result == HeaderValidator::HEADER_SKIP) { + return; + } + if (validation_result != HeaderValidator::HEADER_OK) { + QUICHE_VLOG(2) << "Header validation failed with result " + << static_cast(validation_result); + result_ = InterpretHeaderStatus(validation_result); + return; + } + result_ = visitor_.OnHeaderForStream(stream_id_, key, value); +} + +void OgHttp2Session::PassthroughHeadersHandler::OnHeaderBlockEnd( + size_t /* uncompressed_header_bytes */, + size_t /* compressed_header_bytes */) { + if (result_ == Http2VisitorInterface::HEADER_OK) { + if (!validator_->FinishHeaderBlock(type_)) { + QUICHE_VLOG(1) << "FinishHeaderBlock returned false; returning " + "HEADER_HTTP_MESSAGING"; + result_ = Http2VisitorInterface::HEADER_HTTP_MESSAGING; + } + } + if (frame_contains_fin_ && IsResponse(type_) && + StatusIs1xx(status_header())) { + QUICHE_VLOG(1) << "Unexpected end of stream without final headers"; + result_ = Http2VisitorInterface::HEADER_HTTP_MESSAGING; + } + if (result_ == Http2VisitorInterface::HEADER_OK) { + const bool result = visitor_.OnEndHeadersForStream(stream_id_); + if (!result) { + session_.fatal_visitor_callback_failure_ = true; + session_.decoder_.StopProcessing(); + } + } else { + session_.OnHeaderStatus(stream_id_, result_); + } + frame_contains_fin_ = false; +} + +// TODO(diannahu): Add checks for request methods. +bool OgHttp2Session::PassthroughHeadersHandler::CanReceiveBody() const { + switch (header_type()) { + case HeaderType::REQUEST_TRAILER: + case HeaderType::RESPONSE_TRAILER: + case HeaderType::RESPONSE_100: + return false; + case HeaderType::RESPONSE: + // 304 responses should not have a body: + // https://httpwg.org/specs/rfc7230.html#rfc.section.3.3.2 + // Neither should 204 responses: + // https://httpwg.org/specs/rfc7231.html#rfc.section.6.3.5 + return status_header() != "304" && status_header() != "204"; + case HeaderType::REQUEST: + return true; + } + return true; +} + +// A visitor that extracts an int64_t from each type of a ProcessBytesResult. +struct OgHttp2Session::ProcessBytesResultVisitor { + int64_t operator()(const int64_t bytes) const { return bytes; } + + int64_t operator()(const ProcessBytesError error) const { + switch (error) { + case ProcessBytesError::kUnspecified: + return -1; + case ProcessBytesError::kInvalidConnectionPreface: + return -903; // NGHTTP2_ERR_BAD_CLIENT_MAGIC + case ProcessBytesError::kVisitorCallbackFailed: + return -902; // NGHTTP2_ERR_CALLBACK_FAILURE + } + return -1; + } +}; + +OgHttp2Session::OgHttp2Session(Http2VisitorInterface& visitor, Options options) + : visitor_(visitor), + options_(options), + event_forwarder_([this]() { return !latched_error_; }, *this), + receive_logger_( + &event_forwarder_, TracePerspectiveAsString(options.perspective), + [logging_enabled = GetQuicheFlag(quiche_oghttp2_debug_trace)]() { + return logging_enabled; + }, + this), + send_logger_( + TracePerspectiveAsString(options.perspective), + [logging_enabled = GetQuicheFlag(quiche_oghttp2_debug_trace)]() { + return logging_enabled; + }, + this), + headers_handler_(*this, visitor), + noop_headers_handler_(/*listener=*/nullptr), + connection_window_manager_( + kInitialFlowControlWindowSize, + [this](size_t window_update_delta) { + SendWindowUpdate(kConnectionStreamId, window_update_delta); + }, + options.should_window_update_fn, + /*update_window_on_notify=*/false) { + decoder_.set_visitor(&receive_logger_); + if (options_.max_header_list_bytes) { + // Limit buffering of encoded HPACK data to 2x the decoded limit. + decoder_.GetHpackDecoder()->set_max_decode_buffer_size_bytes( + 2 * *options_.max_header_list_bytes); + // Limit the total bytes accepted for HPACK decoding to 4x the limit. + decoder_.GetHpackDecoder()->set_max_header_block_bytes( + 4 * *options_.max_header_list_bytes); + } + if (IsServerSession()) { + remaining_preface_ = {spdy::kHttp2ConnectionHeaderPrefix, + spdy::kHttp2ConnectionHeaderPrefixSize}; + } + if (options_.max_header_field_size.has_value()) { + headers_handler_.SetMaxFieldSize(options_.max_header_field_size.value()); + } + headers_handler_.SetAllowObsText(options_.allow_obs_text); +} + +OgHttp2Session::~OgHttp2Session() {} + +void OgHttp2Session::SetStreamUserData(Http2StreamId stream_id, + void* user_data) { + auto it = stream_map_.find(stream_id); + if (it != stream_map_.end()) { + it->second.user_data = user_data; + } +} + +void* OgHttp2Session::GetStreamUserData(Http2StreamId stream_id) { + auto it = stream_map_.find(stream_id); + if (it != stream_map_.end()) { + return it->second.user_data; + } + auto p = pending_streams_.find(stream_id); + if (p != pending_streams_.end()) { + return p->second.user_data; + } + return nullptr; +} + +bool OgHttp2Session::ResumeStream(Http2StreamId stream_id) { + auto it = stream_map_.find(stream_id); + if (it == stream_map_.end() || it->second.outbound_body == nullptr || + !write_scheduler_.StreamRegistered(stream_id)) { + return false; + } + it->second.data_deferred = false; + write_scheduler_.MarkStreamReady(stream_id, /*add_to_front=*/false); + return true; +} + +int OgHttp2Session::GetStreamSendWindowSize(Http2StreamId stream_id) const { + auto it = stream_map_.find(stream_id); + if (it != stream_map_.end()) { + return it->second.send_window; + } + return -1; +} + +int OgHttp2Session::GetStreamReceiveWindowLimit(Http2StreamId stream_id) const { + auto it = stream_map_.find(stream_id); + if (it != stream_map_.end()) { + return it->second.window_manager.WindowSizeLimit(); + } + return -1; +} + +int OgHttp2Session::GetStreamReceiveWindowSize(Http2StreamId stream_id) const { + auto it = stream_map_.find(stream_id); + if (it != stream_map_.end()) { + return it->second.window_manager.CurrentWindowSize(); + } + return -1; +} + +int OgHttp2Session::GetReceiveWindowSize() const { + return connection_window_manager_.CurrentWindowSize(); +} + +int OgHttp2Session::GetHpackEncoderDynamicTableSize() const { + const spdy::HpackEncoder* encoder = framer_.GetHpackEncoder(); + return encoder == nullptr ? 0 : encoder->GetDynamicTableSize(); +} + +int OgHttp2Session::GetHpackEncoderDynamicTableCapacity() const { + const spdy::HpackEncoder* encoder = framer_.GetHpackEncoder(); + return encoder == nullptr ? kDefaultHpackTableCapacity + : encoder->CurrentHeaderTableSizeSetting(); +} + +int OgHttp2Session::GetHpackDecoderDynamicTableSize() const { + const spdy::HpackDecoderAdapter* decoder = decoder_.GetHpackDecoder(); + return decoder == nullptr ? 0 : decoder->GetDynamicTableSize(); +} + +int OgHttp2Session::GetHpackDecoderSizeLimit() const { + const spdy::HpackDecoderAdapter* decoder = decoder_.GetHpackDecoder(); + return decoder == nullptr ? 0 : decoder->GetCurrentHeaderTableSizeSetting(); +} + +int64_t OgHttp2Session::ProcessBytes(absl::string_view bytes) { + QUICHE_VLOG(2) << TracePerspectiveAsString(options_.perspective) + << " processing [" << absl::CEscape(bytes) << "]"; + return absl::visit(ProcessBytesResultVisitor(), ProcessBytesImpl(bytes)); +} + +absl::variant +OgHttp2Session::ProcessBytesImpl(absl::string_view bytes) { + if (processing_bytes_) { + QUICHE_VLOG(1) << "Returning early; already processing bytes."; + return 0; + } + processing_bytes_ = true; + RunOnExit r{[this]() { processing_bytes_ = false; }}; + + if (options_.blackhole_data_on_connection_error && latched_error_) { + return static_cast(bytes.size()); + } + + int64_t preface_consumed = 0; + if (!remaining_preface_.empty()) { + QUICHE_VLOG(2) << "Preface bytes remaining: " << remaining_preface_.size(); + // decoder_ does not understand the client connection preface. + size_t min_size = std::min(remaining_preface_.size(), bytes.size()); + if (!absl::StartsWith(remaining_preface_, bytes.substr(0, min_size))) { + // Preface doesn't match! + QUICHE_DLOG(INFO) << "Preface doesn't match! Expected: [" + << absl::CEscape(remaining_preface_) << "], actual: [" + << absl::CEscape(bytes) << "]"; + LatchErrorAndNotify(Http2ErrorCode::PROTOCOL_ERROR, + ConnectionError::kInvalidConnectionPreface); + return ProcessBytesError::kInvalidConnectionPreface; + } + remaining_preface_.remove_prefix(min_size); + bytes.remove_prefix(min_size); + if (!remaining_preface_.empty()) { + QUICHE_VLOG(2) << "Preface bytes remaining: " + << remaining_preface_.size(); + return static_cast(min_size); + } + preface_consumed = min_size; + } + int64_t result = decoder_.ProcessInput(bytes.data(), bytes.size()); + QUICHE_VLOG(2) << "ProcessBytes result: " << result; + if (fatal_visitor_callback_failure_) { + QUICHE_DCHECK(latched_error_); + QUICHE_VLOG(2) << "Visitor callback failed while processing bytes."; + return ProcessBytesError::kVisitorCallbackFailed; + } + if (latched_error_ || result < 0) { + QUICHE_VLOG(2) << "ProcessBytes encountered an error."; + if (options_.blackhole_data_on_connection_error) { + return static_cast(bytes.size() + preface_consumed); + } else { + return ProcessBytesError::kUnspecified; + } + } + return result + preface_consumed; +} + +int OgHttp2Session::Consume(Http2StreamId stream_id, size_t num_bytes) { + auto it = stream_map_.find(stream_id); + if (it == stream_map_.end()) { + QUICHE_LOG(ERROR) << "Stream " << stream_id << " not found when consuming " + << num_bytes << " bytes"; + } else { + it->second.window_manager.MarkDataFlushed(num_bytes); + } + connection_window_manager_.MarkDataFlushed(num_bytes); + return 0; // Remove? +} + +void OgHttp2Session::StartGracefulShutdown() { + if (IsServerSession()) { + if (!queued_goaway_) { + EnqueueFrame(std::make_unique( + std::numeric_limits::max(), spdy::ERROR_CODE_NO_ERROR, + "graceful_shutdown")); + } + } else { + QUICHE_LOG(ERROR) << "Graceful shutdown not needed for clients."; + } +} + +void OgHttp2Session::EnqueueFrame(std::unique_ptr frame) { + if (queued_immediate_goaway_) { + // Do not allow additional frames to be enqueued after the GOAWAY. + return; + } + + const bool is_non_ack_settings = IsNonAckSettings(*frame); + MaybeSetupPreface(is_non_ack_settings); + + if (frame->frame_type() == spdy::SpdyFrameType::GOAWAY) { + queued_goaway_ = true; + if (latched_error_) { + PrepareForImmediateGoAway(); + } + } else if (frame->fin() || + frame->frame_type() == spdy::SpdyFrameType::RST_STREAM) { + auto iter = stream_map_.find(frame->stream_id()); + if (iter != stream_map_.end()) { + iter->second.half_closed_local = true; + } + if (frame->frame_type() == spdy::SpdyFrameType::RST_STREAM) { + // TODO(diannahu): Condition on existence in the stream map? + streams_reset_.insert(frame->stream_id()); + } + } else if (frame->frame_type() == spdy::SpdyFrameType::WINDOW_UPDATE) { + UpdateReceiveWindow( + frame->stream_id(), + reinterpret_cast(*frame).delta()); + } else if (is_non_ack_settings) { + HandleOutboundSettings( + *reinterpret_cast(frame.get())); + } + if (frame->stream_id() != 0) { + auto result = queued_frames_.insert({frame->stream_id(), 1}); + if (!result.second) { + ++(result.first->second); + } + } + frames_.push_back(std::move(frame)); +} + +int OgHttp2Session::Send() { + if (sending_) { + QUICHE_VLOG(1) << TracePerspectiveAsString(options_.perspective) + << " returning early; already sending."; + return 0; + } + sending_ = true; + RunOnExit r{[this]() { sending_ = false; }}; + + if (fatal_send_error_) { + return kSendError; + } + + MaybeSetupPreface(/*sending_outbound_settings=*/false); + + SendResult continue_writing = SendQueuedFrames(); + if (queued_immediate_goaway_) { + // If an immediate GOAWAY was queued, then the above flush either sent the + // GOAWAY or buffered it to be sent on the next successful flush. In either + // case, return early here to avoid sending other frames. + return InterpretSendResult(continue_writing); + } + // Notify on new/pending streams closed due to GOAWAY receipt. + CloseGoAwayRejectedStreams(); + // Wake streams for writes. + while (continue_writing == SendResult::SEND_OK && HasReadyStream()) { + const Http2StreamId stream_id = GetNextReadyStream(); + // TODO(birenroy): Add a return value to indicate write blockage, so streams + // aren't woken unnecessarily. + QUICHE_VLOG(1) << "Waking stream " << stream_id << " for writes."; + continue_writing = WriteForStream(stream_id); + } + if (continue_writing == SendResult::SEND_OK) { + continue_writing = SendQueuedFrames(); + } + return InterpretSendResult(continue_writing); +} + +int OgHttp2Session::InterpretSendResult(SendResult result) { + if (result == SendResult::SEND_ERROR) { + fatal_send_error_ = true; + return kSendError; + } else { + return 0; + } +} + +bool OgHttp2Session::HasReadyStream() const { + return !trailers_ready_.empty() || + (write_scheduler_.HasReadyStreams() && connection_send_window_ > 0); +} + +Http2StreamId OgHttp2Session::GetNextReadyStream() { + QUICHE_DCHECK(HasReadyStream()); + if (!trailers_ready_.empty()) { + const Http2StreamId stream_id = *trailers_ready_.begin(); + // WriteForStream() will re-mark the stream as ready, if necessary. + write_scheduler_.MarkStreamNotReady(stream_id); + return stream_id; + } + return write_scheduler_.PopNextReadyStream(); +} + +OgHttp2Session::SendResult OgHttp2Session::MaybeSendBufferedData() { + int64_t result = std::numeric_limits::max(); + while (result > 0 && !buffered_data_.empty()) { + result = visitor_.OnReadyToSend(buffered_data_); + if (result > 0) { + buffered_data_.erase(0, result); + } + } + if (result < 0) { + LatchErrorAndNotify(Http2ErrorCode::INTERNAL_ERROR, + ConnectionError::kSendError); + return SendResult::SEND_ERROR; + } + return buffered_data_.empty() ? SendResult::SEND_OK + : SendResult::SEND_BLOCKED; +} + +OgHttp2Session::SendResult OgHttp2Session::SendQueuedFrames() { + // Flush any serialized prefix. + const SendResult result = MaybeSendBufferedData(); + if (result != SendResult::SEND_OK) { + return result; + } + // Serialize and send frames in the queue. + while (!frames_.empty()) { + const auto& frame_ptr = frames_.front(); + FrameAttributeCollector c; + frame_ptr->Visit(&c); + + // DATA frames should never be queued. + QUICHE_DCHECK_NE(c.frame_type(), 0); + + const bool stream_reset = + c.stream_id() != 0 && streams_reset_.count(c.stream_id()) > 0; + if (stream_reset && + c.frame_type() != static_cast(FrameType::RST_STREAM)) { + // The stream has been reset, so any other remaining frames can be + // skipped. + // TODO(birenroy): inform the visitor of frames that are skipped. + DecrementQueuedFrameCount(c.stream_id(), c.frame_type()); + frames_.pop_front(); + continue; + } else if (!IsServerSession() && received_goaway_ && + c.stream_id() > + static_cast(received_goaway_stream_id_)) { + // This frame will be ignored by the server, so don't send it. The stream + // associated with this frame should have been closed in OnGoAway(). + frames_.pop_front(); + continue; + } + // Frames can't accurately report their own length; the actual serialized + // length must be used instead. + spdy::SpdySerializedFrame frame = framer_.SerializeFrame(*frame_ptr); + const size_t frame_payload_length = frame.size() - spdy::kFrameHeaderSize; + frame_ptr->Visit(&send_logger_); + visitor_.OnBeforeFrameSent(c.frame_type(), c.stream_id(), + frame_payload_length, c.flags()); + const int64_t result = visitor_.OnReadyToSend(absl::string_view(frame)); + if (result < 0) { + LatchErrorAndNotify(Http2ErrorCode::INTERNAL_ERROR, + ConnectionError::kSendError); + return SendResult::SEND_ERROR; + } else if (result == 0) { + // Write blocked. + return SendResult::SEND_BLOCKED; + } else { + frames_.pop_front(); + + const bool ok = + AfterFrameSent(c.frame_type(), c.stream_id(), frame_payload_length, + c.flags(), c.error_code()); + if (!ok) { + LatchErrorAndNotify(Http2ErrorCode::INTERNAL_ERROR, + ConnectionError::kSendError); + return SendResult::SEND_ERROR; + } + if (static_cast(result) < frame.size()) { + // The frame was partially written, so the rest must be buffered. + buffered_data_.append(frame.data() + result, frame.size() - result); + return SendResult::SEND_BLOCKED; + } + } + } + return SendResult::SEND_OK; +} + +bool OgHttp2Session::AfterFrameSent(uint8_t frame_type_int, uint32_t stream_id, + size_t payload_length, uint8_t flags, + uint32_t error_code) { + const FrameType frame_type = static_cast(frame_type_int); + int result = visitor_.OnFrameSent(frame_type_int, stream_id, payload_length, + flags, error_code); + if (result < 0) { + return false; + } + if (stream_id == 0) { + if (frame_type == FrameType::SETTINGS) { + const bool is_settings_ack = (flags & ACK_FLAG); + if (is_settings_ack && encoder_header_table_capacity_when_acking_) { + framer_.UpdateHeaderEncoderTableSize( + encoder_header_table_capacity_when_acking_.value()); + encoder_header_table_capacity_when_acking_ = absl::nullopt; + } else if (!is_settings_ack) { + sent_non_ack_settings_ = true; + } + } + return true; + } + + const bool contains_fin = + (frame_type == FrameType::DATA || frame_type == FrameType::HEADERS) && + (flags & END_STREAM_FLAG) == END_STREAM_FLAG; + auto it = stream_map_.find(stream_id); + const bool still_open_remote = + it != stream_map_.end() && !it->second.half_closed_remote; + if (contains_fin && still_open_remote && + options_.rst_stream_no_error_when_incomplete && IsServerSession()) { + // Since the peer has not yet ended the stream, this endpoint should + // send a RST_STREAM NO_ERROR. See RFC 7540 Section 8.1. + frames_.push_front(std::make_unique( + stream_id, spdy::SpdyErrorCode::ERROR_CODE_NO_ERROR)); + auto queued_result = queued_frames_.insert({stream_id, 1}); + if (!queued_result.second) { + ++(queued_result.first->second); + } + it->second.half_closed_remote = true; + } + + DecrementQueuedFrameCount(stream_id, frame_type_int); + return true; +} + +OgHttp2Session::SendResult OgHttp2Session::WriteForStream( + Http2StreamId stream_id) { + auto it = stream_map_.find(stream_id); + if (it == stream_map_.end()) { + QUICHE_LOG(ERROR) << "Can't find stream " << stream_id + << " which is ready to write!"; + return SendResult::SEND_OK; + } + StreamState& state = it->second; + auto reset_it = streams_reset_.find(stream_id); + if (reset_it != streams_reset_.end()) { + // The stream has been reset; there's no point in sending DATA or trailing + // HEADERS. + state.outbound_body = nullptr; + state.trailers = nullptr; + return SendResult::SEND_OK; + } + + SendResult connection_can_write = SendResult::SEND_OK; + if (state.outbound_body == nullptr || + (!options_.trailers_require_end_data && state.data_deferred)) { + // No data to send, but there might be trailers. + if (state.trailers != nullptr) { + // Trailers will include END_STREAM, so the data source can be discarded. + // Since data_deferred is true, there is no data waiting to be flushed for + // this stream. + state.outbound_body = nullptr; + auto block_ptr = std::move(state.trailers); + if (state.half_closed_local) { + QUICHE_LOG(ERROR) << "Sent fin; can't send trailers."; + } else { + SendTrailers(stream_id, std::move(*block_ptr)); + } + } + return SendResult::SEND_OK; + } + int32_t available_window = + std::min({connection_send_window_, state.send_window, + static_cast(max_frame_payload_)}); + while (connection_can_write == SendResult::SEND_OK && available_window > 0 && + state.outbound_body != nullptr && !state.data_deferred) { + auto [length, end_data] = + state.outbound_body->SelectPayloadLength(available_window); + QUICHE_VLOG(2) << "WriteForStream | length: " << length + << " end_data: " << end_data + << " trailers: " << state.trailers.get(); + if (length == 0 && !end_data && + (options_.trailers_require_end_data || state.trailers == nullptr)) { + // An unproductive call to SelectPayloadLength() results in this stream + // entering the "deferred" state only if either no trailers are available + // to send, or trailers require an explicit end_data before being sent. + state.data_deferred = true; + break; + } else if (length == DataFrameSource::kError) { + // TODO(birenroy): Consider queuing a RST_STREAM INTERNAL_ERROR instead. + CloseStream(stream_id, Http2ErrorCode::INTERNAL_ERROR); + // No more work on the stream; it has been closed. + break; + } + const bool fin = end_data ? state.outbound_body->send_fin() : false; + if (length > 0 || fin) { + spdy::SpdyDataIR data(stream_id); + data.set_fin(fin); + data.SetDataShallow(length); + spdy::SpdySerializedFrame header = + spdy::SpdyFramer::SerializeDataFrameHeaderWithPaddingLengthField( + data); + QUICHE_DCHECK(buffered_data_.empty() && frames_.empty()); + const bool success = + state.outbound_body->Send(absl::string_view(header), length); + if (!success) { + connection_can_write = SendResult::SEND_BLOCKED; + break; + } + connection_send_window_ -= length; + state.send_window -= length; + available_window = std::min({connection_send_window_, state.send_window, + static_cast(max_frame_payload_)}); + if (fin) { + state.half_closed_local = true; + MaybeFinWithRstStream(it); + } + const bool ok = AfterFrameSent(/* DATA */ 0, stream_id, length, + fin ? END_STREAM_FLAG : 0x0, 0); + if (!ok) { + LatchErrorAndNotify(Http2ErrorCode::INTERNAL_ERROR, + ConnectionError::kSendError); + return SendResult::SEND_ERROR; + } + if (!stream_map_.contains(stream_id)) { + // Note: the stream may have been closed if `fin` is true. + break; + } + } + if (end_data || (length == 0 && state.trailers != nullptr && + !options_.trailers_require_end_data)) { + // If SelectPayloadLength() returned {0, false}, and there are trailers to + // send, and the safety feature is disabled, it's okay to send the + // trailers. + if (state.trailers != nullptr) { + auto block_ptr = std::move(state.trailers); + if (fin) { + QUICHE_LOG(ERROR) << "Sent fin; can't send trailers."; + } else { + SendTrailers(stream_id, std::move(*block_ptr)); + } + } + state.outbound_body = nullptr; + } + } + // If the stream still exists and has data to send, it should be marked as + // ready in the write scheduler. + if (stream_map_.contains(stream_id) && !state.data_deferred && + state.send_window > 0 && state.outbound_body != nullptr) { + write_scheduler_.MarkStreamReady(stream_id, false); + } + // Streams can continue writing as long as the connection is not write-blocked + // and there is additional flow control quota available. + if (connection_can_write != SendResult::SEND_OK) { + return connection_can_write; + } + return connection_send_window_ <= 0 ? SendResult::SEND_BLOCKED + : SendResult::SEND_OK; +} + +void OgHttp2Session::SerializeMetadata(Http2StreamId stream_id, + std::unique_ptr source) { + const uint32_t max_payload_size = + std::min(kMaxAllowedMetadataFrameSize, max_frame_payload_); + auto payload_buffer = std::make_unique(max_payload_size); + + while (true) { + auto [written, end_metadata] = + source->Pack(payload_buffer.get(), max_payload_size); + if (written < 0) { + // Unable to pack any metadata. + return; + } + QUICHE_DCHECK_LE(static_cast(written), max_payload_size); + auto payload = absl::string_view( + reinterpret_cast(payload_buffer.get()), written); + EnqueueFrame(std::make_unique( + stream_id, kMetadataFrameType, end_metadata ? kMetadataEndFlag : 0u, + std::string(payload))); + if (end_metadata) { + return; + } + } +} + +int32_t OgHttp2Session::SubmitRequest( + absl::Span headers, + std::unique_ptr data_source, void* user_data) { + // TODO(birenroy): return an error for the incorrect perspective + const Http2StreamId stream_id = next_stream_id_; + next_stream_id_ += 2; + if (!pending_streams_.empty() || !CanCreateStream()) { + // TODO(diannahu): There should probably be a limit to the number of allowed + // pending streams. + pending_streams_.insert( + {stream_id, PendingStreamState{ToHeaderBlock(headers), + std::move(data_source), user_data}}); + StartPendingStreams(); + } else { + StartRequest(stream_id, ToHeaderBlock(headers), std::move(data_source), + user_data); + } + return stream_id; +} + +int OgHttp2Session::SubmitResponse( + Http2StreamId stream_id, absl::Span headers, + std::unique_ptr data_source) { + // TODO(birenroy): return an error for the incorrect perspective + auto iter = stream_map_.find(stream_id); + if (iter == stream_map_.end()) { + QUICHE_LOG(ERROR) << "Unable to find stream " << stream_id; + return -501; // NGHTTP2_ERR_INVALID_ARGUMENT + } + const bool end_stream = data_source == nullptr; + if (!end_stream) { + // Add data source to stream state + iter->second.outbound_body = std::move(data_source); + write_scheduler_.MarkStreamReady(stream_id, false); + } + SendHeaders(stream_id, ToHeaderBlock(headers), end_stream); + return 0; +} + +int OgHttp2Session::SubmitTrailer(Http2StreamId stream_id, + absl::Span trailers) { + // TODO(birenroy): Reject trailers when acting as a client? + auto iter = stream_map_.find(stream_id); + if (iter == stream_map_.end()) { + QUICHE_LOG(ERROR) << "Unable to find stream " << stream_id; + return -501; // NGHTTP2_ERR_INVALID_ARGUMENT + } + StreamState& state = iter->second; + if (state.half_closed_local) { + QUICHE_LOG(ERROR) << "Stream " << stream_id << " is half closed (local)"; + return -514; // NGHTTP2_ERR_INVALID_STREAM_STATE + } + if (state.trailers != nullptr) { + QUICHE_LOG(ERROR) << "Stream " << stream_id + << " already has trailers queued"; + return -514; // NGHTTP2_ERR_INVALID_STREAM_STATE + } + if (state.outbound_body == nullptr) { + // Enqueue trailers immediately. + SendTrailers(stream_id, ToHeaderBlock(trailers)); + } else { + QUICHE_LOG_IF(ERROR, state.outbound_body->send_fin()) + << "DataFrameSource will send fin, preventing trailers!"; + // Save trailers so they can be written once data is done. + state.trailers = + std::make_unique(ToHeaderBlock(trailers)); + if (!options_.trailers_require_end_data || !iter->second.data_deferred) { + trailers_ready_.insert(stream_id); + } + } + return 0; +} + +void OgHttp2Session::SubmitMetadata(Http2StreamId stream_id, + std::unique_ptr source) { + SerializeMetadata(stream_id, std::move(source)); +} + +void OgHttp2Session::SubmitSettings(absl::Span settings) { + auto frame = PrepareSettingsFrame(settings); + EnqueueFrame(std::move(frame)); +} + +void OgHttp2Session::OnError(SpdyFramerError error, + std::string detailed_error) { + QUICHE_VLOG(1) << "Error: " + << http2::Http2DecoderAdapter::SpdyFramerErrorToString(error) + << " details: " << detailed_error; + // TODO(diannahu): Consider propagating `detailed_error`. + LatchErrorAndNotify(GetHttp2ErrorCode(error), ConnectionError::kParseError); +} + +void OgHttp2Session::OnCommonHeader(spdy::SpdyStreamId stream_id, size_t length, + uint8_t type, uint8_t flags) { + highest_received_stream_id_ = std::max(static_cast(stream_id), + highest_received_stream_id_); + if (streams_reset_.contains(stream_id)) { + return; + } + const bool result = visitor_.OnFrameHeader(stream_id, length, type, flags); + if (!result) { + fatal_visitor_callback_failure_ = true; + decoder_.StopProcessing(); + } +} + +void OgHttp2Session::OnDataFrameHeader(spdy::SpdyStreamId stream_id, + size_t length, bool /*fin*/) { + auto iter = stream_map_.find(stream_id); + if (iter == stream_map_.end() || streams_reset_.contains(stream_id)) { + // The stream does not exist; it could be an error or a benign close, e.g., + // getting data for a stream this connection recently closed. + if (static_cast(stream_id) > highest_processed_stream_id_) { + // Receiving DATA before HEADERS is a connection error. + LatchErrorAndNotify(Http2ErrorCode::PROTOCOL_ERROR, + ConnectionError::kWrongFrameSequence); + } + return; + } + + if (static_cast(length) > + connection_window_manager_.CurrentWindowSize()) { + // Peer exceeded the connection flow control limit. + LatchErrorAndNotify( + Http2ErrorCode::FLOW_CONTROL_ERROR, + Http2VisitorInterface::ConnectionError::kFlowControlError); + return; + } + + if (static_cast(length) > + iter->second.window_manager.CurrentWindowSize()) { + // Peer exceeded the stream flow control limit. + EnqueueFrame(std::make_unique( + stream_id, spdy::ERROR_CODE_FLOW_CONTROL_ERROR)); + return; + } + + const bool result = visitor_.OnBeginDataForStream(stream_id, length); + if (!result) { + fatal_visitor_callback_failure_ = true; + decoder_.StopProcessing(); + } + + if (!iter->second.can_receive_body && length > 0) { + EnqueueFrame(std::make_unique( + stream_id, spdy::ERROR_CODE_PROTOCOL_ERROR)); + return; + } + + // Validate against the content-length if it exists. + if (iter->second.remaining_content_length.has_value()) { + if (length > *iter->second.remaining_content_length) { + HandleContentLengthError(stream_id); + iter->second.remaining_content_length.reset(); + } else { + *iter->second.remaining_content_length -= length; + } + } +} + +void OgHttp2Session::OnStreamFrameData(spdy::SpdyStreamId stream_id, + const char* data, size_t len) { + // Count the data against flow control, even if the stream is unknown. + MarkDataBuffered(stream_id, len); + + if (!stream_map_.contains(stream_id) || streams_reset_.contains(stream_id)) { + // If the stream was unknown due to a protocol error, the visitor was + // informed in OnDataFrameHeader(). + return; + } + + const bool result = + visitor_.OnDataForStream(stream_id, absl::string_view(data, len)); + if (!result) { + fatal_visitor_callback_failure_ = true; + decoder_.StopProcessing(); + } +} + +void OgHttp2Session::OnStreamEnd(spdy::SpdyStreamId stream_id) { + auto iter = stream_map_.find(stream_id); + if (iter != stream_map_.end()) { + iter->second.half_closed_remote = true; + if (streams_reset_.contains(stream_id)) { + return; + } + + // Validate against the content-length if it exists. + if (iter->second.remaining_content_length.has_value() && + *iter->second.remaining_content_length != 0) { + HandleContentLengthError(stream_id); + return; + } + + const bool result = visitor_.OnEndStream(stream_id); + if (!result) { + fatal_visitor_callback_failure_ = true; + decoder_.StopProcessing(); + } + } + + auto queued_frames_iter = queued_frames_.find(stream_id); + const bool no_queued_frames = queued_frames_iter == queued_frames_.end() || + queued_frames_iter->second == 0; + if (iter != stream_map_.end() && iter->second.half_closed_local && + !IsServerSession() && no_queued_frames) { + // From the client's perspective, the stream can be closed if it's already + // half_closed_local. + CloseStream(stream_id, Http2ErrorCode::HTTP2_NO_ERROR); + } +} + +void OgHttp2Session::OnStreamPadLength(spdy::SpdyStreamId stream_id, + size_t value) { + bool result = visitor_.OnDataPaddingLength(stream_id, 1 + value); + if (!result) { + fatal_visitor_callback_failure_ = true; + decoder_.StopProcessing(); + } + MarkDataBuffered(stream_id, 1 + value); +} + +void OgHttp2Session::OnStreamPadding(spdy::SpdyStreamId /*stream_id*/, size_t + /*len*/) { + // Flow control was accounted for in OnStreamPadLength(). + // TODO(181586191): Pass padding to the visitor? +} + +spdy::SpdyHeadersHandlerInterface* OgHttp2Session::OnHeaderFrameStart( + spdy::SpdyStreamId stream_id) { + auto it = stream_map_.find(stream_id); + if (it != stream_map_.end() && !streams_reset_.contains(stream_id)) { + headers_handler_.set_stream_id(stream_id); + headers_handler_.set_header_type( + NextHeaderType(it->second.received_header_type)); + return &headers_handler_; + } else { + return &noop_headers_handler_; + } +} + +void OgHttp2Session::OnHeaderFrameEnd(spdy::SpdyStreamId stream_id) { + auto it = stream_map_.find(stream_id); + if (it != stream_map_.end()) { + if (headers_handler_.header_type() == HeaderType::RESPONSE && + !headers_handler_.status_header().empty() && + headers_handler_.status_header()[0] == '1') { + // If response headers carried a 1xx response code, final response headers + // should still be forthcoming. + headers_handler_.set_header_type(HeaderType::RESPONSE_100); + } + it->second.received_header_type = headers_handler_.header_type(); + + // Track the content-length if the headers indicate that a body can follow. + it->second.can_receive_body = + headers_handler_.CanReceiveBody() && !it->second.sent_head_method; + if (it->second.can_receive_body) { + it->second.remaining_content_length = headers_handler_.content_length(); + } + + headers_handler_.set_stream_id(0); + } +} + +void OgHttp2Session::OnRstStream(spdy::SpdyStreamId stream_id, + spdy::SpdyErrorCode error_code) { + auto iter = stream_map_.find(stream_id); + if (iter != stream_map_.end()) { + iter->second.half_closed_remote = true; + iter->second.outbound_body = nullptr; + } else if (static_cast(stream_id) > + highest_processed_stream_id_) { + // Receiving RST_STREAM before HEADERS is a connection error. + LatchErrorAndNotify(Http2ErrorCode::PROTOCOL_ERROR, + ConnectionError::kWrongFrameSequence); + return; + } + if (streams_reset_.contains(stream_id)) { + return; + } + visitor_.OnRstStream(stream_id, TranslateErrorCode(error_code)); + // TODO(birenroy): Consider whether there are outbound frames queued for the + // stream. + CloseStream(stream_id, TranslateErrorCode(error_code)); +} + +void OgHttp2Session::OnSettings() { + visitor_.OnSettingsStart(); + auto settings = std::make_unique(); + settings->set_is_ack(true); + EnqueueFrame(std::move(settings)); +} + +void OgHttp2Session::OnSetting(spdy::SpdySettingsId id, uint32_t value) { + switch (id) { + case HEADER_TABLE_SIZE: + value = std::min(value, HpackCapacityBound(options_)); + if (value < framer_.GetHpackEncoder()->CurrentHeaderTableSizeSetting()) { + // Safe to apply a smaller table capacity immediately. + QUICHE_VLOG(2) << TracePerspectiveAsString(options_.perspective) + << " applying encoder table capacity " << value; + framer_.GetHpackEncoder()->ApplyHeaderTableSizeSetting(value); + } else { + QUICHE_VLOG(2) + << TracePerspectiveAsString(options_.perspective) + << " NOT applying encoder table capacity until writing ack: " + << value; + encoder_header_table_capacity_when_acking_ = value; + } + break; + case ENABLE_PUSH: + if (value > 1u) { + visitor_.OnInvalidFrame( + 0, Http2VisitorInterface::InvalidFrameError::kProtocol); + // The specification says this is a connection-level protocol error. + LatchErrorAndNotify( + Http2ErrorCode::PROTOCOL_ERROR, + Http2VisitorInterface::ConnectionError::kInvalidSetting); + return; + } + // Aside from validation, this setting is ignored. + break; + case MAX_CONCURRENT_STREAMS: + max_outbound_concurrent_streams_ = value; + if (!IsServerSession()) { + // We may now be able to start pending streams. + StartPendingStreams(); + } + break; + case INITIAL_WINDOW_SIZE: + if (value > spdy::kSpdyMaximumWindowSize) { + visitor_.OnInvalidFrame( + 0, Http2VisitorInterface::InvalidFrameError::kFlowControl); + // The specification says this is a connection-level flow control error. + LatchErrorAndNotify( + Http2ErrorCode::FLOW_CONTROL_ERROR, + Http2VisitorInterface::ConnectionError::kFlowControlError); + return; + } else { + UpdateStreamSendWindowSizes(value); + } + break; + case MAX_FRAME_SIZE: + if (value < kDefaultFramePayloadSizeLimit || + value > kMaximumFramePayloadSizeLimit) { + visitor_.OnInvalidFrame( + 0, Http2VisitorInterface::InvalidFrameError::kProtocol); + // The specification says this is a connection-level protocol error. + LatchErrorAndNotify( + Http2ErrorCode::PROTOCOL_ERROR, + Http2VisitorInterface::ConnectionError::kInvalidSetting); + return; + } + max_frame_payload_ = value; + break; + case ENABLE_CONNECT_PROTOCOL: + if (value > 1u || (value == 0 && peer_enables_connect_protocol_)) { + visitor_.OnInvalidFrame( + 0, Http2VisitorInterface::InvalidFrameError::kProtocol); + LatchErrorAndNotify( + Http2ErrorCode::PROTOCOL_ERROR, + Http2VisitorInterface::ConnectionError::kInvalidSetting); + return; + } + peer_enables_connect_protocol_ = (value == 1u); + break; + default: + // TODO(bnc): See if C++17 inline constants are allowed in QUICHE. + if (id == kMetadataExtensionId) { + peer_supports_metadata_ = (value != 0); + } else { + QUICHE_VLOG(1) << "Unimplemented SETTING id: " << id; + } + } + visitor_.OnSetting({id, value}); +} + +void OgHttp2Session::OnSettingsEnd() { visitor_.OnSettingsEnd(); } + +void OgHttp2Session::OnSettingsAck() { + if (!settings_ack_callbacks_.empty()) { + SettingsAckCallback callback = std::move(settings_ack_callbacks_.front()); + settings_ack_callbacks_.pop_front(); + callback(); + } + + visitor_.OnSettingsAck(); +} + +void OgHttp2Session::OnPing(spdy::SpdyPingId unique_id, bool is_ack) { + visitor_.OnPing(unique_id, is_ack); + if (options_.auto_ping_ack && !is_ack) { + auto ping = std::make_unique(unique_id); + ping->set_is_ack(true); + EnqueueFrame(std::move(ping)); + } +} + +void OgHttp2Session::OnGoAway(spdy::SpdyStreamId last_accepted_stream_id, + spdy::SpdyErrorCode error_code) { + if (received_goaway_ && + last_accepted_stream_id > + static_cast(received_goaway_stream_id_)) { + // This GOAWAY has a higher `last_accepted_stream_id` than a previous + // GOAWAY, a connection-level spec violation. + const bool ok = visitor_.OnInvalidFrame( + kConnectionStreamId, + Http2VisitorInterface::InvalidFrameError::kProtocol); + if (!ok) { + fatal_visitor_callback_failure_ = true; + } + LatchErrorAndNotify(Http2ErrorCode::PROTOCOL_ERROR, + ConnectionError::kInvalidGoAwayLastStreamId); + return; + } + + received_goaway_ = true; + received_goaway_stream_id_ = last_accepted_stream_id; + const bool result = visitor_.OnGoAway(last_accepted_stream_id, + TranslateErrorCode(error_code), ""); + if (!result) { + fatal_visitor_callback_failure_ = true; + decoder_.StopProcessing(); + } + + // Close the streams above `last_accepted_stream_id`. Only applies if the + // session receives a GOAWAY as a client, as we do not support server push. + if (last_accepted_stream_id == spdy::kMaxStreamId || IsServerSession()) { + return; + } + std::vector streams_to_close; + for (const auto& [stream_id, stream_state] : stream_map_) { + if (static_cast(stream_id) > last_accepted_stream_id) { + streams_to_close.push_back(stream_id); + } + } + for (Http2StreamId stream_id : streams_to_close) { + CloseStream(stream_id, Http2ErrorCode::REFUSED_STREAM); + } +} + +bool OgHttp2Session::OnGoAwayFrameData(const char* /*goaway_data*/, size_t + /*len*/) { + // Opaque data is currently ignored. + return true; +} + +void OgHttp2Session::OnHeaders(spdy::SpdyStreamId stream_id, + size_t /*payload_length*/, bool /*has_priority*/, + int /*weight*/, + spdy::SpdyStreamId /*parent_stream_id*/, + bool /*exclusive*/, bool fin, bool /*end*/) { + if (stream_id % 2 == 0) { + // Server push is disabled; receiving push HEADERS is a connection error. + LatchErrorAndNotify(Http2ErrorCode::PROTOCOL_ERROR, + ConnectionError::kInvalidNewStreamId); + return; + } + if (fin) { + headers_handler_.set_frame_contains_fin(); + } + if (IsServerSession()) { + const auto new_stream_id = static_cast(stream_id); + if (stream_map_.find(new_stream_id) != stream_map_.end() && fin) { + // Not a new stream, must be trailers. + return; + } + if (new_stream_id <= highest_processed_stream_id_) { + // A new stream ID lower than the watermark is a connection error. + LatchErrorAndNotify(Http2ErrorCode::PROTOCOL_ERROR, + ConnectionError::kInvalidNewStreamId); + return; + } + + if (stream_map_.size() >= max_inbound_concurrent_streams_) { + // The new stream would exceed our advertised and acknowledged + // MAX_CONCURRENT_STREAMS. For parity with nghttp2, treat this error as a + // connection-level PROTOCOL_ERROR. + bool ok = visitor_.OnInvalidFrame( + stream_id, Http2VisitorInterface::InvalidFrameError::kProtocol); + if (!ok) { + fatal_visitor_callback_failure_ = true; + } + LatchErrorAndNotify(Http2ErrorCode::PROTOCOL_ERROR, + ConnectionError::kExceededMaxConcurrentStreams); + return; + } + if (stream_map_.size() >= pending_max_inbound_concurrent_streams_) { + // The new stream would exceed our advertised but unacked + // MAX_CONCURRENT_STREAMS. Refuse the stream for parity with nghttp2. + EnqueueFrame(std::make_unique( + stream_id, spdy::ERROR_CODE_REFUSED_STREAM)); + const bool ok = visitor_.OnInvalidFrame( + stream_id, Http2VisitorInterface::InvalidFrameError::kRefusedStream); + if (!ok) { + fatal_visitor_callback_failure_ = true; + LatchErrorAndNotify(Http2ErrorCode::REFUSED_STREAM, + ConnectionError::kExceededMaxConcurrentStreams); + } + return; + } + + CreateStream(stream_id); + } +} + +void OgHttp2Session::OnWindowUpdate(spdy::SpdyStreamId stream_id, + int delta_window_size) { + constexpr int kMaxWindowValue = 2147483647; // (1 << 31) - 1 + if (stream_id == 0) { + if (delta_window_size == 0) { + // A PROTOCOL_ERROR, according to RFC 9113 Section 6.9. + LatchErrorAndNotify(Http2ErrorCode::PROTOCOL_ERROR, + ConnectionError::kFlowControlError); + return; + } + if (connection_send_window_ > 0 && + delta_window_size > (kMaxWindowValue - connection_send_window_)) { + // Window overflow is a FLOW_CONTROL_ERROR. + LatchErrorAndNotify(Http2ErrorCode::FLOW_CONTROL_ERROR, + ConnectionError::kFlowControlError); + return; + } + connection_send_window_ += delta_window_size; + } else { + if (delta_window_size == 0) { + // A PROTOCOL_ERROR, according to RFC 9113 Section 6.9. + EnqueueFrame(std::make_unique( + stream_id, spdy::ERROR_CODE_PROTOCOL_ERROR)); + return; + } + auto it = stream_map_.find(stream_id); + if (it == stream_map_.end()) { + QUICHE_VLOG(1) << "Stream " << stream_id << " not found!"; + if (static_cast(stream_id) > + highest_processed_stream_id_) { + // Receiving WINDOW_UPDATE before HEADERS is a connection error. + LatchErrorAndNotify(Http2ErrorCode::PROTOCOL_ERROR, + ConnectionError::kWrongFrameSequence); + } + // Do not inform the visitor of a WINDOW_UPDATE for a non-existent stream. + return; + } else { + if (streams_reset_.contains(stream_id)) { + return; + } + if (it->second.send_window > 0 && + delta_window_size > (kMaxWindowValue - it->second.send_window)) { + // Window overflow is a FLOW_CONTROL_ERROR. + EnqueueFrame(std::make_unique( + stream_id, spdy::ERROR_CODE_FLOW_CONTROL_ERROR)); + return; + } + const bool was_blocked = (it->second.send_window <= 0); + it->second.send_window += delta_window_size; + if (was_blocked && it->second.send_window > 0) { + // The stream was blocked on flow control. + QUICHE_VLOG(1) << "Marking stream " << stream_id << " ready to write."; + write_scheduler_.MarkStreamReady(stream_id, false); + } + } + } + visitor_.OnWindowUpdate(stream_id, delta_window_size); +} + +void OgHttp2Session::OnPushPromise(spdy::SpdyStreamId /*stream_id*/, + spdy::SpdyStreamId /*promised_stream_id*/, + bool /*end*/) { + // Server push is disabled; PUSH_PROMISE is an invalid frame. + LatchErrorAndNotify(Http2ErrorCode::PROTOCOL_ERROR, + ConnectionError::kInvalidPushPromise); +} + +void OgHttp2Session::OnContinuation(spdy::SpdyStreamId /*stream_id*/, + size_t /*payload_length*/, bool /*end*/) {} + +void OgHttp2Session::OnAltSvc(spdy::SpdyStreamId /*stream_id*/, + absl::string_view /*origin*/, + const spdy::SpdyAltSvcWireFormat:: + AlternativeServiceVector& /*altsvc_vector*/) { +} + +void OgHttp2Session::OnPriority(spdy::SpdyStreamId /*stream_id*/, + spdy::SpdyStreamId /*parent_stream_id*/, + int /*weight*/, bool /*exclusive*/) {} + +void OgHttp2Session::OnPriorityUpdate( + spdy::SpdyStreamId /*prioritized_stream_id*/, + absl::string_view /*priority_field_value*/) {} + +bool OgHttp2Session::OnUnknownFrame(spdy::SpdyStreamId /*stream_id*/, + uint8_t /*frame_type*/) { + return true; +} + +void OgHttp2Session::OnUnknownFrameStart(spdy::SpdyStreamId stream_id, + size_t length, uint8_t type, + uint8_t flags) { + process_metadata_ = false; + if (streams_reset_.contains(stream_id)) { + return; + } + if (type == kMetadataFrameType) { + QUICHE_DCHECK_EQ(metadata_length_, 0u); + visitor_.OnBeginMetadataForStream(stream_id, length); + metadata_length_ = length; + process_metadata_ = true; + end_metadata_ = flags & kMetadataEndFlag; + + // Empty metadata payloads will not trigger OnUnknownFramePayload(), so + // handle that possibility here. + MaybeHandleMetadataEndForStream(stream_id); + } else { + QUICHE_DLOG(INFO) << "Received unexpected frame type " + << static_cast(type); + } +} + +void OgHttp2Session::OnUnknownFramePayload(spdy::SpdyStreamId stream_id, + absl::string_view payload) { + if (!process_metadata_) { + return; + } + if (streams_reset_.contains(stream_id)) { + return; + } + if (metadata_length_ > 0) { + QUICHE_DCHECK_LE(payload.size(), metadata_length_); + const bool payload_success = + visitor_.OnMetadataForStream(stream_id, payload); + if (payload_success) { + metadata_length_ -= payload.size(); + MaybeHandleMetadataEndForStream(stream_id); + } else { + fatal_visitor_callback_failure_ = true; + decoder_.StopProcessing(); + } + } else { + QUICHE_DLOG(INFO) << "Unexpected metadata payload for stream " << stream_id; + } +} + +void OgHttp2Session::OnHeaderStatus( + Http2StreamId stream_id, Http2VisitorInterface::OnHeaderResult result) { + QUICHE_DCHECK_NE(result, Http2VisitorInterface::HEADER_OK); + QUICHE_VLOG(1) << "OnHeaderStatus(stream_id=" << stream_id + << ", result=" << result << ")"; + const bool should_reset_stream = + result == Http2VisitorInterface::HEADER_RST_STREAM || + result == Http2VisitorInterface::HEADER_FIELD_INVALID || + result == Http2VisitorInterface::HEADER_HTTP_MESSAGING; + if (should_reset_stream) { + const Http2ErrorCode error_code = + (result == Http2VisitorInterface::HEADER_RST_STREAM) + ? Http2ErrorCode::INTERNAL_ERROR + : Http2ErrorCode::PROTOCOL_ERROR; + const spdy::SpdyErrorCode spdy_error_code = TranslateErrorCode(error_code); + const Http2VisitorInterface::InvalidFrameError frame_error = + (result == Http2VisitorInterface::HEADER_RST_STREAM || + result == Http2VisitorInterface::HEADER_FIELD_INVALID) + ? Http2VisitorInterface::InvalidFrameError::kHttpHeader + : Http2VisitorInterface::InvalidFrameError::kHttpMessaging; + auto it = streams_reset_.find(stream_id); + if (it == streams_reset_.end()) { + EnqueueFrame( + std::make_unique(stream_id, spdy_error_code)); + + if (result == Http2VisitorInterface::HEADER_FIELD_INVALID || + result == Http2VisitorInterface::HEADER_HTTP_MESSAGING) { + const bool ok = visitor_.OnInvalidFrame(stream_id, frame_error); + if (!ok) { + fatal_visitor_callback_failure_ = true; + LatchErrorAndNotify(error_code, ConnectionError::kHeaderError); + } + } + } + } else if (result == Http2VisitorInterface::HEADER_CONNECTION_ERROR) { + fatal_visitor_callback_failure_ = true; + LatchErrorAndNotify(Http2ErrorCode::INTERNAL_ERROR, + ConnectionError::kHeaderError); + } else if (result == Http2VisitorInterface::HEADER_COMPRESSION_ERROR) { + LatchErrorAndNotify(Http2ErrorCode::COMPRESSION_ERROR, + ConnectionError::kHeaderError); + } +} + +void OgHttp2Session::MaybeSetupPreface(bool sending_outbound_settings) { + if (!queued_preface_) { + queued_preface_ = true; + if (!IsServerSession()) { + buffered_data_.assign(spdy::kHttp2ConnectionHeaderPrefix, + spdy::kHttp2ConnectionHeaderPrefixSize); + } + if (!sending_outbound_settings) { + QUICHE_DCHECK(frames_.empty()); + // First frame must be a non-ack SETTINGS. + EnqueueFrame(PrepareSettingsFrame(GetInitialSettings())); + } + } +} + +std::vector OgHttp2Session::GetInitialSettings() const { + std::vector settings; + if (!IsServerSession()) { + // Disable server push. Note that server push from clients is already + // disabled, so the server does not need to send this disabling setting. + // TODO(diannahu): Consider applying server push disabling on SETTINGS ack. + settings.push_back({Http2KnownSettingsId::ENABLE_PUSH, 0}); + } + if (options_.max_header_list_bytes) { + settings.push_back({Http2KnownSettingsId::MAX_HEADER_LIST_SIZE, + *options_.max_header_list_bytes}); + } + if (options_.allow_extended_connect && IsServerSession()) { + settings.push_back({Http2KnownSettingsId::ENABLE_CONNECT_PROTOCOL, 1u}); + } + return settings; +} + +std::unique_ptr OgHttp2Session::PrepareSettingsFrame( + absl::Span settings) { + auto settings_ir = std::make_unique(); + for (const Http2Setting& setting : settings) { + settings_ir->AddSetting(setting.id, setting.value); + } + return settings_ir; +} + +void OgHttp2Session::HandleOutboundSettings( + const spdy::SpdySettingsIR& settings_frame) { + for (const auto& [id, value] : settings_frame.values()) { + switch (static_cast(id)) { + case MAX_CONCURRENT_STREAMS: + pending_max_inbound_concurrent_streams_ = value; + break; + case ENABLE_CONNECT_PROTOCOL: + if (value == 1u && IsServerSession()) { + // Allow extended CONNECT semantics even before SETTINGS are acked, to + // make things easier for clients. + headers_handler_.SetAllowExtendedConnect(); + } + break; + case HEADER_TABLE_SIZE: + case ENABLE_PUSH: + case INITIAL_WINDOW_SIZE: + case MAX_FRAME_SIZE: + case MAX_HEADER_LIST_SIZE: + QUICHE_VLOG(2) + << "Not adjusting internal state for outbound setting with id " + << id; + break; + } + } + + // Copy the (small) map of settings we are about to send so that we can set + // values in the SETTINGS ack callback. + settings_ack_callbacks_.push_back( + [this, settings_map = settings_frame.values()]() { + for (const auto& [id, value] : settings_map) { + switch (static_cast(id)) { + case MAX_CONCURRENT_STREAMS: + max_inbound_concurrent_streams_ = value; + break; + case HEADER_TABLE_SIZE: + decoder_.GetHpackDecoder()->ApplyHeaderTableSizeSetting(value); + break; + case INITIAL_WINDOW_SIZE: + UpdateStreamReceiveWindowSizes(value); + initial_stream_receive_window_ = value; + break; + case MAX_FRAME_SIZE: + decoder_.SetMaxFrameSize(value); + break; + case ENABLE_PUSH: + case MAX_HEADER_LIST_SIZE: + case ENABLE_CONNECT_PROTOCOL: + QUICHE_VLOG(2) + << "No action required in ack for outbound setting with id " + << id; + break; + } + } + }); +} + +void OgHttp2Session::SendWindowUpdate(Http2StreamId stream_id, + size_t update_delta) { + EnqueueFrame( + std::make_unique(stream_id, update_delta)); +} + +void OgHttp2Session::SendHeaders(Http2StreamId stream_id, + spdy::Http2HeaderBlock headers, + bool end_stream) { + auto frame = + std::make_unique(stream_id, std::move(headers)); + frame->set_fin(end_stream); + EnqueueFrame(std::move(frame)); +} + +void OgHttp2Session::SendTrailers(Http2StreamId stream_id, + spdy::Http2HeaderBlock trailers) { + auto frame = + std::make_unique(stream_id, std::move(trailers)); + frame->set_fin(true); + EnqueueFrame(std::move(frame)); + trailers_ready_.erase(stream_id); +} + +void OgHttp2Session::MaybeFinWithRstStream(StreamStateMap::iterator iter) { + QUICHE_DCHECK(iter != stream_map_.end() && iter->second.half_closed_local); + + if (options_.rst_stream_no_error_when_incomplete && IsServerSession() && + !iter->second.half_closed_remote) { + // Since the peer has not yet ended the stream, this endpoint should + // send a RST_STREAM NO_ERROR. See RFC 7540 Section 8.1. + EnqueueFrame(std::make_unique( + iter->first, spdy::SpdyErrorCode::ERROR_CODE_NO_ERROR)); + iter->second.half_closed_remote = true; + } +} + +void OgHttp2Session::MarkDataBuffered(Http2StreamId stream_id, size_t bytes) { + connection_window_manager_.MarkDataBuffered(bytes); + if (auto it = stream_map_.find(stream_id); it != stream_map_.end()) { + it->second.window_manager.MarkDataBuffered(bytes); + } +} + +OgHttp2Session::StreamStateMap::iterator OgHttp2Session::CreateStream( + Http2StreamId stream_id) { + WindowManager::WindowUpdateListener listener = + [this, stream_id](size_t window_update_delta) { + SendWindowUpdate(stream_id, window_update_delta); + }; + auto [iter, inserted] = stream_map_.try_emplace( + stream_id, + StreamState(initial_stream_receive_window_, initial_stream_send_window_, + std::move(listener), options_.should_window_update_fn)); + if (inserted) { + // Add the stream to the write scheduler. + const spdy::SpdyPriority priority = 3; + write_scheduler_.RegisterStream(stream_id, priority); + + highest_processed_stream_id_ = + std::max(highest_processed_stream_id_, stream_id); + } + return iter; +} + +void OgHttp2Session::StartRequest(Http2StreamId stream_id, + spdy::Http2HeaderBlock headers, + std::unique_ptr data_source, + void* user_data) { + if (received_goaway_) { + // Do not start new streams after receiving a GOAWAY. + goaway_rejected_streams_.insert(stream_id); + return; + } + + auto iter = CreateStream(stream_id); + const bool end_stream = data_source == nullptr; + if (!end_stream) { + iter->second.outbound_body = std::move(data_source); + write_scheduler_.MarkStreamReady(stream_id, false); + } + iter->second.user_data = user_data; + for (const auto& [name, value] : headers) { + if (name == kHttp2MethodPseudoHeader && value == kHeadValue) { + iter->second.sent_head_method = true; + } + } + SendHeaders(stream_id, std::move(headers), end_stream); +} + +void OgHttp2Session::StartPendingStreams() { + while (!pending_streams_.empty() && CanCreateStream()) { + auto& [stream_id, pending_stream] = pending_streams_.front(); + StartRequest(stream_id, std::move(pending_stream.headers), + std::move(pending_stream.data_source), + pending_stream.user_data); + pending_streams_.pop_front(); + } +} + +void OgHttp2Session::CloseStream(Http2StreamId stream_id, + Http2ErrorCode error_code) { + const bool result = visitor_.OnCloseStream(stream_id, error_code); + if (!result) { + latched_error_ = true; + decoder_.StopProcessing(); + } + stream_map_.erase(stream_id); + trailers_ready_.erase(stream_id); + streams_reset_.erase(stream_id); + auto queued_it = queued_frames_.find(stream_id); + if (queued_it != queued_frames_.end()) { + // Remove any queued frames for this stream. + int frames_remaining = queued_it->second; + queued_frames_.erase(queued_it); + for (auto it = frames_.begin(); + frames_remaining > 0 && it != frames_.end();) { + if (static_cast((*it)->stream_id()) == stream_id) { + it = frames_.erase(it); + --frames_remaining; + } else { + ++it; + } + } + } + if (write_scheduler_.StreamRegistered(stream_id)) { + write_scheduler_.UnregisterStream(stream_id); + } + + StartPendingStreams(); +} + +bool OgHttp2Session::CanCreateStream() const { + return stream_map_.size() < max_outbound_concurrent_streams_; +} + +HeaderType OgHttp2Session::NextHeaderType( + absl::optional current_type) { + if (IsServerSession()) { + if (!current_type) { + return HeaderType::REQUEST; + } else { + QUICHE_DCHECK(current_type == HeaderType::REQUEST); + return HeaderType::REQUEST_TRAILER; + } + } else if (!current_type || + current_type.value() == HeaderType::RESPONSE_100) { + return HeaderType::RESPONSE; + } else { + return HeaderType::RESPONSE_TRAILER; + } +} + +void OgHttp2Session::LatchErrorAndNotify(Http2ErrorCode error_code, + ConnectionError error) { + if (latched_error_) { + // Do not kick a connection when it is down. + return; + } + + latched_error_ = true; + visitor_.OnConnectionError(error); + decoder_.StopProcessing(); + EnqueueFrame(std::make_unique( + highest_processed_stream_id_, TranslateErrorCode(error_code), + ConnectionErrorToString(error))); +} + +void OgHttp2Session::CloseStreamIfReady(uint8_t frame_type, + uint32_t stream_id) { + auto iter = stream_map_.find(stream_id); + if (iter == stream_map_.end()) { + return; + } + const StreamState& state = iter->second; + if (static_cast(frame_type) == FrameType::RST_STREAM || + (state.half_closed_local && state.half_closed_remote)) { + CloseStream(stream_id, Http2ErrorCode::HTTP2_NO_ERROR); + } +} + +void OgHttp2Session::CloseGoAwayRejectedStreams() { + for (Http2StreamId stream_id : goaway_rejected_streams_) { + const bool result = + visitor_.OnCloseStream(stream_id, Http2ErrorCode::REFUSED_STREAM); + if (!result) { + latched_error_ = true; + decoder_.StopProcessing(); + } + } + goaway_rejected_streams_.clear(); +} + +void OgHttp2Session::PrepareForImmediateGoAway() { + queued_immediate_goaway_ = true; + + // Keep the initial SETTINGS frame if the session has SETTINGS at the front of + // the queue but has not sent SETTINGS yet. The session should send initial + // SETTINGS before GOAWAY. + std::unique_ptr initial_settings; + if (!sent_non_ack_settings_ && !frames_.empty() && + IsNonAckSettings(*frames_.front())) { + initial_settings = std::move(frames_.front()); + frames_.pop_front(); + } + + // Remove all pending frames except for RST_STREAMs. It is important to send + // RST_STREAMs so the peer knows of errors below the GOAWAY last stream ID. + // TODO(diannahu): Consider informing the visitor of dropped frames. This may + // mean keeping the frames and invoking a frame-not-sent callback, similar to + // nghttp2. Could add a closure to each frame in the frames queue. + frames_.remove_if([](const auto& frame) { + return frame->frame_type() != spdy::SpdyFrameType::RST_STREAM; + }); + + if (initial_settings != nullptr) { + frames_.push_front(std::move(initial_settings)); + } +} + +void OgHttp2Session::MaybeHandleMetadataEndForStream(Http2StreamId stream_id) { + if (metadata_length_ == 0 && end_metadata_) { + const bool completion_success = visitor_.OnMetadataEndForStream(stream_id); + if (!completion_success) { + fatal_visitor_callback_failure_ = true; + decoder_.StopProcessing(); + } + process_metadata_ = false; + end_metadata_ = false; + } +} + +void OgHttp2Session::DecrementQueuedFrameCount(uint32_t stream_id, + uint8_t frame_type) { + auto iter = queued_frames_.find(stream_id); + if (iter == queued_frames_.end()) { + QUICHE_LOG(ERROR) << "Unable to find a queued frame count for stream " + << stream_id; + return; + } + if (static_cast(frame_type) != FrameType::DATA) { + --iter->second; + } + if (iter->second == 0) { + // TODO(birenroy): Consider passing through `error_code` here. + CloseStreamIfReady(frame_type, stream_id); + } +} + +void OgHttp2Session::HandleContentLengthError(Http2StreamId stream_id) { + EnqueueFrame(std::make_unique( + stream_id, spdy::ERROR_CODE_PROTOCOL_ERROR)); +} + +void OgHttp2Session::UpdateReceiveWindow(Http2StreamId stream_id, + int32_t delta) { + if (stream_id == 0) { + connection_window_manager_.IncreaseWindow(delta); + // TODO(b/181586191): Provide an explicit way to set the desired window + // limit, remove the upsize-on-window-update behavior. + const int64_t current_window = + connection_window_manager_.CurrentWindowSize(); + if (current_window > connection_window_manager_.WindowSizeLimit()) { + connection_window_manager_.SetWindowSizeLimit(current_window); + } + } else { + auto iter = stream_map_.find(stream_id); + if (iter != stream_map_.end()) { + WindowManager& manager = iter->second.window_manager; + manager.IncreaseWindow(delta); + // TODO(b/181586191): Provide an explicit way to set the desired window + // limit, remove the upsize-on-window-update behavior. + const int64_t current_window = manager.CurrentWindowSize(); + if (current_window > manager.WindowSizeLimit()) { + manager.SetWindowSizeLimit(current_window); + } + } + } +} + +void OgHttp2Session::UpdateStreamSendWindowSizes(uint32_t new_value) { + const int32_t delta = + static_cast(new_value) - initial_stream_send_window_; + initial_stream_send_window_ = new_value; + for (auto& [stream_id, stream_state] : stream_map_) { + const int64_t current_window_size = stream_state.send_window; + const int64_t new_window_size = current_window_size + delta; + if (new_window_size > spdy::kSpdyMaximumWindowSize) { + EnqueueFrame(std::make_unique( + stream_id, spdy::ERROR_CODE_FLOW_CONTROL_ERROR)); + } else { + stream_state.send_window += delta; + } + if (current_window_size <= 0 && new_window_size > 0) { + write_scheduler_.MarkStreamReady(stream_id, false); + } + } +} + +void OgHttp2Session::UpdateStreamReceiveWindowSizes(uint32_t new_value) { + for (auto& [stream_id, stream_state] : stream_map_) { + stream_state.window_manager.OnWindowSizeLimitChange(new_value); + } +} + +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/oghttp2_session.h b/quiche/http2/adapter/oghttp2_session.h new file mode 100644 index 000000000000..9b31fa4b0909 --- /dev/null +++ b/quiche/http2/adapter/oghttp2_session.h @@ -0,0 +1,557 @@ +#ifndef QUICHE_HTTP2_ADAPTER_OGHTTP2_SESSION_H_ +#define QUICHE_HTTP2_ADAPTER_OGHTTP2_SESSION_H_ + +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "quiche/http2/adapter/data_source.h" +#include "quiche/http2/adapter/event_forwarder.h" +#include "quiche/http2/adapter/header_validator.h" +#include "quiche/http2/adapter/header_validator_base.h" +#include "quiche/http2/adapter/http2_protocol.h" +#include "quiche/http2/adapter/http2_session.h" +#include "quiche/http2/adapter/http2_util.h" +#include "quiche/http2/adapter/http2_visitor_interface.h" +#include "quiche/http2/adapter/window_manager.h" +#include "quiche/http2/core/http2_trace_logging.h" +#include "quiche/http2/core/priority_write_scheduler.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_flags.h" +#include "quiche/common/quiche_linked_hash_map.h" +#include "quiche/spdy/core/http2_frame_decoder_adapter.h" +#include "quiche/spdy/core/http2_header_block.h" +#include "quiche/spdy/core/no_op_headers_handler.h" +#include "quiche/spdy/core/spdy_framer.h" +#include "quiche/spdy/core/spdy_protocol.h" + +namespace http2 { +namespace adapter { + +// This class manages state associated with a single multiplexed HTTP/2 session. +class QUICHE_EXPORT OgHttp2Session : public Http2Session, + public spdy::SpdyFramerVisitorInterface { + public: + struct QUICHE_EXPORT Options { + // Returns whether to send a WINDOW_UPDATE based on the window limit, window + // size, and delta that would be sent in the WINDOW_UPDATE. + WindowManager::ShouldWindowUpdateFn should_window_update_fn = + DeltaAtLeastHalfLimit; + // The perspective of this session. + Perspective perspective = Perspective::kClient; + // The maximum HPACK table size to use. + absl::optional max_hpack_encoding_table_capacity = absl::nullopt; + // The maximum number of decoded header bytes that a stream can receive. + absl::optional max_header_list_bytes = absl::nullopt; + // The maximum size of an individual header field, including name and value. + absl::optional max_header_field_size = absl::nullopt; + // Whether to automatically send PING acks when receiving a PING. + bool auto_ping_ack = true; + // Whether (as server) to send a RST_STREAM NO_ERROR when sending a fin on + // an incomplete stream. + bool rst_stream_no_error_when_incomplete = false; + // Whether (as server) to queue trailers until after a stream's data source + // has indicated the end of data. If false, the server will assume that + // submitting trailers indicates the end of data. + bool trailers_require_end_data = false; + // Whether to mark all input data as consumed upon encountering a connection + // error while processing bytes. If true, subsequent processing will also + // mark all input data as consumed. + bool blackhole_data_on_connection_error = true; + // Whether to advertise support for the extended CONNECT semantics described + // in RFC 8441. If true, this endpoint will send the appropriate setting in + // initial SETTINGS. + bool allow_extended_connect = true; + // Whether to allow `obs-text` (characters from hexadecimal 0x80 to 0xff) in + // header field values. + bool allow_obs_text = true; + // If true, validates header field names and values according to RFC 7230 + // and RFC 7540. + bool validate_http_headers = true; + }; + + OgHttp2Session(Http2VisitorInterface& visitor, Options options); + ~OgHttp2Session() override; + + // Enqueues a frame for transmission to the peer. + void EnqueueFrame(std::unique_ptr frame); + + // Starts a graceful shutdown sequence. No-op if a GOAWAY has already been + // sent. + void StartGracefulShutdown(); + + // Invokes the visitor's OnReadyToSend() method for serialized frames and + // DataFrameSource::Send() for data frames. + int Send(); + + int32_t SubmitRequest(absl::Span headers, + std::unique_ptr data_source, + void* user_data); + int SubmitResponse(Http2StreamId stream_id, absl::Span headers, + std::unique_ptr data_source); + int SubmitTrailer(Http2StreamId stream_id, absl::Span trailers); + void SubmitMetadata(Http2StreamId stream_id, + std::unique_ptr source); + void SubmitSettings(absl::Span settings); + + bool IsServerSession() const { + return options_.perspective == Perspective::kServer; + } + Http2StreamId GetHighestReceivedStreamId() const { + return highest_received_stream_id_; + } + void SetStreamUserData(Http2StreamId stream_id, void* user_data); + void* GetStreamUserData(Http2StreamId stream_id); + + // Resumes a stream that was previously blocked. Returns true on success. + bool ResumeStream(Http2StreamId stream_id); + + // Returns the peer's outstanding stream receive window for the given stream. + int GetStreamSendWindowSize(Http2StreamId stream_id) const; + + // Returns the current upper bound on the flow control receive window for this + // stream. + int GetStreamReceiveWindowLimit(Http2StreamId stream_id) const; + + // Returns the outstanding stream receive window, or -1 if the stream does not + // exist. + int GetStreamReceiveWindowSize(Http2StreamId stream_id) const; + + // Returns the outstanding connection receive window. + int GetReceiveWindowSize() const; + + // Returns the size of the HPACK encoder's dynamic table, including the + // per-entry overhead from the specification. + int GetHpackEncoderDynamicTableSize() const; + + // Returns the maximum capacity of the HPACK encoder's dynamic table. + int GetHpackEncoderDynamicTableCapacity() const; + + // Returns the size of the HPACK decoder's dynamic table, including the + // per-entry overhead from the specification. + int GetHpackDecoderDynamicTableSize() const; + + // Returns the size of the HPACK decoder's most recently applied size limit. + int GetHpackDecoderSizeLimit() const; + + // From Http2Session. + int64_t ProcessBytes(absl::string_view bytes) override; + int Consume(Http2StreamId stream_id, size_t num_bytes) override; + bool want_read() const override { + return !received_goaway_ && !decoder_.HasError(); + } + bool want_write() const override { + return !fatal_send_error_ && + (!frames_.empty() || !buffered_data_.empty() || HasReadyStream() || + !goaway_rejected_streams_.empty()); + } + int GetRemoteWindowSize() const override { return connection_send_window_; } + bool peer_enables_connect_protocol() { + return peer_enables_connect_protocol_; + } + + // From SpdyFramerVisitorInterface + void OnError(http2::Http2DecoderAdapter::SpdyFramerError error, + std::string detailed_error) override; + void OnCommonHeader(spdy::SpdyStreamId /*stream_id*/, size_t /*length*/, + uint8_t /*type*/, uint8_t /*flags*/) override; + void OnDataFrameHeader(spdy::SpdyStreamId stream_id, size_t length, + bool fin) override; + void OnStreamFrameData(spdy::SpdyStreamId stream_id, const char* data, + size_t len) override; + void OnStreamEnd(spdy::SpdyStreamId stream_id) override; + void OnStreamPadLength(spdy::SpdyStreamId /*stream_id*/, + size_t /*value*/) override; + void OnStreamPadding(spdy::SpdyStreamId stream_id, size_t len) override; + spdy::SpdyHeadersHandlerInterface* OnHeaderFrameStart( + spdy::SpdyStreamId stream_id) override; + void OnHeaderFrameEnd(spdy::SpdyStreamId stream_id) override; + void OnRstStream(spdy::SpdyStreamId stream_id, + spdy::SpdyErrorCode error_code) override; + void OnSettings() override; + void OnSetting(spdy::SpdySettingsId id, uint32_t value) override; + void OnSettingsEnd() override; + void OnSettingsAck() override; + void OnPing(spdy::SpdyPingId unique_id, bool is_ack) override; + void OnGoAway(spdy::SpdyStreamId last_accepted_stream_id, + spdy::SpdyErrorCode error_code) override; + bool OnGoAwayFrameData(const char* goaway_data, size_t len) override; + void OnHeaders(spdy::SpdyStreamId stream_id, size_t payload_length, + bool has_priority, int weight, + spdy::SpdyStreamId parent_stream_id, bool exclusive, bool fin, + bool end) override; + void OnWindowUpdate(spdy::SpdyStreamId stream_id, + int delta_window_size) override; + void OnPushPromise(spdy::SpdyStreamId stream_id, + spdy::SpdyStreamId promised_stream_id, bool end) override; + void OnContinuation(spdy::SpdyStreamId stream_id, size_t payload_length, + bool end) override; + void OnAltSvc(spdy::SpdyStreamId /*stream_id*/, absl::string_view /*origin*/, + const spdy::SpdyAltSvcWireFormat:: + AlternativeServiceVector& /*altsvc_vector*/) override; + void OnPriority(spdy::SpdyStreamId stream_id, + spdy::SpdyStreamId parent_stream_id, int weight, + bool exclusive) override; + void OnPriorityUpdate(spdy::SpdyStreamId prioritized_stream_id, + absl::string_view priority_field_value) override; + bool OnUnknownFrame(spdy::SpdyStreamId stream_id, + uint8_t frame_type) override; + void OnUnknownFrameStart(spdy::SpdyStreamId stream_id, size_t length, + uint8_t type, uint8_t flags) override; + void OnUnknownFramePayload(spdy::SpdyStreamId stream_id, + absl::string_view payload) override; + + // Invoked when header processing encounters an invalid or otherwise + // problematic header. + void OnHeaderStatus(Http2StreamId stream_id, + Http2VisitorInterface::OnHeaderResult result); + + private: + struct QUICHE_EXPORT StreamState { + StreamState(int32_t stream_receive_window, int32_t stream_send_window, + WindowManager::WindowUpdateListener listener, + WindowManager::ShouldWindowUpdateFn should_window_update_fn) + : window_manager(stream_receive_window, std::move(listener), + std::move(should_window_update_fn), + /*update_window_on_notify=*/false), + send_window(stream_send_window) {} + + WindowManager window_manager; + std::unique_ptr outbound_body; + std::unique_ptr trailers; + void* user_data = nullptr; + int32_t send_window; + absl::optional received_header_type; + absl::optional remaining_content_length; + bool half_closed_local = false; + bool half_closed_remote = false; + // Indicates that `outbound_body` temporarily cannot produce data. + bool data_deferred = false; + bool sent_head_method = false; + bool can_receive_body = true; + }; + using StreamStateMap = absl::flat_hash_map; + + struct QUICHE_EXPORT PendingStreamState { + spdy::Http2HeaderBlock headers; + std::unique_ptr data_source; + void* user_data = nullptr; + }; + + class QUICHE_EXPORT PassthroughHeadersHandler + : public spdy::SpdyHeadersHandlerInterface { + public: + PassthroughHeadersHandler(OgHttp2Session& session, + Http2VisitorInterface& visitor); + + void set_stream_id(Http2StreamId stream_id) { + stream_id_ = stream_id; + result_ = Http2VisitorInterface::HEADER_OK; + } + + void set_frame_contains_fin() { frame_contains_fin_ = true; } + void set_header_type(HeaderType type) { type_ = type; } + HeaderType header_type() const { return type_; } + + void OnHeaderBlockStart() override; + void OnHeader(absl::string_view key, absl::string_view value) override; + void OnHeaderBlockEnd(size_t /* uncompressed_header_bytes */, + size_t /* compressed_header_bytes */) override; + absl::string_view status_header() const { + QUICHE_DCHECK(type_ == HeaderType::RESPONSE || + type_ == HeaderType::RESPONSE_100); + return validator_->status_header(); + } + absl::optional content_length() const { + return validator_->content_length(); + } + void SetAllowExtendedConnect() { validator_->SetAllowExtendedConnect(); } + void SetMaxFieldSize(uint32_t field_size) { + validator_->SetMaxFieldSize(field_size); + } + void SetAllowObsText(bool allow) { + validator_->SetObsTextOption(allow ? ObsTextOption::kAllow + : ObsTextOption::kDisallow); + } + bool CanReceiveBody() const; + + private: + OgHttp2Session& session_; + Http2VisitorInterface& visitor_; + Http2StreamId stream_id_ = 0; + Http2VisitorInterface::OnHeaderResult result_ = + Http2VisitorInterface::HEADER_OK; + // Validates header blocks according to the HTTP/2 specification. + std::unique_ptr validator_; + HeaderType type_ = HeaderType::RESPONSE; + bool frame_contains_fin_ = false; + }; + + struct QUICHE_EXPORT ProcessBytesResultVisitor; + + // Queues the connection preface, if not already done. If not + // `sending_outbound_settings` and the preface has not yet been queued, this + // method will generate and enqueue initial SETTINGS. + void MaybeSetupPreface(bool sending_outbound_settings); + + // Gets the settings to be sent in the initial SETTINGS frame sent as part of + // the connection preface. + std::vector GetInitialSettings() const; + + // Prepares and returns a SETTINGS frame with the given `settings`. + std::unique_ptr PrepareSettingsFrame( + absl::Span settings); + + // Updates internal state to match the SETTINGS advertised to the peer. + void HandleOutboundSettings(const spdy::SpdySettingsIR& settings_frame); + + void SendWindowUpdate(Http2StreamId stream_id, size_t update_delta); + + enum class SendResult { + // All data was flushed. + SEND_OK, + // Not all data was flushed (due to flow control or TCP back pressure). + SEND_BLOCKED, + // An error occurred while sending data. + SEND_ERROR, + }; + + // Returns the int corresponding to the `result`, updating state as needed. + int InterpretSendResult(SendResult result); + + enum class ProcessBytesError { + // A general, unspecified error. + kUnspecified, + // The (server-side) session received an invalid client connection preface. + kInvalidConnectionPreface, + // A user/visitor callback failed with a fatal error. + kVisitorCallbackFailed, + }; + using ProcessBytesResult = absl::variant; + + // Attempts to process `bytes` and returns the number of bytes proccessed on + // success or the processing error on failure. + ProcessBytesResult ProcessBytesImpl(absl::string_view bytes); + + // Returns true if at least one stream has data or control frames to write. + bool HasReadyStream() const; + + // Returns the next stream that has something to write. If there are no such + // streams, returns zero. + Http2StreamId GetNextReadyStream(); + + // Sends the buffered connection preface or serialized frame data, if any. + SendResult MaybeSendBufferedData(); + + // Serializes and sends queued frames. + SendResult SendQueuedFrames(); + + // Returns false if a fatal connection error occurred. + bool AfterFrameSent(uint8_t frame_type_int, uint32_t stream_id, + size_t payload_length, uint8_t flags, + uint32_t error_code); + + // Writes DATA frames for stream `stream_id`. + SendResult WriteForStream(Http2StreamId stream_id); + + void SerializeMetadata(Http2StreamId stream_id, + std::unique_ptr source); + + void SendHeaders(Http2StreamId stream_id, spdy::Http2HeaderBlock headers, + bool end_stream); + + void SendTrailers(Http2StreamId stream_id, spdy::Http2HeaderBlock trailers); + + // Encapsulates the RST_STREAM NO_ERROR behavior described in RFC 7540 + // Section 8.1. + void MaybeFinWithRstStream(StreamStateMap::iterator iter); + + // Performs flow control accounting for data sent by the peer. + void MarkDataBuffered(Http2StreamId stream_id, size_t bytes); + + // Creates a stream for `stream_id` if not already present and returns an + // iterator pointing to it. + StreamStateMap::iterator CreateStream(Http2StreamId stream_id); + + // Creates a stream for `stream_id`, stores the `data_source` and `user_data` + // in the stream state, and sends the `headers`. + void StartRequest(Http2StreamId stream_id, spdy::Http2HeaderBlock headers, + std::unique_ptr data_source, + void* user_data); + + // Sends headers for pending streams as long as the stream limit allows. + void StartPendingStreams(); + + // Closes the given `stream_id` with the given `error_code`. + void CloseStream(Http2StreamId stream_id, Http2ErrorCode error_code); + + // Calculates the next expected header type for a stream in a given state. + HeaderType NextHeaderType(absl::optional current_type); + + // Returns true if the session can create a new stream. + bool CanCreateStream() const; + + // Informs the visitor of the connection `error` and stops processing on the + // connection. If server-side, also sends a GOAWAY with `error_code`. + void LatchErrorAndNotify(Http2ErrorCode error_code, + Http2VisitorInterface::ConnectionError error); + + void CloseStreamIfReady(uint8_t frame_type, uint32_t stream_id); + + // Informs the visitor of rejected, non-active streams due to GOAWAY receipt. + void CloseGoAwayRejectedStreams(); + + // Updates internal state to prepare for sending an immediate GOAWAY. + void PrepareForImmediateGoAway(); + + // Handles the potential end of received metadata for the given `stream_id`. + void MaybeHandleMetadataEndForStream(Http2StreamId stream_id); + + void DecrementQueuedFrameCount(uint32_t stream_id, uint8_t frame_type); + + void HandleContentLengthError(Http2StreamId stream_id); + + // Invoked when sending a flow control window update to the peer. + void UpdateReceiveWindow(Http2StreamId stream_id, int32_t delta); + + // Updates stream send window accounting to respect the peer's advertised + // initial window setting. + void UpdateStreamSendWindowSizes(uint32_t new_value); + + // Updates stream receive window managers to use the newly advertised stream + // initial window. + void UpdateStreamReceiveWindowSizes(uint32_t new_value); + + // Receives events when inbound frames are parsed. + Http2VisitorInterface& visitor_; + + const Options options_; + + // Forwards received events to the session if it can accept them. + EventForwarder event_forwarder_; + + // Logs received frames when enabled. + Http2TraceLogger receive_logger_; + // Logs sent frames when enabled. + Http2FrameLogger send_logger_; + + // Encodes outbound frames. + spdy::SpdyFramer framer_{spdy::SpdyFramer::ENABLE_COMPRESSION}; + + // Decodes inbound frames. + http2::Http2DecoderAdapter decoder_; + + // Maintains the state of active streams known to this session. + StreamStateMap stream_map_; + + // Maintains the state of pending streams known to this session. A pending + // stream is kept in this list until it can be created while complying with + // `max_outbound_concurrent_streams_`. + quiche::QuicheLinkedHashMap + pending_streams_; + + // The queue of outbound frames. + std::list> frames_; + // Buffered data (connection preface, serialized frames) that has not yet been + // sent. + std::string buffered_data_; + + // Maintains the set of streams ready to write data to the peer. + using WriteScheduler = PriorityWriteScheduler; + WriteScheduler write_scheduler_; + + // Stores the queue of callbacks to invoke upon receiving SETTINGS acks. At + // most one callback is invoked for each SETTINGS ack. + using SettingsAckCallback = std::function; + std::list settings_ack_callbacks_; + + // Delivers header name-value pairs to the visitor. + PassthroughHeadersHandler headers_handler_; + + // Ignores header data, e.g., for an unknown or rejected stream. + spdy::NoOpHeadersHandler noop_headers_handler_; + + // Tracks the remaining client connection preface, in the case of a server + // session. + absl::string_view remaining_preface_; + + WindowManager connection_window_manager_; + + // Tracks the streams that have been marked for reset. A stream is removed + // from this set once it is closed. + absl::flat_hash_set streams_reset_; + + // The number of frames currently queued per stream. + absl::flat_hash_map queued_frames_; + // Includes streams that are currently ready to write trailers. + absl::flat_hash_set trailers_ready_; + // Includes streams that will not be written due to receipt of GOAWAY. + absl::flat_hash_set goaway_rejected_streams_; + + Http2StreamId next_stream_id_ = 1; + // The highest received stream ID is the highest stream ID in any frame read + // from the peer. The highest processed stream ID is the highest stream ID for + // which this endpoint created a stream in the stream map. + Http2StreamId highest_received_stream_id_ = 0; + Http2StreamId highest_processed_stream_id_ = 0; + Http2StreamId received_goaway_stream_id_ = 0; + size_t metadata_length_ = 0; + int32_t connection_send_window_ = kInitialFlowControlWindowSize; + // The initial flow control receive window size for any newly created streams. + int32_t initial_stream_receive_window_ = kInitialFlowControlWindowSize; + // The initial flow control send window size for any newly created streams. + int32_t initial_stream_send_window_ = kInitialFlowControlWindowSize; + uint32_t max_frame_payload_ = kDefaultFramePayloadSizeLimit; + // The maximum number of concurrent streams that this connection can open to + // its peer and allow from its peer, respectively. Although the initial value + // is unlimited, the spec encourages a value of at least 100. We limit + // ourselves to opening 100 until told otherwise by the peer and allow an + // unlimited number from the peer until updated from SETTINGS we send. + uint32_t max_outbound_concurrent_streams_ = 100u; + uint32_t pending_max_inbound_concurrent_streams_ = + std::numeric_limits::max(); + uint32_t max_inbound_concurrent_streams_ = + std::numeric_limits::max(); + + // The HPACK encoder header table capacity that will be applied when + // acking SETTINGS from the peer. Only contains a value if the peer advertises + // a larger table capacity than currently used; a smaller value can safely be + // applied immediately upon receipt. + absl::optional encoder_header_table_capacity_when_acking_; + + bool received_goaway_ = false; + bool queued_preface_ = false; + bool peer_supports_metadata_ = false; + bool end_metadata_ = false; + bool process_metadata_ = false; + bool sent_non_ack_settings_ = false; + + // Recursion guard for ProcessBytes(). + bool processing_bytes_ = false; + // Recursion guard for Send(). + bool sending_ = false; + + bool peer_enables_connect_protocol_ = false; + + // Replace this with a stream ID, for multiple GOAWAY support. + bool queued_goaway_ = false; + bool queued_immediate_goaway_ = false; + bool latched_error_ = false; + + // True if a fatal sending error has occurred. + bool fatal_send_error_ = false; + + // True if a fatal processing visitor callback failed. + bool fatal_visitor_callback_failure_ = false; +}; + +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_OGHTTP2_SESSION_H_ diff --git a/quiche/http2/adapter/oghttp2_session_test.cc b/quiche/http2/adapter/oghttp2_session_test.cc new file mode 100644 index 000000000000..2953f1e93041 --- /dev/null +++ b/quiche/http2/adapter/oghttp2_session_test.cc @@ -0,0 +1,1076 @@ +#include "quiche/http2/adapter/oghttp2_session.h" + +#include + +#include "quiche/http2/adapter/mock_http2_visitor.h" +#include "quiche/http2/adapter/test_frame_sequence.h" +#include "quiche/http2/adapter/test_utils.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace adapter { +namespace test { +namespace { + +using spdy::SpdyFrameType; +using testing::_; + +enum FrameType { + DATA, + HEADERS, + PRIORITY, + RST_STREAM, + SETTINGS, + PUSH_PROMISE, + PING, + GOAWAY, + WINDOW_UPDATE, +}; + +} // namespace + +TEST(OgHttp2SessionTest, ClientConstruction) { + testing::StrictMock visitor; + OgHttp2Session::Options options; + options.perspective = Perspective::kClient; + OgHttp2Session session(visitor, options); + EXPECT_TRUE(session.want_read()); + EXPECT_FALSE(session.want_write()); + EXPECT_EQ(session.GetRemoteWindowSize(), kInitialFlowControlWindowSize); + EXPECT_FALSE(session.IsServerSession()); + EXPECT_EQ(0, session.GetHighestReceivedStreamId()); +} + +TEST(OgHttp2SessionTest, ClientHandlesFrames) { + testing::StrictMock visitor; + OgHttp2Session::Options options; + options.perspective = Perspective::kClient; + OgHttp2Session session(visitor, options); + + const std::string initial_frames = TestFrameSequence() + .ServerPreface() + .Ping(42) + .WindowUpdate(0, 1000) + .Serialize(); + testing::InSequence s; + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(0, 8, PING, 0)); + EXPECT_CALL(visitor, OnPing(42, false)); + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(0, 1000)); + + const int64_t initial_result = session.ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), static_cast(initial_result)); + + EXPECT_EQ(session.GetRemoteWindowSize(), + kInitialFlowControlWindowSize + 1000); + EXPECT_EQ(0, session.GetHighestReceivedStreamId()); + + // Connection has not yet received any data. + EXPECT_EQ(kInitialFlowControlWindowSize, session.GetReceiveWindowSize()); + + EXPECT_EQ(0, session.GetHpackDecoderDynamicTableSize()); + + // Submit a request to ensure the first stream is created. + const char* kSentinel1 = "arbitrary pointer 1"; + auto body1 = std::make_unique(visitor, true); + body1->AppendPayload("This is an example request body."); + body1->EndData(); + int stream_id = + session.SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(body1), const_cast(kSentinel1)); + EXPECT_EQ(stream_id, 1); + + // Submit another request to ensure the next stream is created. + int stream_id2 = + session.SubmitRequest(ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}), + nullptr, nullptr); + EXPECT_EQ(stream_id2, 3); + + const std::string stream_frames = + TestFrameSequence() + .Headers(stream_id, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(stream_id, "This is the response body.") + .RstStream(stream_id2, Http2ErrorCode::INTERNAL_ERROR) + .GoAway(5, Http2ErrorCode::ENHANCE_YOUR_CALM, "calm down!!") + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(stream_id, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(stream_id)); + EXPECT_CALL(visitor, OnHeaderForStream(stream_id, ":status", "200")); + EXPECT_CALL(visitor, + OnHeaderForStream(stream_id, "server", "my-fake-server")); + EXPECT_CALL(visitor, OnHeaderForStream(stream_id, "date", + "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(stream_id)); + EXPECT_CALL(visitor, OnFrameHeader(stream_id, 26, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(stream_id, 26)); + EXPECT_CALL(visitor, + OnDataForStream(stream_id, "This is the response body.")); + EXPECT_CALL(visitor, OnFrameHeader(stream_id2, 4, RST_STREAM, 0)); + EXPECT_CALL(visitor, OnRstStream(stream_id2, Http2ErrorCode::INTERNAL_ERROR)); + EXPECT_CALL(visitor, + OnCloseStream(stream_id2, Http2ErrorCode::INTERNAL_ERROR)); + EXPECT_CALL(visitor, OnFrameHeader(0, 19, GOAWAY, 0)); + EXPECT_CALL(visitor, OnGoAway(5, Http2ErrorCode::ENHANCE_YOUR_CALM, "")); + const int64_t stream_result = session.ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + EXPECT_EQ(stream_id2, session.GetHighestReceivedStreamId()); + + // The first stream is active and has received some data. + EXPECT_GT(kInitialFlowControlWindowSize, + session.GetStreamReceiveWindowSize(stream_id)); + // Connection receive window is equivalent to the first stream's. + EXPECT_EQ(session.GetReceiveWindowSize(), + session.GetStreamReceiveWindowSize(stream_id)); + // Receive window upper bound is still the initial value. + EXPECT_EQ(kInitialFlowControlWindowSize, + session.GetStreamReceiveWindowLimit(stream_id)); + + EXPECT_GT(session.GetHpackDecoderDynamicTableSize(), 0); +} + +// Verifies that a client session enqueues initial SETTINGS if Send() is called +// before any frames are explicitly queued. +TEST(OgHttp2SessionTest, ClientEnqueuesSettingsOnSend) { + DataSavingVisitor visitor; + OgHttp2Session::Options options; + options.perspective = Perspective::kClient; + OgHttp2Session session(visitor, options); + EXPECT_FALSE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + + int result = session.Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, EqualsFrames({SpdyFrameType::SETTINGS})); +} + +// Verifies that a client session enqueues initial SETTINGS before whatever +// frame type is passed to the first invocation of EnqueueFrame(). +TEST(OgHttp2SessionTest, ClientEnqueuesSettingsBeforeOtherFrame) { + DataSavingVisitor visitor; + OgHttp2Session::Options options; + options.perspective = Perspective::kClient; + OgHttp2Session session(visitor, options); + EXPECT_FALSE(session.want_write()); + session.EnqueueFrame(std::make_unique(42)); + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(PING, 0, 8, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(PING, 0, 8, 0x0, 0)); + + int result = session.Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::PING})); +} + +// Verifies that if the first call to EnqueueFrame() passes a SETTINGS frame, +// the client session will not enqueue an additional SETTINGS frame. +TEST(OgHttp2SessionTest, ClientEnqueuesSettingsOnce) { + DataSavingVisitor visitor; + OgHttp2Session::Options options; + options.perspective = Perspective::kClient; + OgHttp2Session session(visitor, options); + EXPECT_FALSE(session.want_write()); + session.EnqueueFrame(std::make_unique()); + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + + int result = session.Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(OgHttp2SessionTest, ClientSubmitRequest) { + DataSavingVisitor visitor; + OgHttp2Session::Options options; + options.perspective = Perspective::kClient; + OgHttp2Session session(visitor, options); + + EXPECT_FALSE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + + // Even though the user has not queued any frames for the session, it should + // still send the connection preface. + int result = session.Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + // Initial SETTINGS. + EXPECT_THAT(serialized, EqualsFrames({SpdyFrameType::SETTINGS})); + visitor.Clear(); + + const std::string initial_frames = + TestFrameSequence().ServerPreface().Serialize(); + testing::InSequence s; + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = session.ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), static_cast(initial_result)); + + // Session will want to write a SETTINGS ack. + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + result = session.Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); + visitor.Clear(); + + EXPECT_EQ(0, session.GetHpackEncoderDynamicTableSize()); + + const char* kSentinel1 = "arbitrary pointer 1"; + auto body1 = std::make_unique(visitor, true); + body1->AppendPayload("This is an example request body."); + body1->EndData(); + int stream_id = + session.SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(body1), const_cast(kSentinel1)); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(session.want_write()); + EXPECT_EQ(kSentinel1, session.GetStreamUserData(stream_id)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, _, 0x1, 0)); + + result = session.Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::HEADERS, + spdy::SpdyFrameType::DATA})); + visitor.Clear(); + EXPECT_FALSE(session.want_write()); + + // Some data was sent, so the remaining send window size should be less than + // the default. + EXPECT_LT(session.GetStreamSendWindowSize(stream_id), + kInitialFlowControlWindowSize); + EXPECT_GT(session.GetStreamSendWindowSize(stream_id), 0); + // Send window for a nonexistent stream is not available. + EXPECT_EQ(-1, session.GetStreamSendWindowSize(stream_id + 2)); + + EXPECT_GT(session.GetHpackEncoderDynamicTableSize(), 0); + + stream_id = + session.SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}), + nullptr, nullptr); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(session.want_write()); + const char* kSentinel2 = "arbitrary pointer 2"; + EXPECT_EQ(nullptr, session.GetStreamUserData(stream_id)); + session.SetStreamUserData(stream_id, const_cast(kSentinel2)); + EXPECT_EQ(kSentinel2, session.GetStreamUserData(stream_id)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x5, 0)); + + result = session.Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::HEADERS})); + + // No data was sent (just HEADERS), so the remaining send window size should + // still be the default. + EXPECT_EQ(session.GetStreamSendWindowSize(stream_id), + kInitialFlowControlWindowSize); +} + +TEST(OgHttp2SessionTest, ClientSubmitRequestWithLargePayload) { + DataSavingVisitor visitor; + OgHttp2Session::Options options; + options.perspective = Perspective::kClient; + OgHttp2Session session(visitor, options); + + EXPECT_FALSE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + + // Even though the user has not queued any frames for the session, it should + // still send the connection preface. + int result = session.Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + // Initial SETTINGS. + EXPECT_THAT(serialized, EqualsFrames({SpdyFrameType::SETTINGS})); + visitor.Clear(); + + const std::string initial_frames = + TestFrameSequence() + .ServerPreface( + {Http2Setting{Http2KnownSettingsId::MAX_FRAME_SIZE, 32768u}}) + .Serialize(); + testing::InSequence s; + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting(Http2Setting{ + Http2KnownSettingsId::MAX_FRAME_SIZE, 32768u})); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = session.ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), static_cast(initial_result)); + + // Session will want to write a SETTINGS ack. + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + result = session.Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); + visitor.Clear(); + + auto body1 = std::make_unique(visitor, true); + body1->AppendPayload(std::string(20000, 'a')); + body1->EndData(); + int stream_id = + session.SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(body1), nullptr); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x4, 0)); + // Single DATA frame with fin, indicating all 20k bytes fit in one frame. + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, _, 0x1, 0)); + + result = session.Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::HEADERS, + spdy::SpdyFrameType::DATA})); + visitor.Clear(); + EXPECT_FALSE(session.want_write()); +} + +// This test exercises the case where the client request body source is read +// blocked. +TEST(OgHttp2SessionTest, ClientSubmitRequestWithReadBlock) { + DataSavingVisitor visitor; + OgHttp2Session::Options options; + options.perspective = Perspective::kClient; + OgHttp2Session session(visitor, options); + EXPECT_FALSE(session.want_write()); + + const char* kSentinel1 = "arbitrary pointer 1"; + auto body1 = std::make_unique(visitor, true); + TestDataFrameSource* body_ref = body1.get(); + int stream_id = + session.SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(body1), const_cast(kSentinel1)); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(session.want_write()); + EXPECT_EQ(kSentinel1, session.GetStreamUserData(stream_id)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x4, 0)); + + int result = session.Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); + // No data frame, as body1 was read blocked. + visitor.Clear(); + EXPECT_FALSE(session.want_write()); + + body_ref->AppendPayload("This is an example request body."); + body_ref->EndData(); + EXPECT_TRUE(session.ResumeStream(stream_id)); + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, _, 0x1, 0)); + + result = session.Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::DATA})); + EXPECT_FALSE(session.want_write()); + + // Stream data is done, so this stream cannot be resumed. + EXPECT_FALSE(session.ResumeStream(stream_id)); + EXPECT_FALSE(session.want_write()); +} + +// This test exercises the case where the client request body source is read +// blocked, then ends with an empty DATA frame. +TEST(OgHttp2SessionTest, ClientSubmitRequestEmptyDataWithFin) { + DataSavingVisitor visitor; + OgHttp2Session::Options options; + options.perspective = Perspective::kClient; + OgHttp2Session session(visitor, options); + EXPECT_FALSE(session.want_write()); + + const char* kSentinel1 = "arbitrary pointer 1"; + auto body1 = std::make_unique(visitor, true); + TestDataFrameSource* body_ref = body1.get(); + int stream_id = + session.SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(body1), const_cast(kSentinel1)); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(session.want_write()); + EXPECT_EQ(kSentinel1, session.GetStreamUserData(stream_id)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x4, 0)); + + int result = session.Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); + // No data frame, as body1 was read blocked. + visitor.Clear(); + EXPECT_FALSE(session.want_write()); + + body_ref->EndData(); + EXPECT_TRUE(session.ResumeStream(stream_id)); + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, 0, 0x1, 0)); + + result = session.Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::DATA})); + EXPECT_FALSE(session.want_write()); + + // Stream data is done, so this stream cannot be resumed. + EXPECT_FALSE(session.ResumeStream(stream_id)); + EXPECT_FALSE(session.want_write()); +} + +// This test exercises the case where the connection to the peer is write +// blocked. +TEST(OgHttp2SessionTest, ClientSubmitRequestWithWriteBlock) { + DataSavingVisitor visitor; + OgHttp2Session::Options options; + options.perspective = Perspective::kClient; + OgHttp2Session session(visitor, options); + EXPECT_FALSE(session.want_write()); + + const char* kSentinel1 = "arbitrary pointer 1"; + auto body1 = std::make_unique(visitor, true); + body1->AppendPayload("This is an example request body."); + body1->EndData(); + int stream_id = + session.SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(body1), const_cast(kSentinel1)); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(session.want_write()); + EXPECT_EQ(kSentinel1, session.GetStreamUserData(stream_id)); + visitor.set_is_write_blocked(true); + int result = session.Send(); + EXPECT_EQ(0, result); + + EXPECT_THAT(visitor.data(), testing::IsEmpty()); + EXPECT_TRUE(session.want_write()); + visitor.set_is_write_blocked(false); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, _, 0x1, 0)); + + result = session.Send(); + EXPECT_EQ(0, result); + + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS, + SpdyFrameType::DATA})); + EXPECT_FALSE(session.want_write()); +} + +TEST(OgHttp2SessionTest, ServerConstruction) { + testing::StrictMock visitor; + OgHttp2Session::Options options; + options.perspective = Perspective::kServer; + OgHttp2Session session(visitor, options); + EXPECT_TRUE(session.want_read()); + EXPECT_FALSE(session.want_write()); + EXPECT_EQ(session.GetRemoteWindowSize(), kInitialFlowControlWindowSize); + EXPECT_TRUE(session.IsServerSession()); + EXPECT_EQ(0, session.GetHighestReceivedStreamId()); +} + +TEST(OgHttp2SessionTest, ServerHandlesFrames) { + DataSavingVisitor visitor; + OgHttp2Session::Options options; + options.perspective = Perspective::kServer; + OgHttp2Session session(visitor, options); + + EXPECT_EQ(0, session.GetHpackDecoderDynamicTableSize()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Ping(42) + .WindowUpdate(0, 1000) + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .Headers(3, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .RstStream(3, Http2ErrorCode::CANCEL) + .Ping(47) + .Serialize(); + testing::InSequence s; + + const char* kSentinel1 = "arbitrary pointer 1"; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(0, 8, PING, 0)); + EXPECT_CALL(visitor, OnPing(42, false)); + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(0, 1000)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)) + .WillOnce(testing::InvokeWithoutArgs([&session, kSentinel1]() { + session.SetStreamUserData(1, const_cast(kSentinel1)); + return true; + })); + EXPECT_CALL(visitor, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(1, 2000)); + EXPECT_CALL(visitor, OnFrameHeader(1, 25, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 25)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the request body.")); + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":scheme", "http")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":path", "/this/is/request/two")); + EXPECT_CALL(visitor, OnEndHeadersForStream(3)); + EXPECT_CALL(visitor, OnEndStream(3)); + EXPECT_CALL(visitor, OnFrameHeader(3, 4, RST_STREAM, 0)); + EXPECT_CALL(visitor, OnRstStream(3, Http2ErrorCode::CANCEL)); + EXPECT_CALL(visitor, OnCloseStream(3, Http2ErrorCode::CANCEL)); + EXPECT_CALL(visitor, OnFrameHeader(0, 8, PING, 0)); + EXPECT_CALL(visitor, OnPing(47, false)); + + const int64_t result = session.ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + EXPECT_EQ(kSentinel1, session.GetStreamUserData(1)); + + // The first stream is active and has received some data. + EXPECT_GT(kInitialFlowControlWindowSize, + session.GetStreamReceiveWindowSize(1)); + // Connection receive window is equivalent to the first stream's. + EXPECT_EQ(session.GetReceiveWindowSize(), + session.GetStreamReceiveWindowSize(1)); + // Receive window upper bound is still the initial value. + EXPECT_EQ(kInitialFlowControlWindowSize, + session.GetStreamReceiveWindowLimit(1)); + + EXPECT_GT(session.GetHpackDecoderDynamicTableSize(), 0); + + // It should no longer be possible to set user data on a closed stream. + const char* kSentinel3 = "another arbitrary pointer"; + session.SetStreamUserData(3, const_cast(kSentinel3)); + EXPECT_EQ(nullptr, session.GetStreamUserData(3)); + + EXPECT_EQ(session.GetRemoteWindowSize(), + kInitialFlowControlWindowSize + 1000); + EXPECT_EQ(3, session.GetHighestReceivedStreamId()); + + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(PING, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(PING, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(PING, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(PING, 0, _, 0x1, 0)); + + // Some bytes should have been serialized. + int send_result = session.Send(); + EXPECT_EQ(0, send_result); + // Initial SETTINGS, SETTINGS ack, and PING acks (for PING IDs 42 and 47). + EXPECT_THAT(visitor.data(), + EqualsFrames( + {spdy::SpdyFrameType::SETTINGS, spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::PING, spdy::SpdyFrameType::PING})); +} + +// Verifies that a server session enqueues initial SETTINGS before whatever +// frame type is passed to the first invocation of EnqueueFrame(). +TEST(OgHttp2SessionTest, ServerEnqueuesSettingsBeforeOtherFrame) { + DataSavingVisitor visitor; + OgHttp2Session::Options options; + options.perspective = Perspective::kServer; + OgHttp2Session session(visitor, options); + EXPECT_FALSE(session.want_write()); + session.EnqueueFrame(std::make_unique(42)); + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(PING, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(PING, 0, _, 0x0, 0)); + + int result = session.Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::PING})); +} + +// Verifies that if the first call to EnqueueFrame() passes a SETTINGS frame, +// the server session will not enqueue an additional SETTINGS frame. +TEST(OgHttp2SessionTest, ServerEnqueuesSettingsOnce) { + DataSavingVisitor visitor; + OgHttp2Session::Options options; + options.perspective = Perspective::kServer; + OgHttp2Session session(visitor, options); + EXPECT_FALSE(session.want_write()); + session.EnqueueFrame(std::make_unique()); + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + + int result = session.Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(OgHttp2SessionTest, ServerSubmitResponse) { + DataSavingVisitor visitor; + OgHttp2Session::Options options; + options.perspective = Perspective::kServer; + OgHttp2Session session(visitor, options); + + EXPECT_FALSE(session.want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + const char* kSentinel1 = "arbitrary pointer 1"; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)) + .WillOnce(testing::InvokeWithoutArgs([&session, kSentinel1]() { + session.SetStreamUserData(1, const_cast(kSentinel1)); + return true; + })); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = session.ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + EXPECT_EQ(1, session.GetHighestReceivedStreamId()); + + EXPECT_EQ(0, session.GetHpackEncoderDynamicTableSize()); + + // Server will want to send initial SETTINGS, and a SETTINGS ack. + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + int send_result = session.Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS})); + visitor.Clear(); + + EXPECT_FALSE(session.want_write()); + // A data fin is not sent so that the stream remains open, and the flow + // control state can be verified. + auto body1 = std::make_unique(visitor, false); + body1->AppendPayload("This is an example response body."); + int submit_result = session.SubmitResponse( + 1, + ToHeaders({{":status", "404"}, + {"x-comment", "I have no idea what you're talking about."}}), + std::move(body1)); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(session.want_write()); + + // Stream user data should have been set successfully after receiving headers. + EXPECT_EQ(kSentinel1, session.GetStreamUserData(1)); + session.SetStreamUserData(1, nullptr); + EXPECT_EQ(nullptr, session.GetStreamUserData(1)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)); + + send_result = session.Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::HEADERS, SpdyFrameType::DATA})); + EXPECT_FALSE(session.want_write()); + + // Some data was sent, so the remaining send window size should be less than + // the default. + EXPECT_LT(session.GetStreamSendWindowSize(1), kInitialFlowControlWindowSize); + EXPECT_GT(session.GetStreamSendWindowSize(1), 0); + // Send window for a nonexistent stream is not available. + EXPECT_EQ(session.GetStreamSendWindowSize(3), -1); + + EXPECT_GT(session.GetHpackEncoderDynamicTableSize(), 0); +} + +// Tests the case where the server queues trailers after the data stream is +// exhausted. +TEST(OgHttp2SessionTest, ServerSendsTrailers) { + DataSavingVisitor visitor; + OgHttp2Session::Options options; + options.perspective = Perspective::kServer; + OgHttp2Session session(visitor, options); + + EXPECT_FALSE(session.want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = session.ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + // Server will want to send initial SETTINGS, and a SETTINGS ack. + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + int send_result = session.Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS})); + visitor.Clear(); + + EXPECT_FALSE(session.want_write()); + + // The body source must indicate that the end of the body is not the end of + // the stream. + auto body1 = std::make_unique(visitor, false); + body1->AppendPayload("This is an example response body."); + body1->EndData(); + int submit_result = session.SubmitResponse( + 1, ToHeaders({{":status", "200"}, {"x-comment", "Sure, sounds good."}}), + std::move(body1)); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)); + + send_result = session.Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::HEADERS, SpdyFrameType::DATA})); + visitor.Clear(); + EXPECT_FALSE(session.want_write()); + + // The body source has been exhausted by the call to Send() above. + int trailer_result = session.SubmitTrailer( + 1, ToHeaders({{"final-status", "a-ok"}, + {"x-comment", "trailers sure are cool"}})); + ASSERT_EQ(trailer_result, 0); + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x5, 0)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + send_result = session.Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::HEADERS})); +} + +// Tests the case where the server queues trailers immediately after headers and +// data, and before any writes have taken place. +TEST(OgHttp2SessionTest, ServerQueuesTrailersWithResponse) { + DataSavingVisitor visitor; + OgHttp2Session::Options options; + options.perspective = Perspective::kServer; + OgHttp2Session session(visitor, options); + + EXPECT_FALSE(session.want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = session.ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + // Server will want to send initial SETTINGS, and a SETTINGS ack. + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + int send_result = session.Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS})); + visitor.Clear(); + + EXPECT_FALSE(session.want_write()); + + // The body source must indicate that the end of the body is not the end of + // the stream. + auto body1 = std::make_unique(visitor, false); + body1->AppendPayload("This is an example response body."); + body1->EndData(); + int submit_result = session.SubmitResponse( + 1, ToHeaders({{":status", "200"}, {"x-comment", "Sure, sounds good."}}), + std::move(body1)); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(session.want_write()); + // There has not been a call to Send() yet, so neither headers nor body have + // been written. + int trailer_result = session.SubmitTrailer( + 1, ToHeaders({{"final-status", "a-ok"}, + {"x-comment", "trailers sure are cool"}})); + ASSERT_EQ(trailer_result, 0); + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x5, 0)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + send_result = session.Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::HEADERS, SpdyFrameType::DATA, + SpdyFrameType::HEADERS})); +} + +TEST(OgHttp2SessionTest, ServerSeesErrorOnEndStream) { + DataSavingVisitor visitor; + OgHttp2Session::Options options; + options.perspective = Perspective::kServer; + OgHttp2Session session(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/"}}, + /*fin=*/false) + .Data(1, "Request body", true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0x4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 0x1)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, _)); + EXPECT_CALL(visitor, OnDataForStream(1, "Request body")); + EXPECT_CALL(visitor, OnEndStream(1)).WillOnce(testing::Return(false)); + EXPECT_CALL( + visitor, + OnConnectionError(Http2VisitorInterface::ConnectionError::kParseError)); + + const int64_t result = session.ProcessBytes(frames); + EXPECT_EQ(/*NGHTTP2_ERR_CALLBACK_FAILURE=*/-902, result); + + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL( + visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast( + Http2VisitorInterface::ConnectionError::kParseError))); + + int send_result = session.Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::GOAWAY})); + visitor.Clear(); + + EXPECT_FALSE(session.want_write()); +} + +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/oghttp2_util.cc b/quiche/http2/adapter/oghttp2_util.cc new file mode 100644 index 000000000000..26a3fb3461c7 --- /dev/null +++ b/quiche/http2/adapter/oghttp2_util.cc @@ -0,0 +1,17 @@ +#include "quiche/http2/adapter/oghttp2_util.h" + +namespace http2 { +namespace adapter { + +spdy::Http2HeaderBlock ToHeaderBlock(absl::Span headers) { + spdy::Http2HeaderBlock block; + for (const Header& header : headers) { + absl::string_view name = GetStringView(header.first).first; + absl::string_view value = GetStringView(header.second).first; + block.AppendValueOrAddHeader(name, value); + } + return block; +} + +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/oghttp2_util.h b/quiche/http2/adapter/oghttp2_util.h new file mode 100644 index 000000000000..0aba100a4a20 --- /dev/null +++ b/quiche/http2/adapter/oghttp2_util.h @@ -0,0 +1,18 @@ +#ifndef QUICHE_HTTP2_ADAPTER_OGHTTP2_UTIL_H_ +#define QUICHE_HTTP2_ADAPTER_OGHTTP2_UTIL_H_ + +#include "absl/types/span.h" +#include "quiche/http2/adapter/http2_protocol.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace http2 { +namespace adapter { + +QUICHE_EXPORT spdy::Http2HeaderBlock ToHeaderBlock( + absl::Span headers); + +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_OGHTTP2_UTIL_H_ diff --git a/quiche/http2/adapter/oghttp2_util_test.cc b/quiche/http2/adapter/oghttp2_util_test.cc new file mode 100644 index 000000000000..d1a4177d9fbb --- /dev/null +++ b/quiche/http2/adapter/oghttp2_util_test.cc @@ -0,0 +1,83 @@ +#include "quiche/http2/adapter/oghttp2_util.h" + +#include +#include + +#include "quiche/http2/adapter/http2_protocol.h" +#include "quiche/http2/adapter/test_frame_sequence.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace adapter { +namespace test { +namespace { + +using HeaderPair = std::pair; + +TEST(ToHeaderBlock, EmptySpan) { + spdy::Http2HeaderBlock block = ToHeaderBlock({}); + EXPECT_TRUE(block.empty()); +} + +TEST(ToHeaderBlock, ExampleRequestHeaders) { + const std::vector pairs = {{":authority", "example.com"}, + {":method", "GET"}, + {":path", "/example.html"}, + {":scheme", "http"}, + {"accept", "text/plain, text/html"}}; + const std::vector
headers = ToHeaders(pairs); + spdy::Http2HeaderBlock block = ToHeaderBlock(headers); + EXPECT_THAT(block, testing::ElementsAreArray(pairs)); +} + +TEST(ToHeaderBlock, ExampleResponseHeaders) { + const std::vector pairs = { + {":status", "403"}, + {"content-length", "1023"}, + {"x-extra-info", "humblest apologies"}}; + const std::vector
headers = ToHeaders(pairs); + spdy::Http2HeaderBlock block = ToHeaderBlock(headers); + EXPECT_THAT(block, testing::ElementsAreArray(pairs)); +} + +TEST(ToHeaderBlock, RepeatedRequestHeaderNames) { + const std::vector pairs = { + {":authority", "example.com"}, {":method", "GET"}, + {":path", "/example.html"}, {":scheme", "http"}, + {"cookie", "chocolate_chips=yes"}, {"accept", "text/plain, text/html"}, + {"cookie", "raisins=no"}}; + const std::vector expected = { + {":authority", "example.com"}, + {":method", "GET"}, + {":path", "/example.html"}, + {":scheme", "http"}, + {"cookie", "chocolate_chips=yes; raisins=no"}, + {"accept", "text/plain, text/html"}}; + const std::vector
headers = ToHeaders(pairs); + spdy::Http2HeaderBlock block = ToHeaderBlock(headers); + EXPECT_THAT(block, testing::ElementsAreArray(expected)); +} + +TEST(ToHeaderBlock, RepeatedResponseHeaderNames) { + const std::vector pairs = { + {":status", "403"}, {"x-extra-info", "sorry"}, + {"content-length", "1023"}, {"x-extra-info", "humblest apologies"}, + {"content-length", "1024"}, {"set-cookie", "chocolate_chips=yes"}, + {"set-cookie", "raisins=no"}}; + const std::vector expected = { + {":status", "403"}, + {"x-extra-info", absl::string_view("sorry\0humblest apologies", 24)}, + {"content-length", absl::string_view("1023" + "\0" + "1024", + 9)}, + {"set-cookie", absl::string_view("chocolate_chips=yes\0raisins=no", 30)}}; + const std::vector
headers = ToHeaders(pairs); + spdy::Http2HeaderBlock block = ToHeaderBlock(headers); + EXPECT_THAT(block, testing::ElementsAreArray(expected)); +} + +} // namespace +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/recording_http2_visitor.cc b/quiche/http2/adapter/recording_http2_visitor.cc new file mode 100644 index 000000000000..d55045f3c8b3 --- /dev/null +++ b/quiche/http2/adapter/recording_http2_visitor.cc @@ -0,0 +1,181 @@ +#include "quiche/http2/adapter/recording_http2_visitor.h" + +#include "absl/strings/str_format.h" +#include "quiche/http2/adapter/http2_protocol.h" +#include "quiche/http2/adapter/http2_util.h" + +namespace http2 { +namespace adapter { +namespace test { + +int64_t RecordingHttp2Visitor::OnReadyToSend(absl::string_view serialized) { + events_.push_back(absl::StrFormat("OnReadyToSend %d", serialized.size())); + return serialized.size(); +} + +void RecordingHttp2Visitor::OnConnectionError(ConnectionError error) { + events_.push_back( + absl::StrFormat("OnConnectionError %s", ConnectionErrorToString(error))); +} + +bool RecordingHttp2Visitor::OnFrameHeader(Http2StreamId stream_id, + size_t length, uint8_t type, + uint8_t flags) { + events_.push_back(absl::StrFormat("OnFrameHeader %d %d %d %d", stream_id, + length, type, flags)); + return true; +} + +void RecordingHttp2Visitor::OnSettingsStart() { + events_.push_back("OnSettingsStart"); +} + +void RecordingHttp2Visitor::OnSetting(Http2Setting setting) { + events_.push_back(absl::StrFormat( + "OnSetting %s %d", Http2SettingsIdToString(setting.id), setting.value)); +} + +void RecordingHttp2Visitor::OnSettingsEnd() { + events_.push_back("OnSettingsEnd"); +} + +void RecordingHttp2Visitor::OnSettingsAck() { + events_.push_back("OnSettingsAck"); +} + +bool RecordingHttp2Visitor::OnBeginHeadersForStream(Http2StreamId stream_id) { + events_.push_back(absl::StrFormat("OnBeginHeadersForStream %d", stream_id)); + return true; +} + +Http2VisitorInterface::OnHeaderResult RecordingHttp2Visitor::OnHeaderForStream( + Http2StreamId stream_id, absl::string_view name, absl::string_view value) { + events_.push_back( + absl::StrFormat("OnHeaderForStream %d %s %s", stream_id, name, value)); + return HEADER_OK; +} + +bool RecordingHttp2Visitor::OnEndHeadersForStream(Http2StreamId stream_id) { + events_.push_back(absl::StrFormat("OnEndHeadersForStream %d", stream_id)); + return true; +} + +bool RecordingHttp2Visitor::OnDataPaddingLength(Http2StreamId stream_id, + size_t padding_length) { + events_.push_back( + absl::StrFormat("OnDataPaddingLength %d %d", stream_id, padding_length)); + return true; +} + +bool RecordingHttp2Visitor::OnBeginDataForStream(Http2StreamId stream_id, + size_t payload_length) { + events_.push_back( + absl::StrFormat("OnBeginDataForStream %d %d", stream_id, payload_length)); + return true; +} + +bool RecordingHttp2Visitor::OnDataForStream(Http2StreamId stream_id, + absl::string_view data) { + events_.push_back(absl::StrFormat("OnDataForStream %d %s", stream_id, data)); + return true; +} + +bool RecordingHttp2Visitor::OnEndStream(Http2StreamId stream_id) { + events_.push_back(absl::StrFormat("OnEndStream %d", stream_id)); + return true; +} + +void RecordingHttp2Visitor::OnRstStream(Http2StreamId stream_id, + Http2ErrorCode error_code) { + events_.push_back(absl::StrFormat("OnRstStream %d %s", stream_id, + Http2ErrorCodeToString(error_code))); +} + +bool RecordingHttp2Visitor::OnCloseStream(Http2StreamId stream_id, + Http2ErrorCode error_code) { + events_.push_back(absl::StrFormat("OnCloseStream %d %s", stream_id, + Http2ErrorCodeToString(error_code))); + return true; +} + +void RecordingHttp2Visitor::OnPriorityForStream(Http2StreamId stream_id, + Http2StreamId parent_stream_id, + int weight, bool exclusive) { + events_.push_back(absl::StrFormat("OnPriorityForStream %d %d %d %d", + stream_id, parent_stream_id, weight, + exclusive)); +} + +void RecordingHttp2Visitor::OnPing(Http2PingId ping_id, bool is_ack) { + events_.push_back(absl::StrFormat("OnPing %d %d", ping_id, is_ack)); +} + +void RecordingHttp2Visitor::OnPushPromiseForStream( + Http2StreamId stream_id, Http2StreamId promised_stream_id) { + events_.push_back(absl::StrFormat("OnPushPromiseForStream %d %d", stream_id, + promised_stream_id)); +} + +bool RecordingHttp2Visitor::OnGoAway(Http2StreamId last_accepted_stream_id, + Http2ErrorCode error_code, + absl::string_view opaque_data) { + events_.push_back( + absl::StrFormat("OnGoAway %d %s %s", last_accepted_stream_id, + Http2ErrorCodeToString(error_code), opaque_data)); + return true; +} + +void RecordingHttp2Visitor::OnWindowUpdate(Http2StreamId stream_id, + int window_increment) { + events_.push_back( + absl::StrFormat("OnWindowUpdate %d %d", stream_id, window_increment)); +} + +int RecordingHttp2Visitor::OnBeforeFrameSent(uint8_t frame_type, + Http2StreamId stream_id, + size_t length, uint8_t flags) { + events_.push_back(absl::StrFormat("OnBeforeFrameSent %d %d %d %d", frame_type, + stream_id, length, flags)); + return 0; +} + +int RecordingHttp2Visitor::OnFrameSent(uint8_t frame_type, + Http2StreamId stream_id, size_t length, + uint8_t flags, uint32_t error_code) { + events_.push_back(absl::StrFormat("OnFrameSent %d %d %d %d %d", frame_type, + stream_id, length, flags, error_code)); + return 0; +} + +bool RecordingHttp2Visitor::OnInvalidFrame(Http2StreamId stream_id, + InvalidFrameError error) { + events_.push_back(absl::StrFormat("OnInvalidFrame %d %s", stream_id, + InvalidFrameErrorToString(error))); + return true; +} + +void RecordingHttp2Visitor::OnBeginMetadataForStream(Http2StreamId stream_id, + size_t payload_length) { + events_.push_back(absl::StrFormat("OnBeginMetadataForStream %d %d", stream_id, + payload_length)); +} + +bool RecordingHttp2Visitor::OnMetadataForStream(Http2StreamId stream_id, + absl::string_view metadata) { + events_.push_back( + absl::StrFormat("OnMetadataForStream %d %s", stream_id, metadata)); + return true; +} + +bool RecordingHttp2Visitor::OnMetadataEndForStream(Http2StreamId stream_id) { + events_.push_back(absl::StrFormat("OnMetadataEndForStream %d", stream_id)); + return true; +} + +void RecordingHttp2Visitor::OnErrorDebug(absl::string_view message) { + events_.push_back(absl::StrFormat("OnErrorDebug %s", message)); +} + +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/recording_http2_visitor.h b/quiche/http2/adapter/recording_http2_visitor.h new file mode 100644 index 000000000000..d796fe721ba9 --- /dev/null +++ b/quiche/http2/adapter/recording_http2_visitor.h @@ -0,0 +1,80 @@ +#ifndef QUICHE_HTTP2_ADAPTER_RECORDING_HTTP2_VISITOR_H_ +#define QUICHE_HTTP2_ADAPTER_RECORDING_HTTP2_VISITOR_H_ + +#include +#include +#include + +#include "quiche/http2/adapter/http2_visitor_interface.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace adapter { +namespace test { + +// A visitor implementation that records the sequence of callbacks it receives. +class QUICHE_NO_EXPORT RecordingHttp2Visitor : public Http2VisitorInterface { + public: + using Event = std::string; + using EventSequence = std::list; + + // From Http2VisitorInterface + int64_t OnReadyToSend(absl::string_view serialized) override; + void OnConnectionError(ConnectionError error) override; + bool OnFrameHeader(Http2StreamId stream_id, size_t length, uint8_t type, + uint8_t flags) override; + void OnSettingsStart() override; + void OnSetting(Http2Setting setting) override; + void OnSettingsEnd() override; + void OnSettingsAck() override; + bool OnBeginHeadersForStream(Http2StreamId stream_id) override; + OnHeaderResult OnHeaderForStream(Http2StreamId stream_id, + absl::string_view name, + absl::string_view value) override; + bool OnEndHeadersForStream(Http2StreamId stream_id) override; + bool OnDataPaddingLength(Http2StreamId stream_id, + size_t padding_length) override; + bool OnBeginDataForStream(Http2StreamId stream_id, + size_t payload_length) override; + bool OnDataForStream(Http2StreamId stream_id, + absl::string_view data) override; + bool OnEndStream(Http2StreamId stream_id) override; + void OnRstStream(Http2StreamId stream_id, Http2ErrorCode error_code) override; + bool OnCloseStream(Http2StreamId stream_id, + Http2ErrorCode error_code) override; + void OnPriorityForStream(Http2StreamId stream_id, + Http2StreamId parent_stream_id, int weight, + bool exclusive) override; + void OnPing(Http2PingId ping_id, bool is_ack) override; + void OnPushPromiseForStream(Http2StreamId stream_id, + Http2StreamId promised_stream_id) override; + bool OnGoAway(Http2StreamId last_accepted_stream_id, + Http2ErrorCode error_code, + absl::string_view opaque_data) override; + void OnWindowUpdate(Http2StreamId stream_id, int window_increment) override; + int OnBeforeFrameSent(uint8_t frame_type, Http2StreamId stream_id, + size_t length, uint8_t flags) override; + int OnFrameSent(uint8_t frame_type, Http2StreamId stream_id, size_t length, + uint8_t flags, uint32_t error_code) override; + bool OnInvalidFrame(Http2StreamId stream_id, + InvalidFrameError error) override; + void OnBeginMetadataForStream(Http2StreamId stream_id, + size_t payload_length) override; + bool OnMetadataForStream(Http2StreamId stream_id, + absl::string_view metadata) override; + bool OnMetadataEndForStream(Http2StreamId stream_id) override; + void OnErrorDebug(absl::string_view message) override; + + const EventSequence& GetEventSequence() const { return events_; } + void Clear() { events_.clear(); } + + private: + EventSequence events_; +}; + +} // namespace test +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_RECORDING_HTTP2_VISITOR_H_ diff --git a/quiche/http2/adapter/recording_http2_visitor_test.cc b/quiche/http2/adapter/recording_http2_visitor_test.cc new file mode 100644 index 000000000000..bf2dee2016a8 --- /dev/null +++ b/quiche/http2/adapter/recording_http2_visitor_test.cc @@ -0,0 +1,131 @@ +#include "quiche/http2/adapter/recording_http2_visitor.h" + +#include + +#include "quiche/http2/adapter/http2_protocol.h" +#include "quiche/http2/adapter/http2_visitor_interface.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace adapter { +namespace test { +namespace { + +using ::testing::IsEmpty; + +TEST(RecordingHttp2VisitorTest, EmptySequence) { + RecordingHttp2Visitor chocolate_visitor; + RecordingHttp2Visitor vanilla_visitor; + + EXPECT_THAT(chocolate_visitor.GetEventSequence(), IsEmpty()); + EXPECT_THAT(vanilla_visitor.GetEventSequence(), IsEmpty()); + EXPECT_EQ(chocolate_visitor.GetEventSequence(), + vanilla_visitor.GetEventSequence()); + + chocolate_visitor.OnSettingsStart(); + + EXPECT_THAT(chocolate_visitor.GetEventSequence(), testing::Not(IsEmpty())); + EXPECT_THAT(vanilla_visitor.GetEventSequence(), IsEmpty()); + EXPECT_NE(chocolate_visitor.GetEventSequence(), + vanilla_visitor.GetEventSequence()); + + chocolate_visitor.Clear(); + + EXPECT_THAT(chocolate_visitor.GetEventSequence(), IsEmpty()); + EXPECT_THAT(vanilla_visitor.GetEventSequence(), IsEmpty()); + EXPECT_EQ(chocolate_visitor.GetEventSequence(), + vanilla_visitor.GetEventSequence()); +} + +TEST(RecordingHttp2VisitorTest, SameEventsProduceSameSequence) { + RecordingHttp2Visitor chocolate_visitor; + RecordingHttp2Visitor vanilla_visitor; + + // Prepare some random values to deliver with the events. + http2::test::Http2Random random; + const Http2StreamId stream_id = random.Uniform(kMaxStreamId); + const Http2StreamId another_stream_id = random.Uniform(kMaxStreamId); + const size_t length = random.Rand16(); + const uint8_t type = random.Rand8(); + const uint8_t flags = random.Rand8(); + const Http2ErrorCode error_code = static_cast( + random.Uniform(static_cast(Http2ErrorCode::MAX_ERROR_CODE))); + const Http2Setting setting = {random.Rand16(), random.Rand32()}; + const absl::string_view alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-"; + const std::string some_string = + random.RandStringWithAlphabet(random.Rand8(), alphabet); + const std::string another_string = + random.RandStringWithAlphabet(random.Rand8(), alphabet); + const uint16_t some_int = random.Rand16(); + const bool some_bool = random.OneIn(2); + + // Send the same arbitrary sequence of events to both visitors. + std::list visitors = {&chocolate_visitor, + &vanilla_visitor}; + for (RecordingHttp2Visitor* visitor : visitors) { + visitor->OnConnectionError( + Http2VisitorInterface::ConnectionError::kSendError); + visitor->OnFrameHeader(stream_id, length, type, flags); + visitor->OnSettingsStart(); + visitor->OnSetting(setting); + visitor->OnSettingsEnd(); + visitor->OnSettingsAck(); + visitor->OnBeginHeadersForStream(stream_id); + visitor->OnHeaderForStream(stream_id, some_string, another_string); + visitor->OnEndHeadersForStream(stream_id); + visitor->OnBeginDataForStream(stream_id, length); + visitor->OnDataForStream(stream_id, some_string); + visitor->OnDataForStream(stream_id, another_string); + visitor->OnEndStream(stream_id); + visitor->OnRstStream(stream_id, error_code); + visitor->OnCloseStream(stream_id, error_code); + visitor->OnPriorityForStream(stream_id, another_stream_id, some_int, + some_bool); + visitor->OnPing(some_int, some_bool); + visitor->OnPushPromiseForStream(stream_id, another_stream_id); + visitor->OnGoAway(stream_id, error_code, some_string); + visitor->OnWindowUpdate(stream_id, some_int); + visitor->OnBeginMetadataForStream(stream_id, length); + visitor->OnMetadataForStream(stream_id, some_string); + visitor->OnMetadataForStream(stream_id, another_string); + visitor->OnMetadataEndForStream(stream_id); + } + + EXPECT_EQ(chocolate_visitor.GetEventSequence(), + vanilla_visitor.GetEventSequence()); +} + +TEST(RecordingHttp2VisitorTest, DifferentEventsProduceDifferentSequence) { + RecordingHttp2Visitor chocolate_visitor; + RecordingHttp2Visitor vanilla_visitor; + EXPECT_EQ(chocolate_visitor.GetEventSequence(), + vanilla_visitor.GetEventSequence()); + + const Http2StreamId stream_id = 1; + const size_t length = 42; + + // Different events with the same method arguments should produce different + // event sequences. + chocolate_visitor.OnBeginDataForStream(stream_id, length); + vanilla_visitor.OnBeginMetadataForStream(stream_id, length); + EXPECT_NE(chocolate_visitor.GetEventSequence(), + vanilla_visitor.GetEventSequence()); + + chocolate_visitor.Clear(); + vanilla_visitor.Clear(); + EXPECT_EQ(chocolate_visitor.GetEventSequence(), + vanilla_visitor.GetEventSequence()); + + // The same events with different method arguments should produce different + // event sequences. + chocolate_visitor.OnBeginHeadersForStream(stream_id); + vanilla_visitor.OnBeginHeadersForStream(stream_id + 2); + EXPECT_NE(chocolate_visitor.GetEventSequence(), + vanilla_visitor.GetEventSequence()); +} + +} // namespace +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/test_frame_sequence.cc b/quiche/http2/adapter/test_frame_sequence.cc new file mode 100644 index 000000000000..d73f036f3c00 --- /dev/null +++ b/quiche/http2/adapter/test_frame_sequence.cc @@ -0,0 +1,188 @@ +#include "quiche/http2/adapter/test_frame_sequence.h" + +#include + +#include "quiche/http2/adapter/http2_util.h" +#include "quiche/http2/adapter/oghttp2_util.h" +#include "quiche/spdy/core/hpack/hpack_encoder.h" +#include "quiche/spdy/core/spdy_framer.h" + +namespace http2 { +namespace adapter { +namespace test { + +std::vector
ToHeaders( + absl::Span> headers) { + std::vector
out; + for (auto [name, value] : headers) { + out.push_back(std::make_pair(HeaderRep(name), HeaderRep(value))); + } + return out; +} + +TestFrameSequence& TestFrameSequence::ClientPreface( + absl::Span settings) { + preface_ = spdy::kHttp2ConnectionHeaderPrefix; + return Settings(settings); +} + +TestFrameSequence& TestFrameSequence::ServerPreface( + absl::Span settings) { + return Settings(settings); +} + +TestFrameSequence& TestFrameSequence::Data(Http2StreamId stream_id, + absl::string_view payload, bool fin, + absl::optional padding_length) { + auto data = std::make_unique(stream_id, payload); + data->set_fin(fin); + if (padding_length) { + data->set_padding_len(padding_length.value()); + } + frames_.push_back(std::move(data)); + return *this; +} + +TestFrameSequence& TestFrameSequence::RstStream(Http2StreamId stream_id, + Http2ErrorCode error) { + frames_.push_back(std::make_unique( + stream_id, TranslateErrorCode(error))); + return *this; +} + +TestFrameSequence& TestFrameSequence::Settings( + absl::Span settings) { + auto settings_frame = std::make_unique(); + for (const Http2Setting& setting : settings) { + settings_frame->AddSetting(setting.id, setting.value); + } + frames_.push_back(std::move(settings_frame)); + return *this; +} + +TestFrameSequence& TestFrameSequence::SettingsAck() { + auto settings = std::make_unique(); + settings->set_is_ack(true); + frames_.push_back(std::move(settings)); + return *this; +} + +TestFrameSequence& TestFrameSequence::PushPromise( + Http2StreamId stream_id, Http2StreamId promised_stream_id, + absl::Span headers) { + frames_.push_back(std::make_unique( + stream_id, promised_stream_id, ToHeaderBlock(headers))); + return *this; +} + +TestFrameSequence& TestFrameSequence::Ping(Http2PingId id) { + frames_.push_back(std::make_unique(id)); + return *this; +} + +TestFrameSequence& TestFrameSequence::PingAck(Http2PingId id) { + auto ping = std::make_unique(id); + ping->set_is_ack(true); + frames_.push_back(std::move(ping)); + return *this; +} + +TestFrameSequence& TestFrameSequence::GoAway(Http2StreamId last_good_stream_id, + Http2ErrorCode error, + absl::string_view payload) { + frames_.push_back(std::make_unique( + last_good_stream_id, TranslateErrorCode(error), std::string(payload))); + return *this; +} + +TestFrameSequence& TestFrameSequence::Headers( + Http2StreamId stream_id, + absl::Span> headers, + bool fin, bool add_continuation) { + return Headers(stream_id, ToHeaders(headers), fin, add_continuation); +} + +TestFrameSequence& TestFrameSequence::Headers(Http2StreamId stream_id, + spdy::Http2HeaderBlock block, + bool fin, bool add_continuation) { + if (add_continuation) { + // The normal intermediate representations don't allow you to represent a + // nonterminal HEADERS frame explicitly, so we'll need to use + // SpdyUnknownIRs. For simplicity, and in order not to mess up HPACK state, + // the payload will be uncompressed. + spdy::HpackEncoder encoder; + encoder.DisableCompression(); + std::string encoded_block = encoder.EncodeHeaderBlock(block); + const size_t pos = encoded_block.size() / 2; + const uint8_t flags = fin ? END_STREAM_FLAG : 0x0; + frames_.push_back(std::make_unique( + stream_id, static_cast(spdy::SpdyFrameType::HEADERS), flags, + encoded_block.substr(0, pos))); + + auto continuation = std::make_unique(stream_id); + continuation->set_end_headers(true); + continuation->take_encoding(encoded_block.substr(pos)); + frames_.push_back(std::move(continuation)); + } else { + auto headers = + std::make_unique(stream_id, std::move(block)); + headers->set_fin(fin); + frames_.push_back(std::move(headers)); + } + return *this; +} + +TestFrameSequence& TestFrameSequence::Headers(Http2StreamId stream_id, + absl::Span headers, + bool fin, bool add_continuation) { + return Headers(stream_id, ToHeaderBlock(headers), fin, add_continuation); +} + +TestFrameSequence& TestFrameSequence::WindowUpdate(Http2StreamId stream_id, + int32_t delta) { + frames_.push_back( + std::make_unique(stream_id, delta)); + return *this; +} + +TestFrameSequence& TestFrameSequence::Priority(Http2StreamId stream_id, + Http2StreamId parent_stream_id, + int weight, bool exclusive) { + frames_.push_back(std::make_unique( + stream_id, parent_stream_id, weight, exclusive)); + return *this; +} + +TestFrameSequence& TestFrameSequence::Metadata(Http2StreamId stream_id, + absl::string_view payload, + bool multiple_frames) { + if (multiple_frames) { + const size_t pos = payload.size() / 2; + frames_.push_back(std::make_unique( + stream_id, kMetadataFrameType, 0, std::string(payload.substr(0, pos)))); + frames_.push_back(std::make_unique( + stream_id, kMetadataFrameType, kMetadataEndFlag, + std::string(payload.substr(pos)))); + } else { + frames_.push_back(std::make_unique( + stream_id, kMetadataFrameType, kMetadataEndFlag, std::string(payload))); + } + return *this; +} + +std::string TestFrameSequence::Serialize() { + std::string result; + if (!preface_.empty()) { + result = preface_; + } + spdy::SpdyFramer framer(spdy::SpdyFramer::ENABLE_COMPRESSION); + for (const auto& frame : frames_) { + spdy::SpdySerializedFrame f = framer.SerializeFrame(*frame); + absl::StrAppend(&result, absl::string_view(f)); + } + return result; +} + +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/test_frame_sequence.h b/quiche/http2/adapter/test_frame_sequence.h new file mode 100644 index 000000000000..953aed111f08 --- /dev/null +++ b/quiche/http2/adapter/test_frame_sequence.h @@ -0,0 +1,72 @@ +#ifndef QUICHE_HTTP2_ADAPTER_TEST_FRAME_SEQUENCE_H_ +#define QUICHE_HTTP2_ADAPTER_TEST_FRAME_SEQUENCE_H_ + +#include +#include +#include +#include + +#include "quiche/http2/adapter/http2_protocol.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/spdy/core/http2_header_block.h" +#include "quiche/spdy/core/spdy_protocol.h" + +namespace http2 { +namespace adapter { +namespace test { + +std::vector
QUICHE_NO_EXPORT ToHeaders( + absl::Span> headers); + +class QUICHE_NO_EXPORT TestFrameSequence { + public: + TestFrameSequence() = default; + + TestFrameSequence& ClientPreface( + absl::Span settings = {}); + TestFrameSequence& ServerPreface( + absl::Span settings = {}); + TestFrameSequence& Data(Http2StreamId stream_id, absl::string_view payload, + bool fin = false, + absl::optional padding_length = absl::nullopt); + TestFrameSequence& RstStream(Http2StreamId stream_id, Http2ErrorCode error); + TestFrameSequence& Settings(absl::Span settings); + TestFrameSequence& SettingsAck(); + TestFrameSequence& PushPromise(Http2StreamId stream_id, + Http2StreamId promised_stream_id, + absl::Span headers); + TestFrameSequence& Ping(Http2PingId id); + TestFrameSequence& PingAck(Http2PingId id); + TestFrameSequence& GoAway(Http2StreamId last_good_stream_id, + Http2ErrorCode error, + absl::string_view payload = ""); + TestFrameSequence& Headers( + Http2StreamId stream_id, + absl::Span> headers, + bool fin = false, bool add_continuation = false); + TestFrameSequence& Headers(Http2StreamId stream_id, + spdy::Http2HeaderBlock block, bool fin = false, + bool add_continuation = false); + TestFrameSequence& Headers(Http2StreamId stream_id, + absl::Span headers, bool fin = false, + bool add_continuation = false); + TestFrameSequence& WindowUpdate(Http2StreamId stream_id, int32_t delta); + TestFrameSequence& Priority(Http2StreamId stream_id, + Http2StreamId parent_stream_id, int weight, + bool exclusive); + TestFrameSequence& Metadata(Http2StreamId stream_id, + absl::string_view payload, + bool multiple_frames = false); + + std::string Serialize(); + + private: + std::string preface_; + std::vector> frames_; +}; + +} // namespace test +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_TEST_FRAME_SEQUENCE_H_ diff --git a/quiche/http2/adapter/test_utils.cc b/quiche/http2/adapter/test_utils.cc new file mode 100644 index 000000000000..f85d5a3f5e22 --- /dev/null +++ b/quiche/http2/adapter/test_utils.cc @@ -0,0 +1,224 @@ +#include "quiche/http2/adapter/test_utils.h" + +#include + +#include "absl/strings/str_format.h" +#include "quiche/http2/adapter/http2_visitor_interface.h" +#include "quiche/common/quiche_data_reader.h" +#include "quiche/spdy/core/hpack/hpack_encoder.h" + +namespace http2 { +namespace adapter { +namespace test { +namespace { + +using ConnectionError = Http2VisitorInterface::ConnectionError; + +} // anonymous namespace + +TestDataFrameSource::TestDataFrameSource(Http2VisitorInterface& visitor, + bool has_fin) + : visitor_(visitor), has_fin_(has_fin) {} + +void TestDataFrameSource::AppendPayload(absl::string_view payload) { + QUICHE_CHECK(!end_data_); + if (!payload.empty()) { + payload_fragments_.push_back(std::string(payload)); + current_fragment_ = payload_fragments_.front(); + } +} + +void TestDataFrameSource::EndData() { end_data_ = true; } + +std::pair TestDataFrameSource::SelectPayloadLength( + size_t max_length) { + if (return_error_) { + return {DataFrameSource::kError, false}; + } + // The stream is done if there's no more data, or if |max_length| is at least + // as large as the remaining data. + const bool end_data = end_data_ && (current_fragment_.empty() || + (payload_fragments_.size() == 1 && + max_length >= current_fragment_.size())); + const int64_t length = std::min(max_length, current_fragment_.size()); + return {length, end_data}; +} + +bool TestDataFrameSource::Send(absl::string_view frame_header, + size_t payload_length) { + QUICHE_LOG_IF(DFATAL, payload_length > current_fragment_.size()) + << "payload_length: " << payload_length + << " current_fragment_size: " << current_fragment_.size(); + const std::string concatenated = + absl::StrCat(frame_header, current_fragment_.substr(0, payload_length)); + const int64_t result = visitor_.OnReadyToSend(concatenated); + if (result < 0) { + // Write encountered error. + visitor_.OnConnectionError(ConnectionError::kSendError); + current_fragment_ = {}; + payload_fragments_.clear(); + return false; + } else if (result == 0) { + // Write blocked. + return false; + } else if (static_cast(result) < concatenated.size()) { + // Probably need to handle this better within this test class. + QUICHE_LOG(DFATAL) + << "DATA frame not fully flushed. Connection will be corrupt!"; + visitor_.OnConnectionError(ConnectionError::kSendError); + current_fragment_ = {}; + payload_fragments_.clear(); + return false; + } + if (payload_length > 0) { + current_fragment_.remove_prefix(payload_length); + } + if (current_fragment_.empty() && !payload_fragments_.empty()) { + payload_fragments_.erase(payload_fragments_.begin()); + if (!payload_fragments_.empty()) { + current_fragment_ = payload_fragments_.front(); + } + } + return true; +} + +std::string EncodeHeaders(const spdy::Http2HeaderBlock& entries) { + spdy::HpackEncoder encoder; + encoder.DisableCompression(); + return encoder.EncodeHeaderBlock(entries); +} + +TestMetadataSource::TestMetadataSource(const spdy::Http2HeaderBlock& entries) + : encoded_entries_(EncodeHeaders(entries)) { + remaining_ = encoded_entries_; +} + +std::pair TestMetadataSource::Pack(uint8_t* dest, + size_t dest_len) { + const size_t copied = std::min(dest_len, remaining_.size()); + std::memcpy(dest, remaining_.data(), copied); + remaining_.remove_prefix(copied); + return std::make_pair(copied, remaining_.empty()); +} + +namespace { + +using TypeAndOptionalLength = + std::pair>; + +std::ostream& operator<<( + std::ostream& os, + const std::vector& types_and_lengths) { + for (const auto& type_and_length : types_and_lengths) { + os << "(" << spdy::FrameTypeToString(type_and_length.first) << ", " + << (type_and_length.second ? absl::StrCat(type_and_length.second.value()) + : "") + << ") "; + } + return os; +} + +std::string FrameTypeToString(uint8_t frame_type) { + if (spdy::IsDefinedFrameType(frame_type)) { + return spdy::FrameTypeToString(spdy::ParseFrameType(frame_type)); + } else { + return absl::StrFormat("0x%x", static_cast(frame_type)); + } +} + +// Custom gMock matcher, used to implement EqualsFrames(). +class SpdyControlFrameMatcher + : public testing::MatcherInterface { + public: + explicit SpdyControlFrameMatcher( + std::vector types_and_lengths) + : expected_types_and_lengths_(std::move(types_and_lengths)) {} + + bool MatchAndExplain(absl::string_view s, + testing::MatchResultListener* listener) const override { + quiche::QuicheDataReader reader(s.data(), s.size()); + + for (TypeAndOptionalLength expected : expected_types_and_lengths_) { + if (!MatchAndExplainOneFrame(expected.first, expected.second, &reader, + listener)) { + return false; + } + } + if (!reader.IsDoneReading()) { + *listener << "; " << reader.BytesRemaining() << " bytes left to read!"; + return false; + } + return true; + } + + bool MatchAndExplainOneFrame(spdy::SpdyFrameType expected_type, + absl::optional expected_length, + quiche::QuicheDataReader* reader, + testing::MatchResultListener* listener) const { + uint32_t payload_length; + if (!reader->ReadUInt24(&payload_length)) { + *listener << "; unable to read length field for expected_type " + << FrameTypeToString(expected_type) << ". data too short!"; + return false; + } + + if (expected_length && payload_length != expected_length.value()) { + *listener << "; actual length: " << payload_length + << " but expected length: " << expected_length.value(); + return false; + } + + uint8_t raw_type; + if (!reader->ReadUInt8(&raw_type)) { + *listener << "; unable to read type field for expected_type " + << FrameTypeToString(expected_type) << ". data too short!"; + return false; + } + + if (raw_type != static_cast(expected_type)) { + *listener << "; actual type: " << FrameTypeToString(raw_type) + << " but expected type: " << FrameTypeToString(expected_type); + return false; + } + + // Seek past flags (1B), stream ID (4B), and payload. Reach the next frame. + reader->Seek(5 + payload_length); + return true; + } + + void DescribeTo(std::ostream* os) const override { + *os << "Data contains frames of types in sequence " + << expected_types_and_lengths_; + } + + void DescribeNegationTo(std::ostream* os) const override { + *os << "Data does not contain frames of types in sequence " + << expected_types_and_lengths_; + } + + private: + const std::vector expected_types_and_lengths_; +}; + +} // namespace + +testing::Matcher EqualsFrames( + std::vector>> + types_and_lengths) { + return MakeMatcher(new SpdyControlFrameMatcher(std::move(types_and_lengths))); +} + +testing::Matcher EqualsFrames( + std::vector types) { + std::vector>> + types_and_lengths; + types_and_lengths.reserve(types.size()); + for (spdy::SpdyFrameType type : types) { + types_and_lengths.push_back({type, absl::nullopt}); + } + return MakeMatcher(new SpdyControlFrameMatcher(std::move(types_and_lengths))); +} + +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/test_utils.h b/quiche/http2/adapter/test_utils.h new file mode 100644 index 000000000000..cd1aa8a3c1d3 --- /dev/null +++ b/quiche/http2/adapter/test_utils.h @@ -0,0 +1,139 @@ +#ifndef QUICHE_HTTP2_ADAPTER_TEST_UTILS_H_ +#define QUICHE_HTTP2_ADAPTER_TEST_UTILS_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "quiche/http2/adapter/data_source.h" +#include "quiche/http2/adapter/http2_protocol.h" +#include "quiche/http2/adapter/mock_http2_visitor.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/spdy/core/http2_header_block.h" +#include "quiche/spdy/core/spdy_protocol.h" + +namespace http2 { +namespace adapter { +namespace test { + +class QUICHE_NO_EXPORT DataSavingVisitor + : public testing::StrictMock { + public: + int64_t OnReadyToSend(absl::string_view data) override { + if (has_write_error_) { + return kSendError; + } + if (is_write_blocked_) { + return kSendBlocked; + } + const size_t to_accept = std::min(send_limit_, data.size()); + if (to_accept == 0) { + return kSendBlocked; + } + absl::StrAppend(&data_, data.substr(0, to_accept)); + return to_accept; + } + + bool OnMetadataForStream(Http2StreamId stream_id, + absl::string_view metadata) override { + const bool ret = testing::StrictMock::OnMetadataForStream( + stream_id, metadata); + if (ret) { + auto result = + metadata_map_.try_emplace(stream_id, std::vector()); + result.first->second.push_back(std::string(metadata)); + } + return ret; + } + + const std::vector GetMetadata(Http2StreamId stream_id) { + auto it = metadata_map_.find(stream_id); + if (it == metadata_map_.end()) { + return {}; + } else { + return it->second; + } + } + + const std::string& data() { return data_; } + void Clear() { data_.clear(); } + + void set_send_limit(size_t limit) { send_limit_ = limit; } + + bool is_write_blocked() const { return is_write_blocked_; } + void set_is_write_blocked(bool value) { is_write_blocked_ = value; } + + void set_has_write_error() { has_write_error_ = true; } + + private: + std::string data_; + absl::flat_hash_map> metadata_map_; + size_t send_limit_ = std::numeric_limits::max(); + bool is_write_blocked_ = false; + bool has_write_error_ = false; +}; + +// A test DataFrameSource. Starts out in the empty, blocked state. +class QUICHE_NO_EXPORT TestDataFrameSource : public DataFrameSource { + public: + TestDataFrameSource(Http2VisitorInterface& visitor, bool has_fin); + + void AppendPayload(absl::string_view payload); + void EndData(); + void SimulateError() { return_error_ = true; } + + std::pair SelectPayloadLength(size_t max_length) override; + bool Send(absl::string_view frame_header, size_t payload_length) override; + bool send_fin() const override { return has_fin_; } + + private: + Http2VisitorInterface& visitor_; + std::vector payload_fragments_; + absl::string_view current_fragment_; + // Whether the stream should end with the final frame of data. + const bool has_fin_; + // Whether |payload_fragments_| contains the final segment of data. + bool end_data_ = false; + // Whether SelectPayloadLength() should return an error. + bool return_error_ = false; +}; + +class QUICHE_NO_EXPORT TestMetadataSource : public MetadataSource { + public: + explicit TestMetadataSource(const spdy::Http2HeaderBlock& entries); + + size_t NumFrames(size_t max_frame_size) const override { + // Round up to the next frame. + return (encoded_entries_.size() + max_frame_size - 1) / max_frame_size; + } + std::pair Pack(uint8_t* dest, size_t dest_len) override; + void OnFailure() override {} + + private: + const std::string encoded_entries_; + absl::string_view remaining_; +}; + +// These matchers check whether a string consists entirely of HTTP/2 frames of +// the specified ordered sequence. This is useful in tests where we want to show +// that one or more particular frame types are serialized for sending to the +// peer. The match will fail if there are input bytes not consumed by the +// matcher. + +// Requires that frames match both types and lengths. +testing::Matcher EqualsFrames( + std::vector>> + types_and_lengths); + +// Requires that frames match the specified types. +testing::Matcher EqualsFrames( + std::vector types); + +} // namespace test +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_TEST_UTILS_H_ diff --git a/quiche/http2/adapter/test_utils_test.cc b/quiche/http2/adapter/test_utils_test.cc new file mode 100644 index 000000000000..0ea44dc54f84 --- /dev/null +++ b/quiche/http2/adapter/test_utils_test.cc @@ -0,0 +1,123 @@ +#include "quiche/http2/adapter/test_utils.h" + +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/spdy/core/spdy_framer.h" + +namespace http2 { +namespace adapter { +namespace test { +namespace { + +using spdy::SpdyFramer; + +TEST(EqualsFrames, Empty) { + EXPECT_THAT("", EqualsFrames(std::vector{})); +} + +TEST(EqualsFrames, SingleFrameWithLength) { + SpdyFramer framer{SpdyFramer::ENABLE_COMPRESSION}; + + spdy::SpdyPingIR ping{511}; + EXPECT_THAT(framer.SerializeFrame(ping), + EqualsFrames({{spdy::SpdyFrameType::PING, 8}})); + + spdy::SpdyWindowUpdateIR window_update{1, 101}; + EXPECT_THAT(framer.SerializeFrame(window_update), + EqualsFrames({{spdy::SpdyFrameType::WINDOW_UPDATE, 4}})); + + spdy::SpdyDataIR data{3, "Some example data, ha ha!"}; + EXPECT_THAT(framer.SerializeFrame(data), + EqualsFrames({{spdy::SpdyFrameType::DATA, 25}})); +} + +TEST(EqualsFrames, SingleFrameWithoutLength) { + SpdyFramer framer{SpdyFramer::ENABLE_COMPRESSION}; + + spdy::SpdyRstStreamIR rst_stream{7, spdy::ERROR_CODE_REFUSED_STREAM}; + EXPECT_THAT(framer.SerializeFrame(rst_stream), + EqualsFrames({{spdy::SpdyFrameType::RST_STREAM, absl::nullopt}})); + + spdy::SpdyGoAwayIR goaway{13, spdy::ERROR_CODE_ENHANCE_YOUR_CALM, + "Consider taking some deep breaths."}; + EXPECT_THAT(framer.SerializeFrame(goaway), + EqualsFrames({{spdy::SpdyFrameType::GOAWAY, absl::nullopt}})); + + spdy::Http2HeaderBlock block; + block[":method"] = "GET"; + block[":path"] = "/example"; + block[":authority"] = "example.com"; + spdy::SpdyHeadersIR headers{17, std::move(block)}; + EXPECT_THAT(framer.SerializeFrame(headers), + EqualsFrames({{spdy::SpdyFrameType::HEADERS, absl::nullopt}})); +} + +TEST(EqualsFrames, MultipleFrames) { + SpdyFramer framer{SpdyFramer::ENABLE_COMPRESSION}; + + spdy::SpdyPingIR ping{511}; + spdy::SpdyWindowUpdateIR window_update{1, 101}; + spdy::SpdyDataIR data{3, "Some example data, ha ha!"}; + spdy::SpdyRstStreamIR rst_stream{7, spdy::ERROR_CODE_REFUSED_STREAM}; + spdy::SpdyGoAwayIR goaway{13, spdy::ERROR_CODE_ENHANCE_YOUR_CALM, + "Consider taking some deep breaths."}; + spdy::Http2HeaderBlock block; + block[":method"] = "GET"; + block[":path"] = "/example"; + block[":authority"] = "example.com"; + spdy::SpdyHeadersIR headers{17, std::move(block)}; + + const std::string frame_sequence = + absl::StrCat(absl::string_view(framer.SerializeFrame(ping)), + absl::string_view(framer.SerializeFrame(window_update)), + absl::string_view(framer.SerializeFrame(data)), + absl::string_view(framer.SerializeFrame(rst_stream)), + absl::string_view(framer.SerializeFrame(goaway)), + absl::string_view(framer.SerializeFrame(headers))); + absl::string_view frame_sequence_view = frame_sequence; + EXPECT_THAT(frame_sequence, + EqualsFrames({{spdy::SpdyFrameType::PING, absl::nullopt}, + {spdy::SpdyFrameType::WINDOW_UPDATE, absl::nullopt}, + {spdy::SpdyFrameType::DATA, 25}, + {spdy::SpdyFrameType::RST_STREAM, absl::nullopt}, + {spdy::SpdyFrameType::GOAWAY, 42}, + {spdy::SpdyFrameType::HEADERS, 19}})); + EXPECT_THAT(frame_sequence_view, + EqualsFrames({{spdy::SpdyFrameType::PING, absl::nullopt}, + {spdy::SpdyFrameType::WINDOW_UPDATE, absl::nullopt}, + {spdy::SpdyFrameType::DATA, 25}, + {spdy::SpdyFrameType::RST_STREAM, absl::nullopt}, + {spdy::SpdyFrameType::GOAWAY, 42}, + {spdy::SpdyFrameType::HEADERS, 19}})); + EXPECT_THAT( + frame_sequence, + EqualsFrames( + {spdy::SpdyFrameType::PING, spdy::SpdyFrameType::WINDOW_UPDATE, + spdy::SpdyFrameType::DATA, spdy::SpdyFrameType::RST_STREAM, + spdy::SpdyFrameType::GOAWAY, spdy::SpdyFrameType::HEADERS})); + EXPECT_THAT( + frame_sequence_view, + EqualsFrames( + {spdy::SpdyFrameType::PING, spdy::SpdyFrameType::WINDOW_UPDATE, + spdy::SpdyFrameType::DATA, spdy::SpdyFrameType::RST_STREAM, + spdy::SpdyFrameType::GOAWAY, spdy::SpdyFrameType::HEADERS})); + + // If the final frame type is removed the expectation fails, as there are + // bytes left to read. + EXPECT_THAT( + frame_sequence, + testing::Not(EqualsFrames( + {spdy::SpdyFrameType::PING, spdy::SpdyFrameType::WINDOW_UPDATE, + spdy::SpdyFrameType::DATA, spdy::SpdyFrameType::RST_STREAM, + spdy::SpdyFrameType::GOAWAY}))); + EXPECT_THAT( + frame_sequence_view, + testing::Not(EqualsFrames( + {spdy::SpdyFrameType::PING, spdy::SpdyFrameType::WINDOW_UPDATE, + spdy::SpdyFrameType::DATA, spdy::SpdyFrameType::RST_STREAM, + spdy::SpdyFrameType::GOAWAY}))); +} + +} // namespace +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/window_manager.cc b/quiche/http2/adapter/window_manager.cc new file mode 100644 index 000000000000..81eaa775cad0 --- /dev/null +++ b/quiche/http2/adapter/window_manager.cc @@ -0,0 +1,103 @@ +#include "quiche/http2/adapter/window_manager.h" + +#include + +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { +namespace adapter { + +bool DefaultShouldWindowUpdateFn(int64_t limit, int64_t window, int64_t delta) { + // For the sake of efficiency, we want to send window updates if less than + // half of the max quota is available to the peer at any point in time. + const int64_t kDesiredMinWindow = limit / 2; + const int64_t kDesiredMinDelta = limit / 3; + if (delta >= kDesiredMinDelta) { + // This particular window update was sent because the available delta + // exceeded the desired minimum. + return true; + } else if (window < kDesiredMinWindow) { + // This particular window update was sent because the quota available to the + // peer at this moment is less than the desired minimum. + return true; + } + return false; +} + +WindowManager::WindowManager(int64_t window_size_limit, + WindowUpdateListener listener, + ShouldWindowUpdateFn should_window_update_fn, + bool update_window_on_notify) + : limit_(window_size_limit), + window_(window_size_limit), + buffered_(0), + listener_(std::move(listener)), + should_window_update_fn_(std::move(should_window_update_fn)), + update_window_on_notify_(update_window_on_notify) { + if (!should_window_update_fn_) { + should_window_update_fn_ = DefaultShouldWindowUpdateFn; + } +} + +void WindowManager::OnWindowSizeLimitChange(const int64_t new_limit) { + QUICHE_VLOG(2) << "WindowManager@" << this + << " OnWindowSizeLimitChange from old limit of " << limit_ + << " to new limit of " << new_limit; + window_ += (new_limit - limit_); + limit_ = new_limit; +} + +void WindowManager::SetWindowSizeLimit(int64_t new_limit) { + QUICHE_VLOG(2) << "WindowManager@" << this + << " SetWindowSizeLimit from old limit of " << limit_ + << " to new limit of " << new_limit; + limit_ = new_limit; + MaybeNotifyListener(); +} + +bool WindowManager::MarkDataBuffered(int64_t bytes) { + QUICHE_VLOG(2) << "WindowManager@" << this << " window: " << window_ + << " bytes: " << bytes; + if (window_ < bytes) { + QUICHE_VLOG(2) << "WindowManager@" << this << " window underflow " + << "window: " << window_ << " bytes: " << bytes; + window_ = 0; + } else { + window_ -= bytes; + } + buffered_ += bytes; + if (window_ == 0) { + // If data hasn't been flushed in a while there may be space available. + MaybeNotifyListener(); + } + return window_ > 0; +} + +void WindowManager::MarkDataFlushed(int64_t bytes) { + QUICHE_VLOG(2) << "WindowManager@" << this << " buffered: " << buffered_ + << " bytes: " << bytes; + if (buffered_ < bytes) { + QUICHE_BUG(bug_2816_1) << "WindowManager@" << this << " buffered underflow " + << "buffered_: " << buffered_ << " bytes: " << bytes; + buffered_ = 0; + } else { + buffered_ -= bytes; + } + MaybeNotifyListener(); +} + +void WindowManager::MaybeNotifyListener() { + const int64_t delta = limit_ - (buffered_ + window_); + if (should_window_update_fn_(limit_, window_, delta) && delta > 0) { + QUICHE_VLOG(2) << "WindowManager@" << this + << " Informing listener of delta: " << delta; + listener_(delta); + if (update_window_on_notify_) { + window_ += delta; + } + } +} + +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/adapter/window_manager.h b/quiche/http2/adapter/window_manager.h new file mode 100644 index 000000000000..700597416ebd --- /dev/null +++ b/quiche/http2/adapter/window_manager.h @@ -0,0 +1,93 @@ +#ifndef QUICHE_HTTP2_ADAPTER_WINDOW_MANAGER_H_ +#define QUICHE_HTTP2_ADAPTER_WINDOW_MANAGER_H_ + +#include +#include + +#include + +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace adapter { + +namespace test { +class WindowManagerPeer; +} + +// This class keeps track of a HTTP/2 flow control window, notifying a listener +// when a window update needs to be sent. This class is not thread-safe. +class QUICHE_EXPORT WindowManager { + public: + // A WindowUpdateListener is invoked when it is time to send a window update. + using WindowUpdateListener = std::function; + + // Invoked to determine whether to call the listener based on the window + // limit, window size, and delta that would be sent. + using ShouldWindowUpdateFn = + std::function; + + WindowManager(int64_t window_size_limit, WindowUpdateListener listener, + ShouldWindowUpdateFn should_window_update_fn = {}, + bool update_window_on_notify = true); + + int64_t CurrentWindowSize() const { return window_; } + int64_t WindowSizeLimit() const { return limit_; } + + // Called when the window size limit is changed (typically via settings) but + // no window update should be sent. + void OnWindowSizeLimitChange(int64_t new_limit); + + // Sets the window size limit to |new_limit| and notifies the listener to + // update as necessary. + void SetWindowSizeLimit(int64_t new_limit); + + // Increments the running total of data bytes buffered. Returns true iff there + // is more window remaining. + bool MarkDataBuffered(int64_t bytes); + + // Increments the running total of data bytes that have been flushed or + // dropped. Invokes the listener if the current window is smaller than some + // threshold and there is quota available to send. + void MarkDataFlushed(int64_t bytes); + + // Convenience method, used when incoming data is immediately dropped or + // ignored. + void MarkWindowConsumed(int64_t bytes) { + MarkDataBuffered(bytes); + MarkDataFlushed(bytes); + } + + // Increments the window size without affecting the limit. Useful if this end + // of a stream or connection issues a one-time WINDOW_UPDATE. + void IncreaseWindow(int64_t delta) { window_ += delta; } + + private: + friend class test::WindowManagerPeer; + + void MaybeNotifyListener(); + + // The upper bound on the flow control window. The GFE attempts to maintain a + // window of this size at the peer as data is proxied through. + int64_t limit_; + + // The current flow control window that has not been advertised to the peer + // and not yet consumed. The peer can send this many bytes before becoming + // blocked. + int64_t window_; + + // The amount of data already buffered, which should count against the flow + // control window upper bound. + int64_t buffered_; + + WindowUpdateListener listener_; + + ShouldWindowUpdateFn should_window_update_fn_; + + bool update_window_on_notify_; +}; + +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_WINDOW_MANAGER_H_ diff --git a/quiche/http2/adapter/window_manager_test.cc b/quiche/http2/adapter/window_manager_test.cc new file mode 100644 index 000000000000..f6617c567734 --- /dev/null +++ b/quiche/http2/adapter/window_manager_test.cc @@ -0,0 +1,342 @@ +#include "quiche/http2/adapter/window_manager.h" + +#include + +#include "absl/functional/bind_front.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/common/platform/api/quiche_expect_bug.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace adapter { +namespace test { + +// Use the peer to access private vars of WindowManager. +class WindowManagerPeer { + public: + explicit WindowManagerPeer(const WindowManager& wm) : wm_(wm) {} + + int64_t buffered() { return wm_.buffered_; } + + private: + const WindowManager& wm_; +}; + +namespace { + +class WindowManagerTest : public quiche::test::QuicheTest { + protected: + WindowManagerTest() + : wm_(kDefaultLimit, absl::bind_front(&WindowManagerTest::OnCall, this)), + peer_(wm_) {} + + void OnCall(int64_t s) { call_sequence_.push_back(s); } + + const int64_t kDefaultLimit = 32 * 1024 * 3; + std::list call_sequence_; + WindowManager wm_; + WindowManagerPeer peer_; + ::http2::test::Http2Random random_; +}; + +// A few no-op calls. +TEST_F(WindowManagerTest, NoOps) { + wm_.SetWindowSizeLimit(kDefaultLimit); + wm_.SetWindowSizeLimit(0); + wm_.SetWindowSizeLimit(kDefaultLimit); + wm_.MarkDataBuffered(0); + wm_.MarkDataFlushed(0); + EXPECT_TRUE(call_sequence_.empty()); +} + +// This test verifies that WindowManager does not notify its listener when data +// is only buffered, and never flushed. +TEST_F(WindowManagerTest, DataOnlyBuffered) { + int64_t total = 0; + while (total < kDefaultLimit) { + int64_t s = std::min(kDefaultLimit - total, random_.Uniform(1024)); + total += s; + wm_.MarkDataBuffered(s); + } + EXPECT_THAT(call_sequence_, ::testing::IsEmpty()); +} + +// This test verifies that WindowManager does notify its listener when data is +// buffered and subsequently flushed. +TEST_F(WindowManagerTest, DataBufferedAndFlushed) { + int64_t total_buffered = 0; + int64_t total_flushed = 0; + while (call_sequence_.empty()) { + int64_t buffered = std::min(kDefaultLimit - total_buffered, + random_.Uniform(1024)); + wm_.MarkDataBuffered(buffered); + total_buffered += buffered; + EXPECT_TRUE(call_sequence_.empty()); + int64_t flushed = (total_buffered - total_flushed) > 0 + ? random_.Uniform(total_buffered - total_flushed) + : 0; + wm_.MarkDataFlushed(flushed); + total_flushed += flushed; + } + // If WindowManager decided to send an update, at least one third of the + // window must have been consumed by buffered data. + EXPECT_GE(total_buffered, kDefaultLimit / 3); +} + +// Window manager should avoid window underflow. +TEST_F(WindowManagerTest, AvoidWindowUnderflow) { + EXPECT_EQ(wm_.CurrentWindowSize(), wm_.WindowSizeLimit()); + // Don't buffer more than the total window! + wm_.MarkDataBuffered(wm_.WindowSizeLimit() + 1); + EXPECT_EQ(wm_.CurrentWindowSize(), 0u); +} + +// Window manager should GFE_BUG and avoid buffered underflow. +TEST_F(WindowManagerTest, AvoidBufferedUnderflow) { + EXPECT_EQ(peer_.buffered(), 0u); + // Don't flush more than has been buffered! + EXPECT_QUICHE_BUG(wm_.MarkDataFlushed(1), "buffered underflow"); + EXPECT_EQ(peer_.buffered(), 0u); + + wm_.MarkDataBuffered(42); + EXPECT_EQ(peer_.buffered(), 42u); + // Don't flush more than has been buffered! + EXPECT_QUICHE_BUG( + { + wm_.MarkDataFlushed(43); + EXPECT_EQ(peer_.buffered(), 0u); + }, + "buffered underflow"); +} + +// This test verifies that WindowManager notifies its listener when window is +// consumed (data is ignored or immediately dropped). +TEST_F(WindowManagerTest, WindowConsumed) { + int64_t consumed = kDefaultLimit / 3 - 1; + wm_.MarkWindowConsumed(consumed); + EXPECT_TRUE(call_sequence_.empty()); + const int64_t extra = 1; + wm_.MarkWindowConsumed(extra); + EXPECT_THAT(call_sequence_, testing::ElementsAre(consumed + extra)); +} + +// This test verifies that WindowManager notifies its listener when the window +// size limit is increased. +TEST_F(WindowManagerTest, ListenerCalledOnSizeUpdate) { + wm_.SetWindowSizeLimit(kDefaultLimit - 1024); + EXPECT_TRUE(call_sequence_.empty()); + wm_.SetWindowSizeLimit(kDefaultLimit * 5); + // Because max(outstanding window, previous limit) is kDefaultLimit, it is + // only appropriate to increase the window by kDefaultLimit * 4. + EXPECT_THAT(call_sequence_, testing::ElementsAre(kDefaultLimit * 4)); +} + +// This test verifies that when data is buffered and then the limit is +// decreased, WindowManager only notifies the listener once any outstanding +// window has been consumed. +TEST_F(WindowManagerTest, WindowUpdateAfterLimitDecreased) { + wm_.MarkDataBuffered(kDefaultLimit - 1024); + wm_.SetWindowSizeLimit(kDefaultLimit - 2048); + + // Now there are 2048 bytes of window outstanding beyond the current limit, + // and we have 1024 bytes of data buffered beyond the current limit. This is + // intentional, to be sure that WindowManager works properly if the limit is + // decreased at runtime. + + wm_.MarkDataFlushed(512); + EXPECT_TRUE(call_sequence_.empty()); + wm_.MarkDataFlushed(512); + EXPECT_TRUE(call_sequence_.empty()); + wm_.MarkDataFlushed(512); + EXPECT_TRUE(call_sequence_.empty()); + wm_.MarkDataFlushed(1024); + EXPECT_THAT(call_sequence_, testing::ElementsAre(512)); +} + +// For normal behavior, we only call MaybeNotifyListener() when data is +// flushed. But if window runs out entirely, we still need to call +// MaybeNotifyListener() to avoid becoming artificially blocked when data isn't +// being flushed. +TEST_F(WindowManagerTest, ZeroWindowNotification) { + // Consume a byte of window, but not enough to trigger an update. + wm_.MarkWindowConsumed(1); + + // Buffer the remaining window. + wm_.MarkDataBuffered(kDefaultLimit - 1); + // Listener is notified of the remaining byte of possible window. + EXPECT_THAT(call_sequence_, testing::ElementsAre(1)); +} + +TEST_F(WindowManagerTest, OnWindowSizeLimitChange) { + wm_.MarkDataBuffered(10000); + EXPECT_EQ(wm_.CurrentWindowSize(), kDefaultLimit - 10000); + EXPECT_EQ(wm_.WindowSizeLimit(), kDefaultLimit); + + wm_.OnWindowSizeLimitChange(kDefaultLimit + 1000); + EXPECT_EQ(wm_.CurrentWindowSize(), kDefaultLimit - 9000); + EXPECT_EQ(wm_.WindowSizeLimit(), kDefaultLimit + 1000); + + wm_.OnWindowSizeLimitChange(kDefaultLimit - 1000); + EXPECT_EQ(wm_.CurrentWindowSize(), kDefaultLimit - 11000); + EXPECT_EQ(wm_.WindowSizeLimit(), kDefaultLimit - 1000); +} + +TEST_F(WindowManagerTest, NegativeWindowSize) { + wm_.MarkDataBuffered(80000); + // 98304 window - 80000 buffered = 18304 available + EXPECT_EQ(wm_.CurrentWindowSize(), 18304); + wm_.OnWindowSizeLimitChange(65535); + // limit decreases by 98304 - 65535 = 32769, window becomes -14465 + EXPECT_EQ(wm_.CurrentWindowSize(), -14465); + wm_.MarkDataFlushed(70000); + // Still 10000 bytes buffered, so window manager grants sufficient quota to + // reach a window of 65535 - 10000. + EXPECT_EQ(wm_.CurrentWindowSize(), 55535); + // Desired window minus existing window: 55535 - (-14465) = 70000 + EXPECT_THAT(call_sequence_, testing::ElementsAre(70000)); +} + +TEST_F(WindowManagerTest, IncreaseWindow) { + wm_.MarkDataBuffered(1000); + EXPECT_EQ(wm_.CurrentWindowSize(), kDefaultLimit - 1000); + EXPECT_EQ(wm_.WindowSizeLimit(), kDefaultLimit); + + // Increasing the window beyond the limit is allowed. + wm_.IncreaseWindow(5000); + EXPECT_EQ(wm_.CurrentWindowSize(), kDefaultLimit + 4000); + EXPECT_EQ(wm_.WindowSizeLimit(), kDefaultLimit); + + // 80000 bytes are buffered, then flushed. + wm_.MarkWindowConsumed(80000); + // The window manager replenishes the consumed quota up to the limit. + EXPECT_THAT(call_sequence_, testing::ElementsAre(75000)); + // The window is the limit, minus buffered data, as expected. + EXPECT_EQ(wm_.CurrentWindowSize(), kDefaultLimit - 1000); +} + +// This test verifies that when the constructor option is specified, +// WindowManager does not update its internal accounting of the flow control +// window when notifying the listener. +TEST(WindowManagerNoUpdateTest, NoWindowUpdateOnListener) { + const int64_t kDefaultLimit = 65535; + + std::list call_sequence1; + WindowManager wm1( + kDefaultLimit, + [&call_sequence1](int64_t delta) { call_sequence1.push_back(delta); }, + /*should_notify_listener=*/{}, + /*update_window_on_notify=*/true); // default + std::list call_sequence2; + WindowManager wm2( + kDefaultLimit, + [&call_sequence2](int64_t delta) { call_sequence2.push_back(delta); }, + /*should_notify_listener=*/{}, + /*update_window_on_notify=*/false); + + const int64_t consumed = kDefaultLimit / 3 - 1; + + wm1.MarkWindowConsumed(consumed); + EXPECT_TRUE(call_sequence1.empty()); + wm2.MarkWindowConsumed(consumed); + EXPECT_TRUE(call_sequence2.empty()); + + EXPECT_EQ(wm1.CurrentWindowSize(), kDefaultLimit - consumed); + EXPECT_EQ(wm2.CurrentWindowSize(), kDefaultLimit - consumed); + + const int64_t extra = 1; + wm1.MarkWindowConsumed(extra); + EXPECT_THAT(call_sequence1, testing::ElementsAre(consumed + extra)); + // Window size *is* updated after invoking the listener. + EXPECT_EQ(wm1.CurrentWindowSize(), kDefaultLimit); + call_sequence1.clear(); + + wm2.MarkWindowConsumed(extra); + EXPECT_THAT(call_sequence2, testing::ElementsAre(consumed + extra)); + // Window size is *not* updated after invoking the listener. + EXPECT_EQ(wm2.CurrentWindowSize(), kDefaultLimit - (consumed + extra)); + call_sequence2.clear(); + + // Manually increase the window by the listener notification amount. + wm2.IncreaseWindow(consumed + extra); + EXPECT_EQ(wm2.CurrentWindowSize(), kDefaultLimit); + + wm1.SetWindowSizeLimit(kDefaultLimit * 5); + EXPECT_THAT(call_sequence1, testing::ElementsAre(kDefaultLimit * 4)); + // *Does* update the window size. + EXPECT_EQ(wm1.CurrentWindowSize(), kDefaultLimit * 5); + + wm2.SetWindowSizeLimit(kDefaultLimit * 5); + EXPECT_THAT(call_sequence2, testing::ElementsAre(kDefaultLimit * 4)); + // Does *not* update the window size. + EXPECT_EQ(wm2.CurrentWindowSize(), kDefaultLimit); +} + +// This test verifies that when the constructor option is specified, +// WindowManager uses the provided ShouldWindowUpdateFn to determine when to +// notify the listener. +TEST(WindowManagerShouldUpdateTest, CustomShouldWindowUpdateFn) { + const int64_t kDefaultLimit = 65535; + + // This window manager should always notify. + std::list call_sequence1; + WindowManager wm1( + kDefaultLimit, + [&call_sequence1](int64_t delta) { call_sequence1.push_back(delta); }, + [](int64_t /*limit*/, int64_t /*window*/, int64_t /*delta*/) { + return true; + }); + // This window manager should never notify. + std::list call_sequence2; + WindowManager wm2( + kDefaultLimit, + [&call_sequence2](int64_t delta) { call_sequence2.push_back(delta); }, + [](int64_t /*limit*/, int64_t /*window*/, int64_t /*delta*/) { + return false; + }); + // This window manager should notify as long as no data is buffered. + std::list call_sequence3; + WindowManager wm3( + kDefaultLimit, + [&call_sequence3](int64_t delta) { call_sequence3.push_back(delta); }, + [](int64_t limit, int64_t window, int64_t delta) { + return delta == limit - window; + }); + + const int64_t consumed = kDefaultLimit / 4; + + wm1.MarkWindowConsumed(consumed); + EXPECT_THAT(call_sequence1, testing::ElementsAre(consumed)); + wm2.MarkWindowConsumed(consumed); + EXPECT_TRUE(call_sequence2.empty()); + wm3.MarkWindowConsumed(consumed); + EXPECT_THAT(call_sequence3, testing::ElementsAre(consumed)); + + const int64_t buffered = 42; + + wm1.MarkDataBuffered(buffered); + EXPECT_THAT(call_sequence1, testing::ElementsAre(consumed)); + wm2.MarkDataBuffered(buffered); + EXPECT_TRUE(call_sequence2.empty()); + wm3.MarkDataBuffered(buffered); + EXPECT_THAT(call_sequence3, testing::ElementsAre(consumed)); + + wm1.MarkDataFlushed(buffered / 3); + EXPECT_THAT(call_sequence1, testing::ElementsAre(consumed, buffered / 3)); + wm2.MarkDataFlushed(buffered / 3); + EXPECT_TRUE(call_sequence2.empty()); + wm3.MarkDataFlushed(buffered / 3); + EXPECT_THAT(call_sequence3, testing::ElementsAre(consumed)); + + wm1.MarkDataFlushed(2 * buffered / 3); + EXPECT_THAT(call_sequence1, + testing::ElementsAre(consumed, buffered / 3, 2 * buffered / 3)); + wm2.MarkDataFlushed(2 * buffered / 3); + EXPECT_TRUE(call_sequence2.empty()); + wm3.MarkDataFlushed(2 * buffered / 3); + EXPECT_THAT(call_sequence3, testing::ElementsAre(consumed, buffered)); +} + +} // namespace +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/quiche/http2/core/http2_trace_logging.cc b/quiche/http2/core/http2_trace_logging.cc new file mode 100644 index 000000000000..ac0a82746d40 --- /dev/null +++ b/quiche/http2/core/http2_trace_logging.cc @@ -0,0 +1,482 @@ +#include "quiche/http2/core/http2_trace_logging.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/spdy/core/http2_header_block.h" +#include "quiche/spdy/core/spdy_protocol.h" + +// Convenience macros for printing function arguments in log lines in the +// format arg_name=value. +#define FORMAT_ARG(arg) " " #arg "=" << arg +#define FORMAT_INT_ARG(arg) " " #arg "=" << static_cast(arg) + +// Convenience macros for printing Spdy*IR attributes in log lines in the +// format attrib_name=value. +#define FORMAT_ATTR(ir, attrib) " " #attrib "=" << ir.attrib() +#define FORMAT_INT_ATTR(ir, attrib) \ + " " #attrib "=" << static_cast(ir.attrib()) + +namespace { + +// Logs a container, using a user-provided object to log each individual item. +template +struct ContainerLogger { + explicit ContainerLogger(const T& c, ItemLogger l) + : container(c), item_logger(l) {} + + friend std::ostream& operator<<(std::ostream& out, + const ContainerLogger& logger) { + out << "["; + auto begin = logger.container.begin(); + for (auto it = begin; it != logger.container.end(); ++it) { + if (it != begin) { + out << ", "; + } + logger.item_logger.Log(out, *it); + } + out << "]"; + return out; + } + const T& container; + ItemLogger item_logger; +}; + +// Returns a ContainerLogger that will log |container| using |item_logger|. +template +auto LogContainer(const T& container, ItemLogger item_logger) + -> decltype(ContainerLogger(container, item_logger)) { + return ContainerLogger(container, item_logger); +} + +} // anonymous namespace + +#define FORMAT_HEADER_BLOCK(ir) \ + " header_block=" << LogContainer(ir.header_block(), LogHeaderBlockEntry()) + +namespace http2 { + +using spdy::Http2HeaderBlock; +using spdy::SettingsMap; +using spdy::SpdyAltSvcIR; +using spdy::SpdyContinuationIR; +using spdy::SpdyDataIR; +using spdy::SpdyGoAwayIR; +using spdy::SpdyHeadersIR; +using spdy::SpdyPingIR; +using spdy::SpdyPriorityIR; +using spdy::SpdyPushPromiseIR; +using spdy::SpdyRstStreamIR; +using spdy::SpdySettingsIR; +using spdy::SpdyStreamId; +using spdy::SpdyUnknownIR; +using spdy::SpdyWindowUpdateIR; + +namespace { + +// Defines how elements of Http2HeaderBlocks are logged. +struct LogHeaderBlockEntry { + void Log(std::ostream& out, + const Http2HeaderBlock::value_type& entry) const { // NOLINT + out << "\"" << entry.first << "\": \"" << entry.second << "\""; + } +}; + +// Defines how elements of SettingsMap are logged. +struct LogSettingsEntry { + void Log(std::ostream& out, + const SettingsMap::value_type& entry) const { // NOLINT + out << spdy::SettingsIdToString(entry.first) << ": " << entry.second; + } +}; + +// Defines how elements of AlternativeServiceVector are logged. +struct LogAlternativeService { + void Log(std::ostream& out, + const spdy::SpdyAltSvcWireFormat::AlternativeService& altsvc) + const { // NOLINT + out << "{" + << "protocol_id=" << altsvc.protocol_id << " host=" << altsvc.host + << " port=" << altsvc.port + << " max_age_seconds=" << altsvc.max_age_seconds << " version="; + for (auto v : altsvc.version) { + out << v << ","; + } + out << "}"; + } +}; + +} // anonymous namespace + +Http2TraceLogger::Http2TraceLogger(SpdyFramerVisitorInterface* parent, + absl::string_view perspective, + std::function is_enabled, + const void* connection_id) + : wrapped_(parent), + perspective_(perspective), + is_enabled_(std::move(is_enabled)), + connection_id_(connection_id) {} + +Http2TraceLogger::~Http2TraceLogger() { + if (recording_headers_handler_ != nullptr && + !recording_headers_handler_->decoded_block().empty()) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "connection_id=" << connection_id_ + << " Received headers that were never logged! keys/values:" + << recording_headers_handler_->decoded_block().DebugString(); + } +} + +void Http2TraceLogger::OnError(Http2DecoderAdapter::SpdyFramerError error, + std::string detailed_error) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnError:" << FORMAT_ARG(connection_id_) + << ", error=" << Http2DecoderAdapter::SpdyFramerErrorToString(error); + wrapped_->OnError(error, detailed_error); +} + +void Http2TraceLogger::OnCommonHeader(SpdyStreamId stream_id, size_t length, + uint8_t type, uint8_t flags) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnCommonHeader:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(stream_id) << FORMAT_ARG(length) << FORMAT_INT_ARG(type) + << FORMAT_INT_ARG(flags); + wrapped_->OnCommonHeader(stream_id, length, type, flags); +} + +void Http2TraceLogger::OnDataFrameHeader(SpdyStreamId stream_id, size_t length, + bool fin) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnDataFrameHeader:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(stream_id) << FORMAT_ARG(length) << FORMAT_ARG(fin); + wrapped_->OnDataFrameHeader(stream_id, length, fin); +} + +void Http2TraceLogger::OnStreamFrameData(SpdyStreamId stream_id, + const char* data, size_t len) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnStreamFrameData:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(stream_id) << FORMAT_ARG(len); + wrapped_->OnStreamFrameData(stream_id, data, len); +} + +void Http2TraceLogger::OnStreamEnd(SpdyStreamId stream_id) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnStreamEnd:" << FORMAT_ARG(connection_id_) << FORMAT_ARG(stream_id); + wrapped_->OnStreamEnd(stream_id); +} + +void Http2TraceLogger::OnStreamPadLength(SpdyStreamId stream_id, size_t value) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnStreamPadLength:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(stream_id) << FORMAT_ARG(value); + wrapped_->OnStreamPadLength(stream_id, value); +} + +void Http2TraceLogger::OnStreamPadding(SpdyStreamId stream_id, size_t len) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnStreamPadding:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(stream_id) << FORMAT_ARG(len); + wrapped_->OnStreamPadding(stream_id, len); +} + +spdy::SpdyHeadersHandlerInterface* Http2TraceLogger::OnHeaderFrameStart( + SpdyStreamId stream_id) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnHeaderFrameStart:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(stream_id); + spdy::SpdyHeadersHandlerInterface* result = + wrapped_->OnHeaderFrameStart(stream_id); + if (is_enabled_()) { + recording_headers_handler_ = + std::make_unique(result); + result = recording_headers_handler_.get(); + } else { + recording_headers_handler_ = nullptr; + } + return result; +} + +void Http2TraceLogger::OnHeaderFrameEnd(SpdyStreamId stream_id) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnHeaderFrameEnd:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(stream_id); + LogReceivedHeaders(); + wrapped_->OnHeaderFrameEnd(stream_id); + recording_headers_handler_ = nullptr; +} + +void Http2TraceLogger::OnRstStream(SpdyStreamId stream_id, + SpdyErrorCode error_code) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnRstStream:" << FORMAT_ARG(connection_id_) << FORMAT_ARG(stream_id) + << " error_code=" << spdy::ErrorCodeToString(error_code); + wrapped_->OnRstStream(stream_id, error_code); +} + +void Http2TraceLogger::OnSettings() { wrapped_->OnSettings(); } + +void Http2TraceLogger::OnSetting(SpdySettingsId id, uint32_t value) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnSetting:" << FORMAT_ARG(connection_id_) + << " id=" << spdy::SettingsIdToString(id) << FORMAT_ARG(value); + wrapped_->OnSetting(id, value); +} + +void Http2TraceLogger::OnSettingsEnd() { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnSettingsEnd:" << FORMAT_ARG(connection_id_); + wrapped_->OnSettingsEnd(); +} + +void Http2TraceLogger::OnSettingsAck() { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnSettingsAck:" << FORMAT_ARG(connection_id_); + wrapped_->OnSettingsAck(); +} + +void Http2TraceLogger::OnPing(SpdyPingId unique_id, bool is_ack) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnPing:" << FORMAT_ARG(connection_id_) << FORMAT_ARG(unique_id) + << FORMAT_ARG(is_ack); + wrapped_->OnPing(unique_id, is_ack); +} + +void Http2TraceLogger::OnGoAway(SpdyStreamId last_accepted_stream_id, + SpdyErrorCode error_code) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnGoAway:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(last_accepted_stream_id) + << " error_code=" << spdy::ErrorCodeToString(error_code); + wrapped_->OnGoAway(last_accepted_stream_id, error_code); +} + +bool Http2TraceLogger::OnGoAwayFrameData(const char* goaway_data, size_t len) { + return wrapped_->OnGoAwayFrameData(goaway_data, len); +} + +void Http2TraceLogger::OnHeaders(SpdyStreamId stream_id, size_t payload_length, + bool has_priority, int weight, + SpdyStreamId parent_stream_id, bool exclusive, + bool fin, bool end) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnHeaders:" << FORMAT_ARG(connection_id_) << FORMAT_ARG(stream_id) + << FORMAT_ARG(payload_length) << FORMAT_ARG(has_priority) + << FORMAT_INT_ARG(weight) << FORMAT_ARG(parent_stream_id) + << FORMAT_ARG(exclusive) << FORMAT_ARG(fin) << FORMAT_ARG(end); + wrapped_->OnHeaders(stream_id, payload_length, has_priority, weight, + parent_stream_id, exclusive, fin, end); +} + +void Http2TraceLogger::OnWindowUpdate(SpdyStreamId stream_id, + int delta_window_size) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnWindowUpdate:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(stream_id) << FORMAT_ARG(delta_window_size); + wrapped_->OnWindowUpdate(stream_id, delta_window_size); +} + +void Http2TraceLogger::OnPushPromise(SpdyStreamId original_stream_id, + SpdyStreamId promised_stream_id, + bool end) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnPushPromise:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(original_stream_id) << FORMAT_ARG(promised_stream_id) + << FORMAT_ARG(end); + wrapped_->OnPushPromise(original_stream_id, promised_stream_id, end); +} + +void Http2TraceLogger::OnContinuation(SpdyStreamId stream_id, + size_t payload_length, bool end) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnContinuation:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(stream_id) << FORMAT_ARG(payload_length) << FORMAT_ARG(end); + wrapped_->OnContinuation(stream_id, payload_length, end); +} + +void Http2TraceLogger::OnAltSvc( + SpdyStreamId stream_id, absl::string_view origin, + const SpdyAltSvcWireFormat::AlternativeServiceVector& altsvc_vector) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnAltSvc:" << FORMAT_ARG(connection_id_) << FORMAT_ARG(stream_id) + << FORMAT_ARG(origin) << " altsvc_vector=" + << LogContainer(altsvc_vector, LogAlternativeService()); + wrapped_->OnAltSvc(stream_id, origin, altsvc_vector); +} + +void Http2TraceLogger::OnPriority(SpdyStreamId stream_id, + SpdyStreamId parent_stream_id, int weight, + bool exclusive) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnPriority:" << FORMAT_ARG(connection_id_) << FORMAT_ARG(stream_id) + << FORMAT_ARG(parent_stream_id) << FORMAT_INT_ARG(weight) + << FORMAT_ARG(exclusive); + wrapped_->OnPriority(stream_id, parent_stream_id, weight, exclusive); +} + +void Http2TraceLogger::OnPriorityUpdate( + SpdyStreamId prioritized_stream_id, + absl::string_view priority_field_value) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnPriorityUpdate:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(prioritized_stream_id) << FORMAT_ARG(priority_field_value); + wrapped_->OnPriorityUpdate(prioritized_stream_id, priority_field_value); +} + +bool Http2TraceLogger::OnUnknownFrame(SpdyStreamId stream_id, + uint8_t frame_type) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnUnknownFrame:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(stream_id) << FORMAT_INT_ARG(frame_type); + return wrapped_->OnUnknownFrame(stream_id, frame_type); +} + +void Http2TraceLogger::OnUnknownFrameStart(spdy::SpdyStreamId stream_id, + size_t length, uint8_t type, + uint8_t flags) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnUnknownFrameStart:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(stream_id) << FORMAT_ARG(length) << FORMAT_INT_ARG(type) + << FORMAT_INT_ARG(flags); + wrapped_->OnUnknownFrameStart(stream_id, length, type, flags); +} + +void Http2TraceLogger::OnUnknownFramePayload(spdy::SpdyStreamId stream_id, + absl::string_view payload) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnUnknownFramePayload:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(stream_id) << " length=" << payload.size(); + wrapped_->OnUnknownFramePayload(stream_id, payload); +} + +void Http2TraceLogger::LogReceivedHeaders() const { + if (recording_headers_handler_ == nullptr) { + // Trace logging was not enabled when the start of the header block was + // received. + return; + } + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Received headers;" << FORMAT_ARG(connection_id_) << " keys/values:" + << recording_headers_handler_->decoded_block().DebugString() + << " compressed_bytes=" + << recording_headers_handler_->compressed_header_bytes() + << " uncompressed_bytes=" + << recording_headers_handler_->uncompressed_header_bytes(); +} + +void Http2FrameLogger::VisitRstStream(const SpdyRstStreamIR& rst_stream) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Wrote SpdyRstStreamIR:" << FORMAT_ARG(connection_id_) + << FORMAT_ATTR(rst_stream, stream_id) + << " error_code=" << spdy::ErrorCodeToString(rst_stream.error_code()); +} + +void Http2FrameLogger::VisitSettings(const SpdySettingsIR& settings) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Wrote SpdySettingsIR:" << FORMAT_ARG(connection_id_) + << FORMAT_ATTR(settings, is_ack) + << " values=" << LogContainer(settings.values(), LogSettingsEntry()); +} + +void Http2FrameLogger::VisitPing(const SpdyPingIR& ping) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Wrote SpdyPingIR:" << FORMAT_ARG(connection_id_) + << FORMAT_ATTR(ping, id) << FORMAT_ATTR(ping, is_ack); +} + +void Http2FrameLogger::VisitGoAway(const SpdyGoAwayIR& goaway) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Wrote SpdyGoAwayIR:" << FORMAT_ARG(connection_id_) + << FORMAT_ATTR(goaway, last_good_stream_id) + << " error_code=" << spdy::ErrorCodeToString(goaway.error_code()) + << FORMAT_ATTR(goaway, description); +} + +void Http2FrameLogger::VisitHeaders(const SpdyHeadersIR& headers) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Wrote SpdyHeadersIR:" << FORMAT_ARG(connection_id_) + << FORMAT_ATTR(headers, stream_id) << FORMAT_ATTR(headers, fin) + << FORMAT_ATTR(headers, has_priority) << FORMAT_INT_ATTR(headers, weight) + << FORMAT_ATTR(headers, parent_stream_id) + << FORMAT_ATTR(headers, exclusive) << FORMAT_ATTR(headers, padded) + << FORMAT_ATTR(headers, padding_payload_len) + << FORMAT_HEADER_BLOCK(headers); +} + +void Http2FrameLogger::VisitWindowUpdate( + const SpdyWindowUpdateIR& window_update) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Wrote SpdyWindowUpdateIR:" << FORMAT_ARG(connection_id_) + << FORMAT_ATTR(window_update, stream_id) + << FORMAT_ATTR(window_update, delta); +} + +void Http2FrameLogger::VisitPushPromise(const SpdyPushPromiseIR& push_promise) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Wrote SpdyPushPromiseIR:" << FORMAT_ARG(connection_id_) + << FORMAT_ATTR(push_promise, stream_id) << FORMAT_ATTR(push_promise, fin) + << FORMAT_ATTR(push_promise, promised_stream_id) + << FORMAT_ATTR(push_promise, padded) + << FORMAT_ATTR(push_promise, padding_payload_len) + << FORMAT_HEADER_BLOCK(push_promise); +} + +void Http2FrameLogger::VisitContinuation( + const SpdyContinuationIR& continuation) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Wrote SpdyContinuationIR:" << FORMAT_ARG(connection_id_) + << FORMAT_ATTR(continuation, stream_id) + << FORMAT_ATTR(continuation, end_headers); +} + +void Http2FrameLogger::VisitAltSvc(const SpdyAltSvcIR& altsvc) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Wrote SpdyAltSvcIR:" << FORMAT_ARG(connection_id_) + << FORMAT_ATTR(altsvc, stream_id) << FORMAT_ATTR(altsvc, origin) + << " altsvc_vector=" + << LogContainer(altsvc.altsvc_vector(), LogAlternativeService()); +} + +void Http2FrameLogger::VisitPriority(const SpdyPriorityIR& priority) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Wrote SpdyPriorityIR:" << FORMAT_ARG(connection_id_) + << FORMAT_ATTR(priority, stream_id) + << FORMAT_ATTR(priority, parent_stream_id) + << FORMAT_INT_ATTR(priority, weight) << FORMAT_ATTR(priority, exclusive); +} + +void Http2FrameLogger::VisitData(const SpdyDataIR& data) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Wrote SpdyDataIR:" << FORMAT_ARG(connection_id_) + << FORMAT_ATTR(data, stream_id) << FORMAT_ATTR(data, fin) + << " data_len=" << data.data_len() << FORMAT_ATTR(data, padded) + << FORMAT_ATTR(data, padding_payload_len); +} + +void Http2FrameLogger::VisitPriorityUpdate( + const spdy::SpdyPriorityUpdateIR& priority_update) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Wrote SpdyPriorityUpdateIR:" << FORMAT_ARG(connection_id_) + << FORMAT_ATTR(priority_update, stream_id) + << FORMAT_ATTR(priority_update, prioritized_stream_id) + << FORMAT_ATTR(priority_update, priority_field_value); +} + +void Http2FrameLogger::VisitAcceptCh( + const spdy::SpdyAcceptChIR& /*accept_ch*/) { + QUICHE_BUG(bug_2794_2) + << "Sending ACCEPT_CH frames is currently unimplemented."; +} + +void Http2FrameLogger::VisitUnknown(const SpdyUnknownIR& ir) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Wrote SpdyUnknownIR:" << FORMAT_ARG(connection_id_) + << FORMAT_ATTR(ir, stream_id) << FORMAT_INT_ATTR(ir, type) + << FORMAT_INT_ATTR(ir, flags) << FORMAT_ATTR(ir, length); +} + +} // namespace http2 diff --git a/quiche/http2/core/http2_trace_logging.h b/quiche/http2/core/http2_trace_logging.h new file mode 100644 index 000000000000..f7218fe18f09 --- /dev/null +++ b/quiche/http2/core/http2_trace_logging.h @@ -0,0 +1,144 @@ +// Classes and utilities for supporting HTTP/2 trace logging, which logs +// information about all control and data frames sent and received over +// HTTP/2 connections. + +#ifndef QUICHE_HTTP2_CORE_HTTP2_TRACE_LOGGING_H_ +#define QUICHE_HTTP2_CORE_HTTP2_TRACE_LOGGING_H_ + +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/spdy/core/http2_frame_decoder_adapter.h" +#include "quiche/spdy/core/recording_headers_handler.h" +#include "quiche/spdy/core/spdy_headers_handler_interface.h" +#include "quiche/spdy/core/spdy_protocol.h" + +// Logging macro to use for all HTTP/2 trace logging. Iff trace logging is +// enabled, logs at level INFO with a common prefix prepended (to facilitate +// post-hoc filtering of trace logging output). +#define HTTP2_TRACE_LOG(perspective, is_enabled) \ + QUICHE_LOG_IF(INFO, is_enabled()) << "[HTTP2_TRACE " << perspective << "] " + +namespace http2 { + +// Intercepts deframing events to provide detailed logs. Intended to be used for +// manual debugging. +// +// Note any new methods in SpdyFramerVisitorInterface MUST be overridden here to +// properly forward the event. This could be ensured by making every event in +// SpdyFramerVisitorInterface a pure virtual. +class QUICHE_EXPORT Http2TraceLogger : public spdy::SpdyFramerVisitorInterface { + public: + typedef spdy::SpdyAltSvcWireFormat SpdyAltSvcWireFormat; + typedef spdy::SpdyErrorCode SpdyErrorCode; + typedef spdy::SpdyFramerVisitorInterface SpdyFramerVisitorInterface; + typedef spdy::SpdyPingId SpdyPingId; + typedef spdy::SpdyPriority SpdyPriority; + typedef spdy::SpdySettingsId SpdySettingsId; + typedef spdy::SpdyStreamId SpdyStreamId; + + Http2TraceLogger(SpdyFramerVisitorInterface* parent, + absl::string_view perspective, + std::function is_enabled, const void* connection_id); + ~Http2TraceLogger() override; + + Http2TraceLogger(const Http2TraceLogger&) = delete; + Http2TraceLogger& operator=(const Http2TraceLogger&) = delete; + + void OnError(http2::Http2DecoderAdapter::SpdyFramerError error, + std::string detailed_error) override; + void OnCommonHeader(SpdyStreamId stream_id, size_t length, uint8_t type, + uint8_t flags) override; + spdy::SpdyHeadersHandlerInterface* OnHeaderFrameStart( + SpdyStreamId stream_id) override; + void OnHeaderFrameEnd(SpdyStreamId stream_id) override; + void OnDataFrameHeader(SpdyStreamId stream_id, size_t length, + bool fin) override; + void OnStreamFrameData(SpdyStreamId stream_id, const char* data, + size_t len) override; + void OnStreamEnd(SpdyStreamId stream_id) override; + void OnStreamPadLength(SpdyStreamId stream_id, size_t value) override; + void OnStreamPadding(SpdyStreamId stream_id, size_t len) override; + void OnRstStream(SpdyStreamId stream_id, SpdyErrorCode error_code) override; + void OnSetting(spdy::SpdySettingsId id, uint32_t value) override; + void OnPing(SpdyPingId unique_id, bool is_ack) override; + void OnSettings() override; + void OnSettingsEnd() override; + void OnSettingsAck() override; + void OnGoAway(SpdyStreamId last_accepted_stream_id, + SpdyErrorCode error_code) override; + bool OnGoAwayFrameData(const char* goaway_data, size_t len) override; + void OnHeaders(SpdyStreamId stream_id, size_t payload_length, + bool has_priority, int weight, SpdyStreamId parent_stream_id, + bool exclusive, bool fin, bool end) override; + void OnWindowUpdate(SpdyStreamId stream_id, int delta_window_size) override; + void OnPushPromise(SpdyStreamId stream_id, SpdyStreamId promised_stream_id, + bool end) override; + void OnContinuation(SpdyStreamId stream_id, size_t payload_length, + bool end) override; + void OnAltSvc(SpdyStreamId stream_id, absl::string_view origin, + const SpdyAltSvcWireFormat::AlternativeServiceVector& + altsvc_vector) override; + void OnPriority(SpdyStreamId stream_id, SpdyStreamId parent_stream_id, + int weight, bool exclusive) override; + void OnPriorityUpdate(SpdyStreamId prioritized_stream_id, + absl::string_view priority_field_value) override; + bool OnUnknownFrame(SpdyStreamId stream_id, uint8_t frame_type) override; + void OnUnknownFrameStart(SpdyStreamId stream_id, size_t length, uint8_t type, + uint8_t flags) override; + void OnUnknownFramePayload(SpdyStreamId stream_id, + absl::string_view payload) override; + + private: + void LogReceivedHeaders() const; + + std::unique_ptr recording_headers_handler_; + + SpdyFramerVisitorInterface* wrapped_; + const absl::string_view perspective_; + const std::function is_enabled_; + const void* connection_id_; +}; + +// Visitor to log control frames that have been written. +class QUICHE_EXPORT Http2FrameLogger : public spdy::SpdyFrameVisitor { + public: + // This class will preface all of its log messages with the value of + // |connection_id| in hexadecimal. + Http2FrameLogger(absl::string_view perspective, + std::function is_enabled, const void* connection_id) + : perspective_(perspective), + is_enabled_(std::move(is_enabled)), + connection_id_(connection_id) {} + + Http2FrameLogger(const Http2FrameLogger&) = delete; + Http2FrameLogger& operator=(const Http2FrameLogger&) = delete; + + void VisitRstStream(const spdy::SpdyRstStreamIR& rst_stream) override; + void VisitSettings(const spdy::SpdySettingsIR& settings) override; + void VisitPing(const spdy::SpdyPingIR& ping) override; + void VisitGoAway(const spdy::SpdyGoAwayIR& goaway) override; + void VisitHeaders(const spdy::SpdyHeadersIR& headers) override; + void VisitWindowUpdate( + const spdy::SpdyWindowUpdateIR& window_update) override; + void VisitPushPromise(const spdy::SpdyPushPromiseIR& push_promise) override; + void VisitContinuation(const spdy::SpdyContinuationIR& continuation) override; + void VisitAltSvc(const spdy::SpdyAltSvcIR& altsvc) override; + void VisitPriority(const spdy::SpdyPriorityIR& priority) override; + void VisitData(const spdy::SpdyDataIR& data) override; + void VisitPriorityUpdate( + const spdy::SpdyPriorityUpdateIR& priority_update) override; + void VisitAcceptCh(const spdy::SpdyAcceptChIR& accept_ch) override; + void VisitUnknown(const spdy::SpdyUnknownIR& ir) override; + + private: + const absl::string_view perspective_; + const std::function is_enabled_; + const void* connection_id_; +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_CORE_HTTP2_TRACE_LOGGING_H_ diff --git a/quiche/http2/core/priority_write_scheduler.h b/quiche/http2/core/priority_write_scheduler.h new file mode 100644 index 000000000000..d63a316360fe --- /dev/null +++ b/quiche/http2/core/priority_write_scheduler.h @@ -0,0 +1,381 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_CORE_PRIORITY_WRITE_SCHEDULER_H_ +#define QUICHE_HTTP2_CORE_PRIORITY_WRITE_SCHEDULER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_cat.h" +#include "absl/time/time.h" +#include "absl/types/optional.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_circular_deque.h" +#include "quiche/spdy/core/spdy_protocol.h" + +namespace http2 { + +namespace test { +template +class PriorityWriteSchedulerPeer; +} + +// SpdyPriority is an integer type, so this functor can be used both as +// PriorityTypeToInt and as IntToPriorityType. +struct QUICHE_EXPORT SpdyPriorityToSpdyPriority { + spdy::SpdyPriority operator()(spdy::SpdyPriority priority) { + return priority; + } +}; + +// PriorityWriteScheduler manages the order in which HTTP/2 or HTTP/3 streams +// are written. Each stream has a priority of type PriorityType. This includes +// an integer between 0 and 7, and optionally other information that is stored +// but otherwise ignored by this class. Higher priority (lower integer value) +// streams are always given precedence over lower priority (higher value) +// streams, as long as the higher priority stream is not blocked. +// +// Each stream can be in one of two states: ready or not ready (for writing). +// Ready state is changed by calling the MarkStreamReady() and +// MarkStreamNotReady() methods. Only streams in the ready state can be returned +// by PopNextReadyStream(). When returned by that method, the stream's state +// changes to not ready. +// +template +class QUICHE_EXPORT PriorityWriteScheduler { + public: + static constexpr int kHighestPriority = 0; + static constexpr int kLowestPriority = 7; + + static_assert(spdy::kV3HighestPriority == kHighestPriority); + static_assert(spdy::kV3LowestPriority == kLowestPriority); + + // Registers new stream `stream_id` with the scheduler, assigning it the + // given priority. + // + // Preconditions: `stream_id` should be unregistered. + void RegisterStream(StreamIdType stream_id, PriorityType priority) { + auto stream_info = std::make_unique( + StreamInfo{std::move(priority), stream_id, false}); + bool inserted = + stream_infos_.insert(std::make_pair(stream_id, std::move(stream_info))) + .second; + QUICHE_BUG_IF(spdy_bug_19_2, !inserted) + << "Stream " << stream_id << " already registered"; + } + + // Unregisters the given stream from the scheduler, which will no longer keep + // state for it. + // + // Preconditions: `stream_id` should be registered. + void UnregisterStream(StreamIdType stream_id) { + auto it = stream_infos_.find(stream_id); + if (it == stream_infos_.end()) { + QUICHE_BUG(spdy_bug_19_3) << "Stream " << stream_id << " not registered"; + return; + } + const StreamInfo* const stream_info = it->second.get(); + if (stream_info->ready) { + bool erased = + Erase(&priority_infos_[PriorityTypeToInt()(stream_info->priority)] + .ready_list, + stream_info); + QUICHE_DCHECK(erased); + } + stream_infos_.erase(it); + } + + // Returns true if the given stream is currently registered. + bool StreamRegistered(StreamIdType stream_id) const { + return stream_infos_.find(stream_id) != stream_infos_.end(); + } + + // Returns the priority of the specified stream. + // + // Preconditions: `stream_id` should be registered. + PriorityType GetStreamPriority(StreamIdType stream_id) const { + auto it = stream_infos_.find(stream_id); + if (it == stream_infos_.end()) { + QUICHE_DVLOG(1) << "Stream " << stream_id << " not registered"; + return IntToPriorityType()(kLowestPriority); + } + return it->second->priority; + } + + // Updates the priority of the given stream. + // + // Preconditions: `stream_id` should be registered. + void UpdateStreamPriority(StreamIdType stream_id, PriorityType priority) { + auto it = stream_infos_.find(stream_id); + if (it == stream_infos_.end()) { + // TODO(mpw): add to stream_infos_ on demand--see b/15676312. + QUICHE_DVLOG(1) << "Stream " << stream_id << " not registered"; + return; + } + + StreamInfo* const stream_info = it->second.get(); + if (stream_info->priority == priority) { + return; + } + + // Only move `stream_info` to a different bucket if the integral priority + // value changes. + if (PriorityTypeToInt()(stream_info->priority) != + PriorityTypeToInt()(priority) && + stream_info->ready) { + bool erased = + Erase(&priority_infos_[PriorityTypeToInt()(stream_info->priority)] + .ready_list, + stream_info); + QUICHE_DCHECK(erased); + priority_infos_[PriorityTypeToInt()(priority)].ready_list.push_back( + stream_info); + ++num_ready_streams_; + } + + // But override `priority` for the stream regardless of the integral value, + // because it might contain additional information. + stream_info->priority = std::move(priority); + } + + // Records time of a read/write event for the given stream. + // + // Preconditions: `stream_id` should be registered. + void RecordStreamEventTime(StreamIdType stream_id, absl::Time now) { + auto it = stream_infos_.find(stream_id); + if (it == stream_infos_.end()) { + QUICHE_BUG(spdy_bug_19_4) << "Stream " << stream_id << " not registered"; + return; + } + PriorityInfo& priority_info = + priority_infos_[PriorityTypeToInt()(it->second->priority)]; + priority_info.last_event_time = + std::max(priority_info.last_event_time, absl::make_optional(now)); + } + + // Returns time of the last read/write event for a stream with higher priority + // than the priority of the given stream, or nullopt if there is no such + // event. + // + // Preconditions: `stream_id` should be registered. + absl::optional GetLatestEventWithPriority( + StreamIdType stream_id) const { + auto it = stream_infos_.find(stream_id); + if (it == stream_infos_.end()) { + QUICHE_BUG(spdy_bug_19_5) << "Stream " << stream_id << " not registered"; + return absl::nullopt; + } + absl::optional last_event_time; + const StreamInfo* const stream_info = it->second.get(); + for (int p = kHighestPriority; + p < PriorityTypeToInt()(stream_info->priority); ++p) { + last_event_time = + std::max(last_event_time, priority_infos_[p].last_event_time); + } + return last_event_time; + } + + // If the scheduler has any ready streams, returns the next scheduled + // ready stream, in the process transitioning the stream from ready to not + // ready. + // + // Preconditions: `HasReadyStreams() == true` + StreamIdType PopNextReadyStream() { + return std::get<0>(PopNextReadyStreamAndPriority()); + } + + // If the scheduler has any ready streams, returns the next scheduled + // ready stream and its priority, in the process transitioning the stream from + // ready to not ready. + // + // Preconditions: `HasReadyStreams() == true` + std::tuple PopNextReadyStreamAndPriority() { + for (int p = kHighestPriority; p <= kLowestPriority; ++p) { + ReadyList& ready_list = priority_infos_[p].ready_list; + if (!ready_list.empty()) { + StreamInfo* const info = ready_list.front(); + ready_list.pop_front(); + --num_ready_streams_; + + QUICHE_DCHECK(stream_infos_.find(info->stream_id) != + stream_infos_.end()); + info->ready = false; + return std::make_tuple(info->stream_id, info->priority); + } + } + QUICHE_BUG(spdy_bug_19_6) << "No ready streams available"; + return std::make_tuple(0, IntToPriorityType()(kLowestPriority)); + } + + // Returns true if there's another stream ahead of the given stream in the + // scheduling queue. This function can be called to see if the given stream + // should yield work to another stream. + // + // Preconditions: `stream_id` should be registered. + bool ShouldYield(StreamIdType stream_id) const { + auto it = stream_infos_.find(stream_id); + if (it == stream_infos_.end()) { + QUICHE_BUG(spdy_bug_19_7) << "Stream " << stream_id << " not registered"; + return false; + } + + // If there's a higher priority stream, this stream should yield. + const StreamInfo* const stream_info = it->second.get(); + for (int p = kHighestPriority; + p < PriorityTypeToInt()(stream_info->priority); ++p) { + if (!priority_infos_[p].ready_list.empty()) { + return true; + } + } + + // If this priority level is empty, or this stream is the next up, there's + // no need to yield. + const auto& ready_list = + priority_infos_[PriorityTypeToInt()(it->second->priority)].ready_list; + if (ready_list.empty() || ready_list.front()->stream_id == stream_id) { + return false; + } + + // There are other streams in this priority level which take precedence. + // Yield. + return true; + } + + // Marks the stream as ready to write. If the stream was already ready, does + // nothing. If add_to_front is true, the stream is scheduled ahead of other + // streams of the same priority/weight, otherwise it is scheduled behind them. + // + // Preconditions: `stream_id` should be registered. + void MarkStreamReady(StreamIdType stream_id, bool add_to_front) { + auto it = stream_infos_.find(stream_id); + if (it == stream_infos_.end()) { + QUICHE_BUG(spdy_bug_19_8) << "Stream " << stream_id << " not registered"; + return; + } + StreamInfo* const stream_info = it->second.get(); + if (stream_info->ready) { + return; + } + ReadyList& ready_list = + priority_infos_[PriorityTypeToInt()(stream_info->priority)].ready_list; + if (add_to_front) { + ready_list.push_front(stream_info); + } else { + ready_list.push_back(stream_info); + } + ++num_ready_streams_; + stream_info->ready = true; + } + + // Marks the stream as not ready to write. If the stream is not registered or + // not ready, does nothing. + // + // Preconditions: `stream_id` should be registered. + void MarkStreamNotReady(StreamIdType stream_id) { + auto it = stream_infos_.find(stream_id); + if (it == stream_infos_.end()) { + QUICHE_BUG(spdy_bug_19_9) << "Stream " << stream_id << " not registered"; + return; + } + StreamInfo* const stream_info = it->second.get(); + if (!stream_info->ready) { + return; + } + bool erased = Erase( + &priority_infos_[PriorityTypeToInt()(stream_info->priority)].ready_list, + stream_info); + QUICHE_DCHECK(erased); + stream_info->ready = false; + } + + // Returns true iff the scheduler has any ready streams. + bool HasReadyStreams() const { return num_ready_streams_ > 0; } + + // Returns the number of streams currently marked ready. + size_t NumReadyStreams() const { return num_ready_streams_; } + + // Returns the number of registered streams. + size_t NumRegisteredStreams() const { return stream_infos_.size(); } + + // Returns summary of internal state, for logging/debugging. + std::string DebugString() const { + return absl::StrCat( + "PriorityWriteScheduler {num_streams=", stream_infos_.size(), + " num_ready_streams=", NumReadyStreams(), "}"); + } + + // Returns true if stream with `stream_id` is ready. + bool IsStreamReady(StreamIdType stream_id) const { + auto it = stream_infos_.find(stream_id); + if (it == stream_infos_.end()) { + QUICHE_DLOG(INFO) << "Stream " << stream_id << " not registered"; + return false; + } + return it->second->ready; + } + + private: + friend class test::PriorityWriteSchedulerPeer; + + // State kept for all registered streams. + // All ready streams have `ready == true` and should be present in + // `priority_infos_[priority].ready_list`. + struct QUICHE_EXPORT StreamInfo { + PriorityType priority; + StreamIdType stream_id; + bool ready; + }; + + // O(1) size lookup, O(1) insert at front or back (amortized). + using ReadyList = quiche::QuicheCircularDeque; + + // State kept for each priority level. + struct QUICHE_EXPORT PriorityInfo { + // IDs of streams that are ready to write. + ReadyList ready_list; + // Time of latest write event for stream of this priority. + absl::optional last_event_time; + }; + + // Use std::unique_ptr, because absl::flat_hash_map does not have pointer + // stability, but ReadyList stores pointers to the StreamInfo objects. + using StreamInfoMap = + absl::flat_hash_map>; + + // Erases `info` from `ready_list`, returning true if found (and erased), or + // false otherwise. Decrements `num_ready_streams_` if an entry is erased. + bool Erase(ReadyList* ready_list, const StreamInfo* info) { + auto it = std::remove(ready_list->begin(), ready_list->end(), info); + if (it == ready_list->end()) { + // `info` was not found. + return false; + } + ready_list->pop_back(); + --num_ready_streams_; + return true; + } + + // Number of ready streams. + size_t num_ready_streams_ = 0; + // Per-priority state, including ready lists. + PriorityInfo priority_infos_[kLowestPriority + 1]; + // StreamInfos for all registered streams. + StreamInfoMap stream_infos_; +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_CORE_PRIORITY_WRITE_SCHEDULER_H_ diff --git a/quiche/http2/core/priority_write_scheduler_test.cc b/quiche/http2/core/priority_write_scheduler_test.cc new file mode 100644 index 000000000000..96c4bd2fbe40 --- /dev/null +++ b/quiche/http2/core/priority_write_scheduler_test.cc @@ -0,0 +1,344 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/core/priority_write_scheduler.h" + +#include "quiche/common/platform/api/quiche_expect_bug.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/spdy/core/spdy_protocol.h" +#include "quiche/spdy/test_tools/spdy_test_utils.h" + +namespace http2 { +namespace test { + +using ::spdy::kHttp2RootStreamId; +using ::spdy::SpdyPriority; +using ::spdy::SpdyStreamId; +using ::testing::Eq; +using ::testing::Optional; + +template +class PriorityWriteSchedulerPeer { + public: + explicit PriorityWriteSchedulerPeer( + PriorityWriteScheduler* scheduler) + : scheduler_(scheduler) {} + + size_t NumReadyStreams(SpdyPriority priority) const { + return scheduler_->priority_infos_[priority].ready_list.size(); + } + + private: + PriorityWriteScheduler* scheduler_; +}; + +namespace { + +class PriorityWriteSchedulerTest : public quiche::test::QuicheTest { + public: + static constexpr int kLowestPriority = + PriorityWriteScheduler::kLowestPriority; + + PriorityWriteSchedulerTest() : peer_(&scheduler_) {} + + PriorityWriteScheduler scheduler_; + PriorityWriteSchedulerPeer peer_; +}; + +TEST_F(PriorityWriteSchedulerTest, RegisterUnregisterStreams) { + EXPECT_FALSE(scheduler_.HasReadyStreams()); + EXPECT_FALSE(scheduler_.StreamRegistered(1)); + EXPECT_EQ(0u, scheduler_.NumRegisteredStreams()); + scheduler_.RegisterStream(1, 1); + EXPECT_TRUE(scheduler_.StreamRegistered(1)); + EXPECT_EQ(1u, scheduler_.NumRegisteredStreams()); + + // Try redundant registrations. + EXPECT_QUICHE_BUG(scheduler_.RegisterStream(1, 1), + "Stream 1 already registered"); + EXPECT_EQ(1u, scheduler_.NumRegisteredStreams()); + + EXPECT_QUICHE_BUG(scheduler_.RegisterStream(1, 2), + "Stream 1 already registered"); + EXPECT_EQ(1u, scheduler_.NumRegisteredStreams()); + + scheduler_.RegisterStream(2, 3); + EXPECT_EQ(2u, scheduler_.NumRegisteredStreams()); + + // Verify registration != ready. + EXPECT_FALSE(scheduler_.HasReadyStreams()); + + scheduler_.UnregisterStream(1); + EXPECT_EQ(1u, scheduler_.NumRegisteredStreams()); + scheduler_.UnregisterStream(2); + EXPECT_EQ(0u, scheduler_.NumRegisteredStreams()); + + // Try redundant unregistration. + EXPECT_QUICHE_BUG(scheduler_.UnregisterStream(1), "Stream 1 not registered"); + EXPECT_QUICHE_BUG(scheduler_.UnregisterStream(2), "Stream 2 not registered"); + EXPECT_EQ(0u, scheduler_.NumRegisteredStreams()); +} + +TEST_F(PriorityWriteSchedulerTest, GetStreamPriority) { + // Unknown streams tolerated due to b/15676312. However, return lowest + // priority. + EXPECT_EQ(kLowestPriority, scheduler_.GetStreamPriority(1)); + + scheduler_.RegisterStream(1, 3); + EXPECT_EQ(3, scheduler_.GetStreamPriority(1)); + + // Redundant registration shouldn't change stream priority. + EXPECT_QUICHE_BUG(scheduler_.RegisterStream(1, 4), + "Stream 1 already registered"); + EXPECT_EQ(3, scheduler_.GetStreamPriority(1)); + + scheduler_.UpdateStreamPriority(1, 5); + EXPECT_EQ(5, scheduler_.GetStreamPriority(1)); + + // Toggling ready state shouldn't change stream priority. + scheduler_.MarkStreamReady(1, true); + EXPECT_EQ(5, scheduler_.GetStreamPriority(1)); + + // Test changing priority of ready stream. + EXPECT_EQ(1u, peer_.NumReadyStreams(5)); + scheduler_.UpdateStreamPriority(1, 6); + EXPECT_EQ(6, scheduler_.GetStreamPriority(1)); + EXPECT_EQ(0u, peer_.NumReadyStreams(5)); + EXPECT_EQ(1u, peer_.NumReadyStreams(6)); + + EXPECT_EQ(1u, scheduler_.PopNextReadyStream()); + EXPECT_EQ(6, scheduler_.GetStreamPriority(1)); + + scheduler_.UnregisterStream(1); + EXPECT_EQ(kLowestPriority, scheduler_.GetStreamPriority(1)); +} + +TEST_F(PriorityWriteSchedulerTest, PopNextReadyStreamAndPriority) { + scheduler_.RegisterStream(1, 3); + scheduler_.MarkStreamReady(1, true); + EXPECT_EQ(std::make_tuple(1u, 3), scheduler_.PopNextReadyStreamAndPriority()); + scheduler_.UnregisterStream(1); +} + +TEST_F(PriorityWriteSchedulerTest, UpdateStreamPriority) { + // For the moment, updating stream priority on a non-registered stream should + // have no effect. In the future, it will lazily cause the stream to be + // registered (b/15676312). + EXPECT_EQ(kLowestPriority, scheduler_.GetStreamPriority(3)); + EXPECT_FALSE(scheduler_.StreamRegistered(3)); + scheduler_.UpdateStreamPriority(3, 1); + EXPECT_FALSE(scheduler_.StreamRegistered(3)); + EXPECT_EQ(kLowestPriority, scheduler_.GetStreamPriority(3)); + + scheduler_.RegisterStream(3, 1); + EXPECT_EQ(1, scheduler_.GetStreamPriority(3)); + scheduler_.UpdateStreamPriority(3, 2); + EXPECT_EQ(2, scheduler_.GetStreamPriority(3)); + + // Updating priority of stream to current priority value is valid, but has no + // effect. + scheduler_.UpdateStreamPriority(3, 2); + EXPECT_EQ(2, scheduler_.GetStreamPriority(3)); + + // Even though stream 4 is marked ready after stream 5, it should be returned + // first by PopNextReadyStream() since it has higher priority. + scheduler_.RegisterStream(4, 1); + scheduler_.MarkStreamReady(3, false); // priority 2 + EXPECT_TRUE(scheduler_.IsStreamReady(3)); + scheduler_.MarkStreamReady(4, false); // priority 1 + EXPECT_TRUE(scheduler_.IsStreamReady(4)); + EXPECT_EQ(4u, scheduler_.PopNextReadyStream()); + EXPECT_FALSE(scheduler_.IsStreamReady(4)); + EXPECT_EQ(3u, scheduler_.PopNextReadyStream()); + EXPECT_FALSE(scheduler_.IsStreamReady(3)); + + // Verify that lowering priority of stream 4 causes it to be returned later + // by PopNextReadyStream(). + scheduler_.MarkStreamReady(3, false); // priority 2 + scheduler_.MarkStreamReady(4, false); // priority 1 + scheduler_.UpdateStreamPriority(4, 3); + EXPECT_EQ(3u, scheduler_.PopNextReadyStream()); + EXPECT_EQ(4u, scheduler_.PopNextReadyStream()); + + scheduler_.UnregisterStream(3); +} + +TEST_F(PriorityWriteSchedulerTest, MarkStreamReadyBack) { + EXPECT_FALSE(scheduler_.HasReadyStreams()); + EXPECT_QUICHE_BUG(scheduler_.MarkStreamReady(1, false), + "Stream 1 not registered"); + EXPECT_FALSE(scheduler_.HasReadyStreams()); + EXPECT_QUICHE_BUG(EXPECT_EQ(0u, scheduler_.PopNextReadyStream()), + "No ready streams available"); + + // Add a bunch of ready streams to tail of per-priority lists. + // Expected order: (P2) 4, (P3) 1, 2, 3, (P5) 5. + scheduler_.RegisterStream(1, 3); + scheduler_.MarkStreamReady(1, false); + EXPECT_TRUE(scheduler_.HasReadyStreams()); + scheduler_.RegisterStream(2, 3); + scheduler_.MarkStreamReady(2, false); + scheduler_.RegisterStream(3, 3); + scheduler_.MarkStreamReady(3, false); + scheduler_.RegisterStream(4, 2); + scheduler_.MarkStreamReady(4, false); + scheduler_.RegisterStream(5, 5); + scheduler_.MarkStreamReady(5, false); + + EXPECT_EQ(4u, scheduler_.PopNextReadyStream()); + EXPECT_EQ(1u, scheduler_.PopNextReadyStream()); + EXPECT_EQ(2u, scheduler_.PopNextReadyStream()); + EXPECT_EQ(3u, scheduler_.PopNextReadyStream()); + EXPECT_EQ(5u, scheduler_.PopNextReadyStream()); + EXPECT_QUICHE_BUG(EXPECT_EQ(0u, scheduler_.PopNextReadyStream()), + "No ready streams available"); +} + +TEST_F(PriorityWriteSchedulerTest, MarkStreamReadyFront) { + EXPECT_FALSE(scheduler_.HasReadyStreams()); + EXPECT_QUICHE_BUG(scheduler_.MarkStreamReady(1, true), + "Stream 1 not registered"); + EXPECT_FALSE(scheduler_.HasReadyStreams()); + EXPECT_QUICHE_BUG(EXPECT_EQ(0u, scheduler_.PopNextReadyStream()), + "No ready streams available"); + + // Add a bunch of ready streams to head of per-priority lists. + // Expected order: (P2) 4, (P3) 3, 2, 1, (P5) 5 + scheduler_.RegisterStream(1, 3); + scheduler_.MarkStreamReady(1, true); + EXPECT_TRUE(scheduler_.HasReadyStreams()); + scheduler_.RegisterStream(2, 3); + scheduler_.MarkStreamReady(2, true); + scheduler_.RegisterStream(3, 3); + scheduler_.MarkStreamReady(3, true); + scheduler_.RegisterStream(4, 2); + scheduler_.MarkStreamReady(4, true); + scheduler_.RegisterStream(5, 5); + scheduler_.MarkStreamReady(5, true); + + EXPECT_EQ(4u, scheduler_.PopNextReadyStream()); + EXPECT_EQ(3u, scheduler_.PopNextReadyStream()); + EXPECT_EQ(2u, scheduler_.PopNextReadyStream()); + EXPECT_EQ(1u, scheduler_.PopNextReadyStream()); + EXPECT_EQ(5u, scheduler_.PopNextReadyStream()); + EXPECT_QUICHE_BUG(EXPECT_EQ(0u, scheduler_.PopNextReadyStream()), + "No ready streams available"); +} + +TEST_F(PriorityWriteSchedulerTest, MarkStreamReadyBackAndFront) { + scheduler_.RegisterStream(1, 4); + scheduler_.RegisterStream(2, 3); + scheduler_.RegisterStream(3, 3); + scheduler_.RegisterStream(4, 3); + scheduler_.RegisterStream(5, 4); + scheduler_.RegisterStream(6, 1); + + // Add a bunch of ready streams to per-priority lists, with variety of adding + // at head and tail. + // Expected order: (P1) 6, (P3) 4, 2, 3, (P4) 1, 5 + scheduler_.MarkStreamReady(1, true); + scheduler_.MarkStreamReady(2, true); + scheduler_.MarkStreamReady(3, false); + scheduler_.MarkStreamReady(4, true); + scheduler_.MarkStreamReady(5, false); + scheduler_.MarkStreamReady(6, true); + + EXPECT_EQ(6u, scheduler_.PopNextReadyStream()); + EXPECT_EQ(4u, scheduler_.PopNextReadyStream()); + EXPECT_EQ(2u, scheduler_.PopNextReadyStream()); + EXPECT_EQ(3u, scheduler_.PopNextReadyStream()); + EXPECT_EQ(1u, scheduler_.PopNextReadyStream()); + EXPECT_EQ(5u, scheduler_.PopNextReadyStream()); + EXPECT_QUICHE_BUG(EXPECT_EQ(0u, scheduler_.PopNextReadyStream()), + "No ready streams available"); +} + +TEST_F(PriorityWriteSchedulerTest, MarkStreamNotReady) { + // Verify ready state reflected in NumReadyStreams(). + scheduler_.RegisterStream(1, 1); + EXPECT_EQ(0u, scheduler_.NumReadyStreams()); + scheduler_.MarkStreamReady(1, false); + EXPECT_EQ(1u, scheduler_.NumReadyStreams()); + scheduler_.MarkStreamNotReady(1); + EXPECT_EQ(0u, scheduler_.NumReadyStreams()); + + // Empty pop should fail. + EXPECT_QUICHE_BUG(EXPECT_EQ(0u, scheduler_.PopNextReadyStream()), + "No ready streams available"); + + // Tolerate redundant marking of a stream as not ready. + scheduler_.MarkStreamNotReady(1); + EXPECT_EQ(0u, scheduler_.NumReadyStreams()); + + // Should only be able to mark registered streams. + EXPECT_QUICHE_BUG(scheduler_.MarkStreamNotReady(3), + "Stream 3 not registered"); +} + +TEST_F(PriorityWriteSchedulerTest, UnregisterRemovesStream) { + scheduler_.RegisterStream(3, 4); + scheduler_.MarkStreamReady(3, false); + EXPECT_EQ(1u, scheduler_.NumReadyStreams()); + + // Unregistering a stream should remove it from set of ready streams. + scheduler_.UnregisterStream(3); + EXPECT_EQ(0u, scheduler_.NumReadyStreams()); + EXPECT_QUICHE_BUG(EXPECT_EQ(0u, scheduler_.PopNextReadyStream()), + "No ready streams available"); +} + +TEST_F(PriorityWriteSchedulerTest, ShouldYield) { + scheduler_.RegisterStream(1, 1); + scheduler_.RegisterStream(4, 4); + scheduler_.RegisterStream(5, 4); + scheduler_.RegisterStream(7, 7); + + // Make sure we don't yield when the list is empty. + EXPECT_FALSE(scheduler_.ShouldYield(1)); + + // Add a low priority stream. + scheduler_.MarkStreamReady(4, false); + // 4 should not yield to itself. + EXPECT_FALSE(scheduler_.ShouldYield(4)); + // 7 should yield as 4 is blocked and a higher priority. + EXPECT_TRUE(scheduler_.ShouldYield(7)); + // 5 should yield to 4 as they are the same priority. + EXPECT_TRUE(scheduler_.ShouldYield(5)); + // 1 should not yield as 1 is higher priority. + EXPECT_FALSE(scheduler_.ShouldYield(1)); + + // Add a second stream in that priority class. + scheduler_.MarkStreamReady(5, false); + // 4 and 5 are both blocked, but 4 is at the front so should not yield. + EXPECT_FALSE(scheduler_.ShouldYield(4)); + EXPECT_TRUE(scheduler_.ShouldYield(5)); +} + +TEST_F(PriorityWriteSchedulerTest, GetLatestEventWithPriority) { + EXPECT_QUICHE_BUG( + scheduler_.RecordStreamEventTime(3, absl::FromUnixMicros(5)), + "Stream 3 not registered"); + EXPECT_QUICHE_BUG( + EXPECT_FALSE(scheduler_.GetLatestEventWithPriority(4).has_value()), + "Stream 4 not registered"); + + for (int i = 1; i < 5; ++i) { + scheduler_.RegisterStream(i, i); + } + for (int i = 1; i < 5; ++i) { + EXPECT_FALSE(scheduler_.GetLatestEventWithPriority(i).has_value()); + } + for (int i = 1; i < 5; ++i) { + scheduler_.RecordStreamEventTime(i, absl::FromUnixMicros(i * 100)); + } + EXPECT_FALSE(scheduler_.GetLatestEventWithPriority(1).has_value()); + for (int i = 2; i < 5; ++i) { + EXPECT_THAT(scheduler_.GetLatestEventWithPriority(i), + Optional(Eq(absl::FromUnixMicros((i - 1) * 100)))); + } +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/decoder/decode_buffer.cc b/quiche/http2/decoder/decode_buffer.cc new file mode 100644 index 000000000000..1d13ca085892 --- /dev/null +++ b/quiche/http2/decoder/decode_buffer.cc @@ -0,0 +1,93 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/decode_buffer.h" + +namespace http2 { + +uint8_t DecodeBuffer::DecodeUInt8() { + return static_cast(DecodeChar()); +} + +uint16_t DecodeBuffer::DecodeUInt16() { + QUICHE_DCHECK_LE(2u, Remaining()); + const uint8_t b1 = DecodeUInt8(); + const uint8_t b2 = DecodeUInt8(); + // Note that chars are automatically promoted to ints during arithmetic, + // so the b1 << 8 doesn't end up as zero before being or-ed with b2. + // And the left-shift operator has higher precedence than the or operator. + return b1 << 8 | b2; +} + +uint32_t DecodeBuffer::DecodeUInt24() { + QUICHE_DCHECK_LE(3u, Remaining()); + const uint8_t b1 = DecodeUInt8(); + const uint8_t b2 = DecodeUInt8(); + const uint8_t b3 = DecodeUInt8(); + return b1 << 16 | b2 << 8 | b3; +} + +uint32_t DecodeBuffer::DecodeUInt31() { + QUICHE_DCHECK_LE(4u, Remaining()); + const uint8_t b1 = DecodeUInt8() & 0x7f; // Mask out the high order bit. + const uint8_t b2 = DecodeUInt8(); + const uint8_t b3 = DecodeUInt8(); + const uint8_t b4 = DecodeUInt8(); + return b1 << 24 | b2 << 16 | b3 << 8 | b4; +} + +uint32_t DecodeBuffer::DecodeUInt32() { + QUICHE_DCHECK_LE(4u, Remaining()); + const uint8_t b1 = DecodeUInt8(); + const uint8_t b2 = DecodeUInt8(); + const uint8_t b3 = DecodeUInt8(); + const uint8_t b4 = DecodeUInt8(); + return b1 << 24 | b2 << 16 | b3 << 8 | b4; +} + +#ifndef NDEBUG +void DecodeBuffer::set_subset_of_base(DecodeBuffer* base, + const DecodeBufferSubset* subset) { + QUICHE_DCHECK_EQ(this, subset); + base->set_subset(subset); +} +void DecodeBuffer::clear_subset_of_base(DecodeBuffer* base, + const DecodeBufferSubset* subset) { + QUICHE_DCHECK_EQ(this, subset); + base->clear_subset(subset); +} +void DecodeBuffer::set_subset(const DecodeBufferSubset* subset) { + QUICHE_DCHECK(subset != nullptr); + QUICHE_DCHECK_EQ(subset_, nullptr) << "There is already a subset"; + subset_ = subset; +} +void DecodeBuffer::clear_subset(const DecodeBufferSubset* subset) { + QUICHE_DCHECK(subset != nullptr); + QUICHE_DCHECK_EQ(subset_, subset); + subset_ = nullptr; +} +void DecodeBufferSubset::DebugSetup() { + start_base_offset_ = base_buffer_->Offset(); + max_base_offset_ = start_base_offset_ + FullSize(); + QUICHE_DCHECK_LE(max_base_offset_, base_buffer_->FullSize()); + + // Ensure that there is only one DecodeBufferSubset at a time for a base. + set_subset_of_base(base_buffer_, this); +} +void DecodeBufferSubset::DebugTearDown() { + // Ensure that the base hasn't been modified. + QUICHE_DCHECK_EQ(start_base_offset_, base_buffer_->Offset()) + << "The base buffer was modified"; + + // Ensure that we haven't gone beyond the maximum allowed offset. + size_t offset = Offset(); + QUICHE_DCHECK_LE(offset, FullSize()); + QUICHE_DCHECK_LE(start_base_offset_ + offset, max_base_offset_); + QUICHE_DCHECK_LE(max_base_offset_, base_buffer_->FullSize()); + + clear_subset_of_base(base_buffer_, this); +} +#endif + +} // namespace http2 diff --git a/quiche/http2/decoder/decode_buffer.h b/quiche/http2/decoder/decode_buffer.h new file mode 100644 index 000000000000..5c9262f36951 --- /dev/null +++ b/quiche/http2/decoder/decode_buffer.h @@ -0,0 +1,171 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_DECODER_DECODE_BUFFER_H_ +#define QUICHE_HTTP2_DECODER_DECODE_BUFFER_H_ + +// DecodeBuffer provides primitives for decoding various integer types found in +// HTTP/2 frames. It wraps a byte array from which we can read and decode +// serialized HTTP/2 frames, or parts thereof. DecodeBuffer is intended only for +// stack allocation, where the caller is typically going to use the DecodeBuffer +// instance as part of decoding the entire buffer before returning to its own +// caller. + +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { +class DecodeBufferSubset; + +class QUICHE_EXPORT DecodeBuffer { + public: + // We assume the decode buffers will typically be modest in size (i.e. often a + // few KB, perhaps as high as 100KB). Let's make sure during testing that we + // don't go very high, with 32MB selected rather arbitrarily. This is exposed + // to support testing. + static constexpr size_t kMaxDecodeBufferLength = 1 << 25; + + DecodeBuffer(const char* buffer, size_t len) + : buffer_(buffer), cursor_(buffer), beyond_(buffer + len) { + QUICHE_DCHECK(buffer != nullptr); + QUICHE_DCHECK_LE(len, kMaxDecodeBufferLength); + } + explicit DecodeBuffer(absl::string_view s) + : DecodeBuffer(s.data(), s.size()) {} + // Constructor for character arrays, typically in tests. For example: + // const char input[] = { 0x11 }; + // DecodeBuffer b(input); + template + explicit DecodeBuffer(const char (&buf)[N]) : DecodeBuffer(buf, N) {} + + DecodeBuffer(const DecodeBuffer&) = delete; + DecodeBuffer operator=(const DecodeBuffer&) = delete; + + bool Empty() const { return cursor_ >= beyond_; } + bool HasData() const { return cursor_ < beyond_; } + size_t Remaining() const { + QUICHE_DCHECK_LE(cursor_, beyond_); + return beyond_ - cursor_; + } + size_t Offset() const { return cursor_ - buffer_; } + size_t FullSize() const { return beyond_ - buffer_; } + + // Returns the minimum of the number of bytes remaining in this DecodeBuffer + // and |length|, in support of determining how much of some structure/payload + // is in this DecodeBuffer. + size_t MinLengthRemaining(size_t length) const { + return std::min(length, Remaining()); + } + + // For string decoding, returns a pointer to the next byte/char to be decoded. + const char* cursor() const { return cursor_; } + // Advances the cursor (pointer to the next byte/char to be decoded). + void AdvanceCursor(size_t amount) { + QUICHE_DCHECK_LE(amount, + Remaining()); // Need at least that much remaining. + QUICHE_DCHECK_EQ(subset_, nullptr) + << "Access via subset only when present."; + cursor_ += amount; + } + + // Only call methods starting "Decode" when there is enough input remaining. + char DecodeChar() { + QUICHE_DCHECK_LE(1u, Remaining()); // Need at least one byte remaining. + QUICHE_DCHECK_EQ(subset_, nullptr) + << "Access via subset only when present."; + return *cursor_++; + } + + uint8_t DecodeUInt8(); + uint16_t DecodeUInt16(); + uint32_t DecodeUInt24(); + + // For 31-bit unsigned integers, where the 32nd bit is reserved for future + // use (i.e. the high-bit of the first byte of the encoding); examples: + // the Stream Id in a frame header or the Window Size Increment in a + // WINDOW_UPDATE frame. + uint32_t DecodeUInt31(); + + uint32_t DecodeUInt32(); + + protected: +#ifndef NDEBUG + // These are part of validating during tests that there is at most one + // DecodeBufferSubset instance at a time for any DecodeBuffer instance. + void set_subset_of_base(DecodeBuffer* base, const DecodeBufferSubset* subset); + void clear_subset_of_base(DecodeBuffer* base, + const DecodeBufferSubset* subset); +#endif + + private: +#ifndef NDEBUG + void set_subset(const DecodeBufferSubset* subset); + void clear_subset(const DecodeBufferSubset* subset); +#endif + + // Prevent heap allocation of DecodeBuffer. + static void* operator new(size_t s); + static void* operator new[](size_t s); + static void operator delete(void* p); + static void operator delete[](void* p); + + const char* const buffer_; + const char* cursor_; + const char* const beyond_; + const DecodeBufferSubset* subset_ = nullptr; // Used for QUICHE_DCHECKs. +}; + +// DecodeBufferSubset is used when decoding a known sized chunk of data, which +// starts at base->cursor(), and continues for subset_len, which may be +// entirely in |base|, or may extend beyond it (hence the MinLengthRemaining +// in the constructor). +// There are two benefits to using DecodeBufferSubset: it ensures that the +// cursor of |base| is advanced when the subset's destructor runs, and it +// ensures that the consumer of the subset can't go beyond the subset which +// it is intended to decode. +// There must be only a single DecodeBufferSubset at a time for a base +// DecodeBuffer, though they can be nested (i.e. a DecodeBufferSubset's +// base may itself be a DecodeBufferSubset). This avoids the AdvanceCursor +// being called erroneously. +class QUICHE_EXPORT DecodeBufferSubset : public DecodeBuffer { + public: + DecodeBufferSubset(DecodeBuffer* base, size_t subset_len) + : DecodeBuffer(base->cursor(), base->MinLengthRemaining(subset_len)), + base_buffer_(base) { +#ifndef NDEBUG + DebugSetup(); +#endif + } + + DecodeBufferSubset(const DecodeBufferSubset&) = delete; + DecodeBufferSubset operator=(const DecodeBufferSubset&) = delete; + + ~DecodeBufferSubset() { + size_t offset = Offset(); +#ifndef NDEBUG + DebugTearDown(); +#endif + base_buffer_->AdvanceCursor(offset); + } + + private: + DecodeBuffer* const base_buffer_; +#ifndef NDEBUG + size_t start_base_offset_; // Used for QUICHE_DCHECKs. + size_t max_base_offset_; // Used for QUICHE_DCHECKs. + + void DebugSetup(); + void DebugTearDown(); +#endif +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_DECODER_DECODE_BUFFER_H_ diff --git a/quiche/http2/decoder/decode_buffer_test.cc b/quiche/http2/decoder/decode_buffer_test.cc new file mode 100644 index 000000000000..7eaae79f7bdc --- /dev/null +++ b/quiche/http2/decoder/decode_buffer_test.cc @@ -0,0 +1,203 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/decode_buffer.h" + +#include + +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { +namespace { + +enum class TestEnumClass32 { + kValue1 = 1, + kValue99 = 99, + kValue1M = 1000000, +}; + +enum class TestEnumClass8 { + kValue1 = 1, + kValue2 = 1, + kValue99 = 99, + kValue255 = 255, +}; + +enum TestEnum8 { + kMaskLo = 0x01, + kMaskHi = 0x80, +}; + +struct TestStruct { + uint8_t f1; + uint16_t f2; + uint32_t f3; // Decoded as a uint24 + uint32_t f4; + uint32_t f5; // Decoded as if uint31 + TestEnumClass32 f6; + TestEnumClass8 f7; + TestEnum8 f8; +}; + +class DecodeBufferTest : public quiche::test::QuicheTest { + protected: + Http2Random random_; + uint32_t decode_offset_; +}; + +TEST_F(DecodeBufferTest, DecodesFixedInts) { + const char data[] = "\x01\x12\x23\x34\x45\x56\x67\x78\x89\x9a"; + DecodeBuffer b1(data, strlen(data)); + EXPECT_EQ(1, b1.DecodeUInt8()); + EXPECT_EQ(0x1223u, b1.DecodeUInt16()); + EXPECT_EQ(0x344556u, b1.DecodeUInt24()); + EXPECT_EQ(0x6778899Au, b1.DecodeUInt32()); +} + +// Make sure that DecodeBuffer is not copying input, just pointing into +// provided input buffer. +TEST_F(DecodeBufferTest, HasNotCopiedInput) { + const char data[] = "ab"; + DecodeBuffer b1(data, 2); + + EXPECT_EQ(2u, b1.Remaining()); + EXPECT_EQ(0u, b1.Offset()); + EXPECT_FALSE(b1.Empty()); + EXPECT_EQ(data, b1.cursor()); // cursor points to input buffer + EXPECT_TRUE(b1.HasData()); + + b1.AdvanceCursor(1); + + EXPECT_EQ(1u, b1.Remaining()); + EXPECT_EQ(1u, b1.Offset()); + EXPECT_FALSE(b1.Empty()); + EXPECT_EQ(&data[1], b1.cursor()); + EXPECT_TRUE(b1.HasData()); + + b1.AdvanceCursor(1); + + EXPECT_EQ(0u, b1.Remaining()); + EXPECT_EQ(2u, b1.Offset()); + EXPECT_TRUE(b1.Empty()); + EXPECT_EQ(&data[2], b1.cursor()); + EXPECT_FALSE(b1.HasData()); + + DecodeBuffer b2(data, 0); + + EXPECT_EQ(0u, b2.Remaining()); + EXPECT_EQ(0u, b2.Offset()); + EXPECT_TRUE(b2.Empty()); + EXPECT_EQ(data, b2.cursor()); + EXPECT_FALSE(b2.HasData()); +} + +// DecodeBufferSubset can't go beyond the end of the base buffer. +TEST_F(DecodeBufferTest, DecodeBufferSubsetLimited) { + const char data[] = "abc"; + DecodeBuffer base(data, 3); + base.AdvanceCursor(1); + DecodeBufferSubset subset(&base, 100); + EXPECT_EQ(2u, subset.FullSize()); +} + +// DecodeBufferSubset advances the cursor of its base upon destruction. +TEST_F(DecodeBufferTest, DecodeBufferSubsetAdvancesCursor) { + const char data[] = "abc"; + const size_t size = sizeof(data) - 1; + EXPECT_EQ(3u, size); + DecodeBuffer base(data, size); + { + // First no change to the cursor. + DecodeBufferSubset subset(&base, size + 100); + EXPECT_EQ(size, subset.FullSize()); + EXPECT_EQ(base.FullSize(), subset.FullSize()); + EXPECT_EQ(0u, subset.Offset()); + } + EXPECT_EQ(0u, base.Offset()); + EXPECT_EQ(size, base.Remaining()); +} + +// Make sure that DecodeBuffer ctor complains about bad args. +#if GTEST_HAS_DEATH_TEST && !defined(NDEBUG) +TEST(DecodeBufferDeathTest, NonNullBufferRequired) { + EXPECT_QUICHE_DEBUG_DEATH({ DecodeBuffer b(nullptr, 3); }, "nullptr"); +} + +// Make sure that DecodeBuffer ctor complains about bad args. +TEST(DecodeBufferDeathTest, ModestBufferSizeRequired) { + EXPECT_QUICHE_DEBUG_DEATH( + { + constexpr size_t kLength = DecodeBuffer::kMaxDecodeBufferLength + 1; + auto data = std::make_unique(kLength); + DecodeBuffer b(data.get(), kLength); + }, + "Max.*Length"); +} + +// Make sure that DecodeBuffer detects advance beyond end, in debug mode. +TEST(DecodeBufferDeathTest, LimitedAdvance) { + { + // Advance right up to end is OK. + const char data[] = "abc"; + DecodeBuffer b(data, 3); + b.AdvanceCursor(3); // OK + EXPECT_TRUE(b.Empty()); + } + EXPECT_QUICHE_DEBUG_DEATH( + { + // Going beyond is not OK. + const char data[] = "abc"; + DecodeBuffer b(data, 3); + b.AdvanceCursor(4); + }, + "Remaining"); +} + +// Make sure that DecodeBuffer detects decode beyond end, in debug mode. +TEST(DecodeBufferDeathTest, DecodeUInt8PastEnd) { + const char data[] = {0x12, 0x23}; + DecodeBuffer b(data, sizeof data); + EXPECT_EQ(2u, b.FullSize()); + EXPECT_EQ(0x1223, b.DecodeUInt16()); + EXPECT_QUICHE_DEBUG_DEATH({ b.DecodeUInt8(); }, "Remaining"); +} + +// Make sure that DecodeBuffer detects decode beyond end, in debug mode. +TEST(DecodeBufferDeathTest, DecodeUInt16OverEnd) { + const char data[] = {0x12, 0x23, 0x34}; + DecodeBuffer b(data, sizeof data); + EXPECT_EQ(3u, b.FullSize()); + EXPECT_EQ(0x1223, b.DecodeUInt16()); + EXPECT_QUICHE_DEBUG_DEATH({ b.DecodeUInt16(); }, "Remaining"); +} + +// Make sure that DecodeBuffer doesn't agree with having two subsets. +TEST(DecodeBufferSubsetDeathTest, TwoSubsets) { + const char data[] = "abc"; + DecodeBuffer base(data, 3); + DecodeBufferSubset subset1(&base, 1); + EXPECT_QUICHE_DEBUG_DEATH({ DecodeBufferSubset subset2(&base, 1); }, + "There is already a subset"); +} + +// Make sure that DecodeBufferSubset notices when the base's cursor has moved. +TEST(DecodeBufferSubsetDeathTest, BaseCursorAdvanced) { + const char data[] = "abc"; + DecodeBuffer base(data, 3); + base.AdvanceCursor(1); + EXPECT_QUICHE_DEBUG_DEATH( + { + DecodeBufferSubset subset1(&base, 2); + base.AdvanceCursor(1); + }, + "Access via subset only when present"); +} +#endif // GTEST_HAS_DEATH_TEST && !defined(NDEBUG) + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/decoder/decode_http2_structures.cc b/quiche/http2/decoder/decode_http2_structures.cc new file mode 100644 index 000000000000..1730a21e520f --- /dev/null +++ b/quiche/http2/decoder/decode_http2_structures.cc @@ -0,0 +1,121 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/decode_http2_structures.h" + +#include +#include + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +// Http2FrameHeader decoding: + +void DoDecode(Http2FrameHeader* out, DecodeBuffer* b) { + QUICHE_DCHECK_NE(nullptr, out); + QUICHE_DCHECK_NE(nullptr, b); + QUICHE_DCHECK_LE(Http2FrameHeader::EncodedSize(), b->Remaining()); + out->payload_length = b->DecodeUInt24(); + out->type = static_cast(b->DecodeUInt8()); + out->flags = static_cast(b->DecodeUInt8()); + out->stream_id = b->DecodeUInt31(); +} + +// Http2PriorityFields decoding: + +void DoDecode(Http2PriorityFields* out, DecodeBuffer* b) { + QUICHE_DCHECK_NE(nullptr, out); + QUICHE_DCHECK_NE(nullptr, b); + QUICHE_DCHECK_LE(Http2PriorityFields::EncodedSize(), b->Remaining()); + uint32_t stream_id_and_flag = b->DecodeUInt32(); + out->stream_dependency = stream_id_and_flag & StreamIdMask(); + if (out->stream_dependency == stream_id_and_flag) { + out->is_exclusive = false; + } else { + out->is_exclusive = true; + } + // Note that chars are automatically promoted to ints during arithmetic, + // so 255 + 1 doesn't end up as zero. + out->weight = b->DecodeUInt8() + 1; +} + +// Http2RstStreamFields decoding: + +void DoDecode(Http2RstStreamFields* out, DecodeBuffer* b) { + QUICHE_DCHECK_NE(nullptr, out); + QUICHE_DCHECK_NE(nullptr, b); + QUICHE_DCHECK_LE(Http2RstStreamFields::EncodedSize(), b->Remaining()); + out->error_code = static_cast(b->DecodeUInt32()); +} + +// Http2SettingFields decoding: + +void DoDecode(Http2SettingFields* out, DecodeBuffer* b) { + QUICHE_DCHECK_NE(nullptr, out); + QUICHE_DCHECK_NE(nullptr, b); + QUICHE_DCHECK_LE(Http2SettingFields::EncodedSize(), b->Remaining()); + out->parameter = static_cast(b->DecodeUInt16()); + out->value = b->DecodeUInt32(); +} + +// Http2PushPromiseFields decoding: + +void DoDecode(Http2PushPromiseFields* out, DecodeBuffer* b) { + QUICHE_DCHECK_NE(nullptr, out); + QUICHE_DCHECK_NE(nullptr, b); + QUICHE_DCHECK_LE(Http2PushPromiseFields::EncodedSize(), b->Remaining()); + out->promised_stream_id = b->DecodeUInt31(); +} + +// Http2PingFields decoding: + +void DoDecode(Http2PingFields* out, DecodeBuffer* b) { + QUICHE_DCHECK_NE(nullptr, out); + QUICHE_DCHECK_NE(nullptr, b); + QUICHE_DCHECK_LE(Http2PingFields::EncodedSize(), b->Remaining()); + memcpy(out->opaque_bytes, b->cursor(), Http2PingFields::EncodedSize()); + b->AdvanceCursor(Http2PingFields::EncodedSize()); +} + +// Http2GoAwayFields decoding: + +void DoDecode(Http2GoAwayFields* out, DecodeBuffer* b) { + QUICHE_DCHECK_NE(nullptr, out); + QUICHE_DCHECK_NE(nullptr, b); + QUICHE_DCHECK_LE(Http2GoAwayFields::EncodedSize(), b->Remaining()); + out->last_stream_id = b->DecodeUInt31(); + out->error_code = static_cast(b->DecodeUInt32()); +} + +// Http2WindowUpdateFields decoding: + +void DoDecode(Http2WindowUpdateFields* out, DecodeBuffer* b) { + QUICHE_DCHECK_NE(nullptr, out); + QUICHE_DCHECK_NE(nullptr, b); + QUICHE_DCHECK_LE(Http2WindowUpdateFields::EncodedSize(), b->Remaining()); + out->window_size_increment = b->DecodeUInt31(); +} + +// Http2PriorityUpdateFields decoding: + +void DoDecode(Http2PriorityUpdateFields* out, DecodeBuffer* b) { + QUICHE_DCHECK_NE(nullptr, out); + QUICHE_DCHECK_NE(nullptr, b); + QUICHE_DCHECK_LE(Http2PriorityUpdateFields::EncodedSize(), b->Remaining()); + out->prioritized_stream_id = b->DecodeUInt31(); +} + +// Http2AltSvcFields decoding: + +void DoDecode(Http2AltSvcFields* out, DecodeBuffer* b) { + QUICHE_DCHECK_NE(nullptr, out); + QUICHE_DCHECK_NE(nullptr, b); + QUICHE_DCHECK_LE(Http2AltSvcFields::EncodedSize(), b->Remaining()); + out->origin_length = b->DecodeUInt16(); +} + +} // namespace http2 diff --git a/quiche/http2/decoder/decode_http2_structures.h b/quiche/http2/decoder/decode_http2_structures.h new file mode 100644 index 000000000000..9740497f4272 --- /dev/null +++ b/quiche/http2/decoder/decode_http2_structures.h @@ -0,0 +1,33 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_DECODER_DECODE_HTTP2_STRUCTURES_H_ +#define QUICHE_HTTP2_DECODER_DECODE_HTTP2_STRUCTURES_H_ + +// Provides functions for decoding the fixed size structures in the HTTP/2 spec. + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { + +// DoDecode(STRUCTURE* out, DecodeBuffer* b) decodes the structure from start +// to end, advancing the cursor by STRUCTURE::EncodedSize(). The decode buffer +// must be large enough (i.e. b->Remaining() >= STRUCTURE::EncodedSize()). + +QUICHE_EXPORT void DoDecode(Http2FrameHeader* out, DecodeBuffer* b); +QUICHE_EXPORT void DoDecode(Http2PriorityFields* out, DecodeBuffer* b); +QUICHE_EXPORT void DoDecode(Http2RstStreamFields* out, DecodeBuffer* b); +QUICHE_EXPORT void DoDecode(Http2SettingFields* out, DecodeBuffer* b); +QUICHE_EXPORT void DoDecode(Http2PushPromiseFields* out, DecodeBuffer* b); +QUICHE_EXPORT void DoDecode(Http2PingFields* out, DecodeBuffer* b); +QUICHE_EXPORT void DoDecode(Http2GoAwayFields* out, DecodeBuffer* b); +QUICHE_EXPORT void DoDecode(Http2WindowUpdateFields* out, DecodeBuffer* b); +QUICHE_EXPORT void DoDecode(Http2AltSvcFields* out, DecodeBuffer* b); +QUICHE_EXPORT void DoDecode(Http2PriorityUpdateFields* out, DecodeBuffer* b); + +} // namespace http2 + +#endif // QUICHE_HTTP2_DECODER_DECODE_HTTP2_STRUCTURES_H_ diff --git a/quiche/http2/decoder/decode_http2_structures_test.cc b/quiche/http2/decoder/decode_http2_structures_test.cc new file mode 100644 index 000000000000..61413f3afc29 --- /dev/null +++ b/quiche/http2/decoder/decode_http2_structures_test.cc @@ -0,0 +1,460 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/decode_http2_structures.h" + +// Tests decoding all of the fixed size HTTP/2 structures (i.e. those defined +// in quiche/http2/http2_structures.h). + +#include + +#include + +#include "absl/strings/string_view.h" +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/test_tools/http2_frame_builder.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/http2/test_tools/http2_structures_test_util.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { +namespace { + +template +absl::string_view ToStringPiece(T (&data)[N]) { + return absl::string_view(reinterpret_cast(data), N * sizeof(T)); +} + +template +std::string SerializeStructure(const S& s) { + Http2FrameBuilder fb; + fb.Append(s); + EXPECT_EQ(S::EncodedSize(), fb.size()); + return fb.buffer(); +} + +template +class StructureDecoderTest : public quiche::test::QuicheTest { + protected: + typedef S Structure; + + StructureDecoderTest() : random_(), random_decode_count_(100) {} + + // Set the fields of |*p| to random values. + void Randomize(S* p) { ::http2::test::Randomize(p, &random_); } + + // Fully decodes the Structure at the start of data, and confirms it matches + // *expected (if provided). + void DecodeLeadingStructure(const S* expected, absl::string_view data) { + ASSERT_LE(S::EncodedSize(), data.size()); + DecodeBuffer db(data); + Randomize(&structure_); + DoDecode(&structure_, &db); + EXPECT_EQ(db.Offset(), S::EncodedSize()); + if (expected != nullptr) { + EXPECT_EQ(structure_, *expected); + } + } + + template + void DecodeLeadingStructure(const char (&data)[N]) { + DecodeLeadingStructure(nullptr, absl::string_view(data, N)); + } + + // Encode the structure |in_s| into bytes, then decode the bytes + // and validate that the decoder produced the same field values. + void EncodeThenDecode(const S& in_s) { + std::string bytes = SerializeStructure(in_s); + EXPECT_EQ(S::EncodedSize(), bytes.size()); + DecodeLeadingStructure(&in_s, bytes); + } + + // Generate + void TestDecodingRandomizedStructures(size_t count) { + for (size_t i = 0; i < count && !HasFailure(); ++i) { + Structure input; + Randomize(&input); + EncodeThenDecode(input); + } + } + + void TestDecodingRandomizedStructures() { + TestDecodingRandomizedStructures(random_decode_count_); + } + + Http2Random random_; + const size_t random_decode_count_; + uint32_t decode_offset_ = 0; + S structure_; + size_t fast_decode_count_ = 0; + size_t slow_decode_count_ = 0; +}; + +class FrameHeaderDecoderTest : public StructureDecoderTest {}; + +TEST_F(FrameHeaderDecoderTest, DecodesLiteral) { + { + // Realistic input. + const char kData[] = { + '\x00', '\x00', '\x05', // Payload length: 5 + '\x01', // Frame type: HEADERS + '\x08', // Flags: PADDED + '\x00', '\x00', '\x00', '\x01', // Stream ID: 1 + '\x04', // Padding length: 4 + '\x00', '\x00', '\x00', '\x00', // Padding bytes + }; + DecodeLeadingStructure(kData); + if (!HasFailure()) { + EXPECT_EQ(5u, structure_.payload_length); + EXPECT_EQ(Http2FrameType::HEADERS, structure_.type); + EXPECT_EQ(Http2FrameFlag::PADDED, structure_.flags); + EXPECT_EQ(1u, structure_.stream_id); + } + } + { + // Unlikely input. + const char kData[] = { + '\xff', '\xff', '\xff', // Payload length: uint24 max + '\xff', // Frame type: Unknown + '\xff', // Flags: Unknown/All + '\xff', '\xff', '\xff', '\xff', // Stream ID: uint31 max, plus R-bit + }; + DecodeLeadingStructure(kData); + if (!HasFailure()) { + EXPECT_EQ((1u << 24) - 1, structure_.payload_length); + EXPECT_EQ(static_cast(255), structure_.type); + EXPECT_EQ(255, structure_.flags); + EXPECT_EQ(0x7FFFFFFFu, structure_.stream_id); + } + } +} + +TEST_F(FrameHeaderDecoderTest, DecodesRandomized) { + TestDecodingRandomizedStructures(); +} + +//------------------------------------------------------------------------------ + +class PriorityFieldsDecoderTest + : public StructureDecoderTest {}; + +TEST_F(PriorityFieldsDecoderTest, DecodesLiteral) { + { + const char kData[] = { + '\x80', '\x00', '\x00', '\x05', // Exclusive (yes) and Dependency (5) + '\xff', // Weight: 256 (after adding 1) + }; + DecodeLeadingStructure(kData); + if (!HasFailure()) { + EXPECT_EQ(5u, structure_.stream_dependency); + EXPECT_EQ(256u, structure_.weight); + EXPECT_EQ(true, structure_.is_exclusive); + } + } + { + const char kData[] = { + '\x7f', '\xff', + '\xff', '\xff', // Exclusive (no) and Dependency (0x7fffffff) + '\x00', // Weight: 1 (after adding 1) + }; + DecodeLeadingStructure(kData); + if (!HasFailure()) { + EXPECT_EQ(StreamIdMask(), structure_.stream_dependency); + EXPECT_EQ(1u, structure_.weight); + EXPECT_FALSE(structure_.is_exclusive); + } + } +} + +TEST_F(PriorityFieldsDecoderTest, DecodesRandomized) { + TestDecodingRandomizedStructures(); +} + +//------------------------------------------------------------------------------ + +class RstStreamFieldsDecoderTest + : public StructureDecoderTest {}; + +TEST_F(RstStreamFieldsDecoderTest, DecodesLiteral) { + { + const char kData[] = { + '\x00', '\x00', '\x00', '\x01', // Error: PROTOCOL_ERROR + }; + DecodeLeadingStructure(kData); + if (!HasFailure()) { + EXPECT_TRUE(structure_.IsSupportedErrorCode()); + EXPECT_EQ(Http2ErrorCode::PROTOCOL_ERROR, structure_.error_code); + } + } + { + const char kData[] = { + '\xff', '\xff', '\xff', + '\xff', // Error: max uint32 (Unknown error code) + }; + DecodeLeadingStructure(kData); + if (!HasFailure()) { + EXPECT_FALSE(structure_.IsSupportedErrorCode()); + EXPECT_EQ(static_cast(0xffffffff), structure_.error_code); + } + } +} + +TEST_F(RstStreamFieldsDecoderTest, DecodesRandomized) { + TestDecodingRandomizedStructures(); +} + +//------------------------------------------------------------------------------ + +class SettingFieldsDecoderTest + : public StructureDecoderTest {}; + +TEST_F(SettingFieldsDecoderTest, DecodesLiteral) { + { + const char kData[] = { + '\x00', '\x01', // Setting: HEADER_TABLE_SIZE + '\x00', '\x00', '\x40', '\x00', // Value: 16K + }; + DecodeLeadingStructure(kData); + if (!HasFailure()) { + EXPECT_TRUE(structure_.IsSupportedParameter()); + EXPECT_EQ(Http2SettingsParameter::HEADER_TABLE_SIZE, + structure_.parameter); + EXPECT_EQ(1u << 14, structure_.value); + } + } + { + const char kData[] = { + '\x00', '\x00', // Setting: Unknown (0) + '\xff', '\xff', '\xff', '\xff', // Value: max uint32 + }; + DecodeLeadingStructure(kData); + if (!HasFailure()) { + EXPECT_FALSE(structure_.IsSupportedParameter()); + EXPECT_EQ(static_cast(0), structure_.parameter); + } + } +} + +TEST_F(SettingFieldsDecoderTest, DecodesRandomized) { + TestDecodingRandomizedStructures(); +} + +//------------------------------------------------------------------------------ + +class PushPromiseFieldsDecoderTest + : public StructureDecoderTest {}; + +TEST_F(PushPromiseFieldsDecoderTest, DecodesLiteral) { + { + const char kData[] = { + '\x00', '\x01', '\x8a', '\x92', // Promised Stream ID: 101010 + }; + DecodeLeadingStructure(kData); + if (!HasFailure()) { + EXPECT_EQ(101010u, structure_.promised_stream_id); + } + } + { + // Promised stream id has R-bit (reserved for future use) set, which + // should be cleared by the decoder. + const char kData[] = { + '\xff', '\xff', '\xff', + '\xff', // Promised Stream ID: max uint31 and R-bit + }; + DecodeLeadingStructure(kData); + if (!HasFailure()) { + EXPECT_EQ(StreamIdMask(), structure_.promised_stream_id); + } + } +} + +TEST_F(PushPromiseFieldsDecoderTest, DecodesRandomized) { + TestDecodingRandomizedStructures(); +} + +//------------------------------------------------------------------------------ + +class PingFieldsDecoderTest : public StructureDecoderTest {}; + +TEST_F(PingFieldsDecoderTest, DecodesLiteral) { + { + // Each byte is different, so can detect if order changed. + const char kData[] = { + '\x00', '\x01', '\x02', '\x03', '\x04', '\x05', '\x06', '\x07', + }; + DecodeLeadingStructure(kData); + if (!HasFailure()) { + EXPECT_EQ(absl::string_view(kData, 8), + ToStringPiece(structure_.opaque_bytes)); + } + } + { + // All zeros, detect problems handling NULs. + const char kData[] = { + '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', + }; + DecodeLeadingStructure(kData); + if (!HasFailure()) { + EXPECT_EQ(absl::string_view(kData, 8), + ToStringPiece(structure_.opaque_bytes)); + } + } + { + const char kData[] = { + '\xff', '\xff', '\xff', '\xff', '\xff', '\xff', '\xff', '\xff', + }; + DecodeLeadingStructure(kData); + if (!HasFailure()) { + EXPECT_EQ(absl::string_view(kData, 8), + ToStringPiece(structure_.opaque_bytes)); + } + } +} + +TEST_F(PingFieldsDecoderTest, DecodesRandomized) { + TestDecodingRandomizedStructures(); +} + +//------------------------------------------------------------------------------ + +class GoAwayFieldsDecoderTest : public StructureDecoderTest { +}; + +TEST_F(GoAwayFieldsDecoderTest, DecodesLiteral) { + { + const char kData[] = { + '\x00', '\x00', '\x00', '\x00', // Last Stream ID: 0 + '\x00', '\x00', '\x00', '\x00', // Error: NO_ERROR (0) + }; + DecodeLeadingStructure(kData); + if (!HasFailure()) { + EXPECT_EQ(0u, structure_.last_stream_id); + EXPECT_TRUE(structure_.IsSupportedErrorCode()); + EXPECT_EQ(Http2ErrorCode::HTTP2_NO_ERROR, structure_.error_code); + } + } + { + const char kData[] = { + '\x00', '\x00', '\x00', '\x01', // Last Stream ID: 1 + '\x00', '\x00', '\x00', '\x0d', // Error: HTTP_1_1_REQUIRED + }; + DecodeLeadingStructure(kData); + if (!HasFailure()) { + EXPECT_EQ(1u, structure_.last_stream_id); + EXPECT_TRUE(structure_.IsSupportedErrorCode()); + EXPECT_EQ(Http2ErrorCode::HTTP_1_1_REQUIRED, structure_.error_code); + } + } + { + const char kData[] = { + '\xff', '\xff', + '\xff', '\xff', // Last Stream ID: max uint31 and R-bit + '\xff', '\xff', + '\xff', '\xff', // Error: max uint32 (Unknown error code) + }; + DecodeLeadingStructure(kData); + if (!HasFailure()) { + EXPECT_EQ(StreamIdMask(), structure_.last_stream_id); // No high-bit. + EXPECT_FALSE(structure_.IsSupportedErrorCode()); + EXPECT_EQ(static_cast(0xffffffff), structure_.error_code); + } + } +} + +TEST_F(GoAwayFieldsDecoderTest, DecodesRandomized) { + TestDecodingRandomizedStructures(); +} + +//------------------------------------------------------------------------------ + +class WindowUpdateFieldsDecoderTest + : public StructureDecoderTest {}; + +TEST_F(WindowUpdateFieldsDecoderTest, DecodesLiteral) { + { + const char kData[] = { + '\x00', '\x01', '\x00', '\x00', // Window Size Increment: 2 ^ 16 + }; + DecodeLeadingStructure(kData); + if (!HasFailure()) { + EXPECT_EQ(1u << 16, structure_.window_size_increment); + } + } + { + // Increment must be non-zero, but we need to be able to decode the invalid + // zero to detect it. + const char kData[] = { + '\x00', '\x00', '\x00', '\x00', // Window Size Increment: 0 + }; + DecodeLeadingStructure(kData); + if (!HasFailure()) { + EXPECT_EQ(0u, structure_.window_size_increment); + } + } + { + // Increment has R-bit (reserved for future use) set, which + // should be cleared by the decoder. + // clang-format off + const char kData[] = { + // Window Size Increment: max uint31 and R-bit + '\xff', '\xff', '\xff', '\xff', + }; + // clang-format on + DecodeLeadingStructure(kData); + if (!HasFailure()) { + EXPECT_EQ(StreamIdMask(), structure_.window_size_increment); + } + } +} + +TEST_F(WindowUpdateFieldsDecoderTest, DecodesRandomized) { + TestDecodingRandomizedStructures(); +} + +//------------------------------------------------------------------------------ + +class AltSvcFieldsDecoderTest : public StructureDecoderTest { +}; + +TEST_F(AltSvcFieldsDecoderTest, DecodesLiteral) { + { + const char kData[] = { + '\x00', '\x00', // Origin Length: 0 + }; + DecodeLeadingStructure(kData); + if (!HasFailure()) { + EXPECT_EQ(0, structure_.origin_length); + } + } + { + const char kData[] = { + '\x00', '\x14', // Origin Length: 20 + }; + DecodeLeadingStructure(kData); + if (!HasFailure()) { + EXPECT_EQ(20, structure_.origin_length); + } + } + { + const char kData[] = { + '\xff', '\xff', // Origin Length: uint16 max + }; + DecodeLeadingStructure(kData); + if (!HasFailure()) { + EXPECT_EQ(65535, structure_.origin_length); + } + } +} + +TEST_F(AltSvcFieldsDecoderTest, DecodesRandomized) { + TestDecodingRandomizedStructures(); +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/decoder/decode_status.cc b/quiche/http2/decoder/decode_status.cc new file mode 100644 index 000000000000..c5887ad4b24c --- /dev/null +++ b/quiche/http2/decoder/decode_status.cc @@ -0,0 +1,28 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/decode_status.h" + +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +std::ostream& operator<<(std::ostream& out, DecodeStatus v) { + switch (v) { + case DecodeStatus::kDecodeDone: + return out << "DecodeDone"; + case DecodeStatus::kDecodeInProgress: + return out << "DecodeInProgress"; + case DecodeStatus::kDecodeError: + return out << "DecodeError"; + } + // Since the value doesn't come over the wire, only a programming bug should + // result in reaching this point. + int unknown = static_cast(v); + QUICHE_BUG(http2_bug_147_1) << "Unknown DecodeStatus " << unknown; + return out << "DecodeStatus(" << unknown << ")"; +} + +} // namespace http2 diff --git a/quiche/http2/decoder/decode_status.h b/quiche/http2/decoder/decode_status.h new file mode 100644 index 000000000000..7571938af486 --- /dev/null +++ b/quiche/http2/decoder/decode_status.h @@ -0,0 +1,32 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_DECODER_DECODE_STATUS_H_ +#define QUICHE_HTTP2_DECODER_DECODE_STATUS_H_ + +// Enum DecodeStatus is used to report the status of decoding of many +// types of HTTP/2 and HPACK objects. + +#include + +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { + +enum class DecodeStatus { + // Decoding is done. + kDecodeDone, + + // Decoder needs more input to be able to make progress. + kDecodeInProgress, + + // Decoding failed (e.g. HPACK variable length integer is too large, or + // an HTTP/2 frame has padding declared to be larger than the payload). + kDecodeError, +}; +QUICHE_EXPORT std::ostream& operator<<(std::ostream& out, DecodeStatus v); + +} // namespace http2 + +#endif // QUICHE_HTTP2_DECODER_DECODE_STATUS_H_ diff --git a/quiche/http2/decoder/frame_decoder_state.cc b/quiche/http2/decoder/frame_decoder_state.cc new file mode 100644 index 000000000000..fe535a54c9ec --- /dev/null +++ b/quiche/http2/decoder/frame_decoder_state.cc @@ -0,0 +1,80 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/frame_decoder_state.h" + +namespace http2 { + +DecodeStatus FrameDecoderState::ReadPadLength(DecodeBuffer* db, + bool report_pad_length) { + QUICHE_DVLOG(2) << "ReadPadLength db->Remaining=" << db->Remaining() + << "; payload_length=" << frame_header().payload_length; + QUICHE_DCHECK(IsPaddable()); + QUICHE_DCHECK(frame_header().IsPadded()); + + // Pad Length is always at the start of the frame, so remaining_payload_ + // should equal payload_length at this point. + const uint32_t total_payload = frame_header().payload_length; + QUICHE_DCHECK_EQ(total_payload, remaining_payload_); + QUICHE_DCHECK_EQ(0u, remaining_padding_); + + if (db->HasData()) { + const uint32_t pad_length = db->DecodeUInt8(); + const uint32_t total_padding = pad_length + 1; + if (total_padding <= total_payload) { + remaining_padding_ = pad_length; + remaining_payload_ = total_payload - total_padding; + if (report_pad_length) { + listener()->OnPadLength(pad_length); + } + return DecodeStatus::kDecodeDone; + } + const uint32_t missing_length = total_padding - total_payload; + // To allow for the possibility of recovery, record the number of + // remaining bytes of the frame's payload (invalid though it is) + // in remaining_payload_. + remaining_payload_ = total_payload - 1; // 1 for sizeof(Pad Length). + remaining_padding_ = 0; + listener()->OnPaddingTooLong(frame_header(), missing_length); + return DecodeStatus::kDecodeError; + } + + if (total_payload == 0) { + remaining_payload_ = 0; + remaining_padding_ = 0; + listener()->OnPaddingTooLong(frame_header(), 1); + return DecodeStatus::kDecodeError; + } + // Need to wait for another buffer. + return DecodeStatus::kDecodeInProgress; +} + +bool FrameDecoderState::SkipPadding(DecodeBuffer* db) { + QUICHE_DVLOG(2) << "SkipPadding remaining_padding_=" << remaining_padding_ + << ", db->Remaining=" << db->Remaining() + << ", header: " << frame_header(); + QUICHE_DCHECK_EQ(remaining_payload_, 0u); + QUICHE_DCHECK(IsPaddable()) << "header: " << frame_header(); + QUICHE_DCHECK(remaining_padding_ == 0 || frame_header().IsPadded()) + << "remaining_padding_=" << remaining_padding_ + << ", header: " << frame_header(); + const size_t avail = AvailablePadding(db); + if (avail > 0) { + listener()->OnPadding(db->cursor(), avail); + db->AdvanceCursor(avail); + remaining_padding_ -= avail; + } + return remaining_padding_ == 0; +} + +DecodeStatus FrameDecoderState::ReportFrameSizeError() { + QUICHE_DVLOG(2) << "FrameDecoderState::ReportFrameSizeError: " + << " remaining_payload_=" << remaining_payload_ + << "; remaining_padding_=" << remaining_padding_ + << ", header: " << frame_header(); + listener()->OnFrameSizeError(frame_header()); + return DecodeStatus::kDecodeError; +} + +} // namespace http2 diff --git a/quiche/http2/decoder/frame_decoder_state.h b/quiche/http2/decoder/frame_decoder_state.h new file mode 100644 index 000000000000..39757da89a9b --- /dev/null +++ b/quiche/http2/decoder/frame_decoder_state.h @@ -0,0 +1,252 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_DECODER_FRAME_DECODER_STATE_H_ +#define QUICHE_HTTP2_DECODER_FRAME_DECODER_STATE_H_ + +// FrameDecoderState provides state and behaviors in support of decoding +// the common frame header and the payload of all frame types. +// It is an input to all of the payload decoders. + +// TODO(jamessynge): Since FrameDecoderState has far more than state in it, +// rename to FrameDecoderHelper, or similar. + +#include + +#include + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/decoder/http2_structure_decoder.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { +namespace test { +class FrameDecoderStatePeer; +} // namespace test + +class QUICHE_EXPORT FrameDecoderState { + public: + FrameDecoderState() {} + + // Sets the listener which the decoders should call as they decode HTTP/2 + // frames. The listener can be changed at any time, which allows for replacing + // it with a no-op listener when an error is detected, either by the payload + // decoder (OnPaddingTooLong or OnFrameSizeError) or by the "real" listener. + // That in turn allows us to define Http2FrameDecoderListener such that all + // methods have return type void, with no direct way to indicate whether the + // decoder should stop, and to eliminate from the decoder all checks of the + // return value. Instead the listener/caller can simply replace the current + // listener with a no-op listener implementation. + // TODO(jamessynge): Make set_listener private as only Http2FrameDecoder + // and tests need to set it, so it doesn't need to be public. + void set_listener(Http2FrameDecoderListener* listener) { + listener_ = listener; + } + Http2FrameDecoderListener* listener() const { return listener_; } + + // The most recently decoded frame header. + const Http2FrameHeader& frame_header() const { return frame_header_; } + + // Decode a structure in the payload, adjusting remaining_payload_ to account + // for the consumed portion of the payload. Returns kDecodeDone when fully + // decoded, kDecodeError if it ran out of payload before decoding completed, + // and kDecodeInProgress if the decode buffer didn't have enough of the + // remaining payload. + template + DecodeStatus StartDecodingStructureInPayload(S* out, DecodeBuffer* db) { + QUICHE_DVLOG(2) << __func__ << "\n\tdb->Remaining=" << db->Remaining() + << "\n\tremaining_payload_=" << remaining_payload_ + << "\n\tneed=" << S::EncodedSize(); + DecodeStatus status = + structure_decoder_.Start(out, db, &remaining_payload_); + if (status != DecodeStatus::kDecodeError) { + return status; + } + QUICHE_DVLOG(2) + << "StartDecodingStructureInPayload: detected frame size error"; + return ReportFrameSizeError(); + } + + // Resume decoding of a structure that has been split across buffers, + // adjusting remaining_payload_ to account for the consumed portion of + // the payload. Returns values are as for StartDecodingStructureInPayload. + template + DecodeStatus ResumeDecodingStructureInPayload(S* out, DecodeBuffer* db) { + QUICHE_DVLOG(2) << __func__ << "\n\tdb->Remaining=" << db->Remaining() + << "\n\tremaining_payload_=" << remaining_payload_; + if (structure_decoder_.Resume(out, db, &remaining_payload_)) { + return DecodeStatus::kDecodeDone; + } else if (remaining_payload_ > 0) { + return DecodeStatus::kDecodeInProgress; + } else { + QUICHE_DVLOG(2) + << "ResumeDecodingStructureInPayload: detected frame size error"; + return ReportFrameSizeError(); + } + } + + // Initializes the two remaining* fields, which is needed if the frame's + // payload is split across buffers, or the decoder calls ReadPadLength or + // StartDecodingStructureInPayload, and of course the methods below which + // read those fields, as their names imply. + void InitializeRemainders() { + remaining_payload_ = frame_header().payload_length; + // Note that remaining_total_payload() relies on remaining_padding_ being + // zero for frames that have no padding. + remaining_padding_ = 0; + } + + // Returns the number of bytes of the frame's payload that remain to be + // decoded, including any trailing padding. This method must only be called + // after the variables have been initialized, which in practice means once a + // payload decoder has called InitializeRemainders and/or ReadPadLength. + size_t remaining_total_payload() const { + QUICHE_DCHECK(IsPaddable() || remaining_padding_ == 0) << frame_header(); + return remaining_payload_ + remaining_padding_; + } + + // Returns the number of bytes of the frame's payload that remain to be + // decoded, excluding any trailing padding. This method must only be called + // after the variable has been initialized, which in practice means once a + // payload decoder has called InitializeRemainders; ReadPadLength will deduct + // the total number of padding bytes from remaining_payload_, including the + // size of the Pad Length field itself (1 byte). + size_t remaining_payload() const { return remaining_payload_; } + + // Returns the number of bytes of the frame's payload that remain to be + // decoded, including any trailing padding. This method must only be called if + // the frame type allows padding, and after the variable has been initialized, + // which in practice means once a payload decoder has called + // InitializeRemainders and/or ReadPadLength. + size_t remaining_payload_and_padding() const { + QUICHE_DCHECK(IsPaddable()) << frame_header(); + return remaining_payload_ + remaining_padding_; + } + + // Returns the number of bytes of trailing padding after the payload that + // remain to be decoded. This method must only be called if the frame type + // allows padding, and after the variable has been initialized, which in + // practice means once a payload decoder has called InitializeRemainders, + // and isn't set to a non-zero value until ReadPadLength has been called. + uint32_t remaining_padding() const { + QUICHE_DCHECK(IsPaddable()) << frame_header(); + return remaining_padding_; + } + + // How many bytes of the remaining payload are in db? + size_t AvailablePayload(DecodeBuffer* db) const { + return db->MinLengthRemaining(remaining_payload_); + } + + // How many bytes of the remaining payload and padding are in db? + // Call only for frames whose type is paddable. + size_t AvailablePayloadAndPadding(DecodeBuffer* db) const { + QUICHE_DCHECK(IsPaddable()) << frame_header(); + return db->MinLengthRemaining(remaining_payload_ + remaining_padding_); + } + + // How many bytes of the padding that have not yet been skipped are in db? + // Call only after remaining_padding_ has been set (for padded frames), or + // been cleared (for unpadded frames); and after all of the non-padding + // payload has been decoded. + size_t AvailablePadding(DecodeBuffer* db) const { + QUICHE_DCHECK(IsPaddable()) << frame_header(); + QUICHE_DCHECK_EQ(remaining_payload_, 0u); + return db->MinLengthRemaining(remaining_padding_); + } + + // Reduces remaining_payload_ by amount. To be called by a payload decoder + // after it has passed a variable length portion of the payload to the + // listener; remaining_payload_ will be automatically reduced when fixed + // size structures and padding, including the Pad Length field, are decoded. + void ConsumePayload(size_t amount) { + QUICHE_DCHECK_LE(amount, remaining_payload_); + remaining_payload_ -= amount; + } + + // Reads the Pad Length field into remaining_padding_, and appropriately sets + // remaining_payload_. When present, the Pad Length field is always the first + // field in the payload, which this method relies on so that the caller need + // not set remaining_payload_ before calling this method. + // If report_pad_length is true, calls the listener's OnPadLength method when + // it decodes the Pad Length field. + // Returns kDecodeDone if the decode buffer was not empty (i.e. because the + // field is only a single byte long, it can always be decoded if the buffer is + // not empty). + // Returns kDecodeError if the buffer is empty because the frame has no + // payload (i.e. payload_length() == 0). + // Returns kDecodeInProgress if the buffer is empty but the frame has a + // payload. + DecodeStatus ReadPadLength(DecodeBuffer* db, bool report_pad_length); + + // Skip the trailing padding bytes; only call once remaining_payload_==0. + // Returns true when the padding has been skipped. + // Does NOT check that the padding is all zeroes. + bool SkipPadding(DecodeBuffer* db); + + // Calls the listener's OnFrameSizeError method and returns kDecodeError. + DecodeStatus ReportFrameSizeError(); + + private: + friend class Http2FrameDecoder; + friend class test::FrameDecoderStatePeer; + + // Starts the decoding of a common frame header. Returns true if completed the + // decoding, false if the decode buffer didn't have enough data in it, in + // which case the decode buffer will have been drained and the caller should + // call ResumeDecodingFrameHeader when more data is available. This is called + // from Http2FrameDecoder, a friend class. + bool StartDecodingFrameHeader(DecodeBuffer* db) { + return structure_decoder_.Start(&frame_header_, db); + } + + // Resumes decoding the common frame header after the preceding call to + // StartDecodingFrameHeader returned false, as did any subsequent calls to + // ResumeDecodingFrameHeader. This is called from Http2FrameDecoder, + // a friend class. + bool ResumeDecodingFrameHeader(DecodeBuffer* db) { + return structure_decoder_.Resume(&frame_header_, db); + } + + // Clear any of the flags in the frame header that aren't set in valid_flags. + void RetainFlags(uint8_t valid_flags) { + frame_header_.RetainFlags(valid_flags); + } + + // Clear all of the flags in the frame header; for use with frame types that + // don't define any flags, such as WINDOW_UPDATE. + void ClearFlags() { frame_header_.flags = Http2FrameFlag(); } + + // Returns true if the type of frame being decoded can have padding. + bool IsPaddable() const { + return frame_header().type == Http2FrameType::DATA || + frame_header().type == Http2FrameType::HEADERS || + frame_header().type == Http2FrameType::PUSH_PROMISE; + } + + Http2FrameDecoderListener* listener_ = nullptr; + Http2FrameHeader frame_header_; + + // Number of bytes remaining to be decoded, if set; does not include the + // trailing padding once the length of padding has been determined. + // See ReadPadLength. + uint32_t remaining_payload_; + + // The amount of trailing padding after the payload that remains to be + // decoded. See ReadPadLength. + uint32_t remaining_padding_; + + // Generic decoder of structures, which takes care of buffering the needed + // bytes if the encoded structure is split across decode buffers. + Http2StructureDecoder structure_decoder_; +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_DECODER_FRAME_DECODER_STATE_H_ diff --git a/quiche/http2/decoder/http2_frame_decoder.cc b/quiche/http2/decoder/http2_frame_decoder.cc new file mode 100644 index 000000000000..14b34308ff71 --- /dev/null +++ b/quiche/http2/decoder/http2_frame_decoder.cc @@ -0,0 +1,456 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/http2_frame_decoder.h" + +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/hpack/varint/hpack_varint_decoder.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +std::ostream& operator<<(std::ostream& out, Http2FrameDecoder::State v) { + switch (v) { + case Http2FrameDecoder::State::kStartDecodingHeader: + return out << "kStartDecodingHeader"; + case Http2FrameDecoder::State::kResumeDecodingHeader: + return out << "kResumeDecodingHeader"; + case Http2FrameDecoder::State::kResumeDecodingPayload: + return out << "kResumeDecodingPayload"; + case Http2FrameDecoder::State::kDiscardPayload: + return out << "kDiscardPayload"; + } + // Since the value doesn't come over the wire, only a programming bug should + // result in reaching this point. + int unknown = static_cast(v); + QUICHE_BUG(http2_bug_155_1) << "Http2FrameDecoder::State " << unknown; + return out << "Http2FrameDecoder::State(" << unknown << ")"; +} + +Http2FrameDecoder::Http2FrameDecoder(Http2FrameDecoderListener* listener) + : state_(State::kStartDecodingHeader), + maximum_payload_size_(Http2SettingsInfo::DefaultMaxFrameSize()) { + set_listener(listener); +} + +void Http2FrameDecoder::set_listener(Http2FrameDecoderListener* listener) { + if (listener == nullptr) { + listener = &no_op_listener_; + } + frame_decoder_state_.set_listener(listener); +} + +Http2FrameDecoderListener* Http2FrameDecoder::listener() const { + return frame_decoder_state_.listener(); +} + +DecodeStatus Http2FrameDecoder::DecodeFrame(DecodeBuffer* db) { + QUICHE_DVLOG(2) << "Http2FrameDecoder::DecodeFrame state=" << state_; + switch (state_) { + case State::kStartDecodingHeader: + if (frame_decoder_state_.StartDecodingFrameHeader(db)) { + return StartDecodingPayload(db); + } + state_ = State::kResumeDecodingHeader; + return DecodeStatus::kDecodeInProgress; + + case State::kResumeDecodingHeader: + if (frame_decoder_state_.ResumeDecodingFrameHeader(db)) { + return StartDecodingPayload(db); + } + return DecodeStatus::kDecodeInProgress; + + case State::kResumeDecodingPayload: + return ResumeDecodingPayload(db); + + case State::kDiscardPayload: + return DiscardPayload(db); + } + + QUICHE_NOTREACHED(); + return DecodeStatus::kDecodeError; +} + +size_t Http2FrameDecoder::remaining_payload() const { + return frame_decoder_state_.remaining_payload(); +} + +uint32_t Http2FrameDecoder::remaining_padding() const { + return frame_decoder_state_.remaining_padding(); +} + +DecodeStatus Http2FrameDecoder::StartDecodingPayload(DecodeBuffer* db) { + const Http2FrameHeader& header = frame_header(); + + // TODO(jamessynge): Remove OnFrameHeader once done with supporting + // SpdyFramer's exact states. + if (!listener()->OnFrameHeader(header)) { + QUICHE_DVLOG(2) + << "OnFrameHeader rejected the frame, will discard; header: " << header; + state_ = State::kDiscardPayload; + frame_decoder_state_.InitializeRemainders(); + return DecodeStatus::kDecodeError; + } + + if (header.payload_length > maximum_payload_size_) { + QUICHE_DVLOG(2) << "Payload length is greater than allowed: " + << header.payload_length << " > " << maximum_payload_size_ + << "\n header: " << header; + state_ = State::kDiscardPayload; + frame_decoder_state_.InitializeRemainders(); + listener()->OnFrameSizeError(header); + return DecodeStatus::kDecodeError; + } + + // The decode buffer can extend across many frames. Make sure that the + // buffer we pass to the start method that is specific to the frame type + // does not exend beyond this frame. + DecodeBufferSubset subset(db, header.payload_length); + DecodeStatus status; + switch (header.type) { + case Http2FrameType::DATA: + status = StartDecodingDataPayload(&subset); + break; + + case Http2FrameType::HEADERS: + status = StartDecodingHeadersPayload(&subset); + break; + + case Http2FrameType::PRIORITY: + status = StartDecodingPriorityPayload(&subset); + break; + + case Http2FrameType::RST_STREAM: + status = StartDecodingRstStreamPayload(&subset); + break; + + case Http2FrameType::SETTINGS: + status = StartDecodingSettingsPayload(&subset); + break; + + case Http2FrameType::PUSH_PROMISE: + status = StartDecodingPushPromisePayload(&subset); + break; + + case Http2FrameType::PING: + status = StartDecodingPingPayload(&subset); + break; + + case Http2FrameType::GOAWAY: + status = StartDecodingGoAwayPayload(&subset); + break; + + case Http2FrameType::WINDOW_UPDATE: + status = StartDecodingWindowUpdatePayload(&subset); + break; + + case Http2FrameType::CONTINUATION: + status = StartDecodingContinuationPayload(&subset); + break; + + case Http2FrameType::ALTSVC: + status = StartDecodingAltSvcPayload(&subset); + break; + + case Http2FrameType::PRIORITY_UPDATE: + status = StartDecodingPriorityUpdatePayload(&subset); + break; + + default: + status = StartDecodingUnknownPayload(&subset); + break; + } + + if (status == DecodeStatus::kDecodeDone) { + state_ = State::kStartDecodingHeader; + return status; + } else if (status == DecodeStatus::kDecodeInProgress) { + state_ = State::kResumeDecodingPayload; + return status; + } else { + state_ = State::kDiscardPayload; + return status; + } +} + +DecodeStatus Http2FrameDecoder::ResumeDecodingPayload(DecodeBuffer* db) { + // The decode buffer can extend across many frames. Make sure that the + // buffer we pass to the start method that is specific to the frame type + // does not exend beyond this frame. + size_t remaining = frame_decoder_state_.remaining_total_payload(); + QUICHE_DCHECK_LE(remaining, frame_header().payload_length); + DecodeBufferSubset subset(db, remaining); + DecodeStatus status; + switch (frame_header().type) { + case Http2FrameType::DATA: + status = ResumeDecodingDataPayload(&subset); + break; + + case Http2FrameType::HEADERS: + status = ResumeDecodingHeadersPayload(&subset); + break; + + case Http2FrameType::PRIORITY: + status = ResumeDecodingPriorityPayload(&subset); + break; + + case Http2FrameType::RST_STREAM: + status = ResumeDecodingRstStreamPayload(&subset); + break; + + case Http2FrameType::SETTINGS: + status = ResumeDecodingSettingsPayload(&subset); + break; + + case Http2FrameType::PUSH_PROMISE: + status = ResumeDecodingPushPromisePayload(&subset); + break; + + case Http2FrameType::PING: + status = ResumeDecodingPingPayload(&subset); + break; + + case Http2FrameType::GOAWAY: + status = ResumeDecodingGoAwayPayload(&subset); + break; + + case Http2FrameType::WINDOW_UPDATE: + status = ResumeDecodingWindowUpdatePayload(&subset); + break; + + case Http2FrameType::CONTINUATION: + status = ResumeDecodingContinuationPayload(&subset); + break; + + case Http2FrameType::ALTSVC: + status = ResumeDecodingAltSvcPayload(&subset); + break; + + case Http2FrameType::PRIORITY_UPDATE: + status = ResumeDecodingPriorityUpdatePayload(&subset); + break; + + default: + status = ResumeDecodingUnknownPayload(&subset); + break; + } + + if (status == DecodeStatus::kDecodeDone) { + state_ = State::kStartDecodingHeader; + return status; + } else if (status == DecodeStatus::kDecodeInProgress) { + return status; + } else { + state_ = State::kDiscardPayload; + return status; + } +} + +// Clear any of the flags in the frame header that aren't set in valid_flags. +void Http2FrameDecoder::RetainFlags(uint8_t valid_flags) { + frame_decoder_state_.RetainFlags(valid_flags); +} + +// Clear all of the flags in the frame header; for use with frame types that +// don't define any flags, such as WINDOW_UPDATE. +void Http2FrameDecoder::ClearFlags() { frame_decoder_state_.ClearFlags(); } + +DecodeStatus Http2FrameDecoder::StartDecodingAltSvcPayload(DecodeBuffer* db) { + ClearFlags(); + return altsvc_payload_decoder_.StartDecodingPayload(&frame_decoder_state_, + db); +} +DecodeStatus Http2FrameDecoder::ResumeDecodingAltSvcPayload(DecodeBuffer* db) { + // The frame is not paddable. + QUICHE_DCHECK_EQ(frame_decoder_state_.remaining_total_payload(), + frame_decoder_state_.remaining_payload()); + return altsvc_payload_decoder_.ResumeDecodingPayload(&frame_decoder_state_, + db); +} + +DecodeStatus Http2FrameDecoder::StartDecodingContinuationPayload( + DecodeBuffer* db) { + RetainFlags(Http2FrameFlag::END_HEADERS); + return continuation_payload_decoder_.StartDecodingPayload( + &frame_decoder_state_, db); +} +DecodeStatus Http2FrameDecoder::ResumeDecodingContinuationPayload( + DecodeBuffer* db) { + // The frame is not paddable. + QUICHE_DCHECK_EQ(frame_decoder_state_.remaining_total_payload(), + frame_decoder_state_.remaining_payload()); + return continuation_payload_decoder_.ResumeDecodingPayload( + &frame_decoder_state_, db); +} + +DecodeStatus Http2FrameDecoder::StartDecodingDataPayload(DecodeBuffer* db) { + RetainFlags(Http2FrameFlag::END_STREAM | Http2FrameFlag::PADDED); + return data_payload_decoder_.StartDecodingPayload(&frame_decoder_state_, db); +} +DecodeStatus Http2FrameDecoder::ResumeDecodingDataPayload(DecodeBuffer* db) { + return data_payload_decoder_.ResumeDecodingPayload(&frame_decoder_state_, db); +} + +DecodeStatus Http2FrameDecoder::StartDecodingGoAwayPayload(DecodeBuffer* db) { + ClearFlags(); + return goaway_payload_decoder_.StartDecodingPayload(&frame_decoder_state_, + db); +} +DecodeStatus Http2FrameDecoder::ResumeDecodingGoAwayPayload(DecodeBuffer* db) { + // The frame is not paddable. + QUICHE_DCHECK_EQ(frame_decoder_state_.remaining_total_payload(), + frame_decoder_state_.remaining_payload()); + return goaway_payload_decoder_.ResumeDecodingPayload(&frame_decoder_state_, + db); +} + +DecodeStatus Http2FrameDecoder::StartDecodingHeadersPayload(DecodeBuffer* db) { + RetainFlags(Http2FrameFlag::END_STREAM | Http2FrameFlag::END_HEADERS | + Http2FrameFlag::PADDED | Http2FrameFlag::PRIORITY); + return headers_payload_decoder_.StartDecodingPayload(&frame_decoder_state_, + db); +} +DecodeStatus Http2FrameDecoder::ResumeDecodingHeadersPayload(DecodeBuffer* db) { + QUICHE_DCHECK_LE(frame_decoder_state_.remaining_payload_and_padding(), + frame_header().payload_length); + return headers_payload_decoder_.ResumeDecodingPayload(&frame_decoder_state_, + db); +} + +DecodeStatus Http2FrameDecoder::StartDecodingPingPayload(DecodeBuffer* db) { + RetainFlags(Http2FrameFlag::ACK); + return ping_payload_decoder_.StartDecodingPayload(&frame_decoder_state_, db); +} +DecodeStatus Http2FrameDecoder::ResumeDecodingPingPayload(DecodeBuffer* db) { + // The frame is not paddable. + QUICHE_DCHECK_EQ(frame_decoder_state_.remaining_total_payload(), + frame_decoder_state_.remaining_payload()); + return ping_payload_decoder_.ResumeDecodingPayload(&frame_decoder_state_, db); +} + +DecodeStatus Http2FrameDecoder::StartDecodingPriorityPayload(DecodeBuffer* db) { + ClearFlags(); + return priority_payload_decoder_.StartDecodingPayload(&frame_decoder_state_, + db); +} +DecodeStatus Http2FrameDecoder::ResumeDecodingPriorityPayload( + DecodeBuffer* db) { + // The frame is not paddable. + QUICHE_DCHECK_EQ(frame_decoder_state_.remaining_total_payload(), + frame_decoder_state_.remaining_payload()); + return priority_payload_decoder_.ResumeDecodingPayload(&frame_decoder_state_, + db); +} + +DecodeStatus Http2FrameDecoder::StartDecodingPriorityUpdatePayload( + DecodeBuffer* db) { + ClearFlags(); + return priority_payload_update_decoder_.StartDecodingPayload( + &frame_decoder_state_, db); +} +DecodeStatus Http2FrameDecoder::ResumeDecodingPriorityUpdatePayload( + DecodeBuffer* db) { + // The frame is not paddable. + QUICHE_DCHECK_EQ(frame_decoder_state_.remaining_total_payload(), + frame_decoder_state_.remaining_payload()); + return priority_payload_update_decoder_.ResumeDecodingPayload( + &frame_decoder_state_, db); +} + +DecodeStatus Http2FrameDecoder::StartDecodingPushPromisePayload( + DecodeBuffer* db) { + RetainFlags(Http2FrameFlag::END_HEADERS | Http2FrameFlag::PADDED); + return push_promise_payload_decoder_.StartDecodingPayload( + &frame_decoder_state_, db); +} +DecodeStatus Http2FrameDecoder::ResumeDecodingPushPromisePayload( + DecodeBuffer* db) { + QUICHE_DCHECK_LE(frame_decoder_state_.remaining_payload_and_padding(), + frame_header().payload_length); + return push_promise_payload_decoder_.ResumeDecodingPayload( + &frame_decoder_state_, db); +} + +DecodeStatus Http2FrameDecoder::StartDecodingRstStreamPayload( + DecodeBuffer* db) { + ClearFlags(); + return rst_stream_payload_decoder_.StartDecodingPayload(&frame_decoder_state_, + db); +} +DecodeStatus Http2FrameDecoder::ResumeDecodingRstStreamPayload( + DecodeBuffer* db) { + // The frame is not paddable. + QUICHE_DCHECK_EQ(frame_decoder_state_.remaining_total_payload(), + frame_decoder_state_.remaining_payload()); + return rst_stream_payload_decoder_.ResumeDecodingPayload( + &frame_decoder_state_, db); +} + +DecodeStatus Http2FrameDecoder::StartDecodingSettingsPayload(DecodeBuffer* db) { + RetainFlags(Http2FrameFlag::ACK); + return settings_payload_decoder_.StartDecodingPayload(&frame_decoder_state_, + db); +} +DecodeStatus Http2FrameDecoder::ResumeDecodingSettingsPayload( + DecodeBuffer* db) { + // The frame is not paddable. + QUICHE_DCHECK_EQ(frame_decoder_state_.remaining_total_payload(), + frame_decoder_state_.remaining_payload()); + return settings_payload_decoder_.ResumeDecodingPayload(&frame_decoder_state_, + db); +} + +DecodeStatus Http2FrameDecoder::StartDecodingUnknownPayload(DecodeBuffer* db) { + // We don't known what type of frame this is, so we don't know which flags + // are valid, so we don't touch them. + return unknown_payload_decoder_.StartDecodingPayload(&frame_decoder_state_, + db); +} +DecodeStatus Http2FrameDecoder::ResumeDecodingUnknownPayload(DecodeBuffer* db) { + // We don't known what type of frame this is, so we treat it as not paddable. + QUICHE_DCHECK_EQ(frame_decoder_state_.remaining_total_payload(), + frame_decoder_state_.remaining_payload()); + return unknown_payload_decoder_.ResumeDecodingPayload(&frame_decoder_state_, + db); +} + +DecodeStatus Http2FrameDecoder::StartDecodingWindowUpdatePayload( + DecodeBuffer* db) { + ClearFlags(); + return window_update_payload_decoder_.StartDecodingPayload( + &frame_decoder_state_, db); +} +DecodeStatus Http2FrameDecoder::ResumeDecodingWindowUpdatePayload( + DecodeBuffer* db) { + // The frame is not paddable. + QUICHE_DCHECK_EQ(frame_decoder_state_.remaining_total_payload(), + frame_decoder_state_.remaining_payload()); + return window_update_payload_decoder_.ResumeDecodingPayload( + &frame_decoder_state_, db); +} + +DecodeStatus Http2FrameDecoder::DiscardPayload(DecodeBuffer* db) { + QUICHE_DVLOG(2) << "remaining_payload=" + << frame_decoder_state_.remaining_payload_ + << "; remaining_padding=" + << frame_decoder_state_.remaining_padding_; + frame_decoder_state_.remaining_payload_ += + frame_decoder_state_.remaining_padding_; + frame_decoder_state_.remaining_padding_ = 0; + const size_t avail = frame_decoder_state_.AvailablePayload(db); + QUICHE_DVLOG(2) << "avail=" << avail; + if (avail > 0) { + frame_decoder_state_.ConsumePayload(avail); + db->AdvanceCursor(avail); + } + if (frame_decoder_state_.remaining_payload_ == 0) { + state_ = State::kStartDecodingHeader; + return DecodeStatus::kDecodeDone; + } + return DecodeStatus::kDecodeInProgress; +} + +} // namespace http2 diff --git a/quiche/http2/decoder/http2_frame_decoder.h b/quiche/http2/decoder/http2_frame_decoder.h new file mode 100644 index 000000000000..3d8f1ac8291a --- /dev/null +++ b/quiche/http2/decoder/http2_frame_decoder.h @@ -0,0 +1,214 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_DECODER_HTTP2_FRAME_DECODER_H_ +#define QUICHE_HTTP2_DECODER_HTTP2_FRAME_DECODER_H_ + +// Http2FrameDecoder decodes the available input until it reaches the end of +// the input or it reaches the end of the first frame in the input. +// Note that Http2FrameDecoder does only minimal validation; for example, +// stream ids are not checked, nor is the sequence of frames such as +// CONTINUATION frame placement. +// +// Http2FrameDecoder enters state kError once it has called the listener's +// OnFrameSizeError or OnPaddingTooLong methods, and at this time has no +// provision for leaving that state. While the HTTP/2 spec (RFC7540) allows +// for some such errors to be considered as just stream errors in some cases, +// this implementation treats them all as connection errors. + +#include + +#include + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/decoder/frame_decoder_state.h" +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/decoder/payload_decoders/altsvc_payload_decoder.h" +#include "quiche/http2/decoder/payload_decoders/continuation_payload_decoder.h" +#include "quiche/http2/decoder/payload_decoders/data_payload_decoder.h" +#include "quiche/http2/decoder/payload_decoders/goaway_payload_decoder.h" +#include "quiche/http2/decoder/payload_decoders/headers_payload_decoder.h" +#include "quiche/http2/decoder/payload_decoders/ping_payload_decoder.h" +#include "quiche/http2/decoder/payload_decoders/priority_payload_decoder.h" +#include "quiche/http2/decoder/payload_decoders/priority_update_payload_decoder.h" +#include "quiche/http2/decoder/payload_decoders/push_promise_payload_decoder.h" +#include "quiche/http2/decoder/payload_decoders/rst_stream_payload_decoder.h" +#include "quiche/http2/decoder/payload_decoders/settings_payload_decoder.h" +#include "quiche/http2/decoder/payload_decoders/unknown_payload_decoder.h" +#include "quiche/http2/decoder/payload_decoders/window_update_payload_decoder.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { +namespace test { +class Http2FrameDecoderPeer; +} // namespace test + +class QUICHE_EXPORT Http2FrameDecoder { + public: + explicit Http2FrameDecoder(Http2FrameDecoderListener* listener); + + Http2FrameDecoder(const Http2FrameDecoder&) = delete; + Http2FrameDecoder& operator=(const Http2FrameDecoder&) = delete; + + // The decoder will call the listener's methods as it decodes a frame. + void set_listener(Http2FrameDecoderListener* listener); + Http2FrameDecoderListener* listener() const; + + // The decoder will reject frame's whose payload + // length field exceeds the maximum payload size. + void set_maximum_payload_size(size_t v) { maximum_payload_size_ = v; } + size_t maximum_payload_size() const { return maximum_payload_size_; } + + // Decodes the input up to the next frame boundary (i.e. at most one frame). + // + // Returns kDecodeDone if it decodes the final byte of a frame, OR if there + // is no input and it is awaiting the start of a new frame (e.g. if this + // is the first call to DecodeFrame, or if the previous call returned + // kDecodeDone). + // + // Returns kDecodeInProgress if it decodes all of the decode buffer, but has + // not reached the end of the frame. + // + // Returns kDecodeError if the frame's padding or length wasn't valid (i.e. if + // the decoder called either the listener's OnPaddingTooLong or + // OnFrameSizeError method). + // + // If the decode buffer contains the entirety of a frame payload or field, + // then the corresponding Http2FrameDecoderListener::On*Payload(), + // OnHpackFragment(), OnGoAwayOpaqueData(), or OnAltSvcValueData() method is + // guaranteed to be called exactly once, with the entire payload or field in a + // single chunk. + DecodeStatus DecodeFrame(DecodeBuffer* db); + + ////////////////////////////////////////////////////////////////////////////// + // Methods that support Http2FrameDecoderAdapter. + + // Is the remainder of the frame's payload being discarded? + bool IsDiscardingPayload() const { return state_ == State::kDiscardPayload; } + + // Returns the number of bytes of the frame's payload that remain to be + // decoded, excluding any trailing padding. This method must only be called + // after the frame header has been decoded AND DecodeFrame has returned + // kDecodeInProgress. + size_t remaining_payload() const; + + // Returns the number of bytes of trailing padding after the payload that + // remain to be decoded. This method must only be called if the frame type + // allows padding, and after the frame header has been decoded AND + // DecodeFrame has returned. Will return 0 if the Pad Length field has not + // yet been decoded. + uint32_t remaining_padding() const; + + private: + enum class State { + // Ready to start decoding a new frame's header. + kStartDecodingHeader, + // Was in state kStartDecodingHeader, but unable to read the entire frame + // header, so needs more input to complete decoding the header. + kResumeDecodingHeader, + + // Have decoded the frame header, and started decoding the available bytes + // of the frame's payload, but need more bytes to finish the job. + kResumeDecodingPayload, + + // Decoding of the most recently started frame resulted in an error: + // OnPaddingTooLong or OnFrameSizeError was called to indicate that the + // decoder detected a problem, or OnFrameHeader returned false, indicating + // that the listener detected a problem. Regardless of which, the decoder + // will stay in state kDiscardPayload until it has been passed the rest + // of the bytes of the frame's payload that it hasn't yet seen, after + // which it will be ready to decode another frame. + kDiscardPayload, + }; + + friend class test::Http2FrameDecoderPeer; + QUICHE_EXPORT friend std::ostream& operator<<(std::ostream& out, State v); + + DecodeStatus StartDecodingPayload(DecodeBuffer* db); + DecodeStatus ResumeDecodingPayload(DecodeBuffer* db); + DecodeStatus DiscardPayload(DecodeBuffer* db); + + const Http2FrameHeader& frame_header() const { + return frame_decoder_state_.frame_header(); + } + + // Clear any of the flags in the frame header that aren't set in valid_flags. + void RetainFlags(uint8_t valid_flags); + + // Clear all of the flags in the frame header; for use with frame types that + // don't define any flags, such as WINDOW_UPDATE. + void ClearFlags(); + + // These methods call the StartDecodingPayload() method of the frame type's + // payload decoder, after first clearing invalid flags in the header. The + // caller must ensure that the decode buffer does not extend beyond the + // end of the payload (handled by Http2FrameDecoder::StartDecodingPayload). + DecodeStatus StartDecodingAltSvcPayload(DecodeBuffer* db); + DecodeStatus StartDecodingContinuationPayload(DecodeBuffer* db); + DecodeStatus StartDecodingDataPayload(DecodeBuffer* db); + DecodeStatus StartDecodingGoAwayPayload(DecodeBuffer* db); + DecodeStatus StartDecodingHeadersPayload(DecodeBuffer* db); + DecodeStatus StartDecodingPingPayload(DecodeBuffer* db); + DecodeStatus StartDecodingPriorityPayload(DecodeBuffer* db); + DecodeStatus StartDecodingPriorityUpdatePayload(DecodeBuffer* db); + DecodeStatus StartDecodingPushPromisePayload(DecodeBuffer* db); + DecodeStatus StartDecodingRstStreamPayload(DecodeBuffer* db); + DecodeStatus StartDecodingSettingsPayload(DecodeBuffer* db); + DecodeStatus StartDecodingUnknownPayload(DecodeBuffer* db); + DecodeStatus StartDecodingWindowUpdatePayload(DecodeBuffer* db); + + // These methods call the ResumeDecodingPayload() method of the frame type's + // payload decoder; they are called only if the preceding call to the + // corresponding Start method (above) returned kDecodeInProgress, as did any + // subsequent calls to the resume method. + // Unlike the Start methods, the decode buffer may extend beyond the + // end of the payload, so the method will create a DecodeBufferSubset + // before calling the ResumeDecodingPayload method of the frame type's + // payload decoder. + DecodeStatus ResumeDecodingAltSvcPayload(DecodeBuffer* db); + DecodeStatus ResumeDecodingContinuationPayload(DecodeBuffer* db); + DecodeStatus ResumeDecodingDataPayload(DecodeBuffer* db); + DecodeStatus ResumeDecodingGoAwayPayload(DecodeBuffer* db); + DecodeStatus ResumeDecodingHeadersPayload(DecodeBuffer* db); + DecodeStatus ResumeDecodingPingPayload(DecodeBuffer* db); + DecodeStatus ResumeDecodingPriorityPayload(DecodeBuffer* db); + DecodeStatus ResumeDecodingPriorityUpdatePayload(DecodeBuffer* db); + DecodeStatus ResumeDecodingPushPromisePayload(DecodeBuffer* db); + DecodeStatus ResumeDecodingRstStreamPayload(DecodeBuffer* db); + DecodeStatus ResumeDecodingSettingsPayload(DecodeBuffer* db); + DecodeStatus ResumeDecodingUnknownPayload(DecodeBuffer* db); + DecodeStatus ResumeDecodingWindowUpdatePayload(DecodeBuffer* db); + + FrameDecoderState frame_decoder_state_; + + // We only need one payload decoder at a time, so they share the same storage. + union { + AltSvcPayloadDecoder altsvc_payload_decoder_; + ContinuationPayloadDecoder continuation_payload_decoder_; + DataPayloadDecoder data_payload_decoder_; + GoAwayPayloadDecoder goaway_payload_decoder_; + HeadersPayloadDecoder headers_payload_decoder_; + PingPayloadDecoder ping_payload_decoder_; + PriorityPayloadDecoder priority_payload_decoder_; + PriorityUpdatePayloadDecoder priority_payload_update_decoder_; + PushPromisePayloadDecoder push_promise_payload_decoder_; + RstStreamPayloadDecoder rst_stream_payload_decoder_; + SettingsPayloadDecoder settings_payload_decoder_; + UnknownPayloadDecoder unknown_payload_decoder_; + WindowUpdatePayloadDecoder window_update_payload_decoder_; + }; + + State state_; + size_t maximum_payload_size_; + + // Listener used whenever caller passes nullptr to set_listener. + Http2FrameDecoderNoOpListener no_op_listener_; +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_DECODER_HTTP2_FRAME_DECODER_H_ diff --git a/quiche/http2/decoder/http2_frame_decoder_listener.cc b/quiche/http2/decoder/http2_frame_decoder_listener.cc new file mode 100644 index 000000000000..76f19da5bb09 --- /dev/null +++ b/quiche/http2/decoder/http2_frame_decoder_listener.cc @@ -0,0 +1,14 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" + +namespace http2 { + +bool Http2FrameDecoderNoOpListener::OnFrameHeader( + const Http2FrameHeader& /*header*/) { + return true; +} + +} // namespace http2 diff --git a/quiche/http2/decoder/http2_frame_decoder_listener.h b/quiche/http2/decoder/http2_frame_decoder_listener.h new file mode 100644 index 000000000000..0a98b3da0511 --- /dev/null +++ b/quiche/http2/decoder/http2_frame_decoder_listener.h @@ -0,0 +1,385 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_DECODER_HTTP2_FRAME_DECODER_LISTENER_H_ +#define QUICHE_HTTP2_DECODER_HTTP2_FRAME_DECODER_LISTENER_H_ + +// Http2FrameDecoderListener is the interface which the HTTP/2 decoder uses +// to report the decoded frames to a listener. +// +// The general design is to assume that the listener will copy the data it needs +// (e.g. frame headers) and will keep track of the implicit state of the +// decoding process (i.e. the decoder maintains just the information it needs in +// order to perform the decoding). Therefore, the parameters are just those with +// (potentially) new data, not previously provided info about the current frame. +// +// The calls are described as if they are made in quick succession, i.e. one +// after another, but of course the decoder needs input to decode, and the +// decoder will only call the listener once the necessary input has been +// provided. For example: OnDataStart can only be called once the 9 bytes of +// of an HTTP/2 common frame header have been received. The decoder will call +// the listener methods as soon as possible to avoid almost all buffering. +// +// The listener interface is designed so that it is possible to exactly +// reconstruct the serialized frames, with the exception of reserved bits, +// including in the frame header's flags and stream_id fields, which will have +// been cleared before the methods below are called. + +#include + +#include +#include + +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { + +// TODO(jamessynge): Consider sorting the methods by frequency of call, if that +// helps at all. +class QUICHE_EXPORT Http2FrameDecoderListener { + public: + Http2FrameDecoderListener() {} + virtual ~Http2FrameDecoderListener() {} + + // Called once the common frame header has been decoded for any frame, and + // before any of the methods below, which will also be called. This method is + // included in this interface only for the purpose of supporting SpdyFramer + // semantics via an adapter. This is the only method that has a non-void + // return type, and this is just so that Http2FrameDecoderAdapter (called + // from SpdyFramer) can more readily pass existing tests that expect decoding + // to stop if the headers alone indicate an error. Return false to stop + // decoding just after decoding the header, else return true to continue + // decoding. + // TODO(jamessynge): Remove OnFrameHeader once done with supporting + // SpdyFramer's exact states. + virtual bool OnFrameHeader(const Http2FrameHeader& header) = 0; + + ////////////////////////////////////////////////////////////////////////////// + + // Called once the common frame header has been decoded for a DATA frame, + // before examining the frame's payload, after which: + // OnPadLength will be called if header.IsPadded() is true, i.e. if the + // PADDED flag is set; + // OnDataPayload will be called as the non-padding portion of the payload + // is available until all of it has been provided; + // OnPadding will be called if the frame is padded AND the Pad Length field + // is greater than zero; + // OnDataEnd will be called last. If the frame is unpadded and has no + // payload, then this will be called immediately after OnDataStart. + virtual void OnDataStart(const Http2FrameHeader& header) = 0; + + // Called when the next non-padding portion of a DATA frame's payload is + // received. + // |data| The start of |len| bytes of data. + // |len| The length of the data buffer. Maybe zero in some cases, which does + // not mean anything special. + virtual void OnDataPayload(const char* data, size_t len) = 0; + + // Called after an entire DATA frame has been received. + // If header.IsEndStream() == true, this is the last data for the stream. + virtual void OnDataEnd() = 0; + + // Called once the common frame header has been decoded for a HEADERS frame, + // before examining the frame's payload, after which: + // OnPadLength will be called if header.IsPadded() is true, i.e. if the + // PADDED flag is set; + // OnHeadersPriority will be called if header.HasPriority() is true, i.e. if + // the frame has the PRIORITY flag; + // OnHpackFragment as the remainder of the non-padding payload is available + // until all if has been provided; + // OnPadding will be called if the frame is padded AND the Pad Length field + // is greater than zero; + // OnHeadersEnd will be called last; If the frame is unpadded and has no + // payload, then this will be called immediately after OnHeadersStart; + // OnHeadersEnd indicates the end of the HPACK block only if the frame + // header had the END_HEADERS flag set, else the END_HEADERS should be + // looked for on a subsequent CONTINUATION frame. + virtual void OnHeadersStart(const Http2FrameHeader& header) = 0; + + // Called when a HEADERS frame is received with the PRIORITY flag set and + // the priority fields have been decoded. + virtual void OnHeadersPriority( + const Http2PriorityFields& priority_fields) = 0; + + // Called when a fragment (i.e. some or all of an HPACK Block) is received; + // this may be part of a HEADERS, PUSH_PROMISE or CONTINUATION frame. + // |data| The start of |len| bytes of data. + // |len| The length of the data buffer. Maybe zero in some cases, which does + // not mean anything special, except that it simplified the decoder. + virtual void OnHpackFragment(const char* data, size_t len) = 0; + + // Called after an entire HEADERS frame has been received. The frame is the + // end of the HEADERS if the END_HEADERS flag is set; else there should be + // CONTINUATION frames after this frame. + virtual void OnHeadersEnd() = 0; + + // Called when an entire PRIORITY frame has been decoded. + virtual void OnPriorityFrame(const Http2FrameHeader& header, + const Http2PriorityFields& priority_fields) = 0; + + // Called once the common frame header has been decoded for a CONTINUATION + // frame, before examining the frame's payload, after which: + // OnHpackFragment as the frame's payload is available until all of it + // has been provided; + // OnContinuationEnd will be called last; If the frame has no payload, + // then this will be called immediately after OnContinuationStart; + // the HPACK block is at an end if and only if the frame header passed + // to OnContinuationStart had the END_HEADERS flag set. + virtual void OnContinuationStart(const Http2FrameHeader& header) = 0; + + // Called after an entire CONTINUATION frame has been received. The frame is + // the end of the HEADERS if the END_HEADERS flag is set. + virtual void OnContinuationEnd() = 0; + + // Called when Pad Length field has been read. Applies to DATA and HEADERS + // frames. For PUSH_PROMISE frames, the Pad Length + 1 is provided in the + // OnPushPromiseStart call as total_padding_length. + virtual void OnPadLength(size_t pad_length) = 0; + + // Called when padding is skipped over. + virtual void OnPadding(const char* padding, size_t skipped_length) = 0; + + // Called when an entire RST_STREAM frame has been decoded. + // This is the only callback for RST_STREAM frames. + virtual void OnRstStream(const Http2FrameHeader& header, + Http2ErrorCode error_code) = 0; + + // Called once the common frame header has been decoded for a SETTINGS frame + // without the ACK flag, before examining the frame's payload, after which: + // OnSetting will be called in turn for each pair of settings parameter and + // value found in the payload; + // OnSettingsEnd will be called last; If the frame has no payload, + // then this will be called immediately after OnSettingsStart. + // The frame header is passed so that the caller can check the stream_id, + // which should be zero, but that hasn't been checked by the decoder. + virtual void OnSettingsStart(const Http2FrameHeader& header) = 0; + + // Called for each setting parameter and value within a SETTINGS frame. + virtual void OnSetting(const Http2SettingFields& setting_fields) = 0; + + // Called after parsing the complete payload of SETTINGS frame (non-ACK). + virtual void OnSettingsEnd() = 0; + + // Called when an entire SETTINGS frame, with the ACK flag, has been decoded. + virtual void OnSettingsAck(const Http2FrameHeader& header) = 0; + + // Called just before starting to process the HPACK block of a PUSH_PROMISE + // frame. The Pad Length field has already been decoded at this point, so + // OnPadLength will not be called; note that total_padding_length is Pad + // Length + 1. After OnPushPromiseStart: + // OnHpackFragment as the remainder of the non-padding payload is available + // until all if has been provided; + // OnPadding will be called if the frame is padded AND the Pad Length field + // is greater than zero (i.e. total_padding_length > 1); + // OnPushPromiseEnd will be called last; If the frame is unpadded and has no + // payload, then this will be called immediately after OnPushPromiseStart. + virtual void OnPushPromiseStart(const Http2FrameHeader& header, + const Http2PushPromiseFields& promise, + size_t total_padding_length) = 0; + + // Called after all of the HPACK block fragment and padding of a PUSH_PROMISE + // has been decoded and delivered to the listener. This call indicates the end + // of the HPACK block if and only if the frame header had the END_HEADERS flag + // set (i.e. header.IsEndHeaders() is true); otherwise the next block must be + // a CONTINUATION frame with the same stream id (not the same promised stream + // id). + virtual void OnPushPromiseEnd() = 0; + + // Called when an entire PING frame, without the ACK flag, has been decoded. + virtual void OnPing(const Http2FrameHeader& header, + const Http2PingFields& ping) = 0; + + // Called when an entire PING frame, with the ACK flag, has been decoded. + virtual void OnPingAck(const Http2FrameHeader& header, + const Http2PingFields& ping) = 0; + + // Called after parsing a GOAWAY frame's header and fixed size fields, after + // which: + // OnGoAwayOpaqueData will be called as opaque data of the payload becomes + // available to the decoder, until all of it has been provided to the + // listener; + // OnGoAwayEnd will be called last, after all the opaque data has been + // provided to the listener; if there is no opaque data, then OnGoAwayEnd + // will be called immediately after OnGoAwayStart. + virtual void OnGoAwayStart(const Http2FrameHeader& header, + const Http2GoAwayFields& goaway) = 0; + + // Called when the next portion of a GOAWAY frame's payload is received. + // |data| The start of |len| bytes of opaque data. + // |len| The length of the opaque data buffer. Maybe zero in some cases, + // which does not mean anything special. + virtual void OnGoAwayOpaqueData(const char* data, size_t len) = 0; + + // Called after finishing decoding all of a GOAWAY frame. + virtual void OnGoAwayEnd() = 0; + + // Called when an entire WINDOW_UPDATE frame has been decoded. The + // window_size_increment is required to be non-zero, but that has not been + // checked. If header.stream_id==0, the connection's flow control window is + // being increased, else the specified stream's flow control is being + // increased. + virtual void OnWindowUpdate(const Http2FrameHeader& header, + uint32_t window_size_increment) = 0; + + // Called when an ALTSVC frame header and origin length have been parsed. + // Either or both lengths may be zero. After OnAltSvcStart: + // OnAltSvcOriginData will be called until all of the (optional) Origin + // has been provided; + // OnAltSvcValueData will be called until all of the Alt-Svc-Field-Value + // has been provided; + // OnAltSvcEnd will called last, after all of the origin and + // Alt-Svc-Field-Value have been delivered to the listener. + virtual void OnAltSvcStart(const Http2FrameHeader& header, + size_t origin_length, size_t value_length) = 0; + + // Called when decoding the (optional) origin of an ALTSVC; + // the field is uninterpreted. + virtual void OnAltSvcOriginData(const char* data, size_t len) = 0; + + // Called when decoding the Alt-Svc-Field-Value of an ALTSVC; + // the field is uninterpreted. + virtual void OnAltSvcValueData(const char* data, size_t len) = 0; + + // Called after decoding all of a ALTSVC frame and providing to the listener + // via the above methods. + virtual void OnAltSvcEnd() = 0; + + // Called when an PRIORITY_UPDATE frame header and Prioritized Stream ID have + // been parsed. Afterwards: + // OnPriorityUpdatePayload will be called each time a portion of the + // Priority Field Value field is available until all of it has been + // provided; + // OnPriorityUpdateEnd will be called last. If the frame has an empty + // Priority Field Value, then this will be called immediately after + // OnPriorityUpdateStart. + virtual void OnPriorityUpdateStart( + const Http2FrameHeader& header, + const Http2PriorityUpdateFields& priority_update) = 0; + + // Called when the next portion of a PRIORITY_UPDATE frame's Priority Field + // Value field is received. + // |data| The start of |len| bytes of data. + // |len| The length of the data buffer. May be zero in some cases, which does + // not mean anything special. + virtual void OnPriorityUpdatePayload(const char* data, size_t len) = 0; + + // Called after an entire PRIORITY_UPDATE frame has been received. + virtual void OnPriorityUpdateEnd() = 0; + + // Called when the common frame header has been decoded, but the frame type + // is unknown, after which: + // OnUnknownPayload is called as the payload of the frame is provided to the + // decoder, until all of the payload has been decoded; + // OnUnknownEnd will called last, after the entire frame of the unknown type + // has been decoded and provided to the listener. + virtual void OnUnknownStart(const Http2FrameHeader& header) = 0; + + // Called when the payload of an unknown frame type is received. + // |data| A buffer containing the data received. + // |len| The length of the data buffer. + virtual void OnUnknownPayload(const char* data, size_t len) = 0; + + // Called after decoding all of the payload of an unknown frame type. + virtual void OnUnknownEnd() = 0; + + ////////////////////////////////////////////////////////////////////////////// + // Below here are events indicating a problem has been detected during + // decoding (i.e. the received frames are malformed in some way). + + // Padding field (uint8) has a value that is too large (i.e. the amount of + // padding is greater than the remainder of the payload that isn't required). + // From RFC Section 6.1, DATA: + // If the length of the padding is the length of the frame payload or + // greater, the recipient MUST treat this as a connection error + // (Section 5.4.1) of type PROTOCOL_ERROR. + // The same is true for HEADERS and PUSH_PROMISE. + virtual void OnPaddingTooLong(const Http2FrameHeader& header, + size_t missing_length) = 0; + + // Frame size error. Depending upon the effected frame, this may or may not + // require terminating the connection, though that is probably the best thing + // to do. + // From RFC Section 4.2, Frame Size: + // An endpoint MUST send an error code of FRAME_SIZE_ERROR if a frame + // exceeds the size defined in SETTINGS_MAX_FRAME_SIZE, exceeds any limit + // defined for the frame type, or is too small to contain mandatory frame + // data. A frame size error in a frame that could alter the state of the + // the entire connection MUST be treated as a connection error + // (Section 5.4.1); this includes any frame carrying a header block + // (Section 4.3) (that is, HEADERS, PUSH_PROMISE, and CONTINUATION), + // SETTINGS, and any frame with a stream identifier of 0. + virtual void OnFrameSizeError(const Http2FrameHeader& header) = 0; +}; + +// Do nothing for each call. Useful for ignoring a frame that is invalid. +class QUICHE_EXPORT Http2FrameDecoderNoOpListener + : public Http2FrameDecoderListener { + public: + Http2FrameDecoderNoOpListener() {} + ~Http2FrameDecoderNoOpListener() override {} + + // TODO(jamessynge): Remove OnFrameHeader once done with supporting + // SpdyFramer's exact states. + bool OnFrameHeader(const Http2FrameHeader& header) override; + + void OnDataStart(const Http2FrameHeader& /*header*/) override {} + void OnDataPayload(const char* /*data*/, size_t /*len*/) override {} + void OnDataEnd() override {} + void OnHeadersStart(const Http2FrameHeader& /*header*/) override {} + void OnHeadersPriority(const Http2PriorityFields& /*priority*/) override {} + void OnHpackFragment(const char* /*data*/, size_t /*len*/) override {} + void OnHeadersEnd() override {} + void OnPriorityFrame(const Http2FrameHeader& /*header*/, + const Http2PriorityFields& /*priority*/) override {} + void OnContinuationStart(const Http2FrameHeader& /*header*/) override {} + void OnContinuationEnd() override {} + void OnPadLength(size_t /*trailing_length*/) override {} + void OnPadding(const char* /*padding*/, size_t /*skipped_length*/) override {} + void OnRstStream(const Http2FrameHeader& /*header*/, + Http2ErrorCode /*error_code*/) override {} + void OnSettingsStart(const Http2FrameHeader& /*header*/) override {} + void OnSetting(const Http2SettingFields& /*setting_fields*/) override {} + void OnSettingsEnd() override {} + void OnSettingsAck(const Http2FrameHeader& /*header*/) override {} + void OnPushPromiseStart(const Http2FrameHeader& /*header*/, + const Http2PushPromiseFields& /*promise*/, + size_t /*total_padding_length*/) override {} + void OnPushPromiseEnd() override {} + void OnPing(const Http2FrameHeader& /*header*/, + const Http2PingFields& /*ping*/) override {} + void OnPingAck(const Http2FrameHeader& /*header*/, + const Http2PingFields& /*ping*/) override {} + void OnGoAwayStart(const Http2FrameHeader& /*header*/, + const Http2GoAwayFields& /*goaway*/) override {} + void OnGoAwayOpaqueData(const char* /*data*/, size_t /*len*/) override {} + void OnGoAwayEnd() override {} + void OnWindowUpdate(const Http2FrameHeader& /*header*/, + uint32_t /*increment*/) override {} + void OnAltSvcStart(const Http2FrameHeader& /*header*/, + size_t /*origin_length*/, + size_t /*value_length*/) override {} + void OnAltSvcOriginData(const char* /*data*/, size_t /*len*/) override {} + void OnAltSvcValueData(const char* /*data*/, size_t /*len*/) override {} + void OnAltSvcEnd() override {} + void OnPriorityUpdateStart( + const Http2FrameHeader& /*header*/, + const Http2PriorityUpdateFields& /*priority_update*/) override {} + void OnPriorityUpdatePayload(const char* /*data*/, size_t /*len*/) override {} + void OnPriorityUpdateEnd() override {} + void OnUnknownStart(const Http2FrameHeader& /*header*/) override {} + void OnUnknownPayload(const char* /*data*/, size_t /*len*/) override {} + void OnUnknownEnd() override {} + void OnPaddingTooLong(const Http2FrameHeader& /*header*/, + size_t /*missing_length*/) override {} + void OnFrameSizeError(const Http2FrameHeader& /*header*/) override {} +}; + +static_assert(!std::is_abstract(), + "Http2FrameDecoderNoOpListener ought to be concrete."); + +} // namespace http2 + +#endif // QUICHE_HTTP2_DECODER_HTTP2_FRAME_DECODER_LISTENER_H_ diff --git a/quiche/http2/decoder/http2_frame_decoder_test.cc b/quiche/http2/decoder/http2_frame_decoder_test.cc new file mode 100644 index 000000000000..cd408805dd15 --- /dev/null +++ b/quiche/http2/decoder/http2_frame_decoder_test.cc @@ -0,0 +1,919 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/http2_frame_decoder.h" + +// Tests of Http2FrameDecoder. + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/test_tools/frame_parts.h" +#include "quiche/http2/test_tools/frame_parts_collector_listener.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/http2/test_tools/random_decoder_test_base.h" +#include "quiche/http2/test_tools/verify_macros.h" +#include "quiche/common/platform/api/quiche_logging.h" + +using ::testing::AssertionSuccess; + +namespace http2 { +namespace test { +class Http2FrameDecoderPeer { + public: + static size_t remaining_total_payload(Http2FrameDecoder* decoder) { + return decoder->frame_decoder_state_.remaining_total_payload(); + } +}; + +namespace { + +class Http2FrameDecoderTest : public RandomDecoderTest { + protected: + DecodeStatus StartDecoding(DecodeBuffer* db) override { + QUICHE_DVLOG(2) << "StartDecoding, db->Remaining=" << db->Remaining(); + collector_.Reset(); + PrepareDecoder(); + + DecodeStatus status = decoder_->DecodeFrame(db); + if (status != DecodeStatus::kDecodeInProgress) { + // Keep track of this so that a concrete test can verify that both fast + // and slow decoding paths have been tested. + ++fast_decode_count_; + if (status == DecodeStatus::kDecodeError) { + ConfirmDiscardsRemainingPayload(); + } + } + return status; + } + + DecodeStatus ResumeDecoding(DecodeBuffer* db) override { + QUICHE_DVLOG(2) << "ResumeDecoding, db->Remaining=" << db->Remaining(); + DecodeStatus status = decoder_->DecodeFrame(db); + if (status != DecodeStatus::kDecodeInProgress) { + // Keep track of this so that a concrete test can verify that both fast + // and slow decoding paths have been tested. + ++slow_decode_count_; + if (status == DecodeStatus::kDecodeError) { + ConfirmDiscardsRemainingPayload(); + } + } + return status; + } + + // When an error is returned, the decoder is in state kDiscardPayload, and + // stays there until the remaining bytes of the frame's payload have been + // skipped over. There are no callbacks for this situation. + void ConfirmDiscardsRemainingPayload() { + ASSERT_TRUE(decoder_->IsDiscardingPayload()); + size_t remaining = + Http2FrameDecoderPeer::remaining_total_payload(decoder_.get()); + // The decoder will discard the remaining bytes, but not go beyond that, + // which these conditions verify. + size_t extra = 10; + std::string junk(remaining + extra, '0'); + DecodeBuffer tmp(junk); + EXPECT_EQ(DecodeStatus::kDecodeDone, decoder_->DecodeFrame(&tmp)); + EXPECT_EQ(remaining, tmp.Offset()); + EXPECT_EQ(extra, tmp.Remaining()); + EXPECT_FALSE(decoder_->IsDiscardingPayload()); + } + + void PrepareDecoder() { + decoder_ = std::make_unique(&collector_); + decoder_->set_maximum_payload_size(maximum_payload_size_); + } + + void ResetDecodeSpeedCounters() { + fast_decode_count_ = 0; + slow_decode_count_ = 0; + } + + AssertionResult VerifyCollected(const FrameParts& expected) { + HTTP2_VERIFY_FALSE(collector_.IsInProgress()); + HTTP2_VERIFY_EQ(1u, collector_.size()); + return expected.VerifyEquals(*collector_.frame(0)); + } + + AssertionResult DecodePayloadAndValidateSeveralWays(absl::string_view payload, + Validator validator) { + DecodeBuffer db(payload); + bool start_decoding_requires_non_empty = false; + return DecodeAndValidateSeveralWays(&db, start_decoding_requires_non_empty, + validator); + } + + // Decode one frame's payload and confirm that the listener recorded the + // expected FrameParts instance, and only one FrameParts instance. The + // payload will be decoded several times with different partitionings + // of the payload, and after each the validator will be called. + AssertionResult DecodePayloadAndValidateSeveralWays( + absl::string_view payload, const FrameParts& expected) { + auto validator = [&expected, this](const DecodeBuffer& /*input*/, + DecodeStatus status) -> AssertionResult { + HTTP2_VERIFY_EQ(status, DecodeStatus::kDecodeDone); + return VerifyCollected(expected); + }; + ResetDecodeSpeedCounters(); + HTTP2_VERIFY_SUCCESS(DecodePayloadAndValidateSeveralWays( + payload, ValidateDoneAndEmpty(validator))); + HTTP2_VERIFY_GT(fast_decode_count_, 0u); + HTTP2_VERIFY_GT(slow_decode_count_, 0u); + + // Repeat with more input; it should stop without reading that input. + std::string next_frame = Random().RandString(10); + std::string input(payload.data(), payload.size()); + input += next_frame; + + ResetDecodeSpeedCounters(); + HTTP2_VERIFY_SUCCESS(DecodePayloadAndValidateSeveralWays( + payload, ValidateDoneAndOffset(payload.size(), validator))); + HTTP2_VERIFY_GT(fast_decode_count_, 0u); + HTTP2_VERIFY_GT(slow_decode_count_, 0u); + + return AssertionSuccess(); + } + + template + AssertionResult DecodePayloadAndValidateSeveralWays( + const char (&buf)[N], const FrameParts& expected) { + return DecodePayloadAndValidateSeveralWays(absl::string_view(buf, N), + expected); + } + + template + AssertionResult DecodePayloadAndValidateSeveralWays( + const char (&buf)[N], const Http2FrameHeader& header) { + return DecodePayloadAndValidateSeveralWays(absl::string_view(buf, N), + FrameParts(header)); + } + + template + AssertionResult DecodePayloadExpectingError(const char (&buf)[N], + const FrameParts& expected) { + auto validator = [&expected, this](const DecodeBuffer& /*input*/, + DecodeStatus status) -> AssertionResult { + HTTP2_VERIFY_EQ(status, DecodeStatus::kDecodeError); + return VerifyCollected(expected); + }; + ResetDecodeSpeedCounters(); + EXPECT_TRUE( + DecodePayloadAndValidateSeveralWays(ToStringPiece(buf), validator)); + EXPECT_GT(fast_decode_count_, 0u); + EXPECT_GT(slow_decode_count_, 0u); + return AssertionSuccess(); + } + + template + AssertionResult DecodePayloadExpectingFrameSizeError(const char (&buf)[N], + FrameParts expected) { + expected.SetHasFrameSizeError(true); + return DecodePayloadExpectingError(buf, expected); + } + + template + AssertionResult DecodePayloadExpectingFrameSizeError( + const char (&buf)[N], const Http2FrameHeader& header) { + return DecodePayloadExpectingFrameSizeError(buf, FrameParts(header)); + } + + // Count of payloads that are fully decoded by StartDecodingPayload or for + // which an error was detected by StartDecodingPayload. + size_t fast_decode_count_ = 0; + + // Count of payloads that required calling ResumeDecodingPayload in order to + // decode completely, or for which an error was detected by + // ResumeDecodingPayload. + size_t slow_decode_count_ = 0; + + uint32_t maximum_payload_size_ = Http2SettingsInfo::DefaultMaxFrameSize(); + FramePartsCollectorListener collector_; + std::unique_ptr decoder_; +}; + +//////////////////////////////////////////////////////////////////////////////// +// Tests that pass the minimum allowed size for the frame type, which is often +// empty. The tests are in order by frame type value (i.e. 0 for DATA frames). + +TEST_F(Http2FrameDecoderTest, DataEmpty) { + const char kFrameData[] = { + '\x00', '\x00', '\x00', // Payload length: 0 + '\x00', // DATA + '\x00', // Flags: none + '\x00', '\x00', '\x00', + '\x00', // Stream ID: 0 (invalid but unchecked here) + }; + Http2FrameHeader header(0, Http2FrameType::DATA, 0, 0); + FrameParts expected(header, ""); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, HeadersEmpty) { + const char kFrameData[] = { + '\x00', '\x00', '\x00', // Payload length: 0 + '\x01', // HEADERS + '\x00', // Flags: none + '\x00', '\x00', '\x00', '\x01', // Stream ID: 0 (REQUIRES ID) + }; + Http2FrameHeader header(0, Http2FrameType::HEADERS, 0, 1); + FrameParts expected(header, ""); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, Priority) { + const char kFrameData[] = { + '\x00', '\x00', '\x05', // Length: 5 + '\x02', // Type: PRIORITY + '\x00', // Flags: none + '\x00', '\x00', '\x00', '\x02', // Stream: 2 + '\x80', '\x00', '\x00', '\x01', // Parent: 1 (Exclusive) + '\x10', // Weight: 17 + }; + Http2FrameHeader header(5, Http2FrameType::PRIORITY, 0, 2); + FrameParts expected(header); + expected.SetOptPriority(Http2PriorityFields(1, 17, true)); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, RstStream) { + const char kFrameData[] = { + '\x00', '\x00', '\x04', // Length: 4 + '\x03', // Type: RST_STREAM + '\x00', // Flags: none + '\x00', '\x00', '\x00', '\x01', // Stream: 1 + '\x00', '\x00', '\x00', '\x01', // Error: PROTOCOL_ERROR + }; + Http2FrameHeader header(4, Http2FrameType::RST_STREAM, 0, 1); + FrameParts expected(header); + expected.SetOptRstStreamErrorCode(Http2ErrorCode::PROTOCOL_ERROR); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, SettingsEmpty) { + const char kFrameData[] = { + '\x00', '\x00', '\x00', // Length: 0 + '\x04', // Type: SETTINGS + '\x00', // Flags: none + '\x00', '\x00', '\x00', '\x01', // Stream: 1 (invalid but unchecked here) + }; + Http2FrameHeader header(0, Http2FrameType::SETTINGS, 0, 1); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, header)); +} + +TEST_F(Http2FrameDecoderTest, SettingsAck) { + const char kFrameData[] = { + '\x00', '\x00', '\x00', // Length: 6 + '\x04', // Type: SETTINGS + '\x01', // Flags: ACK + '\x00', '\x00', '\x00', '\x00', // Stream: 0 + }; + Http2FrameHeader header(0, Http2FrameType::SETTINGS, Http2FrameFlag::ACK, 0); + FrameParts expected(header); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, PushPromiseMinimal) { + const char kFrameData[] = { + '\x00', '\x00', '\x04', // Payload length: 4 + '\x05', // PUSH_PROMISE + '\x04', // Flags: END_HEADERS + '\x00', '\x00', '\x00', + '\x02', // Stream: 2 (invalid but unchecked here) + '\x00', '\x00', '\x00', + '\x01', // Promised: 1 (invalid but unchecked here) + }; + Http2FrameHeader header(4, Http2FrameType::PUSH_PROMISE, + Http2FrameFlag::END_HEADERS, 2); + FrameParts expected(header, ""); + expected.SetOptPushPromise(Http2PushPromiseFields{1}); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, Ping) { + const char kFrameData[] = { + '\x00', '\x00', '\x08', // Length: 8 + '\x06', // Type: PING + '\xfe', // Flags: no valid flags + '\x00', '\x00', '\x00', '\x00', // Stream: 0 + 's', 'o', 'm', 'e', // "some" + 'd', 'a', 't', 'a', // "data" + }; + Http2FrameHeader header(8, Http2FrameType::PING, 0, 0); + FrameParts expected(header); + expected.SetOptPing( + Http2PingFields{{'s', 'o', 'm', 'e', 'd', 'a', 't', 'a'}}); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, PingAck) { + const char kFrameData[] = { + '\x00', '\x00', '\x08', // Length: 8 + '\x06', // Type: PING + '\xff', // Flags: ACK (plus all invalid flags) + '\x00', '\x00', '\x00', '\x00', // Stream: 0 + 's', 'o', 'm', 'e', // "some" + 'd', 'a', 't', 'a', // "data" + }; + Http2FrameHeader header(8, Http2FrameType::PING, Http2FrameFlag::ACK, 0); + FrameParts expected(header); + expected.SetOptPing( + Http2PingFields{{'s', 'o', 'm', 'e', 'd', 'a', 't', 'a'}}); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, GoAwayMinimal) { + const char kFrameData[] = { + '\x00', '\x00', '\x08', // Length: 8 (no opaque data) + '\x07', // Type: GOAWAY + '\xff', // Flags: 0xff (no valid flags) + '\x00', '\x00', '\x00', '\x01', // Stream: 1 (invalid but unchecked here) + '\x80', '\x00', '\x00', '\xff', // Last: 255 (plus R bit) + '\x00', '\x00', '\x00', '\x09', // Error: COMPRESSION_ERROR + }; + Http2FrameHeader header(8, Http2FrameType::GOAWAY, 0, 1); + FrameParts expected(header); + expected.SetOptGoaway( + Http2GoAwayFields(255, Http2ErrorCode::COMPRESSION_ERROR)); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, WindowUpdate) { + const char kFrameData[] = { + '\x00', '\x00', '\x04', // Length: 4 + '\x08', // Type: WINDOW_UPDATE + '\x0f', // Flags: 0xff (no valid flags) + '\x00', '\x00', '\x00', '\x01', // Stream: 1 + '\x80', '\x00', '\x04', '\x00', // Incr: 1024 (plus R bit) + }; + Http2FrameHeader header(4, Http2FrameType::WINDOW_UPDATE, 0, 1); + FrameParts expected(header); + expected.SetOptWindowUpdateIncrement(1024); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, ContinuationEmpty) { + const char kFrameData[] = { + '\x00', '\x00', '\x00', // Payload length: 0 + '\x09', // CONTINUATION + '\x00', // Flags: none + '\x00', '\x00', '\x00', + '\x00', // Stream ID: 0 (invalid but unchecked here) + }; + Http2FrameHeader header(0, Http2FrameType::CONTINUATION, 0, 0); + FrameParts expected(header); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, AltSvcMinimal) { + const char kFrameData[] = { + '\x00', '\x00', '\x02', // Payload length: 2 + '\x0a', // ALTSVC + '\xff', // Flags: none (plus 0xff) + '\x00', '\x00', '\x00', + '\x00', // Stream ID: 0 (invalid but unchecked here) + '\x00', '\x00', // Origin Length: 0 + }; + Http2FrameHeader header(2, Http2FrameType::ALTSVC, 0, 0); + FrameParts expected(header); + expected.SetOptAltsvcOriginLength(0); + expected.SetOptAltsvcValueLength(0); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, UnknownEmpty) { + const char kFrameData[] = { + '\x00', '\x00', '\x00', // Payload length: 0 + '\x20', // 32 (unknown) + '\xff', // Flags: all + '\x00', '\x00', '\x00', '\x00', // Stream ID: 0 + }; + Http2FrameHeader header(0, static_cast(32), 0xff, 0); + FrameParts expected(header); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +//////////////////////////////////////////////////////////////////////////////// +// Tests of longer payloads, for those frame types that allow longer payloads. + +TEST_F(Http2FrameDecoderTest, DataPayload) { + const char kFrameData[] = { + '\x00', '\x00', '\x03', // Payload length: 7 + '\x00', // DATA + '\x80', // Flags: 0x80 + '\x00', '\x00', '\x02', '\x02', // Stream ID: 514 + 'a', 'b', 'c', // Data + }; + Http2FrameHeader header(3, Http2FrameType::DATA, 0, 514); + FrameParts expected(header, "abc"); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, HeadersPayload) { + const char kFrameData[] = { + '\x00', '\x00', '\x03', // Payload length: 3 + '\x01', // HEADERS + '\x05', // Flags: END_STREAM | END_HEADERS + '\x00', '\x00', '\x00', '\x02', // Stream ID: 0 (REQUIRES ID) + 'a', 'b', 'c', // HPACK fragment (doesn't have to be valid) + }; + Http2FrameHeader header( + 3, Http2FrameType::HEADERS, + Http2FrameFlag::END_STREAM | Http2FrameFlag::END_HEADERS, 2); + FrameParts expected(header, "abc"); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, HeadersPriority) { + const char kFrameData[] = { + '\x00', '\x00', '\x05', // Payload length: 5 + '\x01', // HEADERS + '\x20', // Flags: PRIORITY + '\x00', '\x00', '\x00', '\x02', // Stream ID: 0 (REQUIRES ID) + '\x00', '\x00', '\x00', '\x01', // Parent: 1 (Not Exclusive) + '\xff', // Weight: 256 + }; + Http2FrameHeader header(5, Http2FrameType::HEADERS, Http2FrameFlag::PRIORITY, + 2); + FrameParts expected(header); + expected.SetOptPriority(Http2PriorityFields(1, 256, false)); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, Settings) { + const char kFrameData[] = { + '\x00', '\x00', '\x0c', // Length: 12 + '\x04', // Type: SETTINGS + '\x00', // Flags: none + '\x00', '\x00', '\x00', '\x00', // Stream: 0 + '\x00', '\x04', // Param: INITIAL_WINDOW_SIZE + '\x0a', '\x0b', '\x0c', '\x0d', // Value: 168496141 + '\x00', '\x02', // Param: ENABLE_PUSH + '\x00', '\x00', '\x00', '\x03', // Value: 3 (invalid but unchecked here) + }; + Http2FrameHeader header(12, Http2FrameType::SETTINGS, 0, 0); + FrameParts expected(header); + expected.AppendSetting(Http2SettingFields( + Http2SettingsParameter::INITIAL_WINDOW_SIZE, 168496141)); + expected.AppendSetting( + Http2SettingFields(Http2SettingsParameter::ENABLE_PUSH, 3)); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, PushPromisePayload) { + const char kFrameData[] = { + '\x00', '\x00', 7, // Payload length: 7 + '\x05', // PUSH_PROMISE + '\x04', // Flags: END_HEADERS + '\x00', '\x00', '\x00', '\xff', // Stream ID: 255 + '\x00', '\x00', '\x01', '\x00', // Promised: 256 + 'a', 'b', 'c', // HPACK fragment (doesn't have to be valid) + }; + Http2FrameHeader header(7, Http2FrameType::PUSH_PROMISE, + Http2FrameFlag::END_HEADERS, 255); + FrameParts expected(header, "abc"); + expected.SetOptPushPromise(Http2PushPromiseFields{256}); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, GoAwayOpaqueData) { + const char kFrameData[] = { + '\x00', '\x00', '\x0e', // Length: 14 + '\x07', // Type: GOAWAY + '\xff', // Flags: 0xff (no valid flags) + '\x80', '\x00', '\x00', '\x00', // Stream: 0 (plus R bit) + '\x00', '\x00', '\x01', '\x00', // Last: 256 + '\x00', '\x00', '\x00', '\x03', // Error: FLOW_CONTROL_ERROR + 'o', 'p', 'a', 'q', 'u', 'e', + }; + Http2FrameHeader header(14, Http2FrameType::GOAWAY, 0, 0); + FrameParts expected(header, "opaque"); + expected.SetOptGoaway( + Http2GoAwayFields(256, Http2ErrorCode::FLOW_CONTROL_ERROR)); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, ContinuationPayload) { + const char kFrameData[] = { + '\x00', '\x00', '\x03', // Payload length: 3 + '\x09', // CONTINUATION + '\xff', // Flags: END_HEADERS | 0xfb + '\x00', '\x00', '\x00', '\x02', // Stream ID: 2 + 'a', 'b', 'c', // Data + }; + Http2FrameHeader header(3, Http2FrameType::CONTINUATION, + Http2FrameFlag::END_HEADERS, 2); + FrameParts expected(header, "abc"); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, AltSvcPayload) { + const char kFrameData[] = { + '\x00', '\x00', '\x08', // Payload length: 3 + '\x0a', // ALTSVC + '\x00', // Flags: none + '\x00', '\x00', '\x00', '\x02', // Stream ID: 2 + '\x00', '\x03', // Origin Length: 0 + 'a', 'b', 'c', // Origin + 'd', 'e', 'f', // Value + }; + Http2FrameHeader header(8, Http2FrameType::ALTSVC, 0, 2); + FrameParts expected(header); + expected.SetAltSvcExpected("abc", "def"); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, PriorityUpdatePayload) { + const char kFrameData[] = { + '\x00', '\x00', '\x07', // Payload length: 7 + '\x10', // PRIORITY_UPDATE + '\x00', // Flags: none + '\x00', '\x00', '\x00', '\x00', // Stream ID: 0 + '\x00', '\x00', '\x00', '\x05', // Prioritized Stream ID: 5 + 'a', 'b', 'c', // Priority Field Value + }; + Http2FrameHeader header(7, Http2FrameType::PRIORITY_UPDATE, 0, 0); + + FrameParts expected(header, "abc"); + expected.SetOptPriorityUpdate(Http2PriorityUpdateFields{5}); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, UnknownPayload) { + const char kFrameData[] = { + '\x00', '\x00', '\x03', // Payload length: 3 + '\x30', // 48 (unknown) + '\x00', // Flags: none + '\x00', '\x00', '\x00', '\x02', // Stream ID: 2 + 'a', 'b', 'c', // Payload + }; + Http2FrameHeader header(3, static_cast(48), 0, 2); + FrameParts expected(header, "abc"); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +//////////////////////////////////////////////////////////////////////////////// +// Tests of padded payloads, for those frame types that allow padding. + +TEST_F(Http2FrameDecoderTest, DataPayloadAndPadding) { + const char kFrameData[] = { + '\x00', '\x00', '\x07', // Payload length: 7 + '\x00', // DATA + '\x09', // Flags: END_STREAM | PADDED + '\x00', '\x00', '\x00', '\x02', // Stream ID: 0 (REQUIRES ID) + '\x03', // Pad Len + 'a', 'b', 'c', // Data + '\x00', '\x00', '\x00', // Padding + }; + Http2FrameHeader header(7, Http2FrameType::DATA, + Http2FrameFlag::END_STREAM | Http2FrameFlag::PADDED, + 2); + size_t total_pad_length = 4; // Including the Pad Length field. + FrameParts expected(header, "abc", total_pad_length); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, HeadersPayloadAndPadding) { + const char kFrameData[] = { + '\x00', '\x00', '\x07', // Payload length: 7 + '\x01', // HEADERS + '\x08', // Flags: PADDED + '\x00', '\x00', '\x00', '\x02', // Stream ID: 0 (REQUIRES ID) + '\x03', // Pad Len + 'a', 'b', 'c', // HPACK fragment (doesn't have to be valid) + '\x00', '\x00', '\x00', // Padding + }; + Http2FrameHeader header(7, Http2FrameType::HEADERS, Http2FrameFlag::PADDED, + 2); + size_t total_pad_length = 4; // Including the Pad Length field. + FrameParts expected(header, "abc", total_pad_length); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, HeadersPayloadPriorityAndPadding) { + const char kFrameData[] = { + '\x00', '\x00', '\x0c', // Payload length: 12 + '\x01', // HEADERS + '\xff', // Flags: all, including undefined + '\x00', '\x00', '\x00', '\x02', // Stream ID: 0 (REQUIRES ID) + '\x03', // Pad Len + '\x80', '\x00', '\x00', '\x01', // Parent: 1 (Exclusive) + '\x10', // Weight: 17 + 'a', 'b', 'c', // HPACK fragment (doesn't have to be valid) + '\x00', '\x00', '\x00', // Padding + }; + Http2FrameHeader header(12, Http2FrameType::HEADERS, + Http2FrameFlag::END_STREAM | + Http2FrameFlag::END_HEADERS | + Http2FrameFlag::PADDED | Http2FrameFlag::PRIORITY, + 2); + size_t total_pad_length = 4; // Including the Pad Length field. + FrameParts expected(header, "abc", total_pad_length); + expected.SetOptPriority(Http2PriorityFields(1, 17, true)); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, PushPromisePayloadAndPadding) { + const char kFrameData[] = { + '\x00', '\x00', 11, // Payload length: 11 + '\x05', // PUSH_PROMISE + '\xff', // Flags: END_HEADERS | PADDED | 0xf3 + '\x00', '\x00', '\x00', '\x01', // Stream ID: 1 + '\x03', // Pad Len + '\x00', '\x00', '\x00', '\x02', // Promised: 2 + 'a', 'b', 'c', // HPACK fragment (doesn't have to be valid) + '\x00', '\x00', '\x00', // Padding + }; + Http2FrameHeader header(11, Http2FrameType::PUSH_PROMISE, + Http2FrameFlag::END_HEADERS | Http2FrameFlag::PADDED, + 1); + size_t total_pad_length = 4; // Including the Pad Length field. + FrameParts expected(header, "abc", total_pad_length); + expected.SetOptPushPromise(Http2PushPromiseFields{2}); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(kFrameData, expected)); +} + +//////////////////////////////////////////////////////////////////////////////// +// Payload too short errors. + +TEST_F(Http2FrameDecoderTest, DataMissingPadLengthField) { + const char kFrameData[] = { + '\x00', '\x00', '\x00', // Payload length: 0 + '\x00', // DATA + '\x08', // Flags: PADDED + '\x00', '\x00', '\x00', '\x01', // Stream ID: 1 + }; + Http2FrameHeader header(0, Http2FrameType::DATA, Http2FrameFlag::PADDED, 1); + FrameParts expected(header); + expected.SetOptMissingLength(1); + EXPECT_TRUE(DecodePayloadExpectingError(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, HeaderPaddingTooLong) { + const char kFrameData[] = { + '\x00', '\x00', '\x02', // Payload length: 0 + '\x01', // HEADERS + '\x08', // Flags: PADDED + '\x00', '\x01', '\x00', '\x00', // Stream ID: 65536 + '\xff', // Pad Len: 255 + '\x00', // Only one byte of padding + }; + Http2FrameHeader header(2, Http2FrameType::HEADERS, Http2FrameFlag::PADDED, + 65536); + FrameParts expected(header); + expected.SetOptMissingLength(254); + EXPECT_TRUE(DecodePayloadExpectingError(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, HeaderMissingPriority) { + const char kFrameData[] = { + '\x00', '\x00', '\x04', // Payload length: 0 + '\x01', // HEADERS + '\x20', // Flags: PRIORITY + '\x00', '\x01', '\x00', '\x00', // Stream ID: 65536 + '\x00', '\x00', '\x00', '\x00', // Priority (truncated) + }; + Http2FrameHeader header(4, Http2FrameType::HEADERS, Http2FrameFlag::PRIORITY, + 65536); + EXPECT_TRUE(DecodePayloadExpectingFrameSizeError(kFrameData, header)); +} + +TEST_F(Http2FrameDecoderTest, PriorityTooShort) { + const char kFrameData[] = { + '\x00', '\x00', '\x04', // Length: 5 + '\x02', // Type: PRIORITY + '\x00', // Flags: none + '\x00', '\x00', '\x00', '\x02', // Stream: 2 + '\x80', '\x00', '\x00', '\x01', // Parent: 1 (Exclusive) + }; + Http2FrameHeader header(4, Http2FrameType::PRIORITY, 0, 2); + EXPECT_TRUE(DecodePayloadExpectingFrameSizeError(kFrameData, header)); +} + +TEST_F(Http2FrameDecoderTest, RstStreamTooShort) { + const char kFrameData[] = { + '\x00', '\x00', '\x03', // Length: 4 + '\x03', // Type: RST_STREAM + '\x00', // Flags: none + '\x00', '\x00', '\x00', '\x01', // Stream: 1 + '\x00', '\x00', '\x00', // Truncated + }; + Http2FrameHeader header(3, Http2FrameType::RST_STREAM, 0, 1); + EXPECT_TRUE(DecodePayloadExpectingFrameSizeError(kFrameData, header)); +} + +// SETTINGS frames must a multiple of 6 bytes long, so an 9 byte payload is +// invalid. +TEST_F(Http2FrameDecoderTest, SettingsWrongSize) { + const char kFrameData[] = { + '\x00', '\x00', '\x09', // Length: 2 + '\x04', // Type: SETTINGS + '\x00', // Flags: none + '\x00', '\x00', '\x00', '\x00', // Stream: 0 + '\x00', '\x02', // Param: ENABLE_PUSH + '\x00', '\x00', '\x00', '\x03', // Value: 1 + '\x00', '\x04', // Param: INITIAL_WINDOW_SIZE + '\x00', // Value: Truncated + }; + Http2FrameHeader header(9, Http2FrameType::SETTINGS, 0, 0); + FrameParts expected(header); + expected.AppendSetting( + Http2SettingFields(Http2SettingsParameter::ENABLE_PUSH, 3)); + EXPECT_TRUE(DecodePayloadExpectingFrameSizeError(kFrameData, expected)); +} + +TEST_F(Http2FrameDecoderTest, PushPromiseTooShort) { + const char kFrameData[] = { + '\x00', '\x00', 3, // Payload length: 3 + '\x05', // PUSH_PROMISE + '\x00', // Flags: none + '\x00', '\x00', '\x00', '\x01', // Stream ID: 1 + '\x00', '\x00', '\x00', // Truncated promise id + }; + Http2FrameHeader header(3, Http2FrameType::PUSH_PROMISE, 0, 1); + EXPECT_TRUE(DecodePayloadExpectingFrameSizeError(kFrameData, header)); +} + +TEST_F(Http2FrameDecoderTest, PushPromisePaddedTruncatedPromise) { + const char kFrameData[] = { + '\x00', '\x00', 4, // Payload length: 4 + '\x05', // PUSH_PROMISE + '\x08', // Flags: PADDED + '\x00', '\x00', '\x00', '\x01', // Stream ID: 1 + '\x00', // Pad Len + '\x00', '\x00', '\x00', // Truncated promise id + }; + Http2FrameHeader header(4, Http2FrameType::PUSH_PROMISE, + Http2FrameFlag::PADDED, 1); + EXPECT_TRUE(DecodePayloadExpectingFrameSizeError(kFrameData, header)); +} + +TEST_F(Http2FrameDecoderTest, PingTooShort) { + const char kFrameData[] = { + '\x00', '\x00', '\x07', // Length: 8 + '\x06', // Type: PING + '\xfe', // Flags: no valid flags + '\x00', '\x00', '\x00', '\x00', // Stream: 0 + 's', 'o', 'm', 'e', // "some" + 'd', 'a', 't', // Too little + }; + Http2FrameHeader header(7, Http2FrameType::PING, 0, 0); + EXPECT_TRUE(DecodePayloadExpectingFrameSizeError(kFrameData, header)); +} + +TEST_F(Http2FrameDecoderTest, GoAwayTooShort) { + const char kFrameData[] = { + '\x00', '\x00', '\x00', // Length: 0 + '\x07', // Type: GOAWAY + '\xff', // Flags: 0xff (no valid flags) + '\x00', '\x00', '\x00', '\x00', // Stream: 0 + }; + Http2FrameHeader header(0, Http2FrameType::GOAWAY, 0, 0); + EXPECT_TRUE(DecodePayloadExpectingFrameSizeError(kFrameData, header)); +} + +TEST_F(Http2FrameDecoderTest, WindowUpdateTooShort) { + const char kFrameData[] = { + '\x00', '\x00', '\x03', // Length: 3 + '\x08', // Type: WINDOW_UPDATE + '\x0f', // Flags: 0xff (no valid flags) + '\x00', '\x00', '\x00', '\x01', // Stream: 1 + '\x80', '\x00', '\x04', // Truncated + }; + Http2FrameHeader header(3, Http2FrameType::WINDOW_UPDATE, 0, 1); + EXPECT_TRUE(DecodePayloadExpectingFrameSizeError(kFrameData, header)); +} + +TEST_F(Http2FrameDecoderTest, AltSvcTruncatedOriginLength) { + const char kFrameData[] = { + '\x00', '\x00', '\x01', // Payload length: 3 + '\x0a', // ALTSVC + '\x00', // Flags: none + '\x00', '\x00', '\x00', '\x02', // Stream ID: 2 + '\x00', // Origin Length: truncated + }; + Http2FrameHeader header(1, Http2FrameType::ALTSVC, 0, 2); + EXPECT_TRUE(DecodePayloadExpectingFrameSizeError(kFrameData, header)); +} + +TEST_F(Http2FrameDecoderTest, AltSvcTruncatedOrigin) { + const char kFrameData[] = { + '\x00', '\x00', '\x05', // Payload length: 3 + '\x0a', // ALTSVC + '\x00', // Flags: none + '\x00', '\x00', '\x00', '\x02', // Stream ID: 2 + '\x00', '\x04', // Origin Length: 4 (too long) + 'a', 'b', 'c', // Origin + }; + Http2FrameHeader header(5, Http2FrameType::ALTSVC, 0, 2); + EXPECT_TRUE(DecodePayloadExpectingFrameSizeError(kFrameData, header)); +} + +//////////////////////////////////////////////////////////////////////////////// +// Payload too long errors. + +// The decoder calls the listener's OnFrameSizeError method if the frame's +// payload is longer than the currently configured maximum payload size. +TEST_F(Http2FrameDecoderTest, BeyondMaximum) { + maximum_payload_size_ = 2; + const char kFrameData[] = { + '\x00', '\x00', '\x07', // Payload length: 7 + '\x00', // DATA + '\x09', // Flags: END_STREAM | PADDED + '\x00', '\x00', '\x00', '\x02', // Stream ID: 0 (REQUIRES ID) + '\x03', // Pad Len + 'a', 'b', 'c', // Data + '\x00', '\x00', '\x00', // Padding + }; + Http2FrameHeader header(7, Http2FrameType::DATA, + Http2FrameFlag::END_STREAM | Http2FrameFlag::PADDED, + 2); + FrameParts expected(header); + expected.SetHasFrameSizeError(true); + auto validator = [&expected, this](const DecodeBuffer& input, + DecodeStatus status) -> AssertionResult { + HTTP2_VERIFY_EQ(status, DecodeStatus::kDecodeError); + // The decoder detects this error after decoding the header, and without + // trying to decode the payload. + HTTP2_VERIFY_EQ(input.Offset(), Http2FrameHeader::EncodedSize()); + return VerifyCollected(expected); + }; + ResetDecodeSpeedCounters(); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(ToStringPiece(kFrameData), + validator)); + EXPECT_GT(fast_decode_count_, 0u); + EXPECT_GT(slow_decode_count_, 0u); +} + +TEST_F(Http2FrameDecoderTest, PriorityTooLong) { + const char kFrameData[] = { + '\x00', '\x00', '\x06', // Length: 5 + '\x02', // Type: PRIORITY + '\x00', // Flags: none + '\x00', '\x00', '\x00', '\x02', // Stream: 2 + '\x80', '\x00', '\x00', '\x01', // Parent: 1 (Exclusive) + '\x10', // Weight: 17 + '\x00', // Too much + }; + Http2FrameHeader header(6, Http2FrameType::PRIORITY, 0, 2); + EXPECT_TRUE(DecodePayloadExpectingFrameSizeError(kFrameData, header)); +} + +TEST_F(Http2FrameDecoderTest, RstStreamTooLong) { + const char kFrameData[] = { + '\x00', '\x00', '\x05', // Length: 4 + '\x03', // Type: RST_STREAM + '\x00', // Flags: none + '\x00', '\x00', '\x00', '\x01', // Stream: 1 + '\x00', '\x00', '\x00', '\x01', // Error: PROTOCOL_ERROR + '\x00', // Too much + }; + Http2FrameHeader header(5, Http2FrameType::RST_STREAM, 0, 1); + EXPECT_TRUE(DecodePayloadExpectingFrameSizeError(kFrameData, header)); +} + +TEST_F(Http2FrameDecoderTest, SettingsAckTooLong) { + const char kFrameData[] = { + '\x00', '\x00', '\x06', // Length: 6 + '\x04', // Type: SETTINGS + '\x01', // Flags: ACK + '\x00', '\x00', '\x00', '\x00', // Stream: 0 + '\x00', '\x00', // Extra + '\x00', '\x00', '\x00', '\x00', // Extra + }; + Http2FrameHeader header(6, Http2FrameType::SETTINGS, Http2FrameFlag::ACK, 0); + EXPECT_TRUE(DecodePayloadExpectingFrameSizeError(kFrameData, header)); +} + +TEST_F(Http2FrameDecoderTest, PingAckTooLong) { + const char kFrameData[] = { + '\x00', '\x00', '\x09', // Length: 8 + '\x06', // Type: PING + '\xff', // Flags: ACK | 0xfe + '\x00', '\x00', '\x00', '\x00', // Stream: 0 + 's', 'o', 'm', 'e', // "some" + 'd', 'a', 't', 'a', // "data" + '\x00', // Too much + }; + Http2FrameHeader header(9, Http2FrameType::PING, Http2FrameFlag::ACK, 0); + EXPECT_TRUE(DecodePayloadExpectingFrameSizeError(kFrameData, header)); +} + +TEST_F(Http2FrameDecoderTest, WindowUpdateTooLong) { + const char kFrameData[] = { + '\x00', '\x00', '\x05', // Length: 5 + '\x08', // Type: WINDOW_UPDATE + '\x0f', // Flags: 0xff (no valid flags) + '\x00', '\x00', '\x00', '\x01', // Stream: 1 + '\x80', '\x00', '\x04', '\x00', // Incr: 1024 (plus R bit) + '\x00', // Too much + }; + Http2FrameHeader header(5, Http2FrameType::WINDOW_UPDATE, 0, 1); + EXPECT_TRUE(DecodePayloadExpectingFrameSizeError(kFrameData, header)); +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/decoder/http2_structure_decoder.cc b/quiche/http2/decoder/http2_structure_decoder.cc new file mode 100644 index 000000000000..0569c984d0e8 --- /dev/null +++ b/quiche/http2/decoder/http2_structure_decoder.cc @@ -0,0 +1,95 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/http2_structure_decoder.h" + +#include +#include + +#include "quiche/common/platform/api/quiche_bug_tracker.h" + +namespace http2 { + +// Below we have some defensive coding: if we somehow run off the end, don't +// overwrite lots of memory. Note that most of this decoder is not defensive +// against bugs in the decoder, only against malicious encoders, but since +// we're copying memory into a buffer here, let's make sure we don't allow a +// small mistake to grow larger. The decoder will get stuck if we hit the +// QUICHE_BUG conditions, but shouldn't corrupt memory. + +uint32_t Http2StructureDecoder::IncompleteStart(DecodeBuffer* db, + uint32_t target_size) { + if (target_size > sizeof buffer_) { + QUICHE_BUG(http2_bug_154_1) + << "target_size too large for buffer: " << target_size; + return 0; + } + const uint32_t num_to_copy = db->MinLengthRemaining(target_size); + memcpy(buffer_, db->cursor(), num_to_copy); + offset_ = num_to_copy; + db->AdvanceCursor(num_to_copy); + return num_to_copy; +} + +DecodeStatus Http2StructureDecoder::IncompleteStart(DecodeBuffer* db, + uint32_t* remaining_payload, + uint32_t target_size) { + QUICHE_DVLOG(1) << "IncompleteStart@" << this + << ": *remaining_payload=" << *remaining_payload + << "; target_size=" << target_size + << "; db->Remaining=" << db->Remaining(); + *remaining_payload -= + IncompleteStart(db, std::min(target_size, *remaining_payload)); + if (*remaining_payload > 0 && db->Empty()) { + return DecodeStatus::kDecodeInProgress; + } + QUICHE_DVLOG(1) << "IncompleteStart: kDecodeError"; + return DecodeStatus::kDecodeError; +} + +bool Http2StructureDecoder::ResumeFillingBuffer(DecodeBuffer* db, + uint32_t target_size) { + QUICHE_DVLOG(2) << "ResumeFillingBuffer@" << this + << ": target_size=" << target_size << "; offset_=" << offset_ + << "; db->Remaining=" << db->Remaining(); + if (target_size < offset_) { + QUICHE_BUG(http2_bug_154_2) + << "Already filled buffer_! target_size=" << target_size + << " offset_=" << offset_; + return false; + } + const uint32_t needed = target_size - offset_; + const uint32_t num_to_copy = db->MinLengthRemaining(needed); + QUICHE_DVLOG(2) << "ResumeFillingBuffer num_to_copy=" << num_to_copy; + memcpy(&buffer_[offset_], db->cursor(), num_to_copy); + db->AdvanceCursor(num_to_copy); + offset_ += num_to_copy; + return needed == num_to_copy; +} + +bool Http2StructureDecoder::ResumeFillingBuffer(DecodeBuffer* db, + uint32_t* remaining_payload, + uint32_t target_size) { + QUICHE_DVLOG(2) << "ResumeFillingBuffer@" << this + << ": target_size=" << target_size << "; offset_=" << offset_ + << "; *remaining_payload=" << *remaining_payload + << "; db->Remaining=" << db->Remaining(); + if (target_size < offset_) { + QUICHE_BUG(http2_bug_154_3) + << "Already filled buffer_! target_size=" << target_size + << " offset_=" << offset_; + return false; + } + const uint32_t needed = target_size - offset_; + const uint32_t num_to_copy = + db->MinLengthRemaining(std::min(needed, *remaining_payload)); + QUICHE_DVLOG(2) << "ResumeFillingBuffer num_to_copy=" << num_to_copy; + memcpy(&buffer_[offset_], db->cursor(), num_to_copy); + db->AdvanceCursor(num_to_copy); + offset_ += num_to_copy; + *remaining_payload -= num_to_copy; + return needed == num_to_copy; +} + +} // namespace http2 diff --git a/quiche/http2/decoder/http2_structure_decoder.h b/quiche/http2/decoder/http2_structure_decoder.h new file mode 100644 index 000000000000..1daee07b309c --- /dev/null +++ b/quiche/http2/decoder/http2_structure_decoder.h @@ -0,0 +1,130 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_DECODER_HTTP2_STRUCTURE_DECODER_H_ +#define QUICHE_HTTP2_DECODER_HTTP2_STRUCTURE_DECODER_H_ + +// Http2StructureDecoder is a class for decoding the fixed size structures in +// the HTTP/2 spec, defined in quiche/http2/http2_structures.h. This class +// is in aid of deciding whether to keep the SlowDecode methods which I +// (jamessynge) now think may not be worth their complexity. In particular, +// if most transport buffers are large, so it is rare that a structure is +// split across buffer boundaries, than the cost of buffering upon +// those rare occurrences is small, which then simplifies the callers. + +#include + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_http2_structures.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { +namespace test { +class Http2StructureDecoderPeer; +} // namespace test + +class QUICHE_EXPORT Http2StructureDecoder { + public: + // The caller needs to keep track of whether to call Start or Resume. + // + // Start has an optimization for the case where the DecodeBuffer holds the + // entire encoded structure; in that case it decodes into *out and returns + // true, and does NOT touch the data members of the Http2StructureDecoder + // instance because the caller won't be calling Resume later. + // + // However, if the DecodeBuffer is too small to hold the entire encoded + // structure, Start copies the available bytes into the Http2StructureDecoder + // instance, and returns false to indicate that it has not been able to + // complete the decoding. + // + template + bool Start(S* out, DecodeBuffer* db) { + static_assert(S::EncodedSize() <= sizeof buffer_, "buffer_ is too small"); + QUICHE_DVLOG(2) << __func__ << "@" << this + << ": db->Remaining=" << db->Remaining() + << "; EncodedSize=" << S::EncodedSize(); + if (db->Remaining() >= S::EncodedSize()) { + DoDecode(out, db); + return true; + } + IncompleteStart(db, S::EncodedSize()); + return false; + } + + template + bool Resume(S* out, DecodeBuffer* db) { + QUICHE_DVLOG(2) << __func__ << "@" << this << ": offset_=" << offset_ + << "; db->Remaining=" << db->Remaining(); + if (ResumeFillingBuffer(db, S::EncodedSize())) { + // We have the whole thing now. + QUICHE_DVLOG(2) << __func__ << "@" << this << " offset_=" << offset_ + << " Ready to decode from buffer_."; + DecodeBuffer buffer_db(buffer_, S::EncodedSize()); + DoDecode(out, &buffer_db); + return true; + } + QUICHE_DCHECK_LT(offset_, S::EncodedSize()); + return false; + } + + // A second pair of Start and Resume, where the caller has a variable, + // |remaining_payload| that is both tested for sufficiency and updated + // during decoding. Note that the decode buffer may extend beyond the + // remaining payload because the buffer may include padding. + template + DecodeStatus Start(S* out, DecodeBuffer* db, uint32_t* remaining_payload) { + static_assert(S::EncodedSize() <= sizeof buffer_, "buffer_ is too small"); + QUICHE_DVLOG(2) << __func__ << "@" << this + << ": *remaining_payload=" << *remaining_payload + << "; db->Remaining=" << db->Remaining() + << "; EncodedSize=" << S::EncodedSize(); + if (db->MinLengthRemaining(*remaining_payload) >= S::EncodedSize()) { + DoDecode(out, db); + *remaining_payload -= S::EncodedSize(); + return DecodeStatus::kDecodeDone; + } + return IncompleteStart(db, remaining_payload, S::EncodedSize()); + } + + template + bool Resume(S* out, DecodeBuffer* db, uint32_t* remaining_payload) { + QUICHE_DVLOG(3) << __func__ << "@" << this << ": offset_=" << offset_ + << "; *remaining_payload=" << *remaining_payload + << "; db->Remaining=" << db->Remaining() + << "; EncodedSize=" << S::EncodedSize(); + if (ResumeFillingBuffer(db, remaining_payload, S::EncodedSize())) { + // We have the whole thing now. + QUICHE_DVLOG(2) << __func__ << "@" << this << ": offset_=" << offset_ + << "; Ready to decode from buffer_."; + DecodeBuffer buffer_db(buffer_, S::EncodedSize()); + DoDecode(out, &buffer_db); + return true; + } + QUICHE_DCHECK_LT(offset_, S::EncodedSize()); + return false; + } + + uint32_t offset() const { return offset_; } + + private: + friend class test::Http2StructureDecoderPeer; + + uint32_t IncompleteStart(DecodeBuffer* db, uint32_t target_size); + DecodeStatus IncompleteStart(DecodeBuffer* db, uint32_t* remaining_payload, + uint32_t target_size); + + bool ResumeFillingBuffer(DecodeBuffer* db, uint32_t target_size); + bool ResumeFillingBuffer(DecodeBuffer* db, uint32_t* remaining_payload, + uint32_t target_size); + + uint32_t offset_; + char buffer_[Http2FrameHeader::EncodedSize()]; +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_DECODER_HTTP2_STRUCTURE_DECODER_H_ diff --git a/quiche/http2/decoder/http2_structure_decoder_test.cc b/quiche/http2/decoder/http2_structure_decoder_test.cc new file mode 100644 index 000000000000..bbb4ba842f8b --- /dev/null +++ b/quiche/http2/decoder/http2_structure_decoder_test.cc @@ -0,0 +1,535 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/http2_structure_decoder.h" + +// Tests decoding all of the fixed size HTTP/2 structures (i.e. those defined in +// quiche/http2/http2_structures.h) using Http2StructureDecoder, which +// handles buffering of structures split across input buffer boundaries, and in +// turn uses DoDecode when it has all of a structure in a contiguous buffer. + +// NOTE: This tests the first pair of Start and Resume, which don't take +// a remaining_payload parameter. The other pair are well tested via the +// payload decoder tests, though... +// TODO(jamessynge): Create type parameterized tests for Http2StructureDecoder +// where the type is the type of structure, and with testing of both pairs of +// Start and Resume methods; note that it appears that the first pair will be +// used only for Http2FrameHeader, and the other pair only for structures in the +// frame payload. + +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/test_tools/http2_frame_builder.h" +#include "quiche/http2/test_tools/http2_structures_test_util.h" +#include "quiche/http2/test_tools/random_decoder_test_base.h" +#include "quiche/http2/test_tools/verify_macros.h" +#include "quiche/common/platform/api/quiche_logging.h" + +using ::testing::AssertionSuccess; + +namespace http2 { +namespace test { +namespace { +const bool kMayReturnZeroOnFirst = false; + +template +class Http2StructureDecoderTest : public RandomDecoderTest { + protected: + typedef S Structure; + + Http2StructureDecoderTest() { + // IF the test adds more data after the encoded structure, stop as + // soon as the structure is decoded. + stop_decode_on_done_ = true; + } + + DecodeStatus StartDecoding(DecodeBuffer* b) override { + // Overwrite the current contents of |structure_|, into which we'll + // decode the buffer, so that we can be confident that we really decoded + // the structure every time. + structure_ = std::make_unique(); + uint32_t old_remaining = b->Remaining(); + if (structure_decoder_.Start(structure_.get(), b)) { + EXPECT_EQ(old_remaining - S::EncodedSize(), b->Remaining()); + ++fast_decode_count_; + return DecodeStatus::kDecodeDone; + } else { + EXPECT_LT(structure_decoder_.offset(), S::EncodedSize()); + EXPECT_EQ(0u, b->Remaining()); + EXPECT_EQ(old_remaining - structure_decoder_.offset(), b->Remaining()); + ++incomplete_start_count_; + return DecodeStatus::kDecodeInProgress; + } + } + + DecodeStatus ResumeDecoding(DecodeBuffer* b) override { + uint32_t old_offset = structure_decoder_.offset(); + EXPECT_LT(old_offset, S::EncodedSize()); + uint32_t avail = b->Remaining(); + if (structure_decoder_.Resume(structure_.get(), b)) { + EXPECT_LE(S::EncodedSize(), old_offset + avail); + EXPECT_EQ(b->Remaining(), avail - (S::EncodedSize() - old_offset)); + ++slow_decode_count_; + return DecodeStatus::kDecodeDone; + } else { + EXPECT_LT(structure_decoder_.offset(), S::EncodedSize()); + EXPECT_EQ(0u, b->Remaining()); + EXPECT_GT(S::EncodedSize(), old_offset + avail); + ++incomplete_resume_count_; + return DecodeStatus::kDecodeInProgress; + } + } + + // Fully decodes the Structure at the start of data, and confirms it matches + // *expected (if provided). + AssertionResult DecodeLeadingStructure(const S* expected, + absl::string_view data) { + HTTP2_VERIFY_LE(S::EncodedSize(), data.size()); + DecodeBuffer original(data); + + // The validator is called after each of the several times that the input + // DecodeBuffer is decoded, each with a different segmentation of the input. + // Validate that structure_ matches the expected value, if provided. + Validator validator; + if (expected != nullptr) { + validator = [expected, this](const DecodeBuffer& /*db*/, + DecodeStatus /*status*/) -> AssertionResult { + HTTP2_VERIFY_EQ(*expected, *structure_); + return AssertionSuccess(); + }; + } + + // Before that, validate that decoding is done and that we've advanced + // the cursor the expected amount. + validator = ValidateDoneAndOffset(S::EncodedSize(), validator); + + // Decode several times, with several segmentations of the input buffer. + fast_decode_count_ = 0; + slow_decode_count_ = 0; + incomplete_start_count_ = 0; + incomplete_resume_count_ = 0; + HTTP2_VERIFY_SUCCESS(DecodeAndValidateSeveralWays( + &original, kMayReturnZeroOnFirst, validator)); + HTTP2_VERIFY_FALSE(HasFailure()); + HTTP2_VERIFY_EQ(S::EncodedSize(), structure_decoder_.offset()); + HTTP2_VERIFY_EQ(S::EncodedSize(), original.Offset()); + HTTP2_VERIFY_LT(0u, fast_decode_count_); + HTTP2_VERIFY_LT(0u, slow_decode_count_); + HTTP2_VERIFY_LT(0u, incomplete_start_count_); + + // If the structure is large enough so that SelectZeroOrOne will have + // caused Resume to return false, check that occurred. + if (S::EncodedSize() >= 2) { + HTTP2_VERIFY_LE(0u, incomplete_resume_count_); + } else { + HTTP2_VERIFY_EQ(0u, incomplete_resume_count_); + } + if (expected != nullptr) { + QUICHE_DVLOG(1) << "DecodeLeadingStructure expected: " << *expected; + QUICHE_DVLOG(1) << "DecodeLeadingStructure actual: " << *structure_; + HTTP2_VERIFY_EQ(*expected, *structure_); + } + return AssertionSuccess(); + } + + template + AssertionResult DecodeLeadingStructure(const char (&data)[N]) { + return DecodeLeadingStructure(nullptr, absl::string_view(data, N)); + } + + template + AssertionResult DecodeLeadingStructure(const unsigned char (&data)[N]) { + return DecodeLeadingStructure(nullptr, ToStringPiece(data)); + } + + // Encode the structure |in_s| into bytes, then decode the bytes + // and validate that the decoder produced the same field values. + AssertionResult EncodeThenDecode(const S& in_s) { + std::string bytes = SerializeStructure(in_s); + HTTP2_VERIFY_EQ(S::EncodedSize(), bytes.size()); + return DecodeLeadingStructure(&in_s, bytes); + } + + // Repeatedly fill a structure with random but valid contents, encode it, then + // decode it, and finally validate that the decoded structure matches the + // random input. Lather-rinse-and-repeat. + AssertionResult TestDecodingRandomizedStructures(size_t count) { + for (size_t i = 0; i < count; ++i) { + Structure input; + Randomize(&input, RandomPtr()); + HTTP2_VERIFY_SUCCESS(EncodeThenDecode(input)); + } + return AssertionSuccess(); + } + + AssertionResult TestDecodingRandomizedStructures() { + HTTP2_VERIFY_SUCCESS(TestDecodingRandomizedStructures(100)); + return AssertionSuccess(); + } + + uint32_t decode_offset_ = 0; + std::unique_ptr structure_; + Http2StructureDecoder structure_decoder_; + size_t fast_decode_count_ = 0; + size_t slow_decode_count_ = 0; + size_t incomplete_start_count_ = 0; + size_t incomplete_resume_count_ = 0; +}; + +class Http2FrameHeaderDecoderTest + : public Http2StructureDecoderTest {}; + +TEST_F(Http2FrameHeaderDecoderTest, DecodesLiteral) { + { + // Realistic input. + // clang-format off + const char kData[] = { + 0x00, 0x00, 0x05, // Payload length: 5 + 0x01, // Frame type: HEADERS + 0x08, // Flags: PADDED + 0x00, 0x00, 0x00, 0x01, // Stream ID: 1 + 0x04, // Padding length: 4 + 0x00, 0x00, 0x00, 0x00, // Padding bytes + }; + // clang-format on + ASSERT_TRUE(DecodeLeadingStructure(kData)); + EXPECT_EQ(5u, structure_->payload_length); + EXPECT_EQ(Http2FrameType::HEADERS, structure_->type); + EXPECT_EQ(Http2FrameFlag::PADDED, structure_->flags); + EXPECT_EQ(1u, structure_->stream_id); + } + { + // Unlikely input. + // clang-format off + const unsigned char kData[] = { + 0xff, 0xff, 0xff, // Payload length: uint24 max + 0xff, // Frame type: Unknown + 0xff, // Flags: Unknown/All + 0xff, 0xff, 0xff, 0xff, // Stream ID: uint31 max, plus R-bit + }; + // clang-format on + ASSERT_TRUE(DecodeLeadingStructure(kData)); + EXPECT_EQ((1u << 24) - 1u, structure_->payload_length); + EXPECT_EQ(static_cast(255), structure_->type); + EXPECT_EQ(255, structure_->flags); + EXPECT_EQ(0x7FFFFFFFu, structure_->stream_id); + } +} + +TEST_F(Http2FrameHeaderDecoderTest, DecodesRandomized) { + EXPECT_TRUE(TestDecodingRandomizedStructures()); +} + +//------------------------------------------------------------------------------ + +class Http2PriorityFieldsDecoderTest + : public Http2StructureDecoderTest {}; + +TEST_F(Http2PriorityFieldsDecoderTest, DecodesLiteral) { + { + // clang-format off + const unsigned char kData[] = { + 0x80, 0x00, 0x00, 0x05, // Exclusive (yes) and Dependency (5) + 0xff, // Weight: 256 (after adding 1) + }; + // clang-format on + ASSERT_TRUE(DecodeLeadingStructure(kData)); + EXPECT_EQ(5u, structure_->stream_dependency); + EXPECT_EQ(256u, structure_->weight); + EXPECT_EQ(true, structure_->is_exclusive); + } + { + // clang-format off + const unsigned char kData[] = { + 0x7f, 0xff, 0xff, 0xff, // Excl. (no) and Dependency (uint31 max) + 0x00, // Weight: 1 (after adding 1) + }; + // clang-format on + ASSERT_TRUE(DecodeLeadingStructure(kData)); + EXPECT_EQ(StreamIdMask(), structure_->stream_dependency); + EXPECT_EQ(1u, structure_->weight); + EXPECT_FALSE(structure_->is_exclusive); + } +} + +TEST_F(Http2PriorityFieldsDecoderTest, DecodesRandomized) { + EXPECT_TRUE(TestDecodingRandomizedStructures()); +} + +//------------------------------------------------------------------------------ + +class Http2RstStreamFieldsDecoderTest + : public Http2StructureDecoderTest {}; + +TEST_F(Http2RstStreamFieldsDecoderTest, DecodesLiteral) { + { + // clang-format off + const char kData[] = { + 0x00, 0x00, 0x00, 0x01, // Error: PROTOCOL_ERROR + }; + // clang-format on + ASSERT_TRUE(DecodeLeadingStructure(kData)); + EXPECT_TRUE(structure_->IsSupportedErrorCode()); + EXPECT_EQ(Http2ErrorCode::PROTOCOL_ERROR, structure_->error_code); + } + { + // clang-format off + const unsigned char kData[] = { + 0xff, 0xff, 0xff, 0xff, // Error: max uint32 (Unknown error code) + }; + // clang-format on + ASSERT_TRUE(DecodeLeadingStructure(kData)); + EXPECT_FALSE(structure_->IsSupportedErrorCode()); + EXPECT_EQ(static_cast(0xffffffff), structure_->error_code); + } +} + +TEST_F(Http2RstStreamFieldsDecoderTest, DecodesRandomized) { + EXPECT_TRUE(TestDecodingRandomizedStructures()); +} + +//------------------------------------------------------------------------------ + +class Http2SettingFieldsDecoderTest + : public Http2StructureDecoderTest {}; + +TEST_F(Http2SettingFieldsDecoderTest, DecodesLiteral) { + { + // clang-format off + const char kData[] = { + 0x00, 0x01, // Setting: HEADER_TABLE_SIZE + 0x00, 0x00, 0x40, 0x00, // Value: 16K + }; + // clang-format on + ASSERT_TRUE(DecodeLeadingStructure(kData)); + EXPECT_TRUE(structure_->IsSupportedParameter()); + EXPECT_EQ(Http2SettingsParameter::HEADER_TABLE_SIZE, structure_->parameter); + EXPECT_EQ(1u << 14, structure_->value); + } + { + // clang-format off + const unsigned char kData[] = { + 0x00, 0x00, // Setting: Unknown (0) + 0xff, 0xff, 0xff, 0xff, // Value: max uint32 + }; + // clang-format on + ASSERT_TRUE(DecodeLeadingStructure(kData)); + EXPECT_FALSE(structure_->IsSupportedParameter()); + EXPECT_EQ(static_cast(0), structure_->parameter); + } +} + +TEST_F(Http2SettingFieldsDecoderTest, DecodesRandomized) { + EXPECT_TRUE(TestDecodingRandomizedStructures()); +} + +//------------------------------------------------------------------------------ + +class Http2PushPromiseFieldsDecoderTest + : public Http2StructureDecoderTest {}; + +TEST_F(Http2PushPromiseFieldsDecoderTest, DecodesLiteral) { + { + // clang-format off + const unsigned char kData[] = { + 0x00, 0x01, 0x8a, 0x92, // Promised Stream ID: 101010 + }; + // clang-format on + ASSERT_TRUE(DecodeLeadingStructure(kData)); + EXPECT_EQ(101010u, structure_->promised_stream_id); + } + { + // Promised stream id has R-bit (reserved for future use) set, which + // should be cleared by the decoder. + // clang-format off + const unsigned char kData[] = { + // Promised Stream ID: max uint31 and R-bit + 0xff, 0xff, 0xff, 0xff, + }; + // clang-format on + ASSERT_TRUE(DecodeLeadingStructure(kData)); + EXPECT_EQ(StreamIdMask(), structure_->promised_stream_id); + } +} + +TEST_F(Http2PushPromiseFieldsDecoderTest, DecodesRandomized) { + EXPECT_TRUE(TestDecodingRandomizedStructures()); +} + +//------------------------------------------------------------------------------ + +class Http2PingFieldsDecoderTest + : public Http2StructureDecoderTest {}; + +TEST_F(Http2PingFieldsDecoderTest, DecodesLiteral) { + { + // Each byte is different, so can detect if order changed. + const char kData[] = { + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + }; + ASSERT_TRUE(DecodeLeadingStructure(kData)); + EXPECT_EQ(ToStringPiece(kData), ToStringPiece(structure_->opaque_bytes)); + } + { + // All zeros, detect problems handling NULs. + const char kData[] = { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }; + ASSERT_TRUE(DecodeLeadingStructure(kData)); + EXPECT_EQ(ToStringPiece(kData), ToStringPiece(structure_->opaque_bytes)); + } + { + const unsigned char kData[] = { + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + }; + ASSERT_TRUE(DecodeLeadingStructure(kData)); + EXPECT_EQ(ToStringPiece(kData), ToStringPiece(structure_->opaque_bytes)); + } +} + +TEST_F(Http2PingFieldsDecoderTest, DecodesRandomized) { + EXPECT_TRUE(TestDecodingRandomizedStructures()); +} + +//------------------------------------------------------------------------------ + +class Http2GoAwayFieldsDecoderTest + : public Http2StructureDecoderTest {}; + +TEST_F(Http2GoAwayFieldsDecoderTest, DecodesLiteral) { + { + // clang-format off + const char kData[] = { + 0x00, 0x00, 0x00, 0x00, // Last Stream ID: 0 + 0x00, 0x00, 0x00, 0x00, // Error: NO_ERROR (0) + }; + // clang-format on + ASSERT_TRUE(DecodeLeadingStructure(kData)); + EXPECT_EQ(0u, structure_->last_stream_id); + EXPECT_TRUE(structure_->IsSupportedErrorCode()); + EXPECT_EQ(Http2ErrorCode::HTTP2_NO_ERROR, structure_->error_code); + } + { + // clang-format off + const char kData[] = { + 0x00, 0x00, 0x00, 0x01, // Last Stream ID: 1 + 0x00, 0x00, 0x00, 0x0d, // Error: HTTP_1_1_REQUIRED + }; + // clang-format on + ASSERT_TRUE(DecodeLeadingStructure(kData)); + EXPECT_EQ(1u, structure_->last_stream_id); + EXPECT_TRUE(structure_->IsSupportedErrorCode()); + EXPECT_EQ(Http2ErrorCode::HTTP_1_1_REQUIRED, structure_->error_code); + } + { + // clang-format off + const unsigned char kData[] = { + 0xff, 0xff, 0xff, 0xff, // Last Stream ID: max uint31 and R-bit + 0xff, 0xff, 0xff, 0xff, // Error: max uint32 (Unknown error code) + }; + // clang-format on + ASSERT_TRUE(DecodeLeadingStructure(kData)); + EXPECT_EQ(StreamIdMask(), structure_->last_stream_id); // No high-bit. + EXPECT_FALSE(structure_->IsSupportedErrorCode()); + EXPECT_EQ(static_cast(0xffffffff), structure_->error_code); + } +} + +TEST_F(Http2GoAwayFieldsDecoderTest, DecodesRandomized) { + EXPECT_TRUE(TestDecodingRandomizedStructures()); +} + +//------------------------------------------------------------------------------ + +class Http2WindowUpdateFieldsDecoderTest + : public Http2StructureDecoderTest {}; + +TEST_F(Http2WindowUpdateFieldsDecoderTest, DecodesLiteral) { + { + // clang-format off + const char kData[] = { + 0x00, 0x01, 0x00, 0x00, // Window Size Increment: 2 ^ 16 + }; + // clang-format on + ASSERT_TRUE(DecodeLeadingStructure(kData)); + EXPECT_EQ(1u << 16, structure_->window_size_increment); + } + { + // Increment must be non-zero, but we need to be able to decode the invalid + // zero to detect it. + // clang-format off + const char kData[] = { + 0x00, 0x00, 0x00, 0x00, // Window Size Increment: 0 + }; + // clang-format on + ASSERT_TRUE(DecodeLeadingStructure(kData)); + EXPECT_EQ(0u, structure_->window_size_increment); + } + { + // Increment has R-bit (reserved for future use) set, which + // should be cleared by the decoder. + // clang-format off + const unsigned char kData[] = { + // Window Size Increment: max uint31 and R-bit + 0xff, 0xff, 0xff, 0xff, + }; + // clang-format on + ASSERT_TRUE(DecodeLeadingStructure(kData)); + EXPECT_EQ(StreamIdMask(), structure_->window_size_increment); + } +} + +TEST_F(Http2WindowUpdateFieldsDecoderTest, DecodesRandomized) { + EXPECT_TRUE(TestDecodingRandomizedStructures()); +} + +//------------------------------------------------------------------------------ + +class Http2AltSvcFieldsDecoderTest + : public Http2StructureDecoderTest {}; + +TEST_F(Http2AltSvcFieldsDecoderTest, DecodesLiteral) { + { + // clang-format off + const char kData[] = { + 0x00, 0x00, // Origin Length: 0 + }; + // clang-format on + ASSERT_TRUE(DecodeLeadingStructure(kData)); + EXPECT_EQ(0, structure_->origin_length); + } + { + // clang-format off + const char kData[] = { + 0x00, 0x14, // Origin Length: 20 + }; + // clang-format on + ASSERT_TRUE(DecodeLeadingStructure(kData)); + EXPECT_EQ(20, structure_->origin_length); + } + { + // clang-format off + const unsigned char kData[] = { + 0xff, 0xff, // Origin Length: uint16 max + }; + // clang-format on + ASSERT_TRUE(DecodeLeadingStructure(kData)); + EXPECT_EQ(65535, structure_->origin_length); + } +} + +TEST_F(Http2AltSvcFieldsDecoderTest, DecodesRandomized) { + EXPECT_TRUE(TestDecodingRandomizedStructures()); +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/decoder/payload_decoders/altsvc_payload_decoder.cc b/quiche/http2/decoder/payload_decoders/altsvc_payload_decoder.cc new file mode 100644 index 000000000000..8602f703a181 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/altsvc_payload_decoder.cc @@ -0,0 +1,149 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/payload_decoders/altsvc_payload_decoder.h" + +#include + +#include "absl/base/macros.h" +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +std::ostream& operator<<(std::ostream& out, + AltSvcPayloadDecoder::PayloadState v) { + switch (v) { + case AltSvcPayloadDecoder::PayloadState::kStartDecodingStruct: + return out << "kStartDecodingStruct"; + case AltSvcPayloadDecoder::PayloadState::kMaybeDecodedStruct: + return out << "kMaybeDecodedStruct"; + case AltSvcPayloadDecoder::PayloadState::kDecodingStrings: + return out << "kDecodingStrings"; + case AltSvcPayloadDecoder::PayloadState::kResumeDecodingStruct: + return out << "kResumeDecodingStruct"; + } + // Since the value doesn't come over the wire, only a programming bug should + // result in reaching this point. + int unknown = static_cast(v); + QUICHE_BUG(http2_bug_163_1) + << "Invalid AltSvcPayloadDecoder::PayloadState: " << unknown; + return out << "AltSvcPayloadDecoder::PayloadState(" << unknown << ")"; +} + +DecodeStatus AltSvcPayloadDecoder::StartDecodingPayload( + FrameDecoderState* state, DecodeBuffer* db) { + QUICHE_DVLOG(2) << "AltSvcPayloadDecoder::StartDecodingPayload: " + << state->frame_header(); + QUICHE_DCHECK_EQ(Http2FrameType::ALTSVC, state->frame_header().type); + QUICHE_DCHECK_LE(db->Remaining(), state->frame_header().payload_length); + QUICHE_DCHECK_EQ(0, state->frame_header().flags); + + state->InitializeRemainders(); + payload_state_ = PayloadState::kStartDecodingStruct; + + return ResumeDecodingPayload(state, db); +} + +DecodeStatus AltSvcPayloadDecoder::ResumeDecodingPayload( + FrameDecoderState* state, DecodeBuffer* db) { + const Http2FrameHeader& frame_header = state->frame_header(); + QUICHE_DVLOG(2) << "AltSvcPayloadDecoder::ResumeDecodingPayload: " + << frame_header; + QUICHE_DCHECK_EQ(Http2FrameType::ALTSVC, frame_header.type); + QUICHE_DCHECK_LE(state->remaining_payload(), frame_header.payload_length); + QUICHE_DCHECK_LE(db->Remaining(), state->remaining_payload()); + QUICHE_DCHECK_NE(PayloadState::kMaybeDecodedStruct, payload_state_); + // |status| has to be initialized to some value to avoid compiler error in + // case PayloadState::kMaybeDecodedStruct below, but value does not matter, + // see QUICHE_DCHECK_NE above. + DecodeStatus status = DecodeStatus::kDecodeError; + while (true) { + QUICHE_DVLOG(2) + << "AltSvcPayloadDecoder::ResumeDecodingPayload payload_state_=" + << payload_state_; + switch (payload_state_) { + case PayloadState::kStartDecodingStruct: + status = state->StartDecodingStructureInPayload(&altsvc_fields_, db); + ABSL_FALLTHROUGH_INTENDED; + + case PayloadState::kMaybeDecodedStruct: + if (status == DecodeStatus::kDecodeDone && + altsvc_fields_.origin_length <= state->remaining_payload()) { + size_t origin_length = altsvc_fields_.origin_length; + size_t value_length = state->remaining_payload() - origin_length; + state->listener()->OnAltSvcStart(frame_header, origin_length, + value_length); + } else if (status != DecodeStatus::kDecodeDone) { + QUICHE_DCHECK(state->remaining_payload() > 0 || + status == DecodeStatus::kDecodeError) + << "\nremaining_payload: " << state->remaining_payload() + << "\nstatus: " << status << "\nheader: " << frame_header; + // Assume in progress. + payload_state_ = PayloadState::kResumeDecodingStruct; + return status; + } else { + // The origin's length is longer than the remaining payload. + QUICHE_DCHECK_GT(altsvc_fields_.origin_length, + state->remaining_payload()); + return state->ReportFrameSizeError(); + } + ABSL_FALLTHROUGH_INTENDED; + + case PayloadState::kDecodingStrings: + return DecodeStrings(state, db); + + case PayloadState::kResumeDecodingStruct: + status = state->ResumeDecodingStructureInPayload(&altsvc_fields_, db); + payload_state_ = PayloadState::kMaybeDecodedStruct; + continue; + } + QUICHE_BUG(http2_bug_163_2) << "PayloadState: " << payload_state_; + } +} + +DecodeStatus AltSvcPayloadDecoder::DecodeStrings(FrameDecoderState* state, + DecodeBuffer* db) { + QUICHE_DVLOG(2) << "AltSvcPayloadDecoder::DecodeStrings remaining_payload=" + << state->remaining_payload() + << ", db->Remaining=" << db->Remaining(); + // Note that we don't explicitly keep track of exactly how far through the + // origin; instead we compute it from how much is left of the original + // payload length and the decoded total length of the origin. + size_t origin_length = altsvc_fields_.origin_length; + size_t value_length = state->frame_header().payload_length - origin_length - + Http2AltSvcFields::EncodedSize(); + if (state->remaining_payload() > value_length) { + size_t remaining_origin_length = state->remaining_payload() - value_length; + size_t avail = db->MinLengthRemaining(remaining_origin_length); + state->listener()->OnAltSvcOriginData(db->cursor(), avail); + db->AdvanceCursor(avail); + state->ConsumePayload(avail); + if (remaining_origin_length > avail) { + payload_state_ = PayloadState::kDecodingStrings; + return DecodeStatus::kDecodeInProgress; + } + } + // All that is left is the value string. + QUICHE_DCHECK_LE(state->remaining_payload(), value_length); + QUICHE_DCHECK_LE(db->Remaining(), state->remaining_payload()); + if (db->HasData()) { + size_t avail = db->Remaining(); + state->listener()->OnAltSvcValueData(db->cursor(), avail); + db->AdvanceCursor(avail); + state->ConsumePayload(avail); + } + if (state->remaining_payload() == 0) { + state->listener()->OnAltSvcEnd(); + return DecodeStatus::kDecodeDone; + } + payload_state_ = PayloadState::kDecodingStrings; + return DecodeStatus::kDecodeInProgress; +} + +} // namespace http2 diff --git a/quiche/http2/decoder/payload_decoders/altsvc_payload_decoder.h b/quiche/http2/decoder/payload_decoders/altsvc_payload_decoder.h new file mode 100644 index 000000000000..2829249fbbbb --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/altsvc_payload_decoder.h @@ -0,0 +1,64 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_ALTSVC_PAYLOAD_DECODER_H_ +#define QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_ALTSVC_PAYLOAD_DECODER_H_ + +// Decodes the payload of a ALTSVC frame. + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/decoder/frame_decoder_state.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace test { +class AltSvcPayloadDecoderPeer; +} // namespace test + +class QUICHE_EXPORT AltSvcPayloadDecoder { + public: + // States during decoding of a ALTSVC frame. + enum class PayloadState { + // Start decoding the fixed size structure at the start of an ALTSVC + // frame (Http2AltSvcFields). + kStartDecodingStruct, + + // Handle the DecodeStatus returned from starting or resuming the + // decoding of Http2AltSvcFields. If complete, calls OnAltSvcStart. + kMaybeDecodedStruct, + + // Reports the value of the strings (origin and value) of an ALTSVC frame + // to the listener. + kDecodingStrings, + + // The initial decode buffer wasn't large enough for the Http2AltSvcFields, + // so this state resumes the decoding when ResumeDecodingPayload is called + // later with a new DecodeBuffer. + kResumeDecodingStruct, + }; + + // Starts the decoding of a ALTSVC frame's payload, and completes it if the + // entire payload is in the provided decode buffer. + DecodeStatus StartDecodingPayload(FrameDecoderState* state, DecodeBuffer* db); + + // Resumes decoding a ALTSVC frame's payload that has been split across + // decode buffers. + DecodeStatus ResumeDecodingPayload(FrameDecoderState* state, + DecodeBuffer* db); + + private: + friend class test::AltSvcPayloadDecoderPeer; + + // Implements state kDecodingStrings. + DecodeStatus DecodeStrings(FrameDecoderState* state, DecodeBuffer* db); + + Http2AltSvcFields altsvc_fields_; + PayloadState payload_state_; +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_ALTSVC_PAYLOAD_DECODER_H_ diff --git a/quiche/http2/decoder/payload_decoders/altsvc_payload_decoder_test.cc b/quiche/http2/decoder/payload_decoders/altsvc_payload_decoder_test.cc new file mode 100644 index 000000000000..f41934cbbd85 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/altsvc_payload_decoder_test.cc @@ -0,0 +1,121 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/payload_decoders/altsvc_payload_decoder.h" + +#include + +#include +#include + +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/test_tools/frame_parts.h" +#include "quiche/http2/test_tools/frame_parts_collector.h" +#include "quiche/http2/test_tools/http2_frame_builder.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/http2/test_tools/http2_structures_test_util.h" +#include "quiche/http2/test_tools/payload_decoder_base_test_util.h" +#include "quiche/http2/test_tools/random_decoder_test_base.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { + +// Provides friend access to an instance of the payload decoder, and also +// provides info to aid in testing. +class AltSvcPayloadDecoderPeer { + public: + static constexpr Http2FrameType FrameType() { return Http2FrameType::ALTSVC; } + + // Returns the mask of flags that affect the decoding of the payload (i.e. + // flags that that indicate the presence of certain fields or padding). + static constexpr uint8_t FlagsAffectingPayloadDecoding() { return 0; } +}; + +namespace { + +struct Listener : public FramePartsCollector { + void OnAltSvcStart(const Http2FrameHeader& header, size_t origin_length, + size_t value_length) override { + QUICHE_VLOG(1) << "OnAltSvcStart header: " << header + << "; origin_length=" << origin_length + << "; value_length=" << value_length; + StartFrame(header)->OnAltSvcStart(header, origin_length, value_length); + } + + void OnAltSvcOriginData(const char* data, size_t len) override { + QUICHE_VLOG(1) << "OnAltSvcOriginData: len=" << len; + CurrentFrame()->OnAltSvcOriginData(data, len); + } + + void OnAltSvcValueData(const char* data, size_t len) override { + QUICHE_VLOG(1) << "OnAltSvcValueData: len=" << len; + CurrentFrame()->OnAltSvcValueData(data, len); + } + + void OnAltSvcEnd() override { + QUICHE_VLOG(1) << "OnAltSvcEnd"; + EndFrame()->OnAltSvcEnd(); + } + + void OnFrameSizeError(const Http2FrameHeader& header) override { + QUICHE_VLOG(1) << "OnFrameSizeError: " << header; + FrameError(header)->OnFrameSizeError(header); + } +}; + +class AltSvcPayloadDecoderTest + : public AbstractPayloadDecoderTest {}; + +// Confirm we get an error if the payload is not long enough to hold +// Http2AltSvcFields and the indicated length of origin. +TEST_F(AltSvcPayloadDecoderTest, Truncated) { + Http2FrameBuilder fb; + fb.Append(Http2AltSvcFields{0xffff}); // The longest possible origin length. + fb.Append("Too little origin!"); + EXPECT_TRUE( + VerifyDetectsFrameSizeError(0, fb.buffer(), /*approve_size*/ nullptr)); +} + +class AltSvcPayloadLengthTests + : public AltSvcPayloadDecoderTest, + public ::testing::WithParamInterface> { + protected: + AltSvcPayloadLengthTests() + : origin_length_(std::get<0>(GetParam())), + value_length_(std::get<1>(GetParam())) { + QUICHE_VLOG(1) << "################ origin_length_=" << origin_length_ + << " value_length_=" << value_length_ + << " ################"; + } + + const uint16_t origin_length_; + const uint32_t value_length_; +}; + +INSTANTIATE_TEST_SUITE_P(VariousOriginAndValueLengths, AltSvcPayloadLengthTests, + ::testing::Combine(::testing::Values(0, 1, 3, 65535), + ::testing::Values(0, 1, 3, 65537))); + +TEST_P(AltSvcPayloadLengthTests, ValidOriginAndValueLength) { + std::string origin = Random().RandString(origin_length_); + std::string value = Random().RandString(value_length_); + Http2FrameBuilder fb; + fb.Append(Http2AltSvcFields{origin_length_}); + fb.Append(origin); + fb.Append(value); + Http2FrameHeader header(fb.size(), Http2FrameType::ALTSVC, RandFlags(), + RandStreamId()); + set_frame_header(header); + FrameParts expected(header); + expected.SetAltSvcExpected(origin, value); + ASSERT_TRUE(DecodePayloadAndValidateSeveralWays(fb.buffer(), expected)); +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/decoder/payload_decoders/continuation_payload_decoder.cc b/quiche/http2/decoder/payload_decoders/continuation_payload_decoder.cc new file mode 100644 index 000000000000..0ed64a97e491 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/continuation_payload_decoder.cc @@ -0,0 +1,57 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/payload_decoders/continuation_payload_decoder.h" + +#include + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +DecodeStatus ContinuationPayloadDecoder::StartDecodingPayload( + FrameDecoderState* state, DecodeBuffer* db) { + const Http2FrameHeader& frame_header = state->frame_header(); + const uint32_t total_length = frame_header.payload_length; + + QUICHE_DVLOG(2) << "ContinuationPayloadDecoder::StartDecodingPayload: " + << frame_header; + QUICHE_DCHECK_EQ(Http2FrameType::CONTINUATION, frame_header.type); + QUICHE_DCHECK_LE(db->Remaining(), total_length); + QUICHE_DCHECK_EQ(0, frame_header.flags & ~(Http2FrameFlag::END_HEADERS)); + + state->InitializeRemainders(); + state->listener()->OnContinuationStart(frame_header); + return ResumeDecodingPayload(state, db); +} + +DecodeStatus ContinuationPayloadDecoder::ResumeDecodingPayload( + FrameDecoderState* state, DecodeBuffer* db) { + QUICHE_DVLOG(2) << "ContinuationPayloadDecoder::ResumeDecodingPayload" + << " remaining_payload=" << state->remaining_payload() + << " db->Remaining=" << db->Remaining(); + QUICHE_DCHECK_EQ(Http2FrameType::CONTINUATION, state->frame_header().type); + QUICHE_DCHECK_LE(state->remaining_payload(), + state->frame_header().payload_length); + QUICHE_DCHECK_LE(db->Remaining(), state->remaining_payload()); + + size_t avail = db->Remaining(); + QUICHE_DCHECK_LE(avail, state->remaining_payload()); + if (avail > 0) { + state->listener()->OnHpackFragment(db->cursor(), avail); + db->AdvanceCursor(avail); + state->ConsumePayload(avail); + } + if (state->remaining_payload() == 0) { + state->listener()->OnContinuationEnd(); + return DecodeStatus::kDecodeDone; + } + return DecodeStatus::kDecodeInProgress; +} + +} // namespace http2 diff --git a/quiche/http2/decoder/payload_decoders/continuation_payload_decoder.h b/quiche/http2/decoder/payload_decoders/continuation_payload_decoder.h new file mode 100644 index 000000000000..b599fb59aeee --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/continuation_payload_decoder.h @@ -0,0 +1,31 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_CONTINUATION_PAYLOAD_DECODER_H_ +#define QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_CONTINUATION_PAYLOAD_DECODER_H_ + +// Decodes the payload of a CONTINUATION frame. + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/decoder/frame_decoder_state.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { + +class QUICHE_EXPORT ContinuationPayloadDecoder { + public: + // Starts the decoding of a CONTINUATION frame's payload, and completes + // it if the entire payload is in the provided decode buffer. + DecodeStatus StartDecodingPayload(FrameDecoderState* state, DecodeBuffer* db); + + // Resumes decoding a CONTINUATION frame's payload that has been split across + // decode buffers. + DecodeStatus ResumeDecodingPayload(FrameDecoderState* state, + DecodeBuffer* db); +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_CONTINUATION_PAYLOAD_DECODER_H_ diff --git a/quiche/http2/decoder/payload_decoders/continuation_payload_decoder_test.cc b/quiche/http2/decoder/payload_decoders/continuation_payload_decoder_test.cc new file mode 100644 index 000000000000..e7aec65b7817 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/continuation_payload_decoder_test.cc @@ -0,0 +1,84 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/payload_decoders/continuation_payload_decoder.h" + +#include + +#include +#include + +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/http2/test_tools/frame_parts.h" +#include "quiche/http2/test_tools/frame_parts_collector.h" +#include "quiche/http2/test_tools/payload_decoder_base_test_util.h" +#include "quiche/http2/test_tools/random_decoder_test_base.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { + +// Provides friend access to an instance of the payload decoder, and also +// provides info to aid in testing. +class ContinuationPayloadDecoderPeer { + public: + static constexpr Http2FrameType FrameType() { + return Http2FrameType::CONTINUATION; + } + + // Returns the mask of flags that affect the decoding of the payload (i.e. + // flags that that indicate the presence of certain fields or padding). + static constexpr uint8_t FlagsAffectingPayloadDecoding() { return 0; } +}; + +namespace { + +struct Listener : public FramePartsCollector { + void OnContinuationStart(const Http2FrameHeader& header) override { + QUICHE_VLOG(1) << "OnContinuationStart: " << header; + StartFrame(header)->OnContinuationStart(header); + } + + void OnHpackFragment(const char* data, size_t len) override { + QUICHE_VLOG(1) << "OnHpackFragment: len=" << len; + CurrentFrame()->OnHpackFragment(data, len); + } + + void OnContinuationEnd() override { + QUICHE_VLOG(1) << "OnContinuationEnd"; + EndFrame()->OnContinuationEnd(); + } +}; + +class ContinuationPayloadDecoderTest + : public AbstractPayloadDecoderTest< + ContinuationPayloadDecoder, ContinuationPayloadDecoderPeer, Listener>, + public ::testing::WithParamInterface { + protected: + ContinuationPayloadDecoderTest() : length_(GetParam()) { + QUICHE_VLOG(1) << "################ length_=" << length_ + << " ################"; + } + + const uint32_t length_; +}; + +INSTANTIATE_TEST_SUITE_P(VariousLengths, ContinuationPayloadDecoderTest, + ::testing::Values(0, 1, 2, 3, 4, 5, 6)); + +TEST_P(ContinuationPayloadDecoderTest, ValidLength) { + std::string hpack_payload = Random().RandString(length_); + Http2FrameHeader frame_header(length_, Http2FrameType::CONTINUATION, + RandFlags(), RandStreamId()); + set_frame_header(frame_header); + FrameParts expected(frame_header, hpack_payload); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(hpack_payload, expected)); +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/decoder/payload_decoders/data_payload_decoder.cc b/quiche/http2/decoder/payload_decoders/data_payload_decoder.cc new file mode 100644 index 000000000000..e0c7ab057747 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/data_payload_decoder.cc @@ -0,0 +1,128 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/payload_decoders/data_payload_decoder.h" + +#include + +#include "absl/base/macros.h" +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +std::ostream& operator<<(std::ostream& out, + DataPayloadDecoder::PayloadState v) { + switch (v) { + case DataPayloadDecoder::PayloadState::kReadPadLength: + return out << "kReadPadLength"; + case DataPayloadDecoder::PayloadState::kReadPayload: + return out << "kReadPayload"; + case DataPayloadDecoder::PayloadState::kSkipPadding: + return out << "kSkipPadding"; + } + // Since the value doesn't come over the wire, only a programming bug should + // result in reaching this point. + int unknown = static_cast(v); + QUICHE_BUG(http2_bug_174_1) + << "Invalid DataPayloadDecoder::PayloadState: " << unknown; + return out << "DataPayloadDecoder::PayloadState(" << unknown << ")"; +} + +DecodeStatus DataPayloadDecoder::StartDecodingPayload(FrameDecoderState* state, + DecodeBuffer* db) { + const Http2FrameHeader& frame_header = state->frame_header(); + const uint32_t total_length = frame_header.payload_length; + + QUICHE_DVLOG(2) << "DataPayloadDecoder::StartDecodingPayload: " + << frame_header; + QUICHE_DCHECK_EQ(Http2FrameType::DATA, frame_header.type); + QUICHE_DCHECK_LE(db->Remaining(), total_length); + QUICHE_DCHECK_EQ(0, frame_header.flags & ~(Http2FrameFlag::END_STREAM | + Http2FrameFlag::PADDED)); + + // Special case for the hoped for common case: unpadded and fits fully into + // the decode buffer. TO BE SEEN if that is true. It certainly requires that + // the transport buffers be large (e.g. >> 16KB typically). + // TODO(jamessynge) Add counters. + QUICHE_DVLOG(2) << "StartDecodingPayload total_length=" << total_length; + if (!frame_header.IsPadded()) { + QUICHE_DVLOG(2) << "StartDecodingPayload !IsPadded"; + if (db->Remaining() == total_length) { + QUICHE_DVLOG(2) << "StartDecodingPayload all present"; + // Note that we don't cache the listener field so that the callee can + // replace it if the frame is bad. + // If this case is common enough, consider combining the 3 callbacks + // into one. + state->listener()->OnDataStart(frame_header); + if (total_length > 0) { + state->listener()->OnDataPayload(db->cursor(), total_length); + db->AdvanceCursor(total_length); + } + state->listener()->OnDataEnd(); + return DecodeStatus::kDecodeDone; + } + payload_state_ = PayloadState::kReadPayload; + } else { + payload_state_ = PayloadState::kReadPadLength; + } + state->InitializeRemainders(); + state->listener()->OnDataStart(frame_header); + return ResumeDecodingPayload(state, db); +} + +DecodeStatus DataPayloadDecoder::ResumeDecodingPayload(FrameDecoderState* state, + DecodeBuffer* db) { + QUICHE_DVLOG(2) << "DataPayloadDecoder::ResumeDecodingPayload payload_state_=" + << payload_state_; + const Http2FrameHeader& frame_header = state->frame_header(); + QUICHE_DCHECK_EQ(Http2FrameType::DATA, frame_header.type); + QUICHE_DCHECK_LE(state->remaining_payload_and_padding(), + frame_header.payload_length); + QUICHE_DCHECK_LE(db->Remaining(), state->remaining_payload_and_padding()); + DecodeStatus status; + size_t avail; + switch (payload_state_) { + case PayloadState::kReadPadLength: + // ReadPadLength handles the OnPadLength callback, and updating the + // remaining_payload and remaining_padding fields. If the amount of + // padding is too large to fit in the frame's payload, ReadPadLength + // instead calls OnPaddingTooLong and returns kDecodeError. + status = state->ReadPadLength(db, /*report_pad_length*/ true); + if (status != DecodeStatus::kDecodeDone) { + return status; + } + ABSL_FALLTHROUGH_INTENDED; + + case PayloadState::kReadPayload: + avail = state->AvailablePayload(db); + if (avail > 0) { + state->listener()->OnDataPayload(db->cursor(), avail); + db->AdvanceCursor(avail); + state->ConsumePayload(avail); + } + if (state->remaining_payload() > 0) { + payload_state_ = PayloadState::kReadPayload; + return DecodeStatus::kDecodeInProgress; + } + ABSL_FALLTHROUGH_INTENDED; + + case PayloadState::kSkipPadding: + // SkipPadding handles the OnPadding callback. + if (state->SkipPadding(db)) { + state->listener()->OnDataEnd(); + return DecodeStatus::kDecodeDone; + } + payload_state_ = PayloadState::kSkipPadding; + return DecodeStatus::kDecodeInProgress; + } + QUICHE_BUG(http2_bug_174_2) << "PayloadState: " << payload_state_; + return DecodeStatus::kDecodeError; +} + +} // namespace http2 diff --git a/quiche/http2/decoder/payload_decoders/data_payload_decoder.h b/quiche/http2/decoder/payload_decoders/data_payload_decoder.h new file mode 100644 index 000000000000..b0b117d3a964 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/data_payload_decoder.h @@ -0,0 +1,54 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_DATA_PAYLOAD_DECODER_H_ +#define QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_DATA_PAYLOAD_DECODER_H_ + +// Decodes the payload of a DATA frame. + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/decoder/frame_decoder_state.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace test { +class DataPayloadDecoderPeer; +} // namespace test + +class QUICHE_EXPORT DataPayloadDecoder { + public: + // States during decoding of a DATA frame. + enum class PayloadState { + // The frame is padded and we need to read the PAD_LENGTH field (1 byte), + // and then call OnPadLength + kReadPadLength, + + // Report the non-padding portion of the payload to the listener's + // OnDataPayload method. + kReadPayload, + + // The decoder has finished with the non-padding portion of the payload, + // and is now ready to skip the trailing padding, if the frame has any. + kSkipPadding, + }; + + // Starts decoding a DATA frame's payload, and completes it if + // the entire payload is in the provided decode buffer. + DecodeStatus StartDecodingPayload(FrameDecoderState* state, DecodeBuffer* db); + + // Resumes decoding a DATA frame's payload that has been split across + // decode buffers. + DecodeStatus ResumeDecodingPayload(FrameDecoderState* state, + DecodeBuffer* db); + + private: + friend class test::DataPayloadDecoderPeer; + + PayloadState payload_state_; +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_DATA_PAYLOAD_DECODER_H_ diff --git a/quiche/http2/decoder/payload_decoders/data_payload_decoder_test.cc b/quiche/http2/decoder/payload_decoders/data_payload_decoder_test.cc new file mode 100644 index 000000000000..bbd49135bb1b --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/data_payload_decoder_test.cc @@ -0,0 +1,110 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/payload_decoders/data_payload_decoder.h" + +#include + +#include + +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/http2/test_tools/frame_parts.h" +#include "quiche/http2/test_tools/frame_parts_collector.h" +#include "quiche/http2/test_tools/http2_frame_builder.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/http2/test_tools/http2_structures_test_util.h" +#include "quiche/http2/test_tools/payload_decoder_base_test_util.h" +#include "quiche/http2/test_tools/random_decoder_test_base.h" +#include "quiche/common/platform/api/quiche_expect_bug.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { + +// Provides friend access to an instance of the payload decoder, and also +// provides info to aid in testing. +class DataPayloadDecoderPeer { + public: + static constexpr Http2FrameType FrameType() { return Http2FrameType::DATA; } + + // Returns the mask of flags that affect the decoding of the payload (i.e. + // flags that that indicate the presence of certain fields or padding). + static constexpr uint8_t FlagsAffectingPayloadDecoding() { + return Http2FrameFlag::PADDED; + } +}; + +namespace { + +struct Listener : public FramePartsCollector { + void OnDataStart(const Http2FrameHeader& header) override { + QUICHE_VLOG(1) << "OnDataStart: " << header; + StartFrame(header)->OnDataStart(header); + } + + void OnDataPayload(const char* data, size_t len) override { + QUICHE_VLOG(1) << "OnDataPayload: len=" << len; + CurrentFrame()->OnDataPayload(data, len); + } + + void OnDataEnd() override { + QUICHE_VLOG(1) << "OnDataEnd"; + EndFrame()->OnDataEnd(); + } + + void OnPadLength(size_t pad_length) override { + QUICHE_VLOG(1) << "OnPadLength: " << pad_length; + CurrentFrame()->OnPadLength(pad_length); + } + + void OnPadding(const char* padding, size_t skipped_length) override { + QUICHE_VLOG(1) << "OnPadding: " << skipped_length; + CurrentFrame()->OnPadding(padding, skipped_length); + } + + void OnPaddingTooLong(const Http2FrameHeader& header, + size_t missing_length) override { + QUICHE_VLOG(1) << "OnPaddingTooLong: " << header + << " missing_length: " << missing_length; + EndFrame()->OnPaddingTooLong(header, missing_length); + } +}; + +class DataPayloadDecoderTest + : public AbstractPaddablePayloadDecoderTest< + DataPayloadDecoder, DataPayloadDecoderPeer, Listener> { + protected: + AssertionResult CreateAndDecodeDataOfSize(size_t data_size) { + Reset(); + uint8_t flags = RandFlags(); + + std::string data_payload = Random().RandString(data_size); + frame_builder_.Append(data_payload); + MaybeAppendTrailingPadding(); + + Http2FrameHeader frame_header(frame_builder_.size(), Http2FrameType::DATA, + flags, RandStreamId()); + set_frame_header(frame_header); + ScrubFlagsOfHeader(&frame_header); + FrameParts expected(frame_header, data_payload, total_pad_length_); + return DecodePayloadAndValidateSeveralWays(frame_builder_.buffer(), + expected); + } +}; + +INSTANTIATE_TEST_SUITE_P(VariousPadLengths, DataPayloadDecoderTest, + ::testing::Values(0, 1, 2, 3, 4, 254, 255, 256)); + +TEST_P(DataPayloadDecoderTest, VariousDataPayloadSizes) { + for (size_t data_size : {0, 1, 2, 3, 255, 256, 1024}) { + EXPECT_TRUE(CreateAndDecodeDataOfSize(data_size)); + } +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/decoder/payload_decoders/goaway_payload_decoder.cc b/quiche/http2/decoder/payload_decoders/goaway_payload_decoder.cc new file mode 100644 index 000000000000..fca781c2465a --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/goaway_payload_decoder.cc @@ -0,0 +1,120 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/payload_decoders/goaway_payload_decoder.h" + +#include + +#include "absl/base/macros.h" +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +std::ostream& operator<<(std::ostream& out, + GoAwayPayloadDecoder::PayloadState v) { + switch (v) { + case GoAwayPayloadDecoder::PayloadState::kStartDecodingFixedFields: + return out << "kStartDecodingFixedFields"; + case GoAwayPayloadDecoder::PayloadState::kHandleFixedFieldsStatus: + return out << "kHandleFixedFieldsStatus"; + case GoAwayPayloadDecoder::PayloadState::kReadOpaqueData: + return out << "kReadOpaqueData"; + case GoAwayPayloadDecoder::PayloadState::kResumeDecodingFixedFields: + return out << "kResumeDecodingFixedFields"; + } + // Since the value doesn't come over the wire, only a programming bug should + // result in reaching this point. + int unknown = static_cast(v); + QUICHE_BUG(http2_bug_167_1) + << "Invalid GoAwayPayloadDecoder::PayloadState: " << unknown; + return out << "GoAwayPayloadDecoder::PayloadState(" << unknown << ")"; +} + +DecodeStatus GoAwayPayloadDecoder::StartDecodingPayload( + FrameDecoderState* state, DecodeBuffer* db) { + QUICHE_DVLOG(2) << "GoAwayPayloadDecoder::StartDecodingPayload: " + << state->frame_header(); + QUICHE_DCHECK_EQ(Http2FrameType::GOAWAY, state->frame_header().type); + QUICHE_DCHECK_LE(db->Remaining(), state->frame_header().payload_length); + QUICHE_DCHECK_EQ(0, state->frame_header().flags); + + state->InitializeRemainders(); + payload_state_ = PayloadState::kStartDecodingFixedFields; + return ResumeDecodingPayload(state, db); +} + +DecodeStatus GoAwayPayloadDecoder::ResumeDecodingPayload( + FrameDecoderState* state, DecodeBuffer* db) { + QUICHE_DVLOG(2) + << "GoAwayPayloadDecoder::ResumeDecodingPayload: remaining_payload=" + << state->remaining_payload() << ", db->Remaining=" << db->Remaining(); + + const Http2FrameHeader& frame_header = state->frame_header(); + QUICHE_DCHECK_EQ(Http2FrameType::GOAWAY, frame_header.type); + QUICHE_DCHECK_LE(db->Remaining(), frame_header.payload_length); + QUICHE_DCHECK_NE(PayloadState::kHandleFixedFieldsStatus, payload_state_); + + // |status| has to be initialized to some value to avoid compiler error in + // case PayloadState::kHandleFixedFieldsStatus below, but value does not + // matter, see QUICHE_DCHECK_NE above. + DecodeStatus status = DecodeStatus::kDecodeError; + size_t avail; + while (true) { + QUICHE_DVLOG(2) + << "GoAwayPayloadDecoder::ResumeDecodingPayload payload_state_=" + << payload_state_; + switch (payload_state_) { + case PayloadState::kStartDecodingFixedFields: + status = state->StartDecodingStructureInPayload(&goaway_fields_, db); + ABSL_FALLTHROUGH_INTENDED; + + case PayloadState::kHandleFixedFieldsStatus: + if (status == DecodeStatus::kDecodeDone) { + state->listener()->OnGoAwayStart(frame_header, goaway_fields_); + } else { + // Not done decoding the structure. Either we've got more payload + // to decode, or we've run out because the payload is too short, + // in which case OnFrameSizeError will have already been called. + QUICHE_DCHECK((status == DecodeStatus::kDecodeInProgress && + state->remaining_payload() > 0) || + (status == DecodeStatus::kDecodeError && + state->remaining_payload() == 0)) + << "\n status=" << status + << "; remaining_payload=" << state->remaining_payload(); + payload_state_ = PayloadState::kResumeDecodingFixedFields; + return status; + } + ABSL_FALLTHROUGH_INTENDED; + + case PayloadState::kReadOpaqueData: + // The opaque data is all the remains to be decoded, so anything left + // in the decode buffer is opaque data. + avail = db->Remaining(); + if (avail > 0) { + state->listener()->OnGoAwayOpaqueData(db->cursor(), avail); + db->AdvanceCursor(avail); + state->ConsumePayload(avail); + } + if (state->remaining_payload() > 0) { + payload_state_ = PayloadState::kReadOpaqueData; + return DecodeStatus::kDecodeInProgress; + } + state->listener()->OnGoAwayEnd(); + return DecodeStatus::kDecodeDone; + + case PayloadState::kResumeDecodingFixedFields: + status = state->ResumeDecodingStructureInPayload(&goaway_fields_, db); + payload_state_ = PayloadState::kHandleFixedFieldsStatus; + continue; + } + QUICHE_BUG(http2_bug_167_2) << "PayloadState: " << payload_state_; + } +} + +} // namespace http2 diff --git a/quiche/http2/decoder/payload_decoders/goaway_payload_decoder.h b/quiche/http2/decoder/payload_decoders/goaway_payload_decoder.h new file mode 100644 index 000000000000..1f3b69cc4c1e --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/goaway_payload_decoder.h @@ -0,0 +1,66 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_GOAWAY_PAYLOAD_DECODER_H_ +#define QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_GOAWAY_PAYLOAD_DECODER_H_ + +// Decodes the payload of a GOAWAY frame. + +// TODO(jamessynge): Sweep through all payload decoders, changing the names of +// the PayloadState enums so that they are really states, and not actions. + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/decoder/frame_decoder_state.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace test { +class GoAwayPayloadDecoderPeer; +} // namespace test + +class QUICHE_EXPORT GoAwayPayloadDecoder { + public: + // States during decoding of a GOAWAY frame. + enum class PayloadState { + // At the start of the GOAWAY frame payload, ready to start decoding the + // fixed size fields into goaway_fields_. + kStartDecodingFixedFields, + + // Handle the DecodeStatus returned from starting or resuming the + // decoding of Http2GoAwayFields into goaway_fields_. If complete, + // calls OnGoAwayStart. + kHandleFixedFieldsStatus, + + // Report the Opaque Data portion of the payload to the listener's + // OnGoAwayOpaqueData method, and call OnGoAwayEnd when the end of the + // payload is reached. + kReadOpaqueData, + + // The fixed size fields weren't all available when the decoder first + // tried to decode them (state kStartDecodingFixedFields); this state + // resumes the decoding when ResumeDecodingPayload is called later. + kResumeDecodingFixedFields, + }; + + // Starts the decoding of a GOAWAY frame's payload, and completes it if + // the entire payload is in the provided decode buffer. + DecodeStatus StartDecodingPayload(FrameDecoderState* state, DecodeBuffer* db); + + // Resumes decoding a GOAWAY frame's payload that has been split across + // decode buffers. + DecodeStatus ResumeDecodingPayload(FrameDecoderState* state, + DecodeBuffer* db); + + private: + friend class test::GoAwayPayloadDecoderPeer; + + Http2GoAwayFields goaway_fields_; + PayloadState payload_state_; +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_GOAWAY_PAYLOAD_DECODER_H_ diff --git a/quiche/http2/decoder/payload_decoders/goaway_payload_decoder_test.cc b/quiche/http2/decoder/payload_decoders/goaway_payload_decoder_test.cc new file mode 100644 index 000000000000..78f44d12e3d3 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/goaway_payload_decoder_test.cc @@ -0,0 +1,108 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/payload_decoders/goaway_payload_decoder.h" + +#include + +#include + +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/test_tools/frame_parts.h" +#include "quiche/http2/test_tools/frame_parts_collector.h" +#include "quiche/http2/test_tools/http2_frame_builder.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/http2/test_tools/http2_structures_test_util.h" +#include "quiche/http2/test_tools/payload_decoder_base_test_util.h" +#include "quiche/http2/test_tools/random_decoder_test_base.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { + +class GoAwayPayloadDecoderPeer { + public: + static constexpr Http2FrameType FrameType() { return Http2FrameType::GOAWAY; } + + // Returns the mask of flags that affect the decoding of the payload (i.e. + // flags that that indicate the presence of certain fields or padding). + static constexpr uint8_t FlagsAffectingPayloadDecoding() { return 0; } +}; + +namespace { + +struct Listener : public FramePartsCollector { + void OnGoAwayStart(const Http2FrameHeader& header, + const Http2GoAwayFields& goaway) override { + QUICHE_VLOG(1) << "OnGoAwayStart header: " << header + << "; goaway: " << goaway; + StartFrame(header)->OnGoAwayStart(header, goaway); + } + + void OnGoAwayOpaqueData(const char* data, size_t len) override { + QUICHE_VLOG(1) << "OnGoAwayOpaqueData: len=" << len; + CurrentFrame()->OnGoAwayOpaqueData(data, len); + } + + void OnGoAwayEnd() override { + QUICHE_VLOG(1) << "OnGoAwayEnd"; + EndFrame()->OnGoAwayEnd(); + } + + void OnFrameSizeError(const Http2FrameHeader& header) override { + QUICHE_VLOG(1) << "OnFrameSizeError: " << header; + FrameError(header)->OnFrameSizeError(header); + } +}; + +class GoAwayPayloadDecoderTest + : public AbstractPayloadDecoderTest {}; + +// Confirm we get an error if the payload is not long enough to hold +// Http2GoAwayFields. +TEST_F(GoAwayPayloadDecoderTest, Truncated) { + auto approve_size = [](size_t size) { + return size != Http2GoAwayFields::EncodedSize(); + }; + Http2FrameBuilder fb; + fb.Append(Http2GoAwayFields(123, Http2ErrorCode::ENHANCE_YOUR_CALM)); + EXPECT_TRUE(VerifyDetectsFrameSizeError(0, fb.buffer(), approve_size)); +} + +class GoAwayOpaqueDataLengthTests + : public GoAwayPayloadDecoderTest, + public ::testing::WithParamInterface { + protected: + GoAwayOpaqueDataLengthTests() : length_(GetParam()) { + QUICHE_VLOG(1) << "################ length_=" << length_ + << " ################"; + } + + const uint32_t length_; +}; + +INSTANTIATE_TEST_SUITE_P(VariousLengths, GoAwayOpaqueDataLengthTests, + ::testing::Values(0, 1, 2, 3, 4, 5, 6)); + +TEST_P(GoAwayOpaqueDataLengthTests, ValidLength) { + Http2GoAwayFields goaway; + Randomize(&goaway, RandomPtr()); + std::string opaque_data = Random().RandString(length_); + Http2FrameBuilder fb; + fb.Append(goaway); + fb.Append(opaque_data); + Http2FrameHeader header(fb.size(), Http2FrameType::GOAWAY, RandFlags(), + RandStreamId()); + set_frame_header(header); + FrameParts expected(header, opaque_data); + expected.SetOptGoaway(goaway); + ASSERT_TRUE(DecodePayloadAndValidateSeveralWays(fb.buffer(), expected)); +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/decoder/payload_decoders/headers_payload_decoder.cc b/quiche/http2/decoder/payload_decoders/headers_payload_decoder.cc new file mode 100644 index 000000000000..e3517ffabe25 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/headers_payload_decoder.cc @@ -0,0 +1,176 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/payload_decoders/headers_payload_decoder.h" + +#include + +#include "absl/base/macros.h" +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +std::ostream& operator<<(std::ostream& out, + HeadersPayloadDecoder::PayloadState v) { + switch (v) { + case HeadersPayloadDecoder::PayloadState::kReadPadLength: + return out << "kReadPadLength"; + case HeadersPayloadDecoder::PayloadState::kStartDecodingPriorityFields: + return out << "kStartDecodingPriorityFields"; + case HeadersPayloadDecoder::PayloadState::kResumeDecodingPriorityFields: + return out << "kResumeDecodingPriorityFields"; + case HeadersPayloadDecoder::PayloadState::kReadPayload: + return out << "kReadPayload"; + case HeadersPayloadDecoder::PayloadState::kSkipPadding: + return out << "kSkipPadding"; + } + // Since the value doesn't come over the wire, only a programming bug should + // result in reaching this point. + int unknown = static_cast(v); + QUICHE_BUG(http2_bug_189_1) + << "Invalid HeadersPayloadDecoder::PayloadState: " << unknown; + return out << "HeadersPayloadDecoder::PayloadState(" << unknown << ")"; +} + +DecodeStatus HeadersPayloadDecoder::StartDecodingPayload( + FrameDecoderState* state, DecodeBuffer* db) { + const Http2FrameHeader& frame_header = state->frame_header(); + const uint32_t total_length = frame_header.payload_length; + + QUICHE_DVLOG(2) << "HeadersPayloadDecoder::StartDecodingPayload: " + << frame_header; + + QUICHE_DCHECK_EQ(Http2FrameType::HEADERS, frame_header.type); + QUICHE_DCHECK_LE(db->Remaining(), total_length); + QUICHE_DCHECK_EQ( + 0, frame_header.flags & + ~(Http2FrameFlag::END_STREAM | Http2FrameFlag::END_HEADERS | + Http2FrameFlag::PADDED | Http2FrameFlag::PRIORITY)); + + // Special case for HEADERS frames that contain only the HPACK block + // (fragment or whole) and that fit fully into the decode buffer. + // Why? Unencoded browser GET requests are typically under 1K and HPACK + // commonly shrinks request headers by 80%, so we can expect this to + // be common. + // TODO(jamessynge) Add counters here and to Spdy for determining how + // common this situation is. A possible approach is to create a + // Http2FrameDecoderListener that counts the callbacks and then forwards + // them on to another listener, which makes it easy to add and remove + // counting on a connection or even frame basis. + + // PADDED and PRIORITY both extra steps to decode, but if neither flag is + // set then we can decode faster. + const auto payload_flags = Http2FrameFlag::PADDED | Http2FrameFlag::PRIORITY; + if (!frame_header.HasAnyFlags(payload_flags)) { + QUICHE_DVLOG(2) << "StartDecodingPayload !IsPadded && !HasPriority"; + if (db->Remaining() == total_length) { + QUICHE_DVLOG(2) << "StartDecodingPayload all present"; + // Note that we don't cache the listener field so that the callee can + // replace it if the frame is bad. + // If this case is common enough, consider combining the 3 callbacks + // into one, especially if END_HEADERS is also set. + state->listener()->OnHeadersStart(frame_header); + if (total_length > 0) { + state->listener()->OnHpackFragment(db->cursor(), total_length); + db->AdvanceCursor(total_length); + } + state->listener()->OnHeadersEnd(); + return DecodeStatus::kDecodeDone; + } + payload_state_ = PayloadState::kReadPayload; + } else if (frame_header.IsPadded()) { + payload_state_ = PayloadState::kReadPadLength; + } else { + QUICHE_DCHECK(frame_header.HasPriority()) << frame_header; + payload_state_ = PayloadState::kStartDecodingPriorityFields; + } + state->InitializeRemainders(); + state->listener()->OnHeadersStart(frame_header); + return ResumeDecodingPayload(state, db); +} + +DecodeStatus HeadersPayloadDecoder::ResumeDecodingPayload( + FrameDecoderState* state, DecodeBuffer* db) { + QUICHE_DVLOG(2) << "HeadersPayloadDecoder::ResumeDecodingPayload " + << "remaining_payload=" << state->remaining_payload() + << "; db->Remaining=" << db->Remaining(); + + const Http2FrameHeader& frame_header = state->frame_header(); + + QUICHE_DCHECK_EQ(Http2FrameType::HEADERS, frame_header.type); + QUICHE_DCHECK_LE(state->remaining_payload_and_padding(), + frame_header.payload_length); + QUICHE_DCHECK_LE(db->Remaining(), state->remaining_payload_and_padding()); + DecodeStatus status; + size_t avail; + while (true) { + QUICHE_DVLOG(2) + << "HeadersPayloadDecoder::ResumeDecodingPayload payload_state_=" + << payload_state_; + switch (payload_state_) { + case PayloadState::kReadPadLength: + // ReadPadLength handles the OnPadLength callback, and updating the + // remaining_payload and remaining_padding fields. If the amount of + // padding is too large to fit in the frame's payload, ReadPadLength + // instead calls OnPaddingTooLong and returns kDecodeError. + status = state->ReadPadLength(db, /*report_pad_length*/ true); + if (status != DecodeStatus::kDecodeDone) { + return status; + } + if (!frame_header.HasPriority()) { + payload_state_ = PayloadState::kReadPayload; + continue; + } + ABSL_FALLTHROUGH_INTENDED; + + case PayloadState::kStartDecodingPriorityFields: + status = state->StartDecodingStructureInPayload(&priority_fields_, db); + if (status != DecodeStatus::kDecodeDone) { + payload_state_ = PayloadState::kResumeDecodingPriorityFields; + return status; + } + state->listener()->OnHeadersPriority(priority_fields_); + ABSL_FALLTHROUGH_INTENDED; + + case PayloadState::kReadPayload: + avail = state->AvailablePayload(db); + if (avail > 0) { + state->listener()->OnHpackFragment(db->cursor(), avail); + db->AdvanceCursor(avail); + state->ConsumePayload(avail); + } + if (state->remaining_payload() > 0) { + payload_state_ = PayloadState::kReadPayload; + return DecodeStatus::kDecodeInProgress; + } + ABSL_FALLTHROUGH_INTENDED; + + case PayloadState::kSkipPadding: + // SkipPadding handles the OnPadding callback. + if (state->SkipPadding(db)) { + state->listener()->OnHeadersEnd(); + return DecodeStatus::kDecodeDone; + } + payload_state_ = PayloadState::kSkipPadding; + return DecodeStatus::kDecodeInProgress; + + case PayloadState::kResumeDecodingPriorityFields: + status = state->ResumeDecodingStructureInPayload(&priority_fields_, db); + if (status != DecodeStatus::kDecodeDone) { + return status; + } + state->listener()->OnHeadersPriority(priority_fields_); + payload_state_ = PayloadState::kReadPayload; + continue; + } + QUICHE_BUG(http2_bug_189_2) << "PayloadState: " << payload_state_; + } +} + +} // namespace http2 diff --git a/quiche/http2/decoder/payload_decoders/headers_payload_decoder.h b/quiche/http2/decoder/payload_decoders/headers_payload_decoder.h new file mode 100644 index 000000000000..05979cd313c2 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/headers_payload_decoder.h @@ -0,0 +1,67 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_HEADERS_PAYLOAD_DECODER_H_ +#define QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_HEADERS_PAYLOAD_DECODER_H_ + +// Decodes the payload of a HEADERS frame. + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/decoder/frame_decoder_state.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace test { +class HeadersPayloadDecoderPeer; +} // namespace test + +class QUICHE_EXPORT HeadersPayloadDecoder { + public: + // States during decoding of a HEADERS frame, unless the fast path kicks + // in, in which case the state machine will be bypassed. + enum class PayloadState { + // The PADDED flag is set, and we now need to read the Pad Length field + // (the first byte of the payload, after the common frame header). + kReadPadLength, + + // The PRIORITY flag is set, and we now need to read the fixed size priority + // fields (E, Stream Dependency, Weight) into priority_fields_. Calls on + // OnHeadersPriority if completely decodes those fields. + kStartDecodingPriorityFields, + + // The decoder passes the non-padding portion of the remaining payload + // (i.e. the HPACK block fragment) to the listener's OnHpackFragment method. + kReadPayload, + + // The decoder has finished with the HPACK block fragment, and is now + // ready to skip the trailing padding, if the frame has any. + kSkipPadding, + + // The fixed size fields weren't all available when the decoder first tried + // to decode them (state kStartDecodingPriorityFields); this state resumes + // the decoding when ResumeDecodingPayload is called later. + kResumeDecodingPriorityFields, + }; + + // Starts the decoding of a HEADERS frame's payload, and completes it if + // the entire payload is in the provided decode buffer. + DecodeStatus StartDecodingPayload(FrameDecoderState* state, DecodeBuffer* db); + + // Resumes decoding a HEADERS frame's payload that has been split across + // decode buffers. + DecodeStatus ResumeDecodingPayload(FrameDecoderState* state, + DecodeBuffer* db); + + private: + friend class test::HeadersPayloadDecoderPeer; + + PayloadState payload_state_; + Http2PriorityFields priority_fields_; +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_HEADERS_PAYLOAD_DECODER_H_ diff --git a/quiche/http2/decoder/payload_decoders/headers_payload_decoder_test.cc b/quiche/http2/decoder/payload_decoders/headers_payload_decoder_test.cc new file mode 100644 index 000000000000..f6f9af150fb3 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/headers_payload_decoder_test.cc @@ -0,0 +1,158 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/payload_decoders/headers_payload_decoder.h" + +#include + +#include + +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/test_tools/frame_parts.h" +#include "quiche/http2/test_tools/frame_parts_collector.h" +#include "quiche/http2/test_tools/http2_frame_builder.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/http2/test_tools/http2_structures_test_util.h" +#include "quiche/http2/test_tools/payload_decoder_base_test_util.h" +#include "quiche/http2/test_tools/random_decoder_test_base.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { + +class HeadersPayloadDecoderPeer { + public: + static constexpr Http2FrameType FrameType() { + return Http2FrameType::HEADERS; + } + + // Returns the mask of flags that affect the decoding of the payload (i.e. + // flags that that indicate the presence of certain fields or padding). + static constexpr uint8_t FlagsAffectingPayloadDecoding() { + return Http2FrameFlag::PADDED | Http2FrameFlag::PRIORITY; + } +}; + +namespace { + +// Listener handles all On* methods that are expected to be called. If any other +// On* methods of Http2FrameDecoderListener is called then the test fails; this +// is achieved by way of FailingHttp2FrameDecoderListener, the base class of +// FramePartsCollector. +// These On* methods make use of StartFrame, EndFrame, etc. of the base class +// to create and access to FrameParts instance(s) that will record the details. +// After decoding, the test validation code can access the FramePart instance(s) +// via the public methods of FramePartsCollector. +struct Listener : public FramePartsCollector { + void OnHeadersStart(const Http2FrameHeader& header) override { + QUICHE_VLOG(1) << "OnHeadersStart: " << header; + StartFrame(header)->OnHeadersStart(header); + } + + void OnHeadersPriority(const Http2PriorityFields& priority) override { + QUICHE_VLOG(1) << "OnHeadersPriority: " << priority; + CurrentFrame()->OnHeadersPriority(priority); + } + + void OnHpackFragment(const char* data, size_t len) override { + QUICHE_VLOG(1) << "OnHpackFragment: len=" << len; + CurrentFrame()->OnHpackFragment(data, len); + } + + void OnHeadersEnd() override { + QUICHE_VLOG(1) << "OnHeadersEnd"; + EndFrame()->OnHeadersEnd(); + } + + void OnPadLength(size_t pad_length) override { + QUICHE_VLOG(1) << "OnPadLength: " << pad_length; + CurrentFrame()->OnPadLength(pad_length); + } + + void OnPadding(const char* padding, size_t skipped_length) override { + QUICHE_VLOG(1) << "OnPadding: " << skipped_length; + CurrentFrame()->OnPadding(padding, skipped_length); + } + + void OnPaddingTooLong(const Http2FrameHeader& header, + size_t missing_length) override { + QUICHE_VLOG(1) << "OnPaddingTooLong: " << header + << "; missing_length: " << missing_length; + FrameError(header)->OnPaddingTooLong(header, missing_length); + } + + void OnFrameSizeError(const Http2FrameHeader& header) override { + QUICHE_VLOG(1) << "OnFrameSizeError: " << header; + FrameError(header)->OnFrameSizeError(header); + } +}; + +class HeadersPayloadDecoderTest + : public AbstractPaddablePayloadDecoderTest< + HeadersPayloadDecoder, HeadersPayloadDecoderPeer, Listener> {}; + +INSTANTIATE_TEST_SUITE_P(VariousPadLengths, HeadersPayloadDecoderTest, + ::testing::Values(0, 1, 2, 3, 4, 254, 255, 256)); + +// Decode various sizes of (fake) HPACK payload, both with and without the +// PRIORITY flag set. +TEST_P(HeadersPayloadDecoderTest, VariousHpackPayloadSizes) { + for (size_t hpack_size : {0, 1, 2, 3, 255, 256, 1024}) { + QUICHE_LOG(INFO) << "########### hpack_size = " << hpack_size + << " ###########"; + Http2PriorityFields priority(RandStreamId(), 1 + Random().Rand8(), + Random().OneIn(2)); + + for (bool has_priority : {false, true}) { + Reset(); + ASSERT_EQ(IsPadded() ? 1u : 0u, frame_builder_.size()); + uint8_t flags = RandFlags(); + if (has_priority) { + flags |= Http2FrameFlag::PRIORITY; + frame_builder_.Append(priority); + } + + std::string hpack_payload = Random().RandString(hpack_size); + frame_builder_.Append(hpack_payload); + + MaybeAppendTrailingPadding(); + Http2FrameHeader frame_header(frame_builder_.size(), + Http2FrameType::HEADERS, flags, + RandStreamId()); + set_frame_header(frame_header); + ScrubFlagsOfHeader(&frame_header); + FrameParts expected(frame_header, hpack_payload, total_pad_length_); + if (has_priority) { + expected.SetOptPriority(priority); + } + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(frame_builder_.buffer(), + expected)); + } + } +} + +// Confirm we get an error if the PRIORITY flag is set but the payload is +// not long enough, regardless of the amount of (valid) padding. +TEST_P(HeadersPayloadDecoderTest, Truncated) { + auto approve_size = [](size_t size) { + return size != Http2PriorityFields::EncodedSize(); + }; + Http2FrameBuilder fb; + fb.Append(Http2PriorityFields(RandStreamId(), 1 + Random().Rand8(), + Random().OneIn(2))); + EXPECT_TRUE(VerifyDetectsMultipleFrameSizeErrors( + Http2FrameFlag::PRIORITY, fb.buffer(), approve_size, total_pad_length_)); +} + +// Confirm we get an error if the PADDED flag is set but the payload is not +// long enough to hold even the Pad Length amount of padding. +TEST_P(HeadersPayloadDecoderTest, PaddingTooLong) { + EXPECT_TRUE(VerifyDetectsPaddingTooLong()); +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/decoder/payload_decoders/ping_payload_decoder.cc b/quiche/http2/decoder/payload_decoders/ping_payload_decoder.cc new file mode 100644 index 000000000000..c700ac291632 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/ping_payload_decoder.cc @@ -0,0 +1,90 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/payload_decoders/ping_payload_decoder.h" + +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { +namespace { +constexpr auto kOpaqueSize = Http2PingFields::EncodedSize(); +} + +DecodeStatus PingPayloadDecoder::StartDecodingPayload(FrameDecoderState* state, + DecodeBuffer* db) { + const Http2FrameHeader& frame_header = state->frame_header(); + const uint32_t total_length = frame_header.payload_length; + + QUICHE_DVLOG(2) << "PingPayloadDecoder::StartDecodingPayload: " + << frame_header; + QUICHE_DCHECK_EQ(Http2FrameType::PING, frame_header.type); + QUICHE_DCHECK_LE(db->Remaining(), total_length); + QUICHE_DCHECK_EQ(0, frame_header.flags & ~(Http2FrameFlag::ACK)); + + // Is the payload entirely in the decode buffer and is it the correct size? + // Given the size of the header and payload (17 bytes total), this is most + // likely the case the vast majority of the time. + if (db->Remaining() == kOpaqueSize && total_length == kOpaqueSize) { + // Special case this situation as it allows us to avoid any copying; + // the other path makes two copies, first into the buffer in + // Http2StructureDecoder as it accumulates the 8 bytes of opaque data, + // and a second copy into the Http2PingFields member of in this class. + // This supports the claim that this decoder is (mostly) non-buffering. + static_assert(sizeof(Http2PingFields) == kOpaqueSize, + "If not, then can't enter this block!"); + auto* ping = reinterpret_cast(db->cursor()); + if (frame_header.IsAck()) { + state->listener()->OnPingAck(frame_header, *ping); + } else { + state->listener()->OnPing(frame_header, *ping); + } + db->AdvanceCursor(kOpaqueSize); + return DecodeStatus::kDecodeDone; + } + state->InitializeRemainders(); + return HandleStatus( + state, state->StartDecodingStructureInPayload(&ping_fields_, db)); +} + +DecodeStatus PingPayloadDecoder::ResumeDecodingPayload(FrameDecoderState* state, + DecodeBuffer* db) { + QUICHE_DVLOG(2) << "ResumeDecodingPayload: remaining_payload=" + << state->remaining_payload(); + QUICHE_DCHECK_EQ(Http2FrameType::PING, state->frame_header().type); + QUICHE_DCHECK_LE(db->Remaining(), state->frame_header().payload_length); + return HandleStatus( + state, state->ResumeDecodingStructureInPayload(&ping_fields_, db)); +} + +DecodeStatus PingPayloadDecoder::HandleStatus(FrameDecoderState* state, + DecodeStatus status) { + QUICHE_DVLOG(2) << "HandleStatus: status=" << status + << "; remaining_payload=" << state->remaining_payload(); + if (status == DecodeStatus::kDecodeDone) { + if (state->remaining_payload() == 0) { + const Http2FrameHeader& frame_header = state->frame_header(); + if (frame_header.IsAck()) { + state->listener()->OnPingAck(frame_header, ping_fields_); + } else { + state->listener()->OnPing(frame_header, ping_fields_); + } + return DecodeStatus::kDecodeDone; + } + // Payload is too long. + return state->ReportFrameSizeError(); + } + // Not done decoding the structure. Either we've got more payload to decode, + // or we've run out because the payload is too short. + QUICHE_DCHECK( + (status == DecodeStatus::kDecodeInProgress && + state->remaining_payload() > 0) || + (status == DecodeStatus::kDecodeError && state->remaining_payload() == 0)) + << "\n status=" << status + << "; remaining_payload=" << state->remaining_payload(); + return status; +} + +} // namespace http2 diff --git a/quiche/http2/decoder/payload_decoders/ping_payload_decoder.h b/quiche/http2/decoder/payload_decoders/ping_payload_decoder.h new file mode 100644 index 000000000000..0b9e963dbf49 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/ping_payload_decoder.h @@ -0,0 +1,43 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_PING_PAYLOAD_DECODER_H_ +#define QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_PING_PAYLOAD_DECODER_H_ + +// Decodes the payload of a PING frame; for the RFC, see: +// http://httpwg.org/specs/rfc7540.html#PING + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/decoder/frame_decoder_state.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace test { +class PingPayloadDecoderPeer; +} // namespace test + +class QUICHE_EXPORT PingPayloadDecoder { + public: + // Starts the decoding of a PING frame's payload, and completes it if the + // entire payload is in the provided decode buffer. + DecodeStatus StartDecodingPayload(FrameDecoderState* state, DecodeBuffer* db); + + // Resumes decoding a PING frame's payload that has been split across + // decode buffers. + DecodeStatus ResumeDecodingPayload(FrameDecoderState* state, + DecodeBuffer* db); + + private: + friend class test::PingPayloadDecoderPeer; + + DecodeStatus HandleStatus(FrameDecoderState* state, DecodeStatus status); + + Http2PingFields ping_fields_; +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_PING_PAYLOAD_DECODER_H_ diff --git a/quiche/http2/decoder/payload_decoders/ping_payload_decoder_test.cc b/quiche/http2/decoder/payload_decoders/ping_payload_decoder_test.cc new file mode 100644 index 000000000000..dba03b1437c4 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/ping_payload_decoder_test.cc @@ -0,0 +1,109 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/payload_decoders/ping_payload_decoder.h" + +#include + +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/test_tools/frame_parts.h" +#include "quiche/http2/test_tools/frame_parts_collector.h" +#include "quiche/http2/test_tools/http2_frame_builder.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/http2/test_tools/http2_structures_test_util.h" +#include "quiche/http2/test_tools/payload_decoder_base_test_util.h" +#include "quiche/http2/test_tools/random_decoder_test_base.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { + +class PingPayloadDecoderPeer { + public: + static constexpr Http2FrameType FrameType() { return Http2FrameType::PING; } + + // Returns the mask of flags that affect the decoding of the payload (i.e. + // flags that that indicate the presence of certain fields or padding). + static constexpr uint8_t FlagsAffectingPayloadDecoding() { return 0; } +}; + +namespace { + +struct Listener : public FramePartsCollector { + void OnPing(const Http2FrameHeader& header, + const Http2PingFields& ping) override { + QUICHE_VLOG(1) << "OnPing: " << header << "; " << ping; + StartAndEndFrame(header)->OnPing(header, ping); + } + + void OnPingAck(const Http2FrameHeader& header, + const Http2PingFields& ping) override { + QUICHE_VLOG(1) << "OnPingAck: " << header << "; " << ping; + StartAndEndFrame(header)->OnPingAck(header, ping); + } + + void OnFrameSizeError(const Http2FrameHeader& header) override { + QUICHE_VLOG(1) << "OnFrameSizeError: " << header; + FrameError(header)->OnFrameSizeError(header); + } +}; + +class PingPayloadDecoderTest + : public AbstractPayloadDecoderTest { + protected: + Http2PingFields RandPingFields() { + Http2PingFields fields; + test::Randomize(&fields, RandomPtr()); + return fields; + } +}; + +// Confirm we get an error if the payload is not the correct size to hold +// exactly one Http2PingFields. +TEST_F(PingPayloadDecoderTest, WrongSize) { + auto approve_size = [](size_t size) { + return size != Http2PingFields::EncodedSize(); + }; + Http2FrameBuilder fb; + fb.Append(RandPingFields()); + fb.Append(RandPingFields()); + fb.Append(RandPingFields()); + EXPECT_TRUE(VerifyDetectsFrameSizeError(0, fb.buffer(), approve_size)); +} + +TEST_F(PingPayloadDecoderTest, Ping) { + for (int n = 0; n < 100; ++n) { + Http2PingFields fields = RandPingFields(); + Http2FrameBuilder fb; + fb.Append(fields); + Http2FrameHeader header(fb.size(), Http2FrameType::PING, + RandFlags() & ~Http2FrameFlag::ACK, RandStreamId()); + set_frame_header(header); + FrameParts expected(header); + expected.SetOptPing(fields); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(fb.buffer(), expected)); + } +} + +TEST_F(PingPayloadDecoderTest, PingAck) { + for (int n = 0; n < 100; ++n) { + Http2PingFields fields; + Randomize(&fields, RandomPtr()); + Http2FrameBuilder fb; + fb.Append(fields); + Http2FrameHeader header(fb.size(), Http2FrameType::PING, + RandFlags() | Http2FrameFlag::ACK, RandStreamId()); + set_frame_header(header); + FrameParts expected(header); + expected.SetOptPing(fields); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(fb.buffer(), expected)); + } +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/decoder/payload_decoders/priority_payload_decoder.cc b/quiche/http2/decoder/payload_decoders/priority_payload_decoder.cc new file mode 100644 index 000000000000..1a5a57ba2f92 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/priority_payload_decoder.cc @@ -0,0 +1,62 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/payload_decoders/priority_payload_decoder.h" + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +DecodeStatus PriorityPayloadDecoder::StartDecodingPayload( + FrameDecoderState* state, DecodeBuffer* db) { + QUICHE_DVLOG(2) << "PriorityPayloadDecoder::StartDecodingPayload: " + << state->frame_header(); + QUICHE_DCHECK_EQ(Http2FrameType::PRIORITY, state->frame_header().type); + QUICHE_DCHECK_LE(db->Remaining(), state->frame_header().payload_length); + // PRIORITY frames have no flags. + QUICHE_DCHECK_EQ(0, state->frame_header().flags); + state->InitializeRemainders(); + return HandleStatus( + state, state->StartDecodingStructureInPayload(&priority_fields_, db)); +} + +DecodeStatus PriorityPayloadDecoder::ResumeDecodingPayload( + FrameDecoderState* state, DecodeBuffer* db) { + QUICHE_DVLOG(2) << "PriorityPayloadDecoder::ResumeDecodingPayload" + << " remaining_payload=" << state->remaining_payload() + << " db->Remaining=" << db->Remaining(); + QUICHE_DCHECK_EQ(Http2FrameType::PRIORITY, state->frame_header().type); + QUICHE_DCHECK_LE(db->Remaining(), state->frame_header().payload_length); + return HandleStatus( + state, state->ResumeDecodingStructureInPayload(&priority_fields_, db)); +} + +DecodeStatus PriorityPayloadDecoder::HandleStatus(FrameDecoderState* state, + DecodeStatus status) { + if (status == DecodeStatus::kDecodeDone) { + if (state->remaining_payload() == 0) { + state->listener()->OnPriorityFrame(state->frame_header(), + priority_fields_); + return DecodeStatus::kDecodeDone; + } + // Payload is too long. + return state->ReportFrameSizeError(); + } + // Not done decoding the structure. Either we've got more payload to decode, + // or we've run out because the payload is too short, in which case + // OnFrameSizeError will have already been called. + QUICHE_DCHECK( + (status == DecodeStatus::kDecodeInProgress && + state->remaining_payload() > 0) || + (status == DecodeStatus::kDecodeError && state->remaining_payload() == 0)) + << "\n status=" << status + << "; remaining_payload=" << state->remaining_payload(); + return status; +} + +} // namespace http2 diff --git a/quiche/http2/decoder/payload_decoders/priority_payload_decoder.h b/quiche/http2/decoder/payload_decoders/priority_payload_decoder.h new file mode 100644 index 000000000000..2056ff968ab1 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/priority_payload_decoder.h @@ -0,0 +1,44 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_PRIORITY_PAYLOAD_DECODER_H_ +#define QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_PRIORITY_PAYLOAD_DECODER_H_ + +// Decodes the payload of a PRIORITY frame. + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/decoder/frame_decoder_state.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace test { +class PriorityPayloadDecoderPeer; +} // namespace test + +class QUICHE_EXPORT PriorityPayloadDecoder { + public: + // Starts the decoding of a PRIORITY frame's payload, and completes it if + // the entire payload is in the provided decode buffer. + DecodeStatus StartDecodingPayload(FrameDecoderState* state, DecodeBuffer* db); + + // Resumes decoding a PRIORITY frame that has been split across decode + // buffers. + DecodeStatus ResumeDecodingPayload(FrameDecoderState* state, + DecodeBuffer* db); + + private: + friend class test::PriorityPayloadDecoderPeer; + + // Determines whether to report the PRIORITY to the listener, wait for more + // input, or to report a Frame Size Error. + DecodeStatus HandleStatus(FrameDecoderState* state, DecodeStatus status); + + Http2PriorityFields priority_fields_; +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_PRIORITY_PAYLOAD_DECODER_H_ diff --git a/quiche/http2/decoder/payload_decoders/priority_payload_decoder_test.cc b/quiche/http2/decoder/payload_decoders/priority_payload_decoder_test.cc new file mode 100644 index 000000000000..573433b49ad4 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/priority_payload_decoder_test.cc @@ -0,0 +1,89 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/payload_decoders/priority_payload_decoder.h" + +#include + +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/test_tools/frame_parts.h" +#include "quiche/http2/test_tools/frame_parts_collector.h" +#include "quiche/http2/test_tools/http2_frame_builder.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/http2/test_tools/http2_structures_test_util.h" +#include "quiche/http2/test_tools/payload_decoder_base_test_util.h" +#include "quiche/http2/test_tools/random_decoder_test_base.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { + +class PriorityPayloadDecoderPeer { + public: + static constexpr Http2FrameType FrameType() { + return Http2FrameType::PRIORITY; + } + + // Returns the mask of flags that affect the decoding of the payload (i.e. + // flags that that indicate the presence of certain fields or padding). + static constexpr uint8_t FlagsAffectingPayloadDecoding() { return 0; } +}; + +namespace { + +struct Listener : public FramePartsCollector { + void OnPriorityFrame(const Http2FrameHeader& header, + const Http2PriorityFields& priority_fields) override { + QUICHE_VLOG(1) << "OnPriority: " << header << "; " << priority_fields; + StartAndEndFrame(header)->OnPriorityFrame(header, priority_fields); + } + + void OnFrameSizeError(const Http2FrameHeader& header) override { + QUICHE_VLOG(1) << "OnFrameSizeError: " << header; + FrameError(header)->OnFrameSizeError(header); + } +}; + +class PriorityPayloadDecoderTest + : public AbstractPayloadDecoderTest { + protected: + Http2PriorityFields RandPriorityFields() { + Http2PriorityFields fields; + test::Randomize(&fields, RandomPtr()); + return fields; + } +}; + +// Confirm we get an error if the payload is not the correct size to hold +// exactly one Http2PriorityFields. +TEST_F(PriorityPayloadDecoderTest, WrongSize) { + auto approve_size = [](size_t size) { + return size != Http2PriorityFields::EncodedSize(); + }; + Http2FrameBuilder fb; + fb.Append(RandPriorityFields()); + fb.Append(RandPriorityFields()); + EXPECT_TRUE(VerifyDetectsFrameSizeError(0, fb.buffer(), approve_size)); +} + +TEST_F(PriorityPayloadDecoderTest, VariousPayloads) { + for (int n = 0; n < 100; ++n) { + Http2PriorityFields fields = RandPriorityFields(); + Http2FrameBuilder fb; + fb.Append(fields); + Http2FrameHeader header(fb.size(), Http2FrameType::PRIORITY, RandFlags(), + RandStreamId()); + set_frame_header(header); + FrameParts expected(header); + expected.SetOptPriority(fields); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(fb.buffer(), expected)); + } +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/decoder/payload_decoders/priority_update_payload_decoder.cc b/quiche/http2/decoder/payload_decoders/priority_update_payload_decoder.cc new file mode 100644 index 000000000000..2c74251c2043 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/priority_update_payload_decoder.cc @@ -0,0 +1,123 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/payload_decoders/priority_update_payload_decoder.h" + +#include + +#include "absl/base/macros.h" +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +std::ostream& operator<<(std::ostream& out, + PriorityUpdatePayloadDecoder::PayloadState v) { + switch (v) { + case PriorityUpdatePayloadDecoder::PayloadState::kStartDecodingFixedFields: + return out << "kStartDecodingFixedFields"; + case PriorityUpdatePayloadDecoder::PayloadState::kResumeDecodingFixedFields: + return out << "kResumeDecodingFixedFields"; + case PriorityUpdatePayloadDecoder::PayloadState::kHandleFixedFieldsStatus: + return out << "kHandleFixedFieldsStatus"; + case PriorityUpdatePayloadDecoder::PayloadState::kReadPriorityFieldValue: + return out << "kReadPriorityFieldValue"; + } + // Since the value doesn't come over the wire, only a programming bug should + // result in reaching this point. + int unknown = static_cast(v); + QUICHE_BUG(http2_bug_173_1) + << "Invalid PriorityUpdatePayloadDecoder::PayloadState: " << unknown; + return out << "PriorityUpdatePayloadDecoder::PayloadState(" << unknown << ")"; +} + +DecodeStatus PriorityUpdatePayloadDecoder::StartDecodingPayload( + FrameDecoderState* state, DecodeBuffer* db) { + QUICHE_DVLOG(2) << "PriorityUpdatePayloadDecoder::StartDecodingPayload: " + << state->frame_header(); + QUICHE_DCHECK_EQ(Http2FrameType::PRIORITY_UPDATE, state->frame_header().type); + QUICHE_DCHECK_LE(db->Remaining(), state->frame_header().payload_length); + QUICHE_DCHECK_EQ(0, state->frame_header().flags); + + state->InitializeRemainders(); + payload_state_ = PayloadState::kStartDecodingFixedFields; + return ResumeDecodingPayload(state, db); +} + +DecodeStatus PriorityUpdatePayloadDecoder::ResumeDecodingPayload( + FrameDecoderState* state, DecodeBuffer* db) { + QUICHE_DVLOG(2) << "PriorityUpdatePayloadDecoder::ResumeDecodingPayload: " + "remaining_payload=" + << state->remaining_payload() + << ", db->Remaining=" << db->Remaining(); + + const Http2FrameHeader& frame_header = state->frame_header(); + QUICHE_DCHECK_EQ(Http2FrameType::PRIORITY_UPDATE, frame_header.type); + QUICHE_DCHECK_LE(db->Remaining(), frame_header.payload_length); + QUICHE_DCHECK_NE(PayloadState::kHandleFixedFieldsStatus, payload_state_); + + // |status| has to be initialized to some value to avoid compiler error in + // case PayloadState::kHandleFixedFieldsStatus below, but value does not + // matter, see QUICHE_DCHECK_NE above. + DecodeStatus status = DecodeStatus::kDecodeError; + size_t avail; + while (true) { + QUICHE_DVLOG(2) + << "PriorityUpdatePayloadDecoder::ResumeDecodingPayload payload_state_=" + << payload_state_; + switch (payload_state_) { + case PayloadState::kStartDecodingFixedFields: + status = state->StartDecodingStructureInPayload( + &priority_update_fields_, db); + ABSL_FALLTHROUGH_INTENDED; + + case PayloadState::kHandleFixedFieldsStatus: + if (status == DecodeStatus::kDecodeDone) { + state->listener()->OnPriorityUpdateStart(frame_header, + priority_update_fields_); + } else { + // Not done decoding the structure. Either we've got more payload + // to decode, or we've run out because the payload is too short, + // in which case OnFrameSizeError will have already been called. + QUICHE_DCHECK((status == DecodeStatus::kDecodeInProgress && + state->remaining_payload() > 0) || + (status == DecodeStatus::kDecodeError && + state->remaining_payload() == 0)) + << "\n status=" << status + << "; remaining_payload=" << state->remaining_payload(); + payload_state_ = PayloadState::kResumeDecodingFixedFields; + return status; + } + ABSL_FALLTHROUGH_INTENDED; + + case PayloadState::kReadPriorityFieldValue: + // Anything left in the decode buffer is the Priority Field Value. + avail = db->Remaining(); + if (avail > 0) { + state->listener()->OnPriorityUpdatePayload(db->cursor(), avail); + db->AdvanceCursor(avail); + state->ConsumePayload(avail); + } + if (state->remaining_payload() > 0) { + payload_state_ = PayloadState::kReadPriorityFieldValue; + return DecodeStatus::kDecodeInProgress; + } + state->listener()->OnPriorityUpdateEnd(); + return DecodeStatus::kDecodeDone; + + case PayloadState::kResumeDecodingFixedFields: + status = state->ResumeDecodingStructureInPayload( + &priority_update_fields_, db); + payload_state_ = PayloadState::kHandleFixedFieldsStatus; + continue; + } + QUICHE_BUG(http2_bug_173_2) << "PayloadState: " << payload_state_; + } +} + +} // namespace http2 diff --git a/quiche/http2/decoder/payload_decoders/priority_update_payload_decoder.h b/quiche/http2/decoder/payload_decoders/priority_update_payload_decoder.h new file mode 100644 index 000000000000..39d081765bdb --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/priority_update_payload_decoder.h @@ -0,0 +1,63 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_PRIORITY_UPDATE_PAYLOAD_DECODER_H_ +#define QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_PRIORITY_UPDATE_PAYLOAD_DECODER_H_ + +// Decodes the payload of a PRIORITY_UPDATE frame. + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/decoder/frame_decoder_state.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace test { +class PriorityUpdatePayloadDecoderPeer; +} // namespace test + +class QUICHE_EXPORT PriorityUpdatePayloadDecoder { + public: + // States during decoding of a PRIORITY_UPDATE frame. + enum class PayloadState { + // At the start of the PRIORITY_UPDATE frame payload, ready to start + // decoding the fixed size fields into priority_update_fields_. + kStartDecodingFixedFields, + + // The fixed size fields weren't all available when the decoder first + // tried to decode them; this state resumes the decoding when + // ResumeDecodingPayload is called later. + kResumeDecodingFixedFields, + + // Handle the DecodeStatus returned from starting or resuming the decoding + // of Http2PriorityUpdateFields into priority_update_fields_. If complete, + // calls OnPriorityUpdateStart. + kHandleFixedFieldsStatus, + + // Report the Priority Field Value portion of the payload to the listener's + // OnPriorityUpdatePayload method, and call OnPriorityUpdateEnd when the end + // of the payload is reached. + kReadPriorityFieldValue, + }; + + // Starts the decoding of a PRIORITY_UPDATE frame's payload, and completes it + // if the entire payload is in the provided decode buffer. + DecodeStatus StartDecodingPayload(FrameDecoderState* state, DecodeBuffer* db); + + // Resumes decoding a PRIORITY_UPDATE frame that has been split across decode + // buffers. + DecodeStatus ResumeDecodingPayload(FrameDecoderState* state, + DecodeBuffer* db); + + private: + friend class test::PriorityUpdatePayloadDecoderPeer; + + Http2PriorityUpdateFields priority_update_fields_; + PayloadState payload_state_; +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_PRIORITY_UPDATE_PAYLOAD_DECODER_H_ diff --git a/quiche/http2/decoder/payload_decoders/priority_update_payload_decoder_test.cc b/quiche/http2/decoder/payload_decoders/priority_update_payload_decoder_test.cc new file mode 100644 index 000000000000..14f3557b2bf4 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/priority_update_payload_decoder_test.cc @@ -0,0 +1,114 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/payload_decoders/priority_update_payload_decoder.h" + +#include + +#include + +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/test_tools/frame_parts.h" +#include "quiche/http2/test_tools/frame_parts_collector.h" +#include "quiche/http2/test_tools/http2_frame_builder.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/http2/test_tools/http2_structures_test_util.h" +#include "quiche/http2/test_tools/payload_decoder_base_test_util.h" +#include "quiche/http2/test_tools/random_decoder_test_base.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { + +class PriorityUpdatePayloadDecoderPeer { + public: + static constexpr Http2FrameType FrameType() { + return Http2FrameType::PRIORITY_UPDATE; + } + + // Returns the mask of flags that affect the decoding of the payload (i.e. + // flags that that indicate the presence of certain fields or padding). + static constexpr uint8_t FlagsAffectingPayloadDecoding() { return 0; } +}; + +namespace { + +struct Listener : public FramePartsCollector { + void OnPriorityUpdateStart( + const Http2FrameHeader& header, + const Http2PriorityUpdateFields& priority_update) override { + QUICHE_VLOG(1) << "OnPriorityUpdateStart header: " << header + << "; priority_update: " << priority_update; + StartFrame(header)->OnPriorityUpdateStart(header, priority_update); + } + + void OnPriorityUpdatePayload(const char* data, size_t len) override { + QUICHE_VLOG(1) << "OnPriorityUpdatePayload: len=" << len; + CurrentFrame()->OnPriorityUpdatePayload(data, len); + } + + void OnPriorityUpdateEnd() override { + QUICHE_VLOG(1) << "OnPriorityUpdateEnd"; + EndFrame()->OnPriorityUpdateEnd(); + } + + void OnFrameSizeError(const Http2FrameHeader& header) override { + QUICHE_VLOG(1) << "OnFrameSizeError: " << header; + FrameError(header)->OnFrameSizeError(header); + } +}; + +class PriorityUpdatePayloadDecoderTest + : public AbstractPayloadDecoderTest {}; + +// Confirm we get an error if the payload is not long enough to hold +// Http2PriorityUpdateFields. +TEST_F(PriorityUpdatePayloadDecoderTest, Truncated) { + auto approve_size = [](size_t size) { + return size != Http2PriorityUpdateFields::EncodedSize(); + }; + Http2FrameBuilder fb; + fb.Append(Http2PriorityUpdateFields(123)); + EXPECT_TRUE(VerifyDetectsFrameSizeError(0, fb.buffer(), approve_size)); +} + +class PriorityUpdatePayloadLengthTests + : public AbstractPayloadDecoderTest, + public ::testing::WithParamInterface { + protected: + PriorityUpdatePayloadLengthTests() : length_(GetParam()) { + QUICHE_VLOG(1) << "################ length_=" << length_ + << " ################"; + } + + const uint32_t length_; +}; + +INSTANTIATE_TEST_SUITE_P(VariousLengths, PriorityUpdatePayloadLengthTests, + ::testing::Values(0, 1, 2, 3, 4, 5, 6)); + +TEST_P(PriorityUpdatePayloadLengthTests, ValidLength) { + Http2PriorityUpdateFields priority_update; + Randomize(&priority_update, RandomPtr()); + std::string priority_field_value = Random().RandString(length_); + Http2FrameBuilder fb; + fb.Append(priority_update); + fb.Append(priority_field_value); + Http2FrameHeader header(fb.size(), Http2FrameType::PRIORITY_UPDATE, + RandFlags(), RandStreamId()); + set_frame_header(header); + FrameParts expected(header, priority_field_value); + expected.SetOptPriorityUpdate(Http2PriorityUpdateFields{priority_update}); + ASSERT_TRUE(DecodePayloadAndValidateSeveralWays(fb.buffer(), expected)); +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/decoder/payload_decoders/push_promise_payload_decoder.cc b/quiche/http2/decoder/payload_decoders/push_promise_payload_decoder.cc new file mode 100644 index 000000000000..50a403e1126a --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/push_promise_payload_decoder.cc @@ -0,0 +1,171 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/payload_decoders/push_promise_payload_decoder.h" + +#include + +#include "absl/base/macros.h" +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +std::ostream& operator<<(std::ostream& out, + PushPromisePayloadDecoder::PayloadState v) { + switch (v) { + case PushPromisePayloadDecoder::PayloadState::kReadPadLength: + return out << "kReadPadLength"; + case PushPromisePayloadDecoder::PayloadState:: + kStartDecodingPushPromiseFields: + return out << "kStartDecodingPushPromiseFields"; + case PushPromisePayloadDecoder::PayloadState::kReadPayload: + return out << "kReadPayload"; + case PushPromisePayloadDecoder::PayloadState::kSkipPadding: + return out << "kSkipPadding"; + case PushPromisePayloadDecoder::PayloadState:: + kResumeDecodingPushPromiseFields: + return out << "kResumeDecodingPushPromiseFields"; + } + return out << static_cast(v); +} + +DecodeStatus PushPromisePayloadDecoder::StartDecodingPayload( + FrameDecoderState* state, DecodeBuffer* db) { + const Http2FrameHeader& frame_header = state->frame_header(); + const uint32_t total_length = frame_header.payload_length; + + QUICHE_DVLOG(2) << "PushPromisePayloadDecoder::StartDecodingPayload: " + << frame_header; + + QUICHE_DCHECK_EQ(Http2FrameType::PUSH_PROMISE, frame_header.type); + QUICHE_DCHECK_LE(db->Remaining(), total_length); + QUICHE_DCHECK_EQ(0, frame_header.flags & ~(Http2FrameFlag::END_HEADERS | + Http2FrameFlag::PADDED)); + + if (!frame_header.IsPadded()) { + // If it turns out that PUSH_PROMISE frames without padding are sufficiently + // common, and that they are usually short enough that they fit entirely + // into one DecodeBuffer, we can detect that here and implement a special + // case, avoiding the state machine in ResumeDecodingPayload. + payload_state_ = PayloadState::kStartDecodingPushPromiseFields; + } else { + payload_state_ = PayloadState::kReadPadLength; + } + state->InitializeRemainders(); + return ResumeDecodingPayload(state, db); +} + +DecodeStatus PushPromisePayloadDecoder::ResumeDecodingPayload( + FrameDecoderState* state, DecodeBuffer* db) { + QUICHE_DVLOG(2) << "UnknownPayloadDecoder::ResumeDecodingPayload" + << " remaining_payload=" << state->remaining_payload() + << " db->Remaining=" << db->Remaining(); + + const Http2FrameHeader& frame_header = state->frame_header(); + QUICHE_DCHECK_EQ(Http2FrameType::PUSH_PROMISE, frame_header.type); + QUICHE_DCHECK_LE(state->remaining_payload(), frame_header.payload_length); + QUICHE_DCHECK_LE(db->Remaining(), frame_header.payload_length); + + DecodeStatus status; + while (true) { + QUICHE_DVLOG(2) + << "PushPromisePayloadDecoder::ResumeDecodingPayload payload_state_=" + << payload_state_; + switch (payload_state_) { + case PayloadState::kReadPadLength: + QUICHE_DCHECK_EQ(state->remaining_payload(), + frame_header.payload_length); + // ReadPadLength handles the OnPadLength callback, and updating the + // remaining_payload and remaining_padding fields. If the amount of + // padding is too large to fit in the frame's payload, ReadPadLength + // instead calls OnPaddingTooLong and returns kDecodeError. + // Suppress the call to OnPadLength because we haven't yet called + // OnPushPromiseStart, which needs to wait until we've decoded the + // Promised Stream ID. + status = state->ReadPadLength(db, /*report_pad_length*/ false); + if (status != DecodeStatus::kDecodeDone) { + payload_state_ = PayloadState::kReadPadLength; + return status; + } + ABSL_FALLTHROUGH_INTENDED; + + case PayloadState::kStartDecodingPushPromiseFields: + status = + state->StartDecodingStructureInPayload(&push_promise_fields_, db); + if (status != DecodeStatus::kDecodeDone) { + payload_state_ = PayloadState::kResumeDecodingPushPromiseFields; + return status; + } + // Finished decoding the Promised Stream ID. Can now tell the listener + // that we're starting to decode a PUSH_PROMISE frame. + ReportPushPromise(state); + ABSL_FALLTHROUGH_INTENDED; + + case PayloadState::kReadPayload: + QUICHE_DCHECK_LT(state->remaining_payload(), + frame_header.payload_length); + QUICHE_DCHECK_LE(state->remaining_payload(), + frame_header.payload_length - + Http2PushPromiseFields::EncodedSize()); + QUICHE_DCHECK_LE( + state->remaining_payload(), + frame_header.payload_length - + Http2PushPromiseFields::EncodedSize() - + (frame_header.IsPadded() ? (1 + state->remaining_padding()) + : 0)); + { + size_t avail = state->AvailablePayload(db); + state->listener()->OnHpackFragment(db->cursor(), avail); + db->AdvanceCursor(avail); + state->ConsumePayload(avail); + } + if (state->remaining_payload() > 0) { + payload_state_ = PayloadState::kReadPayload; + return DecodeStatus::kDecodeInProgress; + } + ABSL_FALLTHROUGH_INTENDED; + + case PayloadState::kSkipPadding: + // SkipPadding handles the OnPadding callback. + if (state->SkipPadding(db)) { + state->listener()->OnPushPromiseEnd(); + return DecodeStatus::kDecodeDone; + } + payload_state_ = PayloadState::kSkipPadding; + return DecodeStatus::kDecodeInProgress; + + case PayloadState::kResumeDecodingPushPromiseFields: + status = + state->ResumeDecodingStructureInPayload(&push_promise_fields_, db); + if (status == DecodeStatus::kDecodeDone) { + // Finished decoding the Promised Stream ID. Can now tell the listener + // that we're starting to decode a PUSH_PROMISE frame. + ReportPushPromise(state); + payload_state_ = PayloadState::kReadPayload; + continue; + } + payload_state_ = PayloadState::kResumeDecodingPushPromiseFields; + return status; + } + QUICHE_BUG(http2_bug_183_1) << "PayloadState: " << payload_state_; + } +} + +void PushPromisePayloadDecoder::ReportPushPromise(FrameDecoderState* state) { + const Http2FrameHeader& frame_header = state->frame_header(); + if (frame_header.IsPadded()) { + state->listener()->OnPushPromiseStart(frame_header, push_promise_fields_, + 1 + state->remaining_padding()); + } else { + state->listener()->OnPushPromiseStart(frame_header, push_promise_fields_, + 0); + } +} + +} // namespace http2 diff --git a/quiche/http2/decoder/payload_decoders/push_promise_payload_decoder.h b/quiche/http2/decoder/payload_decoders/push_promise_payload_decoder.h new file mode 100644 index 000000000000..c6bf2a2b1a24 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/push_promise_payload_decoder.h @@ -0,0 +1,66 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_PUSH_PROMISE_PAYLOAD_DECODER_H_ +#define QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_PUSH_PROMISE_PAYLOAD_DECODER_H_ + +// Decodes the payload of a PUSH_PROMISE frame. + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/decoder/frame_decoder_state.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace test { +class PushPromisePayloadDecoderPeer; +} // namespace test + +class QUICHE_EXPORT PushPromisePayloadDecoder { + public: + // States during decoding of a PUSH_PROMISE frame. + enum class PayloadState { + // The frame is padded and we need to read the PAD_LENGTH field (1 byte). + kReadPadLength, + + // Ready to start decoding the fixed size fields of the PUSH_PROMISE + // frame into push_promise_fields_. + kStartDecodingPushPromiseFields, + + // The decoder has already called OnPushPromiseStart, and is now reporting + // the HPACK block fragment to the listener's OnHpackFragment method. + kReadPayload, + + // The decoder has finished with the HPACK block fragment, and is now + // ready to skip the trailing padding, if the frame has any. + kSkipPadding, + + // The fixed size fields weren't all available when the decoder first tried + // to decode them (state kStartDecodingPushPromiseFields); this state + // resumes the decoding when ResumeDecodingPayload is called later. + kResumeDecodingPushPromiseFields, + }; + + // Starts the decoding of a PUSH_PROMISE frame's payload, and completes it if + // the entire payload is in the provided decode buffer. + DecodeStatus StartDecodingPayload(FrameDecoderState* state, DecodeBuffer* db); + + // Resumes decoding a PUSH_PROMISE frame's payload that has been split across + // decode buffers. + DecodeStatus ResumeDecodingPayload(FrameDecoderState* state, + DecodeBuffer* db); + + private: + friend class test::PushPromisePayloadDecoderPeer; + + void ReportPushPromise(FrameDecoderState* state); + + PayloadState payload_state_; + Http2PushPromiseFields push_promise_fields_; +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_PUSH_PROMISE_PAYLOAD_DECODER_H_ diff --git a/quiche/http2/decoder/payload_decoders/push_promise_payload_decoder_test.cc b/quiche/http2/decoder/payload_decoders/push_promise_payload_decoder_test.cc new file mode 100644 index 000000000000..2ea5e62e1eec --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/push_promise_payload_decoder_test.cc @@ -0,0 +1,138 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/payload_decoders/push_promise_payload_decoder.h" + +#include + +#include + +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/test_tools/frame_parts.h" +#include "quiche/http2/test_tools/frame_parts_collector.h" +#include "quiche/http2/test_tools/http2_frame_builder.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/http2/test_tools/http2_structures_test_util.h" +#include "quiche/http2/test_tools/payload_decoder_base_test_util.h" +#include "quiche/http2/test_tools/random_decoder_test_base.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { + +// Provides friend access to an instance of the payload decoder, and also +// provides info to aid in testing. +class PushPromisePayloadDecoderPeer { + public: + static constexpr Http2FrameType FrameType() { + return Http2FrameType::PUSH_PROMISE; + } + + // Returns the mask of flags that affect the decoding of the payload (i.e. + // flags that that indicate the presence of certain fields or padding). + static constexpr uint8_t FlagsAffectingPayloadDecoding() { + return Http2FrameFlag::PADDED; + } +}; + +namespace { + +// Listener listens for only those methods expected by the payload decoder +// under test, and forwards them onto the FrameParts instance for the current +// frame. +struct Listener : public FramePartsCollector { + void OnPushPromiseStart(const Http2FrameHeader& header, + const Http2PushPromiseFields& promise, + size_t total_padding_length) override { + QUICHE_VLOG(1) << "OnPushPromiseStart header: " << header + << " promise: " << promise + << " total_padding_length: " << total_padding_length; + EXPECT_EQ(Http2FrameType::PUSH_PROMISE, header.type); + StartFrame(header)->OnPushPromiseStart(header, promise, + total_padding_length); + } + + void OnHpackFragment(const char* data, size_t len) override { + QUICHE_VLOG(1) << "OnHpackFragment: len=" << len; + CurrentFrame()->OnHpackFragment(data, len); + } + + void OnPushPromiseEnd() override { + QUICHE_VLOG(1) << "OnPushPromiseEnd"; + EndFrame()->OnPushPromiseEnd(); + } + + void OnPadding(const char* padding, size_t skipped_length) override { + QUICHE_VLOG(1) << "OnPadding: " << skipped_length; + CurrentFrame()->OnPadding(padding, skipped_length); + } + + void OnPaddingTooLong(const Http2FrameHeader& header, + size_t missing_length) override { + QUICHE_VLOG(1) << "OnPaddingTooLong: " << header + << "; missing_length: " << missing_length; + FrameError(header)->OnPaddingTooLong(header, missing_length); + } + + void OnFrameSizeError(const Http2FrameHeader& header) override { + QUICHE_VLOG(1) << "OnFrameSizeError: " << header; + FrameError(header)->OnFrameSizeError(header); + } +}; + +class PushPromisePayloadDecoderTest + : public AbstractPaddablePayloadDecoderTest< + PushPromisePayloadDecoder, PushPromisePayloadDecoderPeer, Listener> { +}; + +INSTANTIATE_TEST_SUITE_P(VariousPadLengths, PushPromisePayloadDecoderTest, + ::testing::Values(0, 1, 2, 3, 4, 254, 255, 256)); + +// Payload contains the required Http2PushPromiseFields, followed by some +// (fake) HPACK payload. +TEST_P(PushPromisePayloadDecoderTest, VariousHpackPayloadSizes) { + for (size_t hpack_size : {0, 1, 2, 3, 255, 256, 1024}) { + QUICHE_LOG(INFO) << "########### hpack_size = " << hpack_size + << " ###########"; + Reset(); + std::string hpack_payload = Random().RandString(hpack_size); + Http2PushPromiseFields push_promise{RandStreamId()}; + frame_builder_.Append(push_promise); + frame_builder_.Append(hpack_payload); + MaybeAppendTrailingPadding(); + Http2FrameHeader frame_header(frame_builder_.size(), + Http2FrameType::PUSH_PROMISE, RandFlags(), + RandStreamId()); + set_frame_header(frame_header); + FrameParts expected(frame_header, hpack_payload, total_pad_length_); + expected.SetOptPushPromise(push_promise); + EXPECT_TRUE( + DecodePayloadAndValidateSeveralWays(frame_builder_.buffer(), expected)); + } +} + +// Confirm we get an error if the payload is not long enough for the required +// portion of the payload, regardless of the amount of (valid) padding. +TEST_P(PushPromisePayloadDecoderTest, Truncated) { + auto approve_size = [](size_t size) { + return size != Http2PushPromiseFields::EncodedSize(); + }; + Http2PushPromiseFields push_promise{RandStreamId()}; + Http2FrameBuilder fb; + fb.Append(push_promise); + EXPECT_TRUE(VerifyDetectsMultipleFrameSizeErrors(0, fb.buffer(), approve_size, + total_pad_length_)); +} + +// Confirm we get an error if the PADDED flag is set but the payload is not +// long enough to hold even the Pad Length amount of padding. +TEST_P(PushPromisePayloadDecoderTest, PaddingTooLong) { + EXPECT_TRUE(VerifyDetectsPaddingTooLong()); +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/decoder/payload_decoders/rst_stream_payload_decoder.cc b/quiche/http2/decoder/payload_decoders/rst_stream_payload_decoder.cc new file mode 100644 index 000000000000..27d410c154a4 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/rst_stream_payload_decoder.cc @@ -0,0 +1,64 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/payload_decoders/rst_stream_payload_decoder.h" + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +DecodeStatus RstStreamPayloadDecoder::StartDecodingPayload( + FrameDecoderState* state, DecodeBuffer* db) { + QUICHE_DVLOG(2) << "RstStreamPayloadDecoder::StartDecodingPayload: " + << state->frame_header(); + QUICHE_DCHECK_EQ(Http2FrameType::RST_STREAM, state->frame_header().type); + QUICHE_DCHECK_LE(db->Remaining(), state->frame_header().payload_length); + // RST_STREAM has no flags. + QUICHE_DCHECK_EQ(0, state->frame_header().flags); + state->InitializeRemainders(); + return HandleStatus( + state, state->StartDecodingStructureInPayload(&rst_stream_fields_, db)); +} + +DecodeStatus RstStreamPayloadDecoder::ResumeDecodingPayload( + FrameDecoderState* state, DecodeBuffer* db) { + QUICHE_DVLOG(2) << "RstStreamPayloadDecoder::ResumeDecodingPayload" + << " remaining_payload=" << state->remaining_payload() + << " db->Remaining=" << db->Remaining(); + QUICHE_DCHECK_EQ(Http2FrameType::RST_STREAM, state->frame_header().type); + QUICHE_DCHECK_LE(db->Remaining(), state->frame_header().payload_length); + return HandleStatus( + state, state->ResumeDecodingStructureInPayload(&rst_stream_fields_, db)); +} + +DecodeStatus RstStreamPayloadDecoder::HandleStatus(FrameDecoderState* state, + DecodeStatus status) { + QUICHE_DVLOG(2) << "HandleStatus: status=" << status + << "; remaining_payload=" << state->remaining_payload(); + if (status == DecodeStatus::kDecodeDone) { + if (state->remaining_payload() == 0) { + state->listener()->OnRstStream(state->frame_header(), + rst_stream_fields_.error_code); + return DecodeStatus::kDecodeDone; + } + // Payload is too long. + return state->ReportFrameSizeError(); + } + // Not done decoding the structure. Either we've got more payload to decode, + // or we've run out because the payload is too short, in which case + // OnFrameSizeError will have already been called by the FrameDecoderState. + QUICHE_DCHECK( + (status == DecodeStatus::kDecodeInProgress && + state->remaining_payload() > 0) || + (status == DecodeStatus::kDecodeError && state->remaining_payload() == 0)) + << "\n status=" << status + << "; remaining_payload=" << state->remaining_payload(); + return status; +} + +} // namespace http2 diff --git a/quiche/http2/decoder/payload_decoders/rst_stream_payload_decoder.h b/quiche/http2/decoder/payload_decoders/rst_stream_payload_decoder.h new file mode 100644 index 000000000000..68e701ef7249 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/rst_stream_payload_decoder.h @@ -0,0 +1,42 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_RST_STREAM_PAYLOAD_DECODER_H_ +#define QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_RST_STREAM_PAYLOAD_DECODER_H_ + +// Decodes the payload of a RST_STREAM frame. + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/decoder/frame_decoder_state.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace test { +class RstStreamPayloadDecoderPeer; +} // namespace test + +class QUICHE_EXPORT RstStreamPayloadDecoder { + public: + // Starts the decoding of a RST_STREAM frame's payload, and completes it if + // the entire payload is in the provided decode buffer. + DecodeStatus StartDecodingPayload(FrameDecoderState* state, DecodeBuffer* db); + + // Resumes decoding a RST_STREAM frame's payload that has been split across + // decode buffers. + DecodeStatus ResumeDecodingPayload(FrameDecoderState* state, + DecodeBuffer* db); + + private: + friend class test::RstStreamPayloadDecoderPeer; + + DecodeStatus HandleStatus(FrameDecoderState* state, DecodeStatus status); + + Http2RstStreamFields rst_stream_fields_; +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_RST_STREAM_PAYLOAD_DECODER_H_ diff --git a/quiche/http2/decoder/payload_decoders/rst_stream_payload_decoder_test.cc b/quiche/http2/decoder/payload_decoders/rst_stream_payload_decoder_test.cc new file mode 100644 index 000000000000..999d7d07f860 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/rst_stream_payload_decoder_test.cc @@ -0,0 +1,92 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/payload_decoders/rst_stream_payload_decoder.h" + +#include + +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/test_tools/frame_parts.h" +#include "quiche/http2/test_tools/frame_parts_collector.h" +#include "quiche/http2/test_tools/http2_constants_test_util.h" +#include "quiche/http2/test_tools/http2_frame_builder.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/http2/test_tools/http2_structures_test_util.h" +#include "quiche/http2/test_tools/payload_decoder_base_test_util.h" +#include "quiche/http2/test_tools/random_decoder_test_base.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { + +class RstStreamPayloadDecoderPeer { + public: + static constexpr Http2FrameType FrameType() { + return Http2FrameType::RST_STREAM; + } + + // Returns the mask of flags that affect the decoding of the payload (i.e. + // flags that that indicate the presence of certain fields or padding). + static constexpr uint8_t FlagsAffectingPayloadDecoding() { return 0; } +}; + +namespace { + +struct Listener : public FramePartsCollector { + void OnRstStream(const Http2FrameHeader& header, + Http2ErrorCode error_code) override { + QUICHE_VLOG(1) << "OnRstStream: " << header + << "; error_code=" << error_code; + StartAndEndFrame(header)->OnRstStream(header, error_code); + } + + void OnFrameSizeError(const Http2FrameHeader& header) override { + QUICHE_VLOG(1) << "OnFrameSizeError: " << header; + FrameError(header)->OnFrameSizeError(header); + } +}; + +class RstStreamPayloadDecoderTest + : public AbstractPayloadDecoderTest { + protected: + Http2RstStreamFields RandRstStreamFields() { + Http2RstStreamFields fields; + test::Randomize(&fields, RandomPtr()); + return fields; + } +}; + +// Confirm we get an error if the payload is not the correct size to hold +// exactly one Http2RstStreamFields. +TEST_F(RstStreamPayloadDecoderTest, WrongSize) { + auto approve_size = [](size_t size) { + return size != Http2RstStreamFields::EncodedSize(); + }; + Http2FrameBuilder fb; + fb.Append(RandRstStreamFields()); + fb.Append(RandRstStreamFields()); + fb.Append(RandRstStreamFields()); + EXPECT_TRUE(VerifyDetectsFrameSizeError(0, fb.buffer(), approve_size)); +} + +TEST_F(RstStreamPayloadDecoderTest, AllErrors) { + for (auto error_code : AllHttp2ErrorCodes()) { + Http2RstStreamFields fields{error_code}; + Http2FrameBuilder fb; + fb.Append(fields); + Http2FrameHeader header(fb.size(), Http2FrameType::RST_STREAM, RandFlags(), + RandStreamId()); + set_frame_header(header); + FrameParts expected(header); + expected.SetOptRstStreamErrorCode(error_code); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(fb.buffer(), expected)); + } +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/decoder/payload_decoders/settings_payload_decoder.cc b/quiche/http2/decoder/payload_decoders/settings_payload_decoder.cc new file mode 100644 index 000000000000..0ba328ff3e15 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/settings_payload_decoder.cc @@ -0,0 +1,95 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/payload_decoders/settings_payload_decoder.h" + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +DecodeStatus SettingsPayloadDecoder::StartDecodingPayload( + FrameDecoderState* state, DecodeBuffer* db) { + const Http2FrameHeader& frame_header = state->frame_header(); + const uint32_t total_length = frame_header.payload_length; + + QUICHE_DVLOG(2) << "SettingsPayloadDecoder::StartDecodingPayload: " + << frame_header; + QUICHE_DCHECK_EQ(Http2FrameType::SETTINGS, frame_header.type); + QUICHE_DCHECK_LE(db->Remaining(), total_length); + QUICHE_DCHECK_EQ(0, frame_header.flags & ~(Http2FrameFlag::ACK)); + + if (frame_header.IsAck()) { + if (total_length == 0) { + state->listener()->OnSettingsAck(frame_header); + return DecodeStatus::kDecodeDone; + } else { + state->InitializeRemainders(); + return state->ReportFrameSizeError(); + } + } else { + state->InitializeRemainders(); + state->listener()->OnSettingsStart(frame_header); + return StartDecodingSettings(state, db); + } +} + +DecodeStatus SettingsPayloadDecoder::ResumeDecodingPayload( + FrameDecoderState* state, DecodeBuffer* db) { + QUICHE_DVLOG(2) << "SettingsPayloadDecoder::ResumeDecodingPayload" + << " remaining_payload=" << state->remaining_payload() + << " db->Remaining=" << db->Remaining(); + QUICHE_DCHECK_EQ(Http2FrameType::SETTINGS, state->frame_header().type); + QUICHE_DCHECK_LE(db->Remaining(), state->frame_header().payload_length); + + DecodeStatus status = + state->ResumeDecodingStructureInPayload(&setting_fields_, db); + if (status == DecodeStatus::kDecodeDone) { + state->listener()->OnSetting(setting_fields_); + return StartDecodingSettings(state, db); + } + return HandleNotDone(state, db, status); +} + +DecodeStatus SettingsPayloadDecoder::StartDecodingSettings( + FrameDecoderState* state, DecodeBuffer* db) { + QUICHE_DVLOG(2) << "SettingsPayloadDecoder::StartDecodingSettings" + << " remaining_payload=" << state->remaining_payload() + << " db->Remaining=" << db->Remaining(); + while (state->remaining_payload() > 0) { + DecodeStatus status = + state->StartDecodingStructureInPayload(&setting_fields_, db); + if (status == DecodeStatus::kDecodeDone) { + state->listener()->OnSetting(setting_fields_); + continue; + } + return HandleNotDone(state, db, status); + } + QUICHE_DVLOG(2) << "LEAVING SettingsPayloadDecoder::StartDecodingSettings" + << "\n\tdb->Remaining=" << db->Remaining() + << "\n\t remaining_payload=" << state->remaining_payload(); + state->listener()->OnSettingsEnd(); + return DecodeStatus::kDecodeDone; +} + +DecodeStatus SettingsPayloadDecoder::HandleNotDone(FrameDecoderState* state, + DecodeBuffer* db, + DecodeStatus status) { + // Not done decoding the structure. Either we've got more payload to decode, + // or we've run out because the payload is too short, in which case + // OnFrameSizeError will have already been called. + QUICHE_DCHECK( + (status == DecodeStatus::kDecodeInProgress && + state->remaining_payload() > 0) || + (status == DecodeStatus::kDecodeError && state->remaining_payload() == 0)) + << "\n status=" << status + << "; remaining_payload=" << state->remaining_payload() + << "; db->Remaining=" << db->Remaining(); + return status; +} + +} // namespace http2 diff --git a/quiche/http2/decoder/payload_decoders/settings_payload_decoder.h b/quiche/http2/decoder/payload_decoders/settings_payload_decoder.h new file mode 100644 index 000000000000..c7a62662dfa2 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/settings_payload_decoder.h @@ -0,0 +1,53 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_SETTINGS_PAYLOAD_DECODER_H_ +#define QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_SETTINGS_PAYLOAD_DECODER_H_ + +// Decodes the payload of a SETTINGS frame; for the RFC, see: +// http://httpwg.org/specs/rfc7540.html#SETTINGS + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/decoder/frame_decoder_state.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace test { +class SettingsPayloadDecoderPeer; +} // namespace test + +class QUICHE_EXPORT SettingsPayloadDecoder { + public: + // Starts the decoding of a SETTINGS frame's payload, and completes it if + // the entire payload is in the provided decode buffer. + DecodeStatus StartDecodingPayload(FrameDecoderState* state, DecodeBuffer* db); + + // Resumes decoding a SETTINGS frame that has been split across decode + // buffers. + DecodeStatus ResumeDecodingPayload(FrameDecoderState* state, + DecodeBuffer* db); + + private: + friend class test::SettingsPayloadDecoderPeer; + + // Decodes as many settings as are available in the decode buffer, starting at + // the first byte of one setting; if a single setting is split across buffers, + // ResumeDecodingPayload will handle starting from where the previous call + // left off, and then will call StartDecodingSettings. + DecodeStatus StartDecodingSettings(FrameDecoderState* state, + DecodeBuffer* db); + + // Decoding a single SETTING returned a status other than kDecodeDone; this + // method just brings together the QUICHE_DCHECKs to reduce duplication. + DecodeStatus HandleNotDone(FrameDecoderState* state, DecodeBuffer* db, + DecodeStatus status); + + Http2SettingFields setting_fields_; +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_SETTINGS_PAYLOAD_DECODER_H_ diff --git a/quiche/http2/decoder/payload_decoders/settings_payload_decoder_test.cc b/quiche/http2/decoder/payload_decoders/settings_payload_decoder_test.cc new file mode 100644 index 000000000000..3533fd4f8c20 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/settings_payload_decoder_test.cc @@ -0,0 +1,159 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/payload_decoders/settings_payload_decoder.h" + +#include + +#include + +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/test_tools/frame_parts.h" +#include "quiche/http2/test_tools/frame_parts_collector.h" +#include "quiche/http2/test_tools/http2_constants_test_util.h" +#include "quiche/http2/test_tools/http2_frame_builder.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/http2/test_tools/http2_structures_test_util.h" +#include "quiche/http2/test_tools/payload_decoder_base_test_util.h" +#include "quiche/http2/test_tools/random_decoder_test_base.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { + +class SettingsPayloadDecoderPeer { + public: + static constexpr Http2FrameType FrameType() { + return Http2FrameType::SETTINGS; + } + + // Returns the mask of flags that affect the decoding of the payload (i.e. + // flags that that indicate the presence of certain fields or padding). + static constexpr uint8_t FlagsAffectingPayloadDecoding() { + return Http2FrameFlag::ACK; + } +}; + +namespace { + +struct Listener : public FramePartsCollector { + void OnSettingsStart(const Http2FrameHeader& header) override { + QUICHE_VLOG(1) << "OnSettingsStart: " << header; + EXPECT_EQ(Http2FrameType::SETTINGS, header.type) << header; + EXPECT_EQ(Http2FrameFlag(), header.flags) << header; + StartFrame(header)->OnSettingsStart(header); + } + + void OnSetting(const Http2SettingFields& setting_fields) override { + QUICHE_VLOG(1) << "Http2SettingFields: setting_fields=" << setting_fields; + CurrentFrame()->OnSetting(setting_fields); + } + + void OnSettingsEnd() override { + QUICHE_VLOG(1) << "OnSettingsEnd"; + EndFrame()->OnSettingsEnd(); + } + + void OnSettingsAck(const Http2FrameHeader& header) override { + QUICHE_VLOG(1) << "OnSettingsAck: " << header; + StartAndEndFrame(header)->OnSettingsAck(header); + } + + void OnFrameSizeError(const Http2FrameHeader& header) override { + QUICHE_VLOG(1) << "OnFrameSizeError: " << header; + FrameError(header)->OnFrameSizeError(header); + } +}; + +class SettingsPayloadDecoderTest + : public AbstractPayloadDecoderTest { + protected: + Http2SettingFields RandSettingsFields() { + Http2SettingFields fields; + test::Randomize(&fields, RandomPtr()); + return fields; + } +}; + +// Confirm we get an error if the SETTINGS payload is not the correct size +// to hold exactly zero or more whole Http2SettingFields. +TEST_F(SettingsPayloadDecoderTest, SettingsWrongSize) { + auto approve_size = [](size_t size) { + // Should get an error if size is not an integral multiple of the size + // of one setting. + return 0 != (size % Http2SettingFields::EncodedSize()); + }; + Http2FrameBuilder fb; + fb.Append(RandSettingsFields()); + fb.Append(RandSettingsFields()); + fb.Append(RandSettingsFields()); + EXPECT_TRUE(VerifyDetectsFrameSizeError(0, fb.buffer(), approve_size)); +} + +// Confirm we get an error if the SETTINGS ACK payload is not empty. +TEST_F(SettingsPayloadDecoderTest, SettingsAkcWrongSize) { + auto approve_size = [](size_t size) { return size != 0; }; + Http2FrameBuilder fb; + fb.Append(RandSettingsFields()); + fb.Append(RandSettingsFields()); + fb.Append(RandSettingsFields()); + EXPECT_TRUE(VerifyDetectsFrameSizeError(Http2FrameFlag::ACK, fb.buffer(), + approve_size)); +} + +// SETTINGS must have stream_id==0, but the payload decoder doesn't check that. +TEST_F(SettingsPayloadDecoderTest, SettingsAck) { + for (int stream_id = 0; stream_id < 3; ++stream_id) { + Http2FrameHeader header(0, Http2FrameType::SETTINGS, + RandFlags() | Http2FrameFlag::ACK, stream_id); + set_frame_header(header); + FrameParts expected(header); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays("", expected)); + } +} + +// Try several values of each known SETTINGS parameter. +TEST_F(SettingsPayloadDecoderTest, OneRealSetting) { + std::vector values = {0, 1, 0xffffffff, Random().Rand32()}; + for (auto param : AllHttp2SettingsParameters()) { + for (uint32_t value : values) { + Http2SettingFields fields(param, value); + Http2FrameBuilder fb; + fb.Append(fields); + Http2FrameHeader header(fb.size(), Http2FrameType::SETTINGS, RandFlags(), + RandStreamId()); + set_frame_header(header); + FrameParts expected(header); + expected.AppendSetting(fields); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(fb.buffer(), expected)); + } + } +} + +// Decode a SETTINGS frame with lots of fields. +TEST_F(SettingsPayloadDecoderTest, ManySettings) { + const size_t num_settings = 100; + const size_t size = Http2SettingFields::EncodedSize() * num_settings; + Http2FrameHeader header(size, Http2FrameType::SETTINGS, + RandFlags(), // & ~Http2FrameFlag::ACK, + RandStreamId()); + set_frame_header(header); + FrameParts expected(header); + Http2FrameBuilder fb; + for (size_t n = 0; n < num_settings; ++n) { + Http2SettingFields fields(static_cast(n), + Random().Rand32()); + fb.Append(fields); + expected.AppendSetting(fields); + } + ASSERT_EQ(size, fb.size()); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(fb.buffer(), expected)); +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/decoder/payload_decoders/unknown_payload_decoder.cc b/quiche/http2/decoder/payload_decoders/unknown_payload_decoder.cc new file mode 100644 index 000000000000..bb8dae7ac927 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/unknown_payload_decoder.cc @@ -0,0 +1,55 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/payload_decoders/unknown_payload_decoder.h" + +#include + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +DecodeStatus UnknownPayloadDecoder::StartDecodingPayload( + FrameDecoderState* state, DecodeBuffer* db) { + const Http2FrameHeader& frame_header = state->frame_header(); + + QUICHE_DVLOG(2) << "UnknownPayloadDecoder::StartDecodingPayload: " + << frame_header; + QUICHE_DCHECK(!IsSupportedHttp2FrameType(frame_header.type)) << frame_header; + QUICHE_DCHECK_LE(db->Remaining(), frame_header.payload_length); + + state->InitializeRemainders(); + state->listener()->OnUnknownStart(frame_header); + return ResumeDecodingPayload(state, db); +} + +DecodeStatus UnknownPayloadDecoder::ResumeDecodingPayload( + FrameDecoderState* state, DecodeBuffer* db) { + QUICHE_DVLOG(2) << "UnknownPayloadDecoder::ResumeDecodingPayload " + << "remaining_payload=" << state->remaining_payload() + << "; db->Remaining=" << db->Remaining(); + QUICHE_DCHECK(!IsSupportedHttp2FrameType(state->frame_header().type)) + << state->frame_header(); + QUICHE_DCHECK_LE(state->remaining_payload(), + state->frame_header().payload_length); + QUICHE_DCHECK_LE(db->Remaining(), state->remaining_payload()); + + size_t avail = db->Remaining(); + if (avail > 0) { + state->listener()->OnUnknownPayload(db->cursor(), avail); + db->AdvanceCursor(avail); + state->ConsumePayload(avail); + } + if (state->remaining_payload() == 0) { + state->listener()->OnUnknownEnd(); + return DecodeStatus::kDecodeDone; + } + return DecodeStatus::kDecodeInProgress; +} + +} // namespace http2 diff --git a/quiche/http2/decoder/payload_decoders/unknown_payload_decoder.h b/quiche/http2/decoder/payload_decoders/unknown_payload_decoder.h new file mode 100644 index 000000000000..bdc5d81e7f94 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/unknown_payload_decoder.h @@ -0,0 +1,33 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_UNKNOWN_PAYLOAD_DECODER_H_ +#define QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_UNKNOWN_PAYLOAD_DECODER_H_ + +// Decodes the payload of a frame whose type unknown. According to the HTTP/2 +// specification (http://httpwg.org/specs/rfc7540.html#FrameHeader): +// Implementations MUST ignore and discard any frame that has +// a type that is unknown. + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/decoder/frame_decoder_state.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { + +class QUICHE_EXPORT UnknownPayloadDecoder { + public: + // Starts decoding a payload of unknown type; just passes it to the listener. + DecodeStatus StartDecodingPayload(FrameDecoderState* state, DecodeBuffer* db); + + // Resumes decoding a payload of unknown type that has been split across + // decode buffers. + DecodeStatus ResumeDecodingPayload(FrameDecoderState* state, + DecodeBuffer* db); +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_UNKNOWN_PAYLOAD_DECODER_H_ diff --git a/quiche/http2/decoder/payload_decoders/unknown_payload_decoder_test.cc b/quiche/http2/decoder/payload_decoders/unknown_payload_decoder_test.cc new file mode 100644 index 000000000000..47f89b7dd6fe --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/unknown_payload_decoder_test.cc @@ -0,0 +1,99 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/payload_decoders/unknown_payload_decoder.h" + +#include + +#include +#include + +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/http2/test_tools/frame_parts.h" +#include "quiche/http2/test_tools/frame_parts_collector.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/http2/test_tools/payload_decoder_base_test_util.h" +#include "quiche/http2/test_tools/random_decoder_test_base.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { +namespace { +Http2FrameType g_unknown_frame_type; +} // namespace + +// Provides friend access to an instance of the payload decoder, and also +// provides info to aid in testing. +class UnknownPayloadDecoderPeer { + public: + static Http2FrameType FrameType() { return g_unknown_frame_type; } + + // Returns the mask of flags that affect the decoding of the payload (i.e. + // flags that that indicate the presence of certain fields or padding). + static constexpr uint8_t FlagsAffectingPayloadDecoding() { return 0; } +}; + +namespace { + +struct Listener : public FramePartsCollector { + void OnUnknownStart(const Http2FrameHeader& header) override { + QUICHE_VLOG(1) << "OnUnknownStart: " << header; + StartFrame(header)->OnUnknownStart(header); + } + + void OnUnknownPayload(const char* data, size_t len) override { + QUICHE_VLOG(1) << "OnUnknownPayload: len=" << len; + CurrentFrame()->OnUnknownPayload(data, len); + } + + void OnUnknownEnd() override { + QUICHE_VLOG(1) << "OnUnknownEnd"; + EndFrame()->OnUnknownEnd(); + } +}; + +constexpr bool SupportedFrameType = false; + +class UnknownPayloadDecoderTest + : public AbstractPayloadDecoderTest, + public ::testing::WithParamInterface { + protected: + UnknownPayloadDecoderTest() : length_(GetParam()) { + QUICHE_VLOG(1) << "################ length_=" << length_ + << " ################"; + + // Each test case will choose a random frame type that isn't supported. + do { + g_unknown_frame_type = static_cast(Random().Rand8()); + } while (IsSupportedHttp2FrameType(g_unknown_frame_type)); + } + + const uint32_t length_; +}; + +INSTANTIATE_TEST_SUITE_P(VariousLengths, UnknownPayloadDecoderTest, + ::testing::Values(0, 1, 2, 3, 255, 256)); + +TEST_P(UnknownPayloadDecoderTest, ValidLength) { + std::string unknown_payload = Random().RandString(length_); + Http2FrameHeader frame_header(length_, g_unknown_frame_type, Random().Rand8(), + RandStreamId()); + set_frame_header(frame_header); + FrameParts expected(frame_header, unknown_payload); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(unknown_payload, expected)); + // TODO(jamessynge): Check here (and in other such tests) that the fast + // and slow decode counts are both non-zero. Perhaps also add some kind of + // test for the listener having been called. That could simply be a test + // that there is a single collected FrameParts instance, and that it matches + // expected. +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/decoder/payload_decoders/window_update_payload_decoder.cc b/quiche/http2/decoder/payload_decoders/window_update_payload_decoder.cc new file mode 100644 index 000000000000..4174df8fd46b --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/window_update_payload_decoder.cc @@ -0,0 +1,80 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/payload_decoders/window_update_payload_decoder.h" + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_http2_structures.h" +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +DecodeStatus WindowUpdatePayloadDecoder::StartDecodingPayload( + FrameDecoderState* state, DecodeBuffer* db) { + const Http2FrameHeader& frame_header = state->frame_header(); + const uint32_t total_length = frame_header.payload_length; + + QUICHE_DVLOG(2) << "WindowUpdatePayloadDecoder::StartDecodingPayload: " + << frame_header; + + QUICHE_DCHECK_EQ(Http2FrameType::WINDOW_UPDATE, frame_header.type); + QUICHE_DCHECK_LE(db->Remaining(), total_length); + + // WINDOW_UPDATE frames have no flags. + QUICHE_DCHECK_EQ(0, frame_header.flags); + + // Special case for when the payload is the correct size and entirely in + // the buffer. + if (db->Remaining() == Http2WindowUpdateFields::EncodedSize() && + total_length == Http2WindowUpdateFields::EncodedSize()) { + DoDecode(&window_update_fields_, db); + state->listener()->OnWindowUpdate( + frame_header, window_update_fields_.window_size_increment); + return DecodeStatus::kDecodeDone; + } + state->InitializeRemainders(); + return HandleStatus(state, state->StartDecodingStructureInPayload( + &window_update_fields_, db)); +} + +DecodeStatus WindowUpdatePayloadDecoder::ResumeDecodingPayload( + FrameDecoderState* state, DecodeBuffer* db) { + QUICHE_DVLOG(2) << "ResumeDecodingPayload: remaining_payload=" + << state->remaining_payload() + << "; db->Remaining=" << db->Remaining(); + QUICHE_DCHECK_EQ(Http2FrameType::WINDOW_UPDATE, state->frame_header().type); + QUICHE_DCHECK_LE(db->Remaining(), state->frame_header().payload_length); + return HandleStatus(state, state->ResumeDecodingStructureInPayload( + &window_update_fields_, db)); +} + +DecodeStatus WindowUpdatePayloadDecoder::HandleStatus(FrameDecoderState* state, + DecodeStatus status) { + QUICHE_DVLOG(2) << "HandleStatus: status=" << status + << "; remaining_payload=" << state->remaining_payload(); + if (status == DecodeStatus::kDecodeDone) { + if (state->remaining_payload() == 0) { + state->listener()->OnWindowUpdate( + state->frame_header(), window_update_fields_.window_size_increment); + return DecodeStatus::kDecodeDone; + } + // Payload is too long. + return state->ReportFrameSizeError(); + } + // Not done decoding the structure. Either we've got more payload to decode, + // or we've run out because the payload is too short, in which case + // OnFrameSizeError will have already been called. + QUICHE_DCHECK( + (status == DecodeStatus::kDecodeInProgress && + state->remaining_payload() > 0) || + (status == DecodeStatus::kDecodeError && state->remaining_payload() == 0)) + << "\n status=" << status + << "; remaining_payload=" << state->remaining_payload(); + return status; +} + +} // namespace http2 diff --git a/quiche/http2/decoder/payload_decoders/window_update_payload_decoder.h b/quiche/http2/decoder/payload_decoders/window_update_payload_decoder.h new file mode 100644 index 000000000000..659a7151701c --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/window_update_payload_decoder.h @@ -0,0 +1,42 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_WINDOW_UPDATE_PAYLOAD_DECODER_H_ +#define QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_WINDOW_UPDATE_PAYLOAD_DECODER_H_ + +// Decodes the payload of a WINDOW_UPDATE frame. + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/decoder/frame_decoder_state.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace test { +class WindowUpdatePayloadDecoderPeer; +} // namespace test + +class QUICHE_EXPORT WindowUpdatePayloadDecoder { + public: + // Starts decoding a WINDOW_UPDATE frame's payload, and completes it if + // the entire payload is in the provided decode buffer. + DecodeStatus StartDecodingPayload(FrameDecoderState* state, DecodeBuffer* db); + + // Resumes decoding a WINDOW_UPDATE frame's payload that has been split across + // decode buffers. + DecodeStatus ResumeDecodingPayload(FrameDecoderState* state, + DecodeBuffer* db); + + private: + friend class test::WindowUpdatePayloadDecoderPeer; + + DecodeStatus HandleStatus(FrameDecoderState* state, DecodeStatus status); + + Http2WindowUpdateFields window_update_fields_; +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_DECODER_PAYLOAD_DECODERS_WINDOW_UPDATE_PAYLOAD_DECODER_H_ diff --git a/quiche/http2/decoder/payload_decoders/window_update_payload_decoder_test.cc b/quiche/http2/decoder/payload_decoders/window_update_payload_decoder_test.cc new file mode 100644 index 000000000000..6dd82059e0e1 --- /dev/null +++ b/quiche/http2/decoder/payload_decoders/window_update_payload_decoder_test.cc @@ -0,0 +1,95 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/decoder/payload_decoders/window_update_payload_decoder.h" + +#include + +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/test_tools/frame_parts.h" +#include "quiche/http2/test_tools/frame_parts_collector.h" +#include "quiche/http2/test_tools/http2_frame_builder.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/http2/test_tools/http2_structures_test_util.h" +#include "quiche/http2/test_tools/payload_decoder_base_test_util.h" +#include "quiche/http2/test_tools/random_decoder_test_base.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { + +class WindowUpdatePayloadDecoderPeer { + public: + static constexpr Http2FrameType FrameType() { + return Http2FrameType::WINDOW_UPDATE; + } + + // Returns the mask of flags that affect the decoding of the payload (i.e. + // flags that that indicate the presence of certain fields or padding). + static constexpr uint8_t FlagsAffectingPayloadDecoding() { return 0; } +}; + +namespace { + +struct Listener : public FramePartsCollector { + void OnWindowUpdate(const Http2FrameHeader& header, + uint32_t window_size_increment) override { + QUICHE_VLOG(1) << "OnWindowUpdate: " << header + << "; window_size_increment=" << window_size_increment; + EXPECT_EQ(Http2FrameType::WINDOW_UPDATE, header.type); + StartAndEndFrame(header)->OnWindowUpdate(header, window_size_increment); + } + + void OnFrameSizeError(const Http2FrameHeader& header) override { + QUICHE_VLOG(1) << "OnFrameSizeError: " << header; + FrameError(header)->OnFrameSizeError(header); + } +}; + +class WindowUpdatePayloadDecoderTest + : public AbstractPayloadDecoderTest { + protected: + Http2WindowUpdateFields RandWindowUpdateFields() { + Http2WindowUpdateFields fields; + test::Randomize(&fields, RandomPtr()); + QUICHE_VLOG(3) << "RandWindowUpdateFields: " << fields; + return fields; + } +}; + +// Confirm we get an error if the payload is not the correct size to hold +// exactly one Http2WindowUpdateFields. +TEST_F(WindowUpdatePayloadDecoderTest, WrongSize) { + auto approve_size = [](size_t size) { + return size != Http2WindowUpdateFields::EncodedSize(); + }; + Http2FrameBuilder fb; + fb.Append(RandWindowUpdateFields()); + fb.Append(RandWindowUpdateFields()); + fb.Append(RandWindowUpdateFields()); + EXPECT_TRUE(VerifyDetectsFrameSizeError(0, fb.buffer(), approve_size)); +} + +TEST_F(WindowUpdatePayloadDecoderTest, VariousPayloads) { + for (int n = 0; n < 100; ++n) { + uint32_t stream_id = n == 0 ? 0 : RandStreamId(); + Http2WindowUpdateFields fields = RandWindowUpdateFields(); + Http2FrameBuilder fb; + fb.Append(fields); + Http2FrameHeader header(fb.size(), Http2FrameType::WINDOW_UPDATE, + RandFlags(), stream_id); + set_frame_header(header); + FrameParts expected(header); + expected.SetOptWindowUpdateIncrement(fields.window_size_increment); + EXPECT_TRUE(DecodePayloadAndValidateSeveralWays(fb.buffer(), expected)); + } +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/hpack/decoder/hpack_block_collector_test.cc b/quiche/http2/hpack/decoder/hpack_block_collector_test.cc new file mode 100644 index 000000000000..5edfe1c02020 --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_block_collector_test.cc @@ -0,0 +1,121 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/test_tools/hpack_block_collector.h" + +// Tests of HpackBlockCollector. Not intended to be comprehensive, as +// HpackBlockCollector is itself support for testing HpackBlockDecoder, and +// should be pretty thoroughly exercised via the tests of HpackBlockDecoder. + +#include "quiche/http2/test_tools/hpack_block_builder.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { +namespace { + +TEST(HpackBlockCollectorTest, Clear) { + HpackBlockCollector collector; + EXPECT_TRUE(collector.IsClear()); + EXPECT_TRUE(collector.IsNotPending()); + + collector.OnIndexedHeader(234); + EXPECT_FALSE(collector.IsClear()); + EXPECT_TRUE(collector.IsNotPending()); + + collector.Clear(); + EXPECT_TRUE(collector.IsClear()); + EXPECT_TRUE(collector.IsNotPending()); + + collector.OnDynamicTableSizeUpdate(0); + EXPECT_FALSE(collector.IsClear()); + EXPECT_TRUE(collector.IsNotPending()); + + collector.Clear(); + collector.OnStartLiteralHeader(HpackEntryType::kIndexedLiteralHeader, 1); + EXPECT_FALSE(collector.IsClear()); + EXPECT_FALSE(collector.IsNotPending()); +} + +TEST(HpackBlockCollectorTest, IndexedHeader) { + HpackBlockCollector a; + a.OnIndexedHeader(123); + EXPECT_TRUE(a.ValidateSoleIndexedHeader(123)); + + HpackBlockCollector b; + EXPECT_FALSE(a.VerifyEq(b)); + + b.OnIndexedHeader(1); + EXPECT_TRUE(b.ValidateSoleIndexedHeader(1)); + EXPECT_FALSE(a.VerifyEq(b)); + + b.Clear(); + b.OnIndexedHeader(123); + EXPECT_TRUE(a.VerifyEq(b)); + + b.OnIndexedHeader(234); + EXPECT_FALSE(b.VerifyEq(a)); + a.OnIndexedHeader(234); + EXPECT_TRUE(b.VerifyEq(a)); + + std::string expected; + { + HpackBlockBuilder hbb; + hbb.AppendIndexedHeader(123); + hbb.AppendIndexedHeader(234); + EXPECT_EQ(3u, hbb.size()); + expected = hbb.buffer(); + } + std::string actual; + { + HpackBlockBuilder hbb; + a.AppendToHpackBlockBuilder(&hbb); + EXPECT_EQ(3u, hbb.size()); + actual = hbb.buffer(); + } + EXPECT_EQ(expected, actual); +} + +TEST(HpackBlockCollectorTest, DynamicTableSizeUpdate) { + HpackBlockCollector a; + a.OnDynamicTableSizeUpdate(0); + EXPECT_TRUE(a.ValidateSoleDynamicTableSizeUpdate(0)); + + HpackBlockCollector b; + EXPECT_FALSE(a.VerifyEq(b)); + + b.OnDynamicTableSizeUpdate(1); + EXPECT_TRUE(b.ValidateSoleDynamicTableSizeUpdate(1)); + EXPECT_FALSE(a.VerifyEq(b)); + + b.Clear(); + b.OnDynamicTableSizeUpdate(0); + EXPECT_TRUE(a.VerifyEq(b)); + + b.OnDynamicTableSizeUpdate(4096); + EXPECT_FALSE(b.VerifyEq(a)); + a.OnDynamicTableSizeUpdate(4096); + EXPECT_TRUE(b.VerifyEq(a)); + + std::string expected; + { + HpackBlockBuilder hbb; + hbb.AppendDynamicTableSizeUpdate(0); + hbb.AppendDynamicTableSizeUpdate(4096); + EXPECT_EQ(4u, hbb.size()); + expected = hbb.buffer(); + } + std::string actual; + { + HpackBlockBuilder hbb; + a.AppendToHpackBlockBuilder(&hbb); + EXPECT_EQ(4u, hbb.size()); + actual = hbb.buffer(); + } + EXPECT_EQ(expected, actual); +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/hpack/decoder/hpack_block_decoder.cc b/quiche/http2/hpack/decoder/hpack_block_decoder.cc new file mode 100644 index 000000000000..f47c806de6dd --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_block_decoder.cc @@ -0,0 +1,65 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/decoder/hpack_block_decoder.h" + +#include + +#include "absl/strings/str_cat.h" +#include "quiche/common/platform/api/quiche_flag_utils.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +DecodeStatus HpackBlockDecoder::Decode(DecodeBuffer* db) { + if (!before_entry_) { + QUICHE_DVLOG(2) << "HpackBlockDecoder::Decode resume entry, db->Remaining=" + << db->Remaining(); + DecodeStatus status = entry_decoder_.Resume(db, listener_); + switch (status) { + case DecodeStatus::kDecodeDone: + before_entry_ = true; + break; + case DecodeStatus::kDecodeInProgress: + QUICHE_DCHECK_EQ(0u, db->Remaining()); + return DecodeStatus::kDecodeInProgress; + case DecodeStatus::kDecodeError: + QUICHE_CODE_COUNT_N(decompress_failure_3, 1, 23); + return DecodeStatus::kDecodeError; + } + } + QUICHE_DCHECK(before_entry_); + while (db->HasData()) { + QUICHE_DVLOG(2) << "HpackBlockDecoder::Decode start entry, db->Remaining=" + << db->Remaining(); + DecodeStatus status = entry_decoder_.Start(db, listener_); + switch (status) { + case DecodeStatus::kDecodeDone: + continue; + case DecodeStatus::kDecodeInProgress: + QUICHE_DCHECK_EQ(0u, db->Remaining()); + before_entry_ = false; + return DecodeStatus::kDecodeInProgress; + case DecodeStatus::kDecodeError: + QUICHE_CODE_COUNT_N(decompress_failure_3, 2, 23); + return DecodeStatus::kDecodeError; + } + QUICHE_DCHECK(false); + } + QUICHE_DCHECK(before_entry_); + return DecodeStatus::kDecodeDone; +} + +std::string HpackBlockDecoder::DebugString() const { + return absl::StrCat( + "HpackBlockDecoder(", entry_decoder_.DebugString(), ", listener@", + absl::Hex(reinterpret_cast(listener_)), + (before_entry_ ? ", between entries)" : ", in an entry)")); +} + +std::ostream& operator<<(std::ostream& out, const HpackBlockDecoder& v) { + return out << v.DebugString(); +} + +} // namespace http2 diff --git a/quiche/http2/hpack/decoder/hpack_block_decoder.h b/quiche/http2/hpack/decoder/hpack_block_decoder.h new file mode 100644 index 000000000000..4eeb6d3a309a --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_block_decoder.h @@ -0,0 +1,69 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_HPACK_DECODER_HPACK_BLOCK_DECODER_H_ +#define QUICHE_HTTP2_HPACK_DECODER_HPACK_BLOCK_DECODER_H_ + +// HpackBlockDecoder decodes an entire HPACK block (or the available portion +// thereof in the DecodeBuffer) into entries, but doesn't include HPACK static +// or dynamic table support, so table indices remain indices at this level. +// Reports the entries to an HpackEntryDecoderListener. + +#include + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/hpack/decoder/hpack_decoding_error.h" +#include "quiche/http2/hpack/decoder/hpack_entry_decoder.h" +#include "quiche/http2/hpack/decoder/hpack_entry_decoder_listener.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +class QUICHE_EXPORT HpackBlockDecoder { + public: + explicit HpackBlockDecoder(HpackEntryDecoderListener* listener) + : listener_(listener) { + QUICHE_DCHECK_NE(listener_, nullptr); + } + ~HpackBlockDecoder() {} + + HpackBlockDecoder(const HpackBlockDecoder&) = delete; + HpackBlockDecoder& operator=(const HpackBlockDecoder&) = delete; + + // Prepares the decoder to start decoding a new HPACK block. Expected + // to be called from an implementation of Http2FrameDecoderListener's + // OnHeadersStart or OnPushPromiseStart methods. + void Reset() { + QUICHE_DVLOG(2) << "HpackBlockDecoder::Reset"; + before_entry_ = true; + } + + // Decode the fragment of the HPACK block contained in the decode buffer. + // Expected to be called from an implementation of Http2FrameDecoderListener's + // OnHpackFragment method. + DecodeStatus Decode(DecodeBuffer* db); + + // Is the decoding process between entries (i.e. would the next byte be the + // first byte of a new HPACK entry)? + bool before_entry() const { return before_entry_; } + + // Return error code after decoding error occurred in HpackEntryDecoder. + HpackDecodingError error() const { return entry_decoder_.error(); } + + std::string DebugString() const; + + private: + HpackEntryDecoder entry_decoder_; + HpackEntryDecoderListener* const listener_; + bool before_entry_ = true; +}; + +QUICHE_EXPORT std::ostream& operator<<(std::ostream& out, + const HpackBlockDecoder& v); + +} // namespace http2 + +#endif // QUICHE_HTTP2_HPACK_DECODER_HPACK_BLOCK_DECODER_H_ diff --git a/quiche/http2/hpack/decoder/hpack_block_decoder_test.cc b/quiche/http2/hpack/decoder/hpack_block_decoder_test.cc new file mode 100644 index 000000000000..064ae44041a4 --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_block_decoder_test.cc @@ -0,0 +1,290 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/decoder/hpack_block_decoder.h" + +// Tests of HpackBlockDecoder. + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/hpack/http2_hpack_constants.h" +#include "quiche/http2/test_tools/hpack_block_builder.h" +#include "quiche/http2/test_tools/hpack_block_collector.h" +#include "quiche/http2/test_tools/hpack_example.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/http2/test_tools/random_decoder_test_base.h" +#include "quiche/common/platform/api/quiche_expect_bug.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { +namespace { + +class HpackBlockDecoderTest : public RandomDecoderTest { + protected: + HpackBlockDecoderTest() : listener_(&collector_), decoder_(&listener_) { + stop_decode_on_done_ = false; + decoder_.Reset(); + // Make sure logging doesn't crash. Not examining the result. + std::ostringstream strm; + strm << decoder_; + } + + DecodeStatus StartDecoding(DecodeBuffer* db) override { + collector_.Clear(); + decoder_.Reset(); + return ResumeDecoding(db); + } + + DecodeStatus ResumeDecoding(DecodeBuffer* db) override { + DecodeStatus status = decoder_.Decode(db); + + // Make sure logging doesn't crash. Not examining the result. + std::ostringstream strm; + strm << decoder_; + + return status; + } + + AssertionResult DecodeAndValidateSeveralWays(DecodeBuffer* db, + const Validator& validator) { + bool return_non_zero_on_first = false; + return RandomDecoderTest::DecodeAndValidateSeveralWays( + db, return_non_zero_on_first, validator); + } + + AssertionResult DecodeAndValidateSeveralWays(const HpackBlockBuilder& hbb, + const Validator& validator) { + DecodeBuffer db(hbb.buffer()); + return DecodeAndValidateSeveralWays(&db, validator); + } + + AssertionResult DecodeHpackExampleAndValidateSeveralWays( + absl::string_view hpack_example, Validator validator) { + std::string input = HpackExampleToStringOrDie(hpack_example); + DecodeBuffer db(input); + return DecodeAndValidateSeveralWays(&db, validator); + } + + uint8_t Rand8() { return Random().Rand8(); } + + std::string Rand8String() { return Random().RandString(Rand8()); } + + HpackBlockCollector collector_; + HpackEntryDecoderVLoggingListener listener_; + HpackBlockDecoder decoder_; +}; + +// http://httpwg.org/specs/rfc7541.html#rfc.section.C.2.1 +TEST_F(HpackBlockDecoderTest, SpecExample_C_2_1) { + auto do_check = [this]() { + return collector_.ValidateSoleLiteralNameValueHeader( + HpackEntryType::kIndexedLiteralHeader, false, "custom-key", false, + "custom-header"); + }; + const char hpack_example[] = R"( + 40 | == Literal indexed == + 0a | Literal name (len = 10) + 6375 7374 6f6d 2d6b 6579 | custom-key + 0d | Literal value (len = 13) + 6375 7374 6f6d 2d68 6561 6465 72 | custom-header + | -> custom-key: + | custom-header + )"; + EXPECT_TRUE(DecodeHpackExampleAndValidateSeveralWays( + hpack_example, ValidateDoneAndEmpty(do_check))); + EXPECT_TRUE(do_check()); +} + +// http://httpwg.org/specs/rfc7541.html#rfc.section.C.2.2 +TEST_F(HpackBlockDecoderTest, SpecExample_C_2_2) { + auto do_check = [this]() { + return collector_.ValidateSoleLiteralValueHeader( + HpackEntryType::kUnindexedLiteralHeader, 4, false, "/sample/path"); + }; + const char hpack_example[] = R"( + 04 | == Literal not indexed == + | Indexed name (idx = 4) + | :path + 0c | Literal value (len = 12) + 2f73 616d 706c 652f 7061 7468 | /sample/path + | -> :path: /sample/path + )"; + EXPECT_TRUE(DecodeHpackExampleAndValidateSeveralWays( + hpack_example, ValidateDoneAndEmpty(do_check))); + EXPECT_TRUE(do_check()); +} + +// http://httpwg.org/specs/rfc7541.html#rfc.section.C.2.3 +TEST_F(HpackBlockDecoderTest, SpecExample_C_2_3) { + auto do_check = [this]() { + return collector_.ValidateSoleLiteralNameValueHeader( + HpackEntryType::kNeverIndexedLiteralHeader, false, "password", false, + "secret"); + }; + const char hpack_example[] = R"( + 10 | == Literal never indexed == + 08 | Literal name (len = 8) + 7061 7373 776f 7264 | password + 06 | Literal value (len = 6) + 7365 6372 6574 | secret + | -> password: secret + )"; + EXPECT_TRUE(DecodeHpackExampleAndValidateSeveralWays( + hpack_example, ValidateDoneAndEmpty(do_check))); + EXPECT_TRUE(do_check()); +} + +// http://httpwg.org/specs/rfc7541.html#rfc.section.C.2.4 +TEST_F(HpackBlockDecoderTest, SpecExample_C_2_4) { + auto do_check = [this]() { return collector_.ValidateSoleIndexedHeader(2); }; + const char hpack_example[] = R"( + 82 | == Indexed - Add == + | idx = 2 + | -> :method: GET + )"; + EXPECT_TRUE(DecodeHpackExampleAndValidateSeveralWays( + hpack_example, ValidateDoneAndEmpty(do_check))); + EXPECT_TRUE(do_check()); +} +// http://httpwg.org/specs/rfc7541.html#rfc.section.C.3.1 +TEST_F(HpackBlockDecoderTest, SpecExample_C_3_1) { + std::string example = R"( + 82 | == Indexed - Add == + | idx = 2 + | -> :method: GET + 86 | == Indexed - Add == + | idx = 6 + | -> :scheme: http + 84 | == Indexed - Add == + | idx = 4 + | -> :path: / + 41 | == Literal indexed == + | Indexed name (idx = 1) + | :authority + 0f | Literal value (len = 15) + 7777 772e 6578 616d 706c 652e 636f 6d | www.example.com + | -> :authority: + | www.example.com + )"; + HpackBlockCollector expected; + expected.ExpectIndexedHeader(2); + expected.ExpectIndexedHeader(6); + expected.ExpectIndexedHeader(4); + expected.ExpectNameIndexAndLiteralValue(HpackEntryType::kIndexedLiteralHeader, + 1, false, "www.example.com"); + NoArgValidator do_check = [expected, this]() { + return collector_.VerifyEq(expected); + }; + EXPECT_TRUE(DecodeHpackExampleAndValidateSeveralWays( + example, ValidateDoneAndEmpty(do_check))); + EXPECT_TRUE(do_check()); +} + +// http://httpwg.org/specs/rfc7541.html#rfc.section.C.5.1 +TEST_F(HpackBlockDecoderTest, SpecExample_C_5_1) { + std::string example = R"( + 48 | == Literal indexed == + | Indexed name (idx = 8) + | :status + 03 | Literal value (len = 3) + 3330 32 | 302 + | -> :status: 302 + 58 | == Literal indexed == + | Indexed name (idx = 24) + | cache-control + 07 | Literal value (len = 7) + 7072 6976 6174 65 | private + | -> cache-control: private + 61 | == Literal indexed == + | Indexed name (idx = 33) + | date + 1d | Literal value (len = 29) + 4d6f 6e2c 2032 3120 4f63 7420 3230 3133 | Mon, 21 Oct 2013 + 2032 303a 3133 3a32 3120 474d 54 | 20:13:21 GMT + | -> date: Mon, 21 Oct 2013 + | 20:13:21 GMT + 6e | == Literal indexed == + | Indexed name (idx = 46) + | location + 17 | Literal value (len = 23) + 6874 7470 733a 2f2f 7777 772e 6578 616d | https://www.exam + 706c 652e 636f 6d | ple.com + | -> location: + | https://www.example.com + )"; + HpackBlockCollector expected; + expected.ExpectNameIndexAndLiteralValue(HpackEntryType::kIndexedLiteralHeader, + 8, false, "302"); + expected.ExpectNameIndexAndLiteralValue(HpackEntryType::kIndexedLiteralHeader, + 24, false, "private"); + expected.ExpectNameIndexAndLiteralValue(HpackEntryType::kIndexedLiteralHeader, + 33, false, + "Mon, 21 Oct 2013 20:13:21 GMT"); + expected.ExpectNameIndexAndLiteralValue(HpackEntryType::kIndexedLiteralHeader, + 46, false, "https://www.example.com"); + NoArgValidator do_check = [expected, this]() { + return collector_.VerifyEq(expected); + }; + EXPECT_TRUE(DecodeHpackExampleAndValidateSeveralWays( + example, ValidateDoneAndEmpty(do_check))); + EXPECT_TRUE(do_check()); +} + +// Generate a bunch of HPACK block entries to expect, use those expectations +// to generate an HPACK block, then decode it and confirm it matches those +// expectations. Some of these are invalid (such as Indexed, with index=0), +// but well-formed, and the decoder doesn't check for validity, just +// well-formedness. That includes the validity of the strings not being checked, +// such as lower-case ascii for the names, and valid Huffman encodings. +TEST_F(HpackBlockDecoderTest, Computed) { + HpackBlockCollector expected; + expected.ExpectIndexedHeader(0); + expected.ExpectIndexedHeader(1); + expected.ExpectIndexedHeader(126); + expected.ExpectIndexedHeader(127); + expected.ExpectIndexedHeader(128); + expected.ExpectDynamicTableSizeUpdate(0); + expected.ExpectDynamicTableSizeUpdate(1); + expected.ExpectDynamicTableSizeUpdate(14); + expected.ExpectDynamicTableSizeUpdate(15); + expected.ExpectDynamicTableSizeUpdate(30); + expected.ExpectDynamicTableSizeUpdate(31); + expected.ExpectDynamicTableSizeUpdate(4095); + expected.ExpectDynamicTableSizeUpdate(4096); + expected.ExpectDynamicTableSizeUpdate(8192); + for (auto type : {HpackEntryType::kIndexedLiteralHeader, + HpackEntryType::kUnindexedLiteralHeader, + HpackEntryType::kNeverIndexedLiteralHeader}) { + for (bool value_huffman : {false, true}) { + // An entry with an index for the name. Ensure the name index + // is not zero by adding one to the Rand8() result. + expected.ExpectNameIndexAndLiteralValue(type, Rand8() + 1, value_huffman, + Rand8String()); + // And two entries with literal names, one plain, one huffman encoded. + expected.ExpectLiteralNameAndValue(type, false, Rand8String(), + value_huffman, Rand8String()); + expected.ExpectLiteralNameAndValue(type, true, Rand8String(), + value_huffman, Rand8String()); + } + } + // Shuffle the entries and serialize them to produce an HPACK block. + expected.ShuffleEntries(RandomPtr()); + HpackBlockBuilder hbb; + expected.AppendToHpackBlockBuilder(&hbb); + + NoArgValidator do_check = [expected, this]() { + return collector_.VerifyEq(expected); + }; + EXPECT_TRUE( + DecodeAndValidateSeveralWays(hbb, ValidateDoneAndEmpty(do_check))); + EXPECT_TRUE(do_check()); +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/hpack/decoder/hpack_decoder.cc b/quiche/http2/hpack/decoder/hpack_decoder.cc new file mode 100644 index 000000000000..935cd7f29672 --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_decoder.cc @@ -0,0 +1,124 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/decoder/hpack_decoder.h" + +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/common/platform/api/quiche_flag_utils.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +HpackDecoder::HpackDecoder(HpackDecoderListener* listener, + size_t max_string_size) + : decoder_state_(listener), + entry_buffer_(&decoder_state_, max_string_size), + block_decoder_(&entry_buffer_), + error_(HpackDecodingError::kOk) {} + +HpackDecoder::~HpackDecoder() = default; + +void HpackDecoder::set_max_string_size_bytes(size_t max_string_size_bytes) { + entry_buffer_.set_max_string_size_bytes(max_string_size_bytes); +} + +void HpackDecoder::ApplyHeaderTableSizeSetting(uint32_t max_header_table_size) { + decoder_state_.ApplyHeaderTableSizeSetting(max_header_table_size); +} + +bool HpackDecoder::StartDecodingBlock() { + QUICHE_DVLOG(3) << "HpackDecoder::StartDecodingBlock, error_detected=" + << (DetectError() ? "true" : "false"); + if (DetectError()) { + return false; + } + // TODO(jamessynge): Eliminate Reset(), which shouldn't be necessary + // if there are no errors, and shouldn't be necessary with errors if + // we never resume decoding after an error has been detected. + block_decoder_.Reset(); + decoder_state_.OnHeaderBlockStart(); + return true; +} + +bool HpackDecoder::DecodeFragment(DecodeBuffer* db) { + QUICHE_DVLOG(3) << "HpackDecoder::DecodeFragment, error_detected=" + << (DetectError() ? "true" : "false") + << ", size=" << db->Remaining(); + if (DetectError()) { + QUICHE_CODE_COUNT_N(decompress_failure_3, 3, 23); + return false; + } + // Decode contents of db as an HPACK block fragment, forwards the decoded + // entries to entry_buffer_, which in turn forwards them to decode_state_, + // which finally forwards them to the HpackDecoderListener. + DecodeStatus status = block_decoder_.Decode(db); + if (status == DecodeStatus::kDecodeError) { + ReportError(block_decoder_.error(), ""); + QUICHE_CODE_COUNT_N(decompress_failure_3, 4, 23); + return false; + } else if (DetectError()) { + QUICHE_CODE_COUNT_N(decompress_failure_3, 5, 23); + return false; + } + // Should be positioned between entries iff decoding is complete. + QUICHE_DCHECK_EQ(block_decoder_.before_entry(), + status == DecodeStatus::kDecodeDone) + << status; + if (!block_decoder_.before_entry()) { + entry_buffer_.BufferStringsIfUnbuffered(); + } + return true; +} + +bool HpackDecoder::EndDecodingBlock() { + QUICHE_DVLOG(3) << "HpackDecoder::EndDecodingBlock, error_detected=" + << (DetectError() ? "true" : "false"); + if (DetectError()) { + QUICHE_CODE_COUNT_N(decompress_failure_3, 6, 23); + return false; + } + if (!block_decoder_.before_entry()) { + // The HPACK block ended in the middle of an entry. + ReportError(HpackDecodingError::kTruncatedBlock, ""); + QUICHE_CODE_COUNT_N(decompress_failure_3, 7, 23); + return false; + } + decoder_state_.OnHeaderBlockEnd(); + if (DetectError()) { + // HpackDecoderState will have reported the error. + QUICHE_CODE_COUNT_N(decompress_failure_3, 8, 23); + return false; + } + return true; +} + +bool HpackDecoder::DetectError() { + if (error_ != HpackDecodingError::kOk) { + return true; + } + + if (decoder_state_.error() != HpackDecodingError::kOk) { + QUICHE_DVLOG(2) << "Error detected in decoder_state_"; + QUICHE_CODE_COUNT_N(decompress_failure_3, 10, 23); + error_ = decoder_state_.error(); + detailed_error_ = decoder_state_.detailed_error(); + } + + return error_ != HpackDecodingError::kOk; +} + +void HpackDecoder::ReportError(HpackDecodingError error, + std::string detailed_error) { + QUICHE_DVLOG(3) << "HpackDecoder::ReportError is new=" + << (error_ == HpackDecodingError::kOk ? "true" : "false") + << ", error: " << HpackDecodingErrorToString(error); + if (error_ == HpackDecodingError::kOk) { + error_ = error; + detailed_error_ = detailed_error; + decoder_state_.listener()->OnHeaderErrorDetected( + HpackDecodingErrorToString(error)); + } +} + +} // namespace http2 diff --git a/quiche/http2/hpack/decoder/hpack_decoder.h b/quiche/http2/hpack/decoder/hpack_decoder.h new file mode 100644 index 000000000000..9e4b68bc927d --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_decoder.h @@ -0,0 +1,132 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_HPACK_DECODER_HPACK_DECODER_H_ +#define QUICHE_HTTP2_HPACK_DECODER_HPACK_DECODER_H_ + +// Decodes HPACK blocks, calls an HpackDecoderListener with the decoded header +// entries. Also notifies the listener of errors and of the boundaries of the +// HPACK blocks. + +// TODO(jamessynge): Add feature allowing an HpackEntryDecoderListener +// sub-class (and possibly others) to be passed in for counting events, +// so that deciding whether to count is not done by having lots of if +// statements, but instead by inserting an indirection only when needed. + +// TODO(jamessynge): Consider whether to return false from methods below +// when an error has been previously detected. It protects calling code +// from its failure to pay attention to previous errors, but should we +// spend time to do that? + +#include + +#include + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/hpack/decoder/hpack_block_decoder.h" +#include "quiche/http2/hpack/decoder/hpack_decoder_listener.h" +#include "quiche/http2/hpack/decoder/hpack_decoder_state.h" +#include "quiche/http2/hpack/decoder/hpack_decoder_tables.h" +#include "quiche/http2/hpack/decoder/hpack_decoding_error.h" +#include "quiche/http2/hpack/decoder/hpack_whole_entry_buffer.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace test { +class HpackDecoderPeer; +} // namespace test + +class QUICHE_EXPORT HpackDecoder { + public: + HpackDecoder(HpackDecoderListener* listener, size_t max_string_size); + virtual ~HpackDecoder(); + + HpackDecoder(const HpackDecoder&) = delete; + HpackDecoder& operator=(const HpackDecoder&) = delete; + + // max_string_size specifies the maximum size of an on-the-wire string (name + // or value, plain or Huffman encoded) that will be accepted. See sections + // 5.1 and 5.2 of RFC 7541. This is a defense against OOM attacks; HTTP/2 + // allows a decoder to enforce any limit of the size of the header lists + // that it is willing to decode, including less than the MAX_HEADER_LIST_SIZE + // setting, a setting that is initially unlimited. For example, we might + // choose to send a MAX_HEADER_LIST_SIZE of 64KB, and to use that same value + // as the upper bound for individual strings. + void set_max_string_size_bytes(size_t max_string_size_bytes); + + // ApplyHeaderTableSizeSetting notifies this object that this endpoint has + // received a SETTINGS ACK frame acknowledging an earlier SETTINGS frame from + // this endpoint specifying a new value for SETTINGS_HEADER_TABLE_SIZE (the + // maximum size of the dynamic table that this endpoint will use to decode + // HPACK blocks). + // Because a SETTINGS frame can contain SETTINGS_HEADER_TABLE_SIZE values, + // the caller must keep track of those multiple changes, and make + // corresponding calls to this method. In particular, a call must be made + // with the lowest value acknowledged by the peer, and a call must be made + // with the final value acknowledged, in that order; additional calls may + // be made if additional values were sent. These calls must be made between + // decoding the SETTINGS ACK, and before the next HPACK block is decoded. + void ApplyHeaderTableSizeSetting(uint32_t max_header_table_size); + + // Returns the most recently applied value of SETTINGS_HEADER_TABLE_SIZE. + size_t GetCurrentHeaderTableSizeSetting() const { + return decoder_state_.GetCurrentHeaderTableSizeSetting(); + } + + // Prepares the decoder for decoding a new HPACK block, and announces this to + // its listener. Returns true if OK to continue with decoding, false if an + // error has been detected, which for StartDecodingBlock means the error was + // detected while decoding a previous HPACK block. + bool StartDecodingBlock(); + + // Decodes a fragment (some or all of the remainder) of an HPACK block, + // reporting header entries (name & value pairs) that it completely decodes + // in the process to the listener. Returns true successfully decoded, false if + // an error has been detected, either during decoding of the fragment, or + // prior to this call. + bool DecodeFragment(DecodeBuffer* db); + + // Completes the process of decoding an HPACK block: if the HPACK block was + // properly terminated, announces the end of the header list to the listener + // and returns true; else returns false. + bool EndDecodingBlock(); + + // If no error has been detected so far, query |decoder_state_| for errors and + // set |error_| if necessary. Returns true if an error has ever been + // detected. + bool DetectError(); + + size_t GetDynamicTableSize() const { + return decoder_state_.GetDynamicTableSize(); + } + + // Error code if an error has occurred, HpackDecodingError::kOk otherwise. + HpackDecodingError error() const { return error_; } + + std::string detailed_error() const { return detailed_error_; } + + private: + friend class test::HpackDecoderPeer; + + // Reports an error to the listener IF this is the first error detected. + void ReportError(HpackDecodingError error, std::string detailed_error); + + // The decompressor state, as defined by HPACK (i.e. the static and dynamic + // tables). + HpackDecoderState decoder_state_; + + // Assembles the various parts of a header entry into whole entries. + HpackWholeEntryBuffer entry_buffer_; + + // The decoder of HPACK blocks into entry parts, passed to entry_buffer_. + HpackBlockDecoder block_decoder_; + + // Error code if an error has occurred, HpackDecodingError::kOk otherwise. + HpackDecodingError error_; + std::string detailed_error_; +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_HPACK_DECODER_HPACK_DECODER_H_ diff --git a/quiche/http2/hpack/decoder/hpack_decoder_listener.cc b/quiche/http2/hpack/decoder/hpack_decoder_listener.cc new file mode 100644 index 000000000000..75a59695c27c --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_decoder_listener.cc @@ -0,0 +1,29 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/decoder/hpack_decoder_listener.h" + +namespace http2 { + +HpackDecoderListener::HpackDecoderListener() = default; +HpackDecoderListener::~HpackDecoderListener() = default; + +HpackDecoderNoOpListener::HpackDecoderNoOpListener() = default; +HpackDecoderNoOpListener::~HpackDecoderNoOpListener() = default; + +void HpackDecoderNoOpListener::OnHeaderListStart() {} +void HpackDecoderNoOpListener::OnHeader(const std::string& /*name*/, + const std::string& /*value*/) {} +void HpackDecoderNoOpListener::OnHeaderListEnd() {} +void HpackDecoderNoOpListener::OnHeaderErrorDetected( + absl::string_view /*error_message*/) {} + +// static +HpackDecoderNoOpListener* HpackDecoderNoOpListener::NoOpListener() { + static HpackDecoderNoOpListener* static_instance = + new HpackDecoderNoOpListener(); + return static_instance; +} + +} // namespace http2 diff --git a/quiche/http2/hpack/decoder/hpack_decoder_listener.h b/quiche/http2/hpack/decoder/hpack_decoder_listener.h new file mode 100644 index 000000000000..37564a4bca35 --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_decoder_listener.h @@ -0,0 +1,60 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Defines HpackDecoderListener, the base class of listeners for HTTP header +// lists decoded from an HPACK block. + +#ifndef QUICHE_HTTP2_HPACK_DECODER_HPACK_DECODER_LISTENER_H_ +#define QUICHE_HTTP2_HPACK_DECODER_HPACK_DECODER_LISTENER_H_ + +#include "absl/strings/string_view.h" +#include "quiche/http2/hpack/http2_hpack_constants.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { + +class QUICHE_EXPORT HpackDecoderListener { + public: + HpackDecoderListener(); + virtual ~HpackDecoderListener(); + + // OnHeaderListStart is called at the start of decoding an HPACK block into + // an HTTP/2 header list. Will only be called once per block, even if it + // extends into CONTINUATION frames. + virtual void OnHeaderListStart() = 0; + + // Called for each header name-value pair that is decoded, in the order they + // appear in the HPACK block. Multiple values for a given key will be emitted + // as multiple calls to OnHeader. + virtual void OnHeader(const std::string& name, const std::string& value) = 0; + + // OnHeaderListEnd is called after successfully decoding an HPACK block into + // an HTTP/2 header list. Will only be called once per block, even if it + // extends into CONTINUATION frames. + virtual void OnHeaderListEnd() = 0; + + // OnHeaderErrorDetected is called if an error is detected while decoding. + // error_message may be used in a GOAWAY frame as the Opaque Data. + virtual void OnHeaderErrorDetected(absl::string_view error_message) = 0; +}; + +// A no-op implementation of HpackDecoderListener, useful for ignoring +// callbacks once an error is detected. +class QUICHE_EXPORT HpackDecoderNoOpListener : public HpackDecoderListener { + public: + HpackDecoderNoOpListener(); + ~HpackDecoderNoOpListener() override; + + void OnHeaderListStart() override; + void OnHeader(const std::string& name, const std::string& value) override; + void OnHeaderListEnd() override; + void OnHeaderErrorDetected(absl::string_view error_message) override; + + // Returns a listener that ignores all the calls. + static HpackDecoderNoOpListener* NoOpListener(); +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_HPACK_DECODER_HPACK_DECODER_LISTENER_H_ diff --git a/quiche/http2/hpack/decoder/hpack_decoder_state.cc b/quiche/http2/hpack/decoder/hpack_decoder_state.cc new file mode 100644 index 000000000000..e34be72b20b4 --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_decoder_state.cc @@ -0,0 +1,223 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/decoder/hpack_decoder_state.h" + +#include + +#include "quiche/http2/http2_constants.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { +namespace { + +std::string ExtractString(HpackDecoderStringBuffer* string_buffer) { + if (string_buffer->IsBuffered()) { + return string_buffer->ReleaseString(); + } else { + auto result = std::string(string_buffer->str()); + string_buffer->Reset(); + return result; + } +} + +} // namespace + +HpackDecoderState::HpackDecoderState(HpackDecoderListener* listener) + : listener_(listener), + final_header_table_size_(Http2SettingsInfo::DefaultHeaderTableSize()), + lowest_header_table_size_(final_header_table_size_), + require_dynamic_table_size_update_(false), + allow_dynamic_table_size_update_(true), + saw_dynamic_table_size_update_(false), + error_(HpackDecodingError::kOk) { + QUICHE_CHECK(listener_); +} + +HpackDecoderState::~HpackDecoderState() = default; + +void HpackDecoderState::ApplyHeaderTableSizeSetting( + uint32_t header_table_size) { + QUICHE_DVLOG(2) << "HpackDecoderState::ApplyHeaderTableSizeSetting(" + << header_table_size << ")"; + QUICHE_DCHECK_LE(lowest_header_table_size_, final_header_table_size_); + if (header_table_size < lowest_header_table_size_) { + lowest_header_table_size_ = header_table_size; + } + final_header_table_size_ = header_table_size; + QUICHE_DVLOG(2) << "low water mark: " << lowest_header_table_size_; + QUICHE_DVLOG(2) << "final limit: " << final_header_table_size_; +} + +// Called to notify this object that we're starting to decode an HPACK block +// (e.g. a HEADERS or PUSH_PROMISE frame's header has been decoded). +void HpackDecoderState::OnHeaderBlockStart() { + QUICHE_DVLOG(2) << "HpackDecoderState::OnHeaderBlockStart"; + // This instance can't be reused after an error has been detected, as we must + // assume that the encoder and decoder compression states are no longer + // synchronized. + QUICHE_DCHECK(error_ == HpackDecodingError::kOk) + << HpackDecodingErrorToString(error_); + QUICHE_DCHECK_LE(lowest_header_table_size_, final_header_table_size_); + allow_dynamic_table_size_update_ = true; + saw_dynamic_table_size_update_ = false; + // If the peer has acknowledged a HEADER_TABLE_SIZE smaller than that which + // its HPACK encoder has been using, then the next HPACK block it sends MUST + // start with a Dynamic Table Size Update entry that is at least as low as + // lowest_header_table_size_. That may be followed by another as great as + // final_header_table_size_, if those are different. + require_dynamic_table_size_update_ = + (lowest_header_table_size_ < + decoder_tables_.current_header_table_size() || + final_header_table_size_ < decoder_tables_.header_table_size_limit()); + QUICHE_DVLOG(2) << "HpackDecoderState::OnHeaderListStart " + << "require_dynamic_table_size_update_=" + << require_dynamic_table_size_update_; + listener_->OnHeaderListStart(); +} + +void HpackDecoderState::OnIndexedHeader(size_t index) { + QUICHE_DVLOG(2) << "HpackDecoderState::OnIndexedHeader: " << index; + if (error_ != HpackDecodingError::kOk) { + return; + } + if (require_dynamic_table_size_update_) { + ReportError(HpackDecodingError::kMissingDynamicTableSizeUpdate, ""); + return; + } + allow_dynamic_table_size_update_ = false; + const HpackStringPair* entry = decoder_tables_.Lookup(index); + if (entry != nullptr) { + listener_->OnHeader(entry->name, entry->value); + } else { + ReportError(HpackDecodingError::kInvalidIndex, ""); + } +} + +void HpackDecoderState::OnNameIndexAndLiteralValue( + HpackEntryType entry_type, size_t name_index, + HpackDecoderStringBuffer* value_buffer) { + QUICHE_DVLOG(2) << "HpackDecoderState::OnNameIndexAndLiteralValue " + << entry_type << ", " << name_index << ", " + << value_buffer->str(); + if (error_ != HpackDecodingError::kOk) { + return; + } + if (require_dynamic_table_size_update_) { + ReportError(HpackDecodingError::kMissingDynamicTableSizeUpdate, ""); + return; + } + allow_dynamic_table_size_update_ = false; + const HpackStringPair* entry = decoder_tables_.Lookup(name_index); + if (entry != nullptr) { + std::string value(ExtractString(value_buffer)); + listener_->OnHeader(entry->name, value); + if (entry_type == HpackEntryType::kIndexedLiteralHeader) { + decoder_tables_.Insert(entry->name, std::move(value)); + } + } else { + ReportError(HpackDecodingError::kInvalidNameIndex, ""); + } +} + +void HpackDecoderState::OnLiteralNameAndValue( + HpackEntryType entry_type, HpackDecoderStringBuffer* name_buffer, + HpackDecoderStringBuffer* value_buffer) { + QUICHE_DVLOG(2) << "HpackDecoderState::OnLiteralNameAndValue " << entry_type + << ", " << name_buffer->str() << ", " << value_buffer->str(); + if (error_ != HpackDecodingError::kOk) { + return; + } + if (require_dynamic_table_size_update_) { + ReportError(HpackDecodingError::kMissingDynamicTableSizeUpdate, ""); + return; + } + allow_dynamic_table_size_update_ = false; + std::string name(ExtractString(name_buffer)); + std::string value(ExtractString(value_buffer)); + listener_->OnHeader(name, value); + if (entry_type == HpackEntryType::kIndexedLiteralHeader) { + decoder_tables_.Insert(std::move(name), std::move(value)); + } +} + +void HpackDecoderState::OnDynamicTableSizeUpdate(size_t size_limit) { + QUICHE_DVLOG(2) << "HpackDecoderState::OnDynamicTableSizeUpdate " + << size_limit << ", required=" + << (require_dynamic_table_size_update_ ? "true" : "false") + << ", allowed=" + << (allow_dynamic_table_size_update_ ? "true" : "false"); + if (error_ != HpackDecodingError::kOk) { + return; + } + QUICHE_DCHECK_LE(lowest_header_table_size_, final_header_table_size_); + if (!allow_dynamic_table_size_update_) { + // At most two dynamic table size updates allowed at the start, and not + // after a header. + ReportError(HpackDecodingError::kDynamicTableSizeUpdateNotAllowed, ""); + return; + } + if (require_dynamic_table_size_update_) { + // The new size must not be greater than the low water mark. + if (size_limit > lowest_header_table_size_) { + ReportError( + HpackDecodingError::kInitialDynamicTableSizeUpdateIsAboveLowWaterMark, + ""); + return; + } + require_dynamic_table_size_update_ = false; + } else if (size_limit > final_header_table_size_) { + // The new size must not be greater than the final max header table size + // that the peer acknowledged. + ReportError( + HpackDecodingError::kDynamicTableSizeUpdateIsAboveAcknowledgedSetting, + ""); + return; + } + decoder_tables_.DynamicTableSizeUpdate(size_limit); + if (saw_dynamic_table_size_update_) { + allow_dynamic_table_size_update_ = false; + } else { + saw_dynamic_table_size_update_ = true; + } + // We no longer need to keep an eye out for a lower header table size. + lowest_header_table_size_ = final_header_table_size_; +} + +void HpackDecoderState::OnHpackDecodeError(HpackDecodingError error, + std::string detailed_error) { + QUICHE_DVLOG(2) << "HpackDecoderState::OnHpackDecodeError " + << HpackDecodingErrorToString(error); + if (error_ == HpackDecodingError::kOk) { + ReportError(error, detailed_error); + } +} + +void HpackDecoderState::OnHeaderBlockEnd() { + QUICHE_DVLOG(2) << "HpackDecoderState::OnHeaderBlockEnd"; + if (error_ != HpackDecodingError::kOk) { + return; + } + if (require_dynamic_table_size_update_) { + // Apparently the HPACK block was empty, but we needed it to contain at + // least 1 dynamic table size update. + ReportError(HpackDecodingError::kMissingDynamicTableSizeUpdate, ""); + } else { + listener_->OnHeaderListEnd(); + } +} + +void HpackDecoderState::ReportError(HpackDecodingError error, + std::string detailed_error) { + QUICHE_DVLOG(2) << "HpackDecoderState::ReportError is new=" + << (error_ == HpackDecodingError::kOk ? "true" : "false") + << ", error: " << HpackDecodingErrorToString(error); + if (error_ == HpackDecodingError::kOk) { + listener_->OnHeaderErrorDetected(HpackDecodingErrorToString(error)); + error_ = error; + detailed_error_ = detailed_error; + } +} + +} // namespace http2 diff --git a/quiche/http2/hpack/decoder/hpack_decoder_state.h b/quiche/http2/hpack/decoder/hpack_decoder_state.h new file mode 100644 index 000000000000..4198ef4d170a --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_decoder_state.h @@ -0,0 +1,137 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// HpackDecoderState maintains the HPACK decompressor state; i.e. updates the +// HPACK dynamic table according to RFC 7541 as the entries in an HPACK block +// are decoded, and reads from the static and dynamic tables in order to build +// complete header entries. Calls an HpackDecoderListener with the completely +// decoded headers (i.e. after resolving table indices into names or values), +// thus translating the decoded HPACK entries into HTTP/2 headers. + +#ifndef QUICHE_HTTP2_HPACK_DECODER_HPACK_DECODER_STATE_H_ +#define QUICHE_HTTP2_HPACK_DECODER_HPACK_DECODER_STATE_H_ + +#include + +#include + +#include "absl/strings/string_view.h" +#include "quiche/http2/hpack/decoder/hpack_decoder_listener.h" +#include "quiche/http2/hpack/decoder/hpack_decoder_string_buffer.h" +#include "quiche/http2/hpack/decoder/hpack_decoder_tables.h" +#include "quiche/http2/hpack/decoder/hpack_decoding_error.h" +#include "quiche/http2/hpack/decoder/hpack_whole_entry_listener.h" +#include "quiche/http2/hpack/http2_hpack_constants.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace test { +class HpackDecoderStatePeer; +} // namespace test + +class QUICHE_EXPORT HpackDecoderState : public HpackWholeEntryListener { + public: + explicit HpackDecoderState(HpackDecoderListener* listener); + ~HpackDecoderState() override; + + HpackDecoderState(const HpackDecoderState&) = delete; + HpackDecoderState& operator=(const HpackDecoderState&) = delete; + + // Set the listener to be notified when a whole entry has been decoded, + // including resolving name or name and value references. + // The listener may be changed at any time. + HpackDecoderListener* listener() const { return listener_; } + + // ApplyHeaderTableSizeSetting notifies this object that this endpoint has + // received a SETTINGS ACK frame acknowledging an earlier SETTINGS frame from + // this endpoint specifying a new value for SETTINGS_HEADER_TABLE_SIZE (the + // maximum size of the dynamic table that this endpoint will use to decode + // HPACK blocks). + // Because a SETTINGS frame can contain SETTINGS_HEADER_TABLE_SIZE values, + // the caller must keep track of those multiple changes, and make + // corresponding calls to this method. In particular, a call must be made + // with the lowest value acknowledged by the peer, and a call must be made + // with the final value acknowledged, in that order; additional calls may + // be made if additional values were sent. These calls must be made between + // decoding the SETTINGS ACK, and before the next HPACK block is decoded. + void ApplyHeaderTableSizeSetting(uint32_t max_header_table_size); + + // Returns the most recently applied value of SETTINGS_HEADER_TABLE_SIZE. + size_t GetCurrentHeaderTableSizeSetting() const { + return final_header_table_size_; + } + + // OnHeaderBlockStart notifies this object that we're starting to decode the + // HPACK payload of a HEADERS or PUSH_PROMISE frame. + void OnHeaderBlockStart(); + + // Implement the HpackWholeEntryListener methods, each of which notifies this + // object when an entire entry has been decoded. + void OnIndexedHeader(size_t index) override; + void OnNameIndexAndLiteralValue( + HpackEntryType entry_type, size_t name_index, + HpackDecoderStringBuffer* value_buffer) override; + void OnLiteralNameAndValue(HpackEntryType entry_type, + HpackDecoderStringBuffer* name_buffer, + HpackDecoderStringBuffer* value_buffer) override; + void OnDynamicTableSizeUpdate(size_t size) override; + void OnHpackDecodeError(HpackDecodingError error, + std::string detailed_error) override; + + // OnHeaderBlockEnd notifies this object that an entire HPACK block has been + // decoded, which might have extended into CONTINUATION blocks. + void OnHeaderBlockEnd(); + + // Returns error code after an error has been detected and reported. + // No further callbacks will be made to the listener. + HpackDecodingError error() const { return error_; } + + size_t GetDynamicTableSize() const { + return decoder_tables_.current_header_table_size(); + } + + const HpackDecoderTables& decoder_tables_for_test() const { + return decoder_tables_; + } + + std::string detailed_error() const { return detailed_error_; } + + private: + friend class test::HpackDecoderStatePeer; + + // Reports an error to the listener IF this is the first error detected. + void ReportError(HpackDecodingError error, std::string detailed_error); + + // The static and dynamic HPACK tables. + HpackDecoderTables decoder_tables_; + + // The listener to be notified of headers, the start and end of header + // lists, and of errors. + HpackDecoderListener* listener_; + + // The most recent HEADER_TABLE_SIZE setting acknowledged by the peer. + uint32_t final_header_table_size_; + + // The lowest HEADER_TABLE_SIZE setting acknowledged by the peer; valid until + // the next HPACK block is decoded. + // TODO(jamessynge): Test raising the HEADER_TABLE_SIZE. + uint32_t lowest_header_table_size_; + + // Must the next (first) HPACK entry be a dynamic table size update? + bool require_dynamic_table_size_update_; + + // May the next (first or second) HPACK entry be a dynamic table size update? + bool allow_dynamic_table_size_update_; + + // Have we already seen a dynamic table size update in this HPACK block? + bool saw_dynamic_table_size_update_; + + // Has an error already been detected and reported to the listener? + HpackDecodingError error_; + std::string detailed_error_; +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_HPACK_DECODER_HPACK_DECODER_STATE_H_ diff --git a/quiche/http2/hpack/decoder/hpack_decoder_state_test.cc b/quiche/http2/hpack/decoder/hpack_decoder_state_test.cc new file mode 100644 index 000000000000..46872aa1a852 --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_decoder_state_test.cc @@ -0,0 +1,541 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/decoder/hpack_decoder_state.h" + +// Tests of HpackDecoderState. + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/http2/hpack/http2_hpack_constants.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/test_tools/verify_macros.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +using ::testing::AssertionResult; +using ::testing::AssertionSuccess; +using ::testing::Eq; +using ::testing::Mock; +using ::testing::StrictMock; + +namespace http2 { +namespace test { +class HpackDecoderStatePeer { + public: + static HpackDecoderTables* GetDecoderTables(HpackDecoderState* state) { + return &state->decoder_tables_; + } +}; + +namespace { + +class MockHpackDecoderListener : public HpackDecoderListener { + public: + MOCK_METHOD(void, OnHeaderListStart, (), (override)); + MOCK_METHOD(void, OnHeader, + (const std::string& name, const std::string& value), (override)); + MOCK_METHOD(void, OnHeaderListEnd, (), (override)); + MOCK_METHOD(void, OnHeaderErrorDetected, (absl::string_view error_message), + (override)); +}; + +enum StringBacking { STATIC, UNBUFFERED, BUFFERED }; + +class HpackDecoderStateTest : public quiche::test::QuicheTest { + protected: + HpackDecoderStateTest() : decoder_state_(&listener_) {} + + HpackDecoderTables* GetDecoderTables() { + return HpackDecoderStatePeer::GetDecoderTables(&decoder_state_); + } + + const HpackStringPair* Lookup(size_t index) { + return GetDecoderTables()->Lookup(index); + } + + size_t current_header_table_size() { + return GetDecoderTables()->current_header_table_size(); + } + + size_t header_table_size_limit() { + return GetDecoderTables()->header_table_size_limit(); + } + + void set_header_table_size_limit(size_t size) { + GetDecoderTables()->DynamicTableSizeUpdate(size); + } + + void SetStringBuffer(const char* s, StringBacking backing, + HpackDecoderStringBuffer* string_buffer) { + switch (backing) { + case STATIC: + string_buffer->Set(s, true); + break; + case UNBUFFERED: + string_buffer->Set(s, false); + break; + case BUFFERED: + string_buffer->Set(s, false); + string_buffer->BufferStringIfUnbuffered(); + break; + } + } + + void SetName(const char* s, StringBacking backing) { + SetStringBuffer(s, backing, &name_buffer_); + } + + void SetValue(const char* s, StringBacking backing) { + SetStringBuffer(s, backing, &value_buffer_); + } + + void SendStartAndVerifyCallback() { + EXPECT_CALL(listener_, OnHeaderListStart()); + decoder_state_.OnHeaderBlockStart(); + Mock::VerifyAndClearExpectations(&listener_); + } + + void SendSizeUpdate(size_t size) { + decoder_state_.OnDynamicTableSizeUpdate(size); + Mock::VerifyAndClearExpectations(&listener_); + } + + void SendIndexAndVerifyCallback(size_t index, + HpackEntryType /*expected_type*/, + const char* expected_name, + const char* expected_value) { + EXPECT_CALL(listener_, OnHeader(Eq(expected_name), Eq(expected_value))); + decoder_state_.OnIndexedHeader(index); + Mock::VerifyAndClearExpectations(&listener_); + } + + void SendValueAndVerifyCallback(size_t name_index, HpackEntryType entry_type, + const char* name, const char* value, + StringBacking value_backing) { + SetValue(value, value_backing); + EXPECT_CALL(listener_, OnHeader(Eq(name), Eq(value))); + decoder_state_.OnNameIndexAndLiteralValue(entry_type, name_index, + &value_buffer_); + Mock::VerifyAndClearExpectations(&listener_); + } + + void SendNameAndValueAndVerifyCallback(HpackEntryType entry_type, + const char* name, + StringBacking name_backing, + const char* value, + StringBacking value_backing) { + SetName(name, name_backing); + SetValue(value, value_backing); + EXPECT_CALL(listener_, OnHeader(Eq(name), Eq(value))); + decoder_state_.OnLiteralNameAndValue(entry_type, &name_buffer_, + &value_buffer_); + Mock::VerifyAndClearExpectations(&listener_); + } + + void SendEndAndVerifyCallback() { + EXPECT_CALL(listener_, OnHeaderListEnd()); + decoder_state_.OnHeaderBlockEnd(); + Mock::VerifyAndClearExpectations(&listener_); + } + + // dynamic_index is one-based, because that is the way RFC 7541 shows it. + AssertionResult VerifyEntry(size_t dynamic_index, const char* name, + const char* value) { + const HpackStringPair* entry = + Lookup(dynamic_index + kFirstDynamicTableIndex - 1); + HTTP2_VERIFY_NE(entry, nullptr); + HTTP2_VERIFY_EQ(entry->name, name); + HTTP2_VERIFY_EQ(entry->value, value); + return AssertionSuccess(); + } + AssertionResult VerifyNoEntry(size_t dynamic_index) { + const HpackStringPair* entry = + Lookup(dynamic_index + kFirstDynamicTableIndex - 1); + HTTP2_VERIFY_EQ(entry, nullptr); + return AssertionSuccess(); + } + AssertionResult VerifyDynamicTableContents( + const std::vector>& entries) { + size_t index = 1; + for (const auto& entry : entries) { + HTTP2_VERIFY_SUCCESS(VerifyEntry(index, entry.first, entry.second)); + ++index; + } + HTTP2_VERIFY_SUCCESS(VerifyNoEntry(index)); + return AssertionSuccess(); + } + + StrictMock listener_; + HpackDecoderState decoder_state_; + HpackDecoderStringBuffer name_buffer_, value_buffer_; +}; + +// Test based on RFC 7541, section C.3: Request Examples without Huffman Coding. +// This section shows several consecutive header lists, corresponding to HTTP +// requests, on the same connection. +TEST_F(HpackDecoderStateTest, C3_RequestExamples) { + // C.3.1 First Request + // + // Header list to encode: + // + // :method: GET + // :scheme: http + // :path: / + // :authority: www.example.com + + SendStartAndVerifyCallback(); + SendIndexAndVerifyCallback(2, HpackEntryType::kIndexedHeader, ":method", + "GET"); + SendIndexAndVerifyCallback(6, HpackEntryType::kIndexedHeader, ":scheme", + "http"); + SendIndexAndVerifyCallback(4, HpackEntryType::kIndexedHeader, ":path", "/"); + SendValueAndVerifyCallback(1, HpackEntryType::kIndexedLiteralHeader, + ":authority", "www.example.com", UNBUFFERED); + SendEndAndVerifyCallback(); + + // Dynamic Table (after decoding): + // + // [ 1] (s = 57) :authority: www.example.com + // Table size: 57 + + ASSERT_TRUE(VerifyDynamicTableContents({{":authority", "www.example.com"}})); + ASSERT_EQ(57u, current_header_table_size()); + + // C.3.2 Second Request + // + // Header list to encode: + // + // :method: GET + // :scheme: http + // :path: / + // :authority: www.example.com + // cache-control: no-cache + + SendStartAndVerifyCallback(); + SendIndexAndVerifyCallback(2, HpackEntryType::kIndexedHeader, ":method", + "GET"); + SendIndexAndVerifyCallback(6, HpackEntryType::kIndexedHeader, ":scheme", + "http"); + SendIndexAndVerifyCallback(4, HpackEntryType::kIndexedHeader, ":path", "/"); + SendIndexAndVerifyCallback(62, HpackEntryType::kIndexedHeader, ":authority", + "www.example.com"); + SendValueAndVerifyCallback(24, HpackEntryType::kIndexedLiteralHeader, + "cache-control", "no-cache", UNBUFFERED); + SendEndAndVerifyCallback(); + + // Dynamic Table (after decoding): + // + // [ 1] (s = 53) cache-control: no-cache + // [ 2] (s = 57) :authority: www.example.com + // Table size: 110 + + ASSERT_TRUE(VerifyDynamicTableContents( + {{"cache-control", "no-cache"}, {":authority", "www.example.com"}})); + ASSERT_EQ(110u, current_header_table_size()); + + // C.3.3 Third Request + // + // Header list to encode: + // + // :method: GET + // :scheme: https + // :path: /index.html + // :authority: www.example.com + // custom-key: custom-value + + SendStartAndVerifyCallback(); + SendIndexAndVerifyCallback(2, HpackEntryType::kIndexedHeader, ":method", + "GET"); + SendIndexAndVerifyCallback(7, HpackEntryType::kIndexedHeader, ":scheme", + "https"); + SendIndexAndVerifyCallback(5, HpackEntryType::kIndexedHeader, ":path", + "/index.html"); + SendIndexAndVerifyCallback(63, HpackEntryType::kIndexedHeader, ":authority", + "www.example.com"); + SendNameAndValueAndVerifyCallback(HpackEntryType::kIndexedLiteralHeader, + "custom-key", UNBUFFERED, "custom-value", + UNBUFFERED); + SendEndAndVerifyCallback(); + + // Dynamic Table (after decoding): + // + // [ 1] (s = 54) custom-key: custom-value + // [ 2] (s = 53) cache-control: no-cache + // [ 3] (s = 57) :authority: www.example.com + // Table size: 164 + + ASSERT_TRUE(VerifyDynamicTableContents({{"custom-key", "custom-value"}, + {"cache-control", "no-cache"}, + {":authority", "www.example.com"}})); + ASSERT_EQ(164u, current_header_table_size()); +} + +// Test based on RFC 7541, section C.5: Response Examples without Huffman +// Coding. This section shows several consecutive header lists, corresponding +// to HTTP responses, on the same connection. The HTTP/2 setting parameter +// SETTINGS_HEADER_TABLE_SIZE is set to the value of 256 octets, causing +// some evictions to occur. +TEST_F(HpackDecoderStateTest, C5_ResponseExamples) { + set_header_table_size_limit(256); + + // C.5.1 First Response + // + // Header list to encode: + // + // :status: 302 + // cache-control: private + // date: Mon, 21 Oct 2013 20:13:21 GMT + // location: https://www.example.com + + SendStartAndVerifyCallback(); + SendValueAndVerifyCallback(8, HpackEntryType::kIndexedLiteralHeader, + ":status", "302", BUFFERED); + SendValueAndVerifyCallback(24, HpackEntryType::kIndexedLiteralHeader, + "cache-control", "private", UNBUFFERED); + SendValueAndVerifyCallback(33, HpackEntryType::kIndexedLiteralHeader, "date", + "Mon, 21 Oct 2013 20:13:21 GMT", UNBUFFERED); + SendValueAndVerifyCallback(46, HpackEntryType::kIndexedLiteralHeader, + "location", "https://www.example.com", UNBUFFERED); + SendEndAndVerifyCallback(); + + // Dynamic Table (after decoding): + // + // [ 1] (s = 63) location: https://www.example.com + // [ 2] (s = 65) date: Mon, 21 Oct 2013 20:13:21 GMT + // [ 3] (s = 52) cache-control: private + // [ 4] (s = 42) :status: 302 + // Table size: 222 + + ASSERT_TRUE( + VerifyDynamicTableContents({{"location", "https://www.example.com"}, + {"date", "Mon, 21 Oct 2013 20:13:21 GMT"}, + {"cache-control", "private"}, + {":status", "302"}})); + ASSERT_EQ(222u, current_header_table_size()); + + // C.5.2 Second Response + // + // The (":status", "302") header field is evicted from the dynamic table to + // free space to allow adding the (":status", "307") header field. + // + // Header list to encode: + // + // :status: 307 + // cache-control: private + // date: Mon, 21 Oct 2013 20:13:21 GMT + // location: https://www.example.com + + SendStartAndVerifyCallback(); + SendValueAndVerifyCallback(8, HpackEntryType::kIndexedLiteralHeader, + ":status", "307", BUFFERED); + SendIndexAndVerifyCallback(65, HpackEntryType::kIndexedHeader, + "cache-control", "private"); + SendIndexAndVerifyCallback(64, HpackEntryType::kIndexedHeader, "date", + "Mon, 21 Oct 2013 20:13:21 GMT"); + SendIndexAndVerifyCallback(63, HpackEntryType::kIndexedHeader, "location", + "https://www.example.com"); + SendEndAndVerifyCallback(); + + // Dynamic Table (after decoding): + // + // [ 1] (s = 42) :status: 307 + // [ 2] (s = 63) location: https://www.example.com + // [ 3] (s = 65) date: Mon, 21 Oct 2013 20:13:21 GMT + // [ 4] (s = 52) cache-control: private + // Table size: 222 + + ASSERT_TRUE( + VerifyDynamicTableContents({{":status", "307"}, + {"location", "https://www.example.com"}, + {"date", "Mon, 21 Oct 2013 20:13:21 GMT"}, + {"cache-control", "private"}})); + ASSERT_EQ(222u, current_header_table_size()); + + // C.5.3 Third Response + // + // Several header fields are evicted from the dynamic table during the + // processing of this header list. + // + // Header list to encode: + // + // :status: 200 + // cache-control: private + // date: Mon, 21 Oct 2013 20:13:22 GMT + // location: https://www.example.com + // content-encoding: gzip + // set-cookie: foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1 + + SendStartAndVerifyCallback(); + SendIndexAndVerifyCallback(8, HpackEntryType::kIndexedHeader, ":status", + "200"); + SendIndexAndVerifyCallback(65, HpackEntryType::kIndexedHeader, + "cache-control", "private"); + SendValueAndVerifyCallback(33, HpackEntryType::kIndexedLiteralHeader, "date", + "Mon, 21 Oct 2013 20:13:22 GMT", BUFFERED); + SendIndexAndVerifyCallback(64, HpackEntryType::kIndexedHeader, "location", + "https://www.example.com"); + SendValueAndVerifyCallback(26, HpackEntryType::kIndexedLiteralHeader, + "content-encoding", "gzip", UNBUFFERED); + SendValueAndVerifyCallback( + 55, HpackEntryType::kIndexedLiteralHeader, "set-cookie", + "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1", BUFFERED); + SendEndAndVerifyCallback(); + + // Dynamic Table (after decoding): + // + // [ 1] (s = 98) set-cookie: foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; + // max-age=3600; version=1 + // [ 2] (s = 52) content-encoding: gzip + // [ 3] (s = 65) date: Mon, 21 Oct 2013 20:13:22 GMT + // Table size: 215 + + ASSERT_TRUE(VerifyDynamicTableContents( + {{"set-cookie", + "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"}, + {"content-encoding", "gzip"}, + {"date", "Mon, 21 Oct 2013 20:13:22 GMT"}})); + ASSERT_EQ(215u, current_header_table_size()); +} + +// Confirm that the table size can be changed, but at most twice. +TEST_F(HpackDecoderStateTest, OptionalTableSizeChanges) { + SendStartAndVerifyCallback(); + EXPECT_EQ(Http2SettingsInfo::DefaultHeaderTableSize(), + header_table_size_limit()); + SendSizeUpdate(1024); + EXPECT_EQ(1024u, header_table_size_limit()); + SendSizeUpdate(0); + EXPECT_EQ(0u, header_table_size_limit()); + + // Three updates aren't allowed. + EXPECT_CALL(listener_, OnHeaderErrorDetected( + Eq("Dynamic table size update not allowed"))); + SendSizeUpdate(0); +} + +// Confirm that required size updates are indeed required before headers. +TEST_F(HpackDecoderStateTest, RequiredTableSizeChangeBeforeHeader) { + EXPECT_EQ(4096u, decoder_state_.GetCurrentHeaderTableSizeSetting()); + decoder_state_.ApplyHeaderTableSizeSetting(1024); + decoder_state_.ApplyHeaderTableSizeSetting(2048); + EXPECT_EQ(2048u, decoder_state_.GetCurrentHeaderTableSizeSetting()); + + // First provide the required update, and an allowed second update. + SendStartAndVerifyCallback(); + EXPECT_EQ(Http2SettingsInfo::DefaultHeaderTableSize(), + header_table_size_limit()); + SendSizeUpdate(1024); + EXPECT_EQ(1024u, header_table_size_limit()); + SendSizeUpdate(1500); + EXPECT_EQ(1500u, header_table_size_limit()); + SendEndAndVerifyCallback(); + + // Another HPACK block, but this time missing the required size update. + decoder_state_.ApplyHeaderTableSizeSetting(1024); + EXPECT_EQ(1024u, decoder_state_.GetCurrentHeaderTableSizeSetting()); + SendStartAndVerifyCallback(); + EXPECT_CALL(listener_, + OnHeaderErrorDetected(Eq("Missing dynamic table size update"))); + decoder_state_.OnIndexedHeader(1); + + // Further decoded entries are ignored. + decoder_state_.OnIndexedHeader(1); + decoder_state_.OnDynamicTableSizeUpdate(1); + SetValue("value", UNBUFFERED); + decoder_state_.OnNameIndexAndLiteralValue( + HpackEntryType::kIndexedLiteralHeader, 4, &value_buffer_); + SetName("name", UNBUFFERED); + decoder_state_.OnLiteralNameAndValue(HpackEntryType::kIndexedLiteralHeader, + &name_buffer_, &value_buffer_); + decoder_state_.OnHeaderBlockEnd(); + decoder_state_.OnHpackDecodeError(HpackDecodingError::kIndexVarintError, ""); +} + +// Confirm that required size updates are validated. +TEST_F(HpackDecoderStateTest, InvalidRequiredSizeUpdate) { + // Require a size update, but provide one that isn't small enough. + decoder_state_.ApplyHeaderTableSizeSetting(1024); + SendStartAndVerifyCallback(); + EXPECT_EQ(Http2SettingsInfo::DefaultHeaderTableSize(), + header_table_size_limit()); + EXPECT_CALL( + listener_, + OnHeaderErrorDetected( + Eq("Initial dynamic table size update is above low water mark"))); + SendSizeUpdate(2048); +} + +// Confirm that required size updates are indeed required before the end. +TEST_F(HpackDecoderStateTest, RequiredTableSizeChangeBeforeEnd) { + decoder_state_.ApplyHeaderTableSizeSetting(1024); + SendStartAndVerifyCallback(); + EXPECT_CALL(listener_, + OnHeaderErrorDetected(Eq("Missing dynamic table size update"))); + decoder_state_.OnHeaderBlockEnd(); +} + +// Confirm that optional size updates are validated. +TEST_F(HpackDecoderStateTest, InvalidOptionalSizeUpdate) { + // Require a size update, but provide one that isn't small enough. + SendStartAndVerifyCallback(); + EXPECT_EQ(Http2SettingsInfo::DefaultHeaderTableSize(), + header_table_size_limit()); + EXPECT_CALL(listener_, + OnHeaderErrorDetected(Eq( + "Dynamic table size update is above acknowledged setting"))); + SendSizeUpdate(Http2SettingsInfo::DefaultHeaderTableSize() + 1); +} + +TEST_F(HpackDecoderStateTest, InvalidStaticIndex) { + SendStartAndVerifyCallback(); + EXPECT_CALL(listener_, + OnHeaderErrorDetected( + Eq("Invalid index in indexed header field representation"))); + decoder_state_.OnIndexedHeader(0); +} + +TEST_F(HpackDecoderStateTest, InvalidDynamicIndex) { + SendStartAndVerifyCallback(); + EXPECT_CALL(listener_, + OnHeaderErrorDetected( + Eq("Invalid index in indexed header field representation"))); + decoder_state_.OnIndexedHeader(kFirstDynamicTableIndex); +} + +TEST_F(HpackDecoderStateTest, InvalidNameIndex) { + SendStartAndVerifyCallback(); + EXPECT_CALL(listener_, + OnHeaderErrorDetected(Eq("Invalid index in literal header field " + "with indexed name representation"))); + SetValue("value", UNBUFFERED); + decoder_state_.OnNameIndexAndLiteralValue( + HpackEntryType::kIndexedLiteralHeader, kFirstDynamicTableIndex, + &value_buffer_); +} + +TEST_F(HpackDecoderStateTest, ErrorsSuppressCallbacks) { + SendStartAndVerifyCallback(); + EXPECT_CALL(listener_, + OnHeaderErrorDetected(Eq("Name Huffman encoding error"))); + decoder_state_.OnHpackDecodeError(HpackDecodingError::kNameHuffmanError, ""); + + // Further decoded entries are ignored. + decoder_state_.OnIndexedHeader(1); + decoder_state_.OnDynamicTableSizeUpdate(1); + SetValue("value", UNBUFFERED); + decoder_state_.OnNameIndexAndLiteralValue( + HpackEntryType::kIndexedLiteralHeader, 4, &value_buffer_); + SetName("name", UNBUFFERED); + decoder_state_.OnLiteralNameAndValue(HpackEntryType::kIndexedLiteralHeader, + &name_buffer_, &value_buffer_); + decoder_state_.OnHeaderBlockEnd(); + decoder_state_.OnHpackDecodeError(HpackDecodingError::kIndexVarintError, ""); +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/hpack/decoder/hpack_decoder_string_buffer.cc b/quiche/http2/hpack/decoder/hpack_decoder_string_buffer.cc new file mode 100644 index 000000000000..c57da8e28532 --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_decoder_string_buffer.cc @@ -0,0 +1,239 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/decoder/hpack_decoder_string_buffer.h" + +#include + +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +std::ostream& operator<<(std::ostream& out, + const HpackDecoderStringBuffer::State v) { + switch (v) { + case HpackDecoderStringBuffer::State::RESET: + return out << "RESET"; + case HpackDecoderStringBuffer::State::COLLECTING: + return out << "COLLECTING"; + case HpackDecoderStringBuffer::State::COMPLETE: + return out << "COMPLETE"; + } + // Since the value doesn't come over the wire, only a programming bug should + // result in reaching this point. + int unknown = static_cast(v); + QUICHE_BUG(http2_bug_50_1) + << "Invalid HpackDecoderStringBuffer::State: " << unknown; + return out << "HpackDecoderStringBuffer::State(" << unknown << ")"; +} + +std::ostream& operator<<(std::ostream& out, + const HpackDecoderStringBuffer::Backing v) { + switch (v) { + case HpackDecoderStringBuffer::Backing::RESET: + return out << "RESET"; + case HpackDecoderStringBuffer::Backing::UNBUFFERED: + return out << "UNBUFFERED"; + case HpackDecoderStringBuffer::Backing::BUFFERED: + return out << "BUFFERED"; + case HpackDecoderStringBuffer::Backing::STATIC: + return out << "STATIC"; + } + // Since the value doesn't come over the wire, only a programming bug should + // result in reaching this point. + auto v2 = static_cast(v); + QUICHE_BUG(http2_bug_50_2) + << "Invalid HpackDecoderStringBuffer::Backing: " << v2; + return out << "HpackDecoderStringBuffer::Backing(" << v2 << ")"; +} + +HpackDecoderStringBuffer::HpackDecoderStringBuffer() + : remaining_len_(0), + is_huffman_encoded_(false), + state_(State::RESET), + backing_(Backing::RESET) {} +HpackDecoderStringBuffer::~HpackDecoderStringBuffer() = default; + +void HpackDecoderStringBuffer::Reset() { + QUICHE_DVLOG(3) << "HpackDecoderStringBuffer::Reset"; + state_ = State::RESET; +} + +void HpackDecoderStringBuffer::Set(absl::string_view value, bool is_static) { + QUICHE_DVLOG(2) << "HpackDecoderStringBuffer::Set"; + QUICHE_DCHECK_EQ(state_, State::RESET); + value_ = value; + state_ = State::COMPLETE; + backing_ = is_static ? Backing::STATIC : Backing::UNBUFFERED; + // TODO(jamessynge): Determine which of these two fields must be set. + remaining_len_ = 0; + is_huffman_encoded_ = false; +} + +void HpackDecoderStringBuffer::OnStart(bool huffman_encoded, size_t len) { + QUICHE_DVLOG(2) << "HpackDecoderStringBuffer::OnStart"; + QUICHE_DCHECK_EQ(state_, State::RESET); + + remaining_len_ = len; + is_huffman_encoded_ = huffman_encoded; + state_ = State::COLLECTING; + + if (huffman_encoded) { + // We don't set, clear or use value_ for buffered strings until OnEnd. + decoder_.Reset(); + buffer_.clear(); + backing_ = Backing::BUFFERED; + + // Reserve space in buffer_ for the uncompressed string, assuming the + // maximum expansion. The shortest Huffman codes in the RFC are 5 bits long, + // which then expand to 8 bits during decoding (i.e. each code is for one + // plain text octet, aka byte), so the maximum size is 60% longer than the + // encoded size. + len = len * 8 / 5; + if (buffer_.capacity() < len) { + buffer_.reserve(len); + } + } else { + // Assume for now that we won't need to use buffer_, so don't reserve space + // in it. + backing_ = Backing::RESET; + // OnData is not called for empty (zero length) strings, so make sure that + // value_ is cleared. + value_ = absl::string_view(); + } +} + +bool HpackDecoderStringBuffer::OnData(const char* data, size_t len) { + QUICHE_DVLOG(2) << "HpackDecoderStringBuffer::OnData state=" << state_ + << ", backing=" << backing_; + QUICHE_DCHECK_EQ(state_, State::COLLECTING); + QUICHE_DCHECK_LE(len, remaining_len_); + remaining_len_ -= len; + + if (is_huffman_encoded_) { + QUICHE_DCHECK_EQ(backing_, Backing::BUFFERED); + return decoder_.Decode(absl::string_view(data, len), &buffer_); + } + + if (backing_ == Backing::RESET) { + // This is the first call to OnData. If data contains the entire string, + // don't copy the string. If we later find that the HPACK entry is split + // across input buffers, then we'll copy the string into buffer_. + if (remaining_len_ == 0) { + value_ = absl::string_view(data, len); + backing_ = Backing::UNBUFFERED; + return true; + } + + // We need to buffer the string because it is split across input buffers. + // Reserve space in buffer_ for the entire string. + backing_ = Backing::BUFFERED; + buffer_.reserve(remaining_len_ + len); + buffer_.assign(data, len); + return true; + } + + // This is not the first call to OnData for this string, so it should be + // buffered. + QUICHE_DCHECK_EQ(backing_, Backing::BUFFERED); + + // Append to the current contents of the buffer. + buffer_.append(data, len); + return true; +} + +bool HpackDecoderStringBuffer::OnEnd() { + QUICHE_DVLOG(2) << "HpackDecoderStringBuffer::OnEnd"; + QUICHE_DCHECK_EQ(state_, State::COLLECTING); + QUICHE_DCHECK_EQ(0u, remaining_len_); + + if (is_huffman_encoded_) { + QUICHE_DCHECK_EQ(backing_, Backing::BUFFERED); + // Did the Huffman encoding of the string end properly? + if (!decoder_.InputProperlyTerminated()) { + return false; // No, it didn't. + } + value_ = buffer_; + } else if (backing_ == Backing::BUFFERED) { + value_ = buffer_; + } + state_ = State::COMPLETE; + return true; +} + +void HpackDecoderStringBuffer::BufferStringIfUnbuffered() { + QUICHE_DVLOG(3) << "HpackDecoderStringBuffer::BufferStringIfUnbuffered state=" + << state_ << ", backing=" << backing_; + if (state_ != State::RESET && backing_ == Backing::UNBUFFERED) { + QUICHE_DVLOG(2) + << "HpackDecoderStringBuffer buffering std::string of length " + << value_.size(); + buffer_.assign(value_.data(), value_.size()); + if (state_ == State::COMPLETE) { + value_ = buffer_; + } + backing_ = Backing::BUFFERED; + } +} + +bool HpackDecoderStringBuffer::IsBuffered() const { + QUICHE_DVLOG(3) << "HpackDecoderStringBuffer::IsBuffered"; + return state_ != State::RESET && backing_ == Backing::BUFFERED; +} + +size_t HpackDecoderStringBuffer::BufferedLength() const { + QUICHE_DVLOG(3) << "HpackDecoderStringBuffer::BufferedLength"; + return IsBuffered() ? buffer_.size() : 0; +} + +absl::string_view HpackDecoderStringBuffer::str() const { + QUICHE_DVLOG(3) << "HpackDecoderStringBuffer::str"; + QUICHE_DCHECK_EQ(state_, State::COMPLETE); + return value_; +} + +absl::string_view HpackDecoderStringBuffer::GetStringIfComplete() const { + if (state_ != State::COMPLETE) { + return {}; + } + return str(); +} + +std::string HpackDecoderStringBuffer::ReleaseString() { + QUICHE_DVLOG(3) << "HpackDecoderStringBuffer::ReleaseString"; + QUICHE_DCHECK_EQ(state_, State::COMPLETE); + QUICHE_DCHECK_EQ(backing_, Backing::BUFFERED); + if (state_ == State::COMPLETE) { + state_ = State::RESET; + if (backing_ == Backing::BUFFERED) { + return std::move(buffer_); + } else { + return std::string(value_); + } + } + return ""; +} + +void HpackDecoderStringBuffer::OutputDebugStringTo(std::ostream& out) const { + out << "{state=" << state_; + if (state_ != State::RESET) { + out << ", backing=" << backing_; + out << ", remaining_len=" << remaining_len_; + out << ", is_huffman_encoded=" << is_huffman_encoded_; + if (backing_ == Backing::BUFFERED) { + out << ", buffer: " << buffer_; + } else { + out << ", value: " << value_; + } + } + out << "}"; +} + +std::ostream& operator<<(std::ostream& out, const HpackDecoderStringBuffer& v) { + v.OutputDebugStringTo(out); + return out; +} + +} // namespace http2 diff --git a/quiche/http2/hpack/decoder/hpack_decoder_string_buffer.h b/quiche/http2/hpack/decoder/hpack_decoder_string_buffer.h new file mode 100644 index 000000000000..a2b8605a5bbd --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_decoder_string_buffer.h @@ -0,0 +1,101 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_HPACK_DECODER_HPACK_DECODER_STRING_BUFFER_H_ +#define QUICHE_HTTP2_HPACK_DECODER_HPACK_DECODER_STRING_BUFFER_H_ + +// HpackDecoderStringBuffer helps an HPACK decoder to avoid copies of a string +// literal (name or value) except when necessary (e.g. when split across two +// or more HPACK block fragments). + +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/http2/hpack/huffman/hpack_huffman_decoder.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { + +class QUICHE_EXPORT HpackDecoderStringBuffer { + public: + enum class State : uint8_t { RESET, COLLECTING, COMPLETE }; + enum class Backing : uint8_t { RESET, UNBUFFERED, BUFFERED, STATIC }; + + HpackDecoderStringBuffer(); + ~HpackDecoderStringBuffer(); + + HpackDecoderStringBuffer(const HpackDecoderStringBuffer&) = delete; + HpackDecoderStringBuffer& operator=(const HpackDecoderStringBuffer&) = delete; + + void Reset(); + void Set(absl::string_view value, bool is_static); + + // Note that for Huffman encoded strings the length of the string after + // decoding may be larger (expected), the same or even smaller; the latter + // are unlikely, but possible if the encoder makes odd choices. + void OnStart(bool huffman_encoded, size_t len); + bool OnData(const char* data, size_t len); + bool OnEnd(); + void BufferStringIfUnbuffered(); + bool IsBuffered() const; + size_t BufferedLength() const; + + // Accessors for the completely collected string (i.e. Set or OnEnd has just + // been called, and no reset of the state has occurred). + + // Returns a string_view pointing to the backing store for the string, + // either the internal buffer or the original transport buffer (e.g. for a + // literal value that wasn't Huffman encoded, and that wasn't split across + // transport buffers). + absl::string_view str() const; + + // Same as str() if state_ is COMPLETE. Otherwise, returns empty string piece. + absl::string_view GetStringIfComplete() const; + + // Returns the completely collected string by value, using std::move in an + // effort to avoid unnecessary copies. ReleaseString() must not be called + // unless the string has been buffered (to avoid forcing a potentially + // unnecessary copy). ReleaseString() also resets the instance so that it can + // be used to collect another string. + std::string ReleaseString(); + + State state_for_testing() const { return state_; } + Backing backing_for_testing() const { return backing_; } + void OutputDebugStringTo(std::ostream& out) const; + + private: + // Storage for the string being buffered, if buffering is necessary + // (e.g. if Huffman encoded, buffer_ is storage for the decoded string). + std::string buffer_; + + // The string_view to be returned by HpackDecoderStringBuffer::str(). If + // a string has been collected, but not buffered, value_ points to that + // string. + absl::string_view value_; + + // The decoder to use if the string is Huffman encoded. + HpackHuffmanDecoder decoder_; + + // Count of bytes not yet passed to OnData. + size_t remaining_len_; + + // Is the HPACK string Huffman encoded? + bool is_huffman_encoded_; + + // State of the string decoding process. + State state_; + + // Where is the string stored? + Backing backing_; +}; + +QUICHE_EXPORT std::ostream& operator<<(std::ostream& out, + const HpackDecoderStringBuffer& v); + +} // namespace http2 + +#endif // QUICHE_HTTP2_HPACK_DECODER_HPACK_DECODER_STRING_BUFFER_H_ diff --git a/quiche/http2/hpack/decoder/hpack_decoder_string_buffer_test.cc b/quiche/http2/hpack/decoder/hpack_decoder_string_buffer_test.cc new file mode 100644 index 000000000000..299391f11cf4 --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_decoder_string_buffer_test.cc @@ -0,0 +1,250 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/decoder/hpack_decoder_string_buffer.h" + +// Tests of HpackDecoderStringBuffer. + +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/match.h" +#include "quiche/http2/test_tools/verify_macros.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +using ::testing::AssertionResult; +using ::testing::AssertionSuccess; +using ::testing::HasSubstr; + +namespace http2 { +namespace test { +namespace { + +class HpackDecoderStringBufferTest : public quiche::test::QuicheTest { + protected: + typedef HpackDecoderStringBuffer::State State; + typedef HpackDecoderStringBuffer::Backing Backing; + + State state() const { return buf_.state_for_testing(); } + Backing backing() const { return buf_.backing_for_testing(); } + + // We want to know that QUICHE_LOG(x) << buf_ will work in production should + // that be needed, so we test that it outputs the expected values. + AssertionResult VerifyLogHasSubstrs(std::initializer_list strs) { + QUICHE_VLOG(1) << buf_; + std::ostringstream ss; + buf_.OutputDebugStringTo(ss); + std::string dbg_str(ss.str()); + for (const auto& expected : strs) { + HTTP2_VERIFY_TRUE(absl::StrContains(dbg_str, expected)); + } + return AssertionSuccess(); + } + + HpackDecoderStringBuffer buf_; +}; + +TEST_F(HpackDecoderStringBufferTest, SetStatic) { + absl::string_view data("static string"); + + EXPECT_EQ(state(), State::RESET); + EXPECT_TRUE(VerifyLogHasSubstrs({"state=RESET"})); + + buf_.Set(data, /*is_static*/ true); + QUICHE_LOG(INFO) << buf_; + EXPECT_EQ(state(), State::COMPLETE); + EXPECT_EQ(backing(), Backing::STATIC); + EXPECT_EQ(data, buf_.str()); + EXPECT_EQ(data.data(), buf_.str().data()); + EXPECT_TRUE(VerifyLogHasSubstrs( + {"state=COMPLETE", "backing=STATIC", "value: static string"})); + + // The string is static, so BufferStringIfUnbuffered won't change anything. + buf_.BufferStringIfUnbuffered(); + EXPECT_EQ(state(), State::COMPLETE); + EXPECT_EQ(backing(), Backing::STATIC); + EXPECT_EQ(data, buf_.str()); + EXPECT_EQ(data.data(), buf_.str().data()); + EXPECT_TRUE(VerifyLogHasSubstrs( + {"state=COMPLETE", "backing=STATIC", "value: static string"})); +} + +TEST_F(HpackDecoderStringBufferTest, PlainWhole) { + absl::string_view data("some text."); + + QUICHE_LOG(INFO) << buf_; + EXPECT_EQ(state(), State::RESET); + + buf_.OnStart(/*huffman_encoded*/ false, data.size()); + EXPECT_EQ(state(), State::COLLECTING); + EXPECT_EQ(backing(), Backing::RESET); + QUICHE_LOG(INFO) << buf_; + + EXPECT_TRUE(buf_.OnData(data.data(), data.size())); + EXPECT_EQ(state(), State::COLLECTING); + EXPECT_EQ(backing(), Backing::UNBUFFERED); + + EXPECT_TRUE(buf_.OnEnd()); + EXPECT_EQ(state(), State::COMPLETE); + EXPECT_EQ(backing(), Backing::UNBUFFERED); + EXPECT_EQ(0u, buf_.BufferedLength()); + EXPECT_TRUE(VerifyLogHasSubstrs( + {"state=COMPLETE", "backing=UNBUFFERED", "value: some text."})); + + // We expect that the string buffer points to the passed in + // string_view's backing store. + EXPECT_EQ(data.data(), buf_.str().data()); + + // Now force it to buffer the string, after which it will still have the same + // string value, but the backing store will be different. + buf_.BufferStringIfUnbuffered(); + QUICHE_LOG(INFO) << buf_; + EXPECT_EQ(backing(), Backing::BUFFERED); + EXPECT_EQ(buf_.BufferedLength(), data.size()); + EXPECT_EQ(data, buf_.str()); + EXPECT_NE(data.data(), buf_.str().data()); + EXPECT_TRUE(VerifyLogHasSubstrs( + {"state=COMPLETE", "backing=BUFFERED", "buffer: some text."})); +} + +TEST_F(HpackDecoderStringBufferTest, PlainSplit) { + absl::string_view data("some text."); + absl::string_view part1 = data.substr(0, 1); + absl::string_view part2 = data.substr(1); + + EXPECT_EQ(state(), State::RESET); + buf_.OnStart(/*huffman_encoded*/ false, data.size()); + EXPECT_EQ(state(), State::COLLECTING); + EXPECT_EQ(backing(), Backing::RESET); + + // OnData with only a part of the data, not the whole, so buf_ will buffer + // the data. + EXPECT_TRUE(buf_.OnData(part1.data(), part1.size())); + EXPECT_EQ(state(), State::COLLECTING); + EXPECT_EQ(backing(), Backing::BUFFERED); + EXPECT_EQ(buf_.BufferedLength(), part1.size()); + QUICHE_LOG(INFO) << buf_; + + EXPECT_TRUE(buf_.OnData(part2.data(), part2.size())); + EXPECT_EQ(state(), State::COLLECTING); + EXPECT_EQ(backing(), Backing::BUFFERED); + EXPECT_EQ(buf_.BufferedLength(), data.size()); + + EXPECT_TRUE(buf_.OnEnd()); + EXPECT_EQ(state(), State::COMPLETE); + EXPECT_EQ(backing(), Backing::BUFFERED); + EXPECT_EQ(buf_.BufferedLength(), data.size()); + QUICHE_LOG(INFO) << buf_; + + absl::string_view buffered = buf_.str(); + EXPECT_EQ(data, buffered); + EXPECT_NE(data.data(), buffered.data()); + + // The string is already buffered, so BufferStringIfUnbuffered should not make + // any change. + buf_.BufferStringIfUnbuffered(); + EXPECT_EQ(backing(), Backing::BUFFERED); + EXPECT_EQ(buf_.BufferedLength(), data.size()); + EXPECT_EQ(buffered, buf_.str()); + EXPECT_EQ(buffered.data(), buf_.str().data()); +} + +TEST_F(HpackDecoderStringBufferTest, HuffmanWhole) { + std::string encoded = absl::HexStringToBytes("f1e3c2e5f23a6ba0ab90f4ff"); + absl::string_view decoded("www.example.com"); + + EXPECT_EQ(state(), State::RESET); + buf_.OnStart(/*huffman_encoded*/ true, encoded.size()); + EXPECT_EQ(state(), State::COLLECTING); + + EXPECT_TRUE(buf_.OnData(encoded.data(), encoded.size())); + EXPECT_EQ(state(), State::COLLECTING); + EXPECT_EQ(backing(), Backing::BUFFERED); + + EXPECT_TRUE(buf_.OnEnd()); + EXPECT_EQ(state(), State::COMPLETE); + EXPECT_EQ(backing(), Backing::BUFFERED); + EXPECT_EQ(buf_.BufferedLength(), decoded.size()); + EXPECT_EQ(decoded, buf_.str()); + EXPECT_TRUE(VerifyLogHasSubstrs( + {"{state=COMPLETE", "backing=BUFFERED", "buffer: www.example.com}"})); + + std::string s = buf_.ReleaseString(); + EXPECT_EQ(s, decoded); + EXPECT_EQ(state(), State::RESET); +} + +TEST_F(HpackDecoderStringBufferTest, HuffmanSplit) { + std::string encoded = absl::HexStringToBytes("f1e3c2e5f23a6ba0ab90f4ff"); + std::string part1 = encoded.substr(0, 5); + std::string part2 = encoded.substr(5); + absl::string_view decoded("www.example.com"); + + EXPECT_EQ(state(), State::RESET); + buf_.OnStart(/*huffman_encoded*/ true, encoded.size()); + EXPECT_EQ(state(), State::COLLECTING); + EXPECT_EQ(backing(), Backing::BUFFERED); + EXPECT_EQ(0u, buf_.BufferedLength()); + QUICHE_LOG(INFO) << buf_; + + EXPECT_TRUE(buf_.OnData(part1.data(), part1.size())); + EXPECT_EQ(state(), State::COLLECTING); + EXPECT_EQ(backing(), Backing::BUFFERED); + EXPECT_GT(buf_.BufferedLength(), 0u); + EXPECT_LT(buf_.BufferedLength(), decoded.size()); + QUICHE_LOG(INFO) << buf_; + + EXPECT_TRUE(buf_.OnData(part2.data(), part2.size())); + EXPECT_EQ(state(), State::COLLECTING); + EXPECT_EQ(backing(), Backing::BUFFERED); + EXPECT_EQ(buf_.BufferedLength(), decoded.size()); + QUICHE_LOG(INFO) << buf_; + + EXPECT_TRUE(buf_.OnEnd()); + EXPECT_EQ(state(), State::COMPLETE); + EXPECT_EQ(backing(), Backing::BUFFERED); + EXPECT_EQ(buf_.BufferedLength(), decoded.size()); + EXPECT_EQ(decoded, buf_.str()); + QUICHE_LOG(INFO) << buf_; + + buf_.Reset(); + EXPECT_EQ(state(), State::RESET); + QUICHE_LOG(INFO) << buf_; +} + +TEST_F(HpackDecoderStringBufferTest, InvalidHuffmanOnData) { + // Explicitly encode the End-of-String symbol, a no-no. + std::string encoded = absl::HexStringToBytes("ffffffff"); + + buf_.OnStart(/*huffman_encoded*/ true, encoded.size()); + EXPECT_EQ(state(), State::COLLECTING); + + EXPECT_FALSE(buf_.OnData(encoded.data(), encoded.size())); + EXPECT_EQ(state(), State::COLLECTING); + EXPECT_EQ(backing(), Backing::BUFFERED); + + QUICHE_LOG(INFO) << buf_; +} + +TEST_F(HpackDecoderStringBufferTest, InvalidHuffmanOnEnd) { + // Last byte of string doesn't end with prefix of End-of-String symbol. + std::string encoded = absl::HexStringToBytes("00"); + + buf_.OnStart(/*huffman_encoded*/ true, encoded.size()); + EXPECT_EQ(state(), State::COLLECTING); + + EXPECT_TRUE(buf_.OnData(encoded.data(), encoded.size())); + EXPECT_EQ(state(), State::COLLECTING); + EXPECT_EQ(backing(), Backing::BUFFERED); + + EXPECT_FALSE(buf_.OnEnd()); + QUICHE_LOG(INFO) << buf_; +} + +// TODO(jamessynge): Add tests for ReleaseString(). + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/hpack/decoder/hpack_decoder_tables.cc b/quiche/http2/hpack/decoder/hpack_decoder_tables.cc new file mode 100644 index 000000000000..41cffa75758b --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_decoder_tables.cc @@ -0,0 +1,148 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/decoder/hpack_decoder_tables.h" + +#include "absl/strings/str_cat.h" +#include "quiche/http2/hpack/http2_hpack_constants.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { +namespace { + +std::vector* MakeStaticTable() { + auto* ptr = new std::vector(); + ptr->reserve(kFirstDynamicTableIndex); + ptr->emplace_back("", ""); + +#define STATIC_TABLE_ENTRY(name, value, index) \ + QUICHE_DCHECK_EQ(ptr->size(), static_cast(index)); \ + ptr->emplace_back(name, value) + +#include "quiche/http2/hpack/hpack_static_table_entries.inc" + +#undef STATIC_TABLE_ENTRY + + return ptr; +} + +const std::vector* GetStaticTable() { + static const std::vector* const g_static_table = + MakeStaticTable(); + return g_static_table; +} + +} // namespace + +HpackStringPair::HpackStringPair(std::string name, std::string value) + : name(std::move(name)), value(std::move(value)) { + QUICHE_DVLOG(3) << DebugString() << " ctor"; +} + +HpackStringPair::~HpackStringPair() { + QUICHE_DVLOG(3) << DebugString() << " dtor"; +} + +std::string HpackStringPair::DebugString() const { + return absl::StrCat("HpackStringPair(name=", name, ", value=", value, ")"); +} + +std::ostream& operator<<(std::ostream& os, const HpackStringPair& p) { + os << p.DebugString(); + return os; +} + +HpackDecoderStaticTable::HpackDecoderStaticTable( + const std::vector* table) + : table_(table) {} + +HpackDecoderStaticTable::HpackDecoderStaticTable() : table_(GetStaticTable()) {} + +const HpackStringPair* HpackDecoderStaticTable::Lookup(size_t index) const { + if (0 < index && index < kFirstDynamicTableIndex) { + return &((*table_)[index]); + } + return nullptr; +} + +HpackDecoderDynamicTable::HpackDecoderDynamicTable() + : insert_count_(kFirstDynamicTableIndex - 1) {} +HpackDecoderDynamicTable::~HpackDecoderDynamicTable() = default; + +void HpackDecoderDynamicTable::DynamicTableSizeUpdate(size_t size_limit) { + QUICHE_DVLOG(3) << "HpackDecoderDynamicTable::DynamicTableSizeUpdate " + << size_limit; + EnsureSizeNoMoreThan(size_limit); + QUICHE_DCHECK_LE(current_size_, size_limit); + size_limit_ = size_limit; +} + +// TODO(jamessynge): Check somewhere before here that names received from the +// peer are valid (e.g. are lower-case, no whitespace, etc.). +void HpackDecoderDynamicTable::Insert(std::string name, std::string value) { + HpackStringPair entry(std::move(name), std::move(value)); + size_t entry_size = entry.size(); + QUICHE_DVLOG(2) << "InsertEntry of size=" << entry_size + << "\n name: " << entry.name + << "\n value: " << entry.value; + if (entry_size > size_limit_) { + QUICHE_DVLOG(2) << "InsertEntry: entry larger than table, removing " + << table_.size() << " entries, of total size " + << current_size_ << " bytes."; + table_.clear(); + current_size_ = 0; + return; + } + ++insert_count_; + size_t insert_limit = size_limit_ - entry_size; + EnsureSizeNoMoreThan(insert_limit); + table_.push_front(entry); + current_size_ += entry_size; + QUICHE_DVLOG(2) << "InsertEntry: current_size_=" << current_size_; + QUICHE_DCHECK_GE(current_size_, entry_size); + QUICHE_DCHECK_LE(current_size_, size_limit_); +} + +const HpackStringPair* HpackDecoderDynamicTable::Lookup(size_t index) const { + if (index < table_.size()) { + return &table_[index]; + } + return nullptr; +} + +void HpackDecoderDynamicTable::EnsureSizeNoMoreThan(size_t limit) { + QUICHE_DVLOG(2) << "EnsureSizeNoMoreThan limit=" << limit + << ", current_size_=" << current_size_; + // Not the most efficient choice, but any easy way to start. + while (current_size_ > limit) { + RemoveLastEntry(); + } + QUICHE_DCHECK_LE(current_size_, limit); +} + +void HpackDecoderDynamicTable::RemoveLastEntry() { + QUICHE_DCHECK(!table_.empty()); + if (!table_.empty()) { + QUICHE_DVLOG(2) << "RemoveLastEntry current_size_=" << current_size_ + << ", last entry size=" << table_.back().size(); + QUICHE_DCHECK_GE(current_size_, table_.back().size()); + current_size_ -= table_.back().size(); + table_.pop_back(); + // Empty IFF current_size_ == 0. + QUICHE_DCHECK_EQ(table_.empty(), current_size_ == 0); + } +} + +HpackDecoderTables::HpackDecoderTables() = default; +HpackDecoderTables::~HpackDecoderTables() = default; + +const HpackStringPair* HpackDecoderTables::Lookup(size_t index) const { + if (index < kFirstDynamicTableIndex) { + return static_table_.Lookup(index); + } else { + return dynamic_table_.Lookup(index - kFirstDynamicTableIndex); + } +} + +} // namespace http2 diff --git a/quiche/http2/hpack/decoder/hpack_decoder_tables.h b/quiche/http2/hpack/decoder/hpack_decoder_tables.h new file mode 100644 index 000000000000..38f665d311e4 --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_decoder_tables.h @@ -0,0 +1,166 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_HPACK_DECODER_HPACK_DECODER_TABLES_H_ +#define QUICHE_HTTP2_HPACK_DECODER_HPACK_DECODER_TABLES_H_ + +// Static and dynamic tables for the HPACK decoder. See: +// http://httpwg.org/specs/rfc7541.html#indexing.tables + +// Note that the Lookup methods return nullptr if the requested index was not +// found. This should be treated as a COMPRESSION error according to the HTTP/2 +// spec, which is a connection level protocol error (i.e. the connection must +// be terminated). See these sections in the two RFCs: +// http://httpwg.org/specs/rfc7541.html#indexed.header.representation +// http://httpwg.org/specs/rfc7541.html#index.address.space +// http://httpwg.org/specs/rfc7540.html#HeaderBlock + +#include + +#include +#include +#include +#include +#include + +#include "quiche/http2/http2_constants.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/quiche_circular_deque.h" + +namespace http2 { +namespace test { +class HpackDecoderTablesPeer; +} // namespace test + +struct QUICHE_EXPORT HpackStringPair { + HpackStringPair(std::string name, std::string value); + ~HpackStringPair(); + + // Returns the size of a header entry with this name and value, per the RFC: + // http://httpwg.org/specs/rfc7541.html#calculating.table.size + size_t size() const { return 32 + name.size() + value.size(); } + + std::string DebugString() const; + + const std::string name; + const std::string value; +}; + +QUICHE_EXPORT std::ostream& operator<<(std::ostream& os, + const HpackStringPair& p); + +// See http://httpwg.org/specs/rfc7541.html#static.table.definition for the +// contents, and http://httpwg.org/specs/rfc7541.html#index.address.space for +// info about accessing the static table. +class QUICHE_EXPORT HpackDecoderStaticTable { + public: + explicit HpackDecoderStaticTable(const std::vector* table); + // Uses a global table shared by all threads. + HpackDecoderStaticTable(); + + // If index is valid, returns a pointer to the entry, otherwise returns + // nullptr. + const HpackStringPair* Lookup(size_t index) const; + + private: + friend class test::HpackDecoderTablesPeer; + const std::vector* const table_; +}; + +// HpackDecoderDynamicTable implements HPACK compression feature "indexed +// headers"; previously sent headers may be referenced later by their index +// in the dynamic table. See these sections of the RFC: +// http://httpwg.org/specs/rfc7541.html#dynamic.table +// http://httpwg.org/specs/rfc7541.html#dynamic.table.management +class QUICHE_EXPORT HpackDecoderDynamicTable { + public: + HpackDecoderDynamicTable(); + ~HpackDecoderDynamicTable(); + + HpackDecoderDynamicTable(const HpackDecoderDynamicTable&) = delete; + HpackDecoderDynamicTable& operator=(const HpackDecoderDynamicTable&) = delete; + + // Sets a new size limit, received from the peer; performs evictions if + // necessary to ensure that the current size does not exceed the new limit. + // The caller needs to have validated that size_limit does not + // exceed the acknowledged value of SETTINGS_HEADER_TABLE_SIZE. + void DynamicTableSizeUpdate(size_t size_limit); + + // Insert entry if possible. + // If entry is too large to insert, then dynamic table will be empty. + void Insert(std::string name, std::string value); + + // If index is valid, returns a pointer to the entry, otherwise returns + // nullptr. + const HpackStringPair* Lookup(size_t index) const; + + size_t size_limit() const { return size_limit_; } + size_t current_size() const { return current_size_; } + + private: + friend class test::HpackDecoderTablesPeer; + + // Drop older entries to ensure the size is not greater than limit. + void EnsureSizeNoMoreThan(size_t limit); + + // Removes the oldest dynamic table entry. + void RemoveLastEntry(); + + quiche::QuicheCircularDeque table_; + + // The last received DynamicTableSizeUpdate value, initialized to + // SETTINGS_HEADER_TABLE_SIZE. + size_t size_limit_ = Http2SettingsInfo::DefaultHeaderTableSize(); + + size_t current_size_ = 0; + + // insert_count_ and debug_listener_ are used by a QUIC experiment; remove + // when the experiment is done. + size_t insert_count_; +}; + +class QUICHE_EXPORT HpackDecoderTables { + public: + HpackDecoderTables(); + ~HpackDecoderTables(); + + HpackDecoderTables(const HpackDecoderTables&) = delete; + HpackDecoderTables& operator=(const HpackDecoderTables&) = delete; + + // Sets a new size limit, received from the peer; performs evictions if + // necessary to ensure that the current size does not exceed the new limit. + // The caller needs to have validated that size_limit does not + // exceed the acknowledged value of SETTINGS_HEADER_TABLE_SIZE. + void DynamicTableSizeUpdate(size_t size_limit) { + dynamic_table_.DynamicTableSizeUpdate(size_limit); + } + + // Insert entry if possible. + // If entry is too large to insert, then dynamic table will be empty. + void Insert(std::string name, std::string value) { + dynamic_table_.Insert(std::move(name), std::move(value)); + } + + // If index is valid, returns a pointer to the entry, otherwise returns + // nullptr. + const HpackStringPair* Lookup(size_t index) const; + + // The size limit that the peer (the HPACK encoder) has told the decoder it is + // currently operating with. Defaults to SETTINGS_HEADER_TABLE_SIZE, 4096. + size_t header_table_size_limit() const { return dynamic_table_.size_limit(); } + + // Sum of the sizes of the dynamic table entries. + size_t current_header_table_size() const { + return dynamic_table_.current_size(); + } + + private: + friend class test::HpackDecoderTablesPeer; + HpackDecoderStaticTable static_table_; + HpackDecoderDynamicTable dynamic_table_; +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_HPACK_DECODER_HPACK_DECODER_TABLES_H_ diff --git a/quiche/http2/hpack/decoder/hpack_decoder_tables_test.cc b/quiche/http2/hpack/decoder/hpack_decoder_tables_test.cc new file mode 100644 index 000000000000..10917da74513 --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_decoder_tables_test.cc @@ -0,0 +1,257 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/decoder/hpack_decoder_tables.h" + +#include +#include +#include +#include + +#include "quiche/http2/hpack/http2_hpack_constants.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/http2/test_tools/random_util.h" +#include "quiche/http2/test_tools/verify_macros.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +using ::testing::AssertionResult; +using ::testing::AssertionSuccess; + +namespace http2 { +namespace test { +class HpackDecoderTablesPeer { + public: + static size_t num_dynamic_entries(const HpackDecoderTables& tables) { + return tables.dynamic_table_.table_.size(); + } +}; + +namespace { +struct StaticEntry { + const char* name; + const char* value; + size_t index; +}; + +std::vector MakeSpecStaticEntries() { + std::vector static_entries; + +#define STATIC_TABLE_ENTRY(name, value, index) \ + QUICHE_DCHECK_EQ(static_entries.size() + 1, static_cast(index)); \ + static_entries.push_back({name, value, index}); + +#include "quiche/http2/hpack/hpack_static_table_entries.inc" + +#undef STATIC_TABLE_ENTRY + + return static_entries; +} + +template +void ShuffleCollection(C* collection, Http2Random* r) { + std::shuffle(collection->begin(), collection->end(), *r); +} + +class HpackDecoderStaticTableTest : public quiche::test::QuicheTest { + protected: + HpackDecoderStaticTableTest() = default; + + std::vector shuffled_static_entries() { + std::vector entries = MakeSpecStaticEntries(); + ShuffleCollection(&entries, &random_); + return entries; + } + + // This test is in a function so that it can be applied to both the static + // table and the combined static+dynamic tables. + AssertionResult VerifyStaticTableContents() { + for (const auto& expected : shuffled_static_entries()) { + const HpackStringPair* found = Lookup(expected.index); + HTTP2_VERIFY_NE(found, nullptr); + HTTP2_VERIFY_EQ(expected.name, found->name) << expected.index; + HTTP2_VERIFY_EQ(expected.value, found->value) << expected.index; + } + + // There should be no entry with index 0. + HTTP2_VERIFY_EQ(nullptr, Lookup(0)); + return AssertionSuccess(); + } + + virtual const HpackStringPair* Lookup(size_t index) { + return static_table_.Lookup(index); + } + + Http2Random* RandomPtr() { return &random_; } + + Http2Random random_; + + private: + HpackDecoderStaticTable static_table_; +}; + +TEST_F(HpackDecoderStaticTableTest, StaticTableContents) { + EXPECT_TRUE(VerifyStaticTableContents()); +} + +size_t Size(const std::string& name, const std::string& value) { + return name.size() + value.size() + 32; +} + +// To support tests with more than a few of hand crafted changes to the dynamic +// table, we have another, exceedingly simple, implementation of the HPACK +// dynamic table containing FakeHpackEntry instances. We can thus compare the +// contents of the actual table with those in fake_dynamic_table_. + +typedef std::tuple FakeHpackEntry; +const std::string& Name(const FakeHpackEntry& entry) { + return std::get<0>(entry); +} +const std::string& Value(const FakeHpackEntry& entry) { + return std::get<1>(entry); +} +size_t Size(const FakeHpackEntry& entry) { return std::get<2>(entry); } + +class HpackDecoderTablesTest : public HpackDecoderStaticTableTest { + protected: + const HpackStringPair* Lookup(size_t index) override { + return tables_.Lookup(index); + } + + size_t dynamic_size_limit() const { + return tables_.header_table_size_limit(); + } + size_t current_dynamic_size() const { + return tables_.current_header_table_size(); + } + size_t num_dynamic_entries() const { + return HpackDecoderTablesPeer::num_dynamic_entries(tables_); + } + + // Insert the name and value into fake_dynamic_table_. + void FakeInsert(const std::string& name, const std::string& value) { + FakeHpackEntry entry(name, value, Size(name, value)); + fake_dynamic_table_.insert(fake_dynamic_table_.begin(), entry); + } + + // Add up the size of all entries in fake_dynamic_table_. + size_t FakeSize() { + size_t sz = 0; + for (const auto& entry : fake_dynamic_table_) { + sz += Size(entry); + } + return sz; + } + + // If the total size of the fake_dynamic_table_ is greater than limit, + // keep the first N entries such that those N entries have a size not + // greater than limit, and such that keeping entry N+1 would have a size + // greater than limit. Returns the count of removed bytes. + size_t FakeTrim(size_t limit) { + size_t original_size = FakeSize(); + size_t total_size = 0; + for (size_t ndx = 0; ndx < fake_dynamic_table_.size(); ++ndx) { + total_size += Size(fake_dynamic_table_[ndx]); + if (total_size > limit) { + // Need to get rid of ndx and all following entries. + fake_dynamic_table_.erase(fake_dynamic_table_.begin() + ndx, + fake_dynamic_table_.end()); + return original_size - FakeSize(); + } + } + return 0; + } + + // Verify that the contents of the actual dynamic table match those in + // fake_dynamic_table_. + AssertionResult VerifyDynamicTableContents() { + HTTP2_VERIFY_EQ(current_dynamic_size(), FakeSize()); + HTTP2_VERIFY_EQ(num_dynamic_entries(), fake_dynamic_table_.size()); + + for (size_t ndx = 0; ndx < fake_dynamic_table_.size(); ++ndx) { + const HpackStringPair* found = Lookup(ndx + kFirstDynamicTableIndex); + HTTP2_VERIFY_NE(found, nullptr); + + const auto& expected = fake_dynamic_table_[ndx]; + HTTP2_VERIFY_EQ(Name(expected), found->name); + HTTP2_VERIFY_EQ(Value(expected), found->value); + } + + // Make sure there are no more entries. + HTTP2_VERIFY_EQ( + nullptr, Lookup(fake_dynamic_table_.size() + kFirstDynamicTableIndex)); + return AssertionSuccess(); + } + + // Apply an update to the limit on the maximum size of the dynamic table. + AssertionResult DynamicTableSizeUpdate(size_t size_limit) { + HTTP2_VERIFY_EQ(current_dynamic_size(), FakeSize()); + if (size_limit < current_dynamic_size()) { + // Will need to trim the dynamic table's oldest entries. + tables_.DynamicTableSizeUpdate(size_limit); + FakeTrim(size_limit); + return VerifyDynamicTableContents(); + } + // Shouldn't change the size. + tables_.DynamicTableSizeUpdate(size_limit); + return VerifyDynamicTableContents(); + } + + // Insert an entry into the dynamic table, confirming that trimming of entries + // occurs if the total size is greater than the limit, and that older entries + // move up by 1 index. + AssertionResult Insert(const std::string& name, const std::string& value) { + size_t old_count = num_dynamic_entries(); + tables_.Insert(name, value); + FakeInsert(name, value); + HTTP2_VERIFY_EQ(old_count + 1, fake_dynamic_table_.size()); + FakeTrim(dynamic_size_limit()); + HTTP2_VERIFY_EQ(current_dynamic_size(), FakeSize()); + HTTP2_VERIFY_EQ(num_dynamic_entries(), fake_dynamic_table_.size()); + return VerifyDynamicTableContents(); + } + + private: + HpackDecoderTables tables_; + + std::vector fake_dynamic_table_; +}; + +TEST_F(HpackDecoderTablesTest, StaticTableContents) { + EXPECT_TRUE(VerifyStaticTableContents()); +} + +// Generate a bunch of random header entries, insert them, and confirm they +// present, as required by the RFC, using VerifyDynamicTableContents above on +// each Insert. Also apply various resizings of the dynamic table. +TEST_F(HpackDecoderTablesTest, RandomDynamicTable) { + EXPECT_EQ(0u, current_dynamic_size()); + EXPECT_TRUE(VerifyStaticTableContents()); + EXPECT_TRUE(VerifyDynamicTableContents()); + + std::vector table_sizes; + table_sizes.push_back(dynamic_size_limit()); + table_sizes.push_back(0); + table_sizes.push_back(dynamic_size_limit() / 2); + table_sizes.push_back(dynamic_size_limit()); + table_sizes.push_back(dynamic_size_limit() / 2); + table_sizes.push_back(0); + table_sizes.push_back(dynamic_size_limit()); + + for (size_t limit : table_sizes) { + ASSERT_TRUE(DynamicTableSizeUpdate(limit)); + for (int insert_count = 0; insert_count < 100; ++insert_count) { + std::string name = + GenerateHttp2HeaderName(random_.UniformInRange(2, 40), RandomPtr()); + std::string value = + GenerateWebSafeString(random_.UniformInRange(2, 600), RandomPtr()); + ASSERT_TRUE(Insert(name, value)); + } + EXPECT_TRUE(VerifyStaticTableContents()); + } +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/hpack/decoder/hpack_decoder_test.cc b/quiche/http2/hpack/decoder/hpack_decoder_test.cc new file mode 100644 index 000000000000..0840203d66dc --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_decoder_test.cc @@ -0,0 +1,1187 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/decoder/hpack_decoder.h" + +// Tests of HpackDecoder. + +#include +#include +#include +#include + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/hpack/decoder/hpack_decoder_listener.h" +#include "quiche/http2/hpack/decoder/hpack_decoder_state.h" +#include "quiche/http2/hpack/decoder/hpack_decoder_tables.h" +#include "quiche/http2/hpack/http2_hpack_constants.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/test_tools/hpack_block_builder.h" +#include "quiche/http2/test_tools/hpack_example.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/http2/test_tools/random_util.h" +#include "quiche/http2/test_tools/verify_macros.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +using ::testing::AssertionResult; +using ::testing::AssertionSuccess; +using ::testing::ElementsAreArray; +using ::testing::Eq; + +namespace http2 { +namespace test { +class HpackDecoderStatePeer { + public: + static HpackDecoderTables* GetDecoderTables(HpackDecoderState* state) { + return &state->decoder_tables_; + } + static void set_listener(HpackDecoderState* state, + HpackDecoderListener* listener) { + state->listener_ = listener; + } +}; +class HpackDecoderPeer { + public: + static HpackDecoderState* GetDecoderState(HpackDecoder* decoder) { + return &decoder->decoder_state_; + } + static HpackDecoderTables* GetDecoderTables(HpackDecoder* decoder) { + return HpackDecoderStatePeer::GetDecoderTables(GetDecoderState(decoder)); + } +}; + +namespace { + +typedef std::pair HpackHeaderEntry; +typedef std::vector HpackHeaderEntries; + +// TODO(jamessynge): Create a ...test_utils.h file with the mock listener +// and with VerifyDynamicTableContents. +class MockHpackDecoderListener : public HpackDecoderListener { + public: + MOCK_METHOD(void, OnHeaderListStart, (), (override)); + MOCK_METHOD(void, OnHeader, + (const std::string& name, const std::string& value), (override)); + MOCK_METHOD(void, OnHeaderListEnd, (), (override)); + MOCK_METHOD(void, OnHeaderErrorDetected, (absl::string_view error_message), + (override)); +}; + +class HpackDecoderTest : public quiche::test::QuicheTestWithParam, + public HpackDecoderListener { + protected: + // Note that we initialize the random number generator with the same seed + // for each individual test, therefore the order in which the tests are + // executed does not effect the sequence produced by the RNG within any + // one test. + HpackDecoderTest() : decoder_(this, 4096) { + fragment_the_hpack_block_ = GetParam(); + } + ~HpackDecoderTest() override = default; + + void OnHeaderListStart() override { + ASSERT_FALSE(saw_start_); + ASSERT_FALSE(saw_end_); + saw_start_ = true; + header_entries_.clear(); + } + + // Called for each header name-value pair that is decoded, in the order they + // appear in the HPACK block. Multiple values for a given key will be emitted + // as multiple calls to OnHeader. + void OnHeader(const std::string& name, const std::string& value) override { + ASSERT_TRUE(saw_start_); + ASSERT_FALSE(saw_end_); + header_entries_.emplace_back(name, value); + } + + // OnHeaderBlockEnd is called after successfully decoding an HPACK block. Will + // only be called once per block, even if it extends into CONTINUATION frames. + // A callback method which notifies when the parser finishes handling a + // header block (i.e. the containing frame has the END_STREAM flag set). + // Also indicates the total number of bytes in this block. + void OnHeaderListEnd() override { + ASSERT_TRUE(saw_start_); + ASSERT_FALSE(saw_end_); + ASSERT_TRUE(error_messages_.empty()); + saw_end_ = true; + } + + // OnHeaderErrorDetected is called if an error is detected while decoding. + // error_message may be used in a GOAWAY frame as the Opaque Data. + void OnHeaderErrorDetected(absl::string_view error_message) override { + ASSERT_TRUE(saw_start_); + error_messages_.push_back(std::string(error_message)); + // No further callbacks should be made at this point, so replace 'this' as + // the listener with mock_listener_, which is a strict mock, so will + // generate an error for any calls. + HpackDecoderStatePeer::set_listener( + HpackDecoderPeer::GetDecoderState(&decoder_), &mock_listener_); + } + + AssertionResult DecodeBlock(absl::string_view block) { + QUICHE_VLOG(1) << "HpackDecoderTest::DecodeBlock"; + + HTTP2_VERIFY_FALSE(decoder_.DetectError()); + HTTP2_VERIFY_TRUE(error_messages_.empty()); + HTTP2_VERIFY_FALSE(saw_start_); + HTTP2_VERIFY_FALSE(saw_end_); + header_entries_.clear(); + + HTTP2_VERIFY_FALSE(decoder_.DetectError()); + HTTP2_VERIFY_TRUE(decoder_.StartDecodingBlock()); + HTTP2_VERIFY_FALSE(decoder_.DetectError()); + + if (fragment_the_hpack_block_) { + // See note in ctor regarding RNG. + while (!block.empty()) { + size_t fragment_size = random_.RandomSizeSkewedLow(block.size()); + DecodeBuffer db(block.substr(0, fragment_size)); + HTTP2_VERIFY_TRUE(decoder_.DecodeFragment(&db)); + HTTP2_VERIFY_EQ(0u, db.Remaining()); + block.remove_prefix(fragment_size); + } + } else { + DecodeBuffer db(block); + HTTP2_VERIFY_TRUE(decoder_.DecodeFragment(&db)); + HTTP2_VERIFY_EQ(0u, db.Remaining()); + } + HTTP2_VERIFY_FALSE(decoder_.DetectError()); + + HTTP2_VERIFY_TRUE(decoder_.EndDecodingBlock()); + if (saw_end_) { + HTTP2_VERIFY_FALSE(decoder_.DetectError()); + HTTP2_VERIFY_TRUE(error_messages_.empty()); + } else { + HTTP2_VERIFY_TRUE(decoder_.DetectError()); + HTTP2_VERIFY_FALSE(error_messages_.empty()); + } + + saw_start_ = saw_end_ = false; + return AssertionSuccess(); + } + + const HpackDecoderTables& GetDecoderTables() { + return *HpackDecoderPeer::GetDecoderTables(&decoder_); + } + const HpackStringPair* Lookup(size_t index) { + return GetDecoderTables().Lookup(index); + } + size_t current_header_table_size() { + return GetDecoderTables().current_header_table_size(); + } + size_t header_table_size_limit() { + return GetDecoderTables().header_table_size_limit(); + } + void set_header_table_size_limit(size_t size) { + HpackDecoderPeer::GetDecoderTables(&decoder_)->DynamicTableSizeUpdate(size); + } + + // dynamic_index is one-based, because that is the way RFC 7541 shows it. + AssertionResult VerifyEntry(size_t dynamic_index, const char* name, + const char* value) { + const HpackStringPair* entry = + Lookup(dynamic_index + kFirstDynamicTableIndex - 1); + HTTP2_VERIFY_NE(entry, nullptr); + HTTP2_VERIFY_EQ(entry->name, name); + HTTP2_VERIFY_EQ(entry->value, value); + return AssertionSuccess(); + } + AssertionResult VerifyNoEntry(size_t dynamic_index) { + const HpackStringPair* entry = + Lookup(dynamic_index + kFirstDynamicTableIndex - 1); + HTTP2_VERIFY_EQ(entry, nullptr); + return AssertionSuccess(); + } + AssertionResult VerifyDynamicTableContents( + const std::vector>& entries) { + size_t index = 1; + for (const auto& entry : entries) { + HTTP2_VERIFY_SUCCESS(VerifyEntry(index, entry.first, entry.second)); + ++index; + } + HTTP2_VERIFY_SUCCESS(VerifyNoEntry(index)); + return AssertionSuccess(); + } + + Http2Random random_; + HpackDecoder decoder_; + testing::StrictMock mock_listener_; + HpackHeaderEntries header_entries_; + std::vector error_messages_; + bool fragment_the_hpack_block_; + bool saw_start_ = false; + bool saw_end_ = false; +}; +INSTANTIATE_TEST_SUITE_P(AllWays, HpackDecoderTest, ::testing::Bool()); + +// Test based on RFC 7541, section C.3: Request Examples without Huffman Coding. +// This section shows several consecutive header lists, corresponding to HTTP +// requests, on the same connection. +// http://httpwg.org/specs/rfc7541.html#rfc.section.C.3 +TEST_P(HpackDecoderTest, C3_RequestExamples) { + // C.3.1 First Request + std::string hpack_block = HpackExampleToStringOrDie(R"( + 82 | == Indexed - Add == + | idx = 2 + | -> :method: GET + 86 | == Indexed - Add == + | idx = 6 + | -> :scheme: http + 84 | == Indexed - Add == + | idx = 4 + | -> :path: / + 41 | == Literal indexed == + | Indexed name (idx = 1) + | :authority + 0f | Literal value (len = 15) + 7777 772e 6578 616d 706c 652e 636f 6d | www.example.com + | -> :authority: + | www.example.com + )"); + EXPECT_TRUE(DecodeBlock(hpack_block)); + ASSERT_THAT(header_entries_, + ElementsAreArray({ + HpackHeaderEntry{":method", "GET"}, + HpackHeaderEntry{":scheme", "http"}, + HpackHeaderEntry{":path", "/"}, + HpackHeaderEntry{":authority", "www.example.com"}, + })); + + // Dynamic Table (after decoding): + // + // [ 1] (s = 57) :authority: www.example.com + // Table size: 57 + ASSERT_TRUE(VerifyDynamicTableContents({{":authority", "www.example.com"}})); + ASSERT_EQ(57u, current_header_table_size()); + + // C.3.2 Second Request + hpack_block = HpackExampleToStringOrDie(R"( + 82 | == Indexed - Add == + | idx = 2 + | -> :method: GET + 86 | == Indexed - Add == + | idx = 6 + | -> :scheme: http + 84 | == Indexed - Add == + | idx = 4 + | -> :path: / + be | == Indexed - Add == + | idx = 62 + | -> :authority: + | www.example.com + 58 | == Literal indexed == + | Indexed name (idx = 24) + | cache-control + 08 | Literal value (len = 8) + 6e6f 2d63 6163 6865 | no-cache + | -> cache-control: no-cache + )"); + EXPECT_TRUE(DecodeBlock(hpack_block)); + ASSERT_THAT(header_entries_, + ElementsAreArray({ + HpackHeaderEntry{":method", "GET"}, + HpackHeaderEntry{":scheme", "http"}, + HpackHeaderEntry{":path", "/"}, + HpackHeaderEntry{":authority", "www.example.com"}, + HpackHeaderEntry{"cache-control", "no-cache"}, + })); + + // Dynamic Table (after decoding): + // + // [ 1] (s = 53) cache-control: no-cache + // [ 2] (s = 57) :authority: www.example.com + // Table size: 110 + ASSERT_TRUE(VerifyDynamicTableContents( + {{"cache-control", "no-cache"}, {":authority", "www.example.com"}})); + ASSERT_EQ(110u, current_header_table_size()); + + // C.3.2 Third Request + hpack_block = HpackExampleToStringOrDie(R"( + 82 | == Indexed - Add == + | idx = 2 + | -> :method: GET + 87 | == Indexed - Add == + | idx = 7 + | -> :scheme: https + 85 | == Indexed - Add == + | idx = 5 + | -> :path: /index.html + bf | == Indexed - Add == + | idx = 63 + | -> :authority: + | www.example.com + 40 | == Literal indexed == + 0a | Literal name (len = 10) + 6375 7374 6f6d 2d6b 6579 | custom-key + 0c | Literal value (len = 12) + 6375 7374 6f6d 2d76 616c 7565 | custom-value + | -> custom-key: + | custom-value + )"); + EXPECT_TRUE(DecodeBlock(hpack_block)); + ASSERT_THAT(header_entries_, + ElementsAreArray({ + HpackHeaderEntry{":method", "GET"}, + HpackHeaderEntry{":scheme", "https"}, + HpackHeaderEntry{":path", "/index.html"}, + HpackHeaderEntry{":authority", "www.example.com"}, + HpackHeaderEntry{"custom-key", "custom-value"}, + })); + + // Dynamic Table (after decoding): + // + // [ 1] (s = 54) custom-key: custom-value + // [ 2] (s = 53) cache-control: no-cache + // [ 3] (s = 57) :authority: www.example.com + // Table size: 164 + ASSERT_TRUE(VerifyDynamicTableContents({{"custom-key", "custom-value"}, + {"cache-control", "no-cache"}, + {":authority", "www.example.com"}})); + ASSERT_EQ(164u, current_header_table_size()); +} + +// Test based on RFC 7541, section C.4 Request Examples with Huffman Coding. +// This section shows the same examples as the previous section but uses +// Huffman encoding for the literal values. +// http://httpwg.org/specs/rfc7541.html#rfc.section.C.4 +TEST_P(HpackDecoderTest, C4_RequestExamplesWithHuffmanEncoding) { + // C.4.1 First Request + std::string hpack_block = HpackExampleToStringOrDie(R"( + 82 | == Indexed - Add == + | idx = 2 + | -> :method: GET + 86 | == Indexed - Add == + | idx = 6 + | -> :scheme: http + 84 | == Indexed - Add == + | idx = 4 + | -> :path: / + 41 | == Literal indexed == + | Indexed name (idx = 1) + | :authority + 8c | Literal value (len = 12) + | Huffman encoded: + f1e3 c2e5 f23a 6ba0 ab90 f4ff | .....:k..... + | Decoded: + | www.example.com + | -> :authority: + | www.example.com + )"); + EXPECT_TRUE(DecodeBlock(hpack_block)); + ASSERT_THAT(header_entries_, + ElementsAreArray({ + HpackHeaderEntry{":method", "GET"}, + HpackHeaderEntry{":scheme", "http"}, + HpackHeaderEntry{":path", "/"}, + HpackHeaderEntry{":authority", "www.example.com"}, + })); + + // Dynamic Table (after decoding): + // + // [ 1] (s = 57) :authority: www.example.com + // Table size: 57 + ASSERT_TRUE(VerifyDynamicTableContents({{":authority", "www.example.com"}})); + ASSERT_EQ(57u, current_header_table_size()); + + // C.4.2 Second Request + hpack_block = HpackExampleToStringOrDie(R"( + 82 | == Indexed - Add == + | idx = 2 + | -> :method: GET + 86 | == Indexed - Add == + | idx = 6 + | -> :scheme: http + 84 | == Indexed - Add == + | idx = 4 + | -> :path: / + be | == Indexed - Add == + | idx = 62 + | -> :authority: + | www.example.com + 58 | == Literal indexed == + | Indexed name (idx = 24) + | cache-control + 86 | Literal value (len = 6) + | Huffman encoded: + a8eb 1064 9cbf | ...d.. + | Decoded: + | no-cache + | -> cache-control: no-cache + )"); + EXPECT_TRUE(DecodeBlock(hpack_block)); + ASSERT_THAT(header_entries_, + ElementsAreArray({ + HpackHeaderEntry{":method", "GET"}, + HpackHeaderEntry{":scheme", "http"}, + HpackHeaderEntry{":path", "/"}, + HpackHeaderEntry{":authority", "www.example.com"}, + HpackHeaderEntry{"cache-control", "no-cache"}, + })); + + // Dynamic Table (after decoding): + // + // [ 1] (s = 53) cache-control: no-cache + // [ 2] (s = 57) :authority: www.example.com + // Table size: 110 + ASSERT_TRUE(VerifyDynamicTableContents( + {{"cache-control", "no-cache"}, {":authority", "www.example.com"}})); + ASSERT_EQ(110u, current_header_table_size()); + + // C.4.2 Third Request + hpack_block = HpackExampleToStringOrDie(R"( + 82 | == Indexed - Add == + | idx = 2 + | -> :method: GET + 87 | == Indexed - Add == + | idx = 7 + | -> :scheme: https + 85 | == Indexed - Add == + | idx = 5 + | -> :path: /index.html + bf | == Indexed - Add == + | idx = 63 + | -> :authority: + | www.example.com + 40 | == Literal indexed == + 88 | Literal name (len = 8) + | Huffman encoded: + 25a8 49e9 5ba9 7d7f | %.I.[.}. + | Decoded: + | custom-key + 89 | Literal value (len = 9) + | Huffman encoded: + 25a8 49e9 5bb8 e8b4 bf | %.I.[.... + | Decoded: + | custom-value + | -> custom-key: + | custom-value + )"); + EXPECT_TRUE(DecodeBlock(hpack_block)); + ASSERT_THAT(header_entries_, + ElementsAreArray({ + HpackHeaderEntry{":method", "GET"}, + HpackHeaderEntry{":scheme", "https"}, + HpackHeaderEntry{":path", "/index.html"}, + HpackHeaderEntry{":authority", "www.example.com"}, + HpackHeaderEntry{"custom-key", "custom-value"}, + })); + + // Dynamic Table (after decoding): + // + // [ 1] (s = 54) custom-key: custom-value + // [ 2] (s = 53) cache-control: no-cache + // [ 3] (s = 57) :authority: www.example.com + // Table size: 164 + ASSERT_TRUE(VerifyDynamicTableContents({{"custom-key", "custom-value"}, + {"cache-control", "no-cache"}, + {":authority", "www.example.com"}})); + ASSERT_EQ(164u, current_header_table_size()); +} + +// Test based on RFC 7541, section C.5: Response Examples without Huffman +// Coding. This section shows several consecutive header lists, corresponding +// to HTTP responses, on the same connection. The HTTP/2 setting parameter +// SETTINGS_HEADER_TABLE_SIZE is set to the value of 256 octets, causing +// some evictions to occur. +// http://httpwg.org/specs/rfc7541.html#rfc.section.C.5 +TEST_P(HpackDecoderTest, C5_ResponseExamples) { + set_header_table_size_limit(256); + + // C.5.1 First Response + // + // Header list to encode: + // + // :status: 302 + // cache-control: private + // date: Mon, 21 Oct 2013 20:13:21 GMT + // location: https://www.example.com + + std::string hpack_block = HpackExampleToStringOrDie(R"( + 48 | == Literal indexed == + | Indexed name (idx = 8) + | :status + 03 | Literal value (len = 3) + 3330 32 | 302 + | -> :status: 302 + 58 | == Literal indexed == + | Indexed name (idx = 24) + | cache-control + 07 | Literal value (len = 7) + 7072 6976 6174 65 | private + | -> cache-control: private + 61 | == Literal indexed == + | Indexed name (idx = 33) + | date + 1d | Literal value (len = 29) + 4d6f 6e2c 2032 3120 4f63 7420 3230 3133 | Mon, 21 Oct 2013 + 2032 303a 3133 3a32 3120 474d 54 | 20:13:21 GMT + | -> date: Mon, 21 Oct 2013 + | 20:13:21 GMT + 6e | == Literal indexed == + | Indexed name (idx = 46) + | location + 17 | Literal value (len = 23) + 6874 7470 733a 2f2f 7777 772e 6578 616d | https://www.exam + 706c 652e 636f 6d | ple.com + | -> location: + | https://www.example.com + )"); + EXPECT_TRUE(DecodeBlock(hpack_block)); + ASSERT_THAT(header_entries_, + ElementsAreArray({ + HpackHeaderEntry{":status", "302"}, + HpackHeaderEntry{"cache-control", "private"}, + HpackHeaderEntry{"date", "Mon, 21 Oct 2013 20:13:21 GMT"}, + HpackHeaderEntry{"location", "https://www.example.com"}, + })); + + // Dynamic Table (after decoding): + // + // [ 1] (s = 63) location: https://www.example.com + // [ 2] (s = 65) date: Mon, 21 Oct 2013 20:13:21 GMT + // [ 3] (s = 52) cache-control: private + // [ 4] (s = 42) :status: 302 + // Table size: 222 + ASSERT_TRUE( + VerifyDynamicTableContents({{"location", "https://www.example.com"}, + {"date", "Mon, 21 Oct 2013 20:13:21 GMT"}, + {"cache-control", "private"}, + {":status", "302"}})); + ASSERT_EQ(222u, current_header_table_size()); + + // C.5.2 Second Response + // + // The (":status", "302") header field is evicted from the dynamic table to + // free space to allow adding the (":status", "307") header field. + // + // Header list to encode: + // + // :status: 307 + // cache-control: private + // date: Mon, 21 Oct 2013 20:13:21 GMT + // location: https://www.example.com + + hpack_block = HpackExampleToStringOrDie(R"( + 48 | == Literal indexed == + | Indexed name (idx = 8) + | :status + 03 | Literal value (len = 3) + 3330 37 | 307 + | - evict: :status: 302 + | -> :status: 307 + c1 | == Indexed - Add == + | idx = 65 + | -> cache-control: private + c0 | == Indexed - Add == + | idx = 64 + | -> date: Mon, 21 Oct 2013 + | 20:13:21 GMT + bf | == Indexed - Add == + | idx = 63 + | -> location: + | https://www.example.com + )"); + EXPECT_TRUE(DecodeBlock(hpack_block)); + ASSERT_THAT(header_entries_, + ElementsAreArray({ + HpackHeaderEntry{":status", "307"}, + HpackHeaderEntry{"cache-control", "private"}, + HpackHeaderEntry{"date", "Mon, 21 Oct 2013 20:13:21 GMT"}, + HpackHeaderEntry{"location", "https://www.example.com"}, + })); + + // Dynamic Table (after decoding): + // + // [ 1] (s = 42) :status: 307 + // [ 2] (s = 63) location: https://www.example.com + // [ 3] (s = 65) date: Mon, 21 Oct 2013 20:13:21 GMT + // [ 4] (s = 52) cache-control: private + // Table size: 222 + + ASSERT_TRUE( + VerifyDynamicTableContents({{":status", "307"}, + {"location", "https://www.example.com"}, + {"date", "Mon, 21 Oct 2013 20:13:21 GMT"}, + {"cache-control", "private"}})); + ASSERT_EQ(222u, current_header_table_size()); + + // C.5.3 Third Response + // + // Several header fields are evicted from the dynamic table during the + // processing of this header list. + // + // Header list to encode: + // + // :status: 200 + // cache-control: private + // date: Mon, 21 Oct 2013 20:13:22 GMT + // location: https://www.example.com + // content-encoding: gzip + // set-cookie: foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1 + hpack_block = HpackExampleToStringOrDie(R"( + 88 | == Indexed - Add == + | idx = 8 + | -> :status: 200 + c1 | == Indexed - Add == + | idx = 65 + | -> cache-control: private + 61 | == Literal indexed == + | Indexed name (idx = 33) + | date + 1d | Literal value (len = 29) + 4d6f 6e2c 2032 3120 4f63 7420 3230 3133 | Mon, 21 Oct 2013 + 2032 303a 3133 3a32 3220 474d 54 | 20:13:22 GMT + | - evict: cache-control: + | private + | -> date: Mon, 21 Oct 2013 + | 20:13:22 GMT + c0 | == Indexed - Add == + | idx = 64 + | -> location: + | https://www.example.com + 5a | == Literal indexed == + | Indexed name (idx = 26) + | content-encoding + 04 | Literal value (len = 4) + 677a 6970 | gzip + | - evict: date: Mon, 21 Oct + | 2013 20:13:21 GMT + | -> content-encoding: gzip + 77 | == Literal indexed == + | Indexed name (idx = 55) + | set-cookie + 38 | Literal value (len = 56) + 666f 6f3d 4153 444a 4b48 514b 425a 584f | foo=ASDJKHQKBZXO + 5157 454f 5049 5541 5851 5745 4f49 553b | QWEOPIUAXQWEOIU; + 206d 6178 2d61 6765 3d33 3630 303b 2076 | max-age=3600; v + 6572 7369 6f6e 3d31 | ersion=1 + | - evict: location: + | https://www.example.com + | - evict: :status: 307 + | -> set-cookie: foo=ASDJKHQ + | KBZXOQWEOPIUAXQWEOIU; ma + | x-age=3600; version=1 + )"); + EXPECT_TRUE(DecodeBlock(hpack_block)); + ASSERT_THAT( + header_entries_, + ElementsAreArray({ + HpackHeaderEntry{":status", "200"}, + HpackHeaderEntry{"cache-control", "private"}, + HpackHeaderEntry{"date", "Mon, 21 Oct 2013 20:13:22 GMT"}, + HpackHeaderEntry{"location", "https://www.example.com"}, + HpackHeaderEntry{"content-encoding", "gzip"}, + HpackHeaderEntry{ + "set-cookie", + "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"}, + })); + + // Dynamic Table (after decoding): + // + // [ 1] (s = 98) set-cookie: foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; + // max-age=3600; version=1 + // [ 2] (s = 52) content-encoding: gzip + // [ 3] (s = 65) date: Mon, 21 Oct 2013 20:13:22 GMT + // Table size: 215 + ASSERT_TRUE(VerifyDynamicTableContents( + {{"set-cookie", + "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"}, + {"content-encoding", "gzip"}, + {"date", "Mon, 21 Oct 2013 20:13:22 GMT"}})); + ASSERT_EQ(215u, current_header_table_size()); +} + +// Test based on RFC 7541, section C.6: Response Examples with Huffman Coding. +// This section shows the same examples as the previous section but uses Huffman +// encoding for the literal values. The HTTP/2 setting parameter +// SETTINGS_HEADER_TABLE_SIZE is set to the value of 256 octets, causing some +// evictions to occur. The eviction mechanism uses the length of the decoded +// literal values, so the same evictions occur as in the previous section. +// http://httpwg.org/specs/rfc7541.html#rfc.section.C.6 +TEST_P(HpackDecoderTest, C6_ResponseExamplesWithHuffmanEncoding) { + set_header_table_size_limit(256); + + // C.5.1 First Response + // + // Header list to encode: + // + // :status: 302 + // cache-control: private + // date: Mon, 21 Oct 2013 20:13:21 GMT + // location: https://www.example.com + std::string hpack_block = HpackExampleToStringOrDie(R"( + 48 | == Literal indexed == + | Indexed name (idx = 8) + | :status + 03 | Literal value (len = 3) + 3330 32 | 302 + | -> :status: 302 + 58 | == Literal indexed == + | Indexed name (idx = 24) + | cache-control + 07 | Literal value (len = 7) + 7072 6976 6174 65 | private + | -> cache-control: private + 61 | == Literal indexed == + | Indexed name (idx = 33) + | date + 1d | Literal value (len = 29) + 4d6f 6e2c 2032 3120 4f63 7420 3230 3133 | Mon, 21 Oct 2013 + 2032 303a 3133 3a32 3120 474d 54 | 20:13:21 GMT + | -> date: Mon, 21 Oct 2013 + | 20:13:21 GMT + 6e | == Literal indexed == + | Indexed name (idx = 46) + | location + 17 | Literal value (len = 23) + 6874 7470 733a 2f2f 7777 772e 6578 616d | https://www.exam + 706c 652e 636f 6d | ple.com + | -> location: + | https://www.example.com + )"); + EXPECT_TRUE(DecodeBlock(hpack_block)); + ASSERT_THAT(header_entries_, + ElementsAreArray({ + HpackHeaderEntry{":status", "302"}, + HpackHeaderEntry{"cache-control", "private"}, + HpackHeaderEntry{"date", "Mon, 21 Oct 2013 20:13:21 GMT"}, + HpackHeaderEntry{"location", "https://www.example.com"}, + })); + + // Dynamic Table (after decoding): + // + // [ 1] (s = 63) location: https://www.example.com + // [ 2] (s = 65) date: Mon, 21 Oct 2013 20:13:21 GMT + // [ 3] (s = 52) cache-control: private + // [ 4] (s = 42) :status: 302 + // Table size: 222 + ASSERT_TRUE( + VerifyDynamicTableContents({{"location", "https://www.example.com"}, + {"date", "Mon, 21 Oct 2013 20:13:21 GMT"}, + {"cache-control", "private"}, + {":status", "302"}})); + ASSERT_EQ(222u, current_header_table_size()); + + // C.5.2 Second Response + // + // The (":status", "302") header field is evicted from the dynamic table to + // free space to allow adding the (":status", "307") header field. + // + // Header list to encode: + // + // :status: 307 + // cache-control: private + // date: Mon, 21 Oct 2013 20:13:21 GMT + // location: https://www.example.com + hpack_block = HpackExampleToStringOrDie(R"( + 48 | == Literal indexed == + | Indexed name (idx = 8) + | :status + 03 | Literal value (len = 3) + 3330 37 | 307 + | - evict: :status: 302 + | -> :status: 307 + c1 | == Indexed - Add == + | idx = 65 + | -> cache-control: private + c0 | == Indexed - Add == + | idx = 64 + | -> date: Mon, 21 Oct 2013 + | 20:13:21 GMT + bf | == Indexed - Add == + | idx = 63 + | -> location: + | https://www.example.com + )"); + EXPECT_TRUE(DecodeBlock(hpack_block)); + ASSERT_THAT(header_entries_, + ElementsAreArray({ + HpackHeaderEntry{":status", "307"}, + HpackHeaderEntry{"cache-control", "private"}, + HpackHeaderEntry{"date", "Mon, 21 Oct 2013 20:13:21 GMT"}, + HpackHeaderEntry{"location", "https://www.example.com"}, + })); + + // Dynamic Table (after decoding): + // + // [ 1] (s = 42) :status: 307 + // [ 2] (s = 63) location: https://www.example.com + // [ 3] (s = 65) date: Mon, 21 Oct 2013 20:13:21 GMT + // [ 4] (s = 52) cache-control: private + // Table size: 222 + ASSERT_TRUE( + VerifyDynamicTableContents({{":status", "307"}, + {"location", "https://www.example.com"}, + {"date", "Mon, 21 Oct 2013 20:13:21 GMT"}, + {"cache-control", "private"}})); + ASSERT_EQ(222u, current_header_table_size()); + + // C.5.3 Third Response + // + // Several header fields are evicted from the dynamic table during the + // processing of this header list. + // + // Header list to encode: + // + // :status: 200 + // cache-control: private + // date: Mon, 21 Oct 2013 20:13:22 GMT + // location: https://www.example.com + // content-encoding: gzip + // set-cookie: foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1 + hpack_block = HpackExampleToStringOrDie(R"( + 88 | == Indexed - Add == + | idx = 8 + | -> :status: 200 + c1 | == Indexed - Add == + | idx = 65 + | -> cache-control: private + 61 | == Literal indexed == + | Indexed name (idx = 33) + | date + 1d | Literal value (len = 29) + 4d6f 6e2c 2032 3120 4f63 7420 3230 3133 | Mon, 21 Oct 2013 + 2032 303a 3133 3a32 3220 474d 54 | 20:13:22 GMT + | - evict: cache-control: + | private + | -> date: Mon, 21 Oct 2013 + | 20:13:22 GMT + c0 | == Indexed - Add == + | idx = 64 + | -> location: + | https://www.example.com + 5a | == Literal indexed == + | Indexed name (idx = 26) + | content-encoding + 04 | Literal value (len = 4) + 677a 6970 | gzip + | - evict: date: Mon, 21 Oct + | 2013 20:13:21 GMT + | -> content-encoding: gzip + 77 | == Literal indexed == + | Indexed name (idx = 55) + | set-cookie + 38 | Literal value (len = 56) + 666f 6f3d 4153 444a 4b48 514b 425a 584f | foo=ASDJKHQKBZXO + 5157 454f 5049 5541 5851 5745 4f49 553b | QWEOPIUAXQWEOIU; + 206d 6178 2d61 6765 3d33 3630 303b 2076 | max-age=3600; v + 6572 7369 6f6e 3d31 | ersion=1 + | - evict: location: + | https://www.example.com + | - evict: :status: 307 + | -> set-cookie: foo=ASDJKHQ + | KBZXOQWEOPIUAXQWEOIU; ma + | x-age=3600; version=1 + )"); + EXPECT_TRUE(DecodeBlock(hpack_block)); + ASSERT_THAT( + header_entries_, + ElementsAreArray({ + HpackHeaderEntry{":status", "200"}, + HpackHeaderEntry{"cache-control", "private"}, + HpackHeaderEntry{"date", "Mon, 21 Oct 2013 20:13:22 GMT"}, + HpackHeaderEntry{"location", "https://www.example.com"}, + HpackHeaderEntry{"content-encoding", "gzip"}, + HpackHeaderEntry{ + "set-cookie", + "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"}, + })); + + // Dynamic Table (after decoding): + // + // [ 1] (s = 98) set-cookie: foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; + // max-age=3600; version=1 + // [ 2] (s = 52) content-encoding: gzip + // [ 3] (s = 65) date: Mon, 21 Oct 2013 20:13:22 GMT + // Table size: 215 + ASSERT_TRUE(VerifyDynamicTableContents( + {{"set-cookie", + "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"}, + {"content-encoding", "gzip"}, + {"date", "Mon, 21 Oct 2013 20:13:22 GMT"}})); + ASSERT_EQ(215u, current_header_table_size()); +} + +// Confirm that the table size can be changed, but at most twice. +TEST_P(HpackDecoderTest, ProcessesOptionalTableSizeUpdates) { + EXPECT_EQ(Http2SettingsInfo::DefaultHeaderTableSize(), + header_table_size_limit()); + // One update allowed. + { + HpackBlockBuilder hbb; + hbb.AppendDynamicTableSizeUpdate(3000); + EXPECT_TRUE(DecodeBlock(hbb.buffer())); + EXPECT_EQ(3000u, header_table_size_limit()); + EXPECT_EQ(0u, current_header_table_size()); + EXPECT_TRUE(header_entries_.empty()); + } + // Two updates allowed. + { + HpackBlockBuilder hbb; + hbb.AppendDynamicTableSizeUpdate(2000); + hbb.AppendDynamicTableSizeUpdate(2500); + EXPECT_TRUE(DecodeBlock(hbb.buffer())); + EXPECT_EQ(2500u, header_table_size_limit()); + EXPECT_EQ(0u, current_header_table_size()); + EXPECT_TRUE(header_entries_.empty()); + } + // A third update in the same HPACK block is rejected, so the final + // size is 1000, not 500. + { + HpackBlockBuilder hbb; + hbb.AppendDynamicTableSizeUpdate(1500); + hbb.AppendDynamicTableSizeUpdate(1000); + hbb.AppendDynamicTableSizeUpdate(500); + EXPECT_FALSE(DecodeBlock(hbb.buffer())); + EXPECT_EQ(HpackDecodingError::kDynamicTableSizeUpdateNotAllowed, + decoder_.error()); + EXPECT_EQ(1u, error_messages_.size()); + EXPECT_THAT(error_messages_[0], + Eq("Dynamic table size update not allowed")); + EXPECT_EQ(1000u, header_table_size_limit()); + EXPECT_EQ(0u, current_header_table_size()); + EXPECT_TRUE(header_entries_.empty()); + } + // An error has been detected, so calls to HpackDecoder::DecodeFragment + // should return immediately. + DecodeBuffer db("\x80"); + EXPECT_FALSE(decoder_.DecodeFragment(&db)); + EXPECT_EQ(0u, db.Offset()); + EXPECT_EQ(1u, error_messages_.size()); +} + +// Confirm that the table size can be changed when required, but at most twice. +TEST_P(HpackDecoderTest, ProcessesRequiredTableSizeUpdate) { + EXPECT_EQ(4096u, decoder_.GetCurrentHeaderTableSizeSetting()); + // One update required, two allowed, one provided, followed by a header. + decoder_.ApplyHeaderTableSizeSetting(1024); + decoder_.ApplyHeaderTableSizeSetting(2048); + EXPECT_EQ(Http2SettingsInfo::DefaultHeaderTableSize(), + header_table_size_limit()); + EXPECT_EQ(2048u, decoder_.GetCurrentHeaderTableSizeSetting()); + { + HpackBlockBuilder hbb; + hbb.AppendDynamicTableSizeUpdate(1024); + hbb.AppendIndexedHeader(4); // :path: / + EXPECT_TRUE(DecodeBlock(hbb.buffer())); + EXPECT_THAT(header_entries_, + ElementsAreArray({HpackHeaderEntry{":path", "/"}})); + EXPECT_EQ(1024u, header_table_size_limit()); + EXPECT_EQ(0u, current_header_table_size()); + } + // One update required, two allowed, two provided, followed by a header. + decoder_.ApplyHeaderTableSizeSetting(1000); + decoder_.ApplyHeaderTableSizeSetting(1500); + EXPECT_EQ(1500u, decoder_.GetCurrentHeaderTableSizeSetting()); + { + HpackBlockBuilder hbb; + hbb.AppendDynamicTableSizeUpdate(500); + hbb.AppendDynamicTableSizeUpdate(1250); + hbb.AppendIndexedHeader(5); // :path: /index.html + EXPECT_TRUE(DecodeBlock(hbb.buffer())); + EXPECT_THAT(header_entries_, + ElementsAreArray({HpackHeaderEntry{":path", "/index.html"}})); + EXPECT_EQ(1250u, header_table_size_limit()); + EXPECT_EQ(0u, current_header_table_size()); + } + // One update required, two allowed, three provided, followed by a header. + // The third update is rejected, so the final size is 1000, not 500. + decoder_.ApplyHeaderTableSizeSetting(500); + decoder_.ApplyHeaderTableSizeSetting(1000); + EXPECT_EQ(1000u, decoder_.GetCurrentHeaderTableSizeSetting()); + { + HpackBlockBuilder hbb; + hbb.AppendDynamicTableSizeUpdate(200); + hbb.AppendDynamicTableSizeUpdate(700); + hbb.AppendDynamicTableSizeUpdate(900); + hbb.AppendIndexedHeader(5); // Not decoded. + EXPECT_FALSE(DecodeBlock(hbb.buffer())); + EXPECT_FALSE(saw_end_); + EXPECT_EQ(HpackDecodingError::kDynamicTableSizeUpdateNotAllowed, + decoder_.error()); + EXPECT_EQ(1u, error_messages_.size()); + EXPECT_THAT(error_messages_[0], + Eq("Dynamic table size update not allowed")); + EXPECT_EQ(700u, header_table_size_limit()); + EXPECT_EQ(0u, current_header_table_size()); + EXPECT_TRUE(header_entries_.empty()); + } + EXPECT_EQ(1000u, decoder_.GetCurrentHeaderTableSizeSetting()); + // Now that an error has been detected, StartDecodingBlock should return + // false. + EXPECT_FALSE(decoder_.StartDecodingBlock()); +} + +// Confirm that required size updates are validated. +TEST_P(HpackDecoderTest, InvalidRequiredSizeUpdate) { + // Require a size update, but provide one that isn't small enough (must be + // zero or one, in this case). + decoder_.ApplyHeaderTableSizeSetting(1); + decoder_.ApplyHeaderTableSizeSetting(1024); + HpackBlockBuilder hbb; + hbb.AppendDynamicTableSizeUpdate(2); + EXPECT_TRUE(decoder_.StartDecodingBlock()); + DecodeBuffer db(hbb.buffer()); + EXPECT_FALSE(decoder_.DecodeFragment(&db)); + EXPECT_FALSE(saw_end_); + EXPECT_EQ( + HpackDecodingError::kInitialDynamicTableSizeUpdateIsAboveLowWaterMark, + decoder_.error()); + EXPECT_EQ(1u, error_messages_.size()); + EXPECT_THAT(error_messages_[0], + Eq("Initial dynamic table size update is above low water mark")); + EXPECT_EQ(Http2SettingsInfo::DefaultHeaderTableSize(), + header_table_size_limit()); +} + +// Confirm that required size updates are indeed required before the end. +TEST_P(HpackDecoderTest, RequiredTableSizeChangeBeforeEnd) { + decoder_.ApplyHeaderTableSizeSetting(1024); + EXPECT_FALSE(DecodeBlock("")); + EXPECT_EQ(HpackDecodingError::kMissingDynamicTableSizeUpdate, + decoder_.error()); + EXPECT_EQ(1u, error_messages_.size()); + EXPECT_THAT(error_messages_[0], Eq("Missing dynamic table size update")); + EXPECT_FALSE(saw_end_); +} + +// Confirm that required size updates are indeed required before an +// indexed header. +TEST_P(HpackDecoderTest, RequiredTableSizeChangeBeforeIndexedHeader) { + decoder_.ApplyHeaderTableSizeSetting(1024); + HpackBlockBuilder hbb; + hbb.AppendIndexedHeader(1); + EXPECT_FALSE(DecodeBlock(hbb.buffer())); + EXPECT_EQ(HpackDecodingError::kMissingDynamicTableSizeUpdate, + decoder_.error()); + EXPECT_EQ(1u, error_messages_.size()); + EXPECT_THAT(error_messages_[0], Eq("Missing dynamic table size update")); + EXPECT_FALSE(saw_end_); + EXPECT_TRUE(header_entries_.empty()); +} + +// Confirm that required size updates are indeed required before an indexed +// header name. +// TODO(jamessynge): Move some of these to hpack_decoder_state_test.cc. +TEST_P(HpackDecoderTest, RequiredTableSizeChangeBeforeIndexedHeaderName) { + decoder_.ApplyHeaderTableSizeSetting(1024); + HpackBlockBuilder hbb; + hbb.AppendNameIndexAndLiteralValue(HpackEntryType::kIndexedLiteralHeader, 2, + false, "PUT"); + EXPECT_FALSE(DecodeBlock(hbb.buffer())); + EXPECT_EQ(HpackDecodingError::kMissingDynamicTableSizeUpdate, + decoder_.error()); + EXPECT_EQ(1u, error_messages_.size()); + EXPECT_THAT(error_messages_[0], Eq("Missing dynamic table size update")); + EXPECT_FALSE(saw_end_); + EXPECT_TRUE(header_entries_.empty()); +} + +// Confirm that required size updates are indeed required before a literal +// header name. +TEST_P(HpackDecoderTest, RequiredTableSizeChangeBeforeLiteralName) { + decoder_.ApplyHeaderTableSizeSetting(1024); + HpackBlockBuilder hbb; + hbb.AppendLiteralNameAndValue(HpackEntryType::kNeverIndexedLiteralHeader, + false, "name", false, "some data."); + EXPECT_FALSE(DecodeBlock(hbb.buffer())); + EXPECT_EQ(HpackDecodingError::kMissingDynamicTableSizeUpdate, + decoder_.error()); + EXPECT_EQ(1u, error_messages_.size()); + EXPECT_THAT(error_messages_[0], Eq("Missing dynamic table size update")); + EXPECT_FALSE(saw_end_); + EXPECT_TRUE(header_entries_.empty()); +} + +// Confirm that an excessively long varint is detected, in this case an +// index of 127, but with lots of additional high-order 0 bits provided, +// too many to be allowed. +TEST_P(HpackDecoderTest, InvalidIndexedHeaderVarint) { + EXPECT_TRUE(decoder_.StartDecodingBlock()); + DecodeBuffer db("\xff\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x00"); + EXPECT_FALSE(decoder_.DecodeFragment(&db)); + EXPECT_TRUE(decoder_.DetectError()); + EXPECT_FALSE(saw_end_); + EXPECT_EQ(HpackDecodingError::kIndexVarintError, decoder_.error()); + EXPECT_EQ(1u, error_messages_.size()); + EXPECT_THAT(error_messages_[0], + Eq("Index varint beyond implementation limit")); + EXPECT_TRUE(header_entries_.empty()); + // Now that an error has been detected, EndDecodingBlock should not succeed. + EXPECT_FALSE(decoder_.EndDecodingBlock()); +} + +// Confirm that an invalid index into the tables is detected, in this case an +// index of 0. +TEST_P(HpackDecoderTest, InvalidIndex) { + EXPECT_TRUE(decoder_.StartDecodingBlock()); + DecodeBuffer db("\x80"); + EXPECT_FALSE(decoder_.DecodeFragment(&db)); + EXPECT_TRUE(decoder_.DetectError()); + EXPECT_FALSE(saw_end_); + EXPECT_EQ(HpackDecodingError::kInvalidIndex, decoder_.error()); + EXPECT_EQ(1u, error_messages_.size()); + EXPECT_THAT(error_messages_[0], + Eq("Invalid index in indexed header field representation")); + EXPECT_TRUE(header_entries_.empty()); + // Now that an error has been detected, EndDecodingBlock should not succeed. + EXPECT_FALSE(decoder_.EndDecodingBlock()); +} + +// Confirm that EndDecodingBlock detects a truncated HPACK block. +TEST_P(HpackDecoderTest, TruncatedBlock) { + HpackBlockBuilder hbb; + hbb.AppendDynamicTableSizeUpdate(3000); + EXPECT_EQ(3u, hbb.size()); + hbb.AppendDynamicTableSizeUpdate(4000); + EXPECT_EQ(6u, hbb.size()); + // Decodes this block if the whole thing is provided. + EXPECT_TRUE(DecodeBlock(hbb.buffer())); + EXPECT_EQ(4000u, header_table_size_limit()); + // Multiple times even. + EXPECT_TRUE(DecodeBlock(hbb.buffer())); + EXPECT_EQ(4000u, header_table_size_limit()); + // But not if the block is truncated. + EXPECT_FALSE(DecodeBlock(hbb.buffer().substr(0, hbb.size() - 1))); + EXPECT_FALSE(saw_end_); + EXPECT_EQ(HpackDecodingError::kTruncatedBlock, decoder_.error()); + EXPECT_EQ(1u, error_messages_.size()); + EXPECT_THAT(error_messages_[0], + Eq("Block ends in the middle of an instruction")); + // The first update was decoded. + EXPECT_EQ(3000u, header_table_size_limit()); + EXPECT_EQ(0u, current_header_table_size()); + EXPECT_TRUE(header_entries_.empty()); +} + +// Confirm that an oversized string is detected, ending decoding. +TEST_P(HpackDecoderTest, OversizeStringDetected) { + HpackBlockBuilder hbb; + hbb.AppendLiteralNameAndValue(HpackEntryType::kNeverIndexedLiteralHeader, + false, "name", false, "some data."); + hbb.AppendLiteralNameAndValue(HpackEntryType::kUnindexedLiteralHeader, false, + "name2", false, "longer data"); + + // Normally able to decode this block. + EXPECT_TRUE(DecodeBlock(hbb.buffer())); + EXPECT_THAT(header_entries_, + ElementsAreArray({HpackHeaderEntry{"name", "some data."}, + HpackHeaderEntry{"name2", "longer data"}})); + + // But not if the maximum size of strings is less than the longest string. + decoder_.set_max_string_size_bytes(10); + EXPECT_FALSE(DecodeBlock(hbb.buffer())); + EXPECT_THAT(header_entries_, + ElementsAreArray({HpackHeaderEntry{"name", "some data."}})); + EXPECT_FALSE(saw_end_); + EXPECT_EQ(HpackDecodingError::kValueTooLong, decoder_.error()); + EXPECT_EQ(1u, error_messages_.size()); + EXPECT_THAT(error_messages_[0], Eq("Value length exceeds buffer limit")); +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/hpack/decoder/hpack_decoding_error.cc b/quiche/http2/hpack/decoder/hpack_decoding_error.cc new file mode 100644 index 000000000000..9adb2e29b9bc --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_decoding_error.cc @@ -0,0 +1,51 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/decoder/hpack_decoding_error.h" + +namespace http2 { + +// static +absl::string_view HpackDecodingErrorToString(HpackDecodingError error) { + switch (error) { + case HpackDecodingError::kOk: + return "No error detected"; + case HpackDecodingError::kIndexVarintError: + return "Index varint beyond implementation limit"; + case HpackDecodingError::kNameLengthVarintError: + return "Name length varint beyond implementation limit"; + case HpackDecodingError::kValueLengthVarintError: + return "Value length varint beyond implementation limit"; + case HpackDecodingError::kNameTooLong: + return "Name length exceeds buffer limit"; + case HpackDecodingError::kValueTooLong: + return "Value length exceeds buffer limit"; + case HpackDecodingError::kNameHuffmanError: + return "Name Huffman encoding error"; + case HpackDecodingError::kValueHuffmanError: + return "Value Huffman encoding error"; + case HpackDecodingError::kMissingDynamicTableSizeUpdate: + return "Missing dynamic table size update"; + case HpackDecodingError::kInvalidIndex: + return "Invalid index in indexed header field representation"; + case HpackDecodingError::kInvalidNameIndex: + return "Invalid index in literal header field with indexed name " + "representation"; + case HpackDecodingError::kDynamicTableSizeUpdateNotAllowed: + return "Dynamic table size update not allowed"; + case HpackDecodingError::kInitialDynamicTableSizeUpdateIsAboveLowWaterMark: + return "Initial dynamic table size update is above low water mark"; + case HpackDecodingError::kDynamicTableSizeUpdateIsAboveAcknowledgedSetting: + return "Dynamic table size update is above acknowledged setting"; + case HpackDecodingError::kTruncatedBlock: + return "Block ends in the middle of an instruction"; + case HpackDecodingError::kFragmentTooLong: + return "Incoming data fragment exceeds buffer limit"; + case HpackDecodingError::kCompressedHeaderSizeExceedsLimit: + return "Total compressed HPACK data size exceeds limit"; + } + return "invalid HpackDecodingError value"; +} + +} // namespace http2 diff --git a/quiche/http2/hpack/decoder/hpack_decoding_error.h b/quiche/http2/hpack/decoder/hpack_decoding_error.h new file mode 100644 index 000000000000..538237ff2fb8 --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_decoding_error.h @@ -0,0 +1,51 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_HPACK_DECODER_HPACK_DECODING_ERROR_H_ +#define QUICHE_HTTP2_HPACK_DECODER_HPACK_DECODING_ERROR_H_ + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { + +enum class HpackDecodingError { + // No error detected so far. + kOk, + // Varint beyond implementation limit. + kIndexVarintError, + kNameLengthVarintError, + kValueLengthVarintError, + // String literal length exceeds buffer limit. + kNameTooLong, + kValueTooLong, + // Error in Huffman encoding. + kNameHuffmanError, + kValueHuffmanError, + // Next instruction should have been a dynamic table size update. + kMissingDynamicTableSizeUpdate, + // Invalid index in indexed header field representation. + kInvalidIndex, + // Invalid index in literal header field with indexed name representation. + kInvalidNameIndex, + // Dynamic table size update not allowed. + kDynamicTableSizeUpdateNotAllowed, + // Initial dynamic table size update is above low water mark. + kInitialDynamicTableSizeUpdateIsAboveLowWaterMark, + // Dynamic table size update is above acknowledged setting. + kDynamicTableSizeUpdateIsAboveAcknowledgedSetting, + // HPACK block ends in the middle of an instruction. + kTruncatedBlock, + // Incoming data fragment exceeds buffer limit. + kFragmentTooLong, + // Total compressed HPACK data size exceeds limit. + kCompressedHeaderSizeExceedsLimit, +}; + +QUICHE_EXPORT absl::string_view HpackDecodingErrorToString( + HpackDecodingError error); + +} // namespace http2 + +#endif // QUICHE_HTTP2_HPACK_DECODER_HPACK_DECODING_ERROR_H_ diff --git a/quiche/http2/hpack/decoder/hpack_entry_collector_test.cc b/quiche/http2/hpack/decoder/hpack_entry_collector_test.cc new file mode 100644 index 000000000000..8512ce085427 --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_entry_collector_test.cc @@ -0,0 +1,155 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/test_tools/hpack_entry_collector.h" + +// Tests of HpackEntryCollector. + +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +using ::testing::HasSubstr; + +namespace http2 { +namespace test { +namespace { + +TEST(HpackEntryCollectorTest, Clear) { + HpackEntryCollector collector; + QUICHE_VLOG(1) << collector; + EXPECT_THAT(collector.ToString(), HasSubstr("!started")); + EXPECT_TRUE(collector.IsClear()); + collector.set_header_type(HpackEntryType::kIndexedLiteralHeader); + EXPECT_FALSE(collector.IsClear()); + QUICHE_VLOG(1) << collector; + collector.Clear(); + EXPECT_TRUE(collector.IsClear()); + collector.set_index(123); + EXPECT_FALSE(collector.IsClear()); + QUICHE_VLOG(1) << collector; + collector.Clear(); + EXPECT_TRUE(collector.IsClear()); + collector.set_name(HpackStringCollector("name", true)); + EXPECT_FALSE(collector.IsClear()); + QUICHE_VLOG(1) << collector; + collector.Clear(); + EXPECT_TRUE(collector.IsClear()); + collector.set_value(HpackStringCollector("value", false)); + EXPECT_FALSE(collector.IsClear()); + QUICHE_VLOG(1) << collector; +} + +// EXPECT_FATAL_FAILURE can not access variables in the scope of a test body, +// including the this variable so can not access non-static members. So, we +// define this test outside of the test body. +void IndexedHeaderErrorTest() { + HpackEntryCollector collector; + collector.OnIndexedHeader(1); + // The next statement will fail because the collector + // has already been used. + collector.OnIndexedHeader(234); +} + +TEST(HpackEntryCollectorTest, IndexedHeader) { + HpackEntryCollector collector; + collector.OnIndexedHeader(123); + QUICHE_VLOG(1) << collector; + EXPECT_FALSE(collector.IsClear()); + EXPECT_TRUE(collector.IsComplete()); + EXPECT_TRUE(collector.ValidateIndexedHeader(123)); + EXPECT_THAT(collector.ToString(), HasSubstr("IndexedHeader")); + EXPECT_THAT(collector.ToString(), HasSubstr("Complete")); + EXPECT_FATAL_FAILURE(IndexedHeaderErrorTest(), "Value of: started_"); +} + +void LiteralValueErrorTest() { + HpackEntryCollector collector; + collector.OnStartLiteralHeader(HpackEntryType::kIndexedLiteralHeader, 1); + // OnNameStart is not expected because an index was specified for the name. + collector.OnNameStart(false, 10); +} + +TEST(HpackEntryCollectorTest, LiteralValueHeader) { + HpackEntryCollector collector; + collector.OnStartLiteralHeader(HpackEntryType::kIndexedLiteralHeader, 4); + QUICHE_VLOG(1) << collector; + EXPECT_FALSE(collector.IsClear()); + EXPECT_FALSE(collector.IsComplete()); + EXPECT_THAT(collector.ToString(), HasSubstr("!ended")); + collector.OnValueStart(true, 5); + QUICHE_VLOG(1) << collector; + collector.OnValueData("value", 5); + collector.OnValueEnd(); + QUICHE_VLOG(1) << collector; + EXPECT_FALSE(collector.IsClear()); + EXPECT_TRUE(collector.IsComplete()); + EXPECT_TRUE(collector.ValidateLiteralValueHeader( + HpackEntryType::kIndexedLiteralHeader, 4, true, "value")); + EXPECT_THAT(collector.ToString(), HasSubstr("IndexedLiteralHeader")); + EXPECT_THAT(collector.ToString(), HasSubstr("Complete")); + EXPECT_FATAL_FAILURE(LiteralValueErrorTest(), + "Value of: LiteralNameExpected"); +} + +void LiteralNameValueHeaderErrorTest() { + HpackEntryCollector collector; + collector.OnStartLiteralHeader(HpackEntryType::kNeverIndexedLiteralHeader, 0); + // OnValueStart is not expected until the name has ended. + collector.OnValueStart(false, 10); +} + +TEST(HpackEntryCollectorTest, LiteralNameValueHeader) { + HpackEntryCollector collector; + collector.OnStartLiteralHeader(HpackEntryType::kUnindexedLiteralHeader, 0); + QUICHE_VLOG(1) << collector; + EXPECT_FALSE(collector.IsClear()); + EXPECT_FALSE(collector.IsComplete()); + collector.OnNameStart(false, 4); + collector.OnNameData("na", 2); + QUICHE_VLOG(1) << collector; + collector.OnNameData("me", 2); + collector.OnNameEnd(); + collector.OnValueStart(true, 5); + QUICHE_VLOG(1) << collector; + collector.OnValueData("Value", 5); + collector.OnValueEnd(); + QUICHE_VLOG(1) << collector; + EXPECT_FALSE(collector.IsClear()); + EXPECT_TRUE(collector.IsComplete()); + EXPECT_TRUE(collector.ValidateLiteralNameValueHeader( + HpackEntryType::kUnindexedLiteralHeader, false, "name", true, "Value")); + EXPECT_FATAL_FAILURE(LiteralNameValueHeaderErrorTest(), + "Value of: name_.HasEnded"); +} + +void DynamicTableSizeUpdateErrorTest() { + HpackEntryCollector collector; + collector.OnDynamicTableSizeUpdate(123); + EXPECT_FALSE(collector.IsClear()); + EXPECT_TRUE(collector.IsComplete()); + EXPECT_TRUE(collector.ValidateDynamicTableSizeUpdate(123)); + // The next statement will fail because the collector + // has already been used. + collector.OnDynamicTableSizeUpdate(234); +} + +TEST(HpackEntryCollectorTest, DynamicTableSizeUpdate) { + HpackEntryCollector collector; + collector.OnDynamicTableSizeUpdate(8192); + QUICHE_VLOG(1) << collector; + EXPECT_FALSE(collector.IsClear()); + EXPECT_TRUE(collector.IsComplete()); + EXPECT_TRUE(collector.ValidateDynamicTableSizeUpdate(8192)); + EXPECT_EQ(collector, + HpackEntryCollector(HpackEntryType::kDynamicTableSizeUpdate, 8192)); + EXPECT_NE(collector, + HpackEntryCollector(HpackEntryType::kIndexedHeader, 8192)); + EXPECT_NE(collector, + HpackEntryCollector(HpackEntryType::kDynamicTableSizeUpdate, 8191)); + EXPECT_FATAL_FAILURE(DynamicTableSizeUpdateErrorTest(), "Value of: started_"); +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/hpack/decoder/hpack_entry_decoder.cc b/quiche/http2/hpack/decoder/hpack_entry_decoder.cc new file mode 100644 index 000000000000..23ef25ab38df --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_entry_decoder.cc @@ -0,0 +1,294 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/decoder/hpack_entry_decoder.h" + +#include + +#include + +#include "absl/base/macros.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_flag_utils.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { +namespace { +// Converts calls from HpackStringDecoder when decoding a header name into the +// appropriate HpackEntryDecoderListener::OnName* calls. +class NameDecoderListener { + public: + explicit NameDecoderListener(HpackEntryDecoderListener* listener) + : listener_(listener) {} + bool OnStringStart(bool huffman_encoded, size_t len) { + listener_->OnNameStart(huffman_encoded, len); + return true; + } + void OnStringData(const char* data, size_t len) { + listener_->OnNameData(data, len); + } + void OnStringEnd() { listener_->OnNameEnd(); } + + private: + HpackEntryDecoderListener* listener_; +}; + +// Converts calls from HpackStringDecoder when decoding a header value into +// the appropriate HpackEntryDecoderListener::OnValue* calls. +class ValueDecoderListener { + public: + explicit ValueDecoderListener(HpackEntryDecoderListener* listener) + : listener_(listener) {} + bool OnStringStart(bool huffman_encoded, size_t len) { + listener_->OnValueStart(huffman_encoded, len); + return true; + } + void OnStringData(const char* data, size_t len) { + listener_->OnValueData(data, len); + } + void OnStringEnd() { listener_->OnValueEnd(); } + + private: + HpackEntryDecoderListener* listener_; +}; +} // namespace + +DecodeStatus HpackEntryDecoder::Start(DecodeBuffer* db, + HpackEntryDecoderListener* listener) { + QUICHE_DCHECK(db != nullptr); + QUICHE_DCHECK(listener != nullptr); + QUICHE_DCHECK(db->HasData()); + DecodeStatus status = entry_type_decoder_.Start(db); + switch (status) { + case DecodeStatus::kDecodeDone: + // The type of the entry and its varint fit into the current decode + // buffer. + if (entry_type_decoder_.entry_type() == HpackEntryType::kIndexedHeader) { + // The entry consists solely of the entry type and varint. + // This is by far the most common case in practice. + listener->OnIndexedHeader(entry_type_decoder_.varint()); + return DecodeStatus::kDecodeDone; + } + state_ = EntryDecoderState::kDecodedType; + return Resume(db, listener); + case DecodeStatus::kDecodeInProgress: + // Hit the end of the decode buffer before fully decoding + // the entry type and varint. + QUICHE_DCHECK_EQ(0u, db->Remaining()); + state_ = EntryDecoderState::kResumeDecodingType; + return status; + case DecodeStatus::kDecodeError: + QUICHE_CODE_COUNT_N(decompress_failure_3, 11, 23); + error_ = HpackDecodingError::kIndexVarintError; + // The varint must have been invalid (too long). + return status; + } + + QUICHE_BUG(http2_bug_63_1) << "Unreachable"; + return DecodeStatus::kDecodeError; +} + +DecodeStatus HpackEntryDecoder::Resume(DecodeBuffer* db, + HpackEntryDecoderListener* listener) { + QUICHE_DCHECK(db != nullptr); + QUICHE_DCHECK(listener != nullptr); + + DecodeStatus status; + + do { + switch (state_) { + case EntryDecoderState::kResumeDecodingType: + // entry_type_decoder_ returned kDecodeInProgress when last called. + QUICHE_DVLOG(1) << "kResumeDecodingType: db->Remaining=" + << db->Remaining(); + status = entry_type_decoder_.Resume(db); + if (status == DecodeStatus::kDecodeError) { + QUICHE_CODE_COUNT_N(decompress_failure_3, 12, 23); + error_ = HpackDecodingError::kIndexVarintError; + } + if (status != DecodeStatus::kDecodeDone) { + return status; + } + state_ = EntryDecoderState::kDecodedType; + ABSL_FALLTHROUGH_INTENDED; + + case EntryDecoderState::kDecodedType: + // entry_type_decoder_ returned kDecodeDone, now need to decide how + // to proceed. + QUICHE_DVLOG(1) << "kDecodedType: db->Remaining=" << db->Remaining(); + if (DispatchOnType(listener)) { + // All done. + return DecodeStatus::kDecodeDone; + } + continue; + + case EntryDecoderState::kStartDecodingName: + QUICHE_DVLOG(1) << "kStartDecodingName: db->Remaining=" + << db->Remaining(); + { + NameDecoderListener ncb(listener); + status = string_decoder_.Start(db, &ncb); + } + if (status != DecodeStatus::kDecodeDone) { + // On the assumption that the status is kDecodeInProgress, set + // state_ accordingly; unnecessary if status is kDecodeError, but + // that will only happen if the varint encoding the name's length + // is too long. + state_ = EntryDecoderState::kResumeDecodingName; + if (status == DecodeStatus::kDecodeError) { + QUICHE_CODE_COUNT_N(decompress_failure_3, 13, 23); + error_ = HpackDecodingError::kNameLengthVarintError; + } + return status; + } + state_ = EntryDecoderState::kStartDecodingValue; + ABSL_FALLTHROUGH_INTENDED; + + case EntryDecoderState::kStartDecodingValue: + QUICHE_DVLOG(1) << "kStartDecodingValue: db->Remaining=" + << db->Remaining(); + { + ValueDecoderListener vcb(listener); + status = string_decoder_.Start(db, &vcb); + } + if (status == DecodeStatus::kDecodeError) { + QUICHE_CODE_COUNT_N(decompress_failure_3, 14, 23); + error_ = HpackDecodingError::kValueLengthVarintError; + } + if (status == DecodeStatus::kDecodeDone) { + // Done with decoding the literal value, so we've reached the + // end of the header entry. + return status; + } + // On the assumption that the status is kDecodeInProgress, set + // state_ accordingly; unnecessary if status is kDecodeError, but + // that will only happen if the varint encoding the value's length + // is too long. + state_ = EntryDecoderState::kResumeDecodingValue; + return status; + + case EntryDecoderState::kResumeDecodingName: + // The literal name was split across decode buffers. + QUICHE_DVLOG(1) << "kResumeDecodingName: db->Remaining=" + << db->Remaining(); + { + NameDecoderListener ncb(listener); + status = string_decoder_.Resume(db, &ncb); + } + if (status != DecodeStatus::kDecodeDone) { + // On the assumption that the status is kDecodeInProgress, set + // state_ accordingly; unnecessary if status is kDecodeError, but + // that will only happen if the varint encoding the name's length + // is too long. + state_ = EntryDecoderState::kResumeDecodingName; + if (status == DecodeStatus::kDecodeError) { + QUICHE_CODE_COUNT_N(decompress_failure_3, 15, 23); + error_ = HpackDecodingError::kNameLengthVarintError; + } + return status; + } + state_ = EntryDecoderState::kStartDecodingValue; + break; + + case EntryDecoderState::kResumeDecodingValue: + // The literal value was split across decode buffers. + QUICHE_DVLOG(1) << "kResumeDecodingValue: db->Remaining=" + << db->Remaining(); + { + ValueDecoderListener vcb(listener); + status = string_decoder_.Resume(db, &vcb); + } + if (status == DecodeStatus::kDecodeError) { + QUICHE_CODE_COUNT_N(decompress_failure_3, 16, 23); + error_ = HpackDecodingError::kValueLengthVarintError; + } + if (status == DecodeStatus::kDecodeDone) { + // Done with decoding the value, therefore the entry as a whole. + return status; + } + // On the assumption that the status is kDecodeInProgress, set + // state_ accordingly; unnecessary if status is kDecodeError, but + // that will only happen if the varint encoding the value's length + // is too long. + state_ = EntryDecoderState::kResumeDecodingValue; + return status; + } + } while (true); +} + +bool HpackEntryDecoder::DispatchOnType(HpackEntryDecoderListener* listener) { + const HpackEntryType entry_type = entry_type_decoder_.entry_type(); + const uint32_t varint = static_cast(entry_type_decoder_.varint()); + switch (entry_type) { + case HpackEntryType::kIndexedHeader: + // The entry consists solely of the entry type and varint. See: + // http://httpwg.org/specs/rfc7541.html#indexed.header.representation + listener->OnIndexedHeader(varint); + return true; + + case HpackEntryType::kIndexedLiteralHeader: + case HpackEntryType::kUnindexedLiteralHeader: + case HpackEntryType::kNeverIndexedLiteralHeader: + // The entry has a literal value, and if the varint is zero also has a + // literal name preceding the value. See: + // http://httpwg.org/specs/rfc7541.html#literal.header.representation + listener->OnStartLiteralHeader(entry_type, varint); + if (varint == 0) { + state_ = EntryDecoderState::kStartDecodingName; + } else { + state_ = EntryDecoderState::kStartDecodingValue; + } + return false; + + case HpackEntryType::kDynamicTableSizeUpdate: + // The entry consists solely of the entry type and varint. FWIW, I've + // never seen this type of entry in production (primarily browser + // traffic) so if you're designing an HPACK successor someday, consider + // dropping it or giving it a much longer prefix. See: + // http://httpwg.org/specs/rfc7541.html#encoding.context.update + listener->OnDynamicTableSizeUpdate(varint); + return true; + } + + QUICHE_BUG(http2_bug_63_2) << "Unreachable, entry_type=" << entry_type; + return true; +} + +void HpackEntryDecoder::OutputDebugString(std::ostream& out) const { + out << "HpackEntryDecoder(state=" << state_ << ", " << entry_type_decoder_ + << ", " << string_decoder_ << ")"; +} + +std::string HpackEntryDecoder::DebugString() const { + std::stringstream s; + s << *this; + return s.str(); +} + +std::ostream& operator<<(std::ostream& out, const HpackEntryDecoder& v) { + v.OutputDebugString(out); + return out; +} + +std::ostream& operator<<(std::ostream& out, + HpackEntryDecoder::EntryDecoderState state) { + typedef HpackEntryDecoder::EntryDecoderState EntryDecoderState; + switch (state) { + case EntryDecoderState::kResumeDecodingType: + return out << "kResumeDecodingType"; + case EntryDecoderState::kDecodedType: + return out << "kDecodedType"; + case EntryDecoderState::kStartDecodingName: + return out << "kStartDecodingName"; + case EntryDecoderState::kResumeDecodingName: + return out << "kResumeDecodingName"; + case EntryDecoderState::kStartDecodingValue: + return out << "kStartDecodingValue"; + case EntryDecoderState::kResumeDecodingValue: + return out << "kResumeDecodingValue"; + } + return out << static_cast(state); +} + +} // namespace http2 diff --git a/quiche/http2/hpack/decoder/hpack_entry_decoder.h b/quiche/http2/hpack/decoder/hpack_entry_decoder.h new file mode 100644 index 000000000000..85b8a781e9c2 --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_entry_decoder.h @@ -0,0 +1,90 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_HPACK_DECODER_HPACK_ENTRY_DECODER_H_ +#define QUICHE_HTTP2_HPACK_DECODER_HPACK_ENTRY_DECODER_H_ + +// HpackEntryDecoder decodes a single HPACK entry (i.e. one header or one +// dynamic table size update), in a resumable fashion. The first call, Start(), +// must provide a non-empty decode buffer. Continue with calls to Resume() if +// Start, and any subsequent calls to Resume, returns kDecodeInProgress. + +#include + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/hpack/decoder/hpack_decoding_error.h" +#include "quiche/http2/hpack/decoder/hpack_entry_decoder_listener.h" +#include "quiche/http2/hpack/decoder/hpack_entry_type_decoder.h" +#include "quiche/http2/hpack/decoder/hpack_string_decoder.h" +#include "quiche/http2/hpack/http2_hpack_constants.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +class QUICHE_EXPORT HpackEntryDecoder { + public: + enum class EntryDecoderState { + // Have started decoding the type/varint, but didn't finish on the previous + // attempt. Next state is kResumeDecodingType or kDecodedType. + kResumeDecodingType, + + // Have just finished decoding the type/varint. Final state if the type is + // kIndexedHeader or kDynamicTableSizeUpdate. Otherwise, the next state is + // kStartDecodingName (if the varint is 0), else kStartDecodingValue. + kDecodedType, + + // Ready to start decoding the literal name of a header entry. Next state + // is kResumeDecodingName (if the name is split across decode buffers), + // else kStartDecodingValue. + kStartDecodingName, + + // Resume decoding the literal name of a header that is split across decode + // buffers. + kResumeDecodingName, + + // Ready to start decoding the literal value of a header entry. Final state + // if the value string is entirely in the decode buffer, else the next state + // is kResumeDecodingValue. + kStartDecodingValue, + + // Resume decoding the literal value of a header that is split across decode + // buffers. + kResumeDecodingValue, + }; + + // Only call when the decode buffer has data (i.e. HpackBlockDecoder must + // not call until there is data). + DecodeStatus Start(DecodeBuffer* db, HpackEntryDecoderListener* listener); + + // Only call Resume if the previous call (Start or Resume) returned + // kDecodeInProgress; Resume is also called from Start when it has succeeded + // in decoding the entry type and its varint. + DecodeStatus Resume(DecodeBuffer* db, HpackEntryDecoderListener* listener); + + // Return error code after decoding error occurred. + HpackDecodingError error() const { return error_; } + + std::string DebugString() const; + void OutputDebugString(std::ostream& out) const; + + private: + // Implements handling state kDecodedType. + bool DispatchOnType(HpackEntryDecoderListener* listener); + + HpackEntryTypeDecoder entry_type_decoder_; + HpackStringDecoder string_decoder_; + EntryDecoderState state_ = EntryDecoderState(); + HpackDecodingError error_ = HpackDecodingError::kOk; +}; + +QUICHE_EXPORT std::ostream& operator<<(std::ostream& out, + const HpackEntryDecoder& v); +QUICHE_EXPORT std::ostream& operator<<( + std::ostream& out, HpackEntryDecoder::EntryDecoderState state); + +} // namespace http2 + +#endif // QUICHE_HTTP2_HPACK_DECODER_HPACK_ENTRY_DECODER_H_ diff --git a/quiche/http2/hpack/decoder/hpack_entry_decoder_listener.cc b/quiche/http2/hpack/decoder/hpack_entry_decoder_listener.cc new file mode 100644 index 000000000000..16b1420d97e4 --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_entry_decoder_listener.cc @@ -0,0 +1,80 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/decoder/hpack_entry_decoder_listener.h" + +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +void HpackEntryDecoderVLoggingListener::OnIndexedHeader(size_t index) { + QUICHE_VLOG(1) << "OnIndexedHeader, index=" << index; + if (wrapped_) { + wrapped_->OnIndexedHeader(index); + } +} + +void HpackEntryDecoderVLoggingListener::OnStartLiteralHeader( + HpackEntryType entry_type, size_t maybe_name_index) { + QUICHE_VLOG(1) << "OnStartLiteralHeader: entry_type=" << entry_type + << ", maybe_name_index=" << maybe_name_index; + if (wrapped_) { + wrapped_->OnStartLiteralHeader(entry_type, maybe_name_index); + } +} + +void HpackEntryDecoderVLoggingListener::OnNameStart(bool huffman_encoded, + size_t len) { + QUICHE_VLOG(1) << "OnNameStart: H=" << huffman_encoded << ", len=" << len; + if (wrapped_) { + wrapped_->OnNameStart(huffman_encoded, len); + } +} + +void HpackEntryDecoderVLoggingListener::OnNameData(const char* data, + size_t len) { + QUICHE_VLOG(1) << "OnNameData: len=" << len; + if (wrapped_) { + wrapped_->OnNameData(data, len); + } +} + +void HpackEntryDecoderVLoggingListener::OnNameEnd() { + QUICHE_VLOG(1) << "OnNameEnd"; + if (wrapped_) { + wrapped_->OnNameEnd(); + } +} + +void HpackEntryDecoderVLoggingListener::OnValueStart(bool huffman_encoded, + size_t len) { + QUICHE_VLOG(1) << "OnValueStart: H=" << huffman_encoded << ", len=" << len; + if (wrapped_) { + wrapped_->OnValueStart(huffman_encoded, len); + } +} + +void HpackEntryDecoderVLoggingListener::OnValueData(const char* data, + size_t len) { + QUICHE_VLOG(1) << "OnValueData: len=" << len; + if (wrapped_) { + wrapped_->OnValueData(data, len); + } +} + +void HpackEntryDecoderVLoggingListener::OnValueEnd() { + QUICHE_VLOG(1) << "OnValueEnd"; + if (wrapped_) { + wrapped_->OnValueEnd(); + } +} + +void HpackEntryDecoderVLoggingListener::OnDynamicTableSizeUpdate(size_t size) { + QUICHE_VLOG(1) << "OnDynamicTableSizeUpdate: size=" << size; + if (wrapped_) { + wrapped_->OnDynamicTableSizeUpdate(size); + } +} + +} // namespace http2 diff --git a/quiche/http2/hpack/decoder/hpack_entry_decoder_listener.h b/quiche/http2/hpack/decoder/hpack_entry_decoder_listener.h new file mode 100644 index 000000000000..86230df7c427 --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_entry_decoder_listener.h @@ -0,0 +1,110 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_HPACK_DECODER_HPACK_ENTRY_DECODER_LISTENER_H_ +#define QUICHE_HTTP2_HPACK_DECODER_HPACK_ENTRY_DECODER_LISTENER_H_ + +// Defines HpackEntryDecoderListener, the base class of listeners that +// HpackEntryDecoder calls. Also defines HpackEntryDecoderVLoggingListener +// which logs before calling another HpackEntryDecoderListener implementation. + +#include + +#include "quiche/http2/hpack/http2_hpack_constants.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { + +class QUICHE_EXPORT HpackEntryDecoderListener { + public: + virtual ~HpackEntryDecoderListener() {} + + // Called when an indexed header (i.e. one in the static or dynamic table) has + // been decoded from an HPACK block. index is supposed to be non-zero, but + // that has not been checked by the caller. + virtual void OnIndexedHeader(size_t index) = 0; + + // Called when the start of a header with a literal value, and maybe a literal + // name, has been decoded. maybe_name_index is zero if the header has a + // literal name, else it is a reference into the static or dynamic table, from + // which the name should be determined. When the name is literal, the next + // call will be to OnNameStart; else it will be to OnValueStart. entry_type + // indicates whether the peer has added the entry to its dynamic table, and + // whether a proxy is permitted to do so when forwarding the entry. + virtual void OnStartLiteralHeader(HpackEntryType entry_type, + size_t maybe_name_index) = 0; + + // Called when the encoding (Huffman compressed or plain text) and the encoded + // length of a literal name has been decoded. OnNameData will be called next, + // and repeatedly until the sum of lengths passed to OnNameData is len. + virtual void OnNameStart(bool huffman_encoded, size_t len) = 0; + + // Called when len bytes of an encoded header name have been decoded. + virtual void OnNameData(const char* data, size_t len) = 0; + + // Called after the entire name has been passed to OnNameData. + // OnValueStart will be called next. + virtual void OnNameEnd() = 0; + + // Called when the encoding (Huffman compressed or plain text) and the encoded + // length of a literal value has been decoded. OnValueData will be called + // next, and repeatedly until the sum of lengths passed to OnValueData is len. + virtual void OnValueStart(bool huffman_encoded, size_t len) = 0; + + // Called when len bytes of an encoded header value have been decoded. + virtual void OnValueData(const char* data, size_t len) = 0; + + // Called after the entire value has been passed to OnValueData, marking the + // end of a header entry with a literal value, and maybe a literal name. + virtual void OnValueEnd() = 0; + + // Called when an update to the size of the peer's dynamic table has been + // decoded. + virtual void OnDynamicTableSizeUpdate(size_t size) = 0; +}; + +class QUICHE_EXPORT HpackEntryDecoderVLoggingListener + : public HpackEntryDecoderListener { + public: + HpackEntryDecoderVLoggingListener() : wrapped_(nullptr) {} + explicit HpackEntryDecoderVLoggingListener(HpackEntryDecoderListener* wrapped) + : wrapped_(wrapped) {} + ~HpackEntryDecoderVLoggingListener() override {} + + void OnIndexedHeader(size_t index) override; + void OnStartLiteralHeader(HpackEntryType entry_type, + size_t maybe_name_index) override; + void OnNameStart(bool huffman_encoded, size_t len) override; + void OnNameData(const char* data, size_t len) override; + void OnNameEnd() override; + void OnValueStart(bool huffman_encoded, size_t len) override; + void OnValueData(const char* data, size_t len) override; + void OnValueEnd() override; + void OnDynamicTableSizeUpdate(size_t size) override; + + private: + HpackEntryDecoderListener* const wrapped_; +}; + +// A no-op implementation of HpackEntryDecoderListener. +class QUICHE_EXPORT HpackEntryDecoderNoOpListener + : public HpackEntryDecoderListener { + public: + ~HpackEntryDecoderNoOpListener() override {} + + void OnIndexedHeader(size_t /*index*/) override {} + void OnStartLiteralHeader(HpackEntryType /*entry_type*/, + size_t /*maybe_name_index*/) override {} + void OnNameStart(bool /*huffman_encoded*/, size_t /*len*/) override {} + void OnNameData(const char* /*data*/, size_t /*len*/) override {} + void OnNameEnd() override {} + void OnValueStart(bool /*huffman_encoded*/, size_t /*len*/) override {} + void OnValueData(const char* /*data*/, size_t /*len*/) override {} + void OnValueEnd() override {} + void OnDynamicTableSizeUpdate(size_t /*size*/) override {} +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_HPACK_DECODER_HPACK_ENTRY_DECODER_LISTENER_H_ diff --git a/quiche/http2/hpack/decoder/hpack_entry_decoder_test.cc b/quiche/http2/hpack/decoder/hpack_entry_decoder_test.cc new file mode 100644 index 000000000000..aefadd1d2eaa --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_entry_decoder_test.cc @@ -0,0 +1,202 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/decoder/hpack_entry_decoder.h" + +// Tests of HpackEntryDecoder. + +#include + +#include "quiche/http2/test_tools/hpack_block_builder.h" +#include "quiche/http2/test_tools/hpack_entry_collector.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/http2/test_tools/random_decoder_test_base.h" +#include "quiche/common/platform/api/quiche_expect_bug.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { +namespace { + +class HpackEntryDecoderTest : public RandomDecoderTest { + protected: + HpackEntryDecoderTest() : listener_(&collector_) {} + + DecodeStatus StartDecoding(DecodeBuffer* b) override { + collector_.Clear(); + return decoder_.Start(b, &listener_); + } + + DecodeStatus ResumeDecoding(DecodeBuffer* b) override { + return decoder_.Resume(b, &listener_); + } + + AssertionResult DecodeAndValidateSeveralWays(DecodeBuffer* db, + const Validator& validator) { + // StartDecoding, above, requires the DecodeBuffer be non-empty so that it + // can call Start with the prefix byte. + bool return_non_zero_on_first = true; + return RandomDecoderTest::DecodeAndValidateSeveralWays( + db, return_non_zero_on_first, validator); + } + + AssertionResult DecodeAndValidateSeveralWays(const HpackBlockBuilder& hbb, + const Validator& validator) { + DecodeBuffer db(hbb.buffer()); + return DecodeAndValidateSeveralWays(&db, validator); + } + + HpackEntryDecoder decoder_; + HpackEntryCollector collector_; + HpackEntryDecoderVLoggingListener listener_; +}; + +TEST_F(HpackEntryDecoderTest, IndexedHeader_Literals) { + { + const char input[] = {'\x82'}; // == Index 2 == + DecodeBuffer b(input); + auto do_check = [this]() { return collector_.ValidateIndexedHeader(2); }; + EXPECT_TRUE( + DecodeAndValidateSeveralWays(&b, ValidateDoneAndEmpty(do_check))); + EXPECT_TRUE(do_check()); + } + collector_.Clear(); + { + const char input[] = {'\xfe'}; // == Index 126 == + DecodeBuffer b(input); + auto do_check = [this]() { return collector_.ValidateIndexedHeader(126); }; + EXPECT_TRUE( + DecodeAndValidateSeveralWays(&b, ValidateDoneAndEmpty(do_check))); + EXPECT_TRUE(do_check()); + } + collector_.Clear(); + { + const char input[] = {'\xff', '\x00'}; // == Index 127 == + DecodeBuffer b(input); + auto do_check = [this]() { return collector_.ValidateIndexedHeader(127); }; + EXPECT_TRUE( + DecodeAndValidateSeveralWays(&b, ValidateDoneAndEmpty(do_check))); + EXPECT_TRUE(do_check()); + } +} + +TEST_F(HpackEntryDecoderTest, IndexedHeader_Various) { + // Indices chosen to hit encoding and table boundaries. + for (const uint32_t ndx : {1, 2, 61, 62, 63, 126, 127, 254, 255, 256}) { + HpackBlockBuilder hbb; + hbb.AppendIndexedHeader(ndx); + + auto do_check = [this, ndx]() { + return collector_.ValidateIndexedHeader(ndx); + }; + EXPECT_TRUE( + DecodeAndValidateSeveralWays(hbb, ValidateDoneAndEmpty(do_check))); + EXPECT_TRUE(do_check()); + } +} + +TEST_F(HpackEntryDecoderTest, IndexedLiteralValue_Literal) { + const char input[] = + "\x7f" // == Literal indexed, name index 0x40 == + "\x01" // 2nd byte of name index (0x01 + 0x3f == 0x40) + "\x0d" // Value length (13) + "custom-header"; // Value + DecodeBuffer b(input, sizeof input - 1); + auto do_check = [this]() { + return collector_.ValidateLiteralValueHeader( + HpackEntryType::kIndexedLiteralHeader, 0x40, false, "custom-header"); + }; + EXPECT_TRUE(DecodeAndValidateSeveralWays(&b, ValidateDoneAndEmpty(do_check))); + EXPECT_TRUE(do_check()); +} + +TEST_F(HpackEntryDecoderTest, IndexedLiteralNameValue_Literal) { + const char input[] = + "\x40" // == Literal indexed == + "\x0a" // Name length (10) + "custom-key" // Name + "\x0d" // Value length (13) + "custom-header"; // Value + + DecodeBuffer b(input, sizeof input - 1); + auto do_check = [this]() { + return collector_.ValidateLiteralNameValueHeader( + HpackEntryType::kIndexedLiteralHeader, false, "custom-key", false, + "custom-header"); + }; + EXPECT_TRUE(DecodeAndValidateSeveralWays(&b, ValidateDoneAndEmpty(do_check))); + EXPECT_TRUE(do_check()); +} + +TEST_F(HpackEntryDecoderTest, DynamicTableSizeUpdate_Literal) { + // Size update, length 31. + const char input[] = "\x3f\x00"; + DecodeBuffer b(input, 2); + auto do_check = [this]() { + return collector_.ValidateDynamicTableSizeUpdate(31); + }; + EXPECT_TRUE(DecodeAndValidateSeveralWays(&b, ValidateDoneAndEmpty(do_check))); + EXPECT_TRUE(do_check()); +} + +class HpackLiteralEntryDecoderTest + : public HpackEntryDecoderTest, + public ::testing::WithParamInterface { + protected: + HpackLiteralEntryDecoderTest() : entry_type_(GetParam()) {} + + const HpackEntryType entry_type_; +}; + +INSTANTIATE_TEST_SUITE_P( + AllLiteralTypes, HpackLiteralEntryDecoderTest, + testing::Values(HpackEntryType::kIndexedLiteralHeader, + HpackEntryType::kUnindexedLiteralHeader, + HpackEntryType::kNeverIndexedLiteralHeader)); + +TEST_P(HpackLiteralEntryDecoderTest, RandNameIndexAndLiteralValue) { + for (int n = 0; n < 10; n++) { + const uint32_t ndx = 1 + Random().Rand8(); + const bool value_is_huffman_encoded = (n % 2) == 0; + const std::string value = Random().RandString(Random().Rand8()); + HpackBlockBuilder hbb; + hbb.AppendNameIndexAndLiteralValue(entry_type_, ndx, + value_is_huffman_encoded, value); + auto do_check = [this, ndx, value_is_huffman_encoded, + value]() -> AssertionResult { + return collector_.ValidateLiteralValueHeader( + entry_type_, ndx, value_is_huffman_encoded, value); + }; + EXPECT_TRUE( + DecodeAndValidateSeveralWays(hbb, ValidateDoneAndEmpty(do_check))); + EXPECT_TRUE(do_check()); + } +} + +TEST_P(HpackLiteralEntryDecoderTest, RandLiteralNameAndValue) { + for (int n = 0; n < 10; n++) { + const bool name_is_huffman_encoded = (n & 1) == 0; + const int name_len = 1 + Random().Rand8(); + const std::string name = Random().RandString(name_len); + const bool value_is_huffman_encoded = (n & 2) == 0; + const int value_len = Random().Skewed(10); + const std::string value = Random().RandString(value_len); + HpackBlockBuilder hbb; + hbb.AppendLiteralNameAndValue(entry_type_, name_is_huffman_encoded, name, + value_is_huffman_encoded, value); + auto do_check = [this, name_is_huffman_encoded, name, + value_is_huffman_encoded, value]() -> AssertionResult { + return collector_.ValidateLiteralNameValueHeader( + entry_type_, name_is_huffman_encoded, name, value_is_huffman_encoded, + value); + }; + EXPECT_TRUE( + DecodeAndValidateSeveralWays(hbb, ValidateDoneAndEmpty(do_check))); + EXPECT_TRUE(do_check()); + } +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/hpack/decoder/hpack_entry_type_decoder.cc b/quiche/http2/hpack/decoder/hpack_entry_type_decoder.cc new file mode 100644 index 000000000000..e5694d4641b6 --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_entry_type_decoder.cc @@ -0,0 +1,361 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/decoder/hpack_entry_type_decoder.h" + +#include "absl/strings/str_cat.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_flag_utils.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +std::string HpackEntryTypeDecoder::DebugString() const { + return absl::StrCat( + "HpackEntryTypeDecoder(varint_decoder=", varint_decoder_.DebugString(), + ", entry_type=", entry_type_, ")"); +} + +std::ostream& operator<<(std::ostream& out, const HpackEntryTypeDecoder& v) { + return out << v.DebugString(); +} + +// This ridiculous looking function turned out to be the winner in benchmarking +// of several very different alternative implementations. It would be even +// faster (~7%) if inlined in the header file, but I'm not sure if that is +// worth doing... yet. +// TODO(jamessynge): Benchmark again at a higher level (e.g. at least at the +// full HTTP/2 decoder level, but preferably still higher) to determine if the +// alternatives that take less code/data space are preferable in that situation. +DecodeStatus HpackEntryTypeDecoder::Start(DecodeBuffer* db) { + QUICHE_DCHECK(db != nullptr); + QUICHE_DCHECK(db->HasData()); + + // The high four bits (nibble) of first byte of the entry determine the type + // of the entry, and may also be the initial bits of the varint that + // represents an index or table size. Note the use of the word 'initial' + // rather than 'high'; the HPACK encoding of varints is not in network + // order (i.e. not big-endian, the high-order byte isn't first), nor in + // little-endian order. See: + // http://httpwg.org/specs/rfc7541.html#integer.representation + uint8_t byte = db->DecodeUInt8(); + switch (byte) { + case 0b00000000: + case 0b00000001: + case 0b00000010: + case 0b00000011: + case 0b00000100: + case 0b00000101: + case 0b00000110: + case 0b00000111: + case 0b00001000: + case 0b00001001: + case 0b00001010: + case 0b00001011: + case 0b00001100: + case 0b00001101: + case 0b00001110: + // The low 4 bits of |byte| are the initial bits of the varint. + // One of those bits is 0, so the varint is only one byte long. + entry_type_ = HpackEntryType::kUnindexedLiteralHeader; + varint_decoder_.set_value(byte); + return DecodeStatus::kDecodeDone; + + case 0b00001111: + // The low 4 bits of |byte| are the initial bits of the varint. All 4 + // are 1, so the varint extends into another byte. + entry_type_ = HpackEntryType::kUnindexedLiteralHeader; + return varint_decoder_.StartExtended(4, db); + + case 0b00010000: + case 0b00010001: + case 0b00010010: + case 0b00010011: + case 0b00010100: + case 0b00010101: + case 0b00010110: + case 0b00010111: + case 0b00011000: + case 0b00011001: + case 0b00011010: + case 0b00011011: + case 0b00011100: + case 0b00011101: + case 0b00011110: + // The low 4 bits of |byte| are the initial bits of the varint. + // One of those bits is 0, so the varint is only one byte long. + entry_type_ = HpackEntryType::kNeverIndexedLiteralHeader; + varint_decoder_.set_value(byte & 0x0f); + return DecodeStatus::kDecodeDone; + + case 0b00011111: + // The low 4 bits of |byte| are the initial bits of the varint. + // All of those bits are 1, so the varint extends into another byte. + entry_type_ = HpackEntryType::kNeverIndexedLiteralHeader; + return varint_decoder_.StartExtended(4, db); + + case 0b00100000: + case 0b00100001: + case 0b00100010: + case 0b00100011: + case 0b00100100: + case 0b00100101: + case 0b00100110: + case 0b00100111: + case 0b00101000: + case 0b00101001: + case 0b00101010: + case 0b00101011: + case 0b00101100: + case 0b00101101: + case 0b00101110: + case 0b00101111: + case 0b00110000: + case 0b00110001: + case 0b00110010: + case 0b00110011: + case 0b00110100: + case 0b00110101: + case 0b00110110: + case 0b00110111: + case 0b00111000: + case 0b00111001: + case 0b00111010: + case 0b00111011: + case 0b00111100: + case 0b00111101: + case 0b00111110: + entry_type_ = HpackEntryType::kDynamicTableSizeUpdate; + // The low 5 bits of |byte| are the initial bits of the varint. + // One of those bits is 0, so the varint is only one byte long. + varint_decoder_.set_value(byte & 0x01f); + return DecodeStatus::kDecodeDone; + + case 0b00111111: + entry_type_ = HpackEntryType::kDynamicTableSizeUpdate; + // The low 5 bits of |byte| are the initial bits of the varint. + // All of those bits are 1, so the varint extends into another byte. + return varint_decoder_.StartExtended(5, db); + + case 0b01000000: + case 0b01000001: + case 0b01000010: + case 0b01000011: + case 0b01000100: + case 0b01000101: + case 0b01000110: + case 0b01000111: + case 0b01001000: + case 0b01001001: + case 0b01001010: + case 0b01001011: + case 0b01001100: + case 0b01001101: + case 0b01001110: + case 0b01001111: + case 0b01010000: + case 0b01010001: + case 0b01010010: + case 0b01010011: + case 0b01010100: + case 0b01010101: + case 0b01010110: + case 0b01010111: + case 0b01011000: + case 0b01011001: + case 0b01011010: + case 0b01011011: + case 0b01011100: + case 0b01011101: + case 0b01011110: + case 0b01011111: + case 0b01100000: + case 0b01100001: + case 0b01100010: + case 0b01100011: + case 0b01100100: + case 0b01100101: + case 0b01100110: + case 0b01100111: + case 0b01101000: + case 0b01101001: + case 0b01101010: + case 0b01101011: + case 0b01101100: + case 0b01101101: + case 0b01101110: + case 0b01101111: + case 0b01110000: + case 0b01110001: + case 0b01110010: + case 0b01110011: + case 0b01110100: + case 0b01110101: + case 0b01110110: + case 0b01110111: + case 0b01111000: + case 0b01111001: + case 0b01111010: + case 0b01111011: + case 0b01111100: + case 0b01111101: + case 0b01111110: + entry_type_ = HpackEntryType::kIndexedLiteralHeader; + // The low 6 bits of |byte| are the initial bits of the varint. + // One of those bits is 0, so the varint is only one byte long. + varint_decoder_.set_value(byte & 0x03f); + return DecodeStatus::kDecodeDone; + + case 0b01111111: + entry_type_ = HpackEntryType::kIndexedLiteralHeader; + // The low 6 bits of |byte| are the initial bits of the varint. + // All of those bits are 1, so the varint extends into another byte. + return varint_decoder_.StartExtended(6, db); + + case 0b10000000: + case 0b10000001: + case 0b10000010: + case 0b10000011: + case 0b10000100: + case 0b10000101: + case 0b10000110: + case 0b10000111: + case 0b10001000: + case 0b10001001: + case 0b10001010: + case 0b10001011: + case 0b10001100: + case 0b10001101: + case 0b10001110: + case 0b10001111: + case 0b10010000: + case 0b10010001: + case 0b10010010: + case 0b10010011: + case 0b10010100: + case 0b10010101: + case 0b10010110: + case 0b10010111: + case 0b10011000: + case 0b10011001: + case 0b10011010: + case 0b10011011: + case 0b10011100: + case 0b10011101: + case 0b10011110: + case 0b10011111: + case 0b10100000: + case 0b10100001: + case 0b10100010: + case 0b10100011: + case 0b10100100: + case 0b10100101: + case 0b10100110: + case 0b10100111: + case 0b10101000: + case 0b10101001: + case 0b10101010: + case 0b10101011: + case 0b10101100: + case 0b10101101: + case 0b10101110: + case 0b10101111: + case 0b10110000: + case 0b10110001: + case 0b10110010: + case 0b10110011: + case 0b10110100: + case 0b10110101: + case 0b10110110: + case 0b10110111: + case 0b10111000: + case 0b10111001: + case 0b10111010: + case 0b10111011: + case 0b10111100: + case 0b10111101: + case 0b10111110: + case 0b10111111: + case 0b11000000: + case 0b11000001: + case 0b11000010: + case 0b11000011: + case 0b11000100: + case 0b11000101: + case 0b11000110: + case 0b11000111: + case 0b11001000: + case 0b11001001: + case 0b11001010: + case 0b11001011: + case 0b11001100: + case 0b11001101: + case 0b11001110: + case 0b11001111: + case 0b11010000: + case 0b11010001: + case 0b11010010: + case 0b11010011: + case 0b11010100: + case 0b11010101: + case 0b11010110: + case 0b11010111: + case 0b11011000: + case 0b11011001: + case 0b11011010: + case 0b11011011: + case 0b11011100: + case 0b11011101: + case 0b11011110: + case 0b11011111: + case 0b11100000: + case 0b11100001: + case 0b11100010: + case 0b11100011: + case 0b11100100: + case 0b11100101: + case 0b11100110: + case 0b11100111: + case 0b11101000: + case 0b11101001: + case 0b11101010: + case 0b11101011: + case 0b11101100: + case 0b11101101: + case 0b11101110: + case 0b11101111: + case 0b11110000: + case 0b11110001: + case 0b11110010: + case 0b11110011: + case 0b11110100: + case 0b11110101: + case 0b11110110: + case 0b11110111: + case 0b11111000: + case 0b11111001: + case 0b11111010: + case 0b11111011: + case 0b11111100: + case 0b11111101: + case 0b11111110: + entry_type_ = HpackEntryType::kIndexedHeader; + // The low 7 bits of |byte| are the initial bits of the varint. + // One of those bits is 0, so the varint is only one byte long. + varint_decoder_.set_value(byte & 0x07f); + return DecodeStatus::kDecodeDone; + + case 0b11111111: + entry_type_ = HpackEntryType::kIndexedHeader; + // The low 7 bits of |byte| are the initial bits of the varint. + // All of those bits are 1, so the varint extends into another byte. + return varint_decoder_.StartExtended(7, db); + } + QUICHE_BUG(http2_bug_66_1) + << "Unreachable, byte=" << std::hex << static_cast(byte); + QUICHE_CODE_COUNT_N(decompress_failure_3, 17, 23); + return DecodeStatus::kDecodeError; +} + +} // namespace http2 diff --git a/quiche/http2/hpack/decoder/hpack_entry_type_decoder.h b/quiche/http2/hpack/decoder/hpack_entry_type_decoder.h new file mode 100644 index 000000000000..2548ef5b78a9 --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_entry_type_decoder.h @@ -0,0 +1,57 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_HPACK_DECODER_HPACK_ENTRY_TYPE_DECODER_H_ +#define QUICHE_HTTP2_HPACK_DECODER_HPACK_ENTRY_TYPE_DECODER_H_ + +// Decodes the type of an HPACK entry, and the variable length integer whose +// prefix is in the low-order bits of the same byte, "below" the type bits. +// The integer represents an index into static or dynamic table, which may be +// zero, or is the new size limit of the dynamic table. + +#include +#include + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/hpack/http2_hpack_constants.h" +#include "quiche/http2/hpack/varint/hpack_varint_decoder.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +class QUICHE_EXPORT HpackEntryTypeDecoder { + public: + // Only call when the decode buffer has data (i.e. HpackEntryDecoder must + // not call until there is data). + DecodeStatus Start(DecodeBuffer* db); + + // Only call Resume if the previous call (Start or Resume) returned + // DecodeStatus::kDecodeInProgress. + DecodeStatus Resume(DecodeBuffer* db) { return varint_decoder_.Resume(db); } + + // Returns the decoded entry type. Only call if the preceding call to Start + // or Resume returned kDecodeDone. + HpackEntryType entry_type() const { return entry_type_; } + + // Returns the decoded variable length integer. Only call if the + // preceding call to Start or Resume returned kDecodeDone. + uint64_t varint() const { return varint_decoder_.value(); } + + std::string DebugString() const; + + private: + HpackVarintDecoder varint_decoder_; + + // This field is initialized just to keep ASAN happy about reading it + // from DebugString(). + HpackEntryType entry_type_ = HpackEntryType::kIndexedHeader; +}; + +QUICHE_EXPORT std::ostream& operator<<(std::ostream& out, + const HpackEntryTypeDecoder& v); + +} // namespace http2 +#endif // QUICHE_HTTP2_HPACK_DECODER_HPACK_ENTRY_TYPE_DECODER_H_ diff --git a/quiche/http2/hpack/decoder/hpack_entry_type_decoder_test.cc b/quiche/http2/hpack/decoder/hpack_entry_type_decoder_test.cc new file mode 100644 index 000000000000..c48f8f41794f --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_entry_type_decoder_test.cc @@ -0,0 +1,86 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/decoder/hpack_entry_type_decoder.h" + +#include + +#include "quiche/http2/test_tools/hpack_block_builder.h" +#include "quiche/http2/test_tools/random_decoder_test_base.h" +#include "quiche/http2/test_tools/verify_macros.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +using ::testing::AssertionSuccess; + +namespace http2 { +namespace test { +namespace { +const bool kReturnNonZeroOnFirst = true; + +class HpackEntryTypeDecoderTest : public RandomDecoderTest { + protected: + DecodeStatus StartDecoding(DecodeBuffer* b) override { + QUICHE_CHECK_LT(0u, b->Remaining()); + return decoder_.Start(b); + } + + DecodeStatus ResumeDecoding(DecodeBuffer* b) override { + return decoder_.Resume(b); + } + + HpackEntryTypeDecoder decoder_; +}; + +TEST_F(HpackEntryTypeDecoderTest, DynamicTableSizeUpdate) { + for (uint32_t size = 0; size < 1000 * 1000; size += 256) { + HpackBlockBuilder bb; + bb.AppendDynamicTableSizeUpdate(size); + DecodeBuffer db(bb.buffer()); + auto validator = [size, this]() -> AssertionResult { + HTTP2_VERIFY_EQ(HpackEntryType::kDynamicTableSizeUpdate, + decoder_.entry_type()); + HTTP2_VERIFY_EQ(size, decoder_.varint()); + return AssertionSuccess(); + }; + EXPECT_TRUE(DecodeAndValidateSeveralWays(&db, kReturnNonZeroOnFirst, + ValidateDoneAndEmpty(validator))) + << "\nentry_type=kDynamicTableSizeUpdate, size=" << size; + // Run the validator again to make sure that DecodeAndValidateSeveralWays + // did the right thing. + EXPECT_TRUE(validator()); + } +} + +TEST_F(HpackEntryTypeDecoderTest, HeaderWithIndex) { + std::vector entry_types = { + HpackEntryType::kIndexedHeader, + HpackEntryType::kIndexedLiteralHeader, + HpackEntryType::kUnindexedLiteralHeader, + HpackEntryType::kNeverIndexedLiteralHeader, + }; + for (const HpackEntryType entry_type : entry_types) { + const uint32_t first = entry_type == HpackEntryType::kIndexedHeader ? 1 : 0; + for (uint32_t index = first; index < 1000; ++index) { + HpackBlockBuilder bb; + bb.AppendEntryTypeAndVarint(entry_type, index); + DecodeBuffer db(bb.buffer()); + auto validator = [entry_type, index, this]() -> AssertionResult { + HTTP2_VERIFY_EQ(entry_type, decoder_.entry_type()); + HTTP2_VERIFY_EQ(index, decoder_.varint()); + return AssertionSuccess(); + }; + EXPECT_TRUE(DecodeAndValidateSeveralWays(&db, kReturnNonZeroOnFirst, + ValidateDoneAndEmpty(validator))) + << "\nentry_type=" << entry_type << ", index=" << index; + // Run the validator again to make sure that DecodeAndValidateSeveralWays + // did the right thing. + EXPECT_TRUE(validator()); + } + } +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/hpack/decoder/hpack_string_decoder.cc b/quiche/http2/hpack/decoder/hpack_string_decoder.cc new file mode 100644 index 000000000000..f2a4bf80664d --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_string_decoder.cc @@ -0,0 +1,35 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/decoder/hpack_string_decoder.h" + +#include "absl/strings/str_cat.h" + +namespace http2 { + +std::string HpackStringDecoder::DebugString() const { + return absl::StrCat("HpackStringDecoder(state=", StateToString(state_), + ", length=", length_decoder_.DebugString(), + ", remaining=", remaining_, + ", huffman=", huffman_encoded_ ? "true)" : "false)"); +} + +// static +std::string HpackStringDecoder::StateToString(StringDecoderState v) { + switch (v) { + case kStartDecodingLength: + return "kStartDecodingLength"; + case kDecodingString: + return "kDecodingString"; + case kResumeDecodingLength: + return "kResumeDecodingLength"; + } + return absl::StrCat("UNKNOWN_STATE(", static_cast(v), ")"); +} + +std::ostream& operator<<(std::ostream& out, const HpackStringDecoder& v) { + return out << v.DebugString(); +} + +} // namespace http2 diff --git a/quiche/http2/hpack/decoder/hpack_string_decoder.h b/quiche/http2/hpack/decoder/hpack_string_decoder.h new file mode 100644 index 000000000000..897ce761a2b8 --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_string_decoder.h @@ -0,0 +1,207 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_HPACK_DECODER_HPACK_STRING_DECODER_H_ +#define QUICHE_HTTP2_HPACK_DECODER_HPACK_STRING_DECODER_H_ + +// HpackStringDecoder decodes strings encoded per the HPACK spec; this does +// not mean decompressing Huffman encoded strings, just identifying the length, +// encoding and contents for a listener. + +#include + +#include +#include +#include + +#include "absl/base/macros.h" +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/hpack/varint/hpack_varint_decoder.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +// Decodes a single string in an HPACK header entry. The high order bit of +// the first byte of the length is the H (Huffman) bit indicating whether +// the value is Huffman encoded, and the remainder of the byte is the first +// 7 bits of an HPACK varint. +// +// Call Start() to begin decoding; if it returns kDecodeInProgress, then call +// Resume() when more input is available, repeating until kDecodeInProgress is +// not returned. If kDecodeDone or kDecodeError is returned, then Resume() must +// not be called until Start() has been called to start decoding a new string. +class QUICHE_EXPORT HpackStringDecoder { + public: + enum StringDecoderState { + kStartDecodingLength, + kDecodingString, + kResumeDecodingLength, + }; + + template + DecodeStatus Start(DecodeBuffer* db, Listener* cb) { + // Fast decode path is used if the string is under 127 bytes and the + // entire length of the string is in the decode buffer. More than 83% of + // string lengths are encoded in just one byte. + if (db->HasData() && (*db->cursor() & 0x7f) != 0x7f) { + // The string is short. + uint8_t h_and_prefix = db->DecodeUInt8(); + uint8_t length = h_and_prefix & 0x7f; + bool huffman_encoded = (h_and_prefix & 0x80) == 0x80; + cb->OnStringStart(huffman_encoded, length); + if (length <= db->Remaining()) { + // Yeah, we've got the whole thing in the decode buffer. + // Ideally this will be the common case. Note that we don't + // update any of the member variables in this path. + cb->OnStringData(db->cursor(), length); + db->AdvanceCursor(length); + cb->OnStringEnd(); + return DecodeStatus::kDecodeDone; + } + // Not all in the buffer. + huffman_encoded_ = huffman_encoded; + remaining_ = length; + // Call Resume to decode the string body, which is only partially + // in the decode buffer (or not at all). + state_ = kDecodingString; + return Resume(db, cb); + } + // Call Resume to decode the string length, which is either not in + // the decode buffer, or spans multiple bytes. + state_ = kStartDecodingLength; + return Resume(db, cb); + } + + template + DecodeStatus Resume(DecodeBuffer* db, Listener* cb) { + DecodeStatus status; + while (true) { + switch (state_) { + case kStartDecodingLength: + QUICHE_DVLOG(2) << "kStartDecodingLength: db->Remaining=" + << db->Remaining(); + if (!StartDecodingLength(db, cb, &status)) { + // The length is split across decode buffers. + return status; + } + // We've finished decoding the length, which spanned one or more + // bytes. Approximately 17% of strings have a length that is greater + // than 126 bytes, and thus the length is encoded in more than one + // byte, and so doesn't get the benefit of the optimization in + // Start() for single byte lengths. But, we still expect that most + // of such strings will be contained entirely in a single decode + // buffer, and hence this fall through skips another trip through the + // switch above and more importantly skips setting the state_ variable + // again in those cases where we don't need it. + ABSL_FALLTHROUGH_INTENDED; + + case kDecodingString: + QUICHE_DVLOG(2) << "kDecodingString: db->Remaining=" + << db->Remaining() << " remaining_=" << remaining_; + return DecodeString(db, cb); + + case kResumeDecodingLength: + QUICHE_DVLOG(2) << "kResumeDecodingLength: db->Remaining=" + << db->Remaining(); + if (!ResumeDecodingLength(db, cb, &status)) { + return status; + } + } + } + } + + std::string DebugString() const; + + private: + static std::string StateToString(StringDecoderState v); + + // Returns true if the length is fully decoded and the listener wants the + // decoding to continue, false otherwise; status is set to the status from + // the varint decoder. + // If the length is not fully decoded, case state_ is set appropriately + // for the next call to Resume. + template + bool StartDecodingLength(DecodeBuffer* db, Listener* cb, + DecodeStatus* status) { + if (db->Empty()) { + *status = DecodeStatus::kDecodeInProgress; + state_ = kStartDecodingLength; + return false; + } + uint8_t h_and_prefix = db->DecodeUInt8(); + huffman_encoded_ = (h_and_prefix & 0x80) == 0x80; + *status = length_decoder_.Start(h_and_prefix, 7, db); + if (*status == DecodeStatus::kDecodeDone) { + OnStringStart(cb, status); + return true; + } + // Set the state to cover the DecodeStatus::kDecodeInProgress case. + // Won't be needed if the status is kDecodeError. + state_ = kResumeDecodingLength; + return false; + } + + // Returns true if the length is fully decoded and the listener wants the + // decoding to continue, false otherwise; status is set to the status from + // the varint decoder; state_ is updated when fully decoded. + // If the length is not fully decoded, case state_ is set appropriately + // for the next call to Resume. + template + bool ResumeDecodingLength(DecodeBuffer* db, Listener* cb, + DecodeStatus* status) { + QUICHE_DCHECK_EQ(state_, kResumeDecodingLength); + *status = length_decoder_.Resume(db); + if (*status == DecodeStatus::kDecodeDone) { + state_ = kDecodingString; + OnStringStart(cb, status); + return true; + } + return false; + } + + // Returns true if the listener wants the decoding to continue, and + // false otherwise, in which case status set. + template + void OnStringStart(Listener* cb, DecodeStatus* /*status*/) { + // TODO(vasilvv): fail explicitly in case of truncation. + remaining_ = static_cast(length_decoder_.value()); + // Make callback so consumer knows what is coming. + cb->OnStringStart(huffman_encoded_, remaining_); + } + + // Passes the available portion of the string to the listener, and signals + // the end of the string when it is reached. Returns kDecodeDone or + // kDecodeInProgress as appropriate. + template + DecodeStatus DecodeString(DecodeBuffer* db, Listener* cb) { + size_t len = std::min(remaining_, db->Remaining()); + if (len > 0) { + cb->OnStringData(db->cursor(), len); + db->AdvanceCursor(len); + remaining_ -= len; + } + if (remaining_ == 0) { + cb->OnStringEnd(); + return DecodeStatus::kDecodeDone; + } + state_ = kDecodingString; + return DecodeStatus::kDecodeInProgress; + } + + HpackVarintDecoder length_decoder_; + + // These fields are initialized just to keep ASAN happy about reading + // them from DebugString(). + size_t remaining_ = 0; + StringDecoderState state_ = kStartDecodingLength; + bool huffman_encoded_ = false; +}; + +QUICHE_EXPORT std::ostream& operator<<(std::ostream& out, + const HpackStringDecoder& v); + +} // namespace http2 +#endif // QUICHE_HTTP2_HPACK_DECODER_HPACK_STRING_DECODER_H_ diff --git a/quiche/http2/hpack/decoder/hpack_string_decoder_listener.cc b/quiche/http2/hpack/decoder/hpack_string_decoder_listener.cc new file mode 100644 index 000000000000..afc172115c4b --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_string_decoder_listener.cc @@ -0,0 +1,36 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/decoder/hpack_string_decoder_listener.h" + +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { +namespace test { + +void HpackStringDecoderVLoggingListener::OnStringStart(bool huffman_encoded, + size_t len) { + QUICHE_VLOG(1) << "OnStringStart: H=" << huffman_encoded << ", len=" << len; + if (wrapped_) { + wrapped_->OnStringStart(huffman_encoded, len); + } +} + +void HpackStringDecoderVLoggingListener::OnStringData(const char* data, + size_t len) { + QUICHE_VLOG(1) << "OnStringData: len=" << len; + if (wrapped_) { + return wrapped_->OnStringData(data, len); + } +} + +void HpackStringDecoderVLoggingListener::OnStringEnd() { + QUICHE_VLOG(1) << "OnStringEnd"; + if (wrapped_) { + return wrapped_->OnStringEnd(); + } +} + +} // namespace test +} // namespace http2 diff --git a/quiche/http2/hpack/decoder/hpack_string_decoder_listener.h b/quiche/http2/hpack/decoder/hpack_string_decoder_listener.h new file mode 100644 index 000000000000..942056035af8 --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_string_decoder_listener.h @@ -0,0 +1,62 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_HPACK_DECODER_HPACK_STRING_DECODER_LISTENER_H_ +#define QUICHE_HTTP2_HPACK_DECODER_HPACK_STRING_DECODER_LISTENER_H_ + +// Defines HpackStringDecoderListener which defines the methods required by an +// HpackStringDecoder. Also defines HpackStringDecoderVLoggingListener which +// logs before calling another HpackStringDecoderListener implementation. +// For now these are only used by tests, so placed in the test namespace. + +#include + +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace test { + +// HpackStringDecoder methods require a listener that implements the methods +// below, but it is NOT necessary to extend this class because the methods +// are templates. +class QUICHE_EXPORT HpackStringDecoderListener { + public: + virtual ~HpackStringDecoderListener() {} + + // Called at the start of decoding an HPACK string. The encoded length of the + // string is |len| bytes, which may be zero. The string is Huffman encoded + // if huffman_encoded is true, else it is plain text (i.e. the encoded length + // is then the plain text length). + virtual void OnStringStart(bool huffman_encoded, size_t len) = 0; + + // Called when some data is available, or once when the string length is zero + // (to simplify the decoder, it doesn't have a special case for len==0). + virtual void OnStringData(const char* data, size_t len) = 0; + + // Called after OnStringData has provided all of the encoded bytes of the + // string. + virtual void OnStringEnd() = 0; +}; + +class QUICHE_EXPORT HpackStringDecoderVLoggingListener + : public HpackStringDecoderListener { + public: + HpackStringDecoderVLoggingListener() : wrapped_(nullptr) {} + explicit HpackStringDecoderVLoggingListener( + HpackStringDecoderListener* wrapped) + : wrapped_(wrapped) {} + ~HpackStringDecoderVLoggingListener() override {} + + void OnStringStart(bool huffman_encoded, size_t len) override; + void OnStringData(const char* data, size_t len) override; + void OnStringEnd() override; + + private: + HpackStringDecoderListener* const wrapped_; +}; + +} // namespace test +} // namespace http2 + +#endif // QUICHE_HTTP2_HPACK_DECODER_HPACK_STRING_DECODER_LISTENER_H_ diff --git a/quiche/http2/hpack/decoder/hpack_string_decoder_test.cc b/quiche/http2/hpack/decoder/hpack_string_decoder_test.cc new file mode 100644 index 000000000000..8a15e4b4f9d5 --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_string_decoder_test.cc @@ -0,0 +1,153 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/decoder/hpack_string_decoder.h" + +// Tests of HpackStringDecoder. + +#include "absl/strings/string_view.h" +#include "quiche/http2/hpack/decoder/hpack_string_decoder_listener.h" +#include "quiche/http2/test_tools/hpack_block_builder.h" +#include "quiche/http2/test_tools/hpack_string_collector.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/http2/test_tools/random_decoder_test_base.h" +#include "quiche/http2/test_tools/verify_macros.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { +namespace { + +const bool kMayReturnZeroOnFirst = false; +const bool kCompressed = true; +const bool kUncompressed = false; + +class HpackStringDecoderTest : public RandomDecoderTest { + protected: + HpackStringDecoderTest() : listener_(&collector_) {} + + DecodeStatus StartDecoding(DecodeBuffer* b) override { + ++start_decoding_calls_; + collector_.Clear(); + return decoder_.Start(b, &listener_); + } + + DecodeStatus ResumeDecoding(DecodeBuffer* b) override { + // Provides coverage of DebugString and StateToString. + // Not validating output. + QUICHE_VLOG(1) << decoder_.DebugString(); + QUICHE_VLOG(2) << collector_; + return decoder_.Resume(b, &listener_); + } + + AssertionResult Collected(absl::string_view s, bool huffman_encoded) { + QUICHE_VLOG(1) << collector_; + return collector_.Collected(s, huffman_encoded); + } + + // expected_str is a std::string rather than a const std::string& or + // absl::string_view so that the lambda makes a copy of the string, and thus + // the string to be passed to Collected outlives the call to MakeValidator. + Validator MakeValidator(const std::string& expected_str, + bool expected_huffman) { + return [expected_str, expected_huffman, this]( + const DecodeBuffer& /*input*/, + DecodeStatus /*status*/) -> AssertionResult { + AssertionResult result = Collected(expected_str, expected_huffman); + if (result) { + HTTP2_VERIFY_EQ(collector_, + HpackStringCollector(expected_str, expected_huffman)); + } else { + HTTP2_VERIFY_NE(collector_, + HpackStringCollector(expected_str, expected_huffman)); + } + QUICHE_VLOG(2) << collector_.ToString(); + collector_.Clear(); + QUICHE_VLOG(2) << collector_; + return result; + }; + } + + HpackStringDecoder decoder_; + HpackStringCollector collector_; + HpackStringDecoderVLoggingListener listener_; + size_t start_decoding_calls_ = 0; +}; + +TEST_F(HpackStringDecoderTest, DecodeEmptyString) { + { + Validator validator = ValidateDoneAndEmpty(MakeValidator("", kCompressed)); + const char kData[] = {'\x80'}; + DecodeBuffer b(kData); + EXPECT_TRUE( + DecodeAndValidateSeveralWays(&b, kMayReturnZeroOnFirst, validator)); + } + { + // Make sure it stops after decoding the empty string. + Validator validator = + ValidateDoneAndOffset(1, MakeValidator("", kUncompressed)); + const char kData[] = {'\x00', '\xff'}; + DecodeBuffer b(kData); + EXPECT_EQ(2u, b.Remaining()); + EXPECT_TRUE( + DecodeAndValidateSeveralWays(&b, kMayReturnZeroOnFirst, validator)); + EXPECT_EQ(1u, b.Remaining()); + } +} + +TEST_F(HpackStringDecoderTest, DecodeShortString) { + { + // Make sure it stops after decoding the non-empty string. + Validator validator = + ValidateDoneAndOffset(11, MakeValidator("start end.", kCompressed)); + const char kData[] = "\x8astart end.Don't peek at this."; + DecodeBuffer b(kData); + EXPECT_TRUE( + DecodeAndValidateSeveralWays(&b, kMayReturnZeroOnFirst, validator)); + } + { + Validator validator = + ValidateDoneAndOffset(11, MakeValidator("start end.", kUncompressed)); + absl::string_view data("\x0astart end."); + DecodeBuffer b(data); + EXPECT_TRUE( + DecodeAndValidateSeveralWays(&b, kMayReturnZeroOnFirst, validator)); + } +} + +TEST_F(HpackStringDecoderTest, DecodeLongStrings) { + std::string name = Random().RandString(1024); + std::string value = Random().RandString(65536); + HpackBlockBuilder hbb; + + hbb.AppendString(false, name); + uint32_t offset_after_name = hbb.size(); + EXPECT_EQ(3 + name.size(), offset_after_name); + + hbb.AppendString(true, value); + uint32_t offset_after_value = hbb.size(); + EXPECT_EQ(3 + name.size() + 4 + value.size(), offset_after_value); + + DecodeBuffer b(hbb.buffer()); + + // Decode the name... + EXPECT_TRUE(DecodeAndValidateSeveralWays( + &b, kMayReturnZeroOnFirst, + ValidateDoneAndOffset(offset_after_name, + MakeValidator(name, kUncompressed)))); + EXPECT_EQ(offset_after_name, b.Offset()); + EXPECT_EQ(offset_after_value - offset_after_name, b.Remaining()); + + // Decode the value... + EXPECT_TRUE(DecodeAndValidateSeveralWays( + &b, kMayReturnZeroOnFirst, + ValidateDoneAndOffset(offset_after_value - offset_after_name, + MakeValidator(value, kCompressed)))); + EXPECT_EQ(offset_after_value, b.Offset()); + EXPECT_EQ(0u, b.Remaining()); +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/hpack/decoder/hpack_whole_entry_buffer.cc b/quiche/http2/hpack/decoder/hpack_whole_entry_buffer.cc new file mode 100644 index 000000000000..8cf2e14af7da --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_whole_entry_buffer.cc @@ -0,0 +1,152 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/decoder/hpack_whole_entry_buffer.h" + +#include "absl/strings/str_cat.h" +#include "quiche/common/platform/api/quiche_flag_utils.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_text_utils.h" + +namespace http2 { + +HpackWholeEntryBuffer::HpackWholeEntryBuffer(HpackWholeEntryListener* listener, + size_t max_string_size_bytes) + : max_string_size_bytes_(max_string_size_bytes) { + set_listener(listener); +} +HpackWholeEntryBuffer::~HpackWholeEntryBuffer() = default; + +void HpackWholeEntryBuffer::set_listener(HpackWholeEntryListener* listener) { + QUICHE_CHECK(listener); + listener_ = listener; +} + +void HpackWholeEntryBuffer::set_max_string_size_bytes( + size_t max_string_size_bytes) { + max_string_size_bytes_ = max_string_size_bytes; +} + +void HpackWholeEntryBuffer::BufferStringsIfUnbuffered() { + name_.BufferStringIfUnbuffered(); + value_.BufferStringIfUnbuffered(); +} + +void HpackWholeEntryBuffer::OnIndexedHeader(size_t index) { + QUICHE_DVLOG(2) << "HpackWholeEntryBuffer::OnIndexedHeader: index=" << index; + listener_->OnIndexedHeader(index); +} + +void HpackWholeEntryBuffer::OnStartLiteralHeader(HpackEntryType entry_type, + size_t maybe_name_index) { + QUICHE_DVLOG(2) << "HpackWholeEntryBuffer::OnStartLiteralHeader: entry_type=" + << entry_type << ", maybe_name_index=" << maybe_name_index; + entry_type_ = entry_type; + maybe_name_index_ = maybe_name_index; +} + +void HpackWholeEntryBuffer::OnNameStart(bool huffman_encoded, size_t len) { + QUICHE_DVLOG(2) << "HpackWholeEntryBuffer::OnNameStart: huffman_encoded=" + << (huffman_encoded ? "true" : "false") << ", len=" << len; + QUICHE_DCHECK_EQ(maybe_name_index_, 0u); + if (!error_detected_) { + if (len > max_string_size_bytes_) { + QUICHE_DVLOG(1) << "Name length (" << len + << ") is longer than permitted (" + << max_string_size_bytes_ << ")"; + ReportError(HpackDecodingError::kNameTooLong, ""); + QUICHE_CODE_COUNT_N(decompress_failure_3, 18, 23); + return; + } + name_.OnStart(huffman_encoded, len); + } +} + +void HpackWholeEntryBuffer::OnNameData(const char* data, size_t len) { + QUICHE_DVLOG(2) << "HpackWholeEntryBuffer::OnNameData: len=" << len + << " data:\n" + << quiche::QuicheTextUtils::HexDump( + absl::string_view(data, len)); + QUICHE_DCHECK_EQ(maybe_name_index_, 0u); + if (!error_detected_ && !name_.OnData(data, len)) { + ReportError(HpackDecodingError::kNameHuffmanError, ""); + QUICHE_CODE_COUNT_N(decompress_failure_3, 19, 23); + } +} + +void HpackWholeEntryBuffer::OnNameEnd() { + QUICHE_DVLOG(2) << "HpackWholeEntryBuffer::OnNameEnd"; + QUICHE_DCHECK_EQ(maybe_name_index_, 0u); + if (!error_detected_ && !name_.OnEnd()) { + ReportError(HpackDecodingError::kNameHuffmanError, ""); + QUICHE_CODE_COUNT_N(decompress_failure_3, 20, 23); + } +} + +void HpackWholeEntryBuffer::OnValueStart(bool huffman_encoded, size_t len) { + QUICHE_DVLOG(2) << "HpackWholeEntryBuffer::OnValueStart: huffman_encoded=" + << (huffman_encoded ? "true" : "false") << ", len=" << len; + if (!error_detected_) { + if (len > max_string_size_bytes_) { + std::string detailed_error = absl::StrCat( + "Value length (", len, ") of [", name_.GetStringIfComplete(), + "] is longer than permitted (", max_string_size_bytes_, ")"); + QUICHE_DVLOG(1) << detailed_error; + ReportError(HpackDecodingError::kValueTooLong, detailed_error); + QUICHE_CODE_COUNT_N(decompress_failure_3, 21, 23); + return; + } + value_.OnStart(huffman_encoded, len); + } +} + +void HpackWholeEntryBuffer::OnValueData(const char* data, size_t len) { + QUICHE_DVLOG(2) << "HpackWholeEntryBuffer::OnValueData: len=" << len + << " data:\n" + << quiche::QuicheTextUtils::HexDump( + absl::string_view(data, len)); + if (!error_detected_ && !value_.OnData(data, len)) { + ReportError(HpackDecodingError::kValueHuffmanError, ""); + QUICHE_CODE_COUNT_N(decompress_failure_3, 22, 23); + } +} + +void HpackWholeEntryBuffer::OnValueEnd() { + QUICHE_DVLOG(2) << "HpackWholeEntryBuffer::OnValueEnd"; + if (error_detected_) { + return; + } + if (!value_.OnEnd()) { + ReportError(HpackDecodingError::kValueHuffmanError, ""); + QUICHE_CODE_COUNT_N(decompress_failure_3, 23, 23); + return; + } + if (maybe_name_index_ == 0) { + listener_->OnLiteralNameAndValue(entry_type_, &name_, &value_); + name_.Reset(); + } else { + listener_->OnNameIndexAndLiteralValue(entry_type_, maybe_name_index_, + &value_); + } + value_.Reset(); +} + +void HpackWholeEntryBuffer::OnDynamicTableSizeUpdate(size_t size) { + QUICHE_DVLOG(2) << "HpackWholeEntryBuffer::OnDynamicTableSizeUpdate: size=" + << size; + listener_->OnDynamicTableSizeUpdate(size); +} + +void HpackWholeEntryBuffer::ReportError(HpackDecodingError error, + std::string detailed_error) { + if (!error_detected_) { + QUICHE_DVLOG(1) << "HpackWholeEntryBuffer::ReportError: " + << HpackDecodingErrorToString(error); + error_detected_ = true; + listener_->OnHpackDecodeError(error, detailed_error); + listener_ = HpackWholeEntryNoOpListener::NoOpListener(); + } +} + +} // namespace http2 diff --git a/quiche/http2/hpack/decoder/hpack_whole_entry_buffer.h b/quiche/http2/hpack/decoder/hpack_whole_entry_buffer.h new file mode 100644 index 000000000000..b7bd1088e7bc --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_whole_entry_buffer.h @@ -0,0 +1,101 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_HPACK_DECODER_HPACK_WHOLE_ENTRY_BUFFER_H_ +#define QUICHE_HTTP2_HPACK_DECODER_HPACK_WHOLE_ENTRY_BUFFER_H_ + +// HpackWholeEntryBuffer isolates a listener from the fact that an entry may +// be split across multiple input buffers, providing one callback per entry. +// HpackWholeEntryBuffer requires that the HpackEntryDecoderListener be made in +// the correct order, which is tested by hpack_entry_decoder_test.cc. + +#include + +#include "absl/strings/string_view.h" +#include "quiche/http2/hpack/decoder/hpack_decoder_string_buffer.h" +#include "quiche/http2/hpack/decoder/hpack_decoding_error.h" +#include "quiche/http2/hpack/decoder/hpack_entry_decoder_listener.h" +#include "quiche/http2/hpack/decoder/hpack_whole_entry_listener.h" +#include "quiche/http2/hpack/http2_hpack_constants.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { + +// TODO(jamessynge): Consider renaming HpackEntryDecoderListener to +// HpackEntryPartsListener or HpackEntryFragmentsListener. +class QUICHE_EXPORT HpackWholeEntryBuffer : public HpackEntryDecoderListener { + public: + // max_string_size specifies the maximum size of an on-the-wire string (name + // or value, plain or Huffman encoded) that will be accepted. See sections + // 5.1 and 5.2 of RFC 7541. This is a defense against OOM attacks; HTTP/2 + // allows a decoder to enforce any limit of the size of the header lists + // that it is willing decode, including less than the MAX_HEADER_LIST_SIZE + // setting, a setting that is initially unlimited. For example, we might + // choose to send a MAX_HEADER_LIST_SIZE of 64KB, and to use that same value + // as the upper bound for individual strings. + HpackWholeEntryBuffer(HpackWholeEntryListener* listener, + size_t max_string_size); + ~HpackWholeEntryBuffer() override; + + HpackWholeEntryBuffer(const HpackWholeEntryBuffer&) = delete; + HpackWholeEntryBuffer& operator=(const HpackWholeEntryBuffer&) = delete; + + // Set the listener to be notified when a whole entry has been decoded. + // The listener may be changed at any time. + void set_listener(HpackWholeEntryListener* listener); + + // Set how much encoded data this decoder is willing to buffer. + // TODO(jamessynge): Come up with consistent semantics for this protection + // across the various decoders; e.g. should it be for a single string or + // a single header entry? + void set_max_string_size_bytes(size_t max_string_size_bytes); + + // Ensure that decoded strings pointed to by the HpackDecoderStringBuffer + // instances name_ and value_ are buffered, which allows any underlying + // transport buffer to be freed or reused without overwriting the decoded + // strings. This is needed only when an HPACK entry is split across transport + // buffers. See HpackDecoder::DecodeFragment. + void BufferStringsIfUnbuffered(); + + // Was an error detected? After an error has been detected and reported, + // no further callbacks will be made to the listener. + bool error_detected() const { return error_detected_; } + + // Implement the HpackEntryDecoderListener methods. + + void OnIndexedHeader(size_t index) override; + void OnStartLiteralHeader(HpackEntryType entry_type, + size_t maybe_name_index) override; + void OnNameStart(bool huffman_encoded, size_t len) override; + void OnNameData(const char* data, size_t len) override; + void OnNameEnd() override; + void OnValueStart(bool huffman_encoded, size_t len) override; + void OnValueData(const char* data, size_t len) override; + void OnValueEnd() override; + void OnDynamicTableSizeUpdate(size_t size) override; + + private: + void ReportError(HpackDecodingError error, std::string detailed_error); + + HpackWholeEntryListener* listener_; + HpackDecoderStringBuffer name_, value_; + + // max_string_size_bytes_ specifies the maximum allowed size of an on-the-wire + // string. Larger strings will be reported as errors to the listener; the + // endpoint should treat these as COMPRESSION errors, which are CONNECTION + // level errors. + size_t max_string_size_bytes_; + + // The name index (or zero) of the current header entry with a literal value. + size_t maybe_name_index_; + + // The type of the current header entry (with literals) that is being decoded. + HpackEntryType entry_type_; + + bool error_detected_ = false; +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_HPACK_DECODER_HPACK_WHOLE_ENTRY_BUFFER_H_ diff --git a/quiche/http2/hpack/decoder/hpack_whole_entry_buffer_test.cc b/quiche/http2/hpack/decoder/hpack_whole_entry_buffer_test.cc new file mode 100644 index 000000000000..d2bb81c7b56c --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_whole_entry_buffer_test.cc @@ -0,0 +1,226 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/decoder/hpack_whole_entry_buffer.h" + +// Tests of HpackWholeEntryBuffer: does it buffer correctly, and does it +// detect Huffman decoding errors and oversize string errors? + +#include "quiche/common/platform/api/quiche_test.h" + +using ::testing::_; +using ::testing::AllOf; +using ::testing::InSequence; +using ::testing::Property; +using ::testing::StrictMock; + +namespace http2 { +namespace test { +namespace { + +constexpr size_t kMaxStringSize = 20; + +class MockHpackWholeEntryListener : public HpackWholeEntryListener { + public: + ~MockHpackWholeEntryListener() override = default; + + MOCK_METHOD(void, OnIndexedHeader, (size_t index), (override)); + MOCK_METHOD(void, OnNameIndexAndLiteralValue, + (HpackEntryType entry_type, size_t name_index, + HpackDecoderStringBuffer* value_buffer), + (override)); + MOCK_METHOD(void, OnLiteralNameAndValue, + (HpackEntryType entry_type, HpackDecoderStringBuffer* name_buffer, + HpackDecoderStringBuffer* value_buffer), + (override)); + MOCK_METHOD(void, OnDynamicTableSizeUpdate, (size_t size), (override)); + MOCK_METHOD(void, OnHpackDecodeError, + (HpackDecodingError error, std::string detailed_error), + (override)); +}; + +class HpackWholeEntryBufferTest : public quiche::test::QuicheTest { + protected: + HpackWholeEntryBufferTest() : entry_buffer_(&listener_, kMaxStringSize) {} + ~HpackWholeEntryBufferTest() override = default; + + StrictMock listener_; + HpackWholeEntryBuffer entry_buffer_; +}; + +// OnIndexedHeader is an immediate pass through. +TEST_F(HpackWholeEntryBufferTest, OnIndexedHeader) { + { + InSequence seq; + EXPECT_CALL(listener_, OnIndexedHeader(17)); + entry_buffer_.OnIndexedHeader(17); + } + { + InSequence seq; + EXPECT_CALL(listener_, OnIndexedHeader(62)); + entry_buffer_.OnIndexedHeader(62); + } + { + InSequence seq; + EXPECT_CALL(listener_, OnIndexedHeader(62)); + entry_buffer_.OnIndexedHeader(62); + } + { + InSequence seq; + EXPECT_CALL(listener_, OnIndexedHeader(128)); + entry_buffer_.OnIndexedHeader(128); + } + StrictMock listener2; + entry_buffer_.set_listener(&listener2); + { + InSequence seq; + EXPECT_CALL(listener2, OnIndexedHeader(100)); + entry_buffer_.OnIndexedHeader(100); + } +} + +// OnDynamicTableSizeUpdate is an immediate pass through. +TEST_F(HpackWholeEntryBufferTest, OnDynamicTableSizeUpdate) { + { + InSequence seq; + EXPECT_CALL(listener_, OnDynamicTableSizeUpdate(4096)); + entry_buffer_.OnDynamicTableSizeUpdate(4096); + } + { + InSequence seq; + EXPECT_CALL(listener_, OnDynamicTableSizeUpdate(0)); + entry_buffer_.OnDynamicTableSizeUpdate(0); + } + { + InSequence seq; + EXPECT_CALL(listener_, OnDynamicTableSizeUpdate(1024)); + entry_buffer_.OnDynamicTableSizeUpdate(1024); + } + { + InSequence seq; + EXPECT_CALL(listener_, OnDynamicTableSizeUpdate(1024)); + entry_buffer_.OnDynamicTableSizeUpdate(1024); + } + StrictMock listener2; + entry_buffer_.set_listener(&listener2); + { + InSequence seq; + EXPECT_CALL(listener2, OnDynamicTableSizeUpdate(0)); + entry_buffer_.OnDynamicTableSizeUpdate(0); + } +} + +TEST_F(HpackWholeEntryBufferTest, OnNameIndexAndLiteralValue) { + entry_buffer_.OnStartLiteralHeader(HpackEntryType::kNeverIndexedLiteralHeader, + 123); + entry_buffer_.OnValueStart(false, 10); + entry_buffer_.OnValueData("some data.", 10); + + // Force the value to be buffered. + entry_buffer_.BufferStringsIfUnbuffered(); + + EXPECT_CALL( + listener_, + OnNameIndexAndLiteralValue( + HpackEntryType::kNeverIndexedLiteralHeader, 123, + AllOf(Property(&HpackDecoderStringBuffer::str, "some data."), + Property(&HpackDecoderStringBuffer::BufferedLength, 10)))); + + entry_buffer_.OnValueEnd(); +} + +TEST_F(HpackWholeEntryBufferTest, OnLiteralNameAndValue) { + entry_buffer_.OnStartLiteralHeader(HpackEntryType::kIndexedLiteralHeader, 0); + // Force the name to be buffered by delivering it in two pieces. + entry_buffer_.OnNameStart(false, 9); + entry_buffer_.OnNameData("some-", 5); + entry_buffer_.OnNameData("name", 4); + entry_buffer_.OnNameEnd(); + entry_buffer_.OnValueStart(false, 12); + entry_buffer_.OnValueData("Header Value", 12); + + EXPECT_CALL( + listener_, + OnLiteralNameAndValue( + HpackEntryType::kIndexedLiteralHeader, + AllOf(Property(&HpackDecoderStringBuffer::str, "some-name"), + Property(&HpackDecoderStringBuffer::BufferedLength, 9)), + AllOf(Property(&HpackDecoderStringBuffer::str, "Header Value"), + Property(&HpackDecoderStringBuffer::BufferedLength, 0)))); + + entry_buffer_.OnValueEnd(); +} + +// Verify that a name longer than the allowed size generates an error. +TEST_F(HpackWholeEntryBufferTest, NameTooLong) { + entry_buffer_.OnStartLiteralHeader(HpackEntryType::kIndexedLiteralHeader, 0); + EXPECT_CALL(listener_, + OnHpackDecodeError(HpackDecodingError::kNameTooLong, _)); + entry_buffer_.OnNameStart(false, kMaxStringSize + 1); +} + +// Verify that a value longer than the allowed size generates an error. +TEST_F(HpackWholeEntryBufferTest, ValueTooLong) { + entry_buffer_.OnStartLiteralHeader(HpackEntryType::kIndexedLiteralHeader, 0); + EXPECT_CALL(listener_, + OnHpackDecodeError( + HpackDecodingError::kValueTooLong, + "Value length (21) of [path] is longer than permitted (20)")); + entry_buffer_.OnNameStart(false, 4); + entry_buffer_.OnNameData("path", 4); + entry_buffer_.OnNameEnd(); + entry_buffer_.OnValueStart(false, kMaxStringSize + 1); +} + +// Regression test for b/162141899. +TEST_F(HpackWholeEntryBufferTest, ValueTooLongWithoutName) { + entry_buffer_.OnStartLiteralHeader(HpackEntryType::kIndexedLiteralHeader, 1); + EXPECT_CALL(listener_, + OnHpackDecodeError( + HpackDecodingError::kValueTooLong, + "Value length (21) of [] is longer than permitted (20)")); + entry_buffer_.OnValueStart(false, kMaxStringSize + 1); +} + +// Verify that a Huffman encoded name with an explicit EOS generates an error +// for an explicit EOS. +TEST_F(HpackWholeEntryBufferTest, NameHuffmanError) { + const char data[] = "\xff\xff\xff"; + entry_buffer_.OnStartLiteralHeader(HpackEntryType::kUnindexedLiteralHeader, + 0); + entry_buffer_.OnNameStart(true, 4); + entry_buffer_.OnNameData(data, 3); + + EXPECT_CALL(listener_, + OnHpackDecodeError(HpackDecodingError::kNameHuffmanError, _)); + + entry_buffer_.OnNameData(data, 1); + + // After an error is reported, the listener is not called again. + EXPECT_CALL(listener_, OnDynamicTableSizeUpdate(8096)).Times(0); + entry_buffer_.OnDynamicTableSizeUpdate(8096); +} + +// Verify that a Huffman encoded value that isn't properly terminated with +// a partial EOS symbol generates an error. +TEST_F(HpackWholeEntryBufferTest, ValueHuffmanError) { + const char data[] = "\x00\x00\x00"; + entry_buffer_.OnStartLiteralHeader(HpackEntryType::kNeverIndexedLiteralHeader, + 61); + entry_buffer_.OnValueStart(true, 3); + entry_buffer_.OnValueData(data, 3); + + EXPECT_CALL(listener_, + OnHpackDecodeError(HpackDecodingError::kValueHuffmanError, _)); + + entry_buffer_.OnValueEnd(); + + // After an error is reported, the listener is not called again. + EXPECT_CALL(listener_, OnIndexedHeader(17)).Times(0); + entry_buffer_.OnIndexedHeader(17); +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/hpack/decoder/hpack_whole_entry_listener.cc b/quiche/http2/hpack/decoder/hpack_whole_entry_listener.cc new file mode 100644 index 000000000000..d6fe82b65723 --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_whole_entry_listener.cc @@ -0,0 +1,31 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/decoder/hpack_whole_entry_listener.h" + +namespace http2 { + +HpackWholeEntryListener::~HpackWholeEntryListener() = default; + +HpackWholeEntryNoOpListener::~HpackWholeEntryNoOpListener() = default; + +void HpackWholeEntryNoOpListener::OnIndexedHeader(size_t /*index*/) {} +void HpackWholeEntryNoOpListener::OnNameIndexAndLiteralValue( + HpackEntryType /*entry_type*/, size_t /*name_index*/, + HpackDecoderStringBuffer* /*value_buffer*/) {} +void HpackWholeEntryNoOpListener::OnLiteralNameAndValue( + HpackEntryType /*entry_type*/, HpackDecoderStringBuffer* /*name_buffer*/, + HpackDecoderStringBuffer* /*value_buffer*/) {} +void HpackWholeEntryNoOpListener::OnDynamicTableSizeUpdate(size_t /*size*/) {} +void HpackWholeEntryNoOpListener::OnHpackDecodeError( + HpackDecodingError /*error*/, std::string /*detailed_error*/) {} + +// static +HpackWholeEntryNoOpListener* HpackWholeEntryNoOpListener::NoOpListener() { + static HpackWholeEntryNoOpListener* static_instance = + new HpackWholeEntryNoOpListener(); + return static_instance; +} + +} // namespace http2 diff --git a/quiche/http2/hpack/decoder/hpack_whole_entry_listener.h b/quiche/http2/hpack/decoder/hpack_whole_entry_listener.h new file mode 100644 index 000000000000..54e45063fee5 --- /dev/null +++ b/quiche/http2/hpack/decoder/hpack_whole_entry_listener.h @@ -0,0 +1,80 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Defines HpackWholeEntryListener, the base class of listeners for decoded +// complete HPACK entries, as opposed to HpackEntryDecoderListener which +// receives multiple callbacks for some single entries. + +#ifndef QUICHE_HTTP2_HPACK_DECODER_HPACK_WHOLE_ENTRY_LISTENER_H_ +#define QUICHE_HTTP2_HPACK_DECODER_HPACK_WHOLE_ENTRY_LISTENER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "quiche/http2/hpack/decoder/hpack_decoder_string_buffer.h" +#include "quiche/http2/hpack/decoder/hpack_decoding_error.h" +#include "quiche/http2/hpack/http2_hpack_constants.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { + +class QUICHE_EXPORT HpackWholeEntryListener { + public: + virtual ~HpackWholeEntryListener(); + + // Called when an indexed header (i.e. one in the static or dynamic table) has + // been decoded from an HPACK block. index is supposed to be non-zero, but + // that has not been checked by the caller. + virtual void OnIndexedHeader(size_t index) = 0; + + // Called when a header entry with a name index and literal value has + // been fully decoded from an HPACK block. name_index is NOT zero. + // entry_type will be kIndexedLiteralHeader, kUnindexedLiteralHeader, or + // kNeverIndexedLiteralHeader. + virtual void OnNameIndexAndLiteralValue( + HpackEntryType entry_type, size_t name_index, + HpackDecoderStringBuffer* value_buffer) = 0; + + // Called when a header entry with a literal name and literal value + // has been fully decoded from an HPACK block. entry_type will be + // kIndexedLiteralHeader, kUnindexedLiteralHeader, or + // kNeverIndexedLiteralHeader. + virtual void OnLiteralNameAndValue( + HpackEntryType entry_type, HpackDecoderStringBuffer* name_buffer, + HpackDecoderStringBuffer* value_buffer) = 0; + + // Called when an update to the size of the peer's dynamic table has been + // decoded. + virtual void OnDynamicTableSizeUpdate(size_t size) = 0; + + // OnHpackDecodeError is called if an error is detected while decoding. + virtual void OnHpackDecodeError(HpackDecodingError error, + std::string detailed_error) = 0; +}; + +// A no-op implementation of HpackWholeEntryDecoderListener, useful for ignoring +// callbacks once an error is detected. +class QUICHE_EXPORT HpackWholeEntryNoOpListener + : public HpackWholeEntryListener { + public: + ~HpackWholeEntryNoOpListener() override; + + void OnIndexedHeader(size_t index) override; + void OnNameIndexAndLiteralValue( + HpackEntryType entry_type, size_t name_index, + HpackDecoderStringBuffer* value_buffer) override; + void OnLiteralNameAndValue(HpackEntryType entry_type, + HpackDecoderStringBuffer* name_buffer, + HpackDecoderStringBuffer* value_buffer) override; + void OnDynamicTableSizeUpdate(size_t size) override; + void OnHpackDecodeError(HpackDecodingError error, + std::string detailed_error) override; + + // Returns a listener that ignores all the calls. + static HpackWholeEntryNoOpListener* NoOpListener(); +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_HPACK_DECODER_HPACK_WHOLE_ENTRY_LISTENER_H_ diff --git a/quiche/http2/hpack/hpack_static_table_entries.inc b/quiche/http2/hpack/hpack_static_table_entries.inc new file mode 100644 index 000000000000..c6ae125f3b22 --- /dev/null +++ b/quiche/http2/hpack/hpack_static_table_entries.inc @@ -0,0 +1,65 @@ +// This file is designed to be included by C/C++ files which need the contents +// of the HPACK static table. It may be included more than once if necessary. +// See http://httpwg.org/specs/rfc7541.html#static.table.definition + +STATIC_TABLE_ENTRY(":authority", "", 1); +STATIC_TABLE_ENTRY(":method", "GET", 2); +STATIC_TABLE_ENTRY(":method", "POST", 3); +STATIC_TABLE_ENTRY(":path", "/", 4); +STATIC_TABLE_ENTRY(":path", "/index.html", 5); +STATIC_TABLE_ENTRY(":scheme", "http", 6); +STATIC_TABLE_ENTRY(":scheme", "https", 7); +STATIC_TABLE_ENTRY(":status", "200", 8); +STATIC_TABLE_ENTRY(":status", "204", 9); +STATIC_TABLE_ENTRY(":status", "206", 10); +STATIC_TABLE_ENTRY(":status", "304", 11); +STATIC_TABLE_ENTRY(":status", "400", 12); +STATIC_TABLE_ENTRY(":status", "404", 13); +STATIC_TABLE_ENTRY(":status", "500", 14); +STATIC_TABLE_ENTRY("accept-charset", "", 15); +STATIC_TABLE_ENTRY("accept-encoding", "gzip, deflate", 16); +STATIC_TABLE_ENTRY("accept-language", "", 17); +STATIC_TABLE_ENTRY("accept-ranges", "", 18); +STATIC_TABLE_ENTRY("accept", "", 19); +STATIC_TABLE_ENTRY("access-control-allow-origin", "", 20); +STATIC_TABLE_ENTRY("age", "", 21); +STATIC_TABLE_ENTRY("allow", "", 22); +STATIC_TABLE_ENTRY("authorization", "", 23); +STATIC_TABLE_ENTRY("cache-control", "", 24); +STATIC_TABLE_ENTRY("content-disposition", "", 25); +STATIC_TABLE_ENTRY("content-encoding", "", 26); +STATIC_TABLE_ENTRY("content-language", "", 27); +STATIC_TABLE_ENTRY("content-length", "", 28); +STATIC_TABLE_ENTRY("content-location", "", 29); +STATIC_TABLE_ENTRY("content-range", "", 30); +STATIC_TABLE_ENTRY("content-type", "", 31); +STATIC_TABLE_ENTRY("cookie", "", 32); +STATIC_TABLE_ENTRY("date", "", 33); +STATIC_TABLE_ENTRY("etag", "", 34); +STATIC_TABLE_ENTRY("expect", "", 35); +STATIC_TABLE_ENTRY("expires", "", 36); +STATIC_TABLE_ENTRY("from", "", 37); +STATIC_TABLE_ENTRY("host", "", 38); +STATIC_TABLE_ENTRY("if-match", "", 39); +STATIC_TABLE_ENTRY("if-modified-since", "", 40); +STATIC_TABLE_ENTRY("if-none-match", "", 41); +STATIC_TABLE_ENTRY("if-range", "", 42); +STATIC_TABLE_ENTRY("if-unmodified-since", "", 43); +STATIC_TABLE_ENTRY("last-modified", "", 44); +STATIC_TABLE_ENTRY("link", "", 45); +STATIC_TABLE_ENTRY("location", "", 46); +STATIC_TABLE_ENTRY("max-forwards", "", 47); +STATIC_TABLE_ENTRY("proxy-authenticate", "", 48); +STATIC_TABLE_ENTRY("proxy-authorization", "", 49); +STATIC_TABLE_ENTRY("range", "", 50); +STATIC_TABLE_ENTRY("referer", "", 51); +STATIC_TABLE_ENTRY("refresh", "", 52); +STATIC_TABLE_ENTRY("retry-after", "", 53); +STATIC_TABLE_ENTRY("server", "", 54); +STATIC_TABLE_ENTRY("set-cookie", "", 55); +STATIC_TABLE_ENTRY("strict-transport-security", "", 56); +STATIC_TABLE_ENTRY("transfer-encoding", "", 57); +STATIC_TABLE_ENTRY("user-agent", "", 58); +STATIC_TABLE_ENTRY("vary", "", 59); +STATIC_TABLE_ENTRY("via", "", 60); +STATIC_TABLE_ENTRY("www-authenticate", "", 61); diff --git a/quiche/http2/hpack/http2_hpack_constants.cc b/quiche/http2/hpack/http2_hpack_constants.cc new file mode 100644 index 000000000000..e4a71b8fe9f6 --- /dev/null +++ b/quiche/http2/hpack/http2_hpack_constants.cc @@ -0,0 +1,31 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/http2_hpack_constants.h" + +#include "absl/strings/str_cat.h" + +namespace http2 { + +std::string HpackEntryTypeToString(HpackEntryType v) { + switch (v) { + case HpackEntryType::kIndexedHeader: + return "kIndexedHeader"; + case HpackEntryType::kDynamicTableSizeUpdate: + return "kDynamicTableSizeUpdate"; + case HpackEntryType::kIndexedLiteralHeader: + return "kIndexedLiteralHeader"; + case HpackEntryType::kUnindexedLiteralHeader: + return "kUnindexedLiteralHeader"; + case HpackEntryType::kNeverIndexedLiteralHeader: + return "kNeverIndexedLiteralHeader"; + } + return absl::StrCat("UnknownHpackEntryType(", static_cast(v), ")"); +} + +std::ostream& operator<<(std::ostream& out, HpackEntryType v) { + return out << HpackEntryTypeToString(v); +} + +} // namespace http2 diff --git a/quiche/http2/hpack/http2_hpack_constants.h b/quiche/http2/hpack/http2_hpack_constants.h new file mode 100644 index 000000000000..1deaf05bce00 --- /dev/null +++ b/quiche/http2/hpack/http2_hpack_constants.h @@ -0,0 +1,62 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_HPACK_HTTP2_HPACK_CONSTANTS_H_ +#define QUICHE_HTTP2_HPACK_HTTP2_HPACK_CONSTANTS_H_ + +// Enum HpackEntryType identifies the 5 basic types of HPACK Block Entries. +// +// See the spec for details: +// https://http2.github.io/http2-spec/compression.html#rfc.section.6 + +#include +#include + +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { + +const size_t kFirstDynamicTableIndex = 62; + +enum class HpackEntryType { + // Entry is an index into the static or dynamic table. Decoding it has no + // effect on the dynamic table. + kIndexedHeader, + + // The entry contains a literal value. The name may be either a literal or a + // reference to an entry in the static or dynamic table. + // The entry is added to the dynamic table after decoding. + kIndexedLiteralHeader, + + // The entry contains a literal value. The name may be either a literal or a + // reference to an entry in the static or dynamic table. + // The entry is not added to the dynamic table after decoding, but a proxy + // may choose to insert the entry into its dynamic table when forwarding + // to another endpoint. + kUnindexedLiteralHeader, + + // The entry contains a literal value. The name may be either a literal or a + // reference to an entry in the static or dynamic table. + // The entry is not added to the dynamic table after decoding, and a proxy + // must NOT insert the entry into its dynamic table when forwarding to another + // endpoint. + kNeverIndexedLiteralHeader, + + // Entry conveys the size limit of the dynamic table of the encoder to + // the decoder. May be used to flush the table by sending a zero and then + // resetting the size back up to the maximum that the encoder will use + // (within the limits of SETTINGS_HEADER_TABLE_SIZE sent by the + // decoder to the encoder, with the default of 4096 assumed). + kDynamicTableSizeUpdate, +}; + +// Returns the name of the enum member. +QUICHE_EXPORT std::string HpackEntryTypeToString(HpackEntryType v); + +// Inserts the name of the enum member into |out|. +QUICHE_EXPORT std::ostream& operator<<(std::ostream& out, HpackEntryType v); + +} // namespace http2 + +#endif // QUICHE_HTTP2_HPACK_HTTP2_HPACK_CONSTANTS_H_ diff --git a/quiche/http2/hpack/http2_hpack_constants_test.cc b/quiche/http2/hpack/http2_hpack_constants_test.cc new file mode 100644 index 000000000000..39bacd4c930a --- /dev/null +++ b/quiche/http2/hpack/http2_hpack_constants_test.cc @@ -0,0 +1,66 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/http2_hpack_constants.h" + +#include + +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { +namespace { + +TEST(HpackEntryTypeTest, HpackEntryTypeToString) { + EXPECT_EQ("kIndexedHeader", + HpackEntryTypeToString(HpackEntryType::kIndexedHeader)); + EXPECT_EQ("kDynamicTableSizeUpdate", + HpackEntryTypeToString(HpackEntryType::kDynamicTableSizeUpdate)); + EXPECT_EQ("kIndexedLiteralHeader", + HpackEntryTypeToString(HpackEntryType::kIndexedLiteralHeader)); + EXPECT_EQ("kUnindexedLiteralHeader", + HpackEntryTypeToString(HpackEntryType::kUnindexedLiteralHeader)); + EXPECT_EQ("kNeverIndexedLiteralHeader", + HpackEntryTypeToString(HpackEntryType::kNeverIndexedLiteralHeader)); + EXPECT_EQ("UnknownHpackEntryType(12321)", + HpackEntryTypeToString(static_cast(12321))); +} + +TEST(HpackEntryTypeTest, OutputHpackEntryType) { + { + std::stringstream log; + log << HpackEntryType::kIndexedHeader; + EXPECT_EQ("kIndexedHeader", log.str()); + } + { + std::stringstream log; + log << HpackEntryType::kDynamicTableSizeUpdate; + EXPECT_EQ("kDynamicTableSizeUpdate", log.str()); + } + { + std::stringstream log; + log << HpackEntryType::kIndexedLiteralHeader; + EXPECT_EQ("kIndexedLiteralHeader", log.str()); + } + { + std::stringstream log; + log << HpackEntryType::kUnindexedLiteralHeader; + EXPECT_EQ("kUnindexedLiteralHeader", log.str()); + } + { + std::stringstream log; + log << HpackEntryType::kNeverIndexedLiteralHeader; + EXPECT_EQ("kNeverIndexedLiteralHeader", log.str()); + } + { + std::stringstream log; + log << static_cast(1234321); + EXPECT_EQ("UnknownHpackEntryType(1234321)", log.str()); + } +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/hpack/huffman/hpack_huffman_decoder.cc b/quiche/http2/hpack/huffman/hpack_huffman_decoder.cc new file mode 100644 index 000000000000..3727557d2711 --- /dev/null +++ b/quiche/http2/hpack/huffman/hpack_huffman_decoder.cc @@ -0,0 +1,483 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/huffman/hpack_huffman_decoder.h" + +#include +#include + +#include "quiche/common/platform/api/quiche_logging.h" + +// Terminology: +// +// Symbol - a plain text (unencoded) character (uint8), or the End-of-String +// (EOS) symbol, 256. +// +// Code - the sequence of bits used to encode a symbol, varying in length from +// 5 bits for the most common symbols (e.g. '0', '1', and 'a'), to +// 30 bits for the least common (e.g. the EOS symbol). +// For those symbols whose codes have the same length, their code values +// are sorted such that the lower symbol value has a lower code value. +// +// Canonical - a symbol's cardinal value when sorted first by code length, and +// then by symbol value. For example, canonical 0 is for ASCII '0' +// (uint8 value 0x30), which is the first of the symbols whose code +// is 5 bits long, and the last canonical is EOS, which is the last +// of the symbols whose code is 30 bits long. + +namespace http2 { +namespace { + +// HuffmanCode is used to store the codes associated with symbols (a pattern of +// from 5 to 30 bits). +typedef uint32_t HuffmanCode; + +// HuffmanCodeBitCount is used to store a count of bits in a code. +typedef uint16_t HuffmanCodeBitCount; + +// HuffmanCodeBitSet is used for producing a string version of a code because +// std::bitset logs nicely. +typedef std::bitset<32> HuffmanCodeBitSet; +typedef std::bitset<64> HuffmanAccumulatorBitSet; + +static constexpr HuffmanCodeBitCount kMinCodeBitCount = 5; +static constexpr HuffmanCodeBitCount kMaxCodeBitCount = 30; +static constexpr HuffmanCodeBitCount kHuffmanCodeBitCount = + std::numeric_limits::digits; + +static_assert(std::numeric_limits::digits >= kMaxCodeBitCount, + "HuffmanCode isn't big enough."); + +static_assert(std::numeric_limits::digits >= + kMaxCodeBitCount, + "HuffmanAccumulator isn't big enough."); + +static constexpr HuffmanAccumulatorBitCount kHuffmanAccumulatorBitCount = + std::numeric_limits::digits; +static constexpr HuffmanAccumulatorBitCount kExtraAccumulatorBitCount = + kHuffmanAccumulatorBitCount - kHuffmanCodeBitCount; + +// PrefixInfo holds info about a group of codes that are all of the same length. +struct PrefixInfo { + // Given the leading bits (32 in this case) of the encoded string, and that + // they start with a code of length |code_length|, return the corresponding + // canonical for that leading code. + uint32_t DecodeToCanonical(HuffmanCode bits) const { + // What is the position of the canonical symbol being decoded within + // the canonical symbols of |length|? + HuffmanCode ordinal_in_length = + ((bits - first_code) >> (kHuffmanCodeBitCount - code_length)); + + // Combined with |canonical| to produce the position of the canonical symbol + // being decoded within all of the canonical symbols. + return first_canonical + ordinal_in_length; + } + + const HuffmanCode first_code; // First code of this length, left justified in + // the field (i.e. the first bit of the code is + // the high-order bit). + const uint16_t code_length; // Length of the prefix code |base|. + const uint16_t first_canonical; // First canonical symbol of this length. +}; + +inline std::ostream& operator<<(std::ostream& out, const PrefixInfo& v) { + return out << "{first_code: " << HuffmanCodeBitSet(v.first_code) + << ", code_length: " << v.code_length + << ", first_canonical: " << v.first_canonical << "}"; +} + +// Given |value|, a sequence of the leading bits remaining to be decoded, +// figure out which group of canonicals (by code length) that value starts +// with. This function was generated. +PrefixInfo PrefixToInfo(HuffmanCode value) { + if (value < 0b10111000000000000000000000000000) { + if (value < 0b01010000000000000000000000000000) { + return {0b00000000000000000000000000000000, 5, 0}; + } else { + return {0b01010000000000000000000000000000, 6, 10}; + } + } else { + if (value < 0b11111110000000000000000000000000) { + if (value < 0b11111000000000000000000000000000) { + return {0b10111000000000000000000000000000, 7, 36}; + } else { + return {0b11111000000000000000000000000000, 8, 68}; + } + } else { + if (value < 0b11111111110000000000000000000000) { + if (value < 0b11111111101000000000000000000000) { + if (value < 0b11111111010000000000000000000000) { + return {0b11111110000000000000000000000000, 10, 74}; + } else { + return {0b11111111010000000000000000000000, 11, 79}; + } + } else { + return {0b11111111101000000000000000000000, 12, 82}; + } + } else { + if (value < 0b11111111111111100000000000000000) { + if (value < 0b11111111111110000000000000000000) { + if (value < 0b11111111111100000000000000000000) { + return {0b11111111110000000000000000000000, 13, 84}; + } else { + return {0b11111111111100000000000000000000, 14, 90}; + } + } else { + return {0b11111111111110000000000000000000, 15, 92}; + } + } else { + if (value < 0b11111111111111110100100000000000) { + if (value < 0b11111111111111101110000000000000) { + if (value < 0b11111111111111100110000000000000) { + return {0b11111111111111100000000000000000, 19, 95}; + } else { + return {0b11111111111111100110000000000000, 20, 98}; + } + } else { + return {0b11111111111111101110000000000000, 21, 106}; + } + } else { + if (value < 0b11111111111111111110101000000000) { + if (value < 0b11111111111111111011000000000000) { + return {0b11111111111111110100100000000000, 22, 119}; + } else { + return {0b11111111111111111011000000000000, 23, 145}; + } + } else { + if (value < 0b11111111111111111111101111000000) { + if (value < 0b11111111111111111111100000000000) { + if (value < 0b11111111111111111111011000000000) { + return {0b11111111111111111110101000000000, 24, 174}; + } else { + return {0b11111111111111111111011000000000, 25, 186}; + } + } else { + return {0b11111111111111111111100000000000, 26, 190}; + } + } else { + if (value < 0b11111111111111111111111111110000) { + if (value < 0b11111111111111111111111000100000) { + return {0b11111111111111111111101111000000, 27, 205}; + } else { + return {0b11111111111111111111111000100000, 28, 224}; + } + } else { + return {0b11111111111111111111111111110000, 30, 253}; + } + } + } + } + } + } + } + } +} + +// Mapping from canonical symbol (0 to 255) to actual symbol. +// clang-format off +constexpr unsigned char kCanonicalToSymbol[] = { + '0', '1', '2', 'a', 'c', 'e', 'i', 'o', + 's', 't', 0x20, '%', '-', '.', '/', '3', + '4', '5', '6', '7', '8', '9', '=', 'A', + '_', 'b', 'd', 'f', 'g', 'h', 'l', 'm', + 'n', 'p', 'r', 'u', ':', 'B', 'C', 'D', + 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', + 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', + 'U', 'V', 'W', 'Y', 'j', 'k', 'q', 'v', + 'w', 'x', 'y', 'z', '&', '*', ',', ';', + 'X', 'Z', '!', '\"', '(', ')', '?', '\'', + '+', '|', '#', '>', 0x00, '$', '@', '[', + ']', '~', '^', '}', '<', '`', '{', '\\', + 0xc3, 0xd0, 0x80, 0x82, 0x83, 0xa2, 0xb8, 0xc2, + 0xe0, 0xe2, 0x99, 0xa1, 0xa7, 0xac, 0xb0, 0xb1, + 0xb3, 0xd1, 0xd8, 0xd9, 0xe3, 0xe5, 0xe6, 0x81, + 0x84, 0x85, 0x86, 0x88, 0x92, 0x9a, 0x9c, 0xa0, + 0xa3, 0xa4, 0xa9, 0xaa, 0xad, 0xb2, 0xb5, 0xb9, + 0xba, 0xbb, 0xbd, 0xbe, 0xc4, 0xc6, 0xe4, 0xe8, + 0xe9, 0x01, 0x87, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, + 0x8f, 0x93, 0x95, 0x96, 0x97, 0x98, 0x9b, 0x9d, + 0x9e, 0xa5, 0xa6, 0xa8, 0xae, 0xaf, 0xb4, 0xb6, + 0xb7, 0xbc, 0xbf, 0xc5, 0xe7, 0xef, 0x09, 0x8e, + 0x90, 0x91, 0x94, 0x9f, 0xab, 0xce, 0xd7, 0xe1, + 0xec, 0xed, 0xc7, 0xcf, 0xea, 0xeb, 0xc0, 0xc1, + 0xc8, 0xc9, 0xca, 0xcd, 0xd2, 0xd5, 0xda, 0xdb, + 0xee, 0xf0, 0xf2, 0xf3, 0xff, 0xcb, 0xcc, 0xd3, + 0xd4, 0xd6, 0xdd, 0xde, 0xdf, 0xf1, 0xf4, 0xf5, + 0xf6, 0xf7, 0xf8, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, + 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x0b, + 0x0c, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, + 0x15, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, + 0x1e, 0x1f, 0x7f, 0xdc, 0xf9, 0x0a, 0x0d, 0x16, +}; +// clang-format on + +constexpr size_t kShortCodeTableSize = 124; +struct ShortCodeInfo { + uint8_t symbol; + uint8_t length; +} kShortCodeTable[kShortCodeTableSize] = { + {0x30, 5}, // Match: 0b0000000, Symbol: 0 + {0x30, 5}, // Match: 0b0000001, Symbol: 0 + {0x30, 5}, // Match: 0b0000010, Symbol: 0 + {0x30, 5}, // Match: 0b0000011, Symbol: 0 + {0x31, 5}, // Match: 0b0000100, Symbol: 1 + {0x31, 5}, // Match: 0b0000101, Symbol: 1 + {0x31, 5}, // Match: 0b0000110, Symbol: 1 + {0x31, 5}, // Match: 0b0000111, Symbol: 1 + {0x32, 5}, // Match: 0b0001000, Symbol: 2 + {0x32, 5}, // Match: 0b0001001, Symbol: 2 + {0x32, 5}, // Match: 0b0001010, Symbol: 2 + {0x32, 5}, // Match: 0b0001011, Symbol: 2 + {0x61, 5}, // Match: 0b0001100, Symbol: a + {0x61, 5}, // Match: 0b0001101, Symbol: a + {0x61, 5}, // Match: 0b0001110, Symbol: a + {0x61, 5}, // Match: 0b0001111, Symbol: a + {0x63, 5}, // Match: 0b0010000, Symbol: c + {0x63, 5}, // Match: 0b0010001, Symbol: c + {0x63, 5}, // Match: 0b0010010, Symbol: c + {0x63, 5}, // Match: 0b0010011, Symbol: c + {0x65, 5}, // Match: 0b0010100, Symbol: e + {0x65, 5}, // Match: 0b0010101, Symbol: e + {0x65, 5}, // Match: 0b0010110, Symbol: e + {0x65, 5}, // Match: 0b0010111, Symbol: e + {0x69, 5}, // Match: 0b0011000, Symbol: i + {0x69, 5}, // Match: 0b0011001, Symbol: i + {0x69, 5}, // Match: 0b0011010, Symbol: i + {0x69, 5}, // Match: 0b0011011, Symbol: i + {0x6f, 5}, // Match: 0b0011100, Symbol: o + {0x6f, 5}, // Match: 0b0011101, Symbol: o + {0x6f, 5}, // Match: 0b0011110, Symbol: o + {0x6f, 5}, // Match: 0b0011111, Symbol: o + {0x73, 5}, // Match: 0b0100000, Symbol: s + {0x73, 5}, // Match: 0b0100001, Symbol: s + {0x73, 5}, // Match: 0b0100010, Symbol: s + {0x73, 5}, // Match: 0b0100011, Symbol: s + {0x74, 5}, // Match: 0b0100100, Symbol: t + {0x74, 5}, // Match: 0b0100101, Symbol: t + {0x74, 5}, // Match: 0b0100110, Symbol: t + {0x74, 5}, // Match: 0b0100111, Symbol: t + {0x20, 6}, // Match: 0b0101000, Symbol: (space) + {0x20, 6}, // Match: 0b0101001, Symbol: (space) + {0x25, 6}, // Match: 0b0101010, Symbol: % + {0x25, 6}, // Match: 0b0101011, Symbol: % + {0x2d, 6}, // Match: 0b0101100, Symbol: - + {0x2d, 6}, // Match: 0b0101101, Symbol: - + {0x2e, 6}, // Match: 0b0101110, Symbol: . + {0x2e, 6}, // Match: 0b0101111, Symbol: . + {0x2f, 6}, // Match: 0b0110000, Symbol: / + {0x2f, 6}, // Match: 0b0110001, Symbol: / + {0x33, 6}, // Match: 0b0110010, Symbol: 3 + {0x33, 6}, // Match: 0b0110011, Symbol: 3 + {0x34, 6}, // Match: 0b0110100, Symbol: 4 + {0x34, 6}, // Match: 0b0110101, Symbol: 4 + {0x35, 6}, // Match: 0b0110110, Symbol: 5 + {0x35, 6}, // Match: 0b0110111, Symbol: 5 + {0x36, 6}, // Match: 0b0111000, Symbol: 6 + {0x36, 6}, // Match: 0b0111001, Symbol: 6 + {0x37, 6}, // Match: 0b0111010, Symbol: 7 + {0x37, 6}, // Match: 0b0111011, Symbol: 7 + {0x38, 6}, // Match: 0b0111100, Symbol: 8 + {0x38, 6}, // Match: 0b0111101, Symbol: 8 + {0x39, 6}, // Match: 0b0111110, Symbol: 9 + {0x39, 6}, // Match: 0b0111111, Symbol: 9 + {0x3d, 6}, // Match: 0b1000000, Symbol: = + {0x3d, 6}, // Match: 0b1000001, Symbol: = + {0x41, 6}, // Match: 0b1000010, Symbol: A + {0x41, 6}, // Match: 0b1000011, Symbol: A + {0x5f, 6}, // Match: 0b1000100, Symbol: _ + {0x5f, 6}, // Match: 0b1000101, Symbol: _ + {0x62, 6}, // Match: 0b1000110, Symbol: b + {0x62, 6}, // Match: 0b1000111, Symbol: b + {0x64, 6}, // Match: 0b1001000, Symbol: d + {0x64, 6}, // Match: 0b1001001, Symbol: d + {0x66, 6}, // Match: 0b1001010, Symbol: f + {0x66, 6}, // Match: 0b1001011, Symbol: f + {0x67, 6}, // Match: 0b1001100, Symbol: g + {0x67, 6}, // Match: 0b1001101, Symbol: g + {0x68, 6}, // Match: 0b1001110, Symbol: h + {0x68, 6}, // Match: 0b1001111, Symbol: h + {0x6c, 6}, // Match: 0b1010000, Symbol: l + {0x6c, 6}, // Match: 0b1010001, Symbol: l + {0x6d, 6}, // Match: 0b1010010, Symbol: m + {0x6d, 6}, // Match: 0b1010011, Symbol: m + {0x6e, 6}, // Match: 0b1010100, Symbol: n + {0x6e, 6}, // Match: 0b1010101, Symbol: n + {0x70, 6}, // Match: 0b1010110, Symbol: p + {0x70, 6}, // Match: 0b1010111, Symbol: p + {0x72, 6}, // Match: 0b1011000, Symbol: r + {0x72, 6}, // Match: 0b1011001, Symbol: r + {0x75, 6}, // Match: 0b1011010, Symbol: u + {0x75, 6}, // Match: 0b1011011, Symbol: u + {0x3a, 7}, // Match: 0b1011100, Symbol: : + {0x42, 7}, // Match: 0b1011101, Symbol: B + {0x43, 7}, // Match: 0b1011110, Symbol: C + {0x44, 7}, // Match: 0b1011111, Symbol: D + {0x45, 7}, // Match: 0b1100000, Symbol: E + {0x46, 7}, // Match: 0b1100001, Symbol: F + {0x47, 7}, // Match: 0b1100010, Symbol: G + {0x48, 7}, // Match: 0b1100011, Symbol: H + {0x49, 7}, // Match: 0b1100100, Symbol: I + {0x4a, 7}, // Match: 0b1100101, Symbol: J + {0x4b, 7}, // Match: 0b1100110, Symbol: K + {0x4c, 7}, // Match: 0b1100111, Symbol: L + {0x4d, 7}, // Match: 0b1101000, Symbol: M + {0x4e, 7}, // Match: 0b1101001, Symbol: N + {0x4f, 7}, // Match: 0b1101010, Symbol: O + {0x50, 7}, // Match: 0b1101011, Symbol: P + {0x51, 7}, // Match: 0b1101100, Symbol: Q + {0x52, 7}, // Match: 0b1101101, Symbol: R + {0x53, 7}, // Match: 0b1101110, Symbol: S + {0x54, 7}, // Match: 0b1101111, Symbol: T + {0x55, 7}, // Match: 0b1110000, Symbol: U + {0x56, 7}, // Match: 0b1110001, Symbol: V + {0x57, 7}, // Match: 0b1110010, Symbol: W + {0x59, 7}, // Match: 0b1110011, Symbol: Y + {0x6a, 7}, // Match: 0b1110100, Symbol: j + {0x6b, 7}, // Match: 0b1110101, Symbol: k + {0x71, 7}, // Match: 0b1110110, Symbol: q + {0x76, 7}, // Match: 0b1110111, Symbol: v + {0x77, 7}, // Match: 0b1111000, Symbol: w + {0x78, 7}, // Match: 0b1111001, Symbol: x + {0x79, 7}, // Match: 0b1111010, Symbol: y + {0x7a, 7}, // Match: 0b1111011, Symbol: z +}; + +} // namespace + +HuffmanBitBuffer::HuffmanBitBuffer() { Reset(); } + +void HuffmanBitBuffer::Reset() { + accumulator_ = 0; + count_ = 0; +} + +size_t HuffmanBitBuffer::AppendBytes(absl::string_view input) { + HuffmanAccumulatorBitCount free_cnt = free_count(); + size_t bytes_available = input.size(); + if (free_cnt < 8 || bytes_available == 0) { + return 0; + } + + // Top up |accumulator_| until there isn't room for a whole byte. + size_t bytes_used = 0; + auto* ptr = reinterpret_cast(input.data()); + do { + auto b = static_cast(*ptr++); + free_cnt -= 8; + accumulator_ |= (b << free_cnt); + ++bytes_used; + } while (free_cnt >= 8 && bytes_used < bytes_available); + count_ += (bytes_used * 8); + return bytes_used; +} + +HuffmanAccumulatorBitCount HuffmanBitBuffer::free_count() const { + return kHuffmanAccumulatorBitCount - count_; +} + +void HuffmanBitBuffer::ConsumeBits(HuffmanAccumulatorBitCount code_length) { + QUICHE_DCHECK_LE(code_length, count_); + accumulator_ <<= code_length; + count_ -= code_length; +} + +bool HuffmanBitBuffer::InputProperlyTerminated() const { + auto cnt = count(); + if (cnt < 8) { + if (cnt == 0) { + return true; + } + HuffmanAccumulator expected = ~(~HuffmanAccumulator() >> cnt); + // We expect all the bits below the high order |cnt| bits of accumulator_ + // to be cleared as we perform left shift operations while decoding. + QUICHE_DCHECK_EQ(accumulator_ & ~expected, 0u) + << "\n expected: " << HuffmanAccumulatorBitSet(expected) << "\n " + << *this; + return accumulator_ == expected; + } + return false; +} + +std::string HuffmanBitBuffer::DebugString() const { + std::stringstream ss; + ss << "{accumulator: " << HuffmanAccumulatorBitSet(accumulator_) + << "; count: " << count_ << "}"; + return ss.str(); +} + +HpackHuffmanDecoder::HpackHuffmanDecoder() = default; + +HpackHuffmanDecoder::~HpackHuffmanDecoder() = default; + +bool HpackHuffmanDecoder::Decode(absl::string_view input, std::string* output) { + QUICHE_DVLOG(1) << "HpackHuffmanDecoder::Decode"; + + // Fill bit_buffer_ from input. + input.remove_prefix(bit_buffer_.AppendBytes(input)); + + while (true) { + QUICHE_DVLOG(3) << "Enter Decode Loop, bit_buffer_: " << bit_buffer_; + if (bit_buffer_.count() >= 7) { + // Get high 7 bits of the bit buffer, see if that contains a complete + // code of 5, 6 or 7 bits. + uint8_t short_code = + bit_buffer_.value() >> (kHuffmanAccumulatorBitCount - 7); + QUICHE_DCHECK_LT(short_code, 128); + if (short_code < kShortCodeTableSize) { + ShortCodeInfo info = kShortCodeTable[short_code]; + bit_buffer_.ConsumeBits(info.length); + output->push_back(static_cast(info.symbol)); + continue; + } + // The code is more than 7 bits long. Use PrefixToInfo, etc. to decode + // longer codes. + } else { + // We may have (mostly) drained bit_buffer_. If we can top it up, try + // using the table decoder above. + size_t byte_count = bit_buffer_.AppendBytes(input); + if (byte_count > 0) { + input.remove_prefix(byte_count); + continue; + } + } + + HuffmanCode code_prefix = bit_buffer_.value() >> kExtraAccumulatorBitCount; + QUICHE_DVLOG(3) << "code_prefix: " << HuffmanCodeBitSet(code_prefix); + + PrefixInfo prefix_info = PrefixToInfo(code_prefix); + QUICHE_DVLOG(3) << "prefix_info: " << prefix_info; + QUICHE_DCHECK_LE(kMinCodeBitCount, prefix_info.code_length); + QUICHE_DCHECK_LE(prefix_info.code_length, kMaxCodeBitCount); + + if (prefix_info.code_length <= bit_buffer_.count()) { + // We have enough bits for one code. + uint32_t canonical = prefix_info.DecodeToCanonical(code_prefix); + if (canonical < 256) { + // Valid code. + char c = kCanonicalToSymbol[canonical]; + output->push_back(c); + bit_buffer_.ConsumeBits(prefix_info.code_length); + continue; + } + // Encoder is not supposed to explicity encode the EOS symbol. + QUICHE_DLOG(ERROR) << "EOS explicitly encoded!\n " << bit_buffer_ << "\n " + << prefix_info; + return false; + } + // bit_buffer_ doesn't have enough bits in it to decode the next symbol. + // Append to it as many bytes as are available AND fit. + size_t byte_count = bit_buffer_.AppendBytes(input); + if (byte_count == 0) { + QUICHE_DCHECK_EQ(input.size(), 0u); + return true; + } + input.remove_prefix(byte_count); + } +} + +std::string HpackHuffmanDecoder::DebugString() const { + return bit_buffer_.DebugString(); +} + +} // namespace http2 diff --git a/quiche/http2/hpack/huffman/hpack_huffman_decoder.h b/quiche/http2/hpack/huffman/hpack_huffman_decoder.h new file mode 100644 index 000000000000..910aba592d5f --- /dev/null +++ b/quiche/http2/hpack/huffman/hpack_huffman_decoder.h @@ -0,0 +1,134 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_HPACK_HUFFMAN_HPACK_HUFFMAN_DECODER_H_ +#define QUICHE_HTTP2_HPACK_HUFFMAN_HPACK_HUFFMAN_DECODER_H_ + +// HpackHuffmanDecoder is an incremental decoder of strings that have been +// encoded using the Huffman table defined in the HPACK spec. +// By incremental, we mean that the HpackHuffmanDecoder::Decode method does +// not require the entire string to be provided, and can instead decode the +// string as fragments of it become available (e.g. as HPACK block fragments +// are received for decoding by HpackEntryDecoder). + +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { + +// HuffmanAccumulator is used to store bits during decoding, e.g. next N bits +// that have not yet been decoded, but have been extracted from the encoded +// string). An advantage of using a uint64 for the accumulator +// is that it has room for the bits of the longest code plus the bits of a full +// byte; that means that when adding more bits to the accumulator, it can always +// be done in whole bytes. For example, if we currently have 26 bits in the +// accumulator, and need more to decode the current symbol, we can add a whole +// byte to the accumulator, and not have to do juggling with adding 6 bits (to +// reach 30), and then keep track of the last two bits we've not been able to +// add to the accumulator. +typedef uint64_t HuffmanAccumulator; +typedef size_t HuffmanAccumulatorBitCount; + +// HuffmanBitBuffer stores the leading edge of bits to be decoded. The high +// order bit of accumulator_ is the next bit to be decoded. +class QUICHE_EXPORT HuffmanBitBuffer { + public: + HuffmanBitBuffer(); + + // Prepare for decoding a new Huffman encoded string. + void Reset(); + + // Add as many whole bytes to the accumulator (accumulator_) as possible, + // returning the number of bytes added. + size_t AppendBytes(absl::string_view input); + + // Get the bits of the accumulator. + HuffmanAccumulator value() const { return accumulator_; } + + // Number of bits of the encoded string that are in the accumulator + // (accumulator_). + HuffmanAccumulatorBitCount count() const { return count_; } + + // Are there no bits in the accumulator? + bool IsEmpty() const { return count_ == 0; } + + // Number of additional bits that can be added to the accumulator. + HuffmanAccumulatorBitCount free_count() const; + + // Consume the leading |code_length| bits of the accumulator. + void ConsumeBits(HuffmanAccumulatorBitCount code_length); + + // Are the contents valid for the end of a Huffman encoded string? The RFC + // states that EOS (end-of-string) symbol must not be explicitly encoded in + // the bit stream, but any unused bits in the final byte must be set to the + // prefix of the EOS symbol, which is all 1 bits. So there can be at most 7 + // such bits. + // Returns true if the bit buffer is empty, or contains at most 7 bits, all + // of them 1. Otherwise returns false. + bool InputProperlyTerminated() const; + + std::string DebugString() const; + + private: + HuffmanAccumulator accumulator_; + HuffmanAccumulatorBitCount count_; +}; + +inline std::ostream& operator<<(std::ostream& out, const HuffmanBitBuffer& v) { + return out << v.DebugString(); +} + +class QUICHE_EXPORT HpackHuffmanDecoder { + public: + HpackHuffmanDecoder(); + ~HpackHuffmanDecoder(); + + // Prepare for decoding a new Huffman encoded string. + void Reset() { bit_buffer_.Reset(); } + + // Decode the portion of a HPACK Huffman encoded string that is in |input|, + // appending the decoded symbols into |*output|, stopping when more bits are + // needed to determine the next symbol, which/ means that the input has been + // drained, and also that the bit_buffer_ is empty or that the bits that are + // in it are not a whole symbol. + // If |input| is the start of a string, the caller must first call Reset. + // If |input| includes the end of the encoded string, the caller must call + // InputProperlyTerminated after Decode has returned true in order to + // determine if the encoded string was properly terminated. + // Returns false if something went wrong (e.g. the encoding contains the code + // EOS symbol). Otherwise returns true, in which case input has been fully + // decoded or buffered; in particular, if the low-order bit of the final byte + // of the input is not the last bit of an encoded symbol, then bit_buffer_ + // will contain the leading bits of the code for that symbol, but not the + // final bits of that code. + // Note that output should be empty, but that it is not cleared by Decode(). + bool Decode(absl::string_view input, std::string* output); + + // Is what remains in the bit_buffer_ valid at the end of an encoded string? + // Call after passing the the final portion of a Huffman string to Decode, + // and getting true as the result. + bool InputProperlyTerminated() const { + return bit_buffer_.InputProperlyTerminated(); + } + + std::string DebugString() const; + + private: + HuffmanBitBuffer bit_buffer_; +}; + +inline std::ostream& operator<<(std::ostream& out, + const HpackHuffmanDecoder& v) { + return out << v.DebugString(); +} + +} // namespace http2 + +#endif // QUICHE_HTTP2_HPACK_HUFFMAN_HPACK_HUFFMAN_DECODER_H_ diff --git a/quiche/http2/hpack/huffman/hpack_huffman_decoder_test.cc b/quiche/http2/hpack/huffman/hpack_huffman_decoder_test.cc new file mode 100644 index 000000000000..1a81a5d5204a --- /dev/null +++ b/quiche/http2/hpack/huffman/hpack_huffman_decoder_test.cc @@ -0,0 +1,242 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/huffman/hpack_huffman_decoder.h" + +// Tests of HpackHuffmanDecoder and HuffmanBitBuffer. + +#include + +#include "absl/base/macros.h" +#include "absl/strings/escaping.h" +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/test_tools/random_decoder_test_base.h" +#include "quiche/common/platform/api/quiche_expect_bug.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { +namespace { + +TEST(HuffmanBitBufferTest, Reset) { + HuffmanBitBuffer bb; + EXPECT_TRUE(bb.IsEmpty()); + EXPECT_TRUE(bb.InputProperlyTerminated()); + EXPECT_EQ(bb.count(), 0u); + EXPECT_EQ(bb.free_count(), 64u); + EXPECT_EQ(bb.value(), 0u); +} + +TEST(HuffmanBitBufferTest, AppendBytesAligned) { + std::string s; + s.push_back('\x11'); + s.push_back('\x22'); + s.push_back('\x33'); + absl::string_view sp(s); + + HuffmanBitBuffer bb; + sp.remove_prefix(bb.AppendBytes(sp)); + EXPECT_TRUE(sp.empty()); + EXPECT_FALSE(bb.IsEmpty()) << bb; + EXPECT_FALSE(bb.InputProperlyTerminated()); + EXPECT_EQ(bb.count(), 24u) << bb; + EXPECT_EQ(bb.free_count(), 40u) << bb; + EXPECT_EQ(bb.value(), HuffmanAccumulator(0x112233) << 40) << bb; + + s.clear(); + s.push_back('\x44'); + sp = s; + + sp.remove_prefix(bb.AppendBytes(sp)); + EXPECT_TRUE(sp.empty()); + EXPECT_EQ(bb.count(), 32u) << bb; + EXPECT_EQ(bb.free_count(), 32u) << bb; + EXPECT_EQ(bb.value(), HuffmanAccumulator(0x11223344) << 32) << bb; + + s.clear(); + s.push_back('\x55'); + s.push_back('\x66'); + s.push_back('\x77'); + s.push_back('\x88'); + s.push_back('\x99'); + sp = s; + + sp.remove_prefix(bb.AppendBytes(sp)); + EXPECT_EQ(sp.size(), 1u); + EXPECT_EQ('\x99', sp[0]); + EXPECT_EQ(bb.count(), 64u) << bb; + EXPECT_EQ(bb.free_count(), 0u) << bb; + EXPECT_EQ(bb.value(), HuffmanAccumulator(0x1122334455667788LL)) << bb; + + sp.remove_prefix(bb.AppendBytes(sp)); + EXPECT_EQ(sp.size(), 1u); + EXPECT_EQ('\x99', sp[0]); + EXPECT_EQ(bb.count(), 64u) << bb; + EXPECT_EQ(bb.free_count(), 0u) << bb; + EXPECT_EQ(bb.value(), HuffmanAccumulator(0x1122334455667788LL)) << bb; +} + +TEST(HuffmanBitBufferTest, ConsumeBits) { + std::string s; + s.push_back('\x11'); + s.push_back('\x22'); + s.push_back('\x33'); + absl::string_view sp(s); + + HuffmanBitBuffer bb; + sp.remove_prefix(bb.AppendBytes(sp)); + EXPECT_TRUE(sp.empty()); + + bb.ConsumeBits(1); + EXPECT_EQ(bb.count(), 23u) << bb; + EXPECT_EQ(bb.free_count(), 41u) << bb; + EXPECT_EQ(bb.value(), HuffmanAccumulator(0x112233) << 41) << bb; + + bb.ConsumeBits(20); + EXPECT_EQ(bb.count(), 3u) << bb; + EXPECT_EQ(bb.free_count(), 61u) << bb; + EXPECT_EQ(bb.value(), HuffmanAccumulator(0x3) << 61) << bb; +} + +TEST(HuffmanBitBufferTest, AppendBytesUnaligned) { + std::string s; + s.push_back('\x11'); + s.push_back('\x22'); + s.push_back('\x33'); + s.push_back('\x44'); + s.push_back('\x55'); + s.push_back('\x66'); + s.push_back('\x77'); + s.push_back('\x88'); + s.push_back('\x99'); + s.push_back('\xaa'); + s.push_back('\xbb'); + s.push_back('\xcc'); + s.push_back('\xdd'); + absl::string_view sp(s); + + HuffmanBitBuffer bb; + sp.remove_prefix(bb.AppendBytes(sp)); + EXPECT_EQ(sp.size(), 5u); + EXPECT_FALSE(bb.InputProperlyTerminated()); + + bb.ConsumeBits(15); + EXPECT_EQ(bb.count(), 49u) << bb; + EXPECT_EQ(bb.free_count(), 15u) << bb; + + HuffmanAccumulator expected(0x1122334455667788); + expected <<= 15; + EXPECT_EQ(bb.value(), expected); + + sp.remove_prefix(bb.AppendBytes(sp)); + EXPECT_EQ(sp.size(), 4u); + EXPECT_EQ(bb.count(), 57u) << bb; + EXPECT_EQ(bb.free_count(), 7u) << bb; + + expected |= (HuffmanAccumulator(0x99) << 7); + EXPECT_EQ(bb.value(), expected) + << bb << std::hex << "\n actual: " << bb.value() + << "\n expected: " << expected; +} + +class HpackHuffmanDecoderTest : public RandomDecoderTest { + protected: + HpackHuffmanDecoderTest() { + // The decoder may return true, and its accumulator may be empty, at + // many boundaries while decoding, and yet the whole string hasn't + // been decoded. + stop_decode_on_done_ = false; + } + + DecodeStatus StartDecoding(DecodeBuffer* b) override { + input_bytes_seen_ = 0; + output_buffer_.clear(); + decoder_.Reset(); + return ResumeDecoding(b); + } + + DecodeStatus ResumeDecoding(DecodeBuffer* b) override { + input_bytes_seen_ += b->Remaining(); + absl::string_view sp(b->cursor(), b->Remaining()); + if (decoder_.Decode(sp, &output_buffer_)) { + b->AdvanceCursor(b->Remaining()); + // Successfully decoded (or buffered) the bytes in absl::string_view. + EXPECT_LE(input_bytes_seen_, input_bytes_expected_); + // Have we reached the end of the encoded string? + if (input_bytes_expected_ == input_bytes_seen_) { + if (decoder_.InputProperlyTerminated()) { + return DecodeStatus::kDecodeDone; + } else { + return DecodeStatus::kDecodeError; + } + } + return DecodeStatus::kDecodeInProgress; + } + return DecodeStatus::kDecodeError; + } + + HpackHuffmanDecoder decoder_; + std::string output_buffer_; + size_t input_bytes_seen_; + size_t input_bytes_expected_; +}; + +TEST_F(HpackHuffmanDecoderTest, SpecRequestExamples) { + HpackHuffmanDecoder decoder; + std::string test_table[] = { + absl::HexStringToBytes("f1e3c2e5f23a6ba0ab90f4ff"), + "www.example.com", + absl::HexStringToBytes("a8eb10649cbf"), + "no-cache", + absl::HexStringToBytes("25a849e95ba97d7f"), + "custom-key", + absl::HexStringToBytes("25a849e95bb8e8b4bf"), + "custom-value", + }; + for (size_t i = 0; i != ABSL_ARRAYSIZE(test_table); i += 2) { + const std::string& huffman_encoded(test_table[i]); + const std::string& plain_string(test_table[i + 1]); + std::string buffer; + decoder.Reset(); + EXPECT_TRUE(decoder.Decode(huffman_encoded, &buffer)) << decoder; + EXPECT_TRUE(decoder.InputProperlyTerminated()) << decoder; + EXPECT_EQ(buffer, plain_string); + } +} + +TEST_F(HpackHuffmanDecoderTest, SpecResponseExamples) { + HpackHuffmanDecoder decoder; + // clang-format off + std::string test_table[] = { + absl::HexStringToBytes("6402"), + "302", + absl::HexStringToBytes("aec3771a4b"), + "private", + absl::HexStringToBytes("d07abe941054d444a8200595040b8166" + "e082a62d1bff"), + "Mon, 21 Oct 2013 20:13:21 GMT", + absl::HexStringToBytes("9d29ad171863c78f0b97c8e9ae82ae43" + "d3"), + "https://www.example.com", + absl::HexStringToBytes("94e7821dd7f2e6c7b335dfdfcd5b3960" + "d5af27087f3672c1ab270fb5291f9587" + "316065c003ed4ee5b1063d5007"), + "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1", + }; + // clang-format on + for (size_t i = 0; i != ABSL_ARRAYSIZE(test_table); i += 2) { + const std::string& huffman_encoded(test_table[i]); + const std::string& plain_string(test_table[i + 1]); + std::string buffer; + decoder.Reset(); + EXPECT_TRUE(decoder.Decode(huffman_encoded, &buffer)) << decoder; + EXPECT_TRUE(decoder.InputProperlyTerminated()) << decoder; + EXPECT_EQ(buffer, plain_string); + } +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/hpack/huffman/hpack_huffman_encoder.cc b/quiche/http2/hpack/huffman/hpack_huffman_encoder.cc new file mode 100644 index 000000000000..aa16ea34e0bb --- /dev/null +++ b/quiche/http2/hpack/huffman/hpack_huffman_encoder.cc @@ -0,0 +1,127 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/huffman/hpack_huffman_encoder.h" + +#include "quiche/http2/hpack/huffman/huffman_spec_tables.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +size_t HuffmanSize(absl::string_view plain) { + size_t bits = 0; + for (const uint8_t c : plain) { + bits += HuffmanSpecTables::kCodeLengths[c]; + } + return (bits + 7) / 8; +} + +void HuffmanEncode(absl::string_view plain, size_t encoded_size, + std::string* huffman) { + QUICHE_DCHECK(huffman != nullptr); + huffman->reserve(huffman->size() + encoded_size); + uint64_t bit_buffer = 0; // High-bit is next bit to output. Not clear if that + // is more performant than having the low-bit be the + // last to be output. + size_t bits_unused = 64; // Number of bits available for the next code. + for (uint8_t c : plain) { + size_t code_length = HuffmanSpecTables::kCodeLengths[c]; + if (bits_unused < code_length) { + // There isn't enough room in bit_buffer for the code of c. + // Flush until bits_unused > 56 (i.e. 64 - 8). + do { + char h = static_cast(bit_buffer >> 56); + bit_buffer <<= 8; + bits_unused += 8; + // Perhaps would be more efficient if we populated an array of chars, + // so we don't have to call push_back each time. Reconsider if used + // for production. + huffman->push_back(h); + } while (bits_unused <= 56); + } + uint64_t code = HuffmanSpecTables::kRightCodes[c]; + size_t shift_by = bits_unused - code_length; + bit_buffer |= (code << shift_by); + bits_unused -= code_length; + } + // bit_buffer contains (64-bits_unused) bits that still need to be flushed. + // Output whole bytes until we don't have any whole bytes left. + size_t bits_used = 64 - bits_unused; + while (bits_used >= 8) { + char h = static_cast(bit_buffer >> 56); + bit_buffer <<= 8; + bits_used -= 8; + huffman->push_back(h); + } + if (bits_used > 0) { + // We have less than a byte left to output. The spec calls for padding out + // the final byte with the leading bits of the EOS symbol (30 1-bits). + constexpr uint64_t leading_eos_bits = 0b11111111; + bit_buffer |= (leading_eos_bits << (56 - bits_used)); + char h = static_cast(bit_buffer >> 56); + huffman->push_back(h); + } +} + +void HuffmanEncodeFast(absl::string_view input, size_t encoded_size, + std::string* output) { + const size_t original_size = output->size(); + const size_t final_size = original_size + encoded_size; + // Reserve an extra four bytes to avoid accessing unallocated memory (even + // though it would only be OR'd with zeros and thus not modified). + output->resize(final_size + 4, 0); + + // Pointer to first appended byte. + char* const first = &*output->begin() + original_size; + size_t bit_counter = 0; + for (uint8_t c : input) { + // Align the Huffman code to byte boundaries as it needs to be written. + // The longest Huffman code is 30 bits long, and it can be shifted by up to + // 7 bits, requiring 37 bits in total. The most significant 25 bits and + // least significant 2 bits of |code| are always zero. + uint64_t code = static_cast(HuffmanSpecTables::kLeftCodes[c]) + << (8 - (bit_counter % 8)); + // The byte where the first bit of |code| needs to be written. + char* const current = first + (bit_counter / 8); + + bit_counter += HuffmanSpecTables::kCodeLengths[c]; + + *current |= code >> 32; + + // Do not check if this write is zero before executing it, because with + // uniformly random shifts and an ideal random input distribution + // corresponding to the Huffman tree it would only be zero in 29% of the + // cases. + *(current + 1) |= (code >> 24) & 0xff; + + // Continue to next input character if there is nothing else to write. + // (If next byte is zero, then rest must also be zero.) + if ((code & 0xff0000) == 0) { + continue; + } + *(current + 2) |= (code >> 16) & 0xff; + + // Continue to next input character if there is nothing else to write. + // (If next byte is zero, then rest must also be zero.) + if ((code & 0xff00) == 0) { + continue; + } + *(current + 3) |= (code >> 8) & 0xff; + + // Do not check if this write is zero, because the check would probably be + // as expensive as the write. + *(current + 4) |= code & 0xff; + } + + QUICHE_DCHECK_EQ(encoded_size, (bit_counter + 7) / 8); + + // EOF + if (bit_counter % 8 != 0) { + *(first + encoded_size - 1) |= 0xff >> (bit_counter & 7); + } + + output->resize(final_size); +} + +} // namespace http2 diff --git a/quiche/http2/hpack/huffman/hpack_huffman_encoder.h b/quiche/http2/hpack/huffman/hpack_huffman_encoder.h new file mode 100644 index 000000000000..e6056c1ce855 --- /dev/null +++ b/quiche/http2/hpack/huffman/hpack_huffman_encoder.h @@ -0,0 +1,38 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_HPACK_HUFFMAN_HPACK_HUFFMAN_ENCODER_H_ +#define QUICHE_HTTP2_HPACK_HUFFMAN_HPACK_HUFFMAN_ENCODER_H_ + +// Functions supporting the encoding of strings using the HPACK-defined Huffman +// table. + +#include // For size_t +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { + +// Returns the size of the Huffman encoding of |plain|, which may be greater +// than plain.size(). +QUICHE_EXPORT size_t HuffmanSize(absl::string_view plain); + +// Encode the plain text string |plain| with the Huffman encoding defined in the +// HPACK RFC, 7541. |encoded_size| is used to pre-allocate storage and it +// should be the value returned by HuffmanSize(). Appends the result to +// |*huffman|. +QUICHE_EXPORT void HuffmanEncode(absl::string_view plain, size_t encoded_size, + std::string* huffman); + +// Encode |input| with the Huffman encoding defined RFC7541, used in HPACK and +// QPACK. |encoded_size| must be the value returned by HuffmanSize(). +// Appends the result to the end of |*output|. +QUICHE_EXPORT void HuffmanEncodeFast(absl::string_view input, + size_t encoded_size, std::string* output); + +} // namespace http2 + +#endif // QUICHE_HTTP2_HPACK_HUFFMAN_HPACK_HUFFMAN_ENCODER_H_ diff --git a/quiche/http2/hpack/huffman/hpack_huffman_encoder_test.cc b/quiche/http2/hpack/huffman/hpack_huffman_encoder_test.cc new file mode 100644 index 000000000000..0c4f5a415ed4 --- /dev/null +++ b/quiche/http2/hpack/huffman/hpack_huffman_encoder_test.cc @@ -0,0 +1,130 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/huffman/hpack_huffman_encoder.h" + +#include "absl/base/macros.h" +#include "absl/strings/escaping.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace { + +class HuffmanEncoderTest : public quiche::test::QuicheTestWithParam { + protected: + HuffmanEncoderTest() : use_fast_encoder_(GetParam()) {} + virtual ~HuffmanEncoderTest() = default; + + void Encode(absl::string_view input, size_t encoded_size, + std::string* output) { + use_fast_encoder_ ? HuffmanEncodeFast(input, encoded_size, output) + : HuffmanEncode(input, encoded_size, output); + } + + const bool use_fast_encoder_; +}; + +INSTANTIATE_TEST_SUITE_P(TwoEncoders, HuffmanEncoderTest, ::testing::Bool()); + +TEST_P(HuffmanEncoderTest, Empty) { + std::string empty(""); + size_t encoded_size = HuffmanSize(empty); + EXPECT_EQ(0u, encoded_size); + + std::string buffer; + Encode(empty, encoded_size, &buffer); + EXPECT_EQ("", buffer); +} + +TEST_P(HuffmanEncoderTest, SpecRequestExamples) { + std::string test_table[] = { + absl::HexStringToBytes("f1e3c2e5f23a6ba0ab90f4ff"), + "www.example.com", + absl::HexStringToBytes("a8eb10649cbf"), + "no-cache", + absl::HexStringToBytes("25a849e95ba97d7f"), + "custom-key", + absl::HexStringToBytes("25a849e95bb8e8b4bf"), + "custom-value", + }; + for (size_t i = 0; i != ABSL_ARRAYSIZE(test_table); i += 2) { + const std::string& huffman_encoded(test_table[i]); + const std::string& plain_string(test_table[i + 1]); + size_t encoded_size = HuffmanSize(plain_string); + EXPECT_EQ(huffman_encoded.size(), encoded_size); + std::string buffer; + buffer.reserve(huffman_encoded.size()); + Encode(plain_string, encoded_size, &buffer); + EXPECT_EQ(buffer, huffman_encoded) << "Error encoding " << plain_string; + } +} + +TEST_P(HuffmanEncoderTest, SpecResponseExamples) { + // clang-format off + std::string test_table[] = { + absl::HexStringToBytes("6402"), + "302", + absl::HexStringToBytes("aec3771a4b"), + "private", + absl::HexStringToBytes("d07abe941054d444a8200595040b8166" + "e082a62d1bff"), + "Mon, 21 Oct 2013 20:13:21 GMT", + absl::HexStringToBytes("9d29ad171863c78f0b97c8e9ae82ae43" + "d3"), + "https://www.example.com", + absl::HexStringToBytes("94e7821dd7f2e6c7b335dfdfcd5b3960" + "d5af27087f3672c1ab270fb5291f9587" + "316065c003ed4ee5b1063d5007"), + "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1", + }; + // clang-format on + for (size_t i = 0; i != ABSL_ARRAYSIZE(test_table); i += 2) { + const std::string& huffman_encoded(test_table[i]); + const std::string& plain_string(test_table[i + 1]); + size_t encoded_size = HuffmanSize(plain_string); + EXPECT_EQ(huffman_encoded.size(), encoded_size); + std::string buffer; + Encode(plain_string, encoded_size, &buffer); + EXPECT_EQ(buffer, huffman_encoded) << "Error encoding " << plain_string; + } +} + +TEST_P(HuffmanEncoderTest, EncodedSizeAgreesWithEncodeString) { + std::string test_table[] = { + "", + "Mon, 21 Oct 2013 20:13:21 GMT", + "https://www.example.com", + "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1", + std::string(1, '\0'), + std::string("foo\0bar", 7), + std::string(256, '\0'), + }; + // Modify last |test_table| entry to cover all codes. + for (size_t i = 0; i != 256; ++i) { + test_table[ABSL_ARRAYSIZE(test_table) - 1][i] = static_cast(i); + } + + for (size_t i = 0; i != ABSL_ARRAYSIZE(test_table); ++i) { + const std::string& plain_string = test_table[i]; + size_t encoded_size = HuffmanSize(plain_string); + std::string huffman_encoded; + Encode(plain_string, encoded_size, &huffman_encoded); + EXPECT_EQ(encoded_size, huffman_encoded.size()); + } +} + +// Test that encoding appends to output without overwriting it. +TEST_P(HuffmanEncoderTest, AppendToOutput) { + size_t encoded_size = HuffmanSize("foo"); + std::string buffer; + Encode("foo", encoded_size, &buffer); + EXPECT_EQ(absl::HexStringToBytes("94e7"), buffer); + + encoded_size = HuffmanSize("bar"); + Encode("bar", encoded_size, &buffer); + EXPECT_EQ(absl::HexStringToBytes("94e78c767f"), buffer); +} + +} // namespace +} // namespace http2 diff --git a/quiche/http2/hpack/huffman/hpack_huffman_transcoder_test.cc b/quiche/http2/hpack/huffman/hpack_huffman_transcoder_test.cc new file mode 100644 index 000000000000..b3addbf44cc5 --- /dev/null +++ b/quiche/http2/hpack/huffman/hpack_huffman_transcoder_test.cc @@ -0,0 +1,182 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// A test of roundtrips through the encoder and decoder. + +#include + +#include "absl/strings/string_view.h" +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/hpack/huffman/hpack_huffman_decoder.h" +#include "quiche/http2/hpack/huffman/hpack_huffman_encoder.h" +#include "quiche/http2/test_tools/random_decoder_test_base.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/quiche_text_utils.h" + +using ::testing::AssertionSuccess; +using ::testing::Combine; +using ::testing::Range; +using ::testing::Values; + +namespace http2 { +namespace test { +namespace { + +std::string GenAsciiNonControlSet() { + std::string s; + const char space = ' '; // First character after the control characters: 0x20 + const char del = 127; // First character after the non-control characters. + for (char c = space; c < del; ++c) { + s.push_back(c); + } + return s; +} + +class HpackHuffmanTranscoderTest : public RandomDecoderTest { + protected: + HpackHuffmanTranscoderTest() + : ascii_non_control_set_(GenAsciiNonControlSet()) { + // The decoder may return true, and its accumulator may be empty, at + // many boundaries while decoding, and yet the whole string hasn't + // been decoded. + stop_decode_on_done_ = false; + } + + DecodeStatus StartDecoding(DecodeBuffer* b) override { + input_bytes_seen_ = 0; + output_buffer_.clear(); + decoder_.Reset(); + return ResumeDecoding(b); + } + + DecodeStatus ResumeDecoding(DecodeBuffer* b) override { + input_bytes_seen_ += b->Remaining(); + absl::string_view sp(b->cursor(), b->Remaining()); + if (decoder_.Decode(sp, &output_buffer_)) { + b->AdvanceCursor(b->Remaining()); + // Successfully decoded (or buffered) the bytes in absl::string_view. + EXPECT_LE(input_bytes_seen_, input_bytes_expected_); + // Have we reached the end of the encoded string? + if (input_bytes_expected_ == input_bytes_seen_) { + if (decoder_.InputProperlyTerminated()) { + return DecodeStatus::kDecodeDone; + } else { + return DecodeStatus::kDecodeError; + } + } + return DecodeStatus::kDecodeInProgress; + } + return DecodeStatus::kDecodeError; + } + + AssertionResult TranscodeAndValidateSeveralWays( + absl::string_view plain, absl::string_view expected_huffman) { + size_t encoded_size = HuffmanSize(plain); + std::string encoded; + HuffmanEncode(plain, encoded_size, &encoded); + HTTP2_VERIFY_EQ(encoded_size, encoded.size()); + if (!expected_huffman.empty() || plain.empty()) { + HTTP2_VERIFY_EQ(encoded, expected_huffman); + } + input_bytes_expected_ = encoded.size(); + auto validator = [plain, this]() -> AssertionResult { + HTTP2_VERIFY_EQ(output_buffer_.size(), plain.size()); + HTTP2_VERIFY_EQ(output_buffer_, plain); + return AssertionSuccess(); + }; + DecodeBuffer db(encoded); + bool return_non_zero_on_first = false; + return DecodeAndValidateSeveralWays(&db, return_non_zero_on_first, + ValidateDoneAndEmpty(validator)); + } + + AssertionResult TranscodeAndValidateSeveralWays(absl::string_view plain) { + return TranscodeAndValidateSeveralWays(plain, ""); + } + + std::string RandomAsciiNonControlString(int length) { + return Random().RandStringWithAlphabet(length, ascii_non_control_set_); + } + + std::string RandomBytes(int length) { return Random().RandString(length); } + + const std::string ascii_non_control_set_; + HpackHuffmanDecoder decoder_; + std::string output_buffer_; + size_t input_bytes_seen_; + size_t input_bytes_expected_; +}; + +TEST_F(HpackHuffmanTranscoderTest, RoundTripRandomAsciiNonControlString) { + for (size_t length = 0; length != 20; length++) { + const std::string s = RandomAsciiNonControlString(length); + ASSERT_TRUE(TranscodeAndValidateSeveralWays(s)) + << "Unable to decode:\n\n" + << quiche::QuicheTextUtils::HexDump(s) << "\n\noutput_buffer_:\n" + << quiche::QuicheTextUtils::HexDump(output_buffer_); + } +} + +TEST_F(HpackHuffmanTranscoderTest, RoundTripRandomBytes) { + for (size_t length = 0; length != 20; length++) { + const std::string s = RandomBytes(length); + ASSERT_TRUE(TranscodeAndValidateSeveralWays(s)) + << "Unable to decode:\n\n" + << quiche::QuicheTextUtils::HexDump(s) << "\n\noutput_buffer_:\n" + << quiche::QuicheTextUtils::HexDump(output_buffer_); + } +} + +// Two parameters: decoder choice, and the character to round-trip. +class HpackHuffmanTranscoderAdjacentCharTest + : public HpackHuffmanTranscoderTest, + public testing::WithParamInterface { + protected: + HpackHuffmanTranscoderAdjacentCharTest() + : c_(static_cast(GetParam())) {} + + const char c_; +}; + +INSTANTIATE_TEST_SUITE_P(HpackHuffmanTranscoderAdjacentCharTest, + HpackHuffmanTranscoderAdjacentCharTest, Range(0, 256)); + +// Test c_ adjacent to every other character, both before and after. +TEST_P(HpackHuffmanTranscoderAdjacentCharTest, RoundTripAdjacentChar) { + std::string s; + for (int a = 0; a < 256; ++a) { + s.push_back(static_cast(a)); + s.push_back(c_); + s.push_back(static_cast(a)); + } + ASSERT_TRUE(TranscodeAndValidateSeveralWays(s)); +} + +// Two parameters: character to repeat, number of repeats. +class HpackHuffmanTranscoderRepeatedCharTest + : public HpackHuffmanTranscoderTest, + public testing::WithParamInterface> { + protected: + HpackHuffmanTranscoderRepeatedCharTest() + : c_(static_cast(std::get<0>(GetParam()))), + length_(std::get<1>(GetParam())) {} + std::string MakeString() { return std::string(length_, c_); } + + private: + const char c_; + const size_t length_; +}; + +INSTANTIATE_TEST_SUITE_P(HpackHuffmanTranscoderRepeatedCharTest, + HpackHuffmanTranscoderRepeatedCharTest, + Combine(Range(0, 256), Values(1, 2, 3, 4, 8, 16, 32))); + +TEST_P(HpackHuffmanTranscoderRepeatedCharTest, RoundTripRepeatedChar) { + ASSERT_TRUE(TranscodeAndValidateSeveralWays(MakeString())); +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/hpack/huffman/huffman_spec_tables.cc b/quiche/http2/hpack/huffman/huffman_spec_tables.cc new file mode 100644 index 000000000000..f4b103b93f65 --- /dev/null +++ b/quiche/http2/hpack/huffman/huffman_spec_tables.cc @@ -0,0 +1,572 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/huffman/huffman_spec_tables.h" + +namespace http2 { + +// clang-format off +// static +const uint8_t HuffmanSpecTables::kCodeLengths[] = { + 13, 23, 28, 28, 28, 28, 28, 28, // 0 - 7 + 28, 24, 30, 28, 28, 30, 28, 28, // 8 - 15 + 28, 28, 28, 28, 28, 28, 30, 28, // 16 - 23 + 28, 28, 28, 28, 28, 28, 28, 28, // 24 - 31 + 6, 10, 10, 12, 13, 6, 8, 11, // 32 - 39 + 10, 10, 8, 11, 8, 6, 6, 6, // 40 - 47 + 5, 5, 5, 6, 6, 6, 6, 6, // 48 - 55 + 6, 6, 7, 8, 15, 6, 12, 10, // 56 - 63 + 13, 6, 7, 7, 7, 7, 7, 7, // 64 - 71 + 7, 7, 7, 7, 7, 7, 7, 7, // 72 - 79 + 7, 7, 7, 7, 7, 7, 7, 7, // 80 - 87 + 8, 7, 8, 13, 19, 13, 14, 6, // 88 - 95 + 15, 5, 6, 5, 6, 5, 6, 6, // 96 - 103 + 6, 5, 7, 7, 6, 6, 6, 5, // 104 - 111 + 6, 7, 6, 5, 5, 6, 7, 7, // 112 - 119 + 7, 7, 7, 15, 11, 14, 13, 28, // 120 - 127 + 20, 22, 20, 20, 22, 22, 22, 23, // 128 - 135 + 22, 23, 23, 23, 23, 23, 24, 23, // 136 - 143 + 24, 24, 22, 23, 24, 23, 23, 23, // 144 - 151 + 23, 21, 22, 23, 22, 23, 23, 24, // 152 - 159 + 22, 21, 20, 22, 22, 23, 23, 21, // 160 - 167 + 23, 22, 22, 24, 21, 22, 23, 23, // 168 - 175 + 21, 21, 22, 21, 23, 22, 23, 23, // 176 - 183 + 20, 22, 22, 22, 23, 22, 22, 23, // 184 - 191 + 26, 26, 20, 19, 22, 23, 22, 25, // 192 - 199 + 26, 26, 26, 27, 27, 26, 24, 25, // 200 - 207 + 19, 21, 26, 27, 27, 26, 27, 24, // 208 - 215 + 21, 21, 26, 26, 28, 27, 27, 27, // 216 - 223 + 20, 24, 20, 21, 22, 21, 21, 23, // 224 - 231 + 22, 22, 25, 25, 24, 24, 26, 23, // 232 - 239 + 26, 27, 26, 26, 27, 27, 27, 27, // 240 - 247 + 27, 28, 27, 27, 27, 27, 27, 26, // 248 - 255 + 30, // 256 +}; + +// The encoding of each symbol, left justified (as printed), which means that +// the first bit of the encoding is the high-order bit of the uint32. +// static +const uint32_t HuffmanSpecTables::kLeftCodes[] = { + 0b11111111110000000000000000000000, // 0x00 + 0b11111111111111111011000000000000, // 0x01 + 0b11111111111111111111111000100000, // 0x02 + 0b11111111111111111111111000110000, // 0x03 + 0b11111111111111111111111001000000, // 0x04 + 0b11111111111111111111111001010000, // 0x05 + 0b11111111111111111111111001100000, // 0x06 + 0b11111111111111111111111001110000, // 0x07 + 0b11111111111111111111111010000000, // 0x08 + 0b11111111111111111110101000000000, // 0x09 + 0b11111111111111111111111111110000, // 0x0a + 0b11111111111111111111111010010000, // 0x0b + 0b11111111111111111111111010100000, // 0x0c + 0b11111111111111111111111111110100, // 0x0d + 0b11111111111111111111111010110000, // 0x0e + 0b11111111111111111111111011000000, // 0x0f + 0b11111111111111111111111011010000, // 0x10 + 0b11111111111111111111111011100000, // 0x11 + 0b11111111111111111111111011110000, // 0x12 + 0b11111111111111111111111100000000, // 0x13 + 0b11111111111111111111111100010000, // 0x14 + 0b11111111111111111111111100100000, // 0x15 + 0b11111111111111111111111111111000, // 0x16 + 0b11111111111111111111111100110000, // 0x17 + 0b11111111111111111111111101000000, // 0x18 + 0b11111111111111111111111101010000, // 0x19 + 0b11111111111111111111111101100000, // 0x1a + 0b11111111111111111111111101110000, // 0x1b + 0b11111111111111111111111110000000, // 0x1c + 0b11111111111111111111111110010000, // 0x1d + 0b11111111111111111111111110100000, // 0x1e + 0b11111111111111111111111110110000, // 0x1f + 0b01010000000000000000000000000000, // 0x20 + 0b11111110000000000000000000000000, // '!' + 0b11111110010000000000000000000000, // '\"' + 0b11111111101000000000000000000000, // '#' + 0b11111111110010000000000000000000, // '$' + 0b01010100000000000000000000000000, // '%' + 0b11111000000000000000000000000000, // '&' + 0b11111111010000000000000000000000, // '\'' + 0b11111110100000000000000000000000, // '(' + 0b11111110110000000000000000000000, // ')' + 0b11111001000000000000000000000000, // '*' + 0b11111111011000000000000000000000, // '+' + 0b11111010000000000000000000000000, // ',' + 0b01011000000000000000000000000000, // '-' + 0b01011100000000000000000000000000, // '.' + 0b01100000000000000000000000000000, // '/' + 0b00000000000000000000000000000000, // '0' + 0b00001000000000000000000000000000, // '1' + 0b00010000000000000000000000000000, // '2' + 0b01100100000000000000000000000000, // '3' + 0b01101000000000000000000000000000, // '4' + 0b01101100000000000000000000000000, // '5' + 0b01110000000000000000000000000000, // '6' + 0b01110100000000000000000000000000, // '7' + 0b01111000000000000000000000000000, // '8' + 0b01111100000000000000000000000000, // '9' + 0b10111000000000000000000000000000, // ':' + 0b11111011000000000000000000000000, // ';' + 0b11111111111110000000000000000000, // '<' + 0b10000000000000000000000000000000, // '=' + 0b11111111101100000000000000000000, // '>' + 0b11111111000000000000000000000000, // '?' + 0b11111111110100000000000000000000, // '@' + 0b10000100000000000000000000000000, // 'A' + 0b10111010000000000000000000000000, // 'B' + 0b10111100000000000000000000000000, // 'C' + 0b10111110000000000000000000000000, // 'D' + 0b11000000000000000000000000000000, // 'E' + 0b11000010000000000000000000000000, // 'F' + 0b11000100000000000000000000000000, // 'G' + 0b11000110000000000000000000000000, // 'H' + 0b11001000000000000000000000000000, // 'I' + 0b11001010000000000000000000000000, // 'J' + 0b11001100000000000000000000000000, // 'K' + 0b11001110000000000000000000000000, // 'L' + 0b11010000000000000000000000000000, // 'M' + 0b11010010000000000000000000000000, // 'N' + 0b11010100000000000000000000000000, // 'O' + 0b11010110000000000000000000000000, // 'P' + 0b11011000000000000000000000000000, // 'Q' + 0b11011010000000000000000000000000, // 'R' + 0b11011100000000000000000000000000, // 'S' + 0b11011110000000000000000000000000, // 'T' + 0b11100000000000000000000000000000, // 'U' + 0b11100010000000000000000000000000, // 'V' + 0b11100100000000000000000000000000, // 'W' + 0b11111100000000000000000000000000, // 'X' + 0b11100110000000000000000000000000, // 'Y' + 0b11111101000000000000000000000000, // 'Z' + 0b11111111110110000000000000000000, // '[' + 0b11111111111111100000000000000000, // '\\' + 0b11111111111000000000000000000000, // ']' + 0b11111111111100000000000000000000, // '^' + 0b10001000000000000000000000000000, // '_' + 0b11111111111110100000000000000000, // '`' + 0b00011000000000000000000000000000, // 'a' + 0b10001100000000000000000000000000, // 'b' + 0b00100000000000000000000000000000, // 'c' + 0b10010000000000000000000000000000, // 'd' + 0b00101000000000000000000000000000, // 'e' + 0b10010100000000000000000000000000, // 'f' + 0b10011000000000000000000000000000, // 'g' + 0b10011100000000000000000000000000, // 'h' + 0b00110000000000000000000000000000, // 'i' + 0b11101000000000000000000000000000, // 'j' + 0b11101010000000000000000000000000, // 'k' + 0b10100000000000000000000000000000, // 'l' + 0b10100100000000000000000000000000, // 'm' + 0b10101000000000000000000000000000, // 'n' + 0b00111000000000000000000000000000, // 'o' + 0b10101100000000000000000000000000, // 'p' + 0b11101100000000000000000000000000, // 'q' + 0b10110000000000000000000000000000, // 'r' + 0b01000000000000000000000000000000, // 's' + 0b01001000000000000000000000000000, // 't' + 0b10110100000000000000000000000000, // 'u' + 0b11101110000000000000000000000000, // 'v' + 0b11110000000000000000000000000000, // 'w' + 0b11110010000000000000000000000000, // 'x' + 0b11110100000000000000000000000000, // 'y' + 0b11110110000000000000000000000000, // 'z' + 0b11111111111111000000000000000000, // '{' + 0b11111111100000000000000000000000, // '|' + 0b11111111111101000000000000000000, // '}' + 0b11111111111010000000000000000000, // '~' + 0b11111111111111111111111111000000, // 0x7f + 0b11111111111111100110000000000000, // 0x80 + 0b11111111111111110100100000000000, // 0x81 + 0b11111111111111100111000000000000, // 0x82 + 0b11111111111111101000000000000000, // 0x83 + 0b11111111111111110100110000000000, // 0x84 + 0b11111111111111110101000000000000, // 0x85 + 0b11111111111111110101010000000000, // 0x86 + 0b11111111111111111011001000000000, // 0x87 + 0b11111111111111110101100000000000, // 0x88 + 0b11111111111111111011010000000000, // 0x89 + 0b11111111111111111011011000000000, // 0x8a + 0b11111111111111111011100000000000, // 0x8b + 0b11111111111111111011101000000000, // 0x8c + 0b11111111111111111011110000000000, // 0x8d + 0b11111111111111111110101100000000, // 0x8e + 0b11111111111111111011111000000000, // 0x8f + 0b11111111111111111110110000000000, // 0x90 + 0b11111111111111111110110100000000, // 0x91 + 0b11111111111111110101110000000000, // 0x92 + 0b11111111111111111100000000000000, // 0x93 + 0b11111111111111111110111000000000, // 0x94 + 0b11111111111111111100001000000000, // 0x95 + 0b11111111111111111100010000000000, // 0x96 + 0b11111111111111111100011000000000, // 0x97 + 0b11111111111111111100100000000000, // 0x98 + 0b11111111111111101110000000000000, // 0x99 + 0b11111111111111110110000000000000, // 0x9a + 0b11111111111111111100101000000000, // 0x9b + 0b11111111111111110110010000000000, // 0x9c + 0b11111111111111111100110000000000, // 0x9d + 0b11111111111111111100111000000000, // 0x9e + 0b11111111111111111110111100000000, // 0x9f + 0b11111111111111110110100000000000, // 0xa0 + 0b11111111111111101110100000000000, // 0xa1 + 0b11111111111111101001000000000000, // 0xa2 + 0b11111111111111110110110000000000, // 0xa3 + 0b11111111111111110111000000000000, // 0xa4 + 0b11111111111111111101000000000000, // 0xa5 + 0b11111111111111111101001000000000, // 0xa6 + 0b11111111111111101111000000000000, // 0xa7 + 0b11111111111111111101010000000000, // 0xa8 + 0b11111111111111110111010000000000, // 0xa9 + 0b11111111111111110111100000000000, // 0xaa + 0b11111111111111111111000000000000, // 0xab + 0b11111111111111101111100000000000, // 0xac + 0b11111111111111110111110000000000, // 0xad + 0b11111111111111111101011000000000, // 0xae + 0b11111111111111111101100000000000, // 0xaf + 0b11111111111111110000000000000000, // 0xb0 + 0b11111111111111110000100000000000, // 0xb1 + 0b11111111111111111000000000000000, // 0xb2 + 0b11111111111111110001000000000000, // 0xb3 + 0b11111111111111111101101000000000, // 0xb4 + 0b11111111111111111000010000000000, // 0xb5 + 0b11111111111111111101110000000000, // 0xb6 + 0b11111111111111111101111000000000, // 0xb7 + 0b11111111111111101010000000000000, // 0xb8 + 0b11111111111111111000100000000000, // 0xb9 + 0b11111111111111111000110000000000, // 0xba + 0b11111111111111111001000000000000, // 0xbb + 0b11111111111111111110000000000000, // 0xbc + 0b11111111111111111001010000000000, // 0xbd + 0b11111111111111111001100000000000, // 0xbe + 0b11111111111111111110001000000000, // 0xbf + 0b11111111111111111111100000000000, // 0xc0 + 0b11111111111111111111100001000000, // 0xc1 + 0b11111111111111101011000000000000, // 0xc2 + 0b11111111111111100010000000000000, // 0xc3 + 0b11111111111111111001110000000000, // 0xc4 + 0b11111111111111111110010000000000, // 0xc5 + 0b11111111111111111010000000000000, // 0xc6 + 0b11111111111111111111011000000000, // 0xc7 + 0b11111111111111111111100010000000, // 0xc8 + 0b11111111111111111111100011000000, // 0xc9 + 0b11111111111111111111100100000000, // 0xca + 0b11111111111111111111101111000000, // 0xcb + 0b11111111111111111111101111100000, // 0xcc + 0b11111111111111111111100101000000, // 0xcd + 0b11111111111111111111000100000000, // 0xce + 0b11111111111111111111011010000000, // 0xcf + 0b11111111111111100100000000000000, // 0xd0 + 0b11111111111111110001100000000000, // 0xd1 + 0b11111111111111111111100110000000, // 0xd2 + 0b11111111111111111111110000000000, // 0xd3 + 0b11111111111111111111110000100000, // 0xd4 + 0b11111111111111111111100111000000, // 0xd5 + 0b11111111111111111111110001000000, // 0xd6 + 0b11111111111111111111001000000000, // 0xd7 + 0b11111111111111110010000000000000, // 0xd8 + 0b11111111111111110010100000000000, // 0xd9 + 0b11111111111111111111101000000000, // 0xda + 0b11111111111111111111101001000000, // 0xdb + 0b11111111111111111111111111010000, // 0xdc + 0b11111111111111111111110001100000, // 0xdd + 0b11111111111111111111110010000000, // 0xde + 0b11111111111111111111110010100000, // 0xdf + 0b11111111111111101100000000000000, // 0xe0 + 0b11111111111111111111001100000000, // 0xe1 + 0b11111111111111101101000000000000, // 0xe2 + 0b11111111111111110011000000000000, // 0xe3 + 0b11111111111111111010010000000000, // 0xe4 + 0b11111111111111110011100000000000, // 0xe5 + 0b11111111111111110100000000000000, // 0xe6 + 0b11111111111111111110011000000000, // 0xe7 + 0b11111111111111111010100000000000, // 0xe8 + 0b11111111111111111010110000000000, // 0xe9 + 0b11111111111111111111011100000000, // 0xea + 0b11111111111111111111011110000000, // 0xeb + 0b11111111111111111111010000000000, // 0xec + 0b11111111111111111111010100000000, // 0xed + 0b11111111111111111111101010000000, // 0xee + 0b11111111111111111110100000000000, // 0xef + 0b11111111111111111111101011000000, // 0xf0 + 0b11111111111111111111110011000000, // 0xf1 + 0b11111111111111111111101100000000, // 0xf2 + 0b11111111111111111111101101000000, // 0xf3 + 0b11111111111111111111110011100000, // 0xf4 + 0b11111111111111111111110100000000, // 0xf5 + 0b11111111111111111111110100100000, // 0xf6 + 0b11111111111111111111110101000000, // 0xf7 + 0b11111111111111111111110101100000, // 0xf8 + 0b11111111111111111111111111100000, // 0xf9 + 0b11111111111111111111110110000000, // 0xfa + 0b11111111111111111111110110100000, // 0xfb + 0b11111111111111111111110111000000, // 0xfc + 0b11111111111111111111110111100000, // 0xfd + 0b11111111111111111111111000000000, // 0xfe + 0b11111111111111111111101110000000, // 0xff + 0b11111111111111111111111111111100, // 0x100 +}; + +// static +const uint32_t HuffmanSpecTables::kRightCodes[] = { + 0b00000000000000000001111111111000, // 0x00 + 0b00000000011111111111111111011000, // 0x01 + 0b00001111111111111111111111100010, // 0x02 + 0b00001111111111111111111111100011, // 0x03 + 0b00001111111111111111111111100100, // 0x04 + 0b00001111111111111111111111100101, // 0x05 + 0b00001111111111111111111111100110, // 0x06 + 0b00001111111111111111111111100111, // 0x07 + 0b00001111111111111111111111101000, // 0x08 + 0b00000000111111111111111111101010, // 0x09 + 0b00111111111111111111111111111100, // 0x0a + 0b00001111111111111111111111101001, // 0x0b + 0b00001111111111111111111111101010, // 0x0c + 0b00111111111111111111111111111101, // 0x0d + 0b00001111111111111111111111101011, // 0x0e + 0b00001111111111111111111111101100, // 0x0f + 0b00001111111111111111111111101101, // 0x10 + 0b00001111111111111111111111101110, // 0x11 + 0b00001111111111111111111111101111, // 0x12 + 0b00001111111111111111111111110000, // 0x13 + 0b00001111111111111111111111110001, // 0x14 + 0b00001111111111111111111111110010, // 0x15 + 0b00111111111111111111111111111110, // 0x16 + 0b00001111111111111111111111110011, // 0x17 + 0b00001111111111111111111111110100, // 0x18 + 0b00001111111111111111111111110101, // 0x19 + 0b00001111111111111111111111110110, // 0x1a + 0b00001111111111111111111111110111, // 0x1b + 0b00001111111111111111111111111000, // 0x1c + 0b00001111111111111111111111111001, // 0x1d + 0b00001111111111111111111111111010, // 0x1e + 0b00001111111111111111111111111011, // 0x1f + 0b00000000000000000000000000010100, // 0x20 + 0b00000000000000000000001111111000, // '!' + 0b00000000000000000000001111111001, // '\"' + 0b00000000000000000000111111111010, // '#' + 0b00000000000000000001111111111001, // '$' + 0b00000000000000000000000000010101, // '%' + 0b00000000000000000000000011111000, // '&' + 0b00000000000000000000011111111010, // '\'' + 0b00000000000000000000001111111010, // '(' + 0b00000000000000000000001111111011, // ')' + 0b00000000000000000000000011111001, // '*' + 0b00000000000000000000011111111011, // '+' + 0b00000000000000000000000011111010, // ',' + 0b00000000000000000000000000010110, // '-' + 0b00000000000000000000000000010111, // '.' + 0b00000000000000000000000000011000, // '/' + 0b00000000000000000000000000000000, // '0' + 0b00000000000000000000000000000001, // '1' + 0b00000000000000000000000000000010, // '2' + 0b00000000000000000000000000011001, // '3' + 0b00000000000000000000000000011010, // '4' + 0b00000000000000000000000000011011, // '5' + 0b00000000000000000000000000011100, // '6' + 0b00000000000000000000000000011101, // '7' + 0b00000000000000000000000000011110, // '8' + 0b00000000000000000000000000011111, // '9' + 0b00000000000000000000000001011100, // ':' + 0b00000000000000000000000011111011, // ';' + 0b00000000000000000111111111111100, // '<' + 0b00000000000000000000000000100000, // '=' + 0b00000000000000000000111111111011, // '>' + 0b00000000000000000000001111111100, // '?' + 0b00000000000000000001111111111010, // '@' + 0b00000000000000000000000000100001, // 'A' + 0b00000000000000000000000001011101, // 'B' + 0b00000000000000000000000001011110, // 'C' + 0b00000000000000000000000001011111, // 'D' + 0b00000000000000000000000001100000, // 'E' + 0b00000000000000000000000001100001, // 'F' + 0b00000000000000000000000001100010, // 'G' + 0b00000000000000000000000001100011, // 'H' + 0b00000000000000000000000001100100, // 'I' + 0b00000000000000000000000001100101, // 'J' + 0b00000000000000000000000001100110, // 'K' + 0b00000000000000000000000001100111, // 'L' + 0b00000000000000000000000001101000, // 'M' + 0b00000000000000000000000001101001, // 'N' + 0b00000000000000000000000001101010, // 'O' + 0b00000000000000000000000001101011, // 'P' + 0b00000000000000000000000001101100, // 'Q' + 0b00000000000000000000000001101101, // 'R' + 0b00000000000000000000000001101110, // 'S' + 0b00000000000000000000000001101111, // 'T' + 0b00000000000000000000000001110000, // 'U' + 0b00000000000000000000000001110001, // 'V' + 0b00000000000000000000000001110010, // 'W' + 0b00000000000000000000000011111100, // 'X' + 0b00000000000000000000000001110011, // 'Y' + 0b00000000000000000000000011111101, // 'Z' + 0b00000000000000000001111111111011, // '[' + 0b00000000000001111111111111110000, // '\\' + 0b00000000000000000001111111111100, // ']' + 0b00000000000000000011111111111100, // '^' + 0b00000000000000000000000000100010, // '_' + 0b00000000000000000111111111111101, // '`' + 0b00000000000000000000000000000011, // 'a' + 0b00000000000000000000000000100011, // 'b' + 0b00000000000000000000000000000100, // 'c' + 0b00000000000000000000000000100100, // 'd' + 0b00000000000000000000000000000101, // 'e' + 0b00000000000000000000000000100101, // 'f' + 0b00000000000000000000000000100110, // 'g' + 0b00000000000000000000000000100111, // 'h' + 0b00000000000000000000000000000110, // 'i' + 0b00000000000000000000000001110100, // 'j' + 0b00000000000000000000000001110101, // 'k' + 0b00000000000000000000000000101000, // 'l' + 0b00000000000000000000000000101001, // 'm' + 0b00000000000000000000000000101010, // 'n' + 0b00000000000000000000000000000111, // 'o' + 0b00000000000000000000000000101011, // 'p' + 0b00000000000000000000000001110110, // 'q' + 0b00000000000000000000000000101100, // 'r' + 0b00000000000000000000000000001000, // 's' + 0b00000000000000000000000000001001, // 't' + 0b00000000000000000000000000101101, // 'u' + 0b00000000000000000000000001110111, // 'v' + 0b00000000000000000000000001111000, // 'w' + 0b00000000000000000000000001111001, // 'x' + 0b00000000000000000000000001111010, // 'y' + 0b00000000000000000000000001111011, // 'z' + 0b00000000000000000111111111111110, // '{' + 0b00000000000000000000011111111100, // '|' + 0b00000000000000000011111111111101, // '}' + 0b00000000000000000001111111111101, // '~' + 0b00001111111111111111111111111100, // 0x7f + 0b00000000000011111111111111100110, // 0x80 + 0b00000000001111111111111111010010, // 0x81 + 0b00000000000011111111111111100111, // 0x82 + 0b00000000000011111111111111101000, // 0x83 + 0b00000000001111111111111111010011, // 0x84 + 0b00000000001111111111111111010100, // 0x85 + 0b00000000001111111111111111010101, // 0x86 + 0b00000000011111111111111111011001, // 0x87 + 0b00000000001111111111111111010110, // 0x88 + 0b00000000011111111111111111011010, // 0x89 + 0b00000000011111111111111111011011, // 0x8a + 0b00000000011111111111111111011100, // 0x8b + 0b00000000011111111111111111011101, // 0x8c + 0b00000000011111111111111111011110, // 0x8d + 0b00000000111111111111111111101011, // 0x8e + 0b00000000011111111111111111011111, // 0x8f + 0b00000000111111111111111111101100, // 0x90 + 0b00000000111111111111111111101101, // 0x91 + 0b00000000001111111111111111010111, // 0x92 + 0b00000000011111111111111111100000, // 0x93 + 0b00000000111111111111111111101110, // 0x94 + 0b00000000011111111111111111100001, // 0x95 + 0b00000000011111111111111111100010, // 0x96 + 0b00000000011111111111111111100011, // 0x97 + 0b00000000011111111111111111100100, // 0x98 + 0b00000000000111111111111111011100, // 0x99 + 0b00000000001111111111111111011000, // 0x9a + 0b00000000011111111111111111100101, // 0x9b + 0b00000000001111111111111111011001, // 0x9c + 0b00000000011111111111111111100110, // 0x9d + 0b00000000011111111111111111100111, // 0x9e + 0b00000000111111111111111111101111, // 0x9f + 0b00000000001111111111111111011010, // 0xa0 + 0b00000000000111111111111111011101, // 0xa1 + 0b00000000000011111111111111101001, // 0xa2 + 0b00000000001111111111111111011011, // 0xa3 + 0b00000000001111111111111111011100, // 0xa4 + 0b00000000011111111111111111101000, // 0xa5 + 0b00000000011111111111111111101001, // 0xa6 + 0b00000000000111111111111111011110, // 0xa7 + 0b00000000011111111111111111101010, // 0xa8 + 0b00000000001111111111111111011101, // 0xa9 + 0b00000000001111111111111111011110, // 0xaa + 0b00000000111111111111111111110000, // 0xab + 0b00000000000111111111111111011111, // 0xac + 0b00000000001111111111111111011111, // 0xad + 0b00000000011111111111111111101011, // 0xae + 0b00000000011111111111111111101100, // 0xaf + 0b00000000000111111111111111100000, // 0xb0 + 0b00000000000111111111111111100001, // 0xb1 + 0b00000000001111111111111111100000, // 0xb2 + 0b00000000000111111111111111100010, // 0xb3 + 0b00000000011111111111111111101101, // 0xb4 + 0b00000000001111111111111111100001, // 0xb5 + 0b00000000011111111111111111101110, // 0xb6 + 0b00000000011111111111111111101111, // 0xb7 + 0b00000000000011111111111111101010, // 0xb8 + 0b00000000001111111111111111100010, // 0xb9 + 0b00000000001111111111111111100011, // 0xba + 0b00000000001111111111111111100100, // 0xbb + 0b00000000011111111111111111110000, // 0xbc + 0b00000000001111111111111111100101, // 0xbd + 0b00000000001111111111111111100110, // 0xbe + 0b00000000011111111111111111110001, // 0xbf + 0b00000011111111111111111111100000, // 0xc0 + 0b00000011111111111111111111100001, // 0xc1 + 0b00000000000011111111111111101011, // 0xc2 + 0b00000000000001111111111111110001, // 0xc3 + 0b00000000001111111111111111100111, // 0xc4 + 0b00000000011111111111111111110010, // 0xc5 + 0b00000000001111111111111111101000, // 0xc6 + 0b00000001111111111111111111101100, // 0xc7 + 0b00000011111111111111111111100010, // 0xc8 + 0b00000011111111111111111111100011, // 0xc9 + 0b00000011111111111111111111100100, // 0xca + 0b00000111111111111111111111011110, // 0xcb + 0b00000111111111111111111111011111, // 0xcc + 0b00000011111111111111111111100101, // 0xcd + 0b00000000111111111111111111110001, // 0xce + 0b00000001111111111111111111101101, // 0xcf + 0b00000000000001111111111111110010, // 0xd0 + 0b00000000000111111111111111100011, // 0xd1 + 0b00000011111111111111111111100110, // 0xd2 + 0b00000111111111111111111111100000, // 0xd3 + 0b00000111111111111111111111100001, // 0xd4 + 0b00000011111111111111111111100111, // 0xd5 + 0b00000111111111111111111111100010, // 0xd6 + 0b00000000111111111111111111110010, // 0xd7 + 0b00000000000111111111111111100100, // 0xd8 + 0b00000000000111111111111111100101, // 0xd9 + 0b00000011111111111111111111101000, // 0xda + 0b00000011111111111111111111101001, // 0xdb + 0b00001111111111111111111111111101, // 0xdc + 0b00000111111111111111111111100011, // 0xdd + 0b00000111111111111111111111100100, // 0xde + 0b00000111111111111111111111100101, // 0xdf + 0b00000000000011111111111111101100, // 0xe0 + 0b00000000111111111111111111110011, // 0xe1 + 0b00000000000011111111111111101101, // 0xe2 + 0b00000000000111111111111111100110, // 0xe3 + 0b00000000001111111111111111101001, // 0xe4 + 0b00000000000111111111111111100111, // 0xe5 + 0b00000000000111111111111111101000, // 0xe6 + 0b00000000011111111111111111110011, // 0xe7 + 0b00000000001111111111111111101010, // 0xe8 + 0b00000000001111111111111111101011, // 0xe9 + 0b00000001111111111111111111101110, // 0xea + 0b00000001111111111111111111101111, // 0xeb + 0b00000000111111111111111111110100, // 0xec + 0b00000000111111111111111111110101, // 0xed + 0b00000011111111111111111111101010, // 0xee + 0b00000000011111111111111111110100, // 0xef + 0b00000011111111111111111111101011, // 0xf0 + 0b00000111111111111111111111100110, // 0xf1 + 0b00000011111111111111111111101100, // 0xf2 + 0b00000011111111111111111111101101, // 0xf3 + 0b00000111111111111111111111100111, // 0xf4 + 0b00000111111111111111111111101000, // 0xf5 + 0b00000111111111111111111111101001, // 0xf6 + 0b00000111111111111111111111101010, // 0xf7 + 0b00000111111111111111111111101011, // 0xf8 + 0b00001111111111111111111111111110, // 0xf9 + 0b00000111111111111111111111101100, // 0xfa + 0b00000111111111111111111111101101, // 0xfb + 0b00000111111111111111111111101110, // 0xfc + 0b00000111111111111111111111101111, // 0xfd + 0b00000111111111111111111111110000, // 0xfe + 0b00000011111111111111111111101110, // 0xff + 0b00111111111111111111111111111111, // 0x100 +}; +// clang-format off + +} // namespace http2 diff --git a/quiche/http2/hpack/huffman/huffman_spec_tables.h b/quiche/http2/hpack/huffman/huffman_spec_tables.h new file mode 100644 index 000000000000..7e66fef5027c --- /dev/null +++ b/quiche/http2/hpack/huffman/huffman_spec_tables.h @@ -0,0 +1,31 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_HPACK_HUFFMAN_HUFFMAN_SPEC_TABLES_H_ +#define QUICHE_HTTP2_HPACK_HUFFMAN_HUFFMAN_SPEC_TABLES_H_ + +// Tables describing the Huffman encoding of bytes as specified by RFC7541. + +#include + +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { + +struct QUICHE_EXPORT HuffmanSpecTables { + // Number of bits in the encoding of each symbol (byte). + static const uint8_t kCodeLengths[257]; + + // The encoding of each symbol, right justified (as printed), which means that + // the last bit of the encoding is the low-order bit of the uint32. + static const uint32_t kRightCodes[257]; + + // The encoding of each symbol, left justified (as printed), which means that + // the first bit of the encoding is the high-order bit of the uint32. + static const uint32_t kLeftCodes[257]; +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_HPACK_HUFFMAN_HUFFMAN_SPEC_TABLES_H_ diff --git a/quiche/http2/hpack/varint/hpack_varint_decoder.cc b/quiche/http2/hpack/varint/hpack_varint_decoder.cc new file mode 100644 index 000000000000..48f5aa3769e6 --- /dev/null +++ b/quiche/http2/hpack/varint/hpack_varint_decoder.cc @@ -0,0 +1,143 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/varint/hpack_varint_decoder.h" + +#include "absl/strings/str_cat.h" + +namespace http2 { + +DecodeStatus HpackVarintDecoder::Start(uint8_t prefix_value, + uint8_t prefix_length, + DecodeBuffer* db) { + QUICHE_DCHECK_LE(3u, prefix_length); + QUICHE_DCHECK_LE(prefix_length, 8u); + + // |prefix_mask| defines the sequence of low-order bits of the first byte + // that encode the prefix of the value. It is also the marker in those bits + // of the first byte indicating that at least one extension byte is needed. + const uint8_t prefix_mask = (1 << prefix_length) - 1; + + // Ignore the bits that aren't a part of the prefix of the varint. + value_ = prefix_value & prefix_mask; + + if (value_ < prefix_mask) { + MarkDone(); + return DecodeStatus::kDecodeDone; + } + + offset_ = 0; + return Resume(db); +} + +DecodeStatus HpackVarintDecoder::StartExtended(uint8_t prefix_length, + DecodeBuffer* db) { + QUICHE_DCHECK_LE(3u, prefix_length); + QUICHE_DCHECK_LE(prefix_length, 8u); + + value_ = (1 << prefix_length) - 1; + offset_ = 0; + return Resume(db); +} + +DecodeStatus HpackVarintDecoder::Resume(DecodeBuffer* db) { + // There can be at most 10 continuation bytes. Offset is zero for the + // first one and increases by 7 for each subsequent one. + const uint8_t kMaxOffset = 63; + CheckNotDone(); + + // Process most extension bytes without the need for overflow checking. + while (offset_ < kMaxOffset) { + if (db->Empty()) { + return DecodeStatus::kDecodeInProgress; + } + + uint8_t byte = db->DecodeUInt8(); + uint64_t summand = byte & 0x7f; + + // Shifting a 7 bit value to the left by at most 56 places can never + // overflow on uint64_t. + QUICHE_DCHECK_LE(offset_, 56); + QUICHE_DCHECK_LE(summand, std::numeric_limits::max() >> offset_); + + summand <<= offset_; + + // At this point, + // |value_| is at most (2^prefix_length - 1) + (2^49 - 1), and + // |summand| is at most 255 << 56 (which is smaller than 2^63), + // so adding them can never overflow on uint64_t. + QUICHE_DCHECK_LE(value_, std::numeric_limits::max() - summand); + + value_ += summand; + + // Decoding ends if continuation flag is not set. + if ((byte & 0x80) == 0) { + MarkDone(); + return DecodeStatus::kDecodeDone; + } + + offset_ += 7; + } + + if (db->Empty()) { + return DecodeStatus::kDecodeInProgress; + } + + QUICHE_DCHECK_EQ(kMaxOffset, offset_); + + uint8_t byte = db->DecodeUInt8(); + // No more extension bytes are allowed after this. + if ((byte & 0x80) == 0) { + uint64_t summand = byte & 0x7f; + // Check for overflow in left shift. + if (summand <= std::numeric_limits::max() >> offset_) { + summand <<= offset_; + // Check for overflow in addition. + if (value_ <= std::numeric_limits::max() - summand) { + value_ += summand; + MarkDone(); + return DecodeStatus::kDecodeDone; + } + } + } + + // Signal error if value is too large or there are too many extension bytes. + QUICHE_DLOG(WARNING) + << "Variable length int encoding is too large or too long. " + << DebugString(); + MarkDone(); + return DecodeStatus::kDecodeError; +} + +uint64_t HpackVarintDecoder::value() const { + CheckDone(); + return value_; +} + +void HpackVarintDecoder::set_value(uint64_t v) { + MarkDone(); + value_ = v; +} + +std::string HpackVarintDecoder::DebugString() const { + return absl::StrCat("HpackVarintDecoder(value=", value_, ", offset=", offset_, + ")"); +} + +DecodeStatus HpackVarintDecoder::StartForTest(uint8_t prefix_value, + uint8_t prefix_length, + DecodeBuffer* db) { + return Start(prefix_value, prefix_length, db); +} + +DecodeStatus HpackVarintDecoder::StartExtendedForTest(uint8_t prefix_length, + DecodeBuffer* db) { + return StartExtended(prefix_length, db); +} + +DecodeStatus HpackVarintDecoder::ResumeForTest(DecodeBuffer* db) { + return Resume(db); +} + +} // namespace http2 diff --git a/quiche/http2/hpack/varint/hpack_varint_decoder.h b/quiche/http2/hpack/varint/hpack_varint_decoder.h new file mode 100644 index 000000000000..af998a2453a3 --- /dev/null +++ b/quiche/http2/hpack/varint/hpack_varint_decoder.h @@ -0,0 +1,128 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// HpackVarintDecoder decodes HPACK variable length unsigned integers. In HPACK, +// these integers are used to identify static or dynamic table index entries, to +// specify string lengths, and to update the size limit of the dynamic table. +// In QPACK, in addition to these uses, these integers also identify streams. +// +// The caller will need to validate that the decoded value is in an acceptable +// range. +// +// For details of the encoding, see: +// http://httpwg.org/specs/rfc7541.html#integer.representation +// +// HpackVarintDecoder supports decoding any integer that can be represented on +// uint64_t, thereby exceeding the requirements for QPACK: "QPACK +// implementations MUST be able to decode integers up to 62 bits long." See +// https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#rfc.section.5.1.1 +// +// This decoder supports at most 10 extension bytes (bytes following the prefix, +// also called continuation bytes). An encoder is allowed to zero pad the +// encoded integer on the left, thereby increasing the number of extension +// bytes. If an encoder uses so much padding that the number of extension bytes +// exceeds the limit, then this decoder signals an error. + +#ifndef QUICHE_HTTP2_HPACK_VARINT_HPACK_VARINT_DECODER_H_ +#define QUICHE_HTTP2_HPACK_VARINT_HPACK_VARINT_DECODER_H_ + +#include +#include +#include + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +// Sentinel value for |HpackVarintDecoder::offset_| to signify that decoding is +// completed. Only used in debug builds. +#ifndef NDEBUG +const uint8_t kHpackVarintDecoderOffsetDone = + std::numeric_limits::max(); +#endif + +// Decodes an HPACK variable length unsigned integer, in a resumable fashion +// so it can handle running out of input in the DecodeBuffer. Call Start or +// StartExtended the first time (when decoding the byte that contains the +// prefix), then call Resume later if it is necessary to resume. When done, +// call value() to retrieve the decoded value. +// +// No constructor or destructor. Holds no resources, so destruction isn't +// needed. Start and StartExtended handles the initialization of member +// variables. This is necessary in order for HpackVarintDecoder to be part +// of a union. +class QUICHE_EXPORT HpackVarintDecoder { + public: + // |prefix_value| is the first byte of the encoded varint. + // |prefix_length| is number of bits in the first byte that are used for + // encoding the integer. |db| is the rest of the buffer, that is, not + // including the first byte. + DecodeStatus Start(uint8_t prefix_value, uint8_t prefix_length, + DecodeBuffer* db); + + // The caller has already determined that the encoding requires multiple + // bytes, i.e. that the 3 to 8 low-order bits (the number determined by + // |prefix_length|) of the first byte are are all 1. |db| is the rest of the + // buffer, that is, not including the first byte. + DecodeStatus StartExtended(uint8_t prefix_length, DecodeBuffer* db); + + // Resume decoding a variable length integer after an earlier + // call to Start or StartExtended returned kDecodeInProgress. + DecodeStatus Resume(DecodeBuffer* db); + + uint64_t value() const; + + // This supports optimizations for the case of a varint with zero extension + // bytes, where the handling of the prefix is done by the caller. + void set_value(uint64_t v); + + // All the public methods below are for supporting assertions and tests. + + std::string DebugString() const; + + // For benchmarking, these methods ensure the decoder + // is NOT inlined into the caller. + DecodeStatus StartForTest(uint8_t prefix_value, uint8_t prefix_length, + DecodeBuffer* db); + DecodeStatus StartExtendedForTest(uint8_t prefix_length, DecodeBuffer* db); + DecodeStatus ResumeForTest(DecodeBuffer* db); + + private: + // Protection in case Resume is called when it shouldn't be. + void MarkDone() { +#ifndef NDEBUG + offset_ = kHpackVarintDecoderOffsetDone; +#endif + } + void CheckNotDone() const { +#ifndef NDEBUG + QUICHE_DCHECK_NE(kHpackVarintDecoderOffsetDone, offset_); +#endif + } + void CheckDone() const { +#ifndef NDEBUG + QUICHE_DCHECK_EQ(kHpackVarintDecoderOffsetDone, offset_); +#endif + } + + // These fields are initialized just to keep ASAN happy about reading + // them from DebugString(). + + // The encoded integer is being accumulated in |value_|. When decoding is + // complete, |value_| holds the result. + uint64_t value_ = 0; + + // Each extension byte encodes in its lowest 7 bits a segment of the integer. + // |offset_| is the number of places this segment has to be shifted to the + // left for decoding. It is zero for the first extension byte, and increases + // by 7 for each subsequent extension byte. + uint8_t offset_ = 0; +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_HPACK_VARINT_HPACK_VARINT_DECODER_H_ diff --git a/quiche/http2/hpack/varint/hpack_varint_decoder_test.cc b/quiche/http2/hpack/varint/hpack_varint_decoder_test.cc new file mode 100644 index 000000000000..950b01aaaa07 --- /dev/null +++ b/quiche/http2/hpack/varint/hpack_varint_decoder_test.cc @@ -0,0 +1,309 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/varint/hpack_varint_decoder.h" + +// Test HpackVarintDecoder against hardcoded data. + +#include + +#include "absl/base/macros.h" +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/http2/test_tools/random_decoder_test_base.h" +#include "quiche/http2/test_tools/verify_macros.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +using ::testing::AssertionSuccess; + +namespace http2 { +namespace test { +namespace { + +class HpackVarintDecoderTest + : public RandomDecoderTest, + public ::testing::WithParamInterface> { + protected: + HpackVarintDecoderTest() + : high_bits_(std::get<0>(GetParam())), + suffix_(absl::HexStringToBytes(std::get<1>(GetParam()))), + prefix_length_(0) {} + + void DecodeExpectSuccess(absl::string_view data, uint32_t prefix_length, + uint64_t expected_value) { + Validator validator = [expected_value, this]( + const DecodeBuffer& /*db*/, + DecodeStatus /*status*/) -> AssertionResult { + HTTP2_VERIFY_EQ(expected_value, decoder_.value()) + << "Value doesn't match expected: " << decoder_.value() + << " != " << expected_value; + return AssertionSuccess(); + }; + + // First validate that decoding is done and that we've advanced the cursor + // the expected amount. + validator = ValidateDoneAndOffset(/* offset = */ data.size(), validator); + + EXPECT_TRUE(Decode(data, prefix_length, validator)); + + EXPECT_EQ(expected_value, decoder_.value()); + } + + void DecodeExpectError(absl::string_view data, uint32_t prefix_length) { + Validator validator = [](const DecodeBuffer& /*db*/, + DecodeStatus status) -> AssertionResult { + HTTP2_VERIFY_EQ(DecodeStatus::kDecodeError, status); + return AssertionSuccess(); + }; + + EXPECT_TRUE(Decode(data, prefix_length, validator)); + } + + private: + AssertionResult Decode(absl::string_view data, uint32_t prefix_length, + const Validator validator) { + prefix_length_ = prefix_length; + + // Copy |data| so that it can be modified. + std::string data_copy(data); + + // Bits of the first byte not part of the prefix should be ignored. + uint8_t high_bits_mask = 0b11111111 << prefix_length_; + data_copy[0] |= (high_bits_mask & high_bits_); + + // Extra bytes appended to the input should be ignored. + data_copy.append(suffix_); + + DecodeBuffer b(data_copy); + + // StartDecoding, above, requires the DecodeBuffer be non-empty so that it + // can call Start with the prefix byte. + bool return_non_zero_on_first = true; + + return DecodeAndValidateSeveralWays(&b, return_non_zero_on_first, + validator); + } + + DecodeStatus StartDecoding(DecodeBuffer* b) override { + QUICHE_CHECK_LT(0u, b->Remaining()); + uint8_t prefix = b->DecodeUInt8(); + return decoder_.Start(prefix, prefix_length_, b); + } + + DecodeStatus ResumeDecoding(DecodeBuffer* b) override { + return decoder_.Resume(b); + } + + // Bits of the first byte not part of the prefix. + const uint8_t high_bits_; + // Extra bytes appended to the input. + const std::string suffix_; + + HpackVarintDecoder decoder_; + uint8_t prefix_length_; +}; + +INSTANTIATE_TEST_SUITE_P( + HpackVarintDecoderTest, HpackVarintDecoderTest, + ::testing::Combine( + // Bits of the first byte not part of the prefix should be ignored. + ::testing::Values(0b00000000, 0b11111111, 0b10101010), + // Extra bytes appended to the input should be ignored. + ::testing::Values("", "00", "666f6f"))); + +struct { + const char* data; + uint32_t prefix_length; + uint64_t expected_value; +} kSuccessTestData[] = { + // Zero value with different prefix lengths. + {"00", 3, 0}, + {"00", 4, 0}, + {"00", 5, 0}, + {"00", 6, 0}, + {"00", 7, 0}, + {"00", 8, 0}, + // Small values that fit in the prefix. + {"06", 3, 6}, + {"0d", 4, 13}, + {"10", 5, 16}, + {"29", 6, 41}, + {"56", 7, 86}, + {"bf", 8, 191}, + // Values of 2^n-1, which have an all-zero extension byte. + {"0700", 3, 7}, + {"0f00", 4, 15}, + {"1f00", 5, 31}, + {"3f00", 6, 63}, + {"7f00", 7, 127}, + {"ff00", 8, 255}, + // Values of 2^n-1, plus one extra byte of padding. + {"078000", 3, 7}, + {"0f8000", 4, 15}, + {"1f8000", 5, 31}, + {"3f8000", 6, 63}, + {"7f8000", 7, 127}, + {"ff8000", 8, 255}, + // Values requiring one extension byte. + {"0760", 3, 103}, + {"0f2a", 4, 57}, + {"1f7f", 5, 158}, + {"3f02", 6, 65}, + {"7f49", 7, 200}, + {"ff6f", 8, 366}, + // Values requiring one extension byte, plus one byte of padding. + {"07e000", 3, 103}, + {"0faa00", 4, 57}, + {"1fff00", 5, 158}, + {"3f8200", 6, 65}, + {"7fc900", 7, 200}, + {"ffef00", 8, 366}, + // Values requiring one extension byte, plus two bytes of padding. + {"07e08000", 3, 103}, + {"0faa8000", 4, 57}, + {"1fff8000", 5, 158}, + {"3f828000", 6, 65}, + {"7fc98000", 7, 200}, + {"ffef8000", 8, 366}, + // Values requiring one extension byte, plus the maximum amount of padding. + {"07e0808080808080808000", 3, 103}, + {"0faa808080808080808000", 4, 57}, + {"1fff808080808080808000", 5, 158}, + {"3f82808080808080808000", 6, 65}, + {"7fc9808080808080808000", 7, 200}, + {"ffef808080808080808000", 8, 366}, + // Values requiring two extension bytes. + {"07b260", 3, 12345}, + {"0f8a2a", 4, 5401}, + {"1fa87f", 5, 16327}, + {"3fd002", 6, 399}, + {"7fff49", 7, 9598}, + {"ffe32f", 8, 6370}, + // Values requiring two extension bytes, plus one byte of padding. + {"07b2e000", 3, 12345}, + {"0f8aaa00", 4, 5401}, + {"1fa8ff00", 5, 16327}, + {"3fd08200", 6, 399}, + {"7fffc900", 7, 9598}, + {"ffe3af00", 8, 6370}, + // Values requiring two extension bytes, plus the maximum amount of padding. + {"07b2e080808080808000", 3, 12345}, + {"0f8aaa80808080808000", 4, 5401}, + {"1fa8ff80808080808000", 5, 16327}, + {"3fd08280808080808000", 6, 399}, + {"7fffc980808080808000", 7, 9598}, + {"ffe3af80808080808000", 8, 6370}, + // Values requiring three extension bytes. + {"078ab260", 3, 1579281}, + {"0fc18a2a", 4, 689488}, + {"1fada87f", 5, 2085964}, + {"3fa0d002", 6, 43103}, + {"7ffeff49", 7, 1212541}, + {"ff93de23", 8, 585746}, + // Values requiring three extension bytes, plus one byte of padding. + {"078ab2e000", 3, 1579281}, + {"0fc18aaa00", 4, 689488}, + {"1fada8ff00", 5, 2085964}, + {"3fa0d08200", 6, 43103}, + {"7ffeffc900", 7, 1212541}, + {"ff93dea300", 8, 585746}, + // Values requiring four extension bytes. + {"079f8ab260", 3, 202147110}, + {"0fa2c18a2a", 4, 88252593}, + {"1fd0ada87f", 5, 266999535}, + {"3ff9a0d002", 6, 5509304}, + {"7f9efeff49", 7, 155189149}, + {"ffaa82f404", 8, 10289705}, + // Values requiring four extension bytes, plus one byte of padding. + {"079f8ab2e000", 3, 202147110}, + {"0fa2c18aaa00", 4, 88252593}, + {"1fd0ada8ff00", 5, 266999535}, + {"3ff9a0d08200", 6, 5509304}, + {"7f9efeffc900", 7, 155189149}, + {"ffaa82f48400", 8, 10289705}, + // Values requiring six extension bytes. + {"0783aa9f8ab260", 3, 3311978140938}, + {"0ff0b0a2c18a2a", 4, 1445930244223}, + {"1fda84d0ada87f", 5, 4374519874169}, + {"3fb5fbf9a0d002", 6, 90263420404}, + {"7fcff19efeff49", 7, 2542616951118}, + {"ff9fa486bbc327", 8, 1358138807070}, + // Values requiring eight extension bytes. + {"07f19883aa9f8ab260", 3, 54263449861016696}, + {"0f84fdf0b0a2c18a2a", 4, 23690121121119891}, + {"1fa0dfda84d0ada87f", 5, 71672133617889215}, + {"3f9ff0b5fbf9a0d002", 6, 1478875878881374}, + {"7ffbc1cff19efeff49", 7, 41658236125045114}, + {"ff91b6fb85af99c342", 8, 37450237664484368}, + // Values requiring ten extension bytes. + {"0794f1f19883aa9f8ab201", 3, 12832019021693745307u}, + {"0fa08f84fdf0b0a2c18a01", 4, 9980690937382242223u}, + {"1fbfdda0dfda84d0ada801", 5, 12131360551794650846u}, + {"3f9dc79ff0b5fbf9a0d001", 6, 15006530362736632796u}, + {"7f8790fbc1cff19efeff01", 7, 18445754019193211014u}, + {"fffba8c5b8d3fe9f8c8401", 8, 9518498503615141242u}, + // Maximum value: 2^64-1. + {"07f8ffffffffffffffff01", 3, 18446744073709551615u}, + {"0ff0ffffffffffffffff01", 4, 18446744073709551615u}, + {"1fe0ffffffffffffffff01", 5, 18446744073709551615u}, + {"3fc0ffffffffffffffff01", 6, 18446744073709551615u}, + {"7f80ffffffffffffffff01", 7, 18446744073709551615u}, + {"ff80feffffffffffffff01", 8, 18446744073709551615u}, + // Examples from RFC7541 C.1. + {"0a", 5, 10}, + {"1f9a0a", 5, 1337}, +}; + +TEST_P(HpackVarintDecoderTest, Success) { + for (size_t i = 0; i < ABSL_ARRAYSIZE(kSuccessTestData); ++i) { + DecodeExpectSuccess(absl::HexStringToBytes(kSuccessTestData[i].data), + kSuccessTestData[i].prefix_length, + kSuccessTestData[i].expected_value); + } +} + +struct { + const char* data; + uint32_t prefix_length; +} kErrorTestData[] = { + // Too many extension bytes, all 0s (except for extension bit in each byte). + {"0780808080808080808080", 3}, + {"0f80808080808080808080", 4}, + {"1f80808080808080808080", 5}, + {"3f80808080808080808080", 6}, + {"7f80808080808080808080", 7}, + {"ff80808080808080808080", 8}, + // Too many extension bytes, all 1s. + {"07ffffffffffffffffffff", 3}, + {"0fffffffffffffffffffff", 4}, + {"1fffffffffffffffffffff", 5}, + {"3fffffffffffffffffffff", 6}, + {"7fffffffffffffffffffff", 7}, + {"ffffffffffffffffffffff", 8}, + // Value of 2^64, one higher than maximum of 2^64-1. + {"07f9ffffffffffffffff01", 3}, + {"0ff1ffffffffffffffff01", 4}, + {"1fe1ffffffffffffffff01", 5}, + {"3fc1ffffffffffffffff01", 6}, + {"7f81ffffffffffffffff01", 7}, + {"ff81feffffffffffffff01", 8}, + // Maximum value: 2^64-1, with one byte of padding. + {"07f8ffffffffffffffff8100", 3}, + {"0ff0ffffffffffffffff8100", 4}, + {"1fe0ffffffffffffffff8100", 5}, + {"3fc0ffffffffffffffff8100", 6}, + {"7f80ffffffffffffffff8100", 7}, + {"ff80feffffffffffffff8100", 8}}; + +TEST_P(HpackVarintDecoderTest, Error) { + for (size_t i = 0; i < ABSL_ARRAYSIZE(kErrorTestData); ++i) { + DecodeExpectError(absl::HexStringToBytes(kErrorTestData[i].data), + kErrorTestData[i].prefix_length); + } +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/hpack/varint/hpack_varint_encoder.cc b/quiche/http2/hpack/varint/hpack_varint_encoder.cc new file mode 100644 index 000000000000..07c5141916e7 --- /dev/null +++ b/quiche/http2/hpack/varint/hpack_varint_encoder.cc @@ -0,0 +1,47 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/varint/hpack_varint_encoder.h" + +#include + +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +// static +void HpackVarintEncoder::Encode(uint8_t high_bits, uint8_t prefix_length, + uint64_t varint, std::string* output) { + QUICHE_DCHECK_LE(1u, prefix_length); + QUICHE_DCHECK_LE(prefix_length, 8u); + + // prefix_mask defines the sequence of low-order bits of the first byte + // that encode the prefix of the value. It is also the marker in those bits + // of the first byte indicating that at least one extension byte is needed. + const uint8_t prefix_mask = (1 << prefix_length) - 1; + QUICHE_DCHECK_EQ(0, high_bits & prefix_mask); + + if (varint < prefix_mask) { + // The integer fits into the prefix in its entirety. + unsigned char first_byte = high_bits | static_cast(varint); + output->push_back(first_byte); + return; + } + + // Extension bytes are needed. + unsigned char first_byte = high_bits | prefix_mask; + output->push_back(first_byte); + + varint -= prefix_mask; + while (varint >= 128) { + // Encode the next seven bits, with continuation bit set to one. + output->push_back(0b10000000 | (varint % 128)); + varint >>= 7; + } + + // Encode final seven bits, with continuation bit set to zero. + output->push_back(varint); +} + +} // namespace http2 diff --git a/quiche/http2/hpack/varint/hpack_varint_encoder.h b/quiche/http2/hpack/varint/hpack_varint_encoder.h new file mode 100644 index 000000000000..0e16009eb00d --- /dev/null +++ b/quiche/http2/hpack/varint/hpack_varint_encoder.h @@ -0,0 +1,29 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_HPACK_VARINT_HPACK_VARINT_ENCODER_H_ +#define QUICHE_HTTP2_HPACK_VARINT_HPACK_VARINT_ENCODER_H_ + +#include +#include +#include + +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { + +// HPACK integer encoder class with single static method implementing variable +// length integer representation defined in RFC7541, Section 5.1: +// https://httpwg.org/specs/rfc7541.html#integer.representation +class QUICHE_EXPORT HpackVarintEncoder { + public: + // Encode |varint|, appending encoded data to |*output|. + // Appends between 1 and 11 bytes in total. + static void Encode(uint8_t high_bits, uint8_t prefix_length, uint64_t varint, + std::string* output); +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_HPACK_VARINT_HPACK_VARINT_ENCODER_H_ diff --git a/quiche/http2/hpack/varint/hpack_varint_encoder_test.cc b/quiche/http2/hpack/varint/hpack_varint_encoder_test.cc new file mode 100644 index 000000000000..bd51606b0792 --- /dev/null +++ b/quiche/http2/hpack/varint/hpack_varint_encoder_test.cc @@ -0,0 +1,161 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/varint/hpack_varint_encoder.h" + +#include "absl/base/macros.h" +#include "absl/strings/escaping.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { +namespace { + +struct { + uint8_t high_bits; + uint8_t prefix_length; + uint64_t value; + uint8_t expected_encoding; +} kShortTestData[] = {{0b10110010, 1, 0, 0b10110010}, + {0b10101100, 2, 2, 0b10101110}, + {0b10100000, 3, 6, 0b10100110}, + {0b10110000, 4, 13, 0b10111101}, + {0b10100000, 5, 8, 0b10101000}, + {0b11000000, 6, 48, 0b11110000}, + {0b10000000, 7, 99, 0b11100011}, + // Example from RFC7541 C.1. + {0b00000000, 5, 10, 0b00001010}}; + +// Encode integers that fit in the prefix. +TEST(HpackVarintEncoderTest, Short) { + for (size_t i = 0; i < ABSL_ARRAYSIZE(kShortTestData); ++i) { + std::string output; + HpackVarintEncoder::Encode(kShortTestData[i].high_bits, + kShortTestData[i].prefix_length, + kShortTestData[i].value, &output); + ASSERT_EQ(1u, output.size()); + EXPECT_EQ(kShortTestData[i].expected_encoding, + static_cast(output[0])); + } +} + +struct { + uint8_t high_bits; + uint8_t prefix_length; + uint64_t value; + const char* expected_encoding; +} kLongTestData[] = { + // One extension byte. + {0b10011000, 3, 103, "9f60"}, + {0b10010000, 4, 57, "9f2a"}, + {0b11000000, 5, 158, "df7f"}, + {0b01000000, 6, 65, "7f02"}, + {0b00000000, 7, 200, "7f49"}, + // Two extension bytes. + {0b10011000, 3, 12345, "9fb260"}, + {0b10010000, 4, 5401, "9f8a2a"}, + {0b11000000, 5, 16327, "dfa87f"}, + {0b01000000, 6, 399, "7fd002"}, + {0b00000000, 7, 9598, "7fff49"}, + // Three extension bytes. + {0b10011000, 3, 1579281, "9f8ab260"}, + {0b10010000, 4, 689488, "9fc18a2a"}, + {0b11000000, 5, 2085964, "dfada87f"}, + {0b01000000, 6, 43103, "7fa0d002"}, + {0b00000000, 7, 1212541, "7ffeff49"}, + // Four extension bytes. + {0b10011000, 3, 202147110, "9f9f8ab260"}, + {0b10010000, 4, 88252593, "9fa2c18a2a"}, + {0b11000000, 5, 266999535, "dfd0ada87f"}, + {0b01000000, 6, 5509304, "7ff9a0d002"}, + {0b00000000, 7, 155189149, "7f9efeff49"}, + // Six extension bytes. + {0b10011000, 3, 3311978140938, "9f83aa9f8ab260"}, + {0b10010000, 4, 1445930244223, "9ff0b0a2c18a2a"}, + {0b11000000, 5, 4374519874169, "dfda84d0ada87f"}, + {0b01000000, 6, 90263420404, "7fb5fbf9a0d002"}, + {0b00000000, 7, 2542616951118, "7fcff19efeff49"}, + // Eight extension bytes. + {0b10011000, 3, 54263449861016696, "9ff19883aa9f8ab260"}, + {0b10010000, 4, 23690121121119891, "9f84fdf0b0a2c18a2a"}, + {0b11000000, 5, 71672133617889215, "dfa0dfda84d0ada87f"}, + {0b01000000, 6, 1478875878881374, "7f9ff0b5fbf9a0d002"}, + {0b00000000, 7, 41658236125045114, "7ffbc1cff19efeff49"}, + // Ten extension bytes. + {0b10011000, 3, 12832019021693745307u, "9f94f1f19883aa9f8ab201"}, + {0b10010000, 4, 9980690937382242223u, "9fa08f84fdf0b0a2c18a01"}, + {0b11000000, 5, 12131360551794650846u, "dfbfdda0dfda84d0ada801"}, + {0b01000000, 6, 15006530362736632796u, "7f9dc79ff0b5fbf9a0d001"}, + {0b00000000, 7, 18445754019193211014u, "7f8790fbc1cff19efeff01"}, + // Maximum value: 2^64-1. + {0b10011000, 3, 18446744073709551615u, "9ff8ffffffffffffffff01"}, + {0b10010000, 4, 18446744073709551615u, "9ff0ffffffffffffffff01"}, + {0b11000000, 5, 18446744073709551615u, "dfe0ffffffffffffffff01"}, + {0b01000000, 6, 18446744073709551615u, "7fc0ffffffffffffffff01"}, + {0b00000000, 7, 18446744073709551615u, "7f80ffffffffffffffff01"}, + // Example from RFC7541 C.1. + {0b00000000, 5, 1337, "1f9a0a"}, +}; + +// Encode integers that do not fit in the prefix. +TEST(HpackVarintEncoderTest, Long) { + // Test encoding byte by byte, also test encoding in + // a single ResumeEncoding() call. + for (size_t i = 0; i < ABSL_ARRAYSIZE(kLongTestData); ++i) { + std::string expected_encoding = + absl::HexStringToBytes(kLongTestData[i].expected_encoding); + + std::string output; + HpackVarintEncoder::Encode(kLongTestData[i].high_bits, + kLongTestData[i].prefix_length, + kLongTestData[i].value, &output); + + EXPECT_EQ(expected_encoding, output); + } +} + +struct { + uint8_t high_bits; + uint8_t prefix_length; + uint64_t value; + uint8_t expected_encoding_first_byte; +} kLastByteIsZeroTestData[] = { + {0b10110010, 1, 1, 0b10110011}, {0b10101100, 2, 3, 0b10101111}, + {0b10101000, 3, 7, 0b10101111}, {0b10110000, 4, 15, 0b10111111}, + {0b10100000, 5, 31, 0b10111111}, {0b11000000, 6, 63, 0b11111111}, + {0b10000000, 7, 127, 0b11111111}, {0b00000000, 8, 255, 0b11111111}}; + +// Make sure that the encoder outputs the last byte even when it is zero. This +// happens exactly when encoding the value 2^prefix_length - 1. +TEST(HpackVarintEncoderTest, LastByteIsZero) { + for (size_t i = 0; i < ABSL_ARRAYSIZE(kLastByteIsZeroTestData); ++i) { + std::string output; + HpackVarintEncoder::Encode(kLastByteIsZeroTestData[i].high_bits, + kLastByteIsZeroTestData[i].prefix_length, + kLastByteIsZeroTestData[i].value, &output); + ASSERT_EQ(2u, output.size()); + EXPECT_EQ(kLastByteIsZeroTestData[i].expected_encoding_first_byte, + static_cast(output[0])); + EXPECT_EQ(0b00000000, output[1]); + } +} + +// Test that encoder appends correctly to non-empty string. +TEST(HpackVarintEncoderTest, Append) { + std::string output("foo"); + EXPECT_EQ(absl::HexStringToBytes("666f6f"), output); + + HpackVarintEncoder::Encode(0b10011000, 3, 103, &output); + EXPECT_EQ(absl::HexStringToBytes("666f6f9f60"), output); + + HpackVarintEncoder::Encode(0b10100000, 5, 8, &output); + EXPECT_EQ(absl::HexStringToBytes("666f6f9f60a8"), output); + + HpackVarintEncoder::Encode(0b10011000, 3, 202147110, &output); + EXPECT_EQ(absl::HexStringToBytes("666f6f9f60a89f9f8ab260"), output); +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/hpack/varint/hpack_varint_round_trip_test.cc b/quiche/http2/hpack/varint/hpack_varint_round_trip_test.cc new file mode 100644 index 000000000000..d307ea027182 --- /dev/null +++ b/quiche/http2/hpack/varint/hpack_varint_round_trip_test.cc @@ -0,0 +1,417 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/hpack/varint/hpack_varint_decoder.h" + +// Test HpackVarintDecoder against data encoded via HpackBlockBuilder, +// which uses HpackVarintEncoder under the hood. + +#include + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "quiche/http2/test_tools/hpack_block_builder.h" +#include "quiche/http2/test_tools/random_decoder_test_base.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/quiche_text_utils.h" + +using ::testing::AssertionFailure; +using ::testing::AssertionSuccess; + +namespace http2 { +namespace test { +namespace { + +// Returns the highest value with the specified number of extension bytes +// and the specified prefix length (bits). +uint64_t HiValueOfExtensionBytes(uint32_t extension_bytes, + uint32_t prefix_length) { + return (1 << prefix_length) - 2 + + (extension_bytes == 0 ? 0 : (1LLU << (extension_bytes * 7))); +} + +class HpackVarintRoundTripTest : public RandomDecoderTest { + protected: + HpackVarintRoundTripTest() : prefix_length_(0) {} + + DecodeStatus StartDecoding(DecodeBuffer* b) override { + QUICHE_CHECK_LT(0u, b->Remaining()); + uint8_t prefix = b->DecodeUInt8(); + return decoder_.Start(prefix, prefix_length_, b); + } + + DecodeStatus ResumeDecoding(DecodeBuffer* b) override { + return decoder_.Resume(b); + } + + void DecodeSeveralWays(uint64_t expected_value, uint32_t expected_offset) { + // The validator is called after each of the several times that the input + // DecodeBuffer is decoded, each with a different segmentation of the input. + // Validate that decoder_.value() matches the expected value. + Validator validator = [expected_value, this]( + const DecodeBuffer& /*db*/, + DecodeStatus /*status*/) -> AssertionResult { + if (decoder_.value() != expected_value) { + return AssertionFailure() + << "Value doesn't match expected: " << decoder_.value() + << " != " << expected_value; + } + return AssertionSuccess(); + }; + + // First validate that decoding is done and that we've advanced the cursor + // the expected amount. + validator = ValidateDoneAndOffset(expected_offset, validator); + + // StartDecoding, above, requires the DecodeBuffer be non-empty so that it + // can call Start with the prefix byte. + bool return_non_zero_on_first = true; + + DecodeBuffer b(buffer_); + EXPECT_TRUE( + DecodeAndValidateSeveralWays(&b, return_non_zero_on_first, validator)); + + EXPECT_EQ(expected_value, decoder_.value()); + EXPECT_EQ(expected_offset, b.Offset()); + } + + void EncodeNoRandom(uint64_t value, uint8_t prefix_length) { + QUICHE_DCHECK_LE(3, prefix_length); + QUICHE_DCHECK_LE(prefix_length, 8); + prefix_length_ = prefix_length; + + HpackBlockBuilder bb; + bb.AppendHighBitsAndVarint(0, prefix_length_, value); + buffer_ = bb.buffer(); + ASSERT_LT(0u, buffer_.size()); + + const uint8_t prefix_mask = (1 << prefix_length_) - 1; + ASSERT_EQ(static_cast(buffer_[0]), + static_cast(buffer_[0]) & prefix_mask); + } + + void Encode(uint64_t value, uint8_t prefix_length) { + EncodeNoRandom(value, prefix_length); + // Add some random bits to the prefix (the first byte) above the mask. + uint8_t prefix = buffer_[0]; + buffer_[0] = prefix | (Random().Rand8() << prefix_length); + const uint8_t prefix_mask = (1 << prefix_length_) - 1; + ASSERT_EQ(prefix, buffer_[0] & prefix_mask); + } + + // This is really a test of HpackBlockBuilder, making sure that the input to + // HpackVarintDecoder is as expected, which also acts as confirmation that + // my thinking about the encodings being used by the tests, i.e. cover the + // range desired. + void ValidateEncoding(uint64_t value, uint64_t minimum, uint64_t maximum, + size_t expected_bytes) { + ASSERT_EQ(expected_bytes, buffer_.size()); + if (expected_bytes > 1) { + const uint8_t prefix_mask = (1 << prefix_length_) - 1; + EXPECT_EQ(prefix_mask, buffer_[0] & prefix_mask); + size_t last = expected_bytes - 1; + for (size_t ndx = 1; ndx < last; ++ndx) { + // Before the last extension byte, we expect the high-bit set. + uint8_t byte = buffer_[ndx]; + if (value == minimum) { + EXPECT_EQ(0x80, byte) << "ndx=" << ndx; + } else if (value == maximum) { + if (expected_bytes < 11) { + EXPECT_EQ(0xff, byte) << "ndx=" << ndx; + } + } else { + EXPECT_EQ(0x80, byte & 0x80) << "ndx=" << ndx; + } + } + // The last extension byte should not have the high-bit set. + uint8_t byte = buffer_[last]; + if (value == minimum) { + if (expected_bytes == 2) { + EXPECT_EQ(0x00, byte); + } else { + EXPECT_EQ(0x01, byte); + } + } else if (value == maximum) { + if (expected_bytes < 11) { + EXPECT_EQ(0x7f, byte); + } + } else { + EXPECT_EQ(0x00, byte & 0x80); + } + } else { + const uint8_t prefix_mask = (1 << prefix_length_) - 1; + EXPECT_EQ(value, static_cast(buffer_[0] & prefix_mask)); + EXPECT_LT(value, prefix_mask); + } + } + + void EncodeAndDecodeValues(const std::set& values, + uint8_t prefix_length, size_t expected_bytes) { + QUICHE_CHECK(!values.empty()); + const uint64_t minimum = *values.begin(); + const uint64_t maximum = *values.rbegin(); + for (const uint64_t value : values) { + Encode(value, prefix_length); // Sets buffer_. + + std::string msg = absl::StrCat("value=", value, " (0x", absl::Hex(value), + "), prefix_length=", prefix_length, + ", expected_bytes=", expected_bytes, "\n", + quiche::QuicheTextUtils::HexDump(buffer_)); + + if (value == minimum) { + QUICHE_LOG(INFO) << "Checking minimum; " << msg; + } else if (value == maximum) { + QUICHE_LOG(INFO) << "Checking maximum; " << msg; + } + + SCOPED_TRACE(msg); + ValidateEncoding(value, minimum, maximum, expected_bytes); + DecodeSeveralWays(value, expected_bytes); + + // Append some random data to the end of buffer_ and repeat. That random + // data should be ignored. + buffer_.append(Random().RandString(1 + Random().Uniform(10))); + DecodeSeveralWays(value, expected_bytes); + + // If possible, add extension bytes that don't change the value. + if (1 < expected_bytes) { + buffer_.resize(expected_bytes); + for (uint8_t total_bytes = expected_bytes + 1; total_bytes <= 6; + ++total_bytes) { + // Mark the current last byte as not being the last one. + EXPECT_EQ(0x00, 0x80 & buffer_.back()); + buffer_.back() |= 0x80; + buffer_.push_back('\0'); + DecodeSeveralWays(value, total_bytes); + } + } + } + } + + // Encode values (all or some of it) in [start, start+range). Check + // that |start| is the smallest value and |start+range-1| is the largest value + // corresponding to |expected_bytes|, except if |expected_bytes| is maximal. + void EncodeAndDecodeValuesInRange(uint64_t start, uint64_t range, + uint8_t prefix_length, + size_t expected_bytes) { + const uint8_t prefix_mask = (1 << prefix_length) - 1; + const uint64_t beyond = start + range; + + QUICHE_LOG(INFO) + << "############################################################"; + QUICHE_LOG(INFO) << "prefix_length=" << static_cast(prefix_length); + QUICHE_LOG(INFO) << "prefix_mask=" << std::hex + << static_cast(prefix_mask); + QUICHE_LOG(INFO) << "start=" << start << " (" << std::hex << start << ")"; + QUICHE_LOG(INFO) << "range=" << range << " (" << std::hex << range << ")"; + QUICHE_LOG(INFO) << "beyond=" << beyond << " (" << std::hex << beyond + << ")"; + QUICHE_LOG(INFO) << "expected_bytes=" << expected_bytes; + + if (expected_bytes < 11) { + // Confirm the claim that beyond requires more bytes. + Encode(beyond, prefix_length); + EXPECT_EQ(expected_bytes + 1, buffer_.size()) + << quiche::QuicheTextUtils::HexDump(buffer_); + } + + std::set values; + if (range < 200) { + // Select all values in the range. + for (uint64_t offset = 0; offset < range; ++offset) { + values.insert(start + offset); + } + } else { + // Select some values in this range, including the minimum and maximum + // values that require exactly |expected_bytes| extension bytes. + values.insert({start, start + 1, beyond - 2, beyond - 1}); + while (values.size() < 100) { + values.insert(Random().UniformInRange(start, beyond - 1)); + } + } + + EncodeAndDecodeValues(values, prefix_length, expected_bytes); + } + + HpackVarintDecoder decoder_; + std::string buffer_; + uint8_t prefix_length_; +}; + +// To help me and future debuggers of varint encodings, this HTTP2_LOGs out the +// transition points where a new extension byte is added. +TEST_F(HpackVarintRoundTripTest, Encode) { + for (int prefix_length = 3; prefix_length <= 8; ++prefix_length) { + const uint64_t a = HiValueOfExtensionBytes(0, prefix_length); + const uint64_t b = HiValueOfExtensionBytes(1, prefix_length); + const uint64_t c = HiValueOfExtensionBytes(2, prefix_length); + const uint64_t d = HiValueOfExtensionBytes(3, prefix_length); + const uint64_t e = HiValueOfExtensionBytes(4, prefix_length); + const uint64_t f = HiValueOfExtensionBytes(5, prefix_length); + const uint64_t g = HiValueOfExtensionBytes(6, prefix_length); + const uint64_t h = HiValueOfExtensionBytes(7, prefix_length); + const uint64_t i = HiValueOfExtensionBytes(8, prefix_length); + const uint64_t j = HiValueOfExtensionBytes(9, prefix_length); + + QUICHE_LOG(INFO) + << "############################################################"; + QUICHE_LOG(INFO) << "prefix_length=" << prefix_length << " a=" << a + << " b=" << b << " c=" << c << " d=" << d + << " e=" << e << " f=" << f << " g=" << g + << " h=" << h << " i=" << i << " j=" << j; + + std::vector values = { + 0, 1, // Force line break. + a - 1, a, a + 1, a + 2, a + 3, // Force line break. + b - 1, b, b + 1, b + 2, b + 3, // Force line break. + c - 1, c, c + 1, c + 2, c + 3, // Force line break. + d - 1, d, d + 1, d + 2, d + 3, // Force line break. + e - 1, e, e + 1, e + 2, e + 3, // Force line break. + f - 1, f, f + 1, f + 2, f + 3, // Force line break. + g - 1, g, g + 1, g + 2, g + 3, // Force line break. + h - 1, h, h + 1, h + 2, h + 3, // Force line break. + i - 1, i, i + 1, i + 2, i + 3, // Force line break. + j - 1, j, j + 1, j + 2, j + 3, // Force line break. + }; + + for (uint64_t value : values) { + EncodeNoRandom(value, prefix_length); + std::string dump = quiche::QuicheTextUtils::HexDump(buffer_); + QUICHE_LOG(INFO) << absl::StrFormat("%10llu %0#18x ", value, value) + << quiche::QuicheTextUtils::HexDump(buffer_).substr(7); + } + } +} + +TEST_F(HpackVarintRoundTripTest, FromSpec1337) { + DecodeBuffer b(absl::string_view("\x1f\x9a\x0a")); + uint32_t prefix_length = 5; + uint8_t p = b.DecodeUInt8(); + EXPECT_EQ(1u, b.Offset()); + EXPECT_EQ(DecodeStatus::kDecodeDone, decoder_.Start(p, prefix_length, &b)); + EXPECT_EQ(3u, b.Offset()); + EXPECT_EQ(1337u, decoder_.value()); + + EncodeNoRandom(1337, prefix_length); + EXPECT_EQ(3u, buffer_.size()); + EXPECT_EQ('\x1f', buffer_[0]); + EXPECT_EQ('\x9a', buffer_[1]); + EXPECT_EQ('\x0a', buffer_[2]); +} + +// Test all the values that fit into the prefix (one less than the mask). +TEST_F(HpackVarintRoundTripTest, ValidatePrefixOnly) { + for (int prefix_length = 3; prefix_length <= 8; ++prefix_length) { + const uint8_t prefix_mask = (1 << prefix_length) - 1; + EncodeAndDecodeValuesInRange(0, prefix_mask, prefix_length, 1); + } +} + +// Test all values that require exactly 1 extension byte. +TEST_F(HpackVarintRoundTripTest, ValidateOneExtensionByte) { + for (int prefix_length = 3; prefix_length <= 8; ++prefix_length) { + const uint64_t start = HiValueOfExtensionBytes(0, prefix_length) + 1; + EncodeAndDecodeValuesInRange(start, 128, prefix_length, 2); + } +} + +// Test *some* values that require exactly 2 extension bytes. +TEST_F(HpackVarintRoundTripTest, ValidateTwoExtensionBytes) { + for (int prefix_length = 3; prefix_length <= 8; ++prefix_length) { + const uint64_t start = HiValueOfExtensionBytes(1, prefix_length) + 1; + const uint64_t range = 127 << 7; + + EncodeAndDecodeValuesInRange(start, range, prefix_length, 3); + } +} + +// Test *some* values that require 3 extension bytes. +TEST_F(HpackVarintRoundTripTest, ValidateThreeExtensionBytes) { + for (int prefix_length = 3; prefix_length <= 8; ++prefix_length) { + const uint64_t start = HiValueOfExtensionBytes(2, prefix_length) + 1; + const uint64_t range = 127 << 14; + + EncodeAndDecodeValuesInRange(start, range, prefix_length, 4); + } +} + +// Test *some* values that require 4 extension bytes. +TEST_F(HpackVarintRoundTripTest, ValidateFourExtensionBytes) { + for (int prefix_length = 3; prefix_length <= 8; ++prefix_length) { + const uint64_t start = HiValueOfExtensionBytes(3, prefix_length) + 1; + const uint64_t range = 127 << 21; + + EncodeAndDecodeValuesInRange(start, range, prefix_length, 5); + } +} + +// Test *some* values that require 5 extension bytes. +TEST_F(HpackVarintRoundTripTest, ValidateFiveExtensionBytes) { + for (int prefix_length = 3; prefix_length <= 8; ++prefix_length) { + const uint64_t start = HiValueOfExtensionBytes(4, prefix_length) + 1; + const uint64_t range = 127llu << 28; + + EncodeAndDecodeValuesInRange(start, range, prefix_length, 6); + } +} + +// Test *some* values that require 6 extension bytes. +TEST_F(HpackVarintRoundTripTest, ValidateSixExtensionBytes) { + for (int prefix_length = 3; prefix_length <= 8; ++prefix_length) { + const uint64_t start = HiValueOfExtensionBytes(5, prefix_length) + 1; + const uint64_t range = 127llu << 35; + + EncodeAndDecodeValuesInRange(start, range, prefix_length, 7); + } +} + +// Test *some* values that require 7 extension bytes. +TEST_F(HpackVarintRoundTripTest, ValidateSevenExtensionBytes) { + for (int prefix_length = 3; prefix_length <= 8; ++prefix_length) { + const uint64_t start = HiValueOfExtensionBytes(6, prefix_length) + 1; + const uint64_t range = 127llu << 42; + + EncodeAndDecodeValuesInRange(start, range, prefix_length, 8); + } +} + +// Test *some* values that require 8 extension bytes. +TEST_F(HpackVarintRoundTripTest, ValidateEightExtensionBytes) { + for (int prefix_length = 3; prefix_length <= 8; ++prefix_length) { + const uint64_t start = HiValueOfExtensionBytes(7, prefix_length) + 1; + const uint64_t range = 127llu << 49; + + EncodeAndDecodeValuesInRange(start, range, prefix_length, 9); + } +} + +// Test *some* values that require 9 extension bytes. +TEST_F(HpackVarintRoundTripTest, ValidateNineExtensionBytes) { + for (int prefix_length = 3; prefix_length <= 8; ++prefix_length) { + const uint64_t start = HiValueOfExtensionBytes(8, prefix_length) + 1; + const uint64_t range = 127llu << 56; + + EncodeAndDecodeValuesInRange(start, range, prefix_length, 10); + } +} + +// Test *some* values that require 10 extension bytes. +TEST_F(HpackVarintRoundTripTest, ValidateTenExtensionBytes) { + for (int prefix_length = 3; prefix_length <= 8; ++prefix_length) { + const uint64_t start = HiValueOfExtensionBytes(9, prefix_length) + 1; + const uint64_t range = std::numeric_limits::max() - start; + + EncodeAndDecodeValuesInRange(start, range, prefix_length, 11); + } +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/http2_constants.cc b/quiche/http2/http2_constants.cc new file mode 100644 index 000000000000..a4b6105bd5aa --- /dev/null +++ b/quiche/http2/http2_constants.cc @@ -0,0 +1,181 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/http2_constants.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_flag_utils.h" +#include "quiche/common/platform/api/quiche_flags.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +std::string Http2FrameTypeToString(Http2FrameType v) { + switch (v) { + case Http2FrameType::DATA: + return "DATA"; + case Http2FrameType::HEADERS: + return "HEADERS"; + case Http2FrameType::PRIORITY: + return "PRIORITY"; + case Http2FrameType::RST_STREAM: + return "RST_STREAM"; + case Http2FrameType::SETTINGS: + return "SETTINGS"; + case Http2FrameType::PUSH_PROMISE: + return "PUSH_PROMISE"; + case Http2FrameType::PING: + return "PING"; + case Http2FrameType::GOAWAY: + return "GOAWAY"; + case Http2FrameType::WINDOW_UPDATE: + return "WINDOW_UPDATE"; + case Http2FrameType::CONTINUATION: + return "CONTINUATION"; + case Http2FrameType::ALTSVC: + return "ALTSVC"; + case Http2FrameType::PRIORITY_UPDATE: + return "PRIORITY_UPDATE"; + } + return absl::StrCat("UnknownFrameType(", static_cast(v), ")"); +} + +std::string Http2FrameTypeToString(uint8_t v) { + return Http2FrameTypeToString(static_cast(v)); +} + +std::string Http2FrameFlagsToString(Http2FrameType type, uint8_t flags) { + std::string s; + // Closure to append flag name |v| to the std::string |s|, + // and to clear |bit| from |flags|. + auto append_and_clear = [&s, &flags](absl::string_view v, uint8_t bit) { + if (!s.empty()) { + s.push_back('|'); + } + absl::StrAppend(&s, v); + flags ^= bit; + }; + if (flags & 0x01) { + if (type == Http2FrameType::DATA || type == Http2FrameType::HEADERS) { + append_and_clear("END_STREAM", Http2FrameFlag::END_STREAM); + } else if (type == Http2FrameType::SETTINGS || + type == Http2FrameType::PING) { + append_and_clear("ACK", Http2FrameFlag::ACK); + } + } + if (flags & 0x04) { + if (type == Http2FrameType::HEADERS || + type == Http2FrameType::PUSH_PROMISE || + type == Http2FrameType::CONTINUATION) { + append_and_clear("END_HEADERS", Http2FrameFlag::END_HEADERS); + } + } + if (flags & 0x08) { + if (type == Http2FrameType::DATA || type == Http2FrameType::HEADERS || + type == Http2FrameType::PUSH_PROMISE) { + append_and_clear("PADDED", Http2FrameFlag::PADDED); + } + } + if (flags & 0x20) { + if (type == Http2FrameType::HEADERS) { + append_and_clear("PRIORITY", Http2FrameFlag::PRIORITY); + } + } + if (flags != 0) { + append_and_clear(absl::StrFormat("0x%02x", flags), flags); + } + QUICHE_DCHECK_EQ(0, flags); + return s; +} +std::string Http2FrameFlagsToString(uint8_t type, uint8_t flags) { + return Http2FrameFlagsToString(static_cast(type), flags); +} + +std::string Http2ErrorCodeToString(uint32_t v) { + switch (v) { + case 0x0: + return "NO_ERROR"; + case 0x1: + return "PROTOCOL_ERROR"; + case 0x2: + return "INTERNAL_ERROR"; + case 0x3: + return "FLOW_CONTROL_ERROR"; + case 0x4: + return "SETTINGS_TIMEOUT"; + case 0x5: + return "STREAM_CLOSED"; + case 0x6: + return "FRAME_SIZE_ERROR"; + case 0x7: + return "REFUSED_STREAM"; + case 0x8: + return "CANCEL"; + case 0x9: + return "COMPRESSION_ERROR"; + case 0xa: + return "CONNECT_ERROR"; + case 0xb: + return "ENHANCE_YOUR_CALM"; + case 0xc: + return "INADEQUATE_SECURITY"; + case 0xd: + return "HTTP_1_1_REQUIRED"; + } + return absl::StrCat("UnknownErrorCode(0x", absl::Hex(v), ")"); +} +std::string Http2ErrorCodeToString(Http2ErrorCode v) { + return Http2ErrorCodeToString(static_cast(v)); +} + +std::string Http2SettingsParameterToString(uint32_t v) { + switch (v) { + case 0x1: + return "HEADER_TABLE_SIZE"; + case 0x2: + return "ENABLE_PUSH"; + case 0x3: + return "MAX_CONCURRENT_STREAMS"; + case 0x4: + return "INITIAL_WINDOW_SIZE"; + case 0x5: + return "MAX_FRAME_SIZE"; + case 0x6: + return "MAX_HEADER_LIST_SIZE"; + } + return absl::StrCat("UnknownSettingsParameter(0x", absl::Hex(v), ")"); +} +std::string Http2SettingsParameterToString(Http2SettingsParameter v) { + return Http2SettingsParameterToString(static_cast(v)); +} + +// Invalid HTTP/2 header names according to +// https://datatracker.ietf.org/doc/html/rfc7540#section-8.1.2.2. +// TODO(b/78024822): Consider adding "upgrade" to this set. +constexpr char const* kHttp2InvalidHeaderNames[] = { + "connection", "host", "keep-alive", "proxy-connection", + "transfer-encoding", "", +}; + +constexpr char const* kHttp2InvalidHeaderNamesOld[] = { + "connection", "host", "keep-alive", "proxy-connection", "transfer-encoding", +}; + +const InvalidHeaderSet& GetInvalidHttp2HeaderSet() { + if (!GetQuicheReloadableFlag(quic, quic_verify_request_headers_2)) { + static const auto* invalid_header_set_old = + new InvalidHeaderSet(std::begin(http2::kHttp2InvalidHeaderNamesOld), + std::end(http2::kHttp2InvalidHeaderNamesOld)); + return *invalid_header_set_old; + } + QUICHE_RELOADABLE_FLAG_COUNT_N(quic_verify_request_headers_2, 3, 3); + static const auto* invalid_header_set = + new InvalidHeaderSet(std::begin(http2::kHttp2InvalidHeaderNames), + std::end(http2::kHttp2InvalidHeaderNames)); + return *invalid_header_set; +} + +} // namespace http2 diff --git a/quiche/http2/http2_constants.h b/quiche/http2/http2_constants.h new file mode 100644 index 000000000000..f4034bda485f --- /dev/null +++ b/quiche/http2/http2_constants.h @@ -0,0 +1,270 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_HTTP2_CONSTANTS_H_ +#define QUICHE_HTTP2_HTTP2_CONSTANTS_H_ + +// Constants from the HTTP/2 spec, RFC 7540, and associated helper functions. + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/quiche_text_utils.h" + +namespace http2 { + +// TODO(jamessynge): create http2_simple_types for types similar to +// SpdyStreamId, but not for structures like Http2FrameHeader. Then will be +// able to move these stream id functions there. +constexpr uint32_t UInt31Mask() { return 0x7fffffff; } +constexpr uint32_t StreamIdMask() { return UInt31Mask(); } + +// The value used to identify types of frames. Upper case to match the RFC. +// The comments indicate which flags are valid for that frame type. +enum class Http2FrameType : uint8_t { + DATA = 0, // END_STREAM | PADDED + HEADERS = 1, // END_STREAM | END_HEADERS | PADDED | PRIORITY + PRIORITY = 2, // + RST_STREAM = 3, // + SETTINGS = 4, // ACK + PUSH_PROMISE = 5, // END_HEADERS | PADDED + PING = 6, // ACK + GOAWAY = 7, // + WINDOW_UPDATE = 8, // + CONTINUATION = 9, // END_HEADERS + // https://tools.ietf.org/html/rfc7838 + ALTSVC = 10, // no flags + // https://tools.ietf.org/html/draft-ietf-httpbis-priority-02 + PRIORITY_UPDATE = 16, // no flags +}; + +// Is the frame type known/supported? +inline bool IsSupportedHttp2FrameType(uint32_t v) { + return v <= static_cast(Http2FrameType::ALTSVC) || + v == static_cast(Http2FrameType::PRIORITY_UPDATE); +} +inline bool IsSupportedHttp2FrameType(Http2FrameType v) { + return IsSupportedHttp2FrameType(static_cast(v)); +} + +// The return type is 'std::string' so that they can generate a unique string +// for each unsupported value. Since these are just used for debugging/error +// messages, that isn't a cost to we need to worry about. The same applies to +// the functions later in this file. +QUICHE_EXPORT std::string Http2FrameTypeToString(Http2FrameType v); +QUICHE_EXPORT std::string Http2FrameTypeToString(uint8_t v); +QUICHE_EXPORT inline std::ostream& operator<<(std::ostream& out, + Http2FrameType v) { + return out << Http2FrameTypeToString(v); +} + +// Flags that appear in supported frame types. These are treated as bit masks. +// The comments indicate for which frame types the flag is valid. +enum Http2FrameFlag : uint8_t { + END_STREAM = 0x01, // DATA, HEADERS + ACK = 0x01, // SETTINGS, PING + END_HEADERS = 0x04, // HEADERS, PUSH_PROMISE, CONTINUATION + PADDED = 0x08, // DATA, HEADERS, PUSH_PROMISE + PRIORITY = 0x20, // HEADERS +}; + +// Formats zero or more flags for the specified type of frame. Returns an +// empty string if flags==0. +QUICHE_EXPORT std::string Http2FrameFlagsToString(Http2FrameType type, + uint8_t flags); +QUICHE_EXPORT std::string Http2FrameFlagsToString(uint8_t type, uint8_t flags); + +// Error codes for GOAWAY and RST_STREAM frames. +enum class Http2ErrorCode : uint32_t { + // The associated condition is not a result of an error. For example, a GOAWAY + // might include this code to indicate graceful shutdown of a connection. + HTTP2_NO_ERROR = 0x0, + + // The endpoint detected an unspecific protocol error. This error is for use + // when a more specific error code is not available. + PROTOCOL_ERROR = 0x1, + + // The endpoint encountered an unexpected internal error. + INTERNAL_ERROR = 0x2, + + // The endpoint detected that its peer violated the flow-control protocol. + FLOW_CONTROL_ERROR = 0x3, + + // The endpoint sent a SETTINGS frame but did not receive a response in a + // timely manner. See Section 6.5.3 ("Settings Synchronization"). + SETTINGS_TIMEOUT = 0x4, + + // The endpoint received a frame after a stream was half-closed. + STREAM_CLOSED = 0x5, + + // The endpoint received a frame with an invalid size. + FRAME_SIZE_ERROR = 0x6, + + // The endpoint refused the stream prior to performing any application + // processing (see Section 8.1.4 for details). + REFUSED_STREAM = 0x7, + + // Used by the endpoint to indicate that the stream is no longer needed. + CANCEL = 0x8, + + // The endpoint is unable to maintain the header compression context + // for the connection. + COMPRESSION_ERROR = 0x9, + + // The connection established in response to a CONNECT request (Section 8.3) + // was reset or abnormally closed. + CONNECT_ERROR = 0xa, + + // The endpoint detected that its peer is exhibiting a behavior that might + // be generating excessive load. + ENHANCE_YOUR_CALM = 0xb, + + // The underlying transport has properties that do not meet minimum + // security requirements (see Section 9.2). + INADEQUATE_SECURITY = 0xc, + + // The endpoint requires that HTTP/1.1 be used instead of HTTP/2. + HTTP_1_1_REQUIRED = 0xd, +}; + +// Is the error code supported? (So far that means it is in RFC 7540.) +inline bool IsSupportedHttp2ErrorCode(uint32_t v) { + return v <= static_cast(Http2ErrorCode::HTTP_1_1_REQUIRED); +} +inline bool IsSupportedHttp2ErrorCode(Http2ErrorCode v) { + return IsSupportedHttp2ErrorCode(static_cast(v)); +} + +// Format the specified error code. +QUICHE_EXPORT std::string Http2ErrorCodeToString(uint32_t v); +QUICHE_EXPORT std::string Http2ErrorCodeToString(Http2ErrorCode v); +QUICHE_EXPORT inline std::ostream& operator<<(std::ostream& out, + Http2ErrorCode v) { + return out << Http2ErrorCodeToString(v); +} + +// Supported parameters in SETTINGS frames; so far just those in RFC 7540. +enum class Http2SettingsParameter : uint16_t { + // Allows the sender to inform the remote endpoint of the maximum size of the + // header compression table used to decode header blocks, in octets. The + // encoder can select any size equal to or less than this value by using + // signaling specific to the header compression format inside a header block + // (see [COMPRESSION]). The initial value is 4,096 octets. + HEADER_TABLE_SIZE = 0x1, + + // This setting can be used to disable server push (Section 8.2). An endpoint + // MUST NOT send a PUSH_PROMISE frame if it receives this parameter set to a + // value of 0. An endpoint that has both set this parameter to 0 and had it + // acknowledged MUST treat the receipt of a PUSH_PROMISE frame as a connection + // error (Section 5.4.1) of type PROTOCOL_ERROR. + // + // The initial value is 1, which indicates that server push is permitted. Any + // value other than 0 or 1 MUST be treated as a connection error (Section + // 5.4.1) of type PROTOCOL_ERROR. + ENABLE_PUSH = 0x2, + + // Indicates the maximum number of concurrent streams that the sender will + // allow. This limit is directional: it applies to the number of streams that + // the sender permits the receiver to create. Initially, there is no limit to + // this value. It is recommended that this value be no smaller than 100, so as + // to not unnecessarily limit parallelism. + // + // A value of 0 for MAX_CONCURRENT_STREAMS SHOULD NOT be treated as + // special by endpoints. A zero value does prevent the creation of new + // streams; however, this can also happen for any limit that is exhausted with + // active streams. Servers SHOULD only set a zero value for short durations; + // if a server does not wish to accept requests, closing the connection is + // more appropriate. + MAX_CONCURRENT_STREAMS = 0x3, + + // Indicates the sender's initial window size (in octets) for stream-level + // flow control. The initial value is 2^16-1 (65,535) octets. + // + // This setting affects the window size of all streams (see Section 6.9.2). + // + // Values above the maximum flow-control window size of 2^31-1 MUST be treated + // as a connection error (Section 5.4.1) of type FLOW_CONTROL_ERROR. + INITIAL_WINDOW_SIZE = 0x4, + + // Indicates the size of the largest frame payload that the sender is willing + // to receive, in octets. + // + // The initial value is 2^14 (16,384) octets. The value advertised by an + // endpoint MUST be between this initial value and the maximum allowed frame + // size (2^24-1 or 16,777,215 octets), inclusive. Values outside this range + // MUST be treated as a connection error (Section 5.4.1) of type + // PROTOCOL_ERROR. + MAX_FRAME_SIZE = 0x5, + + // This advisory setting informs a peer of the maximum size of header list + // that the sender is prepared to accept, in octets. The value is based on the + // uncompressed size of header fields, including the length of the name and + // value in octets plus an overhead of 32 octets for each header field. + // + // For any given request, a lower limit than what is advertised MAY be + // enforced. The initial value of this setting is unlimited. + MAX_HEADER_LIST_SIZE = 0x6, +}; + +// Is the settings parameter supported (so far that means it is in RFC 7540)? +inline bool IsSupportedHttp2SettingsParameter(uint32_t v) { + return 0 < v && v <= static_cast( + Http2SettingsParameter::MAX_HEADER_LIST_SIZE); +} +inline bool IsSupportedHttp2SettingsParameter(Http2SettingsParameter v) { + return IsSupportedHttp2SettingsParameter(static_cast(v)); +} + +// Format the specified settings parameter. +QUICHE_EXPORT std::string Http2SettingsParameterToString(uint32_t v); +QUICHE_EXPORT std::string Http2SettingsParameterToString( + Http2SettingsParameter v); +inline std::ostream& operator<<(std::ostream& out, Http2SettingsParameter v) { + return out << Http2SettingsParameterToString(v); +} + +// Information about the initial, minimum and maximum value of settings (not +// applicable to all settings parameters). +class QUICHE_EXPORT Http2SettingsInfo { + public: + // Default value for HEADER_TABLE_SIZE. + static constexpr uint32_t DefaultHeaderTableSize() { return 4096; } + + // Default value for ENABLE_PUSH. + static constexpr bool DefaultEnablePush() { return true; } + + // Default value for INITIAL_WINDOW_SIZE. + static constexpr uint32_t DefaultInitialWindowSize() { return 65535; } + + // Maximum value for INITIAL_WINDOW_SIZE, and for the connection flow control + // window, and for each stream flow control window. + static constexpr uint32_t MaximumWindowSize() { return UInt31Mask(); } + + // Default value for MAX_FRAME_SIZE. + static constexpr uint32_t DefaultMaxFrameSize() { return 16384; } + + // Minimum value for MAX_FRAME_SIZE. + static constexpr uint32_t MinimumMaxFrameSize() { return 16384; } + + // Maximum value for MAX_FRAME_SIZE. + static constexpr uint32_t MaximumMaxFrameSize() { return (1 << 24) - 1; } +}; + +// Http3 early fails upper case request headers, but Http2 still needs case +// insensitive comparison. +using InvalidHeaderSet = + absl::flat_hash_set; + +// Returns all disallowed HTTP/2 headers. +QUICHE_EXPORT const InvalidHeaderSet& GetInvalidHttp2HeaderSet(); + +} // namespace http2 + +#endif // QUICHE_HTTP2_HTTP2_CONSTANTS_H_ diff --git a/quiche/http2/http2_constants_test.cc b/quiche/http2/http2_constants_test.cc new file mode 100644 index 000000000000..8478ae6cd2e9 --- /dev/null +++ b/quiche/http2/http2_constants_test.cc @@ -0,0 +1,271 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/http2_constants.h" + +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { +namespace { + +class Http2ConstantsTest : public quiche::test::QuicheTest {}; + +TEST(Http2ConstantsTest, Http2FrameType) { + EXPECT_EQ(Http2FrameType::DATA, static_cast(0)); + EXPECT_EQ(Http2FrameType::HEADERS, static_cast(1)); + EXPECT_EQ(Http2FrameType::PRIORITY, static_cast(2)); + EXPECT_EQ(Http2FrameType::RST_STREAM, static_cast(3)); + EXPECT_EQ(Http2FrameType::SETTINGS, static_cast(4)); + EXPECT_EQ(Http2FrameType::PUSH_PROMISE, static_cast(5)); + EXPECT_EQ(Http2FrameType::PING, static_cast(6)); + EXPECT_EQ(Http2FrameType::GOAWAY, static_cast(7)); + EXPECT_EQ(Http2FrameType::WINDOW_UPDATE, static_cast(8)); + EXPECT_EQ(Http2FrameType::CONTINUATION, static_cast(9)); + EXPECT_EQ(Http2FrameType::ALTSVC, static_cast(10)); +} + +TEST(Http2ConstantsTest, Http2FrameTypeToString) { + EXPECT_EQ("DATA", Http2FrameTypeToString(Http2FrameType::DATA)); + EXPECT_EQ("HEADERS", Http2FrameTypeToString(Http2FrameType::HEADERS)); + EXPECT_EQ("PRIORITY", Http2FrameTypeToString(Http2FrameType::PRIORITY)); + EXPECT_EQ("RST_STREAM", Http2FrameTypeToString(Http2FrameType::RST_STREAM)); + EXPECT_EQ("SETTINGS", Http2FrameTypeToString(Http2FrameType::SETTINGS)); + EXPECT_EQ("PUSH_PROMISE", + Http2FrameTypeToString(Http2FrameType::PUSH_PROMISE)); + EXPECT_EQ("PING", Http2FrameTypeToString(Http2FrameType::PING)); + EXPECT_EQ("GOAWAY", Http2FrameTypeToString(Http2FrameType::GOAWAY)); + EXPECT_EQ("WINDOW_UPDATE", + Http2FrameTypeToString(Http2FrameType::WINDOW_UPDATE)); + EXPECT_EQ("CONTINUATION", + Http2FrameTypeToString(Http2FrameType::CONTINUATION)); + EXPECT_EQ("ALTSVC", Http2FrameTypeToString(Http2FrameType::ALTSVC)); + + EXPECT_EQ("DATA", Http2FrameTypeToString(0)); + EXPECT_EQ("HEADERS", Http2FrameTypeToString(1)); + EXPECT_EQ("PRIORITY", Http2FrameTypeToString(2)); + EXPECT_EQ("RST_STREAM", Http2FrameTypeToString(3)); + EXPECT_EQ("SETTINGS", Http2FrameTypeToString(4)); + EXPECT_EQ("PUSH_PROMISE", Http2FrameTypeToString(5)); + EXPECT_EQ("PING", Http2FrameTypeToString(6)); + EXPECT_EQ("GOAWAY", Http2FrameTypeToString(7)); + EXPECT_EQ("WINDOW_UPDATE", Http2FrameTypeToString(8)); + EXPECT_EQ("CONTINUATION", Http2FrameTypeToString(9)); + EXPECT_EQ("ALTSVC", Http2FrameTypeToString(10)); + + EXPECT_EQ("UnknownFrameType(99)", Http2FrameTypeToString(99)); +} + +TEST(Http2ConstantsTest, Http2FrameFlag) { + EXPECT_EQ(Http2FrameFlag::END_STREAM, static_cast(0x01)); + EXPECT_EQ(Http2FrameFlag::ACK, static_cast(0x01)); + EXPECT_EQ(Http2FrameFlag::END_HEADERS, static_cast(0x04)); + EXPECT_EQ(Http2FrameFlag::PADDED, static_cast(0x08)); + EXPECT_EQ(Http2FrameFlag::PRIORITY, static_cast(0x20)); + + EXPECT_EQ(Http2FrameFlag::END_STREAM, 0x01); + EXPECT_EQ(Http2FrameFlag::ACK, 0x01); + EXPECT_EQ(Http2FrameFlag::END_HEADERS, 0x04); + EXPECT_EQ(Http2FrameFlag::PADDED, 0x08); + EXPECT_EQ(Http2FrameFlag::PRIORITY, 0x20); +} + +TEST(Http2ConstantsTest, Http2FrameFlagsToString) { + // Single flags... + + // 0b00000001 + EXPECT_EQ("END_STREAM", Http2FrameFlagsToString(Http2FrameType::DATA, + Http2FrameFlag::END_STREAM)); + EXPECT_EQ("END_STREAM", + Http2FrameFlagsToString(Http2FrameType::HEADERS, 0x01)); + EXPECT_EQ("ACK", Http2FrameFlagsToString(Http2FrameType::SETTINGS, + Http2FrameFlag::ACK)); + EXPECT_EQ("ACK", Http2FrameFlagsToString(Http2FrameType::PING, 0x01)); + + // 0b00000010 + EXPECT_EQ("0x02", Http2FrameFlagsToString(0xff, 0x02)); + + // 0b00000100 + EXPECT_EQ("END_HEADERS", + Http2FrameFlagsToString(Http2FrameType::HEADERS, + Http2FrameFlag::END_HEADERS)); + EXPECT_EQ("END_HEADERS", + Http2FrameFlagsToString(Http2FrameType::PUSH_PROMISE, 0x04)); + EXPECT_EQ("END_HEADERS", Http2FrameFlagsToString(0x09, 0x04)); + EXPECT_EQ("0x04", Http2FrameFlagsToString(0xff, 0x04)); + + // 0b00001000 + EXPECT_EQ("PADDED", Http2FrameFlagsToString(Http2FrameType::DATA, + Http2FrameFlag::PADDED)); + EXPECT_EQ("PADDED", Http2FrameFlagsToString(Http2FrameType::HEADERS, 0x08)); + EXPECT_EQ("PADDED", Http2FrameFlagsToString(0x05, 0x08)); + EXPECT_EQ("0x08", Http2FrameFlagsToString(0xff, Http2FrameFlag::PADDED)); + + // 0b00010000 + EXPECT_EQ("0x10", Http2FrameFlagsToString(Http2FrameType::SETTINGS, 0x10)); + + // 0b00100000 + EXPECT_EQ("PRIORITY", Http2FrameFlagsToString(Http2FrameType::HEADERS, 0x20)); + EXPECT_EQ("0x20", + Http2FrameFlagsToString(Http2FrameType::PUSH_PROMISE, 0x20)); + + // 0b01000000 + EXPECT_EQ("0x40", Http2FrameFlagsToString(0xff, 0x40)); + + // 0b10000000 + EXPECT_EQ("0x80", Http2FrameFlagsToString(0xff, 0x80)); + + // Combined flags... + + EXPECT_EQ("END_STREAM|PADDED|0xf6", + Http2FrameFlagsToString(Http2FrameType::DATA, 0xff)); + EXPECT_EQ("END_STREAM|END_HEADERS|PADDED|PRIORITY|0xd2", + Http2FrameFlagsToString(Http2FrameType::HEADERS, 0xff)); + EXPECT_EQ("0xff", Http2FrameFlagsToString(Http2FrameType::PRIORITY, 0xff)); + EXPECT_EQ("0xff", Http2FrameFlagsToString(Http2FrameType::RST_STREAM, 0xff)); + EXPECT_EQ("ACK|0xfe", + Http2FrameFlagsToString(Http2FrameType::SETTINGS, 0xff)); + EXPECT_EQ("END_HEADERS|PADDED|0xf3", + Http2FrameFlagsToString(Http2FrameType::PUSH_PROMISE, 0xff)); + EXPECT_EQ("ACK|0xfe", Http2FrameFlagsToString(Http2FrameType::PING, 0xff)); + EXPECT_EQ("0xff", Http2FrameFlagsToString(Http2FrameType::GOAWAY, 0xff)); + EXPECT_EQ("0xff", + Http2FrameFlagsToString(Http2FrameType::WINDOW_UPDATE, 0xff)); + EXPECT_EQ("END_HEADERS|0xfb", + Http2FrameFlagsToString(Http2FrameType::CONTINUATION, 0xff)); + EXPECT_EQ("0xff", Http2FrameFlagsToString(Http2FrameType::ALTSVC, 0xff)); + EXPECT_EQ("0xff", Http2FrameFlagsToString(0xff, 0xff)); +} + +TEST(Http2ConstantsTest, Http2ErrorCode) { + EXPECT_EQ(Http2ErrorCode::HTTP2_NO_ERROR, static_cast(0x0)); + EXPECT_EQ(Http2ErrorCode::PROTOCOL_ERROR, static_cast(0x1)); + EXPECT_EQ(Http2ErrorCode::INTERNAL_ERROR, static_cast(0x2)); + EXPECT_EQ(Http2ErrorCode::FLOW_CONTROL_ERROR, + static_cast(0x3)); + EXPECT_EQ(Http2ErrorCode::SETTINGS_TIMEOUT, static_cast(0x4)); + EXPECT_EQ(Http2ErrorCode::STREAM_CLOSED, static_cast(0x5)); + EXPECT_EQ(Http2ErrorCode::FRAME_SIZE_ERROR, static_cast(0x6)); + EXPECT_EQ(Http2ErrorCode::REFUSED_STREAM, static_cast(0x7)); + EXPECT_EQ(Http2ErrorCode::CANCEL, static_cast(0x8)); + EXPECT_EQ(Http2ErrorCode::COMPRESSION_ERROR, + static_cast(0x9)); + EXPECT_EQ(Http2ErrorCode::CONNECT_ERROR, static_cast(0xa)); + EXPECT_EQ(Http2ErrorCode::ENHANCE_YOUR_CALM, + static_cast(0xb)); + EXPECT_EQ(Http2ErrorCode::INADEQUATE_SECURITY, + static_cast(0xc)); + EXPECT_EQ(Http2ErrorCode::HTTP_1_1_REQUIRED, + static_cast(0xd)); +} + +TEST(Http2ConstantsTest, Http2ErrorCodeToString) { + EXPECT_EQ("NO_ERROR", Http2ErrorCodeToString(Http2ErrorCode::HTTP2_NO_ERROR)); + EXPECT_EQ("NO_ERROR", Http2ErrorCodeToString(0x0)); + EXPECT_EQ("PROTOCOL_ERROR", + Http2ErrorCodeToString(Http2ErrorCode::PROTOCOL_ERROR)); + EXPECT_EQ("PROTOCOL_ERROR", Http2ErrorCodeToString(0x1)); + EXPECT_EQ("INTERNAL_ERROR", + Http2ErrorCodeToString(Http2ErrorCode::INTERNAL_ERROR)); + EXPECT_EQ("INTERNAL_ERROR", Http2ErrorCodeToString(0x2)); + EXPECT_EQ("FLOW_CONTROL_ERROR", + Http2ErrorCodeToString(Http2ErrorCode::FLOW_CONTROL_ERROR)); + EXPECT_EQ("FLOW_CONTROL_ERROR", Http2ErrorCodeToString(0x3)); + EXPECT_EQ("SETTINGS_TIMEOUT", + Http2ErrorCodeToString(Http2ErrorCode::SETTINGS_TIMEOUT)); + EXPECT_EQ("SETTINGS_TIMEOUT", Http2ErrorCodeToString(0x4)); + EXPECT_EQ("STREAM_CLOSED", + Http2ErrorCodeToString(Http2ErrorCode::STREAM_CLOSED)); + EXPECT_EQ("STREAM_CLOSED", Http2ErrorCodeToString(0x5)); + EXPECT_EQ("FRAME_SIZE_ERROR", + Http2ErrorCodeToString(Http2ErrorCode::FRAME_SIZE_ERROR)); + EXPECT_EQ("FRAME_SIZE_ERROR", Http2ErrorCodeToString(0x6)); + EXPECT_EQ("REFUSED_STREAM", + Http2ErrorCodeToString(Http2ErrorCode::REFUSED_STREAM)); + EXPECT_EQ("REFUSED_STREAM", Http2ErrorCodeToString(0x7)); + EXPECT_EQ("CANCEL", Http2ErrorCodeToString(Http2ErrorCode::CANCEL)); + EXPECT_EQ("CANCEL", Http2ErrorCodeToString(0x8)); + EXPECT_EQ("COMPRESSION_ERROR", + Http2ErrorCodeToString(Http2ErrorCode::COMPRESSION_ERROR)); + EXPECT_EQ("COMPRESSION_ERROR", Http2ErrorCodeToString(0x9)); + EXPECT_EQ("CONNECT_ERROR", + Http2ErrorCodeToString(Http2ErrorCode::CONNECT_ERROR)); + EXPECT_EQ("CONNECT_ERROR", Http2ErrorCodeToString(0xa)); + EXPECT_EQ("ENHANCE_YOUR_CALM", + Http2ErrorCodeToString(Http2ErrorCode::ENHANCE_YOUR_CALM)); + EXPECT_EQ("ENHANCE_YOUR_CALM", Http2ErrorCodeToString(0xb)); + EXPECT_EQ("INADEQUATE_SECURITY", + Http2ErrorCodeToString(Http2ErrorCode::INADEQUATE_SECURITY)); + EXPECT_EQ("INADEQUATE_SECURITY", Http2ErrorCodeToString(0xc)); + EXPECT_EQ("HTTP_1_1_REQUIRED", + Http2ErrorCodeToString(Http2ErrorCode::HTTP_1_1_REQUIRED)); + EXPECT_EQ("HTTP_1_1_REQUIRED", Http2ErrorCodeToString(0xd)); + + EXPECT_EQ("UnknownErrorCode(0x123)", Http2ErrorCodeToString(0x123)); +} + +TEST(Http2ConstantsTest, Http2SettingsParameter) { + EXPECT_EQ(Http2SettingsParameter::HEADER_TABLE_SIZE, + static_cast(0x1)); + EXPECT_EQ(Http2SettingsParameter::ENABLE_PUSH, + static_cast(0x2)); + EXPECT_EQ(Http2SettingsParameter::MAX_CONCURRENT_STREAMS, + static_cast(0x3)); + EXPECT_EQ(Http2SettingsParameter::INITIAL_WINDOW_SIZE, + static_cast(0x4)); + EXPECT_EQ(Http2SettingsParameter::MAX_FRAME_SIZE, + static_cast(0x5)); + EXPECT_EQ(Http2SettingsParameter::MAX_HEADER_LIST_SIZE, + static_cast(0x6)); + + EXPECT_TRUE(IsSupportedHttp2SettingsParameter( + Http2SettingsParameter::HEADER_TABLE_SIZE)); + EXPECT_TRUE( + IsSupportedHttp2SettingsParameter(Http2SettingsParameter::ENABLE_PUSH)); + EXPECT_TRUE(IsSupportedHttp2SettingsParameter( + Http2SettingsParameter::MAX_CONCURRENT_STREAMS)); + EXPECT_TRUE(IsSupportedHttp2SettingsParameter( + Http2SettingsParameter::INITIAL_WINDOW_SIZE)); + EXPECT_TRUE(IsSupportedHttp2SettingsParameter( + Http2SettingsParameter::MAX_FRAME_SIZE)); + EXPECT_TRUE(IsSupportedHttp2SettingsParameter( + Http2SettingsParameter::MAX_HEADER_LIST_SIZE)); + + EXPECT_FALSE(IsSupportedHttp2SettingsParameter( + static_cast(0))); + EXPECT_FALSE(IsSupportedHttp2SettingsParameter( + static_cast(7))); +} + +TEST(Http2ConstantsTest, Http2SettingsParameterToString) { + EXPECT_EQ("HEADER_TABLE_SIZE", + Http2SettingsParameterToString( + Http2SettingsParameter::HEADER_TABLE_SIZE)); + EXPECT_EQ("HEADER_TABLE_SIZE", Http2SettingsParameterToString(0x1)); + EXPECT_EQ("ENABLE_PUSH", Http2SettingsParameterToString( + Http2SettingsParameter::ENABLE_PUSH)); + EXPECT_EQ("ENABLE_PUSH", Http2SettingsParameterToString(0x2)); + EXPECT_EQ("MAX_CONCURRENT_STREAMS", + Http2SettingsParameterToString( + Http2SettingsParameter::MAX_CONCURRENT_STREAMS)); + EXPECT_EQ("MAX_CONCURRENT_STREAMS", Http2SettingsParameterToString(0x3)); + EXPECT_EQ("INITIAL_WINDOW_SIZE", + Http2SettingsParameterToString( + Http2SettingsParameter::INITIAL_WINDOW_SIZE)); + EXPECT_EQ("INITIAL_WINDOW_SIZE", Http2SettingsParameterToString(0x4)); + EXPECT_EQ("MAX_FRAME_SIZE", Http2SettingsParameterToString( + Http2SettingsParameter::MAX_FRAME_SIZE)); + EXPECT_EQ("MAX_FRAME_SIZE", Http2SettingsParameterToString(0x5)); + EXPECT_EQ("MAX_HEADER_LIST_SIZE", + Http2SettingsParameterToString( + Http2SettingsParameter::MAX_HEADER_LIST_SIZE)); + EXPECT_EQ("MAX_HEADER_LIST_SIZE", Http2SettingsParameterToString(0x6)); + + EXPECT_EQ("UnknownSettingsParameter(0x123)", + Http2SettingsParameterToString(0x123)); +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/http2_structures.cc b/quiche/http2/http2_structures.cc new file mode 100644 index 000000000000..c77cfeb46325 --- /dev/null +++ b/quiche/http2/http2_structures.cc @@ -0,0 +1,153 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/http2_structures.h" + +#include // For std::memcmp +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" + +namespace http2 { + +// Http2FrameHeader: + +bool Http2FrameHeader::IsProbableHttpResponse() const { + return (payload_length == 0x485454 && // "HTT" + static_cast(type) == 'P' && // "P" + flags == '/'); // "/" +} + +std::string Http2FrameHeader::ToString() const { + return absl::StrCat("length=", payload_length, + ", type=", Http2FrameTypeToString(type), + ", flags=", FlagsToString(), ", stream=", stream_id); +} + +std::string Http2FrameHeader::FlagsToString() const { + return Http2FrameFlagsToString(type, flags); +} + +bool operator==(const Http2FrameHeader& a, const Http2FrameHeader& b) { + return a.payload_length == b.payload_length && a.stream_id == b.stream_id && + a.type == b.type && a.flags == b.flags; +} + +std::ostream& operator<<(std::ostream& out, const Http2FrameHeader& v) { + return out << v.ToString(); +} + +// Http2PriorityFields: + +bool operator==(const Http2PriorityFields& a, const Http2PriorityFields& b) { + return a.stream_dependency == b.stream_dependency && a.weight == b.weight; +} + +std::string Http2PriorityFields::ToString() const { + std::stringstream ss; + ss << "E=" << (is_exclusive ? "true" : "false") + << ", stream=" << stream_dependency + << ", weight=" << static_cast(weight); + return ss.str(); +} + +std::ostream& operator<<(std::ostream& out, const Http2PriorityFields& v) { + return out << v.ToString(); +} + +// Http2RstStreamFields: + +bool operator==(const Http2RstStreamFields& a, const Http2RstStreamFields& b) { + return a.error_code == b.error_code; +} + +std::ostream& operator<<(std::ostream& out, const Http2RstStreamFields& v) { + return out << "error_code=" << v.error_code; +} + +// Http2SettingFields: + +bool operator==(const Http2SettingFields& a, const Http2SettingFields& b) { + return a.parameter == b.parameter && a.value == b.value; +} +std::ostream& operator<<(std::ostream& out, const Http2SettingFields& v) { + return out << "parameter=" << v.parameter << ", value=" << v.value; +} + +// Http2PushPromiseFields: + +bool operator==(const Http2PushPromiseFields& a, + const Http2PushPromiseFields& b) { + return a.promised_stream_id == b.promised_stream_id; +} + +std::ostream& operator<<(std::ostream& out, const Http2PushPromiseFields& v) { + return out << "promised_stream_id=" << v.promised_stream_id; +} + +// Http2PingFields: + +bool operator==(const Http2PingFields& a, const Http2PingFields& b) { + static_assert((sizeof a.opaque_bytes) == Http2PingFields::EncodedSize(), + "Why not the same size?"); + return 0 == + std::memcmp(a.opaque_bytes, b.opaque_bytes, sizeof a.opaque_bytes); +} + +std::ostream& operator<<(std::ostream& out, const Http2PingFields& v) { + return out << "opaque_bytes=0x" + << absl::BytesToHexString(absl::string_view( + reinterpret_cast(v.opaque_bytes), + sizeof v.opaque_bytes)); +} + +// Http2GoAwayFields: + +bool operator==(const Http2GoAwayFields& a, const Http2GoAwayFields& b) { + return a.last_stream_id == b.last_stream_id && a.error_code == b.error_code; +} +std::ostream& operator<<(std::ostream& out, const Http2GoAwayFields& v) { + return out << "last_stream_id=" << v.last_stream_id + << ", error_code=" << v.error_code; +} + +// Http2WindowUpdateFields: + +bool operator==(const Http2WindowUpdateFields& a, + const Http2WindowUpdateFields& b) { + return a.window_size_increment == b.window_size_increment; +} +std::ostream& operator<<(std::ostream& out, const Http2WindowUpdateFields& v) { + return out << "window_size_increment=" << v.window_size_increment; +} + +// Http2AltSvcFields: + +bool operator==(const Http2AltSvcFields& a, const Http2AltSvcFields& b) { + return a.origin_length == b.origin_length; +} +std::ostream& operator<<(std::ostream& out, const Http2AltSvcFields& v) { + return out << "origin_length=" << v.origin_length; +} + +// Http2PriorityUpdateFields: + +bool operator==(const Http2PriorityUpdateFields& a, + const Http2PriorityUpdateFields& b) { + return a.prioritized_stream_id == b.prioritized_stream_id; +} + +std::string Http2PriorityUpdateFields::ToString() const { + std::stringstream ss; + ss << "prioritized_stream_id=" << prioritized_stream_id; + return ss.str(); +} + +std::ostream& operator<<(std::ostream& out, + const Http2PriorityUpdateFields& v) { + return out << v.ToString(); +} + +} // namespace http2 diff --git a/quiche/http2/http2_structures.h b/quiche/http2/http2_structures.h new file mode 100644 index 000000000000..3e1e4dd4fb99 --- /dev/null +++ b/quiche/http2/http2_structures.h @@ -0,0 +1,347 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_HTTP2_STRUCTURES_H_ +#define QUICHE_HTTP2_HTTP2_STRUCTURES_H_ + +// Defines structs for various fixed sized structures in HTTP/2. +// +// Those structs with multiple fields have constructors that take arguments in +// the same order as their encoding (which may be different from their order +// in the struct). For single field structs, use aggregate initialization if +// desired, e.g.: +// +// Http2RstStreamFields var{Http2ErrorCode::ENHANCE_YOUR_CALM}; +// or: +// SomeFunc(Http2RstStreamFields{Http2ErrorCode::ENHANCE_YOUR_CALM}); +// +// Each struct includes a static method EncodedSize which returns the number +// of bytes of the encoding. +// +// With the exception of Http2FrameHeader, all the types are named +// Http2Fields, where X is the title-case form of the frame which always +// includes the fields; the "always" is to cover the case of the PRIORITY frame; +// its fields optionally appear in the HEADERS frame, but the struct is called +// Http2PriorityFields. + +#include + +#include +#include +#include + +#include "quiche/http2/http2_constants.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { + +struct QUICHE_EXPORT Http2FrameHeader { + Http2FrameHeader() {} + Http2FrameHeader(uint32_t payload_length, Http2FrameType type, uint8_t flags, + uint32_t stream_id) + : payload_length(payload_length), + stream_id(stream_id), + type(type), + flags(flags) { + QUICHE_DCHECK_LT(payload_length, static_cast(1 << 24)) + << "Payload Length is only a 24 bit field\n" + << ToString(); + } + + static constexpr size_t EncodedSize() { return 9; } + + // Keep the current value of those flags that are in + // valid_flags, and clear all the others. + void RetainFlags(uint8_t valid_flags) { flags = (flags & valid_flags); } + + // Returns true if any of the flags in flag_mask are set, + // otherwise false. + bool HasAnyFlags(uint8_t flag_mask) const { return 0 != (flags & flag_mask); } + + // Is the END_STREAM flag set? + bool IsEndStream() const { + QUICHE_DCHECK(type == Http2FrameType::DATA || + type == Http2FrameType::HEADERS) + << ToString(); + return (flags & Http2FrameFlag::END_STREAM) != 0; + } + + // Is the ACK flag set? + bool IsAck() const { + QUICHE_DCHECK(type == Http2FrameType::SETTINGS || + type == Http2FrameType::PING) + << ToString(); + return (flags & Http2FrameFlag::ACK) != 0; + } + + // Is the END_HEADERS flag set? + bool IsEndHeaders() const { + QUICHE_DCHECK(type == Http2FrameType::HEADERS || + type == Http2FrameType::PUSH_PROMISE || + type == Http2FrameType::CONTINUATION) + << ToString(); + return (flags & Http2FrameFlag::END_HEADERS) != 0; + } + + // Is the PADDED flag set? + bool IsPadded() const { + QUICHE_DCHECK(type == Http2FrameType::DATA || + type == Http2FrameType::HEADERS || + type == Http2FrameType::PUSH_PROMISE) + << ToString(); + return (flags & Http2FrameFlag::PADDED) != 0; + } + + // Is the PRIORITY flag set? + bool HasPriority() const { + QUICHE_DCHECK_EQ(type, Http2FrameType::HEADERS) << ToString(); + return (flags & Http2FrameFlag::PRIORITY) != 0; + } + + // Does the encoding of this header start with "HTTP/", indicating that it + // might be from a non-HTTP/2 server. + bool IsProbableHttpResponse() const; + + // Produce strings useful for debugging/logging messages. + std::string ToString() const; + std::string FlagsToString() const; + + // 24 bit length of the payload after the header, including any padding. + // First field in encoding. + uint32_t payload_length; // 24 bits + + // 31 bit stream id, with high bit (32nd bit) reserved (must be zero), + // and is cleared during decoding. + // Fourth field in encoding. + uint32_t stream_id; + + // Type of the frame. + // Second field in encoding. + Http2FrameType type; + + // Flag bits, with interpretations that depend upon the frame type. + // Flag bits not used by the frame type are cleared. + // Third field in encoding. + uint8_t flags; +}; + +QUICHE_EXPORT bool operator==(const Http2FrameHeader& a, + const Http2FrameHeader& b); +QUICHE_EXPORT inline bool operator!=(const Http2FrameHeader& a, + const Http2FrameHeader& b) { + return !(a == b); +} +QUICHE_EXPORT std::ostream& operator<<(std::ostream& out, + const Http2FrameHeader& v); + +// Http2PriorityFields: + +struct QUICHE_EXPORT Http2PriorityFields { + Http2PriorityFields() {} + Http2PriorityFields(uint32_t stream_dependency, uint32_t weight, + bool is_exclusive) + : stream_dependency(stream_dependency), + weight(weight), + is_exclusive(is_exclusive) { + // Can't have the high-bit set in the stream id because we need to use + // that for the EXCLUSIVE flag bit. + QUICHE_DCHECK_EQ(stream_dependency, stream_dependency & StreamIdMask()) + << "Stream Dependency is only a 31-bit field.\n" + << ToString(); + QUICHE_DCHECK_LE(1u, weight) << "Weight is too small."; + QUICHE_DCHECK_LE(weight, 256u) << "Weight is too large."; + } + static constexpr size_t EncodedSize() { return 5; } + + // Produce strings useful for debugging/logging messages. + std::string ToString() const; + + // A 31-bit stream identifier for the stream that this stream depends on. + uint32_t stream_dependency; + + // Weight (1 to 256) is encoded as a byte in the range 0 to 255, so we + // add one when decoding, and store it in a field larger than a byte. + uint32_t weight; + + // A single-bit flag indicating that the stream dependency is exclusive; + // extracted from high bit of stream dependency field during decoding. + bool is_exclusive; +}; + +QUICHE_EXPORT bool operator==(const Http2PriorityFields& a, + const Http2PriorityFields& b); +QUICHE_EXPORT inline bool operator!=(const Http2PriorityFields& a, + const Http2PriorityFields& b) { + return !(a == b); +} +QUICHE_EXPORT std::ostream& operator<<(std::ostream& out, + const Http2PriorityFields& v); + +// Http2RstStreamFields: + +struct QUICHE_EXPORT Http2RstStreamFields { + static constexpr size_t EncodedSize() { return 4; } + bool IsSupportedErrorCode() const { + return IsSupportedHttp2ErrorCode(error_code); + } + + Http2ErrorCode error_code; +}; + +QUICHE_EXPORT bool operator==(const Http2RstStreamFields& a, + const Http2RstStreamFields& b); +QUICHE_EXPORT inline bool operator!=(const Http2RstStreamFields& a, + const Http2RstStreamFields& b) { + return !(a == b); +} +QUICHE_EXPORT std::ostream& operator<<(std::ostream& out, + const Http2RstStreamFields& v); + +// Http2SettingFields: + +struct QUICHE_EXPORT Http2SettingFields { + Http2SettingFields() {} + Http2SettingFields(Http2SettingsParameter parameter, uint32_t value) + : parameter(parameter), value(value) {} + static constexpr size_t EncodedSize() { return 6; } + bool IsSupportedParameter() const { + return IsSupportedHttp2SettingsParameter(parameter); + } + + Http2SettingsParameter parameter; + uint32_t value; +}; + +QUICHE_EXPORT bool operator==(const Http2SettingFields& a, + const Http2SettingFields& b); +QUICHE_EXPORT inline bool operator!=(const Http2SettingFields& a, + const Http2SettingFields& b) { + return !(a == b); +} +QUICHE_EXPORT std::ostream& operator<<(std::ostream& out, + const Http2SettingFields& v); + +// Http2PushPromiseFields: + +struct QUICHE_EXPORT Http2PushPromiseFields { + static constexpr size_t EncodedSize() { return 4; } + + uint32_t promised_stream_id; +}; + +QUICHE_EXPORT bool operator==(const Http2PushPromiseFields& a, + const Http2PushPromiseFields& b); +QUICHE_EXPORT inline bool operator!=(const Http2PushPromiseFields& a, + const Http2PushPromiseFields& b) { + return !(a == b); +} +QUICHE_EXPORT std::ostream& operator<<(std::ostream& out, + const Http2PushPromiseFields& v); + +// Http2PingFields: + +struct QUICHE_EXPORT Http2PingFields { + static constexpr size_t EncodedSize() { return 8; } + + uint8_t opaque_bytes[8]; +}; + +QUICHE_EXPORT bool operator==(const Http2PingFields& a, + const Http2PingFields& b); +QUICHE_EXPORT inline bool operator!=(const Http2PingFields& a, + const Http2PingFields& b) { + return !(a == b); +} +QUICHE_EXPORT std::ostream& operator<<(std::ostream& out, + const Http2PingFields& v); + +// Http2GoAwayFields: + +struct QUICHE_EXPORT Http2GoAwayFields { + Http2GoAwayFields() {} + Http2GoAwayFields(uint32_t last_stream_id, Http2ErrorCode error_code) + : last_stream_id(last_stream_id), error_code(error_code) {} + static constexpr size_t EncodedSize() { return 8; } + bool IsSupportedErrorCode() const { + return IsSupportedHttp2ErrorCode(error_code); + } + + uint32_t last_stream_id; + Http2ErrorCode error_code; +}; + +QUICHE_EXPORT bool operator==(const Http2GoAwayFields& a, + const Http2GoAwayFields& b); +QUICHE_EXPORT inline bool operator!=(const Http2GoAwayFields& a, + const Http2GoAwayFields& b) { + return !(a == b); +} +QUICHE_EXPORT std::ostream& operator<<(std::ostream& out, + const Http2GoAwayFields& v); + +// Http2WindowUpdateFields: + +struct QUICHE_EXPORT Http2WindowUpdateFields { + static constexpr size_t EncodedSize() { return 4; } + + // 31-bit, unsigned increase in the window size (only positive values are + // allowed). The high-bit is reserved for the future. + uint32_t window_size_increment; +}; + +QUICHE_EXPORT bool operator==(const Http2WindowUpdateFields& a, + const Http2WindowUpdateFields& b); +QUICHE_EXPORT inline bool operator!=(const Http2WindowUpdateFields& a, + const Http2WindowUpdateFields& b) { + return !(a == b); +} +QUICHE_EXPORT std::ostream& operator<<(std::ostream& out, + const Http2WindowUpdateFields& v); + +// Http2AltSvcFields: + +struct QUICHE_EXPORT Http2AltSvcFields { + static constexpr size_t EncodedSize() { return 2; } + + // This is the one fixed size portion of the ALTSVC payload. + uint16_t origin_length; +}; + +QUICHE_EXPORT bool operator==(const Http2AltSvcFields& a, + const Http2AltSvcFields& b); +QUICHE_EXPORT inline bool operator!=(const Http2AltSvcFields& a, + const Http2AltSvcFields& b) { + return !(a == b); +} +QUICHE_EXPORT std::ostream& operator<<(std::ostream& out, + const Http2AltSvcFields& v); + +// Http2PriorityUpdateFields: + +struct QUICHE_EXPORT Http2PriorityUpdateFields { + Http2PriorityUpdateFields() {} + Http2PriorityUpdateFields(uint32_t prioritized_stream_id) + : prioritized_stream_id(prioritized_stream_id) {} + static constexpr size_t EncodedSize() { return 4; } + + // Produce strings useful for debugging/logging messages. + std::string ToString() const; + + // The 31-bit stream identifier of the stream whose priority is updated. + uint32_t prioritized_stream_id; +}; + +QUICHE_EXPORT bool operator==(const Http2PriorityUpdateFields& a, + const Http2PriorityUpdateFields& b); +QUICHE_EXPORT inline bool operator!=(const Http2PriorityUpdateFields& a, + const Http2PriorityUpdateFields& b) { + return !(a == b); +} +QUICHE_EXPORT std::ostream& operator<<(std::ostream& out, + const Http2PriorityUpdateFields& v); + +} // namespace http2 + +#endif // QUICHE_HTTP2_HTTP2_STRUCTURES_H_ diff --git a/quiche/http2/http2_structures_test.cc b/quiche/http2/http2_structures_test.cc new file mode 100644 index 000000000000..5907733ef7ef --- /dev/null +++ b/quiche/http2/http2_structures_test.cc @@ -0,0 +1,570 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/http2_structures.h" + +// Tests are focused on Http2FrameHeader because it has by far the most +// methods of any of the structures. +// Note that EXPECT.*DEATH tests are slow (a fork is probably involved). + +// And in case you're wondering, yes, these are ridiculously thorough tests, +// but believe it or not, I've found silly bugs this way. + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/http2/test_tools/http2_structures_test_util.h" +#include "quiche/http2/test_tools/verify_macros.h" +#include "quiche/common/platform/api/quiche_test.h" + +using ::testing::AssertionResult; +using ::testing::AssertionSuccess; +using ::testing::Combine; +using ::testing::HasSubstr; +using ::testing::MatchesRegex; +using ::testing::Not; +using ::testing::Values; +using ::testing::ValuesIn; + +namespace http2 { +namespace test { +namespace { + +template +E IncrementEnum(E e) { + using I = typename std::underlying_type::type; + return static_cast(1 + static_cast(e)); +} + +template +AssertionResult VerifyRandomCalls() { + T t1; + // Initialize with a stable key, to avoid test flakiness. + Http2Random seq1( + "6d9a61ddf2bc1fc0b8245505a1f28e324559d8b5c9c3268f38b42b1af3287c47"); + Randomize(&t1, &seq1); + + T t2; + Http2Random seq2(seq1.Key()); + Randomize(&t2, &seq2); + + // The two Randomize calls should have made the same number of calls into + // the Http2Random implementations. + HTTP2_VERIFY_EQ(seq1.Rand64(), seq2.Rand64()); + + // And because Http2Random implementation is returning the same sequence, and + // Randomize should have been consistent in applying those results, the two + // Ts should have the same value. + HTTP2_VERIFY_EQ(t1, t2); + + Randomize(&t2, &seq2); + HTTP2_VERIFY_NE(t1, t2); + + Randomize(&t1, &seq1); + HTTP2_VERIFY_EQ(t1, t2); + + HTTP2_VERIFY_EQ(seq1.Rand64(), seq2.Rand64()); + + return AssertionSuccess(); +} + +#if GTEST_HAS_DEATH_TEST && !defined(NDEBUG) +std::vector ValidFrameTypes() { + std::vector valid_types{Http2FrameType::DATA}; + while (valid_types.back() != Http2FrameType::ALTSVC) { + valid_types.push_back(IncrementEnum(valid_types.back())); + } + return valid_types; +} +#endif // GTEST_HAS_DEATH_TEST && !defined(NDEBUG) + +TEST(Http2FrameHeaderTest, Constructor) { + Http2Random random; + uint8_t frame_type = 0; + do { + // Only the payload length is QUICHE_DCHECK'd in the constructor, so we need + // to make sure it is a "uint24". + uint32_t payload_length = random.Rand32() & 0xffffff; + Http2FrameType type = static_cast(frame_type); + uint8_t flags = random.Rand8(); + uint32_t stream_id = random.Rand32(); + + Http2FrameHeader v(payload_length, type, flags, stream_id); + + EXPECT_EQ(payload_length, v.payload_length); + EXPECT_EQ(type, v.type); + EXPECT_EQ(flags, v.flags); + EXPECT_EQ(stream_id, v.stream_id); + } while (frame_type++ != 255); + +#if GTEST_HAS_DEATH_TEST && !defined(NDEBUG) + EXPECT_QUICHE_DEBUG_DEATH( + Http2FrameHeader(0x01000000, Http2FrameType::DATA, 0, 1), + "payload_length"); +#endif // GTEST_HAS_DEATH_TEST && !defined(NDEBUG) +} + +TEST(Http2FrameHeaderTest, Eq) { + Http2Random random; + uint32_t payload_length = random.Rand32() & 0xffffff; + Http2FrameType type = static_cast(random.Rand8()); + + uint8_t flags = random.Rand8(); + uint32_t stream_id = random.Rand32(); + + Http2FrameHeader v(payload_length, type, flags, stream_id); + + EXPECT_EQ(payload_length, v.payload_length); + EXPECT_EQ(type, v.type); + EXPECT_EQ(flags, v.flags); + EXPECT_EQ(stream_id, v.stream_id); + + Http2FrameHeader u(0, type, ~flags, stream_id); + + EXPECT_NE(u, v); + EXPECT_NE(v, u); + EXPECT_FALSE(u == v); + EXPECT_FALSE(v == u); + EXPECT_TRUE(u != v); + EXPECT_TRUE(v != u); + + u = v; + + EXPECT_EQ(u, v); + EXPECT_EQ(v, u); + EXPECT_TRUE(u == v); + EXPECT_TRUE(v == u); + EXPECT_FALSE(u != v); + EXPECT_FALSE(v != u); + + EXPECT_TRUE(VerifyRandomCalls()); +} + +#if GTEST_HAS_DEATH_TEST && !defined(NDEBUG) + +using TestParams = std::tuple; + +std::string TestParamToString(const testing::TestParamInfo& info) { + Http2FrameType type = std::get<0>(info.param); + uint8_t flags = std::get<1>(info.param); + + return absl::StrCat(Http2FrameTypeToString(type), static_cast(flags)); +} + +// The tests of the valid frame types include EXPECT_QUICHE_DEBUG_DEATH, which +// is quite slow, so using value parameterized tests in order to allow sharding. +class Http2FrameHeaderTypeAndFlagTest + : public quiche::test::QuicheTestWithParam { + protected: + Http2FrameHeaderTypeAndFlagTest() + : type_(std::get<0>(GetParam())), flags_(std::get<1>(GetParam())) { + QUICHE_LOG(INFO) << "Frame type: " << type_; + QUICHE_LOG(INFO) << "Frame flags: " + << Http2FrameFlagsToString(type_, flags_); + } + + const Http2FrameType type_; + const uint8_t flags_; +}; + +class IsEndStreamTest : public Http2FrameHeaderTypeAndFlagTest {}; +INSTANTIATE_TEST_SUITE_P(IsEndStream, IsEndStreamTest, + Combine(ValuesIn(ValidFrameTypes()), + Values(~Http2FrameFlag::END_STREAM, 0xff)), + TestParamToString); +TEST_P(IsEndStreamTest, IsEndStream) { + const bool is_set = + (flags_ & Http2FrameFlag::END_STREAM) == Http2FrameFlag::END_STREAM; + std::string flags_string; + Http2FrameHeader v(0, type_, flags_, 0); + switch (type_) { + case Http2FrameType::DATA: + case Http2FrameType::HEADERS: + EXPECT_EQ(is_set, v.IsEndStream()) << v; + flags_string = v.FlagsToString(); + if (is_set) { + EXPECT_THAT(flags_string, MatchesRegex(".*\\|?END_STREAM\\|.*")); + } else { + EXPECT_THAT(flags_string, Not(HasSubstr("END_STREAM"))); + } + v.RetainFlags(Http2FrameFlag::END_STREAM); + EXPECT_EQ(is_set, v.IsEndStream()) << v; + { + std::stringstream s; + s << v; + EXPECT_EQ(v.ToString(), s.str()); + if (is_set) { + EXPECT_THAT(s.str(), HasSubstr("flags=END_STREAM,")); + } else { + EXPECT_THAT(s.str(), HasSubstr("flags=,")); + } + } + break; + default: + EXPECT_QUICHE_DEBUG_DEATH(v.IsEndStream(), "DATA.*HEADERS"); + } +} + +class IsACKTest : public Http2FrameHeaderTypeAndFlagTest {}; +INSTANTIATE_TEST_SUITE_P(IsAck, IsACKTest, + Combine(ValuesIn(ValidFrameTypes()), + Values(~Http2FrameFlag::ACK, 0xff)), + TestParamToString); +TEST_P(IsACKTest, IsAck) { + const bool is_set = (flags_ & Http2FrameFlag::ACK) == Http2FrameFlag::ACK; + std::string flags_string; + Http2FrameHeader v(0, type_, flags_, 0); + switch (type_) { + case Http2FrameType::SETTINGS: + case Http2FrameType::PING: + EXPECT_EQ(is_set, v.IsAck()) << v; + flags_string = v.FlagsToString(); + if (is_set) { + EXPECT_THAT(flags_string, MatchesRegex(".*\\|?ACK\\|.*")); + } else { + EXPECT_THAT(flags_string, Not(HasSubstr("ACK"))); + } + v.RetainFlags(Http2FrameFlag::ACK); + EXPECT_EQ(is_set, v.IsAck()) << v; + { + std::stringstream s; + s << v; + EXPECT_EQ(v.ToString(), s.str()); + if (is_set) { + EXPECT_THAT(s.str(), HasSubstr("flags=ACK,")); + } else { + EXPECT_THAT(s.str(), HasSubstr("flags=,")); + } + } + break; + default: + EXPECT_QUICHE_DEBUG_DEATH(v.IsAck(), "SETTINGS.*PING"); + } +} + +class IsEndHeadersTest : public Http2FrameHeaderTypeAndFlagTest {}; +INSTANTIATE_TEST_SUITE_P(IsEndHeaders, IsEndHeadersTest, + Combine(ValuesIn(ValidFrameTypes()), + Values(~Http2FrameFlag::END_HEADERS, 0xff)), + TestParamToString); +TEST_P(IsEndHeadersTest, IsEndHeaders) { + const bool is_set = + (flags_ & Http2FrameFlag::END_HEADERS) == Http2FrameFlag::END_HEADERS; + std::string flags_string; + Http2FrameHeader v(0, type_, flags_, 0); + switch (type_) { + case Http2FrameType::HEADERS: + case Http2FrameType::PUSH_PROMISE: + case Http2FrameType::CONTINUATION: + EXPECT_EQ(is_set, v.IsEndHeaders()) << v; + flags_string = v.FlagsToString(); + if (is_set) { + EXPECT_THAT(flags_string, MatchesRegex(".*\\|?END_HEADERS\\|.*")); + } else { + EXPECT_THAT(flags_string, Not(HasSubstr("END_HEADERS"))); + } + v.RetainFlags(Http2FrameFlag::END_HEADERS); + EXPECT_EQ(is_set, v.IsEndHeaders()) << v; + { + std::stringstream s; + s << v; + EXPECT_EQ(v.ToString(), s.str()); + if (is_set) { + EXPECT_THAT(s.str(), HasSubstr("flags=END_HEADERS,")); + } else { + EXPECT_THAT(s.str(), HasSubstr("flags=,")); + } + } + break; + default: + EXPECT_QUICHE_DEBUG_DEATH(v.IsEndHeaders(), + "HEADERS.*PUSH_PROMISE.*CONTINUATION"); + } +} + +class IsPaddedTest : public Http2FrameHeaderTypeAndFlagTest {}; +INSTANTIATE_TEST_SUITE_P(IsPadded, IsPaddedTest, + Combine(ValuesIn(ValidFrameTypes()), + Values(~Http2FrameFlag::PADDED, 0xff)), + TestParamToString); +TEST_P(IsPaddedTest, IsPadded) { + const bool is_set = + (flags_ & Http2FrameFlag::PADDED) == Http2FrameFlag::PADDED; + std::string flags_string; + Http2FrameHeader v(0, type_, flags_, 0); + switch (type_) { + case Http2FrameType::DATA: + case Http2FrameType::HEADERS: + case Http2FrameType::PUSH_PROMISE: + EXPECT_EQ(is_set, v.IsPadded()) << v; + flags_string = v.FlagsToString(); + if (is_set) { + EXPECT_THAT(flags_string, MatchesRegex(".*\\|?PADDED\\|.*")); + } else { + EXPECT_THAT(flags_string, Not(HasSubstr("PADDED"))); + } + v.RetainFlags(Http2FrameFlag::PADDED); + EXPECT_EQ(is_set, v.IsPadded()) << v; + { + std::stringstream s; + s << v; + EXPECT_EQ(v.ToString(), s.str()); + if (is_set) { + EXPECT_THAT(s.str(), HasSubstr("flags=PADDED,")); + } else { + EXPECT_THAT(s.str(), HasSubstr("flags=,")); + } + } + break; + default: + EXPECT_QUICHE_DEBUG_DEATH(v.IsPadded(), "DATA.*HEADERS.*PUSH_PROMISE"); + } +} + +class HasPriorityTest : public Http2FrameHeaderTypeAndFlagTest {}; +INSTANTIATE_TEST_SUITE_P(HasPriority, HasPriorityTest, + Combine(ValuesIn(ValidFrameTypes()), + Values(~Http2FrameFlag::PRIORITY, 0xff)), + TestParamToString); +TEST_P(HasPriorityTest, HasPriority) { + const bool is_set = + (flags_ & Http2FrameFlag::PRIORITY) == Http2FrameFlag::PRIORITY; + std::string flags_string; + Http2FrameHeader v(0, type_, flags_, 0); + switch (type_) { + case Http2FrameType::HEADERS: + EXPECT_EQ(is_set, v.HasPriority()) << v; + flags_string = v.FlagsToString(); + if (is_set) { + EXPECT_THAT(flags_string, MatchesRegex(".*\\|?PRIORITY\\|.*")); + } else { + EXPECT_THAT(flags_string, Not(HasSubstr("PRIORITY"))); + } + v.RetainFlags(Http2FrameFlag::PRIORITY); + EXPECT_EQ(is_set, v.HasPriority()) << v; + { + std::stringstream s; + s << v; + EXPECT_EQ(v.ToString(), s.str()); + if (is_set) { + EXPECT_THAT(s.str(), HasSubstr("flags=PRIORITY,")); + } else { + EXPECT_THAT(s.str(), HasSubstr("flags=,")); + } + } + break; + default: + EXPECT_QUICHE_DEBUG_DEATH(v.HasPriority(), "HEADERS"); + } +} + +TEST(Http2PriorityFieldsTest, Constructor) { + Http2Random random; + uint32_t stream_dependency = random.Rand32() & StreamIdMask(); + uint32_t weight = 1 + random.Rand8(); + bool is_exclusive = random.OneIn(2); + + Http2PriorityFields v(stream_dependency, weight, is_exclusive); + + EXPECT_EQ(stream_dependency, v.stream_dependency); + EXPECT_EQ(weight, v.weight); + EXPECT_EQ(is_exclusive, v.is_exclusive); + + // The high-bit must not be set on the stream id. + EXPECT_QUICHE_DEBUG_DEATH( + Http2PriorityFields(stream_dependency | 0x80000000, weight, is_exclusive), + "31-bit"); + + // The weight must be in the range 1-256. + EXPECT_QUICHE_DEBUG_DEATH( + Http2PriorityFields(stream_dependency, 0, is_exclusive), "too small"); + EXPECT_QUICHE_DEBUG_DEATH( + Http2PriorityFields(stream_dependency, weight + 256, is_exclusive), + "too large"); + + EXPECT_TRUE(VerifyRandomCalls()); +} +#endif // GTEST_HAS_DEATH_TEST && !defined(NDEBUG) + +TEST(Http2RstStreamFieldsTest, IsSupported) { + Http2RstStreamFields v{Http2ErrorCode::HTTP2_NO_ERROR}; + EXPECT_TRUE(v.IsSupportedErrorCode()) << v; + + Http2RstStreamFields u{static_cast(~0)}; + EXPECT_FALSE(u.IsSupportedErrorCode()) << v; + + EXPECT_TRUE(VerifyRandomCalls()); +} + +TEST(Http2SettingFieldsTest, Misc) { + Http2Random random; + Http2SettingsParameter parameter = + static_cast(random.Rand16()); + uint32_t value = random.Rand32(); + + Http2SettingFields v(parameter, value); + + EXPECT_EQ(v, v); + EXPECT_EQ(parameter, v.parameter); + EXPECT_EQ(value, v.value); + + if (static_cast(parameter) < 7) { + EXPECT_TRUE(v.IsSupportedParameter()) << v; + } else { + EXPECT_FALSE(v.IsSupportedParameter()) << v; + } + + Http2SettingFields u(parameter, ~value); + EXPECT_NE(v, u); + EXPECT_EQ(v.parameter, u.parameter); + EXPECT_NE(v.value, u.value); + + Http2SettingFields w(IncrementEnum(parameter), value); + EXPECT_NE(v, w); + EXPECT_NE(v.parameter, w.parameter); + EXPECT_EQ(v.value, w.value); + + Http2SettingFields x(Http2SettingsParameter::MAX_FRAME_SIZE, 123); + std::stringstream s; + s << x; + EXPECT_EQ("parameter=MAX_FRAME_SIZE, value=123", s.str()); + + EXPECT_TRUE(VerifyRandomCalls()); +} + +TEST(Http2PushPromiseTest, Misc) { + Http2Random random; + uint32_t promised_stream_id = random.Rand32() & StreamIdMask(); + + Http2PushPromiseFields v{promised_stream_id}; + EXPECT_EQ(promised_stream_id, v.promised_stream_id); + EXPECT_EQ(v, v); + + std::stringstream s; + s << v; + EXPECT_EQ(absl::StrCat("promised_stream_id=", promised_stream_id), s.str()); + + // High-bit is reserved, but not used, so we can set it. + promised_stream_id |= 0x80000000; + Http2PushPromiseFields w{promised_stream_id}; + EXPECT_EQ(w, w); + EXPECT_NE(v, w); + + v.promised_stream_id = promised_stream_id; + EXPECT_EQ(v, w); + + EXPECT_TRUE(VerifyRandomCalls()); +} + +TEST(Http2PingFieldsTest, Misc) { + Http2PingFields v{{'8', ' ', 'b', 'y', 't', 'e', 's', '\0'}}; + std::stringstream s; + s << v; + EXPECT_EQ("opaque_bytes=0x3820627974657300", s.str()); + + EXPECT_TRUE(VerifyRandomCalls()); +} + +TEST(Http2GoAwayFieldsTest, Misc) { + Http2Random random; + uint32_t last_stream_id = random.Rand32() & StreamIdMask(); + Http2ErrorCode error_code = static_cast(random.Rand32()); + + Http2GoAwayFields v(last_stream_id, error_code); + EXPECT_EQ(v, v); + EXPECT_EQ(last_stream_id, v.last_stream_id); + EXPECT_EQ(error_code, v.error_code); + + if (static_cast(error_code) < 14) { + EXPECT_TRUE(v.IsSupportedErrorCode()) << v; + } else { + EXPECT_FALSE(v.IsSupportedErrorCode()) << v; + } + + Http2GoAwayFields u(~last_stream_id, error_code); + EXPECT_NE(v, u); + EXPECT_NE(v.last_stream_id, u.last_stream_id); + EXPECT_EQ(v.error_code, u.error_code); + + EXPECT_TRUE(VerifyRandomCalls()); +} + +TEST(Http2WindowUpdateTest, Misc) { + Http2Random random; + uint32_t window_size_increment = random.Rand32() & UInt31Mask(); + + Http2WindowUpdateFields v{window_size_increment}; + EXPECT_EQ(window_size_increment, v.window_size_increment); + EXPECT_EQ(v, v); + + std::stringstream s; + s << v; + EXPECT_EQ(absl::StrCat("window_size_increment=", window_size_increment), + s.str()); + + // High-bit is reserved, but not used, so we can set it. + window_size_increment |= 0x80000000; + Http2WindowUpdateFields w{window_size_increment}; + EXPECT_EQ(w, w); + EXPECT_NE(v, w); + + v.window_size_increment = window_size_increment; + EXPECT_EQ(v, w); + + EXPECT_TRUE(VerifyRandomCalls()); +} + +TEST(Http2AltSvcTest, Misc) { + Http2Random random; + uint16_t origin_length = random.Rand16(); + + Http2AltSvcFields v{origin_length}; + EXPECT_EQ(origin_length, v.origin_length); + EXPECT_EQ(v, v); + + std::stringstream s; + s << v; + EXPECT_EQ(absl::StrCat("origin_length=", origin_length), s.str()); + + Http2AltSvcFields w{++origin_length}; + EXPECT_EQ(w, w); + EXPECT_NE(v, w); + + v.origin_length = w.origin_length; + EXPECT_EQ(v, w); + + EXPECT_TRUE(VerifyRandomCalls()); +} + +TEST(Http2PriorityUpdateFieldsTest, Eq) { + Http2PriorityUpdateFields u(/* prioritized_stream_id = */ 1); + Http2PriorityUpdateFields v(/* prioritized_stream_id = */ 3); + + EXPECT_NE(u, v); + EXPECT_FALSE(u == v); + EXPECT_TRUE(u != v); + + u = v; + EXPECT_EQ(u, v); + EXPECT_TRUE(u == v); + EXPECT_FALSE(u != v); +} + +TEST(Http2PriorityUpdateFieldsTest, Misc) { + Http2PriorityUpdateFields u(/* prioritized_stream_id = */ 1); + EXPECT_EQ("prioritized_stream_id=1", u.ToString()); + + EXPECT_TRUE(VerifyRandomCalls()); +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/test_tools/frame_decoder_state_test_util.cc b/quiche/http2/test_tools/frame_decoder_state_test_util.cc new file mode 100644 index 000000000000..7de363f3e2d4 --- /dev/null +++ b/quiche/http2/test_tools/frame_decoder_state_test_util.cc @@ -0,0 +1,34 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/test_tools/frame_decoder_state_test_util.h" + +#include "quiche/http2/http2_structures.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/http2/test_tools/http2_structure_decoder_test_util.h" +#include "quiche/http2/test_tools/http2_structures_test_util.h" +#include "quiche/http2/test_tools/random_decoder_test_base.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { +namespace test { + +// static +void FrameDecoderStatePeer::Randomize(FrameDecoderState* p, Http2Random* rng) { + QUICHE_VLOG(1) << "FrameDecoderStatePeer::Randomize"; + ::http2::test::Randomize(&p->frame_header_, rng); + p->remaining_payload_ = rng->Rand32(); + p->remaining_padding_ = rng->Rand32(); + Http2StructureDecoderPeer::Randomize(&p->structure_decoder_, rng); +} + +// static +void FrameDecoderStatePeer::set_frame_header(const Http2FrameHeader& header, + FrameDecoderState* p) { + QUICHE_VLOG(1) << "FrameDecoderStatePeer::set_frame_header " << header; + p->frame_header_ = header; +} + +} // namespace test +} // namespace http2 diff --git a/quiche/http2/test_tools/frame_decoder_state_test_util.h b/quiche/http2/test_tools/frame_decoder_state_test_util.h new file mode 100644 index 000000000000..3c6fde9ecee0 --- /dev/null +++ b/quiche/http2/test_tools/frame_decoder_state_test_util.h @@ -0,0 +1,37 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_TEST_TOOLS_FRAME_DECODER_STATE_TEST_UTIL_H_ +#define QUICHE_HTTP2_TEST_TOOLS_FRAME_DECODER_STATE_TEST_UTIL_H_ + +#include "quiche/http2/decoder/frame_decoder_state.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/http2/test_tools/random_decoder_test_base.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace test { + +class QUICHE_NO_EXPORT FrameDecoderStatePeer { + public: + // Randomizes (i.e. corrupts) the fields of the FrameDecoderState. + // PayloadDecoderBaseTest::StartDecoding calls this before passing the first + // decode buffer to the payload decoder, which increases the likelihood of + // detecting any use of prior states of the decoder on the decoding of + // future payloads. + static void Randomize(FrameDecoderState* p, Http2Random* rng); + + // Inject a frame header into the FrameDecoderState. + // PayloadDecoderBaseTest::StartDecoding calls this just after calling + // Randomize (above), to simulate a full frame decoder having just finished + // decoding the common frame header and then calling the appropriate payload + // decoder based on the frame type in that frame header. + static void set_frame_header(const Http2FrameHeader& header, + FrameDecoderState* p); +}; + +} // namespace test +} // namespace http2 + +#endif // QUICHE_HTTP2_TEST_TOOLS_FRAME_DECODER_STATE_TEST_UTIL_H_ diff --git a/quiche/http2/test_tools/frame_parts.cc b/quiche/http2/test_tools/frame_parts.cc new file mode 100644 index 000000000000..4792e27c74fd --- /dev/null +++ b/quiche/http2/test_tools/frame_parts.cc @@ -0,0 +1,554 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/test_tools/frame_parts.h" + +#include + +#include "absl/strings/escaping.h" +#include "quiche/http2/test_tools/http2_structures_test_util.h" +#include "quiche/http2/test_tools/verify_macros.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +using ::testing::AssertionFailure; +using ::testing::AssertionResult; +using ::testing::AssertionSuccess; +using ::testing::ContainerEq; + +namespace http2 { +namespace test { +namespace { + +static_assert(std::is_base_of::value && + !std::is_abstract::value, + "FrameParts needs to implement all of the methods of " + "Http2FrameDecoderListener"); + +// Compare two optional variables of the same type. +// TODO(jamessynge): Maybe create a ::testing::Matcher for this. +template +AssertionResult VerifyOptionalEq(const T& opt_a, const T& opt_b) { + if (opt_a) { + if (opt_b) { + HTTP2_VERIFY_EQ(opt_a.value(), opt_b.value()); + } else { + return AssertionFailure() + << "opt_b is not set; opt_a.value()=" << opt_a.value(); + } + } else if (opt_b) { + return AssertionFailure() + << "opt_a is not set; opt_b.value()=" << opt_b.value(); + } + return AssertionSuccess(); +} + +} // namespace + +FrameParts::FrameParts(const Http2FrameHeader& header) : frame_header_(header) { + QUICHE_VLOG(1) << "FrameParts, header: " << frame_header_; +} + +FrameParts::FrameParts(const Http2FrameHeader& header, + absl::string_view payload) + : FrameParts(header) { + QUICHE_VLOG(1) << "FrameParts with payload.size() = " << payload.size(); + this->payload_.append(payload.data(), payload.size()); + opt_payload_length_ = payload.size(); +} +FrameParts::FrameParts(const Http2FrameHeader& header, + absl::string_view payload, size_t total_pad_length) + : FrameParts(header, payload) { + QUICHE_VLOG(1) << "FrameParts with total_pad_length=" << total_pad_length; + SetTotalPadLength(total_pad_length); +} + +FrameParts::FrameParts(const FrameParts& header) = default; + +FrameParts::~FrameParts() = default; + +AssertionResult FrameParts::VerifyEquals(const FrameParts& that) const { +#define COMMON_MESSAGE "\n this: " << *this << "\n that: " << that + + HTTP2_VERIFY_EQ(frame_header_, that.frame_header_) << COMMON_MESSAGE; + HTTP2_VERIFY_EQ(payload_, that.payload_) << COMMON_MESSAGE; + HTTP2_VERIFY_EQ(padding_, that.padding_) << COMMON_MESSAGE; + HTTP2_VERIFY_EQ(altsvc_origin_, that.altsvc_origin_) << COMMON_MESSAGE; + HTTP2_VERIFY_EQ(altsvc_value_, that.altsvc_value_) << COMMON_MESSAGE; + HTTP2_VERIFY_EQ(settings_, that.settings_) << COMMON_MESSAGE; + +#define HTTP2_VERIFY_OPTIONAL_FIELD(field_name) \ + HTTP2_VERIFY_SUCCESS(VerifyOptionalEq(field_name, that.field_name)) + + HTTP2_VERIFY_OPTIONAL_FIELD(opt_altsvc_origin_length_) << COMMON_MESSAGE; + HTTP2_VERIFY_OPTIONAL_FIELD(opt_altsvc_value_length_) << COMMON_MESSAGE; + HTTP2_VERIFY_OPTIONAL_FIELD(opt_priority_update_) << COMMON_MESSAGE; + HTTP2_VERIFY_OPTIONAL_FIELD(opt_goaway_) << COMMON_MESSAGE; + HTTP2_VERIFY_OPTIONAL_FIELD(opt_missing_length_) << COMMON_MESSAGE; + HTTP2_VERIFY_OPTIONAL_FIELD(opt_pad_length_) << COMMON_MESSAGE; + HTTP2_VERIFY_OPTIONAL_FIELD(opt_ping_) << COMMON_MESSAGE; + HTTP2_VERIFY_OPTIONAL_FIELD(opt_priority_) << COMMON_MESSAGE; + HTTP2_VERIFY_OPTIONAL_FIELD(opt_push_promise_) << COMMON_MESSAGE; + HTTP2_VERIFY_OPTIONAL_FIELD(opt_rst_stream_error_code_) << COMMON_MESSAGE; + HTTP2_VERIFY_OPTIONAL_FIELD(opt_window_update_increment_) << COMMON_MESSAGE; + +#undef HTTP2_VERIFY_OPTIONAL_FIELD + + return AssertionSuccess(); +} + +void FrameParts::SetTotalPadLength(size_t total_pad_length) { + opt_pad_length_.reset(); + padding_.clear(); + if (total_pad_length > 0) { + ASSERT_LE(total_pad_length, 256u); + ASSERT_TRUE(frame_header_.IsPadded()); + opt_pad_length_ = total_pad_length - 1; + char zero = 0; + padding_.append(opt_pad_length_.value(), zero); + } + + if (opt_pad_length_) { + QUICHE_VLOG(1) << "SetTotalPadLength: pad_length=" + << opt_pad_length_.value(); + } else { + QUICHE_VLOG(1) << "SetTotalPadLength: has no pad length"; + } +} + +void FrameParts::SetAltSvcExpected(absl::string_view origin, + absl::string_view value) { + altsvc_origin_.append(origin.data(), origin.size()); + altsvc_value_.append(value.data(), value.size()); + opt_altsvc_origin_length_ = origin.size(); + opt_altsvc_value_length_ = value.size(); +} + +bool FrameParts::OnFrameHeader(const Http2FrameHeader& /*header*/) { + ADD_FAILURE() << "OnFrameHeader: " << *this; + return true; +} + +void FrameParts::OnDataStart(const Http2FrameHeader& header) { + QUICHE_VLOG(1) << "OnDataStart: " << header; + ASSERT_TRUE(StartFrameOfType(header, Http2FrameType::DATA)) << *this; + opt_payload_length_ = header.payload_length; +} + +void FrameParts::OnDataPayload(const char* data, size_t len) { + QUICHE_VLOG(1) << "OnDataPayload: len=" << len + << "; frame_header_: " << frame_header_; + ASSERT_TRUE(InFrameOfType(Http2FrameType::DATA)) << *this; + ASSERT_TRUE(AppendString(absl::string_view(data, len), &payload_, + &opt_payload_length_)); +} + +void FrameParts::OnDataEnd() { + QUICHE_VLOG(1) << "OnDataEnd; frame_header_: " << frame_header_; + ASSERT_TRUE(EndFrameOfType(Http2FrameType::DATA)) << *this; +} + +void FrameParts::OnHeadersStart(const Http2FrameHeader& header) { + QUICHE_VLOG(1) << "OnHeadersStart: " << header; + ASSERT_TRUE(StartFrameOfType(header, Http2FrameType::HEADERS)) << *this; + opt_payload_length_ = header.payload_length; +} + +void FrameParts::OnHeadersPriority(const Http2PriorityFields& priority) { + QUICHE_VLOG(1) << "OnHeadersPriority: priority: " << priority + << "; frame_header_: " << frame_header_; + ASSERT_TRUE(InFrameOfType(Http2FrameType::HEADERS)) << *this; + ASSERT_FALSE(opt_priority_); + opt_priority_ = priority; + ASSERT_TRUE(opt_payload_length_); + opt_payload_length_ = + opt_payload_length_.value() - Http2PriorityFields::EncodedSize(); +} + +void FrameParts::OnHpackFragment(const char* data, size_t len) { + QUICHE_VLOG(1) << "OnHpackFragment: len=" << len + << "; frame_header_: " << frame_header_; + ASSERT_TRUE(got_start_callback_); + ASSERT_FALSE(got_end_callback_); + ASSERT_TRUE(FrameCanHaveHpackPayload(frame_header_)) << *this; + ASSERT_TRUE(AppendString(absl::string_view(data, len), &payload_, + &opt_payload_length_)); +} + +void FrameParts::OnHeadersEnd() { + QUICHE_VLOG(1) << "OnHeadersEnd; frame_header_: " << frame_header_; + ASSERT_TRUE(EndFrameOfType(Http2FrameType::HEADERS)) << *this; +} + +void FrameParts::OnPriorityFrame(const Http2FrameHeader& header, + const Http2PriorityFields& priority) { + QUICHE_VLOG(1) << "OnPriorityFrame: " << header << "; priority: " << priority; + ASSERT_TRUE(StartFrameOfType(header, Http2FrameType::PRIORITY)) << *this; + ASSERT_FALSE(opt_priority_); + opt_priority_ = priority; + ASSERT_TRUE(EndFrameOfType(Http2FrameType::PRIORITY)) << *this; +} + +void FrameParts::OnContinuationStart(const Http2FrameHeader& header) { + QUICHE_VLOG(1) << "OnContinuationStart: " << header; + ASSERT_TRUE(StartFrameOfType(header, Http2FrameType::CONTINUATION)) << *this; + opt_payload_length_ = header.payload_length; +} + +void FrameParts::OnContinuationEnd() { + QUICHE_VLOG(1) << "OnContinuationEnd; frame_header_: " << frame_header_; + ASSERT_TRUE(EndFrameOfType(Http2FrameType::CONTINUATION)) << *this; +} + +void FrameParts::OnPadLength(size_t trailing_length) { + QUICHE_VLOG(1) << "OnPadLength: trailing_length=" << trailing_length; + ASSERT_TRUE(InPaddedFrame()) << *this; + ASSERT_FALSE(opt_pad_length_); + ASSERT_TRUE(opt_payload_length_); + size_t total_padding_length = trailing_length + 1; + ASSERT_GE(opt_payload_length_.value(), total_padding_length); + opt_payload_length_ = opt_payload_length_.value() - total_padding_length; + opt_pad_length_ = trailing_length; +} + +void FrameParts::OnPadding(const char* pad, size_t skipped_length) { + QUICHE_VLOG(1) << "OnPadding: skipped_length=" << skipped_length; + ASSERT_TRUE(InPaddedFrame()) << *this; + ASSERT_TRUE(opt_pad_length_); + ASSERT_TRUE(AppendString(absl::string_view(pad, skipped_length), &padding_, + &opt_pad_length_)); +} + +void FrameParts::OnRstStream(const Http2FrameHeader& header, + Http2ErrorCode error_code) { + QUICHE_VLOG(1) << "OnRstStream: " << header << "; code=" << error_code; + ASSERT_TRUE(StartFrameOfType(header, Http2FrameType::RST_STREAM)) << *this; + ASSERT_FALSE(opt_rst_stream_error_code_); + opt_rst_stream_error_code_ = error_code; + ASSERT_TRUE(EndFrameOfType(Http2FrameType::RST_STREAM)) << *this; +} + +void FrameParts::OnSettingsStart(const Http2FrameHeader& header) { + QUICHE_VLOG(1) << "OnSettingsStart: " << header; + ASSERT_TRUE(StartFrameOfType(header, Http2FrameType::SETTINGS)) << *this; + ASSERT_EQ(0u, settings_.size()); + ASSERT_FALSE(header.IsAck()) << header; +} + +void FrameParts::OnSetting(const Http2SettingFields& setting_fields) { + QUICHE_VLOG(1) << "OnSetting: " << setting_fields; + ASSERT_TRUE(InFrameOfType(Http2FrameType::SETTINGS)) << *this; + settings_.push_back(setting_fields); +} + +void FrameParts::OnSettingsEnd() { + QUICHE_VLOG(1) << "OnSettingsEnd; frame_header_: " << frame_header_; + ASSERT_TRUE(EndFrameOfType(Http2FrameType::SETTINGS)) << *this; +} + +void FrameParts::OnSettingsAck(const Http2FrameHeader& header) { + QUICHE_VLOG(1) << "OnSettingsAck: " << header; + ASSERT_TRUE(StartFrameOfType(header, Http2FrameType::SETTINGS)) << *this; + ASSERT_EQ(0u, settings_.size()); + ASSERT_TRUE(header.IsAck()); + ASSERT_TRUE(EndFrameOfType(Http2FrameType::SETTINGS)) << *this; +} + +void FrameParts::OnPushPromiseStart(const Http2FrameHeader& header, + const Http2PushPromiseFields& promise, + size_t total_padding_length) { + QUICHE_VLOG(1) << "OnPushPromiseStart header: " << header + << "; promise: " << promise + << "; total_padding_length: " << total_padding_length; + ASSERT_TRUE(StartFrameOfType(header, Http2FrameType::PUSH_PROMISE)) << *this; + ASSERT_GE(header.payload_length, Http2PushPromiseFields::EncodedSize()); + opt_payload_length_ = + header.payload_length - Http2PushPromiseFields::EncodedSize(); + ASSERT_FALSE(opt_push_promise_); + opt_push_promise_ = promise; + if (total_padding_length > 0) { + ASSERT_GE(opt_payload_length_.value(), total_padding_length); + OnPadLength(total_padding_length - 1); + } else { + ASSERT_FALSE(header.IsPadded()); + } +} + +void FrameParts::OnPushPromiseEnd() { + QUICHE_VLOG(1) << "OnPushPromiseEnd; frame_header_: " << frame_header_; + ASSERT_TRUE(EndFrameOfType(Http2FrameType::PUSH_PROMISE)) << *this; +} + +void FrameParts::OnPing(const Http2FrameHeader& header, + const Http2PingFields& ping) { + QUICHE_VLOG(1) << "OnPing header: " << header << " ping: " << ping; + ASSERT_TRUE(StartFrameOfType(header, Http2FrameType::PING)) << *this; + ASSERT_FALSE(header.IsAck()); + ASSERT_FALSE(opt_ping_); + opt_ping_ = ping; + ASSERT_TRUE(EndFrameOfType(Http2FrameType::PING)) << *this; +} + +void FrameParts::OnPingAck(const Http2FrameHeader& header, + const Http2PingFields& ping) { + QUICHE_VLOG(1) << "OnPingAck header: " << header << " ping: " << ping; + ASSERT_TRUE(StartFrameOfType(header, Http2FrameType::PING)) << *this; + ASSERT_TRUE(header.IsAck()); + ASSERT_FALSE(opt_ping_); + opt_ping_ = ping; + ASSERT_TRUE(EndFrameOfType(Http2FrameType::PING)) << *this; +} + +void FrameParts::OnGoAwayStart(const Http2FrameHeader& header, + const Http2GoAwayFields& goaway) { + QUICHE_VLOG(1) << "OnGoAwayStart: " << goaway; + ASSERT_TRUE(StartFrameOfType(header, Http2FrameType::GOAWAY)) << *this; + ASSERT_FALSE(opt_goaway_); + opt_goaway_ = goaway; + opt_payload_length_ = + header.payload_length - Http2GoAwayFields::EncodedSize(); +} + +void FrameParts::OnGoAwayOpaqueData(const char* data, size_t len) { + QUICHE_VLOG(1) << "OnGoAwayOpaqueData: len=" << len; + ASSERT_TRUE(InFrameOfType(Http2FrameType::GOAWAY)) << *this; + ASSERT_TRUE(AppendString(absl::string_view(data, len), &payload_, + &opt_payload_length_)); +} + +void FrameParts::OnGoAwayEnd() { + QUICHE_VLOG(1) << "OnGoAwayEnd; frame_header_: " << frame_header_; + ASSERT_TRUE(EndFrameOfType(Http2FrameType::GOAWAY)) << *this; +} + +void FrameParts::OnWindowUpdate(const Http2FrameHeader& header, + uint32_t increment) { + QUICHE_VLOG(1) << "OnWindowUpdate header: " << header + << " increment=" << increment; + ASSERT_TRUE(StartFrameOfType(header, Http2FrameType::WINDOW_UPDATE)) << *this; + ASSERT_FALSE(opt_window_update_increment_); + opt_window_update_increment_ = increment; + ASSERT_TRUE(EndFrameOfType(Http2FrameType::WINDOW_UPDATE)) << *this; +} + +void FrameParts::OnAltSvcStart(const Http2FrameHeader& header, + size_t origin_length, size_t value_length) { + QUICHE_VLOG(1) << "OnAltSvcStart: " << header + << " origin_length: " << origin_length + << " value_length: " << value_length; + ASSERT_TRUE(StartFrameOfType(header, Http2FrameType::ALTSVC)) << *this; + ASSERT_FALSE(opt_altsvc_origin_length_); + opt_altsvc_origin_length_ = origin_length; + ASSERT_FALSE(opt_altsvc_value_length_); + opt_altsvc_value_length_ = value_length; +} + +void FrameParts::OnAltSvcOriginData(const char* data, size_t len) { + QUICHE_VLOG(1) << "OnAltSvcOriginData: len=" << len; + ASSERT_TRUE(InFrameOfType(Http2FrameType::ALTSVC)) << *this; + ASSERT_TRUE(AppendString(absl::string_view(data, len), &altsvc_origin_, + &opt_altsvc_origin_length_)); +} + +void FrameParts::OnAltSvcValueData(const char* data, size_t len) { + QUICHE_VLOG(1) << "OnAltSvcValueData: len=" << len; + ASSERT_TRUE(InFrameOfType(Http2FrameType::ALTSVC)) << *this; + ASSERT_TRUE(AppendString(absl::string_view(data, len), &altsvc_value_, + &opt_altsvc_value_length_)); +} + +void FrameParts::OnAltSvcEnd() { + QUICHE_VLOG(1) << "OnAltSvcEnd; frame_header_: " << frame_header_; + ASSERT_TRUE(EndFrameOfType(Http2FrameType::ALTSVC)) << *this; +} + +void FrameParts::OnPriorityUpdateStart( + const Http2FrameHeader& header, + const Http2PriorityUpdateFields& priority_update) { + QUICHE_VLOG(1) << "OnPriorityUpdateStart: " << header + << " prioritized_stream_id: " + << priority_update.prioritized_stream_id; + ASSERT_TRUE(StartFrameOfType(header, Http2FrameType::PRIORITY_UPDATE)) + << *this; + ASSERT_FALSE(opt_priority_update_); + opt_priority_update_ = priority_update; + opt_payload_length_ = + header.payload_length - Http2PriorityUpdateFields::EncodedSize(); +} + +void FrameParts::OnPriorityUpdatePayload(const char* data, size_t len) { + QUICHE_VLOG(1) << "OnPriorityUpdatePayload: len=" << len; + ASSERT_TRUE(InFrameOfType(Http2FrameType::PRIORITY_UPDATE)) << *this; + payload_.append(data, len); +} + +void FrameParts::OnPriorityUpdateEnd() { + QUICHE_VLOG(1) << "OnPriorityUpdateEnd; frame_header_: " << frame_header_; + ASSERT_TRUE(EndFrameOfType(Http2FrameType::PRIORITY_UPDATE)) << *this; +} + +void FrameParts::OnUnknownStart(const Http2FrameHeader& header) { + QUICHE_VLOG(1) << "OnUnknownStart: " << header; + ASSERT_FALSE(IsSupportedHttp2FrameType(header.type)) << header; + ASSERT_FALSE(got_start_callback_); + ASSERT_EQ(frame_header_, header); + got_start_callback_ = true; + opt_payload_length_ = header.payload_length; +} + +void FrameParts::OnUnknownPayload(const char* data, size_t len) { + QUICHE_VLOG(1) << "OnUnknownPayload: len=" << len; + ASSERT_FALSE(IsSupportedHttp2FrameType(frame_header_.type)) << *this; + ASSERT_TRUE(got_start_callback_); + ASSERT_FALSE(got_end_callback_); + ASSERT_TRUE(AppendString(absl::string_view(data, len), &payload_, + &opt_payload_length_)); +} + +void FrameParts::OnUnknownEnd() { + QUICHE_VLOG(1) << "OnUnknownEnd; frame_header_: " << frame_header_; + ASSERT_FALSE(IsSupportedHttp2FrameType(frame_header_.type)) << *this; + ASSERT_TRUE(got_start_callback_); + ASSERT_FALSE(got_end_callback_); + got_end_callback_ = true; +} + +void FrameParts::OnPaddingTooLong(const Http2FrameHeader& header, + size_t missing_length) { + QUICHE_VLOG(1) << "OnPaddingTooLong: " << header + << "; missing_length: " << missing_length; + ASSERT_EQ(frame_header_, header); + ASSERT_FALSE(got_end_callback_); + ASSERT_TRUE(FrameIsPadded(header)); + ASSERT_FALSE(opt_pad_length_); + ASSERT_FALSE(opt_missing_length_); + opt_missing_length_ = missing_length; + got_start_callback_ = true; + got_end_callback_ = true; +} + +void FrameParts::OnFrameSizeError(const Http2FrameHeader& header) { + QUICHE_VLOG(1) << "OnFrameSizeError: " << header; + ASSERT_EQ(frame_header_, header); + ASSERT_FALSE(got_end_callback_); + ASSERT_FALSE(has_frame_size_error_); + has_frame_size_error_ = true; + got_end_callback_ = true; +} + +void FrameParts::OutputTo(std::ostream& out) const { + out << "FrameParts{\n frame_header_: " << frame_header_ << "\n"; + if (!payload_.empty()) { + out << " payload_=\"" << absl::CHexEscape(payload_) << "\"\n"; + } + if (!padding_.empty()) { + out << " padding_=\"" << absl::CHexEscape(padding_) << "\"\n"; + } + if (!altsvc_origin_.empty()) { + out << " altsvc_origin_=\"" << absl::CHexEscape(altsvc_origin_) << "\"\n"; + } + if (!altsvc_value_.empty()) { + out << " altsvc_value_=\"" << absl::CHexEscape(altsvc_value_) << "\"\n"; + } + if (opt_priority_) { + out << " priority=" << opt_priority_.value() << "\n"; + } + if (opt_rst_stream_error_code_) { + out << " rst_stream=" << opt_rst_stream_error_code_.value() << "\n"; + } + if (opt_push_promise_) { + out << " push_promise=" << opt_push_promise_.value() << "\n"; + } + if (opt_ping_) { + out << " ping=" << opt_ping_.value() << "\n"; + } + if (opt_goaway_) { + out << " goaway=" << opt_goaway_.value() << "\n"; + } + if (opt_window_update_increment_) { + out << " window_update=" << opt_window_update_increment_.value() << "\n"; + } + if (opt_payload_length_) { + out << " payload_length=" << opt_payload_length_.value() << "\n"; + } + if (opt_pad_length_) { + out << " pad_length=" << opt_pad_length_.value() << "\n"; + } + if (opt_missing_length_) { + out << " missing_length=" << opt_missing_length_.value() << "\n"; + } + if (opt_altsvc_origin_length_) { + out << " origin_length=" << opt_altsvc_origin_length_.value() << "\n"; + } + if (opt_altsvc_value_length_) { + out << " value_length=" << opt_altsvc_value_length_.value() << "\n"; + } + if (opt_priority_update_) { + out << " prioritized_stream_id_=" << opt_priority_update_.value() << "\n"; + } + if (has_frame_size_error_) { + out << " has_frame_size_error\n"; + } + if (got_start_callback_) { + out << " got_start_callback\n"; + } + if (got_end_callback_) { + out << " got_end_callback\n"; + } + for (size_t ndx = 0; ndx < settings_.size(); ++ndx) { + out << " setting[" << ndx << "]=" << settings_[ndx]; + } + out << "}"; +} + +AssertionResult FrameParts::StartFrameOfType( + const Http2FrameHeader& header, Http2FrameType expected_frame_type) { + HTTP2_VERIFY_EQ(header.type, expected_frame_type); + HTTP2_VERIFY_FALSE(got_start_callback_); + HTTP2_VERIFY_FALSE(got_end_callback_); + HTTP2_VERIFY_EQ(frame_header_, header); + got_start_callback_ = true; + return AssertionSuccess(); +} + +AssertionResult FrameParts::InFrameOfType(Http2FrameType expected_frame_type) { + HTTP2_VERIFY_TRUE(got_start_callback_); + HTTP2_VERIFY_FALSE(got_end_callback_); + HTTP2_VERIFY_EQ(frame_header_.type, expected_frame_type); + return AssertionSuccess(); +} + +AssertionResult FrameParts::EndFrameOfType(Http2FrameType expected_frame_type) { + HTTP2_VERIFY_SUCCESS(InFrameOfType(expected_frame_type)); + got_end_callback_ = true; + return AssertionSuccess(); +} + +AssertionResult FrameParts::InPaddedFrame() { + HTTP2_VERIFY_TRUE(got_start_callback_); + HTTP2_VERIFY_FALSE(got_end_callback_); + HTTP2_VERIFY_TRUE(FrameIsPadded(frame_header_)); + return AssertionSuccess(); +} + +AssertionResult FrameParts::AppendString(absl::string_view source, + std::string* target, + absl::optional* opt_length) { + target->append(source.data(), source.size()); + if (opt_length != nullptr) { + HTTP2_VERIFY_TRUE(*opt_length) << "Length is not set yet\n" << *this; + HTTP2_VERIFY_LE(target->size(), opt_length->value()) + << "String too large; source.size() = " << source.size() << "\n" + << *this; + } + return ::testing::AssertionSuccess(); +} + +std::ostream& operator<<(std::ostream& out, const FrameParts& v) { + v.OutputTo(out); + return out; +} + +} // namespace test +} // namespace http2 diff --git a/quiche/http2/test_tools/frame_parts.h b/quiche/http2/test_tools/frame_parts.h new file mode 100644 index 000000000000..d1162284abcd --- /dev/null +++ b/quiche/http2/test_tools/frame_parts.h @@ -0,0 +1,258 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_TEST_TOOLS_FRAME_PARTS_H_ +#define QUICHE_HTTP2_TEST_TOOLS_FRAME_PARTS_H_ + +// FrameParts implements Http2FrameDecoderListener, recording the callbacks +// during the decoding of a single frame. It is also used for comparing the +// info that a test expects to be recorded during the decoding of a frame +// with the actual recorded value (i.e. by providing a comparator). + +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { + +class QUICHE_NO_EXPORT FrameParts : public Http2FrameDecoderListener { + public: + // The first callback for every type of frame includes the frame header; this + // is the only constructor used during decoding of a frame. + explicit FrameParts(const Http2FrameHeader& header); + + // For use in tests where the expected frame has a variable size payload. + FrameParts(const Http2FrameHeader& header, absl::string_view payload); + + // For use in tests where the expected frame has a variable size payload + // and may be padded. + FrameParts(const Http2FrameHeader& header, absl::string_view payload, + size_t total_pad_length); + + // Copy constructor. + FrameParts(const FrameParts& header); + + ~FrameParts() override; + + // Returns AssertionSuccess() if they're equal, else AssertionFailure() + // with info about the difference. + ::testing::AssertionResult VerifyEquals(const FrameParts& other) const; + + // Format this FrameParts object. + void OutputTo(std::ostream& out) const; + + // Set the total padding length (0 to 256). + void SetTotalPadLength(size_t total_pad_length); + + // Set the origin and value expected in an ALTSVC frame. + void SetAltSvcExpected(absl::string_view origin, absl::string_view value); + + // Http2FrameDecoderListener methods: + bool OnFrameHeader(const Http2FrameHeader& header) override; + void OnDataStart(const Http2FrameHeader& header) override; + void OnDataPayload(const char* data, size_t len) override; + void OnDataEnd() override; + void OnHeadersStart(const Http2FrameHeader& header) override; + void OnHeadersPriority(const Http2PriorityFields& priority) override; + void OnHpackFragment(const char* data, size_t len) override; + void OnHeadersEnd() override; + void OnPriorityFrame(const Http2FrameHeader& header, + const Http2PriorityFields& priority) override; + void OnContinuationStart(const Http2FrameHeader& header) override; + void OnContinuationEnd() override; + void OnPadLength(size_t trailing_length) override; + void OnPadding(const char* pad, size_t skipped_length) override; + void OnRstStream(const Http2FrameHeader& header, + Http2ErrorCode error_code) override; + void OnSettingsStart(const Http2FrameHeader& header) override; + void OnSetting(const Http2SettingFields& setting_fields) override; + void OnSettingsEnd() override; + void OnSettingsAck(const Http2FrameHeader& header) override; + void OnPushPromiseStart(const Http2FrameHeader& header, + const Http2PushPromiseFields& promise, + size_t total_padding_length) override; + void OnPushPromiseEnd() override; + void OnPing(const Http2FrameHeader& header, + const Http2PingFields& ping) override; + void OnPingAck(const Http2FrameHeader& header, + const Http2PingFields& ping) override; + void OnGoAwayStart(const Http2FrameHeader& header, + const Http2GoAwayFields& goaway) override; + void OnGoAwayOpaqueData(const char* data, size_t len) override; + void OnGoAwayEnd() override; + void OnWindowUpdate(const Http2FrameHeader& header, + uint32_t increment) override; + void OnAltSvcStart(const Http2FrameHeader& header, size_t origin_length, + size_t value_length) override; + void OnAltSvcOriginData(const char* data, size_t len) override; + void OnAltSvcValueData(const char* data, size_t len) override; + void OnAltSvcEnd() override; + void OnPriorityUpdateStart( + const Http2FrameHeader& header, + const Http2PriorityUpdateFields& priority_update) override; + void OnPriorityUpdatePayload(const char* data, size_t len) override; + void OnPriorityUpdateEnd() override; + void OnUnknownStart(const Http2FrameHeader& header) override; + void OnUnknownPayload(const char* data, size_t len) override; + void OnUnknownEnd() override; + void OnPaddingTooLong(const Http2FrameHeader& header, + size_t missing_length) override; + void OnFrameSizeError(const Http2FrameHeader& header) override; + + void AppendSetting(const Http2SettingFields& setting_fields) { + settings_.push_back(setting_fields); + } + + const Http2FrameHeader& GetFrameHeader() const { return frame_header_; } + + absl::optional GetOptPriority() const { + return opt_priority_; + } + absl::optional GetOptRstStreamErrorCode() const { + return opt_rst_stream_error_code_; + } + absl::optional GetOptPushPromise() const { + return opt_push_promise_; + } + absl::optional GetOptPing() const { return opt_ping_; } + absl::optional GetOptGoaway() const { return opt_goaway_; } + absl::optional GetOptPadLength() const { return opt_pad_length_; } + absl::optional GetOptPayloadLength() const { + return opt_payload_length_; + } + absl::optional GetOptMissingLength() const { + return opt_missing_length_; + } + absl::optional GetOptAltsvcOriginLength() const { + return opt_altsvc_origin_length_; + } + absl::optional GetOptAltsvcValueLength() const { + return opt_altsvc_value_length_; + } + absl::optional GetOptWindowUpdateIncrement() const { + return opt_window_update_increment_; + } + bool GetHasFrameSizeError() const { return has_frame_size_error_; } + + void SetOptPriority(absl::optional opt_priority) { + opt_priority_ = opt_priority; + } + void SetOptRstStreamErrorCode( + absl::optional opt_rst_stream_error_code) { + opt_rst_stream_error_code_ = opt_rst_stream_error_code; + } + void SetOptPushPromise( + absl::optional opt_push_promise) { + opt_push_promise_ = opt_push_promise; + } + void SetOptPing(absl::optional opt_ping) { + opt_ping_ = opt_ping; + } + void SetOptGoaway(absl::optional opt_goaway) { + opt_goaway_ = opt_goaway; + } + void SetOptPadLength(absl::optional opt_pad_length) { + opt_pad_length_ = opt_pad_length; + } + void SetOptPayloadLength(absl::optional opt_payload_length) { + opt_payload_length_ = opt_payload_length; + } + void SetOptMissingLength(absl::optional opt_missing_length) { + opt_missing_length_ = opt_missing_length; + } + void SetOptAltsvcOriginLength( + absl::optional opt_altsvc_origin_length) { + opt_altsvc_origin_length_ = opt_altsvc_origin_length; + } + void SetOptAltsvcValueLength(absl::optional opt_altsvc_value_length) { + opt_altsvc_value_length_ = opt_altsvc_value_length; + } + void SetOptWindowUpdateIncrement( + absl::optional opt_window_update_increment) { + opt_window_update_increment_ = opt_window_update_increment; + } + void SetOptPriorityUpdate( + absl::optional priority_update) { + opt_priority_update_ = priority_update; + } + + void SetHasFrameSizeError(bool has_frame_size_error) { + has_frame_size_error_ = has_frame_size_error; + } + + private: + // ASSERT during an On* method that we're handling a frame of type + // expected_frame_type, and have not already received other On* methods + // (i.e. got_start_callback is false). + ::testing::AssertionResult StartFrameOfType( + const Http2FrameHeader& header, Http2FrameType expected_frame_type); + + // ASSERT that StartFrameOfType has already been called with + // expected_frame_type (i.e. got_start_callback has been called), and that + // EndFrameOfType has not yet been called (i.e. got_end_callback is false). + ::testing::AssertionResult InFrameOfType(Http2FrameType expected_frame_type); + + // ASSERT that we're InFrameOfType, and then sets got_end_callback=true. + ::testing::AssertionResult EndFrameOfType(Http2FrameType expected_frame_type); + + // ASSERT that we're in the middle of processing a frame that is padded. + ::testing::AssertionResult InPaddedFrame(); + + // Append source to target. If opt_length is not nullptr, then verifies that + // the optional has a value (i.e. that the necessary On*Start method has been + // called), and that target is not longer than opt_length->value(). + ::testing::AssertionResult AppendString(absl::string_view source, + std::string* target, + absl::optional* opt_length); + + const Http2FrameHeader frame_header_; + + std::string payload_; + std::string padding_; + std::string altsvc_origin_; + std::string altsvc_value_; + + absl::optional opt_priority_; + absl::optional opt_rst_stream_error_code_; + absl::optional opt_push_promise_; + absl::optional opt_ping_; + absl::optional opt_goaway_; + absl::optional opt_priority_update_; + + absl::optional opt_pad_length_; + absl::optional opt_payload_length_; + absl::optional opt_missing_length_; + absl::optional opt_altsvc_origin_length_; + absl::optional opt_altsvc_value_length_; + + absl::optional opt_window_update_increment_; + + bool has_frame_size_error_ = false; + + std::vector settings_; + + // These booleans are not checked by CompareCollectedFrames. + bool got_start_callback_ = false; + bool got_end_callback_ = false; +}; + +QUICHE_NO_EXPORT std::ostream& operator<<(std::ostream& out, + const FrameParts& v); + +} // namespace test +} // namespace http2 + +#endif // QUICHE_HTTP2_TEST_TOOLS_FRAME_PARTS_H_ diff --git a/quiche/http2/test_tools/frame_parts_collector.cc b/quiche/http2/test_tools/frame_parts_collector.cc new file mode 100644 index 000000000000..2b8f6162c6d1 --- /dev/null +++ b/quiche/http2/test_tools/frame_parts_collector.cc @@ -0,0 +1,112 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/test_tools/frame_parts_collector.h" + +#include + +#include "quiche/http2/test_tools/http2_structures_test_util.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { + +FramePartsCollector::FramePartsCollector() = default; +FramePartsCollector::~FramePartsCollector() = default; + +void FramePartsCollector::Reset() { + current_frame_.reset(); + collected_frames_.clear(); + expected_header_set_ = false; +} + +const FrameParts* FramePartsCollector::frame(size_t n) const { + if (n < size()) { + return collected_frames_.at(n).get(); + } + QUICHE_CHECK(n == size()); + return current_frame(); +} + +void FramePartsCollector::ExpectFrameHeader(const Http2FrameHeader& header) { + EXPECT_FALSE(IsInProgress()); + EXPECT_FALSE(expected_header_set_) + << "expected_header_: " << expected_header_; + expected_header_ = header; + expected_header_set_ = true; + // OnFrameHeader is called before the flags are scrubbed, but the other + // methods are called after, so scrub the invalid flags from expected_header_. + ScrubFlagsOfHeader(&expected_header_); +} + +void FramePartsCollector::TestExpectedHeader(const Http2FrameHeader& header) { + if (expected_header_set_) { + EXPECT_EQ(header, expected_header_); + expected_header_set_ = false; + } +} + +Http2FrameDecoderListener* FramePartsCollector::StartFrame( + const Http2FrameHeader& header) { + TestExpectedHeader(header); + EXPECT_FALSE(IsInProgress()); + if (current_frame_ == nullptr) { + current_frame_ = std::make_unique(header); + } + return current_frame(); +} + +Http2FrameDecoderListener* FramePartsCollector::StartAndEndFrame( + const Http2FrameHeader& header) { + TestExpectedHeader(header); + EXPECT_FALSE(IsInProgress()); + if (current_frame_ == nullptr) { + current_frame_ = std::make_unique(header); + } + Http2FrameDecoderListener* result = current_frame(); + collected_frames_.push_back(std::move(current_frame_)); + return result; +} + +Http2FrameDecoderListener* FramePartsCollector::CurrentFrame() { + EXPECT_TRUE(IsInProgress()); + if (current_frame_ == nullptr) { + return &failing_listener_; + } + return current_frame(); +} + +Http2FrameDecoderListener* FramePartsCollector::EndFrame() { + EXPECT_TRUE(IsInProgress()); + if (current_frame_ == nullptr) { + return &failing_listener_; + } + Http2FrameDecoderListener* result = current_frame(); + collected_frames_.push_back(std::move(current_frame_)); + return result; +} + +Http2FrameDecoderListener* FramePartsCollector::FrameError( + const Http2FrameHeader& header) { + TestExpectedHeader(header); + if (current_frame_ == nullptr) { + // The decoder may detect an error before making any calls to the listener + // regarding the frame, in which case current_frame_==nullptr and we need + // to create a FrameParts instance. + current_frame_ = std::make_unique(header); + } else { + // Similarly, the decoder may have made calls to the listener regarding the + // frame before detecting the error; for example, the DATA payload decoder + // calls OnDataStart before it can detect padding errors, hence before it + // can call OnPaddingTooLong. + EXPECT_EQ(header, current_frame_->GetFrameHeader()); + } + Http2FrameDecoderListener* result = current_frame(); + collected_frames_.push_back(std::move(current_frame_)); + return result; +} + +} // namespace test +} // namespace http2 diff --git a/quiche/http2/test_tools/frame_parts_collector.h b/quiche/http2/test_tools/frame_parts_collector.h new file mode 100644 index 000000000000..58adcdb00a97 --- /dev/null +++ b/quiche/http2/test_tools/frame_parts_collector.h @@ -0,0 +1,113 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_TEST_TOOLS_FRAME_PARTS_COLLECTOR_H_ +#define QUICHE_HTTP2_TEST_TOOLS_FRAME_PARTS_COLLECTOR_H_ + +// FramePartsCollector is a base class for Http2FrameDecoderListener +// implementations that create one FrameParts instance for each decoded frame. + +#include + +#include +#include + +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/http2/test_tools/frame_parts.h" +#include "quiche/http2/test_tools/http2_frame_decoder_listener_test_util.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace test { + +class QUICHE_NO_EXPORT FramePartsCollector + : public FailingHttp2FrameDecoderListener { + public: + FramePartsCollector(); + ~FramePartsCollector() override; + + // Toss out the collected data. + void Reset(); + + // Returns true if has started recording the info for a frame and has not yet + // finished doing so. + bool IsInProgress() const { return current_frame_ != nullptr; } + + // Returns the FrameParts instance into which we're currently recording + // callback info if IsInProgress, else nullptr. + const FrameParts* current_frame() const { return current_frame_.get(); } + + // Returns the number of completely collected FrameParts instances. + size_t size() const { return collected_frames_.size(); } + + // Returns the n'th frame, where 0 is the oldest of the collected frames, + // and n==size() is the frame currently being collected, if there is one. + // Returns nullptr if the requested index is not valid. + const FrameParts* frame(size_t n) const; + + protected: + // In support of OnFrameHeader, set the header that we expect to be used in + // the next call. + // TODO(jamessynge): Remove ExpectFrameHeader et al. once done with supporting + // SpdyFramer's exact states. + void ExpectFrameHeader(const Http2FrameHeader& header); + + // For use in implementing On*Start methods of Http2FrameDecoderListener, + // returns a FrameParts instance, which will be newly created if + // IsInProgress==false (which the caller should ensure), else will be the + // current_frame(); never returns nullptr. + // If called when IsInProgress==true, a test failure will be recorded. + Http2FrameDecoderListener* StartFrame(const Http2FrameHeader& header); + + // For use in implementing On* callbacks, such as OnPingAck, that are the only + // call expected for the frame being decoded; not for On*Start methods. + // Returns a FrameParts instance, which will be newly created if + // IsInProgress==false (which the caller should ensure), else will be the + // current_frame(); never returns nullptr. + // If called when IsInProgress==true, a test failure will be recorded. + Http2FrameDecoderListener* StartAndEndFrame(const Http2FrameHeader& header); + + // If IsInProgress==true, returns the FrameParts into which the current + // frame is being recorded; else records a test failure and returns + // failing_listener_, which will record a test failure when any of its + // On* methods is called. + Http2FrameDecoderListener* CurrentFrame(); + + // For use in implementing On*End methods, pushes the current frame onto + // the vector of completed frames, and returns a pointer to it for recording + // the info in the final call. If IsInProgress==false, records a test failure + // and returns failing_listener_, which will record a test failure when any + // of its On* methods is called. + Http2FrameDecoderListener* EndFrame(); + + // For use in implementing OnPaddingTooLong and OnFrameSizeError, is + // equivalent to EndFrame() if IsInProgress==true, else equivalent to + // StartAndEndFrame(). + Http2FrameDecoderListener* FrameError(const Http2FrameHeader& header); + + private: + // Returns the mutable FrameParts instance into which we're currently + // recording callback info if IsInProgress, else nullptr. + FrameParts* current_frame() { return current_frame_.get(); } + + // If expected header is set, verify that it matches the header param. + // TODO(jamessynge): Remove TestExpectedHeader et al. once done + // with supporting SpdyFramer's exact states. + void TestExpectedHeader(const Http2FrameHeader& header); + + std::unique_ptr current_frame_; + std::vector> collected_frames_; + FailingHttp2FrameDecoderListener failing_listener_; + + // TODO(jamessynge): Remove expected_header_ et al. once done with supporting + // SpdyFramer's exact states. + Http2FrameHeader expected_header_; + bool expected_header_set_ = false; +}; + +} // namespace test +} // namespace http2 + +#endif // QUICHE_HTTP2_TEST_TOOLS_FRAME_PARTS_COLLECTOR_H_ diff --git a/quiche/http2/test_tools/frame_parts_collector_listener.cc b/quiche/http2/test_tools/frame_parts_collector_listener.cc new file mode 100644 index 000000000000..21327a0c3da1 --- /dev/null +++ b/quiche/http2/test_tools/frame_parts_collector_listener.cc @@ -0,0 +1,247 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/test_tools/frame_parts_collector_listener.h" + +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { + +bool FramePartsCollectorListener::OnFrameHeader( + const Http2FrameHeader& header) { + QUICHE_VLOG(1) << "OnFrameHeader: " << header; + ExpectFrameHeader(header); + return true; +} + +void FramePartsCollectorListener::OnDataStart(const Http2FrameHeader& header) { + QUICHE_VLOG(1) << "OnDataStart: " << header; + StartFrame(header)->OnDataStart(header); +} + +void FramePartsCollectorListener::OnDataPayload(const char* data, size_t len) { + QUICHE_VLOG(1) << "OnDataPayload: len=" << len; + CurrentFrame()->OnDataPayload(data, len); +} + +void FramePartsCollectorListener::OnDataEnd() { + QUICHE_VLOG(1) << "OnDataEnd"; + EndFrame()->OnDataEnd(); +} + +void FramePartsCollectorListener::OnHeadersStart( + const Http2FrameHeader& header) { + QUICHE_VLOG(1) << "OnHeadersStart: " << header; + StartFrame(header)->OnHeadersStart(header); +} + +void FramePartsCollectorListener::OnHeadersPriority( + const Http2PriorityFields& priority) { + QUICHE_VLOG(1) << "OnHeadersPriority: " << priority; + CurrentFrame()->OnHeadersPriority(priority); +} + +void FramePartsCollectorListener::OnHpackFragment(const char* data, + size_t len) { + QUICHE_VLOG(1) << "OnHpackFragment: len=" << len; + CurrentFrame()->OnHpackFragment(data, len); +} + +void FramePartsCollectorListener::OnHeadersEnd() { + QUICHE_VLOG(1) << "OnHeadersEnd"; + EndFrame()->OnHeadersEnd(); +} + +void FramePartsCollectorListener::OnPriorityFrame( + const Http2FrameHeader& header, + const Http2PriorityFields& priority_fields) { + QUICHE_VLOG(1) << "OnPriority: " << header << "; " << priority_fields; + StartAndEndFrame(header)->OnPriorityFrame(header, priority_fields); +} + +void FramePartsCollectorListener::OnContinuationStart( + const Http2FrameHeader& header) { + QUICHE_VLOG(1) << "OnContinuationStart: " << header; + StartFrame(header)->OnContinuationStart(header); +} + +void FramePartsCollectorListener::OnContinuationEnd() { + QUICHE_VLOG(1) << "OnContinuationEnd"; + EndFrame()->OnContinuationEnd(); +} + +void FramePartsCollectorListener::OnPadLength(size_t pad_length) { + QUICHE_VLOG(1) << "OnPadLength: " << pad_length; + CurrentFrame()->OnPadLength(pad_length); +} + +void FramePartsCollectorListener::OnPadding(const char* padding, + size_t skipped_length) { + QUICHE_VLOG(1) << "OnPadding: " << skipped_length; + CurrentFrame()->OnPadding(padding, skipped_length); +} + +void FramePartsCollectorListener::OnRstStream(const Http2FrameHeader& header, + Http2ErrorCode error_code) { + QUICHE_VLOG(1) << "OnRstStream: " << header << "; error_code=" << error_code; + StartAndEndFrame(header)->OnRstStream(header, error_code); +} + +void FramePartsCollectorListener::OnSettingsStart( + const Http2FrameHeader& header) { + QUICHE_VLOG(1) << "OnSettingsStart: " << header; + EXPECT_EQ(Http2FrameType::SETTINGS, header.type) << header; + EXPECT_EQ(Http2FrameFlag(), header.flags) << header; + StartFrame(header)->OnSettingsStart(header); +} + +void FramePartsCollectorListener::OnSetting( + const Http2SettingFields& setting_fields) { + QUICHE_VLOG(1) << "Http2SettingFields: setting_fields=" << setting_fields; + CurrentFrame()->OnSetting(setting_fields); +} + +void FramePartsCollectorListener::OnSettingsEnd() { + QUICHE_VLOG(1) << "OnSettingsEnd"; + EndFrame()->OnSettingsEnd(); +} + +void FramePartsCollectorListener::OnSettingsAck( + const Http2FrameHeader& header) { + QUICHE_VLOG(1) << "OnSettingsAck: " << header; + StartAndEndFrame(header)->OnSettingsAck(header); +} + +void FramePartsCollectorListener::OnPushPromiseStart( + const Http2FrameHeader& header, const Http2PushPromiseFields& promise, + size_t total_padding_length) { + QUICHE_VLOG(1) << "OnPushPromiseStart header: " << header + << " promise: " << promise + << " total_padding_length: " << total_padding_length; + EXPECT_EQ(Http2FrameType::PUSH_PROMISE, header.type); + StartFrame(header)->OnPushPromiseStart(header, promise, total_padding_length); +} + +void FramePartsCollectorListener::OnPushPromiseEnd() { + QUICHE_VLOG(1) << "OnPushPromiseEnd"; + EndFrame()->OnPushPromiseEnd(); +} + +void FramePartsCollectorListener::OnPing(const Http2FrameHeader& header, + const Http2PingFields& ping) { + QUICHE_VLOG(1) << "OnPing: " << header << "; " << ping; + StartAndEndFrame(header)->OnPing(header, ping); +} + +void FramePartsCollectorListener::OnPingAck(const Http2FrameHeader& header, + const Http2PingFields& ping) { + QUICHE_VLOG(1) << "OnPingAck: " << header << "; " << ping; + StartAndEndFrame(header)->OnPingAck(header, ping); +} + +void FramePartsCollectorListener::OnGoAwayStart( + const Http2FrameHeader& header, const Http2GoAwayFields& goaway) { + QUICHE_VLOG(1) << "OnGoAwayStart header: " << header + << "; goaway: " << goaway; + StartFrame(header)->OnGoAwayStart(header, goaway); +} + +void FramePartsCollectorListener::OnGoAwayOpaqueData(const char* data, + size_t len) { + QUICHE_VLOG(1) << "OnGoAwayOpaqueData: len=" << len; + CurrentFrame()->OnGoAwayOpaqueData(data, len); +} + +void FramePartsCollectorListener::OnGoAwayEnd() { + QUICHE_VLOG(1) << "OnGoAwayEnd"; + EndFrame()->OnGoAwayEnd(); +} + +void FramePartsCollectorListener::OnWindowUpdate( + const Http2FrameHeader& header, uint32_t window_size_increment) { + QUICHE_VLOG(1) << "OnWindowUpdate: " << header + << "; window_size_increment=" << window_size_increment; + EXPECT_EQ(Http2FrameType::WINDOW_UPDATE, header.type); + StartAndEndFrame(header)->OnWindowUpdate(header, window_size_increment); +} + +void FramePartsCollectorListener::OnAltSvcStart(const Http2FrameHeader& header, + size_t origin_length, + size_t value_length) { + QUICHE_VLOG(1) << "OnAltSvcStart header: " << header + << "; origin_length=" << origin_length + << "; value_length=" << value_length; + StartFrame(header)->OnAltSvcStart(header, origin_length, value_length); +} + +void FramePartsCollectorListener::OnAltSvcOriginData(const char* data, + size_t len) { + QUICHE_VLOG(1) << "OnAltSvcOriginData: len=" << len; + CurrentFrame()->OnAltSvcOriginData(data, len); +} + +void FramePartsCollectorListener::OnAltSvcValueData(const char* data, + size_t len) { + QUICHE_VLOG(1) << "OnAltSvcValueData: len=" << len; + CurrentFrame()->OnAltSvcValueData(data, len); +} + +void FramePartsCollectorListener::OnAltSvcEnd() { + QUICHE_VLOG(1) << "OnAltSvcEnd"; + EndFrame()->OnAltSvcEnd(); +} + +void FramePartsCollectorListener::OnPriorityUpdateStart( + const Http2FrameHeader& header, + const Http2PriorityUpdateFields& priority_update) { + QUICHE_VLOG(1) << "OnPriorityUpdateStart header: " << header + << "; priority_update=" << priority_update; + StartFrame(header)->OnPriorityUpdateStart(header, priority_update); +} + +void FramePartsCollectorListener::OnPriorityUpdatePayload(const char* data, + size_t len) { + QUICHE_VLOG(1) << "OnPriorityUpdatePayload: len=" << len; + CurrentFrame()->OnPriorityUpdatePayload(data, len); +} + +void FramePartsCollectorListener::OnPriorityUpdateEnd() { + QUICHE_VLOG(1) << "OnPriorityUpdateEnd"; + EndFrame()->OnPriorityUpdateEnd(); +} + +void FramePartsCollectorListener::OnUnknownStart( + const Http2FrameHeader& header) { + QUICHE_VLOG(1) << "OnUnknownStart: " << header; + StartFrame(header)->OnUnknownStart(header); +} + +void FramePartsCollectorListener::OnUnknownPayload(const char* data, + size_t len) { + QUICHE_VLOG(1) << "OnUnknownPayload: len=" << len; + CurrentFrame()->OnUnknownPayload(data, len); +} + +void FramePartsCollectorListener::OnUnknownEnd() { + QUICHE_VLOG(1) << "OnUnknownEnd"; + EndFrame()->OnUnknownEnd(); +} + +void FramePartsCollectorListener::OnPaddingTooLong( + const Http2FrameHeader& header, size_t missing_length) { + QUICHE_VLOG(1) << "OnPaddingTooLong: " << header + << " missing_length: " << missing_length; + EndFrame()->OnPaddingTooLong(header, missing_length); +} + +void FramePartsCollectorListener::OnFrameSizeError( + const Http2FrameHeader& header) { + QUICHE_VLOG(1) << "OnFrameSizeError: " << header; + FrameError(header)->OnFrameSizeError(header); +} + +} // namespace test +} // namespace http2 diff --git a/quiche/http2/test_tools/frame_parts_collector_listener.h b/quiche/http2/test_tools/frame_parts_collector_listener.h new file mode 100644 index 000000000000..a7412d48b23c --- /dev/null +++ b/quiche/http2/test_tools/frame_parts_collector_listener.h @@ -0,0 +1,91 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_TEST_TOOLS_FRAME_PARTS_COLLECTOR_LISTENER_H_ +#define QUICHE_HTTP2_TEST_TOOLS_FRAME_PARTS_COLLECTOR_LISTENER_H_ + +// FramePartsCollectorListener extends FramePartsCollector with an +// implementation of every method of Http2FrameDecoderListener; it is +// essentially the union of all the Listener classes in the tests of the +// payload decoders (i.e. in ./payload_decoders/*_test.cc files), with the +// addition of the OnFrameHeader method. +// FramePartsCollectorListener supports tests of Http2FrameDecoder. + +#include + +#include + +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/http2/test_tools/frame_parts_collector.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace test { + +class QUICHE_NO_EXPORT FramePartsCollectorListener + : public FramePartsCollector { + public: + FramePartsCollectorListener() {} + ~FramePartsCollectorListener() override {} + + // TODO(jamessynge): Remove OnFrameHeader once done with supporting + // SpdyFramer's exact states. + bool OnFrameHeader(const Http2FrameHeader& header) override; + void OnDataStart(const Http2FrameHeader& header) override; + void OnDataPayload(const char* data, size_t len) override; + void OnDataEnd() override; + void OnHeadersStart(const Http2FrameHeader& header) override; + void OnHeadersPriority(const Http2PriorityFields& priority) override; + void OnHpackFragment(const char* data, size_t len) override; + void OnHeadersEnd() override; + void OnPriorityFrame(const Http2FrameHeader& header, + const Http2PriorityFields& priority_fields) override; + void OnContinuationStart(const Http2FrameHeader& header) override; + void OnContinuationEnd() override; + void OnPadLength(size_t pad_length) override; + void OnPadding(const char* padding, size_t skipped_length) override; + void OnRstStream(const Http2FrameHeader& header, + Http2ErrorCode error_code) override; + void OnSettingsStart(const Http2FrameHeader& header) override; + void OnSetting(const Http2SettingFields& setting_fields) override; + void OnSettingsEnd() override; + void OnSettingsAck(const Http2FrameHeader& header) override; + void OnPushPromiseStart(const Http2FrameHeader& header, + const Http2PushPromiseFields& promise, + size_t total_padding_length) override; + void OnPushPromiseEnd() override; + void OnPing(const Http2FrameHeader& header, + const Http2PingFields& ping) override; + void OnPingAck(const Http2FrameHeader& header, + const Http2PingFields& ping) override; + void OnGoAwayStart(const Http2FrameHeader& header, + const Http2GoAwayFields& goaway) override; + void OnGoAwayOpaqueData(const char* data, size_t len) override; + void OnGoAwayEnd() override; + void OnWindowUpdate(const Http2FrameHeader& header, + uint32_t window_size_increment) override; + void OnAltSvcStart(const Http2FrameHeader& header, size_t origin_length, + size_t value_length) override; + void OnAltSvcOriginData(const char* data, size_t len) override; + void OnAltSvcValueData(const char* data, size_t len) override; + void OnAltSvcEnd() override; + void OnPriorityUpdateStart( + const Http2FrameHeader& header, + const Http2PriorityUpdateFields& priority_update) override; + void OnPriorityUpdatePayload(const char* data, size_t len) override; + void OnPriorityUpdateEnd() override; + void OnUnknownStart(const Http2FrameHeader& header) override; + void OnUnknownPayload(const char* data, size_t len) override; + void OnUnknownEnd() override; + void OnPaddingTooLong(const Http2FrameHeader& header, + size_t missing_length) override; + void OnFrameSizeError(const Http2FrameHeader& header) override; +}; + +} // namespace test +} // namespace http2 + +#endif // QUICHE_HTTP2_TEST_TOOLS_FRAME_PARTS_COLLECTOR_LISTENER_H_ diff --git a/quiche/http2/test_tools/hpack_block_builder.cc b/quiche/http2/test_tools/hpack_block_builder.cc new file mode 100644 index 000000000000..bd9119a2fe71 --- /dev/null +++ b/quiche/http2/test_tools/hpack_block_builder.cc @@ -0,0 +1,66 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/test_tools/hpack_block_builder.h" + +#include "quiche/http2/hpack/varint/hpack_varint_encoder.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { + +void HpackBlockBuilder::AppendHighBitsAndVarint(uint8_t high_bits, + uint8_t prefix_length, + uint64_t varint) { + EXPECT_LE(3, prefix_length); + EXPECT_LE(prefix_length, 8); + + HpackVarintEncoder::Encode(high_bits, prefix_length, varint, &buffer_); +} + +void HpackBlockBuilder::AppendEntryTypeAndVarint(HpackEntryType entry_type, + uint64_t varint) { + uint8_t high_bits; + uint8_t prefix_length; // Bits of the varint prefix in the first byte. + switch (entry_type) { + case HpackEntryType::kIndexedHeader: + high_bits = 0x80; + prefix_length = 7; + break; + case HpackEntryType::kDynamicTableSizeUpdate: + high_bits = 0x20; + prefix_length = 5; + break; + case HpackEntryType::kIndexedLiteralHeader: + high_bits = 0x40; + prefix_length = 6; + break; + case HpackEntryType::kUnindexedLiteralHeader: + high_bits = 0x00; + prefix_length = 4; + break; + case HpackEntryType::kNeverIndexedLiteralHeader: + high_bits = 0x10; + prefix_length = 4; + break; + default: + QUICHE_BUG(http2_bug_110_1) << "Unreached, entry_type=" << entry_type; + high_bits = 0; + prefix_length = 0; + break; + } + AppendHighBitsAndVarint(high_bits, prefix_length, varint); +} + +void HpackBlockBuilder::AppendString(bool is_huffman_encoded, + absl::string_view str) { + uint8_t high_bits = is_huffman_encoded ? 0x80 : 0; + uint8_t prefix_length = 7; + AppendHighBitsAndVarint(high_bits, prefix_length, str.size()); + buffer_.append(str.data(), str.size()); +} + +} // namespace test +} // namespace http2 diff --git a/quiche/http2/test_tools/hpack_block_builder.h b/quiche/http2/test_tools/hpack_block_builder.h new file mode 100644 index 000000000000..5aec637d3c7b --- /dev/null +++ b/quiche/http2/test_tools/hpack_block_builder.h @@ -0,0 +1,97 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_TEST_TOOLS_HPACK_BLOCK_BUILDER_H_ +#define QUICHE_HTTP2_TEST_TOOLS_HPACK_BLOCK_BUILDER_H_ + +// HpackBlockBuilder builds wire-format HPACK blocks (or fragments thereof) +// from components. + +// Supports very large varints to enable tests to create HPACK blocks with +// values that the decoder should reject. For now, this is only intended for +// use in tests, and thus has EXPECT* in the code. If desired to use it in an +// encoder, it will need optimization work, especially w.r.t memory mgmt, and +// the EXPECT* will need to be removed or replaced with QUICHE_DCHECKs. And of +// course the support for very large varints will not be needed in production +// code. + +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/http2/hpack/http2_hpack_constants.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { + +class QUICHE_NO_EXPORT HpackBlockBuilder { + public: + explicit HpackBlockBuilder(absl::string_view initial_contents) + : buffer_(initial_contents.data(), initial_contents.size()) {} + HpackBlockBuilder() {} + ~HpackBlockBuilder() {} + + size_t size() const { return buffer_.size(); } + const std::string& buffer() const { return buffer_; } + + //---------------------------------------------------------------------------- + // Methods for appending a valid HPACK entry. + + void AppendIndexedHeader(uint64_t index) { + AppendEntryTypeAndVarint(HpackEntryType::kIndexedHeader, index); + } + + void AppendDynamicTableSizeUpdate(uint64_t size) { + AppendEntryTypeAndVarint(HpackEntryType::kDynamicTableSizeUpdate, size); + } + + void AppendNameIndexAndLiteralValue(HpackEntryType entry_type, + uint64_t name_index, + bool value_is_huffman_encoded, + absl::string_view value) { + // name_index==0 would indicate that the entry includes a literal name. + // Call AppendLiteralNameAndValue in that case. + EXPECT_NE(0u, name_index); + AppendEntryTypeAndVarint(entry_type, name_index); + AppendString(value_is_huffman_encoded, value); + } + + void AppendLiteralNameAndValue(HpackEntryType entry_type, + bool name_is_huffman_encoded, + absl::string_view name, + bool value_is_huffman_encoded, + absl::string_view value) { + AppendEntryTypeAndVarint(entry_type, 0); + AppendString(name_is_huffman_encoded, name); + AppendString(value_is_huffman_encoded, value); + } + + //---------------------------------------------------------------------------- + // Primitive methods that are not guaranteed to write a valid HPACK entry. + + // Appends a varint, with the specified high_bits above the prefix of the + // varint. + void AppendHighBitsAndVarint(uint8_t high_bits, uint8_t prefix_length, + uint64_t varint); + + // Append the start of an HPACK entry for the specified type, with the + // specified varint. + void AppendEntryTypeAndVarint(HpackEntryType entry_type, uint64_t varint); + + // Append a header string (i.e. a header name or value) in HPACK format. + // Does NOT perform Huffman encoding. + void AppendString(bool is_huffman_encoded, absl::string_view str); + + private: + std::string buffer_; +}; + +} // namespace test +} // namespace http2 + +#endif // QUICHE_HTTP2_TEST_TOOLS_HPACK_BLOCK_BUILDER_H_ diff --git a/quiche/http2/test_tools/hpack_block_builder_test.cc b/quiche/http2/test_tools/hpack_block_builder_test.cc new file mode 100644 index 000000000000..764204972452 --- /dev/null +++ b/quiche/http2/test_tools/hpack_block_builder_test.cc @@ -0,0 +1,169 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/test_tools/hpack_block_builder.h" + +#include "absl/strings/escaping.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { +namespace { +const bool kUncompressed = false; +const bool kCompressed = true; + +// TODO(jamessynge): Once static table code is checked in, switch to using +// constants from there. +const uint32_t kStaticTableMethodGET = 2; +const uint32_t kStaticTablePathSlash = 4; +const uint32_t kStaticTableSchemeHttp = 6; + +// Tests of encoding per the RFC. See: +// http://httpwg.org/specs/rfc7541.html#header.field.representation.examples +// The expected values have been copied from the RFC. +TEST(HpackBlockBuilderTest, ExamplesFromSpecC2) { + { + HpackBlockBuilder b; + b.AppendLiteralNameAndValue(HpackEntryType::kIndexedLiteralHeader, + kUncompressed, "custom-key", kUncompressed, + "custom-header"); + EXPECT_EQ(26u, b.size()); + + const char kExpected[] = + "\x40" // == Literal indexed == + "\x0a" // Name length (10) + "custom-key" // Name + "\x0d" // Value length (13) + "custom-header"; // Value + EXPECT_EQ(kExpected, b.buffer()); + } + { + HpackBlockBuilder b; + b.AppendNameIndexAndLiteralValue(HpackEntryType::kUnindexedLiteralHeader, 4, + kUncompressed, "/sample/path"); + EXPECT_EQ(14u, b.size()); + + const char kExpected[] = + "\x04" // == Literal unindexed, name index 0x04 == + "\x0c" // Value length (12) + "/sample/path"; // Value + EXPECT_EQ(kExpected, b.buffer()); + } + { + HpackBlockBuilder b; + b.AppendLiteralNameAndValue(HpackEntryType::kNeverIndexedLiteralHeader, + kUncompressed, "password", kUncompressed, + "secret"); + EXPECT_EQ(17u, b.size()); + + const char kExpected[] = + "\x10" // == Literal never indexed == + "\x08" // Name length (8) + "password" // Name + "\x06" // Value length (6) + "secret"; // Value + EXPECT_EQ(kExpected, b.buffer()); + } + { + HpackBlockBuilder b; + b.AppendIndexedHeader(2); + EXPECT_EQ(1u, b.size()); + + const char kExpected[] = "\x82"; // == Indexed (2) == + EXPECT_EQ(kExpected, b.buffer()); + } +} + +// Tests of encoding per the RFC. See: +// http://httpwg.org/specs/rfc7541.html#request.examples.without.huffman.coding +TEST(HpackBlockBuilderTest, ExamplesFromSpecC3) { + { + // Header block to encode: + // :method: GET + // :scheme: http + // :path: / + // :authority: www.example.com + HpackBlockBuilder b; + b.AppendIndexedHeader(2); // :method: GET + b.AppendIndexedHeader(6); // :scheme: http + b.AppendIndexedHeader(4); // :path: / + b.AppendNameIndexAndLiteralValue(HpackEntryType::kIndexedLiteralHeader, 1, + kUncompressed, "www.example.com"); + EXPECT_EQ(20u, b.size()); + + // Hex dump of encoded data (copied from RFC): + // 0x0000: 8286 8441 0f77 7777 2e65 7861 6d70 6c65 ...A.www.example + // 0x0010: 2e63 6f6d .com + + const std::string expected = + absl::HexStringToBytes("828684410f7777772e6578616d706c652e636f6d"); + EXPECT_EQ(expected, b.buffer()); + } +} + +// Tests of encoding per the RFC. See: +// http://httpwg.org/specs/rfc7541.html#request.examples.with.huffman.coding +TEST(HpackBlockBuilderTest, ExamplesFromSpecC4) { + { + // Header block to encode: + // :method: GET + // :scheme: http + // :path: / + // :authority: www.example.com (Huffman encoded) + HpackBlockBuilder b; + b.AppendIndexedHeader(kStaticTableMethodGET); + b.AppendIndexedHeader(kStaticTableSchemeHttp); + b.AppendIndexedHeader(kStaticTablePathSlash); + const char kHuffmanWwwExampleCom[] = {'\xf1', '\xe3', '\xc2', '\xe5', + '\xf2', '\x3a', '\x6b', '\xa0', + '\xab', '\x90', '\xf4', '\xff'}; + b.AppendNameIndexAndLiteralValue( + HpackEntryType::kIndexedLiteralHeader, 1, kCompressed, + absl::string_view(kHuffmanWwwExampleCom, sizeof kHuffmanWwwExampleCom)); + EXPECT_EQ(17u, b.size()); + + // Hex dump of encoded data (copied from RFC): + // 0x0000: 8286 8441 8cf1 e3c2 e5f2 3a6b a0ab 90f4 ...A......:k.... + // 0x0010: ff . + + const std::string expected = + absl::HexStringToBytes("828684418cf1e3c2e5f23a6ba0ab90f4ff"); + EXPECT_EQ(expected, b.buffer()); + } +} + +TEST(HpackBlockBuilderTest, DynamicTableSizeUpdate) { + { + HpackBlockBuilder b; + b.AppendDynamicTableSizeUpdate(0); + EXPECT_EQ(1u, b.size()); + + const char kData[] = {'\x20'}; + absl::string_view expected(kData, sizeof kData); + EXPECT_EQ(expected, b.buffer()); + } + { + HpackBlockBuilder b; + b.AppendDynamicTableSizeUpdate(4096); // The default size. + EXPECT_EQ(3u, b.size()); + + const char kData[] = {'\x3f', '\xe1', '\x1f'}; + absl::string_view expected(kData, sizeof kData); + EXPECT_EQ(expected, b.buffer()); + } + { + HpackBlockBuilder b; + b.AppendDynamicTableSizeUpdate(1000000000000); // A very large value. + EXPECT_EQ(7u, b.size()); + + const char kData[] = {'\x3f', '\xe1', '\x9f', '\x94', + '\xa5', '\x8d', '\x1d'}; + absl::string_view expected(kData, sizeof kData); + EXPECT_EQ(expected, b.buffer()); + } +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/test_tools/hpack_block_collector.cc b/quiche/http2/test_tools/hpack_block_collector.cc new file mode 100644 index 000000000000..643806182cfb --- /dev/null +++ b/quiche/http2/test_tools/hpack_block_collector.cc @@ -0,0 +1,143 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/test_tools/hpack_block_collector.h" + +#include +#include + +#include "quiche/http2/test_tools/verify_macros.h" +#include "quiche/common/platform/api/quiche_logging.h" + +using ::testing::AssertionResult; +using ::testing::AssertionSuccess; + +namespace http2 { +namespace test { + +HpackBlockCollector::HpackBlockCollector() = default; +HpackBlockCollector::HpackBlockCollector(const HpackBlockCollector& other) + : pending_entry_(other.pending_entry_), entries_(other.entries_) {} +HpackBlockCollector::~HpackBlockCollector() = default; + +void HpackBlockCollector::OnIndexedHeader(size_t index) { + pending_entry_.OnIndexedHeader(index); + PushPendingEntry(); +} +void HpackBlockCollector::OnDynamicTableSizeUpdate(size_t size) { + pending_entry_.OnDynamicTableSizeUpdate(size); + PushPendingEntry(); +} +void HpackBlockCollector::OnStartLiteralHeader(HpackEntryType header_type, + size_t maybe_name_index) { + pending_entry_.OnStartLiteralHeader(header_type, maybe_name_index); +} +void HpackBlockCollector::OnNameStart(bool huffman_encoded, size_t len) { + pending_entry_.OnNameStart(huffman_encoded, len); +} +void HpackBlockCollector::OnNameData(const char* data, size_t len) { + pending_entry_.OnNameData(data, len); +} +void HpackBlockCollector::OnNameEnd() { pending_entry_.OnNameEnd(); } +void HpackBlockCollector::OnValueStart(bool huffman_encoded, size_t len) { + pending_entry_.OnValueStart(huffman_encoded, len); +} +void HpackBlockCollector::OnValueData(const char* data, size_t len) { + pending_entry_.OnValueData(data, len); +} +void HpackBlockCollector::OnValueEnd() { + pending_entry_.OnValueEnd(); + PushPendingEntry(); +} + +void HpackBlockCollector::PushPendingEntry() { + EXPECT_TRUE(pending_entry_.IsComplete()); + QUICHE_DVLOG(2) << "PushPendingEntry: " << pending_entry_; + entries_.push_back(pending_entry_); + EXPECT_TRUE(entries_.back().IsComplete()); + pending_entry_.Clear(); +} +void HpackBlockCollector::Clear() { + pending_entry_.Clear(); + entries_.clear(); +} + +void HpackBlockCollector::ExpectIndexedHeader(size_t index) { + entries_.push_back( + HpackEntryCollector(HpackEntryType::kIndexedHeader, index)); +} +void HpackBlockCollector::ExpectDynamicTableSizeUpdate(size_t size) { + entries_.push_back( + HpackEntryCollector(HpackEntryType::kDynamicTableSizeUpdate, size)); +} +void HpackBlockCollector::ExpectNameIndexAndLiteralValue( + HpackEntryType type, size_t index, bool value_huffman, + const std::string& value) { + entries_.push_back(HpackEntryCollector(type, index, value_huffman, value)); +} +void HpackBlockCollector::ExpectLiteralNameAndValue(HpackEntryType type, + bool name_huffman, + const std::string& name, + bool value_huffman, + const std::string& value) { + entries_.push_back( + HpackEntryCollector(type, name_huffman, name, value_huffman, value)); +} + +void HpackBlockCollector::ShuffleEntries(Http2Random* rng) { + std::shuffle(entries_.begin(), entries_.end(), *rng); +} + +void HpackBlockCollector::AppendToHpackBlockBuilder( + HpackBlockBuilder* hbb) const { + QUICHE_CHECK(IsNotPending()); + for (const auto& entry : entries_) { + entry.AppendToHpackBlockBuilder(hbb); + } +} + +AssertionResult HpackBlockCollector::ValidateSoleIndexedHeader( + size_t ndx) const { + HTTP2_VERIFY_TRUE(pending_entry_.IsClear()); + HTTP2_VERIFY_EQ(1u, entries_.size()); + HTTP2_VERIFY_TRUE(entries_.front().ValidateIndexedHeader(ndx)); + return AssertionSuccess(); +} +AssertionResult HpackBlockCollector::ValidateSoleLiteralValueHeader( + HpackEntryType expected_type, size_t expected_index, + bool expected_value_huffman, absl::string_view expected_value) const { + HTTP2_VERIFY_TRUE(pending_entry_.IsClear()); + HTTP2_VERIFY_EQ(1u, entries_.size()); + HTTP2_VERIFY_TRUE(entries_.front().ValidateLiteralValueHeader( + expected_type, expected_index, expected_value_huffman, expected_value)); + return AssertionSuccess(); +} +AssertionResult HpackBlockCollector::ValidateSoleLiteralNameValueHeader( + HpackEntryType expected_type, bool expected_name_huffman, + absl::string_view expected_name, bool expected_value_huffman, + absl::string_view expected_value) const { + HTTP2_VERIFY_TRUE(pending_entry_.IsClear()); + HTTP2_VERIFY_EQ(1u, entries_.size()); + HTTP2_VERIFY_TRUE(entries_.front().ValidateLiteralNameValueHeader( + expected_type, expected_name_huffman, expected_name, + expected_value_huffman, expected_value)); + return AssertionSuccess(); +} +AssertionResult HpackBlockCollector::ValidateSoleDynamicTableSizeUpdate( + size_t size) const { + HTTP2_VERIFY_TRUE(pending_entry_.IsClear()); + HTTP2_VERIFY_EQ(1u, entries_.size()); + HTTP2_VERIFY_TRUE(entries_.front().ValidateDynamicTableSizeUpdate(size)); + return AssertionSuccess(); +} + +AssertionResult HpackBlockCollector::VerifyEq( + const HpackBlockCollector& that) const { + HTTP2_VERIFY_EQ(pending_entry_, that.pending_entry_); + HTTP2_VERIFY_EQ(entries_, that.entries_); + return AssertionSuccess(); +} + +} // namespace test +} // namespace http2 diff --git a/quiche/http2/test_tools/hpack_block_collector.h b/quiche/http2/test_tools/hpack_block_collector.h new file mode 100644 index 000000000000..fe49d093cccf --- /dev/null +++ b/quiche/http2/test_tools/hpack_block_collector.h @@ -0,0 +1,122 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_TEST_TOOLS_HPACK_BLOCK_COLLECTOR_H_ +#define QUICHE_HTTP2_TEST_TOOLS_HPACK_BLOCK_COLLECTOR_H_ + +// HpackBlockCollector implements HpackEntryDecoderListener in order to record +// the calls using HpackEntryCollector instances (one per HPACK entry). This +// supports testing of HpackBlockDecoder, which decodes entire HPACK blocks. +// +// In addition to implementing the callback methods, HpackBlockCollector also +// supports comparing two HpackBlockCollector instances (i.e. an expected and +// an actual), or a sole HPACK entry against an expected value. + +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/http2/hpack/decoder/hpack_entry_decoder_listener.h" +#include "quiche/http2/hpack/http2_hpack_constants.h" +#include "quiche/http2/test_tools/hpack_block_builder.h" +#include "quiche/http2/test_tools/hpack_entry_collector.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace test { + +class QUICHE_NO_EXPORT HpackBlockCollector : public HpackEntryDecoderListener { + public: + HpackBlockCollector(); + HpackBlockCollector(const HpackBlockCollector& other); + ~HpackBlockCollector() override; + + // Implementations of HpackEntryDecoderListener, forwarding to pending_entry_, + // an HpackEntryCollector for the "in-progress" HPACK entry. OnIndexedHeader + // and OnDynamicTableSizeUpdate are pending only for that one call, while + // OnStartLiteralHeader is followed by many calls, ending with OnValueEnd. + // Once all the calls for one HPACK entry have been received, PushPendingEntry + // is used to append the pending_entry_ entry to the collected entries_. + void OnIndexedHeader(size_t index) override; + void OnDynamicTableSizeUpdate(size_t size) override; + void OnStartLiteralHeader(HpackEntryType header_type, + size_t maybe_name_index) override; + void OnNameStart(bool huffman_encoded, size_t len) override; + void OnNameData(const char* data, size_t len) override; + void OnNameEnd() override; + void OnValueStart(bool huffman_encoded, size_t len) override; + void OnValueData(const char* data, size_t len) override; + void OnValueEnd() override; + + // Methods for creating a set of expectations (i.e. HPACK entries to compare + // against those collected by another instance of HpackBlockCollector). + + // Add an HPACK entry for an indexed header. + void ExpectIndexedHeader(size_t index); + + // Add an HPACK entry for a dynamic table size update. + void ExpectDynamicTableSizeUpdate(size_t size); + + // Add an HPACK entry for a header entry with an index for the name, and a + // literal value. + void ExpectNameIndexAndLiteralValue(HpackEntryType type, size_t index, + bool value_huffman, + const std::string& value); + + // Add an HPACK entry for a header entry with a literal name and value. + void ExpectLiteralNameAndValue(HpackEntryType type, bool name_huffman, + const std::string& name, bool value_huffman, + const std::string& value); + + // Shuffle the entries, in support of generating an HPACK block of entries + // in some random order. + void ShuffleEntries(Http2Random* rng); + + // Serialize entries_ to the HpackBlockBuilder. + void AppendToHpackBlockBuilder(HpackBlockBuilder* hbb) const; + + // Return AssertionSuccess if there is just one entry, and it is an + // Indexed Header with the specified index. + ::testing::AssertionResult ValidateSoleIndexedHeader(size_t ndx) const; + + // Return AssertionSuccess if there is just one entry, and it is a + // Dynamic Table Size Update with the specified size. + ::testing::AssertionResult ValidateSoleDynamicTableSizeUpdate( + size_t size) const; + + // Return AssertionSuccess if there is just one entry, and it is a Header + // entry with an index for the name and a literal value. + ::testing::AssertionResult ValidateSoleLiteralValueHeader( + HpackEntryType expected_type, size_t expected_index, + bool expected_value_huffman, absl::string_view expected_value) const; + + // Return AssertionSuccess if there is just one entry, and it is a Header + // with a literal name and literal value. + ::testing::AssertionResult ValidateSoleLiteralNameValueHeader( + HpackEntryType expected_type, bool expected_name_huffman, + absl::string_view expected_name, bool expected_value_huffman, + absl::string_view expected_value) const; + + bool IsNotPending() const { return pending_entry_.IsClear(); } + bool IsClear() const { return IsNotPending() && entries_.empty(); } + void Clear(); + + ::testing::AssertionResult VerifyEq(const HpackBlockCollector& that) const; + + private: + // Push the value of pending_entry_ onto entries_, and clear pending_entry_. + // The pending_entry_ must be complete. + void PushPendingEntry(); + + HpackEntryCollector pending_entry_; + std::vector entries_; +}; + +} // namespace test +} // namespace http2 + +#endif // QUICHE_HTTP2_TEST_TOOLS_HPACK_BLOCK_COLLECTOR_H_ diff --git a/quiche/http2/test_tools/hpack_entry_collector.cc b/quiche/http2/test_tools/hpack_entry_collector.cc new file mode 100644 index 000000000000..77cfb6ecb31a --- /dev/null +++ b/quiche/http2/test_tools/hpack_entry_collector.cc @@ -0,0 +1,293 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/test_tools/hpack_entry_collector.h" + +#include "absl/strings/str_cat.h" +#include "quiche/http2/hpack/http2_hpack_constants.h" +#include "quiche/http2/test_tools/hpack_string_collector.h" +#include "quiche/http2/test_tools/verify_macros.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +using ::testing::AssertionResult; + +namespace http2 { +namespace test { +namespace { + +const HpackEntryType kInvalidHeaderType = static_cast(99); +const size_t kInvalidIndex = 99999999; + +} // namespace + +HpackEntryCollector::HpackEntryCollector() { Clear(); } + +HpackEntryCollector::HpackEntryCollector(const HpackEntryCollector& other) = + default; + +HpackEntryCollector::HpackEntryCollector(HpackEntryType type, + size_t index_or_size) + : header_type_(type), index_(index_or_size), started_(true), ended_(true) {} +HpackEntryCollector::HpackEntryCollector(HpackEntryType type, size_t index, + bool value_huffman, + const std::string& value) + : header_type_(type), + index_(index), + value_(value, value_huffman), + started_(true), + ended_(true) {} +HpackEntryCollector::HpackEntryCollector(HpackEntryType type, bool name_huffman, + const std::string& name, + bool value_huffman, + const std::string& value) + : header_type_(type), + index_(0), + name_(name, name_huffman), + value_(value, value_huffman), + started_(true), + ended_(true) {} + +HpackEntryCollector::~HpackEntryCollector() = default; + +void HpackEntryCollector::OnIndexedHeader(size_t index) { + ASSERT_FALSE(started_); + ASSERT_TRUE(IsClear()) << ToString(); + Init(HpackEntryType::kIndexedHeader, index); + ended_ = true; +} +void HpackEntryCollector::OnStartLiteralHeader(HpackEntryType header_type, + size_t maybe_name_index) { + ASSERT_FALSE(started_); + ASSERT_TRUE(IsClear()) << ToString(); + Init(header_type, maybe_name_index); +} +void HpackEntryCollector::OnNameStart(bool huffman_encoded, size_t len) { + ASSERT_TRUE(started_); + ASSERT_FALSE(ended_); + ASSERT_FALSE(IsClear()); + ASSERT_TRUE(LiteralNameExpected()) << ToString(); + name_.OnStringStart(huffman_encoded, len); +} +void HpackEntryCollector::OnNameData(const char* data, size_t len) { + ASSERT_TRUE(started_); + ASSERT_FALSE(ended_); + ASSERT_TRUE(LiteralNameExpected()) << ToString(); + ASSERT_TRUE(name_.IsInProgress()); + name_.OnStringData(data, len); +} +void HpackEntryCollector::OnNameEnd() { + ASSERT_TRUE(started_); + ASSERT_FALSE(ended_); + ASSERT_TRUE(LiteralNameExpected()) << ToString(); + ASSERT_TRUE(name_.IsInProgress()); + name_.OnStringEnd(); +} +void HpackEntryCollector::OnValueStart(bool huffman_encoded, size_t len) { + ASSERT_TRUE(started_); + ASSERT_FALSE(ended_); + if (LiteralNameExpected()) { + ASSERT_TRUE(name_.HasEnded()); + } + ASSERT_TRUE(LiteralValueExpected()) << ToString(); + ASSERT_TRUE(value_.IsClear()) << value_.ToString(); + value_.OnStringStart(huffman_encoded, len); +} +void HpackEntryCollector::OnValueData(const char* data, size_t len) { + ASSERT_TRUE(started_); + ASSERT_FALSE(ended_); + ASSERT_TRUE(LiteralValueExpected()) << ToString(); + ASSERT_TRUE(value_.IsInProgress()); + value_.OnStringData(data, len); +} +void HpackEntryCollector::OnValueEnd() { + ASSERT_TRUE(started_); + ASSERT_FALSE(ended_); + ASSERT_TRUE(LiteralValueExpected()) << ToString(); + ASSERT_TRUE(value_.IsInProgress()); + value_.OnStringEnd(); + ended_ = true; +} +void HpackEntryCollector::OnDynamicTableSizeUpdate(size_t size) { + ASSERT_FALSE(started_); + ASSERT_TRUE(IsClear()) << ToString(); + Init(HpackEntryType::kDynamicTableSizeUpdate, size); + ended_ = true; +} + +void HpackEntryCollector::Clear() { + header_type_ = kInvalidHeaderType; + index_ = kInvalidIndex; + name_.Clear(); + value_.Clear(); + started_ = ended_ = false; +} +bool HpackEntryCollector::IsClear() const { + return header_type_ == kInvalidHeaderType && index_ == kInvalidIndex && + name_.IsClear() && value_.IsClear() && !started_ && !ended_; +} +bool HpackEntryCollector::IsComplete() const { return started_ && ended_; } +bool HpackEntryCollector::LiteralNameExpected() const { + switch (header_type_) { + case HpackEntryType::kIndexedLiteralHeader: + case HpackEntryType::kUnindexedLiteralHeader: + case HpackEntryType::kNeverIndexedLiteralHeader: + return index_ == 0; + default: + return false; + } +} +bool HpackEntryCollector::LiteralValueExpected() const { + switch (header_type_) { + case HpackEntryType::kIndexedLiteralHeader: + case HpackEntryType::kUnindexedLiteralHeader: + case HpackEntryType::kNeverIndexedLiteralHeader: + return true; + default: + return false; + } +} +AssertionResult HpackEntryCollector::ValidateIndexedHeader( + size_t expected_index) const { + HTTP2_VERIFY_TRUE(started_); + HTTP2_VERIFY_TRUE(ended_); + HTTP2_VERIFY_EQ(HpackEntryType::kIndexedHeader, header_type_); + HTTP2_VERIFY_EQ(expected_index, index_); + return ::testing::AssertionSuccess(); +} +AssertionResult HpackEntryCollector::ValidateLiteralValueHeader( + HpackEntryType expected_type, size_t expected_index, + bool expected_value_huffman, absl::string_view expected_value) const { + HTTP2_VERIFY_TRUE(started_); + HTTP2_VERIFY_TRUE(ended_); + HTTP2_VERIFY_EQ(expected_type, header_type_); + HTTP2_VERIFY_NE(0u, expected_index); + HTTP2_VERIFY_EQ(expected_index, index_); + HTTP2_VERIFY_TRUE(name_.IsClear()); + HTTP2_VERIFY_SUCCESS( + value_.Collected(expected_value, expected_value_huffman)); + return ::testing::AssertionSuccess(); +} +AssertionResult HpackEntryCollector::ValidateLiteralNameValueHeader( + HpackEntryType expected_type, bool expected_name_huffman, + absl::string_view expected_name, bool expected_value_huffman, + absl::string_view expected_value) const { + HTTP2_VERIFY_TRUE(started_); + HTTP2_VERIFY_TRUE(ended_); + HTTP2_VERIFY_EQ(expected_type, header_type_); + HTTP2_VERIFY_EQ(0u, index_); + HTTP2_VERIFY_SUCCESS(name_.Collected(expected_name, expected_name_huffman)); + HTTP2_VERIFY_SUCCESS( + value_.Collected(expected_value, expected_value_huffman)); + return ::testing::AssertionSuccess(); +} +AssertionResult HpackEntryCollector::ValidateDynamicTableSizeUpdate( + size_t size) const { + HTTP2_VERIFY_TRUE(started_); + HTTP2_VERIFY_TRUE(ended_); + HTTP2_VERIFY_EQ(HpackEntryType::kDynamicTableSizeUpdate, header_type_); + HTTP2_VERIFY_EQ(index_, size); + return ::testing::AssertionSuccess(); +} + +void HpackEntryCollector::AppendToHpackBlockBuilder( + HpackBlockBuilder* hbb) const { + ASSERT_TRUE(started_ && ended_) << *this; + switch (header_type_) { + case HpackEntryType::kIndexedHeader: + hbb->AppendIndexedHeader(index_); + return; + + case HpackEntryType::kDynamicTableSizeUpdate: + hbb->AppendDynamicTableSizeUpdate(index_); + return; + + case HpackEntryType::kIndexedLiteralHeader: + case HpackEntryType::kUnindexedLiteralHeader: + case HpackEntryType::kNeverIndexedLiteralHeader: + ASSERT_TRUE(value_.HasEnded()) << *this; + if (index_ != 0) { + QUICHE_CHECK(name_.IsClear()); + hbb->AppendNameIndexAndLiteralValue(header_type_, index_, + value_.huffman_encoded, value_.s); + } else { + QUICHE_CHECK(name_.HasEnded()) << *this; + hbb->AppendLiteralNameAndValue(header_type_, name_.huffman_encoded, + name_.s, value_.huffman_encoded, + value_.s); + } + return; + + default: + ADD_FAILURE() << *this; + } +} + +std::string HpackEntryCollector::ToString() const { + std::string result("Type="); + switch (header_type_) { + case HpackEntryType::kIndexedHeader: + result += "IndexedHeader"; + break; + case HpackEntryType::kDynamicTableSizeUpdate: + result += "DynamicTableSizeUpdate"; + break; + case HpackEntryType::kIndexedLiteralHeader: + result += "IndexedLiteralHeader"; + break; + case HpackEntryType::kUnindexedLiteralHeader: + result += "UnindexedLiteralHeader"; + break; + case HpackEntryType::kNeverIndexedLiteralHeader: + result += "NeverIndexedLiteralHeader"; + break; + default: + if (header_type_ == kInvalidHeaderType) { + result += ""; + } else { + absl::StrAppend(&result, header_type_); + } + } + if (index_ != 0) { + absl::StrAppend(&result, " Index=", index_); + } + if (!name_.IsClear()) { + absl::StrAppend(&result, " Name", name_.ToString()); + } + if (!value_.IsClear()) { + absl::StrAppend(&result, " Value", value_.ToString()); + } + if (!started_) { + EXPECT_FALSE(ended_); + absl::StrAppend(&result, " !started"); + } else if (!ended_) { + absl::StrAppend(&result, " !ended"); + } else { + absl::StrAppend(&result, " Complete"); + } + return result; +} + +void HpackEntryCollector::Init(HpackEntryType type, size_t maybe_index) { + ASSERT_TRUE(IsClear()) << ToString(); + header_type_ = type; + index_ = maybe_index; + started_ = true; +} + +bool operator==(const HpackEntryCollector& a, const HpackEntryCollector& b) { + return a.name() == b.name() && a.value() == b.value() && + a.index() == b.index() && a.header_type() == b.header_type() && + a.started() == b.started() && a.ended() == b.ended(); +} +bool operator!=(const HpackEntryCollector& a, const HpackEntryCollector& b) { + return !(a == b); +} + +std::ostream& operator<<(std::ostream& out, const HpackEntryCollector& v) { + return out << v.ToString(); +} + +} // namespace test +} // namespace http2 diff --git a/quiche/http2/test_tools/hpack_entry_collector.h b/quiche/http2/test_tools/hpack_entry_collector.h new file mode 100644 index 000000000000..93ad860453d7 --- /dev/null +++ b/quiche/http2/test_tools/hpack_entry_collector.h @@ -0,0 +1,151 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_TEST_TOOLS_HPACK_ENTRY_COLLECTOR_H_ +#define QUICHE_HTTP2_TEST_TOOLS_HPACK_ENTRY_COLLECTOR_H_ + +// HpackEntryCollector records calls to HpackEntryDecoderListener in support +// of tests of HpackEntryDecoder, or which use it. Can only record the callbacks +// for the decoding of a single entry; call Clear() between decoding successive +// entries or use a distinct HpackEntryCollector for each entry. + +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/http2/hpack/decoder/hpack_entry_decoder_listener.h" +#include "quiche/http2/hpack/http2_hpack_constants.h" +#include "quiche/http2/test_tools/hpack_block_builder.h" +#include "quiche/http2/test_tools/hpack_string_collector.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { + +class QUICHE_NO_EXPORT HpackEntryCollector : public HpackEntryDecoderListener { + public: + HpackEntryCollector(); + HpackEntryCollector(const HpackEntryCollector& other); + + // These next three constructors are intended for use in tests that create + // an HpackEntryCollector "manually", and then compare it against another + // that is populated via calls to the HpackEntryDecoderListener methods. + HpackEntryCollector(HpackEntryType type, size_t index_or_size); + HpackEntryCollector(HpackEntryType type, size_t index, bool value_huffman, + const std::string& value); + HpackEntryCollector(HpackEntryType type, bool name_huffman, + const std::string& name, bool value_huffman, + const std::string& value); + + ~HpackEntryCollector() override; + + // Methods defined by HpackEntryDecoderListener. + void OnIndexedHeader(size_t index) override; + void OnStartLiteralHeader(HpackEntryType header_type, + size_t maybe_name_index) override; + void OnNameStart(bool huffman_encoded, size_t len) override; + void OnNameData(const char* data, size_t len) override; + void OnNameEnd() override; + void OnValueStart(bool huffman_encoded, size_t len) override; + void OnValueData(const char* data, size_t len) override; + void OnValueEnd() override; + void OnDynamicTableSizeUpdate(size_t size) override; + + // Clears the fields of the collector so that it is ready to start collecting + // another HPACK block entry. + void Clear(); + + // Is the collector ready to start collecting another HPACK block entry. + bool IsClear() const; + + // Has a complete entry been collected? + bool IsComplete() const; + + // Based on the HpackEntryType, is a literal name expected? + bool LiteralNameExpected() const; + + // Based on the HpackEntryType, is a literal value expected? + bool LiteralValueExpected() const; + + // Returns success if collected an Indexed Header (i.e. OnIndexedHeader was + // called). + ::testing::AssertionResult ValidateIndexedHeader(size_t expected_index) const; + + // Returns success if collected a Header with an indexed name and literal + // value (i.e. OnStartLiteralHeader was called with a non-zero index for + // the name, which must match expected_index). + ::testing::AssertionResult ValidateLiteralValueHeader( + HpackEntryType expected_type, size_t expected_index, + bool expected_value_huffman, absl::string_view expected_value) const; + + // Returns success if collected a Header with an literal name and literal + // value. + ::testing::AssertionResult ValidateLiteralNameValueHeader( + HpackEntryType expected_type, bool expected_name_huffman, + absl::string_view expected_name, bool expected_value_huffman, + absl::string_view expected_value) const; + + // Returns success if collected a Dynamic Table Size Update, + // with the specified size. + ::testing::AssertionResult ValidateDynamicTableSizeUpdate( + size_t expected_size) const; + + void set_header_type(HpackEntryType v) { header_type_ = v; } + HpackEntryType header_type() const { return header_type_; } + + void set_index(size_t v) { index_ = v; } + size_t index() const { return index_; } + + void set_name(const HpackStringCollector& v) { name_ = v; } + const HpackStringCollector& name() const { return name_; } + + void set_value(const HpackStringCollector& v) { value_ = v; } + const HpackStringCollector& value() const { return value_; } + + void set_started(bool v) { started_ = v; } + bool started() const { return started_; } + + void set_ended(bool v) { ended_ = v; } + bool ended() const { return ended_; } + + void AppendToHpackBlockBuilder(HpackBlockBuilder* hbb) const; + + // Returns a debug string. + std::string ToString() const; + + private: + void Init(HpackEntryType type, size_t maybe_index); + + HpackEntryType header_type_; + size_t index_; + + HpackStringCollector name_; + HpackStringCollector value_; + + // True if has received a call to an HpackEntryDecoderListener method + // indicating the start of decoding an HPACK entry; for example, + // OnIndexedHeader set it true, but OnNameStart does not change it. + bool started_ = false; + + // True if has received a call to an HpackEntryDecoderListener method + // indicating the end of decoding an HPACK entry; for example, + // OnIndexedHeader and OnValueEnd both set it true, but OnNameEnd does + // not change it. + bool ended_ = false; +}; + +QUICHE_NO_EXPORT bool operator==(const HpackEntryCollector& a, + const HpackEntryCollector& b); +QUICHE_NO_EXPORT bool operator!=(const HpackEntryCollector& a, + const HpackEntryCollector& b); +QUICHE_NO_EXPORT std::ostream& operator<<(std::ostream& out, + const HpackEntryCollector& v); + +} // namespace test +} // namespace http2 + +#endif // QUICHE_HTTP2_TEST_TOOLS_HPACK_ENTRY_COLLECTOR_H_ diff --git a/quiche/http2/test_tools/hpack_example.cc b/quiche/http2/test_tools/hpack_example.cc new file mode 100644 index 000000000000..42d9261879a5 --- /dev/null +++ b/quiche/http2/test_tools/hpack_example.cc @@ -0,0 +1,59 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/test_tools/hpack_example.h" + +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { +namespace test { +namespace { + +void HpackExampleToStringOrDie(absl::string_view example, std::string* output) { + while (!example.empty()) { + const char c0 = example[0]; + if (isxdigit(c0)) { + QUICHE_CHECK_GT(example.size(), 1u) << "Truncated hex byte?"; + const char c1 = example[1]; + QUICHE_CHECK(isxdigit(c1)) << "Found half a byte?"; + *output += absl::HexStringToBytes(example.substr(0, 2)); + example.remove_prefix(2); + continue; + } + if (isspace(c0)) { + example.remove_prefix(1); + continue; + } + if (!example.empty() && example[0] == '|') { + // Start of a comment. Skip to end of line or of input. + auto pos = example.find('\n'); + if (pos == absl::string_view::npos) { + // End of input. + break; + } + example.remove_prefix(pos + 1); + continue; + } + QUICHE_BUG(http2_bug_107_1) + << "Can't parse byte " << static_cast(c0) + << absl::StrCat(" (0x", absl::Hex(c0), ")") << "\nExample: " << example; + } + QUICHE_CHECK_LT(0u, output->size()) << "Example is empty."; +} + +} // namespace + +std::string HpackExampleToStringOrDie(absl::string_view example) { + std::string output; + HpackExampleToStringOrDie(example, &output); + return output; +} + +} // namespace test +} // namespace http2 diff --git a/quiche/http2/test_tools/hpack_example.h b/quiche/http2/test_tools/hpack_example.h new file mode 100644 index 000000000000..007b969c7dab --- /dev/null +++ b/quiche/http2/test_tools/hpack_example.h @@ -0,0 +1,32 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_TEST_TOOLS_HPACK_EXAMPLE_H_ +#define QUICHE_HTTP2_TEST_TOOLS_HPACK_EXAMPLE_H_ + +#include + +#include "absl/strings/string_view.h" + +// Parses HPACK examples in the format seen in the HPACK specification, +// RFC 7541. For example: +// +// 10 | == Literal never indexed == +// 08 | Literal name (len = 8) +// 7061 7373 776f 7264 | password +// 06 | Literal value (len = 6) +// 7365 6372 6574 | secret +// | -> password: secret +// +// (excluding the leading "//"). + +namespace http2 { +namespace test { + +std::string HpackExampleToStringOrDie(absl::string_view example); + +} // namespace test +} // namespace http2 + +#endif // QUICHE_HTTP2_TEST_TOOLS_HPACK_EXAMPLE_H_ diff --git a/quiche/http2/test_tools/hpack_example_test.cc b/quiche/http2/test_tools/hpack_example_test.cc new file mode 100644 index 000000000000..4dd24bd3fd66 --- /dev/null +++ b/quiche/http2/test_tools/hpack_example_test.cc @@ -0,0 +1,45 @@ +#include "quiche/http2/test_tools/hpack_example.h" + +// Tests of HpackExampleToStringOrDie. + +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { +namespace { + +TEST(HpackExampleToStringOrDie, GoodInput) { + std::string bytes = HpackExampleToStringOrDie(R"( + 40 | == Literal never indexed == + | Blank lines are OK in example: + + 08 | Literal name (len = 8) + 7061 7373 776f 7264 | password + 06 | Literal value (len = 6) + 7365 6372 6574 | secret + | -> password: secret + )"); + + // clang-format off + const char kExpected[] = { + 0x40, // Never Indexed, Literal Name and Value + 0x08, // Name Len: 8 + 0x70, 0x61, 0x73, 0x73, // Name: password + 0x77, 0x6f, 0x72, 0x64, // + 0x06, // Value Len: 6 + 0x73, 0x65, 0x63, 0x72, // Value: secret + 0x65, 0x74, // + }; + // clang-format on + EXPECT_EQ(absl::string_view(kExpected, sizeof kExpected), bytes); +} + +TEST(HpackExampleToStringOrDie, InvalidInput) { + EXPECT_DEATH(HpackExampleToStringOrDie("4"), "Truncated"); + EXPECT_DEATH(HpackExampleToStringOrDie("4x"), "half"); + EXPECT_DEATH(HpackExampleToStringOrDie(""), "empty"); +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/test_tools/hpack_string_collector.cc b/quiche/http2/test_tools/hpack_string_collector.cc new file mode 100644 index 000000000000..1110c3ae8ad7 --- /dev/null +++ b/quiche/http2/test_tools/hpack_string_collector.cc @@ -0,0 +1,117 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/test_tools/hpack_string_collector.h" + +#include + +#include +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "quiche/http2/test_tools/verify_macros.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { +namespace { + +std::ostream& operator<<(std::ostream& out, + HpackStringCollector::CollectorState v) { + switch (v) { + case HpackStringCollector::CollectorState::kGenesis: + return out << "kGenesis"; + case HpackStringCollector::CollectorState::kStarted: + return out << "kStarted"; + case HpackStringCollector::CollectorState::kEnded: + return out << "kEnded"; + } + return out << "UnknownCollectorState"; +} + +} // namespace + +HpackStringCollector::HpackStringCollector() { Clear(); } + +HpackStringCollector::HpackStringCollector(const std::string& str, bool huffman) + : s(str), len(str.size()), huffman_encoded(huffman), state(kEnded) {} + +void HpackStringCollector::Clear() { + s = ""; + len = 0; + huffman_encoded = false; + state = kGenesis; +} + +bool HpackStringCollector::IsClear() const { + return s.empty() && len == 0 && huffman_encoded == false && state == kGenesis; +} + +bool HpackStringCollector::IsInProgress() const { return state == kStarted; } + +bool HpackStringCollector::HasEnded() const { return state == kEnded; } + +void HpackStringCollector::OnStringStart(bool huffman, size_t length) { + EXPECT_TRUE(IsClear()) << ToString(); + state = kStarted; + huffman_encoded = huffman; + len = length; +} + +void HpackStringCollector::OnStringData(const char* data, size_t length) { + absl::string_view sp(data, length); + EXPECT_TRUE(IsInProgress()) << ToString(); + EXPECT_LE(sp.size(), len) << ToString(); + absl::StrAppend(&s, sp); + EXPECT_LE(s.size(), len) << ToString(); +} + +void HpackStringCollector::OnStringEnd() { + EXPECT_TRUE(IsInProgress()) << ToString(); + EXPECT_EQ(s.size(), len) << ToString(); + state = kEnded; +} + +::testing::AssertionResult HpackStringCollector::Collected( + absl::string_view str, bool is_huffman_encoded) const { + HTTP2_VERIFY_TRUE(HasEnded()); + HTTP2_VERIFY_EQ(str.size(), len); + HTTP2_VERIFY_EQ(is_huffman_encoded, huffman_encoded); + HTTP2_VERIFY_EQ(str, s); + return ::testing::AssertionSuccess(); +} + +std::string HpackStringCollector::ToString() const { + std::stringstream ss; + ss << *this; + return ss.str(); +} + +bool operator==(const HpackStringCollector& a, const HpackStringCollector& b) { + return a.s == b.s && a.len == b.len && + a.huffman_encoded == b.huffman_encoded && a.state == b.state; +} + +bool operator!=(const HpackStringCollector& a, const HpackStringCollector& b) { + return !(a == b); +} + +std::ostream& operator<<(std::ostream& out, const HpackStringCollector& v) { + out << "HpackStringCollector(state=" << v.state; + if (v.state == HpackStringCollector::kGenesis) { + return out << ")"; + } + if (v.huffman_encoded) { + out << ", Huffman Encoded"; + } + out << ", Length=" << v.len; + if (!v.s.empty() && v.len != v.s.size()) { + out << " (" << v.s.size() << ")"; + } + return out << ", String=\"" << absl::CHexEscape(v.s) << "\")"; +} + +} // namespace test +} // namespace http2 diff --git a/quiche/http2/test_tools/hpack_string_collector.h b/quiche/http2/test_tools/hpack_string_collector.h new file mode 100644 index 000000000000..15dab67ffa2e --- /dev/null +++ b/quiche/http2/test_tools/hpack_string_collector.h @@ -0,0 +1,66 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_TEST_TOOLS_HPACK_STRING_COLLECTOR_H_ +#define QUICHE_HTTP2_TEST_TOOLS_HPACK_STRING_COLLECTOR_H_ + +// Supports tests of decoding HPACK strings. + +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/http2/hpack/decoder/hpack_string_decoder_listener.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { + +// Records the callbacks associated with a decoding a string; must +// call Clear() between decoding successive strings. +struct QUICHE_NO_EXPORT HpackStringCollector + : public HpackStringDecoderListener { + enum CollectorState { + kGenesis, + kStarted, + kEnded, + }; + + HpackStringCollector(); + HpackStringCollector(const std::string& str, bool huffman); + + void Clear(); + bool IsClear() const; + bool IsInProgress() const; + bool HasEnded() const; + + void OnStringStart(bool huffman, size_t length) override; + void OnStringData(const char* data, size_t length) override; + void OnStringEnd() override; + + ::testing::AssertionResult Collected(absl::string_view str, + bool is_huffman_encoded) const; + + std::string ToString() const; + + std::string s; + size_t len; + bool huffman_encoded; + CollectorState state; +}; + +bool operator==(const HpackStringCollector& a, const HpackStringCollector& b); + +bool operator!=(const HpackStringCollector& a, const HpackStringCollector& b); + +QUICHE_NO_EXPORT std::ostream& operator<<(std::ostream& out, + const HpackStringCollector& v); + +} // namespace test +} // namespace http2 + +#endif // QUICHE_HTTP2_TEST_TOOLS_HPACK_STRING_COLLECTOR_H_ diff --git a/quiche/http2/test_tools/http2_constants_test_util.cc b/quiche/http2/test_tools/http2_constants_test_util.cc new file mode 100644 index 000000000000..ddb5cbdff68a --- /dev/null +++ b/quiche/http2/test_tools/http2_constants_test_util.cc @@ -0,0 +1,84 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/test_tools/http2_constants_test_util.h" + +namespace http2 { +namespace test { + +std::vector AllHttp2ErrorCodes() { + // clang-format off + return { + Http2ErrorCode::HTTP2_NO_ERROR, + Http2ErrorCode::PROTOCOL_ERROR, + Http2ErrorCode::INTERNAL_ERROR, + Http2ErrorCode::FLOW_CONTROL_ERROR, + Http2ErrorCode::SETTINGS_TIMEOUT, + Http2ErrorCode::STREAM_CLOSED, + Http2ErrorCode::FRAME_SIZE_ERROR, + Http2ErrorCode::REFUSED_STREAM, + Http2ErrorCode::CANCEL, + Http2ErrorCode::COMPRESSION_ERROR, + Http2ErrorCode::CONNECT_ERROR, + Http2ErrorCode::ENHANCE_YOUR_CALM, + Http2ErrorCode::INADEQUATE_SECURITY, + Http2ErrorCode::HTTP_1_1_REQUIRED, + }; + // clang-format on +} + +std::vector AllHttp2SettingsParameters() { + // clang-format off + return { + Http2SettingsParameter::HEADER_TABLE_SIZE, + Http2SettingsParameter::ENABLE_PUSH, + Http2SettingsParameter::MAX_CONCURRENT_STREAMS, + Http2SettingsParameter::INITIAL_WINDOW_SIZE, + Http2SettingsParameter::MAX_FRAME_SIZE, + Http2SettingsParameter::MAX_HEADER_LIST_SIZE, + }; + // clang-format on +} + +// Returns a mask of flags supported for the specified frame type. Returns +// zero for unknown frame types. +uint8_t KnownFlagsMaskForFrameType(Http2FrameType type) { + switch (type) { + case Http2FrameType::DATA: + return Http2FrameFlag::END_STREAM | Http2FrameFlag::PADDED; + case Http2FrameType::HEADERS: + return Http2FrameFlag::END_STREAM | Http2FrameFlag::END_HEADERS | + Http2FrameFlag::PADDED | Http2FrameFlag::PRIORITY; + case Http2FrameType::PRIORITY: + return 0x00; + case Http2FrameType::RST_STREAM: + return 0x00; + case Http2FrameType::SETTINGS: + return Http2FrameFlag::ACK; + case Http2FrameType::PUSH_PROMISE: + return Http2FrameFlag::END_HEADERS | Http2FrameFlag::PADDED; + case Http2FrameType::PING: + return Http2FrameFlag::ACK; + case Http2FrameType::GOAWAY: + return 0x00; + case Http2FrameType::WINDOW_UPDATE: + return 0x00; + case Http2FrameType::CONTINUATION: + return Http2FrameFlag::END_HEADERS; + case Http2FrameType::ALTSVC: + return 0x00; + default: + return 0x00; + } +} + +uint8_t InvalidFlagMaskForFrameType(Http2FrameType type) { + if (IsSupportedHttp2FrameType(type)) { + return ~KnownFlagsMaskForFrameType(type); + } + return 0x00; +} + +} // namespace test +} // namespace http2 diff --git a/quiche/http2/test_tools/http2_constants_test_util.h b/quiche/http2/test_tools/http2_constants_test_util.h new file mode 100644 index 000000000000..20edc8df7069 --- /dev/null +++ b/quiche/http2/test_tools/http2_constants_test_util.h @@ -0,0 +1,34 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_TEST_TOOLS_HTTP2_CONSTANTS_TEST_UTIL_H_ +#define QUICHE_HTTP2_TEST_TOOLS_HTTP2_CONSTANTS_TEST_UTIL_H_ + +#include +#include + +#include "quiche/http2/http2_constants.h" + +namespace http2 { +namespace test { + +// Returns a vector of all supported RST_STREAM and GOAWAY error codes. +std::vector AllHttp2ErrorCodes(); + +// Returns a vector of all supported parameters in SETTINGS frames. +std::vector AllHttp2SettingsParameters(); + +// Returns a mask of flags supported for the specified frame type. Returns +// zero for unknown frame types. +uint8_t KnownFlagsMaskForFrameType(Http2FrameType type); + +// Returns a mask of flag bits known to be invalid for the frame type. +// For unknown frame types, the mask is zero; i.e., we don't know that any +// are invalid. +uint8_t InvalidFlagMaskForFrameType(Http2FrameType type); + +} // namespace test +} // namespace http2 + +#endif // QUICHE_HTTP2_TEST_TOOLS_HTTP2_CONSTANTS_TEST_UTIL_H_ diff --git a/quiche/http2/test_tools/http2_frame_builder.cc b/quiche/http2/test_tools/http2_frame_builder.cc new file mode 100644 index 000000000000..d31ddc92000c --- /dev/null +++ b/quiche/http2/test_tools/http2_frame_builder.cc @@ -0,0 +1,179 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/test_tools/http2_frame_builder.h" + +#ifdef WIN32 +#include // for htonl() functions +#else +#include +#include // for htonl, htons +#endif + +#include "absl/strings/str_cat.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { + +Http2FrameBuilder::Http2FrameBuilder(Http2FrameType type, uint8_t flags, + uint32_t stream_id) { + AppendUInt24(0); // Frame payload length, unknown so far. + Append(type); + AppendUInt8(flags); + AppendUInt31(stream_id); +} + +Http2FrameBuilder::Http2FrameBuilder(const Http2FrameHeader& v) { Append(v); } + +void Http2FrameBuilder::Append(absl::string_view s) { + absl::StrAppend(&buffer_, s); +} + +void Http2FrameBuilder::AppendBytes(const void* data, uint32_t num_bytes) { + Append(absl::string_view(static_cast(data), num_bytes)); +} + +void Http2FrameBuilder::AppendZeroes(size_t num_zero_bytes) { + char zero = 0; + buffer_.append(num_zero_bytes, zero); +} + +void Http2FrameBuilder::AppendUInt8(uint8_t value) { AppendBytes(&value, 1); } + +void Http2FrameBuilder::AppendUInt16(uint16_t value) { + value = htons(value); + AppendBytes(&value, 2); +} + +void Http2FrameBuilder::AppendUInt24(uint32_t value) { + // Doesn't make sense to try to append a larger value, as that doesn't + // simulate something an encoder could do (i.e. the other 8 bits simply aren't + // there to be occupied). + EXPECT_EQ(value, value & 0xffffff); + value = htonl(value); + AppendBytes(reinterpret_cast(&value) + 1, 3); +} + +void Http2FrameBuilder::AppendUInt31(uint32_t value) { + // If you want to test the high-bit being set, call AppendUInt32 instead. + uint32_t tmp = value & StreamIdMask(); + EXPECT_EQ(value, value & StreamIdMask()) + << "High-bit of uint32_t should be clear."; + value = htonl(tmp); + AppendBytes(&value, 4); +} + +void Http2FrameBuilder::AppendUInt32(uint32_t value) { + value = htonl(value); + AppendBytes(&value, sizeof(value)); +} + +void Http2FrameBuilder::Append(Http2ErrorCode error_code) { + AppendUInt32(static_cast(error_code)); +} + +void Http2FrameBuilder::Append(Http2FrameType type) { + AppendUInt8(static_cast(type)); +} + +void Http2FrameBuilder::Append(Http2SettingsParameter parameter) { + AppendUInt16(static_cast(parameter)); +} + +void Http2FrameBuilder::Append(const Http2FrameHeader& v) { + AppendUInt24(v.payload_length); + Append(v.type); + AppendUInt8(v.flags); + AppendUInt31(v.stream_id); +} + +void Http2FrameBuilder::Append(const Http2PriorityFields& v) { + // The EXCLUSIVE flag is the high-bit of the 32-bit stream dependency field. + uint32_t tmp = v.stream_dependency & StreamIdMask(); + EXPECT_EQ(tmp, v.stream_dependency); + if (v.is_exclusive) { + tmp |= 0x80000000; + } + AppendUInt32(tmp); + + // The PRIORITY frame's weight field is logically in the range [1, 256], + // but is encoded as a byte in the range [0, 255]. + ASSERT_LE(1u, v.weight); + ASSERT_LE(v.weight, 256u); + AppendUInt8(v.weight - 1); +} + +void Http2FrameBuilder::Append(const Http2RstStreamFields& v) { + Append(v.error_code); +} + +void Http2FrameBuilder::Append(const Http2SettingFields& v) { + Append(v.parameter); + AppendUInt32(v.value); +} + +void Http2FrameBuilder::Append(const Http2PushPromiseFields& v) { + AppendUInt31(v.promised_stream_id); +} + +void Http2FrameBuilder::Append(const Http2PingFields& v) { + AppendBytes(v.opaque_bytes, sizeof Http2PingFields::opaque_bytes); +} + +void Http2FrameBuilder::Append(const Http2GoAwayFields& v) { + AppendUInt31(v.last_stream_id); + Append(v.error_code); +} + +void Http2FrameBuilder::Append(const Http2WindowUpdateFields& v) { + EXPECT_NE(0u, v.window_size_increment) << "Increment must be non-zero."; + AppendUInt31(v.window_size_increment); +} + +void Http2FrameBuilder::Append(const Http2AltSvcFields& v) { + AppendUInt16(v.origin_length); +} + +void Http2FrameBuilder::Append(const Http2PriorityUpdateFields& v) { + AppendUInt31(v.prioritized_stream_id); +} + +// Methods for changing existing buffer contents. + +void Http2FrameBuilder::WriteAt(absl::string_view s, size_t offset) { + ASSERT_LE(offset, buffer_.size()); + size_t len = offset + s.size(); + if (len > buffer_.size()) { + buffer_.resize(len); + } + for (size_t ndx = 0; ndx < s.size(); ++ndx) { + buffer_[offset + ndx] = s[ndx]; + } +} + +void Http2FrameBuilder::WriteBytesAt(const void* data, uint32_t num_bytes, + size_t offset) { + WriteAt(absl::string_view(static_cast(data), num_bytes), offset); +} + +void Http2FrameBuilder::WriteUInt24At(uint32_t value, size_t offset) { + ASSERT_LT(value, static_cast(1 << 24)); + value = htonl(value); + WriteBytesAt(reinterpret_cast(&value) + 1, sizeof(value) - 1, offset); +} + +void Http2FrameBuilder::SetPayloadLength(uint32_t payload_length) { + WriteUInt24At(payload_length, 0); +} + +size_t Http2FrameBuilder::SetPayloadLength() { + EXPECT_GE(size(), Http2FrameHeader::EncodedSize()); + uint32_t payload_length = size() - Http2FrameHeader::EncodedSize(); + SetPayloadLength(payload_length); + return payload_length; +} + +} // namespace test +} // namespace http2 diff --git a/quiche/http2/test_tools/http2_frame_builder.h b/quiche/http2/test_tools/http2_frame_builder.h new file mode 100644 index 000000000000..8ff1916d3db5 --- /dev/null +++ b/quiche/http2/test_tools/http2_frame_builder.h @@ -0,0 +1,103 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_TEST_TOOLS_HTTP2_FRAME_BUILDER_H_ +#define QUICHE_HTTP2_TEST_TOOLS_HTTP2_FRAME_BUILDER_H_ + +// Http2FrameBuilder builds wire-format HTTP/2 frames (or fragments thereof) +// from components. +// +// For now, this is only intended for use in tests, and thus has EXPECT* in the +// code. If desired to use it in an encoder, it will need optimization work, +// especially w.r.t memory mgmt, and the EXPECT* will need to be removed or +// replaced with QUICHE_DCHECKs. + +#include // for size_t + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace test { + +class QUICHE_NO_EXPORT Http2FrameBuilder { + public: + Http2FrameBuilder(Http2FrameType type, uint8_t flags, uint32_t stream_id); + explicit Http2FrameBuilder(const Http2FrameHeader& v); + Http2FrameBuilder() {} + ~Http2FrameBuilder() {} + + size_t size() const { return buffer_.size(); } + const std::string& buffer() const { return buffer_; } + + //---------------------------------------------------------------------------- + // Methods for appending to the end of the buffer. + + // Append a sequence of bytes from various sources. + void Append(absl::string_view s); + void AppendBytes(const void* data, uint32_t num_bytes); + + // Append an array of type T[N] to the string. Intended for tests with arrays + // initialized from literals, such as: + // const char kData[] = {0x12, 0x23, ...}; + // builder.AppendBytes(kData); + template + void AppendBytes(T (&buf)[N]) { + AppendBytes(buf, N * sizeof(buf[0])); + } + + // Support for appending padding. Does not read or write the Pad Length field. + void AppendZeroes(size_t num_zero_bytes); + + // Append various sizes of unsigned integers. + void AppendUInt8(uint8_t value); + void AppendUInt16(uint16_t value); + void AppendUInt24(uint32_t value); + void AppendUInt31(uint32_t value); + void AppendUInt32(uint32_t value); + + // Append various enums. + void Append(Http2ErrorCode error_code); + void Append(Http2FrameType type); + void Append(Http2SettingsParameter parameter); + + // Append various structures. + void Append(const Http2FrameHeader& v); + void Append(const Http2PriorityFields& v); + void Append(const Http2RstStreamFields& v); + void Append(const Http2SettingFields& v); + void Append(const Http2PushPromiseFields& v); + void Append(const Http2PingFields& v); + void Append(const Http2GoAwayFields& v); + void Append(const Http2WindowUpdateFields& v); + void Append(const Http2AltSvcFields& v); + void Append(const Http2PriorityUpdateFields& v); + + // Methods for changing existing buffer contents (mostly focused on updating + // the payload length). + + void WriteAt(absl::string_view s, size_t offset); + void WriteBytesAt(const void* data, uint32_t num_bytes, size_t offset); + void WriteUInt24At(uint32_t value, size_t offset); + + // Set the payload length to the specified size. + void SetPayloadLength(uint32_t payload_length); + + // Sets the payload length to the size of the buffer minus the size of + // the frame header. + size_t SetPayloadLength(); + + private: + std::string buffer_; +}; + +} // namespace test +} // namespace http2 + +#endif // QUICHE_HTTP2_TEST_TOOLS_HTTP2_FRAME_BUILDER_H_ diff --git a/quiche/http2/test_tools/http2_frame_builder_test.cc b/quiche/http2/test_tools/http2_frame_builder_test.cc new file mode 100644 index 000000000000..9097458f5cd4 --- /dev/null +++ b/quiche/http2/test_tools/http2_frame_builder_test.cc @@ -0,0 +1,228 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/test_tools/http2_frame_builder.h" + +#include "absl/strings/escaping.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { +namespace { + +const char kHighBitSetMsg[] = "High-bit of uint32_t should be clear"; + +TEST(Http2FrameBuilderTest, Constructors) { + { + Http2FrameBuilder fb; + EXPECT_EQ(0u, fb.size()); + } + { + Http2FrameBuilder fb(Http2FrameType::DATA, 0, 123); + EXPECT_EQ(9u, fb.size()); + + const std::string kData = absl::HexStringToBytes( + "000000" // Payload length: 0 (unset) + "00" // Frame type: DATA + "00" // Flags: none + "0000007b"); // Stream ID: 123 + EXPECT_EQ(kData, fb.buffer()); + } + { + Http2FrameHeader header; + header.payload_length = (1 << 24) - 1; + header.type = Http2FrameType::HEADERS; + header.flags = Http2FrameFlag::END_HEADERS; + header.stream_id = StreamIdMask(); + Http2FrameBuilder fb(header); + EXPECT_EQ(9u, fb.size()); + + const std::string kData = absl::HexStringToBytes( + "ffffff" // Payload length: 2^24 - 1 (max uint24) + "01" // Frame type: HEADER + "04" // Flags: END_HEADERS + "7fffffff"); // Stream ID: stream id mask + EXPECT_EQ(kData, fb.buffer()); + } +} + +TEST(Http2FrameBuilderTest, SetPayloadLength) { + Http2FrameBuilder fb(Http2FrameType::DATA, PADDED, 20000); + EXPECT_EQ(9u, fb.size()); + + fb.AppendUInt8(50); // Trailing payload length + EXPECT_EQ(10u, fb.size()); + + fb.Append("ten bytes."); + EXPECT_EQ(20u, fb.size()); + + fb.AppendZeroes(50); + EXPECT_EQ(70u, fb.size()); + + fb.SetPayloadLength(); + EXPECT_EQ(70u, fb.size()); + + const std::string kData = absl::HexStringToBytes( + "00003d" // Payload length: 61 + "00" // Frame type: DATA + "08" // Flags: PADDED + "00004e20" // Stream ID: 20000 + "32" // Padding Length: 50 + "74656e2062797465732e" // "ten bytes." + "00000000000000000000" // Padding bytes + "00000000000000000000" // Padding bytes + "00000000000000000000" // Padding bytes + "00000000000000000000" // Padding bytes + "00000000000000000000"); // Padding bytes + EXPECT_EQ(kData, fb.buffer()); +} + +TEST(Http2FrameBuilderTest, Settings) { + Http2FrameBuilder fb(Http2FrameType::SETTINGS, 0, 0); + Http2SettingFields sf; + + sf.parameter = Http2SettingsParameter::HEADER_TABLE_SIZE; + sf.value = 1 << 12; + fb.Append(sf); + + sf.parameter = Http2SettingsParameter::ENABLE_PUSH; + sf.value = 0; + fb.Append(sf); + + sf.parameter = Http2SettingsParameter::MAX_CONCURRENT_STREAMS; + sf.value = ~0; + fb.Append(sf); + + sf.parameter = Http2SettingsParameter::INITIAL_WINDOW_SIZE; + sf.value = 1 << 16; + fb.Append(sf); + + sf.parameter = Http2SettingsParameter::MAX_FRAME_SIZE; + sf.value = 1 << 14; + fb.Append(sf); + + sf.parameter = Http2SettingsParameter::MAX_HEADER_LIST_SIZE; + sf.value = 1 << 10; + fb.Append(sf); + + size_t payload_size = 6 * Http2SettingFields::EncodedSize(); + EXPECT_EQ(Http2FrameHeader::EncodedSize() + payload_size, fb.size()); + + fb.SetPayloadLength(payload_size); + + const std::string kData = absl::HexStringToBytes( + "000024" // Payload length: 36 + "04" // Frame type: SETTINGS + "00" // Flags: none + "00000000" // Stream ID: 0 + "0001" // HEADER_TABLE_SIZE + "00001000" // 4096 + "0002" // ENABLE_PUSH + "00000000" // 0 + "0003" // MAX_CONCURRENT_STREAMS + "ffffffff" // 0xffffffff (max uint32) + "0004" // INITIAL_WINDOW_SIZE + "00010000" // 4096 + "0005" // MAX_FRAME_SIZE + "00004000" // 4096 + "0006" // MAX_HEADER_LIST_SIZE + "00000400"); // 1024 + EXPECT_EQ(kData, fb.buffer()); +} + +TEST(Http2FrameBuilderTest, EnhanceYourCalm) { + const std::string kData = absl::HexStringToBytes("0000000b"); + { + Http2FrameBuilder fb; + fb.Append(Http2ErrorCode::ENHANCE_YOUR_CALM); + EXPECT_EQ(kData, fb.buffer()); + } + { + Http2FrameBuilder fb; + Http2RstStreamFields rsp; + rsp.error_code = Http2ErrorCode::ENHANCE_YOUR_CALM; + fb.Append(rsp); + EXPECT_EQ(kData, fb.buffer()); + } +} + +TEST(Http2FrameBuilderTest, PushPromise) { + const std::string kData = absl::HexStringToBytes("7fffffff"); + { + Http2FrameBuilder fb; + fb.Append(Http2PushPromiseFields{0x7fffffff}); + EXPECT_EQ(kData, fb.buffer()); + } + { + Http2FrameBuilder fb; + // Will generate an error if the high-bit of the stream id is set. + EXPECT_NONFATAL_FAILURE(fb.Append(Http2PushPromiseFields{0xffffffff}), + kHighBitSetMsg); + EXPECT_EQ(kData, fb.buffer()); + } +} + +TEST(Http2FrameBuilderTest, Ping) { + Http2FrameBuilder fb; + Http2PingFields ping{"8 bytes"}; + fb.Append(ping); + + const absl::string_view kData{"8 bytes\0", 8}; + EXPECT_EQ(kData.size(), Http2PingFields::EncodedSize()); + EXPECT_EQ(kData, fb.buffer()); +} + +TEST(Http2FrameBuilderTest, GoAway) { + const std::string kData = absl::HexStringToBytes( + "12345678" // Last Stream Id + "00000001"); // Error code + EXPECT_EQ(kData.size(), Http2GoAwayFields::EncodedSize()); + { + Http2FrameBuilder fb; + Http2GoAwayFields ga(0x12345678, Http2ErrorCode::PROTOCOL_ERROR); + fb.Append(ga); + EXPECT_EQ(kData, fb.buffer()); + } + { + Http2FrameBuilder fb; + // Will generate a test failure if the high-bit of the stream id is set. + Http2GoAwayFields ga(0x92345678, Http2ErrorCode::PROTOCOL_ERROR); + EXPECT_NONFATAL_FAILURE(fb.Append(ga), kHighBitSetMsg); + EXPECT_EQ(kData, fb.buffer()); + } +} + +TEST(Http2FrameBuilderTest, WindowUpdate) { + Http2FrameBuilder fb; + fb.Append(Http2WindowUpdateFields{123456}); + + // Will generate a test failure if the high-bit of the increment is set. + EXPECT_NONFATAL_FAILURE(fb.Append(Http2WindowUpdateFields{0x80000001}), + kHighBitSetMsg); + + // Will generate a test failure if the increment is zero. + EXPECT_NONFATAL_FAILURE(fb.Append(Http2WindowUpdateFields{0}), "non-zero"); + + const std::string kData = absl::HexStringToBytes( + "0001e240" // Valid Window Size Increment + "00000001" // High-bit cleared + "00000000"); // Invalid Window Size Increment + EXPECT_EQ(kData.size(), 3 * Http2WindowUpdateFields::EncodedSize()); + EXPECT_EQ(kData, fb.buffer()); +} + +TEST(Http2FrameBuilderTest, AltSvc) { + Http2FrameBuilder fb; + fb.Append(Http2AltSvcFields{99}); + fb.Append(Http2AltSvcFields{0}); // No optional origin + const std::string kData = absl::HexStringToBytes( + "0063" // Has origin. + "0000"); // Doesn't have origin. + EXPECT_EQ(kData.size(), 2 * Http2AltSvcFields::EncodedSize()); + EXPECT_EQ(kData, fb.buffer()); +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/test_tools/http2_frame_decoder_listener_test_util.cc b/quiche/http2/test_tools/http2_frame_decoder_listener_test_util.cc new file mode 100644 index 000000000000..fea56f27d43c --- /dev/null +++ b/quiche/http2/test_tools/http2_frame_decoder_listener_test_util.cc @@ -0,0 +1,511 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/test_tools/http2_frame_decoder_listener_test_util.h" + +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { + +FailingHttp2FrameDecoderListener::FailingHttp2FrameDecoderListener() = default; +FailingHttp2FrameDecoderListener::~FailingHttp2FrameDecoderListener() = default; + +bool FailingHttp2FrameDecoderListener::OnFrameHeader( + const Http2FrameHeader& header) { + ADD_FAILURE() << "OnFrameHeader: " << header; + return false; +} + +void FailingHttp2FrameDecoderListener::OnDataStart( + const Http2FrameHeader& header) { + FAIL() << "OnDataStart: " << header; +} + +void FailingHttp2FrameDecoderListener::OnDataPayload(const char* /*data*/, + size_t len) { + FAIL() << "OnDataPayload: len=" << len; +} + +void FailingHttp2FrameDecoderListener::OnDataEnd() { FAIL() << "OnDataEnd"; } + +void FailingHttp2FrameDecoderListener::OnHeadersStart( + const Http2FrameHeader& header) { + FAIL() << "OnHeadersStart: " << header; +} + +void FailingHttp2FrameDecoderListener::OnHeadersPriority( + const Http2PriorityFields& priority) { + FAIL() << "OnHeadersPriority: " << priority; +} + +void FailingHttp2FrameDecoderListener::OnHpackFragment(const char* /*data*/, + size_t len) { + FAIL() << "OnHpackFragment: len=" << len; +} + +void FailingHttp2FrameDecoderListener::OnHeadersEnd() { + FAIL() << "OnHeadersEnd"; +} + +void FailingHttp2FrameDecoderListener::OnPriorityFrame( + const Http2FrameHeader& header, const Http2PriorityFields& priority) { + FAIL() << "OnPriorityFrame: " << header << "; priority: " << priority; +} + +void FailingHttp2FrameDecoderListener::OnContinuationStart( + const Http2FrameHeader& header) { + FAIL() << "OnContinuationStart: " << header; +} + +void FailingHttp2FrameDecoderListener::OnContinuationEnd() { + FAIL() << "OnContinuationEnd"; +} + +void FailingHttp2FrameDecoderListener::OnPadLength(size_t trailing_length) { + FAIL() << "OnPadLength: trailing_length=" << trailing_length; +} + +void FailingHttp2FrameDecoderListener::OnPadding(const char* /*padding*/, + size_t skipped_length) { + FAIL() << "OnPadding: skipped_length=" << skipped_length; +} + +void FailingHttp2FrameDecoderListener::OnRstStream( + const Http2FrameHeader& header, Http2ErrorCode error_code) { + FAIL() << "OnRstStream: " << header << "; code=" << error_code; +} + +void FailingHttp2FrameDecoderListener::OnSettingsStart( + const Http2FrameHeader& header) { + FAIL() << "OnSettingsStart: " << header; +} + +void FailingHttp2FrameDecoderListener::OnSetting( + const Http2SettingFields& setting_fields) { + FAIL() << "OnSetting: " << setting_fields; +} + +void FailingHttp2FrameDecoderListener::OnSettingsEnd() { + FAIL() << "OnSettingsEnd"; +} + +void FailingHttp2FrameDecoderListener::OnSettingsAck( + const Http2FrameHeader& header) { + FAIL() << "OnSettingsAck: " << header; +} + +void FailingHttp2FrameDecoderListener::OnPushPromiseStart( + const Http2FrameHeader& header, const Http2PushPromiseFields& promise, + size_t total_padding_length) { + FAIL() << "OnPushPromiseStart: " << header << "; promise: " << promise + << "; total_padding_length: " << total_padding_length; +} + +void FailingHttp2FrameDecoderListener::OnPushPromiseEnd() { + FAIL() << "OnPushPromiseEnd"; +} + +void FailingHttp2FrameDecoderListener::OnPing(const Http2FrameHeader& header, + const Http2PingFields& ping) { + FAIL() << "OnPing: " << header << "; ping: " << ping; +} + +void FailingHttp2FrameDecoderListener::OnPingAck(const Http2FrameHeader& header, + const Http2PingFields& ping) { + FAIL() << "OnPingAck: " << header << "; ping: " << ping; +} + +void FailingHttp2FrameDecoderListener::OnGoAwayStart( + const Http2FrameHeader& header, const Http2GoAwayFields& goaway) { + FAIL() << "OnGoAwayStart: " << header << "; goaway: " << goaway; +} + +void FailingHttp2FrameDecoderListener::OnGoAwayOpaqueData(const char* /*data*/, + size_t len) { + FAIL() << "OnGoAwayOpaqueData: len=" << len; +} + +void FailingHttp2FrameDecoderListener::OnGoAwayEnd() { + FAIL() << "OnGoAwayEnd"; +} + +void FailingHttp2FrameDecoderListener::OnWindowUpdate( + const Http2FrameHeader& header, uint32_t increment) { + FAIL() << "OnWindowUpdate: " << header << "; increment=" << increment; +} + +void FailingHttp2FrameDecoderListener::OnAltSvcStart( + const Http2FrameHeader& header, size_t origin_length, size_t value_length) { + FAIL() << "OnAltSvcStart: " << header << "; origin_length: " << origin_length + << "; value_length: " << value_length; +} + +void FailingHttp2FrameDecoderListener::OnAltSvcOriginData(const char* /*data*/, + size_t len) { + FAIL() << "OnAltSvcOriginData: len=" << len; +} + +void FailingHttp2FrameDecoderListener::OnAltSvcValueData(const char* /*data*/, + size_t len) { + FAIL() << "OnAltSvcValueData: len=" << len; +} + +void FailingHttp2FrameDecoderListener::OnAltSvcEnd() { + FAIL() << "OnAltSvcEnd"; +} + +void FailingHttp2FrameDecoderListener::OnPriorityUpdateStart( + const Http2FrameHeader& header, + const Http2PriorityUpdateFields& priority_update) { + FAIL() << "OnPriorityUpdateStart: " << header << "; prioritized_stream_id: " + << priority_update.prioritized_stream_id; +} + +void FailingHttp2FrameDecoderListener::OnPriorityUpdatePayload( + const char* /*data*/, size_t len) { + FAIL() << "OnPriorityUpdatePayload: len=" << len; +} + +void FailingHttp2FrameDecoderListener::OnPriorityUpdateEnd() { + FAIL() << "OnPriorityUpdateEnd"; +} + +void FailingHttp2FrameDecoderListener::OnUnknownStart( + const Http2FrameHeader& header) { + FAIL() << "OnUnknownStart: " << header; +} + +void FailingHttp2FrameDecoderListener::OnUnknownPayload(const char* /*data*/, + size_t len) { + FAIL() << "OnUnknownPayload: len=" << len; +} + +void FailingHttp2FrameDecoderListener::OnUnknownEnd() { + FAIL() << "OnUnknownEnd"; +} + +void FailingHttp2FrameDecoderListener::OnPaddingTooLong( + const Http2FrameHeader& header, size_t missing_length) { + FAIL() << "OnPaddingTooLong: " << header + << "; missing_length: " << missing_length; +} + +void FailingHttp2FrameDecoderListener::OnFrameSizeError( + const Http2FrameHeader& header) { + FAIL() << "OnFrameSizeError: " << header; +} + +LoggingHttp2FrameDecoderListener::LoggingHttp2FrameDecoderListener() + : wrapped_(nullptr) {} +LoggingHttp2FrameDecoderListener::LoggingHttp2FrameDecoderListener( + Http2FrameDecoderListener* wrapped) + : wrapped_(wrapped) {} +LoggingHttp2FrameDecoderListener::~LoggingHttp2FrameDecoderListener() = default; + +bool LoggingHttp2FrameDecoderListener::OnFrameHeader( + const Http2FrameHeader& header) { + QUICHE_VLOG(1) << "OnFrameHeader: " << header; + if (wrapped_ != nullptr) { + return wrapped_->OnFrameHeader(header); + } + return true; +} + +void LoggingHttp2FrameDecoderListener::OnDataStart( + const Http2FrameHeader& header) { + QUICHE_VLOG(1) << "OnDataStart: " << header; + if (wrapped_ != nullptr) { + wrapped_->OnDataStart(header); + } +} + +void LoggingHttp2FrameDecoderListener::OnDataPayload(const char* data, + size_t len) { + QUICHE_VLOG(1) << "OnDataPayload: len=" << len; + if (wrapped_ != nullptr) { + wrapped_->OnDataPayload(data, len); + } +} + +void LoggingHttp2FrameDecoderListener::OnDataEnd() { + QUICHE_VLOG(1) << "OnDataEnd"; + if (wrapped_ != nullptr) { + wrapped_->OnDataEnd(); + } +} + +void LoggingHttp2FrameDecoderListener::OnHeadersStart( + const Http2FrameHeader& header) { + QUICHE_VLOG(1) << "OnHeadersStart: " << header; + if (wrapped_ != nullptr) { + wrapped_->OnHeadersStart(header); + } +} + +void LoggingHttp2FrameDecoderListener::OnHeadersPriority( + const Http2PriorityFields& priority) { + QUICHE_VLOG(1) << "OnHeadersPriority: " << priority; + if (wrapped_ != nullptr) { + wrapped_->OnHeadersPriority(priority); + } +} + +void LoggingHttp2FrameDecoderListener::OnHpackFragment(const char* data, + size_t len) { + QUICHE_VLOG(1) << "OnHpackFragment: len=" << len; + if (wrapped_ != nullptr) { + wrapped_->OnHpackFragment(data, len); + } +} + +void LoggingHttp2FrameDecoderListener::OnHeadersEnd() { + QUICHE_VLOG(1) << "OnHeadersEnd"; + if (wrapped_ != nullptr) { + wrapped_->OnHeadersEnd(); + } +} + +void LoggingHttp2FrameDecoderListener::OnPriorityFrame( + const Http2FrameHeader& header, const Http2PriorityFields& priority) { + QUICHE_VLOG(1) << "OnPriorityFrame: " << header << "; priority: " << priority; + if (wrapped_ != nullptr) { + wrapped_->OnPriorityFrame(header, priority); + } +} + +void LoggingHttp2FrameDecoderListener::OnContinuationStart( + const Http2FrameHeader& header) { + QUICHE_VLOG(1) << "OnContinuationStart: " << header; + if (wrapped_ != nullptr) { + wrapped_->OnContinuationStart(header); + } +} + +void LoggingHttp2FrameDecoderListener::OnContinuationEnd() { + QUICHE_VLOG(1) << "OnContinuationEnd"; + if (wrapped_ != nullptr) { + wrapped_->OnContinuationEnd(); + } +} + +void LoggingHttp2FrameDecoderListener::OnPadLength(size_t trailing_length) { + QUICHE_VLOG(1) << "OnPadLength: trailing_length=" << trailing_length; + if (wrapped_ != nullptr) { + wrapped_->OnPadLength(trailing_length); + } +} + +void LoggingHttp2FrameDecoderListener::OnPadding(const char* padding, + size_t skipped_length) { + QUICHE_VLOG(1) << "OnPadding: skipped_length=" << skipped_length; + if (wrapped_ != nullptr) { + wrapped_->OnPadding(padding, skipped_length); + } +} + +void LoggingHttp2FrameDecoderListener::OnRstStream( + const Http2FrameHeader& header, Http2ErrorCode error_code) { + QUICHE_VLOG(1) << "OnRstStream: " << header << "; code=" << error_code; + if (wrapped_ != nullptr) { + wrapped_->OnRstStream(header, error_code); + } +} + +void LoggingHttp2FrameDecoderListener::OnSettingsStart( + const Http2FrameHeader& header) { + QUICHE_VLOG(1) << "OnSettingsStart: " << header; + if (wrapped_ != nullptr) { + wrapped_->OnSettingsStart(header); + } +} + +void LoggingHttp2FrameDecoderListener::OnSetting( + const Http2SettingFields& setting_fields) { + QUICHE_VLOG(1) << "OnSetting: " << setting_fields; + if (wrapped_ != nullptr) { + wrapped_->OnSetting(setting_fields); + } +} + +void LoggingHttp2FrameDecoderListener::OnSettingsEnd() { + QUICHE_VLOG(1) << "OnSettingsEnd"; + if (wrapped_ != nullptr) { + wrapped_->OnSettingsEnd(); + } +} + +void LoggingHttp2FrameDecoderListener::OnSettingsAck( + const Http2FrameHeader& header) { + QUICHE_VLOG(1) << "OnSettingsAck: " << header; + if (wrapped_ != nullptr) { + wrapped_->OnSettingsAck(header); + } +} + +void LoggingHttp2FrameDecoderListener::OnPushPromiseStart( + const Http2FrameHeader& header, const Http2PushPromiseFields& promise, + size_t total_padding_length) { + QUICHE_VLOG(1) << "OnPushPromiseStart: " << header << "; promise: " << promise + << "; total_padding_length: " << total_padding_length; + if (wrapped_ != nullptr) { + wrapped_->OnPushPromiseStart(header, promise, total_padding_length); + } +} + +void LoggingHttp2FrameDecoderListener::OnPushPromiseEnd() { + QUICHE_VLOG(1) << "OnPushPromiseEnd"; + if (wrapped_ != nullptr) { + wrapped_->OnPushPromiseEnd(); + } +} + +void LoggingHttp2FrameDecoderListener::OnPing(const Http2FrameHeader& header, + const Http2PingFields& ping) { + QUICHE_VLOG(1) << "OnPing: " << header << "; ping: " << ping; + if (wrapped_ != nullptr) { + wrapped_->OnPing(header, ping); + } +} + +void LoggingHttp2FrameDecoderListener::OnPingAck(const Http2FrameHeader& header, + const Http2PingFields& ping) { + QUICHE_VLOG(1) << "OnPingAck: " << header << "; ping: " << ping; + if (wrapped_ != nullptr) { + wrapped_->OnPingAck(header, ping); + } +} + +void LoggingHttp2FrameDecoderListener::OnGoAwayStart( + const Http2FrameHeader& header, const Http2GoAwayFields& goaway) { + QUICHE_VLOG(1) << "OnGoAwayStart: " << header << "; goaway: " << goaway; + if (wrapped_ != nullptr) { + wrapped_->OnGoAwayStart(header, goaway); + } +} + +void LoggingHttp2FrameDecoderListener::OnGoAwayOpaqueData(const char* data, + size_t len) { + QUICHE_VLOG(1) << "OnGoAwayOpaqueData: len=" << len; + if (wrapped_ != nullptr) { + wrapped_->OnGoAwayOpaqueData(data, len); + } +} + +void LoggingHttp2FrameDecoderListener::OnGoAwayEnd() { + QUICHE_VLOG(1) << "OnGoAwayEnd"; + if (wrapped_ != nullptr) { + wrapped_->OnGoAwayEnd(); + } +} + +void LoggingHttp2FrameDecoderListener::OnWindowUpdate( + const Http2FrameHeader& header, uint32_t increment) { + QUICHE_VLOG(1) << "OnWindowUpdate: " << header << "; increment=" << increment; + if (wrapped_ != nullptr) { + wrapped_->OnWindowUpdate(header, increment); + } +} + +void LoggingHttp2FrameDecoderListener::OnAltSvcStart( + const Http2FrameHeader& header, size_t origin_length, size_t value_length) { + QUICHE_VLOG(1) << "OnAltSvcStart: " << header + << "; origin_length: " << origin_length + << "; value_length: " << value_length; + if (wrapped_ != nullptr) { + wrapped_->OnAltSvcStart(header, origin_length, value_length); + } +} + +void LoggingHttp2FrameDecoderListener::OnAltSvcOriginData(const char* data, + size_t len) { + QUICHE_VLOG(1) << "OnAltSvcOriginData: len=" << len; + if (wrapped_ != nullptr) { + wrapped_->OnAltSvcOriginData(data, len); + } +} + +void LoggingHttp2FrameDecoderListener::OnAltSvcValueData(const char* data, + size_t len) { + QUICHE_VLOG(1) << "OnAltSvcValueData: len=" << len; + if (wrapped_ != nullptr) { + wrapped_->OnAltSvcValueData(data, len); + } +} + +void LoggingHttp2FrameDecoderListener::OnAltSvcEnd() { + QUICHE_VLOG(1) << "OnAltSvcEnd"; + if (wrapped_ != nullptr) { + wrapped_->OnAltSvcEnd(); + } +} + +void LoggingHttp2FrameDecoderListener::OnPriorityUpdateStart( + const Http2FrameHeader& header, + const Http2PriorityUpdateFields& priority_update) { + QUICHE_VLOG(1) << "OnPriorityUpdateStart"; + if (wrapped_ != nullptr) { + wrapped_->OnPriorityUpdateStart(header, priority_update); + } +} + +void LoggingHttp2FrameDecoderListener::OnPriorityUpdatePayload(const char* data, + size_t len) { + QUICHE_VLOG(1) << "OnPriorityUpdatePayload"; + if (wrapped_ != nullptr) { + wrapped_->OnPriorityUpdatePayload(data, len); + } +} + +void LoggingHttp2FrameDecoderListener::OnPriorityUpdateEnd() { + QUICHE_VLOG(1) << "OnPriorityUpdateEnd"; + if (wrapped_ != nullptr) { + wrapped_->OnPriorityUpdateEnd(); + } +} + +void LoggingHttp2FrameDecoderListener::OnUnknownStart( + const Http2FrameHeader& header) { + QUICHE_VLOG(1) << "OnUnknownStart: " << header; + if (wrapped_ != nullptr) { + wrapped_->OnUnknownStart(header); + } +} + +void LoggingHttp2FrameDecoderListener::OnUnknownPayload(const char* data, + size_t len) { + QUICHE_VLOG(1) << "OnUnknownPayload: len=" << len; + if (wrapped_ != nullptr) { + wrapped_->OnUnknownPayload(data, len); + } +} + +void LoggingHttp2FrameDecoderListener::OnUnknownEnd() { + QUICHE_VLOG(1) << "OnUnknownEnd"; + if (wrapped_ != nullptr) { + wrapped_->OnUnknownEnd(); + } +} + +void LoggingHttp2FrameDecoderListener::OnPaddingTooLong( + const Http2FrameHeader& header, size_t missing_length) { + QUICHE_VLOG(1) << "OnPaddingTooLong: " << header + << "; missing_length: " << missing_length; + if (wrapped_ != nullptr) { + wrapped_->OnPaddingTooLong(header, missing_length); + } +} + +void LoggingHttp2FrameDecoderListener::OnFrameSizeError( + const Http2FrameHeader& header) { + QUICHE_VLOG(1) << "OnFrameSizeError: " << header; + if (wrapped_ != nullptr) { + wrapped_->OnFrameSizeError(header); + } +} + +} // namespace http2 diff --git a/quiche/http2/test_tools/http2_frame_decoder_listener_test_util.h b/quiche/http2/test_tools/http2_frame_decoder_listener_test_util.h new file mode 100644 index 000000000000..db341ce09735 --- /dev/null +++ b/quiche/http2/test_tools/http2_frame_decoder_listener_test_util.h @@ -0,0 +1,154 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_TEST_TOOLS_HTTP2_FRAME_DECODER_LISTENER_TEST_UTIL_H_ +#define QUICHE_HTTP2_TEST_TOOLS_HTTP2_FRAME_DECODER_LISTENER_TEST_UTIL_H_ + +#include + +#include + +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { + +// Fail if any of the methods are called. Allows a test to override only the +// expected calls. +class QUICHE_NO_EXPORT FailingHttp2FrameDecoderListener + : public Http2FrameDecoderListener { + public: + FailingHttp2FrameDecoderListener(); + ~FailingHttp2FrameDecoderListener() override; + + // TODO(jamessynge): Remove OnFrameHeader once done with supporting + // SpdyFramer's exact states. + bool OnFrameHeader(const Http2FrameHeader& header) override; + void OnDataStart(const Http2FrameHeader& header) override; + void OnDataPayload(const char* data, size_t len) override; + void OnDataEnd() override; + void OnHeadersStart(const Http2FrameHeader& header) override; + void OnHeadersPriority(const Http2PriorityFields& priority) override; + void OnHpackFragment(const char* data, size_t len) override; + void OnHeadersEnd() override; + void OnPriorityFrame(const Http2FrameHeader& header, + const Http2PriorityFields& priority) override; + void OnContinuationStart(const Http2FrameHeader& header) override; + void OnContinuationEnd() override; + void OnPadLength(size_t trailing_length) override; + void OnPadding(const char* padding, size_t skipped_length) override; + void OnRstStream(const Http2FrameHeader& header, + Http2ErrorCode error_code) override; + void OnSettingsStart(const Http2FrameHeader& header) override; + void OnSetting(const Http2SettingFields& setting_fields) override; + void OnSettingsEnd() override; + void OnSettingsAck(const Http2FrameHeader& header) override; + void OnPushPromiseStart(const Http2FrameHeader& header, + const Http2PushPromiseFields& promise, + size_t total_padding_length) override; + void OnPushPromiseEnd() override; + void OnPing(const Http2FrameHeader& header, + const Http2PingFields& ping) override; + void OnPingAck(const Http2FrameHeader& header, + const Http2PingFields& ping) override; + void OnGoAwayStart(const Http2FrameHeader& header, + const Http2GoAwayFields& goaway) override; + void OnGoAwayOpaqueData(const char* data, size_t len) override; + void OnGoAwayEnd() override; + void OnWindowUpdate(const Http2FrameHeader& header, + uint32_t increment) override; + void OnAltSvcStart(const Http2FrameHeader& header, size_t origin_length, + size_t value_length) override; + void OnAltSvcOriginData(const char* data, size_t len) override; + void OnAltSvcValueData(const char* data, size_t len) override; + void OnAltSvcEnd() override; + void OnPriorityUpdateStart( + const Http2FrameHeader& header, + const Http2PriorityUpdateFields& priority_update) override; + void OnPriorityUpdatePayload(const char* data, size_t len) override; + void OnPriorityUpdateEnd() override; + void OnUnknownStart(const Http2FrameHeader& header) override; + void OnUnknownPayload(const char* data, size_t len) override; + void OnUnknownEnd() override; + void OnPaddingTooLong(const Http2FrameHeader& header, + size_t missing_length) override; + void OnFrameSizeError(const Http2FrameHeader& header) override; + + private: + void EnsureNotAbstract() { FailingHttp2FrameDecoderListener instance; } +}; + +// QUICHE_VLOG's all the calls it receives, and forwards those calls to an +// optional listener. +class QUICHE_NO_EXPORT LoggingHttp2FrameDecoderListener + : public Http2FrameDecoderListener { + public: + LoggingHttp2FrameDecoderListener(); + explicit LoggingHttp2FrameDecoderListener(Http2FrameDecoderListener* wrapped); + ~LoggingHttp2FrameDecoderListener() override; + + // TODO(jamessynge): Remove OnFrameHeader once done with supporting + // SpdyFramer's exact states. + bool OnFrameHeader(const Http2FrameHeader& header) override; + void OnDataStart(const Http2FrameHeader& header) override; + void OnDataPayload(const char* data, size_t len) override; + void OnDataEnd() override; + void OnHeadersStart(const Http2FrameHeader& header) override; + void OnHeadersPriority(const Http2PriorityFields& priority) override; + void OnHpackFragment(const char* data, size_t len) override; + void OnHeadersEnd() override; + void OnPriorityFrame(const Http2FrameHeader& header, + const Http2PriorityFields& priority) override; + void OnContinuationStart(const Http2FrameHeader& header) override; + void OnContinuationEnd() override; + void OnPadLength(size_t trailing_length) override; + void OnPadding(const char* padding, size_t skipped_length) override; + void OnRstStream(const Http2FrameHeader& header, + Http2ErrorCode error_code) override; + void OnSettingsStart(const Http2FrameHeader& header) override; + void OnSetting(const Http2SettingFields& setting_fields) override; + void OnSettingsEnd() override; + void OnSettingsAck(const Http2FrameHeader& header) override; + void OnPushPromiseStart(const Http2FrameHeader& header, + const Http2PushPromiseFields& promise, + size_t total_padding_length) override; + void OnPushPromiseEnd() override; + void OnPing(const Http2FrameHeader& header, + const Http2PingFields& ping) override; + void OnPingAck(const Http2FrameHeader& header, + const Http2PingFields& ping) override; + void OnGoAwayStart(const Http2FrameHeader& header, + const Http2GoAwayFields& goaway) override; + void OnGoAwayOpaqueData(const char* data, size_t len) override; + void OnGoAwayEnd() override; + void OnWindowUpdate(const Http2FrameHeader& header, + uint32_t increment) override; + void OnAltSvcStart(const Http2FrameHeader& header, size_t origin_length, + size_t value_length) override; + void OnAltSvcOriginData(const char* data, size_t len) override; + void OnAltSvcValueData(const char* data, size_t len) override; + void OnAltSvcEnd() override; + void OnPriorityUpdateStart( + const Http2FrameHeader& header, + const Http2PriorityUpdateFields& priority_update) override; + void OnPriorityUpdatePayload(const char* data, size_t len) override; + void OnPriorityUpdateEnd() override; + void OnUnknownStart(const Http2FrameHeader& header) override; + void OnUnknownPayload(const char* data, size_t len) override; + void OnUnknownEnd() override; + void OnPaddingTooLong(const Http2FrameHeader& header, + size_t missing_length) override; + void OnFrameSizeError(const Http2FrameHeader& header) override; + + private: + void EnsureNotAbstract() { LoggingHttp2FrameDecoderListener instance; } + + Http2FrameDecoderListener* wrapped_; +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_TEST_TOOLS_HTTP2_FRAME_DECODER_LISTENER_TEST_UTIL_H_ diff --git a/quiche/http2/test_tools/http2_random.cc b/quiche/http2/test_tools/http2_random.cc new file mode 100644 index 000000000000..8ef8f98563b8 --- /dev/null +++ b/quiche/http2/test_tools/http2_random.cc @@ -0,0 +1,73 @@ +#include "quiche/http2/test_tools/http2_random.h" + +#include "absl/strings/escaping.h" +#include "openssl/chacha.h" +#include "openssl/rand.h" +#include "quiche/common/platform/api/quiche_logging.h" + +static const uint8_t kZeroNonce[12] = {0}; + +namespace http2 { +namespace test { + +Http2Random::Http2Random() { + RAND_bytes(key_, sizeof(key_)); + + QUICHE_LOG(INFO) << "Initialized test RNG with the following key: " << Key(); +} + +Http2Random::Http2Random(absl::string_view key) { + std::string decoded_key = absl::HexStringToBytes(key); + QUICHE_CHECK_EQ(sizeof(key_), decoded_key.size()); + memcpy(key_, decoded_key.data(), sizeof(key_)); +} + +std::string Http2Random::Key() const { + return absl::BytesToHexString( + absl::string_view(reinterpret_cast(key_), sizeof(key_))); +} + +void Http2Random::FillRandom(void* buffer, size_t buffer_size) { + memset(buffer, 0, buffer_size); + uint8_t* buffer_u8 = reinterpret_cast(buffer); + CRYPTO_chacha_20(buffer_u8, buffer_u8, buffer_size, key_, kZeroNonce, + counter_++); +} + +std::string Http2Random::RandString(int length) { + std::string result; + result.resize(length); + FillRandom(&result[0], length); + return result; +} + +uint64_t Http2Random::Rand64() { + union { + uint64_t number; + uint8_t bytes[sizeof(uint64_t)]; + } result; + FillRandom(result.bytes, sizeof(result.bytes)); + return result.number; +} + +double Http2Random::RandDouble() { + union { + double f; + uint64_t i; + } value; + value.i = (1023ull << 52ull) | (Rand64() & 0xfffffffffffffu); + return value.f - 1.0; +} + +std::string Http2Random::RandStringWithAlphabet(int length, + absl::string_view alphabet) { + std::string result; + result.resize(length); + for (int i = 0; i < length; i++) { + result[i] = alphabet[Uniform(alphabet.size())]; + } + return result; +} + +} // namespace test +} // namespace http2 diff --git a/quiche/http2/test_tools/http2_random.h b/quiche/http2/test_tools/http2_random.h new file mode 100644 index 000000000000..b4ec61883af8 --- /dev/null +++ b/quiche/http2/test_tools/http2_random.h @@ -0,0 +1,89 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_TEST_TOOLS_HTTP2_RANDOM_H_ +#define QUICHE_HTTP2_TEST_TOOLS_HTTP2_RANDOM_H_ + +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace test { + +// The random number generator used for unit tests. Since the algorithm is +// deterministic and fixed, this can be used to reproduce flakes in the unit +// tests caused by specific random values. +class QUICHE_NO_EXPORT Http2Random { + public: + Http2Random(); + + Http2Random(const Http2Random&) = delete; + Http2Random& operator=(const Http2Random&) = delete; + + // Reproducible random number generation: by using the same key, the same + // sequence of results is obtained. + explicit Http2Random(absl::string_view key); + std::string Key() const; + + void FillRandom(void* buffer, size_t buffer_size); + std::string RandString(int length); + + // Returns a random 64-bit value. + uint64_t Rand64(); + + // Return a uniformly distrubted random number in [0, n). + uint32_t Uniform(uint32_t n) { return Rand64() % n; } + // Return a uniformly distrubted random number in [lo, hi). + uint64_t UniformInRange(uint64_t lo, uint64_t hi) { + return lo + Rand64() % (hi - lo); + } + // Return an integer of logarithmically random scale. + uint32_t Skewed(uint32_t max_log) { + const uint32_t base = Rand32() % (max_log + 1); + const uint32_t mask = ((base < 32) ? (1u << base) : 0u) - 1u; + return Rand32() & mask; + } + // Return a random number in [0, max] range that skews low. + uint64_t RandomSizeSkewedLow(uint64_t max) { + return std::round(max * std::pow(RandDouble(), 2)); + } + + // Returns a random double between 0 and 1. + double RandDouble(); + float RandFloat() { return RandDouble(); } + + // Has 1/n chance of returning true. + bool OneIn(int n) { return Uniform(n) == 0; } + + uint8_t Rand8() { return Rand64(); } + uint16_t Rand16() { return Rand64(); } + uint32_t Rand32() { return Rand64(); } + + // Return a random string consisting of the characters from the specified + // alphabet. + std::string RandStringWithAlphabet(int length, absl::string_view alphabet); + + // STL UniformRandomNumberGenerator implementation. + using result_type = uint64_t; + static constexpr result_type min() { return 0; } + static constexpr result_type max() { + return std::numeric_limits::max(); + } + result_type operator()() { return Rand64(); } + + private: + uint8_t key_[32]; + uint32_t counter_ = 0; +}; + +} // namespace test +} // namespace http2 + +#endif // QUICHE_HTTP2_TEST_TOOLS_HTTP2_RANDOM_H_ diff --git a/quiche/http2/test_tools/http2_random_test.cc b/quiche/http2/test_tools/http2_random_test.cc new file mode 100644 index 000000000000..2cf5ba735a88 --- /dev/null +++ b/quiche/http2/test_tools/http2_random_test.cc @@ -0,0 +1,93 @@ +#include "quiche/http2/test_tools/http2_random.h" + +#include + +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { +namespace { + +TEST(Http2RandomTest, ProducesDifferentNumbers) { + Http2Random random; + uint64_t value1 = random.Rand64(); + uint64_t value2 = random.Rand64(); + uint64_t value3 = random.Rand64(); + + EXPECT_NE(value1, value2); + EXPECT_NE(value2, value3); + EXPECT_NE(value3, value1); +} + +TEST(Http2RandomTest, StartsWithDifferentKeys) { + Http2Random random1; + Http2Random random2; + + EXPECT_NE(random1.Key(), random2.Key()); + EXPECT_NE(random1.Rand64(), random2.Rand64()); + EXPECT_NE(random1.Rand64(), random2.Rand64()); + EXPECT_NE(random1.Rand64(), random2.Rand64()); +} + +TEST(Http2RandomTest, ReproducibleRandom) { + Http2Random random; + uint64_t value1 = random.Rand64(); + uint64_t value2 = random.Rand64(); + + Http2Random clone_random(random.Key()); + EXPECT_EQ(clone_random.Key(), random.Key()); + EXPECT_EQ(value1, clone_random.Rand64()); + EXPECT_EQ(value2, clone_random.Rand64()); +} + +TEST(Http2RandomTest, STLShuffle) { + Http2Random random; + const std::string original = "abcdefghijklmonpqrsuvwxyz"; + + std::string shuffled = original; + std::shuffle(shuffled.begin(), shuffled.end(), random); + EXPECT_NE(original, shuffled); +} + +TEST(Http2RandomTest, RandFloat) { + Http2Random random; + for (int i = 0; i < 10000; i++) { + float value = random.RandFloat(); + ASSERT_GE(value, 0.f); + ASSERT_LE(value, 1.f); + } +} + +TEST(Http2RandomTest, RandStringWithAlphabet) { + Http2Random random; + std::string str = random.RandStringWithAlphabet(1000, "xyz"); + EXPECT_EQ(1000u, str.size()); + + std::set characters(str.begin(), str.end()); + EXPECT_THAT(characters, testing::ElementsAre('x', 'y', 'z')); +} + +TEST(Http2RandomTest, SkewedLow) { + Http2Random random; + constexpr size_t kMax = 1234; + for (int i = 0; i < 10000; i++) { + size_t value = random.RandomSizeSkewedLow(kMax); + ASSERT_GE(value, 0u); + ASSERT_LE(value, kMax); + } +} + +// Checks that SkewedLow() generates full range. This is required, since in +// some unit tests would infinitely loop. +TEST(Http2RandomTest, SkewedLowFullRange) { + Http2Random random; + std::set values; + for (int i = 0; i < 1000; i++) { + values.insert(random.RandomSizeSkewedLow(3)); + } + EXPECT_THAT(values, testing::ElementsAre(0, 1, 2, 3)); +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/test_tools/http2_structure_decoder_test_util.cc b/quiche/http2/test_tools/http2_structure_decoder_test_util.cc new file mode 100644 index 000000000000..214fd855f3ed --- /dev/null +++ b/quiche/http2/test_tools/http2_structure_decoder_test_util.cc @@ -0,0 +1,22 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/test_tools/http2_structure_decoder_test_util.h" + +#include + +namespace http2 { +namespace test { + +// static +void Http2StructureDecoderPeer::Randomize(Http2StructureDecoder* p, + Http2Random* rng) { + p->offset_ = rng->Rand32(); + for (size_t i = 0; i < sizeof p->buffer_; ++i) { + p->buffer_[i] = rng->Rand8(); + } +} + +} // namespace test +} // namespace http2 diff --git a/quiche/http2/test_tools/http2_structure_decoder_test_util.h b/quiche/http2/test_tools/http2_structure_decoder_test_util.h new file mode 100644 index 000000000000..ad367c399789 --- /dev/null +++ b/quiche/http2/test_tools/http2_structure_decoder_test_util.h @@ -0,0 +1,24 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_TEST_TOOLS_HTTP2_STRUCTURE_DECODER_TEST_UTIL_H_ +#define QUICHE_HTTP2_TEST_TOOLS_HTTP2_STRUCTURE_DECODER_TEST_UTIL_H_ + +#include "quiche/http2/decoder/http2_structure_decoder.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace test { + +class QUICHE_NO_EXPORT Http2StructureDecoderPeer { + public: + // Overwrite the Http2StructureDecoder instance with random values. + static void Randomize(Http2StructureDecoder* p, Http2Random* rng); +}; + +} // namespace test +} // namespace http2 + +#endif // QUICHE_HTTP2_TEST_TOOLS_HTTP2_STRUCTURE_DECODER_TEST_UTIL_H_ diff --git a/quiche/http2/test_tools/http2_structures_test_util.cc b/quiche/http2/test_tools/http2_structures_test_util.cc new file mode 100644 index 000000000000..09bf1eee5366 --- /dev/null +++ b/quiche/http2/test_tools/http2_structures_test_util.cc @@ -0,0 +1,112 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/test_tools/http2_structures_test_util.h" + +#include + +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/http2/test_tools/http2_constants_test_util.h" +#include "quiche/http2/test_tools/http2_random.h" + +namespace http2 { +namespace test { + +void Randomize(Http2FrameHeader* out, Http2Random* rng) { + out->payload_length = rng->Rand32() & 0xffffff; + out->type = static_cast(rng->Rand8()); + out->flags = static_cast(rng->Rand8()); + out->stream_id = rng->Rand32() & StreamIdMask(); +} +void Randomize(Http2PriorityFields* out, Http2Random* rng) { + out->stream_dependency = rng->Rand32() & StreamIdMask(); + out->weight = rng->Rand8() + 1; + out->is_exclusive = rng->OneIn(2); +} +void Randomize(Http2RstStreamFields* out, Http2Random* rng) { + out->error_code = static_cast(rng->Rand32()); +} +void Randomize(Http2SettingFields* out, Http2Random* rng) { + out->parameter = static_cast(rng->Rand16()); + out->value = rng->Rand32(); +} +void Randomize(Http2PushPromiseFields* out, Http2Random* rng) { + out->promised_stream_id = rng->Rand32() & StreamIdMask(); +} +void Randomize(Http2PingFields* out, Http2Random* rng) { + for (int ndx = 0; ndx < 8; ++ndx) { + out->opaque_bytes[ndx] = rng->Rand8(); + } +} +void Randomize(Http2GoAwayFields* out, Http2Random* rng) { + out->last_stream_id = rng->Rand32() & StreamIdMask(); + out->error_code = static_cast(rng->Rand32()); +} +void Randomize(Http2WindowUpdateFields* out, Http2Random* rng) { + out->window_size_increment = rng->Rand32() & 0x7fffffff; +} +void Randomize(Http2AltSvcFields* out, Http2Random* rng) { + out->origin_length = rng->Rand16(); +} +void Randomize(Http2PriorityUpdateFields* out, Http2Random* rng) { + out->prioritized_stream_id = rng->Rand32() & StreamIdMask(); +} + +void ScrubFlagsOfHeader(Http2FrameHeader* header) { + uint8_t invalid_mask = InvalidFlagMaskForFrameType(header->type); + uint8_t keep_mask = ~invalid_mask; + header->RetainFlags(keep_mask); +} + +bool FrameIsPadded(const Http2FrameHeader& header) { + switch (header.type) { + case Http2FrameType::DATA: + case Http2FrameType::HEADERS: + case Http2FrameType::PUSH_PROMISE: + return header.IsPadded(); + default: + return false; + } +} + +bool FrameHasPriority(const Http2FrameHeader& header) { + switch (header.type) { + case Http2FrameType::HEADERS: + return header.HasPriority(); + case Http2FrameType::PRIORITY: + return true; + default: + return false; + } +} + +bool FrameCanHavePayload(const Http2FrameHeader& header) { + switch (header.type) { + case Http2FrameType::DATA: + case Http2FrameType::HEADERS: + case Http2FrameType::PUSH_PROMISE: + case Http2FrameType::CONTINUATION: + case Http2FrameType::PING: + case Http2FrameType::GOAWAY: + case Http2FrameType::ALTSVC: + return true; + default: + return false; + } +} + +bool FrameCanHaveHpackPayload(const Http2FrameHeader& header) { + switch (header.type) { + case Http2FrameType::HEADERS: + case Http2FrameType::PUSH_PROMISE: + case Http2FrameType::CONTINUATION: + return true; + default: + return false; + } +} + +} // namespace test +} // namespace http2 diff --git a/quiche/http2/test_tools/http2_structures_test_util.h b/quiche/http2/test_tools/http2_structures_test_util.h new file mode 100644 index 000000000000..5a5353e8b3a9 --- /dev/null +++ b/quiche/http2/test_tools/http2_structures_test_util.h @@ -0,0 +1,61 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_TEST_TOOLS_HTTP2_STRUCTURES_TEST_UTIL_H_ +#define QUICHE_HTTP2_TEST_TOOLS_HTTP2_STRUCTURES_TEST_UTIL_H_ + +#include + +#include "quiche/http2/http2_structures.h" +#include "quiche/http2/test_tools/http2_frame_builder.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { + +template +std::string SerializeStructure(const S& s) { + Http2FrameBuilder fb; + fb.Append(s); + EXPECT_EQ(S::EncodedSize(), fb.size()); + return fb.buffer(); +} + +// Randomize the members of out, in a manner that yields encodeable contents +// (e.g. a "uint24" field has only the low 24 bits set). +void Randomize(Http2FrameHeader* out, Http2Random* rng); +void Randomize(Http2PriorityFields* out, Http2Random* rng); +void Randomize(Http2RstStreamFields* out, Http2Random* rng); +void Randomize(Http2SettingFields* out, Http2Random* rng); +void Randomize(Http2PushPromiseFields* out, Http2Random* rng); +void Randomize(Http2PingFields* out, Http2Random* rng); +void Randomize(Http2GoAwayFields* out, Http2Random* rng); +void Randomize(Http2WindowUpdateFields* out, Http2Random* rng); +void Randomize(Http2AltSvcFields* out, Http2Random* rng); +void Randomize(Http2PriorityUpdateFields* out, Http2Random* rng); + +// Clear bits of header->flags that are known to be invalid for the +// type. For unknown frame types, no change is made. +void ScrubFlagsOfHeader(Http2FrameHeader* header); + +// Is the frame with this header padded? Only true for known/supported frame +// types. +bool FrameIsPadded(const Http2FrameHeader& header); + +// Does the frame with this header have Http2PriorityFields? +bool FrameHasPriority(const Http2FrameHeader& header); + +// Does the frame with this header have a variable length payload (including +// empty) payload (e.g. DATA or HEADERS)? Really a test of the frame type. +bool FrameCanHavePayload(const Http2FrameHeader& header); + +// Does the frame with this header have a variable length HPACK payload +// (including empty) payload (e.g. HEADERS)? Really a test of the frame type. +bool FrameCanHaveHpackPayload(const Http2FrameHeader& header); + +} // namespace test +} // namespace http2 + +#endif // QUICHE_HTTP2_TEST_TOOLS_HTTP2_STRUCTURES_TEST_UTIL_H_ diff --git a/quiche/http2/test_tools/payload_decoder_base_test_util.cc b/quiche/http2/test_tools/payload_decoder_base_test_util.cc new file mode 100644 index 000000000000..702c72535e39 --- /dev/null +++ b/quiche/http2/test_tools/payload_decoder_base_test_util.cc @@ -0,0 +1,97 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/test_tools/payload_decoder_base_test_util.h" + +#include "quiche/http2/test_tools/frame_decoder_state_test_util.h" +#include "quiche/http2/test_tools/http2_structures_test_util.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { +PayloadDecoderBaseTest::PayloadDecoderBaseTest() { + // If the test adds more data after the frame payload, + // stop as soon as the payload is decoded. + stop_decode_on_done_ = true; + frame_header_is_set_ = false; + Randomize(&frame_header_, RandomPtr()); +} + +DecodeStatus PayloadDecoderBaseTest::StartDecoding(DecodeBuffer* db) { + QUICHE_DVLOG(2) << "StartDecoding, db->Remaining=" << db->Remaining(); + // Make sure sub-class has set frame_header_ so that we can inject it + // into the payload decoder below. + if (!frame_header_is_set_) { + ADD_FAILURE() << "frame_header_ is not set"; + return DecodeStatus::kDecodeError; + } + // The contract with the payload decoders is that they won't receive a + // decode buffer that extends beyond the end of the frame. + if (db->Remaining() > frame_header_.payload_length) { + ADD_FAILURE() << "DecodeBuffer has too much data: " << db->Remaining() + << " > " << frame_header_.payload_length; + return DecodeStatus::kDecodeError; + } + + // Prepare the payload decoder. + PreparePayloadDecoder(); + + // Reconstruct the FrameDecoderState, prepare the listener, and add it to + // the FrameDecoderState. + frame_decoder_state_ = std::make_unique(); + frame_decoder_state_->set_listener(PrepareListener()); + + // Make sure that a listener was provided. + if (frame_decoder_state_->listener() == nullptr) { + ADD_FAILURE() << "PrepareListener must return a listener."; + return DecodeStatus::kDecodeError; + } + + // Now that nothing in the payload decoder should be valid, inject the + // Http2FrameHeader whose payload we're about to decode. That header is the + // only state that a payload decoder should expect is valid when its Start + // method is called. + FrameDecoderStatePeer::set_frame_header(frame_header_, + frame_decoder_state_.get()); + DecodeStatus status = StartDecodingPayload(db); + if (status != DecodeStatus::kDecodeInProgress) { + // Keep track of this so that a concrete test can verify that both fast + // and slow decoding paths have been tested. + ++fast_decode_count_; + } + return status; +} + +DecodeStatus PayloadDecoderBaseTest::ResumeDecoding(DecodeBuffer* db) { + QUICHE_DVLOG(2) << "ResumeDecoding, db->Remaining=" << db->Remaining(); + DecodeStatus status = ResumeDecodingPayload(db); + if (status != DecodeStatus::kDecodeInProgress) { + // Keep track of this so that a concrete test can verify that both fast + // and slow decoding paths have been tested. + ++slow_decode_count_; + } + return status; +} + +::testing::AssertionResult +PayloadDecoderBaseTest::DecodePayloadAndValidateSeveralWays( + absl::string_view payload, Validator validator) { + HTTP2_VERIFY_TRUE(frame_header_is_set_); + // Cap the payload to be decoded at the declared payload length. This is + // required by the decoders' preconditions; they are designed on the + // assumption that they're never passed more than they're permitted to + // consume. + // Note that it is OK if the payload is too short; the validator may be + // designed to check for that. + if (payload.size() > frame_header_.payload_length) { + payload = absl::string_view(payload.data(), frame_header_.payload_length); + } + DecodeBuffer db(payload); + ResetDecodeSpeedCounters(); + const bool kMayReturnZeroOnFirst = false; + return DecodeAndValidateSeveralWays(&db, kMayReturnZeroOnFirst, validator); +} + +} // namespace test +} // namespace http2 diff --git a/quiche/http2/test_tools/payload_decoder_base_test_util.h b/quiche/http2/test_tools/payload_decoder_base_test_util.h new file mode 100644 index 000000000000..375fee694e15 --- /dev/null +++ b/quiche/http2/test_tools/payload_decoder_base_test_util.h @@ -0,0 +1,444 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_TEST_TOOLS_PAYLOAD_DECODER_BASE_TEST_UTIL_H_ +#define QUICHE_HTTP2_TEST_TOOLS_PAYLOAD_DECODER_BASE_TEST_UTIL_H_ + +// Base class for testing concrete payload decoder classes. + +#include + +#include + +#include "absl/strings/string_view.h" +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/decoder/frame_decoder_state.h" +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/http2/test_tools/frame_parts.h" +#include "quiche/http2/test_tools/http2_constants_test_util.h" +#include "quiche/http2/test_tools/http2_frame_builder.h" +#include "quiche/http2/test_tools/random_decoder_test_base.h" +#include "quiche/http2/test_tools/verify_macros.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace http2 { +namespace test { + +// Base class for tests of payload decoders. Below this there is a templated +// sub-class that adds a bunch of type specific features. +class QUICHE_NO_EXPORT PayloadDecoderBaseTest : public RandomDecoderTest { + protected: + PayloadDecoderBaseTest(); + + // Virtual functions to be implemented by the test classes for the individual + // payload decoders... + + // Start decoding the payload. + virtual DecodeStatus StartDecodingPayload(DecodeBuffer* db) = 0; + + // Resume decoding the payload. + virtual DecodeStatus ResumeDecodingPayload(DecodeBuffer* db) = 0; + + // In support of ensuring that we're really accessing and updating the + // decoder, prepare the decoder by, for example, overwriting the decoder. + virtual void PreparePayloadDecoder() = 0; + + // Get the listener to be inserted into the FrameDecoderState, ready for + // listening (e.g. reset if it is a FramePartsCollector). + virtual Http2FrameDecoderListener* PrepareListener() = 0; + + // Record a frame header for use on each call to StartDecoding. + void set_frame_header(const Http2FrameHeader& header) { + EXPECT_EQ(0, InvalidFlagMaskForFrameType(header.type) & header.flags); + if (!frame_header_is_set_ || frame_header_ != header) { + QUICHE_VLOG(2) << "set_frame_header: " << frame_header_; + } + frame_header_ = header; + frame_header_is_set_ = true; + } + + FrameDecoderState* mutable_state() { return frame_decoder_state_.get(); } + + // Randomize the payload decoder, sets the payload decoder's frame_header_, + // then start decoding the payload. Called by RandomDecoderTest. This method + // is final so that we can always perform certain actions when + // RandomDecoderTest starts the decoding of a payload, such as randomizing the + // the payload decoder, injecting the frame header and counting fast decoding + // cases. Sub-classes must implement StartDecodingPayload to perform their + // initial decoding of a frame's payload. + DecodeStatus StartDecoding(DecodeBuffer* db) final; + + // Called by RandomDecoderTest. This method is final so that we can always + // perform certain actions when RandomDecoderTest calls it, such as counting + // slow decode cases. Sub-classes must implement ResumeDecodingPayload to + // continue decoding the frame's payload, which must not all be in one buffer. + DecodeStatus ResumeDecoding(DecodeBuffer* db) final; + + // Given the specified payload (without the common frame header), decode + // it with several partitionings of the payload. + ::testing::AssertionResult DecodePayloadAndValidateSeveralWays( + absl::string_view payload, Validator validator); + + // TODO(jamessynge): Add helper method for verifying these are both non-zero, + // and call the new method from tests that expect successful decoding. + void ResetDecodeSpeedCounters() { + fast_decode_count_ = 0; + slow_decode_count_ = 0; + } + + // Count of payloads that are full decoded by StartDecodingPayload, or that + // an error was detected by StartDecodingPayload. + size_t fast_decode_count_ = 0; + + // Count of payloads that require calling ResumeDecodingPayload in order to + // decode them completely (or to detect an error during decoding). + size_t slow_decode_count_ = 0; + + private: + bool frame_header_is_set_ = false; + Http2FrameHeader frame_header_; + std::unique_ptr frame_decoder_state_; +}; + +// Base class for payload decoders of type Decoder, with corresponding test +// peer of type DecoderPeer, and using class Listener as the implementation +// of Http2FrameDecoderListenerInterface to be used during decoding. +// Typically Listener is a sub-class of FramePartsCollector. +// SupportedFrameType is set to false only for UnknownPayloadDecoder. +template +class QUICHE_NO_EXPORT AbstractPayloadDecoderTest + : public PayloadDecoderBaseTest { + protected: + // An ApproveSize function returns true to approve decoding the specified + // size of payload, else false to skip that size. Typically used for negative + // tests; for example, decoding a SETTINGS frame at all sizes except for + // multiples of 6. + typedef std::function ApproveSize; + + AbstractPayloadDecoderTest() {} + + // These tests are in setup rather than the constructor for two reasons: + // 1) Constructors are not allowed to fail, so gUnit documents that EXPECT_* + // and ASSERT_* are not allowed in constructors, and should instead be in + // SetUp if they are needed before the body of the test is executed. + // 2) To allow the sub-class constructor to make any desired modifications to + // the DecoderPeer before these tests are executed; in particular, + // UnknownPayloadDecoderPeer has not got a fixed frame type, but it is + // instead set during the test's constructor. + void SetUp() override { + PayloadDecoderBaseTest::SetUp(); + + // Confirm that DecoderPeer et al returns sensible values. Using auto as the + // variable type so that no (narrowing) conversions take place that hide + // problems; i.e. if someone changes KnownFlagsMaskForFrameType so that it + // doesn't return a uint8, and has bits above the low-order 8 bits set, this + // bit of paranoia should detect the problem before we get too far. + auto frame_type = DecoderPeer::FrameType(); + if (SupportedFrameType) { + EXPECT_TRUE(IsSupportedHttp2FrameType(frame_type)) << frame_type; + } else { + EXPECT_FALSE(IsSupportedHttp2FrameType(frame_type)) << frame_type; + } + + auto known_flags = KnownFlagsMaskForFrameType(frame_type); + EXPECT_EQ(known_flags, known_flags & 0xff); + + auto flags_to_avoid = DecoderPeer::FlagsAffectingPayloadDecoding(); + EXPECT_EQ(flags_to_avoid, flags_to_avoid & known_flags); + } + + void PreparePayloadDecoder() override { + payload_decoder_ = std::make_unique(); + } + + Http2FrameDecoderListener* PrepareListener() override { + listener_.Reset(); + return &listener_; + } + + // Returns random flags, but only those valid for the frame type, yet not + // those that the DecoderPeer says will affect the decoding of the payload + // (e.g. the PRIORTY flag on a HEADERS frame or PADDED on DATA frames). + uint8_t RandFlags() { + return Random().Rand8() & + KnownFlagsMaskForFrameType(DecoderPeer::FrameType()) & + ~DecoderPeer::FlagsAffectingPayloadDecoding(); + } + + // Start decoding the payload. + DecodeStatus StartDecodingPayload(DecodeBuffer* db) override { + QUICHE_DVLOG(2) << "StartDecodingPayload, db->Remaining=" + << db->Remaining(); + return payload_decoder_->StartDecodingPayload(mutable_state(), db); + } + + // Resume decoding the payload. + DecodeStatus ResumeDecodingPayload(DecodeBuffer* db) override { + QUICHE_DVLOG(2) << "ResumeDecodingPayload, db->Remaining=" + << db->Remaining(); + return payload_decoder_->ResumeDecodingPayload(mutable_state(), db); + } + + // Decode one frame's payload and confirm that the listener recorded the + // expected FrameParts instance, and only FrameParts instance. The payload + // will be decoded several times with different partitionings of the payload, + // and after each the validator will be called. + AssertionResult DecodePayloadAndValidateSeveralWays( + absl::string_view payload, const FrameParts& expected) { + auto validator = [&expected, this]() -> AssertionResult { + HTTP2_VERIFY_FALSE(listener_.IsInProgress()); + HTTP2_VERIFY_EQ(1u, listener_.size()); + return expected.VerifyEquals(*listener_.frame(0)); + }; + return PayloadDecoderBaseTest::DecodePayloadAndValidateSeveralWays( + payload, ValidateDoneAndEmpty(validator)); + } + + // Decode one frame's payload, expecting that the final status will be + // kDecodeError, and that OnFrameSizeError will have been called on the + // listener. The payload will be decoded several times with different + // partitionings of the payload. The type WrappedValidator is either + // RandomDecoderTest::Validator, RandomDecoderTest::NoArgValidator or + // std::nullptr_t (not extra validation). + template + ::testing::AssertionResult VerifyDetectsFrameSizeError( + absl::string_view payload, const Http2FrameHeader& header, + WrappedValidator wrapped_validator) { + set_frame_header(header); + // If wrapped_validator is not a RandomDecoderTest::Validator, make it so. + Validator validator = ToValidator(wrapped_validator); + // And wrap that validator in another which will check that we've reached + // the expected state of kDecodeError with OnFrameSizeError having been + // called by the payload decoder. + validator = [header, validator, this]( + const DecodeBuffer& input, + DecodeStatus status) -> ::testing::AssertionResult { + QUICHE_DVLOG(2) << "VerifyDetectsFrameSizeError validator; status=" + << status << "; input.Remaining=" << input.Remaining(); + HTTP2_VERIFY_EQ(DecodeStatus::kDecodeError, status); + HTTP2_VERIFY_FALSE(listener_.IsInProgress()); + HTTP2_VERIFY_EQ(1u, listener_.size()); + const FrameParts* frame = listener_.frame(0); + HTTP2_VERIFY_EQ(header, frame->GetFrameHeader()); + HTTP2_VERIFY_TRUE(frame->GetHasFrameSizeError()); + // Verify did not get OnPaddingTooLong, as we should only ever produce + // one of these two errors for a single frame. + HTTP2_VERIFY_FALSE(frame->GetOptMissingLength()); + return validator(input, status); + }; + return PayloadDecoderBaseTest::DecodePayloadAndValidateSeveralWays( + payload, validator); + } + + // Confirm that we get OnFrameSizeError when trying to decode unpadded_payload + // at all sizes from zero to unpadded_payload.size(), except those sizes not + // approved by approve_size. + // If total_pad_length is greater than zero, then that amount of padding + // is added to the payload (including the Pad Length field). + // The flags will be required_flags, PADDED if total_pad_length > 0, and some + // randomly selected flag bits not excluded by FlagsAffectingPayloadDecoding. + ::testing::AssertionResult VerifyDetectsMultipleFrameSizeErrors( + uint8_t required_flags, absl::string_view unpadded_payload, + ApproveSize approve_size, int total_pad_length) { + // required_flags should come from those that are defined for the frame + // type AND are those that affect the decoding of the payload (otherwise, + // the flag shouldn't be required). + Http2FrameType frame_type = DecoderPeer::FrameType(); + HTTP2_VERIFY_EQ(required_flags, + required_flags & KnownFlagsMaskForFrameType(frame_type)); + HTTP2_VERIFY_EQ( + required_flags, + required_flags & DecoderPeer::FlagsAffectingPayloadDecoding()); + + if (0 != + (Http2FrameFlag::PADDED & KnownFlagsMaskForFrameType(frame_type))) { + // Frame type supports padding. + if (total_pad_length == 0) { + required_flags &= ~Http2FrameFlag::PADDED; + } else { + required_flags |= Http2FrameFlag::PADDED; + } + } else { + HTTP2_VERIFY_EQ(0, total_pad_length); + } + + bool validated = false; + for (size_t real_payload_size = 0; + real_payload_size <= unpadded_payload.size(); ++real_payload_size) { + if (approve_size != nullptr && !approve_size(real_payload_size)) { + continue; + } + QUICHE_VLOG(1) << "real_payload_size=" << real_payload_size; + uint8_t flags = required_flags | RandFlags(); + Http2FrameBuilder fb; + if (total_pad_length > 0) { + // total_pad_length_ includes the size of the Pad Length field, and thus + // ranges from 0 (no PADDED flag) to 256 (Pad Length == 255). + fb.AppendUInt8(total_pad_length - 1); + } + // Append a subset of the unpadded_payload, which the decoder should + // determine is not a valid amount. + fb.Append(unpadded_payload.substr(0, real_payload_size)); + if (total_pad_length > 0) { + fb.AppendZeroes(total_pad_length - 1); + } + // We choose a random stream id because the payload decoders aren't + // checking stream ids. + uint32_t stream_id = RandStreamId(); + Http2FrameHeader header(fb.size(), frame_type, flags, stream_id); + HTTP2_VERIFY_SUCCESS( + VerifyDetectsFrameSizeError(fb.buffer(), header, nullptr)); + validated = true; + } + HTTP2_VERIFY_TRUE(validated); + return ::testing::AssertionSuccess(); + } + + // As above, but for frames without padding. + ::testing::AssertionResult VerifyDetectsFrameSizeError( + uint8_t required_flags, absl::string_view unpadded_payload, + const ApproveSize& approve_size) { + Http2FrameType frame_type = DecoderPeer::FrameType(); + uint8_t known_flags = KnownFlagsMaskForFrameType(frame_type); + HTTP2_VERIFY_EQ(0, known_flags & Http2FrameFlag::PADDED); + HTTP2_VERIFY_EQ(0, required_flags & Http2FrameFlag::PADDED); + return VerifyDetectsMultipleFrameSizeErrors( + required_flags, unpadded_payload, approve_size, 0); + } + + Listener listener_; + std::unique_ptr payload_decoder_; +}; + +// A base class for tests parameterized by the total number of bytes of +// padding, including the Pad Length field (i.e. a total_pad_length of 0 +// means unpadded as there is then no room for the Pad Length field). +// The frame type must support padding. +template +class QUICHE_NO_EXPORT AbstractPaddablePayloadDecoderTest + : public AbstractPayloadDecoderTest, + public ::testing::WithParamInterface { + typedef AbstractPayloadDecoderTest Base; + + protected: + using Base::listener_; + using Base::Random; + using Base::RandStreamId; + using Base::set_frame_header; + typedef typename Base::Validator Validator; + + AbstractPaddablePayloadDecoderTest() : total_pad_length_(GetParam()) { + QUICHE_LOG(INFO) << "total_pad_length_ = " << total_pad_length_; + } + + // Note that total_pad_length_ includes the size of the Pad Length field, + // and thus ranges from 0 (no PADDED flag) to 256 (Pad Length == 255). + bool IsPadded() const { return total_pad_length_ > 0; } + + // Value of the Pad Length field. Only call if IsPadded. + size_t pad_length() const { + EXPECT_TRUE(IsPadded()); + return total_pad_length_ - 1; + } + + // Clear the frame builder and add the Pad Length field if appropriate. + void Reset() { + frame_builder_ = Http2FrameBuilder(); + if (IsPadded()) { + frame_builder_.AppendUInt8(pad_length()); + } + } + + void MaybeAppendTrailingPadding() { + if (IsPadded()) { + frame_builder_.AppendZeroes(pad_length()); + } + } + + uint8_t RandFlags() { + uint8_t flags = Base::RandFlags(); + if (IsPadded()) { + flags |= Http2FrameFlag::PADDED; + } else { + flags &= ~Http2FrameFlag::PADDED; + } + return flags; + } + + // Verify that we get OnPaddingTooLong when decoding payload, and that the + // amount of missing padding is as specified. header.IsPadded must be true, + // and the payload must be empty or the PadLength field must be too large. + ::testing::AssertionResult VerifyDetectsPaddingTooLong( + absl::string_view payload, const Http2FrameHeader& header, + size_t expected_missing_length) { + set_frame_header(header); + auto& listener = listener_; + Validator validator = + [header, expected_missing_length, &listener]( + const DecodeBuffer&, + DecodeStatus status) -> ::testing::AssertionResult { + HTTP2_VERIFY_EQ(DecodeStatus::kDecodeError, status); + HTTP2_VERIFY_FALSE(listener.IsInProgress()); + HTTP2_VERIFY_EQ(1u, listener.size()); + const FrameParts* frame = listener.frame(0); + HTTP2_VERIFY_EQ(header, frame->GetFrameHeader()); + HTTP2_VERIFY_TRUE(frame->GetOptMissingLength()); + HTTP2_VERIFY_EQ(expected_missing_length, + frame->GetOptMissingLength().value()); + // Verify did not get OnFrameSizeError. + HTTP2_VERIFY_FALSE(frame->GetHasFrameSizeError()); + return ::testing::AssertionSuccess(); + }; + return PayloadDecoderBaseTest::DecodePayloadAndValidateSeveralWays( + payload, validator); + } + + // Verifies that we get OnPaddingTooLong for a padded frame payload whose + // (randomly selected) payload length is less than total_pad_length_. + // Flags will be selected at random, except PADDED will be set and + // flags_to_avoid will not be set. The stream id is selected at random. + ::testing::AssertionResult VerifyDetectsPaddingTooLong() { + uint8_t flags = RandFlags() | Http2FrameFlag::PADDED; + + // Create an all padding payload for total_pad_length_. + int payload_length = 0; + Http2FrameBuilder fb; + if (IsPadded()) { + fb.AppendUInt8(pad_length()); + fb.AppendZeroes(pad_length()); + QUICHE_VLOG(1) << "fb.size=" << fb.size(); + // Pick a random length for the payload that is shorter than neccesary. + payload_length = Random().Uniform(fb.size()); + } + + QUICHE_VLOG(1) << "payload_length=" << payload_length; + std::string payload = fb.buffer().substr(0, payload_length); + + // The missing length is the amount we cut off the end, unless + // payload_length is zero, in which case the decoder knows only that 1 + // byte, the Pad Length field, is missing. + size_t missing_length = + payload_length == 0 ? 1 : fb.size() - payload_length; + QUICHE_VLOG(1) << "missing_length=" << missing_length; + + const Http2FrameHeader header(payload_length, DecoderPeer::FrameType(), + flags, RandStreamId()); + return VerifyDetectsPaddingTooLong(payload, header, missing_length); + } + + // total_pad_length_ includes the size of the Pad Length field, and thus + // ranges from 0 (no PADDED flag) to 256 (Pad Length == 255). + const size_t total_pad_length_; + Http2FrameBuilder frame_builder_; +}; + +} // namespace test +} // namespace http2 + +#endif // QUICHE_HTTP2_TEST_TOOLS_PAYLOAD_DECODER_BASE_TEST_UTIL_H_ diff --git a/quiche/http2/test_tools/random_decoder_test_base.cc b/quiche/http2/test_tools/random_decoder_test_base.cc new file mode 100644 index 000000000000..b3439e9bb23e --- /dev/null +++ b/quiche/http2/test_tools/random_decoder_test_base.cc @@ -0,0 +1,167 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/test_tools/random_decoder_test_base.h" + +#include + +#include +#include + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/test_tools/verify_macros.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +using ::testing::AssertionResult; + +namespace http2 { +namespace test { + +RandomDecoderTest::RandomDecoderTest() = default; + +bool RandomDecoderTest::StopDecodeOnDone() { return stop_decode_on_done_; } + +DecodeStatus RandomDecoderTest::DecodeSegments(DecodeBuffer* original, + const SelectSize& select_size) { + DecodeStatus status = DecodeStatus::kDecodeInProgress; + bool first = true; + QUICHE_VLOG(2) << "DecodeSegments: input size=" << original->Remaining(); + while (first || original->HasData()) { + size_t remaining = original->Remaining(); + size_t size = + std::min(remaining, select_size(first, original->Offset(), remaining)); + DecodeBuffer db(original->cursor(), size); + QUICHE_VLOG(2) << "Decoding " << size << " bytes of " << remaining + << " remaining"; + if (first) { + first = false; + status = StartDecoding(&db); + } else { + status = ResumeDecoding(&db); + } + // A decoder MUST consume some input (if any is available), else we could + // get stuck in infinite loops. + if (db.Offset() == 0 && db.HasData() && + status != DecodeStatus::kDecodeError) { + ADD_FAILURE() << "Decoder didn't make any progress; db.FullSize=" + << db.FullSize() + << " original.Offset=" << original->Offset(); + return DecodeStatus::kDecodeError; + } + original->AdvanceCursor(db.Offset()); + switch (status) { + case DecodeStatus::kDecodeDone: + if (original->Empty() || StopDecodeOnDone()) { + return DecodeStatus::kDecodeDone; + } + continue; + case DecodeStatus::kDecodeInProgress: + continue; + case DecodeStatus::kDecodeError: + return DecodeStatus::kDecodeError; + } + } + return status; +} + +// Decode |original| multiple times, with different segmentations, validating +// after each decode, returning on the first failure. +AssertionResult RandomDecoderTest::DecodeAndValidateSeveralWays( + DecodeBuffer* original, bool return_non_zero_on_first, + const Validator& validator) { + const uint32_t original_remaining = original->Remaining(); + QUICHE_VLOG(1) << "DecodeAndValidateSeveralWays - Start, remaining = " + << original_remaining; + uint32_t first_consumed; + { + // Fast decode (no stopping unless decoder does so). + DecodeBuffer input(original->cursor(), original_remaining); + QUICHE_VLOG(2) << "DecodeSegmentsAndValidate with SelectRemaining"; + HTTP2_VERIFY_SUCCESS( + DecodeSegmentsAndValidate(&input, SelectRemaining(), validator)) + << "\nFailed with SelectRemaining; input.Offset=" << input.Offset() + << "; input.Remaining=" << input.Remaining(); + first_consumed = input.Offset(); + } + if (original_remaining <= 30) { + // Decode again, one byte at a time. + DecodeBuffer input(original->cursor(), original_remaining); + QUICHE_VLOG(2) << "DecodeSegmentsAndValidate with SelectOne"; + HTTP2_VERIFY_SUCCESS( + DecodeSegmentsAndValidate(&input, SelectOne(), validator)) + << "\nFailed with SelectOne; input.Offset=" << input.Offset() + << "; input.Remaining=" << input.Remaining(); + HTTP2_VERIFY_EQ(first_consumed, input.Offset()) + << "\nFailed with SelectOne"; + } + if (original_remaining <= 20) { + // Decode again, one or zero bytes at a time. + DecodeBuffer input(original->cursor(), original_remaining); + QUICHE_VLOG(2) << "DecodeSegmentsAndValidate with SelectZeroAndOne"; + HTTP2_VERIFY_SUCCESS(DecodeSegmentsAndValidate( + &input, SelectZeroAndOne(return_non_zero_on_first), validator)) + << "\nFailed with SelectZeroAndOne"; + HTTP2_VERIFY_EQ(first_consumed, input.Offset()) + << "\nFailed with SelectZeroAndOne; input.Offset=" << input.Offset() + << "; input.Remaining=" << input.Remaining(); + } + { + // Decode again, with randomly selected segment sizes. + DecodeBuffer input(original->cursor(), original_remaining); + QUICHE_VLOG(2) << "DecodeSegmentsAndValidate with SelectRandom"; + HTTP2_VERIFY_SUCCESS(DecodeSegmentsAndValidate( + &input, SelectRandom(return_non_zero_on_first), validator)) + << "\nFailed with SelectRandom; input.Offset=" << input.Offset() + << "; input.Remaining=" << input.Remaining(); + HTTP2_VERIFY_EQ(first_consumed, input.Offset()) + << "\nFailed with SelectRandom"; + } + HTTP2_VERIFY_EQ(original_remaining, original->Remaining()); + original->AdvanceCursor(first_consumed); + QUICHE_VLOG(1) << "DecodeAndValidateSeveralWays - SUCCESS"; + return ::testing::AssertionSuccess(); +} + +// static +RandomDecoderTest::SelectSize RandomDecoderTest::SelectZeroAndOne( + bool return_non_zero_on_first) { + std::shared_ptr zero_next(new bool); + *zero_next = !return_non_zero_on_first; + return [zero_next](bool /*first*/, size_t /*offset*/, + size_t /*remaining*/) -> size_t { + if (*zero_next) { + *zero_next = false; + return 0; + } else { + *zero_next = true; + return 1; + } + }; +} + +RandomDecoderTest::SelectSize RandomDecoderTest::SelectRandom( + bool return_non_zero_on_first) { + return [this, return_non_zero_on_first](bool first, size_t /*offset*/, + size_t remaining) -> size_t { + uint32_t r = random_.Rand32(); + if (first && return_non_zero_on_first) { + QUICHE_CHECK_LT(0u, remaining); + if (remaining == 1) { + return 1; + } + return 1 + (r % remaining); // size in range [1, remaining). + } + return r % (remaining + 1); // size in range [0, remaining]. + }; +} + +uint32_t RandomDecoderTest::RandStreamId() { + return random_.Rand32() & StreamIdMask(); +} + +} // namespace test +} // namespace http2 diff --git a/quiche/http2/test_tools/random_decoder_test_base.h b/quiche/http2/test_tools/random_decoder_test_base.h new file mode 100644 index 000000000000..a01ea8de18df --- /dev/null +++ b/quiche/http2/test_tools/random_decoder_test_base.h @@ -0,0 +1,255 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_TEST_TOOLS_RANDOM_DECODER_TEST_BASE_H_ +#define QUICHE_HTTP2_TEST_TOOLS_RANDOM_DECODER_TEST_BASE_H_ + +// RandomDecoderTest is a base class for tests of decoding various kinds +// of HTTP/2 and HPACK encodings. + +// TODO(jamessynge): Move more methods into .cc file. + +#include + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/http2/test_tools/verify_macros.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { + +// Some helpers. + +template +absl::string_view ToStringPiece(T (&data)[N]) { + return absl::string_view(reinterpret_cast(data), N * sizeof(T)); +} + +// Overwrite the enum with some random value, probably not a valid value for +// the enum type, but which fits into its storage. +template ::value>::type> +void CorruptEnum(T* out, Http2Random* rng) { + // Per cppreference.com, if the destination type of a static_cast is + // smaller than the source type (i.e. type of r and uint32 below), the + // resulting value is the smallest unsigned value equal to the source value + // modulo 2^n, where n is the number of bits used to represent the + // destination type unsigned U. + using underlying_type_T = typename std::underlying_type::type; + using unsigned_underlying_type_T = + typename std::make_unsigned::type; + auto r = static_cast(rng->Rand32()); + *out = static_cast(r); +} + +// Base class for tests of the ability to decode a sequence of bytes with +// various boundaries between the DecodeBuffers provided to the decoder. +class QUICHE_NO_EXPORT RandomDecoderTest : public quiche::test::QuicheTest { + public: + // SelectSize returns the size of the next DecodeBuffer to be passed to the + // decoder. Note that RandomDecoderTest allows that size to be zero, though + // some decoders can't deal with that on the first byte, hence the |first| + // parameter. + typedef std::function + SelectSize; + + // Validator returns an AssertionResult so test can do: + // EXPECT_THAT(DecodeAndValidate(..., validator)); + typedef ::testing::AssertionResult AssertionResult; + typedef std::function + Validator; + typedef std::function NoArgValidator; + + RandomDecoderTest(); + + protected: + // Start decoding; call allows sub-class to Reset the decoder, or deal with + // the first byte if that is done in a unique fashion. Might be called with + // a zero byte buffer. + virtual DecodeStatus StartDecoding(DecodeBuffer* db) = 0; + + // Resume decoding of the input after a prior call to StartDecoding, and + // possibly many calls to ResumeDecoding. + virtual DecodeStatus ResumeDecoding(DecodeBuffer* db) = 0; + + // Return true if a decode status of kDecodeDone indicates that + // decoding should stop. + virtual bool StopDecodeOnDone(); + + // Decode buffer |original| until we run out of input, or kDecodeDone is + // returned by the decoder AND StopDecodeOnDone() returns true. Segments + // (i.e. cuts up) the original DecodeBuffer into (potentially) smaller buffers + // by calling |select_size| to decide how large each buffer should be. + // We do this to test the ability to deal with arbitrary boundaries, as might + // happen in transport. + // Returns the final DecodeStatus. + DecodeStatus DecodeSegments(DecodeBuffer* original, + const SelectSize& select_size); + + // Decode buffer |original| until we run out of input, or kDecodeDone is + // returned by the decoder AND StopDecodeOnDone() returns true. Segments + // (i.e. cuts up) the original DecodeBuffer into (potentially) smaller buffers + // by calling |select_size| to decide how large each buffer should be. + // We do this to test the ability to deal with arbitrary boundaries, as might + // happen in transport. + // Invokes |validator| with the final decode status and the original decode + // buffer, with the cursor advanced as far as has been consumed by the decoder + // and returns validator's result. + ::testing::AssertionResult DecodeSegmentsAndValidate( + DecodeBuffer* original, const SelectSize& select_size, + const Validator& validator) { + DecodeStatus status = DecodeSegments(original, select_size); + return validator(*original, status); + } + + // Returns a SelectSize function for fast decoding, i.e. passing all that + // is available to the decoder. + static SelectSize SelectRemaining() { + return [](bool /*first*/, size_t /*offset*/, size_t remaining) -> size_t { + return remaining; + }; + } + + // Returns a SelectSize function for decoding a single byte at a time. + static SelectSize SelectOne() { + return [](bool /*first*/, size_t /*offset*/, + size_t /*remaining*/) -> size_t { return 1; }; + } + + // Returns a SelectSize function for decoding a single byte at a time, where + // zero byte buffers are also allowed. Alternates between zero and one. + static SelectSize SelectZeroAndOne(bool return_non_zero_on_first); + + // Returns a SelectSize function for decoding random sized segments. + SelectSize SelectRandom(bool return_non_zero_on_first); + + // Decode |original| multiple times, with different segmentations of the + // decode buffer, validating after each decode, and confirming that they + // each decode the same amount. Returns on the first failure, else returns + // success. + AssertionResult DecodeAndValidateSeveralWays(DecodeBuffer* original, + bool return_non_zero_on_first, + const Validator& validator); + + static Validator ToValidator(std::nullptr_t) { + return [](const DecodeBuffer& /*input*/, DecodeStatus /*status*/) { + return ::testing::AssertionSuccess(); + }; + } + + static Validator ToValidator(const Validator& validator) { + if (validator == nullptr) { + return ToValidator(nullptr); + } + return validator; + } + + static Validator ToValidator(const NoArgValidator& validator) { + if (validator == nullptr) { + return ToValidator(nullptr); + } + return [validator](const DecodeBuffer& /*input*/, DecodeStatus /*status*/) { + return validator(); + }; + } + + // Wraps a validator with another validator + // that first checks that the DecodeStatus is kDecodeDone and + // that the DecodeBuffer is empty. + // TODO(jamessynge): Replace this overload with the next, as using this method + // usually means that the wrapped function doesn't need to be passed the + // DecodeBuffer nor the DecodeStatus. + static Validator ValidateDoneAndEmpty(const Validator& wrapped) { + return [wrapped](const DecodeBuffer& input, + DecodeStatus status) -> AssertionResult { + HTTP2_VERIFY_EQ(status, DecodeStatus::kDecodeDone); + HTTP2_VERIFY_EQ(0u, input.Remaining()) << "\nOffset=" << input.Offset(); + if (wrapped) { + return wrapped(input, status); + } + return ::testing::AssertionSuccess(); + }; + } + static Validator ValidateDoneAndEmpty(NoArgValidator wrapped) { + return [wrapped](const DecodeBuffer& input, + DecodeStatus status) -> AssertionResult { + HTTP2_VERIFY_EQ(status, DecodeStatus::kDecodeDone); + HTTP2_VERIFY_EQ(0u, input.Remaining()) << "\nOffset=" << input.Offset(); + if (wrapped) { + return wrapped(); + } + return ::testing::AssertionSuccess(); + }; + } + static Validator ValidateDoneAndEmpty() { + NoArgValidator validator; + return ValidateDoneAndEmpty(validator); + } + + // Wraps a validator with another validator + // that first checks that the DecodeStatus is kDecodeDone and + // that the DecodeBuffer has the expected offset. + // TODO(jamessynge): Replace this overload with the next, as using this method + // usually means that the wrapped function doesn't need to be passed the + // DecodeBuffer nor the DecodeStatus. + static Validator ValidateDoneAndOffset(uint32_t offset, + const Validator& wrapped) { + return [wrapped, offset](const DecodeBuffer& input, + DecodeStatus status) -> AssertionResult { + HTTP2_VERIFY_EQ(status, DecodeStatus::kDecodeDone); + HTTP2_VERIFY_EQ(offset, input.Offset()) + << "\nRemaining=" << input.Remaining(); + if (wrapped) { + return wrapped(input, status); + } + return ::testing::AssertionSuccess(); + }; + } + static Validator ValidateDoneAndOffset(uint32_t offset, + NoArgValidator wrapped) { + return [wrapped, offset](const DecodeBuffer& input, + DecodeStatus status) -> AssertionResult { + HTTP2_VERIFY_EQ(status, DecodeStatus::kDecodeDone); + HTTP2_VERIFY_EQ(offset, input.Offset()) + << "\nRemaining=" << input.Remaining(); + if (wrapped) { + return wrapped(); + } + return ::testing::AssertionSuccess(); + }; + } + static Validator ValidateDoneAndOffset(uint32_t offset) { + NoArgValidator validator; + return ValidateDoneAndOffset(offset, validator); + } + + // Expose |random_| as Http2Random so callers don't have to care about which + // sub-class of Http2Random is used, nor can they rely on the specific + // sub-class that RandomDecoderTest uses. + Http2Random& Random() { return random_; } + Http2Random* RandomPtr() { return &random_; } + + uint32_t RandStreamId(); + + bool stop_decode_on_done_ = true; + + private: + Http2Random random_; +}; + +} // namespace test +} // namespace http2 + +#endif // QUICHE_HTTP2_TEST_TOOLS_RANDOM_DECODER_TEST_BASE_H_ diff --git a/quiche/http2/test_tools/random_decoder_test_base_test.cc b/quiche/http2/test_tools/random_decoder_test_base_test.cc new file mode 100644 index 000000000000..29f665cb0a05 --- /dev/null +++ b/quiche/http2/test_tools/random_decoder_test_base_test.cc @@ -0,0 +1,327 @@ +#include "quiche/http2/test_tools/random_decoder_test_base.h" + +#include + +#include +#include +#include + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace test { +namespace { +const char kData[]{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}; +const bool kReturnNonZeroOnFirst = true; +const bool kMayReturnZeroOnFirst = false; + +// Confirm the behavior of various parts of RandomDecoderTest. +class RandomDecoderTestTest : public RandomDecoderTest { + public: + RandomDecoderTestTest() : data_db_(kData) { + QUICHE_CHECK_EQ(sizeof kData, 8u); + } + + protected: + typedef std::function DecodingFn; + + DecodeStatus StartDecoding(DecodeBuffer* db) override { + ++start_decoding_calls_; + if (start_decoding_fn_) { + return start_decoding_fn_(db); + } + return DecodeStatus::kDecodeError; + } + + DecodeStatus ResumeDecoding(DecodeBuffer* db) override { + ++resume_decoding_calls_; + if (resume_decoding_fn_) { + return resume_decoding_fn_(db); + } + return DecodeStatus::kDecodeError; + } + + bool StopDecodeOnDone() override { + ++stop_decode_on_done_calls_; + if (override_stop_decode_on_done_) { + return sub_stop_decode_on_done_; + } + return RandomDecoderTest::StopDecodeOnDone(); + } + + size_t start_decoding_calls_ = 0; + size_t resume_decoding_calls_ = 0; + size_t stop_decode_on_done_calls_ = 0; + + DecodingFn start_decoding_fn_; + DecodingFn resume_decoding_fn_; + + DecodeBuffer data_db_; + + bool sub_stop_decode_on_done_ = true; + bool override_stop_decode_on_done_ = true; +}; + +// Decode a single byte on the StartDecoding call, then stop. +TEST_F(RandomDecoderTestTest, StopOnStartPartiallyDone) { + start_decoding_fn_ = [this](DecodeBuffer* db) { + EXPECT_EQ(1u, start_decoding_calls_); + // Make sure the correct buffer is being used. + EXPECT_EQ(kData, db->cursor()); + EXPECT_EQ(sizeof kData, db->Remaining()); + db->DecodeUInt8(); + return DecodeStatus::kDecodeDone; + }; + + EXPECT_EQ(DecodeStatus::kDecodeDone, + DecodeSegments(&data_db_, SelectRemaining())); + EXPECT_EQ(1u, data_db_.Offset()); + // StartDecoding should only be called once from each call to DecodeSegments. + EXPECT_EQ(1u, start_decoding_calls_); + EXPECT_EQ(0u, resume_decoding_calls_); + EXPECT_EQ(1u, stop_decode_on_done_calls_); +} + +// Stop decoding upon return from the first ResumeDecoding call. +TEST_F(RandomDecoderTestTest, StopOnResumePartiallyDone) { + start_decoding_fn_ = [this](DecodeBuffer* db) { + EXPECT_EQ(1u, start_decoding_calls_); + db->DecodeUInt8(); + return DecodeStatus::kDecodeInProgress; + }; + resume_decoding_fn_ = [this](DecodeBuffer* db) { + EXPECT_EQ(1u, resume_decoding_calls_); + // Make sure the correct buffer is being used. + EXPECT_EQ(data_db_.cursor(), db->cursor()); + db->DecodeUInt16(); + return DecodeStatus::kDecodeDone; + }; + + // Check that the base class honors it's member variable stop_decode_on_done_. + override_stop_decode_on_done_ = false; + stop_decode_on_done_ = true; + + EXPECT_EQ(DecodeStatus::kDecodeDone, + DecodeSegments(&data_db_, SelectRemaining())); + EXPECT_EQ(3u, data_db_.Offset()); + EXPECT_EQ(1u, start_decoding_calls_); + EXPECT_EQ(1u, resume_decoding_calls_); + EXPECT_EQ(1u, stop_decode_on_done_calls_); +} + +// Decode a random sized chunks, always reporting back kDecodeInProgress. +TEST_F(RandomDecoderTestTest, InProgressWhenEmpty) { + start_decoding_fn_ = [this](DecodeBuffer* db) { + EXPECT_EQ(1u, start_decoding_calls_); + // Consume up to 2 bytes. + if (db->HasData()) { + db->DecodeUInt8(); + if (db->HasData()) { + db->DecodeUInt8(); + } + } + return DecodeStatus::kDecodeInProgress; + }; + resume_decoding_fn_ = [](DecodeBuffer* db) { + // Consume all available bytes. + if (db->HasData()) { + db->AdvanceCursor(db->Remaining()); + } + return DecodeStatus::kDecodeInProgress; + }; + + EXPECT_EQ(DecodeStatus::kDecodeInProgress, + DecodeSegments(&data_db_, SelectRandom(kMayReturnZeroOnFirst))); + EXPECT_TRUE(data_db_.Empty()); + EXPECT_EQ(1u, start_decoding_calls_); + EXPECT_LE(1u, resume_decoding_calls_); + EXPECT_EQ(0u, stop_decode_on_done_calls_); +} + +TEST_F(RandomDecoderTestTest, DoneExactlyAtEnd) { + start_decoding_fn_ = [this](DecodeBuffer* db) { + EXPECT_EQ(1u, start_decoding_calls_); + EXPECT_EQ(1u, db->Remaining()); + EXPECT_EQ(1u, db->FullSize()); + db->DecodeUInt8(); + return DecodeStatus::kDecodeInProgress; + }; + resume_decoding_fn_ = [this](DecodeBuffer* db) { + EXPECT_EQ(1u, db->Remaining()); + EXPECT_EQ(1u, db->FullSize()); + db->DecodeUInt8(); + if (data_db_.Remaining() == 1) { + return DecodeStatus::kDecodeDone; + } + return DecodeStatus::kDecodeInProgress; + }; + override_stop_decode_on_done_ = true; + sub_stop_decode_on_done_ = true; + + EXPECT_EQ(DecodeStatus::kDecodeDone, DecodeSegments(&data_db_, SelectOne())); + EXPECT_EQ(0u, data_db_.Remaining()); + EXPECT_EQ(1u, start_decoding_calls_); + EXPECT_EQ((sizeof kData) - 1, resume_decoding_calls_); + // Didn't need to call StopDecodeOnDone because we didn't finish early. + EXPECT_EQ(0u, stop_decode_on_done_calls_); +} + +TEST_F(RandomDecoderTestTest, DecodeSeveralWaysToEnd) { + // Each call to StartDecoding or ResumeDecoding will consume all that is + // available. When all the data has been consumed, returns kDecodeDone. + size_t decoded_since_start = 0; + auto shared_fn = [&decoded_since_start, this](DecodeBuffer* db) { + decoded_since_start += db->Remaining(); + db->AdvanceCursor(db->Remaining()); + EXPECT_EQ(0u, db->Remaining()); + if (decoded_since_start == data_db_.FullSize()) { + return DecodeStatus::kDecodeDone; + } + return DecodeStatus::kDecodeInProgress; + }; + + start_decoding_fn_ = [&decoded_since_start, shared_fn](DecodeBuffer* db) { + decoded_since_start = 0; + return shared_fn(db); + }; + resume_decoding_fn_ = shared_fn; + + Validator validator = ValidateDoneAndEmpty(); + + EXPECT_TRUE(DecodeAndValidateSeveralWays(&data_db_, kMayReturnZeroOnFirst, + validator)); + + // We should have reached the end. + EXPECT_EQ(0u, data_db_.Remaining()); + + // We currently have 4 ways of decoding; update this if that changes. + EXPECT_EQ(4u, start_decoding_calls_); + + // Didn't need to call StopDecodeOnDone because we didn't finish early. + EXPECT_EQ(0u, stop_decode_on_done_calls_); +} + +TEST_F(RandomDecoderTestTest, DecodeTwoWaysAndStopEarly) { + // On the second decode, return kDecodeDone before finishing. + size_t decoded_since_start = 0; + auto shared_fn = [&decoded_since_start, this](DecodeBuffer* db) { + uint32_t amount = db->Remaining(); + if (start_decoding_calls_ == 2 && amount > 1) { + amount = 1; + } + decoded_since_start += amount; + db->AdvanceCursor(amount); + if (decoded_since_start == data_db_.FullSize()) { + return DecodeStatus::kDecodeDone; + } + if (decoded_since_start > 1 && start_decoding_calls_ == 2) { + return DecodeStatus::kDecodeDone; + } + return DecodeStatus::kDecodeInProgress; + }; + + start_decoding_fn_ = [&decoded_since_start, shared_fn](DecodeBuffer* db) { + decoded_since_start = 0; + return shared_fn(db); + }; + resume_decoding_fn_ = shared_fn; + + // We expect the first and second to succeed, but the second to end at a + // different offset, which DecodeAndValidateSeveralWays should complain about. + Validator validator = [this](const DecodeBuffer& /*input*/, + DecodeStatus status) -> AssertionResult { + if (start_decoding_calls_ <= 2 && status != DecodeStatus::kDecodeDone) { + return ::testing::AssertionFailure() + << "Expected DecodeStatus::kDecodeDone, not " << status; + } + if (start_decoding_calls_ > 2) { + return ::testing::AssertionFailure() + << "How did we get to pass " << start_decoding_calls_; + } + return ::testing::AssertionSuccess(); + }; + + EXPECT_FALSE(DecodeAndValidateSeveralWays(&data_db_, kMayReturnZeroOnFirst, + validator)); + EXPECT_EQ(2u, start_decoding_calls_); + EXPECT_EQ(1u, stop_decode_on_done_calls_); +} + +TEST_F(RandomDecoderTestTest, DecodeThreeWaysAndError) { + // Return kDecodeError from ResumeDecoding on the third decoding pass. + size_t decoded_since_start = 0; + auto shared_fn = [&decoded_since_start, this](DecodeBuffer* db) { + if (start_decoding_calls_ == 3 && decoded_since_start > 0) { + return DecodeStatus::kDecodeError; + } + uint32_t amount = db->Remaining(); + if (start_decoding_calls_ == 3 && amount > 1) { + amount = 1; + } + decoded_since_start += amount; + db->AdvanceCursor(amount); + if (decoded_since_start == data_db_.FullSize()) { + return DecodeStatus::kDecodeDone; + } + return DecodeStatus::kDecodeInProgress; + }; + + start_decoding_fn_ = [&decoded_since_start, shared_fn](DecodeBuffer* db) { + decoded_since_start = 0; + return shared_fn(db); + }; + resume_decoding_fn_ = shared_fn; + + Validator validator = ValidateDoneAndEmpty(); + EXPECT_FALSE(DecodeAndValidateSeveralWays(&data_db_, kReturnNonZeroOnFirst, + validator)); + EXPECT_EQ(3u, start_decoding_calls_); + EXPECT_EQ(0u, stop_decode_on_done_calls_); +} + +// CorruptEnum should produce lots of different values. On the assumption that +// the enum gets at least a byte of storage, we should be able to produce +// 256 distinct values. +TEST(CorruptEnumTest, ManyValues) { + std::set values; + DecodeStatus status; + QUICHE_LOG(INFO) << "sizeof status = " << sizeof status; + Http2Random rng; + for (int ndx = 0; ndx < 256; ++ndx) { + CorruptEnum(&status, &rng); + values.insert(static_cast(status)); + } +} + +// In practice the underlying type is an int, and currently that is 4 bytes. +typedef typename std::underlying_type::type DecodeStatusUT; + +struct CorruptEnumTestStruct { + DecodeStatusUT filler1; + DecodeStatus status; + DecodeStatusUT filler2; +}; + +// CorruptEnum should only overwrite the enum, not any adjacent storage. +TEST(CorruptEnumTest, CorruptsOnlyEnum) { + Http2Random rng; + for (const DecodeStatusUT filler : {DecodeStatusUT(), ~DecodeStatusUT()}) { + QUICHE_LOG(INFO) << "filler=0x" << std::hex << filler; + CorruptEnumTestStruct s; + s.filler1 = filler; + s.filler2 = filler; + for (int ndx = 0; ndx < 256; ++ndx) { + CorruptEnum(&s.status, &rng); + EXPECT_EQ(s.filler1, filler); + EXPECT_EQ(s.filler2, filler); + } + } +} + +} // namespace +} // namespace test +} // namespace http2 diff --git a/quiche/http2/test_tools/random_util.cc b/quiche/http2/test_tools/random_util.cc new file mode 100644 index 000000000000..cf171bb6348b --- /dev/null +++ b/quiche/http2/test_tools/random_util.cc @@ -0,0 +1,39 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/http2/test_tools/random_util.h" + +#include + +namespace http2 { +namespace test { + +// Here "word" means something that starts with a lower-case letter, and has +// zero or more additional characters that are numbers or lower-case letters. +std::string GenerateHttp2HeaderName(size_t len, Http2Random* rng) { + absl::string_view alpha_lc = "abcdefghijklmnopqrstuvwxyz"; + // If the name is short, just make it one word. + if (len < 8) { + return rng->RandStringWithAlphabet(len, alpha_lc); + } + // If the name is longer, ensure it starts with a word, and after that may + // have any character in alphanumdash_lc. 4 is arbitrary, could be as low + // as 1. + absl::string_view alphanumdash_lc = "abcdefghijklmnopqrstuvwxyz0123456789-"; + return rng->RandStringWithAlphabet(4, alpha_lc) + + rng->RandStringWithAlphabet(len - 4, alphanumdash_lc); +} + +std::string GenerateWebSafeString(size_t len, Http2Random* rng) { + static const char* kWebsafe64 = + "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ-_"; + return rng->RandStringWithAlphabet(len, kWebsafe64); +} + +std::string GenerateWebSafeString(size_t lo, size_t hi, Http2Random* rng) { + return GenerateWebSafeString(rng->UniformInRange(lo, hi), rng); +} + +} // namespace test +} // namespace http2 diff --git a/quiche/http2/test_tools/random_util.h b/quiche/http2/test_tools/random_util.h new file mode 100644 index 000000000000..ea4cea3696e9 --- /dev/null +++ b/quiche/http2/test_tools/random_util.h @@ -0,0 +1,30 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_TEST_TOOLS_RANDOM_UTIL_H_ +#define QUICHE_HTTP2_TEST_TOOLS_RANDOM_UTIL_H_ + +#include + +#include + +#include "quiche/http2/test_tools/http2_random.h" + +namespace http2 { +namespace test { + +// Generate a string with the allowed character set for HTTP/2 / HPACK header +// names. +std::string GenerateHttp2HeaderName(size_t len, Http2Random* rng); + +// Generate a string with the web-safe string character set of specified len. +std::string GenerateWebSafeString(size_t len, Http2Random* rng); + +// Generate a string with the web-safe string character set of length [lo, hi). +std::string GenerateWebSafeString(size_t lo, size_t hi, Http2Random* rng); + +} // namespace test +} // namespace http2 + +#endif // QUICHE_HTTP2_TEST_TOOLS_RANDOM_UTIL_H_ diff --git a/quiche/http2/test_tools/verify_macros.h b/quiche/http2/test_tools/verify_macros.h new file mode 100644 index 000000000000..fb7ac0ff1638 --- /dev/null +++ b/quiche/http2/test_tools/verify_macros.h @@ -0,0 +1,32 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_HTTP2_TEST_TOOLS_VERIFY_MACROS_H_ +#define QUICHE_HTTP2_TEST_TOOLS_VERIFY_MACROS_H_ + +#include "quiche/common/platform/api/quiche_test.h" + +#define HTTP2_VERIFY_CORE(value, str) \ + if ((value)) \ + ; \ + else \ + return ::testing::AssertionFailure() \ + << __FILE__ << ":" << __LINE__ << " " \ + << "Failed to verify that '" << str << "'" + +#define HTTP2_VERIFY_TRUE(value) HTTP2_VERIFY_CORE(value, #value) +#define HTTP2_VERIFY_FALSE(value) HTTP2_VERIFY_CORE(!value, "!" #value) +#define HTTP2_VERIFY_SUCCESS HTTP2_VERIFY_TRUE +#define HTTP2_VERIFY_EQ(value1, value2) \ + HTTP2_VERIFY_CORE((value1) == (value2), #value1 "==" #value2) +#define HTTP2_VERIFY_NE(value1, value2) \ + HTTP2_VERIFY_CORE((value1) != (value2), #value1 "!=" #value2) +#define HTTP2_VERIFY_LE(value1, value2) \ + HTTP2_VERIFY_CORE((value1) <= (value2), #value1 "<=" #value2) +#define HTTP2_VERIFY_LT(value1, value2) \ + HTTP2_VERIFY_CORE((value1) < (value2), #value1 "<" #value2) +#define HTTP2_VERIFY_GT(value1, value2) \ + HTTP2_VERIFY_CORE((value1) > (value2), #value1 ">" #value2) + +#endif // QUICHE_HTTP2_TEST_TOOLS_VERIFY_MACROS_H_ diff --git a/quiche/oblivious_http/buffers/oblivious_http_integration_test.cc b/quiche/oblivious_http/buffers/oblivious_http_integration_test.cc new file mode 100644 index 000000000000..5d00efc403a2 --- /dev/null +++ b/quiche/oblivious_http/buffers/oblivious_http_integration_test.cc @@ -0,0 +1,108 @@ +#include + +#include + +#include "absl/strings/escaping.h" +#include "openssl/hpke.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/oblivious_http/buffers/oblivious_http_response.h" + +namespace quiche { +namespace { + +struct ObliviousHttpResponseTestStrings { + std::string test_case_name; + uint8_t key_id; + std::string request_plaintext; + std::string response_plaintext; +}; + +std::string GetHpkePrivateKey() { + absl::string_view hpke_key_hex = + "b77431ecfa8f4cfc30d6e467aafa06944dffe28cb9dd1409e33a3045f5adc8a1"; + return absl::HexStringToBytes(hpke_key_hex); +} + +std::string GetHpkePublicKey() { + absl::string_view public_key = + "6d21cfe09fbea5122f9ebc2eb2a69fcc4f06408cd54aac934f012e76fcdcef62"; + return absl::HexStringToBytes(public_key); +} + +const ObliviousHttpHeaderKeyConfig GetOhttpKeyConfig(uint8_t key_id, + uint16_t kem_id, + uint16_t kdf_id, + uint16_t aead_id) { + auto ohttp_key_config = + ObliviousHttpHeaderKeyConfig::Create(key_id, kem_id, kdf_id, aead_id); + EXPECT_TRUE(ohttp_key_config.ok()); + return std::move(ohttp_key_config.value()); +} + +bssl::UniquePtr ConstructHpkeKey( + absl::string_view hpke_key, + const ObliviousHttpHeaderKeyConfig &ohttp_key_config) { + bssl::UniquePtr bssl_hpke_key(EVP_HPKE_KEY_new()); + EXPECT_NE(bssl_hpke_key, nullptr); + EXPECT_TRUE(EVP_HPKE_KEY_init( + bssl_hpke_key.get(), ohttp_key_config.GetHpkeKem(), + reinterpret_cast(hpke_key.data()), hpke_key.size())); + return bssl_hpke_key; +} +} // namespace + +using ObliviousHttpParameterizedTest = + test::QuicheTestWithParam; + +TEST_P(ObliviousHttpParameterizedTest, TestEndToEndWithOfflineStrings) { + // For each test case, verify end to end request-handling and + // response-handling. + const ObliviousHttpResponseTestStrings &test = GetParam(); + + auto ohttp_key_config = + GetOhttpKeyConfig(test.key_id, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM); + // Round-trip request flow. + auto client_req_encap = ObliviousHttpRequest::CreateClientObliviousRequest( + test.request_plaintext, GetHpkePublicKey(), ohttp_key_config); + EXPECT_TRUE(client_req_encap.ok()); + ASSERT_FALSE(client_req_encap->EncapsulateAndSerialize().empty()); + auto server_req_decap = ObliviousHttpRequest::CreateServerObliviousRequest( + client_req_encap->EncapsulateAndSerialize(), + *(ConstructHpkeKey(GetHpkePrivateKey(), ohttp_key_config)), + ohttp_key_config); + EXPECT_TRUE(server_req_decap.ok()); + EXPECT_EQ(server_req_decap->GetPlaintextData(), test.request_plaintext); + + // Round-trip response flow. + auto server_request_context = + std::move(server_req_decap.value()).ReleaseContext(); + auto server_resp_encap = ObliviousHttpResponse::CreateServerObliviousResponse( + test.response_plaintext, server_request_context); + EXPECT_TRUE(server_resp_encap.ok()); + ASSERT_FALSE(server_resp_encap->EncapsulateAndSerialize().empty()); + auto client_request_context = + std::move(client_req_encap.value()).ReleaseContext(); + auto client_resp_decap = ObliviousHttpResponse::CreateClientObliviousResponse( + server_resp_encap->EncapsulateAndSerialize(), client_request_context); + EXPECT_TRUE(client_resp_decap.ok()); + EXPECT_EQ(client_resp_decap->GetPlaintextData(), test.response_plaintext); +} + +INSTANTIATE_TEST_SUITE_P( + ObliviousHttpParameterizedTests, ObliviousHttpParameterizedTest, + testing::ValuesIn( + {{"test_case_1", 4, "test request 1", "test response 1"}, + {"test_case_2", 6, "test request 2", "test response 2"}, + {"test_case_3", 7, "test request 3", "test response 3"}, + {"test_case_4", 2, "test request 4", "test response 4"}, + {"test_case_5", 1, "test request 5", "test response 5"}, + {"test_case_6", 7, "test request 6", "test response 6"}, + {"test_case_7", 3, "test request 7", "test response 7"}, + {"test_case_8", 9, "test request 8", "test response 8"}, + {"test_case_9", 3, "test request 9", "test response 9"}, + {"test_case_10", 4, "test request 10", "test response 10"}}), + [](const testing::TestParamInfo + &info) { return info.param.test_case_name; }); + +} // namespace quiche diff --git a/quiche/oblivious_http/buffers/oblivious_http_request.cc b/quiche/oblivious_http/buffers/oblivious_http_request.cc new file mode 100644 index 000000000000..7c0b2ea99c47 --- /dev/null +++ b/quiche/oblivious_http/buffers/oblivious_http_request.cc @@ -0,0 +1,209 @@ +#include "quiche/oblivious_http/buffers/oblivious_http_request.h" + +#include +#include + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "openssl/hpke.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_crypto_logging.h" + +namespace quiche { +// Ctor. +ObliviousHttpRequest::Context::Context( + bssl::UniquePtr hpke_context, std::string encapsulated_key) + : hpke_context_(std::move(hpke_context)), + encapsulated_key_(std::move(encapsulated_key)) {} + +// Ctor. +ObliviousHttpRequest::ObliviousHttpRequest( + bssl::UniquePtr hpke_context, std::string encapsulated_key, + const ObliviousHttpHeaderKeyConfig& ohttp_key_config, + std::string req_ciphertext, std::string req_plaintext) + : oblivious_http_request_context_(absl::make_optional( + Context(std::move(hpke_context), std::move(encapsulated_key)))), + key_config_(ohttp_key_config), + request_ciphertext_(std::move(req_ciphertext)), + request_plaintext_(std::move(req_plaintext)) {} + +// Request Decapsulation. +absl::StatusOr +ObliviousHttpRequest::CreateServerObliviousRequest( + absl::string_view encrypted_data, const EVP_HPKE_KEY& gateway_key, + const ObliviousHttpHeaderKeyConfig& ohttp_key_config) { + if (EVP_HPKE_KEY_kem(&gateway_key) == nullptr) { + return absl::InvalidArgumentError( + "Invalid input param. Failed to import gateway_key."); + } + bssl::UniquePtr gateway_ctx(EVP_HPKE_CTX_new()); + if (gateway_ctx == nullptr) { + return SslErrorAsStatus("Failed to initialize Gateway/Server's Context."); + } + + QuicheDataReader reader(encrypted_data); + + auto is_hdr_ok = ohttp_key_config.ParseOhttpPayloadHeader(reader); + if (!is_hdr_ok.ok()) { + return is_hdr_ok; + } + + size_t enc_key_len = EVP_HPKE_KEM_enc_len(EVP_HPKE_KEY_kem(&gateway_key)); + + absl::string_view enc_key_received; + if (!reader.ReadStringPiece(&enc_key_received, enc_key_len)) { + return absl::FailedPreconditionError(absl::StrCat( + "Failed to extract encapsulation key of expected len=", enc_key_len, + "from payload.")); + } + std::string info = ohttp_key_config.SerializeRecipientContextInfo(); + if (!EVP_HPKE_CTX_setup_recipient( + gateway_ctx.get(), &gateway_key, ohttp_key_config.GetHpkeKdf(), + ohttp_key_config.GetHpkeAead(), + reinterpret_cast(enc_key_received.data()), + enc_key_received.size(), + reinterpret_cast(info.data()), info.size())) { + return SslErrorAsStatus("Failed to setup recipient context"); + } + + absl::string_view ciphertext_received = reader.ReadRemainingPayload(); + // Decrypt the message. + std::string decrypted(ciphertext_received.size(), '\0'); + size_t decrypted_len; + if (!EVP_HPKE_CTX_open( + gateway_ctx.get(), reinterpret_cast(decrypted.data()), + &decrypted_len, decrypted.size(), + reinterpret_cast(ciphertext_received.data()), + ciphertext_received.size(), nullptr, 0)) { + return SslErrorAsStatus("Failed to decrypt.", + absl::StatusCode::kInvalidArgument); + } + decrypted.resize(decrypted_len); + return ObliviousHttpRequest( + std::move(gateway_ctx), std::string(enc_key_received), ohttp_key_config, + std::string(ciphertext_received), std::move(decrypted)); +} + +// Request Encapsulation. +absl::StatusOr +ObliviousHttpRequest::CreateClientObliviousRequest( + std::string plaintext_payload, absl::string_view hpke_public_key, + const ObliviousHttpHeaderKeyConfig& ohttp_key_config) { + return EncapsulateWithSeed(std::move(plaintext_payload), hpke_public_key, + ohttp_key_config, ""); +} + +absl::StatusOr +ObliviousHttpRequest::CreateClientWithSeedForTesting( + std::string plaintext_payload, absl::string_view hpke_public_key, + const ObliviousHttpHeaderKeyConfig& ohttp_key_config, + absl::string_view seed) { + return ObliviousHttpRequest::EncapsulateWithSeed( + std::move(plaintext_payload), hpke_public_key, ohttp_key_config, seed); +} + +absl::StatusOr ObliviousHttpRequest::EncapsulateWithSeed( + std::string plaintext_payload, absl::string_view hpke_public_key, + const ObliviousHttpHeaderKeyConfig& ohttp_key_config, + absl::string_view seed) { + if (plaintext_payload.empty() || hpke_public_key.empty()) { + return absl::InvalidArgumentError("Invalid input."); + } + // Initialize HPKE key and context. + bssl::UniquePtr client_key(EVP_HPKE_KEY_new()); + if (client_key == nullptr) { + return SslErrorAsStatus("Failed to initialize HPKE Client Key."); + } + bssl::UniquePtr client_ctx(EVP_HPKE_CTX_new()); + if (client_ctx == nullptr) { + return SslErrorAsStatus("Failed to initialize HPKE Client Context."); + } + // Setup the sender (client) + std::string encapsulated_key(EVP_HPKE_MAX_ENC_LENGTH, '\0'); + size_t enc_len; + std::string info = ohttp_key_config.SerializeRecipientContextInfo(); + if (seed.empty()) { + if (!EVP_HPKE_CTX_setup_sender( + client_ctx.get(), + reinterpret_cast(encapsulated_key.data()), &enc_len, + encapsulated_key.size(), ohttp_key_config.GetHpkeKem(), + ohttp_key_config.GetHpkeKdf(), ohttp_key_config.GetHpkeAead(), + reinterpret_cast(hpke_public_key.data()), + hpke_public_key.size(), + reinterpret_cast(info.data()), info.size())) { + return SslErrorAsStatus( + "Failed to setup HPKE context with given public key param " + "hpke_public_key."); + } + } else { + if (!EVP_HPKE_CTX_setup_sender_with_seed_for_testing( + client_ctx.get(), + reinterpret_cast(encapsulated_key.data()), &enc_len, + encapsulated_key.size(), ohttp_key_config.GetHpkeKem(), + ohttp_key_config.GetHpkeKdf(), ohttp_key_config.GetHpkeAead(), + reinterpret_cast(hpke_public_key.data()), + hpke_public_key.size(), + reinterpret_cast(info.data()), info.size(), + reinterpret_cast(seed.data()), seed.size())) { + return SslErrorAsStatus( + "Failed to setup HPKE context with given public key param " + "hpke_public_key and seed."); + } + } + encapsulated_key.resize(enc_len); + std::string ciphertext( + plaintext_payload.size() + EVP_HPKE_CTX_max_overhead(client_ctx.get()), + '\0'); + size_t ciphertext_len; + if (!EVP_HPKE_CTX_seal( + client_ctx.get(), reinterpret_cast(ciphertext.data()), + &ciphertext_len, ciphertext.size(), + reinterpret_cast(plaintext_payload.data()), + plaintext_payload.size(), nullptr, 0)) { + return SslErrorAsStatus( + "Failed to encrypt plaintext_payload with given public key param " + "hpke_public_key."); + } + ciphertext.resize(ciphertext_len); + if (encapsulated_key.empty() || ciphertext.empty()) { + return absl::InternalError(absl::StrCat( + "Failed to generate required data: ", + (encapsulated_key.empty() ? "encapsulated key is empty" : ""), + (ciphertext.empty() ? "encrypted data is empty" : ""), ".")); + } + + return ObliviousHttpRequest( + std::move(client_ctx), std::move(encapsulated_key), ohttp_key_config, + std::move(ciphertext), std::move(plaintext_payload)); +} + +// Request Serialize. +// Builds request=[hdr, enc, ct]. +// https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#section-4.1-4.5 +std::string ObliviousHttpRequest::EncapsulateAndSerialize() const { + if (!oblivious_http_request_context_.has_value()) { + QUICHE_BUG(ohttp_encapsulate_after_context_extract) + << "EncapsulateAndSerialize cannot be called after ReleaseContext()"; + return ""; + } + return absl::StrCat(key_config_.SerializeOhttpPayloadHeader(), + oblivious_http_request_context_->encapsulated_key_, + request_ciphertext_); +} + +// Returns Decrypted blob in the case of server, and returns plaintext used by +// the client while `CreateClientObliviousRequest`. +absl::string_view ObliviousHttpRequest::GetPlaintextData() const { + return request_plaintext_; +} + +} // namespace quiche diff --git a/quiche/oblivious_http/buffers/oblivious_http_request.h b/quiche/oblivious_http/buffers/oblivious_http_request.h new file mode 100644 index 000000000000..58a555b04b6d --- /dev/null +++ b/quiche/oblivious_http/buffers/oblivious_http_request.h @@ -0,0 +1,120 @@ +#ifndef QUICHE_OBLIVIOUS_HTTP_BUFFERS_OBLIVIOUS_HTTP_REQUEST_H_ +#define QUICHE_OBLIVIOUS_HTTP_BUFFERS_OBLIVIOUS_HTTP_REQUEST_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "openssl/hpke.h" +#include "quiche/oblivious_http/common/oblivious_http_header_key_config.h" + +namespace quiche { +// 1. Handles client side encryption of the payload that will subsequently be +// added to HTTP POST body and passed on to Relay. +// 2. Handles server side decryption of the payload received in HTTP POST body +// from Relay. +// https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#name-encapsulation-of-requests +class QUICHE_EXPORT ObliviousHttpRequest { + public: + // Holds the HPKE related data received from request. This context is created + // during request processing, and subsequently passed into response handling + // in `ObliviousHttpResponse`. + class QUICHE_EXPORT Context { + public: + ~Context() = default; + + // Movable + Context(Context&& other) = default; + Context& operator=(Context&& other) = default; + + private: + explicit Context(bssl::UniquePtr hpke_context, + std::string encapsulated_key); + + // All accessors must be friends to read `Context`. + friend class ObliviousHttpRequest; + friend class ObliviousHttpResponse; + // Tests which need access. + friend class + ObliviousHttpRequest_TestDecapsulateWithSpecAppendixAExample_Test; + friend class ObliviousHttpRequest_TestEncapsulatedRequestStructure_Test; + friend class + ObliviousHttpRequest_TestEncapsulatedOhttpEncryptedPayload_Test; + friend class ObliviousHttpRequest_TestDeterministicSeededOhttpRequest_Test; + friend class ObliviousHttpResponse_EndToEndTestForResponse_Test; + friend class ObliviousHttpResponse_TestEncapsulateWithQuicheRandom_Test; + + bssl::UniquePtr hpke_context_; + std::string encapsulated_key_; + }; + // Parse the OHTTP request from the given `encrypted_data`. + // On success, returns obj that callers will use to `GetPlaintextData`. + // Generic Usecase : server-side calls this method in the context of Request. + static absl::StatusOr CreateServerObliviousRequest( + absl::string_view encrypted_data, const EVP_HPKE_KEY& gateway_key, + const ObliviousHttpHeaderKeyConfig& ohttp_key_config); + + // Constructs an OHTTP request for the given `plaintext_payload`. + // On success, returns obj that callers will use to `EncapsulateAndSerialize` + // OHttp request. + static absl::StatusOr CreateClientObliviousRequest( + std::string plaintext_payload, absl::string_view hpke_public_key, + const ObliviousHttpHeaderKeyConfig& ohttp_key_config); + + // Same as above but accepts a random number seed for testing. + static absl::StatusOr CreateClientWithSeedForTesting( + std::string plaintext_payload, absl::string_view hpke_public_key, + const ObliviousHttpHeaderKeyConfig& ohttp_key_config, + absl::string_view seed); + + // Movable. + ObliviousHttpRequest(ObliviousHttpRequest&& other) = default; + ObliviousHttpRequest& operator=(ObliviousHttpRequest&& other) = default; + + ~ObliviousHttpRequest() = default; + + // Returns serialized OHTTP request bytestring. + // @note: This method MUST NOT be called after `ReleaseContext()` has been + // called. + std::string EncapsulateAndSerialize() const; + + // Generic Usecase : server-side calls this method after Decapsulation using + // `CreateServerObliviousRequest`. + absl::string_view GetPlaintextData() const; + + // Oblivious HTTP request context is created after successful creation of + // `this` object, and subsequently passed into the `ObliviousHttpResponse` for + // followup response handling. + // @returns: This rvalue reference qualified member function transfers the + // ownership of `Context` to the caller, and further invokes + // ClangTidy:misc-use-after-move warning if callers try to extract `Context` + // twice after the fact that the ownership has already been transferred. + // @note: Callers shouldn't extract the `Context` until you're done with this + // Request and its data. + Context ReleaseContext() && { + return std::move(oblivious_http_request_context_.value()); + } + + private: + explicit ObliviousHttpRequest( + bssl::UniquePtr hpke_context, std::string encapsulated_key, + const ObliviousHttpHeaderKeyConfig& ohttp_key_config, + std::string req_ciphertext, std::string req_plaintext); + + static absl::StatusOr EncapsulateWithSeed( + std::string plaintext_payload, absl::string_view hpke_public_key, + const ObliviousHttpHeaderKeyConfig& ohttp_key_config, + absl::string_view seed); + + // This field will be empty after calling `ReleaseContext()`. + absl::optional oblivious_http_request_context_; + ObliviousHttpHeaderKeyConfig key_config_; + std::string request_ciphertext_; + std::string request_plaintext_; +}; + +} // namespace quiche + +#endif // QUICHE_OBLIVIOUS_HTTP_BUFFERS_OBLIVIOUS_HTTP_REQUEST_H_ diff --git a/quiche/oblivious_http/buffers/oblivious_http_request_test.cc b/quiche/oblivious_http/buffers/oblivious_http_request_test.cc new file mode 100644 index 000000000000..b788e5c7b894 --- /dev/null +++ b/quiche/oblivious_http/buffers/oblivious_http_request_test.cc @@ -0,0 +1,287 @@ +#include "quiche/oblivious_http/buffers/oblivious_http_request.h" + +#include + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "openssl/hkdf.h" +#include "openssl/hpke.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/quiche_data_reader.h" +#include "quiche/oblivious_http/common/oblivious_http_header_key_config.h" + +namespace quiche { + +namespace { +const uint32_t kHeaderLength = ObliviousHttpHeaderKeyConfig::kHeaderLength; +std::string GetHpkePrivateKey() { + absl::string_view hpke_key_hex = + "b77431ecfa8f4cfc30d6e467aafa06944dffe28cb9dd1409e33a3045f5adc8a1"; + return absl::HexStringToBytes(hpke_key_hex); +} + +std::string GetHpkePublicKey() { + absl::string_view public_key = + "6d21cfe09fbea5122f9ebc2eb2a69fcc4f06408cd54aac934f012e76fcdcef62"; + return absl::HexStringToBytes(public_key); +} + +std::string GetAlternativeHpkePublicKey() { + absl::string_view public_key = + "6d21cfe09fbea5122f9ebc2eb2a69fcc4f06408cd54aac934f012e76fcdcef63"; + return absl::HexStringToBytes(public_key); +} + +std::string GetSeed() { + absl::string_view seed = + "52c4a758a802cd8b936eceea314432798d5baf2d7e9235dc084ab1b9cfa2f736"; + return absl::HexStringToBytes(seed); +} + +std::string GetSeededEncapsulatedKey() { + absl::string_view encapsulated_key = + "37fda3567bdbd628e88668c3c8d7e97d1d1253b6d4ea6d44c150f741f1bf4431"; + return absl::HexStringToBytes(encapsulated_key); +} + +bssl::UniquePtr ConstructHpkeKey( + absl::string_view hpke_key, + const ObliviousHttpHeaderKeyConfig &ohttp_key_config) { + bssl::UniquePtr bssl_hpke_key(EVP_HPKE_KEY_new()); + EXPECT_NE(bssl_hpke_key, nullptr); + EXPECT_TRUE(EVP_HPKE_KEY_init( + bssl_hpke_key.get(), ohttp_key_config.GetHpkeKem(), + reinterpret_cast(hpke_key.data()), hpke_key.size())); + return bssl_hpke_key; +} + +const ObliviousHttpHeaderKeyConfig GetOhttpKeyConfig(uint8_t key_id, + uint16_t kem_id, + uint16_t kdf_id, + uint16_t aead_id) { + auto ohttp_key_config = + ObliviousHttpHeaderKeyConfig::Create(key_id, kem_id, kdf_id, aead_id); + EXPECT_TRUE(ohttp_key_config.ok()); + return std::move(ohttp_key_config.value()); +} +} // namespace + +// Direct test example from OHttp spec. +// https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#appendix-A +TEST(ObliviousHttpRequest, TestDecapsulateWithSpecAppendixAExample) { + auto ohttp_key_config = + GetOhttpKeyConfig(/*key_id=*/1, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_128_GCM); + + // X25519 Secret key (priv key). + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#appendix-A-2 + constexpr absl::string_view kX25519SecretKey = + "3c168975674b2fa8e465970b79c8dcf09f1c741626480bd4c6162fc5b6a98e1a"; + + // Encapsulated request. + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#appendix-A-14 + constexpr absl::string_view kEncapsulatedRequest = + "010020000100014b28f881333e7c164ffc499ad9796f877f4e1051ee6d31bad19dec96c2" + "08b4726374e469135906992e1268c594d2a10c695d858c40a026e7965e7d86b83dd440b2" + "c0185204b4d63525"; + + // Initialize Request obj to Decapsulate (decrypt). + auto instance = ObliviousHttpRequest::CreateServerObliviousRequest( + absl::HexStringToBytes(kEncapsulatedRequest), + *(ConstructHpkeKey(absl::HexStringToBytes(kX25519SecretKey), + ohttp_key_config)), + ohttp_key_config); + ASSERT_TRUE(instance.ok()); + auto decrypted = instance->GetPlaintextData(); + + // Encapsulated/Ephemeral public key. + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#appendix-A-10 + constexpr absl::string_view kExpectedEphemeralPublicKey = + "4b28f881333e7c164ffc499ad9796f877f4e1051ee6d31bad19dec96c208b472"; + auto oblivious_request_context = std::move(instance.value()).ReleaseContext(); + EXPECT_EQ(oblivious_request_context.encapsulated_key_, + absl::HexStringToBytes(kExpectedEphemeralPublicKey)); + + // Binary HTTP message. + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#appendix-A-6 + constexpr absl::string_view kExpectedBinaryHTTPMessage = + "00034745540568747470730b6578616d706c652e636f6d012f"; + EXPECT_EQ(decrypted, absl::HexStringToBytes(kExpectedBinaryHTTPMessage)); +} + +TEST(ObliviousHttpRequest, TestEncapsulatedRequestStructure) { + uint8_t test_key_id = 7; + uint16_t test_kem_id = EVP_HPKE_DHKEM_X25519_HKDF_SHA256; + uint16_t test_kdf_id = EVP_HPKE_HKDF_SHA256; + uint16_t test_aead_id = EVP_HPKE_AES_256_GCM; + std::string plaintext = "test"; + auto instance = ObliviousHttpRequest::CreateClientObliviousRequest( + plaintext, GetHpkePublicKey(), + GetOhttpKeyConfig(test_key_id, test_kem_id, test_kdf_id, test_aead_id)); + ASSERT_TRUE(instance.ok()); + auto payload_bytes = instance->EncapsulateAndSerialize(); + EXPECT_GE(payload_bytes.size(), kHeaderLength); + // Parse header. + QuicheDataReader reader(payload_bytes); + uint8_t key_id; + EXPECT_TRUE(reader.ReadUInt8(&key_id)); + EXPECT_EQ(key_id, test_key_id); + uint16_t kem_id; + EXPECT_TRUE(reader.ReadUInt16(&kem_id)); + EXPECT_EQ(kem_id, test_kem_id); + uint16_t kdf_id; + EXPECT_TRUE(reader.ReadUInt16(&kdf_id)); + EXPECT_EQ(kdf_id, test_kdf_id); + uint16_t aead_id; + EXPECT_TRUE(reader.ReadUInt16(&aead_id)); + EXPECT_EQ(aead_id, test_aead_id); + auto client_request_context = std::move(instance.value()).ReleaseContext(); + auto client_encapsulated_key = client_request_context.encapsulated_key_; + EXPECT_EQ(client_encapsulated_key.size(), X25519_PUBLIC_VALUE_LEN); + auto enc_key_plus_ciphertext = payload_bytes.substr(kHeaderLength); + auto packed_encapsulated_key = + enc_key_plus_ciphertext.substr(0, X25519_PUBLIC_VALUE_LEN); + EXPECT_EQ(packed_encapsulated_key, client_encapsulated_key); + auto ciphertext = enc_key_plus_ciphertext.substr(X25519_PUBLIC_VALUE_LEN); + EXPECT_GE(ciphertext.size(), plaintext.size()); +} + +TEST(ObliviousHttpRequest, TestDeterministicSeededOhttpRequest) { + auto ohttp_key_config = + GetOhttpKeyConfig(4, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM); + auto encapsulated = ObliviousHttpRequest::CreateClientWithSeedForTesting( + "test", GetHpkePublicKey(), ohttp_key_config, GetSeed()); + ASSERT_TRUE(encapsulated.ok()); + auto encapsulated_request = encapsulated->EncapsulateAndSerialize(); + auto ohttp_request_context = std::move(encapsulated.value()).ReleaseContext(); + EXPECT_EQ(ohttp_request_context.encapsulated_key_, + GetSeededEncapsulatedKey()); + absl::string_view expected_encrypted_request = + "9f37cfed07d0111ecd2c34f794671759bcbd922a"; + EXPECT_NE(ohttp_request_context.hpke_context_, nullptr); + size_t encapsulated_key_len = EVP_HPKE_KEM_enc_len( + EVP_HPKE_CTX_kem(ohttp_request_context.hpke_context_.get())); + int encrypted_payload_offset = kHeaderLength + encapsulated_key_len; + EXPECT_EQ(encapsulated_request.substr(encrypted_payload_offset), + absl::HexStringToBytes(expected_encrypted_request)); +} + +TEST(ObliviousHttpRequest, + TestSeededEncapsulatedKeySamePlaintextsSameCiphertexts) { + auto ohttp_key_config = + GetOhttpKeyConfig(8, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM); + auto req_with_same_plaintext_1 = + ObliviousHttpRequest::CreateClientWithSeedForTesting( + "same plaintext", GetHpkePublicKey(), ohttp_key_config, GetSeed()); + ASSERT_TRUE(req_with_same_plaintext_1.ok()); + auto ciphertext_1 = req_with_same_plaintext_1->EncapsulateAndSerialize(); + auto req_with_same_plaintext_2 = + ObliviousHttpRequest::CreateClientWithSeedForTesting( + "same plaintext", GetHpkePublicKey(), ohttp_key_config, GetSeed()); + ASSERT_TRUE(req_with_same_plaintext_2.ok()); + auto ciphertext_2 = req_with_same_plaintext_2->EncapsulateAndSerialize(); + EXPECT_EQ(ciphertext_1, ciphertext_2); +} + +TEST(ObliviousHttpRequest, + TestSeededEncapsulatedKeyDifferentPlaintextsDifferentCiphertexts) { + auto ohttp_key_config = + GetOhttpKeyConfig(8, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM); + auto req_with_different_plaintext_1 = + ObliviousHttpRequest::CreateClientWithSeedForTesting( + "different 1", GetHpkePublicKey(), ohttp_key_config, GetSeed()); + ASSERT_TRUE(req_with_different_plaintext_1.ok()); + auto ciphertext_1 = req_with_different_plaintext_1->EncapsulateAndSerialize(); + auto req_with_different_plaintext_2 = + ObliviousHttpRequest::CreateClientWithSeedForTesting( + "different 2", GetHpkePublicKey(), ohttp_key_config, GetSeed()); + ASSERT_TRUE(req_with_different_plaintext_2.ok()); + auto ciphertext_2 = req_with_different_plaintext_2->EncapsulateAndSerialize(); + EXPECT_NE(ciphertext_1, ciphertext_2); +} + +TEST(ObliviousHttpRequest, TestInvalidInputsOnClientSide) { + auto ohttp_key_config = + GetOhttpKeyConfig(30, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM); + // Empty plaintext. + EXPECT_EQ(ObliviousHttpRequest::CreateClientObliviousRequest( + /*plaintext_payload*/ "", GetHpkePublicKey(), ohttp_key_config) + .status() + .code(), + absl::StatusCode::kInvalidArgument); + // Empty HPKE public key. + EXPECT_EQ(ObliviousHttpRequest::CreateClientObliviousRequest( + "some plaintext", + /*hpke_public_key*/ "", ohttp_key_config) + .status() + .code(), + absl::StatusCode::kInvalidArgument); +} + +TEST(ObliviousHttpRequest, TestInvalidInputsOnServerSide) { + auto ohttp_key_config = + GetOhttpKeyConfig(4, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM); + // Empty encrypted payload. + EXPECT_EQ(ObliviousHttpRequest::CreateServerObliviousRequest( + /*encrypted_data*/ "", + *(ConstructHpkeKey(GetHpkePrivateKey(), ohttp_key_config)), + ohttp_key_config) + .status() + .code(), + absl::StatusCode::kInvalidArgument); + // Empty EVP_HPKE_KEY struct. + EXPECT_EQ(ObliviousHttpRequest::CreateServerObliviousRequest( + absl::StrCat(ohttp_key_config.SerializeOhttpPayloadHeader(), + GetSeededEncapsulatedKey(), + "9f37cfed07d0111ecd2c34f794671759bcbd922a"), + /*gateway_key*/ {}, ohttp_key_config) + .status() + .code(), + absl::StatusCode::kInvalidArgument); +} + +TEST(ObliviousHttpRequest, EndToEndTestForRequest) { + auto ohttp_key_config = + GetOhttpKeyConfig(5, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM); + auto encapsulate = ObliviousHttpRequest::CreateClientObliviousRequest( + "test", GetHpkePublicKey(), ohttp_key_config); + ASSERT_TRUE(encapsulate.ok()); + auto oblivious_request = encapsulate->EncapsulateAndSerialize(); + auto decapsulate = ObliviousHttpRequest::CreateServerObliviousRequest( + oblivious_request, + *(ConstructHpkeKey(GetHpkePrivateKey(), ohttp_key_config)), + ohttp_key_config); + ASSERT_TRUE(decapsulate.ok()); + auto decrypted = decapsulate->GetPlaintextData(); + EXPECT_EQ(decrypted, "test"); +} + +TEST(ObliviousHttpRequest, EndToEndTestForRequestWithWrongKey) { + auto ohttp_key_config = + GetOhttpKeyConfig(5, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM); + auto encapsulate = ObliviousHttpRequest::CreateClientObliviousRequest( + "test", GetAlternativeHpkePublicKey(), ohttp_key_config); + ASSERT_TRUE(encapsulate.ok()); + auto oblivious_request = encapsulate->EncapsulateAndSerialize(); + auto decapsulate = ObliviousHttpRequest::CreateServerObliviousRequest( + oblivious_request, + *(ConstructHpkeKey(GetHpkePrivateKey(), ohttp_key_config)), + ohttp_key_config); + EXPECT_EQ(decapsulate.status().code(), absl::StatusCode::kInvalidArgument); +} +} // namespace quiche diff --git a/quiche/oblivious_http/buffers/oblivious_http_response.cc b/quiche/oblivious_http/buffers/oblivious_http_response.cc new file mode 100644 index 000000000000..ea1936db9e9a --- /dev/null +++ b/quiche/oblivious_http/buffers/oblivious_http_response.cc @@ -0,0 +1,353 @@ +#include "quiche/oblivious_http/buffers/oblivious_http_response.h" + +#include +#include + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "openssl/aead.h" +#include "openssl/hkdf.h" +#include "openssl/hpke.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/quiche_crypto_logging.h" +#include "quiche/common/quiche_random.h" +#include "quiche/oblivious_http/common/oblivious_http_header_key_config.h" + +namespace quiche { +namespace { +// Generate a random string. +void random(QuicheRandom* quiche_random, char* dest, size_t len) { + if (quiche_random == nullptr) { + quiche_random = QuicheRandom::GetInstance(); + } + quiche_random->RandBytes(dest, len); +} +} // namespace + +// Ctor. +ObliviousHttpResponse::ObliviousHttpResponse(std::string encrypted_data, + std::string resp_plaintext) + : encrypted_data_(std::move(encrypted_data)), + response_plaintext_(std::move(resp_plaintext)) {} + +// Response Decapsulation. +// 1. Extract resp_nonce +// 2. Build prk (pseudorandom key) using HKDF_Extract +// 3. Derive aead_key using HKDF_Labeled_Expand +// 4. Derive aead_nonce using HKDF_Labeled_Expand +// 5. Setup AEAD context and Decrypt. +// https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#section-4.2-4 +absl::StatusOr +ObliviousHttpResponse::CreateClientObliviousResponse( + std::string encrypted_data, + ObliviousHttpRequest::Context& oblivious_http_request_context) { + if (oblivious_http_request_context.hpke_context_ == nullptr) { + return absl::FailedPreconditionError( + "HPKE context wasn't initialized before proceeding with this Response " + "Decapsulation on Client-side."); + } + size_t expected_key_len = EVP_HPKE_KEM_enc_len( + EVP_HPKE_CTX_kem(oblivious_http_request_context.hpke_context_.get())); + if (oblivious_http_request_context.encapsulated_key_.size() != + expected_key_len) { + return absl::InvalidArgumentError(absl::StrCat( + "Invalid len for encapsulated_key arg. Expected:", expected_key_len, + " Actual:", oblivious_http_request_context.encapsulated_key_.size())); + } + if (encrypted_data.empty()) { + return absl::InvalidArgumentError("Empty encrypted_data input param."); + } + + absl::StatusOr aead_params_st = + GetCommonAeadParams(oblivious_http_request_context); + if (!aead_params_st.ok()) { + return aead_params_st.status(); + } + + // secret_len = [max(Nn, Nk)] where Nk and Nn are the length of AEAD + // key and nonce associated with HPKE context. + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#section-4.2-2.1 + size_t secret_len = aead_params_st.value().secret_len; + if (encrypted_data.size() < secret_len) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid input response. Failed to parse required minimum " + "expected_len=", + secret_len, " bytes.")); + } + // Extract response_nonce. Step 2 + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#section-4.2-2.2 + absl::string_view response_nonce = + absl::string_view(encrypted_data).substr(0, secret_len); + absl::string_view encrypted_response = + absl::string_view(encrypted_data).substr(secret_len); + + // Steps (1, 3 to 5) + AEAD context SetUp before 6th step is performed in + // CommonOperations. + auto common_ops_st = CommonOperationsToEncapDecap( + response_nonce, oblivious_http_request_context, + aead_params_st.value().aead_key_len, + aead_params_st.value().aead_nonce_len, aead_params_st.value().secret_len); + if (!common_ops_st.ok()) { + return common_ops_st.status(); + } + + std::string decrypted(encrypted_response.size(), '\0'); + size_t decrypted_len; + + // Decrypt with initialized AEAD context. + // response, error = Open(aead_key, aead_nonce, "", ct) + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#section-4.2-6 + if (!EVP_AEAD_CTX_open( + common_ops_st.value().aead_ctx.get(), + reinterpret_cast(decrypted.data()), &decrypted_len, + decrypted.size(), + reinterpret_cast( + common_ops_st.value().aead_nonce.data()), + aead_params_st.value().aead_nonce_len, + reinterpret_cast(encrypted_response.data()), + encrypted_response.size(), nullptr, 0)) { + return SslErrorAsStatus( + "Failed to decrypt the response with derived AEAD key and nonce."); + } + decrypted.resize(decrypted_len); + ObliviousHttpResponse oblivious_response(std::move(encrypted_data), + std::move(decrypted)); + return oblivious_response; +} + +// Response Encapsulation. +// Follows the Ohttp spec section-4.2 (Encapsulation of Responses) Ref +// https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#section-4.2 +// Use HPKE context from BoringSSL to export a secret and use it to Seal (AKA +// encrypt) the response back to the Sender(client) +absl::StatusOr +ObliviousHttpResponse::CreateServerObliviousResponse( + std::string plaintext_payload, + ObliviousHttpRequest::Context& oblivious_http_request_context, + QuicheRandom* quiche_random) { + if (oblivious_http_request_context.hpke_context_ == nullptr) { + return absl::FailedPreconditionError( + "HPKE context wasn't initialized before proceeding with this Response " + "Encapsulation on Server-side."); + } + size_t expected_key_len = EVP_HPKE_KEM_enc_len( + EVP_HPKE_CTX_kem(oblivious_http_request_context.hpke_context_.get())); + if (oblivious_http_request_context.encapsulated_key_.size() != + expected_key_len) { + return absl::InvalidArgumentError(absl::StrCat( + "Invalid len for encapsulated_key arg. Expected:", expected_key_len, + " Actual:", oblivious_http_request_context.encapsulated_key_.size())); + } + if (plaintext_payload.empty()) { + return absl::InvalidArgumentError("Empty plaintext_payload input param."); + } + absl::StatusOr aead_params_st = + GetCommonAeadParams(oblivious_http_request_context); + if (!aead_params_st.ok()) { + return aead_params_st.status(); + } + const size_t nonce_size = aead_params_st->secret_len; + const size_t max_encrypted_data_size = + nonce_size + plaintext_payload.size() + + EVP_AEAD_max_overhead(EVP_HPKE_AEAD_aead(EVP_HPKE_CTX_aead( + oblivious_http_request_context.hpke_context_.get()))); + std::string encrypted_data(max_encrypted_data_size, '\0'); + // response_nonce = random(max(Nn, Nk)) + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#section-4.2-2.2 + random(quiche_random, encrypted_data.data(), nonce_size); + absl::string_view response_nonce = + absl::string_view(encrypted_data).substr(0, nonce_size); + + // Steps (1, 3 to 5) + AEAD context SetUp before 6th step is performed in + // CommonOperations. + auto common_ops_st = CommonOperationsToEncapDecap( + response_nonce, oblivious_http_request_context, + aead_params_st.value().aead_key_len, + aead_params_st.value().aead_nonce_len, aead_params_st.value().secret_len); + if (!common_ops_st.ok()) { + return common_ops_st.status(); + } + + // ct = Seal(aead_key, aead_nonce, "", response) + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#section-4.2-2.6 + size_t ciphertext_len; + if (!EVP_AEAD_CTX_seal( + common_ops_st.value().aead_ctx.get(), + reinterpret_cast(encrypted_data.data() + nonce_size), + &ciphertext_len, encrypted_data.size() - nonce_size, + reinterpret_cast( + common_ops_st.value().aead_nonce.data()), + aead_params_st.value().aead_nonce_len, + reinterpret_cast(plaintext_payload.data()), + plaintext_payload.size(), nullptr, 0)) { + return SslErrorAsStatus( + "Failed to encrypt the payload with derived AEAD key."); + } + encrypted_data.resize(nonce_size + ciphertext_len); + if (nonce_size == 0 || ciphertext_len == 0) { + return absl::InternalError(absl::StrCat( + "ObliviousHttpResponse Object wasn't initialized with required fields.", + (nonce_size == 0 ? "Generated nonce is empty." : ""), + (ciphertext_len == 0 ? "Generated Encrypted payload is empty." : ""))); + } + ObliviousHttpResponse oblivious_response(std::move(encrypted_data), + std::move(plaintext_payload)); + return oblivious_response; +} + +// Serialize. +// enc_response = concat(response_nonce, ct) +// https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#section-4.2-4 +const std::string& ObliviousHttpResponse::EncapsulateAndSerialize() const { + return encrypted_data_; +} + +// Decrypted blob. +const std::string& ObliviousHttpResponse::GetPlaintextData() const { + return response_plaintext_; +} + +// This section mainly deals with common operations performed by both +// Sender(client) and Receiver(gateway) on ObliviousHttpResponse. + +absl::StatusOr +ObliviousHttpResponse::GetCommonAeadParams( + ObliviousHttpRequest::Context& oblivious_http_request_context) { + const EVP_AEAD* evp_hpke_aead = EVP_HPKE_AEAD_aead( + EVP_HPKE_CTX_aead(oblivious_http_request_context.hpke_context_.get())); + if (evp_hpke_aead == nullptr) { + return absl::FailedPreconditionError( + "Key Configuration not supported by HPKE AEADs. Check your key " + "config."); + } + // Nk = [AEAD key len], is determined by BoringSSL. + const size_t aead_key_len = EVP_AEAD_key_length(evp_hpke_aead); + // Nn = [AEAD nonce len], is determined by BoringSSL. + const size_t aead_nonce_len = EVP_AEAD_nonce_length(evp_hpke_aead); + const size_t secret_len = std::max(aead_key_len, aead_nonce_len); + CommonAeadParamsResult result{evp_hpke_aead, aead_key_len, aead_nonce_len, + secret_len}; + return result; +} + +// Common Steps of AEAD key and AEAD nonce derivation common to both +// client(decapsulation) & Gateway(encapsulation) in handling +// Oblivious-Response. Ref Steps (1, 3-to-5, and setting up AEAD context in +// preparation for 6th step's Seal/Open) in spec. +// https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#section-4.2-4 +absl::StatusOr +ObliviousHttpResponse::CommonOperationsToEncapDecap( + absl::string_view response_nonce, + ObliviousHttpRequest::Context& oblivious_http_request_context, + const size_t aead_key_len, const size_t aead_nonce_len, + const size_t secret_len) { + if (response_nonce.empty()) { + return absl::InvalidArgumentError("Invalid input params."); + } + // secret = context.Export("message/bhttp response", Nk) + // Export secret of len [max(Nn, Nk)] where Nk and Nn are the length of AEAD + // key and nonce associated with context. + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#section-4.2-2.1 + std::string secret(secret_len, '\0'); + absl::string_view resp_label = + ObliviousHttpHeaderKeyConfig::kOhttpResponseLabel; + if (!EVP_HPKE_CTX_export(oblivious_http_request_context.hpke_context_.get(), + reinterpret_cast(secret.data()), + secret.size(), + reinterpret_cast(resp_label.data()), + resp_label.size())) { + return SslErrorAsStatus("Failed to export secret."); + } + + // salt = concat(enc, response_nonce) + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#section-4.2-2.3 + std::string salt = absl::StrCat( + oblivious_http_request_context.encapsulated_key_, response_nonce); + + // prk = Extract(salt, secret) + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#section-4.2-2.3 + std::string pseudorandom_key(EVP_MAX_MD_SIZE, '\0'); + size_t prk_len; + auto evp_md = EVP_HPKE_KDF_hkdf_md( + EVP_HPKE_CTX_kdf(oblivious_http_request_context.hpke_context_.get())); + if (evp_md == nullptr) { + QUICHE_BUG(Invalid Key Configuration + : Unsupported BoringSSL HPKE KDFs) + << "Update KeyConfig to support only BoringSSL HKDFs."; + return absl::FailedPreconditionError( + "Key Configuration not supported by BoringSSL HPKE KDFs. Check your " + "Key " + "Config."); + } + if (!HKDF_extract( + reinterpret_cast(pseudorandom_key.data()), &prk_len, evp_md, + reinterpret_cast(secret.data()), secret_len, + reinterpret_cast(salt.data()), salt.size())) { + return SslErrorAsStatus( + "Failed to derive pesudorandom key from salt and secret."); + } + pseudorandom_key.resize(prk_len); + + // aead_key = Expand(prk, "key", Nk) + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#section-4.2-2.4 + std::string aead_key(aead_key_len, '\0'); + absl::string_view hkdf_info = ObliviousHttpHeaderKeyConfig::kKeyHkdfInfo; + // All currently supported KDFs are HKDF-based. See CheckKdfId in + // `ObliviousHttpHeaderKeyConfig`. + if (!HKDF_expand(reinterpret_cast(aead_key.data()), aead_key_len, + evp_md, + reinterpret_cast(pseudorandom_key.data()), + prk_len, reinterpret_cast(hkdf_info.data()), + hkdf_info.size())) { + return SslErrorAsStatus( + "Failed to expand AEAD key using pseudorandom key(prk)."); + } + + // aead_nonce = Expand(prk, "nonce", Nn) + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#section-4.2-2.5 + std::string aead_nonce(aead_nonce_len, '\0'); + hkdf_info = ObliviousHttpHeaderKeyConfig::kNonceHkdfInfo; + // All currently supported KDFs are HKDF-based. See CheckKdfId in + // `ObliviousHttpHeaderKeyConfig`. + if (!HKDF_expand(reinterpret_cast(aead_nonce.data()), + aead_nonce_len, evp_md, + reinterpret_cast(pseudorandom_key.data()), + prk_len, reinterpret_cast(hkdf_info.data()), + hkdf_info.size())) { + return SslErrorAsStatus( + "Failed to expand AEAD nonce using pseudorandom key(prk)."); + } + + const EVP_AEAD* evp_hpke_aead = EVP_HPKE_AEAD_aead( + EVP_HPKE_CTX_aead(oblivious_http_request_context.hpke_context_.get())); + if (evp_hpke_aead == nullptr) { + return absl::FailedPreconditionError( + "Key Configuration not supported by HPKE AEADs. Check your key " + "config."); + } + + // Setup AEAD context for subsequent Seal/Open operation in response handling. + bssl::UniquePtr aead_ctx(EVP_AEAD_CTX_new( + evp_hpke_aead, reinterpret_cast(aead_key.data()), + aead_key.size(), 0)); + if (aead_ctx == nullptr) { + return SslErrorAsStatus("Failed to initialize AEAD context."); + } + if (!EVP_AEAD_CTX_init(aead_ctx.get(), evp_hpke_aead, + reinterpret_cast(aead_key.data()), + aead_key.size(), 0, nullptr)) { + return SslErrorAsStatus( + "Failed to initialize AEAD context with derived key."); + } + CommonOperationsResult result{std::move(aead_ctx), std::move(aead_nonce)}; + return result; +} + +} // namespace quiche diff --git a/quiche/oblivious_http/buffers/oblivious_http_response.h b/quiche/oblivious_http/buffers/oblivious_http_response.h new file mode 100644 index 000000000000..5e3cb0b5f16d --- /dev/null +++ b/quiche/oblivious_http/buffers/oblivious_http_response.h @@ -0,0 +1,95 @@ +#ifndef QUICHE_OBLIVIOUS_HTTP_BUFFERS_OBLIVIOUS_HTTP_RESPONSE_H_ +#define QUICHE_OBLIVIOUS_HTTP_BUFFERS_OBLIVIOUS_HTTP_RESPONSE_H_ + +#include + +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "quiche/common/quiche_random.h" +#include "quiche/oblivious_http/buffers/oblivious_http_request.h" + +namespace quiche { + +class QUICHE_EXPORT ObliviousHttpResponse { + public: + // Parse and decrypt the OHttp response using ObliviousHttpContext context obj + // that was returned from `CreateClientObliviousRequest` method. On success, + // returns obj that callers will use to `GetDecryptedMessage`. + // @params: Note that `oblivious_http_request_context` is required to stay + // alive only for the lifetime of this factory method call. + static absl::StatusOr CreateClientObliviousResponse( + std::string encrypted_data, + ObliviousHttpRequest::Context& oblivious_http_request_context); + + // Encrypt the input param `plaintext_payload` and create OHttp response using + // ObliviousHttpContext context obj that was returned from + // `CreateServerObliviousRequest` method. On success, returns obj that callers + // will use to `Serialize` OHttp response. Generic Usecase : server-side calls + // this method in the context of Response. + // @params: Note that `oblivious_http_request_context` is required to stay + // alive only for the lifetime of this factory method call. + // @params: If callers do not provide `quiche_random`, it will be initialized + // to default supplied `QuicheRandom::GetInstance()`. It's recommended that + // callers initialize `QuicheRandom* quiche_random` as a Singleton instance + // within their code and pass in the same, in order to have optimized random + // string generation. `quiche_random` is required to stay alive only for the + // lifetime of this factory method call. + static absl::StatusOr CreateServerObliviousResponse( + std::string plaintext_payload, + ObliviousHttpRequest::Context& oblivious_http_request_context, + QuicheRandom* quiche_random = nullptr); + + // Copyable. + ObliviousHttpResponse(const ObliviousHttpResponse& other) = default; + ObliviousHttpResponse& operator=(const ObliviousHttpResponse& other) = + default; + + // Movable. + ObliviousHttpResponse(ObliviousHttpResponse&& other) = default; + ObliviousHttpResponse& operator=(ObliviousHttpResponse&& other) = default; + + ~ObliviousHttpResponse() = default; + + // Generic Usecase : server-side calls this method in the context of Response + // to serialize OHTTP response that will be returned to client-side. + // Returns serialized OHTTP response bytestring. + const std::string& EncapsulateAndSerialize() const; + + const std::string& GetPlaintextData() const; + + private: + struct CommonAeadParamsResult { + const EVP_AEAD* evp_hpke_aead; + const size_t aead_key_len; + const size_t aead_nonce_len; + const size_t secret_len; + }; + + struct CommonOperationsResult { + bssl::UniquePtr aead_ctx; + const std::string aead_nonce; + }; + + explicit ObliviousHttpResponse(std::string encrypted_data, + std::string resp_plaintext); + + // Determines AEAD key len(Nk), AEAD nonce len(Nn) based on HPKE context and + // further estimates secret_len = std::max(Nk, Nn) + static absl::StatusOr GetCommonAeadParams( + ObliviousHttpRequest::Context& oblivious_http_request_context); + // Performs operations related to response handling that are common between + // client and server. + static absl::StatusOr CommonOperationsToEncapDecap( + absl::string_view response_nonce, + ObliviousHttpRequest::Context& oblivious_http_request_context, + const size_t aead_key_len, const size_t aead_nonce_len, + const size_t secret_len); + std::string encrypted_data_; + std::string response_plaintext_; +}; + +} // namespace quiche + +#endif // QUICHE_OBLIVIOUS_HTTP_BUFFERS_OBLIVIOUS_HTTP_RESPONSE_H_ diff --git a/quiche/oblivious_http/buffers/oblivious_http_response_test.cc b/quiche/oblivious_http/buffers/oblivious_http_response_test.cc new file mode 100644 index 000000000000..b2147d724817 --- /dev/null +++ b/quiche/oblivious_http/buffers/oblivious_http_response_test.cc @@ -0,0 +1,210 @@ +#include "quiche/oblivious_http/buffers/oblivious_http_response.h" + +#include +#include +#include + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "openssl/hpke.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/oblivious_http/buffers/oblivious_http_request.h" + +namespace quiche { + +namespace { +std::string GetHpkePrivateKey() { + absl::string_view hpke_key_hex = + "b77431ecfa8f4cfc30d6e467aafa06944dffe28cb9dd1409e33a3045f5adc8a1"; + return absl::HexStringToBytes(hpke_key_hex); +} + +std::string GetHpkePublicKey() { + absl::string_view public_key = + "6d21cfe09fbea5122f9ebc2eb2a69fcc4f06408cd54aac934f012e76fcdcef62"; + return absl::HexStringToBytes(public_key); +} + +std::string GetSeed() { + absl::string_view seed = + "52c4a758a802cd8b936eceea314432798d5baf2d7e9235dc084ab1b9cfa2f736"; + return absl::HexStringToBytes(seed); +} + +std::string GetSeededEncapsulatedKey() { + absl::string_view encapsulated_key = + "37fda3567bdbd628e88668c3c8d7e97d1d1253b6d4ea6d44c150f741f1bf4431"; + return absl::HexStringToBytes(encapsulated_key); +} + +const ObliviousHttpHeaderKeyConfig GetOhttpKeyConfig(uint8_t key_id, + uint16_t kem_id, + uint16_t kdf_id, + uint16_t aead_id) { + auto ohttp_key_config = + ObliviousHttpHeaderKeyConfig::Create(key_id, kem_id, kdf_id, aead_id); + EXPECT_TRUE(ohttp_key_config.ok()); + return ohttp_key_config.value(); +} + +bssl::UniquePtr GetSeededClientContext(uint8_t key_id, + uint16_t kem_id, + uint16_t kdf_id, + uint16_t aead_id) { + bssl::UniquePtr client_ctx(EVP_HPKE_CTX_new()); + std::string encapsulated_key(EVP_HPKE_MAX_ENC_LENGTH, '\0'); + size_t enc_len; + std::string info = GetOhttpKeyConfig(key_id, kem_id, kdf_id, aead_id) + .SerializeRecipientContextInfo(); + + EXPECT_TRUE(EVP_HPKE_CTX_setup_sender_with_seed_for_testing( + client_ctx.get(), reinterpret_cast(encapsulated_key.data()), + &enc_len, encapsulated_key.size(), EVP_hpke_x25519_hkdf_sha256(), + EVP_hpke_hkdf_sha256(), EVP_hpke_aes_256_gcm(), + reinterpret_cast(GetHpkePublicKey().data()), + GetHpkePublicKey().size(), reinterpret_cast(info.data()), + info.size(), reinterpret_cast(GetSeed().data()), + GetSeed().size())); + encapsulated_key.resize(enc_len); + EXPECT_EQ(encapsulated_key, GetSeededEncapsulatedKey()); + return client_ctx; +} + +bssl::UniquePtr ConstructHpkeKey( + absl::string_view hpke_key, + const ObliviousHttpHeaderKeyConfig &ohttp_key_config) { + bssl::UniquePtr bssl_hpke_key(EVP_HPKE_KEY_new()); + EXPECT_NE(bssl_hpke_key, nullptr); + EXPECT_TRUE(EVP_HPKE_KEY_init( + bssl_hpke_key.get(), ohttp_key_config.GetHpkeKem(), + reinterpret_cast(hpke_key.data()), hpke_key.size())); + return bssl_hpke_key; +} + +ObliviousHttpRequest SetUpObliviousHttpContext(uint8_t key_id, uint16_t kem_id, + uint16_t kdf_id, + uint16_t aead_id, + std::string plaintext) { + auto ohttp_key_config = GetOhttpKeyConfig(key_id, kem_id, kdf_id, aead_id); + auto client_request_encapsulate = + ObliviousHttpRequest::CreateClientWithSeedForTesting( + std::move(plaintext), GetHpkePublicKey(), ohttp_key_config, + GetSeed()); + EXPECT_TRUE(client_request_encapsulate.ok()); + auto oblivious_request = + client_request_encapsulate->EncapsulateAndSerialize(); + auto server_request_decapsulate = + ObliviousHttpRequest::CreateServerObliviousRequest( + oblivious_request, + *(ConstructHpkeKey(GetHpkePrivateKey(), ohttp_key_config)), + ohttp_key_config); + EXPECT_TRUE(server_request_decapsulate.ok()); + return std::move(server_request_decapsulate.value()); +} + +// QuicheRandom implementation. +// Just fills the buffer with repeated chars that's initialized in seed. +class TestQuicheRandom : public QuicheRandom { + public: + TestQuicheRandom(char seed) : seed_(seed) {} + ~TestQuicheRandom() override {} + + void RandBytes(void *data, size_t len) override { memset(data, seed_, len); } + + uint64_t RandUint64() override { + uint64_t random_int; + memset(&random_int, seed_, sizeof(random_int)); + return random_int; + } + + void InsecureRandBytes(void *data, size_t len) override { + return RandBytes(data, len); + } + uint64_t InsecureRandUint64() override { return RandUint64(); } + + private: + char seed_; +}; + +size_t GetResponseNonceLength(const EVP_HPKE_CTX &hpke_context) { + EXPECT_NE(&hpke_context, nullptr); + const EVP_AEAD *evp_hpke_aead = + EVP_HPKE_AEAD_aead(EVP_HPKE_CTX_aead(&hpke_context)); + EXPECT_NE(evp_hpke_aead, nullptr); + // Nk = [AEAD key len], is determined by BSSL. + const size_t aead_key_len = EVP_AEAD_key_length(evp_hpke_aead); + // Nn = [AEAD nonce len], is determined by BSSL. + const size_t aead_nonce_len = EVP_AEAD_nonce_length(evp_hpke_aead); + const size_t secret_len = std::max(aead_key_len, aead_nonce_len); + return secret_len; +} + +TEST(ObliviousHttpResponse, TestDecapsulateReceivedResponse) { + // Construct encrypted payload with plaintext: "test response" + absl::string_view encrypted_response = + "39d5b03c02c97e216df444e4681007105974d4df1585aae05e7b53f3ccdb55d51f711d48" + "eeefbc1a555d6d928e35df33fd23c23846fa7b083e30692f7b"; + auto oblivious_context = + SetUpObliviousHttpContext(4, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM, + "test") + .ReleaseContext(); + auto decapsulated = ObliviousHttpResponse::CreateClientObliviousResponse( + absl::HexStringToBytes(encrypted_response), oblivious_context); + EXPECT_TRUE(decapsulated.ok()); + auto decrypted = decapsulated->GetPlaintextData(); + EXPECT_EQ(decrypted, "test response"); +} +} // namespace + +TEST(ObliviousHttpResponse, EndToEndTestForResponse) { + auto oblivious_ctx = ObliviousHttpRequest::Context( + GetSeededClientContext(5, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM), + GetSeededEncapsulatedKey()); + auto server_response_encapsulate = + ObliviousHttpResponse::CreateServerObliviousResponse("test response", + oblivious_ctx); + EXPECT_TRUE(server_response_encapsulate.ok()); + auto oblivious_response = + server_response_encapsulate->EncapsulateAndSerialize(); + auto client_response_encapsulate = + ObliviousHttpResponse::CreateClientObliviousResponse(oblivious_response, + oblivious_ctx); + auto decrypted = client_response_encapsulate->GetPlaintextData(); + EXPECT_EQ(decrypted, "test response"); +} + +TEST(ObliviousHttpResponse, TestEncapsulateWithQuicheRandom) { + auto random = TestQuicheRandom('z'); + auto server_seeded_request = SetUpObliviousHttpContext( + 6, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, EVP_HPKE_HKDF_SHA256, + EVP_HPKE_AES_256_GCM, "test"); + auto server_request_context = + std::move(server_seeded_request).ReleaseContext(); + auto server_response_encapsulate = + ObliviousHttpResponse::CreateServerObliviousResponse( + "test response", server_request_context, &random); + EXPECT_TRUE(server_response_encapsulate.ok()); + std::string response_nonce = + server_response_encapsulate->EncapsulateAndSerialize().substr( + 0, GetResponseNonceLength(*(server_request_context.hpke_context_))); + EXPECT_EQ(response_nonce, + std::string( + GetResponseNonceLength(*(server_request_context.hpke_context_)), + 'z')); + absl::string_view expected_encrypted_response = + "2a3271ac4e6a501f51d0264d3dd7d0bc8a06973b58e89c26d6dac06144"; + EXPECT_EQ( + server_response_encapsulate->EncapsulateAndSerialize().substr( + GetResponseNonceLength(*(server_request_context.hpke_context_))), + absl::HexStringToBytes(expected_encrypted_response)); +} + +} // namespace quiche diff --git a/quiche/oblivious_http/common/oblivious_http_header_key_config.cc b/quiche/oblivious_http/common/oblivious_http_header_key_config.cc new file mode 100644 index 000000000000..b0103c23785b --- /dev/null +++ b/quiche/oblivious_http/common/oblivious_http_header_key_config.cc @@ -0,0 +1,472 @@ +#include "quiche/oblivious_http/common/oblivious_http_header_key_config.h" + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "openssl/base.h" +#include "openssl/hpke.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_data_writer.h" +#include "quiche/common/quiche_endian.h" + +namespace quiche { +namespace { + +// Size of KEM ID is 2 bytes. Refer to OHTTP Key Config in the spec, +// https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-06.html#name-a-single-key-configuration +constexpr size_t kSizeOfHpkeKemId = 2; + +// Size of Symmetric algorithms is 2 bytes(16 bits) each. +// Refer to HPKE Symmetric Algorithms configuration in the spec, +// https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-06.html#name-a-single-key-configuration +constexpr size_t kSizeOfSymmetricAlgorithmHpkeKdfId = 2; +constexpr size_t kSizeOfSymmetricAlgorithmHpkeAeadId = 2; + +absl::StatusOr CheckKemId(uint16_t kem_id) { + switch (kem_id) { + case EVP_HPKE_DHKEM_X25519_HKDF_SHA256: + return EVP_hpke_x25519_hkdf_sha256(); + default: + return absl::UnimplementedError("No support for this KEM ID."); + } +} + +absl::StatusOr CheckKdfId(uint16_t kdf_id) { + switch (kdf_id) { + case EVP_HPKE_HKDF_SHA256: + return EVP_hpke_hkdf_sha256(); + default: + return absl::UnimplementedError("No support for this KDF ID."); + } +} + +absl::StatusOr CheckAeadId(uint16_t aead_id) { + switch (aead_id) { + case EVP_HPKE_AES_128_GCM: + return EVP_hpke_aes_128_gcm(); + case EVP_HPKE_AES_256_GCM: + return EVP_hpke_aes_256_gcm(); + case EVP_HPKE_CHACHA20_POLY1305: + return EVP_hpke_chacha20_poly1305(); + default: + return absl::UnimplementedError("No support for this AEAD ID."); + } +} + +} // namespace + +ObliviousHttpHeaderKeyConfig::ObliviousHttpHeaderKeyConfig(uint8_t key_id, + uint16_t kem_id, + uint16_t kdf_id, + uint16_t aead_id) + : key_id_(key_id), kem_id_(kem_id), kdf_id_(kdf_id), aead_id_(aead_id) {} + +absl::StatusOr +ObliviousHttpHeaderKeyConfig::Create(uint8_t key_id, uint16_t kem_id, + uint16_t kdf_id, uint16_t aead_id) { + ObliviousHttpHeaderKeyConfig instance(key_id, kem_id, kdf_id, aead_id); + auto is_config_ok = instance.ValidateKeyConfig(); + if (!is_config_ok.ok()) { + return is_config_ok; + } + return instance; +} + +absl::Status ObliviousHttpHeaderKeyConfig::ValidateKeyConfig() const { + auto supported_kem = CheckKemId(kem_id_); + if (!supported_kem.ok()) { + return absl::InvalidArgumentError( + absl::StrCat("Unsupported KEM ID:", kem_id_)); + } + auto supported_kdf = CheckKdfId(kdf_id_); + if (!supported_kdf.ok()) { + return absl::InvalidArgumentError( + absl::StrCat("Unsupported KDF ID:", kdf_id_)); + } + auto supported_aead = CheckAeadId(aead_id_); + if (!supported_aead.ok()) { + return absl::InvalidArgumentError( + absl::StrCat("Unsupported AEAD ID:", aead_id_)); + } + return absl::OkStatus(); +} + +const EVP_HPKE_KEM* ObliviousHttpHeaderKeyConfig::GetHpkeKem() const { + auto kem = CheckKemId(kem_id_); + QUICHE_CHECK_OK(kem.status()); + return kem.value(); +} +const EVP_HPKE_KDF* ObliviousHttpHeaderKeyConfig::GetHpkeKdf() const { + auto kdf = CheckKdfId(kdf_id_); + QUICHE_CHECK_OK(kdf.status()); + return kdf.value(); +} +const EVP_HPKE_AEAD* ObliviousHttpHeaderKeyConfig::GetHpkeAead() const { + auto aead = CheckAeadId(aead_id_); + QUICHE_CHECK_OK(aead.status()); + return aead.value(); +} + +std::string ObliviousHttpHeaderKeyConfig::SerializeRecipientContextInfo() + const { + uint8_t zero_byte = 0x00; + int buf_len = kOhttpRequestLabel.size() + kHeaderLength + sizeof(zero_byte); + std::string info(buf_len, '\0'); + QuicheDataWriter writer(info.size(), info.data()); + QUICHE_CHECK(writer.WriteStringPiece(kOhttpRequestLabel)); + QUICHE_CHECK(writer.WriteUInt8(zero_byte)); // Zero byte. + QUICHE_CHECK(writer.WriteUInt8(key_id_)); + QUICHE_CHECK(writer.WriteUInt16(kem_id_)); + QUICHE_CHECK(writer.WriteUInt16(kdf_id_)); + QUICHE_CHECK(writer.WriteUInt16(aead_id_)); + return info; +} + +/** + * Follows IETF Ohttp spec, section 4.1 (Encapsulation of Requests). + * https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#section-4.1-10 + */ +absl::Status ObliviousHttpHeaderKeyConfig::ParseOhttpPayloadHeader( + absl::string_view payload_bytes) const { + if (payload_bytes.empty()) { + return absl::InvalidArgumentError("Empty request payload."); + } + QuicheDataReader reader(payload_bytes); + return ParseOhttpPayloadHeader(reader); +} + +absl::Status ObliviousHttpHeaderKeyConfig::ParseOhttpPayloadHeader( + QuicheDataReader& reader) const { + uint8_t key_id; + if (!reader.ReadUInt8(&key_id)) { + return absl::InvalidArgumentError("Failed to read key_id from header."); + } + if (key_id != key_id_) { + return absl::InvalidArgumentError( + absl::StrCat("KeyID in request:", static_cast(key_id), + " doesn't match with server's public key " + "configuration KeyID:", + static_cast(key_id_))); + } + uint16_t kem_id; + if (!reader.ReadUInt16(&kem_id)) { + return absl::InvalidArgumentError("Failed to read kem_id from header."); + } + if (kem_id != kem_id_) { + return absl::InvalidArgumentError( + absl::StrCat("Received Invalid kemID:", kem_id, " Expected:", kem_id_)); + } + uint16_t kdf_id; + if (!reader.ReadUInt16(&kdf_id)) { + return absl::InvalidArgumentError("Failed to read kdf_id from header."); + } + if (kdf_id != kdf_id_) { + return absl::InvalidArgumentError( + absl::StrCat("Received Invalid kdfID:", kdf_id, " Expected:", kdf_id_)); + } + uint16_t aead_id; + if (!reader.ReadUInt16(&aead_id)) { + return absl::InvalidArgumentError("Failed to read aead_id from header."); + } + if (aead_id != aead_id_) { + return absl::InvalidArgumentError(absl::StrCat( + "Received Invalid aeadID:", aead_id, " Expected:", aead_id_)); + } + return absl::OkStatus(); +} + +absl::StatusOr +ObliviousHttpHeaderKeyConfig::ParseKeyIdFromObliviousHttpRequestPayload( + absl::string_view payload_bytes) { + if (payload_bytes.empty()) { + return absl::InvalidArgumentError("Empty request payload."); + } + QuicheDataReader reader(payload_bytes); + uint8_t key_id; + if (!reader.ReadUInt8(&key_id)) { + return absl::InvalidArgumentError("Failed to read key_id from payload."); + } + return key_id; +} + +std::string ObliviousHttpHeaderKeyConfig::SerializeOhttpPayloadHeader() const { + int buf_len = + sizeof(key_id_) + sizeof(kem_id_) + sizeof(kdf_id_) + sizeof(aead_id_); + std::string hdr(buf_len, '\0'); + QuicheDataWriter writer(hdr.size(), hdr.data()); + QUICHE_CHECK(writer.WriteUInt8(key_id_)); + QUICHE_CHECK(writer.WriteUInt16(kem_id_)); // kemID + QUICHE_CHECK(writer.WriteUInt16(kdf_id_)); // kdfID + QUICHE_CHECK(writer.WriteUInt16(aead_id_)); // aeadID + return hdr; +} + +namespace { +// https://www.rfc-editor.org/rfc/rfc9180#section-7.1 +absl::StatusOr KeyLength(uint16_t kem_id) { + auto supported_kem = CheckKemId(kem_id); + if (!supported_kem.ok()) { + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported KEM ID:", kem_id, ". public key length is unknown.")); + } + return EVP_HPKE_KEM_public_key_len(supported_kem.value()); +} + +absl::StatusOr SerializeOhttpKeyWithPublicKey( + uint8_t key_id, absl::string_view public_key, + const std::vector& ohttp_configs) { + auto ohttp_config = ohttp_configs[0]; + // Check if `ohttp_config` match spec's encoding guidelines. + static_assert(sizeof(ohttp_config.GetHpkeKemId()) == kSizeOfHpkeKemId && + sizeof(ohttp_config.GetHpkeKdfId()) == + kSizeOfSymmetricAlgorithmHpkeKdfId && + sizeof(ohttp_config.GetHpkeAeadId()) == + kSizeOfSymmetricAlgorithmHpkeAeadId, + "Size of HPKE IDs should match RFC specification."); + + uint16_t symmetric_algs_length = + ohttp_configs.size() * (kSizeOfSymmetricAlgorithmHpkeKdfId + + kSizeOfSymmetricAlgorithmHpkeAeadId); + int buf_len = sizeof(key_id) + kSizeOfHpkeKemId + public_key.size() + + sizeof(symmetric_algs_length) + symmetric_algs_length; + std::string ohttp_key_configuration(buf_len, '\0'); + QuicheDataWriter writer(ohttp_key_configuration.size(), + ohttp_key_configuration.data()); + if (!writer.WriteUInt8(key_id)) { + return absl::InternalError("Failed to serialize OHTTP key.[key_id]"); + } + if (!writer.WriteUInt16(ohttp_config.GetHpkeKemId())) { + return absl::InternalError( + "Failed to serialize OHTTP key.[kem_id]"); // kemID. + } + if (!writer.WriteStringPiece(public_key)) { + return absl::InternalError( + "Failed to serialize OHTTP key.[public_key]"); // Raw public key. + } + if (!writer.WriteUInt16(symmetric_algs_length)) { + return absl::InternalError( + "Failed to serialize OHTTP key.[symmetric_algs_length]"); + } + for (const auto& item : ohttp_configs) { + // Check if KEM ID is the same for all the configs stored in `this` for + // given `key_id`. + if (item.GetHpkeKemId() != ohttp_config.GetHpkeKemId()) { + QUICHE_BUG(ohttp_key_configs_builder_parser) + << "ObliviousHttpKeyConfigs object cannot hold ConfigMap of " + "different KEM IDs:[ " + << item.GetHpkeKemId() << "," << ohttp_config.GetHpkeKemId() + << " ]for a given key_id:" << static_cast(key_id); + } + if (!writer.WriteUInt16(item.GetHpkeKdfId())) { + return absl::InternalError( + "Failed to serialize OHTTP key.[kdf_id]"); // kdfID. + } + if (!writer.WriteUInt16(item.GetHpkeAeadId())) { + return absl::InternalError( + "Failed to serialize OHTTP key.[aead_id]"); // aeadID. + } + } + QUICHE_DCHECK_EQ(writer.remaining(), 0u); + return ohttp_key_configuration; +} + +std::string GetDebugStringForFailedKeyConfig( + const ObliviousHttpKeyConfigs::OhttpKeyConfig& failed_key_config) { + std::string debug_string = "[ "; + absl::StrAppend(&debug_string, + "key_id:", static_cast(failed_key_config.key_id), + " , kem_id:", failed_key_config.kem_id, + ". Printing HEX formatted public_key:", + absl::BytesToHexString(failed_key_config.public_key)); + absl::StrAppend(&debug_string, ", symmetric_algorithms: { "); + for (const auto& symmetric_config : failed_key_config.symmetric_algorithms) { + absl::StrAppend(&debug_string, "{kdf_id: ", symmetric_config.kdf_id, + ", aead_id:", symmetric_config.aead_id, " }"); + } + absl::StrAppend(&debug_string, " } ]"); + return debug_string; +} + +// Verifies if the `key_config` contains all valid combinations of [kem_id, +// kdf_id, aead_id] that comprises Single Key configuration encoding as +// specified in +// https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#name-a-single-key-configuration. +absl::Status StoreKeyConfigIfValid( + ObliviousHttpKeyConfigs::OhttpKeyConfig key_config, + absl::btree_map, + std::greater>& configs, + absl::flat_hash_map& keys) { + if (!CheckKemId(key_config.kem_id).ok() || + key_config.public_key.size() != KeyLength(key_config.kem_id).value()) { + QUICHE_LOG(ERROR) << "Failed to process: " + << GetDebugStringForFailedKeyConfig(key_config); + return absl::InvalidArgumentError( + absl::StrCat("Invalid key_config! [KEM ID:", key_config.kem_id, "]")); + } + for (const auto& symmetric_config : key_config.symmetric_algorithms) { + if (!CheckKdfId(symmetric_config.kdf_id).ok() || + !CheckAeadId(symmetric_config.aead_id).ok()) { + QUICHE_LOG(ERROR) << "Failed to process: " + << GetDebugStringForFailedKeyConfig(key_config); + return absl::InvalidArgumentError( + absl::StrCat("Invalid key_config! [KDF ID:", symmetric_config.kdf_id, + ", AEAD ID:", symmetric_config.aead_id, "]")); + } + auto ohttp_config = ObliviousHttpHeaderKeyConfig::Create( + key_config.key_id, key_config.kem_id, symmetric_config.kdf_id, + symmetric_config.aead_id); + if (ohttp_config.ok()) { + configs[key_config.key_id].emplace_back(std::move(ohttp_config.value())); + } + } + keys.emplace(key_config.key_id, std::move(key_config.public_key)); + return absl::OkStatus(); +} + +} // namespace + +absl::StatusOr +ObliviousHttpKeyConfigs::ParseConcatenatedKeys(absl::string_view key_config) { + ConfigMap configs; + PublicKeyMap keys; + auto reader = QuicheDataReader(key_config); + while (!reader.IsDoneReading()) { + absl::Status status = ReadSingleKeyConfig(reader, configs, keys); + if (!status.ok()) return status; + } + return ObliviousHttpKeyConfigs(std::move(configs), std::move(keys)); +} + +absl::StatusOr ObliviousHttpKeyConfigs::Create( + absl::flat_hash_set + ohttp_key_configs) { + if (ohttp_key_configs.empty()) { + return absl::InvalidArgumentError("Empty input."); + } + ConfigMap configs_map; + PublicKeyMap keys_map; + for (auto& ohttp_key_config : ohttp_key_configs) { + auto result = StoreKeyConfigIfValid(std::move(ohttp_key_config), + configs_map, keys_map); + if (!result.ok()) { + return result; + } + } + auto oblivious_configs = + ObliviousHttpKeyConfigs(std::move(configs_map), std::move(keys_map)); + return oblivious_configs; +} + +absl::StatusOr ObliviousHttpKeyConfigs::Create( + const ObliviousHttpHeaderKeyConfig& single_key_config, + absl::string_view public_key) { + if (public_key.empty()) { + return absl::InvalidArgumentError("Empty input."); + } + + if (auto key_length = KeyLength(single_key_config.GetHpkeKemId()); + public_key.size() != key_length.value()) { + return absl::InvalidArgumentError(absl::StrCat( + "Invalid key. Key size mismatch. Expected:", key_length.value(), + " Actual:", public_key.size())); + } + + ConfigMap configs; + PublicKeyMap keys; + uint8_t key_id = single_key_config.GetKeyId(); + keys.emplace(key_id, public_key); + configs[key_id].emplace_back(std::move(single_key_config)); + return ObliviousHttpKeyConfigs(std::move(configs), std::move(keys)); +} + +absl::StatusOr ObliviousHttpKeyConfigs::GenerateConcatenatedKeys() + const { + std::string concatenated_keys; + for (const auto& [key_id, ohttp_configs] : configs_) { + auto key = public_keys_.find(key_id); + if (key == public_keys_.end()) { + return absl::InternalError( + "Failed to serialize. No public key found for key_id"); + } + auto serialized = + SerializeOhttpKeyWithPublicKey(key_id, key->second, ohttp_configs); + if (!serialized.ok()) { + return absl::InternalError("Failed to serialize OHTTP key configs."); + } + absl::StrAppend(&concatenated_keys, serialized.value()); + } + return concatenated_keys; +} + +ObliviousHttpHeaderKeyConfig ObliviousHttpKeyConfigs::PreferredConfig() const { + // configs_ is forced to have at least one object during construction. + return configs_.begin()->second.front(); +} + +absl::StatusOr ObliviousHttpKeyConfigs::GetPublicKeyForId( + uint8_t key_id) const { + auto key = public_keys_.find(key_id); + if (key == public_keys_.end()) { + return absl::NotFoundError("No public key found for key_id"); + } + return key->second; +} + +absl::Status ObliviousHttpKeyConfigs::ReadSingleKeyConfig( + QuicheDataReader& reader, ConfigMap& configs, PublicKeyMap& keys) { + uint8_t key_id; + uint16_t kem_id; + // First byte: key_id; next two bytes: kem_id. + if (!reader.ReadUInt8(&key_id) || !reader.ReadUInt16(&kem_id)) { + return absl::InvalidArgumentError("Invalid key_config!"); + } + + // Public key length depends on the kem_id. + auto maybe_key_length = KeyLength(kem_id); + if (!maybe_key_length.ok()) { + return maybe_key_length.status(); + } + const int key_length = maybe_key_length.value(); + std::string key_str(key_length, '\0'); + if (!reader.ReadBytes(key_str.data(), key_length)) { + return absl::InvalidArgumentError("Invalid key_config!"); + } + if (!keys.insert({key_id, std::move(key_str)}).second) { + return absl::InvalidArgumentError("Duplicate key_id's in key_config!"); + } + + // Extract the algorithms for this public key. + absl::string_view alg_bytes; + // Read the 16-bit length, then read that many bytes into alg_bytes. + if (!reader.ReadStringPiece16(&alg_bytes)) { + return absl::InvalidArgumentError("Invalid key_config!"); + } + QuicheDataReader sub_reader(alg_bytes); + while (!sub_reader.IsDoneReading()) { + uint16_t kdf_id; + uint16_t aead_id; + if (!sub_reader.ReadUInt16(&kdf_id) || !sub_reader.ReadUInt16(&aead_id)) { + return absl::InvalidArgumentError("Invalid key_config!"); + } + + absl::StatusOr maybe_cfg = + ObliviousHttpHeaderKeyConfig::Create(key_id, kem_id, kdf_id, aead_id); + if (!maybe_cfg.ok()) { + // TODO(kmg): Add support to ignore key types in the server response that + // aren't supported by the client. + return maybe_cfg.status(); + } + configs[key_id].emplace_back(std::move(maybe_cfg.value())); + } + return absl::OkStatus(); +} + +} // namespace quiche diff --git a/quiche/oblivious_http/common/oblivious_http_header_key_config.h b/quiche/oblivious_http/common/oblivious_http_header_key_config.h new file mode 100644 index 000000000000..488561b55e5a --- /dev/null +++ b/quiche/oblivious_http/common/oblivious_http_header_key_config.h @@ -0,0 +1,220 @@ +#ifndef QUICHE_OBLIVIOUS_HTTP_COMMON_OBLIVIOUS_HTTP_HEADER_KEY_CONFIG_H_ +#define QUICHE_OBLIVIOUS_HTTP_COMMON_OBLIVIOUS_HTTP_HEADER_KEY_CONFIG_H_ + +#include + +#include "absl/container/btree_map.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "openssl/base.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/quiche_data_reader.h" + +namespace quiche { + +class QUICHE_EXPORT ObliviousHttpHeaderKeyConfig { + public: + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#section-4.1-4.2 + static constexpr absl::string_view kOhttpRequestLabel = + "message/bhttp request"; + static constexpr absl::string_view kOhttpResponseLabel = + "message/bhttp response"; + // Length of the Oblivious HTTP header. + static constexpr uint32_t kHeaderLength = + sizeof(uint8_t) + (3 * sizeof(uint16_t)); + static constexpr absl::string_view kKeyHkdfInfo = "key"; + static constexpr absl::string_view kNonceHkdfInfo = "nonce"; + + static absl::StatusOr Create(uint8_t key_id, + uint16_t kem_id, + uint16_t kdf_id, + uint16_t aead_id); + + // Copyable to support stack allocated pass-by-value for trivial data members. + ObliviousHttpHeaderKeyConfig(const ObliviousHttpHeaderKeyConfig& other) = + default; + ObliviousHttpHeaderKeyConfig& operator=( + const ObliviousHttpHeaderKeyConfig& other) = default; + + // Movable. + ObliviousHttpHeaderKeyConfig(ObliviousHttpHeaderKeyConfig&& other) = default; + ObliviousHttpHeaderKeyConfig& operator=( + ObliviousHttpHeaderKeyConfig&& other) = default; + + ~ObliviousHttpHeaderKeyConfig() = default; + + const EVP_HPKE_KEM* GetHpkeKem() const; + const EVP_HPKE_KDF* GetHpkeKdf() const; + const EVP_HPKE_AEAD* GetHpkeAead() const; + + uint8_t GetKeyId() const { return key_id_; } + uint16_t GetHpkeKemId() const { return kem_id_; } + uint16_t GetHpkeKdfId() const { return kdf_id_; } + uint16_t GetHpkeAeadId() const { return aead_id_; } + + // Build HPKE context info ["message/bhttp request", 0x00, keyID(1 byte), + // kemID(2 bytes), kdfID(2 bytes), aeadID(2 bytes)] in network byte order and + // return a sequence of bytes(bytestring). + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#section-4.1-10 + std::string SerializeRecipientContextInfo() const; + + // Parses the below Header + // [keyID(1 byte), kemID(2 bytes), kdfID(2 bytes), aeadID(2 bytes)] + // from the payload received in Ohttp Request, and verifies that these values + // match with the info stored in `this` namely [key_id_, kem_id_, kdf_id_, + // aead_id_] + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#section-4.1-7 + absl::Status ParseOhttpPayloadHeader(absl::string_view payload_bytes) const; + + // Parses the Oblivious HTTP header [keyID(1 byte), kemID(2 bytes), kdfID(2 + // bytes), aeadID(2 bytes)] from the buffer initialized within + // `QuicheDataReader`, and verifies these values against instantiated class + // data namely [key_id_, kem_id_, kdf_id_, aead_id_] for a match. On + // success(i.e., if matched successfully), leaves `reader` pointing at the + // first byte after the header. + absl::Status ParseOhttpPayloadHeader(QuicheDataReader& reader) const; + + // Extracts Key ID from the OHTTP Request payload. + static absl::StatusOr ParseKeyIdFromObliviousHttpRequestPayload( + absl::string_view payload_bytes); + + // Build Request header according to network byte order and return string. + std::string SerializeOhttpPayloadHeader() const; + + private: + // Constructor + explicit ObliviousHttpHeaderKeyConfig(uint8_t key_id, uint16_t kem_id, + uint16_t kdf_id, uint16_t aead_id); + + // Helps validate Key configuration for supported schemes. + absl::Status ValidateKeyConfig() const; + + // Public Key configuration hosted by Gateway to facilitate Oblivious HTTP + // HPKE encryption. + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#name-key-configuration-encoding + uint8_t key_id_; + uint16_t kem_id_; + uint16_t kdf_id_; + uint16_t aead_id_; +}; + +// Contains multiple ObliviousHttpHeaderKeyConfig objects and associated private +// keys. An ObliviousHttpHeaderKeyConfigs object can be constructed from the +// "Key Configuration" defined in the Oblivious HTTP spec. Multiple key +// configurations maybe be supported by the server. +// +// See https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-04.html#section-3 +// for details of the "Key Configuration" spec. +// +// ObliviousHttpKeyConfigs objects are immutable after construction. +class QUICHE_EXPORT ObliviousHttpKeyConfigs { + public: + // Below two structures follow the Single key configuration spec in OHTTP RFC. + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-06.html#name-a-single-key-configuration + struct SymmetricAlgorithmsConfig { + uint16_t kdf_id; + uint16_t aead_id; + + bool operator==(const SymmetricAlgorithmsConfig& other) const { + return kdf_id == other.kdf_id && aead_id == other.aead_id; + } + + template + friend H AbslHashValue(H h, const SymmetricAlgorithmsConfig& sym_alg_cfg) { + return H::combine(std::move(h), sym_alg_cfg.kdf_id, sym_alg_cfg.aead_id); + } + }; + + struct OhttpKeyConfig { + uint8_t key_id; + uint16_t kem_id; + std::string public_key; // Raw byte string. + absl::flat_hash_set symmetric_algorithms; + + bool operator==(const OhttpKeyConfig& other) const { + return key_id == other.key_id && kem_id == other.kem_id && + public_key == other.public_key && + symmetric_algorithms == other.symmetric_algorithms; + } + + template + friend H AbslHashValue(H h, const OhttpKeyConfig& ohttp_key_cfg) { + return H::combine(std::move(h), ohttp_key_cfg.key_id, + ohttp_key_cfg.kem_id, ohttp_key_cfg.public_key, + ohttp_key_cfg.symmetric_algorithms); + } + }; + + // Parses the "application/ohttp-keys" media type, which is a byte string + // formatted according to the spec: + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-04.html#section-3 + static absl::StatusOr ParseConcatenatedKeys( + absl::string_view key_configs); + + // Builds `ObliviousHttpKeyConfigs` with multiple key configurations, each + // made up of Single Key Configuration([{key_id, kem_id, public key}, + // Set]) encoding specified in section 3. + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#name-key-configuration-encoding + // @params: Set<{key_id, kem_id, public key, Set>. + // @return: When given all valid configs supported by BoringSSL, builds and + // returns `ObliviousHttpKeyConfigs`. If any one of the input configs are + // invalid or unsupported by BSSL, returns an error. + // @note: Subsequently, To get concatenated keys[contiguous byte string of + // keys], use `GenerateConcatenatedKeys()`. This output can inturn be parsed + // by `ObliviousHttpKeyConfigs::ParseConcatenatedKeys` on client side. + static absl::StatusOr Create( + absl::flat_hash_set ohttp_key_configs); + + // Builds `ObliviousHttpKeyConfigs` with given public_key and Single key + // configuration specified in `ObliviousHttpHeaderKeyConfig` object. After + // successful `Create`, clients can call `GenerateConcatenatedKeys()` to build + // the Single key config. + static absl::StatusOr Create( + const ObliviousHttpHeaderKeyConfig& single_key_config, + absl::string_view public_key); + + // Generates byte string corresponding to "application/ohttp-keys" media type. + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-04.html#section-3 + absl::StatusOr GenerateConcatenatedKeys() const; + + int NumKeys() const { return public_keys_.size(); } + + // Returns a preferred config to use. The preferred key is the key with + // the highest key_id. If more than one configuration exists for the + // preferred key any configuration may be returned. + // + // These methods are useful in the (common) case where only one key + // configuration is supported by the server. + ObliviousHttpHeaderKeyConfig PreferredConfig() const; + + absl::StatusOr GetPublicKeyForId(uint8_t key_id) const; + + // TODO(kmg): Add methods to somehow access other non-preferred key + // configurations. + + private: + using PublicKeyMap = absl::flat_hash_map; + using ConfigMap = + absl::btree_map, + std::greater>; + + ObliviousHttpKeyConfigs(ConfigMap cm, PublicKeyMap km) + : configs_(std::move(cm)), public_keys_(std::move(km)) {} + + static absl::Status ReadSingleKeyConfig(QuicheDataReader& reader, + ConfigMap& configs, + PublicKeyMap& keys); + + // A mapping from key_id to ObliviousHttpHeaderKeyConfig objects for that key. + const ConfigMap configs_; + + // A mapping from key_id to the public key for that key_id. + const PublicKeyMap public_keys_; +}; + +} // namespace quiche + +#endif // QUICHE_OBLIVIOUS_HTTP_COMMON_OBLIVIOUS_HTTP_HEADER_KEY_CONFIG_H_ diff --git a/quiche/oblivious_http/common/oblivious_http_header_key_config_test.cc b/quiche/oblivious_http/common/oblivious_http_header_key_config_test.cc new file mode 100644 index 000000000000..73e4c07afc33 --- /dev/null +++ b/quiche/oblivious_http/common/oblivious_http_header_key_config_test.cc @@ -0,0 +1,356 @@ +#include "quiche/oblivious_http/common/oblivious_http_header_key_config.h" + +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "openssl/hpke.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/quiche_data_writer.h" + +namespace quiche { +namespace { +using ::testing::AllOf; +using ::testing::Property; +using ::testing::StrEq; +using ::testing::UnorderedElementsAre; +using ::testing::UnorderedElementsAreArray; + +/** + * Build Request header. + */ +std::string BuildHeader(uint8_t key_id, uint16_t kem_id, uint16_t kdf_id, + uint16_t aead_id) { + int buf_len = + sizeof(key_id) + sizeof(kem_id) + sizeof(kdf_id) + sizeof(aead_id); + std::string hdr(buf_len, '\0'); + QuicheDataWriter writer(hdr.size(), hdr.data()); + EXPECT_TRUE(writer.WriteUInt8(key_id)); + EXPECT_TRUE(writer.WriteUInt16(kem_id)); // kemID + EXPECT_TRUE(writer.WriteUInt16(kdf_id)); // kdfID + EXPECT_TRUE(writer.WriteUInt16(aead_id)); // aeadID + return hdr; +} + +std::string GetSerializedKeyConfig( + ObliviousHttpKeyConfigs::OhttpKeyConfig& key_config) { + uint16_t symmetric_algs_length = + key_config.symmetric_algorithms.size() * + (sizeof(key_config.symmetric_algorithms.cbegin()->kdf_id) + + sizeof(key_config.symmetric_algorithms.cbegin()->aead_id)); + int buf_len = sizeof(key_config.key_id) + sizeof(key_config.kem_id) + + key_config.public_key.size() + sizeof(symmetric_algs_length) + + symmetric_algs_length; + std::string ohttp_key(buf_len, '\0'); + QuicheDataWriter writer(ohttp_key.size(), ohttp_key.data()); + EXPECT_TRUE(writer.WriteUInt8(key_config.key_id)); + EXPECT_TRUE(writer.WriteUInt16(key_config.kem_id)); + EXPECT_TRUE(writer.WriteStringPiece(key_config.public_key)); + EXPECT_TRUE(writer.WriteUInt16(symmetric_algs_length)); + for (const auto& symmetric_alg : key_config.symmetric_algorithms) { + EXPECT_TRUE(writer.WriteUInt16(symmetric_alg.kdf_id)); + EXPECT_TRUE(writer.WriteUInt16(symmetric_alg.aead_id)); + } + return ohttp_key; +} + +TEST(ObliviousHttpHeaderKeyConfig, TestSerializeRecipientContextInfo) { + uint8_t key_id = 3; + uint16_t kem_id = EVP_HPKE_DHKEM_X25519_HKDF_SHA256; + uint16_t kdf_id = EVP_HPKE_HKDF_SHA256; + uint16_t aead_id = EVP_HPKE_AES_256_GCM; + absl::string_view ohttp_req_label = "message/bhttp request"; + std::string expected(ohttp_req_label); + uint8_t zero_byte = 0x00; + int buf_len = ohttp_req_label.size() + sizeof(zero_byte) + sizeof(key_id) + + sizeof(kem_id) + sizeof(kdf_id) + sizeof(aead_id); + expected.reserve(buf_len); + expected.push_back(zero_byte); + std::string ohttp_cfg(BuildHeader(key_id, kem_id, kdf_id, aead_id)); + expected.insert(expected.end(), ohttp_cfg.begin(), ohttp_cfg.end()); + auto instance = + ObliviousHttpHeaderKeyConfig::Create(key_id, kem_id, kdf_id, aead_id); + ASSERT_TRUE(instance.ok()); + EXPECT_EQ(instance.value().SerializeRecipientContextInfo(), expected); +} + +TEST(ObliviousHttpHeaderKeyConfig, TestValidKeyConfig) { + auto valid_key_config = ObliviousHttpHeaderKeyConfig::Create( + 2, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, EVP_HPKE_HKDF_SHA256, + EVP_HPKE_AES_256_GCM); + ASSERT_TRUE(valid_key_config.ok()); +} + +TEST(ObliviousHttpHeaderKeyConfig, TestInvalidKeyConfig) { + auto invalid_kem = ObliviousHttpHeaderKeyConfig::Create( + 3, 0, EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM); + EXPECT_EQ(invalid_kem.status().code(), absl::StatusCode::kInvalidArgument); + auto invalid_kdf = ObliviousHttpHeaderKeyConfig::Create( + 3, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, 0, EVP_HPKE_AES_256_GCM); + EXPECT_EQ(invalid_kdf.status().code(), absl::StatusCode::kInvalidArgument); + auto invalid_aead = ObliviousHttpHeaderKeyConfig::Create( + 3, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, EVP_HPKE_HKDF_SHA256, 0); + EXPECT_EQ(invalid_kdf.status().code(), absl::StatusCode::kInvalidArgument); +} + +TEST(ObliviousHttpHeaderKeyConfig, TestParsingValidHeader) { + auto instance = ObliviousHttpHeaderKeyConfig::Create( + 5, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, EVP_HPKE_HKDF_SHA256, + EVP_HPKE_AES_256_GCM); + ASSERT_TRUE(instance.ok()); + std::string good_hdr(BuildHeader(5, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM)); + ASSERT_TRUE(instance.value().ParseOhttpPayloadHeader(good_hdr).ok()); +} + +TEST(ObliviousHttpHeaderKeyConfig, TestParsingInvalidHeader) { + auto instance = ObliviousHttpHeaderKeyConfig::Create( + 8, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, EVP_HPKE_HKDF_SHA256, + EVP_HPKE_AES_256_GCM); + ASSERT_TRUE(instance.ok()); + std::string keyid_mismatch_hdr( + BuildHeader(0, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, EVP_HPKE_HKDF_SHA256, + EVP_HPKE_AES_256_GCM)); + EXPECT_EQ(instance.value().ParseOhttpPayloadHeader(keyid_mismatch_hdr).code(), + absl::StatusCode::kInvalidArgument); + std::string invalid_hpke_hdr(BuildHeader(8, 0, 0, 0)); + EXPECT_EQ(instance.value().ParseOhttpPayloadHeader(invalid_hpke_hdr).code(), + absl::StatusCode::kInvalidArgument); +} + +TEST(ObliviousHttpHeaderKeyConfig, TestParsingKeyIdFromObliviousHttpRequest) { + std::string key_id(sizeof(uint8_t), '\0'); + QuicheDataWriter writer(key_id.size(), key_id.data()); + EXPECT_TRUE(writer.WriteUInt8(99)); + auto parsed_key_id = + ObliviousHttpHeaderKeyConfig::ParseKeyIdFromObliviousHttpRequestPayload( + key_id); + ASSERT_TRUE(parsed_key_id.ok()); + EXPECT_EQ(parsed_key_id.value(), 99); +} + +TEST(ObliviousHttpHeaderKeyConfig, TestCopyable) { + auto obj1 = ObliviousHttpHeaderKeyConfig::Create( + 4, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, EVP_HPKE_HKDF_SHA256, + EVP_HPKE_AES_256_GCM); + ASSERT_TRUE(obj1.ok()); + auto copy_obj1_to_obj2 = obj1.value(); + EXPECT_EQ(copy_obj1_to_obj2.kHeaderLength, obj1->kHeaderLength); + EXPECT_EQ(copy_obj1_to_obj2.SerializeRecipientContextInfo(), + obj1->SerializeRecipientContextInfo()); +} + +TEST(ObliviousHttpHeaderKeyConfig, TestSerializeOhttpPayloadHeader) { + auto instance = ObliviousHttpHeaderKeyConfig::Create( + 7, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, EVP_HPKE_HKDF_SHA256, + EVP_HPKE_AES_128_GCM); + ASSERT_TRUE(instance.ok()); + EXPECT_EQ(instance->SerializeOhttpPayloadHeader(), + BuildHeader(7, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_128_GCM)); +} + +MATCHER_P(HasKeyId, id, "") { + *result_listener << "has key_id=" << arg.GetKeyId(); + return arg.GetKeyId() == id; +} +MATCHER_P(HasKemId, id, "") { + *result_listener << "has kem_id=" << arg.GetHpkeKemId(); + return arg.GetHpkeKemId() == id; +} +MATCHER_P(HasKdfId, id, "") { + *result_listener << "has kdf_id=" << arg.GetHpkeKdfId(); + return arg.GetHpkeKdfId() == id; +} +MATCHER_P(HasAeadId, id, "") { + *result_listener << "has aead_id=" << arg.GetHpkeAeadId(); + return arg.GetHpkeAeadId() == id; +} + +TEST(ObliviousHttpKeyConfigs, SingleKeyConfig) { + std::string key = absl::HexStringToBytes( + "4b0020f83e0a17cbdb18d2684dd2a9b087a43e5f3fa3fa27a049bc746a6e97a1e0244b00" + "0400010002"); + auto configs = ObliviousHttpKeyConfigs::ParseConcatenatedKeys(key).value(); + EXPECT_THAT(configs, Property(&ObliviousHttpKeyConfigs::NumKeys, 1)); + EXPECT_THAT( + configs.PreferredConfig(), + AllOf(HasKeyId(0x4b), HasKemId(EVP_HPKE_DHKEM_X25519_HKDF_SHA256), + HasKdfId(EVP_HPKE_HKDF_SHA256), HasAeadId(EVP_HPKE_AES_256_GCM))); + EXPECT_THAT( + configs.GetPublicKeyForId(configs.PreferredConfig().GetKeyId()).value(), + StrEq(absl::HexStringToBytes( + "f83e0a17cbdb18d2684dd2a9b087a43e5f3fa3fa27a049bc746a6e97a1e0244b"))); +} + +TEST(ObliviousHttpKeyConfigs, TwoSimilarKeyConfigs) { + std::string key = absl::HexStringToBytes( + "4b0020f83e0a17cbdb18d2684dd2a9b087a43e5f3fa3fa27a049bc746a6e97a1e0244b00" + "0400010002" // Intentional concatenation + "4f0020f83e0a17cbdb18d2684dd2a9b087a43e5f3fa3fa27a049bc746a6e97a1e0244b00" + "0400010001"); + EXPECT_THAT(ObliviousHttpKeyConfigs::ParseConcatenatedKeys(key).value(), + Property(&ObliviousHttpKeyConfigs::NumKeys, 2)); + EXPECT_THAT( + ObliviousHttpKeyConfigs::ParseConcatenatedKeys(key)->PreferredConfig(), + AllOf(HasKeyId(0x4f), HasKemId(EVP_HPKE_DHKEM_X25519_HKDF_SHA256), + HasKdfId(EVP_HPKE_HKDF_SHA256), HasAeadId(EVP_HPKE_AES_128_GCM))); +} + +TEST(ObliviousHttpKeyConfigs, RFCExample) { + std::string key = absl::HexStringToBytes( + "01002031e1f05a740102115220e9af918f738674aec95f54db6e04eb705aae8e79815500" + "080001000100010003"); + auto configs = ObliviousHttpKeyConfigs::ParseConcatenatedKeys(key).value(); + EXPECT_THAT(configs, Property(&ObliviousHttpKeyConfigs::NumKeys, 1)); + EXPECT_THAT( + configs.PreferredConfig(), + AllOf(HasKeyId(0x01), HasKemId(EVP_HPKE_DHKEM_X25519_HKDF_SHA256), + HasKdfId(EVP_HPKE_HKDF_SHA256), HasAeadId(EVP_HPKE_AES_128_GCM))); + EXPECT_THAT( + configs.GetPublicKeyForId(configs.PreferredConfig().GetKeyId()).value(), + StrEq(absl::HexStringToBytes( + "31e1f05a740102115220e9af918f738674aec95f54db6e04eb705aae8e798155"))); +} + +TEST(ObliviousHttpKeyConfigs, DuplicateKeyId) { + std::string key = absl::HexStringToBytes( + "4b0020f83e0a17cbdb18d2684dd2a9b087a43e5f3fa3fa27a049bc746a6e97a1e0244b00" + "0400010002" // Intentional concatenation + "4b0020f83e0a17cbdb18d2684dd2a9b087a43e5f3fa3fb27a049bc746a6e97a1e0244b00" + "0400010001"); + EXPECT_FALSE(ObliviousHttpKeyConfigs::ParseConcatenatedKeys(key).ok()); +} + +TEST(ObliviousHttpHeaderKeyConfigs, TestCreateWithSingleKeyConfig) { + auto instance = ObliviousHttpHeaderKeyConfig::Create( + 123, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, EVP_HPKE_HKDF_SHA256, + EVP_HPKE_CHACHA20_POLY1305); + EXPECT_TRUE(instance.ok()); + std::string test_public_key( + EVP_HPKE_KEM_public_key_len(instance->GetHpkeKem()), 'a'); + auto configs = + ObliviousHttpKeyConfigs::Create(instance.value(), test_public_key); + EXPECT_TRUE(configs.ok()); + auto serialized_key = configs->GenerateConcatenatedKeys(); + EXPECT_TRUE(serialized_key.ok()); + auto ohttp_configs = + ObliviousHttpKeyConfigs::ParseConcatenatedKeys(serialized_key.value()); + EXPECT_TRUE(ohttp_configs.ok()); + ASSERT_EQ(ohttp_configs->PreferredConfig().GetKeyId(), 123); + auto parsed_public_key = ohttp_configs->GetPublicKeyForId(123); + EXPECT_TRUE(parsed_public_key.ok()); + EXPECT_EQ(parsed_public_key.value(), test_public_key); +} + +TEST(ObliviousHttpHeaderKeyConfigs, TestCreateWithWithMultipleKeys) { + std::string expected_preferred_public_key(32, 'b'); + ObliviousHttpKeyConfigs::OhttpKeyConfig config1 = { + 100, + EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + std::string(32, 'a'), + {{EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM}}}; + ObliviousHttpKeyConfigs::OhttpKeyConfig config2 = { + 200, + EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + expected_preferred_public_key, + {{EVP_HPKE_HKDF_SHA256, EVP_HPKE_CHACHA20_POLY1305}}}; + auto configs = ObliviousHttpKeyConfigs::Create({config1, config2}); + EXPECT_TRUE(configs.ok()); + auto serialized_key = configs->GenerateConcatenatedKeys(); + EXPECT_TRUE(serialized_key.ok()); + ASSERT_EQ(serialized_key.value(), + absl::StrCat(GetSerializedKeyConfig(config2), + GetSerializedKeyConfig(config1))); + auto ohttp_configs = + ObliviousHttpKeyConfigs::ParseConcatenatedKeys(serialized_key.value()); + EXPECT_TRUE(ohttp_configs.ok()); + ASSERT_EQ(ohttp_configs->NumKeys(), 2); + EXPECT_THAT(configs->PreferredConfig(), + AllOf(HasKeyId(200), HasKemId(EVP_HPKE_DHKEM_X25519_HKDF_SHA256), + HasKdfId(EVP_HPKE_HKDF_SHA256), + HasAeadId(EVP_HPKE_CHACHA20_POLY1305))); + auto parsed_preferred_public_key = ohttp_configs->GetPublicKeyForId( + ohttp_configs->PreferredConfig().GetKeyId()); + EXPECT_TRUE(parsed_preferred_public_key.ok()); + EXPECT_EQ(parsed_preferred_public_key.value(), expected_preferred_public_key); +} + +TEST(ObliviousHttpHeaderKeyConfigs, TestCreateWithInvalidConfigs) { + ASSERT_EQ(ObliviousHttpKeyConfigs::Create({}).status().code(), + absl::StatusCode::kInvalidArgument); + ASSERT_EQ(ObliviousHttpKeyConfigs::Create( + {{100, 2, std::string(32, 'a'), {{2, 3}, {4, 5}}}, + {200, 6, std::string(32, 'b'), {{7, 8}, {9, 10}}}}) + .status() + .code(), + absl::StatusCode::kInvalidArgument); + + EXPECT_EQ( + ObliviousHttpKeyConfigs::Create( + {{123, + EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + "invalid key length" /*expected length for given kem_id is 32*/, + {{EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_128_GCM}}}}) + .status() + .code(), + absl::StatusCode::kInvalidArgument); +} + +TEST(ObliviousHttpHeaderKeyConfigs, + TestCreateSingleKeyConfigWithInvalidConfig) { + const auto sample_ohttp_hdr_config = ObliviousHttpHeaderKeyConfig::Create( + 123, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, EVP_HPKE_HKDF_SHA256, + EVP_HPKE_AES_128_GCM); + ASSERT_TRUE(sample_ohttp_hdr_config.ok()); + ASSERT_EQ(ObliviousHttpKeyConfigs::Create(sample_ohttp_hdr_config.value(), + "" /*empty public_key*/) + .status() + .code(), + absl::StatusCode::kInvalidArgument); + EXPECT_EQ(ObliviousHttpKeyConfigs::Create( + sample_ohttp_hdr_config.value(), + "invalid key length" /*expected length for given kem_id is 32*/) + .status() + .code(), + absl::StatusCode::kInvalidArgument); +} + +TEST(ObliviousHttpHeaderKeyConfigs, TestHashImplWithObliviousStruct) { + // Insert different symmetric algorithms 50 times. + absl::flat_hash_set + symmetric_algs_set; + for (int i = 0; i < 50; ++i) { + symmetric_algs_set.insert({EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_128_GCM}); + symmetric_algs_set.insert({EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM}); + symmetric_algs_set.insert( + {EVP_HPKE_HKDF_SHA256, EVP_HPKE_CHACHA20_POLY1305}); + } + ASSERT_EQ(symmetric_algs_set.size(), 3); + EXPECT_THAT(symmetric_algs_set, + UnorderedElementsAreArray< + ObliviousHttpKeyConfigs::SymmetricAlgorithmsConfig>({ + {EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_128_GCM}, + {EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM}, + {EVP_HPKE_HKDF_SHA256, EVP_HPKE_CHACHA20_POLY1305}, + })); + + // Insert different Key configs 50 times. + absl::flat_hash_set + ohttp_key_configs_set; + ObliviousHttpKeyConfigs::OhttpKeyConfig expected_key_config{ + 100, + EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + std::string(32, 'c'), + {{EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_128_GCM}, + {EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM}}}; + for (int i = 0; i < 50; ++i) { + ohttp_key_configs_set.insert(expected_key_config); + } + ASSERT_EQ(ohttp_key_configs_set.size(), 1); + EXPECT_THAT(ohttp_key_configs_set, UnorderedElementsAre(expected_key_config)); +} + +} // namespace +} // namespace quiche diff --git a/quiche/oblivious_http/oblivious_http_client.cc b/quiche/oblivious_http/oblivious_http_client.cc new file mode 100644 index 000000000000..8a77c752e0da --- /dev/null +++ b/quiche/oblivious_http/oblivious_http_client.cc @@ -0,0 +1,91 @@ +#include "quiche/oblivious_http/oblivious_http_client.h" + +#include +#include + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/common/quiche_crypto_logging.h" + +namespace quiche { + +namespace { + +// Use BoringSSL's setup_sender API to validate whether the HPKE public key +// input provided by the user is valid. +absl::Status ValidateClientParameters( + absl::string_view hpke_public_key, + const ObliviousHttpHeaderKeyConfig& ohttp_key_config) { + // Initialize HPKE client context and check if context can be setup with the + // given public key to verify if the public key is indeed valid. + bssl::UniquePtr client_ctx(EVP_HPKE_CTX_new()); + if (client_ctx == nullptr) { + return SslErrorAsStatus( + "Failed to initialize HPKE ObliviousHttpClient Context."); + } + // Setup the sender (client) + std::string encapsulated_key(EVP_HPKE_MAX_ENC_LENGTH, '\0'); + size_t enc_len; + absl::string_view info = "verify if given HPKE public key is valid"; + if (!EVP_HPKE_CTX_setup_sender( + client_ctx.get(), reinterpret_cast(encapsulated_key.data()), + &enc_len, encapsulated_key.size(), ohttp_key_config.GetHpkeKem(), + ohttp_key_config.GetHpkeKdf(), ohttp_key_config.GetHpkeAead(), + reinterpret_cast(hpke_public_key.data()), + hpke_public_key.size(), reinterpret_cast(info.data()), + info.size())) { + return SslErrorAsStatus( + "Failed to setup HPKE context with given public key param " + "hpke_public_key."); + } + return absl::OkStatus(); +} + +} // namespace + +// Constructor. +ObliviousHttpClient::ObliviousHttpClient( + std::string client_public_key, + const ObliviousHttpHeaderKeyConfig& ohttp_key_config) + : hpke_public_key_(std::move(client_public_key)), + ohttp_key_config_(ohttp_key_config) {} + +// Initialize Bssl. +absl::StatusOr ObliviousHttpClient::Create( + absl::string_view hpke_public_key, + const ObliviousHttpHeaderKeyConfig& ohttp_key_config) { + if (hpke_public_key.empty()) { + return absl::InvalidArgumentError("Invalid/Empty HPKE public key."); + } + auto is_valid_input = + ValidateClientParameters(hpke_public_key, ohttp_key_config); + if (!is_valid_input.ok()) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid input received in method parameters. ", + is_valid_input.message())); + } + return ObliviousHttpClient(std::string(hpke_public_key), ohttp_key_config); +} + +absl::StatusOr +ObliviousHttpClient::CreateObliviousHttpRequest( + std::string plaintext_data) const { + return ObliviousHttpRequest::CreateClientObliviousRequest( + std::move(plaintext_data), hpke_public_key_, ohttp_key_config_); +} + +absl::StatusOr +ObliviousHttpClient::DecryptObliviousHttpResponse( + std::string encrypted_data, + ObliviousHttpRequest::Context& oblivious_http_request_context) const { + return ObliviousHttpResponse::CreateClientObliviousResponse( + std::move(encrypted_data), oblivious_http_request_context); +} + +} // namespace quiche diff --git a/quiche/oblivious_http/oblivious_http_client.h b/quiche/oblivious_http/oblivious_http_client.h new file mode 100644 index 000000000000..9527b68b731a --- /dev/null +++ b/quiche/oblivious_http/oblivious_http_client.h @@ -0,0 +1,80 @@ +#ifndef QUICHE_OBLIVIOUS_HTTP_OBLIVIOUS_HTTP_CLIENT_H_ +#define QUICHE_OBLIVIOUS_HTTP_OBLIVIOUS_HTTP_CLIENT_H_ + +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "openssl/base.h" +#include "openssl/hpke.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/oblivious_http/buffers/oblivious_http_request.h" +#include "quiche/oblivious_http/buffers/oblivious_http_response.h" +#include "quiche/oblivious_http/common/oblivious_http_header_key_config.h" + +namespace quiche { +// 1. Facilitates client side to intiate OHttp request flow by initializing the +// HPKE public key obtained from server, and subsequently uses it to encrypt the +// Binary HTTP request payload. +// 2. After initializing this class with server's HPKE public key, users can +// call `CreateObliviousHttpRequest` which constructs OHTTP request of the input +// payload(Binary HTTP request). +// 3. Handles decryption of response (that's in the form of encrypted Binary +// HTTP response) that will be sent back from Server-to-Relay and +// Relay-to-client in HTTP POST body. +// 4. Handles BoringSSL HPKE context setup and bookkeeping. + +// This class is immutable (except moves) and thus trivially thread-safe. +class QUICHE_EXPORT ObliviousHttpClient { + public: + static absl::StatusOr Create( + absl::string_view hpke_public_key, + const ObliviousHttpHeaderKeyConfig& ohttp_key_config); + + // Copyable. + ObliviousHttpClient(const ObliviousHttpClient& other) = default; + ObliviousHttpClient& operator=(const ObliviousHttpClient& other) = default; + + // Movable. + ObliviousHttpClient(ObliviousHttpClient&& other) = default; + ObliviousHttpClient& operator=(ObliviousHttpClient&& other) = default; + + ~ObliviousHttpClient() = default; + + // After successful `Create`, callers will use the returned object to + // repeatedly call into this method in order to create Oblivious HTTP request + // with the initialized HPKE public key. Call sequence: Create -> + // CreateObliviousHttpRequest -> DecryptObliviousHttpResponse. + // Eg., + // auto ohttp_client_object = ObliviousHttpClient::Create( , ); + // auto encrypted_request1 = + // ohttp_client_object.CreateObliviousHttpRequest("binary http string 1"); + // auto encrypted_request2 = + // ohttp_client_object.CreateObliviousHttpRequest("binary http string 2"); + absl::StatusOr CreateObliviousHttpRequest( + std::string plaintext_data) const; + + // After `CreateObliviousHttpRequest` operation, callers on client-side will + // extract `oblivious_http_request_context` from the returned object + // `ObliviousHttpRequest` and pass in to this method in order to decrypt the + // response that's received from Gateway for the given request at hand. + absl::StatusOr DecryptObliviousHttpResponse( + std::string encrypted_data, + ObliviousHttpRequest::Context& oblivious_http_request_context) const; + + private: + explicit ObliviousHttpClient( + std::string client_public_key, + const ObliviousHttpHeaderKeyConfig& ohttp_key_config); + std::string hpke_public_key_; + // Holds server's keyID and HPKE related IDs that's published under HPKE + // public Key configuration. + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#name-key-configuration + ObliviousHttpHeaderKeyConfig ohttp_key_config_; +}; + +} // namespace quiche + +#endif // QUICHE_OBLIVIOUS_HTTP_OBLIVIOUS_HTTP_CLIENT_H_ diff --git a/quiche/oblivious_http/oblivious_http_client_test.cc b/quiche/oblivious_http/oblivious_http_client_test.cc new file mode 100644 index 000000000000..a2768a519c23 --- /dev/null +++ b/quiche/oblivious_http/oblivious_http_client_test.cc @@ -0,0 +1,252 @@ +#include "quiche/oblivious_http/oblivious_http_client.h" + +#include + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/platform/api/quiche_thread.h" + +namespace quiche { + +std::string GetHpkePrivateKey() { + // Dev/Test private key generated using Keystore. + absl::string_view hpke_key_hex = + "b77431ecfa8f4cfc30d6e467aafa06944dffe28cb9dd1409e33a3045f5adc8a1"; + return absl::HexStringToBytes(hpke_key_hex); +} + +std::string GetHpkePublicKey() { + // Dev/Test public key generated using Keystore. + absl::string_view public_key = + "6d21cfe09fbea5122f9ebc2eb2a69fcc4f06408cd54aac934f012e76fcdcef62"; + return absl::HexStringToBytes(public_key); +} + +const ObliviousHttpHeaderKeyConfig GetOhttpKeyConfig(uint8_t key_id, + uint16_t kem_id, + uint16_t kdf_id, + uint16_t aead_id) { + auto ohttp_key_config = + ObliviousHttpHeaderKeyConfig::Create(key_id, kem_id, kdf_id, aead_id); + EXPECT_TRUE(ohttp_key_config.ok()); + return ohttp_key_config.value(); +} + +bssl::UniquePtr ConstructHpkeKey( + absl::string_view hpke_key, + const ObliviousHttpHeaderKeyConfig& ohttp_key_config) { + bssl::UniquePtr bssl_hpke_key(EVP_HPKE_KEY_new()); + EXPECT_NE(bssl_hpke_key, nullptr); + EXPECT_TRUE(EVP_HPKE_KEY_init( + bssl_hpke_key.get(), ohttp_key_config.GetHpkeKem(), + reinterpret_cast(hpke_key.data()), hpke_key.size())); + return bssl_hpke_key; +} + +TEST(ObliviousHttpClient, TestEncapsulate) { + auto client = ObliviousHttpClient::Create( + GetHpkePublicKey(), + GetOhttpKeyConfig(8, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM)); + ASSERT_TRUE(client.ok()); + auto encrypted_req = client->CreateObliviousHttpRequest("test string 1"); + ASSERT_TRUE(encrypted_req.ok()); + auto serialized_encrypted_req = encrypted_req->EncapsulateAndSerialize(); + ASSERT_FALSE(serialized_encrypted_req.empty()); +} + +TEST(ObliviousHttpClient, TestEncryptingMultipleRequestsWithSingleInstance) { + auto client = ObliviousHttpClient::Create( + GetHpkePublicKey(), + GetOhttpKeyConfig(1, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM)); + ASSERT_TRUE(client.ok()); + auto ohttp_req_1 = client->CreateObliviousHttpRequest("test string 1"); + ASSERT_TRUE(ohttp_req_1.ok()); + auto serialized_ohttp_req_1 = ohttp_req_1->EncapsulateAndSerialize(); + ASSERT_FALSE(serialized_ohttp_req_1.empty()); + auto ohttp_req_2 = client->CreateObliviousHttpRequest("test string 2"); + ASSERT_TRUE(ohttp_req_2.ok()); + auto serialized_ohttp_req_2 = ohttp_req_2->EncapsulateAndSerialize(); + ASSERT_FALSE(serialized_ohttp_req_2.empty()); + EXPECT_NE(serialized_ohttp_req_1, serialized_ohttp_req_2); +} + +TEST(ObliviousHttpClient, TestInvalidHPKEKey) { + // Invalid public key. + EXPECT_EQ(ObliviousHttpClient::Create( + "Invalid HPKE key", + GetOhttpKeyConfig(50, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM)) + .status() + .code(), + absl::StatusCode::kInvalidArgument); + // Empty public key. + EXPECT_EQ(ObliviousHttpClient::Create( + /*hpke_public_key*/ "", + GetOhttpKeyConfig(50, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM)) + .status() + .code(), + absl::StatusCode::kInvalidArgument); +} + +TEST(ObliviousHttpClient, + TestTwoSamePlaintextsWillGenerateDifferentEncryptedPayloads) { + // Due to the nature of the encapsulated_key generated in HPKE being unique + // for every request, expect different encrypted payloads when encrypting same + // plaintexts. + auto client = ObliviousHttpClient::Create( + GetHpkePublicKey(), + GetOhttpKeyConfig(1, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM)); + ASSERT_TRUE(client.ok()); + auto encrypted_request_1 = + client->CreateObliviousHttpRequest("same plaintext"); + ASSERT_TRUE(encrypted_request_1.ok()); + auto serialized_encrypted_request_1 = + encrypted_request_1->EncapsulateAndSerialize(); + ASSERT_FALSE(serialized_encrypted_request_1.empty()); + auto encrypted_request_2 = + client->CreateObliviousHttpRequest("same plaintext"); + ASSERT_TRUE(encrypted_request_2.ok()); + auto serialized_encrypted_request_2 = + encrypted_request_2->EncapsulateAndSerialize(); + ASSERT_FALSE(serialized_encrypted_request_2.empty()); + EXPECT_NE(serialized_encrypted_request_1, serialized_encrypted_request_2); +} + +TEST(ObliviousHttpClient, TestObliviousResponseHandling) { + auto ohttp_key_config = + GetOhttpKeyConfig(1, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM); + auto encapsulate_req_on_client = + ObliviousHttpRequest::CreateClientObliviousRequest( + "test", GetHpkePublicKey(), ohttp_key_config); + ASSERT_TRUE(encapsulate_req_on_client.ok()); + auto decapsulate_req_on_gateway = + ObliviousHttpRequest::CreateServerObliviousRequest( + encapsulate_req_on_client->EncapsulateAndSerialize(), + *(ConstructHpkeKey(GetHpkePrivateKey(), ohttp_key_config)), + ohttp_key_config); + ASSERT_TRUE(decapsulate_req_on_gateway.ok()); + auto gateway_request_context = + std::move(decapsulate_req_on_gateway.value()).ReleaseContext(); + auto encapsulate_resp_on_gateway = + ObliviousHttpResponse::CreateServerObliviousResponse( + "test response", gateway_request_context); + ASSERT_TRUE(encapsulate_resp_on_gateway.ok()); + + auto client = + ObliviousHttpClient::Create(GetHpkePublicKey(), ohttp_key_config); + ASSERT_TRUE(client.ok()); + auto client_request_context = + std::move(encapsulate_req_on_client.value()).ReleaseContext(); + auto decapsulate_resp_on_client = client->DecryptObliviousHttpResponse( + encapsulate_resp_on_gateway->EncapsulateAndSerialize(), + client_request_context); + ASSERT_TRUE(decapsulate_resp_on_client.ok()); + EXPECT_EQ(decapsulate_resp_on_client->GetPlaintextData(), "test response"); +} + +TEST(ObliviousHttpClient, + DecryptResponseReceivedByTheClientUsingServersObliviousContext) { + auto ohttp_key_config = + GetOhttpKeyConfig(1, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM); + auto encapsulate_req_on_client = + ObliviousHttpRequest::CreateClientObliviousRequest( + "test", GetHpkePublicKey(), ohttp_key_config); + ASSERT_TRUE(encapsulate_req_on_client.ok()); + auto decapsulate_req_on_gateway = + ObliviousHttpRequest::CreateServerObliviousRequest( + encapsulate_req_on_client->EncapsulateAndSerialize(), + *(ConstructHpkeKey(GetHpkePrivateKey(), ohttp_key_config)), + ohttp_key_config); + ASSERT_TRUE(decapsulate_req_on_gateway.ok()); + auto gateway_request_context = + std::move(decapsulate_req_on_gateway.value()).ReleaseContext(); + auto encapsulate_resp_on_gateway = + ObliviousHttpResponse::CreateServerObliviousResponse( + "test response", gateway_request_context); + ASSERT_TRUE(encapsulate_resp_on_gateway.ok()); + + auto client = + ObliviousHttpClient::Create(GetHpkePublicKey(), ohttp_key_config); + ASSERT_TRUE(client.ok()); + auto decapsulate_resp_on_client = client->DecryptObliviousHttpResponse( + encapsulate_resp_on_gateway->EncapsulateAndSerialize(), + gateway_request_context); + ASSERT_TRUE(decapsulate_resp_on_client.ok()); + EXPECT_EQ(decapsulate_resp_on_client->GetPlaintextData(), "test response"); +} + +TEST(ObliviousHttpClient, TestWithMultipleThreads) { + class TestQuicheThread : public QuicheThread { + public: + TestQuicheThread(const ObliviousHttpClient& client, + std::string request_payload, + ObliviousHttpHeaderKeyConfig ohttp_key_config) + : QuicheThread("client_thread"), + client_(client), + request_payload_(request_payload), + ohttp_key_config_(ohttp_key_config) {} + + protected: + void Run() override { + auto encrypted_request = + client_.CreateObliviousHttpRequest(request_payload_); + ASSERT_TRUE(encrypted_request.ok()); + ASSERT_FALSE(encrypted_request->EncapsulateAndSerialize().empty()); + // Setup recipient and get encrypted response payload. + auto decapsulate_req_on_gateway = + ObliviousHttpRequest::CreateServerObliviousRequest( + encrypted_request->EncapsulateAndSerialize(), + *(ConstructHpkeKey(GetHpkePrivateKey(), ohttp_key_config_)), + ohttp_key_config_); + ASSERT_TRUE(decapsulate_req_on_gateway.ok()); + auto gateway_request_context = + std::move(decapsulate_req_on_gateway.value()).ReleaseContext(); + auto encapsulate_resp_on_gateway = + ObliviousHttpResponse::CreateServerObliviousResponse( + "test response", gateway_request_context); + ASSERT_TRUE(encapsulate_resp_on_gateway.ok()); + ASSERT_FALSE( + encapsulate_resp_on_gateway->EncapsulateAndSerialize().empty()); + auto client_request_context = + std::move(encrypted_request.value()).ReleaseContext(); + auto decrypted_response = client_.DecryptObliviousHttpResponse( + encapsulate_resp_on_gateway->EncapsulateAndSerialize(), + client_request_context); + ASSERT_TRUE(decrypted_response.ok()); + ASSERT_FALSE(decrypted_response->GetPlaintextData().empty()); + } + + private: + const ObliviousHttpClient& client_; + std::string request_payload_; + ObliviousHttpHeaderKeyConfig ohttp_key_config_; + }; + + auto ohttp_key_config = + GetOhttpKeyConfig(1, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM); + auto client = + ObliviousHttpClient::Create(GetHpkePublicKey(), ohttp_key_config); + + TestQuicheThread t1(*client, "test request 1", ohttp_key_config); + TestQuicheThread t2(*client, "test request 2", ohttp_key_config); + t1.Start(); + t2.Start(); + t1.Join(); + t2.Join(); +} + +} // namespace quiche diff --git a/quiche/oblivious_http/oblivious_http_gateway.cc b/quiche/oblivious_http/oblivious_http_gateway.cc new file mode 100644 index 000000000000..b2d2e88a7e23 --- /dev/null +++ b/quiche/oblivious_http/oblivious_http_gateway.cc @@ -0,0 +1,68 @@ +#include "quiche/oblivious_http/oblivious_http_gateway.h" + +#include + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "quiche/common/quiche_crypto_logging.h" +#include "quiche/common/quiche_random.h" + +namespace quiche { + +// Constructor. +ObliviousHttpGateway::ObliviousHttpGateway( + bssl::UniquePtr recipient_key, + const ObliviousHttpHeaderKeyConfig& ohttp_key_config, + QuicheRandom* quiche_random) + : server_hpke_key_(std::move(recipient_key)), + ohttp_key_config_(ohttp_key_config), + quiche_random_(quiche_random) {} + +// Initialize ObliviousHttpGateway(Recipient/Server) context. +absl::StatusOr ObliviousHttpGateway::Create( + absl::string_view hpke_private_key, + const ObliviousHttpHeaderKeyConfig& ohttp_key_config, + QuicheRandom* quiche_random) { + if (hpke_private_key.empty()) { + return absl::InvalidArgumentError("Invalid/Empty HPKE private key."); + } + // Initialize HPKE key and context. + bssl::UniquePtr recipient_key(EVP_HPKE_KEY_new()); + if (recipient_key == nullptr) { + return SslErrorAsStatus( + "Failed to initialize ObliviousHttpGateway/Server's Key."); + } + if (!EVP_HPKE_KEY_init( + recipient_key.get(), ohttp_key_config.GetHpkeKem(), + reinterpret_cast(hpke_private_key.data()), + hpke_private_key.size())) { + return SslErrorAsStatus("Failed to import HPKE private key."); + } + if (quiche_random == nullptr) quiche_random = QuicheRandom::GetInstance(); + return ObliviousHttpGateway(std::move(recipient_key), ohttp_key_config, + quiche_random); +} + +absl::StatusOr +ObliviousHttpGateway::DecryptObliviousHttpRequest( + absl::string_view encrypted_data) const { + return ObliviousHttpRequest::CreateServerObliviousRequest( + encrypted_data, *(server_hpke_key_), ohttp_key_config_); +} + +absl::StatusOr +ObliviousHttpGateway::CreateObliviousHttpResponse( + std::string plaintext_data, + ObliviousHttpRequest::Context& oblivious_http_request_context) const { + return ObliviousHttpResponse::CreateServerObliviousResponse( + std::move(plaintext_data), oblivious_http_request_context, + quiche_random_); +} + +} // namespace quiche diff --git a/quiche/oblivious_http/oblivious_http_gateway.h b/quiche/oblivious_http/oblivious_http_gateway.h new file mode 100644 index 000000000000..ae6c746acfb5 --- /dev/null +++ b/quiche/oblivious_http/oblivious_http_gateway.h @@ -0,0 +1,83 @@ +#ifndef QUICHE_OBLIVIOUS_HTTP_OBLIVIOUS_HTTP_GATEWAY_H_ +#define QUICHE_OBLIVIOUS_HTTP_OBLIVIOUS_HTTP_GATEWAY_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "openssl/base.h" +#include "openssl/hpke.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/quiche_random.h" +#include "quiche/oblivious_http/buffers/oblivious_http_request.h" +#include "quiche/oblivious_http/buffers/oblivious_http_response.h" +#include "quiche/oblivious_http/common/oblivious_http_header_key_config.h" + +namespace quiche { +// 1. Handles server side decryption of the payload received in HTTP POST body +// from Relay. +// 2. Handles server side encryption of response (that's in the form of Binary +// HTTP) that will be sent back to Relay in HTTP POST body. +// 3. Handles BSSL initialization and HPKE context bookkeeping. + +// This class is immutable (except moves) and thus trivially thread-safe, +// assuming the `QuicheRandom* quiche_random` passed in with `Create` is +// thread-safe. Note that default `QuicheRandom::GetInstance()` is thread-safe. +class QUICHE_EXPORT ObliviousHttpGateway { + public: + // @params: If callers would like to pass in their own `QuicheRandom` + // instance, they can make use of the param `quiche_random`. Otherwise, the + // default `QuicheRandom::GetInstance()` will be used. + static absl::StatusOr Create( + absl::string_view hpke_private_key, + const ObliviousHttpHeaderKeyConfig& ohttp_key_config, + QuicheRandom* quiche_random = nullptr); + + // only Movable (due to `UniquePtr server_hpke_key_`). + ObliviousHttpGateway(ObliviousHttpGateway&& other) = default; + ObliviousHttpGateway& operator=(ObliviousHttpGateway&& other) = default; + + ~ObliviousHttpGateway() = default; + + // After successful `Create`, callers will use the returned object to + // repeatedly call into this method in order to create Oblivious HTTP request + // with the initialized HPKE private key. Call sequence: Create -> + // DecryptObliviousHttpRequest -> CreateObliviousHttpResponse. + // Eg., + // auto ohttp_server_object = ObliviousHttpGateway::Create( , ); + // auto decrypted_request1 = + // ohttp_server_object.DecryptObliviousHttpRequest(); + // auto decrypted_request2 = + // ohttp_server_object.DecryptObliviousHttpRequest(); + absl::StatusOr DecryptObliviousHttpRequest( + absl::string_view encrypted_data) const; + + // After `DecryptObliviousHttpRequest` operation, callers on server-side will + // extract `oblivious_http_request_context` from the returned object + // `ObliviousHttpRequest` and pass in to this method in order to handle the + // response flow back to the client. + absl::StatusOr CreateObliviousHttpResponse( + std::string plaintext_data, + ObliviousHttpRequest::Context& oblivious_http_request_context) const; + + private: + explicit ObliviousHttpGateway( + bssl::UniquePtr recipient_key, + const ObliviousHttpHeaderKeyConfig& ohttp_key_config, + QuicheRandom* quiche_random); + bssl::UniquePtr server_hpke_key_; + // Holds server's keyID and HPKE related IDs that's published under HPKE + // public Key configuration. + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#name-key-configuration + ObliviousHttpHeaderKeyConfig ohttp_key_config_; + QuicheRandom* quiche_random_; +}; + +} // namespace quiche + +#endif // QUICHE_OBLIVIOUS_HTTP_OBLIVIOUS_HTTP_GATEWAY_H_ diff --git a/quiche/oblivious_http/oblivious_http_gateway_test.cc b/quiche/oblivious_http/oblivious_http_gateway_test.cc new file mode 100644 index 000000000000..af03b210c8e7 --- /dev/null +++ b/quiche/oblivious_http/oblivious_http_gateway_test.cc @@ -0,0 +1,227 @@ +#include "quiche/oblivious_http/oblivious_http_gateway.h" + +#include + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/platform/api/quiche_thread.h" +#include "quiche/common/quiche_random.h" +#include "quiche/oblivious_http/buffers/oblivious_http_request.h" + +namespace quiche { +namespace { + +std::string GetHpkePrivateKey() { + // Dev/Test private key generated using Keystore. + absl::string_view hpke_key_hex = + "b77431ecfa8f4cfc30d6e467aafa06944dffe28cb9dd1409e33a3045f5adc8a1"; + return absl::HexStringToBytes(hpke_key_hex); +} + +std::string GetHpkePublicKey() { + // Dev/Test public key generated using Keystore. + absl::string_view public_key = + "6d21cfe09fbea5122f9ebc2eb2a69fcc4f06408cd54aac934f012e76fcdcef62"; + return absl::HexStringToBytes(public_key); +} + +const ObliviousHttpHeaderKeyConfig GetOhttpKeyConfig(uint8_t key_id, + uint16_t kem_id, + uint16_t kdf_id, + uint16_t aead_id) { + auto ohttp_key_config = + ObliviousHttpHeaderKeyConfig::Create(key_id, kem_id, kdf_id, aead_id); + EXPECT_TRUE(ohttp_key_config.ok()); + return std::move(ohttp_key_config.value()); +} + +TEST(ObliviousHttpGateway, TestProvisioningKeyAndDecapsulate) { + // X25519 Secret key (priv key). + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#appendix-A-2 + constexpr absl::string_view kX25519SecretKey = + "3c168975674b2fa8e465970b79c8dcf09f1c741626480bd4c6162fc5b6a98e1a"; + + auto instance = ObliviousHttpGateway::Create( + /*hpke_private_key*/ absl::HexStringToBytes(kX25519SecretKey), + /*ohttp_key_config*/ GetOhttpKeyConfig( + /*key_id=*/1, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, EVP_HPKE_HKDF_SHA256, + EVP_HPKE_AES_128_GCM)); + + // Encapsulated request. + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#appendix-A-14 + constexpr absl::string_view kEncapsulatedRequest = + "010020000100014b28f881333e7c164ffc499ad9796f877f4e1051ee6d31bad19dec96c2" + "08b4726374e469135906992e1268c594d2a10c695d858c40a026e7965e7d86b83dd440b2" + "c0185204b4d63525"; + + auto decrypted_req = instance->DecryptObliviousHttpRequest( + absl::HexStringToBytes(kEncapsulatedRequest)); + ASSERT_TRUE(decrypted_req.ok()); + ASSERT_FALSE(decrypted_req->GetPlaintextData().empty()); +} + +TEST(ObliviousHttpGateway, TestDecryptingMultipleRequestsWithSingleInstance) { + auto instance = ObliviousHttpGateway::Create( + GetHpkePrivateKey(), + GetOhttpKeyConfig(1, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM)); + // plaintext: "test request 1" + absl::string_view encrypted_req_1 = + "010020000100025f20b60306b61ad9ecad389acd752ca75c4e2969469809fe3d84aae137" + "f73e4ccfe9ba71f12831fdce6c8202fbd38a84c5d8a73ac4c8ea6c10592594845f"; + auto decapsulated_req_1 = instance->DecryptObliviousHttpRequest( + absl::HexStringToBytes(encrypted_req_1)); + ASSERT_TRUE(decapsulated_req_1.ok()); + ASSERT_FALSE(decapsulated_req_1->GetPlaintextData().empty()); + + // plaintext: "test request 2" + absl::string_view encrypted_req_2 = + "01002000010002285ebc2fcad72cc91b378050cac29a62feea9cd97829335ee9fc87e672" + "4fa13ff2efdff620423d54225d3099088e7b32a5165f805a5d922918865a0a447a"; + auto decapsulated_req_2 = instance->DecryptObliviousHttpRequest( + absl::HexStringToBytes(encrypted_req_2)); + ASSERT_TRUE(decapsulated_req_2.ok()); + ASSERT_FALSE(decapsulated_req_2->GetPlaintextData().empty()); +} + +TEST(ObliviousHttpGateway, TestInvalidHPKEKey) { + // Invalid private key. + EXPECT_EQ(ObliviousHttpGateway::Create( + "Invalid HPKE key", + GetOhttpKeyConfig(70, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM)) + .status() + .code(), + absl::StatusCode::kInternal); + // Empty private key. + EXPECT_EQ(ObliviousHttpGateway::Create( + /*hpke_private_key*/ "", + GetOhttpKeyConfig(70, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM)) + .status() + .code(), + absl::StatusCode::kInvalidArgument); +} + +TEST(ObliviousHttpGateway, TestObliviousResponseHandling) { + auto ohttp_key_config = + GetOhttpKeyConfig(3, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM); + auto instance = + ObliviousHttpGateway::Create(GetHpkePrivateKey(), ohttp_key_config); + ASSERT_TRUE(instance.ok()); + auto encapsualte_request_on_client = + ObliviousHttpRequest::CreateClientObliviousRequest( + "test", GetHpkePublicKey(), ohttp_key_config); + ASSERT_TRUE(encapsualte_request_on_client.ok()); + // Setup Recipient to allow setting up the HPKE context, and subsequently use + // it to encrypt the response. + auto decapsulated_req_on_server = instance->DecryptObliviousHttpRequest( + encapsualte_request_on_client->EncapsulateAndSerialize()); + ASSERT_TRUE(decapsulated_req_on_server.ok()); + auto server_request_context = + std::move(decapsulated_req_on_server.value()).ReleaseContext(); + auto encapsulate_resp_on_gateway = instance->CreateObliviousHttpResponse( + "some response", server_request_context); + ASSERT_TRUE(encapsulate_resp_on_gateway.ok()); + ASSERT_FALSE(encapsulate_resp_on_gateway->EncapsulateAndSerialize().empty()); +} + +TEST(ObliviousHttpGateway, + TestHandlingMultipleResponsesForMultipleRequestsWithSingleInstance) { + auto instance = ObliviousHttpGateway::Create( + GetHpkePrivateKey(), + GetOhttpKeyConfig(1, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM), + QuicheRandom::GetInstance()); + // Setup contexts first. + auto decrypted_request_1 = instance->DecryptObliviousHttpRequest( + absl::HexStringToBytes("010020000100025f20b60306b61ad9ecad389acd752ca75c4" + "e2969469809fe3d84aae137" + "f73e4ccfe9ba71f12831fdce6c8202fbd38a84c5d8a73ac4c" + "8ea6c10592594845f")); + ASSERT_TRUE(decrypted_request_1.ok()); + auto decrypted_request_2 = instance->DecryptObliviousHttpRequest( + absl::HexStringToBytes("01002000010002285ebc2fcad72cc91b378050cac29a62fee" + "a9cd97829335ee9fc87e672" + "4fa13ff2efdff620423d54225d3099088e7b32a5165f805a5" + "d922918865a0a447a")); + ASSERT_TRUE(decrypted_request_2.ok()); + + // Extract contexts and handle the response for each corresponding request. + auto oblivious_request_context_1 = + std::move(decrypted_request_1.value()).ReleaseContext(); + auto encrypted_response_1 = instance->CreateObliviousHttpResponse( + "test response 1", oblivious_request_context_1); + ASSERT_TRUE(encrypted_response_1.ok()); + ASSERT_FALSE(encrypted_response_1->EncapsulateAndSerialize().empty()); + auto oblivious_request_context_2 = + std::move(decrypted_request_2.value()).ReleaseContext(); + auto encrypted_response_2 = instance->CreateObliviousHttpResponse( + "test response 2", oblivious_request_context_2); + ASSERT_TRUE(encrypted_response_2.ok()); + ASSERT_FALSE(encrypted_response_2->EncapsulateAndSerialize().empty()); +} + +TEST(ObliviousHttpGateway, TestWithMultipleThreads) { + class TestQuicheThread : public QuicheThread { + public: + TestQuicheThread(const ObliviousHttpGateway& gateway_receiver, + std::string request_payload, std::string response_payload) + : QuicheThread("gateway_thread"), + gateway_receiver_(gateway_receiver), + request_payload_(request_payload), + response_payload_(response_payload) {} + + protected: + void Run() override { + auto decrypted_request = + gateway_receiver_.DecryptObliviousHttpRequest(request_payload_); + ASSERT_TRUE(decrypted_request.ok()); + ASSERT_FALSE(decrypted_request->GetPlaintextData().empty()); + auto gateway_request_context = + std::move(decrypted_request.value()).ReleaseContext(); + auto encrypted_response = gateway_receiver_.CreateObliviousHttpResponse( + response_payload_, gateway_request_context); + ASSERT_TRUE(encrypted_response.ok()); + ASSERT_FALSE(encrypted_response->EncapsulateAndSerialize().empty()); + } + + private: + const ObliviousHttpGateway& gateway_receiver_; + std::string request_payload_, response_payload_; + }; + + auto gateway_receiver = ObliviousHttpGateway::Create( + GetHpkePrivateKey(), + GetOhttpKeyConfig(1, EVP_HPKE_DHKEM_X25519_HKDF_SHA256, + EVP_HPKE_HKDF_SHA256, EVP_HPKE_AES_256_GCM), + QuicheRandom::GetInstance()); + + TestQuicheThread t1( + *gateway_receiver, + absl::HexStringToBytes("010020000100025f20b60306b61ad9ecad389acd752ca75c4" + "e2969469809fe3d84aae137" + "f73e4ccfe9ba71f12831fdce6c8202fbd38a84c5d8a73ac4c" + "8ea6c10592594845f"), + "test response 1"); + TestQuicheThread t2( + *gateway_receiver, + absl::HexStringToBytes("01002000010002285ebc2fcad72cc91b378050cac29a62fee" + "a9cd97829335ee9fc87e672" + "4fa13ff2efdff620423d54225d3099088e7b32a5165f805a5" + "d922918865a0a447a"), + "test response 2"); + t1.Start(); + t2.Start(); + t1.Join(); + t2.Join(); +} +} // namespace +} // namespace quiche diff --git a/quiche/quic/bindings/quic_libevent.cc b/quiche/quic/bindings/quic_libevent.cc new file mode 100644 index 000000000000..a053d3512f6b --- /dev/null +++ b/quiche/quic/bindings/quic_libevent.cc @@ -0,0 +1,239 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/bindings/quic_libevent.h" + +#include + +#include "absl/time/time.h" +#include "event2/event.h" +#include "event2/event_struct.h" +#include "event2/thread.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/quic_alarm.h" +#include "quiche/quic/core/quic_clock.h" +#include "quiche/quic/core/quic_default_clock.h" +#include "quiche/quic/core/quic_time.h" + +namespace quic { + +using LibeventEventMask = short; // NOLINT(runtime/int) + +QuicSocketEventMask LibeventEventMaskToQuicEvents(int events) { + return ((events & EV_READ) ? kSocketEventReadable : 0) | + ((events & EV_WRITE) ? kSocketEventWritable : 0); +} + +LibeventEventMask QuicEventsToLibeventEventMask(QuicSocketEventMask events) { + return ((events & kSocketEventReadable) ? EV_READ : 0) | + ((events & kSocketEventWritable) ? EV_WRITE : 0); +} + +class LibeventAlarm : public QuicAlarm { + public: + LibeventAlarm(LibeventQuicEventLoop* loop, + QuicArenaScopedPtr delegate) + : QuicAlarm(std::move(delegate)), clock_(loop->clock()) { + event_.reset(evtimer_new( + loop->base(), + [](evutil_socket_t, LibeventEventMask, void* arg) { + LibeventAlarm* self = reinterpret_cast(arg); + self->Fire(); + }, + this)); + } + + protected: + void SetImpl() override { + absl::Duration timeout = + absl::Microseconds((deadline() - clock_->Now()).ToMicroseconds()); + timeval unix_time = absl::ToTimeval(timeout); + event_add(event_.get(), &unix_time); + } + + void CancelImpl() override { event_del(event_.get()); } + + private: + // While we inline `struct event` elsewhere, it is actually quite large, so + // doing that for the libevent-based QuicAlarm would cause it to not fit into + // the QuicConnectionArena. + struct EventDeleter { + void operator()(event* ev) { event_free(ev); } + }; + std::unique_ptr event_; + QuicClock* clock_; +}; + +LibeventQuicEventLoop::LibeventQuicEventLoop(event_base* base, QuicClock* clock) + : base_(base), + edge_triggered_(event_base_get_features(base) & EV_FEATURE_ET), + clock_(clock) { + QUICHE_CHECK_LE(sizeof(event), event_get_struct_event_size()) + << "libevent ABI mismatch: sizeof(event) is bigger than the one QUICHE " + "has been compiled with"; +} + +bool LibeventQuicEventLoop::RegisterSocket(QuicUdpSocketFd fd, + QuicSocketEventMask events, + QuicSocketEventListener* listener) { + auto [it, success] = + registration_map_.try_emplace(fd, this, fd, events, listener); + return success; +} + +bool LibeventQuicEventLoop::UnregisterSocket(QuicUdpSocketFd fd) { + return registration_map_.erase(fd); +} + +bool LibeventQuicEventLoop::RearmSocket(QuicUdpSocketFd fd, + QuicSocketEventMask events) { + if (edge_triggered_) { + QUICHE_BUG(LibeventQuicEventLoop_RearmSocket_called_on_ET) + << "RearmSocket() called on an edge-triggered event loop"; + return false; + } + auto it = registration_map_.find(fd); + if (it == registration_map_.end()) { + return false; + } + it->second.Rearm(events); + return true; +} + +bool LibeventQuicEventLoop::ArtificiallyNotifyEvent( + QuicUdpSocketFd fd, QuicSocketEventMask events) { + auto it = registration_map_.find(fd); + if (it == registration_map_.end()) { + return false; + } + it->second.ArtificiallyNotify(events); + return true; +} + +void LibeventQuicEventLoop::RunEventLoopOnce(QuicTime::Delta default_timeout) { + timeval timeout = + absl::ToTimeval(absl::Microseconds(default_timeout.ToMicroseconds())); + event_base_loopexit(base_, &timeout); + event_base_loop(base_, EVLOOP_ONCE); +} + +void LibeventQuicEventLoop::WakeUp() { + timeval timeout = absl::ToTimeval(absl::ZeroDuration()); + event_base_loopexit(base_, &timeout); +} + +LibeventQuicEventLoop::Registration::Registration( + LibeventQuicEventLoop* loop, QuicUdpSocketFd fd, QuicSocketEventMask events, + QuicSocketEventListener* listener) + : loop_(loop), listener_(listener) { + event_callback_fn callback = [](evutil_socket_t fd, LibeventEventMask events, + void* arg) { + auto* self = reinterpret_cast(arg); + self->listener_->OnSocketEvent(self->loop_, fd, + LibeventEventMaskToQuicEvents(events)); + }; + + if (loop_->SupportsEdgeTriggered()) { + LibeventEventMask mask = + QuicEventsToLibeventEventMask(events) | EV_PERSIST | EV_ET; + event_assign(&both_events_, loop_->base(), fd, mask, callback, this); + event_add(&both_events_, nullptr); + } else { + event_assign(&read_event_, loop_->base(), fd, EV_READ, callback, this); + event_assign(&write_event_, loop_->base(), fd, EV_WRITE, callback, this); + Rearm(events); + } +} + +LibeventQuicEventLoop::Registration::~Registration() { + if (loop_->SupportsEdgeTriggered()) { + event_del(&both_events_); + } else { + event_del(&read_event_); + event_del(&write_event_); + } +} + +void LibeventQuicEventLoop::Registration::ArtificiallyNotify( + QuicSocketEventMask events) { + if (loop_->SupportsEdgeTriggered()) { + event_active(&both_events_, QuicEventsToLibeventEventMask(events), 0); + return; + } + + if (events & kSocketEventReadable) { + event_active(&read_event_, EV_READ, 0); + } + if (events & kSocketEventWritable) { + event_active(&write_event_, EV_WRITE, 0); + } +} + +void LibeventQuicEventLoop::Registration::Rearm(QuicSocketEventMask events) { + QUICHE_DCHECK(!loop_->SupportsEdgeTriggered()); + if (events & kSocketEventReadable) { + event_add(&read_event_, nullptr); + } + if (events & kSocketEventWritable) { + event_add(&write_event_, nullptr); + } +} + +QuicAlarm* LibeventQuicEventLoop::AlarmFactory::CreateAlarm( + QuicAlarm::Delegate* delegate) { + return new LibeventAlarm(loop_, + QuicArenaScopedPtr(delegate)); +} + +QuicArenaScopedPtr LibeventQuicEventLoop::AlarmFactory::CreateAlarm( + QuicArenaScopedPtr delegate, + QuicConnectionArena* arena) { + if (arena != nullptr) { + return arena->New(loop_, std::move(delegate)); + } + return QuicArenaScopedPtr( + new LibeventAlarm(loop_, std::move(delegate))); +} + +QuicLibeventEventLoopFactory::QuicLibeventEventLoopFactory( + bool force_level_triggered) + : force_level_triggered_(force_level_triggered) { + std::unique_ptr event_loop = Create(QuicDefaultClock::Get()); + name_ = absl::StrFormat( + "libevent(%s)", + event_base_get_method( + static_cast(event_loop.get()) + ->base())); +} + +struct LibeventConfigDeleter { + void operator()(event_config* config) { event_config_free(config); } +}; + +std::unique_ptr +LibeventQuicEventLoopWithOwnership::Create(QuicClock* clock, + bool force_level_triggered) { + // Required for event_base_loopbreak() to actually work. + static int threads_initialized = []() { +#ifdef _WIN32 + return evthread_use_windows_threads(); +#else + return evthread_use_pthreads(); +#endif + }(); + QUICHE_DCHECK_EQ(threads_initialized, 0); + + std::unique_ptr config( + event_config_new()); + if (force_level_triggered) { + // epoll and kqueue are the two only current libevent backends that support + // edge-triggered I/O. + event_config_avoid_method(config.get(), "epoll"); + event_config_avoid_method(config.get(), "kqueue"); + } + return std::make_unique( + event_base_new_with_config(config.get()), clock); +} + +} // namespace quic diff --git a/quiche/quic/bindings/quic_libevent.h b/quiche/quic/bindings/quic_libevent.h new file mode 100644 index 000000000000..82783e9a1cdd --- /dev/null +++ b/quiche/quic/bindings/quic_libevent.h @@ -0,0 +1,155 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_BINDINGS_QUIC_LIBEVENT_H_ +#define QUICHE_QUIC_BINDINGS_QUIC_LIBEVENT_H_ + +#include + +#include "absl/container/node_hash_map.h" +#include "event2/event.h" +#include "event2/event_struct.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/quic_alarm.h" +#include "quiche/quic/core/quic_alarm_factory.h" +#include "quiche/quic/core/quic_udp_socket.h" + +namespace quic { + +// Provides a libevent-based implementation of QuicEventLoop. Since libevent +// uses relative time for all timeouts, the provided clock does not need to use +// the UNIX time. +class QUICHE_EXPORT LibeventQuicEventLoop : public QuicEventLoop { + public: + explicit LibeventQuicEventLoop(event_base* base, QuicClock* clock); + + // QuicEventLoop implementation. + bool SupportsEdgeTriggered() const override { return edge_triggered_; } + std::unique_ptr CreateAlarmFactory() override { + return std::make_unique(this); + } + bool RegisterSocket(QuicUdpSocketFd fd, QuicSocketEventMask events, + QuicSocketEventListener* listener) override; + bool UnregisterSocket(QuicUdpSocketFd fd) override; + bool RearmSocket(QuicUdpSocketFd fd, QuicSocketEventMask events) override; + bool ArtificiallyNotifyEvent(QuicUdpSocketFd fd, + QuicSocketEventMask events) override; + void RunEventLoopOnce(QuicTime::Delta default_timeout) override; + const QuicClock* GetClock() override { return clock_; } + + // Can be called from another thread to wake up the event loop from a blocking + // RunEventLoopOnce() call. + void WakeUp(); + + event_base* base() { return base_; } + QuicClock* clock() const { return clock_; } + + private: + class AlarmFactory : public QuicAlarmFactory { + public: + AlarmFactory(LibeventQuicEventLoop* loop) : loop_(loop) {} + + // QuicAlarmFactory interface. + QuicAlarm* CreateAlarm(QuicAlarm::Delegate* delegate) override; + QuicArenaScopedPtr CreateAlarm( + QuicArenaScopedPtr delegate, + QuicConnectionArena* arena) override; + + private: + LibeventQuicEventLoop* loop_; + }; + + class Registration { + public: + Registration(LibeventQuicEventLoop* loop, QuicUdpSocketFd fd, + QuicSocketEventMask events, QuicSocketEventListener* listener); + ~Registration(); + + void ArtificiallyNotify(QuicSocketEventMask events); + void Rearm(QuicSocketEventMask events); + + private: + LibeventQuicEventLoop* loop_; + QuicSocketEventListener* listener_; + + // Used for edge-triggered backends. + event both_events_; + // Used for level-triggered backends, since we may have to re-arm read + // events and write events separately. + event read_event_; + event write_event_; + }; + + using RegistrationMap = absl::node_hash_map; + + event_base* base_; + const bool edge_triggered_; + QuicClock* clock_; + + RegistrationMap registration_map_; +}; + +// RAII-style wrapper around event_base. +class QUICHE_EXPORT LibeventLoop { + public: + LibeventLoop(struct event_base* base) : event_base_(base) {} + ~LibeventLoop() { event_base_free(event_base_); } + + struct event_base* event_base() { return event_base_; } + + private: + struct event_base* event_base_; +}; + +// A version of LibeventQuicEventLoop that owns the supplied `event_base`. Note +// that the inheritance order here matters, since otherwise the `event_base` in +// question will be deleted before the LibeventQuicEventLoop object referencing +// it. +class QUICHE_EXPORT LibeventQuicEventLoopWithOwnership + : public LibeventLoop, + public LibeventQuicEventLoop { + public: + static std::unique_ptr Create( + QuicClock* clock, bool force_level_triggered = false); + + // Takes ownership of |base|. + explicit LibeventQuicEventLoopWithOwnership(struct event_base* base, + QuicClock* clock) + : LibeventLoop(base), LibeventQuicEventLoop(base, clock) {} +}; + +class QUICHE_EXPORT QuicLibeventEventLoopFactory : public QuicEventLoopFactory { + public: + // Provides the preferred libevent backend. + static QuicLibeventEventLoopFactory* Get() { + static QuicLibeventEventLoopFactory* factory = + new QuicLibeventEventLoopFactory(/*force_level_triggered=*/false); + return factory; + } + + // Provides the libevent backend that does not support edge-triggered + // notifications. Those are useful for tests, since we can test + // level-triggered I/O even on platforms where edge-triggered is the default. + static QuicLibeventEventLoopFactory* GetLevelTriggeredBackendForTests() { + static QuicLibeventEventLoopFactory* factory = + new QuicLibeventEventLoopFactory(/*force_level_triggered=*/true); + return factory; + } + + std::unique_ptr Create(QuicClock* clock) override { + return LibeventQuicEventLoopWithOwnership::Create(clock, + force_level_triggered_); + } + std::string GetName() const override { return name_; } + + private: + explicit QuicLibeventEventLoopFactory(bool force_level_triggered); + + bool force_level_triggered_; + std::string name_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_BINDINGS_QUIC_LIBEVENT_H_ diff --git a/quiche/quic/bindings/quic_libevent_test.cc b/quiche/quic/bindings/quic_libevent_test.cc new file mode 100644 index 000000000000..e6f2427f8ac2 --- /dev/null +++ b/quiche/quic/bindings/quic_libevent_test.cc @@ -0,0 +1,68 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/bindings/quic_libevent.h" + +#include + +#include "absl/memory/memory.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "quiche/quic/core/quic_alarm.h" +#include "quiche/quic/core/quic_default_clock.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/platform/api/quic_thread.h" + +namespace quic::test { +namespace { + +class FailureAlarmDelegate : public QuicAlarm::Delegate { + public: + QuicConnectionContext* GetConnectionContext() override { return nullptr; } + void OnAlarm() override { ADD_FAILURE() << "Test timed out"; } +}; + +class LoopBreakThread : public QuicThread { + public: + LoopBreakThread(LibeventQuicEventLoop* loop) + : QuicThread("LoopBreakThread"), loop_(loop) {} + + void Run() override { + // Make sure the other thread has actually made the blocking poll/epoll/etc + // call before calling WakeUp(). + absl::SleepFor(absl::Milliseconds(250)); + + loop_broken_.store(true); + loop_->WakeUp(); + } + + std::atomic& loop_broken() { return loop_broken_; } + + private: + LibeventQuicEventLoop* loop_; + std::atomic loop_broken_ = 0; +}; + +TEST(QuicLibeventTest, WakeUpFromAnotherThread) { + QuicClock* clock = QuicDefaultClock::Get(); + auto event_loop_owned = QuicLibeventEventLoopFactory::Get()->Create(clock); + LibeventQuicEventLoop* event_loop = + static_cast(event_loop_owned.get()); + std::unique_ptr alarm_factory = + event_loop->CreateAlarmFactory(); + std::unique_ptr timeout_alarm = + absl::WrapUnique(alarm_factory->CreateAlarm(new FailureAlarmDelegate())); + + const QuicTime kTimeoutAt = clock->Now() + QuicTime::Delta::FromSeconds(10); + timeout_alarm->Set(kTimeoutAt); + + LoopBreakThread thread(event_loop); + thread.Start(); + event_loop->RunEventLoopOnce(QuicTime::Delta::FromSeconds(5 * 60)); + EXPECT_TRUE(thread.loop_broken().load()); + thread.Join(); +} + +} // namespace +} // namespace quic::test diff --git a/quiche/quic/core/batch_writer/quic_batch_writer_base.cc b/quiche/quic/core/batch_writer/quic_batch_writer_base.cc new file mode 100644 index 000000000000..4a94c713c749 --- /dev/null +++ b/quiche/quic/core/batch_writer/quic_batch_writer_base.cc @@ -0,0 +1,176 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/batch_writer/quic_batch_writer_base.h" + +#include + +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_server_stats.h" + +namespace quic { + +QuicBatchWriterBase::QuicBatchWriterBase( + std::unique_ptr batch_buffer) + : write_blocked_(false), batch_buffer_(std::move(batch_buffer)) {} + +WriteResult QuicBatchWriterBase::WritePacket( + const char* buffer, size_t buf_len, const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, PerPacketOptions* options) { + const WriteResult result = + InternalWritePacket(buffer, buf_len, self_address, peer_address, options); + + if (IsWriteBlockedStatus(result.status)) { + write_blocked_ = true; + } + + return result; +} + +WriteResult QuicBatchWriterBase::InternalWritePacket( + const char* buffer, size_t buf_len, const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, PerPacketOptions* options) { + if (buf_len > kMaxOutgoingPacketSize) { + return WriteResult(WRITE_STATUS_MSG_TOO_BIG, EMSGSIZE); + } + + ReleaseTime release_time{0, QuicTime::Delta::Zero()}; + if (SupportsReleaseTime()) { + release_time = GetReleaseTime(options); + if (release_time.release_time_offset >= QuicTime::Delta::Zero()) { + QUIC_SERVER_HISTOGRAM_TIMES( + "batch_writer_positive_release_time_offset", + release_time.release_time_offset.ToMicroseconds(), 1, 100000, 50, + "Duration from ideal release time to actual " + "release time, in microseconds."); + } else { + QUIC_SERVER_HISTOGRAM_TIMES( + "batch_writer_negative_release_time_offset", + -release_time.release_time_offset.ToMicroseconds(), 1, 100000, 50, + "Duration from actual release time to ideal " + "release time, in microseconds."); + } + } + + const CanBatchResult can_batch_result = + CanBatch(buffer, buf_len, self_address, peer_address, options, + release_time.actual_release_time); + + bool buffered = false; + bool flush = can_batch_result.must_flush; + + if (can_batch_result.can_batch) { + QuicBatchWriterBuffer::PushResult push_result = + batch_buffer_->PushBufferedWrite(buffer, buf_len, self_address, + peer_address, options, + release_time.actual_release_time); + if (push_result.succeeded) { + buffered = true; + // If there's no space left after the packet is buffered, force a flush. + flush = flush || (batch_buffer_->GetNextWriteLocation() == nullptr); + } else { + // If there's no space without this packet, force a flush. + flush = true; + } + } + + if (!flush) { + WriteResult result(WRITE_STATUS_OK, 0); + result.send_time_offset = release_time.release_time_offset; + return result; + } + + size_t num_buffered_packets = buffered_writes().size(); + const FlushImplResult flush_result = CheckedFlush(); + WriteResult result = flush_result.write_result; + QUIC_DVLOG(1) << "Internally flushed " << flush_result.num_packets_sent + << " out of " << num_buffered_packets + << " packets. WriteResult=" << result; + + if (result.status != WRITE_STATUS_OK) { + if (IsWriteBlockedStatus(result.status)) { + return WriteResult( + buffered ? WRITE_STATUS_BLOCKED_DATA_BUFFERED : WRITE_STATUS_BLOCKED, + result.error_code); + } + + // Drop all packets, including the one being written. + size_t dropped_packets = + buffered ? buffered_writes().size() : buffered_writes().size() + 1; + + batch_buffer().Clear(); + result.dropped_packets = + dropped_packets > std::numeric_limits::max() + ? std::numeric_limits::max() + : static_cast(dropped_packets); + return result; + } + + if (!buffered) { + QuicBatchWriterBuffer::PushResult push_result = + batch_buffer_->PushBufferedWrite(buffer, buf_len, self_address, + peer_address, options, + release_time.actual_release_time); + buffered = push_result.succeeded; + + // Since buffered_writes has been emptied, this write must have been + // buffered successfully. + QUIC_BUG_IF(quic_bug_10826_1, !buffered) + << "Failed to push to an empty batch buffer." + << " self_addr:" << self_address.ToString() + << ", peer_addr:" << peer_address.ToString() << ", buf_len:" << buf_len; + } + + result.send_time_offset = release_time.release_time_offset; + return result; +} + +QuicBatchWriterBase::FlushImplResult QuicBatchWriterBase::CheckedFlush() { + if (buffered_writes().empty()) { + return FlushImplResult{WriteResult(WRITE_STATUS_OK, 0), + /*num_packets_sent=*/0, /*bytes_written=*/0}; + } + + const FlushImplResult flush_result = FlushImpl(); + + // Either flush_result.write_result.status is not WRITE_STATUS_OK, or it is + // WRITE_STATUS_OK and batch_buffer is empty. + QUICHE_DCHECK(flush_result.write_result.status != WRITE_STATUS_OK || + buffered_writes().empty()); + + // Flush should never return WRITE_STATUS_BLOCKED_DATA_BUFFERED. + QUICHE_DCHECK(flush_result.write_result.status != + WRITE_STATUS_BLOCKED_DATA_BUFFERED); + + return flush_result; +} + +WriteResult QuicBatchWriterBase::Flush() { + size_t num_buffered_packets = buffered_writes().size(); + FlushImplResult flush_result = CheckedFlush(); + QUIC_DVLOG(1) << "Externally flushed " << flush_result.num_packets_sent + << " out of " << num_buffered_packets + << " packets. WriteResult=" << flush_result.write_result; + + if (IsWriteError(flush_result.write_result.status)) { + if (buffered_writes().size() > std::numeric_limits::max()) { + flush_result.write_result.dropped_packets = + std::numeric_limits::max(); + } else { + flush_result.write_result.dropped_packets = + static_cast(buffered_writes().size()); + } + // Treat all errors as non-retryable fatal errors. Drop all buffered packets + // to avoid sending them and getting the same error again. + batch_buffer().Clear(); + } + + if (flush_result.write_result.status == WRITE_STATUS_BLOCKED) { + write_blocked_ = true; + } + return flush_result.write_result; +} + +} // namespace quic diff --git a/quiche/quic/core/batch_writer/quic_batch_writer_base.h b/quiche/quic/core/batch_writer/quic_batch_writer_base.h new file mode 100644 index 000000000000..a5a1e3877a52 --- /dev/null +++ b/quiche/quic/core/batch_writer/quic_batch_writer_base.h @@ -0,0 +1,156 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_BATCH_WRITER_QUIC_BATCH_WRITER_BASE_H_ +#define QUICHE_QUIC_CORE_BATCH_WRITER_QUIC_BATCH_WRITER_BASE_H_ + +#include + +#include "quiche/quic/core/batch_writer/quic_batch_writer_buffer.h" +#include "quiche/quic/core/quic_packet_writer.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_socket_address.h" + +namespace quic { + +// QuicBatchWriterBase implements logic common to all derived batch writers, +// including maintaining write blockage state and a skeleton implemention of +// WritePacket(). +// A derived batch writer must override the FlushImpl() function to send all +// buffered writes in a batch. It must also override the CanBatch() function +// to control whether/when a WritePacket() call should flush. +class QUIC_EXPORT_PRIVATE QuicBatchWriterBase : public QuicPacketWriter { + public: + explicit QuicBatchWriterBase( + std::unique_ptr batch_buffer); + + // ATTENTION: If this write triggered a flush, and the flush failed, all + // buffered packets will be dropped to allow the next write to work. The + // number of dropped packets can be found in WriteResult.dropped_packets. + WriteResult WritePacket(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + PerPacketOptions* options) override; + + bool IsWriteBlocked() const final { return write_blocked_; } + + void SetWritable() final { write_blocked_ = false; } + + absl::optional MessageTooBigErrorCode() const override { + return EMSGSIZE; + } + + QuicByteCount GetMaxPacketSize( + const QuicSocketAddress& /*peer_address*/) const final { + return kMaxOutgoingPacketSize; + } + + bool SupportsReleaseTime() const override { return false; } + + bool IsBatchMode() const final { return true; } + + QuicPacketBuffer GetNextWriteLocation( + const QuicIpAddress& /*self_address*/, + const QuicSocketAddress& /*peer_address*/) final { + // No need to explicitly delete QuicBatchWriterBuffer. + return {batch_buffer_->GetNextWriteLocation(), nullptr}; + } + + WriteResult Flush() override; + + protected: + const QuicBatchWriterBuffer& batch_buffer() const { return *batch_buffer_; } + QuicBatchWriterBuffer& batch_buffer() { return *batch_buffer_; } + + const quiche::QuicheCircularDeque& buffered_writes() const { + return batch_buffer_->buffered_writes(); + } + + // Given the release delay in |options| and the state of |batch_buffer_|, get + // the absolute release time. + struct QUIC_NO_EXPORT ReleaseTime { + // The actual (absolute) release time. + uint64_t actual_release_time = 0; + // The difference between |actual_release_time| and ideal release time, + // which is (now + |options->release_time_delay|). + QuicTime::Delta release_time_offset = QuicTime::Delta::Zero(); + }; + virtual ReleaseTime GetReleaseTime( + const PerPacketOptions* /*options*/) const { + QUICHE_DCHECK(false) + << "Should not be called since release time is unsupported."; + return ReleaseTime{0, QuicTime::Delta::Zero()}; + } + + struct QUIC_EXPORT_PRIVATE CanBatchResult { + CanBatchResult(bool can_batch, bool must_flush) + : can_batch(can_batch), must_flush(must_flush) {} + // Whether this write can be batched with existing buffered writes. + bool can_batch; + // If |can_batch|, whether the caller must flush after this packet is + // buffered. + // Always true if not |can_batch|. + bool must_flush; + }; + + // Given the existing buffered writes(in buffered_writes()), whether a new + // write(in the arguments) can be batched. + virtual CanBatchResult CanBatch(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + const PerPacketOptions* options, + uint64_t release_time) const = 0; + + struct QUIC_EXPORT_PRIVATE FlushImplResult { + // The return value of the Flush() interface, which is: + // - WriteResult(WRITE_STATUS_OK, ) if all buffered writes + // were sent successfully. + // - WRITE_STATUS_BLOCKED or WRITE_STATUS_ERROR, if the batch write is + // blocked or returned an error while sending. If a portion of buffered + // writes were sent successfully, |FlushImplResult.num_packets_sent| and + // |FlushImplResult.bytes_written| contain the number of successfully sent + // packets and their total bytes. + WriteResult write_result; + int num_packets_sent; + // If write_result.status == WRITE_STATUS_OK, |bytes_written| will be equal + // to write_result.bytes_written. Otherwise |bytes_written| will be the + // number of bytes written before WRITE_BLOCK or WRITE_ERROR happened. + int bytes_written; + }; + + // Send all buffered writes(in buffered_writes()) in a batch. + // buffered_writes() is guaranteed to be non-empty when this function is + // called. + virtual FlushImplResult FlushImpl() = 0; + + private: + WriteResult InternalWritePacket(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + PerPacketOptions* options); + + // Calls FlushImpl() and check its post condition. + FlushImplResult CheckedFlush(); + + bool write_blocked_; + std::unique_ptr batch_buffer_; +}; + +// QuicUdpBatchWriter is a batch writer backed by a UDP socket. +class QUIC_EXPORT_PRIVATE QuicUdpBatchWriter : public QuicBatchWriterBase { + public: + QuicUdpBatchWriter(std::unique_ptr batch_buffer, + int fd) + : QuicBatchWriterBase(std::move(batch_buffer)), fd_(fd) {} + + int fd() const { return fd_; } + + private: + const int fd_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_BATCH_WRITER_QUIC_BATCH_WRITER_BASE_H_ diff --git a/quiche/quic/core/batch_writer/quic_batch_writer_buffer.cc b/quiche/quic/core/batch_writer/quic_batch_writer_buffer.cc new file mode 100644 index 000000000000..ac7ddd793c55 --- /dev/null +++ b/quiche/quic/core/batch_writer/quic_batch_writer_buffer.cc @@ -0,0 +1,151 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/batch_writer/quic_batch_writer_buffer.h" + +#include + +namespace quic { + +QuicBatchWriterBuffer::QuicBatchWriterBuffer() { + memset(buffer_, 0, sizeof(buffer_)); +} + +void QuicBatchWriterBuffer::Clear() { buffered_writes_.clear(); } + +std::string QuicBatchWriterBuffer::DebugString() const { + std::ostringstream os; + os << "{ buffer: " << static_cast(buffer_) + << " buffer_end: " << static_cast(buffer_end()) + << " buffered_writes_.size(): " << buffered_writes_.size() + << " next_write_loc: " << static_cast(GetNextWriteLocation()) + << " SizeInUse: " << SizeInUse() << " }"; + return os.str(); +} + +bool QuicBatchWriterBuffer::Invariants() const { + // Buffers in buffered_writes_ should not overlap, and collectively they + // should cover a continuous prefix of buffer_. + const char* next_buffer = buffer_; + for (auto iter = buffered_writes_.begin(); iter != buffered_writes_.end(); + ++iter) { + if ((iter->buffer != next_buffer) || + (iter->buffer + iter->buf_len > buffer_end())) { + return false; + } + next_buffer += iter->buf_len; + } + + return static_cast(next_buffer - buffer_) == SizeInUse(); +} + +char* QuicBatchWriterBuffer::GetNextWriteLocation() const { + const char* next_loc = + buffered_writes_.empty() + ? buffer_ + : buffered_writes_.back().buffer + buffered_writes_.back().buf_len; + if (static_cast(buffer_end() - next_loc) < kMaxOutgoingPacketSize) { + return nullptr; + } + return const_cast(next_loc); +} + +QuicBatchWriterBuffer::PushResult QuicBatchWriterBuffer::PushBufferedWrite( + const char* buffer, size_t buf_len, const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, const PerPacketOptions* options, + uint64_t release_time) { + QUICHE_DCHECK(Invariants()); + QUICHE_DCHECK_LE(buf_len, kMaxOutgoingPacketSize); + + PushResult result = {/*succeeded=*/false, /*buffer_copied=*/false}; + char* next_write_location = GetNextWriteLocation(); + if (next_write_location == nullptr) { + return result; + } + + if (buffer != next_write_location) { + if (IsExternalBuffer(buffer, buf_len)) { + memcpy(next_write_location, buffer, buf_len); + } else if (IsInternalBuffer(buffer, buf_len)) { + memmove(next_write_location, buffer, buf_len); + } else { + QUIC_BUG(quic_bug_10831_1) + << "Buffer[" << static_cast(buffer) << ", " + << static_cast(buffer + buf_len) + << ") overlaps with internal buffer[" + << static_cast(buffer_) << ", " + << static_cast(buffer_end()) << ")"; + return result; + } + result.buffer_copied = true; + } else { + // In place push, do nothing. + } + buffered_writes_.emplace_back( + next_write_location, buf_len, self_address, peer_address, + options ? options->Clone() : std::unique_ptr(), + release_time); + + QUICHE_DCHECK(Invariants()); + + result.succeeded = true; + return result; +} + +void QuicBatchWriterBuffer::UndoLastPush() { + if (!buffered_writes_.empty()) { + buffered_writes_.pop_back(); + } +} + +QuicBatchWriterBuffer::PopResult QuicBatchWriterBuffer::PopBufferedWrite( + int32_t num_buffered_writes) { + QUICHE_DCHECK(Invariants()); + QUICHE_DCHECK_GE(num_buffered_writes, 0); + QUICHE_DCHECK_LE(static_cast(num_buffered_writes), + buffered_writes_.size()); + + PopResult result = {/*num_buffers_popped=*/0, + /*moved_remaining_buffers=*/false}; + + result.num_buffers_popped = std::max(num_buffered_writes, 0); + result.num_buffers_popped = + std::min(result.num_buffers_popped, buffered_writes_.size()); + buffered_writes_.pop_front_n(result.num_buffers_popped); + + if (!buffered_writes_.empty()) { + // If not all buffered writes are erased, the remaining ones will not cover + // a continuous prefix of buffer_. We'll fix it by moving the remaining + // buffers to the beginning of buffer_ and adjust the buffer pointers in all + // remaining buffered writes. + // This should happen very rarely, about once per write block. + result.moved_remaining_buffers = true; + const char* buffer_before_move = buffered_writes_.front().buffer; + size_t buffer_len_to_move = buffered_writes_.back().buffer + + buffered_writes_.back().buf_len - + buffer_before_move; + memmove(buffer_, buffer_before_move, buffer_len_to_move); + + size_t distance_to_move = buffer_before_move - buffer_; + for (BufferedWrite& buffered_write : buffered_writes_) { + buffered_write.buffer -= distance_to_move; + } + + QUICHE_DCHECK_EQ(buffer_, buffered_writes_.front().buffer); + } + QUICHE_DCHECK(Invariants()); + + return result; +} + +size_t QuicBatchWriterBuffer::SizeInUse() const { + if (buffered_writes_.empty()) { + return 0; + } + + return buffered_writes_.back().buffer + buffered_writes_.back().buf_len - + buffer_; +} + +} // namespace quic diff --git a/quiche/quic/core/batch_writer/quic_batch_writer_buffer.h b/quiche/quic/core/batch_writer/quic_batch_writer_buffer.h new file mode 100644 index 000000000000..23694019c3d5 --- /dev/null +++ b/quiche/quic/core/batch_writer/quic_batch_writer_buffer.h @@ -0,0 +1,94 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_BATCH_WRITER_QUIC_BATCH_WRITER_BUFFER_H_ +#define QUICHE_QUIC_CORE_BATCH_WRITER_QUIC_BATCH_WRITER_BUFFER_H_ + +#include "absl/base/optimization.h" +#include "quiche/quic/core/quic_linux_socket_utils.h" +#include "quiche/quic/core/quic_packet_writer.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/common/quiche_circular_deque.h" + +namespace quic { + +// QuicBatchWriterBuffer manages an internal buffer to hold data from multiple +// packets. Packet data are placed continuously within the internal buffer such +// that they can be sent by a QuicGsoBatchWriter. +// This class can also be used by a QuicBatchWriter which uses sendmmsg, +// although it is not optimized for that use case. +class QUIC_EXPORT_PRIVATE QuicBatchWriterBuffer { + public: + QuicBatchWriterBuffer(); + + // Clear all buffered writes, but leave the internal buffer intact. + void Clear(); + + char* GetNextWriteLocation() const; + + // Push a buffered write to the back. + struct QUIC_EXPORT_PRIVATE PushResult { + bool succeeded; + // True in one of the following cases: + // 1) The packet buffer is external and copied to the internal buffer, or + // 2) The packet buffer is from the internal buffer and moved within it. + // This only happens if PopBufferedWrite is called in the middle of a + // in-place push. + // Only valid if |succeeded| is true. + bool buffer_copied; + }; + + PushResult PushBufferedWrite(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + const PerPacketOptions* options, + uint64_t release_time); + + void UndoLastPush(); + + // Pop |num_buffered_writes| buffered writes from the front. + // |num_buffered_writes| will be capped to [0, buffered_writes().size()] + // before it is used. + struct QUIC_EXPORT_PRIVATE PopResult { + int32_t num_buffers_popped; + // True if after |num_buffers_popped| buffers are popped from front, the + // remaining buffers are moved to the beginning of the internal buffer. + // This should normally be false. + bool moved_remaining_buffers; + }; + PopResult PopBufferedWrite(int32_t num_buffered_writes); + + const quiche::QuicheCircularDeque& buffered_writes() const { + return buffered_writes_; + } + + bool IsExternalBuffer(const char* buffer, size_t buf_len) const { + return (buffer + buf_len) <= buffer_ || buffer >= buffer_end(); + } + bool IsInternalBuffer(const char* buffer, size_t buf_len) const { + return buffer >= buffer_ && (buffer + buf_len) <= buffer_end(); + } + + // Number of bytes used in |buffer_|. + // PushBufferedWrite() increases this; PopBufferedWrite decreases this. + size_t SizeInUse() const; + + // Rounded up from |kMaxGsoPacketSize|, which is the maximum allowed + // size of a GSO packet. + static const size_t kBufferSize = 64 * 1024; + + std::string DebugString() const; + + protected: + // Whether the invariants of the buffer are upheld. For debug & test only. + bool Invariants() const; + const char* buffer_end() const { return buffer_ + sizeof(buffer_); } + ABSL_CACHELINE_ALIGNED char buffer_[kBufferSize]; + quiche::QuicheCircularDeque buffered_writes_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_BATCH_WRITER_QUIC_BATCH_WRITER_BUFFER_H_ diff --git a/quiche/quic/core/batch_writer/quic_batch_writer_buffer_test.cc b/quiche/quic/core/batch_writer/quic_batch_writer_buffer_test.cc new file mode 100644 index 000000000000..e081a798d2dd --- /dev/null +++ b/quiche/quic/core/batch_writer/quic_batch_writer_buffer_test.cc @@ -0,0 +1,281 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/batch_writer/quic_batch_writer_buffer.h" + +#include +#include + +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { +namespace { + +class QUIC_EXPORT_PRIVATE TestQuicBatchWriterBuffer + : public QuicBatchWriterBuffer { + public: + using QuicBatchWriterBuffer::buffer_; + using QuicBatchWriterBuffer::buffered_writes_; +}; + +static const size_t kBatchBufferSize = QuicBatchWriterBuffer::kBufferSize; + +class QuicBatchWriterBufferTest : public QuicTest { + public: + QuicBatchWriterBufferTest() { SwitchToNewBuffer(); } + + void SwitchToNewBuffer() { + batch_buffer_ = std::make_unique(); + } + + // Fill packet_buffer_ with kMaxOutgoingPacketSize bytes of |c|s. + char* FillPacketBuffer(char c) { + return FillPacketBuffer(c, packet_buffer_, kMaxOutgoingPacketSize); + } + + // Fill |packet_buffer| with kMaxOutgoingPacketSize bytes of |c|s. + char* FillPacketBuffer(char c, char* packet_buffer) { + return FillPacketBuffer(c, packet_buffer, kMaxOutgoingPacketSize); + } + + // Fill |packet_buffer| with |buf_len| bytes of |c|s. + char* FillPacketBuffer(char c, char* packet_buffer, size_t buf_len) { + memset(packet_buffer, c, buf_len); + return packet_buffer; + } + + void CheckBufferedWriteContent(int buffered_write_index, char buffer_content, + size_t buf_len, const QuicIpAddress& self_addr, + const QuicSocketAddress& peer_addr, + const PerPacketOptions* options) { + const BufferedWrite& buffered_write = + batch_buffer_->buffered_writes()[buffered_write_index]; + EXPECT_EQ(buf_len, buffered_write.buf_len); + for (size_t i = 0; i < buf_len; ++i) { + EXPECT_EQ(buffer_content, buffered_write.buffer[i]); + if (buffer_content != buffered_write.buffer[i]) { + break; + } + } + EXPECT_EQ(self_addr, buffered_write.self_address); + EXPECT_EQ(peer_addr, buffered_write.peer_address); + if (options == nullptr) { + EXPECT_EQ(nullptr, buffered_write.options); + } else { + EXPECT_EQ(options->release_time_delay, + buffered_write.options->release_time_delay); + } + } + + protected: + std::unique_ptr batch_buffer_; + QuicIpAddress self_addr_; + QuicSocketAddress peer_addr_; + uint64_t release_time_ = 0; + char packet_buffer_[kMaxOutgoingPacketSize]; +}; + +class BufferSizeSequence { + public: + explicit BufferSizeSequence( + std::vector, size_t>> stages) + : stages_(std::move(stages)), + total_buf_len_(0), + stage_index_(0), + sequence_index_(0) {} + + size_t Next() { + const std::vector& seq = stages_[stage_index_].first; + size_t buf_len = seq[sequence_index_++ % seq.size()]; + total_buf_len_ += buf_len; + if (stages_[stage_index_].second <= total_buf_len_) { + stage_index_ = std::min(stage_index_ + 1, stages_.size() - 1); + } + return buf_len; + } + + private: + const std::vector, size_t>> stages_; + size_t total_buf_len_; + size_t stage_index_; + size_t sequence_index_; +}; + +// Test in-place pushes. A in-place push is a push with a buffer address that is +// equal to the result of GetNextWriteLocation(). +TEST_F(QuicBatchWriterBufferTest, InPlacePushes) { + std::vector buffer_size_sequences = { + // Push large writes until the buffer is near full, then switch to 1-byte + // writes. This covers the edge cases when detecting insufficient buffer. + BufferSizeSequence({{{1350}, kBatchBufferSize - 3000}, {{1}, 1e6}}), + // A sequence that looks real. + BufferSizeSequence({{{1, 39, 97, 150, 1350, 1350, 1350, 1350}, 1e6}}), + }; + + for (auto& buffer_size_sequence : buffer_size_sequences) { + SwitchToNewBuffer(); + int64_t num_push_failures = 0; + + while (batch_buffer_->SizeInUse() < kBatchBufferSize) { + size_t buf_len = buffer_size_sequence.Next(); + const bool has_enough_space = + (kBatchBufferSize - batch_buffer_->SizeInUse() >= + kMaxOutgoingPacketSize); + + char* buffer = batch_buffer_->GetNextWriteLocation(); + + if (has_enough_space) { + EXPECT_EQ(batch_buffer_->buffer_ + batch_buffer_->SizeInUse(), buffer); + } else { + EXPECT_EQ(nullptr, buffer); + } + + SCOPED_TRACE(testing::Message() + << "Before Push: buf_len=" << buf_len + << ", has_enough_space=" << has_enough_space + << ", batch_buffer=" << batch_buffer_->DebugString()); + + auto push_result = batch_buffer_->PushBufferedWrite( + buffer, buf_len, self_addr_, peer_addr_, nullptr, release_time_); + if (!push_result.succeeded) { + ++num_push_failures; + } + EXPECT_EQ(has_enough_space, push_result.succeeded); + EXPECT_FALSE(push_result.buffer_copied); + if (!has_enough_space) { + break; + } + } + // Expect one and only one failure from the final push operation. + EXPECT_EQ(1, num_push_failures); + } +} + +// Test some in-place pushes mixed with pushes with external buffers. +TEST_F(QuicBatchWriterBufferTest, MixedPushes) { + // First, a in-place push. + char* buffer = batch_buffer_->GetNextWriteLocation(); + auto push_result = batch_buffer_->PushBufferedWrite( + FillPacketBuffer('A', buffer), kDefaultMaxPacketSize, self_addr_, + peer_addr_, nullptr, release_time_); + EXPECT_TRUE(push_result.succeeded); + EXPECT_FALSE(push_result.buffer_copied); + CheckBufferedWriteContent(0, 'A', kDefaultMaxPacketSize, self_addr_, + peer_addr_, nullptr); + + // Then a push with external buffer. + push_result = batch_buffer_->PushBufferedWrite( + FillPacketBuffer('B'), kDefaultMaxPacketSize, self_addr_, peer_addr_, + nullptr, release_time_); + EXPECT_TRUE(push_result.succeeded); + EXPECT_TRUE(push_result.buffer_copied); + CheckBufferedWriteContent(1, 'B', kDefaultMaxPacketSize, self_addr_, + peer_addr_, nullptr); + + // Then another in-place push. + buffer = batch_buffer_->GetNextWriteLocation(); + push_result = batch_buffer_->PushBufferedWrite( + FillPacketBuffer('C', buffer), kDefaultMaxPacketSize, self_addr_, + peer_addr_, nullptr, release_time_); + EXPECT_TRUE(push_result.succeeded); + EXPECT_FALSE(push_result.buffer_copied); + CheckBufferedWriteContent(2, 'C', kDefaultMaxPacketSize, self_addr_, + peer_addr_, nullptr); + + // Then another push with external buffer. + push_result = batch_buffer_->PushBufferedWrite( + FillPacketBuffer('D'), kDefaultMaxPacketSize, self_addr_, peer_addr_, + nullptr, release_time_); + EXPECT_TRUE(push_result.succeeded); + EXPECT_TRUE(push_result.buffer_copied); + CheckBufferedWriteContent(3, 'D', kDefaultMaxPacketSize, self_addr_, + peer_addr_, nullptr); +} + +TEST_F(QuicBatchWriterBufferTest, PopAll) { + const int kNumBufferedWrites = 10; + for (int i = 0; i < kNumBufferedWrites; ++i) { + EXPECT_TRUE(batch_buffer_ + ->PushBufferedWrite(packet_buffer_, kDefaultMaxPacketSize, + self_addr_, peer_addr_, nullptr, + release_time_) + .succeeded); + } + EXPECT_EQ(kNumBufferedWrites, + static_cast(batch_buffer_->buffered_writes().size())); + + auto pop_result = batch_buffer_->PopBufferedWrite(kNumBufferedWrites); + EXPECT_EQ(0u, batch_buffer_->buffered_writes().size()); + EXPECT_EQ(kNumBufferedWrites, pop_result.num_buffers_popped); + EXPECT_FALSE(pop_result.moved_remaining_buffers); +} + +TEST_F(QuicBatchWriterBufferTest, PopPartial) { + const int kNumBufferedWrites = 10; + for (int i = 0; i < kNumBufferedWrites; ++i) { + EXPECT_TRUE(batch_buffer_ + ->PushBufferedWrite(FillPacketBuffer('A' + i), + kDefaultMaxPacketSize - i, self_addr_, + peer_addr_, nullptr, release_time_) + .succeeded); + } + + for (size_t i = 0; + i < kNumBufferedWrites && !batch_buffer_->buffered_writes().empty(); + ++i) { + const size_t size_before_pop = batch_buffer_->buffered_writes().size(); + const size_t expect_size_after_pop = + size_before_pop < i ? 0 : size_before_pop - i; + batch_buffer_->PopBufferedWrite(i); + ASSERT_EQ(expect_size_after_pop, batch_buffer_->buffered_writes().size()); + const char first_write_content = + 'A' + kNumBufferedWrites - expect_size_after_pop; + const size_t first_write_len = + kDefaultMaxPacketSize - kNumBufferedWrites + expect_size_after_pop; + for (size_t j = 0; j < expect_size_after_pop; ++j) { + CheckBufferedWriteContent(j, first_write_content + j, first_write_len - j, + self_addr_, peer_addr_, nullptr); + } + } +} + +TEST_F(QuicBatchWriterBufferTest, InPlacePushWithPops) { + // First, a in-place push. + char* buffer = batch_buffer_->GetNextWriteLocation(); + const size_t first_packet_len = 2; + auto push_result = batch_buffer_->PushBufferedWrite( + FillPacketBuffer('A', buffer, first_packet_len), first_packet_len, + self_addr_, peer_addr_, nullptr, release_time_); + EXPECT_TRUE(push_result.succeeded); + EXPECT_FALSE(push_result.buffer_copied); + CheckBufferedWriteContent(0, 'A', first_packet_len, self_addr_, peer_addr_, + nullptr); + + // Simulate the case where the writer wants to do another in-place push, but + // can't do so because it can't be batched with the first buffer. + buffer = batch_buffer_->GetNextWriteLocation(); + const size_t second_packet_len = 1350; + + // Flush the first buffer. + auto pop_result = batch_buffer_->PopBufferedWrite(1); + EXPECT_EQ(1, pop_result.num_buffers_popped); + EXPECT_FALSE(pop_result.moved_remaining_buffers); + + // Now the second push. + push_result = batch_buffer_->PushBufferedWrite( + FillPacketBuffer('B', buffer, second_packet_len), second_packet_len, + self_addr_, peer_addr_, nullptr, release_time_); + EXPECT_TRUE(push_result.succeeded); + EXPECT_TRUE(push_result.buffer_copied); + CheckBufferedWriteContent(0, 'B', second_packet_len, self_addr_, peer_addr_, + nullptr); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/batch_writer/quic_batch_writer_test.cc b/quiche/quic/core/batch_writer/quic_batch_writer_test.cc new file mode 100644 index 000000000000..4a0683088f58 --- /dev/null +++ b/quiche/quic/core/batch_writer/quic_batch_writer_test.cc @@ -0,0 +1,76 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/batch_writer/quic_batch_writer_test.h" + +#include "quiche/quic/core/batch_writer/quic_gso_batch_writer.h" +#include "quiche/quic/core/batch_writer/quic_sendmmsg_batch_writer.h" + +namespace quic { +namespace test { +namespace { + +class QuicGsoBatchWriterIOTestDelegate + : public QuicUdpBatchWriterIOTestDelegate { + public: + bool ShouldSkip(const QuicUdpBatchWriterIOTestParams& params) override { + QuicUdpSocketApi socket_api; + int fd = + socket_api.Create(params.address_family, + /*receive_buffer_size=*/kDefaultSocketReceiveBuffer, + /*send_buffer_size=*/kDefaultSocketReceiveBuffer); + if (fd < 0) { + QUIC_LOG(ERROR) << "CreateSocket() failed: " << strerror(errno); + return false; // Let the test fail rather than skip it. + } + const bool gso_not_supported = + QuicLinuxSocketUtils::GetUDPSegmentSize(fd) < 0; + socket_api.Destroy(fd); + + if (gso_not_supported) { + QUIC_LOG(WARNING) << "Test skipped since GSO is not supported."; + return true; + } + + QUIC_LOG(WARNING) << "OK: GSO is supported."; + return false; + } + + void ResetWriter(int fd) override { + writer_ = std::make_unique(fd); + } + + QuicUdpBatchWriter* GetWriter() override { return writer_.get(); } + + private: + std::unique_ptr writer_; +}; + +INSTANTIATE_TEST_SUITE_P( + QuicGsoBatchWriterTest, QuicUdpBatchWriterIOTest, + testing::ValuesIn( + MakeQuicBatchWriterTestParams())); + +class QuicSendmmsgBatchWriterIOTestDelegate + : public QuicUdpBatchWriterIOTestDelegate { + public: + void ResetWriter(int fd) override { + writer_ = std::make_unique( + std::make_unique(), fd); + } + + QuicUdpBatchWriter* GetWriter() override { return writer_.get(); } + + private: + std::unique_ptr writer_; +}; + +INSTANTIATE_TEST_SUITE_P( + QuicSendmmsgBatchWriterTest, QuicUdpBatchWriterIOTest, + testing::ValuesIn(MakeQuicBatchWriterTestParams< + QuicSendmmsgBatchWriterIOTestDelegate>())); + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/batch_writer/quic_batch_writer_test.h b/quiche/quic/core/batch_writer/quic_batch_writer_test.h new file mode 100644 index 000000000000..ebbb9176916b --- /dev/null +++ b/quiche/quic/core/batch_writer/quic_batch_writer_test.h @@ -0,0 +1,286 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_BATCH_WRITER_QUIC_BATCH_WRITER_TEST_H_ +#define QUICHE_QUIC_CORE_BATCH_WRITER_QUIC_BATCH_WRITER_TEST_H_ + +#include +#include + +#include +#include +#include + +#include "absl/base/optimization.h" +#include "quiche/quic/core/batch_writer/quic_batch_writer_base.h" +#include "quiche/quic/core/quic_udp_socket.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { + +static bool IsAddressFamilySupported(int address_family) { + static auto check_function = [](int address_family) { + int fd = socket(address_family, SOCK_STREAM, 0); + if (fd < 0) { + QUIC_LOG(ERROR) << "address_family not supported: " << address_family + << ", error: " << strerror(errno); + EXPECT_EQ(EAFNOSUPPORT, errno); + return false; + } + close(fd); + return true; + }; + + if (address_family == AF_INET) { + static const bool ipv4_supported = check_function(AF_INET); + return ipv4_supported; + } + + static const bool ipv6_supported = check_function(AF_INET6); + return ipv6_supported; +} + +static bool CreateSocket(int family, QuicSocketAddress* address, int* fd) { + if (family == AF_INET) { + *address = QuicSocketAddress(QuicIpAddress::Loopback4(), 0); + } else { + QUICHE_DCHECK_EQ(family, AF_INET6); + *address = QuicSocketAddress(QuicIpAddress::Loopback6(), 0); + } + + QuicUdpSocketApi socket_api; + *fd = socket_api.Create(family, + /*receive_buffer_size=*/kDefaultSocketReceiveBuffer, + /*send_buffer_size=*/kDefaultSocketReceiveBuffer); + if (*fd < 0) { + QUIC_LOG(ERROR) << "CreateSocket() failed: " << strerror(errno); + return false; + } + socket_api.EnableDroppedPacketCount(*fd); + + if (!socket_api.Bind(*fd, *address)) { + QUIC_LOG(ERROR) << "Bind failed: " << strerror(errno); + return false; + } + + if (address->FromSocket(*fd) != 0) { + QUIC_LOG(ERROR) << "Unable to get self address. Error: " + << strerror(errno); + return false; + } + return true; +} + +struct QuicUdpBatchWriterIOTestParams; +class QUIC_EXPORT_PRIVATE QuicUdpBatchWriterIOTestDelegate { + public: + virtual ~QuicUdpBatchWriterIOTestDelegate() {} + + virtual bool ShouldSkip(const QuicUdpBatchWriterIOTestParams& /*params*/) { + return false; + } + + virtual void ResetWriter(int fd) = 0; + + virtual QuicUdpBatchWriter* GetWriter() = 0; +}; + +struct QUIC_EXPORT_PRIVATE QuicUdpBatchWriterIOTestParams { + // Use shared_ptr because gtest makes copies of test params. + std::shared_ptr delegate; + int address_family; + int data_size; + int packet_size; + + QUIC_EXPORT_PRIVATE friend std::ostream& operator<<( + std::ostream& os, const QuicUdpBatchWriterIOTestParams& p) { + os << "{ address_family: " << p.address_family + << " data_size: " << p.data_size << " packet_size: " << p.packet_size + << " }"; + return os; + } +}; + +template +static std::vector +MakeQuicBatchWriterTestParams() { + static_assert(std::is_base_of::value, + " needs to derive from " + "QuicUdpBatchWriterIOTestDelegate"); + + std::vector params; + for (int address_family : {AF_INET, AF_INET6}) { + for (int data_size : {1, 150, 1500, 15000, 64000, 512 * 1024}) { + for (int packet_size : {1, 50, 1350, 1452}) { + if (packet_size <= data_size && (data_size / packet_size < 2000)) { + params.push_back( + {std::make_unique(), + address_family, data_size, packet_size}); + } + } + } + } + return params; +} + +// QuicUdpBatchWriterIOTest is a value parameterized test fixture that can be +// used by tests of derived classes of QuicUdpBatchWriter, to verify basic +// packet IO capabilities. +class QUIC_EXPORT_PRIVATE QuicUdpBatchWriterIOTest + : public QuicTestWithParam { + protected: + QuicUdpBatchWriterIOTest() + : address_family_(GetParam().address_family), + data_size_(GetParam().data_size), + packet_size_(GetParam().packet_size), + self_socket_(-1), + peer_socket_(-1) { + QUIC_LOG(INFO) << "QuicUdpBatchWriterIOTestParams: " << GetParam(); + EXPECT_TRUE(address_family_ == AF_INET || address_family_ == AF_INET6); + EXPECT_LE(packet_size_, data_size_); + EXPECT_LE(packet_size_, sizeof(packet_buffer_)); + } + + ~QuicUdpBatchWriterIOTest() override { + if (self_socket_ > 0) { + close(self_socket_); + } + if (peer_socket_ > 0) { + close(peer_socket_); + } + } + + // Whether this test should be skipped. A test is passed if skipped. + // A test can be skipped when e.g. it exercises a kernel feature that is not + // available on the system. + bool ShouldSkip() { + if (!IsAddressFamilySupported(address_family_)) { + QUIC_LOG(WARNING) + << "Test skipped since address_family is not supported."; + return true; + } + + return GetParam().delegate->ShouldSkip(GetParam()); + } + + // Initialize a test. + // To fail the test in Initialize, use ASSERT_xx macros. + void Initialize() { + ASSERT_TRUE(CreateSocket(address_family_, &self_address_, &self_socket_)); + ASSERT_TRUE(CreateSocket(address_family_, &peer_address_, &peer_socket_)); + + QUIC_DLOG(INFO) << "Self address: " << self_address_.ToString() << ", fd " + << self_socket_; + QUIC_DLOG(INFO) << "Peer address: " << peer_address_.ToString() << ", fd " + << peer_socket_; + GetParam().delegate->ResetWriter(self_socket_); + } + + QuicUdpBatchWriter* GetWriter() { return GetParam().delegate->GetWriter(); } + + void ValidateWrite() { + char this_packet_content = '\0'; + int this_packet_size; + int num_writes = 0; + size_t bytes_flushed = 0; + WriteResult result; + + for (size_t bytes_sent = 0; bytes_sent < data_size_; + bytes_sent += this_packet_size, ++this_packet_content) { + this_packet_size = std::min(packet_size_, data_size_ - bytes_sent); + memset(&packet_buffer_[0], this_packet_content, this_packet_size); + + result = GetWriter()->WritePacket(&packet_buffer_[0], this_packet_size, + self_address_.host(), peer_address_, + nullptr); + + ASSERT_EQ(WRITE_STATUS_OK, result.status) << strerror(result.error_code); + bytes_flushed += result.bytes_written; + ++num_writes; + + QUIC_DVLOG(1) << "[write #" << num_writes + << "] this_packet_size: " << this_packet_size + << ", total_bytes_sent: " << bytes_sent + this_packet_size + << ", bytes_flushed: " << bytes_flushed + << ", pkt content:" << std::hex << int(this_packet_content); + } + + result = GetWriter()->Flush(); + ASSERT_EQ(WRITE_STATUS_OK, result.status) << strerror(result.error_code); + bytes_flushed += result.bytes_written; + ASSERT_EQ(data_size_, bytes_flushed); + + QUIC_LOG(INFO) << "Sent " << data_size_ << " bytes in " << num_writes + << " writes."; + } + + void ValidateRead() { + char this_packet_content = '\0'; + int this_packet_size; + int packets_received = 0; + for (size_t bytes_received = 0; bytes_received < data_size_; + bytes_received += this_packet_size, ++this_packet_content) { + this_packet_size = std::min(packet_size_, data_size_ - bytes_received); + SCOPED_TRACE(testing::Message() + << "Before ReadPacket: bytes_received=" << bytes_received + << ", this_packet_size=" << this_packet_size); + + QuicUdpSocketApi::ReadPacketResult result; + result.packet_buffer = {&packet_buffer_[0], sizeof(packet_buffer_)}; + result.control_buffer = {&control_buffer_[0], sizeof(control_buffer_)}; + QuicUdpSocketApi().ReadPacket( + peer_socket_, + quic::BitMask64(QuicUdpPacketInfoBit::V4_SELF_IP, + QuicUdpPacketInfoBit::V6_SELF_IP, + QuicUdpPacketInfoBit::PEER_ADDRESS), + &result); + ASSERT_TRUE(result.ok); + ASSERT_TRUE( + result.packet_info.HasValue(QuicUdpPacketInfoBit::PEER_ADDRESS)); + QuicSocketAddress read_peer_address = result.packet_info.peer_address(); + QuicIpAddress read_self_address = read_peer_address.host().IsIPv6() + ? result.packet_info.self_v6_ip() + : result.packet_info.self_v4_ip(); + + EXPECT_EQ(read_self_address, peer_address_.host()); + EXPECT_EQ(read_peer_address, self_address_); + for (int i = 0; i < this_packet_size; ++i) { + EXPECT_EQ(this_packet_content, packet_buffer_[i]); + } + packets_received += this_packet_size; + } + + QUIC_LOG(INFO) << "Received " << data_size_ << " bytes in " + << packets_received << " packets."; + } + + QuicSocketAddress self_address_; + QuicSocketAddress peer_address_; + ABSL_CACHELINE_ALIGNED char packet_buffer_[1500]; + ABSL_CACHELINE_ALIGNED char + control_buffer_[kDefaultUdpPacketControlBufferSize]; + int address_family_; + const size_t data_size_; + const size_t packet_size_; + int self_socket_; + int peer_socket_; +}; + +TEST_P(QuicUdpBatchWriterIOTest, WriteAndRead) { + if (ShouldSkip()) { + return; + } + + Initialize(); + + ValidateWrite(); + ValidateRead(); +} + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_CORE_BATCH_WRITER_QUIC_BATCH_WRITER_TEST_H_ diff --git a/quiche/quic/core/batch_writer/quic_gso_batch_writer.cc b/quiche/quic/core/batch_writer/quic_gso_batch_writer.cc new file mode 100644 index 000000000000..b6d6200fb544 --- /dev/null +++ b/quiche/quic/core/batch_writer/quic_gso_batch_writer.cc @@ -0,0 +1,159 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/batch_writer/quic_gso_batch_writer.h" + +#include + +#include + +#include "quiche/quic/core/quic_linux_socket_utils.h" +#include "quiche/quic/platform/api/quic_server_stats.h" + +namespace quic { + +// static +std::unique_ptr +QuicGsoBatchWriter::CreateBatchWriterBuffer() { + return std::make_unique(); +} + +QuicGsoBatchWriter::QuicGsoBatchWriter(int fd) + : QuicGsoBatchWriter(fd, CLOCK_MONOTONIC) {} + +QuicGsoBatchWriter::QuicGsoBatchWriter(int fd, + clockid_t clockid_for_release_time) + : QuicUdpBatchWriter(CreateBatchWriterBuffer(), fd), + clockid_for_release_time_(clockid_for_release_time), + supports_release_time_( + GetQuicRestartFlag(quic_support_release_time_for_gso) && + QuicLinuxSocketUtils::EnableReleaseTime(fd, + clockid_for_release_time)) { + if (supports_release_time_) { + QUIC_RESTART_FLAG_COUNT(quic_support_release_time_for_gso); + QUIC_LOG_FIRST_N(INFO, 5) << "Release time is enabled."; + } else { + QUIC_LOG_FIRST_N(INFO, 5) << "Release time is not enabled."; + } +} + +QuicGsoBatchWriter::QuicGsoBatchWriter( + std::unique_ptr batch_buffer, int fd, + clockid_t clockid_for_release_time, ReleaseTimeForceEnabler /*enabler*/) + : QuicUdpBatchWriter(std::move(batch_buffer), fd), + clockid_for_release_time_(clockid_for_release_time), + supports_release_time_(true) { + QUIC_DLOG(INFO) << "Release time forcefully enabled."; +} + +QuicGsoBatchWriter::CanBatchResult QuicGsoBatchWriter::CanBatch( + const char* /*buffer*/, size_t buf_len, const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, const PerPacketOptions* options, + uint64_t release_time) const { + // If there is nothing buffered already, this write will be included in this + // batch. + if (buffered_writes().empty()) { + return CanBatchResult(/*can_batch=*/true, /*must_flush=*/false); + } + + // The new write can be batched if all of the following are true: + // [0] The total number of the GSO segments(one write=one segment, including + // the new write) must not exceed |max_segments|. + // [1] It has the same source and destination addresses as already buffered + // writes. + // [2] It won't cause this batch to exceed kMaxGsoPacketSize. + // [3] Already buffered writes all have the same length. + // [4] Length of already buffered writes must >= length of the new write. + // [5] The new packet can be released without delay, or it has the same + // release time as buffered writes. + const BufferedWrite& first = buffered_writes().front(); + const BufferedWrite& last = buffered_writes().back(); + // Whether this packet can be sent without delay, regardless of release time. + const bool can_burst = !SupportsReleaseTime() || !options || + options->release_time_delay.IsZero() || + options->allow_burst; + size_t max_segments = MaxSegments(first.buf_len); + bool can_batch = + buffered_writes().size() < max_segments && // [0] + last.self_address == self_address && // [1] + last.peer_address == peer_address && // [1] + batch_buffer().SizeInUse() + buf_len <= kMaxGsoPacketSize && // [2] + first.buf_len == last.buf_len && // [3] + first.buf_len >= buf_len && // [4] + (can_burst || first.release_time == release_time); // [5] + + // A flush is required if any of the following is true: + // [a] The new write can't be batched. + // [b] Length of the new write is different from the length of already + // buffered writes. + // [c] The total number of the GSO segments, including the new write, reaches + // |max_segments|. + bool must_flush = (!can_batch) || // [a] + (last.buf_len != buf_len) || // [b] + (buffered_writes().size() + 1 == max_segments); // [c] + return CanBatchResult(can_batch, must_flush); +} + +QuicGsoBatchWriter::ReleaseTime QuicGsoBatchWriter::GetReleaseTime( + const PerPacketOptions* options) const { + QUICHE_DCHECK(SupportsReleaseTime()); + + if (options == nullptr) { + return {0, QuicTime::Delta::Zero()}; + } + + const uint64_t now = NowInNanosForReleaseTime(); + const uint64_t ideal_release_time = + now + options->release_time_delay.ToMicroseconds() * 1000; + + if ((options->release_time_delay.IsZero() || options->allow_burst) && + !buffered_writes().empty() && + // If release time of buffered packets is in the past, flush buffered + // packets and buffer this packet at the ideal release time. + (buffered_writes().back().release_time >= now)) { + // Send as soon as possible, but no sooner than the last buffered packet. + const uint64_t actual_release_time = buffered_writes().back().release_time; + + const int64_t offset_ns = actual_release_time - ideal_release_time; + ReleaseTime result{actual_release_time, + QuicTime::Delta::FromMicroseconds(offset_ns / 1000)}; + + QUIC_DVLOG(1) << "ideal_release_time:" << ideal_release_time + << ", actual_release_time:" << actual_release_time + << ", offset:" << result.release_time_offset; + return result; + } + + // Send according to the release time delay. + return {ideal_release_time, QuicTime::Delta::Zero()}; +} + +uint64_t QuicGsoBatchWriter::NowInNanosForReleaseTime() const { + struct timespec ts; + + if (clock_gettime(clockid_for_release_time_, &ts) != 0) { + return 0; + } + + return ts.tv_sec * (1000ULL * 1000 * 1000) + ts.tv_nsec; +} + +// static +void QuicGsoBatchWriter::BuildCmsg(QuicMsgHdr* hdr, + const QuicIpAddress& self_address, + uint16_t gso_size, uint64_t release_time) { + hdr->SetIpInNextCmsg(self_address); + if (gso_size > 0) { + *hdr->GetNextCmsgData(SOL_UDP, UDP_SEGMENT) = gso_size; + } + if (release_time != 0) { + *hdr->GetNextCmsgData(SOL_SOCKET, SO_TXTIME) = release_time; + } +} + +QuicGsoBatchWriter::FlushImplResult QuicGsoBatchWriter::FlushImpl() { + return InternalFlushImpl(BuildCmsg); +} + +} // namespace quic diff --git a/quiche/quic/core/batch_writer/quic_gso_batch_writer.h b/quiche/quic/core/batch_writer/quic_gso_batch_writer.h new file mode 100644 index 000000000000..17657fc2be32 --- /dev/null +++ b/quiche/quic/core/batch_writer/quic_gso_batch_writer.h @@ -0,0 +1,113 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_BATCH_WRITER_QUIC_GSO_BATCH_WRITER_H_ +#define QUICHE_QUIC_CORE_BATCH_WRITER_QUIC_GSO_BATCH_WRITER_H_ + +#include "quiche/quic/core/batch_writer/quic_batch_writer_base.h" + +namespace quic { + +// QuicGsoBatchWriter sends QUIC packets in batches, using UDP socket's generic +// segmentation offload(GSO) capability. +class QUIC_EXPORT_PRIVATE QuicGsoBatchWriter : public QuicUdpBatchWriter { + public: + explicit QuicGsoBatchWriter(int fd); + + // |clockid_for_release_time|: FQ qdisc requires CLOCK_MONOTONIC, EDF requires + // CLOCK_TAI. + QuicGsoBatchWriter(int fd, clockid_t clockid_for_release_time); + + bool SupportsReleaseTime() const final { return supports_release_time_; } + + CanBatchResult CanBatch(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + const PerPacketOptions* options, + uint64_t release_time) const override; + + FlushImplResult FlushImpl() override; + + protected: + // Test only constructor to forcefully enable release time. + struct QUIC_EXPORT_PRIVATE ReleaseTimeForceEnabler {}; + QuicGsoBatchWriter(std::unique_ptr batch_buffer, + int fd, clockid_t clockid_for_release_time, + ReleaseTimeForceEnabler enabler); + + ReleaseTime GetReleaseTime(const PerPacketOptions* options) const override; + + // Get the current time in nanos from |clockid_for_release_time_|. + virtual uint64_t NowInNanosForReleaseTime() const; + + static size_t MaxSegments(size_t gso_size) { + // Max segments should be the min of UDP_MAX_SEGMENTS(64) and + // (((64KB - sizeof(ip hdr) - sizeof(udp hdr)) / MSS) + 1), in the typical + // case of IPv6 packets with 1500-byte MTU, the result is + // ((64KB - 40 - 8) / (1500 - 48)) + 1 = 46 + // However, due a kernel bug, the limit is much lower for tiny gso_sizes. + return gso_size <= 2 ? 16 : 45; + } + + static const int kCmsgSpace = + kCmsgSpaceForIp + kCmsgSpaceForSegmentSize + kCmsgSpaceForTxTime; + static void BuildCmsg(QuicMsgHdr* hdr, const QuicIpAddress& self_address, + uint16_t gso_size, uint64_t release_time); + + template + FlushImplResult InternalFlushImpl(CmsgBuilderT cmsg_builder) { + QUICHE_DCHECK(!IsWriteBlocked()); + QUICHE_DCHECK(!buffered_writes().empty()); + + FlushImplResult result = {WriteResult(WRITE_STATUS_OK, 0), + /*num_packets_sent=*/0, /*bytes_written=*/0}; + WriteResult& write_result = result.write_result; + + int total_bytes = batch_buffer().SizeInUse(); + const BufferedWrite& first = buffered_writes().front(); + char cbuf[CmsgSpace]; + QuicMsgHdr hdr(first.buffer, total_bytes, first.peer_address, cbuf, + sizeof(cbuf)); + + uint16_t gso_size = buffered_writes().size() > 1 ? first.buf_len : 0; + cmsg_builder(&hdr, first.self_address, gso_size, first.release_time); + + write_result = QuicLinuxSocketUtils::WritePacket(fd(), hdr); + QUIC_DVLOG(1) << "Write GSO packet result: " << write_result + << ", fd: " << fd() + << ", self_address: " << first.self_address.ToString() + << ", peer_address: " << first.peer_address.ToString() + << ", num_segments: " << buffered_writes().size() + << ", total_bytes: " << total_bytes + << ", gso_size: " << gso_size + << ", release_time: " << first.release_time; + + // All segments in a GSO packet share the same fate - if the write failed, + // none of them are sent, and it's not needed to call PopBufferedWrite(). + if (write_result.status != WRITE_STATUS_OK) { + return result; + } + + result.num_packets_sent = buffered_writes().size(); + + write_result.bytes_written = total_bytes; + result.bytes_written = total_bytes; + + batch_buffer().PopBufferedWrite(buffered_writes().size()); + + QUIC_BUG_IF(quic_bug_12544_1, !buffered_writes().empty()) + << "All packets should have been written on a successful return"; + return result; + } + + private: + static std::unique_ptr CreateBatchWriterBuffer(); + + const clockid_t clockid_for_release_time_; + const bool supports_release_time_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_BATCH_WRITER_QUIC_GSO_BATCH_WRITER_H_ diff --git a/quiche/quic/core/batch_writer/quic_gso_batch_writer_test.cc b/quiche/quic/core/batch_writer/quic_gso_batch_writer_test.cc new file mode 100644 index 000000000000..efe4c713897b --- /dev/null +++ b/quiche/quic/core/batch_writer/quic_gso_batch_writer_test.cc @@ -0,0 +1,462 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/batch_writer/quic_gso_batch_writer.h" + +#include +#include +#include +#include + +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_mock_syscall_wrapper.h" + +using testing::_; +using testing::Invoke; +using testing::StrictMock; + +namespace quic { +namespace test { +namespace { + +size_t PacketLength(const msghdr* msg) { + size_t length = 0; + for (size_t i = 0; i < msg->msg_iovlen; ++i) { + length += msg->msg_iov[i].iov_len; + } + return length; +} + +uint64_t MillisToNanos(uint64_t milliseconds) { return milliseconds * 1000000; } + +class QUIC_EXPORT_PRIVATE TestQuicGsoBatchWriter : public QuicGsoBatchWriter { + public: + using QuicGsoBatchWriter::batch_buffer; + using QuicGsoBatchWriter::buffered_writes; + using QuicGsoBatchWriter::CanBatch; + using QuicGsoBatchWriter::CanBatchResult; + using QuicGsoBatchWriter::GetReleaseTime; + using QuicGsoBatchWriter::MaxSegments; + using QuicGsoBatchWriter::QuicGsoBatchWriter; + using QuicGsoBatchWriter::ReleaseTime; + + static std::unique_ptr + NewInstanceWithReleaseTimeSupport() { + return std::unique_ptr(new TestQuicGsoBatchWriter( + std::make_unique(), + /*fd=*/-1, CLOCK_MONOTONIC, ReleaseTimeForceEnabler())); + } + + uint64_t NowInNanosForReleaseTime() const override { + return MillisToNanos(forced_release_time_ms_); + } + + void ForceReleaseTimeMs(uint64_t forced_release_time_ms) { + forced_release_time_ms_ = forced_release_time_ms; + } + + private: + uint64_t forced_release_time_ms_ = 1; +}; + +struct QUIC_EXPORT_PRIVATE TestPerPacketOptions : public PerPacketOptions { + std::unique_ptr Clone() const override { + return std::make_unique(*this); + } +}; + +// TestBufferedWrite is a copy-constructible BufferedWrite. +struct QUIC_EXPORT_PRIVATE TestBufferedWrite : public BufferedWrite { + using BufferedWrite::BufferedWrite; + TestBufferedWrite(const TestBufferedWrite& other) + : BufferedWrite(other.buffer, other.buf_len, other.self_address, + other.peer_address, + other.options ? other.options->Clone() + : std::unique_ptr(), + other.release_time) {} +}; + +// Pointed to by all instances of |BatchCriteriaTestData|. Content not used. +static char unused_packet_buffer[kMaxOutgoingPacketSize]; + +struct QUIC_EXPORT_PRIVATE BatchCriteriaTestData { + BatchCriteriaTestData(size_t buf_len, const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + uint64_t release_time, bool can_batch, bool must_flush) + : buffered_write(unused_packet_buffer, buf_len, self_address, + peer_address, std::unique_ptr(), + release_time), + can_batch(can_batch), + must_flush(must_flush) {} + + TestBufferedWrite buffered_write; + // Expected value of CanBatchResult.can_batch when batching |buffered_write|. + bool can_batch; + // Expected value of CanBatchResult.must_flush when batching |buffered_write|. + bool must_flush; +}; + +std::vector BatchCriteriaTestData_SizeDecrease() { + const QuicIpAddress self_addr; + const QuicSocketAddress peer_addr; + std::vector test_data_table = { + // clang-format off + // buf_len self_addr peer_addr t_rel can_batch must_flush + {1350, self_addr, peer_addr, 0, true, false}, + {1350, self_addr, peer_addr, 0, true, false}, + {1350, self_addr, peer_addr, 0, true, false}, + {39, self_addr, peer_addr, 0, true, true}, + {39, self_addr, peer_addr, 0, false, true}, + {1350, self_addr, peer_addr, 0, false, true}, + // clang-format on + }; + return test_data_table; +} + +std::vector BatchCriteriaTestData_SizeIncrease() { + const QuicIpAddress self_addr; + const QuicSocketAddress peer_addr; + std::vector test_data_table = { + // clang-format off + // buf_len self_addr peer_addr t_rel can_batch must_flush + {1350, self_addr, peer_addr, 0, true, false}, + {1350, self_addr, peer_addr, 0, true, false}, + {1350, self_addr, peer_addr, 0, true, false}, + {1351, self_addr, peer_addr, 0, false, true}, + // clang-format on + }; + return test_data_table; +} + +std::vector BatchCriteriaTestData_AddressChange() { + const QuicIpAddress self_addr1 = QuicIpAddress::Loopback4(); + const QuicIpAddress self_addr2 = QuicIpAddress::Loopback6(); + const QuicSocketAddress peer_addr1(self_addr1, 666); + const QuicSocketAddress peer_addr2(self_addr1, 777); + const QuicSocketAddress peer_addr3(self_addr2, 666); + const QuicSocketAddress peer_addr4(self_addr2, 777); + std::vector test_data_table = { + // clang-format off + // buf_len self_addr peer_addr t_rel can_batch must_flush + {1350, self_addr1, peer_addr1, 0, true, false}, + {1350, self_addr1, peer_addr1, 0, true, false}, + {1350, self_addr1, peer_addr1, 0, true, false}, + {1350, self_addr2, peer_addr1, 0, false, true}, + {1350, self_addr1, peer_addr2, 0, false, true}, + {1350, self_addr1, peer_addr3, 0, false, true}, + {1350, self_addr1, peer_addr4, 0, false, true}, + {1350, self_addr1, peer_addr4, 0, false, true}, + // clang-format on + }; + return test_data_table; +} + +std::vector BatchCriteriaTestData_ReleaseTime1() { + const QuicIpAddress self_addr; + const QuicSocketAddress peer_addr; + std::vector test_data_table = { + // clang-format off + // buf_len self_addr peer_addr t_rel can_batch must_flush + {1350, self_addr, peer_addr, 5, true, false}, + {1350, self_addr, peer_addr, 5, true, false}, + {1350, self_addr, peer_addr, 5, true, false}, + {1350, self_addr, peer_addr, 9, false, true}, + // clang-format on + }; + return test_data_table; +} + +std::vector BatchCriteriaTestData_ReleaseTime2() { + const QuicIpAddress self_addr; + const QuicSocketAddress peer_addr; + std::vector test_data_table = { + // clang-format off + // buf_len self_addr peer_addr t_rel can_batch must_flush + {1350, self_addr, peer_addr, 0, true, false}, + {1350, self_addr, peer_addr, 0, true, false}, + {1350, self_addr, peer_addr, 0, true, false}, + {1350, self_addr, peer_addr, 9, false, true}, + // clang-format on + }; + return test_data_table; +} + +std::vector BatchCriteriaTestData_MaxSegments( + size_t gso_size) { + const QuicIpAddress self_addr; + const QuicSocketAddress peer_addr; + std::vector test_data_table; + size_t max_segments = TestQuicGsoBatchWriter::MaxSegments(gso_size); + for (size_t i = 0; i < max_segments; ++i) { + bool is_last_in_batch = (i + 1 == max_segments); + test_data_table.push_back({gso_size, self_addr, peer_addr, + /*release_time=*/0, true, is_last_in_batch}); + } + test_data_table.push_back( + {gso_size, self_addr, peer_addr, /*release_time=*/0, false, true}); + return test_data_table; +} + +class QuicGsoBatchWriterTest : public QuicTest { + protected: + WriteResult WritePacket(QuicGsoBatchWriter* writer, size_t packet_size) { + return writer->WritePacket(&packet_buffer_[0], packet_size, self_address_, + peer_address_, nullptr); + } + + WriteResult WritePacketWithOptions(QuicGsoBatchWriter* writer, + PerPacketOptions* options) { + return writer->WritePacket(&packet_buffer_[0], 1350, self_address_, + peer_address_, options); + } + + QuicIpAddress self_address_ = QuicIpAddress::Any4(); + QuicSocketAddress peer_address_{QuicIpAddress::Any4(), 443}; + char packet_buffer_[1500]; + StrictMock mock_syscalls_; + ScopedGlobalSyscallWrapperOverride syscall_override_{&mock_syscalls_}; +}; + +TEST_F(QuicGsoBatchWriterTest, BatchCriteria) { + std::unique_ptr writer; + + std::vector> test_data_tables; + test_data_tables.emplace_back(BatchCriteriaTestData_SizeDecrease()); + test_data_tables.emplace_back(BatchCriteriaTestData_SizeIncrease()); + test_data_tables.emplace_back(BatchCriteriaTestData_AddressChange()); + test_data_tables.emplace_back(BatchCriteriaTestData_ReleaseTime1()); + test_data_tables.emplace_back(BatchCriteriaTestData_ReleaseTime2()); + test_data_tables.emplace_back(BatchCriteriaTestData_MaxSegments(1)); + test_data_tables.emplace_back(BatchCriteriaTestData_MaxSegments(2)); + test_data_tables.emplace_back(BatchCriteriaTestData_MaxSegments(1350)); + + for (size_t i = 0; i < test_data_tables.size(); ++i) { + writer = TestQuicGsoBatchWriter::NewInstanceWithReleaseTimeSupport(); + + const auto& test_data_table = test_data_tables[i]; + for (size_t j = 0; j < test_data_table.size(); ++j) { + const BatchCriteriaTestData& test_data = test_data_table[j]; + SCOPED_TRACE(testing::Message() << "i=" << i << ", j=" << j); + TestPerPacketOptions options; + options.release_time_delay = QuicTime::Delta::FromMicroseconds( + test_data.buffered_write.release_time); + TestQuicGsoBatchWriter::CanBatchResult result = writer->CanBatch( + test_data.buffered_write.buffer, test_data.buffered_write.buf_len, + test_data.buffered_write.self_address, + test_data.buffered_write.peer_address, &options, + test_data.buffered_write.release_time); + + ASSERT_EQ(test_data.can_batch, result.can_batch); + ASSERT_EQ(test_data.must_flush, result.must_flush); + + if (result.can_batch) { + ASSERT_TRUE(writer->batch_buffer() + .PushBufferedWrite( + test_data.buffered_write.buffer, + test_data.buffered_write.buf_len, + test_data.buffered_write.self_address, + test_data.buffered_write.peer_address, &options, + test_data.buffered_write.release_time) + .succeeded); + } + } + } +} + +TEST_F(QuicGsoBatchWriterTest, WriteSuccess) { + TestQuicGsoBatchWriter writer(/*fd=*/-1); + + ASSERT_EQ(WriteResult(WRITE_STATUS_OK, 0), WritePacket(&writer, 1000)); + + EXPECT_CALL(mock_syscalls_, Sendmsg(_, _, _)) + .WillOnce(Invoke([](int /*sockfd*/, const msghdr* msg, int /*flags*/) { + EXPECT_EQ(1100u, PacketLength(msg)); + return 1100; + })); + ASSERT_EQ(WriteResult(WRITE_STATUS_OK, 1100), WritePacket(&writer, 100)); + ASSERT_EQ(0u, writer.batch_buffer().SizeInUse()); + ASSERT_EQ(0u, writer.buffered_writes().size()); +} + +TEST_F(QuicGsoBatchWriterTest, WriteBlockDataNotBuffered) { + TestQuicGsoBatchWriter writer(/*fd=*/-1); + + ASSERT_EQ(WriteResult(WRITE_STATUS_OK, 0), WritePacket(&writer, 100)); + ASSERT_EQ(WriteResult(WRITE_STATUS_OK, 0), WritePacket(&writer, 100)); + + EXPECT_CALL(mock_syscalls_, Sendmsg(_, _, _)) + .WillOnce(Invoke([](int /*sockfd*/, const msghdr* msg, int /*flags*/) { + EXPECT_EQ(200u, PacketLength(msg)); + errno = EWOULDBLOCK; + return -1; + })); + ASSERT_EQ(WriteResult(WRITE_STATUS_BLOCKED, EWOULDBLOCK), + WritePacket(&writer, 150)); + ASSERT_EQ(200u, writer.batch_buffer().SizeInUse()); + ASSERT_EQ(2u, writer.buffered_writes().size()); +} + +TEST_F(QuicGsoBatchWriterTest, WriteBlockDataBuffered) { + TestQuicGsoBatchWriter writer(/*fd=*/-1); + + ASSERT_EQ(WriteResult(WRITE_STATUS_OK, 0), WritePacket(&writer, 100)); + ASSERT_EQ(WriteResult(WRITE_STATUS_OK, 0), WritePacket(&writer, 100)); + + EXPECT_CALL(mock_syscalls_, Sendmsg(_, _, _)) + .WillOnce(Invoke([](int /*sockfd*/, const msghdr* msg, int /*flags*/) { + EXPECT_EQ(250u, PacketLength(msg)); + errno = EWOULDBLOCK; + return -1; + })); + ASSERT_EQ(WriteResult(WRITE_STATUS_BLOCKED_DATA_BUFFERED, EWOULDBLOCK), + WritePacket(&writer, 50)); + + EXPECT_TRUE(writer.IsWriteBlocked()); + + ASSERT_EQ(250u, writer.batch_buffer().SizeInUse()); + ASSERT_EQ(3u, writer.buffered_writes().size()); +} + +TEST_F(QuicGsoBatchWriterTest, WriteErrorWithoutDataBuffered) { + TestQuicGsoBatchWriter writer(/*fd=*/-1); + + ASSERT_EQ(WriteResult(WRITE_STATUS_OK, 0), WritePacket(&writer, 100)); + ASSERT_EQ(WriteResult(WRITE_STATUS_OK, 0), WritePacket(&writer, 100)); + + EXPECT_CALL(mock_syscalls_, Sendmsg(_, _, _)) + .WillOnce(Invoke([](int /*sockfd*/, const msghdr* msg, int /*flags*/) { + EXPECT_EQ(200u, PacketLength(msg)); + errno = EPERM; + return -1; + })); + WriteResult error_result = WritePacket(&writer, 150); + ASSERT_EQ(WriteResult(WRITE_STATUS_ERROR, EPERM), error_result); + + ASSERT_EQ(3u, error_result.dropped_packets); + ASSERT_EQ(0u, writer.batch_buffer().SizeInUse()); + ASSERT_EQ(0u, writer.buffered_writes().size()); +} + +TEST_F(QuicGsoBatchWriterTest, WriteErrorAfterDataBuffered) { + TestQuicGsoBatchWriter writer(/*fd=*/-1); + + ASSERT_EQ(WriteResult(WRITE_STATUS_OK, 0), WritePacket(&writer, 100)); + ASSERT_EQ(WriteResult(WRITE_STATUS_OK, 0), WritePacket(&writer, 100)); + + EXPECT_CALL(mock_syscalls_, Sendmsg(_, _, _)) + .WillOnce(Invoke([](int /*sockfd*/, const msghdr* msg, int /*flags*/) { + EXPECT_EQ(250u, PacketLength(msg)); + errno = EPERM; + return -1; + })); + WriteResult error_result = WritePacket(&writer, 50); + ASSERT_EQ(WriteResult(WRITE_STATUS_ERROR, EPERM), error_result); + + ASSERT_EQ(3u, error_result.dropped_packets); + ASSERT_EQ(0u, writer.batch_buffer().SizeInUse()); + ASSERT_EQ(0u, writer.buffered_writes().size()); +} + +TEST_F(QuicGsoBatchWriterTest, FlushError) { + TestQuicGsoBatchWriter writer(/*fd=*/-1); + + ASSERT_EQ(WriteResult(WRITE_STATUS_OK, 0), WritePacket(&writer, 100)); + ASSERT_EQ(WriteResult(WRITE_STATUS_OK, 0), WritePacket(&writer, 100)); + + EXPECT_CALL(mock_syscalls_, Sendmsg(_, _, _)) + .WillOnce(Invoke([](int /*sockfd*/, const msghdr* msg, int /*flags*/) { + EXPECT_EQ(200u, PacketLength(msg)); + errno = EINVAL; + return -1; + })); + WriteResult error_result = writer.Flush(); + ASSERT_EQ(WriteResult(WRITE_STATUS_ERROR, EINVAL), error_result); + + ASSERT_EQ(2u, error_result.dropped_packets); + ASSERT_EQ(0u, writer.batch_buffer().SizeInUse()); + ASSERT_EQ(0u, writer.buffered_writes().size()); +} + +TEST_F(QuicGsoBatchWriterTest, ReleaseTimeNullOptions) { + auto writer = TestQuicGsoBatchWriter::NewInstanceWithReleaseTimeSupport(); + EXPECT_EQ(0u, writer->GetReleaseTime(nullptr).actual_release_time); +} + +TEST_F(QuicGsoBatchWriterTest, ReleaseTime) { + const WriteResult write_buffered(WRITE_STATUS_OK, 0); + + auto writer = TestQuicGsoBatchWriter::NewInstanceWithReleaseTimeSupport(); + + TestPerPacketOptions options; + EXPECT_TRUE(options.release_time_delay.IsZero()); + EXPECT_FALSE(options.allow_burst); + EXPECT_EQ(MillisToNanos(1), + writer->GetReleaseTime(&options).actual_release_time); + + // The 1st packet has no delay. + WriteResult result = WritePacketWithOptions(writer.get(), &options); + ASSERT_EQ(write_buffered, result); + EXPECT_EQ(MillisToNanos(1), writer->buffered_writes().back().release_time); + EXPECT_EQ(result.send_time_offset, QuicTime::Delta::Zero()); + + // The 2nd packet has some delay, but allows burst. + options.release_time_delay = QuicTime::Delta::FromMilliseconds(3); + options.allow_burst = true; + result = WritePacketWithOptions(writer.get(), &options); + ASSERT_EQ(write_buffered, result); + EXPECT_EQ(MillisToNanos(1), writer->buffered_writes().back().release_time); + EXPECT_EQ(result.send_time_offset, QuicTime::Delta::FromMilliseconds(-3)); + + // The 3rd packet has more delay and does not allow burst. + // The first 2 packets are flushed due to different release time. + EXPECT_CALL(mock_syscalls_, Sendmsg(_, _, _)) + .WillOnce(Invoke([](int /*sockfd*/, const msghdr* msg, int /*flags*/) { + EXPECT_EQ(2700u, PacketLength(msg)); + errno = 0; + return 0; + })); + options.release_time_delay = QuicTime::Delta::FromMilliseconds(5); + options.allow_burst = false; + result = WritePacketWithOptions(writer.get(), &options); + ASSERT_EQ(WriteResult(WRITE_STATUS_OK, 2700), result); + EXPECT_EQ(MillisToNanos(6), writer->buffered_writes().back().release_time); + EXPECT_EQ(result.send_time_offset, QuicTime::Delta::Zero()); + + // The 4th packet has same delay, but allows burst. + options.allow_burst = true; + result = WritePacketWithOptions(writer.get(), &options); + ASSERT_EQ(write_buffered, result); + EXPECT_EQ(MillisToNanos(6), writer->buffered_writes().back().release_time); + EXPECT_EQ(result.send_time_offset, QuicTime::Delta::Zero()); + + // The 5th packet has same delay, allows burst, but is shorter. + // Packets 3,4 and 5 are flushed. + EXPECT_CALL(mock_syscalls_, Sendmsg(_, _, _)) + .WillOnce(Invoke([](int /*sockfd*/, const msghdr* msg, int /*flags*/) { + EXPECT_EQ(3000u, PacketLength(msg)); + errno = 0; + return 0; + })); + options.allow_burst = true; + EXPECT_EQ(MillisToNanos(6), + writer->GetReleaseTime(&options).actual_release_time); + ASSERT_EQ(WriteResult(WRITE_STATUS_OK, 3000), + writer->WritePacket(&packet_buffer_[0], 300, self_address_, + peer_address_, &options)); + EXPECT_TRUE(writer->buffered_writes().empty()); + + // Pretend 1ms has elapsed and the 6th packet has 1ms less delay. In other + // words, the release time should still be the same as packets 3-5. + writer->ForceReleaseTimeMs(2); + options.release_time_delay = QuicTime::Delta::FromMilliseconds(4); + result = WritePacketWithOptions(writer.get(), &options); + ASSERT_EQ(write_buffered, result); + EXPECT_EQ(MillisToNanos(6), writer->buffered_writes().back().release_time); + EXPECT_EQ(result.send_time_offset, QuicTime::Delta::Zero()); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/batch_writer/quic_sendmmsg_batch_writer.cc b/quiche/quic/core/batch_writer/quic_sendmmsg_batch_writer.cc new file mode 100644 index 000000000000..8568e26f7c2b --- /dev/null +++ b/quiche/quic/core/batch_writer/quic_sendmmsg_batch_writer.cc @@ -0,0 +1,81 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/batch_writer/quic_sendmmsg_batch_writer.h" + +namespace quic { + +QuicSendmmsgBatchWriter::QuicSendmmsgBatchWriter( + std::unique_ptr batch_buffer, int fd) + : QuicUdpBatchWriter(std::move(batch_buffer), fd) {} + +QuicSendmmsgBatchWriter::CanBatchResult QuicSendmmsgBatchWriter::CanBatch( + const char* /*buffer*/, size_t /*buf_len*/, + const QuicIpAddress& /*self_address*/, + const QuicSocketAddress& /*peer_address*/, + const PerPacketOptions* /*options*/, uint64_t /*release_time*/) const { + return CanBatchResult(/*can_batch=*/true, /*must_flush=*/false); +} + +QuicSendmmsgBatchWriter::FlushImplResult QuicSendmmsgBatchWriter::FlushImpl() { + return InternalFlushImpl( + kCmsgSpaceForIp, + [](QuicMMsgHdr* mhdr, int i, const BufferedWrite& buffered_write) { + mhdr->SetIpInNextCmsg(i, buffered_write.self_address); + }); +} + +QuicSendmmsgBatchWriter::FlushImplResult +QuicSendmmsgBatchWriter::InternalFlushImpl(size_t cmsg_space, + const CmsgBuilder& cmsg_builder) { + QUICHE_DCHECK(!IsWriteBlocked()); + QUICHE_DCHECK(!buffered_writes().empty()); + + FlushImplResult result = {WriteResult(WRITE_STATUS_OK, 0), + /*num_packets_sent=*/0, /*bytes_written=*/0}; + WriteResult& write_result = result.write_result; + + auto first = buffered_writes().cbegin(); + const auto last = buffered_writes().cend(); + while (first != last) { + QuicMMsgHdr mhdr(first, last, cmsg_space, cmsg_builder); + + int num_packets_sent; + write_result = QuicLinuxSocketUtils::WriteMultiplePackets( + fd(), &mhdr, &num_packets_sent); + QUIC_DVLOG(1) << "WriteMultiplePackets sent " << num_packets_sent + << " out of " << mhdr.num_msgs() + << " packets. WriteResult=" << write_result; + + if (write_result.status != WRITE_STATUS_OK) { + QUICHE_DCHECK_EQ(0, num_packets_sent); + break; + } else if (num_packets_sent == 0) { + QUIC_BUG(quic_bug_10825_1) + << "WriteMultiplePackets returned OK, but no packets were sent."; + write_result = WriteResult(WRITE_STATUS_ERROR, EIO); + break; + } + + first += num_packets_sent; + + result.num_packets_sent += num_packets_sent; + result.bytes_written += write_result.bytes_written; + } + + // Call PopBufferedWrite() even if write_result.status is not WRITE_STATUS_OK, + // to deal with partial writes. + batch_buffer().PopBufferedWrite(result.num_packets_sent); + + if (write_result.status != WRITE_STATUS_OK) { + return result; + } + + QUIC_BUG_IF(quic_bug_12537_1, !buffered_writes().empty()) + << "All packets should have been written on a successful return"; + write_result.bytes_written = result.bytes_written; + return result; +} + +} // namespace quic diff --git a/quiche/quic/core/batch_writer/quic_sendmmsg_batch_writer.h b/quiche/quic/core/batch_writer/quic_sendmmsg_batch_writer.h new file mode 100644 index 000000000000..04a1b28374cb --- /dev/null +++ b/quiche/quic/core/batch_writer/quic_sendmmsg_batch_writer.h @@ -0,0 +1,34 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_BATCH_WRITER_QUIC_SENDMMSG_BATCH_WRITER_H_ +#define QUICHE_QUIC_CORE_BATCH_WRITER_QUIC_SENDMMSG_BATCH_WRITER_H_ + +#include "quiche/quic/core/batch_writer/quic_batch_writer_base.h" +#include "quiche/quic/core/quic_linux_socket_utils.h" + +namespace quic { + +class QUIC_EXPORT_PRIVATE QuicSendmmsgBatchWriter : public QuicUdpBatchWriter { + public: + QuicSendmmsgBatchWriter(std::unique_ptr batch_buffer, + int fd); + + CanBatchResult CanBatch(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + const PerPacketOptions* options, + uint64_t release_time) const override; + + FlushImplResult FlushImpl() override; + + protected: + using CmsgBuilder = QuicMMsgHdr::ControlBufferInitializer; + FlushImplResult InternalFlushImpl(size_t cmsg_space, + const CmsgBuilder& cmsg_builder); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_BATCH_WRITER_QUIC_SENDMMSG_BATCH_WRITER_H_ diff --git a/quiche/quic/core/batch_writer/quic_sendmmsg_batch_writer_test.cc b/quiche/quic/core/batch_writer/quic_sendmmsg_batch_writer_test.cc new file mode 100644 index 000000000000..c8a213a2ab33 --- /dev/null +++ b/quiche/quic/core/batch_writer/quic_sendmmsg_batch_writer_test.cc @@ -0,0 +1,15 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/batch_writer/quic_sendmmsg_batch_writer.h" + +namespace quic { +namespace test { +namespace { + +// Add tests here. + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/chlo_extractor.cc b/quiche/quic/core/chlo_extractor.cc new file mode 100644 index 000000000000..e55df8a976d0 --- /dev/null +++ b/quiche/quic/core/chlo_extractor.cc @@ -0,0 +1,362 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/chlo_extractor.h" + +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/crypto_framer.h" +#include "quiche/quic/core/crypto/crypto_handshake.h" +#include "quiche/quic/core/crypto/crypto_handshake_message.h" +#include "quiche/quic/core/crypto/quic_decrypter.h" +#include "quiche/quic/core/crypto/quic_encrypter.h" +#include "quiche/quic/core/frames/quic_ack_frequency_frame.h" +#include "quiche/quic/core/quic_framer.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" + +namespace quic { + +namespace { + +class ChloFramerVisitor : public QuicFramerVisitorInterface, + public CryptoFramerVisitorInterface { + public: + ChloFramerVisitor(QuicFramer* framer, + const QuicTagVector& create_session_tag_indicators, + ChloExtractor::Delegate* delegate); + + ~ChloFramerVisitor() override = default; + + // QuicFramerVisitorInterface implementation + void OnError(QuicFramer* /*framer*/) override {} + bool OnProtocolVersionMismatch(ParsedQuicVersion version) override; + void OnPacket() override {} + void OnPublicResetPacket(const QuicPublicResetPacket& /*packet*/) override {} + void OnVersionNegotiationPacket( + const QuicVersionNegotiationPacket& /*packet*/) override {} + void OnRetryPacket(QuicConnectionId /*original_connection_id*/, + QuicConnectionId /*new_connection_id*/, + absl::string_view /*retry_token*/, + absl::string_view /*retry_integrity_tag*/, + absl::string_view /*retry_without_tag*/) override {} + bool OnUnauthenticatedPublicHeader(const QuicPacketHeader& header) override; + bool OnUnauthenticatedHeader(const QuicPacketHeader& header) override; + void OnDecryptedPacket(size_t /*length*/, + EncryptionLevel /*level*/) override {} + bool OnPacketHeader(const QuicPacketHeader& header) override; + void OnCoalescedPacket(const QuicEncryptedPacket& packet) override; + void OnUndecryptablePacket(const QuicEncryptedPacket& packet, + EncryptionLevel decryption_level, + bool has_decryption_key) override; + bool OnStreamFrame(const QuicStreamFrame& frame) override; + bool OnCryptoFrame(const QuicCryptoFrame& frame) override; + bool OnAckFrameStart(QuicPacketNumber largest_acked, + QuicTime::Delta ack_delay_time) override; + bool OnAckRange(QuicPacketNumber start, QuicPacketNumber end) override; + bool OnAckTimestamp(QuicPacketNumber packet_number, + QuicTime timestamp) override; + bool OnAckFrameEnd(QuicPacketNumber start, + const absl::optional& ecn_counts) override; + bool OnStopWaitingFrame(const QuicStopWaitingFrame& frame) override; + bool OnPingFrame(const QuicPingFrame& frame) override; + bool OnRstStreamFrame(const QuicRstStreamFrame& frame) override; + bool OnConnectionCloseFrame(const QuicConnectionCloseFrame& frame) override; + bool OnNewConnectionIdFrame(const QuicNewConnectionIdFrame& frame) override; + bool OnRetireConnectionIdFrame( + const QuicRetireConnectionIdFrame& frame) override; + bool OnNewTokenFrame(const QuicNewTokenFrame& frame) override; + bool OnStopSendingFrame(const QuicStopSendingFrame& frame) override; + bool OnPathChallengeFrame(const QuicPathChallengeFrame& frame) override; + bool OnPathResponseFrame(const QuicPathResponseFrame& frame) override; + bool OnGoAwayFrame(const QuicGoAwayFrame& frame) override; + bool OnMaxStreamsFrame(const QuicMaxStreamsFrame& frame) override; + bool OnStreamsBlockedFrame(const QuicStreamsBlockedFrame& frame) override; + bool OnWindowUpdateFrame(const QuicWindowUpdateFrame& frame) override; + bool OnBlockedFrame(const QuicBlockedFrame& frame) override; + bool OnPaddingFrame(const QuicPaddingFrame& frame) override; + bool OnMessageFrame(const QuicMessageFrame& frame) override; + bool OnHandshakeDoneFrame(const QuicHandshakeDoneFrame& frame) override; + bool OnAckFrequencyFrame(const QuicAckFrequencyFrame& farme) override; + void OnPacketComplete() override {} + bool IsValidStatelessResetToken( + const StatelessResetToken& token) const override; + void OnAuthenticatedIetfStatelessResetPacket( + const QuicIetfStatelessResetPacket& /*packet*/) override {} + void OnKeyUpdate(KeyUpdateReason /*reason*/) override; + void OnDecryptedFirstPacketInKeyPhase() override; + std::unique_ptr AdvanceKeysAndCreateCurrentOneRttDecrypter() + override; + std::unique_ptr CreateCurrentOneRttEncrypter() override; + + // CryptoFramerVisitorInterface implementation. + void OnError(CryptoFramer* framer) override; + void OnHandshakeMessage(const CryptoHandshakeMessage& message) override; + + // Shared implementation between OnStreamFrame and OnCryptoFrame. + bool OnHandshakeData(absl::string_view data); + + bool found_chlo() { return found_chlo_; } + bool chlo_contains_tags() { return chlo_contains_tags_; } + + private: + QuicFramer* framer_; + const QuicTagVector& create_session_tag_indicators_; + ChloExtractor::Delegate* delegate_; + bool found_chlo_; + bool chlo_contains_tags_; + QuicConnectionId connection_id_; +}; + +ChloFramerVisitor::ChloFramerVisitor( + QuicFramer* framer, const QuicTagVector& create_session_tag_indicators, + ChloExtractor::Delegate* delegate) + : framer_(framer), + create_session_tag_indicators_(create_session_tag_indicators), + delegate_(delegate), + found_chlo_(false), + chlo_contains_tags_(false), + connection_id_(EmptyQuicConnectionId()) {} + +bool ChloFramerVisitor::OnProtocolVersionMismatch(ParsedQuicVersion version) { + if (!framer_->IsSupportedVersion(version)) { + return false; + } + framer_->set_version(version); + return true; +} + +bool ChloFramerVisitor::OnUnauthenticatedPublicHeader( + const QuicPacketHeader& header) { + connection_id_ = header.destination_connection_id; + // QuicFramer creates a NullEncrypter and NullDecrypter at level + // ENCRYPTION_INITIAL. While those are the correct ones to use with some + // versions of QUIC, others use the IETF-style initial crypters, so those need + // to be created and installed. + framer_->SetInitialObfuscators(header.destination_connection_id); + return true; +} +bool ChloFramerVisitor::OnUnauthenticatedHeader( + const QuicPacketHeader& /*header*/) { + return true; +} +bool ChloFramerVisitor::OnPacketHeader(const QuicPacketHeader& /*header*/) { + return true; +} + +void ChloFramerVisitor::OnCoalescedPacket( + const QuicEncryptedPacket& /*packet*/) {} + +void ChloFramerVisitor::OnUndecryptablePacket( + const QuicEncryptedPacket& /*packet*/, EncryptionLevel /*decryption_level*/, + bool /*has_decryption_key*/) {} + +bool ChloFramerVisitor::OnStreamFrame(const QuicStreamFrame& frame) { + if (QuicVersionUsesCryptoFrames(framer_->transport_version())) { + // CHLO will be sent in CRYPTO frames in v47 and above. + return false; + } + absl::string_view data(frame.data_buffer, frame.data_length); + if (QuicUtils::IsCryptoStreamId(framer_->transport_version(), + frame.stream_id) && + frame.offset == 0 && absl::StartsWith(data, "CHLO")) { + return OnHandshakeData(data); + } + return true; +} + +bool ChloFramerVisitor::OnCryptoFrame(const QuicCryptoFrame& frame) { + if (!QuicVersionUsesCryptoFrames(framer_->transport_version())) { + // CHLO will be in stream frames before v47. + return false; + } + absl::string_view data(frame.data_buffer, frame.data_length); + if (frame.offset == 0 && absl::StartsWith(data, "CHLO")) { + return OnHandshakeData(data); + } + return true; +} + +bool ChloFramerVisitor::OnHandshakeData(absl::string_view data) { + CryptoFramer crypto_framer; + crypto_framer.set_visitor(this); + if (!crypto_framer.ProcessInput(data)) { + return false; + } + // Interrogate the crypto framer and see if there are any + // intersecting tags between what we saw in the maybe-CHLO and the + // indicator set. + for (const QuicTag tag : create_session_tag_indicators_) { + if (crypto_framer.HasTag(tag)) { + chlo_contains_tags_ = true; + } + } + if (chlo_contains_tags_ && delegate_) { + // Unfortunately, because this is a partial CHLO, + // OnHandshakeMessage was never called, so the ALPN was never + // extracted. Fake it up a bit and send it to the delegate so that + // the correct dispatch can happen. + crypto_framer.ForceHandshake(); + } + + return true; +} + +bool ChloFramerVisitor::OnAckFrameStart(QuicPacketNumber /*largest_acked*/, + QuicTime::Delta /*ack_delay_time*/) { + return true; +} + +bool ChloFramerVisitor::OnAckRange(QuicPacketNumber /*start*/, + QuicPacketNumber /*end*/) { + return true; +} + +bool ChloFramerVisitor::OnAckTimestamp(QuicPacketNumber /*packet_number*/, + QuicTime /*timestamp*/) { + return true; +} + +bool ChloFramerVisitor::OnAckFrameEnd( + QuicPacketNumber /*start*/, + const absl::optional& /*ecn_counts*/) { + return true; +} + +bool ChloFramerVisitor::OnStopWaitingFrame( + const QuicStopWaitingFrame& /*frame*/) { + return true; +} + +bool ChloFramerVisitor::OnPingFrame(const QuicPingFrame& /*frame*/) { + return true; +} + +bool ChloFramerVisitor::OnRstStreamFrame(const QuicRstStreamFrame& /*frame*/) { + return true; +} + +bool ChloFramerVisitor::OnConnectionCloseFrame( + const QuicConnectionCloseFrame& /*frame*/) { + return true; +} + +bool ChloFramerVisitor::OnStopSendingFrame( + const QuicStopSendingFrame& /*frame*/) { + return true; +} + +bool ChloFramerVisitor::OnPathChallengeFrame( + const QuicPathChallengeFrame& /*frame*/) { + return true; +} + +bool ChloFramerVisitor::OnPathResponseFrame( + const QuicPathResponseFrame& /*frame*/) { + return true; +} + +bool ChloFramerVisitor::OnGoAwayFrame(const QuicGoAwayFrame& /*frame*/) { + return true; +} + +bool ChloFramerVisitor::OnWindowUpdateFrame( + const QuicWindowUpdateFrame& /*frame*/) { + return true; +} + +bool ChloFramerVisitor::OnBlockedFrame(const QuicBlockedFrame& /*frame*/) { + return true; +} + +bool ChloFramerVisitor::OnNewConnectionIdFrame( + const QuicNewConnectionIdFrame& /*frame*/) { + return true; +} + +bool ChloFramerVisitor::OnRetireConnectionIdFrame( + const QuicRetireConnectionIdFrame& /*frame*/) { + return true; +} + +bool ChloFramerVisitor::OnNewTokenFrame(const QuicNewTokenFrame& /*frame*/) { + return true; +} + +bool ChloFramerVisitor::OnPaddingFrame(const QuicPaddingFrame& /*frame*/) { + return true; +} + +bool ChloFramerVisitor::OnMessageFrame(const QuicMessageFrame& /*frame*/) { + return true; +} + +bool ChloFramerVisitor::OnHandshakeDoneFrame( + const QuicHandshakeDoneFrame& /*frame*/) { + return true; +} + +bool ChloFramerVisitor::OnAckFrequencyFrame( + const QuicAckFrequencyFrame& /*frame*/) { + return true; +} + +bool ChloFramerVisitor::IsValidStatelessResetToken( + const StatelessResetToken& /*token*/) const { + return false; +} + +bool ChloFramerVisitor::OnMaxStreamsFrame( + const QuicMaxStreamsFrame& /*frame*/) { + return true; +} + +bool ChloFramerVisitor::OnStreamsBlockedFrame( + const QuicStreamsBlockedFrame& /*frame*/) { + return true; +} + +void ChloFramerVisitor::OnKeyUpdate(KeyUpdateReason /*reason*/) {} + +void ChloFramerVisitor::OnDecryptedFirstPacketInKeyPhase() {} + +std::unique_ptr +ChloFramerVisitor::AdvanceKeysAndCreateCurrentOneRttDecrypter() { + return nullptr; +} + +std::unique_ptr +ChloFramerVisitor::CreateCurrentOneRttEncrypter() { + return nullptr; +} + +void ChloFramerVisitor::OnError(CryptoFramer* /*framer*/) {} + +void ChloFramerVisitor::OnHandshakeMessage( + const CryptoHandshakeMessage& message) { + if (delegate_ != nullptr) { + delegate_->OnChlo(framer_->transport_version(), connection_id_, message); + } + found_chlo_ = true; +} + +} // namespace + +// static +bool ChloExtractor::Extract(const QuicEncryptedPacket& packet, + ParsedQuicVersion version, + const QuicTagVector& create_session_tag_indicators, + Delegate* delegate, uint8_t connection_id_length) { + QUIC_DVLOG(1) << "Extracting CHLO using version " << version; + QuicFramer framer({version}, QuicTime::Zero(), Perspective::IS_SERVER, + connection_id_length); + ChloFramerVisitor visitor(&framer, create_session_tag_indicators, delegate); + framer.set_visitor(&visitor); + if (!framer.ProcessPacket(packet)) { + return false; + } + return visitor.found_chlo() || visitor.chlo_contains_tags(); +} + +} // namespace quic diff --git a/quiche/quic/core/chlo_extractor.h b/quiche/quic/core/chlo_extractor.h new file mode 100644 index 000000000000..fd83233c97ff --- /dev/null +++ b/quiche/quic/core/chlo_extractor.h @@ -0,0 +1,44 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CHLO_EXTRACTOR_H_ +#define QUICHE_QUIC_CORE_CHLO_EXTRACTOR_H_ + +#include "quiche/quic/core/crypto/crypto_handshake_message.h" +#include "quiche/quic/core/quic_packets.h" + +namespace quic { + +// A utility for extracting QUIC Client Hello messages from packets, +// without needing to spin up a full QuicSession. +class QUIC_NO_EXPORT ChloExtractor { + public: + class QUIC_NO_EXPORT Delegate { + public: + virtual ~Delegate() {} + + // Called when a CHLO message is found in the packets. + virtual void OnChlo(QuicTransportVersion version, + QuicConnectionId connection_id, + const CryptoHandshakeMessage& chlo) = 0; + }; + + // Extracts a CHLO message from |packet| and invokes the OnChlo + // method of |delegate|. Return true if a CHLO message was found, + // and false otherwise. If non-empty, + // |create_session_tag_indicators| contains a list of QUIC tags that + // if found will result in the session being created early, to + // enable support for multi-packet CHLOs. + static bool Extract(const QuicEncryptedPacket& packet, + ParsedQuicVersion version, + const QuicTagVector& create_session_tag_indicators, + Delegate* delegate, uint8_t connection_id_length); + + ChloExtractor(const ChloExtractor&) = delete; + ChloExtractor operator=(const ChloExtractor&) = delete; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CHLO_EXTRACTOR_H_ diff --git a/quiche/quic/core/chlo_extractor_test.cc b/quiche/quic/core/chlo_extractor_test.cc new file mode 100644 index 000000000000..6b49fdcab6bd --- /dev/null +++ b/quiche/quic/core/chlo_extractor_test.cc @@ -0,0 +1,177 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/chlo_extractor.h" + +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_framer.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/first_flight.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { +namespace test { +namespace { + +class TestDelegate : public ChloExtractor::Delegate { + public: + TestDelegate() = default; + ~TestDelegate() override = default; + + // ChloExtractor::Delegate implementation + void OnChlo(QuicTransportVersion version, QuicConnectionId connection_id, + const CryptoHandshakeMessage& chlo) override { + version_ = version; + connection_id_ = connection_id; + chlo_ = chlo.DebugString(); + absl::string_view alpn_value; + if (chlo.GetStringPiece(kALPN, &alpn_value)) { + alpn_ = std::string(alpn_value); + } + } + + QuicConnectionId connection_id() const { return connection_id_; } + QuicTransportVersion transport_version() const { return version_; } + const std::string& chlo() const { return chlo_; } + const std::string& alpn() const { return alpn_; } + + private: + QuicConnectionId connection_id_; + QuicTransportVersion version_; + std::string chlo_; + std::string alpn_; +}; + +class ChloExtractorTest : public QuicTestWithParam { + public: + ChloExtractorTest() : version_(GetParam()) {} + + void MakePacket(absl::string_view data, bool munge_offset, + bool munge_stream_id) { + QuicPacketHeader header; + header.destination_connection_id = TestConnectionId(); + header.destination_connection_id_included = CONNECTION_ID_PRESENT; + header.version_flag = true; + header.version = version_; + header.reset_flag = false; + header.packet_number_length = PACKET_4BYTE_PACKET_NUMBER; + header.packet_number = QuicPacketNumber(1); + if (version_.HasLongHeaderLengths()) { + header.retry_token_length_length = + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1; + header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2; + } + QuicFrames frames; + size_t offset = 0; + if (munge_offset) { + offset++; + } + QuicFramer framer(SupportedVersions(version_), QuicTime::Zero(), + Perspective::IS_CLIENT, kQuicDefaultConnectionIdLength); + framer.SetInitialObfuscators(TestConnectionId()); + if (!version_.UsesCryptoFrames() || munge_stream_id) { + QuicStreamId stream_id = + QuicUtils::GetCryptoStreamId(version_.transport_version); + if (munge_stream_id) { + stream_id++; + } + frames.push_back( + QuicFrame(QuicStreamFrame(stream_id, false, offset, data))); + } else { + frames.push_back( + QuicFrame(new QuicCryptoFrame(ENCRYPTION_INITIAL, offset, data))); + } + std::unique_ptr packet( + BuildUnsizedDataPacket(&framer, header, frames)); + EXPECT_TRUE(packet != nullptr); + size_t encrypted_length = + framer.EncryptPayload(ENCRYPTION_INITIAL, header.packet_number, *packet, + buffer_, ABSL_ARRAYSIZE(buffer_)); + ASSERT_NE(0u, encrypted_length); + packet_ = std::make_unique(buffer_, encrypted_length); + EXPECT_TRUE(packet_ != nullptr); + DeleteFrames(&frames); + } + + protected: + ParsedQuicVersion version_; + TestDelegate delegate_; + std::unique_ptr packet_; + char buffer_[kMaxOutgoingPacketSize]; +}; + +INSTANTIATE_TEST_SUITE_P( + ChloExtractorTests, ChloExtractorTest, + ::testing::ValuesIn(AllSupportedVersionsWithQuicCrypto()), + ::testing::PrintToStringParamName()); + +TEST_P(ChloExtractorTest, FindsValidChlo) { + CryptoHandshakeMessage client_hello; + client_hello.set_tag(kCHLO); + + std::string client_hello_str(client_hello.GetSerialized().AsStringPiece()); + + MakePacket(client_hello_str, /*munge_offset=*/false, + /*munge_stream_id=*/false); + EXPECT_TRUE(ChloExtractor::Extract(*packet_, version_, {}, &delegate_, + kQuicDefaultConnectionIdLength)); + EXPECT_EQ(version_.transport_version, delegate_.transport_version()); + EXPECT_EQ(TestConnectionId(), delegate_.connection_id()); + EXPECT_EQ(client_hello.DebugString(), delegate_.chlo()); +} + +TEST_P(ChloExtractorTest, DoesNotFindValidChloOnWrongStream) { + if (version_.UsesCryptoFrames()) { + // When crypto frames are in use we do not use stream frames. + return; + } + CryptoHandshakeMessage client_hello; + client_hello.set_tag(kCHLO); + + std::string client_hello_str(client_hello.GetSerialized().AsStringPiece()); + MakePacket(client_hello_str, + /*munge_offset=*/false, /*munge_stream_id=*/true); + EXPECT_FALSE(ChloExtractor::Extract(*packet_, version_, {}, &delegate_, + kQuicDefaultConnectionIdLength)); +} + +TEST_P(ChloExtractorTest, DoesNotFindValidChloOnWrongOffset) { + CryptoHandshakeMessage client_hello; + client_hello.set_tag(kCHLO); + + std::string client_hello_str(client_hello.GetSerialized().AsStringPiece()); + MakePacket(client_hello_str, /*munge_offset=*/true, + /*munge_stream_id=*/false); + EXPECT_FALSE(ChloExtractor::Extract(*packet_, version_, {}, &delegate_, + kQuicDefaultConnectionIdLength)); +} + +TEST_P(ChloExtractorTest, DoesNotFindInvalidChlo) { + MakePacket("foo", /*munge_offset=*/false, + /*munge_stream_id=*/false); + EXPECT_FALSE(ChloExtractor::Extract(*packet_, version_, {}, &delegate_, + kQuicDefaultConnectionIdLength)); +} + +TEST_P(ChloExtractorTest, FirstFlight) { + std::vector> packets = + GetFirstFlightOfPackets(version_); + ASSERT_EQ(packets.size(), 1u); + EXPECT_TRUE(ChloExtractor::Extract(*packets[0], version_, {}, &delegate_, + kQuicDefaultConnectionIdLength)); + EXPECT_EQ(version_.transport_version, delegate_.transport_version()); + EXPECT_EQ(TestConnectionId(), delegate_.connection_id()); + EXPECT_EQ(AlpnForVersion(version_), delegate_.alpn()); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/congestion_control/bandwidth_sampler.cc b/quiche/quic/core/congestion_control/bandwidth_sampler.cc new file mode 100644 index 000000000000..fe42e083e589 --- /dev/null +++ b/quiche/quic/core/congestion_control/bandwidth_sampler.cc @@ -0,0 +1,583 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/bandwidth_sampler.h" + +#include + +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +std::ostream& operator<<(std::ostream& os, const SendTimeState& s) { + os << "{valid:" << s.is_valid << ", app_limited:" << s.is_app_limited + << ", total_sent:" << s.total_bytes_sent + << ", total_acked:" << s.total_bytes_acked + << ", total_lost:" << s.total_bytes_lost + << ", inflight:" << s.bytes_in_flight << "}"; + return os; +} + +QuicByteCount MaxAckHeightTracker::Update( + QuicBandwidth bandwidth_estimate, bool is_new_max_bandwidth, + QuicRoundTripCount round_trip_count, + QuicPacketNumber last_sent_packet_number, + QuicPacketNumber last_acked_packet_number, QuicTime ack_time, + QuicByteCount bytes_acked) { + bool force_new_epoch = false; + + if (reduce_extra_acked_on_bandwidth_increase_ && is_new_max_bandwidth) { + // Save and clear existing entries. + ExtraAckedEvent best = max_ack_height_filter_.GetBest(); + ExtraAckedEvent second_best = max_ack_height_filter_.GetSecondBest(); + ExtraAckedEvent third_best = max_ack_height_filter_.GetThirdBest(); + max_ack_height_filter_.Clear(); + + // Reinsert the heights into the filter after recalculating. + QuicByteCount expected_bytes_acked = bandwidth_estimate * best.time_delta; + if (expected_bytes_acked < best.bytes_acked) { + best.extra_acked = best.bytes_acked - expected_bytes_acked; + max_ack_height_filter_.Update(best, best.round); + } + expected_bytes_acked = bandwidth_estimate * second_best.time_delta; + if (expected_bytes_acked < second_best.bytes_acked) { + QUICHE_DCHECK_LE(best.round, second_best.round); + second_best.extra_acked = second_best.bytes_acked - expected_bytes_acked; + max_ack_height_filter_.Update(second_best, second_best.round); + } + expected_bytes_acked = bandwidth_estimate * third_best.time_delta; + if (expected_bytes_acked < third_best.bytes_acked) { + QUICHE_DCHECK_LE(second_best.round, third_best.round); + third_best.extra_acked = third_best.bytes_acked - expected_bytes_acked; + max_ack_height_filter_.Update(third_best, third_best.round); + } + } + + // If any packet sent after the start of the epoch has been acked, start a new + // epoch. + if (start_new_aggregation_epoch_after_full_round_ && + last_sent_packet_number_before_epoch_.IsInitialized() && + last_acked_packet_number.IsInitialized() && + last_acked_packet_number > last_sent_packet_number_before_epoch_) { + QUIC_DVLOG(3) << "Force starting a new aggregation epoch. " + "last_sent_packet_number_before_epoch_:" + << last_sent_packet_number_before_epoch_ + << ", last_acked_packet_number:" << last_acked_packet_number; + if (reduce_extra_acked_on_bandwidth_increase_) { + QUIC_BUG(quic_bwsampler_46) + << "A full round of aggregation should never " + << "pass with startup_include_extra_acked(B204) enabled."; + } + force_new_epoch = true; + } + if (aggregation_epoch_start_time_ == QuicTime::Zero() || force_new_epoch) { + aggregation_epoch_bytes_ = bytes_acked; + aggregation_epoch_start_time_ = ack_time; + last_sent_packet_number_before_epoch_ = last_sent_packet_number; + ++num_ack_aggregation_epochs_; + return 0; + } + + // Compute how many bytes are expected to be delivered, assuming max bandwidth + // is correct. + QuicTime::Delta aggregation_delta = ack_time - aggregation_epoch_start_time_; + QuicByteCount expected_bytes_acked = bandwidth_estimate * aggregation_delta; + // Reset the current aggregation epoch as soon as the ack arrival rate is less + // than or equal to the max bandwidth. + if (aggregation_epoch_bytes_ <= + ack_aggregation_bandwidth_threshold_ * expected_bytes_acked) { + QUIC_DVLOG(3) << "Starting a new aggregation epoch because " + "aggregation_epoch_bytes_ " + << aggregation_epoch_bytes_ + << " is smaller than expected. " + "ack_aggregation_bandwidth_threshold_:" + << ack_aggregation_bandwidth_threshold_ + << ", expected_bytes_acked:" << expected_bytes_acked + << ", bandwidth_estimate:" << bandwidth_estimate + << ", aggregation_duration:" << aggregation_delta + << ", new_aggregation_epoch:" << ack_time + << ", new_aggregation_bytes_acked:" << bytes_acked; + // Reset to start measuring a new aggregation epoch. + aggregation_epoch_bytes_ = bytes_acked; + aggregation_epoch_start_time_ = ack_time; + last_sent_packet_number_before_epoch_ = last_sent_packet_number; + ++num_ack_aggregation_epochs_; + return 0; + } + + aggregation_epoch_bytes_ += bytes_acked; + + // Compute how many extra bytes were delivered vs max bandwidth. + QuicByteCount extra_bytes_acked = + aggregation_epoch_bytes_ - expected_bytes_acked; + QUIC_DVLOG(3) << "Updating MaxAckHeight. ack_time:" << ack_time + << ", last sent packet:" << last_sent_packet_number + << ", bandwidth_estimate:" << bandwidth_estimate + << ", bytes_acked:" << bytes_acked + << ", expected_bytes_acked:" << expected_bytes_acked + << ", aggregation_epoch_bytes_:" << aggregation_epoch_bytes_ + << ", extra_bytes_acked:" << extra_bytes_acked; + ExtraAckedEvent new_event; + new_event.extra_acked = extra_bytes_acked; + new_event.bytes_acked = aggregation_epoch_bytes_; + new_event.time_delta = aggregation_delta; + max_ack_height_filter_.Update(new_event, round_trip_count); + return extra_bytes_acked; +} + +BandwidthSampler::BandwidthSampler( + const QuicUnackedPacketMap* unacked_packet_map, + QuicRoundTripCount max_height_tracker_window_length) + : total_bytes_sent_(0), + total_bytes_acked_(0), + total_bytes_lost_(0), + total_bytes_neutered_(0), + total_bytes_sent_at_last_acked_packet_(0), + last_acked_packet_sent_time_(QuicTime::Zero()), + last_acked_packet_ack_time_(QuicTime::Zero()), + is_app_limited_(true), + connection_state_map_(), + max_tracked_packets_(GetQuicFlag(quic_max_tracked_packet_count)), + unacked_packet_map_(unacked_packet_map), + max_ack_height_tracker_(max_height_tracker_window_length), + total_bytes_acked_after_last_ack_event_(0), + overestimate_avoidance_(false), + limit_max_ack_height_tracker_by_send_rate_(false) {} + +BandwidthSampler::BandwidthSampler(const BandwidthSampler& other) + : total_bytes_sent_(other.total_bytes_sent_), + total_bytes_acked_(other.total_bytes_acked_), + total_bytes_lost_(other.total_bytes_lost_), + total_bytes_neutered_(other.total_bytes_neutered_), + total_bytes_sent_at_last_acked_packet_( + other.total_bytes_sent_at_last_acked_packet_), + last_acked_packet_sent_time_(other.last_acked_packet_sent_time_), + last_acked_packet_ack_time_(other.last_acked_packet_ack_time_), + last_sent_packet_(other.last_sent_packet_), + last_acked_packet_(other.last_acked_packet_), + is_app_limited_(other.is_app_limited_), + end_of_app_limited_phase_(other.end_of_app_limited_phase_), + connection_state_map_(other.connection_state_map_), + recent_ack_points_(other.recent_ack_points_), + a0_candidates_(other.a0_candidates_), + max_tracked_packets_(other.max_tracked_packets_), + unacked_packet_map_(other.unacked_packet_map_), + max_ack_height_tracker_(other.max_ack_height_tracker_), + total_bytes_acked_after_last_ack_event_( + other.total_bytes_acked_after_last_ack_event_), + overestimate_avoidance_(other.overestimate_avoidance_), + limit_max_ack_height_tracker_by_send_rate_( + other.limit_max_ack_height_tracker_by_send_rate_) {} + +void BandwidthSampler::EnableOverestimateAvoidance() { + if (overestimate_avoidance_) { + return; + } + + overestimate_avoidance_ = true; + // TODO(wub): Change the default value of + // --quic_ack_aggregation_bandwidth_threshold to 2.0. + max_ack_height_tracker_.SetAckAggregationBandwidthThreshold(2.0); +} + +BandwidthSampler::~BandwidthSampler() {} + +void BandwidthSampler::OnPacketSent( + QuicTime sent_time, QuicPacketNumber packet_number, QuicByteCount bytes, + QuicByteCount bytes_in_flight, + HasRetransmittableData has_retransmittable_data) { + last_sent_packet_ = packet_number; + + if (has_retransmittable_data != HAS_RETRANSMITTABLE_DATA) { + return; + } + + total_bytes_sent_ += bytes; + + // If there are no packets in flight, the time at which the new transmission + // opens can be treated as the A_0 point for the purpose of bandwidth + // sampling. This underestimates bandwidth to some extent, and produces some + // artificially low samples for most packets in flight, but it provides with + // samples at important points where we would not have them otherwise, most + // importantly at the beginning of the connection. + if (bytes_in_flight == 0) { + last_acked_packet_ack_time_ = sent_time; + if (overestimate_avoidance_) { + recent_ack_points_.Clear(); + recent_ack_points_.Update(sent_time, total_bytes_acked_); + a0_candidates_.clear(); + a0_candidates_.push_back(recent_ack_points_.MostRecentPoint()); + } + total_bytes_sent_at_last_acked_packet_ = total_bytes_sent_; + + // In this situation ack compression is not a concern, set send rate to + // effectively infinite. + last_acked_packet_sent_time_ = sent_time; + } + + if (!connection_state_map_.IsEmpty() && + packet_number > + connection_state_map_.last_packet() + max_tracked_packets_) { + if (unacked_packet_map_ != nullptr && !unacked_packet_map_->empty()) { + QuicPacketNumber maybe_least_unacked = + unacked_packet_map_->GetLeastUnacked(); + QUIC_BUG(quic_bug_10437_1) + << "BandwidthSampler in-flight packet map has exceeded maximum " + "number of tracked packets(" + << max_tracked_packets_ + << "). First tracked: " << connection_state_map_.first_packet() + << "; last tracked: " << connection_state_map_.last_packet() + << "; entry_slots_used: " << connection_state_map_.entry_slots_used() + << "; number_of_present_entries: " + << connection_state_map_.number_of_present_entries() + << "; packet number: " << packet_number + << "; unacked_map: " << unacked_packet_map_->DebugString() + << "; total_bytes_sent: " << total_bytes_sent_ + << "; total_bytes_acked: " << total_bytes_acked_ + << "; total_bytes_lost: " << total_bytes_lost_ + << "; total_bytes_neutered: " << total_bytes_neutered_ + << "; last_acked_packet_sent_time: " << last_acked_packet_sent_time_ + << "; total_bytes_sent_at_last_acked_packet: " + << total_bytes_sent_at_last_acked_packet_ + << "; least_unacked_packet_info: " + << (unacked_packet_map_->IsUnacked(maybe_least_unacked) + ? unacked_packet_map_ + ->GetTransmissionInfo(maybe_least_unacked) + .DebugString() + : "n/a"); + } else { + QUIC_BUG(quic_bug_10437_2) + << "BandwidthSampler in-flight packet map has exceeded maximum " + "number of tracked packets."; + } + } + + bool success = connection_state_map_.Emplace(packet_number, sent_time, bytes, + bytes_in_flight + bytes, *this); + QUIC_BUG_IF(quic_bug_10437_3, !success) + << "BandwidthSampler failed to insert the packet " + "into the map, most likely because it's already " + "in it."; +} + +void BandwidthSampler::OnPacketNeutered(QuicPacketNumber packet_number) { + connection_state_map_.Remove( + packet_number, [&](const ConnectionStateOnSentPacket& sent_packet) { + QUIC_CODE_COUNT(quic_bandwidth_sampler_packet_neutered); + total_bytes_neutered_ += sent_packet.size; + }); +} + +BandwidthSamplerInterface::CongestionEventSample +BandwidthSampler::OnCongestionEvent(QuicTime ack_time, + const AckedPacketVector& acked_packets, + const LostPacketVector& lost_packets, + QuicBandwidth max_bandwidth, + QuicBandwidth est_bandwidth_upper_bound, + QuicRoundTripCount round_trip_count) { + CongestionEventSample event_sample; + + SendTimeState last_lost_packet_send_state; + + for (const LostPacket& packet : lost_packets) { + SendTimeState send_state = + OnPacketLost(packet.packet_number, packet.bytes_lost); + if (send_state.is_valid) { + last_lost_packet_send_state = send_state; + } + } + + if (acked_packets.empty()) { + // Only populate send state for a loss-only event. + event_sample.last_packet_send_state = last_lost_packet_send_state; + return event_sample; + } + + SendTimeState last_acked_packet_send_state; + QuicBandwidth max_send_rate = QuicBandwidth::Zero(); + for (const auto& packet : acked_packets) { + BandwidthSample sample = + OnPacketAcknowledged(ack_time, packet.packet_number); + if (!sample.state_at_send.is_valid) { + continue; + } + + last_acked_packet_send_state = sample.state_at_send; + + if (!sample.rtt.IsZero()) { + event_sample.sample_rtt = std::min(event_sample.sample_rtt, sample.rtt); + } + if (sample.bandwidth > event_sample.sample_max_bandwidth) { + event_sample.sample_max_bandwidth = sample.bandwidth; + event_sample.sample_is_app_limited = sample.state_at_send.is_app_limited; + } + if (!sample.send_rate.IsInfinite()) { + max_send_rate = std::max(max_send_rate, sample.send_rate); + } + const QuicByteCount inflight_sample = + total_bytes_acked() - last_acked_packet_send_state.total_bytes_acked; + if (inflight_sample > event_sample.sample_max_inflight) { + event_sample.sample_max_inflight = inflight_sample; + } + } + + if (!last_lost_packet_send_state.is_valid) { + event_sample.last_packet_send_state = last_acked_packet_send_state; + } else if (!last_acked_packet_send_state.is_valid) { + event_sample.last_packet_send_state = last_lost_packet_send_state; + } else { + // If two packets are inflight and an alarm is armed to lose a packet and it + // wakes up late, then the first of two in flight packets could have been + // acknowledged before the wakeup, which re-evaluates loss detection, and + // could declare the later of the two lost. + event_sample.last_packet_send_state = + lost_packets.back().packet_number > acked_packets.back().packet_number + ? last_lost_packet_send_state + : last_acked_packet_send_state; + } + + bool is_new_max_bandwidth = event_sample.sample_max_bandwidth > max_bandwidth; + max_bandwidth = std::max(max_bandwidth, event_sample.sample_max_bandwidth); + if (limit_max_ack_height_tracker_by_send_rate_) { + max_bandwidth = std::max(max_bandwidth, max_send_rate); + } + // TODO(ianswett): Why is the min being passed in here? + event_sample.extra_acked = + OnAckEventEnd(std::min(est_bandwidth_upper_bound, max_bandwidth), + is_new_max_bandwidth, round_trip_count); + + return event_sample; +} + +QuicByteCount BandwidthSampler::OnAckEventEnd( + QuicBandwidth bandwidth_estimate, bool is_new_max_bandwidth, + QuicRoundTripCount round_trip_count) { + const QuicByteCount newly_acked_bytes = + total_bytes_acked_ - total_bytes_acked_after_last_ack_event_; + + if (newly_acked_bytes == 0) { + return 0; + } + total_bytes_acked_after_last_ack_event_ = total_bytes_acked_; + QuicByteCount extra_acked = max_ack_height_tracker_.Update( + bandwidth_estimate, is_new_max_bandwidth, round_trip_count, + last_sent_packet_, last_acked_packet_, last_acked_packet_ack_time_, + newly_acked_bytes); + // If |extra_acked| is zero, i.e. this ack event marks the start of a new ack + // aggregation epoch, save LessRecentPoint, which is the last ack point of the + // previous epoch, as a A0 candidate. + if (overestimate_avoidance_ && extra_acked == 0) { + a0_candidates_.push_back(recent_ack_points_.LessRecentPoint()); + QUIC_DVLOG(1) << "New a0_candidate:" << a0_candidates_.back(); + } + return extra_acked; +} + +BandwidthSample BandwidthSampler::OnPacketAcknowledged( + QuicTime ack_time, QuicPacketNumber packet_number) { + last_acked_packet_ = packet_number; + ConnectionStateOnSentPacket* sent_packet_pointer = + connection_state_map_.GetEntry(packet_number); + if (sent_packet_pointer == nullptr) { + // See the TODO below. + return BandwidthSample(); + } + BandwidthSample sample = + OnPacketAcknowledgedInner(ack_time, packet_number, *sent_packet_pointer); + return sample; +} + +BandwidthSample BandwidthSampler::OnPacketAcknowledgedInner( + QuicTime ack_time, QuicPacketNumber packet_number, + const ConnectionStateOnSentPacket& sent_packet) { + total_bytes_acked_ += sent_packet.size; + total_bytes_sent_at_last_acked_packet_ = + sent_packet.send_time_state.total_bytes_sent; + last_acked_packet_sent_time_ = sent_packet.sent_time; + last_acked_packet_ack_time_ = ack_time; + if (overestimate_avoidance_) { + recent_ack_points_.Update(ack_time, total_bytes_acked_); + } + + if (is_app_limited_) { + // Exit app-limited phase in two cases: + // (1) end_of_app_limited_phase_ is not initialized, i.e., so far all + // packets are sent while there are buffered packets or pending data. + // (2) The current acked packet is after the sent packet marked as the end + // of the app limit phase. + if (!end_of_app_limited_phase_.IsInitialized() || + packet_number > end_of_app_limited_phase_) { + is_app_limited_ = false; + } + } + + // There might have been no packets acknowledged at the moment when the + // current packet was sent. In that case, there is no bandwidth sample to + // make. + if (sent_packet.last_acked_packet_sent_time == QuicTime::Zero()) { + QUIC_BUG(quic_bug_10437_4) + << "sent_packet.last_acked_packet_sent_time is zero"; + return BandwidthSample(); + } + + // Infinite rate indicates that the sampler is supposed to discard the + // current send rate sample and use only the ack rate. + QuicBandwidth send_rate = QuicBandwidth::Infinite(); + if (sent_packet.sent_time > sent_packet.last_acked_packet_sent_time) { + send_rate = QuicBandwidth::FromBytesAndTimeDelta( + sent_packet.send_time_state.total_bytes_sent - + sent_packet.total_bytes_sent_at_last_acked_packet, + sent_packet.sent_time - sent_packet.last_acked_packet_sent_time); + } + + AckPoint a0; + if (overestimate_avoidance_ && + ChooseA0Point(sent_packet.send_time_state.total_bytes_acked, &a0)) { + QUIC_DVLOG(2) << "Using a0 point: " << a0; + } else { + a0.ack_time = sent_packet.last_acked_packet_ack_time, + a0.total_bytes_acked = sent_packet.send_time_state.total_bytes_acked; + } + + // During the slope calculation, ensure that ack time of the current packet is + // always larger than the time of the previous packet, otherwise division by + // zero or integer underflow can occur. + if (ack_time <= a0.ack_time) { + // TODO(wub): Compare this code count before and after fixing clock jitter + // issue. + if (a0.ack_time == sent_packet.sent_time) { + // This is the 1st packet after quiescense. + QUIC_CODE_COUNT_N(quic_prev_ack_time_larger_than_current_ack_time, 1, 2); + } else { + QUIC_CODE_COUNT_N(quic_prev_ack_time_larger_than_current_ack_time, 2, 2); + } + QUIC_LOG_EVERY_N_SEC(ERROR, 60) + << "Time of the previously acked packet:" + << a0.ack_time.ToDebuggingValue() + << " is larger than the ack time of the current packet:" + << ack_time.ToDebuggingValue() + << ". acked packet number:" << packet_number + << ", total_bytes_acked_:" << total_bytes_acked_ + << ", overestimate_avoidance_:" << overestimate_avoidance_ + << ", sent_packet:" << sent_packet; + return BandwidthSample(); + } + QuicBandwidth ack_rate = QuicBandwidth::FromBytesAndTimeDelta( + total_bytes_acked_ - a0.total_bytes_acked, ack_time - a0.ack_time); + + BandwidthSample sample; + sample.bandwidth = std::min(send_rate, ack_rate); + // Note: this sample does not account for delayed acknowledgement time. This + // means that the RTT measurements here can be artificially high, especially + // on low bandwidth connections. + sample.rtt = ack_time - sent_packet.sent_time; + sample.send_rate = send_rate; + SentPacketToSendTimeState(sent_packet, &sample.state_at_send); + + if (sample.bandwidth.IsZero()) { + QUIC_LOG_EVERY_N_SEC(ERROR, 60) + << "ack_rate: " << ack_rate << ", send_rate: " << send_rate + << ". acked packet number:" << packet_number + << ", overestimate_avoidance_:" << overestimate_avoidance_ << "a1:{" + << total_bytes_acked_ << "@" << ack_time << "}, a0:{" + << a0.total_bytes_acked << "@" << a0.ack_time + << "}, sent_packet:" << sent_packet; + } + return sample; +} + +bool BandwidthSampler::ChooseA0Point(QuicByteCount total_bytes_acked, + AckPoint* a0) { + if (a0_candidates_.empty()) { + QUIC_BUG(quic_bug_10437_5) + << "No A0 point candicates. total_bytes_acked:" << total_bytes_acked; + return false; + } + + if (a0_candidates_.size() == 1) { + *a0 = a0_candidates_.front(); + return true; + } + + for (size_t i = 1; i < a0_candidates_.size(); ++i) { + if (a0_candidates_[i].total_bytes_acked > total_bytes_acked) { + *a0 = a0_candidates_[i - 1]; + if (i > 1) { + a0_candidates_.pop_front_n(i - 1); + } + return true; + } + } + + // All candidates' total_bytes_acked is <= |total_bytes_acked|. + *a0 = a0_candidates_.back(); + a0_candidates_.pop_front_n(a0_candidates_.size() - 1); + return true; +} + +SendTimeState BandwidthSampler::OnPacketLost(QuicPacketNumber packet_number, + QuicPacketLength bytes_lost) { + // TODO(vasilvv): see the comment for the case of missing packets in + // BandwidthSampler::OnPacketAcknowledged on why this does not raise a + // QUIC_BUG when removal fails. + SendTimeState send_time_state; + + total_bytes_lost_ += bytes_lost; + ConnectionStateOnSentPacket* sent_packet_pointer = + connection_state_map_.GetEntry(packet_number); + if (sent_packet_pointer != nullptr) { + SentPacketToSendTimeState(*sent_packet_pointer, &send_time_state); + } + + return send_time_state; +} + +void BandwidthSampler::SentPacketToSendTimeState( + const ConnectionStateOnSentPacket& sent_packet, + SendTimeState* send_time_state) const { + *send_time_state = sent_packet.send_time_state; + send_time_state->is_valid = true; +} + +void BandwidthSampler::OnAppLimited() { + is_app_limited_ = true; + end_of_app_limited_phase_ = last_sent_packet_; +} + +void BandwidthSampler::RemoveObsoletePackets(QuicPacketNumber least_unacked) { + // A packet can become obsolete when it is removed from QuicUnackedPacketMap's + // view of inflight before it is acked or marked as lost. For example, when + // QuicSentPacketManager::RetransmitCryptoPackets retransmits a crypto packet, + // the packet is removed from QuicUnackedPacketMap's inflight, but is not + // marked as acked or lost in the BandwidthSampler. + connection_state_map_.RemoveUpTo(least_unacked); +} + +QuicByteCount BandwidthSampler::total_bytes_sent() const { + return total_bytes_sent_; +} + +QuicByteCount BandwidthSampler::total_bytes_acked() const { + return total_bytes_acked_; +} + +QuicByteCount BandwidthSampler::total_bytes_lost() const { + return total_bytes_lost_; +} + +QuicByteCount BandwidthSampler::total_bytes_neutered() const { + return total_bytes_neutered_; +} + +bool BandwidthSampler::is_app_limited() const { return is_app_limited_; } + +QuicPacketNumber BandwidthSampler::end_of_app_limited_phase() const { + return end_of_app_limited_phase_; +} + +} // namespace quic diff --git a/quiche/quic/core/congestion_control/bandwidth_sampler.h b/quiche/quic/core/congestion_control/bandwidth_sampler.h new file mode 100644 index 000000000000..3ca09d1c5787 --- /dev/null +++ b/quiche/quic/core/congestion_control/bandwidth_sampler.h @@ -0,0 +1,612 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CONGESTION_CONTROL_BANDWIDTH_SAMPLER_H_ +#define QUICHE_QUIC_CORE_CONGESTION_CONTROL_BANDWIDTH_SAMPLER_H_ + +#include "quiche/quic/core/congestion_control/send_algorithm_interface.h" +#include "quiche/quic/core/congestion_control/windowed_filter.h" +#include "quiche/quic/core/packet_number_indexed_queue.h" +#include "quiche/quic/core/quic_bandwidth.h" +#include "quiche/quic/core/quic_packet_number.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_unacked_packet_map.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/common/quiche_circular_deque.h" + +namespace quic { + +namespace test { +class BandwidthSamplerPeer; +} // namespace test + +// A subset of BandwidthSampler::ConnectionStateOnSentPacket which is returned +// to the caller when the packet is acked or lost. +struct QUIC_EXPORT_PRIVATE SendTimeState { + SendTimeState() + : is_valid(false), + is_app_limited(false), + total_bytes_sent(0), + total_bytes_acked(0), + total_bytes_lost(0), + bytes_in_flight(0) {} + + SendTimeState(bool is_app_limited, QuicByteCount total_bytes_sent, + QuicByteCount total_bytes_acked, QuicByteCount total_bytes_lost, + QuicByteCount bytes_in_flight) + : is_valid(true), + is_app_limited(is_app_limited), + total_bytes_sent(total_bytes_sent), + total_bytes_acked(total_bytes_acked), + total_bytes_lost(total_bytes_lost), + bytes_in_flight(bytes_in_flight) {} + + SendTimeState(const SendTimeState& other) = default; + SendTimeState& operator=(const SendTimeState& other) = default; + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<(std::ostream& os, + const SendTimeState& s); + + // Whether other states in this object is valid. + bool is_valid; + + // Whether the sender is app limited at the time the packet was sent. + // App limited bandwidth sample might be artificially low because the sender + // did not have enough data to send in order to saturate the link. + bool is_app_limited; + + // Total number of sent bytes at the time the packet was sent. + // Includes the packet itself. + QuicByteCount total_bytes_sent; + + // Total number of acked bytes at the time the packet was sent. + QuicByteCount total_bytes_acked; + + // Total number of lost bytes at the time the packet was sent. + QuicByteCount total_bytes_lost; + + // Total number of inflight bytes at the time the packet was sent. + // Includes the packet itself. + // It should be equal to |total_bytes_sent| minus the sum of + // |total_bytes_acked|, |total_bytes_lost| and total neutered bytes. + QuicByteCount bytes_in_flight; +}; + +struct QUIC_NO_EXPORT ExtraAckedEvent { + // The excess bytes acknowlwedged in the time delta for this event. + QuicByteCount extra_acked = 0; + + // The bytes acknowledged and time delta from the event. + QuicByteCount bytes_acked = 0; + QuicTime::Delta time_delta = QuicTime::Delta::Zero(); + // The round trip of the event. + QuicRoundTripCount round = 0; + + bool operator>=(const ExtraAckedEvent& other) const { + return extra_acked >= other.extra_acked; + } + bool operator==(const ExtraAckedEvent& other) const { + return extra_acked == other.extra_acked; + } +}; + +struct QUIC_EXPORT_PRIVATE BandwidthSample { + // The bandwidth at that particular sample. Zero if no valid bandwidth sample + // is available. + QuicBandwidth bandwidth = QuicBandwidth::Zero(); + + // The RTT measurement at this particular sample. Zero if no RTT sample is + // available. Does not correct for delayed ack time. + QuicTime::Delta rtt = QuicTime::Delta::Zero(); + + // |send_rate| is computed from the current packet being acked('P') and an + // earlier packet that is acked before P was sent. + QuicBandwidth send_rate = QuicBandwidth::Infinite(); + + // States captured when the packet was sent. + SendTimeState state_at_send; +}; + +// MaxAckHeightTracker is part of the BandwidthSampler. It is called after every +// ack event to keep track the degree of ack aggregation(a.k.a "ack height"). +class QUIC_EXPORT_PRIVATE MaxAckHeightTracker { + public: + explicit MaxAckHeightTracker(QuicRoundTripCount initial_filter_window) + : max_ack_height_filter_(initial_filter_window, ExtraAckedEvent(), 0) {} + + QuicByteCount Get() const { + return max_ack_height_filter_.GetBest().extra_acked; + } + + QuicByteCount Update(QuicBandwidth bandwidth_estimate, + bool is_new_max_bandwidth, + QuicRoundTripCount round_trip_count, + QuicPacketNumber last_sent_packet_number, + QuicPacketNumber last_acked_packet_number, + QuicTime ack_time, QuicByteCount bytes_acked); + + void SetFilterWindowLength(QuicRoundTripCount length) { + max_ack_height_filter_.SetWindowLength(length); + } + + void Reset(QuicByteCount new_height, QuicRoundTripCount new_time) { + ExtraAckedEvent new_event; + new_event.extra_acked = new_height; + new_event.round = new_time; + max_ack_height_filter_.Reset(new_event, new_time); + } + + void SetAckAggregationBandwidthThreshold(double threshold) { + ack_aggregation_bandwidth_threshold_ = threshold; + } + + void SetStartNewAggregationEpochAfterFullRound(bool value) { + start_new_aggregation_epoch_after_full_round_ = value; + } + + void SetReduceExtraAckedOnBandwidthIncrease(bool value) { + reduce_extra_acked_on_bandwidth_increase_ = value; + } + + double ack_aggregation_bandwidth_threshold() const { + return ack_aggregation_bandwidth_threshold_; + } + + uint64_t num_ack_aggregation_epochs() const { + return num_ack_aggregation_epochs_; + } + + private: + // Tracks the maximum number of bytes acked faster than the estimated + // bandwidth. + using MaxAckHeightFilter = + WindowedFilter, + QuicRoundTripCount, QuicRoundTripCount>; + MaxAckHeightFilter max_ack_height_filter_; + + // The time this aggregation started and the number of bytes acked during it. + QuicTime aggregation_epoch_start_time_ = QuicTime::Zero(); + QuicByteCount aggregation_epoch_bytes_ = 0; + // The last sent packet number before the current aggregation epoch started. + QuicPacketNumber last_sent_packet_number_before_epoch_; + // The number of ack aggregation epochs ever started, including the ongoing + // one. Stats only. + uint64_t num_ack_aggregation_epochs_ = 0; + double ack_aggregation_bandwidth_threshold_ = + GetQuicFlag(quic_ack_aggregation_bandwidth_threshold); + bool start_new_aggregation_epoch_after_full_round_ = false; + bool reduce_extra_acked_on_bandwidth_increase_ = false; +}; + +// An interface common to any class that can provide bandwidth samples from the +// information per individual acknowledged packet. +class QUIC_EXPORT_PRIVATE BandwidthSamplerInterface { + public: + virtual ~BandwidthSamplerInterface() {} + + // Inputs the sent packet information into the sampler. Assumes that all + // packets are sent in order. The information about the packet will not be + // released from the sampler until it the packet is either acknowledged or + // declared lost. + virtual void OnPacketSent( + QuicTime sent_time, QuicPacketNumber packet_number, QuicByteCount bytes, + QuicByteCount bytes_in_flight, + HasRetransmittableData has_retransmittable_data) = 0; + + virtual void OnPacketNeutered(QuicPacketNumber packet_number) = 0; + + struct QUIC_NO_EXPORT CongestionEventSample { + // The maximum bandwidth sample from all acked packets. + // QuicBandwidth::Zero() if no samples are available. + QuicBandwidth sample_max_bandwidth = QuicBandwidth::Zero(); + // Whether |sample_max_bandwidth| is from a app-limited sample. + bool sample_is_app_limited = false; + // The minimum rtt sample from all acked packets. + // QuicTime::Delta::Infinite() if no samples are available. + QuicTime::Delta sample_rtt = QuicTime::Delta::Infinite(); + // For each packet p in acked packets, this is the max value of INFLIGHT(p), + // where INFLIGHT(p) is the number of bytes acked while p is inflight. + QuicByteCount sample_max_inflight = 0; + // The send state of the largest packet in acked_packets, unless it is + // empty. If acked_packets is empty, it's the send state of the largest + // packet in lost_packets. + SendTimeState last_packet_send_state; + // The number of extra bytes acked from this ack event, compared to what is + // expected from the flow's bandwidth. Larger value means more ack + // aggregation. + QuicByteCount extra_acked = 0; + }; + // Notifies the sampler that at |ack_time|, all packets in |acked_packets| + // have been acked, and all packets in |lost_packets| have been lost. + // See the comments in CongestionEventSample for the return value. + // |max_bandwidth| is the windowed maximum observed bandwidth. + // |est_bandwidth_upper_bound| is an upper bound of estimated bandwidth used + // to calculate extra_acked. + virtual CongestionEventSample OnCongestionEvent( + QuicTime ack_time, const AckedPacketVector& acked_packets, + const LostPacketVector& lost_packets, QuicBandwidth max_bandwidth, + QuicBandwidth est_bandwidth_upper_bound, + QuicRoundTripCount round_trip_count) = 0; + + // Informs the sampler that the connection is currently app-limited, causing + // the sampler to enter the app-limited phase. The phase will expire by + // itself. + virtual void OnAppLimited() = 0; + + // Remove all the packets lower than the specified packet number. + virtual void RemoveObsoletePackets(QuicPacketNumber least_unacked) = 0; + + // Total number of bytes sent/acked/lost/neutered in the connection. + virtual QuicByteCount total_bytes_sent() const = 0; + virtual QuicByteCount total_bytes_acked() const = 0; + virtual QuicByteCount total_bytes_lost() const = 0; + virtual QuicByteCount total_bytes_neutered() const = 0; + + // Application-limited information exported for debugging. + virtual bool is_app_limited() const = 0; + + virtual QuicPacketNumber end_of_app_limited_phase() const = 0; +}; + +// BandwidthSampler keeps track of sent and acknowledged packets and outputs a +// bandwidth sample for every packet acknowledged. The samples are taken for +// individual packets, and are not filtered; the consumer has to filter the +// bandwidth samples itself. In certain cases, the sampler will locally severely +// underestimate the bandwidth, hence a maximum filter with a size of at least +// one RTT is recommended. +// +// This class bases its samples on the slope of two curves: the number of bytes +// sent over time, and the number of bytes acknowledged as received over time. +// It produces a sample of both slopes for every packet that gets acknowledged, +// based on a slope between two points on each of the corresponding curves. Note +// that due to the packet loss, the number of bytes on each curve might get +// further and further away from each other, meaning that it is not feasible to +// compare byte values coming from different curves with each other. +// +// The obvious points for measuring slope sample are the ones corresponding to +// the packet that was just acknowledged. Let us denote them as S_1 (point at +// which the current packet was sent) and A_1 (point at which the current packet +// was acknowledged). However, taking a slope requires two points on each line, +// so estimating bandwidth requires picking a packet in the past with respect to +// which the slope is measured. +// +// For that purpose, BandwidthSampler always keeps track of the most recently +// acknowledged packet, and records it together with every outgoing packet. +// When a packet gets acknowledged (A_1), it has not only information about when +// it itself was sent (S_1), but also the information about a previously +// acknowledged packet before it was sent (S_0 and A_0). +// +// Based on that data, send and ack rate are estimated as: +// send_rate = (bytes(S_1) - bytes(S_0)) / (time(S_1) - time(S_0)) +// ack_rate = (bytes(A_1) - bytes(A_0)) / (time(A_1) - time(A_0)) +// +// Here, the ack rate is intuitively the rate we want to treat as bandwidth. +// However, in certain cases (e.g. ack compression) the ack rate at a point may +// end up higher than the rate at which the data was originally sent, which is +// not indicative of the real bandwidth. Hence, we use the send rate as an upper +// bound, and the sample value is +// rate_sample = min(send_rate, ack_rate) +// +// An important edge case handled by the sampler is tracking the app-limited +// samples. There are multiple meaning of "app-limited" used interchangeably, +// hence it is important to understand and to be able to distinguish between +// them. +// +// Meaning 1: connection state. The connection is said to be app-limited when +// there is no outstanding data to send. This means that certain bandwidth +// samples in the future would not be an accurate indication of the link +// capacity, and it is important to inform consumer about that. Whenever +// connection becomes app-limited, the sampler is notified via OnAppLimited() +// method. +// +// Meaning 2: a phase in the bandwidth sampler. As soon as the bandwidth +// sampler becomes notified about the connection being app-limited, it enters +// app-limited phase. In that phase, all *sent* packets are marked as +// app-limited. Note that the connection itself does not have to be +// app-limited during the app-limited phase, and in fact it will not be +// (otherwise how would it send packets?). The boolean flag below indicates +// whether the sampler is in that phase. +// +// Meaning 3: a flag on the sent packet and on the sample. If a sent packet is +// sent during the app-limited phase, the resulting sample related to the +// packet will be marked as app-limited. +// +// With the terminology issue out of the way, let us consider the question of +// what kind of situation it addresses. +// +// Consider a scenario where we first send packets 1 to 20 at a regular +// bandwidth, and then immediately run out of data. After a few seconds, we send +// packets 21 to 60, and only receive ack for 21 between sending packets 40 and +// 41. In this case, when we sample bandwidth for packets 21 to 40, the S_0/A_0 +// we use to compute the slope is going to be packet 20, a few seconds apart +// from the current packet, hence the resulting estimate would be extremely low +// and not indicative of anything. Only at packet 41 the S_0/A_0 will become 21, +// meaning that the bandwidth sample would exclude the quiescence. +// +// Based on the analysis of that scenario, we implement the following rule: once +// OnAppLimited() is called, all sent packets will produce app-limited samples +// up until an ack for a packet that was sent after OnAppLimited() was called. +// Note that while the scenario above is not the only scenario when the +// connection is app-limited, the approach works in other cases too. +class QUIC_EXPORT_PRIVATE BandwidthSampler : public BandwidthSamplerInterface { + public: + BandwidthSampler(const QuicUnackedPacketMap* unacked_packet_map, + QuicRoundTripCount max_height_tracker_window_length); + + // Copy states from |other|. This is useful when changing send algorithms in + // the middle of a connection. + BandwidthSampler(const BandwidthSampler& other); + ~BandwidthSampler() override; + + void OnPacketSent(QuicTime sent_time, QuicPacketNumber packet_number, + QuicByteCount bytes, QuicByteCount bytes_in_flight, + HasRetransmittableData has_retransmittable_data) override; + void OnPacketNeutered(QuicPacketNumber packet_number) override; + + CongestionEventSample OnCongestionEvent( + QuicTime ack_time, const AckedPacketVector& acked_packets, + const LostPacketVector& lost_packets, QuicBandwidth max_bandwidth, + QuicBandwidth est_bandwidth_upper_bound, + QuicRoundTripCount round_trip_count) override; + QuicByteCount OnAckEventEnd(QuicBandwidth bandwidth_estimate, + bool is_new_max_bandwidth, + QuicRoundTripCount round_trip_count); + + void OnAppLimited() override; + + void RemoveObsoletePackets(QuicPacketNumber least_unacked) override; + + QuicByteCount total_bytes_sent() const override; + QuicByteCount total_bytes_acked() const override; + QuicByteCount total_bytes_lost() const override; + QuicByteCount total_bytes_neutered() const override; + + bool is_app_limited() const override; + + QuicPacketNumber end_of_app_limited_phase() const override; + + QuicByteCount max_ack_height() const { return max_ack_height_tracker_.Get(); } + + uint64_t num_ack_aggregation_epochs() const { + return max_ack_height_tracker_.num_ack_aggregation_epochs(); + } + + void SetMaxAckHeightTrackerWindowLength(QuicRoundTripCount length) { + max_ack_height_tracker_.SetFilterWindowLength(length); + } + + void ResetMaxAckHeightTracker(QuicByteCount new_height, + QuicRoundTripCount new_time) { + max_ack_height_tracker_.Reset(new_height, new_time); + } + + void SetStartNewAggregationEpochAfterFullRound(bool value) { + max_ack_height_tracker_.SetStartNewAggregationEpochAfterFullRound(value); + } + + void SetLimitMaxAckHeightTrackerBySendRate(bool value) { + limit_max_ack_height_tracker_by_send_rate_ = value; + } + + void SetReduceExtraAckedOnBandwidthIncrease(bool value) { + max_ack_height_tracker_.SetReduceExtraAckedOnBandwidthIncrease(value); + } + + // AckPoint represents a point on the ack line. + struct QUIC_NO_EXPORT AckPoint { + QuicTime ack_time = QuicTime::Zero(); + QuicByteCount total_bytes_acked = 0; + + friend QUIC_NO_EXPORT std::ostream& operator<<(std::ostream& os, + const AckPoint& ack_point) { + return os << ack_point.ack_time << ":" << ack_point.total_bytes_acked; + } + }; + + // RecentAckPoints maintains the most recent 2 ack points at distinct times. + class QUIC_NO_EXPORT RecentAckPoints { + public: + void Update(QuicTime ack_time, QuicByteCount total_bytes_acked) { + QUICHE_DCHECK_GE(total_bytes_acked, ack_points_[1].total_bytes_acked); + + if (ack_time < ack_points_[1].ack_time) { + // This can only happen when time goes backwards, we use the smaller + // timestamp for the most recent ack point in that case. + // TODO(wub): Add a QUIC_BUG if ack time stops going backwards. + ack_points_[1].ack_time = ack_time; + } else if (ack_time > ack_points_[1].ack_time) { + ack_points_[0] = ack_points_[1]; + ack_points_[1].ack_time = ack_time; + } + + ack_points_[1].total_bytes_acked = total_bytes_acked; + } + + void Clear() { ack_points_[0] = ack_points_[1] = AckPoint(); } + + const AckPoint& MostRecentPoint() const { return ack_points_[1]; } + + const AckPoint& LessRecentPoint() const { + if (ack_points_[0].total_bytes_acked != 0) { + return ack_points_[0]; + } + + return ack_points_[1]; + } + + private: + AckPoint ack_points_[2]; + }; + + void EnableOverestimateAvoidance(); + bool IsOverestimateAvoidanceEnabled() const { + return overestimate_avoidance_; + } + + private: + friend class test::BandwidthSamplerPeer; + + // ConnectionStateOnSentPacket represents the information about a sent packet + // and the state of the connection at the moment the packet was sent, + // specifically the information about the most recently acknowledged packet at + // that moment. + struct QUIC_EXPORT_PRIVATE ConnectionStateOnSentPacket { + // Time at which the packet is sent. + QuicTime sent_time; + + // Size of the packet. + QuicByteCount size; + + // The value of |total_bytes_sent_at_last_acked_packet_| at the time the + // packet was sent. + QuicByteCount total_bytes_sent_at_last_acked_packet; + + // The value of |last_acked_packet_sent_time_| at the time the packet was + // sent. + QuicTime last_acked_packet_sent_time; + + // The value of |last_acked_packet_ack_time_| at the time the packet was + // sent. + QuicTime last_acked_packet_ack_time; + + // Send time states that are returned to the congestion controller when the + // packet is acked or lost. + SendTimeState send_time_state; + + // Snapshot constructor. Records the current state of the bandwidth + // sampler. + // |bytes_in_flight| is the bytes in flight right after the packet is sent. + ConnectionStateOnSentPacket(QuicTime sent_time, QuicByteCount size, + QuicByteCount bytes_in_flight, + const BandwidthSampler& sampler) + : sent_time(sent_time), + size(size), + total_bytes_sent_at_last_acked_packet( + sampler.total_bytes_sent_at_last_acked_packet_), + last_acked_packet_sent_time(sampler.last_acked_packet_sent_time_), + last_acked_packet_ack_time(sampler.last_acked_packet_ack_time_), + send_time_state(sampler.is_app_limited_, sampler.total_bytes_sent_, + sampler.total_bytes_acked_, sampler.total_bytes_lost_, + bytes_in_flight) {} + + // Default constructor. Required to put this structure into + // PacketNumberIndexedQueue. + ConnectionStateOnSentPacket() + : sent_time(QuicTime::Zero()), + size(0), + total_bytes_sent_at_last_acked_packet(0), + last_acked_packet_sent_time(QuicTime::Zero()), + last_acked_packet_ack_time(QuicTime::Zero()) {} + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const ConnectionStateOnSentPacket& p) { + os << "{sent_time:" << p.sent_time << ", size:" << p.size + << ", total_bytes_sent_at_last_acked_packet:" + << p.total_bytes_sent_at_last_acked_packet + << ", last_acked_packet_sent_time:" << p.last_acked_packet_sent_time + << ", last_acked_packet_ack_time:" << p.last_acked_packet_ack_time + << ", send_time_state:" << p.send_time_state << "}"; + return os; + } + }; + + BandwidthSample OnPacketAcknowledged(QuicTime ack_time, + QuicPacketNumber packet_number); + + SendTimeState OnPacketLost(QuicPacketNumber packet_number, + QuicPacketLength bytes_lost); + + // Copy a subset of the (private) ConnectionStateOnSentPacket to the (public) + // SendTimeState. Always set send_time_state->is_valid to true. + void SentPacketToSendTimeState(const ConnectionStateOnSentPacket& sent_packet, + SendTimeState* send_time_state) const; + + // Choose the best a0 from |a0_candidates_| to calculate the ack rate. + // |total_bytes_acked| is the total bytes acked when the packet being acked is + // sent. The best a0 is chosen as follows: + // - If there's only one candidate, use it. + // - If there are multiple candidates, let a[n] be the nth candidate, and + // a[n-1].total_bytes_acked <= |total_bytes_acked| < a[n].total_bytes_acked, + // use a[n-1]. + // - If all candidates's total_bytes_acked is > |total_bytes_acked|, use a[0]. + // This may happen when acks are received out of order, and ack[n] caused + // some candidates of ack[n-x] to be removed. + // - If all candidates's total_bytes_acked is <= |total_bytes_acked|, use + // a[a.size()-1]. + bool ChooseA0Point(QuicByteCount total_bytes_acked, AckPoint* a0); + + // The total number of congestion controlled bytes sent during the connection. + QuicByteCount total_bytes_sent_; + + // The total number of congestion controlled bytes which were acknowledged. + QuicByteCount total_bytes_acked_; + + // The total number of congestion controlled bytes which were lost. + QuicByteCount total_bytes_lost_; + + // The total number of congestion controlled bytes which have been neutered. + QuicByteCount total_bytes_neutered_; + + // The value of |total_bytes_sent_| at the time the last acknowledged packet + // was sent. Valid only when |last_acked_packet_sent_time_| is valid. + QuicByteCount total_bytes_sent_at_last_acked_packet_; + + // The time at which the last acknowledged packet was sent. Set to + // QuicTime::Zero() if no valid timestamp is available. + QuicTime last_acked_packet_sent_time_; + + // The time at which the most recent packet was acknowledged. + QuicTime last_acked_packet_ack_time_; + + // The most recently sent packet. + QuicPacketNumber last_sent_packet_; + + // The most recently acked packet. + QuicPacketNumber last_acked_packet_; + + // Indicates whether the bandwidth sampler is currently in an app-limited + // phase. + bool is_app_limited_; + + // The packet that will be acknowledged after this one will cause the sampler + // to exit the app-limited phase. + QuicPacketNumber end_of_app_limited_phase_; + + // Record of the connection state at the point where each packet in flight was + // sent, indexed by the packet number. + PacketNumberIndexedQueue connection_state_map_; + + RecentAckPoints recent_ack_points_; + quiche::QuicheCircularDeque a0_candidates_; + + // Maximum number of tracked packets. + const QuicPacketCount max_tracked_packets_; + + // The main unacked packet map. Used for outputting extra debugging details. + // May be null. + // TODO(vasilvv): remove this once it's no longer useful for debugging. + const QuicUnackedPacketMap* unacked_packet_map_; + + // Handles the actual bandwidth calculations, whereas the outer method handles + // retrieving and removing |sent_packet|. + BandwidthSample OnPacketAcknowledgedInner( + QuicTime ack_time, QuicPacketNumber packet_number, + const ConnectionStateOnSentPacket& sent_packet); + + MaxAckHeightTracker max_ack_height_tracker_; + QuicByteCount total_bytes_acked_after_last_ack_event_; + + // True if connection option 'BSAO' is set. + bool overestimate_avoidance_; + + // True if connection option 'BBRB' is set. + bool limit_max_ack_height_tracker_by_send_rate_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CONGESTION_CONTROL_BANDWIDTH_SAMPLER_H_ diff --git a/quiche/quic/core/congestion_control/bandwidth_sampler_test.cc b/quiche/quic/core/congestion_control/bandwidth_sampler_test.cc new file mode 100644 index 000000000000..97a0bf3104d9 --- /dev/null +++ b/quiche/quic/core/congestion_control/bandwidth_sampler_test.cc @@ -0,0 +1,888 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/bandwidth_sampler.h" + +#include +#include + +#include "quiche/quic/core/quic_bandwidth.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/mock_clock.h" + +namespace quic { +namespace test { + +class BandwidthSamplerPeer { + public: + static size_t GetNumberOfTrackedPackets(const BandwidthSampler& sampler) { + return sampler.connection_state_map_.number_of_present_entries(); + } + + static QuicByteCount GetPacketSize(const BandwidthSampler& sampler, + QuicPacketNumber packet_number) { + return sampler.connection_state_map_.GetEntry(packet_number)->size; + } +}; + +const QuicByteCount kRegularPacketSize = 1280; +// Enforce divisibility for some of the tests. +static_assert((kRegularPacketSize & 31) == 0, + "kRegularPacketSize has to be five times divisible by 2"); + +struct TestParameters { + bool overestimate_avoidance; +}; + +// Used by ::testing::PrintToStringParamName(). +std::string PrintToString(const TestParameters& p) { + return p.overestimate_avoidance ? "enable_overestimate_avoidance" + : "no_enable_overestimate_avoidance"; +} + +// A test fixture with utility methods for BandwidthSampler tests. +class BandwidthSamplerTest : public QuicTestWithParam { + protected: + BandwidthSamplerTest() + : sampler_(nullptr, /*max_height_tracker_window_length=*/0), + sampler_app_limited_at_start_(sampler_.is_app_limited()), + bytes_in_flight_(0), + max_bandwidth_(QuicBandwidth::Zero()), + est_bandwidth_upper_bound_(QuicBandwidth::Infinite()), + round_trip_count_(0) { + // Ensure that the clock does not start at zero. + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(1)); + if (GetParam().overestimate_avoidance) { + sampler_.EnableOverestimateAvoidance(); + } + } + + MockClock clock_; + BandwidthSampler sampler_; + bool sampler_app_limited_at_start_; + QuicByteCount bytes_in_flight_; + QuicBandwidth max_bandwidth_; // Max observed bandwidth from acks. + QuicBandwidth est_bandwidth_upper_bound_; + QuicRoundTripCount round_trip_count_; // Needed to calculate extra_acked. + + QuicByteCount PacketsToBytes(QuicPacketCount packet_count) { + return packet_count * kRegularPacketSize; + } + + void SendPacketInner(uint64_t packet_number, QuicByteCount bytes, + HasRetransmittableData has_retransmittable_data) { + sampler_.OnPacketSent(clock_.Now(), QuicPacketNumber(packet_number), bytes, + bytes_in_flight_, has_retransmittable_data); + if (has_retransmittable_data == HAS_RETRANSMITTABLE_DATA) { + bytes_in_flight_ += bytes; + } + } + + void SendPacket(uint64_t packet_number) { + SendPacketInner(packet_number, kRegularPacketSize, + HAS_RETRANSMITTABLE_DATA); + } + + BandwidthSample AckPacketInner(uint64_t packet_number) { + QuicByteCount size = BandwidthSamplerPeer::GetPacketSize( + sampler_, QuicPacketNumber(packet_number)); + bytes_in_flight_ -= size; + BandwidthSampler::CongestionEventSample sample = sampler_.OnCongestionEvent( + clock_.Now(), {MakeAckedPacket(packet_number)}, {}, max_bandwidth_, + est_bandwidth_upper_bound_, round_trip_count_); + max_bandwidth_ = std::max(max_bandwidth_, sample.sample_max_bandwidth); + BandwidthSample bandwidth_sample; + bandwidth_sample.bandwidth = sample.sample_max_bandwidth; + bandwidth_sample.rtt = sample.sample_rtt; + bandwidth_sample.state_at_send = sample.last_packet_send_state; + EXPECT_TRUE(bandwidth_sample.state_at_send.is_valid); + return bandwidth_sample; + } + + AckedPacket MakeAckedPacket(uint64_t packet_number) const { + QuicByteCount size = BandwidthSamplerPeer::GetPacketSize( + sampler_, QuicPacketNumber(packet_number)); + return AckedPacket(QuicPacketNumber(packet_number), size, clock_.Now()); + } + + LostPacket MakeLostPacket(uint64_t packet_number) const { + return LostPacket(QuicPacketNumber(packet_number), + BandwidthSamplerPeer::GetPacketSize( + sampler_, QuicPacketNumber(packet_number))); + } + + // Acknowledge receipt of a packet and expect it to be not app-limited. + QuicBandwidth AckPacket(uint64_t packet_number) { + BandwidthSample sample = AckPacketInner(packet_number); + return sample.bandwidth; + } + + BandwidthSampler::CongestionEventSample OnCongestionEvent( + std::set acked_packet_numbers, + std::set lost_packet_numbers) { + AckedPacketVector acked_packets; + for (auto it = acked_packet_numbers.begin(); + it != acked_packet_numbers.end(); ++it) { + acked_packets.push_back(MakeAckedPacket(*it)); + bytes_in_flight_ -= acked_packets.back().bytes_acked; + } + + LostPacketVector lost_packets; + for (auto it = lost_packet_numbers.begin(); it != lost_packet_numbers.end(); + ++it) { + lost_packets.push_back(MakeLostPacket(*it)); + bytes_in_flight_ -= lost_packets.back().bytes_lost; + } + + BandwidthSampler::CongestionEventSample sample = sampler_.OnCongestionEvent( + clock_.Now(), acked_packets, lost_packets, max_bandwidth_, + est_bandwidth_upper_bound_, round_trip_count_); + max_bandwidth_ = std::max(max_bandwidth_, sample.sample_max_bandwidth); + return sample; + } + + SendTimeState LosePacket(uint64_t packet_number) { + QuicByteCount size = BandwidthSamplerPeer::GetPacketSize( + sampler_, QuicPacketNumber(packet_number)); + bytes_in_flight_ -= size; + LostPacket lost_packet(QuicPacketNumber(packet_number), size); + BandwidthSampler::CongestionEventSample sample = sampler_.OnCongestionEvent( + clock_.Now(), {}, {lost_packet}, max_bandwidth_, + est_bandwidth_upper_bound_, round_trip_count_); + EXPECT_TRUE(sample.last_packet_send_state.is_valid); + EXPECT_EQ(sample.sample_max_bandwidth, QuicBandwidth::Zero()); + EXPECT_EQ(sample.sample_rtt, QuicTime::Delta::Infinite()); + return sample.last_packet_send_state; + } + + // Sends one packet and acks it. Then, send 20 packets. Finally, send + // another 20 packets while acknowledging previous 20. + void Send40PacketsAndAckFirst20(QuicTime::Delta time_between_packets) { + // Send 20 packets at a constant inter-packet time. + for (int i = 1; i <= 20; i++) { + SendPacket(i); + clock_.AdvanceTime(time_between_packets); + } + + // Ack packets 1 to 20, while sending new packets at the same rate as + // before. + for (int i = 1; i <= 20; i++) { + AckPacket(i); + SendPacket(i + 20); + clock_.AdvanceTime(time_between_packets); + } + } +}; + +INSTANTIATE_TEST_SUITE_P( + BandwidthSamplerTests, BandwidthSamplerTest, + testing::Values(TestParameters{/*overestimate_avoidance=*/false}, + TestParameters{/*overestimate_avoidance=*/true}), + testing::PrintToStringParamName()); + +// Test the sampler in a simple stop-and-wait sender setting. +TEST_P(BandwidthSamplerTest, SendAndWait) { + QuicTime::Delta time_between_packets = QuicTime::Delta::FromMilliseconds(10); + QuicBandwidth expected_bandwidth = + QuicBandwidth::FromBytesPerSecond(kRegularPacketSize * 100); + + // Send packets at the constant bandwidth. + for (int i = 1; i < 20; i++) { + SendPacket(i); + clock_.AdvanceTime(time_between_packets); + QuicBandwidth current_sample = AckPacket(i); + EXPECT_EQ(expected_bandwidth, current_sample); + } + + // Send packets at the exponentially decreasing bandwidth. + for (int i = 20; i < 25; i++) { + time_between_packets = time_between_packets * 2; + expected_bandwidth = expected_bandwidth * 0.5; + + SendPacket(i); + clock_.AdvanceTime(time_between_packets); + QuicBandwidth current_sample = AckPacket(i); + EXPECT_EQ(expected_bandwidth, current_sample); + } + sampler_.RemoveObsoletePackets(QuicPacketNumber(25)); + + EXPECT_EQ(0u, BandwidthSamplerPeer::GetNumberOfTrackedPackets(sampler_)); + EXPECT_EQ(0u, bytes_in_flight_); +} + +TEST_P(BandwidthSamplerTest, SendTimeState) { + QuicTime::Delta time_between_packets = QuicTime::Delta::FromMilliseconds(10); + + // Send packets 1-5. + for (int i = 1; i <= 5; i++) { + SendPacket(i); + EXPECT_EQ(PacketsToBytes(i), sampler_.total_bytes_sent()); + clock_.AdvanceTime(time_between_packets); + } + + // Ack packet 1. + SendTimeState send_time_state = AckPacketInner(1).state_at_send; + EXPECT_EQ(PacketsToBytes(1), send_time_state.total_bytes_sent); + EXPECT_EQ(0u, send_time_state.total_bytes_acked); + EXPECT_EQ(0u, send_time_state.total_bytes_lost); + EXPECT_EQ(PacketsToBytes(1), sampler_.total_bytes_acked()); + + // Lose packet 2. + send_time_state = LosePacket(2); + EXPECT_EQ(PacketsToBytes(2), send_time_state.total_bytes_sent); + EXPECT_EQ(0u, send_time_state.total_bytes_acked); + EXPECT_EQ(0u, send_time_state.total_bytes_lost); + EXPECT_EQ(PacketsToBytes(1), sampler_.total_bytes_lost()); + + // Lose packet 3. + send_time_state = LosePacket(3); + EXPECT_EQ(PacketsToBytes(3), send_time_state.total_bytes_sent); + EXPECT_EQ(0u, send_time_state.total_bytes_acked); + EXPECT_EQ(0u, send_time_state.total_bytes_lost); + EXPECT_EQ(PacketsToBytes(2), sampler_.total_bytes_lost()); + + // Send packets 6-10. + for (int i = 6; i <= 10; i++) { + SendPacket(i); + EXPECT_EQ(PacketsToBytes(i), sampler_.total_bytes_sent()); + clock_.AdvanceTime(time_between_packets); + } + + // Ack all inflight packets. + QuicPacketCount acked_packet_count = 1; + EXPECT_EQ(PacketsToBytes(acked_packet_count), sampler_.total_bytes_acked()); + for (int i = 4; i <= 10; i++) { + send_time_state = AckPacketInner(i).state_at_send; + ++acked_packet_count; + EXPECT_EQ(PacketsToBytes(acked_packet_count), sampler_.total_bytes_acked()); + EXPECT_EQ(PacketsToBytes(i), send_time_state.total_bytes_sent); + if (i <= 5) { + EXPECT_EQ(0u, send_time_state.total_bytes_acked); + EXPECT_EQ(0u, send_time_state.total_bytes_lost); + } else { + EXPECT_EQ(PacketsToBytes(1), send_time_state.total_bytes_acked); + EXPECT_EQ(PacketsToBytes(2), send_time_state.total_bytes_lost); + } + + // This equation works because there is no neutered bytes. + EXPECT_EQ(send_time_state.total_bytes_sent - + send_time_state.total_bytes_acked - + send_time_state.total_bytes_lost, + send_time_state.bytes_in_flight); + + clock_.AdvanceTime(time_between_packets); + } +} + +// Test the sampler during regular windowed sender scenario with fixed +// CWND of 20. +TEST_P(BandwidthSamplerTest, SendPaced) { + const QuicTime::Delta time_between_packets = + QuicTime::Delta::FromMilliseconds(1); + QuicBandwidth expected_bandwidth = + QuicBandwidth::FromKBytesPerSecond(kRegularPacketSize); + + Send40PacketsAndAckFirst20(time_between_packets); + + // Ack the packets 21 to 40, arriving at the correct bandwidth. + QuicBandwidth last_bandwidth = QuicBandwidth::Zero(); + for (int i = 21; i <= 40; i++) { + last_bandwidth = AckPacket(i); + EXPECT_EQ(expected_bandwidth, last_bandwidth) << "i is " << i; + clock_.AdvanceTime(time_between_packets); + } + sampler_.RemoveObsoletePackets(QuicPacketNumber(41)); + + EXPECT_EQ(0u, BandwidthSamplerPeer::GetNumberOfTrackedPackets(sampler_)); + EXPECT_EQ(0u, bytes_in_flight_); +} + +// Test the sampler in a scenario where 50% of packets is consistently lost. +TEST_P(BandwidthSamplerTest, SendWithLosses) { + const QuicTime::Delta time_between_packets = + QuicTime::Delta::FromMilliseconds(1); + QuicBandwidth expected_bandwidth = + QuicBandwidth::FromKBytesPerSecond(kRegularPacketSize) * 0.5; + + // Send 20 packets, each 1 ms apart. + for (int i = 1; i <= 20; i++) { + SendPacket(i); + clock_.AdvanceTime(time_between_packets); + } + + // Ack packets 1 to 20, losing every even-numbered packet, while sending new + // packets at the same rate as before. + for (int i = 1; i <= 20; i++) { + if (i % 2 == 0) { + AckPacket(i); + } else { + LosePacket(i); + } + SendPacket(i + 20); + clock_.AdvanceTime(time_between_packets); + } + + // Ack the packets 21 to 40 with the same loss pattern. + QuicBandwidth last_bandwidth = QuicBandwidth::Zero(); + for (int i = 21; i <= 40; i++) { + if (i % 2 == 0) { + last_bandwidth = AckPacket(i); + EXPECT_EQ(expected_bandwidth, last_bandwidth); + } else { + LosePacket(i); + } + clock_.AdvanceTime(time_between_packets); + } + sampler_.RemoveObsoletePackets(QuicPacketNumber(41)); + + EXPECT_EQ(0u, BandwidthSamplerPeer::GetNumberOfTrackedPackets(sampler_)); + EXPECT_EQ(0u, bytes_in_flight_); +} + +// Test the sampler in a scenario where the 50% of packets are not +// congestion controlled (specifically, non-retransmittable data is not +// congestion controlled). Should be functionally consistent in behavior with +// the SendWithLosses test. +TEST_P(BandwidthSamplerTest, NotCongestionControlled) { + const QuicTime::Delta time_between_packets = + QuicTime::Delta::FromMilliseconds(1); + QuicBandwidth expected_bandwidth = + QuicBandwidth::FromKBytesPerSecond(kRegularPacketSize) * 0.5; + + // Send 20 packets, each 1 ms apart. Every even packet is not congestion + // controlled. + for (int i = 1; i <= 20; i++) { + SendPacketInner( + i, kRegularPacketSize, + i % 2 == 0 ? HAS_RETRANSMITTABLE_DATA : NO_RETRANSMITTABLE_DATA); + clock_.AdvanceTime(time_between_packets); + } + + // Ensure only congestion controlled packets are tracked. + EXPECT_EQ(10u, BandwidthSamplerPeer::GetNumberOfTrackedPackets(sampler_)); + + // Ack packets 2 to 21, ignoring every even-numbered packet, while sending new + // packets at the same rate as before. + for (int i = 1; i <= 20; i++) { + if (i % 2 == 0) { + AckPacket(i); + } + SendPacketInner( + i + 20, kRegularPacketSize, + i % 2 == 0 ? HAS_RETRANSMITTABLE_DATA : NO_RETRANSMITTABLE_DATA); + clock_.AdvanceTime(time_between_packets); + } + + // Ack the packets 22 to 41 with the same congestion controlled pattern. + QuicBandwidth last_bandwidth = QuicBandwidth::Zero(); + for (int i = 21; i <= 40; i++) { + if (i % 2 == 0) { + last_bandwidth = AckPacket(i); + EXPECT_EQ(expected_bandwidth, last_bandwidth); + } + clock_.AdvanceTime(time_between_packets); + } + sampler_.RemoveObsoletePackets(QuicPacketNumber(41)); + + // Since only congestion controlled packets are entered into the map, it has + // to be empty at this point. + EXPECT_EQ(0u, BandwidthSamplerPeer::GetNumberOfTrackedPackets(sampler_)); + EXPECT_EQ(0u, bytes_in_flight_); +} + +// Simulate a situation where ACKs arrive in burst and earlier than usual, thus +// producing an ACK rate which is higher than the original send rate. +TEST_P(BandwidthSamplerTest, CompressedAck) { + const QuicTime::Delta time_between_packets = + QuicTime::Delta::FromMilliseconds(1); + QuicBandwidth expected_bandwidth = + QuicBandwidth::FromKBytesPerSecond(kRegularPacketSize); + + Send40PacketsAndAckFirst20(time_between_packets); + + // Simulate an RTT somewhat lower than the one for 1-to-21 transmission. + clock_.AdvanceTime(time_between_packets * 15); + + // Ack the packets 21 to 40 almost immediately at once. + QuicBandwidth last_bandwidth = QuicBandwidth::Zero(); + QuicTime::Delta ridiculously_small_time_delta = + QuicTime::Delta::FromMicroseconds(20); + for (int i = 21; i <= 40; i++) { + last_bandwidth = AckPacket(i); + clock_.AdvanceTime(ridiculously_small_time_delta); + } + EXPECT_EQ(expected_bandwidth, last_bandwidth); + + sampler_.RemoveObsoletePackets(QuicPacketNumber(41)); + + EXPECT_EQ(0u, BandwidthSamplerPeer::GetNumberOfTrackedPackets(sampler_)); + EXPECT_EQ(0u, bytes_in_flight_); +} + +// Tests receiving ACK packets in the reverse order. +TEST_P(BandwidthSamplerTest, ReorderedAck) { + const QuicTime::Delta time_between_packets = + QuicTime::Delta::FromMilliseconds(1); + QuicBandwidth expected_bandwidth = + QuicBandwidth::FromKBytesPerSecond(kRegularPacketSize); + + Send40PacketsAndAckFirst20(time_between_packets); + + // Ack the packets 21 to 40 in the reverse order, while sending packets 41 to + // 60. + QuicBandwidth last_bandwidth = QuicBandwidth::Zero(); + for (int i = 0; i < 20; i++) { + last_bandwidth = AckPacket(40 - i); + EXPECT_EQ(expected_bandwidth, last_bandwidth); + SendPacket(41 + i); + clock_.AdvanceTime(time_between_packets); + } + + // Ack the packets 41 to 60, now in the regular order. + for (int i = 41; i <= 60; i++) { + last_bandwidth = AckPacket(i); + EXPECT_EQ(expected_bandwidth, last_bandwidth); + clock_.AdvanceTime(time_between_packets); + } + sampler_.RemoveObsoletePackets(QuicPacketNumber(61)); + + EXPECT_EQ(0u, BandwidthSamplerPeer::GetNumberOfTrackedPackets(sampler_)); + EXPECT_EQ(0u, bytes_in_flight_); +} + +// Test the app-limited logic. +TEST_P(BandwidthSamplerTest, AppLimited) { + const QuicTime::Delta time_between_packets = + QuicTime::Delta::FromMilliseconds(1); + QuicBandwidth expected_bandwidth = + QuicBandwidth::FromKBytesPerSecond(kRegularPacketSize); + + // Send 20 packets at a constant inter-packet time. + for (int i = 1; i <= 20; i++) { + SendPacket(i); + clock_.AdvanceTime(time_between_packets); + } + + // Ack packets 1 to 20, while sending new packets at the same rate as + // before. + for (int i = 1; i <= 20; i++) { + BandwidthSample sample = AckPacketInner(i); + EXPECT_EQ(sample.state_at_send.is_app_limited, + sampler_app_limited_at_start_); + SendPacket(i + 20); + clock_.AdvanceTime(time_between_packets); + } + + // We are now app-limited. Ack 21 to 40 as usual, but do not send anything for + // now. + sampler_.OnAppLimited(); + for (int i = 21; i <= 40; i++) { + BandwidthSample sample = AckPacketInner(i); + EXPECT_FALSE(sample.state_at_send.is_app_limited); + EXPECT_EQ(expected_bandwidth, sample.bandwidth); + clock_.AdvanceTime(time_between_packets); + } + + // Enter quiescence. + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(1)); + + // Send packets 41 to 60, all of which would be marked as app-limited. + for (int i = 41; i <= 60; i++) { + SendPacket(i); + clock_.AdvanceTime(time_between_packets); + } + + // Ack packets 41 to 60, while sending packets 61 to 80. 41 to 60 should be + // app-limited and underestimate the bandwidth due to that. + for (int i = 41; i <= 60; i++) { + BandwidthSample sample = AckPacketInner(i); + EXPECT_TRUE(sample.state_at_send.is_app_limited); + EXPECT_LT(sample.bandwidth, 0.7f * expected_bandwidth); + + SendPacket(i + 20); + clock_.AdvanceTime(time_between_packets); + } + + // Run out of packets, and then ack packet 61 to 80, all of which should have + // correct non-app-limited samples. + for (int i = 61; i <= 80; i++) { + BandwidthSample sample = AckPacketInner(i); + EXPECT_FALSE(sample.state_at_send.is_app_limited); + EXPECT_EQ(sample.bandwidth, expected_bandwidth); + clock_.AdvanceTime(time_between_packets); + } + sampler_.RemoveObsoletePackets(QuicPacketNumber(81)); + + EXPECT_EQ(0u, BandwidthSamplerPeer::GetNumberOfTrackedPackets(sampler_)); + EXPECT_EQ(0u, bytes_in_flight_); +} + +// Test the samples taken at the first flight of packets sent. +TEST_P(BandwidthSamplerTest, FirstRoundTrip) { + const QuicTime::Delta time_between_packets = + QuicTime::Delta::FromMilliseconds(1); + const QuicTime::Delta rtt = QuicTime::Delta::FromMilliseconds(800); + const int num_packets = 10; + const QuicByteCount num_bytes = kRegularPacketSize * num_packets; + const QuicBandwidth real_bandwidth = + QuicBandwidth::FromBytesAndTimeDelta(num_bytes, rtt); + + for (int i = 1; i <= 10; i++) { + SendPacket(i); + clock_.AdvanceTime(time_between_packets); + } + + clock_.AdvanceTime(rtt - num_packets * time_between_packets); + + QuicBandwidth last_sample = QuicBandwidth::Zero(); + for (int i = 1; i <= 10; i++) { + QuicBandwidth sample = AckPacket(i); + EXPECT_GT(sample, last_sample); + last_sample = sample; + clock_.AdvanceTime(time_between_packets); + } + + // The final measured sample for the first flight of sample is expected to be + // smaller than the real bandwidth, yet it should not lose more than 10%. The + // specific value of the error depends on the difference between the RTT and + // the time it takes to exhaust the congestion window (i.e. in the limit when + // all packets are sent simultaneously, last sample would indicate the real + // bandwidth). + EXPECT_LT(last_sample, real_bandwidth); + EXPECT_GT(last_sample, 0.9f * real_bandwidth); +} + +// Test sampler's ability to remove obsolete packets. +TEST_P(BandwidthSamplerTest, RemoveObsoletePackets) { + SendPacket(1); + SendPacket(2); + SendPacket(3); + SendPacket(4); + SendPacket(5); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(100)); + + EXPECT_EQ(5u, BandwidthSamplerPeer::GetNumberOfTrackedPackets(sampler_)); + sampler_.RemoveObsoletePackets(QuicPacketNumber(4)); + EXPECT_EQ(2u, BandwidthSamplerPeer::GetNumberOfTrackedPackets(sampler_)); + LosePacket(4); + sampler_.RemoveObsoletePackets(QuicPacketNumber(5)); + + EXPECT_EQ(1u, BandwidthSamplerPeer::GetNumberOfTrackedPackets(sampler_)); + AckPacket(5); + + sampler_.RemoveObsoletePackets(QuicPacketNumber(6)); + + EXPECT_EQ(0u, BandwidthSamplerPeer::GetNumberOfTrackedPackets(sampler_)); +} + +TEST_P(BandwidthSamplerTest, NeuterPacket) { + SendPacket(1); + EXPECT_EQ(0u, sampler_.total_bytes_neutered()); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + sampler_.OnPacketNeutered(QuicPacketNumber(1)); + EXPECT_LT(0u, sampler_.total_bytes_neutered()); + EXPECT_EQ(0u, sampler_.total_bytes_acked()); + + // If packet 1 is acked it should not produce a bandwidth sample. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + BandwidthSampler::CongestionEventSample sample = sampler_.OnCongestionEvent( + clock_.Now(), + {AckedPacket(QuicPacketNumber(1), kRegularPacketSize, clock_.Now())}, {}, + max_bandwidth_, est_bandwidth_upper_bound_, round_trip_count_); + EXPECT_EQ(0u, sampler_.total_bytes_acked()); + EXPECT_EQ(QuicBandwidth::Zero(), sample.sample_max_bandwidth); + EXPECT_FALSE(sample.sample_is_app_limited); + EXPECT_EQ(QuicTime::Delta::Infinite(), sample.sample_rtt); + EXPECT_EQ(0u, sample.sample_max_inflight); + EXPECT_EQ(0u, sample.extra_acked); +} + +TEST_P(BandwidthSamplerTest, CongestionEventSampleDefaultValues) { + // Make sure a default constructed CongestionEventSample has the correct + // initial values for BandwidthSampler::OnCongestionEvent() to work. + BandwidthSampler::CongestionEventSample sample; + + EXPECT_EQ(QuicBandwidth::Zero(), sample.sample_max_bandwidth); + EXPECT_FALSE(sample.sample_is_app_limited); + EXPECT_EQ(QuicTime::Delta::Infinite(), sample.sample_rtt); + EXPECT_EQ(0u, sample.sample_max_inflight); + EXPECT_EQ(0u, sample.extra_acked); +} + +// 1) Send 2 packets, 2) Ack both in 1 event, 3) Repeat. +TEST_P(BandwidthSamplerTest, TwoAckedPacketsPerEvent) { + QuicTime::Delta time_between_packets = QuicTime::Delta::FromMilliseconds(10); + QuicBandwidth sending_rate = QuicBandwidth::FromBytesAndTimeDelta( + kRegularPacketSize, time_between_packets); + + for (uint64_t i = 1; i < 21; i++) { + SendPacket(i); + clock_.AdvanceTime(time_between_packets); + if (i % 2 != 0) { + continue; + } + + BandwidthSampler::CongestionEventSample sample = + OnCongestionEvent({i - 1, i}, {}); + EXPECT_EQ(sending_rate, sample.sample_max_bandwidth); + EXPECT_EQ(time_between_packets, sample.sample_rtt); + EXPECT_EQ(2 * kRegularPacketSize, sample.sample_max_inflight); + EXPECT_TRUE(sample.last_packet_send_state.is_valid); + EXPECT_EQ(2 * kRegularPacketSize, + sample.last_packet_send_state.bytes_in_flight); + EXPECT_EQ(i * kRegularPacketSize, + sample.last_packet_send_state.total_bytes_sent); + EXPECT_EQ((i - 2) * kRegularPacketSize, + sample.last_packet_send_state.total_bytes_acked); + EXPECT_EQ(0u, sample.last_packet_send_state.total_bytes_lost); + sampler_.RemoveObsoletePackets(QuicPacketNumber(i - 2)); + } +} + +TEST_P(BandwidthSamplerTest, LoseEveryOtherPacket) { + QuicTime::Delta time_between_packets = QuicTime::Delta::FromMilliseconds(10); + QuicBandwidth sending_rate = QuicBandwidth::FromBytesAndTimeDelta( + kRegularPacketSize, time_between_packets); + + for (uint64_t i = 1; i < 21; i++) { + SendPacket(i); + clock_.AdvanceTime(time_between_packets); + if (i % 2 != 0) { + continue; + } + + // Ack packet i and lose i-1. + BandwidthSampler::CongestionEventSample sample = + OnCongestionEvent({i}, {i - 1}); + // Losing 50% packets means sending rate is twice the bandwidth. + EXPECT_EQ(sending_rate, sample.sample_max_bandwidth * 2); + EXPECT_EQ(time_between_packets, sample.sample_rtt); + EXPECT_EQ(kRegularPacketSize, sample.sample_max_inflight); + EXPECT_TRUE(sample.last_packet_send_state.is_valid); + EXPECT_EQ(2 * kRegularPacketSize, + sample.last_packet_send_state.bytes_in_flight); + EXPECT_EQ(i * kRegularPacketSize, + sample.last_packet_send_state.total_bytes_sent); + EXPECT_EQ((i - 2) * kRegularPacketSize / 2, + sample.last_packet_send_state.total_bytes_acked); + EXPECT_EQ((i - 2) * kRegularPacketSize / 2, + sample.last_packet_send_state.total_bytes_lost); + sampler_.RemoveObsoletePackets(QuicPacketNumber(i - 2)); + } +} + +TEST_P(BandwidthSamplerTest, AckHeightRespectBandwidthEstimateUpperBound) { + QuicTime::Delta time_between_packets = QuicTime::Delta::FromMilliseconds(10); + QuicBandwidth first_packet_sending_rate = + QuicBandwidth::FromBytesAndTimeDelta(kRegularPacketSize, + time_between_packets); + + // Send packets 1 to 4 and ack packet 1. + SendPacket(1); + clock_.AdvanceTime(time_between_packets); + SendPacket(2); + SendPacket(3); + SendPacket(4); + BandwidthSampler::CongestionEventSample sample = OnCongestionEvent({1}, {}); + EXPECT_EQ(first_packet_sending_rate, sample.sample_max_bandwidth); + EXPECT_EQ(first_packet_sending_rate, max_bandwidth_); + + // Ack packet 2, 3 and 4, all of which uses S(1) to calculate ack rate since + // there were no acks at the time they were sent. + round_trip_count_++; + est_bandwidth_upper_bound_ = first_packet_sending_rate * 0.3; + clock_.AdvanceTime(time_between_packets); + sample = OnCongestionEvent({2, 3, 4}, {}); + EXPECT_EQ(first_packet_sending_rate * 2, sample.sample_max_bandwidth); + EXPECT_EQ(max_bandwidth_, sample.sample_max_bandwidth); + + EXPECT_LT(2 * kRegularPacketSize, sample.extra_acked); +} + +class MaxAckHeightTrackerTest : public QuicTest { + protected: + MaxAckHeightTrackerTest() : tracker_(/*initial_filter_window=*/10) { + tracker_.SetAckAggregationBandwidthThreshold(1.8); + tracker_.SetStartNewAggregationEpochAfterFullRound(true); + } + + // Run a full aggregation episode, which is one or more aggregated acks, + // followed by a quiet period in which no ack happens. + // After this function returns, the time is set to the earliest point at which + // any ack event will cause tracker_.Update() to start a new aggregation. + void AggregationEpisode(QuicBandwidth aggregation_bandwidth, + QuicTime::Delta aggregation_duration, + QuicByteCount bytes_per_ack, + bool expect_new_aggregation_epoch) { + ASSERT_GE(aggregation_bandwidth, bandwidth_); + const QuicTime start_time = now_; + + const QuicByteCount aggregation_bytes = + aggregation_bandwidth * aggregation_duration; + + const int num_acks = aggregation_bytes / bytes_per_ack; + ASSERT_EQ(aggregation_bytes, num_acks * bytes_per_ack) + << "aggregation_bytes: " << aggregation_bytes << " [" + << aggregation_bandwidth << " in " << aggregation_duration + << "], bytes_per_ack: " << bytes_per_ack; + + const QuicTime::Delta time_between_acks = QuicTime::Delta::FromMicroseconds( + aggregation_duration.ToMicroseconds() / num_acks); + ASSERT_EQ(aggregation_duration, num_acks * time_between_acks) + << "aggregation_bytes: " << aggregation_bytes + << ", num_acks: " << num_acks + << ", time_between_acks: " << time_between_acks; + + // The total duration of aggregation time and quiet period. + const QuicTime::Delta total_duration = QuicTime::Delta::FromMicroseconds( + aggregation_bytes * 8 * 1000000 / bandwidth_.ToBitsPerSecond()); + ASSERT_EQ(aggregation_bytes, total_duration * bandwidth_) + << "total_duration: " << total_duration + << ", bandwidth_: " << bandwidth_; + + QuicByteCount last_extra_acked = 0; + for (QuicByteCount bytes = 0; bytes < aggregation_bytes; + bytes += bytes_per_ack) { + QuicByteCount extra_acked = tracker_.Update( + bandwidth_, true, RoundTripCount(), last_sent_packet_number_, + last_acked_packet_number_, now_, bytes_per_ack); + QUIC_VLOG(1) << "T" << now_ << ": Update after " << bytes_per_ack + << " bytes acked, " << extra_acked << " extra bytes acked"; + // |extra_acked| should be 0 if either + // [1] We are at the beginning of a aggregation epoch(bytes==0) and the + // the current tracker implementation can identify it, or + // [2] We are not really aggregating acks. + if ((bytes == 0 && expect_new_aggregation_epoch) || // [1] + (aggregation_bandwidth == bandwidth_)) { // [2] + EXPECT_EQ(0u, extra_acked); + } else { + EXPECT_LT(last_extra_acked, extra_acked); + } + now_ = now_ + time_between_acks; + last_extra_acked = extra_acked; + } + + // Advance past the quiet period. + const QuicTime time_after_aggregation = now_; + now_ = start_time + total_duration; + QUIC_VLOG(1) << "Advanced time from " << time_after_aggregation << " to " + << now_ << ". Aggregation time[" + << (time_after_aggregation - start_time) << "], Quiet time[" + << (now_ - time_after_aggregation) << "]."; + } + + QuicRoundTripCount RoundTripCount() const { + return (now_ - QuicTime::Zero()).ToMicroseconds() / rtt_.ToMicroseconds(); + } + + MaxAckHeightTracker tracker_; + QuicBandwidth bandwidth_ = QuicBandwidth::FromBytesPerSecond(10 * 1000); + QuicTime now_ = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(1); + QuicTime::Delta rtt_ = QuicTime::Delta::FromMilliseconds(60); + QuicPacketNumber last_sent_packet_number_; + QuicPacketNumber last_acked_packet_number_; +}; + +TEST_F(MaxAckHeightTrackerTest, VeryAggregatedLargeAck) { + AggregationEpisode(bandwidth_ * 20, QuicTime::Delta::FromMilliseconds(6), + 1200, true); + AggregationEpisode(bandwidth_ * 20, QuicTime::Delta::FromMilliseconds(6), + 1200, true); + now_ = now_ - QuicTime::Delta::FromMilliseconds(1); + + if (tracker_.ack_aggregation_bandwidth_threshold() > 1.1) { + AggregationEpisode(bandwidth_ * 20, QuicTime::Delta::FromMilliseconds(6), + 1200, true); + EXPECT_EQ(3u, tracker_.num_ack_aggregation_epochs()); + } else { + AggregationEpisode(bandwidth_ * 20, QuicTime::Delta::FromMilliseconds(6), + 1200, false); + EXPECT_EQ(2u, tracker_.num_ack_aggregation_epochs()); + } +} + +TEST_F(MaxAckHeightTrackerTest, VeryAggregatedSmallAcks) { + AggregationEpisode(bandwidth_ * 20, QuicTime::Delta::FromMilliseconds(6), 300, + true); + AggregationEpisode(bandwidth_ * 20, QuicTime::Delta::FromMilliseconds(6), 300, + true); + now_ = now_ - QuicTime::Delta::FromMilliseconds(1); + + if (tracker_.ack_aggregation_bandwidth_threshold() > 1.1) { + AggregationEpisode(bandwidth_ * 20, QuicTime::Delta::FromMilliseconds(6), + 300, true); + EXPECT_EQ(3u, tracker_.num_ack_aggregation_epochs()); + } else { + AggregationEpisode(bandwidth_ * 20, QuicTime::Delta::FromMilliseconds(6), + 300, false); + EXPECT_EQ(2u, tracker_.num_ack_aggregation_epochs()); + } +} + +TEST_F(MaxAckHeightTrackerTest, SomewhatAggregatedLargeAck) { + AggregationEpisode(bandwidth_ * 2, QuicTime::Delta::FromMilliseconds(50), + 1000, true); + AggregationEpisode(bandwidth_ * 2, QuicTime::Delta::FromMilliseconds(50), + 1000, true); + now_ = now_ - QuicTime::Delta::FromMilliseconds(1); + + if (tracker_.ack_aggregation_bandwidth_threshold() > 1.1) { + AggregationEpisode(bandwidth_ * 2, QuicTime::Delta::FromMilliseconds(50), + 1000, true); + EXPECT_EQ(3u, tracker_.num_ack_aggregation_epochs()); + } else { + AggregationEpisode(bandwidth_ * 2, QuicTime::Delta::FromMilliseconds(50), + 1000, false); + EXPECT_EQ(2u, tracker_.num_ack_aggregation_epochs()); + } +} + +TEST_F(MaxAckHeightTrackerTest, SomewhatAggregatedSmallAcks) { + AggregationEpisode(bandwidth_ * 2, QuicTime::Delta::FromMilliseconds(50), 100, + true); + AggregationEpisode(bandwidth_ * 2, QuicTime::Delta::FromMilliseconds(50), 100, + true); + now_ = now_ - QuicTime::Delta::FromMilliseconds(1); + + if (tracker_.ack_aggregation_bandwidth_threshold() > 1.1) { + AggregationEpisode(bandwidth_ * 2, QuicTime::Delta::FromMilliseconds(50), + 100, true); + EXPECT_EQ(3u, tracker_.num_ack_aggregation_epochs()); + } else { + AggregationEpisode(bandwidth_ * 2, QuicTime::Delta::FromMilliseconds(50), + 100, false); + EXPECT_EQ(2u, tracker_.num_ack_aggregation_epochs()); + } +} + +TEST_F(MaxAckHeightTrackerTest, NotAggregated) { + AggregationEpisode(bandwidth_, QuicTime::Delta::FromMilliseconds(100), 100, + true); + EXPECT_LT(2u, tracker_.num_ack_aggregation_epochs()); +} + +TEST_F(MaxAckHeightTrackerTest, StartNewEpochAfterAFullRound) { + last_sent_packet_number_ = QuicPacketNumber(10); + AggregationEpisode(bandwidth_ * 2, QuicTime::Delta::FromMilliseconds(50), 100, + true); + + last_acked_packet_number_ = QuicPacketNumber(11); + // Update with a tiny bandwidth causes a very low expected bytes acked, which + // in turn causes the current epoch to continue if the |tracker_| doesn't + // check the packet numbers. + tracker_.Update(bandwidth_ * 0.1, true, RoundTripCount(), + last_sent_packet_number_, last_acked_packet_number_, now_, + 100); + + EXPECT_EQ(2u, tracker_.num_ack_aggregation_epochs()); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/congestion_control/bbr2_drain.cc b/quiche/quic/core/congestion_control/bbr2_drain.cc new file mode 100644 index 000000000000..c13e9d81e29d --- /dev/null +++ b/quiche/quic/core/congestion_control/bbr2_drain.cc @@ -0,0 +1,59 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/bbr2_drain.h" + +#include "quiche/quic/core/congestion_control/bbr2_sender.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +Bbr2Mode Bbr2DrainMode::OnCongestionEvent( + QuicByteCount /*prior_in_flight*/, QuicTime /*event_time*/, + const AckedPacketVector& /*acked_packets*/, + const LostPacketVector& /*lost_packets*/, + const Bbr2CongestionEvent& congestion_event) { + model_->set_pacing_gain(Params().drain_pacing_gain); + + // Only STARTUP can transition to DRAIN, both of them use the same cwnd gain. + QUICHE_DCHECK_EQ(model_->cwnd_gain(), Params().drain_cwnd_gain); + model_->set_cwnd_gain(Params().drain_cwnd_gain); + + QuicByteCount drain_target = DrainTarget(); + if (congestion_event.bytes_in_flight <= drain_target) { + QUIC_DVLOG(3) << sender_ << " Exiting DRAIN. bytes_in_flight:" + << congestion_event.bytes_in_flight + << ", bdp:" << model_->BDP() + << ", drain_target:" << drain_target << " @ " + << congestion_event.event_time; + return Bbr2Mode::PROBE_BW; + } + + QUIC_DVLOG(3) << sender_ << " Staying in DRAIN. bytes_in_flight:" + << congestion_event.bytes_in_flight << ", bdp:" << model_->BDP() + << ", drain_target:" << drain_target << " @ " + << congestion_event.event_time; + return Bbr2Mode::DRAIN; +} + +QuicByteCount Bbr2DrainMode::DrainTarget() const { + QuicByteCount bdp = model_->BDP(); + return std::max(bdp, sender_->GetMinimumCongestionWindow()); +} + +Bbr2DrainMode::DebugState Bbr2DrainMode::ExportDebugState() const { + DebugState s; + s.drain_target = DrainTarget(); + return s; +} + +std::ostream& operator<<(std::ostream& os, + const Bbr2DrainMode::DebugState& state) { + os << "[DRAIN] drain_target: " << state.drain_target << "\n"; + return os; +} + +const Bbr2Params& Bbr2DrainMode::Params() const { return sender_->Params(); } + +} // namespace quic diff --git a/quiche/quic/core/congestion_control/bbr2_drain.h b/quiche/quic/core/congestion_control/bbr2_drain.h new file mode 100644 index 000000000000..1d61a41a933c --- /dev/null +++ b/quiche/quic/core/congestion_control/bbr2_drain.h @@ -0,0 +1,59 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CONGESTION_CONTROL_BBR2_DRAIN_H_ +#define QUICHE_QUIC_CORE_CONGESTION_CONTROL_BBR2_DRAIN_H_ + +#include "quiche/quic/core/congestion_control/bbr2_misc.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class Bbr2Sender; +class QUIC_EXPORT_PRIVATE Bbr2DrainMode final : public Bbr2ModeBase { + public: + using Bbr2ModeBase::Bbr2ModeBase; + + void Enter(QuicTime /*now*/, + const Bbr2CongestionEvent* /*congestion_event*/) override {} + void Leave(QuicTime /*now*/, + const Bbr2CongestionEvent* /*congestion_event*/) override {} + + Bbr2Mode OnCongestionEvent( + QuicByteCount prior_in_flight, QuicTime event_time, + const AckedPacketVector& acked_packets, + const LostPacketVector& lost_packets, + const Bbr2CongestionEvent& congestion_event) override; + + Limits GetCwndLimits() const override { + return NoGreaterThan(model_->inflight_lo()); + } + + bool IsProbingForBandwidth() const override { return false; } + + Bbr2Mode OnExitQuiescence(QuicTime /*now*/, + QuicTime /*quiescence_start_time*/) override { + return Bbr2Mode::DRAIN; + } + + struct QUIC_EXPORT_PRIVATE DebugState { + QuicByteCount drain_target; + }; + + DebugState ExportDebugState() const; + + private: + const Bbr2Params& Params() const; + + QuicByteCount DrainTarget() const; +}; + +QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const Bbr2DrainMode::DebugState& state); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CONGESTION_CONTROL_BBR2_DRAIN_H_ diff --git a/quiche/quic/core/congestion_control/bbr2_misc.cc b/quiche/quic/core/congestion_control/bbr2_misc.cc new file mode 100644 index 000000000000..55e63e9e3b5c --- /dev/null +++ b/quiche/quic/core/congestion_control/bbr2_misc.cc @@ -0,0 +1,460 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/bbr2_misc.h" + +#include "quiche/quic/core/congestion_control/bandwidth_sampler.h" +#include "quiche/quic/core/quic_bandwidth.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +RoundTripCounter::RoundTripCounter() : round_trip_count_(0) {} + +void RoundTripCounter::OnPacketSent(QuicPacketNumber packet_number) { + QUICHE_DCHECK(!last_sent_packet_.IsInitialized() || + last_sent_packet_ < packet_number); + last_sent_packet_ = packet_number; +} + +bool RoundTripCounter::OnPacketsAcked(QuicPacketNumber last_acked_packet) { + if (!end_of_round_trip_.IsInitialized() || + last_acked_packet > end_of_round_trip_) { + round_trip_count_++; + end_of_round_trip_ = last_sent_packet_; + return true; + } + return false; +} + +void RoundTripCounter::RestartRound() { + end_of_round_trip_ = last_sent_packet_; +} + +MinRttFilter::MinRttFilter(QuicTime::Delta initial_min_rtt, + QuicTime initial_min_rtt_timestamp) + : min_rtt_(initial_min_rtt), + min_rtt_timestamp_(initial_min_rtt_timestamp) {} + +void MinRttFilter::Update(QuicTime::Delta sample_rtt, QuicTime now) { + if (sample_rtt < min_rtt_ || min_rtt_timestamp_ == QuicTime::Zero()) { + min_rtt_ = sample_rtt; + min_rtt_timestamp_ = now; + } +} + +void MinRttFilter::ForceUpdate(QuicTime::Delta sample_rtt, QuicTime now) { + min_rtt_ = sample_rtt; + min_rtt_timestamp_ = now; +} + +Bbr2NetworkModel::Bbr2NetworkModel(const Bbr2Params* params, + QuicTime::Delta initial_rtt, + QuicTime initial_rtt_timestamp, + float cwnd_gain, float pacing_gain, + const BandwidthSampler* old_sampler) + : params_(params), + bandwidth_sampler_([](QuicRoundTripCount max_height_tracker_window_length, + const BandwidthSampler* old_sampler) { + if (old_sampler != nullptr) { + return BandwidthSampler(*old_sampler); + } + return BandwidthSampler(/*unacked_packet_map=*/nullptr, + max_height_tracker_window_length); + }(params->initial_max_ack_height_filter_window, old_sampler)), + min_rtt_filter_(initial_rtt, initial_rtt_timestamp), + cwnd_gain_(cwnd_gain), + pacing_gain_(pacing_gain) {} + +void Bbr2NetworkModel::OnPacketSent(QuicTime sent_time, + QuicByteCount bytes_in_flight, + QuicPacketNumber packet_number, + QuicByteCount bytes, + HasRetransmittableData is_retransmittable) { + // Updating the min here ensures a more realistic (0) value when flows exit + // quiescence. + if (bytes_in_flight < min_bytes_in_flight_in_round_) { + min_bytes_in_flight_in_round_ = bytes_in_flight; + } + if (bytes_in_flight + bytes >= inflight_hi_) { + inflight_hi_limited_in_round_ = true; + } + round_trip_counter_.OnPacketSent(packet_number); + + bandwidth_sampler_.OnPacketSent(sent_time, packet_number, bytes, + bytes_in_flight, is_retransmittable); +} + +void Bbr2NetworkModel::OnCongestionEventStart( + QuicTime event_time, const AckedPacketVector& acked_packets, + const LostPacketVector& lost_packets, + Bbr2CongestionEvent* congestion_event) { + const QuicByteCount prior_bytes_acked = total_bytes_acked(); + const QuicByteCount prior_bytes_lost = total_bytes_lost(); + + congestion_event->event_time = event_time; + congestion_event->end_of_round_trip = + acked_packets.empty() ? false + : round_trip_counter_.OnPacketsAcked( + acked_packets.rbegin()->packet_number); + + BandwidthSamplerInterface::CongestionEventSample sample = + bandwidth_sampler_.OnCongestionEvent(event_time, acked_packets, + lost_packets, MaxBandwidth(), + bandwidth_lo(), RoundTripCount()); + + if (sample.extra_acked == 0) { + cwnd_limited_before_aggregation_epoch_ = + congestion_event->prior_bytes_in_flight >= congestion_event->prior_cwnd; + } + + if (sample.last_packet_send_state.is_valid) { + congestion_event->last_packet_send_state = sample.last_packet_send_state; + } + + // Avoid updating |max_bandwidth_filter_| if a) this is a loss-only event, or + // b) all packets in |acked_packets| did not generate valid samples. (e.g. ack + // of ack-only packets). In both cases, total_bytes_acked() will not change. + if (prior_bytes_acked != total_bytes_acked()) { + QUIC_LOG_IF(WARNING, sample.sample_max_bandwidth.IsZero()) + << total_bytes_acked() - prior_bytes_acked << " bytes from " + << acked_packets.size() + << " packets have been acked, but sample_max_bandwidth is zero."; + congestion_event->sample_max_bandwidth = sample.sample_max_bandwidth; + if (!sample.sample_is_app_limited || + sample.sample_max_bandwidth > MaxBandwidth()) { + max_bandwidth_filter_.Update(congestion_event->sample_max_bandwidth); + } + } + + if (!sample.sample_rtt.IsInfinite()) { + congestion_event->sample_min_rtt = sample.sample_rtt; + min_rtt_filter_.Update(congestion_event->sample_min_rtt, event_time); + } + + congestion_event->bytes_acked = total_bytes_acked() - prior_bytes_acked; + congestion_event->bytes_lost = total_bytes_lost() - prior_bytes_lost; + + if (congestion_event->prior_bytes_in_flight >= + congestion_event->bytes_acked + congestion_event->bytes_lost) { + congestion_event->bytes_in_flight = + congestion_event->prior_bytes_in_flight - + congestion_event->bytes_acked - congestion_event->bytes_lost; + } else { + QUIC_LOG_FIRST_N(ERROR, 1) + << "prior_bytes_in_flight:" << congestion_event->prior_bytes_in_flight + << " is smaller than the sum of bytes_acked:" + << congestion_event->bytes_acked + << " and bytes_lost:" << congestion_event->bytes_lost; + congestion_event->bytes_in_flight = 0; + } + + if (congestion_event->bytes_lost > 0) { + bytes_lost_in_round_ += congestion_event->bytes_lost; + loss_events_in_round_++; + } + + if (congestion_event->bytes_acked > 0 && + congestion_event->last_packet_send_state.is_valid && + total_bytes_acked() > + congestion_event->last_packet_send_state.total_bytes_acked) { + QuicByteCount bytes_delivered = + total_bytes_acked() - + congestion_event->last_packet_send_state.total_bytes_acked; + max_bytes_delivered_in_round_ = + std::max(max_bytes_delivered_in_round_, bytes_delivered); + } + // TODO(ianswett) Consider treating any bytes lost as decreasing inflight, + // because it's a sign of overutilization, not underutilization. + if (congestion_event->bytes_in_flight < min_bytes_in_flight_in_round_) { + min_bytes_in_flight_in_round_ = congestion_event->bytes_in_flight; + } + + // |bandwidth_latest_| and |inflight_latest_| only increased within a round. + if (sample.sample_max_bandwidth > bandwidth_latest_) { + bandwidth_latest_ = sample.sample_max_bandwidth; + } + + if (sample.sample_max_inflight > inflight_latest_) { + inflight_latest_ = sample.sample_max_inflight; + } + + // Adapt lower bounds(bandwidth_lo and inflight_lo). + AdaptLowerBounds(*congestion_event); + + if (!congestion_event->end_of_round_trip) { + return; + } + + if (!sample.sample_max_bandwidth.IsZero()) { + bandwidth_latest_ = sample.sample_max_bandwidth; + } + + if (sample.sample_max_inflight > 0) { + inflight_latest_ = sample.sample_max_inflight; + } +} + +void Bbr2NetworkModel::AdaptLowerBounds( + const Bbr2CongestionEvent& congestion_event) { + if (Params().bw_lo_mode_ == Bbr2Params::DEFAULT) { + if (!congestion_event.end_of_round_trip || + congestion_event.is_probing_for_bandwidth) { + return; + } + + if (bytes_lost_in_round_ > 0) { + if (bandwidth_lo_.IsInfinite()) { + bandwidth_lo_ = MaxBandwidth(); + } + bandwidth_lo_ = + std::max(bandwidth_latest_, bandwidth_lo_ * (1.0 - Params().beta)); + QUIC_DVLOG(3) << "bandwidth_lo_ updated to " << bandwidth_lo_ + << ", bandwidth_latest_ is " << bandwidth_latest_; + + if (Params().ignore_inflight_lo) { + return; + } + if (inflight_lo_ == inflight_lo_default()) { + inflight_lo_ = congestion_event.prior_cwnd; + } + inflight_lo_ = std::max( + inflight_latest_, inflight_lo_ * (1.0 - Params().beta)); + } + return; + } + + // Params().bw_lo_mode_ != Bbr2Params::DEFAULT + if (congestion_event.bytes_lost == 0) { + return; + } + // Ignore losses from packets sent when probing for more bandwidth in + // STARTUP or PROBE_UP when they're lost in DRAIN or PROBE_DOWN. + if (pacing_gain_ < 1) { + return; + } + // Decrease bandwidth_lo whenever there is loss. + // Set bandwidth_lo_ if it is not yet set. + if (bandwidth_lo_.IsInfinite()) { + bandwidth_lo_ = MaxBandwidth(); + } + // Save bandwidth_lo_ if it hasn't already been saved. + if (prior_bandwidth_lo_.IsZero()) { + prior_bandwidth_lo_ = bandwidth_lo_; + } + switch (Params().bw_lo_mode_) { + case Bbr2Params::MIN_RTT_REDUCTION: + bandwidth_lo_ = + bandwidth_lo_ - QuicBandwidth::FromBytesAndTimeDelta( + congestion_event.bytes_lost, MinRtt()); + break; + case Bbr2Params::INFLIGHT_REDUCTION: { + // Use a max of BDP and inflight to avoid starving app-limited flows. + const QuicByteCount effective_inflight = + std::max(BDP(), congestion_event.prior_bytes_in_flight); + // This could use bytes_lost_in_round if the bandwidth_lo_ was saved + // when entering 'recovery', but this BBRv2 implementation doesn't have + // recovery defined. + bandwidth_lo_ = + bandwidth_lo_ * ((effective_inflight - congestion_event.bytes_lost) / + static_cast(effective_inflight)); + break; + } + case Bbr2Params::CWND_REDUCTION: + bandwidth_lo_ = + bandwidth_lo_ * + ((congestion_event.prior_cwnd - congestion_event.bytes_lost) / + static_cast(congestion_event.prior_cwnd)); + break; + case Bbr2Params::DEFAULT: + QUIC_BUG(quic_bug_10466_1) << "Unreachable case DEFAULT."; + } + QuicBandwidth last_bandwidth = bandwidth_latest_; + // sample_max_bandwidth will be Zero() if the loss is triggered by a timer + // expiring. Ideally we'd use the most recent bandwidth sample, + // but bandwidth_latest is safer than Zero(). + if (!congestion_event.sample_max_bandwidth.IsZero()) { + // bandwidth_latest_ is the max bandwidth for the round, but to allow + // fast, conservation style response to loss, use the last sample. + last_bandwidth = congestion_event.sample_max_bandwidth; + } + if (pacing_gain_ > Params().full_bw_threshold) { + // In STARTUP, pacing_gain_ is applied to bandwidth_lo_ in + // UpdatePacingRate, so this backs that multiplication out to allow the + // pacing rate to decrease, but not below + // last_bandwidth * full_bw_threshold. + // TODO(ianswett): Consider altering pacing_gain_ when in STARTUP instead. + bandwidth_lo_ = + std::max(bandwidth_lo_, + last_bandwidth * (Params().full_bw_threshold / pacing_gain_)); + } else { + // Ensure bandwidth_lo isn't lower than last_bandwidth. + bandwidth_lo_ = std::max(bandwidth_lo_, last_bandwidth); + } + // If it's the end of the round, ensure bandwidth_lo doesn't decrease more + // than beta. + if (congestion_event.end_of_round_trip) { + bandwidth_lo_ = + std::max(bandwidth_lo_, prior_bandwidth_lo_ * (1.0 - Params().beta)); + prior_bandwidth_lo_ = QuicBandwidth::Zero(); + } + // These modes ignore inflight_lo as well. +} + +void Bbr2NetworkModel::OnCongestionEventFinish( + QuicPacketNumber least_unacked_packet, + const Bbr2CongestionEvent& congestion_event) { + if (congestion_event.end_of_round_trip) { + OnNewRound(); + } + + bandwidth_sampler_.RemoveObsoletePackets(least_unacked_packet); +} + +void Bbr2NetworkModel::UpdateNetworkParameters(QuicTime::Delta rtt) { + if (!rtt.IsZero()) { + min_rtt_filter_.Update(rtt, MinRttTimestamp()); + } +} + +bool Bbr2NetworkModel::MaybeExpireMinRtt( + const Bbr2CongestionEvent& congestion_event) { + if (congestion_event.event_time < + (MinRttTimestamp() + Params().probe_rtt_period)) { + return false; + } + if (congestion_event.sample_min_rtt.IsInfinite()) { + return false; + } + QUIC_DVLOG(3) << "Replacing expired min rtt of " << min_rtt_filter_.Get() + << " by " << congestion_event.sample_min_rtt << " @ " + << congestion_event.event_time; + min_rtt_filter_.ForceUpdate(congestion_event.sample_min_rtt, + congestion_event.event_time); + return true; +} + +bool Bbr2NetworkModel::IsInflightTooHigh( + const Bbr2CongestionEvent& congestion_event, + int64_t max_loss_events) const { + const SendTimeState& send_state = congestion_event.last_packet_send_state; + if (!send_state.is_valid) { + // Not enough information. + return false; + } + + if (loss_events_in_round() < max_loss_events) { + return false; + } + + const QuicByteCount inflight_at_send = BytesInFlight(send_state); + // TODO(wub): Consider total_bytes_lost() - send_state.total_bytes_lost, which + // is the total bytes lost when the largest numbered packet was inflight. + // bytes_lost_in_round_, OTOH, is the total bytes lost in the "current" round. + const QuicByteCount bytes_lost_in_round = bytes_lost_in_round_; + + QUIC_DVLOG(3) << "IsInflightTooHigh: loss_events_in_round:" + << loss_events_in_round() + + << " bytes_lost_in_round:" << bytes_lost_in_round + << ", lost_in_round_threshold:" + << inflight_at_send * Params().loss_threshold; + + if (inflight_at_send > 0 && bytes_lost_in_round > 0) { + QuicByteCount lost_in_round_threshold = + inflight_at_send * Params().loss_threshold; + if (bytes_lost_in_round > lost_in_round_threshold) { + return true; + } + } + + return false; +} + +void Bbr2NetworkModel::RestartRoundEarly() { + OnNewRound(); + round_trip_counter_.RestartRound(); + rounds_with_queueing_ = 0; +} + +void Bbr2NetworkModel::OnNewRound() { + bytes_lost_in_round_ = 0; + loss_events_in_round_ = 0; + max_bytes_delivered_in_round_ = 0; + min_bytes_in_flight_in_round_ = std::numeric_limits::max(); + inflight_hi_limited_in_round_ = false; +} + +void Bbr2NetworkModel::cap_inflight_lo(QuicByteCount cap) { + if (Params().ignore_inflight_lo) { + return; + } + if (inflight_lo_ != inflight_lo_default() && inflight_lo_ > cap) { + inflight_lo_ = cap; + } +} + +QuicByteCount Bbr2NetworkModel::inflight_hi_with_headroom() const { + QuicByteCount headroom = inflight_hi_ * Params().inflight_hi_headroom; + + return inflight_hi_ > headroom ? inflight_hi_ - headroom : 0; +} + +bool Bbr2NetworkModel::HasBandwidthGrowth( + const Bbr2CongestionEvent& congestion_event) { + QUICHE_DCHECK(!full_bandwidth_reached_); + QUICHE_DCHECK(congestion_event.end_of_round_trip); + + QuicBandwidth threshold = + full_bandwidth_baseline_ * Params().full_bw_threshold; + + if (MaxBandwidth() >= threshold) { + QUIC_DVLOG(3) << " CheckBandwidthGrowth at end of round. max_bandwidth:" + << MaxBandwidth() << ", threshold:" << threshold + << " (Still growing) @ " << congestion_event.event_time; + full_bandwidth_baseline_ = MaxBandwidth(); + rounds_without_bandwidth_growth_ = 0; + return true; + } + ++rounds_without_bandwidth_growth_; + + // full_bandwidth_reached is only set to true when not app-limited, except + // when exit_startup_on_persistent_queue is true. + if (rounds_without_bandwidth_growth_ >= Params().startup_full_bw_rounds && + !congestion_event.last_packet_send_state.is_app_limited) { + full_bandwidth_reached_ = true; + } + QUIC_DVLOG(3) << " CheckBandwidthGrowth at end of round. max_bandwidth:" + << MaxBandwidth() << ", threshold:" << threshold + << " rounds_without_growth:" << rounds_without_bandwidth_growth_ + << " full_bw_reached:" << full_bandwidth_reached_ << " @ " + << congestion_event.event_time; + + return false; +} + +void Bbr2NetworkModel::CheckPersistentQueue( + const Bbr2CongestionEvent& congestion_event, float target_gain) { + QUICHE_DCHECK(congestion_event.end_of_round_trip); + QUICHE_DCHECK_NE(min_bytes_in_flight_in_round_, + std::numeric_limits::max()); + QUICHE_DCHECK_GE(target_gain, Params().full_bw_threshold); + QuicByteCount target = + std::max(static_cast(target_gain * BDP()), + BDP() + QueueingThresholdExtraBytes()); + if (min_bytes_in_flight_in_round_ < target) { + rounds_with_queueing_ = 0; + return; + } + rounds_with_queueing_++; + if (rounds_with_queueing_ >= Params().max_startup_queue_rounds) { + full_bandwidth_reached_ = true; + } +} + +} // namespace quic diff --git a/quiche/quic/core/congestion_control/bbr2_misc.h b/quiche/quic/core/congestion_control/bbr2_misc.h new file mode 100644 index 000000000000..8f0f8f69a25d --- /dev/null +++ b/quiche/quic/core/congestion_control/bbr2_misc.h @@ -0,0 +1,679 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CONGESTION_CONTROL_BBR2_MISC_H_ +#define QUICHE_QUIC_CORE_CONGESTION_CONTROL_BBR2_MISC_H_ + +#include +#include + +#include "quiche/quic/core/congestion_control/bandwidth_sampler.h" +#include "quiche/quic/core/congestion_control/send_algorithm_interface.h" +#include "quiche/quic/core/congestion_control/windowed_filter.h" +#include "quiche/quic/core/quic_bandwidth.h" +#include "quiche/quic/core/quic_packet_number.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_flags.h" + +namespace quic { + +template +class QUIC_EXPORT_PRIVATE Limits { + public: + Limits(T min, T max) : min_(min), max_(max) {} + + // If [min, max] is an empty range, i.e. min > max, this function returns max, + // because typically a value larger than max means "risky". + T ApplyLimits(T raw_value) const { + return std::min(max_, std::max(min_, raw_value)); + } + + T Min() const { return min_; } + T Max() const { return max_; } + + private: + T min_; + T max_; +}; + +template +QUIC_EXPORT_PRIVATE inline Limits MinMax(T min, T max) { + return Limits(min, max); +} + +template +QUIC_EXPORT_PRIVATE inline Limits NoLessThan(T min) { + return Limits(min, std::numeric_limits::max()); +} + +template +QUIC_EXPORT_PRIVATE inline Limits NoGreaterThan(T max) { + return Limits(std::numeric_limits::min(), max); +} + +template +QUIC_EXPORT_PRIVATE inline Limits Unlimited() { + return Limits(std::numeric_limits::min(), + std::numeric_limits::max()); +} + +template +QUIC_EXPORT_PRIVATE inline std::ostream& operator<<(std::ostream& os, + const Limits& limits) { + return os << "[" << limits.Min() << ", " << limits.Max() << "]"; +} + +// Bbr2Params contains all parameters of a Bbr2Sender. +struct QUIC_EXPORT_PRIVATE Bbr2Params { + Bbr2Params(QuicByteCount cwnd_min, QuicByteCount cwnd_max) + : cwnd_limits(cwnd_min, cwnd_max) {} + + /* + * STARTUP parameters. + */ + + // The gain for CWND in startup. + float startup_cwnd_gain = 2.0; + // TODO(wub): Maybe change to the newly derived value of 2.773 (4 * ln(2)). + float startup_pacing_gain = 2.885; + + // STARTUP or PROBE_UP are exited if the total bandwidth growth is less than + // |full_bw_threshold| in the last |startup_full_bw_rounds| round trips. + float full_bw_threshold = 1.25; + + QuicRoundTripCount startup_full_bw_rounds = 3; + + // Number of rounds to stay in STARTUP when there's a sufficient queue that + // bytes_in_flight never drops below the target (1.75 * BDP). 0 indicates the + // feature is disabled and we never exit due to queueing. + QuicRoundTripCount max_startup_queue_rounds = 0; + + // The minimum number of loss marking events to exit STARTUP. + int64_t startup_full_loss_count = + GetQuicFlag(quic_bbr2_default_startup_full_loss_count); + + // If true, always exit STARTUP on loss, even if bandwidth exceeds threshold. + // If false, exit STARTUP on loss only if bandwidth is below threshold. + bool always_exit_startup_on_excess_loss = false; + + // If true, include extra acked during STARTUP and proactively reduce extra + // acked when bandwidth increases. + bool startup_include_extra_acked = false; + + + /* + * DRAIN parameters. + */ + float drain_cwnd_gain = 2.0; + float drain_pacing_gain = 1.0 / 2.885; + + /* + * PROBE_BW parameters. + */ + // Max amount of randomness to inject in round counting for Reno-coexistence. + QuicRoundTripCount probe_bw_max_probe_rand_rounds = 2; + + // Max number of rounds before probing for Reno-coexistence. + uint32_t probe_bw_probe_max_rounds = 63; + + // Multiplier to get Reno-style probe epoch duration as: k * BDP round trips. + // If zero, disables Reno-style BDP-scaled coexistence mechanism. + float probe_bw_probe_reno_gain = 1.0; + + // Minimum duration for BBR-native probes. + QuicTime::Delta probe_bw_probe_base_duration = + QuicTime::Delta::FromMilliseconds( + GetQuicFlag(quic_bbr2_default_probe_bw_base_duration_ms)); + + // The upper bound of the random amount of BBR-native probes. + QuicTime::Delta probe_bw_probe_max_rand_duration = + QuicTime::Delta::FromMilliseconds( + GetQuicFlag(quic_bbr2_default_probe_bw_max_rand_duration_ms)); + + // The minimum number of loss marking events to exit the PROBE_UP phase. + int64_t probe_bw_full_loss_count = + GetQuicFlag(quic_bbr2_default_probe_bw_full_loss_count); + + // Pacing gains. + float probe_bw_probe_up_pacing_gain = 1.25; + float probe_bw_probe_down_pacing_gain = 0.75; + float probe_bw_default_pacing_gain = 1.0; + + float probe_bw_cwnd_gain = 2.0; + + /* + * PROBE_UP parameters. + */ + bool probe_up_ignore_inflight_hi = true; + bool probe_up_simplify_inflight_hi = false; + + // Number of rounds to stay in PROBE_UP when there's a sufficient queue that + // bytes_in_flight never drops below the target. 0 indicates the feature is + // disabled and we never exit due to queueing. + QuicRoundTripCount max_probe_up_queue_rounds = 0; + + /* + * PROBE_RTT parameters. + */ + float probe_rtt_inflight_target_bdp_fraction = 0.5; + QuicTime::Delta probe_rtt_period = QuicTime::Delta::FromMilliseconds( + GetQuicFlag(quic_bbr2_default_probe_rtt_period_ms)); + QuicTime::Delta probe_rtt_duration = QuicTime::Delta::FromMilliseconds(200); + + /* + * Parameters used by multiple modes. + */ + + // The initial value of the max ack height filter's window length. + QuicRoundTripCount initial_max_ack_height_filter_window = + GetQuicFlag(quic_bbr2_default_initial_ack_height_filter_window); + + // Fraction of unutilized headroom to try to leave in path upon high loss. + float inflight_hi_headroom = + GetQuicFlag(quic_bbr2_default_inflight_hi_headroom); + + // Estimate startup/bw probing has gone too far if loss rate exceeds this. + float loss_threshold = GetQuicFlag(quic_bbr2_default_loss_threshold); + + // A common factor for multiplicative decreases. Used for adjusting + // bandwidth_lo, inflight_lo and inflight_hi upon losses. + float beta = 0.3; + + Limits cwnd_limits; + + /* + * Experimental flags from QuicConfig. + */ + + // Can be disabled by connection option 'B2NA'. + bool add_ack_height_to_queueing_threshold = true; + + // Can be disabled by connection option 'B2RP'. + bool avoid_unnecessary_probe_rtt = true; + + // Can be enabled by connection option 'B2LO'. + bool ignore_inflight_lo = false; + + // Can be enabled by connection option 'B2H2'. + bool limit_inflight_hi_by_max_delivered = false; + + // Can be disabled by connection option 'B2SL'. + bool startup_loss_exit_use_max_delivered_for_inflight_hi = true; + + // Can be enabled by connection option 'B2DL'. + bool use_bytes_delivered_for_inflight_hi = false; + + // Can be disabled by connection option 'B2RC'. + bool enable_reno_coexistence = true; + + // For experimentation to improve fast convergence upon loss. + enum QuicBandwidthLoMode : uint8_t { + DEFAULT = 0, + MIN_RTT_REDUCTION = 1, // 'BBQ7' + INFLIGHT_REDUCTION = 2, // 'BBQ8' + CWND_REDUCTION = 3, // 'BBQ9' + }; + + // Different modes change bandwidth_lo_ differently upon loss. + QuicBandwidthLoMode bw_lo_mode_ = QuicBandwidthLoMode::DEFAULT; + + // Set the pacing gain to 25% larger than the recent BW increase in STARTUP. + bool decrease_startup_pacing_at_end_of_round = false; +}; + +class QUIC_EXPORT_PRIVATE RoundTripCounter { + public: + RoundTripCounter(); + + QuicRoundTripCount Count() const { return round_trip_count_; } + + QuicPacketNumber last_sent_packet() const { return last_sent_packet_; } + + // Must be called in ascending packet number order. + void OnPacketSent(QuicPacketNumber packet_number); + + // Return whether a round trip has just completed. + bool OnPacketsAcked(QuicPacketNumber last_acked_packet); + + void RestartRound(); + + private: + QuicRoundTripCount round_trip_count_; + QuicPacketNumber last_sent_packet_; + // The last sent packet number of the current round trip. + QuicPacketNumber end_of_round_trip_; +}; + +class QUIC_EXPORT_PRIVATE MinRttFilter { + public: + MinRttFilter(QuicTime::Delta initial_min_rtt, + QuicTime initial_min_rtt_timestamp); + + void Update(QuicTime::Delta sample_rtt, QuicTime now); + + void ForceUpdate(QuicTime::Delta sample_rtt, QuicTime now); + + QuicTime::Delta Get() const { return min_rtt_; } + + QuicTime GetTimestamp() const { return min_rtt_timestamp_; } + + private: + QuicTime::Delta min_rtt_; + // Time when the current value of |min_rtt_| was assigned. + QuicTime min_rtt_timestamp_; +}; + +class QUIC_EXPORT_PRIVATE Bbr2MaxBandwidthFilter { + public: + void Update(QuicBandwidth sample) { + max_bandwidth_[1] = std::max(sample, max_bandwidth_[1]); + } + + void Advance() { + if (max_bandwidth_[1].IsZero()) { + return; + } + + max_bandwidth_[0] = max_bandwidth_[1]; + max_bandwidth_[1] = QuicBandwidth::Zero(); + } + + QuicBandwidth Get() const { + return std::max(max_bandwidth_[0], max_bandwidth_[1]); + } + + private: + QuicBandwidth max_bandwidth_[2] = {QuicBandwidth::Zero(), + QuicBandwidth::Zero()}; +}; + +// Information that are meaningful only when Bbr2Sender::OnCongestionEvent is +// running. +struct QUIC_EXPORT_PRIVATE Bbr2CongestionEvent { + QuicTime event_time = QuicTime::Zero(); + + // The congestion window prior to the processing of the ack/loss events. + QuicByteCount prior_cwnd; + + // Total bytes inflight before the processing of the ack/loss events. + QuicByteCount prior_bytes_in_flight = 0; + + // Total bytes inflight after the processing of the ack/loss events. + QuicByteCount bytes_in_flight = 0; + + // Total bytes acked from acks in this event. + QuicByteCount bytes_acked = 0; + + // Total bytes lost from losses in this event. + QuicByteCount bytes_lost = 0; + + // Whether acked_packets indicates the end of a round trip. + bool end_of_round_trip = false; + + // When the event happened, whether the sender is probing for bandwidth. + bool is_probing_for_bandwidth = false; + + // Minimum rtt of all bandwidth samples from acked_packets. + // QuicTime::Delta::Infinite() if acked_packets is empty. + QuicTime::Delta sample_min_rtt = QuicTime::Delta::Infinite(); + + // Maximum bandwidth of all bandwidth samples from acked_packets. + // This sample may be app-limited, and will be Zero() if there are no newly + // acknowledged inflight packets. + QuicBandwidth sample_max_bandwidth = QuicBandwidth::Zero(); + + // The send state of the largest packet in acked_packets, unless it is empty. + // If acked_packets is empty, it's the send state of the largest packet in + // lost_packets. + SendTimeState last_packet_send_state; +}; + +// Bbr2NetworkModel takes low level congestion signals(packets sent/acked/lost) +// as input and produces BBRv2 model parameters like inflight_(hi|lo), +// bandwidth_(hi|lo), bandwidth and rtt estimates, etc. +class QUIC_EXPORT_PRIVATE Bbr2NetworkModel { + public: + Bbr2NetworkModel(const Bbr2Params* params, QuicTime::Delta initial_rtt, + QuicTime initial_rtt_timestamp, float cwnd_gain, + float pacing_gain, const BandwidthSampler* old_sampler); + + void OnPacketSent(QuicTime sent_time, QuicByteCount bytes_in_flight, + QuicPacketNumber packet_number, QuicByteCount bytes, + HasRetransmittableData is_retransmittable); + + void OnCongestionEventStart(QuicTime event_time, + const AckedPacketVector& acked_packets, + const LostPacketVector& lost_packets, + Bbr2CongestionEvent* congestion_event); + + void OnCongestionEventFinish(QuicPacketNumber least_unacked_packet, + const Bbr2CongestionEvent& congestion_event); + + // Update the model without a congestion event. + // Min rtt is updated if |rtt| is non-zero and smaller than existing min rtt. + void UpdateNetworkParameters(QuicTime::Delta rtt); + + // Update inflight/bandwidth short-term lower bounds. + void AdaptLowerBounds(const Bbr2CongestionEvent& congestion_event); + + // Restart the current round trip as if it is starting now. + void RestartRoundEarly(); + + void AdvanceMaxBandwidthFilter() { max_bandwidth_filter_.Advance(); } + + void OnApplicationLimited() { bandwidth_sampler_.OnAppLimited(); } + + // Calculates BDP using the current MaxBandwidth. + QuicByteCount BDP() const { return BDP(MaxBandwidth()); } + + QuicByteCount BDP(QuicBandwidth bandwidth) const { + return bandwidth * MinRtt(); + } + + QuicByteCount BDP(QuicBandwidth bandwidth, float gain) const { + return bandwidth * MinRtt() * gain; + } + + QuicTime::Delta MinRtt() const { return min_rtt_filter_.Get(); } + + QuicTime MinRttTimestamp() const { return min_rtt_filter_.GetTimestamp(); } + + // TODO(wub): If we do this too frequently, we can potentailly postpone + // PROBE_RTT indefinitely. Observe how it works in production and improve it. + void PostponeMinRttTimestamp(QuicTime::Delta duration) { + min_rtt_filter_.ForceUpdate(MinRtt(), MinRttTimestamp() + duration); + } + + QuicBandwidth MaxBandwidth() const { return max_bandwidth_filter_.Get(); } + + QuicByteCount MaxAckHeight() const { + return bandwidth_sampler_.max_ack_height(); + } + + // 2 packets. Used to indicate the typical number of bytes ACKed at once. + QuicByteCount QueueingThresholdExtraBytes() const { + return 2 * kDefaultTCPMSS; + } + + bool cwnd_limited_before_aggregation_epoch() const { + return cwnd_limited_before_aggregation_epoch_; + } + + void EnableOverestimateAvoidance() { + bandwidth_sampler_.EnableOverestimateAvoidance(); + } + + bool IsBandwidthOverestimateAvoidanceEnabled() const { + return bandwidth_sampler_.IsOverestimateAvoidanceEnabled(); + } + + void OnPacketNeutered(QuicPacketNumber packet_number) { + bandwidth_sampler_.OnPacketNeutered(packet_number); + } + + uint64_t num_ack_aggregation_epochs() const { + return bandwidth_sampler_.num_ack_aggregation_epochs(); + } + + void SetStartNewAggregationEpochAfterFullRound(bool value) { + bandwidth_sampler_.SetStartNewAggregationEpochAfterFullRound(value); + } + + void SetLimitMaxAckHeightTrackerBySendRate(bool value) { + bandwidth_sampler_.SetLimitMaxAckHeightTrackerBySendRate(value); + } + + void SetMaxAckHeightTrackerWindowLength(QuicRoundTripCount value) { + bandwidth_sampler_.SetMaxAckHeightTrackerWindowLength(value); + } + + void SetReduceExtraAckedOnBandwidthIncrease(bool value) { + bandwidth_sampler_.SetReduceExtraAckedOnBandwidthIncrease(value); + } + + bool MaybeExpireMinRtt(const Bbr2CongestionEvent& congestion_event); + + QuicBandwidth BandwidthEstimate() const { + return std::min(MaxBandwidth(), bandwidth_lo_); + } + + QuicRoundTripCount RoundTripCount() const { + return round_trip_counter_.Count(); + } + + // Return true if the number of loss events exceeds max_loss_events and + // fraction of bytes lost exceed the loss threshold. + bool IsInflightTooHigh(const Bbr2CongestionEvent& congestion_event, + int64_t max_loss_events) const; + + // Check bandwidth growth in the past round. Must be called at the end of a + // round. Returns true if there was sufficient bandwidth growth and false + // otherwise. If it's been too many rounds without growth, also sets + // |full_bandwidth_reached_| to true. + bool HasBandwidthGrowth(const Bbr2CongestionEvent& congestion_event); + + // Increments rounds_with_queueing_ if the minimum bytes in flight during the + // round is greater than the BDP * |target_gain|. + void CheckPersistentQueue(const Bbr2CongestionEvent& congestion_event, + float target_gain); + + QuicPacketNumber last_sent_packet() const { + return round_trip_counter_.last_sent_packet(); + } + + QuicByteCount total_bytes_acked() const { + return bandwidth_sampler_.total_bytes_acked(); + } + + QuicByteCount total_bytes_lost() const { + return bandwidth_sampler_.total_bytes_lost(); + } + + QuicByteCount total_bytes_sent() const { + return bandwidth_sampler_.total_bytes_sent(); + } + + int64_t loss_events_in_round() const { return loss_events_in_round_; } + + QuicByteCount max_bytes_delivered_in_round() const { + return max_bytes_delivered_in_round_; + } + + QuicByteCount min_bytes_in_flight_in_round() const { + return min_bytes_in_flight_in_round_; + } + + bool inflight_hi_limited_in_round() const { + return inflight_hi_limited_in_round_; + } + + QuicPacketNumber end_of_app_limited_phase() const { + return bandwidth_sampler_.end_of_app_limited_phase(); + } + + QuicBandwidth bandwidth_latest() const { return bandwidth_latest_; } + QuicBandwidth bandwidth_lo() const { return bandwidth_lo_; } + static QuicBandwidth bandwidth_lo_default() { + return QuicBandwidth::Infinite(); + } + void clear_bandwidth_lo() { bandwidth_lo_ = bandwidth_lo_default(); } + + QuicByteCount inflight_latest() const { return inflight_latest_; } + QuicByteCount inflight_lo() const { return inflight_lo_; } + static QuicByteCount inflight_lo_default() { + return std::numeric_limits::max(); + } + void clear_inflight_lo() { inflight_lo_ = inflight_lo_default(); } + void cap_inflight_lo(QuicByteCount cap); + + QuicByteCount inflight_hi_with_headroom() const; + QuicByteCount inflight_hi() const { return inflight_hi_; } + static QuicByteCount inflight_hi_default() { + return std::numeric_limits::max(); + } + void set_inflight_hi(QuicByteCount inflight_hi) { + inflight_hi_ = inflight_hi; + } + + float cwnd_gain() const { return cwnd_gain_; } + void set_cwnd_gain(float cwnd_gain) { cwnd_gain_ = cwnd_gain; } + + float pacing_gain() const { return pacing_gain_; } + void set_pacing_gain(float pacing_gain) { pacing_gain_ = pacing_gain; } + + bool full_bandwidth_reached() const { return full_bandwidth_reached_; } + void set_full_bandwidth_reached() { full_bandwidth_reached_ = true; } + QuicBandwidth full_bandwidth_baseline() const { + return full_bandwidth_baseline_; + } + QuicRoundTripCount rounds_without_bandwidth_growth() const { + return rounds_without_bandwidth_growth_; + } + QuicRoundTripCount rounds_with_queueing() const { + return rounds_with_queueing_; + } + + private: + // Called when a new round trip starts. + void OnNewRound(); + + const Bbr2Params& Params() const { return *params_; } + const Bbr2Params* const params_; + RoundTripCounter round_trip_counter_; + + // Bandwidth sampler provides BBR with the bandwidth measurements at + // individual points. + BandwidthSampler bandwidth_sampler_; + // The filter that tracks the maximum bandwidth over multiple recent round + // trips. + Bbr2MaxBandwidthFilter max_bandwidth_filter_; + MinRttFilter min_rtt_filter_; + + // Bytes lost in the current round. Updated once per congestion event. + QuicByteCount bytes_lost_in_round_ = 0; + // Number of loss marking events in the current round. + int64_t loss_events_in_round_ = 0; + + // A max of bytes delivered among all congestion events in the current round. + // A congestions event's bytes delivered is the total bytes acked between time + // Ts and Ta, which is the time when the largest acked packet(within the + // congestion event) was sent and acked, respectively. + QuicByteCount max_bytes_delivered_in_round_ = 0; + + // The minimum bytes in flight during this round. + QuicByteCount min_bytes_in_flight_in_round_ = + std::numeric_limits::max(); + + // True if sending was limited by inflight_hi anytime in the current round. + bool inflight_hi_limited_in_round_ = false; + + // Max bandwidth in the current round. Updated once per congestion event. + QuicBandwidth bandwidth_latest_ = QuicBandwidth::Zero(); + // Max bandwidth of recent rounds. Updated once per round. + QuicBandwidth bandwidth_lo_ = bandwidth_lo_default(); + // bandwidth_lo_ at the beginning of a round with loss. Only used when the + // bw_lo_mode is non-default. + QuicBandwidth prior_bandwidth_lo_ = QuicBandwidth::Zero(); + + // Max inflight in the current round. Updated once per congestion event. + QuicByteCount inflight_latest_ = 0; + // Max inflight of recent rounds. Updated once per round. + QuicByteCount inflight_lo_ = inflight_lo_default(); + QuicByteCount inflight_hi_ = inflight_hi_default(); + + float cwnd_gain_; + float pacing_gain_; + + // Whether we are cwnd limited prior to the start of the current aggregation + // epoch. + bool cwnd_limited_before_aggregation_epoch_ = false; + + // STARTUP-centric fields which experimentally used by PROBE_UP. + bool full_bandwidth_reached_ = false; + QuicBandwidth full_bandwidth_baseline_ = QuicBandwidth::Zero(); + QuicRoundTripCount rounds_without_bandwidth_growth_ = 0; + + // Used by STARTUP and PROBE_UP to decide when to exit. + QuicRoundTripCount rounds_with_queueing_ = 0; +}; + +enum class Bbr2Mode : uint8_t { + // Startup phase of the connection. + STARTUP, + // After achieving the highest possible bandwidth during the startup, lower + // the pacing rate in order to drain the queue. + DRAIN, + // Cruising mode. + PROBE_BW, + // Temporarily slow down sending in order to empty the buffer and measure + // the real minimum RTT. + PROBE_RTT, +}; + +QUIC_EXPORT_PRIVATE inline std::ostream& operator<<(std::ostream& os, + const Bbr2Mode& mode) { + switch (mode) { + case Bbr2Mode::STARTUP: + return os << "STARTUP"; + case Bbr2Mode::DRAIN: + return os << "DRAIN"; + case Bbr2Mode::PROBE_BW: + return os << "PROBE_BW"; + case Bbr2Mode::PROBE_RTT: + return os << "PROBE_RTT"; + } + return os << ""; +} + +// The base class for all BBRv2 modes. A Bbr2Sender is in one mode at a time, +// this interface is used to implement mode-specific behaviors. +class Bbr2Sender; +class QUIC_EXPORT_PRIVATE Bbr2ModeBase { + public: + Bbr2ModeBase(const Bbr2Sender* sender, Bbr2NetworkModel* model) + : sender_(sender), model_(model) {} + + virtual ~Bbr2ModeBase() = default; + + // Called when entering/leaving this mode. + // congestion_event != nullptr means BBRv2 is switching modes in the context + // of a ack and/or loss. + virtual void Enter(QuicTime now, + const Bbr2CongestionEvent* congestion_event) = 0; + virtual void Leave(QuicTime now, + const Bbr2CongestionEvent* congestion_event) = 0; + + virtual Bbr2Mode OnCongestionEvent( + QuicByteCount prior_in_flight, QuicTime event_time, + const AckedPacketVector& acked_packets, + const LostPacketVector& lost_packets, + const Bbr2CongestionEvent& congestion_event) = 0; + + virtual Limits GetCwndLimits() const = 0; + + virtual bool IsProbingForBandwidth() const = 0; + + virtual Bbr2Mode OnExitQuiescence(QuicTime now, + QuicTime quiescence_start_time) = 0; + + protected: + const Bbr2Sender* const sender_; + Bbr2NetworkModel* model_; +}; + +QUIC_EXPORT_PRIVATE inline QuicByteCount BytesInFlight( + const SendTimeState& send_state) { + QUICHE_DCHECK(send_state.is_valid); + if (send_state.bytes_in_flight != 0) { + return send_state.bytes_in_flight; + } + return send_state.total_bytes_sent - send_state.total_bytes_acked - + send_state.total_bytes_lost; +} + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CONGESTION_CONTROL_BBR2_MISC_H_ diff --git a/quiche/quic/core/congestion_control/bbr2_probe_bw.cc b/quiche/quic/core/congestion_control/bbr2_probe_bw.cc new file mode 100644 index 000000000000..cb07b9af71ed --- /dev/null +++ b/quiche/quic/core/congestion_control/bbr2_probe_bw.cc @@ -0,0 +1,653 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/bbr2_probe_bw.h" + +#include "quiche/quic/core/congestion_control/bbr2_misc.h" +#include "quiche/quic/core/congestion_control/bbr2_sender.h" +#include "quiche/quic/core/quic_bandwidth.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +void Bbr2ProbeBwMode::Enter(QuicTime now, + const Bbr2CongestionEvent* /*congestion_event*/) { + if (cycle_.phase == CyclePhase::PROBE_NOT_STARTED) { + // First time entering PROBE_BW. Start a new probing cycle. + EnterProbeDown(/*probed_too_high=*/false, /*stopped_risky_probe=*/false, + now); + } else { + // Transitioning from PROBE_RTT to PROBE_BW. Re-enter the last phase before + // PROBE_RTT. + QUICHE_DCHECK(cycle_.phase == CyclePhase::PROBE_CRUISE || + cycle_.phase == CyclePhase::PROBE_REFILL); + cycle_.cycle_start_time = now; + if (cycle_.phase == CyclePhase::PROBE_CRUISE) { + EnterProbeCruise(now); + } else if (cycle_.phase == CyclePhase::PROBE_REFILL) { + EnterProbeRefill(cycle_.probe_up_rounds, now); + } + } +} + +Bbr2Mode Bbr2ProbeBwMode::OnCongestionEvent( + QuicByteCount prior_in_flight, QuicTime event_time, + const AckedPacketVector& /*acked_packets*/, + const LostPacketVector& /*lost_packets*/, + const Bbr2CongestionEvent& congestion_event) { + QUICHE_DCHECK_NE(cycle_.phase, CyclePhase::PROBE_NOT_STARTED); + + if (congestion_event.end_of_round_trip) { + if (cycle_.cycle_start_time != event_time) { + ++cycle_.rounds_since_probe; + } + if (cycle_.phase_start_time != event_time) { + ++cycle_.rounds_in_phase; + } + } + + bool switch_to_probe_rtt = false; + + if (cycle_.phase == CyclePhase::PROBE_UP) { + UpdateProbeUp(prior_in_flight, congestion_event); + } else if (cycle_.phase == CyclePhase::PROBE_DOWN) { + UpdateProbeDown(prior_in_flight, congestion_event); + // Maybe transition to PROBE_RTT at the end of this cycle. + if (cycle_.phase != CyclePhase::PROBE_DOWN && + model_->MaybeExpireMinRtt(congestion_event)) { + switch_to_probe_rtt = true; + } + } else if (cycle_.phase == CyclePhase::PROBE_CRUISE) { + UpdateProbeCruise(congestion_event); + } else if (cycle_.phase == CyclePhase::PROBE_REFILL) { + UpdateProbeRefill(congestion_event); + } + + // Do not need to set the gains if switching to PROBE_RTT, they will be set + // when Bbr2ProbeRttMode::Enter is called. + if (!switch_to_probe_rtt) { + model_->set_pacing_gain(PacingGainForPhase(cycle_.phase)); + model_->set_cwnd_gain(Params().probe_bw_cwnd_gain); + } + + return switch_to_probe_rtt ? Bbr2Mode::PROBE_RTT : Bbr2Mode::PROBE_BW; +} + +Limits Bbr2ProbeBwMode::GetCwndLimits() const { + if (cycle_.phase == CyclePhase::PROBE_CRUISE) { + return NoGreaterThan( + std::min(model_->inflight_lo(), model_->inflight_hi_with_headroom())); + } + if (Params().probe_up_ignore_inflight_hi && + cycle_.phase == CyclePhase::PROBE_UP) { + // Similar to STARTUP. + return NoGreaterThan(model_->inflight_lo()); + } + + return NoGreaterThan(std::min(model_->inflight_lo(), model_->inflight_hi())); +} + +bool Bbr2ProbeBwMode::IsProbingForBandwidth() const { + return cycle_.phase == CyclePhase::PROBE_REFILL || + cycle_.phase == CyclePhase::PROBE_UP; +} + +Bbr2Mode Bbr2ProbeBwMode::OnExitQuiescence(QuicTime now, + QuicTime quiescence_start_time) { + QUIC_DVLOG(3) << sender_ << " Postponing min_rtt_timestamp(" + << model_->MinRttTimestamp() << ") by " + << now - quiescence_start_time; + model_->PostponeMinRttTimestamp(now - quiescence_start_time); + return Bbr2Mode::PROBE_BW; +} + +// TODO(ianswett): Remove prior_in_flight from UpdateProbeDown. +void Bbr2ProbeBwMode::UpdateProbeDown( + QuicByteCount prior_in_flight, + const Bbr2CongestionEvent& congestion_event) { + QUICHE_DCHECK_EQ(cycle_.phase, CyclePhase::PROBE_DOWN); + + if (cycle_.rounds_in_phase == 1 && congestion_event.end_of_round_trip) { + cycle_.is_sample_from_probing = false; + + if (!congestion_event.last_packet_send_state.is_app_limited) { + QUIC_DVLOG(2) + << sender_ + << " Advancing max bw filter after one round in PROBE_DOWN."; + model_->AdvanceMaxBandwidthFilter(); + cycle_.has_advanced_max_bw = true; + } + + if (last_cycle_stopped_risky_probe_ && !last_cycle_probed_too_high_) { + EnterProbeRefill(/*probe_up_rounds=*/0, congestion_event.event_time); + return; + } + } + + MaybeAdaptUpperBounds(congestion_event); + + if (IsTimeToProbeBandwidth(congestion_event)) { + EnterProbeRefill(/*probe_up_rounds=*/0, congestion_event.event_time); + return; + } + + if (HasStayedLongEnoughInProbeDown(congestion_event)) { + QUIC_DVLOG(3) << sender_ << " Proportional time based PROBE_DOWN exit"; + EnterProbeCruise(congestion_event.event_time); + return; + } + + const QuicByteCount inflight_with_headroom = + model_->inflight_hi_with_headroom(); + QUIC_DVLOG(3) + << sender_ + << " Checking if have enough inflight headroom. prior_in_flight:" + << prior_in_flight << " congestion_event.bytes_in_flight:" + << congestion_event.bytes_in_flight + << ", inflight_with_headroom:" << inflight_with_headroom; + QuicByteCount bytes_in_flight = congestion_event.bytes_in_flight; + + if (bytes_in_flight > inflight_with_headroom) { + // Stay in PROBE_DOWN. + return; + } + + // Transition to PROBE_CRUISE iff we've drained to target. + QuicByteCount bdp = model_->BDP(); + QUIC_DVLOG(3) << sender_ << " Checking if drained to target. bytes_in_flight:" + << bytes_in_flight << ", bdp:" << bdp; + if (bytes_in_flight < bdp) { + EnterProbeCruise(congestion_event.event_time); + } +} + +Bbr2ProbeBwMode::AdaptUpperBoundsResult Bbr2ProbeBwMode::MaybeAdaptUpperBounds( + const Bbr2CongestionEvent& congestion_event) { + const SendTimeState& send_state = congestion_event.last_packet_send_state; + if (!send_state.is_valid) { + QUIC_DVLOG(3) << sender_ << " " << cycle_.phase + << ": NOT_ADAPTED_INVALID_SAMPLE"; + return NOT_ADAPTED_INVALID_SAMPLE; + } + + // TODO(ianswett): Rename to bytes_delivered if + // use_bytes_delivered_for_inflight_hi is default enabled. + QuicByteCount inflight_at_send = BytesInFlight(send_state); + if (Params().use_bytes_delivered_for_inflight_hi) { + if (congestion_event.last_packet_send_state.total_bytes_acked <= + model_->total_bytes_acked()) { + inflight_at_send = + model_->total_bytes_acked() - + congestion_event.last_packet_send_state.total_bytes_acked; + } else { + QUIC_BUG(quic_bug_10436_1) + << "Total_bytes_acked(" << model_->total_bytes_acked() + << ") < send_state.total_bytes_acked(" + << congestion_event.last_packet_send_state.total_bytes_acked << ")"; + } + } + // TODO(ianswett): Inflight too high is really checking for loss, not + // inflight. + if (model_->IsInflightTooHigh(congestion_event, + Params().probe_bw_full_loss_count)) { + if (cycle_.is_sample_from_probing) { + cycle_.is_sample_from_probing = false; + if (!send_state.is_app_limited || + Params().max_probe_up_queue_rounds > 0) { + const QuicByteCount inflight_target = + sender_->GetTargetBytesInflight() * (1.0 - Params().beta); + if (inflight_at_send >= inflight_target) { + // The new code does not change behavior. + QUIC_CODE_COUNT(quic_bbr2_cut_inflight_hi_gradually_noop); + } else { + // The new code actually cuts inflight_hi slower than before. + QUIC_CODE_COUNT(quic_bbr2_cut_inflight_hi_gradually_in_effect); + } + if (Params().limit_inflight_hi_by_max_delivered) { + QuicByteCount new_inflight_hi = + std::max(inflight_at_send, inflight_target); + if (new_inflight_hi >= model_->max_bytes_delivered_in_round()) { + QUIC_CODE_COUNT(quic_bbr2_cut_inflight_hi_max_delivered_noop); + } else { + QUIC_CODE_COUNT(quic_bbr2_cut_inflight_hi_max_delivered_in_effect); + new_inflight_hi = model_->max_bytes_delivered_in_round(); + } + QUIC_DVLOG(3) << sender_ + << " Setting inflight_hi due to loss. new_inflight_hi:" + << new_inflight_hi + << ", inflight_at_send:" << inflight_at_send + << ", inflight_target:" << inflight_target + << ", max_bytes_delivered_in_round:" + << model_->max_bytes_delivered_in_round() << " @ " + << congestion_event.event_time; + model_->set_inflight_hi(new_inflight_hi); + } else { + model_->set_inflight_hi(std::max(inflight_at_send, inflight_target)); + } + } + + QUIC_DVLOG(3) << sender_ << " " << cycle_.phase + << ": ADAPTED_PROBED_TOO_HIGH"; + return ADAPTED_PROBED_TOO_HIGH; + } + return ADAPTED_OK; + } + + if (model_->inflight_hi() == model_->inflight_hi_default()) { + QUIC_DVLOG(3) << sender_ << " " << cycle_.phase + << ": NOT_ADAPTED_INFLIGHT_HIGH_NOT_SET"; + return NOT_ADAPTED_INFLIGHT_HIGH_NOT_SET; + } + + // Raise the upper bound for inflight. + if (inflight_at_send > model_->inflight_hi()) { + QUIC_DVLOG(3) + << sender_ << " " << cycle_.phase + << ": Adapting inflight_hi from inflight_at_send. inflight_at_send:" + << inflight_at_send << ", old inflight_hi:" << model_->inflight_hi(); + model_->set_inflight_hi(inflight_at_send); + } + + return ADAPTED_OK; +} + +bool Bbr2ProbeBwMode::IsTimeToProbeBandwidth( + const Bbr2CongestionEvent& congestion_event) const { + if (HasCycleLasted(cycle_.probe_wait_time, congestion_event)) { + return true; + } + + if (IsTimeToProbeForRenoCoexistence(1.0, congestion_event)) { + ++sender_->connection_stats_->bbr_num_short_cycles_for_reno_coexistence; + return true; + } + return false; +} + +// QUIC only. Used to prevent a Bbr2 flow from staying in PROBE_DOWN for too +// long, as seen in some multi-sender simulator tests. +bool Bbr2ProbeBwMode::HasStayedLongEnoughInProbeDown( + const Bbr2CongestionEvent& congestion_event) const { + // Stay in PROBE_DOWN for at most the time of a min rtt, as it is done in + // BBRv1. + // TODO(wub): Consider exit after a full round instead, which typically + // indicates most(if not all) packets sent during PROBE_UP have been acked. + return HasPhaseLasted(model_->MinRtt(), congestion_event); +} + +bool Bbr2ProbeBwMode::HasCycleLasted( + QuicTime::Delta duration, + const Bbr2CongestionEvent& congestion_event) const { + bool result = + (congestion_event.event_time - cycle_.cycle_start_time) > duration; + QUIC_DVLOG(3) << sender_ << " " << cycle_.phase + << ": HasCycleLasted=" << result << ". elapsed:" + << (congestion_event.event_time - cycle_.cycle_start_time) + << ", duration:" << duration; + return result; +} + +bool Bbr2ProbeBwMode::HasPhaseLasted( + QuicTime::Delta duration, + const Bbr2CongestionEvent& congestion_event) const { + bool result = + (congestion_event.event_time - cycle_.phase_start_time) > duration; + QUIC_DVLOG(3) << sender_ << " " << cycle_.phase + << ": HasPhaseLasted=" << result << ". elapsed:" + << (congestion_event.event_time - cycle_.phase_start_time) + << ", duration:" << duration; + return result; +} + +bool Bbr2ProbeBwMode::IsTimeToProbeForRenoCoexistence( + double probe_wait_fraction, + const Bbr2CongestionEvent& /*congestion_event*/) const { + if (!Params().enable_reno_coexistence) { + return false; + } + + uint64_t rounds = Params().probe_bw_probe_max_rounds; + if (Params().probe_bw_probe_reno_gain > 0.0) { + QuicByteCount target_bytes_inflight = sender_->GetTargetBytesInflight(); + uint64_t reno_rounds = Params().probe_bw_probe_reno_gain * + target_bytes_inflight / kDefaultTCPMSS; + rounds = std::min(rounds, reno_rounds); + } + bool result = cycle_.rounds_since_probe >= (rounds * probe_wait_fraction); + QUIC_DVLOG(3) << sender_ << " " << cycle_.phase + << ": IsTimeToProbeForRenoCoexistence=" << result + << ". rounds_since_probe:" << cycle_.rounds_since_probe + << ", rounds:" << rounds + << ", probe_wait_fraction:" << probe_wait_fraction; + return result; +} + +void Bbr2ProbeBwMode::RaiseInflightHighSlope() { + QUICHE_DCHECK_EQ(cycle_.phase, CyclePhase::PROBE_UP); + uint64_t growth_this_round = 1 << cycle_.probe_up_rounds; + // The number 30 below means |growth_this_round| is capped at 1G and the lower + // bound of |probe_up_bytes| is (practically) 1 mss, at this speed inflight_hi + // grows by approximately 1 packet per packet acked. + cycle_.probe_up_rounds = std::min(cycle_.probe_up_rounds + 1, 30); + uint64_t probe_up_bytes = sender_->GetCongestionWindow() / growth_this_round; + cycle_.probe_up_bytes = + std::max(probe_up_bytes, kDefaultTCPMSS); + QUIC_DVLOG(3) << sender_ << " Rasing inflight_hi slope. probe_up_rounds:" + << cycle_.probe_up_rounds + << ", probe_up_bytes:" << cycle_.probe_up_bytes; +} + +void Bbr2ProbeBwMode::ProbeInflightHighUpward( + const Bbr2CongestionEvent& congestion_event) { + QUICHE_DCHECK_EQ(cycle_.phase, CyclePhase::PROBE_UP); + if (Params().probe_up_ignore_inflight_hi) { + // When inflight_hi is disabled in PROBE_UP, it increases when + // the number of bytes delivered in a round is larger inflight_hi. + return; + } + if (Params().probe_up_simplify_inflight_hi) { + // Raise inflight_hi exponentially if it was utilized this round. + cycle_.probe_up_acked += congestion_event.bytes_acked; + if (!congestion_event.end_of_round_trip) { + return; + } + if (!model_->inflight_hi_limited_in_round() || + model_->loss_events_in_round() > 0) { + cycle_.probe_up_acked = 0; + return; + } + } else { + if (congestion_event.prior_bytes_in_flight < congestion_event.prior_cwnd) { + QUIC_DVLOG(3) << sender_ + << " Raising inflight_hi early return: Not cwnd limited."; + // Not fully utilizing cwnd, so can't safely grow. + return; + } + + if (congestion_event.prior_cwnd < model_->inflight_hi()) { + QUIC_DVLOG(3) + << sender_ + << " Raising inflight_hi early return: inflight_hi not fully used."; + // Not fully using inflight_hi, so don't grow it. + return; + } + + // Increase inflight_hi by the number of probe_up_bytes within + // probe_up_acked. + cycle_.probe_up_acked += congestion_event.bytes_acked; + } + + if (cycle_.probe_up_acked >= cycle_.probe_up_bytes) { + uint64_t delta = cycle_.probe_up_acked / cycle_.probe_up_bytes; + cycle_.probe_up_acked -= delta * cycle_.probe_up_bytes; + QuicByteCount new_inflight_hi = + model_->inflight_hi() + delta * kDefaultTCPMSS; + if (new_inflight_hi > model_->inflight_hi()) { + QUIC_DVLOG(3) << sender_ << " Raising inflight_hi from " + << model_->inflight_hi() << " to " << new_inflight_hi + << ". probe_up_bytes:" << cycle_.probe_up_bytes + << ", delta:" << delta + << ", (new)probe_up_acked:" << cycle_.probe_up_acked; + + model_->set_inflight_hi(new_inflight_hi); + } else { + QUIC_BUG(quic_bug_10436_2) + << "Not growing inflight_hi due to wrap around. Old value:" + << model_->inflight_hi() << ", new value:" << new_inflight_hi; + } + } + + if (congestion_event.end_of_round_trip) { + RaiseInflightHighSlope(); + } +} + +void Bbr2ProbeBwMode::UpdateProbeCruise( + const Bbr2CongestionEvent& congestion_event) { + QUICHE_DCHECK_EQ(cycle_.phase, CyclePhase::PROBE_CRUISE); + MaybeAdaptUpperBounds(congestion_event); + QUICHE_DCHECK(!cycle_.is_sample_from_probing); + + if (IsTimeToProbeBandwidth(congestion_event)) { + EnterProbeRefill(/*probe_up_rounds=*/0, congestion_event.event_time); + return; + } +} + +void Bbr2ProbeBwMode::UpdateProbeRefill( + const Bbr2CongestionEvent& congestion_event) { + QUICHE_DCHECK_EQ(cycle_.phase, CyclePhase::PROBE_REFILL); + MaybeAdaptUpperBounds(congestion_event); + QUICHE_DCHECK(!cycle_.is_sample_from_probing); + + if (cycle_.rounds_in_phase > 0 && congestion_event.end_of_round_trip) { + EnterProbeUp(congestion_event.event_time); + return; + } +} + +void Bbr2ProbeBwMode::UpdateProbeUp( + QuicByteCount prior_in_flight, + const Bbr2CongestionEvent& congestion_event) { + QUICHE_DCHECK_EQ(cycle_.phase, CyclePhase::PROBE_UP); + if (MaybeAdaptUpperBounds(congestion_event) == ADAPTED_PROBED_TOO_HIGH) { + EnterProbeDown(/*probed_too_high=*/true, /*stopped_risky_probe=*/false, + congestion_event.event_time); + return; + } + + // TODO(wub): Consider exit PROBE_UP after a certain number(e.g. 64) of RTTs. + + ProbeInflightHighUpward(congestion_event); + + bool is_risky = false; + bool is_queuing = false; + if (last_cycle_probed_too_high_ && prior_in_flight >= model_->inflight_hi()) { + is_risky = true; + QUIC_DVLOG(3) << sender_ + << " Probe is too risky. last_cycle_probed_too_high_:" + << last_cycle_probed_too_high_ + << ", prior_in_flight:" << prior_in_flight + << ", inflight_hi:" << model_->inflight_hi(); + // TCP uses min_rtt instead of a full round: + // HasPhaseLasted(model_->MinRtt(), congestion_event) + } else if (cycle_.rounds_in_phase > 0) { + if (Params().max_probe_up_queue_rounds > 0) { + if (congestion_event.end_of_round_trip) { + model_->CheckPersistentQueue(congestion_event, + Params().full_bw_threshold); + if (model_->rounds_with_queueing() >= + Params().max_probe_up_queue_rounds) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_bbr2_probe_two_rounds, 3, 3); + is_queuing = true; + } + } + } else { + QuicByteCount queuing_threshold_extra_bytes = + model_->QueueingThresholdExtraBytes(); + if (Params().add_ack_height_to_queueing_threshold) { + queuing_threshold_extra_bytes += model_->MaxAckHeight(); + } + QuicByteCount queuing_threshold = + (Params().full_bw_threshold * model_->BDP()) + + queuing_threshold_extra_bytes; + + is_queuing = congestion_event.bytes_in_flight >= queuing_threshold; + + QUIC_DVLOG(3) << sender_ + << " Checking if building up a queue. prior_in_flight:" + << prior_in_flight + << ", post_in_flight:" << congestion_event.bytes_in_flight + << ", threshold:" << queuing_threshold + << ", is_queuing:" << is_queuing + << ", max_bw:" << model_->MaxBandwidth() + << ", min_rtt:" << model_->MinRtt(); + } + } + + if (is_risky || is_queuing) { + EnterProbeDown(/*probed_too_high=*/false, /*stopped_risky_probe=*/is_risky, + congestion_event.event_time); + } +} + +void Bbr2ProbeBwMode::EnterProbeDown(bool probed_too_high, + bool stopped_risky_probe, QuicTime now) { + QUIC_DVLOG(2) << sender_ << " Phase change: " << cycle_.phase << " ==> " + << CyclePhase::PROBE_DOWN << " after " + << now - cycle_.phase_start_time << ", or " + << cycle_.rounds_in_phase + << " rounds. probed_too_high:" << probed_too_high + << ", stopped_risky_probe:" << stopped_risky_probe << " @ " + << now; + last_cycle_probed_too_high_ = probed_too_high; + last_cycle_stopped_risky_probe_ = stopped_risky_probe; + + cycle_.cycle_start_time = now; + cycle_.phase = CyclePhase::PROBE_DOWN; + cycle_.rounds_in_phase = 0; + cycle_.phase_start_time = now; + ++sender_->connection_stats_->bbr_num_cycles; + if (Params().bw_lo_mode_ != Bbr2Params::QuicBandwidthLoMode::DEFAULT) { + // Clear bandwidth lo if it was set in PROBE_UP, because losses in PROBE_UP + // should not permanently change bandwidth_lo. + // It's possible for bandwidth_lo to be set during REFILL, but if that was + // a valid value, it'll quickly be rediscovered. + model_->clear_bandwidth_lo(); + } + + // Pick probe wait time. + cycle_.rounds_since_probe = + sender_->RandomUint64(Params().probe_bw_max_probe_rand_rounds); + cycle_.probe_wait_time = + Params().probe_bw_probe_base_duration + + QuicTime::Delta::FromMicroseconds(sender_->RandomUint64( + Params().probe_bw_probe_max_rand_duration.ToMicroseconds())); + + cycle_.probe_up_bytes = std::numeric_limits::max(); + cycle_.probe_up_app_limited_since_inflight_hi_limited_ = false; + cycle_.has_advanced_max_bw = false; + model_->RestartRoundEarly(); +} + +void Bbr2ProbeBwMode::EnterProbeCruise(QuicTime now) { + if (cycle_.phase == CyclePhase::PROBE_DOWN) { + ExitProbeDown(); + } + QUIC_DVLOG(2) << sender_ << " Phase change: " << cycle_.phase << " ==> " + << CyclePhase::PROBE_CRUISE << " after " + << now - cycle_.phase_start_time << ", or " + << cycle_.rounds_in_phase << " rounds. @ " << now; + + model_->cap_inflight_lo(model_->inflight_hi()); + cycle_.phase = CyclePhase::PROBE_CRUISE; + cycle_.rounds_in_phase = 0; + cycle_.phase_start_time = now; + cycle_.is_sample_from_probing = false; +} + +void Bbr2ProbeBwMode::EnterProbeRefill(uint64_t probe_up_rounds, QuicTime now) { + if (cycle_.phase == CyclePhase::PROBE_DOWN) { + ExitProbeDown(); + } + QUIC_DVLOG(2) << sender_ << " Phase change: " << cycle_.phase << " ==> " + << CyclePhase::PROBE_REFILL << " after " + << now - cycle_.phase_start_time << ", or " + << cycle_.rounds_in_phase + << " rounds. probe_up_rounds:" << probe_up_rounds << " @ " + << now; + cycle_.phase = CyclePhase::PROBE_REFILL; + cycle_.rounds_in_phase = 0; + cycle_.phase_start_time = now; + cycle_.is_sample_from_probing = false; + last_cycle_stopped_risky_probe_ = false; + + model_->clear_bandwidth_lo(); + model_->clear_inflight_lo(); + cycle_.probe_up_rounds = probe_up_rounds; + cycle_.probe_up_acked = 0; + model_->RestartRoundEarly(); +} + +void Bbr2ProbeBwMode::EnterProbeUp(QuicTime now) { + QUICHE_DCHECK_EQ(cycle_.phase, CyclePhase::PROBE_REFILL); + QUIC_DVLOG(2) << sender_ << " Phase change: " << cycle_.phase << " ==> " + << CyclePhase::PROBE_UP << " after " + << now - cycle_.phase_start_time << ", or " + << cycle_.rounds_in_phase << " rounds. @ " << now; + cycle_.phase = CyclePhase::PROBE_UP; + cycle_.rounds_in_phase = 0; + cycle_.phase_start_time = now; + cycle_.is_sample_from_probing = true; + RaiseInflightHighSlope(); + + model_->RestartRoundEarly(); +} + +void Bbr2ProbeBwMode::ExitProbeDown() { + QUICHE_DCHECK_EQ(cycle_.phase, CyclePhase::PROBE_DOWN); + if (!cycle_.has_advanced_max_bw) { + QUIC_DVLOG(2) << sender_ << " Advancing max bw filter at end of cycle."; + model_->AdvanceMaxBandwidthFilter(); + cycle_.has_advanced_max_bw = true; + } +} + +// static +const char* Bbr2ProbeBwMode::CyclePhaseToString(CyclePhase phase) { + switch (phase) { + case CyclePhase::PROBE_NOT_STARTED: + return "PROBE_NOT_STARTED"; + case CyclePhase::PROBE_UP: + return "PROBE_UP"; + case CyclePhase::PROBE_DOWN: + return "PROBE_DOWN"; + case CyclePhase::PROBE_CRUISE: + return "PROBE_CRUISE"; + case CyclePhase::PROBE_REFILL: + return "PROBE_REFILL"; + default: + break; + } + return ""; +} + +std::ostream& operator<<(std::ostream& os, + const Bbr2ProbeBwMode::CyclePhase phase) { + return os << Bbr2ProbeBwMode::CyclePhaseToString(phase); +} + +Bbr2ProbeBwMode::DebugState Bbr2ProbeBwMode::ExportDebugState() const { + DebugState s; + s.phase = cycle_.phase; + s.cycle_start_time = cycle_.cycle_start_time; + s.phase_start_time = cycle_.phase_start_time; + return s; +} + +std::ostream& operator<<(std::ostream& os, + const Bbr2ProbeBwMode::DebugState& state) { + os << "[PROBE_BW] phase: " << state.phase << "\n"; + os << "[PROBE_BW] cycle_start_time: " << state.cycle_start_time << "\n"; + os << "[PROBE_BW] phase_start_time: " << state.phase_start_time << "\n"; + return os; +} + +const Bbr2Params& Bbr2ProbeBwMode::Params() const { return sender_->Params(); } + +float Bbr2ProbeBwMode::PacingGainForPhase( + Bbr2ProbeBwMode::CyclePhase phase) const { + if (phase == Bbr2ProbeBwMode::CyclePhase::PROBE_UP) { + return Params().probe_bw_probe_up_pacing_gain; + } + if (phase == Bbr2ProbeBwMode::CyclePhase::PROBE_DOWN) { + return Params().probe_bw_probe_down_pacing_gain; + } + return Params().probe_bw_default_pacing_gain; +} + +} // namespace quic diff --git a/quiche/quic/core/congestion_control/bbr2_probe_bw.h b/quiche/quic/core/congestion_control/bbr2_probe_bw.h new file mode 100644 index 000000000000..7a2448d15534 --- /dev/null +++ b/quiche/quic/core/congestion_control/bbr2_probe_bw.h @@ -0,0 +1,138 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CONGESTION_CONTROL_BBR2_PROBE_BW_H_ +#define QUICHE_QUIC_CORE_CONGESTION_CONTROL_BBR2_PROBE_BW_H_ + +#include + +#include "quiche/quic/core/congestion_control/bbr2_misc.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_flags.h" + +namespace quic { + +class Bbr2Sender; +class QUIC_EXPORT_PRIVATE Bbr2ProbeBwMode final : public Bbr2ModeBase { + public: + using Bbr2ModeBase::Bbr2ModeBase; + + void Enter(QuicTime now, + const Bbr2CongestionEvent* congestion_event) override; + void Leave(QuicTime /*now*/, + const Bbr2CongestionEvent* /*congestion_event*/) override {} + + Bbr2Mode OnCongestionEvent( + QuicByteCount prior_in_flight, QuicTime event_time, + const AckedPacketVector& acked_packets, + const LostPacketVector& lost_packets, + const Bbr2CongestionEvent& congestion_event) override; + + Limits GetCwndLimits() const override; + + bool IsProbingForBandwidth() const override; + + Bbr2Mode OnExitQuiescence(QuicTime now, + QuicTime quiescence_start_time) override; + + enum class CyclePhase : uint8_t { + PROBE_NOT_STARTED, + PROBE_UP, + PROBE_DOWN, + PROBE_CRUISE, + PROBE_REFILL, + }; + + static const char* CyclePhaseToString(CyclePhase phase); + + struct QUIC_EXPORT_PRIVATE DebugState { + CyclePhase phase; + QuicTime cycle_start_time = QuicTime::Zero(); + QuicTime phase_start_time = QuicTime::Zero(); + }; + + DebugState ExportDebugState() const; + + private: + const Bbr2Params& Params() const; + float PacingGainForPhase(CyclePhase phase) const; + + void UpdateProbeUp(QuicByteCount prior_in_flight, + const Bbr2CongestionEvent& congestion_event); + void UpdateProbeDown(QuicByteCount prior_in_flight, + const Bbr2CongestionEvent& congestion_event); + void UpdateProbeCruise(const Bbr2CongestionEvent& congestion_event); + void UpdateProbeRefill(const Bbr2CongestionEvent& congestion_event); + + enum AdaptUpperBoundsResult : uint8_t { + ADAPTED_OK, + ADAPTED_PROBED_TOO_HIGH, + NOT_ADAPTED_INFLIGHT_HIGH_NOT_SET, + NOT_ADAPTED_INVALID_SAMPLE, + }; + + // Return whether adapted inflight_hi. If inflight is too high, this function + // will not adapt inflight_hi and will return false. + AdaptUpperBoundsResult MaybeAdaptUpperBounds( + const Bbr2CongestionEvent& congestion_event); + + void EnterProbeDown(bool probed_too_high, bool stopped_risky_probe, + QuicTime now); + void EnterProbeCruise(QuicTime now); + void EnterProbeRefill(uint64_t probe_up_rounds, QuicTime now); + void EnterProbeUp(QuicTime now); + + // Call right before the exit of PROBE_DOWN. + void ExitProbeDown(); + + float PercentTimeElapsedToProbeBandwidth( + const Bbr2CongestionEvent& congestion_event) const; + + bool IsTimeToProbeBandwidth( + const Bbr2CongestionEvent& congestion_event) const; + bool HasStayedLongEnoughInProbeDown( + const Bbr2CongestionEvent& congestion_event) const; + bool HasCycleLasted(QuicTime::Delta duration, + const Bbr2CongestionEvent& congestion_event) const; + bool HasPhaseLasted(QuicTime::Delta duration, + const Bbr2CongestionEvent& congestion_event) const; + bool IsTimeToProbeForRenoCoexistence( + double probe_wait_fraction, + const Bbr2CongestionEvent& congestion_event) const; + + void RaiseInflightHighSlope(); + void ProbeInflightHighUpward(const Bbr2CongestionEvent& congestion_event); + + struct QUIC_EXPORT_PRIVATE Cycle { + QuicTime cycle_start_time = QuicTime::Zero(); + CyclePhase phase = CyclePhase::PROBE_NOT_STARTED; + uint64_t rounds_in_phase = 0; + QuicTime phase_start_time = QuicTime::Zero(); + QuicRoundTripCount rounds_since_probe = 0; + QuicTime::Delta probe_wait_time = QuicTime::Delta::Zero(); + uint64_t probe_up_rounds = 0; + QuicByteCount probe_up_bytes = std::numeric_limits::max(); + QuicByteCount probe_up_acked = 0; + bool probe_up_app_limited_since_inflight_hi_limited_ = false; + // Whether max bandwidth filter window has advanced in this cycle. It is + // advanced once per cycle. + bool has_advanced_max_bw = false; + bool is_sample_from_probing = false; + } cycle_; + + bool last_cycle_probed_too_high_; + bool last_cycle_stopped_risky_probe_; +}; + +QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const Bbr2ProbeBwMode::DebugState& state); + +QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const Bbr2ProbeBwMode::CyclePhase phase); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CONGESTION_CONTROL_BBR2_PROBE_BW_H_ diff --git a/quiche/quic/core/congestion_control/bbr2_probe_rtt.cc b/quiche/quic/core/congestion_control/bbr2_probe_rtt.cc new file mode 100644 index 000000000000..f425fd4efac8 --- /dev/null +++ b/quiche/quic/core/congestion_control/bbr2_probe_rtt.cc @@ -0,0 +1,79 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/bbr2_probe_rtt.h" + +#include "quiche/quic/core/congestion_control/bbr2_sender.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +void Bbr2ProbeRttMode::Enter(QuicTime /*now*/, + const Bbr2CongestionEvent* /*congestion_event*/) { + model_->set_pacing_gain(1.0); + model_->set_cwnd_gain(1.0); + exit_time_ = QuicTime::Zero(); +} + +Bbr2Mode Bbr2ProbeRttMode::OnCongestionEvent( + QuicByteCount /*prior_in_flight*/, QuicTime /*event_time*/, + const AckedPacketVector& /*acked_packets*/, + const LostPacketVector& /*lost_packets*/, + const Bbr2CongestionEvent& congestion_event) { + if (exit_time_ == QuicTime::Zero()) { + if (congestion_event.bytes_in_flight <= InflightTarget() || + congestion_event.bytes_in_flight <= + sender_->GetMinimumCongestionWindow()) { + exit_time_ = congestion_event.event_time + Params().probe_rtt_duration; + QUIC_DVLOG(2) << sender_ << " PROBE_RTT exit time set to " << exit_time_ + << ". bytes_inflight:" << congestion_event.bytes_in_flight + << ", inflight_target:" << InflightTarget() + << ", min_congestion_window:" + << sender_->GetMinimumCongestionWindow() << " @ " + << congestion_event.event_time; + } + return Bbr2Mode::PROBE_RTT; + } + + return congestion_event.event_time > exit_time_ ? Bbr2Mode::PROBE_BW + : Bbr2Mode::PROBE_RTT; +} + +QuicByteCount Bbr2ProbeRttMode::InflightTarget() const { + return model_->BDP(model_->MaxBandwidth(), + Params().probe_rtt_inflight_target_bdp_fraction); +} + +Limits Bbr2ProbeRttMode::GetCwndLimits() const { + QuicByteCount inflight_upper_bound = + std::min(model_->inflight_lo(), model_->inflight_hi_with_headroom()); + return NoGreaterThan(std::min(inflight_upper_bound, InflightTarget())); +} + +Bbr2Mode Bbr2ProbeRttMode::OnExitQuiescence( + QuicTime now, QuicTime /*quiescence_start_time*/) { + if (now > exit_time_) { + return Bbr2Mode::PROBE_BW; + } + return Bbr2Mode::PROBE_RTT; +} + +Bbr2ProbeRttMode::DebugState Bbr2ProbeRttMode::ExportDebugState() const { + DebugState s; + s.inflight_target = InflightTarget(); + s.exit_time = exit_time_; + return s; +} + +std::ostream& operator<<(std::ostream& os, + const Bbr2ProbeRttMode::DebugState& state) { + os << "[PROBE_RTT] inflight_target: " << state.inflight_target << "\n"; + os << "[PROBE_RTT] exit_time: " << state.exit_time << "\n"; + return os; +} + +const Bbr2Params& Bbr2ProbeRttMode::Params() const { return sender_->Params(); } + +} // namespace quic diff --git a/quiche/quic/core/congestion_control/bbr2_probe_rtt.h b/quiche/quic/core/congestion_control/bbr2_probe_rtt.h new file mode 100644 index 000000000000..a6edbb3d5f22 --- /dev/null +++ b/quiche/quic/core/congestion_control/bbr2_probe_rtt.h @@ -0,0 +1,58 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CONGESTION_CONTROL_BBR2_PROBE_RTT_H_ +#define QUICHE_QUIC_CORE_CONGESTION_CONTROL_BBR2_PROBE_RTT_H_ + +#include "quiche/quic/core/congestion_control/bbr2_misc.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class Bbr2Sender; +class QUIC_EXPORT_PRIVATE Bbr2ProbeRttMode final : public Bbr2ModeBase { + public: + using Bbr2ModeBase::Bbr2ModeBase; + + void Enter(QuicTime now, + const Bbr2CongestionEvent* congestion_event) override; + void Leave(QuicTime /*now*/, + const Bbr2CongestionEvent* /*congestion_event*/) override {} + + Bbr2Mode OnCongestionEvent( + QuicByteCount prior_in_flight, QuicTime event_time, + const AckedPacketVector& acked_packets, + const LostPacketVector& lost_packets, + const Bbr2CongestionEvent& congestion_event) override; + + Limits GetCwndLimits() const override; + + bool IsProbingForBandwidth() const override { return false; } + + Bbr2Mode OnExitQuiescence(QuicTime now, + QuicTime quiescence_start_time) override; + + struct QUIC_EXPORT_PRIVATE DebugState { + QuicByteCount inflight_target; + QuicTime exit_time = QuicTime::Zero(); + }; + + DebugState ExportDebugState() const; + + private: + const Bbr2Params& Params() const; + + QuicByteCount InflightTarget() const; + + QuicTime exit_time_ = QuicTime::Zero(); +}; + +QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const Bbr2ProbeRttMode::DebugState& state); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CONGESTION_CONTROL_BBR2_PROBE_RTT_H_ diff --git a/quiche/quic/core/congestion_control/bbr2_sender.cc b/quiche/quic/core/congestion_control/bbr2_sender.cc new file mode 100644 index 000000000000..6f884d87a935 --- /dev/null +++ b/quiche/quic/core/congestion_control/bbr2_sender.cc @@ -0,0 +1,577 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/bbr2_sender.h" + +#include + +#include "quiche/quic/core/congestion_control/bandwidth_sampler.h" +#include "quiche/quic/core/congestion_control/bbr2_drain.h" +#include "quiche/quic/core/congestion_control/bbr2_misc.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/quic_bandwidth.h" +#include "quiche/quic/core/quic_tag.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/print_elements.h" + +namespace quic { + +namespace { +// Constants based on TCP defaults. +// The minimum CWND to ensure delayed acks don't reduce bandwidth measurements. +// Does not inflate the pacing rate. +const QuicByteCount kDefaultMinimumCongestionWindow = 4 * kMaxSegmentSize; + +const float kInitialPacingGain = 2.885f; + +const int kMaxModeChangesPerCongestionEvent = 4; +} // namespace + +// Call |member_function_call| based on the current Bbr2Mode we are in. e.g. +// +// auto result = BBR2_MODE_DISPATCH(Foo()); +// +// is equivalent to: +// +// Bbr2ModeBase& Bbr2Sender::GetCurrentMode() { +// if (mode_ == Bbr2Mode::STARTUP) { return startup_; } +// if (mode_ == Bbr2Mode::DRAIN) { return drain_; } +// ... +// } +// auto result = GetCurrentMode().Foo(); +// +// Except that BBR2_MODE_DISPATCH guarantees the call to Foo() is non-virtual. +// +#define BBR2_MODE_DISPATCH(member_function_call) \ + (mode_ == Bbr2Mode::STARTUP \ + ? (startup_.member_function_call) \ + : (mode_ == Bbr2Mode::PROBE_BW \ + ? (probe_bw_.member_function_call) \ + : (mode_ == Bbr2Mode::DRAIN \ + ? (drain_.member_function_call) \ + : (probe_rtt_or_die().member_function_call)))) + +Bbr2Sender::Bbr2Sender(QuicTime now, const RttStats* rtt_stats, + const QuicUnackedPacketMap* unacked_packets, + QuicPacketCount initial_cwnd_in_packets, + QuicPacketCount max_cwnd_in_packets, QuicRandom* random, + QuicConnectionStats* stats, BbrSender* old_sender) + : mode_(Bbr2Mode::STARTUP), + rtt_stats_(rtt_stats), + unacked_packets_(unacked_packets), + random_(random), + connection_stats_(stats), + params_(kDefaultMinimumCongestionWindow, + max_cwnd_in_packets * kDefaultTCPMSS), + model_(¶ms_, rtt_stats->SmoothedOrInitialRtt(), + rtt_stats->last_update_time(), + /*cwnd_gain=*/1.0, + /*pacing_gain=*/kInitialPacingGain, + old_sender ? &old_sender->sampler_ : nullptr), + initial_cwnd_(cwnd_limits().ApplyLimits( + (old_sender) ? old_sender->GetCongestionWindow() + : (initial_cwnd_in_packets * kDefaultTCPMSS))), + cwnd_(initial_cwnd_), + pacing_rate_(kInitialPacingGain * + QuicBandwidth::FromBytesAndTimeDelta( + cwnd_, rtt_stats->SmoothedOrInitialRtt())), + startup_(this, &model_, now), + drain_(this, &model_), + probe_bw_(this, &model_), + probe_rtt_(this, &model_), + last_sample_is_app_limited_(false) { + QUIC_DVLOG(2) << this << " Initializing Bbr2Sender. mode:" << mode_ + << ", PacingRate:" << pacing_rate_ << ", Cwnd:" << cwnd_ + << ", CwndLimits:" << cwnd_limits() << " @ " << now; + QUICHE_DCHECK_EQ(mode_, Bbr2Mode::STARTUP); +} + +void Bbr2Sender::SetFromConfig(const QuicConfig& config, + Perspective perspective) { + if (config.HasClientRequestedIndependentOption(kB2NA, perspective)) { + params_.add_ack_height_to_queueing_threshold = false; + } + if (config.HasClientRequestedIndependentOption(kB2RP, perspective)) { + params_.avoid_unnecessary_probe_rtt = false; + } + if (config.HasClientRequestedIndependentOption(k1RTT, perspective)) { + params_.startup_full_bw_rounds = 1; + } + if (config.HasClientRequestedIndependentOption(k2RTT, perspective)) { + params_.startup_full_bw_rounds = 2; + } + if (config.HasClientRequestedIndependentOption(kB2HR, perspective)) { + params_.inflight_hi_headroom = 0.15; + } + if (config.HasClientRequestedIndependentOption(kICW1, perspective)) { + max_cwnd_when_network_parameters_adjusted_ = 100 * kDefaultTCPMSS; + } + + ApplyConnectionOptions(config.ClientRequestedIndependentOptions(perspective)); +} + +void Bbr2Sender::ApplyConnectionOptions( + const QuicTagVector& connection_options) { + if (GetQuicReloadableFlag(quic_bbr2_extra_acked_window) && + ContainsQuicTag(connection_options, kBBR4)) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_bbr2_extra_acked_window, 1, 2); + model_.SetMaxAckHeightTrackerWindowLength(20); + } + if (GetQuicReloadableFlag(quic_bbr2_extra_acked_window) && + ContainsQuicTag(connection_options, kBBR5)) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_bbr2_extra_acked_window, 2, 2); + model_.SetMaxAckHeightTrackerWindowLength(40); + } + if (ContainsQuicTag(connection_options, kBBQ1)) { + params_.startup_pacing_gain = 2.773; + params_.drain_pacing_gain = 1.0 / params_.drain_cwnd_gain; + } + if (ContainsQuicTag(connection_options, kBBQ2)) { + params_.startup_cwnd_gain = 2.885; + params_.drain_cwnd_gain = 2.885; + model_.set_cwnd_gain(params_.startup_cwnd_gain); + } + if (ContainsQuicTag(connection_options, kB2LO)) { + params_.ignore_inflight_lo = true; + } + if (ContainsQuicTag(connection_options, kB2NE)) { + params_.always_exit_startup_on_excess_loss = true; + } + if (ContainsQuicTag(connection_options, kB2SL)) { + params_.startup_loss_exit_use_max_delivered_for_inflight_hi = false; + } + if (ContainsQuicTag(connection_options, kB2H2)) { + params_.limit_inflight_hi_by_max_delivered = true; + } + if (ContainsQuicTag(connection_options, kB2DL)) { + params_.use_bytes_delivered_for_inflight_hi = true; + } + if (ContainsQuicTag(connection_options, kB2RC)) { + params_.enable_reno_coexistence = false; + } + if (ContainsQuicTag(connection_options, kBSAO)) { + model_.EnableOverestimateAvoidance(); + } + if (ContainsQuicTag(connection_options, kBBQ6)) { + params_.decrease_startup_pacing_at_end_of_round = true; + } + if (ContainsQuicTag(connection_options, kBBQ7)) { + params_.bw_lo_mode_ = Bbr2Params::QuicBandwidthLoMode::MIN_RTT_REDUCTION; + } + if (ContainsQuicTag(connection_options, kBBQ8)) { + params_.bw_lo_mode_ = Bbr2Params::QuicBandwidthLoMode::INFLIGHT_REDUCTION; + } + if (ContainsQuicTag(connection_options, kBBQ9)) { + params_.bw_lo_mode_ = Bbr2Params::QuicBandwidthLoMode::CWND_REDUCTION; + } + if (ContainsQuicTag(connection_options, kB202)) { + params_.max_probe_up_queue_rounds = 1; + } + if (ContainsQuicTag(connection_options, kB203)) { + params_.probe_up_ignore_inflight_hi = false; + } + if (ContainsQuicTag(connection_options, kB204)) { + model_.SetReduceExtraAckedOnBandwidthIncrease(true); + } + if (ContainsQuicTag(connection_options, kB205)) { + params_.startup_include_extra_acked = true; + } + if (ContainsQuicTag(connection_options, kB207)) { + params_.max_startup_queue_rounds = 1; + } + if (ContainsQuicTag(connection_options, kBBRA)) { + model_.SetStartNewAggregationEpochAfterFullRound(true); + } + if (ContainsQuicTag(connection_options, kBBRB)) { + model_.SetLimitMaxAckHeightTrackerBySendRate(true); + } + if (ContainsQuicTag(connection_options, kB206)) { + params_.startup_full_loss_count = params_.probe_bw_full_loss_count; + } + if (ContainsQuicTag(connection_options, kBBPD)) { + // Derived constant to ensure fairness. + params_.probe_bw_probe_down_pacing_gain = 0.91; + } + if (GetQuicReloadableFlag(quic_bbr2_simplify_inflight_hi) && + ContainsQuicTag(connection_options, kBBHI)) { + QUIC_RELOADABLE_FLAG_COUNT(quic_bbr2_simplify_inflight_hi); + params_.probe_up_simplify_inflight_hi = true; + // Simplify inflight_hi is intended as an alternative to ignoring it, + // so ensure we're not ignoring it. + params_.probe_up_ignore_inflight_hi = false; + } + if (GetQuicReloadableFlag(quic_bbr2_probe_two_rounds) && + ContainsQuicTag(connection_options, kBB2U)) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_bbr2_probe_two_rounds, 1, 3); + params_.max_probe_up_queue_rounds = 2; + } + if (GetQuicReloadableFlag(quic_bbr2_probe_two_rounds) && + ContainsQuicTag(connection_options, kBB2S)) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_bbr2_probe_two_rounds, 2, 3); + params_.max_startup_queue_rounds = 2; + } +} + +Limits Bbr2Sender::GetCwndLimitsByMode() const { + switch (mode_) { + case Bbr2Mode::STARTUP: + return startup_.GetCwndLimits(); + case Bbr2Mode::PROBE_BW: + return probe_bw_.GetCwndLimits(); + case Bbr2Mode::DRAIN: + return drain_.GetCwndLimits(); + case Bbr2Mode::PROBE_RTT: + return probe_rtt_.GetCwndLimits(); + default: + QUICHE_NOTREACHED(); + return Unlimited(); + } +} + +const Limits& Bbr2Sender::cwnd_limits() const { + return params().cwnd_limits; +} + +void Bbr2Sender::AdjustNetworkParameters(const NetworkParams& params) { + model_.UpdateNetworkParameters(params.rtt); + + if (mode_ == Bbr2Mode::STARTUP) { + const QuicByteCount prior_cwnd = cwnd_; + + QuicBandwidth effective_bandwidth = + std::max(params.bandwidth, model_.BandwidthEstimate()); + connection_stats_->cwnd_bootstrapping_rtt_us = + model_.MinRtt().ToMicroseconds(); + + if (params.max_initial_congestion_window > 0) { + max_cwnd_when_network_parameters_adjusted_ = + params.max_initial_congestion_window * kDefaultTCPMSS; + } + cwnd_ = cwnd_limits().ApplyLimits( + std::min(max_cwnd_when_network_parameters_adjusted_, + model_.BDP(effective_bandwidth))); + + if (!params.allow_cwnd_to_decrease) { + cwnd_ = std::max(cwnd_, prior_cwnd); + } + + pacing_rate_ = std::max(pacing_rate_, QuicBandwidth::FromBytesAndTimeDelta( + cwnd_, model_.MinRtt())); + } +} + +void Bbr2Sender::SetInitialCongestionWindowInPackets( + QuicPacketCount congestion_window) { + if (mode_ == Bbr2Mode::STARTUP) { + // The cwnd limits is unchanged and still applies to the new cwnd. + cwnd_ = cwnd_limits().ApplyLimits(congestion_window * kDefaultTCPMSS); + } +} + +void Bbr2Sender::OnCongestionEvent(bool /*rtt_updated*/, + QuicByteCount prior_in_flight, + QuicTime event_time, + const AckedPacketVector& acked_packets, + const LostPacketVector& lost_packets, + QuicPacketCount /*num_ect*/, + QuicPacketCount /*num_ce*/) { + QUIC_DVLOG(3) << this + << " OnCongestionEvent. prior_in_flight:" << prior_in_flight + << " prior_cwnd:" << cwnd_ << " @ " << event_time; + Bbr2CongestionEvent congestion_event; + congestion_event.prior_cwnd = cwnd_; + congestion_event.prior_bytes_in_flight = prior_in_flight; + congestion_event.is_probing_for_bandwidth = + BBR2_MODE_DISPATCH(IsProbingForBandwidth()); + + model_.OnCongestionEventStart(event_time, acked_packets, lost_packets, + &congestion_event); + + if (InSlowStart()) { + if (!lost_packets.empty()) { + connection_stats_->slowstart_packets_lost += lost_packets.size(); + connection_stats_->slowstart_bytes_lost += congestion_event.bytes_lost; + } + if (congestion_event.end_of_round_trip) { + ++connection_stats_->slowstart_num_rtts; + } + } + + // Number of mode changes allowed for this congestion event. + int mode_changes_allowed = kMaxModeChangesPerCongestionEvent; + while (true) { + Bbr2Mode next_mode = BBR2_MODE_DISPATCH( + OnCongestionEvent(prior_in_flight, event_time, acked_packets, + lost_packets, congestion_event)); + + if (next_mode == mode_) { + break; + } + + QUIC_DVLOG(2) << this << " Mode change: " << mode_ << " ==> " << next_mode + << " @ " << event_time; + BBR2_MODE_DISPATCH(Leave(event_time, &congestion_event)); + mode_ = next_mode; + BBR2_MODE_DISPATCH(Enter(event_time, &congestion_event)); + --mode_changes_allowed; + if (mode_changes_allowed < 0) { + QUIC_BUG(quic_bug_10443_1) + << "Exceeded max number of mode changes per congestion event."; + break; + } + } + + UpdatePacingRate(congestion_event.bytes_acked); + QUIC_BUG_IF(quic_bug_10443_2, pacing_rate_.IsZero()) + << "Pacing rate must not be zero!"; + + UpdateCongestionWindow(congestion_event.bytes_acked); + QUIC_BUG_IF(quic_bug_10443_3, cwnd_ == 0u) + << "Congestion window must not be zero!"; + + model_.OnCongestionEventFinish(unacked_packets_->GetLeastUnacked(), + congestion_event); + last_sample_is_app_limited_ = + congestion_event.last_packet_send_state.is_app_limited; + if (!last_sample_is_app_limited_) { + has_non_app_limited_sample_ = true; + } + if (congestion_event.bytes_in_flight == 0 && + params().avoid_unnecessary_probe_rtt) { + OnEnterQuiescence(event_time); + } + + QUIC_DVLOG(3) + << this + << " END CongestionEvent(acked:" << quiche::PrintElements(acked_packets) + << ", lost:" << lost_packets.size() << ") " + << ", Mode:" << mode_ << ", RttCount:" << model_.RoundTripCount() + << ", BytesInFlight:" << congestion_event.bytes_in_flight + << ", PacingRate:" << PacingRate(0) << ", CWND:" << GetCongestionWindow() + << ", PacingGain:" << model_.pacing_gain() + << ", CwndGain:" << model_.cwnd_gain() + << ", BandwidthEstimate(kbps):" << BandwidthEstimate().ToKBitsPerSecond() + << ", MinRTT(us):" << model_.MinRtt().ToMicroseconds() + << ", BDP:" << model_.BDP(BandwidthEstimate()) + << ", BandwidthLatest(kbps):" + << model_.bandwidth_latest().ToKBitsPerSecond() + << ", BandwidthLow(kbps):" << model_.bandwidth_lo().ToKBitsPerSecond() + << ", BandwidthHigh(kbps):" << model_.MaxBandwidth().ToKBitsPerSecond() + << ", InflightLatest:" << model_.inflight_latest() + << ", InflightLow:" << model_.inflight_lo() + << ", InflightHigh:" << model_.inflight_hi() + << ", TotalAcked:" << model_.total_bytes_acked() + << ", TotalLost:" << model_.total_bytes_lost() + << ", TotalSent:" << model_.total_bytes_sent() << " @ " << event_time; +} + +void Bbr2Sender::UpdatePacingRate(QuicByteCount bytes_acked) { + if (BandwidthEstimate().IsZero()) { + return; + } + + if (model_.total_bytes_acked() == bytes_acked) { + // After the first ACK, cwnd_ is still the initial congestion window. + pacing_rate_ = QuicBandwidth::FromBytesAndTimeDelta(cwnd_, model_.MinRtt()); + return; + } + + QuicBandwidth target_rate = model_.pacing_gain() * model_.BandwidthEstimate(); + if (model_.full_bandwidth_reached()) { + pacing_rate_ = target_rate; + return; + } + if (params_.decrease_startup_pacing_at_end_of_round && + model_.pacing_gain() < Params().startup_pacing_gain) { + pacing_rate_ = target_rate; + return; + } + if (params_.bw_lo_mode_ != Bbr2Params::DEFAULT && + model_.loss_events_in_round() > 0) { + pacing_rate_ = target_rate; + return; + } + + // By default, the pacing rate never decreases in STARTUP. + if (target_rate > pacing_rate_) { + pacing_rate_ = target_rate; + } +} + +void Bbr2Sender::UpdateCongestionWindow(QuicByteCount bytes_acked) { + QuicByteCount target_cwnd = GetTargetCongestionWindow(model_.cwnd_gain()); + + const QuicByteCount prior_cwnd = cwnd_; + if (model_.full_bandwidth_reached() || Params().startup_include_extra_acked) { + target_cwnd += model_.MaxAckHeight(); + cwnd_ = std::min(prior_cwnd + bytes_acked, target_cwnd); + } else if (prior_cwnd < target_cwnd || prior_cwnd < 2 * initial_cwnd_) { + cwnd_ = prior_cwnd + bytes_acked; + } + const QuicByteCount desired_cwnd = cwnd_; + + cwnd_ = GetCwndLimitsByMode().ApplyLimits(cwnd_); + const QuicByteCount model_limited_cwnd = cwnd_; + + cwnd_ = cwnd_limits().ApplyLimits(cwnd_); + + QUIC_DVLOG(3) << this << " Updating CWND. target_cwnd:" << target_cwnd + << ", max_ack_height:" << model_.MaxAckHeight() + << ", full_bw:" << model_.full_bandwidth_reached() + << ", bytes_acked:" << bytes_acked + << ", inflight_lo:" << model_.inflight_lo() + << ", inflight_hi:" << model_.inflight_hi() << ". (prior_cwnd) " + << prior_cwnd << " => (desired_cwnd) " << desired_cwnd + << " => (model_limited_cwnd) " << model_limited_cwnd + << " => (final_cwnd) " << cwnd_; +} + +QuicByteCount Bbr2Sender::GetTargetCongestionWindow(float gain) const { + return std::max(model_.BDP(model_.BandwidthEstimate(), gain), + cwnd_limits().Min()); +} + +void Bbr2Sender::OnPacketSent(QuicTime sent_time, QuicByteCount bytes_in_flight, + QuicPacketNumber packet_number, + QuicByteCount bytes, + HasRetransmittableData is_retransmittable) { + QUIC_DVLOG(3) << this << " OnPacketSent: pkn:" << packet_number + << ", bytes:" << bytes << ", cwnd:" << cwnd_ + << ", inflight:" << bytes_in_flight + bytes + << ", total_sent:" << model_.total_bytes_sent() + bytes + << ", total_acked:" << model_.total_bytes_acked() + << ", total_lost:" << model_.total_bytes_lost() << " @ " + << sent_time; + if (InSlowStart()) { + ++connection_stats_->slowstart_packets_sent; + connection_stats_->slowstart_bytes_sent += bytes; + } + if (bytes_in_flight == 0 && params().avoid_unnecessary_probe_rtt) { + OnExitQuiescence(sent_time); + } + model_.OnPacketSent(sent_time, bytes_in_flight, packet_number, bytes, + is_retransmittable); +} + +void Bbr2Sender::OnPacketNeutered(QuicPacketNumber packet_number) { + model_.OnPacketNeutered(packet_number); +} + +bool Bbr2Sender::CanSend(QuicByteCount bytes_in_flight) { + const bool result = bytes_in_flight < GetCongestionWindow(); + return result; +} + +QuicByteCount Bbr2Sender::GetCongestionWindow() const { + // TODO(wub): Implement Recovery? + return cwnd_; +} + +QuicBandwidth Bbr2Sender::PacingRate(QuicByteCount /*bytes_in_flight*/) const { + return pacing_rate_; +} + +void Bbr2Sender::OnApplicationLimited(QuicByteCount bytes_in_flight) { + if (bytes_in_flight >= GetCongestionWindow()) { + return; + } + + model_.OnApplicationLimited(); + QUIC_DVLOG(2) << this << " Becoming application limited. Last sent packet: " + << model_.last_sent_packet() + << ", CWND: " << GetCongestionWindow(); +} + +QuicByteCount Bbr2Sender::GetTargetBytesInflight() const { + QuicByteCount bdp = model_.BDP(model_.BandwidthEstimate()); + return std::min(bdp, GetCongestionWindow()); +} + +void Bbr2Sender::PopulateConnectionStats(QuicConnectionStats* stats) const { + stats->num_ack_aggregation_epochs = model_.num_ack_aggregation_epochs(); +} + +void Bbr2Sender::OnEnterQuiescence(QuicTime now) { + last_quiescence_start_ = now; +} + +void Bbr2Sender::OnExitQuiescence(QuicTime now) { + if (last_quiescence_start_ != QuicTime::Zero()) { + Bbr2Mode next_mode = BBR2_MODE_DISPATCH( + OnExitQuiescence(now, std::min(now, last_quiescence_start_))); + if (next_mode != mode_) { + BBR2_MODE_DISPATCH(Leave(now, nullptr)); + mode_ = next_mode; + BBR2_MODE_DISPATCH(Enter(now, nullptr)); + } + last_quiescence_start_ = QuicTime::Zero(); + } +} + +std::string Bbr2Sender::GetDebugState() const { + std::ostringstream stream; + stream << ExportDebugState(); + return stream.str(); +} + +Bbr2Sender::DebugState Bbr2Sender::ExportDebugState() const { + DebugState s; + s.mode = mode_; + s.round_trip_count = model_.RoundTripCount(); + s.bandwidth_hi = model_.MaxBandwidth(); + s.bandwidth_lo = model_.bandwidth_lo(); + s.bandwidth_est = BandwidthEstimate(); + s.inflight_hi = model_.inflight_hi(); + s.inflight_lo = model_.inflight_lo(); + s.max_ack_height = model_.MaxAckHeight(); + s.min_rtt = model_.MinRtt(); + s.min_rtt_timestamp = model_.MinRttTimestamp(); + s.congestion_window = cwnd_; + s.pacing_rate = pacing_rate_; + s.last_sample_is_app_limited = last_sample_is_app_limited_; + s.end_of_app_limited_phase = model_.end_of_app_limited_phase(); + + s.startup = startup_.ExportDebugState(); + s.drain = drain_.ExportDebugState(); + s.probe_bw = probe_bw_.ExportDebugState(); + s.probe_rtt = probe_rtt_.ExportDebugState(); + + return s; +} + +std::ostream& operator<<(std::ostream& os, const Bbr2Sender::DebugState& s) { + os << "mode: " << s.mode << "\n"; + os << "round_trip_count: " << s.round_trip_count << "\n"; + os << "bandwidth_hi ~ lo ~ est: " << s.bandwidth_hi << " ~ " << s.bandwidth_lo + << " ~ " << s.bandwidth_est << "\n"; + os << "min_rtt: " << s.min_rtt << "\n"; + os << "min_rtt_timestamp: " << s.min_rtt_timestamp << "\n"; + os << "congestion_window: " << s.congestion_window << "\n"; + os << "pacing_rate: " << s.pacing_rate << "\n"; + os << "last_sample_is_app_limited: " << s.last_sample_is_app_limited << "\n"; + + if (s.mode == Bbr2Mode::STARTUP) { + os << s.startup; + } + + if (s.mode == Bbr2Mode::DRAIN) { + os << s.drain; + } + + if (s.mode == Bbr2Mode::PROBE_BW) { + os << s.probe_bw; + } + + if (s.mode == Bbr2Mode::PROBE_RTT) { + os << s.probe_rtt; + } + + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/congestion_control/bbr2_sender.h b/quiche/quic/core/congestion_control/bbr2_sender.h new file mode 100644 index 000000000000..171028a05922 --- /dev/null +++ b/quiche/quic/core/congestion_control/bbr2_sender.h @@ -0,0 +1,218 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CONGESTION_CONTROL_BBR2_SENDER_H_ +#define QUICHE_QUIC_CORE_CONGESTION_CONTROL_BBR2_SENDER_H_ + +#include + +#include "quiche/quic/core/congestion_control/bandwidth_sampler.h" +#include "quiche/quic/core/congestion_control/bbr2_drain.h" +#include "quiche/quic/core/congestion_control/bbr2_misc.h" +#include "quiche/quic/core/congestion_control/bbr2_probe_bw.h" +#include "quiche/quic/core/congestion_control/bbr2_probe_rtt.h" +#include "quiche/quic/core/congestion_control/bbr2_startup.h" +#include "quiche/quic/core/congestion_control/bbr_sender.h" +#include "quiche/quic/core/congestion_control/rtt_stats.h" +#include "quiche/quic/core/congestion_control/send_algorithm_interface.h" +#include "quiche/quic/core/congestion_control/windowed_filter.h" +#include "quiche/quic/core/quic_bandwidth.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_flags.h" + +namespace quic { + +class QUIC_EXPORT_PRIVATE Bbr2Sender final : public SendAlgorithmInterface { + public: + Bbr2Sender(QuicTime now, const RttStats* rtt_stats, + const QuicUnackedPacketMap* unacked_packets, + QuicPacketCount initial_cwnd_in_packets, + QuicPacketCount max_cwnd_in_packets, QuicRandom* random, + QuicConnectionStats* stats, BbrSender* old_sender); + + ~Bbr2Sender() override = default; + + // Start implementation of SendAlgorithmInterface. + bool InSlowStart() const override { return mode_ == Bbr2Mode::STARTUP; } + + bool InRecovery() const override { + // TODO(wub): Implement Recovery. + return false; + } + + void SetFromConfig(const QuicConfig& config, + Perspective perspective) override; + + void ApplyConnectionOptions(const QuicTagVector& connection_options) override; + + void AdjustNetworkParameters(const NetworkParams& params) override; + + void SetInitialCongestionWindowInPackets( + QuicPacketCount congestion_window) override; + + void OnCongestionEvent(bool rtt_updated, QuicByteCount prior_in_flight, + QuicTime event_time, + const AckedPacketVector& acked_packets, + const LostPacketVector& lost_packets, + QuicPacketCount num_ect, + QuicPacketCount num_ce) override; + + void OnPacketSent(QuicTime sent_time, QuicByteCount bytes_in_flight, + QuicPacketNumber packet_number, QuicByteCount bytes, + HasRetransmittableData is_retransmittable) override; + + void OnPacketNeutered(QuicPacketNumber packet_number) override; + + void OnRetransmissionTimeout(bool /*packets_retransmitted*/) override {} + + void OnConnectionMigration() override {} + + bool CanSend(QuicByteCount bytes_in_flight) override; + + QuicBandwidth PacingRate(QuicByteCount bytes_in_flight) const override; + + QuicBandwidth BandwidthEstimate() const override { + return model_.BandwidthEstimate(); + } + + bool HasGoodBandwidthEstimateForResumption() const override { + return has_non_app_limited_sample_; + } + + QuicByteCount GetCongestionWindow() const override; + + QuicByteCount GetSlowStartThreshold() const override { return 0; } + + CongestionControlType GetCongestionControlType() const override { + return kBBRv2; + } + + std::string GetDebugState() const override; + + void OnApplicationLimited(QuicByteCount bytes_in_flight) override; + + void PopulateConnectionStats(QuicConnectionStats* stats) const override; + + bool SupportsECT0() const override { return false; } + bool SupportsECT1() const override { return false; } + // End implementation of SendAlgorithmInterface. + + const Bbr2Params& Params() const { return params_; } + + QuicByteCount GetMinimumCongestionWindow() const { + return cwnd_limits().Min(); + } + + // Returns the min of BDP and congestion window. + QuicByteCount GetTargetBytesInflight() const; + + bool IsBandwidthOverestimateAvoidanceEnabled() const { + return model_.IsBandwidthOverestimateAvoidanceEnabled(); + } + + struct QUIC_EXPORT_PRIVATE DebugState { + Bbr2Mode mode; + + // Shared states. + QuicRoundTripCount round_trip_count; + QuicBandwidth bandwidth_hi = QuicBandwidth::Zero(); + QuicBandwidth bandwidth_lo = QuicBandwidth::Zero(); + QuicBandwidth bandwidth_est = QuicBandwidth::Zero(); + QuicByteCount inflight_hi; + QuicByteCount inflight_lo; + QuicByteCount max_ack_height; + QuicTime::Delta min_rtt = QuicTime::Delta::Zero(); + QuicTime min_rtt_timestamp = QuicTime::Zero(); + QuicByteCount congestion_window; + QuicBandwidth pacing_rate = QuicBandwidth::Zero(); + bool last_sample_is_app_limited; + QuicPacketNumber end_of_app_limited_phase; + + // Mode-specific debug states. + Bbr2StartupMode::DebugState startup; + Bbr2DrainMode::DebugState drain; + Bbr2ProbeBwMode::DebugState probe_bw; + Bbr2ProbeRttMode::DebugState probe_rtt; + }; + + DebugState ExportDebugState() const; + + private: + void UpdatePacingRate(QuicByteCount bytes_acked); + void UpdateCongestionWindow(QuicByteCount bytes_acked); + QuicByteCount GetTargetCongestionWindow(float gain) const; + void OnEnterQuiescence(QuicTime now); + void OnExitQuiescence(QuicTime now); + + // Helper function for BBR2_MODE_DISPATCH. + Bbr2ProbeRttMode& probe_rtt_or_die() { + QUICHE_DCHECK_EQ(mode_, Bbr2Mode::PROBE_RTT); + return probe_rtt_; + } + + const Bbr2ProbeRttMode& probe_rtt_or_die() const { + QUICHE_DCHECK_EQ(mode_, Bbr2Mode::PROBE_RTT); + return probe_rtt_; + } + + uint64_t RandomUint64(uint64_t max) const { + return random_->RandUint64() % max; + } + + // Cwnd limits imposed by the current Bbr2 mode. + Limits GetCwndLimitsByMode() const; + + // Cwnd limits imposed by caller. + const Limits& cwnd_limits() const; + + const Bbr2Params& params() const { return params_; } + + Bbr2Mode mode_; + + const RttStats* const rtt_stats_; + const QuicUnackedPacketMap* const unacked_packets_; + QuicRandom* random_; + QuicConnectionStats* connection_stats_; + + // Don't use it directly outside of SetFromConfig and ApplyConnectionOptions. + // Instead, use params() to get read-only access. + Bbr2Params params_; + + // Max congestion window when adjusting network parameters. + QuicByteCount max_cwnd_when_network_parameters_adjusted_ = + kMaxInitialCongestionWindow * kDefaultTCPMSS; + + Bbr2NetworkModel model_; + + const QuicByteCount initial_cwnd_; + + // Current cwnd and pacing rate. + QuicByteCount cwnd_; + QuicBandwidth pacing_rate_; + + QuicTime last_quiescence_start_ = QuicTime::Zero(); + + Bbr2StartupMode startup_; + Bbr2DrainMode drain_; + Bbr2ProbeBwMode probe_bw_; + Bbr2ProbeRttMode probe_rtt_; + + bool has_non_app_limited_sample_ = false; + + // Debug only. + bool last_sample_is_app_limited_; + + friend class Bbr2StartupMode; + friend class Bbr2DrainMode; + friend class Bbr2ProbeBwMode; + friend class Bbr2ProbeRttMode; +}; + +QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const Bbr2Sender::DebugState& state); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CONGESTION_CONTROL_BBR2_SENDER_H_ diff --git a/quiche/quic/core/congestion_control/bbr2_simulator_test.cc b/quiche/quic/core/congestion_control/bbr2_simulator_test.cc new file mode 100644 index 000000000000..d65adfe52182 --- /dev/null +++ b/quiche/quic/core/congestion_control/bbr2_simulator_test.cc @@ -0,0 +1,2575 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "quiche/quic/core/congestion_control/bbr2_misc.h" +#include "quiche/quic/core/congestion_control/bbr2_sender.h" +#include "quiche/quic/core/congestion_control/bbr_sender.h" +#include "quiche/quic/core/congestion_control/tcp_cubic_sender_bytes.h" +#include "quiche/quic/core/quic_bandwidth.h" +#include "quiche/quic/core/quic_packet_number.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_config_peer.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_sent_packet_manager_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/test_tools/send_algorithm_test_result.pb.h" +#include "quiche/quic/test_tools/send_algorithm_test_utils.h" +#include "quiche/quic/test_tools/simulator/link.h" +#include "quiche/quic/test_tools/simulator/quic_endpoint.h" +#include "quiche/quic/test_tools/simulator/simulator.h" +#include "quiche/quic/test_tools/simulator/switch.h" +#include "quiche/quic/test_tools/simulator/traffic_policer.h" +#include "quiche/common/platform/api/quiche_command_line_flags.h" + +using testing::AllOf; +using testing::Ge; +using testing::Le; + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::string, quic_bbr2_test_regression_mode, "", + "One of a) 'record' to record test result (one file per test), or " + "b) 'regress' to regress against recorded results, or " + "c) for non-regression mode."); + +namespace quic { + +using CyclePhase = Bbr2ProbeBwMode::CyclePhase; + +namespace test { + +// Use the initial CWND of 10, as 32 is too much for the test network. +const uint32_t kDefaultInitialCwndPackets = 10; +const uint32_t kDefaultInitialCwndBytes = + kDefaultInitialCwndPackets * kDefaultTCPMSS; + +struct LinkParams { + LinkParams(int64_t kilo_bits_per_sec, int64_t delay_us) + : bandwidth(QuicBandwidth::FromKBitsPerSecond(kilo_bits_per_sec)), + delay(QuicTime::Delta::FromMicroseconds(delay_us)) {} + QuicBandwidth bandwidth; + QuicTime::Delta delay; +}; + +struct TrafficPolicerParams { + std::string name = "policer"; + QuicByteCount initial_burst_size; + QuicByteCount max_bucket_size; + QuicBandwidth target_bandwidth = QuicBandwidth::Zero(); +}; + +// All Bbr2DefaultTopologyTests uses the default network topology: +// +// Sender +// | +// | <-- local_link +// | +// Network switch +// * <-- the bottleneck queue in the direction +// | of the receiver +// | +// | <-- test_link +// | +// | +// Receiver +class DefaultTopologyParams { + public: + LinkParams local_link = {10000, 2000}; + LinkParams test_link = {4000, 30000}; + + const simulator::SwitchPortNumber switch_port_count = 2; + // Network switch queue capacity, in number of BDPs. + float switch_queue_capacity_in_bdp = 2; + + absl::optional sender_policer_params; + + QuicBandwidth BottleneckBandwidth() const { + return std::min(local_link.bandwidth, test_link.bandwidth); + } + + // Round trip time of a single full size packet. + QuicTime::Delta RTT() const { + return 2 * (local_link.delay + test_link.delay + + local_link.bandwidth.TransferTime(kMaxOutgoingPacketSize) + + test_link.bandwidth.TransferTime(kMaxOutgoingPacketSize)); + } + + QuicByteCount BDP() const { return BottleneckBandwidth() * RTT(); } + + QuicByteCount SwitchQueueCapacity() const { + return switch_queue_capacity_in_bdp * BDP(); + } + + std::string ToString() const { + std::ostringstream os; + os << "{ BottleneckBandwidth: " << BottleneckBandwidth() + << " RTT: " << RTT() << " BDP: " << BDP() + << " BottleneckQueueSize: " << SwitchQueueCapacity() << "}"; + return os.str(); + } +}; + +class Bbr2SimulatorTest : public QuicTest { + protected: + Bbr2SimulatorTest() : simulator_(&random_) { + // Prevent the server(receiver), which only sends acks, from closing + // connection due to too many outstanding packets. + SetQuicFlag(quic_max_tracked_packet_count, 1000000); + } + + void SetUp() override { + if (quiche::GetQuicheCommandLineFlag( + FLAGS_quic_bbr2_test_regression_mode) == "regress") { + SendAlgorithmTestResult expected; + ASSERT_TRUE(LoadSendAlgorithmTestResult(&expected)); + random_seed_ = expected.random_seed(); + } else { + random_seed_ = QuicRandom::GetInstance()->RandUint64(); + } + random_.set_seed(random_seed_); + QUIC_LOG(INFO) << "Using random seed: " << random_seed_; + } + + ~Bbr2SimulatorTest() override { + const std::string regression_mode = + quiche::GetQuicheCommandLineFlag(FLAGS_quic_bbr2_test_regression_mode); + const QuicTime::Delta simulated_duration = + SimulatedNow() - QuicTime::Zero(); + if (regression_mode == "record") { + RecordSendAlgorithmTestResult(random_seed_, + simulated_duration.ToMicroseconds()); + } else if (regression_mode == "regress") { + CompareSendAlgorithmTestResult(simulated_duration.ToMicroseconds()); + } + } + + QuicTime SimulatedNow() const { return simulator_.GetClock()->Now(); } + + uint64_t random_seed_; + SimpleRandom random_; + simulator::Simulator simulator_; +}; + +class Bbr2DefaultTopologyTest : public Bbr2SimulatorTest { + protected: + Bbr2DefaultTopologyTest() + : sender_endpoint_(&simulator_, "Sender", "Receiver", + Perspective::IS_CLIENT, TestConnectionId(42)), + receiver_endpoint_(&simulator_, "Receiver", "Sender", + Perspective::IS_SERVER, TestConnectionId(42)) { + sender_ = SetupBbr2Sender(&sender_endpoint_, /*old_sender=*/nullptr); + } + + ~Bbr2DefaultTopologyTest() { + const auto* test_info = + ::testing::UnitTest::GetInstance()->current_test_info(); + const Bbr2Sender::DebugState& debug_state = sender_->ExportDebugState(); + QUIC_LOG(INFO) << "Bbr2DefaultTopologyTest." << test_info->name() + << " completed at simulated time: " + << SimulatedNow().ToDebuggingValue() / 1e6 + << " sec. packet loss:" + << sender_loss_rate_in_packets() * 100 + << "%, bw_hi:" << debug_state.bandwidth_hi; + } + + QuicUnackedPacketMap* GetUnackedMap(QuicConnection* connection) { + return QuicSentPacketManagerPeer::GetUnackedPacketMap( + QuicConnectionPeer::GetSentPacketManager(connection)); + } + + Bbr2Sender* SetupBbr2Sender(simulator::QuicEndpoint* endpoint, + BbrSender* old_sender) { + // Ownership of the sender will be overtaken by the endpoint. + Bbr2Sender* sender = new Bbr2Sender( + endpoint->connection()->clock()->Now(), + endpoint->connection()->sent_packet_manager().GetRttStats(), + GetUnackedMap(endpoint->connection()), kDefaultInitialCwndPackets, + GetQuicFlag(quic_max_congestion_window), &random_, + QuicConnectionPeer::GetStats(endpoint->connection()), old_sender); + QuicConnectionPeer::SetSendAlgorithm(endpoint->connection(), sender); + const int kTestMaxPacketSize = 1350; + endpoint->connection()->SetMaxPacketLength(kTestMaxPacketSize); + endpoint->RecordTrace(); + return sender; + } + + void CreateNetwork(const DefaultTopologyParams& params) { + QUIC_LOG(INFO) << "CreateNetwork with parameters: " << params.ToString(); + switch_ = std::make_unique(&simulator_, "Switch", + params.switch_port_count, + params.SwitchQueueCapacity()); + + // WARNING: The order to add links to network_links_ matters, because some + // tests adjusts the link bandwidth on the fly. + + // Local link connects sender and port 1. + network_links_.push_back(std::make_unique( + &sender_endpoint_, switch_->port(1), params.local_link.bandwidth, + params.local_link.delay)); + + // Test link connects receiver and port 2. + if (params.sender_policer_params.has_value()) { + const TrafficPolicerParams& policer_params = + params.sender_policer_params.value(); + sender_policer_ = std::make_unique( + &simulator_, policer_params.name, policer_params.initial_burst_size, + policer_params.max_bucket_size, policer_params.target_bandwidth, + switch_->port(2)); + network_links_.push_back(std::make_unique( + &receiver_endpoint_, sender_policer_.get(), + params.test_link.bandwidth, params.test_link.delay)); + } else { + network_links_.push_back(std::make_unique( + &receiver_endpoint_, switch_->port(2), params.test_link.bandwidth, + params.test_link.delay)); + } + } + + simulator::SymmetricLink* TestLink() { return network_links_[1].get(); } + + void DoSimpleTransfer(QuicByteCount transfer_size, QuicTime::Delta timeout) { + sender_endpoint_.AddBytesToTransfer(transfer_size); + // TODO(wub): consider rewriting this to run until the receiver actually + // receives the intended amount of bytes. + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + timeout); + EXPECT_TRUE(simulator_result) + << "Simple transfer failed. Bytes remaining: " + << sender_endpoint_.bytes_to_transfer(); + QUIC_LOG(INFO) << "Simple transfer state: " << sender_->ExportDebugState(); + } + + // Drive the simulator by sending enough data to enter PROBE_BW. + void DriveOutOfStartup(const DefaultTopologyParams& params) { + ASSERT_FALSE(sender_->ExportDebugState().startup.full_bandwidth_reached); + DoSimpleTransfer(1024 * 1024, QuicTime::Delta::FromSeconds(15)); + EXPECT_EQ(Bbr2Mode::PROBE_BW, sender_->ExportDebugState().mode); + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.02f); + } + + // Send |bytes|-sized bursts of data |number_of_bursts| times, waiting for + // |wait_time| between each burst. + void SendBursts(const DefaultTopologyParams& params, size_t number_of_bursts, + QuicByteCount bytes, QuicTime::Delta wait_time) { + ASSERT_EQ(0u, sender_endpoint_.bytes_to_transfer()); + for (size_t i = 0; i < number_of_bursts; i++) { + sender_endpoint_.AddBytesToTransfer(bytes); + + // Transfer data and wait for three seconds between each transfer. + simulator_.RunFor(wait_time); + + // Ensure the connection did not time out. + ASSERT_TRUE(sender_endpoint_.connection()->connected()); + ASSERT_TRUE(receiver_endpoint_.connection()->connected()); + } + + simulator_.RunFor(wait_time + params.RTT()); + ASSERT_EQ(0u, sender_endpoint_.bytes_to_transfer()); + } + + template + bool SendUntilOrTimeout(TerminationPredicate termination_predicate, + QuicTime::Delta timeout) { + EXPECT_EQ(0u, sender_endpoint_.bytes_to_transfer()); + const QuicTime deadline = SimulatedNow() + timeout; + do { + sender_endpoint_.AddBytesToTransfer(4 * kDefaultTCPMSS); + if (simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + deadline - SimulatedNow()) && + termination_predicate()) { + return true; + } + } while (SimulatedNow() < deadline); + return false; + } + + void EnableAggregation(QuicByteCount aggregation_bytes, + QuicTime::Delta aggregation_timeout) { + switch_->port_queue(1)->EnableAggregation(aggregation_bytes, + aggregation_timeout); + } + + void SetConnectionOption(QuicTag option) { + SetConnectionOption(std::move(option), sender_); + } + + void SetConnectionOption(QuicTag option, Bbr2Sender* sender) { + QuicConfig config; + QuicTagVector options; + options.push_back(option); + QuicConfigPeer::SetReceivedConnectionOptions(&config, options); + sender->SetFromConfig(config, Perspective::IS_SERVER); + } + + bool Bbr2ModeIsOneOf(const std::vector& expected_modes) const { + const Bbr2Mode mode = sender_->ExportDebugState().mode; + for (Bbr2Mode expected_mode : expected_modes) { + if (mode == expected_mode) { + return true; + } + } + return false; + } + + const RttStats* rtt_stats() { + return sender_endpoint_.connection()->sent_packet_manager().GetRttStats(); + } + + QuicConnection* sender_connection() { return sender_endpoint_.connection(); } + + Bbr2Sender::DebugState sender_debug_state() const { + return sender_->ExportDebugState(); + } + + const QuicConnectionStats& sender_connection_stats() { + return sender_connection()->GetStats(); + } + + QuicUnackedPacketMap* sender_unacked_map() { + return GetUnackedMap(sender_connection()); + } + + float sender_loss_rate_in_packets() { + return static_cast(sender_connection_stats().packets_lost) / + sender_connection_stats().packets_sent; + } + + simulator::QuicEndpoint sender_endpoint_; + simulator::QuicEndpoint receiver_endpoint_; + Bbr2Sender* sender_; + + std::unique_ptr switch_; + std::unique_ptr sender_policer_; + std::vector> network_links_; +}; + +TEST_F(Bbr2DefaultTopologyTest, NormalStartup) { + DefaultTopologyParams params; + CreateNetwork(params); + + // Run until the full bandwidth is reached and check how many rounds it was. + sender_endpoint_.AddBytesToTransfer(12 * 1024 * 1024); + QuicRoundTripCount max_bw_round = 0; + QuicBandwidth max_bw(QuicBandwidth::Zero()); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this, &max_bw, &max_bw_round]() { + if (max_bw * 1.001 < sender_->ExportDebugState().bandwidth_hi) { + max_bw = sender_->ExportDebugState().bandwidth_hi; + max_bw_round = sender_->ExportDebugState().round_trip_count; + } + return sender_->ExportDebugState().startup.full_bandwidth_reached; + }, + QuicTime::Delta::FromSeconds(5)); + ASSERT_TRUE(simulator_result); + EXPECT_EQ(Bbr2Mode::DRAIN, sender_->ExportDebugState().mode); + EXPECT_EQ(3u, sender_->ExportDebugState().round_trip_count - max_bw_round); + EXPECT_EQ( + 3u, + sender_->ExportDebugState().startup.round_trips_without_bandwidth_growth); + EXPECT_EQ(0u, sender_connection_stats().packets_lost); + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.01f); + EXPECT_FALSE(sender_->ExportDebugState().last_sample_is_app_limited); +} + +TEST_F(Bbr2DefaultTopologyTest, NormalStartupB207) { + SetConnectionOption(kB207); + DefaultTopologyParams params; + CreateNetwork(params); + + // Run until the full bandwidth is reached and check how many rounds it was. + sender_endpoint_.AddBytesToTransfer(12 * 1024 * 1024); + QuicRoundTripCount max_bw_round = 0; + QuicBandwidth max_bw(QuicBandwidth::Zero()); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this, &max_bw, &max_bw_round]() { + if (max_bw < sender_->ExportDebugState().bandwidth_hi) { + max_bw = sender_->ExportDebugState().bandwidth_hi; + max_bw_round = sender_->ExportDebugState().round_trip_count; + } + return sender_->ExportDebugState().startup.full_bandwidth_reached; + }, + QuicTime::Delta::FromSeconds(5)); + ASSERT_TRUE(simulator_result); + EXPECT_EQ(Bbr2Mode::DRAIN, sender_->ExportDebugState().mode); + EXPECT_EQ(1u, sender_->ExportDebugState().round_trip_count - max_bw_round); + EXPECT_EQ( + 1u, + sender_->ExportDebugState().startup.round_trips_without_bandwidth_growth); + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.01f); + EXPECT_EQ(0u, sender_connection_stats().packets_lost); +} + +// Add extra_acked to CWND in STARTUP and exit STARTUP on a persistent queue. +TEST_F(Bbr2DefaultTopologyTest, NormalStartupB207andB205) { + SetConnectionOption(kB205); + SetConnectionOption(kB207); + DefaultTopologyParams params; + CreateNetwork(params); + + // Run until the full bandwidth is reached and check how many rounds it was. + sender_endpoint_.AddBytesToTransfer(12 * 1024 * 1024); + QuicRoundTripCount max_bw_round = 0; + QuicBandwidth max_bw(QuicBandwidth::Zero()); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this, &max_bw, &max_bw_round]() { + if (max_bw < sender_->ExportDebugState().bandwidth_hi) { + max_bw = sender_->ExportDebugState().bandwidth_hi; + max_bw_round = sender_->ExportDebugState().round_trip_count; + } + return sender_->ExportDebugState().startup.full_bandwidth_reached; + }, + QuicTime::Delta::FromSeconds(5)); + ASSERT_TRUE(simulator_result); + EXPECT_EQ(Bbr2Mode::DRAIN, sender_->ExportDebugState().mode); + EXPECT_EQ(1u, sender_->ExportDebugState().round_trip_count - max_bw_round); + EXPECT_EQ( + 2u, + sender_->ExportDebugState().startup.round_trips_without_bandwidth_growth); + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.01f); + EXPECT_EQ(0u, sender_connection_stats().packets_lost); +} + +// Add extra_acked to CWND in STARTUP and exit STARTUP on a persistent queue. +TEST_F(Bbr2DefaultTopologyTest, NormalStartupBB2S) { + SetQuicReloadableFlag(quic_bbr2_probe_two_rounds, true); + SetConnectionOption(kBB2S); + DefaultTopologyParams params; + CreateNetwork(params); + + // Run until the full bandwidth is reached and check how many rounds it was. + sender_endpoint_.AddBytesToTransfer(12 * 1024 * 1024); + QuicRoundTripCount max_bw_round = 0; + QuicBandwidth max_bw(QuicBandwidth::Zero()); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this, &max_bw, &max_bw_round]() { + if (max_bw * 1.001 < sender_->ExportDebugState().bandwidth_hi) { + max_bw = sender_->ExportDebugState().bandwidth_hi; + max_bw_round = sender_->ExportDebugState().round_trip_count; + } + return sender_->ExportDebugState().startup.full_bandwidth_reached; + }, + QuicTime::Delta::FromSeconds(5)); + ASSERT_TRUE(simulator_result); + EXPECT_EQ(Bbr2Mode::DRAIN, sender_->ExportDebugState().mode); + // BB2S reduces 3 rounds without bandwidth growth to 2. + EXPECT_EQ(2u, sender_->ExportDebugState().round_trip_count - max_bw_round); + EXPECT_EQ( + 2u, + sender_->ExportDebugState().startup.round_trips_without_bandwidth_growth); + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.01f); + EXPECT_EQ(0u, sender_connection_stats().packets_lost); +} + +// Test a simple long data transfer in the default setup. +TEST_F(Bbr2DefaultTopologyTest, SimpleTransfer) { + DefaultTopologyParams params; + CreateNetwork(params); + + // At startup make sure we are at the default. + EXPECT_EQ(kDefaultInitialCwndBytes, sender_->GetCongestionWindow()); + // At startup make sure we can send. + EXPECT_TRUE(sender_->CanSend(0)); + // And that window is un-affected. + EXPECT_EQ(kDefaultInitialCwndBytes, sender_->GetCongestionWindow()); + + // Verify that Sender is in slow start. + EXPECT_TRUE(sender_->InSlowStart()); + + // Verify that pacing rate is based on the initial RTT. + QuicBandwidth expected_pacing_rate = QuicBandwidth::FromBytesAndTimeDelta( + 2.885 * kDefaultInitialCwndBytes, rtt_stats()->initial_rtt()); + EXPECT_APPROX_EQ(expected_pacing_rate.ToBitsPerSecond(), + sender_->PacingRate(0).ToBitsPerSecond(), 0.01f); + + ASSERT_GE(params.BDP(), kDefaultInitialCwndBytes + kDefaultTCPMSS); + + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(30)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + EXPECT_EQ(0u, sender_connection_stats().packets_lost); + EXPECT_FALSE(sender_->ExportDebugState().last_sample_is_app_limited); + + // The margin here is quite high, since there exists a possibility that the + // connection just exited high gain cycle. + EXPECT_APPROX_EQ(params.RTT(), rtt_stats()->smoothed_rtt(), 1.0f); +} + +TEST_F(Bbr2DefaultTopologyTest, SimpleTransferB2RC) { + SetConnectionOption(kB2RC); + DefaultTopologyParams params; + CreateNetwork(params); + + // Transfer 12MB. + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(35)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.01f); + + EXPECT_LE(sender_loss_rate_in_packets(), 0.05); + // The margin here is high, because the aggregation greatly increases + // smoothed rtt. + EXPECT_GE(params.RTT() * 4, rtt_stats()->smoothed_rtt()); + EXPECT_APPROX_EQ(params.RTT(), rtt_stats()->min_rtt(), 0.2f); +} + +TEST_F(Bbr2DefaultTopologyTest, SimpleTransferB201) { + SetConnectionOption(kB201); + DefaultTopologyParams params; + CreateNetwork(params); + + // Transfer 12MB. + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(35)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.01f); + + EXPECT_LE(sender_loss_rate_in_packets(), 0.05); + // The margin here is high, because the aggregation greatly increases + // smoothed rtt. + EXPECT_GE(params.RTT() * 4, rtt_stats()->smoothed_rtt()); + EXPECT_APPROX_EQ(params.RTT(), rtt_stats()->min_rtt(), 0.2f); +} + +TEST_F(Bbr2DefaultTopologyTest, SimpleTransferB206) { + SetConnectionOption(kB206); + DefaultTopologyParams params; + CreateNetwork(params); + + // Transfer 12MB. + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(35)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.01f); + + EXPECT_LE(sender_loss_rate_in_packets(), 0.05); + // The margin here is high, because the aggregation greatly increases + // smoothed rtt. + EXPECT_GE(params.RTT() * 4, rtt_stats()->smoothed_rtt()); + EXPECT_APPROX_EQ(params.RTT(), rtt_stats()->min_rtt(), 0.2f); +} + +TEST_F(Bbr2DefaultTopologyTest, SimpleTransferB207) { + SetConnectionOption(kB207); + DefaultTopologyParams params; + CreateNetwork(params); + + // Transfer 12MB. + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(35)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.01f); + + EXPECT_LE(sender_loss_rate_in_packets(), 0.05); + // The margin here is high, because the aggregation greatly increases + // smoothed rtt. + EXPECT_GE(params.RTT() * 4, rtt_stats()->smoothed_rtt()); + EXPECT_APPROX_EQ(params.RTT(), rtt_stats()->min_rtt(), 0.2f); +} + +TEST_F(Bbr2DefaultTopologyTest, SimpleTransferBBRB) { + SetConnectionOption(kBBRB); + DefaultTopologyParams params; + CreateNetwork(params); + + // Transfer 12MB. + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(35)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.01f); + + EXPECT_LE(sender_loss_rate_in_packets(), 0.05); + // The margin here is high, because the aggregation greatly increases + // smoothed rtt. + EXPECT_GE(params.RTT() * 4, rtt_stats()->smoothed_rtt()); + EXPECT_APPROX_EQ(params.RTT(), rtt_stats()->min_rtt(), 0.2f); +} + +TEST_F(Bbr2DefaultTopologyTest, SimpleTransferBBR4) { + SetQuicReloadableFlag(quic_bbr2_extra_acked_window, true); + SetConnectionOption(kBBR4); + DefaultTopologyParams params; + CreateNetwork(params); + + // Transfer 12MB. + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(35)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.01f); + + EXPECT_LE(sender_loss_rate_in_packets(), 0.05); + // The margin here is high, because the aggregation greatly increases + // smoothed rtt. + EXPECT_GE(params.RTT() * 4, rtt_stats()->smoothed_rtt()); + EXPECT_APPROX_EQ(params.RTT(), rtt_stats()->min_rtt(), 0.2f); +} + +TEST_F(Bbr2DefaultTopologyTest, SimpleTransferBBR5) { + SetQuicReloadableFlag(quic_bbr2_extra_acked_window, true); + SetConnectionOption(kBBR5); + DefaultTopologyParams params; + CreateNetwork(params); + + // Transfer 12MB. + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(35)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.01f); + + EXPECT_LE(sender_loss_rate_in_packets(), 0.05); + // The margin here is high, because the aggregation greatly increases + // smoothed rtt. + EXPECT_GE(params.RTT() * 4, rtt_stats()->smoothed_rtt()); + EXPECT_APPROX_EQ(params.RTT(), rtt_stats()->min_rtt(), 0.2f); +} + +TEST_F(Bbr2DefaultTopologyTest, SimpleTransferBBQ1) { + SetConnectionOption(kBBQ1); + DefaultTopologyParams params; + CreateNetwork(params); + + // Transfer 12MB. + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(35)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.01f); + + EXPECT_LE(sender_loss_rate_in_packets(), 0.05); + // The margin here is high, because the aggregation greatly increases + // smoothed rtt. + EXPECT_GE(params.RTT() * 4, rtt_stats()->smoothed_rtt()); + EXPECT_APPROX_EQ(params.RTT(), rtt_stats()->min_rtt(), 0.2f); +} + +TEST_F(Bbr2DefaultTopologyTest, SimpleTransferSmallBuffer) { + DefaultTopologyParams params; + params.switch_queue_capacity_in_bdp = 0.5; + CreateNetwork(params); + + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(30)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.02f); + EXPECT_GE(sender_connection_stats().packets_lost, 0u); + EXPECT_FALSE(sender_->ExportDebugState().last_sample_is_app_limited); +} + +TEST_F(Bbr2DefaultTopologyTest, SimpleTransferSmallBufferB2H2) { + SetConnectionOption(kB2H2); + DefaultTopologyParams params; + params.switch_queue_capacity_in_bdp = 0.5; + CreateNetwork(params); + + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(30)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.02f); + EXPECT_GE(sender_connection_stats().packets_lost, 0u); + EXPECT_FALSE(sender_->ExportDebugState().last_sample_is_app_limited); +} + +TEST_F(Bbr2DefaultTopologyTest, SimpleTransfer2RTTAggregationBytes) { + SetConnectionOption(kBSAO); + DefaultTopologyParams params; + CreateNetwork(params); + // 2 RTTs of aggregation, with a max of 10kb. + EnableAggregation(10 * 1024, 2 * params.RTT()); + + // Transfer 12MB. + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(35)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.01f); + + EXPECT_EQ(sender_loss_rate_in_packets(), 0); + // The margin here is high, because both link level aggregation and ack + // decimation can greatly increase smoothed rtt. + EXPECT_GE(params.RTT() * 5, rtt_stats()->smoothed_rtt()); + EXPECT_APPROX_EQ(params.RTT(), rtt_stats()->min_rtt(), 0.2f); +} + +TEST_F(Bbr2DefaultTopologyTest, SimpleTransfer2RTTAggregationBytesB201) { + SetConnectionOption(kB201); + DefaultTopologyParams params; + CreateNetwork(params); + // 2 RTTs of aggregation, with a max of 10kb. + EnableAggregation(10 * 1024, 2 * params.RTT()); + + // Transfer 12MB. + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(35)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + + // TODO(wub): Tighten the error bound once BSAO is default enabled. + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.5f); + + EXPECT_LE(sender_loss_rate_in_packets(), 0.01); + // The margin here is high, because both link level aggregation and ack + // decimation can greatly increase smoothed rtt. + EXPECT_GE(params.RTT() * 5, rtt_stats()->smoothed_rtt()); + EXPECT_APPROX_EQ(params.RTT(), rtt_stats()->min_rtt(), 0.2f); +} + +TEST_F(Bbr2DefaultTopologyTest, SimpleTransferAckDecimation) { + SetConnectionOption(kBSAO); + DefaultTopologyParams params; + CreateNetwork(params); + + // Transfer 12MB. + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(35)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.01f); + + EXPECT_LE(sender_loss_rate_in_packets(), 0.001); + EXPECT_FALSE(sender_->ExportDebugState().last_sample_is_app_limited); + // The margin here is high, because the aggregation greatly increases + // smoothed rtt. + EXPECT_GE(params.RTT() * 3, rtt_stats()->smoothed_rtt()); + EXPECT_APPROX_EQ(params.RTT(), rtt_stats()->min_rtt(), 0.1f); +} + +// Test Bbr2's reaction to a 100x bandwidth decrease during a transfer. +TEST_F(Bbr2DefaultTopologyTest, QUIC_SLOW_TEST(BandwidthDecrease)) { + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + CreateNetwork(params); + + sender_endpoint_.AddBytesToTransfer(20 * 1024 * 1024); + + // We can transfer ~12MB in the first 10 seconds. The rest ~8MB needs about + // 640 seconds. + simulator_.RunFor(QuicTime::Delta::FromSeconds(10)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth decreasing at time " << SimulatedNow(); + + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.1f); + EXPECT_EQ(0u, sender_connection_stats().packets_lost); + + // Now decrease the bottleneck bandwidth from 10Mbps to 100Kbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(800)); + EXPECT_TRUE(simulator_result); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer with B203 +TEST_F(Bbr2DefaultTopologyTest, QUIC_SLOW_TEST(BandwidthIncreaseB203)) { + SetConnectionOption(kB203); + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + sender_endpoint_.AddBytesToTransfer(20 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.1f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.30); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure the full bandwidth is discovered. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.02f); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer with BBQ0 +TEST_F(Bbr2DefaultTopologyTest, QUIC_SLOW_TEST(BandwidthIncreaseBBQ0)) { + SetConnectionOption(kBBQ0); + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + sender_endpoint_.AddBytesToTransfer(10 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.1f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.30); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure the full bandwidth is discovered. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.02f); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer with BBQ0 +// in the presence of ACK aggregation. +TEST_F(Bbr2DefaultTopologyTest, + QUIC_SLOW_TEST(BandwidthIncreaseBBQ0Aggregation)) { + SetConnectionOption(kBBQ0); + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + // 2 RTTs of aggregation, with a max of 10kb. + EnableAggregation(10 * 1024, 2 * params.RTT()); + + // Reduce the payload to 2MB because 10MB takes too long. + sender_endpoint_.AddBytesToTransfer(2 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + // This is much farther off when aggregation is present, + // Ideally BSAO or another option would fix this. + // TODO(ianswett) Make these bound tighter once overestimation is reduced. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.6f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.35); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure at least 10% of full bandwidth is discovered. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.90f); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer with B202 +TEST_F(Bbr2DefaultTopologyTest, QUIC_SLOW_TEST(BandwidthIncreaseB202)) { + SetConnectionOption(kB202); + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + sender_endpoint_.AddBytesToTransfer(10 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.1f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.30); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure the full bandwidth is discovered. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.1f); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer with B202 +// in the presence of ACK aggregation. +TEST_F(Bbr2DefaultTopologyTest, + QUIC_SLOW_TEST(BandwidthIncreaseB202Aggregation)) { + SetConnectionOption(kB202); + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + // 2 RTTs of aggregation, with a max of 10kb. + EnableAggregation(10 * 1024, 2 * params.RTT()); + + // Reduce the payload to 2MB because 10MB takes too long. + sender_endpoint_.AddBytesToTransfer(2 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + // This is much farther off when aggregation is present, + // Ideally BSAO or another option would fix this. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.6f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.35); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure at least 10% of full bandwidth is discovered. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.92f); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer. +TEST_F(Bbr2DefaultTopologyTest, QUIC_SLOW_TEST(BandwidthIncrease)) { + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + sender_endpoint_.AddBytesToTransfer(10 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.1f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.30); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure the full bandwidth is discovered. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.02f); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer in the +// presence of ACK aggregation. +TEST_F(Bbr2DefaultTopologyTest, QUIC_SLOW_TEST(BandwidthIncreaseAggregation)) { + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + // 2 RTTs of aggregation, with a max of 10kb. + EnableAggregation(10 * 1024, 2 * params.RTT()); + + // Reduce the payload to 2MB because 10MB takes too long. + sender_endpoint_.AddBytesToTransfer(2 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + // This is much farther off when aggregation is present, + // Ideally BSAO or another option would fix this. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.60f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.35); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure at least 10% of full bandwidth is discovered. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.91f); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer with BBHI +TEST_F(Bbr2DefaultTopologyTest, QUIC_SLOW_TEST(BandwidthIncreaseBBHI)) { + SetQuicReloadableFlag(quic_bbr2_simplify_inflight_hi, true); + SetConnectionOption(kBBHI); + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + sender_endpoint_.AddBytesToTransfer(10 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.1f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.30); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure the full bandwidth is discovered. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.02f); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer with BBHI +// in the presence of ACK aggregation. +TEST_F(Bbr2DefaultTopologyTest, + QUIC_SLOW_TEST(BandwidthIncreaseBBHIAggregation)) { + SetQuicReloadableFlag(quic_bbr2_simplify_inflight_hi, true); + SetConnectionOption(kBBHI); + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + // 2 RTTs of aggregation, with a max of 10kb. + EnableAggregation(10 * 1024, 2 * params.RTT()); + + // Reduce the payload to 2MB because 10MB takes too long. + sender_endpoint_.AddBytesToTransfer(2 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + // This is much farther off when aggregation is present, + // Ideally BSAO or another option would fix this. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.60f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.35); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure the full bandwidth is discovered. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.90f); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer with BBHI +// and B202, which changes the exit criteria to be based on +// min_bytes_in_flight_in_round, in the presence of ACK aggregation. +TEST_F(Bbr2DefaultTopologyTest, + QUIC_SLOW_TEST(BandwidthIncreaseBBHI_B202Aggregation)) { + SetQuicReloadableFlag(quic_bbr2_simplify_inflight_hi, true); + SetConnectionOption(kBBHI); + SetConnectionOption(kB202); + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + // 2 RTTs of aggregation, with a max of 10kb. + EnableAggregation(10 * 1024, 2 * params.RTT()); + + // Reduce the payload to 2MB because 10MB takes too long. + sender_endpoint_.AddBytesToTransfer(2 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + // This is much farther off when aggregation is present, + // Ideally BSAO or another option would fix this. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.60f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.35); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure at least 18% of the bandwidth is discovered. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.85f); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer with B204 +TEST_F(Bbr2DefaultTopologyTest, QUIC_SLOW_TEST(BandwidthIncreaseB204)) { + SetConnectionOption(kB204); + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + sender_endpoint_.AddBytesToTransfer(10 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.1f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.25); + EXPECT_LE(sender_->ExportDebugState().max_ack_height, 2000u); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure the full bandwidth is discovered. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.02f); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer with B204 +// in the presence of ACK aggregation. +TEST_F(Bbr2DefaultTopologyTest, + QUIC_SLOW_TEST(BandwidthIncreaseB204Aggregation)) { + SetConnectionOption(kB204); + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + // 2 RTTs of aggregation, with a max of 10kb. + EnableAggregation(10 * 1024, 2 * params.RTT()); + + // Reduce the payload to 2MB because 10MB takes too long. + sender_endpoint_.AddBytesToTransfer(2 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + // This is much farther off when aggregation is present, and B204 actually + // is increasing overestimation, which is surprising. + // Ideally BSAO or another option would fix this. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.60f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.35); + EXPECT_LE(sender_->ExportDebugState().max_ack_height, 10000u); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure at least 10% of full bandwidth is discovered. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.95f); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer with B205 +TEST_F(Bbr2DefaultTopologyTest, QUIC_SLOW_TEST(BandwidthIncreaseB205)) { + SetConnectionOption(kB205); + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + sender_endpoint_.AddBytesToTransfer(10 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.1f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.10); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure the full bandwidth is discovered. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.1f); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer with B205 +// in the presence of ACK aggregation. +TEST_F(Bbr2DefaultTopologyTest, + QUIC_SLOW_TEST(BandwidthIncreaseB205Aggregation)) { + SetConnectionOption(kB205); + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + // 2 RTTs of aggregation, with a max of 10kb. + EnableAggregation(10 * 1024, 2 * params.RTT()); + + // Reduce the payload to 2MB because 10MB takes too long. + sender_endpoint_.AddBytesToTransfer(2 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + // This is much farther off when aggregation is present, + // Ideally BSAO or another option would fix this. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.45f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.15); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure at least 5% of full bandwidth is discovered. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.9f); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer with BB2U +TEST_F(Bbr2DefaultTopologyTest, QUIC_SLOW_TEST(BandwidthIncreaseBB2U)) { + SetQuicReloadableFlag(quic_bbr2_probe_two_rounds, true); + SetConnectionOption(kBB2U); + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + sender_endpoint_.AddBytesToTransfer(10 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.1f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.25); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure the full bandwidth is discovered. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.1f); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer with BB2U +// in the presence of ACK aggregation. +TEST_F(Bbr2DefaultTopologyTest, + QUIC_SLOW_TEST(BandwidthIncreaseBB2UAggregation)) { + SetQuicReloadableFlag(quic_bbr2_probe_two_rounds, true); + SetConnectionOption(kBB2U); + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + // 2 RTTs of aggregation, with a max of 10kb. + EnableAggregation(10 * 1024, 2 * params.RTT()); + + // Reduce the payload to 5MB because 10MB takes too long. + sender_endpoint_.AddBytesToTransfer(5 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + // This is much farther off when aggregation is present, + // Ideally BSAO or another option would fix this. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.45f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.30); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure at least 15% of the full bandwidth is observed. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.85f); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer with BB2U +// and BBHI in the presence of ACK aggregation. +TEST_F(Bbr2DefaultTopologyTest, + QUIC_SLOW_TEST(BandwidthIncreaseBB2UandBBHIAggregation)) { + SetQuicReloadableFlag(quic_bbr2_probe_two_rounds, true); + SetConnectionOption(kBB2U); + SetQuicReloadableFlag(quic_bbr2_simplify_inflight_hi, true); + SetConnectionOption(kBBHI); + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + // 2 RTTs of aggregation, with a max of 10kb. + EnableAggregation(10 * 1024, 2 * params.RTT()); + + // Reduce the payload to 5MB because 10MB takes too long. + sender_endpoint_.AddBytesToTransfer(5 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + // This is much farther off when aggregation is present, + // Ideally BSAO or another option would fix this. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.45f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.30); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure at least 15% of the full bandwidth is observed. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.85f); +} + +// Test the number of losses incurred by the startup phase in a situation when +// the buffer is less than BDP. +TEST_F(Bbr2DefaultTopologyTest, PacketLossOnSmallBufferStartup) { + DefaultTopologyParams params; + params.switch_queue_capacity_in_bdp = 0.5; + CreateNetwork(params); + + DriveOutOfStartup(params); + // Packet loss is smaller with a CWND gain of 2 than 2.889. + EXPECT_LE(sender_loss_rate_in_packets(), 0.05); +} + +// Test the number of losses decreases with packet-conservation pacing. +TEST_F(Bbr2DefaultTopologyTest, PacketLossBBQ6SmallBufferStartup) { + SetConnectionOption(kBBQ2); // Increase CWND gain. + SetConnectionOption(kBBQ6); + DefaultTopologyParams params; + params.switch_queue_capacity_in_bdp = 0.5; + CreateNetwork(params); + + DriveOutOfStartup(params); + EXPECT_LE(sender_loss_rate_in_packets(), 0.0575); + // bandwidth_lo is cleared exiting STARTUP. + EXPECT_EQ(sender_->ExportDebugState().bandwidth_lo, + QuicBandwidth::Infinite()); +} + +// Test the number of losses decreases with min_rtt packet-conservation pacing. +TEST_F(Bbr2DefaultTopologyTest, PacketLossBBQ7SmallBufferStartup) { + SetConnectionOption(kBBQ2); // Increase CWND gain. + SetConnectionOption(kBBQ7); + DefaultTopologyParams params; + params.switch_queue_capacity_in_bdp = 0.5; + CreateNetwork(params); + + DriveOutOfStartup(params); + EXPECT_LE(sender_loss_rate_in_packets(), 0.06); + // bandwidth_lo is cleared exiting STARTUP. + EXPECT_EQ(sender_->ExportDebugState().bandwidth_lo, + QuicBandwidth::Infinite()); +} + +// Test the number of losses decreases with Inflight packet-conservation pacing. +TEST_F(Bbr2DefaultTopologyTest, PacketLossBBQ8SmallBufferStartup) { + SetConnectionOption(kBBQ2); // Increase CWND gain. + SetConnectionOption(kBBQ8); + DefaultTopologyParams params; + params.switch_queue_capacity_in_bdp = 0.5; + CreateNetwork(params); + + DriveOutOfStartup(params); + EXPECT_LE(sender_loss_rate_in_packets(), 0.065); + // bandwidth_lo is cleared exiting STARTUP. + EXPECT_EQ(sender_->ExportDebugState().bandwidth_lo, + QuicBandwidth::Infinite()); +} + +// Test the number of losses decreases with CWND packet-conservation pacing. +TEST_F(Bbr2DefaultTopologyTest, PacketLossBBQ9SmallBufferStartup) { + SetConnectionOption(kBBQ2); // Increase CWND gain. + SetConnectionOption(kBBQ9); + DefaultTopologyParams params; + params.switch_queue_capacity_in_bdp = 0.5; + CreateNetwork(params); + + DriveOutOfStartup(params); + EXPECT_LE(sender_loss_rate_in_packets(), 0.065); + // bandwidth_lo is cleared exiting STARTUP. + EXPECT_EQ(sender_->ExportDebugState().bandwidth_lo, + QuicBandwidth::Infinite()); +} + +// Verify the behavior of the algorithm in the case when the connection sends +// small bursts of data after sending continuously for a while. +TEST_F(Bbr2DefaultTopologyTest, ApplicationLimitedBursts) { + DefaultTopologyParams params; + CreateNetwork(params); + + EXPECT_FALSE(sender_->HasGoodBandwidthEstimateForResumption()); + DriveOutOfStartup(params); + EXPECT_FALSE(sender_->ExportDebugState().last_sample_is_app_limited); + EXPECT_TRUE(sender_->HasGoodBandwidthEstimateForResumption()); + + SendBursts(params, 20, 512, QuicTime::Delta::FromSeconds(3)); + EXPECT_TRUE(sender_->ExportDebugState().last_sample_is_app_limited); + EXPECT_TRUE(sender_->HasGoodBandwidthEstimateForResumption()); + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.01f); +} + +// Verify the behavior of the algorithm in the case when the connection sends +// small bursts of data and then starts sending continuously. +TEST_F(Bbr2DefaultTopologyTest, ApplicationLimitedBurstsWithoutPrior) { + DefaultTopologyParams params; + CreateNetwork(params); + + SendBursts(params, 40, 512, QuicTime::Delta::FromSeconds(3)); + EXPECT_TRUE(sender_->ExportDebugState().last_sample_is_app_limited); + + DriveOutOfStartup(params); + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.01f); + EXPECT_FALSE(sender_->ExportDebugState().last_sample_is_app_limited); +} + +// Verify that the DRAIN phase works correctly. +TEST_F(Bbr2DefaultTopologyTest, Drain) { + DefaultTopologyParams params; + CreateNetwork(params); + + const QuicTime::Delta timeout = QuicTime::Delta::FromSeconds(10); + // Get the queue at the bottleneck, which is the outgoing queue at the port to + // which the receiver is connected. + const simulator::Queue* queue = switch_->port_queue(2); + bool simulator_result; + + // We have no intention of ever finishing this transfer. + sender_endpoint_.AddBytesToTransfer(100 * 1024 * 1024); + + // Run the startup, and verify that it fills up the queue. + ASSERT_EQ(Bbr2Mode::STARTUP, sender_->ExportDebugState().mode); + simulator_result = simulator_.RunUntilOrTimeout( + [this]() { + return sender_->ExportDebugState().mode != Bbr2Mode::STARTUP; + }, + timeout); + ASSERT_TRUE(simulator_result); + ASSERT_EQ(Bbr2Mode::DRAIN, sender_->ExportDebugState().mode); + EXPECT_APPROX_EQ(sender_->BandwidthEstimate() * (1 / 2.885f), + sender_->PacingRate(0), 0.01f); + + // BBR uses CWND gain of 2 during STARTUP, hence it will fill the buffer with + // approximately 1 BDP. Here, we use 0.95 to give some margin for error. + EXPECT_GE(queue->bytes_queued(), 0.95 * params.BDP()); + + // Observe increased RTT due to bufferbloat. + const QuicTime::Delta queueing_delay = + params.test_link.bandwidth.TransferTime(queue->bytes_queued()); + EXPECT_APPROX_EQ(params.RTT() + queueing_delay, rtt_stats()->latest_rtt(), + 0.1f); + + // Transition to the drain phase and verify that it makes the queue + // have at most a BDP worth of packets. + simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_->ExportDebugState().mode != Bbr2Mode::DRAIN; }, + timeout); + ASSERT_TRUE(simulator_result); + ASSERT_EQ(Bbr2Mode::PROBE_BW, sender_->ExportDebugState().mode); + EXPECT_LE(queue->bytes_queued(), params.BDP()); + + // Wait for a few round trips and ensure we're in appropriate phase of gain + // cycling before taking an RTT measurement. + const QuicRoundTripCount start_round_trip = + sender_->ExportDebugState().round_trip_count; + simulator_result = simulator_.RunUntilOrTimeout( + [this, start_round_trip]() { + const auto& debug_state = sender_->ExportDebugState(); + QuicRoundTripCount rounds_passed = + debug_state.round_trip_count - start_round_trip; + return rounds_passed >= 4 && debug_state.mode == Bbr2Mode::PROBE_BW && + debug_state.probe_bw.phase == CyclePhase::PROBE_REFILL; + }, + timeout); + ASSERT_TRUE(simulator_result); + + // Observe the bufferbloat go away. + EXPECT_APPROX_EQ(params.RTT(), rtt_stats()->smoothed_rtt(), 0.1f); +} + +// Ensure that a connection that is app-limited and is at sufficiently low +// bandwidth will not exit high gain phase, and similarly ensure that the +// connection will exit low gain early if the number of bytes in flight is low. +TEST_F(Bbr2DefaultTopologyTest, InFlightAwareGainCycling) { + DefaultTopologyParams params; + CreateNetwork(params); + DriveOutOfStartup(params); + + const QuicTime::Delta timeout = QuicTime::Delta::FromSeconds(5); + bool simulator_result; + + // Start a few cycles prior to the high gain one. + simulator_result = SendUntilOrTimeout( + [this]() { + return sender_->ExportDebugState().probe_bw.phase == + CyclePhase::PROBE_REFILL; + }, + timeout); + ASSERT_TRUE(simulator_result); + + // Send at 10% of available rate. Run for 3 seconds, checking in the middle + // and at the end. The pacing gain should be high throughout. + QuicBandwidth target_bandwidth = 0.1f * params.BottleneckBandwidth(); + QuicTime::Delta burst_interval = QuicTime::Delta::FromMilliseconds(300); + for (int i = 0; i < 2; i++) { + SendBursts(params, 5, target_bandwidth * burst_interval, burst_interval); + EXPECT_EQ(Bbr2Mode::PROBE_BW, sender_->ExportDebugState().mode); + EXPECT_EQ(CyclePhase::PROBE_UP, sender_->ExportDebugState().probe_bw.phase); + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.02f); + } + + // Now that in-flight is almost zero and the pacing gain is still above 1, + // send approximately 1.4 BDPs worth of data. This should cause the PROBE_BW + // mode to enter low gain cycle(PROBE_DOWN), and exit it earlier than one + // min_rtt due to running out of data to send. + sender_endpoint_.AddBytesToTransfer(1.4 * params.BDP()); + simulator_result = simulator_.RunUntilOrTimeout( + [this]() { + return sender_->ExportDebugState().probe_bw.phase == + CyclePhase::PROBE_DOWN; + }, + timeout); + ASSERT_TRUE(simulator_result); + simulator_.RunFor(0.75 * sender_->ExportDebugState().min_rtt); + EXPECT_EQ(Bbr2Mode::PROBE_BW, sender_->ExportDebugState().mode); + EXPECT_EQ(CyclePhase::PROBE_CRUISE, + sender_->ExportDebugState().probe_bw.phase); +} + +// Test exiting STARTUP earlier upon loss due to loss. +TEST_F(Bbr2DefaultTopologyTest, ExitStartupDueToLoss) { + DefaultTopologyParams params; + params.switch_queue_capacity_in_bdp = 0.5; + CreateNetwork(params); + + // Run until the full bandwidth is reached and check how many rounds it was. + sender_endpoint_.AddBytesToTransfer(12 * 1024 * 1024); + QuicRoundTripCount max_bw_round = 0; + QuicBandwidth max_bw(QuicBandwidth::Zero()); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this, &max_bw, &max_bw_round]() { + if (max_bw < sender_->ExportDebugState().bandwidth_hi) { + max_bw = sender_->ExportDebugState().bandwidth_hi; + max_bw_round = sender_->ExportDebugState().round_trip_count; + } + return sender_->ExportDebugState().startup.full_bandwidth_reached; + }, + QuicTime::Delta::FromSeconds(5)); + ASSERT_TRUE(simulator_result); + EXPECT_EQ(Bbr2Mode::DRAIN, sender_->ExportDebugState().mode); + EXPECT_GE(2u, sender_->ExportDebugState().round_trip_count - max_bw_round); + EXPECT_EQ( + 1u, + sender_->ExportDebugState().startup.round_trips_without_bandwidth_growth); + EXPECT_NE(0u, sender_connection_stats().packets_lost); + EXPECT_FALSE(sender_->ExportDebugState().last_sample_is_app_limited); + + EXPECT_GT(sender_->ExportDebugState().inflight_hi, 1.2f * params.BDP()); +} + +// Test exiting STARTUP earlier upon loss due to loss when connection option +// B2SL is used. +TEST_F(Bbr2DefaultTopologyTest, ExitStartupDueToLossB2SL) { + SetConnectionOption(kB2SL); + DefaultTopologyParams params; + params.switch_queue_capacity_in_bdp = 0.5; + CreateNetwork(params); + + // Run until the full bandwidth is reached and check how many rounds it was. + sender_endpoint_.AddBytesToTransfer(12 * 1024 * 1024); + QuicRoundTripCount max_bw_round = 0; + QuicBandwidth max_bw(QuicBandwidth::Zero()); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this, &max_bw, &max_bw_round]() { + if (max_bw < sender_->ExportDebugState().bandwidth_hi) { + max_bw = sender_->ExportDebugState().bandwidth_hi; + max_bw_round = sender_->ExportDebugState().round_trip_count; + } + return sender_->ExportDebugState().startup.full_bandwidth_reached; + }, + QuicTime::Delta::FromSeconds(5)); + ASSERT_TRUE(simulator_result); + EXPECT_EQ(Bbr2Mode::DRAIN, sender_->ExportDebugState().mode); + EXPECT_GE(2u, sender_->ExportDebugState().round_trip_count - max_bw_round); + EXPECT_EQ( + 1u, + sender_->ExportDebugState().startup.round_trips_without_bandwidth_growth); + EXPECT_NE(0u, sender_connection_stats().packets_lost); + EXPECT_FALSE(sender_->ExportDebugState().last_sample_is_app_limited); + + EXPECT_APPROX_EQ(sender_->ExportDebugState().inflight_hi, params.BDP(), 0.1f); +} + +// Verifies that in STARTUP, if we exceed loss threshold in a round, we exit +// STARTUP at the end of the round even if there's enough bandwidth growth. +TEST_F(Bbr2DefaultTopologyTest, ExitStartupDueToLossB2NE) { + // Set up flags such that any loss will be considered "too high". + SetQuicFlag(quic_bbr2_default_startup_full_loss_count, 0); + SetQuicFlag(quic_bbr2_default_loss_threshold, 0.0); + + sender_ = SetupBbr2Sender(&sender_endpoint_, /*old_sender=*/nullptr); + + SetConnectionOption(kB2NE); + DefaultTopologyParams params; + params.switch_queue_capacity_in_bdp = 0.5; + CreateNetwork(params); + + // Run until the full bandwidth is reached and check how many rounds it was. + sender_endpoint_.AddBytesToTransfer(12 * 1024 * 1024); + QuicRoundTripCount max_bw_round = 0; + QuicBandwidth max_bw(QuicBandwidth::Zero()); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this, &max_bw, &max_bw_round]() { + if (max_bw < sender_->ExportDebugState().bandwidth_hi) { + max_bw = sender_->ExportDebugState().bandwidth_hi; + max_bw_round = sender_->ExportDebugState().round_trip_count; + } + return sender_->ExportDebugState().startup.full_bandwidth_reached; + }, + QuicTime::Delta::FromSeconds(5)); + ASSERT_TRUE(simulator_result); + EXPECT_EQ(Bbr2Mode::DRAIN, sender_->ExportDebugState().mode); + EXPECT_EQ(sender_->ExportDebugState().round_trip_count, max_bw_round); + EXPECT_EQ( + 0u, + sender_->ExportDebugState().startup.round_trips_without_bandwidth_growth); + EXPECT_NE(0u, sender_connection_stats().packets_lost); +} + +TEST_F(Bbr2DefaultTopologyTest, SenderPoliced) { + DefaultTopologyParams params; + params.sender_policer_params = TrafficPolicerParams(); + params.sender_policer_params->initial_burst_size = 1000 * 10; + params.sender_policer_params->max_bucket_size = 1000 * 100; + params.sender_policer_params->target_bandwidth = + params.BottleneckBandwidth() * 0.25; + + CreateNetwork(params); + + ASSERT_GE(params.BDP(), kDefaultInitialCwndBytes + kDefaultTCPMSS); + + DoSimpleTransfer(3 * 1024 * 1024, QuicTime::Delta::FromSeconds(30)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + // TODO(wub): Fix (long-term) bandwidth overestimation in policer mode, then + // reduce the loss rate upper bound. + EXPECT_LE(sender_loss_rate_in_packets(), 0.30); +} + +// TODO(wub): Add other slowstart stats to BBRv2. +TEST_F(Bbr2DefaultTopologyTest, StartupStats) { + DefaultTopologyParams params; + CreateNetwork(params); + + DriveOutOfStartup(params); + ASSERT_FALSE(sender_->InSlowStart()); + + const QuicConnectionStats& stats = sender_connection_stats(); + // The test explicitly replaces the default-created send algorithm with the + // one created by the test. slowstart_count increaments every time a BBR + // sender is created. + EXPECT_GE(stats.slowstart_count, 1u); + EXPECT_FALSE(stats.slowstart_duration.IsRunning()); + EXPECT_THAT(stats.slowstart_duration.GetTotalElapsedTime(), + AllOf(Ge(QuicTime::Delta::FromMilliseconds(500)), + Le(QuicTime::Delta::FromMilliseconds(1500)))); + EXPECT_EQ(stats.slowstart_duration.GetTotalElapsedTime(), + QuicConnectionPeer::GetSentPacketManager(sender_connection()) + ->GetSlowStartDuration()); +} + +TEST_F(Bbr2DefaultTopologyTest, ProbeUpAdaptInflightHiGradually) { + DefaultTopologyParams params; + CreateNetwork(params); + + DriveOutOfStartup(params); + + AckedPacketVector acked_packets; + QuicPacketNumber acked_packet_number = + sender_unacked_map()->GetLeastUnacked(); + for (auto& info : *sender_unacked_map()) { + acked_packets.emplace_back(acked_packet_number++, info.bytes_sent, + SimulatedNow()); + } + + // Advance time significantly so the OnCongestionEvent enters PROBE_REFILL. + QuicTime now = SimulatedNow() + QuicTime::Delta::FromSeconds(5); + auto next_packet_number = sender_unacked_map()->largest_sent_packet() + 1; + sender_->OnCongestionEvent( + /*rtt_updated=*/true, sender_unacked_map()->bytes_in_flight(), now, + acked_packets, {}, 0, 0); + ASSERT_EQ(CyclePhase::PROBE_REFILL, + sender_->ExportDebugState().probe_bw.phase); + + // Send and Ack one packet to exit app limited and enter PROBE_UP. + sender_->OnPacketSent(now, /*bytes_in_flight=*/0, next_packet_number++, + kDefaultMaxPacketSize, HAS_RETRANSMITTABLE_DATA); + now = now + params.RTT(); + sender_->OnCongestionEvent( + /*rtt_updated=*/true, kDefaultMaxPacketSize, now, + {AckedPacket(next_packet_number - 1, kDefaultMaxPacketSize, now)}, {}, 0, + 0); + ASSERT_EQ(CyclePhase::PROBE_UP, sender_->ExportDebugState().probe_bw.phase); + + // Send 2 packets and lose the first one(50% loss) to exit PROBE_UP. + for (uint64_t i = 0; i < 2; ++i) { + sender_->OnPacketSent(now, /*bytes_in_flight=*/i * kDefaultMaxPacketSize, + next_packet_number++, kDefaultMaxPacketSize, + HAS_RETRANSMITTABLE_DATA); + } + now = now + params.RTT(); + sender_->OnCongestionEvent( + /*rtt_updated=*/true, kDefaultMaxPacketSize, now, + {AckedPacket(next_packet_number - 1, kDefaultMaxPacketSize, now)}, + {LostPacket(next_packet_number - 2, kDefaultMaxPacketSize)}, 0, 0); + + QuicByteCount inflight_hi = sender_->ExportDebugState().inflight_hi; + EXPECT_LT(2 * kDefaultMaxPacketSize, inflight_hi); +} + +// Ensures bandwidth estimate does not change after a loss only event. +TEST_F(Bbr2DefaultTopologyTest, LossOnlyCongestionEvent) { + DefaultTopologyParams params; + CreateNetwork(params); + + DriveOutOfStartup(params); + EXPECT_FALSE(sender_->ExportDebugState().last_sample_is_app_limited); + + // Send some bursts, each burst increments round count by 1, since it only + // generates small, app-limited samples, the max_bandwidth_filter_ will not be + // updated. + SendBursts(params, 20, 512, QuicTime::Delta::FromSeconds(3)); + + // Run until we have something in flight. + sender_endpoint_.AddBytesToTransfer(50 * 1024 * 1024); + bool simulator_result = simulator_.RunUntilOrTimeout( + [&]() { return sender_unacked_map()->bytes_in_flight() > 0; }, + QuicTime::Delta::FromSeconds(5)); + ASSERT_TRUE(simulator_result); + + const QuicBandwidth prior_bandwidth_estimate = sender_->BandwidthEstimate(); + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), prior_bandwidth_estimate, + 0.01f); + + // Lose the least unacked packet. + LostPacketVector lost_packets; + lost_packets.emplace_back( + sender_connection()->sent_packet_manager().GetLeastUnacked(), + kDefaultMaxPacketSize); + + QuicTime now = simulator_.GetClock()->Now() + params.RTT() * 0.25; + sender_->OnCongestionEvent(false, sender_unacked_map()->bytes_in_flight(), + now, {}, lost_packets, 0, 0); + + // Bandwidth estimate should not change for the loss only event. + EXPECT_EQ(prior_bandwidth_estimate, sender_->BandwidthEstimate()); +} + +// After quiescence, if the sender is in PROBE_RTT, it should transition to +// PROBE_BW immediately on the first sent packet after quiescence. +TEST_F(Bbr2DefaultTopologyTest, ProbeRttAfterQuiescenceImmediatelyExits) { + DefaultTopologyParams params; + CreateNetwork(params); + + DriveOutOfStartup(params); + + const QuicTime::Delta timeout = QuicTime::Delta::FromSeconds(15); + bool simulator_result; + + // Keep sending until reach PROBE_RTT. + simulator_result = SendUntilOrTimeout( + [this]() { + return sender_->ExportDebugState().mode == Bbr2Mode::PROBE_RTT; + }, + timeout); + ASSERT_TRUE(simulator_result); + + // Wait for entering a quiescence of 5 seconds. + ASSERT_TRUE(simulator_.RunUntilOrTimeout( + [this]() { + return sender_unacked_map()->bytes_in_flight() == 0 && + sender_->ExportDebugState().mode == Bbr2Mode::PROBE_RTT; + }, + timeout)); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(5)); + + // Send one packet to exit quiescence. + EXPECT_EQ(sender_->ExportDebugState().mode, Bbr2Mode::PROBE_RTT); + sender_->OnPacketSent(SimulatedNow(), /*bytes_in_flight=*/0, + sender_unacked_map()->largest_sent_packet() + 1, + kDefaultMaxPacketSize, HAS_RETRANSMITTABLE_DATA); + + EXPECT_EQ(sender_->ExportDebugState().mode, Bbr2Mode::PROBE_BW); +} + +TEST_F(Bbr2DefaultTopologyTest, ProbeBwAfterQuiescencePostponeMinRttTimestamp) { + DefaultTopologyParams params; + CreateNetwork(params); + + DriveOutOfStartup(params); + + const QuicTime::Delta timeout = QuicTime::Delta::FromSeconds(5); + bool simulator_result; + + // Keep sending until reach PROBE_REFILL. + simulator_result = SendUntilOrTimeout( + [this]() { + return sender_->ExportDebugState().probe_bw.phase == + CyclePhase::PROBE_REFILL; + }, + timeout); + ASSERT_TRUE(simulator_result); + + const QuicTime min_rtt_timestamp_before_idle = + sender_->ExportDebugState().min_rtt_timestamp; + + // Wait for entering a quiescence of 15 seconds. + ASSERT_TRUE(simulator_.RunUntilOrTimeout( + [this]() { return sender_unacked_map()->bytes_in_flight() == 0; }, + params.RTT() + timeout)); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + + // Send some data to exit quiescence. + SendBursts(params, 1, kDefaultTCPMSS, QuicTime::Delta::Zero()); + const QuicTime min_rtt_timestamp_after_idle = + sender_->ExportDebugState().min_rtt_timestamp; + + EXPECT_LT(min_rtt_timestamp_before_idle + QuicTime::Delta::FromSeconds(14), + min_rtt_timestamp_after_idle); +} + +TEST_F(Bbr2DefaultTopologyTest, SwitchToBbr2MidConnection) { + QuicTime now = QuicTime::Zero(); + BbrSender old_sender(sender_connection()->clock()->Now(), + sender_connection()->sent_packet_manager().GetRttStats(), + GetUnackedMap(sender_connection()), + kDefaultInitialCwndPackets + 1, + GetQuicFlag(quic_max_congestion_window), &random_, + QuicConnectionPeer::GetStats(sender_connection())); + + QuicPacketNumber next_packet_number(1); + + // Send packets 1-4. + while (next_packet_number < QuicPacketNumber(5)) { + now = now + QuicTime::Delta::FromMilliseconds(10); + + old_sender.OnPacketSent(now, /*bytes_in_flight=*/0, next_packet_number++, + /*bytes=*/1350, HAS_RETRANSMITTABLE_DATA); + } + + // Switch from |old_sender| to |sender_|. + const QuicByteCount old_sender_cwnd = old_sender.GetCongestionWindow(); + sender_ = SetupBbr2Sender(&sender_endpoint_, &old_sender); + EXPECT_EQ(old_sender_cwnd, sender_->GetCongestionWindow()); + + // Send packets 5-7. + now = now + QuicTime::Delta::FromMilliseconds(10); + sender_->OnPacketSent(now, /*bytes_in_flight=*/1350, next_packet_number++, + /*bytes=*/23, NO_RETRANSMITTABLE_DATA); + + now = now + QuicTime::Delta::FromMilliseconds(10); + sender_->OnPacketSent(now, /*bytes_in_flight=*/1350, next_packet_number++, + /*bytes=*/767, HAS_RETRANSMITTABLE_DATA); + + QuicByteCount bytes_in_flight = 767; + while (next_packet_number < QuicPacketNumber(30)) { + now = now + QuicTime::Delta::FromMilliseconds(10); + bytes_in_flight += 1350; + sender_->OnPacketSent(now, bytes_in_flight, next_packet_number++, + /*bytes=*/1350, HAS_RETRANSMITTABLE_DATA); + } + + // Ack 1 & 2. + AckedPacketVector acked = { + AckedPacket(QuicPacketNumber(1), /*bytes_acked=*/0, QuicTime::Zero()), + AckedPacket(QuicPacketNumber(2), /*bytes_acked=*/0, QuicTime::Zero()), + }; + now = now + QuicTime::Delta::FromMilliseconds(2000); + sender_->OnCongestionEvent(true, bytes_in_flight, now, acked, {}, 0, 0); + + // Send 30-41. + while (next_packet_number < QuicPacketNumber(42)) { + now = now + QuicTime::Delta::FromMilliseconds(10); + bytes_in_flight += 1350; + sender_->OnPacketSent(now, bytes_in_flight, next_packet_number++, + /*bytes=*/1350, HAS_RETRANSMITTABLE_DATA); + } + + // Ack 3. + acked = { + AckedPacket(QuicPacketNumber(3), /*bytes_acked=*/0, QuicTime::Zero()), + }; + now = now + QuicTime::Delta::FromMilliseconds(2000); + sender_->OnCongestionEvent(true, bytes_in_flight, now, acked, {}, 0, 0); + + // Send 42. + now = now + QuicTime::Delta::FromMilliseconds(10); + bytes_in_flight += 1350; + sender_->OnPacketSent(now, bytes_in_flight, next_packet_number++, + /*bytes=*/1350, HAS_RETRANSMITTABLE_DATA); + + // Ack 4-7. + acked = { + AckedPacket(QuicPacketNumber(4), /*bytes_acked=*/0, QuicTime::Zero()), + AckedPacket(QuicPacketNumber(5), /*bytes_acked=*/0, QuicTime::Zero()), + AckedPacket(QuicPacketNumber(6), /*bytes_acked=*/767, QuicTime::Zero()), + AckedPacket(QuicPacketNumber(7), /*bytes_acked=*/1350, QuicTime::Zero()), + }; + now = now + QuicTime::Delta::FromMilliseconds(2000); + sender_->OnCongestionEvent(true, bytes_in_flight, now, acked, {}, 0, 0); + EXPECT_FALSE(sender_->BandwidthEstimate().IsZero()); +} + +TEST_F(Bbr2DefaultTopologyTest, AdjustNetworkParameters) { + DefaultTopologyParams params; + CreateNetwork(params); + + QUIC_LOG(INFO) << "Initial cwnd: " << sender_debug_state().congestion_window + << "\nInitial pacing rate: " << sender_->PacingRate(0) + << "\nInitial bandwidth estimate: " + << sender_->BandwidthEstimate() + << "\nInitial rtt: " << sender_debug_state().min_rtt; + + sender_connection()->AdjustNetworkParameters( + SendAlgorithmInterface::NetworkParams(params.BottleneckBandwidth(), + params.RTT(), + /*allow_cwnd_to_decrease=*/false)); + + EXPECT_EQ(params.BDP(), sender_->ExportDebugState().congestion_window); + + EXPECT_EQ(params.BottleneckBandwidth(), + sender_->PacingRate(/*bytes_in_flight=*/0)); + EXPECT_NE(params.BottleneckBandwidth(), sender_->BandwidthEstimate()); + + EXPECT_APPROX_EQ(params.RTT(), sender_->ExportDebugState().min_rtt, 0.01f); + + DriveOutOfStartup(params); +} + +TEST_F(Bbr2DefaultTopologyTest, + 200InitialCongestionWindowWithNetworkParameterAdjusted) { + DefaultTopologyParams params; + CreateNetwork(params); + + sender_endpoint_.AddBytesToTransfer(1 * 1024 * 1024); + + // Wait until an ACK comes back. + const QuicTime::Delta timeout = QuicTime::Delta::FromSeconds(5); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return !sender_->ExportDebugState().min_rtt.IsZero(); }, + timeout); + ASSERT_TRUE(simulator_result); + + // Bootstrap cwnd by a overly large bandwidth sample. + sender_connection()->AdjustNetworkParameters( + SendAlgorithmInterface::NetworkParams(1024 * params.BottleneckBandwidth(), + QuicTime::Delta::Zero(), false)); + + // Verify cwnd is capped at 200. + EXPECT_EQ(200 * kDefaultTCPMSS, + sender_->ExportDebugState().congestion_window); + EXPECT_GT(1024 * params.BottleneckBandwidth(), sender_->PacingRate(0)); +} + +TEST_F(Bbr2DefaultTopologyTest, + 100InitialCongestionWindowFromNetworkParameter) { + DefaultTopologyParams params; + CreateNetwork(params); + + sender_endpoint_.AddBytesToTransfer(1 * 1024 * 1024); + // Wait until an ACK comes back. + const QuicTime::Delta timeout = QuicTime::Delta::FromSeconds(5); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return !sender_->ExportDebugState().min_rtt.IsZero(); }, + timeout); + ASSERT_TRUE(simulator_result); + + // Bootstrap cwnd by a overly large bandwidth sample. + SendAlgorithmInterface::NetworkParams network_params( + 1024 * params.BottleneckBandwidth(), QuicTime::Delta::Zero(), false); + network_params.max_initial_congestion_window = 100; + sender_connection()->AdjustNetworkParameters(network_params); + + // Verify cwnd is capped at 100. + EXPECT_EQ(100 * kDefaultTCPMSS, + sender_->ExportDebugState().congestion_window); + EXPECT_GT(1024 * params.BottleneckBandwidth(), sender_->PacingRate(0)); +} + +TEST_F(Bbr2DefaultTopologyTest, + 100InitialCongestionWindowWithNetworkParameterAdjusted) { + SetConnectionOption(kICW1); + DefaultTopologyParams params; + CreateNetwork(params); + + sender_endpoint_.AddBytesToTransfer(1 * 1024 * 1024); + // Wait until an ACK comes back. + const QuicTime::Delta timeout = QuicTime::Delta::FromSeconds(5); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return !sender_->ExportDebugState().min_rtt.IsZero(); }, + timeout); + ASSERT_TRUE(simulator_result); + + // Bootstrap cwnd by a overly large bandwidth sample. + sender_connection()->AdjustNetworkParameters( + SendAlgorithmInterface::NetworkParams(1024 * params.BottleneckBandwidth(), + QuicTime::Delta::Zero(), false)); + + // Verify cwnd is capped at 100. + EXPECT_EQ(100 * kDefaultTCPMSS, + sender_->ExportDebugState().congestion_window); + EXPECT_GT(1024 * params.BottleneckBandwidth(), sender_->PacingRate(0)); +} + +// All Bbr2MultiSenderTests uses the following network topology: +// +// Sender 0 (A Bbr2Sender) +// | +// | <-- local_links[0] +// | +// | Sender N (1 <= N < kNumLocalLinks) (May or may not be a Bbr2Sender) +// | | +// | | <-- local_links[N] +// | | +// Network switch +// * <-- the bottleneck queue in the direction +// | of the receiver +// | +// | <-- test_link +// | +// | +// Receiver +class MultiSenderTopologyParams { + public: + static constexpr size_t kNumLocalLinks = 8; + std::array local_links = { + LinkParams(10000, 1987), LinkParams(10000, 1993), LinkParams(10000, 1997), + LinkParams(10000, 1999), LinkParams(10000, 2003), LinkParams(10000, 2011), + LinkParams(10000, 2017), LinkParams(10000, 2027), + }; + + LinkParams test_link = LinkParams(4000, 30000); + + const simulator::SwitchPortNumber switch_port_count = kNumLocalLinks + 1; + + // Network switch queue capacity, in number of BDPs. + float switch_queue_capacity_in_bdp = 2; + + QuicBandwidth BottleneckBandwidth() const { + // Make sure all local links have a higher bandwidth than the test link. + for (size_t i = 0; i < local_links.size(); ++i) { + QUICHE_CHECK_GT(local_links[i].bandwidth, test_link.bandwidth); + } + return test_link.bandwidth; + } + + // Sender n's round trip time of a single full size packet. + QuicTime::Delta Rtt(size_t n) const { + return 2 * (local_links[n].delay + test_link.delay + + local_links[n].bandwidth.TransferTime(kMaxOutgoingPacketSize) + + test_link.bandwidth.TransferTime(kMaxOutgoingPacketSize)); + } + + QuicByteCount Bdp(size_t n) const { return BottleneckBandwidth() * Rtt(n); } + + QuicByteCount SwitchQueueCapacity() const { + return switch_queue_capacity_in_bdp * Bdp(1); + } + + std::string ToString() const { + std::ostringstream os; + os << "{ BottleneckBandwidth: " << BottleneckBandwidth(); + for (size_t i = 0; i < local_links.size(); ++i) { + os << " RTT_" << i << ": " << Rtt(i) << " BDP_" << i << ": " << Bdp(i); + } + os << " BottleneckQueueSize: " << SwitchQueueCapacity() << "}"; + return os.str(); + } +}; + +class Bbr2MultiSenderTest : public Bbr2SimulatorTest { + protected: + Bbr2MultiSenderTest() { + uint64_t first_connection_id = 42; + std::vector receiver_endpoint_pointers; + for (size_t i = 0; i < MultiSenderTopologyParams::kNumLocalLinks; ++i) { + std::string sender_name = absl::StrCat("Sender", i + 1); + std::string receiver_name = absl::StrCat("Receiver", i + 1); + sender_endpoints_.push_back(std::make_unique( + &simulator_, sender_name, receiver_name, Perspective::IS_CLIENT, + TestConnectionId(first_connection_id + i))); + receiver_endpoints_.push_back(std::make_unique( + &simulator_, receiver_name, sender_name, Perspective::IS_SERVER, + TestConnectionId(first_connection_id + i))); + receiver_endpoint_pointers.push_back(receiver_endpoints_.back().get()); + } + receiver_multiplexer_ = + std::make_unique( + "Receiver multiplexer", receiver_endpoint_pointers); + sender_0_ = SetupBbr2Sender(sender_endpoints_[0].get()); + } + + ~Bbr2MultiSenderTest() { + const auto* test_info = + ::testing::UnitTest::GetInstance()->current_test_info(); + QUIC_LOG(INFO) << "Bbr2MultiSenderTest." << test_info->name() + << " completed at simulated time: " + << SimulatedNow().ToDebuggingValue() / 1e6 + << " sec. Per sender stats:"; + for (size_t i = 0; i < sender_endpoints_.size(); ++i) { + QUIC_LOG(INFO) << "sender[" << i << "]: " + << sender_connection(i) + ->sent_packet_manager() + .GetSendAlgorithm() + ->GetCongestionControlType() + << ", packet_loss:" + << 100.0 * sender_loss_rate_in_packets(i) << "%"; + } + } + + Bbr2Sender* SetupBbr2Sender(simulator::QuicEndpoint* endpoint) { + // Ownership of the sender will be overtaken by the endpoint. + Bbr2Sender* sender = new Bbr2Sender( + endpoint->connection()->clock()->Now(), + endpoint->connection()->sent_packet_manager().GetRttStats(), + QuicSentPacketManagerPeer::GetUnackedPacketMap( + QuicConnectionPeer::GetSentPacketManager(endpoint->connection())), + kDefaultInitialCwndPackets, GetQuicFlag(quic_max_congestion_window), + &random_, QuicConnectionPeer::GetStats(endpoint->connection()), + nullptr); + // TODO(ianswett): Add dedicated tests for this option until it becomes + // the default behavior. + SetConnectionOption(sender, kBBRA); + + QuicConnectionPeer::SetSendAlgorithm(endpoint->connection(), sender); + endpoint->RecordTrace(); + return sender; + } + + BbrSender* SetupBbrSender(simulator::QuicEndpoint* endpoint) { + // Ownership of the sender will be overtaken by the endpoint. + BbrSender* sender = new BbrSender( + endpoint->connection()->clock()->Now(), + endpoint->connection()->sent_packet_manager().GetRttStats(), + QuicSentPacketManagerPeer::GetUnackedPacketMap( + QuicConnectionPeer::GetSentPacketManager(endpoint->connection())), + kDefaultInitialCwndPackets, GetQuicFlag(quic_max_congestion_window), + &random_, QuicConnectionPeer::GetStats(endpoint->connection())); + QuicConnectionPeer::SetSendAlgorithm(endpoint->connection(), sender); + endpoint->RecordTrace(); + return sender; + } + + // reno => Reno. !reno => Cubic. + TcpCubicSenderBytes* SetupTcpSender(simulator::QuicEndpoint* endpoint, + bool reno) { + // Ownership of the sender will be overtaken by the endpoint. + TcpCubicSenderBytes* sender = new TcpCubicSenderBytes( + endpoint->connection()->clock(), + endpoint->connection()->sent_packet_manager().GetRttStats(), reno, + kDefaultInitialCwndPackets, GetQuicFlag(quic_max_congestion_window), + QuicConnectionPeer::GetStats(endpoint->connection())); + QuicConnectionPeer::SetSendAlgorithm(endpoint->connection(), sender); + endpoint->RecordTrace(); + return sender; + } + + void SetConnectionOption(SendAlgorithmInterface* sender, QuicTag option) { + QuicConfig config; + QuicTagVector options; + options.push_back(option); + QuicConfigPeer::SetReceivedConnectionOptions(&config, options); + sender->SetFromConfig(config, Perspective::IS_SERVER); + } + + void CreateNetwork(const MultiSenderTopologyParams& params) { + QUIC_LOG(INFO) << "CreateNetwork with parameters: " << params.ToString(); + switch_ = std::make_unique(&simulator_, "Switch", + params.switch_port_count, + params.SwitchQueueCapacity()); + + network_links_.push_back(std::make_unique( + receiver_multiplexer_.get(), switch_->port(1), + params.test_link.bandwidth, params.test_link.delay)); + for (size_t i = 0; i < MultiSenderTopologyParams::kNumLocalLinks; ++i) { + simulator::SwitchPortNumber port_number = i + 2; + network_links_.push_back(std::make_unique( + sender_endpoints_[i].get(), switch_->port(port_number), + params.local_links[i].bandwidth, params.local_links[i].delay)); + } + } + + QuicConnection* sender_connection(size_t which) { + return sender_endpoints_[which]->connection(); + } + + const QuicConnectionStats& sender_connection_stats(size_t which) { + return sender_connection(which)->GetStats(); + } + + float sender_loss_rate_in_packets(size_t which) { + return static_cast(sender_connection_stats(which).packets_lost) / + sender_connection_stats(which).packets_sent; + } + + std::vector> sender_endpoints_; + std::vector> receiver_endpoints_; + std::unique_ptr receiver_multiplexer_; + Bbr2Sender* sender_0_; + + std::unique_ptr switch_; + std::vector> network_links_; +}; + +TEST_F(Bbr2MultiSenderTest, Bbr2VsBbr2) { + SetupBbr2Sender(sender_endpoints_[1].get()); + + MultiSenderTopologyParams params; + CreateNetwork(params); + + const QuicByteCount transfer_size = 10 * 1024 * 1024; + const QuicTime::Delta transfer_time = + params.BottleneckBandwidth().TransferTime(transfer_size); + QUIC_LOG(INFO) << "Single flow transfer time: " << transfer_time; + + // Transfer 10% of data in first transfer. + sender_endpoints_[0]->AddBytesToTransfer(transfer_size); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { + return receiver_endpoints_[0]->bytes_received() >= 0.1 * transfer_size; + }, + transfer_time); + ASSERT_TRUE(simulator_result); + + // Start the second transfer and wait until both finish. + sender_endpoints_[1]->AddBytesToTransfer(transfer_size); + simulator_result = simulator_.RunUntilOrTimeout( + [this]() { + return receiver_endpoints_[0]->bytes_received() == transfer_size && + receiver_endpoints_[1]->bytes_received() == transfer_size; + }, + 3 * transfer_time); + ASSERT_TRUE(simulator_result); +} + +TEST_F(Bbr2MultiSenderTest, Bbr2VsBbr2BBPD) { + SetConnectionOption(sender_0_, kBBPD); + Bbr2Sender* sender_1 = SetupBbr2Sender(sender_endpoints_[1].get()); + SetConnectionOption(sender_1, kBBPD); + + MultiSenderTopologyParams params; + CreateNetwork(params); + + const QuicByteCount transfer_size = 10 * 1024 * 1024; + const QuicTime::Delta transfer_time = + params.BottleneckBandwidth().TransferTime(transfer_size); + QUIC_LOG(INFO) << "Single flow transfer time: " << transfer_time; + + // Transfer 10% of data in first transfer. + sender_endpoints_[0]->AddBytesToTransfer(transfer_size); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { + return receiver_endpoints_[0]->bytes_received() >= 0.1 * transfer_size; + }, + transfer_time); + ASSERT_TRUE(simulator_result); + + // Start the second transfer and wait until both finish. + sender_endpoints_[1]->AddBytesToTransfer(transfer_size); + simulator_result = simulator_.RunUntilOrTimeout( + [this]() { + return receiver_endpoints_[0]->bytes_received() == transfer_size && + receiver_endpoints_[1]->bytes_received() == transfer_size; + }, + 3 * transfer_time); + ASSERT_TRUE(simulator_result); +} + +TEST_F(Bbr2MultiSenderTest, QUIC_SLOW_TEST(MultipleBbr2s)) { + const int kTotalNumSenders = 6; + for (int i = 1; i < kTotalNumSenders; ++i) { + SetupBbr2Sender(sender_endpoints_[i].get()); + } + + MultiSenderTopologyParams params; + CreateNetwork(params); + + const QuicByteCount transfer_size = 10 * 1024 * 1024; + const QuicTime::Delta transfer_time = + params.BottleneckBandwidth().TransferTime(transfer_size); + QUIC_LOG(INFO) << "Single flow transfer time: " << transfer_time + << ". Now: " << SimulatedNow(); + + // Start all transfers. + for (int i = 0; i < kTotalNumSenders; ++i) { + if (i != 0) { + const QuicTime sender_start_time = + SimulatedNow() + QuicTime::Delta::FromSeconds(2); + bool simulator_result = simulator_.RunUntilOrTimeout( + [&]() { return SimulatedNow() >= sender_start_time; }, transfer_time); + ASSERT_TRUE(simulator_result); + } + + sender_endpoints_[i]->AddBytesToTransfer(transfer_size); + } + + // Wait for all transfers to finish. + QuicTime::Delta expected_total_transfer_time_upper_bound = + QuicTime::Delta::FromMicroseconds(kTotalNumSenders * + transfer_time.ToMicroseconds() * 1.1); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { + for (int i = 0; i < kTotalNumSenders; ++i) { + if (receiver_endpoints_[i]->bytes_received() < transfer_size) { + return false; + } + } + return true; + }, + expected_total_transfer_time_upper_bound); + ASSERT_TRUE(simulator_result) + << "Expected upper bound: " << expected_total_transfer_time_upper_bound; +} + +/* The first 11 packets are sent at the same time, but the duration between the + * acks of the 1st and the 11th packet is 49 milliseconds, causing very low bw + * samples. This happens for both large and small buffers. + */ +/* +TEST_F(Bbr2MultiSenderTest, Bbr2VsBbr2LargeRttTinyBuffer) { + SetupBbr2Sender(sender_endpoints_[1].get()); + + MultiSenderTopologyParams params; + params.switch_queue_capacity_in_bdp = 0.05; + params.test_link.delay = QuicTime::Delta::FromSeconds(1); + CreateNetwork(params); + + const QuicByteCount transfer_size = 10 * 1024 * 1024; + const QuicTime::Delta transfer_time = + params.BottleneckBandwidth().TransferTime(transfer_size); + QUIC_LOG(INFO) << "Single flow transfer time: " << transfer_time; + + // Transfer 10% of data in first transfer. + sender_endpoints_[0]->AddBytesToTransfer(transfer_size); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { + return receiver_endpoints_[0]->bytes_received() >= 0.1 * transfer_size; + }, + transfer_time); + ASSERT_TRUE(simulator_result); + + // Start the second transfer and wait until both finish. + sender_endpoints_[1]->AddBytesToTransfer(transfer_size); + simulator_result = simulator_.RunUntilOrTimeout( + [this]() { + return receiver_endpoints_[0]->bytes_received() == transfer_size && + receiver_endpoints_[1]->bytes_received() == transfer_size; + }, + 3 * transfer_time); + ASSERT_TRUE(simulator_result); +} +*/ + +TEST_F(Bbr2MultiSenderTest, Bbr2VsBbr1) { + SetupBbrSender(sender_endpoints_[1].get()); + + MultiSenderTopologyParams params; + CreateNetwork(params); + + const QuicByteCount transfer_size = 10 * 1024 * 1024; + const QuicTime::Delta transfer_time = + params.BottleneckBandwidth().TransferTime(transfer_size); + QUIC_LOG(INFO) << "Single flow transfer time: " << transfer_time; + + // Transfer 10% of data in first transfer. + sender_endpoints_[0]->AddBytesToTransfer(transfer_size); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { + return receiver_endpoints_[0]->bytes_received() >= 0.1 * transfer_size; + }, + transfer_time); + ASSERT_TRUE(simulator_result); + + // Start the second transfer and wait until both finish. + sender_endpoints_[1]->AddBytesToTransfer(transfer_size); + simulator_result = simulator_.RunUntilOrTimeout( + [this]() { + return receiver_endpoints_[0]->bytes_received() == transfer_size && + receiver_endpoints_[1]->bytes_received() == transfer_size; + }, + 3 * transfer_time); + ASSERT_TRUE(simulator_result); +} + +TEST_F(Bbr2MultiSenderTest, QUIC_SLOW_TEST(Bbr2VsReno)) { + SetupTcpSender(sender_endpoints_[1].get(), /*reno=*/true); + + MultiSenderTopologyParams params; + CreateNetwork(params); + + const QuicByteCount transfer_size = 10 * 1024 * 1024; + const QuicTime::Delta transfer_time = + params.BottleneckBandwidth().TransferTime(transfer_size); + QUIC_LOG(INFO) << "Single flow transfer time: " << transfer_time; + + // Transfer 10% of data in first transfer. + sender_endpoints_[0]->AddBytesToTransfer(transfer_size); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { + return receiver_endpoints_[0]->bytes_received() >= 0.1 * transfer_size; + }, + transfer_time); + ASSERT_TRUE(simulator_result); + + // Start the second transfer and wait until both finish. + sender_endpoints_[1]->AddBytesToTransfer(transfer_size); + simulator_result = simulator_.RunUntilOrTimeout( + [this]() { + return receiver_endpoints_[0]->bytes_received() == transfer_size && + receiver_endpoints_[1]->bytes_received() == transfer_size; + }, + 3 * transfer_time); + ASSERT_TRUE(simulator_result); +} + +TEST_F(Bbr2MultiSenderTest, QUIC_SLOW_TEST(Bbr2VsRenoB2RC)) { + SetConnectionOption(sender_0_, kB2RC); + SetupTcpSender(sender_endpoints_[1].get(), /*reno=*/true); + + MultiSenderTopologyParams params; + CreateNetwork(params); + + const QuicByteCount transfer_size = 10 * 1024 * 1024; + const QuicTime::Delta transfer_time = + params.BottleneckBandwidth().TransferTime(transfer_size); + QUIC_LOG(INFO) << "Single flow transfer time: " << transfer_time; + + // Transfer 10% of data in first transfer. + sender_endpoints_[0]->AddBytesToTransfer(transfer_size); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { + return receiver_endpoints_[0]->bytes_received() >= 0.1 * transfer_size; + }, + transfer_time); + ASSERT_TRUE(simulator_result); + + // Start the second transfer and wait until both finish. + sender_endpoints_[1]->AddBytesToTransfer(transfer_size); + simulator_result = simulator_.RunUntilOrTimeout( + [this]() { + return receiver_endpoints_[0]->bytes_received() == transfer_size && + receiver_endpoints_[1]->bytes_received() == transfer_size; + }, + 3 * transfer_time); + ASSERT_TRUE(simulator_result); +} + +TEST_F(Bbr2MultiSenderTest, QUIC_SLOW_TEST(Bbr2VsCubic)) { + SetupTcpSender(sender_endpoints_[1].get(), /*reno=*/false); + + MultiSenderTopologyParams params; + CreateNetwork(params); + + const QuicByteCount transfer_size = 50 * 1024 * 1024; + const QuicTime::Delta transfer_time = + params.BottleneckBandwidth().TransferTime(transfer_size); + QUIC_LOG(INFO) << "Single flow transfer time: " << transfer_time; + + // Transfer 10% of data in first transfer. + sender_endpoints_[0]->AddBytesToTransfer(transfer_size); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { + return receiver_endpoints_[0]->bytes_received() >= 0.1 * transfer_size; + }, + transfer_time); + ASSERT_TRUE(simulator_result); + + // Start the second transfer and wait until both finish. + sender_endpoints_[1]->AddBytesToTransfer(transfer_size); + simulator_result = simulator_.RunUntilOrTimeout( + [this]() { + return receiver_endpoints_[0]->bytes_received() == transfer_size && + receiver_endpoints_[1]->bytes_received() == transfer_size; + }, + 3 * transfer_time); + ASSERT_TRUE(simulator_result); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/congestion_control/bbr2_startup.cc b/quiche/quic/core/congestion_control/bbr2_startup.cc new file mode 100644 index 000000000000..3c84f514be46 --- /dev/null +++ b/quiche/quic/core/congestion_control/bbr2_startup.cc @@ -0,0 +1,154 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/bbr2_startup.h" + +#include "quiche/quic/core/congestion_control/bbr2_misc.h" +#include "quiche/quic/core/congestion_control/bbr2_sender.h" +#include "quiche/quic/core/quic_bandwidth.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +Bbr2StartupMode::Bbr2StartupMode(const Bbr2Sender* sender, + Bbr2NetworkModel* model, QuicTime now) + : Bbr2ModeBase(sender, model) { + // Increment, instead of reset startup stats, so we don't lose data recorded + // before QuicConnection switched send algorithm to BBRv2. + ++sender_->connection_stats_->slowstart_count; + if (!sender_->connection_stats_->slowstart_duration.IsRunning()) { + sender_->connection_stats_->slowstart_duration.Start(now); + } + // Enter() is never called for Startup, so the gains needs to be set here. + model_->set_pacing_gain(Params().startup_pacing_gain); + model_->set_cwnd_gain(Params().startup_cwnd_gain); +} + +void Bbr2StartupMode::Enter(QuicTime /*now*/, + const Bbr2CongestionEvent* /*congestion_event*/) { + QUIC_BUG(quic_bug_10463_1) << "Bbr2StartupMode::Enter should not be called"; +} + +void Bbr2StartupMode::Leave(QuicTime now, + const Bbr2CongestionEvent* /*congestion_event*/) { + sender_->connection_stats_->slowstart_duration.Stop(now); + // Clear bandwidth_lo if it's set during STARTUP. + model_->clear_bandwidth_lo(); +} + +Bbr2Mode Bbr2StartupMode::OnCongestionEvent( + QuicByteCount /*prior_in_flight*/, QuicTime /*event_time*/, + const AckedPacketVector& /*acked_packets*/, + const LostPacketVector& /*lost_packets*/, + const Bbr2CongestionEvent& congestion_event) { + if (model_->full_bandwidth_reached()) { + QUIC_BUG() << "In STARTUP, but full_bandwidth_reached is true."; + return Bbr2Mode::DRAIN; + } + if (!congestion_event.end_of_round_trip) { + return Bbr2Mode::STARTUP; + } + bool has_bandwidth_growth = model_->HasBandwidthGrowth(congestion_event); + if (Params().max_startup_queue_rounds > 0 && !has_bandwidth_growth) { + // 1.75 is less than the 2x CWND gain, but substantially more than 1.25x, + // the minimum bandwidth increase expected during STARTUP. + model_->CheckPersistentQueue(congestion_event, 1.75); + } + // TCP BBR always exits upon excessive losses. QUIC BBRv1 does not exit + // upon excessive losses, if enough bandwidth growth is observed or if the + // sample was app limited. + if (Params().always_exit_startup_on_excess_loss || + (!congestion_event.last_packet_send_state.is_app_limited && + !has_bandwidth_growth)) { + CheckExcessiveLosses(congestion_event); + } + + if (Params().decrease_startup_pacing_at_end_of_round) { + QUICHE_DCHECK_GT(model_->pacing_gain(), 0); + if (!congestion_event.last_packet_send_state.is_app_limited) { + // Multiply by startup_pacing_gain, so if the bandwidth doubles, + // the pacing gain will be the full startup_pacing_gain. + if (max_bw_at_round_beginning_ > QuicBandwidth::Zero()) { + const float bandwidth_ratio = + std::max(1., model_->MaxBandwidth().ToBitsPerSecond() / + static_cast( + max_bw_at_round_beginning_.ToBitsPerSecond())); + // Even when bandwidth isn't increasing, use a gain large enough to + // cause a full_bw_threshold increase. + const float new_gain = + ((bandwidth_ratio - 1) * + (Params().startup_pacing_gain - Params().full_bw_threshold)) + + Params().full_bw_threshold; + // Allow the pacing gain to decrease. + model_->set_pacing_gain( + std::min(Params().startup_pacing_gain, new_gain)); + // Clear bandwidth_lo if it's less than the pacing rate. + // This avoids a constantly app-limited flow from having it's pacing + // gain effectively decreased below 1.25. + if (model_->bandwidth_lo() < + model_->MaxBandwidth() * model_->pacing_gain()) { + model_->clear_bandwidth_lo(); + } + } + max_bw_at_round_beginning_ = model_->MaxBandwidth(); + } + } + + // TODO(wub): Maybe implement STARTUP => PROBE_RTT. + return model_->full_bandwidth_reached() ? Bbr2Mode::DRAIN : Bbr2Mode::STARTUP; +} + +void Bbr2StartupMode::CheckExcessiveLosses( + const Bbr2CongestionEvent& congestion_event) { + QUICHE_DCHECK(congestion_event.end_of_round_trip); + + if (model_->full_bandwidth_reached()) { + return; + } + + // At the end of a round trip. Check if loss is too high in this round. + if (model_->IsInflightTooHigh(congestion_event, + Params().startup_full_loss_count)) { + QuicByteCount new_inflight_hi = model_->BDP(); + if (Params().startup_loss_exit_use_max_delivered_for_inflight_hi) { + if (new_inflight_hi < model_->max_bytes_delivered_in_round()) { + new_inflight_hi = model_->max_bytes_delivered_in_round(); + } + } + QUIC_DVLOG(3) << sender_ << " Exiting STARTUP due to loss at round " + << model_->RoundTripCount() + << ". inflight_hi:" << new_inflight_hi; + // TODO(ianswett): Add a shared method to set inflight_hi in the model. + model_->set_inflight_hi(new_inflight_hi); + model_->set_full_bandwidth_reached(); + sender_->connection_stats_->bbr_exit_startup_due_to_loss = true; + } +} + +Bbr2StartupMode::DebugState Bbr2StartupMode::ExportDebugState() const { + DebugState s; + s.full_bandwidth_reached = model_->full_bandwidth_reached(); + s.full_bandwidth_baseline = model_->full_bandwidth_baseline(); + s.round_trips_without_bandwidth_growth = + model_->rounds_without_bandwidth_growth(); + return s; +} + +std::ostream& operator<<(std::ostream& os, + const Bbr2StartupMode::DebugState& state) { + os << "[STARTUP] full_bandwidth_reached: " << state.full_bandwidth_reached + << "\n"; + os << "[STARTUP] full_bandwidth_baseline: " << state.full_bandwidth_baseline + << "\n"; + os << "[STARTUP] round_trips_without_bandwidth_growth: " + << state.round_trips_without_bandwidth_growth << "\n"; + return os; +} + +const Bbr2Params& Bbr2StartupMode::Params() const { return sender_->Params(); } + +} // namespace quic diff --git a/quiche/quic/core/congestion_control/bbr2_startup.h b/quiche/quic/core/congestion_control/bbr2_startup.h new file mode 100644 index 000000000000..b246b3c7d0ce --- /dev/null +++ b/quiche/quic/core/congestion_control/bbr2_startup.h @@ -0,0 +1,68 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CONGESTION_CONTROL_BBR2_STARTUP_H_ +#define QUICHE_QUIC_CORE_CONGESTION_CONTROL_BBR2_STARTUP_H_ + +#include "quiche/quic/core/congestion_control/bbr2_misc.h" +#include "quiche/quic/core/quic_bandwidth.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class Bbr2Sender; +class QUIC_EXPORT_PRIVATE Bbr2StartupMode final : public Bbr2ModeBase { + public: + Bbr2StartupMode(const Bbr2Sender* sender, Bbr2NetworkModel* model, + QuicTime now); + + void Enter(QuicTime now, + const Bbr2CongestionEvent* congestion_event) override; + void Leave(QuicTime now, + const Bbr2CongestionEvent* congestion_event) override; + + Bbr2Mode OnCongestionEvent( + QuicByteCount prior_in_flight, QuicTime event_time, + const AckedPacketVector& acked_packets, + const LostPacketVector& lost_packets, + const Bbr2CongestionEvent& congestion_event) override; + + Limits GetCwndLimits() const override { + // Inflight_lo is never set in STARTUP. + QUICHE_DCHECK_EQ(Bbr2NetworkModel::inflight_lo_default(), + model_->inflight_lo()); + return NoGreaterThan(model_->inflight_lo()); + } + + bool IsProbingForBandwidth() const override { return true; } + + Bbr2Mode OnExitQuiescence(QuicTime /*now*/, + QuicTime /*quiescence_start_time*/) override { + return Bbr2Mode::STARTUP; + } + + struct QUIC_EXPORT_PRIVATE DebugState { + bool full_bandwidth_reached; + QuicBandwidth full_bandwidth_baseline = QuicBandwidth::Zero(); + QuicRoundTripCount round_trips_without_bandwidth_growth; + }; + + DebugState ExportDebugState() const; + + private: + const Bbr2Params& Params() const; + + void CheckExcessiveLosses(const Bbr2CongestionEvent& congestion_event); + // Used when the pacing gain can decrease in STARTUP. + QuicBandwidth max_bw_at_round_beginning_ = QuicBandwidth::Zero(); +}; + +QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const Bbr2StartupMode::DebugState& state); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CONGESTION_CONTROL_BBR2_STARTUP_H_ diff --git a/quiche/quic/core/congestion_control/bbr_sender.cc b/quiche/quic/core/congestion_control/bbr_sender.cc new file mode 100644 index 000000000000..322e7aa5d94c --- /dev/null +++ b/quiche/quic/core/congestion_control/bbr_sender.cc @@ -0,0 +1,896 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/bbr_sender.h" + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "quiche/quic/core/congestion_control/rtt_stats.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_time_accumulator.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +namespace { +// Constants based on TCP defaults. +// The minimum CWND to ensure delayed acks don't reduce bandwidth measurements. +// Does not inflate the pacing rate. +const QuicByteCount kDefaultMinimumCongestionWindow = 4 * kMaxSegmentSize; + +// The gain used for the STARTUP, equal to 2/ln(2). +const float kDefaultHighGain = 2.885f; +// The newly derived gain for STARTUP, equal to 4 * ln(2) +const float kDerivedHighGain = 2.773f; +// The newly derived CWND gain for STARTUP, 2. +const float kDerivedHighCWNDGain = 2.0f; +// The cycle of gains used during the PROBE_BW stage. +const float kPacingGain[] = {1.25, 0.75, 1, 1, 1, 1, 1, 1}; + +// The length of the gain cycle. +const size_t kGainCycleLength = sizeof(kPacingGain) / sizeof(kPacingGain[0]); +// The size of the bandwidth filter window, in round-trips. +const QuicRoundTripCount kBandwidthWindowSize = kGainCycleLength + 2; + +// The time after which the current min_rtt value expires. +const QuicTime::Delta kMinRttExpiry = QuicTime::Delta::FromSeconds(10); +// The minimum time the connection can spend in PROBE_RTT mode. +const QuicTime::Delta kProbeRttTime = QuicTime::Delta::FromMilliseconds(200); +// If the bandwidth does not increase by the factor of |kStartupGrowthTarget| +// within |kRoundTripsWithoutGrowthBeforeExitingStartup| rounds, the connection +// will exit the STARTUP mode. +const float kStartupGrowthTarget = 1.25; +const QuicRoundTripCount kRoundTripsWithoutGrowthBeforeExitingStartup = 3; +} // namespace + +BbrSender::DebugState::DebugState(const BbrSender& sender) + : mode(sender.mode_), + max_bandwidth(sender.max_bandwidth_.GetBest()), + round_trip_count(sender.round_trip_count_), + gain_cycle_index(sender.cycle_current_offset_), + congestion_window(sender.congestion_window_), + is_at_full_bandwidth(sender.is_at_full_bandwidth_), + bandwidth_at_last_round(sender.bandwidth_at_last_round_), + rounds_without_bandwidth_gain(sender.rounds_without_bandwidth_gain_), + min_rtt(sender.min_rtt_), + min_rtt_timestamp(sender.min_rtt_timestamp_), + recovery_state(sender.recovery_state_), + recovery_window(sender.recovery_window_), + last_sample_is_app_limited(sender.last_sample_is_app_limited_), + end_of_app_limited_phase(sender.sampler_.end_of_app_limited_phase()) {} + +BbrSender::DebugState::DebugState(const DebugState& state) = default; + +BbrSender::BbrSender(QuicTime now, const RttStats* rtt_stats, + const QuicUnackedPacketMap* unacked_packets, + QuicPacketCount initial_tcp_congestion_window, + QuicPacketCount max_tcp_congestion_window, + QuicRandom* random, QuicConnectionStats* stats) + : rtt_stats_(rtt_stats), + unacked_packets_(unacked_packets), + random_(random), + stats_(stats), + mode_(STARTUP), + sampler_(unacked_packets, kBandwidthWindowSize), + round_trip_count_(0), + num_loss_events_in_round_(0), + bytes_lost_in_round_(0), + max_bandwidth_(kBandwidthWindowSize, QuicBandwidth::Zero(), 0), + min_rtt_(QuicTime::Delta::Zero()), + min_rtt_timestamp_(QuicTime::Zero()), + congestion_window_(initial_tcp_congestion_window * kDefaultTCPMSS), + initial_congestion_window_(initial_tcp_congestion_window * + kDefaultTCPMSS), + max_congestion_window_(max_tcp_congestion_window * kDefaultTCPMSS), + min_congestion_window_(kDefaultMinimumCongestionWindow), + high_gain_(kDefaultHighGain), + high_cwnd_gain_(kDefaultHighGain), + drain_gain_(1.f / kDefaultHighGain), + pacing_rate_(QuicBandwidth::Zero()), + pacing_gain_(1), + congestion_window_gain_(1), + congestion_window_gain_constant_( + static_cast(GetQuicFlag(quic_bbr_cwnd_gain))), + num_startup_rtts_(kRoundTripsWithoutGrowthBeforeExitingStartup), + cycle_current_offset_(0), + last_cycle_start_(QuicTime::Zero()), + is_at_full_bandwidth_(false), + rounds_without_bandwidth_gain_(0), + bandwidth_at_last_round_(QuicBandwidth::Zero()), + exiting_quiescence_(false), + exit_probe_rtt_at_(QuicTime::Zero()), + probe_rtt_round_passed_(false), + last_sample_is_app_limited_(false), + has_non_app_limited_sample_(false), + recovery_state_(NOT_IN_RECOVERY), + recovery_window_(max_congestion_window_), + slower_startup_(false), + rate_based_startup_(false), + enable_ack_aggregation_during_startup_(false), + expire_ack_aggregation_in_startup_(false), + drain_to_target_(false), + detect_overshooting_(false), + bytes_lost_while_detecting_overshooting_(0), + bytes_lost_multiplier_while_detecting_overshooting_(2), + cwnd_to_calculate_min_pacing_rate_(initial_congestion_window_), + max_congestion_window_with_network_parameters_adjusted_( + kMaxInitialCongestionWindow * kDefaultTCPMSS) { + if (stats_) { + // Clear some startup stats if |stats_| has been used by another sender, + // which happens e.g. when QuicConnection switch send algorithms. + stats_->slowstart_count = 0; + stats_->slowstart_duration = QuicTimeAccumulator(); + } + EnterStartupMode(now); + set_high_cwnd_gain(kDerivedHighCWNDGain); +} + +BbrSender::~BbrSender() {} + +void BbrSender::SetInitialCongestionWindowInPackets( + QuicPacketCount congestion_window) { + if (mode_ == STARTUP) { + initial_congestion_window_ = congestion_window * kDefaultTCPMSS; + congestion_window_ = congestion_window * kDefaultTCPMSS; + cwnd_to_calculate_min_pacing_rate_ = std::min( + initial_congestion_window_, cwnd_to_calculate_min_pacing_rate_); + } +} + +bool BbrSender::InSlowStart() const { return mode_ == STARTUP; } + +void BbrSender::OnPacketSent(QuicTime sent_time, QuicByteCount bytes_in_flight, + QuicPacketNumber packet_number, + QuicByteCount bytes, + HasRetransmittableData is_retransmittable) { + if (stats_ && InSlowStart()) { + ++stats_->slowstart_packets_sent; + stats_->slowstart_bytes_sent += bytes; + } + + last_sent_packet_ = packet_number; + + if (bytes_in_flight == 0 && sampler_.is_app_limited()) { + exiting_quiescence_ = true; + } + + sampler_.OnPacketSent(sent_time, packet_number, bytes, bytes_in_flight, + is_retransmittable); +} + +void BbrSender::OnPacketNeutered(QuicPacketNumber packet_number) { + sampler_.OnPacketNeutered(packet_number); +} + +bool BbrSender::CanSend(QuicByteCount bytes_in_flight) { + return bytes_in_flight < GetCongestionWindow(); +} + +QuicBandwidth BbrSender::PacingRate(QuicByteCount /*bytes_in_flight*/) const { + if (pacing_rate_.IsZero()) { + return high_gain_ * QuicBandwidth::FromBytesAndTimeDelta( + initial_congestion_window_, GetMinRtt()); + } + return pacing_rate_; +} + +QuicBandwidth BbrSender::BandwidthEstimate() const { + return max_bandwidth_.GetBest(); +} + +QuicByteCount BbrSender::GetCongestionWindow() const { + if (mode_ == PROBE_RTT) { + return ProbeRttCongestionWindow(); + } + + if (InRecovery()) { + return std::min(congestion_window_, recovery_window_); + } + + return congestion_window_; +} + +QuicByteCount BbrSender::GetSlowStartThreshold() const { return 0; } + +bool BbrSender::InRecovery() const { + return recovery_state_ != NOT_IN_RECOVERY; +} + +void BbrSender::SetFromConfig(const QuicConfig& config, + Perspective perspective) { + if (config.HasClientRequestedIndependentOption(k1RTT, perspective)) { + num_startup_rtts_ = 1; + } + if (config.HasClientRequestedIndependentOption(k2RTT, perspective)) { + num_startup_rtts_ = 2; + } + if (config.HasClientRequestedIndependentOption(kBBR3, perspective)) { + drain_to_target_ = true; + } + if (config.HasClientRequestedIndependentOption(kBWM3, perspective)) { + bytes_lost_multiplier_while_detecting_overshooting_ = 3; + } + if (config.HasClientRequestedIndependentOption(kBWM4, perspective)) { + bytes_lost_multiplier_while_detecting_overshooting_ = 4; + } + if (config.HasClientRequestedIndependentOption(kBBR4, perspective)) { + sampler_.SetMaxAckHeightTrackerWindowLength(2 * kBandwidthWindowSize); + } + if (config.HasClientRequestedIndependentOption(kBBR5, perspective)) { + sampler_.SetMaxAckHeightTrackerWindowLength(4 * kBandwidthWindowSize); + } + if (config.HasClientRequestedIndependentOption(kBBQ1, perspective)) { + set_high_gain(kDerivedHighGain); + set_high_cwnd_gain(kDerivedHighGain); + set_drain_gain(1.0 / kDerivedHighCWNDGain); + } + if (config.HasClientRequestedIndependentOption(kBBQ3, perspective)) { + enable_ack_aggregation_during_startup_ = true; + } + if (config.HasClientRequestedIndependentOption(kBBQ5, perspective)) { + expire_ack_aggregation_in_startup_ = true; + } + if (config.HasClientRequestedIndependentOption(kMIN1, perspective)) { + min_congestion_window_ = kMaxSegmentSize; + } + if (config.HasClientRequestedIndependentOption(kICW1, perspective)) { + max_congestion_window_with_network_parameters_adjusted_ = + 100 * kDefaultTCPMSS; + } + if (config.HasClientRequestedIndependentOption(kDTOS, perspective)) { + detect_overshooting_ = true; + // DTOS would allow pacing rate drop to IW 10 / min_rtt if overshooting is + // detected. + cwnd_to_calculate_min_pacing_rate_ = + std::min(initial_congestion_window_, 10 * kDefaultTCPMSS); + } + + ApplyConnectionOptions(config.ClientRequestedIndependentOptions(perspective)); +} + +void BbrSender::ApplyConnectionOptions( + const QuicTagVector& connection_options) { + if (ContainsQuicTag(connection_options, kBSAO)) { + sampler_.EnableOverestimateAvoidance(); + } + if (ContainsQuicTag(connection_options, kBBRA)) { + sampler_.SetStartNewAggregationEpochAfterFullRound(true); + } + if (ContainsQuicTag(connection_options, kBBRB)) { + sampler_.SetLimitMaxAckHeightTrackerBySendRate(true); + } +} + +void BbrSender::AdjustNetworkParameters(const NetworkParams& params) { + const QuicBandwidth& bandwidth = params.bandwidth; + const QuicTime::Delta& rtt = params.rtt; + + if (!rtt.IsZero() && (min_rtt_ > rtt || min_rtt_.IsZero())) { + min_rtt_ = rtt; + } + + if (mode_ == STARTUP) { + if (bandwidth.IsZero()) { + // Ignore bad bandwidth samples. + return; + } + + auto cwnd_bootstrapping_rtt = GetMinRtt(); + if (params.max_initial_congestion_window > 0) { + max_congestion_window_with_network_parameters_adjusted_ = + params.max_initial_congestion_window * kDefaultTCPMSS; + } + const QuicByteCount new_cwnd = std::max( + kMinInitialCongestionWindow * kDefaultTCPMSS, + std::min(max_congestion_window_with_network_parameters_adjusted_, + bandwidth * cwnd_bootstrapping_rtt)); + + stats_->cwnd_bootstrapping_rtt_us = cwnd_bootstrapping_rtt.ToMicroseconds(); + if (!rtt_stats_->smoothed_rtt().IsZero()) { + QUIC_CODE_COUNT(quic_smoothed_rtt_available); + } else if (rtt_stats_->initial_rtt() != + QuicTime::Delta::FromMilliseconds(kInitialRttMs)) { + QUIC_CODE_COUNT(quic_client_initial_rtt_available); + } else { + QUIC_CODE_COUNT(quic_default_initial_rtt); + } + if (new_cwnd < congestion_window_ && !params.allow_cwnd_to_decrease) { + // Only decrease cwnd if allow_cwnd_to_decrease is true. + return; + } + if (GetQuicReloadableFlag(quic_conservative_cwnd_and_pacing_gains)) { + // Decreases cwnd gain and pacing gain. Please note, if pacing_rate_ has + // been calculated, it cannot decrease in STARTUP phase. + QUIC_RELOADABLE_FLAG_COUNT(quic_conservative_cwnd_and_pacing_gains); + set_high_gain(kDerivedHighCWNDGain); + set_high_cwnd_gain(kDerivedHighCWNDGain); + } + congestion_window_ = new_cwnd; + + // Pace at the rate of new_cwnd / RTT. + QuicBandwidth new_pacing_rate = + QuicBandwidth::FromBytesAndTimeDelta(congestion_window_, GetMinRtt()); + pacing_rate_ = std::max(pacing_rate_, new_pacing_rate); + detect_overshooting_ = true; + } +} + +void BbrSender::OnCongestionEvent(bool /*rtt_updated*/, + QuicByteCount prior_in_flight, + QuicTime event_time, + const AckedPacketVector& acked_packets, + const LostPacketVector& lost_packets, + QuicPacketCount /*num_ect*/, + QuicPacketCount /*num_ce*/) { + const QuicByteCount total_bytes_acked_before = sampler_.total_bytes_acked(); + const QuicByteCount total_bytes_lost_before = sampler_.total_bytes_lost(); + + bool is_round_start = false; + bool min_rtt_expired = false; + QuicByteCount excess_acked = 0; + QuicByteCount bytes_lost = 0; + + // The send state of the largest packet in acked_packets, unless it is + // empty. If acked_packets is empty, it's the send state of the largest + // packet in lost_packets. + SendTimeState last_packet_send_state; + + if (!acked_packets.empty()) { + QuicPacketNumber last_acked_packet = acked_packets.rbegin()->packet_number; + is_round_start = UpdateRoundTripCounter(last_acked_packet); + UpdateRecoveryState(last_acked_packet, !lost_packets.empty(), + is_round_start); + } + + BandwidthSamplerInterface::CongestionEventSample sample = + sampler_.OnCongestionEvent(event_time, acked_packets, lost_packets, + max_bandwidth_.GetBest(), + QuicBandwidth::Infinite(), round_trip_count_); + if (sample.last_packet_send_state.is_valid) { + last_sample_is_app_limited_ = sample.last_packet_send_state.is_app_limited; + has_non_app_limited_sample_ |= !last_sample_is_app_limited_; + if (stats_) { + stats_->has_non_app_limited_sample = has_non_app_limited_sample_; + } + } + // Avoid updating |max_bandwidth_| if a) this is a loss-only event, or b) all + // packets in |acked_packets| did not generate valid samples. (e.g. ack of + // ack-only packets). In both cases, sampler_.total_bytes_acked() will not + // change. + if (total_bytes_acked_before != sampler_.total_bytes_acked()) { + QUIC_LOG_IF(WARNING, sample.sample_max_bandwidth.IsZero()) + << sampler_.total_bytes_acked() - total_bytes_acked_before + << " bytes from " << acked_packets.size() + << " packets have been acked, but sample_max_bandwidth is zero."; + if (!sample.sample_is_app_limited || + sample.sample_max_bandwidth > max_bandwidth_.GetBest()) { + max_bandwidth_.Update(sample.sample_max_bandwidth, round_trip_count_); + } + } + + if (!sample.sample_rtt.IsInfinite()) { + min_rtt_expired = MaybeUpdateMinRtt(event_time, sample.sample_rtt); + } + bytes_lost = sampler_.total_bytes_lost() - total_bytes_lost_before; + if (mode_ == STARTUP) { + if (stats_) { + stats_->slowstart_packets_lost += lost_packets.size(); + stats_->slowstart_bytes_lost += bytes_lost; + } + } + excess_acked = sample.extra_acked; + last_packet_send_state = sample.last_packet_send_state; + + if (!lost_packets.empty()) { + ++num_loss_events_in_round_; + bytes_lost_in_round_ += bytes_lost; + } + + // Handle logic specific to PROBE_BW mode. + if (mode_ == PROBE_BW) { + UpdateGainCyclePhase(event_time, prior_in_flight, !lost_packets.empty()); + } + + // Handle logic specific to STARTUP and DRAIN modes. + if (is_round_start && !is_at_full_bandwidth_) { + CheckIfFullBandwidthReached(last_packet_send_state); + } + MaybeExitStartupOrDrain(event_time); + + // Handle logic specific to PROBE_RTT. + MaybeEnterOrExitProbeRtt(event_time, is_round_start, min_rtt_expired); + + // Calculate number of packets acked and lost. + QuicByteCount bytes_acked = + sampler_.total_bytes_acked() - total_bytes_acked_before; + + // After the model is updated, recalculate the pacing rate and congestion + // window. + CalculatePacingRate(bytes_lost); + CalculateCongestionWindow(bytes_acked, excess_acked); + CalculateRecoveryWindow(bytes_acked, bytes_lost); + + // Cleanup internal state. + sampler_.RemoveObsoletePackets(unacked_packets_->GetLeastUnacked()); + if (is_round_start) { + num_loss_events_in_round_ = 0; + bytes_lost_in_round_ = 0; + } +} + +CongestionControlType BbrSender::GetCongestionControlType() const { + return kBBR; +} + +QuicTime::Delta BbrSender::GetMinRtt() const { + if (!min_rtt_.IsZero()) { + return min_rtt_; + } + // min_rtt could be available if the handshake packet gets neutered then + // gets acknowledged. This could only happen for QUIC crypto where we do not + // drop keys. + return rtt_stats_->MinOrInitialRtt(); +} + +QuicByteCount BbrSender::GetTargetCongestionWindow(float gain) const { + QuicByteCount bdp = GetMinRtt() * BandwidthEstimate(); + QuicByteCount congestion_window = gain * bdp; + + // BDP estimate will be zero if no bandwidth samples are available yet. + if (congestion_window == 0) { + congestion_window = gain * initial_congestion_window_; + } + + return std::max(congestion_window, min_congestion_window_); +} + +QuicByteCount BbrSender::ProbeRttCongestionWindow() const { + return min_congestion_window_; +} + +void BbrSender::EnterStartupMode(QuicTime now) { + if (stats_) { + ++stats_->slowstart_count; + stats_->slowstart_duration.Start(now); + } + mode_ = STARTUP; + pacing_gain_ = high_gain_; + congestion_window_gain_ = high_cwnd_gain_; +} + +void BbrSender::EnterProbeBandwidthMode(QuicTime now) { + mode_ = PROBE_BW; + congestion_window_gain_ = congestion_window_gain_constant_; + + // Pick a random offset for the gain cycle out of {0, 2..7} range. 1 is + // excluded because in that case increased gain and decreased gain would not + // follow each other. + cycle_current_offset_ = random_->RandUint64() % (kGainCycleLength - 1); + if (cycle_current_offset_ >= 1) { + cycle_current_offset_ += 1; + } + + last_cycle_start_ = now; + pacing_gain_ = kPacingGain[cycle_current_offset_]; +} + +bool BbrSender::UpdateRoundTripCounter(QuicPacketNumber last_acked_packet) { + if (!current_round_trip_end_.IsInitialized() || + last_acked_packet > current_round_trip_end_) { + round_trip_count_++; + current_round_trip_end_ = last_sent_packet_; + if (stats_ && InSlowStart()) { + ++stats_->slowstart_num_rtts; + } + return true; + } + + return false; +} + +bool BbrSender::MaybeUpdateMinRtt(QuicTime now, + QuicTime::Delta sample_min_rtt) { + // Do not expire min_rtt if none was ever available. + bool min_rtt_expired = + !min_rtt_.IsZero() && (now > (min_rtt_timestamp_ + kMinRttExpiry)); + + if (min_rtt_expired || sample_min_rtt < min_rtt_ || min_rtt_.IsZero()) { + QUIC_DVLOG(2) << "Min RTT updated, old value: " << min_rtt_ + << ", new value: " << sample_min_rtt + << ", current time: " << now.ToDebuggingValue(); + + min_rtt_ = sample_min_rtt; + min_rtt_timestamp_ = now; + } + QUICHE_DCHECK(!min_rtt_.IsZero()); + + return min_rtt_expired; +} + +void BbrSender::UpdateGainCyclePhase(QuicTime now, + QuicByteCount prior_in_flight, + bool has_losses) { + const QuicByteCount bytes_in_flight = unacked_packets_->bytes_in_flight(); + // In most cases, the cycle is advanced after an RTT passes. + bool should_advance_gain_cycling = now - last_cycle_start_ > GetMinRtt(); + + // If the pacing gain is above 1.0, the connection is trying to probe the + // bandwidth by increasing the number of bytes in flight to at least + // pacing_gain * BDP. Make sure that it actually reaches the target, as long + // as there are no losses suggesting that the buffers are not able to hold + // that much. + if (pacing_gain_ > 1.0 && !has_losses && + prior_in_flight < GetTargetCongestionWindow(pacing_gain_)) { + should_advance_gain_cycling = false; + } + + // If pacing gain is below 1.0, the connection is trying to drain the extra + // queue which could have been incurred by probing prior to it. If the number + // of bytes in flight falls down to the estimated BDP value earlier, conclude + // that the queue has been successfully drained and exit this cycle early. + if (pacing_gain_ < 1.0 && bytes_in_flight <= GetTargetCongestionWindow(1)) { + should_advance_gain_cycling = true; + } + + if (should_advance_gain_cycling) { + cycle_current_offset_ = (cycle_current_offset_ + 1) % kGainCycleLength; + if (cycle_current_offset_ == 0) { + ++stats_->bbr_num_cycles; + } + last_cycle_start_ = now; + // Stay in low gain mode until the target BDP is hit. + // Low gain mode will be exited immediately when the target BDP is achieved. + if (drain_to_target_ && pacing_gain_ < 1 && + kPacingGain[cycle_current_offset_] == 1 && + bytes_in_flight > GetTargetCongestionWindow(1)) { + return; + } + pacing_gain_ = kPacingGain[cycle_current_offset_]; + } +} + +void BbrSender::CheckIfFullBandwidthReached( + const SendTimeState& last_packet_send_state) { + if (last_sample_is_app_limited_) { + return; + } + + QuicBandwidth target = bandwidth_at_last_round_ * kStartupGrowthTarget; + if (BandwidthEstimate() >= target) { + bandwidth_at_last_round_ = BandwidthEstimate(); + rounds_without_bandwidth_gain_ = 0; + if (expire_ack_aggregation_in_startup_) { + // Expire old excess delivery measurements now that bandwidth increased. + sampler_.ResetMaxAckHeightTracker(0, round_trip_count_); + } + return; + } + + rounds_without_bandwidth_gain_++; + if ((rounds_without_bandwidth_gain_ >= num_startup_rtts_) || + ShouldExitStartupDueToLoss(last_packet_send_state)) { + QUICHE_DCHECK(has_non_app_limited_sample_); + is_at_full_bandwidth_ = true; + } +} + +void BbrSender::MaybeExitStartupOrDrain(QuicTime now) { + if (mode_ == STARTUP && is_at_full_bandwidth_) { + OnExitStartup(now); + mode_ = DRAIN; + pacing_gain_ = drain_gain_; + congestion_window_gain_ = high_cwnd_gain_; + } + if (mode_ == DRAIN && + unacked_packets_->bytes_in_flight() <= GetTargetCongestionWindow(1)) { + EnterProbeBandwidthMode(now); + } +} + +void BbrSender::OnExitStartup(QuicTime now) { + QUICHE_DCHECK_EQ(mode_, STARTUP); + if (stats_) { + stats_->slowstart_duration.Stop(now); + } +} + +bool BbrSender::ShouldExitStartupDueToLoss( + const SendTimeState& last_packet_send_state) const { + if (num_loss_events_in_round_ < + GetQuicFlag(quic_bbr2_default_startup_full_loss_count) || + !last_packet_send_state.is_valid) { + return false; + } + + const QuicByteCount inflight_at_send = last_packet_send_state.bytes_in_flight; + + if (inflight_at_send > 0 && bytes_lost_in_round_ > 0) { + if (bytes_lost_in_round_ > + inflight_at_send * GetQuicFlag(quic_bbr2_default_loss_threshold)) { + stats_->bbr_exit_startup_due_to_loss = true; + return true; + } + return false; + } + + return false; +} + +void BbrSender::MaybeEnterOrExitProbeRtt(QuicTime now, bool is_round_start, + bool min_rtt_expired) { + if (min_rtt_expired && !exiting_quiescence_ && mode_ != PROBE_RTT) { + if (InSlowStart()) { + OnExitStartup(now); + } + mode_ = PROBE_RTT; + pacing_gain_ = 1; + // Do not decide on the time to exit PROBE_RTT until the |bytes_in_flight| + // is at the target small value. + exit_probe_rtt_at_ = QuicTime::Zero(); + } + + if (mode_ == PROBE_RTT) { + sampler_.OnAppLimited(); + + if (exit_probe_rtt_at_ == QuicTime::Zero()) { + // If the window has reached the appropriate size, schedule exiting + // PROBE_RTT. The CWND during PROBE_RTT is kMinimumCongestionWindow, but + // we allow an extra packet since QUIC checks CWND before sending a + // packet. + if (unacked_packets_->bytes_in_flight() < + ProbeRttCongestionWindow() + kMaxOutgoingPacketSize) { + exit_probe_rtt_at_ = now + kProbeRttTime; + probe_rtt_round_passed_ = false; + } + } else { + if (is_round_start) { + probe_rtt_round_passed_ = true; + } + if (now >= exit_probe_rtt_at_ && probe_rtt_round_passed_) { + min_rtt_timestamp_ = now; + if (!is_at_full_bandwidth_) { + EnterStartupMode(now); + } else { + EnterProbeBandwidthMode(now); + } + } + } + } + + exiting_quiescence_ = false; +} + +void BbrSender::UpdateRecoveryState(QuicPacketNumber last_acked_packet, + bool has_losses, bool is_round_start) { + // Disable recovery in startup, if loss-based exit is enabled. + if (!is_at_full_bandwidth_) { + return; + } + + // Exit recovery when there are no losses for a round. + if (has_losses) { + end_recovery_at_ = last_sent_packet_; + } + + switch (recovery_state_) { + case NOT_IN_RECOVERY: + // Enter conservation on the first loss. + if (has_losses) { + recovery_state_ = CONSERVATION; + // This will cause the |recovery_window_| to be set to the correct + // value in CalculateRecoveryWindow(). + recovery_window_ = 0; + // Since the conservation phase is meant to be lasting for a whole + // round, extend the current round as if it were started right now. + current_round_trip_end_ = last_sent_packet_; + } + break; + + case CONSERVATION: + if (is_round_start) { + recovery_state_ = GROWTH; + } + ABSL_FALLTHROUGH_INTENDED; + + case GROWTH: + // Exit recovery if appropriate. + if (!has_losses && last_acked_packet > end_recovery_at_) { + recovery_state_ = NOT_IN_RECOVERY; + } + + break; + } +} + +void BbrSender::CalculatePacingRate(QuicByteCount bytes_lost) { + if (BandwidthEstimate().IsZero()) { + return; + } + + QuicBandwidth target_rate = pacing_gain_ * BandwidthEstimate(); + if (is_at_full_bandwidth_) { + pacing_rate_ = target_rate; + return; + } + + // Pace at the rate of initial_window / RTT as soon as RTT measurements are + // available. + if (pacing_rate_.IsZero() && !rtt_stats_->min_rtt().IsZero()) { + pacing_rate_ = QuicBandwidth::FromBytesAndTimeDelta( + initial_congestion_window_, rtt_stats_->min_rtt()); + return; + } + + if (detect_overshooting_) { + bytes_lost_while_detecting_overshooting_ += bytes_lost; + // Check for overshooting with network parameters adjusted when pacing rate + // > target_rate and loss has been detected. + if (pacing_rate_ > target_rate && + bytes_lost_while_detecting_overshooting_ > 0) { + if (has_non_app_limited_sample_ || + bytes_lost_while_detecting_overshooting_ * + bytes_lost_multiplier_while_detecting_overshooting_ > + initial_congestion_window_) { + // We are fairly sure overshoot happens if 1) there is at least one + // non app-limited bw sample or 2) half of IW gets lost. Slow pacing + // rate. + pacing_rate_ = std::max( + target_rate, QuicBandwidth::FromBytesAndTimeDelta( + cwnd_to_calculate_min_pacing_rate_, GetMinRtt())); + if (stats_) { + stats_->overshooting_detected_with_network_parameters_adjusted = true; + } + bytes_lost_while_detecting_overshooting_ = 0; + detect_overshooting_ = false; + } + } + } + + // Do not decrease the pacing rate during startup. + pacing_rate_ = std::max(pacing_rate_, target_rate); +} + +void BbrSender::CalculateCongestionWindow(QuicByteCount bytes_acked, + QuicByteCount excess_acked) { + if (mode_ == PROBE_RTT) { + return; + } + + QuicByteCount target_window = + GetTargetCongestionWindow(congestion_window_gain_); + if (is_at_full_bandwidth_) { + // Add the max recently measured ack aggregation to CWND. + target_window += sampler_.max_ack_height(); + } else if (enable_ack_aggregation_during_startup_) { + // Add the most recent excess acked. Because CWND never decreases in + // STARTUP, this will automatically create a very localized max filter. + target_window += excess_acked; + } + + // Instead of immediately setting the target CWND as the new one, BBR grows + // the CWND towards |target_window| by only increasing it |bytes_acked| at a + // time. + if (is_at_full_bandwidth_) { + congestion_window_ = + std::min(target_window, congestion_window_ + bytes_acked); + } else if (congestion_window_ < target_window || + sampler_.total_bytes_acked() < initial_congestion_window_) { + // If the connection is not yet out of startup phase, do not decrease the + // window. + congestion_window_ = congestion_window_ + bytes_acked; + } + + // Enforce the limits on the congestion window. + congestion_window_ = std::max(congestion_window_, min_congestion_window_); + congestion_window_ = std::min(congestion_window_, max_congestion_window_); +} + +void BbrSender::CalculateRecoveryWindow(QuicByteCount bytes_acked, + QuicByteCount bytes_lost) { + if (recovery_state_ == NOT_IN_RECOVERY) { + return; + } + + // Set up the initial recovery window. + if (recovery_window_ == 0) { + recovery_window_ = unacked_packets_->bytes_in_flight() + bytes_acked; + recovery_window_ = std::max(min_congestion_window_, recovery_window_); + return; + } + + // Remove losses from the recovery window, while accounting for a potential + // integer underflow. + recovery_window_ = recovery_window_ >= bytes_lost + ? recovery_window_ - bytes_lost + : kMaxSegmentSize; + + // In CONSERVATION mode, just subtracting losses is sufficient. In GROWTH, + // release additional |bytes_acked| to achieve a slow-start-like behavior. + if (recovery_state_ == GROWTH) { + recovery_window_ += bytes_acked; + } + + // Always allow sending at least |bytes_acked| in response. + recovery_window_ = std::max( + recovery_window_, unacked_packets_->bytes_in_flight() + bytes_acked); + recovery_window_ = std::max(min_congestion_window_, recovery_window_); +} + +std::string BbrSender::GetDebugState() const { + std::ostringstream stream; + stream << ExportDebugState(); + return stream.str(); +} + +void BbrSender::OnApplicationLimited(QuicByteCount bytes_in_flight) { + if (bytes_in_flight >= GetCongestionWindow()) { + return; + } + + sampler_.OnAppLimited(); + QUIC_DVLOG(2) << "Becoming application limited. Last sent packet: " + << last_sent_packet_ << ", CWND: " << GetCongestionWindow(); +} + +void BbrSender::PopulateConnectionStats(QuicConnectionStats* stats) const { + stats->num_ack_aggregation_epochs = sampler_.num_ack_aggregation_epochs(); +} + +BbrSender::DebugState BbrSender::ExportDebugState() const { + return DebugState(*this); +} + +static std::string ModeToString(BbrSender::Mode mode) { + switch (mode) { + case BbrSender::STARTUP: + return "STARTUP"; + case BbrSender::DRAIN: + return "DRAIN"; + case BbrSender::PROBE_BW: + return "PROBE_BW"; + case BbrSender::PROBE_RTT: + return "PROBE_RTT"; + } + return "???"; +} + +std::ostream& operator<<(std::ostream& os, const BbrSender::Mode& mode) { + os << ModeToString(mode); + return os; +} + +std::ostream& operator<<(std::ostream& os, const BbrSender::DebugState& state) { + os << "Mode: " << ModeToString(state.mode) << std::endl; + os << "Maximum bandwidth: " << state.max_bandwidth << std::endl; + os << "Round trip counter: " << state.round_trip_count << std::endl; + os << "Gain cycle index: " << static_cast(state.gain_cycle_index) + << std::endl; + os << "Congestion window: " << state.congestion_window << " bytes" + << std::endl; + + if (state.mode == BbrSender::STARTUP) { + os << "(startup) Bandwidth at last round: " << state.bandwidth_at_last_round + << std::endl; + os << "(startup) Rounds without gain: " + << state.rounds_without_bandwidth_gain << std::endl; + } + + os << "Minimum RTT: " << state.min_rtt << std::endl; + os << "Minimum RTT timestamp: " << state.min_rtt_timestamp.ToDebuggingValue() + << std::endl; + + os << "Last sample is app-limited: " + << (state.last_sample_is_app_limited ? "yes" : "no"); + + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/congestion_control/bbr_sender.h b/quiche/quic/core/congestion_control/bbr_sender.h new file mode 100644 index 000000000000..150400427cd2 --- /dev/null +++ b/quiche/quic/core/congestion_control/bbr_sender.h @@ -0,0 +1,391 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// BBR (Bottleneck Bandwidth and RTT) congestion control algorithm. + +#ifndef QUICHE_QUIC_CORE_CONGESTION_CONTROL_BBR_SENDER_H_ +#define QUICHE_QUIC_CORE_CONGESTION_CONTROL_BBR_SENDER_H_ + +#include +#include +#include + +#include "quiche/quic/core/congestion_control/bandwidth_sampler.h" +#include "quiche/quic/core/congestion_control/send_algorithm_interface.h" +#include "quiche/quic/core/congestion_control/windowed_filter.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/quic_bandwidth.h" +#include "quiche/quic/core/quic_packet_number.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_unacked_packet_map.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_flags.h" + +namespace quic { + +class RttStats; + +// BbrSender implements BBR congestion control algorithm. BBR aims to estimate +// the current available Bottleneck Bandwidth and RTT (hence the name), and +// regulates the pacing rate and the size of the congestion window based on +// those signals. +// +// BBR relies on pacing in order to function properly. Do not use BBR when +// pacing is disabled. +// +// TODO(vasilvv): implement traffic policer (long-term sampling) mode. +class QUIC_EXPORT_PRIVATE BbrSender : public SendAlgorithmInterface { + public: + enum Mode { + // Startup phase of the connection. + STARTUP, + // After achieving the highest possible bandwidth during the startup, lower + // the pacing rate in order to drain the queue. + DRAIN, + // Cruising mode. + PROBE_BW, + // Temporarily slow down sending in order to empty the buffer and measure + // the real minimum RTT. + PROBE_RTT, + }; + + // Indicates how the congestion control limits the amount of bytes in flight. + enum RecoveryState { + // Do not limit. + NOT_IN_RECOVERY, + // Allow an extra outstanding byte for each byte acknowledged. + CONSERVATION, + // Allow two extra outstanding bytes for each byte acknowledged (slow + // start). + GROWTH + }; + + // Debug state can be exported in order to troubleshoot potential congestion + // control issues. + struct QUIC_EXPORT_PRIVATE DebugState { + explicit DebugState(const BbrSender& sender); + DebugState(const DebugState& state); + + Mode mode; + QuicBandwidth max_bandwidth; + QuicRoundTripCount round_trip_count; + int gain_cycle_index; + QuicByteCount congestion_window; + + bool is_at_full_bandwidth; + QuicBandwidth bandwidth_at_last_round; + QuicRoundTripCount rounds_without_bandwidth_gain; + + QuicTime::Delta min_rtt; + QuicTime min_rtt_timestamp; + + RecoveryState recovery_state; + QuicByteCount recovery_window; + + bool last_sample_is_app_limited; + QuicPacketNumber end_of_app_limited_phase; + }; + + BbrSender(QuicTime now, const RttStats* rtt_stats, + const QuicUnackedPacketMap* unacked_packets, + QuicPacketCount initial_tcp_congestion_window, + QuicPacketCount max_tcp_congestion_window, QuicRandom* random, + QuicConnectionStats* stats); + BbrSender(const BbrSender&) = delete; + BbrSender& operator=(const BbrSender&) = delete; + ~BbrSender() override; + + // Start implementation of SendAlgorithmInterface. + bool InSlowStart() const override; + bool InRecovery() const override; + + void SetFromConfig(const QuicConfig& config, + Perspective perspective) override; + void ApplyConnectionOptions(const QuicTagVector& connection_options) override; + + void AdjustNetworkParameters(const NetworkParams& params) override; + void SetInitialCongestionWindowInPackets( + QuicPacketCount congestion_window) override; + void OnCongestionEvent(bool rtt_updated, QuicByteCount prior_in_flight, + QuicTime event_time, + const AckedPacketVector& acked_packets, + const LostPacketVector& lost_packets, + QuicPacketCount num_ect, + QuicPacketCount num_ce) override; + void OnPacketSent(QuicTime sent_time, QuicByteCount bytes_in_flight, + QuicPacketNumber packet_number, QuicByteCount bytes, + HasRetransmittableData is_retransmittable) override; + void OnPacketNeutered(QuicPacketNumber packet_number) override; + void OnRetransmissionTimeout(bool /*packets_retransmitted*/) override {} + void OnConnectionMigration() override {} + bool CanSend(QuicByteCount bytes_in_flight) override; + QuicBandwidth PacingRate(QuicByteCount bytes_in_flight) const override; + QuicBandwidth BandwidthEstimate() const override; + bool HasGoodBandwidthEstimateForResumption() const override { + return has_non_app_limited_sample(); + } + QuicByteCount GetCongestionWindow() const override; + QuicByteCount GetSlowStartThreshold() const override; + CongestionControlType GetCongestionControlType() const override; + std::string GetDebugState() const override; + void OnApplicationLimited(QuicByteCount bytes_in_flight) override; + void PopulateConnectionStats(QuicConnectionStats* stats) const override; + bool SupportsECT0() const override { return false; } + bool SupportsECT1() const override { return false; } + // End implementation of SendAlgorithmInterface. + + // Gets the number of RTTs BBR remains in STARTUP phase. + QuicRoundTripCount num_startup_rtts() const { return num_startup_rtts_; } + bool has_non_app_limited_sample() const { + return has_non_app_limited_sample_; + } + + // Sets the pacing gain used in STARTUP. Must be greater than 1. + void set_high_gain(float high_gain) { + QUICHE_DCHECK_LT(1.0f, high_gain); + high_gain_ = high_gain; + if (mode_ == STARTUP) { + pacing_gain_ = high_gain; + } + } + + // Sets the CWND gain used in STARTUP. Must be greater than 1. + void set_high_cwnd_gain(float high_cwnd_gain) { + QUICHE_DCHECK_LT(1.0f, high_cwnd_gain); + high_cwnd_gain_ = high_cwnd_gain; + if (mode_ == STARTUP) { + congestion_window_gain_ = high_cwnd_gain; + } + } + + // Sets the gain used in DRAIN. Must be less than 1. + void set_drain_gain(float drain_gain) { + QUICHE_DCHECK_GT(1.0f, drain_gain); + drain_gain_ = drain_gain; + } + + // Returns the current estimate of the RTT of the connection. Outside of the + // edge cases, this is minimum RTT. + QuicTime::Delta GetMinRtt() const; + + DebugState ExportDebugState() const; + + private: + // For switching send algorithm mid connection. + friend class Bbr2Sender; + + using MaxBandwidthFilter = + WindowedFilter, + QuicRoundTripCount, QuicRoundTripCount>; + + using MaxAckHeightFilter = + WindowedFilter, + QuicRoundTripCount, QuicRoundTripCount>; + + // Computes the target congestion window using the specified gain. + QuicByteCount GetTargetCongestionWindow(float gain) const; + // The target congestion window during PROBE_RTT. + QuicByteCount ProbeRttCongestionWindow() const; + bool MaybeUpdateMinRtt(QuicTime now, QuicTime::Delta sample_min_rtt); + + // Enters the STARTUP mode. + void EnterStartupMode(QuicTime now); + // Enters the PROBE_BW mode. + void EnterProbeBandwidthMode(QuicTime now); + + // Updates the round-trip counter if a round-trip has passed. Returns true if + // the counter has been advanced. + bool UpdateRoundTripCounter(QuicPacketNumber last_acked_packet); + + // Updates the current gain used in PROBE_BW mode. + void UpdateGainCyclePhase(QuicTime now, QuicByteCount prior_in_flight, + bool has_losses); + // Tracks for how many round-trips the bandwidth has not increased + // significantly. + void CheckIfFullBandwidthReached(const SendTimeState& last_packet_send_state); + // Transitions from STARTUP to DRAIN and from DRAIN to PROBE_BW if + // appropriate. + void MaybeExitStartupOrDrain(QuicTime now); + // Decides whether to enter or exit PROBE_RTT. + void MaybeEnterOrExitProbeRtt(QuicTime now, bool is_round_start, + bool min_rtt_expired); + // Determines whether BBR needs to enter, exit or advance state of the + // recovery. + void UpdateRecoveryState(QuicPacketNumber last_acked_packet, bool has_losses, + bool is_round_start); + + // Updates the ack aggregation max filter in bytes. + // Returns the most recent addition to the filter, or |newly_acked_bytes| if + // nothing was fed in to the filter. + QuicByteCount UpdateAckAggregationBytes(QuicTime ack_time, + QuicByteCount newly_acked_bytes); + + // Determines the appropriate pacing rate for the connection. + void CalculatePacingRate(QuicByteCount bytes_lost); + // Determines the appropriate congestion window for the connection. + void CalculateCongestionWindow(QuicByteCount bytes_acked, + QuicByteCount excess_acked); + // Determines the appropriate window that constrains the in-flight during + // recovery. + void CalculateRecoveryWindow(QuicByteCount bytes_acked, + QuicByteCount bytes_lost); + + // Called right before exiting STARTUP. + void OnExitStartup(QuicTime now); + + // Return whether we should exit STARTUP due to excessive loss. + bool ShouldExitStartupDueToLoss( + const SendTimeState& last_packet_send_state) const; + + const RttStats* rtt_stats_; + const QuicUnackedPacketMap* unacked_packets_; + QuicRandom* random_; + QuicConnectionStats* stats_; + + Mode mode_; + + // Bandwidth sampler provides BBR with the bandwidth measurements at + // individual points. + BandwidthSampler sampler_; + + // The number of the round trips that have occurred during the connection. + QuicRoundTripCount round_trip_count_; + + // The packet number of the most recently sent packet. + QuicPacketNumber last_sent_packet_; + // Acknowledgement of any packet after |current_round_trip_end_| will cause + // the round trip counter to advance. + QuicPacketNumber current_round_trip_end_; + + // Number of congestion events with some losses, in the current round. + int64_t num_loss_events_in_round_; + + // Number of total bytes lost in the current round. + QuicByteCount bytes_lost_in_round_; + + // The filter that tracks the maximum bandwidth over the multiple recent + // round-trips. + MaxBandwidthFilter max_bandwidth_; + + // Minimum RTT estimate. Automatically expires within 10 seconds (and + // triggers PROBE_RTT mode) if no new value is sampled during that period. + QuicTime::Delta min_rtt_; + // The time at which the current value of |min_rtt_| was assigned. + QuicTime min_rtt_timestamp_; + + // The maximum allowed number of bytes in flight. + QuicByteCount congestion_window_; + + // The initial value of the |congestion_window_|. + QuicByteCount initial_congestion_window_; + + // The largest value the |congestion_window_| can achieve. + QuicByteCount max_congestion_window_; + + // The smallest value the |congestion_window_| can achieve. + QuicByteCount min_congestion_window_; + + // The pacing gain applied during the STARTUP phase. + float high_gain_; + + // The CWND gain applied during the STARTUP phase. + float high_cwnd_gain_; + + // The pacing gain applied during the DRAIN phase. + float drain_gain_; + + // The current pacing rate of the connection. + QuicBandwidth pacing_rate_; + + // The gain currently applied to the pacing rate. + float pacing_gain_; + // The gain currently applied to the congestion window. + float congestion_window_gain_; + + // The gain used for the congestion window during PROBE_BW. Latched from + // quic_bbr_cwnd_gain flag. + const float congestion_window_gain_constant_; + // The number of RTTs to stay in STARTUP mode. Defaults to 3. + QuicRoundTripCount num_startup_rtts_; + + // Number of round-trips in PROBE_BW mode, used for determining the current + // pacing gain cycle. + int cycle_current_offset_; + // The time at which the last pacing gain cycle was started. + QuicTime last_cycle_start_; + + // Indicates whether the connection has reached the full bandwidth mode. + bool is_at_full_bandwidth_; + // Number of rounds during which there was no significant bandwidth increase. + QuicRoundTripCount rounds_without_bandwidth_gain_; + // The bandwidth compared to which the increase is measured. + QuicBandwidth bandwidth_at_last_round_; + + // Set to true upon exiting quiescence. + bool exiting_quiescence_; + + // Time at which PROBE_RTT has to be exited. Setting it to zero indicates + // that the time is yet unknown as the number of packets in flight has not + // reached the required value. + QuicTime exit_probe_rtt_at_; + // Indicates whether a round-trip has passed since PROBE_RTT became active. + bool probe_rtt_round_passed_; + + // Indicates whether the most recent bandwidth sample was marked as + // app-limited. + bool last_sample_is_app_limited_; + // Indicates whether any non app-limited samples have been recorded. + bool has_non_app_limited_sample_; + + // Current state of recovery. + RecoveryState recovery_state_; + // Receiving acknowledgement of a packet after |end_recovery_at_| will cause + // BBR to exit the recovery mode. A value above zero indicates at least one + // loss has been detected, so it must not be set back to zero. + QuicPacketNumber end_recovery_at_; + // A window used to limit the number of bytes in flight during loss recovery. + QuicByteCount recovery_window_; + // If true, consider all samples in recovery app-limited. + bool is_app_limited_recovery_; + + // When true, pace at 1.5x and disable packet conservation in STARTUP. + bool slower_startup_; + // When true, disables packet conservation in STARTUP. + bool rate_based_startup_; + + // When true, add the most recent ack aggregation measurement during STARTUP. + bool enable_ack_aggregation_during_startup_; + // When true, expire the windowed ack aggregation values in STARTUP when + // bandwidth increases more than 25%. + bool expire_ack_aggregation_in_startup_; + + // If true, will not exit low gain mode until bytes_in_flight drops below BDP + // or it's time for high gain mode. + bool drain_to_target_; + + // If true, slow down pacing rate in STARTUP when overshooting is detected. + bool detect_overshooting_; + // Bytes lost while detect_overshooting_ is true. + QuicByteCount bytes_lost_while_detecting_overshooting_; + // Slow down pacing rate if + // bytes_lost_while_detecting_overshooting_ * + // bytes_lost_multiplier_while_detecting_overshooting_ > IW. + uint8_t bytes_lost_multiplier_while_detecting_overshooting_; + // When overshooting is detected, do not drop pacing_rate_ below this value / + // min_rtt. + QuicByteCount cwnd_to_calculate_min_pacing_rate_; + + // Max congestion window when adjusting network parameters. + QuicByteCount max_congestion_window_with_network_parameters_adjusted_; +}; + +QUIC_EXPORT_PRIVATE std::ostream& operator<<(std::ostream& os, + const BbrSender::Mode& mode); +QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const BbrSender::DebugState& state); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CONGESTION_CONTROL_BBR_SENDER_H_ diff --git a/quiche/quic/core/congestion_control/bbr_sender_test.cc b/quiche/quic/core/congestion_control/bbr_sender_test.cc new file mode 100644 index 000000000000..696e364560d4 --- /dev/null +++ b/quiche/quic/core/congestion_control/bbr_sender_test.cc @@ -0,0 +1,1323 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/bbr_sender.h" + +#include +#include +#include +#include + +#include "quiche/quic/core/congestion_control/rtt_stats.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/quic_bandwidth.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/mock_clock.h" +#include "quiche/quic/test_tools/quic_config_peer.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_sent_packet_manager_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/test_tools/send_algorithm_test_result.pb.h" +#include "quiche/quic/test_tools/send_algorithm_test_utils.h" +#include "quiche/quic/test_tools/simulator/quic_endpoint.h" +#include "quiche/quic/test_tools/simulator/simulator.h" +#include "quiche/quic/test_tools/simulator/switch.h" +#include "quiche/common/platform/api/quiche_command_line_flags.h" + +using testing::AllOf; +using testing::Ge; +using testing::Le; + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::string, quic_bbr_test_regression_mode, "", + "One of a) 'record' to record test result (one file per test), or " + "b) 'regress' to regress against recorded results, or " + "c) for non-regression mode."); + +namespace quic { +namespace test { + +// Use the initial CWND of 10, as 32 is too much for the test network. +const uint32_t kInitialCongestionWindowPackets = 10; +const uint32_t kDefaultWindowTCP = + kInitialCongestionWindowPackets * kDefaultTCPMSS; + +// Test network parameters. Here, the topology of the network is: +// +// BBR sender +// | +// | <-- local link (10 Mbps, 2 ms delay) +// | +// Network switch +// * <-- the bottleneck queue in the direction +// | of the receiver +// | +// | <-- test link (4 Mbps, 30 ms delay) +// | +// | +// Receiver +// +// The reason the bandwidths chosen are relatively low is the fact that the +// connection simulator uses QuicTime for its internal clock, and as such has +// the granularity of 1us, meaning that at bandwidth higher than 20 Mbps the +// packets can start to land on the same timestamp. +const QuicBandwidth kTestLinkBandwidth = + QuicBandwidth::FromKBitsPerSecond(4000); +const QuicBandwidth kLocalLinkBandwidth = + QuicBandwidth::FromKBitsPerSecond(10000); +const QuicTime::Delta kTestPropagationDelay = + QuicTime::Delta::FromMilliseconds(30); +const QuicTime::Delta kLocalPropagationDelay = + QuicTime::Delta::FromMilliseconds(2); +const QuicTime::Delta kTestTransferTime = + kTestLinkBandwidth.TransferTime(kMaxOutgoingPacketSize) + + kLocalLinkBandwidth.TransferTime(kMaxOutgoingPacketSize); +const QuicTime::Delta kTestRtt = + (kTestPropagationDelay + kLocalPropagationDelay + kTestTransferTime) * 2; +const QuicByteCount kTestBdp = kTestRtt * kTestLinkBandwidth; + +class BbrSenderTest : public QuicTest { + protected: + BbrSenderTest() + : simulator_(&random_), + bbr_sender_(&simulator_, "BBR sender", "Receiver", + Perspective::IS_CLIENT, + /*connection_id=*/TestConnectionId(42)), + competing_sender_(&simulator_, "Competing sender", "Competing receiver", + Perspective::IS_CLIENT, + /*connection_id=*/TestConnectionId(43)), + receiver_(&simulator_, "Receiver", "BBR sender", Perspective::IS_SERVER, + /*connection_id=*/TestConnectionId(42)), + competing_receiver_(&simulator_, "Competing receiver", + "Competing sender", Perspective::IS_SERVER, + /*connection_id=*/TestConnectionId(43)), + receiver_multiplexer_("Receiver multiplexer", + {&receiver_, &competing_receiver_}) { + rtt_stats_ = bbr_sender_.connection()->sent_packet_manager().GetRttStats(); + const int kTestMaxPacketSize = 1350; + bbr_sender_.connection()->SetMaxPacketLength(kTestMaxPacketSize); + sender_ = SetupBbrSender(&bbr_sender_); + SetConnectionOption(kBBRA); + clock_ = simulator_.GetClock(); + } + + void SetUp() override { + if (quiche::GetQuicheCommandLineFlag(FLAGS_quic_bbr_test_regression_mode) == + "regress") { + SendAlgorithmTestResult expected; + ASSERT_TRUE(LoadSendAlgorithmTestResult(&expected)); + random_seed_ = expected.random_seed(); + } else { + random_seed_ = QuicRandom::GetInstance()->RandUint64(); + } + random_.set_seed(random_seed_); + QUIC_LOG(INFO) << "BbrSenderTest simulator set up. Seed: " << random_seed_; + } + + ~BbrSenderTest() { + const std::string regression_mode = + quiche::GetQuicheCommandLineFlag(FLAGS_quic_bbr_test_regression_mode); + const QuicTime::Delta simulated_duration = clock_->Now() - QuicTime::Zero(); + if (regression_mode == "record") { + RecordSendAlgorithmTestResult(random_seed_, + simulated_duration.ToMicroseconds()); + } else if (regression_mode == "regress") { + CompareSendAlgorithmTestResult(simulated_duration.ToMicroseconds()); + } + } + + uint64_t random_seed_; + SimpleRandom random_; + simulator::Simulator simulator_; + simulator::QuicEndpoint bbr_sender_; + simulator::QuicEndpoint competing_sender_; + simulator::QuicEndpoint receiver_; + simulator::QuicEndpoint competing_receiver_; + simulator::QuicEndpointMultiplexer receiver_multiplexer_; + std::unique_ptr switch_; + std::unique_ptr bbr_sender_link_; + std::unique_ptr competing_sender_link_; + std::unique_ptr receiver_link_; + + // Owned by different components of the connection. + const QuicClock* clock_; + const RttStats* rtt_stats_; + BbrSender* sender_; + + // Enables BBR on |endpoint| and returns the associated BBR congestion + // controller. + BbrSender* SetupBbrSender(simulator::QuicEndpoint* endpoint) { + const RttStats* rtt_stats = + endpoint->connection()->sent_packet_manager().GetRttStats(); + // Ownership of the sender will be overtaken by the endpoint. + BbrSender* sender = new BbrSender( + endpoint->connection()->clock()->Now(), rtt_stats, + QuicSentPacketManagerPeer::GetUnackedPacketMap( + QuicConnectionPeer::GetSentPacketManager(endpoint->connection())), + kInitialCongestionWindowPackets, + GetQuicFlag(quic_max_congestion_window), &random_, + QuicConnectionPeer::GetStats(endpoint->connection())); + QuicConnectionPeer::SetSendAlgorithm(endpoint->connection(), sender); + endpoint->RecordTrace(); + return sender; + } + + // Creates a default setup, which is a network with a bottleneck between the + // receiver and the switch. The switch has the buffers four times larger than + // the bottleneck BDP, which should guarantee a lack of losses. + void CreateDefaultSetup() { + switch_ = std::make_unique(&simulator_, "Switch", 8, + 2 * kTestBdp); + bbr_sender_link_ = std::make_unique( + &bbr_sender_, switch_->port(1), kLocalLinkBandwidth, + kLocalPropagationDelay); + receiver_link_ = std::make_unique( + &receiver_, switch_->port(2), kTestLinkBandwidth, + kTestPropagationDelay); + } + + // Same as the default setup, except the buffer now is half of the BDP. + void CreateSmallBufferSetup() { + switch_ = std::make_unique(&simulator_, "Switch", 8, + 0.5 * kTestBdp); + bbr_sender_link_ = std::make_unique( + &bbr_sender_, switch_->port(1), kLocalLinkBandwidth, + kLocalPropagationDelay); + receiver_link_ = std::make_unique( + &receiver_, switch_->port(2), kTestLinkBandwidth, + kTestPropagationDelay); + } + + // Creates the variation of the default setup in which there is another sender + // that competes for the same bottleneck link. + void CreateCompetitionSetup() { + switch_ = std::make_unique(&simulator_, "Switch", 8, + 2 * kTestBdp); + + // Add a small offset to the competing link in order to avoid + // synchronization effects. + const QuicTime::Delta small_offset = QuicTime::Delta::FromMicroseconds(3); + bbr_sender_link_ = std::make_unique( + &bbr_sender_, switch_->port(1), kLocalLinkBandwidth, + kLocalPropagationDelay); + competing_sender_link_ = std::make_unique( + &competing_sender_, switch_->port(3), kLocalLinkBandwidth, + kLocalPropagationDelay + small_offset); + receiver_link_ = std::make_unique( + &receiver_multiplexer_, switch_->port(2), kTestLinkBandwidth, + kTestPropagationDelay); + } + + // Creates a BBR vs BBR competition setup. + void CreateBbrVsBbrSetup() { + SetupBbrSender(&competing_sender_); + CreateCompetitionSetup(); + } + + void EnableAggregation(QuicByteCount aggregation_bytes, + QuicTime::Delta aggregation_timeout) { + // Enable aggregation on the path from the receiver to the sender. + switch_->port_queue(1)->EnableAggregation(aggregation_bytes, + aggregation_timeout); + } + + void DoSimpleTransfer(QuicByteCount transfer_size, QuicTime::Delta deadline) { + bbr_sender_.AddBytesToTransfer(transfer_size); + // TODO(vasilvv): consider rewriting this to run until the receiver actually + // receives the intended amount of bytes. + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return bbr_sender_.bytes_to_transfer() == 0; }, deadline); + EXPECT_TRUE(simulator_result) + << "Simple transfer failed. Bytes remaining: " + << bbr_sender_.bytes_to_transfer(); + QUIC_LOG(INFO) << "Simple transfer state: " << sender_->ExportDebugState(); + } + + // Drive the simulator by sending enough data to enter PROBE_BW. + void DriveOutOfStartup() { + ASSERT_FALSE(sender_->ExportDebugState().is_at_full_bandwidth); + DoSimpleTransfer(1024 * 1024, QuicTime::Delta::FromSeconds(15)); + EXPECT_EQ(BbrSender::PROBE_BW, sender_->ExportDebugState().mode); + EXPECT_APPROX_EQ(kTestLinkBandwidth, + sender_->ExportDebugState().max_bandwidth, 0.02f); + } + + // Send |bytes|-sized bursts of data |number_of_bursts| times, waiting for + // |wait_time| between each burst. + void SendBursts(size_t number_of_bursts, QuicByteCount bytes, + QuicTime::Delta wait_time) { + ASSERT_EQ(0u, bbr_sender_.bytes_to_transfer()); + for (size_t i = 0; i < number_of_bursts; i++) { + bbr_sender_.AddBytesToTransfer(bytes); + + // Transfer data and wait for three seconds between each transfer. + simulator_.RunFor(wait_time); + + // Ensure the connection did not time out. + ASSERT_TRUE(bbr_sender_.connection()->connected()); + ASSERT_TRUE(receiver_.connection()->connected()); + } + + simulator_.RunFor(wait_time + kTestRtt); + ASSERT_EQ(0u, bbr_sender_.bytes_to_transfer()); + } + + void SetConnectionOption(QuicTag option) { + QuicConfig config; + QuicTagVector options; + options.push_back(option); + QuicConfigPeer::SetReceivedConnectionOptions(&config, options); + sender_->SetFromConfig(config, Perspective::IS_SERVER); + } +}; + +TEST_F(BbrSenderTest, SetInitialCongestionWindow) { + EXPECT_NE(3u * kDefaultTCPMSS, sender_->GetCongestionWindow()); + sender_->SetInitialCongestionWindowInPackets(3); + EXPECT_EQ(3u * kDefaultTCPMSS, sender_->GetCongestionWindow()); +} + +// Test a simple long data transfer in the default setup. +TEST_F(BbrSenderTest, SimpleTransfer) { + CreateDefaultSetup(); + + // At startup make sure we are at the default. + EXPECT_EQ(kDefaultWindowTCP, sender_->GetCongestionWindow()); + // At startup make sure we can send. + EXPECT_TRUE(sender_->CanSend(0)); + // And that window is un-affected. + EXPECT_EQ(kDefaultWindowTCP, sender_->GetCongestionWindow()); + + // Verify that Sender is in slow start. + EXPECT_TRUE(sender_->InSlowStart()); + + // Verify that pacing rate is based on the initial RTT. + QuicBandwidth expected_pacing_rate = QuicBandwidth::FromBytesAndTimeDelta( + 2.885 * kDefaultWindowTCP, rtt_stats_->initial_rtt()); + EXPECT_APPROX_EQ(expected_pacing_rate.ToBitsPerSecond(), + sender_->PacingRate(0).ToBitsPerSecond(), 0.01f); + + ASSERT_GE(kTestBdp, kDefaultWindowTCP + kDefaultTCPMSS); + + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(30)); + EXPECT_EQ(BbrSender::PROBE_BW, sender_->ExportDebugState().mode); + EXPECT_EQ(0u, bbr_sender_.connection()->GetStats().packets_lost); + EXPECT_FALSE(sender_->ExportDebugState().last_sample_is_app_limited); + + // The margin here is quite high, since there exists a possibility that the + // connection just exited high gain cycle. + EXPECT_APPROX_EQ(kTestRtt, rtt_stats_->smoothed_rtt(), 0.2f); +} + +TEST_F(BbrSenderTest, SimpleTransferBBRB) { + SetConnectionOption(kBBRB); + CreateDefaultSetup(); + + // At startup make sure we are at the default. + EXPECT_EQ(kDefaultWindowTCP, sender_->GetCongestionWindow()); + // At startup make sure we can send. + EXPECT_TRUE(sender_->CanSend(0)); + // And that window is un-affected. + EXPECT_EQ(kDefaultWindowTCP, sender_->GetCongestionWindow()); + + // Verify that Sender is in slow start. + EXPECT_TRUE(sender_->InSlowStart()); + + // Verify that pacing rate is based on the initial RTT. + QuicBandwidth expected_pacing_rate = QuicBandwidth::FromBytesAndTimeDelta( + 2.885 * kDefaultWindowTCP, rtt_stats_->initial_rtt()); + EXPECT_APPROX_EQ(expected_pacing_rate.ToBitsPerSecond(), + sender_->PacingRate(0).ToBitsPerSecond(), 0.01f); + + ASSERT_GE(kTestBdp, kDefaultWindowTCP + kDefaultTCPMSS); + + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(30)); + EXPECT_EQ(BbrSender::PROBE_BW, sender_->ExportDebugState().mode); + EXPECT_EQ(0u, bbr_sender_.connection()->GetStats().packets_lost); + EXPECT_FALSE(sender_->ExportDebugState().last_sample_is_app_limited); + + // The margin here is quite high, since there exists a possibility that the + // connection just exited high gain cycle. + EXPECT_APPROX_EQ(kTestRtt, rtt_stats_->smoothed_rtt(), 0.2f); +} + +// Test a simple transfer in a situation when the buffer is less than BDP. +TEST_F(BbrSenderTest, SimpleTransferSmallBuffer) { + CreateSmallBufferSetup(); + + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(30)); + EXPECT_EQ(BbrSender::PROBE_BW, sender_->ExportDebugState().mode); + EXPECT_APPROX_EQ(kTestLinkBandwidth, + sender_->ExportDebugState().max_bandwidth, 0.01f); + EXPECT_GE(bbr_sender_.connection()->GetStats().packets_lost, 0u); + EXPECT_FALSE(sender_->ExportDebugState().last_sample_is_app_limited); + + // The margin here is quite high, since there exists a possibility that the + // connection just exited high gain cycle. + EXPECT_APPROX_EQ(kTestRtt, sender_->GetMinRtt(), 0.2f); +} + +TEST_F(BbrSenderTest, RemoveBytesLostInRecovery) { + CreateDefaultSetup(); + + DriveOutOfStartup(); + + // Drop a packet to enter recovery. + receiver_.DropNextIncomingPacket(); + ASSERT_TRUE( + simulator_.RunUntilOrTimeout([this]() { return sender_->InRecovery(); }, + QuicTime::Delta::FromSeconds(30))); + + QuicUnackedPacketMap* unacked_packets = + QuicSentPacketManagerPeer::GetUnackedPacketMap( + QuicConnectionPeer::GetSentPacketManager(bbr_sender_.connection())); + QuicPacketNumber largest_sent = + bbr_sender_.connection()->sent_packet_manager().GetLargestSentPacket(); + // least_inflight is the smallest inflight packet. + QuicPacketNumber least_inflight = + bbr_sender_.connection()->sent_packet_manager().GetLeastUnacked(); + while (!unacked_packets->GetTransmissionInfo(least_inflight).in_flight) { + ASSERT_LE(least_inflight, largest_sent); + least_inflight++; + } + QuicPacketLength least_inflight_packet_size = + unacked_packets->GetTransmissionInfo(least_inflight).bytes_sent; + QuicByteCount prior_recovery_window = + sender_->ExportDebugState().recovery_window; + QuicByteCount prior_inflight = unacked_packets->bytes_in_flight(); + QUIC_LOG(INFO) << "Recovery window:" << prior_recovery_window + << ", least_inflight_packet_size:" + << least_inflight_packet_size + << ", bytes_in_flight:" << prior_inflight; + ASSERT_GT(prior_recovery_window, least_inflight_packet_size); + + // Lose the least inflight packet and expect the recovery window to drop. + unacked_packets->RemoveFromInFlight(least_inflight); + LostPacketVector lost_packets; + lost_packets.emplace_back(least_inflight, least_inflight_packet_size); + sender_->OnCongestionEvent(false, prior_inflight, clock_->Now(), {}, + lost_packets, 0, 0); + EXPECT_EQ(sender_->ExportDebugState().recovery_window, + prior_inflight - least_inflight_packet_size); + EXPECT_LT(sender_->ExportDebugState().recovery_window, prior_recovery_window); +} + +// Test a simple long data transfer with 2 rtts of aggregation. +TEST_F(BbrSenderTest, SimpleTransfer2RTTAggregationBytes) { + SetConnectionOption(kBSAO); + CreateDefaultSetup(); + // 2 RTTs of aggregation, with a max of 10kb. + EnableAggregation(10 * 1024, 2 * kTestRtt); + + // Transfer 12MB. + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(35)); + EXPECT_TRUE(sender_->ExportDebugState().mode == BbrSender::PROBE_BW || + sender_->ExportDebugState().mode == BbrSender::PROBE_RTT); + + EXPECT_APPROX_EQ(kTestLinkBandwidth, + sender_->ExportDebugState().max_bandwidth, 0.01f); + + // The margin here is high, because the aggregation greatly increases + // smoothed rtt. + EXPECT_GE(kTestRtt * 4, rtt_stats_->smoothed_rtt()); + EXPECT_APPROX_EQ(kTestRtt, rtt_stats_->min_rtt(), 0.5f); +} + +// Test a simple long data transfer with 2 rtts of aggregation. +TEST_F(BbrSenderTest, SimpleTransferAckDecimation) { + SetConnectionOption(kBSAO); + // Decrease the CWND gain so extra CWND is required with stretch acks. + SetQuicFlag(quic_bbr_cwnd_gain, 1.0); + sender_ = new BbrSender( + bbr_sender_.connection()->clock()->Now(), rtt_stats_, + QuicSentPacketManagerPeer::GetUnackedPacketMap( + QuicConnectionPeer::GetSentPacketManager(bbr_sender_.connection())), + kInitialCongestionWindowPackets, GetQuicFlag(quic_max_congestion_window), + &random_, QuicConnectionPeer::GetStats(bbr_sender_.connection())); + QuicConnectionPeer::SetSendAlgorithm(bbr_sender_.connection(), sender_); + CreateDefaultSetup(); + + // Transfer 12MB. + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(35)); + EXPECT_EQ(BbrSender::PROBE_BW, sender_->ExportDebugState().mode); + + EXPECT_APPROX_EQ(kTestLinkBandwidth, + sender_->ExportDebugState().max_bandwidth, 0.01f); + + // TODO(ianswett): Expect 0 packets are lost once BBR no longer measures + // bandwidth higher than the link rate. + EXPECT_FALSE(sender_->ExportDebugState().last_sample_is_app_limited); + // The margin here is high, because the aggregation greatly increases + // smoothed rtt. + EXPECT_GE(kTestRtt * 2, rtt_stats_->smoothed_rtt()); + EXPECT_APPROX_EQ(kTestRtt, rtt_stats_->min_rtt(), 0.1f); +} + +// Test a simple long data transfer with 2 rtts of aggregation. +// TODO(b/172302465) Re-enable this test. +TEST_F(BbrSenderTest, QUIC_TEST_DISABLED_IN_CHROME( + SimpleTransfer2RTTAggregationBytes20RTTWindow)) { + SetConnectionOption(kBSAO); + CreateDefaultSetup(); + SetConnectionOption(kBBR4); + // 2 RTTs of aggregation, with a max of 10kb. + EnableAggregation(10 * 1024, 2 * kTestRtt); + + // Transfer 12MB. + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(35)); + EXPECT_TRUE(sender_->ExportDebugState().mode == BbrSender::PROBE_BW || + sender_->ExportDebugState().mode == BbrSender::PROBE_RTT); + + EXPECT_APPROX_EQ(kTestLinkBandwidth, + sender_->ExportDebugState().max_bandwidth, 0.01f); + + // TODO(ianswett): Expect 0 packets are lost once BBR no longer measures + // bandwidth higher than the link rate. + // The margin here is high, because the aggregation greatly increases + // smoothed rtt. + EXPECT_GE(kTestRtt * 4, rtt_stats_->smoothed_rtt()); + EXPECT_APPROX_EQ(kTestRtt, rtt_stats_->min_rtt(), 0.25f); +} + +// Test a simple long data transfer with 2 rtts of aggregation. +TEST_F(BbrSenderTest, SimpleTransfer2RTTAggregationBytes40RTTWindow) { + SetConnectionOption(kBSAO); + CreateDefaultSetup(); + SetConnectionOption(kBBR5); + // 2 RTTs of aggregation, with a max of 10kb. + EnableAggregation(10 * 1024, 2 * kTestRtt); + + // Transfer 12MB. + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(35)); + EXPECT_TRUE(sender_->ExportDebugState().mode == BbrSender::PROBE_BW || + sender_->ExportDebugState().mode == BbrSender::PROBE_RTT); + + EXPECT_APPROX_EQ(kTestLinkBandwidth, + sender_->ExportDebugState().max_bandwidth, 0.01f); + + // TODO(ianswett): Expect 0 packets are lost once BBR no longer measures + // bandwidth higher than the link rate. + // The margin here is high, because the aggregation greatly increases + // smoothed rtt. + EXPECT_GE(kTestRtt * 4, rtt_stats_->smoothed_rtt()); + EXPECT_APPROX_EQ(kTestRtt, rtt_stats_->min_rtt(), 0.25f); +} + +// Test the number of losses incurred by the startup phase in a situation when +// the buffer is less than BDP. +TEST_F(BbrSenderTest, PacketLossOnSmallBufferStartup) { + CreateSmallBufferSetup(); + + DriveOutOfStartup(); + float loss_rate = + static_cast(bbr_sender_.connection()->GetStats().packets_lost) / + bbr_sender_.connection()->GetStats().packets_sent; + EXPECT_LE(loss_rate, 0.31); +} + +// Test the number of losses incurred by the startup phase in a situation when +// the buffer is less than BDP, with a STARTUP CWND gain of 2. +TEST_F(BbrSenderTest, PacketLossOnSmallBufferStartupDerivedCWNDGain) { + CreateSmallBufferSetup(); + + SetConnectionOption(kBBQ2); + DriveOutOfStartup(); + float loss_rate = + static_cast(bbr_sender_.connection()->GetStats().packets_lost) / + bbr_sender_.connection()->GetStats().packets_sent; + EXPECT_LE(loss_rate, 0.1); +} + +// Ensures the code transitions loss recovery states correctly (NOT_IN_RECOVERY +// -> CONSERVATION -> GROWTH -> NOT_IN_RECOVERY). +TEST_F(BbrSenderTest, RecoveryStates) { + const QuicTime::Delta timeout = QuicTime::Delta::FromSeconds(10); + bool simulator_result; + CreateSmallBufferSetup(); + + bbr_sender_.AddBytesToTransfer(100 * 1024 * 1024); + ASSERT_EQ(BbrSender::NOT_IN_RECOVERY, + sender_->ExportDebugState().recovery_state); + + simulator_result = simulator_.RunUntilOrTimeout( + [this]() { + return sender_->ExportDebugState().recovery_state != + BbrSender::NOT_IN_RECOVERY; + }, + timeout); + ASSERT_TRUE(simulator_result); + ASSERT_EQ(BbrSender::CONSERVATION, + sender_->ExportDebugState().recovery_state); + + simulator_result = simulator_.RunUntilOrTimeout( + [this]() { + return sender_->ExportDebugState().recovery_state != + BbrSender::CONSERVATION; + }, + timeout); + ASSERT_TRUE(simulator_result); + ASSERT_EQ(BbrSender::GROWTH, sender_->ExportDebugState().recovery_state); + + simulator_result = simulator_.RunUntilOrTimeout( + [this]() { + return sender_->ExportDebugState().recovery_state != BbrSender::GROWTH; + }, + timeout); + + ASSERT_EQ(BbrSender::NOT_IN_RECOVERY, + sender_->ExportDebugState().recovery_state); + ASSERT_TRUE(simulator_result); +} + +// Verify the behavior of the algorithm in the case when the connection sends +// small bursts of data after sending continuously for a while. +TEST_F(BbrSenderTest, ApplicationLimitedBursts) { + CreateDefaultSetup(); + EXPECT_FALSE(sender_->HasGoodBandwidthEstimateForResumption()); + + DriveOutOfStartup(); + EXPECT_FALSE(sender_->ExportDebugState().last_sample_is_app_limited); + EXPECT_TRUE(sender_->HasGoodBandwidthEstimateForResumption()); + + SendBursts(20, 512, QuicTime::Delta::FromSeconds(3)); + EXPECT_TRUE(sender_->ExportDebugState().last_sample_is_app_limited); + EXPECT_TRUE(sender_->HasGoodBandwidthEstimateForResumption()); + EXPECT_APPROX_EQ(kTestLinkBandwidth, + sender_->ExportDebugState().max_bandwidth, 0.01f); +} + +// Verify the behavior of the algorithm in the case when the connection sends +// small bursts of data and then starts sending continuously. +TEST_F(BbrSenderTest, ApplicationLimitedBurstsWithoutPrior) { + CreateDefaultSetup(); + + SendBursts(40, 512, QuicTime::Delta::FromSeconds(3)); + EXPECT_TRUE(sender_->ExportDebugState().last_sample_is_app_limited); + + DriveOutOfStartup(); + EXPECT_APPROX_EQ(kTestLinkBandwidth, + sender_->ExportDebugState().max_bandwidth, 0.01f); + EXPECT_FALSE(sender_->ExportDebugState().last_sample_is_app_limited); +} + +// Verify that the DRAIN phase works correctly. +TEST_F(BbrSenderTest, Drain) { + CreateDefaultSetup(); + const QuicTime::Delta timeout = QuicTime::Delta::FromSeconds(10); + // Get the queue at the bottleneck, which is the outgoing queue at the port to + // which the receiver is connected. + const simulator::Queue* queue = switch_->port_queue(2); + bool simulator_result; + + // We have no intention of ever finishing this transfer. + bbr_sender_.AddBytesToTransfer(100 * 1024 * 1024); + + // Run the startup, and verify that it fills up the queue. + ASSERT_EQ(BbrSender::STARTUP, sender_->ExportDebugState().mode); + simulator_result = simulator_.RunUntilOrTimeout( + [this]() { + return sender_->ExportDebugState().mode != BbrSender::STARTUP; + }, + timeout); + ASSERT_TRUE(simulator_result); + ASSERT_EQ(BbrSender::DRAIN, sender_->ExportDebugState().mode); + EXPECT_APPROX_EQ(sender_->BandwidthEstimate() * (1 / 2.885f), + sender_->PacingRate(0), 0.01f); + + // BBR uses CWND gain of 2 during STARTUP, hence it will fill the buffer + // with approximately 1 BDP. Here, we use 0.8 to give some margin for + // error. + EXPECT_GE(queue->bytes_queued(), 0.8 * kTestBdp); + + // Observe increased RTT due to bufferbloat. + const QuicTime::Delta queueing_delay = + kTestLinkBandwidth.TransferTime(queue->bytes_queued()); + EXPECT_APPROX_EQ(kTestRtt + queueing_delay, rtt_stats_->latest_rtt(), 0.1f); + + // Transition to the drain phase and verify that it makes the queue + // have at most a BDP worth of packets. + simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_->ExportDebugState().mode != BbrSender::DRAIN; }, + timeout); + ASSERT_TRUE(simulator_result); + ASSERT_EQ(BbrSender::PROBE_BW, sender_->ExportDebugState().mode); + EXPECT_LE(queue->bytes_queued(), kTestBdp); + + // Wait for a few round trips and ensure we're in appropriate phase of gain + // cycling before taking an RTT measurement. + const QuicRoundTripCount start_round_trip = + sender_->ExportDebugState().round_trip_count; + simulator_result = simulator_.RunUntilOrTimeout( + [this, start_round_trip]() { + QuicRoundTripCount rounds_passed = + sender_->ExportDebugState().round_trip_count - start_round_trip; + return rounds_passed >= 4 && + sender_->ExportDebugState().gain_cycle_index == 7; + }, + timeout); + ASSERT_TRUE(simulator_result); + + // Observe the bufferbloat go away. + EXPECT_APPROX_EQ(kTestRtt, rtt_stats_->smoothed_rtt(), 0.1f); +} + +// TODO(wub): Re-enable this test once default drain_gain changed to 0.75. +// Verify that the DRAIN phase works correctly. +TEST_F(BbrSenderTest, DISABLED_ShallowDrain) { + CreateDefaultSetup(); + const QuicTime::Delta timeout = QuicTime::Delta::FromSeconds(10); + // Get the queue at the bottleneck, which is the outgoing queue at the port to + // which the receiver is connected. + const simulator::Queue* queue = switch_->port_queue(2); + bool simulator_result; + + // We have no intention of ever finishing this transfer. + bbr_sender_.AddBytesToTransfer(100 * 1024 * 1024); + + // Run the startup, and verify that it fills up the queue. + ASSERT_EQ(BbrSender::STARTUP, sender_->ExportDebugState().mode); + simulator_result = simulator_.RunUntilOrTimeout( + [this]() { + return sender_->ExportDebugState().mode != BbrSender::STARTUP; + }, + timeout); + ASSERT_TRUE(simulator_result); + ASSERT_EQ(BbrSender::DRAIN, sender_->ExportDebugState().mode); + EXPECT_EQ(0.75 * sender_->BandwidthEstimate(), sender_->PacingRate(0)); + // BBR uses CWND gain of 2.88 during STARTUP, hence it will fill the buffer + // with approximately 1.88 BDPs. Here, we use 1.5 to give some margin for + // error. + EXPECT_GE(queue->bytes_queued(), 1.5 * kTestBdp); + + // Observe increased RTT due to bufferbloat. + const QuicTime::Delta queueing_delay = + kTestLinkBandwidth.TransferTime(queue->bytes_queued()); + EXPECT_APPROX_EQ(kTestRtt + queueing_delay, rtt_stats_->latest_rtt(), 0.1f); + + // Transition to the drain phase and verify that it makes the queue + // have at most a BDP worth of packets. + simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_->ExportDebugState().mode != BbrSender::DRAIN; }, + timeout); + ASSERT_TRUE(simulator_result); + ASSERT_EQ(BbrSender::PROBE_BW, sender_->ExportDebugState().mode); + EXPECT_LE(queue->bytes_queued(), kTestBdp); + + // Wait for a few round trips and ensure we're in appropriate phase of gain + // cycling before taking an RTT measurement. + const QuicRoundTripCount start_round_trip = + sender_->ExportDebugState().round_trip_count; + simulator_result = simulator_.RunUntilOrTimeout( + [this, start_round_trip]() { + QuicRoundTripCount rounds_passed = + sender_->ExportDebugState().round_trip_count - start_round_trip; + return rounds_passed >= 4 && + sender_->ExportDebugState().gain_cycle_index == 7; + }, + timeout); + ASSERT_TRUE(simulator_result); + + // Observe the bufferbloat go away. + EXPECT_APPROX_EQ(kTestRtt, rtt_stats_->smoothed_rtt(), 0.1f); +} + +// Verify that the connection enters and exits PROBE_RTT correctly. +TEST_F(BbrSenderTest, ProbeRtt) { + CreateDefaultSetup(); + DriveOutOfStartup(); + + // We have no intention of ever finishing this transfer. + bbr_sender_.AddBytesToTransfer(100 * 1024 * 1024); + + // Wait until the connection enters PROBE_RTT. + const QuicTime::Delta timeout = QuicTime::Delta::FromSeconds(12); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { + return sender_->ExportDebugState().mode == BbrSender::PROBE_RTT; + }, + timeout); + ASSERT_TRUE(simulator_result); + ASSERT_EQ(BbrSender::PROBE_RTT, sender_->ExportDebugState().mode); + + // Exit PROBE_RTT. + const QuicTime probe_rtt_start = clock_->Now(); + const QuicTime::Delta time_to_exit_probe_rtt = + kTestRtt + QuicTime::Delta::FromMilliseconds(200); + simulator_.RunFor(1.5 * time_to_exit_probe_rtt); + EXPECT_EQ(BbrSender::PROBE_BW, sender_->ExportDebugState().mode); + EXPECT_GE(sender_->ExportDebugState().min_rtt_timestamp, probe_rtt_start); +} + +// Ensure that a connection that is app-limited and is at sufficiently low +// bandwidth will not exit high gain phase, and similarly ensure that the +// connection will exit low gain early if the number of bytes in flight is low. +// TODO(crbug.com/1145095): Re-enable this test. +TEST_F(BbrSenderTest, QUIC_TEST_DISABLED_IN_CHROME(InFlightAwareGainCycling)) { + CreateDefaultSetup(); + DriveOutOfStartup(); + + const QuicTime::Delta timeout = QuicTime::Delta::FromSeconds(5); + while (!(sender_->ExportDebugState().gain_cycle_index >= 4 && + bbr_sender_.bytes_to_transfer() == 0)) { + bbr_sender_.AddBytesToTransfer(kTestLinkBandwidth.ToBytesPerSecond()); + ASSERT_TRUE(simulator_.RunUntilOrTimeout( + [this]() { return bbr_sender_.bytes_to_transfer() == 0; }, timeout)); + } + + // Send at 10% of available rate. Run for 3 seconds, checking in the middle + // and at the end. The pacing gain should be high throughout. + QuicBandwidth target_bandwidth = 0.1f * kTestLinkBandwidth; + QuicTime::Delta burst_interval = QuicTime::Delta::FromMilliseconds(300); + for (int i = 0; i < 2; i++) { + SendBursts(5, target_bandwidth * burst_interval, burst_interval); + EXPECT_EQ(BbrSender::PROBE_BW, sender_->ExportDebugState().mode); + EXPECT_EQ(0, sender_->ExportDebugState().gain_cycle_index); + EXPECT_APPROX_EQ(kTestLinkBandwidth, + sender_->ExportDebugState().max_bandwidth, 0.02f); + } + + // Now that in-flight is almost zero and the pacing gain is still above 1, + // send approximately 1.25 BDPs worth of data. This should cause the + // PROBE_BW mode to enter low gain cycle, and exit it earlier than one min_rtt + // due to running out of data to send. + bbr_sender_.AddBytesToTransfer(1.3 * kTestBdp); + ASSERT_TRUE(simulator_.RunUntilOrTimeout( + [this]() { return sender_->ExportDebugState().gain_cycle_index == 1; }, + timeout)); + + simulator_.RunFor(0.75 * sender_->ExportDebugState().min_rtt); + EXPECT_EQ(BbrSender::PROBE_BW, sender_->ExportDebugState().mode); + EXPECT_EQ(2, sender_->ExportDebugState().gain_cycle_index); +} + +// Ensure that the pacing rate does not drop at startup. +TEST_F(BbrSenderTest, NoBandwidthDropOnStartup) { + CreateDefaultSetup(); + + const QuicTime::Delta timeout = QuicTime::Delta::FromSeconds(5); + bool simulator_result; + + QuicBandwidth initial_rate = QuicBandwidth::FromBytesAndTimeDelta( + kInitialCongestionWindowPackets * kDefaultTCPMSS, + rtt_stats_->initial_rtt()); + EXPECT_GE(sender_->PacingRate(0), initial_rate); + + // Send a packet. + bbr_sender_.AddBytesToTransfer(1000); + simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return receiver_.bytes_received() == 1000; }, timeout); + ASSERT_TRUE(simulator_result); + EXPECT_GE(sender_->PacingRate(0), initial_rate); + + // Wait for a while. + simulator_.RunFor(QuicTime::Delta::FromSeconds(2)); + EXPECT_GE(sender_->PacingRate(0), initial_rate); + + // Send another packet. + bbr_sender_.AddBytesToTransfer(1000); + simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return receiver_.bytes_received() == 2000; }, timeout); + ASSERT_TRUE(simulator_result); + EXPECT_GE(sender_->PacingRate(0), initial_rate); +} + +// Test exiting STARTUP earlier due to the 1RTT connection option. +TEST_F(BbrSenderTest, SimpleTransfer1RTTStartup) { + CreateDefaultSetup(); + + SetConnectionOption(k1RTT); + EXPECT_EQ(1u, sender_->num_startup_rtts()); + + // Run until the full bandwidth is reached and check how many rounds it was. + bbr_sender_.AddBytesToTransfer(12 * 1024 * 1024); + QuicRoundTripCount max_bw_round = 0; + QuicBandwidth max_bw(QuicBandwidth::Zero()); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this, &max_bw, &max_bw_round]() { + if (max_bw < sender_->ExportDebugState().max_bandwidth) { + max_bw = sender_->ExportDebugState().max_bandwidth; + max_bw_round = sender_->ExportDebugState().round_trip_count; + } + return sender_->ExportDebugState().is_at_full_bandwidth; + }, + QuicTime::Delta::FromSeconds(5)); + ASSERT_TRUE(simulator_result); + EXPECT_EQ(BbrSender::DRAIN, sender_->ExportDebugState().mode); + EXPECT_EQ(1u, sender_->ExportDebugState().round_trip_count - max_bw_round); + EXPECT_EQ(1u, sender_->ExportDebugState().rounds_without_bandwidth_gain); + EXPECT_EQ(0u, bbr_sender_.connection()->GetStats().packets_lost); + EXPECT_FALSE(sender_->ExportDebugState().last_sample_is_app_limited); +} + +// Test exiting STARTUP earlier due to the 2RTT connection option. +TEST_F(BbrSenderTest, SimpleTransfer2RTTStartup) { + CreateDefaultSetup(); + + SetConnectionOption(k2RTT); + EXPECT_EQ(2u, sender_->num_startup_rtts()); + + // Run until the full bandwidth is reached and check how many rounds it was. + bbr_sender_.AddBytesToTransfer(12 * 1024 * 1024); + QuicRoundTripCount max_bw_round = 0; + QuicBandwidth max_bw(QuicBandwidth::Zero()); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this, &max_bw, &max_bw_round]() { + if (max_bw * 1.001 < sender_->ExportDebugState().max_bandwidth) { + max_bw = sender_->ExportDebugState().max_bandwidth; + max_bw_round = sender_->ExportDebugState().round_trip_count; + } + return sender_->ExportDebugState().is_at_full_bandwidth; + }, + QuicTime::Delta::FromSeconds(5)); + ASSERT_TRUE(simulator_result); + EXPECT_EQ(BbrSender::DRAIN, sender_->ExportDebugState().mode); + EXPECT_EQ(2u, sender_->ExportDebugState().round_trip_count - max_bw_round); + EXPECT_EQ(2u, sender_->ExportDebugState().rounds_without_bandwidth_gain); + EXPECT_EQ(0u, bbr_sender_.connection()->GetStats().packets_lost); + EXPECT_FALSE(sender_->ExportDebugState().last_sample_is_app_limited); +} + +// Test exiting STARTUP earlier upon loss. +TEST_F(BbrSenderTest, SimpleTransferExitStartupOnLoss) { + CreateDefaultSetup(); + + EXPECT_EQ(3u, sender_->num_startup_rtts()); + + // Run until the full bandwidth is reached and check how many rounds it was. + bbr_sender_.AddBytesToTransfer(12 * 1024 * 1024); + QuicRoundTripCount max_bw_round = 0; + QuicBandwidth max_bw(QuicBandwidth::Zero()); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this, &max_bw, &max_bw_round]() { + if (max_bw * 1.001 < sender_->ExportDebugState().max_bandwidth) { + max_bw = sender_->ExportDebugState().max_bandwidth; + max_bw_round = sender_->ExportDebugState().round_trip_count; + } + return sender_->ExportDebugState().is_at_full_bandwidth; + }, + QuicTime::Delta::FromSeconds(5)); + ASSERT_TRUE(simulator_result); + EXPECT_EQ(BbrSender::DRAIN, sender_->ExportDebugState().mode); + EXPECT_EQ(3u, sender_->ExportDebugState().round_trip_count - max_bw_round); + EXPECT_EQ(3u, sender_->ExportDebugState().rounds_without_bandwidth_gain); + EXPECT_EQ(0u, bbr_sender_.connection()->GetStats().packets_lost); + EXPECT_FALSE(sender_->ExportDebugState().last_sample_is_app_limited); +} + +// Test exiting STARTUP earlier upon loss with a small buffer. +TEST_F(BbrSenderTest, SimpleTransferExitStartupOnLossSmallBuffer) { + CreateSmallBufferSetup(); + + EXPECT_EQ(3u, sender_->num_startup_rtts()); + + // Run until the full bandwidth is reached and check how many rounds it was. + bbr_sender_.AddBytesToTransfer(12 * 1024 * 1024); + QuicRoundTripCount max_bw_round = 0; + QuicBandwidth max_bw(QuicBandwidth::Zero()); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this, &max_bw, &max_bw_round]() { + if (max_bw < sender_->ExportDebugState().max_bandwidth) { + max_bw = sender_->ExportDebugState().max_bandwidth; + max_bw_round = sender_->ExportDebugState().round_trip_count; + } + return sender_->ExportDebugState().is_at_full_bandwidth; + }, + QuicTime::Delta::FromSeconds(5)); + ASSERT_TRUE(simulator_result); + EXPECT_EQ(BbrSender::DRAIN, sender_->ExportDebugState().mode); + EXPECT_GE(2u, sender_->ExportDebugState().round_trip_count - max_bw_round); + EXPECT_EQ(1u, sender_->ExportDebugState().rounds_without_bandwidth_gain); + EXPECT_NE(0u, bbr_sender_.connection()->GetStats().packets_lost); + EXPECT_FALSE(sender_->ExportDebugState().last_sample_is_app_limited); +} + +TEST_F(BbrSenderTest, DerivedPacingGainStartup) { + CreateDefaultSetup(); + + SetConnectionOption(kBBQ1); + EXPECT_EQ(3u, sender_->num_startup_rtts()); + // Verify that Sender is in slow start. + EXPECT_TRUE(sender_->InSlowStart()); + // Verify that pacing rate is based on the initial RTT. + QuicBandwidth expected_pacing_rate = QuicBandwidth::FromBytesAndTimeDelta( + 2.773 * kDefaultWindowTCP, rtt_stats_->initial_rtt()); + EXPECT_APPROX_EQ(expected_pacing_rate.ToBitsPerSecond(), + sender_->PacingRate(0).ToBitsPerSecond(), 0.01f); + + // Run until the full bandwidth is reached and check how many rounds it was. + bbr_sender_.AddBytesToTransfer(12 * 1024 * 1024); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_->ExportDebugState().is_at_full_bandwidth; }, + QuicTime::Delta::FromSeconds(5)); + ASSERT_TRUE(simulator_result); + EXPECT_EQ(BbrSender::DRAIN, sender_->ExportDebugState().mode); + EXPECT_EQ(3u, sender_->ExportDebugState().rounds_without_bandwidth_gain); + EXPECT_APPROX_EQ(kTestLinkBandwidth, + sender_->ExportDebugState().max_bandwidth, 0.01f); + EXPECT_EQ(0u, bbr_sender_.connection()->GetStats().packets_lost); + EXPECT_FALSE(sender_->ExportDebugState().last_sample_is_app_limited); +} + +TEST_F(BbrSenderTest, DerivedCWNDGainStartup) { + CreateSmallBufferSetup(); + + EXPECT_EQ(3u, sender_->num_startup_rtts()); + // Verify that Sender is in slow start. + EXPECT_TRUE(sender_->InSlowStart()); + // Verify that pacing rate is based on the initial RTT. + QuicBandwidth expected_pacing_rate = QuicBandwidth::FromBytesAndTimeDelta( + 2.885 * kDefaultWindowTCP, rtt_stats_->initial_rtt()); + EXPECT_APPROX_EQ(expected_pacing_rate.ToBitsPerSecond(), + sender_->PacingRate(0).ToBitsPerSecond(), 0.01f); + + // Run until the full bandwidth is reached and check how many rounds it was. + bbr_sender_.AddBytesToTransfer(12 * 1024 * 1024); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_->ExportDebugState().is_at_full_bandwidth; }, + QuicTime::Delta::FromSeconds(5)); + ASSERT_TRUE(simulator_result); + EXPECT_EQ(BbrSender::DRAIN, sender_->ExportDebugState().mode); + if (!bbr_sender_.connection()->GetStats().bbr_exit_startup_due_to_loss) { + EXPECT_EQ(3u, sender_->ExportDebugState().rounds_without_bandwidth_gain); + } + EXPECT_APPROX_EQ(kTestLinkBandwidth, + sender_->ExportDebugState().max_bandwidth, 0.01f); + float loss_rate = + static_cast(bbr_sender_.connection()->GetStats().packets_lost) / + bbr_sender_.connection()->GetStats().packets_sent; + EXPECT_LT(loss_rate, 0.15f); + EXPECT_FALSE(sender_->ExportDebugState().last_sample_is_app_limited); + // Expect an SRTT less than 2.7 * Min RTT on exit from STARTUP. + EXPECT_GT(kTestRtt * 2.7, rtt_stats_->smoothed_rtt()); +} + +TEST_F(BbrSenderTest, AckAggregationInStartup) { + CreateDefaultSetup(); + + SetConnectionOption(kBBQ3); + EXPECT_EQ(3u, sender_->num_startup_rtts()); + // Verify that Sender is in slow start. + EXPECT_TRUE(sender_->InSlowStart()); + // Verify that pacing rate is based on the initial RTT. + QuicBandwidth expected_pacing_rate = QuicBandwidth::FromBytesAndTimeDelta( + 2.885 * kDefaultWindowTCP, rtt_stats_->initial_rtt()); + EXPECT_APPROX_EQ(expected_pacing_rate.ToBitsPerSecond(), + sender_->PacingRate(0).ToBitsPerSecond(), 0.01f); + + // Run until the full bandwidth is reached and check how many rounds it was. + bbr_sender_.AddBytesToTransfer(12 * 1024 * 1024); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_->ExportDebugState().is_at_full_bandwidth; }, + QuicTime::Delta::FromSeconds(5)); + ASSERT_TRUE(simulator_result); + EXPECT_EQ(BbrSender::DRAIN, sender_->ExportDebugState().mode); + EXPECT_EQ(3u, sender_->ExportDebugState().rounds_without_bandwidth_gain); + EXPECT_APPROX_EQ(kTestLinkBandwidth, + sender_->ExportDebugState().max_bandwidth, 0.01f); + EXPECT_EQ(0u, bbr_sender_.connection()->GetStats().packets_lost); + EXPECT_FALSE(sender_->ExportDebugState().last_sample_is_app_limited); +} + +// Test that two BBR flows started slightly apart from each other terminate. +TEST_F(BbrSenderTest, SimpleCompetition) { + const QuicByteCount transfer_size = 10 * 1024 * 1024; + const QuicTime::Delta transfer_time = + kTestLinkBandwidth.TransferTime(transfer_size); + CreateBbrVsBbrSetup(); + + // Transfer 10% of data in first transfer. + bbr_sender_.AddBytesToTransfer(transfer_size); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return receiver_.bytes_received() >= 0.1 * transfer_size; }, + transfer_time); + ASSERT_TRUE(simulator_result); + + // Start the second transfer and wait until both finish. + competing_sender_.AddBytesToTransfer(transfer_size); + simulator_result = simulator_.RunUntilOrTimeout( + [this]() { + return receiver_.bytes_received() == transfer_size && + competing_receiver_.bytes_received() == transfer_size; + }, + 3 * transfer_time); + ASSERT_TRUE(simulator_result); +} + +// Test that BBR can resume bandwidth from cached network parameters. +TEST_F(BbrSenderTest, ResumeConnectionState) { + CreateDefaultSetup(); + + bbr_sender_.connection()->AdjustNetworkParameters( + SendAlgorithmInterface::NetworkParams(kTestLinkBandwidth, kTestRtt, + false)); + EXPECT_EQ(kTestLinkBandwidth * kTestRtt, + sender_->ExportDebugState().congestion_window); + + EXPECT_EQ(kTestLinkBandwidth, sender_->PacingRate(/*bytes_in_flight=*/0)); + + EXPECT_APPROX_EQ(kTestRtt, sender_->ExportDebugState().min_rtt, 0.01f); + + DriveOutOfStartup(); +} + +// Test with a min CWND of 1 instead of 4 packets. +TEST_F(BbrSenderTest, ProbeRTTMinCWND1) { + CreateDefaultSetup(); + SetConnectionOption(kMIN1); + DriveOutOfStartup(); + + // We have no intention of ever finishing this transfer. + bbr_sender_.AddBytesToTransfer(100 * 1024 * 1024); + + // Wait until the connection enters PROBE_RTT. + const QuicTime::Delta timeout = QuicTime::Delta::FromSeconds(12); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { + return sender_->ExportDebugState().mode == BbrSender::PROBE_RTT; + }, + timeout); + ASSERT_TRUE(simulator_result); + ASSERT_EQ(BbrSender::PROBE_RTT, sender_->ExportDebugState().mode); + // The PROBE_RTT CWND should be 1 if the min CWND is 1. + EXPECT_EQ(kDefaultTCPMSS, sender_->GetCongestionWindow()); + + // Exit PROBE_RTT. + const QuicTime probe_rtt_start = clock_->Now(); + const QuicTime::Delta time_to_exit_probe_rtt = + kTestRtt + QuicTime::Delta::FromMilliseconds(200); + simulator_.RunFor(1.5 * time_to_exit_probe_rtt); + EXPECT_EQ(BbrSender::PROBE_BW, sender_->ExportDebugState().mode); + EXPECT_GE(sender_->ExportDebugState().min_rtt_timestamp, probe_rtt_start); +} + +TEST_F(BbrSenderTest, StartupStats) { + CreateDefaultSetup(); + + DriveOutOfStartup(); + ASSERT_FALSE(sender_->InSlowStart()); + + const QuicConnectionStats& stats = bbr_sender_.connection()->GetStats(); + EXPECT_EQ(1u, stats.slowstart_count); + EXPECT_THAT(stats.slowstart_num_rtts, AllOf(Ge(5u), Le(15u))); + EXPECT_THAT(stats.slowstart_packets_sent, AllOf(Ge(100u), Le(1000u))); + EXPECT_THAT(stats.slowstart_bytes_sent, AllOf(Ge(100000u), Le(1000000u))); + EXPECT_LE(stats.slowstart_packets_lost, 10u); + EXPECT_LE(stats.slowstart_bytes_lost, 10000u); + EXPECT_FALSE(stats.slowstart_duration.IsRunning()); + EXPECT_THAT(stats.slowstart_duration.GetTotalElapsedTime(), + AllOf(Ge(QuicTime::Delta::FromMilliseconds(500)), + Le(QuicTime::Delta::FromMilliseconds(1500)))); + EXPECT_EQ(stats.slowstart_duration.GetTotalElapsedTime(), + QuicConnectionPeer::GetSentPacketManager(bbr_sender_.connection()) + ->GetSlowStartDuration()); +} + +// Regression test for b/143540157. +TEST_F(BbrSenderTest, RecalculatePacingRateOnCwndChange1RTT) { + CreateDefaultSetup(); + + bbr_sender_.AddBytesToTransfer(1 * 1024 * 1024); + // Wait until an ACK comes back. + const QuicTime::Delta timeout = QuicTime::Delta::FromSeconds(5); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return !sender_->ExportDebugState().min_rtt.IsZero(); }, + timeout); + ASSERT_TRUE(simulator_result); + const QuicByteCount previous_cwnd = + sender_->ExportDebugState().congestion_window; + + // Bootstrap cwnd. + bbr_sender_.connection()->AdjustNetworkParameters( + SendAlgorithmInterface::NetworkParams(kTestLinkBandwidth, + QuicTime::Delta::Zero(), false)); + EXPECT_LT(previous_cwnd, sender_->ExportDebugState().congestion_window); + + // Verify pacing rate is re-calculated based on the new cwnd and min_rtt. + EXPECT_APPROX_EQ(QuicBandwidth::FromBytesAndTimeDelta( + sender_->ExportDebugState().congestion_window, + sender_->ExportDebugState().min_rtt), + sender_->PacingRate(/*bytes_in_flight=*/0), 0.01f); +} + +TEST_F(BbrSenderTest, RecalculatePacingRateOnCwndChange0RTT) { + CreateDefaultSetup(); + // Initial RTT is available. + const_cast(rtt_stats_)->set_initial_rtt(kTestRtt); + + // Bootstrap cwnd. + bbr_sender_.connection()->AdjustNetworkParameters( + SendAlgorithmInterface::NetworkParams(kTestLinkBandwidth, + QuicTime::Delta::Zero(), false)); + EXPECT_LT(kInitialCongestionWindowPackets * kDefaultTCPMSS, + sender_->ExportDebugState().congestion_window); + // No Rtt sample is available. + EXPECT_TRUE(sender_->ExportDebugState().min_rtt.IsZero()); + + // Verify pacing rate is re-calculated based on the new cwnd and initial + // RTT. + EXPECT_APPROX_EQ(QuicBandwidth::FromBytesAndTimeDelta( + sender_->ExportDebugState().congestion_window, + rtt_stats_->initial_rtt()), + sender_->PacingRate(/*bytes_in_flight=*/0), 0.01f); +} + +TEST_F(BbrSenderTest, MitigateCwndBootstrappingOvershoot) { + CreateDefaultSetup(); + bbr_sender_.AddBytesToTransfer(1 * 1024 * 1024); + + // Wait until an ACK comes back. + const QuicTime::Delta timeout = QuicTime::Delta::FromSeconds(5); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return !sender_->ExportDebugState().min_rtt.IsZero(); }, + timeout); + ASSERT_TRUE(simulator_result); + + // Bootstrap cwnd by a overly large bandwidth sample. + bbr_sender_.connection()->AdjustNetworkParameters( + SendAlgorithmInterface::NetworkParams(8 * kTestLinkBandwidth, + QuicTime::Delta::Zero(), false)); + QuicBandwidth pacing_rate = sender_->PacingRate(0); + EXPECT_EQ(8 * kTestLinkBandwidth, pacing_rate); + + // Wait until pacing_rate decreases. + simulator_result = simulator_.RunUntilOrTimeout( + [this, pacing_rate]() { return sender_->PacingRate(0) < pacing_rate; }, + timeout); + ASSERT_TRUE(simulator_result); + EXPECT_EQ(BbrSender::STARTUP, sender_->ExportDebugState().mode); + if (GetQuicReloadableFlag(quic_conservative_cwnd_and_pacing_gains)) { + EXPECT_APPROX_EQ(2.0f * sender_->BandwidthEstimate(), + sender_->PacingRate(0), 0.01f); + } else { + EXPECT_APPROX_EQ(2.885f * sender_->BandwidthEstimate(), + sender_->PacingRate(0), 0.01f); + } +} + +TEST_F(BbrSenderTest, 200InitialCongestionWindowWithNetworkParameterAdjusted) { + CreateDefaultSetup(); + + bbr_sender_.AddBytesToTransfer(1 * 1024 * 1024); + // Wait until an ACK comes back. + const QuicTime::Delta timeout = QuicTime::Delta::FromSeconds(5); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return !sender_->ExportDebugState().min_rtt.IsZero(); }, + timeout); + ASSERT_TRUE(simulator_result); + + // Bootstrap cwnd by a overly large bandwidth sample. + bbr_sender_.connection()->AdjustNetworkParameters( + SendAlgorithmInterface::NetworkParams(1024 * kTestLinkBandwidth, + QuicTime::Delta::Zero(), false)); + // Verify cwnd is capped at 200. + EXPECT_EQ(200 * kDefaultTCPMSS, + sender_->ExportDebugState().congestion_window); + EXPECT_GT(1024 * kTestLinkBandwidth, sender_->PacingRate(0)); +} + +TEST_F(BbrSenderTest, 100InitialCongestionWindowFromNetworkParameter) { + CreateDefaultSetup(); + + bbr_sender_.AddBytesToTransfer(1 * 1024 * 1024); + // Wait until an ACK comes back. + const QuicTime::Delta timeout = QuicTime::Delta::FromSeconds(5); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return !sender_->ExportDebugState().min_rtt.IsZero(); }, + timeout); + ASSERT_TRUE(simulator_result); + + // Bootstrap cwnd by a overly large bandwidth sample. + SendAlgorithmInterface::NetworkParams network_params( + 1024 * kTestLinkBandwidth, QuicTime::Delta::Zero(), false); + network_params.max_initial_congestion_window = 100; + bbr_sender_.connection()->AdjustNetworkParameters(network_params); + // Verify cwnd is capped at 100. + EXPECT_EQ(100 * kDefaultTCPMSS, + sender_->ExportDebugState().congestion_window); + EXPECT_GT(1024 * kTestLinkBandwidth, sender_->PacingRate(0)); +} + +TEST_F(BbrSenderTest, 100InitialCongestionWindowWithNetworkParameterAdjusted) { + SetConnectionOption(kICW1); + CreateDefaultSetup(); + + bbr_sender_.AddBytesToTransfer(1 * 1024 * 1024); + // Wait until an ACK comes back. + const QuicTime::Delta timeout = QuicTime::Delta::FromSeconds(5); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return !sender_->ExportDebugState().min_rtt.IsZero(); }, + timeout); + ASSERT_TRUE(simulator_result); + + // Bootstrap cwnd by a overly large bandwidth sample. + bbr_sender_.connection()->AdjustNetworkParameters( + SendAlgorithmInterface::NetworkParams(1024 * kTestLinkBandwidth, + QuicTime::Delta::Zero(), false)); + // Verify cwnd is capped at 100. + EXPECT_EQ(100 * kDefaultTCPMSS, + sender_->ExportDebugState().congestion_window); + EXPECT_GT(1024 * kTestLinkBandwidth, sender_->PacingRate(0)); +} + +// Ensures bandwidth estimate does not change after a loss only event. +// Regression test for b/151239871. +TEST_F(BbrSenderTest, LossOnlyCongestionEvent) { + CreateDefaultSetup(); + + DriveOutOfStartup(); + EXPECT_FALSE(sender_->ExportDebugState().last_sample_is_app_limited); + + // Send some bursts, each burst increments round count by 1, since it only + // generates small, app-limited samples, the max_bandwidth_ will not be + // updated. At the end of all bursts, all estimates in max_bandwidth_ will + // look very old such that any Update() will reset all estimates. + SendBursts(20, 512, QuicTime::Delta::FromSeconds(3)); + + QuicUnackedPacketMap* unacked_packets = + QuicSentPacketManagerPeer::GetUnackedPacketMap( + QuicConnectionPeer::GetSentPacketManager(bbr_sender_.connection())); + // Run until we have something in flight. + bbr_sender_.AddBytesToTransfer(50 * 1024 * 1024); + bool simulator_result = simulator_.RunUntilOrTimeout( + [&]() { return unacked_packets->bytes_in_flight() > 0; }, + QuicTime::Delta::FromSeconds(5)); + ASSERT_TRUE(simulator_result); + + const QuicBandwidth prior_bandwidth_estimate = sender_->BandwidthEstimate(); + EXPECT_APPROX_EQ(kTestLinkBandwidth, prior_bandwidth_estimate, 0.01f); + + // Lose the least unacked packet. + LostPacketVector lost_packets; + lost_packets.emplace_back( + bbr_sender_.connection()->sent_packet_manager().GetLeastUnacked(), + kDefaultMaxPacketSize); + + QuicTime now = simulator_.GetClock()->Now() + kTestRtt * 0.25; + sender_->OnCongestionEvent(false, unacked_packets->bytes_in_flight(), now, {}, + lost_packets, 0, 0); + + // Bandwidth estimate should not change for the loss only event. + EXPECT_EQ(prior_bandwidth_estimate, sender_->BandwidthEstimate()); +} + +TEST_F(BbrSenderTest, EnableOvershootingDetection) { + SetConnectionOption(kDTOS); + CreateSmallBufferSetup(); + // Set a overly large initial cwnd. + sender_->SetInitialCongestionWindowInPackets(200); + const QuicConnectionStats& stats = bbr_sender_.connection()->GetStats(); + EXPECT_FALSE(stats.overshooting_detected_with_network_parameters_adjusted); + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(30)); + + // Verify overshooting is detected. + EXPECT_TRUE(stats.overshooting_detected_with_network_parameters_adjusted); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/congestion_control/cubic_bytes.cc b/quiche/quic/core/congestion_control/cubic_bytes.cc new file mode 100644 index 000000000000..489042555f81 --- /dev/null +++ b/quiche/quic/core/congestion_control/cubic_bytes.cc @@ -0,0 +1,189 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/cubic_bytes.h" + +#include +#include +#include + +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +namespace { + +// Constants based on TCP defaults. +// The following constants are in 2^10 fractions of a second instead of ms to +// allow a 10 shift right to divide. +const int kCubeScale = 40; // 1024*1024^3 (first 1024 is from 0.100^3) + // where 0.100 is 100 ms which is the scaling + // round trip time. +const int kCubeCongestionWindowScale = 410; +// The cube factor for packets in bytes. +const uint64_t kCubeFactor = + (UINT64_C(1) << kCubeScale) / kCubeCongestionWindowScale / kDefaultTCPMSS; + +const float kDefaultCubicBackoffFactor = 0.7f; // Default Cubic backoff factor. +// Additional backoff factor when loss occurs in the concave part of the Cubic +// curve. This additional backoff factor is expected to give up bandwidth to +// new concurrent flows and speed up convergence. +const float kBetaLastMax = 0.85f; + +} // namespace + +CubicBytes::CubicBytes(const QuicClock* clock) + : clock_(clock), + num_connections_(kDefaultNumConnections), + epoch_(QuicTime::Zero()) { + ResetCubicState(); +} + +void CubicBytes::SetNumConnections(int num_connections) { + num_connections_ = num_connections; +} + +float CubicBytes::Alpha() const { + // TCPFriendly alpha is described in Section 3.3 of the CUBIC paper. Note that + // beta here is a cwnd multiplier, and is equal to 1-beta from the paper. + // We derive the equivalent alpha for an N-connection emulation as: + const float beta = Beta(); + return 3 * num_connections_ * num_connections_ * (1 - beta) / (1 + beta); +} + +float CubicBytes::Beta() const { + // kNConnectionBeta is the backoff factor after loss for our N-connection + // emulation, which emulates the effective backoff of an ensemble of N + // TCP-Reno connections on a single loss event. The effective multiplier is + // computed as: + return (num_connections_ - 1 + kDefaultCubicBackoffFactor) / num_connections_; +} + +float CubicBytes::BetaLastMax() const { + // BetaLastMax is the additional backoff factor after loss for our + // N-connection emulation, which emulates the additional backoff of + // an ensemble of N TCP-Reno connections on a single loss event. The + // effective multiplier is computed as: + return (num_connections_ - 1 + kBetaLastMax) / num_connections_; +} + +void CubicBytes::ResetCubicState() { + epoch_ = QuicTime::Zero(); // Reset time. + last_max_congestion_window_ = 0; + acked_bytes_count_ = 0; + estimated_tcp_congestion_window_ = 0; + origin_point_congestion_window_ = 0; + time_to_origin_point_ = 0; + last_target_congestion_window_ = 0; +} + +void CubicBytes::OnApplicationLimited() { + // When sender is not using the available congestion window, the window does + // not grow. But to be RTT-independent, Cubic assumes that the sender has been + // using the entire window during the time since the beginning of the current + // "epoch" (the end of the last loss recovery period). Since + // application-limited periods break this assumption, we reset the epoch when + // in such a period. This reset effectively freezes congestion window growth + // through application-limited periods and allows Cubic growth to continue + // when the entire window is being used. + epoch_ = QuicTime::Zero(); +} + +QuicByteCount CubicBytes::CongestionWindowAfterPacketLoss( + QuicByteCount current_congestion_window) { + // Since bytes-mode Reno mode slightly under-estimates the cwnd, we + // may never reach precisely the last cwnd over the course of an + // RTT. Do not interpret a slight under-estimation as competing traffic. + if (current_congestion_window + kDefaultTCPMSS < + last_max_congestion_window_) { + // We never reached the old max, so assume we are competing with + // another flow. Use our extra back off factor to allow the other + // flow to go up. + last_max_congestion_window_ = + static_cast(BetaLastMax() * current_congestion_window); + } else { + last_max_congestion_window_ = current_congestion_window; + } + epoch_ = QuicTime::Zero(); // Reset time. + return static_cast(current_congestion_window * Beta()); +} + +QuicByteCount CubicBytes::CongestionWindowAfterAck( + QuicByteCount acked_bytes, QuicByteCount current_congestion_window, + QuicTime::Delta delay_min, QuicTime event_time) { + acked_bytes_count_ += acked_bytes; + + if (!epoch_.IsInitialized()) { + // First ACK after a loss event. + QUIC_DVLOG(1) << "Start of epoch"; + epoch_ = event_time; // Start of epoch. + acked_bytes_count_ = acked_bytes; // Reset count. + // Reset estimated_tcp_congestion_window_ to be in sync with cubic. + estimated_tcp_congestion_window_ = current_congestion_window; + if (last_max_congestion_window_ <= current_congestion_window) { + time_to_origin_point_ = 0; + origin_point_congestion_window_ = current_congestion_window; + } else { + time_to_origin_point_ = static_cast( + cbrt(kCubeFactor * + (last_max_congestion_window_ - current_congestion_window))); + origin_point_congestion_window_ = last_max_congestion_window_; + } + } + // Change the time unit from microseconds to 2^10 fractions per second. Take + // the round trip time in account. This is done to allow us to use shift as a + // divide operator. + int64_t elapsed_time = + ((event_time + delay_min - epoch_).ToMicroseconds() << 10) / + kNumMicrosPerSecond; + + // Right-shifts of negative, signed numbers have implementation-dependent + // behavior, so force the offset to be positive, as is done in the kernel. + uint64_t offset = std::abs(time_to_origin_point_ - elapsed_time); + + QuicByteCount delta_congestion_window = (kCubeCongestionWindowScale * offset * + offset * offset * kDefaultTCPMSS) >> + kCubeScale; + + const bool add_delta = elapsed_time > time_to_origin_point_; + QUICHE_DCHECK(add_delta || + (origin_point_congestion_window_ > delta_congestion_window)); + QuicByteCount target_congestion_window = + add_delta ? origin_point_congestion_window_ + delta_congestion_window + : origin_point_congestion_window_ - delta_congestion_window; + // Limit the CWND increase to half the acked bytes. + target_congestion_window = + std::min(target_congestion_window, + current_congestion_window + acked_bytes_count_ / 2); + + QUICHE_DCHECK_LT(0u, estimated_tcp_congestion_window_); + // Increase the window by approximately Alpha * 1 MSS of bytes every + // time we ack an estimated tcp window of bytes. For small + // congestion windows (less than 25), the formula below will + // increase slightly slower than linearly per estimated tcp window + // of bytes. + estimated_tcp_congestion_window_ += acked_bytes_count_ * + (Alpha() * kDefaultTCPMSS) / + estimated_tcp_congestion_window_; + acked_bytes_count_ = 0; + + // We have a new cubic congestion window. + last_target_congestion_window_ = target_congestion_window; + + // Compute target congestion_window based on cubic target and estimated TCP + // congestion_window, use highest (fastest). + if (target_congestion_window < estimated_tcp_congestion_window_) { + target_congestion_window = estimated_tcp_congestion_window_; + } + + QUIC_DVLOG(1) << "Final target congestion_window: " + << target_congestion_window; + return target_congestion_window; +} + +} // namespace quic diff --git a/quiche/quic/core/congestion_control/cubic_bytes.h b/quiche/quic/core/congestion_control/cubic_bytes.h new file mode 100644 index 000000000000..6eb43ab58cd4 --- /dev/null +++ b/quiche/quic/core/congestion_control/cubic_bytes.h @@ -0,0 +1,102 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Cubic algorithm, helper class to TCP cubic. +// For details see http://netsrv.csc.ncsu.edu/export/cubic_a_new_tcp_2008.pdf. + +#ifndef QUICHE_QUIC_CORE_CONGESTION_CONTROL_CUBIC_BYTES_H_ +#define QUICHE_QUIC_CORE_CONGESTION_CONTROL_CUBIC_BYTES_H_ + +#include + +#include "quiche/quic/core/quic_bandwidth.h" +#include "quiche/quic/core/quic_clock.h" +#include "quiche/quic/core/quic_connection_stats.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +namespace test { +class CubicBytesTest; +} // namespace test + +class QUIC_EXPORT_PRIVATE CubicBytes { + public: + explicit CubicBytes(const QuicClock* clock); + CubicBytes(const CubicBytes&) = delete; + CubicBytes& operator=(const CubicBytes&) = delete; + + void SetNumConnections(int num_connections); + + // Call after a timeout to reset the cubic state. + void ResetCubicState(); + + // Compute a new congestion window to use after a loss event. + // Returns the new congestion window in packets. The new congestion window is + // a multiplicative decrease of our current window. + QuicByteCount CongestionWindowAfterPacketLoss(QuicPacketCount current); + + // Compute a new congestion window to use after a received ACK. + // Returns the new congestion window in bytes. The new congestion window + // follows a cubic function that depends on the time passed since last packet + // loss. + QuicByteCount CongestionWindowAfterAck(QuicByteCount acked_bytes, + QuicByteCount current, + QuicTime::Delta delay_min, + QuicTime event_time); + + // Call on ack arrival when sender is unable to use the available congestion + // window. Resets Cubic state during quiescence. + void OnApplicationLimited(); + + private: + friend class test::CubicBytesTest; + + static const QuicTime::Delta MaxCubicTimeInterval() { + return QuicTime::Delta::FromMilliseconds(30); + } + + // Compute the TCP Cubic alpha, beta, and beta-last-max based on the + // current number of connections. + float Alpha() const; + float Beta() const; + float BetaLastMax() const; + + QuicByteCount last_max_congestion_window() const { + return last_max_congestion_window_; + } + + const QuicClock* clock_; + + // Number of connections to simulate. + int num_connections_; + + // Time when this cycle started, after last loss event. + QuicTime epoch_; + + // Max congestion window used just before last loss event. + // Note: to improve fairness to other streams an additional back off is + // applied to this value if the new value is below our latest value. + QuicByteCount last_max_congestion_window_; + + // Number of acked bytes since the cycle started (epoch). + QuicByteCount acked_bytes_count_; + + // TCP Reno equivalent congestion window in packets. + QuicByteCount estimated_tcp_congestion_window_; + + // Origin point of cubic function. + QuicByteCount origin_point_congestion_window_; + + // Time to origin point of cubic function in 2^10 fractions of a second. + uint32_t time_to_origin_point_; + + // Last congestion window in packets computed by cubic function. + QuicByteCount last_target_congestion_window_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CONGESTION_CONTROL_CUBIC_BYTES_H_ diff --git a/quiche/quic/core/congestion_control/cubic_bytes_test.cc b/quiche/quic/core/congestion_control/cubic_bytes_test.cc new file mode 100644 index 000000000000..4899d516b1e7 --- /dev/null +++ b/quiche/quic/core/congestion_control/cubic_bytes_test.cc @@ -0,0 +1,387 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/cubic_bytes.h" + +#include + +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/mock_clock.h" + +namespace quic { +namespace test { +namespace { + +const float kBeta = 0.7f; // Default Cubic backoff factor. +const float kBetaLastMax = 0.85f; // Default Cubic backoff factor. +const uint32_t kNumConnections = 2; +const float kNConnectionBeta = (kNumConnections - 1 + kBeta) / kNumConnections; +const float kNConnectionBetaLastMax = + (kNumConnections - 1 + kBetaLastMax) / kNumConnections; +const float kNConnectionAlpha = 3 * kNumConnections * kNumConnections * + (1 - kNConnectionBeta) / (1 + kNConnectionBeta); + +} // namespace + +class CubicBytesTest : public QuicTest { + protected: + CubicBytesTest() + : one_ms_(QuicTime::Delta::FromMilliseconds(1)), + hundred_ms_(QuicTime::Delta::FromMilliseconds(100)), + cubic_(&clock_) {} + + QuicByteCount RenoCwndInBytes(QuicByteCount current_cwnd) { + QuicByteCount reno_estimated_cwnd = + current_cwnd + + kDefaultTCPMSS * (kNConnectionAlpha * kDefaultTCPMSS) / current_cwnd; + return reno_estimated_cwnd; + } + + QuicByteCount ConservativeCwndInBytes(QuicByteCount current_cwnd) { + QuicByteCount conservative_cwnd = current_cwnd + kDefaultTCPMSS / 2; + return conservative_cwnd; + } + + QuicByteCount CubicConvexCwndInBytes(QuicByteCount initial_cwnd, + QuicTime::Delta rtt, + QuicTime::Delta elapsed_time) { + const int64_t offset = + ((elapsed_time + rtt).ToMicroseconds() << 10) / 1000000; + const QuicByteCount delta_congestion_window = + ((410 * offset * offset * offset) * kDefaultTCPMSS >> 40); + const QuicByteCount cubic_cwnd = initial_cwnd + delta_congestion_window; + return cubic_cwnd; + } + + QuicByteCount LastMaxCongestionWindow() { + return cubic_.last_max_congestion_window(); + } + + QuicTime::Delta MaxCubicTimeInterval() { + return cubic_.MaxCubicTimeInterval(); + } + + const QuicTime::Delta one_ms_; + const QuicTime::Delta hundred_ms_; + MockClock clock_; + CubicBytes cubic_; +}; + +// TODO(jokulik): The original "AboveOrigin" test, below, is very +// loose. It's nearly impossible to make the test tighter without +// deploying the fix for convex mode. Once cubic convex is deployed, +// replace "AboveOrigin" with this test. +TEST_F(CubicBytesTest, AboveOriginWithTighterBounds) { + // Convex growth. + const QuicTime::Delta rtt_min = hundred_ms_; + int64_t rtt_min_ms = rtt_min.ToMilliseconds(); + float rtt_min_s = rtt_min_ms / 1000.0; + QuicByteCount current_cwnd = 10 * kDefaultTCPMSS; + const QuicByteCount initial_cwnd = current_cwnd; + + clock_.AdvanceTime(one_ms_); + const QuicTime initial_time = clock_.ApproximateNow(); + const QuicByteCount expected_first_cwnd = RenoCwndInBytes(current_cwnd); + current_cwnd = cubic_.CongestionWindowAfterAck(kDefaultTCPMSS, current_cwnd, + rtt_min, initial_time); + ASSERT_EQ(expected_first_cwnd, current_cwnd); + + // Normal TCP phase. + // The maximum number of expected Reno RTTs is calculated by + // finding the point where the cubic curve and the reno curve meet. + const int max_reno_rtts = + std::sqrt(kNConnectionAlpha / (.4 * rtt_min_s * rtt_min_s * rtt_min_s)) - + 2; + for (int i = 0; i < max_reno_rtts; ++i) { + // Alternatively, we expect it to increase by one, every time we + // receive current_cwnd/Alpha acks back. (This is another way of + // saying we expect cwnd to increase by approximately Alpha once + // we receive current_cwnd number ofacks back). + const uint64_t num_acks_this_epoch = + current_cwnd / kDefaultTCPMSS / kNConnectionAlpha; + const QuicByteCount initial_cwnd_this_epoch = current_cwnd; + for (QuicPacketCount n = 0; n < num_acks_this_epoch; ++n) { + // Call once per ACK. + const QuicByteCount expected_next_cwnd = RenoCwndInBytes(current_cwnd); + current_cwnd = cubic_.CongestionWindowAfterAck( + kDefaultTCPMSS, current_cwnd, rtt_min, clock_.ApproximateNow()); + ASSERT_EQ(expected_next_cwnd, current_cwnd); + } + // Our byte-wise Reno implementation is an estimate. We expect + // the cwnd to increase by approximately one MSS every + // cwnd/kDefaultTCPMSS/Alpha acks, but it may be off by as much as + // half a packet for smaller values of current_cwnd. + const QuicByteCount cwnd_change_this_epoch = + current_cwnd - initial_cwnd_this_epoch; + ASSERT_NEAR(kDefaultTCPMSS, cwnd_change_this_epoch, kDefaultTCPMSS / 2); + clock_.AdvanceTime(hundred_ms_); + } + + for (int i = 0; i < 54; ++i) { + const uint64_t max_acks_this_epoch = current_cwnd / kDefaultTCPMSS; + const QuicTime::Delta interval = QuicTime::Delta::FromMicroseconds( + hundred_ms_.ToMicroseconds() / max_acks_this_epoch); + for (QuicPacketCount n = 0; n < max_acks_this_epoch; ++n) { + clock_.AdvanceTime(interval); + current_cwnd = cubic_.CongestionWindowAfterAck( + kDefaultTCPMSS, current_cwnd, rtt_min, clock_.ApproximateNow()); + const QuicByteCount expected_cwnd = CubicConvexCwndInBytes( + initial_cwnd, rtt_min, (clock_.ApproximateNow() - initial_time)); + // If we allow per-ack updates, every update is a small cubic update. + ASSERT_EQ(expected_cwnd, current_cwnd); + } + } + const QuicByteCount expected_cwnd = CubicConvexCwndInBytes( + initial_cwnd, rtt_min, (clock_.ApproximateNow() - initial_time)); + current_cwnd = cubic_.CongestionWindowAfterAck( + kDefaultTCPMSS, current_cwnd, rtt_min, clock_.ApproximateNow()); + ASSERT_EQ(expected_cwnd, current_cwnd); +} + +// TODO(ianswett): This test was disabled when all fixes were enabled, but it +// may be worth fixing. +TEST_F(CubicBytesTest, DISABLED_AboveOrigin) { + // Convex growth. + const QuicTime::Delta rtt_min = hundred_ms_; + QuicByteCount current_cwnd = 10 * kDefaultTCPMSS; + // Without the signed-integer, cubic-convex fix, we start out in the + // wrong mode. + QuicPacketCount expected_cwnd = RenoCwndInBytes(current_cwnd); + // Initialize the state. + clock_.AdvanceTime(one_ms_); + ASSERT_EQ(expected_cwnd, + cubic_.CongestionWindowAfterAck(kDefaultTCPMSS, current_cwnd, + rtt_min, clock_.ApproximateNow())); + current_cwnd = expected_cwnd; + const QuicPacketCount initial_cwnd = expected_cwnd; + // Normal TCP phase. + for (int i = 0; i < 48; ++i) { + for (QuicPacketCount n = 1; + n < current_cwnd / kDefaultTCPMSS / kNConnectionAlpha; ++n) { + // Call once per ACK. + ASSERT_NEAR( + current_cwnd, + cubic_.CongestionWindowAfterAck(kDefaultTCPMSS, current_cwnd, rtt_min, + clock_.ApproximateNow()), + kDefaultTCPMSS); + } + clock_.AdvanceTime(hundred_ms_); + current_cwnd = cubic_.CongestionWindowAfterAck( + kDefaultTCPMSS, current_cwnd, rtt_min, clock_.ApproximateNow()); + // When we fix convex mode and the uint64 arithmetic, we + // increase the expected_cwnd only after after the first 100ms, + // rather than after the initial 1ms. + expected_cwnd += kDefaultTCPMSS; + ASSERT_NEAR(expected_cwnd, current_cwnd, kDefaultTCPMSS); + } + // Cubic phase. + for (int i = 0; i < 52; ++i) { + for (QuicPacketCount n = 1; n < current_cwnd / kDefaultTCPMSS; ++n) { + // Call once per ACK. + ASSERT_NEAR( + current_cwnd, + cubic_.CongestionWindowAfterAck(kDefaultTCPMSS, current_cwnd, rtt_min, + clock_.ApproximateNow()), + kDefaultTCPMSS); + } + clock_.AdvanceTime(hundred_ms_); + current_cwnd = cubic_.CongestionWindowAfterAck( + kDefaultTCPMSS, current_cwnd, rtt_min, clock_.ApproximateNow()); + } + // Total time elapsed so far; add min_rtt (0.1s) here as well. + float elapsed_time_s = 10.0f + 0.1f; + // |expected_cwnd| is initial value of cwnd + K * t^3, where K = 0.4. + expected_cwnd = + initial_cwnd / kDefaultTCPMSS + + (elapsed_time_s * elapsed_time_s * elapsed_time_s * 410) / 1024; + EXPECT_EQ(expected_cwnd, current_cwnd / kDefaultTCPMSS); +} + +// Constructs an artificial scenario to ensure that cubic-convex +// increases are truly fine-grained: +// +// - After starting the epoch, this test advances the elapsed time +// sufficiently far that cubic will do small increases at less than +// MaxCubicTimeInterval() intervals. +// +// - Sets an artificially large initial cwnd to prevent Reno from the +// convex increases on every ack. +TEST_F(CubicBytesTest, AboveOriginFineGrainedCubing) { + // Start the test with an artificially large cwnd to prevent Reno + // from over-taking cubic. + QuicByteCount current_cwnd = 1000 * kDefaultTCPMSS; + const QuicByteCount initial_cwnd = current_cwnd; + const QuicTime::Delta rtt_min = hundred_ms_; + clock_.AdvanceTime(one_ms_); + QuicTime initial_time = clock_.ApproximateNow(); + + // Start the epoch and then artificially advance the time. + current_cwnd = cubic_.CongestionWindowAfterAck( + kDefaultTCPMSS, current_cwnd, rtt_min, clock_.ApproximateNow()); + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(600)); + current_cwnd = cubic_.CongestionWindowAfterAck( + kDefaultTCPMSS, current_cwnd, rtt_min, clock_.ApproximateNow()); + + // We expect the algorithm to perform only non-zero, fine-grained cubic + // increases on every ack in this case. + for (int i = 0; i < 100; ++i) { + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + const QuicByteCount expected_cwnd = CubicConvexCwndInBytes( + initial_cwnd, rtt_min, (clock_.ApproximateNow() - initial_time)); + const QuicByteCount next_cwnd = cubic_.CongestionWindowAfterAck( + kDefaultTCPMSS, current_cwnd, rtt_min, clock_.ApproximateNow()); + // Make sure we are performing cubic increases. + ASSERT_EQ(expected_cwnd, next_cwnd); + // Make sure that these are non-zero, less-than-packet sized + // increases. + ASSERT_GT(next_cwnd, current_cwnd); + const QuicByteCount cwnd_delta = next_cwnd - current_cwnd; + ASSERT_GT(kDefaultTCPMSS * .1, cwnd_delta); + + current_cwnd = next_cwnd; + } +} + +// Constructs an artificial scenario to show what happens when we +// allow per-ack updates, rather than limititing update freqency. In +// this scenario, the first two acks of the epoch produce the same +// cwnd. When we limit per-ack updates, this would cause the +// cessation of cubic updates for 30ms. When we allow per-ack +// updates, the window continues to grow on every ack. +TEST_F(CubicBytesTest, PerAckUpdates) { + // Start the test with a large cwnd and RTT, to force the first + // increase to be a cubic increase. + QuicPacketCount initial_cwnd_packets = 150; + QuicByteCount current_cwnd = initial_cwnd_packets * kDefaultTCPMSS; + const QuicTime::Delta rtt_min = 350 * one_ms_; + + // Initialize the epoch + clock_.AdvanceTime(one_ms_); + // Keep track of the growth of the reno-equivalent cwnd. + QuicByteCount reno_cwnd = RenoCwndInBytes(current_cwnd); + current_cwnd = cubic_.CongestionWindowAfterAck( + kDefaultTCPMSS, current_cwnd, rtt_min, clock_.ApproximateNow()); + const QuicByteCount initial_cwnd = current_cwnd; + + // Simulate the return of cwnd packets in less than + // MaxCubicInterval() time. + const QuicPacketCount max_acks = initial_cwnd_packets / kNConnectionAlpha; + const QuicTime::Delta interval = QuicTime::Delta::FromMicroseconds( + MaxCubicTimeInterval().ToMicroseconds() / (max_acks + 1)); + + // In this scenario, the first increase is dictated by the cubic + // equation, but it is less than one byte, so the cwnd doesn't + // change. Normally, without per-ack increases, any cwnd plateau + // will cause the cwnd to be pinned for MaxCubicTimeInterval(). If + // we enable per-ack updates, the cwnd will continue to grow, + // regardless of the temporary plateau. + clock_.AdvanceTime(interval); + reno_cwnd = RenoCwndInBytes(reno_cwnd); + ASSERT_EQ(current_cwnd, + cubic_.CongestionWindowAfterAck(kDefaultTCPMSS, current_cwnd, + rtt_min, clock_.ApproximateNow())); + for (QuicPacketCount i = 1; i < max_acks; ++i) { + clock_.AdvanceTime(interval); + const QuicByteCount next_cwnd = cubic_.CongestionWindowAfterAck( + kDefaultTCPMSS, current_cwnd, rtt_min, clock_.ApproximateNow()); + reno_cwnd = RenoCwndInBytes(reno_cwnd); + // The window shoud increase on every ack. + ASSERT_LT(current_cwnd, next_cwnd); + ASSERT_EQ(reno_cwnd, next_cwnd); + current_cwnd = next_cwnd; + } + + // After all the acks are returned from the epoch, we expect the + // cwnd to have increased by nearly one packet. (Not exactly one + // packet, because our byte-wise Reno algorithm is always a slight + // under-estimation). Without per-ack updates, the current_cwnd + // would otherwise be unchanged. + const QuicByteCount minimum_expected_increase = kDefaultTCPMSS * .9; + EXPECT_LT(minimum_expected_increase + initial_cwnd, current_cwnd); +} + +TEST_F(CubicBytesTest, LossEvents) { + const QuicTime::Delta rtt_min = hundred_ms_; + QuicByteCount current_cwnd = 422 * kDefaultTCPMSS; + // Without the signed-integer, cubic-convex fix, we mistakenly + // increment cwnd after only one_ms_ and a single ack. + QuicPacketCount expected_cwnd = RenoCwndInBytes(current_cwnd); + // Initialize the state. + clock_.AdvanceTime(one_ms_); + EXPECT_EQ(expected_cwnd, + cubic_.CongestionWindowAfterAck(kDefaultTCPMSS, current_cwnd, + rtt_min, clock_.ApproximateNow())); + + // On the first loss, the last max congestion window is set to the + // congestion window before the loss. + QuicByteCount pre_loss_cwnd = current_cwnd; + ASSERT_EQ(0u, LastMaxCongestionWindow()); + expected_cwnd = static_cast(current_cwnd * kNConnectionBeta); + EXPECT_EQ(expected_cwnd, + cubic_.CongestionWindowAfterPacketLoss(current_cwnd)); + ASSERT_EQ(pre_loss_cwnd, LastMaxCongestionWindow()); + current_cwnd = expected_cwnd; + + // On the second loss, the current congestion window has not yet + // reached the last max congestion window. The last max congestion + // window will be reduced by an additional backoff factor to allow + // for competition. + pre_loss_cwnd = current_cwnd; + expected_cwnd = static_cast(current_cwnd * kNConnectionBeta); + ASSERT_EQ(expected_cwnd, + cubic_.CongestionWindowAfterPacketLoss(current_cwnd)); + current_cwnd = expected_cwnd; + EXPECT_GT(pre_loss_cwnd, LastMaxCongestionWindow()); + QuicByteCount expected_last_max = + static_cast(pre_loss_cwnd * kNConnectionBetaLastMax); + EXPECT_EQ(expected_last_max, LastMaxCongestionWindow()); + EXPECT_LT(expected_cwnd, LastMaxCongestionWindow()); + // Simulate an increase, and check that we are below the origin. + current_cwnd = cubic_.CongestionWindowAfterAck( + kDefaultTCPMSS, current_cwnd, rtt_min, clock_.ApproximateNow()); + EXPECT_GT(LastMaxCongestionWindow(), current_cwnd); + + // On the final loss, simulate the condition where the congestion + // window had a chance to grow nearly to the last congestion window. + current_cwnd = LastMaxCongestionWindow() - 1; + pre_loss_cwnd = current_cwnd; + expected_cwnd = static_cast(current_cwnd * kNConnectionBeta); + EXPECT_EQ(expected_cwnd, + cubic_.CongestionWindowAfterPacketLoss(current_cwnd)); + expected_last_max = pre_loss_cwnd; + ASSERT_EQ(expected_last_max, LastMaxCongestionWindow()); +} + +TEST_F(CubicBytesTest, BelowOrigin) { + // Concave growth. + const QuicTime::Delta rtt_min = hundred_ms_; + QuicByteCount current_cwnd = 422 * kDefaultTCPMSS; + // Without the signed-integer, cubic-convex fix, we mistakenly + // increment cwnd after only one_ms_ and a single ack. + QuicPacketCount expected_cwnd = RenoCwndInBytes(current_cwnd); + // Initialize the state. + clock_.AdvanceTime(one_ms_); + EXPECT_EQ(expected_cwnd, + cubic_.CongestionWindowAfterAck(kDefaultTCPMSS, current_cwnd, + rtt_min, clock_.ApproximateNow())); + expected_cwnd = static_cast(current_cwnd * kNConnectionBeta); + EXPECT_EQ(expected_cwnd, + cubic_.CongestionWindowAfterPacketLoss(current_cwnd)); + current_cwnd = expected_cwnd; + // First update after loss to initialize the epoch. + current_cwnd = cubic_.CongestionWindowAfterAck( + kDefaultTCPMSS, current_cwnd, rtt_min, clock_.ApproximateNow()); + // Cubic phase. + for (int i = 0; i < 40; ++i) { + clock_.AdvanceTime(hundred_ms_); + current_cwnd = cubic_.CongestionWindowAfterAck( + kDefaultTCPMSS, current_cwnd, rtt_min, clock_.ApproximateNow()); + } + expected_cwnd = 553632; + EXPECT_EQ(expected_cwnd, current_cwnd); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/congestion_control/general_loss_algorithm.cc b/quiche/quic/core/congestion_control/general_loss_algorithm.cc new file mode 100644 index 000000000000..b92dc09d02a2 --- /dev/null +++ b/quiche/quic/core/congestion_control/general_loss_algorithm.cc @@ -0,0 +1,190 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/general_loss_algorithm.h" + +#include "quiche/quic/core/congestion_control/rtt_stats.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" + +namespace quic { + +namespace { +float DetectionResponseTime(QuicTime::Delta rtt, QuicTime send_time, + QuicTime detection_time) { + if (detection_time <= send_time || rtt.IsZero()) { + // Time skewed, assume a very fast detection where |detection_time| is + // |send_time| + |rtt|. + return 1.0; + } + float send_to_detection_us = (detection_time - send_time).ToMicroseconds(); + return send_to_detection_us / rtt.ToMicroseconds(); +} + +QuicTime::Delta GetMaxRtt(const RttStats& rtt_stats) { + return std::max(kAlarmGranularity, + std::max(rtt_stats.previous_srtt(), rtt_stats.latest_rtt())); +} + +} // namespace + +// Uses nack counts to decide when packets are lost. +LossDetectionInterface::DetectionStats GeneralLossAlgorithm::DetectLosses( + const QuicUnackedPacketMap& unacked_packets, QuicTime time, + const RttStats& rtt_stats, QuicPacketNumber largest_newly_acked, + const AckedPacketVector& packets_acked, LostPacketVector* packets_lost) { + DetectionStats detection_stats; + + loss_detection_timeout_ = QuicTime::Zero(); + if (!packets_acked.empty() && least_in_flight_.IsInitialized() && + packets_acked.front().packet_number == least_in_flight_) { + if (packets_acked.back().packet_number == largest_newly_acked && + least_in_flight_ + packets_acked.size() - 1 == largest_newly_acked) { + // Optimization for the case when no packet is missing. Please note, + // packets_acked can include packets of different packet number space, so + // do not use this optimization if largest_newly_acked is not the largest + // packet in packets_acked. + least_in_flight_ = largest_newly_acked + 1; + return detection_stats; + } + // There is hole in acked_packets, increment least_in_flight_ if possible. + for (const auto& acked : packets_acked) { + if (acked.packet_number != least_in_flight_) { + break; + } + ++least_in_flight_; + } + } + + const QuicTime::Delta max_rtt = GetMaxRtt(rtt_stats); + + QuicPacketNumber packet_number = unacked_packets.GetLeastUnacked(); + auto it = unacked_packets.begin(); + if (least_in_flight_.IsInitialized() && least_in_flight_ >= packet_number) { + if (least_in_flight_ > unacked_packets.largest_sent_packet() + 1) { + QUIC_BUG(quic_bug_10430_1) << "least_in_flight: " << least_in_flight_ + << " is greater than largest_sent_packet + 1: " + << unacked_packets.largest_sent_packet() + 1; + } else { + it += (least_in_flight_ - packet_number); + packet_number = least_in_flight_; + } + } + // Clear least_in_flight_. + least_in_flight_.Clear(); + QUICHE_DCHECK_EQ(packet_number_space_, + unacked_packets.GetPacketNumberSpace(largest_newly_acked)); + for (; it != unacked_packets.end() && packet_number <= largest_newly_acked; + ++it, ++packet_number) { + if (unacked_packets.GetPacketNumberSpace(it->encryption_level) != + packet_number_space_) { + // Skip packets of different packet number space. + continue; + } + + if (!it->in_flight) { + continue; + } + + if (parent_ != nullptr && largest_newly_acked != packet_number) { + parent_->OnReorderingDetected(); + } + + if (largest_newly_acked - packet_number > + detection_stats.sent_packets_max_sequence_reordering) { + detection_stats.sent_packets_max_sequence_reordering = + largest_newly_acked - packet_number; + } + + // Packet threshold loss detection. + // Skip packet threshold loss detection if largest_newly_acked is a runt. + const bool skip_packet_threshold_detection = + !use_packet_threshold_for_runt_packets_ && + it->bytes_sent > + unacked_packets.GetTransmissionInfo(largest_newly_acked).bytes_sent; + if (!skip_packet_threshold_detection && + largest_newly_acked - packet_number >= reordering_threshold_) { + packets_lost->push_back(LostPacket(packet_number, it->bytes_sent)); + detection_stats.total_loss_detection_response_time += + DetectionResponseTime(max_rtt, it->sent_time, time); + continue; + } + + // Time threshold loss detection. + const QuicTime::Delta loss_delay = max_rtt + (max_rtt >> reordering_shift_); + QuicTime when_lost = it->sent_time + loss_delay; + if (time < when_lost) { + if (time >= + it->sent_time + max_rtt + (max_rtt >> (reordering_shift_ + 1))) { + ++detection_stats.sent_packets_num_borderline_time_reorderings; + } + loss_detection_timeout_ = when_lost; + if (!least_in_flight_.IsInitialized()) { + // At this point, packet_number is in flight and not detected as lost. + least_in_flight_ = packet_number; + } + break; + } + packets_lost->push_back(LostPacket(packet_number, it->bytes_sent)); + detection_stats.total_loss_detection_response_time += + DetectionResponseTime(max_rtt, it->sent_time, time); + } + if (!least_in_flight_.IsInitialized()) { + // There is no in flight packet. + least_in_flight_ = largest_newly_acked + 1; + } + + return detection_stats; +} + +QuicTime GeneralLossAlgorithm::GetLossTimeout() const { + return loss_detection_timeout_; +} + +void GeneralLossAlgorithm::SpuriousLossDetected( + const QuicUnackedPacketMap& unacked_packets, const RttStats& rtt_stats, + QuicTime ack_receive_time, QuicPacketNumber packet_number, + QuicPacketNumber previous_largest_acked) { + if (use_adaptive_time_threshold_ && reordering_shift_ > 0) { + // Increase reordering fraction such that the packet would not have been + // declared lost. + QuicTime::Delta time_needed = + ack_receive_time - + unacked_packets.GetTransmissionInfo(packet_number).sent_time; + QuicTime::Delta max_rtt = + std::max(rtt_stats.previous_srtt(), rtt_stats.latest_rtt()); + while (max_rtt + (max_rtt >> reordering_shift_) < time_needed && + reordering_shift_ > 0) { + --reordering_shift_; + } + } + + if (use_adaptive_reordering_threshold_) { + QUICHE_DCHECK_LT(packet_number, previous_largest_acked); + // Increase reordering_threshold_ such that packet_number would not have + // been declared lost. + reordering_threshold_ = std::max( + reordering_threshold_, previous_largest_acked - packet_number + 1); + } +} + +void GeneralLossAlgorithm::Initialize(PacketNumberSpace packet_number_space, + LossDetectionInterface* parent) { + parent_ = parent; + if (packet_number_space_ < NUM_PACKET_NUMBER_SPACES) { + QUIC_BUG(quic_bug_10430_2) << "Cannot switch packet_number_space"; + return; + } + + packet_number_space_ = packet_number_space; +} + +void GeneralLossAlgorithm::Reset() { + loss_detection_timeout_ = QuicTime::Zero(); + least_in_flight_.Clear(); +} + +} // namespace quic diff --git a/quiche/quic/core/congestion_control/general_loss_algorithm.h b/quiche/quic/core/congestion_control/general_loss_algorithm.h new file mode 100644 index 000000000000..7e586162b623 --- /dev/null +++ b/quiche/quic/core/congestion_control/general_loss_algorithm.h @@ -0,0 +1,137 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CONGESTION_CONTROL_GENERAL_LOSS_ALGORITHM_H_ +#define QUICHE_QUIC_CORE_CONGESTION_CONTROL_GENERAL_LOSS_ALGORITHM_H_ + +#include +#include + +#include "quiche/quic/core/congestion_control/loss_detection_interface.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_unacked_packet_map.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// Class which can be configured to implement's TCP's approach of detecting loss +// when 3 nacks have been received for a packet or with a time threshold. +// Also implements TCP's early retransmit(RFC5827). +class QUIC_EXPORT_PRIVATE GeneralLossAlgorithm : public LossDetectionInterface { + public: + GeneralLossAlgorithm() = default; + GeneralLossAlgorithm(const GeneralLossAlgorithm&) = delete; + GeneralLossAlgorithm& operator=(const GeneralLossAlgorithm&) = delete; + ~GeneralLossAlgorithm() override {} + + void SetFromConfig(const QuicConfig& /*config*/, + Perspective /*perspective*/) override {} + + // Uses |largest_acked| and time to decide when packets are lost. + DetectionStats DetectLosses(const QuicUnackedPacketMap& unacked_packets, + QuicTime time, const RttStats& rtt_stats, + QuicPacketNumber largest_newly_acked, + const AckedPacketVector& packets_acked, + LostPacketVector* packets_lost) override; + + // Returns a non-zero value when the early retransmit timer is active. + QuicTime GetLossTimeout() const override; + + // Called to increases time and/or packet threshold. + void SpuriousLossDetected(const QuicUnackedPacketMap& unacked_packets, + const RttStats& rtt_stats, + QuicTime ack_receive_time, + QuicPacketNumber packet_number, + QuicPacketNumber previous_largest_acked) override; + + void OnConfigNegotiated() override { + QUICHE_DCHECK(false) + << "Unexpected call to GeneralLossAlgorithm::OnConfigNegotiated"; + } + + void OnMinRttAvailable() override { + QUICHE_DCHECK(false) + << "Unexpected call to GeneralLossAlgorithm::OnMinRttAvailable"; + } + + void OnUserAgentIdKnown() override { + QUICHE_DCHECK(false) + << "Unexpected call to GeneralLossAlgorithm::OnUserAgentIdKnown"; + } + + void OnConnectionClosed() override { + QUICHE_DCHECK(false) + << "Unexpected call to GeneralLossAlgorithm::OnConnectionClosed"; + } + + void OnReorderingDetected() override { + QUICHE_DCHECK(false) + << "Unexpected call to GeneralLossAlgorithm::OnReorderingDetected"; + } + + void Initialize(PacketNumberSpace packet_number_space, + LossDetectionInterface* parent); + + void Reset(); + + QuicPacketCount reordering_threshold() const { return reordering_threshold_; } + + int reordering_shift() const { return reordering_shift_; } + + void set_reordering_shift(int reordering_shift) { + reordering_shift_ = reordering_shift; + } + + void set_reordering_threshold(QuicPacketCount reordering_threshold) { + reordering_threshold_ = reordering_threshold; + } + + bool use_adaptive_reordering_threshold() const { + return use_adaptive_reordering_threshold_; + } + + void set_use_adaptive_reordering_threshold(bool value) { + use_adaptive_reordering_threshold_ = value; + } + + bool use_adaptive_time_threshold() const { + return use_adaptive_time_threshold_; + } + + void enable_adaptive_time_threshold() { use_adaptive_time_threshold_ = true; } + + bool use_packet_threshold_for_runt_packets() const { + return use_packet_threshold_for_runt_packets_; + } + + void disable_packet_threshold_for_runt_packets() { + use_packet_threshold_for_runt_packets_ = false; + } + + private: + LossDetectionInterface* parent_ = nullptr; + QuicTime loss_detection_timeout_ = QuicTime::Zero(); + // Fraction of a max(SRTT, latest_rtt) to permit reordering before declaring + // loss. Fraction calculated by shifting max(SRTT, latest_rtt) to the right + // by reordering_shift. + int reordering_shift_ = kDefaultLossDelayShift; + // Reordering threshold for loss detection. + QuicPacketCount reordering_threshold_ = kDefaultPacketReorderingThreshold; + // If true, uses adaptive reordering threshold for loss detection. + bool use_adaptive_reordering_threshold_ = true; + // If true, uses adaptive time threshold for time based loss detection. + bool use_adaptive_time_threshold_ = false; + // If true, uses packet threshold when largest acked is a runt packet. + bool use_packet_threshold_for_runt_packets_ = true; + // The least in flight packet. Loss detection should start from this. Please + // note, least_in_flight_ could be largest packet ever sent + 1. + QuicPacketNumber least_in_flight_{1}; + PacketNumberSpace packet_number_space_ = NUM_PACKET_NUMBER_SPACES; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CONGESTION_CONTROL_GENERAL_LOSS_ALGORITHM_H_ diff --git a/quiche/quic/core/congestion_control/general_loss_algorithm_test.cc b/quiche/quic/core/congestion_control/general_loss_algorithm_test.cc new file mode 100644 index 000000000000..8fa22e0377d1 --- /dev/null +++ b/quiche/quic/core/congestion_control/general_loss_algorithm_test.cc @@ -0,0 +1,488 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/general_loss_algorithm.h" + +#include +#include + +#include "quiche/quic/core/congestion_control/rtt_stats.h" +#include "quiche/quic/core/quic_unacked_packet_map.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/mock_clock.h" + +namespace quic { +namespace test { +namespace { + +// Default packet length. +const uint32_t kDefaultLength = 1000; + +class GeneralLossAlgorithmTest : public QuicTest { + protected: + GeneralLossAlgorithmTest() : unacked_packets_(Perspective::IS_CLIENT) { + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(100), + QuicTime::Delta::Zero(), clock_.Now()); + EXPECT_LT(0, rtt_stats_.smoothed_rtt().ToMicroseconds()); + loss_algorithm_.Initialize(HANDSHAKE_DATA, nullptr); + } + + ~GeneralLossAlgorithmTest() override {} + + void SendDataPacket(uint64_t packet_number, + QuicPacketLength encrypted_length) { + QuicStreamFrame frame; + frame.stream_id = QuicUtils::GetFirstBidirectionalStreamId( + CurrentSupportedVersions()[0].transport_version, + Perspective::IS_CLIENT); + SerializedPacket packet(QuicPacketNumber(packet_number), + PACKET_1BYTE_PACKET_NUMBER, nullptr, + encrypted_length, false, false); + packet.retransmittable_frames.push_back(QuicFrame(frame)); + unacked_packets_.AddSentPacket(&packet, NOT_RETRANSMISSION, clock_.Now(), + true, true, ECN_NOT_ECT); + } + + void SendDataPacket(uint64_t packet_number) { + SendDataPacket(packet_number, kDefaultLength); + } + + void SendAckPacket(uint64_t packet_number) { + SerializedPacket packet(QuicPacketNumber(packet_number), + PACKET_1BYTE_PACKET_NUMBER, nullptr, kDefaultLength, + true, false); + unacked_packets_.AddSentPacket(&packet, NOT_RETRANSMISSION, clock_.Now(), + false, true, ECN_NOT_ECT); + } + + void VerifyLosses(uint64_t largest_newly_acked, + const AckedPacketVector& packets_acked, + const std::vector& losses_expected) { + return VerifyLosses(largest_newly_acked, packets_acked, losses_expected, + absl::nullopt, absl::nullopt); + } + + void VerifyLosses( + uint64_t largest_newly_acked, const AckedPacketVector& packets_acked, + const std::vector& losses_expected, + absl::optional max_sequence_reordering_expected, + absl::optional + num_borderline_time_reorderings_expected) { + unacked_packets_.MaybeUpdateLargestAckedOfPacketNumberSpace( + APPLICATION_DATA, QuicPacketNumber(largest_newly_acked)); + LostPacketVector lost_packets; + LossDetectionInterface::DetectionStats stats = loss_algorithm_.DetectLosses( + unacked_packets_, clock_.Now(), rtt_stats_, + QuicPacketNumber(largest_newly_acked), packets_acked, &lost_packets); + if (max_sequence_reordering_expected.has_value()) { + EXPECT_EQ(stats.sent_packets_max_sequence_reordering, + max_sequence_reordering_expected.value()); + } + if (num_borderline_time_reorderings_expected.has_value()) { + EXPECT_EQ(stats.sent_packets_num_borderline_time_reorderings, + num_borderline_time_reorderings_expected.value()); + } + ASSERT_EQ(losses_expected.size(), lost_packets.size()); + for (size_t i = 0; i < losses_expected.size(); ++i) { + EXPECT_EQ(lost_packets[i].packet_number, + QuicPacketNumber(losses_expected[i])); + } + } + + QuicUnackedPacketMap unacked_packets_; + GeneralLossAlgorithm loss_algorithm_; + RttStats rtt_stats_; + MockClock clock_; +}; + +TEST_F(GeneralLossAlgorithmTest, NackRetransmit1Packet) { + const size_t kNumSentPackets = 5; + // Transmit 5 packets. + for (size_t i = 1; i <= kNumSentPackets; ++i) { + SendDataPacket(i); + } + AckedPacketVector packets_acked; + // No loss on one ack. + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(2)); + packets_acked.push_back(AckedPacket( + QuicPacketNumber(2), kMaxOutgoingPacketSize, QuicTime::Zero())); + VerifyLosses(2, packets_acked, std::vector{}, 1, 0); + packets_acked.clear(); + // No loss on two acks. + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(3)); + packets_acked.push_back(AckedPacket( + QuicPacketNumber(3), kMaxOutgoingPacketSize, QuicTime::Zero())); + VerifyLosses(3, packets_acked, std::vector{}, 2, 0); + packets_acked.clear(); + // Loss on three acks. + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(4)); + packets_acked.push_back(AckedPacket( + QuicPacketNumber(4), kMaxOutgoingPacketSize, QuicTime::Zero())); + VerifyLosses(4, packets_acked, {1}, 3, 0); + EXPECT_EQ(QuicTime::Zero(), loss_algorithm_.GetLossTimeout()); +} + +// A stretch ack is an ack that covers more than 1 packet of previously +// unacknowledged data. +TEST_F(GeneralLossAlgorithmTest, NackRetransmit1PacketWith1StretchAck) { + const size_t kNumSentPackets = 10; + // Transmit 10 packets. + for (size_t i = 1; i <= kNumSentPackets; ++i) { + SendDataPacket(i); + } + AckedPacketVector packets_acked; + // Nack the first packet 3 times in a single StretchAck. + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(2)); + packets_acked.push_back(AckedPacket( + QuicPacketNumber(2), kMaxOutgoingPacketSize, QuicTime::Zero())); + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(3)); + packets_acked.push_back(AckedPacket( + QuicPacketNumber(3), kMaxOutgoingPacketSize, QuicTime::Zero())); + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(4)); + packets_acked.push_back(AckedPacket( + QuicPacketNumber(4), kMaxOutgoingPacketSize, QuicTime::Zero())); + VerifyLosses(4, packets_acked, {1}); + EXPECT_EQ(QuicTime::Zero(), loss_algorithm_.GetLossTimeout()); +} + +// Ack a packet 3 packets ahead, causing a retransmit. +TEST_F(GeneralLossAlgorithmTest, NackRetransmit1PacketSingleAck) { + const size_t kNumSentPackets = 10; + // Transmit 10 packets. + for (size_t i = 1; i <= kNumSentPackets; ++i) { + SendDataPacket(i); + } + AckedPacketVector packets_acked; + // Nack the first packet 3 times in an AckFrame with three missing packets. + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(4)); + packets_acked.push_back(AckedPacket( + QuicPacketNumber(4), kMaxOutgoingPacketSize, QuicTime::Zero())); + VerifyLosses(4, packets_acked, {1}); + EXPECT_EQ(clock_.Now() + 1.25 * rtt_stats_.smoothed_rtt(), + loss_algorithm_.GetLossTimeout()); +} + +TEST_F(GeneralLossAlgorithmTest, EarlyRetransmit1Packet) { + const size_t kNumSentPackets = 2; + // Transmit 2 packets. + for (size_t i = 1; i <= kNumSentPackets; ++i) { + SendDataPacket(i); + } + AckedPacketVector packets_acked; + // Early retransmit when the final packet gets acked and the first is nacked. + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(2)); + packets_acked.push_back(AckedPacket( + QuicPacketNumber(2), kMaxOutgoingPacketSize, QuicTime::Zero())); + VerifyLosses(2, packets_acked, std::vector{}); + packets_acked.clear(); + EXPECT_EQ(clock_.Now() + 1.25 * rtt_stats_.smoothed_rtt(), + loss_algorithm_.GetLossTimeout()); + + clock_.AdvanceTime(1.13 * rtt_stats_.latest_rtt()); + // If reordering_shift increases by one we should have detected a loss. + VerifyLosses(2, packets_acked, {}, /*max_sequence_reordering_expected=*/1, + /*num_borderline_time_reorderings_expected=*/1); + + clock_.AdvanceTime(0.13 * rtt_stats_.latest_rtt()); + VerifyLosses(2, packets_acked, {1}); + EXPECT_EQ(QuicTime::Zero(), loss_algorithm_.GetLossTimeout()); +} + +TEST_F(GeneralLossAlgorithmTest, EarlyRetransmitAllPackets) { + const size_t kNumSentPackets = 5; + for (size_t i = 1; i <= kNumSentPackets; ++i) { + SendDataPacket(i); + // Advance the time 1/4 RTT between 3 and 4. + if (i == 3) { + clock_.AdvanceTime(0.25 * rtt_stats_.smoothed_rtt()); + } + } + AckedPacketVector packets_acked; + // Early retransmit when the final packet gets acked and 1.25 RTTs have + // elapsed since the packets were sent. + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(kNumSentPackets)); + packets_acked.push_back(AckedPacket(QuicPacketNumber(kNumSentPackets), + kMaxOutgoingPacketSize, + QuicTime::Zero())); + // This simulates a single ack following multiple missing packets with FACK. + VerifyLosses(kNumSentPackets, packets_acked, {1, 2}); + packets_acked.clear(); + // The time has already advanced 1/4 an RTT, so ensure the timeout is set + // 1.25 RTTs after the earliest pending packet(3), not the last(4). + EXPECT_EQ(clock_.Now() + rtt_stats_.smoothed_rtt(), + loss_algorithm_.GetLossTimeout()); + + clock_.AdvanceTime(rtt_stats_.smoothed_rtt()); + VerifyLosses(kNumSentPackets, packets_acked, {3}); + EXPECT_EQ(clock_.Now() + 0.25 * rtt_stats_.smoothed_rtt(), + loss_algorithm_.GetLossTimeout()); + clock_.AdvanceTime(0.25 * rtt_stats_.smoothed_rtt()); + VerifyLosses(kNumSentPackets, packets_acked, {4}); + EXPECT_EQ(QuicTime::Zero(), loss_algorithm_.GetLossTimeout()); +} + +TEST_F(GeneralLossAlgorithmTest, DontEarlyRetransmitNeuteredPacket) { + const size_t kNumSentPackets = 2; + // Transmit 2 packets. + for (size_t i = 1; i <= kNumSentPackets; ++i) { + SendDataPacket(i); + } + AckedPacketVector packets_acked; + // Neuter packet 1. + unacked_packets_.RemoveRetransmittability(QuicPacketNumber(1)); + clock_.AdvanceTime(rtt_stats_.smoothed_rtt()); + + // Early retransmit when the final packet gets acked and the first is nacked. + unacked_packets_.MaybeUpdateLargestAckedOfPacketNumberSpace( + APPLICATION_DATA, QuicPacketNumber(2)); + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(2)); + packets_acked.push_back(AckedPacket( + QuicPacketNumber(2), kMaxOutgoingPacketSize, QuicTime::Zero())); + VerifyLosses(2, packets_acked, std::vector{}); + EXPECT_EQ(clock_.Now() + 0.25 * rtt_stats_.smoothed_rtt(), + loss_algorithm_.GetLossTimeout()); +} + +TEST_F(GeneralLossAlgorithmTest, EarlyRetransmitWithLargerUnackablePackets) { + // Transmit 2 data packets and one ack. + SendDataPacket(1); + SendDataPacket(2); + SendAckPacket(3); + AckedPacketVector packets_acked; + clock_.AdvanceTime(rtt_stats_.smoothed_rtt()); + + // Early retransmit when the final packet gets acked and the first is nacked. + unacked_packets_.MaybeUpdateLargestAckedOfPacketNumberSpace( + APPLICATION_DATA, QuicPacketNumber(2)); + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(2)); + packets_acked.push_back(AckedPacket( + QuicPacketNumber(2), kMaxOutgoingPacketSize, QuicTime::Zero())); + VerifyLosses(2, packets_acked, std::vector{}); + packets_acked.clear(); + EXPECT_EQ(clock_.Now() + 0.25 * rtt_stats_.smoothed_rtt(), + loss_algorithm_.GetLossTimeout()); + + // The packet should be lost once the loss timeout is reached. + clock_.AdvanceTime(0.25 * rtt_stats_.latest_rtt()); + VerifyLosses(2, packets_acked, {1}); + EXPECT_EQ(QuicTime::Zero(), loss_algorithm_.GetLossTimeout()); +} + +TEST_F(GeneralLossAlgorithmTest, AlwaysLosePacketSent1RTTEarlier) { + // Transmit 1 packet and then wait an rtt plus 1ms. + SendDataPacket(1); + clock_.AdvanceTime(rtt_stats_.smoothed_rtt() + + QuicTime::Delta::FromMilliseconds(1)); + + // Transmit 2 packets. + SendDataPacket(2); + SendDataPacket(3); + AckedPacketVector packets_acked; + // Wait another RTT and ack 2. + clock_.AdvanceTime(rtt_stats_.smoothed_rtt()); + unacked_packets_.MaybeUpdateLargestAckedOfPacketNumberSpace( + APPLICATION_DATA, QuicPacketNumber(2)); + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(2)); + packets_acked.push_back(AckedPacket( + QuicPacketNumber(2), kMaxOutgoingPacketSize, QuicTime::Zero())); + VerifyLosses(2, packets_acked, {1}); +} + +TEST_F(GeneralLossAlgorithmTest, IncreaseTimeThresholdUponSpuriousLoss) { + loss_algorithm_.enable_adaptive_time_threshold(); + loss_algorithm_.set_reordering_shift(kDefaultLossDelayShift); + EXPECT_EQ(kDefaultLossDelayShift, loss_algorithm_.reordering_shift()); + EXPECT_TRUE(loss_algorithm_.use_adaptive_time_threshold()); + const size_t kNumSentPackets = 10; + // Transmit 2 packets at 1/10th an RTT interval. + for (size_t i = 1; i <= kNumSentPackets; ++i) { + SendDataPacket(i); + clock_.AdvanceTime(0.1 * rtt_stats_.smoothed_rtt()); + } + EXPECT_EQ(QuicTime::Zero() + rtt_stats_.smoothed_rtt(), clock_.Now()); + AckedPacketVector packets_acked; + // Expect the timer to not be set. + EXPECT_EQ(QuicTime::Zero(), loss_algorithm_.GetLossTimeout()); + // Packet 1 should not be lost until 1/4 RTTs pass. + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(2)); + packets_acked.push_back(AckedPacket( + QuicPacketNumber(2), kMaxOutgoingPacketSize, QuicTime::Zero())); + VerifyLosses(2, packets_acked, std::vector{}); + packets_acked.clear(); + // Expect the timer to be set to 1/4 RTT's in the future. + EXPECT_EQ(rtt_stats_.smoothed_rtt() * (1.0f / 4), + loss_algorithm_.GetLossTimeout() - clock_.Now()); + VerifyLosses(2, packets_acked, std::vector{}); + clock_.AdvanceTime(rtt_stats_.smoothed_rtt() * (1.0f / 4)); + VerifyLosses(2, packets_acked, {1}); + EXPECT_EQ(QuicTime::Zero(), loss_algorithm_.GetLossTimeout()); + // Retransmit packet 1 as 11 and 2 as 12. + SendDataPacket(11); + SendDataPacket(12); + + // Advance the time 1/4 RTT and indicate the loss was spurious. + // The new threshold should be 1/2 RTT. + clock_.AdvanceTime(rtt_stats_.smoothed_rtt() * (1.0f / 4)); + loss_algorithm_.SpuriousLossDetected(unacked_packets_, rtt_stats_, + clock_.Now(), QuicPacketNumber(1), + QuicPacketNumber(2)); + EXPECT_EQ(1, loss_algorithm_.reordering_shift()); +} + +TEST_F(GeneralLossAlgorithmTest, IncreaseReorderingThresholdUponSpuriousLoss) { + loss_algorithm_.set_use_adaptive_reordering_threshold(true); + for (size_t i = 1; i <= 4; ++i) { + SendDataPacket(i); + } + // Acking 4 causes 1 detected lost. + AckedPacketVector packets_acked; + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(4)); + packets_acked.push_back(AckedPacket( + QuicPacketNumber(4), kMaxOutgoingPacketSize, QuicTime::Zero())); + VerifyLosses(4, packets_acked, std::vector{1}); + packets_acked.clear(); + + // Retransmit 1 as 5. + SendDataPacket(5); + + // Acking 1 such that it was detected lost spuriously. + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(1)); + packets_acked.push_back(AckedPacket( + QuicPacketNumber(1), kMaxOutgoingPacketSize, QuicTime::Zero())); + loss_algorithm_.SpuriousLossDetected(unacked_packets_, rtt_stats_, + clock_.Now(), QuicPacketNumber(1), + QuicPacketNumber(4)); + VerifyLosses(4, packets_acked, std::vector{}); + packets_acked.clear(); + + // Verify acking 5 does not cause 2 detected lost. + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(5)); + packets_acked.push_back(AckedPacket( + QuicPacketNumber(5), kMaxOutgoingPacketSize, QuicTime::Zero())); + VerifyLosses(5, packets_acked, std::vector{}); + packets_acked.clear(); + + SendDataPacket(6); + + // Acking 6 will causes 2 detected lost. + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(6)); + packets_acked.push_back(AckedPacket( + QuicPacketNumber(6), kMaxOutgoingPacketSize, QuicTime::Zero())); + VerifyLosses(6, packets_acked, std::vector{2}); + packets_acked.clear(); + + // Retransmit 2 as 7. + SendDataPacket(7); + + // Acking 2 such that it was detected lost spuriously. + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(2)); + packets_acked.push_back(AckedPacket( + QuicPacketNumber(2), kMaxOutgoingPacketSize, QuicTime::Zero())); + loss_algorithm_.SpuriousLossDetected(unacked_packets_, rtt_stats_, + clock_.Now(), QuicPacketNumber(2), + QuicPacketNumber(6)); + VerifyLosses(6, packets_acked, std::vector{}); + packets_acked.clear(); + + // Acking 7 will not cause 3 as detected lost. + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(7)); + packets_acked.push_back(AckedPacket( + QuicPacketNumber(7), kMaxOutgoingPacketSize, QuicTime::Zero())); + VerifyLosses(7, packets_acked, std::vector{}); + packets_acked.clear(); +} + +TEST_F(GeneralLossAlgorithmTest, DefaultIetfLossDetection) { + loss_algorithm_.set_reordering_shift(kDefaultIetfLossDelayShift); + for (size_t i = 1; i <= 6; ++i) { + SendDataPacket(i); + } + // Packet threshold loss detection. + AckedPacketVector packets_acked; + // No loss on one ack. + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(2)); + packets_acked.push_back(AckedPacket( + QuicPacketNumber(2), kMaxOutgoingPacketSize, QuicTime::Zero())); + VerifyLosses(2, packets_acked, std::vector{}); + packets_acked.clear(); + // No loss on two acks. + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(3)); + packets_acked.push_back(AckedPacket( + QuicPacketNumber(3), kMaxOutgoingPacketSize, QuicTime::Zero())); + VerifyLosses(3, packets_acked, std::vector{}); + packets_acked.clear(); + // Loss on three acks. + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(4)); + packets_acked.push_back(AckedPacket( + QuicPacketNumber(4), kMaxOutgoingPacketSize, QuicTime::Zero())); + VerifyLosses(4, packets_acked, {1}); + EXPECT_EQ(QuicTime::Zero(), loss_algorithm_.GetLossTimeout()); + packets_acked.clear(); + + SendDataPacket(7); + + // Time threshold loss detection. + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(6)); + packets_acked.push_back(AckedPacket( + QuicPacketNumber(6), kMaxOutgoingPacketSize, QuicTime::Zero())); + VerifyLosses(6, packets_acked, std::vector{}); + packets_acked.clear(); + EXPECT_EQ(clock_.Now() + rtt_stats_.smoothed_rtt() + + (rtt_stats_.smoothed_rtt() >> 3), + loss_algorithm_.GetLossTimeout()); + clock_.AdvanceTime(rtt_stats_.smoothed_rtt() + + (rtt_stats_.smoothed_rtt() >> 3)); + VerifyLosses(6, packets_acked, {5}); + EXPECT_EQ(QuicTime::Zero(), loss_algorithm_.GetLossTimeout()); +} + +TEST_F(GeneralLossAlgorithmTest, IetfLossDetectionWithOneFourthRttDelay) { + loss_algorithm_.set_reordering_shift(2); + SendDataPacket(1); + SendDataPacket(2); + + AckedPacketVector packets_acked; + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(2)); + packets_acked.push_back(AckedPacket( + QuicPacketNumber(2), kMaxOutgoingPacketSize, QuicTime::Zero())); + VerifyLosses(2, packets_acked, std::vector{}); + packets_acked.clear(); + EXPECT_EQ(clock_.Now() + rtt_stats_.smoothed_rtt() + + (rtt_stats_.smoothed_rtt() >> 2), + loss_algorithm_.GetLossTimeout()); + clock_.AdvanceTime(rtt_stats_.smoothed_rtt() + + (rtt_stats_.smoothed_rtt() >> 2)); + VerifyLosses(2, packets_acked, {1}); + EXPECT_EQ(QuicTime::Zero(), loss_algorithm_.GetLossTimeout()); +} + +TEST_F(GeneralLossAlgorithmTest, NoPacketThresholdForRuntPackets) { + loss_algorithm_.disable_packet_threshold_for_runt_packets(); + for (size_t i = 1; i <= 6; ++i) { + SendDataPacket(i); + } + // Send a small packet. + SendDataPacket(7, /*encrypted_length=*/kDefaultLength / 2); + // No packet threshold for runt packet. + AckedPacketVector packets_acked; + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(7)); + packets_acked.push_back(AckedPacket( + QuicPacketNumber(7), kMaxOutgoingPacketSize, QuicTime::Zero())); + // Verify no packet is detected lost because packet 7 is a runt. + VerifyLosses(7, packets_acked, std::vector{}); + EXPECT_EQ(clock_.Now() + rtt_stats_.smoothed_rtt() + + (rtt_stats_.smoothed_rtt() >> 2), + loss_algorithm_.GetLossTimeout()); + clock_.AdvanceTime(rtt_stats_.smoothed_rtt() + + (rtt_stats_.smoothed_rtt() >> 2)); + // Verify packets are declared lost because time threshold has passed. + VerifyLosses(7, packets_acked, {1, 2, 3, 4, 5, 6}); + EXPECT_EQ(QuicTime::Zero(), loss_algorithm_.GetLossTimeout()); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/congestion_control/hybrid_slow_start.cc b/quiche/quic/core/congestion_control/hybrid_slow_start.cc new file mode 100644 index 000000000000..1eb99446941c --- /dev/null +++ b/quiche/quic/core/congestion_control/hybrid_slow_start.cc @@ -0,0 +1,104 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/hybrid_slow_start.h" + +#include + +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +// Note(pwestin): the magic clamping numbers come from the original code in +// tcp_cubic.c. +const int64_t kHybridStartLowWindow = 16; +// Number of delay samples for detecting the increase of delay. +const uint32_t kHybridStartMinSamples = 8; +// Exit slow start if the min rtt has increased by more than 1/8th. +const int kHybridStartDelayFactorExp = 3; // 2^3 = 8 +// The original paper specifies 2 and 8ms, but those have changed over time. +const int64_t kHybridStartDelayMinThresholdUs = 4000; +const int64_t kHybridStartDelayMaxThresholdUs = 16000; + +HybridSlowStart::HybridSlowStart() + : started_(false), + hystart_found_(NOT_FOUND), + rtt_sample_count_(0), + current_min_rtt_(QuicTime::Delta::Zero()) {} + +void HybridSlowStart::OnPacketAcked(QuicPacketNumber acked_packet_number) { + // OnPacketAcked gets invoked after ShouldExitSlowStart, so it's best to end + // the round when the final packet of the burst is received and start it on + // the next incoming ack. + if (IsEndOfRound(acked_packet_number)) { + started_ = false; + } +} + +void HybridSlowStart::OnPacketSent(QuicPacketNumber packet_number) { + last_sent_packet_number_ = packet_number; +} + +void HybridSlowStart::Restart() { + started_ = false; + hystart_found_ = NOT_FOUND; +} + +void HybridSlowStart::StartReceiveRound(QuicPacketNumber last_sent) { + QUIC_DVLOG(1) << "Reset hybrid slow start @" << last_sent; + end_packet_number_ = last_sent; + current_min_rtt_ = QuicTime::Delta::Zero(); + rtt_sample_count_ = 0; + started_ = true; +} + +bool HybridSlowStart::IsEndOfRound(QuicPacketNumber ack) const { + return !end_packet_number_.IsInitialized() || end_packet_number_ <= ack; +} + +bool HybridSlowStart::ShouldExitSlowStart(QuicTime::Delta latest_rtt, + QuicTime::Delta min_rtt, + QuicPacketCount congestion_window) { + if (!started_) { + // Time to start the hybrid slow start. + StartReceiveRound(last_sent_packet_number_); + } + if (hystart_found_ != NOT_FOUND) { + return true; + } + // Second detection parameter - delay increase detection. + // Compare the minimum delay (current_min_rtt_) of the current + // burst of packets relative to the minimum delay during the session. + // Note: we only look at the first few(8) packets in each burst, since we + // only want to compare the lowest RTT of the burst relative to previous + // bursts. + rtt_sample_count_++; + if (rtt_sample_count_ <= kHybridStartMinSamples) { + if (current_min_rtt_.IsZero() || current_min_rtt_ > latest_rtt) { + current_min_rtt_ = latest_rtt; + } + } + // We only need to check this once per round. + if (rtt_sample_count_ == kHybridStartMinSamples) { + // Divide min_rtt by 8 to get a rtt increase threshold for exiting. + int64_t min_rtt_increase_threshold_us = + min_rtt.ToMicroseconds() >> kHybridStartDelayFactorExp; + // Ensure the rtt threshold is never less than 2ms or more than 16ms. + min_rtt_increase_threshold_us = std::min(min_rtt_increase_threshold_us, + kHybridStartDelayMaxThresholdUs); + QuicTime::Delta min_rtt_increase_threshold = + QuicTime::Delta::FromMicroseconds(std::max( + min_rtt_increase_threshold_us, kHybridStartDelayMinThresholdUs)); + + if (current_min_rtt_ > min_rtt + min_rtt_increase_threshold) { + hystart_found_ = DELAY; + } + } + // Exit from slow start if the cwnd is greater than 16 and + // increasing delay is found. + return congestion_window >= kHybridStartLowWindow && + hystart_found_ != NOT_FOUND; +} + +} // namespace quic diff --git a/quiche/quic/core/congestion_control/hybrid_slow_start.h b/quiche/quic/core/congestion_control/hybrid_slow_start.h new file mode 100644 index 000000000000..73c7670b6716 --- /dev/null +++ b/quiche/quic/core/congestion_control/hybrid_slow_start.h @@ -0,0 +1,82 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This class is a helper class to TcpCubicSender. +// Slow start is the initial startup phase of TCP, it lasts until first packet +// loss. This class implements hybrid slow start of the TCP cubic send side +// congestion algorithm. The key feaure of hybrid slow start is that it tries to +// avoid running into the wall too hard during the slow start phase, which +// the traditional TCP implementation does. +// This does not implement ack train detection because it interacts poorly with +// pacing. +// http://netsrv.csc.ncsu.edu/export/hybridstart_pfldnet08.pdf +// http://research.csc.ncsu.edu/netsrv/sites/default/files/hystart_techreport_2008.pdf + +#ifndef QUICHE_QUIC_CORE_CONGESTION_CONTROL_HYBRID_SLOW_START_H_ +#define QUICHE_QUIC_CORE_CONGESTION_CONTROL_HYBRID_SLOW_START_H_ + +#include + +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class QUIC_EXPORT_PRIVATE HybridSlowStart { + public: + HybridSlowStart(); + HybridSlowStart(const HybridSlowStart&) = delete; + HybridSlowStart& operator=(const HybridSlowStart&) = delete; + + void OnPacketAcked(QuicPacketNumber acked_packet_number); + + void OnPacketSent(QuicPacketNumber packet_number); + + // ShouldExitSlowStart should be called on every new ack frame, since a new + // RTT measurement can be made then. + // rtt: the RTT for this ack packet. + // min_rtt: is the lowest delay (RTT) we have seen during the session. + // congestion_window: the congestion window in packets. + bool ShouldExitSlowStart(QuicTime::Delta rtt, QuicTime::Delta min_rtt, + QuicPacketCount congestion_window); + + // Start a new slow start phase. + void Restart(); + + // TODO(ianswett): The following methods should be private, but that requires + // a follow up CL to update the unit test. + // Returns true if this ack the last packet number of our current slow start + // round. + // Call Reset if this returns true. + bool IsEndOfRound(QuicPacketNumber ack) const; + + // Call for the start of each receive round (burst) in the slow start phase. + void StartReceiveRound(QuicPacketNumber last_sent); + + // Whether slow start has started. + bool started() const { return started_; } + + private: + // Whether a condition for exiting slow start has been found. + enum HystartState { + NOT_FOUND, + DELAY, // Too much increase in the round's min_rtt was observed. + }; + + // Whether the hybrid slow start has been started. + bool started_; + HystartState hystart_found_; + // Last packet number sent which was CWND limited. + QuicPacketNumber last_sent_packet_number_; + + // Variables for tracking acks received during a slow start round. + QuicPacketNumber end_packet_number_; // End of the receive round. + uint32_t rtt_sample_count_; // Number of rtt samples in the current round. + QuicTime::Delta current_min_rtt_; // The minimum rtt of current round. +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CONGESTION_CONTROL_HYBRID_SLOW_START_H_ diff --git a/quiche/quic/core/congestion_control/hybrid_slow_start_test.cc b/quiche/quic/core/congestion_control/hybrid_slow_start_test.cc new file mode 100644 index 000000000000..94654fc6b1bc --- /dev/null +++ b/quiche/quic/core/congestion_control/hybrid_slow_start_test.cc @@ -0,0 +1,76 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/hybrid_slow_start.h" + +#include +#include + +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { + +class HybridSlowStartTest : public QuicTest { + protected: + HybridSlowStartTest() + : one_ms_(QuicTime::Delta::FromMilliseconds(1)), + rtt_(QuicTime::Delta::FromMilliseconds(60)) {} + void SetUp() override { slow_start_ = std::make_unique(); } + const QuicTime::Delta one_ms_; + const QuicTime::Delta rtt_; + std::unique_ptr slow_start_; +}; + +TEST_F(HybridSlowStartTest, Simple) { + QuicPacketNumber packet_number(1); + QuicPacketNumber end_packet_number(3); + slow_start_->StartReceiveRound(end_packet_number); + + EXPECT_FALSE(slow_start_->IsEndOfRound(packet_number++)); + + // Test duplicates. + EXPECT_FALSE(slow_start_->IsEndOfRound(packet_number)); + + EXPECT_FALSE(slow_start_->IsEndOfRound(packet_number++)); + EXPECT_TRUE(slow_start_->IsEndOfRound(packet_number++)); + + // Test without a new registered end_packet_number; + EXPECT_TRUE(slow_start_->IsEndOfRound(packet_number++)); + + end_packet_number = QuicPacketNumber(20); + slow_start_->StartReceiveRound(end_packet_number); + while (packet_number < end_packet_number) { + EXPECT_FALSE(slow_start_->IsEndOfRound(packet_number++)); + } + EXPECT_TRUE(slow_start_->IsEndOfRound(packet_number++)); +} + +TEST_F(HybridSlowStartTest, Delay) { + // We expect to detect the increase at +1/8 of the RTT; hence at a typical + // RTT of 60ms the detection will happen at 67.5 ms. + const int kHybridStartMinSamples = 8; // Number of acks required to trigger. + + QuicPacketNumber end_packet_number(1); + slow_start_->StartReceiveRound(end_packet_number++); + + // Will not trigger since our lowest RTT in our burst is the same as the long + // term RTT provided. + for (int n = 0; n < kHybridStartMinSamples; ++n) { + EXPECT_FALSE(slow_start_->ShouldExitSlowStart( + rtt_ + QuicTime::Delta::FromMilliseconds(n), rtt_, 100)); + } + slow_start_->StartReceiveRound(end_packet_number++); + for (int n = 1; n < kHybridStartMinSamples; ++n) { + EXPECT_FALSE(slow_start_->ShouldExitSlowStart( + rtt_ + QuicTime::Delta::FromMilliseconds(n + 10), rtt_, 100)); + } + // Expect to trigger since all packets in this burst was above the long term + // RTT provided. + EXPECT_TRUE(slow_start_->ShouldExitSlowStart( + rtt_ + QuicTime::Delta::FromMilliseconds(10), rtt_, 100)); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/congestion_control/loss_detection_interface.h b/quiche/quic/core/congestion_control/loss_detection_interface.h new file mode 100644 index 000000000000..c81702ee2cf0 --- /dev/null +++ b/quiche/quic/core/congestion_control/loss_detection_interface.h @@ -0,0 +1,71 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// The pure virtual class for send side loss detection algorithm. + +#ifndef QUICHE_QUIC_CORE_CONGESTION_CONTROL_LOSS_DETECTION_INTERFACE_H_ +#define QUICHE_QUIC_CORE_CONGESTION_CONTROL_LOSS_DETECTION_INTERFACE_H_ + +#include "quiche/quic/core/congestion_control/send_algorithm_interface.h" +#include "quiche/quic/core/quic_config.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class QuicUnackedPacketMap; +class RttStats; + +class QUIC_EXPORT_PRIVATE LossDetectionInterface { + public: + virtual ~LossDetectionInterface() {} + + virtual void SetFromConfig(const QuicConfig& config, + Perspective perspective) = 0; + + struct QUIC_NO_EXPORT DetectionStats { + // Maximum sequence reordering observed in newly acked packets. + QuicPacketCount sent_packets_max_sequence_reordering = 0; + QuicPacketCount sent_packets_num_borderline_time_reorderings = 0; + // Total detection response time for lost packets from this detection. + // See QuicConnectionStats for the definition of detection response time. + float total_loss_detection_response_time = 0.0; + }; + + // Called when a new ack arrives or the loss alarm fires. + virtual DetectionStats DetectLosses( + const QuicUnackedPacketMap& unacked_packets, QuicTime time, + const RttStats& rtt_stats, QuicPacketNumber largest_newly_acked, + const AckedPacketVector& packets_acked, + LostPacketVector* packets_lost) = 0; + + // Get the time the LossDetectionAlgorithm wants to re-evaluate losses. + // Returns QuicTime::Zero if no alarm needs to be set. + virtual QuicTime GetLossTimeout() const = 0; + + // Called when |packet_number| was detected lost but gets acked later. + virtual void SpuriousLossDetected( + const QuicUnackedPacketMap& unacked_packets, const RttStats& rtt_stats, + QuicTime ack_receive_time, QuicPacketNumber packet_number, + QuicPacketNumber previous_largest_acked) = 0; + + virtual void OnConfigNegotiated() = 0; + + virtual void OnMinRttAvailable() = 0; + + virtual void OnUserAgentIdKnown() = 0; + + virtual void OnConnectionClosed() = 0; + + // Called when a reordering is detected by the loss algorithm, but _before_ + // the reordering_shift and reordering_threshold are consulted to see whether + // it is a loss. + virtual void OnReorderingDetected() = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CONGESTION_CONTROL_LOSS_DETECTION_INTERFACE_H_ diff --git a/quiche/quic/core/congestion_control/pacing_sender.cc b/quiche/quic/core/congestion_control/pacing_sender.cc new file mode 100644 index 000000000000..b4b6105443cb --- /dev/null +++ b/quiche/quic/core/congestion_control/pacing_sender.cc @@ -0,0 +1,167 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/pacing_sender.h" + +#include "quiche/quic/core/quic_bandwidth.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { +namespace { + +// Configured maximum size of the burst coming out of quiescence. The burst +// is never larger than the current CWND in packets. +static const uint32_t kInitialUnpacedBurst = 10; + +} // namespace + +PacingSender::PacingSender() + : sender_(nullptr), + max_pacing_rate_(QuicBandwidth::Zero()), + burst_tokens_(kInitialUnpacedBurst), + ideal_next_packet_send_time_(QuicTime::Zero()), + initial_burst_size_(kInitialUnpacedBurst), + lumpy_tokens_(0), + alarm_granularity_(kAlarmGranularity), + pacing_limited_(false) {} + +PacingSender::~PacingSender() {} + +void PacingSender::set_sender(SendAlgorithmInterface* sender) { + QUICHE_DCHECK(sender != nullptr); + sender_ = sender; +} + +void PacingSender::OnCongestionEvent(bool rtt_updated, + QuicByteCount bytes_in_flight, + QuicTime event_time, + const AckedPacketVector& acked_packets, + const LostPacketVector& lost_packets, + QuicPacketCount num_ect, + QuicPacketCount num_ce) { + QUICHE_DCHECK(sender_ != nullptr); + if (!lost_packets.empty()) { + // Clear any burst tokens when entering recovery. + burst_tokens_ = 0; + } + sender_->OnCongestionEvent(rtt_updated, bytes_in_flight, event_time, + acked_packets, lost_packets, num_ect, num_ce); +} + +void PacingSender::OnPacketSent( + QuicTime sent_time, QuicByteCount bytes_in_flight, + QuicPacketNumber packet_number, QuicByteCount bytes, + HasRetransmittableData has_retransmittable_data) { + QUICHE_DCHECK(sender_ != nullptr); + sender_->OnPacketSent(sent_time, bytes_in_flight, packet_number, bytes, + has_retransmittable_data); + if (has_retransmittable_data != HAS_RETRANSMITTABLE_DATA) { + return; + } + // If in recovery, the connection is not coming out of quiescence. + if (bytes_in_flight == 0 && !sender_->InRecovery()) { + // Add more burst tokens anytime the connection is leaving quiescence, but + // limit it to the equivalent of a single bulk write, not exceeding the + // current CWND in packets. + burst_tokens_ = std::min( + initial_burst_size_, + static_cast(sender_->GetCongestionWindow() / kDefaultTCPMSS)); + } + if (burst_tokens_ > 0) { + --burst_tokens_; + ideal_next_packet_send_time_ = QuicTime::Zero(); + pacing_limited_ = false; + return; + } + // The next packet should be sent as soon as the current packet has been + // transferred. PacingRate is based on bytes in flight including this packet. + QuicTime::Delta delay = + PacingRate(bytes_in_flight + bytes).TransferTime(bytes); + if (!pacing_limited_ || lumpy_tokens_ == 0) { + // Reset lumpy_tokens_ if either application or cwnd throttles sending or + // token runs out. + lumpy_tokens_ = std::max( + 1u, std::min(static_cast(GetQuicFlag(quic_lumpy_pacing_size)), + static_cast( + (sender_->GetCongestionWindow() * + GetQuicFlag(quic_lumpy_pacing_cwnd_fraction)) / + kDefaultTCPMSS))); + if (sender_->BandwidthEstimate() < + QuicBandwidth::FromKBitsPerSecond( + GetQuicFlag(quic_lumpy_pacing_min_bandwidth_kbps))) { + // Below 1.2Mbps, send 1 packet at once, because one full-sized packet + // is about 10ms of queueing. + lumpy_tokens_ = 1u; + } + if ((bytes_in_flight + bytes) >= sender_->GetCongestionWindow()) { + // Don't add lumpy_tokens if the congestion controller is CWND limited. + lumpy_tokens_ = 1u; + } + } + --lumpy_tokens_; + if (pacing_limited_) { + // Make up for lost time since pacing throttles the sending. + ideal_next_packet_send_time_ = ideal_next_packet_send_time_ + delay; + } else { + ideal_next_packet_send_time_ = + std::max(ideal_next_packet_send_time_ + delay, sent_time + delay); + } + // Stop making up for lost time if underlying sender prevents sending. + pacing_limited_ = sender_->CanSend(bytes_in_flight + bytes); +} + +void PacingSender::OnApplicationLimited() { + // The send is application limited, stop making up for lost time. + pacing_limited_ = false; +} + +void PacingSender::SetBurstTokens(uint32_t burst_tokens) { + initial_burst_size_ = burst_tokens; + burst_tokens_ = std::min( + initial_burst_size_, + static_cast(sender_->GetCongestionWindow() / kDefaultTCPMSS)); +} + +QuicTime::Delta PacingSender::TimeUntilSend( + QuicTime now, QuicByteCount bytes_in_flight) const { + QUICHE_DCHECK(sender_ != nullptr); + + if (!sender_->CanSend(bytes_in_flight)) { + // The underlying sender prevents sending. + return QuicTime::Delta::Infinite(); + } + + if (burst_tokens_ > 0 || bytes_in_flight == 0 || lumpy_tokens_ > 0) { + // Don't pace if we have burst tokens available or leaving quiescence. + QUIC_DVLOG(1) << "Sending packet now. burst_tokens:" << burst_tokens_ + << ", bytes_in_flight:" << bytes_in_flight + << ", lumpy_tokens:" << lumpy_tokens_; + return QuicTime::Delta::Zero(); + } + + // If the next send time is within the alarm granularity, send immediately. + if (ideal_next_packet_send_time_ > now + alarm_granularity_) { + QUIC_DVLOG(1) << "Delaying packet: " + << (ideal_next_packet_send_time_ - now).ToMicroseconds(); + return ideal_next_packet_send_time_ - now; + } + + QUIC_DVLOG(1) << "Sending packet now. ideal_next_packet_send_time: " + << ideal_next_packet_send_time_ << ", now: " << now; + return QuicTime::Delta::Zero(); +} + +QuicBandwidth PacingSender::PacingRate(QuicByteCount bytes_in_flight) const { + QUICHE_DCHECK(sender_ != nullptr); + if (!max_pacing_rate_.IsZero()) { + return QuicBandwidth::FromBitsPerSecond( + std::min(max_pacing_rate_.ToBitsPerSecond(), + sender_->PacingRate(bytes_in_flight).ToBitsPerSecond())); + } + return sender_->PacingRate(bytes_in_flight); +} + +} // namespace quic diff --git a/quiche/quic/core/congestion_control/pacing_sender.h b/quiche/quic/core/congestion_control/pacing_sender.h new file mode 100644 index 000000000000..0e3de0098ff0 --- /dev/null +++ b/quiche/quic/core/congestion_control/pacing_sender.h @@ -0,0 +1,114 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// A send algorithm that adds pacing on top of an another send algorithm. +// It uses the underlying sender's pacing rate to schedule packets. +// It also takes into consideration the expected granularity of the underlying +// alarm to ensure that alarms are not set too aggressively, and err towards +// sending packets too early instead of too late. + +#ifndef QUICHE_QUIC_CORE_CONGESTION_CONTROL_PACING_SENDER_H_ +#define QUICHE_QUIC_CORE_CONGESTION_CONTROL_PACING_SENDER_H_ + +#include +#include +#include + +#include "quiche/quic/core/congestion_control/send_algorithm_interface.h" +#include "quiche/quic/core/quic_bandwidth.h" +#include "quiche/quic/core/quic_config.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +namespace test { +class QuicSentPacketManagerPeer; +} // namespace test + +class QUIC_EXPORT_PRIVATE PacingSender { + public: + PacingSender(); + PacingSender(const PacingSender&) = delete; + PacingSender& operator=(const PacingSender&) = delete; + ~PacingSender(); + + // Sets the underlying sender. Does not take ownership of |sender|. |sender| + // must not be null. This must be called before any of the + // SendAlgorithmInterface wrapper methods are called. + void set_sender(SendAlgorithmInterface* sender); + + void set_max_pacing_rate(QuicBandwidth max_pacing_rate) { + max_pacing_rate_ = max_pacing_rate; + } + + void set_alarm_granularity(QuicTime::Delta alarm_granularity) { + alarm_granularity_ = alarm_granularity; + } + + QuicBandwidth max_pacing_rate() const { return max_pacing_rate_; } + + void OnCongestionEvent(bool rtt_updated, QuicByteCount bytes_in_flight, + QuicTime event_time, + const AckedPacketVector& acked_packets, + const LostPacketVector& lost_packets, + QuicPacketCount num_ect, QuicPacketCount num_ce); + + void OnPacketSent(QuicTime sent_time, QuicByteCount bytes_in_flight, + QuicPacketNumber packet_number, QuicByteCount bytes, + HasRetransmittableData has_retransmittable_data); + + // Called when application throttles the sending, so that pacing sender stops + // making up for lost time. + void OnApplicationLimited(); + + // Set burst_tokens_ and initial_burst_size_. + void SetBurstTokens(uint32_t burst_tokens); + + QuicTime::Delta TimeUntilSend(QuicTime now, + QuicByteCount bytes_in_flight) const; + + QuicBandwidth PacingRate(QuicByteCount bytes_in_flight) const; + + NextReleaseTimeResult GetNextReleaseTime() const { + bool allow_burst = (burst_tokens_ > 0 || lumpy_tokens_ > 0); + return {ideal_next_packet_send_time_, allow_burst}; + } + + uint32_t initial_burst_size() const { return initial_burst_size_; } + + protected: + uint32_t lumpy_tokens() const { return lumpy_tokens_; } + + private: + friend class test::QuicSentPacketManagerPeer; + + // Underlying sender. Not owned. + SendAlgorithmInterface* sender_; + // If not QuicBandidth::Zero, the maximum rate the PacingSender will use. + QuicBandwidth max_pacing_rate_; + + // Number of unpaced packets to be sent before packets are delayed. + uint32_t burst_tokens_; + QuicTime ideal_next_packet_send_time_; // When can the next packet be sent. + uint32_t initial_burst_size_; + + // Number of unpaced packets to be sent before packets are delayed. This token + // is consumed after burst_tokens_ ran out. + uint32_t lumpy_tokens_; + + // If the next send time is within alarm_granularity_, send immediately. + // TODO(fayang): Remove alarm_granularity_ when deprecating + // quic_offload_pacing_to_usps2 flag. + QuicTime::Delta alarm_granularity_; + + // Indicates whether pacing throttles the sending. If true, make up for lost + // time. + bool pacing_limited_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CONGESTION_CONTROL_PACING_SENDER_H_ diff --git a/quiche/quic/core/congestion_control/pacing_sender_test.cc b/quiche/quic/core/congestion_control/pacing_sender_test.cc new file mode 100644 index 000000000000..108b12e761a4 --- /dev/null +++ b/quiche/quic/core/congestion_control/pacing_sender_test.cc @@ -0,0 +1,585 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/pacing_sender.h" + +#include +#include + +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/mock_clock.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +using testing::_; +using testing::AtMost; +using testing::Return; +using testing::StrictMock; + +namespace quic { +namespace test { + +const QuicByteCount kBytesInFlight = 1024; +const int kInitialBurstPackets = 10; + +class TestPacingSender : public PacingSender { + public: + using PacingSender::lumpy_tokens; + using PacingSender::PacingSender; + + QuicTime ideal_next_packet_send_time() const { + return GetNextReleaseTime().release_time; + } +}; + +class PacingSenderTest : public QuicTest { + protected: + PacingSenderTest() + : zero_time_(QuicTime::Delta::Zero()), + infinite_time_(QuicTime::Delta::Infinite()), + packet_number_(1), + mock_sender_(new StrictMock()), + pacing_sender_(new TestPacingSender) { + pacing_sender_->set_sender(mock_sender_.get()); + // Pick arbitrary time. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(9)); + } + + ~PacingSenderTest() override {} + + void InitPacingRate(QuicPacketCount burst_size, QuicBandwidth bandwidth) { + mock_sender_ = std::make_unique>(); + pacing_sender_ = std::make_unique(); + pacing_sender_->set_sender(mock_sender_.get()); + EXPECT_CALL(*mock_sender_, PacingRate(_)).WillRepeatedly(Return(bandwidth)); + EXPECT_CALL(*mock_sender_, BandwidthEstimate()) + .WillRepeatedly(Return(bandwidth)); + if (burst_size == 0) { + EXPECT_CALL(*mock_sender_, OnCongestionEvent(_, _, _, _, _, _, _)); + LostPacketVector lost_packets; + lost_packets.push_back( + LostPacket(QuicPacketNumber(1), kMaxOutgoingPacketSize)); + AckedPacketVector empty; + pacing_sender_->OnCongestionEvent(true, 1234, clock_.Now(), empty, + lost_packets, 0, 0); + } else if (burst_size != kInitialBurstPackets) { + QUIC_LOG(FATAL) << "Unsupported burst_size " << burst_size + << " specificied, only 0 and " << kInitialBurstPackets + << " are supported."; + } + } + + void CheckPacketIsSentImmediately(HasRetransmittableData retransmittable_data, + QuicByteCount prior_in_flight, + bool in_recovery, QuicPacketCount cwnd) { + // In order for the packet to be sendable, the underlying sender must + // permit it to be sent immediately. + for (int i = 0; i < 2; ++i) { + EXPECT_CALL(*mock_sender_, CanSend(prior_in_flight)) + .WillOnce(Return(true)); + // Verify that the packet can be sent immediately. + EXPECT_EQ(zero_time_, + pacing_sender_->TimeUntilSend(clock_.Now(), prior_in_flight)); + } + + // Actually send the packet. + if (prior_in_flight == 0) { + EXPECT_CALL(*mock_sender_, InRecovery()).WillOnce(Return(in_recovery)); + } + EXPECT_CALL(*mock_sender_, + OnPacketSent(clock_.Now(), prior_in_flight, packet_number_, + kMaxOutgoingPacketSize, retransmittable_data)); + EXPECT_CALL(*mock_sender_, GetCongestionWindow()) + .WillRepeatedly(Return(cwnd * kDefaultTCPMSS)); + EXPECT_CALL(*mock_sender_, + CanSend(prior_in_flight + kMaxOutgoingPacketSize)) + .Times(AtMost(1)) + .WillRepeatedly(Return((prior_in_flight + kMaxOutgoingPacketSize) < + (cwnd * kDefaultTCPMSS))); + pacing_sender_->OnPacketSent(clock_.Now(), prior_in_flight, + packet_number_++, kMaxOutgoingPacketSize, + retransmittable_data); + } + + void CheckPacketIsSentImmediately() { + CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, kBytesInFlight, + false, 10); + } + + void CheckPacketIsDelayed(QuicTime::Delta delay) { + // In order for the packet to be sendable, the underlying sender must + // permit it to be sent immediately. + for (int i = 0; i < 2; ++i) { + EXPECT_CALL(*mock_sender_, CanSend(kBytesInFlight)) + .WillOnce(Return(true)); + // Verify that the packet is delayed. + EXPECT_EQ(delay.ToMicroseconds(), + pacing_sender_->TimeUntilSend(clock_.Now(), kBytesInFlight) + .ToMicroseconds()); + } + } + + void UpdateRtt() { + EXPECT_CALL(*mock_sender_, + OnCongestionEvent(true, kBytesInFlight, _, _, _, _, _)); + AckedPacketVector empty_acked; + LostPacketVector empty_lost; + pacing_sender_->OnCongestionEvent(true, kBytesInFlight, clock_.Now(), + empty_acked, empty_lost, 0, 0); + } + + void OnApplicationLimited() { pacing_sender_->OnApplicationLimited(); } + + const QuicTime::Delta zero_time_; + const QuicTime::Delta infinite_time_; + MockClock clock_; + QuicPacketNumber packet_number_; + std::unique_ptr> mock_sender_; + std::unique_ptr pacing_sender_; +}; + +TEST_F(PacingSenderTest, NoSend) { + for (int i = 0; i < 2; ++i) { + EXPECT_CALL(*mock_sender_, CanSend(kBytesInFlight)).WillOnce(Return(false)); + EXPECT_EQ(infinite_time_, + pacing_sender_->TimeUntilSend(clock_.Now(), kBytesInFlight)); + } +} + +TEST_F(PacingSenderTest, SendNow) { + for (int i = 0; i < 2; ++i) { + EXPECT_CALL(*mock_sender_, CanSend(kBytesInFlight)).WillOnce(Return(true)); + EXPECT_EQ(zero_time_, + pacing_sender_->TimeUntilSend(clock_.Now(), kBytesInFlight)); + } +} + +TEST_F(PacingSenderTest, VariousSending) { + // Configure pacing rate of 1 packet per 1 ms, no initial burst. + InitPacingRate( + 0, QuicBandwidth::FromBytesAndTimeDelta( + kMaxOutgoingPacketSize, QuicTime::Delta::FromMilliseconds(1))); + + // Now update the RTT and verify that packets are actually paced. + UpdateRtt(); + + CheckPacketIsSentImmediately(); + CheckPacketIsSentImmediately(); + + // The first packet was a "make up", then we sent two packets "into the + // future", so the delay should be 2. + CheckPacketIsDelayed(QuicTime::Delta::FromMilliseconds(2)); + + // Wake up on time. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(2)); + CheckPacketIsSentImmediately(); + CheckPacketIsSentImmediately(); + CheckPacketIsDelayed(QuicTime::Delta::FromMilliseconds(2)); + + // Wake up late. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(4)); + CheckPacketIsSentImmediately(); + CheckPacketIsSentImmediately(); + CheckPacketIsSentImmediately(); + CheckPacketIsSentImmediately(); + CheckPacketIsDelayed(QuicTime::Delta::FromMilliseconds(2)); + + // Wake up really late. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(8)); + CheckPacketIsSentImmediately(); + CheckPacketIsSentImmediately(); + CheckPacketIsSentImmediately(); + CheckPacketIsSentImmediately(); + CheckPacketIsSentImmediately(); + CheckPacketIsSentImmediately(); + CheckPacketIsSentImmediately(); + CheckPacketIsSentImmediately(); + CheckPacketIsDelayed(QuicTime::Delta::FromMilliseconds(2)); + + // Wake up really late again, but application pause partway through. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(8)); + CheckPacketIsSentImmediately(); + CheckPacketIsSentImmediately(); + OnApplicationLimited(); + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(100)); + CheckPacketIsSentImmediately(); + CheckPacketIsSentImmediately(); + CheckPacketIsDelayed(QuicTime::Delta::FromMilliseconds(2)); + // Wake up early, but after enough time has passed to permit a send. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(1)); + CheckPacketIsSentImmediately(); +} + +TEST_F(PacingSenderTest, InitialBurst) { + // Configure pacing rate of 1 packet per 1 ms. + InitPacingRate( + 10, QuicBandwidth::FromBytesAndTimeDelta( + kMaxOutgoingPacketSize, QuicTime::Delta::FromMilliseconds(1))); + + // Update the RTT and verify that the first 10 packets aren't paced. + UpdateRtt(); + + // Send 10 packets, and verify that they are not paced. + for (int i = 0; i < kInitialBurstPackets; ++i) { + CheckPacketIsSentImmediately(); + } + + // The first packet was a "make up", then we sent two packets "into the + // future", so the delay should be 2ms. + CheckPacketIsSentImmediately(); + CheckPacketIsSentImmediately(); + CheckPacketIsDelayed(QuicTime::Delta::FromMilliseconds(2)); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + CheckPacketIsSentImmediately(); + + // Next time TimeUntilSend is called with no bytes in flight, pacing should + // allow a packet to be sent, and when it's sent, the tokens are refilled. + CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, 0, false, 10); + for (int i = 0; i < kInitialBurstPackets - 1; ++i) { + CheckPacketIsSentImmediately(); + } + + // The first packet was a "make up", then we sent two packets "into the + // future", so the delay should be 2ms. + CheckPacketIsSentImmediately(); + CheckPacketIsSentImmediately(); + CheckPacketIsDelayed(QuicTime::Delta::FromMilliseconds(2)); +} + +TEST_F(PacingSenderTest, InitialBurstNoRttMeasurement) { + // Configure pacing rate of 1 packet per 1 ms. + InitPacingRate( + 10, QuicBandwidth::FromBytesAndTimeDelta( + kMaxOutgoingPacketSize, QuicTime::Delta::FromMilliseconds(1))); + + // Send 10 packets, and verify that they are not paced. + for (int i = 0; i < kInitialBurstPackets; ++i) { + CheckPacketIsSentImmediately(); + } + + // The first packet was a "make up", then we sent two packets "into the + // future", so the delay should be 2ms. + CheckPacketIsSentImmediately(); + CheckPacketIsSentImmediately(); + CheckPacketIsDelayed(QuicTime::Delta::FromMilliseconds(2)); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + CheckPacketIsSentImmediately(); + + // Next time TimeUntilSend is called with no bytes in flight, the tokens + // should be refilled and there should be no delay. + CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, 0, false, 10); + // Send 10 packets, and verify that they are not paced. + for (int i = 0; i < kInitialBurstPackets - 1; ++i) { + CheckPacketIsSentImmediately(); + } + + // The first packet was a "make up", then we sent two packets "into the + // future", so the delay should be 2ms. + CheckPacketIsSentImmediately(); + CheckPacketIsSentImmediately(); + CheckPacketIsDelayed(QuicTime::Delta::FromMilliseconds(2)); +} + +TEST_F(PacingSenderTest, FastSending) { + // Ensure the pacing sender paces, even when the inter-packet spacing(0.5ms) + // is less than the pacing granularity(1ms). + InitPacingRate(10, QuicBandwidth::FromBytesAndTimeDelta( + 2 * kMaxOutgoingPacketSize, + QuicTime::Delta::FromMilliseconds(1))); + // Update the RTT and verify that the first 10 packets aren't paced. + UpdateRtt(); + + // Send 10 packets, and verify that they are not paced. + for (int i = 0; i < kInitialBurstPackets; ++i) { + CheckPacketIsSentImmediately(); + } + + CheckPacketIsSentImmediately(); // Make up + CheckPacketIsSentImmediately(); // Lumpy token + CheckPacketIsSentImmediately(); // "In the future" but within granularity. + CheckPacketIsSentImmediately(); // Lumpy token + CheckPacketIsDelayed(QuicTime::Delta::FromMicroseconds(2000)); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + CheckPacketIsSentImmediately(); + + // Next time TimeUntilSend is called with no bytes in flight, the tokens + // should be refilled and there should be no delay. + CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, 0, false, 10); + for (int i = 0; i < kInitialBurstPackets - 1; ++i) { + CheckPacketIsSentImmediately(); + } + + // The first packet was a "make up", then we sent two packets "into the + // future", so the delay should be 1.5ms. + CheckPacketIsSentImmediately(); // Make up + CheckPacketIsSentImmediately(); // Lumpy token + CheckPacketIsSentImmediately(); // "In the future" but within granularity. + CheckPacketIsSentImmediately(); // Lumpy token + CheckPacketIsDelayed(QuicTime::Delta::FromMicroseconds(2000)); +} + +TEST_F(PacingSenderTest, NoBurstEnteringRecovery) { + // Configure pacing rate of 1 packet per 1 ms with no burst tokens. + InitPacingRate( + 0, QuicBandwidth::FromBytesAndTimeDelta( + kMaxOutgoingPacketSize, QuicTime::Delta::FromMilliseconds(1))); + // Sending a packet will set burst tokens. + CheckPacketIsSentImmediately(); + + // Losing a packet will set clear burst tokens. + LostPacketVector lost_packets; + lost_packets.push_back( + LostPacket(QuicPacketNumber(1), kMaxOutgoingPacketSize)); + AckedPacketVector empty_acked; + EXPECT_CALL(*mock_sender_, OnCongestionEvent(true, kMaxOutgoingPacketSize, _, + testing::IsEmpty(), _, _, _)); + pacing_sender_->OnCongestionEvent(true, kMaxOutgoingPacketSize, clock_.Now(), + empty_acked, lost_packets, 0, 0); + // One packet is sent immediately, because of 1ms pacing granularity. + CheckPacketIsSentImmediately(); + // Ensure packets are immediately paced. + EXPECT_CALL(*mock_sender_, CanSend(kMaxOutgoingPacketSize)) + .WillOnce(Return(true)); + // Verify the next packet is paced and delayed 2ms due to granularity. + EXPECT_EQ( + QuicTime::Delta::FromMilliseconds(2), + pacing_sender_->TimeUntilSend(clock_.Now(), kMaxOutgoingPacketSize)); + CheckPacketIsDelayed(QuicTime::Delta::FromMilliseconds(2)); +} + +TEST_F(PacingSenderTest, NoBurstInRecovery) { + // Configure pacing rate of 1 packet per 1 ms with no burst tokens. + InitPacingRate( + 0, QuicBandwidth::FromBytesAndTimeDelta( + kMaxOutgoingPacketSize, QuicTime::Delta::FromMilliseconds(1))); + + UpdateRtt(); + + // Ensure only one packet is sent immediately and the rest are paced. + CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, 0, true, 10); + CheckPacketIsSentImmediately(); + CheckPacketIsDelayed(QuicTime::Delta::FromMilliseconds(2)); +} + +TEST_F(PacingSenderTest, CwndLimited) { + // Configure pacing rate of 1 packet per 1 ms, no initial burst. + InitPacingRate( + 0, QuicBandwidth::FromBytesAndTimeDelta( + kMaxOutgoingPacketSize, QuicTime::Delta::FromMilliseconds(1))); + + UpdateRtt(); + + CheckPacketIsSentImmediately(); + CheckPacketIsSentImmediately(); + // Packet 3 will be delayed 2ms. + CheckPacketIsDelayed(QuicTime::Delta::FromMilliseconds(2)); + + // Wake up on time. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(2)); + // After sending packet 3, cwnd is limited. + // This test is slightly odd because bytes_in_flight is calculated using + // kMaxOutgoingPacketSize and CWND is calculated using kDefaultTCPMSS, + // which is 8 bytes larger, so 3 packets can be sent for a CWND of 2. + CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, + 2 * kMaxOutgoingPacketSize, false, 2); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(100)); + // Verify pacing sender stops making up for lost time after sending packet 3. + // Packet 6 will be delayed 2ms. + CheckPacketIsSentImmediately(); + CheckPacketIsSentImmediately(); + CheckPacketIsDelayed(QuicTime::Delta::FromMilliseconds(2)); +} + +TEST_F(PacingSenderTest, LumpyPacingWithInitialBurstToken) { + // Set lumpy size to be 3, and cwnd faction to 0.5 + SetQuicFlag(quic_lumpy_pacing_size, 3); + SetQuicFlag(quic_lumpy_pacing_cwnd_fraction, 0.5f); + // Configure pacing rate of 1 packet per 1 ms. + InitPacingRate( + 10, QuicBandwidth::FromBytesAndTimeDelta( + kMaxOutgoingPacketSize, QuicTime::Delta::FromMilliseconds(1))); + UpdateRtt(); + + // Send 10 packets, and verify that they are not paced. + for (int i = 0; i < kInitialBurstPackets; ++i) { + CheckPacketIsSentImmediately(); + } + + CheckPacketIsSentImmediately(); + CheckPacketIsSentImmediately(); + CheckPacketIsSentImmediately(); + // Packet 14 will be delayed 3ms. + CheckPacketIsDelayed(QuicTime::Delta::FromMilliseconds(3)); + + // Wake up on time. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(3)); + CheckPacketIsSentImmediately(); + CheckPacketIsSentImmediately(); + CheckPacketIsSentImmediately(); + // Packet 17 will be delayed 3ms. + CheckPacketIsDelayed(QuicTime::Delta::FromMilliseconds(3)); + + // Application throttles sending. + OnApplicationLimited(); + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(100)); + CheckPacketIsSentImmediately(); + CheckPacketIsSentImmediately(); + CheckPacketIsSentImmediately(); + // Packet 20 will be delayed 3ms. + CheckPacketIsDelayed(QuicTime::Delta::FromMilliseconds(3)); + + // Wake up on time. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(3)); + CheckPacketIsSentImmediately(); + // After sending packet 21, cwnd is limited. + // This test is slightly odd because bytes_in_flight is calculated using + // kMaxOutgoingPacketSize and CWND is calculated using kDefaultTCPMSS, + // which is 8 bytes larger, so 21 packets can be sent for a CWND of 20. + CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, + 20 * kMaxOutgoingPacketSize, false, 20); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(100)); + // Suppose cwnd size is 5, so that lumpy size becomes 2. + CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, kBytesInFlight, false, + 5); + CheckPacketIsSentImmediately(); + // Packet 24 will be delayed 2ms. + CheckPacketIsDelayed(QuicTime::Delta::FromMilliseconds(2)); +} + +TEST_F(PacingSenderTest, NoLumpyPacingForLowBandwidthFlows) { + // Set lumpy size to be 3, and cwnd fraction to 0.5 + SetQuicFlag(quic_lumpy_pacing_size, 3); + SetQuicFlag(quic_lumpy_pacing_cwnd_fraction, 0.5f); + + // Configure pacing rate of 1 packet per 100 ms. + QuicTime::Delta inter_packet_delay = QuicTime::Delta::FromMilliseconds(100); + InitPacingRate(kInitialBurstPackets, + QuicBandwidth::FromBytesAndTimeDelta(kMaxOutgoingPacketSize, + inter_packet_delay)); + UpdateRtt(); + + // Send kInitialBurstPackets packets, and verify that they are not paced. + for (int i = 0; i < kInitialBurstPackets; ++i) { + CheckPacketIsSentImmediately(); + } + + // The first packet after burst token exhausted is also sent immediately, + // because ideal_next_packet_send_time has not been set yet. + CheckPacketIsSentImmediately(); + + for (int i = 0; i < 200; ++i) { + CheckPacketIsDelayed(inter_packet_delay); + } +} + +// Regression test for b/184471302 to ensure that ACKs received back-to-back +// don't cause bursts in sending. +TEST_F(PacingSenderTest, NoBurstsForLumpyPacingWithAckAggregation) { + // Configure pacing rate of 1 packet per millisecond. + QuicTime::Delta inter_packet_delay = QuicTime::Delta::FromMilliseconds(1); + InitPacingRate(kInitialBurstPackets, + QuicBandwidth::FromBytesAndTimeDelta(kMaxOutgoingPacketSize, + inter_packet_delay)); + UpdateRtt(); + + // Send kInitialBurstPackets packets, and verify that they are not paced. + for (int i = 0; i < kInitialBurstPackets; ++i) { + CheckPacketIsSentImmediately(); + } + // The last packet of the burst causes the sender to be CWND limited. + CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, + 10 * kMaxOutgoingPacketSize, false, 10); + + // The last sent packet made the connection CWND limited, so no lumpy tokens + // should be available. + EXPECT_EQ(0u, pacing_sender_->lumpy_tokens()); + CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, + 10 * kMaxOutgoingPacketSize, false, 10); + EXPECT_EQ(0u, pacing_sender_->lumpy_tokens()); + CheckPacketIsDelayed(2 * inter_packet_delay); +} + +TEST_F(PacingSenderTest, IdealNextPacketSendTimeWithLumpyPacing) { + // Set lumpy size to be 3, and cwnd faction to 0.5 + SetQuicFlag(quic_lumpy_pacing_size, 3); + SetQuicFlag(quic_lumpy_pacing_cwnd_fraction, 0.5f); + + // Configure pacing rate of 1 packet per millisecond. + QuicTime::Delta inter_packet_delay = QuicTime::Delta::FromMilliseconds(1); + InitPacingRate(kInitialBurstPackets, + QuicBandwidth::FromBytesAndTimeDelta(kMaxOutgoingPacketSize, + inter_packet_delay)); + + // Send kInitialBurstPackets packets, and verify that they are not paced. + for (int i = 0; i < kInitialBurstPackets; ++i) { + CheckPacketIsSentImmediately(); + } + + CheckPacketIsSentImmediately(); + EXPECT_EQ(pacing_sender_->ideal_next_packet_send_time(), + clock_.Now() + inter_packet_delay); + EXPECT_EQ(pacing_sender_->lumpy_tokens(), 2u); + + CheckPacketIsSentImmediately(); + EXPECT_EQ(pacing_sender_->ideal_next_packet_send_time(), + clock_.Now() + 2 * inter_packet_delay); + EXPECT_EQ(pacing_sender_->lumpy_tokens(), 1u); + + CheckPacketIsSentImmediately(); + EXPECT_EQ(pacing_sender_->ideal_next_packet_send_time(), + clock_.Now() + 3 * inter_packet_delay); + EXPECT_EQ(pacing_sender_->lumpy_tokens(), 0u); + + CheckPacketIsDelayed(3 * inter_packet_delay); + + // Wake up on time. + clock_.AdvanceTime(3 * inter_packet_delay); + CheckPacketIsSentImmediately(); + EXPECT_EQ(pacing_sender_->ideal_next_packet_send_time(), + clock_.Now() + inter_packet_delay); + EXPECT_EQ(pacing_sender_->lumpy_tokens(), 2u); + + CheckPacketIsSentImmediately(); + EXPECT_EQ(pacing_sender_->ideal_next_packet_send_time(), + clock_.Now() + 2 * inter_packet_delay); + EXPECT_EQ(pacing_sender_->lumpy_tokens(), 1u); + + CheckPacketIsSentImmediately(); + EXPECT_EQ(pacing_sender_->ideal_next_packet_send_time(), + clock_.Now() + 3 * inter_packet_delay); + EXPECT_EQ(pacing_sender_->lumpy_tokens(), 0u); + + CheckPacketIsDelayed(3 * inter_packet_delay); + + // Wake up late. + clock_.AdvanceTime(4.5 * inter_packet_delay); + CheckPacketIsSentImmediately(); + EXPECT_EQ(pacing_sender_->ideal_next_packet_send_time(), + clock_.Now() - 0.5 * inter_packet_delay); + EXPECT_EQ(pacing_sender_->lumpy_tokens(), 2u); + + CheckPacketIsSentImmediately(); + EXPECT_EQ(pacing_sender_->ideal_next_packet_send_time(), + clock_.Now() + 0.5 * inter_packet_delay); + EXPECT_EQ(pacing_sender_->lumpy_tokens(), 1u); + + CheckPacketIsSentImmediately(); + EXPECT_EQ(pacing_sender_->ideal_next_packet_send_time(), + clock_.Now() + 1.5 * inter_packet_delay); + EXPECT_EQ(pacing_sender_->lumpy_tokens(), 0u); + + CheckPacketIsDelayed(1.5 * inter_packet_delay); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/congestion_control/prr_sender.cc b/quiche/quic/core/congestion_control/prr_sender.cc new file mode 100644 index 000000000000..951f0ba5e7d7 --- /dev/null +++ b/quiche/quic/core/congestion_control/prr_sender.cc @@ -0,0 +1,62 @@ +// Copyright (c) 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/prr_sender.h" + +#include "quiche/quic/core/quic_packets.h" + +namespace quic { + +PrrSender::PrrSender() + : bytes_sent_since_loss_(0), + bytes_delivered_since_loss_(0), + ack_count_since_loss_(0), + bytes_in_flight_before_loss_(0) {} + +void PrrSender::OnPacketSent(QuicByteCount sent_bytes) { + bytes_sent_since_loss_ += sent_bytes; +} + +void PrrSender::OnPacketLost(QuicByteCount prior_in_flight) { + bytes_sent_since_loss_ = 0; + bytes_in_flight_before_loss_ = prior_in_flight; + bytes_delivered_since_loss_ = 0; + ack_count_since_loss_ = 0; +} + +void PrrSender::OnPacketAcked(QuicByteCount acked_bytes) { + bytes_delivered_since_loss_ += acked_bytes; + ++ack_count_since_loss_; +} + +bool PrrSender::CanSend(QuicByteCount congestion_window, + QuicByteCount bytes_in_flight, + QuicByteCount slowstart_threshold) const { + // Return QuicTime::Zero in order to ensure limited transmit always works. + if (bytes_sent_since_loss_ == 0 || bytes_in_flight < kMaxSegmentSize) { + return true; + } + if (congestion_window > bytes_in_flight) { + // During PRR-SSRB, limit outgoing packets to 1 extra MSS per ack, instead + // of sending the entire available window. This prevents burst retransmits + // when more packets are lost than the CWND reduction. + // limit = MAX(prr_delivered - prr_out, DeliveredData) + MSS + if (bytes_delivered_since_loss_ + ack_count_since_loss_ * kMaxSegmentSize <= + bytes_sent_since_loss_) { + return false; + } + return true; + } + // Implement Proportional Rate Reduction (RFC6937). + // Checks a simplified version of the PRR formula that doesn't use division: + // AvailableSendWindow = + // CEIL(prr_delivered * ssthresh / BytesInFlightAtLoss) - prr_sent + if (bytes_delivered_since_loss_ * slowstart_threshold > + bytes_sent_since_loss_ * bytes_in_flight_before_loss_) { + return true; + } + return false; +} + +} // namespace quic diff --git a/quiche/quic/core/congestion_control/prr_sender.h b/quiche/quic/core/congestion_control/prr_sender.h new file mode 100644 index 000000000000..4e178403db6f --- /dev/null +++ b/quiche/quic/core/congestion_control/prr_sender.h @@ -0,0 +1,42 @@ +// Copyright (c) 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Implements Proportional Rate Reduction (PRR) per RFC 6937. + +#ifndef QUICHE_QUIC_CORE_CONGESTION_CONTROL_PRR_SENDER_H_ +#define QUICHE_QUIC_CORE_CONGESTION_CONTROL_PRR_SENDER_H_ + +#include "quiche/quic/core/quic_bandwidth.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class QUIC_EXPORT_PRIVATE PrrSender { + public: + PrrSender(); + // OnPacketLost should be called on the first loss that triggers a recovery + // period and all other methods in this class should only be called when in + // recovery. + void OnPacketLost(QuicByteCount prior_in_flight); + void OnPacketSent(QuicByteCount sent_bytes); + void OnPacketAcked(QuicByteCount acked_bytes); + bool CanSend(QuicByteCount congestion_window, QuicByteCount bytes_in_flight, + QuicByteCount slowstart_threshold) const; + + private: + // Bytes sent and acked since the last loss event. + // |bytes_sent_since_loss_| is the same as "prr_out_" in RFC 6937, + // and |bytes_delivered_since_loss_| is the same as "prr_delivered_". + QuicByteCount bytes_sent_since_loss_; + QuicByteCount bytes_delivered_since_loss_; + size_t ack_count_since_loss_; + + // The congestion window before the last loss event. + QuicByteCount bytes_in_flight_before_loss_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CONGESTION_CONTROL_PRR_SENDER_H_ diff --git a/quiche/quic/core/congestion_control/prr_sender_test.cc b/quiche/quic/core/congestion_control/prr_sender_test.cc new file mode 100644 index 000000000000..60dd77929c68 --- /dev/null +++ b/quiche/quic/core/congestion_control/prr_sender_test.cc @@ -0,0 +1,123 @@ +// Copyright (c) 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/prr_sender.h" + +#include + +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { + +namespace { +// Constant based on TCP defaults. +const QuicByteCount kMaxSegmentSize = kDefaultTCPMSS; +} // namespace + +class PrrSenderTest : public QuicTest {}; + +TEST_F(PrrSenderTest, SingleLossResultsInSendOnEveryOtherAck) { + PrrSender prr; + QuicPacketCount num_packets_in_flight = 50; + QuicByteCount bytes_in_flight = num_packets_in_flight * kMaxSegmentSize; + const QuicPacketCount ssthresh_after_loss = num_packets_in_flight / 2; + const QuicByteCount congestion_window = ssthresh_after_loss * kMaxSegmentSize; + + prr.OnPacketLost(bytes_in_flight); + // Ack a packet. PRR allows one packet to leave immediately. + prr.OnPacketAcked(kMaxSegmentSize); + bytes_in_flight -= kMaxSegmentSize; + EXPECT_TRUE(prr.CanSend(congestion_window, bytes_in_flight, + ssthresh_after_loss * kMaxSegmentSize)); + // Send retransmission. + prr.OnPacketSent(kMaxSegmentSize); + // PRR shouldn't allow sending any more packets. + EXPECT_FALSE(prr.CanSend(congestion_window, bytes_in_flight, + ssthresh_after_loss * kMaxSegmentSize)); + + // One packet is lost, and one ack was consumed above. PRR now paces + // transmissions through the remaining 48 acks. PRR will alternatively + // disallow and allow a packet to be sent in response to an ack. + for (uint64_t i = 0; i < ssthresh_after_loss - 1; ++i) { + // Ack a packet. PRR shouldn't allow sending a packet in response. + prr.OnPacketAcked(kMaxSegmentSize); + bytes_in_flight -= kMaxSegmentSize; + EXPECT_FALSE(prr.CanSend(congestion_window, bytes_in_flight, + ssthresh_after_loss * kMaxSegmentSize)); + // Ack another packet. PRR should now allow sending a packet in response. + prr.OnPacketAcked(kMaxSegmentSize); + bytes_in_flight -= kMaxSegmentSize; + EXPECT_TRUE(prr.CanSend(congestion_window, bytes_in_flight, + ssthresh_after_loss * kMaxSegmentSize)); + // Send a packet in response. + prr.OnPacketSent(kMaxSegmentSize); + bytes_in_flight += kMaxSegmentSize; + } + + // Since bytes_in_flight is now equal to congestion_window, PRR now maintains + // packet conservation, allowing one packet to be sent in response to an ack. + EXPECT_EQ(congestion_window, bytes_in_flight); + for (int i = 0; i < 10; ++i) { + // Ack a packet. + prr.OnPacketAcked(kMaxSegmentSize); + bytes_in_flight -= kMaxSegmentSize; + EXPECT_TRUE(prr.CanSend(congestion_window, bytes_in_flight, + ssthresh_after_loss * kMaxSegmentSize)); + // Send a packet in response, since PRR allows it. + prr.OnPacketSent(kMaxSegmentSize); + bytes_in_flight += kMaxSegmentSize; + + // Since bytes_in_flight is equal to the congestion_window, + // PRR disallows sending. + EXPECT_EQ(congestion_window, bytes_in_flight); + EXPECT_FALSE(prr.CanSend(congestion_window, bytes_in_flight, + ssthresh_after_loss * kMaxSegmentSize)); + } +} + +TEST_F(PrrSenderTest, BurstLossResultsInSlowStart) { + PrrSender prr; + QuicByteCount bytes_in_flight = 20 * kMaxSegmentSize; + const QuicPacketCount num_packets_lost = 13; + const QuicPacketCount ssthresh_after_loss = 10; + const QuicByteCount congestion_window = ssthresh_after_loss * kMaxSegmentSize; + + // Lose 13 packets. + bytes_in_flight -= num_packets_lost * kMaxSegmentSize; + prr.OnPacketLost(bytes_in_flight); + + // PRR-SSRB will allow the following 3 acks to send up to 2 packets. + for (int i = 0; i < 3; ++i) { + prr.OnPacketAcked(kMaxSegmentSize); + bytes_in_flight -= kMaxSegmentSize; + // PRR-SSRB should allow two packets to be sent. + for (int j = 0; j < 2; ++j) { + EXPECT_TRUE(prr.CanSend(congestion_window, bytes_in_flight, + ssthresh_after_loss * kMaxSegmentSize)); + // Send a packet in response. + prr.OnPacketSent(kMaxSegmentSize); + bytes_in_flight += kMaxSegmentSize; + } + // PRR should allow no more than 2 packets in response to an ack. + EXPECT_FALSE(prr.CanSend(congestion_window, bytes_in_flight, + ssthresh_after_loss * kMaxSegmentSize)); + } + + // Out of SSRB mode, PRR allows one send in response to each ack. + for (int i = 0; i < 10; ++i) { + prr.OnPacketAcked(kMaxSegmentSize); + bytes_in_flight -= kMaxSegmentSize; + EXPECT_TRUE(prr.CanSend(congestion_window, bytes_in_flight, + ssthresh_after_loss * kMaxSegmentSize)); + // Send a packet in response. + prr.OnPacketSent(kMaxSegmentSize); + bytes_in_flight += kMaxSegmentSize; + } +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/congestion_control/rtt_stats.cc b/quiche/quic/core/congestion_control/rtt_stats.cc new file mode 100644 index 000000000000..d679cd63631d --- /dev/null +++ b/quiche/quic/core/congestion_control/rtt_stats.cc @@ -0,0 +1,143 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/rtt_stats.h" + +#include // std::abs + +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +namespace { + +const float kAlpha = 0.125f; +const float kOneMinusAlpha = (1 - kAlpha); +const float kBeta = 0.25f; +const float kOneMinusBeta = (1 - kBeta); + +} // namespace + +RttStats::RttStats() + : latest_rtt_(QuicTime::Delta::Zero()), + min_rtt_(QuicTime::Delta::Zero()), + smoothed_rtt_(QuicTime::Delta::Zero()), + previous_srtt_(QuicTime::Delta::Zero()), + mean_deviation_(QuicTime::Delta::Zero()), + calculate_standard_deviation_(false), + initial_rtt_(QuicTime::Delta::FromMilliseconds(kInitialRttMs)), + last_update_time_(QuicTime::Zero()) {} + +void RttStats::ExpireSmoothedMetrics() { + mean_deviation_ = std::max( + mean_deviation_, QuicTime::Delta::FromMicroseconds(std::abs( + (smoothed_rtt_ - latest_rtt_).ToMicroseconds()))); + smoothed_rtt_ = std::max(smoothed_rtt_, latest_rtt_); +} + +// Updates the RTT based on a new sample. +bool RttStats::UpdateRtt(QuicTime::Delta send_delta, QuicTime::Delta ack_delay, + QuicTime now) { + if (send_delta.IsInfinite() || send_delta <= QuicTime::Delta::Zero()) { + QUIC_LOG_FIRST_N(WARNING, 3) + << "Ignoring measured send_delta, because it's is " + << "either infinite, zero, or negative. send_delta = " + << send_delta.ToMicroseconds(); + return false; + } + + last_update_time_ = now; + + // Update min_rtt_ first. min_rtt_ does not use an rtt_sample corrected for + // ack_delay but the raw observed send_delta, since poor clock granularity at + // the client may cause a high ack_delay to result in underestimation of the + // min_rtt_. + if (min_rtt_.IsZero() || min_rtt_ > send_delta) { + min_rtt_ = send_delta; + } + + QuicTime::Delta rtt_sample(send_delta); + previous_srtt_ = smoothed_rtt_; + // Correct for ack_delay if information received from the peer results in a + // an RTT sample at least as large as min_rtt. Otherwise, only use the + // send_delta. + // TODO(fayang): consider to ignore rtt_sample if rtt_sample < ack_delay and + // ack_delay is relatively large. + if (rtt_sample > ack_delay) { + if (rtt_sample - min_rtt_ >= ack_delay) { + rtt_sample = rtt_sample - ack_delay; + } else { + QUIC_CODE_COUNT(quic_ack_delay_makes_rtt_sample_smaller_than_min_rtt); + } + } else { + QUIC_CODE_COUNT(quic_ack_delay_greater_than_rtt_sample); + } + latest_rtt_ = rtt_sample; + if (calculate_standard_deviation_) { + standard_deviation_calculator_.OnNewRttSample(rtt_sample, smoothed_rtt_); + } + // First time call. + if (smoothed_rtt_.IsZero()) { + smoothed_rtt_ = rtt_sample; + mean_deviation_ = + QuicTime::Delta::FromMicroseconds(rtt_sample.ToMicroseconds() / 2); + } else { + mean_deviation_ = QuicTime::Delta::FromMicroseconds(static_cast( + kOneMinusBeta * mean_deviation_.ToMicroseconds() + + kBeta * std::abs((smoothed_rtt_ - rtt_sample).ToMicroseconds()))); + smoothed_rtt_ = kOneMinusAlpha * smoothed_rtt_ + kAlpha * rtt_sample; + QUIC_DVLOG(1) << " smoothed_rtt(us):" << smoothed_rtt_.ToMicroseconds() + << " mean_deviation(us):" << mean_deviation_.ToMicroseconds(); + } + return true; +} + +void RttStats::OnConnectionMigration() { + latest_rtt_ = QuicTime::Delta::Zero(); + min_rtt_ = QuicTime::Delta::Zero(); + smoothed_rtt_ = QuicTime::Delta::Zero(); + mean_deviation_ = QuicTime::Delta::Zero(); + initial_rtt_ = QuicTime::Delta::FromMilliseconds(kInitialRttMs); +} + +QuicTime::Delta RttStats::GetStandardOrMeanDeviation() const { + QUICHE_DCHECK(calculate_standard_deviation_); + if (!standard_deviation_calculator_.has_valid_standard_deviation) { + return mean_deviation_; + } + return standard_deviation_calculator_.CalculateStandardDeviation(); +} + +void RttStats::StandardDeviationCaculator::OnNewRttSample( + QuicTime::Delta rtt_sample, QuicTime::Delta smoothed_rtt) { + double new_value = rtt_sample.ToMicroseconds(); + if (smoothed_rtt.IsZero()) { + return; + } + has_valid_standard_deviation = true; + const double delta = new_value - smoothed_rtt.ToMicroseconds(); + m2 = kOneMinusBeta * m2 + kBeta * pow(delta, 2); +} + +QuicTime::Delta +RttStats::StandardDeviationCaculator::CalculateStandardDeviation() const { + QUICHE_DCHECK(has_valid_standard_deviation); + return QuicTime::Delta::FromMicroseconds(sqrt(m2)); +} + +void RttStats::CloneFrom(const RttStats& stats) { + latest_rtt_ = stats.latest_rtt_; + min_rtt_ = stats.min_rtt_; + smoothed_rtt_ = stats.smoothed_rtt_; + previous_srtt_ = stats.previous_srtt_; + mean_deviation_ = stats.mean_deviation_; + standard_deviation_calculator_ = stats.standard_deviation_calculator_; + calculate_standard_deviation_ = stats.calculate_standard_deviation_; + initial_rtt_ = stats.initial_rtt_; + last_update_time_ = stats.last_update_time_; +} + +} // namespace quic diff --git a/quiche/quic/core/congestion_control/rtt_stats.h b/quiche/quic/core/congestion_control/rtt_stats.h new file mode 100644 index 000000000000..04a0148433af --- /dev/null +++ b/quiche/quic/core/congestion_control/rtt_stats.h @@ -0,0 +1,131 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// A convenience class to store rtt samples and calculate smoothed rtt. + +#ifndef QUICHE_QUIC_CORE_CONGESTION_CONTROL_RTT_STATS_H_ +#define QUICHE_QUIC_CORE_CONGESTION_CONTROL_RTT_STATS_H_ + +#include +#include + +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +namespace test { +class RttStatsPeer; +} // namespace test + +class QUIC_EXPORT_PRIVATE RttStats { + public: + // Calculates running standard-deviation using Welford's algorithm: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance# + // Welford's_Online_algorithm. + struct QUIC_EXPORT_PRIVATE StandardDeviationCaculator { + StandardDeviationCaculator() {} + + // Called when a new RTT sample is available. + void OnNewRttSample(QuicTime::Delta rtt_sample, + QuicTime::Delta smoothed_rtt); + // Calculates the standard deviation. + QuicTime::Delta CalculateStandardDeviation() const; + + bool has_valid_standard_deviation = false; + + private: + double m2 = 0; + }; + + RttStats(); + RttStats(const RttStats&) = delete; + RttStats& operator=(const RttStats&) = delete; + + // Updates the RTT from an incoming ack which is received |send_delta| after + // the packet is sent and the peer reports the ack being delayed |ack_delay|. + // Returns true if RTT was updated, and false if the sample was ignored. + bool UpdateRtt(QuicTime::Delta send_delta, QuicTime::Delta ack_delay, + QuicTime now); + + // Causes the smoothed_rtt to be increased to the latest_rtt if the latest_rtt + // is larger. The mean deviation is increased to the most recent deviation if + // it's larger. + void ExpireSmoothedMetrics(); + + // Called when connection migrates and rtt measurement needs to be reset. + void OnConnectionMigration(); + + // Returns the EWMA smoothed RTT for the connection. + // May return Zero if no valid updates have occurred. + QuicTime::Delta smoothed_rtt() const { return smoothed_rtt_; } + + // Returns the EWMA smoothed RTT prior to the most recent RTT sample. + QuicTime::Delta previous_srtt() const { return previous_srtt_; } + + QuicTime::Delta initial_rtt() const { return initial_rtt_; } + + QuicTime::Delta SmoothedOrInitialRtt() const { + return smoothed_rtt_.IsZero() ? initial_rtt_ : smoothed_rtt_; + } + + QuicTime::Delta MinOrInitialRtt() const { + return min_rtt_.IsZero() ? initial_rtt_ : min_rtt_; + } + + // Sets an initial RTT to be used for SmoothedRtt before any RTT updates. + void set_initial_rtt(QuicTime::Delta initial_rtt) { + if (initial_rtt.ToMicroseconds() <= 0) { + QUIC_BUG(quic_bug_10453_1) << "Attempt to set initial rtt to <= 0."; + return; + } + initial_rtt_ = initial_rtt; + } + + // The most recent rtt measurement. + // May return Zero if no valid updates have occurred. + QuicTime::Delta latest_rtt() const { return latest_rtt_; } + + // Returns the min_rtt for the entire connection. + // May return Zero if no valid updates have occurred. + QuicTime::Delta min_rtt() const { return min_rtt_; } + + QuicTime::Delta mean_deviation() const { return mean_deviation_; } + + // Returns standard deviation if there is a valid one. Otherwise, returns + // mean_deviation_. + QuicTime::Delta GetStandardOrMeanDeviation() const; + + QuicTime last_update_time() const { return last_update_time_; } + + void EnableStandardDeviationCalculation() { + calculate_standard_deviation_ = true; + } + + void CloneFrom(const RttStats& stats); + + private: + friend class test::RttStatsPeer; + + QuicTime::Delta latest_rtt_; + QuicTime::Delta min_rtt_; + QuicTime::Delta smoothed_rtt_; + QuicTime::Delta previous_srtt_; + // Mean RTT deviation during this session. + // Approximation of standard deviation, the error is roughly 1.25 times + // larger than the standard deviation, for a normally distributed signal. + QuicTime::Delta mean_deviation_; + // Standard deviation calculator. Only used calculate_standard_deviation_ is + // true. + StandardDeviationCaculator standard_deviation_calculator_; + bool calculate_standard_deviation_; + QuicTime::Delta initial_rtt_; + QuicTime last_update_time_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CONGESTION_CONTROL_RTT_STATS_H_ diff --git a/quiche/quic/core/congestion_control/rtt_stats_test.cc b/quiche/quic/core/congestion_control/rtt_stats_test.cc new file mode 100644 index 000000000000..11bf11a3cf88 --- /dev/null +++ b/quiche/quic/core/congestion_control/rtt_stats_test.cc @@ -0,0 +1,231 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/rtt_stats.h" + +#include + +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +using testing::Message; + +namespace quic { +namespace test { + +class RttStatsTest : public QuicTest { + protected: + RttStats rtt_stats_; +}; + +TEST_F(RttStatsTest, DefaultsBeforeUpdate) { + EXPECT_LT(QuicTime::Delta::Zero(), rtt_stats_.initial_rtt()); + EXPECT_EQ(QuicTime::Delta::Zero(), rtt_stats_.min_rtt()); + EXPECT_EQ(QuicTime::Delta::Zero(), rtt_stats_.smoothed_rtt()); +} + +TEST_F(RttStatsTest, SmoothedRtt) { + // Verify that ack_delay is ignored in the first measurement. + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(300), + QuicTime::Delta::FromMilliseconds(100), + QuicTime::Zero()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(300), rtt_stats_.latest_rtt()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(300), rtt_stats_.smoothed_rtt()); + // Verify that a plausible ack delay increases the max ack delay. + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(400), + QuicTime::Delta::FromMilliseconds(100), + QuicTime::Zero()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(300), rtt_stats_.latest_rtt()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(300), rtt_stats_.smoothed_rtt()); + // Verify that Smoothed RTT includes max ack delay if it's reasonable. + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(350), + QuicTime::Delta::FromMilliseconds(50), QuicTime::Zero()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(300), rtt_stats_.latest_rtt()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(300), rtt_stats_.smoothed_rtt()); + // Verify that large erroneous ack_delay does not change Smoothed RTT. + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(200), + QuicTime::Delta::FromMilliseconds(300), + QuicTime::Zero()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(200), rtt_stats_.latest_rtt()); + EXPECT_EQ(QuicTime::Delta::FromMicroseconds(287500), + rtt_stats_.smoothed_rtt()); +} + +// Ensure that the potential rounding artifacts in EWMA calculation do not cause +// the SRTT to drift too far from the exact value. +TEST_F(RttStatsTest, SmoothedRttStability) { + for (size_t time = 3; time < 20000; time++) { + RttStats stats; + for (size_t i = 0; i < 100; i++) { + stats.UpdateRtt(QuicTime::Delta::FromMicroseconds(time), + QuicTime::Delta::FromMilliseconds(0), QuicTime::Zero()); + int64_t time_delta_us = stats.smoothed_rtt().ToMicroseconds() - time; + ASSERT_LE(std::abs(time_delta_us), 1); + } + } +} + +TEST_F(RttStatsTest, PreviousSmoothedRtt) { + // Verify that ack_delay is corrected for in Smoothed RTT. + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(200), + QuicTime::Delta::FromMilliseconds(0), QuicTime::Zero()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(200), rtt_stats_.latest_rtt()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(200), rtt_stats_.smoothed_rtt()); + EXPECT_EQ(QuicTime::Delta::Zero(), rtt_stats_.previous_srtt()); + // Ensure the previous SRTT is 200ms after a 100ms sample. + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(100), + QuicTime::Delta::Zero(), QuicTime::Zero()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(100), rtt_stats_.latest_rtt()); + EXPECT_EQ(QuicTime::Delta::FromMicroseconds(187500).ToMicroseconds(), + rtt_stats_.smoothed_rtt().ToMicroseconds()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(200), rtt_stats_.previous_srtt()); +} + +TEST_F(RttStatsTest, MinRtt) { + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(200), + QuicTime::Delta::Zero(), QuicTime::Zero()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(200), rtt_stats_.min_rtt()); + rtt_stats_.UpdateRtt( + QuicTime::Delta::FromMilliseconds(10), QuicTime::Delta::Zero(), + QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(10)); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(10), rtt_stats_.min_rtt()); + rtt_stats_.UpdateRtt( + QuicTime::Delta::FromMilliseconds(50), QuicTime::Delta::Zero(), + QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(20)); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(10), rtt_stats_.min_rtt()); + rtt_stats_.UpdateRtt( + QuicTime::Delta::FromMilliseconds(50), QuicTime::Delta::Zero(), + QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(30)); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(10), rtt_stats_.min_rtt()); + rtt_stats_.UpdateRtt( + QuicTime::Delta::FromMilliseconds(50), QuicTime::Delta::Zero(), + QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(40)); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(10), rtt_stats_.min_rtt()); + // Verify that ack_delay does not go into recording of min_rtt_. + rtt_stats_.UpdateRtt( + QuicTime::Delta::FromMilliseconds(7), + QuicTime::Delta::FromMilliseconds(2), + QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(50)); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(7), rtt_stats_.min_rtt()); +} + +TEST_F(RttStatsTest, ExpireSmoothedMetrics) { + QuicTime::Delta initial_rtt = QuicTime::Delta::FromMilliseconds(10); + rtt_stats_.UpdateRtt(initial_rtt, QuicTime::Delta::Zero(), QuicTime::Zero()); + EXPECT_EQ(initial_rtt, rtt_stats_.min_rtt()); + EXPECT_EQ(initial_rtt, rtt_stats_.smoothed_rtt()); + + EXPECT_EQ(0.5 * initial_rtt, rtt_stats_.mean_deviation()); + + // Update once with a 20ms RTT. + QuicTime::Delta doubled_rtt = 2 * initial_rtt; + rtt_stats_.UpdateRtt(doubled_rtt, QuicTime::Delta::Zero(), QuicTime::Zero()); + EXPECT_EQ(1.125 * initial_rtt, rtt_stats_.smoothed_rtt()); + + // Expire the smoothed metrics, increasing smoothed rtt and mean deviation. + rtt_stats_.ExpireSmoothedMetrics(); + EXPECT_EQ(doubled_rtt, rtt_stats_.smoothed_rtt()); + EXPECT_EQ(0.875 * initial_rtt, rtt_stats_.mean_deviation()); + + // Now go back down to 5ms and expire the smoothed metrics, and ensure the + // mean deviation increases to 15ms. + QuicTime::Delta half_rtt = 0.5 * initial_rtt; + rtt_stats_.UpdateRtt(half_rtt, QuicTime::Delta::Zero(), QuicTime::Zero()); + EXPECT_GT(doubled_rtt, rtt_stats_.smoothed_rtt()); + EXPECT_LT(initial_rtt, rtt_stats_.mean_deviation()); +} + +TEST_F(RttStatsTest, UpdateRttWithBadSendDeltas) { + QuicTime::Delta initial_rtt = QuicTime::Delta::FromMilliseconds(10); + rtt_stats_.UpdateRtt(initial_rtt, QuicTime::Delta::Zero(), QuicTime::Zero()); + EXPECT_EQ(initial_rtt, rtt_stats_.min_rtt()); + EXPECT_EQ(initial_rtt, rtt_stats_.smoothed_rtt()); + + std::vector bad_send_deltas; + bad_send_deltas.push_back(QuicTime::Delta::Zero()); + bad_send_deltas.push_back(QuicTime::Delta::Infinite()); + bad_send_deltas.push_back(QuicTime::Delta::FromMicroseconds(-1000)); + + for (QuicTime::Delta bad_send_delta : bad_send_deltas) { + SCOPED_TRACE(Message() << "bad_send_delta = " + << bad_send_delta.ToMicroseconds()); + EXPECT_FALSE(rtt_stats_.UpdateRtt(bad_send_delta, QuicTime::Delta::Zero(), + QuicTime::Zero())); + EXPECT_EQ(initial_rtt, rtt_stats_.min_rtt()); + EXPECT_EQ(initial_rtt, rtt_stats_.smoothed_rtt()); + } +} + +TEST_F(RttStatsTest, ResetAfterConnectionMigrations) { + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(200), + QuicTime::Delta::FromMilliseconds(0), QuicTime::Zero()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(200), rtt_stats_.latest_rtt()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(200), rtt_stats_.smoothed_rtt()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(200), rtt_stats_.min_rtt()); + + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(300), + QuicTime::Delta::FromMilliseconds(100), + QuicTime::Zero()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(200), rtt_stats_.latest_rtt()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(200), rtt_stats_.smoothed_rtt()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(200), rtt_stats_.min_rtt()); + + // Reset rtt stats on connection migrations. + rtt_stats_.OnConnectionMigration(); + EXPECT_EQ(QuicTime::Delta::Zero(), rtt_stats_.latest_rtt()); + EXPECT_EQ(QuicTime::Delta::Zero(), rtt_stats_.smoothed_rtt()); + EXPECT_EQ(QuicTime::Delta::Zero(), rtt_stats_.min_rtt()); +} + +TEST_F(RttStatsTest, StandardDeviationCaculatorTest1) { + // All samples are the same. + rtt_stats_.EnableStandardDeviationCalculation(); + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(10), + QuicTime::Delta::Zero(), QuicTime::Zero()); + EXPECT_EQ(rtt_stats_.mean_deviation(), + rtt_stats_.GetStandardOrMeanDeviation()); + + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(10), + QuicTime::Delta::Zero(), QuicTime::Zero()); + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(10), + QuicTime::Delta::Zero(), QuicTime::Zero()); + EXPECT_EQ(QuicTime::Delta::Zero(), rtt_stats_.GetStandardOrMeanDeviation()); +} + +TEST_F(RttStatsTest, StandardDeviationCaculatorTest2) { + // Small variance. + rtt_stats_.EnableStandardDeviationCalculation(); + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(10), + QuicTime::Delta::Zero(), QuicTime::Zero()); + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(10), + QuicTime::Delta::Zero(), QuicTime::Zero()); + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(10), + QuicTime::Delta::Zero(), QuicTime::Zero()); + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(9), + QuicTime::Delta::Zero(), QuicTime::Zero()); + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(11), + QuicTime::Delta::Zero(), QuicTime::Zero()); + EXPECT_LT(QuicTime::Delta::FromMicroseconds(500), + rtt_stats_.GetStandardOrMeanDeviation()); + EXPECT_GT(QuicTime::Delta::FromMilliseconds(1), + rtt_stats_.GetStandardOrMeanDeviation()); +} + +TEST_F(RttStatsTest, StandardDeviationCaculatorTest3) { + // Some variance. + rtt_stats_.EnableStandardDeviationCalculation(); + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(50), + QuicTime::Delta::Zero(), QuicTime::Zero()); + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(100), + QuicTime::Delta::Zero(), QuicTime::Zero()); + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(100), + QuicTime::Delta::Zero(), QuicTime::Zero()); + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(50), + QuicTime::Delta::Zero(), QuicTime::Zero()); + EXPECT_APPROX_EQ(rtt_stats_.mean_deviation(), + rtt_stats_.GetStandardOrMeanDeviation(), 0.25f); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/congestion_control/send_algorithm_interface.cc b/quiche/quic/core/congestion_control/send_algorithm_interface.cc new file mode 100644 index 000000000000..89b1ce982518 --- /dev/null +++ b/quiche/quic/core/congestion_control/send_algorithm_interface.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/send_algorithm_interface.h" + +#include "absl/base/attributes.h" +#include "quiche/quic/core/congestion_control/bbr2_sender.h" +#include "quiche/quic/core/congestion_control/bbr_sender.h" +#include "quiche/quic/core/congestion_control/tcp_cubic_sender_bytes.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" + +namespace quic { + +class RttStats; + +// Factory for send side congestion control algorithm. +SendAlgorithmInterface* SendAlgorithmInterface::Create( + const QuicClock* clock, const RttStats* rtt_stats, + const QuicUnackedPacketMap* unacked_packets, + CongestionControlType congestion_control_type, QuicRandom* random, + QuicConnectionStats* stats, QuicPacketCount initial_congestion_window, + SendAlgorithmInterface* old_send_algorithm) { + QuicPacketCount max_congestion_window = + GetQuicFlag(quic_max_congestion_window); + switch (congestion_control_type) { + case kGoogCC: // GoogCC is not supported by quic/core, fall back to BBR. + case kBBR: + return new BbrSender(clock->ApproximateNow(), rtt_stats, unacked_packets, + initial_congestion_window, max_congestion_window, + random, stats); + case kBBRv2: + return new Bbr2Sender( + clock->ApproximateNow(), rtt_stats, unacked_packets, + initial_congestion_window, max_congestion_window, random, stats, + old_send_algorithm && + old_send_algorithm->GetCongestionControlType() == kBBR + ? static_cast(old_send_algorithm) + : nullptr); + case kPCC: + // PCC is currently not supported, fall back to CUBIC instead. + ABSL_FALLTHROUGH_INTENDED; + case kCubicBytes: + return new TcpCubicSenderBytes( + clock, rtt_stats, false /* don't use Reno */, + initial_congestion_window, max_congestion_window, stats); + case kRenoBytes: + return new TcpCubicSenderBytes(clock, rtt_stats, true /* use Reno */, + initial_congestion_window, + max_congestion_window, stats); + } + return nullptr; +} + +} // namespace quic diff --git a/quiche/quic/core/congestion_control/send_algorithm_interface.h b/quiche/quic/core/congestion_control/send_algorithm_interface.h new file mode 100644 index 000000000000..e5e8d588e405 --- /dev/null +++ b/quiche/quic/core/congestion_control/send_algorithm_interface.h @@ -0,0 +1,179 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// The pure virtual class for send side congestion control algorithm. + +#ifndef QUICHE_QUIC_CORE_CONGESTION_CONTROL_SEND_ALGORITHM_INTERFACE_H_ +#define QUICHE_QUIC_CORE_CONGESTION_CONTROL_SEND_ALGORITHM_INTERFACE_H_ + +#include +#include +#include + +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/quic_bandwidth.h" +#include "quiche/quic/core/quic_clock.h" +#include "quiche/quic/core/quic_config.h" +#include "quiche/quic/core/quic_connection_stats.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_unacked_packet_map.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +using QuicRoundTripCount = uint64_t; + +class CachedNetworkParameters; +class RttStats; + +class QUIC_EXPORT_PRIVATE SendAlgorithmInterface { + public: + // Network Params for AdjustNetworkParameters. + struct QUIC_NO_EXPORT NetworkParams { + NetworkParams() = default; + NetworkParams(const QuicBandwidth& bandwidth, const QuicTime::Delta& rtt, + bool allow_cwnd_to_decrease) + : bandwidth(bandwidth), + rtt(rtt), + allow_cwnd_to_decrease(allow_cwnd_to_decrease) {} + + bool operator==(const NetworkParams& other) const { + return bandwidth == other.bandwidth && rtt == other.rtt && + max_initial_congestion_window == + other.max_initial_congestion_window && + allow_cwnd_to_decrease == other.allow_cwnd_to_decrease && + is_rtt_trusted == other.is_rtt_trusted; + } + + QuicBandwidth bandwidth = QuicBandwidth::Zero(); + QuicTime::Delta rtt = QuicTime::Delta::Zero(); + int max_initial_congestion_window = 0; + bool allow_cwnd_to_decrease = false; + bool is_rtt_trusted = false; + }; + + static SendAlgorithmInterface* Create( + const QuicClock* clock, const RttStats* rtt_stats, + const QuicUnackedPacketMap* unacked_packets, CongestionControlType type, + QuicRandom* random, QuicConnectionStats* stats, + QuicPacketCount initial_congestion_window, + SendAlgorithmInterface* old_send_algorithm); + + virtual ~SendAlgorithmInterface() {} + + virtual void SetFromConfig(const QuicConfig& config, + Perspective perspective) = 0; + + virtual void ApplyConnectionOptions( + const QuicTagVector& connection_options) = 0; + + // Sets the initial congestion window in number of packets. May be ignored + // if called after the initial congestion window is no longer relevant. + virtual void SetInitialCongestionWindowInPackets(QuicPacketCount packets) = 0; + + // Indicates an update to the congestion state, caused either by an incoming + // ack or loss event timeout. |rtt_updated| indicates whether a new + // latest_rtt sample has been taken, |prior_in_flight| the bytes in flight + // prior to the congestion event. |acked_packets| and |lost_packets| are any + // packets considered acked or lost as a result of the congestion event. + // |num_ect| and |num_ce| indicate the number of newly acknowledged packets + // for which the receiver reported the Explicit Congestion Notification (ECN) + // bits were set to ECT(1) or CE, respectively. A sender will not use ECT(0). + // If QUIC determines the peer's feedback is invalid, it will send zero in + // these fields. + virtual void OnCongestionEvent(bool rtt_updated, + QuicByteCount prior_in_flight, + QuicTime event_time, + const AckedPacketVector& acked_packets, + const LostPacketVector& lost_packets, + QuicPacketCount num_ect, + QuicPacketCount num_ce) = 0; + + // Inform that we sent |bytes| to the wire, and if the packet is + // retransmittable. |bytes_in_flight| is the number of bytes in flight before + // the packet was sent. + // Note: this function must be called for every packet sent to the wire. + virtual void OnPacketSent(QuicTime sent_time, QuicByteCount bytes_in_flight, + QuicPacketNumber packet_number, QuicByteCount bytes, + HasRetransmittableData is_retransmittable) = 0; + + // Inform that |packet_number| has been neutered. + virtual void OnPacketNeutered(QuicPacketNumber packet_number) = 0; + + // Called when the retransmission timeout fires. Neither OnPacketAbandoned + // nor OnPacketLost will be called for these packets. + virtual void OnRetransmissionTimeout(bool packets_retransmitted) = 0; + + // Called when connection migrates and cwnd needs to be reset. + virtual void OnConnectionMigration() = 0; + + // Make decision on whether the sender can send right now. Note that even + // when this method returns true, the sending can be delayed due to pacing. + virtual bool CanSend(QuicByteCount bytes_in_flight) = 0; + + // The pacing rate of the send algorithm. May be zero if the rate is unknown. + virtual QuicBandwidth PacingRate(QuicByteCount bytes_in_flight) const = 0; + + // What's the current estimated bandwidth in bytes per second. + // Returns 0 when it does not have an estimate. + virtual QuicBandwidth BandwidthEstimate() const = 0; + + // Whether BandwidthEstimate returns a good measurement for resumption. + virtual bool HasGoodBandwidthEstimateForResumption() const = 0; + + // Returns the size of the current congestion window in bytes. Note, this is + // not the *available* window. Some send algorithms may not use a congestion + // window and will return 0. + virtual QuicByteCount GetCongestionWindow() const = 0; + + // Whether the send algorithm is currently in slow start. When true, the + // BandwidthEstimate is expected to be too low. + virtual bool InSlowStart() const = 0; + + // Whether the send algorithm is currently in recovery. + virtual bool InRecovery() const = 0; + + // Returns the size of the slow start congestion window in bytes, + // aka ssthresh. Only defined for Cubic and Reno, other algorithms return 0. + virtual QuicByteCount GetSlowStartThreshold() const = 0; + + virtual CongestionControlType GetCongestionControlType() const = 0; + + // Notifies the congestion control algorithm of an external network + // measurement or prediction. Either |bandwidth| or |rtt| may be zero if no + // sample is available. + virtual void AdjustNetworkParameters(const NetworkParams& params) = 0; + + // Retrieves debugging information about the current state of the + // send algorithm. + virtual std::string GetDebugState() const = 0; + + // Called when the connection has no outstanding data to send. Specifically, + // this means that none of the data streams are write-blocked, there are no + // packets in the connection queue, and there are no pending retransmissins, + // i.e. the sender cannot send anything for reasons other than being blocked + // by congestion controller. This includes cases when the connection is + // blocked by the flow controller. + // + // The fact that this method is called does not necessarily imply that the + // connection would not be blocked by the congestion control if it actually + // tried to send data. If the congestion control algorithm needs to exclude + // such cases, it should use the internal state it uses for congestion control + // for that. + virtual void OnApplicationLimited(QuicByteCount bytes_in_flight) = 0; + + // Called before connection close to collect stats. + virtual void PopulateConnectionStats(QuicConnectionStats* stats) const = 0; + + // Returns true if the algorithm will respond to Congestion Experienced (CE) + // indications in accordance with RFC3168 [ECT(0)] or RFC9331 [ECT(1)]. + virtual bool SupportsECT0() const = 0; + virtual bool SupportsECT1() const = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CONGESTION_CONTROL_SEND_ALGORITHM_INTERFACE_H_ diff --git a/quiche/quic/core/congestion_control/send_algorithm_test.cc b/quiche/quic/core/congestion_control/send_algorithm_test.cc new file mode 100644 index 000000000000..76e61a344bd2 --- /dev/null +++ b/quiche/quic/core/congestion_control/send_algorithm_test.cc @@ -0,0 +1,347 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "quiche/quic/core/congestion_control/rtt_stats.h" +#include "quiche/quic/core/congestion_control/send_algorithm_interface.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/mock_clock.h" +#include "quiche/quic/test_tools/quic_config_peer.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_sent_packet_manager_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/test_tools/simulator/quic_endpoint.h" +#include "quiche/quic/test_tools/simulator/simulator.h" +#include "quiche/quic/test_tools/simulator/switch.h" + +namespace quic { +namespace test { +namespace { + +// Use the initial CWND of 10, as 32 is too much for the test network. +const uint32_t kInitialCongestionWindowPackets = 10; + +// Test network parameters. Here, the topology of the network is: +// +// QUIC Sender +// | +// | <-- local link +// | +// Network switch +// * <-- the bottleneck queue in the direction +// | of the receiver +// | +// | <-- test link +// | +// | +// Receiver +// +// When setting the bandwidth of the local link and test link, choose +// a bandwidth lower than 20Mbps, as the clock-granularity of the +// simulator can only handle a granularity of 1us. + +// Default settings between the switch and the sender. +const QuicBandwidth kLocalLinkBandwidth = + QuicBandwidth::FromKBitsPerSecond(10000); +const QuicTime::Delta kLocalPropagationDelay = + QuicTime::Delta::FromMilliseconds(2); + +// Wired network settings. A typical desktop network setup, a +// high-bandwidth, 30ms test link to the receiver. +const QuicBandwidth kTestLinkWiredBandwidth = + QuicBandwidth::FromKBitsPerSecond(4000); +const QuicTime::Delta kTestLinkWiredPropagationDelay = + QuicTime::Delta::FromMilliseconds(50); +const QuicTime::Delta kTestWiredTransferTime = + kTestLinkWiredBandwidth.TransferTime(kMaxOutgoingPacketSize) + + kLocalLinkBandwidth.TransferTime(kMaxOutgoingPacketSize); +const QuicTime::Delta kTestWiredRtt = + (kTestLinkWiredPropagationDelay + kLocalPropagationDelay + + kTestWiredTransferTime) * + 2; +const QuicByteCount kTestWiredBdp = kTestWiredRtt * kTestLinkWiredBandwidth; + +// Small BDP, Bandwidth-policed network settings. In this scenario, +// the receiver has a low-bandwidth, short propagation-delay link, +// resulting in a small BDP. We model the policer by setting the +// queue size to only one packet. +const QuicBandwidth kTestLinkLowBdpBandwidth = + QuicBandwidth::FromKBitsPerSecond(200); +const QuicTime::Delta kTestLinkLowBdpPropagationDelay = + QuicTime::Delta::FromMilliseconds(50); +const QuicByteCount kTestPolicerQueue = kMaxOutgoingPacketSize; + +// Satellite network settings. In a satellite network, the bottleneck +// buffer is typically sized for non-satellite links , but the +// propagation delay of the test link to the receiver is as much as a +// quarter second. +const QuicTime::Delta kTestSatellitePropagationDelay = + QuicTime::Delta::FromMilliseconds(250); + +// Cellular scenarios. In a cellular network, the bottleneck queue at +// the edge of the network can be as great as 3MB. +const QuicBandwidth kTestLink2GBandwidth = + QuicBandwidth::FromKBitsPerSecond(100); +const QuicBandwidth kTestLink3GBandwidth = + QuicBandwidth::FromKBitsPerSecond(1500); +const QuicByteCount kCellularQueue = 3 * 1024 * 1024; +const QuicTime::Delta kTestCellularPropagationDelay = + QuicTime::Delta::FromMilliseconds(40); + +// Small RTT scenario, below the per-ack-update threshold of 30ms. +const QuicTime::Delta kTestLinkSmallRTTDelay = + QuicTime::Delta::FromMilliseconds(10); + +struct TestParams { + explicit TestParams(CongestionControlType congestion_control_type) + : congestion_control_type(congestion_control_type) {} + + friend std::ostream& operator<<(std::ostream& os, const TestParams& p) { + os << "{ congestion_control_type: " + << CongestionControlTypeToString(p.congestion_control_type); + os << " }"; + return os; + } + + const CongestionControlType congestion_control_type; +}; + +std::string TestParamToString( + const testing::TestParamInfo& params) { + return absl::StrCat( + CongestionControlTypeToString(params.param.congestion_control_type), "_"); +} + +// Constructs various test permutations. +std::vector GetTestParams() { + std::vector params; + for (const CongestionControlType congestion_control_type : + {kBBR, kCubicBytes, kRenoBytes, kPCC}) { + params.push_back(TestParams(congestion_control_type)); + } + return params; +} + +} // namespace + +class SendAlgorithmTest : public QuicTestWithParam { + protected: + SendAlgorithmTest() + : simulator_(), + quic_sender_(&simulator_, "QUIC sender", "Receiver", + Perspective::IS_CLIENT, TestConnectionId()), + receiver_(&simulator_, "Receiver", "QUIC sender", + Perspective::IS_SERVER, TestConnectionId()) { + rtt_stats_ = quic_sender_.connection()->sent_packet_manager().GetRttStats(); + sender_ = SendAlgorithmInterface::Create( + simulator_.GetClock(), rtt_stats_, + QuicSentPacketManagerPeer::GetUnackedPacketMap( + QuicConnectionPeer::GetSentPacketManager( + quic_sender_.connection())), + GetParam().congestion_control_type, &random_, &stats_, + kInitialCongestionWindowPackets, nullptr); + quic_sender_.RecordTrace(); + + QuicConnectionPeer::SetSendAlgorithm(quic_sender_.connection(), sender_); + const int kTestMaxPacketSize = 1350; + quic_sender_.connection()->SetMaxPacketLength(kTestMaxPacketSize); + clock_ = simulator_.GetClock(); + simulator_.set_random_generator(&random_); + + uint64_t seed = QuicRandom::GetInstance()->RandUint64(); + random_.set_seed(seed); + QUIC_LOG(INFO) << "SendAlgorithmTest simulator set up. Seed: " << seed; + } + + // Creates a simulated network, with default settings between the + // sender and the switch and the given settings from the switch to + // the receiver. + void CreateSetup(const QuicBandwidth& test_bandwidth, + const QuicTime::Delta& test_link_delay, + QuicByteCount bottleneck_queue_length) { + switch_ = std::make_unique(&simulator_, "Switch", 8, + bottleneck_queue_length); + quic_sender_link_ = std::make_unique( + &quic_sender_, switch_->port(1), kLocalLinkBandwidth, + kLocalPropagationDelay); + receiver_link_ = std::make_unique( + &receiver_, switch_->port(2), test_bandwidth, test_link_delay); + } + + void DoSimpleTransfer(QuicByteCount transfer_size, QuicTime::Delta deadline) { + quic_sender_.AddBytesToTransfer(transfer_size); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return quic_sender_.bytes_to_transfer() == 0; }, deadline); + EXPECT_TRUE(simulator_result) + << "Simple transfer failed. Bytes remaining: " + << quic_sender_.bytes_to_transfer(); + } + + void SendBursts(size_t number_of_bursts, QuicByteCount bytes, + QuicTime::Delta rtt, QuicTime::Delta wait_time) { + ASSERT_EQ(0u, quic_sender_.bytes_to_transfer()); + for (size_t i = 0; i < number_of_bursts; i++) { + quic_sender_.AddBytesToTransfer(bytes); + + // Transfer data and wait for three seconds between each transfer. + simulator_.RunFor(wait_time); + + // Ensure the connection did not time out. + ASSERT_TRUE(quic_sender_.connection()->connected()); + ASSERT_TRUE(receiver_.connection()->connected()); + } + + simulator_.RunFor(wait_time + rtt); + EXPECT_EQ(0u, quic_sender_.bytes_to_transfer()); + } + + // Estimates the elapsed time for a given transfer size, given the + // bottleneck bandwidth and link propagation delay. + QuicTime::Delta EstimatedElapsedTime( + QuicByteCount transfer_size_bytes, QuicBandwidth test_link_bandwidth, + const QuicTime::Delta& test_link_delay) const { + return test_link_bandwidth.TransferTime(transfer_size_bytes) + + 2 * test_link_delay; + } + + QuicTime QuicSenderStartTime() { + return quic_sender_.connection()->GetStats().connection_creation_time; + } + + void PrintTransferStats() { + const QuicConnectionStats& stats = quic_sender_.connection()->GetStats(); + QUIC_LOG(INFO) << "Summary for scenario " << GetParam(); + QUIC_LOG(INFO) << "Sender stats is " << stats; + const double rtx_rate = + static_cast(stats.bytes_retransmitted) / stats.bytes_sent; + QUIC_LOG(INFO) << "Retransmit rate (num_rtx/num_total_sent): " << rtx_rate; + QUIC_LOG(INFO) << "Connection elapsed time: " + << (clock_->Now() - QuicSenderStartTime()).ToMilliseconds() + << " (ms)"; + } + + simulator::Simulator simulator_; + simulator::QuicEndpoint quic_sender_; + simulator::QuicEndpoint receiver_; + std::unique_ptr switch_; + std::unique_ptr quic_sender_link_; + std::unique_ptr receiver_link_; + QuicConnectionStats stats_; + + SimpleRandom random_; + + // Owned by different components of the connection. + const QuicClock* clock_; + const RttStats* rtt_stats_; + SendAlgorithmInterface* sender_; +}; + +INSTANTIATE_TEST_SUITE_P(SendAlgorithmTests, SendAlgorithmTest, + ::testing::ValuesIn(GetTestParams()), + TestParamToString); + +// Test a simple long data transfer in the default setup. +TEST_P(SendAlgorithmTest, SimpleWiredNetworkTransfer) { + CreateSetup(kTestLinkWiredBandwidth, kTestLinkWiredPropagationDelay, + kTestWiredBdp); + const QuicByteCount kTransferSizeBytes = 12 * 1024 * 1024; + const QuicTime::Delta maximum_elapsed_time = + EstimatedElapsedTime(kTransferSizeBytes, kTestLinkWiredBandwidth, + kTestLinkWiredPropagationDelay) * + 1.2; + DoSimpleTransfer(kTransferSizeBytes, maximum_elapsed_time); + PrintTransferStats(); +} + +TEST_P(SendAlgorithmTest, LowBdpPolicedNetworkTransfer) { + CreateSetup(kTestLinkLowBdpBandwidth, kTestLinkLowBdpPropagationDelay, + kTestPolicerQueue); + const QuicByteCount kTransferSizeBytes = 5 * 1024 * 1024; + const QuicTime::Delta maximum_elapsed_time = + EstimatedElapsedTime(kTransferSizeBytes, kTestLinkLowBdpBandwidth, + kTestLinkLowBdpPropagationDelay) * + 1.2; + DoSimpleTransfer(kTransferSizeBytes, maximum_elapsed_time); + PrintTransferStats(); +} + +TEST_P(SendAlgorithmTest, AppLimitedBurstsOverWiredNetwork) { + CreateSetup(kTestLinkWiredBandwidth, kTestLinkWiredPropagationDelay, + kTestWiredBdp); + const QuicByteCount kBurstSizeBytes = 512; + const int kNumBursts = 20; + const QuicTime::Delta kWaitTime = QuicTime::Delta::FromSeconds(3); + SendBursts(kNumBursts, kBurstSizeBytes, kTestWiredRtt, kWaitTime); + PrintTransferStats(); + + const QuicTime::Delta estimated_burst_time = + EstimatedElapsedTime(kBurstSizeBytes, kTestLinkWiredBandwidth, + kTestLinkWiredPropagationDelay) + + kWaitTime; + const QuicTime::Delta max_elapsed_time = + kNumBursts * estimated_burst_time + kWaitTime; + const QuicTime::Delta actual_elapsed_time = + clock_->Now() - QuicSenderStartTime(); + EXPECT_GE(max_elapsed_time, actual_elapsed_time); +} + +TEST_P(SendAlgorithmTest, SatelliteNetworkTransfer) { + CreateSetup(kTestLinkWiredBandwidth, kTestSatellitePropagationDelay, + kTestWiredBdp); + const QuicByteCount kTransferSizeBytes = 12 * 1024 * 1024; + const QuicTime::Delta maximum_elapsed_time = + EstimatedElapsedTime(kTransferSizeBytes, kTestLinkWiredBandwidth, + kTestSatellitePropagationDelay) * + 1.25; + DoSimpleTransfer(kTransferSizeBytes, maximum_elapsed_time); + PrintTransferStats(); +} + +TEST_P(SendAlgorithmTest, 2GNetworkTransfer) { + CreateSetup(kTestLink2GBandwidth, kTestCellularPropagationDelay, + kCellularQueue); + const QuicByteCount kTransferSizeBytes = 1024 * 1024; + const QuicTime::Delta maximum_elapsed_time = + EstimatedElapsedTime(kTransferSizeBytes, kTestLink2GBandwidth, + kTestCellularPropagationDelay) * + 1.2; + DoSimpleTransfer(kTransferSizeBytes, maximum_elapsed_time); + PrintTransferStats(); +} + +TEST_P(SendAlgorithmTest, 3GNetworkTransfer) { + CreateSetup(kTestLink3GBandwidth, kTestCellularPropagationDelay, + kCellularQueue); + const QuicByteCount kTransferSizeBytes = 5 * 1024 * 1024; + const QuicTime::Delta maximum_elapsed_time = + EstimatedElapsedTime(kTransferSizeBytes, kTestLink3GBandwidth, + kTestCellularPropagationDelay) * + 1.2; + DoSimpleTransfer(kTransferSizeBytes, maximum_elapsed_time); + PrintTransferStats(); +} + +TEST_P(SendAlgorithmTest, LowRTTTransfer) { + CreateSetup(kTestLinkWiredBandwidth, kTestLinkSmallRTTDelay, kCellularQueue); + + const QuicByteCount kTransferSizeBytes = 12 * 1024 * 1024; + const QuicTime::Delta maximum_elapsed_time = + EstimatedElapsedTime(kTransferSizeBytes, kTestLinkWiredBandwidth, + kTestLinkSmallRTTDelay) * + 1.2; + DoSimpleTransfer(kTransferSizeBytes, maximum_elapsed_time); + PrintTransferStats(); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/congestion_control/tcp_cubic_sender_bytes.cc b/quiche/quic/core/congestion_control/tcp_cubic_sender_bytes.cc new file mode 100644 index 000000000000..610bce71cb70 --- /dev/null +++ b/quiche/quic/core/congestion_control/tcp_cubic_sender_bytes.cc @@ -0,0 +1,387 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/tcp_cubic_sender_bytes.h" + +#include +#include +#include + +#include "quiche/quic/core/congestion_control/prr_sender.h" +#include "quiche/quic/core/congestion_control/rtt_stats.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +namespace { +// Constants based on TCP defaults. +const QuicByteCount kMaxBurstBytes = 3 * kDefaultTCPMSS; +const float kRenoBeta = 0.7f; // Reno backoff factor. +// The minimum cwnd based on RFC 3782 (TCP NewReno) for cwnd reductions on a +// fast retransmission. +const QuicByteCount kDefaultMinimumCongestionWindow = 2 * kDefaultTCPMSS; +} // namespace + +TcpCubicSenderBytes::TcpCubicSenderBytes( + const QuicClock* clock, const RttStats* rtt_stats, bool reno, + QuicPacketCount initial_tcp_congestion_window, + QuicPacketCount max_congestion_window, QuicConnectionStats* stats) + : rtt_stats_(rtt_stats), + stats_(stats), + reno_(reno), + num_connections_(kDefaultNumConnections), + min4_mode_(false), + last_cutback_exited_slowstart_(false), + slow_start_large_reduction_(false), + no_prr_(false), + cubic_(clock), + num_acked_packets_(0), + congestion_window_(initial_tcp_congestion_window * kDefaultTCPMSS), + min_congestion_window_(kDefaultMinimumCongestionWindow), + max_congestion_window_(max_congestion_window * kDefaultTCPMSS), + slowstart_threshold_(max_congestion_window * kDefaultTCPMSS), + initial_tcp_congestion_window_(initial_tcp_congestion_window * + kDefaultTCPMSS), + initial_max_tcp_congestion_window_(max_congestion_window * + kDefaultTCPMSS), + min_slow_start_exit_window_(min_congestion_window_) {} + +TcpCubicSenderBytes::~TcpCubicSenderBytes() {} + +void TcpCubicSenderBytes::SetFromConfig(const QuicConfig& config, + Perspective perspective) { + if (perspective == Perspective::IS_SERVER && + config.HasReceivedConnectionOptions()) { + if (ContainsQuicTag(config.ReceivedConnectionOptions(), kMIN4)) { + // Min CWND of 4 experiment. + min4_mode_ = true; + SetMinCongestionWindowInPackets(1); + } + if (ContainsQuicTag(config.ReceivedConnectionOptions(), kSSLR)) { + // Slow Start Fast Exit experiment. + slow_start_large_reduction_ = true; + } + if (ContainsQuicTag(config.ReceivedConnectionOptions(), kNPRR)) { + // Use unity pacing instead of PRR. + no_prr_ = true; + } + } +} + +void TcpCubicSenderBytes::AdjustNetworkParameters(const NetworkParams& params) { + if (params.bandwidth.IsZero() || params.rtt.IsZero()) { + return; + } + SetCongestionWindowFromBandwidthAndRtt(params.bandwidth, params.rtt); +} + +float TcpCubicSenderBytes::RenoBeta() const { + // kNConnectionBeta is the backoff factor after loss for our N-connection + // emulation, which emulates the effective backoff of an ensemble of N + // TCP-Reno connections on a single loss event. The effective multiplier is + // computed as: + return (num_connections_ - 1 + kRenoBeta) / num_connections_; +} + +void TcpCubicSenderBytes::OnCongestionEvent( + bool rtt_updated, QuicByteCount prior_in_flight, QuicTime event_time, + const AckedPacketVector& acked_packets, + const LostPacketVector& lost_packets, QuicPacketCount /*num_ect*/, + QuicPacketCount /*num_ce*/) { + if (rtt_updated && InSlowStart() && + hybrid_slow_start_.ShouldExitSlowStart( + rtt_stats_->latest_rtt(), rtt_stats_->min_rtt(), + GetCongestionWindow() / kDefaultTCPMSS)) { + ExitSlowstart(); + } + for (const LostPacket& lost_packet : lost_packets) { + OnPacketLost(lost_packet.packet_number, lost_packet.bytes_lost, + prior_in_flight); + } + for (const AckedPacket& acked_packet : acked_packets) { + OnPacketAcked(acked_packet.packet_number, acked_packet.bytes_acked, + prior_in_flight, event_time); + } +} + +void TcpCubicSenderBytes::OnPacketAcked(QuicPacketNumber acked_packet_number, + QuicByteCount acked_bytes, + QuicByteCount prior_in_flight, + QuicTime event_time) { + largest_acked_packet_number_.UpdateMax(acked_packet_number); + if (InRecovery()) { + if (!no_prr_) { + // PRR is used when in recovery. + prr_.OnPacketAcked(acked_bytes); + } + return; + } + MaybeIncreaseCwnd(acked_packet_number, acked_bytes, prior_in_flight, + event_time); + if (InSlowStart()) { + hybrid_slow_start_.OnPacketAcked(acked_packet_number); + } +} + +void TcpCubicSenderBytes::OnPacketSent( + QuicTime /*sent_time*/, QuicByteCount /*bytes_in_flight*/, + QuicPacketNumber packet_number, QuicByteCount bytes, + HasRetransmittableData is_retransmittable) { + if (InSlowStart()) { + ++(stats_->slowstart_packets_sent); + } + + if (is_retransmittable != HAS_RETRANSMITTABLE_DATA) { + return; + } + if (InRecovery()) { + // PRR is used when in recovery. + prr_.OnPacketSent(bytes); + } + QUICHE_DCHECK(!largest_sent_packet_number_.IsInitialized() || + largest_sent_packet_number_ < packet_number); + largest_sent_packet_number_ = packet_number; + hybrid_slow_start_.OnPacketSent(packet_number); +} + +bool TcpCubicSenderBytes::CanSend(QuicByteCount bytes_in_flight) { + if (!no_prr_ && InRecovery()) { + // PRR is used when in recovery. + return prr_.CanSend(GetCongestionWindow(), bytes_in_flight, + GetSlowStartThreshold()); + } + if (GetCongestionWindow() > bytes_in_flight) { + return true; + } + if (min4_mode_ && bytes_in_flight < 4 * kDefaultTCPMSS) { + return true; + } + return false; +} + +QuicBandwidth TcpCubicSenderBytes::PacingRate( + QuicByteCount /* bytes_in_flight */) const { + // We pace at twice the rate of the underlying sender's bandwidth estimate + // during slow start and 1.25x during congestion avoidance to ensure pacing + // doesn't prevent us from filling the window. + QuicTime::Delta srtt = rtt_stats_->SmoothedOrInitialRtt(); + const QuicBandwidth bandwidth = + QuicBandwidth::FromBytesAndTimeDelta(GetCongestionWindow(), srtt); + return bandwidth * (InSlowStart() ? 2 : (no_prr_ && InRecovery() ? 1 : 1.25)); +} + +QuicBandwidth TcpCubicSenderBytes::BandwidthEstimate() const { + QuicTime::Delta srtt = rtt_stats_->smoothed_rtt(); + if (srtt.IsZero()) { + // If we haven't measured an rtt, the bandwidth estimate is unknown. + return QuicBandwidth::Zero(); + } + return QuicBandwidth::FromBytesAndTimeDelta(GetCongestionWindow(), srtt); +} + +bool TcpCubicSenderBytes::InSlowStart() const { + return GetCongestionWindow() < GetSlowStartThreshold(); +} + +bool TcpCubicSenderBytes::IsCwndLimited(QuicByteCount bytes_in_flight) const { + const QuicByteCount congestion_window = GetCongestionWindow(); + if (bytes_in_flight >= congestion_window) { + return true; + } + const QuicByteCount available_bytes = congestion_window - bytes_in_flight; + const bool slow_start_limited = + InSlowStart() && bytes_in_flight > congestion_window / 2; + return slow_start_limited || available_bytes <= kMaxBurstBytes; +} + +bool TcpCubicSenderBytes::InRecovery() const { + return largest_acked_packet_number_.IsInitialized() && + largest_sent_at_last_cutback_.IsInitialized() && + largest_acked_packet_number_ <= largest_sent_at_last_cutback_; +} + +void TcpCubicSenderBytes::OnRetransmissionTimeout(bool packets_retransmitted) { + largest_sent_at_last_cutback_.Clear(); + if (!packets_retransmitted) { + return; + } + hybrid_slow_start_.Restart(); + HandleRetransmissionTimeout(); +} + +std::string TcpCubicSenderBytes::GetDebugState() const { return ""; } + +void TcpCubicSenderBytes::OnApplicationLimited( + QuicByteCount /*bytes_in_flight*/) {} + +void TcpCubicSenderBytes::SetCongestionWindowFromBandwidthAndRtt( + QuicBandwidth bandwidth, QuicTime::Delta rtt) { + QuicByteCount new_congestion_window = bandwidth.ToBytesPerPeriod(rtt); + // Limit new CWND if needed. + congestion_window_ = + std::max(min_congestion_window_, + std::min(new_congestion_window, + kMaxResumptionCongestionWindow * kDefaultTCPMSS)); +} + +void TcpCubicSenderBytes::SetInitialCongestionWindowInPackets( + QuicPacketCount congestion_window) { + congestion_window_ = congestion_window * kDefaultTCPMSS; +} + +void TcpCubicSenderBytes::SetMinCongestionWindowInPackets( + QuicPacketCount congestion_window) { + min_congestion_window_ = congestion_window * kDefaultTCPMSS; +} + +void TcpCubicSenderBytes::SetNumEmulatedConnections(int num_connections) { + num_connections_ = std::max(1, num_connections); + cubic_.SetNumConnections(num_connections_); +} + +void TcpCubicSenderBytes::ExitSlowstart() { + slowstart_threshold_ = congestion_window_; +} + +void TcpCubicSenderBytes::OnPacketLost(QuicPacketNumber packet_number, + QuicByteCount lost_bytes, + QuicByteCount prior_in_flight) { + // TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets + // already sent should be treated as a single loss event, since it's expected. + if (largest_sent_at_last_cutback_.IsInitialized() && + packet_number <= largest_sent_at_last_cutback_) { + if (last_cutback_exited_slowstart_) { + ++stats_->slowstart_packets_lost; + stats_->slowstart_bytes_lost += lost_bytes; + if (slow_start_large_reduction_) { + // Reduce congestion window by lost_bytes for every loss. + congestion_window_ = std::max(congestion_window_ - lost_bytes, + min_slow_start_exit_window_); + slowstart_threshold_ = congestion_window_; + } + } + QUIC_DVLOG(1) << "Ignoring loss for largest_missing:" << packet_number + << " because it was sent prior to the last CWND cutback."; + return; + } + ++stats_->tcp_loss_events; + last_cutback_exited_slowstart_ = InSlowStart(); + if (InSlowStart()) { + ++stats_->slowstart_packets_lost; + } + + if (!no_prr_) { + prr_.OnPacketLost(prior_in_flight); + } + + // TODO(b/77268641): Separate out all of slow start into a separate class. + if (slow_start_large_reduction_ && InSlowStart()) { + QUICHE_DCHECK_LT(kDefaultTCPMSS, congestion_window_); + if (congestion_window_ >= 2 * initial_tcp_congestion_window_) { + min_slow_start_exit_window_ = congestion_window_ / 2; + } + congestion_window_ = congestion_window_ - kDefaultTCPMSS; + } else if (reno_) { + congestion_window_ = congestion_window_ * RenoBeta(); + } else { + congestion_window_ = + cubic_.CongestionWindowAfterPacketLoss(congestion_window_); + } + if (congestion_window_ < min_congestion_window_) { + congestion_window_ = min_congestion_window_; + } + slowstart_threshold_ = congestion_window_; + largest_sent_at_last_cutback_ = largest_sent_packet_number_; + // Reset packet count from congestion avoidance mode. We start counting again + // when we're out of recovery. + num_acked_packets_ = 0; + QUIC_DVLOG(1) << "Incoming loss; congestion window: " << congestion_window_ + << " slowstart threshold: " << slowstart_threshold_; +} + +QuicByteCount TcpCubicSenderBytes::GetCongestionWindow() const { + return congestion_window_; +} + +QuicByteCount TcpCubicSenderBytes::GetSlowStartThreshold() const { + return slowstart_threshold_; +} + +// Called when we receive an ack. Normal TCP tracks how many packets one ack +// represents, but quic has a separate ack for each packet. +void TcpCubicSenderBytes::MaybeIncreaseCwnd( + QuicPacketNumber /*acked_packet_number*/, QuicByteCount acked_bytes, + QuicByteCount prior_in_flight, QuicTime event_time) { + QUIC_BUG_IF(quic_bug_10439_1, InRecovery()) + << "Never increase the CWND during recovery."; + // Do not increase the congestion window unless the sender is close to using + // the current window. + if (!IsCwndLimited(prior_in_flight)) { + cubic_.OnApplicationLimited(); + return; + } + if (congestion_window_ >= max_congestion_window_) { + return; + } + if (InSlowStart()) { + // TCP slow start, exponential growth, increase by one for each ACK. + congestion_window_ += kDefaultTCPMSS; + QUIC_DVLOG(1) << "Slow start; congestion window: " << congestion_window_ + << " slowstart threshold: " << slowstart_threshold_; + return; + } + // Congestion avoidance. + if (reno_) { + // Classic Reno congestion avoidance. + ++num_acked_packets_; + // Divide by num_connections to smoothly increase the CWND at a faster rate + // than conventional Reno. + if (num_acked_packets_ * num_connections_ >= + congestion_window_ / kDefaultTCPMSS) { + congestion_window_ += kDefaultTCPMSS; + num_acked_packets_ = 0; + } + + QUIC_DVLOG(1) << "Reno; congestion window: " << congestion_window_ + << " slowstart threshold: " << slowstart_threshold_ + << " congestion window count: " << num_acked_packets_; + } else { + congestion_window_ = std::min( + max_congestion_window_, + cubic_.CongestionWindowAfterAck(acked_bytes, congestion_window_, + rtt_stats_->min_rtt(), event_time)); + QUIC_DVLOG(1) << "Cubic; congestion window: " << congestion_window_ + << " slowstart threshold: " << slowstart_threshold_; + } +} + +void TcpCubicSenderBytes::HandleRetransmissionTimeout() { + cubic_.ResetCubicState(); + slowstart_threshold_ = congestion_window_ / 2; + congestion_window_ = min_congestion_window_; +} + +void TcpCubicSenderBytes::OnConnectionMigration() { + hybrid_slow_start_.Restart(); + prr_ = PrrSender(); + largest_sent_packet_number_.Clear(); + largest_acked_packet_number_.Clear(); + largest_sent_at_last_cutback_.Clear(); + last_cutback_exited_slowstart_ = false; + cubic_.ResetCubicState(); + num_acked_packets_ = 0; + congestion_window_ = initial_tcp_congestion_window_; + max_congestion_window_ = initial_max_tcp_congestion_window_; + slowstart_threshold_ = initial_max_tcp_congestion_window_; +} + +CongestionControlType TcpCubicSenderBytes::GetCongestionControlType() const { + return reno_ ? kRenoBytes : kCubicBytes; +} + +} // namespace quic diff --git a/quiche/quic/core/congestion_control/tcp_cubic_sender_bytes.h b/quiche/quic/core/congestion_control/tcp_cubic_sender_bytes.h new file mode 100644 index 000000000000..184357d1a294 --- /dev/null +++ b/quiche/quic/core/congestion_control/tcp_cubic_sender_bytes.h @@ -0,0 +1,171 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// TCP cubic send side congestion algorithm, emulates the behavior of TCP cubic. + +#ifndef QUICHE_QUIC_CORE_CONGESTION_CONTROL_TCP_CUBIC_SENDER_BYTES_H_ +#define QUICHE_QUIC_CORE_CONGESTION_CONTROL_TCP_CUBIC_SENDER_BYTES_H_ + +#include +#include + +#include "quiche/quic/core/congestion_control/cubic_bytes.h" +#include "quiche/quic/core/congestion_control/hybrid_slow_start.h" +#include "quiche/quic/core/congestion_control/prr_sender.h" +#include "quiche/quic/core/congestion_control/send_algorithm_interface.h" +#include "quiche/quic/core/quic_bandwidth.h" +#include "quiche/quic/core/quic_connection_stats.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class RttStats; + +// Maximum window to allow when doing bandwidth resumption. +const QuicPacketCount kMaxResumptionCongestionWindow = 200; + +namespace test { +class TcpCubicSenderBytesPeer; +} // namespace test + +class QUIC_EXPORT_PRIVATE TcpCubicSenderBytes : public SendAlgorithmInterface { + public: + TcpCubicSenderBytes(const QuicClock* clock, const RttStats* rtt_stats, + bool reno, QuicPacketCount initial_tcp_congestion_window, + QuicPacketCount max_congestion_window, + QuicConnectionStats* stats); + TcpCubicSenderBytes(const TcpCubicSenderBytes&) = delete; + TcpCubicSenderBytes& operator=(const TcpCubicSenderBytes&) = delete; + ~TcpCubicSenderBytes() override; + + // Start implementation of SendAlgorithmInterface. + void SetFromConfig(const QuicConfig& config, + Perspective perspective) override; + void ApplyConnectionOptions( + const QuicTagVector& /*connection_options*/) override {} + void AdjustNetworkParameters(const NetworkParams& params) override; + void SetNumEmulatedConnections(int num_connections); + void SetInitialCongestionWindowInPackets( + QuicPacketCount congestion_window) override; + void OnConnectionMigration() override; + void OnCongestionEvent(bool rtt_updated, QuicByteCount prior_in_flight, + QuicTime event_time, + const AckedPacketVector& acked_packets, + const LostPacketVector& lost_packets, + QuicPacketCount num_ect, + QuicPacketCount num_ce) override; + void OnPacketSent(QuicTime sent_time, QuicByteCount bytes_in_flight, + QuicPacketNumber packet_number, QuicByteCount bytes, + HasRetransmittableData is_retransmittable) override; + void OnPacketNeutered(QuicPacketNumber /*packet_number*/) override {} + void OnRetransmissionTimeout(bool packets_retransmitted) override; + bool CanSend(QuicByteCount bytes_in_flight) override; + QuicBandwidth PacingRate(QuicByteCount bytes_in_flight) const override; + QuicBandwidth BandwidthEstimate() const override; + bool HasGoodBandwidthEstimateForResumption() const override { return false; } + QuicByteCount GetCongestionWindow() const override; + QuicByteCount GetSlowStartThreshold() const override; + CongestionControlType GetCongestionControlType() const override; + bool InSlowStart() const override; + bool InRecovery() const override; + std::string GetDebugState() const override; + void OnApplicationLimited(QuicByteCount bytes_in_flight) override; + void PopulateConnectionStats(QuicConnectionStats* /*stats*/) const override {} + bool SupportsECT0() const override { return false; } + bool SupportsECT1() const override { return false; } + // End implementation of SendAlgorithmInterface. + + QuicByteCount min_congestion_window() const { return min_congestion_window_; } + + protected: + // Compute the TCP Reno beta based on the current number of connections. + float RenoBeta() const; + + bool IsCwndLimited(QuicByteCount bytes_in_flight) const; + + // TODO(ianswett): Remove these and migrate to OnCongestionEvent. + void OnPacketAcked(QuicPacketNumber acked_packet_number, + QuicByteCount acked_bytes, QuicByteCount prior_in_flight, + QuicTime event_time); + void SetCongestionWindowFromBandwidthAndRtt(QuicBandwidth bandwidth, + QuicTime::Delta rtt); + void SetMinCongestionWindowInPackets(QuicPacketCount congestion_window); + void ExitSlowstart(); + void OnPacketLost(QuicPacketNumber packet_number, QuicByteCount lost_bytes, + QuicByteCount prior_in_flight); + void MaybeIncreaseCwnd(QuicPacketNumber acked_packet_number, + QuicByteCount acked_bytes, + QuicByteCount prior_in_flight, QuicTime event_time); + void HandleRetransmissionTimeout(); + + private: + friend class test::TcpCubicSenderBytesPeer; + + HybridSlowStart hybrid_slow_start_; + PrrSender prr_; + const RttStats* rtt_stats_; + QuicConnectionStats* stats_; + + // If true, Reno congestion control is used instead of Cubic. + const bool reno_; + + // Number of connections to simulate. + uint32_t num_connections_; + + // Track the largest packet that has been sent. + QuicPacketNumber largest_sent_packet_number_; + + // Track the largest packet that has been acked. + QuicPacketNumber largest_acked_packet_number_; + + // Track the largest packet number outstanding when a CWND cutback occurs. + QuicPacketNumber largest_sent_at_last_cutback_; + + // Whether to use 4 packets as the actual min, but pace lower. + bool min4_mode_; + + // Whether the last loss event caused us to exit slowstart. + // Used for stats collection of slowstart_packets_lost + bool last_cutback_exited_slowstart_; + + // When true, exit slow start with large cutback of congestion window. + bool slow_start_large_reduction_; + + // When true, use unity pacing instead of PRR. + bool no_prr_; + + CubicBytes cubic_; + + // ACK counter for the Reno implementation. + uint64_t num_acked_packets_; + + // Congestion window in bytes. + QuicByteCount congestion_window_; + + // Minimum congestion window in bytes. + QuicByteCount min_congestion_window_; + + // Maximum congestion window in bytes. + QuicByteCount max_congestion_window_; + + // Slow start congestion window in bytes, aka ssthresh. + QuicByteCount slowstart_threshold_; + + // Initial TCP congestion window in bytes. This variable can only be set when + // this algorithm is created. + const QuicByteCount initial_tcp_congestion_window_; + + // Initial maximum TCP congestion window in bytes. This variable can only be + // set when this algorithm is created. + const QuicByteCount initial_max_tcp_congestion_window_; + + // The minimum window when exiting slow start with large reduction. + QuicByteCount min_slow_start_exit_window_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CONGESTION_CONTROL_TCP_CUBIC_SENDER_BYTES_H_ diff --git a/quiche/quic/core/congestion_control/tcp_cubic_sender_bytes_test.cc b/quiche/quic/core/congestion_control/tcp_cubic_sender_bytes_test.cc new file mode 100644 index 000000000000..919233a49698 --- /dev/null +++ b/quiche/quic/core/congestion_control/tcp_cubic_sender_bytes_test.cc @@ -0,0 +1,841 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/tcp_cubic_sender_bytes.h" + +#include +#include +#include +#include + +#include "quiche/quic/core/congestion_control/rtt_stats.h" +#include "quiche/quic/core/congestion_control/send_algorithm_interface.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/mock_clock.h" +#include "quiche/quic/test_tools/quic_config_peer.h" + +namespace quic { +namespace test { + +// TODO(ianswett): A number of theses tests were written with the assumption of +// an initial CWND of 10. They have carefully calculated values which should be +// updated to be based on kInitialCongestionWindow. +const uint32_t kInitialCongestionWindowPackets = 10; +const uint32_t kMaxCongestionWindowPackets = 200; +const uint32_t kDefaultWindowTCP = + kInitialCongestionWindowPackets * kDefaultTCPMSS; +const float kRenoBeta = 0.7f; // Reno backoff factor. + +class TcpCubicSenderBytesPeer : public TcpCubicSenderBytes { + public: + TcpCubicSenderBytesPeer(const QuicClock* clock, bool reno) + : TcpCubicSenderBytes(clock, &rtt_stats_, reno, + kInitialCongestionWindowPackets, + kMaxCongestionWindowPackets, &stats_) {} + + const HybridSlowStart& hybrid_slow_start() const { + return hybrid_slow_start_; + } + + float GetRenoBeta() const { return RenoBeta(); } + + RttStats rtt_stats_; + QuicConnectionStats stats_; +}; + +class TcpCubicSenderBytesTest : public QuicTest { + protected: + TcpCubicSenderBytesTest() + : one_ms_(QuicTime::Delta::FromMilliseconds(1)), + sender_(new TcpCubicSenderBytesPeer(&clock_, true)), + packet_number_(1), + acked_packet_number_(0), + bytes_in_flight_(0) {} + + int SendAvailableSendWindow() { + return SendAvailableSendWindow(kDefaultTCPMSS); + } + + int SendAvailableSendWindow(QuicPacketLength /*packet_length*/) { + // Send as long as TimeUntilSend returns Zero. + int packets_sent = 0; + bool can_send = sender_->CanSend(bytes_in_flight_); + while (can_send) { + sender_->OnPacketSent(clock_.Now(), bytes_in_flight_, + QuicPacketNumber(packet_number_++), kDefaultTCPMSS, + HAS_RETRANSMITTABLE_DATA); + ++packets_sent; + bytes_in_flight_ += kDefaultTCPMSS; + can_send = sender_->CanSend(bytes_in_flight_); + } + return packets_sent; + } + + // Normal is that TCP acks every other segment. + void AckNPackets(int n) { + sender_->rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(60), + QuicTime::Delta::Zero(), clock_.Now()); + AckedPacketVector acked_packets; + LostPacketVector lost_packets; + for (int i = 0; i < n; ++i) { + ++acked_packet_number_; + acked_packets.push_back( + AckedPacket(QuicPacketNumber(acked_packet_number_), kDefaultTCPMSS, + QuicTime::Zero())); + } + sender_->OnCongestionEvent(true, bytes_in_flight_, clock_.Now(), + acked_packets, lost_packets, 0, 0); + bytes_in_flight_ -= n * kDefaultTCPMSS; + clock_.AdvanceTime(one_ms_); + } + + void LoseNPackets(int n) { LoseNPackets(n, kDefaultTCPMSS); } + + void LoseNPackets(int n, QuicPacketLength packet_length) { + AckedPacketVector acked_packets; + LostPacketVector lost_packets; + for (int i = 0; i < n; ++i) { + ++acked_packet_number_; + lost_packets.push_back( + LostPacket(QuicPacketNumber(acked_packet_number_), packet_length)); + } + sender_->OnCongestionEvent(false, bytes_in_flight_, clock_.Now(), + acked_packets, lost_packets, 0, 0); + bytes_in_flight_ -= n * packet_length; + } + + // Does not increment acked_packet_number_. + void LosePacket(uint64_t packet_number) { + AckedPacketVector acked_packets; + LostPacketVector lost_packets; + lost_packets.push_back( + LostPacket(QuicPacketNumber(packet_number), kDefaultTCPMSS)); + sender_->OnCongestionEvent(false, bytes_in_flight_, clock_.Now(), + acked_packets, lost_packets, 0, 0); + bytes_in_flight_ -= kDefaultTCPMSS; + } + + const QuicTime::Delta one_ms_; + MockClock clock_; + std::unique_ptr sender_; + uint64_t packet_number_; + uint64_t acked_packet_number_; + QuicByteCount bytes_in_flight_; +}; + +TEST_F(TcpCubicSenderBytesTest, SimpleSender) { + // At startup make sure we are at the default. + EXPECT_EQ(kDefaultWindowTCP, sender_->GetCongestionWindow()); + // At startup make sure we can send. + EXPECT_TRUE(sender_->CanSend(0)); + // Make sure we can send. + EXPECT_TRUE(sender_->CanSend(0)); + // And that window is un-affected. + EXPECT_EQ(kDefaultWindowTCP, sender_->GetCongestionWindow()); + + // Fill the send window with data, then verify that we can't send. + SendAvailableSendWindow(); + EXPECT_FALSE(sender_->CanSend(sender_->GetCongestionWindow())); +} + +TEST_F(TcpCubicSenderBytesTest, ApplicationLimitedSlowStart) { + // Send exactly 10 packets and ensure the CWND ends at 14 packets. + const int kNumberOfAcks = 5; + // At startup make sure we can send. + EXPECT_TRUE(sender_->CanSend(0)); + // Make sure we can send. + EXPECT_TRUE(sender_->CanSend(0)); + + SendAvailableSendWindow(); + for (int i = 0; i < kNumberOfAcks; ++i) { + AckNPackets(2); + } + QuicByteCount bytes_to_send = sender_->GetCongestionWindow(); + // It's expected 2 acks will arrive when the bytes_in_flight are greater than + // half the CWND. + EXPECT_EQ(kDefaultWindowTCP + kDefaultTCPMSS * 2 * 2, bytes_to_send); +} + +TEST_F(TcpCubicSenderBytesTest, ExponentialSlowStart) { + const int kNumberOfAcks = 20; + // At startup make sure we can send. + EXPECT_TRUE(sender_->CanSend(0)); + EXPECT_EQ(QuicBandwidth::Zero(), sender_->BandwidthEstimate()); + // Make sure we can send. + EXPECT_TRUE(sender_->CanSend(0)); + + for (int i = 0; i < kNumberOfAcks; ++i) { + // Send our full send window. + SendAvailableSendWindow(); + AckNPackets(2); + } + const QuicByteCount cwnd = sender_->GetCongestionWindow(); + EXPECT_EQ(kDefaultWindowTCP + kDefaultTCPMSS * 2 * kNumberOfAcks, cwnd); + EXPECT_EQ(QuicBandwidth::FromBytesAndTimeDelta( + cwnd, sender_->rtt_stats_.smoothed_rtt()), + sender_->BandwidthEstimate()); +} + +TEST_F(TcpCubicSenderBytesTest, SlowStartPacketLoss) { + sender_->SetNumEmulatedConnections(1); + const int kNumberOfAcks = 10; + for (int i = 0; i < kNumberOfAcks; ++i) { + // Send our full send window. + SendAvailableSendWindow(); + AckNPackets(2); + } + SendAvailableSendWindow(); + QuicByteCount expected_send_window = + kDefaultWindowTCP + (kDefaultTCPMSS * 2 * kNumberOfAcks); + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + // Lose a packet to exit slow start. + LoseNPackets(1); + size_t packets_in_recovery_window = expected_send_window / kDefaultTCPMSS; + + // We should now have fallen out of slow start with a reduced window. + expected_send_window *= kRenoBeta; + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + // Recovery phase. We need to ack every packet in the recovery window before + // we exit recovery. + size_t number_of_packets_in_window = expected_send_window / kDefaultTCPMSS; + QUIC_DLOG(INFO) << "number_packets: " << number_of_packets_in_window; + AckNPackets(packets_in_recovery_window); + SendAvailableSendWindow(); + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + // We need to ack an entire window before we increase CWND by 1. + AckNPackets(number_of_packets_in_window - 2); + SendAvailableSendWindow(); + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + // Next ack should increase cwnd by 1. + AckNPackets(1); + expected_send_window += kDefaultTCPMSS; + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + // Now RTO and ensure slow start gets reset. + EXPECT_TRUE(sender_->hybrid_slow_start().started()); + sender_->OnRetransmissionTimeout(true); + EXPECT_FALSE(sender_->hybrid_slow_start().started()); +} + +TEST_F(TcpCubicSenderBytesTest, SlowStartPacketLossWithLargeReduction) { + QuicConfig config; + QuicTagVector options; + options.push_back(kSSLR); + QuicConfigPeer::SetReceivedConnectionOptions(&config, options); + sender_->SetFromConfig(config, Perspective::IS_SERVER); + + sender_->SetNumEmulatedConnections(1); + const int kNumberOfAcks = (kDefaultWindowTCP / (2 * kDefaultTCPMSS)) - 1; + for (int i = 0; i < kNumberOfAcks; ++i) { + // Send our full send window. + SendAvailableSendWindow(); + AckNPackets(2); + } + SendAvailableSendWindow(); + QuicByteCount expected_send_window = + kDefaultWindowTCP + (kDefaultTCPMSS * 2 * kNumberOfAcks); + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + // Lose a packet to exit slow start. We should now have fallen out of + // slow start with a window reduced by 1. + LoseNPackets(1); + expected_send_window -= kDefaultTCPMSS; + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + // Lose 5 packets in recovery and verify that congestion window is reduced + // further. + LoseNPackets(5); + expected_send_window -= 5 * kDefaultTCPMSS; + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + // Lose another 10 packets and ensure it reduces below half the peak CWND, + // because we never acked the full IW. + LoseNPackets(10); + expected_send_window -= 10 * kDefaultTCPMSS; + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + size_t packets_in_recovery_window = expected_send_window / kDefaultTCPMSS; + + // Recovery phase. We need to ack every packet in the recovery window before + // we exit recovery. + size_t number_of_packets_in_window = expected_send_window / kDefaultTCPMSS; + QUIC_DLOG(INFO) << "number_packets: " << number_of_packets_in_window; + AckNPackets(packets_in_recovery_window); + SendAvailableSendWindow(); + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + // We need to ack an entire window before we increase CWND by 1. + AckNPackets(number_of_packets_in_window - 1); + SendAvailableSendWindow(); + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + // Next ack should increase cwnd by 1. + AckNPackets(1); + expected_send_window += kDefaultTCPMSS; + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + // Now RTO and ensure slow start gets reset. + EXPECT_TRUE(sender_->hybrid_slow_start().started()); + sender_->OnRetransmissionTimeout(true); + EXPECT_FALSE(sender_->hybrid_slow_start().started()); +} + +TEST_F(TcpCubicSenderBytesTest, SlowStartHalfPacketLossWithLargeReduction) { + QuicConfig config; + QuicTagVector options; + options.push_back(kSSLR); + QuicConfigPeer::SetReceivedConnectionOptions(&config, options); + sender_->SetFromConfig(config, Perspective::IS_SERVER); + + sender_->SetNumEmulatedConnections(1); + const int kNumberOfAcks = 10; + for (int i = 0; i < kNumberOfAcks; ++i) { + // Send our full send window in half sized packets. + SendAvailableSendWindow(kDefaultTCPMSS / 2); + AckNPackets(2); + } + SendAvailableSendWindow(kDefaultTCPMSS / 2); + QuicByteCount expected_send_window = + kDefaultWindowTCP + (kDefaultTCPMSS * 2 * kNumberOfAcks); + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + // Lose a packet to exit slow start. We should now have fallen out of + // slow start with a window reduced by 1. + LoseNPackets(1); + expected_send_window -= kDefaultTCPMSS; + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + // Lose 10 packets in recovery and verify that congestion window is reduced + // by 5 packets. + LoseNPackets(10, kDefaultTCPMSS / 2); + expected_send_window -= 5 * kDefaultTCPMSS; + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); +} + +TEST_F(TcpCubicSenderBytesTest, SlowStartPacketLossWithMaxHalfReduction) { + QuicConfig config; + QuicTagVector options; + options.push_back(kSSLR); + QuicConfigPeer::SetReceivedConnectionOptions(&config, options); + sender_->SetFromConfig(config, Perspective::IS_SERVER); + + sender_->SetNumEmulatedConnections(1); + const int kNumberOfAcks = kInitialCongestionWindowPackets / 2; + for (int i = 0; i < kNumberOfAcks; ++i) { + // Send our full send window. + SendAvailableSendWindow(); + AckNPackets(2); + } + SendAvailableSendWindow(); + QuicByteCount expected_send_window = + kDefaultWindowTCP + (kDefaultTCPMSS * 2 * kNumberOfAcks); + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + // Lose a packet to exit slow start. We should now have fallen out of + // slow start with a window reduced by 1. + LoseNPackets(1); + expected_send_window -= kDefaultTCPMSS; + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + // Lose half the outstanding packets in recovery and verify the congestion + // window is only reduced by a max of half. + LoseNPackets(kNumberOfAcks * 2); + expected_send_window -= (kNumberOfAcks * 2 - 1) * kDefaultTCPMSS; + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + LoseNPackets(5); + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); +} + +TEST_F(TcpCubicSenderBytesTest, NoPRRWhenLessThanOnePacketInFlight) { + SendAvailableSendWindow(); + LoseNPackets(kInitialCongestionWindowPackets - 1); + AckNPackets(1); + // PRR will allow 2 packets for every ack during recovery. + EXPECT_EQ(2, SendAvailableSendWindow()); + // Simulate abandoning all packets by supplying a bytes_in_flight of 0. + // PRR should now allow a packet to be sent, even though prr's state variables + // believe it has sent enough packets. + EXPECT_TRUE(sender_->CanSend(0)); +} + +TEST_F(TcpCubicSenderBytesTest, SlowStartPacketLossPRR) { + sender_->SetNumEmulatedConnections(1); + // Test based on the first example in RFC6937. + // Ack 10 packets in 5 acks to raise the CWND to 20, as in the example. + const int kNumberOfAcks = 5; + for (int i = 0; i < kNumberOfAcks; ++i) { + // Send our full send window. + SendAvailableSendWindow(); + AckNPackets(2); + } + SendAvailableSendWindow(); + QuicByteCount expected_send_window = + kDefaultWindowTCP + (kDefaultTCPMSS * 2 * kNumberOfAcks); + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + LoseNPackets(1); + + // We should now have fallen out of slow start with a reduced window. + size_t send_window_before_loss = expected_send_window; + expected_send_window *= kRenoBeta; + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + // Testing TCP proportional rate reduction. + // We should send packets paced over the received acks for the remaining + // outstanding packets. The number of packets before we exit recovery is the + // original CWND minus the packet that has been lost and the one which + // triggered the loss. + size_t remaining_packets_in_recovery = + send_window_before_loss / kDefaultTCPMSS - 2; + + for (size_t i = 0; i < remaining_packets_in_recovery; ++i) { + AckNPackets(1); + SendAvailableSendWindow(); + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + } + + // We need to ack another window before we increase CWND by 1. + size_t number_of_packets_in_window = expected_send_window / kDefaultTCPMSS; + for (size_t i = 0; i < number_of_packets_in_window; ++i) { + AckNPackets(1); + EXPECT_EQ(1, SendAvailableSendWindow()); + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + } + + AckNPackets(1); + expected_send_window += kDefaultTCPMSS; + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); +} + +TEST_F(TcpCubicSenderBytesTest, SlowStartBurstPacketLossPRR) { + sender_->SetNumEmulatedConnections(1); + // Test based on the second example in RFC6937, though we also implement + // forward acknowledgements, so the first two incoming acks will trigger + // PRR immediately. + // Ack 20 packets in 10 acks to raise the CWND to 30. + const int kNumberOfAcks = 10; + for (int i = 0; i < kNumberOfAcks; ++i) { + // Send our full send window. + SendAvailableSendWindow(); + AckNPackets(2); + } + SendAvailableSendWindow(); + QuicByteCount expected_send_window = + kDefaultWindowTCP + (kDefaultTCPMSS * 2 * kNumberOfAcks); + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + // Lose one more than the congestion window reduction, so that after loss, + // bytes_in_flight is lesser than the congestion window. + size_t send_window_after_loss = kRenoBeta * expected_send_window; + size_t num_packets_to_lose = + (expected_send_window - send_window_after_loss) / kDefaultTCPMSS + 1; + LoseNPackets(num_packets_to_lose); + // Immediately after the loss, ensure at least one packet can be sent. + // Losses without subsequent acks can occur with timer based loss detection. + EXPECT_TRUE(sender_->CanSend(bytes_in_flight_)); + AckNPackets(1); + + // We should now have fallen out of slow start with a reduced window. + expected_send_window *= kRenoBeta; + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + // Only 2 packets should be allowed to be sent, per PRR-SSRB. + EXPECT_EQ(2, SendAvailableSendWindow()); + + // Ack the next packet, which triggers another loss. + LoseNPackets(1); + AckNPackets(1); + + // Send 2 packets to simulate PRR-SSRB. + EXPECT_EQ(2, SendAvailableSendWindow()); + + // Ack the next packet, which triggers another loss. + LoseNPackets(1); + AckNPackets(1); + + // Send 2 packets to simulate PRR-SSRB. + EXPECT_EQ(2, SendAvailableSendWindow()); + + // Exit recovery and return to sending at the new rate. + for (int i = 0; i < kNumberOfAcks; ++i) { + AckNPackets(1); + EXPECT_EQ(1, SendAvailableSendWindow()); + } +} + +TEST_F(TcpCubicSenderBytesTest, RTOCongestionWindow) { + EXPECT_EQ(kDefaultWindowTCP, sender_->GetCongestionWindow()); + // Expect the window to decrease to the minimum once the RTO fires and slow + // start threshold to be set to 1/2 of the CWND. + sender_->OnRetransmissionTimeout(true); + EXPECT_EQ(2 * kDefaultTCPMSS, sender_->GetCongestionWindow()); + EXPECT_EQ(5u * kDefaultTCPMSS, sender_->GetSlowStartThreshold()); +} + +TEST_F(TcpCubicSenderBytesTest, RTOCongestionWindowNoRetransmission) { + EXPECT_EQ(kDefaultWindowTCP, sender_->GetCongestionWindow()); + + // Expect the window to remain unchanged if the RTO fires but no packets are + // retransmitted. + sender_->OnRetransmissionTimeout(false); + EXPECT_EQ(kDefaultWindowTCP, sender_->GetCongestionWindow()); +} + +TEST_F(TcpCubicSenderBytesTest, TcpCubicResetEpochOnQuiescence) { + const int kMaxCongestionWindow = 50; + const QuicByteCount kMaxCongestionWindowBytes = + kMaxCongestionWindow * kDefaultTCPMSS; + int num_sent = SendAvailableSendWindow(); + + // Make sure we fall out of slow start. + QuicByteCount saved_cwnd = sender_->GetCongestionWindow(); + LoseNPackets(1); + EXPECT_GT(saved_cwnd, sender_->GetCongestionWindow()); + + // Ack the rest of the outstanding packets to get out of recovery. + for (int i = 1; i < num_sent; ++i) { + AckNPackets(1); + } + EXPECT_EQ(0u, bytes_in_flight_); + + // Send a new window of data and ack all; cubic growth should occur. + saved_cwnd = sender_->GetCongestionWindow(); + num_sent = SendAvailableSendWindow(); + for (int i = 0; i < num_sent; ++i) { + AckNPackets(1); + } + EXPECT_LT(saved_cwnd, sender_->GetCongestionWindow()); + EXPECT_GT(kMaxCongestionWindowBytes, sender_->GetCongestionWindow()); + EXPECT_EQ(0u, bytes_in_flight_); + + // Quiescent time of 100 seconds + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(100000)); + + // Send new window of data and ack one packet. Cubic epoch should have + // been reset; ensure cwnd increase is not dramatic. + saved_cwnd = sender_->GetCongestionWindow(); + SendAvailableSendWindow(); + AckNPackets(1); + EXPECT_NEAR(saved_cwnd, sender_->GetCongestionWindow(), kDefaultTCPMSS); + EXPECT_GT(kMaxCongestionWindowBytes, sender_->GetCongestionWindow()); +} + +TEST_F(TcpCubicSenderBytesTest, MultipleLossesInOneWindow) { + SendAvailableSendWindow(); + const QuicByteCount initial_window = sender_->GetCongestionWindow(); + LosePacket(acked_packet_number_ + 1); + const QuicByteCount post_loss_window = sender_->GetCongestionWindow(); + EXPECT_GT(initial_window, post_loss_window); + LosePacket(acked_packet_number_ + 3); + EXPECT_EQ(post_loss_window, sender_->GetCongestionWindow()); + LosePacket(packet_number_ - 1); + EXPECT_EQ(post_loss_window, sender_->GetCongestionWindow()); + + // Lose a later packet and ensure the window decreases. + LosePacket(packet_number_); + EXPECT_GT(post_loss_window, sender_->GetCongestionWindow()); +} + +TEST_F(TcpCubicSenderBytesTest, ConfigureMaxInitialWindow) { + QuicConfig config; + + // Verify that kCOPT: kIW10 forces the congestion window to the default of 10. + QuicTagVector options; + options.push_back(kIW10); + QuicConfigPeer::SetReceivedConnectionOptions(&config, options); + sender_->SetFromConfig(config, Perspective::IS_SERVER); + EXPECT_EQ(10u * kDefaultTCPMSS, sender_->GetCongestionWindow()); +} + +TEST_F(TcpCubicSenderBytesTest, SetInitialCongestionWindow) { + EXPECT_NE(3u * kDefaultTCPMSS, sender_->GetCongestionWindow()); + sender_->SetInitialCongestionWindowInPackets(3); + EXPECT_EQ(3u * kDefaultTCPMSS, sender_->GetCongestionWindow()); +} + +TEST_F(TcpCubicSenderBytesTest, 2ConnectionCongestionAvoidanceAtEndOfRecovery) { + sender_->SetNumEmulatedConnections(2); + // Ack 10 packets in 5 acks to raise the CWND to 20. + const int kNumberOfAcks = 5; + for (int i = 0; i < kNumberOfAcks; ++i) { + // Send our full send window. + SendAvailableSendWindow(); + AckNPackets(2); + } + SendAvailableSendWindow(); + QuicByteCount expected_send_window = + kDefaultWindowTCP + (kDefaultTCPMSS * 2 * kNumberOfAcks); + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + LoseNPackets(1); + + // We should now have fallen out of slow start with a reduced window. + expected_send_window = expected_send_window * sender_->GetRenoBeta(); + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + // No congestion window growth should occur in recovery phase, i.e., until the + // currently outstanding 20 packets are acked. + for (int i = 0; i < 10; ++i) { + // Send our full send window. + SendAvailableSendWindow(); + EXPECT_TRUE(sender_->InRecovery()); + AckNPackets(2); + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + } + EXPECT_FALSE(sender_->InRecovery()); + + // Out of recovery now. Congestion window should not grow for half an RTT. + size_t packets_in_send_window = expected_send_window / kDefaultTCPMSS; + SendAvailableSendWindow(); + AckNPackets(packets_in_send_window / 2 - 2); + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + // Next ack should increase congestion window by 1MSS. + SendAvailableSendWindow(); + AckNPackets(2); + expected_send_window += kDefaultTCPMSS; + packets_in_send_window += 1; + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + // Congestion window should remain steady again for half an RTT. + SendAvailableSendWindow(); + AckNPackets(packets_in_send_window / 2 - 1); + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + // Next ack should cause congestion window to grow by 1MSS. + SendAvailableSendWindow(); + AckNPackets(2); + expected_send_window += kDefaultTCPMSS; + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); +} + +TEST_F(TcpCubicSenderBytesTest, 1ConnectionCongestionAvoidanceAtEndOfRecovery) { + sender_->SetNumEmulatedConnections(1); + // Ack 10 packets in 5 acks to raise the CWND to 20. + const int kNumberOfAcks = 5; + for (int i = 0; i < kNumberOfAcks; ++i) { + // Send our full send window. + SendAvailableSendWindow(); + AckNPackets(2); + } + SendAvailableSendWindow(); + QuicByteCount expected_send_window = + kDefaultWindowTCP + (kDefaultTCPMSS * 2 * kNumberOfAcks); + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + LoseNPackets(1); + + // We should now have fallen out of slow start with a reduced window. + expected_send_window *= kRenoBeta; + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + // No congestion window growth should occur in recovery phase, i.e., until the + // currently outstanding 20 packets are acked. + for (int i = 0; i < 10; ++i) { + // Send our full send window. + SendAvailableSendWindow(); + EXPECT_TRUE(sender_->InRecovery()); + AckNPackets(2); + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + } + EXPECT_FALSE(sender_->InRecovery()); + + // Out of recovery now. Congestion window should not grow during RTT. + for (uint64_t i = 0; i < expected_send_window / kDefaultTCPMSS - 2; i += 2) { + // Send our full send window. + SendAvailableSendWindow(); + AckNPackets(2); + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + } + + // Next ack should cause congestion window to grow by 1MSS. + SendAvailableSendWindow(); + AckNPackets(2); + expected_send_window += kDefaultTCPMSS; + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); +} + +TEST_F(TcpCubicSenderBytesTest, BandwidthResumption) { + // Test that when provided with CachedNetworkParameters and opted in to the + // bandwidth resumption experiment, that the TcpCubicSenderPackets sets + // initial CWND appropriately. + + // Set some common values. + const QuicPacketCount kNumberOfPackets = 123; + const QuicBandwidth kBandwidthEstimate = + QuicBandwidth::FromBytesPerSecond(kNumberOfPackets * kDefaultTCPMSS); + const QuicTime::Delta kRttEstimate = QuicTime::Delta::FromSeconds(1); + + SendAlgorithmInterface::NetworkParams network_param; + network_param.bandwidth = kBandwidthEstimate; + network_param.rtt = kRttEstimate; + sender_->AdjustNetworkParameters(network_param); + EXPECT_EQ(kNumberOfPackets * kDefaultTCPMSS, sender_->GetCongestionWindow()); + + // Resume with an illegal value of 0 and verify the server ignores it. + SendAlgorithmInterface::NetworkParams network_param_no_bandwidth; + network_param_no_bandwidth.bandwidth = QuicBandwidth::Zero(); + network_param_no_bandwidth.rtt = kRttEstimate; + sender_->AdjustNetworkParameters(network_param_no_bandwidth); + EXPECT_EQ(kNumberOfPackets * kDefaultTCPMSS, sender_->GetCongestionWindow()); + + // Resumed CWND is limited to be in a sensible range. + const QuicBandwidth kUnreasonableBandwidth = + QuicBandwidth::FromBytesPerSecond((kMaxResumptionCongestionWindow + 1) * + kDefaultTCPMSS); + SendAlgorithmInterface::NetworkParams network_param_large_bandwidth; + network_param_large_bandwidth.bandwidth = kUnreasonableBandwidth; + network_param_large_bandwidth.rtt = QuicTime::Delta::FromSeconds(1); + sender_->AdjustNetworkParameters(network_param_large_bandwidth); + EXPECT_EQ(kMaxResumptionCongestionWindow * kDefaultTCPMSS, + sender_->GetCongestionWindow()); +} + +TEST_F(TcpCubicSenderBytesTest, PaceBelowCWND) { + QuicConfig config; + + // Verify that kCOPT: kMIN4 forces the min CWND to 1 packet, but allows up + // to 4 to be sent. + QuicTagVector options; + options.push_back(kMIN4); + QuicConfigPeer::SetReceivedConnectionOptions(&config, options); + sender_->SetFromConfig(config, Perspective::IS_SERVER); + sender_->OnRetransmissionTimeout(true); + EXPECT_EQ(kDefaultTCPMSS, sender_->GetCongestionWindow()); + EXPECT_TRUE(sender_->CanSend(kDefaultTCPMSS)); + EXPECT_TRUE(sender_->CanSend(2 * kDefaultTCPMSS)); + EXPECT_TRUE(sender_->CanSend(3 * kDefaultTCPMSS)); + EXPECT_FALSE(sender_->CanSend(4 * kDefaultTCPMSS)); +} + +TEST_F(TcpCubicSenderBytesTest, NoPRR) { + QuicTime::Delta rtt = QuicTime::Delta::FromMilliseconds(100); + sender_->rtt_stats_.UpdateRtt(rtt, QuicTime::Delta::Zero(), QuicTime::Zero()); + + sender_->SetNumEmulatedConnections(1); + // Verify that kCOPT: kNPRR allows all packets to be sent, even if only one + // ack has been received. + QuicTagVector options; + options.push_back(kNPRR); + QuicConfig config; + QuicConfigPeer::SetReceivedConnectionOptions(&config, options); + sender_->SetFromConfig(config, Perspective::IS_SERVER); + SendAvailableSendWindow(); + LoseNPackets(9); + AckNPackets(1); + + // We should now have fallen out of slow start with a reduced window. + EXPECT_EQ(kRenoBeta * kDefaultWindowTCP, sender_->GetCongestionWindow()); + const QuicPacketCount window_in_packets = + kRenoBeta * kDefaultWindowTCP / kDefaultTCPMSS; + const QuicBandwidth expected_pacing_rate = + QuicBandwidth::FromBytesAndTimeDelta(kRenoBeta * kDefaultWindowTCP, + sender_->rtt_stats_.smoothed_rtt()); + EXPECT_EQ(expected_pacing_rate, sender_->PacingRate(0)); + EXPECT_EQ(window_in_packets, + static_cast(SendAvailableSendWindow())); + EXPECT_EQ(expected_pacing_rate, + sender_->PacingRate(kRenoBeta * kDefaultWindowTCP)); +} + +TEST_F(TcpCubicSenderBytesTest, ResetAfterConnectionMigration) { + // Starts from slow start. + sender_->SetNumEmulatedConnections(1); + const int kNumberOfAcks = 10; + for (int i = 0; i < kNumberOfAcks; ++i) { + // Send our full send window. + SendAvailableSendWindow(); + AckNPackets(2); + } + SendAvailableSendWindow(); + QuicByteCount expected_send_window = + kDefaultWindowTCP + (kDefaultTCPMSS * 2 * kNumberOfAcks); + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + + // Loses a packet to exit slow start. + LoseNPackets(1); + + // We should now have fallen out of slow start with a reduced window. Slow + // start threshold is also updated. + expected_send_window *= kRenoBeta; + EXPECT_EQ(expected_send_window, sender_->GetCongestionWindow()); + EXPECT_EQ(expected_send_window, sender_->GetSlowStartThreshold()); + + // Resets cwnd and slow start threshold on connection migrations. + sender_->OnConnectionMigration(); + EXPECT_EQ(kDefaultWindowTCP, sender_->GetCongestionWindow()); + EXPECT_EQ(kMaxCongestionWindowPackets * kDefaultTCPMSS, + sender_->GetSlowStartThreshold()); + EXPECT_FALSE(sender_->hybrid_slow_start().started()); +} + +TEST_F(TcpCubicSenderBytesTest, DefaultMaxCwnd) { + RttStats rtt_stats; + QuicConnectionStats stats; + std::unique_ptr sender(SendAlgorithmInterface::Create( + &clock_, &rtt_stats, /*unacked_packets=*/nullptr, kCubicBytes, + QuicRandom::GetInstance(), &stats, kInitialCongestionWindow, nullptr)); + + AckedPacketVector acked_packets; + LostPacketVector missing_packets; + QuicPacketCount max_congestion_window = + GetQuicFlag(quic_max_congestion_window); + for (uint64_t i = 1; i < max_congestion_window; ++i) { + acked_packets.clear(); + acked_packets.push_back( + AckedPacket(QuicPacketNumber(i), 1350, QuicTime::Zero())); + sender->OnCongestionEvent(true, sender->GetCongestionWindow(), clock_.Now(), + acked_packets, missing_packets, 0, 0); + } + EXPECT_EQ(max_congestion_window, + sender->GetCongestionWindow() / kDefaultTCPMSS); +} + +TEST_F(TcpCubicSenderBytesTest, LimitCwndIncreaseInCongestionAvoidance) { + // Enable Cubic. + sender_ = std::make_unique(&clock_, false); + + int num_sent = SendAvailableSendWindow(); + + // Make sure we fall out of slow start. + QuicByteCount saved_cwnd = sender_->GetCongestionWindow(); + LoseNPackets(1); + EXPECT_GT(saved_cwnd, sender_->GetCongestionWindow()); + + // Ack the rest of the outstanding packets to get out of recovery. + for (int i = 1; i < num_sent; ++i) { + AckNPackets(1); + } + EXPECT_EQ(0u, bytes_in_flight_); + // Send a new window of data and ack all; cubic growth should occur. + saved_cwnd = sender_->GetCongestionWindow(); + num_sent = SendAvailableSendWindow(); + + // Ack packets until the CWND increases. + while (sender_->GetCongestionWindow() == saved_cwnd) { + AckNPackets(1); + SendAvailableSendWindow(); + } + // Bytes in flight may be larger than the CWND if the CWND isn't an exact + // multiple of the packet sizes being sent. + EXPECT_GE(bytes_in_flight_, sender_->GetCongestionWindow()); + saved_cwnd = sender_->GetCongestionWindow(); + + // Advance time 2 seconds waiting for an ack. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(2000)); + + // Ack two packets. The CWND should increase by only one packet. + AckNPackets(2); + EXPECT_EQ(saved_cwnd + kDefaultTCPMSS, sender_->GetCongestionWindow()); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/congestion_control/uber_loss_algorithm.cc b/quiche/quic/core/congestion_control/uber_loss_algorithm.cc new file mode 100644 index 000000000000..cdd8547a5788 --- /dev/null +++ b/quiche/quic/core/congestion_control/uber_loss_algorithm.cc @@ -0,0 +1,210 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/uber_loss_algorithm.h" + +#include + +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" + +namespace quic { + +UberLossAlgorithm::UberLossAlgorithm() { + for (int8_t i = INITIAL_DATA; i < NUM_PACKET_NUMBER_SPACES; ++i) { + general_loss_algorithms_[i].Initialize(static_cast(i), + this); + } +} + +void UberLossAlgorithm::SetFromConfig(const QuicConfig& config, + Perspective perspective) { + if (config.HasClientRequestedIndependentOption(kELDT, perspective) && + tuner_ != nullptr) { + tuning_configured_ = true; + MaybeStartTuning(); + } +} + +LossDetectionInterface::DetectionStats UberLossAlgorithm::DetectLosses( + const QuicUnackedPacketMap& unacked_packets, QuicTime time, + const RttStats& rtt_stats, QuicPacketNumber /*largest_newly_acked*/, + const AckedPacketVector& packets_acked, LostPacketVector* packets_lost) { + DetectionStats overall_stats; + + for (int8_t i = INITIAL_DATA; i < NUM_PACKET_NUMBER_SPACES; ++i) { + const QuicPacketNumber largest_acked = + unacked_packets.GetLargestAckedOfPacketNumberSpace( + static_cast(i)); + if (!largest_acked.IsInitialized() || + unacked_packets.GetLeastUnacked() > largest_acked) { + // Skip detecting losses if no packet has been received for this packet + // number space or the least_unacked is greater than largest_acked. + continue; + } + + DetectionStats stats = general_loss_algorithms_[i].DetectLosses( + unacked_packets, time, rtt_stats, largest_acked, packets_acked, + packets_lost); + + overall_stats.sent_packets_max_sequence_reordering = + std::max(overall_stats.sent_packets_max_sequence_reordering, + stats.sent_packets_max_sequence_reordering); + overall_stats.sent_packets_num_borderline_time_reorderings += + stats.sent_packets_num_borderline_time_reorderings; + overall_stats.total_loss_detection_response_time += + stats.total_loss_detection_response_time; + } + + return overall_stats; +} + +QuicTime UberLossAlgorithm::GetLossTimeout() const { + QuicTime loss_timeout = QuicTime::Zero(); + // Returns the earliest non-zero loss timeout. + for (int8_t i = INITIAL_DATA; i < NUM_PACKET_NUMBER_SPACES; ++i) { + const QuicTime timeout = general_loss_algorithms_[i].GetLossTimeout(); + if (!loss_timeout.IsInitialized()) { + loss_timeout = timeout; + continue; + } + if (timeout.IsInitialized()) { + loss_timeout = std::min(loss_timeout, timeout); + } + } + return loss_timeout; +} + +void UberLossAlgorithm::SpuriousLossDetected( + const QuicUnackedPacketMap& unacked_packets, const RttStats& rtt_stats, + QuicTime ack_receive_time, QuicPacketNumber packet_number, + QuicPacketNumber previous_largest_acked) { + general_loss_algorithms_[unacked_packets.GetPacketNumberSpace(packet_number)] + .SpuriousLossDetected(unacked_packets, rtt_stats, ack_receive_time, + packet_number, previous_largest_acked); +} + +void UberLossAlgorithm::SetLossDetectionTuner( + std::unique_ptr tuner) { + if (tuner_ != nullptr) { + QUIC_BUG(quic_bug_10469_1) + << "LossDetectionTuner can only be set once when session begins."; + return; + } + tuner_ = std::move(tuner); +} + +void UberLossAlgorithm::MaybeStartTuning() { + if (tuner_started_ || !tuning_configured_ || !min_rtt_available_ || + !user_agent_known_ || !reorder_happened_) { + return; + } + + tuner_started_ = tuner_->Start(&tuned_parameters_); + if (!tuner_started_) { + return; + } + + if (tuned_parameters_.reordering_shift.has_value() && + tuned_parameters_.reordering_threshold.has_value()) { + QUIC_DLOG(INFO) << "Setting reordering shift to " + << *tuned_parameters_.reordering_shift + << ", and reordering threshold to " + << *tuned_parameters_.reordering_threshold; + SetReorderingShift(*tuned_parameters_.reordering_shift); + SetReorderingThreshold(*tuned_parameters_.reordering_threshold); + } else { + QUIC_BUG(quic_bug_10469_2) + << "Tuner started but some parameters are missing"; + } +} + +void UberLossAlgorithm::OnConfigNegotiated() {} + +void UberLossAlgorithm::OnMinRttAvailable() { + min_rtt_available_ = true; + MaybeStartTuning(); +} + +void UberLossAlgorithm::OnUserAgentIdKnown() { + user_agent_known_ = true; + MaybeStartTuning(); +} + +void UberLossAlgorithm::OnConnectionClosed() { + if (tuner_ != nullptr && tuner_started_) { + tuner_->Finish(tuned_parameters_); + } +} + +void UberLossAlgorithm::OnReorderingDetected() { + const bool tuner_started_before = tuner_started_; + const bool reorder_happened_before = reorder_happened_; + + reorder_happened_ = true; + MaybeStartTuning(); + + if (!tuner_started_before && tuner_started_) { + if (reorder_happened_before) { + QUIC_CODE_COUNT(quic_loss_tuner_started_after_first_reorder); + } else { + QUIC_CODE_COUNT(quic_loss_tuner_started_on_first_reorder); + } + } +} + +void UberLossAlgorithm::SetReorderingShift(int reordering_shift) { + for (int8_t i = INITIAL_DATA; i < NUM_PACKET_NUMBER_SPACES; ++i) { + general_loss_algorithms_[i].set_reordering_shift(reordering_shift); + } +} + +void UberLossAlgorithm::SetReorderingThreshold( + QuicPacketCount reordering_threshold) { + for (int8_t i = INITIAL_DATA; i < NUM_PACKET_NUMBER_SPACES; ++i) { + general_loss_algorithms_[i].set_reordering_threshold(reordering_threshold); + } +} + +void UberLossAlgorithm::EnableAdaptiveReorderingThreshold() { + for (int8_t i = INITIAL_DATA; i < NUM_PACKET_NUMBER_SPACES; ++i) { + general_loss_algorithms_[i].set_use_adaptive_reordering_threshold(true); + } +} + +void UberLossAlgorithm::DisableAdaptiveReorderingThreshold() { + for (int8_t i = INITIAL_DATA; i < NUM_PACKET_NUMBER_SPACES; ++i) { + general_loss_algorithms_[i].set_use_adaptive_reordering_threshold(false); + } +} + +void UberLossAlgorithm::EnableAdaptiveTimeThreshold() { + for (int8_t i = INITIAL_DATA; i < NUM_PACKET_NUMBER_SPACES; ++i) { + general_loss_algorithms_[i].enable_adaptive_time_threshold(); + } +} + +QuicPacketCount UberLossAlgorithm::GetPacketReorderingThreshold() const { + return general_loss_algorithms_[APPLICATION_DATA].reordering_threshold(); +} + +int UberLossAlgorithm::GetPacketReorderingShift() const { + return general_loss_algorithms_[APPLICATION_DATA].reordering_shift(); +} + +void UberLossAlgorithm::DisablePacketThresholdForRuntPackets() { + for (int8_t i = INITIAL_DATA; i < NUM_PACKET_NUMBER_SPACES; ++i) { + general_loss_algorithms_[i].disable_packet_threshold_for_runt_packets(); + } +} + +void UberLossAlgorithm::ResetLossDetection(PacketNumberSpace space) { + if (space >= NUM_PACKET_NUMBER_SPACES) { + QUIC_BUG(quic_bug_10469_3) << "Invalid packet number space: " << space; + return; + } + general_loss_algorithms_[space].Reset(); +} + +} // namespace quic diff --git a/quiche/quic/core/congestion_control/uber_loss_algorithm.h b/quiche/quic/core/congestion_control/uber_loss_algorithm.h new file mode 100644 index 000000000000..febea73e5e92 --- /dev/null +++ b/quiche/quic/core/congestion_control/uber_loss_algorithm.h @@ -0,0 +1,139 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CONGESTION_CONTROL_UBER_LOSS_ALGORITHM_H_ +#define QUICHE_QUIC_CORE_CONGESTION_CONTROL_UBER_LOSS_ALGORITHM_H_ + +#include "absl/types/optional.h" +#include "quiche/quic/core/congestion_control/general_loss_algorithm.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_flags.h" + +namespace quic { + +namespace test { + +class QuicSentPacketManagerPeer; + +} // namespace test + +struct QUIC_EXPORT_PRIVATE LossDetectionParameters { + // See GeneralLossAlgorithm for the meaning of reordering_(shift|threshold). + absl::optional reordering_shift; + absl::optional reordering_threshold; +}; + +class QUIC_EXPORT_PRIVATE LossDetectionTunerInterface { + public: + virtual ~LossDetectionTunerInterface() {} + + // Start the tuning by choosing parameters and saving them into |*params|. + // Called near the start of a QUIC session, see the .cc file for exactly + // where. + virtual bool Start(LossDetectionParameters* params) = 0; + + // Finish tuning. The tuner is expected to use the actual loss detection + // performance(for its definition of performance) to improve the parameter + // selection for future QUIC sessions. + // Called when a QUIC session closes. + virtual void Finish(const LossDetectionParameters& params) = 0; +}; + +// This class comprises multiple loss algorithms, each per packet number space. +class QUIC_EXPORT_PRIVATE UberLossAlgorithm : public LossDetectionInterface { + public: + UberLossAlgorithm(); + UberLossAlgorithm(const UberLossAlgorithm&) = delete; + UberLossAlgorithm& operator=(const UberLossAlgorithm&) = delete; + ~UberLossAlgorithm() override {} + + void SetFromConfig(const QuicConfig& config, + Perspective perspective) override; + + // Detects lost packets. + DetectionStats DetectLosses(const QuicUnackedPacketMap& unacked_packets, + QuicTime time, const RttStats& rtt_stats, + QuicPacketNumber largest_newly_acked, + const AckedPacketVector& packets_acked, + LostPacketVector* packets_lost) override; + + // Returns the earliest time the early retransmit timer should be active. + QuicTime GetLossTimeout() const override; + + // Called to increases time or packet threshold. + void SpuriousLossDetected(const QuicUnackedPacketMap& unacked_packets, + const RttStats& rtt_stats, + QuicTime ack_receive_time, + QuicPacketNumber packet_number, + QuicPacketNumber previous_largest_acked) override; + + void SetLossDetectionTuner( + std::unique_ptr tuner); + void OnConfigNegotiated() override; + void OnMinRttAvailable() override; + void OnUserAgentIdKnown() override; + void OnConnectionClosed() override; + void OnReorderingDetected() override; + + // Sets reordering_shift for all packet number spaces. + void SetReorderingShift(int reordering_shift); + + // Sets reordering_threshold for all packet number spaces. + void SetReorderingThreshold(QuicPacketCount reordering_threshold); + + // Enable adaptive reordering threshold of all packet number spaces. + void EnableAdaptiveReorderingThreshold(); + + // Disable adaptive reordering threshold of all packet number spaces. + void DisableAdaptiveReorderingThreshold(); + + // Enable adaptive time threshold of all packet number spaces. + void EnableAdaptiveTimeThreshold(); + + // Get the packet reordering threshold from the APPLICATION_DATA PN space. + // Always 3 when adaptive reordering is not enabled. + QuicPacketCount GetPacketReorderingThreshold() const; + + // Get the packet reordering shift from the APPLICATION_DATA PN space. + int GetPacketReorderingShift() const; + + // Disable packet threshold loss detection for *runt* packets. + void DisablePacketThresholdForRuntPackets(); + + // Called to reset loss detection of |space|. + void ResetLossDetection(PacketNumberSpace space); + + bool use_adaptive_reordering_threshold() const { + return general_loss_algorithms_[APPLICATION_DATA] + .use_adaptive_reordering_threshold(); + } + + bool use_adaptive_time_threshold() const { + return general_loss_algorithms_[APPLICATION_DATA] + .use_adaptive_time_threshold(); + } + + private: + friend class test::QuicSentPacketManagerPeer; + + void MaybeStartTuning(); + + // One loss algorithm per packet number space. + GeneralLossAlgorithm general_loss_algorithms_[NUM_PACKET_NUMBER_SPACES]; + + // Used to tune reordering_shift and reordering_threshold. + std::unique_ptr tuner_; + LossDetectionParameters tuned_parameters_; + bool tuner_started_ = false; + bool min_rtt_available_ = false; + // Whether user agent is known to the session. + bool user_agent_known_ = false; + // Whether tuning is configured in QuicConfig. + bool tuning_configured_ = false; + bool reorder_happened_ = false; // Whether any reordered packet is observed. +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CONGESTION_CONTROL_UBER_LOSS_ALGORITHM_H_ diff --git a/quiche/quic/core/congestion_control/uber_loss_algorithm_test.cc b/quiche/quic/core/congestion_control/uber_loss_algorithm_test.cc new file mode 100644 index 000000000000..f84084777c31 --- /dev/null +++ b/quiche/quic/core/congestion_control/uber_loss_algorithm_test.cc @@ -0,0 +1,360 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/uber_loss_algorithm.h" + +#include +#include + +#include "absl/types/optional.h" +#include "quiche/quic/core/congestion_control/rtt_stats.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/mock_clock.h" +#include "quiche/quic/test_tools/quic_unacked_packet_map_peer.h" + +namespace quic { +namespace test { +namespace { + +// Default packet length. +const uint32_t kDefaultLength = 1000; + +class UberLossAlgorithmTest : public QuicTest { + protected: + UberLossAlgorithmTest() { + unacked_packets_ = + std::make_unique(Perspective::IS_CLIENT); + rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(100), + QuicTime::Delta::Zero(), clock_.Now()); + EXPECT_LT(0, rtt_stats_.smoothed_rtt().ToMicroseconds()); + } + + void SendPacket(uint64_t packet_number, EncryptionLevel encryption_level) { + QuicStreamFrame frame; + QuicTransportVersion version = + CurrentSupportedVersions()[0].transport_version; + frame.stream_id = QuicUtils::GetFirstBidirectionalStreamId( + version, Perspective::IS_CLIENT); + if (encryption_level == ENCRYPTION_INITIAL) { + if (QuicVersionUsesCryptoFrames(version)) { + frame.stream_id = QuicUtils::GetFirstBidirectionalStreamId( + version, Perspective::IS_CLIENT); + } else { + frame.stream_id = QuicUtils::GetCryptoStreamId(version); + } + } + SerializedPacket packet(QuicPacketNumber(packet_number), + PACKET_1BYTE_PACKET_NUMBER, nullptr, kDefaultLength, + false, false); + packet.encryption_level = encryption_level; + packet.retransmittable_frames.push_back(QuicFrame(frame)); + unacked_packets_->AddSentPacket(&packet, NOT_RETRANSMISSION, clock_.Now(), + true, true, ECN_NOT_ECT); + } + + void AckPackets(const std::vector& packets_acked) { + packets_acked_.clear(); + for (uint64_t acked : packets_acked) { + unacked_packets_->RemoveFromInFlight(QuicPacketNumber(acked)); + packets_acked_.push_back(AckedPacket( + QuicPacketNumber(acked), kMaxOutgoingPacketSize, QuicTime::Zero())); + } + } + + void VerifyLosses(uint64_t largest_newly_acked, + const AckedPacketVector& packets_acked, + const std::vector& losses_expected) { + return VerifyLosses(largest_newly_acked, packets_acked, losses_expected, + absl::nullopt); + } + + void VerifyLosses( + uint64_t largest_newly_acked, const AckedPacketVector& packets_acked, + const std::vector& losses_expected, + absl::optional max_sequence_reordering_expected) { + LostPacketVector lost_packets; + LossDetectionInterface::DetectionStats stats = loss_algorithm_.DetectLosses( + *unacked_packets_, clock_.Now(), rtt_stats_, + QuicPacketNumber(largest_newly_acked), packets_acked, &lost_packets); + if (max_sequence_reordering_expected.has_value()) { + EXPECT_EQ(stats.sent_packets_max_sequence_reordering, + max_sequence_reordering_expected.value()); + } + ASSERT_EQ(losses_expected.size(), lost_packets.size()); + for (size_t i = 0; i < losses_expected.size(); ++i) { + EXPECT_EQ(lost_packets[i].packet_number, + QuicPacketNumber(losses_expected[i])); + } + } + + MockClock clock_; + std::unique_ptr unacked_packets_; + RttStats rtt_stats_; + UberLossAlgorithm loss_algorithm_; + AckedPacketVector packets_acked_; +}; + +TEST_F(UberLossAlgorithmTest, ScenarioA) { + // This test mimics a scenario: client sends 1-CHLO, 2-0RTT, 3-0RTT, + // timeout and retransmits 4-CHLO. Server acks packet 1 (ack gets lost). + // Server receives and buffers packets 2 and 3. Server receives packet 4 and + // processes handshake asynchronously, so server acks 4 and cannot process + // packets 2 and 3. + SendPacket(1, ENCRYPTION_INITIAL); + SendPacket(2, ENCRYPTION_ZERO_RTT); + SendPacket(3, ENCRYPTION_ZERO_RTT); + unacked_packets_->RemoveFromInFlight(QuicPacketNumber(1)); + SendPacket(4, ENCRYPTION_INITIAL); + + AckPackets({1, 4}); + unacked_packets_->MaybeUpdateLargestAckedOfPacketNumberSpace( + HANDSHAKE_DATA, QuicPacketNumber(4)); + // Verify no packet is detected lost. + VerifyLosses(4, packets_acked_, std::vector{}, 0); + EXPECT_EQ(QuicTime::Zero(), loss_algorithm_.GetLossTimeout()); +} + +TEST_F(UberLossAlgorithmTest, ScenarioB) { + // This test mimics a scenario: client sends 3-0RTT, 4-0RTT, receives SHLO, + // sends 5-1RTT, 6-1RTT. + SendPacket(3, ENCRYPTION_ZERO_RTT); + SendPacket(4, ENCRYPTION_ZERO_RTT); + SendPacket(5, ENCRYPTION_FORWARD_SECURE); + SendPacket(6, ENCRYPTION_FORWARD_SECURE); + + AckPackets({4}); + unacked_packets_->MaybeUpdateLargestAckedOfPacketNumberSpace( + APPLICATION_DATA, QuicPacketNumber(4)); + // No packet loss by acking 4. + VerifyLosses(4, packets_acked_, std::vector{}, 1); + EXPECT_EQ(clock_.Now() + 1.25 * rtt_stats_.smoothed_rtt(), + loss_algorithm_.GetLossTimeout()); + + // Acking 6 causes 3 to be detected loss. + AckPackets({6}); + unacked_packets_->MaybeUpdateLargestAckedOfPacketNumberSpace( + APPLICATION_DATA, QuicPacketNumber(6)); + VerifyLosses(6, packets_acked_, std::vector{3}, 3); + EXPECT_EQ(clock_.Now() + 1.25 * rtt_stats_.smoothed_rtt(), + loss_algorithm_.GetLossTimeout()); + packets_acked_.clear(); + + clock_.AdvanceTime(1.25 * rtt_stats_.latest_rtt()); + // Verify 5 will be early retransmitted. + VerifyLosses(6, packets_acked_, {5}, 1); +} + +TEST_F(UberLossAlgorithmTest, ScenarioC) { + // This test mimics a scenario: server sends 1-SHLO, 2-1RTT, 3-1RTT, 4-1RTT + // and retransmit 4-SHLO. Client receives and buffers packet 4. Client + // receives packet 5 and processes 4. + QuicUnackedPacketMapPeer::SetPerspective(unacked_packets_.get(), + Perspective::IS_SERVER); + SendPacket(1, ENCRYPTION_ZERO_RTT); + SendPacket(2, ENCRYPTION_FORWARD_SECURE); + SendPacket(3, ENCRYPTION_FORWARD_SECURE); + SendPacket(4, ENCRYPTION_FORWARD_SECURE); + unacked_packets_->RemoveFromInFlight(QuicPacketNumber(1)); + SendPacket(5, ENCRYPTION_ZERO_RTT); + + AckPackets({4, 5}); + unacked_packets_->MaybeUpdateLargestAckedOfPacketNumberSpace( + APPLICATION_DATA, QuicPacketNumber(4)); + unacked_packets_->MaybeUpdateLargestAckedOfPacketNumberSpace( + HANDSHAKE_DATA, QuicPacketNumber(5)); + // No packet loss by acking 5. + VerifyLosses(5, packets_acked_, std::vector{}, 2); + EXPECT_EQ(clock_.Now() + 1.25 * rtt_stats_.smoothed_rtt(), + loss_algorithm_.GetLossTimeout()); + packets_acked_.clear(); + + clock_.AdvanceTime(1.25 * rtt_stats_.latest_rtt()); + // Verify 2 and 3 will be early retransmitted. + VerifyLosses(5, packets_acked_, std::vector{2, 3}, 2); +} + +// Regression test for b/133771183. +TEST_F(UberLossAlgorithmTest, PacketInLimbo) { + // This test mimics a scenario: server sends 1-SHLO, 2-1RTT, 3-1RTT, + // 4-retransmit SHLO. Client receives and ACKs packets 1, 3 and 4. + QuicUnackedPacketMapPeer::SetPerspective(unacked_packets_.get(), + Perspective::IS_SERVER); + + SendPacket(1, ENCRYPTION_ZERO_RTT); + SendPacket(2, ENCRYPTION_FORWARD_SECURE); + SendPacket(3, ENCRYPTION_FORWARD_SECURE); + SendPacket(4, ENCRYPTION_ZERO_RTT); + + SendPacket(5, ENCRYPTION_FORWARD_SECURE); + AckPackets({1, 3, 4}); + unacked_packets_->MaybeUpdateLargestAckedOfPacketNumberSpace( + APPLICATION_DATA, QuicPacketNumber(3)); + unacked_packets_->MaybeUpdateLargestAckedOfPacketNumberSpace( + HANDSHAKE_DATA, QuicPacketNumber(4)); + // No packet loss detected. + VerifyLosses(4, packets_acked_, std::vector{}); + + SendPacket(6, ENCRYPTION_FORWARD_SECURE); + AckPackets({5, 6}); + unacked_packets_->MaybeUpdateLargestAckedOfPacketNumberSpace( + APPLICATION_DATA, QuicPacketNumber(6)); + // Verify packet 2 is detected lost. + VerifyLosses(6, packets_acked_, std::vector{2}); +} + +class TestLossTuner : public LossDetectionTunerInterface { + public: + TestLossTuner(bool forced_start_result, + LossDetectionParameters forced_parameters) + : forced_start_result_(forced_start_result), + forced_parameters_(std::move(forced_parameters)) {} + + ~TestLossTuner() override = default; + + bool Start(LossDetectionParameters* params) override { + start_called_ = true; + *params = forced_parameters_; + return forced_start_result_; + } + + void Finish(const LossDetectionParameters& /*params*/) override {} + + bool start_called() const { return start_called_; } + + private: + bool forced_start_result_; + LossDetectionParameters forced_parameters_; + bool start_called_ = false; +}; + +// Verify the parameters are changed if first call SetFromConfig(), then call +// OnMinRttAvailable(). +TEST_F(UberLossAlgorithmTest, LossDetectionTuning_SetFromConfigFirst) { + const int old_reordering_shift = loss_algorithm_.GetPacketReorderingShift(); + const QuicPacketCount old_reordering_threshold = + loss_algorithm_.GetPacketReorderingThreshold(); + + loss_algorithm_.OnUserAgentIdKnown(); + + // Not owned. + TestLossTuner* test_tuner = new TestLossTuner( + /*forced_start_result=*/true, + LossDetectionParameters{ + /*reordering_shift=*/old_reordering_shift + 1, + /*reordering_threshold=*/old_reordering_threshold * 2}); + loss_algorithm_.SetLossDetectionTuner( + std::unique_ptr(test_tuner)); + + QuicConfig config; + QuicTagVector connection_options; + connection_options.push_back(kELDT); + config.SetInitialReceivedConnectionOptions(connection_options); + loss_algorithm_.SetFromConfig(config, Perspective::IS_SERVER); + + // MinRtt was not available when SetFromConfig was called. + EXPECT_FALSE(test_tuner->start_called()); + EXPECT_EQ(old_reordering_shift, loss_algorithm_.GetPacketReorderingShift()); + EXPECT_EQ(old_reordering_threshold, + loss_algorithm_.GetPacketReorderingThreshold()); + + // MinRtt available. Tuner should not start yet because no reordering yet. + loss_algorithm_.OnMinRttAvailable(); + EXPECT_FALSE(test_tuner->start_called()); + + // Reordering happened. Tuner should start now. + loss_algorithm_.OnReorderingDetected(); + EXPECT_TRUE(test_tuner->start_called()); + EXPECT_NE(old_reordering_shift, loss_algorithm_.GetPacketReorderingShift()); + EXPECT_NE(old_reordering_threshold, + loss_algorithm_.GetPacketReorderingThreshold()); +} + +// Verify the parameters are changed if first call OnMinRttAvailable(), then +// call SetFromConfig(). +TEST_F(UberLossAlgorithmTest, LossDetectionTuning_OnMinRttAvailableFirst) { + const int old_reordering_shift = loss_algorithm_.GetPacketReorderingShift(); + const QuicPacketCount old_reordering_threshold = + loss_algorithm_.GetPacketReorderingThreshold(); + + loss_algorithm_.OnUserAgentIdKnown(); + + // Not owned. + TestLossTuner* test_tuner = new TestLossTuner( + /*forced_start_result=*/true, + LossDetectionParameters{ + /*reordering_shift=*/old_reordering_shift + 1, + /*reordering_threshold=*/old_reordering_threshold * 2}); + loss_algorithm_.SetLossDetectionTuner( + std::unique_ptr(test_tuner)); + + loss_algorithm_.OnMinRttAvailable(); + EXPECT_FALSE(test_tuner->start_called()); + EXPECT_EQ(old_reordering_shift, loss_algorithm_.GetPacketReorderingShift()); + EXPECT_EQ(old_reordering_threshold, + loss_algorithm_.GetPacketReorderingThreshold()); + + // Pretend a reodering has happened. + loss_algorithm_.OnReorderingDetected(); + EXPECT_FALSE(test_tuner->start_called()); + + QuicConfig config; + QuicTagVector connection_options; + connection_options.push_back(kELDT); + config.SetInitialReceivedConnectionOptions(connection_options); + // Should start tuning since MinRtt is available. + loss_algorithm_.SetFromConfig(config, Perspective::IS_SERVER); + + EXPECT_TRUE(test_tuner->start_called()); + EXPECT_NE(old_reordering_shift, loss_algorithm_.GetPacketReorderingShift()); + EXPECT_NE(old_reordering_threshold, + loss_algorithm_.GetPacketReorderingThreshold()); +} + +// Verify the parameters are not changed if Tuner.Start() returns false. +TEST_F(UberLossAlgorithmTest, LossDetectionTuning_StartFailed) { + const int old_reordering_shift = loss_algorithm_.GetPacketReorderingShift(); + const QuicPacketCount old_reordering_threshold = + loss_algorithm_.GetPacketReorderingThreshold(); + + loss_algorithm_.OnUserAgentIdKnown(); + + // Not owned. + TestLossTuner* test_tuner = new TestLossTuner( + /*forced_start_result=*/false, + LossDetectionParameters{ + /*reordering_shift=*/old_reordering_shift + 1, + /*reordering_threshold=*/old_reordering_threshold * 2}); + loss_algorithm_.SetLossDetectionTuner( + std::unique_ptr(test_tuner)); + + QuicConfig config; + QuicTagVector connection_options; + connection_options.push_back(kELDT); + config.SetInitialReceivedConnectionOptions(connection_options); + loss_algorithm_.SetFromConfig(config, Perspective::IS_SERVER); + + // MinRtt was not available when SetFromConfig was called. + EXPECT_FALSE(test_tuner->start_called()); + EXPECT_EQ(old_reordering_shift, loss_algorithm_.GetPacketReorderingShift()); + EXPECT_EQ(old_reordering_threshold, + loss_algorithm_.GetPacketReorderingThreshold()); + + // Pretend a reodering has happened. + loss_algorithm_.OnReorderingDetected(); + EXPECT_FALSE(test_tuner->start_called()); + + // Parameters should not change since test_tuner->Start() returns false. + loss_algorithm_.OnMinRttAvailable(); + EXPECT_TRUE(test_tuner->start_called()); + EXPECT_EQ(old_reordering_shift, loss_algorithm_.GetPacketReorderingShift()); + EXPECT_EQ(old_reordering_threshold, + loss_algorithm_.GetPacketReorderingThreshold()); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/congestion_control/windowed_filter.h b/quiche/quic/core/congestion_control/windowed_filter.h new file mode 100644 index 000000000000..9326a08d85b8 --- /dev/null +++ b/quiche/quic/core/congestion_control/windowed_filter.h @@ -0,0 +1,164 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CONGESTION_CONTROL_WINDOWED_FILTER_H_ +#define QUICHE_QUIC_CORE_CONGESTION_CONTROL_WINDOWED_FILTER_H_ + +// Implements Kathleen Nichols' algorithm for tracking the minimum (or maximum) +// estimate of a stream of samples over some fixed time interval. (E.g., +// the minimum RTT over the past five minutes.) The algorithm keeps track of +// the best, second best, and third best min (or max) estimates, maintaining an +// invariant that the measurement time of the n'th best >= n-1'th best. + +// The algorithm works as follows. On a reset, all three estimates are set to +// the same sample. The second best estimate is then recorded in the second +// quarter of the window, and a third best estimate is recorded in the second +// half of the window, bounding the worst case error when the true min is +// monotonically increasing (or true max is monotonically decreasing) over the +// window. +// +// A new best sample replaces all three estimates, since the new best is lower +// (or higher) than everything else in the window and it is the most recent. +// The window thus effectively gets reset on every new min. The same property +// holds true for second best and third best estimates. Specifically, when a +// sample arrives that is better than the second best but not better than the +// best, it replaces the second and third best estimates but not the best +// estimate. Similarly, a sample that is better than the third best estimate +// but not the other estimates replaces only the third best estimate. +// +// Finally, when the best expires, it is replaced by the second best, which in +// turn is replaced by the third best. The newest sample replaces the third +// best. + +#include "quiche/quic/core/quic_time.h" + +namespace quic { + +// Compares two values and returns true if the first is less than or equal +// to the second. +template +struct QUIC_EXPORT_PRIVATE MinFilter { + bool operator()(const T& lhs, const T& rhs) const { return lhs <= rhs; } +}; + +// Compares two values and returns true if the first is greater than or equal +// to the second. +template +struct QUIC_EXPORT_PRIVATE MaxFilter { + bool operator()(const T& lhs, const T& rhs) const { return lhs >= rhs; } +}; + +// Use the following to construct a windowed filter object of type T. +// For example, a min filter using QuicTime as the time type: +// WindowedFilter, QuicTime, QuicTime::Delta> ObjectName; +// A max filter using 64-bit integers as the time type: +// WindowedFilter, uint64_t, int64_t> ObjectName; +// Specifically, this template takes four arguments: +// 1. T -- type of the measurement that is being filtered. +// 2. Compare -- MinFilter or MaxFilter, depending on the type of filter +// desired. +// 3. TimeT -- the type used to represent timestamps. +// 4. TimeDeltaT -- the type used to represent continuous time intervals between +// two timestamps. Has to be the type of (a - b) if both |a| and |b| are +// of type TimeT. +template +class QUIC_EXPORT_PRIVATE WindowedFilter { + public: + // |window_length| is the period after which a best estimate expires. + // |zero_value| is used as the uninitialized value for objects of T. + // Importantly, |zero_value| should be an invalid value for a true sample. + WindowedFilter(TimeDeltaT window_length, T zero_value, TimeT zero_time) + : window_length_(window_length), + zero_value_(zero_value), + zero_time_(zero_time), + estimates_{Sample(zero_value_, zero_time), + Sample(zero_value_, zero_time), + Sample(zero_value_, zero_time)} {} + + // Changes the window length. Does not update any current samples. + void SetWindowLength(TimeDeltaT window_length) { + window_length_ = window_length; + } + + // Updates best estimates with |sample|, and expires and updates best + // estimates as necessary. + void Update(T new_sample, TimeT new_time) { + // Reset all estimates if they have not yet been initialized, if new sample + // is a new best, or if the newest recorded estimate is too old. + if (estimates_[0].sample == zero_value_ || + Compare()(new_sample, estimates_[0].sample) || + new_time - estimates_[2].time > window_length_) { + Reset(new_sample, new_time); + return; + } + + if (Compare()(new_sample, estimates_[1].sample)) { + estimates_[1] = Sample(new_sample, new_time); + estimates_[2] = estimates_[1]; + } else if (Compare()(new_sample, estimates_[2].sample)) { + estimates_[2] = Sample(new_sample, new_time); + } + + // Expire and update estimates as necessary. + if (new_time - estimates_[0].time > window_length_) { + // The best estimate hasn't been updated for an entire window, so promote + // second and third best estimates. + estimates_[0] = estimates_[1]; + estimates_[1] = estimates_[2]; + estimates_[2] = Sample(new_sample, new_time); + // Need to iterate one more time. Check if the new best estimate is + // outside the window as well, since it may also have been recorded a + // long time ago. Don't need to iterate once more since we cover that + // case at the beginning of the method. + if (new_time - estimates_[0].time > window_length_) { + estimates_[0] = estimates_[1]; + estimates_[1] = estimates_[2]; + } + return; + } + if (estimates_[1].sample == estimates_[0].sample && + new_time - estimates_[1].time > window_length_ >> 2) { + // A quarter of the window has passed without a better sample, so the + // second-best estimate is taken from the second quarter of the window. + estimates_[2] = estimates_[1] = Sample(new_sample, new_time); + return; + } + + if (estimates_[2].sample == estimates_[1].sample && + new_time - estimates_[2].time > window_length_ >> 1) { + // We've passed a half of the window without a better estimate, so take + // a third-best estimate from the second half of the window. + estimates_[2] = Sample(new_sample, new_time); + } + } + + // Resets all estimates to new sample. + void Reset(T new_sample, TimeT new_time) { + estimates_[0] = estimates_[1] = estimates_[2] = + Sample(new_sample, new_time); + } + + void Clear() { Reset(zero_value_, zero_time_); } + + T GetBest() const { return estimates_[0].sample; } + T GetSecondBest() const { return estimates_[1].sample; } + T GetThirdBest() const { return estimates_[2].sample; } + + private: + struct QUIC_EXPORT_PRIVATE Sample { + T sample; + TimeT time; + Sample(T init_sample, TimeT init_time) + : sample(init_sample), time(init_time) {} + }; + + TimeDeltaT window_length_; // Time length of window. + T zero_value_; // Uninitialized value of T. + TimeT zero_time_; // Uninitialized value of TimeT. + Sample estimates_[3]; // Best estimate is element 0. +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CONGESTION_CONTROL_WINDOWED_FILTER_H_ diff --git a/quiche/quic/core/congestion_control/windowed_filter_test.cc b/quiche/quic/core/congestion_control/windowed_filter_test.cc new file mode 100644 index 000000000000..e984cf8f691c --- /dev/null +++ b/quiche/quic/core/congestion_control/windowed_filter_test.cc @@ -0,0 +1,381 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/congestion_control/windowed_filter.h" + +#include "quiche/quic/core/congestion_control/rtt_stats.h" +#include "quiche/quic/core/quic_bandwidth.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { + +class WindowedFilterTest : public QuicTest { + public: + // Set the window to 99ms, so 25ms is more than a quarter rtt. + WindowedFilterTest() + : windowed_min_rtt_(QuicTime::Delta::FromMilliseconds(99), + QuicTime::Delta::Zero(), QuicTime::Zero()), + windowed_max_bw_(QuicTime::Delta::FromMilliseconds(99), + QuicBandwidth::Zero(), QuicTime::Zero()) {} + + // Sets up windowed_min_rtt_ to have the following values: + // Best = 20ms, recorded at 25ms + // Second best = 40ms, recorded at 75ms + // Third best = 50ms, recorded at 100ms + void InitializeMinFilter() { + QuicTime now = QuicTime::Zero(); + QuicTime::Delta rtt_sample = QuicTime::Delta::FromMilliseconds(10); + for (int i = 0; i < 5; ++i) { + windowed_min_rtt_.Update(rtt_sample, now); + QUIC_VLOG(1) << "i: " << i << " sample: " << rtt_sample.ToMilliseconds() + << " mins: " + << " " << windowed_min_rtt_.GetBest().ToMilliseconds() << " " + << windowed_min_rtt_.GetSecondBest().ToMilliseconds() << " " + << windowed_min_rtt_.GetThirdBest().ToMilliseconds(); + now = now + QuicTime::Delta::FromMilliseconds(25); + rtt_sample = rtt_sample + QuicTime::Delta::FromMilliseconds(10); + } + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(20), + windowed_min_rtt_.GetBest()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(40), + windowed_min_rtt_.GetSecondBest()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(50), + windowed_min_rtt_.GetThirdBest()); + } + + // Sets up windowed_max_bw_ to have the following values: + // Best = 900 bps, recorded at 25ms + // Second best = 700 bps, recorded at 75ms + // Third best = 600 bps, recorded at 100ms + void InitializeMaxFilter() { + QuicTime now = QuicTime::Zero(); + QuicBandwidth bw_sample = QuicBandwidth::FromBitsPerSecond(1000); + for (int i = 0; i < 5; ++i) { + windowed_max_bw_.Update(bw_sample, now); + QUIC_VLOG(1) << "i: " << i << " sample: " << bw_sample.ToBitsPerSecond() + << " maxs: " + << " " << windowed_max_bw_.GetBest().ToBitsPerSecond() << " " + << windowed_max_bw_.GetSecondBest().ToBitsPerSecond() << " " + << windowed_max_bw_.GetThirdBest().ToBitsPerSecond(); + now = now + QuicTime::Delta::FromMilliseconds(25); + bw_sample = bw_sample - QuicBandwidth::FromBitsPerSecond(100); + } + EXPECT_EQ(QuicBandwidth::FromBitsPerSecond(900), + windowed_max_bw_.GetBest()); + EXPECT_EQ(QuicBandwidth::FromBitsPerSecond(700), + windowed_max_bw_.GetSecondBest()); + EXPECT_EQ(QuicBandwidth::FromBitsPerSecond(600), + windowed_max_bw_.GetThirdBest()); + } + + protected: + WindowedFilter, QuicTime, + QuicTime::Delta> + windowed_min_rtt_; + WindowedFilter, QuicTime, + QuicTime::Delta> + windowed_max_bw_; +}; + +namespace { +// Test helper function: updates the filter with a lot of small values in order +// to ensure that it is not susceptible to noise. +void UpdateWithIrrelevantSamples( + WindowedFilter, uint64_t, uint64_t>* filter, + uint64_t max_value, uint64_t time) { + for (uint64_t i = 0; i < 1000; i++) { + filter->Update(i % max_value, time); + } +} +} // namespace + +TEST_F(WindowedFilterTest, UninitializedEstimates) { + EXPECT_EQ(QuicTime::Delta::Zero(), windowed_min_rtt_.GetBest()); + EXPECT_EQ(QuicTime::Delta::Zero(), windowed_min_rtt_.GetSecondBest()); + EXPECT_EQ(QuicTime::Delta::Zero(), windowed_min_rtt_.GetThirdBest()); + EXPECT_EQ(QuicBandwidth::Zero(), windowed_max_bw_.GetBest()); + EXPECT_EQ(QuicBandwidth::Zero(), windowed_max_bw_.GetSecondBest()); + EXPECT_EQ(QuicBandwidth::Zero(), windowed_max_bw_.GetThirdBest()); +} + +TEST_F(WindowedFilterTest, MonotonicallyIncreasingMin) { + QuicTime now = QuicTime::Zero(); + QuicTime::Delta rtt_sample = QuicTime::Delta::FromMilliseconds(10); + windowed_min_rtt_.Update(rtt_sample, now); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(10), windowed_min_rtt_.GetBest()); + + // Gradually increase the rtt samples and ensure the windowed min rtt starts + // rising. + for (int i = 0; i < 6; ++i) { + now = now + QuicTime::Delta::FromMilliseconds(25); + rtt_sample = rtt_sample + QuicTime::Delta::FromMilliseconds(10); + windowed_min_rtt_.Update(rtt_sample, now); + QUIC_VLOG(1) << "i: " << i << " sample: " << rtt_sample.ToMilliseconds() + << " mins: " + << " " << windowed_min_rtt_.GetBest().ToMilliseconds() << " " + << windowed_min_rtt_.GetSecondBest().ToMilliseconds() << " " + << windowed_min_rtt_.GetThirdBest().ToMilliseconds(); + if (i < 3) { + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(10), + windowed_min_rtt_.GetBest()); + } else if (i == 3) { + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(20), + windowed_min_rtt_.GetBest()); + } else if (i < 6) { + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(40), + windowed_min_rtt_.GetBest()); + } + } +} + +TEST_F(WindowedFilterTest, MonotonicallyDecreasingMax) { + QuicTime now = QuicTime::Zero(); + QuicBandwidth bw_sample = QuicBandwidth::FromBitsPerSecond(1000); + windowed_max_bw_.Update(bw_sample, now); + EXPECT_EQ(QuicBandwidth::FromBitsPerSecond(1000), windowed_max_bw_.GetBest()); + + // Gradually decrease the bw samples and ensure the windowed max bw starts + // decreasing. + for (int i = 0; i < 6; ++i) { + now = now + QuicTime::Delta::FromMilliseconds(25); + bw_sample = bw_sample - QuicBandwidth::FromBitsPerSecond(100); + windowed_max_bw_.Update(bw_sample, now); + QUIC_VLOG(1) << "i: " << i << " sample: " << bw_sample.ToBitsPerSecond() + << " maxs: " + << " " << windowed_max_bw_.GetBest().ToBitsPerSecond() << " " + << windowed_max_bw_.GetSecondBest().ToBitsPerSecond() << " " + << windowed_max_bw_.GetThirdBest().ToBitsPerSecond(); + if (i < 3) { + EXPECT_EQ(QuicBandwidth::FromBitsPerSecond(1000), + windowed_max_bw_.GetBest()); + } else if (i == 3) { + EXPECT_EQ(QuicBandwidth::FromBitsPerSecond(900), + windowed_max_bw_.GetBest()); + } else if (i < 6) { + EXPECT_EQ(QuicBandwidth::FromBitsPerSecond(700), + windowed_max_bw_.GetBest()); + } + } +} + +TEST_F(WindowedFilterTest, SampleChangesThirdBestMin) { + InitializeMinFilter(); + // RTT sample lower than the third-choice min-rtt sets that, but nothing else. + QuicTime::Delta rtt_sample = + windowed_min_rtt_.GetThirdBest() - QuicTime::Delta::FromMilliseconds(5); + // This assert is necessary to avoid triggering -Wstrict-overflow + // See crbug/616957 + ASSERT_GT(windowed_min_rtt_.GetThirdBest(), + QuicTime::Delta::FromMilliseconds(5)); + // Latest sample was recorded at 100ms. + QuicTime now = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(101); + windowed_min_rtt_.Update(rtt_sample, now); + EXPECT_EQ(rtt_sample, windowed_min_rtt_.GetThirdBest()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(40), + windowed_min_rtt_.GetSecondBest()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(20), windowed_min_rtt_.GetBest()); +} + +TEST_F(WindowedFilterTest, SampleChangesThirdBestMax) { + InitializeMaxFilter(); + // BW sample higher than the third-choice max sets that, but nothing else. + QuicBandwidth bw_sample = + windowed_max_bw_.GetThirdBest() + QuicBandwidth::FromBitsPerSecond(50); + // Latest sample was recorded at 100ms. + QuicTime now = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(101); + windowed_max_bw_.Update(bw_sample, now); + EXPECT_EQ(bw_sample, windowed_max_bw_.GetThirdBest()); + EXPECT_EQ(QuicBandwidth::FromBitsPerSecond(700), + windowed_max_bw_.GetSecondBest()); + EXPECT_EQ(QuicBandwidth::FromBitsPerSecond(900), windowed_max_bw_.GetBest()); +} + +TEST_F(WindowedFilterTest, SampleChangesSecondBestMin) { + InitializeMinFilter(); + // RTT sample lower than the second-choice min sets that and also + // the third-choice min. + QuicTime::Delta rtt_sample = + windowed_min_rtt_.GetSecondBest() - QuicTime::Delta::FromMilliseconds(5); + // This assert is necessary to avoid triggering -Wstrict-overflow + // See crbug/616957 + ASSERT_GT(windowed_min_rtt_.GetSecondBest(), + QuicTime::Delta::FromMilliseconds(5)); + // Latest sample was recorded at 100ms. + QuicTime now = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(101); + windowed_min_rtt_.Update(rtt_sample, now); + EXPECT_EQ(rtt_sample, windowed_min_rtt_.GetThirdBest()); + EXPECT_EQ(rtt_sample, windowed_min_rtt_.GetSecondBest()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(20), windowed_min_rtt_.GetBest()); +} + +TEST_F(WindowedFilterTest, SampleChangesSecondBestMax) { + InitializeMaxFilter(); + // BW sample higher than the second-choice max sets that and also + // the third-choice max. + QuicBandwidth bw_sample = + windowed_max_bw_.GetSecondBest() + QuicBandwidth::FromBitsPerSecond(50); + // Latest sample was recorded at 100ms. + QuicTime now = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(101); + windowed_max_bw_.Update(bw_sample, now); + EXPECT_EQ(bw_sample, windowed_max_bw_.GetThirdBest()); + EXPECT_EQ(bw_sample, windowed_max_bw_.GetSecondBest()); + EXPECT_EQ(QuicBandwidth::FromBitsPerSecond(900), windowed_max_bw_.GetBest()); +} + +TEST_F(WindowedFilterTest, SampleChangesAllMins) { + InitializeMinFilter(); + // RTT sample lower than the first-choice min-rtt sets that and also + // the second and third-choice mins. + QuicTime::Delta rtt_sample = + windowed_min_rtt_.GetBest() - QuicTime::Delta::FromMilliseconds(5); + // This assert is necessary to avoid triggering -Wstrict-overflow + // See crbug/616957 + ASSERT_GT(windowed_min_rtt_.GetBest(), QuicTime::Delta::FromMilliseconds(5)); + // Latest sample was recorded at 100ms. + QuicTime now = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(101); + windowed_min_rtt_.Update(rtt_sample, now); + EXPECT_EQ(rtt_sample, windowed_min_rtt_.GetThirdBest()); + EXPECT_EQ(rtt_sample, windowed_min_rtt_.GetSecondBest()); + EXPECT_EQ(rtt_sample, windowed_min_rtt_.GetBest()); +} + +TEST_F(WindowedFilterTest, SampleChangesAllMaxs) { + InitializeMaxFilter(); + // BW sample higher than the first-choice max sets that and also + // the second and third-choice maxs. + QuicBandwidth bw_sample = + windowed_max_bw_.GetBest() + QuicBandwidth::FromBitsPerSecond(50); + // Latest sample was recorded at 100ms. + QuicTime now = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(101); + windowed_max_bw_.Update(bw_sample, now); + EXPECT_EQ(bw_sample, windowed_max_bw_.GetThirdBest()); + EXPECT_EQ(bw_sample, windowed_max_bw_.GetSecondBest()); + EXPECT_EQ(bw_sample, windowed_max_bw_.GetBest()); +} + +TEST_F(WindowedFilterTest, ExpireBestMin) { + InitializeMinFilter(); + QuicTime::Delta old_third_best = windowed_min_rtt_.GetThirdBest(); + QuicTime::Delta old_second_best = windowed_min_rtt_.GetSecondBest(); + QuicTime::Delta rtt_sample = + old_third_best + QuicTime::Delta::FromMilliseconds(5); + // Best min sample was recorded at 25ms, so expiry time is 124ms. + QuicTime now = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(125); + windowed_min_rtt_.Update(rtt_sample, now); + EXPECT_EQ(rtt_sample, windowed_min_rtt_.GetThirdBest()); + EXPECT_EQ(old_third_best, windowed_min_rtt_.GetSecondBest()); + EXPECT_EQ(old_second_best, windowed_min_rtt_.GetBest()); +} + +TEST_F(WindowedFilterTest, ExpireBestMax) { + InitializeMaxFilter(); + QuicBandwidth old_third_best = windowed_max_bw_.GetThirdBest(); + QuicBandwidth old_second_best = windowed_max_bw_.GetSecondBest(); + QuicBandwidth bw_sample = + old_third_best - QuicBandwidth::FromBitsPerSecond(50); + // Best max sample was recorded at 25ms, so expiry time is 124ms. + QuicTime now = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(125); + windowed_max_bw_.Update(bw_sample, now); + EXPECT_EQ(bw_sample, windowed_max_bw_.GetThirdBest()); + EXPECT_EQ(old_third_best, windowed_max_bw_.GetSecondBest()); + EXPECT_EQ(old_second_best, windowed_max_bw_.GetBest()); +} + +TEST_F(WindowedFilterTest, ExpireSecondBestMin) { + InitializeMinFilter(); + QuicTime::Delta old_third_best = windowed_min_rtt_.GetThirdBest(); + QuicTime::Delta rtt_sample = + old_third_best + QuicTime::Delta::FromMilliseconds(5); + // Second best min sample was recorded at 75ms, so expiry time is 174ms. + QuicTime now = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(175); + windowed_min_rtt_.Update(rtt_sample, now); + EXPECT_EQ(rtt_sample, windowed_min_rtt_.GetThirdBest()); + EXPECT_EQ(rtt_sample, windowed_min_rtt_.GetSecondBest()); + EXPECT_EQ(old_third_best, windowed_min_rtt_.GetBest()); +} + +TEST_F(WindowedFilterTest, ExpireSecondBestMax) { + InitializeMaxFilter(); + QuicBandwidth old_third_best = windowed_max_bw_.GetThirdBest(); + QuicBandwidth bw_sample = + old_third_best - QuicBandwidth::FromBitsPerSecond(50); + // Second best max sample was recorded at 75ms, so expiry time is 174ms. + QuicTime now = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(175); + windowed_max_bw_.Update(bw_sample, now); + EXPECT_EQ(bw_sample, windowed_max_bw_.GetThirdBest()); + EXPECT_EQ(bw_sample, windowed_max_bw_.GetSecondBest()); + EXPECT_EQ(old_third_best, windowed_max_bw_.GetBest()); +} + +TEST_F(WindowedFilterTest, ExpireAllMins) { + InitializeMinFilter(); + QuicTime::Delta rtt_sample = + windowed_min_rtt_.GetThirdBest() + QuicTime::Delta::FromMilliseconds(5); + // This assert is necessary to avoid triggering -Wstrict-overflow + // See crbug/616957 + ASSERT_LT(windowed_min_rtt_.GetThirdBest(), + QuicTime::Delta::Infinite() - QuicTime::Delta::FromMilliseconds(5)); + // Third best min sample was recorded at 100ms, so expiry time is 199ms. + QuicTime now = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(200); + windowed_min_rtt_.Update(rtt_sample, now); + EXPECT_EQ(rtt_sample, windowed_min_rtt_.GetThirdBest()); + EXPECT_EQ(rtt_sample, windowed_min_rtt_.GetSecondBest()); + EXPECT_EQ(rtt_sample, windowed_min_rtt_.GetBest()); +} + +TEST_F(WindowedFilterTest, ExpireAllMaxs) { + InitializeMaxFilter(); + QuicBandwidth bw_sample = + windowed_max_bw_.GetThirdBest() - QuicBandwidth::FromBitsPerSecond(50); + // Third best max sample was recorded at 100ms, so expiry time is 199ms. + QuicTime now = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(200); + windowed_max_bw_.Update(bw_sample, now); + EXPECT_EQ(bw_sample, windowed_max_bw_.GetThirdBest()); + EXPECT_EQ(bw_sample, windowed_max_bw_.GetSecondBest()); + EXPECT_EQ(bw_sample, windowed_max_bw_.GetBest()); +} + +// Test the windowed filter where the time used is an exact counter instead of a +// timestamp. This is useful if, for example, the time is measured in round +// trips. +TEST_F(WindowedFilterTest, ExpireCounterBasedMax) { + // Create a window which starts at t = 0 and expires after two cycles. + WindowedFilter, uint64_t, uint64_t> max_filter( + 2, 0, 0); + + const uint64_t kBest = 50000; + // Insert 50000 at t = 1. + max_filter.Update(50000, 1); + EXPECT_EQ(kBest, max_filter.GetBest()); + UpdateWithIrrelevantSamples(&max_filter, 20, 1); + EXPECT_EQ(kBest, max_filter.GetBest()); + + // Insert 40000 at t = 2. Nothing is expected to expire. + max_filter.Update(40000, 2); + EXPECT_EQ(kBest, max_filter.GetBest()); + UpdateWithIrrelevantSamples(&max_filter, 20, 2); + EXPECT_EQ(kBest, max_filter.GetBest()); + + // Insert 30000 at t = 3. Nothing is expected to expire yet. + max_filter.Update(30000, 3); + EXPECT_EQ(kBest, max_filter.GetBest()); + UpdateWithIrrelevantSamples(&max_filter, 20, 3); + EXPECT_EQ(kBest, max_filter.GetBest()); + QUIC_VLOG(0) << max_filter.GetSecondBest(); + QUIC_VLOG(0) << max_filter.GetThirdBest(); + + // Insert 20000 at t = 4. 50000 at t = 1 expires, so 40000 becomes the new + // maximum. + const uint64_t kNewBest = 40000; + max_filter.Update(20000, 4); + EXPECT_EQ(kNewBest, max_filter.GetBest()); + UpdateWithIrrelevantSamples(&max_filter, 20, 4); + EXPECT_EQ(kNewBest, max_filter.GetBest()); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/connecting_client_socket.h b/quiche/quic/core/connecting_client_socket.h new file mode 100644 index 000000000000..0670d919b4ca --- /dev/null +++ b/quiche/quic/core/connecting_client_socket.h @@ -0,0 +1,111 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CONNECTING_CLIENT_SOCKET_H_ +#define QUICHE_QUIC_CORE_CONNECTING_CLIENT_SOCKET_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" + +namespace quic { + +// A client socket that provides connection-based send/receive. In the case of +// protocols like UDP, may only be a pseudo-connection that doesn't actually +// affect the underlying network protocol. +// +// Must not destroy a connected/connecting socket. If connected or connecting, +// must call Disconnect() to disconnect or cancel the connection before +// destruction. +// +// Warning regarding blocking calls: Code in the QUICHE library typically +// handles IO on a single thread, so if making calls from that typical +// environment, it would be problematic to make a blocking call and block that +// single thread. +class QUICHE_EXPORT ConnectingClientSocket { + public: + class AsyncVisitor { + public: + virtual ~AsyncVisitor() = default; + + virtual void ConnectComplete(absl::Status status) = 0; + + // If the operation completed without error, `data` is set to the received + // data. + virtual void ReceiveComplete( + absl::StatusOr data) = 0; + + virtual void SendComplete(absl::Status status) = 0; + }; + + virtual ~ConnectingClientSocket() = default; + + // Establishes a connection synchronously. Should not be called if socket has + // already been successfully connected without first calling Disconnect(). + // + // After calling, the socket must not be destroyed until Disconnect() is + // called. + virtual absl::Status ConnectBlocking() = 0; + + // Establishes a connection asynchronously. On completion, calls + // ConnectComplete() on the visitor, potentially before return from + // ConnectAsync(). Should not be called if socket has already been + // successfully connected without first calling Disconnect(). + // + // After calling, the socket must not be destroyed until Disconnect() is + // called. + virtual void ConnectAsync() = 0; + + // Disconnects a connected socket or cancels an in-progress ConnectAsync(), + // invoking the `ConnectComplete(absl::CancelledError())` on the visitor. + // After success, it is possible to call ConnectBlocking() or ConnectAsync() + // again to establish a new connection. Cancels any pending read or write + // operations, calling visitor completion methods with + // `absl::CancelledError()`. + // + // Typically implemented via a call to ::close(), which for TCP can result in + // either FIN or RST, depending on socket/platform state and undefined + // platform behavior. + virtual void Disconnect() = 0; + + // Gets the address assigned to a connected socket. + virtual absl::StatusOr GetLocalAddress() = 0; + + // Blocking read. Receives and returns a buffer of up to `max_size` bytes from + // socket. Returns status on error. + virtual absl::StatusOr ReceiveBlocking( + QuicByteCount max_size) = 0; + + // Asynchronous read. Receives up to `max_size` bytes from socket. If + // no data is synchronously available to be read, waits until some data is + // available or the socket is closed. On completion, calls ReceiveComplete() + // on the visitor, potentially before return from ReceiveAsync(). + // + // After calling, the socket must not be destroyed until ReceiveComplete() is + // called. + virtual void ReceiveAsync(QuicByteCount max_size) = 0; + + // Blocking write. Sends all of `data` (potentially via multiple underlying + // socket sends). + virtual absl::Status SendBlocking(std::string data) = 0; + virtual absl::Status SendBlocking(quiche::QuicheMemSlice data) = 0; + + // Asynchronous write. Sends all of `data` (potentially via multiple + // underlying socket sends). On completion, calls SendComplete() on the + // visitor, potentially before return from SendAsync(). + // + // After calling, the socket must not be destroyed until SendComplete() is + // called. + virtual void SendAsync(std::string data) = 0; + virtual void SendAsync(quiche::QuicheMemSlice data) = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CONNECTING_CLIENT_SOCKET_H_ diff --git a/quiche/quic/core/connection_id_generator.h b/quiche/quic/core/connection_id_generator.h new file mode 100644 index 000000000000..d27b44f06431 --- /dev/null +++ b/quiche/quic/core/connection_id_generator.h @@ -0,0 +1,34 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CONNECTION_ID_GENERATOR_H_ +#define QUICHE_QUIC_CORE_CONNECTION_ID_GENERATOR_H_ + +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_versions.h" + +namespace quic { + +class QUIC_EXPORT_PRIVATE ConnectionIdGeneratorInterface { + // Interface which is responsible for generating new connection IDs from an + // existing connection ID. + public: + // Generate a new connection ID for a given connection ID. Returns the new + // connection ID. If it cannot be generated for some reason, returns + // empty. + virtual absl::optional GenerateNextConnectionId( + const QuicConnectionId& original) = 0; + // Consider the client-generated server connection ID in the quic handshake + // and consider replacing it. Returns empty if not replaced. + virtual absl::optional MaybeReplaceConnectionId( + const QuicConnectionId& original, const ParsedQuicVersion& version) = 0; + // Returns the length of a connection ID generated by this generator with the + // specified first byte. + virtual uint8_t ConnectionIdLength(uint8_t first_byte) const = 0; + virtual ~ConnectionIdGeneratorInterface() = default; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CONNECTION_ID_GENERATOR_H_ diff --git a/quiche/quic/core/crypto/aead_base_decrypter.cc b/quiche/quic/core/crypto/aead_base_decrypter.cc new file mode 100644 index 000000000000..b6a3c4f68c3f --- /dev/null +++ b/quiche/quic/core/crypto/aead_base_decrypter.cc @@ -0,0 +1,190 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/aead_base_decrypter.h" + +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/string_view.h" +#include "openssl/crypto.h" +#include "openssl/err.h" +#include "openssl/evp.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/quiche_crypto_logging.h" + +namespace quic { +using ::quiche::ClearOpenSslErrors; +using ::quiche::DLogOpenSslErrors; +namespace { + +const EVP_AEAD* InitAndCall(const EVP_AEAD* (*aead_getter)()) { + // Ensure BoringSSL is initialized before calling |aead_getter|. In Chromium, + // the static initializer is disabled. + CRYPTO_library_init(); + return aead_getter(); +} + +} // namespace + +AeadBaseDecrypter::AeadBaseDecrypter(const EVP_AEAD* (*aead_getter)(), + size_t key_size, size_t auth_tag_size, + size_t nonce_size, + bool use_ietf_nonce_construction) + : aead_alg_(InitAndCall(aead_getter)), + key_size_(key_size), + auth_tag_size_(auth_tag_size), + nonce_size_(nonce_size), + use_ietf_nonce_construction_(use_ietf_nonce_construction), + have_preliminary_key_(false) { + QUICHE_DCHECK_GT(256u, key_size); + QUICHE_DCHECK_GT(256u, auth_tag_size); + QUICHE_DCHECK_GT(256u, nonce_size); + QUICHE_DCHECK_LE(key_size_, sizeof(key_)); + QUICHE_DCHECK_LE(nonce_size_, sizeof(iv_)); +} + +AeadBaseDecrypter::~AeadBaseDecrypter() {} + +bool AeadBaseDecrypter::SetKey(absl::string_view key) { + QUICHE_DCHECK_EQ(key.size(), key_size_); + if (key.size() != key_size_) { + return false; + } + memcpy(key_, key.data(), key.size()); + + EVP_AEAD_CTX_cleanup(ctx_.get()); + if (!EVP_AEAD_CTX_init(ctx_.get(), aead_alg_, key_, key_size_, auth_tag_size_, + nullptr)) { + DLogOpenSslErrors(); + return false; + } + + return true; +} + +bool AeadBaseDecrypter::SetNoncePrefix(absl::string_view nonce_prefix) { + if (use_ietf_nonce_construction_) { + QUIC_BUG(quic_bug_10709_1) + << "Attempted to set nonce prefix on IETF QUIC crypter"; + return false; + } + QUICHE_DCHECK_EQ(nonce_prefix.size(), nonce_size_ - sizeof(QuicPacketNumber)); + if (nonce_prefix.size() != nonce_size_ - sizeof(QuicPacketNumber)) { + return false; + } + memcpy(iv_, nonce_prefix.data(), nonce_prefix.size()); + return true; +} + +bool AeadBaseDecrypter::SetIV(absl::string_view iv) { + if (!use_ietf_nonce_construction_) { + QUIC_BUG(quic_bug_10709_2) << "Attempted to set IV on Google QUIC crypter"; + return false; + } + QUICHE_DCHECK_EQ(iv.size(), nonce_size_); + if (iv.size() != nonce_size_) { + return false; + } + memcpy(iv_, iv.data(), iv.size()); + return true; +} + +bool AeadBaseDecrypter::SetPreliminaryKey(absl::string_view key) { + QUICHE_DCHECK(!have_preliminary_key_); + SetKey(key); + have_preliminary_key_ = true; + + return true; +} + +bool AeadBaseDecrypter::SetDiversificationNonce( + const DiversificationNonce& nonce) { + if (!have_preliminary_key_) { + return true; + } + + std::string key, nonce_prefix; + size_t prefix_size = nonce_size_; + if (!use_ietf_nonce_construction_) { + prefix_size -= sizeof(QuicPacketNumber); + } + DiversifyPreliminaryKey( + absl::string_view(reinterpret_cast(key_), key_size_), + absl::string_view(reinterpret_cast(iv_), prefix_size), nonce, + key_size_, prefix_size, &key, &nonce_prefix); + + if (!SetKey(key) || + (!use_ietf_nonce_construction_ && !SetNoncePrefix(nonce_prefix)) || + (use_ietf_nonce_construction_ && !SetIV(nonce_prefix))) { + QUICHE_DCHECK(false); + return false; + } + + have_preliminary_key_ = false; + return true; +} + +bool AeadBaseDecrypter::DecryptPacket(uint64_t packet_number, + absl::string_view associated_data, + absl::string_view ciphertext, + char* output, size_t* output_length, + size_t max_output_length) { + if (ciphertext.length() < auth_tag_size_) { + return false; + } + + if (have_preliminary_key_) { + QUIC_BUG(quic_bug_10709_3) + << "Unable to decrypt while key diversification is pending"; + return false; + } + + uint8_t nonce[kMaxNonceSize]; + memcpy(nonce, iv_, nonce_size_); + size_t prefix_len = nonce_size_ - sizeof(packet_number); + if (use_ietf_nonce_construction_) { + for (size_t i = 0; i < sizeof(packet_number); ++i) { + nonce[prefix_len + i] ^= + (packet_number >> ((sizeof(packet_number) - i - 1) * 8)) & 0xff; + } + } else { + memcpy(nonce + prefix_len, &packet_number, sizeof(packet_number)); + } + if (!EVP_AEAD_CTX_open( + ctx_.get(), reinterpret_cast(output), output_length, + max_output_length, reinterpret_cast(nonce), + nonce_size_, reinterpret_cast(ciphertext.data()), + ciphertext.size(), + reinterpret_cast(associated_data.data()), + associated_data.size())) { + // Because QuicFramer does trial decryption, decryption errors are expected + // when encryption level changes. So we don't log decryption errors. + ClearOpenSslErrors(); + return false; + } + return true; +} + +size_t AeadBaseDecrypter::GetKeySize() const { return key_size_; } + +size_t AeadBaseDecrypter::GetNoncePrefixSize() const { + return nonce_size_ - sizeof(QuicPacketNumber); +} + +size_t AeadBaseDecrypter::GetIVSize() const { return nonce_size_; } + +absl::string_view AeadBaseDecrypter::GetKey() const { + return absl::string_view(reinterpret_cast(key_), key_size_); +} + +absl::string_view AeadBaseDecrypter::GetNoncePrefix() const { + return absl::string_view(reinterpret_cast(iv_), + nonce_size_ - sizeof(QuicPacketNumber)); +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/aead_base_decrypter.h b/quiche/quic/core/crypto/aead_base_decrypter.h new file mode 100644 index 000000000000..b123b13e387d --- /dev/null +++ b/quiche/quic/core/crypto/aead_base_decrypter.h @@ -0,0 +1,69 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_AEAD_BASE_DECRYPTER_H_ +#define QUICHE_QUIC_CORE_CRYPTO_AEAD_BASE_DECRYPTER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "openssl/aead.h" +#include "quiche/quic/core/crypto/quic_decrypter.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// AeadBaseDecrypter is the base class of AEAD QuicDecrypter subclasses. +class QUIC_EXPORT_PRIVATE AeadBaseDecrypter : public QuicDecrypter { + public: + // This takes the function pointer rather than the EVP_AEAD itself so + // subclasses do not need to call CRYPTO_library_init. + AeadBaseDecrypter(const EVP_AEAD* (*aead_getter)(), size_t key_size, + size_t auth_tag_size, size_t nonce_size, + bool use_ietf_nonce_construction); + AeadBaseDecrypter(const AeadBaseDecrypter&) = delete; + AeadBaseDecrypter& operator=(const AeadBaseDecrypter&) = delete; + ~AeadBaseDecrypter() override; + + // QuicDecrypter implementation + bool SetKey(absl::string_view key) override; + bool SetNoncePrefix(absl::string_view nonce_prefix) override; + bool SetIV(absl::string_view iv) override; + bool SetPreliminaryKey(absl::string_view key) override; + bool SetDiversificationNonce(const DiversificationNonce& nonce) override; + bool DecryptPacket(uint64_t packet_number, absl::string_view associated_data, + absl::string_view ciphertext, char* output, + size_t* output_length, size_t max_output_length) override; + size_t GetKeySize() const override; + size_t GetNoncePrefixSize() const override; + size_t GetIVSize() const override; + absl::string_view GetKey() const override; + absl::string_view GetNoncePrefix() const override; + + protected: + // Make these constants available to the subclasses so that the subclasses + // can assert at compile time their key_size_ and nonce_size_ do not + // exceed the maximum. + static const size_t kMaxKeySize = 32; + static const size_t kMaxNonceSize = 12; + + private: + const EVP_AEAD* const aead_alg_; + const size_t key_size_; + const size_t auth_tag_size_; + const size_t nonce_size_; + const bool use_ietf_nonce_construction_; + bool have_preliminary_key_; + + // The key. + unsigned char key_[kMaxKeySize]; + // The IV used to construct the nonce. + unsigned char iv_[kMaxNonceSize]; + + bssl::ScopedEVP_AEAD_CTX ctx_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_AEAD_BASE_DECRYPTER_H_ diff --git a/quiche/quic/core/crypto/aead_base_encrypter.cc b/quiche/quic/core/crypto/aead_base_encrypter.cc new file mode 100644 index 000000000000..481eaa970a20 --- /dev/null +++ b/quiche/quic/core/crypto/aead_base_encrypter.cc @@ -0,0 +1,168 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/aead_base_encrypter.h" + +#include "absl/base/macros.h" +#include "absl/strings/string_view.h" +#include "openssl/crypto.h" +#include "openssl/err.h" +#include "openssl/evp.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/quiche_crypto_logging.h" + +namespace quic { +using ::quiche::DLogOpenSslErrors; +namespace { + +const EVP_AEAD* InitAndCall(const EVP_AEAD* (*aead_getter)()) { + // Ensure BoringSSL is initialized before calling |aead_getter|. In Chromium, + // the static initializer is disabled. + CRYPTO_library_init(); + return aead_getter(); +} + +} // namespace + +AeadBaseEncrypter::AeadBaseEncrypter(const EVP_AEAD* (*aead_getter)(), + size_t key_size, size_t auth_tag_size, + size_t nonce_size, + bool use_ietf_nonce_construction) + : aead_alg_(InitAndCall(aead_getter)), + key_size_(key_size), + auth_tag_size_(auth_tag_size), + nonce_size_(nonce_size), + use_ietf_nonce_construction_(use_ietf_nonce_construction) { + QUICHE_DCHECK_LE(key_size_, sizeof(key_)); + QUICHE_DCHECK_LE(nonce_size_, sizeof(iv_)); + QUICHE_DCHECK_GE(kMaxNonceSize, nonce_size_); +} + +AeadBaseEncrypter::~AeadBaseEncrypter() {} + +bool AeadBaseEncrypter::SetKey(absl::string_view key) { + QUICHE_DCHECK_EQ(key.size(), key_size_); + if (key.size() != key_size_) { + return false; + } + memcpy(key_, key.data(), key.size()); + + EVP_AEAD_CTX_cleanup(ctx_.get()); + + if (!EVP_AEAD_CTX_init(ctx_.get(), aead_alg_, key_, key_size_, auth_tag_size_, + nullptr)) { + DLogOpenSslErrors(); + return false; + } + + return true; +} + +bool AeadBaseEncrypter::SetNoncePrefix(absl::string_view nonce_prefix) { + if (use_ietf_nonce_construction_) { + QUIC_BUG(quic_bug_10634_1) + << "Attempted to set nonce prefix on IETF QUIC crypter"; + return false; + } + QUICHE_DCHECK_EQ(nonce_prefix.size(), nonce_size_ - sizeof(QuicPacketNumber)); + if (nonce_prefix.size() != nonce_size_ - sizeof(QuicPacketNumber)) { + return false; + } + memcpy(iv_, nonce_prefix.data(), nonce_prefix.size()); + return true; +} + +bool AeadBaseEncrypter::SetIV(absl::string_view iv) { + if (!use_ietf_nonce_construction_) { + QUIC_BUG(quic_bug_10634_2) << "Attempted to set IV on Google QUIC crypter"; + return false; + } + QUICHE_DCHECK_EQ(iv.size(), nonce_size_); + if (iv.size() != nonce_size_) { + return false; + } + memcpy(iv_, iv.data(), iv.size()); + return true; +} + +bool AeadBaseEncrypter::Encrypt(absl::string_view nonce, + absl::string_view associated_data, + absl::string_view plaintext, + unsigned char* output) { + QUICHE_DCHECK_EQ(nonce.size(), nonce_size_); + + size_t ciphertext_len; + if (!EVP_AEAD_CTX_seal( + ctx_.get(), output, &ciphertext_len, + plaintext.size() + auth_tag_size_, + reinterpret_cast(nonce.data()), nonce.size(), + reinterpret_cast(plaintext.data()), plaintext.size(), + reinterpret_cast(associated_data.data()), + associated_data.size())) { + DLogOpenSslErrors(); + return false; + } + + return true; +} + +bool AeadBaseEncrypter::EncryptPacket(uint64_t packet_number, + absl::string_view associated_data, + absl::string_view plaintext, char* output, + size_t* output_length, + size_t max_output_length) { + size_t ciphertext_size = GetCiphertextSize(plaintext.length()); + if (max_output_length < ciphertext_size) { + return false; + } + // TODO(ianswett): Introduce a check to ensure that we don't encrypt with the + // same packet number twice. + alignas(4) char nonce_buffer[kMaxNonceSize]; + memcpy(nonce_buffer, iv_, nonce_size_); + size_t prefix_len = nonce_size_ - sizeof(packet_number); + if (use_ietf_nonce_construction_) { + for (size_t i = 0; i < sizeof(packet_number); ++i) { + nonce_buffer[prefix_len + i] ^= + (packet_number >> ((sizeof(packet_number) - i - 1) * 8)) & 0xff; + } + } else { + memcpy(nonce_buffer + prefix_len, &packet_number, sizeof(packet_number)); + } + + if (!Encrypt(absl::string_view(nonce_buffer, nonce_size_), associated_data, + plaintext, reinterpret_cast(output))) { + return false; + } + *output_length = ciphertext_size; + return true; +} + +size_t AeadBaseEncrypter::GetKeySize() const { return key_size_; } + +size_t AeadBaseEncrypter::GetNoncePrefixSize() const { + return nonce_size_ - sizeof(QuicPacketNumber); +} + +size_t AeadBaseEncrypter::GetIVSize() const { return nonce_size_; } + +size_t AeadBaseEncrypter::GetMaxPlaintextSize(size_t ciphertext_size) const { + return ciphertext_size - std::min(ciphertext_size, auth_tag_size_); +} + +size_t AeadBaseEncrypter::GetCiphertextSize(size_t plaintext_size) const { + return plaintext_size + auth_tag_size_; +} + +absl::string_view AeadBaseEncrypter::GetKey() const { + return absl::string_view(reinterpret_cast(key_), key_size_); +} + +absl::string_view AeadBaseEncrypter::GetNoncePrefix() const { + return absl::string_view(reinterpret_cast(iv_), + GetNoncePrefixSize()); +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/aead_base_encrypter.h b/quiche/quic/core/crypto/aead_base_encrypter.h new file mode 100644 index 000000000000..205b23265d1f --- /dev/null +++ b/quiche/quic/core/crypto/aead_base_encrypter.h @@ -0,0 +1,73 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_AEAD_BASE_ENCRYPTER_H_ +#define QUICHE_QUIC_CORE_CRYPTO_AEAD_BASE_ENCRYPTER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "openssl/aead.h" +#include "quiche/quic/core/crypto/quic_encrypter.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// AeadBaseEncrypter is the base class of AEAD QuicEncrypter subclasses. +class QUIC_EXPORT_PRIVATE AeadBaseEncrypter : public QuicEncrypter { + public: + // This takes the function pointer rather than the EVP_AEAD itself so + // subclasses do not need to call CRYPTO_library_init. + AeadBaseEncrypter(const EVP_AEAD* (*aead_getter)(), size_t key_size, + size_t auth_tag_size, size_t nonce_size, + bool use_ietf_nonce_construction); + AeadBaseEncrypter(const AeadBaseEncrypter&) = delete; + AeadBaseEncrypter& operator=(const AeadBaseEncrypter&) = delete; + ~AeadBaseEncrypter() override; + + // QuicEncrypter implementation + bool SetKey(absl::string_view key) override; + bool SetNoncePrefix(absl::string_view nonce_prefix) override; + bool SetIV(absl::string_view iv) override; + bool EncryptPacket(uint64_t packet_number, absl::string_view associated_data, + absl::string_view plaintext, char* output, + size_t* output_length, size_t max_output_length) override; + size_t GetKeySize() const override; + size_t GetNoncePrefixSize() const override; + size_t GetIVSize() const override; + size_t GetMaxPlaintextSize(size_t ciphertext_size) const override; + size_t GetCiphertextSize(size_t plaintext_size) const override; + absl::string_view GetKey() const override; + absl::string_view GetNoncePrefix() const override; + + // Necessary so unit tests can explicitly specify a nonce, instead of an IV + // (or nonce prefix) and packet number. + bool Encrypt(absl::string_view nonce, absl::string_view associated_data, + absl::string_view plaintext, unsigned char* output); + + protected: + // Make these constants available to the subclasses so that the subclasses + // can assert at compile time their key_size_ and nonce_size_ do not + // exceed the maximum. + static const size_t kMaxKeySize = 32; + enum : size_t { kMaxNonceSize = 12 }; + + private: + const EVP_AEAD* const aead_alg_; + const size_t key_size_; + const size_t auth_tag_size_; + const size_t nonce_size_; + const bool use_ietf_nonce_construction_; + + // The key. + unsigned char key_[kMaxKeySize]; + // The IV used to construct the nonce. + unsigned char iv_[kMaxNonceSize]; + + bssl::ScopedEVP_AEAD_CTX ctx_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_AEAD_BASE_ENCRYPTER_H_ diff --git a/quiche/quic/core/crypto/aes_128_gcm_12_decrypter.cc b/quiche/quic/core/crypto/aes_128_gcm_12_decrypter.cc new file mode 100644 index 000000000000..66f2ad2da09a --- /dev/null +++ b/quiche/quic/core/crypto/aes_128_gcm_12_decrypter.cc @@ -0,0 +1,32 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/aes_128_gcm_12_decrypter.h" + +#include "openssl/aead.h" +#include "openssl/tls1.h" + +namespace quic { + +namespace { + +const size_t kKeySize = 16; +const size_t kNonceSize = 12; + +} // namespace + +Aes128Gcm12Decrypter::Aes128Gcm12Decrypter() + : AesBaseDecrypter(EVP_aead_aes_128_gcm, kKeySize, kAuthTagSize, kNonceSize, + /* use_ietf_nonce_construction */ false) { + static_assert(kKeySize <= kMaxKeySize, "key size too big"); + static_assert(kNonceSize <= kMaxNonceSize, "nonce size too big"); +} + +Aes128Gcm12Decrypter::~Aes128Gcm12Decrypter() {} + +uint32_t Aes128Gcm12Decrypter::cipher_id() const { + return TLS1_CK_AES_128_GCM_SHA256; +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/aes_128_gcm_12_decrypter.h b/quiche/quic/core/crypto/aes_128_gcm_12_decrypter.h new file mode 100644 index 000000000000..38a941991e0a --- /dev/null +++ b/quiche/quic/core/crypto/aes_128_gcm_12_decrypter.h @@ -0,0 +1,38 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_AES_128_GCM_12_DECRYPTER_H_ +#define QUICHE_QUIC_CORE_CRYPTO_AES_128_GCM_12_DECRYPTER_H_ + +#include + +#include "quiche/quic/core/crypto/aes_base_decrypter.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// An Aes128Gcm12Decrypter is a QuicDecrypter that implements the +// AEAD_AES_128_GCM_12 algorithm specified in RFC 5282. Create an instance by +// calling QuicDecrypter::Create(kAESG). +// +// It uses an authentication tag of 12 bytes (96 bits). The fixed prefix +// of the nonce is four bytes. +class QUIC_EXPORT_PRIVATE Aes128Gcm12Decrypter : public AesBaseDecrypter { + public: + enum { + // Authentication tags are truncated to 96 bits. + kAuthTagSize = 12, + }; + + Aes128Gcm12Decrypter(); + Aes128Gcm12Decrypter(const Aes128Gcm12Decrypter&) = delete; + Aes128Gcm12Decrypter& operator=(const Aes128Gcm12Decrypter&) = delete; + ~Aes128Gcm12Decrypter() override; + + uint32_t cipher_id() const override; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_AES_128_GCM_12_DECRYPTER_H_ diff --git a/quiche/quic/core/crypto/aes_128_gcm_12_decrypter_test.cc b/quiche/quic/core/crypto/aes_128_gcm_12_decrypter_test.cc new file mode 100644 index 000000000000..64e11eb24a5e --- /dev/null +++ b/quiche/quic/core/crypto/aes_128_gcm_12_decrypter_test.cc @@ -0,0 +1,288 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/aes_128_gcm_12_decrypter.h" + +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +namespace { + +// The AES GCM test vectors come from the file gcmDecrypt128.rsp +// downloaded from http://csrc.nist.gov/groups/STM/cavp/index.html on +// 2013-02-01. The test vectors in that file look like this: +// +// [Keylen = 128] +// [IVlen = 96] +// [PTlen = 0] +// [AADlen = 0] +// [Taglen = 128] +// +// Count = 0 +// Key = cf063a34d4a9a76c2c86787d3f96db71 +// IV = 113b9785971864c83b01c787 +// CT = +// AAD = +// Tag = 72ac8493e3a5228b5d130a69d2510e42 +// PT = +// +// Count = 1 +// Key = a49a5e26a2f8cb63d05546c2a62f5343 +// IV = 907763b19b9b4ab6bd4f0281 +// CT = +// AAD = +// Tag = a2be08210d8c470a8df6e8fbd79ec5cf +// FAIL +// +// ... +// +// The gcmDecrypt128.rsp file is huge (2.6 MB), so I selected just a +// few test vectors for this unit test. + +// Describes a group of test vectors that all have a given key length, IV +// length, plaintext length, AAD length, and tag length. +struct TestGroupInfo { + size_t key_len; + size_t iv_len; + size_t pt_len; + size_t aad_len; + size_t tag_len; +}; + +// Each test vector consists of six strings of lowercase hexadecimal digits. +// The strings may be empty (zero length). A test vector with a nullptr |key| +// marks the end of an array of test vectors. +struct TestVector { + // Input: + const char* key; + const char* iv; + const char* ct; + const char* aad; + const char* tag; + + // Expected output: + const char* pt; // An empty string "" means decryption succeeded and + // the plaintext is zero-length. nullptr means decryption + // failed. +}; + +const TestGroupInfo test_group_info[] = { + {128, 96, 0, 0, 128}, {128, 96, 0, 128, 128}, {128, 96, 128, 0, 128}, + {128, 96, 408, 160, 128}, {128, 96, 408, 720, 128}, {128, 96, 104, 0, 128}, +}; + +const TestVector test_group_0[] = { + {"cf063a34d4a9a76c2c86787d3f96db71", "113b9785971864c83b01c787", "", "", + "72ac8493e3a5228b5d130a69d2510e42", ""}, + { + "a49a5e26a2f8cb63d05546c2a62f5343", "907763b19b9b4ab6bd4f0281", "", "", + "a2be08210d8c470a8df6e8fbd79ec5cf", + nullptr // FAIL + }, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_1[] = { + { + "d1f6af919cde85661208bdce0c27cb22", "898c6929b435017bf031c3c5", "", + "7c5faa40e636bbc91107e68010c92b9f", "ae45f11777540a2caeb128be8092468a", + nullptr // FAIL + }, + {"2370e320d4344208e0ff5683f243b213", "04dbb82f044d30831c441228", "", + "d43a8e5089eea0d026c03a85178b27da", "2a049c049d25aa95969b451d93c31c6e", + ""}, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_2[] = { + {"e98b72a9881a84ca6b76e0f43e68647a", "8b23299fde174053f3d652ba", + "5a3c1cf1985dbb8bed818036fdd5ab42", "", "23c7ab0f952b7091cd324835043b5eb5", + "28286a321293253c3e0aa2704a278032"}, + {"33240636cd3236165f1a553b773e728e", "17c4d61493ecdc8f31700b12", + "47bb7e23f7bdfe05a8091ac90e4f8b2e", "", "b723c70e931d9785f40fd4ab1d612dc9", + "95695a5b12f2870b9cc5fdc8f218a97d"}, + { + "5164df856f1e9cac04a79b808dc5be39", "e76925d5355e0584ce871b2b", + "0216c899c88d6e32c958c7e553daa5bc", "", + "a145319896329c96df291f64efbe0e3a", + nullptr // FAIL + }, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_3[] = { + {"af57f42c60c0fc5a09adb81ab86ca1c3", "a2dc01871f37025dc0fc9a79", + "b9a535864f48ea7b6b1367914978f9bfa087d854bb0e269bed8d279d2eea1210e48947" + "338b22f9bad09093276a331e9c79c7f4", + "41dc38988945fcb44faf2ef72d0061289ef8efd8", + "4f71e72bde0018f555c5adcce062e005", + "3803a0727eeb0ade441e0ec107161ded2d425ec0d102f21f51bf2cf9947c7ec4aa7279" + "5b2f69b041596e8817d0a3c16f8fadeb"}, + {"ebc753e5422b377d3cb64b58ffa41b61", "2e1821efaced9acf1f241c9b", + "069567190554e9ab2b50a4e1fbf9c147340a5025fdbd201929834eaf6532325899ccb9" + "f401823e04b05817243d2142a3589878", + "b9673412fd4f88ba0e920f46dd6438ff791d8eef", + "534d9234d2351cf30e565de47baece0b", + "39077edb35e9c5a4b1e4c2a6b9bb1fce77f00f5023af40333d6d699014c2bcf4209c18" + "353a18017f5b36bfc00b1f6dcb7ed485"}, + { + "52bdbbf9cf477f187ec010589cb39d58", "d3be36d3393134951d324b31", + "700188da144fa692cf46e4a8499510a53d90903c967f7f13e8a1bd8151a74adc4fe63e" + "32b992760b3a5f99e9a47838867000a9", + "93c4fc6a4135f54d640b0c976bf755a06a292c33", + "8ca4e38aa3dfa6b1d0297021ccf3ea5f", + nullptr // FAIL + }, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_4[] = { + {"da2bb7d581493d692380c77105590201", "44aa3e7856ca279d2eb020c6", + "9290d430c9e89c37f0446dbd620c9a6b34b1274aeb6f911f75867efcf95b6feda69f1a" + "f4ee16c761b3c9aeac3da03aa9889c88", + "4cd171b23bddb3a53cdf959d5c1710b481eb3785a90eb20a2345ee00d0bb7868c367ab" + "12e6f4dd1dee72af4eee1d197777d1d6499cc541f34edbf45cda6ef90b3c024f9272d7" + "2ec1909fb8fba7db88a4d6f7d3d925980f9f9f72", + "9e3ac938d3eb0cadd6f5c9e35d22ba38", + "9bbf4c1a2742f6ac80cb4e8a052e4a8f4f07c43602361355b717381edf9fabd4cb7e3a" + "d65dbd1378b196ac270588dd0621f642"}, + {"d74e4958717a9d5c0e235b76a926cae8", "0b7471141e0c70b1995fd7b1", + "e701c57d2330bf066f9ff8cf3ca4343cafe4894651cd199bdaaa681ba486b4a65c5a22" + "b0f1420be29ea547d42c713bc6af66aa", + "4a42b7aae8c245c6f1598a395316e4b8484dbd6e64648d5e302021b1d3fa0a38f46e22" + "bd9c8080b863dc0016482538a8562a4bd0ba84edbe2697c76fd039527ac179ec5506cf" + "34a6039312774cedebf4961f3978b14a26509f96", + "e192c23cb036f0b31592989119eed55d", + "840d9fb95e32559fb3602e48590280a172ca36d9b49ab69510f5bd552bfab7a306f85f" + "f0a34bc305b88b804c60b90add594a17"}, + { + "1986310c725ac94ecfe6422e75fc3ee7", "93ec4214fa8e6dc4e3afc775", + "b178ec72f85a311ac4168f42a4b2c23113fbea4b85f4b9dabb74e143eb1b8b0a361e02" + "43edfd365b90d5b325950df0ada058f9", + "e80b88e62c49c958b5e0b8b54f532d9ff6aa84c8a40132e93e55b59fc24e8decf28463" + "139f155d1e8ce4ee76aaeefcd245baa0fc519f83a5fb9ad9aa40c4b21126013f576c42" + "72c2cb136c8fd091cc4539877a5d1e72d607f960", + "8b347853f11d75e81e8a95010be81f17", + nullptr // FAIL + }, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_5[] = { + {"387218b246c1a8257748b56980e50c94", "dd7e014198672be39f95b69d", + "cdba9e73eaf3d38eceb2b04a8d", "", "ecf90f4a47c9c626d6fb2c765d201556", + "48f5b426baca03064554cc2b30"}, + {"294de463721e359863887c820524b3d4", "3338b35c9d57a5d28190e8c9", + "2f46634e74b8e4c89812ac83b9", "", "dabd506764e68b82a7e720aa18da0abe", + "46a2e55c8e264df211bd112685"}, + {"28ead7fd2179e0d12aa6d5d88c58c2dc", "5055347f18b4d5add0ae5c41", + "142d8210c3fb84774cdbd0447a", "", "5fd321d9cdb01952dc85f034736c2a7d", + "3b95b981086ee73cc4d0cc1422"}, + { + "7d7b6c988137b8d470c57bf674a09c87", "9edf2aa970d016ac962e1fd8", + "a85b66c3cb5eab91d5bdc8bc0e", "", "dc054efc01f3afd21d9c2484819f569a", + nullptr // FAIL + }, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector* const test_group_array[] = { + test_group_0, test_group_1, test_group_2, + test_group_3, test_group_4, test_group_5, +}; + +} // namespace + +namespace quic { +namespace test { + +// DecryptWithNonce wraps the |Decrypt| method of |decrypter| to allow passing +// in an nonce and also to allocate the buffer needed for the plaintext. +QuicData* DecryptWithNonce(Aes128Gcm12Decrypter* decrypter, + absl::string_view nonce, + absl::string_view associated_data, + absl::string_view ciphertext) { + uint64_t packet_number; + absl::string_view nonce_prefix(nonce.data(), + nonce.size() - sizeof(packet_number)); + decrypter->SetNoncePrefix(nonce_prefix); + memcpy(&packet_number, nonce.data() + nonce_prefix.size(), + sizeof(packet_number)); + std::unique_ptr output(new char[ciphertext.length()]); + size_t output_length = 0; + const bool success = decrypter->DecryptPacket( + packet_number, associated_data, ciphertext, output.get(), &output_length, + ciphertext.length()); + if (!success) { + return nullptr; + } + return new QuicData(output.release(), output_length, true); +} + +class Aes128Gcm12DecrypterTest : public QuicTest {}; + +TEST_F(Aes128Gcm12DecrypterTest, Decrypt) { + for (size_t i = 0; i < ABSL_ARRAYSIZE(test_group_array); i++) { + SCOPED_TRACE(i); + const TestVector* test_vectors = test_group_array[i]; + const TestGroupInfo& test_info = test_group_info[i]; + for (size_t j = 0; test_vectors[j].key != nullptr; j++) { + // If not present then decryption is expected to fail. + bool has_pt = test_vectors[j].pt; + + // Decode the test vector. + std::string key = absl::HexStringToBytes(test_vectors[j].key); + std::string iv = absl::HexStringToBytes(test_vectors[j].iv); + std::string ct = absl::HexStringToBytes(test_vectors[j].ct); + std::string aad = absl::HexStringToBytes(test_vectors[j].aad); + std::string tag = absl::HexStringToBytes(test_vectors[j].tag); + std::string pt; + if (has_pt) { + pt = absl::HexStringToBytes(test_vectors[j].pt); + } + + // The test vector's lengths should look sane. Note that the lengths + // in |test_info| are in bits. + EXPECT_EQ(test_info.key_len, key.length() * 8); + EXPECT_EQ(test_info.iv_len, iv.length() * 8); + EXPECT_EQ(test_info.pt_len, ct.length() * 8); + EXPECT_EQ(test_info.aad_len, aad.length() * 8); + EXPECT_EQ(test_info.tag_len, tag.length() * 8); + if (has_pt) { + EXPECT_EQ(test_info.pt_len, pt.length() * 8); + } + + // The test vectors have 16 byte authenticators but this code only uses + // the first 12. + ASSERT_LE(static_cast(Aes128Gcm12Decrypter::kAuthTagSize), + tag.length()); + tag.resize(Aes128Gcm12Decrypter::kAuthTagSize); + std::string ciphertext = ct + tag; + + Aes128Gcm12Decrypter decrypter; + ASSERT_TRUE(decrypter.SetKey(key)); + + std::unique_ptr decrypted(DecryptWithNonce( + &decrypter, iv, + // This deliberately tests that the decrypter can + // handle an AAD that is set to nullptr, as opposed + // to a zero-length, non-nullptr pointer. + aad.length() ? aad : absl::string_view(), ciphertext)); + if (!decrypted) { + EXPECT_FALSE(has_pt); + continue; + } + EXPECT_TRUE(has_pt); + + ASSERT_EQ(pt.length(), decrypted->length()); + quiche::test::CompareCharArraysWithHexError( + "plaintext", decrypted->data(), pt.length(), pt.data(), pt.length()); + } + } +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/aes_128_gcm_12_encrypter.cc b/quiche/quic/core/crypto/aes_128_gcm_12_encrypter.cc new file mode 100644 index 000000000000..5bbaeba079d6 --- /dev/null +++ b/quiche/quic/core/crypto/aes_128_gcm_12_encrypter.cc @@ -0,0 +1,27 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/aes_128_gcm_12_encrypter.h" + +#include "openssl/evp.h" + +namespace quic { + +namespace { + +const size_t kKeySize = 16; +const size_t kNonceSize = 12; + +} // namespace + +Aes128Gcm12Encrypter::Aes128Gcm12Encrypter() + : AesBaseEncrypter(EVP_aead_aes_128_gcm, kKeySize, kAuthTagSize, kNonceSize, + /* use_ietf_nonce_construction */ false) { + static_assert(kKeySize <= kMaxKeySize, "key size too big"); + static_assert(kNonceSize <= kMaxNonceSize, "nonce size too big"); +} + +Aes128Gcm12Encrypter::~Aes128Gcm12Encrypter() {} + +} // namespace quic diff --git a/quiche/quic/core/crypto/aes_128_gcm_12_encrypter.h b/quiche/quic/core/crypto/aes_128_gcm_12_encrypter.h new file mode 100644 index 000000000000..64f0292f26a4 --- /dev/null +++ b/quiche/quic/core/crypto/aes_128_gcm_12_encrypter.h @@ -0,0 +1,34 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_AES_128_GCM_12_ENCRYPTER_H_ +#define QUICHE_QUIC_CORE_CRYPTO_AES_128_GCM_12_ENCRYPTER_H_ + +#include "quiche/quic/core/crypto/aes_base_encrypter.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// An Aes128Gcm12Encrypter is a QuicEncrypter that implements the +// AEAD_AES_128_GCM_12 algorithm specified in RFC 5282. Create an instance by +// calling QuicEncrypter::Create(kAESG). +// +// It uses an authentication tag of 12 bytes (96 bits). The fixed prefix +// of the nonce is four bytes. +class QUIC_EXPORT_PRIVATE Aes128Gcm12Encrypter : public AesBaseEncrypter { + public: + enum { + // Authentication tags are truncated to 96 bits. + kAuthTagSize = 12, + }; + + Aes128Gcm12Encrypter(); + Aes128Gcm12Encrypter(const Aes128Gcm12Encrypter&) = delete; + Aes128Gcm12Encrypter& operator=(const Aes128Gcm12Encrypter&) = delete; + ~Aes128Gcm12Encrypter() override; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_AES_128_GCM_12_ENCRYPTER_H_ diff --git a/quiche/quic/core/crypto/aes_128_gcm_12_encrypter_test.cc b/quiche/quic/core/crypto/aes_128_gcm_12_encrypter_test.cc new file mode 100644 index 000000000000..47dbd67e8b59 --- /dev/null +++ b/quiche/quic/core/crypto/aes_128_gcm_12_encrypter_test.cc @@ -0,0 +1,244 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/aes_128_gcm_12_encrypter.h" + +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +namespace { + +// The AES GCM test vectors come from the file gcmEncryptExtIV128.rsp +// downloaded from http://csrc.nist.gov/groups/STM/cavp/index.html on +// 2013-02-01. The test vectors in that file look like this: +// +// [Keylen = 128] +// [IVlen = 96] +// [PTlen = 0] +// [AADlen = 0] +// [Taglen = 128] +// +// Count = 0 +// Key = 11754cd72aec309bf52f7687212e8957 +// IV = 3c819d9a9bed087615030b65 +// PT = +// AAD = +// CT = +// Tag = 250327c674aaf477aef2675748cf6971 +// +// Count = 1 +// Key = ca47248ac0b6f8372a97ac43508308ed +// IV = ffd2b598feabc9019262d2be +// PT = +// AAD = +// CT = +// Tag = 60d20404af527d248d893ae495707d1a +// +// ... +// +// The gcmEncryptExtIV128.rsp file is huge (2.8 MB), so I selected just a +// few test vectors for this unit test. + +// Describes a group of test vectors that all have a given key length, IV +// length, plaintext length, AAD length, and tag length. +struct TestGroupInfo { + size_t key_len; + size_t iv_len; + size_t pt_len; + size_t aad_len; + size_t tag_len; +}; + +// Each test vector consists of six strings of lowercase hexadecimal digits. +// The strings may be empty (zero length). A test vector with a nullptr |key| +// marks the end of an array of test vectors. +struct TestVector { + const char* key; + const char* iv; + const char* pt; + const char* aad; + const char* ct; + const char* tag; +}; + +const TestGroupInfo test_group_info[] = { + {128, 96, 0, 0, 128}, {128, 96, 0, 128, 128}, {128, 96, 128, 0, 128}, + {128, 96, 408, 160, 128}, {128, 96, 408, 720, 128}, {128, 96, 104, 0, 128}, +}; + +const TestVector test_group_0[] = { + {"11754cd72aec309bf52f7687212e8957", "3c819d9a9bed087615030b65", "", "", "", + "250327c674aaf477aef2675748cf6971"}, + {"ca47248ac0b6f8372a97ac43508308ed", "ffd2b598feabc9019262d2be", "", "", "", + "60d20404af527d248d893ae495707d1a"}, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_1[] = { + {"77be63708971c4e240d1cb79e8d77feb", "e0e00f19fed7ba0136a797f3", "", + "7a43ec1d9c0a5a78a0b16533a6213cab", "", + "209fcc8d3675ed938e9c7166709dd946"}, + {"7680c5d3ca6154758e510f4d25b98820", "f8f105f9c3df4965780321f8", "", + "c94c410194c765e3dcc7964379758ed3", "", + "94dca8edfcf90bb74b153c8d48a17930"}, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_2[] = { + {"7fddb57453c241d03efbed3ac44e371c", "ee283a3fc75575e33efd4887", + "d5de42b461646c255c87bd2962d3b9a2", "", "2ccda4a5415cb91e135c2a0f78c9b2fd", + "b36d1df9b9d5e596f83e8b7f52971cb3"}, + {"ab72c77b97cb5fe9a382d9fe81ffdbed", "54cc7dc2c37ec006bcc6d1da", + "007c5e5b3e59df24a7c355584fc1518d", "", "0e1bde206a07a9c2c1b65300f8c64997", + "2b4401346697138c7a4891ee59867d0c"}, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_3[] = { + {"fe47fcce5fc32665d2ae399e4eec72ba", "5adb9609dbaeb58cbd6e7275", + "7c0e88c88899a779228465074797cd4c2e1498d259b54390b85e3eef1c02df60e743f1" + "b840382c4bccaf3bafb4ca8429bea063", + "88319d6e1d3ffa5f987199166c8a9b56c2aeba5a", + "98f4826f05a265e6dd2be82db241c0fbbbf9ffb1c173aa83964b7cf539304373636525" + "3ddbc5db8778371495da76d269e5db3e", + "291ef1982e4defedaa2249f898556b47"}, + {"ec0c2ba17aa95cd6afffe949da9cc3a8", "296bce5b50b7d66096d627ef", + "b85b3753535b825cbe5f632c0b843c741351f18aa484281aebec2f45bb9eea2d79d987" + "b764b9611f6c0f8641843d5d58f3a242", + "f8d00f05d22bf68599bcdeb131292ad6e2df5d14", + "a7443d31c26bdf2a1c945e29ee4bd344a99cfaf3aa71f8b3f191f83c2adfc7a0716299" + "5506fde6309ffc19e716eddf1a828c5a", + "890147971946b627c40016da1ecf3e77"}, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_4[] = { + {"2c1f21cf0f6fb3661943155c3e3d8492", "23cb5ff362e22426984d1907", + "42f758836986954db44bf37c6ef5e4ac0adaf38f27252a1b82d02ea949c8a1a2dbc0d6" + "8b5615ba7c1220ff6510e259f06655d8", + "5d3624879d35e46849953e45a32a624d6a6c536ed9857c613b572b0333e701557a713e" + "3f010ecdf9a6bd6c9e3e44b065208645aff4aabee611b391528514170084ccf587177f" + "4488f33cfb5e979e42b6e1cfc0a60238982a7aec", + "81824f0e0d523db30d3da369fdc0d60894c7a0a20646dd015073ad2732bd989b14a222" + "b6ad57af43e1895df9dca2a5344a62cc", + "57a3ee28136e94c74838997ae9823f3a"}, + {"d9f7d2411091f947b4d6f1e2d1f0fb2e", "e1934f5db57cc983e6b180e7", + "73ed042327f70fe9c572a61545eda8b2a0c6e1d6c291ef19248e973aee6c312012f490" + "c2c6f6166f4a59431e182663fcaea05a", + "0a8a18a7150e940c3d87b38e73baee9a5c049ee21795663e264b694a949822b639092d" + "0e67015e86363583fcf0ca645af9f43375f05fdb4ce84f411dcbca73c2220dea03a201" + "15d2e51398344b16bee1ed7c499b353d6c597af8", + "aaadbd5c92e9151ce3db7210b8714126b73e43436d242677afa50384f2149b831f1d57" + "3c7891c2a91fbc48db29967ec9542b23", + "21b51ca862cb637cdd03b99a0f93b134"}, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_5[] = { + {"fe9bb47deb3a61e423c2231841cfd1fb", "4d328eb776f500a2f7fb47aa", + "f1cc3818e421876bb6b8bbd6c9", "", "b88c5c1977b35b517b0aeae967", + "43fd4727fe5cdb4b5b42818dea7ef8c9"}, + {"6703df3701a7f54911ca72e24dca046a", "12823ab601c350ea4bc2488c", + "793cd125b0b84a043e3ac67717", "", "b2051c80014f42f08735a7b0cd", + "38e6bcd29962e5f2c13626b85a877101"}, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector* const test_group_array[] = { + test_group_0, test_group_1, test_group_2, + test_group_3, test_group_4, test_group_5, +}; + +} // namespace + +namespace quic { +namespace test { + +// EncryptWithNonce wraps the |Encrypt| method of |encrypter| to allow passing +// in an nonce and also to allocate the buffer needed for the ciphertext. +QuicData* EncryptWithNonce(Aes128Gcm12Encrypter* encrypter, + absl::string_view nonce, + absl::string_view associated_data, + absl::string_view plaintext) { + size_t ciphertext_size = encrypter->GetCiphertextSize(plaintext.length()); + std::unique_ptr ciphertext(new char[ciphertext_size]); + + if (!encrypter->Encrypt(nonce, associated_data, plaintext, + reinterpret_cast(ciphertext.get()))) { + return nullptr; + } + + return new QuicData(ciphertext.release(), ciphertext_size, true); +} + +class Aes128Gcm12EncrypterTest : public QuicTest {}; + +TEST_F(Aes128Gcm12EncrypterTest, Encrypt) { + for (size_t i = 0; i < ABSL_ARRAYSIZE(test_group_array); i++) { + SCOPED_TRACE(i); + const TestVector* test_vectors = test_group_array[i]; + const TestGroupInfo& test_info = test_group_info[i]; + for (size_t j = 0; test_vectors[j].key != nullptr; j++) { + // Decode the test vector. + std::string key = absl::HexStringToBytes(test_vectors[j].key); + std::string iv = absl::HexStringToBytes(test_vectors[j].iv); + std::string pt = absl::HexStringToBytes(test_vectors[j].pt); + std::string aad = absl::HexStringToBytes(test_vectors[j].aad); + std::string ct = absl::HexStringToBytes(test_vectors[j].ct); + std::string tag = absl::HexStringToBytes(test_vectors[j].tag); + + // The test vector's lengths should look sane. Note that the lengths + // in |test_info| are in bits. + EXPECT_EQ(test_info.key_len, key.length() * 8); + EXPECT_EQ(test_info.iv_len, iv.length() * 8); + EXPECT_EQ(test_info.pt_len, pt.length() * 8); + EXPECT_EQ(test_info.aad_len, aad.length() * 8); + EXPECT_EQ(test_info.pt_len, ct.length() * 8); + EXPECT_EQ(test_info.tag_len, tag.length() * 8); + + Aes128Gcm12Encrypter encrypter; + ASSERT_TRUE(encrypter.SetKey(key)); + std::unique_ptr encrypted( + EncryptWithNonce(&encrypter, iv, + // This deliberately tests that the encrypter can + // handle an AAD that is set to nullptr, as opposed + // to a zero-length, non-nullptr pointer. + aad.length() ? aad : absl::string_view(), pt)); + ASSERT_TRUE(encrypted.get()); + + // The test vectors have 16 byte authenticators but this code only uses + // the first 12. + ASSERT_LE(static_cast(Aes128Gcm12Encrypter::kAuthTagSize), + tag.length()); + tag.resize(Aes128Gcm12Encrypter::kAuthTagSize); + + ASSERT_EQ(ct.length() + tag.length(), encrypted->length()); + quiche::test::CompareCharArraysWithHexError( + "ciphertext", encrypted->data(), ct.length(), ct.data(), ct.length()); + quiche::test::CompareCharArraysWithHexError( + "authentication tag", encrypted->data() + ct.length(), tag.length(), + tag.data(), tag.length()); + } + } +} + +TEST_F(Aes128Gcm12EncrypterTest, GetMaxPlaintextSize) { + Aes128Gcm12Encrypter encrypter; + EXPECT_EQ(1000u, encrypter.GetMaxPlaintextSize(1012)); + EXPECT_EQ(100u, encrypter.GetMaxPlaintextSize(112)); + EXPECT_EQ(10u, encrypter.GetMaxPlaintextSize(22)); + EXPECT_EQ(0u, encrypter.GetMaxPlaintextSize(11)); +} + +TEST_F(Aes128Gcm12EncrypterTest, GetCiphertextSize) { + Aes128Gcm12Encrypter encrypter; + EXPECT_EQ(1012u, encrypter.GetCiphertextSize(1000)); + EXPECT_EQ(112u, encrypter.GetCiphertextSize(100)); + EXPECT_EQ(22u, encrypter.GetCiphertextSize(10)); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/aes_128_gcm_decrypter.cc b/quiche/quic/core/crypto/aes_128_gcm_decrypter.cc new file mode 100644 index 000000000000..c43123bb4458 --- /dev/null +++ b/quiche/quic/core/crypto/aes_128_gcm_decrypter.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/aes_128_gcm_decrypter.h" + +#include "openssl/aead.h" +#include "openssl/tls1.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" + +namespace quic { + +namespace { + +const size_t kKeySize = 16; +const size_t kNonceSize = 12; + +} // namespace + +Aes128GcmDecrypter::Aes128GcmDecrypter() + : AesBaseDecrypter(EVP_aead_aes_128_gcm, kKeySize, kAuthTagSize, kNonceSize, + /* use_ietf_nonce_construction */ true) { + static_assert(kKeySize <= kMaxKeySize, "key size too big"); + static_assert(kNonceSize <= kMaxNonceSize, "nonce size too big"); +} + +Aes128GcmDecrypter::~Aes128GcmDecrypter() {} + +uint32_t Aes128GcmDecrypter::cipher_id() const { + return TLS1_CK_AES_128_GCM_SHA256; +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/aes_128_gcm_decrypter.h b/quiche/quic/core/crypto/aes_128_gcm_decrypter.h new file mode 100644 index 000000000000..c5f4d17de149 --- /dev/null +++ b/quiche/quic/core/crypto/aes_128_gcm_decrypter.h @@ -0,0 +1,36 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_AES_128_GCM_DECRYPTER_H_ +#define QUICHE_QUIC_CORE_CRYPTO_AES_128_GCM_DECRYPTER_H_ + +#include + +#include "quiche/quic/core/crypto/aes_base_decrypter.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// An Aes128GcmDecrypter is a QuicDecrypter that implements the +// AEAD_AES_128_GCM algorithm specified in RFC 5116 for use in IETF QUIC. +// +// It uses an authentication tag of 16 bytes (128 bits). It uses a 12 byte IV +// that is XOR'd with the packet number to compute the nonce. +class QUIC_EXPORT_PRIVATE Aes128GcmDecrypter : public AesBaseDecrypter { + public: + enum { + kAuthTagSize = 16, + }; + + Aes128GcmDecrypter(); + Aes128GcmDecrypter(const Aes128GcmDecrypter&) = delete; + Aes128GcmDecrypter& operator=(const Aes128GcmDecrypter&) = delete; + ~Aes128GcmDecrypter() override; + + uint32_t cipher_id() const override; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_AES_128_GCM_DECRYPTER_H_ diff --git a/quiche/quic/core/crypto/aes_128_gcm_decrypter_test.cc b/quiche/quic/core/crypto/aes_128_gcm_decrypter_test.cc new file mode 100644 index 000000000000..e02b433d22fe --- /dev/null +++ b/quiche/quic/core/crypto/aes_128_gcm_decrypter_test.cc @@ -0,0 +1,291 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/aes_128_gcm_decrypter.h" + +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/escaping.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +namespace { + +// The AES GCM test vectors come from the file gcmDecrypt128.rsp +// downloaded from http://csrc.nist.gov/groups/STM/cavp/index.html on +// 2013-02-01. The test vectors in that file look like this: +// +// [Keylen = 128] +// [IVlen = 96] +// [PTlen = 0] +// [AADlen = 0] +// [Taglen = 128] +// +// Count = 0 +// Key = cf063a34d4a9a76c2c86787d3f96db71 +// IV = 113b9785971864c83b01c787 +// CT = +// AAD = +// Tag = 72ac8493e3a5228b5d130a69d2510e42 +// PT = +// +// Count = 1 +// Key = a49a5e26a2f8cb63d05546c2a62f5343 +// IV = 907763b19b9b4ab6bd4f0281 +// CT = +// AAD = +// Tag = a2be08210d8c470a8df6e8fbd79ec5cf +// FAIL +// +// ... +// +// The gcmDecrypt128.rsp file is huge (2.6 MB), so I selected just a +// few test vectors for this unit test. + +// Describes a group of test vectors that all have a given key length, IV +// length, plaintext length, AAD length, and tag length. +struct TestGroupInfo { + size_t key_len; + size_t iv_len; + size_t pt_len; + size_t aad_len; + size_t tag_len; +}; + +// Each test vector consists of six strings of lowercase hexadecimal digits. +// The strings may be empty (zero length). A test vector with a nullptr |key| +// marks the end of an array of test vectors. +struct TestVector { + // Input: + const char* key; + const char* iv; + const char* ct; + const char* aad; + const char* tag; + + // Expected output: + const char* pt; // An empty string "" means decryption succeeded and + // the plaintext is zero-length. nullptr means decryption + // failed. +}; + +const TestGroupInfo test_group_info[] = { + {128, 96, 0, 0, 128}, {128, 96, 0, 128, 128}, {128, 96, 128, 0, 128}, + {128, 96, 408, 160, 128}, {128, 96, 408, 720, 128}, {128, 96, 104, 0, 128}, +}; + +const TestVector test_group_0[] = { + {"cf063a34d4a9a76c2c86787d3f96db71", "113b9785971864c83b01c787", "", "", + "72ac8493e3a5228b5d130a69d2510e42", ""}, + { + "a49a5e26a2f8cb63d05546c2a62f5343", "907763b19b9b4ab6bd4f0281", "", "", + "a2be08210d8c470a8df6e8fbd79ec5cf", + nullptr // FAIL + }, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_1[] = { + { + "d1f6af919cde85661208bdce0c27cb22", "898c6929b435017bf031c3c5", "", + "7c5faa40e636bbc91107e68010c92b9f", "ae45f11777540a2caeb128be8092468a", + nullptr // FAIL + }, + {"2370e320d4344208e0ff5683f243b213", "04dbb82f044d30831c441228", "", + "d43a8e5089eea0d026c03a85178b27da", "2a049c049d25aa95969b451d93c31c6e", + ""}, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_2[] = { + {"e98b72a9881a84ca6b76e0f43e68647a", "8b23299fde174053f3d652ba", + "5a3c1cf1985dbb8bed818036fdd5ab42", "", "23c7ab0f952b7091cd324835043b5eb5", + "28286a321293253c3e0aa2704a278032"}, + {"33240636cd3236165f1a553b773e728e", "17c4d61493ecdc8f31700b12", + "47bb7e23f7bdfe05a8091ac90e4f8b2e", "", "b723c70e931d9785f40fd4ab1d612dc9", + "95695a5b12f2870b9cc5fdc8f218a97d"}, + { + "5164df856f1e9cac04a79b808dc5be39", "e76925d5355e0584ce871b2b", + "0216c899c88d6e32c958c7e553daa5bc", "", + "a145319896329c96df291f64efbe0e3a", + nullptr // FAIL + }, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_3[] = { + {"af57f42c60c0fc5a09adb81ab86ca1c3", "a2dc01871f37025dc0fc9a79", + "b9a535864f48ea7b6b1367914978f9bfa087d854bb0e269bed8d279d2eea1210e48947" + "338b22f9bad09093276a331e9c79c7f4", + "41dc38988945fcb44faf2ef72d0061289ef8efd8", + "4f71e72bde0018f555c5adcce062e005", + "3803a0727eeb0ade441e0ec107161ded2d425ec0d102f21f51bf2cf9947c7ec4aa7279" + "5b2f69b041596e8817d0a3c16f8fadeb"}, + {"ebc753e5422b377d3cb64b58ffa41b61", "2e1821efaced9acf1f241c9b", + "069567190554e9ab2b50a4e1fbf9c147340a5025fdbd201929834eaf6532325899ccb9" + "f401823e04b05817243d2142a3589878", + "b9673412fd4f88ba0e920f46dd6438ff791d8eef", + "534d9234d2351cf30e565de47baece0b", + "39077edb35e9c5a4b1e4c2a6b9bb1fce77f00f5023af40333d6d699014c2bcf4209c18" + "353a18017f5b36bfc00b1f6dcb7ed485"}, + { + "52bdbbf9cf477f187ec010589cb39d58", "d3be36d3393134951d324b31", + "700188da144fa692cf46e4a8499510a53d90903c967f7f13e8a1bd8151a74adc4fe63e" + "32b992760b3a5f99e9a47838867000a9", + "93c4fc6a4135f54d640b0c976bf755a06a292c33", + "8ca4e38aa3dfa6b1d0297021ccf3ea5f", + nullptr // FAIL + }, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_4[] = { + {"da2bb7d581493d692380c77105590201", "44aa3e7856ca279d2eb020c6", + "9290d430c9e89c37f0446dbd620c9a6b34b1274aeb6f911f75867efcf95b6feda69f1a" + "f4ee16c761b3c9aeac3da03aa9889c88", + "4cd171b23bddb3a53cdf959d5c1710b481eb3785a90eb20a2345ee00d0bb7868c367ab" + "12e6f4dd1dee72af4eee1d197777d1d6499cc541f34edbf45cda6ef90b3c024f9272d7" + "2ec1909fb8fba7db88a4d6f7d3d925980f9f9f72", + "9e3ac938d3eb0cadd6f5c9e35d22ba38", + "9bbf4c1a2742f6ac80cb4e8a052e4a8f4f07c43602361355b717381edf9fabd4cb7e3a" + "d65dbd1378b196ac270588dd0621f642"}, + {"d74e4958717a9d5c0e235b76a926cae8", "0b7471141e0c70b1995fd7b1", + "e701c57d2330bf066f9ff8cf3ca4343cafe4894651cd199bdaaa681ba486b4a65c5a22" + "b0f1420be29ea547d42c713bc6af66aa", + "4a42b7aae8c245c6f1598a395316e4b8484dbd6e64648d5e302021b1d3fa0a38f46e22" + "bd9c8080b863dc0016482538a8562a4bd0ba84edbe2697c76fd039527ac179ec5506cf" + "34a6039312774cedebf4961f3978b14a26509f96", + "e192c23cb036f0b31592989119eed55d", + "840d9fb95e32559fb3602e48590280a172ca36d9b49ab69510f5bd552bfab7a306f85f" + "f0a34bc305b88b804c60b90add594a17"}, + { + "1986310c725ac94ecfe6422e75fc3ee7", "93ec4214fa8e6dc4e3afc775", + "b178ec72f85a311ac4168f42a4b2c23113fbea4b85f4b9dabb74e143eb1b8b0a361e02" + "43edfd365b90d5b325950df0ada058f9", + "e80b88e62c49c958b5e0b8b54f532d9ff6aa84c8a40132e93e55b59fc24e8decf28463" + "139f155d1e8ce4ee76aaeefcd245baa0fc519f83a5fb9ad9aa40c4b21126013f576c42" + "72c2cb136c8fd091cc4539877a5d1e72d607f960", + "8b347853f11d75e81e8a95010be81f17", + nullptr // FAIL + }, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_5[] = { + {"387218b246c1a8257748b56980e50c94", "dd7e014198672be39f95b69d", + "cdba9e73eaf3d38eceb2b04a8d", "", "ecf90f4a47c9c626d6fb2c765d201556", + "48f5b426baca03064554cc2b30"}, + {"294de463721e359863887c820524b3d4", "3338b35c9d57a5d28190e8c9", + "2f46634e74b8e4c89812ac83b9", "", "dabd506764e68b82a7e720aa18da0abe", + "46a2e55c8e264df211bd112685"}, + {"28ead7fd2179e0d12aa6d5d88c58c2dc", "5055347f18b4d5add0ae5c41", + "142d8210c3fb84774cdbd0447a", "", "5fd321d9cdb01952dc85f034736c2a7d", + "3b95b981086ee73cc4d0cc1422"}, + { + "7d7b6c988137b8d470c57bf674a09c87", "9edf2aa970d016ac962e1fd8", + "a85b66c3cb5eab91d5bdc8bc0e", "", "dc054efc01f3afd21d9c2484819f569a", + nullptr // FAIL + }, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector* const test_group_array[] = { + test_group_0, test_group_1, test_group_2, + test_group_3, test_group_4, test_group_5, +}; + +} // namespace + +namespace quic { +namespace test { + +// DecryptWithNonce wraps the |Decrypt| method of |decrypter| to allow passing +// in an nonce and also to allocate the buffer needed for the plaintext. +QuicData* DecryptWithNonce(Aes128GcmDecrypter* decrypter, + absl::string_view nonce, + absl::string_view associated_data, + absl::string_view ciphertext) { + decrypter->SetIV(nonce); + std::unique_ptr output(new char[ciphertext.length()]); + size_t output_length = 0; + const bool success = + decrypter->DecryptPacket(0, associated_data, ciphertext, output.get(), + &output_length, ciphertext.length()); + if (!success) { + return nullptr; + } + return new QuicData(output.release(), output_length, true); +} + +class Aes128GcmDecrypterTest : public QuicTest {}; + +TEST_F(Aes128GcmDecrypterTest, Decrypt) { + for (size_t i = 0; i < ABSL_ARRAYSIZE(test_group_array); i++) { + SCOPED_TRACE(i); + const TestVector* test_vectors = test_group_array[i]; + const TestGroupInfo& test_info = test_group_info[i]; + for (size_t j = 0; test_vectors[j].key != nullptr; j++) { + // If not present then decryption is expected to fail. + bool has_pt = test_vectors[j].pt; + + // Decode the test vector. + std::string key = absl::HexStringToBytes(test_vectors[j].key); + std::string iv = absl::HexStringToBytes(test_vectors[j].iv); + std::string ct = absl::HexStringToBytes(test_vectors[j].ct); + std::string aad = absl::HexStringToBytes(test_vectors[j].aad); + std::string tag = absl::HexStringToBytes(test_vectors[j].tag); + std::string pt; + if (has_pt) { + pt = absl::HexStringToBytes(test_vectors[j].pt); + } + + // The test vector's lengths should look sane. Note that the lengths + // in |test_info| are in bits. + EXPECT_EQ(test_info.key_len, key.length() * 8); + EXPECT_EQ(test_info.iv_len, iv.length() * 8); + EXPECT_EQ(test_info.pt_len, ct.length() * 8); + EXPECT_EQ(test_info.aad_len, aad.length() * 8); + EXPECT_EQ(test_info.tag_len, tag.length() * 8); + if (has_pt) { + EXPECT_EQ(test_info.pt_len, pt.length() * 8); + } + std::string ciphertext = ct + tag; + + Aes128GcmDecrypter decrypter; + ASSERT_TRUE(decrypter.SetKey(key)); + + std::unique_ptr decrypted(DecryptWithNonce( + &decrypter, iv, + // This deliberately tests that the decrypter can + // handle an AAD that is set to nullptr, as opposed + // to a zero-length, non-nullptr pointer. + aad.length() ? aad : absl::string_view(), ciphertext)); + if (!decrypted) { + EXPECT_FALSE(has_pt); + continue; + } + EXPECT_TRUE(has_pt); + + ASSERT_EQ(pt.length(), decrypted->length()); + quiche::test::CompareCharArraysWithHexError( + "plaintext", decrypted->data(), pt.length(), pt.data(), pt.length()); + } + } +} + +TEST_F(Aes128GcmDecrypterTest, GenerateHeaderProtectionMask) { + Aes128GcmDecrypter decrypter; + std::string key = absl::HexStringToBytes("d9132370cb18476ab833649cf080d970"); + std::string sample = + absl::HexStringToBytes("d1d7998068517adb769b48b924a32c47"); + QuicDataReader sample_reader(sample.data(), sample.size()); + ASSERT_TRUE(decrypter.SetHeaderProtectionKey(key)); + std::string mask = decrypter.GenerateHeaderProtectionMask(&sample_reader); + std::string expected_mask = + absl::HexStringToBytes("b132c37d6164da4ea4dc9b763aceec27"); + quiche::test::CompareCharArraysWithHexError( + "header protection mask", mask.data(), mask.size(), expected_mask.data(), + expected_mask.size()); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/aes_128_gcm_encrypter.cc b/quiche/quic/core/crypto/aes_128_gcm_encrypter.cc new file mode 100644 index 000000000000..22f9b2a2168d --- /dev/null +++ b/quiche/quic/core/crypto/aes_128_gcm_encrypter.cc @@ -0,0 +1,27 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/aes_128_gcm_encrypter.h" + +#include "openssl/evp.h" + +namespace quic { + +namespace { + +const size_t kKeySize = 16; +const size_t kNonceSize = 12; + +} // namespace + +Aes128GcmEncrypter::Aes128GcmEncrypter() + : AesBaseEncrypter(EVP_aead_aes_128_gcm, kKeySize, kAuthTagSize, kNonceSize, + /* use_ietf_nonce_construction */ true) { + static_assert(kKeySize <= kMaxKeySize, "key size too big"); + static_assert(kNonceSize <= kMaxNonceSize, "nonce size too big"); +} + +Aes128GcmEncrypter::~Aes128GcmEncrypter() {} + +} // namespace quic diff --git a/quiche/quic/core/crypto/aes_128_gcm_encrypter.h b/quiche/quic/core/crypto/aes_128_gcm_encrypter.h new file mode 100644 index 000000000000..a40735c9e37e --- /dev/null +++ b/quiche/quic/core/crypto/aes_128_gcm_encrypter.h @@ -0,0 +1,32 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_AES_128_GCM_ENCRYPTER_H_ +#define QUICHE_QUIC_CORE_CRYPTO_AES_128_GCM_ENCRYPTER_H_ + +#include "quiche/quic/core/crypto/aes_base_encrypter.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// An Aes128GcmEncrypter is a QuicEncrypter that implements the +// AEAD_AES_128_GCM algorithm specified in RFC 5116 for use in IETF QUIC. +// +// It uses an authentication tag of 16 bytes (128 bits). It uses a 12 byte IV +// that is XOR'd with the packet number to compute the nonce. +class QUIC_EXPORT_PRIVATE Aes128GcmEncrypter : public AesBaseEncrypter { + public: + enum { + kAuthTagSize = 16, + }; + + Aes128GcmEncrypter(); + Aes128GcmEncrypter(const Aes128GcmEncrypter&) = delete; + Aes128GcmEncrypter& operator=(const Aes128GcmEncrypter&) = delete; + ~Aes128GcmEncrypter() override; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_AES_128_GCM_ENCRYPTER_H_ diff --git a/quiche/quic/core/crypto/aes_128_gcm_encrypter_test.cc b/quiche/quic/core/crypto/aes_128_gcm_encrypter_test.cc new file mode 100644 index 000000000000..70860944b55c --- /dev/null +++ b/quiche/quic/core/crypto/aes_128_gcm_encrypter_test.cc @@ -0,0 +1,273 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/aes_128_gcm_encrypter.h" + +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/escaping.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +namespace { + +// The AES GCM test vectors come from the file gcmEncryptExtIV128.rsp +// downloaded from http://csrc.nist.gov/groups/STM/cavp/index.html on +// 2013-02-01. The test vectors in that file look like this: +// +// [Keylen = 128] +// [IVlen = 96] +// [PTlen = 0] +// [AADlen = 0] +// [Taglen = 128] +// +// Count = 0 +// Key = 11754cd72aec309bf52f7687212e8957 +// IV = 3c819d9a9bed087615030b65 +// PT = +// AAD = +// CT = +// Tag = 250327c674aaf477aef2675748cf6971 +// +// Count = 1 +// Key = ca47248ac0b6f8372a97ac43508308ed +// IV = ffd2b598feabc9019262d2be +// PT = +// AAD = +// CT = +// Tag = 60d20404af527d248d893ae495707d1a +// +// ... +// +// The gcmEncryptExtIV128.rsp file is huge (2.8 MB), so I selected just a +// few test vectors for this unit test. + +// Describes a group of test vectors that all have a given key length, IV +// length, plaintext length, AAD length, and tag length. +struct TestGroupInfo { + size_t key_len; + size_t iv_len; + size_t pt_len; + size_t aad_len; + size_t tag_len; +}; + +// Each test vector consists of six strings of lowercase hexadecimal digits. +// The strings may be empty (zero length). A test vector with a nullptr |key| +// marks the end of an array of test vectors. +struct TestVector { + const char* key; + const char* iv; + const char* pt; + const char* aad; + const char* ct; + const char* tag; +}; + +const TestGroupInfo test_group_info[] = { + {128, 96, 0, 0, 128}, {128, 96, 0, 128, 128}, {128, 96, 128, 0, 128}, + {128, 96, 408, 160, 128}, {128, 96, 408, 720, 128}, {128, 96, 104, 0, 128}, +}; + +const TestVector test_group_0[] = { + {"11754cd72aec309bf52f7687212e8957", "3c819d9a9bed087615030b65", "", "", "", + "250327c674aaf477aef2675748cf6971"}, + {"ca47248ac0b6f8372a97ac43508308ed", "ffd2b598feabc9019262d2be", "", "", "", + "60d20404af527d248d893ae495707d1a"}, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_1[] = { + {"77be63708971c4e240d1cb79e8d77feb", "e0e00f19fed7ba0136a797f3", "", + "7a43ec1d9c0a5a78a0b16533a6213cab", "", + "209fcc8d3675ed938e9c7166709dd946"}, + {"7680c5d3ca6154758e510f4d25b98820", "f8f105f9c3df4965780321f8", "", + "c94c410194c765e3dcc7964379758ed3", "", + "94dca8edfcf90bb74b153c8d48a17930"}, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_2[] = { + {"7fddb57453c241d03efbed3ac44e371c", "ee283a3fc75575e33efd4887", + "d5de42b461646c255c87bd2962d3b9a2", "", "2ccda4a5415cb91e135c2a0f78c9b2fd", + "b36d1df9b9d5e596f83e8b7f52971cb3"}, + {"ab72c77b97cb5fe9a382d9fe81ffdbed", "54cc7dc2c37ec006bcc6d1da", + "007c5e5b3e59df24a7c355584fc1518d", "", "0e1bde206a07a9c2c1b65300f8c64997", + "2b4401346697138c7a4891ee59867d0c"}, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_3[] = { + {"fe47fcce5fc32665d2ae399e4eec72ba", "5adb9609dbaeb58cbd6e7275", + "7c0e88c88899a779228465074797cd4c2e1498d259b54390b85e3eef1c02df60e743f1" + "b840382c4bccaf3bafb4ca8429bea063", + "88319d6e1d3ffa5f987199166c8a9b56c2aeba5a", + "98f4826f05a265e6dd2be82db241c0fbbbf9ffb1c173aa83964b7cf539304373636525" + "3ddbc5db8778371495da76d269e5db3e", + "291ef1982e4defedaa2249f898556b47"}, + {"ec0c2ba17aa95cd6afffe949da9cc3a8", "296bce5b50b7d66096d627ef", + "b85b3753535b825cbe5f632c0b843c741351f18aa484281aebec2f45bb9eea2d79d987" + "b764b9611f6c0f8641843d5d58f3a242", + "f8d00f05d22bf68599bcdeb131292ad6e2df5d14", + "a7443d31c26bdf2a1c945e29ee4bd344a99cfaf3aa71f8b3f191f83c2adfc7a0716299" + "5506fde6309ffc19e716eddf1a828c5a", + "890147971946b627c40016da1ecf3e77"}, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_4[] = { + {"2c1f21cf0f6fb3661943155c3e3d8492", "23cb5ff362e22426984d1907", + "42f758836986954db44bf37c6ef5e4ac0adaf38f27252a1b82d02ea949c8a1a2dbc0d6" + "8b5615ba7c1220ff6510e259f06655d8", + "5d3624879d35e46849953e45a32a624d6a6c536ed9857c613b572b0333e701557a713e" + "3f010ecdf9a6bd6c9e3e44b065208645aff4aabee611b391528514170084ccf587177f" + "4488f33cfb5e979e42b6e1cfc0a60238982a7aec", + "81824f0e0d523db30d3da369fdc0d60894c7a0a20646dd015073ad2732bd989b14a222" + "b6ad57af43e1895df9dca2a5344a62cc", + "57a3ee28136e94c74838997ae9823f3a"}, + {"d9f7d2411091f947b4d6f1e2d1f0fb2e", "e1934f5db57cc983e6b180e7", + "73ed042327f70fe9c572a61545eda8b2a0c6e1d6c291ef19248e973aee6c312012f490" + "c2c6f6166f4a59431e182663fcaea05a", + "0a8a18a7150e940c3d87b38e73baee9a5c049ee21795663e264b694a949822b639092d" + "0e67015e86363583fcf0ca645af9f43375f05fdb4ce84f411dcbca73c2220dea03a201" + "15d2e51398344b16bee1ed7c499b353d6c597af8", + "aaadbd5c92e9151ce3db7210b8714126b73e43436d242677afa50384f2149b831f1d57" + "3c7891c2a91fbc48db29967ec9542b23", + "21b51ca862cb637cdd03b99a0f93b134"}, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_5[] = { + {"fe9bb47deb3a61e423c2231841cfd1fb", "4d328eb776f500a2f7fb47aa", + "f1cc3818e421876bb6b8bbd6c9", "", "b88c5c1977b35b517b0aeae967", + "43fd4727fe5cdb4b5b42818dea7ef8c9"}, + {"6703df3701a7f54911ca72e24dca046a", "12823ab601c350ea4bc2488c", + "793cd125b0b84a043e3ac67717", "", "b2051c80014f42f08735a7b0cd", + "38e6bcd29962e5f2c13626b85a877101"}, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector* const test_group_array[] = { + test_group_0, test_group_1, test_group_2, + test_group_3, test_group_4, test_group_5, +}; + +} // namespace + +namespace quic { +namespace test { + +// EncryptWithNonce wraps the |Encrypt| method of |encrypter| to allow passing +// in an nonce and also to allocate the buffer needed for the ciphertext. +QuicData* EncryptWithNonce(Aes128GcmEncrypter* encrypter, + absl::string_view nonce, + absl::string_view associated_data, + absl::string_view plaintext) { + size_t ciphertext_size = encrypter->GetCiphertextSize(plaintext.length()); + std::unique_ptr ciphertext(new char[ciphertext_size]); + + if (!encrypter->Encrypt(nonce, associated_data, plaintext, + reinterpret_cast(ciphertext.get()))) { + return nullptr; + } + + return new QuicData(ciphertext.release(), ciphertext_size, true); +} + +class Aes128GcmEncrypterTest : public QuicTest {}; + +TEST_F(Aes128GcmEncrypterTest, Encrypt) { + for (size_t i = 0; i < ABSL_ARRAYSIZE(test_group_array); i++) { + SCOPED_TRACE(i); + const TestVector* test_vectors = test_group_array[i]; + const TestGroupInfo& test_info = test_group_info[i]; + for (size_t j = 0; test_vectors[j].key != nullptr; j++) { + // Decode the test vector. + std::string key = absl::HexStringToBytes(test_vectors[j].key); + std::string iv = absl::HexStringToBytes(test_vectors[j].iv); + std::string pt = absl::HexStringToBytes(test_vectors[j].pt); + std::string aad = absl::HexStringToBytes(test_vectors[j].aad); + std::string ct = absl::HexStringToBytes(test_vectors[j].ct); + std::string tag = absl::HexStringToBytes(test_vectors[j].tag); + + // The test vector's lengths should look sane. Note that the lengths + // in |test_info| are in bits. + EXPECT_EQ(test_info.key_len, key.length() * 8); + EXPECT_EQ(test_info.iv_len, iv.length() * 8); + EXPECT_EQ(test_info.pt_len, pt.length() * 8); + EXPECT_EQ(test_info.aad_len, aad.length() * 8); + EXPECT_EQ(test_info.pt_len, ct.length() * 8); + EXPECT_EQ(test_info.tag_len, tag.length() * 8); + + Aes128GcmEncrypter encrypter; + ASSERT_TRUE(encrypter.SetKey(key)); + std::unique_ptr encrypted( + EncryptWithNonce(&encrypter, iv, + // This deliberately tests that the encrypter can + // handle an AAD that is set to nullptr, as opposed + // to a zero-length, non-nullptr pointer. + aad.length() ? aad : absl::string_view(), pt)); + ASSERT_TRUE(encrypted.get()); + + ASSERT_EQ(ct.length() + tag.length(), encrypted->length()); + quiche::test::CompareCharArraysWithHexError( + "ciphertext", encrypted->data(), ct.length(), ct.data(), ct.length()); + quiche::test::CompareCharArraysWithHexError( + "authentication tag", encrypted->data() + ct.length(), tag.length(), + tag.data(), tag.length()); + } + } +} + +TEST_F(Aes128GcmEncrypterTest, EncryptPacket) { + std::string key = absl::HexStringToBytes("d95a145250826c25a77b6a84fd4d34fc"); + std::string iv = absl::HexStringToBytes("50c4431ebb18283448e276e2"); + uint64_t packet_num = 0x13278f44; + std::string aad = + absl::HexStringToBytes("875d49f64a70c9cbe713278f44ff000005"); + std::string pt = absl::HexStringToBytes("aa0003a250bd000000000001"); + std::string ct = absl::HexStringToBytes( + "7dd4708b989ee7d38a013e3656e9b37beefd05808fe1ab41e3b4f2c0"); + + std::vector out(ct.size()); + size_t out_size; + + Aes128GcmEncrypter encrypter; + ASSERT_TRUE(encrypter.SetKey(key)); + ASSERT_TRUE(encrypter.SetIV(iv)); + ASSERT_TRUE(encrypter.EncryptPacket(packet_num, aad, pt, out.data(), + &out_size, out.size())); + EXPECT_EQ(out_size, out.size()); + quiche::test::CompareCharArraysWithHexError("ciphertext", out.data(), + out.size(), ct.data(), ct.size()); +} + +TEST_F(Aes128GcmEncrypterTest, GetMaxPlaintextSize) { + Aes128GcmEncrypter encrypter; + EXPECT_EQ(1000u, encrypter.GetMaxPlaintextSize(1016)); + EXPECT_EQ(100u, encrypter.GetMaxPlaintextSize(116)); + EXPECT_EQ(10u, encrypter.GetMaxPlaintextSize(26)); +} + +TEST_F(Aes128GcmEncrypterTest, GetCiphertextSize) { + Aes128GcmEncrypter encrypter; + EXPECT_EQ(1016u, encrypter.GetCiphertextSize(1000)); + EXPECT_EQ(116u, encrypter.GetCiphertextSize(100)); + EXPECT_EQ(26u, encrypter.GetCiphertextSize(10)); +} + +TEST_F(Aes128GcmEncrypterTest, GenerateHeaderProtectionMask) { + Aes128GcmEncrypter encrypter; + std::string key = absl::HexStringToBytes("d9132370cb18476ab833649cf080d970"); + std::string sample = + absl::HexStringToBytes("d1d7998068517adb769b48b924a32c47"); + ASSERT_TRUE(encrypter.SetHeaderProtectionKey(key)); + std::string mask = encrypter.GenerateHeaderProtectionMask(sample); + std::string expected_mask = + absl::HexStringToBytes("b132c37d6164da4ea4dc9b763aceec27"); + quiche::test::CompareCharArraysWithHexError( + "header protection mask", mask.data(), mask.size(), expected_mask.data(), + expected_mask.size()); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/aes_256_gcm_decrypter.cc b/quiche/quic/core/crypto/aes_256_gcm_decrypter.cc new file mode 100644 index 000000000000..58d4e3c2cf30 --- /dev/null +++ b/quiche/quic/core/crypto/aes_256_gcm_decrypter.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/aes_256_gcm_decrypter.h" + +#include "openssl/aead.h" +#include "openssl/tls1.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" + +namespace quic { + +namespace { + +const size_t kKeySize = 32; +const size_t kNonceSize = 12; + +} // namespace + +Aes256GcmDecrypter::Aes256GcmDecrypter() + : AesBaseDecrypter(EVP_aead_aes_256_gcm, kKeySize, kAuthTagSize, kNonceSize, + /* use_ietf_nonce_construction */ true) { + static_assert(kKeySize <= kMaxKeySize, "key size too big"); + static_assert(kNonceSize <= kMaxNonceSize, "nonce size too big"); +} + +Aes256GcmDecrypter::~Aes256GcmDecrypter() {} + +uint32_t Aes256GcmDecrypter::cipher_id() const { + return TLS1_CK_AES_256_GCM_SHA384; +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/aes_256_gcm_decrypter.h b/quiche/quic/core/crypto/aes_256_gcm_decrypter.h new file mode 100644 index 000000000000..dc4f8c08486f --- /dev/null +++ b/quiche/quic/core/crypto/aes_256_gcm_decrypter.h @@ -0,0 +1,36 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_AES_256_GCM_DECRYPTER_H_ +#define QUICHE_QUIC_CORE_CRYPTO_AES_256_GCM_DECRYPTER_H_ + +#include + +#include "quiche/quic/core/crypto/aes_base_decrypter.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// An Aes256GcmDecrypter is a QuicDecrypter that implements the +// AEAD_AES_256_GCM algorithm specified in RFC 5116 for use in IETF QUIC. +// +// It uses an authentication tag of 16 bytes (128 bits). It uses a 12 byte IV +// that is XOR'd with the packet number to compute the nonce. +class QUIC_EXPORT_PRIVATE Aes256GcmDecrypter : public AesBaseDecrypter { + public: + enum { + kAuthTagSize = 16, + }; + + Aes256GcmDecrypter(); + Aes256GcmDecrypter(const Aes256GcmDecrypter&) = delete; + Aes256GcmDecrypter& operator=(const Aes256GcmDecrypter&) = delete; + ~Aes256GcmDecrypter() override; + + uint32_t cipher_id() const override; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_AES_256_GCM_DECRYPTER_H_ diff --git a/quiche/quic/core/crypto/aes_256_gcm_decrypter_test.cc b/quiche/quic/core/crypto/aes_256_gcm_decrypter_test.cc new file mode 100644 index 000000000000..7c48c8cd87e9 --- /dev/null +++ b/quiche/quic/core/crypto/aes_256_gcm_decrypter_test.cc @@ -0,0 +1,297 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/aes_256_gcm_decrypter.h" + +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +namespace { + +// The AES GCM test vectors come from the file gcmDecrypt256.rsp +// downloaded from +// https://csrc.nist.gov/Projects/Cryptographic-Algorithm-Validation-Program/CAVP-TESTING-BLOCK-CIPHER-MODES#GCMVS +// on 2017-09-27. The test vectors in that file look like this: +// +// [Keylen = 256] +// [IVlen = 96] +// [PTlen = 0] +// [AADlen = 0] +// [Taglen = 128] +// +// Count = 0 +// Key = f5a2b27c74355872eb3ef6c5feafaa740e6ae990d9d48c3bd9bb8235e589f010 +// IV = 58d2240f580a31c1d24948e9 +// CT = +// AAD = +// Tag = 15e051a5e4a5f5da6cea92e2ebee5bac +// PT = +// +// Count = 1 +// Key = e5a8123f2e2e007d4e379ba114a2fb66e6613f57c72d4e4f024964053028a831 +// IV = 51e43385bf533e168427e1ad +// CT = +// AAD = +// Tag = 38fe845c66e66bdd884c2aecafd280e6 +// FAIL +// +// ... +// +// The gcmDecrypt256.rsp file is huge (3.0 MB), so a few test vectors were +// selected for this unit test. + +// Describes a group of test vectors that all have a given key length, IV +// length, plaintext length, AAD length, and tag length. +struct TestGroupInfo { + size_t key_len; + size_t iv_len; + size_t pt_len; + size_t aad_len; + size_t tag_len; +}; + +// Each test vector consists of six strings of lowercase hexadecimal digits. +// The strings may be empty (zero length). A test vector with a nullptr |key| +// marks the end of an array of test vectors. +struct TestVector { + // Input: + const char* key; + const char* iv; + const char* ct; + const char* aad; + const char* tag; + + // Expected output: + const char* pt; // An empty string "" means decryption succeeded and + // the plaintext is zero-length. nullptr means decryption + // failed. +}; + +const TestGroupInfo test_group_info[] = { + {256, 96, 0, 0, 128}, {256, 96, 0, 128, 128}, {256, 96, 128, 0, 128}, + {256, 96, 408, 160, 128}, {256, 96, 408, 720, 128}, {256, 96, 104, 0, 128}, +}; + +const TestVector test_group_0[] = { + {"f5a2b27c74355872eb3ef6c5feafaa740e6ae990d9d48c3bd9bb8235e589f010", + "58d2240f580a31c1d24948e9", "", "", "15e051a5e4a5f5da6cea92e2ebee5bac", + ""}, + { + "e5a8123f2e2e007d4e379ba114a2fb66e6613f57c72d4e4f024964053028a831", + "51e43385bf533e168427e1ad", "", "", "38fe845c66e66bdd884c2aecafd280e6", + nullptr // FAIL + }, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_1[] = { + {"6dfdafd6703c285c01f14fd10a6012862b2af950d4733abb403b2e745b26945d", + "3749d0b3d5bacb71be06ade6", "", "c0d249871992e70302ae008193d1e89f", + "4aa4cc69f84ee6ac16d9bfb4e05de500", ""}, + { + "2c392a5eb1a9c705371beda3a901c7c61dca4d93b4291de1dd0dd15ec11ffc45", + "0723fb84a08f4ea09841f32a", "", "140be561b6171eab942c486a94d33d43", + "aa0e1c9b57975bfc91aa137231977d2c", nullptr // FAIL + }, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_2[] = { + {"4c8ebfe1444ec1b2d503c6986659af2c94fafe945f72c1e8486a5acfedb8a0f8", + "473360e0ad24889959858995", "d2c78110ac7e8f107c0df0570bd7c90c", "", + "c26a379b6d98ef2852ead8ce83a833a7", "7789b41cb3ee548814ca0b388c10b343"}, + {"3934f363fd9f771352c4c7a060682ed03c2864223a1573b3af997e2ababd60ab", + "efe2656d878c586e41c539c4", "e0de64302ac2d04048d65a87d2ad09fe", "", + "33cbd8d2fb8a3a03e30c1eb1b53c1d99", "697aff2d6b77e5ed6232770e400c1ead"}, + { + "c997768e2d14e3d38259667a6649079de77beb4543589771e5068e6cd7cd0b14", + "835090aed9552dbdd45277e2", "9f6607d68e22ccf21928db0986be126e", "", + "f32617f67c574fd9f44ef76ff880ab9f", nullptr // FAIL + }, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_3[] = { + { + "e9d381a9c413bee66175d5586a189836e5c20f5583535ab4d3f3e612dc21700e", + "23e81571da1c7821c681c7ca", + "a25f3f580306cd5065d22a6b7e9660110af7204bb77d370f7f34bee547feeff7b32a59" + "6fce29c9040e68b1589aad48da881990", + "6f39c9ae7b8e8a58a95f0dd8ea6a9087cbccdfd6", + "5b6dcd70eefb0892fab1539298b92a4b", + nullptr // FAIL + }, + {"6450d4501b1e6cfbe172c4c8570363e96b496591b842661c28c2f6c908379cad", + "7e4262035e0bf3d60e91668a", + "5a99b336fd3cfd82f10fb08f7045012415f0d9a06bb92dcf59c6f0dbe62d433671aacb8a1" + "c52ce7bbf6aea372bf51e2ba79406", + "f1c522f026e4c5d43851da516a1b78768ab18171", + "fe93b01636f7bb0458041f213e98de65", + "17449e236ef5858f6d891412495ead4607bfae2a2d735182a2a0242f9d52fc5345ef912db" + "e16f3bb4576fe3bcafe336dee6085"}, + {"90f2e71ccb1148979cb742efc8f921de95457d898c84ce28edeed701650d3a26", + "aba58ad60047ba553f6e4c98", + "3fc77a5fe9203d091c7916587c9763cf2e4d0d53ca20b078b851716f1dab4873fe342b7b3" + "01402f015d00263bf3f77c58a99d6", + "2abe465df6e5be47f05b92c9a93d76ae3611fac5", + "9cb3d04637048bc0bddef803ffbb56cf", + "1d21639640e11638a2769e3fab78778f84be3f4a8ce28dfd99cb2e75171e05ea8e94e30aa" + "78b54bb402b39d613616a8ed951dc"}, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_4[] = { + { + "e36aca93414b13f5313e76a7244588ee116551d1f34c32859166f2eb0ac1a9b7", + "e9e701b1ccef6bddd03391d8", + "5b059ac6733b6de0e8cf5b88b7301c02c993426f71bb12abf692e9deeacfac1ff1644c" + "87d4df130028f515f0feda636309a24d", + "6a08fe6e55a08f283cec4c4b37676e770f402af6102f548ad473ec6236da764f7076ff" + "d41bbd9611b439362d899682b7b0f839fc5a68d9df54afd1e2b3c4e7d072454ee27111" + "d52193d28b9c4f925d2a8b451675af39191a2cba", + "43c7c9c93cc265fc8e192000e0417b5b", + nullptr // FAIL + }, + {"5f72046245d3f4a0877e50a86554bfd57d1c5e073d1ed3b5451f6d0fc2a8507a", + "ea6f5b391e44b751b26bce6f", + "0e6e0b2114c40769c15958d965a14dcf50b680e0185a4409d77d894ca15b1e698dd83b353" + "6b18c05d8cd0873d1edce8150ecb5", + "9b3a68c941d42744673fb60fea49075eae77322e7e70e34502c115b6495ebfc796d629080" + "7653c6b53cd84281bd0311656d0013f44619d2748177e99e8f8347c989a7b59f9d8dcf00f" + "31db0684a4a83e037e8777bae55f799b0d", + "fdaaff86ceb937502cd9012d03585800", + "b0a881b751cc1eb0c912a4cf9bd971983707dbd2411725664503455c55db25cdb19bc669c" + "2654a3a8011de6bf7eff3f9f07834"}, + {"ab639bae205547607506522bd3cdca7861369e2b42ef175ff135f6ba435d5a8e", + "5fbb63eb44bd59fee458d8f6", + "9a34c62bed0972285503a32812877187a54dedbd55d2317fed89282bf1af4ba0b6bb9f9e1" + "6dd86da3b441deb7841262bc6bd63", + "1ef2b1768b805587935ffaf754a11bd2a305076d6374f1f5098b1284444b78f55408a786d" + "a37e1b7f1401c330d3585ef56f3e4d35eaaac92e1381d636477dc4f4beaf559735e902d6b" + "e58723257d4ac1ed9bd213de387f35f3c4", + "e0299e079bff46fd12e36d1c60e41434", + "e5a3ce804a8516cdd12122c091256b789076576040dbf3c55e8be3c016025896b8a72532b" + "fd51196cc82efca47aa0fd8e2e0dc"}, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_5[] = { + { + "8b37c4b8cf634704920059866ad96c49e9da502c63fca4a3a7a4dcec74cb0610", + "cb59344d2b06c4ae57cd0ea4", "66ab935c93555e786b775637a3", "", + "d8733acbb564d8afaa99d7ca2e2f92a9", nullptr // FAIL + }, + {"a71dac1377a3bf5d7fb1b5e36bee70d2e01de2a84a1c1009ba7448f7f26131dc", + "c5b60dda3f333b1146e9da7c", "43af49ec1ae3738a20755034d6", "", + "6f80b6ef2d8830a55eb63680a8dff9e0", "5b87141335f2becac1a559e05f"}, + {"dc1f64681014be221b00793bbcf5a5bc675b968eb7a3a3d5aa5978ef4fa45ecc", + "056ae9a1a69e38af603924fe", "33013a48d9ea0df2911d583271", "", + "5b8f9cc22303e979cd1524187e9f70fe", "2a7e05612191c8bce2f529dca9"}, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector* const test_group_array[] = { + test_group_0, test_group_1, test_group_2, + test_group_3, test_group_4, test_group_5, +}; + +} // namespace + +namespace quic { +namespace test { + +// DecryptWithNonce wraps the |Decrypt| method of |decrypter| to allow passing +// in an nonce and also to allocate the buffer needed for the plaintext. +QuicData* DecryptWithNonce(Aes256GcmDecrypter* decrypter, + absl::string_view nonce, + absl::string_view associated_data, + absl::string_view ciphertext) { + decrypter->SetIV(nonce); + std::unique_ptr output(new char[ciphertext.length()]); + size_t output_length = 0; + const bool success = + decrypter->DecryptPacket(0, associated_data, ciphertext, output.get(), + &output_length, ciphertext.length()); + if (!success) { + return nullptr; + } + return new QuicData(output.release(), output_length, true); +} + +class Aes256GcmDecrypterTest : public QuicTest {}; + +TEST_F(Aes256GcmDecrypterTest, Decrypt) { + for (size_t i = 0; i < ABSL_ARRAYSIZE(test_group_array); i++) { + SCOPED_TRACE(i); + const TestVector* test_vectors = test_group_array[i]; + const TestGroupInfo& test_info = test_group_info[i]; + for (size_t j = 0; test_vectors[j].key != nullptr; j++) { + // If not present then decryption is expected to fail. + bool has_pt = test_vectors[j].pt; + + // Decode the test vector. + std::string key = absl::HexStringToBytes(test_vectors[j].key); + std::string iv = absl::HexStringToBytes(test_vectors[j].iv); + std::string ct = absl::HexStringToBytes(test_vectors[j].ct); + std::string aad = absl::HexStringToBytes(test_vectors[j].aad); + std::string tag = absl::HexStringToBytes(test_vectors[j].tag); + std::string pt; + if (has_pt) { + pt = absl::HexStringToBytes(test_vectors[j].pt); + } + + // The test vector's lengths should look sane. Note that the lengths + // in |test_info| are in bits. + EXPECT_EQ(test_info.key_len, key.length() * 8); + EXPECT_EQ(test_info.iv_len, iv.length() * 8); + EXPECT_EQ(test_info.pt_len, ct.length() * 8); + EXPECT_EQ(test_info.aad_len, aad.length() * 8); + EXPECT_EQ(test_info.tag_len, tag.length() * 8); + if (has_pt) { + EXPECT_EQ(test_info.pt_len, pt.length() * 8); + } + std::string ciphertext = ct + tag; + + Aes256GcmDecrypter decrypter; + ASSERT_TRUE(decrypter.SetKey(key)); + + std::unique_ptr decrypted(DecryptWithNonce( + &decrypter, iv, + // This deliberately tests that the decrypter can + // handle an AAD that is set to nullptr, as opposed + // to a zero-length, non-nullptr pointer. + aad.length() ? aad : absl::string_view(), ciphertext)); + if (!decrypted) { + EXPECT_FALSE(has_pt); + continue; + } + EXPECT_TRUE(has_pt); + + ASSERT_EQ(pt.length(), decrypted->length()); + quiche::test::CompareCharArraysWithHexError( + "plaintext", decrypted->data(), pt.length(), pt.data(), pt.length()); + } + } +} + +TEST_F(Aes256GcmDecrypterTest, GenerateHeaderProtectionMask) { + Aes256GcmDecrypter decrypter; + std::string key = absl::HexStringToBytes( + "ed23ecbf54d426def5c52c3dcfc84434e62e57781d3125bb21ed91b7d3e07788"); + std::string sample = + absl::HexStringToBytes("4d190c474be2b8babafb49ec4e38e810"); + QuicDataReader sample_reader(sample.data(), sample.size()); + ASSERT_TRUE(decrypter.SetHeaderProtectionKey(key)); + std::string mask = decrypter.GenerateHeaderProtectionMask(&sample_reader); + std::string expected_mask = + absl::HexStringToBytes("db9ed4e6ccd033af2eae01407199c56e"); + quiche::test::CompareCharArraysWithHexError( + "header protection mask", mask.data(), mask.size(), expected_mask.data(), + expected_mask.size()); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/aes_256_gcm_encrypter.cc b/quiche/quic/core/crypto/aes_256_gcm_encrypter.cc new file mode 100644 index 000000000000..802ff992c9ba --- /dev/null +++ b/quiche/quic/core/crypto/aes_256_gcm_encrypter.cc @@ -0,0 +1,27 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/aes_256_gcm_encrypter.h" + +#include "openssl/evp.h" + +namespace quic { + +namespace { + +const size_t kKeySize = 32; +const size_t kNonceSize = 12; + +} // namespace + +Aes256GcmEncrypter::Aes256GcmEncrypter() + : AesBaseEncrypter(EVP_aead_aes_256_gcm, kKeySize, kAuthTagSize, kNonceSize, + /* use_ietf_nonce_construction */ true) { + static_assert(kKeySize <= kMaxKeySize, "key size too big"); + static_assert(kNonceSize <= kMaxNonceSize, "nonce size too big"); +} + +Aes256GcmEncrypter::~Aes256GcmEncrypter() {} + +} // namespace quic diff --git a/quiche/quic/core/crypto/aes_256_gcm_encrypter.h b/quiche/quic/core/crypto/aes_256_gcm_encrypter.h new file mode 100644 index 000000000000..9ba47f12130c --- /dev/null +++ b/quiche/quic/core/crypto/aes_256_gcm_encrypter.h @@ -0,0 +1,32 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_AES_256_GCM_ENCRYPTER_H_ +#define QUICHE_QUIC_CORE_CRYPTO_AES_256_GCM_ENCRYPTER_H_ + +#include "quiche/quic/core/crypto/aes_base_encrypter.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// An Aes256GcmEncrypter is a QuicEncrypter that implements the +// AEAD_AES_256_GCM algorithm specified in RFC 5116 for use in IETF QUIC. +// +// It uses an authentication tag of 16 bytes (128 bits). It uses a 12 byte IV +// that is XOR'd with the packet number to compute the nonce. +class QUIC_EXPORT_PRIVATE Aes256GcmEncrypter : public AesBaseEncrypter { + public: + enum { + kAuthTagSize = 16, + }; + + Aes256GcmEncrypter(); + Aes256GcmEncrypter(const Aes256GcmEncrypter&) = delete; + Aes256GcmEncrypter& operator=(const Aes256GcmEncrypter&) = delete; + ~Aes256GcmEncrypter() override; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_AES_256_GCM_ENCRYPTER_H_ diff --git a/quiche/quic/core/crypto/aes_256_gcm_encrypter_test.cc b/quiche/quic/core/crypto/aes_256_gcm_encrypter_test.cc new file mode 100644 index 000000000000..6389fdb53eff --- /dev/null +++ b/quiche/quic/core/crypto/aes_256_gcm_encrypter_test.cc @@ -0,0 +1,259 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/aes_256_gcm_encrypter.h" + +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +namespace { + +// The AES GCM test vectors come from the file gcmEncryptExtIV256.rsp +// downloaded from +// https://csrc.nist.gov/Projects/Cryptographic-Algorithm-Validation-Program/CAVP-TESTING-BLOCK-CIPHER-MODES#GCMVS +// on 2017-09-27. The test vectors in that file look like this: +// +// [Keylen = 256] +// [IVlen = 96] +// [PTlen = 0] +// [AADlen = 0] +// [Taglen = 128] +// +// Count = 0 +// Key = b52c505a37d78eda5dd34f20c22540ea1b58963cf8e5bf8ffa85f9f2492505b4 +// IV = 516c33929df5a3284ff463d7 +// PT = +// AAD = +// CT = +// Tag = bdc1ac884d332457a1d2664f168c76f0 +// +// Count = 1 +// Key = 5fe0861cdc2690ce69b3658c7f26f8458eec1c9243c5ba0845305d897e96ca0f +// IV = 770ac1a5a3d476d5d96944a1 +// PT = +// AAD = +// CT = +// Tag = 196d691e1047093ca4b3d2ef4baba216 +// +// ... +// +// The gcmEncryptExtIV256.rsp file is huge (3.2 MB), so a few test vectors were +// selected for this unit test. + +// Describes a group of test vectors that all have a given key length, IV +// length, plaintext length, AAD length, and tag length. +struct TestGroupInfo { + size_t key_len; + size_t iv_len; + size_t pt_len; + size_t aad_len; + size_t tag_len; +}; + +// Each test vector consists of six strings of lowercase hexadecimal digits. +// The strings may be empty (zero length). A test vector with a nullptr |key| +// marks the end of an array of test vectors. +struct TestVector { + const char* key; + const char* iv; + const char* pt; + const char* aad; + const char* ct; + const char* tag; +}; + +const TestGroupInfo test_group_info[] = { + {256, 96, 0, 0, 128}, {256, 96, 0, 128, 128}, {256, 96, 128, 0, 128}, + {256, 96, 408, 160, 128}, {256, 96, 408, 720, 128}, {256, 96, 104, 0, 128}, +}; + +const TestVector test_group_0[] = { + {"b52c505a37d78eda5dd34f20c22540ea1b58963cf8e5bf8ffa85f9f2492505b4", + "516c33929df5a3284ff463d7", "", "", "", + "bdc1ac884d332457a1d2664f168c76f0"}, + {"5fe0861cdc2690ce69b3658c7f26f8458eec1c9243c5ba0845305d897e96ca0f", + "770ac1a5a3d476d5d96944a1", "", "", "", + "196d691e1047093ca4b3d2ef4baba216"}, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_1[] = { + {"78dc4e0aaf52d935c3c01eea57428f00ca1fd475f5da86a49c8dd73d68c8e223", + "d79cf22d504cc793c3fb6c8a", "", "b96baa8c1c75a671bfb2d08d06be5f36", "", + "3e5d486aa2e30b22e040b85723a06e76"}, + {"4457ff33683cca6ca493878bdc00373893a9763412eef8cddb54f91318e0da88", + "699d1f29d7b8c55300bb1fd2", "", "6749daeea367d0e9809e2dc2f309e6e3", "", + "d60c74d2517fde4a74e0cd4709ed43a9"}, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_2[] = { + {"31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22", + "0d18e06c7c725ac9e362e1ce", "2db5168e932556f8089a0622981d017d", "", + "fa4362189661d163fcd6a56d8bf0405a", "d636ac1bbedd5cc3ee727dc2ab4a9489"}, + {"460fc864972261c2560e1eb88761ff1c992b982497bd2ac36c04071cbb8e5d99", + "8a4a16b9e210eb68bcb6f58d", "99e4e926ffe927f691893fb79a96b067", "", + "133fc15751621b5f325c7ff71ce08324", "ec4e87e0cf74a13618d0b68636ba9fa7"}, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_3[] = { + {"24501ad384e473963d476edcfe08205237acfd49b5b8f33857f8114e863fec7f", + "9ff18563b978ec281b3f2794", + "27f348f9cdc0c5bd5e66b1ccb63ad920ff2219d14e8d631b3872265cf117ee86757accb15" + "8bd9abb3868fdc0d0b074b5f01b2c", + "adb5ec720ccf9898500028bf34afccbcaca126ef", + "eb7cb754c824e8d96f7c6d9b76c7d26fb874ffbf1d65c6f64a698d839b0b06145dae82057" + "ad55994cf59ad7f67c0fa5e85fab8", + "bc95c532fecc594c36d1550286a7a3f0"}, + {"fb43f5ab4a1738a30c1e053d484a94254125d55dccee1ad67c368bc1a985d235", + "9fbb5f8252db0bca21f1c230", + "34b797bb82250e23c5e796db2c37e488b3b99d1b981cea5e5b0c61a0b39adb6bd6ef1f507" + "22e2e4f81115cfcf53f842e2a6c08", + "98f8ae1735c39f732e2cbee1156dabeb854ec7a2", + "871cd53d95a8b806bd4821e6c4456204d27fd704ba3d07ce25872dc604ea5c5ea13322186" + "b7489db4fa060c1fd4159692612c8", + "07b48e4a32fac47e115d7ac7445d8330"}, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_4[] = { + {"148579a3cbca86d5520d66c0ec71ca5f7e41ba78e56dc6eebd566fed547fe691", + "b08a5ea1927499c6ecbfd4e0", + "9d0b15fdf1bd595f91f8b3abc0f7dec927dfd4799935a1795d9ce00c9b879434420fe42c2" + "75a7cd7b39d638fb81ca52b49dc41", + "e4f963f015ffbb99ee3349bbaf7e8e8e6c2a71c230a48f9d59860a29091d2747e01a5ca57" + "2347e247d25f56ba7ae8e05cde2be3c97931292c02370208ecd097ef692687fecf2f419d3" + "200162a6480a57dad408a0dfeb492e2c5d", + "2097e372950a5e9383c675e89eea1c314f999159f5611344b298cda45e62843716f215f82" + "ee663919c64002a5c198d7878fd3f", + "adbecdb0d5c2224d804d2886ff9a5760"}, + {"e49af19182faef0ebeeba9f2d3be044e77b1212358366e4ef59e008aebcd9788", + "e7f37d79a6a487a5a703edbb", + "461cd0caf7427a3d44408d825ed719237272ecd503b9094d1f62c97d63ed83a0b50bdc804" + "ffdd7991da7a5b6dcf48d4bcd2cbc", + "19a9a1cfc647346781bef51ed9070d05f99a0e0192a223c5cd2522dbdf97d9739dd39fb17" + "8ade3339e68774b058aa03e9a20a9a205bc05f32381df4d63396ef691fefd5a71b49a2ad8" + "2d5ea428778ca47ee1398792762413cff4", + "32ca3588e3e56eb4c8301b009d8b84b8a900b2b88ca3c21944205e9dd7311757b51394ae9" + "0d8bb3807b471677614f4198af909", + "3e403d035c71d88f1be1a256c89ba6ad"}, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector test_group_5[] = { + {"82c4f12eeec3b2d3d157b0f992d292b237478d2cecc1d5f161389b97f999057a", + "7b40b20f5f397177990ef2d1", "982a296ee1cd7086afad976945", "", + "ec8e05a0471d6b43a59ca5335f", "113ddeafc62373cac2f5951bb9165249"}, + {"db4340af2f835a6c6d7ea0ca9d83ca81ba02c29b7410f221cb6071114e393240", + "40e438357dd80a85cac3349e", "8ddb3397bd42853193cb0f80c9", "", + "b694118c85c41abf69e229cb0f", "c07f1b8aafbd152f697eb67f2a85fe45"}, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +const TestVector* const test_group_array[] = { + test_group_0, test_group_1, test_group_2, + test_group_3, test_group_4, test_group_5, +}; + +} // namespace + +namespace quic { +namespace test { + +// EncryptWithNonce wraps the |Encrypt| method of |encrypter| to allow passing +// in an nonce and also to allocate the buffer needed for the ciphertext. +QuicData* EncryptWithNonce(Aes256GcmEncrypter* encrypter, + absl::string_view nonce, + absl::string_view associated_data, + absl::string_view plaintext) { + size_t ciphertext_size = encrypter->GetCiphertextSize(plaintext.length()); + std::unique_ptr ciphertext(new char[ciphertext_size]); + + if (!encrypter->Encrypt(nonce, associated_data, plaintext, + reinterpret_cast(ciphertext.get()))) { + return nullptr; + } + + return new QuicData(ciphertext.release(), ciphertext_size, true); +} + +class Aes256GcmEncrypterTest : public QuicTest {}; + +TEST_F(Aes256GcmEncrypterTest, Encrypt) { + for (size_t i = 0; i < ABSL_ARRAYSIZE(test_group_array); i++) { + SCOPED_TRACE(i); + const TestVector* test_vectors = test_group_array[i]; + const TestGroupInfo& test_info = test_group_info[i]; + for (size_t j = 0; test_vectors[j].key != nullptr; j++) { + // Decode the test vector. + std::string key = absl::HexStringToBytes(test_vectors[j].key); + std::string iv = absl::HexStringToBytes(test_vectors[j].iv); + std::string pt = absl::HexStringToBytes(test_vectors[j].pt); + std::string aad = absl::HexStringToBytes(test_vectors[j].aad); + std::string ct = absl::HexStringToBytes(test_vectors[j].ct); + std::string tag = absl::HexStringToBytes(test_vectors[j].tag); + + // The test vector's lengths should look sane. Note that the lengths + // in |test_info| are in bits. + EXPECT_EQ(test_info.key_len, key.length() * 8); + EXPECT_EQ(test_info.iv_len, iv.length() * 8); + EXPECT_EQ(test_info.pt_len, pt.length() * 8); + EXPECT_EQ(test_info.aad_len, aad.length() * 8); + EXPECT_EQ(test_info.pt_len, ct.length() * 8); + EXPECT_EQ(test_info.tag_len, tag.length() * 8); + + Aes256GcmEncrypter encrypter; + ASSERT_TRUE(encrypter.SetKey(key)); + std::unique_ptr encrypted( + EncryptWithNonce(&encrypter, iv, + // This deliberately tests that the encrypter can + // handle an AAD that is set to nullptr, as opposed + // to a zero-length, non-nullptr pointer. + aad.length() ? aad : absl::string_view(), pt)); + ASSERT_TRUE(encrypted.get()); + + ASSERT_EQ(ct.length() + tag.length(), encrypted->length()); + quiche::test::CompareCharArraysWithHexError( + "ciphertext", encrypted->data(), ct.length(), ct.data(), ct.length()); + quiche::test::CompareCharArraysWithHexError( + "authentication tag", encrypted->data() + ct.length(), tag.length(), + tag.data(), tag.length()); + } + } +} + +TEST_F(Aes256GcmEncrypterTest, GetMaxPlaintextSize) { + Aes256GcmEncrypter encrypter; + EXPECT_EQ(1000u, encrypter.GetMaxPlaintextSize(1016)); + EXPECT_EQ(100u, encrypter.GetMaxPlaintextSize(116)); + EXPECT_EQ(10u, encrypter.GetMaxPlaintextSize(26)); +} + +TEST_F(Aes256GcmEncrypterTest, GetCiphertextSize) { + Aes256GcmEncrypter encrypter; + EXPECT_EQ(1016u, encrypter.GetCiphertextSize(1000)); + EXPECT_EQ(116u, encrypter.GetCiphertextSize(100)); + EXPECT_EQ(26u, encrypter.GetCiphertextSize(10)); +} + +TEST_F(Aes256GcmEncrypterTest, GenerateHeaderProtectionMask) { + Aes256GcmEncrypter encrypter; + std::string key = absl::HexStringToBytes( + "ed23ecbf54d426def5c52c3dcfc84434e62e57781d3125bb21ed91b7d3e07788"); + std::string sample = + absl::HexStringToBytes("4d190c474be2b8babafb49ec4e38e810"); + ASSERT_TRUE(encrypter.SetHeaderProtectionKey(key)); + std::string mask = encrypter.GenerateHeaderProtectionMask(sample); + std::string expected_mask = + absl::HexStringToBytes("db9ed4e6ccd033af2eae01407199c56e"); + quiche::test::CompareCharArraysWithHexError( + "header protection mask", mask.data(), mask.size(), expected_mask.data(), + expected_mask.size()); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/aes_base_decrypter.cc b/quiche/quic/core/crypto/aes_base_decrypter.cc new file mode 100644 index 000000000000..2962854c1750 --- /dev/null +++ b/quiche/quic/core/crypto/aes_base_decrypter.cc @@ -0,0 +1,52 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/aes_base_decrypter.h" + +#include "absl/strings/string_view.h" +#include "openssl/aes.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" + +namespace quic { + +bool AesBaseDecrypter::SetHeaderProtectionKey(absl::string_view key) { + if (key.size() != GetKeySize()) { + QUIC_BUG(quic_bug_10649_1) << "Invalid key size for header protection"; + return false; + } + if (AES_set_encrypt_key(reinterpret_cast(key.data()), + key.size() * 8, &pne_key_) != 0) { + QUIC_BUG(quic_bug_10649_2) << "Unexpected failure of AES_set_encrypt_key"; + return false; + } + return true; +} + +std::string AesBaseDecrypter::GenerateHeaderProtectionMask( + QuicDataReader* sample_reader) { + absl::string_view sample; + if (!sample_reader->ReadStringPiece(&sample, AES_BLOCK_SIZE)) { + return std::string(); + } + std::string out(AES_BLOCK_SIZE, 0); + AES_encrypt(reinterpret_cast(sample.data()), + reinterpret_cast(const_cast(out.data())), + &pne_key_); + return out; +} + +QuicPacketCount AesBaseDecrypter::GetIntegrityLimit() const { + // For AEAD_AES_128_GCM ... endpoints that do not attempt to remove + // protection from packets larger than 2^11 bytes can attempt to remove + // protection from at most 2^57 packets. + // For AEAD_AES_256_GCM [the limit] is substantially larger than the limit for + // AEAD_AES_128_GCM. However, this document recommends that the same limit be + // applied to both functions as either limit is acceptably large. + // https://quicwg.org/base-drafts/draft-ietf-quic-tls.html#name-integrity-limit + static_assert(kMaxIncomingPacketSize <= 2048, + "This key limit requires limits on decryption payload sizes"); + return 144115188075855872U; +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/aes_base_decrypter.h b/quiche/quic/core/crypto/aes_base_decrypter.h new file mode 100644 index 000000000000..9fa35cfd6666 --- /dev/null +++ b/quiche/quic/core/crypto/aes_base_decrypter.h @@ -0,0 +1,33 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_AES_BASE_DECRYPTER_H_ +#define QUICHE_QUIC_CORE_CRYPTO_AES_BASE_DECRYPTER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "openssl/aes.h" +#include "quiche/quic/core/crypto/aead_base_decrypter.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class QUIC_EXPORT_PRIVATE AesBaseDecrypter : public AeadBaseDecrypter { + public: + using AeadBaseDecrypter::AeadBaseDecrypter; + + bool SetHeaderProtectionKey(absl::string_view key) override; + std::string GenerateHeaderProtectionMask( + QuicDataReader* sample_reader) override; + QuicPacketCount GetIntegrityLimit() const override; + + private: + // The key used for packet number encryption. + AES_KEY pne_key_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_AES_BASE_DECRYPTER_H_ diff --git a/quiche/quic/core/crypto/aes_base_encrypter.cc b/quiche/quic/core/crypto/aes_base_encrypter.cc new file mode 100644 index 000000000000..89ab64566416 --- /dev/null +++ b/quiche/quic/core/crypto/aes_base_encrypter.cc @@ -0,0 +1,48 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/aes_base_encrypter.h" + +#include "absl/strings/string_view.h" +#include "openssl/aes.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" + +namespace quic { + +bool AesBaseEncrypter::SetHeaderProtectionKey(absl::string_view key) { + if (key.size() != GetKeySize()) { + QUIC_BUG(quic_bug_10726_1) + << "Invalid key size for header protection: " << key.size(); + return false; + } + if (AES_set_encrypt_key(reinterpret_cast(key.data()), + key.size() * 8, &pne_key_) != 0) { + QUIC_BUG(quic_bug_10726_2) << "Unexpected failure of AES_set_encrypt_key"; + return false; + } + return true; +} + +std::string AesBaseEncrypter::GenerateHeaderProtectionMask( + absl::string_view sample) { + if (sample.size() != AES_BLOCK_SIZE) { + return std::string(); + } + std::string out(AES_BLOCK_SIZE, 0); + AES_encrypt(reinterpret_cast(sample.data()), + reinterpret_cast(const_cast(out.data())), + &pne_key_); + return out; +} + +QuicPacketCount AesBaseEncrypter::GetConfidentialityLimit() const { + // For AEAD_AES_128_GCM and AEAD_AES_256_GCM ... endpoints that do not send + // packets larger than 2^11 bytes cannot protect more than 2^28 packets. + // https://quicwg.org/base-drafts/draft-ietf-quic-tls.html#name-confidentiality-limit + static_assert(kMaxOutgoingPacketSize <= 2048, + "This key limit requires limits on encryption payload sizes"); + return 268435456U; +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/aes_base_encrypter.h b/quiche/quic/core/crypto/aes_base_encrypter.h new file mode 100644 index 000000000000..c4fdb86ecab1 --- /dev/null +++ b/quiche/quic/core/crypto/aes_base_encrypter.h @@ -0,0 +1,32 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_AES_BASE_ENCRYPTER_H_ +#define QUICHE_QUIC_CORE_CRYPTO_AES_BASE_ENCRYPTER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "openssl/aes.h" +#include "quiche/quic/core/crypto/aead_base_encrypter.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class QUIC_EXPORT_PRIVATE AesBaseEncrypter : public AeadBaseEncrypter { + public: + using AeadBaseEncrypter::AeadBaseEncrypter; + + bool SetHeaderProtectionKey(absl::string_view key) override; + std::string GenerateHeaderProtectionMask(absl::string_view sample) override; + QuicPacketCount GetConfidentialityLimit() const override; + + private: + // The key used for packet number encryption. + AES_KEY pne_key_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_AES_BASE_ENCRYPTER_H_ diff --git a/quiche/quic/core/crypto/boring_utils.h b/quiche/quic/core/crypto/boring_utils.h new file mode 100644 index 000000000000..d6a1dd413daf --- /dev/null +++ b/quiche/quic/core/crypto/boring_utils.h @@ -0,0 +1,34 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_BORING_UTILS_H_ +#define QUICHE_QUIC_CORE_CRYPTO_BORING_UTILS_H_ + +#include "absl/strings/string_view.h" +#include "openssl/bytestring.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +inline QUIC_EXPORT_PRIVATE absl::string_view CbsToStringPiece(CBS cbs) { + return absl::string_view(reinterpret_cast(CBS_data(&cbs)), + CBS_len(&cbs)); +} + +inline QUIC_EXPORT_PRIVATE CBS StringPieceToCbs(absl::string_view piece) { + CBS result; + CBS_init(&result, reinterpret_cast(piece.data()), + piece.size()); + return result; +} + +inline QUIC_EXPORT_PRIVATE bool AddStringToCbb(CBB* cbb, + absl::string_view piece) { + return 1 == CBB_add_bytes(cbb, reinterpret_cast(piece.data()), + piece.size()); +} + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_BORING_UTILS_H_ diff --git a/quiche/quic/core/crypto/cert_compressor.cc b/quiche/quic/core/crypto/cert_compressor.cc new file mode 100644 index 000000000000..4357b9c7c826 --- /dev/null +++ b/quiche/quic/core/crypto/cert_compressor.cc @@ -0,0 +1,598 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/cert_compressor.h" + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "zlib.h" + +namespace quic { + +namespace { + +// kCommonCertSubstrings contains ~1500 bytes of common certificate substrings +// in order to help zlib. This was generated via a fairly dumb algorithm from +// the Alexa Top 5000 set - we could probably do better. +static const unsigned char kCommonCertSubstrings[] = { + 0x04, 0x02, 0x30, 0x00, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x25, 0x04, + 0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, + 0x01, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x02, 0x30, + 0x5f, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x86, 0xf8, 0x42, 0x04, 0x01, + 0x06, 0x06, 0x0b, 0x60, 0x86, 0x48, 0x01, 0x86, 0xfd, 0x6d, 0x01, 0x07, + 0x17, 0x01, 0x30, 0x33, 0x20, 0x45, 0x78, 0x74, 0x65, 0x6e, 0x64, 0x65, + 0x64, 0x20, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x20, 0x53, 0x20, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x65, 0x64, 0x31, 0x34, + 0x20, 0x53, 0x53, 0x4c, 0x20, 0x43, 0x41, 0x30, 0x1e, 0x17, 0x0d, 0x31, + 0x32, 0x20, 0x53, 0x65, 0x63, 0x75, 0x72, 0x65, 0x20, 0x53, 0x65, 0x72, + 0x76, 0x65, 0x72, 0x20, 0x43, 0x41, 0x30, 0x2d, 0x61, 0x69, 0x61, 0x2e, + 0x76, 0x65, 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, + 0x2f, 0x45, 0x2d, 0x63, 0x72, 0x6c, 0x2e, 0x76, 0x65, 0x72, 0x69, 0x73, + 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x45, 0x2e, 0x63, 0x65, + 0x72, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, + 0x01, 0x05, 0x05, 0x00, 0x03, 0x82, 0x01, 0x01, 0x00, 0x4a, 0x2e, 0x63, + 0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73, + 0x2f, 0x63, 0x70, 0x73, 0x20, 0x28, 0x63, 0x29, 0x30, 0x30, 0x09, 0x06, + 0x03, 0x55, 0x1d, 0x13, 0x04, 0x02, 0x30, 0x00, 0x30, 0x1d, 0x30, 0x0d, + 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05, + 0x00, 0x03, 0x82, 0x01, 0x01, 0x00, 0x7b, 0x30, 0x1d, 0x06, 0x03, 0x55, + 0x1d, 0x0e, 0x30, 0x82, 0x01, 0x22, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, + 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, 0x05, 0x00, 0x03, 0x82, 0x01, + 0x0f, 0x00, 0x30, 0x82, 0x01, 0x0a, 0x02, 0x82, 0x01, 0x01, 0x00, 0xd2, + 0x6f, 0x64, 0x6f, 0x63, 0x61, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x43, 0x2e, + 0x63, 0x72, 0x6c, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x0e, 0x04, 0x16, + 0x04, 0x14, 0xb4, 0x2e, 0x67, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x73, 0x69, + 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x30, 0x0b, 0x06, 0x03, + 0x55, 0x1d, 0x0f, 0x04, 0x04, 0x03, 0x02, 0x01, 0x30, 0x0d, 0x06, 0x09, + 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05, 0x00, 0x30, + 0x81, 0xca, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, + 0x02, 0x55, 0x53, 0x31, 0x10, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x04, 0x08, + 0x13, 0x07, 0x41, 0x72, 0x69, 0x7a, 0x6f, 0x6e, 0x61, 0x31, 0x13, 0x30, + 0x11, 0x06, 0x03, 0x55, 0x04, 0x07, 0x13, 0x0a, 0x53, 0x63, 0x6f, 0x74, + 0x74, 0x73, 0x64, 0x61, 0x6c, 0x65, 0x31, 0x1a, 0x30, 0x18, 0x06, 0x03, + 0x55, 0x04, 0x0a, 0x13, 0x11, 0x47, 0x6f, 0x44, 0x61, 0x64, 0x64, 0x79, + 0x2e, 0x63, 0x6f, 0x6d, 0x2c, 0x20, 0x49, 0x6e, 0x63, 0x2e, 0x31, 0x33, + 0x30, 0x31, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x2a, 0x68, 0x74, 0x74, + 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, + 0x61, 0x74, 0x65, 0x73, 0x2e, 0x67, 0x6f, 0x64, 0x61, 0x64, 0x64, 0x79, + 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x70, 0x6f, 0x73, 0x69, 0x74, + 0x6f, 0x72, 0x79, 0x31, 0x30, 0x30, 0x2e, 0x06, 0x03, 0x55, 0x04, 0x03, + 0x13, 0x27, 0x47, 0x6f, 0x20, 0x44, 0x61, 0x64, 0x64, 0x79, 0x20, 0x53, + 0x65, 0x63, 0x75, 0x72, 0x65, 0x20, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, + 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x20, 0x41, 0x75, 0x74, 0x68, + 0x6f, 0x72, 0x69, 0x74, 0x79, 0x31, 0x11, 0x30, 0x0f, 0x06, 0x03, 0x55, + 0x04, 0x05, 0x13, 0x08, 0x30, 0x37, 0x39, 0x36, 0x39, 0x32, 0x38, 0x37, + 0x30, 0x1e, 0x17, 0x0d, 0x31, 0x31, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d, + 0x0f, 0x01, 0x01, 0xff, 0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x0c, + 0x06, 0x03, 0x55, 0x1d, 0x13, 0x01, 0x01, 0xff, 0x04, 0x02, 0x30, 0x00, + 0x30, 0x1d, 0x30, 0x0f, 0x06, 0x03, 0x55, 0x1d, 0x13, 0x01, 0x01, 0xff, + 0x04, 0x05, 0x30, 0x03, 0x01, 0x01, 0x00, 0x30, 0x1d, 0x06, 0x03, 0x55, + 0x1d, 0x25, 0x04, 0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, + 0x05, 0x07, 0x03, 0x01, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, + 0x03, 0x02, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d, 0x0f, 0x01, 0x01, 0xff, + 0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x33, 0x06, 0x03, 0x55, 0x1d, + 0x1f, 0x04, 0x2c, 0x30, 0x2a, 0x30, 0x28, 0xa0, 0x26, 0xa0, 0x24, 0x86, + 0x22, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x72, 0x6c, 0x2e, + 0x67, 0x6f, 0x64, 0x61, 0x64, 0x64, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, + 0x67, 0x64, 0x73, 0x31, 0x2d, 0x32, 0x30, 0x2a, 0x30, 0x28, 0x06, 0x08, + 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x02, 0x01, 0x16, 0x1c, 0x68, 0x74, + 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x76, 0x65, + 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, + 0x70, 0x73, 0x30, 0x34, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x5a, 0x17, + 0x0d, 0x31, 0x33, 0x30, 0x35, 0x30, 0x39, 0x06, 0x08, 0x2b, 0x06, 0x01, + 0x05, 0x05, 0x07, 0x30, 0x02, 0x86, 0x2d, 0x68, 0x74, 0x74, 0x70, 0x3a, + 0x2f, 0x2f, 0x73, 0x30, 0x39, 0x30, 0x37, 0x06, 0x08, 0x2b, 0x06, 0x01, + 0x05, 0x05, 0x07, 0x02, 0x30, 0x44, 0x06, 0x03, 0x55, 0x1d, 0x20, 0x04, + 0x3d, 0x30, 0x3b, 0x30, 0x39, 0x06, 0x0b, 0x60, 0x86, 0x48, 0x01, 0x86, + 0xf8, 0x45, 0x01, 0x07, 0x17, 0x06, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, + 0x55, 0x04, 0x06, 0x13, 0x02, 0x47, 0x42, 0x31, 0x1b, 0x53, 0x31, 0x17, + 0x30, 0x15, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x0e, 0x56, 0x65, 0x72, + 0x69, 0x53, 0x69, 0x67, 0x6e, 0x2c, 0x20, 0x49, 0x6e, 0x63, 0x2e, 0x31, + 0x1f, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x16, 0x56, 0x65, + 0x72, 0x69, 0x53, 0x69, 0x67, 0x6e, 0x20, 0x54, 0x72, 0x75, 0x73, 0x74, + 0x20, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x31, 0x3b, 0x30, 0x39, + 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x32, 0x54, 0x65, 0x72, 0x6d, 0x73, + 0x20, 0x6f, 0x66, 0x20, 0x75, 0x73, 0x65, 0x20, 0x61, 0x74, 0x20, 0x68, + 0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x76, + 0x65, 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, + 0x72, 0x70, 0x61, 0x20, 0x28, 0x63, 0x29, 0x30, 0x31, 0x10, 0x30, 0x0e, + 0x06, 0x03, 0x55, 0x04, 0x07, 0x13, 0x07, 0x53, 0x31, 0x13, 0x30, 0x11, + 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x0a, 0x47, 0x31, 0x13, 0x30, 0x11, + 0x06, 0x0b, 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x3c, 0x02, 0x01, + 0x03, 0x13, 0x02, 0x55, 0x31, 0x16, 0x30, 0x14, 0x06, 0x03, 0x55, 0x04, + 0x03, 0x14, 0x31, 0x19, 0x30, 0x17, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, + 0x31, 0x1d, 0x30, 0x1b, 0x06, 0x03, 0x55, 0x04, 0x0f, 0x13, 0x14, 0x50, + 0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x20, 0x4f, 0x72, 0x67, 0x61, 0x6e, + 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x31, 0x12, 0x31, 0x21, 0x30, + 0x1f, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x18, 0x44, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x20, 0x43, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x20, 0x56, + 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x64, 0x31, 0x14, 0x31, 0x31, + 0x30, 0x2f, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x28, 0x53, 0x65, 0x65, + 0x20, 0x77, 0x77, 0x77, 0x2e, 0x72, 0x3a, 0x2f, 0x2f, 0x73, 0x65, 0x63, + 0x75, 0x72, 0x65, 0x2e, 0x67, 0x47, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x53, + 0x69, 0x67, 0x6e, 0x31, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x43, 0x41, + 0x2e, 0x63, 0x72, 0x6c, 0x56, 0x65, 0x72, 0x69, 0x53, 0x69, 0x67, 0x6e, + 0x20, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x20, 0x33, 0x20, 0x45, 0x63, 0x72, + 0x6c, 0x2e, 0x67, 0x65, 0x6f, 0x74, 0x72, 0x75, 0x73, 0x74, 0x2e, 0x63, + 0x6f, 0x6d, 0x2f, 0x63, 0x72, 0x6c, 0x73, 0x2f, 0x73, 0x64, 0x31, 0x1a, + 0x30, 0x18, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x68, 0x74, 0x74, 0x70, 0x3a, + 0x2f, 0x2f, 0x45, 0x56, 0x49, 0x6e, 0x74, 0x6c, 0x2d, 0x63, 0x63, 0x72, + 0x74, 0x2e, 0x67, 0x77, 0x77, 0x77, 0x2e, 0x67, 0x69, 0x63, 0x65, 0x72, + 0x74, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x31, 0x6f, 0x63, 0x73, 0x70, 0x2e, + 0x76, 0x65, 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, + 0x30, 0x39, 0x72, 0x61, 0x70, 0x69, 0x64, 0x73, 0x73, 0x6c, 0x2e, 0x63, + 0x6f, 0x73, 0x2e, 0x67, 0x6f, 0x64, 0x61, 0x64, 0x64, 0x79, 0x2e, 0x63, + 0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x6f, 0x72, + 0x79, 0x2f, 0x30, 0x81, 0x80, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, + 0x07, 0x01, 0x01, 0x04, 0x74, 0x30, 0x72, 0x30, 0x24, 0x06, 0x08, 0x2b, + 0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x01, 0x86, 0x18, 0x68, 0x74, 0x74, + 0x70, 0x3a, 0x2f, 0x2f, 0x6f, 0x63, 0x73, 0x70, 0x2e, 0x67, 0x6f, 0x64, + 0x61, 0x64, 0x64, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x30, 0x4a, 0x06, + 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x02, 0x86, 0x3e, 0x68, + 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66, + 0x69, 0x63, 0x61, 0x74, 0x65, 0x73, 0x2e, 0x67, 0x6f, 0x64, 0x61, 0x64, + 0x64, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x70, 0x6f, 0x73, + 0x69, 0x74, 0x6f, 0x72, 0x79, 0x2f, 0x67, 0x64, 0x5f, 0x69, 0x6e, 0x74, + 0x65, 0x72, 0x6d, 0x65, 0x64, 0x69, 0x61, 0x74, 0x65, 0x2e, 0x63, 0x72, + 0x74, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x1d, 0x23, 0x04, 0x18, 0x30, 0x16, + 0x80, 0x14, 0xfd, 0xac, 0x61, 0x32, 0x93, 0x6c, 0x45, 0xd6, 0xe2, 0xee, + 0x85, 0x5f, 0x9a, 0xba, 0xe7, 0x76, 0x99, 0x68, 0xcc, 0xe7, 0x30, 0x27, + 0x86, 0x29, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x86, 0x30, + 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x73, +}; + +// CertEntry represents a certificate in compressed form. Each entry is one of +// the three types enumerated in |Type|. +struct CertEntry { + public: + enum Type { + // Type 0 is reserved to mean "end of list" in the wire format. + + // COMPRESSED means that the certificate is included in the trailing zlib + // data. + COMPRESSED = 1, + // CACHED means that the certificate is already known to the peer and will + // be replaced by its 64-bit hash (in |hash|). + CACHED = 2, + }; + + Type type; + uint64_t hash; + uint64_t set_hash; + uint32_t index; +}; + +// MatchCerts returns a vector of CertEntries describing how to most +// efficiently represent |certs| to a peer who has cached the certificates +// with the 64-bit, FNV-1a hashes in |client_cached_cert_hashes|. +std::vector MatchCerts(const std::vector& certs, + absl::string_view client_cached_cert_hashes) { + std::vector entries; + entries.reserve(certs.size()); + + const bool cached_valid = + client_cached_cert_hashes.size() % sizeof(uint64_t) == 0 && + !client_cached_cert_hashes.empty(); + + for (auto i = certs.begin(); i != certs.end(); ++i) { + CertEntry entry; + + if (cached_valid) { + bool cached = false; + + uint64_t hash = QuicUtils::FNV1a_64_Hash(*i); + // This assumes that the machine is little-endian. + for (size_t j = 0; j < client_cached_cert_hashes.size(); + j += sizeof(uint64_t)) { + uint64_t cached_hash; + memcpy(&cached_hash, client_cached_cert_hashes.data() + j, + sizeof(uint64_t)); + if (hash != cached_hash) { + continue; + } + + entry.type = CertEntry::CACHED; + entry.hash = hash; + entries.push_back(entry); + cached = true; + break; + } + + if (cached) { + continue; + } + } + + entry.type = CertEntry::COMPRESSED; + entries.push_back(entry); + } + + return entries; +} + +// CertEntriesSize returns the size, in bytes, of the serialised form of +// |entries|. +size_t CertEntriesSize(const std::vector& entries) { + size_t entries_size = 0; + + for (auto i = entries.begin(); i != entries.end(); ++i) { + entries_size++; + switch (i->type) { + case CertEntry::COMPRESSED: + break; + case CertEntry::CACHED: + entries_size += sizeof(uint64_t); + break; + } + } + + entries_size++; // for end marker + + return entries_size; +} + +// SerializeCertEntries serialises |entries| to |out|, which must have enough +// space to contain them. +void SerializeCertEntries(uint8_t* out, const std::vector& entries) { + for (auto i = entries.begin(); i != entries.end(); ++i) { + *out++ = static_cast(i->type); + switch (i->type) { + case CertEntry::COMPRESSED: + break; + case CertEntry::CACHED: + memcpy(out, &i->hash, sizeof(i->hash)); + out += sizeof(uint64_t); + break; + } + } + + *out++ = 0; // end marker +} + +// ZlibDictForEntries returns a string that contains the zlib pre-shared +// dictionary to use in order to decompress a zlib block following |entries|. +// |certs| is one-to-one with |entries| and contains the certificates for those +// entries that are CACHED. +std::string ZlibDictForEntries(const std::vector& entries, + const std::vector& certs) { + std::string zlib_dict; + + // The dictionary starts with the cached certs in reverse order. + size_t zlib_dict_size = 0; + for (size_t i = certs.size() - 1; i < certs.size(); i--) { + if (entries[i].type != CertEntry::COMPRESSED) { + zlib_dict_size += certs[i].size(); + } + } + + // At the end of the dictionary is a block of common certificate substrings. + zlib_dict_size += sizeof(kCommonCertSubstrings); + + zlib_dict.reserve(zlib_dict_size); + + for (size_t i = certs.size() - 1; i < certs.size(); i--) { + if (entries[i].type != CertEntry::COMPRESSED) { + zlib_dict += certs[i]; + } + } + + zlib_dict += std::string(reinterpret_cast(kCommonCertSubstrings), + sizeof(kCommonCertSubstrings)); + + QUICHE_DCHECK_EQ(zlib_dict.size(), zlib_dict_size); + + return zlib_dict; +} + +// HashCerts returns the FNV-1a hashes of |certs|. +std::vector HashCerts(const std::vector& certs) { + std::vector ret; + ret.reserve(certs.size()); + + for (auto i = certs.begin(); i != certs.end(); ++i) { + ret.push_back(QuicUtils::FNV1a_64_Hash(*i)); + } + + return ret; +} + +// ParseEntries parses the serialised form of a vector of CertEntries from +// |in_out| and writes them to |out_entries|. CACHED entries are resolved using +// |cached_certs| and written to |out_certs|. |in_out| is updated to contain +// the trailing data. +bool ParseEntries(absl::string_view* in_out, + const std::vector& cached_certs, + std::vector* out_entries, + std::vector* out_certs) { + absl::string_view in = *in_out; + std::vector cached_hashes; + + out_entries->clear(); + out_certs->clear(); + + for (;;) { + if (in.empty()) { + return false; + } + CertEntry entry; + const uint8_t type_byte = in[0]; + in.remove_prefix(1); + + if (type_byte == 0) { + break; + } + + entry.type = static_cast(type_byte); + + switch (entry.type) { + case CertEntry::COMPRESSED: + out_certs->push_back(std::string()); + break; + case CertEntry::CACHED: { + if (in.size() < sizeof(uint64_t)) { + return false; + } + memcpy(&entry.hash, in.data(), sizeof(uint64_t)); + in.remove_prefix(sizeof(uint64_t)); + + if (cached_hashes.size() != cached_certs.size()) { + cached_hashes = HashCerts(cached_certs); + } + bool found = false; + for (size_t i = 0; i < cached_hashes.size(); i++) { + if (cached_hashes[i] == entry.hash) { + out_certs->push_back(cached_certs[i]); + found = true; + break; + } + } + if (!found) { + return false; + } + break; + } + + default: + return false; + } + out_entries->push_back(entry); + } + + *in_out = in; + return true; +} + +// ScopedZLib deals with the automatic destruction of a zlib context. +class ScopedZLib { + public: + enum Type { + INFLATE, + DEFLATE, + }; + + explicit ScopedZLib(Type type) : z_(nullptr), type_(type) {} + + void reset(z_stream* z) { + Clear(); + z_ = z; + } + + ~ScopedZLib() { Clear(); } + + private: + void Clear() { + if (!z_) { + return; + } + + if (type_ == DEFLATE) { + deflateEnd(z_); + } else { + inflateEnd(z_); + } + z_ = nullptr; + } + + z_stream* z_; + const Type type_; +}; + +} // anonymous namespace + +// static +std::string CertCompressor::CompressChain( + const std::vector& certs, + absl::string_view client_cached_cert_hashes) { + const std::vector entries = + MatchCerts(certs, client_cached_cert_hashes); + QUICHE_DCHECK_EQ(entries.size(), certs.size()); + + size_t uncompressed_size = 0; + for (size_t i = 0; i < entries.size(); i++) { + if (entries[i].type == CertEntry::COMPRESSED) { + uncompressed_size += 4 /* uint32_t length */ + certs[i].size(); + } + } + + size_t compressed_size = 0; + z_stream z; + ScopedZLib scoped_z(ScopedZLib::DEFLATE); + + if (uncompressed_size > 0) { + memset(&z, 0, sizeof(z)); + int rv = deflateInit(&z, Z_DEFAULT_COMPRESSION); + QUICHE_DCHECK_EQ(Z_OK, rv); + if (rv != Z_OK) { + return ""; + } + scoped_z.reset(&z); + + std::string zlib_dict = ZlibDictForEntries(entries, certs); + + rv = deflateSetDictionary( + &z, reinterpret_cast(&zlib_dict[0]), zlib_dict.size()); + QUICHE_DCHECK_EQ(Z_OK, rv); + if (rv != Z_OK) { + return ""; + } + + compressed_size = deflateBound(&z, uncompressed_size); + } + + const size_t entries_size = CertEntriesSize(entries); + + std::string result; + result.resize(entries_size + (uncompressed_size > 0 ? 4 : 0) + + compressed_size); + + uint8_t* j = reinterpret_cast(&result[0]); + SerializeCertEntries(j, entries); + j += entries_size; + + if (uncompressed_size == 0) { + return result; + } + + uint32_t uncompressed_size_32 = uncompressed_size; + memcpy(j, &uncompressed_size_32, sizeof(uint32_t)); + j += sizeof(uint32_t); + + int rv; + + z.next_out = j; + z.avail_out = compressed_size; + + for (size_t i = 0; i < certs.size(); i++) { + if (entries[i].type != CertEntry::COMPRESSED) { + continue; + } + + uint32_t length32 = certs[i].size(); + z.next_in = reinterpret_cast(&length32); + z.avail_in = sizeof(length32); + rv = deflate(&z, Z_NO_FLUSH); + QUICHE_DCHECK_EQ(Z_OK, rv); + QUICHE_DCHECK_EQ(0u, z.avail_in); + if (rv != Z_OK || z.avail_in) { + return ""; + } + + z.next_in = + const_cast(reinterpret_cast(certs[i].data())); + z.avail_in = certs[i].size(); + rv = deflate(&z, Z_NO_FLUSH); + QUICHE_DCHECK_EQ(Z_OK, rv); + QUICHE_DCHECK_EQ(0u, z.avail_in); + if (rv != Z_OK || z.avail_in) { + return ""; + } + } + + z.avail_in = 0; + rv = deflate(&z, Z_FINISH); + QUICHE_DCHECK_EQ(Z_STREAM_END, rv); + if (rv != Z_STREAM_END) { + return ""; + } + + result.resize(result.size() - z.avail_out); + return result; +} + +// static +bool CertCompressor::DecompressChain( + absl::string_view in, const std::vector& cached_certs, + std::vector* out_certs) { + std::vector entries; + if (!ParseEntries(&in, cached_certs, &entries, out_certs)) { + return false; + } + QUICHE_DCHECK_EQ(entries.size(), out_certs->size()); + + std::unique_ptr uncompressed_data; + absl::string_view uncompressed; + + if (!in.empty()) { + if (in.size() < sizeof(uint32_t)) { + return false; + } + + uint32_t uncompressed_size; + memcpy(&uncompressed_size, in.data(), sizeof(uncompressed_size)); + in.remove_prefix(sizeof(uint32_t)); + + if (uncompressed_size > 128 * 1024) { + return false; + } + + uncompressed_data = std::make_unique(uncompressed_size); + z_stream z; + ScopedZLib scoped_z(ScopedZLib::INFLATE); + + memset(&z, 0, sizeof(z)); + z.next_out = uncompressed_data.get(); + z.avail_out = uncompressed_size; + z.next_in = + const_cast(reinterpret_cast(in.data())); + z.avail_in = in.size(); + + if (Z_OK != inflateInit(&z)) { + return false; + } + scoped_z.reset(&z); + + int rv = inflate(&z, Z_FINISH); + if (rv == Z_NEED_DICT) { + std::string zlib_dict = ZlibDictForEntries(entries, *out_certs); + const uint8_t* dict = reinterpret_cast(zlib_dict.data()); + if (Z_OK != inflateSetDictionary(&z, dict, zlib_dict.size())) { + return false; + } + rv = inflate(&z, Z_FINISH); + } + + if (Z_STREAM_END != rv || z.avail_out > 0 || z.avail_in > 0) { + return false; + } + + uncompressed = absl::string_view( + reinterpret_cast(uncompressed_data.get()), uncompressed_size); + } + + for (size_t i = 0; i < entries.size(); i++) { + switch (entries[i].type) { + case CertEntry::COMPRESSED: + if (uncompressed.size() < sizeof(uint32_t)) { + return false; + } + uint32_t cert_len; + memcpy(&cert_len, uncompressed.data(), sizeof(cert_len)); + uncompressed.remove_prefix(sizeof(uint32_t)); + if (uncompressed.size() < cert_len) { + return false; + } + (*out_certs)[i] = std::string(uncompressed.substr(0, cert_len)); + uncompressed.remove_prefix(cert_len); + break; + case CertEntry::CACHED: + break; + } + } + + if (!uncompressed.empty()) { + return false; + } + + return true; +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/cert_compressor.h b/quiche/quic/core/crypto/cert_compressor.h new file mode 100644 index 000000000000..9509ccd0f5de --- /dev/null +++ b/quiche/quic/core/crypto/cert_compressor.h @@ -0,0 +1,45 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_CERT_COMPRESSOR_H_ +#define QUICHE_QUIC_CORE_CRYPTO_CERT_COMPRESSOR_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// CertCompressor provides functions for compressing and decompressing +// certificate chains using two techniquies: +// 1) The peer may provide a list of a 64-bit, FNV-1a hashes of certificates +// that they already have. In the event that one of them is to be +// compressed, it can be replaced with just the hash. +// 2) Otherwise the certificates are compressed with zlib using a pre-shared +// dictionary that consists of the certificates handled with the above +// methods and a small chunk of common substrings. +class QUIC_EXPORT_PRIVATE CertCompressor { + public: + CertCompressor() = delete; + + // CompressChain compresses the certificates in |certs| and returns a + // compressed representation. client_cached_cert_hashes| contains + // 64-bit, FNV-1a hashes of certificates that the peer already possesses. + static std::string CompressChain(const std::vector& certs, + absl::string_view client_cached_cert_hashes); + + // DecompressChain decompresses the result of |CompressChain|, given in |in|, + // into a series of certificates that are written to |out_certs|. + // |cached_certs| contains certificates that the peer may have omitted. + static bool DecompressChain(absl::string_view in, + const std::vector& cached_certs, + std::vector* out_certs); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_CERT_COMPRESSOR_H_ diff --git a/quiche/quic/core/crypto/cert_compressor_test.cc b/quiche/quic/core/crypto/cert_compressor_test.cc new file mode 100644 index 000000000000..d98f4c770f54 --- /dev/null +++ b/quiche/quic/core/crypto/cert_compressor_test.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/cert_compressor.h" + +#include +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" + +namespace quic { +namespace test { + +class CertCompressorTest : public QuicTest {}; + +TEST_F(CertCompressorTest, EmptyChain) { + std::vector chain; + const std::string compressed = + CertCompressor::CompressChain(chain, absl::string_view()); + EXPECT_EQ("00", absl::BytesToHexString(compressed)); + + std::vector chain2, cached_certs; + ASSERT_TRUE( + CertCompressor::DecompressChain(compressed, cached_certs, &chain2)); + EXPECT_EQ(chain.size(), chain2.size()); +} + +TEST_F(CertCompressorTest, Compressed) { + std::vector chain; + chain.push_back("testcert"); + const std::string compressed = + CertCompressor::CompressChain(chain, absl::string_view()); + ASSERT_GE(compressed.size(), 2u); + EXPECT_EQ("0100", absl::BytesToHexString(compressed.substr(0, 2))); + + std::vector chain2, cached_certs; + ASSERT_TRUE( + CertCompressor::DecompressChain(compressed, cached_certs, &chain2)); + EXPECT_EQ(chain.size(), chain2.size()); + EXPECT_EQ(chain[0], chain2[0]); +} + +TEST_F(CertCompressorTest, Common) { + std::vector chain; + chain.push_back("testcert"); + static const uint64_t set_hash = 42; + const std::string compressed = CertCompressor::CompressChain( + chain, absl::string_view(reinterpret_cast(&set_hash), + sizeof(set_hash))); + ASSERT_GE(compressed.size(), 2u); + // 01 is the prefix for a zlib "compressed" cert not common or cached. + EXPECT_EQ("0100", absl::BytesToHexString(compressed.substr(0, 2))); + + std::vector chain2, cached_certs; + ASSERT_TRUE( + CertCompressor::DecompressChain(compressed, cached_certs, &chain2)); + EXPECT_EQ(chain.size(), chain2.size()); + EXPECT_EQ(chain[0], chain2[0]); +} + +TEST_F(CertCompressorTest, Cached) { + std::vector chain; + chain.push_back("testcert"); + uint64_t hash = QuicUtils::FNV1a_64_Hash(chain[0]); + absl::string_view hash_bytes(reinterpret_cast(&hash), sizeof(hash)); + const std::string compressed = + CertCompressor::CompressChain(chain, hash_bytes); + + EXPECT_EQ("02" /* cached */ + absl::BytesToHexString(hash_bytes) + + "00" /* end of list */, + absl::BytesToHexString(compressed)); + + std::vector cached_certs, chain2; + cached_certs.push_back(chain[0]); + ASSERT_TRUE( + CertCompressor::DecompressChain(compressed, cached_certs, &chain2)); + EXPECT_EQ(chain.size(), chain2.size()); + EXPECT_EQ(chain[0], chain2[0]); +} + +TEST_F(CertCompressorTest, BadInputs) { + std::vector cached_certs, chain; + + EXPECT_FALSE(CertCompressor::DecompressChain( + absl::BytesToHexString("04") /* bad entry type */, cached_certs, &chain)); + + EXPECT_FALSE(CertCompressor::DecompressChain( + absl::BytesToHexString("01") /* no terminator */, cached_certs, &chain)); + + EXPECT_FALSE(CertCompressor::DecompressChain( + absl::BytesToHexString("0200") /* hash truncated */, cached_certs, + &chain)); + + EXPECT_FALSE(CertCompressor::DecompressChain( + absl::BytesToHexString("0300") /* hash and index truncated */, + cached_certs, &chain)); + + /* without a CommonCertSets */ + EXPECT_FALSE( + CertCompressor::DecompressChain(absl::BytesToHexString("03" + "0000000000000000" + "00000000"), + cached_certs, &chain)); + + /* incorrect hash and index */ + EXPECT_FALSE( + CertCompressor::DecompressChain(absl::BytesToHexString("03" + "a200000000000000" + "00000000"), + cached_certs, &chain)); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/certificate_util.cc b/quiche/quic/core/crypto/certificate_util.cc new file mode 100644 index 000000000000..1f2ce870eb7b --- /dev/null +++ b/quiche/quic/core/crypto/certificate_util.cc @@ -0,0 +1,280 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/certificate_util.h" + +#include "absl/strings/str_format.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "openssl/bn.h" +#include "openssl/bytestring.h" +#include "openssl/digest.h" +#include "openssl/ec_key.h" +#include "openssl/mem.h" +#include "openssl/pkcs7.h" +#include "openssl/pool.h" +#include "openssl/rsa.h" +#include "openssl/stack.h" +#include "quiche/quic/core/crypto/boring_utils.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { +namespace { +bool AddEcdsa256SignatureAlgorithm(CBB* cbb) { + // See RFC 5758. This is the encoding of OID 1.2.840.10045.4.3.2. + static const uint8_t kEcdsaWithSha256[] = {0x2a, 0x86, 0x48, 0xce, + 0x3d, 0x04, 0x03, 0x02}; + + // An AlgorithmIdentifier is described in RFC 5280, 4.1.1.2. + CBB sequence, oid; + if (!CBB_add_asn1(cbb, &sequence, CBS_ASN1_SEQUENCE) || + !CBB_add_asn1(&sequence, &oid, CBS_ASN1_OBJECT)) { + return false; + } + + if (!CBB_add_bytes(&oid, kEcdsaWithSha256, sizeof(kEcdsaWithSha256))) { + return false; + } + + // RFC 5758, section 3.2: ecdsa-with-sha256 MUST omit the parameters field. + return CBB_flush(cbb); +} + +// Adds an X.509 Name with the specified distinguished name to |cbb|. +bool AddName(CBB* cbb, absl::string_view name) { + // See RFC 4519. + static const uint8_t kCommonName[] = {0x55, 0x04, 0x03}; + static const uint8_t kCountryName[] = {0x55, 0x04, 0x06}; + static const uint8_t kOrganizationName[] = {0x55, 0x04, 0x0a}; + static const uint8_t kOrganizationalUnitName[] = {0x55, 0x04, 0x0b}; + + std::vector attributes = + absl::StrSplit(name, ',', absl::SkipEmpty()); + + if (attributes.empty()) { + QUIC_LOG(ERROR) << "Missing DN or wrong format"; + return false; + } + + // See RFC 5280, section 4.1.2.4. + CBB rdns; + if (!CBB_add_asn1(cbb, &rdns, CBS_ASN1_SEQUENCE)) { + return false; + } + + for (const std::string& attribute : attributes) { + std::vector parts = + absl::StrSplit(absl::StripAsciiWhitespace(attribute), '='); + if (parts.size() != 2) { + QUIC_LOG(ERROR) << "Wrong DN format at " + attribute; + return false; + } + + const std::string& type_string = parts[0]; + const std::string& value_string = parts[1]; + absl::Span type_bytes; + if (type_string == "CN") { + type_bytes = kCommonName; + } else if (type_string == "C") { + type_bytes = kCountryName; + } else if (type_string == "O") { + type_bytes = kOrganizationName; + } else if (type_string == "OU") { + type_bytes = kOrganizationalUnitName; + } else { + QUIC_LOG(ERROR) << "Unrecognized type " + type_string; + return false; + } + + CBB rdn, attr, type, value; + if (!CBB_add_asn1(&rdns, &rdn, CBS_ASN1_SET) || + !CBB_add_asn1(&rdn, &attr, CBS_ASN1_SEQUENCE) || + !CBB_add_asn1(&attr, &type, CBS_ASN1_OBJECT) || + !CBB_add_bytes(&type, type_bytes.data(), type_bytes.size()) || + !CBB_add_asn1(&attr, &value, + type_string == "C" ? CBS_ASN1_PRINTABLESTRING + : CBS_ASN1_UTF8STRING) || + !AddStringToCbb(&value, value_string) || !CBB_flush(&rdns)) { + return false; + } + } + if (!CBB_flush(cbb)) { + return false; + } + return true; +} + +bool CBBAddTime(CBB* cbb, const CertificateTimestamp& timestamp) { + CBB child; + std::string formatted_time; + + // Per RFC 5280, 4.1.2.5, times which fit in UTCTime must be encoded as + // UTCTime rather than GeneralizedTime. + const bool is_utc_time = (1950 <= timestamp.year && timestamp.year < 2050); + if (is_utc_time) { + uint16_t year = timestamp.year - 1900; + if (year >= 100) { + year -= 100; + } + formatted_time = absl::StrFormat("%02d", year); + if (!CBB_add_asn1(cbb, &child, CBS_ASN1_UTCTIME)) { + return false; + } + } else { + formatted_time = absl::StrFormat("%04d", timestamp.year); + if (!CBB_add_asn1(cbb, &child, CBS_ASN1_GENERALIZEDTIME)) { + return false; + } + } + + absl::StrAppendFormat(&formatted_time, "%02d%02d%02d%02d%02dZ", + timestamp.month, timestamp.day, timestamp.hour, + timestamp.minute, timestamp.second); + + static const size_t kGeneralizedTimeLength = 15; + static const size_t kUTCTimeLength = 13; + QUICHE_DCHECK_EQ(formatted_time.size(), + is_utc_time ? kUTCTimeLength : kGeneralizedTimeLength); + + return AddStringToCbb(&child, formatted_time) && CBB_flush(cbb); +} + +bool CBBAddExtension(CBB* extensions, absl::Span oid, + bool critical, absl::Span contents) { + CBB extension, cbb_oid, cbb_contents; + if (!CBB_add_asn1(extensions, &extension, CBS_ASN1_SEQUENCE) || + !CBB_add_asn1(&extension, &cbb_oid, CBS_ASN1_OBJECT) || + !CBB_add_bytes(&cbb_oid, oid.data(), oid.size()) || + (critical && !CBB_add_asn1_bool(&extension, 1)) || + !CBB_add_asn1(&extension, &cbb_contents, CBS_ASN1_OCTETSTRING) || + !CBB_add_bytes(&cbb_contents, contents.data(), contents.size()) || + !CBB_flush(extensions)) { + return false; + } + + return true; +} + +bool IsEcdsa256Key(const EVP_PKEY& evp_key) { + if (EVP_PKEY_id(&evp_key) != EVP_PKEY_EC) { + return false; + } + const EC_KEY* key = EVP_PKEY_get0_EC_KEY(&evp_key); + if (key == nullptr) { + return false; + } + const EC_GROUP* group = EC_KEY_get0_group(key); + if (group == nullptr) { + return false; + } + return EC_GROUP_get_curve_name(group) == NID_X9_62_prime256v1; +} + +} // namespace + +bssl::UniquePtr MakeKeyPairForSelfSignedCertificate() { + bssl::UniquePtr context( + EVP_PKEY_CTX_new_id(EVP_PKEY_EC, nullptr)); + if (!context) { + return nullptr; + } + if (EVP_PKEY_keygen_init(context.get()) != 1) { + return nullptr; + } + if (EVP_PKEY_CTX_set_ec_paramgen_curve_nid(context.get(), + NID_X9_62_prime256v1) != 1) { + return nullptr; + } + EVP_PKEY* raw_key = nullptr; + if (EVP_PKEY_keygen(context.get(), &raw_key) != 1) { + return nullptr; + } + return bssl::UniquePtr(raw_key); +} + +std::string CreateSelfSignedCertificate(EVP_PKEY& key, + const CertificateOptions& options) { + std::string error; + if (!IsEcdsa256Key(key)) { + QUIC_LOG(ERROR) << "CreateSelfSignedCert only accepts ECDSA P-256 keys"; + return error; + } + + // See RFC 5280, section 4.1. First, construct the TBSCertificate. + bssl::ScopedCBB cbb; + CBB tbs_cert, version, validity; + uint8_t* tbs_cert_bytes; + size_t tbs_cert_len; + + if (!CBB_init(cbb.get(), 64) || + !CBB_add_asn1(cbb.get(), &tbs_cert, CBS_ASN1_SEQUENCE) || + !CBB_add_asn1(&tbs_cert, &version, + CBS_ASN1_CONTEXT_SPECIFIC | CBS_ASN1_CONSTRUCTED | 0) || + !CBB_add_asn1_uint64(&version, 2) || // X.509 version 3 + !CBB_add_asn1_uint64(&tbs_cert, options.serial_number) || + !AddEcdsa256SignatureAlgorithm(&tbs_cert) || // signature algorithm + !AddName(&tbs_cert, options.subject) || // issuer + !CBB_add_asn1(&tbs_cert, &validity, CBS_ASN1_SEQUENCE) || + !CBBAddTime(&validity, options.validity_start) || + !CBBAddTime(&validity, options.validity_end) || + !AddName(&tbs_cert, options.subject) || // subject + !EVP_marshal_public_key(&tbs_cert, &key)) { // subjectPublicKeyInfo + return error; + } + + CBB outer_extensions, extensions; + if (!CBB_add_asn1(&tbs_cert, &outer_extensions, + 3 | CBS_ASN1_CONTEXT_SPECIFIC | CBS_ASN1_CONSTRUCTED) || + !CBB_add_asn1(&outer_extensions, &extensions, CBS_ASN1_SEQUENCE)) { + return error; + } + + // Key Usage + constexpr uint8_t kKeyUsageOid[] = {0x55, 0x1d, 0x0f}; + constexpr uint8_t kKeyUsageContent[] = { + 0x3, // BIT STRING + 0x2, // Length + 0x0, // Unused bits + 0x80, // bit(0): digitalSignature + }; + CBBAddExtension(&extensions, kKeyUsageOid, true, kKeyUsageContent); + + // TODO(wub): Add more extensions here if needed. + + if (!CBB_finish(cbb.get(), &tbs_cert_bytes, &tbs_cert_len)) { + return error; + } + + bssl::UniquePtr delete_tbs_cert_bytes(tbs_cert_bytes); + + // Sign the TBSCertificate and write the entire certificate. + CBB cert, signature; + bssl::ScopedEVP_MD_CTX ctx; + uint8_t* sig_out; + size_t sig_len; + uint8_t* cert_bytes; + size_t cert_len; + if (!CBB_init(cbb.get(), tbs_cert_len) || + !CBB_add_asn1(cbb.get(), &cert, CBS_ASN1_SEQUENCE) || + !CBB_add_bytes(&cert, tbs_cert_bytes, tbs_cert_len) || + !AddEcdsa256SignatureAlgorithm(&cert) || + !CBB_add_asn1(&cert, &signature, CBS_ASN1_BITSTRING) || + !CBB_add_u8(&signature, 0 /* no unused bits */) || + !EVP_DigestSignInit(ctx.get(), nullptr, EVP_sha256(), nullptr, &key) || + // Compute the maximum signature length. + !EVP_DigestSign(ctx.get(), nullptr, &sig_len, tbs_cert_bytes, + tbs_cert_len) || + !CBB_reserve(&signature, &sig_out, sig_len) || + // Actually sign the TBSCertificate. + !EVP_DigestSign(ctx.get(), sig_out, &sig_len, tbs_cert_bytes, + tbs_cert_len) || + !CBB_did_write(&signature, sig_len) || + !CBB_finish(cbb.get(), &cert_bytes, &cert_len)) { + return error; + } + bssl::UniquePtr delete_cert_bytes(cert_bytes); + return std::string(reinterpret_cast(cert_bytes), cert_len); +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/certificate_util.h b/quiche/quic/core/crypto/certificate_util.h new file mode 100644 index 000000000000..35bb5611e064 --- /dev/null +++ b/quiche/quic/core/crypto/certificate_util.h @@ -0,0 +1,46 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_CERTIFICATE_UTIL_H_ +#define QUICHE_QUIC_CORE_CRYPTO_CERTIFICATE_UTIL_H_ + +#include + +#include "absl/strings/string_view.h" +#include "openssl/evp.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +struct QUIC_NO_EXPORT CertificateTimestamp { + uint16_t year; + uint8_t month; + uint8_t day; + uint8_t hour; + uint8_t minute; + uint8_t second; +}; + +struct QUIC_NO_EXPORT CertificateOptions { + absl::string_view subject; + uint64_t serial_number; + CertificateTimestamp validity_start; // a.k.a not_valid_before + CertificateTimestamp validity_end; // a.k.a not_valid_after +}; + +// Creates a ECDSA P-256 key pair. +QUIC_EXPORT_PRIVATE bssl::UniquePtr +MakeKeyPairForSelfSignedCertificate(); + +// Creates a self-signed, DER-encoded X.509 certificate. +// |key| must be a ECDSA P-256 key. +// This is mostly stolen from Chromium's net/cert/x509_util.h, with +// modifications to make it work in QUICHE. +QUIC_EXPORT_PRIVATE std::string CreateSelfSignedCertificate( + EVP_PKEY& key, const CertificateOptions& options); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_CERTIFICATE_UTIL_H_ diff --git a/quiche/quic/core/crypto/certificate_util_test.cc b/quiche/quic/core/crypto/certificate_util_test.cc new file mode 100644 index 000000000000..4c98d7cab812 --- /dev/null +++ b/quiche/quic/core/crypto/certificate_util_test.cc @@ -0,0 +1,49 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/certificate_util.h" + +#include "openssl/ssl.h" +#include "quiche/quic/core/crypto/certificate_view.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/platform/api/quic_test_output.h" + +namespace quic { +namespace test { +namespace { + +TEST(CertificateUtilTest, CreateSelfSignedCertificate) { + bssl::UniquePtr key = MakeKeyPairForSelfSignedCertificate(); + ASSERT_NE(key, nullptr); + + CertificatePrivateKey cert_key(std::move(key)); + + CertificateOptions options; + options.subject = "CN=subject"; + options.serial_number = 0x12345678; + options.validity_start = {2020, 1, 1, 0, 0, 0}; + options.validity_end = {2049, 12, 31, 0, 0, 0}; + std::string der_cert = + CreateSelfSignedCertificate(*cert_key.private_key(), options); + ASSERT_FALSE(der_cert.empty()); + + QuicSaveTestOutput("CertificateUtilTest_CreateSelfSignedCert.crt", der_cert); + + std::unique_ptr cert_view = + CertificateView::ParseSingleCertificate(der_cert); + ASSERT_NE(cert_view, nullptr); + EXPECT_EQ(cert_view->public_key_type(), PublicKeyType::kP256); + + absl::optional subject = cert_view->GetHumanReadableSubject(); + ASSERT_TRUE(subject.has_value()); + EXPECT_EQ(*subject, options.subject); + + EXPECT_TRUE( + cert_key.ValidForSignatureAlgorithm(SSL_SIGN_ECDSA_SECP256R1_SHA256)); + EXPECT_TRUE(cert_key.MatchesPublicKey(*cert_view)); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/certificate_view.cc b/quiche/quic/core/crypto/certificate_view.cc new file mode 100644 index 000000000000..c3b187ce7665 --- /dev/null +++ b/quiche/quic/core/crypto/certificate_view.cc @@ -0,0 +1,664 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/certificate_view.h" + +#include +#include +#include +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "openssl/base.h" +#include "openssl/bytestring.h" +#include "openssl/digest.h" +#include "openssl/ec.h" +#include "openssl/ec_key.h" +#include "openssl/evp.h" +#include "openssl/nid.h" +#include "openssl/rsa.h" +#include "openssl/ssl.h" +#include "quiche/quic/core/crypto/boring_utils.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/platform/api/quiche_time_utils.h" +#include "quiche/common/quiche_data_reader.h" +#include "quiche/common/quiche_text_utils.h" + +namespace quic { +namespace { + +using ::quiche::QuicheTextUtils; + +// The literals below were encoded using `ascii2der | xxd -i`. The comments +// above the literals are the contents in the der2ascii syntax. + +// X.509 version 3 (version numbering starts with zero). +// INTEGER { 2 } +constexpr uint8_t kX509Version[] = {0x02, 0x01, 0x02}; + +// 2.5.29.17 +constexpr uint8_t kSubjectAltNameOid[] = {0x55, 0x1d, 0x11}; + +PublicKeyType PublicKeyTypeFromKey(EVP_PKEY* public_key) { + switch (EVP_PKEY_id(public_key)) { + case EVP_PKEY_RSA: + return PublicKeyType::kRsa; + case EVP_PKEY_EC: { + const EC_KEY* key = EVP_PKEY_get0_EC_KEY(public_key); + if (key == nullptr) { + return PublicKeyType::kUnknown; + } + const EC_GROUP* group = EC_KEY_get0_group(key); + if (group == nullptr) { + return PublicKeyType::kUnknown; + } + const int curve_nid = EC_GROUP_get_curve_name(group); + switch (curve_nid) { + case NID_X9_62_prime256v1: + return PublicKeyType::kP256; + case NID_secp384r1: + return PublicKeyType::kP384; + default: + return PublicKeyType::kUnknown; + } + } + case EVP_PKEY_ED25519: + return PublicKeyType::kEd25519; + default: + return PublicKeyType::kUnknown; + } +} + +} // namespace + +PublicKeyType PublicKeyTypeFromSignatureAlgorithm( + uint16_t signature_algorithm) { + // This should be kept in sync with the list in + // SupportedSignatureAlgorithmsForQuic(). + switch (signature_algorithm) { + case SSL_SIGN_RSA_PSS_RSAE_SHA256: + return PublicKeyType::kRsa; + case SSL_SIGN_ECDSA_SECP256R1_SHA256: + return PublicKeyType::kP256; + case SSL_SIGN_ECDSA_SECP384R1_SHA384: + return PublicKeyType::kP384; + case SSL_SIGN_ED25519: + return PublicKeyType::kEd25519; + default: + return PublicKeyType::kUnknown; + } +} + +QUIC_EXPORT_PRIVATE QuicSignatureAlgorithmVector +SupportedSignatureAlgorithmsForQuic() { + // This should be kept in sync with the list in + // PublicKeyTypeFromSignatureAlgorithm(). + return QuicSignatureAlgorithmVector{ + SSL_SIGN_ED25519, SSL_SIGN_ECDSA_SECP256R1_SHA256, + SSL_SIGN_ECDSA_SECP384R1_SHA384, SSL_SIGN_RSA_PSS_RSAE_SHA256}; +} + +namespace { + +std::string AttributeNameToString(const CBS& oid_cbs) { + absl::string_view oid = CbsToStringPiece(oid_cbs); + + // We only handle OIDs of form 2.5.4.N, which have binary encoding of + // "55 04 0N". + if (oid.length() == 3 && absl::StartsWith(oid, "\x55\x04")) { + // clang-format off + switch (oid[2]) { + case '\x3': return "CN"; + case '\x7': return "L"; + case '\x8': return "ST"; + case '\xa': return "O"; + case '\xb': return "OU"; + case '\x6': return "C"; + } + // clang-format on + } + + bssl::UniquePtr oid_representation(CBS_asn1_oid_to_text(&oid_cbs)); + if (oid_representation == nullptr) { + return absl::StrCat("(", absl::BytesToHexString(oid), ")"); + } + return std::string(oid_representation.get()); +} + +} // namespace + +absl::optional X509NameAttributeToString(CBS input) { + CBS name, value; + unsigned value_tag; + if (!CBS_get_asn1(&input, &name, CBS_ASN1_OBJECT) || + !CBS_get_any_asn1(&input, &value, &value_tag) || CBS_len(&input) != 0) { + return absl::nullopt; + } + // Note that this does not process encoding of |input| in any way. This works + // fine for the most cases. + return absl::StrCat(AttributeNameToString(name), "=", + absl::CHexEscape(CbsToStringPiece(value))); +} + +namespace { + +template (*parser)(CBS)> +absl::optional ParseAndJoin(CBS input) { + std::vector pieces; + while (CBS_len(&input) != 0) { + CBS attribute; + if (!CBS_get_asn1(&input, &attribute, inner_tag)) { + return absl::nullopt; + } + absl::optional formatted = parser(attribute); + if (!formatted.has_value()) { + return absl::nullopt; + } + pieces.push_back(*formatted); + } + + return absl::StrJoin(pieces, std::string({separator})); +} + +absl::optional RelativeDistinguishedNameToString(CBS input) { + return ParseAndJoin(input); +} + +absl::optional DistinguishedNameToString(CBS input) { + return ParseAndJoin( + input); +} + +} // namespace + +std::string PublicKeyTypeToString(PublicKeyType type) { + switch (type) { + case PublicKeyType::kRsa: + return "RSA"; + case PublicKeyType::kP256: + return "ECDSA P-256"; + case PublicKeyType::kP384: + return "ECDSA P-384"; + case PublicKeyType::kEd25519: + return "Ed25519"; + case PublicKeyType::kUnknown: + return "unknown"; + } + return ""; +} + +absl::optional ParseDerTime(unsigned tag, + absl::string_view payload) { + if (tag != CBS_ASN1_GENERALIZEDTIME && tag != CBS_ASN1_UTCTIME) { + QUIC_DLOG(WARNING) << "Invalid tag supplied for a DER timestamp"; + return absl::nullopt; + } + + const size_t year_length = tag == CBS_ASN1_GENERALIZEDTIME ? 4 : 2; + uint64_t year, month, day, hour, minute, second; + quiche::QuicheDataReader reader(payload); + if (!reader.ReadDecimal64(year_length, &year) || + !reader.ReadDecimal64(2, &month) || !reader.ReadDecimal64(2, &day) || + !reader.ReadDecimal64(2, &hour) || !reader.ReadDecimal64(2, &minute) || + !reader.ReadDecimal64(2, &second) || + reader.ReadRemainingPayload() != "Z") { + QUIC_DLOG(WARNING) << "Failed to parse the DER timestamp"; + return absl::nullopt; + } + + if (tag == CBS_ASN1_UTCTIME) { + QUICHE_DCHECK_LE(year, 100u); + year += (year >= 50) ? 1900 : 2000; + } + + const absl::optional unix_time = + quiche::QuicheUtcDateTimeToUnixSeconds(year, month, day, hour, minute, + second); + if (!unix_time.has_value() || *unix_time < 0) { + return absl::nullopt; + } + return QuicWallTime::FromUNIXSeconds(*unix_time); +} + +PemReadResult ReadNextPemMessage(std::istream* input) { + constexpr absl::string_view kPemBegin = "-----BEGIN "; + constexpr absl::string_view kPemEnd = "-----END "; + constexpr absl::string_view kPemDashes = "-----"; + + std::string line_buffer, encoded_message_contents, expected_end; + bool pending_message = false; + PemReadResult result; + while (std::getline(*input, line_buffer)) { + absl::string_view line(line_buffer); + QuicheTextUtils::RemoveLeadingAndTrailingWhitespace(&line); + + // Handle BEGIN lines. + if (!pending_message && absl::StartsWith(line, kPemBegin) && + absl::EndsWith(line, kPemDashes)) { + result.type = std::string( + line.substr(kPemBegin.size(), + line.size() - kPemDashes.size() - kPemBegin.size())); + expected_end = absl::StrCat(kPemEnd, result.type, kPemDashes); + pending_message = true; + continue; + } + + // Handle END lines. + if (pending_message && line == expected_end) { + absl::optional data = + QuicheTextUtils::Base64Decode(encoded_message_contents); + if (data.has_value()) { + result.status = PemReadResult::kOk; + result.contents = data.value(); + } else { + result.status = PemReadResult::kError; + } + return result; + } + + if (pending_message) { + encoded_message_contents.append(std::string(line)); + } + } + bool eof_reached = input->eof() && !pending_message; + return PemReadResult{ + (eof_reached ? PemReadResult::kEof : PemReadResult::kError), "", ""}; +} + +std::unique_ptr CertificateView::ParseSingleCertificate( + absl::string_view certificate) { + std::unique_ptr result(new CertificateView()); + CBS top = StringPieceToCbs(certificate); + + CBS top_certificate, tbs_certificate, signature_algorithm, signature; + if (!CBS_get_asn1(&top, &top_certificate, CBS_ASN1_SEQUENCE) || + CBS_len(&top) != 0) { + return nullptr; + } + + // Certificate ::= SEQUENCE { + if ( + // tbsCertificate TBSCertificate, + !CBS_get_asn1(&top_certificate, &tbs_certificate, CBS_ASN1_SEQUENCE) || + + // signatureAlgorithm AlgorithmIdentifier, + !CBS_get_asn1(&top_certificate, &signature_algorithm, + CBS_ASN1_SEQUENCE) || + + // signature BIT STRING } + !CBS_get_asn1(&top_certificate, &signature, CBS_ASN1_BITSTRING) || + CBS_len(&top_certificate) != 0) { + return nullptr; + } + + int has_version, has_extensions; + CBS version, serial, signature_algorithm_inner, issuer, validity, subject, + spki, issuer_id, subject_id, extensions_outer; + // TBSCertificate ::= SEQUENCE { + if ( + // version [0] Version DEFAULT v1, + !CBS_get_optional_asn1( + &tbs_certificate, &version, &has_version, + CBS_ASN1_CONSTRUCTED | CBS_ASN1_CONTEXT_SPECIFIC | 0) || + + // serialNumber CertificateSerialNumber, + !CBS_get_asn1(&tbs_certificate, &serial, CBS_ASN1_INTEGER) || + + // signature AlgorithmIdentifier, + !CBS_get_asn1(&tbs_certificate, &signature_algorithm_inner, + CBS_ASN1_SEQUENCE) || + + // issuer Name, + !CBS_get_asn1(&tbs_certificate, &issuer, CBS_ASN1_SEQUENCE) || + + // validity Validity, + !CBS_get_asn1(&tbs_certificate, &validity, CBS_ASN1_SEQUENCE) || + + // subject Name, + !CBS_get_asn1(&tbs_certificate, &subject, CBS_ASN1_SEQUENCE) || + + // subjectPublicKeyInfo SubjectPublicKeyInfo, + !CBS_get_asn1_element(&tbs_certificate, &spki, CBS_ASN1_SEQUENCE) || + + // issuerUniqueID [1] IMPLICIT UniqueIdentifier OPTIONAL, + // -- If present, version MUST be v2 or v3 + !CBS_get_optional_asn1(&tbs_certificate, &issuer_id, nullptr, + CBS_ASN1_CONTEXT_SPECIFIC | 1) || + + // subjectUniqueID [2] IMPLICIT UniqueIdentifier OPTIONAL, + // -- If present, version MUST be v2 or v3 + !CBS_get_optional_asn1(&tbs_certificate, &subject_id, nullptr, + CBS_ASN1_CONTEXT_SPECIFIC | 2) || + + // extensions [3] Extensions OPTIONAL + // -- If present, version MUST be v3 -- } + !CBS_get_optional_asn1( + &tbs_certificate, &extensions_outer, &has_extensions, + CBS_ASN1_CONSTRUCTED | CBS_ASN1_CONTEXT_SPECIFIC | 3) || + + CBS_len(&tbs_certificate) != 0) { + return nullptr; + } + + result->subject_der_ = CbsToStringPiece(subject); + + unsigned not_before_tag, not_after_tag; + CBS not_before, not_after; + if (!CBS_get_any_asn1(&validity, ¬_before, ¬_before_tag) || + !CBS_get_any_asn1(&validity, ¬_after, ¬_after_tag) || + CBS_len(&validity) != 0) { + QUIC_DLOG(WARNING) << "Failed to extract the validity dates"; + return nullptr; + } + absl::optional not_before_parsed = + ParseDerTime(not_before_tag, CbsToStringPiece(not_before)); + absl::optional not_after_parsed = + ParseDerTime(not_after_tag, CbsToStringPiece(not_after)); + if (!not_before_parsed.has_value() || !not_after_parsed.has_value()) { + QUIC_DLOG(WARNING) << "Failed to parse validity dates"; + return nullptr; + } + result->validity_start_ = *not_before_parsed; + result->validity_end_ = *not_after_parsed; + + result->public_key_.reset(EVP_parse_public_key(&spki)); + if (result->public_key_ == nullptr) { + QUIC_DLOG(WARNING) << "Failed to parse the public key"; + return nullptr; + } + if (!result->ValidatePublicKeyParameters()) { + QUIC_DLOG(WARNING) << "Public key has invalid parameters"; + return nullptr; + } + + // Only support X.509v3. + if (!has_version || + !CBS_mem_equal(&version, kX509Version, sizeof(kX509Version))) { + QUIC_DLOG(WARNING) << "Bad X.509 version"; + return nullptr; + } + + if (!has_extensions) { + return nullptr; + } + + CBS extensions; + if (!CBS_get_asn1(&extensions_outer, &extensions, CBS_ASN1_SEQUENCE) || + CBS_len(&extensions_outer) != 0) { + QUIC_DLOG(WARNING) << "Failed to extract the extension sequence"; + return nullptr; + } + if (!result->ParseExtensions(extensions)) { + QUIC_DLOG(WARNING) << "Failed to parse extensions"; + return nullptr; + } + + return result; +} + +bool CertificateView::ParseExtensions(CBS extensions) { + while (CBS_len(&extensions) != 0) { + CBS extension, oid, critical, payload; + if ( + // Extension ::= SEQUENCE { + !CBS_get_asn1(&extensions, &extension, CBS_ASN1_SEQUENCE) || + // extnID OBJECT IDENTIFIER, + !CBS_get_asn1(&extension, &oid, CBS_ASN1_OBJECT) || + // critical BOOLEAN DEFAULT FALSE, + !CBS_get_optional_asn1(&extension, &critical, nullptr, + CBS_ASN1_BOOLEAN) || + // extnValue OCTET STRING + // -- contains the DER encoding of an ASN.1 value + // -- corresponding to the extension type identified + // -- by extnID + !CBS_get_asn1(&extension, &payload, CBS_ASN1_OCTETSTRING) || + CBS_len(&extension) != 0) { + QUIC_DLOG(WARNING) << "Bad extension entry"; + return false; + } + + if (CBS_mem_equal(&oid, kSubjectAltNameOid, sizeof(kSubjectAltNameOid))) { + CBS alt_names; + if (!CBS_get_asn1(&payload, &alt_names, CBS_ASN1_SEQUENCE) || + CBS_len(&payload) != 0) { + QUIC_DLOG(WARNING) << "Failed to parse subjectAltName"; + return false; + } + while (CBS_len(&alt_names) != 0) { + CBS alt_name_cbs; + unsigned int alt_name_tag; + if (!CBS_get_any_asn1(&alt_names, &alt_name_cbs, &alt_name_tag)) { + QUIC_DLOG(WARNING) << "Failed to parse subjectAltName"; + return false; + } + + absl::string_view alt_name = CbsToStringPiece(alt_name_cbs); + QuicIpAddress ip_address; + // GeneralName ::= CHOICE { + switch (alt_name_tag) { + // dNSName [2] IA5String, + case CBS_ASN1_CONTEXT_SPECIFIC | 2: + subject_alt_name_domains_.push_back(alt_name); + break; + + // iPAddress [7] OCTET STRING, + case CBS_ASN1_CONTEXT_SPECIFIC | 7: + if (!ip_address.FromPackedString(alt_name.data(), + alt_name.size())) { + QUIC_DLOG(WARNING) << "Failed to parse subjectAltName IP address"; + return false; + } + subject_alt_name_ips_.push_back(ip_address); + break; + + default: + QUIC_DLOG(INFO) << "Unknown subjectAltName tag " << alt_name_tag; + continue; + } + } + } + } + + return true; +} + +std::vector CertificateView::LoadPemFromStream( + std::istream* input) { + std::vector result; + for (;;) { + PemReadResult read_result = ReadNextPemMessage(input); + if (read_result.status == PemReadResult::kEof) { + return result; + } + if (read_result.status != PemReadResult::kOk) { + return std::vector(); + } + if (read_result.type != "CERTIFICATE") { + continue; + } + result.emplace_back(std::move(read_result.contents)); + } +} + +PublicKeyType CertificateView::public_key_type() const { + return PublicKeyTypeFromKey(public_key_.get()); +} + +bool CertificateView::ValidatePublicKeyParameters() { + // The profile here affects what certificates can be used when QUIC is used as + // a server library without any custom certificate provider logic. + // The goal is to allow at minimum any certificate that would be allowed on a + // regular Web session over TLS 1.3 while ensuring we do not expose any + // algorithms we don't want to support long-term. + PublicKeyType key_type = PublicKeyTypeFromKey(public_key_.get()); + switch (key_type) { + case PublicKeyType::kRsa: + return EVP_PKEY_bits(public_key_.get()) >= 2048; + case PublicKeyType::kP256: + case PublicKeyType::kP384: + case PublicKeyType::kEd25519: + return true; + default: + return false; + } +} + +bool CertificateView::VerifySignature(absl::string_view data, + absl::string_view signature, + uint16_t signature_algorithm) const { + if (PublicKeyTypeFromSignatureAlgorithm(signature_algorithm) != + PublicKeyTypeFromKey(public_key_.get())) { + QUIC_BUG(quic_bug_10640_1) + << "Mismatch between the requested signature algorithm and the " + "type of the public key."; + return false; + } + + bssl::ScopedEVP_MD_CTX md_ctx; + EVP_PKEY_CTX* pctx; + if (!EVP_DigestVerifyInit( + md_ctx.get(), &pctx, + SSL_get_signature_algorithm_digest(signature_algorithm), nullptr, + public_key_.get())) { + return false; + } + if (SSL_is_signature_algorithm_rsa_pss(signature_algorithm)) { + if (!EVP_PKEY_CTX_set_rsa_padding(pctx, RSA_PKCS1_PSS_PADDING) || + !EVP_PKEY_CTX_set_rsa_pss_saltlen(pctx, -1)) { + return false; + } + } + return EVP_DigestVerify( + md_ctx.get(), reinterpret_cast(signature.data()), + signature.size(), reinterpret_cast(data.data()), + data.size()); +} + +absl::optional CertificateView::GetHumanReadableSubject() const { + CBS input = StringPieceToCbs(subject_der_); + return DistinguishedNameToString(input); +} + +std::unique_ptr CertificatePrivateKey::LoadFromDer( + absl::string_view private_key) { + std::unique_ptr result(new CertificatePrivateKey()); + CBS private_key_cbs = StringPieceToCbs(private_key); + result->private_key_.reset(EVP_parse_private_key(&private_key_cbs)); + if (result->private_key_ == nullptr || CBS_len(&private_key_cbs) != 0) { + return nullptr; + } + return result; +} + +std::unique_ptr CertificatePrivateKey::LoadPemFromStream( + std::istream* input) { +skip: + PemReadResult result = ReadNextPemMessage(input); + if (result.status != PemReadResult::kOk) { + return nullptr; + } + // RFC 5958 OneAsymmetricKey message. + if (result.type == "PRIVATE KEY") { + return LoadFromDer(result.contents); + } + // Legacy OpenSSL format: PKCS#1 (RFC 8017) RSAPrivateKey message. + if (result.type == "RSA PRIVATE KEY") { + CBS private_key_cbs = StringPieceToCbs(result.contents); + bssl::UniquePtr rsa(RSA_parse_private_key(&private_key_cbs)); + if (rsa == nullptr || CBS_len(&private_key_cbs) != 0) { + return nullptr; + } + + std::unique_ptr key(new CertificatePrivateKey()); + key->private_key_.reset(EVP_PKEY_new()); + EVP_PKEY_assign_RSA(key->private_key_.get(), rsa.release()); + return key; + } + // EC keys are sometimes generated with "openssl ecparam -genkey". If the user + // forgets -noout, OpenSSL will output a redundant copy of the EC parameters. + // Skip those. + if (result.type == "EC PARAMETERS") { + goto skip; + } + // Legacy OpenSSL format: RFC 5915 ECPrivateKey message. + if (result.type == "EC PRIVATE KEY") { + CBS private_key_cbs = StringPieceToCbs(result.contents); + bssl::UniquePtr ec_key( + EC_KEY_parse_private_key(&private_key_cbs, /*group=*/nullptr)); + if (ec_key == nullptr || CBS_len(&private_key_cbs) != 0) { + return nullptr; + } + + std::unique_ptr key(new CertificatePrivateKey()); + key->private_key_.reset(EVP_PKEY_new()); + EVP_PKEY_assign_EC_KEY(key->private_key_.get(), ec_key.release()); + return key; + } + // Unknown format. + return nullptr; +} + +std::string CertificatePrivateKey::Sign(absl::string_view input, + uint16_t signature_algorithm) const { + if (!ValidForSignatureAlgorithm(signature_algorithm)) { + QUIC_BUG(quic_bug_10640_2) + << "Mismatch between the requested signature algorithm and the " + "type of the private key."; + return ""; + } + + bssl::ScopedEVP_MD_CTX md_ctx; + EVP_PKEY_CTX* pctx; + if (!EVP_DigestSignInit( + md_ctx.get(), &pctx, + SSL_get_signature_algorithm_digest(signature_algorithm), + /*e=*/nullptr, private_key_.get())) { + return ""; + } + if (SSL_is_signature_algorithm_rsa_pss(signature_algorithm)) { + if (!EVP_PKEY_CTX_set_rsa_padding(pctx, RSA_PKCS1_PSS_PADDING) || + !EVP_PKEY_CTX_set_rsa_pss_saltlen(pctx, -1)) { + return ""; + } + } + + std::string output; + size_t output_size; + if (!EVP_DigestSign(md_ctx.get(), /*out_sig=*/nullptr, &output_size, + reinterpret_cast(input.data()), + input.size())) { + return ""; + } + output.resize(output_size); + if (!EVP_DigestSign( + md_ctx.get(), reinterpret_cast(&output[0]), &output_size, + reinterpret_cast(input.data()), input.size())) { + return ""; + } + output.resize(output_size); + return output; +} + +bool CertificatePrivateKey::MatchesPublicKey( + const CertificateView& view) const { + return EVP_PKEY_cmp(view.public_key(), private_key_.get()) == 1; +} + +bool CertificatePrivateKey::ValidForSignatureAlgorithm( + uint16_t signature_algorithm) const { + return PublicKeyTypeFromSignatureAlgorithm(signature_algorithm) == + PublicKeyTypeFromKey(private_key_.get()); +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/certificate_view.h b/quiche/quic/core/crypto/certificate_view.h new file mode 100644 index 000000000000..5c2aafc1af32 --- /dev/null +++ b/quiche/quic/core/crypto/certificate_view.h @@ -0,0 +1,156 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_CERTIFICATE_VIEW_H_ +#define QUICHE_QUIC_CORE_CRYPTO_CERTIFICATE_VIEW_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "openssl/base.h" +#include "openssl/bytestring.h" +#include "openssl/evp.h" +#include "quiche/quic/core/crypto/boring_utils.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_ip_address.h" + +namespace quic { + +struct QUIC_EXPORT_PRIVATE PemReadResult { + enum Status { kOk, kEof, kError }; + Status status; + std::string contents; + // The type of the PEM message (e.g., if the message starts with + // "-----BEGIN CERTIFICATE-----", the |type| would be "CERTIFICATE"). + std::string type; +}; + +// Reads |input| line-by-line and returns the next available PEM message. +QUIC_EXPORT_PRIVATE PemReadResult ReadNextPemMessage(std::istream* input); + +// Cryptograhpic algorithms recognized in X.509. +enum class PublicKeyType { + kRsa, + kP256, + kP384, + kEd25519, + kUnknown, +}; +QUIC_EXPORT_PRIVATE std::string PublicKeyTypeToString(PublicKeyType type); +QUIC_EXPORT_PRIVATE PublicKeyType +PublicKeyTypeFromSignatureAlgorithm(uint16_t signature_algorithm); + +// Returns the list of the signature algorithms that can be processed by +// CertificateView::VerifySignature() and CertificatePrivateKey::Sign(). +QUIC_EXPORT_PRIVATE QuicSignatureAlgorithmVector +SupportedSignatureAlgorithmsForQuic(); + +// CertificateView represents a parsed version of a single X.509 certificate. As +// the word "view" implies, it does not take ownership of the underlying strings +// and consists primarily of pointers into the certificate that is passed into +// the parser. +class QUIC_EXPORT_PRIVATE CertificateView { + public: + // Parses a single DER-encoded X.509 certificate. Returns nullptr on parse + // error. + static std::unique_ptr ParseSingleCertificate( + absl::string_view certificate); + + // Loads all PEM-encoded X.509 certificates found in the |input| stream + // without parsing them. Returns an empty vector if any parsing error occurs. + static std::vector LoadPemFromStream(std::istream* input); + + QuicWallTime validity_start() const { return validity_start_; } + QuicWallTime validity_end() const { return validity_end_; } + const EVP_PKEY* public_key() const { return public_key_.get(); } + + const std::vector& subject_alt_name_domains() const { + return subject_alt_name_domains_; + } + const std::vector& subject_alt_name_ips() const { + return subject_alt_name_ips_; + } + + // Returns a human-readable representation of the Subject field. The format + // is similar to RFC 2253, but does not match it exactly. + absl::optional GetHumanReadableSubject() const; + + // |signature_algorithm| is a TLS signature algorithm ID. + bool VerifySignature(absl::string_view data, absl::string_view signature, + uint16_t signature_algorithm) const; + + // Returns the type of the key used in the certificate's SPKI. + PublicKeyType public_key_type() const; + + private: + CertificateView() = default; + + QuicWallTime validity_start_ = QuicWallTime::Zero(); + QuicWallTime validity_end_ = QuicWallTime::Zero(); + absl::string_view subject_der_; + + // Public key parsed from SPKI. + bssl::UniquePtr public_key_; + + // SubjectAltName, https://tools.ietf.org/html/rfc5280#section-4.2.1.6 + std::vector subject_alt_name_domains_; + std::vector subject_alt_name_ips_; + + // Called from ParseSingleCertificate(). + bool ParseExtensions(CBS extensions); + bool ValidatePublicKeyParameters(); +}; + +// CertificatePrivateKey represents a private key that can be used with an X.509 +// certificate. +class QUIC_EXPORT_PRIVATE CertificatePrivateKey { + public: + explicit CertificatePrivateKey(bssl::UniquePtr private_key) + : private_key_(std::move(private_key)) {} + + // Loads a DER-encoded PrivateKeyInfo structure (RFC 5958) as a private key. + static std::unique_ptr LoadFromDer( + absl::string_view private_key); + + // Loads a private key from a PEM file formatted according to RFC 7468. Also + // supports legacy OpenSSL RSA key format ("BEGIN RSA PRIVATE KEY"). + static std::unique_ptr LoadPemFromStream( + std::istream* input); + + // |signature_algorithm| is a TLS signature algorithm ID. + std::string Sign(absl::string_view input, uint16_t signature_algorithm) const; + + // Verifies that the private key in question matches the public key of the + // certificate |view|. + bool MatchesPublicKey(const CertificateView& view) const; + + // Verifies that the private key can be used with the specified TLS signature + // algorithm. + bool ValidForSignatureAlgorithm(uint16_t signature_algorithm) const; + + EVP_PKEY* private_key() const { return private_key_.get(); } + + private: + CertificatePrivateKey() = default; + + bssl::UniquePtr private_key_; +}; + +// Parses a DER-encoded X.509 NameAttribute. Exposed primarily for testing. +QUIC_EXPORT_PRIVATE absl::optional X509NameAttributeToString( + CBS input); + +// Parses a DER time based on the specified ASN.1 tag. Exposed primarily for +// testing. +QUIC_EXPORT_PRIVATE absl::optional ParseDerTime( + unsigned tag, absl::string_view payload); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_CERTIFICATE_VIEW_H_ diff --git a/quiche/quic/core/crypto/certificate_view_der_fuzzer.cc b/quiche/quic/core/crypto/certificate_view_der_fuzzer.cc new file mode 100644 index 000000000000..81c91eb943b8 --- /dev/null +++ b/quiche/quic/core/crypto/certificate_view_der_fuzzer.cc @@ -0,0 +1,19 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include + +#include "quiche/quic/core/crypto/certificate_view.h" + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + std::string input(reinterpret_cast(data), size); + + std::unique_ptr view = + quic::CertificateView::ParseSingleCertificate(input); + if (view != nullptr) { + view->GetHumanReadableSubject(); + } + quic::CertificatePrivateKey::LoadFromDer(input); + return 0; +} diff --git a/quiche/quic/core/crypto/certificate_view_pem_fuzzer.cc b/quiche/quic/core/crypto/certificate_view_pem_fuzzer.cc new file mode 100644 index 000000000000..e6d70e51218b --- /dev/null +++ b/quiche/quic/core/crypto/certificate_view_pem_fuzzer.cc @@ -0,0 +1,18 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include +#include + +#include "quiche/quic/core/crypto/certificate_view.h" + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + std::string input(reinterpret_cast(data), size); + std::stringstream stream(input); + + quic::CertificateView::LoadPemFromStream(&stream); + stream.seekg(0); + quic::CertificatePrivateKey::LoadPemFromStream(&stream); + return 0; +} diff --git a/quiche/quic/core/crypto/certificate_view_test.cc b/quiche/quic/core/crypto/certificate_view_test.cc new file mode 100644 index 000000000000..d142ae45215d --- /dev/null +++ b/quiche/quic/core/crypto/certificate_view_test.cc @@ -0,0 +1,230 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/certificate_view.h" + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "openssl/base.h" +#include "openssl/bytestring.h" +#include "openssl/evp.h" +#include "openssl/ssl.h" +#include "quiche/quic/core/crypto/boring_utils.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/test_certificates.h" +#include "quiche/common/platform/api/quiche_time_utils.h" + +namespace quic { +namespace test { +namespace { + +using ::testing::ElementsAre; +using ::testing::HasSubstr; +using ::testing::Optional; + +TEST(CertificateViewTest, PemParser) { + std::stringstream stream(kTestCertificatePem); + PemReadResult result = ReadNextPemMessage(&stream); + EXPECT_EQ(result.status, PemReadResult::kOk); + EXPECT_EQ(result.type, "CERTIFICATE"); + EXPECT_EQ(result.contents, kTestCertificate); + + result = ReadNextPemMessage(&stream); + EXPECT_EQ(result.status, PemReadResult::kEof); +} + +TEST(CertificateViewTest, Parse) { + std::unique_ptr view = + CertificateView::ParseSingleCertificate(kTestCertificate); + ASSERT_TRUE(view != nullptr); + + EXPECT_THAT(view->subject_alt_name_domains(), + ElementsAre(absl::string_view("www.example.org"), + absl::string_view("mail.example.org"), + absl::string_view("mail.example.com"))); + EXPECT_THAT(view->subject_alt_name_ips(), + ElementsAre(QuicIpAddress::Loopback4())); + EXPECT_EQ(EVP_PKEY_id(view->public_key()), EVP_PKEY_RSA); + + const QuicWallTime validity_start = QuicWallTime::FromUNIXSeconds( + *quiche::QuicheUtcDateTimeToUnixSeconds(2020, 1, 30, 18, 13, 59)); + EXPECT_EQ(view->validity_start(), validity_start); + const QuicWallTime validity_end = QuicWallTime::FromUNIXSeconds( + *quiche::QuicheUtcDateTimeToUnixSeconds(2020, 2, 2, 18, 13, 59)); + EXPECT_EQ(view->validity_end(), validity_end); + EXPECT_EQ(view->public_key_type(), PublicKeyType::kRsa); + EXPECT_EQ(PublicKeyTypeToString(view->public_key_type()), "RSA"); + + EXPECT_EQ("C=US,ST=California,L=Mountain View,O=QUIC Server,CN=127.0.0.1", + view->GetHumanReadableSubject()); +} + +TEST(CertificateViewTest, ParseCertWithUnknownSanType) { + std::stringstream stream(kTestCertWithUnknownSanTypePem); + PemReadResult result = ReadNextPemMessage(&stream); + EXPECT_EQ(result.status, PemReadResult::kOk); + EXPECT_EQ(result.type, "CERTIFICATE"); + + std::unique_ptr view = + CertificateView::ParseSingleCertificate(result.contents); + EXPECT_TRUE(view != nullptr); +} + +TEST(CertificateViewTest, PemSingleCertificate) { + std::stringstream pem_stream(kTestCertificatePem); + std::vector chain = + CertificateView::LoadPemFromStream(&pem_stream); + EXPECT_THAT(chain, ElementsAre(kTestCertificate)); +} + +TEST(CertificateViewTest, PemMultipleCertificates) { + std::stringstream pem_stream(kTestCertificateChainPem); + std::vector chain = + CertificateView::LoadPemFromStream(&pem_stream); + EXPECT_THAT(chain, + ElementsAre(kTestCertificate, HasSubstr("QUIC Server Root CA"))); +} + +TEST(CertificateViewTest, PemNoCertificates) { + std::stringstream pem_stream("one\ntwo\nthree\n"); + std::vector chain = + CertificateView::LoadPemFromStream(&pem_stream); + EXPECT_TRUE(chain.empty()); +} + +TEST(CertificateViewTest, SignAndVerify) { + std::unique_ptr key = + CertificatePrivateKey::LoadFromDer(kTestCertificatePrivateKey); + ASSERT_TRUE(key != nullptr); + + std::string data = "A really important message"; + std::string signature = key->Sign(data, SSL_SIGN_RSA_PSS_RSAE_SHA256); + ASSERT_FALSE(signature.empty()); + + std::unique_ptr view = + CertificateView::ParseSingleCertificate(kTestCertificate); + ASSERT_TRUE(view != nullptr); + EXPECT_TRUE(key->MatchesPublicKey(*view)); + + EXPECT_TRUE( + view->VerifySignature(data, signature, SSL_SIGN_RSA_PSS_RSAE_SHA256)); + EXPECT_FALSE(view->VerifySignature("An unimportant message", signature, + SSL_SIGN_RSA_PSS_RSAE_SHA256)); + EXPECT_FALSE(view->VerifySignature(data, "Not a signature", + SSL_SIGN_RSA_PSS_RSAE_SHA256)); +} + +TEST(CertificateViewTest, PrivateKeyPem) { + std::unique_ptr view = + CertificateView::ParseSingleCertificate(kTestCertificate); + ASSERT_TRUE(view != nullptr); + + std::stringstream pem_stream(kTestCertificatePrivateKeyPem); + std::unique_ptr pem_key = + CertificatePrivateKey::LoadPemFromStream(&pem_stream); + ASSERT_TRUE(pem_key != nullptr); + EXPECT_TRUE(pem_key->MatchesPublicKey(*view)); + + std::stringstream legacy_stream(kTestCertificatePrivateKeyLegacyPem); + std::unique_ptr legacy_key = + CertificatePrivateKey::LoadPemFromStream(&legacy_stream); + ASSERT_TRUE(legacy_key != nullptr); + EXPECT_TRUE(legacy_key->MatchesPublicKey(*view)); +} + +TEST(CertificateViewTest, PrivateKeyEcdsaPem) { + std::stringstream pem_stream(kTestEcPrivateKeyLegacyPem); + std::unique_ptr key = + CertificatePrivateKey::LoadPemFromStream(&pem_stream); + ASSERT_TRUE(key != nullptr); + EXPECT_TRUE(key->ValidForSignatureAlgorithm(SSL_SIGN_ECDSA_SECP256R1_SHA256)); +} + +TEST(CertificateViewTest, DerTime) { + EXPECT_THAT(ParseDerTime(CBS_ASN1_GENERALIZEDTIME, "19700101000024Z"), + Optional(QuicWallTime::FromUNIXSeconds(24))); + EXPECT_THAT(ParseDerTime(CBS_ASN1_GENERALIZEDTIME, "19710101000024Z"), + Optional(QuicWallTime::FromUNIXSeconds(365 * 86400 + 24))); + EXPECT_THAT(ParseDerTime(CBS_ASN1_UTCTIME, "700101000024Z"), + Optional(QuicWallTime::FromUNIXSeconds(24))); + EXPECT_TRUE(ParseDerTime(CBS_ASN1_UTCTIME, "200101000024Z").has_value()); + + EXPECT_EQ(ParseDerTime(CBS_ASN1_GENERALIZEDTIME, ""), absl::nullopt); + EXPECT_EQ(ParseDerTime(CBS_ASN1_GENERALIZEDTIME, "19700101000024.001Z"), + absl::nullopt); + EXPECT_EQ(ParseDerTime(CBS_ASN1_GENERALIZEDTIME, "19700101000024Q"), + absl::nullopt); + EXPECT_EQ(ParseDerTime(CBS_ASN1_GENERALIZEDTIME, "19700101000024-0500"), + absl::nullopt); + EXPECT_EQ(ParseDerTime(CBS_ASN1_GENERALIZEDTIME, "700101000024ZZ"), + absl::nullopt); + EXPECT_EQ(ParseDerTime(CBS_ASN1_GENERALIZEDTIME, "19700101000024.00Z"), + absl::nullopt); + EXPECT_EQ(ParseDerTime(CBS_ASN1_GENERALIZEDTIME, "19700101000024.Z"), + absl::nullopt); + EXPECT_EQ(ParseDerTime(CBS_ASN1_GENERALIZEDTIME, "197O0101000024Z"), + absl::nullopt); + EXPECT_EQ(ParseDerTime(CBS_ASN1_GENERALIZEDTIME, "19700101000024.0O1Z"), + absl::nullopt); + EXPECT_EQ(ParseDerTime(CBS_ASN1_GENERALIZEDTIME, "-9700101000024Z"), + absl::nullopt); + EXPECT_EQ(ParseDerTime(CBS_ASN1_GENERALIZEDTIME, "1970-101000024Z"), + absl::nullopt); + + EXPECT_TRUE(ParseDerTime(CBS_ASN1_UTCTIME, "490101000024Z").has_value()); + // This should parse as 1950, which predates UNIX epoch. + EXPECT_FALSE(ParseDerTime(CBS_ASN1_UTCTIME, "500101000024Z").has_value()); + + EXPECT_THAT(ParseDerTime(CBS_ASN1_GENERALIZEDTIME, "19700101230000Z"), + Optional(QuicWallTime::FromUNIXSeconds(23 * 3600))); + EXPECT_EQ(ParseDerTime(CBS_ASN1_GENERALIZEDTIME, "19700101240000Z"), + absl::nullopt); +} + +TEST(CertificateViewTest, NameAttribute) { + // OBJECT_IDENTIFIER { 1.2.840.113554.4.1.112411 } + // UTF8String { "Test" } + std::string unknown_oid = + absl::HexStringToBytes("060b2a864886f712040186ee1b0c0454657374"); + EXPECT_EQ("1.2.840.113554.4.1.112411=Test", + X509NameAttributeToString(StringPieceToCbs(unknown_oid))); + + // OBJECT_IDENTIFIER { 2.5.4.3 } + // UTF8String { "Bell: \x07" } + std::string non_printable = + absl::HexStringToBytes("06035504030c0742656c6c3a2007"); + EXPECT_EQ(R"(CN=Bell: \x07)", + X509NameAttributeToString(StringPieceToCbs(non_printable))); + + // OBJECT_IDENTIFIER { "\x55\x80" } + // UTF8String { "Test" } + std::string invalid_oid = absl::HexStringToBytes("060255800c0454657374"); + EXPECT_EQ("(5580)=Test", + X509NameAttributeToString(StringPieceToCbs(invalid_oid))); +} + +TEST(CertificateViewTest, SupportedSignatureAlgorithmsForQuicIsUpToDate) { + QuicSignatureAlgorithmVector supported = + SupportedSignatureAlgorithmsForQuic(); + for (int i = 0; i < std::numeric_limits::max(); i++) { + uint16_t sigalg = static_cast(i); + PublicKeyType key_type = PublicKeyTypeFromSignatureAlgorithm(sigalg); + if (absl::c_find(supported, sigalg) == supported.end()) { + EXPECT_EQ(key_type, PublicKeyType::kUnknown); + } else { + EXPECT_NE(key_type, PublicKeyType::kUnknown); + } + } +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/chacha20_poly1305_decrypter.cc b/quiche/quic/core/crypto/chacha20_poly1305_decrypter.cc new file mode 100644 index 000000000000..31758b432457 --- /dev/null +++ b/quiche/quic/core/crypto/chacha20_poly1305_decrypter.cc @@ -0,0 +1,41 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/chacha20_poly1305_decrypter.h" + +#include "openssl/aead.h" +#include "openssl/tls1.h" + +namespace quic { + +namespace { + +const size_t kKeySize = 32; +const size_t kNonceSize = 12; + +} // namespace + +ChaCha20Poly1305Decrypter::ChaCha20Poly1305Decrypter() + : ChaChaBaseDecrypter(EVP_aead_chacha20_poly1305, kKeySize, kAuthTagSize, + kNonceSize, + /* use_ietf_nonce_construction */ false) { + static_assert(kKeySize <= kMaxKeySize, "key size too big"); + static_assert(kNonceSize <= kMaxNonceSize, "nonce size too big"); +} + +ChaCha20Poly1305Decrypter::~ChaCha20Poly1305Decrypter() {} + +uint32_t ChaCha20Poly1305Decrypter::cipher_id() const { + return TLS1_CK_CHACHA20_POLY1305_SHA256; +} + +QuicPacketCount ChaCha20Poly1305Decrypter::GetIntegrityLimit() const { + // For AEAD_CHACHA20_POLY1305, the integrity limit is 2^36 invalid packets. + // https://quicwg.org/base-drafts/draft-ietf-quic-tls.html#name-limits-on-aead-usage + static_assert(kMaxIncomingPacketSize < 16384, + "This key limit requires limits on decryption payload sizes"); + return 68719476736U; +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/chacha20_poly1305_decrypter.h b/quiche/quic/core/crypto/chacha20_poly1305_decrypter.h new file mode 100644 index 000000000000..6eb6c87f0a74 --- /dev/null +++ b/quiche/quic/core/crypto/chacha20_poly1305_decrypter.h @@ -0,0 +1,41 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_CHACHA20_POLY1305_DECRYPTER_H_ +#define QUICHE_QUIC_CORE_CRYPTO_CHACHA20_POLY1305_DECRYPTER_H_ + +#include + +#include "quiche/quic/core/crypto/chacha_base_decrypter.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// A ChaCha20Poly1305Decrypter is a QuicDecrypter that implements the +// AEAD_CHACHA20_POLY1305 algorithm specified in RFC 7539, except that +// it truncates the Poly1305 authenticator to 12 bytes. Create an instance +// by calling QuicDecrypter::Create(kCC20). +// +// It uses an authentication tag of 12 bytes (96 bits). The fixed prefix of the +// nonce is four bytes. +class QUIC_EXPORT_PRIVATE ChaCha20Poly1305Decrypter + : public ChaChaBaseDecrypter { + public: + enum { + kAuthTagSize = 12, + }; + + ChaCha20Poly1305Decrypter(); + ChaCha20Poly1305Decrypter(const ChaCha20Poly1305Decrypter&) = delete; + ChaCha20Poly1305Decrypter& operator=(const ChaCha20Poly1305Decrypter&) = + delete; + ~ChaCha20Poly1305Decrypter() override; + + uint32_t cipher_id() const override; + QuicPacketCount GetIntegrityLimit() const override; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_CHACHA20_POLY1305_DECRYPTER_H_ diff --git a/quiche/quic/core/crypto/chacha20_poly1305_decrypter_test.cc b/quiche/quic/core/crypto/chacha20_poly1305_decrypter_test.cc new file mode 100644 index 000000000000..019c56b59029 --- /dev/null +++ b/quiche/quic/core/crypto/chacha20_poly1305_decrypter_test.cc @@ -0,0 +1,178 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/chacha20_poly1305_decrypter.h" + +#include +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +namespace { + +// The test vectors come from RFC 7539 Section 2.8.2. + +// Each test vector consists of six strings of lowercase hexadecimal digits. +// The strings may be empty (zero length). A test vector with a nullptr |key| +// marks the end of an array of test vectors. +struct TestVector { + // Input: + const char* key; + const char* iv; + const char* fixed; + const char* aad; + const char* ct; + + // Expected output: + const char* pt; // An empty string "" means decryption succeeded and + // the plaintext is zero-length. nullptr means decryption + // failed. +}; + +const TestVector test_vectors[] = { + {"808182838485868788898a8b8c8d8e8f" + "909192939495969798999a9b9c9d9e9f", + + "4041424344454647", + + "07000000", + + "50515253c0c1c2c3c4c5c6c7", + + "d31a8d34648e60db7b86afbc53ef7ec2" + "a4aded51296e08fea9e2b5a736ee62d6" + "3dbea45e8ca9671282fafb69da92728b" + "1a71de0a9e060b2905d6a5b67ecd3b36" + "92ddbd7f2d778b8c9803aee328091b58" + "fab324e4fad675945585808b4831d7bc" + "3ff4def08e4b7a9de576d26586cec64b" + "6116" + "1ae10b594f09e26a7e902ecb", // "d0600691" truncated + + "4c616469657320616e642047656e746c" + "656d656e206f662074686520636c6173" + "73206f66202739393a20496620492063" + "6f756c64206f6666657220796f75206f" + "6e6c79206f6e652074697020666f7220" + "746865206675747572652c2073756e73" + "637265656e20776f756c642062652069" + "742e"}, + // Modify the ciphertext (Poly1305 authenticator). + {"808182838485868788898a8b8c8d8e8f" + "909192939495969798999a9b9c9d9e9f", + + "4041424344454647", + + "07000000", + + "50515253c0c1c2c3c4c5c6c7", + + "d31a8d34648e60db7b86afbc53ef7ec2" + "a4aded51296e08fea9e2b5a736ee62d6" + "3dbea45e8ca9671282fafb69da92728b" + "1a71de0a9e060b2905d6a5b67ecd3b36" + "92ddbd7f2d778b8c9803aee328091b58" + "fab324e4fad675945585808b4831d7bc" + "3ff4def08e4b7a9de576d26586cec64b" + "6116" + "1ae10b594f09e26a7e902ecc", // "d0600691" truncated + + nullptr}, + // Modify the associated data. + {"808182838485868788898a8b8c8d8e8f" + "909192939495969798999a9b9c9d9e9f", + + "4041424344454647", + + "07000000", + + "60515253c0c1c2c3c4c5c6c7", + + "d31a8d34648e60db7b86afbc53ef7ec2" + "a4aded51296e08fea9e2b5a736ee62d6" + "3dbea45e8ca9671282fafb69da92728b" + "1a71de0a9e060b2905d6a5b67ecd3b36" + "92ddbd7f2d778b8c9803aee328091b58" + "fab324e4fad675945585808b4831d7bc" + "3ff4def08e4b7a9de576d26586cec64b" + "6116" + "1ae10b594f09e26a7e902ecb", // "d0600691" truncated + + nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +} // namespace + +namespace quic { +namespace test { + +// DecryptWithNonce wraps the |Decrypt| method of |decrypter| to allow passing +// in an nonce and also to allocate the buffer needed for the plaintext. +QuicData* DecryptWithNonce(ChaCha20Poly1305Decrypter* decrypter, + absl::string_view nonce, + absl::string_view associated_data, + absl::string_view ciphertext) { + uint64_t packet_number; + absl::string_view nonce_prefix(nonce.data(), + nonce.size() - sizeof(packet_number)); + decrypter->SetNoncePrefix(nonce_prefix); + memcpy(&packet_number, nonce.data() + nonce_prefix.size(), + sizeof(packet_number)); + std::unique_ptr output(new char[ciphertext.length()]); + size_t output_length = 0; + const bool success = decrypter->DecryptPacket( + packet_number, associated_data, ciphertext, output.get(), &output_length, + ciphertext.length()); + if (!success) { + return nullptr; + } + return new QuicData(output.release(), output_length, true); +} + +class ChaCha20Poly1305DecrypterTest : public QuicTest {}; + +TEST_F(ChaCha20Poly1305DecrypterTest, Decrypt) { + for (size_t i = 0; test_vectors[i].key != nullptr; i++) { + // If not present then decryption is expected to fail. + bool has_pt = test_vectors[i].pt; + + // Decode the test vector. + std::string key = absl::HexStringToBytes(test_vectors[i].key); + std::string iv = absl::HexStringToBytes(test_vectors[i].iv); + std::string fixed = absl::HexStringToBytes(test_vectors[i].fixed); + std::string aad = absl::HexStringToBytes(test_vectors[i].aad); + std::string ct = absl::HexStringToBytes(test_vectors[i].ct); + std::string pt; + if (has_pt) { + pt = absl::HexStringToBytes(test_vectors[i].pt); + } + + ChaCha20Poly1305Decrypter decrypter; + ASSERT_TRUE(decrypter.SetKey(key)); + std::unique_ptr decrypted(DecryptWithNonce( + &decrypter, fixed + iv, + // This deliberately tests that the decrypter can handle an AAD that + // is set to nullptr, as opposed to a zero-length, non-nullptr pointer. + absl::string_view(aad.length() ? aad.data() : nullptr, aad.length()), + ct)); + if (!decrypted) { + EXPECT_FALSE(has_pt); + continue; + } + EXPECT_TRUE(has_pt); + + EXPECT_EQ(12u, ct.size() - decrypted->length()); + ASSERT_EQ(pt.length(), decrypted->length()); + quiche::test::CompareCharArraysWithHexError( + "plaintext", decrypted->data(), pt.length(), pt.data(), pt.length()); + } +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/chacha20_poly1305_encrypter.cc b/quiche/quic/core/crypto/chacha20_poly1305_encrypter.cc new file mode 100644 index 000000000000..1adad076d9b2 --- /dev/null +++ b/quiche/quic/core/crypto/chacha20_poly1305_encrypter.cc @@ -0,0 +1,35 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/chacha20_poly1305_encrypter.h" + +#include "openssl/evp.h" + +namespace quic { + +namespace { + +const size_t kKeySize = 32; +const size_t kNonceSize = 12; + +} // namespace + +ChaCha20Poly1305Encrypter::ChaCha20Poly1305Encrypter() + : ChaChaBaseEncrypter(EVP_aead_chacha20_poly1305, kKeySize, kAuthTagSize, + kNonceSize, + /* use_ietf_nonce_construction */ false) { + static_assert(kKeySize <= kMaxKeySize, "key size too big"); + static_assert(kNonceSize <= kMaxNonceSize, "nonce size too big"); +} + +ChaCha20Poly1305Encrypter::~ChaCha20Poly1305Encrypter() {} + +QuicPacketCount ChaCha20Poly1305Encrypter::GetConfidentialityLimit() const { + // For AEAD_CHACHA20_POLY1305, the confidentiality limit is greater than the + // number of possible packets (2^62) and so can be disregarded. + // https://quicwg.org/base-drafts/draft-ietf-quic-tls.html#name-limits-on-aead-usage + return std::numeric_limits::max(); +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/chacha20_poly1305_encrypter.h b/quiche/quic/core/crypto/chacha20_poly1305_encrypter.h new file mode 100644 index 000000000000..d37f26c8a968 --- /dev/null +++ b/quiche/quic/core/crypto/chacha20_poly1305_encrypter.h @@ -0,0 +1,38 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_CHACHA20_POLY1305_ENCRYPTER_H_ +#define QUICHE_QUIC_CORE_CRYPTO_CHACHA20_POLY1305_ENCRYPTER_H_ + +#include "quiche/quic/core/crypto/chacha_base_encrypter.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// A ChaCha20Poly1305Encrypter is a QuicEncrypter that implements the +// AEAD_CHACHA20_POLY1305 algorithm specified in RFC 7539, except that +// it truncates the Poly1305 authenticator to 12 bytes. Create an instance +// by calling QuicEncrypter::Create(kCC20). +// +// It uses an authentication tag of 12 bytes (96 bits). The fixed prefix of the +// nonce is four bytes. +class QUIC_EXPORT_PRIVATE ChaCha20Poly1305Encrypter + : public ChaChaBaseEncrypter { + public: + enum { + kAuthTagSize = 12, + }; + + ChaCha20Poly1305Encrypter(); + ChaCha20Poly1305Encrypter(const ChaCha20Poly1305Encrypter&) = delete; + ChaCha20Poly1305Encrypter& operator=(const ChaCha20Poly1305Encrypter&) = + delete; + ~ChaCha20Poly1305Encrypter() override; + + QuicPacketCount GetConfidentialityLimit() const override; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_CHACHA20_POLY1305_ENCRYPTER_H_ diff --git a/quiche/quic/core/crypto/chacha20_poly1305_encrypter_test.cc b/quiche/quic/core/crypto/chacha20_poly1305_encrypter_test.cc new file mode 100644 index 000000000000..9ae728a430a2 --- /dev/null +++ b/quiche/quic/core/crypto/chacha20_poly1305_encrypter_test.cc @@ -0,0 +1,159 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/chacha20_poly1305_encrypter.h" + +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/chacha20_poly1305_decrypter.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +namespace { + +// The test vectors come from RFC 7539 Section 2.8.2. + +// Each test vector consists of five strings of lowercase hexadecimal digits. +// The strings may be empty (zero length). A test vector with a nullptr |key| +// marks the end of an array of test vectors. +struct TestVector { + const char* key; + const char* pt; + const char* iv; + const char* fixed; + const char* aad; + const char* ct; +}; + +const TestVector test_vectors[] = { + { + "808182838485868788898a8b8c8d8e8f" + "909192939495969798999a9b9c9d9e9f", + + "4c616469657320616e642047656e746c" + "656d656e206f662074686520636c6173" + "73206f66202739393a20496620492063" + "6f756c64206f6666657220796f75206f" + "6e6c79206f6e652074697020666f7220" + "746865206675747572652c2073756e73" + "637265656e20776f756c642062652069" + "742e", + + "4041424344454647", + + "07000000", + + "50515253c0c1c2c3c4c5c6c7", + + "d31a8d34648e60db7b86afbc53ef7ec2" + "a4aded51296e08fea9e2b5a736ee62d6" + "3dbea45e8ca9671282fafb69da92728b" + "1a71de0a9e060b2905d6a5b67ecd3b36" + "92ddbd7f2d778b8c9803aee328091b58" + "fab324e4fad675945585808b4831d7bc" + "3ff4def08e4b7a9de576d26586cec64b" + "6116" + "1ae10b594f09e26a7e902ecb", // "d0600691" truncated + }, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +} // namespace + +namespace quic { +namespace test { + +// EncryptWithNonce wraps the |Encrypt| method of |encrypter| to allow passing +// in an nonce and also to allocate the buffer needed for the ciphertext. +QuicData* EncryptWithNonce(ChaCha20Poly1305Encrypter* encrypter, + absl::string_view nonce, + absl::string_view associated_data, + absl::string_view plaintext) { + size_t ciphertext_size = encrypter->GetCiphertextSize(plaintext.length()); + std::unique_ptr ciphertext(new char[ciphertext_size]); + + if (!encrypter->Encrypt(nonce, associated_data, plaintext, + reinterpret_cast(ciphertext.get()))) { + return nullptr; + } + + return new QuicData(ciphertext.release(), ciphertext_size, true); +} + +class ChaCha20Poly1305EncrypterTest : public QuicTest {}; + +TEST_F(ChaCha20Poly1305EncrypterTest, EncryptThenDecrypt) { + ChaCha20Poly1305Encrypter encrypter; + ChaCha20Poly1305Decrypter decrypter; + + std::string key = absl::HexStringToBytes(test_vectors[0].key); + ASSERT_TRUE(encrypter.SetKey(key)); + ASSERT_TRUE(decrypter.SetKey(key)); + ASSERT_TRUE(encrypter.SetNoncePrefix("abcd")); + ASSERT_TRUE(decrypter.SetNoncePrefix("abcd")); + + uint64_t packet_number = UINT64_C(0x123456789ABC); + std::string associated_data = "associated_data"; + std::string plaintext = "plaintext"; + char encrypted[1024]; + size_t len; + ASSERT_TRUE(encrypter.EncryptPacket(packet_number, associated_data, plaintext, + encrypted, &len, + ABSL_ARRAYSIZE(encrypted))); + absl::string_view ciphertext(encrypted, len); + char decrypted[1024]; + ASSERT_TRUE(decrypter.DecryptPacket(packet_number, associated_data, + ciphertext, decrypted, &len, + ABSL_ARRAYSIZE(decrypted))); +} + +TEST_F(ChaCha20Poly1305EncrypterTest, Encrypt) { + for (size_t i = 0; test_vectors[i].key != nullptr; i++) { + // Decode the test vector. + std::string key = absl::HexStringToBytes(test_vectors[i].key); + std::string pt = absl::HexStringToBytes(test_vectors[i].pt); + std::string iv = absl::HexStringToBytes(test_vectors[i].iv); + std::string fixed = absl::HexStringToBytes(test_vectors[i].fixed); + std::string aad = absl::HexStringToBytes(test_vectors[i].aad); + std::string ct = absl::HexStringToBytes(test_vectors[i].ct); + + ChaCha20Poly1305Encrypter encrypter; + ASSERT_TRUE(encrypter.SetKey(key)); + std::unique_ptr encrypted(EncryptWithNonce( + &encrypter, fixed + iv, + // This deliberately tests that the encrypter can handle an AAD that + // is set to nullptr, as opposed to a zero-length, non-nullptr pointer. + absl::string_view(aad.length() ? aad.data() : nullptr, aad.length()), + pt)); + ASSERT_TRUE(encrypted.get()); + EXPECT_EQ(12u, ct.size() - pt.size()); + EXPECT_EQ(12u, encrypted->length() - pt.size()); + + quiche::test::CompareCharArraysWithHexError("ciphertext", encrypted->data(), + encrypted->length(), ct.data(), + ct.length()); + } +} + +TEST_F(ChaCha20Poly1305EncrypterTest, GetMaxPlaintextSize) { + ChaCha20Poly1305Encrypter encrypter; + EXPECT_EQ(1000u, encrypter.GetMaxPlaintextSize(1012)); + EXPECT_EQ(100u, encrypter.GetMaxPlaintextSize(112)); + EXPECT_EQ(10u, encrypter.GetMaxPlaintextSize(22)); +} + +TEST_F(ChaCha20Poly1305EncrypterTest, GetCiphertextSize) { + ChaCha20Poly1305Encrypter encrypter; + EXPECT_EQ(1012u, encrypter.GetCiphertextSize(1000)); + EXPECT_EQ(112u, encrypter.GetCiphertextSize(100)); + EXPECT_EQ(22u, encrypter.GetCiphertextSize(10)); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/chacha20_poly1305_tls_decrypter.cc b/quiche/quic/core/crypto/chacha20_poly1305_tls_decrypter.cc new file mode 100644 index 000000000000..93b099352c19 --- /dev/null +++ b/quiche/quic/core/crypto/chacha20_poly1305_tls_decrypter.cc @@ -0,0 +1,43 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/chacha20_poly1305_tls_decrypter.h" + +#include "openssl/aead.h" +#include "openssl/tls1.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" + +namespace quic { + +namespace { + +const size_t kKeySize = 32; +const size_t kNonceSize = 12; + +} // namespace + +ChaCha20Poly1305TlsDecrypter::ChaCha20Poly1305TlsDecrypter() + : ChaChaBaseDecrypter(EVP_aead_chacha20_poly1305, kKeySize, kAuthTagSize, + kNonceSize, + /* use_ietf_nonce_construction */ true) { + static_assert(kKeySize <= kMaxKeySize, "key size too big"); + static_assert(kNonceSize <= kMaxNonceSize, "nonce size too big"); +} + +ChaCha20Poly1305TlsDecrypter::~ChaCha20Poly1305TlsDecrypter() {} + +uint32_t ChaCha20Poly1305TlsDecrypter::cipher_id() const { + return TLS1_CK_CHACHA20_POLY1305_SHA256; +} + +QuicPacketCount ChaCha20Poly1305TlsDecrypter::GetIntegrityLimit() const { + // For AEAD_CHACHA20_POLY1305, the integrity limit is 2^36 invalid packets. + // https://quicwg.org/base-drafts/draft-ietf-quic-tls.html#name-limits-on-aead-usage + static_assert(kMaxIncomingPacketSize < 16384, + "This key limit requires limits on decryption payload sizes"); + return 68719476736U; +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/chacha20_poly1305_tls_decrypter.h b/quiche/quic/core/crypto/chacha20_poly1305_tls_decrypter.h new file mode 100644 index 000000000000..f8108f2662a1 --- /dev/null +++ b/quiche/quic/core/crypto/chacha20_poly1305_tls_decrypter.h @@ -0,0 +1,39 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_CHACHA20_POLY1305_TLS_DECRYPTER_H_ +#define QUICHE_QUIC_CORE_CRYPTO_CHACHA20_POLY1305_TLS_DECRYPTER_H_ + +#include + +#include "quiche/quic/core/crypto/chacha_base_decrypter.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// A ChaCha20Poly1305TlsDecrypter is a QuicDecrypter that implements the +// AEAD_CHACHA20_POLY1305 algorithm specified in RFC 7539 for use in IETF QUIC. +// +// It uses an authentication tag of 16 bytes (128 bits). It uses a 12 bytes IV +// that is XOR'd with the packet number to compute the nonce. +class QUIC_EXPORT_PRIVATE ChaCha20Poly1305TlsDecrypter + : public ChaChaBaseDecrypter { + public: + enum { + kAuthTagSize = 16, + }; + + ChaCha20Poly1305TlsDecrypter(); + ChaCha20Poly1305TlsDecrypter(const ChaCha20Poly1305TlsDecrypter&) = delete; + ChaCha20Poly1305TlsDecrypter& operator=(const ChaCha20Poly1305TlsDecrypter&) = + delete; + ~ChaCha20Poly1305TlsDecrypter() override; + + uint32_t cipher_id() const override; + QuicPacketCount GetIntegrityLimit() const override; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_CHACHA20_POLY1305_TLS_DECRYPTER_H_ diff --git a/quiche/quic/core/crypto/chacha20_poly1305_tls_decrypter_test.cc b/quiche/quic/core/crypto/chacha20_poly1305_tls_decrypter_test.cc new file mode 100644 index 000000000000..00686367dde0 --- /dev/null +++ b/quiche/quic/core/crypto/chacha20_poly1305_tls_decrypter_test.cc @@ -0,0 +1,188 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/chacha20_poly1305_tls_decrypter.h" + +#include +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +namespace { + +// The test vectors come from RFC 7539 Section 2.8.2. + +// Each test vector consists of six strings of lowercase hexadecimal digits. +// The strings may be empty (zero length). A test vector with a nullptr |key| +// marks the end of an array of test vectors. +struct TestVector { + // Input: + const char* key; + const char* iv; + const char* fixed; + const char* aad; + const char* ct; + + // Expected output: + const char* pt; // An empty string "" means decryption succeeded and + // the plaintext is zero-length. nullptr means decryption + // failed. +}; + +const TestVector test_vectors[] = { + {"808182838485868788898a8b8c8d8e8f" + "909192939495969798999a9b9c9d9e9f", + + "4041424344454647", + + "07000000", + + "50515253c0c1c2c3c4c5c6c7", + + "d31a8d34648e60db7b86afbc53ef7ec2" + "a4aded51296e08fea9e2b5a736ee62d6" + "3dbea45e8ca9671282fafb69da92728b" + "1a71de0a9e060b2905d6a5b67ecd3b36" + "92ddbd7f2d778b8c9803aee328091b58" + "fab324e4fad675945585808b4831d7bc" + "3ff4def08e4b7a9de576d26586cec64b" + "6116" + "1ae10b594f09e26a7e902ecbd0600691", + + "4c616469657320616e642047656e746c" + "656d656e206f662074686520636c6173" + "73206f66202739393a20496620492063" + "6f756c64206f6666657220796f75206f" + "6e6c79206f6e652074697020666f7220" + "746865206675747572652c2073756e73" + "637265656e20776f756c642062652069" + "742e"}, + // Modify the ciphertext (Poly1305 authenticator). + {"808182838485868788898a8b8c8d8e8f" + "909192939495969798999a9b9c9d9e9f", + + "4041424344454647", + + "07000000", + + "50515253c0c1c2c3c4c5c6c7", + + "d31a8d34648e60db7b86afbc53ef7ec2" + "a4aded51296e08fea9e2b5a736ee62d6" + "3dbea45e8ca9671282fafb69da92728b" + "1a71de0a9e060b2905d6a5b67ecd3b36" + "92ddbd7f2d778b8c9803aee328091b58" + "fab324e4fad675945585808b4831d7bc" + "3ff4def08e4b7a9de576d26586cec64b" + "6116" + "1ae10b594f09e26a7e902eccd0600691", + + nullptr}, + // Modify the associated data. + {"808182838485868788898a8b8c8d8e8f" + "909192939495969798999a9b9c9d9e9f", + + "4041424344454647", + + "07000000", + + "60515253c0c1c2c3c4c5c6c7", + + "d31a8d34648e60db7b86afbc53ef7ec2" + "a4aded51296e08fea9e2b5a736ee62d6" + "3dbea45e8ca9671282fafb69da92728b" + "1a71de0a9e060b2905d6a5b67ecd3b36" + "92ddbd7f2d778b8c9803aee328091b58" + "fab324e4fad675945585808b4831d7bc" + "3ff4def08e4b7a9de576d26586cec64b" + "6116" + "1ae10b594f09e26a7e902ecbd0600691", + + nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +} // namespace + +namespace quic { +namespace test { + +// DecryptWithNonce wraps the |Decrypt| method of |decrypter| to allow passing +// in an nonce and also to allocate the buffer needed for the plaintext. +QuicData* DecryptWithNonce(ChaCha20Poly1305TlsDecrypter* decrypter, + absl::string_view nonce, + absl::string_view associated_data, + absl::string_view ciphertext) { + decrypter->SetIV(nonce); + std::unique_ptr output(new char[ciphertext.length()]); + size_t output_length = 0; + const bool success = + decrypter->DecryptPacket(0, associated_data, ciphertext, output.get(), + &output_length, ciphertext.length()); + if (!success) { + return nullptr; + } + return new QuicData(output.release(), output_length, true); +} + +class ChaCha20Poly1305TlsDecrypterTest : public QuicTest {}; + +TEST_F(ChaCha20Poly1305TlsDecrypterTest, Decrypt) { + for (size_t i = 0; test_vectors[i].key != nullptr; i++) { + // If not present then decryption is expected to fail. + bool has_pt = test_vectors[i].pt; + + // Decode the test vector. + std::string key = absl::HexStringToBytes(test_vectors[i].key); + std::string iv = absl::HexStringToBytes(test_vectors[i].iv); + std::string fixed = absl::HexStringToBytes(test_vectors[i].fixed); + std::string aad = absl::HexStringToBytes(test_vectors[i].aad); + std::string ct = absl::HexStringToBytes(test_vectors[i].ct); + std::string pt; + if (has_pt) { + pt = absl::HexStringToBytes(test_vectors[i].pt); + } + + ChaCha20Poly1305TlsDecrypter decrypter; + ASSERT_TRUE(decrypter.SetKey(key)); + std::unique_ptr decrypted(DecryptWithNonce( + &decrypter, fixed + iv, + // This deliberately tests that the decrypter can handle an AAD that + // is set to nullptr, as opposed to a zero-length, non-nullptr pointer. + absl::string_view(aad.length() ? aad.data() : nullptr, aad.length()), + ct)); + if (!decrypted) { + EXPECT_FALSE(has_pt); + continue; + } + EXPECT_TRUE(has_pt); + + EXPECT_EQ(16u, ct.size() - decrypted->length()); + ASSERT_EQ(pt.length(), decrypted->length()); + quiche::test::CompareCharArraysWithHexError( + "plaintext", decrypted->data(), pt.length(), pt.data(), pt.length()); + } +} + +TEST_F(ChaCha20Poly1305TlsDecrypterTest, GenerateHeaderProtectionMask) { + ChaCha20Poly1305TlsDecrypter decrypter; + std::string key = absl::HexStringToBytes( + "6a067f432787bd6034dd3f08f07fc9703a27e58c70e2d88d948b7f6489923cc7"); + std::string sample = + absl::HexStringToBytes("1210d91cceb45c716b023f492c29e612"); + QuicDataReader sample_reader(sample.data(), sample.size()); + ASSERT_TRUE(decrypter.SetHeaderProtectionKey(key)); + std::string mask = decrypter.GenerateHeaderProtectionMask(&sample_reader); + std::string expected_mask = absl::HexStringToBytes("1cc2cd98dc"); + quiche::test::CompareCharArraysWithHexError( + "header protection mask", mask.data(), mask.size(), expected_mask.data(), + expected_mask.size()); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/chacha20_poly1305_tls_encrypter.cc b/quiche/quic/core/crypto/chacha20_poly1305_tls_encrypter.cc new file mode 100644 index 000000000000..fe6f6b44ac9c --- /dev/null +++ b/quiche/quic/core/crypto/chacha20_poly1305_tls_encrypter.cc @@ -0,0 +1,35 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/chacha20_poly1305_tls_encrypter.h" + +#include "openssl/evp.h" + +namespace quic { + +namespace { + +const size_t kKeySize = 32; +const size_t kNonceSize = 12; + +} // namespace + +ChaCha20Poly1305TlsEncrypter::ChaCha20Poly1305TlsEncrypter() + : ChaChaBaseEncrypter(EVP_aead_chacha20_poly1305, kKeySize, kAuthTagSize, + kNonceSize, + /* use_ietf_nonce_construction */ true) { + static_assert(kKeySize <= kMaxKeySize, "key size too big"); + static_assert(kNonceSize <= kMaxNonceSize, "nonce size too big"); +} + +ChaCha20Poly1305TlsEncrypter::~ChaCha20Poly1305TlsEncrypter() {} + +QuicPacketCount ChaCha20Poly1305TlsEncrypter::GetConfidentialityLimit() const { + // For AEAD_CHACHA20_POLY1305, the confidentiality limit is greater than the + // number of possible packets (2^62) and so can be disregarded. + // https://quicwg.org/base-drafts/draft-ietf-quic-tls.html#name-limits-on-aead-usage + return std::numeric_limits::max(); +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/chacha20_poly1305_tls_encrypter.h b/quiche/quic/core/crypto/chacha20_poly1305_tls_encrypter.h new file mode 100644 index 000000000000..e5d8f378ec5c --- /dev/null +++ b/quiche/quic/core/crypto/chacha20_poly1305_tls_encrypter.h @@ -0,0 +1,36 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_CHACHA20_POLY1305_TLS_ENCRYPTER_H_ +#define QUICHE_QUIC_CORE_CRYPTO_CHACHA20_POLY1305_TLS_ENCRYPTER_H_ + +#include "quiche/quic/core/crypto/chacha_base_encrypter.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// A ChaCha20Poly1305Encrypter is a QuicEncrypter that implements the +// AEAD_CHACHA20_POLY1305 algorithm specified in RFC 7539 for use in IETF QUIC. +// +// It uses an authentication tag of 16 bytes (128 bits). It uses a 12 byte IV +// that is XOR'd with the packet number to compute the nonce. +class QUIC_EXPORT_PRIVATE ChaCha20Poly1305TlsEncrypter + : public ChaChaBaseEncrypter { + public: + enum { + kAuthTagSize = 16, + }; + + ChaCha20Poly1305TlsEncrypter(); + ChaCha20Poly1305TlsEncrypter(const ChaCha20Poly1305TlsEncrypter&) = delete; + ChaCha20Poly1305TlsEncrypter& operator=(const ChaCha20Poly1305TlsEncrypter&) = + delete; + ~ChaCha20Poly1305TlsEncrypter() override; + + QuicPacketCount GetConfidentialityLimit() const override; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_CHACHA20_POLY1305_TLS_ENCRYPTER_H_ diff --git a/quiche/quic/core/crypto/chacha20_poly1305_tls_encrypter_test.cc b/quiche/quic/core/crypto/chacha20_poly1305_tls_encrypter_test.cc new file mode 100644 index 000000000000..322651b846d6 --- /dev/null +++ b/quiche/quic/core/crypto/chacha20_poly1305_tls_encrypter_test.cc @@ -0,0 +1,173 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/chacha20_poly1305_tls_encrypter.h" + +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/chacha20_poly1305_tls_decrypter.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +namespace { + +// The test vectors come from RFC 7539 Section 2.8.2. + +// Each test vector consists of five strings of lowercase hexadecimal digits. +// The strings may be empty (zero length). A test vector with a nullptr |key| +// marks the end of an array of test vectors. +struct TestVector { + const char* key; + const char* pt; + const char* iv; + const char* fixed; + const char* aad; + const char* ct; +}; + +const TestVector test_vectors[] = { + { + "808182838485868788898a8b8c8d8e8f" + "909192939495969798999a9b9c9d9e9f", + + "4c616469657320616e642047656e746c" + "656d656e206f662074686520636c6173" + "73206f66202739393a20496620492063" + "6f756c64206f6666657220796f75206f" + "6e6c79206f6e652074697020666f7220" + "746865206675747572652c2073756e73" + "637265656e20776f756c642062652069" + "742e", + + "4041424344454647", + + "07000000", + + "50515253c0c1c2c3c4c5c6c7", + + "d31a8d34648e60db7b86afbc53ef7ec2" + "a4aded51296e08fea9e2b5a736ee62d6" + "3dbea45e8ca9671282fafb69da92728b" + "1a71de0a9e060b2905d6a5b67ecd3b36" + "92ddbd7f2d778b8c9803aee328091b58" + "fab324e4fad675945585808b4831d7bc" + "3ff4def08e4b7a9de576d26586cec64b" + "6116" + "1ae10b594f09e26a7e902ecbd0600691", + }, + {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}}; + +} // namespace + +namespace quic { +namespace test { + +// EncryptWithNonce wraps the |Encrypt| method of |encrypter| to allow passing +// in an nonce and also to allocate the buffer needed for the ciphertext. +QuicData* EncryptWithNonce(ChaCha20Poly1305TlsEncrypter* encrypter, + absl::string_view nonce, + absl::string_view associated_data, + absl::string_view plaintext) { + size_t ciphertext_size = encrypter->GetCiphertextSize(plaintext.length()); + std::unique_ptr ciphertext(new char[ciphertext_size]); + + if (!encrypter->Encrypt(nonce, associated_data, plaintext, + reinterpret_cast(ciphertext.get()))) { + return nullptr; + } + + return new QuicData(ciphertext.release(), ciphertext_size, true); +} + +class ChaCha20Poly1305TlsEncrypterTest : public QuicTest {}; + +TEST_F(ChaCha20Poly1305TlsEncrypterTest, EncryptThenDecrypt) { + ChaCha20Poly1305TlsEncrypter encrypter; + ChaCha20Poly1305TlsDecrypter decrypter; + + std::string key = absl::HexStringToBytes(test_vectors[0].key); + ASSERT_TRUE(encrypter.SetKey(key)); + ASSERT_TRUE(decrypter.SetKey(key)); + ASSERT_TRUE(encrypter.SetIV("abcdefghijkl")); + ASSERT_TRUE(decrypter.SetIV("abcdefghijkl")); + + uint64_t packet_number = UINT64_C(0x123456789ABC); + std::string associated_data = "associated_data"; + std::string plaintext = "plaintext"; + char encrypted[1024]; + size_t len; + ASSERT_TRUE(encrypter.EncryptPacket(packet_number, associated_data, plaintext, + encrypted, &len, + ABSL_ARRAYSIZE(encrypted))); + absl::string_view ciphertext(encrypted, len); + char decrypted[1024]; + ASSERT_TRUE(decrypter.DecryptPacket(packet_number, associated_data, + ciphertext, decrypted, &len, + ABSL_ARRAYSIZE(decrypted))); +} + +TEST_F(ChaCha20Poly1305TlsEncrypterTest, Encrypt) { + for (size_t i = 0; test_vectors[i].key != nullptr; i++) { + // Decode the test vector. + std::string key = absl::HexStringToBytes(test_vectors[i].key); + std::string pt = absl::HexStringToBytes(test_vectors[i].pt); + std::string iv = absl::HexStringToBytes(test_vectors[i].iv); + std::string fixed = absl::HexStringToBytes(test_vectors[i].fixed); + std::string aad = absl::HexStringToBytes(test_vectors[i].aad); + std::string ct = absl::HexStringToBytes(test_vectors[i].ct); + + ChaCha20Poly1305TlsEncrypter encrypter; + ASSERT_TRUE(encrypter.SetKey(key)); + std::unique_ptr encrypted(EncryptWithNonce( + &encrypter, fixed + iv, + // This deliberately tests that the encrypter can handle an AAD that + // is set to nullptr, as opposed to a zero-length, non-nullptr pointer. + absl::string_view(aad.length() ? aad.data() : nullptr, aad.length()), + pt)); + ASSERT_TRUE(encrypted.get()); + EXPECT_EQ(16u, ct.size() - pt.size()); + EXPECT_EQ(16u, encrypted->length() - pt.size()); + + quiche::test::CompareCharArraysWithHexError("ciphertext", encrypted->data(), + encrypted->length(), ct.data(), + ct.length()); + } +} + +TEST_F(ChaCha20Poly1305TlsEncrypterTest, GetMaxPlaintextSize) { + ChaCha20Poly1305TlsEncrypter encrypter; + EXPECT_EQ(1000u, encrypter.GetMaxPlaintextSize(1016)); + EXPECT_EQ(100u, encrypter.GetMaxPlaintextSize(116)); + EXPECT_EQ(10u, encrypter.GetMaxPlaintextSize(26)); +} + +TEST_F(ChaCha20Poly1305TlsEncrypterTest, GetCiphertextSize) { + ChaCha20Poly1305TlsEncrypter encrypter; + EXPECT_EQ(1016u, encrypter.GetCiphertextSize(1000)); + EXPECT_EQ(116u, encrypter.GetCiphertextSize(100)); + EXPECT_EQ(26u, encrypter.GetCiphertextSize(10)); +} + +TEST_F(ChaCha20Poly1305TlsEncrypterTest, GenerateHeaderProtectionMask) { + ChaCha20Poly1305TlsEncrypter encrypter; + std::string key = absl::HexStringToBytes( + "6a067f432787bd6034dd3f08f07fc9703a27e58c70e2d88d948b7f6489923cc7"); + std::string sample = + absl::HexStringToBytes("1210d91cceb45c716b023f492c29e612"); + ASSERT_TRUE(encrypter.SetHeaderProtectionKey(key)); + std::string mask = encrypter.GenerateHeaderProtectionMask(sample); + std::string expected_mask = absl::HexStringToBytes("1cc2cd98dc"); + quiche::test::CompareCharArraysWithHexError( + "header protection mask", mask.data(), mask.size(), expected_mask.data(), + expected_mask.size()); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/chacha_base_decrypter.cc b/quiche/quic/core/crypto/chacha_base_decrypter.cc new file mode 100644 index 000000000000..a90c9eff98ea --- /dev/null +++ b/quiche/quic/core/crypto/chacha_base_decrypter.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/chacha_base_decrypter.h" + +#include + +#include "absl/base/macros.h" +#include "absl/strings/string_view.h" +#include "openssl/chacha.h" +#include "quiche/quic/core/quic_data_reader.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/common/quiche_endian.h" + +namespace quic { + +bool ChaChaBaseDecrypter::SetHeaderProtectionKey(absl::string_view key) { + if (key.size() != GetKeySize()) { + QUIC_BUG(quic_bug_10620_1) << "Invalid key size for header protection"; + return false; + } + memcpy(pne_key_, key.data(), key.size()); + return true; +} + +std::string ChaChaBaseDecrypter::GenerateHeaderProtectionMask( + QuicDataReader* sample_reader) { + absl::string_view sample; + if (!sample_reader->ReadStringPiece(&sample, 16)) { + return std::string(); + } + const uint8_t* nonce = reinterpret_cast(sample.data()) + 4; + uint32_t counter; + QuicDataReader(sample.data(), 4, quiche::HOST_BYTE_ORDER) + .ReadUInt32(&counter); + const uint8_t zeroes[] = {0, 0, 0, 0, 0}; + std::string out(ABSL_ARRAYSIZE(zeroes), 0); + CRYPTO_chacha_20(reinterpret_cast(const_cast(out.data())), + zeroes, ABSL_ARRAYSIZE(zeroes), pne_key_, nonce, counter); + return out; +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/chacha_base_decrypter.h b/quiche/quic/core/crypto/chacha_base_decrypter.h new file mode 100644 index 000000000000..5cd08c74cf63 --- /dev/null +++ b/quiche/quic/core/crypto/chacha_base_decrypter.h @@ -0,0 +1,31 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_CHACHA_BASE_DECRYPTER_H_ +#define QUICHE_QUIC_CORE_CRYPTO_CHACHA_BASE_DECRYPTER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/aead_base_decrypter.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class QUIC_EXPORT_PRIVATE ChaChaBaseDecrypter : public AeadBaseDecrypter { + public: + using AeadBaseDecrypter::AeadBaseDecrypter; + + bool SetHeaderProtectionKey(absl::string_view key) override; + std::string GenerateHeaderProtectionMask( + QuicDataReader* sample_reader) override; + + private: + // The key used for packet number encryption. + unsigned char pne_key_[kMaxKeySize]; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_CHACHA_BASE_DECRYPTER_H_ diff --git a/quiche/quic/core/crypto/chacha_base_encrypter.cc b/quiche/quic/core/crypto/chacha_base_encrypter.cc new file mode 100644 index 000000000000..847345130b8c --- /dev/null +++ b/quiche/quic/core/crypto/chacha_base_encrypter.cc @@ -0,0 +1,41 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/chacha_base_encrypter.h" + +#include "absl/base/macros.h" +#include "absl/strings/string_view.h" +#include "openssl/chacha.h" +#include "quiche/quic/core/quic_data_reader.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/common/quiche_endian.h" + +namespace quic { + +bool ChaChaBaseEncrypter::SetHeaderProtectionKey(absl::string_view key) { + if (key.size() != GetKeySize()) { + QUIC_BUG(quic_bug_10656_1) << "Invalid key size for header protection"; + return false; + } + memcpy(pne_key_, key.data(), key.size()); + return true; +} + +std::string ChaChaBaseEncrypter::GenerateHeaderProtectionMask( + absl::string_view sample) { + if (sample.size() != 16) { + return std::string(); + } + const uint8_t* nonce = reinterpret_cast(sample.data()) + 4; + uint32_t counter; + QuicDataReader(sample.data(), 4, quiche::HOST_BYTE_ORDER) + .ReadUInt32(&counter); + const uint8_t zeroes[] = {0, 0, 0, 0, 0}; + std::string out(ABSL_ARRAYSIZE(zeroes), 0); + CRYPTO_chacha_20(reinterpret_cast(const_cast(out.data())), + zeroes, ABSL_ARRAYSIZE(zeroes), pne_key_, nonce, counter); + return out; +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/chacha_base_encrypter.h b/quiche/quic/core/crypto/chacha_base_encrypter.h new file mode 100644 index 000000000000..14773ec1cd7b --- /dev/null +++ b/quiche/quic/core/crypto/chacha_base_encrypter.h @@ -0,0 +1,30 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_CHACHA_BASE_ENCRYPTER_H_ +#define QUICHE_QUIC_CORE_CRYPTO_CHACHA_BASE_ENCRYPTER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/aead_base_encrypter.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class QUIC_EXPORT_PRIVATE ChaChaBaseEncrypter : public AeadBaseEncrypter { + public: + using AeadBaseEncrypter::AeadBaseEncrypter; + + bool SetHeaderProtectionKey(absl::string_view key) override; + std::string GenerateHeaderProtectionMask(absl::string_view sample) override; + + private: + // The key used for packet number encryption. + unsigned char pne_key_[kMaxKeySize]; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_CHACHA_BASE_ENCRYPTER_H_ diff --git a/quiche/quic/core/crypto/channel_id.cc b/quiche/quic/core/crypto/channel_id.cc new file mode 100644 index 000000000000..77288dd52eea --- /dev/null +++ b/quiche/quic/core/crypto/channel_id.cc @@ -0,0 +1,90 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/channel_id.h" + +#include + +#include "absl/strings/string_view.h" +#include "openssl/bn.h" +#include "openssl/ec.h" +#include "openssl/ecdsa.h" +#include "openssl/nid.h" +#include "openssl/sha.h" + +namespace quic { + +// static +const char ChannelIDVerifier::kContextStr[] = "QUIC ChannelID"; +// static +const char ChannelIDVerifier::kClientToServerStr[] = "client -> server"; + +// static +bool ChannelIDVerifier::Verify(absl::string_view key, + absl::string_view signed_data, + absl::string_view signature) { + return VerifyRaw(key, signed_data, signature, true); +} + +// static +bool ChannelIDVerifier::VerifyRaw(absl::string_view key, + absl::string_view signed_data, + absl::string_view signature, + bool is_channel_id_signature) { + if (key.size() != 32 * 2 || signature.size() != 32 * 2) { + return false; + } + + bssl::UniquePtr p256( + EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1)); + if (p256.get() == nullptr) { + return false; + } + + bssl::UniquePtr x(BN_new()), y(BN_new()), r(BN_new()), s(BN_new()); + + ECDSA_SIG sig; + sig.r = r.get(); + sig.s = s.get(); + + const uint8_t* key_bytes = reinterpret_cast(key.data()); + const uint8_t* signature_bytes = + reinterpret_cast(signature.data()); + + if (BN_bin2bn(key_bytes + 0, 32, x.get()) == nullptr || + BN_bin2bn(key_bytes + 32, 32, y.get()) == nullptr || + BN_bin2bn(signature_bytes + 0, 32, sig.r) == nullptr || + BN_bin2bn(signature_bytes + 32, 32, sig.s) == nullptr) { + return false; + } + + bssl::UniquePtr point(EC_POINT_new(p256.get())); + if (point.get() == nullptr || + !EC_POINT_set_affine_coordinates_GFp(p256.get(), point.get(), x.get(), + y.get(), nullptr)) { + return false; + } + + bssl::UniquePtr ecdsa_key(EC_KEY_new()); + if (ecdsa_key.get() == nullptr || + !EC_KEY_set_group(ecdsa_key.get(), p256.get()) || + !EC_KEY_set_public_key(ecdsa_key.get(), point.get())) { + return false; + } + + SHA256_CTX sha256; + SHA256_Init(&sha256); + if (is_channel_id_signature) { + SHA256_Update(&sha256, kContextStr, strlen(kContextStr) + 1); + SHA256_Update(&sha256, kClientToServerStr, strlen(kClientToServerStr) + 1); + } + SHA256_Update(&sha256, signed_data.data(), signed_data.size()); + + unsigned char digest[SHA256_DIGEST_LENGTH]; + SHA256_Final(digest, &sha256); + + return ECDSA_do_verify(digest, sizeof(digest), &sig, ecdsa_key.get()) == 1; +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/channel_id.h b/quiche/quic/core/crypto/channel_id.h new file mode 100644 index 000000000000..2f8a52781e5f --- /dev/null +++ b/quiche/quic/core/crypto/channel_id.h @@ -0,0 +1,47 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_CHANNEL_ID_H_ +#define QUICHE_QUIC_CORE_CRYPTO_CHANNEL_ID_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// ChannelIDVerifier verifies ChannelID signatures. +class QUIC_EXPORT_PRIVATE ChannelIDVerifier { + public: + ChannelIDVerifier() = delete; + + // kContextStr is prepended to the data to be signed in order to ensure that + // a ChannelID signature cannot be used in a different context. (The + // terminating NUL byte is inclued.) + static const char kContextStr[]; + // kClientToServerStr follows kContextStr to specify that the ChannelID is + // being used in the client to server direction. (The terminating NUL byte is + // included.) + static const char kClientToServerStr[]; + + // Verify returns true iff |signature| is a valid signature of |signed_data| + // by |key|. + static bool Verify(absl::string_view key, absl::string_view signed_data, + absl::string_view signature); + + // FOR TESTING ONLY: VerifyRaw returns true iff |signature| is a valid + // signature of |signed_data| by |key|. |is_channel_id_signature| indicates + // whether |signature| is a ChannelID signature (with kContextStr prepended + // to the data to be signed). + static bool VerifyRaw(absl::string_view key, absl::string_view signed_data, + absl::string_view signature, + bool is_channel_id_signature); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_CHANNEL_ID_H_ diff --git a/quiche/quic/core/crypto/channel_id_test.cc b/quiche/quic/core/crypto/channel_id_test.cc new file mode 100644 index 000000000000..ff3d73d159ce --- /dev/null +++ b/quiche/quic/core/crypto/channel_id_test.cc @@ -0,0 +1,285 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/channel_id.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" + +namespace quic { +namespace test { + +namespace { + +// The following ECDSA signature verification test vectors for P-256,SHA-256 +// come from the SigVer.rsp file in +// http://csrc.nist.gov/groups/STM/cavp/documents/dss/186-3ecdsatestvectors.zip +// downloaded on 2013-06-11. +struct TestVector { + // Input: + const char* msg; + const char* qx; + const char* qy; + const char* r; + const char* s; + + // Expected output: + bool result; // true means "P", false means "F" +}; + +const TestVector test_vector[] = { + { + "e4796db5f785f207aa30d311693b3702821dff1168fd2e04c0836825aefd850d" + "9aa60326d88cde1a23c7745351392ca2288d632c264f197d05cd424a30336c19" + "fd09bb229654f0222fcb881a4b35c290a093ac159ce13409111ff0358411133c" + "24f5b8e2090d6db6558afc36f06ca1f6ef779785adba68db27a409859fc4c4a0", + "87f8f2b218f49845f6f10eec3877136269f5c1a54736dbdf69f89940cad41555", + "e15f369036f49842fac7a86c8a2b0557609776814448b8f5e84aa9f4395205e9", + "d19ff48b324915576416097d2544f7cbdf8768b1454ad20e0baac50e211f23b0", + "a3e81e59311cdfff2d4784949f7a2cb50ba6c3a91fa54710568e61aca3e847c6", + false // F (3 - S changed) + }, + { + "069a6e6b93dfee6df6ef6997cd80dd2182c36653cef10c655d524585655462d6" + "83877f95ecc6d6c81623d8fac4e900ed0019964094e7de91f1481989ae187300" + "4565789cbf5dc56c62aedc63f62f3b894c9c6f7788c8ecaadc9bd0e81ad91b2b" + "3569ea12260e93924fdddd3972af5273198f5efda0746219475017557616170e", + "5cf02a00d205bdfee2016f7421807fc38ae69e6b7ccd064ee689fc1a94a9f7d2", + "ec530ce3cc5c9d1af463f264d685afe2b4db4b5828d7e61b748930f3ce622a85", + "dc23d130c6117fb5751201455e99f36f59aba1a6a21cf2d0e7481a97451d6693", + "d6ce7708c18dbf35d4f8aa7240922dc6823f2e7058cbc1484fcad1599db5018c", + false // F (2 - R changed) + }, + { + "df04a346cf4d0e331a6db78cca2d456d31b0a000aa51441defdb97bbeb20b94d" + "8d746429a393ba88840d661615e07def615a342abedfa4ce912e562af7149598" + "96858af817317a840dcff85a057bb91a3c2bf90105500362754a6dd321cdd861" + "28cfc5f04667b57aa78c112411e42da304f1012d48cd6a7052d7de44ebcc01de", + "2ddfd145767883ffbb0ac003ab4a44346d08fa2570b3120dcce94562422244cb", + "5f70c7d11ac2b7a435ccfbbae02c3df1ea6b532cc0e9db74f93fffca7c6f9a64", + "9913111cff6f20c5bf453a99cd2c2019a4e749a49724a08774d14e4c113edda8", + "9467cd4cd21ecb56b0cab0a9a453b43386845459127a952421f5c6382866c5cc", + false // F (4 - Q changed) + }, + { + "e1130af6a38ccb412a9c8d13e15dbfc9e69a16385af3c3f1e5da954fd5e7c45f" + "d75e2b8c36699228e92840c0562fbf3772f07e17f1add56588dd45f7450e1217" + "ad239922dd9c32695dc71ff2424ca0dec1321aa47064a044b7fe3c2b97d03ce4" + "70a592304c5ef21eed9f93da56bb232d1eeb0035f9bf0dfafdcc4606272b20a3", + "e424dc61d4bb3cb7ef4344a7f8957a0c5134e16f7a67c074f82e6e12f49abf3c", + "970eed7aa2bc48651545949de1dddaf0127e5965ac85d1243d6f60e7dfaee927", + "bf96b99aa49c705c910be33142017c642ff540c76349b9dab72f981fd9347f4f", + "17c55095819089c2e03b9cd415abdf12444e323075d98f31920b9e0f57ec871c", + true // P (0 ) + }, + { + "73c5f6a67456ae48209b5f85d1e7de7758bf235300c6ae2bdceb1dcb27a7730f" + "b68c950b7fcada0ecc4661d3578230f225a875e69aaa17f1e71c6be5c831f226" + "63bac63d0c7a9635edb0043ff8c6f26470f02a7bc56556f1437f06dfa27b487a" + "6c4290d8bad38d4879b334e341ba092dde4e4ae694a9c09302e2dbf443581c08", + "e0fc6a6f50e1c57475673ee54e3a57f9a49f3328e743bf52f335e3eeaa3d2864", + "7f59d689c91e463607d9194d99faf316e25432870816dde63f5d4b373f12f22a", + "1d75830cd36f4c9aa181b2c4221e87f176b7f05b7c87824e82e396c88315c407", + "cb2acb01dac96efc53a32d4a0d85d0c2e48955214783ecf50a4f0414a319c05a", + true // P (0 ) + }, + { + "666036d9b4a2426ed6585a4e0fd931a8761451d29ab04bd7dc6d0c5b9e38e6c2" + "b263ff6cb837bd04399de3d757c6c7005f6d7a987063cf6d7e8cb38a4bf0d74a" + "282572bd01d0f41e3fd066e3021575f0fa04f27b700d5b7ddddf50965993c3f9" + "c7118ed78888da7cb221849b3260592b8e632d7c51e935a0ceae15207bedd548", + "a849bef575cac3c6920fbce675c3b787136209f855de19ffe2e8d29b31a5ad86", + "bf5fe4f7858f9b805bd8dcc05ad5e7fb889de2f822f3d8b41694e6c55c16b471", + "25acc3aa9d9e84c7abf08f73fa4195acc506491d6fc37cb9074528a7db87b9d6", + "9b21d5b5259ed3f2ef07dfec6cc90d3a37855d1ce122a85ba6a333f307d31537", + false // F (2 - R changed) + }, + { + "7e80436bce57339ce8da1b5660149a20240b146d108deef3ec5da4ae256f8f89" + "4edcbbc57b34ce37089c0daa17f0c46cd82b5a1599314fd79d2fd2f446bd5a25" + "b8e32fcf05b76d644573a6df4ad1dfea707b479d97237a346f1ec632ea5660ef" + "b57e8717a8628d7f82af50a4e84b11f21bdff6839196a880ae20b2a0918d58cd", + "3dfb6f40f2471b29b77fdccba72d37c21bba019efa40c1c8f91ec405d7dcc5df", + "f22f953f1e395a52ead7f3ae3fc47451b438117b1e04d613bc8555b7d6e6d1bb", + "548886278e5ec26bed811dbb72db1e154b6f17be70deb1b210107decb1ec2a5a", + "e93bfebd2f14f3d827ca32b464be6e69187f5edbd52def4f96599c37d58eee75", + false // F (4 - Q changed) + }, + { + "1669bfb657fdc62c3ddd63269787fc1c969f1850fb04c933dda063ef74a56ce1" + "3e3a649700820f0061efabf849a85d474326c8a541d99830eea8131eaea584f2" + "2d88c353965dabcdc4bf6b55949fd529507dfb803ab6b480cd73ca0ba00ca19c" + "438849e2cea262a1c57d8f81cd257fb58e19dec7904da97d8386e87b84948169", + "69b7667056e1e11d6caf6e45643f8b21e7a4bebda463c7fdbc13bc98efbd0214", + "d3f9b12eb46c7c6fda0da3fc85bc1fd831557f9abc902a3be3cb3e8be7d1aa2f", + "288f7a1cd391842cce21f00e6f15471c04dc182fe4b14d92dc18910879799790", + "247b3c4e89a3bcadfea73c7bfd361def43715fa382b8c3edf4ae15d6e55e9979", + false // F (1 - Message changed) + }, + { + "3fe60dd9ad6caccf5a6f583b3ae65953563446c4510b70da115ffaa0ba04c076" + "115c7043ab8733403cd69c7d14c212c655c07b43a7c71b9a4cffe22c2684788e" + "c6870dc2013f269172c822256f9e7cc674791bf2d8486c0f5684283e1649576e" + "fc982ede17c7b74b214754d70402fb4bb45ad086cf2cf76b3d63f7fce39ac970", + "bf02cbcf6d8cc26e91766d8af0b164fc5968535e84c158eb3bc4e2d79c3cc682", + "069ba6cb06b49d60812066afa16ecf7b51352f2c03bd93ec220822b1f3dfba03", + "f5acb06c59c2b4927fb852faa07faf4b1852bbb5d06840935e849c4d293d1bad", + "049dab79c89cc02f1484c437f523e080a75f134917fda752f2d5ca397addfe5d", + false // F (3 - S changed) + }, + { + "983a71b9994d95e876d84d28946a041f8f0a3f544cfcc055496580f1dfd4e312" + "a2ad418fe69dbc61db230cc0c0ed97e360abab7d6ff4b81ee970a7e97466acfd" + "9644f828ffec538abc383d0e92326d1c88c55e1f46a668a039beaa1be631a891" + "29938c00a81a3ae46d4aecbf9707f764dbaccea3ef7665e4c4307fa0b0a3075c", + "224a4d65b958f6d6afb2904863efd2a734b31798884801fcab5a590f4d6da9de", + "178d51fddada62806f097aa615d33b8f2404e6b1479f5fd4859d595734d6d2b9", + "87b93ee2fecfda54deb8dff8e426f3c72c8864991f8ec2b3205bb3b416de93d2", + "4044a24df85be0cc76f21a4430b75b8e77b932a87f51e4eccbc45c263ebf8f66", + false // F (2 - R changed) + }, + { + "4a8c071ac4fd0d52faa407b0fe5dab759f7394a5832127f2a3498f34aac28733" + "9e043b4ffa79528faf199dc917f7b066ad65505dab0e11e6948515052ce20cfd" + "b892ffb8aa9bf3f1aa5be30a5bbe85823bddf70b39fd7ebd4a93a2f75472c1d4" + "f606247a9821f1a8c45a6cb80545de2e0c6c0174e2392088c754e9c8443eb5af", + "43691c7795a57ead8c5c68536fe934538d46f12889680a9cb6d055a066228369", + "f8790110b3c3b281aa1eae037d4f1234aff587d903d93ba3af225c27ddc9ccac", + "8acd62e8c262fa50dd9840480969f4ef70f218ebf8ef9584f199031132c6b1ce", + "cfca7ed3d4347fb2a29e526b43c348ae1ce6c60d44f3191b6d8ea3a2d9c92154", + false // F (3 - S changed) + }, + { + "0a3a12c3084c865daf1d302c78215d39bfe0b8bf28272b3c0b74beb4b7409db0" + "718239de700785581514321c6440a4bbaea4c76fa47401e151e68cb6c29017f0" + "bce4631290af5ea5e2bf3ed742ae110b04ade83a5dbd7358f29a85938e23d87a" + "c8233072b79c94670ff0959f9c7f4517862ff829452096c78f5f2e9a7e4e9216", + "9157dbfcf8cf385f5bb1568ad5c6e2a8652ba6dfc63bc1753edf5268cb7eb596", + "972570f4313d47fc96f7c02d5594d77d46f91e949808825b3d31f029e8296405", + "dfaea6f297fa320b707866125c2a7d5d515b51a503bee817de9faa343cc48eeb", + "8f780ad713f9c3e5a4f7fa4c519833dfefc6a7432389b1e4af463961f09764f2", + false // F (1 - Message changed) + }, + { + "785d07a3c54f63dca11f5d1a5f496ee2c2f9288e55007e666c78b007d95cc285" + "81dce51f490b30fa73dc9e2d45d075d7e3a95fb8a9e1465ad191904124160b7c" + "60fa720ef4ef1c5d2998f40570ae2a870ef3e894c2bc617d8a1dc85c3c557749" + "28c38789b4e661349d3f84d2441a3b856a76949b9f1f80bc161648a1cad5588e", + "072b10c081a4c1713a294f248aef850e297991aca47fa96a7470abe3b8acfdda", + "9581145cca04a0fb94cedce752c8f0370861916d2a94e7c647c5373ce6a4c8f5", + "09f5483eccec80f9d104815a1be9cc1a8e5b12b6eb482a65c6907b7480cf4f19", + "a4f90e560c5e4eb8696cb276e5165b6a9d486345dedfb094a76e8442d026378d", + false // F (4 - Q changed) + }, + { + "76f987ec5448dd72219bd30bf6b66b0775c80b394851a43ff1f537f140a6e722" + "9ef8cd72ad58b1d2d20298539d6347dd5598812bc65323aceaf05228f738b5ad" + "3e8d9fe4100fd767c2f098c77cb99c2992843ba3eed91d32444f3b6db6cd212d" + "d4e5609548f4bb62812a920f6e2bf1581be1ebeebdd06ec4e971862cc42055ca", + "09308ea5bfad6e5adf408634b3d5ce9240d35442f7fe116452aaec0d25be8c24", + "f40c93e023ef494b1c3079b2d10ef67f3170740495ce2cc57f8ee4b0618b8ee5", + "5cc8aa7c35743ec0c23dde88dabd5e4fcd0192d2116f6926fef788cddb754e73", + "9c9c045ebaa1b828c32f82ace0d18daebf5e156eb7cbfdc1eff4399a8a900ae7", + false // F (1 - Message changed) + }, + { + "60cd64b2cd2be6c33859b94875120361a24085f3765cb8b2bf11e026fa9d8855" + "dbe435acf7882e84f3c7857f96e2baab4d9afe4588e4a82e17a78827bfdb5ddb" + "d1c211fbc2e6d884cddd7cb9d90d5bf4a7311b83f352508033812c776a0e00c0" + "03c7e0d628e50736c7512df0acfa9f2320bd102229f46495ae6d0857cc452a84", + "2d98ea01f754d34bbc3003df5050200abf445ec728556d7ed7d5c54c55552b6d", + "9b52672742d637a32add056dfd6d8792f2a33c2e69dafabea09b960bc61e230a", + "06108e525f845d0155bf60193222b3219c98e3d49424c2fb2a0987f825c17959", + "62b5cdd591e5b507e560167ba8f6f7cda74673eb315680cb89ccbc4eec477dce", + true // P (0 ) + }, + {nullptr, nullptr, nullptr, nullptr, nullptr, false}}; + +// Returns true if |ch| is a lowercase hexadecimal digit. +bool IsHexDigit(char ch) { + return ('0' <= ch && ch <= '9') || ('a' <= ch && ch <= 'f'); +} + +// Converts a lowercase hexadecimal digit to its integer value. +int HexDigitToInt(char ch) { + if ('0' <= ch && ch <= '9') { + return ch - '0'; + } + return ch - 'a' + 10; +} + +// |in| is a string consisting of lowercase hexadecimal digits, where +// every two digits represent one byte. |out| is a buffer of size |max_len|. +// Converts |in| to bytes and stores the bytes in the |out| buffer. The +// number of bytes converted is returned in |*out_len|. Returns true on +// success, false on failure. +bool DecodeHexString(const char* in, char* out, size_t* out_len, + size_t max_len) { + if (!in) { + *out_len = static_cast(-1); + return true; + } + *out_len = 0; + while (*in != '\0') { + if (!IsHexDigit(*in) || !IsHexDigit(*(in + 1))) { + return false; + } + if (*out_len >= max_len) { + return false; + } + out[*out_len] = HexDigitToInt(*in) * 16 + HexDigitToInt(*(in + 1)); + (*out_len)++; + in += 2; + } + return true; +} + +} // namespace + +class ChannelIDTest : public QuicTest {}; + +// A known answer test for ChannelIDVerifier. +TEST_F(ChannelIDTest, VerifyKnownAnswerTest) { + char msg[1024]; + size_t msg_len; + char key[64]; + size_t qx_len; + size_t qy_len; + char signature[64]; + size_t r_len; + size_t s_len; + + for (size_t i = 0; test_vector[i].msg != nullptr; i++) { + SCOPED_TRACE(i); + // Decode the test vector. + ASSERT_TRUE( + DecodeHexString(test_vector[i].msg, msg, &msg_len, sizeof(msg))); + ASSERT_TRUE(DecodeHexString(test_vector[i].qx, key, &qx_len, sizeof(key))); + ASSERT_TRUE(DecodeHexString(test_vector[i].qy, key + qx_len, &qy_len, + sizeof(key) - qx_len)); + ASSERT_TRUE(DecodeHexString(test_vector[i].r, signature, &r_len, + sizeof(signature))); + ASSERT_TRUE(DecodeHexString(test_vector[i].s, signature + r_len, &s_len, + sizeof(signature) - r_len)); + + // The test vector's lengths should look sane. + EXPECT_EQ(sizeof(key) / 2, qx_len); + EXPECT_EQ(sizeof(key) / 2, qy_len); + EXPECT_EQ(sizeof(signature) / 2, r_len); + EXPECT_EQ(sizeof(signature) / 2, s_len); + + EXPECT_EQ(test_vector[i].result, + ChannelIDVerifier::VerifyRaw( + absl::string_view(key, sizeof(key)), + absl::string_view(msg, msg_len), + absl::string_view(signature, sizeof(signature)), false)); + } +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/client_proof_source.cc b/quiche/quic/core/crypto/client_proof_source.cc new file mode 100644 index 000000000000..9d4795ca5d36 --- /dev/null +++ b/quiche/quic/core/crypto/client_proof_source.cc @@ -0,0 +1,62 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/client_proof_source.h" + +#include "absl/strings/match.h" + +namespace quic { + +bool DefaultClientProofSource::AddCertAndKey( + std::vector server_hostnames, + quiche::QuicheReferenceCountedPointer chain, + CertificatePrivateKey private_key) { + if (!ValidateCertAndKey(chain, private_key)) { + return false; + } + + auto cert_and_key = + std::make_shared(std::move(chain), std::move(private_key)); + for (const std::string& domain : server_hostnames) { + cert_and_keys_[domain] = cert_and_key; + } + return true; +} + +const ClientProofSource::CertAndKey* DefaultClientProofSource::GetCertAndKey( + absl::string_view hostname) const { + const CertAndKey* result = LookupExact(hostname); + if (result != nullptr || hostname == "*") { + return result; + } + + // Either a full or a wildcard domain lookup failed. In the former case, + // derive the wildcard domain and look it up. + if (hostname.size() > 1 && !absl::StartsWith(hostname, "*.")) { + auto dot_pos = hostname.find('.'); + if (dot_pos != std::string::npos) { + std::string wildcard = absl::StrCat("*", hostname.substr(dot_pos)); + const CertAndKey* result = LookupExact(wildcard); + if (result != nullptr) { + return result; + } + } + } + + // Return default cert, if any. + return LookupExact("*"); +} + +const ClientProofSource::CertAndKey* DefaultClientProofSource::LookupExact( + absl::string_view map_key) const { + const auto it = cert_and_keys_.find(map_key); + QUIC_DVLOG(1) << "LookupExact(" << map_key + << ") found:" << (it != cert_and_keys_.end()); + if (it != cert_and_keys_.end()) { + return it->second.get(); + } + return nullptr; +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/client_proof_source.h b/quiche/quic/core/crypto/client_proof_source.h new file mode 100644 index 000000000000..d1450f7bedb9 --- /dev/null +++ b/quiche/quic/core/crypto/client_proof_source.h @@ -0,0 +1,70 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_CLIENT_PROOF_SOURCE_H_ +#define QUICHE_QUIC_CORE_CRYPTO_CLIENT_PROOF_SOURCE_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "quiche/quic/core/crypto/certificate_view.h" +#include "quiche/quic/core/crypto/proof_source.h" + +namespace quic { + +// ClientProofSource is the interface for a QUIC client to provide client certs +// and keys based on server hostname. It is only used by TLS handshakes. +class QUIC_EXPORT_PRIVATE ClientProofSource { + public: + using Chain = ProofSource::Chain; + + virtual ~ClientProofSource() {} + + struct QUIC_EXPORT_PRIVATE CertAndKey { + CertAndKey(quiche::QuicheReferenceCountedPointer chain, + CertificatePrivateKey private_key) + : chain(std::move(chain)), private_key(std::move(private_key)) {} + + quiche::QuicheReferenceCountedPointer chain; + CertificatePrivateKey private_key; + }; + + // Get the client certificate to be sent to the server with |server_hostname| + // and its corresponding private key. It returns nullptr if the cert and key + // can not be found. + // + // |server_hostname| is typically a full domain name(www.foo.com), but it + // could also be a wildcard domain(*.foo.com), or a "*" which will return the + // default cert. + virtual const CertAndKey* GetCertAndKey( + absl::string_view server_hostname) const = 0; +}; + +// DefaultClientProofSource is an implementation that simply keeps an in memory +// map of server hostnames to certs. +class QUIC_EXPORT_PRIVATE DefaultClientProofSource : public ClientProofSource { + public: + ~DefaultClientProofSource() override {} + + // Associate all hostnames in |server_hostnames| with {|chain|,|private_key|}. + // Elements of |server_hostnames| can be full domain names(www.foo.com), + // wildcard domains(*.foo.com), or "*" which means the given cert chain is the + // default one. + // If any element of |server_hostnames| is already associated with a cert + // chain, it will be updated to be associated with the new cert chain. + bool AddCertAndKey(std::vector server_hostnames, + quiche::QuicheReferenceCountedPointer chain, + CertificatePrivateKey private_key); + + // ClientProofSource implementation + const CertAndKey* GetCertAndKey(absl::string_view hostname) const override; + + private: + const CertAndKey* LookupExact(absl::string_view map_key) const; + absl::flat_hash_map> cert_and_keys_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_CLIENT_PROOF_SOURCE_H_ diff --git a/quiche/quic/core/crypto/client_proof_source_test.cc b/quiche/quic/core/crypto/client_proof_source_test.cc new file mode 100644 index 000000000000..a35e0aa87a27 --- /dev/null +++ b/quiche/quic/core/crypto/client_proof_source_test.cc @@ -0,0 +1,215 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/client_proof_source.h" + +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/test_certificates.h" + +namespace quic { +namespace test { + +quiche::QuicheReferenceCountedPointer +TestCertChain() { + return quiche::QuicheReferenceCountedPointer( + new ClientProofSource::Chain({std::string(kTestCertificate)})); +} + +CertificatePrivateKey TestPrivateKey() { + CBS private_key_cbs; + CBS_init(&private_key_cbs, + reinterpret_cast(kTestCertificatePrivateKey.data()), + kTestCertificatePrivateKey.size()); + + return CertificatePrivateKey( + bssl::UniquePtr(EVP_parse_private_key(&private_key_cbs))); +} + +const ClientProofSource::CertAndKey* TestCertAndKey() { + static const ClientProofSource::CertAndKey cert_and_key(TestCertChain(), + TestPrivateKey()); + return &cert_and_key; +} + +quiche::QuicheReferenceCountedPointer +NullCertChain() { + return quiche::QuicheReferenceCountedPointer(); +} + +quiche::QuicheReferenceCountedPointer +EmptyCertChain() { + return quiche::QuicheReferenceCountedPointer( + new ClientProofSource::Chain(std::vector())); +} + +quiche::QuicheReferenceCountedPointer BadCertChain() { + return quiche::QuicheReferenceCountedPointer( + new ClientProofSource::Chain({"This is the content of a bad cert."})); +} + +CertificatePrivateKey EmptyPrivateKey() { + return CertificatePrivateKey(bssl::UniquePtr(EVP_PKEY_new())); +} + +#define VERIFY_CERT_AND_KEY_MATCHES(lhs, rhs) \ + do { \ + SCOPED_TRACE(testing::Message()); \ + VerifyCertAndKeyMatches(lhs, rhs); \ + } while (0) + +void VerifyCertAndKeyMatches(const ClientProofSource::CertAndKey* lhs, + const ClientProofSource::CertAndKey* rhs) { + if (lhs == rhs) { + return; + } + + if (lhs == nullptr) { + ADD_FAILURE() << "lhs is nullptr, but rhs is not"; + return; + } + + if (rhs == nullptr) { + ADD_FAILURE() << "rhs is nullptr, but lhs is not"; + return; + } + + if (1 != EVP_PKEY_cmp(lhs->private_key.private_key(), + rhs->private_key.private_key())) { + ADD_FAILURE() << "Private keys mismatch"; + return; + } + + const ClientProofSource::Chain* lhs_chain = lhs->chain.get(); + const ClientProofSource::Chain* rhs_chain = rhs->chain.get(); + + if (lhs_chain == rhs_chain) { + return; + } + + if (lhs_chain == nullptr) { + ADD_FAILURE() << "lhs->chain is nullptr, but rhs->chain is not"; + return; + } + + if (rhs_chain == nullptr) { + ADD_FAILURE() << "rhs->chain is nullptr, but lhs->chain is not"; + return; + } + + if (lhs_chain->certs.size() != rhs_chain->certs.size()) { + ADD_FAILURE() << "Cert chain length differ. lhs:" << lhs_chain->certs.size() + << ", rhs:" << rhs_chain->certs.size(); + return; + } + + for (size_t i = 0; i < lhs_chain->certs.size(); ++i) { + if (lhs_chain->certs[i] != rhs_chain->certs[i]) { + ADD_FAILURE() << "The " << i << "-th certs differ."; + return; + } + } + + // All good. +} + +TEST(DefaultClientProofSource, FullDomain) { + DefaultClientProofSource proof_source; + ASSERT_TRUE(proof_source.AddCertAndKey({"www.google.com"}, TestCertChain(), + TestPrivateKey())); + VERIFY_CERT_AND_KEY_MATCHES(proof_source.GetCertAndKey("www.google.com"), + TestCertAndKey()); + EXPECT_EQ(proof_source.GetCertAndKey("*.google.com"), nullptr); + EXPECT_EQ(proof_source.GetCertAndKey("*"), nullptr); +} + +TEST(DefaultClientProofSource, WildcardDomain) { + DefaultClientProofSource proof_source; + ASSERT_TRUE(proof_source.AddCertAndKey({"*.google.com"}, TestCertChain(), + TestPrivateKey())); + VERIFY_CERT_AND_KEY_MATCHES(proof_source.GetCertAndKey("www.google.com"), + TestCertAndKey()); + VERIFY_CERT_AND_KEY_MATCHES(proof_source.GetCertAndKey("*.google.com"), + TestCertAndKey()); + EXPECT_EQ(proof_source.GetCertAndKey("*"), nullptr); +} + +TEST(DefaultClientProofSource, DefaultDomain) { + DefaultClientProofSource proof_source; + ASSERT_TRUE( + proof_source.AddCertAndKey({"*"}, TestCertChain(), TestPrivateKey())); + VERIFY_CERT_AND_KEY_MATCHES(proof_source.GetCertAndKey("www.google.com"), + TestCertAndKey()); + VERIFY_CERT_AND_KEY_MATCHES(proof_source.GetCertAndKey("*.google.com"), + TestCertAndKey()); + VERIFY_CERT_AND_KEY_MATCHES(proof_source.GetCertAndKey("*"), + TestCertAndKey()); +} + +TEST(DefaultClientProofSource, FullAndWildcard) { + DefaultClientProofSource proof_source; + ASSERT_TRUE(proof_source.AddCertAndKey({"www.google.com", "*.google.com"}, + TestCertChain(), TestPrivateKey())); + VERIFY_CERT_AND_KEY_MATCHES(proof_source.GetCertAndKey("www.google.com"), + TestCertAndKey()); + VERIFY_CERT_AND_KEY_MATCHES(proof_source.GetCertAndKey("foo.google.com"), + TestCertAndKey()); + EXPECT_EQ(proof_source.GetCertAndKey("www.example.com"), nullptr); + EXPECT_EQ(proof_source.GetCertAndKey("*"), nullptr); +} + +TEST(DefaultClientProofSource, FullWildcardAndDefault) { + DefaultClientProofSource proof_source; + ASSERT_TRUE( + proof_source.AddCertAndKey({"www.google.com", "*.google.com", "*"}, + TestCertChain(), TestPrivateKey())); + VERIFY_CERT_AND_KEY_MATCHES(proof_source.GetCertAndKey("www.google.com"), + TestCertAndKey()); + VERIFY_CERT_AND_KEY_MATCHES(proof_source.GetCertAndKey("foo.google.com"), + TestCertAndKey()); + VERIFY_CERT_AND_KEY_MATCHES(proof_source.GetCertAndKey("www.example.com"), + TestCertAndKey()); + VERIFY_CERT_AND_KEY_MATCHES(proof_source.GetCertAndKey("*.google.com"), + TestCertAndKey()); + VERIFY_CERT_AND_KEY_MATCHES(proof_source.GetCertAndKey("*"), + TestCertAndKey()); +} + +TEST(DefaultClientProofSource, EmptyCerts) { + DefaultClientProofSource proof_source; + bool ok; + EXPECT_QUIC_BUG( + ok = proof_source.AddCertAndKey({"*"}, NullCertChain(), TestPrivateKey()), + "Certificate chain is empty"); + ASSERT_FALSE(ok); + + EXPECT_QUIC_BUG(ok = proof_source.AddCertAndKey({"*"}, EmptyCertChain(), + TestPrivateKey()), + "Certificate chain is empty"); + ASSERT_FALSE(ok); + EXPECT_EQ(proof_source.GetCertAndKey("*"), nullptr); +} + +TEST(DefaultClientProofSource, BadCerts) { + DefaultClientProofSource proof_source; + bool ok; + EXPECT_QUIC_BUG( + ok = proof_source.AddCertAndKey({"*"}, BadCertChain(), TestPrivateKey()), + "Unabled to parse leaf certificate"); + ASSERT_FALSE(ok); + EXPECT_EQ(proof_source.GetCertAndKey("*"), nullptr); +} + +TEST(DefaultClientProofSource, KeyMismatch) { + DefaultClientProofSource proof_source; + bool ok; + EXPECT_QUIC_BUG(ok = proof_source.AddCertAndKey( + {"www.google.com"}, TestCertChain(), EmptyPrivateKey()), + "Private key does not match the leaf certificate"); + ASSERT_FALSE(ok); + EXPECT_EQ(proof_source.GetCertAndKey("*"), nullptr); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/crypto_framer.cc b/quiche/quic/core/crypto/crypto_framer.cc new file mode 100644 index 000000000000..57a949ed6d7f --- /dev/null +++ b/quiche/quic/core/crypto/crypto_framer.cc @@ -0,0 +1,351 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/crypto_framer.h" + +#include +#include + +#include "absl/base/attributes.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/quic_data_reader.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/quiche_endian.h" + +namespace quic { + +namespace { + +const size_t kQuicTagSize = sizeof(QuicTag); +const size_t kCryptoEndOffsetSize = sizeof(uint32_t); +const size_t kNumEntriesSize = sizeof(uint16_t); + +// OneShotVisitor is a framer visitor that records a single handshake message. +class OneShotVisitor : public CryptoFramerVisitorInterface { + public: + OneShotVisitor() : error_(false) {} + + void OnError(CryptoFramer* /*framer*/) override { error_ = true; } + + void OnHandshakeMessage(const CryptoHandshakeMessage& message) override { + out_ = std::make_unique(message); + } + + bool error() const { return error_; } + + std::unique_ptr release() { return std::move(out_); } + + private: + std::unique_ptr out_; + bool error_; +}; + +} // namespace + +CryptoFramer::CryptoFramer() + : visitor_(nullptr), + error_detail_(""), + num_entries_(0), + values_len_(0), + process_truncated_messages_(false) { + Clear(); +} + +CryptoFramer::~CryptoFramer() {} + +// static +std::unique_ptr CryptoFramer::ParseMessage( + absl::string_view in) { + OneShotVisitor visitor; + CryptoFramer framer; + + framer.set_visitor(&visitor); + if (!framer.ProcessInput(in) || visitor.error() || + framer.InputBytesRemaining()) { + return nullptr; + } + + return visitor.release(); +} + +QuicErrorCode CryptoFramer::error() const { return error_; } + +const std::string& CryptoFramer::error_detail() const { return error_detail_; } + +bool CryptoFramer::ProcessInput(absl::string_view input, + EncryptionLevel /*level*/) { + return ProcessInput(input); +} + +bool CryptoFramer::ProcessInput(absl::string_view input) { + QUICHE_DCHECK_EQ(QUIC_NO_ERROR, error_); + if (error_ != QUIC_NO_ERROR) { + return false; + } + error_ = Process(input); + if (error_ != QUIC_NO_ERROR) { + QUICHE_DCHECK(!error_detail_.empty()); + visitor_->OnError(this); + return false; + } + + return true; +} + +size_t CryptoFramer::InputBytesRemaining() const { return buffer_.length(); } + +bool CryptoFramer::HasTag(QuicTag tag) const { + if (state_ != STATE_READING_VALUES) { + return false; + } + for (const auto& it : tags_and_lengths_) { + if (it.first == tag) { + return true; + } + } + return false; +} + +void CryptoFramer::ForceHandshake() { + QuicDataReader reader(buffer_.data(), buffer_.length(), + quiche::HOST_BYTE_ORDER); + for (const std::pair& item : tags_and_lengths_) { + absl::string_view value; + if (reader.BytesRemaining() < item.second) { + break; + } + reader.ReadStringPiece(&value, item.second); + message_.SetStringPiece(item.first, value); + } + visitor_->OnHandshakeMessage(message_); +} + +// static +std::unique_ptr CryptoFramer::ConstructHandshakeMessage( + const CryptoHandshakeMessage& message) { + size_t num_entries = message.tag_value_map().size(); + size_t pad_length = 0; + bool need_pad_tag = false; + bool need_pad_value = false; + + size_t len = message.size(); + if (len < message.minimum_size()) { + need_pad_tag = true; + need_pad_value = true; + num_entries++; + + size_t delta = message.minimum_size() - len; + const size_t overhead = kQuicTagSize + kCryptoEndOffsetSize; + if (delta > overhead) { + pad_length = delta - overhead; + } + len += overhead + pad_length; + } + + if (num_entries > kMaxEntries) { + return nullptr; + } + + std::unique_ptr buffer(new char[len]); + QuicDataWriter writer(len, buffer.get(), quiche::HOST_BYTE_ORDER); + if (!writer.WriteTag(message.tag())) { + QUICHE_DCHECK(false) << "Failed to write message tag."; + return nullptr; + } + if (!writer.WriteUInt16(static_cast(num_entries))) { + QUICHE_DCHECK(false) << "Failed to write size."; + return nullptr; + } + if (!writer.WriteUInt16(0)) { + QUICHE_DCHECK(false) << "Failed to write padding."; + return nullptr; + } + + uint32_t end_offset = 0; + // Tags and offsets + for (auto it = message.tag_value_map().begin(); + it != message.tag_value_map().end(); ++it) { + if (it->first == kPAD && need_pad_tag) { + // Existing PAD tags are only checked when padding needs to be added + // because parts of the code may need to reserialize received messages + // and those messages may, legitimately include padding. + QUICHE_DCHECK(false) + << "Message needed padding but already contained a PAD tag"; + return nullptr; + } + + if (it->first > kPAD && need_pad_tag) { + need_pad_tag = false; + if (!WritePadTag(&writer, pad_length, &end_offset)) { + return nullptr; + } + } + + if (!writer.WriteTag(it->first)) { + QUICHE_DCHECK(false) << "Failed to write tag."; + return nullptr; + } + end_offset += it->second.length(); + if (!writer.WriteUInt32(end_offset)) { + QUICHE_DCHECK(false) << "Failed to write end offset."; + return nullptr; + } + } + + if (need_pad_tag) { + if (!WritePadTag(&writer, pad_length, &end_offset)) { + return nullptr; + } + } + + // Values + for (auto it = message.tag_value_map().begin(); + it != message.tag_value_map().end(); ++it) { + if (it->first > kPAD && need_pad_value) { + need_pad_value = false; + if (!writer.WriteRepeatedByte('-', pad_length)) { + QUICHE_DCHECK(false) << "Failed to write padding."; + return nullptr; + } + } + + if (!writer.WriteBytes(it->second.data(), it->second.length())) { + QUICHE_DCHECK(false) << "Failed to write value."; + return nullptr; + } + } + + if (need_pad_value) { + if (!writer.WriteRepeatedByte('-', pad_length)) { + QUICHE_DCHECK(false) << "Failed to write padding."; + return nullptr; + } + } + + return std::make_unique(buffer.release(), len, true); +} + +void CryptoFramer::Clear() { + message_.Clear(); + tags_and_lengths_.clear(); + error_ = QUIC_NO_ERROR; + error_detail_ = ""; + state_ = STATE_READING_TAG; +} + +QuicErrorCode CryptoFramer::Process(absl::string_view input) { + // Add this data to the buffer. + buffer_.append(input.data(), input.length()); + QuicDataReader reader(buffer_.data(), buffer_.length(), + quiche::HOST_BYTE_ORDER); + + switch (state_) { + case STATE_READING_TAG: + if (reader.BytesRemaining() < kQuicTagSize) { + break; + } + QuicTag message_tag; + reader.ReadTag(&message_tag); + message_.set_tag(message_tag); + state_ = STATE_READING_NUM_ENTRIES; + ABSL_FALLTHROUGH_INTENDED; + case STATE_READING_NUM_ENTRIES: + if (reader.BytesRemaining() < kNumEntriesSize + sizeof(uint16_t)) { + break; + } + reader.ReadUInt16(&num_entries_); + if (num_entries_ > kMaxEntries) { + error_detail_ = absl::StrCat(num_entries_, " entries"); + return QUIC_CRYPTO_TOO_MANY_ENTRIES; + } + uint16_t padding; + reader.ReadUInt16(&padding); + + tags_and_lengths_.reserve(num_entries_); + state_ = STATE_READING_TAGS_AND_LENGTHS; + values_len_ = 0; + ABSL_FALLTHROUGH_INTENDED; + case STATE_READING_TAGS_AND_LENGTHS: { + if (reader.BytesRemaining() < + num_entries_ * (kQuicTagSize + kCryptoEndOffsetSize)) { + break; + } + + uint32_t last_end_offset = 0; + for (unsigned i = 0; i < num_entries_; ++i) { + QuicTag tag; + reader.ReadTag(&tag); + if (i > 0 && tag <= tags_and_lengths_[i - 1].first) { + if (tag == tags_and_lengths_[i - 1].first) { + error_detail_ = absl::StrCat("Duplicate tag:", tag); + return QUIC_CRYPTO_DUPLICATE_TAG; + } + error_detail_ = absl::StrCat("Tag ", tag, " out of order"); + return QUIC_CRYPTO_TAGS_OUT_OF_ORDER; + } + + uint32_t end_offset; + reader.ReadUInt32(&end_offset); + + if (end_offset < last_end_offset) { + error_detail_ = + absl::StrCat("End offset: ", end_offset, " vs ", last_end_offset); + return QUIC_CRYPTO_TAGS_OUT_OF_ORDER; + } + tags_and_lengths_.push_back(std::make_pair( + tag, static_cast(end_offset - last_end_offset))); + last_end_offset = end_offset; + } + values_len_ = last_end_offset; + state_ = STATE_READING_VALUES; + ABSL_FALLTHROUGH_INTENDED; + } + case STATE_READING_VALUES: + if (reader.BytesRemaining() < values_len_) { + if (!process_truncated_messages_) { + break; + } + QUIC_LOG(ERROR) << "Trunacted message. Missing " + << values_len_ - reader.BytesRemaining() << " bytes."; + } + for (const std::pair& item : tags_and_lengths_) { + absl::string_view value; + if (!reader.ReadStringPiece(&value, item.second)) { + QUICHE_DCHECK(process_truncated_messages_); + // Store an empty value. + message_.SetStringPiece(item.first, ""); + continue; + } + message_.SetStringPiece(item.first, value); + } + visitor_->OnHandshakeMessage(message_); + Clear(); + state_ = STATE_READING_TAG; + break; + } + // Save any remaining data. + buffer_ = std::string(reader.PeekRemainingPayload()); + return QUIC_NO_ERROR; +} + +// static +bool CryptoFramer::WritePadTag(QuicDataWriter* writer, size_t pad_length, + uint32_t* end_offset) { + if (!writer->WriteTag(kPAD)) { + QUICHE_DCHECK(false) << "Failed to write tag."; + return false; + } + *end_offset += pad_length; + if (!writer->WriteUInt32(*end_offset)) { + QUICHE_DCHECK(false) << "Failed to write end offset."; + return false; + } + return true; +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/crypto_framer.h b/quiche/quic/core/crypto/crypto_framer.h new file mode 100644 index 000000000000..8d20bdcdae34 --- /dev/null +++ b/quiche/quic/core/crypto/crypto_framer.h @@ -0,0 +1,136 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_CRYPTO_FRAMER_H_ +#define QUICHE_QUIC_CORE_CRYPTO_CRYPTO_FRAMER_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/crypto_handshake_message.h" +#include "quiche/quic/core/crypto/crypto_message_parser.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class CryptoFramer; +class QuicData; +class QuicDataWriter; + +class QUIC_EXPORT_PRIVATE CryptoFramerVisitorInterface { + public: + virtual ~CryptoFramerVisitorInterface() {} + + // Called if an error is detected. + virtual void OnError(CryptoFramer* framer) = 0; + + // Called when a complete handshake message has been parsed. + virtual void OnHandshakeMessage(const CryptoHandshakeMessage& message) = 0; +}; + +// A class for framing the crypto messages that are exchanged in a QUIC +// session. +class QUIC_EXPORT_PRIVATE CryptoFramer : public CryptoMessageParser { + public: + CryptoFramer(); + + ~CryptoFramer() override; + + // ParseMessage parses exactly one message from the given + // absl::string_view. If there is an error, the message is truncated, + // or the message has trailing garbage then nullptr will be returned. + static std::unique_ptr ParseMessage( + absl::string_view in); + + // Set callbacks to be called from the framer. A visitor must be set, or + // else the framer will crash. It is acceptable for the visitor to do + // nothing. If this is called multiple times, only the last visitor + // will be used. |visitor| will be owned by the framer. + void set_visitor(CryptoFramerVisitorInterface* visitor) { + visitor_ = visitor; + } + + QuicErrorCode error() const override; + const std::string& error_detail() const override; + + // Processes input data, which must be delivered in order. Returns + // false if there was an error, and true otherwise. ProcessInput optionally + // takes an EncryptionLevel, but it is ignored. The variant with the + // EncryptionLevel is provided to match the CryptoMessageParser interface. + bool ProcessInput(absl::string_view input, EncryptionLevel level) override; + bool ProcessInput(absl::string_view input); + + // Returns the number of bytes of buffered input data remaining to be + // parsed. + size_t InputBytesRemaining() const override; + + // Checks if the specified tag has been seen. Returns |true| if it + // has, and |false| if it has not or a CHLO has not been seen. + bool HasTag(QuicTag tag) const; + + // Even if the CHLO has not been fully received, force processing of + // the handshake message. This is dangerous and should not be used + // except as a mechanism of last resort. + void ForceHandshake(); + + // Returns a new QuicData owned by the caller that contains a serialized + // |message|, or nullptr if there was an error. + static std::unique_ptr ConstructHandshakeMessage( + const CryptoHandshakeMessage& message); + + // Debug only method which permits processing truncated messages. + void set_process_truncated_messages(bool process_truncated_messages) { + process_truncated_messages_ = process_truncated_messages; + } + + private: + // Clears per-message state. Does not clear the visitor. + void Clear(); + + // Process does does the work of |ProcessInput|, but returns an error code, + // doesn't set error_ and doesn't call |visitor_->OnError()|. + QuicErrorCode Process(absl::string_view input); + + static bool WritePadTag(QuicDataWriter* writer, size_t pad_length, + uint32_t* end_offset); + + // Represents the current state of the parsing state machine. + enum CryptoFramerState { + STATE_READING_TAG, + STATE_READING_NUM_ENTRIES, + STATE_READING_TAGS_AND_LENGTHS, + STATE_READING_VALUES + }; + + // Visitor to invoke when messages are parsed. + CryptoFramerVisitorInterface* visitor_; + // Last error. + QuicErrorCode error_; + // Remaining unparsed data. + std::string buffer_; + // Current state of the parsing. + CryptoFramerState state_; + // The message currently being parsed. + CryptoHandshakeMessage message_; + // The issue which caused |error_| + std::string error_detail_; + // Number of entires in the message currently being parsed. + uint16_t num_entries_; + // tags_and_lengths_ contains the tags that are currently being parsed and + // their lengths. + std::vector> tags_and_lengths_; + // Cumulative length of all values in the message currently being parsed. + size_t values_len_; + // Set to true to allow of processing of truncated messages for debugging. + bool process_truncated_messages_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_CRYPTO_FRAMER_H_ diff --git a/quiche/quic/core/crypto/crypto_framer_test.cc b/quiche/quic/core/crypto/crypto_framer_test.cc new file mode 100644 index 000000000000..5f79f640bcad --- /dev/null +++ b/quiche/quic/core/crypto/crypto_framer_test.cc @@ -0,0 +1,442 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/crypto_framer.h" + +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/crypto_handshake.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +namespace quic { +namespace test { +namespace { + +char* AsChars(unsigned char* data) { return reinterpret_cast(data); } + +class TestCryptoVisitor : public CryptoFramerVisitorInterface { + public: + TestCryptoVisitor() : error_count_(0) {} + + void OnError(CryptoFramer* framer) override { + QUIC_DLOG(ERROR) << "CryptoFramer Error: " << framer->error(); + ++error_count_; + } + + void OnHandshakeMessage(const CryptoHandshakeMessage& message) override { + messages_.push_back(message); + } + + // Counters from the visitor callbacks. + int error_count_; + + std::vector messages_; +}; + +TEST(CryptoFramerTest, ConstructHandshakeMessage) { + CryptoHandshakeMessage message; + message.set_tag(0xFFAA7733); + message.SetStringPiece(0x12345678, "abcdef"); + message.SetStringPiece(0x12345679, "ghijk"); + message.SetStringPiece(0x1234567A, "lmnopqr"); + + unsigned char packet[] = {// tag + 0x33, 0x77, 0xAA, 0xFF, + // num entries + 0x03, 0x00, + // padding + 0x00, 0x00, + // tag 1 + 0x78, 0x56, 0x34, 0x12, + // end offset 1 + 0x06, 0x00, 0x00, 0x00, + // tag 2 + 0x79, 0x56, 0x34, 0x12, + // end offset 2 + 0x0b, 0x00, 0x00, 0x00, + // tag 3 + 0x7A, 0x56, 0x34, 0x12, + // end offset 3 + 0x12, 0x00, 0x00, 0x00, + // value 1 + 'a', 'b', 'c', 'd', 'e', 'f', + // value 2 + 'g', 'h', 'i', 'j', 'k', + // value 3 + 'l', 'm', 'n', 'o', 'p', 'q', 'r'}; + + CryptoFramer framer; + std::unique_ptr data = framer.ConstructHandshakeMessage(message); + ASSERT_TRUE(data != nullptr); + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet), + ABSL_ARRAYSIZE(packet)); +} + +TEST(CryptoFramerTest, ConstructHandshakeMessageWithTwoKeys) { + CryptoHandshakeMessage message; + message.set_tag(0xFFAA7733); + message.SetStringPiece(0x12345678, "abcdef"); + message.SetStringPiece(0x12345679, "ghijk"); + + unsigned char packet[] = {// tag + 0x33, 0x77, 0xAA, 0xFF, + // num entries + 0x02, 0x00, + // padding + 0x00, 0x00, + // tag 1 + 0x78, 0x56, 0x34, 0x12, + // end offset 1 + 0x06, 0x00, 0x00, 0x00, + // tag 2 + 0x79, 0x56, 0x34, 0x12, + // end offset 2 + 0x0b, 0x00, 0x00, 0x00, + // value 1 + 'a', 'b', 'c', 'd', 'e', 'f', + // value 2 + 'g', 'h', 'i', 'j', 'k'}; + + CryptoFramer framer; + std::unique_ptr data = framer.ConstructHandshakeMessage(message); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet), + ABSL_ARRAYSIZE(packet)); +} + +TEST(CryptoFramerTest, ConstructHandshakeMessageZeroLength) { + CryptoHandshakeMessage message; + message.set_tag(0xFFAA7733); + message.SetStringPiece(0x12345678, ""); + + unsigned char packet[] = {// tag + 0x33, 0x77, 0xAA, 0xFF, + // num entries + 0x01, 0x00, + // padding + 0x00, 0x00, + // tag 1 + 0x78, 0x56, 0x34, 0x12, + // end offset 1 + 0x00, 0x00, 0x00, 0x00}; + + CryptoFramer framer; + std::unique_ptr data = framer.ConstructHandshakeMessage(message); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet), + ABSL_ARRAYSIZE(packet)); +} + +TEST(CryptoFramerTest, ConstructHandshakeMessageTooManyEntries) { + CryptoHandshakeMessage message; + message.set_tag(0xFFAA7733); + for (uint32_t key = 1; key <= kMaxEntries + 1; ++key) { + message.SetStringPiece(key, "abcdef"); + } + + CryptoFramer framer; + std::unique_ptr data = framer.ConstructHandshakeMessage(message); + EXPECT_TRUE(data == nullptr); +} + +TEST(CryptoFramerTest, ConstructHandshakeMessageMinimumSize) { + CryptoHandshakeMessage message; + message.set_tag(0xFFAA7733); + message.SetStringPiece(0x01020304, "test"); + message.set_minimum_size(64); + + unsigned char packet[] = {// tag + 0x33, 0x77, 0xAA, 0xFF, + // num entries + 0x02, 0x00, + // padding + 0x00, 0x00, + // tag 1 + 'P', 'A', 'D', 0, + // end offset 1 + 0x24, 0x00, 0x00, 0x00, + // tag 2 + 0x04, 0x03, 0x02, 0x01, + // end offset 2 + 0x28, 0x00, 0x00, 0x00, + // 36 bytes of padding. + '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', + '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', + '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', + '-', '-', '-', '-', '-', '-', + // value 2 + 't', 'e', 's', 't'}; + + CryptoFramer framer; + std::unique_ptr data = framer.ConstructHandshakeMessage(message); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet), + ABSL_ARRAYSIZE(packet)); +} + +TEST(CryptoFramerTest, ConstructHandshakeMessageMinimumSizePadLast) { + CryptoHandshakeMessage message; + message.set_tag(0xFFAA7733); + message.SetStringPiece(1, ""); + message.set_minimum_size(64); + + unsigned char packet[] = {// tag + 0x33, 0x77, 0xAA, 0xFF, + // num entries + 0x02, 0x00, + // padding + 0x00, 0x00, + // tag 1 + 0x01, 0x00, 0x00, 0x00, + // end offset 1 + 0x00, 0x00, 0x00, 0x00, + // tag 2 + 'P', 'A', 'D', 0, + // end offset 2 + 0x28, 0x00, 0x00, 0x00, + // 40 bytes of padding. + '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', + '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', + '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', + '-', '-', '-', '-', '-', '-', '-', '-', '-', '-'}; + + CryptoFramer framer; + std::unique_ptr data = framer.ConstructHandshakeMessage(message); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet), + ABSL_ARRAYSIZE(packet)); +} + +TEST(CryptoFramerTest, ProcessInput) { + test::TestCryptoVisitor visitor; + CryptoFramer framer; + framer.set_visitor(&visitor); + + unsigned char input[] = {// tag + 0x33, 0x77, 0xAA, 0xFF, + // num entries + 0x02, 0x00, + // padding + 0x00, 0x00, + // tag 1 + 0x78, 0x56, 0x34, 0x12, + // end offset 1 + 0x06, 0x00, 0x00, 0x00, + // tag 2 + 0x79, 0x56, 0x34, 0x12, + // end offset 2 + 0x0b, 0x00, 0x00, 0x00, + // value 1 + 'a', 'b', 'c', 'd', 'e', 'f', + // value 2 + 'g', 'h', 'i', 'j', 'k'}; + + EXPECT_TRUE(framer.ProcessInput( + absl::string_view(AsChars(input), ABSL_ARRAYSIZE(input)))); + EXPECT_EQ(0u, framer.InputBytesRemaining()); + EXPECT_EQ(0, visitor.error_count_); + ASSERT_EQ(1u, visitor.messages_.size()); + const CryptoHandshakeMessage& message = visitor.messages_[0]; + EXPECT_EQ(0xFFAA7733, message.tag()); + EXPECT_EQ(2u, message.tag_value_map().size()); + EXPECT_EQ("abcdef", crypto_test_utils::GetValueForTag(message, 0x12345678)); + EXPECT_EQ("ghijk", crypto_test_utils::GetValueForTag(message, 0x12345679)); +} + +TEST(CryptoFramerTest, ProcessInputWithThreeKeys) { + test::TestCryptoVisitor visitor; + CryptoFramer framer; + framer.set_visitor(&visitor); + + unsigned char input[] = {// tag + 0x33, 0x77, 0xAA, 0xFF, + // num entries + 0x03, 0x00, + // padding + 0x00, 0x00, + // tag 1 + 0x78, 0x56, 0x34, 0x12, + // end offset 1 + 0x06, 0x00, 0x00, 0x00, + // tag 2 + 0x79, 0x56, 0x34, 0x12, + // end offset 2 + 0x0b, 0x00, 0x00, 0x00, + // tag 3 + 0x7A, 0x56, 0x34, 0x12, + // end offset 3 + 0x12, 0x00, 0x00, 0x00, + // value 1 + 'a', 'b', 'c', 'd', 'e', 'f', + // value 2 + 'g', 'h', 'i', 'j', 'k', + // value 3 + 'l', 'm', 'n', 'o', 'p', 'q', 'r'}; + + EXPECT_TRUE(framer.ProcessInput( + absl::string_view(AsChars(input), ABSL_ARRAYSIZE(input)))); + EXPECT_EQ(0u, framer.InputBytesRemaining()); + EXPECT_EQ(0, visitor.error_count_); + ASSERT_EQ(1u, visitor.messages_.size()); + const CryptoHandshakeMessage& message = visitor.messages_[0]; + EXPECT_EQ(0xFFAA7733, message.tag()); + EXPECT_EQ(3u, message.tag_value_map().size()); + EXPECT_EQ("abcdef", crypto_test_utils::GetValueForTag(message, 0x12345678)); + EXPECT_EQ("ghijk", crypto_test_utils::GetValueForTag(message, 0x12345679)); + EXPECT_EQ("lmnopqr", crypto_test_utils::GetValueForTag(message, 0x1234567A)); +} + +TEST(CryptoFramerTest, ProcessInputIncrementally) { + test::TestCryptoVisitor visitor; + CryptoFramer framer; + framer.set_visitor(&visitor); + + unsigned char input[] = {// tag + 0x33, 0x77, 0xAA, 0xFF, + // num entries + 0x02, 0x00, + // padding + 0x00, 0x00, + // tag 1 + 0x78, 0x56, 0x34, 0x12, + // end offset 1 + 0x06, 0x00, 0x00, 0x00, + // tag 2 + 0x79, 0x56, 0x34, 0x12, + // end offset 2 + 0x0b, 0x00, 0x00, 0x00, + // value 1 + 'a', 'b', 'c', 'd', 'e', 'f', + // value 2 + 'g', 'h', 'i', 'j', 'k'}; + + for (size_t i = 0; i < ABSL_ARRAYSIZE(input); i++) { + EXPECT_TRUE(framer.ProcessInput(absl::string_view(AsChars(input) + i, 1))); + } + EXPECT_EQ(0u, framer.InputBytesRemaining()); + ASSERT_EQ(1u, visitor.messages_.size()); + const CryptoHandshakeMessage& message = visitor.messages_[0]; + EXPECT_EQ(0xFFAA7733, message.tag()); + EXPECT_EQ(2u, message.tag_value_map().size()); + EXPECT_EQ("abcdef", crypto_test_utils::GetValueForTag(message, 0x12345678)); + EXPECT_EQ("ghijk", crypto_test_utils::GetValueForTag(message, 0x12345679)); +} + +TEST(CryptoFramerTest, ProcessInputTagsOutOfOrder) { + test::TestCryptoVisitor visitor; + CryptoFramer framer; + framer.set_visitor(&visitor); + + unsigned char input[] = {// tag + 0x33, 0x77, 0xAA, 0xFF, + // num entries + 0x02, 0x00, + // padding + 0x00, 0x00, + // tag 1 + 0x78, 0x56, 0x34, 0x13, + // end offset 1 + 0x01, 0x00, 0x00, 0x00, + // tag 2 + 0x79, 0x56, 0x34, 0x12, + // end offset 2 + 0x02, 0x00, 0x00, 0x00}; + + EXPECT_FALSE(framer.ProcessInput( + absl::string_view(AsChars(input), ABSL_ARRAYSIZE(input)))); + EXPECT_THAT(framer.error(), IsError(QUIC_CRYPTO_TAGS_OUT_OF_ORDER)); + EXPECT_EQ(1, visitor.error_count_); +} + +TEST(CryptoFramerTest, ProcessEndOffsetsOutOfOrder) { + test::TestCryptoVisitor visitor; + CryptoFramer framer; + framer.set_visitor(&visitor); + + unsigned char input[] = {// tag + 0x33, 0x77, 0xAA, 0xFF, + // num entries + 0x02, 0x00, + // padding + 0x00, 0x00, + // tag 1 + 0x79, 0x56, 0x34, 0x12, + // end offset 1 + 0x01, 0x00, 0x00, 0x00, + // tag 2 + 0x78, 0x56, 0x34, 0x13, + // end offset 2 + 0x00, 0x00, 0x00, 0x00}; + + EXPECT_FALSE(framer.ProcessInput( + absl::string_view(AsChars(input), ABSL_ARRAYSIZE(input)))); + EXPECT_THAT(framer.error(), IsError(QUIC_CRYPTO_TAGS_OUT_OF_ORDER)); + EXPECT_EQ(1, visitor.error_count_); +} + +TEST(CryptoFramerTest, ProcessInputTooManyEntries) { + test::TestCryptoVisitor visitor; + CryptoFramer framer; + framer.set_visitor(&visitor); + + unsigned char input[] = {// tag + 0x33, 0x77, 0xAA, 0xFF, + // num entries + 0xA0, 0x00, + // padding + 0x00, 0x00}; + + EXPECT_FALSE(framer.ProcessInput( + absl::string_view(AsChars(input), ABSL_ARRAYSIZE(input)))); + EXPECT_THAT(framer.error(), IsError(QUIC_CRYPTO_TOO_MANY_ENTRIES)); + EXPECT_EQ(1, visitor.error_count_); +} + +TEST(CryptoFramerTest, ProcessInputZeroLength) { + test::TestCryptoVisitor visitor; + CryptoFramer framer; + framer.set_visitor(&visitor); + + unsigned char input[] = {// tag + 0x33, 0x77, 0xAA, 0xFF, + // num entries + 0x02, 0x00, + // padding + 0x00, 0x00, + // tag 1 + 0x78, 0x56, 0x34, 0x12, + // end offset 1 + 0x00, 0x00, 0x00, 0x00, + // tag 2 + 0x79, 0x56, 0x34, 0x12, + // end offset 2 + 0x05, 0x00, 0x00, 0x00}; + + EXPECT_TRUE(framer.ProcessInput( + absl::string_view(AsChars(input), ABSL_ARRAYSIZE(input)))); + EXPECT_EQ(0, visitor.error_count_); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/crypto_handshake.cc b/quiche/quic/core/crypto/crypto_handshake.cc new file mode 100644 index 000000000000..bc17a97de513 --- /dev/null +++ b/quiche/quic/core/crypto/crypto_handshake.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/crypto_handshake.h" + +#include "quiche/quic/core/crypto/key_exchange.h" +#include "quiche/quic/core/crypto/quic_decrypter.h" +#include "quiche/quic/core/crypto/quic_encrypter.h" + +namespace quic { + +QuicCryptoNegotiatedParameters::QuicCryptoNegotiatedParameters() + : key_exchange(0), + aead(0), + token_binding_key_param(0), + sct_supported_by_client(false) {} + +QuicCryptoNegotiatedParameters::~QuicCryptoNegotiatedParameters() {} + +CrypterPair::CrypterPair() {} + +CrypterPair::~CrypterPair() {} + +// static +const char QuicCryptoConfig::kInitialLabel[] = "QUIC key expansion"; + +// static +const char QuicCryptoConfig::kCETVLabel[] = "QUIC CETV block"; + +// static +const char QuicCryptoConfig::kForwardSecureLabel[] = + "QUIC forward secure key expansion"; + +QuicCryptoConfig::QuicCryptoConfig() = default; + +QuicCryptoConfig::~QuicCryptoConfig() = default; + +} // namespace quic diff --git a/quiche/quic/core/crypto/crypto_handshake.h b/quiche/quic/core/crypto/crypto_handshake.h new file mode 100644 index 000000000000..6a4b274f8104 --- /dev/null +++ b/quiche/quic/core/crypto/crypto_handshake.h @@ -0,0 +1,190 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_CRYPTO_HANDSHAKE_H_ +#define QUICHE_QUIC_CORE_CRYPTO_CRYPTO_HANDSHAKE_H_ + +#include +#include +#include + +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class SynchronousKeyExchange; +class QuicDecrypter; +class QuicEncrypter; + +// HandshakeFailureReason enum values are uploaded to UMA, they cannot be +// changed. +enum HandshakeFailureReason { + HANDSHAKE_OK = 0, + + // Failure reasons for an invalid client nonce in CHLO. + // + // The default error value for nonce verification failures from strike + // register (covers old strike registers and unknown failures). + CLIENT_NONCE_UNKNOWN_FAILURE = 1, + // Client nonce had incorrect length. + CLIENT_NONCE_INVALID_FAILURE = 2, + // Client nonce is not unique. + CLIENT_NONCE_NOT_UNIQUE_FAILURE = 3, + // Client orbit is invalid or incorrect. + CLIENT_NONCE_INVALID_ORBIT_FAILURE = 4, + // Client nonce's timestamp is not in the strike register's valid time range. + CLIENT_NONCE_INVALID_TIME_FAILURE = 5, + // Strike register's RPC call timed out, client nonce couldn't be verified. + CLIENT_NONCE_STRIKE_REGISTER_TIMEOUT = 6, + // Strike register is down, client nonce couldn't be verified. + CLIENT_NONCE_STRIKE_REGISTER_FAILURE = 7, + + // Failure reasons for an invalid server nonce in CHLO. + // + // Unbox of server nonce failed. + SERVER_NONCE_DECRYPTION_FAILURE = 8, + // Decrypted server nonce had incorrect length. + SERVER_NONCE_INVALID_FAILURE = 9, + // Server nonce is not unique. + SERVER_NONCE_NOT_UNIQUE_FAILURE = 10, + // Server nonce's timestamp is not in the strike register's valid time range. + SERVER_NONCE_INVALID_TIME_FAILURE = 11, + // The server requires handshake confirmation. + SERVER_NONCE_REQUIRED_FAILURE = 20, + + // Failure reasons for an invalid server config in CHLO. + // + // Missing Server config id (kSCID) tag. + SERVER_CONFIG_INCHOATE_HELLO_FAILURE = 12, + // Couldn't find the Server config id (kSCID). + SERVER_CONFIG_UNKNOWN_CONFIG_FAILURE = 13, + + // Failure reasons for an invalid source-address token. + // + // Missing Source-address token (kSourceAddressTokenTag) tag. + SOURCE_ADDRESS_TOKEN_INVALID_FAILURE = 14, + // Unbox of Source-address token failed. + SOURCE_ADDRESS_TOKEN_DECRYPTION_FAILURE = 15, + // Couldn't parse the unbox'ed Source-address token. + SOURCE_ADDRESS_TOKEN_PARSE_FAILURE = 16, + // Source-address token is for a different IP address. + SOURCE_ADDRESS_TOKEN_DIFFERENT_IP_ADDRESS_FAILURE = 17, + // The source-address token has a timestamp in the future. + SOURCE_ADDRESS_TOKEN_CLOCK_SKEW_FAILURE = 18, + // The source-address token has expired. + SOURCE_ADDRESS_TOKEN_EXPIRED_FAILURE = 19, + + // The expected leaf certificate hash could not be validated. + INVALID_EXPECTED_LEAF_CERTIFICATE = 21, + + MAX_FAILURE_REASON = 22, +}; + +// These errors will be packed into an uint32_t and we don't want to set the +// most significant bit, which may be misinterpreted as the sign bit. +static_assert(MAX_FAILURE_REASON <= 32, "failure reason out of sync"); + +// A CrypterPair contains the encrypter and decrypter for an encryption level. +struct QUIC_EXPORT_PRIVATE CrypterPair { + CrypterPair(); + CrypterPair(CrypterPair&&) = default; + ~CrypterPair(); + + std::unique_ptr encrypter; + std::unique_ptr decrypter; +}; + +// Parameters negotiated by the crypto handshake. +struct QUIC_EXPORT_PRIVATE QuicCryptoNegotiatedParameters + : public quiche::QuicheReferenceCounted { + // Initializes the members to 0 or empty values. + QuicCryptoNegotiatedParameters(); + + QuicTag key_exchange; + QuicTag aead; + std::string initial_premaster_secret; + std::string forward_secure_premaster_secret; + // initial_subkey_secret is used as the PRK input to the HKDF used when + // performing key extraction that needs to happen before forward-secure keys + // are available. + std::string initial_subkey_secret; + // subkey_secret is used as the PRK input to the HKDF used for key extraction. + std::string subkey_secret; + CrypterPair initial_crypters; + CrypterPair forward_secure_crypters; + // Normalized SNI: converted to lower case and trailing '.' removed. + std::string sni; + std::string client_nonce; + std::string server_nonce; + // hkdf_input_suffix contains the HKDF input following the label: the + // ConnectionId, client hello and server config. This is only populated in the + // client because only the client needs to derive the forward secure keys at a + // later time from the initial keys. + std::string hkdf_input_suffix; + // cached_certs contains the cached certificates that a client used when + // sending a client hello. + std::vector cached_certs; + // client_key_exchange is used by clients to store the ephemeral KeyExchange + // for the connection. + std::unique_ptr client_key_exchange; + // channel_id is set by servers to a ChannelID key when the client correctly + // proves possession of the corresponding private key. It consists of 32 + // bytes of x coordinate, followed by 32 bytes of y coordinate. Both values + // are big-endian and the pair is a P-256 public key. + std::string channel_id; + QuicTag token_binding_key_param; + + // Used when generating proof signature when sending server config updates. + + // Used to generate cert chain when sending server config updates. + std::string client_cached_cert_hashes; + + // Default to false; set to true if the client indicates that it supports sct + // by sending CSCT tag with an empty value in client hello. + bool sct_supported_by_client; + + // Parameters only populated for TLS handshakes. These will be 0 for + // connections not using TLS, or if the TLS handshake is not finished yet. + uint16_t cipher_suite = 0; + uint16_t key_exchange_group = 0; + uint16_t peer_signature_algorithm = 0; + bool encrypted_client_hello = false; + + protected: + ~QuicCryptoNegotiatedParameters() override; +}; + +// QuicCryptoConfig contains common configuration between clients and servers. +class QUIC_EXPORT_PRIVATE QuicCryptoConfig { + public: + // kInitialLabel is a constant that is used when deriving the initial + // (non-forward secure) keys for the connection in order to tie the resulting + // key to this protocol. + static const char kInitialLabel[]; + + // kCETVLabel is a constant that is used when deriving the keys for the + // encrypted tag/value block in the client hello. + static const char kCETVLabel[]; + + // kForwardSecureLabel is a constant that is used when deriving the forward + // secure keys for the connection in order to tie the resulting key to this + // protocol. + static const char kForwardSecureLabel[]; + + QuicCryptoConfig(); + QuicCryptoConfig(const QuicCryptoConfig&) = delete; + QuicCryptoConfig& operator=(const QuicCryptoConfig&) = delete; + ~QuicCryptoConfig(); + + // Key exchange methods. The following two members' values correspond by + // index. + QuicTagVector kexs; + // Authenticated encryption with associated data (AEAD) algorithms. + QuicTagVector aead; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_CRYPTO_HANDSHAKE_H_ diff --git a/quiche/quic/core/crypto/crypto_handshake_message.cc b/quiche/quic/core/crypto/crypto_handshake_message.cc new file mode 100644 index 000000000000..9fa4dfa5319c --- /dev/null +++ b/quiche/quic/core/crypto/crypto_handshake_message.cc @@ -0,0 +1,368 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/crypto_handshake_message.h" + +#include +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/crypto_framer.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/crypto/crypto_utils.h" +#include "quiche/quic/core/quic_socket_address_coder.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/common/quiche_endian.h" + +namespace quic { + +CryptoHandshakeMessage::CryptoHandshakeMessage() : tag_(0), minimum_size_(0) {} + +CryptoHandshakeMessage::CryptoHandshakeMessage( + const CryptoHandshakeMessage& other) + : tag_(other.tag_), + tag_value_map_(other.tag_value_map_), + minimum_size_(other.minimum_size_) { + // Don't copy serialized_. unique_ptr doesn't have a copy constructor. + // The new object can lazily reconstruct serialized_. +} + +CryptoHandshakeMessage::CryptoHandshakeMessage(CryptoHandshakeMessage&& other) = + default; + +CryptoHandshakeMessage::~CryptoHandshakeMessage() {} + +CryptoHandshakeMessage& CryptoHandshakeMessage::operator=( + const CryptoHandshakeMessage& other) { + tag_ = other.tag_; + tag_value_map_ = other.tag_value_map_; + // Don't copy serialized_. unique_ptr doesn't have an assignment operator. + // However, invalidate serialized_. + serialized_.reset(); + minimum_size_ = other.minimum_size_; + return *this; +} + +CryptoHandshakeMessage& CryptoHandshakeMessage::operator=( + CryptoHandshakeMessage&& other) = default; + +bool CryptoHandshakeMessage::operator==( + const CryptoHandshakeMessage& rhs) const { + return tag_ == rhs.tag_ && tag_value_map_ == rhs.tag_value_map_ && + minimum_size_ == rhs.minimum_size_; +} + +bool CryptoHandshakeMessage::operator!=( + const CryptoHandshakeMessage& rhs) const { + return !(*this == rhs); +} + +void CryptoHandshakeMessage::Clear() { + tag_ = 0; + tag_value_map_.clear(); + minimum_size_ = 0; + serialized_.reset(); +} + +const QuicData& CryptoHandshakeMessage::GetSerialized() const { + if (!serialized_) { + serialized_ = CryptoFramer::ConstructHandshakeMessage(*this); + } + return *serialized_; +} + +void CryptoHandshakeMessage::MarkDirty() { serialized_.reset(); } + +void CryptoHandshakeMessage::SetVersionVector( + QuicTag tag, ParsedQuicVersionVector versions) { + QuicVersionLabelVector version_labels; + for (const ParsedQuicVersion& version : versions) { + version_labels.push_back( + quiche::QuicheEndian::HostToNet32(CreateQuicVersionLabel(version))); + } + SetVector(tag, version_labels); +} + +void CryptoHandshakeMessage::SetVersion(QuicTag tag, + ParsedQuicVersion version) { + SetValue(tag, + quiche::QuicheEndian::HostToNet32(CreateQuicVersionLabel(version))); +} + +void CryptoHandshakeMessage::SetStringPiece(QuicTag tag, + absl::string_view value) { + tag_value_map_[tag] = std::string(value); +} + +void CryptoHandshakeMessage::Erase(QuicTag tag) { tag_value_map_.erase(tag); } + +QuicErrorCode CryptoHandshakeMessage::GetTaglist( + QuicTag tag, QuicTagVector* out_tags) const { + auto it = tag_value_map_.find(tag); + QuicErrorCode ret = QUIC_NO_ERROR; + + if (it == tag_value_map_.end()) { + ret = QUIC_CRYPTO_MESSAGE_PARAMETER_NOT_FOUND; + } else if (it->second.size() % sizeof(QuicTag) != 0) { + ret = QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER; + } + + if (ret != QUIC_NO_ERROR) { + out_tags->clear(); + return ret; + } + + size_t num_tags = it->second.size() / sizeof(QuicTag); + out_tags->resize(num_tags); + for (size_t i = 0; i < num_tags; ++i) { + QuicTag tag; + memcpy(&tag, it->second.data() + i * sizeof(tag), sizeof(tag)); + (*out_tags)[i] = tag; + } + return ret; +} + +QuicErrorCode CryptoHandshakeMessage::GetVersionLabelList( + QuicTag tag, QuicVersionLabelVector* out) const { + QuicErrorCode error = GetTaglist(tag, out); + if (error != QUIC_NO_ERROR) { + return error; + } + + for (size_t i = 0; i < out->size(); ++i) { + (*out)[i] = quiche::QuicheEndian::HostToNet32((*out)[i]); + } + + return QUIC_NO_ERROR; +} + +QuicErrorCode CryptoHandshakeMessage::GetVersionLabel( + QuicTag tag, QuicVersionLabel* out) const { + QuicErrorCode error = GetUint32(tag, out); + if (error != QUIC_NO_ERROR) { + return error; + } + + *out = quiche::QuicheEndian::HostToNet32(*out); + return QUIC_NO_ERROR; +} + +bool CryptoHandshakeMessage::GetStringPiece(QuicTag tag, + absl::string_view* out) const { + auto it = tag_value_map_.find(tag); + if (it == tag_value_map_.end()) { + return false; + } + *out = it->second; + return true; +} + +bool CryptoHandshakeMessage::HasStringPiece(QuicTag tag) const { + return tag_value_map_.find(tag) != tag_value_map_.end(); +} + +QuicErrorCode CryptoHandshakeMessage::GetNthValue24( + QuicTag tag, unsigned index, absl::string_view* out) const { + absl::string_view value; + if (!GetStringPiece(tag, &value)) { + return QUIC_CRYPTO_MESSAGE_PARAMETER_NOT_FOUND; + } + + for (unsigned i = 0;; i++) { + if (value.empty()) { + return QUIC_CRYPTO_MESSAGE_INDEX_NOT_FOUND; + } + if (value.size() < 3) { + return QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER; + } + + const unsigned char* data = + reinterpret_cast(value.data()); + size_t size = static_cast(data[0]) | + (static_cast(data[1]) << 8) | + (static_cast(data[2]) << 16); + value.remove_prefix(3); + + if (value.size() < size) { + return QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER; + } + + if (i == index) { + *out = absl::string_view(value.data(), size); + return QUIC_NO_ERROR; + } + + value.remove_prefix(size); + } +} + +QuicErrorCode CryptoHandshakeMessage::GetUint32(QuicTag tag, + uint32_t* out) const { + return GetPOD(tag, out, sizeof(uint32_t)); +} + +QuicErrorCode CryptoHandshakeMessage::GetUint64(QuicTag tag, + uint64_t* out) const { + return GetPOD(tag, out, sizeof(uint64_t)); +} + +QuicErrorCode CryptoHandshakeMessage::GetStatelessResetToken( + QuicTag tag, StatelessResetToken* out) const { + return GetPOD(tag, out, kStatelessResetTokenLength); +} + +size_t CryptoHandshakeMessage::size() const { + size_t ret = sizeof(QuicTag) + sizeof(uint16_t) /* number of entries */ + + sizeof(uint16_t) /* padding */; + ret += (sizeof(QuicTag) + sizeof(uint32_t) /* end offset */) * + tag_value_map_.size(); + for (auto i = tag_value_map_.begin(); i != tag_value_map_.end(); ++i) { + ret += i->second.size(); + } + + return ret; +} + +void CryptoHandshakeMessage::set_minimum_size(size_t min_bytes) { + if (min_bytes == minimum_size_) { + return; + } + serialized_.reset(); + minimum_size_ = min_bytes; +} + +size_t CryptoHandshakeMessage::minimum_size() const { return minimum_size_; } + +std::string CryptoHandshakeMessage::DebugString() const { + return DebugStringInternal(0); +} + +QuicErrorCode CryptoHandshakeMessage::GetPOD(QuicTag tag, void* out, + size_t len) const { + auto it = tag_value_map_.find(tag); + QuicErrorCode ret = QUIC_NO_ERROR; + + if (it == tag_value_map_.end()) { + ret = QUIC_CRYPTO_MESSAGE_PARAMETER_NOT_FOUND; + } else if (it->second.size() != len) { + ret = QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER; + } + + if (ret != QUIC_NO_ERROR) { + memset(out, 0, len); + return ret; + } + + memcpy(out, it->second.data(), len); + return ret; +} + +std::string CryptoHandshakeMessage::DebugStringInternal(size_t indent) const { + std::string ret = + std::string(2 * indent, ' ') + QuicTagToString(tag_) + "<\n"; + ++indent; + for (auto it = tag_value_map_.begin(); it != tag_value_map_.end(); ++it) { + ret += std::string(2 * indent, ' ') + QuicTagToString(it->first) + ": "; + + bool done = false; + switch (it->first) { + case kICSL: + case kCFCW: + case kSFCW: + case kIRTT: + case kMIUS: + case kMIBS: + case kTCID: + case kMAD: + // uint32_t value + if (it->second.size() == 4) { + uint32_t value; + memcpy(&value, it->second.data(), sizeof(value)); + absl::StrAppend(&ret, value); + done = true; + } + break; + case kKEXS: + case kAEAD: + case kCOPT: + case kPDMD: + case kVER: + // tag lists + if (it->second.size() % sizeof(QuicTag) == 0) { + for (size_t j = 0; j < it->second.size(); j += sizeof(QuicTag)) { + QuicTag tag; + memcpy(&tag, it->second.data() + j, sizeof(tag)); + if (j > 0) { + ret += ","; + } + ret += "'" + QuicTagToString(tag) + "'"; + } + done = true; + } + break; + case kRREJ: + // uint32_t lists + if (it->second.size() % sizeof(uint32_t) == 0) { + for (size_t j = 0; j < it->second.size(); j += sizeof(uint32_t)) { + uint32_t value; + memcpy(&value, it->second.data() + j, sizeof(value)); + if (j > 0) { + ret += ","; + } + ret += CryptoUtils::HandshakeFailureReasonToString( + static_cast(value)); + } + done = true; + } + break; + case kCADR: + // IP address and port + if (!it->second.empty()) { + QuicSocketAddressCoder decoder; + if (decoder.Decode(it->second.data(), it->second.size())) { + ret += QuicSocketAddress(decoder.ip(), decoder.port()).ToString(); + done = true; + } + } + break; + case kSCFG: + // nested messages. + if (!it->second.empty()) { + std::unique_ptr msg( + CryptoFramer::ParseMessage(it->second)); + if (msg) { + ret += "\n"; + ret += msg->DebugStringInternal(indent + 1); + + done = true; + } + } + break; + case kPAD: + ret += absl::StrFormat("(%d bytes of padding)", it->second.size()); + done = true; + break; + case kSNI: + case kUAID: + ret += "\"" + it->second + "\""; + done = true; + break; + } + + if (!done) { + // If there's no specific format for this tag, or the value is invalid, + // then just use hex. + ret += "0x" + absl::BytesToHexString(it->second); + } + ret += "\n"; + } + --indent; + ret += std::string(2 * indent, ' ') + ">"; + return ret; +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/crypto_handshake_message.h b/quiche/quic/core/crypto/crypto_handshake_message.h new file mode 100644 index 000000000000..b50c559d2ea1 --- /dev/null +++ b/quiche/quic/core/crypto/crypto_handshake_message.h @@ -0,0 +1,159 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_CRYPTO_HANDSHAKE_MESSAGE_H_ +#define QUICHE_QUIC_CORE_CRYPTO_CRYPTO_HANDSHAKE_MESSAGE_H_ + +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// An intermediate format of a handshake message that's convenient for a +// CryptoFramer to serialize from or parse into. +class QUIC_EXPORT_PRIVATE CryptoHandshakeMessage { + public: + CryptoHandshakeMessage(); + CryptoHandshakeMessage(const CryptoHandshakeMessage& other); + CryptoHandshakeMessage(CryptoHandshakeMessage&& other); + ~CryptoHandshakeMessage(); + + CryptoHandshakeMessage& operator=(const CryptoHandshakeMessage& other); + CryptoHandshakeMessage& operator=(CryptoHandshakeMessage&& other); + + bool operator==(const CryptoHandshakeMessage& rhs) const; + bool operator!=(const CryptoHandshakeMessage& rhs) const; + + // Clears state. + void Clear(); + + // GetSerialized returns the serialized form of this message and caches the + // result. Subsequently altering the message does not invalidate the cache. + const QuicData& GetSerialized() const; + + // MarkDirty invalidates the cache created by |GetSerialized|. + void MarkDirty(); + + // SetValue sets an element with the given tag to the raw, memory contents of + // |v|. + template + void SetValue(QuicTag tag, const T& v) { + tag_value_map_[tag] = + std::string(reinterpret_cast(&v), sizeof(v)); + } + + // SetVector sets an element with the given tag to the raw contents of an + // array of elements in |v|. + template + void SetVector(QuicTag tag, const std::vector& v) { + if (v.empty()) { + tag_value_map_[tag] = std::string(); + } else { + tag_value_map_[tag] = std::string(reinterpret_cast(&v[0]), + v.size() * sizeof(T)); + } + } + + // Sets an element with the given tag to the on-the-wire representation of + // |version|. + void SetVersion(QuicTag tag, ParsedQuicVersion version); + + // Sets an element with the given tag to the on-the-wire representation of + // the elements in |versions|. + void SetVersionVector(QuicTag tag, ParsedQuicVersionVector versions); + + // Returns the message tag. + QuicTag tag() const { return tag_; } + // Sets the message tag. + void set_tag(QuicTag tag) { tag_ = tag; } + + const QuicTagValueMap& tag_value_map() const { return tag_value_map_; } + + void SetStringPiece(QuicTag tag, absl::string_view value); + + // Erase removes a tag/value, if present, from the message. + void Erase(QuicTag tag); + + // GetTaglist finds an element with the given tag containing zero or more + // tags. If such a tag doesn't exist, it returns an error code. Otherwise it + // populates |out_tags| with the tags and returns QUIC_NO_ERROR. + QuicErrorCode GetTaglist(QuicTag tag, QuicTagVector* out_tags) const; + + // GetVersionLabelList finds an element with the given tag containing zero or + // more version labels. If such a tag doesn't exist, it returns an error code. + // Otherwise it populates |out| with the labels and returns QUIC_NO_ERROR. + QuicErrorCode GetVersionLabelList(QuicTag tag, + QuicVersionLabelVector* out) const; + + // GetVersionLabel finds an element with the given tag containing a single + // version label. If such a tag doesn't exist, it returns an error code. + // Otherwise it populates |out| with the label and returns QUIC_NO_ERROR. + QuicErrorCode GetVersionLabel(QuicTag tag, QuicVersionLabel* out) const; + + bool GetStringPiece(QuicTag tag, absl::string_view* out) const; + bool HasStringPiece(QuicTag tag) const; + + // GetNthValue24 interprets the value with the given tag to be a series of + // 24-bit, length prefixed values and it returns the subvalue with the given + // index. + QuicErrorCode GetNthValue24(QuicTag tag, unsigned index, + absl::string_view* out) const; + QuicErrorCode GetUint32(QuicTag tag, uint32_t* out) const; + QuicErrorCode GetUint64(QuicTag tag, uint64_t* out) const; + + QuicErrorCode GetStatelessResetToken(QuicTag tag, + StatelessResetToken* out) const; + + // size returns 4 (message tag) + 2 (uint16_t, number of entries) + + // (4 (tag) + 4 (end offset))*tag_value_map_.size() + ∑ value sizes. + size_t size() const; + + // set_minimum_size sets the minimum number of bytes that the message should + // consume. The CryptoFramer will add a PAD tag as needed when serializing in + // order to ensure this. Setting a value of 0 disables padding. + // + // Padding is useful in order to ensure that messages are a minimum size. A + // QUIC server can require a minimum size in order to reduce the + // amplification factor of any mirror DoS attack. + void set_minimum_size(size_t min_bytes); + + size_t minimum_size() const; + + // DebugString returns a multi-line, string representation of the message + // suitable for including in debug output. + std::string DebugString() const; + + private: + // GetPOD is a utility function for extracting a plain-old-data value. If + // |tag| exists in the message, and has a value of exactly |len| bytes then + // it copies |len| bytes of data into |out|. Otherwise |len| bytes at |out| + // are zeroed out. + // + // If used to copy integers then this assumes that the machine is + // little-endian. + QuicErrorCode GetPOD(QuicTag tag, void* out, size_t len) const; + + std::string DebugStringInternal(size_t indent) const; + + QuicTag tag_; + QuicTagValueMap tag_value_map_; + + size_t minimum_size_; + + // The serialized form of the handshake message. This member is constructed + // lazily. + mutable std::unique_ptr serialized_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_CRYPTO_HANDSHAKE_MESSAGE_H_ diff --git a/quiche/quic/core/crypto/crypto_handshake_message_test.cc b/quiche/quic/core/crypto/crypto_handshake_message_test.cc new file mode 100644 index 000000000000..bdc051c2bd9a --- /dev/null +++ b/quiche/quic/core/crypto/crypto_handshake_message_test.cc @@ -0,0 +1,105 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/crypto_handshake_message.h" + +#include "quiche/quic/core/crypto/crypto_handshake.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/common/quiche_endian.h" + +namespace quic { +namespace test { +namespace { + +TEST(CryptoHandshakeMessageTest, DebugString) { + const char* str = "SHLO<\n>"; + + CryptoHandshakeMessage message; + message.set_tag(kSHLO); + EXPECT_EQ(str, message.DebugString()); + + // Test copy + CryptoHandshakeMessage message2(message); + EXPECT_EQ(str, message2.DebugString()); + + // Test move + CryptoHandshakeMessage message3(std::move(message)); + EXPECT_EQ(str, message3.DebugString()); + + // Test assign + CryptoHandshakeMessage message4 = message3; + EXPECT_EQ(str, message4.DebugString()); + + // Test move-assign + CryptoHandshakeMessage message5 = std::move(message3); + EXPECT_EQ(str, message5.DebugString()); +} + +TEST(CryptoHandshakeMessageTest, DebugStringWithUintVector) { + const char* str = + "REJ <\n RREJ: " + "SOURCE_ADDRESS_TOKEN_DIFFERENT_IP_ADDRESS_FAILURE," + "CLIENT_NONCE_NOT_UNIQUE_FAILURE\n>"; + + CryptoHandshakeMessage message; + message.set_tag(kREJ); + std::vector reasons = { + SOURCE_ADDRESS_TOKEN_DIFFERENT_IP_ADDRESS_FAILURE, + CLIENT_NONCE_NOT_UNIQUE_FAILURE}; + message.SetVector(kRREJ, reasons); + EXPECT_EQ(str, message.DebugString()); + + // Test copy + CryptoHandshakeMessage message2(message); + EXPECT_EQ(str, message2.DebugString()); + + // Test move + CryptoHandshakeMessage message3(std::move(message)); + EXPECT_EQ(str, message3.DebugString()); + + // Test assign + CryptoHandshakeMessage message4 = message3; + EXPECT_EQ(str, message4.DebugString()); + + // Test move-assign + CryptoHandshakeMessage message5 = std::move(message3); + EXPECT_EQ(str, message5.DebugString()); +} + +TEST(CryptoHandshakeMessageTest, DebugStringWithTagVector) { + const char* str = "CHLO<\n COPT: 'TBBR','PAD ','BYTE'\n>"; + + CryptoHandshakeMessage message; + message.set_tag(kCHLO); + message.SetVector(kCOPT, QuicTagVector{kTBBR, kPAD, kBYTE}); + EXPECT_EQ(str, message.DebugString()); + + // Test copy + CryptoHandshakeMessage message2(message); + EXPECT_EQ(str, message2.DebugString()); + + // Test move + CryptoHandshakeMessage message3(std::move(message)); + EXPECT_EQ(str, message3.DebugString()); + + // Test assign + CryptoHandshakeMessage message4 = message3; + EXPECT_EQ(str, message4.DebugString()); + + // Test move-assign + CryptoHandshakeMessage message5 = std::move(message3); + EXPECT_EQ(str, message5.DebugString()); +} + +TEST(CryptoHandshakeMessageTest, HasStringPiece) { + CryptoHandshakeMessage message; + EXPECT_FALSE(message.HasStringPiece(kALPN)); + message.SetStringPiece(kALPN, "foo"); + EXPECT_TRUE(message.HasStringPiece(kALPN)); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/crypto_message_parser.h b/quiche/quic/core/crypto/crypto_message_parser.h new file mode 100644 index 000000000000..f819209bbb91 --- /dev/null +++ b/quiche/quic/core/crypto/crypto_message_parser.h @@ -0,0 +1,35 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_CRYPTO_MESSAGE_PARSER_H_ +#define QUICHE_QUIC_CORE_CRYPTO_CRYPTO_MESSAGE_PARSER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" + +namespace quic { + +class QUIC_EXPORT_PRIVATE CryptoMessageParser { + public: + virtual ~CryptoMessageParser() {} + + virtual QuicErrorCode error() const = 0; + virtual const std::string& error_detail() const = 0; + + // Processes input data, which must be delivered in order. The input data + // being processed was received at encryption level |level|. Returns + // false if there was an error, and true otherwise. + virtual bool ProcessInput(absl::string_view input, EncryptionLevel level) = 0; + + // Returns the number of bytes of buffered input data remaining to be + // parsed. + virtual size_t InputBytesRemaining() const = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_CRYPTO_MESSAGE_PARSER_H_ diff --git a/quiche/quic/core/crypto/crypto_protocol.h b/quiche/quic/core/crypto/crypto_protocol.h new file mode 100644 index 000000000000..31937bc44b78 --- /dev/null +++ b/quiche/quic/core/crypto/crypto_protocol.h @@ -0,0 +1,516 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_CRYPTO_PROTOCOL_H_ +#define QUICHE_QUIC_CORE_CRYPTO_CRYPTO_PROTOCOL_H_ + +#include +#include + +#include "quiche/quic/core/quic_tag.h" + +// Version and Crypto tags are written to the wire with a big-endian +// representation of the name of the tag. For example +// the client hello tag (CHLO) will be written as the +// following 4 bytes: 'C' 'H' 'L' 'O'. Since it is +// stored in memory as a little endian uint32_t, we need +// to reverse the order of the bytes. +// +// We use a macro to ensure that no static initialisers are created. Use the +// MakeQuicTag function in normal code. +#define TAG(a, b, c, d) \ + static_cast((d << 24) + (c << 16) + (b << 8) + a) + +namespace quic { + +using ServerConfigID = std::string; + +// The following tags have been deprecated and should not be reused: +// "1CON", "BBQ4", "NCON", "RCID", "SREJ", "TBKP", "TB10", "SCLS", "SMHL", +// "QNZR", "B2HI", "H2PR", "FIFO", "LIFO", "RRWS", "QNSP", "B2CL", "CHSP", +// "BPTE", "ACKD", "AKD2", "AKD4", "MAD1", "MAD4", "MAD5", "ACD0", "ACKQ", +// "TLPR", "CCS\0", "PDP4", "NCHP", "NBPE", "2RTO", "3RTO", "4RTO", "6RTO", +// "PDP1", "PDP2", "PDP3", "PDP5" "QLVE" + +// clang-format off +const QuicTag kCHLO = TAG('C', 'H', 'L', 'O'); // Client hello +const QuicTag kSHLO = TAG('S', 'H', 'L', 'O'); // Server hello +const QuicTag kSCFG = TAG('S', 'C', 'F', 'G'); // Server config +const QuicTag kREJ = TAG('R', 'E', 'J', '\0'); // Reject +const QuicTag kCETV = TAG('C', 'E', 'T', 'V'); // Client encrypted tag-value + // pairs +const QuicTag kPRST = TAG('P', 'R', 'S', 'T'); // Public reset +const QuicTag kSCUP = TAG('S', 'C', 'U', 'P'); // Server config update +const QuicTag kALPN = TAG('A', 'L', 'P', 'N'); // Application-layer protocol + +// Key exchange methods +const QuicTag kP256 = TAG('P', '2', '5', '6'); // ECDH, Curve P-256 +const QuicTag kC255 = TAG('C', '2', '5', '5'); // ECDH, Curve25519 + +// AEAD algorithms +const QuicTag kAESG = TAG('A', 'E', 'S', 'G'); // AES128 + GCM-12 +const QuicTag kCC20 = TAG('C', 'C', '2', '0'); // ChaCha20 + Poly1305 RFC7539 + +// Congestion control feedback types +const QuicTag kQBIC = TAG('Q', 'B', 'I', 'C'); // TCP cubic + +// Connection options (COPT) values +const QuicTag kAFCW = TAG('A', 'F', 'C', 'W'); // Auto-tune flow control + // receive windows. +const QuicTag kIFW5 = TAG('I', 'F', 'W', '5'); // Set initial size + // of stream flow control + // receive window to + // 32KB. (2^5 KB). +const QuicTag kIFW6 = TAG('I', 'F', 'W', '6'); // Set initial size + // of stream flow control + // receive window to + // 64KB. (2^6 KB). +const QuicTag kIFW7 = TAG('I', 'F', 'W', '7'); // Set initial size + // of stream flow control + // receive window to + // 128KB. (2^7 KB). +const QuicTag kIFW8 = TAG('I', 'F', 'W', '8'); // Set initial size + // of stream flow control + // receive window to + // 256KB. (2^8 KB). +const QuicTag kIFW9 = TAG('I', 'F', 'W', '9'); // Set initial size + // of stream flow control + // receive window to + // 512KB. (2^9 KB). +const QuicTag kIFWA = TAG('I', 'F', 'W', 'a'); // Set initial size + // of stream flow control + // receive window to + // 1MB. (2^0xa KB). +const QuicTag kTBBR = TAG('T', 'B', 'B', 'R'); // Reduced Buffer Bloat TCP +const QuicTag k1RTT = TAG('1', 'R', 'T', 'T'); // STARTUP in BBR for 1 RTT +const QuicTag k2RTT = TAG('2', 'R', 'T', 'T'); // STARTUP in BBR for 2 RTTs +const QuicTag kLRTT = TAG('L', 'R', 'T', 'T'); // Exit STARTUP in BBR on loss +const QuicTag kBBS1 = TAG('B', 'B', 'S', '1'); // DEPRECATED +const QuicTag kBBS2 = TAG('B', 'B', 'S', '2'); // More aggressive packet + // conservation in BBR STARTUP +const QuicTag kBBS3 = TAG('B', 'B', 'S', '3'); // Slowstart packet + // conservation in BBR STARTUP +const QuicTag kBBS4 = TAG('B', 'B', 'S', '4'); // DEPRECATED +const QuicTag kBBS5 = TAG('B', 'B', 'S', '5'); // DEPRECATED +const QuicTag kBBRR = TAG('B', 'B', 'R', 'R'); // Rate-based recovery in BBR +const QuicTag kBBR1 = TAG('B', 'B', 'R', '1'); // DEPRECATED +const QuicTag kBBR2 = TAG('B', 'B', 'R', '2'); // DEPRECATED +const QuicTag kBBR3 = TAG('B', 'B', 'R', '3'); // Fully drain the queue once + // per cycle +const QuicTag kBBR4 = TAG('B', 'B', 'R', '4'); // 20 RTT ack aggregation +const QuicTag kBBR5 = TAG('B', 'B', 'R', '5'); // 40 RTT ack aggregation +const QuicTag kBBR9 = TAG('B', 'B', 'R', '9'); // DEPRECATED +const QuicTag kBBRA = TAG('B', 'B', 'R', 'A'); // Starts a new ack aggregation + // epoch if a full round has + // passed +const QuicTag kBBRB = TAG('B', 'B', 'R', 'B'); // Use send rate in BBR's + // MaxAckHeightTracker +const QuicTag kBBRS = TAG('B', 'B', 'R', 'S'); // DEPRECATED +const QuicTag kBBQ1 = TAG('B', 'B', 'Q', '1'); // DEPRECATED +const QuicTag kBBQ2 = TAG('B', 'B', 'Q', '2'); // BBRv2 with 2.885 STARTUP and + // DRAIN CWND gain. +const QuicTag kBBQ3 = TAG('B', 'B', 'Q', '3'); // BBR with ack aggregation + // compensation in STARTUP. +const QuicTag kBBQ5 = TAG('B', 'B', 'Q', '5'); // Expire ack aggregation upon + // bandwidth increase in + // STARTUP. +const QuicTag kBBQ6 = TAG('B', 'B', 'Q', '6'); // Reduce STARTUP gain to 25% + // more than BW increase. +const QuicTag kBBQ7 = TAG('B', 'B', 'Q', '7'); // Reduce bw_lo by + // bytes_lost/min_rtt. +const QuicTag kBBQ8 = TAG('B', 'B', 'Q', '8'); // Reduce bw_lo by + // bw_lo * bytes_lost/inflight +const QuicTag kBBQ9 = TAG('B', 'B', 'Q', '9'); // Reduce bw_lo by + // bw_lo * bytes_lost/cwnd +const QuicTag kBBQ0 = TAG('B', 'B', 'Q', '0'); // Increase bytes_acked in + // PROBE_UP when app limited. +const QuicTag kBBPD = TAG('B', 'B', 'P', 'D'); // Use 0.91 PROBE_DOWN gain. +const QuicTag kBBHI = TAG('B', 'B', 'H', 'I'); // Increase inflight_hi in + // PROBE_UP if ever inflight_hi + // limited in round +const QuicTag kRENO = TAG('R', 'E', 'N', 'O'); // Reno Congestion Control +const QuicTag kTPCC = TAG('P', 'C', 'C', '\0'); // Performance-Oriented + // Congestion Control +const QuicTag kBYTE = TAG('B', 'Y', 'T', 'E'); // TCP cubic or reno in bytes +const QuicTag kIW03 = TAG('I', 'W', '0', '3'); // Force ICWND to 3 +const QuicTag kIW10 = TAG('I', 'W', '1', '0'); // Force ICWND to 10 +const QuicTag kIW20 = TAG('I', 'W', '2', '0'); // Force ICWND to 20 +const QuicTag kIW50 = TAG('I', 'W', '5', '0'); // Force ICWND to 50 +const QuicTag kB2ON = TAG('B', '2', 'O', 'N'); // Enable BBRv2 +const QuicTag kB2NA = TAG('B', '2', 'N', 'A'); // For BBRv2, do not add ack + // height to queueing threshold +const QuicTag kB2NE = TAG('B', '2', 'N', 'E'); // For BBRv2, always exit + // STARTUP on loss, even if + // bandwidth growth exceeds + // threshold. +const QuicTag kB2RP = TAG('B', '2', 'R', 'P'); // For BBRv2, run PROBE_RTT on + // the regular schedule +const QuicTag kB2LO = TAG('B', '2', 'L', 'O'); // Ignore inflight_lo in BBR2 +const QuicTag kB2HR = TAG('B', '2', 'H', 'R'); // 15% inflight_hi headroom. +const QuicTag kB2SL = TAG('B', '2', 'S', 'L'); // When exiting STARTUP due to + // loss, set inflight_hi to the + // max of bdp and max bytes + // delivered in round. +const QuicTag kB2H2 = TAG('B', '2', 'H', '2'); // When exiting PROBE_UP due to + // loss, set inflight_hi to the + // max of inflight@send and max + // bytes delivered in round. +const QuicTag kB2RC = TAG('B', '2', 'R', 'C'); // Disable Reno-coexistence for + // BBR2. +const QuicTag kBSAO = TAG('B', 'S', 'A', 'O'); // Avoid Overestimation in + // Bandwidth Sampler with ack + // aggregation +const QuicTag kB2DL = TAG('B', '2', 'D', 'L'); // Increase inflight_hi based + // on delievered, not inflight. +const QuicTag kB201 = TAG('B', '2', '0', '1'); // DEPRECATED +const QuicTag kB202 = TAG('B', '2', '0', '2'); // Do not exit PROBE_UP if + // inflight dips below 1.25*BW. +const QuicTag kB203 = TAG('B', '2', '0', '3'); // Ignore inflight_hi until + // PROBE_UP is exited. +const QuicTag kB204 = TAG('B', '2', '0', '4'); // Reduce extra acked when + // MaxBW incrases. +const QuicTag kB205 = TAG('B', '2', '0', '5'); // Add extra acked to CWND in + // STARTUP. +const QuicTag kB206 = TAG('B', '2', '0', '6'); // Exit STARTUP after 2 losses. +const QuicTag kB207 = TAG('B', '2', '0', '7'); // Exit STARTUP on persistent + // queue +const QuicTag kBB2U = TAG('B', 'B', '2', 'U'); // Exit PROBE_UP on + // min_bytes_in_flight for two + // rounds in a row. +const QuicTag kBB2S = TAG('B', 'B', '2', 'S'); // Exit STARTUP on + // min_bytes_in_flight for two + // rounds in a row. +const QuicTag kNTLP = TAG('N', 'T', 'L', 'P'); // No tail loss probe +const QuicTag k1TLP = TAG('1', 'T', 'L', 'P'); // 1 tail loss probe +const QuicTag k1RTO = TAG('1', 'R', 'T', 'O'); // Send 1 packet upon RTO +const QuicTag kNRTO = TAG('N', 'R', 'T', 'O'); // CWND reduction on loss +const QuicTag kTIME = TAG('T', 'I', 'M', 'E'); // Time based loss detection +const QuicTag kATIM = TAG('A', 'T', 'I', 'M'); // Adaptive time loss detection +const QuicTag kMIN1 = TAG('M', 'I', 'N', '1'); // Min CWND of 1 packet +const QuicTag kMIN4 = TAG('M', 'I', 'N', '4'); // Min CWND of 4 packets, + // with a min rate of 1 BDP. +const QuicTag kMAD0 = TAG('M', 'A', 'D', '0'); // Ignore ack delay +const QuicTag kMAD2 = TAG('M', 'A', 'D', '2'); // No min TLP +const QuicTag kMAD3 = TAG('M', 'A', 'D', '3'); // No min RTO +const QuicTag k1ACK = TAG('1', 'A', 'C', 'K'); // 1 fast ack for reordering +const QuicTag kAKD3 = TAG('A', 'K', 'D', '3'); // Ack decimation style acking + // with 1/8 RTT acks. +const QuicTag kAKDU = TAG('A', 'K', 'D', 'U'); // Unlimited number of packets + // received before acking +const QuicTag kAFFE = TAG('A', 'F', 'F', 'E'); // Enable client receiving + // AckFrequencyFrame. +const QuicTag kAFF1 = TAG('A', 'F', 'F', '1'); // Use SRTT in building + // AckFrequencyFrame. +const QuicTag kAFF2 = TAG('A', 'F', 'F', '2'); // Send AckFrequencyFrame upon + // handshake completion. +const QuicTag kSSLR = TAG('S', 'S', 'L', 'R'); // Slow Start Large Reduction. +const QuicTag kNPRR = TAG('N', 'P', 'R', 'R'); // Pace at unity instead of PRR +const QuicTag k5RTO = TAG('5', 'R', 'T', 'O'); // Close connection on 5 RTOs +const QuicTag kCBHD = TAG('C', 'B', 'H', 'D'); // Client only blackhole + // detection. +const QuicTag kNBHD = TAG('N', 'B', 'H', 'D'); // No blackhole detection. +const QuicTag kCONH = TAG('C', 'O', 'N', 'H'); // Conservative Handshake + // Retransmissions. +const QuicTag kLFAK = TAG('L', 'F', 'A', 'K'); // Don't invoke FACK on the + // first ack. +const QuicTag kSTMP = TAG('S', 'T', 'M', 'P'); // DEPRECATED +const QuicTag kEACK = TAG('E', 'A', 'C', 'K'); // Bundle ack-eliciting frame + // with an ACK after PTO/RTO + +const QuicTag kILD0 = TAG('I', 'L', 'D', '0'); // IETF style loss detection + // (default with 1/8 RTT time + // threshold) +const QuicTag kILD1 = TAG('I', 'L', 'D', '1'); // IETF style loss detection + // with 1/4 RTT time threshold +const QuicTag kILD2 = TAG('I', 'L', 'D', '2'); // IETF style loss detection + // with adaptive packet + // threshold +const QuicTag kILD3 = TAG('I', 'L', 'D', '3'); // IETF style loss detection + // with 1/4 RTT time threshold + // and adaptive packet + // threshold +const QuicTag kILD4 = TAG('I', 'L', 'D', '4'); // IETF style loss detection + // with both adaptive time + // threshold (default 1/4 RTT) + // and adaptive packet + // threshold +const QuicTag kRUNT = TAG('R', 'U', 'N', 'T'); // No packet threshold loss + // detection for "runt" packet. +const QuicTag kNSTP = TAG('N', 'S', 'T', 'P'); // No stop waiting frames. +const QuicTag kNRTT = TAG('N', 'R', 'T', 'T'); // Ignore initial RTT + +const QuicTag k1PTO = TAG('1', 'P', 'T', 'O'); // Send 1 packet upon PTO. +const QuicTag k2PTO = TAG('2', 'P', 'T', 'O'); // Send 2 packets upon PTO. + +const QuicTag k6PTO = TAG('6', 'P', 'T', 'O'); // Closes connection on 6 + // consecutive PTOs. +const QuicTag k7PTO = TAG('7', 'P', 'T', 'O'); // Closes connection on 7 + // consecutive PTOs. +const QuicTag k8PTO = TAG('8', 'P', 'T', 'O'); // Closes connection on 8 + // consecutive PTOs. +const QuicTag kPTOS = TAG('P', 'T', 'O', 'S'); // Skip packet number before + // sending the last PTO. +const QuicTag kPTOA = TAG('P', 'T', 'O', 'A'); // Do not add max ack delay + // when computing PTO timeout + // if an immediate ACK is + // expected. +const QuicTag kPEB1 = TAG('P', 'E', 'B', '1'); // Start exponential backoff + // since 1st PTO. +const QuicTag kPEB2 = TAG('P', 'E', 'B', '2'); // Start exponential backoff + // since 2nd PTO. +const QuicTag kPVS1 = TAG('P', 'V', 'S', '1'); // Use 2 * rttvar when + // calculating PTO timeout. +const QuicTag kPAG1 = TAG('P', 'A', 'G', '1'); // Make 1st PTO more aggressive +const QuicTag kPAG2 = TAG('P', 'A', 'G', '2'); // Make first 2 PTOs more + // aggressive +const QuicTag kPSDA = TAG('P', 'S', 'D', 'A'); // Use standard deviation when + // calculating PTO timeout. +const QuicTag kPLE1 = TAG('P', 'L', 'E', '1'); // Arm the 1st PTO with + // earliest in flight sent time + // and at least 0.5*srtt from + // last sent packet. +const QuicTag kPLE2 = TAG('P', 'L', 'E', '2'); // Arm the 1st PTO with + // earliest in flight sent time + // and at least 1.5*srtt from + // last sent packet. +const QuicTag kAPTO = TAG('A', 'P', 'T', 'O'); // Use 1.5 * initial RTT before + // any RTT sample is available. + +const QuicTag kELDT = TAG('E', 'L', 'D', 'T'); // Enable Loss Detection Tuning + +// TODO(haoyuewang) Remove RVCM option once +// --quic_remove_connection_migration_connection_option_v2 is deprecated. +const QuicTag kRVCM = TAG('R', 'V', 'C', 'M'); // Validate the new address + // upon client address change. + +const QuicTag kSPAD = TAG('S', 'P', 'A', 'D'); // Use server preferred address +const QuicTag kSPA2 = TAG('S', 'P', 'A', '2'); // Start validating server + // preferred address once it is + // received. Send all coalesced + // packets to both addresses. + +// Optional support of truncated Connection IDs. If sent by a peer, the value +// is the minimum number of bytes allowed for the connection ID sent to the +// peer. +const QuicTag kTCID = TAG('T', 'C', 'I', 'D'); // Connection ID truncation. + +// Multipath option. +const QuicTag kMPTH = TAG('M', 'P', 'T', 'H'); // Enable multipath. + +const QuicTag kNCMR = TAG('N', 'C', 'M', 'R'); // Do not attempt connection + // migration. + +// Allows disabling defer_send_in_response_to_packets in QuicConnection. +const QuicTag kDFER = TAG('D', 'F', 'E', 'R'); // Do not defer sending. + +// Disable Pacing offload option. +const QuicTag kNPCO = TAG('N', 'P', 'C', 'O'); // No pacing offload. + +// Enable bandwidth resumption experiment. +const QuicTag kBWRE = TAG('B', 'W', 'R', 'E'); // Bandwidth resumption. +const QuicTag kBWMX = TAG('B', 'W', 'M', 'X'); // Max bandwidth resumption. +const QuicTag kBWID = TAG('B', 'W', 'I', 'D'); // Send bandwidth when idle. +const QuicTag kBWI1 = TAG('B', 'W', 'I', '1'); // Resume bandwidth experiment 1 +const QuicTag kBWRS = TAG('B', 'W', 'R', 'S'); // Server bandwidth resumption. +const QuicTag kBWS2 = TAG('B', 'W', 'S', '2'); // Server bw resumption v2. +const QuicTag kBWS3 = TAG('B', 'W', 'S', '3'); // QUIC Initial CWND - Control. +const QuicTag kBWS4 = TAG('B', 'W', 'S', '4'); // QUIC Initial CWND - Enabled. +const QuicTag kBWS5 = TAG('B', 'W', 'S', '5'); // QUIC Initial CWND up and down +const QuicTag kBWS6 = TAG('B', 'W', 'S', '6'); // QUIC Initial CWND - Enabled + // with 0.5 * default + // multiplier. +const QuicTag kBWP0 = TAG('B', 'W', 'P', '0'); // QUIC Initial CWND - SPDY + // priority 0. +const QuicTag kBWP1 = TAG('B', 'W', 'P', '1'); // QUIC Initial CWND - SPDY + // priorities 0 and 1. +const QuicTag kBWP2 = TAG('B', 'W', 'P', '2'); // QUIC Initial CWND - SPDY + // priorities 0, 1 and 2. +const QuicTag kBWP3 = TAG('B', 'W', 'P', '3'); // QUIC Initial CWND - SPDY + // priorities 0, 1, 2 and 3. +const QuicTag kBWP4 = TAG('B', 'W', 'P', '4'); // QUIC Initial CWND - SPDY + // priorities >= 0, 1, 2, 3 and + // 4. +const QuicTag kBWG4 = TAG('B', 'W', 'G', '4'); // QUIC Initial CWND - + // Bandwidth model 1. +const QuicTag kBWG7 = TAG('B', 'W', 'G', '7'); // QUIC Initial CWND - + // Bandwidth model 2. +const QuicTag kBWG8 = TAG('B', 'W', 'G', '8'); // QUIC Initial CWND - + // Bandwidth model 3. +const QuicTag kBWS7 = TAG('B', 'W', 'S', '7'); // QUIC Initial CWND - Enabled + // with 0.75 * default + // multiplier. +const QuicTag kBWM3 = TAG('B', 'W', 'M', '3'); // Consider overshooting if + // bytes lost after bandwidth + // resumption * 3 > IW. +const QuicTag kBWM4 = TAG('B', 'W', 'M', '4'); // Consider overshooting if + // bytes lost after bandwidth + // resumption * 4 > IW. +const QuicTag kICW1 = TAG('I', 'C', 'W', '1'); // Max initial congestion window + // 100. +const QuicTag kDTOS = TAG('D', 'T', 'O', 'S'); // Enable overshooting + // detection. + +const QuicTag kFIDT = TAG('F', 'I', 'D', 'T'); // Extend idle timer by PTO + // instead of the whole idle + // timeout. + +const QuicTag k3AFF = TAG('3', 'A', 'F', 'F'); // 3 anti amplification factor. +const QuicTag k10AF = TAG('1', '0', 'A', 'F'); // 10 anti amplification factor. + +// Enable path MTU discovery experiment. +const QuicTag kMTUH = TAG('M', 'T', 'U', 'H'); // High-target MTU discovery. +const QuicTag kMTUL = TAG('M', 'T', 'U', 'L'); // Low-target MTU discovery. + +const QuicTag kNSLC = TAG('N', 'S', 'L', 'C'); // Always send connection close + // for idle timeout. + +// Proof types (i.e. certificate types) +// NOTE: although it would be silly to do so, specifying both kX509 and kX59R +// is allowed and is equivalent to specifying only kX509. +const QuicTag kX509 = TAG('X', '5', '0', '9'); // X.509 certificate, all key + // types +const QuicTag kX59R = TAG('X', '5', '9', 'R'); // X.509 certificate, RSA keys + // only +const QuicTag kCHID = TAG('C', 'H', 'I', 'D'); // Channel ID. + +// Client hello tags +const QuicTag kVER = TAG('V', 'E', 'R', '\0'); // Version +const QuicTag kNONC = TAG('N', 'O', 'N', 'C'); // The client's nonce +const QuicTag kNONP = TAG('N', 'O', 'N', 'P'); // The client's proof nonce +const QuicTag kKEXS = TAG('K', 'E', 'X', 'S'); // Key exchange methods +const QuicTag kAEAD = TAG('A', 'E', 'A', 'D'); // Authenticated + // encryption algorithms +const QuicTag kCOPT = TAG('C', 'O', 'P', 'T'); // Connection options +const QuicTag kCLOP = TAG('C', 'L', 'O', 'P'); // Client connection options +const QuicTag kICSL = TAG('I', 'C', 'S', 'L'); // Idle network timeout +const QuicTag kMIBS = TAG('M', 'I', 'D', 'S'); // Max incoming bidi streams +const QuicTag kMIUS = TAG('M', 'I', 'U', 'S'); // Max incoming unidi streams +const QuicTag kADE = TAG('A', 'D', 'E', 0); // Ack Delay Exponent (IETF + // QUIC ACK Frame Only). +const QuicTag kIRTT = TAG('I', 'R', 'T', 'T'); // Estimated initial RTT in us. +const QuicTag kTRTT = TAG('T', 'R', 'T', 'T'); // If server receives an rtt + // from an address token, set + // it as the initial rtt. +const QuicTag kSNI = TAG('S', 'N', 'I', '\0'); // Server name + // indication +const QuicTag kPUBS = TAG('P', 'U', 'B', 'S'); // Public key values +const QuicTag kSCID = TAG('S', 'C', 'I', 'D'); // Server config id +const QuicTag kORBT = TAG('O', 'B', 'I', 'T'); // Server orbit. +const QuicTag kPDMD = TAG('P', 'D', 'M', 'D'); // Proof demand. +const QuicTag kPROF = TAG('P', 'R', 'O', 'F'); // Proof (signature). +const QuicTag kCCRT = TAG('C', 'C', 'R', 'T'); // Cached certificate +const QuicTag kEXPY = TAG('E', 'X', 'P', 'Y'); // Expiry +const QuicTag kSTTL = TAG('S', 'T', 'T', 'L'); // Server Config TTL +const QuicTag kSFCW = TAG('S', 'F', 'C', 'W'); // Initial stream flow control + // receive window. +const QuicTag kCFCW = TAG('C', 'F', 'C', 'W'); // Initial session/connection + // flow control receive window. +const QuicTag kUAID = TAG('U', 'A', 'I', 'D'); // Client's User Agent ID. +const QuicTag kXLCT = TAG('X', 'L', 'C', 'T'); // Expected leaf certificate. + +const QuicTag kQNZ2 = TAG('Q', 'N', 'Z', '2'); // Turn off QUIC crypto 0-RTT. + +const QuicTag kMAD = TAG('M', 'A', 'D', 0); // Max Ack Delay (IETF QUIC) + +const QuicTag kIGNP = TAG('I', 'G', 'N', 'P'); // Do not use PING only packet + // for RTT measure or + // congestion control. + +const QuicTag kSRWP = TAG('S', 'R', 'W', 'P'); // Enable retransmittable on + // wire PING (ROWP) on the + // server side. +const QuicTag kROWF = TAG('R', 'O', 'W', 'F'); // Send first 1-RTT packet on + // ROWP timeout. +const QuicTag kROWR = TAG('R', 'O', 'W', 'R'); // Send random bytes on ROWP + // timeout. +// Selective Resumption variants. +const QuicTag kGSR0 = TAG('G', 'S', 'R', '0'); +const QuicTag kGSR1 = TAG('G', 'S', 'R', '1'); +const QuicTag kGSR2 = TAG('G', 'S', 'R', '2'); +const QuicTag kGSR3 = TAG('G', 'S', 'R', '3'); + +const QuicTag kNRES = TAG('N', 'R', 'E', 'S'); // No resumption + +const QuicTag kINVC = TAG('I', 'N', 'V', 'C'); // Send connection close for + // INVALID_VERSION + +const QuicTag kMPQC = TAG('M', 'P', 'Q', 'C'); // Multi-port QUIC connection + +// Client Hints triggers. +const QuicTag kGWCH = TAG('G', 'W', 'C', 'H'); +const QuicTag kYTCH = TAG('Y', 'T', 'C', 'H'); +const QuicTag kACH0 = TAG('A', 'C', 'H', '0'); + +// Rejection tags +const QuicTag kRREJ = TAG('R', 'R', 'E', 'J'); // Reasons for server sending + +// Server hello tags +const QuicTag kCADR = TAG('C', 'A', 'D', 'R'); // Client IP address and port +const QuicTag kASAD = TAG('A', 'S', 'A', 'D'); // Alternate Server IP address + // and port. +const QuicTag kSRST = TAG('S', 'R', 'S', 'T'); // Stateless reset token used + // in IETF public reset packet + +// CETV tags +const QuicTag kCIDK = TAG('C', 'I', 'D', 'K'); // ChannelID key +const QuicTag kCIDS = TAG('C', 'I', 'D', 'S'); // ChannelID signature + +// Public reset tags +const QuicTag kRNON = TAG('R', 'N', 'O', 'N'); // Public reset nonce proof +const QuicTag kRSEQ = TAG('R', 'S', 'E', 'Q'); // Rejected packet number + +// Universal tags +const QuicTag kPAD = TAG('P', 'A', 'D', '\0'); // Padding + +// Stats collection tags +const QuicTag kEPID = TAG('E', 'P', 'I', 'D'); // Endpoint identifier. + +// clang-format on + +// These tags have a special form so that they appear either at the beginning +// or the end of a handshake message. Since handshake messages are sorted by +// tag value, the tags with 0 at the end will sort first and those with 255 at +// the end will sort last. +// +// The certificate chain should have a tag that will cause it to be sorted at +// the end of any handshake messages because it's likely to be large and the +// client might be able to get everything that it needs from the small values at +// the beginning. +// +// Likewise tags with random values should be towards the beginning of the +// message because the server mightn't hold state for a rejected client hello +// and therefore the client may have issues reassembling the rejection message +// in the event that it sent two client hellos. +const QuicTag kServerNonceTag = TAG('S', 'N', 'O', 0); // The server's nonce +const QuicTag kSourceAddressTokenTag = + TAG('S', 'T', 'K', 0); // Source-address token +const QuicTag kCertificateTag = TAG('C', 'R', 'T', 255); // Certificate chain +const QuicTag kCertificateSCTTag = + TAG('C', 'S', 'C', 'T'); // Signed cert timestamp (RFC6962) of leaf cert. + +#undef TAG + +const size_t kMaxEntries = 128; // Max number of entries in a message. + +const size_t kNonceSize = 32; // Size in bytes of the connection nonce. + +const size_t kOrbitSize = 8; // Number of bytes in an orbit value. + +// kProofSignatureLabel is prepended to the CHLO hash and server configs before +// signing to avoid any cross-protocol attacks on the signature. +const char kProofSignatureLabel[] = "QUIC CHLO and server config signature"; + +// kClientHelloMinimumSize is the minimum size of a client hello. Client hellos +// will have PAD tags added in order to ensure this minimum is met and client +// hellos smaller than this will be an error. This minimum size reduces the +// amplification factor of any mirror DoS attack. +// +// A client may pad an inchoate client hello to a size larger than +// kClientHelloMinimumSize to make it more likely to receive a complete +// rejection message. +const size_t kClientHelloMinimumSize = 1024; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_CRYPTO_PROTOCOL_H_ diff --git a/quiche/quic/core/crypto/crypto_secret_boxer.cc b/quiche/quic/core/crypto/crypto_secret_boxer.cc new file mode 100644 index 000000000000..2be495d92cc9 --- /dev/null +++ b/quiche/quic/core/crypto/crypto_secret_boxer.cc @@ -0,0 +1,146 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/crypto_secret_boxer.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "openssl/aead.h" +#include "openssl/err.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +// kSIVNonceSize contains the number of bytes of nonce in each AES-GCM-SIV box. +// AES-GCM-SIV takes a 12-byte nonce and, since the messages are so small, each +// key is good for more than 2^64 source-address tokens. See table 1 of +// https://eprint.iacr.org/2017/168.pdf +static const size_t kSIVNonceSize = 12; + +// AES-GCM-SIV comes in AES-128 and AES-256 flavours. The AES-256 version is +// used here so that the key size matches the 256-bit XSalsa20 keys that we +// used to use. +static const size_t kBoxKeySize = 32; + +struct CryptoSecretBoxer::State { + // ctxs are the initialised AEAD contexts. These objects contain the + // scheduled AES state for each of the keys. + std::vector> ctxs; +}; + +CryptoSecretBoxer::CryptoSecretBoxer() {} + +CryptoSecretBoxer::~CryptoSecretBoxer() {} + +// static +size_t CryptoSecretBoxer::GetKeySize() { return kBoxKeySize; } + +// kAEAD is the AEAD used for boxing: AES-256-GCM-SIV. +static const EVP_AEAD* (*const kAEAD)() = EVP_aead_aes_256_gcm_siv; + +bool CryptoSecretBoxer::SetKeys(const std::vector& keys) { + if (keys.empty()) { + QUIC_LOG(DFATAL) << "No keys supplied!"; + return false; + } + const EVP_AEAD* const aead = kAEAD(); + std::unique_ptr new_state(new State); + + for (const std::string& key : keys) { + QUICHE_DCHECK_EQ(kBoxKeySize, key.size()); + bssl::UniquePtr ctx( + EVP_AEAD_CTX_new(aead, reinterpret_cast(key.data()), + key.size(), EVP_AEAD_DEFAULT_TAG_LENGTH)); + if (!ctx) { + ERR_clear_error(); + QUIC_LOG(DFATAL) << "EVP_AEAD_CTX_init failed"; + return false; + } + + new_state->ctxs.push_back(std::move(ctx)); + } + + QuicWriterMutexLock l(&lock_); + state_ = std::move(new_state); + return true; +} + +std::string CryptoSecretBoxer::Box(QuicRandom* rand, + absl::string_view plaintext) const { + // The box is formatted as: + // 12 bytes of random nonce + // n bytes of ciphertext + // 16 bytes of authenticator + size_t out_len = + kSIVNonceSize + plaintext.size() + EVP_AEAD_max_overhead(kAEAD()); + + std::string ret; + ret.resize(out_len); + uint8_t* out = reinterpret_cast(const_cast(ret.data())); + + // Write kSIVNonceSize bytes of random nonce to the beginning of the output + // buffer. + rand->RandBytes(out, kSIVNonceSize); + const uint8_t* const nonce = out; + out += kSIVNonceSize; + out_len -= kSIVNonceSize; + + size_t bytes_written; + { + QuicReaderMutexLock l(&lock_); + if (!EVP_AEAD_CTX_seal(state_->ctxs[0].get(), out, &bytes_written, out_len, + nonce, kSIVNonceSize, + reinterpret_cast(plaintext.data()), + plaintext.size(), nullptr, 0)) { + ERR_clear_error(); + QUIC_LOG(DFATAL) << "EVP_AEAD_CTX_seal failed"; + return ""; + } + } + + QUICHE_DCHECK_EQ(out_len, bytes_written); + return ret; +} + +bool CryptoSecretBoxer::Unbox(absl::string_view in_ciphertext, + std::string* out_storage, + absl::string_view* out) const { + if (in_ciphertext.size() < kSIVNonceSize) { + return false; + } + + const uint8_t* const nonce = + reinterpret_cast(in_ciphertext.data()); + const uint8_t* const ciphertext = nonce + kSIVNonceSize; + const size_t ciphertext_len = in_ciphertext.size() - kSIVNonceSize; + + out_storage->resize(ciphertext_len); + + bool ok = false; + { + QuicReaderMutexLock l(&lock_); + for (const bssl::UniquePtr& ctx : state_->ctxs) { + size_t bytes_written; + if (EVP_AEAD_CTX_open(ctx.get(), + reinterpret_cast( + const_cast(out_storage->data())), + &bytes_written, ciphertext_len, nonce, + kSIVNonceSize, ciphertext, ciphertext_len, nullptr, + 0)) { + ok = true; + *out = absl::string_view(out_storage->data(), bytes_written); + break; + } + + ERR_clear_error(); + } + } + + return ok; +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/crypto_secret_boxer.h b/quiche/quic/core/crypto/crypto_secret_boxer.h new file mode 100644 index 000000000000..5d334c3770d5 --- /dev/null +++ b/quiche/quic/core/crypto/crypto_secret_boxer.h @@ -0,0 +1,67 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_CRYPTO_SECRET_BOXER_H_ +#define QUICHE_QUIC_CORE_CRYPTO_CRYPTO_SECRET_BOXER_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_mutex.h" + +namespace quic { + +// CryptoSecretBoxer encrypts small chunks of plaintext (called 'boxing') and +// then, later, can authenticate+decrypt the resulting boxes. This object is +// thread-safe. +class QUIC_EXPORT_PRIVATE CryptoSecretBoxer { + public: + CryptoSecretBoxer(); + CryptoSecretBoxer(const CryptoSecretBoxer&) = delete; + CryptoSecretBoxer& operator=(const CryptoSecretBoxer&) = delete; + ~CryptoSecretBoxer(); + + // GetKeySize returns the number of bytes in a key. + static size_t GetKeySize(); + + // SetKeys sets a list of encryption keys. The first key in the list will be + // used by |Box|, but all supplied keys will be tried by |Unbox|, to handle + // key skew across the fleet. This must be called before |Box| or |Unbox|. + // Keys must be |GetKeySize()| bytes long. No change is made if any key is + // invalid, or if there are no keys supplied. + bool SetKeys(const std::vector& keys); + + // Box encrypts |plaintext| using a random nonce generated from |rand| and + // returns the resulting ciphertext. Since an authenticator and nonce are + // included, the result will be slightly larger than |plaintext|. The first + // key in the vector supplied to |SetKeys| will be used. |SetKeys| must be + // called before calling this method. + std::string Box(QuicRandom* rand, absl::string_view plaintext) const; + + // Unbox takes the result of a previous call to |Box| in |ciphertext| and + // authenticates+decrypts it. If |ciphertext| cannot be decrypted with any of + // the supplied keys, the function returns false. Otherwise, |out_storage| is + // used to store the result and |out| is set to point into |out_storage| and + // contains the original plaintext. + bool Unbox(absl::string_view ciphertext, std::string* out_storage, + absl::string_view* out) const; + + private: + struct State; + + mutable QuicMutex lock_; + + // state_ is an opaque pointer to whatever additional state the concrete + // implementation of CryptoSecretBoxer requires. + std::unique_ptr state_ QUIC_GUARDED_BY(lock_); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_CRYPTO_SECRET_BOXER_H_ diff --git a/quiche/quic/core/crypto/crypto_secret_boxer_test.cc b/quiche/quic/core/crypto/crypto_secret_boxer_test.cc new file mode 100644 index 000000000000..2c499fc5332c --- /dev/null +++ b/quiche/quic/core/crypto/crypto_secret_boxer_test.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/crypto_secret_boxer.h" + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { + +class CryptoSecretBoxerTest : public QuicTest {}; + +TEST_F(CryptoSecretBoxerTest, BoxAndUnbox) { + absl::string_view message("hello world"); + + CryptoSecretBoxer boxer; + boxer.SetKeys({std::string(CryptoSecretBoxer::GetKeySize(), 0x11)}); + + const std::string box = boxer.Box(QuicRandom::GetInstance(), message); + + std::string storage; + absl::string_view result; + EXPECT_TRUE(boxer.Unbox(box, &storage, &result)); + EXPECT_EQ(result, message); + + EXPECT_FALSE(boxer.Unbox(std::string(1, 'X') + box, &storage, &result)); + EXPECT_FALSE( + boxer.Unbox(box.substr(1, std::string::npos), &storage, &result)); + EXPECT_FALSE(boxer.Unbox(std::string(), &storage, &result)); + EXPECT_FALSE(boxer.Unbox( + std::string(1, box[0] ^ 0x80) + box.substr(1, std::string::npos), + &storage, &result)); +} + +// Helper function to test whether one boxer can decode the output of another. +static bool CanDecode(const CryptoSecretBoxer& decoder, + const CryptoSecretBoxer& encoder) { + absl::string_view message("hello world"); + const std::string boxed = encoder.Box(QuicRandom::GetInstance(), message); + std::string storage; + absl::string_view result; + bool ok = decoder.Unbox(boxed, &storage, &result); + if (ok) { + EXPECT_EQ(result, message); + } + return ok; +} + +TEST_F(CryptoSecretBoxerTest, MultipleKeys) { + std::string key_11(CryptoSecretBoxer::GetKeySize(), 0x11); + std::string key_12(CryptoSecretBoxer::GetKeySize(), 0x12); + + CryptoSecretBoxer boxer_11, boxer_12, boxer; + EXPECT_TRUE(boxer_11.SetKeys({key_11})); + EXPECT_TRUE(boxer_12.SetKeys({key_12})); + EXPECT_TRUE(boxer.SetKeys({key_12, key_11})); + + // Neither single-key boxer can decode the other's tokens. + EXPECT_FALSE(CanDecode(boxer_11, boxer_12)); + EXPECT_FALSE(CanDecode(boxer_12, boxer_11)); + + // |boxer| encodes with the first key, which is key_12. + EXPECT_TRUE(CanDecode(boxer_12, boxer)); + EXPECT_FALSE(CanDecode(boxer_11, boxer)); + + // The boxer with both keys can decode tokens from either single-key boxer. + EXPECT_TRUE(CanDecode(boxer, boxer_11)); + EXPECT_TRUE(CanDecode(boxer, boxer_12)); + + // After we flush key_11 from |boxer|, it can no longer decode tokens from + // |boxer_11|. + EXPECT_TRUE(boxer.SetKeys({key_12})); + EXPECT_FALSE(CanDecode(boxer, boxer_11)); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/crypto_server_test.cc b/quiche/quic/core/crypto/crypto_server_test.cc new file mode 100644 index 000000000000..56f72836b9dd --- /dev/null +++ b/quiche/quic/core/crypto/crypto_server_test.cc @@ -0,0 +1,1122 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/escaping.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "openssl/sha.h" +#include "quiche/quic/core/crypto/cert_compressor.h" +#include "quiche/quic/core/crypto/crypto_handshake.h" +#include "quiche/quic/core/crypto/crypto_utils.h" +#include "quiche/quic/core/crypto/proof_source.h" +#include "quiche/quic/core/crypto/quic_crypto_server_config.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/proto/crypto_server_config_proto.h" +#include "quiche/quic/core/quic_socket_address_coder.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/failing_proof_source.h" +#include "quiche/quic/test_tools/mock_clock.h" +#include "quiche/quic/test_tools/mock_random.h" +#include "quiche/quic/test_tools/quic_crypto_server_config_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/quiche_endian.h" + +namespace quic { +namespace test { + +namespace { + +class DummyProofVerifierCallback : public ProofVerifierCallback { + public: + DummyProofVerifierCallback() {} + ~DummyProofVerifierCallback() override {} + + void Run(bool /*ok*/, const std::string& /*error_details*/, + std::unique_ptr* /*details*/) override { + QUICHE_DCHECK(false); + } +}; + +const char kOldConfigId[] = "old-config-id"; + +} // namespace + +struct TestParams { + friend std::ostream& operator<<(std::ostream& os, const TestParams& p) { + os << " versions: " + << ParsedQuicVersionVectorToString(p.supported_versions) << " }"; + return os; + } + + // Versions supported by client and server. + ParsedQuicVersionVector supported_versions; +}; + +// Used by ::testing::PrintToStringParamName(). +std::string PrintToString(const TestParams& p) { + std::string rv = ParsedQuicVersionVectorToString(p.supported_versions); + std::replace(rv.begin(), rv.end(), ',', '_'); + return rv; +} + +// Constructs various test permutations. +std::vector GetTestParams() { + std::vector params; + + // Start with all versions, remove highest on each iteration. + ParsedQuicVersionVector supported_versions = AllSupportedVersions(); + while (!supported_versions.empty()) { + params.push_back({supported_versions}); + supported_versions.erase(supported_versions.begin()); + } + + return params; +} + +class CryptoServerTest : public QuicTestWithParam { + public: + CryptoServerTest() + : rand_(QuicRandom::GetInstance()), + client_address_(QuicIpAddress::Loopback4(), 1234), + client_version_(UnsupportedQuicVersion()), + config_(QuicCryptoServerConfig::TESTING, rand_, + crypto_test_utils::ProofSourceForTesting(), + KeyExchangeSource::Default()), + peer_(&config_), + compressed_certs_cache_( + QuicCompressedCertsCache::kQuicCompressedCertsCacheSize), + params_(new QuicCryptoNegotiatedParameters), + signed_config_(new QuicSignedServerConfig), + chlo_packet_size_(kDefaultMaxPacketSize) { + supported_versions_ = GetParam().supported_versions; + config_.set_enable_serving_sct(true); + + client_version_ = supported_versions_.front(); + client_version_label_ = CreateQuicVersionLabel(client_version_); + client_version_string_ = + std::string(reinterpret_cast(&client_version_label_), + sizeof(client_version_label_)); + } + + void SetUp() override { + QuicCryptoServerConfig::ConfigOptions old_config_options; + old_config_options.id = kOldConfigId; + config_.AddDefaultConfig(rand_, &clock_, old_config_options); + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(1000)); + QuicServerConfigProtobuf primary_config = + config_.GenerateConfig(rand_, &clock_, config_options_); + primary_config.set_primary_time(clock_.WallNow().ToUNIXSeconds()); + std::unique_ptr msg( + config_.AddConfig(primary_config, clock_.WallNow())); + + absl::string_view orbit; + QUICHE_CHECK(msg->GetStringPiece(kORBT, &orbit)); + QUICHE_CHECK_EQ(sizeof(orbit_), orbit.size()); + memcpy(orbit_, orbit.data(), orbit.size()); + + char public_value[32]; + memset(public_value, 42, sizeof(public_value)); + + nonce_hex_ = "#" + absl::BytesToHexString(GenerateNonce()); + pub_hex_ = "#" + absl::BytesToHexString( + absl::string_view(public_value, sizeof(public_value))); + + CryptoHandshakeMessage client_hello = + crypto_test_utils::CreateCHLO({{"PDMD", "X509"}, + {"AEAD", "AESG"}, + {"KEXS", "C255"}, + {"PUBS", pub_hex_}, + {"NONC", nonce_hex_}, + {"CSCT", ""}, + {"VER\0", client_version_string_}}, + kClientHelloMinimumSize); + ShouldSucceed(client_hello); + // The message should be rejected because the source-address token is + // missing. + CheckRejectTag(); + const HandshakeFailureReason kRejectReasons[] = { + SERVER_CONFIG_INCHOATE_HELLO_FAILURE}; + CheckRejectReasons(kRejectReasons, ABSL_ARRAYSIZE(kRejectReasons)); + + absl::string_view srct; + ASSERT_TRUE(out_.GetStringPiece(kSourceAddressTokenTag, &srct)); + srct_hex_ = "#" + absl::BytesToHexString(srct); + + absl::string_view scfg; + ASSERT_TRUE(out_.GetStringPiece(kSCFG, &scfg)); + server_config_ = CryptoFramer::ParseMessage(scfg); + + absl::string_view scid; + ASSERT_TRUE(server_config_->GetStringPiece(kSCID, &scid)); + scid_hex_ = "#" + absl::BytesToHexString(scid); + + signed_config_ = + quiche::QuicheReferenceCountedPointer( + new QuicSignedServerConfig()); + QUICHE_DCHECK(signed_config_->chain.get() == nullptr); + } + + // Helper used to accept the result of ValidateClientHello and pass + // it on to ProcessClientHello. + class ValidateCallback : public ValidateClientHelloResultCallback { + public: + ValidateCallback(CryptoServerTest* test, bool should_succeed, + const char* error_substr, bool* called) + : test_(test), + should_succeed_(should_succeed), + error_substr_(error_substr), + called_(called) { + *called_ = false; + } + + void Run(quiche::QuicheReferenceCountedPointer result, + std::unique_ptr /* details */) override { + ASSERT_FALSE(*called_); + test_->ProcessValidationResult(std::move(result), should_succeed_, + error_substr_); + *called_ = true; + } + + private: + CryptoServerTest* test_; + const bool should_succeed_; + const char* const error_substr_; + bool* called_; + }; + + void CheckServerHello(const CryptoHandshakeMessage& server_hello) { + QuicVersionLabelVector versions; + server_hello.GetVersionLabelList(kVER, &versions); + ASSERT_EQ(supported_versions_.size(), versions.size()); + for (size_t i = 0; i < versions.size(); ++i) { + EXPECT_EQ(CreateQuicVersionLabel(supported_versions_[i]), versions[i]); + } + + absl::string_view address; + ASSERT_TRUE(server_hello.GetStringPiece(kCADR, &address)); + QuicSocketAddressCoder decoder; + ASSERT_TRUE(decoder.Decode(address.data(), address.size())); + EXPECT_EQ(client_address_.host(), decoder.ip()); + EXPECT_EQ(client_address_.port(), decoder.port()); + } + + void ShouldSucceed(const CryptoHandshakeMessage& message) { + bool called = false; + QuicSocketAddress server_address(QuicIpAddress::Any4(), 5); + config_.ValidateClientHello( + message, client_address_, server_address, + supported_versions_.front().transport_version, &clock_, signed_config_, + std::make_unique(this, true, "", &called)); + EXPECT_TRUE(called); + } + + void ShouldFailMentioning(const char* error_substr, + const CryptoHandshakeMessage& message) { + bool called = false; + ShouldFailMentioning(error_substr, message, &called); + EXPECT_TRUE(called); + } + + void ShouldFailMentioning(const char* error_substr, + const CryptoHandshakeMessage& message, + bool* called) { + QuicSocketAddress server_address(QuicIpAddress::Any4(), 5); + config_.ValidateClientHello( + message, client_address_, server_address, + supported_versions_.front().transport_version, &clock_, signed_config_, + std::make_unique(this, false, error_substr, called)); + } + + class ProcessCallback : public ProcessClientHelloResultCallback { + public: + ProcessCallback( + quiche::QuicheReferenceCountedPointer result, + bool should_succeed, const char* error_substr, bool* called, + CryptoHandshakeMessage* out) + : result_(std::move(result)), + should_succeed_(should_succeed), + error_substr_(error_substr), + called_(called), + out_(out) { + *called_ = false; + } + + void Run(QuicErrorCode error, const std::string& error_details, + std::unique_ptr message, + std::unique_ptr /*diversification_nonce*/, + std::unique_ptr /*proof_source_details*/) + override { + if (should_succeed_) { + ASSERT_EQ(error, QUIC_NO_ERROR) + << "Message failed with error " << error_details << ": " + << result_->client_hello.DebugString(); + } else { + ASSERT_NE(error, QUIC_NO_ERROR) + << "Message didn't fail: " << result_->client_hello.DebugString(); + EXPECT_TRUE(absl::StrContains(error_details, error_substr_)) + << error_substr_ << " not in " << error_details; + } + if (message != nullptr) { + *out_ = *message; + } + *called_ = true; + } + + private: + const quiche::QuicheReferenceCountedPointer + result_; + const bool should_succeed_; + const char* const error_substr_; + bool* called_; + CryptoHandshakeMessage* out_; + }; + + void ProcessValidationResult( + quiche::QuicheReferenceCountedPointer result, + bool should_succeed, const char* error_substr) { + QuicSocketAddress server_address(QuicIpAddress::Any4(), 5); + bool called; + config_.ProcessClientHello( + result, /*reject_only=*/false, + /*connection_id=*/TestConnectionId(1), server_address, client_address_, + supported_versions_.front(), supported_versions_, &clock_, rand_, + &compressed_certs_cache_, params_, signed_config_, + /*total_framing_overhead=*/50, chlo_packet_size_, + std::make_unique(result, should_succeed, error_substr, + &called, &out_)); + EXPECT_TRUE(called); + } + + std::string GenerateNonce() { + std::string nonce; + CryptoUtils::GenerateNonce( + clock_.WallNow(), rand_, + absl::string_view(reinterpret_cast(orbit_), + sizeof(orbit_)), + &nonce); + return nonce; + } + + void CheckRejectReasons( + const HandshakeFailureReason* expected_handshake_failures, + size_t expected_count) { + QuicTagVector reject_reasons; + static_assert(sizeof(QuicTag) == sizeof(uint32_t), "header out of sync"); + QuicErrorCode error_code = out_.GetTaglist(kRREJ, &reject_reasons); + ASSERT_THAT(error_code, IsQuicNoError()); + + EXPECT_EQ(expected_count, reject_reasons.size()); + for (size_t i = 0; i < reject_reasons.size(); ++i) { + EXPECT_EQ(static_cast(expected_handshake_failures[i]), + reject_reasons[i]); + } + } + + void CheckRejectTag() { + ASSERT_EQ(kREJ, out_.tag()) << QuicTagToString(out_.tag()); + } + + std::string XlctHexString() { + uint64_t xlct = crypto_test_utils::LeafCertHashForTesting(); + return "#" + absl::BytesToHexString(absl::string_view( + reinterpret_cast(&xlct), sizeof(xlct))); + } + + protected: + QuicRandom* const rand_; + MockRandom rand_for_id_generation_; + MockClock clock_; + QuicSocketAddress client_address_; + ParsedQuicVersionVector supported_versions_; + ParsedQuicVersion client_version_; + QuicVersionLabel client_version_label_; + std::string client_version_string_; + QuicCryptoServerConfig config_; + QuicCryptoServerConfigPeer peer_; + QuicCompressedCertsCache compressed_certs_cache_; + QuicCryptoServerConfig::ConfigOptions config_options_; + quiche::QuicheReferenceCountedPointer params_; + quiche::QuicheReferenceCountedPointer signed_config_; + CryptoHandshakeMessage out_; + uint8_t orbit_[kOrbitSize]; + size_t chlo_packet_size_; + + // These strings contain hex escaped values from the server suitable for using + // when constructing client hello messages. + std::string nonce_hex_, pub_hex_, srct_hex_, scid_hex_; + std::unique_ptr server_config_; +}; + +INSTANTIATE_TEST_SUITE_P(CryptoServerTests, CryptoServerTest, + ::testing::ValuesIn(GetTestParams()), + ::testing::PrintToStringParamName()); + +TEST_P(CryptoServerTest, BadSNI) { + // clang-format off + std::vector badSNIs = { + "", + "#00", + "#ff00", + "127.0.0.1", + "ffee::1", + }; + // clang-format on + + for (const std::string& bad_sni : badSNIs) { + CryptoHandshakeMessage msg = crypto_test_utils::CreateCHLO( + {{"PDMD", "X509"}, {"SNI", bad_sni}, {"VER\0", client_version_string_}}, + kClientHelloMinimumSize); + ShouldFailMentioning("SNI", msg); + const HandshakeFailureReason kRejectReasons[] = { + SERVER_CONFIG_INCHOATE_HELLO_FAILURE}; + CheckRejectReasons(kRejectReasons, ABSL_ARRAYSIZE(kRejectReasons)); + } + + // Check that SNIs without dots are allowed + CryptoHandshakeMessage msg = crypto_test_utils::CreateCHLO( + {{"PDMD", "X509"}, {"SNI", "foo"}, {"VER\0", client_version_string_}}, + kClientHelloMinimumSize); + ShouldSucceed(msg); +} + +TEST_P(CryptoServerTest, DefaultCert) { + // Check that the server replies with a default certificate when no SNI is + // specified. The CHLO is constructed to generate a REJ with certs, so must + // not contain a valid STK, and must include PDMD. + CryptoHandshakeMessage msg = + crypto_test_utils::CreateCHLO({{"AEAD", "AESG"}, + {"KEXS", "C255"}, + {"PUBS", pub_hex_}, + {"NONC", nonce_hex_}, + {"PDMD", "X509"}, + {"VER\0", client_version_string_}}, + kClientHelloMinimumSize); + + ShouldSucceed(msg); + absl::string_view cert, proof, cert_sct; + EXPECT_TRUE(out_.GetStringPiece(kCertificateTag, &cert)); + EXPECT_TRUE(out_.GetStringPiece(kPROF, &proof)); + EXPECT_TRUE(out_.GetStringPiece(kCertificateSCTTag, &cert_sct)); + EXPECT_NE(0u, cert.size()); + EXPECT_NE(0u, proof.size()); + const HandshakeFailureReason kRejectReasons[] = { + SERVER_CONFIG_INCHOATE_HELLO_FAILURE}; + CheckRejectReasons(kRejectReasons, ABSL_ARRAYSIZE(kRejectReasons)); + EXPECT_LT(0u, cert_sct.size()); +} + +TEST_P(CryptoServerTest, RejectTooLarge) { + // Check that the server replies with no certificate when a CHLO is + // constructed with a PDMD but no SKT when the REJ would be too large. + CryptoHandshakeMessage msg = + crypto_test_utils::CreateCHLO({{"PDMD", "X509"}, + {"AEAD", "AESG"}, + {"KEXS", "C255"}, + {"PUBS", pub_hex_}, + {"NONC", nonce_hex_}, + {"PDMD", "X509"}, + {"VER\0", client_version_string_}}, + kClientHelloMinimumSize); + + // The REJ will be larger than the CHLO so no PROF or CRT will be sent. + config_.set_chlo_multiplier(1); + + ShouldSucceed(msg); + absl::string_view cert, proof, cert_sct; + EXPECT_FALSE(out_.GetStringPiece(kCertificateTag, &cert)); + EXPECT_FALSE(out_.GetStringPiece(kPROF, &proof)); + EXPECT_FALSE(out_.GetStringPiece(kCertificateSCTTag, &cert_sct)); + const HandshakeFailureReason kRejectReasons[] = { + SERVER_CONFIG_INCHOATE_HELLO_FAILURE}; + CheckRejectReasons(kRejectReasons, ABSL_ARRAYSIZE(kRejectReasons)); +} + +TEST_P(CryptoServerTest, RejectNotTooLarge) { + // When the CHLO packet is large enough, ensure that a full REJ is sent. + chlo_packet_size_ *= 5; + + CryptoHandshakeMessage msg = + crypto_test_utils::CreateCHLO({{"PDMD", "X509"}, + {"AEAD", "AESG"}, + {"KEXS", "C255"}, + {"PUBS", pub_hex_}, + {"NONC", nonce_hex_}, + {"PDMD", "X509"}, + {"VER\0", client_version_string_}}, + kClientHelloMinimumSize); + + // The REJ will be larger than the CHLO so no PROF or CRT will be sent. + config_.set_chlo_multiplier(1); + + ShouldSucceed(msg); + absl::string_view cert, proof, cert_sct; + EXPECT_TRUE(out_.GetStringPiece(kCertificateTag, &cert)); + EXPECT_TRUE(out_.GetStringPiece(kPROF, &proof)); + EXPECT_TRUE(out_.GetStringPiece(kCertificateSCTTag, &cert_sct)); + const HandshakeFailureReason kRejectReasons[] = { + SERVER_CONFIG_INCHOATE_HELLO_FAILURE}; + CheckRejectReasons(kRejectReasons, ABSL_ARRAYSIZE(kRejectReasons)); +} + +TEST_P(CryptoServerTest, RejectTooLargeButValidSTK) { + // Check that the server replies with no certificate when a CHLO is + // constructed with a PDMD but no SKT when the REJ would be too large. + CryptoHandshakeMessage msg = + crypto_test_utils::CreateCHLO({{"PDMD", "X509"}, + {"AEAD", "AESG"}, + {"KEXS", "C255"}, + {"PUBS", pub_hex_}, + {"NONC", nonce_hex_}, + {"#004b5453", srct_hex_}, + {"PDMD", "X509"}, + {"VER\0", client_version_string_}}, + kClientHelloMinimumSize); + + // The REJ will be larger than the CHLO so no PROF or CRT will be sent. + config_.set_chlo_multiplier(1); + + ShouldSucceed(msg); + absl::string_view cert, proof, cert_sct; + EXPECT_TRUE(out_.GetStringPiece(kCertificateTag, &cert)); + EXPECT_TRUE(out_.GetStringPiece(kPROF, &proof)); + EXPECT_TRUE(out_.GetStringPiece(kCertificateSCTTag, &cert_sct)); + EXPECT_NE(0u, cert.size()); + EXPECT_NE(0u, proof.size()); + const HandshakeFailureReason kRejectReasons[] = { + SERVER_CONFIG_INCHOATE_HELLO_FAILURE}; + CheckRejectReasons(kRejectReasons, ABSL_ARRAYSIZE(kRejectReasons)); +} + +TEST_P(CryptoServerTest, BadSourceAddressToken) { + // Invalid source-address tokens should be ignored. + // clang-format off + static const char* const kBadSourceAddressTokens[] = { + "", + "foo", + "#0000", + "#0000000000000000000000000000000000000000", + }; + // clang-format on + + for (size_t i = 0; i < ABSL_ARRAYSIZE(kBadSourceAddressTokens); i++) { + CryptoHandshakeMessage msg = + crypto_test_utils::CreateCHLO({{"PDMD", "X509"}, + {"STK", kBadSourceAddressTokens[i]}, + {"VER\0", client_version_string_}}, + kClientHelloMinimumSize); + ShouldSucceed(msg); + const HandshakeFailureReason kRejectReasons[] = { + SERVER_CONFIG_INCHOATE_HELLO_FAILURE}; + CheckRejectReasons(kRejectReasons, ABSL_ARRAYSIZE(kRejectReasons)); + } +} + +TEST_P(CryptoServerTest, BadClientNonce) { + // clang-format off + static const char* const kBadNonces[] = { + "", + "#0000", + "#0000000000000000000000000000000000000000", + }; + // clang-format on + + for (size_t i = 0; i < ABSL_ARRAYSIZE(kBadNonces); i++) { + // Invalid nonces should be ignored, in an inchoate CHLO. + + CryptoHandshakeMessage msg = + crypto_test_utils::CreateCHLO({{"PDMD", "X509"}, + {"NONC", kBadNonces[i]}, + {"VER\0", client_version_string_}}, + kClientHelloMinimumSize); + + ShouldSucceed(msg); + const HandshakeFailureReason kRejectReasons[] = { + SERVER_CONFIG_INCHOATE_HELLO_FAILURE}; + CheckRejectReasons(kRejectReasons, ABSL_ARRAYSIZE(kRejectReasons)); + + // Invalid nonces should result in CLIENT_NONCE_INVALID_FAILURE. + CryptoHandshakeMessage msg1 = + crypto_test_utils::CreateCHLO({{"PDMD", "X509"}, + {"AEAD", "AESG"}, + {"KEXS", "C255"}, + {"SCID", scid_hex_}, + {"#004b5453", srct_hex_}, + {"PUBS", pub_hex_}, + {"NONC", kBadNonces[i]}, + {"NONP", kBadNonces[i]}, + {"XLCT", XlctHexString()}, + {"VER\0", client_version_string_}}, + kClientHelloMinimumSize); + + ShouldSucceed(msg1); + + CheckRejectTag(); + const HandshakeFailureReason kRejectReasons1[] = { + CLIENT_NONCE_INVALID_FAILURE}; + CheckRejectReasons(kRejectReasons1, ABSL_ARRAYSIZE(kRejectReasons1)); + } +} + +TEST_P(CryptoServerTest, NoClientNonce) { + // No client nonces should result in INCHOATE_HELLO_FAILURE. + + CryptoHandshakeMessage msg = crypto_test_utils::CreateCHLO( + {{"PDMD", "X509"}, {"VER\0", client_version_string_}}, + kClientHelloMinimumSize); + + ShouldSucceed(msg); + const HandshakeFailureReason kRejectReasons[] = { + SERVER_CONFIG_INCHOATE_HELLO_FAILURE}; + CheckRejectReasons(kRejectReasons, ABSL_ARRAYSIZE(kRejectReasons)); + + CryptoHandshakeMessage msg1 = + crypto_test_utils::CreateCHLO({{"PDMD", "X509"}, + {"AEAD", "AESG"}, + {"KEXS", "C255"}, + {"SCID", scid_hex_}, + {"#004b5453", srct_hex_}, + {"PUBS", pub_hex_}, + {"XLCT", XlctHexString()}, + {"VER\0", client_version_string_}}, + kClientHelloMinimumSize); + + ShouldSucceed(msg1); + CheckRejectTag(); + const HandshakeFailureReason kRejectReasons1[] = { + SERVER_CONFIG_INCHOATE_HELLO_FAILURE}; + CheckRejectReasons(kRejectReasons1, ABSL_ARRAYSIZE(kRejectReasons1)); +} + +TEST_P(CryptoServerTest, DowngradeAttack) { + if (supported_versions_.size() == 1) { + // No downgrade attack is possible if the server only supports one version. + return; + } + // Set the client's preferred version to a supported version that + // is not the "current" version (supported_versions_.front()). + std::string bad_version = + ParsedQuicVersionToString(supported_versions_.back()); + + CryptoHandshakeMessage msg = crypto_test_utils::CreateCHLO( + {{"PDMD", "X509"}, {"VER\0", bad_version}}, kClientHelloMinimumSize); + + ShouldFailMentioning("Downgrade", msg); + const HandshakeFailureReason kRejectReasons[] = { + SERVER_CONFIG_INCHOATE_HELLO_FAILURE}; + CheckRejectReasons(kRejectReasons, ABSL_ARRAYSIZE(kRejectReasons)); +} + +TEST_P(CryptoServerTest, CorruptServerConfig) { + // This tests corrupted server config. + CryptoHandshakeMessage msg = crypto_test_utils::CreateCHLO( + {{"PDMD", "X509"}, + {"AEAD", "AESG"}, + {"KEXS", "C255"}, + {"SCID", (std::string(1, 'X') + scid_hex_)}, + {"#004b5453", srct_hex_}, + {"PUBS", pub_hex_}, + {"NONC", nonce_hex_}, + {"VER\0", client_version_string_}}, + kClientHelloMinimumSize); + + ShouldSucceed(msg); + CheckRejectTag(); + const HandshakeFailureReason kRejectReasons[] = { + SERVER_CONFIG_UNKNOWN_CONFIG_FAILURE}; + CheckRejectReasons(kRejectReasons, ABSL_ARRAYSIZE(kRejectReasons)); +} + +TEST_P(CryptoServerTest, CorruptSourceAddressToken) { + // This tests corrupted source address token. + CryptoHandshakeMessage msg = crypto_test_utils::CreateCHLO( + {{"PDMD", "X509"}, + {"AEAD", "AESG"}, + {"KEXS", "C255"}, + {"SCID", scid_hex_}, + {"#004b5453", (std::string(1, 'X') + srct_hex_)}, + {"PUBS", pub_hex_}, + {"NONC", nonce_hex_}, + {"XLCT", XlctHexString()}, + {"VER\0", client_version_string_}}, + kClientHelloMinimumSize); + + ShouldSucceed(msg); + CheckRejectTag(); + const HandshakeFailureReason kRejectReasons[] = { + SOURCE_ADDRESS_TOKEN_DECRYPTION_FAILURE}; + CheckRejectReasons(kRejectReasons, ABSL_ARRAYSIZE(kRejectReasons)); +} + +TEST_P(CryptoServerTest, CorruptSourceAddressTokenIsStillAccepted) { + // This tests corrupted source address token. + CryptoHandshakeMessage msg = crypto_test_utils::CreateCHLO( + {{"PDMD", "X509"}, + {"AEAD", "AESG"}, + {"KEXS", "C255"}, + {"SCID", scid_hex_}, + {"#004b5453", (std::string(1, 'X') + srct_hex_)}, + {"PUBS", pub_hex_}, + {"NONC", nonce_hex_}, + {"XLCT", XlctHexString()}, + {"VER\0", client_version_string_}}, + kClientHelloMinimumSize); + + config_.set_validate_source_address_token(false); + + ShouldSucceed(msg); + EXPECT_EQ(kSHLO, out_.tag()); +} + +TEST_P(CryptoServerTest, CorruptClientNonceAndSourceAddressToken) { + // This test corrupts client nonce and source address token. + CryptoHandshakeMessage msg = crypto_test_utils::CreateCHLO( + {{"PDMD", "X509"}, + {"AEAD", "AESG"}, + {"KEXS", "C255"}, + {"SCID", scid_hex_}, + {"#004b5453", (std::string(1, 'X') + srct_hex_)}, + {"PUBS", pub_hex_}, + {"NONC", (std::string(1, 'X') + nonce_hex_)}, + {"XLCT", XlctHexString()}, + {"VER\0", client_version_string_}}, + kClientHelloMinimumSize); + + ShouldSucceed(msg); + CheckRejectTag(); + const HandshakeFailureReason kRejectReasons[] = { + SOURCE_ADDRESS_TOKEN_DECRYPTION_FAILURE, CLIENT_NONCE_INVALID_FAILURE}; + CheckRejectReasons(kRejectReasons, ABSL_ARRAYSIZE(kRejectReasons)); +} + +TEST_P(CryptoServerTest, CorruptMultipleTags) { + // This test corrupts client nonce, server nonce and source address token. + CryptoHandshakeMessage msg = crypto_test_utils::CreateCHLO( + {{"PDMD", "X509"}, + {"AEAD", "AESG"}, + {"KEXS", "C255"}, + {"SCID", scid_hex_}, + {"#004b5453", (std::string(1, 'X') + srct_hex_)}, + {"PUBS", pub_hex_}, + {"NONC", (std::string(1, 'X') + nonce_hex_)}, + {"NONP", (std::string(1, 'X') + nonce_hex_)}, + {"SNO\0", (std::string(1, 'X') + nonce_hex_)}, + {"XLCT", XlctHexString()}, + {"VER\0", client_version_string_}}, + kClientHelloMinimumSize); + + ShouldSucceed(msg); + CheckRejectTag(); + + const HandshakeFailureReason kRejectReasons[] = { + SOURCE_ADDRESS_TOKEN_DECRYPTION_FAILURE, CLIENT_NONCE_INVALID_FAILURE}; + CheckRejectReasons(kRejectReasons, ABSL_ARRAYSIZE(kRejectReasons)); +} + +TEST_P(CryptoServerTest, NoServerNonce) { + // When no server nonce is present and no strike register is configured, + // the CHLO should be rejected. + CryptoHandshakeMessage msg = + crypto_test_utils::CreateCHLO({{"PDMD", "X509"}, + {"AEAD", "AESG"}, + {"KEXS", "C255"}, + {"SCID", scid_hex_}, + {"#004b5453", srct_hex_}, + {"PUBS", pub_hex_}, + {"NONC", nonce_hex_}, + {"NONP", nonce_hex_}, + {"XLCT", XlctHexString()}, + {"VER\0", client_version_string_}}, + kClientHelloMinimumSize); + + ShouldSucceed(msg); + + // Even without a server nonce, this ClientHello should be accepted in + // version 33. + ASSERT_EQ(kSHLO, out_.tag()); + CheckServerHello(out_); +} + +TEST_P(CryptoServerTest, ProofForSuppliedServerConfig) { + client_address_ = QuicSocketAddress(QuicIpAddress::Loopback6(), 1234); + + CryptoHandshakeMessage msg = + crypto_test_utils::CreateCHLO({{"AEAD", "AESG"}, + {"KEXS", "C255"}, + {"PDMD", "X509"}, + {"SCID", kOldConfigId}, + {"#004b5453", srct_hex_}, + {"PUBS", pub_hex_}, + {"NONC", nonce_hex_}, + {"NONP", "123456789012345678901234567890"}, + {"VER\0", client_version_string_}, + {"XLCT", XlctHexString()}}, + kClientHelloMinimumSize); + + ShouldSucceed(msg); + // The message should be rejected because the source-address token is no + // longer valid. + CheckRejectTag(); + const HandshakeFailureReason kRejectReasons[] = { + SOURCE_ADDRESS_TOKEN_DIFFERENT_IP_ADDRESS_FAILURE}; + CheckRejectReasons(kRejectReasons, ABSL_ARRAYSIZE(kRejectReasons)); + + absl::string_view cert, proof, scfg_str; + EXPECT_TRUE(out_.GetStringPiece(kCertificateTag, &cert)); + EXPECT_TRUE(out_.GetStringPiece(kPROF, &proof)); + EXPECT_TRUE(out_.GetStringPiece(kSCFG, &scfg_str)); + std::unique_ptr scfg( + CryptoFramer::ParseMessage(scfg_str)); + absl::string_view scid; + EXPECT_TRUE(scfg->GetStringPiece(kSCID, &scid)); + EXPECT_NE(scid, kOldConfigId); + + // Get certs from compressed certs. + std::vector cached_certs; + + std::vector certs; + ASSERT_TRUE(CertCompressor::DecompressChain(cert, cached_certs, &certs)); + + // Check that the proof in the REJ message is valid. + std::unique_ptr proof_verifier( + crypto_test_utils::ProofVerifierForTesting()); + std::unique_ptr verify_context( + crypto_test_utils::ProofVerifyContextForTesting()); + std::unique_ptr details; + std::string error_details; + std::unique_ptr callback( + new DummyProofVerifierCallback()); + const std::string chlo_hash = + CryptoUtils::HashHandshakeMessage(msg, Perspective::IS_SERVER); + EXPECT_EQ(QUIC_SUCCESS, + proof_verifier->VerifyProof( + "test.example.com", 443, (std::string(scfg_str)), + client_version_.transport_version, chlo_hash, certs, "", + (std::string(proof)), verify_context.get(), &error_details, + &details, std::move(callback))); +} + +TEST_P(CryptoServerTest, RejectInvalidXlct) { + CryptoHandshakeMessage msg = + crypto_test_utils::CreateCHLO({{"PDMD", "X509"}, + {"AEAD", "AESG"}, + {"KEXS", "C255"}, + {"SCID", scid_hex_}, + {"#004b5453", srct_hex_}, + {"PUBS", pub_hex_}, + {"NONC", nonce_hex_}, + {"VER\0", client_version_string_}, + {"XLCT", "#0102030405060708"}}, + kClientHelloMinimumSize); + + // If replay protection isn't disabled, then + // QuicCryptoServerConfig::EvaluateClientHello will leave info.unique as false + // and cause ProcessClientHello to exit early (and generate a REJ message). + config_.set_replay_protection(false); + + ShouldSucceed(msg); + + const HandshakeFailureReason kRejectReasons[] = { + INVALID_EXPECTED_LEAF_CERTIFICATE}; + + CheckRejectReasons(kRejectReasons, ABSL_ARRAYSIZE(kRejectReasons)); +} + +TEST_P(CryptoServerTest, ValidXlct) { + CryptoHandshakeMessage msg = + crypto_test_utils::CreateCHLO({{"PDMD", "X509"}, + {"AEAD", "AESG"}, + {"KEXS", "C255"}, + {"SCID", scid_hex_}, + {"#004b5453", srct_hex_}, + {"PUBS", pub_hex_}, + {"NONC", nonce_hex_}, + {"VER\0", client_version_string_}, + {"XLCT", XlctHexString()}}, + kClientHelloMinimumSize); + + // If replay protection isn't disabled, then + // QuicCryptoServerConfig::EvaluateClientHello will leave info.unique as false + // and cause ProcessClientHello to exit early (and generate a REJ message). + config_.set_replay_protection(false); + + ShouldSucceed(msg); + EXPECT_EQ(kSHLO, out_.tag()); +} + +TEST_P(CryptoServerTest, NonceInSHLO) { + CryptoHandshakeMessage msg = + crypto_test_utils::CreateCHLO({{"PDMD", "X509"}, + {"AEAD", "AESG"}, + {"KEXS", "C255"}, + {"SCID", scid_hex_}, + {"#004b5453", srct_hex_}, + {"PUBS", pub_hex_}, + {"NONC", nonce_hex_}, + {"VER\0", client_version_string_}, + {"XLCT", XlctHexString()}}, + kClientHelloMinimumSize); + + // If replay protection isn't disabled, then + // QuicCryptoServerConfig::EvaluateClientHello will leave info.unique as false + // and cause ProcessClientHello to exit early (and generate a REJ message). + config_.set_replay_protection(false); + + ShouldSucceed(msg); + EXPECT_EQ(kSHLO, out_.tag()); + + absl::string_view nonce; + EXPECT_TRUE(out_.GetStringPiece(kServerNonceTag, &nonce)); +} + +TEST_P(CryptoServerTest, ProofSourceFailure) { + // Install a ProofSource which will unconditionally fail + peer_.ResetProofSource(std::unique_ptr(new FailingProofSource)); + + CryptoHandshakeMessage msg = + crypto_test_utils::CreateCHLO({{"AEAD", "AESG"}, + {"KEXS", "C255"}, + {"SCID", scid_hex_}, + {"PUBS", pub_hex_}, + {"NONC", nonce_hex_}, + {"PDMD", "X509"}, + {"VER\0", client_version_string_}}, + kClientHelloMinimumSize); + + // Just ensure that we don't crash as occurred in b/33916924. + ShouldFailMentioning("", msg); +} + +// Regression test for crbug.com/723604 +// For 2RTT, if the first CHLO from the client contains hashes of cached +// certs (stored in CCRT tag) but the second CHLO does not, then the second REJ +// from the server should not contain hashes of cached certs. +TEST_P(CryptoServerTest, TwoRttServerDropCachedCerts) { + // Send inchoate CHLO to get cert chain from server. This CHLO is only for + // the purpose of getting the server's certs; it is not part of the 2RTT + // handshake. + CryptoHandshakeMessage msg = crypto_test_utils::CreateCHLO( + {{"PDMD", "X509"}, {"VER\0", client_version_string_}}, + kClientHelloMinimumSize); + ShouldSucceed(msg); + + // Decompress cert chain from server to individual certs. + absl::string_view certs_compressed; + ASSERT_TRUE(out_.GetStringPiece(kCertificateTag, &certs_compressed)); + ASSERT_NE(0u, certs_compressed.size()); + std::vector certs; + ASSERT_TRUE(CertCompressor::DecompressChain(certs_compressed, + /*cached_certs=*/{}, &certs)); + + // Start 2-RTT. Client sends CHLO with bad source-address token and hashes of + // the certs, which tells the server that the client has cached those certs. + config_.set_chlo_multiplier(1); + const char kBadSourceAddressToken[] = ""; + msg.SetStringPiece(kSourceAddressTokenTag, kBadSourceAddressToken); + std::vector hashes(certs.size()); + for (size_t i = 0; i < certs.size(); ++i) { + hashes[i] = QuicUtils::QuicUtils::FNV1a_64_Hash(certs[i]); + } + msg.SetVector(kCCRT, hashes); + ShouldSucceed(msg); + + // Server responds with inchoate REJ containing valid source-address token. + absl::string_view srct; + ASSERT_TRUE(out_.GetStringPiece(kSourceAddressTokenTag, &srct)); + + // Client now drops cached certs; sends CHLO with updated source-address + // token but no hashes of certs. + msg.SetStringPiece(kSourceAddressTokenTag, srct); + msg.Erase(kCCRT); + ShouldSucceed(msg); + + // Server response's cert chain should not contain hashes of + // previously-cached certs. + ASSERT_TRUE(out_.GetStringPiece(kCertificateTag, &certs_compressed)); + ASSERT_NE(0u, certs_compressed.size()); + ASSERT_TRUE(CertCompressor::DecompressChain(certs_compressed, + /*cached_certs=*/{}, &certs)); +} + +class CryptoServerConfigGenerationTest : public QuicTest {}; + +TEST_F(CryptoServerConfigGenerationTest, Determinism) { + // Test that using a deterministic PRNG causes the server-config to be + // deterministic. + + MockRandom rand_a, rand_b; + const QuicCryptoServerConfig::ConfigOptions options; + MockClock clock; + + QuicCryptoServerConfig a(QuicCryptoServerConfig::TESTING, &rand_a, + crypto_test_utils::ProofSourceForTesting(), + KeyExchangeSource::Default()); + QuicCryptoServerConfig b(QuicCryptoServerConfig::TESTING, &rand_b, + crypto_test_utils::ProofSourceForTesting(), + KeyExchangeSource::Default()); + std::unique_ptr scfg_a( + a.AddDefaultConfig(&rand_a, &clock, options)); + std::unique_ptr scfg_b( + b.AddDefaultConfig(&rand_b, &clock, options)); + + ASSERT_EQ(scfg_a->DebugString(), scfg_b->DebugString()); +} + +TEST_F(CryptoServerConfigGenerationTest, SCIDVaries) { + // This test ensures that the server config ID varies for different server + // configs. + + MockRandom rand_a, rand_b; + const QuicCryptoServerConfig::ConfigOptions options; + MockClock clock; + + QuicCryptoServerConfig a(QuicCryptoServerConfig::TESTING, &rand_a, + crypto_test_utils::ProofSourceForTesting(), + KeyExchangeSource::Default()); + rand_b.ChangeValue(); + QuicCryptoServerConfig b(QuicCryptoServerConfig::TESTING, &rand_b, + crypto_test_utils::ProofSourceForTesting(), + KeyExchangeSource::Default()); + std::unique_ptr scfg_a( + a.AddDefaultConfig(&rand_a, &clock, options)); + std::unique_ptr scfg_b( + b.AddDefaultConfig(&rand_b, &clock, options)); + + absl::string_view scid_a, scid_b; + EXPECT_TRUE(scfg_a->GetStringPiece(kSCID, &scid_a)); + EXPECT_TRUE(scfg_b->GetStringPiece(kSCID, &scid_b)); + + EXPECT_NE(scid_a, scid_b); +} + +TEST_F(CryptoServerConfigGenerationTest, SCIDIsHashOfServerConfig) { + MockRandom rand_a; + const QuicCryptoServerConfig::ConfigOptions options; + MockClock clock; + + QuicCryptoServerConfig a(QuicCryptoServerConfig::TESTING, &rand_a, + crypto_test_utils::ProofSourceForTesting(), + KeyExchangeSource::Default()); + std::unique_ptr scfg( + a.AddDefaultConfig(&rand_a, &clock, options)); + + absl::string_view scid; + EXPECT_TRUE(scfg->GetStringPiece(kSCID, &scid)); + // Need to take a copy of |scid| has we're about to call |Erase|. + const std::string scid_str(scid); + + scfg->Erase(kSCID); + scfg->MarkDirty(); + const QuicData& serialized(scfg->GetSerialized()); + + uint8_t digest[SHA256_DIGEST_LENGTH]; + SHA256(reinterpret_cast(serialized.data()), + serialized.length(), digest); + + // scid is a SHA-256 hash, truncated to 16 bytes. + ASSERT_EQ(scid.size(), 16u); + EXPECT_EQ(0, memcmp(digest, scid_str.c_str(), scid.size())); +} + +// Those tests were declared incorrectly and thus never ran in first place. +// TODO(b/147891553): figure out if we should fix or delete those. +#if 0 + +class CryptoServerTestNoConfig : public CryptoServerTest { + public: + void SetUp() override { + // Deliberately don't add a config so that we can test this situation. + } +}; + +INSTANTIATE_TEST_SUITE_P(CryptoServerTestsNoConfig, + CryptoServerTestNoConfig, + ::testing::ValuesIn(GetTestParams()), + ::testing::PrintToStringParamName()); + +TEST_P(CryptoServerTestNoConfig, DontCrash) { + CryptoHandshakeMessage msg = crypto_test_utils::CreateCHLO( + {{"PDMD", "X509"}, {"VER\0", client_version_string_}}, + kClientHelloMinimumSize); + + ShouldFailMentioning("No config", msg); + + const HandshakeFailureReason kRejectReasons[] = { + SERVER_CONFIG_INCHOATE_HELLO_FAILURE}; + CheckRejectReasons(kRejectReasons, ABSL_ARRAYSIZE(kRejectReasons)); +} + +class CryptoServerTestOldVersion : public CryptoServerTest { + public: + void SetUp() override { + client_version_ = supported_versions_.back(); + client_version_string_ = ParsedQuicVersionToString(client_version_); + CryptoServerTest::SetUp(); + } +}; + +INSTANTIATE_TEST_SUITE_P(CryptoServerTestsOldVersion, + CryptoServerTestOldVersion, + ::testing::ValuesIn(GetTestParams()), + ::testing::PrintToStringParamName()); + +TEST_P(CryptoServerTestOldVersion, ServerIgnoresXlct) { + CryptoHandshakeMessage msg = + crypto_test_utils::CreateCHLO({{"PDMD", "X509"}, + {"AEAD", "AESG"}, + {"KEXS", "C255"}, + {"SCID", scid_hex_}, + {"#004b5453", srct_hex_}, + {"PUBS", pub_hex_}, + {"NONC", nonce_hex_}, + {"VER\0", client_version_string_}, + {"XLCT", "#0100000000000000"}}, + kClientHelloMinimumSize); + + // If replay protection isn't disabled, then + // QuicCryptoServerConfig::EvaluateClientHello will leave info.unique as false + // and cause ProcessClientHello to exit early (and generate a REJ message). + config_.set_replay_protection(false); + + ShouldSucceed(msg); + EXPECT_EQ(kSHLO, out_.tag()); +} + +TEST_P(CryptoServerTestOldVersion, XlctNotRequired) { + CryptoHandshakeMessage msg = + crypto_test_utils::CreateCHLO({{"PDMD", "X509"}, + {"AEAD", "AESG"}, + {"KEXS", "C255"}, + {"SCID", scid_hex_}, + {"#004b5453", srct_hex_}, + {"PUBS", pub_hex_}, + {"NONC", nonce_hex_}, + {"VER\0", client_version_string_}}, + kClientHelloMinimumSize); + + // If replay protection isn't disabled, then + // QuicCryptoServerConfig::EvaluateClientHello will leave info.unique as false + // and cause ProcessClientHello to exit early (and generate a REJ message). + config_.set_replay_protection(false); + + ShouldSucceed(msg); + EXPECT_EQ(kSHLO, out_.tag()); +} + +#endif // 0 + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/crypto_utils.cc b/quiche/quic/core/crypto/crypto_utils.cc new file mode 100644 index 000000000000..01fb68196902 --- /dev/null +++ b/quiche/quic/core/crypto/crypto_utils.cc @@ -0,0 +1,812 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/crypto_utils.h" + +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "openssl/bytestring.h" +#include "openssl/hkdf.h" +#include "openssl/mem.h" +#include "openssl/sha.h" +#include "quiche/quic/core/crypto/aes_128_gcm_12_decrypter.h" +#include "quiche/quic/core/crypto/aes_128_gcm_12_encrypter.h" +#include "quiche/quic/core/crypto/aes_128_gcm_decrypter.h" +#include "quiche/quic/core/crypto/aes_128_gcm_encrypter.h" +#include "quiche/quic/core/crypto/crypto_handshake.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/crypto/null_decrypter.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/core/crypto/quic_decrypter.h" +#include "quiche/quic/core/crypto/quic_encrypter.h" +#include "quiche/quic/core/crypto/quic_hkdf.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/quiche_endian.h" + +namespace quic { + +namespace { + +// Implements the HKDF-Expand-Label function as defined in section 7.1 of RFC +// 8446. The HKDF-Expand-Label function takes 4 explicit arguments (Secret, +// Label, Context, and Length), as well as implicit PRF which is the hash +// function negotiated by TLS. Its use in QUIC (as needed by the QUIC stack, +// instead of as used internally by the TLS stack) is only for deriving initial +// secrets for obfuscation, for calculating packet protection keys and IVs from +// the corresponding packet protection secret and key update in the same quic +// session. None of these uses need a Context (a zero-length context is +// provided), so this argument is omitted here. +// +// The implicit PRF is explicitly passed into HkdfExpandLabel as |prf|; the +// Secret, Label, and Length are passed in as |secret|, |label|, and +// |out_len|, respectively. The resulting expanded secret is returned. +std::vector HkdfExpandLabel(const EVP_MD* prf, + absl::Span secret, + const std::string& label, size_t out_len) { + bssl::ScopedCBB quic_hkdf_label; + CBB inner_label; + const char label_prefix[] = "tls13 "; + // 20 = size(u16) + size(u8) + len("tls13 ") + + // max_len("client in", "server in", "quicv2 key", ... ) + + // size(u8); + static const size_t max_quic_hkdf_label_length = 20; + if (!CBB_init(quic_hkdf_label.get(), max_quic_hkdf_label_length) || + !CBB_add_u16(quic_hkdf_label.get(), out_len) || + !CBB_add_u8_length_prefixed(quic_hkdf_label.get(), &inner_label) || + !CBB_add_bytes(&inner_label, + reinterpret_cast(label_prefix), + ABSL_ARRAYSIZE(label_prefix) - 1) || + !CBB_add_bytes(&inner_label, + reinterpret_cast(label.data()), + label.size()) || + // Zero length |Context|. + !CBB_add_u8(quic_hkdf_label.get(), 0) || + !CBB_flush(quic_hkdf_label.get())) { + QUIC_LOG(ERROR) << "Building HKDF label failed"; + return std::vector(); + } + std::vector out; + out.resize(out_len); + if (!HKDF_expand(out.data(), out_len, prf, secret.data(), secret.size(), + CBB_data(quic_hkdf_label.get()), + CBB_len(quic_hkdf_label.get()))) { + QUIC_LOG(ERROR) << "Running HKDF-Expand-Label failed"; + return std::vector(); + } + return out; +} + +} // namespace + +const std::string getLabelForVersion(const ParsedQuicVersion& version, + const absl::string_view& predicate) { + static_assert(SupportedVersions().size() == 6u, + "Supported versions out of sync with HKDF labels"); + if (version == ParsedQuicVersion::V2Draft08()) { + return absl::StrCat("quicv2 ", predicate); + } else { + return absl::StrCat("quic ", predicate); + } +} + +void CryptoUtils::InitializeCrypterSecrets( + const EVP_MD* prf, const std::vector& pp_secret, + const ParsedQuicVersion& version, QuicCrypter* crypter) { + SetKeyAndIV(prf, pp_secret, version, crypter); + std::vector header_protection_key = GenerateHeaderProtectionKey( + prf, pp_secret, version, crypter->GetKeySize()); + crypter->SetHeaderProtectionKey( + absl::string_view(reinterpret_cast(header_protection_key.data()), + header_protection_key.size())); +} + +void CryptoUtils::SetKeyAndIV(const EVP_MD* prf, + absl::Span pp_secret, + const ParsedQuicVersion& version, + QuicCrypter* crypter) { + std::vector key = + HkdfExpandLabel(prf, pp_secret, getLabelForVersion(version, "key"), + crypter->GetKeySize()); + std::vector iv = HkdfExpandLabel( + prf, pp_secret, getLabelForVersion(version, "iv"), crypter->GetIVSize()); + crypter->SetKey( + absl::string_view(reinterpret_cast(key.data()), key.size())); + crypter->SetIV( + absl::string_view(reinterpret_cast(iv.data()), iv.size())); +} + +std::vector CryptoUtils::GenerateHeaderProtectionKey( + const EVP_MD* prf, absl::Span pp_secret, + const ParsedQuicVersion& version, size_t out_len) { + return HkdfExpandLabel(prf, pp_secret, getLabelForVersion(version, "hp"), + out_len); +} + +std::vector CryptoUtils::GenerateNextKeyPhaseSecret( + const EVP_MD* prf, const ParsedQuicVersion& version, + const std::vector& current_secret) { + return HkdfExpandLabel(prf, current_secret, getLabelForVersion(version, "ku"), + current_secret.size()); +} + +namespace { + +// Salt from https://tools.ietf.org/html/draft-ietf-quic-tls-29#section-5.2 +const uint8_t kDraft29InitialSalt[] = {0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, + 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61, + 0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99}; +const uint8_t kRFCv1InitialSalt[] = {0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, + 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, + 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a}; +const uint8_t kV2Draft08InitialSalt[] = { + 0x0d, 0xed, 0xe3, 0xde, 0xf7, 0x00, 0xa6, 0xdb, 0x81, 0x93, + 0x81, 0xbe, 0x6e, 0x26, 0x9d, 0xcb, 0xf9, 0xbd, 0x2e, 0xd9, +}; + +// Salts used by deployed versions of QUIC. When introducing a new version, +// generate a new salt by running `openssl rand -hex 20`. + +// Salt to use for initial obfuscators in version Q050. +const uint8_t kQ050Salt[] = {0x50, 0x45, 0x74, 0xef, 0xd0, 0x66, 0xfe, + 0x2f, 0x9d, 0x94, 0x5c, 0xfc, 0xdb, 0xd3, + 0xa7, 0xf0, 0xd3, 0xb5, 0x6b, 0x45}; +// Salt to use for initial obfuscators in +// ParsedQuicVersion::ReservedForNegotiation(). +const uint8_t kReservedForNegotiationSalt[] = { + 0xf9, 0x64, 0xbf, 0x45, 0x3a, 0x1f, 0x1b, 0x80, 0xa5, 0xf8, + 0x82, 0x03, 0x77, 0xd4, 0xaf, 0xca, 0x58, 0x0e, 0xe7, 0x43}; + +const uint8_t* InitialSaltForVersion(const ParsedQuicVersion& version, + size_t* out_len) { + static_assert(SupportedVersions().size() == 6u, + "Supported versions out of sync with initial encryption salts"); + if (version == ParsedQuicVersion::V2Draft08()) { + *out_len = ABSL_ARRAYSIZE(kV2Draft08InitialSalt); + return kV2Draft08InitialSalt; + } else if (version == ParsedQuicVersion::RFCv1()) { + *out_len = ABSL_ARRAYSIZE(kRFCv1InitialSalt); + return kRFCv1InitialSalt; + } else if (version == ParsedQuicVersion::Draft29()) { + *out_len = ABSL_ARRAYSIZE(kDraft29InitialSalt); + return kDraft29InitialSalt; + } else if (version == ParsedQuicVersion::Q050()) { + *out_len = ABSL_ARRAYSIZE(kQ050Salt); + return kQ050Salt; + } else if (version == ParsedQuicVersion::ReservedForNegotiation()) { + *out_len = ABSL_ARRAYSIZE(kReservedForNegotiationSalt); + return kReservedForNegotiationSalt; + } + QUIC_BUG(quic_bug_10699_1) + << "No initial obfuscation salt for version " << version; + *out_len = ABSL_ARRAYSIZE(kReservedForNegotiationSalt); + return kReservedForNegotiationSalt; +} + +const char kPreSharedKeyLabel[] = "QUIC PSK"; + +// Retry Integrity Protection Keys and Nonces. +// https://tools.ietf.org/html/draft-ietf-quic-tls-29#section-5.8 +// When introducing a new Google version, generate a new key by running +// `openssl rand -hex 16`. +const uint8_t kDraft29RetryIntegrityKey[] = {0xcc, 0xce, 0x18, 0x7e, 0xd0, 0x9a, + 0x09, 0xd0, 0x57, 0x28, 0x15, 0x5a, + 0x6c, 0xb9, 0x6b, 0xe1}; +const uint8_t kDraft29RetryIntegrityNonce[] = { + 0xe5, 0x49, 0x30, 0xf9, 0x7f, 0x21, 0x36, 0xf0, 0x53, 0x0a, 0x8c, 0x1c}; +const uint8_t kRFCv1RetryIntegrityKey[] = {0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, + 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, + 0xe3, 0x68, 0xc8, 0x4e}; +const uint8_t kRFCv1RetryIntegrityNonce[] = { + 0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb}; +const uint8_t kV2Draft08RetryIntegrityKey[] = { + 0x8f, 0xb4, 0xb0, 0x1b, 0x56, 0xac, 0x48, 0xe2, + 0x60, 0xfb, 0xcb, 0xce, 0xad, 0x7c, 0xcc, 0x92}; +const uint8_t kV2Draft08RetryIntegrityNonce[] = { + 0xd8, 0x69, 0x69, 0xbc, 0x2d, 0x7c, 0x6d, 0x99, 0x90, 0xef, 0xb0, 0x4a}; +// Retry integrity key used by ParsedQuicVersion::ReservedForNegotiation(). +const uint8_t kReservedForNegotiationRetryIntegrityKey[] = { + 0xf2, 0xcd, 0x8f, 0xe0, 0x36, 0xd0, 0x25, 0x35, + 0x03, 0xe6, 0x7c, 0x7b, 0xd2, 0x44, 0xca, 0xd9}; +// When introducing a new Google version, generate a new nonce by running +// `openssl rand -hex 12`. +// Retry integrity nonce used by ParsedQuicVersion::ReservedForNegotiation(). +const uint8_t kReservedForNegotiationRetryIntegrityNonce[] = { + 0x35, 0x9f, 0x16, 0xd1, 0xed, 0x80, 0x90, 0x8e, 0xec, 0x85, 0xc4, 0xd6}; + +bool RetryIntegrityKeysForVersion(const ParsedQuicVersion& version, + absl::string_view* key, + absl::string_view* nonce) { + static_assert(SupportedVersions().size() == 6u, + "Supported versions out of sync with retry integrity keys"); + if (!version.UsesTls()) { + QUIC_BUG(quic_bug_10699_2) + << "Attempted to get retry integrity keys for invalid version " + << version; + return false; + } else if (version == ParsedQuicVersion::V2Draft08()) { + *key = absl::string_view( + reinterpret_cast(kV2Draft08RetryIntegrityKey), + ABSL_ARRAYSIZE(kV2Draft08RetryIntegrityKey)); + *nonce = absl::string_view( + reinterpret_cast(kV2Draft08RetryIntegrityNonce), + ABSL_ARRAYSIZE(kV2Draft08RetryIntegrityNonce)); + return true; + } else if (version == ParsedQuicVersion::RFCv1()) { + *key = absl::string_view( + reinterpret_cast(kRFCv1RetryIntegrityKey), + ABSL_ARRAYSIZE(kRFCv1RetryIntegrityKey)); + *nonce = absl::string_view( + reinterpret_cast(kRFCv1RetryIntegrityNonce), + ABSL_ARRAYSIZE(kRFCv1RetryIntegrityNonce)); + return true; + } else if (version == ParsedQuicVersion::Draft29()) { + *key = absl::string_view( + reinterpret_cast(kDraft29RetryIntegrityKey), + ABSL_ARRAYSIZE(kDraft29RetryIntegrityKey)); + *nonce = absl::string_view( + reinterpret_cast(kDraft29RetryIntegrityNonce), + ABSL_ARRAYSIZE(kDraft29RetryIntegrityNonce)); + return true; + } else if (version == ParsedQuicVersion::ReservedForNegotiation()) { + *key = absl::string_view( + reinterpret_cast(kReservedForNegotiationRetryIntegrityKey), + ABSL_ARRAYSIZE(kReservedForNegotiationRetryIntegrityKey)); + *nonce = absl::string_view( + reinterpret_cast( + kReservedForNegotiationRetryIntegrityNonce), + ABSL_ARRAYSIZE(kReservedForNegotiationRetryIntegrityNonce)); + return true; + } + QUIC_BUG(quic_bug_10699_3) + << "Attempted to get retry integrity keys for version " << version; + return false; +} + +} // namespace + +// static +void CryptoUtils::CreateInitialObfuscators(Perspective perspective, + ParsedQuicVersion version, + QuicConnectionId connection_id, + CrypterPair* crypters) { + QUIC_DLOG(INFO) << "Creating " + << (perspective == Perspective::IS_CLIENT ? "client" + : "server") + << " crypters for version " << version << " with CID " + << connection_id; + if (!version.UsesInitialObfuscators()) { + crypters->encrypter = std::make_unique(perspective); + crypters->decrypter = std::make_unique(perspective); + return; + } + QUIC_BUG_IF(quic_bug_12871_1, !QuicUtils::IsConnectionIdValidForVersion( + connection_id, version.transport_version)) + << "CreateTlsInitialCrypters: attempted to use connection ID " + << connection_id << " which is invalid with version " << version; + const EVP_MD* hash = EVP_sha256(); + + size_t salt_len; + const uint8_t* salt = InitialSaltForVersion(version, &salt_len); + std::vector handshake_secret; + handshake_secret.resize(EVP_MAX_MD_SIZE); + size_t handshake_secret_len; + const bool hkdf_extract_success = + HKDF_extract(handshake_secret.data(), &handshake_secret_len, hash, + reinterpret_cast(connection_id.data()), + connection_id.length(), salt, salt_len); + QUIC_BUG_IF(quic_bug_12871_2, !hkdf_extract_success) + << "HKDF_extract failed when creating initial crypters"; + handshake_secret.resize(handshake_secret_len); + + const std::string client_label = "client in"; + const std::string server_label = "server in"; + std::string encryption_label, decryption_label; + if (perspective == Perspective::IS_CLIENT) { + encryption_label = client_label; + decryption_label = server_label; + } else { + encryption_label = server_label; + decryption_label = client_label; + } + std::vector encryption_secret = HkdfExpandLabel( + hash, handshake_secret, encryption_label, EVP_MD_size(hash)); + crypters->encrypter = std::make_unique(); + InitializeCrypterSecrets(hash, encryption_secret, version, + crypters->encrypter.get()); + + std::vector decryption_secret = HkdfExpandLabel( + hash, handshake_secret, decryption_label, EVP_MD_size(hash)); + crypters->decrypter = std::make_unique(); + InitializeCrypterSecrets(hash, decryption_secret, version, + crypters->decrypter.get()); +} + +// static +bool CryptoUtils::ValidateRetryIntegrityTag( + ParsedQuicVersion version, QuicConnectionId original_connection_id, + absl::string_view retry_without_tag, absl::string_view integrity_tag) { + unsigned char computed_integrity_tag[kRetryIntegrityTagLength]; + if (integrity_tag.length() != ABSL_ARRAYSIZE(computed_integrity_tag)) { + QUIC_BUG(quic_bug_10699_4) + << "Invalid retry integrity tag length " << integrity_tag.length(); + return false; + } + char retry_pseudo_packet[kMaxIncomingPacketSize + 256]; + QuicDataWriter writer(ABSL_ARRAYSIZE(retry_pseudo_packet), + retry_pseudo_packet); + if (!writer.WriteLengthPrefixedConnectionId(original_connection_id)) { + QUIC_BUG(quic_bug_10699_5) + << "Failed to write original connection ID in retry pseudo packet"; + return false; + } + if (!writer.WriteStringPiece(retry_without_tag)) { + QUIC_BUG(quic_bug_10699_6) + << "Failed to write retry without tag in retry pseudo packet"; + return false; + } + absl::string_view key; + absl::string_view nonce; + if (!RetryIntegrityKeysForVersion(version, &key, &nonce)) { + // RetryIntegrityKeysForVersion already logs failures. + return false; + } + Aes128GcmEncrypter crypter; + crypter.SetKey(key); + absl::string_view associated_data(writer.data(), writer.length()); + absl::string_view plaintext; // Plaintext is empty. + if (!crypter.Encrypt(nonce, associated_data, plaintext, + computed_integrity_tag)) { + QUIC_BUG(quic_bug_10699_7) << "Failed to compute retry integrity tag"; + return false; + } + if (CRYPTO_memcmp(computed_integrity_tag, integrity_tag.data(), + ABSL_ARRAYSIZE(computed_integrity_tag)) != 0) { + QUIC_DLOG(ERROR) << "Failed to validate retry integrity tag"; + return false; + } + return true; +} + +// static +void CryptoUtils::GenerateNonce(QuicWallTime now, QuicRandom* random_generator, + absl::string_view orbit, std::string* nonce) { + // a 4-byte timestamp + 28 random bytes. + nonce->reserve(kNonceSize); + nonce->resize(kNonceSize); + + uint32_t gmt_unix_time = static_cast(now.ToUNIXSeconds()); + // The time in the nonce must be encoded in big-endian because the + // strike-register depends on the nonces being ordered by time. + (*nonce)[0] = static_cast(gmt_unix_time >> 24); + (*nonce)[1] = static_cast(gmt_unix_time >> 16); + (*nonce)[2] = static_cast(gmt_unix_time >> 8); + (*nonce)[3] = static_cast(gmt_unix_time); + size_t bytes_written = 4; + + if (orbit.size() == 8) { + memcpy(&(*nonce)[bytes_written], orbit.data(), orbit.size()); + bytes_written += orbit.size(); + } + + random_generator->RandBytes(&(*nonce)[bytes_written], + kNonceSize - bytes_written); +} + +// static +bool CryptoUtils::DeriveKeys( + const ParsedQuicVersion& version, absl::string_view premaster_secret, + QuicTag aead, absl::string_view client_nonce, + absl::string_view server_nonce, absl::string_view pre_shared_key, + const std::string& hkdf_input, Perspective perspective, + Diversification diversification, CrypterPair* crypters, + std::string* subkey_secret) { + // If the connection is using PSK, concatenate it with the pre-master secret. + std::unique_ptr psk_premaster_secret; + if (!pre_shared_key.empty()) { + const absl::string_view label(kPreSharedKeyLabel); + const size_t psk_premaster_secret_size = label.size() + 1 + + pre_shared_key.size() + 8 + + premaster_secret.size() + 8; + + psk_premaster_secret = std::make_unique(psk_premaster_secret_size); + QuicDataWriter writer(psk_premaster_secret_size, psk_premaster_secret.get(), + quiche::HOST_BYTE_ORDER); + + if (!writer.WriteStringPiece(label) || !writer.WriteUInt8(0) || + !writer.WriteStringPiece(pre_shared_key) || + !writer.WriteUInt64(pre_shared_key.size()) || + !writer.WriteStringPiece(premaster_secret) || + !writer.WriteUInt64(premaster_secret.size()) || + writer.remaining() != 0) { + return false; + } + + premaster_secret = absl::string_view(psk_premaster_secret.get(), + psk_premaster_secret_size); + } + + crypters->encrypter = QuicEncrypter::Create(version, aead); + crypters->decrypter = QuicDecrypter::Create(version, aead); + + size_t key_bytes = crypters->encrypter->GetKeySize(); + size_t nonce_prefix_bytes = crypters->encrypter->GetNoncePrefixSize(); + if (version.UsesInitialObfuscators()) { + nonce_prefix_bytes = crypters->encrypter->GetIVSize(); + } + size_t subkey_secret_bytes = + subkey_secret == nullptr ? 0 : premaster_secret.length(); + + absl::string_view nonce = client_nonce; + std::string nonce_storage; + if (!server_nonce.empty()) { + nonce_storage = std::string(client_nonce) + std::string(server_nonce); + nonce = nonce_storage; + } + + QuicHKDF hkdf(premaster_secret, nonce, hkdf_input, key_bytes, + nonce_prefix_bytes, subkey_secret_bytes); + + // Key derivation depends on the key diversification method being employed. + // both the client and the server support never doing key diversification. + // The server also supports immediate diversification, and the client + // supports pending diversification. + switch (diversification.mode()) { + case Diversification::NEVER: { + if (perspective == Perspective::IS_SERVER) { + if (!crypters->encrypter->SetKey(hkdf.server_write_key()) || + !crypters->encrypter->SetNoncePrefixOrIV(version, + hkdf.server_write_iv()) || + !crypters->encrypter->SetHeaderProtectionKey( + hkdf.server_hp_key()) || + !crypters->decrypter->SetKey(hkdf.client_write_key()) || + !crypters->decrypter->SetNoncePrefixOrIV(version, + hkdf.client_write_iv()) || + !crypters->decrypter->SetHeaderProtectionKey( + hkdf.client_hp_key())) { + return false; + } + } else { + if (!crypters->encrypter->SetKey(hkdf.client_write_key()) || + !crypters->encrypter->SetNoncePrefixOrIV(version, + hkdf.client_write_iv()) || + !crypters->encrypter->SetHeaderProtectionKey( + hkdf.client_hp_key()) || + !crypters->decrypter->SetKey(hkdf.server_write_key()) || + !crypters->decrypter->SetNoncePrefixOrIV(version, + hkdf.server_write_iv()) || + !crypters->decrypter->SetHeaderProtectionKey( + hkdf.server_hp_key())) { + return false; + } + } + break; + } + case Diversification::PENDING: { + if (perspective == Perspective::IS_SERVER) { + QUIC_BUG(quic_bug_10699_8) + << "Pending diversification is only for clients."; + return false; + } + + if (!crypters->encrypter->SetKey(hkdf.client_write_key()) || + !crypters->encrypter->SetNoncePrefixOrIV(version, + hkdf.client_write_iv()) || + !crypters->encrypter->SetHeaderProtectionKey(hkdf.client_hp_key()) || + !crypters->decrypter->SetPreliminaryKey(hkdf.server_write_key()) || + !crypters->decrypter->SetNoncePrefixOrIV(version, + hkdf.server_write_iv()) || + !crypters->decrypter->SetHeaderProtectionKey(hkdf.server_hp_key())) { + return false; + } + break; + } + case Diversification::NOW: { + if (perspective == Perspective::IS_CLIENT) { + QUIC_BUG(quic_bug_10699_9) + << "Immediate diversification is only for servers."; + return false; + } + + std::string key, nonce_prefix; + QuicDecrypter::DiversifyPreliminaryKey( + hkdf.server_write_key(), hkdf.server_write_iv(), + *diversification.nonce(), key_bytes, nonce_prefix_bytes, &key, + &nonce_prefix); + if (!crypters->decrypter->SetKey(hkdf.client_write_key()) || + !crypters->decrypter->SetNoncePrefixOrIV(version, + hkdf.client_write_iv()) || + !crypters->decrypter->SetHeaderProtectionKey(hkdf.client_hp_key()) || + !crypters->encrypter->SetKey(key) || + !crypters->encrypter->SetNoncePrefixOrIV(version, nonce_prefix) || + !crypters->encrypter->SetHeaderProtectionKey(hkdf.server_hp_key())) { + return false; + } + break; + } + default: + QUICHE_DCHECK(false); + } + + if (subkey_secret != nullptr) { + *subkey_secret = std::string(hkdf.subkey_secret()); + } + + return true; +} + +// static +uint64_t CryptoUtils::ComputeLeafCertHash(absl::string_view cert) { + return QuicUtils::FNV1a_64_Hash(cert); +} + +QuicErrorCode CryptoUtils::ValidateServerHello( + const CryptoHandshakeMessage& server_hello, + const ParsedQuicVersionVector& negotiated_versions, + std::string* error_details) { + QUICHE_DCHECK(error_details != nullptr); + + if (server_hello.tag() != kSHLO) { + *error_details = "Bad tag"; + return QUIC_INVALID_CRYPTO_MESSAGE_TYPE; + } + + QuicVersionLabelVector supported_version_labels; + if (server_hello.GetVersionLabelList(kVER, &supported_version_labels) != + QUIC_NO_ERROR) { + *error_details = "server hello missing version list"; + return QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER; + } + + return ValidateServerHelloVersions(supported_version_labels, + negotiated_versions, error_details); +} + +QuicErrorCode CryptoUtils::ValidateServerHelloVersions( + const QuicVersionLabelVector& server_versions, + const ParsedQuicVersionVector& negotiated_versions, + std::string* error_details) { + if (!negotiated_versions.empty()) { + bool mismatch = server_versions.size() != negotiated_versions.size(); + for (size_t i = 0; i < server_versions.size() && !mismatch; ++i) { + mismatch = + server_versions[i] != CreateQuicVersionLabel(negotiated_versions[i]); + } + // The server sent a list of supported versions, and the connection + // reports that there was a version negotiation during the handshake. + // Ensure that these two lists are identical. + if (mismatch) { + *error_details = absl::StrCat( + "Downgrade attack detected: ServerVersions(", server_versions.size(), + ")[", QuicVersionLabelVectorToString(server_versions, ",", 30), + "] NegotiatedVersions(", negotiated_versions.size(), ")[", + ParsedQuicVersionVectorToString(negotiated_versions, ",", 30), "]"); + return QUIC_VERSION_NEGOTIATION_MISMATCH; + } + } + return QUIC_NO_ERROR; +} + +QuicErrorCode CryptoUtils::ValidateClientHello( + const CryptoHandshakeMessage& client_hello, ParsedQuicVersion version, + const ParsedQuicVersionVector& supported_versions, + std::string* error_details) { + if (client_hello.tag() != kCHLO) { + *error_details = "Bad tag"; + return QUIC_INVALID_CRYPTO_MESSAGE_TYPE; + } + + // If the client's preferred version is not the version we are currently + // speaking, then the client went through a version negotiation. In this + // case, we need to make sure that we actually do not support this version + // and that it wasn't a downgrade attack. + QuicVersionLabel client_version_label; + if (client_hello.GetVersionLabel(kVER, &client_version_label) != + QUIC_NO_ERROR) { + *error_details = "client hello missing version list"; + return QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER; + } + return ValidateClientHelloVersion(client_version_label, version, + supported_versions, error_details); +} + +QuicErrorCode CryptoUtils::ValidateClientHelloVersion( + QuicVersionLabel client_version, ParsedQuicVersion connection_version, + const ParsedQuicVersionVector& supported_versions, + std::string* error_details) { + if (client_version != CreateQuicVersionLabel(connection_version)) { + // Check to see if |client_version| is actually on the supported versions + // list. If not, the server doesn't support that version and it's not a + // downgrade attack. + for (size_t i = 0; i < supported_versions.size(); ++i) { + if (client_version == CreateQuicVersionLabel(supported_versions[i])) { + *error_details = absl::StrCat( + "Downgrade attack detected: ClientVersion[", + QuicVersionLabelToString(client_version), "] ConnectionVersion[", + ParsedQuicVersionToString(connection_version), + "] SupportedVersions(", supported_versions.size(), ")[", + ParsedQuicVersionVectorToString(supported_versions, ",", 30), "]"); + return QUIC_VERSION_NEGOTIATION_MISMATCH; + } + } + } + return QUIC_NO_ERROR; +} + +// static +bool CryptoUtils::ValidateChosenVersion( + const QuicVersionLabel& version_information_chosen_version, + const ParsedQuicVersion& session_version, std::string* error_details) { + if (version_information_chosen_version != + CreateQuicVersionLabel(session_version)) { + *error_details = absl::StrCat( + "Detected version mismatch: version_information contained ", + QuicVersionLabelToString(version_information_chosen_version), + " instead of ", ParsedQuicVersionToString(session_version)); + return false; + } + return true; +} + +// static +bool CryptoUtils::ValidateServerVersions( + const QuicVersionLabelVector& version_information_other_versions, + const ParsedQuicVersion& session_version, + const ParsedQuicVersionVector& client_original_supported_versions, + std::string* error_details) { + if (client_original_supported_versions.empty()) { + // We did not receive a version negotiation packet. + return true; + } + // Parse the server's other versions. + ParsedQuicVersionVector parsed_other_versions = + ParseQuicVersionLabelVector(version_information_other_versions); + // Find the first version that we originally supported that is listed in the + // server's other versions. + ParsedQuicVersion expected_version = ParsedQuicVersion::Unsupported(); + for (const ParsedQuicVersion& client_version : + client_original_supported_versions) { + if (std::find(parsed_other_versions.begin(), parsed_other_versions.end(), + client_version) != parsed_other_versions.end()) { + expected_version = client_version; + break; + } + } + if (expected_version != session_version) { + *error_details = absl::StrCat( + "Downgrade attack detected: used ", + ParsedQuicVersionToString(session_version), " but ServerVersions(", + version_information_other_versions.size(), ")[", + QuicVersionLabelVectorToString(version_information_other_versions, ",", + 30), + "] ClientOriginalVersions(", client_original_supported_versions.size(), + ")[", + ParsedQuicVersionVectorToString(client_original_supported_versions, ",", + 30), + "]"); + return false; + } + return true; +} + +#define RETURN_STRING_LITERAL(x) \ + case x: \ + return #x + +// Returns the name of the HandshakeFailureReason as a char* +// static +const char* CryptoUtils::HandshakeFailureReasonToString( + HandshakeFailureReason reason) { + switch (reason) { + RETURN_STRING_LITERAL(HANDSHAKE_OK); + RETURN_STRING_LITERAL(CLIENT_NONCE_UNKNOWN_FAILURE); + RETURN_STRING_LITERAL(CLIENT_NONCE_INVALID_FAILURE); + RETURN_STRING_LITERAL(CLIENT_NONCE_NOT_UNIQUE_FAILURE); + RETURN_STRING_LITERAL(CLIENT_NONCE_INVALID_ORBIT_FAILURE); + RETURN_STRING_LITERAL(CLIENT_NONCE_INVALID_TIME_FAILURE); + RETURN_STRING_LITERAL(CLIENT_NONCE_STRIKE_REGISTER_TIMEOUT); + RETURN_STRING_LITERAL(CLIENT_NONCE_STRIKE_REGISTER_FAILURE); + + RETURN_STRING_LITERAL(SERVER_NONCE_DECRYPTION_FAILURE); + RETURN_STRING_LITERAL(SERVER_NONCE_INVALID_FAILURE); + RETURN_STRING_LITERAL(SERVER_NONCE_NOT_UNIQUE_FAILURE); + RETURN_STRING_LITERAL(SERVER_NONCE_INVALID_TIME_FAILURE); + RETURN_STRING_LITERAL(SERVER_NONCE_REQUIRED_FAILURE); + + RETURN_STRING_LITERAL(SERVER_CONFIG_INCHOATE_HELLO_FAILURE); + RETURN_STRING_LITERAL(SERVER_CONFIG_UNKNOWN_CONFIG_FAILURE); + + RETURN_STRING_LITERAL(SOURCE_ADDRESS_TOKEN_INVALID_FAILURE); + RETURN_STRING_LITERAL(SOURCE_ADDRESS_TOKEN_DECRYPTION_FAILURE); + RETURN_STRING_LITERAL(SOURCE_ADDRESS_TOKEN_PARSE_FAILURE); + RETURN_STRING_LITERAL(SOURCE_ADDRESS_TOKEN_DIFFERENT_IP_ADDRESS_FAILURE); + RETURN_STRING_LITERAL(SOURCE_ADDRESS_TOKEN_CLOCK_SKEW_FAILURE); + RETURN_STRING_LITERAL(SOURCE_ADDRESS_TOKEN_EXPIRED_FAILURE); + + RETURN_STRING_LITERAL(INVALID_EXPECTED_LEAF_CERTIFICATE); + RETURN_STRING_LITERAL(MAX_FAILURE_REASON); + } + // Return a default value so that we return this when |reason| doesn't match + // any HandshakeFailureReason.. This can happen when the message by the peer + // (attacker) has invalid reason. + return "INVALID_HANDSHAKE_FAILURE_REASON"; +} + +#undef RETURN_STRING_LITERAL // undef for jumbo builds + +// static +std::string CryptoUtils::EarlyDataReasonToString( + ssl_early_data_reason_t reason) { + const char* reason_string = SSL_early_data_reason_string(reason); + if (reason_string != nullptr) { + return std::string("ssl_early_data_") + reason_string; + } + QUIC_BUG_IF(quic_bug_12871_3, + reason < 0 || reason > ssl_early_data_reason_max_value) + << "Unknown ssl_early_data_reason_t " << reason; + return "unknown ssl_early_data_reason_t"; +} + +// static +std::string CryptoUtils::HashHandshakeMessage( + const CryptoHandshakeMessage& message, Perspective /*perspective*/) { + std::string output; + const QuicData& serialized = message.GetSerialized(); + uint8_t digest[SHA256_DIGEST_LENGTH]; + SHA256(reinterpret_cast(serialized.data()), + serialized.length(), digest); + output.assign(reinterpret_cast(digest), sizeof(digest)); + return output; +} + +// static +bool CryptoUtils::GetSSLCapabilities(const SSL* ssl, + bssl::UniquePtr* capabilities, + size_t* capabilities_len) { + uint8_t* buffer; + bssl::ScopedCBB cbb; + + if (!CBB_init(cbb.get(), 128) || + !SSL_serialize_capabilities(ssl, cbb.get()) || + !CBB_finish(cbb.get(), &buffer, capabilities_len)) { + return false; + } + + *capabilities = bssl::UniquePtr(buffer); + return true; +} + +// static +absl::optional CryptoUtils::GenerateProofPayloadToBeSigned( + absl::string_view chlo_hash, absl::string_view server_config) { + size_t payload_size = sizeof(kProofSignatureLabel) + sizeof(uint32_t) + + chlo_hash.size() + server_config.size(); + std::string payload; + payload.resize(payload_size); + QuicDataWriter payload_writer(payload_size, payload.data(), + quiche::Endianness::HOST_BYTE_ORDER); + bool success = payload_writer.WriteBytes(kProofSignatureLabel, + sizeof(kProofSignatureLabel)) && + payload_writer.WriteUInt32(chlo_hash.size()) && + payload_writer.WriteStringPiece(chlo_hash) && + payload_writer.WriteStringPiece(server_config); + if (!success) { + return absl::nullopt; + } + return payload; +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/crypto_utils.h b/quiche/quic/core/crypto/crypto_utils.h new file mode 100644 index 000000000000..adc818dbda52 --- /dev/null +++ b/quiche/quic/core/crypto/crypto_utils.h @@ -0,0 +1,259 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Some helpers for quic crypto + +#ifndef QUICHE_QUIC_CORE_CRYPTO_CRYPTO_UTILS_H_ +#define QUICHE_QUIC_CORE_CRYPTO_CRYPTO_UTILS_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "openssl/evp.h" +#include "openssl/ssl.h" +#include "quiche/quic/core/crypto/crypto_handshake.h" +#include "quiche/quic/core/crypto/crypto_handshake_message.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/crypto/quic_crypter.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class QUIC_EXPORT_PRIVATE CryptoUtils { + public: + CryptoUtils() = delete; + + // Diversification is a utility class that's used to act like a union type. + // Values can be created by calling the functions like |NoDiversification|, + // below. + class QUIC_EXPORT_PRIVATE Diversification { + public: + enum Mode { + NEVER, // Key diversification will never be used. Forward secure + // crypters will always use this mode. + + PENDING, // Key diversification will happen when a nonce is later + // received. This should only be used by clients initial + // decrypters which are waiting on the divesification nonce + // from the server. + + NOW, // Key diversification will happen immediate based on the nonce. + // This should only be used by servers initial encrypters. + }; + + Diversification(const Diversification& diversification) = default; + + static Diversification Never() { return Diversification(NEVER, nullptr); } + static Diversification Pending() { + return Diversification(PENDING, nullptr); + } + static Diversification Now(DiversificationNonce* nonce) { + return Diversification(NOW, nonce); + } + + Mode mode() const { return mode_; } + DiversificationNonce* nonce() const { + QUICHE_DCHECK_EQ(mode_, NOW); + return nonce_; + } + + private: + Diversification(Mode mode, DiversificationNonce* nonce) + : mode_(mode), nonce_(nonce) {} + + Mode mode_; + DiversificationNonce* nonce_; + }; + + // InitializeCrypterSecrets derives the key and IV and header protection key + // from the given packet protection secret |pp_secret| and sets those fields + // on the given QuicCrypter |*crypter|. + // This follows the derivation described in section 7.3 of RFC 8446, except + // with the label prefix in HKDF-Expand-Label changed from "tls13 " to "quic " + // as described in draft-ietf-quic-tls-14, section 5.1, or "quicv2 " as + // described in draft-ietf-quic-v2-01. + static void InitializeCrypterSecrets(const EVP_MD* prf, + const std::vector& pp_secret, + const ParsedQuicVersion& version, + QuicCrypter* crypter); + + // Derives the key and IV from the packet protection secret and sets those + // fields on the given QuicCrypter |*crypter|, but does not set the header + // protection key. GenerateHeaderProtectionKey/SetHeaderProtectionKey must be + // called before using |crypter|. + static void SetKeyAndIV(const EVP_MD* prf, + absl::Span pp_secret, + const ParsedQuicVersion& version, + QuicCrypter* crypter); + + // Derives the header protection key from the packet protection secret. + static std::vector GenerateHeaderProtectionKey( + const EVP_MD* prf, absl::Span pp_secret, + const ParsedQuicVersion& version, size_t out_len); + + // Given a secret for key phase n, return the secret for phase n+1. + static std::vector GenerateNextKeyPhaseSecret( + const EVP_MD* prf, const ParsedQuicVersion& version, + const std::vector& current_secret); + + // IETF QUIC encrypts ENCRYPTION_INITIAL messages with a version-specific key + // (to prevent network observers that are not aware of that QUIC version from + // making decisions based on the TLS handshake). This packet protection secret + // is derived from the connection ID in the client's Initial packet. + // + // This function takes that |connection_id| and creates the encrypter and + // decrypter (put in |*crypters|) to use for this packet protection, as well + // as setting the key and IV on those crypters. For older versions of QUIC + // that do not use the new IETF style ENCRYPTION_INITIAL obfuscators, this + // function puts a NullEncrypter and NullDecrypter in |*crypters|. + static void CreateInitialObfuscators(Perspective perspective, + ParsedQuicVersion version, + QuicConnectionId connection_id, + CrypterPair* crypters); + + // IETF QUIC Retry packets carry a retry integrity tag to detect packet + // corruption and make it harder for an attacker to spoof. This function + // checks whether a given retry packet is valid. + static bool ValidateRetryIntegrityTag(ParsedQuicVersion version, + QuicConnectionId original_connection_id, + absl::string_view retry_without_tag, + absl::string_view integrity_tag); + + // Generates the connection nonce. The nonce is formed as: + // <4 bytes> current time + // <8 bytes> |orbit| (or random if |orbit| is empty) + // <20 bytes> random + static void GenerateNonce(QuicWallTime now, QuicRandom* random_generator, + absl::string_view orbit, std::string* nonce); + + // DeriveKeys populates |crypters->encrypter|, |crypters->decrypter|, and + // |subkey_secret| (optional -- may be null) given the contents of + // |premaster_secret|, |client_nonce|, |server_nonce| and |hkdf_input|. |aead| + // determines which cipher will be used. |perspective| controls whether the + // server's keys are assigned to |encrypter| or |decrypter|. |server_nonce| is + // optional and, if non-empty, is mixed into the key derivation. + // |subkey_secret| will have the same length as |premaster_secret|. + // + // If |pre_shared_key| is non-empty, it is incorporated into the key + // derivation parameters. If it is empty, the key derivation is unaltered. + // + // If the mode of |diversification| is NEVER, the the crypters will be + // configured to never perform key diversification. If the mode is + // NOW (which is only for servers, then the encrypter will be keyed via a + // two-step process that uses the nonce from |diversification|. + // If the mode is PENDING (which is only for servres), then the + // decrypter will only be keyed to a preliminary state: a call to + // |SetDiversificationNonce| with a diversification nonce will be needed to + // complete keying. + static bool DeriveKeys(const ParsedQuicVersion& version, + absl::string_view premaster_secret, QuicTag aead, + absl::string_view client_nonce, + absl::string_view server_nonce, + absl::string_view pre_shared_key, + const std::string& hkdf_input, Perspective perspective, + Diversification diversification, CrypterPair* crypters, + std::string* subkey_secret); + + // Computes the FNV-1a hash of the provided DER-encoded cert for use in the + // XLCT tag. + static uint64_t ComputeLeafCertHash(absl::string_view cert); + + // Validates that |server_hello| is actually an SHLO message and that it is + // not part of a downgrade attack. + // + // Returns QUIC_NO_ERROR if this is the case or returns the appropriate error + // code and sets |error_details|. + static QuicErrorCode ValidateServerHello( + const CryptoHandshakeMessage& server_hello, + const ParsedQuicVersionVector& negotiated_versions, + std::string* error_details); + + // Validates that the |server_versions| received do not indicate that the + // ServerHello is part of a downgrade attack. |negotiated_versions| must + // contain the list of versions received in the server's version negotiation + // packet (or be empty if no such packet was received). + // + // Returns QUIC_NO_ERROR if this is the case or returns the appropriate error + // code and sets |error_details|. + static QuicErrorCode ValidateServerHelloVersions( + const QuicVersionLabelVector& server_versions, + const ParsedQuicVersionVector& negotiated_versions, + std::string* error_details); + + // Validates that |client_hello| is actually a CHLO and that this is not part + // of a downgrade attack. + // This includes verifiying versions and detecting downgrade attacks. + // + // Returns QUIC_NO_ERROR if this is the case or returns the appropriate error + // code and sets |error_details|. + static QuicErrorCode ValidateClientHello( + const CryptoHandshakeMessage& client_hello, ParsedQuicVersion version, + const ParsedQuicVersionVector& supported_versions, + std::string* error_details); + + // Validates that the |client_version| received does not indicate that a + // downgrade attack has occurred. |connection_version| is the version of the + // QuicConnection, and |supported_versions| is all versions that that + // QuicConnection supports. + // + // Returns QUIC_NO_ERROR if this is the case or returns the appropriate error + // code and sets |error_details|. + static QuicErrorCode ValidateClientHelloVersion( + QuicVersionLabel client_version, ParsedQuicVersion connection_version, + const ParsedQuicVersionVector& supported_versions, + std::string* error_details); + + // Validates that the chosen version from the version_information matches the + // version from the session. Returns true if they match, otherwise returns + // false and fills in |error_details|. + static bool ValidateChosenVersion( + const QuicVersionLabel& version_information_chosen_version, + const ParsedQuicVersion& session_version, std::string* error_details); + + // Validates that there was no downgrade attack involving a version + // negotiation packet. This verifies that if the client was initially + // configured with |client_original_supported_versions| and it had received a + // version negotiation packet with |version_information_other_versions|, then + // it would have selected |session_version|. Returns true if they match (or if + // |client_original_supported_versions| is empty indicating no version + // negotiation packet was received), otherwise returns + // false and fills in |error_details|. + static bool ValidateServerVersions( + const QuicVersionLabelVector& version_information_other_versions, + const ParsedQuicVersion& session_version, + const ParsedQuicVersionVector& client_original_supported_versions, + std::string* error_details); + + // Returns the name of the HandshakeFailureReason as a char* + static const char* HandshakeFailureReasonToString( + HandshakeFailureReason reason); + + // Returns the name of an ssl_early_data_reason_t as a char* + static std::string EarlyDataReasonToString(ssl_early_data_reason_t reason); + + // Returns a hash of the serialized |message|. + static std::string HashHandshakeMessage(const CryptoHandshakeMessage& message, + Perspective perspective); + + // Wraps SSL_serialize_capabilities. Return nullptr if failed. + static bool GetSSLCapabilities(const SSL* ssl, + bssl::UniquePtr* capabilities, + size_t* capabilities_len); + + // Computes the contents of a binary message that is signed inside QUIC Crypto + // protocol using the certificate key. + static absl::optional GenerateProofPayloadToBeSigned( + absl::string_view chlo_hash, absl::string_view server_config); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_CRYPTO_UTILS_H_ diff --git a/quiche/quic/core/crypto/crypto_utils_test.cc b/quiche/quic/core/crypto/crypto_utils_test.cc new file mode 100644 index 000000000000..6bc17cf4a677 --- /dev/null +++ b/quiche/quic/core/crypto/crypto_utils_test.cc @@ -0,0 +1,262 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/crypto_utils.h" + +#include + +#include "absl/base/macros.h" +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +namespace quic { +namespace test { +namespace { + +class CryptoUtilsTest : public QuicTest {}; + +TEST_F(CryptoUtilsTest, HandshakeFailureReasonToString) { + EXPECT_STREQ("HANDSHAKE_OK", + CryptoUtils::HandshakeFailureReasonToString(HANDSHAKE_OK)); + EXPECT_STREQ("CLIENT_NONCE_UNKNOWN_FAILURE", + CryptoUtils::HandshakeFailureReasonToString( + CLIENT_NONCE_UNKNOWN_FAILURE)); + EXPECT_STREQ("CLIENT_NONCE_INVALID_FAILURE", + CryptoUtils::HandshakeFailureReasonToString( + CLIENT_NONCE_INVALID_FAILURE)); + EXPECT_STREQ("CLIENT_NONCE_NOT_UNIQUE_FAILURE", + CryptoUtils::HandshakeFailureReasonToString( + CLIENT_NONCE_NOT_UNIQUE_FAILURE)); + EXPECT_STREQ("CLIENT_NONCE_INVALID_ORBIT_FAILURE", + CryptoUtils::HandshakeFailureReasonToString( + CLIENT_NONCE_INVALID_ORBIT_FAILURE)); + EXPECT_STREQ("CLIENT_NONCE_INVALID_TIME_FAILURE", + CryptoUtils::HandshakeFailureReasonToString( + CLIENT_NONCE_INVALID_TIME_FAILURE)); + EXPECT_STREQ("CLIENT_NONCE_STRIKE_REGISTER_TIMEOUT", + CryptoUtils::HandshakeFailureReasonToString( + CLIENT_NONCE_STRIKE_REGISTER_TIMEOUT)); + EXPECT_STREQ("CLIENT_NONCE_STRIKE_REGISTER_FAILURE", + CryptoUtils::HandshakeFailureReasonToString( + CLIENT_NONCE_STRIKE_REGISTER_FAILURE)); + EXPECT_STREQ("SERVER_NONCE_DECRYPTION_FAILURE", + CryptoUtils::HandshakeFailureReasonToString( + SERVER_NONCE_DECRYPTION_FAILURE)); + EXPECT_STREQ("SERVER_NONCE_INVALID_FAILURE", + CryptoUtils::HandshakeFailureReasonToString( + SERVER_NONCE_INVALID_FAILURE)); + EXPECT_STREQ("SERVER_NONCE_NOT_UNIQUE_FAILURE", + CryptoUtils::HandshakeFailureReasonToString( + SERVER_NONCE_NOT_UNIQUE_FAILURE)); + EXPECT_STREQ("SERVER_NONCE_INVALID_TIME_FAILURE", + CryptoUtils::HandshakeFailureReasonToString( + SERVER_NONCE_INVALID_TIME_FAILURE)); + EXPECT_STREQ("SERVER_NONCE_REQUIRED_FAILURE", + CryptoUtils::HandshakeFailureReasonToString( + SERVER_NONCE_REQUIRED_FAILURE)); + EXPECT_STREQ("SERVER_CONFIG_INCHOATE_HELLO_FAILURE", + CryptoUtils::HandshakeFailureReasonToString( + SERVER_CONFIG_INCHOATE_HELLO_FAILURE)); + EXPECT_STREQ("SERVER_CONFIG_UNKNOWN_CONFIG_FAILURE", + CryptoUtils::HandshakeFailureReasonToString( + SERVER_CONFIG_UNKNOWN_CONFIG_FAILURE)); + EXPECT_STREQ("SOURCE_ADDRESS_TOKEN_INVALID_FAILURE", + CryptoUtils::HandshakeFailureReasonToString( + SOURCE_ADDRESS_TOKEN_INVALID_FAILURE)); + EXPECT_STREQ("SOURCE_ADDRESS_TOKEN_DECRYPTION_FAILURE", + CryptoUtils::HandshakeFailureReasonToString( + SOURCE_ADDRESS_TOKEN_DECRYPTION_FAILURE)); + EXPECT_STREQ("SOURCE_ADDRESS_TOKEN_PARSE_FAILURE", + CryptoUtils::HandshakeFailureReasonToString( + SOURCE_ADDRESS_TOKEN_PARSE_FAILURE)); + EXPECT_STREQ("SOURCE_ADDRESS_TOKEN_DIFFERENT_IP_ADDRESS_FAILURE", + CryptoUtils::HandshakeFailureReasonToString( + SOURCE_ADDRESS_TOKEN_DIFFERENT_IP_ADDRESS_FAILURE)); + EXPECT_STREQ("SOURCE_ADDRESS_TOKEN_CLOCK_SKEW_FAILURE", + CryptoUtils::HandshakeFailureReasonToString( + SOURCE_ADDRESS_TOKEN_CLOCK_SKEW_FAILURE)); + EXPECT_STREQ("SOURCE_ADDRESS_TOKEN_EXPIRED_FAILURE", + CryptoUtils::HandshakeFailureReasonToString( + SOURCE_ADDRESS_TOKEN_EXPIRED_FAILURE)); + EXPECT_STREQ("INVALID_EXPECTED_LEAF_CERTIFICATE", + CryptoUtils::HandshakeFailureReasonToString( + INVALID_EXPECTED_LEAF_CERTIFICATE)); + EXPECT_STREQ("MAX_FAILURE_REASON", + CryptoUtils::HandshakeFailureReasonToString(MAX_FAILURE_REASON)); + EXPECT_STREQ( + "INVALID_HANDSHAKE_FAILURE_REASON", + CryptoUtils::HandshakeFailureReasonToString( + static_cast(MAX_FAILURE_REASON + 1))); +} + +TEST_F(CryptoUtilsTest, AuthTagLengths) { + for (const auto& version : AllSupportedVersions()) { + for (QuicTag algo : {kAESG, kCC20}) { + SCOPED_TRACE(version); + std::unique_ptr encrypter( + QuicEncrypter::Create(version, algo)); + size_t auth_tag_size = 12; + if (version.UsesInitialObfuscators()) { + auth_tag_size = 16; + } + EXPECT_EQ(encrypter->GetCiphertextSize(0), auth_tag_size); + } + } +} + +TEST_F(CryptoUtilsTest, ValidateChosenVersion) { + for (const ParsedQuicVersion& v1 : AllSupportedVersions()) { + for (const ParsedQuicVersion& v2 : AllSupportedVersions()) { + std::string error_details; + bool success = CryptoUtils::ValidateChosenVersion( + CreateQuicVersionLabel(v1), v2, &error_details); + EXPECT_EQ(success, v1 == v2); + EXPECT_EQ(success, error_details.empty()); + } + } +} + +TEST_F(CryptoUtilsTest, ValidateServerVersionsNoVersionNegotiation) { + QuicVersionLabelVector version_information_other_versions; + ParsedQuicVersionVector client_original_supported_versions; + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + std::string error_details; + EXPECT_TRUE(CryptoUtils::ValidateServerVersions( + version_information_other_versions, version, + client_original_supported_versions, &error_details)); + EXPECT_TRUE(error_details.empty()); + } +} + +TEST_F(CryptoUtilsTest, ValidateServerVersionsWithVersionNegotiation) { + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + QuicVersionLabelVector version_information_other_versions{ + CreateQuicVersionLabel(version)}; + ParsedQuicVersionVector client_original_supported_versions{ + ParsedQuicVersion::ReservedForNegotiation(), version}; + std::string error_details; + EXPECT_TRUE(CryptoUtils::ValidateServerVersions( + version_information_other_versions, version, + client_original_supported_versions, &error_details)); + EXPECT_TRUE(error_details.empty()); + } +} + +TEST_F(CryptoUtilsTest, ValidateServerVersionsWithDowngrade) { + if (AllSupportedVersions().size() <= 1) { + // We are not vulnerable to downgrade if we only support one version. + return; + } + ParsedQuicVersion client_version = AllSupportedVersions().front(); + ParsedQuicVersion server_version = AllSupportedVersions().back(); + ASSERT_NE(client_version, server_version); + QuicVersionLabelVector version_information_other_versions{ + CreateQuicVersionLabel(client_version)}; + ParsedQuicVersionVector client_original_supported_versions{ + ParsedQuicVersion::ReservedForNegotiation(), server_version}; + std::string error_details; + EXPECT_FALSE(CryptoUtils::ValidateServerVersions( + version_information_other_versions, server_version, + client_original_supported_versions, &error_details)); + EXPECT_FALSE(error_details.empty()); +} + +// Test that the library is using the correct labels for each version, and +// therefore generating correct obfuscators, using the test vectors in appendix +// A of each RFC or internet-draft. +TEST_F(CryptoUtilsTest, ValidateCryptoLabels) { + // if the number of HTTP/3 QUIC versions has changed, we need to change the + // expected_keys hardcoded into this test. Regrettably, this is not a + // compile-time constant. + EXPECT_EQ(AllSupportedVersionsWithTls().size(), 3u); + const char draft_29_key[] = {// test vector from draft-ietf-quic-tls-29, A.1 + 0x14, + static_cast(0x9d), + 0x0b, + 0x16, + 0x62, + static_cast(0xab), + static_cast(0x87), + 0x1f, + static_cast(0xbe), + 0x63, + static_cast(0xc4), + static_cast(0x9b), + 0x5e, + 0x65, + 0x5a, + 0x5d}; + const char v1_key[] = {// test vector from RFC 9001, A.1 + static_cast(0xcf), + 0x3a, + 0x53, + 0x31, + 0x65, + 0x3c, + 0x36, + 0x4c, + static_cast(0x88), + static_cast(0xf0), + static_cast(0xf3), + 0x79, + static_cast(0xb6), + 0x06, + 0x7e, + 0x37}; + const char v2_08_key[] = {// test vector from draft-ietf-quic-v2-08 + static_cast(0x82), + static_cast(0xdb), + static_cast(0x63), + static_cast(0x78), + static_cast(0x61), + static_cast(0xd5), + static_cast(0x5e), + 0x1d, + static_cast(0x01), + static_cast(0x1f), + 0x19, + static_cast(0xea), + 0x71, + static_cast(0xd5), + static_cast(0xd2), + static_cast(0xa7)}; + const char connection_id[] = // test vector from both docs + {static_cast(0x83), + static_cast(0x94), + static_cast(0xc8), + static_cast(0xf0), + 0x3e, + 0x51, + 0x57, + 0x08}; + const QuicConnectionId cid(connection_id, sizeof(connection_id)); + const char* key_str; + size_t key_size; + for (const ParsedQuicVersion& version : AllSupportedVersionsWithTls()) { + if (version == ParsedQuicVersion::Draft29()) { + key_str = draft_29_key; + key_size = sizeof(draft_29_key); + } else if (version == ParsedQuicVersion::RFCv1()) { + key_str = v1_key; + key_size = sizeof(v1_key); + } else { // draft-ietf-quic-v2-01 + key_str = v2_08_key; + key_size = sizeof(v2_08_key); + } + const absl::string_view expected_key{key_str, key_size}; + + CrypterPair crypters; + CryptoUtils::CreateInitialObfuscators(Perspective::IS_SERVER, version, cid, + &crypters); + EXPECT_EQ(crypters.encrypter->GetKey(), expected_key); + } +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/curve25519_key_exchange.cc b/quiche/quic/core/crypto/curve25519_key_exchange.cc new file mode 100644 index 000000000000..5340b41107cb --- /dev/null +++ b/quiche/quic/core/crypto/curve25519_key_exchange.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/curve25519_key_exchange.h" + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "openssl/curve25519.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" + +namespace quic { + +Curve25519KeyExchange::Curve25519KeyExchange() {} + +Curve25519KeyExchange::~Curve25519KeyExchange() {} + +// static +std::unique_ptr Curve25519KeyExchange::New( + QuicRandom* rand) { + std::unique_ptr result = + New(Curve25519KeyExchange::NewPrivateKey(rand)); + QUIC_BUG_IF(quic_bug_12891_1, result == nullptr); + return result; +} + +// static +std::unique_ptr Curve25519KeyExchange::New( + absl::string_view private_key) { + // We don't want to #include the BoringSSL headers in the public header file, + // so we use literals for the sizes of private_key_ and public_key_. Here we + // assert that those values are equal to the values from the BoringSSL + // header. + static_assert( + sizeof(Curve25519KeyExchange::private_key_) == X25519_PRIVATE_KEY_LEN, + "header out of sync"); + static_assert( + sizeof(Curve25519KeyExchange::public_key_) == X25519_PUBLIC_VALUE_LEN, + "header out of sync"); + + if (private_key.size() != X25519_PRIVATE_KEY_LEN) { + return nullptr; + } + + // Use absl::WrapUnique(new) instead of std::make_unique because + // Curve25519KeyExchange has a private constructor. + auto ka = absl::WrapUnique(new Curve25519KeyExchange); + memcpy(ka->private_key_, private_key.data(), X25519_PRIVATE_KEY_LEN); + X25519_public_from_private(ka->public_key_, ka->private_key_); + return ka; +} + +// static +std::string Curve25519KeyExchange::NewPrivateKey(QuicRandom* rand) { + uint8_t private_key[X25519_PRIVATE_KEY_LEN]; + rand->RandBytes(private_key, sizeof(private_key)); + return std::string(reinterpret_cast(private_key), sizeof(private_key)); +} + +bool Curve25519KeyExchange::CalculateSharedKeySync( + absl::string_view peer_public_value, std::string* shared_key) const { + if (peer_public_value.size() != X25519_PUBLIC_VALUE_LEN) { + return false; + } + + uint8_t result[X25519_PUBLIC_VALUE_LEN]; + if (!X25519(result, private_key_, + reinterpret_cast(peer_public_value.data()))) { + return false; + } + + shared_key->assign(reinterpret_cast(result), sizeof(result)); + return true; +} + +absl::string_view Curve25519KeyExchange::public_value() const { + return absl::string_view(reinterpret_cast(public_key_), + sizeof(public_key_)); +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/curve25519_key_exchange.h b/quiche/quic/core/crypto/curve25519_key_exchange.h new file mode 100644 index 000000000000..b6e06f384728 --- /dev/null +++ b/quiche/quic/core/crypto/curve25519_key_exchange.h @@ -0,0 +1,52 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_CURVE25519_KEY_EXCHANGE_H_ +#define QUICHE_QUIC_CORE_CRYPTO_CURVE25519_KEY_EXCHANGE_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/key_exchange.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// Curve25519KeyExchange implements a SynchronousKeyExchange using +// elliptic-curve Diffie-Hellman on curve25519. See http://cr.yp.to/ecdh.html +class QUIC_EXPORT_PRIVATE Curve25519KeyExchange + : public SynchronousKeyExchange { + public: + ~Curve25519KeyExchange() override; + + // New generates a private key and then creates new key-exchange object. + static std::unique_ptr New(QuicRandom* rand); + + // New creates a new key-exchange object from a private key. If |private_key| + // is invalid, nullptr is returned. + static std::unique_ptr New( + absl::string_view private_key); + + // NewPrivateKey returns a private key, generated from |rand|, suitable for + // passing to |New|. + static std::string NewPrivateKey(QuicRandom* rand); + + // SynchronousKeyExchange interface. + bool CalculateSharedKeySync(absl::string_view peer_public_value, + std::string* shared_key) const override; + absl::string_view public_value() const override; + QuicTag type() const override { return kC255; } + + private: + Curve25519KeyExchange(); + + uint8_t private_key_[32]; + uint8_t public_key_[32]; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_CURVE25519_KEY_EXCHANGE_H_ diff --git a/quiche/quic/core/crypto/curve25519_key_exchange_test.cc b/quiche/quic/core/crypto/curve25519_key_exchange_test.cc new file mode 100644 index 000000000000..551ee0e1bc97 --- /dev/null +++ b/quiche/quic/core/crypto/curve25519_key_exchange_test.cc @@ -0,0 +1,104 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/curve25519_key_exchange.h" + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { + +class Curve25519KeyExchangeTest : public QuicTest { + public: + // Holds the result of a key exchange callback. + class TestCallbackResult { + public: + void set_ok(bool ok) { ok_ = ok; } + bool ok() { return ok_; } + + private: + bool ok_ = false; + }; + + // Key exchange callback which sets the result into the specified + // TestCallbackResult. + class TestCallback : public AsynchronousKeyExchange::Callback { + public: + TestCallback(TestCallbackResult* result) : result_(result) {} + virtual ~TestCallback() = default; + + void Run(bool ok) { result_->set_ok(ok); } + + private: + TestCallbackResult* result_; + }; +}; + +// SharedKey just tests that the basic key exchange identity holds: that both +// parties end up with the same key. +TEST_F(Curve25519KeyExchangeTest, SharedKey) { + QuicRandom* const rand = QuicRandom::GetInstance(); + + for (int i = 0; i < 5; i++) { + const std::string alice_key(Curve25519KeyExchange::NewPrivateKey(rand)); + const std::string bob_key(Curve25519KeyExchange::NewPrivateKey(rand)); + + std::unique_ptr alice( + Curve25519KeyExchange::New(alice_key)); + std::unique_ptr bob( + Curve25519KeyExchange::New(bob_key)); + + const absl::string_view alice_public(alice->public_value()); + const absl::string_view bob_public(bob->public_value()); + + std::string alice_shared, bob_shared; + ASSERT_TRUE(alice->CalculateSharedKeySync(bob_public, &alice_shared)); + ASSERT_TRUE(bob->CalculateSharedKeySync(alice_public, &bob_shared)); + ASSERT_EQ(alice_shared, bob_shared); + } +} + +// SharedKeyAsync just tests that the basic asynchronous key exchange identity +// holds: that both parties end up with the same key. +TEST_F(Curve25519KeyExchangeTest, SharedKeyAsync) { + QuicRandom* const rand = QuicRandom::GetInstance(); + + for (int i = 0; i < 5; i++) { + const std::string alice_key(Curve25519KeyExchange::NewPrivateKey(rand)); + const std::string bob_key(Curve25519KeyExchange::NewPrivateKey(rand)); + + std::unique_ptr alice( + Curve25519KeyExchange::New(alice_key)); + std::unique_ptr bob( + Curve25519KeyExchange::New(bob_key)); + + const absl::string_view alice_public(alice->public_value()); + const absl::string_view bob_public(bob->public_value()); + + std::string alice_shared, bob_shared; + TestCallbackResult alice_result; + ASSERT_FALSE(alice_result.ok()); + alice->CalculateSharedKeyAsync( + bob_public, &alice_shared, + std::make_unique(&alice_result)); + ASSERT_TRUE(alice_result.ok()); + TestCallbackResult bob_result; + ASSERT_FALSE(bob_result.ok()); + bob->CalculateSharedKeyAsync(alice_public, &bob_shared, + std::make_unique(&bob_result)); + ASSERT_TRUE(bob_result.ok()); + ASSERT_EQ(alice_shared, bob_shared); + ASSERT_NE(0u, alice_shared.length()); + ASSERT_NE(0u, bob_shared.length()); + } +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/key_exchange.cc b/quiche/quic/core/crypto/key_exchange.cc new file mode 100644 index 000000000000..38dea001f1e0 --- /dev/null +++ b/quiche/quic/core/crypto/key_exchange.cc @@ -0,0 +1,42 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/key_exchange.h" + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/curve25519_key_exchange.h" +#include "quiche/quic/core/crypto/p256_key_exchange.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" + +namespace quic { + +std::unique_ptr CreateLocalSynchronousKeyExchange( + QuicTag type, absl::string_view private_key) { + switch (type) { + case kC255: + return Curve25519KeyExchange::New(private_key); + case kP256: + return P256KeyExchange::New(private_key); + default: + QUIC_BUG(quic_bug_10712_1) + << "Unknown key exchange method: " << QuicTagToString(type); + return nullptr; + } +} + +std::unique_ptr CreateLocalSynchronousKeyExchange( + QuicTag type, QuicRandom* rand) { + switch (type) { + case kC255: + return Curve25519KeyExchange::New(rand); + case kP256: + return P256KeyExchange::New(); + default: + QUIC_BUG(quic_bug_10712_2) + << "Unknown key exchange method: " << QuicTagToString(type); + return nullptr; + } +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/key_exchange.h b/quiche/quic/core/crypto/key_exchange.h new file mode 100644 index 000000000000..573b289ddfaf --- /dev/null +++ b/quiche/quic/core/crypto/key_exchange.h @@ -0,0 +1,101 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_KEY_EXCHANGE_H_ +#define QUICHE_QUIC_CORE_CRYPTO_KEY_EXCHANGE_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// Interface for a Diffie-Hellman key exchange with an asynchronous interface. +// This allows for implementations which hold the private key locally, as well +// as ones which make an RPC to an external key-exchange service. +class QUIC_EXPORT_PRIVATE AsynchronousKeyExchange { + public: + virtual ~AsynchronousKeyExchange() = default; + + // Callback base class for receiving the results of an async call to + // CalculateSharedKeys. + class QUIC_EXPORT_PRIVATE Callback { + public: + Callback() = default; + virtual ~Callback() = default; + + // Invoked upon completion of CalculateSharedKeysAsync. + // + // |ok| indicates whether the operation completed successfully. If false, + // then the value pointed to by |shared_key| passed in to + // CalculateSharedKeyAsync is undefined. + virtual void Run(bool ok) = 0; + + private: + Callback(const Callback&) = delete; + Callback& operator=(const Callback&) = delete; + }; + + // CalculateSharedKey computes the shared key between a private key which is + // conceptually owned by this object (though it may not be physically located + // in this process) and a public value from the peer. Callers should expect + // that |callback| might be invoked synchronously. Results will be written + // into |*shared_key|. + virtual void CalculateSharedKeyAsync( + absl::string_view peer_public_value, std::string* shared_key, + std::unique_ptr callback) const = 0; + + // Tag indicating the key-exchange algorithm this object will use. + virtual QuicTag type() const = 0; +}; + +// Interface for a Diffie-Hellman key exchange with both synchronous and +// asynchronous interfaces. Only implementations which hold the private key +// locally should implement this interface. +class QUIC_EXPORT_PRIVATE SynchronousKeyExchange + : public AsynchronousKeyExchange { + public: + virtual ~SynchronousKeyExchange() = default; + + // AyncKeyExchange API. Note that this method is marked 'final.' Subclasses + // should implement CalculateSharedKeySync only. + void CalculateSharedKeyAsync(absl::string_view peer_public_value, + std::string* shared_key, + std::unique_ptr callback) const final { + const bool ok = CalculateSharedKeySync(peer_public_value, shared_key); + callback->Run(ok); + } + + // CalculateSharedKey computes the shared key between a local private key and + // a public value from the peer. Results will be written into |*shared_key|. + virtual bool CalculateSharedKeySync(absl::string_view peer_public_value, + std::string* shared_key) const = 0; + + // public_value returns the local public key which can be sent to a peer in + // order to complete a key exchange. The returned absl::string_view is + // a reference to a member of this object and is only valid for as long as it + // exists. + virtual absl::string_view public_value() const = 0; +}; + +// Create a SynchronousKeyExchange object which will use a keypair generated +// from |private_key|, and a key-exchange algorithm specified by |type|, which +// must be one of {kC255, kC256}. Returns nullptr if |private_key| or |type| is +// invalid. +std::unique_ptr CreateLocalSynchronousKeyExchange( + QuicTag type, absl::string_view private_key); + +// Create a SynchronousKeyExchange object which will use a keypair generated +// from |rand|, and a key-exchange algorithm specified by |type|, which must be +// one of {kC255, kC256}. Returns nullptr if |type| is invalid. +std::unique_ptr CreateLocalSynchronousKeyExchange( + QuicTag type, QuicRandom* rand); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_KEY_EXCHANGE_H_ diff --git a/quiche/quic/core/crypto/null_decrypter.cc b/quiche/quic/core/crypto/null_decrypter.cc new file mode 100644 index 000000000000..af0c44476907 --- /dev/null +++ b/quiche/quic/core/crypto/null_decrypter.cc @@ -0,0 +1,121 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/null_decrypter.h" + +#include + +#include "absl/numeric/int128.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_data_reader.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/common/quiche_endian.h" + +namespace quic { + +NullDecrypter::NullDecrypter(Perspective perspective) + : perspective_(perspective) {} + +bool NullDecrypter::SetKey(absl::string_view key) { return key.empty(); } + +bool NullDecrypter::SetNoncePrefix(absl::string_view nonce_prefix) { + return nonce_prefix.empty(); +} + +bool NullDecrypter::SetIV(absl::string_view iv) { return iv.empty(); } + +bool NullDecrypter::SetHeaderProtectionKey(absl::string_view key) { + return key.empty(); +} + +bool NullDecrypter::SetPreliminaryKey(absl::string_view /*key*/) { + QUIC_BUG(quic_bug_10652_1) << "Should not be called"; + return false; +} + +bool NullDecrypter::SetDiversificationNonce( + const DiversificationNonce& /*nonce*/) { + QUIC_BUG(quic_bug_10652_2) << "Should not be called"; + return true; +} + +bool NullDecrypter::DecryptPacket(uint64_t /*packet_number*/, + absl::string_view associated_data, + absl::string_view ciphertext, char* output, + size_t* output_length, + size_t max_output_length) { + QuicDataReader reader(ciphertext.data(), ciphertext.length(), + quiche::HOST_BYTE_ORDER); + absl::uint128 hash; + + if (!ReadHash(&reader, &hash)) { + return false; + } + + absl::string_view plaintext = reader.ReadRemainingPayload(); + if (plaintext.length() > max_output_length) { + QUIC_BUG(quic_bug_10652_3) + << "Output buffer must be larger than the plaintext."; + return false; + } + if (hash != ComputeHash(associated_data, plaintext)) { + return false; + } + // Copy the plaintext to output. + memcpy(output, plaintext.data(), plaintext.length()); + *output_length = plaintext.length(); + return true; +} + +std::string NullDecrypter::GenerateHeaderProtectionMask( + QuicDataReader* /*sample_reader*/) { + return std::string(5, 0); +} + +size_t NullDecrypter::GetKeySize() const { return 0; } + +size_t NullDecrypter::GetNoncePrefixSize() const { return 0; } + +size_t NullDecrypter::GetIVSize() const { return 0; } + +absl::string_view NullDecrypter::GetKey() const { return absl::string_view(); } + +absl::string_view NullDecrypter::GetNoncePrefix() const { + return absl::string_view(); +} + +uint32_t NullDecrypter::cipher_id() const { return 0; } + +QuicPacketCount NullDecrypter::GetIntegrityLimit() const { + return std::numeric_limits::max(); +} + +bool NullDecrypter::ReadHash(QuicDataReader* reader, absl::uint128* hash) { + uint64_t lo; + uint32_t hi; + if (!reader->ReadUInt64(&lo) || !reader->ReadUInt32(&hi)) { + return false; + } + *hash = absl::MakeUint128(hi, lo); + return true; +} + +absl::uint128 NullDecrypter::ComputeHash(const absl::string_view data1, + const absl::string_view data2) const { + absl::uint128 correct_hash; + if (perspective_ == Perspective::IS_CLIENT) { + // Peer is a server. + correct_hash = QuicUtils::FNV1a_128_Hash_Three(data1, data2, "Server"); + } else { + // Peer is a client. + correct_hash = QuicUtils::FNV1a_128_Hash_Three(data1, data2, "Client"); + } + absl::uint128 mask = absl::MakeUint128(UINT64_C(0x0), UINT64_C(0xffffffff)); + mask <<= 96; + correct_hash &= ~mask; + return correct_hash; +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/null_decrypter.h b/quiche/quic/core/crypto/null_decrypter.h new file mode 100644 index 000000000000..9b6fb4501ff6 --- /dev/null +++ b/quiche/quic/core/crypto/null_decrypter.h @@ -0,0 +1,62 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_NULL_DECRYPTER_H_ +#define QUICHE_QUIC_CORE_CRYPTO_NULL_DECRYPTER_H_ + +#include +#include + +#include "absl/numeric/int128.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/quic_decrypter.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class QuicDataReader; + +// A NullDecrypter is a QuicDecrypter used before a crypto negotiation +// has occurred. It does not actually decrypt the payload, but does +// verify a hash (fnv128) over both the payload and associated data. +class QUIC_EXPORT_PRIVATE NullDecrypter : public QuicDecrypter { + public: + explicit NullDecrypter(Perspective perspective); + NullDecrypter(const NullDecrypter&) = delete; + NullDecrypter& operator=(const NullDecrypter&) = delete; + ~NullDecrypter() override {} + + // QuicDecrypter implementation + bool SetKey(absl::string_view key) override; + bool SetNoncePrefix(absl::string_view nonce_prefix) override; + bool SetIV(absl::string_view iv) override; + bool SetHeaderProtectionKey(absl::string_view key) override; + bool SetPreliminaryKey(absl::string_view key) override; + bool SetDiversificationNonce(const DiversificationNonce& nonce) override; + bool DecryptPacket(uint64_t packet_number, absl::string_view associated_data, + absl::string_view ciphertext, char* output, + size_t* output_length, size_t max_output_length) override; + std::string GenerateHeaderProtectionMask( + QuicDataReader* sample_reader) override; + size_t GetKeySize() const override; + size_t GetNoncePrefixSize() const override; + size_t GetIVSize() const override; + absl::string_view GetKey() const override; + absl::string_view GetNoncePrefix() const override; + + uint32_t cipher_id() const override; + QuicPacketCount GetIntegrityLimit() const override; + + private: + bool ReadHash(QuicDataReader* reader, absl::uint128* hash); + absl::uint128 ComputeHash(absl::string_view data1, + absl::string_view data2) const; + + Perspective perspective_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_NULL_DECRYPTER_H_ diff --git a/quiche/quic/core/crypto/null_decrypter_test.cc b/quiche/quic/core/crypto/null_decrypter_test.cc new file mode 100644 index 000000000000..e71ed01ba8f2 --- /dev/null +++ b/quiche/quic/core/crypto/null_decrypter_test.cc @@ -0,0 +1,137 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/null_decrypter.h" + +#include "absl/base/macros.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { +namespace test { + +class NullDecrypterTest : public QuicTestWithParam {}; + +TEST_F(NullDecrypterTest, DecryptClient) { + unsigned char expected[] = { + // fnv hash + 0x97, + 0xdc, + 0x27, + 0x2f, + 0x18, + 0xa8, + 0x56, + 0x73, + 0xdf, + 0x8d, + 0x1d, + 0xd0, + // payload + 'g', + 'o', + 'o', + 'd', + 'b', + 'y', + 'e', + '!', + }; + const char* data = reinterpret_cast(expected); + size_t len = ABSL_ARRAYSIZE(expected); + NullDecrypter decrypter(Perspective::IS_SERVER); + char buffer[256]; + size_t length = 0; + ASSERT_TRUE(decrypter.DecryptPacket( + 0, "hello world!", absl::string_view(data, len), buffer, &length, 256)); + EXPECT_LT(0u, length); + EXPECT_EQ("goodbye!", absl::string_view(buffer, length)); +} + +TEST_F(NullDecrypterTest, DecryptServer) { + unsigned char expected[] = { + // fnv hash + 0x63, + 0x5e, + 0x08, + 0x03, + 0x32, + 0x80, + 0x8f, + 0x73, + 0xdf, + 0x8d, + 0x1d, + 0x1a, + // payload + 'g', + 'o', + 'o', + 'd', + 'b', + 'y', + 'e', + '!', + }; + const char* data = reinterpret_cast(expected); + size_t len = ABSL_ARRAYSIZE(expected); + NullDecrypter decrypter(Perspective::IS_CLIENT); + char buffer[256]; + size_t length = 0; + ASSERT_TRUE(decrypter.DecryptPacket( + 0, "hello world!", absl::string_view(data, len), buffer, &length, 256)); + EXPECT_LT(0u, length); + EXPECT_EQ("goodbye!", absl::string_view(buffer, length)); +} + +TEST_F(NullDecrypterTest, BadHash) { + unsigned char expected[] = { + // fnv hash + 0x46, + 0x11, + 0xea, + 0x5f, + 0xcf, + 0x1d, + 0x66, + 0x5b, + 0xba, + 0xf0, + 0xbc, + 0xfd, + // payload + 'g', + 'o', + 'o', + 'd', + 'b', + 'y', + 'e', + '!', + }; + const char* data = reinterpret_cast(expected); + size_t len = ABSL_ARRAYSIZE(expected); + NullDecrypter decrypter(Perspective::IS_CLIENT); + char buffer[256]; + size_t length = 0; + ASSERT_FALSE(decrypter.DecryptPacket( + 0, "hello world!", absl::string_view(data, len), buffer, &length, 256)); +} + +TEST_F(NullDecrypterTest, ShortInput) { + unsigned char expected[] = { + // fnv hash (truncated) + 0x46, 0x11, 0xea, 0x5f, 0xcf, 0x1d, 0x66, 0x5b, 0xba, 0xf0, 0xbc, + }; + const char* data = reinterpret_cast(expected); + size_t len = ABSL_ARRAYSIZE(expected); + NullDecrypter decrypter(Perspective::IS_CLIENT); + char buffer[256]; + size_t length = 0; + ASSERT_FALSE(decrypter.DecryptPacket( + 0, "hello world!", absl::string_view(data, len), buffer, &length, 256)); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/null_encrypter.cc b/quiche/quic/core/crypto/null_encrypter.cc new file mode 100644 index 000000000000..87a3f32ac498 --- /dev/null +++ b/quiche/quic/core/crypto/null_encrypter.cc @@ -0,0 +1,88 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/null_encrypter.h" + +#include "absl/numeric/int128.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/core/quic_utils.h" + +namespace quic { + +const size_t kHashSizeShort = 12; // size of uint128 serialized short + +NullEncrypter::NullEncrypter(Perspective perspective) + : perspective_(perspective) {} + +bool NullEncrypter::SetKey(absl::string_view key) { return key.empty(); } + +bool NullEncrypter::SetNoncePrefix(absl::string_view nonce_prefix) { + return nonce_prefix.empty(); +} + +bool NullEncrypter::SetIV(absl::string_view iv) { return iv.empty(); } + +bool NullEncrypter::SetHeaderProtectionKey(absl::string_view key) { + return key.empty(); +} + +bool NullEncrypter::EncryptPacket(uint64_t /*packet_number*/, + absl::string_view associated_data, + absl::string_view plaintext, char* output, + size_t* output_length, + size_t max_output_length) { + const size_t len = plaintext.size() + GetHashLength(); + if (max_output_length < len) { + return false; + } + absl::uint128 hash; + if (perspective_ == Perspective::IS_SERVER) { + hash = + QuicUtils::FNV1a_128_Hash_Three(associated_data, plaintext, "Server"); + } else { + hash = + QuicUtils::FNV1a_128_Hash_Three(associated_data, plaintext, "Client"); + } + // TODO(ianswett): memmove required for in place encryption. Placing the + // hash at the end would allow use of memcpy, doing nothing for in place. + memmove(output + GetHashLength(), plaintext.data(), plaintext.length()); + QuicUtils::SerializeUint128Short(hash, + reinterpret_cast(output)); + *output_length = len; + return true; +} + +std::string NullEncrypter::GenerateHeaderProtectionMask( + absl::string_view /*sample*/) { + return std::string(5, 0); +} + +size_t NullEncrypter::GetKeySize() const { return 0; } + +size_t NullEncrypter::GetNoncePrefixSize() const { return 0; } + +size_t NullEncrypter::GetIVSize() const { return 0; } + +size_t NullEncrypter::GetMaxPlaintextSize(size_t ciphertext_size) const { + return ciphertext_size - std::min(ciphertext_size, GetHashLength()); +} + +size_t NullEncrypter::GetCiphertextSize(size_t plaintext_size) const { + return plaintext_size + GetHashLength(); +} + +QuicPacketCount NullEncrypter::GetConfidentialityLimit() const { + return std::numeric_limits::max(); +} + +absl::string_view NullEncrypter::GetKey() const { return absl::string_view(); } + +absl::string_view NullEncrypter::GetNoncePrefix() const { + return absl::string_view(); +} + +size_t NullEncrypter::GetHashLength() const { return kHashSizeShort; } + +} // namespace quic diff --git a/quiche/quic/core/crypto/null_encrypter.h b/quiche/quic/core/crypto/null_encrypter.h new file mode 100644 index 000000000000..c5e599f51998 --- /dev/null +++ b/quiche/quic/core/crypto/null_encrypter.h @@ -0,0 +1,53 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_NULL_ENCRYPTER_H_ +#define QUICHE_QUIC_CORE_CRYPTO_NULL_ENCRYPTER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/quic_encrypter.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// A NullEncrypter is a QuicEncrypter used before a crypto negotiation +// has occurred. It does not actually encrypt the payload, but does +// generate a MAC (fnv128) over both the payload and associated data. +class QUIC_EXPORT_PRIVATE NullEncrypter : public QuicEncrypter { + public: + explicit NullEncrypter(Perspective perspective); + NullEncrypter(const NullEncrypter&) = delete; + NullEncrypter& operator=(const NullEncrypter&) = delete; + ~NullEncrypter() override {} + + // QuicEncrypter implementation + bool SetKey(absl::string_view key) override; + bool SetNoncePrefix(absl::string_view nonce_prefix) override; + bool SetIV(absl::string_view iv) override; + bool SetHeaderProtectionKey(absl::string_view key) override; + bool EncryptPacket(uint64_t packet_number, absl::string_view associated_data, + absl::string_view plaintext, char* output, + size_t* output_length, size_t max_output_length) override; + std::string GenerateHeaderProtectionMask(absl::string_view sample) override; + size_t GetKeySize() const override; + size_t GetNoncePrefixSize() const override; + size_t GetIVSize() const override; + size_t GetMaxPlaintextSize(size_t ciphertext_size) const override; + size_t GetCiphertextSize(size_t plaintext_size) const override; + QuicPacketCount GetConfidentialityLimit() const override; + absl::string_view GetKey() const override; + absl::string_view GetNoncePrefix() const override; + + private: + size_t GetHashLength() const; + + Perspective perspective_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_NULL_ENCRYPTER_H_ diff --git a/quiche/quic/core/crypto/null_encrypter_test.cc b/quiche/quic/core/crypto/null_encrypter_test.cc new file mode 100644 index 000000000000..85a30115a189 --- /dev/null +++ b/quiche/quic/core/crypto/null_encrypter_test.cc @@ -0,0 +1,103 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/null_encrypter.h" + +#include "absl/base/macros.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +namespace quic { +namespace test { + +class NullEncrypterTest : public QuicTestWithParam {}; + +TEST_F(NullEncrypterTest, EncryptClient) { + unsigned char expected[] = { + // fnv hash + 0x97, + 0xdc, + 0x27, + 0x2f, + 0x18, + 0xa8, + 0x56, + 0x73, + 0xdf, + 0x8d, + 0x1d, + 0xd0, + // payload + 'g', + 'o', + 'o', + 'd', + 'b', + 'y', + 'e', + '!', + }; + char encrypted[256]; + size_t encrypted_len = 0; + NullEncrypter encrypter(Perspective::IS_CLIENT); + ASSERT_TRUE(encrypter.EncryptPacket(0, "hello world!", "goodbye!", encrypted, + &encrypted_len, 256)); + quiche::test::CompareCharArraysWithHexError( + "encrypted data", encrypted, encrypted_len, + reinterpret_cast(expected), ABSL_ARRAYSIZE(expected)); +} + +TEST_F(NullEncrypterTest, EncryptServer) { + unsigned char expected[] = { + // fnv hash + 0x63, + 0x5e, + 0x08, + 0x03, + 0x32, + 0x80, + 0x8f, + 0x73, + 0xdf, + 0x8d, + 0x1d, + 0x1a, + // payload + 'g', + 'o', + 'o', + 'd', + 'b', + 'y', + 'e', + '!', + }; + char encrypted[256]; + size_t encrypted_len = 0; + NullEncrypter encrypter(Perspective::IS_SERVER); + ASSERT_TRUE(encrypter.EncryptPacket(0, "hello world!", "goodbye!", encrypted, + &encrypted_len, 256)); + quiche::test::CompareCharArraysWithHexError( + "encrypted data", encrypted, encrypted_len, + reinterpret_cast(expected), ABSL_ARRAYSIZE(expected)); +} + +TEST_F(NullEncrypterTest, GetMaxPlaintextSize) { + NullEncrypter encrypter(Perspective::IS_CLIENT); + EXPECT_EQ(1000u, encrypter.GetMaxPlaintextSize(1012)); + EXPECT_EQ(100u, encrypter.GetMaxPlaintextSize(112)); + EXPECT_EQ(10u, encrypter.GetMaxPlaintextSize(22)); + EXPECT_EQ(0u, encrypter.GetMaxPlaintextSize(11)); +} + +TEST_F(NullEncrypterTest, GetCiphertextSize) { + NullEncrypter encrypter(Perspective::IS_CLIENT); + EXPECT_EQ(1012u, encrypter.GetCiphertextSize(1000)); + EXPECT_EQ(112u, encrypter.GetCiphertextSize(100)); + EXPECT_EQ(22u, encrypter.GetCiphertextSize(10)); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/p256_key_exchange.cc b/quiche/quic/core/crypto/p256_key_exchange.cc new file mode 100644 index 000000000000..6e8e53988c5e --- /dev/null +++ b/quiche/quic/core/crypto/p256_key_exchange.cc @@ -0,0 +1,121 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/p256_key_exchange.h" + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "openssl/ec.h" +#include "openssl/ecdh.h" +#include "openssl/err.h" +#include "openssl/evp.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +P256KeyExchange::P256KeyExchange(bssl::UniquePtr private_key, + const uint8_t* public_key) + : private_key_(std::move(private_key)) { + memcpy(public_key_, public_key, sizeof(public_key_)); +} + +P256KeyExchange::~P256KeyExchange() {} + +// static +std::unique_ptr P256KeyExchange::New() { + return New(P256KeyExchange::NewPrivateKey()); +} + +// static +std::unique_ptr P256KeyExchange::New(absl::string_view key) { + if (key.empty()) { + QUIC_DLOG(INFO) << "Private key is empty"; + return nullptr; + } + + const uint8_t* keyp = reinterpret_cast(key.data()); + bssl::UniquePtr private_key( + d2i_ECPrivateKey(nullptr, &keyp, key.size())); + if (!private_key.get() || !EC_KEY_check_key(private_key.get())) { + QUIC_DLOG(INFO) << "Private key is invalid."; + return nullptr; + } + + uint8_t public_key[kUncompressedP256PointBytes]; + if (EC_POINT_point2oct(EC_KEY_get0_group(private_key.get()), + EC_KEY_get0_public_key(private_key.get()), + POINT_CONVERSION_UNCOMPRESSED, public_key, + sizeof(public_key), nullptr) != sizeof(public_key)) { + QUIC_DLOG(INFO) << "Can't get public key."; + return nullptr; + } + + return absl::WrapUnique( + new P256KeyExchange(std::move(private_key), public_key)); +} + +// static +std::string P256KeyExchange::NewPrivateKey() { + bssl::UniquePtr key(EC_KEY_new_by_curve_name(NID_X9_62_prime256v1)); + if (!key.get() || !EC_KEY_generate_key(key.get())) { + QUIC_DLOG(INFO) << "Can't generate a new private key."; + return std::string(); + } + + int key_len = i2d_ECPrivateKey(key.get(), nullptr); + if (key_len <= 0) { + QUIC_DLOG(INFO) << "Can't convert private key to string"; + return std::string(); + } + std::unique_ptr private_key(new uint8_t[key_len]); + uint8_t* keyp = private_key.get(); + if (!i2d_ECPrivateKey(key.get(), &keyp)) { + QUIC_DLOG(INFO) << "Can't convert private key to string."; + return std::string(); + } + return std::string(reinterpret_cast(private_key.get()), key_len); +} + +bool P256KeyExchange::CalculateSharedKeySync( + absl::string_view peer_public_value, std::string* shared_key) const { + if (peer_public_value.size() != kUncompressedP256PointBytes) { + QUIC_DLOG(INFO) << "Peer public value is invalid"; + return false; + } + + bssl::UniquePtr point( + EC_POINT_new(EC_KEY_get0_group(private_key_.get()))); + if (!point.get() || + !EC_POINT_oct2point(/* also test if point is on curve */ + EC_KEY_get0_group(private_key_.get()), point.get(), + reinterpret_cast( + peer_public_value.data()), + peer_public_value.size(), nullptr)) { + QUIC_DLOG(INFO) << "Can't convert peer public value to curve point."; + return false; + } + + uint8_t result[kP256FieldBytes]; + if (ECDH_compute_key(result, sizeof(result), point.get(), private_key_.get(), + nullptr) != sizeof(result)) { + QUIC_DLOG(INFO) << "Can't compute ECDH shared key."; + return false; + } + + shared_key->assign(reinterpret_cast(result), sizeof(result)); + return true; +} + +absl::string_view P256KeyExchange::public_value() const { + return absl::string_view(reinterpret_cast(public_key_), + sizeof(public_key_)); +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/p256_key_exchange.h b/quiche/quic/core/crypto/p256_key_exchange.h new file mode 100644 index 000000000000..1341331f5841 --- /dev/null +++ b/quiche/quic/core/crypto/p256_key_exchange.h @@ -0,0 +1,68 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_P256_KEY_EXCHANGE_H_ +#define QUICHE_QUIC_CORE_CRYPTO_P256_KEY_EXCHANGE_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "openssl/base.h" +#include "quiche/quic/core/crypto/key_exchange.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// P256KeyExchange implements a SynchronousKeyExchange using elliptic-curve +// Diffie-Hellman on NIST P-256. +class QUIC_EXPORT_PRIVATE P256KeyExchange : public SynchronousKeyExchange { + public: + ~P256KeyExchange() override; + + // New generates a private key and then creates new key-exchange object. + static std::unique_ptr New(); + + // New creates a new key-exchange object from a private key. If |private_key| + // is invalid, nullptr is returned. + static std::unique_ptr New(absl::string_view private_key); + + // NewPrivateKey returns a private key, suitable for passing to |New|. + // If |NewPrivateKey| can't generate a private key, it returns an empty + // string. + static std::string NewPrivateKey(); + + // SynchronousKeyExchange interface. + bool CalculateSharedKeySync(absl::string_view peer_public_value, + std::string* shared_key) const override; + absl::string_view public_value() const override; + QuicTag type() const override { return kP256; } + + private: + enum { + // A P-256 field element consists of 32 bytes. + kP256FieldBytes = 32, + // A P-256 point in uncompressed form consists of 0x04 (to denote + // that the point is uncompressed) followed by two, 32-byte field + // elements. + kUncompressedP256PointBytes = 1 + 2 * kP256FieldBytes, + // The first byte in an uncompressed P-256 point. + kUncompressedECPointForm = 0x04, + }; + + // P256KeyExchange wraps |private_key|, and expects |public_key| consists of + // |kUncompressedP256PointBytes| bytes. + P256KeyExchange(bssl::UniquePtr private_key, + const uint8_t* public_key); + P256KeyExchange(const P256KeyExchange&) = delete; + P256KeyExchange& operator=(const P256KeyExchange&) = delete; + + bssl::UniquePtr private_key_; + // The public key stored as an uncompressed P-256 point. + uint8_t public_key_[kUncompressedP256PointBytes]; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_P256_KEY_EXCHANGE_H_ diff --git a/quiche/quic/core/crypto/p256_key_exchange_test.cc b/quiche/quic/core/crypto/p256_key_exchange_test.cc new file mode 100644 index 000000000000..c9bc7d3f0944 --- /dev/null +++ b/quiche/quic/core/crypto/p256_key_exchange_test.cc @@ -0,0 +1,109 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/p256_key_exchange.h" + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { + +class P256KeyExchangeTest : public QuicTest { + public: + // Holds the result of a key exchange callback. + class TestCallbackResult { + public: + void set_ok(bool ok) { ok_ = ok; } + bool ok() { return ok_; } + + private: + bool ok_ = false; + }; + + // Key exchange callback which sets the result into the specified + // TestCallbackResult. + class TestCallback : public AsynchronousKeyExchange::Callback { + public: + TestCallback(TestCallbackResult* result) : result_(result) {} + virtual ~TestCallback() = default; + + void Run(bool ok) { result_->set_ok(ok); } + + private: + TestCallbackResult* result_; + }; +}; + +// SharedKeyAsync just tests that the basic asynchronous key exchange identity +// holds: that both parties end up with the same key. +TEST_F(P256KeyExchangeTest, SharedKey) { + for (int i = 0; i < 5; i++) { + std::string alice_private(P256KeyExchange::NewPrivateKey()); + std::string bob_private(P256KeyExchange::NewPrivateKey()); + + ASSERT_FALSE(alice_private.empty()); + ASSERT_FALSE(bob_private.empty()); + ASSERT_NE(alice_private, bob_private); + + std::unique_ptr alice(P256KeyExchange::New(alice_private)); + std::unique_ptr bob(P256KeyExchange::New(bob_private)); + + ASSERT_TRUE(alice != nullptr); + ASSERT_TRUE(bob != nullptr); + + const absl::string_view alice_public(alice->public_value()); + const absl::string_view bob_public(bob->public_value()); + + std::string alice_shared, bob_shared; + ASSERT_TRUE(alice->CalculateSharedKeySync(bob_public, &alice_shared)); + ASSERT_TRUE(bob->CalculateSharedKeySync(alice_public, &bob_shared)); + ASSERT_EQ(alice_shared, bob_shared); + } +} + +// SharedKey just tests that the basic key exchange identity holds: that both +// parties end up with the same key. +TEST_F(P256KeyExchangeTest, AsyncSharedKey) { + for (int i = 0; i < 5; i++) { + std::string alice_private(P256KeyExchange::NewPrivateKey()); + std::string bob_private(P256KeyExchange::NewPrivateKey()); + + ASSERT_FALSE(alice_private.empty()); + ASSERT_FALSE(bob_private.empty()); + ASSERT_NE(alice_private, bob_private); + + std::unique_ptr alice(P256KeyExchange::New(alice_private)); + std::unique_ptr bob(P256KeyExchange::New(bob_private)); + + ASSERT_TRUE(alice != nullptr); + ASSERT_TRUE(bob != nullptr); + + const absl::string_view alice_public(alice->public_value()); + const absl::string_view bob_public(bob->public_value()); + + std::string alice_shared, bob_shared; + TestCallbackResult alice_result; + ASSERT_FALSE(alice_result.ok()); + alice->CalculateSharedKeyAsync( + bob_public, &alice_shared, + std::make_unique(&alice_result)); + ASSERT_TRUE(alice_result.ok()); + TestCallbackResult bob_result; + ASSERT_FALSE(bob_result.ok()); + bob->CalculateSharedKeyAsync(alice_public, &bob_shared, + std::make_unique(&bob_result)); + ASSERT_TRUE(bob_result.ok()); + ASSERT_EQ(alice_shared, bob_shared); + ASSERT_NE(0u, alice_shared.length()); + ASSERT_NE(0u, bob_shared.length()); + } +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/proof_source.cc b/quiche/quic/core/crypto/proof_source.cc new file mode 100644 index 000000000000..b340bc546d8b --- /dev/null +++ b/quiche/quic/core/crypto/proof_source.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/proof_source.h" + +#include + +#include "quiche/quic/platform/api/quic_bug_tracker.h" + +namespace quic { + +CryptoBuffers::~CryptoBuffers() { + for (size_t i = 0; i < value.size(); i++) { + CRYPTO_BUFFER_free(value[i]); + } +} + +ProofSource::Chain::Chain(const std::vector& certs) + : certs(certs) {} + +ProofSource::Chain::~Chain() {} + +CryptoBuffers ProofSource::Chain::ToCryptoBuffers() const { + CryptoBuffers crypto_buffers; + crypto_buffers.value.reserve(certs.size()); + for (size_t i = 0; i < certs.size(); i++) { + crypto_buffers.value.push_back( + CRYPTO_BUFFER_new(reinterpret_cast(certs[i].data()), + certs[i].length(), nullptr)); + } + return crypto_buffers; +} + +bool ValidateCertAndKey( + const quiche::QuicheReferenceCountedPointer& chain, + const CertificatePrivateKey& key) { + if (chain.get() == nullptr || chain->certs.empty()) { + QUIC_BUG(quic_proof_source_empty_chain) << "Certificate chain is empty"; + return false; + } + + std::unique_ptr leaf = + CertificateView::ParseSingleCertificate(chain->certs[0]); + if (leaf == nullptr) { + QUIC_BUG(quic_proof_source_unparsable_leaf_cert) + << "Unabled to parse leaf certificate"; + return false; + } + + if (!key.MatchesPublicKey(*leaf)) { + QUIC_BUG(quic_proof_source_key_mismatch) + << "Private key does not match the leaf certificate"; + return false; + } + return true; +} + +void ProofSource::OnNewSslCtx(SSL_CTX*) {} + +} // namespace quic diff --git a/quiche/quic/core/crypto/proof_source.h b/quiche/quic/core/crypto/proof_source.h new file mode 100644 index 000000000000..7721554a8bd5 --- /dev/null +++ b/quiche/quic/core/crypto/proof_source.h @@ -0,0 +1,354 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_PROOF_SOURCE_H_ +#define QUICHE_QUIC_CORE_CRYPTO_PROOF_SOURCE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "openssl/ssl.h" +#include "quiche/quic/core/crypto/certificate_view.h" +#include "quiche/quic/core/crypto/quic_crypto_proof.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/common/platform/api/quiche_reference_counted.h" + +namespace quic { + +namespace test { +class FakeProofSourceHandle; +} // namespace test + +// CryptoBuffers is a RAII class to own a std::vector and the +// buffers the elements point to. +struct QUIC_EXPORT_PRIVATE CryptoBuffers { + CryptoBuffers() = default; + CryptoBuffers(const CryptoBuffers&) = delete; + CryptoBuffers(CryptoBuffers&&) = default; + ~CryptoBuffers(); + + std::vector value; +}; + +// ProofSource is an interface by which a QUIC server can obtain certificate +// chains and signatures that prove its identity. +class QUIC_EXPORT_PRIVATE ProofSource { + public: + // Chain is a reference-counted wrapper for a vector of stringified + // certificates. + struct QUIC_EXPORT_PRIVATE Chain : public quiche::QuicheReferenceCounted { + explicit Chain(const std::vector& certs); + Chain(const Chain&) = delete; + Chain& operator=(const Chain&) = delete; + + CryptoBuffers ToCryptoBuffers() const; + + const std::vector certs; + + protected: + ~Chain() override; + }; + + // Details is an abstract class which acts as a container for any + // implementation-specific details that a ProofSource wants to return. + class QUIC_EXPORT_PRIVATE Details { + public: + virtual ~Details() {} + }; + + // Callback base class for receiving the results of an async call to GetProof. + class QUIC_EXPORT_PRIVATE Callback { + public: + Callback() {} + virtual ~Callback() {} + + // Invoked upon completion of GetProof. + // + // |ok| indicates whether the operation completed successfully. If false, + // the values of the remaining three arguments are undefined. + // + // |chain| is a reference-counted pointer to an object representing the + // certificate chain. + // + // |signature| contains the signature of the server config. + // + // |leaf_cert_sct| holds the signed timestamp (RFC6962) of the leaf cert. + // + // |details| holds a pointer to an object representing the statistics, if + // any, gathered during the operation of GetProof. If no stats are + // available, this will be nullptr. + virtual void Run(bool ok, + const quiche::QuicheReferenceCountedPointer& chain, + const QuicCryptoProof& proof, + std::unique_ptr
details) = 0; + + private: + Callback(const Callback&) = delete; + Callback& operator=(const Callback&) = delete; + }; + + // Base class for signalling the completion of a call to ComputeTlsSignature. + class QUIC_EXPORT_PRIVATE SignatureCallback { + public: + SignatureCallback() {} + virtual ~SignatureCallback() = default; + + // Invoked upon completion of ComputeTlsSignature. + // + // |ok| indicates whether the operation completed successfully. + // + // |signature| contains the signature of the data provided to + // ComputeTlsSignature. Its value is undefined if |ok| is false. + // + // |details| holds a pointer to an object representing the statistics, if + // any, gathered during the operation of ComputeTlsSignature. If no stats + // are available, this will be nullptr. + virtual void Run(bool ok, std::string signature, + std::unique_ptr
details) = 0; + + private: + SignatureCallback(const SignatureCallback&) = delete; + SignatureCallback& operator=(const SignatureCallback&) = delete; + }; + + virtual ~ProofSource() {} + + // OnNewSslCtx changes SSL parameters if required by ProofSource + // implementation. It is called when new SSL_CTX is created for a listener. + // Default implementation does nothing. + // + // This function may be called concurrently. + virtual void OnNewSslCtx(SSL_CTX* ssl_ctx); + + // GetProof finds a certificate chain for |hostname| (in leaf-first order), + // and calculates a signature of |server_config| using that chain. + // + // The signature uses SHA-256 as the hash function and PSS padding when the + // key is RSA. + // + // The signature uses SHA-256 as the hash function when the key is ECDSA. + // The signature may use an ECDSA key. + // + // The signature depends on |chlo_hash| which means that the signature can not + // be cached. + // + // |hostname| may be empty to signify that a default certificate should be + // used. + // + // This function may be called concurrently. + // + // Callers should expect that |callback| might be invoked synchronously. + virtual void GetProof(const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const std::string& hostname, + const std::string& server_config, + QuicTransportVersion transport_version, + absl::string_view chlo_hash, + std::unique_ptr callback) = 0; + + // Returns the certificate chain for |hostname| in leaf-first order. + // + // Sets *cert_matched_sni to true if the certificate matched the given + // hostname, false if a default cert not matching the hostname was used. + virtual quiche::QuicheReferenceCountedPointer GetCertChain( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, const std::string& hostname, + bool* cert_matched_sni) = 0; + + // Computes a signature using the private key of the certificate for + // |hostname|. The value in |in| is signed using the algorithm specified by + // |signature_algorithm|, which is an |SSL_SIGN_*| value (as defined in TLS + // 1.3). Implementations can only assume that |in| is valid during the call to + // ComputeTlsSignature - an implementation computing signatures asynchronously + // must copy it if the value to be signed is used outside of this function. + // + // Callers should expect that |callback| might be invoked synchronously. + virtual void ComputeTlsSignature( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, const std::string& hostname, + uint16_t signature_algorithm, absl::string_view in, + std::unique_ptr callback) = 0; + + // Return the list of TLS signature algorithms that is acceptable by the + // ComputeTlsSignature method. If the entire BoringSSL's default list of + // supported signature algorithms are acceptable, return an empty list. + // + // If returns a non-empty list, ComputeTlsSignature will only be called with a + // algorithm in the list. + virtual QuicSignatureAlgorithmVector SupportedTlsSignatureAlgorithms() + const = 0; + + class QUIC_EXPORT_PRIVATE DecryptCallback { + public: + DecryptCallback() = default; + virtual ~DecryptCallback() = default; + + virtual void Run(std::vector plaintext) = 0; + + private: + DecryptCallback(const Callback&) = delete; + DecryptCallback& operator=(const Callback&) = delete; + }; + + // TicketCrypter is an interface for managing encryption and decryption of TLS + // session tickets. A TicketCrypter gets used as an + // SSL_CTX_set_ticket_aead_method in BoringSSL, which has a synchronous + // Encrypt/Seal operation and a potentially asynchronous Decrypt/Open + // operation. This interface allows for ticket decryptions to be performed on + // a remote service. + class QUIC_EXPORT_PRIVATE TicketCrypter { + public: + TicketCrypter() = default; + virtual ~TicketCrypter() = default; + + // MaxOverhead returns the maximum number of bytes of overhead that may get + // added when encrypting the ticket. + virtual size_t MaxOverhead() = 0; + + // Encrypt takes a serialized TLS session ticket in |in|, encrypts it, and + // returns the encrypted ticket. The resulting value must not be larger than + // MaxOverhead bytes larger than |in|. If encryption fails, this method + // returns an empty vector. + // + // If |encryption_key| is nonempty, this method should use it for minting + // TLS resumption tickets. If it is empty, this method may use an + // internally cached encryption key, if available. + virtual std::vector Encrypt(absl::string_view in, + absl::string_view encryption_key) = 0; + + // Decrypt takes an encrypted ticket |in|, decrypts it, and calls + // |callback->Run| with the decrypted ticket, which must not be larger than + // |in|. If decryption fails, the callback is invoked with an empty + // vector. + virtual void Decrypt(absl::string_view in, + std::shared_ptr callback) = 0; + }; + + // Returns the TicketCrypter used for encrypting and decrypting TLS + // session tickets, or nullptr if that functionality is not supported. The + // TicketCrypter returned (if not nullptr) must be valid for the lifetime of + // the ProofSource, and the caller does not take ownership of said + // TicketCrypter. + virtual TicketCrypter* GetTicketCrypter() = 0; +}; + +// ProofSourceHandleCallback is an interface that contains the callbacks when +// the operations in ProofSourceHandle completes. +// TODO(wub): Consider deprecating ProofSource by moving all functionalities of +// ProofSource into ProofSourceHandle. +class QUIC_EXPORT_PRIVATE ProofSourceHandleCallback { + public: + virtual ~ProofSourceHandleCallback() = default; + + // Called when a ProofSourceHandle::SelectCertificate operation completes. + // |ok| indicates whether the operation was successful. + // |is_sync| indicates whether the operation completed synchronously, i.e. + // whether it is completed before ProofSourceHandle::SelectCertificate + // returned. + // |chain| the certificate chain in leaf-first order. + // |handshake_hints| (optional) handshake hints that can be used by + // SSL_set_handshake_hints. + // |ticket_encryption_key| (optional) encryption key to be used for minting + // TLS resumption tickets. + // |cert_matched_sni| is true if the certificate matched the SNI hostname, + // false if a non-matching default cert was used. + // |delayed_ssl_config| contains SSL configs to be applied on the SSL object. + // + // When called asynchronously(is_sync=false), this method will be responsible + // to continue the handshake from where it left off. + virtual void OnSelectCertificateDone( + bool ok, bool is_sync, const ProofSource::Chain* chain, + absl::string_view handshake_hints, + absl::string_view ticket_encryption_key, bool cert_matched_sni, + QuicDelayedSSLConfig delayed_ssl_config) = 0; + + // Called when a ProofSourceHandle::ComputeSignature operation completes. + virtual void OnComputeSignatureDone( + bool ok, bool is_sync, std::string signature, + std::unique_ptr details) = 0; + + // Return true iff ProofSourceHandle::ComputeSignature won't be called later. + // The handle can use this function to release resources promptly. + virtual bool WillNotCallComputeSignature() const = 0; +}; + +// ProofSourceHandle is an interface by which a TlsServerHandshaker can obtain +// certificate chains and signatures that prove its identity. +// The operations this interface supports are similar to those in ProofSource, +// the main difference is that ProofSourceHandle is per-handshaker, so +// an implementation can have states that are shared by multiple calls on the +// same handle. +// +// A handle object is owned by a TlsServerHandshaker. Since there might be an +// async operation pending when the handle destructs, an implementation must +// ensure when such operations finish, their corresponding callback method won't +// be invoked. +// +// A handle will have at most one async operation pending at a time. +class QUIC_EXPORT_PRIVATE ProofSourceHandle { + public: + virtual ~ProofSourceHandle() = default; + + // Close the handle. Cancel the pending operation, if any. + // Once called, any completion method on |callback()| won't be invoked, and + // future SelectCertificate and ComputeSignature calls should return failure. + virtual void CloseHandle() = 0; + + // Starts a select certificate operation. If the operation is not cancelled + // when it completes, callback()->OnSelectCertificateDone will be invoked. + // + // server_address and client_address should be normalized by the caller before + // sending down to this function. + // + // If the operation is handled synchronously: + // - QUIC_SUCCESS or QUIC_FAILURE will be returned. + // - callback()->OnSelectCertificateDone should be invoked before the function + // returns. + // + // If the operation is handled asynchronously: + // - QUIC_PENDING will be returned. + // - When the operation is done, callback()->OnSelectCertificateDone should be + // invoked. + virtual QuicAsyncStatus SelectCertificate( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const QuicConnectionId& original_connection_id, + absl::string_view ssl_capabilities, const std::string& hostname, + absl::string_view client_hello, const std::string& alpn, + absl::optional alps, + const std::vector& quic_transport_params, + const absl::optional>& early_data_context, + const QuicSSLConfig& ssl_config) = 0; + + // Starts a compute signature operation. If the operation is not cancelled + // when it completes, callback()->OnComputeSignatureDone will be invoked. + // + // See the comments of SelectCertificate for sync vs. async operations. + virtual QuicAsyncStatus ComputeSignature( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, const std::string& hostname, + uint16_t signature_algorithm, absl::string_view in, + size_t max_signature_size) = 0; + + protected: + // Returns the object that will be notified when an operation completes. + virtual ProofSourceHandleCallback* callback() = 0; + + private: + friend class test::FakeProofSourceHandle; +}; + +// Returns true if |chain| contains a parsable DER-encoded X.509 leaf cert and +// it matches with |key|. +QUIC_EXPORT_PRIVATE bool ValidateCertAndKey( + const quiche::QuicheReferenceCountedPointer& chain, + const CertificatePrivateKey& key); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_PROOF_SOURCE_H_ diff --git a/quiche/quic/core/crypto/proof_source_x509.cc b/quiche/quic/core/crypto/proof_source_x509.cc new file mode 100644 index 000000000000..a86c78bf81fa --- /dev/null +++ b/quiche/quic/core/crypto/proof_source_x509.cc @@ -0,0 +1,169 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/proof_source_x509.h" + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "openssl/ssl.h" +#include "quiche/quic/core/crypto/certificate_view.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/crypto/crypto_utils.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/common/quiche_endian.h" + +namespace quic { + +ProofSourceX509::ProofSourceX509( + quiche::QuicheReferenceCountedPointer default_chain, + CertificatePrivateKey default_key) { + if (!AddCertificateChain(default_chain, std::move(default_key))) { + return; + } + default_certificate_ = &certificates_.front(); +} + +std::unique_ptr ProofSourceX509::Create( + quiche::QuicheReferenceCountedPointer default_chain, + CertificatePrivateKey default_key) { + std::unique_ptr result( + new ProofSourceX509(default_chain, std::move(default_key))); + if (!result->valid()) { + return nullptr; + } + return result; +} + +void ProofSourceX509::GetProof( + const QuicSocketAddress& /*server_address*/, + const QuicSocketAddress& /*client_address*/, const std::string& hostname, + const std::string& server_config, + QuicTransportVersion /*transport_version*/, absl::string_view chlo_hash, + std::unique_ptr callback) { + QuicCryptoProof proof; + + if (!valid()) { + QUIC_BUG(ProofSourceX509::GetProof called in invalid state) + << "ProofSourceX509::GetProof called while the object is not valid"; + callback->Run(/*ok=*/false, nullptr, proof, nullptr); + return; + } + + absl::optional payload = + CryptoUtils::GenerateProofPayloadToBeSigned(chlo_hash, server_config); + if (!payload.has_value()) { + callback->Run(/*ok=*/false, nullptr, proof, nullptr); + return; + } + + Certificate* certificate = GetCertificate(hostname, &proof.cert_matched_sni); + proof.signature = + certificate->key.Sign(*payload, SSL_SIGN_RSA_PSS_RSAE_SHA256); + MaybeAddSctsForHostname(hostname, proof.leaf_cert_scts); + callback->Run(/*ok=*/!proof.signature.empty(), certificate->chain, proof, + nullptr); +} + +quiche::QuicheReferenceCountedPointer +ProofSourceX509::GetCertChain(const QuicSocketAddress& /*server_address*/, + const QuicSocketAddress& /*client_address*/, + const std::string& hostname, + bool* cert_matched_sni) { + if (!valid()) { + QUIC_BUG(ProofSourceX509::GetCertChain called in invalid state) + << "ProofSourceX509::GetCertChain called while the object is not " + "valid"; + return nullptr; + } + + return GetCertificate(hostname, cert_matched_sni)->chain; +} + +void ProofSourceX509::ComputeTlsSignature( + const QuicSocketAddress& /*server_address*/, + const QuicSocketAddress& /*client_address*/, const std::string& hostname, + uint16_t signature_algorithm, absl::string_view in, + std::unique_ptr callback) { + if (!valid()) { + QUIC_BUG(ProofSourceX509::ComputeTlsSignature called in invalid state) + << "ProofSourceX509::ComputeTlsSignature called while the object is " + "not valid"; + callback->Run(/*ok=*/false, "", nullptr); + return; + } + + bool cert_matched_sni; + std::string signature = GetCertificate(hostname, &cert_matched_sni) + ->key.Sign(in, signature_algorithm); + callback->Run(/*ok=*/!signature.empty(), signature, nullptr); +} + +QuicSignatureAlgorithmVector ProofSourceX509::SupportedTlsSignatureAlgorithms() + const { + return SupportedSignatureAlgorithmsForQuic(); +} + +ProofSource::TicketCrypter* ProofSourceX509::GetTicketCrypter() { + return nullptr; +} + +bool ProofSourceX509::AddCertificateChain( + quiche::QuicheReferenceCountedPointer chain, + CertificatePrivateKey key) { + if (chain->certs.empty()) { + QUIC_BUG(quic_bug_10644_1) << "Empty certificate chain supplied."; + return false; + } + + std::unique_ptr leaf = + CertificateView::ParseSingleCertificate(chain->certs[0]); + if (leaf == nullptr) { + QUIC_BUG(quic_bug_10644_2) + << "Unable to parse X.509 leaf certificate in the supplied chain."; + return false; + } + if (!key.MatchesPublicKey(*leaf)) { + QUIC_BUG(quic_bug_10644_3) + << "Private key does not match the leaf certificate."; + return false; + } + + certificates_.push_front(Certificate{ + chain, + std::move(key), + }); + Certificate* certificate = &certificates_.front(); + + for (absl::string_view host : leaf->subject_alt_name_domains()) { + certificate_map_[std::string(host)] = certificate; + } + return true; +} + +ProofSourceX509::Certificate* ProofSourceX509::GetCertificate( + const std::string& hostname, bool* cert_matched_sni) const { + QUICHE_DCHECK(valid()); + auto it = certificate_map_.find(hostname); + if (it != certificate_map_.end()) { + *cert_matched_sni = true; + return it->second; + } + auto dot_pos = hostname.find('.'); + if (dot_pos != std::string::npos) { + std::string wildcard = absl::StrCat("*", hostname.substr(dot_pos)); + it = certificate_map_.find(wildcard); + if (it != certificate_map_.end()) { + *cert_matched_sni = true; + return it->second; + } + } + *cert_matched_sni = false; + return default_certificate_; +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/proof_source_x509.h b/quiche/quic/core/crypto/proof_source_x509.h new file mode 100644 index 000000000000..fa62bbf90d34 --- /dev/null +++ b/quiche/quic/core/crypto/proof_source_x509.h @@ -0,0 +1,84 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_PROOF_SOURCE_X509_H_ +#define QUICHE_QUIC_CORE_CRYPTO_PROOF_SOURCE_X509_H_ + +#include +#include + +#include "absl/base/attributes.h" +#include "absl/container/node_hash_map.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/certificate_view.h" +#include "quiche/quic/core/crypto/proof_source.h" +#include "quiche/quic/core/crypto/quic_crypto_proof.h" + +namespace quic { + +// ProofSourceX509 accepts X.509 certificates with private keys and picks a +// certificate internally based on its SubjectAltName value. +class QUIC_EXPORT_PRIVATE ProofSourceX509 : public ProofSource { + public: + // Creates a proof source that uses |default_chain| when no SubjectAltName + // value matches. Returns nullptr if |default_chain| is invalid. + static std::unique_ptr Create( + quiche::QuicheReferenceCountedPointer default_chain, + CertificatePrivateKey default_key); + + // ProofSource implementation. + void GetProof(const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const std::string& hostname, const std::string& server_config, + QuicTransportVersion transport_version, + absl::string_view chlo_hash, + std::unique_ptr callback) override; + quiche::QuicheReferenceCountedPointer GetCertChain( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, const std::string& hostname, + bool* cert_matched_sni) override; + void ComputeTlsSignature( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, const std::string& hostname, + uint16_t signature_algorithm, absl::string_view in, + std::unique_ptr callback) override; + QuicSignatureAlgorithmVector SupportedTlsSignatureAlgorithms() const override; + TicketCrypter* GetTicketCrypter() override; + + // Adds a certificate chain to the verifier. Returns false if the chain is + // not valid. Newer certificates will override older certificates with the + // same SubjectAltName value. + ABSL_MUST_USE_RESULT bool AddCertificateChain( + quiche::QuicheReferenceCountedPointer chain, + CertificatePrivateKey key); + + protected: + ProofSourceX509(quiche::QuicheReferenceCountedPointer default_chain, + CertificatePrivateKey default_key); + bool valid() const { return default_certificate_ != nullptr; } + + // Gives an opportunity for the subclass proof source to provide SCTs for a + // given hostname. + virtual void MaybeAddSctsForHostname(absl::string_view /*hostname*/, + std::string& /*leaf_cert_scts*/) {} + + private: + struct QUIC_EXPORT_PRIVATE Certificate { + quiche::QuicheReferenceCountedPointer chain; + CertificatePrivateKey key; + }; + + // Looks up certficiate for hostname, returns the default if no certificate is + // found. + Certificate* GetCertificate(const std::string& hostname, + bool* cert_matched_sni) const; + + std::forward_list certificates_; + Certificate* default_certificate_ = nullptr; + absl::node_hash_map certificate_map_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_PROOF_SOURCE_X509_H_ diff --git a/quiche/quic/core/crypto/proof_source_x509_test.cc b/quiche/quic/core/crypto/proof_source_x509_test.cc new file mode 100644 index 000000000000..6db9c75ca1b1 --- /dev/null +++ b/quiche/quic/core/crypto/proof_source_x509_test.cc @@ -0,0 +1,142 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/proof_source_x509.h" + +#include + +#include "absl/strings/string_view.h" +#include "openssl/ssl.h" +#include "quiche/quic/core/crypto/certificate_view.h" +#include "quiche/quic/core/crypto/proof_source.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/test_certificates.h" +#include "quiche/common/platform/api/quiche_reference_counted.h" + +namespace quic { +namespace test { +namespace { + +quiche::QuicheReferenceCountedPointer MakeChain( + absl::string_view cert) { + return quiche::QuicheReferenceCountedPointer( + new ProofSource::Chain(std::vector{std::string(cert)})); +} + +class ProofSourceX509Test : public QuicTest { + public: + ProofSourceX509Test() + : test_chain_(MakeChain(kTestCertificate)), + wildcard_chain_(MakeChain(kWildcardCertificate)), + test_key_( + CertificatePrivateKey::LoadFromDer(kTestCertificatePrivateKey)), + wildcard_key_(CertificatePrivateKey::LoadFromDer( + kWildcardCertificatePrivateKey)) { + QUICHE_CHECK(test_key_ != nullptr); + QUICHE_CHECK(wildcard_key_ != nullptr); + } + + protected: + quiche::QuicheReferenceCountedPointer test_chain_, + wildcard_chain_; + std::unique_ptr test_key_, wildcard_key_; +}; + +TEST_F(ProofSourceX509Test, AddCertificates) { + std::unique_ptr proof_source = + ProofSourceX509::Create(test_chain_, std::move(*test_key_)); + ASSERT_TRUE(proof_source != nullptr); + EXPECT_TRUE(proof_source->AddCertificateChain(wildcard_chain_, + std::move(*wildcard_key_))); +} + +TEST_F(ProofSourceX509Test, AddCertificateKeyMismatch) { + std::unique_ptr proof_source = + ProofSourceX509::Create(test_chain_, std::move(*test_key_)); + ASSERT_TRUE(proof_source != nullptr); + test_key_ = CertificatePrivateKey::LoadFromDer(kTestCertificatePrivateKey); + EXPECT_QUIC_BUG((void)proof_source->AddCertificateChain( + wildcard_chain_, std::move(*test_key_)), + "Private key does not match"); +} + +TEST_F(ProofSourceX509Test, CertificateSelection) { + std::unique_ptr proof_source = + ProofSourceX509::Create(test_chain_, std::move(*test_key_)); + ASSERT_TRUE(proof_source != nullptr); + ASSERT_TRUE(proof_source->AddCertificateChain(wildcard_chain_, + std::move(*wildcard_key_))); + + // Default certificate. + bool cert_matched_sni; + EXPECT_EQ(proof_source + ->GetCertChain(QuicSocketAddress(), QuicSocketAddress(), + "unknown.test", &cert_matched_sni) + ->certs[0], + kTestCertificate); + EXPECT_FALSE(cert_matched_sni); + // mail.example.org is explicitly a SubjectAltName in kTestCertificate. + EXPECT_EQ(proof_source + ->GetCertChain(QuicSocketAddress(), QuicSocketAddress(), + "mail.example.org", &cert_matched_sni) + ->certs[0], + kTestCertificate); + EXPECT_TRUE(cert_matched_sni); + // www.foo.test is in kWildcardCertificate. + EXPECT_EQ(proof_source + ->GetCertChain(QuicSocketAddress(), QuicSocketAddress(), + "www.foo.test", &cert_matched_sni) + ->certs[0], + kWildcardCertificate); + EXPECT_TRUE(cert_matched_sni); + // *.wildcard.test is in kWildcardCertificate. + EXPECT_EQ(proof_source + ->GetCertChain(QuicSocketAddress(), QuicSocketAddress(), + "www.wildcard.test", &cert_matched_sni) + ->certs[0], + kWildcardCertificate); + EXPECT_TRUE(cert_matched_sni); + EXPECT_EQ(proof_source + ->GetCertChain(QuicSocketAddress(), QuicSocketAddress(), + "etc.wildcard.test", &cert_matched_sni) + ->certs[0], + kWildcardCertificate); + EXPECT_TRUE(cert_matched_sni); + // wildcard.test itself is not in kWildcardCertificate. + EXPECT_EQ(proof_source + ->GetCertChain(QuicSocketAddress(), QuicSocketAddress(), + "wildcard.test", &cert_matched_sni) + ->certs[0], + kTestCertificate); + EXPECT_FALSE(cert_matched_sni); +} + +TEST_F(ProofSourceX509Test, TlsSignature) { + class Callback : public ProofSource::SignatureCallback { + public: + void Run(bool ok, std::string signature, + std::unique_ptr /*details*/) override { + ASSERT_TRUE(ok); + std::unique_ptr view = + CertificateView::ParseSingleCertificate(kTestCertificate); + EXPECT_TRUE(view->VerifySignature("Test data", signature, + SSL_SIGN_RSA_PSS_RSAE_SHA256)); + } + }; + + std::unique_ptr proof_source = + ProofSourceX509::Create(test_chain_, std::move(*test_key_)); + ASSERT_TRUE(proof_source != nullptr); + + proof_source->ComputeTlsSignature(QuicSocketAddress(), QuicSocketAddress(), + "example.com", SSL_SIGN_RSA_PSS_RSAE_SHA256, + "Test data", std::make_unique()); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/proof_verifier.h b/quiche/quic/core/crypto/proof_verifier.h new file mode 100644 index 000000000000..e9c6f0353652 --- /dev/null +++ b/quiche/quic/core/crypto/proof_verifier.h @@ -0,0 +1,117 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_PROOF_VERIFIER_H_ +#define QUICHE_QUIC_CORE_CRYPTO_PROOF_VERIFIER_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// ProofVerifyDetails is an abstract class that acts as a container for any +// implementation specific details that a ProofVerifier wishes to return. These +// details are saved in the CachedState for the origin in question. +class QUIC_EXPORT_PRIVATE ProofVerifyDetails { + public: + virtual ~ProofVerifyDetails() {} + + // Returns an new ProofVerifyDetails object with the same contents + // as this one. + virtual ProofVerifyDetails* Clone() const = 0; +}; + +// ProofVerifyContext is an abstract class that acts as a container for any +// implementation specific context that a ProofVerifier needs. +class QUIC_EXPORT_PRIVATE ProofVerifyContext { + public: + virtual ~ProofVerifyContext() {} +}; + +// ProofVerifierCallback provides a generic mechanism for a ProofVerifier to +// call back after an asynchronous verification. +class QUIC_EXPORT_PRIVATE ProofVerifierCallback { + public: + virtual ~ProofVerifierCallback() {} + + // Run is called on the original thread to mark the completion of an + // asynchonous verification. If |ok| is true then the certificate is valid + // and |error_details| is unused. Otherwise, |error_details| contains a + // description of the error. |details| contains implementation-specific + // details of the verification. |Run| may take ownership of |details| by + // calling |release| on it. + virtual void Run(bool ok, const std::string& error_details, + std::unique_ptr* details) = 0; +}; + +// A ProofVerifier checks the signature on a server config, and the certificate +// chain that backs the public key. +class QUIC_EXPORT_PRIVATE ProofVerifier { + public: + virtual ~ProofVerifier() {} + + // VerifyProof checks that |signature| is a valid signature of + // |server_config| by the public key in the leaf certificate of |certs|, and + // that |certs| is a valid chain for |hostname|. On success, it returns + // QUIC_SUCCESS. On failure, it returns QUIC_FAILURE and sets |*error_details| + // to a description of the problem. In either case it may set |*details|, + // which the caller takes ownership of. + // + // |context| specifies an implementation specific struct (which may be nullptr + // for some implementations) that provides useful information for the + // verifier, e.g. logging handles. + // + // This function may also return QUIC_PENDING, in which case the ProofVerifier + // will call back, on the original thread, via |callback| when complete. + // + // The signature uses SHA-256 as the hash function and PSS padding in the + // case of RSA. + virtual QuicAsyncStatus VerifyProof( + const std::string& hostname, const uint16_t port, + const std::string& server_config, QuicTransportVersion transport_version, + absl::string_view chlo_hash, const std::vector& certs, + const std::string& cert_sct, const std::string& signature, + const ProofVerifyContext* context, std::string* error_details, + std::unique_ptr* details, + std::unique_ptr callback) = 0; + + // VerifyCertChain checks that |certs| is a valid chain for |hostname|. On + // success, it returns QUIC_SUCCESS. On failure, it returns QUIC_FAILURE and + // sets |*error_details| to a description of the problem. In either case it + // may set |*details|, which the caller takes ownership of. + // + // |context| specifies an implementation specific struct (which may be nullptr + // for some implementations) that provides useful information for the + // verifier, e.g. logging handles. + // + // If certificate verification fails, a TLS alert will be sent when closing + // the connection. This alert defaults to certificate_unknown. By setting + // |*out_alert|, a different alert can be sent to provide a more specific + // reason why verification failed. + // + // This function may also return QUIC_PENDING, in which case the ProofVerifier + // will call back, on the original thread, via |callback| when complete. + // In this case, the ProofVerifier will take ownership of |callback|. + virtual QuicAsyncStatus VerifyCertChain( + const std::string& hostname, const uint16_t port, + const std::vector& certs, const std::string& ocsp_response, + const std::string& cert_sct, const ProofVerifyContext* context, + std::string* error_details, std::unique_ptr* details, + uint8_t* out_alert, std::unique_ptr callback) = 0; + + // Returns a ProofVerifyContext instance which can be use for subsequent + // verifications. Applications may chose create a different context and + // supply it for verifications instead. + virtual std::unique_ptr CreateDefaultContext() = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_PROOF_VERIFIER_H_ diff --git a/quiche/quic/core/crypto/quic_client_session_cache.cc b/quiche/quic/core/crypto/quic_client_session_cache.cc new file mode 100644 index 000000000000..32f115dca75e --- /dev/null +++ b/quiche/quic/core/crypto/quic_client_session_cache.cc @@ -0,0 +1,173 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/quic_client_session_cache.h" + +#include "quiche/quic/core/quic_clock.h" + +namespace quic { + +namespace { + +const size_t kDefaultMaxEntries = 1024; +// Returns false if the SSL |session| doesn't exist or it is expired at |now|. +bool IsValid(SSL_SESSION* session, uint64_t now) { + if (!session) return false; + + // now_u64 may be slightly behind because of differences in how + // time is calculated at this layer versus BoringSSL. + // Add a second of wiggle room to account for this. + return !(now + 1 < SSL_SESSION_get_time(session) || + now >= SSL_SESSION_get_time(session) + + SSL_SESSION_get_timeout(session)); +} + +bool DoApplicationStatesMatch(const ApplicationState* state, + ApplicationState* other) { + if ((state && !other) || (!state && other)) return false; + if ((!state && !other) || *state == *other) return true; + return false; +} + +} // namespace + +QuicClientSessionCache::QuicClientSessionCache() + : QuicClientSessionCache(kDefaultMaxEntries) {} + +QuicClientSessionCache::QuicClientSessionCache(size_t max_entries) + : cache_(max_entries) {} + +QuicClientSessionCache::~QuicClientSessionCache() { Clear(); } + +void QuicClientSessionCache::Insert(const QuicServerId& server_id, + bssl::UniquePtr session, + const TransportParameters& params, + const ApplicationState* application_state) { + QUICHE_DCHECK(session) << "TLS session is not inserted into client cache."; + auto iter = cache_.Lookup(server_id); + if (iter == cache_.end()) { + CreateAndInsertEntry(server_id, std::move(session), params, + application_state); + return; + } + + QUICHE_DCHECK(iter->second->params); + // The states are both the same, so only need to insert sessions. + if (params == *iter->second->params && + DoApplicationStatesMatch(application_state, + iter->second->application_state.get())) { + iter->second->PushSession(std::move(session)); + return; + } + // Erase the existing entry because this Insert call must come from a + // different QUIC session. + cache_.Erase(iter); + CreateAndInsertEntry(server_id, std::move(session), params, + application_state); +} + +std::unique_ptr QuicClientSessionCache::Lookup( + const QuicServerId& server_id, QuicWallTime now, const SSL_CTX* /*ctx*/) { + auto iter = cache_.Lookup(server_id); + if (iter == cache_.end()) return nullptr; + + if (!IsValid(iter->second->PeekSession(), now.ToUNIXSeconds())) { + QUIC_DLOG(INFO) << "TLS Session expired for host:" << server_id.host(); + cache_.Erase(iter); + return nullptr; + } + auto state = std::make_unique(); + state->tls_session = iter->second->PopSession(); + if (iter->second->params != nullptr) { + state->transport_params = + std::make_unique(*iter->second->params); + } + if (iter->second->application_state != nullptr) { + state->application_state = + std::make_unique(*iter->second->application_state); + } + if (!iter->second->token.empty()) { + state->token = iter->second->token; + // Clear token after use. + iter->second->token.clear(); + } + + return state; +} + +void QuicClientSessionCache::ClearEarlyData(const QuicServerId& server_id) { + auto iter = cache_.Lookup(server_id); + if (iter == cache_.end()) return; + for (auto& session : iter->second->sessions) { + if (session) { + QUIC_DLOG(INFO) << "Clear early data for for host: " << server_id.host(); + session.reset(SSL_SESSION_copy_without_early_data(session.get())); + } + } +} + +void QuicClientSessionCache::OnNewTokenReceived(const QuicServerId& server_id, + absl::string_view token) { + if (token.empty()) { + return; + } + auto iter = cache_.Lookup(server_id); + if (iter == cache_.end()) { + return; + } + iter->second->token = std::string(token); +} + +void QuicClientSessionCache::RemoveExpiredEntries(QuicWallTime now) { + auto iter = cache_.begin(); + while (iter != cache_.end()) { + if (!IsValid(iter->second->PeekSession(), now.ToUNIXSeconds())) { + iter = cache_.Erase(iter); + } else { + ++iter; + } + } +} + +void QuicClientSessionCache::Clear() { cache_.Clear(); } + +void QuicClientSessionCache::CreateAndInsertEntry( + const QuicServerId& server_id, bssl::UniquePtr session, + const TransportParameters& params, + const ApplicationState* application_state) { + auto entry = std::make_unique(); + entry->PushSession(std::move(session)); + entry->params = std::make_unique(params); + if (application_state) { + entry->application_state = + std::make_unique(*application_state); + } + cache_.Insert(server_id, std::move(entry)); +} + +QuicClientSessionCache::Entry::Entry() = default; +QuicClientSessionCache::Entry::Entry(Entry&&) = default; +QuicClientSessionCache::Entry::~Entry() = default; + +void QuicClientSessionCache::Entry::PushSession( + bssl::UniquePtr session) { + if (sessions[0] != nullptr) { + sessions[1] = std::move(sessions[0]); + } + sessions[0] = std::move(session); +} + +bssl::UniquePtr QuicClientSessionCache::Entry::PopSession() { + if (sessions[0] == nullptr) return nullptr; + bssl::UniquePtr session = std::move(sessions[0]); + sessions[0] = std::move(sessions[1]); + sessions[1] = nullptr; + return session; +} + +SSL_SESSION* QuicClientSessionCache::Entry::PeekSession() { + return sessions[0].get(); +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/quic_client_session_cache.h b/quiche/quic/core/crypto/quic_client_session_cache.h new file mode 100644 index 000000000000..e568db67bbb9 --- /dev/null +++ b/quiche/quic/core/crypto/quic_client_session_cache.h @@ -0,0 +1,82 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_QUIC_CLIENT_SESSION_CACHE_H_ +#define QUICHE_QUIC_CORE_CRYPTO_QUIC_CLIENT_SESSION_CACHE_H_ + +#include + +#include "quiche/quic/core/crypto/quic_crypto_client_config.h" +#include "quiche/quic/core/quic_lru_cache.h" +#include "quiche/quic/core/quic_server_id.h" + +namespace quic { + +namespace test { +class QuicClientSessionCachePeer; +} // namespace test + +// QuicClientSessionCache maps from QuicServerId to information used to resume +// TLS sessions for that server. +class QUIC_EXPORT_PRIVATE QuicClientSessionCache : public SessionCache { + public: + QuicClientSessionCache(); + explicit QuicClientSessionCache(size_t max_entries); + ~QuicClientSessionCache() override; + + void Insert(const QuicServerId& server_id, + bssl::UniquePtr session, + const TransportParameters& params, + const ApplicationState* application_state) override; + + std::unique_ptr Lookup(const QuicServerId& server_id, + QuicWallTime now, + const SSL_CTX* ctx) override; + + void ClearEarlyData(const QuicServerId& server_id) override; + + void OnNewTokenReceived(const QuicServerId& server_id, + absl::string_view token) override; + + void RemoveExpiredEntries(QuicWallTime now) override; + + void Clear() override; + + size_t size() const { return cache_.Size(); } + + private: + friend class test::QuicClientSessionCachePeer; + + struct QUIC_EXPORT_PRIVATE Entry { + Entry(); + Entry(Entry&&); + ~Entry(); + + // Adds a new |session| onto sessions, dropping the oldest one if two are + // already stored. + void PushSession(bssl::UniquePtr session); + + // Retrieves the latest session from the entry, meanwhile removing it. + bssl::UniquePtr PopSession(); + + SSL_SESSION* PeekSession(); + + bssl::UniquePtr sessions[2]; + std::unique_ptr params; + std::unique_ptr application_state; + std::string token; // An opaque string received in NEW_TOKEN frame. + }; + + // Creates a new entry and insert into |cache_|. + void CreateAndInsertEntry(const QuicServerId& server_id, + bssl::UniquePtr session, + const TransportParameters& params, + const ApplicationState* application_state); + + QuicLRUCache cache_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_QUIC_CLIENT_SESSION_CACHE_H_ diff --git a/quiche/quic/core/crypto/quic_client_session_cache_test.cc b/quiche/quic/core/crypto/quic_client_session_cache_test.cc new file mode 100644 index 000000000000..880770a77433 --- /dev/null +++ b/quiche/quic/core/crypto/quic_client_session_cache_test.cc @@ -0,0 +1,440 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/quic_client_session_cache.h" + +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/mock_clock.h" +#include "quiche/common/quiche_text_utils.h" + +namespace quic { +namespace test { +namespace { + +const QuicTime::Delta kTimeout = QuicTime::Delta::FromSeconds(1000); +const QuicVersionLabel kFakeVersionLabel = 0x01234567; +const QuicVersionLabel kFakeVersionLabel2 = 0x89ABCDEF; +const uint64_t kFakeIdleTimeoutMilliseconds = 12012; +const uint8_t kFakeStatelessResetTokenData[16] = { + 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, + 0x98, 0x99, 0x9A, 0x9B, 0x9C, 0x9D, 0x9E, 0x9F}; +const uint64_t kFakeMaxPacketSize = 9001; +const uint64_t kFakeInitialMaxData = 101; +const bool kFakeDisableMigration = true; +const auto kCustomParameter1 = + static_cast(0xffcd); +const char* kCustomParameter1Value = "foo"; +const auto kCustomParameter2 = + static_cast(0xff34); +const char* kCustomParameter2Value = "bar"; + +std::vector CreateFakeStatelessResetToken() { + return std::vector( + kFakeStatelessResetTokenData, + kFakeStatelessResetTokenData + sizeof(kFakeStatelessResetTokenData)); +} + +TransportParameters::LegacyVersionInformation +CreateFakeLegacyVersionInformation() { + TransportParameters::LegacyVersionInformation legacy_version_information; + legacy_version_information.version = kFakeVersionLabel; + legacy_version_information.supported_versions.push_back(kFakeVersionLabel); + legacy_version_information.supported_versions.push_back(kFakeVersionLabel2); + return legacy_version_information; +} + +TransportParameters::VersionInformation CreateFakeVersionInformation() { + TransportParameters::VersionInformation version_information; + version_information.chosen_version = kFakeVersionLabel; + version_information.other_versions.push_back(kFakeVersionLabel); + return version_information; +} + +// Make a TransportParameters that has a few fields set to help test comparison. +std::unique_ptr MakeFakeTransportParams() { + auto params = std::make_unique(); + params->perspective = Perspective::IS_CLIENT; + params->legacy_version_information = CreateFakeLegacyVersionInformation(); + params->version_information = CreateFakeVersionInformation(); + params->max_idle_timeout_ms.set_value(kFakeIdleTimeoutMilliseconds); + params->stateless_reset_token = CreateFakeStatelessResetToken(); + params->max_udp_payload_size.set_value(kFakeMaxPacketSize); + params->initial_max_data.set_value(kFakeInitialMaxData); + params->disable_active_migration = kFakeDisableMigration; + params->custom_parameters[kCustomParameter1] = kCustomParameter1Value; + params->custom_parameters[kCustomParameter2] = kCustomParameter2Value; + return params; +} + +// Generated by running TlsClientHandshakerTest.ZeroRttResumption and in +// TlsClientHandshaker::InsertSession calling SSL_SESSION_to_bytes to serialize +// the received 0-RTT capable ticket. +static const char kCachedSession[] = + "30820ad7020101020203040402130104206594ce84e61a866b56163c4ba09079aebf1d4f" + "6cbcbd38dc9d7066a38a76c9cf0420ec9062063582a4cc0a44f9ff93256a195153ba6032" + "0cf3c9189990932d838adaa10602046196f7b9a205020302a300a382039f3082039b3082" + "0183a00302010202021001300d06092a864886f70d010105050030623111300f06035504" + "030c08426f677573204941310b300906035504080c024d41310b30090603550406130255" + "533121301f06092a864886f70d0109011612626f67757340626f6775732d69612e636f6d" + "3110300e060355040a0c07426f6775734941301e170d3231303132383136323030315a17" + "0d3331303132363136323030315a3069311d301b06035504030c14746573745f6563632e" + "6578616d706c652e636f6d310b300906035504080c024d41310b30090603550406130255" + "53311e301c06092a864886f70d010901160f626f67757340626f6775732e636f6d310e30" + "0c060355040a0c05426f6775733059301306072a8648ce3d020106082a8648ce3d030107" + "034200041ba5e2b6f24e64990b9f24ae6d23473d8c77fbcfb7f554f36559529a69a57170" + "a10a81b7fe4a36ebf37b0a8c5e467a8443d8b8c002892aa5c1194bd843f42c9aa31f301d" + "301b0603551d11041430128210746573742e6578616d706c652e636f6d300d06092a8648" + "86f70d0101050500038202010019921d54ac06948763d609215f64f5d6540e3da886c6c9" + "61bc737a437719b4621416ef1229f39282d7d3234e1a5d57535473066233bd246eec8e96" + "1e0633cf4fe014c800e62599981820ec33d92e74ded0fa2953db1d81e19cb6890b6305b6" + "3ede8d3e9fcf3c09f3f57283acf08aa57be4ee9a68d00bb3e2ded5920c619b5d83e5194a" + "adb77ae5d61ed3e0a5670f0ae61cc3197329f0e71e3364dcab0405e9e4a6646adef8f022" + "6415ec16c8046307b1769029fe780bd576114dde2fa9b4a32aa70bc436549a24ee4907a9" + "045f6457ce8dfd8d62cc65315afe798ae1a948eefd70b035d415e73569c48fb20085de1a" + "87de039e6b0b9a5fcb4069df27f3a7a1409e72d1ac739c72f29ef786134207e61c79855f" + "c22e3ee5f6ad59a7b1ff0f18d79776f1c95efaebbebe381664132a58a1e7ff689945b7e0" + "88634b0872feeefbf6be020884b994c6a7ff435f2b3f609077ff97cb509cfa17ff479b34" + "e633e4b5bc46b20c5f27c80a2e2943f795a928acd5a3fc43c3af8425ad600c048b41d87e" + "6361bc72fc4e5e44680a3d325674ba6ffa760d2fc7d9e4847a8e0dd9d35a543324e18b94" + "2d42af6391ed1dd54a39e3f4a4c6b32486eb4ba72815dbd89c56fc053743a0b0483ce676" + "15defce6800c629b99d0cbc56da162487f475b7c246099eaf1e6d10a022b2f49c6af1da3" + "e8ed66096f267c4a76976b9572db7456ef90278330a4020400aa81b60481b3494e534543" + "55524500f3439e548c21d2ad6e5634cc1cc0045730819702010102020304040213010400" + "0420ec9062063582a4cc0a44f9ff93256a195153ba60320cf3c9189990932d838adaa106" + "02046196f7b9a205020302a300a4020400b20302011db5060404130800cdb807020500ff" + "ffffffb9050203093a80ba0404026833bb030101ffbc23042100d27d985bfce04833f02d" + "38366b219f4def42bc4ba1b01844d1778db11731487dbd020400be020400b20302011db3" + "8205da308205d6308203bea00302010202021000300d06092a864886f70d010105050030" + "62310b3009060355040613025553310b300906035504080c024d413110300e060355040a" + "0c07426f67757343413111300f06035504030c08426f6775732043413121301f06092a86" + "4886f70d0109011612626f67757340626f6775732d63612e636f6d3020170d3231303132" + "383136313935385a180f32303730303531313136313935385a30623111300f0603550403" + "0c08426f677573204941310b300906035504080c024d41310b3009060355040613025553" + "3121301f06092a864886f70d0109011612626f67757340626f6775732d69612e636f6d31" + "10300e060355040a0c07426f677573494130820222300d06092a864886f70d0101010500" + "0382020f003082020a028202010096c03a0ffc61bcedcd5ec9bf6f848b8a066b43f08377" + "3af518a6a0044f22e666e24d2ae741954e344302c4be04612185bd53bcd848eb322bf900" + "724eb0848047d647033ffbddb00f01d1de7c1cdb684f83c9bf5fd18ff60afad5a53b0d7d" + "2c2a50abc38df019cd7f50194d05bc4597a1ef8570ea04069a2c36d74496af126573ca18" + "8e470009b56250fadf2a04e837ee3837b36b1f08b7a0cfe2533d05f26484ce4e30203d01" + "517fffd3da63d0341079ddce16e9ab4dbf9d4049e5cc52326031e645dd682fe6220d9e0e" + "95451f5a82f3e1720dc13e8499466426a0bdbea9f6a76b3c9228dd3c79ab4dcc4c145ef0" + "e78d1ee8bfd4650692d7e28a54bed809d8f7b37fe24c586be59cc46638531cb291c8c156" + "8f08d67e768e51563e95a639c1f138b275ffad6a6a2a042ba9e26ad63c2ce63b600013f0" + "a6f0703ee51c4f457f7bab0391c2fc4c5bb3213742c9cf9941bff68cc2e1cc96139d35ed" + "1885244ddde0bf658416c486701841b81f7b17503d08c59a4db08a2a80755e007aa3b6c7" + "eadcaa9e07c8325f3689f100de23970b12c9d9f6d0a8fb35ba0fd75c64410318db4a13ac" + "3972ad16cdf6408af37013c7bcd7c42f20d6d04c3e39436c7531e8dafa219dd04b784ef0" + "3c70ee5a4782b33cafa925aa3deca62a14aed704f179b932efabc2b0c5c15a8a99bfc9e6" + "189dce7da50ea303594b6af9c933dd54b6e9d17c472d0203010001a38193308190300f06" + "03551d130101ff040530030101ff301d0603551d0e041604141a98e80029a80992b7e5e0" + "068ab9b3486cd839d6301f0603551d23041830168014780beeefe2fa419c48a438bdb30b" + "e37ef0b7a94e300b0603551d0f0404030202a430130603551d25040c300a06082b060105" + "05070301301b0603551d11041430128207426f67757343418207426f6775734941300d06" + "092a864886f70d010105050003820201009e822ed8064b1aabaddf1340010ea147f68c06" + "5a5a599ea305349f1b0e545a00817d6e55c7bf85560fab429ca72186c4d520b52f5cc121" + "abd068b06f3111494431d2522efa54642f907059e7db80b73bb5ecf621377195b8700bba" + "df798cece8c67a9571548d0e6592e81ae5d934877cb170aef18d3b97f635600fe0890d98" + "f88b33fe3d1fd34c1c915beae4e5c0b133f476c40b21d220f16ce9cdd9e8f97a36a31723" + "68875f052c9271648d9cb54687c6fdc3ea96f2908003bc5e5e79de00a21da7b8429f8b08" + "af4c4d34641e386d72eabf5f01f106363f2ffd18969bf0bb9a4d17627c6427ff772c4308" + "83c276feef5fc6dba9582c22fdbe9df7e8dfca375695f028ed588df54f3c86462dbf4c07" + "91d80ca738988a1419c86bb4dd8d738b746921f01f39422e5ffd488b6f00195b996e6392" + "3a820a32cd78b5989f339c0fcf4f269103964a30a16347d0ffdc8df1f3653ddc1515fa09" + "22c7aef1af1fbcb23e93ae7622ab1ee11fcfa98319bad4c37c091cad46bd0337b3cc78b5" + "5b9f1ea7994acc1f89c49a0b4cb540d2137e266fd43e56a9b5b778217b6f77df530e1eaf" + "b3417262b5ddb86d3c6c5ac51e3f326c650dcc2434473973b7182c66220d1f3871bde7ee" + "47d3f359d3d4c5bdd61baa684c03db4c75f9d6690c9e6e3abe6eaf5fa2c33c4daf26b373" + "d85a1e8a7d671ac4a0a97b14e36e81280de4593bbb12da7695b5060404130800cdb60301" + "0100b70402020403b807020500ffffffffb9050203093a80ba0404026833bb030101ffbd" + "020400be020400"; + +class QuicClientSessionCacheTest : public QuicTest { + public: + QuicClientSessionCacheTest() : ssl_ctx_(SSL_CTX_new(TLS_method())) { + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(1)); + } + + protected: + bssl::UniquePtr NewSSLSession() { + std::string cached_session = + absl::HexStringToBytes(absl::string_view(kCachedSession)); + SSL_SESSION* session = SSL_SESSION_from_bytes( + reinterpret_cast(cached_session.data()), + cached_session.size(), ssl_ctx_.get()); + QUICHE_DCHECK(session); + return bssl::UniquePtr(session); + } + + bssl::UniquePtr MakeTestSession( + QuicTime::Delta timeout = kTimeout) { + bssl::UniquePtr session = NewSSLSession(); + SSL_SESSION_set_time(session.get(), clock_.WallNow().ToUNIXSeconds()); + SSL_SESSION_set_timeout(session.get(), timeout.ToSeconds()); + return session; + } + + bssl::UniquePtr ssl_ctx_; + MockClock clock_; +}; + +// Tests that simple insertion and lookup work correctly. +TEST_F(QuicClientSessionCacheTest, SingleSession) { + QuicClientSessionCache cache; + + auto params = MakeFakeTransportParams(); + auto session = MakeTestSession(); + QuicServerId id1("a.com", 443); + + auto params2 = MakeFakeTransportParams(); + auto session2 = MakeTestSession(); + SSL_SESSION* unowned2 = session2.get(); + QuicServerId id2("b.com", 443); + + EXPECT_EQ(nullptr, cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())); + EXPECT_EQ(nullptr, cache.Lookup(id2, clock_.WallNow(), ssl_ctx_.get())); + EXPECT_EQ(0u, cache.size()); + + cache.Insert(id1, std::move(session), *params, nullptr); + EXPECT_EQ(1u, cache.size()); + EXPECT_EQ( + *params, + *(cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())->transport_params)); + EXPECT_EQ(nullptr, cache.Lookup(id2, clock_.WallNow(), ssl_ctx_.get())); + // No session is available for id1, even though the entry exists. + EXPECT_EQ(1u, cache.size()); + EXPECT_EQ(nullptr, cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())); + // Lookup() will trigger a deletion of invalid entry. + EXPECT_EQ(0u, cache.size()); + + auto session3 = MakeTestSession(); + SSL_SESSION* unowned3 = session3.get(); + QuicServerId id3("c.com", 443); + cache.Insert(id3, std::move(session3), *params, nullptr); + cache.Insert(id2, std::move(session2), *params2, nullptr); + EXPECT_EQ(2u, cache.size()); + EXPECT_EQ( + unowned2, + cache.Lookup(id2, clock_.WallNow(), ssl_ctx_.get())->tls_session.get()); + EXPECT_EQ( + unowned3, + cache.Lookup(id3, clock_.WallNow(), ssl_ctx_.get())->tls_session.get()); + + // Verify that the cache is cleared after Lookups. + EXPECT_EQ(nullptr, cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())); + EXPECT_EQ(nullptr, cache.Lookup(id2, clock_.WallNow(), ssl_ctx_.get())); + EXPECT_EQ(nullptr, cache.Lookup(id3, clock_.WallNow(), ssl_ctx_.get())); + EXPECT_EQ(0u, cache.size()); +} + +TEST_F(QuicClientSessionCacheTest, MultipleSessions) { + QuicClientSessionCache cache; + + auto params = MakeFakeTransportParams(); + auto session = MakeTestSession(); + QuicServerId id1("a.com", 443); + auto session2 = MakeTestSession(); + SSL_SESSION* unowned2 = session2.get(); + auto session3 = MakeTestSession(); + SSL_SESSION* unowned3 = session3.get(); + + cache.Insert(id1, std::move(session), *params, nullptr); + cache.Insert(id1, std::move(session2), *params, nullptr); + cache.Insert(id1, std::move(session3), *params, nullptr); + // The latest session is popped first. + EXPECT_EQ( + unowned3, + cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())->tls_session.get()); + EXPECT_EQ( + unowned2, + cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())->tls_session.get()); + // Only two sessions are cached. + EXPECT_EQ(nullptr, cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())); +} + +// Test that when a different TransportParameter is inserted for +// the same server id, the existing entry is removed. +TEST_F(QuicClientSessionCacheTest, DifferentTransportParams) { + QuicClientSessionCache cache; + + auto params = MakeFakeTransportParams(); + auto session = MakeTestSession(); + QuicServerId id1("a.com", 443); + auto session2 = MakeTestSession(); + auto session3 = MakeTestSession(); + SSL_SESSION* unowned3 = session3.get(); + + cache.Insert(id1, std::move(session), *params, nullptr); + cache.Insert(id1, std::move(session2), *params, nullptr); + // tweak the transport parameters a little bit. + params->perspective = Perspective::IS_SERVER; + cache.Insert(id1, std::move(session3), *params, nullptr); + auto resumption_state = cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get()); + EXPECT_EQ(unowned3, resumption_state->tls_session.get()); + EXPECT_EQ(*params.get(), *resumption_state->transport_params); + EXPECT_EQ(nullptr, cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())); +} + +TEST_F(QuicClientSessionCacheTest, DifferentApplicationState) { + QuicClientSessionCache cache; + + auto params = MakeFakeTransportParams(); + auto session = MakeTestSession(); + QuicServerId id1("a.com", 443); + auto session2 = MakeTestSession(); + auto session3 = MakeTestSession(); + SSL_SESSION* unowned3 = session3.get(); + ApplicationState state; + state.push_back('a'); + + cache.Insert(id1, std::move(session), *params, &state); + cache.Insert(id1, std::move(session2), *params, &state); + cache.Insert(id1, std::move(session3), *params, nullptr); + auto resumption_state = cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get()); + EXPECT_EQ(unowned3, resumption_state->tls_session.get()); + EXPECT_EQ(nullptr, resumption_state->application_state); + EXPECT_EQ(nullptr, cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())); +} + +TEST_F(QuicClientSessionCacheTest, BothStatesDifferent) { + QuicClientSessionCache cache; + + auto params = MakeFakeTransportParams(); + auto session = MakeTestSession(); + QuicServerId id1("a.com", 443); + auto session2 = MakeTestSession(); + auto session3 = MakeTestSession(); + SSL_SESSION* unowned3 = session3.get(); + ApplicationState state; + state.push_back('a'); + + cache.Insert(id1, std::move(session), *params, &state); + cache.Insert(id1, std::move(session2), *params, &state); + params->perspective = Perspective::IS_SERVER; + cache.Insert(id1, std::move(session3), *params, nullptr); + auto resumption_state = cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get()); + EXPECT_EQ(unowned3, resumption_state->tls_session.get()); + EXPECT_EQ(*params.get(), *resumption_state->transport_params); + EXPECT_EQ(nullptr, resumption_state->application_state); + EXPECT_EQ(nullptr, cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())); +} + +// When the size limit is exceeded, the oldest entry should be erased. +TEST_F(QuicClientSessionCacheTest, SizeLimit) { + QuicClientSessionCache cache(2); + + auto params = MakeFakeTransportParams(); + auto session = MakeTestSession(); + QuicServerId id1("a.com", 443); + + auto session2 = MakeTestSession(); + SSL_SESSION* unowned2 = session2.get(); + QuicServerId id2("b.com", 443); + + auto session3 = MakeTestSession(); + SSL_SESSION* unowned3 = session3.get(); + QuicServerId id3("c.com", 443); + + cache.Insert(id1, std::move(session), *params, nullptr); + cache.Insert(id2, std::move(session2), *params, nullptr); + cache.Insert(id3, std::move(session3), *params, nullptr); + + EXPECT_EQ(2u, cache.size()); + EXPECT_EQ( + unowned2, + cache.Lookup(id2, clock_.WallNow(), ssl_ctx_.get())->tls_session.get()); + EXPECT_EQ( + unowned3, + cache.Lookup(id3, clock_.WallNow(), ssl_ctx_.get())->tls_session.get()); + EXPECT_EQ(nullptr, cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())); +} + +TEST_F(QuicClientSessionCacheTest, ClearEarlyData) { + QuicClientSessionCache cache; + SSL_CTX_set_early_data_enabled(ssl_ctx_.get(), 1); + auto params = MakeFakeTransportParams(); + auto session = MakeTestSession(); + QuicServerId id1("a.com", 443); + auto session2 = MakeTestSession(); + + EXPECT_TRUE(SSL_SESSION_early_data_capable(session.get())); + EXPECT_TRUE(SSL_SESSION_early_data_capable(session2.get())); + + cache.Insert(id1, std::move(session), *params, nullptr); + cache.Insert(id1, std::move(session2), *params, nullptr); + + cache.ClearEarlyData(id1); + + auto resumption_state = cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get()); + EXPECT_FALSE( + SSL_SESSION_early_data_capable(resumption_state->tls_session.get())); + resumption_state = cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get()); + EXPECT_FALSE( + SSL_SESSION_early_data_capable(resumption_state->tls_session.get())); + EXPECT_EQ(nullptr, cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())); +} + +// Expired session isn't considered valid and nullptr will be returned upon +// Lookup. +TEST_F(QuicClientSessionCacheTest, Expiration) { + QuicClientSessionCache cache; + + auto params = MakeFakeTransportParams(); + auto session = MakeTestSession(); + QuicServerId id1("a.com", 443); + + auto session2 = MakeTestSession(3 * kTimeout); + SSL_SESSION* unowned2 = session2.get(); + QuicServerId id2("b.com", 443); + + cache.Insert(id1, std::move(session), *params, nullptr); + cache.Insert(id2, std::move(session2), *params, nullptr); + + EXPECT_EQ(2u, cache.size()); + // Expire the session. + clock_.AdvanceTime(kTimeout * 2); + // The entry has not been removed yet. + EXPECT_EQ(2u, cache.size()); + + EXPECT_EQ(nullptr, cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())); + EXPECT_EQ(1u, cache.size()); + EXPECT_EQ( + unowned2, + cache.Lookup(id2, clock_.WallNow(), ssl_ctx_.get())->tls_session.get()); + EXPECT_EQ(1u, cache.size()); +} + +TEST_F(QuicClientSessionCacheTest, RemoveExpiredEntriesAndClear) { + QuicClientSessionCache cache; + + auto params = MakeFakeTransportParams(); + auto session = MakeTestSession(); + quic::QuicServerId id1("a.com", 443); + + auto session2 = MakeTestSession(3 * kTimeout); + quic::QuicServerId id2("b.com", 443); + + cache.Insert(id1, std::move(session), *params, nullptr); + cache.Insert(id2, std::move(session2), *params, nullptr); + + EXPECT_EQ(2u, cache.size()); + // Expire the session. + clock_.AdvanceTime(kTimeout * 2); + // The entry has not been removed yet. + EXPECT_EQ(2u, cache.size()); + + // Flush expired sessions. + cache.RemoveExpiredEntries(clock_.WallNow()); + + // session is expired and should be flushed. + EXPECT_EQ(nullptr, cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())); + EXPECT_EQ(1u, cache.size()); + + cache.Clear(); + EXPECT_EQ(0u, cache.size()); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/quic_compressed_certs_cache.cc b/quiche/quic/core/crypto/quic_compressed_certs_cache.cc new file mode 100644 index 000000000000..dabbf2402a27 --- /dev/null +++ b/quiche/quic/core/crypto/quic_compressed_certs_cache.cc @@ -0,0 +1,114 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/quic_compressed_certs_cache.h" + +#include + +namespace quic { + +namespace { + +// Inline helper function for extending a 64-bit |seed| in-place with a 64-bit +// |value|. Based on Boost's hash_combine function. +inline void hash_combine(uint64_t* seed, const uint64_t& val) { + (*seed) ^= val + 0x9e3779b9 + ((*seed) << 6) + ((*seed) >> 2); +} + +} // namespace + +const size_t QuicCompressedCertsCache::kQuicCompressedCertsCacheSize = 225; + +QuicCompressedCertsCache::UncompressedCerts::UncompressedCerts() + : chain(nullptr), client_cached_cert_hashes(nullptr) {} + +QuicCompressedCertsCache::UncompressedCerts::UncompressedCerts( + const quiche::QuicheReferenceCountedPointer& chain, + const std::string* client_cached_cert_hashes) + : chain(chain), client_cached_cert_hashes(client_cached_cert_hashes) {} + +QuicCompressedCertsCache::UncompressedCerts::~UncompressedCerts() {} + +QuicCompressedCertsCache::CachedCerts::CachedCerts() {} + +QuicCompressedCertsCache::CachedCerts::CachedCerts( + const UncompressedCerts& uncompressed_certs, + const std::string& compressed_cert) + : chain_(uncompressed_certs.chain), + client_cached_cert_hashes_(*uncompressed_certs.client_cached_cert_hashes), + compressed_cert_(compressed_cert) {} + +QuicCompressedCertsCache::CachedCerts::CachedCerts(const CachedCerts& other) = + default; + +QuicCompressedCertsCache::CachedCerts::~CachedCerts() {} + +bool QuicCompressedCertsCache::CachedCerts::MatchesUncompressedCerts( + const UncompressedCerts& uncompressed_certs) const { + return (client_cached_cert_hashes_ == + *uncompressed_certs.client_cached_cert_hashes && + chain_ == uncompressed_certs.chain); +} + +const std::string* QuicCompressedCertsCache::CachedCerts::compressed_cert() + const { + return &compressed_cert_; +} + +QuicCompressedCertsCache::QuicCompressedCertsCache(int64_t max_num_certs) + : certs_cache_(max_num_certs) {} + +QuicCompressedCertsCache::~QuicCompressedCertsCache() { + // Underlying cache must be cleared before destruction. + certs_cache_.Clear(); +} + +const std::string* QuicCompressedCertsCache::GetCompressedCert( + const quiche::QuicheReferenceCountedPointer& chain, + const std::string& client_cached_cert_hashes) { + UncompressedCerts uncompressed_certs(chain, &client_cached_cert_hashes); + + uint64_t key = ComputeUncompressedCertsHash(uncompressed_certs); + + CachedCerts* cached_value = nullptr; + auto iter = certs_cache_.Lookup(key); + if (iter != certs_cache_.end()) { + cached_value = iter->second.get(); + } + if (cached_value != nullptr && + cached_value->MatchesUncompressedCerts(uncompressed_certs)) { + return cached_value->compressed_cert(); + } + return nullptr; +} + +void QuicCompressedCertsCache::Insert( + const quiche::QuicheReferenceCountedPointer& chain, + const std::string& client_cached_cert_hashes, + const std::string& compressed_cert) { + UncompressedCerts uncompressed_certs(chain, &client_cached_cert_hashes); + + uint64_t key = ComputeUncompressedCertsHash(uncompressed_certs); + + // Insert one unit to the cache. + std::unique_ptr cached_certs( + new CachedCerts(uncompressed_certs, compressed_cert)); + certs_cache_.Insert(key, std::move(cached_certs)); +} + +size_t QuicCompressedCertsCache::MaxSize() { return certs_cache_.MaxSize(); } + +size_t QuicCompressedCertsCache::Size() { return certs_cache_.Size(); } + +uint64_t QuicCompressedCertsCache::ComputeUncompressedCertsHash( + const UncompressedCerts& uncompressed_certs) { + uint64_t hash = + std::hash()(*uncompressed_certs.client_cached_cert_hashes); + + hash_combine(&hash, + reinterpret_cast(uncompressed_certs.chain.get())); + return hash; +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/quic_compressed_certs_cache.h b/quiche/quic/core/crypto/quic_compressed_certs_cache.h new file mode 100644 index 000000000000..918981e717d6 --- /dev/null +++ b/quiche/quic/core/crypto/quic_compressed_certs_cache.h @@ -0,0 +1,103 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_QUIC_COMPRESSED_CERTS_CACHE_H_ +#define QUICHE_QUIC_CORE_CRYPTO_QUIC_COMPRESSED_CERTS_CACHE_H_ + +#include +#include + +#include "quiche/quic/core/crypto/proof_source.h" +#include "quiche/quic/core/quic_lru_cache.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// QuicCompressedCertsCache is a cache to track most recently compressed certs. +class QUIC_EXPORT_PRIVATE QuicCompressedCertsCache { + public: + explicit QuicCompressedCertsCache(int64_t max_num_certs); + ~QuicCompressedCertsCache(); + + // Returns the pointer to the cached compressed cert if + // |chain, client_cached_cert_hashes| hits cache. + // Otherwise, return nullptr. + // Returned pointer might become invalid on the next call to Insert(). + const std::string* GetCompressedCert( + const quiche::QuicheReferenceCountedPointer& chain, + const std::string& client_cached_cert_hashes); + + // Inserts the specified + // |chain, client_cached_cert_hashes, compressed_cert| tuple to the cache. + // If the insertion causes the cache to become overfull, entries will + // be deleted in an LRU order to make room. + void Insert( + const quiche::QuicheReferenceCountedPointer& chain, + const std::string& client_cached_cert_hashes, + const std::string& compressed_cert); + + // Returns max number of cache entries the cache can carry. + size_t MaxSize(); + + // Returns current number of cache entries in the cache. + size_t Size(); + + // Default size of the QuicCompressedCertsCache per server side investigation. + static const size_t kQuicCompressedCertsCacheSize; + + private: + // A wrapper of the tuple: + // |chain, client_cached_cert_hashes| + // to identify uncompressed representation of certs. + struct QUIC_EXPORT_PRIVATE UncompressedCerts { + UncompressedCerts(); + UncompressedCerts( + const quiche::QuicheReferenceCountedPointer& chain, + const std::string* client_cached_cert_hashes); + ~UncompressedCerts(); + + const quiche::QuicheReferenceCountedPointer chain; + const std::string* client_cached_cert_hashes; + }; + + // Certs stored by QuicCompressedCertsCache where uncompressed certs data is + // used to identify the uncompressed representation of certs and + // |compressed_cert| is the cached compressed representation. + class QUIC_EXPORT_PRIVATE CachedCerts { + public: + CachedCerts(); + CachedCerts(const UncompressedCerts& uncompressed_certs, + const std::string& compressed_cert); + CachedCerts(const CachedCerts& other); + ~CachedCerts(); + + // Returns true if the |uncompressed_certs| matches uncompressed + // representation of this cert. + bool MatchesUncompressedCerts( + const UncompressedCerts& uncompressed_certs) const; + + const std::string* compressed_cert() const; + + private: + // Uncompressed certs data. + quiche::QuicheReferenceCountedPointer chain_; + const std::string client_cached_cert_hashes_; + + // Cached compressed representation derived from uncompressed certs. + const std::string compressed_cert_; + }; + + // Computes a uint64_t hash for |uncompressed_certs|. + uint64_t ComputeUncompressedCertsHash( + const UncompressedCerts& uncompressed_certs); + + // Key is a unit64_t hash for UncompressedCerts. Stored associated value is + // CachedCerts which has both original uncompressed certs data and the + // compressed representation of the certs. + QuicLRUCache certs_cache_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_QUIC_COMPRESSED_CERTS_CACHE_H_ diff --git a/quiche/quic/core/crypto/quic_compressed_certs_cache_test.cc b/quiche/quic/core/crypto/quic_compressed_certs_cache_test.cc new file mode 100644 index 000000000000..b98f9f2cbb80 --- /dev/null +++ b/quiche/quic/core/crypto/quic_compressed_certs_cache_test.cc @@ -0,0 +1,91 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/quic_compressed_certs_cache.h" + +#include + +#include "absl/strings/str_cat.h" +#include "quiche/quic/core/crypto/cert_compressor.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" + +namespace quic { + +namespace test { + +namespace { + +class QuicCompressedCertsCacheTest : public QuicTest { + public: + QuicCompressedCertsCacheTest() + : certs_cache_(QuicCompressedCertsCache::kQuicCompressedCertsCacheSize) {} + + protected: + QuicCompressedCertsCache certs_cache_; +}; + +TEST_F(QuicCompressedCertsCacheTest, CacheHit) { + std::vector certs = {"leaf cert", "intermediate cert", + "root cert"}; + quiche::QuicheReferenceCountedPointer chain( + new ProofSource::Chain(certs)); + std::string cached_certs = "cached certs"; + std::string compressed = "compressed cert"; + + certs_cache_.Insert(chain, cached_certs, compressed); + + const std::string* cached_value = + certs_cache_.GetCompressedCert(chain, cached_certs); + ASSERT_NE(nullptr, cached_value); + EXPECT_EQ(*cached_value, compressed); +} + +TEST_F(QuicCompressedCertsCacheTest, CacheMiss) { + std::vector certs = {"leaf cert", "intermediate cert", + "root cert"}; + quiche::QuicheReferenceCountedPointer chain( + new ProofSource::Chain(certs)); + + std::string cached_certs = "cached certs"; + std::string compressed = "compressed cert"; + + certs_cache_.Insert(chain, cached_certs, compressed); + + EXPECT_EQ(nullptr, + certs_cache_.GetCompressedCert(chain, "mismatched cached certs")); + + // A different chain though with equivalent certs should get a cache miss. + quiche::QuicheReferenceCountedPointer chain2( + new ProofSource::Chain(certs)); + EXPECT_EQ(nullptr, certs_cache_.GetCompressedCert(chain2, cached_certs)); +} + +TEST_F(QuicCompressedCertsCacheTest, CacheMissDueToEviction) { + // Test cache returns a miss when a queried uncompressed certs was cached but + // then evicted. + std::vector certs = {"leaf cert", "intermediate cert", + "root cert"}; + quiche::QuicheReferenceCountedPointer chain( + new ProofSource::Chain(certs)); + + std::string cached_certs = "cached certs"; + std::string compressed = "compressed cert"; + certs_cache_.Insert(chain, cached_certs, compressed); + + // Insert another kQuicCompressedCertsCacheSize certs to evict the first + // cached cert. + for (unsigned int i = 0; + i < QuicCompressedCertsCache::kQuicCompressedCertsCacheSize; i++) { + EXPECT_EQ(certs_cache_.Size(), i + 1); + certs_cache_.Insert(chain, absl::StrCat(i), absl::StrCat(i)); + } + EXPECT_EQ(certs_cache_.MaxSize(), certs_cache_.Size()); + + EXPECT_EQ(nullptr, certs_cache_.GetCompressedCert(chain, cached_certs)); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/quic_crypter.cc b/quiche/quic/core/crypto/quic_crypter.cc new file mode 100644 index 000000000000..965023604099 --- /dev/null +++ b/quiche/quic/core/crypto/quic_crypter.cc @@ -0,0 +1,19 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/quic_crypter.h" + +#include "absl/strings/string_view.h" + +namespace quic { + +bool QuicCrypter::SetNoncePrefixOrIV(const ParsedQuicVersion& version, + absl::string_view nonce_prefix_or_iv) { + if (version.UsesInitialObfuscators()) { + return SetIV(nonce_prefix_or_iv); + } + return SetNoncePrefix(nonce_prefix_or_iv); +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/quic_crypter.h b/quiche/quic/core/crypto/quic_crypter.h new file mode 100644 index 000000000000..d57a8e63bb93 --- /dev/null +++ b/quiche/quic/core/crypto/quic_crypter.h @@ -0,0 +1,94 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_QUIC_CRYPTER_H_ +#define QUICHE_QUIC_CORE_CRYPTO_QUIC_CRYPTER_H_ + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// QuicCrypter is the parent class for QuicEncrypter and QuicDecrypter. +// Its purpose is to provide an interface for using methods that are common to +// both classes when operations are being done that apply to both encrypters and +// decrypters. +class QUIC_EXPORT_PRIVATE QuicCrypter { + public: + virtual ~QuicCrypter() {} + + // Sets the symmetric encryption/decryption key. Returns true on success, + // false on failure. + // + // NOTE: The key is the client_write_key or server_write_key derived from + // the master secret. + virtual bool SetKey(absl::string_view key) = 0; + + // Sets the fixed initial bytes of the nonce. Returns true on success, + // false on failure. This method must only be used with Google QUIC crypters. + // + // NOTE: The nonce prefix is the client_write_iv or server_write_iv + // derived from the master secret. A 64-bit packet number will + // be appended to form the nonce. + // + // <------------ 64 bits -----------> + // +---------------------+----------------------------------+ + // | Fixed prefix | packet number | + // +---------------------+----------------------------------+ + // Nonce format + // + // The security of the nonce format requires that QUIC never reuse a + // packet number, even when retransmitting a lost packet. + virtual bool SetNoncePrefix(absl::string_view nonce_prefix) = 0; + + // Sets |iv| as the initialization vector to use when constructing the nonce. + // Returns true on success, false on failure. This method must only be used + // with IETF QUIC crypters. + // + // Google QUIC and IETF QUIC use different nonce constructions. This method + // must be used when using IETF QUIC; SetNoncePrefix must be used when using + // Google QUIC. + // + // The nonce is constructed as follows (draft-ietf-quic-tls-14 section 5.2): + // + // <---------------- max(8, N_MIN) bytes -----------------> + // +--------------------------------------------------------+ + // | packet protection IV | + // +--------------------------------------------------------+ + // XOR + // <------------ 64 bits -----------> + // +---------------------+----------------------------------+ + // | zeroes | reconstructed packet number | + // +---------------------+----------------------------------+ + // + // The nonce is the packet protection IV (|iv|) XOR'd with the left-padded + // reconstructed packet number. + // + // The security of the nonce format requires that QUIC never reuse a + // packet number, even when retransmitting a lost packet. + virtual bool SetIV(absl::string_view iv) = 0; + + // Calls SetNoncePrefix or SetIV depending on whether |version| uses the + // Google QUIC crypto or IETF QUIC nonce construction. + virtual bool SetNoncePrefixOrIV(const ParsedQuicVersion& version, + absl::string_view nonce_prefix_or_iv); + + // Sets the key to use for header protection. + virtual bool SetHeaderProtectionKey(absl::string_view key) = 0; + + // GetKeySize, GetIVSize, and GetNoncePrefixSize are used to know how many + // bytes of key material needs to be derived from the master secret. + + // Returns the size in bytes of a key for the algorithm. + virtual size_t GetKeySize() const = 0; + // Returns the size in bytes of an IV to use with the algorithm. + virtual size_t GetIVSize() const = 0; + // Returns the size in bytes of the fixed initial part of the nonce. + virtual size_t GetNoncePrefixSize() const = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_QUIC_CRYPTER_H_ diff --git a/quiche/quic/core/crypto/quic_crypto_client_config.cc b/quiche/quic/core/crypto/quic_crypto_client_config.cc new file mode 100644 index 000000000000..c9137de70d6f --- /dev/null +++ b/quiche/quic/core/crypto/quic_crypto_client_config.cc @@ -0,0 +1,842 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/quic_crypto_client_config.h" + +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/memory/memory.h" +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" +#include "openssl/ssl.h" +#include "quiche/quic/core/crypto/cert_compressor.h" +#include "quiche/quic/core/crypto/chacha20_poly1305_encrypter.h" +#include "quiche/quic/core/crypto/crypto_framer.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/crypto/crypto_utils.h" +#include "quiche/quic/core/crypto/curve25519_key_exchange.h" +#include "quiche/quic/core/crypto/key_exchange.h" +#include "quiche/quic/core/crypto/p256_key_exchange.h" +#include "quiche/quic/core/crypto/proof_verifier.h" +#include "quiche/quic/core/crypto/quic_encrypter.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/crypto/tls_client_connection.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_client_stats.h" +#include "quiche/quic/platform/api/quic_hostname_utils.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +namespace { + +// Tracks the reason (the state of the server config) for sending inchoate +// ClientHello to the server. +void RecordInchoateClientHelloReason( + QuicCryptoClientConfig::CachedState::ServerConfigState state) { + QUIC_CLIENT_HISTOGRAM_ENUM( + "QuicInchoateClientHelloReason", state, + QuicCryptoClientConfig::CachedState::SERVER_CONFIG_COUNT, ""); +} + +// Tracks the state of the QUIC server information loaded from the disk cache. +void RecordDiskCacheServerConfigState( + QuicCryptoClientConfig::CachedState::ServerConfigState state) { + QUIC_CLIENT_HISTOGRAM_ENUM( + "QuicServerInfo.DiskCacheState", state, + QuicCryptoClientConfig::CachedState::SERVER_CONFIG_COUNT, ""); +} + +} // namespace + +QuicCryptoClientConfig::QuicCryptoClientConfig( + std::unique_ptr proof_verifier) + : QuicCryptoClientConfig(std::move(proof_verifier), nullptr) {} + +QuicCryptoClientConfig::QuicCryptoClientConfig( + std::unique_ptr proof_verifier, + std::unique_ptr session_cache) + : proof_verifier_(std::move(proof_verifier)), + session_cache_(std::move(session_cache)), + ssl_ctx_(TlsClientConnection::CreateSslCtx( + !GetQuicFlag(quic_disable_client_tls_zero_rtt))) { + QUICHE_DCHECK(proof_verifier_.get()); + SetDefaults(); +} + +QuicCryptoClientConfig::~QuicCryptoClientConfig() {} + +QuicCryptoClientConfig::CachedState::CachedState() + : server_config_valid_(false), + expiration_time_(QuicWallTime::Zero()), + generation_counter_(0) {} + +QuicCryptoClientConfig::CachedState::~CachedState() {} + +bool QuicCryptoClientConfig::CachedState::IsComplete(QuicWallTime now) const { + if (server_config_.empty()) { + RecordInchoateClientHelloReason(SERVER_CONFIG_EMPTY); + return false; + } + + if (!server_config_valid_) { + RecordInchoateClientHelloReason(SERVER_CONFIG_INVALID); + return false; + } + + const CryptoHandshakeMessage* scfg = GetServerConfig(); + if (!scfg) { + // Should be impossible short of cache corruption. + RecordInchoateClientHelloReason(SERVER_CONFIG_CORRUPTED); + QUICHE_DCHECK(false); + return false; + } + + if (now.IsBefore(expiration_time_)) { + return true; + } + + QUIC_CLIENT_HISTOGRAM_TIMES( + "QuicClientHelloServerConfig.InvalidDuration", + QuicTime::Delta::FromSeconds(now.ToUNIXSeconds() - + expiration_time_.ToUNIXSeconds()), + QuicTime::Delta::FromSeconds(60), // 1 min. + QuicTime::Delta::FromSeconds(20 * 24 * 3600), // 20 days. + 50, ""); + RecordInchoateClientHelloReason(SERVER_CONFIG_EXPIRED); + return false; +} + +bool QuicCryptoClientConfig::CachedState::IsEmpty() const { + return server_config_.empty(); +} + +const CryptoHandshakeMessage* +QuicCryptoClientConfig::CachedState::GetServerConfig() const { + if (server_config_.empty()) { + return nullptr; + } + + if (!scfg_) { + scfg_ = CryptoFramer::ParseMessage(server_config_); + QUICHE_DCHECK(scfg_.get()); + } + return scfg_.get(); +} + +QuicCryptoClientConfig::CachedState::ServerConfigState +QuicCryptoClientConfig::CachedState::SetServerConfig( + absl::string_view server_config, QuicWallTime now, QuicWallTime expiry_time, + std::string* error_details) { + const bool matches_existing = server_config == server_config_; + + // Even if the new server config matches the existing one, we still wish to + // reject it if it has expired. + std::unique_ptr new_scfg_storage; + const CryptoHandshakeMessage* new_scfg; + + if (!matches_existing) { + new_scfg_storage = CryptoFramer::ParseMessage(server_config); + new_scfg = new_scfg_storage.get(); + } else { + new_scfg = GetServerConfig(); + } + + if (!new_scfg) { + *error_details = "SCFG invalid"; + return SERVER_CONFIG_INVALID; + } + + if (expiry_time.IsZero()) { + uint64_t expiry_seconds; + if (new_scfg->GetUint64(kEXPY, &expiry_seconds) != QUIC_NO_ERROR) { + *error_details = "SCFG missing EXPY"; + return SERVER_CONFIG_INVALID_EXPIRY; + } + expiration_time_ = QuicWallTime::FromUNIXSeconds(expiry_seconds); + } else { + expiration_time_ = expiry_time; + } + + if (now.IsAfter(expiration_time_)) { + *error_details = "SCFG has expired"; + return SERVER_CONFIG_EXPIRED; + } + + if (!matches_existing) { + server_config_ = std::string(server_config); + SetProofInvalid(); + scfg_ = std::move(new_scfg_storage); + } + return SERVER_CONFIG_VALID; +} + +void QuicCryptoClientConfig::CachedState::InvalidateServerConfig() { + server_config_.clear(); + scfg_.reset(); + SetProofInvalid(); +} + +void QuicCryptoClientConfig::CachedState::SetProof( + const std::vector& certs, absl::string_view cert_sct, + absl::string_view chlo_hash, absl::string_view signature) { + bool has_changed = signature != server_config_sig_ || + chlo_hash != chlo_hash_ || certs_.size() != certs.size(); + + if (!has_changed) { + for (size_t i = 0; i < certs_.size(); i++) { + if (certs_[i] != certs[i]) { + has_changed = true; + break; + } + } + } + + if (!has_changed) { + return; + } + + // If the proof has changed then it needs to be revalidated. + SetProofInvalid(); + certs_ = certs; + cert_sct_ = std::string(cert_sct); + chlo_hash_ = std::string(chlo_hash); + server_config_sig_ = std::string(signature); +} + +void QuicCryptoClientConfig::CachedState::Clear() { + server_config_.clear(); + source_address_token_.clear(); + certs_.clear(); + cert_sct_.clear(); + chlo_hash_.clear(); + server_config_sig_.clear(); + server_config_valid_ = false; + proof_verify_details_.reset(); + scfg_.reset(); + ++generation_counter_; +} + +void QuicCryptoClientConfig::CachedState::ClearProof() { + SetProofInvalid(); + certs_.clear(); + cert_sct_.clear(); + chlo_hash_.clear(); + server_config_sig_.clear(); +} + +void QuicCryptoClientConfig::CachedState::SetProofValid() { + server_config_valid_ = true; +} + +void QuicCryptoClientConfig::CachedState::SetProofInvalid() { + server_config_valid_ = false; + ++generation_counter_; +} + +bool QuicCryptoClientConfig::CachedState::Initialize( + absl::string_view server_config, absl::string_view source_address_token, + const std::vector& certs, const std::string& cert_sct, + absl::string_view chlo_hash, absl::string_view signature, QuicWallTime now, + QuicWallTime expiration_time) { + QUICHE_DCHECK(server_config_.empty()); + + if (server_config.empty()) { + RecordDiskCacheServerConfigState(SERVER_CONFIG_EMPTY); + return false; + } + + std::string error_details; + ServerConfigState state = + SetServerConfig(server_config, now, expiration_time, &error_details); + RecordDiskCacheServerConfigState(state); + if (state != SERVER_CONFIG_VALID) { + QUIC_DVLOG(1) << "SetServerConfig failed with " << error_details; + return false; + } + + chlo_hash_.assign(chlo_hash.data(), chlo_hash.size()); + server_config_sig_.assign(signature.data(), signature.size()); + source_address_token_.assign(source_address_token.data(), + source_address_token.size()); + certs_ = certs; + cert_sct_ = cert_sct; + return true; +} + +const std::string& QuicCryptoClientConfig::CachedState::server_config() const { + return server_config_; +} + +const std::string& QuicCryptoClientConfig::CachedState::source_address_token() + const { + return source_address_token_; +} + +const std::vector& QuicCryptoClientConfig::CachedState::certs() + const { + return certs_; +} + +const std::string& QuicCryptoClientConfig::CachedState::cert_sct() const { + return cert_sct_; +} + +const std::string& QuicCryptoClientConfig::CachedState::chlo_hash() const { + return chlo_hash_; +} + +const std::string& QuicCryptoClientConfig::CachedState::signature() const { + return server_config_sig_; +} + +bool QuicCryptoClientConfig::CachedState::proof_valid() const { + return server_config_valid_; +} + +uint64_t QuicCryptoClientConfig::CachedState::generation_counter() const { + return generation_counter_; +} + +const ProofVerifyDetails* +QuicCryptoClientConfig::CachedState::proof_verify_details() const { + return proof_verify_details_.get(); +} + +void QuicCryptoClientConfig::CachedState::set_source_address_token( + absl::string_view token) { + source_address_token_ = std::string(token); +} + +void QuicCryptoClientConfig::CachedState::set_cert_sct( + absl::string_view cert_sct) { + cert_sct_ = std::string(cert_sct); +} + +void QuicCryptoClientConfig::CachedState::SetProofVerifyDetails( + ProofVerifyDetails* details) { + proof_verify_details_.reset(details); +} + +void QuicCryptoClientConfig::CachedState::InitializeFrom( + const QuicCryptoClientConfig::CachedState& other) { + QUICHE_DCHECK(server_config_.empty()); + QUICHE_DCHECK(!server_config_valid_); + server_config_ = other.server_config_; + source_address_token_ = other.source_address_token_; + certs_ = other.certs_; + cert_sct_ = other.cert_sct_; + chlo_hash_ = other.chlo_hash_; + server_config_sig_ = other.server_config_sig_; + server_config_valid_ = other.server_config_valid_; + expiration_time_ = other.expiration_time_; + if (other.proof_verify_details_ != nullptr) { + proof_verify_details_.reset(other.proof_verify_details_->Clone()); + } + ++generation_counter_; +} + +void QuicCryptoClientConfig::SetDefaults() { + // Key exchange methods. + kexs = {kC255, kP256}; + + // Authenticated encryption algorithms. Prefer AES-GCM if hardware-supported + // fast implementation is available. + if (EVP_has_aes_hardware() == 1) { + aead = {kAESG, kCC20}; + } else { + aead = {kCC20, kAESG}; + } +} + +QuicCryptoClientConfig::CachedState* QuicCryptoClientConfig::LookupOrCreate( + const QuicServerId& server_id) { + auto it = cached_states_.find(server_id); + if (it != cached_states_.end()) { + return it->second.get(); + } + + CachedState* cached = new CachedState; + cached_states_.insert(std::make_pair(server_id, absl::WrapUnique(cached))); + bool cache_populated = PopulateFromCanonicalConfig(server_id, cached); + QUIC_CLIENT_HISTOGRAM_BOOL( + "QuicCryptoClientConfig.PopulatedFromCanonicalConfig", cache_populated, + ""); + return cached; +} + +void QuicCryptoClientConfig::ClearCachedStates(const ServerIdFilter& filter) { + for (auto it = cached_states_.begin(); it != cached_states_.end(); ++it) { + if (filter.Matches(it->first)) it->second->Clear(); + } +} + +void QuicCryptoClientConfig::FillInchoateClientHello( + const QuicServerId& server_id, const ParsedQuicVersion preferred_version, + const CachedState* cached, QuicRandom* rand, bool demand_x509_proof, + quiche::QuicheReferenceCountedPointer + out_params, + CryptoHandshakeMessage* out) const { + out->set_tag(kCHLO); + out->set_minimum_size(1); + + // Server name indication. We only send SNI if it's a valid domain name, as + // per the spec. + if (QuicHostnameUtils::IsValidSNI(server_id.host())) { + out->SetStringPiece(kSNI, server_id.host()); + } + out->SetVersion(kVER, preferred_version); + + if (!user_agent_id_.empty()) { + out->SetStringPiece(kUAID, user_agent_id_); + } + + if (!alpn_.empty()) { + out->SetStringPiece(kALPN, alpn_); + } + + // Even though this is an inchoate CHLO, send the SCID so that + // the STK can be validated by the server. + const CryptoHandshakeMessage* scfg = cached->GetServerConfig(); + if (scfg != nullptr) { + absl::string_view scid; + if (scfg->GetStringPiece(kSCID, &scid)) { + out->SetStringPiece(kSCID, scid); + } + } + + if (!cached->source_address_token().empty()) { + out->SetStringPiece(kSourceAddressTokenTag, cached->source_address_token()); + } + + if (!demand_x509_proof) { + return; + } + + char proof_nonce[32]; + rand->RandBytes(proof_nonce, ABSL_ARRAYSIZE(proof_nonce)); + out->SetStringPiece( + kNONP, absl::string_view(proof_nonce, ABSL_ARRAYSIZE(proof_nonce))); + + out->SetVector(kPDMD, QuicTagVector{kX509}); + + out->SetStringPiece(kCertificateSCTTag, ""); + + const std::vector& certs = cached->certs(); + // We save |certs| in the QuicCryptoNegotiatedParameters so that, if the + // client config is being used for multiple connections, another connection + // doesn't update the cached certificates and cause us to be unable to + // process the server's compressed certificate chain. + out_params->cached_certs = certs; + if (!certs.empty()) { + std::vector hashes; + hashes.reserve(certs.size()); + for (auto i = certs.begin(); i != certs.end(); ++i) { + hashes.push_back(QuicUtils::FNV1a_64_Hash(*i)); + } + out->SetVector(kCCRT, hashes); + } +} + +QuicErrorCode QuicCryptoClientConfig::FillClientHello( + const QuicServerId& server_id, QuicConnectionId connection_id, + const ParsedQuicVersion preferred_version, + const ParsedQuicVersion actual_version, const CachedState* cached, + QuicWallTime now, QuicRandom* rand, + quiche::QuicheReferenceCountedPointer + out_params, + CryptoHandshakeMessage* out, std::string* error_details) const { + QUICHE_DCHECK(error_details != nullptr); + QUIC_BUG_IF(quic_bug_12943_2, + !QuicUtils::IsConnectionIdValidForVersion( + connection_id, preferred_version.transport_version)) + << "FillClientHello: attempted to use connection ID " << connection_id + << " which is invalid with version " << preferred_version; + + FillInchoateClientHello(server_id, preferred_version, cached, rand, + /* demand_x509_proof= */ true, out_params, out); + + out->set_minimum_size(1); + + const CryptoHandshakeMessage* scfg = cached->GetServerConfig(); + if (!scfg) { + // This should never happen as our caller should have checked + // cached->IsComplete() before calling this function. + *error_details = "Handshake not ready"; + return QUIC_CRYPTO_INTERNAL_ERROR; + } + + absl::string_view scid; + if (!scfg->GetStringPiece(kSCID, &scid)) { + *error_details = "SCFG missing SCID"; + return QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER; + } + out->SetStringPiece(kSCID, scid); + + out->SetStringPiece(kCertificateSCTTag, ""); + + QuicTagVector their_aeads; + QuicTagVector their_key_exchanges; + if (scfg->GetTaglist(kAEAD, &their_aeads) != QUIC_NO_ERROR || + scfg->GetTaglist(kKEXS, &their_key_exchanges) != QUIC_NO_ERROR) { + *error_details = "Missing AEAD or KEXS"; + return QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER; + } + + // AEAD: the work loads on the client and server are symmetric. Since the + // client is more likely to be CPU-constrained, break the tie by favoring + // the client's preference. + // Key exchange: the client does more work than the server, so favor the + // client's preference. + size_t key_exchange_index; + if (!FindMutualQuicTag(aead, their_aeads, &out_params->aead, nullptr) || + !FindMutualQuicTag(kexs, their_key_exchanges, &out_params->key_exchange, + &key_exchange_index)) { + *error_details = "Unsupported AEAD or KEXS"; + return QUIC_CRYPTO_NO_SUPPORT; + } + out->SetVector(kAEAD, QuicTagVector{out_params->aead}); + out->SetVector(kKEXS, QuicTagVector{out_params->key_exchange}); + + absl::string_view public_value; + if (scfg->GetNthValue24(kPUBS, key_exchange_index, &public_value) != + QUIC_NO_ERROR) { + *error_details = "Missing public value"; + return QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER; + } + + absl::string_view orbit; + if (!scfg->GetStringPiece(kORBT, &orbit) || orbit.size() != kOrbitSize) { + *error_details = "SCFG missing OBIT"; + return QUIC_CRYPTO_MESSAGE_PARAMETER_NOT_FOUND; + } + + CryptoUtils::GenerateNonce(now, rand, orbit, &out_params->client_nonce); + out->SetStringPiece(kNONC, out_params->client_nonce); + if (!out_params->server_nonce.empty()) { + out->SetStringPiece(kServerNonceTag, out_params->server_nonce); + } + + switch (out_params->key_exchange) { + case kC255: + out_params->client_key_exchange = Curve25519KeyExchange::New( + Curve25519KeyExchange::NewPrivateKey(rand)); + break; + case kP256: + out_params->client_key_exchange = + P256KeyExchange::New(P256KeyExchange::NewPrivateKey()); + break; + default: + QUICHE_DCHECK(false); + *error_details = "Configured to support an unknown key exchange"; + return QUIC_CRYPTO_INTERNAL_ERROR; + } + + if (!out_params->client_key_exchange->CalculateSharedKeySync( + public_value, &out_params->initial_premaster_secret)) { + *error_details = "Key exchange failure"; + return QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER; + } + out->SetStringPiece(kPUBS, out_params->client_key_exchange->public_value()); + + const std::vector& certs = cached->certs(); + if (certs.empty()) { + *error_details = "No certs to calculate XLCT"; + return QUIC_CRYPTO_INTERNAL_ERROR; + } + out->SetValue(kXLCT, CryptoUtils::ComputeLeafCertHash(certs[0])); + + // Derive the symmetric keys and set up the encrypters and decrypters. + // Set the following members of out_params: + // out_params->hkdf_input_suffix + // out_params->initial_crypters + out_params->hkdf_input_suffix.clear(); + out_params->hkdf_input_suffix.append(connection_id.data(), + connection_id.length()); + const QuicData& client_hello_serialized = out->GetSerialized(); + out_params->hkdf_input_suffix.append(client_hello_serialized.data(), + client_hello_serialized.length()); + out_params->hkdf_input_suffix.append(cached->server_config()); + if (certs.empty()) { + *error_details = "No certs found to include in KDF"; + return QUIC_CRYPTO_INTERNAL_ERROR; + } + out_params->hkdf_input_suffix.append(certs[0]); + + std::string hkdf_input; + const size_t label_len = strlen(QuicCryptoConfig::kInitialLabel) + 1; + hkdf_input.reserve(label_len + out_params->hkdf_input_suffix.size()); + hkdf_input.append(QuicCryptoConfig::kInitialLabel, label_len); + hkdf_input.append(out_params->hkdf_input_suffix); + + std::string* subkey_secret = &out_params->initial_subkey_secret; + + if (!CryptoUtils::DeriveKeys( + actual_version, out_params->initial_premaster_secret, + out_params->aead, out_params->client_nonce, out_params->server_nonce, + pre_shared_key_, hkdf_input, Perspective::IS_CLIENT, + CryptoUtils::Diversification::Pending(), + &out_params->initial_crypters, subkey_secret)) { + *error_details = "Symmetric key setup failed"; + return QUIC_CRYPTO_SYMMETRIC_KEY_SETUP_FAILED; + } + + return QUIC_NO_ERROR; +} + +QuicErrorCode QuicCryptoClientConfig::CacheNewServerConfig( + const CryptoHandshakeMessage& message, QuicWallTime now, + QuicTransportVersion /*version*/, absl::string_view chlo_hash, + const std::vector& cached_certs, CachedState* cached, + std::string* error_details) { + QUICHE_DCHECK(error_details != nullptr); + + absl::string_view scfg; + if (!message.GetStringPiece(kSCFG, &scfg)) { + *error_details = "Missing SCFG"; + return QUIC_CRYPTO_MESSAGE_PARAMETER_NOT_FOUND; + } + + QuicWallTime expiration_time = QuicWallTime::Zero(); + uint64_t expiry_seconds; + if (message.GetUint64(kSTTL, &expiry_seconds) == QUIC_NO_ERROR) { + // Only cache configs for a maximum of 1 week. + expiration_time = now.Add(QuicTime::Delta::FromSeconds( + std::min(expiry_seconds, kNumSecondsPerWeek))); + } + + CachedState::ServerConfigState state = + cached->SetServerConfig(scfg, now, expiration_time, error_details); + if (state == CachedState::SERVER_CONFIG_EXPIRED) { + return QUIC_CRYPTO_SERVER_CONFIG_EXPIRED; + } + // TODO(rtenneti): Return more specific error code than returning + // QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER. + if (state != CachedState::SERVER_CONFIG_VALID) { + return QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER; + } + + absl::string_view token; + if (message.GetStringPiece(kSourceAddressTokenTag, &token)) { + cached->set_source_address_token(token); + } + + absl::string_view proof, cert_bytes, cert_sct; + bool has_proof = message.GetStringPiece(kPROF, &proof); + bool has_cert = message.GetStringPiece(kCertificateTag, &cert_bytes); + if (has_proof && has_cert) { + std::vector certs; + if (!CertCompressor::DecompressChain(cert_bytes, cached_certs, &certs)) { + *error_details = "Certificate data invalid"; + return QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER; + } + + message.GetStringPiece(kCertificateSCTTag, &cert_sct); + cached->SetProof(certs, cert_sct, chlo_hash, proof); + } else { + // Secure QUIC: clear existing proof as we have been sent a new SCFG + // without matching proof/certs. + cached->ClearProof(); + + if (has_proof && !has_cert) { + *error_details = "Certificate missing"; + return QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER; + } + + if (!has_proof && has_cert) { + *error_details = "Proof missing"; + return QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER; + } + } + + return QUIC_NO_ERROR; +} + +QuicErrorCode QuicCryptoClientConfig::ProcessRejection( + const CryptoHandshakeMessage& rej, QuicWallTime now, + const QuicTransportVersion version, absl::string_view chlo_hash, + CachedState* cached, + quiche::QuicheReferenceCountedPointer + out_params, + std::string* error_details) { + QUICHE_DCHECK(error_details != nullptr); + + if (rej.tag() != kREJ) { + *error_details = "Message is not REJ"; + return QUIC_CRYPTO_INTERNAL_ERROR; + } + + QuicErrorCode error = + CacheNewServerConfig(rej, now, version, chlo_hash, + out_params->cached_certs, cached, error_details); + if (error != QUIC_NO_ERROR) { + return error; + } + + absl::string_view nonce; + if (rej.GetStringPiece(kServerNonceTag, &nonce)) { + out_params->server_nonce = std::string(nonce); + } + + return QUIC_NO_ERROR; +} + +QuicErrorCode QuicCryptoClientConfig::ProcessServerHello( + const CryptoHandshakeMessage& server_hello, + QuicConnectionId /*connection_id*/, ParsedQuicVersion version, + const ParsedQuicVersionVector& negotiated_versions, CachedState* cached, + quiche::QuicheReferenceCountedPointer + out_params, + std::string* error_details) { + QUICHE_DCHECK(error_details != nullptr); + + QuicErrorCode valid = CryptoUtils::ValidateServerHello( + server_hello, negotiated_versions, error_details); + if (valid != QUIC_NO_ERROR) { + return valid; + } + + // Learn about updated source address tokens. + absl::string_view token; + if (server_hello.GetStringPiece(kSourceAddressTokenTag, &token)) { + cached->set_source_address_token(token); + } + + absl::string_view shlo_nonce; + if (!server_hello.GetStringPiece(kServerNonceTag, &shlo_nonce)) { + *error_details = "server hello missing server nonce"; + return QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER; + } + + // TODO(agl): + // learn about updated SCFGs. + + absl::string_view public_value; + if (!server_hello.GetStringPiece(kPUBS, &public_value)) { + *error_details = "server hello missing forward secure public value"; + return QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER; + } + + if (!out_params->client_key_exchange->CalculateSharedKeySync( + public_value, &out_params->forward_secure_premaster_secret)) { + *error_details = "Key exchange failure"; + return QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER; + } + + std::string hkdf_input; + const size_t label_len = strlen(QuicCryptoConfig::kForwardSecureLabel) + 1; + hkdf_input.reserve(label_len + out_params->hkdf_input_suffix.size()); + hkdf_input.append(QuicCryptoConfig::kForwardSecureLabel, label_len); + hkdf_input.append(out_params->hkdf_input_suffix); + + if (!CryptoUtils::DeriveKeys( + version, out_params->forward_secure_premaster_secret, + out_params->aead, out_params->client_nonce, + shlo_nonce.empty() ? out_params->server_nonce : shlo_nonce, + pre_shared_key_, hkdf_input, Perspective::IS_CLIENT, + CryptoUtils::Diversification::Never(), + &out_params->forward_secure_crypters, &out_params->subkey_secret)) { + *error_details = "Symmetric key setup failed"; + return QUIC_CRYPTO_SYMMETRIC_KEY_SETUP_FAILED; + } + + return QUIC_NO_ERROR; +} + +QuicErrorCode QuicCryptoClientConfig::ProcessServerConfigUpdate( + const CryptoHandshakeMessage& server_config_update, QuicWallTime now, + const QuicTransportVersion version, absl::string_view chlo_hash, + CachedState* cached, + quiche::QuicheReferenceCountedPointer + out_params, + std::string* error_details) { + QUICHE_DCHECK(error_details != nullptr); + + if (server_config_update.tag() != kSCUP) { + *error_details = "ServerConfigUpdate must have kSCUP tag."; + return QUIC_INVALID_CRYPTO_MESSAGE_TYPE; + } + return CacheNewServerConfig(server_config_update, now, version, chlo_hash, + out_params->cached_certs, cached, error_details); +} + +ProofVerifier* QuicCryptoClientConfig::proof_verifier() const { + return proof_verifier_.get(); +} + +SessionCache* QuicCryptoClientConfig::session_cache() const { + return session_cache_.get(); +} + +ClientProofSource* QuicCryptoClientConfig::proof_source() const { + return proof_source_.get(); +} + +void QuicCryptoClientConfig::set_proof_source( + std::unique_ptr proof_source) { + proof_source_ = std::move(proof_source); +} + +SSL_CTX* QuicCryptoClientConfig::ssl_ctx() const { return ssl_ctx_.get(); } + +void QuicCryptoClientConfig::InitializeFrom( + const QuicServerId& server_id, const QuicServerId& canonical_server_id, + QuicCryptoClientConfig* canonical_crypto_config) { + CachedState* canonical_cached = + canonical_crypto_config->LookupOrCreate(canonical_server_id); + if (!canonical_cached->proof_valid()) { + return; + } + CachedState* cached = LookupOrCreate(server_id); + cached->InitializeFrom(*canonical_cached); +} + +void QuicCryptoClientConfig::AddCanonicalSuffix(const std::string& suffix) { + canonical_suffixes_.push_back(suffix); +} + +bool QuicCryptoClientConfig::PopulateFromCanonicalConfig( + const QuicServerId& server_id, CachedState* cached) { + QUICHE_DCHECK(cached->IsEmpty()); + size_t i = 0; + for (; i < canonical_suffixes_.size(); ++i) { + if (absl::EndsWithIgnoreCase(server_id.host(), canonical_suffixes_[i])) { + break; + } + } + if (i == canonical_suffixes_.size()) { + return false; + } + + QuicServerId suffix_server_id(canonical_suffixes_[i], server_id.port(), + server_id.privacy_mode_enabled()); + auto it = canonical_server_map_.lower_bound(suffix_server_id); + if (it == canonical_server_map_.end() || it->first != suffix_server_id) { + // This is the first host we've seen which matches the suffix, so make it + // canonical. Use |it| as position hint for faster insertion. + canonical_server_map_.insert( + it, std::make_pair(std::move(suffix_server_id), std::move(server_id))); + return false; + } + + const QuicServerId& canonical_server_id = it->second; + CachedState* canonical_state = cached_states_[canonical_server_id].get(); + if (!canonical_state->proof_valid()) { + return false; + } + + // Update canonical version to point at the "most recent" entry. + it->second = server_id; + + cached->InitializeFrom(*canonical_state); + return true; +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/quic_crypto_client_config.h b/quiche/quic/core/crypto/quic_crypto_client_config.h new file mode 100644 index 000000000000..546f4d13b7c4 --- /dev/null +++ b/quiche/quic/core/crypto/quic_crypto_client_config.h @@ -0,0 +1,467 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_QUIC_CRYPTO_CLIENT_CONFIG_H_ +#define QUICHE_QUIC_CORE_CRYPTO_QUIC_CRYPTO_CLIENT_CONFIG_H_ + +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "openssl/base.h" +#include "openssl/ssl.h" +#include "quiche/quic/core/crypto/client_proof_source.h" +#include "quiche/quic/core/crypto/crypto_handshake.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/crypto/transport_parameters.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/common/platform/api/quiche_reference_counted.h" + +namespace quic { + +class CryptoHandshakeMessage; +class ProofVerifier; +class ProofVerifyDetails; + +// QuicResumptionState stores the state a client needs for performing connection +// resumption. +struct QUIC_EXPORT_PRIVATE QuicResumptionState { + // |tls_session| holds the cryptographic state necessary for a resumption. It + // includes the ALPN negotiated on the connection where the ticket was + // received. + bssl::UniquePtr tls_session; + + // If the application using QUIC doesn't support 0-RTT handshakes or the + // client didn't receive a 0-RTT capable session ticket from the server, + // |transport_params| will be null. Otherwise, it will contain the transport + // parameters received from the server on the original connection. + std::unique_ptr transport_params = nullptr; + + // If |transport_params| is null, then |application_state| is ignored and + // should be empty. |application_state| contains serialized state that the + // client received from the server at the application layer that the client + // needs to remember when performing a 0-RTT handshake. + std::unique_ptr application_state = nullptr; + + // Opaque token received in NEW_TOKEN frame if any. + std::string token; +}; + +// SessionCache is an interface for managing storing and retrieving +// QuicResumptionState structs. +class QUIC_EXPORT_PRIVATE SessionCache { + public: + virtual ~SessionCache() {} + + // Inserts |session|, |params|, and |application_states| into the cache, keyed + // by |server_id|. Insert is first called after all three values are present. + // The ownership of |session| is transferred to the cache, while other two are + // copied. Multiple sessions might need to be inserted for a connection. + // SessionCache implementations should support storing + // multiple entries per server ID. + virtual void Insert(const QuicServerId& server_id, + bssl::UniquePtr session, + const TransportParameters& params, + const ApplicationState* application_state) = 0; + + // Lookup is called once at the beginning of each TLS handshake to potentially + // provide the saved state both for the TLS handshake and for sending 0-RTT + // data (if supported). Lookup may return a nullptr. Implementations should + // delete cache entries after returning them in Lookup so that session tickets + // are used only once. + virtual std::unique_ptr Lookup( + const QuicServerId& server_id, QuicWallTime now, const SSL_CTX* ctx) = 0; + + // Called when 0-RTT is rejected. Disables early data for all the TLS tickets + // associated with |server_id|. + virtual void ClearEarlyData(const QuicServerId& server_id) = 0; + + // Called when NEW_TOKEN frame is received. + virtual void OnNewTokenReceived(const QuicServerId& server_id, + absl::string_view token) = 0; + + // Called to remove expired entries. + virtual void RemoveExpiredEntries(QuicWallTime now) = 0; + + // Clear the session cache. + virtual void Clear() = 0; +}; + +// QuicCryptoClientConfig contains crypto-related configuration settings for a +// client. Note that this object isn't thread-safe. It's designed to be used on +// a single thread at a time. +class QUIC_EXPORT_PRIVATE QuicCryptoClientConfig : public QuicCryptoConfig { + public: + // A CachedState contains the information that the client needs in order to + // perform a 0-RTT handshake with a server. This information can be reused + // over several connections to the same server. + class QUIC_EXPORT_PRIVATE CachedState { + public: + // Enum to track if the server config is valid or not. If it is not valid, + // it specifies why it is invalid. + enum ServerConfigState { + // WARNING: Do not change the numerical values of any of server config + // state. Do not remove deprecated server config states - just comment + // them as deprecated. + SERVER_CONFIG_EMPTY = 0, + SERVER_CONFIG_INVALID = 1, + SERVER_CONFIG_CORRUPTED = 2, + SERVER_CONFIG_EXPIRED = 3, + SERVER_CONFIG_INVALID_EXPIRY = 4, + SERVER_CONFIG_VALID = 5, + // NOTE: Add new server config states only immediately above this line. + // Make sure to update the QuicServerConfigState enum in + // tools/metrics/histograms/histograms.xml accordingly. + SERVER_CONFIG_COUNT + }; + + CachedState(); + CachedState(const CachedState&) = delete; + CachedState& operator=(const CachedState&) = delete; + ~CachedState(); + + // IsComplete returns true if this object contains enough information to + // perform a handshake with the server. |now| is used to judge whether any + // cached server config has expired. + bool IsComplete(QuicWallTime now) const; + + // IsEmpty returns true if |server_config_| is empty. + bool IsEmpty() const; + + // GetServerConfig returns the parsed contents of |server_config|, or + // nullptr if |server_config| is empty. The return value is owned by this + // object and is destroyed when this object is. + const CryptoHandshakeMessage* GetServerConfig() const; + + // SetServerConfig checks that |server_config| parses correctly and stores + // it in |server_config_|. |now| is used to judge whether |server_config| + // has expired. + ServerConfigState SetServerConfig(absl::string_view server_config, + QuicWallTime now, + QuicWallTime expiry_time, + std::string* error_details); + + // InvalidateServerConfig clears the cached server config (if any). + void InvalidateServerConfig(); + + // SetProof stores a cert chain, cert signed timestamp and signature. + void SetProof(const std::vector& certs, + absl::string_view cert_sct, absl::string_view chlo_hash, + absl::string_view signature); + + // Clears all the data. + void Clear(); + + // Clears the certificate chain and signature and invalidates the proof. + void ClearProof(); + + // SetProofValid records that the certificate chain and signature have been + // validated and that it's safe to assume that the server is legitimate. + // (Note: this does not check the chain or signature.) + void SetProofValid(); + + // If the server config or the proof has changed then it needs to be + // revalidated. Helper function to keep server_config_valid_ and + // generation_counter_ in sync. + void SetProofInvalid(); + + const std::string& server_config() const; + const std::string& source_address_token() const; + const std::vector& certs() const; + const std::string& cert_sct() const; + const std::string& chlo_hash() const; + const std::string& signature() const; + bool proof_valid() const; + uint64_t generation_counter() const; + const ProofVerifyDetails* proof_verify_details() const; + + void set_source_address_token(absl::string_view token); + + void set_cert_sct(absl::string_view cert_sct); + + // SetProofVerifyDetails takes ownership of |details|. + void SetProofVerifyDetails(ProofVerifyDetails* details); + + // Copy the |server_config_|, |source_address_token_|, |certs_|, + // |expiration_time_|, |cert_sct_|, |chlo_hash_| and |server_config_sig_| + // from the |other|. The remaining fields, |generation_counter_|, + // |proof_verify_details_|, and |scfg_| remain unchanged. + void InitializeFrom(const CachedState& other); + + // Initializes this cached state based on the arguments provided. + // Returns false if there is a problem parsing the server config. + bool Initialize(absl::string_view server_config, + absl::string_view source_address_token, + const std::vector& certs, + const std::string& cert_sct, absl::string_view chlo_hash, + absl::string_view signature, QuicWallTime now, + QuicWallTime expiration_time); + + private: + std::string server_config_; // A serialized handshake message. + std::string source_address_token_; // An opaque proof of IP ownership. + std::vector certs_; // A list of certificates in leaf-first + // order. + std::string cert_sct_; // Signed timestamp of the leaf cert. + std::string chlo_hash_; // Hash of the CHLO message. + std::string server_config_sig_; // A signature of |server_config_|. + bool server_config_valid_; // True if |server_config_| is correctly + // signed and |certs_| has been validated. + QuicWallTime expiration_time_; // Time when the config is no longer valid. + // Generation counter associated with the |server_config_|, |certs_| and + // |server_config_sig_| combination. It is incremented whenever we set + // server_config_valid_ to false. + uint64_t generation_counter_; + + std::unique_ptr proof_verify_details_; + + // scfg contains the cached, parsed value of |server_config|. + mutable std::unique_ptr scfg_; + }; + + // Used to filter server ids for partial config deletion. + class QUIC_EXPORT_PRIVATE ServerIdFilter { + public: + virtual ~ServerIdFilter() {} + + // Returns true if |server_id| matches the filter. + virtual bool Matches(const QuicServerId& server_id) const = 0; + }; + + // DEPRECATED: Use the constructor below instead. + explicit QuicCryptoClientConfig( + std::unique_ptr proof_verifier); + QuicCryptoClientConfig(std::unique_ptr proof_verifier, + std::unique_ptr session_cache); + QuicCryptoClientConfig(const QuicCryptoClientConfig&) = delete; + QuicCryptoClientConfig& operator=(const QuicCryptoClientConfig&) = delete; + ~QuicCryptoClientConfig(); + + // LookupOrCreate returns a CachedState for the given |server_id|. If no such + // CachedState currently exists, it will be created and cached. + CachedState* LookupOrCreate(const QuicServerId& server_id); + + // Delete CachedState objects whose server ids match |filter| from + // cached_states. + void ClearCachedStates(const ServerIdFilter& filter); + + // FillInchoateClientHello sets |out| to be a CHLO message that elicits a + // source-address token or SCFG from a server. If |cached| is non-nullptr, the + // source-address token will be taken from it. |out_params| is used in order + // to store the cached certs that were sent as hints to the server in + // |out_params->cached_certs|. |preferred_version| is the version of the + // QUIC protocol that this client chose to use initially. This allows the + // server to detect downgrade attacks. If |demand_x509_proof| is true, + // then |out| will include an X509 proof demand, and the associated + // certificate related fields. + void FillInchoateClientHello( + const QuicServerId& server_id, const ParsedQuicVersion preferred_version, + const CachedState* cached, QuicRandom* rand, bool demand_x509_proof, + quiche::QuicheReferenceCountedPointer + out_params, + CryptoHandshakeMessage* out) const; + + // FillClientHello sets |out| to be a CHLO message based on the configuration + // of this object. This object must have cached enough information about + // the server's hostname in order to perform a handshake. This can be checked + // with the |IsComplete| member of |CachedState|. + // + // |now| and |rand| are used to generate the nonce and |out_params| is + // filled with the results of the handshake that the server is expected to + // accept. |preferred_version| is the version of the QUIC protocol that this + // client chose to use initially. This allows the server to detect downgrade + // attacks. + // + // If |channel_id_key| is not null, it is used to sign a secret value derived + // from the client and server's keys, and the Channel ID public key and the + // signature are placed in the CETV value of the CHLO. + QuicErrorCode FillClientHello( + const QuicServerId& server_id, QuicConnectionId connection_id, + const ParsedQuicVersion preferred_version, + const ParsedQuicVersion actual_version, const CachedState* cached, + QuicWallTime now, QuicRandom* rand, + quiche::QuicheReferenceCountedPointer + out_params, + CryptoHandshakeMessage* out, std::string* error_details) const; + + // ProcessRejection processes a REJ message from a server and updates the + // cached information about that server. After this, |IsComplete| may return + // true for that server's CachedState. If the rejection message contains state + // about a future handshake (i.e. an nonce value from the server), then it + // will be saved in |out_params|. |now| is used to judge whether the server + // config in the rejection message has expired. + QuicErrorCode ProcessRejection( + const CryptoHandshakeMessage& rej, QuicWallTime now, + QuicTransportVersion version, absl::string_view chlo_hash, + CachedState* cached, + quiche::QuicheReferenceCountedPointer + out_params, + std::string* error_details); + + // ProcessServerHello processes the message in |server_hello|, updates the + // cached information about that server, writes the negotiated parameters to + // |out_params| and returns QUIC_NO_ERROR. If |server_hello| is unacceptable + // then it puts an error message in |error_details| and returns an error + // code. |version| is the QUIC version for the current connection. + // |negotiated_versions| contains the list of version, if any, that were + // present in a version negotiation packet previously received from the + // server. The contents of this list will be compared against the list of + // versions provided in the VER tag of the server hello. + QuicErrorCode ProcessServerHello( + const CryptoHandshakeMessage& server_hello, + QuicConnectionId connection_id, ParsedQuicVersion version, + const ParsedQuicVersionVector& negotiated_versions, CachedState* cached, + quiche::QuicheReferenceCountedPointer + out_params, + std::string* error_details); + + // Processes the message in |server_config_update|, updating the cached source + // address token, and server config. + // If |server_config_update| is invalid then |error_details| will contain an + // error message, and an error code will be returned. If all has gone well + // QUIC_NO_ERROR is returned. + QuicErrorCode ProcessServerConfigUpdate( + const CryptoHandshakeMessage& server_config_update, QuicWallTime now, + const QuicTransportVersion version, absl::string_view chlo_hash, + CachedState* cached, + quiche::QuicheReferenceCountedPointer + out_params, + std::string* error_details); + + ProofVerifier* proof_verifier() const; + SessionCache* session_cache() const; + ClientProofSource* proof_source() const; + void set_proof_source(std::unique_ptr proof_source); + SSL_CTX* ssl_ctx() const; + + // Initialize the CachedState from |canonical_crypto_config| for the + // |canonical_server_id| as the initial CachedState for |server_id|. We will + // copy config data only if |canonical_crypto_config| has valid proof. + void InitializeFrom(const QuicServerId& server_id, + const QuicServerId& canonical_server_id, + QuicCryptoClientConfig* canonical_crypto_config); + + // Adds |suffix| as a domain suffix for which the server's crypto config + // is expected to be shared among servers with the domain suffix. If a server + // matches this suffix, then the server config from another server with the + // suffix will be used to initialize the cached state for this server. + void AddCanonicalSuffix(const std::string& suffix); + + // Saves the |user_agent_id| that will be passed in QUIC's CHLO message. + void set_user_agent_id(const std::string& user_agent_id) { + user_agent_id_ = user_agent_id; + } + + // Returns the user_agent_id that will be provided in the client hello + // handshake message. + const std::string& user_agent_id() const { return user_agent_id_; } + + void set_tls_signature_algorithms(std::string signature_algorithms) { + tls_signature_algorithms_ = std::move(signature_algorithms); + } + + const absl::optional& tls_signature_algorithms() const { + return tls_signature_algorithms_; + } + + // Saves the |alpn| that will be passed in QUIC's CHLO message. + void set_alpn(const std::string& alpn) { alpn_ = alpn; } + + // Saves the pre-shared key used during the handshake. + void set_pre_shared_key(absl::string_view psk) { + pre_shared_key_ = std::string(psk); + } + + // Returns the pre-shared key used during the handshake. + const std::string& pre_shared_key() const { return pre_shared_key_; } + + bool pad_inchoate_hello() const { return pad_inchoate_hello_; } + void set_pad_inchoate_hello(bool new_value) { + pad_inchoate_hello_ = new_value; + } + + bool pad_full_hello() const { return pad_full_hello_; } + void set_pad_full_hello(bool new_value) { pad_full_hello_ = new_value; } + + SessionCache* mutable_session_cache() { return session_cache_.get(); } + + private: + // Sets the members to reasonable, default values. + void SetDefaults(); + + // CacheNewServerConfig checks for SCFG, STK, PROF, and CRT tags in |message|, + // verifies them, and stores them in the cached state if they validate. + // This is used on receipt of a REJ from a server, or when a server sends + // updated server config during a connection. + QuicErrorCode CacheNewServerConfig( + const CryptoHandshakeMessage& message, QuicWallTime now, + QuicTransportVersion version, absl::string_view chlo_hash, + const std::vector& cached_certs, CachedState* cached, + std::string* error_details); + + // If the suffix of the hostname in |server_id| is in |canonical_suffixes_|, + // then populate |cached| with the canonical cached state from + // |canonical_server_map_| for that suffix. Returns true if |cached| is + // initialized with canonical cached state. + bool PopulateFromCanonicalConfig(const QuicServerId& server_id, + CachedState* cached); + + // cached_states_ maps from the server_id to the cached information about + // that server. + std::map> cached_states_; + + // Contains a map of servers which could share the same server config. Map + // from a canonical host suffix/port/scheme to a representative server with + // the canonical suffix, which has a plausible set of initial certificates + // (or at least server public key). + std::map canonical_server_map_; + + // Contains list of suffixes (for exmaple ".c.youtube.com", + // ".googlevideo.com") of canonical hostnames. + std::vector canonical_suffixes_; + + std::unique_ptr proof_verifier_; + std::unique_ptr session_cache_; + std::unique_ptr proof_source_; + + bssl::UniquePtr ssl_ctx_; + + // The |user_agent_id_| passed in QUIC's CHLO message. + std::string user_agent_id_; + + // The |alpn_| passed in QUIC's CHLO message. + std::string alpn_; + + // If non-empty, the client will operate in the pre-shared key mode by + // incorporating |pre_shared_key_| into the key schedule. + std::string pre_shared_key_; + + // If set, configure the client to use the specified signature algorithms, via + // SSL_set1_sigalgs_list. TLS only. + absl::optional tls_signature_algorithms_; + + // In QUIC, technically, client hello should be fully padded. + // However, fully padding on slow network connection (e.g. 50kbps) can add + // 150ms latency to one roundtrip. Therefore, you can disable padding of + // individual messages. It is recommend to leave at least one message in + // each direction fully padded (e.g. full CHLO and SHLO), but if you know + // the lower-bound MTU, you don't need to pad all of them (keep in mind that + // it's not OK to do it according to the standard). + // + // Also, if you disable padding, you must disable (change) the + // anti-amplification protection. You should only do so if you have some + // other means of verifying the client. + bool pad_inchoate_hello_ = true; + bool pad_full_hello_ = true; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_QUIC_CRYPTO_CLIENT_CONFIG_H_ diff --git a/quiche/quic/core/crypto/quic_crypto_client_config_test.cc b/quiche/quic/core/crypto/quic_crypto_client_config_test.cc new file mode 100644 index 000000000000..7556592f1c66 --- /dev/null +++ b/quiche/quic/core/crypto/quic_crypto_client_config_test.cc @@ -0,0 +1,550 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/quic_crypto_client_config.h" + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/proof_verifier.h" +#include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/mock_random.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +using testing::StartsWith; + +namespace quic { +namespace test { +namespace { + +class TestProofVerifyDetails : public ProofVerifyDetails { + ~TestProofVerifyDetails() override {} + + // ProofVerifyDetails implementation + ProofVerifyDetails* Clone() const override { + return new TestProofVerifyDetails; + } +}; + +class OneServerIdFilter : public QuicCryptoClientConfig::ServerIdFilter { + public: + explicit OneServerIdFilter(const QuicServerId* server_id) + : server_id_(*server_id) {} + + bool Matches(const QuicServerId& server_id) const override { + return server_id == server_id_; + } + + private: + const QuicServerId server_id_; +}; + +class AllServerIdsFilter : public QuicCryptoClientConfig::ServerIdFilter { + public: + bool Matches(const QuicServerId& /*server_id*/) const override { + return true; + } +}; + +} // namespace + +class QuicCryptoClientConfigTest : public QuicTest {}; + +TEST_F(QuicCryptoClientConfigTest, CachedState_IsEmpty) { + QuicCryptoClientConfig::CachedState state; + EXPECT_TRUE(state.IsEmpty()); +} + +TEST_F(QuicCryptoClientConfigTest, CachedState_IsComplete) { + QuicCryptoClientConfig::CachedState state; + EXPECT_FALSE(state.IsComplete(QuicWallTime::FromUNIXSeconds(0))); +} + +TEST_F(QuicCryptoClientConfigTest, CachedState_GenerationCounter) { + QuicCryptoClientConfig::CachedState state; + EXPECT_EQ(0u, state.generation_counter()); + state.SetProofInvalid(); + EXPECT_EQ(1u, state.generation_counter()); +} + +TEST_F(QuicCryptoClientConfigTest, CachedState_SetProofVerifyDetails) { + QuicCryptoClientConfig::CachedState state; + EXPECT_TRUE(state.proof_verify_details() == nullptr); + ProofVerifyDetails* details = new TestProofVerifyDetails; + state.SetProofVerifyDetails(details); + EXPECT_EQ(details, state.proof_verify_details()); +} + +TEST_F(QuicCryptoClientConfigTest, CachedState_InitializeFrom) { + QuicCryptoClientConfig::CachedState state; + QuicCryptoClientConfig::CachedState other; + state.set_source_address_token("TOKEN"); + // TODO(rch): Populate other fields of |state|. + other.InitializeFrom(state); + EXPECT_EQ(state.server_config(), other.server_config()); + EXPECT_EQ(state.source_address_token(), other.source_address_token()); + EXPECT_EQ(state.certs(), other.certs()); + EXPECT_EQ(1u, other.generation_counter()); +} + +TEST_F(QuicCryptoClientConfigTest, InchoateChlo) { + QuicCryptoClientConfig::CachedState state; + QuicCryptoClientConfig config(crypto_test_utils::ProofVerifierForTesting()); + config.set_user_agent_id("quic-tester"); + config.set_alpn("hq"); + quiche::QuicheReferenceCountedPointer params( + new QuicCryptoNegotiatedParameters); + CryptoHandshakeMessage msg; + QuicServerId server_id("www.google.com", 443, false); + MockRandom rand; + config.FillInchoateClientHello(server_id, QuicVersionMax(), &state, &rand, + /* demand_x509_proof= */ true, params, &msg); + + QuicVersionLabel cver; + EXPECT_THAT(msg.GetVersionLabel(kVER, &cver), IsQuicNoError()); + EXPECT_EQ(CreateQuicVersionLabel(QuicVersionMax()), cver); + absl::string_view proof_nonce; + EXPECT_TRUE(msg.GetStringPiece(kNONP, &proof_nonce)); + EXPECT_EQ(std::string(32, 'r'), proof_nonce); + absl::string_view user_agent_id; + EXPECT_TRUE(msg.GetStringPiece(kUAID, &user_agent_id)); + EXPECT_EQ("quic-tester", user_agent_id); + absl::string_view alpn; + EXPECT_TRUE(msg.GetStringPiece(kALPN, &alpn)); + EXPECT_EQ("hq", alpn); + EXPECT_EQ(msg.minimum_size(), 1u); +} + +TEST_F(QuicCryptoClientConfigTest, InchoateChloIsNotPadded) { + QuicCryptoClientConfig::CachedState state; + QuicCryptoClientConfig config(crypto_test_utils::ProofVerifierForTesting()); + config.set_pad_inchoate_hello(false); + config.set_user_agent_id("quic-tester"); + config.set_alpn("hq"); + quiche::QuicheReferenceCountedPointer params( + new QuicCryptoNegotiatedParameters); + CryptoHandshakeMessage msg; + QuicServerId server_id("www.google.com", 443, false); + MockRandom rand; + config.FillInchoateClientHello(server_id, QuicVersionMax(), &state, &rand, + /* demand_x509_proof= */ true, params, &msg); + + EXPECT_EQ(msg.minimum_size(), 1u); +} + +// Make sure AES-GCM is the preferred encryption algorithm if it has hardware +// acceleration. +TEST_F(QuicCryptoClientConfigTest, PreferAesGcm) { + QuicCryptoClientConfig config(crypto_test_utils::ProofVerifierForTesting()); + if (EVP_has_aes_hardware() == 1) { + EXPECT_EQ(kAESG, config.aead[0]); + } else { + EXPECT_EQ(kCC20, config.aead[0]); + } +} + +TEST_F(QuicCryptoClientConfigTest, InchoateChloSecure) { + QuicCryptoClientConfig::CachedState state; + QuicCryptoClientConfig config(crypto_test_utils::ProofVerifierForTesting()); + quiche::QuicheReferenceCountedPointer params( + new QuicCryptoNegotiatedParameters); + CryptoHandshakeMessage msg; + QuicServerId server_id("www.google.com", 443, false); + MockRandom rand; + config.FillInchoateClientHello(server_id, QuicVersionMax(), &state, &rand, + /* demand_x509_proof= */ true, params, &msg); + + QuicTag pdmd; + EXPECT_THAT(msg.GetUint32(kPDMD, &pdmd), IsQuicNoError()); + EXPECT_EQ(kX509, pdmd); + absl::string_view scid; + EXPECT_FALSE(msg.GetStringPiece(kSCID, &scid)); +} + +TEST_F(QuicCryptoClientConfigTest, InchoateChloSecureWithSCIDNoEXPY) { + // Test that a config with no EXPY is still valid when a non-zero + // expiry time is passed in. + QuicCryptoClientConfig::CachedState state; + CryptoHandshakeMessage scfg; + scfg.set_tag(kSCFG); + scfg.SetStringPiece(kSCID, "12345678"); + std::string details; + QuicWallTime now = QuicWallTime::FromUNIXSeconds(1); + QuicWallTime expiry = QuicWallTime::FromUNIXSeconds(2); + state.SetServerConfig(scfg.GetSerialized().AsStringPiece(), now, expiry, + &details); + EXPECT_FALSE(state.IsEmpty()); + + QuicCryptoClientConfig config(crypto_test_utils::ProofVerifierForTesting()); + quiche::QuicheReferenceCountedPointer params( + new QuicCryptoNegotiatedParameters); + CryptoHandshakeMessage msg; + QuicServerId server_id("www.google.com", 443, false); + MockRandom rand; + config.FillInchoateClientHello(server_id, QuicVersionMax(), &state, &rand, + /* demand_x509_proof= */ true, params, &msg); + + absl::string_view scid; + EXPECT_TRUE(msg.GetStringPiece(kSCID, &scid)); + EXPECT_EQ("12345678", scid); +} + +TEST_F(QuicCryptoClientConfigTest, InchoateChloSecureWithSCID) { + QuicCryptoClientConfig::CachedState state; + CryptoHandshakeMessage scfg; + scfg.set_tag(kSCFG); + uint64_t future = 1; + scfg.SetValue(kEXPY, future); + scfg.SetStringPiece(kSCID, "12345678"); + std::string details; + state.SetServerConfig(scfg.GetSerialized().AsStringPiece(), + QuicWallTime::FromUNIXSeconds(1), + QuicWallTime::FromUNIXSeconds(0), &details); + EXPECT_FALSE(state.IsEmpty()); + + QuicCryptoClientConfig config(crypto_test_utils::ProofVerifierForTesting()); + quiche::QuicheReferenceCountedPointer params( + new QuicCryptoNegotiatedParameters); + CryptoHandshakeMessage msg; + QuicServerId server_id("www.google.com", 443, false); + MockRandom rand; + config.FillInchoateClientHello(server_id, QuicVersionMax(), &state, &rand, + /* demand_x509_proof= */ true, params, &msg); + + absl::string_view scid; + EXPECT_TRUE(msg.GetStringPiece(kSCID, &scid)); + EXPECT_EQ("12345678", scid); +} + +TEST_F(QuicCryptoClientConfigTest, FillClientHello) { + QuicCryptoClientConfig::CachedState state; + QuicCryptoClientConfig config(crypto_test_utils::ProofVerifierForTesting()); + quiche::QuicheReferenceCountedPointer params( + new QuicCryptoNegotiatedParameters); + QuicConnectionId kConnectionId = TestConnectionId(1234); + std::string error_details; + MockRandom rand; + CryptoHandshakeMessage chlo; + QuicServerId server_id("www.google.com", 443, false); + config.FillClientHello(server_id, kConnectionId, QuicVersionMax(), + QuicVersionMax(), &state, QuicWallTime::Zero(), &rand, + params, &chlo, &error_details); + + // Verify that the version label has been set correctly in the CHLO. + QuicVersionLabel cver; + EXPECT_THAT(chlo.GetVersionLabel(kVER, &cver), IsQuicNoError()); + EXPECT_EQ(CreateQuicVersionLabel(QuicVersionMax()), cver); +} + +TEST_F(QuicCryptoClientConfigTest, FillClientHelloNoPadding) { + QuicCryptoClientConfig::CachedState state; + QuicCryptoClientConfig config(crypto_test_utils::ProofVerifierForTesting()); + config.set_pad_full_hello(false); + quiche::QuicheReferenceCountedPointer params( + new QuicCryptoNegotiatedParameters); + QuicConnectionId kConnectionId = TestConnectionId(1234); + std::string error_details; + MockRandom rand; + CryptoHandshakeMessage chlo; + QuicServerId server_id("www.google.com", 443, false); + config.FillClientHello(server_id, kConnectionId, QuicVersionMax(), + QuicVersionMax(), &state, QuicWallTime::Zero(), &rand, + params, &chlo, &error_details); + + // Verify that the version label has been set correctly in the CHLO. + QuicVersionLabel cver; + EXPECT_THAT(chlo.GetVersionLabel(kVER, &cver), IsQuicNoError()); + EXPECT_EQ(CreateQuicVersionLabel(QuicVersionMax()), cver); + EXPECT_EQ(chlo.minimum_size(), 1u); +} + +TEST_F(QuicCryptoClientConfigTest, ProcessServerDowngradeAttack) { + ParsedQuicVersionVector supported_versions = AllSupportedVersions(); + if (supported_versions.size() == 1) { + // No downgrade attack is possible if the client only supports one version. + return; + } + + ParsedQuicVersionVector supported_version_vector; + for (size_t i = supported_versions.size(); i > 0; --i) { + supported_version_vector.push_back(supported_versions[i - 1]); + } + + CryptoHandshakeMessage msg; + msg.set_tag(kSHLO); + msg.SetVersionVector(kVER, supported_version_vector); + + QuicCryptoClientConfig::CachedState cached; + quiche::QuicheReferenceCountedPointer + out_params(new QuicCryptoNegotiatedParameters); + std::string error; + QuicCryptoClientConfig config(crypto_test_utils::ProofVerifierForTesting()); + EXPECT_THAT(config.ProcessServerHello( + msg, EmptyQuicConnectionId(), supported_versions.front(), + supported_versions, &cached, out_params, &error), + IsError(QUIC_VERSION_NEGOTIATION_MISMATCH)); + EXPECT_THAT(error, StartsWith("Downgrade attack detected: ServerVersions")); +} + +TEST_F(QuicCryptoClientConfigTest, InitializeFrom) { + QuicCryptoClientConfig config(crypto_test_utils::ProofVerifierForTesting()); + QuicServerId canonical_server_id("www.google.com", 443, false); + QuicCryptoClientConfig::CachedState* state = + config.LookupOrCreate(canonical_server_id); + // TODO(rch): Populate other fields of |state|. + state->set_source_address_token("TOKEN"); + state->SetProofValid(); + + QuicServerId other_server_id("mail.google.com", 443, false); + config.InitializeFrom(other_server_id, canonical_server_id, &config); + QuicCryptoClientConfig::CachedState* other = + config.LookupOrCreate(other_server_id); + + EXPECT_EQ(state->server_config(), other->server_config()); + EXPECT_EQ(state->source_address_token(), other->source_address_token()); + EXPECT_EQ(state->certs(), other->certs()); + EXPECT_EQ(1u, other->generation_counter()); +} + +TEST_F(QuicCryptoClientConfigTest, Canonical) { + QuicCryptoClientConfig config(crypto_test_utils::ProofVerifierForTesting()); + config.AddCanonicalSuffix(".google.com"); + QuicServerId canonical_id1("www.google.com", 443, false); + QuicServerId canonical_id2("mail.google.com", 443, false); + QuicCryptoClientConfig::CachedState* state = + config.LookupOrCreate(canonical_id1); + // TODO(rch): Populate other fields of |state|. + state->set_source_address_token("TOKEN"); + state->SetProofValid(); + + QuicCryptoClientConfig::CachedState* other = + config.LookupOrCreate(canonical_id2); + + EXPECT_TRUE(state->IsEmpty()); + EXPECT_EQ(state->server_config(), other->server_config()); + EXPECT_EQ(state->source_address_token(), other->source_address_token()); + EXPECT_EQ(state->certs(), other->certs()); + EXPECT_EQ(1u, other->generation_counter()); + + QuicServerId different_id("mail.google.org", 443, false); + EXPECT_TRUE(config.LookupOrCreate(different_id)->IsEmpty()); +} + +TEST_F(QuicCryptoClientConfigTest, CanonicalNotUsedIfNotValid) { + QuicCryptoClientConfig config(crypto_test_utils::ProofVerifierForTesting()); + config.AddCanonicalSuffix(".google.com"); + QuicServerId canonical_id1("www.google.com", 443, false); + QuicServerId canonical_id2("mail.google.com", 443, false); + QuicCryptoClientConfig::CachedState* state = + config.LookupOrCreate(canonical_id1); + // TODO(rch): Populate other fields of |state|. + state->set_source_address_token("TOKEN"); + + // Do not set the proof as valid, and check that it is not used + // as a canonical entry. + EXPECT_TRUE(config.LookupOrCreate(canonical_id2)->IsEmpty()); +} + +TEST_F(QuicCryptoClientConfigTest, ClearCachedStates) { + QuicCryptoClientConfig config(crypto_test_utils::ProofVerifierForTesting()); + + // Create two states on different origins. + struct TestCase { + TestCase(const std::string& host, QuicCryptoClientConfig* config) + : server_id(host, 443, false), + state(config->LookupOrCreate(server_id)) { + // TODO(rch): Populate other fields of |state|. + CryptoHandshakeMessage scfg; + scfg.set_tag(kSCFG); + uint64_t future = 1; + scfg.SetValue(kEXPY, future); + scfg.SetStringPiece(kSCID, "12345678"); + std::string details; + state->SetServerConfig(scfg.GetSerialized().AsStringPiece(), + QuicWallTime::FromUNIXSeconds(0), + QuicWallTime::FromUNIXSeconds(future), &details); + + std::vector certs(1); + certs[0] = "Hello Cert for " + host; + state->SetProof(certs, "cert_sct", "chlo_hash", "signature"); + state->set_source_address_token("TOKEN"); + state->SetProofValid(); + + // The generation counter starts at 2, because proof has been once + // invalidated in SetServerConfig(). + EXPECT_EQ(2u, state->generation_counter()); + } + + QuicServerId server_id; + QuicCryptoClientConfig::CachedState* state; + } test_cases[] = {TestCase("www.google.com", &config), + TestCase("www.example.com", &config)}; + + // Verify LookupOrCreate returns the same data. + for (const TestCase& test_case : test_cases) { + QuicCryptoClientConfig::CachedState* other = + config.LookupOrCreate(test_case.server_id); + EXPECT_EQ(test_case.state, other); + EXPECT_EQ(2u, other->generation_counter()); + } + + // Clear the cached state for www.google.com. + OneServerIdFilter google_com_filter(&test_cases[0].server_id); + config.ClearCachedStates(google_com_filter); + + // Verify LookupOrCreate doesn't have any data for google.com. + QuicCryptoClientConfig::CachedState* cleared_cache = + config.LookupOrCreate(test_cases[0].server_id); + + EXPECT_EQ(test_cases[0].state, cleared_cache); + EXPECT_FALSE(cleared_cache->proof_valid()); + EXPECT_TRUE(cleared_cache->server_config().empty()); + EXPECT_TRUE(cleared_cache->certs().empty()); + EXPECT_TRUE(cleared_cache->cert_sct().empty()); + EXPECT_TRUE(cleared_cache->signature().empty()); + EXPECT_EQ(3u, cleared_cache->generation_counter()); + + // But it still does for www.example.com. + QuicCryptoClientConfig::CachedState* existing_cache = + config.LookupOrCreate(test_cases[1].server_id); + + EXPECT_EQ(test_cases[1].state, existing_cache); + EXPECT_TRUE(existing_cache->proof_valid()); + EXPECT_FALSE(existing_cache->server_config().empty()); + EXPECT_FALSE(existing_cache->certs().empty()); + EXPECT_FALSE(existing_cache->cert_sct().empty()); + EXPECT_FALSE(existing_cache->signature().empty()); + EXPECT_EQ(2u, existing_cache->generation_counter()); + + // Clear all cached states. + AllServerIdsFilter all_server_ids; + config.ClearCachedStates(all_server_ids); + + // The data for www.example.com should now be cleared as well. + cleared_cache = config.LookupOrCreate(test_cases[1].server_id); + + EXPECT_EQ(test_cases[1].state, cleared_cache); + EXPECT_FALSE(cleared_cache->proof_valid()); + EXPECT_TRUE(cleared_cache->server_config().empty()); + EXPECT_TRUE(cleared_cache->certs().empty()); + EXPECT_TRUE(cleared_cache->cert_sct().empty()); + EXPECT_TRUE(cleared_cache->signature().empty()); + EXPECT_EQ(3u, cleared_cache->generation_counter()); +} + +TEST_F(QuicCryptoClientConfigTest, ProcessReject) { + CryptoHandshakeMessage rej; + crypto_test_utils::FillInDummyReject(&rej); + + // Now process the rejection. + QuicCryptoClientConfig::CachedState cached; + quiche::QuicheReferenceCountedPointer + out_params(new QuicCryptoNegotiatedParameters); + std::string error; + QuicCryptoClientConfig config(crypto_test_utils::ProofVerifierForTesting()); + EXPECT_THAT( + config.ProcessRejection( + rej, QuicWallTime::FromUNIXSeconds(0), + AllSupportedVersionsWithQuicCrypto().front().transport_version, "", + &cached, out_params, &error), + IsQuicNoError()); +} + +TEST_F(QuicCryptoClientConfigTest, ProcessRejectWithLongTTL) { + CryptoHandshakeMessage rej; + crypto_test_utils::FillInDummyReject(&rej); + QuicTime::Delta one_week = QuicTime::Delta::FromSeconds(kNumSecondsPerWeek); + int64_t long_ttl = 3 * one_week.ToSeconds(); + rej.SetValue(kSTTL, long_ttl); + + // Now process the rejection. + QuicCryptoClientConfig::CachedState cached; + quiche::QuicheReferenceCountedPointer + out_params(new QuicCryptoNegotiatedParameters); + std::string error; + QuicCryptoClientConfig config(crypto_test_utils::ProofVerifierForTesting()); + EXPECT_THAT( + config.ProcessRejection( + rej, QuicWallTime::FromUNIXSeconds(0), + AllSupportedVersionsWithQuicCrypto().front().transport_version, "", + &cached, out_params, &error), + IsQuicNoError()); + cached.SetProofValid(); + EXPECT_FALSE(cached.IsComplete(QuicWallTime::FromUNIXSeconds(long_ttl))); + EXPECT_FALSE( + cached.IsComplete(QuicWallTime::FromUNIXSeconds(one_week.ToSeconds()))); + EXPECT_TRUE(cached.IsComplete( + QuicWallTime::FromUNIXSeconds(one_week.ToSeconds() - 1))); +} + +TEST_F(QuicCryptoClientConfigTest, ServerNonceinSHLO) { + // Test that the server must include a nonce in the SHLO. + CryptoHandshakeMessage msg; + msg.set_tag(kSHLO); + // Choose the latest version. + ParsedQuicVersionVector supported_versions; + ParsedQuicVersion version = AllSupportedVersions().front(); + supported_versions.push_back(version); + msg.SetVersionVector(kVER, supported_versions); + + QuicCryptoClientConfig config(crypto_test_utils::ProofVerifierForTesting()); + QuicCryptoClientConfig::CachedState cached; + quiche::QuicheReferenceCountedPointer + out_params(new QuicCryptoNegotiatedParameters); + std::string error_details; + EXPECT_THAT(config.ProcessServerHello(msg, EmptyQuicConnectionId(), version, + supported_versions, &cached, out_params, + &error_details), + IsError(QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER)); + EXPECT_EQ("server hello missing server nonce", error_details); +} + +// Test that PopulateFromCanonicalConfig() handles the case of multiple entries +// in |canonical_server_map_|. +TEST_F(QuicCryptoClientConfigTest, MultipleCanonicalEntries) { + QuicCryptoClientConfig config(crypto_test_utils::ProofVerifierForTesting()); + config.AddCanonicalSuffix(".google.com"); + QuicServerId canonical_server_id1("www.google.com", 443, false); + QuicCryptoClientConfig::CachedState* state1 = + config.LookupOrCreate(canonical_server_id1); + + CryptoHandshakeMessage scfg; + scfg.set_tag(kSCFG); + scfg.SetStringPiece(kSCID, "12345678"); + std::string details; + QuicWallTime now = QuicWallTime::FromUNIXSeconds(1); + QuicWallTime expiry = QuicWallTime::FromUNIXSeconds(2); + state1->SetServerConfig(scfg.GetSerialized().AsStringPiece(), now, expiry, + &details); + state1->set_source_address_token("TOKEN"); + state1->SetProofValid(); + EXPECT_FALSE(state1->IsEmpty()); + + // This will have the same |suffix_server_id| as |canonical_server_id1|, + // therefore |*state2| will be initialized from |*state1|. + QuicServerId canonical_server_id2("mail.google.com", 443, false); + QuicCryptoClientConfig::CachedState* state2 = + config.LookupOrCreate(canonical_server_id2); + EXPECT_FALSE(state2->IsEmpty()); + const CryptoHandshakeMessage* const scfg2 = state2->GetServerConfig(); + ASSERT_TRUE(scfg2); + EXPECT_EQ(kSCFG, scfg2->tag()); + + // With a different |suffix_server_id|, this will return an empty CachedState. + config.AddCanonicalSuffix(".example.com"); + QuicServerId canonical_server_id3("www.example.com", 443, false); + QuicCryptoClientConfig::CachedState* state3 = + config.LookupOrCreate(canonical_server_id3); + EXPECT_TRUE(state3->IsEmpty()); + const CryptoHandshakeMessage* const scfg3 = state3->GetServerConfig(); + EXPECT_FALSE(scfg3); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/quic_crypto_proof.cc b/quiche/quic/core/crypto/quic_crypto_proof.cc new file mode 100644 index 000000000000..a449c267edb3 --- /dev/null +++ b/quiche/quic/core/crypto/quic_crypto_proof.cc @@ -0,0 +1,12 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/quic_crypto_proof.h" + +namespace quic { + +QuicCryptoProof::QuicCryptoProof() + : send_expect_ct_header(false), cert_matched_sni(false) {} + +} // namespace quic diff --git a/quiche/quic/core/crypto/quic_crypto_proof.h b/quiche/quic/core/crypto/quic_crypto_proof.h new file mode 100644 index 000000000000..579d2640c81a --- /dev/null +++ b/quiche/quic/core/crypto/quic_crypto_proof.h @@ -0,0 +1,32 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_QUIC_CRYPTO_PROOF_H_ +#define QUICHE_QUIC_CORE_CRYPTO_QUIC_CRYPTO_PROOF_H_ + +#include + +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// Contains the crypto-related data provided by ProofSource +struct QUIC_EXPORT_PRIVATE QuicCryptoProof { + QuicCryptoProof(); + + // Signature generated by ProofSource + std::string signature; + // SCTList (RFC6962) to be sent to the client, if it supports receiving it. + std::string leaf_cert_scts; + // Should the Expect-CT header be sent on the connection where the + // certificate is used. + bool send_expect_ct_header; + // Did the selected leaf certificate contain a SubjectAltName that included + // the requested SNI. + bool cert_matched_sni; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_QUIC_CRYPTO_PROOF_H_ diff --git a/quiche/quic/core/crypto/quic_crypto_server_config.cc b/quiche/quic/core/crypto/quic_crypto_server_config.cc new file mode 100644 index 000000000000..17cc94f0abef --- /dev/null +++ b/quiche/quic/core/crypto/quic_crypto_server_config.cc @@ -0,0 +1,1896 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/quic_crypto_server_config.h" + +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "openssl/sha.h" +#include "openssl/ssl.h" +#include "quiche/quic/core/crypto/aes_128_gcm_12_decrypter.h" +#include "quiche/quic/core/crypto/aes_128_gcm_12_encrypter.h" +#include "quiche/quic/core/crypto/cert_compressor.h" +#include "quiche/quic/core/crypto/certificate_view.h" +#include "quiche/quic/core/crypto/chacha20_poly1305_encrypter.h" +#include "quiche/quic/core/crypto/channel_id.h" +#include "quiche/quic/core/crypto/crypto_framer.h" +#include "quiche/quic/core/crypto/crypto_handshake_message.h" +#include "quiche/quic/core/crypto/crypto_utils.h" +#include "quiche/quic/core/crypto/curve25519_key_exchange.h" +#include "quiche/quic/core/crypto/key_exchange.h" +#include "quiche/quic/core/crypto/p256_key_exchange.h" +#include "quiche/quic/core/crypto/proof_source.h" +#include "quiche/quic/core/crypto/quic_decrypter.h" +#include "quiche/quic/core/crypto/quic_encrypter.h" +#include "quiche/quic/core/crypto/quic_hkdf.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/crypto/tls_server_connection.h" +#include "quiche/quic/core/proto/crypto_server_config_proto.h" +#include "quiche/quic/core/proto/source_address_token_proto.h" +#include "quiche/quic/core/quic_clock.h" +#include "quiche/quic/core/quic_connection_context.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_socket_address_coder.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_hostname_utils.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/platform/api/quic_testvalue.h" +#include "quiche/common/platform/api/quiche_reference_counted.h" + +namespace quic { + +namespace { + +// kMultiplier is the multiple of the CHLO message size that a REJ message +// must stay under when the client doesn't present a valid source-address +// token. This is used to protect QUIC from amplification attacks. +// TODO(rch): Reduce this to 2 again once b/25933682 is fixed. +const size_t kMultiplier = 3; + +const int kMaxTokenAddresses = 4; + +std::string DeriveSourceAddressTokenKey( + absl::string_view source_address_token_secret) { + QuicHKDF hkdf(source_address_token_secret, absl::string_view() /* no salt */, + "QUIC source address token key", + CryptoSecretBoxer::GetKeySize(), 0 /* no fixed IV needed */, + 0 /* no subkey secret */); + return std::string(hkdf.server_write_key()); +} + +// Default source for creating KeyExchange objects. +class DefaultKeyExchangeSource : public KeyExchangeSource { + public: + DefaultKeyExchangeSource() = default; + ~DefaultKeyExchangeSource() override = default; + + std::unique_ptr Create( + std::string /*server_config_id*/, bool /* is_fallback */, QuicTag type, + absl::string_view private_key) override { + if (private_key.empty()) { + QUIC_LOG(WARNING) << "Server config contains key exchange method without " + "corresponding private key of type " + << QuicTagToString(type); + return nullptr; + } + + std::unique_ptr ka = + CreateLocalSynchronousKeyExchange(type, private_key); + if (!ka) { + QUIC_LOG(WARNING) << "Failed to create key exchange method of type " + << QuicTagToString(type); + } + return ka; + } +}; + +// Returns true if the PDMD field from the client hello demands an X509 +// certificate. +bool ClientDemandsX509Proof(const CryptoHandshakeMessage& client_hello) { + QuicTagVector their_proof_demands; + + if (client_hello.GetTaglist(kPDMD, &their_proof_demands) != QUIC_NO_ERROR) { + return false; + } + + for (const QuicTag tag : their_proof_demands) { + if (tag == kX509) { + return true; + } + } + return false; +} + +std::string FormatCryptoHandshakeMessageForTrace( + const CryptoHandshakeMessage* message) { + if (message == nullptr) { + return ""; + } + + std::string s = QuicTagToString(message->tag()); + + // Append the reasons for REJ. + if (const auto it = message->tag_value_map().find(kRREJ); + it != message->tag_value_map().end()) { + const std::string& value = it->second; + // The value is a vector of uint32_t(s). + if (value.size() % sizeof(uint32_t) == 0) { + absl::StrAppend(&s, " RREJ:["); + // Append comma-separated list of reasons to |s|. + for (size_t j = 0; j < value.size(); j += sizeof(uint32_t)) { + uint32_t reason; + memcpy(&reason, value.data() + j, sizeof(reason)); + if (j > 0) { + absl::StrAppend(&s, ","); + } + absl::StrAppend(&s, CryptoUtils::HandshakeFailureReasonToString( + static_cast(reason))); + } + absl::StrAppend(&s, "]"); + } else { + absl::StrAppendFormat(&s, " RREJ:[unexpected length:%u]", value.size()); + } + } + + return s; +} + +} // namespace + +// static +std::unique_ptr KeyExchangeSource::Default() { + return std::make_unique(); +} + +class ValidateClientHelloHelper { + public: + // Note: stores a pointer to a unique_ptr, and std::moves the unique_ptr when + // ValidationComplete is called. + ValidateClientHelloHelper( + quiche::QuicheReferenceCountedPointer< + ValidateClientHelloResultCallback::Result> + result, + std::unique_ptr* done_cb) + : result_(std::move(result)), done_cb_(done_cb) {} + ValidateClientHelloHelper(const ValidateClientHelloHelper&) = delete; + ValidateClientHelloHelper& operator=(const ValidateClientHelloHelper&) = + delete; + + ~ValidateClientHelloHelper() { + QUIC_BUG_IF(quic_bug_12963_1, done_cb_ != nullptr) + << "Deleting ValidateClientHelloHelper with a pending callback."; + } + + void ValidationComplete( + QuicErrorCode error_code, const char* error_details, + std::unique_ptr proof_source_details) { + result_->error_code = error_code; + result_->error_details = error_details; + (*done_cb_)->Run(std::move(result_), std::move(proof_source_details)); + DetachCallback(); + } + + void DetachCallback() { + QUIC_BUG_IF(quic_bug_10630_1, done_cb_ == nullptr) + << "Callback already detached."; + done_cb_ = nullptr; + } + + private: + quiche::QuicheReferenceCountedPointer< + ValidateClientHelloResultCallback::Result> + result_; + std::unique_ptr* done_cb_; +}; + +// static +const char QuicCryptoServerConfig::TESTING[] = "secret string for testing"; + +ClientHelloInfo::ClientHelloInfo(const QuicIpAddress& in_client_ip, + QuicWallTime in_now) + : client_ip(in_client_ip), now(in_now), valid_source_address_token(false) {} + +ClientHelloInfo::ClientHelloInfo(const ClientHelloInfo& other) = default; + +ClientHelloInfo::~ClientHelloInfo() {} + +PrimaryConfigChangedCallback::PrimaryConfigChangedCallback() {} + +PrimaryConfigChangedCallback::~PrimaryConfigChangedCallback() {} + +ValidateClientHelloResultCallback::Result::Result( + const CryptoHandshakeMessage& in_client_hello, QuicIpAddress in_client_ip, + QuicWallTime in_now) + : client_hello(in_client_hello), + info(in_client_ip, in_now), + error_code(QUIC_NO_ERROR) {} + +ValidateClientHelloResultCallback::Result::~Result() {} + +ValidateClientHelloResultCallback::ValidateClientHelloResultCallback() {} + +ValidateClientHelloResultCallback::~ValidateClientHelloResultCallback() {} + +ProcessClientHelloResultCallback::ProcessClientHelloResultCallback() {} + +ProcessClientHelloResultCallback::~ProcessClientHelloResultCallback() {} + +QuicCryptoServerConfig::ConfigOptions::ConfigOptions() + : expiry_time(QuicWallTime::Zero()), + channel_id_enabled(false), + p256(false) {} + +QuicCryptoServerConfig::ConfigOptions::ConfigOptions( + const ConfigOptions& other) = default; + +QuicCryptoServerConfig::ConfigOptions::~ConfigOptions() {} + +QuicCryptoServerConfig::ProcessClientHelloContext:: + ~ProcessClientHelloContext() { + if (done_cb_ != nullptr) { + QUIC_LOG(WARNING) + << "Deleting ProcessClientHelloContext with a pending callback."; + } +} + +void QuicCryptoServerConfig::ProcessClientHelloContext::Fail( + QuicErrorCode error, const std::string& error_details) { + QUIC_TRACEPRINTF("ProcessClientHello failed: error=%s, details=%s", + QuicErrorCodeToString(error), error_details); + done_cb_->Run(error, error_details, nullptr, nullptr, nullptr); + done_cb_ = nullptr; +} + +void QuicCryptoServerConfig::ProcessClientHelloContext::Succeed( + std::unique_ptr message, + std::unique_ptr diversification_nonce, + std::unique_ptr proof_source_details) { + QUIC_TRACEPRINTF("ProcessClientHello succeeded: %s", + FormatCryptoHandshakeMessageForTrace(message.get())); + + done_cb_->Run(QUIC_NO_ERROR, std::string(), std::move(message), + std::move(diversification_nonce), + std::move(proof_source_details)); + done_cb_ = nullptr; +} + +QuicCryptoServerConfig::QuicCryptoServerConfig( + absl::string_view source_address_token_secret, + QuicRandom* server_nonce_entropy, std::unique_ptr proof_source, + std::unique_ptr key_exchange_source) + : replay_protection_(true), + chlo_multiplier_(kMultiplier), + configs_lock_(), + primary_config_(nullptr), + next_config_promotion_time_(QuicWallTime::Zero()), + proof_source_(std::move(proof_source)), + key_exchange_source_(std::move(key_exchange_source)), + ssl_ctx_(TlsServerConnection::CreateSslCtx(proof_source_.get())), + source_address_token_future_secs_(3600), + source_address_token_lifetime_secs_(86400), + enable_serving_sct_(false), + rejection_observer_(nullptr), + pad_rej_(true), + pad_shlo_(true), + validate_chlo_size_(true), + validate_source_address_token_(true) { + QUICHE_DCHECK(proof_source_.get()); + source_address_token_boxer_.SetKeys( + {DeriveSourceAddressTokenKey(source_address_token_secret)}); + + // Generate a random key and orbit for server nonces. + server_nonce_entropy->RandBytes(server_nonce_orbit_, + sizeof(server_nonce_orbit_)); + const size_t key_size = server_nonce_boxer_.GetKeySize(); + std::unique_ptr key_bytes(new uint8_t[key_size]); + server_nonce_entropy->RandBytes(key_bytes.get(), key_size); + + server_nonce_boxer_.SetKeys( + {std::string(reinterpret_cast(key_bytes.get()), key_size)}); +} + +QuicCryptoServerConfig::~QuicCryptoServerConfig() {} + +// static +QuicServerConfigProtobuf QuicCryptoServerConfig::GenerateConfig( + QuicRandom* rand, const QuicClock* clock, const ConfigOptions& options) { + CryptoHandshakeMessage msg; + + const std::string curve25519_private_key = + Curve25519KeyExchange::NewPrivateKey(rand); + std::unique_ptr curve25519 = + Curve25519KeyExchange::New(curve25519_private_key); + absl::string_view curve25519_public_value = curve25519->public_value(); + + std::string encoded_public_values; + // First three bytes encode the length of the public value. + QUICHE_DCHECK_LT(curve25519_public_value.size(), (1U << 24)); + encoded_public_values.push_back( + static_cast(curve25519_public_value.size())); + encoded_public_values.push_back( + static_cast(curve25519_public_value.size() >> 8)); + encoded_public_values.push_back( + static_cast(curve25519_public_value.size() >> 16)); + encoded_public_values.append(curve25519_public_value.data(), + curve25519_public_value.size()); + + std::string p256_private_key; + if (options.p256) { + p256_private_key = P256KeyExchange::NewPrivateKey(); + std::unique_ptr p256( + P256KeyExchange::New(p256_private_key)); + absl::string_view p256_public_value = p256->public_value(); + + QUICHE_DCHECK_LT(p256_public_value.size(), (1U << 24)); + encoded_public_values.push_back( + static_cast(p256_public_value.size())); + encoded_public_values.push_back( + static_cast(p256_public_value.size() >> 8)); + encoded_public_values.push_back( + static_cast(p256_public_value.size() >> 16)); + encoded_public_values.append(p256_public_value.data(), + p256_public_value.size()); + } + + msg.set_tag(kSCFG); + if (options.p256) { + msg.SetVector(kKEXS, QuicTagVector{kC255, kP256}); + } else { + msg.SetVector(kKEXS, QuicTagVector{kC255}); + } + msg.SetVector(kAEAD, QuicTagVector{kAESG, kCC20}); + msg.SetStringPiece(kPUBS, encoded_public_values); + + if (options.expiry_time.IsZero()) { + const QuicWallTime now = clock->WallNow(); + const QuicWallTime expiry = now.Add(QuicTime::Delta::FromSeconds( + 60 * 60 * 24 * 180 /* 180 days, ~six months */)); + const uint64_t expiry_seconds = expiry.ToUNIXSeconds(); + msg.SetValue(kEXPY, expiry_seconds); + } else { + msg.SetValue(kEXPY, options.expiry_time.ToUNIXSeconds()); + } + + char orbit_bytes[kOrbitSize]; + if (options.orbit.size() == sizeof(orbit_bytes)) { + memcpy(orbit_bytes, options.orbit.data(), sizeof(orbit_bytes)); + } else { + QUICHE_DCHECK(options.orbit.empty()); + rand->RandBytes(orbit_bytes, sizeof(orbit_bytes)); + } + msg.SetStringPiece(kORBT, + absl::string_view(orbit_bytes, sizeof(orbit_bytes))); + + if (options.channel_id_enabled) { + msg.SetVector(kPDMD, QuicTagVector{kCHID}); + } + + if (options.id.empty()) { + // We need to ensure that the SCID changes whenever the server config does + // thus we make it a hash of the rest of the server config. + std::unique_ptr serialized = + CryptoFramer::ConstructHandshakeMessage(msg); + + uint8_t scid_bytes[SHA256_DIGEST_LENGTH]; + SHA256(reinterpret_cast(serialized->data()), + serialized->length(), scid_bytes); + // The SCID is a truncated SHA-256 digest. + static_assert(16 <= SHA256_DIGEST_LENGTH, "SCID length too high."); + msg.SetStringPiece( + kSCID, + absl::string_view(reinterpret_cast(scid_bytes), 16)); + } else { + msg.SetStringPiece(kSCID, options.id); + } + // Don't put new tags below this point. The SCID generation should hash over + // everything but itself and so extra tags should be added prior to the + // preceding if block. + + std::unique_ptr serialized = + CryptoFramer::ConstructHandshakeMessage(msg); + + QuicServerConfigProtobuf config; + config.set_config(std::string(serialized->AsStringPiece())); + QuicServerConfigProtobuf::PrivateKey* curve25519_key = config.add_key(); + curve25519_key->set_tag(kC255); + curve25519_key->set_private_key(curve25519_private_key); + + if (options.p256) { + QuicServerConfigProtobuf::PrivateKey* p256_key = config.add_key(); + p256_key->set_tag(kP256); + p256_key->set_private_key(p256_private_key); + } + + return config; +} + +std::unique_ptr QuicCryptoServerConfig::AddConfig( + const QuicServerConfigProtobuf& protobuf, const QuicWallTime now) { + std::unique_ptr msg = + CryptoFramer::ParseMessage(protobuf.config()); + + if (!msg) { + QUIC_LOG(WARNING) << "Failed to parse server config message"; + return nullptr; + } + + quiche::QuicheReferenceCountedPointer config = + ParseConfigProtobuf(protobuf, /* is_fallback = */ false); + if (!config) { + QUIC_LOG(WARNING) << "Failed to parse server config message"; + return nullptr; + } + + { + QuicWriterMutexLock locked(&configs_lock_); + if (configs_.find(config->id) != configs_.end()) { + QUIC_LOG(WARNING) << "Failed to add config because another with the same " + "server config id already exists: " + << absl::BytesToHexString(config->id); + return nullptr; + } + + configs_[config->id] = config; + SelectNewPrimaryConfig(now); + QUICHE_DCHECK(primary_config_.get()); + QUICHE_DCHECK_EQ(configs_.find(primary_config_->id)->second.get(), + primary_config_.get()); + } + + return msg; +} + +std::unique_ptr +QuicCryptoServerConfig::AddDefaultConfig(QuicRandom* rand, + const QuicClock* clock, + const ConfigOptions& options) { + return AddConfig(GenerateConfig(rand, clock, options), clock->WallNow()); +} + +bool QuicCryptoServerConfig::SetConfigs( + const std::vector& protobufs, + const QuicServerConfigProtobuf* fallback_protobuf, const QuicWallTime now) { + std::vector> parsed_configs; + for (auto& protobuf : protobufs) { + quiche::QuicheReferenceCountedPointer config = + ParseConfigProtobuf(protobuf, /* is_fallback = */ false); + if (!config) { + QUIC_LOG(WARNING) << "Rejecting QUIC configs because of above errors"; + return false; + } + + parsed_configs.push_back(config); + } + + quiche::QuicheReferenceCountedPointer fallback_config; + if (fallback_protobuf != nullptr) { + fallback_config = + ParseConfigProtobuf(*fallback_protobuf, /* is_fallback = */ true); + if (!fallback_config) { + QUIC_LOG(WARNING) << "Rejecting QUIC configs because of above errors"; + return false; + } + QUIC_LOG(INFO) << "Fallback config has scid " + << absl::BytesToHexString(fallback_config->id); + parsed_configs.push_back(fallback_config); + } else { + QUIC_LOG(INFO) << "No fallback config provided"; + } + + if (parsed_configs.empty()) { + QUIC_LOG(WARNING) + << "Rejecting QUIC configs because new config list is empty."; + return false; + } + + QUIC_LOG(INFO) << "Updating configs:"; + + QuicWriterMutexLock locked(&configs_lock_); + ConfigMap new_configs; + + for (const quiche::QuicheReferenceCountedPointer& config : + parsed_configs) { + auto it = configs_.find(config->id); + if (it != configs_.end()) { + QUIC_LOG(INFO) << "Keeping scid: " << absl::BytesToHexString(config->id) + << " orbit: " + << absl::BytesToHexString(absl::string_view( + reinterpret_cast(config->orbit), + kOrbitSize)) + << " new primary_time " + << config->primary_time.ToUNIXSeconds() + << " old primary_time " + << it->second->primary_time.ToUNIXSeconds() + << " new priority " << config->priority << " old priority " + << it->second->priority; + // Update primary_time and priority. + it->second->primary_time = config->primary_time; + it->second->priority = config->priority; + new_configs.insert(*it); + } else { + QUIC_LOG(INFO) << "Adding scid: " << absl::BytesToHexString(config->id) + << " orbit: " + << absl::BytesToHexString(absl::string_view( + reinterpret_cast(config->orbit), + kOrbitSize)) + << " primary_time " << config->primary_time.ToUNIXSeconds() + << " priority " << config->priority; + new_configs.emplace(config->id, config); + } + } + + configs_ = std::move(new_configs); + fallback_config_ = fallback_config; + SelectNewPrimaryConfig(now); + QUICHE_DCHECK(primary_config_.get()); + QUICHE_DCHECK_EQ(configs_.find(primary_config_->id)->second.get(), + primary_config_.get()); + + return true; +} + +void QuicCryptoServerConfig::SetSourceAddressTokenKeys( + const std::vector& keys) { + // TODO(b/208866709) + source_address_token_boxer_.SetKeys(keys); +} + +std::vector QuicCryptoServerConfig::GetConfigIds() const { + QuicReaderMutexLock locked(&configs_lock_); + std::vector scids; + for (auto it = configs_.begin(); it != configs_.end(); ++it) { + scids.push_back(it->first); + } + return scids; +} + +void QuicCryptoServerConfig::ValidateClientHello( + const CryptoHandshakeMessage& client_hello, + const QuicSocketAddress& client_address, + const QuicSocketAddress& server_address, QuicTransportVersion version, + const QuicClock* clock, + quiche::QuicheReferenceCountedPointer signed_config, + std::unique_ptr done_cb) const { + const QuicWallTime now(clock->WallNow()); + + quiche::QuicheReferenceCountedPointer< + ValidateClientHelloResultCallback::Result> + result(new ValidateClientHelloResultCallback::Result( + client_hello, client_address.host(), now)); + + absl::string_view requested_scid; + // We ignore here the return value from GetStringPiece. If there is no SCID + // tag, EvaluateClientHello will discover that because GetCurrentConfigs will + // not have found the requested config (i.e. because none of the configs will + // have an empty string as its id). + client_hello.GetStringPiece(kSCID, &requested_scid); + Configs configs; + if (!GetCurrentConfigs(now, requested_scid, + /* old_primary_config = */ nullptr, &configs)) { + result->error_code = QUIC_CRYPTO_INTERNAL_ERROR; + result->error_details = "No configurations loaded"; + } + signed_config->config = configs.primary; + + if (result->error_code == QUIC_NO_ERROR) { + // QUIC requires a new proof for each CHLO so clear any existing proof. + signed_config->chain = nullptr; + signed_config->proof.signature = ""; + signed_config->proof.leaf_cert_scts = ""; + EvaluateClientHello(server_address, client_address, version, configs, + result, std::move(done_cb)); + } else { + done_cb->Run(result, /* details = */ nullptr); + } +} + +class QuicCryptoServerConfig::ProcessClientHelloCallback + : public ProofSource::Callback { + public: + ProcessClientHelloCallback(const QuicCryptoServerConfig* config, + std::unique_ptr context, + const Configs& configs) + : config_(config), context_(std::move(context)), configs_(configs) {} + + void Run( + bool ok, + const quiche::QuicheReferenceCountedPointer& chain, + const QuicCryptoProof& proof, + std::unique_ptr details) override { + if (ok) { + context_->signed_config()->chain = chain; + context_->signed_config()->proof = proof; + } + config_->ProcessClientHelloAfterGetProof(!ok, std::move(details), + std::move(context_), configs_); + } + + private: + const QuicCryptoServerConfig* config_; + std::unique_ptr context_; + const Configs configs_; +}; + +class QuicCryptoServerConfig::ProcessClientHelloAfterGetProofCallback + : public AsynchronousKeyExchange::Callback { + public: + ProcessClientHelloAfterGetProofCallback( + const QuicCryptoServerConfig* config, + std::unique_ptr proof_source_details, + QuicTag key_exchange_type, std::unique_ptr out, + absl::string_view public_value, + std::unique_ptr context, + const Configs& configs) + : config_(config), + proof_source_details_(std::move(proof_source_details)), + key_exchange_type_(key_exchange_type), + out_(std::move(out)), + public_value_(public_value), + context_(std::move(context)), + configs_(configs) {} + + void Run(bool ok) override { + config_->ProcessClientHelloAfterCalculateSharedKeys( + !ok, std::move(proof_source_details_), key_exchange_type_, + std::move(out_), public_value_, std::move(context_), configs_); + } + + private: + const QuicCryptoServerConfig* config_; + std::unique_ptr proof_source_details_; + const QuicTag key_exchange_type_; + std::unique_ptr out_; + const std::string public_value_; + std::unique_ptr context_; + const Configs configs_; + std::unique_ptr done_cb_; +}; + +class QuicCryptoServerConfig::SendRejectWithFallbackConfigCallback + : public ProofSource::Callback { + public: + SendRejectWithFallbackConfigCallback( + const QuicCryptoServerConfig* config, + std::unique_ptr context, + quiche::QuicheReferenceCountedPointer fallback_config) + : config_(config), + context_(std::move(context)), + fallback_config_(fallback_config) {} + + // Capture |chain| and |proof| into the signed config, and then invoke + // SendRejectWithFallbackConfigAfterGetProof. + void Run( + bool ok, + const quiche::QuicheReferenceCountedPointer& chain, + const QuicCryptoProof& proof, + std::unique_ptr details) override { + if (ok) { + context_->signed_config()->chain = chain; + context_->signed_config()->proof = proof; + } + config_->SendRejectWithFallbackConfigAfterGetProof( + !ok, std::move(details), std::move(context_), fallback_config_); + } + + private: + const QuicCryptoServerConfig* config_; + std::unique_ptr context_; + quiche::QuicheReferenceCountedPointer fallback_config_; +}; + +void QuicCryptoServerConfig::ProcessClientHello( + quiche::QuicheReferenceCountedPointer< + ValidateClientHelloResultCallback::Result> + validate_chlo_result, + bool reject_only, QuicConnectionId connection_id, + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, ParsedQuicVersion version, + const ParsedQuicVersionVector& supported_versions, const QuicClock* clock, + QuicRandom* rand, QuicCompressedCertsCache* compressed_certs_cache, + quiche::QuicheReferenceCountedPointer + params, + quiche::QuicheReferenceCountedPointer signed_config, + QuicByteCount total_framing_overhead, QuicByteCount chlo_packet_size, + std::shared_ptr done_cb) const { + QUICHE_DCHECK(done_cb); + auto context = std::make_unique( + validate_chlo_result, reject_only, connection_id, server_address, + client_address, version, supported_versions, clock, rand, + compressed_certs_cache, params, signed_config, total_framing_overhead, + chlo_packet_size, std::move(done_cb)); + + // Verify that various parts of the CHLO are valid + std::string error_details; + QuicErrorCode valid = CryptoUtils::ValidateClientHello( + context->client_hello(), context->version(), + context->supported_versions(), &error_details); + if (valid != QUIC_NO_ERROR) { + context->Fail(valid, error_details); + return; + } + + absl::string_view requested_scid; + context->client_hello().GetStringPiece(kSCID, &requested_scid); + Configs configs; + if (!GetCurrentConfigs(context->clock()->WallNow(), requested_scid, + signed_config->config, &configs)) { + context->Fail(QUIC_CRYPTO_INTERNAL_ERROR, "No configurations loaded"); + return; + } + + if (context->validate_chlo_result()->error_code != QUIC_NO_ERROR) { + context->Fail(context->validate_chlo_result()->error_code, + context->validate_chlo_result()->error_details); + return; + } + + if (!ClientDemandsX509Proof(context->client_hello())) { + context->Fail(QUIC_UNSUPPORTED_PROOF_DEMAND, "Missing or invalid PDMD"); + return; + } + + // No need to get a new proof if one was already generated. + if (!context->signed_config()->chain) { + const std::string chlo_hash = CryptoUtils::HashHandshakeMessage( + context->client_hello(), Perspective::IS_SERVER); + const QuicSocketAddress server_address = context->server_address(); + const std::string sni = std::string(context->info().sni); + const QuicTransportVersion transport_version = context->transport_version(); + + auto cb = std::make_unique( + this, std::move(context), configs); + + QUICHE_DCHECK(proof_source_.get()); + proof_source_->GetProof(server_address, client_address, sni, + configs.primary->serialized, transport_version, + chlo_hash, std::move(cb)); + return; + } + + ProcessClientHelloAfterGetProof( + /* found_error = */ false, /* proof_source_details = */ nullptr, + std::move(context), configs); +} + +void QuicCryptoServerConfig::ProcessClientHelloAfterGetProof( + bool found_error, + std::unique_ptr proof_source_details, + std::unique_ptr context, + const Configs& configs) const { + QUIC_BUG_IF(quic_bug_12963_2, + !QuicUtils::IsConnectionIdValidForVersion( + context->connection_id(), context->transport_version())) + << "ProcessClientHelloAfterGetProof: attempted to use connection ID " + << context->connection_id() << " which is invalid with version " + << context->version(); + + if (context->info().reject_reasons.empty()) { + if (!context->signed_config() || !context->signed_config()->chain) { + // No chain. + context->validate_chlo_result()->info.reject_reasons.push_back( + SERVER_CONFIG_UNKNOWN_CONFIG_FAILURE); + } else if (!ValidateExpectedLeafCertificate( + context->client_hello(), + context->signed_config()->chain->certs)) { + // Has chain but leaf is invalid. + context->validate_chlo_result()->info.reject_reasons.push_back( + INVALID_EXPECTED_LEAF_CERTIFICATE); + } + } + + if (found_error) { + context->Fail(QUIC_HANDSHAKE_FAILED, "Failed to get proof"); + return; + } + + auto out_diversification_nonce = std::make_unique(); + + absl::string_view cert_sct; + if (context->client_hello().GetStringPiece(kCertificateSCTTag, &cert_sct) && + cert_sct.empty()) { + context->params()->sct_supported_by_client = true; + } + + auto out = std::make_unique(); + if (!context->info().reject_reasons.empty() || !configs.requested) { + BuildRejectionAndRecordStats(*context, *configs.primary, + context->info().reject_reasons, out.get()); + context->Succeed(std::move(out), std::move(out_diversification_nonce), + std::move(proof_source_details)); + return; + } + + if (context->reject_only()) { + context->Succeed(std::move(out), std::move(out_diversification_nonce), + std::move(proof_source_details)); + return; + } + + QuicTagVector their_aeads; + QuicTagVector their_key_exchanges; + if (context->client_hello().GetTaglist(kAEAD, &their_aeads) != + QUIC_NO_ERROR || + context->client_hello().GetTaglist(kKEXS, &their_key_exchanges) != + QUIC_NO_ERROR || + their_aeads.size() != 1 || their_key_exchanges.size() != 1) { + context->Fail(QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER, + "Missing or invalid AEAD or KEXS"); + return; + } + + size_t key_exchange_index; + if (!FindMutualQuicTag(configs.requested->aead, their_aeads, + &context->params()->aead, nullptr) || + !FindMutualQuicTag(configs.requested->kexs, their_key_exchanges, + &context->params()->key_exchange, + &key_exchange_index)) { + context->Fail(QUIC_CRYPTO_NO_SUPPORT, "Unsupported AEAD or KEXS"); + return; + } + + absl::string_view public_value; + if (!context->client_hello().GetStringPiece(kPUBS, &public_value)) { + context->Fail(QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER, + "Missing public value"); + return; + } + + // Allow testing a specific adversarial case in which a client sends a public + // value of incorrect size. + AdjustTestValue("quic::QuicCryptoServerConfig::public_value_adjust", + &public_value); + + const AsynchronousKeyExchange* key_exchange = + configs.requested->key_exchanges[key_exchange_index].get(); + std::string* initial_premaster_secret = + &context->params()->initial_premaster_secret; + auto cb = std::make_unique( + this, std::move(proof_source_details), key_exchange->type(), + std::move(out), public_value, std::move(context), configs); + key_exchange->CalculateSharedKeyAsync(public_value, initial_premaster_secret, + std::move(cb)); +} + +void QuicCryptoServerConfig::ProcessClientHelloAfterCalculateSharedKeys( + bool found_error, + std::unique_ptr proof_source_details, + QuicTag key_exchange_type, std::unique_ptr out, + absl::string_view public_value, + std::unique_ptr context, + const Configs& configs) const { + QUIC_BUG_IF(quic_bug_12963_3, + !QuicUtils::IsConnectionIdValidForVersion( + context->connection_id(), context->transport_version())) + << "ProcessClientHelloAfterCalculateSharedKeys:" + " attempted to use connection ID " + << context->connection_id() << " which is invalid with version " + << context->version(); + + if (found_error) { + // If we are already using the fallback config, or there is no fallback + // config to use, just bail out of the handshake. + if (configs.fallback == nullptr || + context->signed_config()->config == configs.fallback) { + context->Fail(QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER, + "Failed to calculate shared key"); + } else { + SendRejectWithFallbackConfig(std::move(context), configs.fallback); + } + return; + } + + if (!context->info().sni.empty()) { + context->params()->sni = + QuicHostnameUtils::NormalizeHostname(context->info().sni); + } + + std::string hkdf_suffix; + const QuicData& client_hello_serialized = + context->client_hello().GetSerialized(); + hkdf_suffix.reserve(context->connection_id().length() + + client_hello_serialized.length() + + configs.requested->serialized.size()); + hkdf_suffix.append(context->connection_id().data(), + context->connection_id().length()); + hkdf_suffix.append(client_hello_serialized.data(), + client_hello_serialized.length()); + hkdf_suffix.append(configs.requested->serialized); + QUICHE_DCHECK(proof_source_.get()); + if (context->signed_config()->chain->certs.empty()) { + context->Fail(QUIC_CRYPTO_INTERNAL_ERROR, "Failed to get certs"); + return; + } + hkdf_suffix.append(context->signed_config()->chain->certs.at(0)); + + absl::string_view cetv_ciphertext; + if (configs.requested->channel_id_enabled && + context->client_hello().GetStringPiece(kCETV, &cetv_ciphertext)) { + CryptoHandshakeMessage client_hello_copy(context->client_hello()); + client_hello_copy.Erase(kCETV); + client_hello_copy.Erase(kPAD); + + const QuicData& client_hello_copy_serialized = + client_hello_copy.GetSerialized(); + std::string hkdf_input; + hkdf_input.append(QuicCryptoConfig::kCETVLabel, + strlen(QuicCryptoConfig::kCETVLabel) + 1); + hkdf_input.append(context->connection_id().data(), + context->connection_id().length()); + hkdf_input.append(client_hello_copy_serialized.data(), + client_hello_copy_serialized.length()); + hkdf_input.append(configs.requested->serialized); + + CrypterPair crypters; + if (!CryptoUtils::DeriveKeys( + context->version(), context->params()->initial_premaster_secret, + context->params()->aead, context->info().client_nonce, + context->info().server_nonce, pre_shared_key_, hkdf_input, + Perspective::IS_SERVER, CryptoUtils::Diversification::Never(), + &crypters, nullptr /* subkey secret */)) { + context->Fail(QUIC_CRYPTO_SYMMETRIC_KEY_SETUP_FAILED, + "Symmetric key setup failed"); + return; + } + + char plaintext[kMaxOutgoingPacketSize]; + size_t plaintext_length = 0; + const bool success = crypters.decrypter->DecryptPacket( + 0 /* packet number */, absl::string_view() /* associated data */, + cetv_ciphertext, plaintext, &plaintext_length, kMaxOutgoingPacketSize); + if (!success) { + context->Fail(QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER, + "CETV decryption failure"); + return; + } + std::unique_ptr cetv(CryptoFramer::ParseMessage( + absl::string_view(plaintext, plaintext_length))); + if (!cetv) { + context->Fail(QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER, "CETV parse error"); + return; + } + + absl::string_view key, signature; + if (cetv->GetStringPiece(kCIDK, &key) && + cetv->GetStringPiece(kCIDS, &signature)) { + if (!ChannelIDVerifier::Verify(key, hkdf_input, signature)) { + context->Fail(QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER, + "ChannelID signature failure"); + return; + } + + context->params()->channel_id = std::string(key); + } + } + + std::string hkdf_input; + size_t label_len = strlen(QuicCryptoConfig::kInitialLabel) + 1; + hkdf_input.reserve(label_len + hkdf_suffix.size()); + hkdf_input.append(QuicCryptoConfig::kInitialLabel, label_len); + hkdf_input.append(hkdf_suffix); + + auto out_diversification_nonce = std::make_unique(); + context->rand()->RandBytes(out_diversification_nonce->data(), + out_diversification_nonce->size()); + CryptoUtils::Diversification diversification = + CryptoUtils::Diversification::Now(out_diversification_nonce.get()); + if (!CryptoUtils::DeriveKeys( + context->version(), context->params()->initial_premaster_secret, + context->params()->aead, context->info().client_nonce, + context->info().server_nonce, pre_shared_key_, hkdf_input, + Perspective::IS_SERVER, diversification, + &context->params()->initial_crypters, + &context->params()->initial_subkey_secret)) { + context->Fail(QUIC_CRYPTO_SYMMETRIC_KEY_SETUP_FAILED, + "Symmetric key setup failed"); + return; + } + + std::string forward_secure_public_value; + std::unique_ptr forward_secure_key_exchange = + CreateLocalSynchronousKeyExchange(key_exchange_type, context->rand()); + if (!forward_secure_key_exchange) { + QUIC_DLOG(WARNING) << "Failed to create keypair"; + context->Fail(QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER, + "Failed to create keypair"); + return; + } + + forward_secure_public_value = + std::string(forward_secure_key_exchange->public_value()); + if (!forward_secure_key_exchange->CalculateSharedKeySync( + public_value, &context->params()->forward_secure_premaster_secret)) { + context->Fail(QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER, + "Invalid public value"); + return; + } + + std::string forward_secure_hkdf_input; + label_len = strlen(QuicCryptoConfig::kForwardSecureLabel) + 1; + forward_secure_hkdf_input.reserve(label_len + hkdf_suffix.size()); + forward_secure_hkdf_input.append(QuicCryptoConfig::kForwardSecureLabel, + label_len); + forward_secure_hkdf_input.append(hkdf_suffix); + + std::string shlo_nonce; + shlo_nonce = NewServerNonce(context->rand(), context->info().now); + out->SetStringPiece(kServerNonceTag, shlo_nonce); + + if (!CryptoUtils::DeriveKeys( + context->version(), + context->params()->forward_secure_premaster_secret, + context->params()->aead, context->info().client_nonce, + shlo_nonce.empty() ? context->info().server_nonce : shlo_nonce, + pre_shared_key_, forward_secure_hkdf_input, Perspective::IS_SERVER, + CryptoUtils::Diversification::Never(), + &context->params()->forward_secure_crypters, + &context->params()->subkey_secret)) { + context->Fail(QUIC_CRYPTO_SYMMETRIC_KEY_SETUP_FAILED, + "Symmetric key setup failed"); + return; + } + + out->set_tag(kSHLO); + out->SetVersionVector(kVER, context->supported_versions()); + out->SetStringPiece( + kSourceAddressTokenTag, + NewSourceAddressToken(*configs.requested->source_address_token_boxer, + context->info().source_address_tokens, + context->client_address().host(), context->rand(), + context->info().now, nullptr)); + QuicSocketAddressCoder address_coder(context->client_address()); + out->SetStringPiece(kCADR, address_coder.Encode()); + out->SetStringPiece(kPUBS, forward_secure_public_value); + + context->Succeed(std::move(out), std::move(out_diversification_nonce), + std::move(proof_source_details)); +} + +void QuicCryptoServerConfig::SendRejectWithFallbackConfig( + std::unique_ptr context, + quiche::QuicheReferenceCountedPointer fallback_config) const { + // We failed to calculate a shared initial key, likely because we tried to use + // a remote key-exchange service which could not be reached. We want to send + // a REJ which tells the client to use a different ServerConfig which + // corresponds to a local keypair. To generate the REJ we need to request a + // new proof. + const std::string chlo_hash = CryptoUtils::HashHandshakeMessage( + context->client_hello(), Perspective::IS_SERVER); + const QuicSocketAddress server_address = context->server_address(); + const std::string sni(context->info().sni); + const QuicTransportVersion transport_version = context->transport_version(); + + const QuicSocketAddress& client_address = context->client_address(); + auto cb = std::make_unique( + this, std::move(context), fallback_config); + proof_source_->GetProof(server_address, client_address, sni, + fallback_config->serialized, transport_version, + chlo_hash, std::move(cb)); +} + +void QuicCryptoServerConfig::SendRejectWithFallbackConfigAfterGetProof( + bool found_error, + std::unique_ptr proof_source_details, + std::unique_ptr context, + quiche::QuicheReferenceCountedPointer fallback_config) const { + if (found_error) { + context->Fail(QUIC_HANDSHAKE_FAILED, "Failed to get proof"); + return; + } + + auto out = std::make_unique(); + BuildRejectionAndRecordStats(*context, *fallback_config, + {SERVER_CONFIG_UNKNOWN_CONFIG_FAILURE}, + out.get()); + + context->Succeed(std::move(out), std::make_unique(), + std::move(proof_source_details)); +} + +quiche::QuicheReferenceCountedPointer +QuicCryptoServerConfig::GetConfigWithScid( + absl::string_view requested_scid) const { + configs_lock_.AssertReaderHeld(); + + if (!requested_scid.empty()) { + auto it = configs_.find((std::string(requested_scid))); + if (it != configs_.end()) { + // We'll use the config that the client requested in order to do + // key-agreement. + return quiche::QuicheReferenceCountedPointer(it->second); + } + } + + return quiche::QuicheReferenceCountedPointer(); +} + +bool QuicCryptoServerConfig::GetCurrentConfigs( + const QuicWallTime& now, absl::string_view requested_scid, + quiche::QuicheReferenceCountedPointer old_primary_config, + Configs* configs) const { + QuicReaderMutexLock locked(&configs_lock_); + + if (!primary_config_) { + return false; + } + + if (IsNextConfigReady(now)) { + configs_lock_.ReaderUnlock(); + configs_lock_.WriterLock(); + SelectNewPrimaryConfig(now); + QUICHE_DCHECK(primary_config_.get()); + QUICHE_DCHECK_EQ(configs_.find(primary_config_->id)->second.get(), + primary_config_.get()); + configs_lock_.WriterUnlock(); + configs_lock_.ReaderLock(); + } + + if (old_primary_config != nullptr) { + configs->primary = old_primary_config; + } else { + configs->primary = primary_config_; + } + configs->requested = GetConfigWithScid(requested_scid); + configs->fallback = fallback_config_; + + return true; +} + +// ConfigPrimaryTimeLessThan is a comparator that implements "less than" for +// Config's based on their primary_time. +// static +bool QuicCryptoServerConfig::ConfigPrimaryTimeLessThan( + const quiche::QuicheReferenceCountedPointer& a, + const quiche::QuicheReferenceCountedPointer& b) { + if (a->primary_time.IsBefore(b->primary_time) || + b->primary_time.IsBefore(a->primary_time)) { + // Primary times differ. + return a->primary_time.IsBefore(b->primary_time); + } else if (a->priority != b->priority) { + // Primary times are equal, sort backwards by priority. + return a->priority < b->priority; + } else { + // Primary times and priorities are equal, sort by config id. + return a->id < b->id; + } +} + +void QuicCryptoServerConfig::SelectNewPrimaryConfig( + const QuicWallTime now) const { + std::vector> configs; + configs.reserve(configs_.size()); + + for (auto it = configs_.begin(); it != configs_.end(); ++it) { + // TODO(avd) Exclude expired configs? + configs.push_back(it->second); + } + + if (configs.empty()) { + if (primary_config_ != nullptr) { + QUIC_BUG(quic_bug_10630_2) + << "No valid QUIC server config. Keeping the current config."; + } else { + QUIC_BUG(quic_bug_10630_3) << "No valid QUIC server config."; + } + return; + } + + std::sort(configs.begin(), configs.end(), ConfigPrimaryTimeLessThan); + + quiche::QuicheReferenceCountedPointer best_candidate = configs[0]; + + for (size_t i = 0; i < configs.size(); ++i) { + const quiche::QuicheReferenceCountedPointer config(configs[i]); + if (!config->primary_time.IsAfter(now)) { + if (config->primary_time.IsAfter(best_candidate->primary_time)) { + best_candidate = config; + } + continue; + } + + // This is the first config with a primary_time in the future. Thus the + // previous Config should be the primary and this one should determine the + // next_config_promotion_time_. + quiche::QuicheReferenceCountedPointer new_primary = best_candidate; + if (i == 0) { + // We need the primary_time of the next config. + if (configs.size() > 1) { + next_config_promotion_time_ = configs[1]->primary_time; + } else { + next_config_promotion_time_ = QuicWallTime::Zero(); + } + } else { + next_config_promotion_time_ = config->primary_time; + } + + if (primary_config_) { + primary_config_->is_primary = false; + } + primary_config_ = new_primary; + new_primary->is_primary = true; + QUIC_DLOG(INFO) << "New primary config. orbit: " + << absl::BytesToHexString( + absl::string_view(reinterpret_cast( + primary_config_->orbit), + kOrbitSize)); + if (primary_config_changed_cb_ != nullptr) { + primary_config_changed_cb_->Run(primary_config_->id); + } + + return; + } + + // All config's primary times are in the past. We should make the most recent + // and highest priority candidate primary. + quiche::QuicheReferenceCountedPointer new_primary = best_candidate; + if (primary_config_) { + primary_config_->is_primary = false; + } + primary_config_ = new_primary; + new_primary->is_primary = true; + QUIC_DLOG(INFO) << "New primary config. orbit: " + << absl::BytesToHexString(absl::string_view( + reinterpret_cast(primary_config_->orbit), + kOrbitSize)) + << " scid: " << absl::BytesToHexString(primary_config_->id); + next_config_promotion_time_ = QuicWallTime::Zero(); + if (primary_config_changed_cb_ != nullptr) { + primary_config_changed_cb_->Run(primary_config_->id); + } +} + +void QuicCryptoServerConfig::EvaluateClientHello( + const QuicSocketAddress& /*server_address*/, + const QuicSocketAddress& /*client_address*/, + QuicTransportVersion /*version*/, const Configs& configs, + quiche::QuicheReferenceCountedPointer< + ValidateClientHelloResultCallback::Result> + client_hello_state, + std::unique_ptr done_cb) const { + ValidateClientHelloHelper helper(client_hello_state, &done_cb); + + const CryptoHandshakeMessage& client_hello = client_hello_state->client_hello; + ClientHelloInfo* info = &(client_hello_state->info); + + if (client_hello.GetStringPiece(kSNI, &info->sni) && + !QuicHostnameUtils::IsValidSNI(info->sni)) { + helper.ValidationComplete(QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER, + "Invalid SNI name", nullptr); + return; + } + + client_hello.GetStringPiece(kUAID, &info->user_agent_id); + + HandshakeFailureReason source_address_token_error = MAX_FAILURE_REASON; + if (validate_source_address_token_) { + absl::string_view srct; + if (client_hello.GetStringPiece(kSourceAddressTokenTag, &srct)) { + Config& config = + configs.requested != nullptr ? *configs.requested : *configs.primary; + source_address_token_error = + ParseSourceAddressToken(*config.source_address_token_boxer, srct, + info->source_address_tokens); + + if (source_address_token_error == HANDSHAKE_OK) { + source_address_token_error = ValidateSourceAddressTokens( + info->source_address_tokens, info->client_ip, info->now, + &client_hello_state->cached_network_params); + } + info->valid_source_address_token = + (source_address_token_error == HANDSHAKE_OK); + } else { + source_address_token_error = SOURCE_ADDRESS_TOKEN_INVALID_FAILURE; + } + } else { + source_address_token_error = HANDSHAKE_OK; + info->valid_source_address_token = true; + } + + if (!configs.requested) { + absl::string_view requested_scid; + if (client_hello.GetStringPiece(kSCID, &requested_scid)) { + info->reject_reasons.push_back(SERVER_CONFIG_UNKNOWN_CONFIG_FAILURE); + } else { + info->reject_reasons.push_back(SERVER_CONFIG_INCHOATE_HELLO_FAILURE); + } + // No server config with the requested ID. + helper.ValidationComplete(QUIC_NO_ERROR, "", nullptr); + return; + } + + if (!client_hello.GetStringPiece(kNONC, &info->client_nonce)) { + info->reject_reasons.push_back(SERVER_CONFIG_INCHOATE_HELLO_FAILURE); + // Report no client nonce as INCHOATE_HELLO_FAILURE. + helper.ValidationComplete(QUIC_NO_ERROR, "", nullptr); + return; + } + + if (source_address_token_error != HANDSHAKE_OK) { + info->reject_reasons.push_back(source_address_token_error); + // No valid source address token. + } + + if (info->client_nonce.size() != kNonceSize) { + info->reject_reasons.push_back(CLIENT_NONCE_INVALID_FAILURE); + // Invalid client nonce. + QUIC_LOG_FIRST_N(ERROR, 2) + << "Invalid client nonce: " << client_hello.DebugString(); + QUIC_DLOG(INFO) << "Invalid client nonce."; + } + + // Server nonce is optional, and used for key derivation if present. + client_hello.GetStringPiece(kServerNonceTag, &info->server_nonce); + + // If the server nonce is empty and we're requiring handshake confirmation + // for DoS reasons then we must reject the CHLO. + if (GetQuicReloadableFlag(quic_require_handshake_confirmation) && + info->server_nonce.empty()) { + info->reject_reasons.push_back(SERVER_NONCE_REQUIRED_FAILURE); + } + helper.ValidationComplete(QUIC_NO_ERROR, "", + std::unique_ptr()); +} + +void QuicCryptoServerConfig::BuildServerConfigUpdateMessage( + QuicTransportVersion version, absl::string_view chlo_hash, + const SourceAddressTokens& previous_source_address_tokens, + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, const QuicClock* clock, + QuicRandom* rand, QuicCompressedCertsCache* compressed_certs_cache, + const QuicCryptoNegotiatedParameters& params, + const CachedNetworkParameters* cached_network_params, + std::unique_ptr cb) const { + std::string serialized; + std::string source_address_token; + { + QuicReaderMutexLock locked(&configs_lock_); + serialized = primary_config_->serialized; + source_address_token = NewSourceAddressToken( + *primary_config_->source_address_token_boxer, + previous_source_address_tokens, client_address.host(), rand, + clock->WallNow(), cached_network_params); + } + + CryptoHandshakeMessage message; + message.set_tag(kSCUP); + message.SetStringPiece(kSCFG, serialized); + message.SetStringPiece(kSourceAddressTokenTag, source_address_token); + + auto proof_source_cb = + std::make_unique( + this, compressed_certs_cache, params, std::move(message), + std::move(cb)); + + proof_source_->GetProof(server_address, client_address, params.sni, + serialized, version, chlo_hash, + std::move(proof_source_cb)); +} + +QuicCryptoServerConfig::BuildServerConfigUpdateMessageProofSourceCallback:: + ~BuildServerConfigUpdateMessageProofSourceCallback() {} + +QuicCryptoServerConfig::BuildServerConfigUpdateMessageProofSourceCallback:: + BuildServerConfigUpdateMessageProofSourceCallback( + const QuicCryptoServerConfig* config, + QuicCompressedCertsCache* compressed_certs_cache, + const QuicCryptoNegotiatedParameters& params, + CryptoHandshakeMessage message, + std::unique_ptr cb) + : config_(config), + compressed_certs_cache_(compressed_certs_cache), + client_cached_cert_hashes_(params.client_cached_cert_hashes), + sct_supported_by_client_(params.sct_supported_by_client), + sni_(params.sni), + message_(std::move(message)), + cb_(std::move(cb)) {} + +void QuicCryptoServerConfig::BuildServerConfigUpdateMessageProofSourceCallback:: + Run(bool ok, + const quiche::QuicheReferenceCountedPointer& chain, + const QuicCryptoProof& proof, + std::unique_ptr details) { + config_->FinishBuildServerConfigUpdateMessage( + compressed_certs_cache_, client_cached_cert_hashes_, + sct_supported_by_client_, sni_, ok, chain, proof.signature, + proof.leaf_cert_scts, std::move(details), std::move(message_), + std::move(cb_)); +} + +void QuicCryptoServerConfig::FinishBuildServerConfigUpdateMessage( + QuicCompressedCertsCache* compressed_certs_cache, + const std::string& client_cached_cert_hashes, bool sct_supported_by_client, + const std::string& sni, bool ok, + const quiche::QuicheReferenceCountedPointer& chain, + const std::string& signature, const std::string& leaf_cert_sct, + std::unique_ptr /*details*/, + CryptoHandshakeMessage message, + std::unique_ptr cb) const { + if (!ok) { + cb->Run(false, message); + return; + } + + const std::string compressed = + CompressChain(compressed_certs_cache, chain, client_cached_cert_hashes); + + message.SetStringPiece(kCertificateTag, compressed); + message.SetStringPiece(kPROF, signature); + if (sct_supported_by_client && enable_serving_sct_) { + if (leaf_cert_sct.empty()) { + QUIC_LOG_EVERY_N_SEC(WARNING, 60) + << "SCT is expected but it is empty. SNI: " << sni; + } else { + message.SetStringPiece(kCertificateSCTTag, leaf_cert_sct); + } + } + + cb->Run(true, message); +} + +void QuicCryptoServerConfig::BuildRejectionAndRecordStats( + const ProcessClientHelloContext& context, const Config& config, + const std::vector& reject_reasons, + CryptoHandshakeMessage* out) const { + BuildRejection(context, config, reject_reasons, out); + if (rejection_observer_ != nullptr) { + rejection_observer_->OnRejectionBuilt(reject_reasons, out); + } +} + +void QuicCryptoServerConfig::BuildRejection( + const ProcessClientHelloContext& context, const Config& config, + const std::vector& reject_reasons, + CryptoHandshakeMessage* out) const { + const QuicWallTime now = context.clock()->WallNow(); + + out->set_tag(kREJ); + out->SetStringPiece(kSCFG, config.serialized); + out->SetStringPiece( + kSourceAddressTokenTag, + NewSourceAddressToken( + *config.source_address_token_boxer, + context.info().source_address_tokens, context.info().client_ip, + context.rand(), context.info().now, + &context.validate_chlo_result()->cached_network_params)); + out->SetValue(kSTTL, config.expiry_time.AbsoluteDifference(now).ToSeconds()); + if (replay_protection_) { + out->SetStringPiece(kServerNonceTag, + NewServerNonce(context.rand(), context.info().now)); + } + + // Send client the reject reason for debugging purposes. + QUICHE_DCHECK_LT(0u, reject_reasons.size()); + out->SetVector(kRREJ, reject_reasons); + + // The client may have requested a certificate chain. + if (!ClientDemandsX509Proof(context.client_hello())) { + QUIC_BUG(quic_bug_10630_4) + << "x509 certificates not supported in proof demand"; + return; + } + + absl::string_view client_cached_cert_hashes; + if (context.client_hello().GetStringPiece(kCCRT, + &client_cached_cert_hashes)) { + context.params()->client_cached_cert_hashes = + std::string(client_cached_cert_hashes); + } else { + context.params()->client_cached_cert_hashes.clear(); + } + + const std::string compressed = CompressChain( + context.compressed_certs_cache(), context.signed_config()->chain, + context.params()->client_cached_cert_hashes); + + QUICHE_DCHECK_GT(context.chlo_packet_size(), context.client_hello().size()); + // kREJOverheadBytes is a very rough estimate of how much of a REJ + // message is taken up by things other than the certificates. + // STK: 56 bytes + // SNO: 56 bytes + // SCFG + // SCID: 16 bytes + // PUBS: 38 bytes + const size_t kREJOverheadBytes = 166; + // max_unverified_size is the number of bytes that the certificate chain, + // signature, and (optionally) signed certificate timestamp can consume before + // we will demand a valid source-address token. + const size_t max_unverified_size = + chlo_multiplier_ * + (context.chlo_packet_size() - context.total_framing_overhead()) - + kREJOverheadBytes; + static_assert(kClientHelloMinimumSize * kMultiplier >= kREJOverheadBytes, + "overhead calculation may underflow"); + bool should_return_sct = + context.params()->sct_supported_by_client && enable_serving_sct_; + const std::string& cert_sct = context.signed_config()->proof.leaf_cert_scts; + const size_t sct_size = should_return_sct ? cert_sct.size() : 0; + const size_t total_size = context.signed_config()->proof.signature.size() + + compressed.size() + sct_size; + if (context.info().valid_source_address_token || + total_size < max_unverified_size) { + out->SetStringPiece(kCertificateTag, compressed); + out->SetStringPiece(kPROF, context.signed_config()->proof.signature); + if (should_return_sct) { + if (cert_sct.empty()) { + // Log SNI and subject name for the leaf cert if its SCT is empty. + // This is for debugging b/28342827. + const std::vector& certs = + context.signed_config()->chain->certs; + std::string ca_subject; + if (!certs.empty()) { + std::unique_ptr view = + CertificateView::ParseSingleCertificate(certs[0]); + if (view != nullptr) { + absl::optional maybe_ca_subject = + view->GetHumanReadableSubject(); + if (maybe_ca_subject.has_value()) { + ca_subject = *maybe_ca_subject; + } + } + } + QUIC_LOG_EVERY_N_SEC(WARNING, 60) + << "SCT is expected but it is empty. sni: '" + << context.params()->sni << "' cert subject: '" << ca_subject + << "'"; + } else { + out->SetStringPiece(kCertificateSCTTag, cert_sct); + } + } + } else { + QUIC_LOG_EVERY_N_SEC(WARNING, 60) + << "Sending inchoate REJ for hostname: " << context.info().sni + << " signature: " << context.signed_config()->proof.signature.size() + << " cert: " << compressed.size() << " sct:" << sct_size + << " total: " << total_size << " max: " << max_unverified_size; + } +} + +std::string QuicCryptoServerConfig::CompressChain( + QuicCompressedCertsCache* compressed_certs_cache, + const quiche::QuicheReferenceCountedPointer& chain, + const std::string& client_cached_cert_hashes) { + // Check whether the compressed certs is available in the cache. + QUICHE_DCHECK(compressed_certs_cache); + const std::string* cached_value = compressed_certs_cache->GetCompressedCert( + chain, client_cached_cert_hashes); + if (cached_value) { + return *cached_value; + } + std::string compressed = + CertCompressor::CompressChain(chain->certs, client_cached_cert_hashes); + // Insert the newly compressed cert to cache. + compressed_certs_cache->Insert(chain, client_cached_cert_hashes, compressed); + return compressed; +} + +quiche::QuicheReferenceCountedPointer +QuicCryptoServerConfig::ParseConfigProtobuf( + const QuicServerConfigProtobuf& protobuf, bool is_fallback) const { + std::unique_ptr msg = + CryptoFramer::ParseMessage(protobuf.config()); + + if (!msg) { + QUIC_LOG(WARNING) << "Failed to parse server config message"; + return nullptr; + } + + if (msg->tag() != kSCFG) { + QUIC_LOG(WARNING) << "Server config message has tag " << msg->tag() + << ", but expected " << kSCFG; + return nullptr; + } + + quiche::QuicheReferenceCountedPointer config(new Config); + config->serialized = protobuf.config(); + config->source_address_token_boxer = &source_address_token_boxer_; + + if (protobuf.has_primary_time()) { + config->primary_time = + QuicWallTime::FromUNIXSeconds(protobuf.primary_time()); + } + + config->priority = protobuf.priority(); + + absl::string_view scid; + if (!msg->GetStringPiece(kSCID, &scid)) { + QUIC_LOG(WARNING) << "Server config message is missing SCID"; + return nullptr; + } + if (scid.empty()) { + QUIC_LOG(WARNING) << "Server config message contains an empty SCID"; + return nullptr; + } + config->id = std::string(scid); + + if (msg->GetTaglist(kAEAD, &config->aead) != QUIC_NO_ERROR) { + QUIC_LOG(WARNING) << "Server config message is missing AEAD"; + return nullptr; + } + + QuicTagVector kexs_tags; + if (msg->GetTaglist(kKEXS, &kexs_tags) != QUIC_NO_ERROR) { + QUIC_LOG(WARNING) << "Server config message is missing KEXS"; + return nullptr; + } + + absl::string_view orbit; + if (!msg->GetStringPiece(kORBT, &orbit)) { + QUIC_LOG(WARNING) << "Server config message is missing ORBT"; + return nullptr; + } + + if (orbit.size() != kOrbitSize) { + QUIC_LOG(WARNING) << "Orbit value in server config is the wrong length." + " Got " + << orbit.size() << " want " << kOrbitSize; + return nullptr; + } + static_assert(sizeof(config->orbit) == kOrbitSize, "incorrect orbit size"); + memcpy(config->orbit, orbit.data(), sizeof(config->orbit)); + + QuicTagVector proof_demand_tags; + if (msg->GetTaglist(kPDMD, &proof_demand_tags) == QUIC_NO_ERROR) { + for (QuicTag tag : proof_demand_tags) { + if (tag == kCHID) { + config->channel_id_enabled = true; + break; + } + } + } + + for (size_t i = 0; i < kexs_tags.size(); i++) { + const QuicTag tag = kexs_tags[i]; + std::string private_key; + + config->kexs.push_back(tag); + + for (int j = 0; j < protobuf.key_size(); j++) { + const QuicServerConfigProtobuf::PrivateKey& key = protobuf.key(i); + if (key.tag() == tag) { + private_key = key.private_key(); + break; + } + } + + std::unique_ptr ka = + key_exchange_source_->Create(config->id, is_fallback, tag, private_key); + if (!ka) { + return nullptr; + } + for (const auto& key_exchange : config->key_exchanges) { + if (key_exchange->type() == tag) { + QUIC_LOG(WARNING) << "Duplicate key exchange in config: " << tag; + return nullptr; + } + } + + config->key_exchanges.push_back(std::move(ka)); + } + + uint64_t expiry_seconds; + if (msg->GetUint64(kEXPY, &expiry_seconds) != QUIC_NO_ERROR) { + QUIC_LOG(WARNING) << "Server config message is missing EXPY"; + return nullptr; + } + config->expiry_time = QuicWallTime::FromUNIXSeconds(expiry_seconds); + + return config; +} + +void QuicCryptoServerConfig::set_replay_protection(bool on) { + replay_protection_ = on; +} + +void QuicCryptoServerConfig::set_chlo_multiplier(size_t multiplier) { + chlo_multiplier_ = multiplier; +} + +void QuicCryptoServerConfig::set_source_address_token_future_secs( + uint32_t future_secs) { + source_address_token_future_secs_ = future_secs; +} + +void QuicCryptoServerConfig::set_source_address_token_lifetime_secs( + uint32_t lifetime_secs) { + source_address_token_lifetime_secs_ = lifetime_secs; +} + +void QuicCryptoServerConfig::set_enable_serving_sct(bool enable_serving_sct) { + enable_serving_sct_ = enable_serving_sct; +} + +void QuicCryptoServerConfig::AcquirePrimaryConfigChangedCb( + std::unique_ptr cb) { + QuicWriterMutexLock locked(&configs_lock_); + primary_config_changed_cb_ = std::move(cb); +} + +std::string QuicCryptoServerConfig::NewSourceAddressToken( + const CryptoSecretBoxer& crypto_secret_boxer, + const SourceAddressTokens& previous_tokens, const QuicIpAddress& ip, + QuicRandom* rand, QuicWallTime now, + const CachedNetworkParameters* cached_network_params) const { + SourceAddressTokens source_address_tokens; + SourceAddressToken* source_address_token = source_address_tokens.add_tokens(); + source_address_token->set_ip(ip.DualStacked().ToPackedString()); + source_address_token->set_timestamp(now.ToUNIXSeconds()); + if (cached_network_params != nullptr) { + *(source_address_token->mutable_cached_network_parameters()) = + *cached_network_params; + } + + // Append previous tokens. + for (const SourceAddressToken& token : previous_tokens.tokens()) { + if (source_address_tokens.tokens_size() > kMaxTokenAddresses) { + break; + } + + if (token.ip() == source_address_token->ip()) { + // It's for the same IP address. + continue; + } + + if (ValidateSourceAddressTokenTimestamp(token, now) != HANDSHAKE_OK) { + continue; + } + + *(source_address_tokens.add_tokens()) = token; + } + + return crypto_secret_boxer.Box(rand, + source_address_tokens.SerializeAsString()); +} + +int QuicCryptoServerConfig::NumberOfConfigs() const { + QuicReaderMutexLock locked(&configs_lock_); + return configs_.size(); +} + +ProofSource* QuicCryptoServerConfig::proof_source() const { + return proof_source_.get(); +} + +SSL_CTX* QuicCryptoServerConfig::ssl_ctx() const { return ssl_ctx_.get(); } + +HandshakeFailureReason QuicCryptoServerConfig::ParseSourceAddressToken( + const CryptoSecretBoxer& crypto_secret_boxer, absl::string_view token, + SourceAddressTokens& tokens) const { + std::string storage; + absl::string_view plaintext; + if (!crypto_secret_boxer.Unbox(token, &storage, &plaintext)) { + return SOURCE_ADDRESS_TOKEN_DECRYPTION_FAILURE; + } + + if (!tokens.ParseFromArray(plaintext.data(), plaintext.size())) { + // Some clients might still be using the old source token format so + // attempt to parse that format. + // TODO(rch): remove this code once the new format is ubiquitous. + SourceAddressToken token; + if (!token.ParseFromArray(plaintext.data(), plaintext.size())) { + return SOURCE_ADDRESS_TOKEN_PARSE_FAILURE; + } + *tokens.add_tokens() = token; + } + + return HANDSHAKE_OK; +} + +HandshakeFailureReason QuicCryptoServerConfig::ValidateSourceAddressTokens( + const SourceAddressTokens& source_address_tokens, const QuicIpAddress& ip, + QuicWallTime now, CachedNetworkParameters* cached_network_params) const { + HandshakeFailureReason reason = + SOURCE_ADDRESS_TOKEN_DIFFERENT_IP_ADDRESS_FAILURE; + for (const SourceAddressToken& token : source_address_tokens.tokens()) { + reason = ValidateSingleSourceAddressToken(token, ip, now); + if (reason == HANDSHAKE_OK) { + if (cached_network_params != nullptr && + token.has_cached_network_parameters()) { + *cached_network_params = token.cached_network_parameters(); + } + break; + } + } + return reason; +} + +HandshakeFailureReason QuicCryptoServerConfig::ValidateSingleSourceAddressToken( + const SourceAddressToken& source_address_token, const QuicIpAddress& ip, + QuicWallTime now) const { + if (source_address_token.ip() != ip.DualStacked().ToPackedString()) { + // It's for a different IP address. + return SOURCE_ADDRESS_TOKEN_DIFFERENT_IP_ADDRESS_FAILURE; + } + + return ValidateSourceAddressTokenTimestamp(source_address_token, now); +} + +HandshakeFailureReason +QuicCryptoServerConfig::ValidateSourceAddressTokenTimestamp( + const SourceAddressToken& source_address_token, QuicWallTime now) const { + const QuicWallTime timestamp( + QuicWallTime::FromUNIXSeconds(source_address_token.timestamp())); + const QuicTime::Delta delta(now.AbsoluteDifference(timestamp)); + + if (now.IsBefore(timestamp) && + delta.ToSeconds() > source_address_token_future_secs_) { + return SOURCE_ADDRESS_TOKEN_CLOCK_SKEW_FAILURE; + } + + if (now.IsAfter(timestamp) && + delta.ToSeconds() > source_address_token_lifetime_secs_) { + return SOURCE_ADDRESS_TOKEN_EXPIRED_FAILURE; + } + + return HANDSHAKE_OK; +} + +// kServerNoncePlaintextSize is the number of bytes in an unencrypted server +// nonce. +static const size_t kServerNoncePlaintextSize = + 4 /* timestamp */ + 20 /* random bytes */; + +std::string QuicCryptoServerConfig::NewServerNonce(QuicRandom* rand, + QuicWallTime now) const { + const uint32_t timestamp = static_cast(now.ToUNIXSeconds()); + + uint8_t server_nonce[kServerNoncePlaintextSize]; + static_assert(sizeof(server_nonce) > sizeof(timestamp), "nonce too small"); + server_nonce[0] = static_cast(timestamp >> 24); + server_nonce[1] = static_cast(timestamp >> 16); + server_nonce[2] = static_cast(timestamp >> 8); + server_nonce[3] = static_cast(timestamp); + rand->RandBytes(&server_nonce[sizeof(timestamp)], + sizeof(server_nonce) - sizeof(timestamp)); + + return server_nonce_boxer_.Box( + rand, absl::string_view(reinterpret_cast(server_nonce), + sizeof(server_nonce))); +} + +bool QuicCryptoServerConfig::ValidateExpectedLeafCertificate( + const CryptoHandshakeMessage& client_hello, + const std::vector& certs) const { + if (certs.empty()) { + return false; + } + + uint64_t hash_from_client; + if (client_hello.GetUint64(kXLCT, &hash_from_client) != QUIC_NO_ERROR) { + return false; + } + return CryptoUtils::ComputeLeafCertHash(certs.at(0)) == hash_from_client; +} + +bool QuicCryptoServerConfig::IsNextConfigReady(QuicWallTime now) const { + return !next_config_promotion_time_.IsZero() && + !next_config_promotion_time_.IsAfter(now); +} + +QuicCryptoServerConfig::Config::Config() + : channel_id_enabled(false), + is_primary(false), + primary_time(QuicWallTime::Zero()), + expiry_time(QuicWallTime::Zero()), + priority(0), + source_address_token_boxer(nullptr) {} + +QuicCryptoServerConfig::Config::~Config() {} + +QuicSignedServerConfig::QuicSignedServerConfig() {} +QuicSignedServerConfig::~QuicSignedServerConfig() {} + +} // namespace quic diff --git a/quiche/quic/core/crypto/quic_crypto_server_config.h b/quiche/quic/core/crypto/quic_crypto_server_config.h new file mode 100644 index 000000000000..77be1f470ff6 --- /dev/null +++ b/quiche/quic/core/crypto/quic_crypto_server_config.h @@ -0,0 +1,948 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_QUIC_CRYPTO_SERVER_CONFIG_H_ +#define QUICHE_QUIC_CORE_CRYPTO_QUIC_CRYPTO_SERVER_CONFIG_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "openssl/base.h" +#include "quiche/quic/core/crypto/crypto_handshake.h" +#include "quiche/quic/core/crypto/crypto_handshake_message.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/crypto/crypto_secret_boxer.h" +#include "quiche/quic/core/crypto/key_exchange.h" +#include "quiche/quic/core/crypto/proof_source.h" +#include "quiche/quic/core/crypto/quic_compressed_certs_cache.h" +#include "quiche/quic/core/crypto/quic_crypto_proof.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/proto/cached_network_parameters_proto.h" +#include "quiche/quic/core/proto/source_address_token_proto.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_mutex.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/common/platform/api/quiche_reference_counted.h" + +namespace quic { + +class CryptoHandshakeMessage; +class ProofSource; +class QuicClock; +class QuicServerConfigProtobuf; +struct QuicSignedServerConfig; + +// ClientHelloInfo contains information about a client hello message that is +// only kept for as long as it's being processed. +struct QUIC_EXPORT_PRIVATE ClientHelloInfo { + ClientHelloInfo(const QuicIpAddress& in_client_ip, QuicWallTime in_now); + ClientHelloInfo(const ClientHelloInfo& other); + ~ClientHelloInfo(); + + // Inputs to EvaluateClientHello. + const QuicIpAddress client_ip; + const QuicWallTime now; + + // Outputs from EvaluateClientHello. + bool valid_source_address_token; + absl::string_view sni; + absl::string_view client_nonce; + absl::string_view server_nonce; + absl::string_view user_agent_id; + SourceAddressTokens source_address_tokens; + + // Errors from EvaluateClientHello. + std::vector reject_reasons; + static_assert(sizeof(QuicTag) == sizeof(uint32_t), "header out of sync"); +}; + +namespace test { +class QuicCryptoServerConfigPeer; +} // namespace test + +// Hook that allows application code to subscribe to primary config changes. +class QUIC_EXPORT_PRIVATE PrimaryConfigChangedCallback { + public: + PrimaryConfigChangedCallback(); + PrimaryConfigChangedCallback(const PrimaryConfigChangedCallback&) = delete; + PrimaryConfigChangedCallback& operator=(const PrimaryConfigChangedCallback&) = + delete; + virtual ~PrimaryConfigChangedCallback(); + virtual void Run(const std::string& scid) = 0; +}; + +// Callback used to accept the result of the |client_hello| validation step. +class QUIC_EXPORT_PRIVATE ValidateClientHelloResultCallback { + public: + // Opaque token that holds information about the client_hello and + // its validity. Can be interpreted by calling ProcessClientHello. + struct QUIC_EXPORT_PRIVATE Result : public quiche::QuicheReferenceCounted { + Result(const CryptoHandshakeMessage& in_client_hello, + QuicIpAddress in_client_ip, QuicWallTime in_now); + + CryptoHandshakeMessage client_hello; + ClientHelloInfo info; + QuicErrorCode error_code; + std::string error_details; + + // Populated if the CHLO STK contained a CachedNetworkParameters proto. + CachedNetworkParameters cached_network_params; + + protected: + ~Result() override; + }; + + ValidateClientHelloResultCallback(); + ValidateClientHelloResultCallback(const ValidateClientHelloResultCallback&) = + delete; + ValidateClientHelloResultCallback& operator=( + const ValidateClientHelloResultCallback&) = delete; + virtual ~ValidateClientHelloResultCallback(); + virtual void Run(quiche::QuicheReferenceCountedPointer result, + std::unique_ptr details) = 0; +}; + +// Callback used to accept the result of the ProcessClientHello method. +class QUIC_EXPORT_PRIVATE ProcessClientHelloResultCallback { + public: + ProcessClientHelloResultCallback(); + ProcessClientHelloResultCallback(const ProcessClientHelloResultCallback&) = + delete; + ProcessClientHelloResultCallback& operator=( + const ProcessClientHelloResultCallback&) = delete; + virtual ~ProcessClientHelloResultCallback(); + virtual void Run(QuicErrorCode error, const std::string& error_details, + std::unique_ptr message, + std::unique_ptr diversification_nonce, + std::unique_ptr details) = 0; +}; + +// Callback used to receive the results of a call to +// BuildServerConfigUpdateMessage. +class QUIC_EXPORT_PRIVATE BuildServerConfigUpdateMessageResultCallback { + public: + BuildServerConfigUpdateMessageResultCallback() = default; + virtual ~BuildServerConfigUpdateMessageResultCallback() {} + BuildServerConfigUpdateMessageResultCallback( + const BuildServerConfigUpdateMessageResultCallback&) = delete; + BuildServerConfigUpdateMessageResultCallback& operator=( + const BuildServerConfigUpdateMessageResultCallback&) = delete; + virtual void Run(bool ok, const CryptoHandshakeMessage& message) = 0; +}; + +// Object that is interested in built rejections (which include REJ, SREJ and +// cheap SREJ). +class QUIC_EXPORT_PRIVATE RejectionObserver { + public: + RejectionObserver() = default; + virtual ~RejectionObserver() {} + RejectionObserver(const RejectionObserver&) = delete; + RejectionObserver& operator=(const RejectionObserver&) = delete; + // Called after a rejection is built. + virtual void OnRejectionBuilt(const std::vector& reasons, + CryptoHandshakeMessage* out) const = 0; +}; + +// Factory for creating KeyExchange objects. +class QUIC_EXPORT_PRIVATE KeyExchangeSource { + public: + virtual ~KeyExchangeSource() = default; + + // Returns the default KeyExchangeSource. + static std::unique_ptr Default(); + + // Create a new KeyExchange using the curve specified by |type| using the + // specified private key. |private_key| may be empty for key-exchange + // mechanisms which do not hold the private key in-process. If |is_fallback| + // is set, |private_key| is required to be set, and a local key-exchange + // object should be returned. + virtual std::unique_ptr Create( + std::string server_config_id, bool is_fallback, QuicTag type, + absl::string_view private_key) = 0; +}; + +// QuicCryptoServerConfig contains the crypto configuration of a QUIC server. +// Unlike a client, a QUIC server can have multiple configurations active in +// order to support clients resuming with a previous configuration. +// TODO(agl): when adding configurations at runtime is added, this object will +// need to consider locking. +class QUIC_EXPORT_PRIVATE QuicCryptoServerConfig { + public: + // ConfigOptions contains options for generating server configs. + struct QUIC_EXPORT_PRIVATE ConfigOptions { + ConfigOptions(); + ConfigOptions(const ConfigOptions& other); + ~ConfigOptions(); + + // expiry_time is the time, in UNIX seconds, when the server config will + // expire. If unset, it defaults to the current time plus six months. + QuicWallTime expiry_time; + // channel_id_enabled controls whether the server config will indicate + // support for ChannelIDs. + bool channel_id_enabled; + // id contains the server config id for the resulting config. If empty, a + // random id is generated. + std::string id; + // orbit contains the kOrbitSize bytes of the orbit value for the server + // config. If |orbit| is empty then a random orbit is generated. + std::string orbit; + // p256 determines whether a P-256 public key will be included in the + // server config. Note that this breaks deterministic server-config + // generation since P-256 key generation doesn't use the QuicRandom given + // to GenerateConfig(). + bool p256; + }; + + // |source_address_token_secret|: secret key material used for encrypting and + // decrypting source address tokens. It can be of any length as it is fed + // into a KDF before use. In tests, use TESTING. + // |server_nonce_entropy|: an entropy source used to generate the orbit and + // key for server nonces, which are always local to a given instance of a + // server. Not owned. + // |proof_source|: provides certificate chains and signatures. + // |key_exchange_source|: provides key-exchange functionality. + QuicCryptoServerConfig( + absl::string_view source_address_token_secret, + QuicRandom* server_nonce_entropy, + std::unique_ptr proof_source, + std::unique_ptr key_exchange_source); + QuicCryptoServerConfig(const QuicCryptoServerConfig&) = delete; + QuicCryptoServerConfig& operator=(const QuicCryptoServerConfig&) = delete; + ~QuicCryptoServerConfig(); + + // TESTING is a magic parameter for passing to the constructor in tests. + static const char TESTING[]; + + // Generates a QuicServerConfigProtobuf protobuf suitable for + // AddConfig and SetConfigs. + static QuicServerConfigProtobuf GenerateConfig(QuicRandom* rand, + const QuicClock* clock, + const ConfigOptions& options); + + // AddConfig adds a QuicServerConfigProtobuf to the available configurations. + // It returns the SCFG message from the config if successful. |now| is used in + // conjunction with |protobuf->primary_time()| to determine whether the + // config should be made primary. + std::unique_ptr AddConfig( + const QuicServerConfigProtobuf& protobuf, QuicWallTime now); + + // AddDefaultConfig calls GenerateConfig to create a config and then calls + // AddConfig to add it. See the comment for |GenerateConfig| for details of + // the arguments. + std::unique_ptr AddDefaultConfig( + QuicRandom* rand, const QuicClock* clock, const ConfigOptions& options); + + // SetConfigs takes a vector of config protobufs and the current time. + // Configs are assumed to be uniquely identified by their server config ID. + // Previously unknown configs are added and possibly made the primary config + // depending on their |primary_time| and the value of |now|. Configs that are + // known, but are missing from the protobufs are deleted, unless they are + // currently the primary config. SetConfigs returns false if any errors were + // encountered and no changes to the QuicCryptoServerConfig will occur. + bool SetConfigs(const std::vector& protobufs, + const QuicServerConfigProtobuf* fallback_protobuf, + QuicWallTime now); + + // SetSourceAddressTokenKeys sets the keys to be tried, in order, when + // decrypting a source address token. Note that these keys are used *without* + // passing them through a KDF, in contradistinction to the + // |source_address_token_secret| argument to the constructor. + void SetSourceAddressTokenKeys(const std::vector& keys); + + // Get the server config ids for all known configs. + std::vector GetConfigIds() const; + + // Checks |client_hello| for gross errors and determines whether it can be + // shown to be fresh (i.e. not a replay). The result of the validation step + // must be interpreted by calling QuicCryptoServerConfig::ProcessClientHello + // from the done_cb. + // + // ValidateClientHello may invoke the done_cb before unrolling the + // stack if it is able to assess the validity of the client_nonce + // without asynchronous operations. + // + // client_hello: the incoming client hello message. + // client_ip: the IP address of the client, which is used to generate and + // validate source-address tokens. + // server_address: the IP address and port of the server. The IP address and + // port may be used for certificate selection. + // version: protocol version used for this connection. + // clock: used to validate client nonces and ephemeral keys. + // signed_config: in/out parameter to which will be written the crypto proof + // used in reply to a proof demand. The pointed-to-object must live until + // the callback is invoked. + // done_cb: single-use callback that accepts an opaque + // ValidatedClientHelloMsg token that holds information about + // the client hello. The callback will always be called exactly + // once, either under the current call stack, or after the + // completion of an asynchronous operation. + void ValidateClientHello( + const CryptoHandshakeMessage& client_hello, + const QuicSocketAddress& client_address, + const QuicSocketAddress& server_address, QuicTransportVersion version, + const QuicClock* clock, + quiche::QuicheReferenceCountedPointer + signed_config, + std::unique_ptr done_cb) const; + + // ProcessClientHello processes |client_hello| and decides whether to accept + // or reject the connection. If the connection is to be accepted, |done_cb| is + // invoked with the contents of the ServerHello and QUIC_NO_ERROR. Otherwise + // |done_cb| is called with a REJ or SREJ message and QUIC_NO_ERROR. + // + // validate_chlo_result: Output from the asynchronous call to + // ValidateClientHello. Contains the client hello message and + // information about it. + // reject_only: Only generate rejections, not server hello messages. + // connection_id: the ConnectionId for the connection, which is used in key + // derivation. + // server_ip: the IP address of the server. The IP address may be used for + // certificate selection. + // client_address: the IP address and port of the client. The IP address is + // used to generate and validate source-address tokens. + // version: version of the QUIC protocol in use for this connection + // supported_versions: versions of the QUIC protocol that this server + // supports. + // clock: used to validate client nonces and ephemeral keys. + // rand: an entropy source + // compressed_certs_cache: the cache that caches a set of most recently used + // certs. Owned by QuicDispatcher. + // params: the state of the handshake. This may be updated with a server + // nonce when we send a rejection. + // signed_config: output structure containing the crypto proof used in reply + // to a proof demand. + // total_framing_overhead: the total per-packet overhead for a stream frame + // chlo_packet_size: the size, in bytes, of the CHLO packet + // done_cb: the callback invoked on completion + void ProcessClientHello( + quiche::QuicheReferenceCountedPointer< + ValidateClientHelloResultCallback::Result> + validate_chlo_result, + bool reject_only, QuicConnectionId connection_id, + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, ParsedQuicVersion version, + const ParsedQuicVersionVector& supported_versions, const QuicClock* clock, + QuicRandom* rand, QuicCompressedCertsCache* compressed_certs_cache, + quiche::QuicheReferenceCountedPointer + params, + quiche::QuicheReferenceCountedPointer + signed_config, + QuicByteCount total_framing_overhead, QuicByteCount chlo_packet_size, + std::shared_ptr done_cb) const; + + // BuildServerConfigUpdateMessage invokes |cb| with a SCUP message containing + // the current primary config, an up to date source-address token, and cert + // chain and proof in the case of secure QUIC. Passes true to |cb| if the + // message was generated successfully, and false otherwise. This method + // assumes ownership of |cb|. + // + // |cached_network_params| is optional, and can be nullptr. + void BuildServerConfigUpdateMessage( + QuicTransportVersion version, absl::string_view chlo_hash, + const SourceAddressTokens& previous_source_address_tokens, + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, const QuicClock* clock, + QuicRandom* rand, QuicCompressedCertsCache* compressed_certs_cache, + const QuicCryptoNegotiatedParameters& params, + const CachedNetworkParameters* cached_network_params, + std::unique_ptr cb) const; + + // set_replay_protection controls whether replay protection is enabled. If + // replay protection is disabled then no strike registers are needed and + // frontends can share an orbit value without a shared strike-register. + // However, an attacker can duplicate a handshake and cause a client's + // request to be processed twice. + void set_replay_protection(bool on); + + // set_chlo_multiplier specifies the multiple of the CHLO message size + // that a REJ message must stay under when the client doesn't present a + // valid source-address token. + void set_chlo_multiplier(size_t multiplier); + + // When sender is allowed to not pad client hello (not standards compliant), + // we need to disable the client hello check. + void set_validate_chlo_size(bool new_value) { + validate_chlo_size_ = new_value; + } + + // Returns whether the sender is allowed to not pad the client hello. + bool validate_chlo_size() const { return validate_chlo_size_; } + + // When QUIC is tunneled through some other mechanism, source token validation + // may be disabled. Do not disable it if you are not providing other + // protection. (|true| protects against UDP amplification attack.). + void set_validate_source_address_token(bool new_value) { + validate_source_address_token_ = new_value; + } + + // set_source_address_token_future_secs sets the number of seconds into the + // future that source-address tokens will be accepted from. Since + // source-address tokens are authenticated, this should only happen if + // another, valid server has clock-skew. + void set_source_address_token_future_secs(uint32_t future_secs); + + // set_source_address_token_lifetime_secs sets the number of seconds that a + // source-address token will be valid for. + void set_source_address_token_lifetime_secs(uint32_t lifetime_secs); + + // set_enable_serving_sct enables or disables serving signed cert timestamp + // (RFC6962) in server hello. + void set_enable_serving_sct(bool enable_serving_sct); + + // Set and take ownership of the callback to invoke on primary config changes. + void AcquirePrimaryConfigChangedCb( + std::unique_ptr cb); + + // Returns the number of configs this object owns. + int NumberOfConfigs() const; + + // NewSourceAddressToken returns a fresh source address token for the given + // IP address. |previous_tokens| is the received tokens, and can be empty. + // |cached_network_params| is optional, and can be nullptr. + std::string NewSourceAddressToken( + const CryptoSecretBoxer& crypto_secret_boxer, + const SourceAddressTokens& previous_tokens, const QuicIpAddress& ip, + QuicRandom* rand, QuicWallTime now, + const CachedNetworkParameters* cached_network_params) const; + + // ParseSourceAddressToken parses the source address tokens contained in + // the encrypted |token|, and populates |tokens| with the parsed tokens. + // Returns HANDSHAKE_OK if |token| could be parsed, or the reason for the + // failure. + HandshakeFailureReason ParseSourceAddressToken( + const CryptoSecretBoxer& crypto_secret_boxer, absl::string_view token, + SourceAddressTokens& tokens) const; + + // ValidateSourceAddressTokens returns HANDSHAKE_OK if the source address + // tokens in |tokens| contain a valid and timely token for the IP address + // |ip| given that the current time is |now|. Otherwise it returns the + // reason for failure. |cached_network_params| is populated if the valid + // token contains a CachedNetworkParameters proto. + HandshakeFailureReason ValidateSourceAddressTokens( + const SourceAddressTokens& tokens, const QuicIpAddress& ip, + QuicWallTime now, CachedNetworkParameters* cached_network_params) const; + + // Callers retain the ownership of |rejection_observer| which must outlive the + // config. + void set_rejection_observer(RejectionObserver* rejection_observer) { + rejection_observer_ = rejection_observer; + } + + ProofSource* proof_source() const; + + SSL_CTX* ssl_ctx() const; + + // Pre-shared key used during the handshake. + const std::string& pre_shared_key() const { return pre_shared_key_; } + void set_pre_shared_key(absl::string_view psk) { + pre_shared_key_ = std::string(psk); + } + + bool pad_rej() const { return pad_rej_; } + void set_pad_rej(bool new_value) { pad_rej_ = new_value; } + + bool pad_shlo() const { return pad_shlo_; } + void set_pad_shlo(bool new_value) { pad_shlo_ = new_value; } + + const CryptoSecretBoxer& source_address_token_boxer() const { + return source_address_token_boxer_; + } + + private: + friend class test::QuicCryptoServerConfigPeer; + friend struct QuicSignedServerConfig; + + // Config represents a server config: a collection of preferences and + // Diffie-Hellman public values. + class QUIC_EXPORT_PRIVATE Config : public QuicCryptoConfig, + public quiche::QuicheReferenceCounted { + public: + Config(); + Config(const Config&) = delete; + Config& operator=(const Config&) = delete; + + // TODO(rtenneti): since this is a class, we should probably do + // getters/setters here. + // |serialized| contains the bytes of this server config, suitable for + // sending on the wire. + std::string serialized; + // id contains the SCID of this server config. + std::string id; + // orbit contains the orbit value for this config: an opaque identifier + // used to identify clusters of server frontends. + unsigned char orbit[kOrbitSize]; + + // key_exchanges contains key exchange objects. The values correspond, + // one-to-one, with the tags in |kexs| from the parent class. + std::vector> key_exchanges; + + // channel_id_enabled is true if the config in |serialized| specifies that + // ChannelIDs are supported. + bool channel_id_enabled; + + // is_primary is true if this config is the one that we'll give out to + // clients as the current one. + bool is_primary; + + // primary_time contains the timestamp when this config should become the + // primary config. A value of QuicWallTime::Zero() means that this config + // will not be promoted at a specific time. + QuicWallTime primary_time; + + // expiry_time contains the timestamp when this config expires. + QuicWallTime expiry_time; + + // Secondary sort key for use when selecting primary configs and + // there are multiple configs with the same primary time. + // Smaller numbers mean higher priority. + uint64_t priority; + + // source_address_token_boxer_ is used to protect the + // source-address tokens that are given to clients. + // Points to either source_address_token_boxer_storage or the + // default boxer provided by QuicCryptoServerConfig. + const CryptoSecretBoxer* source_address_token_boxer; + + // Holds the override source_address_token_boxer instance if the + // Config is not using the default source address token boxer + // instance provided by QuicCryptoServerConfig. + std::unique_ptr source_address_token_boxer_storage; + + private: + ~Config() override; + }; + + using ConfigMap = + std::map>; + + // Get a ref to the config with a given server config id. + quiche::QuicheReferenceCountedPointer GetConfigWithScid( + absl::string_view requested_scid) const + QUIC_SHARED_LOCKS_REQUIRED(configs_lock_); + + // A snapshot of the configs associated with an in-progress handshake. + struct QUIC_EXPORT_PRIVATE Configs { + quiche::QuicheReferenceCountedPointer requested; + quiche::QuicheReferenceCountedPointer primary; + quiche::QuicheReferenceCountedPointer fallback; + }; + + // Get a snapshot of the current configs associated with a handshake. If this + // method was called earlier in this handshake |old_primary_config| should be + // set to the primary config returned from that invocation, otherwise nullptr. + // + // Returns true if any configs are loaded. If false is returned, |configs| is + // not modified. + bool GetCurrentConfigs( + const QuicWallTime& now, absl::string_view requested_scid, + quiche::QuicheReferenceCountedPointer old_primary_config, + Configs* configs) const; + + // ConfigPrimaryTimeLessThan returns true if a->primary_time < + // b->primary_time. + static bool ConfigPrimaryTimeLessThan( + const quiche::QuicheReferenceCountedPointer& a, + const quiche::QuicheReferenceCountedPointer& b); + + // SelectNewPrimaryConfig reevaluates the primary config based on the + // "primary_time" deadlines contained in each. + void SelectNewPrimaryConfig(QuicWallTime now) const + QUIC_EXCLUSIVE_LOCKS_REQUIRED(configs_lock_); + + // EvaluateClientHello checks |client_hello_state->client_hello| for gross + // errors and determines whether it is fresh (i.e. not a replay). The results + // are written to |client_hello_state->info|. + void EvaluateClientHello( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, QuicTransportVersion version, + const Configs& configs, + quiche::QuicheReferenceCountedPointer< + ValidateClientHelloResultCallback::Result> + client_hello_state, + std::unique_ptr done_cb) const; + + // Convenience class which carries the arguments passed to + // |ProcessClientHellp| along. + class QUIC_EXPORT_PRIVATE ProcessClientHelloContext { + public: + ProcessClientHelloContext( + quiche::QuicheReferenceCountedPointer< + ValidateClientHelloResultCallback::Result> + validate_chlo_result, + bool reject_only, QuicConnectionId connection_id, + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, ParsedQuicVersion version, + const ParsedQuicVersionVector& supported_versions, + const QuicClock* clock, QuicRandom* rand, + QuicCompressedCertsCache* compressed_certs_cache, + quiche::QuicheReferenceCountedPointer + params, + quiche::QuicheReferenceCountedPointer + signed_config, + QuicByteCount total_framing_overhead, QuicByteCount chlo_packet_size, + std::shared_ptr done_cb) + : validate_chlo_result_(validate_chlo_result), + reject_only_(reject_only), + connection_id_(connection_id), + server_address_(server_address), + client_address_(client_address), + version_(version), + supported_versions_(supported_versions), + clock_(clock), + rand_(rand), + compressed_certs_cache_(compressed_certs_cache), + params_(params), + signed_config_(signed_config), + total_framing_overhead_(total_framing_overhead), + chlo_packet_size_(chlo_packet_size), + done_cb_(std::move(done_cb)) {} + + ~ProcessClientHelloContext(); + + // Invoke |done_cb_| with an error status + void Fail(QuicErrorCode error, const std::string& error_details); + + // Invoke |done_cb_| with a success status + void Succeed(std::unique_ptr message, + std::unique_ptr diversification_nonce, + std::unique_ptr proof_source_details); + + // Member accessors + quiche::QuicheReferenceCountedPointer< + ValidateClientHelloResultCallback::Result> + validate_chlo_result() const { + return validate_chlo_result_; + } + bool reject_only() const { return reject_only_; } + QuicConnectionId connection_id() const { return connection_id_; } + QuicSocketAddress server_address() const { return server_address_; } + QuicSocketAddress client_address() const { return client_address_; } + ParsedQuicVersion version() const { return version_; } + ParsedQuicVersionVector supported_versions() const { + return supported_versions_; + } + const QuicClock* clock() const { return clock_; } + QuicRandom* rand() const { return rand_; } // NOLINT + QuicCompressedCertsCache* compressed_certs_cache() const { + return compressed_certs_cache_; + } + quiche::QuicheReferenceCountedPointer + params() const { + return params_; + } + quiche::QuicheReferenceCountedPointer + signed_config() const { + return signed_config_; + } + QuicByteCount total_framing_overhead() const { + return total_framing_overhead_; + } + QuicByteCount chlo_packet_size() const { return chlo_packet_size_; } + + // Derived value accessors + const CryptoHandshakeMessage& client_hello() const { + return validate_chlo_result()->client_hello; + } + const ClientHelloInfo& info() const { return validate_chlo_result()->info; } + QuicTransportVersion transport_version() const { + return version().transport_version; + } + + private: + const quiche::QuicheReferenceCountedPointer< + ValidateClientHelloResultCallback::Result> + validate_chlo_result_; + const bool reject_only_; + const QuicConnectionId connection_id_; + const QuicSocketAddress server_address_; + const QuicSocketAddress client_address_; + const ParsedQuicVersion version_; + const ParsedQuicVersionVector supported_versions_; + const QuicClock* const clock_; + QuicRandom* const rand_; + QuicCompressedCertsCache* const compressed_certs_cache_; + const quiche::QuicheReferenceCountedPointer + params_; + const quiche::QuicheReferenceCountedPointer + signed_config_; + const QuicByteCount total_framing_overhead_; + const QuicByteCount chlo_packet_size_; + std::shared_ptr done_cb_; + }; + + // Callback class for bridging between ProcessClientHello and + // ProcessClientHelloAfterGetProof. + class ProcessClientHelloCallback; + friend class ProcessClientHelloCallback; + + // Portion of ProcessClientHello which executes after GetProof. + void ProcessClientHelloAfterGetProof( + bool found_error, + std::unique_ptr proof_source_details, + std::unique_ptr context, + const Configs& configs) const; + + // Callback class for bridging between ProcessClientHelloAfterGetProof and + // ProcessClientHelloAfterCalculateSharedKeys. + class ProcessClientHelloAfterGetProofCallback; + friend class ProcessClientHelloAfterGetProofCallback; + + // Portion of ProcessClientHello which executes after CalculateSharedKeys. + void ProcessClientHelloAfterCalculateSharedKeys( + bool found_error, + std::unique_ptr proof_source_details, + QuicTag key_exchange_type, std::unique_ptr out, + absl::string_view public_value, + std::unique_ptr context, + const Configs& configs) const; + + // Send a REJ which contains a different ServerConfig than the one the client + // originally used. This is necessary in cases where we discover in the + // middle of the handshake that the private key for the ServerConfig the + // client used is not accessible. + void SendRejectWithFallbackConfig( + std::unique_ptr context, + quiche::QuicheReferenceCountedPointer fallback_config) const; + + // Callback class for bridging between SendRejectWithFallbackConfig and + // SendRejectWithFallbackConfigAfterGetProof. + class SendRejectWithFallbackConfigCallback; + friend class SendRejectWithFallbackConfigCallback; + + // Portion of ProcessClientHello which executes after GetProof in the case + // where we have received a CHLO but need to reject it due to the ServerConfig + // private keys being inaccessible. + void SendRejectWithFallbackConfigAfterGetProof( + bool found_error, + std::unique_ptr proof_source_details, + std::unique_ptr context, + quiche::QuicheReferenceCountedPointer fallback_config) const; + + // BuildRejectionAndRecordStats calls |BuildRejection| below and also informs + // the RejectionObserver. + void BuildRejectionAndRecordStats(const ProcessClientHelloContext& context, + const Config& config, + const std::vector& reject_reasons, + CryptoHandshakeMessage* out) const; + + // BuildRejection sets |out| to be a REJ message in reply to |client_hello|. + void BuildRejection(const ProcessClientHelloContext& context, + const Config& config, + const std::vector& reject_reasons, + CryptoHandshakeMessage* out) const; + + // CompressChain compresses the certificates in |chain->certs| and returns a + // compressed representation. |client_cached_cert_hashes| contains + // 64-bit, FNV-1a hashes of certificates that the peer already possesses. + static std::string CompressChain( + QuicCompressedCertsCache* compressed_certs_cache, + const quiche::QuicheReferenceCountedPointer& chain, + const std::string& client_cached_cert_hashes); + + // ParseConfigProtobuf parses the given config protobuf and returns a + // quiche::QuicheReferenceCountedPointer if successful. The caller + // adopts the reference to the Config. On error, ParseConfigProtobuf returns + // nullptr. + quiche::QuicheReferenceCountedPointer ParseConfigProtobuf( + const QuicServerConfigProtobuf& protobuf, bool is_fallback) const; + + // ValidateSingleSourceAddressToken returns HANDSHAKE_OK if the source + // address token in |token| is a timely token for the IP address |ip| + // given that the current time is |now|. Otherwise it returns the reason + // for failure. + HandshakeFailureReason ValidateSingleSourceAddressToken( + const SourceAddressToken& token, const QuicIpAddress& ip, + QuicWallTime now) const; + + // Returns HANDSHAKE_OK if the source address token in |token| is a timely + // token given that the current time is |now|. Otherwise it returns the + // reason for failure. + HandshakeFailureReason ValidateSourceAddressTokenTimestamp( + const SourceAddressToken& token, QuicWallTime now) const; + + // NewServerNonce generates and encrypts a random nonce. + std::string NewServerNonce(QuicRandom* rand, QuicWallTime now) const; + + // ValidateExpectedLeafCertificate checks the |client_hello| to see if it has + // an XLCT tag, and if so, verifies that its value matches the hash of the + // server's leaf certificate. |certs| is used to compare against the XLCT + // value. This method returns true if the XLCT tag is not present, or if the + // XLCT tag is present and valid. It returns false otherwise. + bool ValidateExpectedLeafCertificate( + const CryptoHandshakeMessage& client_hello, + const std::vector& certs) const; + + // Callback to receive the results of ProofSource::GetProof. Note: this + // callback has no cancellation support, since the lifetime of the ProofSource + // is controlled by this object via unique ownership. If that ownership + // stricture changes, this decision may need to be revisited. + class BuildServerConfigUpdateMessageProofSourceCallback + : public ProofSource::Callback { + public: + BuildServerConfigUpdateMessageProofSourceCallback( + const BuildServerConfigUpdateMessageProofSourceCallback&) = delete; + ~BuildServerConfigUpdateMessageProofSourceCallback() override; + void operator=(const BuildServerConfigUpdateMessageProofSourceCallback&) = + delete; + BuildServerConfigUpdateMessageProofSourceCallback( + const QuicCryptoServerConfig* config, + QuicCompressedCertsCache* compressed_certs_cache, + const QuicCryptoNegotiatedParameters& params, + CryptoHandshakeMessage message, + std::unique_ptr cb); + + void Run( + bool ok, + const quiche::QuicheReferenceCountedPointer& chain, + const QuicCryptoProof& proof, + std::unique_ptr details) override; + + private: + const QuicCryptoServerConfig* config_; + QuicCompressedCertsCache* compressed_certs_cache_; + const std::string client_cached_cert_hashes_; + const bool sct_supported_by_client_; + const std::string sni_; + CryptoHandshakeMessage message_; + std::unique_ptr cb_; + }; + + // Invoked by BuildServerConfigUpdateMessageProofSourceCallback::Run once + // the proof has been acquired. Finishes building the server config update + // message and invokes |cb|. + void FinishBuildServerConfigUpdateMessage( + QuicCompressedCertsCache* compressed_certs_cache, + const std::string& client_cached_cert_hashes, + bool sct_supported_by_client, const std::string& sni, bool ok, + const quiche::QuicheReferenceCountedPointer& chain, + const std::string& signature, const std::string& leaf_cert_sct, + std::unique_ptr details, + CryptoHandshakeMessage message, + std::unique_ptr cb) const; + + // Returns true if the next config promotion should happen now. + bool IsNextConfigReady(QuicWallTime now) const + QUIC_SHARED_LOCKS_REQUIRED(configs_lock_); + + // replay_protection_ controls whether the server enforces that handshakes + // aren't replays. + bool replay_protection_; + + // The multiple of the CHLO message size that a REJ message must stay under + // when the client doesn't present a valid source-address token. This is + // used to protect QUIC from amplification attacks. + size_t chlo_multiplier_; + + // configs_ satisfies the following invariants: + // 1) configs_.empty() <-> primary_config_ == nullptr + // 2) primary_config_ != nullptr -> primary_config_->is_primary + // 3) ∀ c∈configs_, c->is_primary <-> c == primary_config_ + mutable QuicMutex configs_lock_; + + // configs_ contains all active server configs. It's expected that there are + // about half-a-dozen configs active at any one time. + ConfigMap configs_ QUIC_GUARDED_BY(configs_lock_); + + // primary_config_ points to a Config (which is also in |configs_|) which is + // the primary config - i.e. the one that we'll give out to new clients. + mutable quiche::QuicheReferenceCountedPointer primary_config_ + QUIC_GUARDED_BY(configs_lock_); + + // fallback_config_ points to a Config (which is also in |configs_|) which is + // the fallback config, which will be used if the other configs are unuseable + // for some reason. + // + // TODO(b/112548056): This is currently always nullptr. + quiche::QuicheReferenceCountedPointer fallback_config_ + QUIC_GUARDED_BY(configs_lock_); + + // next_config_promotion_time_ contains the nearest, future time when an + // active config will be promoted to primary. + mutable QuicWallTime next_config_promotion_time_ + QUIC_GUARDED_BY(configs_lock_); + + // Callback to invoke when the primary config changes. + std::unique_ptr primary_config_changed_cb_ + QUIC_GUARDED_BY(configs_lock_); + + // Used to protect the source-address tokens that are given to clients. + CryptoSecretBoxer source_address_token_boxer_; + + // server_nonce_boxer_ is used to encrypt and validate suggested server + // nonces. + CryptoSecretBoxer server_nonce_boxer_; + + // server_nonce_orbit_ contains the random, per-server orbit values that this + // server will use to generate server nonces (the moral equivalent of a SYN + // cookies). + uint8_t server_nonce_orbit_[8]; + + // proof_source_ contains an object that can provide certificate chains and + // signatures. + std::unique_ptr proof_source_; + + // key_exchange_source_ contains an object that can provide key exchange + // objects. + std::unique_ptr key_exchange_source_; + + // ssl_ctx_ contains the server configuration for doing TLS handshakes. + bssl::UniquePtr ssl_ctx_; + + // These fields store configuration values. See the comments for their + // respective setter functions. + uint32_t source_address_token_future_secs_; + uint32_t source_address_token_lifetime_secs_; + + // Enable serving SCT or not. + bool enable_serving_sct_; + + // Does not own this observer. + RejectionObserver* rejection_observer_; + + // If non-empty, the server will operate in the pre-shared key mode by + // incorporating |pre_shared_key_| into the key schedule. + std::string pre_shared_key_; + + // Whether REJ message should be padded to max packet size. + bool pad_rej_; + + // Whether SHLO message should be padded to max packet size. + bool pad_shlo_; + + // If client is allowed to send a small client hello (by disabling padding), + // server MUST not check for the client hello size. + // DO NOT disable this unless you have some other way of validating client. + // (e.g. in realtime scenarios, where quic is tunneled through ICE, ICE will + // do its own peer validation using STUN pings with ufrag/upass). + bool validate_chlo_size_; + + // When source address is validated by some other means (e.g. when using ICE), + // source address token validation may be disabled. + bool validate_source_address_token_; +}; + +struct QUIC_EXPORT_PRIVATE QuicSignedServerConfig + : public quiche::QuicheReferenceCounted { + QuicSignedServerConfig(); + + QuicCryptoProof proof; + quiche::QuicheReferenceCountedPointer chain; + // The server config that is used for this proof (and the rest of the + // request). + quiche::QuicheReferenceCountedPointer config; + std::string primary_scid; + + protected: + ~QuicSignedServerConfig() override; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_QUIC_CRYPTO_SERVER_CONFIG_H_ diff --git a/quiche/quic/core/crypto/quic_crypto_server_config_test.cc b/quiche/quic/core/crypto/quic_crypto_server_config_test.cc new file mode 100644 index 000000000000..ed7ffdb981af --- /dev/null +++ b/quiche/quic/core/crypto/quic_crypto_server_config_test.cc @@ -0,0 +1,494 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/quic_crypto_server_config.h" + +#include + +#include +#include + +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/cert_compressor.h" +#include "quiche/quic/core/crypto/chacha20_poly1305_encrypter.h" +#include "quiche/quic/core/crypto/crypto_handshake_message.h" +#include "quiche/quic/core/crypto/crypto_secret_boxer.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/proto/crypto_server_config_proto.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/mock_clock.h" +#include "quiche/quic/test_tools/quic_crypto_server_config_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { +namespace test { +using ::testing::Not; + +// NOTE: This matcher depends on the wire format of serialzied protocol buffers, +// which may change in the future. +// Switch to ::testing::EqualsProto once it is available in Chromium. +MATCHER_P(SerializedProtoEquals, message, "") { + std::string expected_serialized, actual_serialized; + message.SerializeToString(&expected_serialized); + arg.SerializeToString(&actual_serialized); + return expected_serialized == actual_serialized; +} + +class QuicCryptoServerConfigTest : public QuicTest {}; + +TEST_F(QuicCryptoServerConfigTest, ServerConfig) { + QuicRandom* rand = QuicRandom::GetInstance(); + QuicCryptoServerConfig server(QuicCryptoServerConfig::TESTING, rand, + crypto_test_utils::ProofSourceForTesting(), + KeyExchangeSource::Default()); + MockClock clock; + + std::unique_ptr message(server.AddDefaultConfig( + rand, &clock, QuicCryptoServerConfig::ConfigOptions())); + + // The default configuration should have AES-GCM and at least one ChaCha20 + // cipher. + QuicTagVector aead; + ASSERT_THAT(message->GetTaglist(kAEAD, &aead), IsQuicNoError()); + EXPECT_THAT(aead, ::testing::Contains(kAESG)); + EXPECT_LE(1u, aead.size()); +} + +TEST_F(QuicCryptoServerConfigTest, CompressCerts) { + QuicCompressedCertsCache compressed_certs_cache( + QuicCompressedCertsCache::kQuicCompressedCertsCacheSize); + + QuicRandom* rand = QuicRandom::GetInstance(); + QuicCryptoServerConfig server(QuicCryptoServerConfig::TESTING, rand, + crypto_test_utils::ProofSourceForTesting(), + KeyExchangeSource::Default()); + QuicCryptoServerConfigPeer peer(&server); + + std::vector certs = {"testcert"}; + quiche::QuicheReferenceCountedPointer chain( + new ProofSource::Chain(certs)); + + std::string compressed = QuicCryptoServerConfigPeer::CompressChain( + &compressed_certs_cache, chain, ""); + + EXPECT_EQ(compressed_certs_cache.Size(), 1u); +} + +TEST_F(QuicCryptoServerConfigTest, CompressSameCertsTwice) { + QuicCompressedCertsCache compressed_certs_cache( + QuicCompressedCertsCache::kQuicCompressedCertsCacheSize); + + QuicRandom* rand = QuicRandom::GetInstance(); + QuicCryptoServerConfig server(QuicCryptoServerConfig::TESTING, rand, + crypto_test_utils::ProofSourceForTesting(), + KeyExchangeSource::Default()); + QuicCryptoServerConfigPeer peer(&server); + + // Compress the certs for the first time. + std::vector certs = {"testcert"}; + quiche::QuicheReferenceCountedPointer chain( + new ProofSource::Chain(certs)); + std::string cached_certs = ""; + + std::string compressed = QuicCryptoServerConfigPeer::CompressChain( + &compressed_certs_cache, chain, cached_certs); + EXPECT_EQ(compressed_certs_cache.Size(), 1u); + + // Compress the same certs, should use cache if available. + std::string compressed2 = QuicCryptoServerConfigPeer::CompressChain( + &compressed_certs_cache, chain, cached_certs); + EXPECT_EQ(compressed, compressed2); + EXPECT_EQ(compressed_certs_cache.Size(), 1u); +} + +TEST_F(QuicCryptoServerConfigTest, CompressDifferentCerts) { + // This test compresses a set of similar but not identical certs. Cache if + // used should return cache miss and add all the compressed certs. + QuicCompressedCertsCache compressed_certs_cache( + QuicCompressedCertsCache::kQuicCompressedCertsCacheSize); + + QuicRandom* rand = QuicRandom::GetInstance(); + QuicCryptoServerConfig server(QuicCryptoServerConfig::TESTING, rand, + crypto_test_utils::ProofSourceForTesting(), + KeyExchangeSource::Default()); + QuicCryptoServerConfigPeer peer(&server); + + std::vector certs = {"testcert"}; + quiche::QuicheReferenceCountedPointer chain( + new ProofSource::Chain(certs)); + std::string cached_certs = ""; + + std::string compressed = QuicCryptoServerConfigPeer::CompressChain( + &compressed_certs_cache, chain, cached_certs); + EXPECT_EQ(compressed_certs_cache.Size(), 1u); + + // Compress a similar certs which only differs in the chain. + quiche::QuicheReferenceCountedPointer chain2( + new ProofSource::Chain(certs)); + + std::string compressed2 = QuicCryptoServerConfigPeer::CompressChain( + &compressed_certs_cache, chain2, cached_certs); + EXPECT_EQ(compressed_certs_cache.Size(), 2u); +} + +class SourceAddressTokenTest : public QuicTest { + public: + SourceAddressTokenTest() + : ip4_(QuicIpAddress::Loopback4()), + ip4_dual_(ip4_.DualStacked()), + ip6_(QuicIpAddress::Loopback6()), + original_time_(QuicWallTime::Zero()), + rand_(QuicRandom::GetInstance()), + server_(QuicCryptoServerConfig::TESTING, rand_, + crypto_test_utils::ProofSourceForTesting(), + KeyExchangeSource::Default()), + peer_(&server_) { + // Advance the clock to some non-zero time. + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(1000000)); + original_time_ = clock_.WallNow(); + + primary_config_ = server_.AddDefaultConfig( + rand_, &clock_, QuicCryptoServerConfig::ConfigOptions()); + } + + std::string NewSourceAddressToken(std::string config_id, + const QuicIpAddress& ip) { + return NewSourceAddressToken(config_id, ip, nullptr); + } + + std::string NewSourceAddressToken( + std::string config_id, const QuicIpAddress& ip, + const SourceAddressTokens& previous_tokens) { + return peer_.NewSourceAddressToken(config_id, previous_tokens, ip, rand_, + clock_.WallNow(), nullptr); + } + + std::string NewSourceAddressToken( + std::string config_id, const QuicIpAddress& ip, + CachedNetworkParameters* cached_network_params) { + SourceAddressTokens previous_tokens; + return peer_.NewSourceAddressToken(config_id, previous_tokens, ip, rand_, + clock_.WallNow(), cached_network_params); + } + + HandshakeFailureReason ValidateSourceAddressTokens(std::string config_id, + absl::string_view srct, + const QuicIpAddress& ip) { + return ValidateSourceAddressTokens(config_id, srct, ip, nullptr); + } + + HandshakeFailureReason ValidateSourceAddressTokens( + std::string config_id, absl::string_view srct, const QuicIpAddress& ip, + CachedNetworkParameters* cached_network_params) { + return peer_.ValidateSourceAddressTokens( + config_id, srct, ip, clock_.WallNow(), cached_network_params); + } + + const std::string kPrimary = ""; + const std::string kOverride = "Config with custom source address token key"; + + QuicIpAddress ip4_; + QuicIpAddress ip4_dual_; + QuicIpAddress ip6_; + + MockClock clock_; + QuicWallTime original_time_; + QuicRandom* rand_ = QuicRandom::GetInstance(); + QuicCryptoServerConfig server_; + QuicCryptoServerConfigPeer peer_; + // Stores the primary config. + std::unique_ptr primary_config_; + std::unique_ptr override_config_protobuf_; +}; + +// Test basic behavior of source address tokens including being specific +// to a single IP address and server config. +TEST_F(SourceAddressTokenTest, SourceAddressToken) { + // Primary config generates configs that validate successfully. + const std::string token4 = NewSourceAddressToken(kPrimary, ip4_); + const std::string token4d = NewSourceAddressToken(kPrimary, ip4_dual_); + const std::string token6 = NewSourceAddressToken(kPrimary, ip6_); + EXPECT_EQ(HANDSHAKE_OK, ValidateSourceAddressTokens(kPrimary, token4, ip4_)); + ASSERT_EQ(HANDSHAKE_OK, + ValidateSourceAddressTokens(kPrimary, token4, ip4_dual_)); + ASSERT_EQ(SOURCE_ADDRESS_TOKEN_DIFFERENT_IP_ADDRESS_FAILURE, + ValidateSourceAddressTokens(kPrimary, token4, ip6_)); + ASSERT_EQ(HANDSHAKE_OK, ValidateSourceAddressTokens(kPrimary, token4d, ip4_)); + ASSERT_EQ(HANDSHAKE_OK, + ValidateSourceAddressTokens(kPrimary, token4d, ip4_dual_)); + ASSERT_EQ(SOURCE_ADDRESS_TOKEN_DIFFERENT_IP_ADDRESS_FAILURE, + ValidateSourceAddressTokens(kPrimary, token4d, ip6_)); + ASSERT_EQ(HANDSHAKE_OK, ValidateSourceAddressTokens(kPrimary, token6, ip6_)); +} + +TEST_F(SourceAddressTokenTest, SourceAddressTokenExpiration) { + const std::string token = NewSourceAddressToken(kPrimary, ip4_); + + // Validation fails if the token is from the future. + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(-3600 * 2)); + ASSERT_EQ(SOURCE_ADDRESS_TOKEN_CLOCK_SKEW_FAILURE, + ValidateSourceAddressTokens(kPrimary, token, ip4_)); + + // Validation fails after tokens expire. + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(86400 * 7)); + ASSERT_EQ(SOURCE_ADDRESS_TOKEN_EXPIRED_FAILURE, + ValidateSourceAddressTokens(kPrimary, token, ip4_)); +} + +TEST_F(SourceAddressTokenTest, SourceAddressTokenWithNetworkParams) { + // Make sure that if the source address token contains CachedNetworkParameters + // that this gets written to ValidateSourceAddressToken output argument. + CachedNetworkParameters cached_network_params_input; + cached_network_params_input.set_bandwidth_estimate_bytes_per_second(1234); + const std::string token4_with_cached_network_params = + NewSourceAddressToken(kPrimary, ip4_, &cached_network_params_input); + + CachedNetworkParameters cached_network_params_output; + EXPECT_THAT(cached_network_params_output, + Not(SerializedProtoEquals(cached_network_params_input))); + ValidateSourceAddressTokens(kPrimary, token4_with_cached_network_params, ip4_, + &cached_network_params_output); + EXPECT_THAT(cached_network_params_output, + SerializedProtoEquals(cached_network_params_input)); +} + +// Test the ability for a source address token to be valid for multiple +// addresses. +TEST_F(SourceAddressTokenTest, SourceAddressTokenMultipleAddresses) { + QuicWallTime now = clock_.WallNow(); + + // Now create a token which is usable for both addresses. + SourceAddressToken previous_token; + previous_token.set_ip(ip6_.DualStacked().ToPackedString()); + previous_token.set_timestamp(now.ToUNIXSeconds()); + SourceAddressTokens previous_tokens; + (*previous_tokens.add_tokens()) = previous_token; + const std::string token4or6 = + NewSourceAddressToken(kPrimary, ip4_, previous_tokens); + + EXPECT_EQ(HANDSHAKE_OK, + ValidateSourceAddressTokens(kPrimary, token4or6, ip4_)); + ASSERT_EQ(HANDSHAKE_OK, + ValidateSourceAddressTokens(kPrimary, token4or6, ip6_)); +} + +class CryptoServerConfigsTest : public QuicTest { + public: + CryptoServerConfigsTest() + : rand_(QuicRandom::GetInstance()), + config_(QuicCryptoServerConfig::TESTING, rand_, + crypto_test_utils::ProofSourceForTesting(), + KeyExchangeSource::Default()), + test_peer_(&config_) {} + + void SetUp() override { + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(1000)); + } + + // SetConfigs constructs suitable config protobufs and calls SetConfigs on + // |config_|. + // Each struct in the input vector contains 3 elements. + // The first is the server config ID of a Config. The second is + // the |primary_time| of that Config, given in epoch seconds. (Although note + // that, in these tests, time is set to 1000 seconds since the epoch.). + // The third is the priority. + // + // For example: + // SetConfigs(std::vector()); // calls + // |config_.SetConfigs| with no protobufs. + // + // // Calls |config_.SetConfigs| with two protobufs: one for a Config with + // // a |primary_time| of 900 and priority 1, and another with + // // a |primary_time| of 1000 and priority 2. + + // CheckConfigs( + // {{"id1", 900, 1}, + // {"id2", 1000, 2}}); + // + // If the server config id starts with "INVALID" then the generated protobuf + // will be invalid. + struct ServerConfigIDWithTimeAndPriority { + ServerConfigID server_config_id; + int primary_time; + int priority; + }; + void SetConfigs(std::vector configs) { + const char kOrbit[] = "12345678"; + + bool has_invalid = false; + + std::vector protobufs; + for (const auto& config : configs) { + const ServerConfigID& server_config_id = config.server_config_id; + const int primary_time = config.primary_time; + const int priority = config.priority; + + QuicCryptoServerConfig::ConfigOptions options; + options.id = server_config_id; + options.orbit = kOrbit; + QuicServerConfigProtobuf protobuf = + QuicCryptoServerConfig::GenerateConfig(rand_, &clock_, options); + protobuf.set_primary_time(primary_time); + protobuf.set_priority(priority); + if (absl::StartsWith(std::string(server_config_id), "INVALID")) { + protobuf.clear_key(); + has_invalid = true; + } + protobufs.push_back(std::move(protobuf)); + } + + ASSERT_EQ(!has_invalid && !configs.empty(), + config_.SetConfigs(protobufs, /* fallback_protobuf = */ nullptr, + clock_.WallNow())); + } + + protected: + QuicRandom* const rand_; + MockClock clock_; + QuicCryptoServerConfig config_; + QuicCryptoServerConfigPeer test_peer_; +}; + +TEST_F(CryptoServerConfigsTest, NoConfigs) { + test_peer_.CheckConfigs(std::vector>()); +} + +TEST_F(CryptoServerConfigsTest, MakePrimaryFirst) { + // Make sure that "b" is primary even though "a" comes first. + SetConfigs({{"a", 1100, 1}, {"b", 900, 1}}); + test_peer_.CheckConfigs({{"a", false}, {"b", true}}); +} + +TEST_F(CryptoServerConfigsTest, MakePrimarySecond) { + // Make sure that a remains primary after b is added. + SetConfigs({{"a", 900, 1}, {"b", 1100, 1}}); + test_peer_.CheckConfigs({{"a", true}, {"b", false}}); +} + +TEST_F(CryptoServerConfigsTest, Delete) { + // Ensure that configs get deleted when removed. + SetConfigs({{"a", 800, 1}, {"b", 900, 1}, {"c", 1100, 1}}); + test_peer_.CheckConfigs({{"a", false}, {"b", true}, {"c", false}}); + SetConfigs({{"b", 900, 1}, {"c", 1100, 1}}); + test_peer_.CheckConfigs({{"b", true}, {"c", false}}); +} + +TEST_F(CryptoServerConfigsTest, DeletePrimary) { + // Ensure that deleting the primary config works. + SetConfigs({{"a", 800, 1}, {"b", 900, 1}, {"c", 1100, 1}}); + test_peer_.CheckConfigs({{"a", false}, {"b", true}, {"c", false}}); + SetConfigs({{"a", 800, 1}, {"c", 1100, 1}}); + test_peer_.CheckConfigs({{"a", true}, {"c", false}}); +} + +TEST_F(CryptoServerConfigsTest, FailIfDeletingAllConfigs) { + // Ensure that configs get deleted when removed. + SetConfigs({{"a", 800, 1}, {"b", 900, 1}}); + test_peer_.CheckConfigs({{"a", false}, {"b", true}}); + SetConfigs(std::vector()); + // Config change is rejected, still using old configs. + test_peer_.CheckConfigs({{"a", false}, {"b", true}}); +} + +TEST_F(CryptoServerConfigsTest, ChangePrimaryTime) { + // Check that updates to primary time get picked up. + SetConfigs({{"a", 400, 1}, {"b", 800, 1}, {"c", 1200, 1}}); + test_peer_.SelectNewPrimaryConfig(500); + test_peer_.CheckConfigs({{"a", true}, {"b", false}, {"c", false}}); + SetConfigs({{"a", 1200, 1}, {"b", 800, 1}, {"c", 400, 1}}); + test_peer_.SelectNewPrimaryConfig(500); + test_peer_.CheckConfigs({{"a", false}, {"b", false}, {"c", true}}); +} + +TEST_F(CryptoServerConfigsTest, AllConfigsInThePast) { + // Check that the most recent config is selected. + SetConfigs({{"a", 400, 1}, {"b", 800, 1}, {"c", 1200, 1}}); + test_peer_.SelectNewPrimaryConfig(1500); + test_peer_.CheckConfigs({{"a", false}, {"b", false}, {"c", true}}); +} + +TEST_F(CryptoServerConfigsTest, AllConfigsInTheFuture) { + // Check that the first config is selected. + SetConfigs({{"a", 400, 1}, {"b", 800, 1}, {"c", 1200, 1}}); + test_peer_.SelectNewPrimaryConfig(100); + test_peer_.CheckConfigs({{"a", true}, {"b", false}, {"c", false}}); +} + +TEST_F(CryptoServerConfigsTest, SortByPriority) { + // Check that priority is used to decide on a primary config when + // configs have the same primary time. + SetConfigs({{"a", 900, 1}, {"b", 900, 2}, {"c", 900, 3}}); + test_peer_.CheckConfigs({{"a", true}, {"b", false}, {"c", false}}); + test_peer_.SelectNewPrimaryConfig(800); + test_peer_.CheckConfigs({{"a", true}, {"b", false}, {"c", false}}); + test_peer_.SelectNewPrimaryConfig(1000); + test_peer_.CheckConfigs({{"a", true}, {"b", false}, {"c", false}}); + + // Change priorities and expect sort order to change. + SetConfigs({{"a", 900, 2}, {"b", 900, 1}, {"c", 900, 0}}); + test_peer_.CheckConfigs({{"a", false}, {"b", false}, {"c", true}}); + test_peer_.SelectNewPrimaryConfig(800); + test_peer_.CheckConfigs({{"a", false}, {"b", false}, {"c", true}}); + test_peer_.SelectNewPrimaryConfig(1000); + test_peer_.CheckConfigs({{"a", false}, {"b", false}, {"c", true}}); +} + +TEST_F(CryptoServerConfigsTest, AdvancePrimary) { + // Check that a new primary config is enabled at the right time. + SetConfigs({{"a", 900, 1}, {"b", 1100, 1}}); + test_peer_.SelectNewPrimaryConfig(1000); + test_peer_.CheckConfigs({{"a", true}, {"b", false}}); + test_peer_.SelectNewPrimaryConfig(1101); + test_peer_.CheckConfigs({{"a", false}, {"b", true}}); +} + +class ValidateCallback : public ValidateClientHelloResultCallback { + public: + void Run(quiche::QuicheReferenceCountedPointer /*result*/, + std::unique_ptr /*details*/) override {} +}; + +TEST_F(CryptoServerConfigsTest, AdvancePrimaryViaValidate) { + // Check that a new primary config is enabled at the right time. + SetConfigs({{"a", 900, 1}, {"b", 1100, 1}}); + test_peer_.SelectNewPrimaryConfig(1000); + test_peer_.CheckConfigs({{"a", true}, {"b", false}}); + CryptoHandshakeMessage client_hello; + QuicSocketAddress client_address; + QuicSocketAddress server_address; + QuicTransportVersion transport_version = QUIC_VERSION_UNSUPPORTED; + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + if (version.handshake_protocol == PROTOCOL_QUIC_CRYPTO) { + transport_version = version.transport_version; + break; + } + } + ASSERT_NE(transport_version, QUIC_VERSION_UNSUPPORTED); + MockClock clock; + quiche::QuicheReferenceCountedPointer signed_config( + new QuicSignedServerConfig); + std::unique_ptr done_cb( + new ValidateCallback); + clock.AdvanceTime(QuicTime::Delta::FromSeconds(1100)); + config_.ValidateClientHello(client_hello, client_address, server_address, + transport_version, &clock, signed_config, + std::move(done_cb)); + test_peer_.CheckConfigs({{"a", false}, {"b", true}}); +} + +TEST_F(CryptoServerConfigsTest, InvalidConfigs) { + // Ensure that invalid configs don't change anything. + SetConfigs({{"a", 800, 1}, {"b", 900, 1}, {"c", 1100, 1}}); + test_peer_.CheckConfigs({{"a", false}, {"b", true}, {"c", false}}); + SetConfigs({{"a", 800, 1}, {"c", 1100, 1}, {"INVALID1", 1000, 1}}); + test_peer_.CheckConfigs({{"a", false}, {"b", true}, {"c", false}}); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/quic_decrypter.cc b/quiche/quic/core/crypto/quic_decrypter.cc new file mode 100644 index 000000000000..da0e809acb49 --- /dev/null +++ b/quiche/quic/core/crypto/quic_decrypter.cc @@ -0,0 +1,79 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/quic_decrypter.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "openssl/tls1.h" +#include "quiche/quic/core/crypto/aes_128_gcm_12_decrypter.h" +#include "quiche/quic/core/crypto/aes_128_gcm_decrypter.h" +#include "quiche/quic/core/crypto/aes_256_gcm_decrypter.h" +#include "quiche/quic/core/crypto/chacha20_poly1305_decrypter.h" +#include "quiche/quic/core/crypto/chacha20_poly1305_tls_decrypter.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/crypto/null_decrypter.h" +#include "quiche/quic/core/crypto/quic_hkdf.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +// static +std::unique_ptr QuicDecrypter::Create( + const ParsedQuicVersion& version, QuicTag algorithm) { + switch (algorithm) { + case kAESG: + if (version.UsesInitialObfuscators()) { + return std::make_unique(); + } else { + return std::make_unique(); + } + case kCC20: + if (version.UsesInitialObfuscators()) { + return std::make_unique(); + } else { + return std::make_unique(); + } + default: + QUIC_LOG(FATAL) << "Unsupported algorithm: " << algorithm; + return nullptr; + } +} + +// static +std::unique_ptr QuicDecrypter::CreateFromCipherSuite( + uint32_t cipher_suite) { + switch (cipher_suite) { + case TLS1_CK_AES_128_GCM_SHA256: + return std::make_unique(); + case TLS1_CK_AES_256_GCM_SHA384: + return std::make_unique(); + case TLS1_CK_CHACHA20_POLY1305_SHA256: + return std::make_unique(); + default: + QUIC_BUG(quic_bug_10660_1) << "TLS cipher suite is unknown to QUIC"; + return nullptr; + } +} + +// static +void QuicDecrypter::DiversifyPreliminaryKey(absl::string_view preliminary_key, + absl::string_view nonce_prefix, + const DiversificationNonce& nonce, + size_t key_size, + size_t nonce_prefix_size, + std::string* out_key, + std::string* out_nonce_prefix) { + QuicHKDF hkdf((std::string(preliminary_key)) + (std::string(nonce_prefix)), + absl::string_view(nonce.data(), nonce.size()), + "QUIC key diversification", 0, key_size, 0, nonce_prefix_size, + 0); + *out_key = std::string(hkdf.server_write_key()); + *out_nonce_prefix = std::string(hkdf.server_write_iv()); +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/quic_decrypter.h b/quiche/quic/core/crypto/quic_decrypter.h new file mode 100644 index 000000000000..8e3c754a08d3 --- /dev/null +++ b/quiche/quic/core/crypto/quic_decrypter.h @@ -0,0 +1,93 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_QUIC_DECRYPTER_H_ +#define QUICHE_QUIC_CORE_CRYPTO_QUIC_DECRYPTER_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/quic_crypter.h" +#include "quiche/quic/core/quic_data_reader.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class QUIC_EXPORT_PRIVATE QuicDecrypter : public QuicCrypter { + public: + virtual ~QuicDecrypter() {} + + static std::unique_ptr Create(const ParsedQuicVersion& version, + QuicTag algorithm); + + // Creates an IETF QuicDecrypter based on |cipher_suite| which must be an id + // returned by SSL_CIPHER_get_id. The caller is responsible for taking + // ownership of the new QuicDecrypter. + static std::unique_ptr CreateFromCipherSuite( + uint32_t cipher_suite); + + // Sets the encryption key. Returns true on success, false on failure. + // |DecryptPacket| may not be called until |SetDiversificationNonce| is + // called and the preliminary keying material will be combined with that + // nonce in order to create the actual key and nonce-prefix. + // + // If this function is called, neither |SetKey| nor |SetNoncePrefix| may be + // called. + virtual bool SetPreliminaryKey(absl::string_view key) = 0; + + // SetDiversificationNonce uses |nonce| to derive final keys based on the + // input keying material given by calling |SetPreliminaryKey|. + // + // Calling this function is a no-op if |SetPreliminaryKey| hasn't been + // called. + virtual bool SetDiversificationNonce(const DiversificationNonce& nonce) = 0; + + // Populates |output| with the decrypted |ciphertext| and populates + // |output_length| with the length. Returns 0 if there is an error. + // |output| size is specified by |max_output_length| and must be + // at least as large as the ciphertext. |packet_number| is + // appended to the |nonce_prefix| value provided in SetNoncePrefix() + // to form the nonce. + // TODO(wtc): add a way for DecryptPacket to report decryption failure due + // to non-authentic inputs, as opposed to other reasons for failure. + virtual bool DecryptPacket(uint64_t packet_number, + absl::string_view associated_data, + absl::string_view ciphertext, char* output, + size_t* output_length, + size_t max_output_length) = 0; + + // Reads a sample of ciphertext from |sample_reader| and uses the header + // protection key to generate a mask to use for header protection. If + // successful, this function returns this mask, which is at least 5 bytes + // long. Callers can detect failure by checking if the output string is empty. + virtual std::string GenerateHeaderProtectionMask( + QuicDataReader* sample_reader) = 0; + + // The ID of the cipher. Return 0x03000000 ORed with the 'cryptographic suite + // selector'. + virtual uint32_t cipher_id() const = 0; + + // Returns the maximum number of packets that can safely fail decryption with + // this decrypter. + virtual QuicPacketCount GetIntegrityLimit() const = 0; + + // For use by unit tests only. + virtual absl::string_view GetKey() const = 0; + virtual absl::string_view GetNoncePrefix() const = 0; + + static void DiversifyPreliminaryKey(absl::string_view preliminary_key, + absl::string_view nonce_prefix, + const DiversificationNonce& nonce, + size_t key_size, size_t nonce_prefix_size, + std::string* out_key, + std::string* out_nonce_prefix); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_QUIC_DECRYPTER_H_ diff --git a/quiche/quic/core/crypto/quic_encrypter.cc b/quiche/quic/core/crypto/quic_encrypter.cc new file mode 100644 index 000000000000..151b8d058c2a --- /dev/null +++ b/quiche/quic/core/crypto/quic_encrypter.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/quic_encrypter.h" + +#include + +#include "openssl/tls1.h" +#include "quiche/quic/core/crypto/aes_128_gcm_12_encrypter.h" +#include "quiche/quic/core/crypto/aes_128_gcm_encrypter.h" +#include "quiche/quic/core/crypto/aes_256_gcm_encrypter.h" +#include "quiche/quic/core/crypto/chacha20_poly1305_encrypter.h" +#include "quiche/quic/core/crypto/chacha20_poly1305_tls_encrypter.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +// static +std::unique_ptr QuicEncrypter::Create( + const ParsedQuicVersion& version, QuicTag algorithm) { + switch (algorithm) { + case kAESG: + if (version.UsesInitialObfuscators()) { + return std::make_unique(); + } else { + return std::make_unique(); + } + case kCC20: + if (version.UsesInitialObfuscators()) { + return std::make_unique(); + } else { + return std::make_unique(); + } + default: + QUIC_LOG(FATAL) << "Unsupported algorithm: " << algorithm; + return nullptr; + } +} + +// static +std::unique_ptr QuicEncrypter::CreateFromCipherSuite( + uint32_t cipher_suite) { + switch (cipher_suite) { + case TLS1_CK_AES_128_GCM_SHA256: + return std::make_unique(); + case TLS1_CK_AES_256_GCM_SHA384: + return std::make_unique(); + case TLS1_CK_CHACHA20_POLY1305_SHA256: + return std::make_unique(); + default: + QUIC_BUG(quic_bug_10711_1) << "TLS cipher suite is unknown to QUIC"; + return nullptr; + } +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/quic_encrypter.h b/quiche/quic/core/crypto/quic_encrypter.h new file mode 100644 index 000000000000..7b8c4fa1a454 --- /dev/null +++ b/quiche/quic/core/crypto/quic_encrypter.h @@ -0,0 +1,70 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_QUIC_ENCRYPTER_H_ +#define QUICHE_QUIC_CORE_CRYPTO_QUIC_ENCRYPTER_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/quic_crypter.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class QUIC_EXPORT_PRIVATE QuicEncrypter : public QuicCrypter { + public: + virtual ~QuicEncrypter() {} + + static std::unique_ptr Create(const ParsedQuicVersion& version, + QuicTag algorithm); + + // Creates an IETF QuicEncrypter based on |cipher_suite| which must be an id + // returned by SSL_CIPHER_get_id. The caller is responsible for taking + // ownership of the new QuicEncrypter. + static std::unique_ptr CreateFromCipherSuite( + uint32_t cipher_suite); + + // Writes encrypted |plaintext| and a MAC over |plaintext| and + // |associated_data| into output. Sets |output_length| to the number of + // bytes written. Returns true on success or false if there was an error. + // |packet_number| is appended to the |nonce_prefix| value provided in + // SetNoncePrefix() to form the nonce. |output| must not overlap with + // |associated_data|. If |output| overlaps with |plaintext| then + // |plaintext| must be <= |output|. + virtual bool EncryptPacket(uint64_t packet_number, + absl::string_view associated_data, + absl::string_view plaintext, char* output, + size_t* output_length, + size_t max_output_length) = 0; + + // Takes a |sample| of ciphertext and uses the header protection key to + // generate a mask to use for header protection, and returns that mask. On + // success, the mask will be at least 5 bytes long; on failure the string will + // be empty. + virtual std::string GenerateHeaderProtectionMask( + absl::string_view sample) = 0; + + // Returns the maximum length of plaintext that can be encrypted + // to ciphertext no larger than |ciphertext_size|. + virtual size_t GetMaxPlaintextSize(size_t ciphertext_size) const = 0; + + // Returns the length of the ciphertext that would be generated by encrypting + // to plaintext of size |plaintext_size|. + virtual size_t GetCiphertextSize(size_t plaintext_size) const = 0; + + // Returns the maximum number of packets that can be safely encrypted with + // this encrypter. + virtual QuicPacketCount GetConfidentialityLimit() const = 0; + + // For use by unit tests only. + virtual absl::string_view GetKey() const = 0; + virtual absl::string_view GetNoncePrefix() const = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_QUIC_ENCRYPTER_H_ diff --git a/quiche/quic/core/crypto/quic_hkdf.cc b/quiche/quic/core/crypto/quic_hkdf.cc new file mode 100644 index 000000000000..14dab76a32ce --- /dev/null +++ b/quiche/quic/core/crypto/quic_hkdf.cc @@ -0,0 +1,98 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/quic_hkdf.h" + +#include + +#include "absl/strings/string_view.h" +#include "openssl/digest.h" +#include "openssl/hkdf.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +const size_t kSHA256HashLength = 32; +const size_t kMaxKeyMaterialSize = kSHA256HashLength * 256; + +QuicHKDF::QuicHKDF(absl::string_view secret, absl::string_view salt, + absl::string_view info, size_t key_bytes_to_generate, + size_t iv_bytes_to_generate, + size_t subkey_secret_bytes_to_generate) + : QuicHKDF(secret, salt, info, key_bytes_to_generate, key_bytes_to_generate, + iv_bytes_to_generate, iv_bytes_to_generate, + subkey_secret_bytes_to_generate) {} + +QuicHKDF::QuicHKDF(absl::string_view secret, absl::string_view salt, + absl::string_view info, size_t client_key_bytes_to_generate, + size_t server_key_bytes_to_generate, + size_t client_iv_bytes_to_generate, + size_t server_iv_bytes_to_generate, + size_t subkey_secret_bytes_to_generate) { + const size_t material_length = + 2 * client_key_bytes_to_generate + client_iv_bytes_to_generate + + 2 * server_key_bytes_to_generate + server_iv_bytes_to_generate + + subkey_secret_bytes_to_generate; + QUICHE_DCHECK_LT(material_length, kMaxKeyMaterialSize); + + output_.resize(material_length); + // On Windows, when the size of output_ is zero, dereference of 0'th element + // results in a crash. C++11 solves this problem by adding a data() getter + // method to std::vector. + if (output_.empty()) { + return; + } + + ::HKDF(&output_[0], output_.size(), ::EVP_sha256(), + reinterpret_cast(secret.data()), secret.size(), + reinterpret_cast(salt.data()), salt.size(), + reinterpret_cast(info.data()), info.size()); + + size_t j = 0; + if (client_key_bytes_to_generate) { + client_write_key_ = absl::string_view(reinterpret_cast(&output_[j]), + client_key_bytes_to_generate); + j += client_key_bytes_to_generate; + } + + if (server_key_bytes_to_generate) { + server_write_key_ = absl::string_view(reinterpret_cast(&output_[j]), + server_key_bytes_to_generate); + j += server_key_bytes_to_generate; + } + + if (client_iv_bytes_to_generate) { + client_write_iv_ = absl::string_view(reinterpret_cast(&output_[j]), + client_iv_bytes_to_generate); + j += client_iv_bytes_to_generate; + } + + if (server_iv_bytes_to_generate) { + server_write_iv_ = absl::string_view(reinterpret_cast(&output_[j]), + server_iv_bytes_to_generate); + j += server_iv_bytes_to_generate; + } + + if (subkey_secret_bytes_to_generate) { + subkey_secret_ = absl::string_view(reinterpret_cast(&output_[j]), + subkey_secret_bytes_to_generate); + j += subkey_secret_bytes_to_generate; + } + // Repeat client and server key bytes for header protection keys. + if (client_key_bytes_to_generate) { + client_hp_key_ = absl::string_view(reinterpret_cast(&output_[j]), + client_key_bytes_to_generate); + j += client_key_bytes_to_generate; + } + + if (server_key_bytes_to_generate) { + server_hp_key_ = absl::string_view(reinterpret_cast(&output_[j]), + server_key_bytes_to_generate); + j += server_key_bytes_to_generate; + } +} + +QuicHKDF::~QuicHKDF() {} + +} // namespace quic diff --git a/quiche/quic/core/crypto/quic_hkdf.h b/quiche/quic/core/crypto/quic_hkdf.h new file mode 100644 index 000000000000..3e30f1cb02fc --- /dev/null +++ b/quiche/quic/core/crypto/quic_hkdf.h @@ -0,0 +1,72 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_QUIC_HKDF_H_ +#define QUICHE_QUIC_CORE_CRYPTO_QUIC_HKDF_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// QuicHKDF implements the key derivation function specified in RFC 5869 +// (using SHA-256) and outputs key material, as needed by QUIC. +// See https://tools.ietf.org/html/rfc5869 for details. +class QUIC_EXPORT_PRIVATE QuicHKDF { + public: + // |secret|: the input shared secret (or, from RFC 5869, the IKM). + // |salt|: an (optional) public salt / non-secret random value. While + // optional, callers are strongly recommended to provide a salt. There is no + // added security value in making this larger than the SHA-256 block size of + // 64 bytes. + // |info|: an (optional) label to distinguish different uses of HKDF. It is + // optional context and application specific information (can be a zero-length + // string). + // |key_bytes_to_generate|: the number of bytes of key material to generate + // for both client and server. + // |iv_bytes_to_generate|: the number of bytes of IV to generate for both + // client and server. + // |subkey_secret_bytes_to_generate|: the number of bytes of subkey secret to + // generate, shared between client and server. + QuicHKDF(absl::string_view secret, absl::string_view salt, + absl::string_view info, size_t key_bytes_to_generate, + size_t iv_bytes_to_generate, size_t subkey_secret_bytes_to_generate); + + // An alternative constructor that allows the client and server key/IV + // lengths to be different. + QuicHKDF(absl::string_view secret, absl::string_view salt, + absl::string_view info, size_t client_key_bytes_to_generate, + size_t server_key_bytes_to_generate, + size_t client_iv_bytes_to_generate, + size_t server_iv_bytes_to_generate, + size_t subkey_secret_bytes_to_generate); + + ~QuicHKDF(); + + absl::string_view client_write_key() const { return client_write_key_; } + absl::string_view client_write_iv() const { return client_write_iv_; } + absl::string_view server_write_key() const { return server_write_key_; } + absl::string_view server_write_iv() const { return server_write_iv_; } + absl::string_view subkey_secret() const { return subkey_secret_; } + absl::string_view client_hp_key() const { return client_hp_key_; } + absl::string_view server_hp_key() const { return server_hp_key_; } + + private: + std::vector output_; + + absl::string_view client_write_key_; + absl::string_view server_write_key_; + absl::string_view client_write_iv_; + absl::string_view server_write_iv_; + absl::string_view subkey_secret_; + absl::string_view client_hp_key_; + absl::string_view server_hp_key_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_QUIC_HKDF_H_ diff --git a/quiche/quic/core/crypto/quic_hkdf_test.cc b/quiche/quic/core/crypto/quic_hkdf_test.cc new file mode 100644 index 000000000000..48f041f1f823 --- /dev/null +++ b/quiche/quic/core/crypto/quic_hkdf_test.cc @@ -0,0 +1,91 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/quic_hkdf.h" + +#include + +#include "absl/base/macros.h" +#include "absl/strings/escaping.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { +namespace { + +struct HKDFInput { + const char* key_hex; + const char* salt_hex; + const char* info_hex; + const char* output_hex; +}; + +// These test cases are taken from +// https://tools.ietf.org/html/rfc5869#appendix-A. +static const HKDFInput kHKDFInputs[] = { + { + "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b", + "000102030405060708090a0b0c", + "f0f1f2f3f4f5f6f7f8f9", + "3cb25f25faacd57a90434f64d0362f2a2d2d0a90cf1a5a4c5db02d56ecc4c5bf340072" + "08d5" + "b887185865", + }, + { + "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122" + "2324" + "25262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f4041424344454647" + "4849" + "4a4b4c4d4e4f", + "606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182" + "8384" + "85868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7" + "a8a9" + "aaabacadaeaf", + "b0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2" + "d3d4" + "d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7" + "f8f9" + "fafbfcfdfeff", + "b11e398dc80327a1c8e7f78c596a49344f012eda2d4efad8a050cc4c19afa97c59045a" + "99ca" + "c7827271cb41c65e590e09da3275600c2f09b8367793a9aca3db71cc30c58179ec3e87" + "c14c" + "01d5c1f3434f1d87", + }, + { + "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b", + "", + "", + "8da4e775a563c18f715f802a063c5a31b8a11f5c5ee1879ec3454e5f3c738d2d9d2013" + "95fa" + "a4b61a96c8", + }, +}; + +class QuicHKDFTest : public QuicTest {}; + +TEST_F(QuicHKDFTest, HKDF) { + for (size_t i = 0; i < ABSL_ARRAYSIZE(kHKDFInputs); i++) { + const HKDFInput& test(kHKDFInputs[i]); + SCOPED_TRACE(i); + + const std::string key = absl::HexStringToBytes(test.key_hex); + const std::string salt = absl::HexStringToBytes(test.salt_hex); + const std::string info = absl::HexStringToBytes(test.info_hex); + const std::string expected = absl::HexStringToBytes(test.output_hex); + + // We set the key_length to the length of the expected output and then take + // the result from the first key, which is the client write key. + QuicHKDF hkdf(key, salt, info, expected.size(), 0, 0); + + ASSERT_EQ(expected.size(), hkdf.client_write_key().size()); + EXPECT_EQ(0, memcmp(expected.data(), hkdf.client_write_key().data(), + expected.size())); + } +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/quic_random.h b/quiche/quic/core/crypto/quic_random.h new file mode 100644 index 000000000000..9f7c21626ca6 --- /dev/null +++ b/quiche/quic/core/crypto/quic_random.h @@ -0,0 +1,16 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_QUIC_RANDOM_H_ +#define QUICHE_QUIC_CORE_CRYPTO_QUIC_RANDOM_H_ + +#include "quiche/common/quiche_random.h" + +namespace quic { + +using QuicRandom = quiche::QuicheRandom; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_QUIC_RANDOM_H_ diff --git a/quiche/quic/core/crypto/tls_client_connection.cc b/quiche/quic/core/crypto/tls_client_connection.cc new file mode 100644 index 000000000000..7436b23b9b11 --- /dev/null +++ b/quiche/quic/core/crypto/tls_client_connection.cc @@ -0,0 +1,48 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/tls_client_connection.h" + +namespace quic { + +TlsClientConnection::TlsClientConnection(SSL_CTX* ssl_ctx, Delegate* delegate, + QuicSSLConfig ssl_config) + : TlsConnection(ssl_ctx, delegate->ConnectionDelegate(), + std::move(ssl_config)), + delegate_(delegate) {} + +// static +bssl::UniquePtr TlsClientConnection::CreateSslCtx( + bool enable_early_data) { + bssl::UniquePtr ssl_ctx = TlsConnection::CreateSslCtx(); + // Configure certificate verification. + SSL_CTX_set_custom_verify(ssl_ctx.get(), SSL_VERIFY_PEER, &VerifyCallback); + int reverify_on_resume_enabled = 1; + SSL_CTX_set_reverify_on_resume(ssl_ctx.get(), reverify_on_resume_enabled); + + // Configure session caching. + SSL_CTX_set_session_cache_mode( + ssl_ctx.get(), SSL_SESS_CACHE_CLIENT | SSL_SESS_CACHE_NO_INTERNAL); + SSL_CTX_sess_set_new_cb(ssl_ctx.get(), NewSessionCallback); + + // TODO(wub): Always enable early data on the SSL_CTX, but allow it to be + // overridden on the SSL object, via QuicSSLConfig. + SSL_CTX_set_early_data_enabled(ssl_ctx.get(), enable_early_data); + return ssl_ctx; +} + +void TlsClientConnection::SetCertChain( + const std::vector& cert_chain, EVP_PKEY* privkey) { + SSL_set_chain_and_key(ssl(), cert_chain.data(), cert_chain.size(), privkey, + /*privkey_method=*/nullptr); +} + +// static +int TlsClientConnection::NewSessionCallback(SSL* ssl, SSL_SESSION* session) { + static_cast(ConnectionFromSsl(ssl)) + ->delegate_->InsertSession(bssl::UniquePtr(session)); + return 1; +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/tls_client_connection.h b/quiche/quic/core/crypto/tls_client_connection.h new file mode 100644 index 000000000000..3bf35ce0ecd2 --- /dev/null +++ b/quiche/quic/core/crypto/tls_client_connection.h @@ -0,0 +1,54 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_TLS_CLIENT_CONNECTION_H_ +#define QUICHE_QUIC_CORE_CRYPTO_TLS_CLIENT_CONNECTION_H_ + +#include "quiche/quic/core/crypto/tls_connection.h" + +namespace quic { + +// TlsClientConnection receives calls for client-specific BoringSSL callbacks +// and calls its Delegate for the implementation of those callbacks. +class QUIC_EXPORT_PRIVATE TlsClientConnection : public TlsConnection { + public: + // A TlsClientConnection::Delegate implements the client-specific methods that + // are set as callbacks for an SSL object. + class QUIC_EXPORT_PRIVATE Delegate { + public: + virtual ~Delegate() {} + + protected: + // Called when a NewSessionTicket is received from the server. + virtual void InsertSession(bssl::UniquePtr session) = 0; + + // Provides the delegate for callbacks that are shared between client and + // server. + virtual TlsConnection::Delegate* ConnectionDelegate() = 0; + + friend class TlsClientConnection; + }; + + TlsClientConnection(SSL_CTX* ssl_ctx, Delegate* delegate, + QuicSSLConfig ssl_config); + + // Creates and configures an SSL_CTX that is appropriate for clients to use. + static bssl::UniquePtr CreateSslCtx(bool enable_early_data); + + // Set the client cert and private key to be used on this connection, if + // requested by the server. + void SetCertChain(const std::vector& cert_chain, + EVP_PKEY* privkey); + + private: + // Registered as the callback for SSL_CTX_sess_set_new_cb, which calls + // Delegate::InsertSession. + static int NewSessionCallback(SSL* ssl, SSL_SESSION* session); + + Delegate* delegate_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_TLS_CLIENT_CONNECTION_H_ diff --git a/quiche/quic/core/crypto/tls_connection.cc b/quiche/quic/core/crypto/tls_connection.cc new file mode 100644 index 000000000000..1b54b14645f3 --- /dev/null +++ b/quiche/quic/core/crypto/tls_connection.cc @@ -0,0 +1,206 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/tls_connection.h" + +#include "absl/strings/string_view.h" +#include "openssl/ssl.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" + +namespace quic { + +namespace { + +// BoringSSL allows storing extra data off of some of its data structures, +// including the SSL struct. To allow for multiple callers to store data, each +// caller can use a different index for setting and getting data. These indices +// are globals handed out by calling SSL_get_ex_new_index. +// +// SslIndexSingleton calls SSL_get_ex_new_index on its construction, and then +// provides this index to be used in calls to SSL_get_ex_data/SSL_set_ex_data. +// This is used to store in the SSL struct a pointer to the TlsConnection which +// owns it. +class SslIndexSingleton { + public: + static SslIndexSingleton* GetInstance() { + static SslIndexSingleton* instance = new SslIndexSingleton(); + return instance; + } + + int ssl_ex_data_index_connection() const { + return ssl_ex_data_index_connection_; + } + + private: + SslIndexSingleton() { + CRYPTO_library_init(); + ssl_ex_data_index_connection_ = + SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); + QUICHE_CHECK_LE(0, ssl_ex_data_index_connection_); + } + + SslIndexSingleton(const SslIndexSingleton&) = delete; + SslIndexSingleton& operator=(const SslIndexSingleton&) = delete; + + // The index to supply to SSL_get_ex_data/SSL_set_ex_data for getting/setting + // the TlsConnection pointer. + int ssl_ex_data_index_connection_; +}; + +} // namespace + +// static +EncryptionLevel TlsConnection::QuicEncryptionLevel( + enum ssl_encryption_level_t level) { + switch (level) { + case ssl_encryption_initial: + return ENCRYPTION_INITIAL; + case ssl_encryption_early_data: + return ENCRYPTION_ZERO_RTT; + case ssl_encryption_handshake: + return ENCRYPTION_HANDSHAKE; + case ssl_encryption_application: + return ENCRYPTION_FORWARD_SECURE; + default: + QUIC_BUG(quic_bug_10698_1) + << "Invalid ssl_encryption_level_t " << static_cast(level); + return ENCRYPTION_INITIAL; + } +} + +// static +enum ssl_encryption_level_t TlsConnection::BoringEncryptionLevel( + EncryptionLevel level) { + switch (level) { + case ENCRYPTION_INITIAL: + return ssl_encryption_initial; + case ENCRYPTION_HANDSHAKE: + return ssl_encryption_handshake; + case ENCRYPTION_ZERO_RTT: + return ssl_encryption_early_data; + case ENCRYPTION_FORWARD_SECURE: + return ssl_encryption_application; + default: + QUIC_BUG(quic_bug_10698_2) + << "Invalid encryption level " << static_cast(level); + return ssl_encryption_initial; + } +} + +TlsConnection::TlsConnection(SSL_CTX* ssl_ctx, + TlsConnection::Delegate* delegate, + QuicSSLConfig ssl_config) + : delegate_(delegate), + ssl_(SSL_new(ssl_ctx)), + ssl_config_(std::move(ssl_config)) { + SSL_set_ex_data( + ssl(), SslIndexSingleton::GetInstance()->ssl_ex_data_index_connection(), + this); + if (ssl_config_.early_data_enabled.has_value()) { + const int early_data_enabled = *ssl_config_.early_data_enabled ? 1 : 0; + SSL_set_early_data_enabled(ssl(), early_data_enabled); + } + if (ssl_config_.signing_algorithm_prefs.has_value()) { + SSL_set_signing_algorithm_prefs( + ssl(), ssl_config_.signing_algorithm_prefs->data(), + ssl_config_.signing_algorithm_prefs->size()); + } + if (ssl_config_.disable_ticket_support.has_value()) { + if (*ssl_config_.disable_ticket_support) { + SSL_set_options(ssl(), SSL_OP_NO_TICKET); + } + } +} + +void TlsConnection::EnableInfoCallback() { + SSL_set_info_callback( + ssl(), +[](const SSL* ssl, int type, int value) { + ConnectionFromSsl(ssl)->delegate_->InfoCallback(type, value); + }); +} + +void TlsConnection::DisableTicketSupport() { + ssl_config_.disable_ticket_support = true; + SSL_set_options(ssl(), SSL_OP_NO_TICKET); +} + +// static +bssl::UniquePtr TlsConnection::CreateSslCtx() { + CRYPTO_library_init(); + bssl::UniquePtr ssl_ctx(SSL_CTX_new(TLS_with_buffers_method())); + SSL_CTX_set_min_proto_version(ssl_ctx.get(), TLS1_3_VERSION); + SSL_CTX_set_max_proto_version(ssl_ctx.get(), TLS1_3_VERSION); + SSL_CTX_set_quic_method(ssl_ctx.get(), &kSslQuicMethod); + return ssl_ctx; +} + +// static +TlsConnection* TlsConnection::ConnectionFromSsl(const SSL* ssl) { + return reinterpret_cast(SSL_get_ex_data( + ssl, SslIndexSingleton::GetInstance()->ssl_ex_data_index_connection())); +} + +// static +enum ssl_verify_result_t TlsConnection::VerifyCallback(SSL* ssl, + uint8_t* out_alert) { + return ConnectionFromSsl(ssl)->delegate_->VerifyCert(out_alert); +} + +const SSL_QUIC_METHOD TlsConnection::kSslQuicMethod{ + TlsConnection::SetReadSecretCallback, TlsConnection::SetWriteSecretCallback, + TlsConnection::WriteMessageCallback, TlsConnection::FlushFlightCallback, + TlsConnection::SendAlertCallback}; + +// static +int TlsConnection::SetReadSecretCallback(SSL* ssl, + enum ssl_encryption_level_t level, + const SSL_CIPHER* cipher, + const uint8_t* secret, + size_t secret_length) { + TlsConnection::Delegate* delegate = ConnectionFromSsl(ssl)->delegate_; + if (!delegate->SetReadSecret(QuicEncryptionLevel(level), cipher, + absl::MakeSpan(secret, secret_length))) { + return 0; + } + return 1; +} + +// static +int TlsConnection::SetWriteSecretCallback(SSL* ssl, + enum ssl_encryption_level_t level, + const SSL_CIPHER* cipher, + const uint8_t* secret, + size_t secret_length) { + TlsConnection::Delegate* delegate = ConnectionFromSsl(ssl)->delegate_; + delegate->SetWriteSecret(QuicEncryptionLevel(level), cipher, + absl::MakeSpan(secret, secret_length)); + return 1; +} + +// static +int TlsConnection::WriteMessageCallback(SSL* ssl, + enum ssl_encryption_level_t level, + const uint8_t* data, size_t len) { + ConnectionFromSsl(ssl)->delegate_->WriteMessage( + QuicEncryptionLevel(level), + absl::string_view(reinterpret_cast(data), len)); + return 1; +} + +// static +int TlsConnection::FlushFlightCallback(SSL* ssl) { + ConnectionFromSsl(ssl)->delegate_->FlushFlight(); + return 1; +} + +// static +int TlsConnection::SendAlertCallback(SSL* ssl, + enum ssl_encryption_level_t level, + uint8_t desc) { + ConnectionFromSsl(ssl)->delegate_->SendAlert(QuicEncryptionLevel(level), + desc); + return 1; +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/tls_connection.h b/quiche/quic/core/crypto/tls_connection.h new file mode 100644 index 000000000000..5c4e8b8884a1 --- /dev/null +++ b/quiche/quic/core/crypto/tls_connection.h @@ -0,0 +1,153 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_TLS_CONNECTION_H_ +#define QUICHE_QUIC_CORE_CRYPTO_TLS_CONNECTION_H_ + +#include + +#include "absl/strings/string_view.h" +#include "openssl/ssl.h" +#include "quiche/quic/core/quic_types.h" + +namespace quic { + +// TlsConnection wraps BoringSSL's SSL object which represents a single TLS +// connection. Callbacks set in BoringSSL which are called with an SSL* argument +// will get dispatched to the TlsConnection object owning that SSL. In turn, the +// TlsConnection will delegate the implementation of that callback to its +// Delegate. +// +// The owner of the TlsConnection is responsible for driving the TLS handshake +// (and other interactions with the SSL*). This class only handles mapping +// callbacks to the correct instance. +class QUIC_EXPORT_PRIVATE TlsConnection { + public: + // A TlsConnection::Delegate implements the methods that are set as callbacks + // of TlsConnection. + class QUIC_EXPORT_PRIVATE Delegate { + public: + virtual ~Delegate() {} + + protected: + // Certificate management functions: + + // Verifies the peer's certificate chain. It may use + // SSL_get0_peer_certificates to get the cert chain. This method returns + // ssl_verify_ok if the cert is valid, ssl_verify_invalid if it is invalid, + // or ssl_verify_retry if verification is happening asynchronously. + virtual enum ssl_verify_result_t VerifyCert(uint8_t* out_alert) = 0; + + // QUIC-TLS interface functions: + + // SetWriteSecret provides the encryption secret used to encrypt messages at + // encryption level |level|. The secret provided here is the one from the + // TLS 1.3 key schedule (RFC 8446 section 7.1), in particular the handshake + // traffic secrets and application traffic secrets. The provided write + // secret must be used with the provided cipher suite |cipher|. + virtual void SetWriteSecret(EncryptionLevel level, const SSL_CIPHER* cipher, + absl::Span write_secret) = 0; + + // SetReadSecret is similar to SetWriteSecret, except that it is used for + // decrypting messages. SetReadSecret at a particular level is always called + // after SetWriteSecret for that level, except for ENCRYPTION_ZERO_RTT, + // where the EncryptionLevel for SetWriteSecret is + // ENCRYPTION_FORWARD_SECURE. + virtual bool SetReadSecret(EncryptionLevel level, const SSL_CIPHER* cipher, + absl::Span read_secret) = 0; + + // WriteMessage is called when there is |data| from the TLS stack ready for + // the QUIC stack to write in a crypto frame. The data must be transmitted + // at encryption level |level|. + virtual void WriteMessage(EncryptionLevel level, + absl::string_view data) = 0; + + // FlushFlight is called to signal that the current flight of messages have + // all been written (via calls to WriteMessage) and can be flushed to the + // underlying transport. + virtual void FlushFlight() = 0; + + // SendAlert causes this TlsConnection to close the QUIC connection with an + // error code corersponding to the TLS alert description |desc| sent at + // level |level|. + virtual void SendAlert(EncryptionLevel level, uint8_t desc) = 0; + + // Informational callback from BoringSSL. This callback is disabled by + // default, but can be enabled by TlsConnection::EnableInfoCallback. + // + // See |SSL_CTX_set_info_callback| for the meaning of |type| and |value|. + virtual void InfoCallback(int type, int value) = 0; + + friend class TlsConnection; + }; + + TlsConnection(const TlsConnection&) = delete; + TlsConnection& operator=(const TlsConnection&) = delete; + + // Configure the SSL such that delegate_->InfoCallback will be called. + void EnableInfoCallback(); + + // Configure the SSL to disable session ticket support. Note that, this + // function simply sets the |SSL_OP_NO_TICKET| option on the SSL object, it + // does not check whether it is too late to do so. + void DisableTicketSupport(); + + // Functions to convert between BoringSSL's enum ssl_encryption_level_t and + // QUIC's EncryptionLevel. + static EncryptionLevel QuicEncryptionLevel(enum ssl_encryption_level_t level); + static enum ssl_encryption_level_t BoringEncryptionLevel( + EncryptionLevel level); + + SSL* ssl() const { return ssl_.get(); } + + const QuicSSLConfig& ssl_config() const { return ssl_config_; } + + protected: + // TlsConnection does not take ownership of |ssl_ctx| or |delegate|; they must + // outlive the TlsConnection object. + TlsConnection(SSL_CTX* ssl_ctx, Delegate* delegate, QuicSSLConfig ssl_config); + + // Creates an SSL_CTX and configures it with the options that are appropriate + // for both client and server. The caller is responsible for ownership of the + // newly created struct. + static bssl::UniquePtr CreateSslCtx(); + + // From a given SSL* |ssl|, returns a pointer to the TlsConnection that it + // belongs to. This helper method allows the callbacks set in BoringSSL to be + // dispatched to the correct TlsConnection from the SSL* passed into the + // callback. + static TlsConnection* ConnectionFromSsl(const SSL* ssl); + + // Registered as the callback for SSL(_CTX)_set_custom_verify. The + // implementation is delegated to Delegate::VerifyCert. + static enum ssl_verify_result_t VerifyCallback(SSL* ssl, uint8_t* out_alert); + + QuicSSLConfig& mutable_ssl_config() { return ssl_config_; } + + private: + // TlsConnection implements SSL_QUIC_METHOD, which provides the interface + // between BoringSSL's TLS stack and a QUIC implementation. + static const SSL_QUIC_METHOD kSslQuicMethod; + + // The following static functions make up the members of kSslQuicMethod: + static int SetReadSecretCallback(SSL* ssl, enum ssl_encryption_level_t level, + const SSL_CIPHER* cipher, + const uint8_t* secret, size_t secret_len); + static int SetWriteSecretCallback(SSL* ssl, enum ssl_encryption_level_t level, + const SSL_CIPHER* cipher, + const uint8_t* secret, size_t secret_len); + static int WriteMessageCallback(SSL* ssl, enum ssl_encryption_level_t level, + const uint8_t* data, size_t len); + static int FlushFlightCallback(SSL* ssl); + static int SendAlertCallback(SSL* ssl, enum ssl_encryption_level_t level, + uint8_t desc); + + Delegate* delegate_; + bssl::UniquePtr ssl_; + QuicSSLConfig ssl_config_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_TLS_CONNECTION_H_ diff --git a/quiche/quic/core/crypto/tls_server_connection.cc b/quiche/quic/core/crypto/tls_server_connection.cc new file mode 100644 index 000000000000..51311bcfd44d --- /dev/null +++ b/quiche/quic/core/crypto/tls_server_connection.cc @@ -0,0 +1,172 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/tls_server_connection.h" + +#include "absl/strings/string_view.h" +#include "openssl/ssl.h" +#include "quiche/quic/core/crypto/proof_source.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" + +namespace quic { + +TlsServerConnection::TlsServerConnection(SSL_CTX* ssl_ctx, Delegate* delegate, + QuicSSLConfig ssl_config) + : TlsConnection(ssl_ctx, delegate->ConnectionDelegate(), + std::move(ssl_config)), + delegate_(delegate) { + // By default, cert verify callback is not installed on ssl(), so only need to + // UpdateCertVerifyCallback() if client_cert_mode is not kNone. + if (TlsConnection::ssl_config().client_cert_mode != ClientCertMode::kNone) { + UpdateCertVerifyCallback(); + } +} + +// static +bssl::UniquePtr TlsServerConnection::CreateSslCtx( + ProofSource* proof_source) { + bssl::UniquePtr ssl_ctx = TlsConnection::CreateSslCtx(); + + // Server does not request/verify client certs by default. Individual server + // connections may call SSL_set_custom_verify on their SSL object to request + // client certs. + + SSL_CTX_set_tlsext_servername_callback(ssl_ctx.get(), + &TlsExtServernameCallback); + SSL_CTX_set_alpn_select_cb(ssl_ctx.get(), &SelectAlpnCallback, nullptr); + // We don't actually need the TicketCrypter here, but we need to know + // whether it's set. + if (proof_source->GetTicketCrypter()) { + QUIC_CODE_COUNT(quic_session_tickets_enabled); + SSL_CTX_set_ticket_aead_method(ssl_ctx.get(), + &TlsServerConnection::kSessionTicketMethod); + } else { + QUIC_CODE_COUNT(quic_session_tickets_disabled); + } + + SSL_CTX_set_early_data_enabled(ssl_ctx.get(), 1); + + SSL_CTX_set_select_certificate_cb( + ssl_ctx.get(), &TlsServerConnection::EarlySelectCertCallback); + SSL_CTX_set_options(ssl_ctx.get(), SSL_OP_CIPHER_SERVER_PREFERENCE); + + // Allow ProofSource to change SSL_CTX settings. + proof_source->OnNewSslCtx(ssl_ctx.get()); + + return ssl_ctx; +} + +void TlsServerConnection::SetCertChain( + const std::vector& cert_chain) { + SSL_set_chain_and_key(ssl(), cert_chain.data(), cert_chain.size(), nullptr, + &TlsServerConnection::kPrivateKeyMethod); +} + +void TlsServerConnection::SetClientCertMode(ClientCertMode client_cert_mode) { + if (ssl_config().client_cert_mode == client_cert_mode) { + return; + } + + mutable_ssl_config().client_cert_mode = client_cert_mode; + UpdateCertVerifyCallback(); +} + +void TlsServerConnection::UpdateCertVerifyCallback() { + const ClientCertMode client_cert_mode = ssl_config().client_cert_mode; + if (client_cert_mode == ClientCertMode::kNone) { + SSL_set_custom_verify(ssl(), SSL_VERIFY_NONE, nullptr); + return; + } + + int mode = SSL_VERIFY_PEER; + if (client_cert_mode == ClientCertMode::kRequire) { + mode |= SSL_VERIFY_FAIL_IF_NO_PEER_CERT; + } else { + QUICHE_DCHECK_EQ(client_cert_mode, ClientCertMode::kRequest); + } + SSL_set_custom_verify(ssl(), mode, &VerifyCallback); +} + +const SSL_PRIVATE_KEY_METHOD TlsServerConnection::kPrivateKeyMethod{ + &TlsServerConnection::PrivateKeySign, + nullptr, // decrypt + &TlsServerConnection::PrivateKeyComplete, +}; + +// static +TlsServerConnection* TlsServerConnection::ConnectionFromSsl(SSL* ssl) { + return static_cast( + TlsConnection::ConnectionFromSsl(ssl)); +} + +// static +ssl_select_cert_result_t TlsServerConnection::EarlySelectCertCallback( + const SSL_CLIENT_HELLO* client_hello) { + return ConnectionFromSsl(client_hello->ssl) + ->delegate_->EarlySelectCertCallback(client_hello); +} + +// static +int TlsServerConnection::TlsExtServernameCallback(SSL* ssl, int* out_alert, + void* /*arg*/) { + return ConnectionFromSsl(ssl)->delegate_->TlsExtServernameCallback(out_alert); +} + +// static +int TlsServerConnection::SelectAlpnCallback(SSL* ssl, const uint8_t** out, + uint8_t* out_len, const uint8_t* in, + unsigned in_len, void* /*arg*/) { + return ConnectionFromSsl(ssl)->delegate_->SelectAlpn(out, out_len, in, + in_len); +} + +// static +ssl_private_key_result_t TlsServerConnection::PrivateKeySign( + SSL* ssl, uint8_t* out, size_t* out_len, size_t max_out, uint16_t sig_alg, + const uint8_t* in, size_t in_len) { + return ConnectionFromSsl(ssl)->delegate_->PrivateKeySign( + out, out_len, max_out, sig_alg, + absl::string_view(reinterpret_cast(in), in_len)); +} + +// static +ssl_private_key_result_t TlsServerConnection::PrivateKeyComplete( + SSL* ssl, uint8_t* out, size_t* out_len, size_t max_out) { + return ConnectionFromSsl(ssl)->delegate_->PrivateKeyComplete(out, out_len, + max_out); +} + +// static +const SSL_TICKET_AEAD_METHOD TlsServerConnection::kSessionTicketMethod{ + TlsServerConnection::SessionTicketMaxOverhead, + TlsServerConnection::SessionTicketSeal, + TlsServerConnection::SessionTicketOpen, +}; + +// static +size_t TlsServerConnection::SessionTicketMaxOverhead(SSL* ssl) { + return ConnectionFromSsl(ssl)->delegate_->SessionTicketMaxOverhead(); +} + +// static +int TlsServerConnection::SessionTicketSeal(SSL* ssl, uint8_t* out, + size_t* out_len, size_t max_out_len, + const uint8_t* in, size_t in_len) { + return ConnectionFromSsl(ssl)->delegate_->SessionTicketSeal( + out, out_len, max_out_len, + absl::string_view(reinterpret_cast(in), in_len)); +} + +// static +enum ssl_ticket_aead_result_t TlsServerConnection::SessionTicketOpen( + SSL* ssl, uint8_t* out, size_t* out_len, size_t max_out_len, + const uint8_t* in, size_t in_len) { + return ConnectionFromSsl(ssl)->delegate_->SessionTicketOpen( + out, out_len, max_out_len, + absl::string_view(reinterpret_cast(in), in_len)); +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/tls_server_connection.h b/quiche/quic/core/crypto/tls_server_connection.h new file mode 100644 index 000000000000..3c0eb7e768c3 --- /dev/null +++ b/quiche/quic/core/crypto/tls_server_connection.h @@ -0,0 +1,180 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_TLS_SERVER_CONNECTION_H_ +#define QUICHE_QUIC_CORE_CRYPTO_TLS_SERVER_CONNECTION_H_ + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/proof_source.h" +#include "quiche/quic/core/crypto/tls_connection.h" + +namespace quic { + +// TlsServerConnection receives calls for client-specific BoringSSL callbacks +// and calls its Delegate for the implementation of those callbacks. +class QUIC_EXPORT_PRIVATE TlsServerConnection : public TlsConnection { + public: + // A TlsServerConnection::Delegate implement the server-specific methods that + // are set as callbacks for an SSL object. + class QUIC_EXPORT_PRIVATE Delegate { + public: + virtual ~Delegate() {} + + protected: + // Called from BoringSSL right after SNI is extracted, which is very early + // in the handshake process. + virtual ssl_select_cert_result_t EarlySelectCertCallback( + const SSL_CLIENT_HELLO* client_hello) = 0; + + // Called after the ClientHello extensions have been successfully parsed. + // Returns an SSL_TLSEXT_ERR_* value (see + // https://commondatastorage.googleapis.com/chromium-boringssl-docs/ssl.h.html#SSL_CTX_set_tlsext_servername_callback). + // + // On success, return SSL_TLSEXT_ERR_OK causes the server_name extension to + // be acknowledged in the ServerHello, or return SSL_TLSEXT_ERR_NOACK which + // causes it to be not acknowledged. + // + // If the function returns SSL_TLSEXT_ERR_ALERT_FATAL, then it puts in + // |*out_alert| the TLS alert value that the server will send. + // + virtual int TlsExtServernameCallback(int* out_alert) = 0; + + // Selects which ALPN to use based on the list sent by the client. + virtual int SelectAlpn(const uint8_t** out, uint8_t* out_len, + const uint8_t* in, unsigned in_len) = 0; + + // Signs |in| using the signature algorithm specified by |sig_alg| (an + // SSL_SIGN_* value). If the signing operation cannot be completed + // synchronously, ssl_private_key_retry is returned. If there is an error + // signing, or if the signature is longer than |max_out|, then + // ssl_private_key_failure is returned. Otherwise, ssl_private_key_success + // is returned with the signature put in |*out| and the length in + // |*out_len|. + virtual ssl_private_key_result_t PrivateKeySign(uint8_t* out, + size_t* out_len, + size_t max_out, + uint16_t sig_alg, + absl::string_view in) = 0; + + // When PrivateKeySign returns ssl_private_key_retry, PrivateKeyComplete + // will be called after the async sign operation has completed. + // PrivateKeyComplete puts the resulting signature in |*out| and length in + // |*out_len|. If the length is greater than |max_out| or if there was an + // error in signing, then ssl_private_key_failure is returned. Otherwise, + // ssl_private_key_success is returned. + virtual ssl_private_key_result_t PrivateKeyComplete(uint8_t* out, + size_t* out_len, + size_t max_out) = 0; + + // The following functions are used to implement an SSL_TICKET_AEAD_METHOD. + // See + // https://commondatastorage.googleapis.com/chromium-boringssl-docs/ssl.h.html#ssl_ticket_aead_result_t + // for details on the BoringSSL API. + + // SessionTicketMaxOverhead returns the maximum number of bytes of overhead + // that SessionTicketSeal may add when encrypting a session ticket. + virtual size_t SessionTicketMaxOverhead() = 0; + + // SessionTicketSeal encrypts the session ticket in |in|, putting the + // resulting encrypted ticket in |out|, writing the length of the bytes + // written to |*out_len|, which is no larger than |max_out_len|. It returns + // 1 on success and 0 on error. + virtual int SessionTicketSeal(uint8_t* out, size_t* out_len, + size_t max_out_len, absl::string_view in) = 0; + + // SessionTicketOpen is called when BoringSSL has an encrypted session + // ticket |in| and wants the ticket decrypted. This decryption operation can + // happen synchronously or asynchronously. + // + // If the decrypted ticket is not available at the time of the function + // call, this function returns ssl_ticket_aead_retry. If this function + // returns ssl_ticket_aead_retry, then SSL_do_handshake will return + // SSL_ERROR_PENDING_TICKET. Once the pending ticket decryption has + // completed, SSL_do_handshake needs to be called again. + // + // When this function is called and the decrypted ticket is available + // (either the ticket was decrypted synchronously, or an asynchronous + // operation has completed and SSL_do_handshake has been called again), the + // decrypted ticket is put in |out|, and the length of that output is + // written to |*out_len|, not to exceed |max_out_len|, and + // ssl_ticket_aead_success is returned. If the ticket cannot be decrypted + // and should be ignored, this function returns + // ssl_ticket_aead_ignore_ticket and a full handshake will be performed + // instead. If a fatal error occurs, ssl_ticket_aead_error can be returned + // which will terminate the handshake. + virtual enum ssl_ticket_aead_result_t SessionTicketOpen( + uint8_t* out, size_t* out_len, size_t max_out_len, + absl::string_view in) = 0; + + // Provides the delegate for callbacks that are shared between client and + // server. + virtual TlsConnection::Delegate* ConnectionDelegate() = 0; + + friend class TlsServerConnection; + }; + + TlsServerConnection(SSL_CTX* ssl_ctx, Delegate* delegate, + QuicSSLConfig ssl_config); + + // Creates and configures an SSL_CTX that is appropriate for servers to use. + static bssl::UniquePtr CreateSslCtx(ProofSource* proof_source); + + void SetCertChain(const std::vector& cert_chain); + + // Set the client cert mode to be used on this connection. This should be + // called right after cert selection at the latest, otherwise it is too late + // to has an effect. + void SetClientCertMode(ClientCertMode client_cert_mode); + + private: + // Specialization of TlsConnection::ConnectionFromSsl. + static TlsServerConnection* ConnectionFromSsl(SSL* ssl); + + static ssl_select_cert_result_t EarlySelectCertCallback( + const SSL_CLIENT_HELLO* client_hello); + + // These functions are registered as callbacks in BoringSSL and delegate their + // implementation to the matching methods in Delegate above. + static int TlsExtServernameCallback(SSL* ssl, int* out_alert, void* arg); + static int SelectAlpnCallback(SSL* ssl, const uint8_t** out, uint8_t* out_len, + const uint8_t* in, unsigned in_len, void* arg); + + // |kPrivateKeyMethod| is a vtable pointing to PrivateKeySign and + // PrivateKeyComplete used by the TLS stack to compute the signature for the + // CertificateVerify message (using the server's private key). + static const SSL_PRIVATE_KEY_METHOD kPrivateKeyMethod; + + // The following functions make up the contents of |kPrivateKeyMethod|. + static ssl_private_key_result_t PrivateKeySign( + SSL* ssl, uint8_t* out, size_t* out_len, size_t max_out, uint16_t sig_alg, + const uint8_t* in, size_t in_len); + static ssl_private_key_result_t PrivateKeyComplete(SSL* ssl, uint8_t* out, + size_t* out_len, + size_t max_out); + + // Implementation of SSL_TICKET_AEAD_METHOD which delegates to corresponding + // methods in TlsServerConnection::Delegate (a.k.a. TlsServerHandshaker). + static const SSL_TICKET_AEAD_METHOD kSessionTicketMethod; + + // The following functions make up the contents of |kSessionTicketMethod|. + static size_t SessionTicketMaxOverhead(SSL* ssl); + static int SessionTicketSeal(SSL* ssl, uint8_t* out, size_t* out_len, + size_t max_out_len, const uint8_t* in, + size_t in_len); + static enum ssl_ticket_aead_result_t SessionTicketOpen(SSL* ssl, uint8_t* out, + size_t* out_len, + size_t max_out_len, + const uint8_t* in, + size_t in_len); + + // Install custom verify callback on ssl() if |ssl_config().client_cert_mode| + // is not ClientCertMode::kNone. Uninstall otherwise. + void UpdateCertVerifyCallback(); + + Delegate* delegate_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_TLS_SERVER_CONNECTION_H_ diff --git a/quiche/quic/core/crypto/transport_parameters.cc b/quiche/quic/core/crypto/transport_parameters.cc new file mode 100644 index 000000000000..ebaa07000f3e --- /dev/null +++ b/quiche/quic/core/crypto/transport_parameters.cc @@ -0,0 +1,1655 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/transport_parameters.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "openssl/digest.h" +#include "openssl/sha.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_data_reader.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_ip_address.h" + +namespace quic { + +// Values of the TransportParameterId enum as defined in the +// "Transport Parameter Encoding" section of draft-ietf-quic-transport. +// When parameters are encoded, one of these enum values is used to indicate +// which parameter is encoded. The supported draft version is noted in +// transport_parameters.h. +enum TransportParameters::TransportParameterId : uint64_t { + kOriginalDestinationConnectionId = 0, + kMaxIdleTimeout = 1, + kStatelessResetToken = 2, + kMaxPacketSize = 3, + kInitialMaxData = 4, + kInitialMaxStreamDataBidiLocal = 5, + kInitialMaxStreamDataBidiRemote = 6, + kInitialMaxStreamDataUni = 7, + kInitialMaxStreamsBidi = 8, + kInitialMaxStreamsUni = 9, + kAckDelayExponent = 0xa, + kMaxAckDelay = 0xb, + kDisableActiveMigration = 0xc, + kPreferredAddress = 0xd, + kActiveConnectionIdLimit = 0xe, + kInitialSourceConnectionId = 0xf, + kRetrySourceConnectionId = 0x10, + + kMaxDatagramFrameSize = 0x20, + + kGoogleHandshakeMessage = 0x26ab, + + kInitialRoundTripTime = 0x3127, + kGoogleConnectionOptions = 0x3128, + // 0x3129 was used to convey the user agent string. + // 0x312A was used only in T050 to indicate support for HANDSHAKE_DONE. + // 0x312B was used to indicate that QUIC+TLS key updates were not supported. + // 0x4751 was used for non-standard Google-specific parameters encoded as a + // Google QUIC_CRYPTO CHLO, it has been replaced by individual parameters. + kGoogleQuicVersion = + 0x4752, // Used to transmit version and supported_versions. + + kMinAckDelay = 0xDE1A, // draft-iyengar-quic-delayed-ack. + kVersionInformation = 0xFF73DB, // draft-ietf-quic-version-negotiation. +}; + +namespace { + +constexpr QuicVersionLabel kReservedVersionMask = 0x0f0f0f0f; +constexpr QuicVersionLabel kReservedVersionBits = 0x0a0a0a0a; + +// The following constants define minimum and maximum allowed values for some of +// the parameters. These come from the "Transport Parameter Definitions" +// section of draft-ietf-quic-transport. +constexpr uint64_t kMinMaxPacketSizeTransportParam = 1200; +constexpr uint64_t kMaxAckDelayExponentTransportParam = 20; +constexpr uint64_t kDefaultAckDelayExponentTransportParam = 3; +constexpr uint64_t kMaxMaxAckDelayTransportParam = 16383; +constexpr uint64_t kDefaultMaxAckDelayTransportParam = 25; +constexpr uint64_t kMinActiveConnectionIdLimitTransportParam = 2; +constexpr uint64_t kDefaultActiveConnectionIdLimitTransportParam = 2; + +std::string TransportParameterIdToString( + TransportParameters::TransportParameterId param_id) { + switch (param_id) { + case TransportParameters::kOriginalDestinationConnectionId: + return "original_destination_connection_id"; + case TransportParameters::kMaxIdleTimeout: + return "max_idle_timeout"; + case TransportParameters::kStatelessResetToken: + return "stateless_reset_token"; + case TransportParameters::kMaxPacketSize: + return "max_udp_payload_size"; + case TransportParameters::kInitialMaxData: + return "initial_max_data"; + case TransportParameters::kInitialMaxStreamDataBidiLocal: + return "initial_max_stream_data_bidi_local"; + case TransportParameters::kInitialMaxStreamDataBidiRemote: + return "initial_max_stream_data_bidi_remote"; + case TransportParameters::kInitialMaxStreamDataUni: + return "initial_max_stream_data_uni"; + case TransportParameters::kInitialMaxStreamsBidi: + return "initial_max_streams_bidi"; + case TransportParameters::kInitialMaxStreamsUni: + return "initial_max_streams_uni"; + case TransportParameters::kAckDelayExponent: + return "ack_delay_exponent"; + case TransportParameters::kMaxAckDelay: + return "max_ack_delay"; + case TransportParameters::kDisableActiveMigration: + return "disable_active_migration"; + case TransportParameters::kPreferredAddress: + return "preferred_address"; + case TransportParameters::kActiveConnectionIdLimit: + return "active_connection_id_limit"; + case TransportParameters::kInitialSourceConnectionId: + return "initial_source_connection_id"; + case TransportParameters::kRetrySourceConnectionId: + return "retry_source_connection_id"; + case TransportParameters::kMaxDatagramFrameSize: + return "max_datagram_frame_size"; + case TransportParameters::kGoogleHandshakeMessage: + return "google_handshake_message"; + case TransportParameters::kInitialRoundTripTime: + return "initial_round_trip_time"; + case TransportParameters::kGoogleConnectionOptions: + return "google_connection_options"; + case TransportParameters::kGoogleQuicVersion: + return "google-version"; + case TransportParameters::kMinAckDelay: + return "min_ack_delay_us"; + case TransportParameters::kVersionInformation: + return "version_information"; + } + return absl::StrCat("Unknown(", param_id, ")"); +} + +bool TransportParameterIdIsKnown( + TransportParameters::TransportParameterId param_id) { + switch (param_id) { + case TransportParameters::kOriginalDestinationConnectionId: + case TransportParameters::kMaxIdleTimeout: + case TransportParameters::kStatelessResetToken: + case TransportParameters::kMaxPacketSize: + case TransportParameters::kInitialMaxData: + case TransportParameters::kInitialMaxStreamDataBidiLocal: + case TransportParameters::kInitialMaxStreamDataBidiRemote: + case TransportParameters::kInitialMaxStreamDataUni: + case TransportParameters::kInitialMaxStreamsBidi: + case TransportParameters::kInitialMaxStreamsUni: + case TransportParameters::kAckDelayExponent: + case TransportParameters::kMaxAckDelay: + case TransportParameters::kDisableActiveMigration: + case TransportParameters::kPreferredAddress: + case TransportParameters::kActiveConnectionIdLimit: + case TransportParameters::kInitialSourceConnectionId: + case TransportParameters::kRetrySourceConnectionId: + case TransportParameters::kMaxDatagramFrameSize: + case TransportParameters::kGoogleHandshakeMessage: + case TransportParameters::kInitialRoundTripTime: + case TransportParameters::kGoogleConnectionOptions: + case TransportParameters::kGoogleQuicVersion: + case TransportParameters::kMinAckDelay: + case TransportParameters::kVersionInformation: + return true; + } + return false; +} + +} // namespace + +TransportParameters::IntegerParameter::IntegerParameter( + TransportParameters::TransportParameterId param_id, uint64_t default_value, + uint64_t min_value, uint64_t max_value) + : param_id_(param_id), + value_(default_value), + default_value_(default_value), + min_value_(min_value), + max_value_(max_value), + has_been_read_(false) { + QUICHE_DCHECK_LE(min_value, default_value); + QUICHE_DCHECK_LE(default_value, max_value); + QUICHE_DCHECK_LE(max_value, quiche::kVarInt62MaxValue); +} + +TransportParameters::IntegerParameter::IntegerParameter( + TransportParameters::TransportParameterId param_id) + : TransportParameters::IntegerParameter::IntegerParameter( + param_id, 0, 0, quiche::kVarInt62MaxValue) {} + +void TransportParameters::IntegerParameter::set_value(uint64_t value) { + value_ = value; +} + +uint64_t TransportParameters::IntegerParameter::value() const { return value_; } + +bool TransportParameters::IntegerParameter::IsValid() const { + return min_value_ <= value_ && value_ <= max_value_; +} + +bool TransportParameters::IntegerParameter::Write( + QuicDataWriter* writer) const { + QUICHE_DCHECK(IsValid()); + if (value_ == default_value_) { + // Do not write if the value is default. + return true; + } + if (!writer->WriteVarInt62(param_id_)) { + QUIC_BUG(quic_bug_10743_1) << "Failed to write param_id for " << *this; + return false; + } + const quiche::QuicheVariableLengthIntegerLength value_length = + QuicDataWriter::GetVarInt62Len(value_); + if (!writer->WriteVarInt62(value_length)) { + QUIC_BUG(quic_bug_10743_2) << "Failed to write value_length for " << *this; + return false; + } + if (!writer->WriteVarInt62WithForcedLength(value_, value_length)) { + QUIC_BUG(quic_bug_10743_3) << "Failed to write value for " << *this; + return false; + } + return true; +} + +bool TransportParameters::IntegerParameter::Read(QuicDataReader* reader, + std::string* error_details) { + if (has_been_read_) { + *error_details = + "Received a second " + TransportParameterIdToString(param_id_); + return false; + } + has_been_read_ = true; + + if (!reader->ReadVarInt62(&value_)) { + *error_details = + "Failed to parse value for " + TransportParameterIdToString(param_id_); + return false; + } + if (!reader->IsDoneReading()) { + *error_details = + absl::StrCat("Received unexpected ", reader->BytesRemaining(), + " bytes after parsing ", this->ToString(false)); + return false; + } + return true; +} + +std::string TransportParameters::IntegerParameter::ToString( + bool for_use_in_list) const { + if (for_use_in_list && value_ == default_value_) { + return ""; + } + std::string rv = for_use_in_list ? " " : ""; + absl::StrAppend(&rv, TransportParameterIdToString(param_id_), " ", value_); + if (!IsValid()) { + rv += " (Invalid)"; + } + return rv; +} + +std::ostream& operator<<(std::ostream& os, + const TransportParameters::IntegerParameter& param) { + os << param.ToString(/*for_use_in_list=*/false); + return os; +} + +TransportParameters::PreferredAddress::PreferredAddress() + : ipv4_socket_address(QuicIpAddress::Any4(), 0), + ipv6_socket_address(QuicIpAddress::Any6(), 0), + connection_id(EmptyQuicConnectionId()), + stateless_reset_token(kStatelessResetTokenLength, 0) {} + +TransportParameters::PreferredAddress::~PreferredAddress() {} + +bool TransportParameters::PreferredAddress::operator==( + const PreferredAddress& rhs) const { + return ipv4_socket_address == rhs.ipv4_socket_address && + ipv6_socket_address == rhs.ipv6_socket_address && + connection_id == rhs.connection_id && + stateless_reset_token == rhs.stateless_reset_token; +} + +bool TransportParameters::PreferredAddress::operator!=( + const PreferredAddress& rhs) const { + return !(*this == rhs); +} + +std::ostream& operator<<( + std::ostream& os, + const TransportParameters::PreferredAddress& preferred_address) { + os << preferred_address.ToString(); + return os; +} + +std::string TransportParameters::PreferredAddress::ToString() const { + return "[" + ipv4_socket_address.ToString() + " " + + ipv6_socket_address.ToString() + " connection_id " + + connection_id.ToString() + " stateless_reset_token " + + absl::BytesToHexString(absl::string_view( + reinterpret_cast(stateless_reset_token.data()), + stateless_reset_token.size())) + + "]"; +} + +TransportParameters::LegacyVersionInformation::LegacyVersionInformation() + : version(0) {} + +bool TransportParameters::LegacyVersionInformation::operator==( + const LegacyVersionInformation& rhs) const { + return version == rhs.version && supported_versions == rhs.supported_versions; +} + +bool TransportParameters::LegacyVersionInformation::operator!=( + const LegacyVersionInformation& rhs) const { + return !(*this == rhs); +} + +std::string TransportParameters::LegacyVersionInformation::ToString() const { + std::string rv = + absl::StrCat("legacy[version ", QuicVersionLabelToString(version)); + if (!supported_versions.empty()) { + absl::StrAppend(&rv, + " supported_versions " + + QuicVersionLabelVectorToString(supported_versions)); + } + absl::StrAppend(&rv, "]"); + return rv; +} + +std::ostream& operator<<(std::ostream& os, + const TransportParameters::LegacyVersionInformation& + legacy_version_information) { + os << legacy_version_information.ToString(); + return os; +} + +TransportParameters::VersionInformation::VersionInformation() + : chosen_version(0) {} + +bool TransportParameters::VersionInformation::operator==( + const VersionInformation& rhs) const { + return chosen_version == rhs.chosen_version && + other_versions == rhs.other_versions; +} + +bool TransportParameters::VersionInformation::operator!=( + const VersionInformation& rhs) const { + return !(*this == rhs); +} + +std::string TransportParameters::VersionInformation::ToString() const { + std::string rv = absl::StrCat("[chosen_version ", + QuicVersionLabelToString(chosen_version)); + if (!other_versions.empty()) { + absl::StrAppend(&rv, " other_versions " + + QuicVersionLabelVectorToString(other_versions)); + } + absl::StrAppend(&rv, "]"); + return rv; +} + +std::ostream& operator<<( + std::ostream& os, + const TransportParameters::VersionInformation& version_information) { + os << version_information.ToString(); + return os; +} + +std::ostream& operator<<(std::ostream& os, const TransportParameters& params) { + os << params.ToString(); + return os; +} + +std::string TransportParameters::ToString() const { + std::string rv = "["; + if (perspective == Perspective::IS_SERVER) { + rv += "Server"; + } else { + rv += "Client"; + } + if (legacy_version_information.has_value()) { + rv += " " + legacy_version_information.value().ToString(); + } + if (version_information.has_value()) { + rv += " " + version_information.value().ToString(); + } + if (original_destination_connection_id.has_value()) { + rv += " " + TransportParameterIdToString(kOriginalDestinationConnectionId) + + " " + original_destination_connection_id.value().ToString(); + } + rv += max_idle_timeout_ms.ToString(/*for_use_in_list=*/true); + if (!stateless_reset_token.empty()) { + rv += " " + TransportParameterIdToString(kStatelessResetToken) + " " + + absl::BytesToHexString(absl::string_view( + reinterpret_cast(stateless_reset_token.data()), + stateless_reset_token.size())); + } + rv += max_udp_payload_size.ToString(/*for_use_in_list=*/true); + rv += initial_max_data.ToString(/*for_use_in_list=*/true); + rv += initial_max_stream_data_bidi_local.ToString(/*for_use_in_list=*/true); + rv += initial_max_stream_data_bidi_remote.ToString(/*for_use_in_list=*/true); + rv += initial_max_stream_data_uni.ToString(/*for_use_in_list=*/true); + rv += initial_max_streams_bidi.ToString(/*for_use_in_list=*/true); + rv += initial_max_streams_uni.ToString(/*for_use_in_list=*/true); + rv += ack_delay_exponent.ToString(/*for_use_in_list=*/true); + rv += max_ack_delay.ToString(/*for_use_in_list=*/true); + rv += min_ack_delay_us.ToString(/*for_use_in_list=*/true); + if (disable_active_migration) { + rv += " " + TransportParameterIdToString(kDisableActiveMigration); + } + if (preferred_address) { + rv += " " + TransportParameterIdToString(kPreferredAddress) + " " + + preferred_address->ToString(); + } + rv += active_connection_id_limit.ToString(/*for_use_in_list=*/true); + if (initial_source_connection_id.has_value()) { + rv += " " + TransportParameterIdToString(kInitialSourceConnectionId) + " " + + initial_source_connection_id.value().ToString(); + } + if (retry_source_connection_id.has_value()) { + rv += " " + TransportParameterIdToString(kRetrySourceConnectionId) + " " + + retry_source_connection_id.value().ToString(); + } + rv += max_datagram_frame_size.ToString(/*for_use_in_list=*/true); + if (google_handshake_message.has_value()) { + absl::StrAppend(&rv, " ", + TransportParameterIdToString(kGoogleHandshakeMessage), + " length: ", google_handshake_message.value().length()); + } + rv += initial_round_trip_time_us.ToString(/*for_use_in_list=*/true); + if (google_connection_options.has_value()) { + rv += " " + TransportParameterIdToString(kGoogleConnectionOptions) + " "; + bool first = true; + for (const QuicTag& connection_option : google_connection_options.value()) { + if (first) { + first = false; + } else { + rv += ","; + } + rv += QuicTagToString(connection_option); + } + } + for (const auto& kv : custom_parameters) { + absl::StrAppend(&rv, " 0x", absl::Hex(static_cast(kv.first)), + "="); + static constexpr size_t kMaxPrintableLength = 32; + if (kv.second.length() <= kMaxPrintableLength) { + rv += absl::BytesToHexString(kv.second); + } else { + absl::string_view truncated(kv.second.data(), kMaxPrintableLength); + rv += absl::StrCat(absl::BytesToHexString(truncated), "...(length ", + kv.second.length(), ")"); + } + } + rv += "]"; + return rv; +} + +TransportParameters::TransportParameters() + : max_idle_timeout_ms(kMaxIdleTimeout), + max_udp_payload_size(kMaxPacketSize, kDefaultMaxPacketSizeTransportParam, + kMinMaxPacketSizeTransportParam, + quiche::kVarInt62MaxValue), + initial_max_data(kInitialMaxData), + initial_max_stream_data_bidi_local(kInitialMaxStreamDataBidiLocal), + initial_max_stream_data_bidi_remote(kInitialMaxStreamDataBidiRemote), + initial_max_stream_data_uni(kInitialMaxStreamDataUni), + initial_max_streams_bidi(kInitialMaxStreamsBidi), + initial_max_streams_uni(kInitialMaxStreamsUni), + ack_delay_exponent(kAckDelayExponent, + kDefaultAckDelayExponentTransportParam, 0, + kMaxAckDelayExponentTransportParam), + max_ack_delay(kMaxAckDelay, kDefaultMaxAckDelayTransportParam, 0, + kMaxMaxAckDelayTransportParam), + min_ack_delay_us(kMinAckDelay, 0, 0, + kMaxMaxAckDelayTransportParam * kNumMicrosPerMilli), + disable_active_migration(false), + active_connection_id_limit(kActiveConnectionIdLimit, + kDefaultActiveConnectionIdLimitTransportParam, + kMinActiveConnectionIdLimitTransportParam, + quiche::kVarInt62MaxValue), + max_datagram_frame_size(kMaxDatagramFrameSize), + initial_round_trip_time_us(kInitialRoundTripTime) +// Important note: any new transport parameters must be added +// to TransportParameters::AreValid, SerializeTransportParameters and +// ParseTransportParameters, TransportParameters's custom copy constructor, the +// operator==, and TransportParametersTest.Comparator. +{} + +TransportParameters::TransportParameters(const TransportParameters& other) + : perspective(other.perspective), + legacy_version_information(other.legacy_version_information), + version_information(other.version_information), + original_destination_connection_id( + other.original_destination_connection_id), + max_idle_timeout_ms(other.max_idle_timeout_ms), + stateless_reset_token(other.stateless_reset_token), + max_udp_payload_size(other.max_udp_payload_size), + initial_max_data(other.initial_max_data), + initial_max_stream_data_bidi_local( + other.initial_max_stream_data_bidi_local), + initial_max_stream_data_bidi_remote( + other.initial_max_stream_data_bidi_remote), + initial_max_stream_data_uni(other.initial_max_stream_data_uni), + initial_max_streams_bidi(other.initial_max_streams_bidi), + initial_max_streams_uni(other.initial_max_streams_uni), + ack_delay_exponent(other.ack_delay_exponent), + max_ack_delay(other.max_ack_delay), + min_ack_delay_us(other.min_ack_delay_us), + disable_active_migration(other.disable_active_migration), + active_connection_id_limit(other.active_connection_id_limit), + initial_source_connection_id(other.initial_source_connection_id), + retry_source_connection_id(other.retry_source_connection_id), + max_datagram_frame_size(other.max_datagram_frame_size), + initial_round_trip_time_us(other.initial_round_trip_time_us), + google_handshake_message(other.google_handshake_message), + google_connection_options(other.google_connection_options), + custom_parameters(other.custom_parameters) { + if (other.preferred_address) { + preferred_address = std::make_unique( + *other.preferred_address); + } +} + +bool TransportParameters::operator==(const TransportParameters& rhs) const { + if (!(perspective == rhs.perspective && + legacy_version_information == rhs.legacy_version_information && + version_information == rhs.version_information && + original_destination_connection_id == + rhs.original_destination_connection_id && + max_idle_timeout_ms.value() == rhs.max_idle_timeout_ms.value() && + stateless_reset_token == rhs.stateless_reset_token && + max_udp_payload_size.value() == rhs.max_udp_payload_size.value() && + initial_max_data.value() == rhs.initial_max_data.value() && + initial_max_stream_data_bidi_local.value() == + rhs.initial_max_stream_data_bidi_local.value() && + initial_max_stream_data_bidi_remote.value() == + rhs.initial_max_stream_data_bidi_remote.value() && + initial_max_stream_data_uni.value() == + rhs.initial_max_stream_data_uni.value() && + initial_max_streams_bidi.value() == + rhs.initial_max_streams_bidi.value() && + initial_max_streams_uni.value() == + rhs.initial_max_streams_uni.value() && + ack_delay_exponent.value() == rhs.ack_delay_exponent.value() && + max_ack_delay.value() == rhs.max_ack_delay.value() && + min_ack_delay_us.value() == rhs.min_ack_delay_us.value() && + disable_active_migration == rhs.disable_active_migration && + active_connection_id_limit.value() == + rhs.active_connection_id_limit.value() && + initial_source_connection_id == rhs.initial_source_connection_id && + retry_source_connection_id == rhs.retry_source_connection_id && + max_datagram_frame_size.value() == + rhs.max_datagram_frame_size.value() && + initial_round_trip_time_us.value() == + rhs.initial_round_trip_time_us.value() && + google_handshake_message == rhs.google_handshake_message && + google_connection_options == rhs.google_connection_options && + custom_parameters == rhs.custom_parameters)) { + return false; + } + + if ((!preferred_address && rhs.preferred_address) || + (preferred_address && !rhs.preferred_address)) { + return false; + } + if (preferred_address && rhs.preferred_address && + *preferred_address != *rhs.preferred_address) { + return false; + } + + return true; +} + +bool TransportParameters::operator!=(const TransportParameters& rhs) const { + return !(*this == rhs); +} + +bool TransportParameters::AreValid(std::string* error_details) const { + QUICHE_DCHECK(perspective == Perspective::IS_CLIENT || + perspective == Perspective::IS_SERVER); + if (perspective == Perspective::IS_CLIENT && !stateless_reset_token.empty()) { + *error_details = "Client cannot send stateless reset token"; + return false; + } + if (perspective == Perspective::IS_CLIENT && + original_destination_connection_id.has_value()) { + *error_details = "Client cannot send original_destination_connection_id"; + return false; + } + if (!stateless_reset_token.empty() && + stateless_reset_token.size() != kStatelessResetTokenLength) { + *error_details = absl::StrCat("Stateless reset token has bad length ", + stateless_reset_token.size()); + return false; + } + if (perspective == Perspective::IS_CLIENT && preferred_address) { + *error_details = "Client cannot send preferred address"; + return false; + } + if (preferred_address && preferred_address->stateless_reset_token.size() != + kStatelessResetTokenLength) { + *error_details = + absl::StrCat("Preferred address stateless reset token has bad length ", + preferred_address->stateless_reset_token.size()); + return false; + } + if (preferred_address && + (!preferred_address->ipv4_socket_address.host().IsIPv4() || + !preferred_address->ipv6_socket_address.host().IsIPv6())) { + QUIC_BUG(quic_bug_10743_4) << "Preferred address family failure"; + *error_details = "Internal preferred address family failure"; + return false; + } + if (perspective == Perspective::IS_CLIENT && + retry_source_connection_id.has_value()) { + *error_details = "Client cannot send retry_source_connection_id"; + return false; + } + for (const auto& kv : custom_parameters) { + if (TransportParameterIdIsKnown(kv.first)) { + *error_details = absl::StrCat("Using custom_parameters with known ID ", + TransportParameterIdToString(kv.first), + " is not allowed"); + return false; + } + } + if (perspective == Perspective::IS_SERVER && + google_handshake_message.has_value()) { + *error_details = "Server cannot send google_handshake_message"; + return false; + } + if (perspective == Perspective::IS_SERVER && + initial_round_trip_time_us.value() > 0) { + *error_details = "Server cannot send initial round trip time"; + return false; + } + if (version_information.has_value()) { + const QuicVersionLabel& chosen_version = + version_information.value().chosen_version; + const QuicVersionLabelVector& other_versions = + version_information.value().other_versions; + if (chosen_version == 0) { + *error_details = "Invalid chosen version"; + return false; + } + if (perspective == Perspective::IS_CLIENT && + std::find(other_versions.begin(), other_versions.end(), + chosen_version) == other_versions.end()) { + // When sent by the client, chosen_version needs to be present in + // other_versions because other_versions lists the compatible versions and + // the chosen version is part of that list. When sent by the server, + // other_version contains the list of fully-deployed versions which is + // generally equal to the list of supported versions but can slightly + // differ during removal of versions across a server fleet. See + // draft-ietf-quic-version-negotiation for details. + *error_details = "Client chosen version not in other versions"; + return false; + } + } + const bool ok = + max_idle_timeout_ms.IsValid() && max_udp_payload_size.IsValid() && + initial_max_data.IsValid() && + initial_max_stream_data_bidi_local.IsValid() && + initial_max_stream_data_bidi_remote.IsValid() && + initial_max_stream_data_uni.IsValid() && + initial_max_streams_bidi.IsValid() && initial_max_streams_uni.IsValid() && + ack_delay_exponent.IsValid() && max_ack_delay.IsValid() && + min_ack_delay_us.IsValid() && active_connection_id_limit.IsValid() && + max_datagram_frame_size.IsValid() && initial_round_trip_time_us.IsValid(); + if (!ok) { + *error_details = "Invalid transport parameters " + this->ToString(); + } + return ok; +} + +TransportParameters::~TransportParameters() = default; + +bool SerializeTransportParameters(const TransportParameters& in, + std::vector* out) { + std::string error_details; + if (!in.AreValid(&error_details)) { + QUIC_BUG(invalid transport parameters) + << "Not serializing invalid transport parameters: " << error_details; + return false; + } + if (!in.legacy_version_information.has_value() || + in.legacy_version_information.value().version == 0 || + (in.perspective == Perspective::IS_SERVER && + in.legacy_version_information.value().supported_versions.empty())) { + QUIC_BUG(missing versions) << "Refusing to serialize without versions"; + return false; + } + TransportParameters::ParameterMap custom_parameters = in.custom_parameters; + for (const auto& kv : custom_parameters) { + if (kv.first % 31 == 27) { + // See the "Reserved Transport Parameters" section of RFC 9000. + QUIC_BUG(custom_parameters with GREASE) + << "Serializing custom_parameters with GREASE ID " << kv.first + << " is not allowed"; + return false; + } + } + + // Maximum length of the GREASE transport parameter (see below). + static constexpr size_t kMaxGreaseLength = 16; + + // Empirically transport parameters generally fit within 128 bytes, but we + // need to allocate the size up front. Integer transport parameters + // have a maximum encoded length of 24 bytes (3 variable length integers), + // other transport parameters have a length of 16 + the maximum value length. + static constexpr size_t kTypeAndValueLength = 2 * sizeof(uint64_t); + static constexpr size_t kIntegerParameterLength = + kTypeAndValueLength + sizeof(uint64_t); + static constexpr size_t kStatelessResetParameterLength = + kTypeAndValueLength + 16 /* stateless reset token length */; + static constexpr size_t kConnectionIdParameterLength = + kTypeAndValueLength + 255 /* maximum connection ID length */; + static constexpr size_t kPreferredAddressParameterLength = + kTypeAndValueLength + 4 /*IPv4 address */ + 2 /* IPv4 port */ + + 16 /* IPv6 address */ + 1 /* Connection ID length */ + + 255 /* maximum connection ID length */ + 16 /* stateless reset token */; + static constexpr size_t kKnownTransportParamLength = + kConnectionIdParameterLength + // original_destination_connection_id + kIntegerParameterLength + // max_idle_timeout + kStatelessResetParameterLength + // stateless_reset_token + kIntegerParameterLength + // max_udp_payload_size + kIntegerParameterLength + // initial_max_data + kIntegerParameterLength + // initial_max_stream_data_bidi_local + kIntegerParameterLength + // initial_max_stream_data_bidi_remote + kIntegerParameterLength + // initial_max_stream_data_uni + kIntegerParameterLength + // initial_max_streams_bidi + kIntegerParameterLength + // initial_max_streams_uni + kIntegerParameterLength + // ack_delay_exponent + kIntegerParameterLength + // max_ack_delay + kIntegerParameterLength + // min_ack_delay_us + kTypeAndValueLength + // disable_active_migration + kPreferredAddressParameterLength + // preferred_address + kIntegerParameterLength + // active_connection_id_limit + kConnectionIdParameterLength + // initial_source_connection_id + kConnectionIdParameterLength + // retry_source_connection_id + kIntegerParameterLength + // max_datagram_frame_size + kIntegerParameterLength + // initial_round_trip_time_us + kTypeAndValueLength + // google_handshake_message + kTypeAndValueLength + // google_connection_options + kTypeAndValueLength; // google-version + + std::vector parameter_ids = { + TransportParameters::kOriginalDestinationConnectionId, + TransportParameters::kMaxIdleTimeout, + TransportParameters::kStatelessResetToken, + TransportParameters::kMaxPacketSize, + TransportParameters::kInitialMaxData, + TransportParameters::kInitialMaxStreamDataBidiLocal, + TransportParameters::kInitialMaxStreamDataBidiRemote, + TransportParameters::kInitialMaxStreamDataUni, + TransportParameters::kInitialMaxStreamsBidi, + TransportParameters::kInitialMaxStreamsUni, + TransportParameters::kAckDelayExponent, + TransportParameters::kMaxAckDelay, + TransportParameters::kMinAckDelay, + TransportParameters::kActiveConnectionIdLimit, + TransportParameters::kMaxDatagramFrameSize, + TransportParameters::kGoogleHandshakeMessage, + TransportParameters::kInitialRoundTripTime, + TransportParameters::kDisableActiveMigration, + TransportParameters::kPreferredAddress, + TransportParameters::kInitialSourceConnectionId, + TransportParameters::kRetrySourceConnectionId, + TransportParameters::kGoogleConnectionOptions, + TransportParameters::kGoogleQuicVersion, + TransportParameters::kVersionInformation, + }; + + size_t max_transport_param_length = kKnownTransportParamLength; + // google_connection_options. + if (in.google_connection_options.has_value()) { + max_transport_param_length += + in.google_connection_options.value().size() * sizeof(QuicTag); + } + // Google-specific version extension. + if (in.legacy_version_information.has_value()) { + max_transport_param_length += + sizeof(in.legacy_version_information.value().version) + + 1 /* versions length */ + + in.legacy_version_information.value().supported_versions.size() * + sizeof(QuicVersionLabel); + } + // version_information. + if (in.version_information.has_value()) { + max_transport_param_length += + sizeof(in.version_information.value().chosen_version) + + // Add one for the added GREASE version. + (in.version_information.value().other_versions.size() + 1) * + sizeof(QuicVersionLabel); + } + // google_handshake_message. + if (in.google_handshake_message.has_value()) { + max_transport_param_length += in.google_handshake_message.value().length(); + } + + // Add a random GREASE transport parameter, as defined in the + // "Reserved Transport Parameters" section of RFC 9000. + // This forces receivers to support unexpected input. + QuicRandom* random = QuicRandom::GetInstance(); + // Transport parameter identifiers are 62 bits long so we need to + // ensure that the output of the computation below fits in 62 bits. + uint64_t grease_id64 = random->RandUint64() % ((1ULL << 62) - 31); + // Make sure grease_id % 31 == 27. Note that this is not uniformely + // distributed but is acceptable since no security depends on this + // randomness. + grease_id64 = (grease_id64 / 31) * 31 + 27; + TransportParameters::TransportParameterId grease_id = + static_cast(grease_id64); + const size_t grease_length = random->RandUint64() % kMaxGreaseLength; + QUICHE_DCHECK_GE(kMaxGreaseLength, grease_length); + char grease_contents[kMaxGreaseLength]; + random->RandBytes(grease_contents, grease_length); + custom_parameters[grease_id] = std::string(grease_contents, grease_length); + + // Custom parameters. + for (const auto& kv : custom_parameters) { + max_transport_param_length += kTypeAndValueLength + kv.second.length(); + parameter_ids.push_back(kv.first); + } + + // Randomize order of sent transport parameters by walking the array + // backwards and swapping each element with a random earlier one. + for (size_t i = parameter_ids.size() - 1; i > 0; i--) { + std::swap(parameter_ids[i], + parameter_ids[random->InsecureRandUint64() % (i + 1)]); + } + + out->resize(max_transport_param_length); + QuicDataWriter writer(out->size(), reinterpret_cast(out->data())); + + for (TransportParameters::TransportParameterId parameter_id : parameter_ids) { + switch (parameter_id) { + // original_destination_connection_id + case TransportParameters::kOriginalDestinationConnectionId: { + if (in.original_destination_connection_id.has_value()) { + QUICHE_DCHECK_EQ(Perspective::IS_SERVER, in.perspective); + QuicConnectionId original_destination_connection_id = + in.original_destination_connection_id.value(); + if (!writer.WriteVarInt62( + TransportParameters::kOriginalDestinationConnectionId) || + !writer.WriteStringPieceVarInt62(absl::string_view( + original_destination_connection_id.data(), + original_destination_connection_id.length()))) { + QUIC_BUG(Failed to write original_destination_connection_id) + << "Failed to write original_destination_connection_id " + << original_destination_connection_id << " for " << in; + return false; + } + } + } break; + // max_idle_timeout + case TransportParameters::kMaxIdleTimeout: { + if (!in.max_idle_timeout_ms.Write(&writer)) { + QUIC_BUG(Failed to write idle_timeout) + << "Failed to write idle_timeout for " << in; + return false; + } + } break; + // stateless_reset_token + case TransportParameters::kStatelessResetToken: { + if (!in.stateless_reset_token.empty()) { + QUICHE_DCHECK_EQ(kStatelessResetTokenLength, + in.stateless_reset_token.size()); + QUICHE_DCHECK_EQ(Perspective::IS_SERVER, in.perspective); + if (!writer.WriteVarInt62( + TransportParameters::kStatelessResetToken) || + !writer.WriteStringPieceVarInt62( + absl::string_view(reinterpret_cast( + in.stateless_reset_token.data()), + in.stateless_reset_token.size()))) { + QUIC_BUG(Failed to write stateless_reset_token) + << "Failed to write stateless_reset_token of length " + << in.stateless_reset_token.size() << " for " << in; + return false; + } + } + } break; + // max_udp_payload_size + case TransportParameters::kMaxPacketSize: { + if (!in.max_udp_payload_size.Write(&writer)) { + QUIC_BUG(Failed to write max_udp_payload_size) + << "Failed to write max_udp_payload_size for " << in; + return false; + } + } break; + // initial_max_data + case TransportParameters::kInitialMaxData: { + if (!in.initial_max_data.Write(&writer)) { + QUIC_BUG(Failed to write initial_max_data) + << "Failed to write initial_max_data for " << in; + return false; + } + } break; + // initial_max_stream_data_bidi_local + case TransportParameters::kInitialMaxStreamDataBidiLocal: { + if (!in.initial_max_stream_data_bidi_local.Write(&writer)) { + QUIC_BUG(Failed to write initial_max_stream_data_bidi_local) + << "Failed to write initial_max_stream_data_bidi_local for " + << in; + return false; + } + } break; + // initial_max_stream_data_bidi_remote + case TransportParameters::kInitialMaxStreamDataBidiRemote: { + if (!in.initial_max_stream_data_bidi_remote.Write(&writer)) { + QUIC_BUG(Failed to write initial_max_stream_data_bidi_remote) + << "Failed to write initial_max_stream_data_bidi_remote for " + << in; + return false; + } + } break; + // initial_max_stream_data_uni + case TransportParameters::kInitialMaxStreamDataUni: { + if (!in.initial_max_stream_data_uni.Write(&writer)) { + QUIC_BUG(Failed to write initial_max_stream_data_uni) + << "Failed to write initial_max_stream_data_uni for " << in; + return false; + } + } break; + // initial_max_streams_bidi + case TransportParameters::kInitialMaxStreamsBidi: { + if (!in.initial_max_streams_bidi.Write(&writer)) { + QUIC_BUG(Failed to write initial_max_streams_bidi) + << "Failed to write initial_max_streams_bidi for " << in; + return false; + } + } break; + // initial_max_streams_uni + case TransportParameters::kInitialMaxStreamsUni: { + if (!in.initial_max_streams_uni.Write(&writer)) { + QUIC_BUG(Failed to write initial_max_streams_uni) + << "Failed to write initial_max_streams_uni for " << in; + return false; + } + } break; + // ack_delay_exponent + case TransportParameters::kAckDelayExponent: { + if (!in.ack_delay_exponent.Write(&writer)) { + QUIC_BUG(Failed to write ack_delay_exponent) + << "Failed to write ack_delay_exponent for " << in; + return false; + } + } break; + // max_ack_delay + case TransportParameters::kMaxAckDelay: { + if (!in.max_ack_delay.Write(&writer)) { + QUIC_BUG(Failed to write max_ack_delay) + << "Failed to write max_ack_delay for " << in; + return false; + } + } break; + // min_ack_delay_us + case TransportParameters::kMinAckDelay: { + if (!in.min_ack_delay_us.Write(&writer)) { + QUIC_BUG(Failed to write min_ack_delay_us) + << "Failed to write min_ack_delay_us for " << in; + return false; + } + } break; + // active_connection_id_limit + case TransportParameters::kActiveConnectionIdLimit: { + if (!in.active_connection_id_limit.Write(&writer)) { + QUIC_BUG(Failed to write active_connection_id_limit) + << "Failed to write active_connection_id_limit for " << in; + return false; + } + } break; + // max_datagram_frame_size + case TransportParameters::kMaxDatagramFrameSize: { + if (!in.max_datagram_frame_size.Write(&writer)) { + QUIC_BUG(Failed to write max_datagram_frame_size) + << "Failed to write max_datagram_frame_size for " << in; + return false; + } + } break; + // google_handshake_message + case TransportParameters::kGoogleHandshakeMessage: { + if (in.google_handshake_message.has_value()) { + if (!writer.WriteVarInt62( + TransportParameters::kGoogleHandshakeMessage) || + !writer.WriteStringPieceVarInt62( + in.google_handshake_message.value())) { + QUIC_BUG(Failed to write google_handshake_message) + << "Failed to write google_handshake_message: " + << in.google_handshake_message.value() << " for " << in; + return false; + } + } + } break; + // initial_round_trip_time_us + case TransportParameters::kInitialRoundTripTime: { + if (!in.initial_round_trip_time_us.Write(&writer)) { + QUIC_BUG(Failed to write initial_round_trip_time_us) + << "Failed to write initial_round_trip_time_us for " << in; + return false; + } + } break; + // disable_active_migration + case TransportParameters::kDisableActiveMigration: { + if (in.disable_active_migration) { + if (!writer.WriteVarInt62( + TransportParameters::kDisableActiveMigration) || + !writer.WriteVarInt62(/* transport parameter length */ 0)) { + QUIC_BUG(Failed to write disable_active_migration) + << "Failed to write disable_active_migration for " << in; + return false; + } + } + } break; + // preferred_address + case TransportParameters::kPreferredAddress: { + if (in.preferred_address) { + std::string v4_address_bytes = + in.preferred_address->ipv4_socket_address.host().ToPackedString(); + std::string v6_address_bytes = + in.preferred_address->ipv6_socket_address.host().ToPackedString(); + if (v4_address_bytes.length() != 4 || + v6_address_bytes.length() != 16 || + in.preferred_address->stateless_reset_token.size() != + kStatelessResetTokenLength) { + QUIC_BUG(quic_bug_10743_12) + << "Bad lengths " << *in.preferred_address; + return false; + } + const uint64_t preferred_address_length = + v4_address_bytes.length() + /* IPv4 port */ sizeof(uint16_t) + + v6_address_bytes.length() + /* IPv6 port */ sizeof(uint16_t) + + /* connection ID length byte */ sizeof(uint8_t) + + in.preferred_address->connection_id.length() + + in.preferred_address->stateless_reset_token.size(); + if (!writer.WriteVarInt62(TransportParameters::kPreferredAddress) || + !writer.WriteVarInt62( + /* transport parameter length */ preferred_address_length) || + !writer.WriteStringPiece(v4_address_bytes) || + !writer.WriteUInt16( + in.preferred_address->ipv4_socket_address.port()) || + !writer.WriteStringPiece(v6_address_bytes) || + !writer.WriteUInt16( + in.preferred_address->ipv6_socket_address.port()) || + !writer.WriteUInt8( + in.preferred_address->connection_id.length()) || + !writer.WriteBytes( + in.preferred_address->connection_id.data(), + in.preferred_address->connection_id.length()) || + !writer.WriteBytes( + in.preferred_address->stateless_reset_token.data(), + in.preferred_address->stateless_reset_token.size())) { + QUIC_BUG(Failed to write preferred_address) + << "Failed to write preferred_address for " << in; + return false; + } + } + } break; + // initial_source_connection_id + case TransportParameters::kInitialSourceConnectionId: { + if (in.initial_source_connection_id.has_value()) { + QuicConnectionId initial_source_connection_id = + in.initial_source_connection_id.value(); + if (!writer.WriteVarInt62( + TransportParameters::kInitialSourceConnectionId) || + !writer.WriteStringPieceVarInt62( + absl::string_view(initial_source_connection_id.data(), + initial_source_connection_id.length()))) { + QUIC_BUG(Failed to write initial_source_connection_id) + << "Failed to write initial_source_connection_id " + << initial_source_connection_id << " for " << in; + return false; + } + } + } break; + // retry_source_connection_id + case TransportParameters::kRetrySourceConnectionId: { + if (in.retry_source_connection_id.has_value()) { + QUICHE_DCHECK_EQ(Perspective::IS_SERVER, in.perspective); + QuicConnectionId retry_source_connection_id = + in.retry_source_connection_id.value(); + if (!writer.WriteVarInt62( + TransportParameters::kRetrySourceConnectionId) || + !writer.WriteStringPieceVarInt62( + absl::string_view(retry_source_connection_id.data(), + retry_source_connection_id.length()))) { + QUIC_BUG(Failed to write retry_source_connection_id) + << "Failed to write retry_source_connection_id " + << retry_source_connection_id << " for " << in; + return false; + } + } + } break; + // Google-specific connection options. + case TransportParameters::kGoogleConnectionOptions: { + if (in.google_connection_options.has_value()) { + static_assert( + sizeof(in.google_connection_options.value().front()) == 4, + "bad size"); + uint64_t connection_options_length = + in.google_connection_options.value().size() * 4; + if (!writer.WriteVarInt62( + TransportParameters::kGoogleConnectionOptions) || + !writer.WriteVarInt62( + /* transport parameter length */ connection_options_length)) { + QUIC_BUG(Failed to write google_connection_options) + << "Failed to write google_connection_options of length " + << connection_options_length << " for " << in; + return false; + } + for (const QuicTag& connection_option : + in.google_connection_options.value()) { + if (!writer.WriteTag(connection_option)) { + QUIC_BUG(Failed to write google_connection_option) + << "Failed to write google_connection_option " + << QuicTagToString(connection_option) << " for " << in; + return false; + } + } + } + } break; + // Google-specific version extension. + case TransportParameters::kGoogleQuicVersion: { + if (!in.legacy_version_information.has_value()) { + break; + } + static_assert(sizeof(QuicVersionLabel) == sizeof(uint32_t), + "bad length"); + uint64_t google_version_length = + sizeof(in.legacy_version_information.value().version); + if (in.perspective == Perspective::IS_SERVER) { + google_version_length += + /* versions length */ sizeof(uint8_t) + + sizeof(QuicVersionLabel) * in.legacy_version_information.value() + .supported_versions.size(); + } + if (!writer.WriteVarInt62(TransportParameters::kGoogleQuicVersion) || + !writer.WriteVarInt62( + /* transport parameter length */ google_version_length) || + !writer.WriteUInt32( + in.legacy_version_information.value().version)) { + QUIC_BUG(Failed to write Google version extension) + << "Failed to write Google version extension for " << in; + return false; + } + if (in.perspective == Perspective::IS_SERVER) { + if (!writer.WriteUInt8(sizeof(QuicVersionLabel) * + in.legacy_version_information.value() + .supported_versions.size())) { + QUIC_BUG(Failed to write versions length) + << "Failed to write versions length for " << in; + return false; + } + for (QuicVersionLabel version_label : + in.legacy_version_information.value().supported_versions) { + if (!writer.WriteUInt32(version_label)) { + QUIC_BUG(Failed to write supported version) + << "Failed to write supported version for " << in; + return false; + } + } + } + } break; + // version_information. + case TransportParameters::kVersionInformation: { + if (!in.version_information.has_value()) { + break; + } + static_assert(sizeof(QuicVersionLabel) == sizeof(uint32_t), + "bad length"); + QuicVersionLabelVector other_versions = + in.version_information.value().other_versions; + // Insert one GREASE version at a random index. + const size_t grease_index = + random->InsecureRandUint64() % (other_versions.size() + 1); + other_versions.insert( + other_versions.begin() + grease_index, + CreateQuicVersionLabel(QuicVersionReservedForNegotiation())); + const uint64_t version_information_length = + sizeof(in.version_information.value().chosen_version) + + sizeof(QuicVersionLabel) * other_versions.size(); + if (!writer.WriteVarInt62(TransportParameters::kVersionInformation) || + !writer.WriteVarInt62( + /* transport parameter length */ version_information_length) || + !writer.WriteUInt32( + in.version_information.value().chosen_version)) { + QUIC_BUG(Failed to write chosen version) + << "Failed to write chosen version for " << in; + return false; + } + for (QuicVersionLabel version_label : other_versions) { + if (!writer.WriteUInt32(version_label)) { + QUIC_BUG(Failed to write other version) + << "Failed to write other version for " << in; + return false; + } + } + } break; + // Custom parameters and GREASE. + default: { + auto it = custom_parameters.find(parameter_id); + if (it == custom_parameters.end()) { + QUIC_BUG(Unknown parameter) << "Unknown parameter " << parameter_id; + return false; + } + if (!writer.WriteVarInt62(parameter_id) || + !writer.WriteStringPieceVarInt62(it->second)) { + QUIC_BUG(Failed to write custom parameter) + << "Failed to write custom parameter " << parameter_id; + return false; + } + } break; + } + } + + out->resize(writer.length()); + + QUIC_DLOG(INFO) << "Serialized " << in << " as " << writer.length() + << " bytes"; + + return true; +} + +bool ParseTransportParameters(ParsedQuicVersion version, + Perspective perspective, const uint8_t* in, + size_t in_len, TransportParameters* out, + std::string* error_details) { + out->perspective = perspective; + QuicDataReader reader(reinterpret_cast(in), in_len); + + while (!reader.IsDoneReading()) { + uint64_t param_id64; + if (!reader.ReadVarInt62(¶m_id64)) { + *error_details = "Failed to parse transport parameter ID"; + return false; + } + TransportParameters::TransportParameterId param_id = + static_cast(param_id64); + absl::string_view value; + if (!reader.ReadStringPieceVarInt62(&value)) { + *error_details = + "Failed to read length and value of transport parameter " + + TransportParameterIdToString(param_id); + return false; + } + QuicDataReader value_reader(value); + bool parse_success = true; + switch (param_id) { + case TransportParameters::kOriginalDestinationConnectionId: { + if (out->original_destination_connection_id.has_value()) { + *error_details = + "Received a second original_destination_connection_id"; + return false; + } + const size_t connection_id_length = value_reader.BytesRemaining(); + if (!QuicUtils::IsConnectionIdLengthValidForVersion( + connection_id_length, version.transport_version)) { + *error_details = absl::StrCat( + "Received original_destination_connection_id of invalid length ", + connection_id_length); + return false; + } + QuicConnectionId original_destination_connection_id; + if (!value_reader.ReadConnectionId(&original_destination_connection_id, + connection_id_length)) { + *error_details = "Failed to read original_destination_connection_id"; + return false; + } + out->original_destination_connection_id = + original_destination_connection_id; + } break; + case TransportParameters::kMaxIdleTimeout: + parse_success = + out->max_idle_timeout_ms.Read(&value_reader, error_details); + break; + case TransportParameters::kStatelessResetToken: { + if (!out->stateless_reset_token.empty()) { + *error_details = "Received a second stateless_reset_token"; + return false; + } + absl::string_view stateless_reset_token = + value_reader.ReadRemainingPayload(); + if (stateless_reset_token.length() != kStatelessResetTokenLength) { + *error_details = + absl::StrCat("Received stateless_reset_token of invalid length ", + stateless_reset_token.length()); + return false; + } + out->stateless_reset_token.assign( + stateless_reset_token.data(), + stateless_reset_token.data() + stateless_reset_token.length()); + } break; + case TransportParameters::kMaxPacketSize: + parse_success = + out->max_udp_payload_size.Read(&value_reader, error_details); + break; + case TransportParameters::kInitialMaxData: + parse_success = + out->initial_max_data.Read(&value_reader, error_details); + break; + case TransportParameters::kInitialMaxStreamDataBidiLocal: + parse_success = out->initial_max_stream_data_bidi_local.Read( + &value_reader, error_details); + break; + case TransportParameters::kInitialMaxStreamDataBidiRemote: + parse_success = out->initial_max_stream_data_bidi_remote.Read( + &value_reader, error_details); + break; + case TransportParameters::kInitialMaxStreamDataUni: + parse_success = + out->initial_max_stream_data_uni.Read(&value_reader, error_details); + break; + case TransportParameters::kInitialMaxStreamsBidi: + parse_success = + out->initial_max_streams_bidi.Read(&value_reader, error_details); + break; + case TransportParameters::kInitialMaxStreamsUni: + parse_success = + out->initial_max_streams_uni.Read(&value_reader, error_details); + break; + case TransportParameters::kAckDelayExponent: + parse_success = + out->ack_delay_exponent.Read(&value_reader, error_details); + break; + case TransportParameters::kMaxAckDelay: + parse_success = out->max_ack_delay.Read(&value_reader, error_details); + break; + case TransportParameters::kDisableActiveMigration: + if (out->disable_active_migration) { + *error_details = "Received a second disable_active_migration"; + return false; + } + out->disable_active_migration = true; + break; + case TransportParameters::kPreferredAddress: { + TransportParameters::PreferredAddress preferred_address; + uint16_t ipv4_port, ipv6_port; + in_addr ipv4_address; + in6_addr ipv6_address; + preferred_address.stateless_reset_token.resize( + kStatelessResetTokenLength); + if (!value_reader.ReadBytes(&ipv4_address, sizeof(ipv4_address)) || + !value_reader.ReadUInt16(&ipv4_port) || + !value_reader.ReadBytes(&ipv6_address, sizeof(ipv6_address)) || + !value_reader.ReadUInt16(&ipv6_port) || + !value_reader.ReadLengthPrefixedConnectionId( + &preferred_address.connection_id) || + !value_reader.ReadBytes(&preferred_address.stateless_reset_token[0], + kStatelessResetTokenLength)) { + *error_details = "Failed to read preferred_address"; + return false; + } + preferred_address.ipv4_socket_address = + QuicSocketAddress(QuicIpAddress(ipv4_address), ipv4_port); + preferred_address.ipv6_socket_address = + QuicSocketAddress(QuicIpAddress(ipv6_address), ipv6_port); + if (!preferred_address.ipv4_socket_address.host().IsIPv4() || + !preferred_address.ipv6_socket_address.host().IsIPv6()) { + *error_details = "Received preferred_address of bad families " + + preferred_address.ToString(); + return false; + } + if (!QuicUtils::IsConnectionIdValidForVersion( + preferred_address.connection_id, version.transport_version)) { + *error_details = "Received invalid preferred_address connection ID " + + preferred_address.ToString(); + return false; + } + out->preferred_address = + std::make_unique( + preferred_address); + } break; + case TransportParameters::kActiveConnectionIdLimit: + parse_success = + out->active_connection_id_limit.Read(&value_reader, error_details); + break; + case TransportParameters::kInitialSourceConnectionId: { + if (out->initial_source_connection_id.has_value()) { + *error_details = "Received a second initial_source_connection_id"; + return false; + } + const size_t connection_id_length = value_reader.BytesRemaining(); + if (!QuicUtils::IsConnectionIdLengthValidForVersion( + connection_id_length, version.transport_version)) { + *error_details = absl::StrCat( + "Received initial_source_connection_id of invalid length ", + connection_id_length); + return false; + } + QuicConnectionId initial_source_connection_id; + if (!value_reader.ReadConnectionId(&initial_source_connection_id, + connection_id_length)) { + *error_details = "Failed to read initial_source_connection_id"; + return false; + } + out->initial_source_connection_id = initial_source_connection_id; + } break; + case TransportParameters::kRetrySourceConnectionId: { + if (out->retry_source_connection_id.has_value()) { + *error_details = "Received a second retry_source_connection_id"; + return false; + } + const size_t connection_id_length = value_reader.BytesRemaining(); + if (!QuicUtils::IsConnectionIdLengthValidForVersion( + connection_id_length, version.transport_version)) { + *error_details = absl::StrCat( + "Received retry_source_connection_id of invalid length ", + connection_id_length); + return false; + } + QuicConnectionId retry_source_connection_id; + if (!value_reader.ReadConnectionId(&retry_source_connection_id, + connection_id_length)) { + *error_details = "Failed to read retry_source_connection_id"; + return false; + } + out->retry_source_connection_id = retry_source_connection_id; + } break; + case TransportParameters::kMaxDatagramFrameSize: + parse_success = + out->max_datagram_frame_size.Read(&value_reader, error_details); + break; + case TransportParameters::kGoogleHandshakeMessage: + if (out->google_handshake_message.has_value()) { + *error_details = "Received a second google_handshake_message"; + return false; + } + out->google_handshake_message = + std::string(value_reader.ReadRemainingPayload()); + break; + case TransportParameters::kInitialRoundTripTime: + parse_success = + out->initial_round_trip_time_us.Read(&value_reader, error_details); + break; + case TransportParameters::kGoogleConnectionOptions: { + if (out->google_connection_options.has_value()) { + *error_details = "Received a second google_connection_options"; + return false; + } + out->google_connection_options = QuicTagVector{}; + while (!value_reader.IsDoneReading()) { + QuicTag connection_option; + if (!value_reader.ReadTag(&connection_option)) { + *error_details = "Failed to read a google_connection_options"; + return false; + } + out->google_connection_options.value().push_back(connection_option); + } + } break; + case TransportParameters::kGoogleQuicVersion: { + if (!out->legacy_version_information.has_value()) { + out->legacy_version_information = + TransportParameters::LegacyVersionInformation(); + } + if (!value_reader.ReadUInt32( + &out->legacy_version_information.value().version)) { + *error_details = "Failed to read Google version extension version"; + return false; + } + if (perspective == Perspective::IS_SERVER) { + uint8_t versions_length; + if (!value_reader.ReadUInt8(&versions_length)) { + *error_details = "Failed to parse Google supported versions length"; + return false; + } + const uint8_t num_versions = versions_length / sizeof(uint32_t); + for (uint8_t i = 0; i < num_versions; ++i) { + QuicVersionLabel version; + if (!value_reader.ReadUInt32(&version)) { + *error_details = "Failed to parse Google supported version"; + return false; + } + out->legacy_version_information.value() + .supported_versions.push_back(version); + } + } + } break; + case TransportParameters::kVersionInformation: { + if (out->version_information.has_value()) { + *error_details = "Received a second version_information"; + return false; + } + out->version_information = TransportParameters::VersionInformation(); + if (!value_reader.ReadUInt32( + &out->version_information.value().chosen_version)) { + *error_details = "Failed to read chosen version"; + return false; + } + while (!value_reader.IsDoneReading()) { + QuicVersionLabel other_version; + if (!value_reader.ReadUInt32(&other_version)) { + *error_details = "Failed to parse other version"; + return false; + } + out->version_information.value().other_versions.push_back( + other_version); + } + } break; + case TransportParameters::kMinAckDelay: + parse_success = + out->min_ack_delay_us.Read(&value_reader, error_details); + break; + default: + if (out->custom_parameters.find(param_id) != + out->custom_parameters.end()) { + *error_details = "Received a second unknown parameter" + + TransportParameterIdToString(param_id); + return false; + } + out->custom_parameters[param_id] = + std::string(value_reader.ReadRemainingPayload()); + break; + } + if (!parse_success) { + QUICHE_DCHECK(!error_details->empty()); + return false; + } + if (!value_reader.IsDoneReading()) { + *error_details = absl::StrCat( + "Received unexpected ", value_reader.BytesRemaining(), + " bytes after parsing ", TransportParameterIdToString(param_id)); + return false; + } + } + + if (!out->AreValid(error_details)) { + QUICHE_DCHECK(!error_details->empty()); + return false; + } + + QUIC_DLOG(INFO) << "Parsed transport parameters " << *out << " from " + << in_len << " bytes"; + + return true; +} + +namespace { + +bool DigestUpdateIntegerParam( + EVP_MD_CTX* hash_ctx, const TransportParameters::IntegerParameter& param) { + uint64_t value = param.value(); + return EVP_DigestUpdate(hash_ctx, &value, sizeof(value)); +} + +} // namespace + +bool SerializeTransportParametersForTicket( + const TransportParameters& in, const std::vector& application_data, + std::vector* out) { + std::string error_details; + if (!in.AreValid(&error_details)) { + QUIC_BUG(quic_bug_10743_26) + << "Not serializing invalid transport parameters: " << error_details; + return false; + } + + out->resize(SHA256_DIGEST_LENGTH + 1); + const uint8_t serialization_version = 0; + (*out)[0] = serialization_version; + + bssl::ScopedEVP_MD_CTX hash_ctx; + // Write application data: + uint64_t app_data_len = application_data.size(); + const uint64_t parameter_version = 0; + // The format of the input to the hash function is as follows: + // - The application data, prefixed with a 64-bit length field. + // - Transport parameters: + // - A 64-bit version field indicating which version of encoding is used + // for transport parameters. + // - A list of 64-bit integers representing the relevant parameters. + // + // When changing which parameters are included, additional parameters can be + // added to the end of the list without changing the version field. New + // parameters that are variable length must be length prefixed. If + // parameters are removed from the list, the version field must be + // incremented. + // + // Integers happen to be written in host byte order, not network byte order. + if (!EVP_DigestInit(hash_ctx.get(), EVP_sha256()) || + !EVP_DigestUpdate(hash_ctx.get(), &app_data_len, sizeof(app_data_len)) || + !EVP_DigestUpdate(hash_ctx.get(), application_data.data(), + application_data.size()) || + !EVP_DigestUpdate(hash_ctx.get(), ¶meter_version, + sizeof(parameter_version))) { + QUIC_BUG(quic_bug_10743_27) + << "Unexpected failure of EVP_Digest functions when hashing " + "Transport Parameters for ticket"; + return false; + } + + // Write transport parameters specified by draft-ietf-quic-transport-28, + // section 7.4.1, that are remembered for 0-RTT. + if (!DigestUpdateIntegerParam(hash_ctx.get(), in.initial_max_data) || + !DigestUpdateIntegerParam(hash_ctx.get(), + in.initial_max_stream_data_bidi_local) || + !DigestUpdateIntegerParam(hash_ctx.get(), + in.initial_max_stream_data_bidi_remote) || + !DigestUpdateIntegerParam(hash_ctx.get(), + in.initial_max_stream_data_uni) || + !DigestUpdateIntegerParam(hash_ctx.get(), in.initial_max_streams_bidi) || + !DigestUpdateIntegerParam(hash_ctx.get(), in.initial_max_streams_uni) || + !DigestUpdateIntegerParam(hash_ctx.get(), + in.active_connection_id_limit)) { + QUIC_BUG(quic_bug_10743_28) + << "Unexpected failure of EVP_Digest functions when hashing " + "Transport Parameters for ticket"; + return false; + } + uint8_t disable_active_migration = in.disable_active_migration ? 1 : 0; + if (!EVP_DigestUpdate(hash_ctx.get(), &disable_active_migration, + sizeof(disable_active_migration)) || + !EVP_DigestFinal(hash_ctx.get(), out->data() + 1, nullptr)) { + QUIC_BUG(quic_bug_10743_29) + << "Unexpected failure of EVP_Digest functions when hashing " + "Transport Parameters for ticket"; + return false; + } + return true; +} + +void DegreaseTransportParameters(TransportParameters& parameters) { + // Strip GREASE from custom parameters. + for (auto it = parameters.custom_parameters.begin(); + it != parameters.custom_parameters.end(); + /**/) { + // See the "Reserved Transport Parameters" section of RFC 9000. + if (it->first % 31 == 27) { + parameters.custom_parameters.erase(it++); + } else { + ++it; + } + } + + // Strip GREASE from versions. + if (parameters.version_information.has_value()) { + QuicVersionLabelVector clean_versions; + for (QuicVersionLabel version : + parameters.version_information->other_versions) { + // See the "Versions" section of RFC 9000. + if ((version & kReservedVersionMask) != kReservedVersionBits) { + clean_versions.push_back(version); + } + } + + parameters.version_information->other_versions = std::move(clean_versions); + } +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/transport_parameters.h b/quiche/quic/core/crypto/transport_parameters.h new file mode 100644 index 000000000000..78c2202e837a --- /dev/null +++ b/quiche/quic/core/crypto/transport_parameters.h @@ -0,0 +1,311 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_TRANSPORT_PARAMETERS_H_ +#define QUICHE_QUIC_CORE_CRYPTO_TRANSPORT_PARAMETERS_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_data_reader.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/core/quic_tag.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_socket_address.h" + +namespace quic { + +// TransportParameters contains parameters for QUIC's transport layer that are +// exchanged during the TLS handshake. This struct is a mirror of the struct in +// the "Transport Parameter Encoding" section of draft-ietf-quic-transport. +// This struct currently uses the values from draft 29. +struct QUIC_EXPORT_PRIVATE TransportParameters { + // The identifier used to differentiate transport parameters. + enum TransportParameterId : uint64_t; + // A map used to specify custom parameters. + using ParameterMap = absl::flat_hash_map; + // Represents an individual QUIC transport parameter that only encodes a + // variable length integer. Can only be created inside the constructor for + // TransportParameters. + class QUIC_EXPORT_PRIVATE IntegerParameter { + public: + // Forbid constructing and copying apart from TransportParameters. + IntegerParameter() = delete; + IntegerParameter& operator=(const IntegerParameter&) = delete; + // Sets the value of this transport parameter. + void set_value(uint64_t value); + // Gets the value of this transport parameter. + uint64_t value() const; + // Validates whether the current value is valid. + bool IsValid() const; + // Writes to a crypto byte buffer, used during serialization. Does not write + // anything if the value is equal to the parameter's default value. + // Returns whether the write was successful. + bool Write(QuicDataWriter* writer) const; + // Reads from a crypto byte string, used during parsing. + // Returns whether the read was successful. + // On failure, this method will write a human-readable error message to + // |error_details|. + bool Read(QuicDataReader* reader, std::string* error_details); + // operator<< allows easily logging integer transport parameters. + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const IntegerParameter& param); + + private: + friend struct TransportParameters; + // Constructors for initial setup used by TransportParameters only. + // This constructor sets |default_value| and |min_value| to 0, and + // |max_value| to quiche::kVarInt62MaxValue. + explicit IntegerParameter(TransportParameterId param_id); + IntegerParameter(TransportParameterId param_id, uint64_t default_value, + uint64_t min_value, uint64_t max_value); + IntegerParameter(const IntegerParameter& other) = default; + IntegerParameter(IntegerParameter&& other) = default; + // Human-readable string representation. + std::string ToString(bool for_use_in_list) const; + + // Number used to indicate this transport parameter. + TransportParameterId param_id_; + // Current value of the transport parameter. + uint64_t value_; + // Default value of this transport parameter, as per IETF specification. + const uint64_t default_value_; + // Minimum value of this transport parameter, as per IETF specification. + const uint64_t min_value_; + // Maximum value of this transport parameter, as per IETF specification. + const uint64_t max_value_; + // Ensures this parameter is not parsed twice in the same message. + bool has_been_read_; + }; + + // Represents the preferred_address transport parameter that a server can + // send to clients. + struct QUIC_EXPORT_PRIVATE PreferredAddress { + PreferredAddress(); + PreferredAddress(const PreferredAddress& other) = default; + PreferredAddress(PreferredAddress&& other) = default; + ~PreferredAddress(); + bool operator==(const PreferredAddress& rhs) const; + bool operator!=(const PreferredAddress& rhs) const; + + QuicSocketAddress ipv4_socket_address; + QuicSocketAddress ipv6_socket_address; + QuicConnectionId connection_id; + std::vector stateless_reset_token; + + // Allows easily logging. + std::string ToString() const; + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const TransportParameters& params); + }; + + // LegacyVersionInformation represents the Google QUIC downgrade prevention + // mechanism ported to QUIC+TLS. It is exchanged using transport parameter ID + // 0x4752 and will eventually be deprecated in favor of + // draft-ietf-quic-version-negotiation. + struct QUIC_EXPORT_PRIVATE LegacyVersionInformation { + LegacyVersionInformation(); + LegacyVersionInformation(const LegacyVersionInformation& other) = default; + LegacyVersionInformation& operator=(const LegacyVersionInformation& other) = + default; + LegacyVersionInformation& operator=(LegacyVersionInformation&& other) = + default; + LegacyVersionInformation(LegacyVersionInformation&& other) = default; + ~LegacyVersionInformation() = default; + bool operator==(const LegacyVersionInformation& rhs) const; + bool operator!=(const LegacyVersionInformation& rhs) const; + // When sent by the client, |version| is the initial version offered by the + // client (before any version negotiation packets) for this connection. When + // sent by the server, |version| is the version that is in use. + QuicVersionLabel version; + + // When sent by the server, |supported_versions| contains a list of all + // versions that the server would send in a version negotiation packet. When + // sent by the client, this is empty. + QuicVersionLabelVector supported_versions; + + // Allows easily logging. + std::string ToString() const; + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, + const LegacyVersionInformation& legacy_version_information); + }; + + // Version information used for version downgrade prevention and compatible + // version negotiation. See draft-ietf-quic-version-negotiation-05. + struct QUIC_EXPORT_PRIVATE VersionInformation { + VersionInformation(); + VersionInformation(const VersionInformation& other) = default; + VersionInformation& operator=(const VersionInformation& other) = default; + VersionInformation& operator=(VersionInformation&& other) = default; + VersionInformation(VersionInformation&& other) = default; + ~VersionInformation() = default; + bool operator==(const VersionInformation& rhs) const; + bool operator!=(const VersionInformation& rhs) const; + + // Version that the sender has chosen to use on this connection. + QuicVersionLabel chosen_version; + + // When sent by the client, |other_versions| contains all the versions that + // this first flight is compatible with. When sent by the server, + // |other_versions| contains all of the versions supported by the server. + QuicVersionLabelVector other_versions; + + // Allows easily logging. + std::string ToString() const; + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const VersionInformation& version_information); + }; + + TransportParameters(); + TransportParameters(const TransportParameters& other); + ~TransportParameters(); + bool operator==(const TransportParameters& rhs) const; + bool operator!=(const TransportParameters& rhs) const; + + // Represents the sender of the transport parameters. When |perspective| is + // Perspective::IS_CLIENT, this struct is being used in the client_hello + // handshake message; when it is Perspective::IS_SERVER, it is being used in + // the encrypted_extensions handshake message. + Perspective perspective; + + // Google QUIC downgrade prevention mechanism sent over QUIC+TLS. + absl::optional legacy_version_information; + + // IETF downgrade prevention and compatible version negotiation, see + // draft-ietf-quic-version-negotiation. + absl::optional version_information; + + // The value of the Destination Connection ID field from the first + // Initial packet sent by the client. + absl::optional original_destination_connection_id; + + // Maximum idle timeout expressed in milliseconds. + IntegerParameter max_idle_timeout_ms; + + // Stateless reset token used in verifying stateless resets. + std::vector stateless_reset_token; + + // Limits the size of packets that the endpoint is willing to receive. + // This indicates that packets larger than this limit will be dropped. + IntegerParameter max_udp_payload_size; + + // Contains the initial value for the maximum amount of data that can + // be sent on the connection. + IntegerParameter initial_max_data; + + // Initial flow control limit for locally-initiated bidirectional streams. + IntegerParameter initial_max_stream_data_bidi_local; + + // Initial flow control limit for peer-initiated bidirectional streams. + IntegerParameter initial_max_stream_data_bidi_remote; + + // Initial flow control limit for unidirectional streams. + IntegerParameter initial_max_stream_data_uni; + + // Initial maximum number of bidirectional streams the peer may initiate. + IntegerParameter initial_max_streams_bidi; + + // Initial maximum number of unidirectional streams the peer may initiate. + IntegerParameter initial_max_streams_uni; + + // Exponent used to decode the ACK Delay field in ACK frames. + IntegerParameter ack_delay_exponent; + + // Maximum amount of time in milliseconds by which the endpoint will + // delay sending acknowledgments. + IntegerParameter max_ack_delay; + + // Minimum amount of time in microseconds by which the endpoint will + // delay sending acknowledgments. Used to enable sender control of ack delay. + IntegerParameter min_ack_delay_us; + + // Indicates lack of support for connection migration. + bool disable_active_migration; + + // Used to effect a change in server address at the end of the handshake. + std::unique_ptr preferred_address; + + // Maximum number of connection IDs from the peer that an endpoint is willing + // to store. + IntegerParameter active_connection_id_limit; + + // The value that the endpoint included in the Source Connection ID field of + // the first Initial packet it sent. + absl::optional initial_source_connection_id; + + // The value that the server included in the Source Connection ID field of a + // Retry packet it sent. + absl::optional retry_source_connection_id; + + // Indicates support for the DATAGRAM frame and the maximum frame size that + // the sender accepts. See draft-ietf-quic-datagram. + IntegerParameter max_datagram_frame_size; + + // Google-specific transport parameter that carries an estimate of the + // initial round-trip time in microseconds. + IntegerParameter initial_round_trip_time_us; + + // Google internal handshake message. + absl::optional google_handshake_message; + + // Google-specific connection options. + absl::optional google_connection_options; + + // Validates whether transport parameters are valid according to + // the specification. If the transport parameters are not valid, this method + // will write a human-readable error message to |error_details|. + bool AreValid(std::string* error_details) const; + + // Custom parameters that may be specific to application protocol. + ParameterMap custom_parameters; + + // Allows easily logging transport parameters. + std::string ToString() const; + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const TransportParameters& params); +}; + +// Serializes a TransportParameters struct into the format for sending it in a +// TLS extension. The serialized bytes are written to |*out|. Returns if the +// parameters are valid and serialization succeeded. +QUIC_EXPORT_PRIVATE bool SerializeTransportParameters( + const TransportParameters& in, std::vector* out); + +// Parses bytes from the quic_transport_parameters TLS extension and writes the +// parsed parameters into |*out|. Input is read from |in| for |in_len| bytes. +// |perspective| indicates whether the input came from a client or a server. +// This method returns true if the input was successfully parsed. +// On failure, this method will write a human-readable error message to +// |error_details|. +QUIC_EXPORT_PRIVATE bool ParseTransportParameters( + ParsedQuicVersion version, Perspective perspective, const uint8_t* in, + size_t in_len, TransportParameters* out, std::string* error_details); + +// Serializes |in| and |application_data| in a deterministic format so that +// multiple calls to SerializeTransportParametersForTicket with the same inputs +// will generate the same output, and if the inputs differ, then the output will +// differ. The output of this function is used by the server in +// SSL_set_quic_early_data_context to determine whether early data should be +// accepted: Early data will only be accepted if the inputs to this function +// match what they were on the connection that issued an early data capable +// ticket. +QUIC_EXPORT_PRIVATE bool SerializeTransportParametersForTicket( + const TransportParameters& in, const std::vector& application_data, + std::vector* out); + +// Removes reserved values from custom_parameters and versions. +// The resulting value can be reliably compared with an original or other +// deserialized value. +QUIC_EXPORT_PRIVATE void DegreaseTransportParameters( + TransportParameters& parameters); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_TRANSPORT_PARAMETERS_H_ diff --git a/quiche/quic/core/crypto/transport_parameters_test.cc b/quiche/quic/core/crypto/transport_parameters_test.cc new file mode 100644 index 000000000000..6b782c459cb7 --- /dev/null +++ b/quiche/quic/core/crypto/transport_parameters_test.cc @@ -0,0 +1,1192 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/transport_parameters.h" + +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_tag.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { +namespace test { +namespace { + +const QuicVersionLabel kFakeVersionLabel = 0x01234567; +const QuicVersionLabel kFakeVersionLabel2 = 0x89ABCDEF; +const uint64_t kFakeIdleTimeoutMilliseconds = 12012; +const uint64_t kFakeInitialMaxData = 101; +const uint64_t kFakeInitialMaxStreamDataBidiLocal = 2001; +const uint64_t kFakeInitialMaxStreamDataBidiRemote = 2002; +const uint64_t kFakeInitialMaxStreamDataUni = 3000; +const uint64_t kFakeInitialMaxStreamsBidi = 21; +const uint64_t kFakeInitialMaxStreamsUni = 22; +const bool kFakeDisableMigration = true; +const uint64_t kFakeInitialRoundTripTime = 53; +const uint8_t kFakePreferredStatelessResetTokenData[16] = { + 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, + 0x88, 0x89, 0x8A, 0x8B, 0x8C, 0x8D, 0x8E, 0x8F}; + +const auto kCustomParameter1 = + static_cast(0xffcd); +const char* kCustomParameter1Value = "foo"; +const auto kCustomParameter2 = + static_cast(0xff34); +const char* kCustomParameter2Value = "bar"; + +const char kFakeGoogleHandshakeMessage[] = + "01000106030392655f5230270d4964a4f99b15bbad220736d972aea97bf9ac494ead62e6"; + +QuicConnectionId CreateFakeOriginalDestinationConnectionId() { + return TestConnectionId(0x1337); +} + +QuicConnectionId CreateFakeInitialSourceConnectionId() { + return TestConnectionId(0x2345); +} + +QuicConnectionId CreateFakeRetrySourceConnectionId() { + return TestConnectionId(0x9876); +} + +QuicConnectionId CreateFakePreferredConnectionId() { + return TestConnectionId(0xBEEF); +} + +std::vector CreateFakePreferredStatelessResetToken() { + return std::vector( + kFakePreferredStatelessResetTokenData, + kFakePreferredStatelessResetTokenData + + sizeof(kFakePreferredStatelessResetTokenData)); +} + +QuicSocketAddress CreateFakeV4SocketAddress() { + QuicIpAddress ipv4_address; + if (!ipv4_address.FromString("65.66.67.68")) { // 0x41, 0x42, 0x43, 0x44 + QUIC_LOG(FATAL) << "Failed to create IPv4 address"; + return QuicSocketAddress(); + } + return QuicSocketAddress(ipv4_address, 0x4884); +} + +QuicSocketAddress CreateFakeV6SocketAddress() { + QuicIpAddress ipv6_address; + if (!ipv6_address.FromString("6061:6263:6465:6667:6869:6A6B:6C6D:6E6F")) { + QUIC_LOG(FATAL) << "Failed to create IPv6 address"; + return QuicSocketAddress(); + } + return QuicSocketAddress(ipv6_address, 0x6336); +} + +std::unique_ptr +CreateFakePreferredAddress() { + TransportParameters::PreferredAddress preferred_address; + preferred_address.ipv4_socket_address = CreateFakeV4SocketAddress(); + preferred_address.ipv6_socket_address = CreateFakeV6SocketAddress(); + preferred_address.connection_id = CreateFakePreferredConnectionId(); + preferred_address.stateless_reset_token = + CreateFakePreferredStatelessResetToken(); + return std::make_unique( + preferred_address); +} + +TransportParameters::LegacyVersionInformation +CreateFakeLegacyVersionInformationClient() { + TransportParameters::LegacyVersionInformation legacy_version_information; + legacy_version_information.version = kFakeVersionLabel; + return legacy_version_information; +} + +TransportParameters::LegacyVersionInformation +CreateFakeLegacyVersionInformationServer() { + TransportParameters::LegacyVersionInformation legacy_version_information = + CreateFakeLegacyVersionInformationClient(); + legacy_version_information.supported_versions.push_back(kFakeVersionLabel); + legacy_version_information.supported_versions.push_back(kFakeVersionLabel2); + return legacy_version_information; +} + +TransportParameters::VersionInformation CreateFakeVersionInformation() { + TransportParameters::VersionInformation version_information; + version_information.chosen_version = kFakeVersionLabel; + version_information.other_versions.push_back(kFakeVersionLabel); + version_information.other_versions.push_back(kFakeVersionLabel2); + return version_information; +} + +QuicTagVector CreateFakeGoogleConnectionOptions() { + return {kALPN, MakeQuicTag('E', 'F', 'G', 0x00), + MakeQuicTag('H', 'I', 'J', 0xff)}; +} + +void RemoveGreaseParameters(TransportParameters* params) { + std::vector grease_params; + for (const auto& kv : params->custom_parameters) { + if (kv.first % 31 == 27) { + grease_params.push_back(kv.first); + } + } + EXPECT_EQ(grease_params.size(), 1u); + for (TransportParameters::TransportParameterId param_id : grease_params) { + params->custom_parameters.erase(param_id); + } + // Remove all GREASE versions from version_information.other_versions. + if (params->version_information.has_value()) { + QuicVersionLabelVector& other_versions = + params->version_information.value().other_versions; + for (auto it = other_versions.begin(); it != other_versions.end();) { + if ((*it & 0x0f0f0f0f) == 0x0a0a0a0a) { + it = other_versions.erase(it); + } else { + ++it; + } + } + } +} + +} // namespace + +class TransportParametersTest : public QuicTestWithParam { + protected: + TransportParametersTest() : version_(GetParam()) {} + + ParsedQuicVersion version_; +}; + +INSTANTIATE_TEST_SUITE_P(TransportParametersTests, TransportParametersTest, + ::testing::ValuesIn(AllSupportedVersionsWithTls()), + ::testing::PrintToStringParamName()); + +TEST_P(TransportParametersTest, Comparator) { + TransportParameters orig_params; + TransportParameters new_params; + // Test comparison on primitive members. + orig_params.perspective = Perspective::IS_CLIENT; + new_params.perspective = Perspective::IS_SERVER; + EXPECT_NE(orig_params, new_params); + EXPECT_FALSE(orig_params == new_params); + EXPECT_TRUE(orig_params != new_params); + new_params.perspective = Perspective::IS_CLIENT; + orig_params.legacy_version_information = + CreateFakeLegacyVersionInformationClient(); + new_params.legacy_version_information = + CreateFakeLegacyVersionInformationClient(); + orig_params.version_information = CreateFakeVersionInformation(); + new_params.version_information = CreateFakeVersionInformation(); + orig_params.disable_active_migration = true; + new_params.disable_active_migration = true; + EXPECT_EQ(orig_params, new_params); + EXPECT_TRUE(orig_params == new_params); + EXPECT_FALSE(orig_params != new_params); + + // Test comparison on vectors. + orig_params.legacy_version_information.value().supported_versions.push_back( + kFakeVersionLabel); + new_params.legacy_version_information.value().supported_versions.push_back( + kFakeVersionLabel2); + EXPECT_NE(orig_params, new_params); + EXPECT_FALSE(orig_params == new_params); + EXPECT_TRUE(orig_params != new_params); + new_params.legacy_version_information.value().supported_versions.pop_back(); + new_params.legacy_version_information.value().supported_versions.push_back( + kFakeVersionLabel); + orig_params.stateless_reset_token = CreateStatelessResetTokenForTest(); + new_params.stateless_reset_token = CreateStatelessResetTokenForTest(); + EXPECT_EQ(orig_params, new_params); + EXPECT_TRUE(orig_params == new_params); + EXPECT_FALSE(orig_params != new_params); + + // Test comparison on IntegerParameters. + orig_params.max_udp_payload_size.set_value(kMaxPacketSizeForTest); + new_params.max_udp_payload_size.set_value(kMaxPacketSizeForTest + 1); + EXPECT_NE(orig_params, new_params); + EXPECT_FALSE(orig_params == new_params); + EXPECT_TRUE(orig_params != new_params); + new_params.max_udp_payload_size.set_value(kMaxPacketSizeForTest); + EXPECT_EQ(orig_params, new_params); + EXPECT_TRUE(orig_params == new_params); + EXPECT_FALSE(orig_params != new_params); + + // Test comparison on PreferredAddress + orig_params.preferred_address = CreateFakePreferredAddress(); + EXPECT_NE(orig_params, new_params); + EXPECT_FALSE(orig_params == new_params); + EXPECT_TRUE(orig_params != new_params); + new_params.preferred_address = CreateFakePreferredAddress(); + EXPECT_EQ(orig_params, new_params); + EXPECT_TRUE(orig_params == new_params); + EXPECT_FALSE(orig_params != new_params); + + // Test comparison on CustomMap + orig_params.custom_parameters[kCustomParameter1] = kCustomParameter1Value; + orig_params.custom_parameters[kCustomParameter2] = kCustomParameter2Value; + + new_params.custom_parameters[kCustomParameter2] = kCustomParameter2Value; + new_params.custom_parameters[kCustomParameter1] = kCustomParameter1Value; + EXPECT_EQ(orig_params, new_params); + EXPECT_TRUE(orig_params == new_params); + EXPECT_FALSE(orig_params != new_params); + + // Test comparison on connection IDs. + orig_params.initial_source_connection_id = + CreateFakeInitialSourceConnectionId(); + new_params.initial_source_connection_id = absl::nullopt; + EXPECT_NE(orig_params, new_params); + EXPECT_FALSE(orig_params == new_params); + EXPECT_TRUE(orig_params != new_params); + new_params.initial_source_connection_id = TestConnectionId(0xbadbad); + EXPECT_NE(orig_params, new_params); + EXPECT_FALSE(orig_params == new_params); + EXPECT_TRUE(orig_params != new_params); + new_params.initial_source_connection_id = + CreateFakeInitialSourceConnectionId(); + EXPECT_EQ(orig_params, new_params); + EXPECT_TRUE(orig_params == new_params); + EXPECT_FALSE(orig_params != new_params); +} + +TEST_P(TransportParametersTest, CopyConstructor) { + TransportParameters orig_params; + orig_params.perspective = Perspective::IS_CLIENT; + orig_params.legacy_version_information = + CreateFakeLegacyVersionInformationClient(); + orig_params.version_information = CreateFakeVersionInformation(); + orig_params.original_destination_connection_id = + CreateFakeOriginalDestinationConnectionId(); + orig_params.max_idle_timeout_ms.set_value(kFakeIdleTimeoutMilliseconds); + orig_params.stateless_reset_token = CreateStatelessResetTokenForTest(); + orig_params.max_udp_payload_size.set_value(kMaxPacketSizeForTest); + orig_params.initial_max_data.set_value(kFakeInitialMaxData); + orig_params.initial_max_stream_data_bidi_local.set_value( + kFakeInitialMaxStreamDataBidiLocal); + orig_params.initial_max_stream_data_bidi_remote.set_value( + kFakeInitialMaxStreamDataBidiRemote); + orig_params.initial_max_stream_data_uni.set_value( + kFakeInitialMaxStreamDataUni); + orig_params.initial_max_streams_bidi.set_value(kFakeInitialMaxStreamsBidi); + orig_params.initial_max_streams_uni.set_value(kFakeInitialMaxStreamsUni); + orig_params.ack_delay_exponent.set_value(kAckDelayExponentForTest); + orig_params.max_ack_delay.set_value(kMaxAckDelayForTest); + orig_params.min_ack_delay_us.set_value(kMinAckDelayUsForTest); + orig_params.disable_active_migration = kFakeDisableMigration; + orig_params.preferred_address = CreateFakePreferredAddress(); + orig_params.active_connection_id_limit.set_value( + kActiveConnectionIdLimitForTest); + orig_params.initial_source_connection_id = + CreateFakeInitialSourceConnectionId(); + orig_params.retry_source_connection_id = CreateFakeRetrySourceConnectionId(); + orig_params.initial_round_trip_time_us.set_value(kFakeInitialRoundTripTime); + orig_params.google_handshake_message = + absl::HexStringToBytes(kFakeGoogleHandshakeMessage); + orig_params.google_connection_options = CreateFakeGoogleConnectionOptions(); + orig_params.custom_parameters[kCustomParameter1] = kCustomParameter1Value; + orig_params.custom_parameters[kCustomParameter2] = kCustomParameter2Value; + + TransportParameters new_params(orig_params); + EXPECT_EQ(new_params, orig_params); +} + +TEST_P(TransportParametersTest, RoundTripClient) { + TransportParameters orig_params; + orig_params.perspective = Perspective::IS_CLIENT; + orig_params.legacy_version_information = + CreateFakeLegacyVersionInformationClient(); + orig_params.version_information = CreateFakeVersionInformation(); + orig_params.max_idle_timeout_ms.set_value(kFakeIdleTimeoutMilliseconds); + orig_params.max_udp_payload_size.set_value(kMaxPacketSizeForTest); + orig_params.initial_max_data.set_value(kFakeInitialMaxData); + orig_params.initial_max_stream_data_bidi_local.set_value( + kFakeInitialMaxStreamDataBidiLocal); + orig_params.initial_max_stream_data_bidi_remote.set_value( + kFakeInitialMaxStreamDataBidiRemote); + orig_params.initial_max_stream_data_uni.set_value( + kFakeInitialMaxStreamDataUni); + orig_params.initial_max_streams_bidi.set_value(kFakeInitialMaxStreamsBidi); + orig_params.initial_max_streams_uni.set_value(kFakeInitialMaxStreamsUni); + orig_params.ack_delay_exponent.set_value(kAckDelayExponentForTest); + orig_params.max_ack_delay.set_value(kMaxAckDelayForTest); + orig_params.min_ack_delay_us.set_value(kMinAckDelayUsForTest); + orig_params.disable_active_migration = kFakeDisableMigration; + orig_params.active_connection_id_limit.set_value( + kActiveConnectionIdLimitForTest); + orig_params.initial_source_connection_id = + CreateFakeInitialSourceConnectionId(); + orig_params.initial_round_trip_time_us.set_value(kFakeInitialRoundTripTime); + orig_params.google_handshake_message = + absl::HexStringToBytes(kFakeGoogleHandshakeMessage); + orig_params.google_connection_options = CreateFakeGoogleConnectionOptions(); + orig_params.custom_parameters[kCustomParameter1] = kCustomParameter1Value; + orig_params.custom_parameters[kCustomParameter2] = kCustomParameter2Value; + + std::vector serialized; + ASSERT_TRUE(SerializeTransportParameters(orig_params, &serialized)); + + TransportParameters new_params; + std::string error_details; + ASSERT_TRUE(ParseTransportParameters(version_, Perspective::IS_CLIENT, + serialized.data(), serialized.size(), + &new_params, &error_details)) + << error_details; + EXPECT_TRUE(error_details.empty()); + RemoveGreaseParameters(&new_params); + EXPECT_EQ(new_params, orig_params); +} + +TEST_P(TransportParametersTest, RoundTripServer) { + TransportParameters orig_params; + orig_params.perspective = Perspective::IS_SERVER; + orig_params.legacy_version_information = + CreateFakeLegacyVersionInformationServer(); + orig_params.version_information = CreateFakeVersionInformation(); + orig_params.original_destination_connection_id = + CreateFakeOriginalDestinationConnectionId(); + orig_params.max_idle_timeout_ms.set_value(kFakeIdleTimeoutMilliseconds); + orig_params.stateless_reset_token = CreateStatelessResetTokenForTest(); + orig_params.max_udp_payload_size.set_value(kMaxPacketSizeForTest); + orig_params.initial_max_data.set_value(kFakeInitialMaxData); + orig_params.initial_max_stream_data_bidi_local.set_value( + kFakeInitialMaxStreamDataBidiLocal); + orig_params.initial_max_stream_data_bidi_remote.set_value( + kFakeInitialMaxStreamDataBidiRemote); + orig_params.initial_max_stream_data_uni.set_value( + kFakeInitialMaxStreamDataUni); + orig_params.initial_max_streams_bidi.set_value(kFakeInitialMaxStreamsBidi); + orig_params.initial_max_streams_uni.set_value(kFakeInitialMaxStreamsUni); + orig_params.ack_delay_exponent.set_value(kAckDelayExponentForTest); + orig_params.max_ack_delay.set_value(kMaxAckDelayForTest); + orig_params.min_ack_delay_us.set_value(kMinAckDelayUsForTest); + orig_params.disable_active_migration = kFakeDisableMigration; + orig_params.preferred_address = CreateFakePreferredAddress(); + orig_params.active_connection_id_limit.set_value( + kActiveConnectionIdLimitForTest); + orig_params.initial_source_connection_id = + CreateFakeInitialSourceConnectionId(); + orig_params.retry_source_connection_id = CreateFakeRetrySourceConnectionId(); + orig_params.google_connection_options = CreateFakeGoogleConnectionOptions(); + + std::vector serialized; + ASSERT_TRUE(SerializeTransportParameters(orig_params, &serialized)); + + TransportParameters new_params; + std::string error_details; + ASSERT_TRUE(ParseTransportParameters(version_, Perspective::IS_SERVER, + serialized.data(), serialized.size(), + &new_params, &error_details)) + << error_details; + EXPECT_TRUE(error_details.empty()); + RemoveGreaseParameters(&new_params); + EXPECT_EQ(new_params, orig_params); +} + +TEST_P(TransportParametersTest, AreValid) { + { + TransportParameters params; + std::string error_details; + params.perspective = Perspective::IS_CLIENT; + EXPECT_TRUE(params.AreValid(&error_details)); + EXPECT_TRUE(error_details.empty()); + } + { + TransportParameters params; + std::string error_details; + params.perspective = Perspective::IS_CLIENT; + EXPECT_TRUE(params.AreValid(&error_details)); + EXPECT_TRUE(error_details.empty()); + params.max_idle_timeout_ms.set_value(kFakeIdleTimeoutMilliseconds); + EXPECT_TRUE(params.AreValid(&error_details)); + EXPECT_TRUE(error_details.empty()); + params.max_idle_timeout_ms.set_value(601000); + EXPECT_TRUE(params.AreValid(&error_details)); + EXPECT_TRUE(error_details.empty()); + } + { + TransportParameters params; + std::string error_details; + params.perspective = Perspective::IS_CLIENT; + EXPECT_TRUE(params.AreValid(&error_details)); + EXPECT_TRUE(error_details.empty()); + params.max_udp_payload_size.set_value(1200); + EXPECT_TRUE(params.AreValid(&error_details)); + EXPECT_TRUE(error_details.empty()); + params.max_udp_payload_size.set_value(65535); + EXPECT_TRUE(params.AreValid(&error_details)); + EXPECT_TRUE(error_details.empty()); + params.max_udp_payload_size.set_value(9999999); + EXPECT_TRUE(params.AreValid(&error_details)); + EXPECT_TRUE(error_details.empty()); + params.max_udp_payload_size.set_value(0); + error_details = ""; + EXPECT_FALSE(params.AreValid(&error_details)); + EXPECT_EQ(error_details, + "Invalid transport parameters [Client max_udp_payload_size 0 " + "(Invalid)]"); + params.max_udp_payload_size.set_value(1199); + error_details = ""; + EXPECT_FALSE(params.AreValid(&error_details)); + EXPECT_EQ(error_details, + "Invalid transport parameters [Client max_udp_payload_size 1199 " + "(Invalid)]"); + } + { + TransportParameters params; + std::string error_details; + params.perspective = Perspective::IS_CLIENT; + EXPECT_TRUE(params.AreValid(&error_details)); + EXPECT_TRUE(error_details.empty()); + params.ack_delay_exponent.set_value(0); + EXPECT_TRUE(params.AreValid(&error_details)); + EXPECT_TRUE(error_details.empty()); + params.ack_delay_exponent.set_value(20); + EXPECT_TRUE(params.AreValid(&error_details)); + EXPECT_TRUE(error_details.empty()); + params.ack_delay_exponent.set_value(21); + EXPECT_FALSE(params.AreValid(&error_details)); + EXPECT_EQ(error_details, + "Invalid transport parameters [Client ack_delay_exponent 21 " + "(Invalid)]"); + } + { + TransportParameters params; + std::string error_details; + params.perspective = Perspective::IS_CLIENT; + EXPECT_TRUE(params.AreValid(&error_details)); + EXPECT_TRUE(error_details.empty()); + params.active_connection_id_limit.set_value(2); + EXPECT_TRUE(params.AreValid(&error_details)); + EXPECT_TRUE(error_details.empty()); + params.active_connection_id_limit.set_value(999999); + EXPECT_TRUE(params.AreValid(&error_details)); + EXPECT_TRUE(error_details.empty()); + params.active_connection_id_limit.set_value(1); + EXPECT_FALSE(params.AreValid(&error_details)); + EXPECT_EQ(error_details, + "Invalid transport parameters [Client active_connection_id_limit" + " 1 (Invalid)]"); + params.active_connection_id_limit.set_value(0); + EXPECT_FALSE(params.AreValid(&error_details)); + EXPECT_EQ(error_details, + "Invalid transport parameters [Client active_connection_id_limit" + " 0 (Invalid)]"); + } +} + +TEST_P(TransportParametersTest, NoClientParamsWithStatelessResetToken) { + TransportParameters orig_params; + orig_params.perspective = Perspective::IS_CLIENT; + orig_params.legacy_version_information = + CreateFakeLegacyVersionInformationClient(); + orig_params.max_idle_timeout_ms.set_value(kFakeIdleTimeoutMilliseconds); + orig_params.stateless_reset_token = CreateStatelessResetTokenForTest(); + orig_params.max_udp_payload_size.set_value(kMaxPacketSizeForTest); + + std::vector out; + EXPECT_QUIC_BUG( + EXPECT_FALSE(SerializeTransportParameters(orig_params, &out)), + "Not serializing invalid transport parameters: Client cannot send " + "stateless reset token"); +} + +TEST_P(TransportParametersTest, ParseClientParams) { + // clang-format off + const uint8_t kClientParams[] = { + // max_idle_timeout + 0x01, // parameter id + 0x02, // length + 0x6e, 0xec, // value + // max_udp_payload_size + 0x03, // parameter id + 0x02, // length + 0x63, 0x29, // value + // initial_max_data + 0x04, // parameter id + 0x02, // length + 0x40, 0x65, // value + // initial_max_stream_data_bidi_local + 0x05, // parameter id + 0x02, // length + 0x47, 0xD1, // value + // initial_max_stream_data_bidi_remote + 0x06, // parameter id + 0x02, // length + 0x47, 0xD2, // value + // initial_max_stream_data_uni + 0x07, // parameter id + 0x02, // length + 0x4B, 0xB8, // value + // initial_max_streams_bidi + 0x08, // parameter id + 0x01, // length + 0x15, // value + // initial_max_streams_uni + 0x09, // parameter id + 0x01, // length + 0x16, // value + // ack_delay_exponent + 0x0a, // parameter id + 0x01, // length + 0x0a, // value + // max_ack_delay + 0x0b, // parameter id + 0x01, // length + 0x33, // value + // min_ack_delay_us + 0x80, 0x00, 0xde, 0x1a, // parameter id + 0x02, // length + 0x43, 0xe8, // value + // disable_active_migration + 0x0c, // parameter id + 0x00, // length + // active_connection_id_limit + 0x0e, // parameter id + 0x01, // length + 0x34, // value + // initial_source_connection_id + 0x0f, // parameter id + 0x08, // length + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x23, 0x45, + // google_handshake_message + 0x66, 0xab, // parameter id + 0x24, // length + 0x01, 0x00, 0x01, 0x06, 0x03, 0x03, 0x92, 0x65, 0x5f, 0x52, 0x30, 0x27, + 0x0d, 0x49, 0x64, 0xa4, 0xf9, 0x9b, 0x15, 0xbb, 0xad, 0x22, 0x07, 0x36, + 0xd9, 0x72, 0xae, 0xa9, 0x7b, 0xf9, 0xac, 0x49, 0x4e, 0xad, 0x62, 0xe6, + // initial_round_trip_time_us + 0x71, 0x27, // parameter id + 0x01, // length + 0x35, // value + // google_connection_options + 0x71, 0x28, // parameter id + 0x0c, // length + 'A', 'L', 'P', 'N', // value + 'E', 'F', 'G', 0x00, + 'H', 'I', 'J', 0xff, + // Google version extension + 0x80, 0x00, 0x47, 0x52, // parameter id + 0x04, // length + 0x01, 0x23, 0x45, 0x67, // initial version + // version_information + 0x80, 0xFF, 0x73, 0xDB, // parameter id + 0x0C, // length + 0x01, 0x23, 0x45, 0x67, // chosen version + 0x01, 0x23, 0x45, 0x67, // other version 1 + 0x89, 0xab, 0xcd, 0xef, // other version 2 + }; + // clang-format on + const uint8_t* client_params = + reinterpret_cast(kClientParams); + size_t client_params_length = ABSL_ARRAYSIZE(kClientParams); + TransportParameters new_params; + std::string error_details; + ASSERT_TRUE(ParseTransportParameters(version_, Perspective::IS_CLIENT, + client_params, client_params_length, + &new_params, &error_details)) + << error_details; + EXPECT_TRUE(error_details.empty()); + EXPECT_EQ(Perspective::IS_CLIENT, new_params.perspective); + ASSERT_TRUE(new_params.legacy_version_information.has_value()); + EXPECT_EQ(kFakeVersionLabel, + new_params.legacy_version_information.value().version); + EXPECT_TRUE( + new_params.legacy_version_information.value().supported_versions.empty()); + ASSERT_TRUE(new_params.version_information.has_value()); + EXPECT_EQ(new_params.version_information.value(), + CreateFakeVersionInformation()); + EXPECT_FALSE(new_params.original_destination_connection_id.has_value()); + EXPECT_EQ(kFakeIdleTimeoutMilliseconds, + new_params.max_idle_timeout_ms.value()); + EXPECT_TRUE(new_params.stateless_reset_token.empty()); + EXPECT_EQ(kMaxPacketSizeForTest, new_params.max_udp_payload_size.value()); + EXPECT_EQ(kFakeInitialMaxData, new_params.initial_max_data.value()); + EXPECT_EQ(kFakeInitialMaxStreamDataBidiLocal, + new_params.initial_max_stream_data_bidi_local.value()); + EXPECT_EQ(kFakeInitialMaxStreamDataBidiRemote, + new_params.initial_max_stream_data_bidi_remote.value()); + EXPECT_EQ(kFakeInitialMaxStreamDataUni, + new_params.initial_max_stream_data_uni.value()); + EXPECT_EQ(kFakeInitialMaxStreamsBidi, + new_params.initial_max_streams_bidi.value()); + EXPECT_EQ(kFakeInitialMaxStreamsUni, + new_params.initial_max_streams_uni.value()); + EXPECT_EQ(kAckDelayExponentForTest, new_params.ack_delay_exponent.value()); + EXPECT_EQ(kMaxAckDelayForTest, new_params.max_ack_delay.value()); + EXPECT_EQ(kMinAckDelayUsForTest, new_params.min_ack_delay_us.value()); + EXPECT_EQ(kFakeDisableMigration, new_params.disable_active_migration); + EXPECT_EQ(kActiveConnectionIdLimitForTest, + new_params.active_connection_id_limit.value()); + ASSERT_TRUE(new_params.initial_source_connection_id.has_value()); + EXPECT_EQ(CreateFakeInitialSourceConnectionId(), + new_params.initial_source_connection_id.value()); + EXPECT_FALSE(new_params.retry_source_connection_id.has_value()); + EXPECT_EQ(kFakeInitialRoundTripTime, + new_params.initial_round_trip_time_us.value()); + ASSERT_TRUE(new_params.google_connection_options.has_value()); + EXPECT_EQ(CreateFakeGoogleConnectionOptions(), + new_params.google_connection_options.value()); + EXPECT_EQ(absl::HexStringToBytes(kFakeGoogleHandshakeMessage), + new_params.google_handshake_message); +} + +TEST_P(TransportParametersTest, + ParseClientParamsFailsWithFullStatelessResetToken) { + // clang-format off + const uint8_t kClientParamsWithFullToken[] = { + // max_idle_timeout + 0x01, // parameter id + 0x02, // length + 0x6e, 0xec, // value + // stateless_reset_token + 0x02, // parameter id + 0x10, // length + 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, + 0x98, 0x99, 0x9A, 0x9B, 0x9C, 0x9D, 0x9E, 0x9F, + // max_udp_payload_size + 0x03, // parameter id + 0x02, // length + 0x63, 0x29, // value + // initial_max_data + 0x04, // parameter id + 0x02, // length + 0x40, 0x65, // value + }; + // clang-format on + const uint8_t* client_params = + reinterpret_cast(kClientParamsWithFullToken); + size_t client_params_length = ABSL_ARRAYSIZE(kClientParamsWithFullToken); + TransportParameters out_params; + std::string error_details; + EXPECT_FALSE(ParseTransportParameters(version_, Perspective::IS_CLIENT, + client_params, client_params_length, + &out_params, &error_details)); + EXPECT_EQ(error_details, "Client cannot send stateless reset token"); +} + +TEST_P(TransportParametersTest, + ParseClientParamsFailsWithEmptyStatelessResetToken) { + // clang-format off + const uint8_t kClientParamsWithEmptyToken[] = { + // max_idle_timeout + 0x01, // parameter id + 0x02, // length + 0x6e, 0xec, // value + // stateless_reset_token + 0x02, // parameter id + 0x00, // length + // max_udp_payload_size + 0x03, // parameter id + 0x02, // length + 0x63, 0x29, // value + // initial_max_data + 0x04, // parameter id + 0x02, // length + 0x40, 0x65, // value + }; + // clang-format on + const uint8_t* client_params = + reinterpret_cast(kClientParamsWithEmptyToken); + size_t client_params_length = ABSL_ARRAYSIZE(kClientParamsWithEmptyToken); + TransportParameters out_params; + std::string error_details; + EXPECT_FALSE(ParseTransportParameters(version_, Perspective::IS_CLIENT, + client_params, client_params_length, + &out_params, &error_details)); + EXPECT_EQ(error_details, + "Received stateless_reset_token of invalid length 0"); +} + +TEST_P(TransportParametersTest, ParseClientParametersRepeated) { + // clang-format off + const uint8_t kClientParamsRepeated[] = { + // max_idle_timeout + 0x01, // parameter id + 0x02, // length + 0x6e, 0xec, // value + // max_udp_payload_size + 0x03, // parameter id + 0x02, // length + 0x63, 0x29, // value + // max_idle_timeout (repeated) + 0x01, // parameter id + 0x02, // length + 0x6e, 0xec, // value + }; + // clang-format on + const uint8_t* client_params = + reinterpret_cast(kClientParamsRepeated); + size_t client_params_length = ABSL_ARRAYSIZE(kClientParamsRepeated); + TransportParameters out_params; + std::string error_details; + EXPECT_FALSE(ParseTransportParameters(version_, Perspective::IS_CLIENT, + client_params, client_params_length, + &out_params, &error_details)); + EXPECT_EQ(error_details, "Received a second max_idle_timeout"); +} + +TEST_P(TransportParametersTest, ParseServerParams) { + // clang-format off + const uint8_t kServerParams[] = { + // original_destination_connection_id + 0x00, // parameter id + 0x08, // length + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x13, 0x37, + // max_idle_timeout + 0x01, // parameter id + 0x02, // length + 0x6e, 0xec, // value + // stateless_reset_token + 0x02, // parameter id + 0x10, // length + 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, + 0x98, 0x99, 0x9A, 0x9B, 0x9C, 0x9D, 0x9E, 0x9F, + // max_udp_payload_size + 0x03, // parameter id + 0x02, // length + 0x63, 0x29, // value + // initial_max_data + 0x04, // parameter id + 0x02, // length + 0x40, 0x65, // value + // initial_max_stream_data_bidi_local + 0x05, // parameter id + 0x02, // length + 0x47, 0xD1, // value + // initial_max_stream_data_bidi_remote + 0x06, // parameter id + 0x02, // length + 0x47, 0xD2, // value + // initial_max_stream_data_uni + 0x07, // parameter id + 0x02, // length + 0x4B, 0xB8, // value + // initial_max_streams_bidi + 0x08, // parameter id + 0x01, // length + 0x15, // value + // initial_max_streams_uni + 0x09, // parameter id + 0x01, // length + 0x16, // value + // ack_delay_exponent + 0x0a, // parameter id + 0x01, // length + 0x0a, // value + // max_ack_delay + 0x0b, // parameter id + 0x01, // length + 0x33, // value + // min_ack_delay_us + 0x80, 0x00, 0xde, 0x1a, // parameter id + 0x02, // length + 0x43, 0xe8, // value + // disable_active_migration + 0x0c, // parameter id + 0x00, // length + // preferred_address + 0x0d, // parameter id + 0x31, // length + 0x41, 0x42, 0x43, 0x44, // IPv4 address + 0x48, 0x84, // IPv4 port + 0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, // IPv6 address + 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, + 0x63, 0x36, // IPv6 port + 0x08, // connection ID length + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xBE, 0xEF, // connection ID + 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, // stateless reset token + 0x88, 0x89, 0x8A, 0x8B, 0x8C, 0x8D, 0x8E, 0x8F, + // active_connection_id_limit + 0x0e, // parameter id + 0x01, // length + 0x34, // value + // initial_source_connection_id + 0x0f, // parameter id + 0x08, // length + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x23, 0x45, + // retry_source_connection_id + 0x10, // parameter id + 0x08, // length + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x98, 0x76, + // google_connection_options + 0x71, 0x28, // parameter id + 0x0c, // length + 'A', 'L', 'P', 'N', // value + 'E', 'F', 'G', 0x00, + 'H', 'I', 'J', 0xff, + // Google version extension + 0x80, 0x00, 0x47, 0x52, // parameter id + 0x0d, // length + 0x01, 0x23, 0x45, 0x67, // negotiated_version + 0x08, // length of supported versions array + 0x01, 0x23, 0x45, 0x67, + 0x89, 0xab, 0xcd, 0xef, + // version_information + 0x80, 0xFF, 0x73, 0xDB, // parameter id + 0x0C, // length + 0x01, 0x23, 0x45, 0x67, // chosen version + 0x01, 0x23, 0x45, 0x67, // other version 1 + 0x89, 0xab, 0xcd, 0xef, // other version 2 + }; + // clang-format on + const uint8_t* server_params = + reinterpret_cast(kServerParams); + size_t server_params_length = ABSL_ARRAYSIZE(kServerParams); + TransportParameters new_params; + std::string error_details; + ASSERT_TRUE(ParseTransportParameters(version_, Perspective::IS_SERVER, + server_params, server_params_length, + &new_params, &error_details)) + << error_details; + EXPECT_TRUE(error_details.empty()); + EXPECT_EQ(Perspective::IS_SERVER, new_params.perspective); + ASSERT_TRUE(new_params.legacy_version_information.has_value()); + EXPECT_EQ(kFakeVersionLabel, + new_params.legacy_version_information.value().version); + ASSERT_EQ( + 2u, + new_params.legacy_version_information.value().supported_versions.size()); + EXPECT_EQ( + kFakeVersionLabel, + new_params.legacy_version_information.value().supported_versions[0]); + EXPECT_EQ( + kFakeVersionLabel2, + new_params.legacy_version_information.value().supported_versions[1]); + ASSERT_TRUE(new_params.version_information.has_value()); + EXPECT_EQ(new_params.version_information.value(), + CreateFakeVersionInformation()); + ASSERT_TRUE(new_params.original_destination_connection_id.has_value()); + EXPECT_EQ(CreateFakeOriginalDestinationConnectionId(), + new_params.original_destination_connection_id.value()); + EXPECT_EQ(kFakeIdleTimeoutMilliseconds, + new_params.max_idle_timeout_ms.value()); + EXPECT_EQ(CreateStatelessResetTokenForTest(), + new_params.stateless_reset_token); + EXPECT_EQ(kMaxPacketSizeForTest, new_params.max_udp_payload_size.value()); + EXPECT_EQ(kFakeInitialMaxData, new_params.initial_max_data.value()); + EXPECT_EQ(kFakeInitialMaxStreamDataBidiLocal, + new_params.initial_max_stream_data_bidi_local.value()); + EXPECT_EQ(kFakeInitialMaxStreamDataBidiRemote, + new_params.initial_max_stream_data_bidi_remote.value()); + EXPECT_EQ(kFakeInitialMaxStreamDataUni, + new_params.initial_max_stream_data_uni.value()); + EXPECT_EQ(kFakeInitialMaxStreamsBidi, + new_params.initial_max_streams_bidi.value()); + EXPECT_EQ(kFakeInitialMaxStreamsUni, + new_params.initial_max_streams_uni.value()); + EXPECT_EQ(kAckDelayExponentForTest, new_params.ack_delay_exponent.value()); + EXPECT_EQ(kMaxAckDelayForTest, new_params.max_ack_delay.value()); + EXPECT_EQ(kMinAckDelayUsForTest, new_params.min_ack_delay_us.value()); + EXPECT_EQ(kFakeDisableMigration, new_params.disable_active_migration); + ASSERT_NE(nullptr, new_params.preferred_address.get()); + EXPECT_EQ(CreateFakeV4SocketAddress(), + new_params.preferred_address->ipv4_socket_address); + EXPECT_EQ(CreateFakeV6SocketAddress(), + new_params.preferred_address->ipv6_socket_address); + EXPECT_EQ(CreateFakePreferredConnectionId(), + new_params.preferred_address->connection_id); + EXPECT_EQ(CreateFakePreferredStatelessResetToken(), + new_params.preferred_address->stateless_reset_token); + EXPECT_EQ(kActiveConnectionIdLimitForTest, + new_params.active_connection_id_limit.value()); + ASSERT_TRUE(new_params.initial_source_connection_id.has_value()); + EXPECT_EQ(CreateFakeInitialSourceConnectionId(), + new_params.initial_source_connection_id.value()); + ASSERT_TRUE(new_params.retry_source_connection_id.has_value()); + EXPECT_EQ(CreateFakeRetrySourceConnectionId(), + new_params.retry_source_connection_id.value()); + ASSERT_TRUE(new_params.google_connection_options.has_value()); + EXPECT_EQ(CreateFakeGoogleConnectionOptions(), + new_params.google_connection_options.value()); +} + +TEST_P(TransportParametersTest, ParseServerParametersRepeated) { + // clang-format off + const uint8_t kServerParamsRepeated[] = { + // original_destination_connection_id + 0x00, // parameter id + 0x08, // length + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x13, 0x37, + // max_idle_timeout + 0x01, // parameter id + 0x02, // length + 0x6e, 0xec, // value + // stateless_reset_token + 0x02, // parameter id + 0x10, // length + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, + // max_idle_timeout (repeated) + 0x01, // parameter id + 0x02, // length + 0x6e, 0xec, // value + }; + // clang-format on + const uint8_t* server_params = + reinterpret_cast(kServerParamsRepeated); + size_t server_params_length = ABSL_ARRAYSIZE(kServerParamsRepeated); + TransportParameters out_params; + std::string error_details; + EXPECT_FALSE(ParseTransportParameters(version_, Perspective::IS_SERVER, + server_params, server_params_length, + &out_params, &error_details)); + EXPECT_EQ(error_details, "Received a second max_idle_timeout"); +} + +TEST_P(TransportParametersTest, + ParseServerParametersEmptyOriginalConnectionId) { + // clang-format off + const uint8_t kServerParamsEmptyOriginalConnectionId[] = { + // original_destination_connection_id + 0x00, // parameter id + 0x00, // length + // max_idle_timeout + 0x01, // parameter id + 0x02, // length + 0x6e, 0xec, // value + // stateless_reset_token + 0x02, // parameter id + 0x10, // length + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, + }; + // clang-format on + const uint8_t* server_params = + reinterpret_cast(kServerParamsEmptyOriginalConnectionId); + size_t server_params_length = + ABSL_ARRAYSIZE(kServerParamsEmptyOriginalConnectionId); + TransportParameters out_params; + std::string error_details; + ASSERT_TRUE(ParseTransportParameters(version_, Perspective::IS_SERVER, + server_params, server_params_length, + &out_params, &error_details)) + << error_details; + ASSERT_TRUE(out_params.original_destination_connection_id.has_value()); + EXPECT_EQ(out_params.original_destination_connection_id.value(), + EmptyQuicConnectionId()); +} + +TEST_P(TransportParametersTest, VeryLongCustomParameter) { + // Ensure we can handle a 70KB custom parameter on both send and receive. + std::string custom_value(70000, '?'); + TransportParameters orig_params; + orig_params.perspective = Perspective::IS_CLIENT; + orig_params.legacy_version_information = + CreateFakeLegacyVersionInformationClient(); + orig_params.custom_parameters[kCustomParameter1] = custom_value; + + std::vector serialized; + ASSERT_TRUE(SerializeTransportParameters(orig_params, &serialized)); + + TransportParameters new_params; + std::string error_details; + ASSERT_TRUE(ParseTransportParameters(version_, Perspective::IS_CLIENT, + serialized.data(), serialized.size(), + &new_params, &error_details)) + << error_details; + EXPECT_TRUE(error_details.empty()); + RemoveGreaseParameters(&new_params); + EXPECT_EQ(new_params, orig_params); +} + +TEST_P(TransportParametersTest, SerializationOrderIsRandom) { + TransportParameters orig_params; + orig_params.perspective = Perspective::IS_CLIENT; + orig_params.legacy_version_information = + CreateFakeLegacyVersionInformationClient(); + orig_params.max_idle_timeout_ms.set_value(kFakeIdleTimeoutMilliseconds); + orig_params.max_udp_payload_size.set_value(kMaxPacketSizeForTest); + orig_params.initial_max_data.set_value(kFakeInitialMaxData); + orig_params.initial_max_stream_data_bidi_local.set_value( + kFakeInitialMaxStreamDataBidiLocal); + orig_params.initial_max_stream_data_bidi_remote.set_value( + kFakeInitialMaxStreamDataBidiRemote); + orig_params.initial_max_stream_data_uni.set_value( + kFakeInitialMaxStreamDataUni); + orig_params.initial_max_streams_bidi.set_value(kFakeInitialMaxStreamsBidi); + orig_params.initial_max_streams_uni.set_value(kFakeInitialMaxStreamsUni); + orig_params.ack_delay_exponent.set_value(kAckDelayExponentForTest); + orig_params.max_ack_delay.set_value(kMaxAckDelayForTest); + orig_params.min_ack_delay_us.set_value(kMinAckDelayUsForTest); + orig_params.disable_active_migration = kFakeDisableMigration; + orig_params.active_connection_id_limit.set_value( + kActiveConnectionIdLimitForTest); + orig_params.initial_source_connection_id = + CreateFakeInitialSourceConnectionId(); + orig_params.initial_round_trip_time_us.set_value(kFakeInitialRoundTripTime); + orig_params.google_connection_options = CreateFakeGoogleConnectionOptions(); + orig_params.custom_parameters[kCustomParameter1] = kCustomParameter1Value; + orig_params.custom_parameters[kCustomParameter2] = kCustomParameter2Value; + + std::vector first_serialized; + ASSERT_TRUE(SerializeTransportParameters(orig_params, &first_serialized)); + // Test that a subsequent serialization is different from the first. + // Run in a loop to avoid a failure in the unlikely event that randomization + // produces the same result multiple times. + for (int i = 0; i < 1000; i++) { + std::vector serialized; + ASSERT_TRUE(SerializeTransportParameters(orig_params, &serialized)); + if (serialized != first_serialized) { + return; + } + } +} + +TEST_P(TransportParametersTest, Degrease) { + TransportParameters orig_params; + orig_params.perspective = Perspective::IS_CLIENT; + orig_params.legacy_version_information = + CreateFakeLegacyVersionInformationClient(); + orig_params.version_information = CreateFakeVersionInformation(); + orig_params.max_idle_timeout_ms.set_value(kFakeIdleTimeoutMilliseconds); + orig_params.max_udp_payload_size.set_value(kMaxPacketSizeForTest); + orig_params.initial_max_data.set_value(kFakeInitialMaxData); + orig_params.initial_max_stream_data_bidi_local.set_value( + kFakeInitialMaxStreamDataBidiLocal); + orig_params.initial_max_stream_data_bidi_remote.set_value( + kFakeInitialMaxStreamDataBidiRemote); + orig_params.initial_max_stream_data_uni.set_value( + kFakeInitialMaxStreamDataUni); + orig_params.initial_max_streams_bidi.set_value(kFakeInitialMaxStreamsBidi); + orig_params.initial_max_streams_uni.set_value(kFakeInitialMaxStreamsUni); + orig_params.ack_delay_exponent.set_value(kAckDelayExponentForTest); + orig_params.max_ack_delay.set_value(kMaxAckDelayForTest); + orig_params.min_ack_delay_us.set_value(kMinAckDelayUsForTest); + orig_params.disable_active_migration = kFakeDisableMigration; + orig_params.active_connection_id_limit.set_value( + kActiveConnectionIdLimitForTest); + orig_params.initial_source_connection_id = + CreateFakeInitialSourceConnectionId(); + orig_params.initial_round_trip_time_us.set_value(kFakeInitialRoundTripTime); + orig_params.google_handshake_message = + absl::HexStringToBytes(kFakeGoogleHandshakeMessage); + orig_params.google_connection_options = CreateFakeGoogleConnectionOptions(); + orig_params.custom_parameters[kCustomParameter1] = kCustomParameter1Value; + orig_params.custom_parameters[kCustomParameter2] = kCustomParameter2Value; + + std::vector serialized; + ASSERT_TRUE(SerializeTransportParameters(orig_params, &serialized)); + + TransportParameters new_params; + std::string error_details; + ASSERT_TRUE(ParseTransportParameters(version_, Perspective::IS_CLIENT, + serialized.data(), serialized.size(), + &new_params, &error_details)) + << error_details; + EXPECT_TRUE(error_details.empty()); + + // Deserialized parameters have grease added. + EXPECT_NE(new_params, orig_params); + + DegreaseTransportParameters(new_params); + EXPECT_EQ(new_params, orig_params); +} + +class TransportParametersTicketSerializationTest : public QuicTest { + protected: + void SetUp() override { + original_params_.perspective = Perspective::IS_SERVER; + original_params_.legacy_version_information = + CreateFakeLegacyVersionInformationServer(); + original_params_.original_destination_connection_id = + CreateFakeOriginalDestinationConnectionId(); + original_params_.max_idle_timeout_ms.set_value( + kFakeIdleTimeoutMilliseconds); + original_params_.stateless_reset_token = CreateStatelessResetTokenForTest(); + original_params_.max_udp_payload_size.set_value(kMaxPacketSizeForTest); + original_params_.initial_max_data.set_value(kFakeInitialMaxData); + original_params_.initial_max_stream_data_bidi_local.set_value( + kFakeInitialMaxStreamDataBidiLocal); + original_params_.initial_max_stream_data_bidi_remote.set_value( + kFakeInitialMaxStreamDataBidiRemote); + original_params_.initial_max_stream_data_uni.set_value( + kFakeInitialMaxStreamDataUni); + original_params_.initial_max_streams_bidi.set_value( + kFakeInitialMaxStreamsBidi); + original_params_.initial_max_streams_uni.set_value( + kFakeInitialMaxStreamsUni); + original_params_.ack_delay_exponent.set_value(kAckDelayExponentForTest); + original_params_.max_ack_delay.set_value(kMaxAckDelayForTest); + original_params_.min_ack_delay_us.set_value(kMinAckDelayUsForTest); + original_params_.disable_active_migration = kFakeDisableMigration; + original_params_.preferred_address = CreateFakePreferredAddress(); + original_params_.active_connection_id_limit.set_value( + kActiveConnectionIdLimitForTest); + original_params_.initial_source_connection_id = + CreateFakeInitialSourceConnectionId(); + original_params_.retry_source_connection_id = + CreateFakeRetrySourceConnectionId(); + original_params_.google_connection_options = + CreateFakeGoogleConnectionOptions(); + + ASSERT_TRUE(SerializeTransportParametersForTicket( + original_params_, application_state_, &original_serialized_params_)); + } + + TransportParameters original_params_; + std::vector application_state_ = {0, 1}; + std::vector original_serialized_params_; +}; + +TEST_F(TransportParametersTicketSerializationTest, + StatelessResetTokenDoesntChangeOutput) { + // Test that changing the stateless reset token doesn't change the ticket + // serialization. + TransportParameters new_params = original_params_; + new_params.stateless_reset_token = CreateFakePreferredStatelessResetToken(); + EXPECT_NE(new_params, original_params_); + + std::vector serialized; + ASSERT_TRUE(SerializeTransportParametersForTicket( + new_params, application_state_, &serialized)); + EXPECT_EQ(original_serialized_params_, serialized); +} + +TEST_F(TransportParametersTicketSerializationTest, + ConnectionIDDoesntChangeOutput) { + // Changing original destination CID doesn't change serialization. + TransportParameters new_params = original_params_; + new_params.original_destination_connection_id = TestConnectionId(0xCAFE); + EXPECT_NE(new_params, original_params_); + + std::vector serialized; + ASSERT_TRUE(SerializeTransportParametersForTicket( + new_params, application_state_, &serialized)); + EXPECT_EQ(original_serialized_params_, serialized); +} + +TEST_F(TransportParametersTicketSerializationTest, StreamLimitChangesOutput) { + // Changing a stream limit does change the serialization. + TransportParameters new_params = original_params_; + new_params.initial_max_stream_data_bidi_local.set_value( + kFakeInitialMaxStreamDataBidiLocal + 1); + EXPECT_NE(new_params, original_params_); + + std::vector serialized; + ASSERT_TRUE(SerializeTransportParametersForTicket( + new_params, application_state_, &serialized)); + EXPECT_NE(original_serialized_params_, serialized); +} + +TEST_F(TransportParametersTicketSerializationTest, + ApplicationStateChangesOutput) { + // Changing the application state changes the serialization. + std::vector new_application_state = {0}; + EXPECT_NE(new_application_state, application_state_); + + std::vector serialized; + ASSERT_TRUE(SerializeTransportParametersForTicket( + original_params_, new_application_state, &serialized)); + EXPECT_NE(original_serialized_params_, serialized); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/crypto/web_transport_fingerprint_proof_verifier.cc b/quiche/quic/core/crypto/web_transport_fingerprint_proof_verifier.cc new file mode 100644 index 000000000000..167e4efc456e --- /dev/null +++ b/quiche/quic/core/crypto/web_transport_fingerprint_proof_verifier.cc @@ -0,0 +1,231 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/web_transport_fingerprint_proof_verifier.h" + +#include +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "openssl/sha.h" +#include "quiche/quic/core/crypto/certificate_view.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/common/quiche_text_utils.h" + +namespace quic { +namespace { + +constexpr size_t kFingerprintLength = SHA256_DIGEST_LENGTH * 3 - 1; + +// Assumes that the character is normalized to lowercase beforehand. +bool IsNormalizedHexDigit(char c) { + return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f'); +} + +void NormalizeFingerprint(CertificateFingerprint& fingerprint) { + fingerprint.fingerprint = + quiche::QuicheTextUtils::ToLower(fingerprint.fingerprint); +} + +} // namespace + +constexpr char CertificateFingerprint::kSha256[]; +constexpr char WebTransportHash::kSha256[]; + +ProofVerifyDetails* WebTransportFingerprintProofVerifier::Details::Clone() + const { + return new Details(*this); +} + +WebTransportFingerprintProofVerifier::WebTransportFingerprintProofVerifier( + const QuicClock* clock, int max_validity_days) + : clock_(clock), + max_validity_days_(max_validity_days), + // Add an extra second to max validity to accomodate various edge cases. + max_validity_( + QuicTime::Delta::FromSeconds(max_validity_days * 86400 + 1)) {} + +bool WebTransportFingerprintProofVerifier::AddFingerprint( + CertificateFingerprint fingerprint) { + NormalizeFingerprint(fingerprint); + if (!absl::EqualsIgnoreCase(fingerprint.algorithm, + CertificateFingerprint::kSha256)) { + QUIC_DLOG(WARNING) << "Algorithms other than SHA-256 are not supported"; + return false; + } + if (fingerprint.fingerprint.size() != kFingerprintLength) { + QUIC_DLOG(WARNING) << "Invalid fingerprint length"; + return false; + } + for (size_t i = 0; i < fingerprint.fingerprint.size(); i++) { + char current = fingerprint.fingerprint[i]; + if (i % 3 == 2) { + if (current != ':') { + QUIC_DLOG(WARNING) + << "Missing colon separator between the bytes of the hash"; + return false; + } + } else { + if (!IsNormalizedHexDigit(current)) { + QUIC_DLOG(WARNING) << "Fingerprint must be in hexadecimal"; + return false; + } + } + } + + std::string normalized = + absl::StrReplaceAll(fingerprint.fingerprint, {{":", ""}}); + hashes_.push_back(WebTransportHash{fingerprint.algorithm, + absl::HexStringToBytes(normalized)}); + return true; +} + +bool WebTransportFingerprintProofVerifier::AddFingerprint( + WebTransportHash hash) { + if (hash.algorithm != CertificateFingerprint::kSha256) { + QUIC_DLOG(WARNING) << "Algorithms other than SHA-256 are not supported"; + return false; + } + if (hash.value.size() != SHA256_DIGEST_LENGTH) { + QUIC_DLOG(WARNING) << "Invalid fingerprint length"; + return false; + } + hashes_.push_back(std::move(hash)); + return true; +} + +QuicAsyncStatus WebTransportFingerprintProofVerifier::VerifyProof( + const std::string& /*hostname*/, const uint16_t /*port*/, + const std::string& /*server_config*/, + QuicTransportVersion /*transport_version*/, absl::string_view /*chlo_hash*/, + const std::vector& /*certs*/, const std::string& /*cert_sct*/, + const std::string& /*signature*/, const ProofVerifyContext* /*context*/, + std::string* error_details, std::unique_ptr* details, + std::unique_ptr /*callback*/) { + *error_details = + "QUIC crypto certificate verification is not supported in " + "WebTransportFingerprintProofVerifier"; + QUIC_BUG(quic_bug_10879_1) << *error_details; + *details = std::make_unique
(Status::kInternalError); + return QUIC_FAILURE; +} + +QuicAsyncStatus WebTransportFingerprintProofVerifier::VerifyCertChain( + const std::string& /*hostname*/, const uint16_t /*port*/, + const std::vector& certs, const std::string& /*ocsp_response*/, + const std::string& /*cert_sct*/, const ProofVerifyContext* /*context*/, + std::string* error_details, std::unique_ptr* details, + uint8_t* /*out_alert*/, + std::unique_ptr /*callback*/) { + if (certs.empty()) { + *details = std::make_unique
(Status::kInternalError); + *error_details = "No certificates provided"; + return QUIC_FAILURE; + } + + if (!HasKnownFingerprint(certs[0])) { + *details = std::make_unique
(Status::kUnknownFingerprint); + *error_details = "Certificate does not match any fingerprint"; + return QUIC_FAILURE; + } + + std::unique_ptr view = + CertificateView::ParseSingleCertificate(certs[0]); + if (view == nullptr) { + *details = std::make_unique
(Status::kCertificateParseFailure); + *error_details = "Failed to parse the certificate"; + return QUIC_FAILURE; + } + + if (!HasValidExpiry(*view)) { + *details = std::make_unique
(Status::kExpiryTooLong); + *error_details = + absl::StrCat("Certificate expiry exceeds the configured limit of ", + max_validity_days_, " days"); + return QUIC_FAILURE; + } + + if (!IsWithinValidityPeriod(*view)) { + *details = std::make_unique
(Status::kExpired); + *error_details = + "Certificate has expired or has validity listed in the future"; + return QUIC_FAILURE; + } + + if (!IsKeyTypeAllowedByPolicy(*view)) { + *details = std::make_unique
(Status::kDisallowedKeyAlgorithm); + *error_details = + absl::StrCat("Certificate uses a disallowed public key type (", + PublicKeyTypeToString(view->public_key_type()), ")"); + return QUIC_FAILURE; + } + + *details = std::make_unique
(Status::kValidCertificate); + return QUIC_SUCCESS; +} + +std::unique_ptr +WebTransportFingerprintProofVerifier::CreateDefaultContext() { + return nullptr; +} + +bool WebTransportFingerprintProofVerifier::HasKnownFingerprint( + absl::string_view der_certificate) { + // https://w3c.github.io/webtransport/#verify-a-certificate-hash + const std::string hash = RawSha256(der_certificate); + for (const WebTransportHash& reference : hashes_) { + if (reference.algorithm != WebTransportHash::kSha256) { + QUIC_BUG(quic_bug_10879_2) << "Unexpected non-SHA-256 hash"; + continue; + } + if (hash == reference.value) { + return true; + } + } + return false; +} + +bool WebTransportFingerprintProofVerifier::HasValidExpiry( + const CertificateView& certificate) { + if (!certificate.validity_start().IsBefore(certificate.validity_end())) { + return false; + } + + const QuicTime::Delta duration_seconds = + certificate.validity_end() - certificate.validity_start(); + return duration_seconds <= max_validity_; +} + +bool WebTransportFingerprintProofVerifier::IsWithinValidityPeriod( + const CertificateView& certificate) { + QuicWallTime now = clock_->WallNow(); + return now.IsAfter(certificate.validity_start()) && + now.IsBefore(certificate.validity_end()); +} + +bool WebTransportFingerprintProofVerifier::IsKeyTypeAllowedByPolicy( + const CertificateView& certificate) { + switch (certificate.public_key_type()) { + // https://github.com/w3c/webtransport/pull/375 defines P-256 as an MTI + // algorithm, and prohibits RSA. We also allow P-384 and Ed25519. + case PublicKeyType::kP256: + case PublicKeyType::kP384: + case PublicKeyType::kEd25519: + return true; + case PublicKeyType::kRsa: + // TODO(b/213614428): this should be false by default. + return true; + default: + return false; + } +} + +} // namespace quic diff --git a/quiche/quic/core/crypto/web_transport_fingerprint_proof_verifier.h b/quiche/quic/core/crypto/web_transport_fingerprint_proof_verifier.h new file mode 100644 index 000000000000..ea1908d27c55 --- /dev/null +++ b/quiche/quic/core/crypto/web_transport_fingerprint_proof_verifier.h @@ -0,0 +1,126 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_WEB_TRANSPORT_FINGERPRINT_PROOF_VERIFIER_H_ +#define QUICHE_QUIC_CORE_CRYPTO_WEB_TRANSPORT_FINGERPRINT_PROOF_VERIFIER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/certificate_view.h" +#include "quiche/quic/core/crypto/proof_verifier.h" +#include "quiche/quic/core/quic_clock.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// Represents a fingerprint of an X.509 certificate in a format based on +// https://w3c.github.io/webrtc-pc/#dom-rtcdtlsfingerprint. +// TODO(vasilvv): remove this once all consumers of this API use +// WebTransportHash. +struct QUIC_EXPORT_PRIVATE CertificateFingerprint { + static constexpr char kSha256[] = "sha-256"; + + // An algorithm described by one of the names in + // https://www.iana.org/assignments/hash-function-text-names/hash-function-text-names.xhtml + std::string algorithm; + // Hex-encoded, colon-separated fingerprint of the certificate. For example, + // "12:3d:5b:71:8c:54:df:85:7e:bd:e3:7c:66:da:f9:db:6a:94:8f:85:cb:6e:44:7f:09:3e:05:f2:dd:d4:f7:86" + std::string fingerprint; +}; + +// Represents a fingerprint of an X.509 certificate in a format based on +// https://w3c.github.io/webtransport/#dictdef-webtransporthash. +struct QUIC_EXPORT_PRIVATE WebTransportHash { + static constexpr char kSha256[] = "sha-256"; + + // An algorithm described by one of the names in + // https://www.iana.org/assignments/hash-function-text-names/hash-function-text-names.xhtml + std::string algorithm; + // Raw bytes of the hash. + std::string value; +}; + +// WebTransportFingerprintProofVerifier verifies the server leaf certificate +// against a supplied list of certificate fingerprints following the procedure +// described in the WebTransport specification. The certificate is deemed +// trusted if it matches a fingerprint in the list, has expiry dates that are +// not too long and has not expired. Only the leaf is checked, the rest of the +// chain is ignored. Reference specification: +// https://wicg.github.io/web-transport/#dom-quictransportconfiguration-server_certificate_fingerprints +class QUIC_EXPORT_PRIVATE WebTransportFingerprintProofVerifier + : public ProofVerifier { + public: + // Note: the entries in this list may be logged into a UMA histogram, and thus + // should not be renumbered. + enum class Status { + kValidCertificate = 0, + kUnknownFingerprint = 1, + kCertificateParseFailure = 2, + kExpiryTooLong = 3, + kExpired = 4, + kInternalError = 5, + kDisallowedKeyAlgorithm = 6, + + kMaxValue = kDisallowedKeyAlgorithm, + }; + + class QUIC_EXPORT_PRIVATE Details : public ProofVerifyDetails { + public: + explicit Details(Status status) : status_(status) {} + Status status() const { return status_; } + + ProofVerifyDetails* Clone() const override; + + private: + const Status status_; + }; + + // |clock| is used to check if the certificate has expired. It is not owned + // and must outlive the object. |max_validity_days| is the maximum time for + // which the certificate is allowed to be valid. + WebTransportFingerprintProofVerifier(const QuicClock* clock, + int max_validity_days); + + // Adds a certificate fingerprint to be trusted. The fingerprints are + // case-insensitive and are validated internally; the function returns true if + // the validation passes. + bool AddFingerprint(CertificateFingerprint fingerprint); + bool AddFingerprint(WebTransportHash hash); + + // ProofVerifier implementation. + QuicAsyncStatus VerifyProof( + const std::string& hostname, const uint16_t port, + const std::string& server_config, QuicTransportVersion transport_version, + absl::string_view chlo_hash, const std::vector& certs, + const std::string& cert_sct, const std::string& signature, + const ProofVerifyContext* context, std::string* error_details, + std::unique_ptr* details, + std::unique_ptr callback) override; + QuicAsyncStatus VerifyCertChain( + const std::string& hostname, const uint16_t port, + const std::vector& certs, const std::string& ocsp_response, + const std::string& cert_sct, const ProofVerifyContext* context, + std::string* error_details, std::unique_ptr* details, + uint8_t* out_alert, + std::unique_ptr callback) override; + std::unique_ptr CreateDefaultContext() override; + + protected: + virtual bool IsKeyTypeAllowedByPolicy(const CertificateView& certificate); + + private: + bool HasKnownFingerprint(absl::string_view der_certificate); + bool HasValidExpiry(const CertificateView& certificate); + bool IsWithinValidityPeriod(const CertificateView& certificate); + + const QuicClock* clock_; // Unowned. + const int max_validity_days_; + const QuicTime::Delta max_validity_; + std::vector hashes_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_WEB_TRANSPORT_FINGERPRINT_PROOF_VERIFIER_H_ diff --git a/quiche/quic/core/crypto/web_transport_fingerprint_proof_verifier_test.cc b/quiche/quic/core/crypto/web_transport_fingerprint_proof_verifier_test.cc new file mode 100644 index 000000000000..11c769d76ecd --- /dev/null +++ b/quiche/quic/core/crypto/web_transport_fingerprint_proof_verifier_test.cc @@ -0,0 +1,183 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/crypto/web_transport_fingerprint_proof_verifier.h" + +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/mock_clock.h" +#include "quiche/quic/test_tools/test_certificates.h" + +namespace quic { +namespace test { +namespace { + +using ::testing::HasSubstr; + +// 2020-02-01 12:35:56 UTC +constexpr QuicTime::Delta kValidTime = QuicTime::Delta::FromSeconds(1580560556); + +struct VerifyResult { + QuicAsyncStatus status; + WebTransportFingerprintProofVerifier::Status detailed_status; + std::string error; +}; + +class WebTransportFingerprintProofVerifierTest : public QuicTest { + public: + WebTransportFingerprintProofVerifierTest() { + clock_.AdvanceTime(kValidTime); + verifier_ = std::make_unique( + &clock_, /*max_validity_days=*/365); + AddTestCertificate(); + } + + protected: + VerifyResult Verify(absl::string_view certificate) { + VerifyResult result; + std::unique_ptr details; + uint8_t tls_alert; + result.status = verifier_->VerifyCertChain( + /*hostname=*/"", /*port=*/0, + std::vector{std::string(certificate)}, + /*ocsp_response=*/"", + /*cert_sct=*/"", + /*context=*/nullptr, &result.error, &details, &tls_alert, + /*callback=*/nullptr); + result.detailed_status = + static_cast( + details.get()) + ->status(); + return result; + } + + void AddTestCertificate() { + EXPECT_TRUE(verifier_->AddFingerprint(WebTransportHash{ + WebTransportHash::kSha256, RawSha256(kTestCertificate)})); + } + + MockClock clock_; + std::unique_ptr verifier_; +}; + +TEST_F(WebTransportFingerprintProofVerifierTest, Sha256Fingerprint) { + // Computed using `openssl x509 -fingerprint -sha256`. + EXPECT_EQ(absl::BytesToHexString(RawSha256(kTestCertificate)), + "f2e5465e2bf7ecd6f63066a5a37511734aa0eb7c4701" + "0e86d6758ed4f4fa1b0f"); +} + +TEST_F(WebTransportFingerprintProofVerifierTest, SimpleFingerprint) { + VerifyResult result = Verify(kTestCertificate); + EXPECT_EQ(result.status, QUIC_SUCCESS); + EXPECT_EQ(result.detailed_status, + WebTransportFingerprintProofVerifier::Status::kValidCertificate); + + result = Verify(kWildcardCertificate); + EXPECT_EQ(result.status, QUIC_FAILURE); + EXPECT_EQ(result.detailed_status, + WebTransportFingerprintProofVerifier::Status::kUnknownFingerprint); + + result = Verify("Some random text"); + EXPECT_EQ(result.status, QUIC_FAILURE); +} + +TEST_F(WebTransportFingerprintProofVerifierTest, Validity) { + // Validity periods of kTestCertificate, according to `openssl x509 -text`: + // Not Before: Jan 30 18:13:59 2020 GMT + // Not After : Feb 2 18:13:59 2020 GMT + + // 2020-01-29 19:00:00 UTC + constexpr QuicTime::Delta kStartTime = + QuicTime::Delta::FromSeconds(1580324400); + clock_.Reset(); + clock_.AdvanceTime(kStartTime); + + VerifyResult result = Verify(kTestCertificate); + EXPECT_EQ(result.status, QUIC_FAILURE); + EXPECT_EQ(result.detailed_status, + WebTransportFingerprintProofVerifier::Status::kExpired); + + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(86400)); + result = Verify(kTestCertificate); + EXPECT_EQ(result.status, QUIC_SUCCESS); + EXPECT_EQ(result.detailed_status, + WebTransportFingerprintProofVerifier::Status::kValidCertificate); + + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(4 * 86400)); + result = Verify(kTestCertificate); + EXPECT_EQ(result.status, QUIC_FAILURE); + EXPECT_EQ(result.detailed_status, + WebTransportFingerprintProofVerifier::Status::kExpired); +} + +TEST_F(WebTransportFingerprintProofVerifierTest, MaxValidity) { + verifier_ = std::make_unique( + &clock_, /*max_validity_days=*/2); + AddTestCertificate(); + VerifyResult result = Verify(kTestCertificate); + EXPECT_EQ(result.status, QUIC_FAILURE); + EXPECT_EQ(result.detailed_status, + WebTransportFingerprintProofVerifier::Status::kExpiryTooLong); + EXPECT_THAT(result.error, HasSubstr("limit of 2 days")); + + // kTestCertificate is valid for exactly four days. + verifier_ = std::make_unique( + &clock_, /*max_validity_days=*/4); + AddTestCertificate(); + result = Verify(kTestCertificate); + EXPECT_EQ(result.status, QUIC_SUCCESS); + EXPECT_EQ(result.detailed_status, + WebTransportFingerprintProofVerifier::Status::kValidCertificate); +} + +TEST_F(WebTransportFingerprintProofVerifierTest, InvalidCertificate) { + constexpr absl::string_view kInvalidCertificate = "Hello, world!"; + ASSERT_TRUE(verifier_->AddFingerprint(WebTransportHash{ + WebTransportHash::kSha256, RawSha256(kInvalidCertificate)})); + + VerifyResult result = Verify(kInvalidCertificate); + EXPECT_EQ(result.status, QUIC_FAILURE); + EXPECT_EQ( + result.detailed_status, + WebTransportFingerprintProofVerifier::Status::kCertificateParseFailure); +} + +TEST_F(WebTransportFingerprintProofVerifierTest, AddCertificate) { + // Accept all-uppercase fingerprints. + verifier_ = std::make_unique( + &clock_, /*max_validity_days=*/365); + EXPECT_TRUE(verifier_->AddFingerprint(CertificateFingerprint{ + CertificateFingerprint::kSha256, + "F2:E5:46:5E:2B:F7:EC:D6:F6:30:66:A5:A3:75:11:73:4A:A0:EB:" + "7C:47:01:0E:86:D6:75:8E:D4:F4:FA:1B:0F"})); + EXPECT_EQ(Verify(kTestCertificate).detailed_status, + WebTransportFingerprintProofVerifier::Status::kValidCertificate); + + // Reject unknown hash algorithms. + EXPECT_FALSE(verifier_->AddFingerprint(CertificateFingerprint{ + "sha-1", "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00"})); + // Reject invalid length. + EXPECT_FALSE(verifier_->AddFingerprint( + CertificateFingerprint{CertificateFingerprint::kSha256, "00:00:00:00"})); + // Reject missing colons. + EXPECT_FALSE(verifier_->AddFingerprint(CertificateFingerprint{ + CertificateFingerprint::kSha256, + "00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00." + "00.00.00.00.00.00.00.00.00.00.00.00.00"})); + // Reject non-hex symbols. + EXPECT_FALSE(verifier_->AddFingerprint(CertificateFingerprint{ + CertificateFingerprint::kSha256, + "zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:" + "zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz"})); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/deterministic_connection_id_generator.cc b/quiche/quic/core/deterministic_connection_id_generator.cc new file mode 100644 index 000000000000..fd86dc7bdee6 --- /dev/null +++ b/quiche/quic/core/deterministic_connection_id_generator.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/deterministic_connection_id_generator.h" + +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +DeterministicConnectionIdGenerator::DeterministicConnectionIdGenerator( + uint8_t expected_connection_id_length) + : expected_connection_id_length_(expected_connection_id_length) { + if (expected_connection_id_length_ > + kQuicMaxConnectionIdWithLengthPrefixLength) { + QUIC_BUG(quic_bug_465151159_01) + << "Issuing connection IDs longer than allowed in RFC9000"; + } +} + +absl::optional +DeterministicConnectionIdGenerator::GenerateNextConnectionId( + const QuicConnectionId& original) { + if (expected_connection_id_length_ == 0) { + return EmptyQuicConnectionId(); + } + const uint64_t connection_id_hash64 = QuicUtils::FNV1a_64_Hash( + absl::string_view(original.data(), original.length())); + if (expected_connection_id_length_ <= sizeof(uint64_t)) { + return QuicConnectionId( + reinterpret_cast(&connection_id_hash64), + expected_connection_id_length_); + } + char new_connection_id_data[255] = {}; + const absl::uint128 connection_id_hash128 = QuicUtils::FNV1a_128_Hash( + absl::string_view(original.data(), original.length())); + static_assert(sizeof(connection_id_hash64) + sizeof(connection_id_hash128) <= + sizeof(new_connection_id_data), + "bad size"); + memcpy(new_connection_id_data, &connection_id_hash64, + sizeof(connection_id_hash64)); + // TODO(martinduke): We don't have any test coverage of the line below. In + // particular, if the memcpy somehow misses a byte, a test could check if one + // byte position in generated connection IDs is always the same. + memcpy(new_connection_id_data + sizeof(connection_id_hash64), + &connection_id_hash128, sizeof(connection_id_hash128)); + return QuicConnectionId(new_connection_id_data, + expected_connection_id_length_); +} + +absl::optional +DeterministicConnectionIdGenerator::MaybeReplaceConnectionId( + const QuicConnectionId& original, const ParsedQuicVersion& version) { + if (original.length() == expected_connection_id_length_) { + return absl::optional(); + } + QUICHE_DCHECK(version.AllowsVariableLengthConnectionIds()); + absl::optional new_connection_id = + GenerateNextConnectionId(original); + // Verify that ReplaceShortServerConnectionId is deterministic. + QUICHE_DCHECK(new_connection_id.has_value()); + QUICHE_DCHECK_EQ( + *new_connection_id, + static_cast(*GenerateNextConnectionId(original))); + QUICHE_DCHECK_EQ(expected_connection_id_length_, new_connection_id->length()); + QUIC_DLOG(INFO) << "Replacing incoming connection ID " << original << " with " + << new_connection_id.value(); + return new_connection_id; +} + +} // namespace quic diff --git a/quiche/quic/core/deterministic_connection_id_generator.h b/quiche/quic/core/deterministic_connection_id_generator.h new file mode 100644 index 000000000000..74d42d9159c1 --- /dev/null +++ b/quiche/quic/core/deterministic_connection_id_generator.h @@ -0,0 +1,40 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// A Connection ID generator that generates deterministic connection IDs for +// QUIC servers. + +#ifndef QUICHE_QUIC_CORE_CONNECTION_ID_GENERATOR_DETERMINISTIC_H_ +#define QUICHE_QUIC_CORE_CONNECTION_ID_GENERATOR_DETERMINISTIC_H_ + +#include "quiche/quic/core/connection_id_generator.h" + +namespace quic { + +// Generates connection IDs deterministically from the provided original +// connection ID. +class QUIC_EXPORT_PRIVATE DeterministicConnectionIdGenerator + : public ConnectionIdGeneratorInterface { + public: + DeterministicConnectionIdGenerator(uint8_t expected_connection_id_length); + + // Hashes |original| to create a new connection ID. + absl::optional GenerateNextConnectionId( + const QuicConnectionId& original) override; + // Replace the connection ID if and only if |original| is not of the expected + // length. + absl::optional MaybeReplaceConnectionId( + const QuicConnectionId& original, + const ParsedQuicVersion& version) override; + uint8_t ConnectionIdLength(uint8_t /*first_byte*/) const override { + return expected_connection_id_length_; + } + + private: + const uint8_t expected_connection_id_length_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE__CONNECTION_ID_GENERATOR_DETERMINISTIC_H_ diff --git a/quiche/quic/core/deterministic_connection_id_generator_test.cc b/quiche/quic/core/deterministic_connection_id_generator_test.cc new file mode 100644 index 000000000000..6c3ee21d49f3 --- /dev/null +++ b/quiche/quic/core/deterministic_connection_id_generator_test.cc @@ -0,0 +1,126 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/deterministic_connection_id_generator.h" + +#include + +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { +namespace test { + +struct TestParams { + TestParams(int connection_id_length) + : connection_id_length_(connection_id_length) {} + TestParams() : TestParams(kQuicDefaultConnectionIdLength) {} + + friend std::ostream& operator<<(std::ostream& os, const TestParams& p) { + os << "{ connection ID length: " << p.connection_id_length_ << " }"; + return os; + } + + int connection_id_length_; +}; + +// Constructs various test permutations. +std::vector GetTestParams() { + std::vector params; + std::vector connection_id_lengths{7, 8, 9, 16, 20}; + for (int connection_id_length : connection_id_lengths) { + params.push_back(TestParams(connection_id_length)); + } + return params; +} + +class DeterministicConnectionIdGeneratorTest + : public QuicTestWithParam { + public: + DeterministicConnectionIdGeneratorTest() + : connection_id_length_(GetParam().connection_id_length_), + generator_(DeterministicConnectionIdGenerator(connection_id_length_)), + version_(ParsedQuicVersion::RFCv1()) {} + + protected: + int connection_id_length_; + DeterministicConnectionIdGenerator generator_; + ParsedQuicVersion version_; +}; + +INSTANTIATE_TEST_SUITE_P(DeterministicConnectionIdGeneratorTests, + DeterministicConnectionIdGeneratorTest, + ::testing::ValuesIn(GetTestParams())); + +TEST_P(DeterministicConnectionIdGeneratorTest, + NextConnectionIdIsDeterministic) { + // Verify that two equal connection IDs get the same replacement. + QuicConnectionId connection_id64a = TestConnectionId(33); + QuicConnectionId connection_id64b = TestConnectionId(33); + EXPECT_EQ(connection_id64a, connection_id64b); + EXPECT_EQ(*generator_.GenerateNextConnectionId(connection_id64a), + *generator_.GenerateNextConnectionId(connection_id64b)); + QuicConnectionId connection_id72a = TestConnectionIdNineBytesLong(42); + QuicConnectionId connection_id72b = TestConnectionIdNineBytesLong(42); + EXPECT_EQ(connection_id72a, connection_id72b); + EXPECT_EQ(*generator_.GenerateNextConnectionId(connection_id72a), + *generator_.GenerateNextConnectionId(connection_id72b)); +} + +TEST_P(DeterministicConnectionIdGeneratorTest, + NextConnectionIdLengthIsCorrect) { + // Verify that all generated IDs are of the correct length. + const char connection_id_bytes[255] = {}; + for (uint8_t i = 0; i < sizeof(connection_id_bytes) - 1; ++i) { + QuicConnectionId connection_id(connection_id_bytes, i); + absl::optional replacement_connection_id = + generator_.GenerateNextConnectionId(connection_id); + ASSERT_TRUE(replacement_connection_id.has_value()); + EXPECT_EQ(connection_id_length_, replacement_connection_id->length()); + } +} + +TEST_P(DeterministicConnectionIdGeneratorTest, NextConnectionIdHasEntropy) { + // Make sure all these test connection IDs have different replacements. + for (uint64_t i = 0; i < 256; ++i) { + QuicConnectionId connection_id_i = TestConnectionId(i); + absl::optional new_i = + generator_.GenerateNextConnectionId(connection_id_i); + ASSERT_TRUE(new_i.has_value()); + EXPECT_NE(connection_id_i, *new_i); + for (uint64_t j = i + 1; j <= 256; ++j) { + QuicConnectionId connection_id_j = TestConnectionId(j); + EXPECT_NE(connection_id_i, connection_id_j); + absl::optional new_j = + generator_.GenerateNextConnectionId(connection_id_j); + ASSERT_TRUE(new_j.has_value()); + EXPECT_NE(*new_i, *new_j); + } + } +} + +TEST_P(DeterministicConnectionIdGeneratorTest, + OnlyReplaceConnectionIdWithWrongLength) { + const char connection_id_input[] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, + 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14}; + for (int i = 0; i < kQuicMaxConnectionIdWithLengthPrefixLength; i++) { + QuicConnectionId input = QuicConnectionId(connection_id_input, i); + absl::optional output = + generator_.MaybeReplaceConnectionId(input, version_); + if (i == connection_id_length_) { + EXPECT_FALSE(output.has_value()); + } else { + ASSERT_TRUE(output.has_value()); + EXPECT_EQ(*output, generator_.GenerateNextConnectionId(input)); + } + } +} + +TEST_P(DeterministicConnectionIdGeneratorTest, ReturnLength) { + EXPECT_EQ(generator_.ConnectionIdLength(0x01), connection_id_length_); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/frames/quic_ack_frame.cc b/quiche/quic/core/frames/quic_ack_frame.cc new file mode 100644 index 000000000000..1e42b7dfcd65 --- /dev/null +++ b/quiche/quic/core/frames/quic_ack_frame.cc @@ -0,0 +1,188 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/frames/quic_ack_frame.h" + +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_interval.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" + +namespace quic { + +namespace { + +const QuicPacketCount kMaxPrintRange = 128; + +} // namespace + +bool IsAwaitingPacket(const QuicAckFrame& ack_frame, + QuicPacketNumber packet_number, + QuicPacketNumber peer_least_packet_awaiting_ack) { + QUICHE_DCHECK(packet_number.IsInitialized()); + return (!peer_least_packet_awaiting_ack.IsInitialized() || + packet_number >= peer_least_packet_awaiting_ack) && + !ack_frame.packets.Contains(packet_number); +} + +QuicAckFrame::QuicAckFrame() = default; + +QuicAckFrame::QuicAckFrame(const QuicAckFrame& other) = default; + +QuicAckFrame::~QuicAckFrame() {} + +std::ostream& operator<<(std::ostream& os, const QuicAckFrame& ack_frame) { + os << "{ largest_acked: " << LargestAcked(ack_frame) + << ", ack_delay_time: " << ack_frame.ack_delay_time.ToMicroseconds() + << ", packets: [ " << ack_frame.packets << " ]" + << ", received_packets: [ "; + for (const std::pair& p : + ack_frame.received_packet_times) { + os << p.first << " at " << p.second.ToDebuggingValue() << " "; + } + os << " ]"; + os << ", ecn_counters_populated: " << ack_frame.ecn_counters.has_value(); + if (ack_frame.ecn_counters.has_value()) { + os << ", ect_0_count: " << ack_frame.ecn_counters->ect0 + << ", ect_1_count: " << ack_frame.ecn_counters->ect1 + << ", ecn_ce_count: " << ack_frame.ecn_counters->ce; + } + + os << " }\n"; + return os; +} + +void QuicAckFrame::Clear() { + largest_acked.Clear(); + ack_delay_time = QuicTime::Delta::Infinite(); + received_packet_times.clear(); + packets.Clear(); +} + +PacketNumberQueue::PacketNumberQueue() {} +PacketNumberQueue::PacketNumberQueue(const PacketNumberQueue& other) = default; +PacketNumberQueue::PacketNumberQueue(PacketNumberQueue&& other) = default; +PacketNumberQueue::~PacketNumberQueue() {} + +PacketNumberQueue& PacketNumberQueue::operator=( + const PacketNumberQueue& other) = default; +PacketNumberQueue& PacketNumberQueue::operator=(PacketNumberQueue&& other) = + default; + +void PacketNumberQueue::Add(QuicPacketNumber packet_number) { + if (!packet_number.IsInitialized()) { + return; + } + packet_number_intervals_.AddOptimizedForAppend(packet_number, + packet_number + 1); +} + +void PacketNumberQueue::AddRange(QuicPacketNumber lower, + QuicPacketNumber higher) { + if (!lower.IsInitialized() || !higher.IsInitialized() || lower >= higher) { + return; + } + + packet_number_intervals_.AddOptimizedForAppend(lower, higher); +} + +bool PacketNumberQueue::RemoveUpTo(QuicPacketNumber higher) { + if (!higher.IsInitialized() || Empty()) { + return false; + } + return packet_number_intervals_.TrimLessThan(higher); +} + +void PacketNumberQueue::RemoveSmallestInterval() { + // TODO(wub): Move this QUIC_BUG to upper level. + QUIC_BUG_IF(quic_bug_12614_1, packet_number_intervals_.Size() < 2) + << (Empty() ? "No intervals to remove." + : "Can't remove the last interval."); + packet_number_intervals_.PopFront(); +} + +void PacketNumberQueue::Clear() { packet_number_intervals_.Clear(); } + +bool PacketNumberQueue::Contains(QuicPacketNumber packet_number) const { + if (!packet_number.IsInitialized()) { + return false; + } + return packet_number_intervals_.Contains(packet_number); +} + +bool PacketNumberQueue::Empty() const { + return packet_number_intervals_.Empty(); +} + +QuicPacketNumber PacketNumberQueue::Min() const { + QUICHE_DCHECK(!Empty()); + return packet_number_intervals_.begin()->min(); +} + +QuicPacketNumber PacketNumberQueue::Max() const { + QUICHE_DCHECK(!Empty()); + return packet_number_intervals_.rbegin()->max() - 1; +} + +QuicPacketCount PacketNumberQueue::NumPacketsSlow() const { + QuicPacketCount n_packets = 0; + for (const auto& interval : packet_number_intervals_) { + n_packets += interval.Length(); + } + return n_packets; +} + +size_t PacketNumberQueue::NumIntervals() const { + return packet_number_intervals_.Size(); +} + +PacketNumberQueue::const_iterator PacketNumberQueue::begin() const { + return packet_number_intervals_.begin(); +} + +PacketNumberQueue::const_iterator PacketNumberQueue::end() const { + return packet_number_intervals_.end(); +} + +PacketNumberQueue::const_reverse_iterator PacketNumberQueue::rbegin() const { + return packet_number_intervals_.rbegin(); +} + +PacketNumberQueue::const_reverse_iterator PacketNumberQueue::rend() const { + return packet_number_intervals_.rend(); +} + +QuicPacketCount PacketNumberQueue::LastIntervalLength() const { + QUICHE_DCHECK(!Empty()); + return packet_number_intervals_.rbegin()->Length(); +} + +// Largest min...max range for packet numbers where we print the numbers +// explicitly. If bigger than this, we print as a range [a,d] rather +// than [a b c d] + +std::ostream& operator<<(std::ostream& os, const PacketNumberQueue& q) { + for (const QuicInterval& interval : q) { + // Print as a range if there is a pathological condition. + if ((interval.min() >= interval.max()) || + (interval.max() - interval.min() > kMaxPrintRange)) { + // If min>max, it's really a bug, so QUIC_BUG it to + // catch it in development. + QUIC_BUG_IF(quic_bug_12614_2, interval.min() >= interval.max()) + << "Ack Range minimum (" << interval.min() << "Not less than max (" + << interval.max() << ")"; + // print range as min...max rather than full list. + // in the event of a bug, the list could be very big. + os << interval.min() << "..." << (interval.max() - 1) << " "; + } else { + for (QuicPacketNumber packet_number = interval.min(); + packet_number < interval.max(); ++packet_number) { + os << packet_number << " "; + } + } + } + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/frames/quic_ack_frame.h b/quiche/quic/core/frames/quic_ack_frame.h new file mode 100644 index 000000000000..6828a8c6e9ce --- /dev/null +++ b/quiche/quic/core/frames/quic_ack_frame.h @@ -0,0 +1,140 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_FRAMES_QUIC_ACK_FRAME_H_ +#define QUICHE_QUIC_CORE_FRAMES_QUIC_ACK_FRAME_H_ + +#include + +#include "quiche/quic/core/quic_interval.h" +#include "quiche/quic/core/quic_interval_set.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_flags.h" + +namespace quic { + +// A sequence of packet numbers where each number is unique. Intended to be used +// in a sliding window fashion, where smaller old packet numbers are removed and +// larger new packet numbers are added, with the occasional random access. +class QUIC_EXPORT_PRIVATE PacketNumberQueue { + public: + PacketNumberQueue(); + PacketNumberQueue(const PacketNumberQueue& other); + PacketNumberQueue(PacketNumberQueue&& other); + ~PacketNumberQueue(); + + PacketNumberQueue& operator=(const PacketNumberQueue& other); + PacketNumberQueue& operator=(PacketNumberQueue&& other); + + using const_iterator = QuicIntervalSet::const_iterator; + using const_reverse_iterator = + QuicIntervalSet::const_reverse_iterator; + + // Adds |packet_number| to the set of packets in the queue. + void Add(QuicPacketNumber packet_number); + + // Adds packets between [lower, higher) to the set of packets in the queue. + // No-op if |higher| < |lower|. + // NOTE(wub): Only used in tests as of Nov 2019. + void AddRange(QuicPacketNumber lower, QuicPacketNumber higher); + + // Removes packets with values less than |higher| from the set of packets in + // the queue. Returns true if packets were removed. + bool RemoveUpTo(QuicPacketNumber higher); + + // Removes the smallest interval in the queue. + void RemoveSmallestInterval(); + + // Clear this packet number queue. + void Clear(); + + // Returns true if the queue contains |packet_number|. + bool Contains(QuicPacketNumber packet_number) const; + + // Returns true if the queue is empty. + bool Empty() const; + + // Returns the minimum packet number stored in the queue. It is undefined + // behavior to call this if the queue is empty. + QuicPacketNumber Min() const; + + // Returns the maximum packet number stored in the queue. It is undefined + // behavior to call this if the queue is empty. + QuicPacketNumber Max() const; + + // Returns the number of unique packets stored in the queue. Inefficient; only + // exposed for testing. + QuicPacketCount NumPacketsSlow() const; + + // Returns the number of disjoint packet number intervals contained in the + // queue. + size_t NumIntervals() const; + + // Returns the length of last interval. + QuicPacketCount LastIntervalLength() const; + + // Returns iterators over the packet number intervals. + const_iterator begin() const; + const_iterator end() const; + const_reverse_iterator rbegin() const; + const_reverse_iterator rend() const; + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const PacketNumberQueue& q); + + private: + QuicIntervalSet packet_number_intervals_; +}; + +struct QUIC_EXPORT_PRIVATE QuicAckFrame { + QuicAckFrame(); + QuicAckFrame(const QuicAckFrame& other); + ~QuicAckFrame(); + + void Clear(); + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const QuicAckFrame& ack_frame); + + // The highest packet number we've observed from the peer. When |packets| is + // not empty, it should always be equal to packets.Max(). The |LargestAcked| + // function ensures this invariant in debug mode. + QuicPacketNumber largest_acked; + + // Time elapsed since largest_observed() was received until this Ack frame was + // sent. + QuicTime::Delta ack_delay_time = QuicTime::Delta::Infinite(); + + // Vector of for when packets arrived. + // For IETF versions, packet numbers and timestamps in this vector are both in + // ascending orders. Packets received out of order are not saved here. + PacketTimeVector received_packet_times; + + // Set of packets. + PacketNumberQueue packets; + + // ECN counters. + absl::optional ecn_counters; +}; + +// The highest acked packet number we've observed from the peer. If no packets +// have been observed, return 0. +inline QUIC_EXPORT_PRIVATE QuicPacketNumber +LargestAcked(const QuicAckFrame& frame) { + QUICHE_DCHECK(frame.packets.Empty() || + frame.packets.Max() == frame.largest_acked); + return frame.largest_acked; +} + +// True if the packet number is greater than largest_observed or is listed +// as missing. +// Always returns false for packet numbers less than least_unacked. +QUIC_EXPORT_PRIVATE bool IsAwaitingPacket( + const QuicAckFrame& ack_frame, QuicPacketNumber packet_number, + QuicPacketNumber peer_least_packet_awaiting_ack); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_FRAMES_QUIC_ACK_FRAME_H_ diff --git a/quiche/quic/core/frames/quic_ack_frequency_frame.cc b/quiche/quic/core/frames/quic_ack_frequency_frame.cc new file mode 100644 index 000000000000..9d2fc31dad18 --- /dev/null +++ b/quiche/quic/core/frames/quic_ack_frequency_frame.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/frames/quic_ack_frequency_frame.h" + +#include +#include + +namespace quic { + +QuicAckFrequencyFrame::QuicAckFrequencyFrame( + QuicControlFrameId control_frame_id, uint64_t sequence_number, + uint64_t packet_tolerance, QuicTime::Delta max_ack_delay) + : control_frame_id(control_frame_id), + sequence_number(sequence_number), + packet_tolerance(packet_tolerance), + max_ack_delay(max_ack_delay) {} + +std::ostream& operator<<(std::ostream& os, const QuicAckFrequencyFrame& frame) { + os << "{ control_frame_id: " << frame.control_frame_id + << ", sequence_number: " << frame.sequence_number + << ", packet_tolerance: " << frame.packet_tolerance + << ", max_ack_delay_ms: " << frame.max_ack_delay.ToMilliseconds() + << ", ignore_order: " << frame.ignore_order << " }\n"; + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/frames/quic_ack_frequency_frame.h b/quiche/quic/core/frames/quic_ack_frequency_frame.h new file mode 100644 index 000000000000..c9b3519acb1a --- /dev/null +++ b/quiche/quic/core/frames/quic_ack_frequency_frame.h @@ -0,0 +1,50 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_FRAMES_QUIC_ACK_FREQUENCY_FRAME_H_ +#define QUICHE_QUIC_CORE_FRAMES_QUIC_ACK_FREQUENCY_FRAME_H_ + +#include +#include + +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// A frame that allows sender control of acknowledgement delays. +struct QUIC_EXPORT_PRIVATE QuicAckFrequencyFrame { + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const QuicAckFrequencyFrame& ack_frequency_frame); + + QuicAckFrequencyFrame() = default; + QuicAckFrequencyFrame(QuicControlFrameId control_frame_id, + uint64_t sequence_number, uint64_t packet_tolerance, + QuicTime::Delta max_ack_delay); + + // A unique identifier of this control frame. 0 when this frame is + // received, and non-zero when sent. + QuicControlFrameId control_frame_id = kInvalidControlFrameId; + + // If true, do not ack immediately upon observeation of packet reordering. + bool ignore_order = false; + + // Sequence number assigned to the ACK_FREQUENCY frame by the sender to allow + // receivers to ignore obsolete frames. + uint64_t sequence_number = 0; + + // The maximum number of ack-eliciting packets after which the receiver sends + // an acknowledgement. Invald if == 0. + uint64_t packet_tolerance = 2; + + // The maximum time that ack packets can be delayed. + QuicTime::Delta max_ack_delay = + QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_FRAMES_QUIC_ACK_FREQUENCY_FRAME_H_ diff --git a/quiche/quic/core/frames/quic_blocked_frame.cc b/quiche/quic/core/frames/quic_blocked_frame.cc new file mode 100644 index 000000000000..594de278e0c8 --- /dev/null +++ b/quiche/quic/core/frames/quic_blocked_frame.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/frames/quic_blocked_frame.h" + +#include "quiche/quic/core/quic_types.h" + +namespace quic { + +QuicBlockedFrame::QuicBlockedFrame() : QuicInlinedFrame(BLOCKED_FRAME) {} + +QuicBlockedFrame::QuicBlockedFrame(QuicControlFrameId control_frame_id, + QuicStreamId stream_id, + QuicStreamOffset offset) + : QuicInlinedFrame(BLOCKED_FRAME), + control_frame_id(control_frame_id), + stream_id(stream_id), + offset(offset) {} + +std::ostream& operator<<(std::ostream& os, + const QuicBlockedFrame& blocked_frame) { + os << "{ control_frame_id: " << blocked_frame.control_frame_id + << ", stream_id: " << blocked_frame.stream_id + << ", offset: " << blocked_frame.offset << " }\n"; + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/frames/quic_blocked_frame.h b/quiche/quic/core/frames/quic_blocked_frame.h new file mode 100644 index 000000000000..982e1e8ca8ce --- /dev/null +++ b/quiche/quic/core/frames/quic_blocked_frame.h @@ -0,0 +1,49 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_FRAMES_QUIC_BLOCKED_FRAME_H_ +#define QUICHE_QUIC_CORE_FRAMES_QUIC_BLOCKED_FRAME_H_ + +#include + +#include "quiche/quic/core/frames/quic_inlined_frame.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_types.h" + +namespace quic { + +// The BLOCKED frame is used to indicate to the remote endpoint that this +// endpoint believes itself to be flow-control blocked but otherwise ready to +// send data. The BLOCKED frame is purely advisory and optional. +// Based on SPDY's BLOCKED frame (undocumented as of 2014-01-28). +struct QUIC_EXPORT_PRIVATE QuicBlockedFrame + : public QuicInlinedFrame { + QuicBlockedFrame(); + QuicBlockedFrame(QuicControlFrameId control_frame_id, QuicStreamId stream_id, + QuicStreamOffset offset); + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const QuicBlockedFrame& b); + + QuicFrameType type; + + // A unique identifier of this control frame. 0 when this frame is received, + // and non-zero when sent. + QuicControlFrameId control_frame_id = kInvalidControlFrameId; + + // 0 is a special case meaning the connection is blocked, rather than a + // stream. So stream_id 0 corresponds to a BLOCKED frame and non-0 + // corresponds to a STREAM_BLOCKED. + // TODO(fkastenholz): This should be converted to use + // QuicUtils::GetInvalidStreamId to get the correct invalid stream id value + // and not rely on 0. + QuicStreamId stream_id = 0; + + // For Google QUIC, the offset is ignored. + QuicStreamOffset offset = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_FRAMES_QUIC_BLOCKED_FRAME_H_ diff --git a/quiche/quic/core/frames/quic_connection_close_frame.cc b/quiche/quic/core/frames/quic_connection_close_frame.cc new file mode 100644 index 000000000000..640101e68253 --- /dev/null +++ b/quiche/quic/core/frames/quic_connection_close_frame.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/frames/quic_connection_close_frame.h" + +#include + +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" + +namespace quic { + +QuicConnectionCloseFrame::QuicConnectionCloseFrame( + QuicTransportVersion transport_version, QuicErrorCode error_code, + QuicIetfTransportErrorCodes ietf_error, std::string error_phrase, + uint64_t frame_type) + : quic_error_code(error_code), error_details(error_phrase) { + if (!VersionHasIetfQuicFrames(transport_version)) { + close_type = GOOGLE_QUIC_CONNECTION_CLOSE; + wire_error_code = error_code; + transport_close_frame_type = 0; + return; + } + QuicErrorCodeToIetfMapping mapping = + QuicErrorCodeToTransportErrorCode(error_code); + if (ietf_error != NO_IETF_QUIC_ERROR) { + wire_error_code = ietf_error; + } else { + wire_error_code = mapping.error_code; + } + if (mapping.is_transport_close) { + // Maps to a transport close + close_type = IETF_QUIC_TRANSPORT_CONNECTION_CLOSE; + transport_close_frame_type = frame_type; + return; + } + // Maps to an application close. + close_type = IETF_QUIC_APPLICATION_CONNECTION_CLOSE; + transport_close_frame_type = 0; +} + +std::ostream& operator<<( + std::ostream& os, const QuicConnectionCloseFrame& connection_close_frame) { + os << "{ Close type: " << connection_close_frame.close_type; + switch (connection_close_frame.close_type) { + case IETF_QUIC_TRANSPORT_CONNECTION_CLOSE: + os << ", wire_error_code: " + << static_cast( + connection_close_frame.wire_error_code); + break; + case IETF_QUIC_APPLICATION_CONNECTION_CLOSE: + os << ", wire_error_code: " << connection_close_frame.wire_error_code; + break; + case GOOGLE_QUIC_CONNECTION_CLOSE: + // Do not log, value same as |quic_error_code|. + break; + } + os << ", quic_error_code: " + << QuicErrorCodeToString(connection_close_frame.quic_error_code) + << ", error_details: '" << connection_close_frame.error_details << "'"; + if (connection_close_frame.close_type == + IETF_QUIC_TRANSPORT_CONNECTION_CLOSE) { + os << ", frame_type: " + << static_cast( + connection_close_frame.transport_close_frame_type); + } + os << "}\n"; + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/frames/quic_connection_close_frame.h b/quiche/quic/core/frames/quic_connection_close_frame.h new file mode 100644 index 000000000000..cf780dd5e033 --- /dev/null +++ b/quiche/quic/core/frames/quic_connection_close_frame.h @@ -0,0 +1,68 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_FRAMES_QUIC_CONNECTION_CLOSE_FRAME_H_ +#define QUICHE_QUIC_CORE_FRAMES_QUIC_CONNECTION_CLOSE_FRAME_H_ + +#include +#include + +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +struct QUIC_EXPORT_PRIVATE QuicConnectionCloseFrame { + QuicConnectionCloseFrame() = default; + + // Builds a connection close frame based on the transport version + // and the mapping of error_code. THIS IS THE PREFERRED C'TOR + // TO USE IF YOU NEED TO CREATE A CONNECTION-CLOSE-FRAME AND + // HAVE IT BE CORRECT FOR THE VERSION AND CODE MAPPINGS. + // |ietf_error| may optionally be be used to directly specify the wire + // error code. Otherwise if |ietf_error| is NO_IETF_QUIC_ERROR, the + // QuicErrorCodeToTransportErrorCode mapping of |error_code| will be used. + QuicConnectionCloseFrame(QuicTransportVersion transport_version, + QuicErrorCode error_code, + QuicIetfTransportErrorCodes ietf_error, + std::string error_phrase, + uint64_t transport_close_frame_type); + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const QuicConnectionCloseFrame& c); + + // Indicates whether the the frame is a Google QUIC CONNECTION_CLOSE frame, + // an IETF QUIC CONNECTION_CLOSE frame with transport error code, + // or an IETF QUIC CONNECTION_CLOSE frame with application error code. + QuicConnectionCloseType close_type = GOOGLE_QUIC_CONNECTION_CLOSE; + + // The error code on the wire. For Google QUIC frames, this has the same + // value as |quic_error_code|. + uint64_t wire_error_code = QUIC_NO_ERROR; + + // The underlying error. For Google QUIC frames, this has the same value as + // |wire_error_code|. For sent IETF QUIC frames, this is the error that + // triggered the closure of the connection. For received IETF QUIC frames, + // this is parsed from the Reason Phrase field of the CONNECTION_CLOSE frame, + // or QUIC_IETF_GQUIC_ERROR_MISSING. + QuicErrorCode quic_error_code = QUIC_NO_ERROR; + + // String with additional error details. |quic_error_code| and a colon will be + // prepended to the error details when sending IETF QUIC frames, and parsed + // into |quic_error_code| upon receipt, when present. + std::string error_details; + + // The frame type present in the IETF transport connection close frame. + // Not populated for the Google QUIC or application connection close frames. + // Contains the type of frame that triggered the connection close. Made a + // uint64, as opposed to the QuicIetfFrameType, to support possible + // extensions as well as reporting invalid frame types received from the peer. + uint64_t transport_close_frame_type = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_FRAMES_QUIC_CONNECTION_CLOSE_FRAME_H_ diff --git a/quiche/quic/core/frames/quic_crypto_frame.cc b/quiche/quic/core/frames/quic_crypto_frame.cc new file mode 100644 index 000000000000..11ccf6832e6e --- /dev/null +++ b/quiche/quic/core/frames/quic_crypto_frame.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/frames/quic_crypto_frame.h" + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +QuicCryptoFrame::QuicCryptoFrame(EncryptionLevel level, QuicStreamOffset offset, + QuicPacketLength data_length) + : QuicCryptoFrame(level, offset, nullptr, data_length) {} + +QuicCryptoFrame::QuicCryptoFrame(EncryptionLevel level, QuicStreamOffset offset, + absl::string_view data) + : QuicCryptoFrame(level, offset, data.data(), data.length()) {} + +QuicCryptoFrame::QuicCryptoFrame(EncryptionLevel level, QuicStreamOffset offset, + const char* data_buffer, + QuicPacketLength data_length) + : level(level), + data_length(data_length), + data_buffer(data_buffer), + offset(offset) {} + +QuicCryptoFrame::~QuicCryptoFrame() {} + +std::ostream& operator<<(std::ostream& os, + const QuicCryptoFrame& stream_frame) { + os << "{ level: " << stream_frame.level << ", offset: " << stream_frame.offset + << ", length: " << stream_frame.data_length << " }\n"; + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/frames/quic_crypto_frame.h b/quiche/quic/core/frames/quic_crypto_frame.h new file mode 100644 index 000000000000..19fb5793610f --- /dev/null +++ b/quiche/quic/core/frames/quic_crypto_frame.h @@ -0,0 +1,48 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_FRAMES_QUIC_CRYPTO_FRAME_H_ +#define QUICHE_QUIC_CORE_FRAMES_QUIC_CRYPTO_FRAME_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +struct QUIC_EXPORT_PRIVATE QuicCryptoFrame { + QuicCryptoFrame() = default; + QuicCryptoFrame(EncryptionLevel level, QuicStreamOffset offset, + QuicPacketLength data_length); + QuicCryptoFrame(EncryptionLevel level, QuicStreamOffset offset, + absl::string_view data); + ~QuicCryptoFrame(); + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<(std::ostream& os, + const QuicCryptoFrame& s); + + // TODO(haoyuewang) Consider replace the EncryptionLevel here with + // PacketNumberSpace. + // When writing a crypto frame to a packet, the packet must be encrypted at + // |level|. When a crypto frame is read, the encryption level of the packet it + // was received in is put in |level|. + EncryptionLevel level = ENCRYPTION_INITIAL; + QuicPacketLength data_length = 0; + // When reading, |data_buffer| points to the data that was received in the + // frame. |data_buffer| is not used when writing. + const char* data_buffer = nullptr; + QuicStreamOffset offset = 0; // Location of this data in the stream. + + QuicCryptoFrame(EncryptionLevel level, QuicStreamOffset offset, + const char* data_buffer, QuicPacketLength data_length); +}; +static_assert(sizeof(QuicCryptoFrame) <= 64, + "Keep the QuicCryptoFrame size to a cacheline."); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_FRAMES_QUIC_CRYPTO_FRAME_H_ diff --git a/quiche/quic/core/frames/quic_frame.cc b/quiche/quic/core/frames/quic_frame.cc new file mode 100644 index 000000000000..2ed05509fc50 --- /dev/null +++ b/quiche/quic/core/frames/quic_frame.cc @@ -0,0 +1,531 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/frames/quic_frame.h" + +#include "quiche/quic/core/frames/quic_new_connection_id_frame.h" +#include "quiche/quic/core/frames/quic_retire_connection_id_frame.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/common/quiche_buffer_allocator.h" + +namespace quic { + +QuicFrame::QuicFrame() {} + +QuicFrame::QuicFrame(QuicPaddingFrame padding_frame) + : padding_frame(padding_frame) {} + +QuicFrame::QuicFrame(QuicStreamFrame stream_frame) + : stream_frame(stream_frame) {} + +QuicFrame::QuicFrame(QuicHandshakeDoneFrame handshake_done_frame) + : handshake_done_frame(handshake_done_frame) {} + +QuicFrame::QuicFrame(QuicCryptoFrame* crypto_frame) + : type(CRYPTO_FRAME), crypto_frame(crypto_frame) {} + +QuicFrame::QuicFrame(QuicAckFrame* frame) : type(ACK_FRAME), ack_frame(frame) {} + +QuicFrame::QuicFrame(QuicMtuDiscoveryFrame frame) + : mtu_discovery_frame(frame) {} + +QuicFrame::QuicFrame(QuicStopWaitingFrame frame) : stop_waiting_frame(frame) {} + +QuicFrame::QuicFrame(QuicPingFrame frame) : ping_frame(frame) {} + +QuicFrame::QuicFrame(QuicRstStreamFrame* frame) + : type(RST_STREAM_FRAME), rst_stream_frame(frame) {} + +QuicFrame::QuicFrame(QuicConnectionCloseFrame* frame) + : type(CONNECTION_CLOSE_FRAME), connection_close_frame(frame) {} + +QuicFrame::QuicFrame(QuicGoAwayFrame* frame) + : type(GOAWAY_FRAME), goaway_frame(frame) {} + +QuicFrame::QuicFrame(QuicWindowUpdateFrame frame) + : window_update_frame(frame) {} + +QuicFrame::QuicFrame(QuicBlockedFrame frame) : blocked_frame(frame) {} + +QuicFrame::QuicFrame(QuicNewConnectionIdFrame* frame) + : type(NEW_CONNECTION_ID_FRAME), new_connection_id_frame(frame) {} + +QuicFrame::QuicFrame(QuicRetireConnectionIdFrame* frame) + : type(RETIRE_CONNECTION_ID_FRAME), retire_connection_id_frame(frame) {} + +QuicFrame::QuicFrame(QuicMaxStreamsFrame frame) : max_streams_frame(frame) {} + +QuicFrame::QuicFrame(QuicStreamsBlockedFrame frame) + : streams_blocked_frame(frame) {} + +QuicFrame::QuicFrame(QuicPathResponseFrame frame) + : path_response_frame(frame) {} + +QuicFrame::QuicFrame(QuicPathChallengeFrame frame) + : path_challenge_frame(frame) {} + +QuicFrame::QuicFrame(QuicStopSendingFrame frame) : stop_sending_frame(frame) {} + +QuicFrame::QuicFrame(QuicMessageFrame* frame) + : type(MESSAGE_FRAME), message_frame(frame) {} + +QuicFrame::QuicFrame(QuicNewTokenFrame* frame) + : type(NEW_TOKEN_FRAME), new_token_frame(frame) {} + +QuicFrame::QuicFrame(QuicAckFrequencyFrame* frame) + : type(ACK_FREQUENCY_FRAME), ack_frequency_frame(frame) {} + +void DeleteFrames(QuicFrames* frames) { + for (QuicFrame& frame : *frames) { + DeleteFrame(&frame); + } + frames->clear(); +} + +void DeleteFrame(QuicFrame* frame) { +#if QUIC_FRAME_DEBUG + // If the frame is not inlined, check that it can be safely deleted. + if (frame->type != PADDING_FRAME && frame->type != MTU_DISCOVERY_FRAME && + frame->type != PING_FRAME && frame->type != MAX_STREAMS_FRAME && + frame->type != STOP_WAITING_FRAME && + frame->type != STREAMS_BLOCKED_FRAME && frame->type != STREAM_FRAME && + frame->type != HANDSHAKE_DONE_FRAME && + frame->type != WINDOW_UPDATE_FRAME && frame->type != BLOCKED_FRAME && + frame->type != STOP_SENDING_FRAME && + frame->type != PATH_CHALLENGE_FRAME && + frame->type != PATH_RESPONSE_FRAME) { + QUICHE_CHECK(!frame->delete_forbidden) << *frame; + } +#endif // QUIC_FRAME_DEBUG + switch (frame->type) { + // Frames smaller than a pointer are inlined, so don't need to be deleted. + case PADDING_FRAME: + case MTU_DISCOVERY_FRAME: + case PING_FRAME: + case MAX_STREAMS_FRAME: + case STOP_WAITING_FRAME: + case STREAMS_BLOCKED_FRAME: + case STREAM_FRAME: + case HANDSHAKE_DONE_FRAME: + case WINDOW_UPDATE_FRAME: + case BLOCKED_FRAME: + case STOP_SENDING_FRAME: + case PATH_CHALLENGE_FRAME: + case PATH_RESPONSE_FRAME: + break; + case ACK_FRAME: + delete frame->ack_frame; + break; + case RST_STREAM_FRAME: + delete frame->rst_stream_frame; + break; + case CONNECTION_CLOSE_FRAME: + delete frame->connection_close_frame; + break; + case GOAWAY_FRAME: + delete frame->goaway_frame; + break; + case NEW_CONNECTION_ID_FRAME: + delete frame->new_connection_id_frame; + break; + case RETIRE_CONNECTION_ID_FRAME: + delete frame->retire_connection_id_frame; + break; + case MESSAGE_FRAME: + delete frame->message_frame; + break; + case CRYPTO_FRAME: + delete frame->crypto_frame; + break; + case NEW_TOKEN_FRAME: + delete frame->new_token_frame; + break; + case ACK_FREQUENCY_FRAME: + delete frame->ack_frequency_frame; + break; + case NUM_FRAME_TYPES: + QUICHE_DCHECK(false) << "Cannot delete type: " << frame->type; + } +} + +void RemoveFramesForStream(QuicFrames* frames, QuicStreamId stream_id) { + auto it = frames->begin(); + while (it != frames->end()) { + if (it->type != STREAM_FRAME || it->stream_frame.stream_id != stream_id) { + ++it; + continue; + } + it = frames->erase(it); + } +} + +bool IsControlFrame(QuicFrameType type) { + switch (type) { + case RST_STREAM_FRAME: + case GOAWAY_FRAME: + case WINDOW_UPDATE_FRAME: + case BLOCKED_FRAME: + case STREAMS_BLOCKED_FRAME: + case MAX_STREAMS_FRAME: + case PING_FRAME: + case STOP_SENDING_FRAME: + case NEW_CONNECTION_ID_FRAME: + case RETIRE_CONNECTION_ID_FRAME: + case HANDSHAKE_DONE_FRAME: + case ACK_FREQUENCY_FRAME: + case NEW_TOKEN_FRAME: + return true; + default: + return false; + } +} + +QuicControlFrameId GetControlFrameId(const QuicFrame& frame) { + switch (frame.type) { + case RST_STREAM_FRAME: + return frame.rst_stream_frame->control_frame_id; + case GOAWAY_FRAME: + return frame.goaway_frame->control_frame_id; + case WINDOW_UPDATE_FRAME: + return frame.window_update_frame.control_frame_id; + case BLOCKED_FRAME: + return frame.blocked_frame.control_frame_id; + case STREAMS_BLOCKED_FRAME: + return frame.streams_blocked_frame.control_frame_id; + case MAX_STREAMS_FRAME: + return frame.max_streams_frame.control_frame_id; + case PING_FRAME: + return frame.ping_frame.control_frame_id; + case STOP_SENDING_FRAME: + return frame.stop_sending_frame.control_frame_id; + case NEW_CONNECTION_ID_FRAME: + return frame.new_connection_id_frame->control_frame_id; + case RETIRE_CONNECTION_ID_FRAME: + return frame.retire_connection_id_frame->control_frame_id; + case HANDSHAKE_DONE_FRAME: + return frame.handshake_done_frame.control_frame_id; + case ACK_FREQUENCY_FRAME: + return frame.ack_frequency_frame->control_frame_id; + case NEW_TOKEN_FRAME: + return frame.new_token_frame->control_frame_id; + default: + return kInvalidControlFrameId; + } +} + +void SetControlFrameId(QuicControlFrameId control_frame_id, QuicFrame* frame) { + switch (frame->type) { + case RST_STREAM_FRAME: + frame->rst_stream_frame->control_frame_id = control_frame_id; + return; + case GOAWAY_FRAME: + frame->goaway_frame->control_frame_id = control_frame_id; + return; + case WINDOW_UPDATE_FRAME: + frame->window_update_frame.control_frame_id = control_frame_id; + return; + case BLOCKED_FRAME: + frame->blocked_frame.control_frame_id = control_frame_id; + return; + case PING_FRAME: + frame->ping_frame.control_frame_id = control_frame_id; + return; + case STREAMS_BLOCKED_FRAME: + frame->streams_blocked_frame.control_frame_id = control_frame_id; + return; + case MAX_STREAMS_FRAME: + frame->max_streams_frame.control_frame_id = control_frame_id; + return; + case STOP_SENDING_FRAME: + frame->stop_sending_frame.control_frame_id = control_frame_id; + return; + case NEW_CONNECTION_ID_FRAME: + frame->new_connection_id_frame->control_frame_id = control_frame_id; + return; + case RETIRE_CONNECTION_ID_FRAME: + frame->retire_connection_id_frame->control_frame_id = control_frame_id; + return; + case HANDSHAKE_DONE_FRAME: + frame->handshake_done_frame.control_frame_id = control_frame_id; + return; + case ACK_FREQUENCY_FRAME: + frame->ack_frequency_frame->control_frame_id = control_frame_id; + return; + case NEW_TOKEN_FRAME: + frame->new_token_frame->control_frame_id = control_frame_id; + return; + default: + QUIC_BUG(quic_bug_12594_1) + << "Try to set control frame id of a frame without control frame id"; + } +} + +QuicFrame CopyRetransmittableControlFrame(const QuicFrame& frame) { + QuicFrame copy; + switch (frame.type) { + case RST_STREAM_FRAME: + copy = QuicFrame(new QuicRstStreamFrame(*frame.rst_stream_frame)); + break; + case GOAWAY_FRAME: + copy = QuicFrame(new QuicGoAwayFrame(*frame.goaway_frame)); + break; + case WINDOW_UPDATE_FRAME: + copy = QuicFrame(QuicWindowUpdateFrame(frame.window_update_frame)); + break; + case BLOCKED_FRAME: + copy = QuicFrame(QuicBlockedFrame(frame.blocked_frame)); + break; + case PING_FRAME: + copy = QuicFrame(QuicPingFrame(frame.ping_frame.control_frame_id)); + break; + case STOP_SENDING_FRAME: + copy = QuicFrame(QuicStopSendingFrame(frame.stop_sending_frame)); + break; + case NEW_CONNECTION_ID_FRAME: + copy = QuicFrame( + new QuicNewConnectionIdFrame(*frame.new_connection_id_frame)); + break; + case RETIRE_CONNECTION_ID_FRAME: + copy = QuicFrame( + new QuicRetireConnectionIdFrame(*frame.retire_connection_id_frame)); + break; + case STREAMS_BLOCKED_FRAME: + copy = QuicFrame(QuicStreamsBlockedFrame(frame.streams_blocked_frame)); + break; + case MAX_STREAMS_FRAME: + copy = QuicFrame(QuicMaxStreamsFrame(frame.max_streams_frame)); + break; + case HANDSHAKE_DONE_FRAME: + copy = QuicFrame( + QuicHandshakeDoneFrame(frame.handshake_done_frame.control_frame_id)); + break; + case ACK_FREQUENCY_FRAME: + copy = QuicFrame(new QuicAckFrequencyFrame(*frame.ack_frequency_frame)); + break; + case NEW_TOKEN_FRAME: + copy = QuicFrame(new QuicNewTokenFrame(*frame.new_token_frame)); + break; + default: + QUIC_BUG(quic_bug_10533_1) + << "Try to copy a non-retransmittable control frame: " << frame; + copy = QuicFrame(QuicPingFrame(kInvalidControlFrameId)); + break; + } + return copy; +} + +QuicFrame CopyQuicFrame(quiche::QuicheBufferAllocator* allocator, + const QuicFrame& frame) { + QuicFrame copy; + switch (frame.type) { + case PADDING_FRAME: + copy = QuicFrame(QuicPaddingFrame(frame.padding_frame)); + break; + case RST_STREAM_FRAME: + copy = QuicFrame(new QuicRstStreamFrame(*frame.rst_stream_frame)); + break; + case CONNECTION_CLOSE_FRAME: + copy = QuicFrame( + new QuicConnectionCloseFrame(*frame.connection_close_frame)); + break; + case GOAWAY_FRAME: + copy = QuicFrame(new QuicGoAwayFrame(*frame.goaway_frame)); + break; + case WINDOW_UPDATE_FRAME: + copy = QuicFrame(QuicWindowUpdateFrame(frame.window_update_frame)); + break; + case BLOCKED_FRAME: + copy = QuicFrame(QuicBlockedFrame(frame.blocked_frame)); + break; + case STOP_WAITING_FRAME: + copy = QuicFrame(QuicStopWaitingFrame(frame.stop_waiting_frame)); + break; + case PING_FRAME: + copy = QuicFrame(QuicPingFrame(frame.ping_frame.control_frame_id)); + break; + case CRYPTO_FRAME: + copy = QuicFrame(new QuicCryptoFrame(*frame.crypto_frame)); + break; + case STREAM_FRAME: + copy = QuicFrame(QuicStreamFrame(frame.stream_frame)); + break; + case ACK_FRAME: + copy = QuicFrame(new QuicAckFrame(*frame.ack_frame)); + break; + case MTU_DISCOVERY_FRAME: + copy = QuicFrame(QuicMtuDiscoveryFrame(frame.mtu_discovery_frame)); + break; + case NEW_CONNECTION_ID_FRAME: + copy = QuicFrame( + new QuicNewConnectionIdFrame(*frame.new_connection_id_frame)); + break; + case MAX_STREAMS_FRAME: + copy = QuicFrame(QuicMaxStreamsFrame(frame.max_streams_frame)); + break; + case STREAMS_BLOCKED_FRAME: + copy = QuicFrame(QuicStreamsBlockedFrame(frame.streams_blocked_frame)); + break; + case PATH_RESPONSE_FRAME: + copy = QuicFrame(QuicPathResponseFrame(frame.path_response_frame)); + break; + case PATH_CHALLENGE_FRAME: + copy = QuicFrame(QuicPathChallengeFrame(frame.path_challenge_frame)); + break; + case STOP_SENDING_FRAME: + copy = QuicFrame(QuicStopSendingFrame(frame.stop_sending_frame)); + break; + case MESSAGE_FRAME: + copy = QuicFrame(new QuicMessageFrame(frame.message_frame->message_id)); + copy.message_frame->data = frame.message_frame->data; + copy.message_frame->message_length = frame.message_frame->message_length; + for (const auto& slice : frame.message_frame->message_data) { + quiche::QuicheBuffer buffer = + quiche::QuicheBuffer::Copy(allocator, slice.AsStringView()); + copy.message_frame->message_data.push_back( + quiche::QuicheMemSlice(std::move(buffer))); + } + break; + case NEW_TOKEN_FRAME: + copy = QuicFrame(new QuicNewTokenFrame(*frame.new_token_frame)); + break; + case RETIRE_CONNECTION_ID_FRAME: + copy = QuicFrame( + new QuicRetireConnectionIdFrame(*frame.retire_connection_id_frame)); + break; + case HANDSHAKE_DONE_FRAME: + copy = QuicFrame( + QuicHandshakeDoneFrame(frame.handshake_done_frame.control_frame_id)); + break; + case ACK_FREQUENCY_FRAME: + copy = QuicFrame(new QuicAckFrequencyFrame(*frame.ack_frequency_frame)); + break; + default: + QUIC_BUG(quic_bug_10533_2) << "Cannot copy frame: " << frame; + copy = QuicFrame(QuicPingFrame(kInvalidControlFrameId)); + break; + } + return copy; +} + +QuicFrames CopyQuicFrames(quiche::QuicheBufferAllocator* allocator, + const QuicFrames& frames) { + QuicFrames copy; + for (const auto& frame : frames) { + copy.push_back(CopyQuicFrame(allocator, frame)); + } + return copy; +} + +std::ostream& operator<<(std::ostream& os, const QuicFrame& frame) { + switch (frame.type) { + case PADDING_FRAME: { + os << "type { PADDING_FRAME } " << frame.padding_frame; + break; + } + case RST_STREAM_FRAME: { + os << "type { RST_STREAM_FRAME } " << *(frame.rst_stream_frame); + break; + } + case CONNECTION_CLOSE_FRAME: { + os << "type { CONNECTION_CLOSE_FRAME } " + << *(frame.connection_close_frame); + break; + } + case GOAWAY_FRAME: { + os << "type { GOAWAY_FRAME } " << *(frame.goaway_frame); + break; + } + case WINDOW_UPDATE_FRAME: { + os << "type { WINDOW_UPDATE_FRAME } " << frame.window_update_frame; + break; + } + case BLOCKED_FRAME: { + os << "type { BLOCKED_FRAME } " << frame.blocked_frame; + break; + } + case STREAM_FRAME: { + os << "type { STREAM_FRAME } " << frame.stream_frame; + break; + } + case ACK_FRAME: { + os << "type { ACK_FRAME } " << *(frame.ack_frame); + break; + } + case STOP_WAITING_FRAME: { + os << "type { STOP_WAITING_FRAME } " << frame.stop_waiting_frame; + break; + } + case PING_FRAME: { + os << "type { PING_FRAME } " << frame.ping_frame; + break; + } + case CRYPTO_FRAME: { + os << "type { CRYPTO_FRAME } " << *(frame.crypto_frame); + break; + } + case MTU_DISCOVERY_FRAME: { + os << "type { MTU_DISCOVERY_FRAME } "; + break; + } + case NEW_CONNECTION_ID_FRAME: + os << "type { NEW_CONNECTION_ID } " << *(frame.new_connection_id_frame); + break; + case RETIRE_CONNECTION_ID_FRAME: + os << "type { RETIRE_CONNECTION_ID } " + << *(frame.retire_connection_id_frame); + break; + case MAX_STREAMS_FRAME: + os << "type { MAX_STREAMS } " << frame.max_streams_frame; + break; + case STREAMS_BLOCKED_FRAME: + os << "type { STREAMS_BLOCKED } " << frame.streams_blocked_frame; + break; + case PATH_RESPONSE_FRAME: + os << "type { PATH_RESPONSE } " << frame.path_response_frame; + break; + case PATH_CHALLENGE_FRAME: + os << "type { PATH_CHALLENGE } " << frame.path_challenge_frame; + break; + case STOP_SENDING_FRAME: + os << "type { STOP_SENDING } " << frame.stop_sending_frame; + break; + case MESSAGE_FRAME: + os << "type { MESSAGE_FRAME }" << *(frame.message_frame); + break; + case NEW_TOKEN_FRAME: + os << "type { NEW_TOKEN_FRAME }" << *(frame.new_token_frame); + break; + case HANDSHAKE_DONE_FRAME: + os << "type { HANDSHAKE_DONE_FRAME } " << frame.handshake_done_frame; + break; + case ACK_FREQUENCY_FRAME: + os << "type { ACK_FREQUENCY_FRAME } " << *(frame.ack_frequency_frame); + break; + default: { + QUIC_LOG(ERROR) << "Unknown frame type: " << frame.type; + break; + } + } + return os; +} + +QUIC_EXPORT_PRIVATE std::string QuicFrameToString(const QuicFrame& frame) { + std::ostringstream os; + os << frame; + return os.str(); +} + +std::string QuicFramesToString(const QuicFrames& frames) { + std::ostringstream os; + for (const QuicFrame& frame : frames) { + os << frame; + } + return os.str(); +} + +} // namespace quic diff --git a/quiche/quic/core/frames/quic_frame.h b/quiche/quic/core/frames/quic_frame.h new file mode 100644 index 000000000000..3b7196ec6f7c --- /dev/null +++ b/quiche/quic/core/frames/quic_frame.h @@ -0,0 +1,174 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_FRAMES_QUIC_FRAME_H_ +#define QUICHE_QUIC_CORE_FRAMES_QUIC_FRAME_H_ + +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "quiche/quic/core/frames/quic_ack_frame.h" +#include "quiche/quic/core/frames/quic_ack_frequency_frame.h" +#include "quiche/quic/core/frames/quic_blocked_frame.h" +#include "quiche/quic/core/frames/quic_connection_close_frame.h" +#include "quiche/quic/core/frames/quic_crypto_frame.h" +#include "quiche/quic/core/frames/quic_goaway_frame.h" +#include "quiche/quic/core/frames/quic_handshake_done_frame.h" +#include "quiche/quic/core/frames/quic_max_streams_frame.h" +#include "quiche/quic/core/frames/quic_message_frame.h" +#include "quiche/quic/core/frames/quic_mtu_discovery_frame.h" +#include "quiche/quic/core/frames/quic_new_connection_id_frame.h" +#include "quiche/quic/core/frames/quic_new_token_frame.h" +#include "quiche/quic/core/frames/quic_padding_frame.h" +#include "quiche/quic/core/frames/quic_path_challenge_frame.h" +#include "quiche/quic/core/frames/quic_path_response_frame.h" +#include "quiche/quic/core/frames/quic_ping_frame.h" +#include "quiche/quic/core/frames/quic_retire_connection_id_frame.h" +#include "quiche/quic/core/frames/quic_rst_stream_frame.h" +#include "quiche/quic/core/frames/quic_stop_sending_frame.h" +#include "quiche/quic/core/frames/quic_stop_waiting_frame.h" +#include "quiche/quic/core/frames/quic_stream_frame.h" +#include "quiche/quic/core/frames/quic_streams_blocked_frame.h" +#include "quiche/quic/core/frames/quic_window_update_frame.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +#ifndef QUIC_FRAME_DEBUG +#if !defined(NDEBUG) || defined(ADDRESS_SANITIZER) +#define QUIC_FRAME_DEBUG 1 +#else // !defined(NDEBUG) || defined(ADDRESS_SANITIZER) +#define QUIC_FRAME_DEBUG 0 +#endif // !defined(NDEBUG) || defined(ADDRESS_SANITIZER) +#endif // QUIC_FRAME_DEBUG + +namespace quic { + +struct QUIC_EXPORT_PRIVATE QuicFrame { + QuicFrame(); + // Please keep the constructors in the same order as the union below. + explicit QuicFrame(QuicPaddingFrame padding_frame); + explicit QuicFrame(QuicMtuDiscoveryFrame frame); + explicit QuicFrame(QuicPingFrame frame); + explicit QuicFrame(QuicMaxStreamsFrame frame); + explicit QuicFrame(QuicStopWaitingFrame frame); + explicit QuicFrame(QuicStreamsBlockedFrame frame); + explicit QuicFrame(QuicStreamFrame stream_frame); + explicit QuicFrame(QuicHandshakeDoneFrame handshake_done_frame); + explicit QuicFrame(QuicWindowUpdateFrame frame); + explicit QuicFrame(QuicBlockedFrame frame); + explicit QuicFrame(QuicStopSendingFrame frame); + explicit QuicFrame(QuicPathChallengeFrame frame); + explicit QuicFrame(QuicPathResponseFrame frame); + + explicit QuicFrame(QuicAckFrame* frame); + explicit QuicFrame(QuicRstStreamFrame* frame); + explicit QuicFrame(QuicConnectionCloseFrame* frame); + explicit QuicFrame(QuicGoAwayFrame* frame); + explicit QuicFrame(QuicNewConnectionIdFrame* frame); + explicit QuicFrame(QuicRetireConnectionIdFrame* frame); + explicit QuicFrame(QuicNewTokenFrame* frame); + explicit QuicFrame(QuicMessageFrame* message_frame); + explicit QuicFrame(QuicCryptoFrame* crypto_frame); + explicit QuicFrame(QuicAckFrequencyFrame* ack_frequency_frame); + + QUIC_EXPORT_PRIVATE friend std::ostream& operator<<(std::ostream& os, + const QuicFrame& frame); + + union { + // Inlined frames. + // Overlapping inlined frames have a |type| field at the same 0 offset as + // QuicFrame does for out of line frames below, allowing use of the + // remaining 7 bytes after offset for frame-type specific fields. + QuicPaddingFrame padding_frame; + QuicMtuDiscoveryFrame mtu_discovery_frame; + QuicPingFrame ping_frame; + QuicMaxStreamsFrame max_streams_frame; + QuicStopWaitingFrame stop_waiting_frame; + QuicStreamsBlockedFrame streams_blocked_frame; + QuicStreamFrame stream_frame; + QuicHandshakeDoneFrame handshake_done_frame; + QuicWindowUpdateFrame window_update_frame; + QuicBlockedFrame blocked_frame; + QuicStopSendingFrame stop_sending_frame; + QuicPathChallengeFrame path_challenge_frame; + QuicPathResponseFrame path_response_frame; + + // Out of line frames. + struct { + QuicFrameType type; + +#if QUIC_FRAME_DEBUG + bool delete_forbidden = false; +#endif // QUIC_FRAME_DEBUG + + union { + QuicAckFrame* ack_frame; + QuicRstStreamFrame* rst_stream_frame; + QuicConnectionCloseFrame* connection_close_frame; + QuicGoAwayFrame* goaway_frame; + QuicNewConnectionIdFrame* new_connection_id_frame; + QuicRetireConnectionIdFrame* retire_connection_id_frame; + QuicMessageFrame* message_frame; + QuicCryptoFrame* crypto_frame; + QuicAckFrequencyFrame* ack_frequency_frame; + QuicNewTokenFrame* new_token_frame; + }; + }; + }; +}; + +static_assert(std::is_standard_layout::value, + "QuicFrame must have a standard layout"); +static_assert(sizeof(QuicFrame) <= 24, + "Frames larger than 24 bytes should be referenced by pointer."); +static_assert(offsetof(QuicStreamFrame, type) == offsetof(QuicFrame, type), + "Offset of |type| must match in QuicFrame and QuicStreamFrame"); + +// A inline size of 1 is chosen to optimize the typical use case of +// 1-stream-frame in QuicTransmissionInfo.retransmittable_frames. +using QuicFrames = absl::InlinedVector; + +// Deletes all the sub-frames contained in |frames|. +QUIC_EXPORT_PRIVATE void DeleteFrames(QuicFrames* frames); + +// Delete the sub-frame contained in |frame|. +QUIC_EXPORT_PRIVATE void DeleteFrame(QuicFrame* frame); + +// Deletes all the QuicStreamFrames for the specified |stream_id|. +QUIC_EXPORT_PRIVATE void RemoveFramesForStream(QuicFrames* frames, + QuicStreamId stream_id); + +// Returns true if |type| is a retransmittable control frame. +QUIC_EXPORT_PRIVATE bool IsControlFrame(QuicFrameType type); + +// Returns control_frame_id of |frame|. Returns kInvalidControlFrameId if +// |frame| does not have a valid control_frame_id. +QUIC_EXPORT_PRIVATE QuicControlFrameId +GetControlFrameId(const QuicFrame& frame); + +// Sets control_frame_id of |frame| to |control_frame_id|. +QUIC_EXPORT_PRIVATE void SetControlFrameId(QuicControlFrameId control_frame_id, + QuicFrame* frame); + +// Returns a copy of |frame|. +QUIC_EXPORT_PRIVATE QuicFrame +CopyRetransmittableControlFrame(const QuicFrame& frame); + +// Returns a copy of |frame|. +QUIC_EXPORT_PRIVATE QuicFrame +CopyQuicFrame(quiche::QuicheBufferAllocator* allocator, const QuicFrame& frame); + +// Returns a copy of |frames|. +QUIC_EXPORT_PRIVATE QuicFrames CopyQuicFrames( + quiche::QuicheBufferAllocator* allocator, const QuicFrames& frames); + +// Human-readable description suitable for logging. +QUIC_EXPORT_PRIVATE std::string QuicFrameToString(const QuicFrame& frame); +QUIC_EXPORT_PRIVATE std::string QuicFramesToString(const QuicFrames& frames); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_FRAMES_QUIC_FRAME_H_ diff --git a/quiche/quic/core/frames/quic_frames_test.cc b/quiche/quic/core/frames/quic_frames_test.cc new file mode 100644 index 000000000000..671e7724014d --- /dev/null +++ b/quiche/quic/core/frames/quic_frames_test.cc @@ -0,0 +1,846 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/frames/quic_ack_frame.h" +#include "quiche/quic/core/frames/quic_blocked_frame.h" +#include "quiche/quic/core/frames/quic_connection_close_frame.h" +#include "quiche/quic/core/frames/quic_frame.h" +#include "quiche/quic/core/frames/quic_goaway_frame.h" +#include "quiche/quic/core/frames/quic_mtu_discovery_frame.h" +#include "quiche/quic/core/frames/quic_new_connection_id_frame.h" +#include "quiche/quic/core/frames/quic_padding_frame.h" +#include "quiche/quic/core/frames/quic_ping_frame.h" +#include "quiche/quic/core/frames/quic_rst_stream_frame.h" +#include "quiche/quic/core/frames/quic_stop_waiting_frame.h" +#include "quiche/quic/core/frames/quic_stream_frame.h" +#include "quiche/quic/core/frames/quic_window_update_frame.h" +#include "quiche/quic/core/quic_interval.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { +namespace test { +namespace { + +class QuicFramesTest : public QuicTest {}; + +TEST_F(QuicFramesTest, AckFrameToString) { + QuicAckFrame frame; + frame.largest_acked = QuicPacketNumber(5); + frame.ack_delay_time = QuicTime::Delta::FromMicroseconds(3); + frame.packets.Add(QuicPacketNumber(4)); + frame.packets.Add(QuicPacketNumber(5)); + frame.received_packet_times = { + {QuicPacketNumber(6), + QuicTime::Zero() + QuicTime::Delta::FromMicroseconds(7)}}; + std::ostringstream stream; + stream << frame; + EXPECT_EQ( + "{ largest_acked: 5, ack_delay_time: 3, packets: [ 4 5 ], " + "received_packets: [ 6 at 7 ], ecn_counters_populated: 0 }\n", + stream.str()); + QuicFrame quic_frame(&frame); + EXPECT_FALSE(IsControlFrame(quic_frame.type)); +} + +TEST_F(QuicFramesTest, BigAckFrameToString) { + QuicAckFrame frame; + frame.largest_acked = QuicPacketNumber(500); + frame.ack_delay_time = QuicTime::Delta::FromMicroseconds(3); + frame.packets.AddRange(QuicPacketNumber(4), QuicPacketNumber(501)); + frame.received_packet_times = { + {QuicPacketNumber(500), + QuicTime::Zero() + QuicTime::Delta::FromMicroseconds(7)}}; + std::ostringstream stream; + stream << frame; + EXPECT_EQ( + "{ largest_acked: 500, ack_delay_time: 3, packets: [ 4...500 ], " + "received_packets: [ 500 at 7 ], ecn_counters_populated: 0 }\n", + stream.str()); + QuicFrame quic_frame(&frame); + EXPECT_FALSE(IsControlFrame(quic_frame.type)); +} + +TEST_F(QuicFramesTest, PaddingFrameToString) { + QuicPaddingFrame frame; + frame.num_padding_bytes = 1; + std::ostringstream stream; + stream << frame; + EXPECT_EQ("{ num_padding_bytes: 1 }\n", stream.str()); + QuicFrame quic_frame(frame); + EXPECT_FALSE(IsControlFrame(quic_frame.type)); +} + +TEST_F(QuicFramesTest, RstStreamFrameToString) { + QuicRstStreamFrame rst_stream; + QuicFrame frame(&rst_stream); + SetControlFrameId(1, &frame); + EXPECT_EQ(1u, GetControlFrameId(frame)); + rst_stream.stream_id = 1; + rst_stream.byte_offset = 3; + rst_stream.error_code = QUIC_STREAM_CANCELLED; + std::ostringstream stream; + stream << rst_stream; + EXPECT_EQ( + "{ control_frame_id: 1, stream_id: 1, byte_offset: 3, error_code: 6, " + "ietf_error_code: 0 }\n", + stream.str()); + EXPECT_TRUE(IsControlFrame(frame.type)); +} + +TEST_F(QuicFramesTest, StopSendingFrameToString) { + QuicFrame frame((QuicStopSendingFrame())); + SetControlFrameId(1, &frame); + EXPECT_EQ(1u, GetControlFrameId(frame)); + frame.stop_sending_frame.stream_id = 321; + frame.stop_sending_frame.error_code = QUIC_STREAM_CANCELLED; + frame.stop_sending_frame.ietf_error_code = + static_cast(QuicHttp3ErrorCode::REQUEST_CANCELLED); + std::ostringstream stream; + stream << frame.stop_sending_frame; + EXPECT_EQ( + "{ control_frame_id: 1, stream_id: 321, error_code: 6, ietf_error_code: " + "268 }\n", + stream.str()); +} + +TEST_F(QuicFramesTest, NewConnectionIdFrameToString) { + QuicNewConnectionIdFrame new_connection_id_frame; + QuicFrame frame(&new_connection_id_frame); + SetControlFrameId(1, &frame); + QuicFrame frame_copy = CopyRetransmittableControlFrame(frame); + EXPECT_EQ(1u, GetControlFrameId(frame_copy)); + new_connection_id_frame.connection_id = TestConnectionId(2); + new_connection_id_frame.sequence_number = 2u; + new_connection_id_frame.retire_prior_to = 1u; + new_connection_id_frame.stateless_reset_token = + StatelessResetToken{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}; + std::ostringstream stream; + stream << new_connection_id_frame; + EXPECT_EQ( + "{ control_frame_id: 1, connection_id: 0000000000000002, " + "sequence_number: 2, retire_prior_to: 1 }\n", + stream.str()); + EXPECT_TRUE(IsControlFrame(frame_copy.type)); + DeleteFrame(&frame_copy); +} + +TEST_F(QuicFramesTest, RetireConnectionIdFrameToString) { + QuicRetireConnectionIdFrame retire_connection_id_frame; + QuicFrame frame(&retire_connection_id_frame); + SetControlFrameId(1, &frame); + QuicFrame frame_copy = CopyRetransmittableControlFrame(frame); + EXPECT_EQ(1u, GetControlFrameId(frame_copy)); + retire_connection_id_frame.sequence_number = 1u; + std::ostringstream stream; + stream << retire_connection_id_frame; + EXPECT_EQ("{ control_frame_id: 1, sequence_number: 1 }\n", stream.str()); + EXPECT_TRUE(IsControlFrame(frame_copy.type)); + DeleteFrame(&frame_copy); +} + +TEST_F(QuicFramesTest, StreamsBlockedFrameToString) { + QuicStreamsBlockedFrame streams_blocked; + QuicFrame frame(streams_blocked); + SetControlFrameId(1, &frame); + EXPECT_EQ(1u, GetControlFrameId(frame)); + // QuicStreamsBlocked is copied into a QuicFrame (as opposed to putting a + // pointer to it into QuicFrame) so need to work with the copy in |frame| and + // not the original one, streams_blocked. + frame.streams_blocked_frame.stream_count = 321; + frame.streams_blocked_frame.unidirectional = false; + std::ostringstream stream; + stream << frame.streams_blocked_frame; + EXPECT_EQ("{ control_frame_id: 1, stream count: 321, bidirectional }\n", + stream.str()); + EXPECT_TRUE(IsControlFrame(frame.type)); +} + +TEST_F(QuicFramesTest, MaxStreamsFrameToString) { + QuicMaxStreamsFrame max_streams; + QuicFrame frame(max_streams); + SetControlFrameId(1, &frame); + EXPECT_EQ(1u, GetControlFrameId(frame)); + // QuicMaxStreams is copied into a QuicFrame (as opposed to putting a + // pointer to it into QuicFrame) so need to work with the copy in |frame| and + // not the original one, max_streams. + frame.max_streams_frame.stream_count = 321; + frame.max_streams_frame.unidirectional = true; + std::ostringstream stream; + stream << frame.max_streams_frame; + EXPECT_EQ("{ control_frame_id: 1, stream_count: 321, unidirectional }\n", + stream.str()); + EXPECT_TRUE(IsControlFrame(frame.type)); +} + +TEST_F(QuicFramesTest, ConnectionCloseFrameToString) { + QuicConnectionCloseFrame frame; + frame.quic_error_code = QUIC_NETWORK_IDLE_TIMEOUT; + frame.error_details = "No recent network activity."; + std::ostringstream stream; + stream << frame; + // Note that "extracted_error_code: 122" is QUIC_IETF_GQUIC_ERROR_MISSING, + // indicating that, in fact, no extended error code was available from the + // underlying frame. + EXPECT_EQ( + "{ Close type: GOOGLE_QUIC_CONNECTION_CLOSE, " + "quic_error_code: QUIC_NETWORK_IDLE_TIMEOUT, " + "error_details: 'No recent network activity.'}\n", + stream.str()); + QuicFrame quic_frame(&frame); + EXPECT_FALSE(IsControlFrame(quic_frame.type)); +} + +TEST_F(QuicFramesTest, TransportConnectionCloseFrameToString) { + QuicConnectionCloseFrame frame; + frame.close_type = IETF_QUIC_TRANSPORT_CONNECTION_CLOSE; + frame.wire_error_code = FINAL_SIZE_ERROR; + frame.quic_error_code = QUIC_NETWORK_IDLE_TIMEOUT; + frame.error_details = "No recent network activity."; + frame.transport_close_frame_type = IETF_STREAM; + std::ostringstream stream; + stream << frame; + EXPECT_EQ( + "{ Close type: IETF_QUIC_TRANSPORT_CONNECTION_CLOSE, " + "wire_error_code: FINAL_SIZE_ERROR, " + "quic_error_code: QUIC_NETWORK_IDLE_TIMEOUT, " + "error_details: 'No recent " + "network activity.', " + "frame_type: IETF_STREAM" + "}\n", + stream.str()); + QuicFrame quic_frame(&frame); + EXPECT_FALSE(IsControlFrame(quic_frame.type)); +} + +TEST_F(QuicFramesTest, GoAwayFrameToString) { + QuicGoAwayFrame goaway_frame; + QuicFrame frame(&goaway_frame); + SetControlFrameId(2, &frame); + EXPECT_EQ(2u, GetControlFrameId(frame)); + goaway_frame.error_code = QUIC_NETWORK_IDLE_TIMEOUT; + goaway_frame.last_good_stream_id = 2; + goaway_frame.reason_phrase = "Reason"; + std::ostringstream stream; + stream << goaway_frame; + EXPECT_EQ( + "{ control_frame_id: 2, error_code: 25, last_good_stream_id: 2, " + "reason_phrase: " + "'Reason' }\n", + stream.str()); + EXPECT_TRUE(IsControlFrame(frame.type)); +} + +TEST_F(QuicFramesTest, WindowUpdateFrameToString) { + QuicFrame frame((QuicWindowUpdateFrame())); + SetControlFrameId(3, &frame); + EXPECT_EQ(3u, GetControlFrameId(frame)); + std::ostringstream stream; + frame.window_update_frame.stream_id = 1; + frame.window_update_frame.max_data = 2; + stream << frame.window_update_frame; + EXPECT_EQ("{ control_frame_id: 3, stream_id: 1, max_data: 2 }\n", + stream.str()); + EXPECT_TRUE(IsControlFrame(frame.type)); +} + +TEST_F(QuicFramesTest, BlockedFrameToString) { + QuicFrame frame((QuicBlockedFrame())); + SetControlFrameId(4, &frame); + EXPECT_EQ(4u, GetControlFrameId(frame)); + frame.blocked_frame.stream_id = 1; + frame.blocked_frame.offset = 2; + std::ostringstream stream; + stream << frame.blocked_frame; + EXPECT_EQ("{ control_frame_id: 4, stream_id: 1, offset: 2 }\n", stream.str()); + EXPECT_TRUE(IsControlFrame(frame.type)); +} + +TEST_F(QuicFramesTest, PingFrameToString) { + QuicPingFrame ping; + QuicFrame frame(ping); + SetControlFrameId(5, &frame); + EXPECT_EQ(5u, GetControlFrameId(frame)); + std::ostringstream stream; + stream << frame.ping_frame; + EXPECT_EQ("{ control_frame_id: 5 }\n", stream.str()); + EXPECT_TRUE(IsControlFrame(frame.type)); +} + +TEST_F(QuicFramesTest, HandshakeDoneFrameToString) { + QuicHandshakeDoneFrame handshake_done; + QuicFrame frame(handshake_done); + SetControlFrameId(6, &frame); + EXPECT_EQ(6u, GetControlFrameId(frame)); + std::ostringstream stream; + stream << frame.handshake_done_frame; + EXPECT_EQ("{ control_frame_id: 6 }\n", stream.str()); + EXPECT_TRUE(IsControlFrame(frame.type)); +} + +TEST_F(QuicFramesTest, QuicAckFreuqncyFrameToString) { + QuicAckFrequencyFrame ack_frequency_frame; + ack_frequency_frame.sequence_number = 1; + ack_frequency_frame.packet_tolerance = 2; + ack_frequency_frame.max_ack_delay = QuicTime::Delta::FromMilliseconds(25); + ack_frequency_frame.ignore_order = false; + QuicFrame frame(&ack_frequency_frame); + ASSERT_EQ(ACK_FREQUENCY_FRAME, frame.type); + SetControlFrameId(6, &frame); + EXPECT_EQ(6u, GetControlFrameId(frame)); + std::ostringstream stream; + stream << *frame.ack_frequency_frame; + EXPECT_EQ( + "{ control_frame_id: 6, sequence_number: 1, packet_tolerance: 2, " + "max_ack_delay_ms: 25, ignore_order: 0 }\n", + stream.str()); + EXPECT_TRUE(IsControlFrame(frame.type)); +} + +TEST_F(QuicFramesTest, StreamFrameToString) { + QuicStreamFrame frame; + frame.stream_id = 1; + frame.fin = false; + frame.offset = 2; + frame.data_length = 3; + std::ostringstream stream; + stream << frame; + EXPECT_EQ("{ stream_id: 1, fin: 0, offset: 2, length: 3 }\n", stream.str()); + EXPECT_FALSE(IsControlFrame(frame.type)); +} + +TEST_F(QuicFramesTest, StopWaitingFrameToString) { + QuicStopWaitingFrame frame; + frame.least_unacked = QuicPacketNumber(2); + std::ostringstream stream; + stream << frame; + EXPECT_EQ("{ least_unacked: 2 }\n", stream.str()); + QuicFrame quic_frame(frame); + EXPECT_FALSE(IsControlFrame(quic_frame.type)); +} + +TEST_F(QuicFramesTest, IsAwaitingPacket) { + QuicAckFrame ack_frame1; + ack_frame1.largest_acked = QuicPacketNumber(10u); + ack_frame1.packets.AddRange(QuicPacketNumber(1), QuicPacketNumber(11)); + EXPECT_TRUE( + IsAwaitingPacket(ack_frame1, QuicPacketNumber(11u), QuicPacketNumber())); + EXPECT_FALSE( + IsAwaitingPacket(ack_frame1, QuicPacketNumber(1u), QuicPacketNumber())); + + ack_frame1.packets.Add(QuicPacketNumber(12)); + EXPECT_TRUE( + IsAwaitingPacket(ack_frame1, QuicPacketNumber(11u), QuicPacketNumber())); + + QuicAckFrame ack_frame2; + ack_frame2.largest_acked = QuicPacketNumber(100u); + ack_frame2.packets.AddRange(QuicPacketNumber(21), QuicPacketNumber(100)); + EXPECT_FALSE(IsAwaitingPacket(ack_frame2, QuicPacketNumber(11u), + QuicPacketNumber(20u))); + EXPECT_FALSE(IsAwaitingPacket(ack_frame2, QuicPacketNumber(80u), + QuicPacketNumber(20u))); + EXPECT_TRUE(IsAwaitingPacket(ack_frame2, QuicPacketNumber(101u), + QuicPacketNumber(20u))); + + ack_frame2.packets.AddRange(QuicPacketNumber(102), QuicPacketNumber(200)); + EXPECT_TRUE(IsAwaitingPacket(ack_frame2, QuicPacketNumber(101u), + QuicPacketNumber(20u))); +} + +TEST_F(QuicFramesTest, AddPacket) { + QuicAckFrame ack_frame1; + ack_frame1.packets.Add(QuicPacketNumber(1)); + ack_frame1.packets.Add(QuicPacketNumber(99)); + + EXPECT_EQ(2u, ack_frame1.packets.NumIntervals()); + EXPECT_EQ(QuicPacketNumber(1u), ack_frame1.packets.Min()); + EXPECT_EQ(QuicPacketNumber(99u), ack_frame1.packets.Max()); + + std::vector> expected_intervals; + expected_intervals.emplace_back( + QuicInterval(QuicPacketNumber(1), QuicPacketNumber(2))); + expected_intervals.emplace_back(QuicInterval( + QuicPacketNumber(99), QuicPacketNumber(100))); + + const std::vector> actual_intervals( + ack_frame1.packets.begin(), ack_frame1.packets.end()); + + EXPECT_EQ(expected_intervals, actual_intervals); + + ack_frame1.packets.Add(QuicPacketNumber(20)); + const std::vector> actual_intervals2( + ack_frame1.packets.begin(), ack_frame1.packets.end()); + + std::vector> expected_intervals2; + expected_intervals2.emplace_back( + QuicInterval(QuicPacketNumber(1), QuicPacketNumber(2))); + expected_intervals2.emplace_back(QuicInterval( + QuicPacketNumber(20), QuicPacketNumber(21))); + expected_intervals2.emplace_back(QuicInterval( + QuicPacketNumber(99), QuicPacketNumber(100))); + + EXPECT_EQ(3u, ack_frame1.packets.NumIntervals()); + EXPECT_EQ(expected_intervals2, actual_intervals2); + + ack_frame1.packets.Add(QuicPacketNumber(19)); + ack_frame1.packets.Add(QuicPacketNumber(21)); + + const std::vector> actual_intervals3( + ack_frame1.packets.begin(), ack_frame1.packets.end()); + + std::vector> expected_intervals3; + expected_intervals3.emplace_back( + QuicInterval(QuicPacketNumber(1), QuicPacketNumber(2))); + expected_intervals3.emplace_back(QuicInterval( + QuicPacketNumber(19), QuicPacketNumber(22))); + expected_intervals3.emplace_back(QuicInterval( + QuicPacketNumber(99), QuicPacketNumber(100))); + + EXPECT_EQ(expected_intervals3, actual_intervals3); + + ack_frame1.packets.Add(QuicPacketNumber(20)); + + const std::vector> actual_intervals4( + ack_frame1.packets.begin(), ack_frame1.packets.end()); + + EXPECT_EQ(expected_intervals3, actual_intervals4); + + QuicAckFrame ack_frame2; + ack_frame2.packets.Add(QuicPacketNumber(20)); + ack_frame2.packets.Add(QuicPacketNumber(40)); + ack_frame2.packets.Add(QuicPacketNumber(60)); + ack_frame2.packets.Add(QuicPacketNumber(10)); + ack_frame2.packets.Add(QuicPacketNumber(80)); + + const std::vector> actual_intervals5( + ack_frame2.packets.begin(), ack_frame2.packets.end()); + + std::vector> expected_intervals5; + expected_intervals5.emplace_back(QuicInterval( + QuicPacketNumber(10), QuicPacketNumber(11))); + expected_intervals5.emplace_back(QuicInterval( + QuicPacketNumber(20), QuicPacketNumber(21))); + expected_intervals5.emplace_back(QuicInterval( + QuicPacketNumber(40), QuicPacketNumber(41))); + expected_intervals5.emplace_back(QuicInterval( + QuicPacketNumber(60), QuicPacketNumber(61))); + expected_intervals5.emplace_back(QuicInterval( + QuicPacketNumber(80), QuicPacketNumber(81))); + + EXPECT_EQ(expected_intervals5, actual_intervals5); +} + +TEST_F(QuicFramesTest, AddInterval) { + QuicAckFrame ack_frame1; + ack_frame1.packets.AddRange(QuicPacketNumber(1), QuicPacketNumber(10)); + ack_frame1.packets.AddRange(QuicPacketNumber(50), QuicPacketNumber(100)); + + EXPECT_EQ(2u, ack_frame1.packets.NumIntervals()); + EXPECT_EQ(QuicPacketNumber(1u), ack_frame1.packets.Min()); + EXPECT_EQ(QuicPacketNumber(99u), ack_frame1.packets.Max()); + + std::vector> expected_intervals{ + {QuicPacketNumber(1), QuicPacketNumber(10)}, + {QuicPacketNumber(50), QuicPacketNumber(100)}, + }; + + const std::vector> actual_intervals( + ack_frame1.packets.begin(), ack_frame1.packets.end()); + + EXPECT_EQ(expected_intervals, actual_intervals); + + // Add a range in the middle. + ack_frame1.packets.AddRange(QuicPacketNumber(20), QuicPacketNumber(30)); + + const std::vector> actual_intervals2( + ack_frame1.packets.begin(), ack_frame1.packets.end()); + + std::vector> expected_intervals2{ + {QuicPacketNumber(1), QuicPacketNumber(10)}, + {QuicPacketNumber(20), QuicPacketNumber(30)}, + {QuicPacketNumber(50), QuicPacketNumber(100)}, + }; + + EXPECT_EQ(expected_intervals2.size(), ack_frame1.packets.NumIntervals()); + EXPECT_EQ(expected_intervals2, actual_intervals2); + + // Add ranges at both ends. + QuicAckFrame ack_frame2; + ack_frame2.packets.AddRange(QuicPacketNumber(20), QuicPacketNumber(25)); + ack_frame2.packets.AddRange(QuicPacketNumber(40), QuicPacketNumber(45)); + ack_frame2.packets.AddRange(QuicPacketNumber(60), QuicPacketNumber(65)); + ack_frame2.packets.AddRange(QuicPacketNumber(10), QuicPacketNumber(15)); + ack_frame2.packets.AddRange(QuicPacketNumber(80), QuicPacketNumber(85)); + + const std::vector> actual_intervals8( + ack_frame2.packets.begin(), ack_frame2.packets.end()); + + std::vector> expected_intervals8{ + {QuicPacketNumber(10), QuicPacketNumber(15)}, + {QuicPacketNumber(20), QuicPacketNumber(25)}, + {QuicPacketNumber(40), QuicPacketNumber(45)}, + {QuicPacketNumber(60), QuicPacketNumber(65)}, + {QuicPacketNumber(80), QuicPacketNumber(85)}, + }; + + EXPECT_EQ(expected_intervals8, actual_intervals8); +} + +TEST_F(QuicFramesTest, AddAdjacentForward) { + QuicAckFrame ack_frame1; + ack_frame1.packets.Add(QuicPacketNumber(49)); + ack_frame1.packets.AddRange(QuicPacketNumber(50), QuicPacketNumber(60)); + ack_frame1.packets.AddRange(QuicPacketNumber(60), QuicPacketNumber(70)); + ack_frame1.packets.AddRange(QuicPacketNumber(70), QuicPacketNumber(100)); + + std::vector> expected_intervals; + expected_intervals.emplace_back(QuicInterval( + QuicPacketNumber(49), QuicPacketNumber(100))); + + const std::vector> actual_intervals( + ack_frame1.packets.begin(), ack_frame1.packets.end()); + + EXPECT_EQ(expected_intervals, actual_intervals); +} + +TEST_F(QuicFramesTest, AddAdjacentReverse) { + QuicAckFrame ack_frame1; + ack_frame1.packets.AddRange(QuicPacketNumber(70), QuicPacketNumber(100)); + ack_frame1.packets.AddRange(QuicPacketNumber(60), QuicPacketNumber(70)); + ack_frame1.packets.AddRange(QuicPacketNumber(50), QuicPacketNumber(60)); + ack_frame1.packets.Add(QuicPacketNumber(49)); + + std::vector> expected_intervals; + expected_intervals.emplace_back(QuicInterval( + QuicPacketNumber(49), QuicPacketNumber(100))); + + const std::vector> actual_intervals( + ack_frame1.packets.begin(), ack_frame1.packets.end()); + + EXPECT_EQ(expected_intervals, actual_intervals); +} + +TEST_F(QuicFramesTest, RemoveSmallestInterval) { + QuicAckFrame ack_frame1; + ack_frame1.largest_acked = QuicPacketNumber(100u); + ack_frame1.packets.AddRange(QuicPacketNumber(51), QuicPacketNumber(60)); + ack_frame1.packets.AddRange(QuicPacketNumber(71), QuicPacketNumber(80)); + ack_frame1.packets.AddRange(QuicPacketNumber(91), QuicPacketNumber(100)); + ack_frame1.packets.RemoveSmallestInterval(); + EXPECT_EQ(2u, ack_frame1.packets.NumIntervals()); + EXPECT_EQ(QuicPacketNumber(71u), ack_frame1.packets.Min()); + EXPECT_EQ(QuicPacketNumber(99u), ack_frame1.packets.Max()); + + ack_frame1.packets.RemoveSmallestInterval(); + EXPECT_EQ(1u, ack_frame1.packets.NumIntervals()); + EXPECT_EQ(QuicPacketNumber(91u), ack_frame1.packets.Min()); + EXPECT_EQ(QuicPacketNumber(99u), ack_frame1.packets.Max()); +} + +TEST_F(QuicFramesTest, CopyQuicFrames) { + QuicFrames frames; + QuicMessageFrame* message_frame = + new QuicMessageFrame(1, MemSliceFromString("message")); + // Construct a frame list. + for (uint8_t i = 0; i < NUM_FRAME_TYPES; ++i) { + switch (i) { + case PADDING_FRAME: + frames.push_back(QuicFrame(QuicPaddingFrame(-1))); + break; + case RST_STREAM_FRAME: + frames.push_back(QuicFrame(new QuicRstStreamFrame())); + break; + case CONNECTION_CLOSE_FRAME: + frames.push_back(QuicFrame(new QuicConnectionCloseFrame())); + break; + case GOAWAY_FRAME: + frames.push_back(QuicFrame(new QuicGoAwayFrame())); + break; + case WINDOW_UPDATE_FRAME: + frames.push_back(QuicFrame(QuicWindowUpdateFrame())); + break; + case BLOCKED_FRAME: + frames.push_back(QuicFrame(QuicBlockedFrame())); + break; + case STOP_WAITING_FRAME: + frames.push_back(QuicFrame(QuicStopWaitingFrame())); + break; + case PING_FRAME: + frames.push_back(QuicFrame(QuicPingFrame())); + break; + case CRYPTO_FRAME: + frames.push_back(QuicFrame(new QuicCryptoFrame())); + break; + case STREAM_FRAME: + frames.push_back(QuicFrame(QuicStreamFrame())); + break; + case ACK_FRAME: + frames.push_back(QuicFrame(new QuicAckFrame())); + break; + case MTU_DISCOVERY_FRAME: + frames.push_back(QuicFrame(QuicMtuDiscoveryFrame())); + break; + case NEW_CONNECTION_ID_FRAME: + frames.push_back(QuicFrame(new QuicNewConnectionIdFrame())); + break; + case MAX_STREAMS_FRAME: + frames.push_back(QuicFrame(QuicMaxStreamsFrame())); + break; + case STREAMS_BLOCKED_FRAME: + frames.push_back(QuicFrame(QuicStreamsBlockedFrame())); + break; + case PATH_RESPONSE_FRAME: + frames.push_back(QuicFrame(QuicPathResponseFrame())); + break; + case PATH_CHALLENGE_FRAME: + frames.push_back(QuicFrame(QuicPathChallengeFrame())); + break; + case STOP_SENDING_FRAME: + frames.push_back(QuicFrame(QuicStopSendingFrame())); + break; + case MESSAGE_FRAME: + frames.push_back(QuicFrame(message_frame)); + break; + case NEW_TOKEN_FRAME: + frames.push_back(QuicFrame(new QuicNewTokenFrame())); + break; + case RETIRE_CONNECTION_ID_FRAME: + frames.push_back(QuicFrame(new QuicRetireConnectionIdFrame())); + break; + case HANDSHAKE_DONE_FRAME: + frames.push_back(QuicFrame(QuicHandshakeDoneFrame())); + break; + case ACK_FREQUENCY_FRAME: + frames.push_back(QuicFrame(new QuicAckFrequencyFrame())); + break; + default: + ASSERT_TRUE(false) + << "Please fix CopyQuicFrames if a new frame type is added."; + break; + } + } + + QuicFrames copy = + CopyQuicFrames(quiche::SimpleBufferAllocator::Get(), frames); + ASSERT_EQ(NUM_FRAME_TYPES, copy.size()); + for (uint8_t i = 0; i < NUM_FRAME_TYPES; ++i) { + EXPECT_EQ(i, copy[i].type); + if (i == MESSAGE_FRAME) { + // Verify message frame is correctly copied. + EXPECT_EQ(1u, copy[i].message_frame->message_id); + EXPECT_EQ(nullptr, copy[i].message_frame->data); + EXPECT_EQ(7u, copy[i].message_frame->message_length); + ASSERT_EQ(1u, copy[i].message_frame->message_data.size()); + EXPECT_EQ(0, memcmp(copy[i].message_frame->message_data[0].data(), + frames[i].message_frame->message_data[0].data(), 7)); + } else if (i == PATH_CHALLENGE_FRAME) { + EXPECT_EQ(copy[i].path_challenge_frame.control_frame_id, + frames[i].path_challenge_frame.control_frame_id); + EXPECT_EQ(memcmp(©[i].path_challenge_frame.data_buffer, + &frames[i].path_challenge_frame.data_buffer, + copy[i].path_challenge_frame.data_buffer.size()), + 0); + } else if (i == PATH_RESPONSE_FRAME) { + EXPECT_EQ(copy[i].path_response_frame.control_frame_id, + frames[i].path_response_frame.control_frame_id); + EXPECT_EQ(memcmp(©[i].path_response_frame.data_buffer, + &frames[i].path_response_frame.data_buffer, + copy[i].path_response_frame.data_buffer.size()), + 0); + } + } + DeleteFrames(&frames); + DeleteFrames(©); +} + +class PacketNumberQueueTest : public QuicTest {}; + +// Tests that a queue contains the expected data after calls to Add(). +TEST_F(PacketNumberQueueTest, AddRange) { + PacketNumberQueue queue; + queue.AddRange(QuicPacketNumber(1), QuicPacketNumber(51)); + queue.Add(QuicPacketNumber(53)); + + EXPECT_FALSE(queue.Contains(QuicPacketNumber())); + for (int i = 1; i < 51; ++i) { + EXPECT_TRUE(queue.Contains(QuicPacketNumber(i))); + } + EXPECT_FALSE(queue.Contains(QuicPacketNumber(51))); + EXPECT_FALSE(queue.Contains(QuicPacketNumber(52))); + EXPECT_TRUE(queue.Contains(QuicPacketNumber(53))); + EXPECT_FALSE(queue.Contains(QuicPacketNumber(54))); + EXPECT_EQ(51u, queue.NumPacketsSlow()); + EXPECT_EQ(QuicPacketNumber(1u), queue.Min()); + EXPECT_EQ(QuicPacketNumber(53u), queue.Max()); + + queue.Add(QuicPacketNumber(70)); + EXPECT_EQ(QuicPacketNumber(70u), queue.Max()); +} + +// Tests Contains function +TEST_F(PacketNumberQueueTest, Contains) { + PacketNumberQueue queue; + EXPECT_FALSE(queue.Contains(QuicPacketNumber())); + queue.AddRange(QuicPacketNumber(5), QuicPacketNumber(10)); + queue.Add(QuicPacketNumber(20)); + + for (int i = 1; i < 5; ++i) { + EXPECT_FALSE(queue.Contains(QuicPacketNumber(i))); + } + + for (int i = 5; i < 10; ++i) { + EXPECT_TRUE(queue.Contains(QuicPacketNumber(i))); + } + for (int i = 10; i < 20; ++i) { + EXPECT_FALSE(queue.Contains(QuicPacketNumber(i))); + } + EXPECT_TRUE(queue.Contains(QuicPacketNumber(20))); + EXPECT_FALSE(queue.Contains(QuicPacketNumber(21))); + + PacketNumberQueue queue2; + EXPECT_FALSE(queue2.Contains(QuicPacketNumber(1))); + for (int i = 1; i < 51; ++i) { + queue2.Add(QuicPacketNumber(2 * i)); + } + EXPECT_FALSE(queue2.Contains(QuicPacketNumber())); + for (int i = 1; i < 51; ++i) { + if (i % 2 == 0) { + EXPECT_TRUE(queue2.Contains(QuicPacketNumber(i))); + } else { + EXPECT_FALSE(queue2.Contains(QuicPacketNumber(i))); + } + } + EXPECT_FALSE(queue2.Contains(QuicPacketNumber(101))); +} + +// Tests that a queue contains the expected data after calls to RemoveUpTo(). +TEST_F(PacketNumberQueueTest, Removal) { + PacketNumberQueue queue; + EXPECT_FALSE(queue.Contains(QuicPacketNumber(51))); + queue.AddRange(QuicPacketNumber(1), QuicPacketNumber(100)); + + EXPECT_TRUE(queue.RemoveUpTo(QuicPacketNumber(51))); + EXPECT_FALSE(queue.RemoveUpTo(QuicPacketNumber(51))); + + EXPECT_FALSE(queue.Contains(QuicPacketNumber())); + for (int i = 1; i < 51; ++i) { + EXPECT_FALSE(queue.Contains(QuicPacketNumber(i))); + } + for (int i = 51; i < 100; ++i) { + EXPECT_TRUE(queue.Contains(QuicPacketNumber(i))); + } + EXPECT_EQ(49u, queue.NumPacketsSlow()); + EXPECT_EQ(QuicPacketNumber(51u), queue.Min()); + EXPECT_EQ(QuicPacketNumber(99u), queue.Max()); + + PacketNumberQueue queue2; + queue2.AddRange(QuicPacketNumber(1), QuicPacketNumber(5)); + EXPECT_TRUE(queue2.RemoveUpTo(QuicPacketNumber(3))); + EXPECT_TRUE(queue2.RemoveUpTo(QuicPacketNumber(50))); + EXPECT_TRUE(queue2.Empty()); +} + +// Tests that a queue is empty when all of its elements are removed. +TEST_F(PacketNumberQueueTest, Empty) { + PacketNumberQueue queue; + EXPECT_TRUE(queue.Empty()); + EXPECT_EQ(0u, queue.NumPacketsSlow()); + + queue.AddRange(QuicPacketNumber(1), QuicPacketNumber(100)); + EXPECT_TRUE(queue.RemoveUpTo(QuicPacketNumber(100))); + EXPECT_TRUE(queue.Empty()); + EXPECT_EQ(0u, queue.NumPacketsSlow()); +} + +// Tests that logging the state of a PacketNumberQueue does not crash. +TEST_F(PacketNumberQueueTest, LogDoesNotCrash) { + std::ostringstream oss; + PacketNumberQueue queue; + oss << queue; + + queue.Add(QuicPacketNumber(1)); + queue.AddRange(QuicPacketNumber(50), QuicPacketNumber(100)); + oss << queue; +} + +// Tests that the iterators returned from a packet queue iterate over the queue. +TEST_F(PacketNumberQueueTest, Iterators) { + PacketNumberQueue queue; + queue.AddRange(QuicPacketNumber(1), QuicPacketNumber(100)); + + const std::vector> actual_intervals( + queue.begin(), queue.end()); + + PacketNumberQueue queue2; + for (int i = 1; i < 100; i++) { + queue2.AddRange(QuicPacketNumber(i), QuicPacketNumber(i + 1)); + } + + const std::vector> actual_intervals2( + queue2.begin(), queue2.end()); + + std::vector> expected_intervals; + expected_intervals.emplace_back(QuicInterval( + QuicPacketNumber(1), QuicPacketNumber(100))); + EXPECT_EQ(expected_intervals, actual_intervals); + EXPECT_EQ(expected_intervals, actual_intervals2); + EXPECT_EQ(actual_intervals, actual_intervals2); +} + +TEST_F(PacketNumberQueueTest, ReversedIterators) { + PacketNumberQueue queue; + queue.AddRange(QuicPacketNumber(1), QuicPacketNumber(100)); + PacketNumberQueue queue2; + for (int i = 1; i < 100; i++) { + queue2.AddRange(QuicPacketNumber(i), QuicPacketNumber(i + 1)); + } + const std::vector> actual_intervals( + queue.rbegin(), queue.rend()); + const std::vector> actual_intervals2( + queue2.rbegin(), queue2.rend()); + + std::vector> expected_intervals; + expected_intervals.emplace_back(QuicInterval( + QuicPacketNumber(1), QuicPacketNumber(100))); + + EXPECT_EQ(expected_intervals, actual_intervals); + EXPECT_EQ(expected_intervals, actual_intervals2); + EXPECT_EQ(actual_intervals, actual_intervals2); + + PacketNumberQueue queue3; + for (int i = 1; i < 20; i++) { + queue3.Add(QuicPacketNumber(2 * i)); + } + + auto begin = queue3.begin(); + auto end = queue3.end(); + --end; + auto rbegin = queue3.rbegin(); + auto rend = queue3.rend(); + --rend; + + EXPECT_EQ(*begin, *rend); + EXPECT_EQ(*rbegin, *end); +} + +TEST_F(PacketNumberQueueTest, IntervalLengthAndRemoveInterval) { + PacketNumberQueue queue; + queue.AddRange(QuicPacketNumber(1), QuicPacketNumber(10)); + queue.AddRange(QuicPacketNumber(20), QuicPacketNumber(30)); + queue.AddRange(QuicPacketNumber(40), QuicPacketNumber(50)); + EXPECT_EQ(3u, queue.NumIntervals()); + EXPECT_EQ(10u, queue.LastIntervalLength()); + + EXPECT_TRUE(queue.RemoveUpTo(QuicPacketNumber(25))); + EXPECT_EQ(2u, queue.NumIntervals()); + EXPECT_EQ(10u, queue.LastIntervalLength()); + EXPECT_EQ(QuicPacketNumber(25u), queue.Min()); + EXPECT_EQ(QuicPacketNumber(49u), queue.Max()); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/frames/quic_goaway_frame.cc b/quiche/quic/core/frames/quic_goaway_frame.cc new file mode 100644 index 000000000000..c2394a7bb525 --- /dev/null +++ b/quiche/quic/core/frames/quic_goaway_frame.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/frames/quic_goaway_frame.h" + +#include + +namespace quic { + +QuicGoAwayFrame::QuicGoAwayFrame(QuicControlFrameId control_frame_id, + QuicErrorCode error_code, + QuicStreamId last_good_stream_id, + const std::string& reason) + : control_frame_id(control_frame_id), + error_code(error_code), + last_good_stream_id(last_good_stream_id), + reason_phrase(reason) {} + +std::ostream& operator<<(std::ostream& os, + const QuicGoAwayFrame& goaway_frame) { + os << "{ control_frame_id: " << goaway_frame.control_frame_id + << ", error_code: " << goaway_frame.error_code + << ", last_good_stream_id: " << goaway_frame.last_good_stream_id + << ", reason_phrase: '" << goaway_frame.reason_phrase << "' }\n"; + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/frames/quic_goaway_frame.h b/quiche/quic/core/frames/quic_goaway_frame.h new file mode 100644 index 000000000000..c09488959020 --- /dev/null +++ b/quiche/quic/core/frames/quic_goaway_frame.h @@ -0,0 +1,35 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_FRAMES_QUIC_GOAWAY_FRAME_H_ +#define QUICHE_QUIC_CORE_FRAMES_QUIC_GOAWAY_FRAME_H_ + +#include +#include + +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" + +namespace quic { + +struct QUIC_EXPORT_PRIVATE QuicGoAwayFrame { + QuicGoAwayFrame() = default; + QuicGoAwayFrame(QuicControlFrameId control_frame_id, QuicErrorCode error_code, + QuicStreamId last_good_stream_id, const std::string& reason); + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<(std::ostream& os, + const QuicGoAwayFrame& g); + + // A unique identifier of this control frame. 0 when this frame is received, + // and non-zero when sent. + QuicControlFrameId control_frame_id = kInvalidControlFrameId; + QuicErrorCode error_code = QUIC_NO_ERROR; + QuicStreamId last_good_stream_id = 0; + std::string reason_phrase; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_FRAMES_QUIC_GOAWAY_FRAME_H_ diff --git a/quiche/quic/core/frames/quic_handshake_done_frame.cc b/quiche/quic/core/frames/quic_handshake_done_frame.cc new file mode 100644 index 000000000000..e8a7110d7df8 --- /dev/null +++ b/quiche/quic/core/frames/quic_handshake_done_frame.cc @@ -0,0 +1,24 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/frames/quic_handshake_done_frame.h" + +namespace quic { + +QuicHandshakeDoneFrame::QuicHandshakeDoneFrame() + : QuicInlinedFrame(HANDSHAKE_DONE_FRAME) {} + +QuicHandshakeDoneFrame::QuicHandshakeDoneFrame( + QuicControlFrameId control_frame_id) + : QuicInlinedFrame(HANDSHAKE_DONE_FRAME), + control_frame_id(control_frame_id) {} + +std::ostream& operator<<(std::ostream& os, + const QuicHandshakeDoneFrame& handshake_done_frame) { + os << "{ control_frame_id: " << handshake_done_frame.control_frame_id + << " }\n"; + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/frames/quic_handshake_done_frame.h b/quiche/quic/core/frames/quic_handshake_done_frame.h new file mode 100644 index 000000000000..aa5cbfc2ec30 --- /dev/null +++ b/quiche/quic/core/frames/quic_handshake_done_frame.h @@ -0,0 +1,34 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_FRAMES_QUIC_HANDSHAKE_DONE_FRAME_H_ +#define QUICHE_QUIC_CORE_FRAMES_QUIC_HANDSHAKE_DONE_FRAME_H_ + +#include "quiche/quic/core/frames/quic_inlined_frame.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// A HANDSHAKE_DONE frame contains no payload, and it is retransmittable, +// and ACK'd just like other normal frames. +struct QUIC_EXPORT_PRIVATE QuicHandshakeDoneFrame + : public QuicInlinedFrame { + QuicHandshakeDoneFrame(); + explicit QuicHandshakeDoneFrame(QuicControlFrameId control_frame_id); + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const QuicHandshakeDoneFrame& handshake_done_frame); + + QuicFrameType type; + + // A unique identifier of this control frame. 0 when this frame is received, + // and non-zero when sent. + QuicControlFrameId control_frame_id = kInvalidControlFrameId; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_FRAMES_QUIC_HANDSHAKE_DONE_FRAME_H_ diff --git a/quiche/quic/core/frames/quic_inlined_frame.h b/quiche/quic/core/frames/quic_inlined_frame.h new file mode 100644 index 000000000000..8cab15865daf --- /dev/null +++ b/quiche/quic/core/frames/quic_inlined_frame.h @@ -0,0 +1,34 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_FRAMES_QUIC_INLINED_FRAME_H_ +#define QUICHE_QUIC_CORE_FRAMES_QUIC_INLINED_FRAME_H_ + +#include + +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// QuicInlinedFrame is the base class of all frame types that is inlined in the +// QuicFrame class. It gurantees all inlined frame types contain a 'type' field +// at offset 0, such that QuicFrame.type can get the correct frame type for both +// inline and out-of-line frame types. +template +struct QUIC_EXPORT_PRIVATE QuicInlinedFrame { + QuicInlinedFrame(QuicFrameType type) { + static_cast(this)->type = type; + static_assert(std::is_standard_layout::value, + "Inlined frame must have a standard layout"); + static_assert(offsetof(DerivedT, type) == 0, + "type must be the first field."); + static_assert(sizeof(DerivedT) <= 24, + "Frames larger than 24 bytes should not be inlined."); + } +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_FRAMES_QUIC_INLINED_FRAME_H_ diff --git a/quiche/quic/core/frames/quic_max_streams_frame.cc b/quiche/quic/core/frames/quic_max_streams_frame.cc new file mode 100644 index 000000000000..594224b2fe09 --- /dev/null +++ b/quiche/quic/core/frames/quic_max_streams_frame.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/frames/quic_max_streams_frame.h" + +namespace quic { + +QuicMaxStreamsFrame::QuicMaxStreamsFrame() + : QuicInlinedFrame(MAX_STREAMS_FRAME) {} + +QuicMaxStreamsFrame::QuicMaxStreamsFrame(QuicControlFrameId control_frame_id, + QuicStreamCount stream_count, + bool unidirectional) + : QuicInlinedFrame(MAX_STREAMS_FRAME), + control_frame_id(control_frame_id), + stream_count(stream_count), + unidirectional(unidirectional) {} + +std::ostream& operator<<(std::ostream& os, const QuicMaxStreamsFrame& frame) { + os << "{ control_frame_id: " << frame.control_frame_id + << ", stream_count: " << frame.stream_count + << ((frame.unidirectional) ? ", unidirectional }\n" + : ", bidirectional }\n"); + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/frames/quic_max_streams_frame.h b/quiche/quic/core/frames/quic_max_streams_frame.h new file mode 100644 index 000000000000..eaee69a4839e --- /dev/null +++ b/quiche/quic/core/frames/quic_max_streams_frame.h @@ -0,0 +1,43 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_FRAMES_QUIC_MAX_STREAMS_FRAME_H_ +#define QUICHE_QUIC_CORE_FRAMES_QUIC_MAX_STREAMS_FRAME_H_ + +#include + +#include "quiche/quic/core/frames/quic_inlined_frame.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// IETF format MAX_STREAMS frame. +// This frame is used by the sender to inform the peer of the number of +// streams that the peer may open and that the sender will accept. +struct QUIC_EXPORT_PRIVATE QuicMaxStreamsFrame + : public QuicInlinedFrame { + QuicMaxStreamsFrame(); + QuicMaxStreamsFrame(QuicControlFrameId control_frame_id, + QuicStreamCount stream_count, bool unidirectional); + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const QuicMaxStreamsFrame& frame); + + QuicFrameType type; + + // A unique identifier of this control frame. 0 when this frame is received, + // and non-zero when sent. + QuicControlFrameId control_frame_id = kInvalidControlFrameId; + + // The number of streams that may be opened. + QuicStreamCount stream_count = 0; + // Whether uni- or bi-directional streams + bool unidirectional = false; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_FRAMES_QUIC_MAX_STREAMS_FRAME_H_ diff --git a/quiche/quic/core/frames/quic_message_frame.cc b/quiche/quic/core/frames/quic_message_frame.cc new file mode 100644 index 000000000000..935d7ce39cb0 --- /dev/null +++ b/quiche/quic/core/frames/quic_message_frame.cc @@ -0,0 +1,42 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/frames/quic_message_frame.h" + +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" + +namespace quic { + +QuicMessageFrame::QuicMessageFrame(QuicMessageId message_id) + : message_id(message_id), data(nullptr), message_length(0) {} + +QuicMessageFrame::QuicMessageFrame(QuicMessageId message_id, + absl::Span span) + : message_id(message_id), data(nullptr), message_length(0) { + for (quiche::QuicheMemSlice& slice : span) { + if (slice.empty()) { + continue; + } + message_length += slice.length(); + message_data.push_back(std::move(slice)); + } +} +QuicMessageFrame::QuicMessageFrame(QuicMessageId message_id, + quiche::QuicheMemSlice slice) + : QuicMessageFrame(message_id, absl::MakeSpan(&slice, 1)) {} + +QuicMessageFrame::QuicMessageFrame(const char* data, QuicPacketLength length) + : message_id(0), data(data), message_length(length) {} + +QuicMessageFrame::~QuicMessageFrame() {} + +std::ostream& operator<<(std::ostream& os, const QuicMessageFrame& s) { + os << " message_id: " << s.message_id + << ", message_length: " << s.message_length << " }\n"; + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/frames/quic_message_frame.h b/quiche/quic/core/frames/quic_message_frame.h new file mode 100644 index 000000000000..91f73646f63f --- /dev/null +++ b/quiche/quic/core/frames/quic_message_frame.h @@ -0,0 +1,51 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_FRAMES_QUIC_MESSAGE_FRAME_H_ +#define QUICHE_QUIC_CORE_FRAMES_QUIC_MESSAGE_FRAME_H_ + +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" + +namespace quic { + +using QuicMessageData = absl::InlinedVector; + +struct QUIC_EXPORT_PRIVATE QuicMessageFrame { + QuicMessageFrame() = default; + explicit QuicMessageFrame(QuicMessageId message_id); + QuicMessageFrame(QuicMessageId message_id, + absl::Span span); + QuicMessageFrame(QuicMessageId message_id, quiche::QuicheMemSlice slice); + QuicMessageFrame(const char* data, QuicPacketLength length); + + QuicMessageFrame(const QuicMessageFrame& other) = delete; + QuicMessageFrame& operator=(const QuicMessageFrame& other) = delete; + + QuicMessageFrame(QuicMessageFrame&& other) = default; + QuicMessageFrame& operator=(QuicMessageFrame&& other) = default; + + ~QuicMessageFrame(); + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const QuicMessageFrame& s); + + // message_id is only used on the sender side and does not get serialized on + // wire. + QuicMessageId message_id = 0; + // Not owned, only used on read path. + const char* data = nullptr; + // Total length of message_data, must be fit into one packet. + QuicPacketLength message_length = 0; + + // The actual message data which is reference counted, used on write path. + QuicMessageData message_data; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_FRAMES_QUIC_MESSAGE_FRAME_H_ diff --git a/quiche/quic/core/frames/quic_mtu_discovery_frame.h b/quiche/quic/core/frames/quic_mtu_discovery_frame.h new file mode 100644 index 000000000000..7189463ca395 --- /dev/null +++ b/quiche/quic/core/frames/quic_mtu_discovery_frame.h @@ -0,0 +1,25 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_FRAMES_QUIC_MTU_DISCOVERY_FRAME_H_ +#define QUICHE_QUIC_CORE_FRAMES_QUIC_MTU_DISCOVERY_FRAME_H_ + +#include "quiche/quic/core/frames/quic_inlined_frame.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// A path MTU discovery frame contains no payload and is serialized as a ping +// frame. +struct QUIC_EXPORT_PRIVATE QuicMtuDiscoveryFrame + : public QuicInlinedFrame { + QuicMtuDiscoveryFrame() : QuicInlinedFrame(MTU_DISCOVERY_FRAME) {} + + QuicFrameType type; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_FRAMES_QUIC_MTU_DISCOVERY_FRAME_H_ diff --git a/quiche/quic/core/frames/quic_new_connection_id_frame.cc b/quiche/quic/core/frames/quic_new_connection_id_frame.cc new file mode 100644 index 000000000000..2d3746051e75 --- /dev/null +++ b/quiche/quic/core/frames/quic_new_connection_id_frame.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/frames/quic_new_connection_id_frame.h" + +namespace quic { + +QuicNewConnectionIdFrame::QuicNewConnectionIdFrame( + QuicControlFrameId control_frame_id, QuicConnectionId connection_id, + QuicConnectionIdSequenceNumber sequence_number, + StatelessResetToken stateless_reset_token, uint64_t retire_prior_to) + : control_frame_id(control_frame_id), + connection_id(connection_id), + sequence_number(sequence_number), + stateless_reset_token(stateless_reset_token), + retire_prior_to(retire_prior_to) { + QUICHE_DCHECK(retire_prior_to <= sequence_number); +} + +std::ostream& operator<<(std::ostream& os, + const QuicNewConnectionIdFrame& frame) { + os << "{ control_frame_id: " << frame.control_frame_id + << ", connection_id: " << frame.connection_id + << ", sequence_number: " << frame.sequence_number + << ", retire_prior_to: " << frame.retire_prior_to << " }\n"; + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/frames/quic_new_connection_id_frame.h b/quiche/quic/core/frames/quic_new_connection_id_frame.h new file mode 100644 index 000000000000..8f5e9ba4a11e --- /dev/null +++ b/quiche/quic/core/frames/quic_new_connection_id_frame.h @@ -0,0 +1,39 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_FRAMES_QUIC_NEW_CONNECTION_ID_FRAME_H_ +#define QUICHE_QUIC_CORE_FRAMES_QUIC_NEW_CONNECTION_ID_FRAME_H_ + +#include + +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" + +namespace quic { + +struct QUIC_EXPORT_PRIVATE QuicNewConnectionIdFrame { + QuicNewConnectionIdFrame() = default; + QuicNewConnectionIdFrame(QuicControlFrameId control_frame_id, + QuicConnectionId connection_id, + QuicConnectionIdSequenceNumber sequence_number, + StatelessResetToken stateless_reset_token, + uint64_t retire_prior_to); + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const QuicNewConnectionIdFrame& frame); + + // A unique identifier of this control frame. 0 when this frame is received, + // and non-zero when sent. + QuicControlFrameId control_frame_id = kInvalidControlFrameId; + QuicConnectionId connection_id = EmptyQuicConnectionId(); + QuicConnectionIdSequenceNumber sequence_number = 0; + StatelessResetToken stateless_reset_token; + uint64_t retire_prior_to = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_FRAMES_QUIC_NEW_CONNECTION_ID_FRAME_H_ diff --git a/quiche/quic/core/frames/quic_new_token_frame.cc b/quiche/quic/core/frames/quic_new_token_frame.cc new file mode 100644 index 000000000000..7b5190d225a4 --- /dev/null +++ b/quiche/quic/core/frames/quic_new_token_frame.cc @@ -0,0 +1,23 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/frames/quic_new_token_frame.h" + +#include "absl/strings/escaping.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +QuicNewTokenFrame::QuicNewTokenFrame(QuicControlFrameId control_frame_id, + absl::string_view token) + : control_frame_id(control_frame_id), + token(std::string(token.data(), token.length())) {} + +std::ostream& operator<<(std::ostream& os, const QuicNewTokenFrame& s) { + os << "{ control_frame_id: " << s.control_frame_id + << ", token: " << absl::BytesToHexString(s.token) << " }\n"; + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/frames/quic_new_token_frame.h b/quiche/quic/core/frames/quic_new_token_frame.h new file mode 100644 index 000000000000..9761ed0117fc --- /dev/null +++ b/quiche/quic/core/frames/quic_new_token_frame.h @@ -0,0 +1,37 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_FRAMES_QUIC_NEW_TOKEN_FRAME_H_ +#define QUICHE_QUIC_CORE_FRAMES_QUIC_NEW_TOKEN_FRAME_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +struct QUIC_EXPORT_PRIVATE QuicNewTokenFrame { + QuicNewTokenFrame() = default; + QuicNewTokenFrame(QuicControlFrameId control_frame_id, + absl::string_view token); + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const QuicNewTokenFrame& s); + + // A unique identifier of this control frame. 0 when this frame is received, + // and non-zero when sent. + QuicControlFrameId control_frame_id = kInvalidControlFrameId; + + std::string token; +}; +static_assert(sizeof(QuicNewTokenFrame) <= 64, + "Keep the QuicNewTokenFrame size to a cacheline."); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_FRAMES_QUIC_NEW_TOKEN_FRAME_H_ diff --git a/quiche/quic/core/frames/quic_padding_frame.cc b/quiche/quic/core/frames/quic_padding_frame.cc new file mode 100644 index 000000000000..2170835cfaf8 --- /dev/null +++ b/quiche/quic/core/frames/quic_padding_frame.cc @@ -0,0 +1,15 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/frames/quic_padding_frame.h" + +namespace quic { + +std::ostream& operator<<(std::ostream& os, + const QuicPaddingFrame& padding_frame) { + os << "{ num_padding_bytes: " << padding_frame.num_padding_bytes << " }\n"; + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/frames/quic_padding_frame.h b/quiche/quic/core/frames/quic_padding_frame.h new file mode 100644 index 000000000000..a903c5d949c4 --- /dev/null +++ b/quiche/quic/core/frames/quic_padding_frame.h @@ -0,0 +1,36 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_FRAMES_QUIC_PADDING_FRAME_H_ +#define QUICHE_QUIC_CORE_FRAMES_QUIC_PADDING_FRAME_H_ + +#include +#include + +#include "quiche/quic/core/frames/quic_inlined_frame.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// A padding frame contains no payload. +struct QUIC_EXPORT_PRIVATE QuicPaddingFrame + : public QuicInlinedFrame { + QuicPaddingFrame() : QuicInlinedFrame(PADDING_FRAME) {} + explicit QuicPaddingFrame(int num_padding_bytes) + : QuicInlinedFrame(PADDING_FRAME), num_padding_bytes(num_padding_bytes) {} + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const QuicPaddingFrame& padding_frame); + + QuicFrameType type; + + // -1: full padding to the end of a max-sized packet + // otherwise: only pad up to num_padding_bytes bytes + int num_padding_bytes = -1; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_FRAMES_QUIC_PADDING_FRAME_H_ diff --git a/quiche/quic/core/frames/quic_path_challenge_frame.cc b/quiche/quic/core/frames/quic_path_challenge_frame.cc new file mode 100644 index 000000000000..5f4f57b76229 --- /dev/null +++ b/quiche/quic/core/frames/quic_path_challenge_frame.cc @@ -0,0 +1,32 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/frames/quic_path_challenge_frame.h" + +#include "absl/strings/escaping.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" + +namespace quic { + +QuicPathChallengeFrame::QuicPathChallengeFrame() + : QuicInlinedFrame(PATH_CHALLENGE_FRAME) {} + +QuicPathChallengeFrame::QuicPathChallengeFrame( + QuicControlFrameId control_frame_id, const QuicPathFrameBuffer& data_buff) + : QuicInlinedFrame(PATH_CHALLENGE_FRAME), + control_frame_id(control_frame_id) { + memcpy(data_buffer.data(), data_buff.data(), data_buffer.size()); +} + +std::ostream& operator<<(std::ostream& os, + const QuicPathChallengeFrame& frame) { + os << "{ control_frame_id: " << frame.control_frame_id << ", data: " + << absl::BytesToHexString(absl::string_view( + reinterpret_cast(frame.data_buffer.data()), + frame.data_buffer.size())) + << " }\n"; + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/frames/quic_path_challenge_frame.h b/quiche/quic/core/frames/quic_path_challenge_frame.h new file mode 100644 index 000000000000..34dd40b4ae5b --- /dev/null +++ b/quiche/quic/core/frames/quic_path_challenge_frame.h @@ -0,0 +1,37 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_FRAMES_QUIC_PATH_CHALLENGE_FRAME_H_ +#define QUICHE_QUIC_CORE_FRAMES_QUIC_PATH_CHALLENGE_FRAME_H_ + +#include +#include + +#include "quiche/quic/core/frames/quic_inlined_frame.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_types.h" + +namespace quic { + +struct QUIC_EXPORT_PRIVATE QuicPathChallengeFrame + : public QuicInlinedFrame { + QuicPathChallengeFrame(); + QuicPathChallengeFrame(QuicControlFrameId control_frame_id, + const QuicPathFrameBuffer& data_buff); + ~QuicPathChallengeFrame() = default; + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const QuicPathChallengeFrame& frame); + + QuicFrameType type; + + // A unique identifier of this control frame. 0 when this frame is received, + // and non-zero when sent. + QuicControlFrameId control_frame_id = kInvalidControlFrameId; + + QuicPathFrameBuffer data_buffer{}; +}; +} // namespace quic + +#endif // QUICHE_QUIC_CORE_FRAMES_QUIC_PATH_CHALLENGE_FRAME_H_ diff --git a/quiche/quic/core/frames/quic_path_response_frame.cc b/quiche/quic/core/frames/quic_path_response_frame.cc new file mode 100644 index 000000000000..0f7a41219601 --- /dev/null +++ b/quiche/quic/core/frames/quic_path_response_frame.cc @@ -0,0 +1,31 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/frames/quic_path_response_frame.h" + +#include "absl/strings/escaping.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" + +namespace quic { + +QuicPathResponseFrame::QuicPathResponseFrame() + : QuicInlinedFrame(PATH_RESPONSE_FRAME) {} + +QuicPathResponseFrame::QuicPathResponseFrame( + QuicControlFrameId control_frame_id, const QuicPathFrameBuffer& data_buff) + : QuicInlinedFrame(PATH_RESPONSE_FRAME), + control_frame_id(control_frame_id) { + memcpy(data_buffer.data(), data_buff.data(), data_buffer.size()); +} + +std::ostream& operator<<(std::ostream& os, const QuicPathResponseFrame& frame) { + os << "{ control_frame_id: " << frame.control_frame_id << ", data: " + << absl::BytesToHexString(absl::string_view( + reinterpret_cast(frame.data_buffer.data()), + frame.data_buffer.size())) + << " }\n"; + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/frames/quic_path_response_frame.h b/quiche/quic/core/frames/quic_path_response_frame.h new file mode 100644 index 000000000000..5c6a6673bd58 --- /dev/null +++ b/quiche/quic/core/frames/quic_path_response_frame.h @@ -0,0 +1,37 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_FRAMES_QUIC_PATH_RESPONSE_FRAME_H_ +#define QUICHE_QUIC_CORE_FRAMES_QUIC_PATH_RESPONSE_FRAME_H_ + +#include +#include + +#include "quiche/quic/core/frames/quic_inlined_frame.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_types.h" + +namespace quic { + +struct QUIC_EXPORT_PRIVATE QuicPathResponseFrame + : public QuicInlinedFrame { + QuicPathResponseFrame(); + QuicPathResponseFrame(QuicControlFrameId control_frame_id, + const QuicPathFrameBuffer& data_buff); + ~QuicPathResponseFrame() = default; + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const QuicPathResponseFrame& frame); + + QuicFrameType type; + + // A unique identifier of this control frame. 0 when this frame is received, + // and non-zero when sent. + QuicControlFrameId control_frame_id = kInvalidControlFrameId; + + QuicPathFrameBuffer data_buffer{}; +}; +} // namespace quic + +#endif // QUICHE_QUIC_CORE_FRAMES_QUIC_PATH_RESPONSE_FRAME_H_ diff --git a/quiche/quic/core/frames/quic_ping_frame.cc b/quiche/quic/core/frames/quic_ping_frame.cc new file mode 100644 index 000000000000..c28e671fec5a --- /dev/null +++ b/quiche/quic/core/frames/quic_ping_frame.cc @@ -0,0 +1,19 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/frames/quic_ping_frame.h" + +namespace quic { + +QuicPingFrame::QuicPingFrame() : QuicInlinedFrame(PING_FRAME) {} + +QuicPingFrame::QuicPingFrame(QuicControlFrameId control_frame_id) + : QuicInlinedFrame(PING_FRAME), control_frame_id(control_frame_id) {} + +std::ostream& operator<<(std::ostream& os, const QuicPingFrame& ping_frame) { + os << "{ control_frame_id: " << ping_frame.control_frame_id << " }\n"; + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/frames/quic_ping_frame.h b/quiche/quic/core/frames/quic_ping_frame.h new file mode 100644 index 000000000000..2603b652c3b9 --- /dev/null +++ b/quiche/quic/core/frames/quic_ping_frame.h @@ -0,0 +1,34 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_FRAMES_QUIC_PING_FRAME_H_ +#define QUICHE_QUIC_CORE_FRAMES_QUIC_PING_FRAME_H_ + +#include "quiche/quic/core/frames/quic_inlined_frame.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// A ping frame contains no payload, though it is retransmittable, +// and ACK'd just like other normal frames. +struct QUIC_EXPORT_PRIVATE QuicPingFrame + : public QuicInlinedFrame { + QuicPingFrame(); + explicit QuicPingFrame(QuicControlFrameId control_frame_id); + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const QuicPingFrame& ping_frame); + + QuicFrameType type; + + // A unique identifier of this control frame. 0 when this frame is received, + // and non-zero when sent. + QuicControlFrameId control_frame_id = kInvalidControlFrameId; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_FRAMES_QUIC_PING_FRAME_H_ diff --git a/quiche/quic/core/frames/quic_retire_connection_id_frame.cc b/quiche/quic/core/frames/quic_retire_connection_id_frame.cc new file mode 100644 index 000000000000..93e7e49dda7e --- /dev/null +++ b/quiche/quic/core/frames/quic_retire_connection_id_frame.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/frames/quic_retire_connection_id_frame.h" + +namespace quic { + +QuicRetireConnectionIdFrame::QuicRetireConnectionIdFrame( + QuicControlFrameId control_frame_id, + QuicConnectionIdSequenceNumber sequence_number) + : control_frame_id(control_frame_id), sequence_number(sequence_number) {} + +std::ostream& operator<<(std::ostream& os, + const QuicRetireConnectionIdFrame& frame) { + os << "{ control_frame_id: " << frame.control_frame_id + << ", sequence_number: " << frame.sequence_number << " }\n"; + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/frames/quic_retire_connection_id_frame.h b/quiche/quic/core/frames/quic_retire_connection_id_frame.h new file mode 100644 index 000000000000..fcff6969b43f --- /dev/null +++ b/quiche/quic/core/frames/quic_retire_connection_id_frame.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_FRAMES_QUIC_RETIRE_CONNECTION_ID_FRAME_H_ +#define QUICHE_QUIC_CORE_FRAMES_QUIC_RETIRE_CONNECTION_ID_FRAME_H_ + +#include + +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" + +namespace quic { + +struct QUIC_EXPORT_PRIVATE QuicRetireConnectionIdFrame { + QuicRetireConnectionIdFrame() = default; + QuicRetireConnectionIdFrame(QuicControlFrameId control_frame_id, + QuicConnectionIdSequenceNumber sequence_number); + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const QuicRetireConnectionIdFrame& frame); + + // A unique identifier of this control frame. 0 when this frame is received, + // and non-zero when sent. + QuicControlFrameId control_frame_id = kInvalidControlFrameId; + QuicConnectionIdSequenceNumber sequence_number = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_FRAMES_QUIC_RETIRE_CONNECTION_ID_FRAME_H_ diff --git a/quiche/quic/core/frames/quic_rst_stream_frame.cc b/quiche/quic/core/frames/quic_rst_stream_frame.cc new file mode 100644 index 000000000000..a6d30524278e --- /dev/null +++ b/quiche/quic/core/frames/quic_rst_stream_frame.cc @@ -0,0 +1,41 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/frames/quic_rst_stream_frame.h" + +#include "quiche/quic/core/quic_error_codes.h" + +namespace quic { + +QuicRstStreamFrame::QuicRstStreamFrame(QuicControlFrameId control_frame_id, + QuicStreamId stream_id, + QuicRstStreamErrorCode error_code, + QuicStreamOffset bytes_written) + : control_frame_id(control_frame_id), + stream_id(stream_id), + error_code(error_code), + ietf_error_code(RstStreamErrorCodeToIetfResetStreamErrorCode(error_code)), + byte_offset(bytes_written) {} + +QuicRstStreamFrame::QuicRstStreamFrame(QuicControlFrameId control_frame_id, + QuicStreamId stream_id, + QuicResetStreamError error, + QuicStreamOffset bytes_written) + : control_frame_id(control_frame_id), + stream_id(stream_id), + error_code(error.internal_code()), + ietf_error_code(error.ietf_application_code()), + byte_offset(bytes_written) {} + +std::ostream& operator<<(std::ostream& os, + const QuicRstStreamFrame& rst_frame) { + os << "{ control_frame_id: " << rst_frame.control_frame_id + << ", stream_id: " << rst_frame.stream_id + << ", byte_offset: " << rst_frame.byte_offset + << ", error_code: " << rst_frame.error_code + << ", ietf_error_code: " << rst_frame.ietf_error_code << " }\n"; + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/frames/quic_rst_stream_frame.h b/quiche/quic/core/frames/quic_rst_stream_frame.h new file mode 100644 index 000000000000..c346aff37de1 --- /dev/null +++ b/quiche/quic/core/frames/quic_rst_stream_frame.h @@ -0,0 +1,58 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_FRAMES_QUIC_RST_STREAM_FRAME_H_ +#define QUICHE_QUIC_CORE_FRAMES_QUIC_RST_STREAM_FRAME_H_ + +#include + +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" + +namespace quic { + +struct QUIC_EXPORT_PRIVATE QuicRstStreamFrame { + QuicRstStreamFrame() = default; + QuicRstStreamFrame(QuicControlFrameId control_frame_id, + QuicStreamId stream_id, QuicRstStreamErrorCode error_code, + QuicStreamOffset bytes_written); + QuicRstStreamFrame(QuicControlFrameId control_frame_id, + QuicStreamId stream_id, QuicResetStreamError error, + QuicStreamOffset bytes_written); + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const QuicRstStreamFrame& r); + + // A unique identifier of this control frame. 0 when this frame is received, + // and non-zero when sent. + QuicControlFrameId control_frame_id = kInvalidControlFrameId; + + QuicStreamId stream_id = 0; + + // When using Google QUIC: the RST_STREAM error code on the wire. + // When using IETF QUIC: for an outgoing RESET_STREAM frame, the error code + // generated by the application that determines |ietf_error_code| to be sent + // on the wire; for an incoming RESET_STREAM frame, the error code inferred + // from the |ietf_error_code| received on the wire. + QuicRstStreamErrorCode error_code = QUIC_STREAM_NO_ERROR; + + // Application error code of RESET_STREAM frame. Used for IETF QUIC only. + uint64_t ietf_error_code = 0; + + // Used to update flow control windows. On termination of a stream, both + // endpoints must inform the peer of the number of bytes they have sent on + // that stream. This can be done through normal termination (data packet with + // FIN) or through a RST. + QuicStreamOffset byte_offset = 0; + + // Returns a tuple of both |error_code| and |ietf_error_code|. + QuicResetStreamError error() const { + return QuicResetStreamError(error_code, ietf_error_code); + } +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_FRAMES_QUIC_RST_STREAM_FRAME_H_ diff --git a/quiche/quic/core/frames/quic_stop_sending_frame.cc b/quiche/quic/core/frames/quic_stop_sending_frame.cc new file mode 100644 index 000000000000..f70b7852578f --- /dev/null +++ b/quiche/quic/core/frames/quic_stop_sending_frame.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/frames/quic_stop_sending_frame.h" + +#include "quiche/quic/core/quic_error_codes.h" + +namespace quic { + +QuicStopSendingFrame::QuicStopSendingFrame() + : QuicInlinedFrame(STOP_SENDING_FRAME) {} + +QuicStopSendingFrame::QuicStopSendingFrame(QuicControlFrameId control_frame_id, + QuicStreamId stream_id, + QuicRstStreamErrorCode error_code) + : QuicStopSendingFrame(control_frame_id, stream_id, + QuicResetStreamError::FromInternal(error_code)) {} + +QuicStopSendingFrame::QuicStopSendingFrame(QuicControlFrameId control_frame_id, + QuicStreamId stream_id, + QuicResetStreamError error) + : QuicInlinedFrame(STOP_SENDING_FRAME), + control_frame_id(control_frame_id), + stream_id(stream_id), + error_code(error.internal_code()), + ietf_error_code(error.ietf_application_code()) {} + +std::ostream& operator<<(std::ostream& os, const QuicStopSendingFrame& frame) { + os << "{ control_frame_id: " << frame.control_frame_id + << ", stream_id: " << frame.stream_id + << ", error_code: " << frame.error_code + << ", ietf_error_code: " << frame.ietf_error_code << " }\n"; + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/frames/quic_stop_sending_frame.h b/quiche/quic/core/frames/quic_stop_sending_frame.h new file mode 100644 index 000000000000..8a7b8c6a1d71 --- /dev/null +++ b/quiche/quic/core/frames/quic_stop_sending_frame.h @@ -0,0 +1,52 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_FRAMES_QUIC_STOP_SENDING_FRAME_H_ +#define QUICHE_QUIC_CORE_FRAMES_QUIC_STOP_SENDING_FRAME_H_ + +#include + +#include "quiche/quic/core/frames/quic_inlined_frame.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" + +namespace quic { + +struct QUIC_EXPORT_PRIVATE QuicStopSendingFrame + : public QuicInlinedFrame { + QuicStopSendingFrame(); + QuicStopSendingFrame(QuicControlFrameId control_frame_id, + QuicStreamId stream_id, + QuicRstStreamErrorCode error_code); + QuicStopSendingFrame(QuicControlFrameId control_frame_id, + QuicStreamId stream_id, QuicResetStreamError error); + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const QuicStopSendingFrame& frame); + + QuicFrameType type; + + // A unique identifier of this control frame. 0 when this frame is received, + // and non-zero when sent. + QuicControlFrameId control_frame_id = kInvalidControlFrameId; + QuicStreamId stream_id = 0; + + // For an outgoing frame, the error code generated by the application that + // determines |ietf_error_code| to be sent on the wire; for an incoming frame, + // the error code inferred from |ietf_error_code| received on the wire. + QuicRstStreamErrorCode error_code = QUIC_STREAM_NO_ERROR; + + // On-the-wire application error code of the frame. + uint64_t ietf_error_code = 0; + + // Returns a tuple of both |error_code| and |ietf_error_code|. + QuicResetStreamError error() const { + return QuicResetStreamError(error_code, ietf_error_code); + } +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_FRAMES_QUIC_STOP_SENDING_FRAME_H_ diff --git a/quiche/quic/core/frames/quic_stop_waiting_frame.cc b/quiche/quic/core/frames/quic_stop_waiting_frame.cc new file mode 100644 index 000000000000..32941aadd2e0 --- /dev/null +++ b/quiche/quic/core/frames/quic_stop_waiting_frame.cc @@ -0,0 +1,20 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/frames/quic_stop_waiting_frame.h" + +#include "quiche/quic/core/quic_constants.h" + +namespace quic { + +QuicStopWaitingFrame::QuicStopWaitingFrame() + : QuicInlinedFrame(STOP_WAITING_FRAME) {} + +std::ostream& operator<<(std::ostream& os, + const QuicStopWaitingFrame& sent_info) { + os << "{ least_unacked: " << sent_info.least_unacked << " }\n"; + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/frames/quic_stop_waiting_frame.h b/quiche/quic/core/frames/quic_stop_waiting_frame.h new file mode 100644 index 000000000000..526ea01007b8 --- /dev/null +++ b/quiche/quic/core/frames/quic_stop_waiting_frame.h @@ -0,0 +1,31 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_FRAMES_QUIC_STOP_WAITING_FRAME_H_ +#define QUICHE_QUIC_CORE_FRAMES_QUIC_STOP_WAITING_FRAME_H_ + +#include + +#include "quiche/quic/core/frames/quic_inlined_frame.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +struct QUIC_EXPORT_PRIVATE QuicStopWaitingFrame + : public QuicInlinedFrame { + QuicStopWaitingFrame(); + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const QuicStopWaitingFrame& s); + + QuicFrameType type; + + // The lowest packet we've sent which is unacked, and we expect an ack for. + QuicPacketNumber least_unacked; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_FRAMES_QUIC_STOP_WAITING_FRAME_H_ diff --git a/quiche/quic/core/frames/quic_stream_frame.cc b/quiche/quic/core/frames/quic_stream_frame.cc new file mode 100644 index 000000000000..c6988a0e8db1 --- /dev/null +++ b/quiche/quic/core/frames/quic_stream_frame.cc @@ -0,0 +1,53 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/frames/quic_stream_frame.h" + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +QuicStreamFrame::QuicStreamFrame() : QuicInlinedFrame(STREAM_FRAME) {} + +QuicStreamFrame::QuicStreamFrame(QuicStreamId stream_id, bool fin, + QuicStreamOffset offset, + absl::string_view data) + : QuicStreamFrame(stream_id, fin, offset, data.data(), data.length()) {} + +QuicStreamFrame::QuicStreamFrame(QuicStreamId stream_id, bool fin, + QuicStreamOffset offset, + QuicPacketLength data_length) + : QuicStreamFrame(stream_id, fin, offset, nullptr, data_length) {} + +QuicStreamFrame::QuicStreamFrame(QuicStreamId stream_id, bool fin, + QuicStreamOffset offset, + const char* data_buffer, + QuicPacketLength data_length) + : QuicInlinedFrame(STREAM_FRAME), + fin(fin), + data_length(data_length), + stream_id(stream_id), + data_buffer(data_buffer), + offset(offset) {} + +std::ostream& operator<<(std::ostream& os, + const QuicStreamFrame& stream_frame) { + os << "{ stream_id: " << stream_frame.stream_id + << ", fin: " << stream_frame.fin << ", offset: " << stream_frame.offset + << ", length: " << stream_frame.data_length << " }\n"; + return os; +} + +bool QuicStreamFrame::operator==(const QuicStreamFrame& rhs) const { + return fin == rhs.fin && data_length == rhs.data_length && + stream_id == rhs.stream_id && data_buffer == rhs.data_buffer && + offset == rhs.offset; +} + +bool QuicStreamFrame::operator!=(const QuicStreamFrame& rhs) const { + return !(*this == rhs); +} + +} // namespace quic diff --git a/quiche/quic/core/frames/quic_stream_frame.h b/quiche/quic/core/frames/quic_stream_frame.h new file mode 100644 index 000000000000..a6b965a49d95 --- /dev/null +++ b/quiche/quic/core/frames/quic_stream_frame.h @@ -0,0 +1,50 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_FRAMES_QUIC_STREAM_FRAME_H_ +#define QUICHE_QUIC_CORE_FRAMES_QUIC_STREAM_FRAME_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/frames/quic_inlined_frame.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +struct QUIC_EXPORT_PRIVATE QuicStreamFrame + : public QuicInlinedFrame { + QuicStreamFrame(); + QuicStreamFrame(QuicStreamId stream_id, bool fin, QuicStreamOffset offset, + absl::string_view data); + QuicStreamFrame(QuicStreamId stream_id, bool fin, QuicStreamOffset offset, + QuicPacketLength data_length); + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<(std::ostream& os, + const QuicStreamFrame& s); + + bool operator==(const QuicStreamFrame& rhs) const; + + bool operator!=(const QuicStreamFrame& rhs) const; + + QuicFrameType type; + bool fin = false; + QuicPacketLength data_length = 0; + // TODO(wub): Change to a QuicUtils::GetInvalidStreamId when it is not version + // dependent. + QuicStreamId stream_id = -1; + const char* data_buffer = nullptr; // Not owned. + QuicStreamOffset offset = 0; // Location of this data in the stream. + + QuicStreamFrame(QuicStreamId stream_id, bool fin, QuicStreamOffset offset, + const char* data_buffer, QuicPacketLength data_length); +}; +static_assert(sizeof(QuicStreamFrame) <= 64, + "Keep the QuicStreamFrame size to a cacheline."); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_FRAMES_QUIC_STREAM_FRAME_H_ diff --git a/quiche/quic/core/frames/quic_streams_blocked_frame.cc b/quiche/quic/core/frames/quic_streams_blocked_frame.cc new file mode 100644 index 000000000000..6d6a6d2f6e28 --- /dev/null +++ b/quiche/quic/core/frames/quic_streams_blocked_frame.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/frames/quic_streams_blocked_frame.h" + +namespace quic { + +QuicStreamsBlockedFrame::QuicStreamsBlockedFrame() + : QuicInlinedFrame(STREAMS_BLOCKED_FRAME) {} + +QuicStreamsBlockedFrame::QuicStreamsBlockedFrame( + QuicControlFrameId control_frame_id, QuicStreamCount stream_count, + bool unidirectional) + : QuicInlinedFrame(STREAMS_BLOCKED_FRAME), + control_frame_id(control_frame_id), + stream_count(stream_count), + unidirectional(unidirectional) {} + +std::ostream& operator<<(std::ostream& os, + const QuicStreamsBlockedFrame& frame) { + os << "{ control_frame_id: " << frame.control_frame_id + << ", stream count: " << frame.stream_count + << ((frame.unidirectional) ? ", unidirectional }\n" + : ", bidirectional }\n"); + + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/frames/quic_streams_blocked_frame.h b/quiche/quic/core/frames/quic_streams_blocked_frame.h new file mode 100644 index 000000000000..cc60f31294ed --- /dev/null +++ b/quiche/quic/core/frames/quic_streams_blocked_frame.h @@ -0,0 +1,44 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_FRAMES_QUIC_STREAMS_BLOCKED_FRAME_H_ +#define QUICHE_QUIC_CORE_FRAMES_QUIC_STREAMS_BLOCKED_FRAME_H_ + +#include + +#include "quiche/quic/core/frames/quic_inlined_frame.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// IETF format STREAMS_BLOCKED frame. +// The sender uses this to inform the peer that the sender wished to +// open a new stream, exceeding the limit on the number of streams. +struct QUIC_EXPORT_PRIVATE QuicStreamsBlockedFrame + : public QuicInlinedFrame { + QuicStreamsBlockedFrame(); + QuicStreamsBlockedFrame(QuicControlFrameId control_frame_id, + QuicStreamCount stream_count, bool unidirectional); + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const QuicStreamsBlockedFrame& frame); + + QuicFrameType type; + + // A unique identifier of this control frame. 0 when this frame is received, + // and non-zero when sent. + QuicControlFrameId control_frame_id = kInvalidControlFrameId; + + // The number of streams that the sender wishes to exceed + QuicStreamCount stream_count = 0; + + // Whether uni- or bi-directional streams + bool unidirectional = false; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_FRAMES_QUIC_STREAMS_BLOCKED_FRAME_H_ diff --git a/quiche/quic/core/frames/quic_window_update_frame.cc b/quiche/quic/core/frames/quic_window_update_frame.cc new file mode 100644 index 000000000000..3c134544ae78 --- /dev/null +++ b/quiche/quic/core/frames/quic_window_update_frame.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/frames/quic_window_update_frame.h" + +#include "quiche/quic/core/quic_types.h" + +namespace quic { + +QuicWindowUpdateFrame::QuicWindowUpdateFrame() + : QuicInlinedFrame(WINDOW_UPDATE_FRAME) {} + +QuicWindowUpdateFrame::QuicWindowUpdateFrame( + QuicControlFrameId control_frame_id, QuicStreamId stream_id, + QuicByteCount max_data) + : QuicInlinedFrame(WINDOW_UPDATE_FRAME), + control_frame_id(control_frame_id), + stream_id(stream_id), + max_data(max_data) {} + +std::ostream& operator<<(std::ostream& os, + const QuicWindowUpdateFrame& window_update_frame) { + os << "{ control_frame_id: " << window_update_frame.control_frame_id + << ", stream_id: " << window_update_frame.stream_id + << ", max_data: " << window_update_frame.max_data << " }\n"; + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/frames/quic_window_update_frame.h b/quiche/quic/core/frames/quic_window_update_frame.h new file mode 100644 index 000000000000..1cbd7ffac58b --- /dev/null +++ b/quiche/quic/core/frames/quic_window_update_frame.h @@ -0,0 +1,45 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_FRAMES_QUIC_WINDOW_UPDATE_FRAME_H_ +#define QUICHE_QUIC_CORE_FRAMES_QUIC_WINDOW_UPDATE_FRAME_H_ + +#include + +#include "quiche/quic/core/frames/quic_inlined_frame.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_types.h" + +namespace quic { + +// Flow control updates per-stream and at the connection level. +// Based on SPDY's WINDOW_UPDATE frame, but uses an absolute max data bytes +// rather than a window delta. +struct QUIC_EXPORT_PRIVATE QuicWindowUpdateFrame + : public QuicInlinedFrame { + QuicWindowUpdateFrame(); + QuicWindowUpdateFrame(QuicControlFrameId control_frame_id, + QuicStreamId stream_id, QuicByteCount max_data); + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const QuicWindowUpdateFrame& w); + + QuicFrameType type; + + // A unique identifier of this control frame. 0 when this frame is received, + // and non-zero when sent. + QuicControlFrameId control_frame_id = kInvalidControlFrameId; + + // The stream this frame applies to. 0 is a special case meaning the overall + // connection rather than a specific stream. + QuicStreamId stream_id = 0; + + // Maximum data allowed in the stream or connection. The receiver of this + // frame must not send data which would exceedes this restriction. + QuicByteCount max_data = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_FRAMES_QUIC_WINDOW_UPDATE_FRAME_H_ diff --git a/quiche/quic/core/handshaker_delegate_interface.h b/quiche/quic/core/handshaker_delegate_interface.h new file mode 100644 index 000000000000..d6e250100378 --- /dev/null +++ b/quiche/quic/core/handshaker_delegate_interface.h @@ -0,0 +1,85 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_HANDSHAKER_DELEGATE_INTERFACE_H_ +#define QUICHE_QUIC_CORE_HANDSHAKER_DELEGATE_INTERFACE_H_ + +#include "quiche/quic/core/crypto/transport_parameters.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_versions.h" + +namespace quic { + +class QuicDecrypter; +class QuicEncrypter; + +// Pure virtual class to get notified when particular handshake events occurred. +class QUIC_EXPORT_PRIVATE HandshakerDelegateInterface { + public: + virtual ~HandshakerDelegateInterface() {} + + // Called when new decryption key of |level| is available. Returns true if + // decrypter is set successfully, otherwise, returns false. + virtual bool OnNewDecryptionKeyAvailable( + EncryptionLevel level, std::unique_ptr decrypter, + bool set_alternative_decrypter, bool latch_once_used) = 0; + + // Called when new encryption key of |level| is available. + virtual void OnNewEncryptionKeyAvailable( + EncryptionLevel level, std::unique_ptr encrypter) = 0; + + // Called to set default encryption level to |level|. Only used in QUIC + // crypto. + virtual void SetDefaultEncryptionLevel(EncryptionLevel level) = 0; + + // Called when both 1-RTT read and write keys are available. Only used in TLS + // handshake. + virtual void OnTlsHandshakeComplete() = 0; + + // Called to discard old decryption keys to stop processing packets of + // encryption |level|. + virtual void DiscardOldDecryptionKey(EncryptionLevel level) = 0; + + // Called to discard old encryption keys (and neuter obsolete data). + // TODO(fayang): consider to combine this with DiscardOldDecryptionKey. + virtual void DiscardOldEncryptionKey(EncryptionLevel level) = 0; + + // Called to neuter ENCRYPTION_INITIAL data (without discarding initial keys). + virtual void NeuterUnencryptedData() = 0; + + // Called to neuter data of HANDSHAKE_DATA packet number space. Only used in + // QUIC crypto. This is called 1) when a client switches to forward secure + // encryption level and 2) a server successfully processes a forward secure + // packet. + virtual void NeuterHandshakeData() = 0; + + // Called when 0-RTT data is rejected by the server. This is only called in + // TLS handshakes and only called on clients. + virtual void OnZeroRttRejected(int reason) = 0; + + // Fills in |params| with values from the delegate's QuicConfig. + // Returns whether the operation succeeded. + virtual bool FillTransportParameters(TransportParameters* params) = 0; + + // Read |params| and apply the values to the delegate's QuicConfig. + // On failure, returns a QuicErrorCode and saves a detailed error in + // |error_details|. + virtual QuicErrorCode ProcessTransportParameters( + const TransportParameters& params, bool is_resumption, + std::string* error_details) = 0; + + // Called at the end of an handshake operation callback. + virtual void OnHandshakeCallbackDone() = 0; + + // Whether a packet flusher is currently attached. + virtual bool PacketFlusherAttached() const = 0; + + // Get the QUIC version currently in use. tls_handshaker needs this to pass + // to crypto_utils to apply version-dependent HKDF labels. + virtual ParsedQuicVersion parsed_version() const = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_HANDSHAKER_DELEGATE_INTERFACE_H_ diff --git a/quiche/quic/core/http/end_to_end_test.cc b/quiche/quic/core/http/end_to_end_test.cc new file mode 100644 index 000000000000..bf8e903e93e4 --- /dev/null +++ b/quiche/quic/core/http/end_to_end_test.cc @@ -0,0 +1,7427 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/core/crypto/quic_client_session_cache.h" +#include "quiche/quic/core/frames/quic_blocked_frame.h" +#include "quiche/quic/core/http/http_constants.h" +#include "quiche/quic/core/http/quic_spdy_client_stream.h" +#include "quiche/quic/core/http/web_transport_http3.h" +#include "quiche/quic/core/io/quic_default_event_loop.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/core/quic_default_clock.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_framer.h" +#include "quiche/quic/core/quic_packet_creator.h" +#include "quiche/quic/core/quic_packet_writer_wrapper.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/platform/api/quic_test_loopback.h" +#include "quiche/quic/test_tools/bad_packet_writer.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/packet_dropping_test_writer.h" +#include "quiche/quic/test_tools/packet_reordering_writer.h" +#include "quiche/quic/test_tools/qpack/qpack_encoder_peer.h" +#include "quiche/quic/test_tools/qpack/qpack_test_utils.h" +#include "quiche/quic/test_tools/quic_client_session_cache_peer.h" +#include "quiche/quic/test_tools/quic_config_peer.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_dispatcher_peer.h" +#include "quiche/quic/test_tools/quic_flow_controller_peer.h" +#include "quiche/quic/test_tools/quic_sent_packet_manager_peer.h" +#include "quiche/quic/test_tools/quic_server_peer.h" +#include "quiche/quic/test_tools/quic_session_peer.h" +#include "quiche/quic/test_tools/quic_spdy_session_peer.h" +#include "quiche/quic/test_tools/quic_spdy_stream_peer.h" +#include "quiche/quic/test_tools/quic_stream_id_manager_peer.h" +#include "quiche/quic/test_tools/quic_stream_peer.h" +#include "quiche/quic/test_tools/quic_stream_sequencer_peer.h" +#include "quiche/quic/test_tools/quic_test_backend.h" +#include "quiche/quic/test_tools/quic_test_client.h" +#include "quiche/quic/test_tools/quic_test_server.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/test_tools/server_thread.h" +#include "quiche/quic/test_tools/web_transport_test_tools.h" +#include "quiche/quic/tools/quic_backend_response.h" +#include "quiche/quic/tools/quic_memory_cache_backend.h" +#include "quiche/quic/tools/quic_server.h" +#include "quiche/quic/tools/quic_simple_client_stream.h" +#include "quiche/quic/tools/quic_simple_server_stream.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/quiche_stream.h" +#include "quiche/common/test_tools/quiche_test_utils.h" +#include "quiche/spdy/core/http2_header_block.h" + +using spdy::Http2HeaderBlock; +using spdy::kV3LowestPriority; +using spdy::SpdyFramer; +using spdy::SpdySerializedFrame; +using spdy::SpdySettingsIR; +using ::testing::_; +using ::testing::Assign; +using ::testing::Invoke; +using ::testing::NiceMock; +using ::testing::UnorderedElementsAreArray; + +namespace quic { +namespace test { +namespace { + +const char kFooResponseBody[] = "Artichoke hearts make me happy."; +const char kBarResponseBody[] = "Palm hearts are pretty delicious, also."; +const char kTestUserAgentId[] = "quic/core/http/end_to_end_test.cc"; +const float kSessionToStreamRatio = 1.5; +const int kLongConnectionIdLength = 16; + +// Run all tests with the cross products of all versions. +struct TestParams { + TestParams(const ParsedQuicVersion& version, QuicTag congestion_control_tag, + QuicEventLoopFactory* event_loop, + int override_server_connection_id_length) + : version(version), + congestion_control_tag(congestion_control_tag), + event_loop(event_loop), + override_server_connection_id_length( + override_server_connection_id_length) {} + + friend std::ostream& operator<<(std::ostream& os, const TestParams& p) { + os << "{ version: " << ParsedQuicVersionToString(p.version); + os << " congestion_control_tag: " + << QuicTagToString(p.congestion_control_tag) + << " event loop: " << p.event_loop->GetName() + << " connection ID length: " << p.override_server_connection_id_length + << " }"; + return os; + } + + ParsedQuicVersion version; + QuicTag congestion_control_tag; + QuicEventLoopFactory* event_loop; + int override_server_connection_id_length; +}; + +// Used by ::testing::PrintToStringParamName(). +std::string PrintToString(const TestParams& p) { + std::string rv = absl::StrCat( + ParsedQuicVersionToString(p.version), "_", + QuicTagToString(p.congestion_control_tag), "_", p.event_loop->GetName(), + "_", + std::to_string((p.override_server_connection_id_length == -1) + ? static_cast(kQuicDefaultConnectionIdLength) + : p.override_server_connection_id_length)); + return EscapeTestParamName(rv); +} + +// Constructs various test permutations. +std::vector GetTestParams() { + std::vector params; + std::vector connection_id_lengths{-1, kLongConnectionIdLength}; + for (auto connection_id_length : connection_id_lengths) { + for (const QuicTag congestion_control_tag : {kTBBR, kQBIC, kB2ON}) { + if (!GetQuicReloadableFlag(quic_allow_client_enabled_bbr_v2) && + congestion_control_tag == kB2ON) { + continue; + } + for (const ParsedQuicVersion& version : CurrentSupportedVersions()) { + // TODO(b/232269029): Q050 should be able to handle 0-RTT when the + // initial connection ID is > 8 bytes, but it cannot. This is an + // invasive fix that has no impact as long as gQUIC clients always use + // 8B server connection IDs. If this bug is fixed, we can change + // 'UsesTls' to 'AllowsVariableLengthConnectionIds()' below to test + // qQUIC as well. + if (connection_id_length == -1 || version.UsesTls()) { + params.push_back(TestParams(version, congestion_control_tag, + GetDefaultEventLoop(), + connection_id_length)); + } + } // End of outer version loop. + } // End of congestion_control_tag loop. + } // End of connection_id_length loop. + + // Only run every event loop implementation for one fixed configuration. + for (QuicEventLoopFactory* event_loop : GetAllSupportedEventLoops()) { + if (event_loop == GetDefaultEventLoop()) { + continue; + } + params.push_back( + TestParams(ParsedQuicVersion::RFCv1(), kTBBR, event_loop, -1)); + } + + return params; +} + +void WriteHeadersOnStream(QuicSpdyStream* stream) { + // Since QuicSpdyStream uses QuicHeaderList::empty() to detect too large + // headers, it also fails when receiving empty headers. + Http2HeaderBlock headers; + headers[":authority"] = "test.example.com:443"; + headers[":path"] = "/path"; + headers[":method"] = "GET"; + headers[":scheme"] = "https"; + stream->WriteHeaders(std::move(headers), /* fin = */ false, nullptr); +} + +class ServerDelegate : public PacketDroppingTestWriter::Delegate { + public: + explicit ServerDelegate(QuicDispatcher* dispatcher) + : dispatcher_(dispatcher) {} + ~ServerDelegate() override = default; + void OnCanWrite() override { dispatcher_->OnCanWrite(); } + + private: + QuicDispatcher* dispatcher_; +}; + +class ClientDelegate : public PacketDroppingTestWriter::Delegate { + public: + explicit ClientDelegate(QuicDefaultClient* client) : client_(client) {} + ~ClientDelegate() override = default; + void OnCanWrite() override { + client_->default_network_helper()->OnSocketEvent( + nullptr, client_->GetLatestFD(), kSocketEventWritable); + } + + private: + QuicDefaultClient* client_; +}; + +class EndToEndTest : public QuicTestWithParam { + protected: + EndToEndTest() + : initialized_(false), + connect_to_server_on_initialize_(true), + server_address_(QuicSocketAddress(TestLoopback(), 0)), + server_hostname_("test.example.com"), + fd_(kQuicInvalidSocketFd), + client_writer_(nullptr), + server_writer_(nullptr), + version_(GetParam().version), + client_supported_versions_({version_}), + server_supported_versions_(CurrentSupportedVersions()), + chlo_multiplier_(0), + stream_factory_(nullptr), + override_server_connection_id_length_( + GetParam().override_server_connection_id_length), + expected_server_connection_id_length_(kQuicDefaultConnectionIdLength) { + QUIC_LOG(INFO) << "Using Configuration: " << GetParam(); + + // Use different flow control windows for client/server. + client_config_.SetInitialStreamFlowControlWindowToSend( + 2 * kInitialStreamFlowControlWindowForTest); + client_config_.SetInitialSessionFlowControlWindowToSend( + 2 * kInitialSessionFlowControlWindowForTest); + server_config_.SetInitialStreamFlowControlWindowToSend( + 3 * kInitialStreamFlowControlWindowForTest); + server_config_.SetInitialSessionFlowControlWindowToSend( + 3 * kInitialSessionFlowControlWindowForTest); + + // The default idle timeouts can be too strict when running on a busy + // machine. + const QuicTime::Delta timeout = QuicTime::Delta::FromSeconds(30); + client_config_.set_max_time_before_crypto_handshake(timeout); + client_config_.set_max_idle_time_before_crypto_handshake(timeout); + server_config_.set_max_time_before_crypto_handshake(timeout); + server_config_.set_max_idle_time_before_crypto_handshake(timeout); + + AddToCache("/foo", 200, kFooResponseBody); + AddToCache("/bar", 200, kBarResponseBody); + // Enable fixes for bugs found in tests and prod. + } + + virtual void CreateClientWithWriter() { + client_.reset(CreateQuicClient(client_writer_)); + } + + QuicTestClient* CreateQuicClient(QuicPacketWriterWrapper* writer) { + QuicTestClient* client = new QuicTestClient( + server_address_, server_hostname_, client_config_, + client_supported_versions_, + crypto_test_utils::ProofVerifierForTesting(), + std::make_unique(), + GetParam().event_loop->Create(QuicDefaultClock::Get())); + client->SetUserAgentID(kTestUserAgentId); + client->UseWriter(writer); + if (!pre_shared_key_client_.empty()) { + client->client()->SetPreSharedKey(pre_shared_key_client_); + } + if (override_server_connection_id_length_ >= 0) { + client->UseConnectionIdLength(override_server_connection_id_length_); + } + if (override_client_connection_id_length_ >= 0) { + client->UseClientConnectionIdLength( + override_client_connection_id_length_); + } + client->client()->set_connection_debug_visitor(connection_debug_visitor_); + client->client()->set_enable_web_transport(enable_web_transport_); + client->Connect(); + return client; + } + + void set_smaller_flow_control_receive_window() { + const uint32_t kClientIFCW = 64 * 1024; + const uint32_t kServerIFCW = 1024 * 1024; + set_client_initial_stream_flow_control_receive_window(kClientIFCW); + set_client_initial_session_flow_control_receive_window( + kSessionToStreamRatio * kClientIFCW); + set_server_initial_stream_flow_control_receive_window(kServerIFCW); + set_server_initial_session_flow_control_receive_window( + kSessionToStreamRatio * kServerIFCW); + } + + void set_client_initial_stream_flow_control_receive_window(uint32_t window) { + ASSERT_TRUE(client_ == nullptr); + QUIC_DLOG(INFO) << "Setting client initial stream flow control window: " + << window; + client_config_.SetInitialStreamFlowControlWindowToSend(window); + } + + void set_client_initial_session_flow_control_receive_window(uint32_t window) { + ASSERT_TRUE(client_ == nullptr); + QUIC_DLOG(INFO) << "Setting client initial session flow control window: " + << window; + client_config_.SetInitialSessionFlowControlWindowToSend(window); + } + + void set_client_initial_max_stream_data_incoming_bidirectional( + uint32_t window) { + ASSERT_TRUE(client_ == nullptr); + QUIC_DLOG(INFO) + << "Setting client initial max stream data incoming bidirectional: " + << window; + client_config_.SetInitialMaxStreamDataBytesIncomingBidirectionalToSend( + window); + } + + void set_server_initial_max_stream_data_outgoing_bidirectional( + uint32_t window) { + ASSERT_TRUE(client_ == nullptr); + QUIC_DLOG(INFO) + << "Setting server initial max stream data outgoing bidirectional: " + << window; + server_config_.SetInitialMaxStreamDataBytesOutgoingBidirectionalToSend( + window); + } + + void set_server_initial_stream_flow_control_receive_window(uint32_t window) { + ASSERT_TRUE(server_thread_ == nullptr); + QUIC_DLOG(INFO) << "Setting server initial stream flow control window: " + << window; + server_config_.SetInitialStreamFlowControlWindowToSend(window); + } + + void set_server_initial_session_flow_control_receive_window(uint32_t window) { + ASSERT_TRUE(server_thread_ == nullptr); + QUIC_DLOG(INFO) << "Setting server initial session flow control window: " + << window; + server_config_.SetInitialSessionFlowControlWindowToSend(window); + } + + const QuicSentPacketManager* GetSentPacketManagerFromFirstServerSession() { + QuicConnection* server_connection = GetServerConnection(); + if (server_connection == nullptr) { + ADD_FAILURE() << "Missing server connection"; + return nullptr; + } + return &server_connection->sent_packet_manager(); + } + + const QuicSentPacketManager* GetSentPacketManagerFromClientSession() { + QuicConnection* client_connection = GetClientConnection(); + if (client_connection == nullptr) { + ADD_FAILURE() << "Missing client connection"; + return nullptr; + } + return &client_connection->sent_packet_manager(); + } + + QuicSpdyClientSession* GetClientSession() { + if (!client_) { + ADD_FAILURE() << "Missing QuicTestClient"; + return nullptr; + } + if (client_->client() == nullptr) { + ADD_FAILURE() << "Missing MockableQuicClient"; + return nullptr; + } + return client_->client()->client_session(); + } + + QuicConnection* GetClientConnection() { + QuicSpdyClientSession* client_session = GetClientSession(); + if (client_session == nullptr) { + ADD_FAILURE() << "Missing client session"; + return nullptr; + } + return client_session->connection(); + } + + QuicConnection* GetServerConnection() { + QuicSpdySession* server_session = GetServerSession(); + if (server_session == nullptr) { + ADD_FAILURE() << "Missing server session"; + return nullptr; + } + return server_session->connection(); + } + + QuicSpdySession* GetServerSession() { + if (!server_thread_) { + ADD_FAILURE() << "Missing server thread"; + return nullptr; + } + QuicServer* quic_server = server_thread_->server(); + if (quic_server == nullptr) { + ADD_FAILURE() << "Missing server"; + return nullptr; + } + QuicDispatcher* dispatcher = QuicServerPeer::GetDispatcher(quic_server); + if (dispatcher == nullptr) { + ADD_FAILURE() << "Missing dispatcher"; + return nullptr; + } + if (dispatcher->NumSessions() == 0) { + ADD_FAILURE() << "Empty dispatcher session map"; + return nullptr; + } + EXPECT_EQ(1u, dispatcher->NumSessions()); + return static_cast( + QuicDispatcherPeer::GetFirstSessionIfAny(dispatcher)); + } + + bool Initialize() { + if (enable_web_transport_) { + memory_cache_backend_.set_enable_webtransport(true); + } + + QuicTagVector copt; + server_config_.SetConnectionOptionsToSend(copt); + copt = client_extra_copts_; + + // TODO(nimia): Consider setting the congestion control algorithm for the + // client as well according to the test parameter. + copt.push_back(GetParam().congestion_control_tag); + copt.push_back(k2PTO); + if (version_.HasIetfQuicFrames()) { + copt.push_back(kILD0); + } + copt.push_back(kPLE1); + if (!GetQuicReloadableFlag( + quic_remove_connection_migration_connection_option_v2)) { + copt.push_back(kRVCM); + } + client_config_.SetConnectionOptionsToSend(copt); + + // Start the server first, because CreateQuicClient() attempts + // to connect to the server. + StartServer(); + + if (use_preferred_address_) { + SetQuicReloadableFlag(quic_use_received_client_addresses_cache, true); + // At this point, the server has an ephemeral port to listen on. Restart + // the server with the preferred address. + StopServer(); + // server_address_ now contains the random listening port. + server_preferred_address_ = + QuicSocketAddress(TestLoopback(2), server_address_.port()); + if (server_preferred_address_ == server_address_) { + ADD_FAILURE() << "Preferred address and server address are the same " + << server_address_; + return false; + } + // Send server preferred address and let server listen on Any. + if (server_preferred_address_.host().IsIPv4()) { + server_listening_address_ = + QuicSocketAddress(QuicIpAddress::Any4(), server_address_.port()); + server_config_.SetIPv4AlternateServerAddressToSend( + server_preferred_address_); + } else { + server_listening_address_ = + QuicSocketAddress(QuicIpAddress::Any6(), server_address_.port()); + server_config_.SetIPv6AlternateServerAddressToSend( + server_preferred_address_); + } + // Server restarts. + server_writer_ = new PacketDroppingTestWriter(); + StartServer(); + + client_config_.SetConnectionOptionsToSend(QuicTagVector{kRVCM, kSPAD}); + } + + if (!connect_to_server_on_initialize_) { + initialized_ = true; + return true; + } + + CreateClientWithWriter(); + if (!client_) { + ADD_FAILURE() << "Missing QuicTestClient"; + return false; + } + MockableQuicClient* client = client_->client(); + if (client == nullptr) { + ADD_FAILURE() << "Missing MockableQuicClient"; + return false; + } + if (client_writer_ != nullptr) { + QuicConnection* client_connection = GetClientConnection(); + if (client_connection == nullptr) { + ADD_FAILURE() << "Missing client connection"; + return false; + } + client_writer_->Initialize( + QuicConnectionPeer::GetHelper(client_connection), + QuicConnectionPeer::GetAlarmFactory(client_connection), + std::make_unique(client)); + } + initialized_ = true; + return client->connected(); + } + + void SetUp() override { + // The ownership of these gets transferred to the QuicPacketWriterWrapper + // when Initialize() is executed. + client_writer_ = new PacketDroppingTestWriter(); + server_writer_ = new PacketDroppingTestWriter(); + } + + void TearDown() override { + EXPECT_TRUE(initialized_) << "You must call Initialize() in every test " + << "case. Otherwise, your test will leak memory."; + QuicConnection* client_connection = GetClientConnection(); + if (client_connection != nullptr) { + client_connection->set_debug_visitor(nullptr); + } else { + ADD_FAILURE() << "Missing client connection"; + } + StopServer(/*will_restart=*/false); + if (fd_ != kQuicInvalidSocketFd) { + // Every test should follow StopServer(true) with StartServer(), so we + // should never get here. + QuicUdpSocketApi socket_api; + socket_api.Destroy(fd_); + fd_ = kQuicInvalidSocketFd; + } + } + + void StartServer() { + if (fd_ != kQuicInvalidSocketFd) { + // We previously called StopServer to reserve the ephemeral port. Close + // the socket so that it's available below. + QuicUdpSocketApi socket_api; + socket_api.Destroy(fd_); + fd_ = kQuicInvalidSocketFd; + } + auto test_server = std::make_unique( + crypto_test_utils::ProofSourceForTesting(), server_config_, + server_supported_versions_, &memory_cache_backend_, + expected_server_connection_id_length_); + test_server->SetEventLoopFactory(GetParam().event_loop); + const QuicSocketAddress server_listening_address = + server_listening_address_.has_value() ? *server_listening_address_ + : server_address_; + server_thread_ = std::make_unique(std::move(test_server), + server_listening_address); + if (chlo_multiplier_ != 0) { + server_thread_->server()->SetChloMultiplier(chlo_multiplier_); + } + if (!pre_shared_key_server_.empty()) { + server_thread_->server()->SetPreSharedKey(pre_shared_key_server_); + } + server_thread_->Initialize(); + server_address_ = + QuicSocketAddress(server_address_.host(), server_thread_->GetPort()); + QuicDispatcher* dispatcher = + QuicServerPeer::GetDispatcher(server_thread_->server()); + ASSERT_TRUE(dispatcher != nullptr); + QuicDispatcherPeer::UseWriter(dispatcher, server_writer_); + + server_writer_->Initialize(QuicDispatcherPeer::GetHelper(dispatcher), + QuicDispatcherPeer::GetAlarmFactory(dispatcher), + std::make_unique(dispatcher)); + if (stream_factory_ != nullptr) { + static_cast(server_thread_->server()) + ->SetSpdyStreamFactory(stream_factory_); + } + + server_thread_->Start(); + } + + void StopServer(bool will_restart = true) { + if (server_thread_) { + server_thread_->Quit(); + server_thread_->Join(); + } + if (will_restart) { + // server_address_ now contains the random listening port. Since many + // tests will attempt to re-bind the socket, claim it so that the kernel + // doesn't give away the ephemeral port. + QuicUdpSocketApi socket_api; + fd_ = socket_api.Create( + server_address_.host().AddressFamilyToInt(), + /*receive_buffer_size =*/kDefaultSocketReceiveBuffer, + /*send_buffer_size =*/kDefaultSocketReceiveBuffer); + if (fd_ == kQuicInvalidSocketFd) { + QUIC_LOG(ERROR) << "CreateSocket() failed: " << strerror(errno); + return; + } + int rc = socket_api.Bind(fd_, server_address_); + if (rc < 0) { + QUIC_LOG(ERROR) << "Bind failed: " << strerror(errno); + return; + } + } + } + + void AddToCache(absl::string_view path, int response_code, + absl::string_view body) { + memory_cache_backend_.AddSimpleResponse(server_hostname_, path, + response_code, body); + } + + void SetPacketLossPercentage(int32_t loss) { + client_writer_->set_fake_packet_loss_percentage(loss); + server_writer_->set_fake_packet_loss_percentage(loss); + } + + void SetPacketSendDelay(QuicTime::Delta delay) { + client_writer_->set_fake_packet_delay(delay); + server_writer_->set_fake_packet_delay(delay); + } + + void SetReorderPercentage(int32_t reorder) { + client_writer_->set_fake_reorder_percentage(reorder); + server_writer_->set_fake_reorder_percentage(reorder); + } + + // Verifies that the client and server connections were both free of packets + // being discarded, based on connection stats. + // Calls server_thread_ Pause() and Resume(), which may only be called once + // per test. + void VerifyCleanConnection(bool had_packet_loss) { + QuicConnection* client_connection = GetClientConnection(); + if (client_connection == nullptr) { + ADD_FAILURE() << "Missing client connection"; + return; + } + QuicConnectionStats client_stats = client_connection->GetStats(); + // TODO(ianswett): Determine why this becomes even more flaky with BBR + // enabled. b/62141144 + if (!had_packet_loss && !GetQuicReloadableFlag(quic_default_to_bbr)) { + EXPECT_EQ(0u, client_stats.packets_lost); + } + EXPECT_EQ(0u, client_stats.packets_discarded); + // When client starts with an unsupported version, the version negotiation + // packet sent by server for the old connection (respond for the connection + // close packet) will be dropped by the client. + if (!ServerSendsVersionNegotiation()) { + EXPECT_EQ(0u, client_stats.packets_dropped); + } + if (!version_.UsesTls()) { + // Only enforce this for QUIC crypto because accounting of number of + // packets received, processed gets complicated with packets coalescing + // and key dropping. For example, a received undecryptable coalesced + // packet can be processed later and each sub-packet increases + // packets_processed. + EXPECT_EQ(client_stats.packets_received, client_stats.packets_processed); + } + + if (!server_thread_) { + ADD_FAILURE() << "Missing server thread"; + return; + } + server_thread_->Pause(); + QuicSpdySession* server_session = GetServerSession(); + if (server_session != nullptr) { + QuicConnection* server_connection = server_session->connection(); + if (server_connection != nullptr) { + QuicConnectionStats server_stats = server_connection->GetStats(); + if (!had_packet_loss) { + EXPECT_EQ(0u, server_stats.packets_lost); + } + EXPECT_EQ(0u, server_stats.packets_discarded); + } else { + ADD_FAILURE() << "Missing server connection"; + } + } else { + ADD_FAILURE() << "Missing server session"; + } + // TODO(ianswett): Restore the check for packets_dropped equals 0. + // The expect for packets received is equal to packets processed fails + // due to version negotiation packets. + server_thread_->Resume(); + } + + // Returns true when client starts with an unsupported version, and client + // closes connection when version negotiation is received. + bool ServerSendsVersionNegotiation() { + return client_supported_versions_[0] != version_; + } + + bool SupportsIetfQuicWithTls(ParsedQuicVersion version) { + return version.HasIetfInvariantHeader() && + version.handshake_protocol == PROTOCOL_TLS1_3; + } + + static void ExpectFlowControlsSynced(QuicSession* client, + QuicSession* server) { + EXPECT_EQ( + QuicFlowControllerPeer::SendWindowSize(client->flow_controller()), + QuicFlowControllerPeer::ReceiveWindowSize(server->flow_controller())); + EXPECT_EQ( + QuicFlowControllerPeer::ReceiveWindowSize(client->flow_controller()), + QuicFlowControllerPeer::SendWindowSize(server->flow_controller())); + } + + static void ExpectFlowControlsSynced(QuicStream* client, QuicStream* server) { + EXPECT_EQ(QuicStreamPeer::SendWindowSize(client), + QuicStreamPeer::ReceiveWindowSize(server)); + EXPECT_EQ(QuicStreamPeer::ReceiveWindowSize(client), + QuicStreamPeer::SendWindowSize(server)); + } + + // Must be called before Initialize to have effect. + void SetSpdyStreamFactory(QuicTestServer::StreamFactory* factory) { + stream_factory_ = factory; + } + + QuicStreamId GetNthClientInitiatedBidirectionalId(int n) { + return GetNthClientInitiatedBidirectionalStreamId( + version_.transport_version, n); + } + + QuicStreamId GetNthServerInitiatedBidirectionalId(int n) { + return GetNthServerInitiatedBidirectionalStreamId( + version_.transport_version, n); + } + + bool CheckResponseHeaders(QuicTestClient* client, + const std::string& expected_status) { + const spdy::Http2HeaderBlock* response_headers = client->response_headers(); + auto it = response_headers->find(":status"); + if (it == response_headers->end()) { + ADD_FAILURE() << "Did not find :status header in response"; + return false; + } + if (it->second != expected_status) { + ADD_FAILURE() << "Got bad :status response: \"" << it->second << "\""; + return false; + } + return true; + } + + bool CheckResponseHeaders(QuicTestClient* client) { + return CheckResponseHeaders(client, "200"); + } + + bool CheckResponseHeaders(const std::string& expected_status) { + return CheckResponseHeaders(client_.get(), expected_status); + } + + bool CheckResponseHeaders() { return CheckResponseHeaders(client_.get()); } + + bool CheckResponse(QuicTestClient* client, + const std::string& received_response, + const std::string& expected_response) { + EXPECT_THAT(client_->stream_error(), IsQuicStreamNoError()); + EXPECT_THAT(client_->connection_error(), IsQuicNoError()); + + if (received_response.empty() && !expected_response.empty()) { + ADD_FAILURE() << "Failed to get any response for request"; + return false; + } + if (received_response != expected_response) { + ADD_FAILURE() << "Got wrong response: \"" << received_response << "\""; + return false; + } + return CheckResponseHeaders(client); + } + + bool SendSynchronousRequestAndCheckResponse( + QuicTestClient* client, const std::string& request, + const std::string& expected_response) { + std::string received_response = client->SendSynchronousRequest(request); + return CheckResponse(client, received_response, expected_response); + } + + bool SendSynchronousRequestAndCheckResponse( + const std::string& request, const std::string& expected_response) { + return SendSynchronousRequestAndCheckResponse(client_.get(), request, + expected_response); + } + + bool SendSynchronousFooRequestAndCheckResponse(QuicTestClient* client) { + return SendSynchronousRequestAndCheckResponse(client, "/foo", + kFooResponseBody); + } + + bool SendSynchronousFooRequestAndCheckResponse() { + return SendSynchronousFooRequestAndCheckResponse(client_.get()); + } + + bool SendSynchronousBarRequestAndCheckResponse() { + std::string received_response = client_->SendSynchronousRequest("/bar"); + return CheckResponse(client_.get(), received_response, kBarResponseBody); + } + + bool WaitForFooResponseAndCheckIt(QuicTestClient* client) { + client->WaitForResponse(); + std::string received_response = client->response_body(); + return CheckResponse(client_.get(), received_response, kFooResponseBody); + } + + bool WaitForFooResponseAndCheckIt() { + return WaitForFooResponseAndCheckIt(client_.get()); + } + + WebTransportHttp3* CreateWebTransportSession( + const std::string& path, bool wait_for_server_response, + QuicSpdyStream** connect_stream_out = nullptr) { + // Wait until we receive the settings from the server indicating + // WebTransport support. + client_->WaitUntil( + 2000, [this]() { return GetClientSession()->SupportsWebTransport(); }); + if (!GetClientSession()->SupportsWebTransport()) { + return nullptr; + } + + spdy::Http2HeaderBlock headers; + headers[":scheme"] = "https"; + headers[":authority"] = "localhost"; + headers[":path"] = path; + headers[":method"] = "CONNECT"; + headers[":protocol"] = "webtransport"; + + client_->SendMessage(headers, "", /*fin=*/false); + QuicSpdyStream* stream = client_->latest_created_stream(); + if (stream->web_transport() == nullptr) { + return nullptr; + } + WebTransportSessionId id = client_->latest_created_stream()->id(); + QuicSpdySession* client_session = GetClientSession(); + if (client_session->GetWebTransportSession(id) == nullptr) { + return nullptr; + } + WebTransportHttp3* session = client_session->GetWebTransportSession(id); + if (wait_for_server_response) { + client_->WaitUntil(-1, + [stream]() { return stream->headers_decompressed(); }); + EXPECT_TRUE(session->ready()); + } + if (connect_stream_out != nullptr) { + *connect_stream_out = stream; + } + return session; + } + + NiceMock& SetupWebTransportVisitor( + WebTransportHttp3* session) { + auto visitor_owned = + std::make_unique>(); + NiceMock& visitor = *visitor_owned; + session->SetVisitor(std::move(visitor_owned)); + return visitor; + } + + std::string ReadDataFromWebTransportStreamUntilFin( + WebTransportStream* stream, + MockWebTransportStreamVisitor* visitor = nullptr) { + QuicStreamId id = stream->GetStreamId(); + std::string buffer; + + // Try reading data if immediately available. + WebTransportStream::ReadResult result = stream->Read(&buffer); + if (result.fin) { + return buffer; + } + + while (true) { + bool can_read = false; + if (visitor == nullptr) { + auto visitor_owned = std::make_unique(); + visitor = visitor_owned.get(); + stream->SetVisitor(std::move(visitor_owned)); + } + EXPECT_CALL(*visitor, OnCanRead()) + .WillRepeatedly(Assign(&can_read, true)); + client_->WaitUntil(5000 /*ms*/, [&can_read]() { return can_read; }); + if (!can_read) { + ADD_FAILURE() << "Waiting for readable data on stream " << id + << " timed out"; + return buffer; + } + if (GetClientSession()->GetOrCreateSpdyDataStream(id) == nullptr) { + ADD_FAILURE() << "Stream " << id + << " was deleted while waiting for incoming data"; + return buffer; + } + + result = stream->Read(&buffer); + if (result.fin) { + return buffer; + } + if (result.bytes_read == 0) { + ADD_FAILURE() << "No progress made while reading from stream " + << stream->GetStreamId(); + return buffer; + } + } + } + + void ReadAllIncomingWebTransportUnidirectionalStreams( + WebTransportSession* session) { + while (true) { + WebTransportStream* received_stream = + session->AcceptIncomingUnidirectionalStream(); + if (received_stream == nullptr) { + break; + } + received_webtransport_unidirectional_streams_.push_back( + ReadDataFromWebTransportStreamUntilFin(received_stream)); + } + } + + void WaitForNewConnectionIds() { + // Wait until a new server CID is available for another migration. + const auto* client_connection = GetClientConnection(); + while (!QuicConnectionPeer::HasUnusedPeerIssuedConnectionId( + client_connection) || + (!client_connection->client_connection_id().IsEmpty() && + !QuicConnectionPeer::HasSelfIssuedConnectionIdToConsume( + client_connection))) { + client_->client()->WaitForEvents(); + } + } + + quiche::test::ScopedEnvironmentForThreads environment_; + bool initialized_; + // If true, the Initialize() function will create |client_| and starts to + // connect to the server. + // Default is true. + bool connect_to_server_on_initialize_; + QuicSocketAddress server_address_; + absl::optional server_listening_address_; + std::string server_hostname_; + QuicTestBackend memory_cache_backend_; + std::unique_ptr server_thread_; + // This socket keeps the ephemeral port reserved so that the kernel doesn't + // give it away while the server is shut down. + QuicUdpSocketFd fd_; + std::unique_ptr client_; + QuicConnectionDebugVisitor* connection_debug_visitor_ = nullptr; + PacketDroppingTestWriter* client_writer_; + PacketDroppingTestWriter* server_writer_; + QuicConfig client_config_; + QuicConfig server_config_; + ParsedQuicVersion version_; + ParsedQuicVersionVector client_supported_versions_; + ParsedQuicVersionVector server_supported_versions_; + QuicTagVector client_extra_copts_; + size_t chlo_multiplier_; + QuicTestServer::StreamFactory* stream_factory_; + std::string pre_shared_key_client_; + std::string pre_shared_key_server_; + int override_server_connection_id_length_; + int override_client_connection_id_length_ = -1; + uint8_t expected_server_connection_id_length_; + bool enable_web_transport_ = false; + std::vector received_webtransport_unidirectional_streams_; + bool use_preferred_address_ = false; + QuicSocketAddress server_preferred_address_; +}; + +// Run all end to end tests with all supported versions. +INSTANTIATE_TEST_SUITE_P(EndToEndTests, EndToEndTest, + ::testing::ValuesIn(GetTestParams()), + ::testing::PrintToStringParamName()); + +TEST_P(EndToEndTest, HandshakeSuccessful) { + ASSERT_TRUE(Initialize()); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + ASSERT_TRUE(server_thread_); + server_thread_->WaitForCryptoHandshakeConfirmed(); + QuicSpdyClientSession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + QuicCryptoStream* client_crypto_stream = + QuicSessionPeer::GetMutableCryptoStream(client_session); + ASSERT_TRUE(client_crypto_stream); + QuicStreamSequencer* client_sequencer = + QuicStreamPeer::sequencer(client_crypto_stream); + ASSERT_TRUE(client_sequencer); + EXPECT_FALSE( + QuicStreamSequencerPeer::IsUnderlyingBufferAllocated(client_sequencer)); + + // We've had bugs in the past where the connections could end up on the wrong + // version. This was never diagnosed but could have been due to in-connection + // version negotiation back when that existed. At this point in time, our test + // setup ensures that connections here always use |version_|, but we add this + // sanity check out of paranoia to catch a regression of this type. + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + EXPECT_EQ(client_connection->version(), version_); + + server_thread_->Pause(); + QuicSpdySession* server_session = GetServerSession(); + QuicConnection* server_connection = nullptr; + QuicCryptoStream* server_crypto_stream = nullptr; + QuicStreamSequencer* server_sequencer = nullptr; + if (server_session != nullptr) { + server_connection = server_session->connection(); + server_crypto_stream = + QuicSessionPeer::GetMutableCryptoStream(server_session); + } else { + ADD_FAILURE() << "Missing server session"; + } + if (server_crypto_stream != nullptr) { + server_sequencer = QuicStreamPeer::sequencer(server_crypto_stream); + } else { + ADD_FAILURE() << "Missing server crypto stream"; + } + if (server_sequencer != nullptr) { + EXPECT_FALSE( + QuicStreamSequencerPeer::IsUnderlyingBufferAllocated(server_sequencer)); + } else { + ADD_FAILURE() << "Missing server sequencer"; + } + if (server_connection != nullptr) { + EXPECT_EQ(server_connection->version(), version_); + } else { + ADD_FAILURE() << "Missing server connection"; + } + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, ExportKeyingMaterial) { + ASSERT_TRUE(Initialize()); + if (!version_.UsesTls()) { + return; + } + const char* kExportLabel = "label"; + const int kExportLen = 30; + std::string client_keying_material_export, server_keying_material_export; + + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + ASSERT_TRUE(server_thread_); + server_thread_->WaitForCryptoHandshakeConfirmed(); + + server_thread_->Pause(); + QuicSpdySession* server_session = GetServerSession(); + QuicCryptoStream* server_crypto_stream = nullptr; + if (server_session != nullptr) { + server_crypto_stream = + QuicSessionPeer::GetMutableCryptoStream(server_session); + } else { + ADD_FAILURE() << "Missing server session"; + } + if (server_crypto_stream != nullptr) { + ASSERT_TRUE(server_crypto_stream->ExportKeyingMaterial( + kExportLabel, /*context=*/"", kExportLen, + &server_keying_material_export)); + + } else { + ADD_FAILURE() << "Missing server crypto stream"; + } + server_thread_->Resume(); + + QuicSpdyClientSession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + QuicCryptoStream* client_crypto_stream = + QuicSessionPeer::GetMutableCryptoStream(client_session); + ASSERT_TRUE(client_crypto_stream); + ASSERT_TRUE(client_crypto_stream->ExportKeyingMaterial( + kExportLabel, /*context=*/"", kExportLen, + &client_keying_material_export)); + ASSERT_EQ(client_keying_material_export.size(), + static_cast(kExportLen)); + EXPECT_EQ(client_keying_material_export, server_keying_material_export); +} + +TEST_P(EndToEndTest, SimpleRequestResponse) { + ASSERT_TRUE(Initialize()); + + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_FALSE(client_->client()->EarlyDataAccepted()); + EXPECT_FALSE(client_->client()->ReceivedInchoateReject()); + if (version_.UsesHttp3()) { + QuicSpdyClientSession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_TRUE(QuicSpdySessionPeer::GetSendControlStream(client_session)); + EXPECT_TRUE(QuicSpdySessionPeer::GetReceiveControlStream(client_session)); + server_thread_->Pause(); + QuicSpdySession* server_session = GetServerSession(); + if (server_session != nullptr) { + EXPECT_TRUE(QuicSpdySessionPeer::GetSendControlStream(server_session)); + EXPECT_TRUE(QuicSpdySessionPeer::GetReceiveControlStream(server_session)); + } else { + ADD_FAILURE() << "Missing server session"; + } + server_thread_->Resume(); + } + QuicConnectionStats client_stats = GetClientConnection()->GetStats(); + EXPECT_TRUE(client_stats.handshake_completion_time.IsInitialized()); +} + +TEST_P(EndToEndTest, HandshakeConfirmed) { + ASSERT_TRUE(Initialize()); + if (!version_.UsesTls()) { + return; + } + SendSynchronousFooRequestAndCheckResponse(); + // Verify handshake state. + QuicSpdyClientSession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_EQ(HANDSHAKE_CONFIRMED, client_session->GetHandshakeState()); + server_thread_->Pause(); + QuicSpdySession* server_session = GetServerSession(); + if (server_session != nullptr) { + EXPECT_EQ(HANDSHAKE_CONFIRMED, server_session->GetHandshakeState()); + } else { + ADD_FAILURE() << "Missing server session"; + } + server_thread_->Resume(); + client_->Disconnect(); +} + +TEST_P(EndToEndTest, SendAndReceiveCoalescedPackets) { + ASSERT_TRUE(Initialize()); + if (!version_.CanSendCoalescedPackets()) { + return; + } + SendSynchronousFooRequestAndCheckResponse(); + // Verify client successfully processes coalesced packets. + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + QuicConnectionStats client_stats = client_connection->GetStats(); + EXPECT_LT(0u, client_stats.num_coalesced_packets_received); + EXPECT_EQ(client_stats.num_coalesced_packets_processed, + client_stats.num_coalesced_packets_received); + // TODO(fayang): verify server successfully processes coalesced packets. +} + +// Simple transaction, but set a non-default ack delay at the client +// and ensure it gets to the server. +TEST_P(EndToEndTest, SimpleRequestResponseWithAckDelayChange) { + // Force the ACK delay to be something other than the default. + constexpr uint32_t kClientMaxAckDelay = kDefaultDelayedAckTimeMs + 100u; + client_config_.SetMaxAckDelayToSendMs(kClientMaxAckDelay); + ASSERT_TRUE(Initialize()); + + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_FALSE(client_->client()->EarlyDataAccepted()); + EXPECT_FALSE(client_->client()->ReceivedInchoateReject()); + + server_thread_->Pause(); + const QuicSentPacketManager* server_sent_packet_manager = + GetSentPacketManagerFromFirstServerSession(); + if (server_sent_packet_manager != nullptr) { + EXPECT_EQ( + kClientMaxAckDelay, + server_sent_packet_manager->peer_max_ack_delay().ToMilliseconds()); + } else { + ADD_FAILURE() << "Missing server sent packet manager"; + } + server_thread_->Resume(); +} + +// Simple transaction, but set a non-default ack exponent at the client +// and ensure it gets to the server. +TEST_P(EndToEndTest, SimpleRequestResponseWithAckExponentChange) { + const uint32_t kClientAckDelayExponent = 19; + EXPECT_NE(kClientAckDelayExponent, kDefaultAckDelayExponent); + // Force the ACK exponent to be something other than the default. + // Note that it is sent only with QUIC+TLS. + client_config_.SetAckDelayExponentToSend(kClientAckDelayExponent); + ASSERT_TRUE(Initialize()); + + SendSynchronousFooRequestAndCheckResponse(); + + EXPECT_FALSE(client_->client()->EarlyDataAccepted()); + EXPECT_FALSE(client_->client()->ReceivedInchoateReject()); + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + if (server_connection != nullptr) { + if (version_.UsesTls()) { + // Should be only sent with QUIC+TLS. + EXPECT_EQ(kClientAckDelayExponent, + server_connection->framer().peer_ack_delay_exponent()); + } else { + // No change for QUIC_CRYPTO. + EXPECT_EQ(kDefaultAckDelayExponent, + server_connection->framer().peer_ack_delay_exponent()); + } + // No change, regardless of version. + EXPECT_EQ(kDefaultAckDelayExponent, + server_connection->framer().local_ack_delay_exponent()); + } else { + ADD_FAILURE() << "Missing server connection"; + } + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, SimpleRequestResponseForcedVersionNegotiation) { + client_supported_versions_.insert(client_supported_versions_.begin(), + QuicVersionReservedForNegotiation()); + NiceMock visitor; + connection_debug_visitor_ = &visitor; + EXPECT_CALL(visitor, OnVersionNegotiationPacket(_)).Times(1); + ASSERT_TRUE(Initialize()); + ASSERT_TRUE(ServerSendsVersionNegotiation()); + + SendSynchronousFooRequestAndCheckResponse(); + + EXPECT_FALSE(client_->client()->EarlyDataAccepted()); + EXPECT_FALSE(client_->client()->ReceivedInchoateReject()); +} + +TEST_P(EndToEndTest, ForcedVersionNegotiation) { + client_supported_versions_.insert(client_supported_versions_.begin(), + QuicVersionReservedForNegotiation()); + ASSERT_TRUE(Initialize()); + ASSERT_TRUE(ServerSendsVersionNegotiation()); + + SendSynchronousFooRequestAndCheckResponse(); +} + +TEST_P(EndToEndTest, SimpleRequestResponseZeroConnectionID) { + if (!version_.AllowsVariableLengthConnectionIds() || + override_server_connection_id_length_ > -1) { + ASSERT_TRUE(Initialize()); + return; + } + override_server_connection_id_length_ = 0; + expected_server_connection_id_length_ = 0; + ASSERT_TRUE(Initialize()); + + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_FALSE(client_->client()->EarlyDataAccepted()); + EXPECT_FALSE(client_->client()->ReceivedInchoateReject()); + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + EXPECT_EQ(client_connection->connection_id(), + QuicUtils::CreateZeroConnectionId(version_.transport_version)); +} + +TEST_P(EndToEndTest, ZeroConnectionID) { + if (!version_.AllowsVariableLengthConnectionIds() || + override_server_connection_id_length_ > -1) { + ASSERT_TRUE(Initialize()); + return; + } + override_server_connection_id_length_ = 0; + expected_server_connection_id_length_ = 0; + ASSERT_TRUE(Initialize()); + + SendSynchronousFooRequestAndCheckResponse(); + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + EXPECT_EQ(client_connection->connection_id(), + QuicUtils::CreateZeroConnectionId(version_.transport_version)); +} + +TEST_P(EndToEndTest, BadConnectionIdLength) { + if (!version_.AllowsVariableLengthConnectionIds() || + override_server_connection_id_length_ > -1) { + ASSERT_TRUE(Initialize()); + return; + } + override_server_connection_id_length_ = 9; + ASSERT_TRUE(Initialize()); + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_EQ(kQuicDefaultConnectionIdLength, client_->client() + ->client_session() + ->connection() + ->connection_id() + .length()); +} + +TEST_P(EndToEndTest, ClientConnectionId) { + if (!version_.SupportsClientConnectionIds()) { + ASSERT_TRUE(Initialize()); + return; + } + override_client_connection_id_length_ = kQuicDefaultConnectionIdLength; + ASSERT_TRUE(Initialize()); + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_EQ(override_client_connection_id_length_, client_->client() + ->client_session() + ->connection() + ->client_connection_id() + .length()); +} + +TEST_P(EndToEndTest, ForcedVersionNegotiationAndClientConnectionId) { + if (!version_.SupportsClientConnectionIds()) { + ASSERT_TRUE(Initialize()); + return; + } + client_supported_versions_.insert(client_supported_versions_.begin(), + QuicVersionReservedForNegotiation()); + override_client_connection_id_length_ = kQuicDefaultConnectionIdLength; + ASSERT_TRUE(Initialize()); + ASSERT_TRUE(ServerSendsVersionNegotiation()); + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_EQ(override_client_connection_id_length_, client_->client() + ->client_session() + ->connection() + ->client_connection_id() + .length()); +} + +TEST_P(EndToEndTest, ForcedVersionNegotiationAndBadConnectionIdLength) { + if (!version_.AllowsVariableLengthConnectionIds() || + override_server_connection_id_length_ > -1) { + ASSERT_TRUE(Initialize()); + return; + } + client_supported_versions_.insert(client_supported_versions_.begin(), + QuicVersionReservedForNegotiation()); + override_server_connection_id_length_ = 9; + ASSERT_TRUE(Initialize()); + ASSERT_TRUE(ServerSendsVersionNegotiation()); + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_EQ(kQuicDefaultConnectionIdLength, client_->client() + ->client_session() + ->connection() + ->connection_id() + .length()); +} + +// Forced Version Negotiation with a client connection ID and a long +// connection ID. +TEST_P(EndToEndTest, ForcedVersNegoAndClientCIDAndLongCID) { + if (!version_.SupportsClientConnectionIds() || + !version_.AllowsVariableLengthConnectionIds() || + override_server_connection_id_length_ != kLongConnectionIdLength) { + ASSERT_TRUE(Initialize()); + return; + } + client_supported_versions_.insert(client_supported_versions_.begin(), + QuicVersionReservedForNegotiation()); + override_client_connection_id_length_ = 18; + ASSERT_TRUE(Initialize()); + ASSERT_TRUE(ServerSendsVersionNegotiation()); + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_EQ(kQuicDefaultConnectionIdLength, client_->client() + ->client_session() + ->connection() + ->connection_id() + .length()); + EXPECT_EQ(override_client_connection_id_length_, client_->client() + ->client_session() + ->connection() + ->client_connection_id() + .length()); +} + +TEST_P(EndToEndTest, MixGoodAndBadConnectionIdLengths) { + if (!version_.AllowsVariableLengthConnectionIds() || + override_server_connection_id_length_ > -1) { + ASSERT_TRUE(Initialize()); + return; + } + + // Start client_ which will use a bad connection ID length. + override_server_connection_id_length_ = 9; + ASSERT_TRUE(Initialize()); + override_server_connection_id_length_ = -1; + + // Start client2 which will use a good connection ID length. + std::unique_ptr client2(CreateQuicClient(nullptr)); + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + headers["content-length"] = "3"; + client2->SendMessage(headers, "", /*fin=*/false); + client2->SendData("eep", true); + + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_EQ(kQuicDefaultConnectionIdLength, client_->client() + ->client_session() + ->connection() + ->connection_id() + .length()); + + WaitForFooResponseAndCheckIt(client2.get()); + EXPECT_EQ(kQuicDefaultConnectionIdLength, client2->client() + ->client_session() + ->connection() + ->connection_id() + .length()); +} + +TEST_P(EndToEndTest, SimpleRequestResponseWithIetfDraftSupport) { + if (!version_.HasIetfQuicFrames()) { + ASSERT_TRUE(Initialize()); + return; + } + QuicVersionInitializeSupportForIetfDraft(); + ASSERT_TRUE(Initialize()); + + SendSynchronousFooRequestAndCheckResponse(); +} + +TEST_P(EndToEndTest, SimpleRequestResponseWithLargeReject) { + chlo_multiplier_ = 1; + ASSERT_TRUE(Initialize()); + + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_FALSE(client_->client()->EarlyDataAccepted()); + if (version_.UsesTls()) { + // REJ messages are a QUIC crypto feature, so TLS always returns false. + EXPECT_FALSE(client_->client()->ReceivedInchoateReject()); + } else { + EXPECT_TRUE(client_->client()->ReceivedInchoateReject()); + } +} + +TEST_P(EndToEndTest, SimpleRequestResponsev6) { + server_address_ = + QuicSocketAddress(QuicIpAddress::Loopback6(), server_address_.port()); + ASSERT_TRUE(Initialize()); + + SendSynchronousFooRequestAndCheckResponse(); +} + +TEST_P(EndToEndTest, + ClientDoesNotAllowServerDataOnServerInitiatedBidirectionalStreams) { + set_client_initial_max_stream_data_incoming_bidirectional(0); + ASSERT_TRUE(Initialize()); + SendSynchronousFooRequestAndCheckResponse(); +} + +TEST_P(EndToEndTest, + ServerDoesNotAllowClientDataOnServerInitiatedBidirectionalStreams) { + set_server_initial_max_stream_data_outgoing_bidirectional(0); + ASSERT_TRUE(Initialize()); + SendSynchronousFooRequestAndCheckResponse(); +} + +TEST_P(EndToEndTest, + BothEndpointsDisallowDataOnServerInitiatedBidirectionalStreams) { + set_client_initial_max_stream_data_incoming_bidirectional(0); + set_server_initial_max_stream_data_outgoing_bidirectional(0); + ASSERT_TRUE(Initialize()); + SendSynchronousFooRequestAndCheckResponse(); +} + +// Regression test for a bug where we would always fail to decrypt the first +// initial packet. Undecryptable packets can be seen after the handshake +// is complete due to dropping the initial keys at that point, so we only test +// for undecryptable packets before then. +TEST_P(EndToEndTest, NoUndecryptablePacketsBeforeHandshakeComplete) { + ASSERT_TRUE(Initialize()); + + SendSynchronousFooRequestAndCheckResponse(); + + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + QuicConnectionStats client_stats = client_connection->GetStats(); + EXPECT_EQ( + 0u, + client_stats.undecryptable_packets_received_before_handshake_complete); + + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + if (server_connection != nullptr) { + QuicConnectionStats server_stats = server_connection->GetStats(); + EXPECT_EQ( + 0u, + server_stats.undecryptable_packets_received_before_handshake_complete); + } else { + ADD_FAILURE() << "Missing server connection"; + } + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, SeparateFinPacket) { + ASSERT_TRUE(Initialize()); + + // Send a request in two parts: the request and then an empty packet with FIN. + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + client_->SendMessage(headers, "", /*fin=*/false); + client_->SendData("", true); + WaitForFooResponseAndCheckIt(); + + // Now do the same thing but with a content length. + headers["content-length"] = "3"; + client_->SendMessage(headers, "", /*fin=*/false); + client_->SendData("foo", true); + WaitForFooResponseAndCheckIt(); +} + +TEST_P(EndToEndTest, MultipleRequestResponse) { + ASSERT_TRUE(Initialize()); + + SendSynchronousFooRequestAndCheckResponse(); + SendSynchronousBarRequestAndCheckResponse(); +} + +TEST_P(EndToEndTest, MultipleRequestResponseZeroConnectionID) { + if (!version_.AllowsVariableLengthConnectionIds() || + override_server_connection_id_length_ > -1) { + ASSERT_TRUE(Initialize()); + return; + } + override_server_connection_id_length_ = 0; + expected_server_connection_id_length_ = 0; + ASSERT_TRUE(Initialize()); + + SendSynchronousFooRequestAndCheckResponse(); + SendSynchronousBarRequestAndCheckResponse(); +} + +TEST_P(EndToEndTest, MultipleStreams) { + // Verifies quic_test_client can track responses of all active streams. + ASSERT_TRUE(Initialize()); + + const int kNumRequests = 10; + + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + headers["content-length"] = "3"; + + for (int i = 0; i < kNumRequests; ++i) { + client_->SendMessage(headers, "bar", /*fin=*/true); + } + + while (kNumRequests > client_->num_responses()) { + client_->ClearPerRequestState(); + ASSERT_TRUE(WaitForFooResponseAndCheckIt()); + } +} + +TEST_P(EndToEndTest, MultipleClients) { + ASSERT_TRUE(Initialize()); + std::unique_ptr client2(CreateQuicClient(nullptr)); + + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + headers["content-length"] = "3"; + + client_->SendMessage(headers, "", /*fin=*/false); + client2->SendMessage(headers, "", /*fin=*/false); + + client_->SendData("bar", true); + WaitForFooResponseAndCheckIt(); + + client2->SendData("eep", true); + WaitForFooResponseAndCheckIt(client2.get()); +} + +TEST_P(EndToEndTest, RequestOverMultiplePackets) { + // Send a large enough request to guarantee fragmentation. + std::string huge_request = + "/some/path?query=" + std::string(kMaxOutgoingPacketSize, '.'); + AddToCache(huge_request, 200, kBarResponseBody); + + ASSERT_TRUE(Initialize()); + + SendSynchronousRequestAndCheckResponse(huge_request, kBarResponseBody); +} + +TEST_P(EndToEndTest, MultiplePacketsRandomOrder) { + // Send a large enough request to guarantee fragmentation. + std::string huge_request = + "/some/path?query=" + std::string(kMaxOutgoingPacketSize, '.'); + AddToCache(huge_request, 200, kBarResponseBody); + + ASSERT_TRUE(Initialize()); + SetPacketSendDelay(QuicTime::Delta::FromMilliseconds(2)); + SetReorderPercentage(50); + + SendSynchronousRequestAndCheckResponse(huge_request, kBarResponseBody); +} + +TEST_P(EndToEndTest, PostMissingBytes) { + ASSERT_TRUE(Initialize()); + + // Add a content length header with no body. + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + headers["content-length"] = "3"; + + // This should be detected as stream fin without complete request, + // triggering an error response. + client_->SendCustomSynchronousRequest(headers, ""); + EXPECT_EQ(QuicSimpleServerStream::kErrorResponseBody, + client_->response_body()); + CheckResponseHeaders("500"); +} + +TEST_P(EndToEndTest, LargePostNoPacketLoss) { + ASSERT_TRUE(Initialize()); + + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + + // 1 MB body. + std::string body(1024 * 1024, 'a'); + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + + EXPECT_EQ(kFooResponseBody, + client_->SendCustomSynchronousRequest(headers, body)); + // TODO(ianswett): There should not be packet loss in this test, but on some + // platforms the receive buffer overflows. + VerifyCleanConnection(true); +} + +// Marked as slow since this adds a real-clock one second of delay. +TEST_P(EndToEndTest, QUICHE_SLOW_TEST(LargePostNoPacketLoss1sRTT)) { + ASSERT_TRUE(Initialize()); + SetPacketSendDelay(QuicTime::Delta::FromMilliseconds(1000)); + + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + + // 100 KB body. + std::string body(100 * 1024, 'a'); + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + + EXPECT_EQ(kFooResponseBody, + client_->SendCustomSynchronousRequest(headers, body)); + VerifyCleanConnection(false); +} + +TEST_P(EndToEndTest, LargePostWithPacketLoss) { + // Connect with lower fake packet loss than we'd like to test. + // Until b/10126687 is fixed, losing handshake packets is pretty + // brutal. + // Disable blackhole detection as this test is testing loss recovery. + client_extra_copts_.push_back(kNBHD); + SetPacketLossPercentage(5); + ASSERT_TRUE(Initialize()); + EXPECT_TRUE(client_->client()->WaitForHandshakeConfirmed()); + SetPacketLossPercentage(30); + + // 10 KB body. + std::string body(1024 * 10, 'a'); + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + + EXPECT_EQ(kFooResponseBody, + client_->SendCustomSynchronousRequest(headers, body)); + if (override_server_connection_id_length_ == -1) { + // If the client sends a longer connection ID, we can end up with dropped + // packets. The packets_dropped counter increments whenever a packet arrives + // with a new server connection ID that is not INITIAL, RETRY, or 1-RTT. + // With packet losses, we could easily lose a server INITIAL and have the + // first observed server packet be HANDSHAKE. + VerifyCleanConnection(true); + } +} + +// Regression test for b/80090281. +TEST_P(EndToEndTest, LargePostWithPacketLossAndAlwaysBundleWindowUpdates) { + // Disable blackhole detection as this test is testing loss recovery. + client_extra_copts_.push_back(kNBHD); + ASSERT_TRUE(Initialize()); + EXPECT_TRUE(client_->client()->WaitForHandshakeConfirmed()); + server_thread_->WaitForCryptoHandshakeConfirmed(); + + // Normally server only bundles a retransmittable frame once every other + // kMaxConsecutiveNonRetransmittablePackets ack-only packets. Setting the max + // to 0 to reliably reproduce b/80090281. + server_thread_->Schedule([this]() { + QuicConnection* server_connection = GetServerConnection(); + if (server_connection != nullptr) { + QuicConnectionPeer:: + SetMaxConsecutiveNumPacketsWithNoRetransmittableFrames( + server_connection, 0); + } else { + ADD_FAILURE() << "Missing server connection"; + } + }); + + SetPacketLossPercentage(30); + + // 10 KB body. + std::string body(1024 * 10, 'a'); + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + + EXPECT_EQ(kFooResponseBody, + client_->SendCustomSynchronousRequest(headers, body)); + VerifyCleanConnection(true); +} + +TEST_P(EndToEndTest, LargePostWithPacketLossAndBlockedSocket) { + // Connect with lower fake packet loss than we'd like to test. Until + // b/10126687 is fixed, losing handshake packets is pretty brutal. + // Disable blackhole detection as this test is testing loss recovery. + client_extra_copts_.push_back(kNBHD); + SetPacketLossPercentage(5); + ASSERT_TRUE(Initialize()); + EXPECT_TRUE(client_->client()->WaitForHandshakeConfirmed()); + SetPacketLossPercentage(10); + client_writer_->set_fake_blocked_socket_percentage(10); + + // 10 KB body. + std::string body(1024 * 10, 'a'); + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + + EXPECT_EQ(kFooResponseBody, + client_->SendCustomSynchronousRequest(headers, body)); +} + +TEST_P(EndToEndTest, LargePostNoPacketLossWithDelayAndReordering) { + ASSERT_TRUE(Initialize()); + EXPECT_TRUE(client_->client()->WaitForHandshakeConfirmed()); + // Both of these must be called when the writer is not actively used. + SetPacketSendDelay(QuicTime::Delta::FromMilliseconds(2)); + SetReorderPercentage(30); + + // 1 MB body. + std::string body(1024 * 1024, 'a'); + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + + EXPECT_EQ(kFooResponseBody, + client_->SendCustomSynchronousRequest(headers, body)); +} + +// TODO(b/214587920): make this test not rely on timeouts. +TEST_P(EndToEndTest, QUICHE_SLOW_TEST(AddressToken)) { + client_config_.set_max_time_before_crypto_handshake( + QuicTime::Delta::FromSeconds(3)); + client_config_.set_max_idle_time_before_crypto_handshake( + QuicTime::Delta::FromSeconds(1)); + + client_extra_copts_.push_back(kTRTT); + ASSERT_TRUE(Initialize()); + if (!version_.HasIetfQuicFrames()) { + return; + } + + SendSynchronousFooRequestAndCheckResponse(); + QuicSpdyClientSession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_FALSE(client_session->EarlyDataAccepted()); + EXPECT_FALSE(client_session->ReceivedInchoateReject()); + EXPECT_FALSE(client_->client()->EarlyDataAccepted()); + EXPECT_FALSE(client_->client()->ReceivedInchoateReject()); + + client_->Disconnect(); + + // The 0-RTT handshake should succeed. + client_->Connect(); + EXPECT_TRUE(client_->client()->WaitForHandshakeConfirmed()); + ASSERT_TRUE(client_->client()->connected()); + SendSynchronousFooRequestAndCheckResponse(); + + client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_TRUE(client_session->EarlyDataAccepted()); + EXPECT_TRUE(client_->client()->EarlyDataAccepted()); + + server_thread_->Pause(); + QuicSpdySession* server_session = GetServerSession(); + QuicConnection* server_connection = GetServerConnection(); + if (server_session != nullptr && server_connection != nullptr) { + // Verify address is validated via validating token received in INITIAL + // packet. + EXPECT_FALSE( + server_connection->GetStats().address_validated_via_decrypting_packet); + EXPECT_TRUE(server_connection->GetStats().address_validated_via_token); + + // Verify the server received a cached min_rtt from the token and used it as + // the initial rtt. + const CachedNetworkParameters* server_received_network_params = + static_cast( + server_session->GetCryptoStream()) + ->PreviousCachedNetworkParams(); + + ASSERT_NE(server_received_network_params, nullptr); + // QuicSentPacketManager::SetInitialRtt clamps the initial_rtt to between + // [min_initial_rtt, max_initial_rtt]. + const QuicTime::Delta min_initial_rtt = + QuicTime::Delta::FromMicroseconds(kMinTrustedInitialRoundTripTimeUs); + const QuicTime::Delta max_initial_rtt = + QuicTime::Delta::FromMicroseconds(kMaxInitialRoundTripTimeUs); + const QuicTime::Delta expected_initial_rtt = + std::max(min_initial_rtt, + std::min(max_initial_rtt, + QuicTime::Delta::FromMilliseconds( + server_received_network_params->min_rtt_ms()))); + EXPECT_EQ( + server_connection->sent_packet_manager().GetRttStats()->initial_rtt(), + expected_initial_rtt); + } else { + ADD_FAILURE() << "Missing server connection"; + } + + server_thread_->Resume(); + + client_->Disconnect(); + + // Regression test for b/206087883. + // Mock server crash. + StopServer(); + + // The handshake fails due to idle timeout. + client_->Connect(); + ASSERT_FALSE(client_->client()->WaitForOneRttKeysAvailable()); + client_->WaitForWriteToFlush(); + client_->WaitForResponse(); + ASSERT_FALSE(client_->client()->connected()); + EXPECT_THAT(client_->connection_error(), IsError(QUIC_NETWORK_IDLE_TIMEOUT)); + + // Server restarts. + server_writer_ = new PacketDroppingTestWriter(); + StartServer(); + + // Client re-connect. + client_->Connect(); + ASSERT_TRUE(client_->client()->WaitForHandshakeConfirmed()); + client_->WaitForWriteToFlush(); + client_->WaitForResponse(); + ASSERT_TRUE(client_->client()->connected()); + client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_FALSE(client_session->EarlyDataAccepted()); + EXPECT_FALSE(client_->client()->EarlyDataAccepted()); + server_thread_->Pause(); + server_session = GetServerSession(); + server_connection = GetServerConnection(); + // Verify address token is only used once. + if (server_session != nullptr && server_connection != nullptr) { + // Verify address is validated via decrypting packet. + EXPECT_TRUE( + server_connection->GetStats().address_validated_via_decrypting_packet); + EXPECT_FALSE(server_connection->GetStats().address_validated_via_token); + } else { + ADD_FAILURE() << "Missing server connection"; + } + server_thread_->Resume(); + + client_->Disconnect(); +} + +// Verify that client does not reuse a source address token. +// TODO(b/214587920): make this test not rely on timeouts. +TEST_P(EndToEndTest, QUICHE_SLOW_TEST(AddressTokenNotReusedByClient)) { + client_config_.set_max_time_before_crypto_handshake( + QuicTime::Delta::FromSeconds(3)); + client_config_.set_max_idle_time_before_crypto_handshake( + QuicTime::Delta::FromSeconds(1)); + + ASSERT_TRUE(Initialize()); + if (!version_.HasIetfQuicFrames()) { + return; + } + + QuicCryptoClientConfig* client_crypto_config = + client_->client()->crypto_config(); + QuicServerId server_id = client_->client()->server_id(); + + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_FALSE(GetClientSession()->EarlyDataAccepted()); + + client_->Disconnect(); + + QuicClientSessionCache* session_cache = static_cast( + client_crypto_config->mutable_session_cache()); + ASSERT_TRUE( + !QuicClientSessionCachePeer::GetToken(session_cache, server_id).empty()); + + // Pause the server thread again to blackhole packets from client. + server_thread_->Pause(); + client_->Connect(); + EXPECT_FALSE(client_->client()->WaitForOneRttKeysAvailable()); + EXPECT_FALSE(client_->client()->connected()); + + // Verify address token gets cleared. + ASSERT_TRUE( + QuicClientSessionCachePeer::GetToken(session_cache, server_id).empty()); + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, LargePostZeroRTTFailure) { + // Send a request and then disconnect. This prepares the client to attempt + // a 0-RTT handshake for the next request. + ASSERT_TRUE(Initialize()); + + std::string body(20480, 'a'); + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + + EXPECT_EQ(kFooResponseBody, + client_->SendCustomSynchronousRequest(headers, body)); + QuicSpdyClientSession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_FALSE(client_session->EarlyDataAccepted()); + EXPECT_FALSE(client_session->ReceivedInchoateReject()); + EXPECT_FALSE(client_->client()->EarlyDataAccepted()); + EXPECT_FALSE(client_->client()->ReceivedInchoateReject()); + + client_->Disconnect(); + + // The 0-RTT handshake should succeed. + client_->Connect(); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + ASSERT_TRUE(client_->client()->connected()); + EXPECT_EQ(kFooResponseBody, + client_->SendCustomSynchronousRequest(headers, body)); + + client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_TRUE(client_session->EarlyDataAccepted()); + EXPECT_TRUE(client_->client()->EarlyDataAccepted()); + + client_->Disconnect(); + + // Restart the server so that the 0-RTT handshake will take 1 RTT. + StopServer(); + server_writer_ = new PacketDroppingTestWriter(); + StartServer(); + + client_->Connect(); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + ASSERT_TRUE(client_->client()->connected()); + EXPECT_EQ(kFooResponseBody, + client_->SendCustomSynchronousRequest(headers, body)); + client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_FALSE(client_session->EarlyDataAccepted()); + EXPECT_FALSE(client_session->ReceivedInchoateReject()); + EXPECT_FALSE(client_->client()->EarlyDataAccepted()); + EXPECT_FALSE(client_->client()->ReceivedInchoateReject()); + VerifyCleanConnection(false); +} + +// Regression test for b/168020146. +TEST_P(EndToEndTest, MultipleZeroRtt) { + ASSERT_TRUE(Initialize()); + + EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); + QuicSpdyClientSession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_FALSE(client_session->EarlyDataAccepted()); + EXPECT_FALSE(client_session->ReceivedInchoateReject()); + EXPECT_FALSE(client_->client()->EarlyDataAccepted()); + EXPECT_FALSE(client_->client()->ReceivedInchoateReject()); + + client_->Disconnect(); + + // The 0-RTT handshake should succeed. + client_->Connect(); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + ASSERT_TRUE(client_->client()->connected()); + EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); + + client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_TRUE(client_session->EarlyDataAccepted()); + EXPECT_TRUE(client_->client()->EarlyDataAccepted()); + + client_->Disconnect(); + + client_->Connect(); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + ASSERT_TRUE(client_->client()->connected()); + EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); + + client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_TRUE(client_session->EarlyDataAccepted()); + EXPECT_TRUE(client_->client()->EarlyDataAccepted()); + + client_->Disconnect(); +} + +TEST_P(EndToEndTest, SynchronousRequestZeroRTTFailure) { + // Send a request and then disconnect. This prepares the client to attempt + // a 0-RTT handshake for the next request. + ASSERT_TRUE(Initialize()); + + EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); + QuicSpdyClientSession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_FALSE(client_session->EarlyDataAccepted()); + EXPECT_FALSE(client_session->ReceivedInchoateReject()); + EXPECT_FALSE(client_->client()->EarlyDataAccepted()); + EXPECT_FALSE(client_->client()->ReceivedInchoateReject()); + + client_->Disconnect(); + + // The 0-RTT handshake should succeed. + client_->Connect(); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + ASSERT_TRUE(client_->client()->connected()); + EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); + + client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_TRUE(client_session->EarlyDataAccepted()); + EXPECT_TRUE(client_->client()->EarlyDataAccepted()); + + client_->Disconnect(); + + // Restart the server so that the 0-RTT handshake will take 1 RTT. + StopServer(); + server_writer_ = new PacketDroppingTestWriter(); + StartServer(); + + client_->Connect(); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + ASSERT_TRUE(client_->client()->connected()); + EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); + + client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_FALSE(client_session->EarlyDataAccepted()); + EXPECT_FALSE(client_session->ReceivedInchoateReject()); + EXPECT_FALSE(client_->client()->EarlyDataAccepted()); + EXPECT_FALSE(client_->client()->ReceivedInchoateReject()); + + VerifyCleanConnection(false); +} + +TEST_P(EndToEndTest, LargePostSynchronousRequest) { + // Send a request and then disconnect. This prepares the client to attempt + // a 0-RTT handshake for the next request. + ASSERT_TRUE(Initialize()); + + std::string body(20480, 'a'); + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + + EXPECT_EQ(kFooResponseBody, + client_->SendCustomSynchronousRequest(headers, body)); + QuicSpdyClientSession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_FALSE(client_session->EarlyDataAccepted()); + EXPECT_FALSE(client_session->ReceivedInchoateReject()); + EXPECT_FALSE(client_->client()->EarlyDataAccepted()); + EXPECT_FALSE(client_->client()->ReceivedInchoateReject()); + + client_->Disconnect(); + + // The 0-RTT handshake should succeed. + client_->Connect(); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + ASSERT_TRUE(client_->client()->connected()); + EXPECT_EQ(kFooResponseBody, + client_->SendCustomSynchronousRequest(headers, body)); + + client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_TRUE(client_session->EarlyDataAccepted()); + EXPECT_TRUE(client_->client()->EarlyDataAccepted()); + + client_->Disconnect(); + + // Restart the server so that the 0-RTT handshake will take 1 RTT. + StopServer(); + server_writer_ = new PacketDroppingTestWriter(); + StartServer(); + + client_->Connect(); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + ASSERT_TRUE(client_->client()->connected()); + EXPECT_EQ(kFooResponseBody, + client_->SendCustomSynchronousRequest(headers, body)); + + client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_FALSE(client_session->EarlyDataAccepted()); + EXPECT_FALSE(client_session->ReceivedInchoateReject()); + EXPECT_FALSE(client_->client()->EarlyDataAccepted()); + EXPECT_FALSE(client_->client()->ReceivedInchoateReject()); + + VerifyCleanConnection(false); +} + +TEST_P(EndToEndTest, DisableResumption) { + client_extra_copts_.push_back(kNRES); + ASSERT_TRUE(Initialize()); + if (!version_.UsesTls()) { + return; + } + SendSynchronousFooRequestAndCheckResponse(); + QuicSpdyClientSession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_EQ(client_session->GetCryptoStream()->EarlyDataReason(), + ssl_early_data_no_session_offered); + client_->Disconnect(); + + SendSynchronousFooRequestAndCheckResponse(); + client_session = GetClientSession(); + ASSERT_TRUE(client_session); + if (GetQuicReloadableFlag(quic_enable_disable_resumption)) { + EXPECT_EQ(client_session->GetCryptoStream()->EarlyDataReason(), + ssl_early_data_session_not_resumed); + } else { + EXPECT_EQ(client_session->GetCryptoStream()->EarlyDataReason(), + ssl_early_data_accepted); + } +} + +// This is a regression test for b/162595387 +TEST_P(EndToEndTest, PostZeroRTTRequestDuringHandshake) { + if (!version_.UsesTls()) { + // This test is TLS specific. + ASSERT_TRUE(Initialize()); + return; + } + // Send a request and then disconnect. This prepares the client to attempt + // a 0-RTT handshake for the next request. + NiceMock visitor; + connection_debug_visitor_ = &visitor; + ASSERT_TRUE(Initialize()); + + SendSynchronousFooRequestAndCheckResponse(); + QuicSpdyClientSession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_FALSE(client_session->EarlyDataAccepted()); + EXPECT_FALSE(client_session->ReceivedInchoateReject()); + EXPECT_FALSE(client_->client()->EarlyDataAccepted()); + EXPECT_FALSE(client_->client()->ReceivedInchoateReject()); + + client_->Disconnect(); + + // The 0-RTT handshake should succeed. + ON_CALL(visitor, OnCryptoFrame(_)) + .WillByDefault(Invoke([this](const QuicCryptoFrame& frame) { + if (frame.level != ENCRYPTION_HANDSHAKE) { + return; + } + // At this point in the handshake, the client should have derived + // ENCRYPTION_ZERO_RTT keys (thus set encryption_established). It + // should also have set ENCRYPTION_HANDSHAKE keys after receiving + // the server's ENCRYPTION_INITIAL flight. + EXPECT_TRUE( + GetClientSession()->GetCryptoStream()->encryption_established()); + EXPECT_TRUE( + GetClientConnection()->framer().HasEncrypterOfEncryptionLevel( + ENCRYPTION_HANDSHAKE)); + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + EXPECT_GT( + client_->SendMessage(headers, "", /*fin*/ true, /*flush*/ false), + 0); + })); + client_->Connect(); + ASSERT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + client_->WaitForWriteToFlush(); + client_->WaitForResponse(); + ASSERT_TRUE(client_->client()->connected()); + EXPECT_EQ(kFooResponseBody, client_->response_body()); + + client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_TRUE(client_session->EarlyDataAccepted()); + EXPECT_TRUE(client_->client()->EarlyDataAccepted()); +} + +// Regression test for b/166836136. +TEST_P(EndToEndTest, RetransmissionAfterZeroRTTRejectBeforeOneRtt) { + if (!version_.UsesTls()) { + // This test is TLS specific. + ASSERT_TRUE(Initialize()); + return; + } + // Send a request and then disconnect. This prepares the client to attempt + // a 0-RTT handshake for the next request. + NiceMock visitor; + connection_debug_visitor_ = &visitor; + ASSERT_TRUE(Initialize()); + + SendSynchronousFooRequestAndCheckResponse(); + QuicSpdyClientSession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_FALSE(client_session->EarlyDataAccepted()); + EXPECT_FALSE(client_session->ReceivedInchoateReject()); + EXPECT_FALSE(client_->client()->EarlyDataAccepted()); + EXPECT_FALSE(client_->client()->ReceivedInchoateReject()); + + client_->Disconnect(); + + client_->Connect(); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + ASSERT_TRUE(client_->client()->connected()); + EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); + + client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_TRUE(client_session->EarlyDataAccepted()); + EXPECT_TRUE(client_->client()->EarlyDataAccepted()); + + client_->Disconnect(); + + // Restart the server so that the 0-RTT handshake will take 1 RTT. + StopServer(); + server_writer_ = new PacketDroppingTestWriter(); + StartServer(); + + ON_CALL(visitor, OnZeroRttRejected(_)).WillByDefault(Invoke([this]() { + EXPECT_FALSE(GetClientSession()->IsEncryptionEstablished()); + })); + + // The 0-RTT handshake should fail. + client_->Connect(); + ASSERT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + client_->WaitForWriteToFlush(); + client_->WaitForResponse(); + ASSERT_TRUE(client_->client()->connected()); + + client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_FALSE(client_session->EarlyDataAccepted()); + EXPECT_FALSE(client_->client()->EarlyDataAccepted()); +} + +TEST_P(EndToEndTest, RejectWithPacketLoss) { + // In this test, we intentionally drop the first packet from the + // server, which corresponds with the initial REJ response from + // the server. + server_writer_->set_fake_drop_first_n_packets(1); + ASSERT_TRUE(Initialize()); +} + +TEST_P(EndToEndTest, SetInitialReceivedConnectionOptions) { + QuicTagVector initial_received_options; + initial_received_options.push_back(kTBBR); + initial_received_options.push_back(kIW10); + initial_received_options.push_back(kPRST); + EXPECT_TRUE(server_config_.SetInitialReceivedConnectionOptions( + initial_received_options)); + + ASSERT_TRUE(Initialize()); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + server_thread_->WaitForCryptoHandshakeConfirmed(); + + EXPECT_FALSE(server_config_.SetInitialReceivedConnectionOptions( + initial_received_options)); + + // Verify that server's configuration is correct. + server_thread_->Pause(); + EXPECT_TRUE(server_config_.HasReceivedConnectionOptions()); + EXPECT_TRUE( + ContainsQuicTag(server_config_.ReceivedConnectionOptions(), kTBBR)); + EXPECT_TRUE( + ContainsQuicTag(server_config_.ReceivedConnectionOptions(), kIW10)); + EXPECT_TRUE( + ContainsQuicTag(server_config_.ReceivedConnectionOptions(), kPRST)); +} + +TEST_P(EndToEndTest, LargePostSmallBandwidthLargeBuffer) { + ASSERT_TRUE(Initialize()); + SetPacketSendDelay(QuicTime::Delta::FromMicroseconds(1)); + // 256KB per second with a 256KB buffer from server to client. Wireless + // clients commonly have larger buffers, but our max CWND is 200. + server_writer_->set_max_bandwidth_and_buffer_size( + QuicBandwidth::FromBytesPerSecond(256 * 1024), 256 * 1024); + + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + + // 1 MB body. + std::string body(1024 * 1024, 'a'); + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + + EXPECT_EQ(kFooResponseBody, + client_->SendCustomSynchronousRequest(headers, body)); + // This connection may drop packets, because the buffer is smaller than the + // max CWND. + VerifyCleanConnection(true); +} + +TEST_P(EndToEndTest, DoNotSetSendAlarmIfConnectionFlowControlBlocked) { + // Regression test for b/14677858. + // Test that the resume write alarm is not set in QuicConnection::OnCanWrite + // if currently connection level flow control blocked. If set, this results in + // an infinite loop in the EventLoop, as the alarm fires and is immediately + // rescheduled. + ASSERT_TRUE(Initialize()); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + + // Ensure both stream and connection level are flow control blocked by setting + // the send window offset to 0. + const uint64_t flow_control_window = + server_config_.GetInitialStreamFlowControlWindowToSend(); + QuicSpdyClientStream* stream = client_->GetOrCreateStream(); + QuicSession* session = GetClientSession(); + ASSERT_TRUE(session); + QuicStreamPeer::SetSendWindowOffset(stream, 0); + QuicFlowControllerPeer::SetSendWindowOffset(session->flow_controller(), 0); + EXPECT_TRUE(stream->IsFlowControlBlocked()); + EXPECT_TRUE(session->flow_controller()->IsBlocked()); + + // Make sure that the stream has data pending so that it will be marked as + // write blocked when it receives a stream level WINDOW_UPDATE. + stream->WriteOrBufferBody("hello", false); + + // The stream now attempts to write, fails because it is still connection + // level flow control blocked, and is added to the write blocked list. + QuicWindowUpdateFrame window_update(kInvalidControlFrameId, stream->id(), + 2 * flow_control_window); + stream->OnWindowUpdateFrame(window_update); + + // Prior to fixing b/14677858 this call would result in an infinite loop in + // Chromium. As a proxy for detecting this, we now check whether the + // send alarm is set after OnCanWrite. It should not be, as the + // connection is still flow control blocked. + session->connection()->OnCanWrite(); + + QuicAlarm* send_alarm = + QuicConnectionPeer::GetSendAlarm(session->connection()); + EXPECT_FALSE(send_alarm->IsSet()); +} + +TEST_P(EndToEndTest, InvalidStream) { + ASSERT_TRUE(Initialize()); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + + std::string body(kMaxOutgoingPacketSize, 'a'); + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + + // Force the client to write with a stream ID belonging to a nonexistent + // server-side stream. + QuicSpdySession* session = GetClientSession(); + ASSERT_TRUE(session); + QuicSessionPeer::SetNextOutgoingBidirectionalStreamId( + session, GetNthServerInitiatedBidirectionalId(0)); + + client_->SendCustomSynchronousRequest(headers, body); + EXPECT_THAT(client_->stream_error(), + IsStreamError(QUIC_STREAM_CONNECTION_ERROR)); + EXPECT_THAT(client_->connection_error(), IsError(QUIC_INVALID_STREAM_ID)); +} + +// Test that the server resets the stream if the client sends a request +// with overly large headers. +TEST_P(EndToEndTest, LargeHeaders) { + ASSERT_TRUE(Initialize()); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + + std::string body(kMaxOutgoingPacketSize, 'a'); + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + headers["key1"] = std::string(15 * 1024, 'a'); + headers["key2"] = std::string(15 * 1024, 'a'); + headers["key3"] = std::string(15 * 1024, 'a'); + + client_->SendCustomSynchronousRequest(headers, body); + + if (version_.UsesHttp3()) { + // QuicSpdyStream::OnHeadersTooLarge() resets the stream with + // QUIC_HEADERS_TOO_LARGE. This is sent as H3_EXCESSIVE_LOAD, the closest + // HTTP/3 error code, and translated back to QUIC_STREAM_EXCESSIVE_LOAD on + // the receiving side. + EXPECT_THAT(client_->stream_error(), + IsStreamError(QUIC_STREAM_EXCESSIVE_LOAD)); + } else { + EXPECT_THAT(client_->stream_error(), IsStreamError(QUIC_HEADERS_TOO_LARGE)); + } + EXPECT_THAT(client_->connection_error(), IsQuicNoError()); +} + +TEST_P(EndToEndTest, EarlyResponseWithQuicStreamNoError) { + ASSERT_TRUE(Initialize()); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + + std::string large_body(1024 * 1024, 'a'); + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + // Insert an invalid content_length field in request to trigger an early + // response from server. + headers["content-length"] = "-3"; + + client_->SendCustomSynchronousRequest(headers, large_body); + EXPECT_EQ("bad", client_->response_body()); + CheckResponseHeaders("500"); + EXPECT_THAT(client_->stream_error(), IsQuicStreamNoError()); + EXPECT_THAT(client_->connection_error(), IsQuicNoError()); +} + +// TODO(rch): this test seems to cause net_unittests timeouts :| +TEST_P(EndToEndTest, QUIC_TEST_DISABLED_IN_CHROME(MultipleTermination)) { + ASSERT_TRUE(Initialize()); + + // Set the offset so we won't frame. Otherwise when we pick up termination + // before HTTP framing is complete, we send an error and close the stream, + // and the second write is picked up as writing on a closed stream. + QuicSpdyClientStream* stream = client_->GetOrCreateStream(); + ASSERT_TRUE(stream != nullptr); + QuicStreamPeer::SetStreamBytesWritten(3, stream); + + client_->SendData("bar", true); + client_->WaitForWriteToFlush(); + + // By default the stream protects itself from writes after terminte is set. + // Override this to test the server handling buggy clients. + QuicStreamPeer::SetWriteSideClosed(false, client_->GetOrCreateStream()); + + EXPECT_QUIC_BUG(client_->SendData("eep", true), "Fin already buffered"); +} + +TEST_P(EndToEndTest, Timeout) { + client_config_.SetIdleNetworkTimeout(QuicTime::Delta::FromMicroseconds(500)); + // Note: we do NOT ASSERT_TRUE: we may time out during initial handshake: + // that's enough to validate timeout in this case. + Initialize(); + while (client_->client()->connected()) { + client_->client()->WaitForEvents(); + } +} + +TEST_P(EndToEndTest, MaxDynamicStreamsLimitRespected) { + // Set a limit on maximum number of incoming dynamic streams. + // Make sure the limit is respected by the peer. + const uint32_t kServerMaxDynamicStreams = 1; + server_config_.SetMaxBidirectionalStreamsToSend(kServerMaxDynamicStreams); + ASSERT_TRUE(Initialize()); + if (version_.HasIetfQuicFrames()) { + // Do not run this test for /IETF QUIC. This test relies on the fact that + // Google QUIC allows a small number of additional streams beyond the + // negotiated limit, which is not supported in IETF QUIC. Note that the test + // needs to be here, after calling Initialize(), because all tests end up + // calling EndToEndTest::TearDown(), which asserts that Initialize has been + // called and then proceeds to tear things down -- which fails if they are + // not properly set up. + return; + } + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + + // Make the client misbehave after negotiation. + const int kServerMaxStreams = kMaxStreamsMinimumIncrement + 1; + QuicSpdyClientSession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + QuicSessionPeer::SetMaxOpenOutgoingStreams(client_session, + kServerMaxStreams + 1); + + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + headers["content-length"] = "3"; + + // The server supports a small number of additional streams beyond the + // negotiated limit. Open enough streams to go beyond that limit. + for (int i = 0; i < kServerMaxStreams + 1; ++i) { + client_->SendMessage(headers, "", /*fin=*/false); + } + client_->WaitForResponse(); + + EXPECT_TRUE(client_->connected()); + EXPECT_THAT(client_->stream_error(), IsStreamError(QUIC_REFUSED_STREAM)); + EXPECT_THAT(client_->connection_error(), IsQuicNoError()); +} + +TEST_P(EndToEndTest, SetIndependentMaxDynamicStreamsLimits) { + // Each endpoint can set max dynamic streams independently. + const uint32_t kClientMaxDynamicStreams = 4; + const uint32_t kServerMaxDynamicStreams = 3; + client_config_.SetMaxBidirectionalStreamsToSend(kClientMaxDynamicStreams); + server_config_.SetMaxBidirectionalStreamsToSend(kServerMaxDynamicStreams); + client_config_.SetMaxUnidirectionalStreamsToSend(kClientMaxDynamicStreams); + server_config_.SetMaxUnidirectionalStreamsToSend(kServerMaxDynamicStreams); + + ASSERT_TRUE(Initialize()); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + + // The client has received the server's limit and vice versa. + QuicSpdyClientSession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + // The value returned by max_allowed... includes the Crypto and Header + // stream (created as a part of initialization). The config. values, + // above, are treated as "number of requests/responses" - that is, they do + // not include the static Crypto and Header streams. Reduce the value + // returned by max_allowed... by 2 to remove the static streams from the + // count. + size_t client_max_open_outgoing_bidirectional_streams = + version_.HasIetfQuicFrames() + ? QuicSessionPeer::ietf_streamid_manager(client_session) + ->max_outgoing_bidirectional_streams() + : QuicSessionPeer::GetStreamIdManager(client_session) + ->max_open_outgoing_streams(); + size_t client_max_open_outgoing_unidirectional_streams = + version_.HasIetfQuicFrames() + ? QuicSessionPeer::ietf_streamid_manager(client_session) + ->max_outgoing_unidirectional_streams() - + kHttp3StaticUnidirectionalStreamCount + : QuicSessionPeer::GetStreamIdManager(client_session) + ->max_open_outgoing_streams(); + EXPECT_EQ(kServerMaxDynamicStreams, + client_max_open_outgoing_bidirectional_streams); + EXPECT_EQ(kServerMaxDynamicStreams, + client_max_open_outgoing_unidirectional_streams); + server_thread_->Pause(); + QuicSession* server_session = GetServerSession(); + if (server_session != nullptr) { + size_t server_max_open_outgoing_bidirectional_streams = + version_.HasIetfQuicFrames() + ? QuicSessionPeer::ietf_streamid_manager(server_session) + ->max_outgoing_bidirectional_streams() + : QuicSessionPeer::GetStreamIdManager(server_session) + ->max_open_outgoing_streams(); + size_t server_max_open_outgoing_unidirectional_streams = + version_.HasIetfQuicFrames() + ? QuicSessionPeer::ietf_streamid_manager(server_session) + ->max_outgoing_unidirectional_streams() - + kHttp3StaticUnidirectionalStreamCount + : QuicSessionPeer::GetStreamIdManager(server_session) + ->max_open_outgoing_streams(); + EXPECT_EQ(kClientMaxDynamicStreams, + server_max_open_outgoing_bidirectional_streams); + EXPECT_EQ(kClientMaxDynamicStreams, + server_max_open_outgoing_unidirectional_streams); + } else { + ADD_FAILURE() << "Missing server session"; + } + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, NegotiateCongestionControl) { + ASSERT_TRUE(Initialize()); + + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + + CongestionControlType expected_congestion_control_type = kRenoBytes; + switch (GetParam().congestion_control_tag) { + case kRENO: + expected_congestion_control_type = kRenoBytes; + break; + case kTBBR: + expected_congestion_control_type = kBBR; + break; + case kQBIC: + expected_congestion_control_type = kCubicBytes; + break; + case kB2ON: + expected_congestion_control_type = kBBRv2; + break; + default: + QUIC_DLOG(FATAL) << "Unexpected congestion control tag"; + } + + server_thread_->Pause(); + const QuicSentPacketManager* server_sent_packet_manager = + GetSentPacketManagerFromFirstServerSession(); + if (server_sent_packet_manager != nullptr) { + EXPECT_EQ( + expected_congestion_control_type, + QuicSentPacketManagerPeer::GetSendAlgorithm(*server_sent_packet_manager) + ->GetCongestionControlType()); + } else { + ADD_FAILURE() << "Missing server sent packet manager"; + } + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, ClientSuggestsRTT) { + // Client suggests initial RTT, verify it is used. + const QuicTime::Delta kInitialRTT = QuicTime::Delta::FromMicroseconds(20000); + client_config_.SetInitialRoundTripTimeUsToSend(kInitialRTT.ToMicroseconds()); + + ASSERT_TRUE(Initialize()); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + ASSERT_TRUE(server_thread_); + server_thread_->WaitForCryptoHandshakeConfirmed(); + + // Pause the server so we can access the server's internals without races. + server_thread_->Pause(); + const QuicSentPacketManager* client_sent_packet_manager = + GetSentPacketManagerFromClientSession(); + const QuicSentPacketManager* server_sent_packet_manager = + GetSentPacketManagerFromFirstServerSession(); + if (client_sent_packet_manager != nullptr && + server_sent_packet_manager != nullptr) { + EXPECT_EQ(kInitialRTT, + client_sent_packet_manager->GetRttStats()->initial_rtt()); + EXPECT_EQ(kInitialRTT, + server_sent_packet_manager->GetRttStats()->initial_rtt()); + } else { + ADD_FAILURE() << "Missing sent packet manager"; + } + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, ClientSuggestsIgnoredRTT) { + // Client suggests initial RTT, but also specifies NRTT, so it's not used. + const QuicTime::Delta kInitialRTT = QuicTime::Delta::FromMicroseconds(20000); + client_config_.SetInitialRoundTripTimeUsToSend(kInitialRTT.ToMicroseconds()); + QuicTagVector options; + options.push_back(kNRTT); + client_config_.SetConnectionOptionsToSend(options); + + ASSERT_TRUE(Initialize()); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + ASSERT_TRUE(server_thread_); + server_thread_->WaitForCryptoHandshakeConfirmed(); + + // Pause the server so we can access the server's internals without races. + server_thread_->Pause(); + const QuicSentPacketManager* client_sent_packet_manager = + GetSentPacketManagerFromClientSession(); + const QuicSentPacketManager* server_sent_packet_manager = + GetSentPacketManagerFromFirstServerSession(); + if (client_sent_packet_manager != nullptr && + server_sent_packet_manager != nullptr) { + EXPECT_EQ(kInitialRTT, + client_sent_packet_manager->GetRttStats()->initial_rtt()); + EXPECT_EQ(kInitialRTT, + server_sent_packet_manager->GetRttStats()->initial_rtt()); + } else { + ADD_FAILURE() << "Missing sent packet manager"; + } + server_thread_->Resume(); +} + +// Regression test for b/171378845 +TEST_P(EndToEndTest, ClientDisablesGQuicZeroRtt) { + if (version_.UsesTls()) { + // This feature is gQUIC only. + ASSERT_TRUE(Initialize()); + return; + } + QuicTagVector options; + options.push_back(kQNZ2); + client_config_.SetClientConnectionOptions(options); + + ASSERT_TRUE(Initialize()); + + EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); + QuicSpdyClientSession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_FALSE(client_session->EarlyDataAccepted()); + EXPECT_FALSE(client_session->ReceivedInchoateReject()); + EXPECT_FALSE(client_->client()->EarlyDataAccepted()); + EXPECT_FALSE(client_->client()->ReceivedInchoateReject()); + + client_->Disconnect(); + + // Make sure that the request succeeds but 0-RTT was not used. + client_->Connect(); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + ASSERT_TRUE(client_->client()->connected()); + EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); + + client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_FALSE(client_session->EarlyDataAccepted()); + EXPECT_FALSE(client_->client()->EarlyDataAccepted()); +} + +TEST_P(EndToEndTest, MaxInitialRTT) { + // Client tries to suggest twice the server's max initial rtt and the server + // uses the max. + client_config_.SetInitialRoundTripTimeUsToSend(2 * + kMaxInitialRoundTripTimeUs); + + ASSERT_TRUE(Initialize()); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + ASSERT_TRUE(server_thread_); + server_thread_->WaitForCryptoHandshakeConfirmed(); + + // Pause the server so we can access the server's internals without races. + server_thread_->Pause(); + const QuicSentPacketManager* client_sent_packet_manager = + GetSentPacketManagerFromClientSession(); + const QuicSentPacketManager* server_sent_packet_manager = + GetSentPacketManagerFromFirstServerSession(); + if (client_sent_packet_manager != nullptr && + server_sent_packet_manager != nullptr) { + // Now that acks have been exchanged, the RTT estimate has decreased on the + // server and is not infinite on the client. + EXPECT_FALSE( + client_sent_packet_manager->GetRttStats()->smoothed_rtt().IsInfinite()); + const RttStats* server_rtt_stats = + server_sent_packet_manager->GetRttStats(); + EXPECT_EQ(static_cast(kMaxInitialRoundTripTimeUs), + server_rtt_stats->initial_rtt().ToMicroseconds()); + EXPECT_GE(static_cast(kMaxInitialRoundTripTimeUs), + server_rtt_stats->smoothed_rtt().ToMicroseconds()); + } else { + ADD_FAILURE() << "Missing sent packet manager"; + } + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, MinInitialRTT) { + // Client tries to suggest 0 and the server uses the default. + client_config_.SetInitialRoundTripTimeUsToSend(0); + + ASSERT_TRUE(Initialize()); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + server_thread_->WaitForCryptoHandshakeConfirmed(); + + // Pause the server so we can access the server's internals without races. + server_thread_->Pause(); + const QuicSentPacketManager* client_sent_packet_manager = + GetSentPacketManagerFromClientSession(); + const QuicSentPacketManager* server_sent_packet_manager = + GetSentPacketManagerFromFirstServerSession(); + if (client_sent_packet_manager != nullptr && + server_sent_packet_manager != nullptr) { + // Now that acks have been exchanged, the RTT estimate has decreased on the + // server and is not infinite on the client. + EXPECT_FALSE( + client_sent_packet_manager->GetRttStats()->smoothed_rtt().IsInfinite()); + // Expect the default rtt of 100ms. + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(100), + server_sent_packet_manager->GetRttStats()->initial_rtt()); + // Ensure the bandwidth is valid. + client_sent_packet_manager->BandwidthEstimate(); + server_sent_packet_manager->BandwidthEstimate(); + } else { + ADD_FAILURE() << "Missing sent packet manager"; + } + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, 0ByteConnectionId) { + if (version_.HasIetfInvariantHeader()) { + // SetBytesForConnectionIdToSend only applies to Google QUIC encoding. + ASSERT_TRUE(Initialize()); + return; + } + client_config_.SetBytesForConnectionIdToSend(0); + ASSERT_TRUE(Initialize()); + + SendSynchronousFooRequestAndCheckResponse(); + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + QuicPacketHeader* header = + QuicConnectionPeer::GetLastHeader(client_connection); + EXPECT_EQ(CONNECTION_ID_ABSENT, header->source_connection_id_included); +} + +TEST_P(EndToEndTest, 8ByteConnectionId) { + if (version_.HasIetfInvariantHeader()) { + // SetBytesForConnectionIdToSend only applies to Google QUIC encoding. + ASSERT_TRUE(Initialize()); + return; + } + client_config_.SetBytesForConnectionIdToSend(8); + ASSERT_TRUE(Initialize()); + + SendSynchronousFooRequestAndCheckResponse(); + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + QuicPacketHeader* header = + QuicConnectionPeer::GetLastHeader(client_connection); + EXPECT_EQ(CONNECTION_ID_PRESENT, header->destination_connection_id_included); +} + +TEST_P(EndToEndTest, 15ByteConnectionId) { + if (version_.HasIetfInvariantHeader()) { + // SetBytesForConnectionIdToSend only applies to Google QUIC encoding. + ASSERT_TRUE(Initialize()); + return; + } + client_config_.SetBytesForConnectionIdToSend(15); + ASSERT_TRUE(Initialize()); + + // Our server is permissive and allows for out of bounds values. + SendSynchronousFooRequestAndCheckResponse(); + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + QuicPacketHeader* header = + QuicConnectionPeer::GetLastHeader(client_connection); + EXPECT_EQ(CONNECTION_ID_PRESENT, header->destination_connection_id_included); +} + +TEST_P(EndToEndTest, ResetConnection) { + ASSERT_TRUE(Initialize()); + + SendSynchronousFooRequestAndCheckResponse(); + client_->ResetConnection(); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + SendSynchronousBarRequestAndCheckResponse(); +} + +// Regression test for b/180737158. +TEST_P( + EndToEndTest, + HalfRttResponseBlocksShloRetransmissionWithoutTokenBasedAddressValidation) { + // Turn off token based address validation to make the server get constrained + // by amplification factor during handshake. + SetQuicFlag(quic_reject_retry_token_in_initial_packet, true); + ASSERT_TRUE(Initialize()); + if (!version_.SupportsAntiAmplificationLimit()) { + return; + } + // Perform a full 1-RTT handshake to get the new session ticket such that the + // next connection will perform a 0-RTT handshake. + EXPECT_TRUE(client_->client()->WaitForHandshakeConfirmed()); + client_->Disconnect(); + + server_thread_->Pause(); + // Drop the 1st server packet which is the coalesced INITIAL + HANDSHAKE + + // 1RTT. + PacketDroppingTestWriter* writer = new PacketDroppingTestWriter(); + writer->set_fake_drop_first_n_packets(1); + QuicDispatcherPeer::UseWriter( + QuicServerPeer::GetDispatcher(server_thread_->server()), writer); + server_thread_->Resume(); + + // Large response (100KB) for 0-RTT request. + std::string large_body(102400, 'a'); + AddToCache("/large_response", 200, large_body); + SendSynchronousRequestAndCheckResponse(client_.get(), "/large_response", + large_body); +} + +TEST_P(EndToEndTest, MaxStreamsUberTest) { + // Connect with lower fake packet loss than we'd like to test. Until + // b/10126687 is fixed, losing handshake packets is pretty brutal. + SetPacketLossPercentage(1); + ASSERT_TRUE(Initialize()); + std::string large_body(10240, 'a'); + int max_streams = 100; + + AddToCache("/large_response", 200, large_body); + + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + SetPacketLossPercentage(10); + + for (int i = 0; i < max_streams; ++i) { + EXPECT_LT(0, client_->SendRequest("/large_response")); + } + + // WaitForEvents waits 50ms and returns true if there are outstanding + // requests. + while (client_->client()->WaitForEvents()) { + ASSERT_TRUE(client_->connected()); + } +} + +TEST_P(EndToEndTest, StreamCancelErrorTest) { + ASSERT_TRUE(Initialize()); + std::string small_body(256, 'a'); + + AddToCache("/small_response", 200, small_body); + + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + + QuicSession* session = GetClientSession(); + ASSERT_TRUE(session); + // Lose the request. + SetPacketLossPercentage(100); + EXPECT_LT(0, client_->SendRequest("/small_response")); + client_->client()->WaitForEvents(); + // Transmit the cancel, and ensure the connection is torn down properly. + SetPacketLossPercentage(0); + QuicStreamId stream_id = GetNthClientInitiatedBidirectionalId(0); + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + const QuicPacketCount packets_sent_before = + client_connection->GetStats().packets_sent; + session->ResetStream(stream_id, QUIC_STREAM_CANCELLED); + const QuicPacketCount packets_sent_now = + client_connection->GetStats().packets_sent; + + if (version_.UsesHttp3()) { + // Make sure 2 packets were sent, one for QPACK instructions, another for + // RESET_STREAM and STOP_SENDING. + EXPECT_EQ(packets_sent_before + 2, packets_sent_now); + } + + // WaitForEvents waits 50ms and returns true if there are outstanding + // requests. + while (client_->client()->WaitForEvents()) { + ASSERT_TRUE(client_->connected()); + } + // It should be completely fine to RST a stream before any data has been + // received for that stream. + EXPECT_THAT(client_->connection_error(), IsQuicNoError()); +} + +TEST_P(EndToEndTest, ConnectionMigrationClientIPChanged) { + ASSERT_TRUE(Initialize()); + if (GetQuicFlag(quic_enforce_strict_amplification_factor)) { + return; + } + SendSynchronousFooRequestAndCheckResponse(); + + // Store the client IP address which was used to send the first request. + QuicIpAddress old_host = + client_->client()->network_helper()->GetLatestClientAddress().host(); + + // Migrate socket to the new IP address. + QuicIpAddress new_host = TestLoopback(2); + EXPECT_NE(old_host, new_host); + ASSERT_TRUE(client_->client()->MigrateSocket(new_host)); + + // Send a request using the new socket. + SendSynchronousBarRequestAndCheckResponse(); + + if (!version_.HasIetfQuicFrames() || + !client_->client()->session()->connection()->validate_client_address()) { + return; + } + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + EXPECT_EQ(1u, + client_connection->GetStats().num_connectivity_probing_received); + + // Send another request. + SendSynchronousBarRequestAndCheckResponse(); + // By the time the 2nd request is completed, the PATH_RESPONSE must have been + // received by the server. + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + if (server_connection != nullptr) { + EXPECT_FALSE(server_connection->HasPendingPathValidation()); + EXPECT_EQ(1u, server_connection->GetStats().num_validated_peer_migration); + } else { + ADD_FAILURE() << "Missing server connection"; + } + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, IetfConnectionMigrationClientIPChangedMultipleTimes) { + ASSERT_TRUE(Initialize()); + if (!GetClientConnection()->connection_migration_use_new_cid() || + GetQuicFlag(quic_enforce_strict_amplification_factor)) { + return; + } + SendSynchronousFooRequestAndCheckResponse(); + + // Store the client IP address which was used to send the first request. + QuicIpAddress host0 = + client_->client()->network_helper()->GetLatestClientAddress().host(); + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection != nullptr); + + // Migrate socket to a new IP address. + QuicIpAddress host1 = TestLoopback(2); + EXPECT_NE(host0, host1); + ASSERT_TRUE( + QuicConnectionPeer::HasUnusedPeerIssuedConnectionId(client_connection)); + QuicConnectionId server_cid0 = client_connection->connection_id(); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + EXPECT_TRUE(client_->client()->MigrateSocket(host1)); + QuicConnectionId server_cid1 = client_connection->connection_id(); + EXPECT_FALSE(server_cid1.IsEmpty()); + EXPECT_NE(server_cid0, server_cid1); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + + // Send a request using the new socket. + SendSynchronousBarRequestAndCheckResponse(); + EXPECT_EQ(1u, + client_connection->GetStats().num_connectivity_probing_received); + + // Send another request and wait for response making sure path response is + // received at server. + SendSynchronousBarRequestAndCheckResponse(); + + // Migrate socket to a new IP address. + WaitForNewConnectionIds(); + EXPECT_EQ(1u, client_connection->GetStats().num_retire_connection_id_sent); + QuicIpAddress host2 = TestLoopback(3); + EXPECT_NE(host0, host2); + EXPECT_NE(host1, host2); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + EXPECT_TRUE(client_->client()->MigrateSocket(host2)); + QuicConnectionId server_cid2 = client_connection->connection_id(); + EXPECT_FALSE(server_cid2.IsEmpty()); + EXPECT_NE(server_cid0, server_cid2); + EXPECT_NE(server_cid1, server_cid2); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + + // Send another request using the new socket and wait for response making sure + // path response is received at server. + SendSynchronousBarRequestAndCheckResponse(); + EXPECT_EQ(2u, + client_connection->GetStats().num_connectivity_probing_received); + + // Migrate socket back to an old IP address. + WaitForNewConnectionIds(); + EXPECT_EQ(2u, client_connection->GetStats().num_retire_connection_id_sent); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + EXPECT_TRUE(client_->client()->MigrateSocket(host1)); + QuicConnectionId server_cid3 = client_connection->connection_id(); + EXPECT_FALSE(server_cid3.IsEmpty()); + EXPECT_NE(server_cid0, server_cid3); + EXPECT_NE(server_cid1, server_cid3); + EXPECT_NE(server_cid2, server_cid3); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + const auto* client_packet_creator = + QuicConnectionPeer::GetPacketCreator(client_connection); + EXPECT_TRUE(client_packet_creator->GetClientConnectionId().IsEmpty()); + EXPECT_EQ(server_cid3, client_packet_creator->GetServerConnectionId()); + + // Send another request using the new socket and wait for response making sure + // path response is received at server. + SendSynchronousBarRequestAndCheckResponse(); + // Even this is an old path, server has forgotten about it and thus needs to + // validate the path again. + EXPECT_EQ(3u, + client_connection->GetStats().num_connectivity_probing_received); + + WaitForNewConnectionIds(); + EXPECT_EQ(3u, client_connection->GetStats().num_retire_connection_id_sent); + + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + // By the time the 2nd request is completed, the PATH_RESPONSE must have been + // received by the server. + EXPECT_FALSE(server_connection->HasPendingPathValidation()); + EXPECT_EQ(3u, server_connection->GetStats().num_validated_peer_migration); + EXPECT_EQ(server_cid3, server_connection->connection_id()); + const auto* server_packet_creator = + QuicConnectionPeer::GetPacketCreator(server_connection); + EXPECT_EQ(server_cid3, server_packet_creator->GetServerConnectionId()); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + server_connection) + .IsEmpty()); + EXPECT_EQ(4u, server_connection->GetStats().num_new_connection_id_sent); + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, + ConnectionMigrationWithNonZeroConnectionIDClientIPChangedMultipleTimes) { + if (!version_.SupportsClientConnectionIds() || + GetQuicFlag(quic_enforce_strict_amplification_factor)) { + ASSERT_TRUE(Initialize()); + return; + } + override_client_connection_id_length_ = kQuicDefaultConnectionIdLength; + ASSERT_TRUE(Initialize()); + if (!GetClientConnection()->connection_migration_use_new_cid()) { + return; + } + SendSynchronousFooRequestAndCheckResponse(); + + // Store the client IP address which was used to send the first request. + QuicIpAddress host0 = + client_->client()->network_helper()->GetLatestClientAddress().host(); + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection != nullptr); + + // Migrate socket to a new IP address. + QuicIpAddress host1 = TestLoopback(2); + EXPECT_NE(host0, host1); + ASSERT_TRUE( + QuicConnectionPeer::HasUnusedPeerIssuedConnectionId(client_connection)); + QuicConnectionId server_cid0 = client_connection->connection_id(); + QuicConnectionId client_cid0 = client_connection->client_connection_id(); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + EXPECT_TRUE(QuicConnectionPeer::GetClientConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + EXPECT_TRUE(client_->client()->MigrateSocket(host1)); + QuicConnectionId server_cid1 = client_connection->connection_id(); + QuicConnectionId client_cid1 = client_connection->client_connection_id(); + EXPECT_FALSE(server_cid1.IsEmpty()); + EXPECT_FALSE(client_cid1.IsEmpty()); + EXPECT_NE(server_cid0, server_cid1); + EXPECT_NE(client_cid0, client_cid1); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + EXPECT_TRUE(QuicConnectionPeer::GetClientConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + + // Send another request to ensure that the server will have time to finish the + // reverse path validation and send address token. + SendSynchronousBarRequestAndCheckResponse(); + EXPECT_EQ(1u, + client_connection->GetStats().num_connectivity_probing_received); + + // Migrate socket to a new IP address. + WaitForNewConnectionIds(); + EXPECT_EQ(1u, client_connection->GetStats().num_retire_connection_id_sent); + EXPECT_EQ(2u, client_connection->GetStats().num_new_connection_id_sent); + QuicIpAddress host2 = TestLoopback(3); + EXPECT_NE(host0, host2); + EXPECT_NE(host1, host2); + EXPECT_TRUE(client_->client()->MigrateSocket(host2)); + QuicConnectionId server_cid2 = client_connection->connection_id(); + QuicConnectionId client_cid2 = client_connection->client_connection_id(); + EXPECT_FALSE(server_cid2.IsEmpty()); + EXPECT_NE(server_cid0, server_cid2); + EXPECT_NE(server_cid1, server_cid2); + EXPECT_FALSE(client_cid2.IsEmpty()); + EXPECT_NE(client_cid0, client_cid2); + EXPECT_NE(client_cid1, client_cid2); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + EXPECT_TRUE(QuicConnectionPeer::GetClientConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + + // Send another request to ensure that the server will have time to finish the + // reverse path validation and send address token. + SendSynchronousBarRequestAndCheckResponse(); + EXPECT_EQ(2u, + client_connection->GetStats().num_connectivity_probing_received); + + // Migrate socket back to an old IP address. + WaitForNewConnectionIds(); + EXPECT_EQ(2u, client_connection->GetStats().num_retire_connection_id_sent); + EXPECT_EQ(3u, client_connection->GetStats().num_new_connection_id_sent); + EXPECT_TRUE(client_->client()->MigrateSocket(host1)); + QuicConnectionId server_cid3 = client_connection->connection_id(); + QuicConnectionId client_cid3 = client_connection->client_connection_id(); + EXPECT_FALSE(server_cid3.IsEmpty()); + EXPECT_NE(server_cid0, server_cid3); + EXPECT_NE(server_cid1, server_cid3); + EXPECT_NE(server_cid2, server_cid3); + EXPECT_FALSE(client_cid3.IsEmpty()); + EXPECT_NE(client_cid0, client_cid3); + EXPECT_NE(client_cid1, client_cid3); + EXPECT_NE(client_cid2, client_cid3); + const auto* client_packet_creator = + QuicConnectionPeer::GetPacketCreator(client_connection); + EXPECT_EQ(client_cid3, client_packet_creator->GetClientConnectionId()); + EXPECT_EQ(server_cid3, client_packet_creator->GetServerConnectionId()); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + + // Send another request to ensure that the server will have time to finish the + // reverse path validation and send address token. + SendSynchronousBarRequestAndCheckResponse(); + // Even this is an old path, server has forgotten about it and thus needs to + // validate the path again. + EXPECT_EQ(3u, + client_connection->GetStats().num_connectivity_probing_received); + + WaitForNewConnectionIds(); + EXPECT_EQ(3u, client_connection->GetStats().num_retire_connection_id_sent); + EXPECT_EQ(4u, client_connection->GetStats().num_new_connection_id_sent); + + server_thread_->Pause(); + // By the time the 2nd request is completed, the PATH_RESPONSE must have been + // received by the server. + QuicConnection* server_connection = GetServerConnection(); + EXPECT_FALSE(server_connection->HasPendingPathValidation()); + EXPECT_EQ(3u, server_connection->GetStats().num_validated_peer_migration); + EXPECT_EQ(server_cid3, server_connection->connection_id()); + EXPECT_EQ(client_cid3, server_connection->client_connection_id()); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + server_connection) + .IsEmpty()); + const auto* server_packet_creator = + QuicConnectionPeer::GetPacketCreator(server_connection); + EXPECT_EQ(client_cid3, server_packet_creator->GetClientConnectionId()); + EXPECT_EQ(server_cid3, server_packet_creator->GetServerConnectionId()); + EXPECT_EQ(3u, server_connection->GetStats().num_retire_connection_id_sent); + EXPECT_EQ(4u, server_connection->GetStats().num_new_connection_id_sent); + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, ConnectionMigrationNewTokenForNewIp) { + ASSERT_TRUE(Initialize()); + if (!version_.HasIetfQuicFrames() || + !client_->client()->session()->connection()->validate_client_address() || + GetQuicFlag(quic_enforce_strict_amplification_factor)) { + return; + } + SendSynchronousFooRequestAndCheckResponse(); + + // Store the client IP address which was used to send the first request. + QuicIpAddress old_host = + client_->client()->network_helper()->GetLatestClientAddress().host(); + + // Migrate socket to the new IP address. + QuicIpAddress new_host = TestLoopback(2); + EXPECT_NE(old_host, new_host); + ASSERT_TRUE(client_->client()->MigrateSocket(new_host)); + + // Send a request using the new socket. + SendSynchronousBarRequestAndCheckResponse(); + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + EXPECT_EQ(1u, + client_connection->GetStats().num_connectivity_probing_received); + + // Send another request to ensure that the server will have time to finish the + // reverse path validation and send address token. + SendSynchronousBarRequestAndCheckResponse(); + + client_->Disconnect(); + // The 0-RTT handshake should succeed. + client_->Connect(); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + ASSERT_TRUE(client_->client()->connected()); + SendSynchronousFooRequestAndCheckResponse(); + + EXPECT_TRUE(GetClientSession()->EarlyDataAccepted()); + EXPECT_TRUE(client_->client()->EarlyDataAccepted()); + + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + if (server_connection != nullptr) { + // Verify address is validated via validating token received in INITIAL + // packet. + EXPECT_FALSE( + server_connection->GetStats().address_validated_via_decrypting_packet); + EXPECT_TRUE(server_connection->GetStats().address_validated_via_token); + } else { + ADD_FAILURE() << "Missing server connection"; + } + server_thread_->Resume(); + client_->Disconnect(); +} + +// A writer which copies the packet and send the copy with a specified self +// address and then send the same packet with the original self address. +class DuplicatePacketWithSpoofedSelfAddressWriter + : public QuicPacketWriterWrapper { + public: + WriteResult WritePacket(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + PerPacketOptions* options) override { + if (self_address_to_overwrite_.IsInitialized()) { + // Send the same packet on the overwriting address before sending on the + // actual self address. + QuicPacketWriterWrapper::WritePacket( + buffer, buf_len, self_address_to_overwrite_, peer_address, options); + } + return QuicPacketWriterWrapper::WritePacket(buffer, buf_len, self_address, + peer_address, options); + } + + void set_self_address_to_overwrite(const QuicIpAddress& self_address) { + self_address_to_overwrite_ = self_address; + } + + private: + QuicIpAddress self_address_to_overwrite_; +}; + +TEST_P(EndToEndTest, ClientAddressSpoofedForSomePeriod) { + ASSERT_TRUE(Initialize()); + if (!GetClientConnection()->connection_migration_use_new_cid()) { + return; + } + auto writer = new DuplicatePacketWithSpoofedSelfAddressWriter(); + client_.reset(CreateQuicClient(writer)); + + // Make sure client has unused peer connection ID before migration. + SendSynchronousFooRequestAndCheckResponse(); + ASSERT_TRUE(QuicConnectionPeer::HasUnusedPeerIssuedConnectionId( + GetClientConnection())); + + QuicIpAddress real_host = + client_->client()->session()->connection()->self_address().host(); + ASSERT_TRUE(client_->MigrateSocket(real_host)); + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_EQ( + 0u, GetClientConnection()->GetStats().num_connectivity_probing_received); + EXPECT_EQ( + real_host, + client_->client()->network_helper()->GetLatestClientAddress().host()); + client_->WaitForDelayedAcks(); + + std::string large_body(10240, 'a'); + AddToCache("/large_response", 200, large_body); + + QuicIpAddress spoofed_host = TestLoopback(2); + writer->set_self_address_to_overwrite(spoofed_host); + + client_->SendRequest("/large_response"); + QuicConnection* client_connection = GetClientConnection(); + QuicPacketCount num_packets_received = + client_connection->GetStats().packets_received; + + while (client_->client()->WaitForEvents() && client_->connected()) { + if (client_connection->GetStats().packets_received > num_packets_received) { + // Ideally the client won't receive any packets till the server finds out + // the new client address is not working. But there are 2 corner cases: + // 1) Before the server received the packet from spoofed address, it might + // send packets to the real client address. So the client will immediately + // switch back to use the original address; + // 2) Between the server fails reverse path validation and the client + // receives packets again, the client might sent some packets with the + // spoofed address and triggers another migration. + // In both corner cases, the attempted migration should fail and fall back + // to the working path. + writer->set_self_address_to_overwrite(QuicIpAddress()); + } + } + client_->WaitForResponse(); + EXPECT_EQ(large_body, client_->response_body()); +} + +TEST_P(EndToEndTest, + AsynchronousConnectionMigrationClientIPChangedMultipleTimes) { + ASSERT_TRUE(Initialize()); + if (!GetClientConnection()->connection_migration_use_new_cid()) { + return; + } + client_.reset(CreateQuicClient(nullptr)); + + SendSynchronousFooRequestAndCheckResponse(); + + // Store the client IP address which was used to send the first request. + QuicIpAddress host0 = + client_->client()->network_helper()->GetLatestClientAddress().host(); + QuicConnection* client_connection = GetClientConnection(); + QuicConnectionId server_cid0 = client_connection->connection_id(); + // Server should have one new connection ID upon handshake completion. + ASSERT_TRUE( + QuicConnectionPeer::HasUnusedPeerIssuedConnectionId(client_connection)); + + // Migrate socket to new IP address #1. + QuicIpAddress host1 = TestLoopback(2); + EXPECT_NE(host0, host1); + ASSERT_TRUE(client_->client()->ValidateAndMigrateSocket(host1)); + while (client_->client()->HasPendingPathValidation()) { + client_->client()->WaitForEvents(); + } + EXPECT_EQ(host1, client_->client()->session()->self_address().host()); + EXPECT_EQ(1u, + client_connection->GetStats().num_connectivity_probing_received); + QuicConnectionId server_cid1 = client_connection->connection_id(); + EXPECT_NE(server_cid0, server_cid1); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + + // Send a request using the new socket. + SendSynchronousBarRequestAndCheckResponse(); + + // Migrate socket to new IP address #2. + WaitForNewConnectionIds(); + QuicIpAddress host2 = TestLoopback(3); + EXPECT_NE(host0, host1); + ASSERT_TRUE(client_->client()->ValidateAndMigrateSocket(host2)); + + while (client_->client()->HasPendingPathValidation()) { + client_->client()->WaitForEvents(); + } + EXPECT_EQ(host2, client_->client()->session()->self_address().host()); + EXPECT_EQ(2u, + client_connection->GetStats().num_connectivity_probing_received); + QuicConnectionId server_cid2 = client_connection->connection_id(); + EXPECT_NE(server_cid0, server_cid2); + EXPECT_NE(server_cid1, server_cid2); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + + // Send a request using the new socket. + SendSynchronousBarRequestAndCheckResponse(); + + // Migrate socket back to IP address #1. + WaitForNewConnectionIds(); + ASSERT_TRUE(client_->client()->ValidateAndMigrateSocket(host1)); + + while (client_->client()->HasPendingPathValidation()) { + client_->client()->WaitForEvents(); + } + EXPECT_EQ(host1, client_->client()->session()->self_address().host()); + EXPECT_EQ(3u, + client_connection->GetStats().num_connectivity_probing_received); + QuicConnectionId server_cid3 = client_connection->connection_id(); + EXPECT_NE(server_cid0, server_cid3); + EXPECT_NE(server_cid1, server_cid3); + EXPECT_NE(server_cid2, server_cid3); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + + // Send a request using the new socket. + SendSynchronousBarRequestAndCheckResponse(); + server_thread_->Pause(); + const QuicConnection* server_connection = GetServerConnection(); + EXPECT_EQ(server_connection->connection_id(), server_cid3); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + server_connection) + .IsEmpty()); + server_thread_->Resume(); + + // There should be 1 new connection ID issued by the server. + WaitForNewConnectionIds(); +} + +TEST_P(EndToEndTest, + AsynchronousConnectionMigrationClientIPChangedWithNonEmptyClientCID) { + if (!version_.SupportsClientConnectionIds()) { + ASSERT_TRUE(Initialize()); + return; + } + override_client_connection_id_length_ = kQuicDefaultConnectionIdLength; + ASSERT_TRUE(Initialize()); + if (!GetClientConnection()->connection_migration_use_new_cid()) { + return; + } + client_.reset(CreateQuicClient(nullptr)); + + SendSynchronousFooRequestAndCheckResponse(); + + // Store the client IP address which was used to send the first request. + QuicIpAddress old_host = + client_->client()->network_helper()->GetLatestClientAddress().host(); + auto* client_connection = GetClientConnection(); + QuicConnectionId client_cid0 = client_connection->client_connection_id(); + QuicConnectionId server_cid0 = client_connection->connection_id(); + + // Migrate socket to the new IP address. + QuicIpAddress new_host = TestLoopback(2); + EXPECT_NE(old_host, new_host); + ASSERT_TRUE(client_->client()->ValidateAndMigrateSocket(new_host)); + + while (client_->client()->HasPendingPathValidation()) { + client_->client()->WaitForEvents(); + } + EXPECT_EQ(new_host, client_->client()->session()->self_address().host()); + EXPECT_EQ(1u, + client_connection->GetStats().num_connectivity_probing_received); + QuicConnectionId client_cid1 = client_connection->client_connection_id(); + QuicConnectionId server_cid1 = client_connection->connection_id(); + const auto* client_packet_creator = + QuicConnectionPeer::GetPacketCreator(client_connection); + EXPECT_EQ(client_cid1, client_packet_creator->GetClientConnectionId()); + EXPECT_EQ(server_cid1, client_packet_creator->GetServerConnectionId()); + // Send a request using the new socket. + SendSynchronousBarRequestAndCheckResponse(); + + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + EXPECT_EQ(client_cid1, server_connection->client_connection_id()); + EXPECT_EQ(server_cid1, server_connection->connection_id()); + const auto* server_packet_creator = + QuicConnectionPeer::GetPacketCreator(server_connection); + EXPECT_EQ(client_cid1, server_packet_creator->GetClientConnectionId()); + EXPECT_EQ(server_cid1, server_packet_creator->GetServerConnectionId()); + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, ConnectionMigrationClientPortChanged) { + // Tests that the client's port can change during an established QUIC + // connection, and that doing so does not result in the connection being + // closed by the server. + ASSERT_TRUE(Initialize()); + + SendSynchronousFooRequestAndCheckResponse(); + + // Store the client address which was used to send the first request. + QuicSocketAddress old_address = + client_->client()->network_helper()->GetLatestClientAddress(); + int old_fd = client_->client()->GetLatestFD(); + + // Create a new socket before closing the old one, which will result in a new + // ephemeral port. + client_->client()->network_helper()->CreateUDPSocketAndBind( + client_->client()->server_address(), client_->client()->bind_to_address(), + client_->client()->local_port()); + + // Stop listening and close the old FD. + client_->client()->default_network_helper()->CleanUpUDPSocket(old_fd); + + // The packet writer needs to be updated to use the new FD. + client_->client()->network_helper()->CreateQuicPacketWriter(); + + // Change the internal state of the client and connection to use the new port, + // this is done because in a real NAT rebinding the client wouldn't see any + // port change, and so expects no change to incoming port. + // This is kind of ugly, but needed as we are simply swapping out the client + // FD rather than any more complex NAT rebinding simulation. + int new_port = + client_->client()->network_helper()->GetLatestClientAddress().port(); + client_->client()->default_network_helper()->SetClientPort(new_port); + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + QuicConnectionPeer::SetSelfAddress( + client_connection, + QuicSocketAddress(client_connection->self_address().host(), new_port)); + + // Send a second request, using the new FD. + SendSynchronousBarRequestAndCheckResponse(); + + // Verify that the client's ephemeral port is different. + QuicSocketAddress new_address = + client_->client()->network_helper()->GetLatestClientAddress(); + EXPECT_EQ(old_address.host(), new_address.host()); + EXPECT_NE(old_address.port(), new_address.port()); + + if (!version_.HasIetfQuicFrames() || + !GetClientConnection()->validate_client_address()) { + return; + } + + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + if (server_connection != nullptr) { + EXPECT_FALSE(server_connection->HasPendingPathValidation()); + EXPECT_EQ(1u, server_connection->GetStats().num_validated_peer_migration); + } else { + ADD_FAILURE() << "Missing server connection"; + } + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, NegotiatedInitialCongestionWindow) { + client_extra_copts_.push_back(kIW03); + + ASSERT_TRUE(Initialize()); + + // Values are exchanged during crypto handshake, so wait for that to finish. + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + server_thread_->WaitForCryptoHandshakeConfirmed(); + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + if (server_connection != nullptr) { + QuicPacketCount cwnd = + server_connection->sent_packet_manager().initial_congestion_window(); + EXPECT_EQ(3u, cwnd); + } else { + ADD_FAILURE() << "Missing server connection"; + } + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, DifferentFlowControlWindows) { + // Client and server can set different initial flow control receive windows. + // These are sent in CHLO/SHLO. Tests that these values are exchanged properly + // in the crypto handshake. + const uint32_t kClientStreamIFCW = 123456; + const uint32_t kClientSessionIFCW = 234567; + set_client_initial_stream_flow_control_receive_window(kClientStreamIFCW); + set_client_initial_session_flow_control_receive_window(kClientSessionIFCW); + + uint32_t kServerStreamIFCW = 32 * 1024; + uint32_t kServerSessionIFCW = 48 * 1024; + set_server_initial_stream_flow_control_receive_window(kServerStreamIFCW); + set_server_initial_session_flow_control_receive_window(kServerSessionIFCW); + + ASSERT_TRUE(Initialize()); + + // Values are exchanged during crypto handshake, so wait for that to finish. + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + server_thread_->WaitForCryptoHandshakeConfirmed(); + + // Open a data stream to make sure the stream level flow control is updated. + QuicSpdyClientStream* stream = client_->GetOrCreateStream(); + WriteHeadersOnStream(stream); + stream->WriteOrBufferBody("hello", false); + + if (!version_.UsesTls()) { + // IFWA only exists with QUIC_CRYPTO. + // Client should have the right values for server's receive window. + ASSERT_TRUE(client_->client() + ->client_session() + ->config() + ->HasReceivedInitialStreamFlowControlWindowBytes()); + EXPECT_EQ(kServerStreamIFCW, + client_->client() + ->client_session() + ->config() + ->ReceivedInitialStreamFlowControlWindowBytes()); + ASSERT_TRUE(client_->client() + ->client_session() + ->config() + ->HasReceivedInitialSessionFlowControlWindowBytes()); + EXPECT_EQ(kServerSessionIFCW, + client_->client() + ->client_session() + ->config() + ->ReceivedInitialSessionFlowControlWindowBytes()); + } + EXPECT_EQ(kServerStreamIFCW, QuicStreamPeer::SendWindowOffset(stream)); + QuicSpdyClientSession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_EQ(kServerSessionIFCW, QuicFlowControllerPeer::SendWindowOffset( + client_session->flow_controller())); + + // Server should have the right values for client's receive window. + server_thread_->Pause(); + QuicSpdySession* server_session = GetServerSession(); + if (server_session == nullptr) { + ADD_FAILURE() << "Missing server session"; + server_thread_->Resume(); + return; + } + QuicConfig server_config = *server_session->config(); + EXPECT_EQ(kClientSessionIFCW, QuicFlowControllerPeer::SendWindowOffset( + server_session->flow_controller())); + server_thread_->Resume(); + if (version_.UsesTls()) { + // IFWA only exists with QUIC_CRYPTO. + return; + } + ASSERT_TRUE(server_config.HasReceivedInitialStreamFlowControlWindowBytes()); + EXPECT_EQ(kClientStreamIFCW, + server_config.ReceivedInitialStreamFlowControlWindowBytes()); + ASSERT_TRUE(server_config.HasReceivedInitialSessionFlowControlWindowBytes()); + EXPECT_EQ(kClientSessionIFCW, + server_config.ReceivedInitialSessionFlowControlWindowBytes()); +} + +// Test negotiation of IFWA connection option. +TEST_P(EndToEndTest, NegotiatedServerInitialFlowControlWindow) { + const uint32_t kClientStreamIFCW = 123456; + const uint32_t kClientSessionIFCW = 234567; + set_client_initial_stream_flow_control_receive_window(kClientStreamIFCW); + set_client_initial_session_flow_control_receive_window(kClientSessionIFCW); + + uint32_t kServerStreamIFCW = 32 * 1024; + uint32_t kServerSessionIFCW = 48 * 1024; + set_server_initial_stream_flow_control_receive_window(kServerStreamIFCW); + set_server_initial_session_flow_control_receive_window(kServerSessionIFCW); + + // Bump the window. + const uint32_t kExpectedStreamIFCW = 1024 * 1024; + const uint32_t kExpectedSessionIFCW = 1.5 * 1024 * 1024; + client_extra_copts_.push_back(kIFWA); + + ASSERT_TRUE(Initialize()); + + // Values are exchanged during crypto handshake, so wait for that to finish. + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + server_thread_->WaitForCryptoHandshakeConfirmed(); + + // Open a data stream to make sure the stream level flow control is updated. + QuicSpdyClientStream* stream = client_->GetOrCreateStream(); + WriteHeadersOnStream(stream); + stream->WriteOrBufferBody("hello", false); + + QuicSpdyClientSession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + + if (!version_.UsesTls()) { + // IFWA only exists with QUIC_CRYPTO. + // Client should have the right values for server's receive window. + ASSERT_TRUE(client_session->config() + ->HasReceivedInitialStreamFlowControlWindowBytes()); + EXPECT_EQ(kExpectedStreamIFCW, + client_session->config() + ->ReceivedInitialStreamFlowControlWindowBytes()); + ASSERT_TRUE(client_session->config() + ->HasReceivedInitialSessionFlowControlWindowBytes()); + EXPECT_EQ(kExpectedSessionIFCW, + client_session->config() + ->ReceivedInitialSessionFlowControlWindowBytes()); + } + EXPECT_EQ(kExpectedStreamIFCW, QuicStreamPeer::SendWindowOffset(stream)); + EXPECT_EQ(kExpectedSessionIFCW, QuicFlowControllerPeer::SendWindowOffset( + client_session->flow_controller())); +} + +TEST_P(EndToEndTest, HeadersAndCryptoStreamsNoConnectionFlowControl) { + // The special headers and crypto streams should be subject to per-stream flow + // control limits, but should not be subject to connection level flow control + const uint32_t kStreamIFCW = 32 * 1024; + const uint32_t kSessionIFCW = 48 * 1024; + set_client_initial_stream_flow_control_receive_window(kStreamIFCW); + set_client_initial_session_flow_control_receive_window(kSessionIFCW); + set_server_initial_stream_flow_control_receive_window(kStreamIFCW); + set_server_initial_session_flow_control_receive_window(kSessionIFCW); + + ASSERT_TRUE(Initialize()); + + // Wait for crypto handshake to finish. This should have contributed to the + // crypto stream flow control window, but not affected the session flow + // control window. + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + server_thread_->WaitForCryptoHandshakeConfirmed(); + + QuicSpdyClientSession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + QuicCryptoStream* crypto_stream = + QuicSessionPeer::GetMutableCryptoStream(client_session); + ASSERT_TRUE(crypto_stream); + // In v47 and later, the crypto handshake (sent in CRYPTO frames) is not + // subject to flow control. + if (!version_.UsesCryptoFrames()) { + EXPECT_LT(QuicStreamPeer::SendWindowSize(crypto_stream), kStreamIFCW); + } + // When stream type is enabled, control streams will send settings and + // contribute to flow control windows, so this expectation is no longer valid. + if (!version_.UsesHttp3()) { + EXPECT_EQ(kSessionIFCW, QuicFlowControllerPeer::SendWindowSize( + client_session->flow_controller())); + } + + // Send a request with no body, and verify that the connection level window + // has not been affected. + EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); + + // No headers stream in IETF QUIC. + if (version_.UsesHttp3()) { + return; + } + + QuicHeadersStream* headers_stream = + QuicSpdySessionPeer::GetHeadersStream(client_session); + ASSERT_TRUE(headers_stream); + EXPECT_LT(QuicStreamPeer::SendWindowSize(headers_stream), kStreamIFCW); + EXPECT_EQ(kSessionIFCW, QuicFlowControllerPeer::SendWindowSize( + client_session->flow_controller())); + + // Server should be in a similar state: connection flow control window should + // not have any bytes marked as received. + server_thread_->Pause(); + QuicSession* server_session = GetServerSession(); + if (server_session != nullptr) { + QuicFlowController* server_connection_flow_controller = + server_session->flow_controller(); + EXPECT_EQ(kSessionIFCW, QuicFlowControllerPeer::ReceiveWindowSize( + server_connection_flow_controller)); + } else { + ADD_FAILURE() << "Missing server session"; + } + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, FlowControlsSynced) { + set_smaller_flow_control_receive_window(); + + ASSERT_TRUE(Initialize()); + + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + server_thread_->WaitForCryptoHandshakeConfirmed(); + + QuicSpdySession* const client_session = GetClientSession(); + ASSERT_TRUE(client_session); + + if (version_.UsesHttp3()) { + // Make sure that the client has received the initial SETTINGS frame, which + // is sent in the first packet on the control stream. + while (!QuicSpdySessionPeer::GetReceiveControlStream(client_session)) { + client_->client()->WaitForEvents(); + ASSERT_TRUE(client_->connected()); + } + } + + // Make sure that all data sent by the client has been received by the server + // (and the ack received by the client). + while (client_session->HasUnackedStreamData()) { + client_->client()->WaitForEvents(); + ASSERT_TRUE(client_->connected()); + } + + server_thread_->Pause(); + + QuicSpdySession* const server_session = GetServerSession(); + if (server_session == nullptr) { + ADD_FAILURE() << "Missing server session"; + server_thread_->Resume(); + return; + } + ExpectFlowControlsSynced(client_session, server_session); + + // Check control streams. + if (version_.UsesHttp3()) { + ExpectFlowControlsSynced( + QuicSpdySessionPeer::GetReceiveControlStream(client_session), + QuicSpdySessionPeer::GetSendControlStream(server_session)); + ExpectFlowControlsSynced( + QuicSpdySessionPeer::GetSendControlStream(client_session), + QuicSpdySessionPeer::GetReceiveControlStream(server_session)); + } + + // Check crypto stream. + if (!version_.UsesCryptoFrames()) { + ExpectFlowControlsSynced( + QuicSessionPeer::GetMutableCryptoStream(client_session), + QuicSessionPeer::GetMutableCryptoStream(server_session)); + } + + // Check headers stream. + if (!version_.UsesHttp3()) { + SpdyFramer spdy_framer(SpdyFramer::ENABLE_COMPRESSION); + SpdySettingsIR settings_frame; + settings_frame.AddSetting(spdy::SETTINGS_MAX_HEADER_LIST_SIZE, + kDefaultMaxUncompressedHeaderSize); + SpdySerializedFrame frame(spdy_framer.SerializeFrame(settings_frame)); + + QuicHeadersStream* client_header_stream = + QuicSpdySessionPeer::GetHeadersStream(client_session); + QuicHeadersStream* server_header_stream = + QuicSpdySessionPeer::GetHeadersStream(server_session); + // Both client and server are sending this SETTINGS frame, and the send + // window is consumed. But because of timing issue, the server may send or + // not send the frame, and the client may send/ not send / receive / not + // receive the frame. + // TODO(fayang): Rewrite this part because it is hacky. + QuicByteCount win_difference1 = + QuicStreamPeer::ReceiveWindowSize(server_header_stream) - + QuicStreamPeer::SendWindowSize(client_header_stream); + if (win_difference1 != 0) { + EXPECT_EQ(frame.size(), win_difference1); + } + + QuicByteCount win_difference2 = + QuicStreamPeer::ReceiveWindowSize(client_header_stream) - + QuicStreamPeer::SendWindowSize(server_header_stream); + if (win_difference2 != 0) { + EXPECT_EQ(frame.size(), win_difference2); + } + + // Client *may* have received the SETTINGs frame. + // TODO(fayang): Rewrite this part because it is hacky. + float ratio1 = static_cast(QuicFlowControllerPeer::ReceiveWindowSize( + client_session->flow_controller())) / + QuicStreamPeer::ReceiveWindowSize( + QuicSpdySessionPeer::GetHeadersStream(client_session)); + float ratio2 = static_cast(QuicFlowControllerPeer::ReceiveWindowSize( + client_session->flow_controller())) / + (QuicStreamPeer::ReceiveWindowSize( + QuicSpdySessionPeer::GetHeadersStream(client_session)) + + frame.size()); + EXPECT_TRUE(ratio1 == kSessionToStreamRatio || + ratio2 == kSessionToStreamRatio); + } + + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, RequestWithNoBodyWillNeverSendStreamFrameWithFIN) { + // A stream created on receipt of a simple request with no body will never get + // a stream frame with a FIN. Verify that we don't keep track of the stream in + // the locally closed streams map: it will never be removed if so. + ASSERT_TRUE(Initialize()); + + // Send a simple headers only request, and receive response. + SendSynchronousFooRequestAndCheckResponse(); + + // Now verify that the server is not waiting for a final FIN or RST. + server_thread_->Pause(); + QuicSession* server_session = GetServerSession(); + if (server_session != nullptr) { + EXPECT_EQ(0u, QuicSessionPeer::GetLocallyClosedStreamsHighestOffset( + server_session) + .size()); + } else { + ADD_FAILURE() << "Missing server session"; + } + server_thread_->Resume(); +} + +// TestAckListener counts how many bytes are acked during its lifetime. +class TestAckListener : public QuicAckListenerInterface { + public: + TestAckListener() {} + + void OnPacketAcked(int acked_bytes, + QuicTime::Delta /*delta_largest_observed*/) override { + total_bytes_acked_ += acked_bytes; + } + + void OnPacketRetransmitted(int /*retransmitted_bytes*/) override {} + + int total_bytes_acked() const { return total_bytes_acked_; } + + protected: + // Object is ref counted. + ~TestAckListener() override {} + + private: + int total_bytes_acked_ = 0; +}; + +class TestResponseListener : public QuicSpdyClientBase::ResponseListener { + public: + void OnCompleteResponse(QuicStreamId id, + const Http2HeaderBlock& response_headers, + absl::string_view response_body) override { + QUIC_DVLOG(1) << "response for stream " << id << " " + << response_headers.DebugString() << "\n" + << response_body; + } +}; + +TEST_P(EndToEndTest, AckNotifierWithPacketLossAndBlockedSocket) { + // Verify that even in the presence of packet loss and occasionally blocked + // socket, an AckNotifierDelegate will get informed that the data it is + // interested in has been ACKed. This tests end-to-end ACK notification, and + // demonstrates that retransmissions do not break this functionality. + // Disable blackhole detection as this test is testing loss recovery. + client_extra_copts_.push_back(kNBHD); + SetPacketLossPercentage(5); + ASSERT_TRUE(Initialize()); + // Wait for the server SHLO before upping the packet loss. + EXPECT_TRUE(client_->client()->WaitForHandshakeConfirmed()); + SetPacketLossPercentage(30); + client_writer_->set_fake_blocked_socket_percentage(10); + + // Wait for SETTINGS frame from server that sets QPACK dynamic table capacity + // to make sure request headers will be compressed using the dynamic table. + if (version_.UsesHttp3()) { + while (true) { + // Waits for up to 50 ms. + client_->client()->WaitForEvents(); + ASSERT_TRUE(client_->connected()); + QuicSpdyClientSession* client_session = GetClientSession(); + if (client_session == nullptr) { + ADD_FAILURE() << "Missing client session"; + return; + } + QpackEncoder* qpack_encoder = client_session->qpack_encoder(); + if (qpack_encoder == nullptr) { + ADD_FAILURE() << "Missing QPACK encoder"; + return; + } + QpackEncoderHeaderTable* header_table = + QpackEncoderPeer::header_table(qpack_encoder); + if (header_table == nullptr) { + ADD_FAILURE() << "Missing header table"; + return; + } + if (header_table->dynamic_table_capacity() > 0) { + break; + } + } + } + + // Create a POST request and send the headers only. + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + + // Here, we have to specify flush=false, otherwise we risk a race condition in + // which the headers are sent and acknowledged before the ack notifier is + // installed. + client_->SendMessage(headers, "", /*fin=*/false, /*flush=*/false); + + // Size of headers on the request stream. This is zero if headers are sent on + // the header stream. + size_t header_size = 0; + if (version_.UsesHttp3()) { + // Determine size of headers after QPACK compression. + NoopDecoderStreamErrorDelegate decoder_stream_error_delegate; + NoopQpackStreamSenderDelegate encoder_stream_sender_delegate; + QpackEncoder qpack_encoder(&decoder_stream_error_delegate); + qpack_encoder.set_qpack_stream_sender_delegate( + &encoder_stream_sender_delegate); + + qpack_encoder.SetMaximumDynamicTableCapacity( + kDefaultQpackMaxDynamicTableCapacity); + qpack_encoder.SetDynamicTableCapacity(kDefaultQpackMaxDynamicTableCapacity); + qpack_encoder.SetMaximumBlockedStreams(kDefaultMaximumBlockedStreams); + + std::string encoded_headers = qpack_encoder.EncodeHeaderList( + /* stream_id = */ 0, headers, nullptr); + header_size = encoded_headers.size(); + } + + // Test the AckNotifier's ability to track multiple packets by making the + // request body exceed the size of a single packet. + std::string request_string = "a request body bigger than one packet" + + std::string(kMaxOutgoingPacketSize, '.'); + + const int expected_bytes_acked = header_size + request_string.length(); + + // The TestAckListener will cause a failure if not notified. + quiche::QuicheReferenceCountedPointer ack_listener( + new TestAckListener()); + + // Send the request, and register the delegate for ACKs. + client_->SendData(request_string, true, ack_listener); + WaitForFooResponseAndCheckIt(); + + // Send another request to flush out any pending ACKs on the server. + SendSynchronousBarRequestAndCheckResponse(); + + // Make sure the delegate does get the notification it expects. + int attempts = 0; + constexpr int kMaxAttempts = 20; + while (ack_listener->total_bytes_acked() < expected_bytes_acked) { + // Waits for up to 50 ms. + client_->client()->WaitForEvents(); + ASSERT_TRUE(client_->connected()); + if (++attempts >= kMaxAttempts) { + break; + } + } + EXPECT_EQ(ack_listener->total_bytes_acked(), expected_bytes_acked) + << " header_size " << header_size << " request length " + << request_string.length(); +} + +// Send a public reset from the server. +TEST_P(EndToEndTest, ServerSendPublicReset) { + ASSERT_TRUE(Initialize()); + + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + QuicSpdySession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + QuicConfig* config = client_session->config(); + ASSERT_TRUE(config); + EXPECT_TRUE(config->HasReceivedStatelessResetToken()); + StatelessResetToken stateless_reset_token = + config->ReceivedStatelessResetToken(); + + // Send the public reset. + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + QuicConnectionId connection_id = client_connection->connection_id(); + QuicPublicResetPacket header; + header.connection_id = connection_id; + QuicFramer framer(server_supported_versions_, QuicTime::Zero(), + Perspective::IS_SERVER, kQuicDefaultConnectionIdLength); + std::unique_ptr packet; + if (version_.HasIetfInvariantHeader()) { + packet = framer.BuildIetfStatelessResetPacket( + connection_id, /*received_packet_length=*/100, stateless_reset_token); + } else { + packet = framer.BuildPublicResetPacket(header); + } + // We must pause the server's thread in order to call WritePacket without + // race conditions. + server_thread_->Pause(); + auto client_address = client_connection->self_address(); + server_writer_->WritePacket(packet->data(), packet->length(), + server_address_.host(), client_address, nullptr); + server_thread_->Resume(); + + // The request should fail. + EXPECT_EQ("", client_->SendSynchronousRequest("/foo")); + EXPECT_TRUE(client_->response_headers()->empty()); + EXPECT_THAT(client_->connection_error(), IsError(QUIC_PUBLIC_RESET)); +} + +// Send a public reset from the server for a different connection ID. +// It should be ignored. +TEST_P(EndToEndTest, ServerSendPublicResetWithDifferentConnectionId) { + ASSERT_TRUE(Initialize()); + + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + QuicSpdySession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + QuicConfig* config = client_session->config(); + ASSERT_TRUE(config); + EXPECT_TRUE(config->HasReceivedStatelessResetToken()); + StatelessResetToken stateless_reset_token = + config->ReceivedStatelessResetToken(); + // Send the public reset. + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + QuicConnectionId incorrect_connection_id = TestConnectionId( + TestConnectionIdToUInt64(client_connection->connection_id()) + 1); + QuicPublicResetPacket header; + header.connection_id = incorrect_connection_id; + QuicFramer framer(server_supported_versions_, QuicTime::Zero(), + Perspective::IS_SERVER, kQuicDefaultConnectionIdLength); + std::unique_ptr packet; + NiceMock visitor; + client_connection->set_debug_visitor(&visitor); + if (version_.HasIetfInvariantHeader()) { + packet = framer.BuildIetfStatelessResetPacket( + incorrect_connection_id, /*received_packet_length=*/100, + stateless_reset_token); + EXPECT_CALL(visitor, OnIncorrectConnectionId(incorrect_connection_id)) + .Times(0); + } else { + packet = framer.BuildPublicResetPacket(header); + EXPECT_CALL(visitor, OnIncorrectConnectionId(incorrect_connection_id)) + .Times(1); + } + // We must pause the server's thread in order to call WritePacket without + // race conditions. + server_thread_->Pause(); + auto client_address = client_connection->self_address(); + server_writer_->WritePacket(packet->data(), packet->length(), + server_address_.host(), client_address, nullptr); + server_thread_->Resume(); + + if (version_.HasIetfInvariantHeader()) { + // The request should fail. IETF stateless reset does not include connection + // ID. + EXPECT_EQ("", client_->SendSynchronousRequest("/foo")); + EXPECT_TRUE(client_->response_headers()->empty()); + EXPECT_THAT(client_->connection_error(), IsError(QUIC_PUBLIC_RESET)); + } else { + // The connection should be unaffected. + SendSynchronousFooRequestAndCheckResponse(); + } + + client_connection->set_debug_visitor(nullptr); +} + +TEST_P(EndToEndTest, InduceStatelessResetFromServer) { + ASSERT_TRUE(Initialize()); + if (!version_.HasIetfQuicFrames()) { + return; + } + EXPECT_TRUE(client_->client()->WaitForHandshakeConfirmed()); + SetPacketLossPercentage(100); // Block PEER_GOING_AWAY message from server. + StopServer(true); + server_writer_ = new PacketDroppingTestWriter(); + StartServer(); + SetPacketLossPercentage(0); + // The request should generate a public reset. + EXPECT_EQ("", client_->SendSynchronousRequest("/foo")); + EXPECT_TRUE(client_->response_headers()->empty()); + EXPECT_THAT(client_->connection_error(), IsError(QUIC_PUBLIC_RESET)); + EXPECT_FALSE(client_->connected()); +} + +// Send a public reset from the client for a different connection ID. +// It should be ignored. +TEST_P(EndToEndTest, ClientSendPublicResetWithDifferentConnectionId) { + ASSERT_TRUE(Initialize()); + + // Send the public reset. + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + QuicConnectionId incorrect_connection_id = TestConnectionId( + TestConnectionIdToUInt64(client_connection->connection_id()) + 1); + QuicPublicResetPacket header; + header.connection_id = incorrect_connection_id; + QuicFramer framer(server_supported_versions_, QuicTime::Zero(), + Perspective::IS_CLIENT, kQuicDefaultConnectionIdLength); + std::unique_ptr packet( + framer.BuildPublicResetPacket(header)); + client_writer_->WritePacket( + packet->data(), packet->length(), + client_->client()->network_helper()->GetLatestClientAddress().host(), + server_address_, nullptr); + + // The connection should be unaffected. + SendSynchronousFooRequestAndCheckResponse(); +} + +// Send a version negotiation packet from the server for a different +// connection ID. It should be ignored. +TEST_P(EndToEndTest, ServerSendVersionNegotiationWithDifferentConnectionId) { + ASSERT_TRUE(Initialize()); + + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + + // Send the version negotiation packet. + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + QuicConnectionId incorrect_connection_id = TestConnectionId( + TestConnectionIdToUInt64(client_connection->connection_id()) + 1); + std::unique_ptr packet( + QuicFramer::BuildVersionNegotiationPacket( + incorrect_connection_id, EmptyQuicConnectionId(), + version_.HasIetfInvariantHeader(), + version_.HasLengthPrefixedConnectionIds(), + server_supported_versions_)); + NiceMock visitor; + client_connection->set_debug_visitor(&visitor); + EXPECT_CALL(visitor, OnIncorrectConnectionId(incorrect_connection_id)) + .Times(1); + // We must pause the server's thread in order to call WritePacket without + // race conditions. + server_thread_->Pause(); + server_writer_->WritePacket( + packet->data(), packet->length(), server_address_.host(), + client_->client()->network_helper()->GetLatestClientAddress(), nullptr); + server_thread_->Resume(); + + // The connection should be unaffected. + SendSynchronousFooRequestAndCheckResponse(); + + client_connection->set_debug_visitor(nullptr); +} + +// DowngradePacketWriter is a client writer which will intercept all the client +// writes for |target_version| and reply to them with version negotiation +// packets to attempt a version downgrade attack. Once the client has downgraded +// to a different version, the writer stops intercepting. |server_thread| must +// start off paused, and will be resumed once interception is done. +class DowngradePacketWriter : public PacketDroppingTestWriter { + public: + explicit DowngradePacketWriter( + const ParsedQuicVersion& target_version, + const ParsedQuicVersionVector& supported_versions, QuicTestClient* client, + QuicPacketWriter* server_writer, ServerThread* server_thread) + : target_version_(target_version), + supported_versions_(supported_versions), + client_(client), + server_writer_(server_writer), + server_thread_(server_thread) {} + ~DowngradePacketWriter() override {} + + WriteResult WritePacket(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + quic::PerPacketOptions* options) override { + if (!intercept_enabled_) { + return PacketDroppingTestWriter::WritePacket( + buffer, buf_len, self_address, peer_address, options); + } + PacketHeaderFormat format; + QuicLongHeaderType long_packet_type; + bool version_present, has_length_prefix; + QuicVersionLabel version_label; + ParsedQuicVersion parsed_version = ParsedQuicVersion::Unsupported(); + QuicConnectionId destination_connection_id, source_connection_id; + absl::optional retry_token; + std::string detailed_error; + if (QuicFramer::ParsePublicHeaderDispatcher( + QuicEncryptedPacket(buffer, buf_len), + kQuicDefaultConnectionIdLength, &format, &long_packet_type, + &version_present, &has_length_prefix, &version_label, + &parsed_version, &destination_connection_id, &source_connection_id, + &retry_token, &detailed_error) != QUIC_NO_ERROR) { + ADD_FAILURE() << "Failed to parse our own packet: " << detailed_error; + return WriteResult(WRITE_STATUS_ERROR, 0); + } + if (!version_present || parsed_version != target_version_) { + // Client is sending with another version, the attack has succeeded so we + // can stop intercepting. + intercept_enabled_ = false; + server_thread_->Resume(); + // Pass the client-sent packet through. + return WritePacket(buffer, buf_len, self_address, peer_address, options); + } + // Send a version negotiation packet. + std::unique_ptr packet( + QuicFramer::BuildVersionNegotiationPacket( + destination_connection_id, source_connection_id, + parsed_version.HasIetfInvariantHeader(), has_length_prefix, + supported_versions_)); + server_writer_->WritePacket( + packet->data(), packet->length(), peer_address.host(), + client_->client()->network_helper()->GetLatestClientAddress(), nullptr); + // Drop the client-sent packet but pretend it was sent. + return WriteResult(WRITE_STATUS_OK, buf_len); + } + + private: + bool intercept_enabled_ = true; + ParsedQuicVersion target_version_; + ParsedQuicVersionVector supported_versions_; + QuicTestClient* client_; // Unowned. + QuicPacketWriter* server_writer_; // Unowned. + ServerThread* server_thread_; // Unowned. +}; + +TEST_P(EndToEndTest, VersionNegotiationDowngradeAttackIsDetected) { + ParsedQuicVersion target_version = server_supported_versions_.back(); + if (!version_.UsesTls() || target_version == version_) { + ASSERT_TRUE(Initialize()); + return; + } + connect_to_server_on_initialize_ = false; + client_supported_versions_.insert(client_supported_versions_.begin(), + target_version); + ParsedQuicVersionVector downgrade_versions{version_}; + ASSERT_TRUE(Initialize()); + ASSERT_TRUE(server_thread_); + // Pause the server thread to allow our DowngradePacketWriter to write version + // negotiation packets in a thread-safe manner. It will be resumed by the + // DowngradePacketWriter. + server_thread_->Pause(); + client_.reset(new QuicTestClient(server_address_, server_hostname_, + client_config_, client_supported_versions_, + crypto_test_utils::ProofVerifierForTesting(), + std::make_unique())); + delete client_writer_; + client_writer_ = new DowngradePacketWriter(target_version, downgrade_versions, + client_.get(), server_writer_, + server_thread_.get()); + client_->UseWriter(client_writer_); + // Have the client attempt to send a request. + client_->Connect(); + EXPECT_TRUE(client_->SendSynchronousRequest("/foo").empty()); + // Make sure the downgrade is detected and the handshake fails. + EXPECT_THAT(client_->connection_error(), IsError(QUIC_HANDSHAKE_FAILED)); +} + +// A bad header shouldn't tear down the connection, because the receiver can't +// tell the connection ID. +TEST_P(EndToEndTest, BadPacketHeaderTruncated) { + ASSERT_TRUE(Initialize()); + + // Start the connection. + SendSynchronousFooRequestAndCheckResponse(); + + // Packet with invalid public flags. + char packet[] = {// public flags (8 byte connection_id) + 0x3C, + // truncated connection ID + 0x11}; + client_writer_->WritePacket( + &packet[0], sizeof(packet), + client_->client()->network_helper()->GetLatestClientAddress().host(), + server_address_, nullptr); + EXPECT_TRUE(server_thread_->WaitUntil( + [&] { + return QuicDispatcherPeer::GetAndClearLastError( + QuicServerPeer::GetDispatcher(server_thread_->server())) == + QUIC_INVALID_PACKET_HEADER; + }, + QuicTime::Delta::FromSeconds(5))); + + // The connection should not be terminated. + SendSynchronousFooRequestAndCheckResponse(); +} + +// A bad header shouldn't tear down the connection, because the receiver can't +// tell the connection ID. +TEST_P(EndToEndTest, BadPacketHeaderFlags) { + ASSERT_TRUE(Initialize()); + + // Start the connection. + SendSynchronousFooRequestAndCheckResponse(); + + // Packet with invalid public flags. + uint8_t packet[] = { + // invalid public flags + 0xFF, + // connection_id + 0x10, + 0x32, + 0x54, + 0x76, + 0x98, + 0xBA, + 0xDC, + 0xFE, + // packet sequence number + 0xBC, + 0x9A, + 0x78, + 0x56, + 0x34, + 0x12, + // private flags + 0x00, + }; + client_writer_->WritePacket( + reinterpret_cast(packet), sizeof(packet), + client_->client()->network_helper()->GetLatestClientAddress().host(), + server_address_, nullptr); + + EXPECT_TRUE(server_thread_->WaitUntil( + [&] { + return QuicDispatcherPeer::GetAndClearLastError( + QuicServerPeer::GetDispatcher(server_thread_->server())) == + QUIC_INVALID_PACKET_HEADER; + }, + QuicTime::Delta::FromSeconds(5))); + + // The connection should not be terminated. + SendSynchronousFooRequestAndCheckResponse(); +} + +// Send a packet from the client with bad encrypted data. The server should not +// tear down the connection. +// Marked as slow since it calls absl::SleepFor(). +TEST_P(EndToEndTest, QUICHE_SLOW_TEST(BadEncryptedData)) { + ASSERT_TRUE(Initialize()); + + // Start the connection. + SendSynchronousFooRequestAndCheckResponse(); + + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + std::unique_ptr packet(ConstructEncryptedPacket( + client_connection->connection_id(), EmptyQuicConnectionId(), false, false, + 1, "At least 20 characters.", CONNECTION_ID_PRESENT, CONNECTION_ID_ABSENT, + PACKET_4BYTE_PACKET_NUMBER)); + // Damage the encrypted data. + std::string damaged_packet(packet->data(), packet->length()); + damaged_packet[30] ^= 0x01; + QUIC_DLOG(INFO) << "Sending bad packet."; + client_writer_->WritePacket( + damaged_packet.data(), damaged_packet.length(), + client_->client()->network_helper()->GetLatestClientAddress().host(), + server_address_, nullptr); + // Give the server time to process the packet. + absl::SleepFor(absl::Seconds(1)); + // This error is sent to the connection's OnError (which ignores it), so the + // dispatcher doesn't see it. + // Pause the server so we can access the server's internals without races. + server_thread_->Pause(); + QuicDispatcher* dispatcher = + QuicServerPeer::GetDispatcher(server_thread_->server()); + if (dispatcher != nullptr) { + EXPECT_THAT(QuicDispatcherPeer::GetAndClearLastError(dispatcher), + IsQuicNoError()); + } else { + ADD_FAILURE() << "Missing dispatcher"; + } + server_thread_->Resume(); + + // The connection should not be terminated. + SendSynchronousFooRequestAndCheckResponse(); +} + +TEST_P(EndToEndTest, CanceledStreamDoesNotBecomeZombie) { + ASSERT_TRUE(Initialize()); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + // Lose the request. + SetPacketLossPercentage(100); + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + client_->SendMessage(headers, "test_body", /*fin=*/false); + QuicSpdyClientStream* stream = client_->GetOrCreateStream(); + + // Cancel the stream. + stream->Reset(QUIC_STREAM_CANCELLED); + QuicSession* session = GetClientSession(); + ASSERT_TRUE(session); + // Verify canceled stream does not become zombie. + EXPECT_EQ(1u, QuicSessionPeer::closed_streams(session).size()); +} + +// A test stream that gives |response_body_| as an error response body. +class ServerStreamWithErrorResponseBody : public QuicSimpleServerStream { + public: + ServerStreamWithErrorResponseBody( + QuicStreamId id, QuicSpdySession* session, + QuicSimpleServerBackend* quic_simple_server_backend, + std::string response_body) + : QuicSimpleServerStream(id, session, BIDIRECTIONAL, + quic_simple_server_backend), + response_body_(std::move(response_body)) {} + + ~ServerStreamWithErrorResponseBody() override = default; + + protected: + void SendErrorResponse() override { + QUIC_DLOG(INFO) << "Sending error response for stream " << id(); + Http2HeaderBlock headers; + headers[":status"] = "500"; + headers["content-length"] = absl::StrCat(response_body_.size()); + // This method must call CloseReadSide to cause the test case, StopReading + // is not sufficient. + QuicStreamPeer::CloseReadSide(this); + SendHeadersAndBody(std::move(headers), response_body_); + } + + std::string response_body_; +}; + +class StreamWithErrorFactory : public QuicTestServer::StreamFactory { + public: + explicit StreamWithErrorFactory(std::string response_body) + : response_body_(std::move(response_body)) {} + + ~StreamWithErrorFactory() override = default; + + QuicSimpleServerStream* CreateStream( + QuicStreamId id, QuicSpdySession* session, + QuicSimpleServerBackend* quic_simple_server_backend) override { + return new ServerStreamWithErrorResponseBody( + id, session, quic_simple_server_backend, response_body_); + } + + QuicSimpleServerStream* CreateStream( + PendingStream* /*pending*/, QuicSpdySession* /*session*/, + QuicSimpleServerBackend* /*response_cache*/) override { + return nullptr; + } + + private: + std::string response_body_; +}; + +// A test server stream that drops all received body. +class ServerStreamThatDropsBody : public QuicSimpleServerStream { + public: + ServerStreamThatDropsBody(QuicStreamId id, QuicSpdySession* session, + QuicSimpleServerBackend* quic_simple_server_backend) + : QuicSimpleServerStream(id, session, BIDIRECTIONAL, + quic_simple_server_backend) {} + + ~ServerStreamThatDropsBody() override = default; + + protected: + void OnBodyAvailable() override { + while (HasBytesToRead()) { + struct iovec iov; + if (GetReadableRegions(&iov, 1) == 0) { + // No more data to read. + break; + } + QUIC_DVLOG(1) << "Processed " << iov.iov_len << " bytes for stream " + << id(); + MarkConsumed(iov.iov_len); + } + + if (!sequencer()->IsClosed()) { + sequencer()->SetUnblocked(); + return; + } + + // If the sequencer is closed, then all the body, including the fin, has + // been consumed. + OnFinRead(); + + if (write_side_closed() || fin_buffered()) { + return; + } + + SendResponse(); + } +}; + +class ServerStreamThatDropsBodyFactory : public QuicTestServer::StreamFactory { + public: + ServerStreamThatDropsBodyFactory() = default; + + ~ServerStreamThatDropsBodyFactory() override = default; + + QuicSimpleServerStream* CreateStream( + QuicStreamId id, QuicSpdySession* session, + QuicSimpleServerBackend* quic_simple_server_backend) override { + return new ServerStreamThatDropsBody(id, session, + quic_simple_server_backend); + } + + QuicSimpleServerStream* CreateStream( + PendingStream* /*pending*/, QuicSpdySession* /*session*/, + QuicSimpleServerBackend* /*response_cache*/) override { + return nullptr; + } +}; + +// A test server stream that sends response with body size greater than 4GB. +class ServerStreamThatSendsHugeResponse : public QuicSimpleServerStream { + public: + ServerStreamThatSendsHugeResponse( + QuicStreamId id, QuicSpdySession* session, + QuicSimpleServerBackend* quic_simple_server_backend, int64_t body_bytes) + : QuicSimpleServerStream(id, session, BIDIRECTIONAL, + quic_simple_server_backend), + body_bytes_(body_bytes) {} + + ~ServerStreamThatSendsHugeResponse() override = default; + + protected: + void SendResponse() override { + QuicBackendResponse response; + std::string body(body_bytes_, 'a'); + response.set_body(body); + SendHeadersAndBodyAndTrailers(response.headers().Clone(), response.body(), + response.trailers().Clone()); + } + + private: + // Use a explicit int64_t rather than size_t to simulate a 64-bit server + // talking to a 32-bit client. + int64_t body_bytes_; +}; + +class ServerStreamThatSendsHugeResponseFactory + : public QuicTestServer::StreamFactory { + public: + explicit ServerStreamThatSendsHugeResponseFactory(int64_t body_bytes) + : body_bytes_(body_bytes) {} + + ~ServerStreamThatSendsHugeResponseFactory() override = default; + + QuicSimpleServerStream* CreateStream( + QuicStreamId id, QuicSpdySession* session, + QuicSimpleServerBackend* quic_simple_server_backend) override { + return new ServerStreamThatSendsHugeResponse( + id, session, quic_simple_server_backend, body_bytes_); + } + + QuicSimpleServerStream* CreateStream( + PendingStream* /*pending*/, QuicSpdySession* /*session*/, + QuicSimpleServerBackend* /*response_cache*/) override { + return nullptr; + } + + int64_t body_bytes_; +}; + +class BlockedFrameObserver : public QuicConnectionDebugVisitor { + public: + std::vector blocked_frames() const { + return blocked_frames_; + } + + void OnBlockedFrame(const QuicBlockedFrame& frame) override { + blocked_frames_.push_back(frame); + } + + private: + std::vector blocked_frames_; +}; + +TEST_P(EndToEndTest, BlockedFrameIncludesOffset) { + if (!version_.HasIetfQuicFrames()) { + // For Google QUIC, the BLOCKED frame offset is ignored. + Initialize(); + return; + } + + set_smaller_flow_control_receive_window(); + ASSERT_TRUE(Initialize()); + + // Observe the connection for BLOCKED frames. + BlockedFrameObserver observer; + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + client_connection->set_debug_visitor(&observer); + + // Set the response body larger than the flow control window so the server + // must receive a window update from the client before it can finish sending + // it (hence, causing the server to send a BLOCKED frame) + uint32_t response_body_size = + client_config_.GetInitialSessionFlowControlWindowToSend() + 10; + std::string response_body(response_body_size, 'a'); + AddToCache("/blocked", 200, response_body); + SendSynchronousRequestAndCheckResponse("/blocked", response_body); + client_->Disconnect(); + + ASSERT_GE(observer.blocked_frames().size(), static_cast(0)); + for (const QuicBlockedFrame& frame : observer.blocked_frames()) { + if (frame.stream_id == + QuicUtils::GetInvalidStreamId(version_.transport_version)) { + // connection-level BLOCKED frame + ASSERT_EQ(frame.offset, + client_config_.GetInitialSessionFlowControlWindowToSend()); + } else { + // stream-level BLOCKED frame + ASSERT_EQ(frame.offset, + client_config_.GetInitialStreamFlowControlWindowToSend()); + } + } + + client_connection->set_debug_visitor(nullptr); +} + +TEST_P(EndToEndTest, EarlyResponseFinRecording) { + set_smaller_flow_control_receive_window(); + + // Verify that an incoming FIN is recorded in a stream object even if the read + // side has been closed. This prevents an entry from being made in + // locally_close_streams_highest_offset_ (which will never be deleted). + // To set up the test condition, the server must do the following in order: + // start sending the response and call CloseReadSide + // receive the FIN of the request + // send the FIN of the response + + // The response body must be larger than the flow control window so the server + // must receive a window update from the client before it can finish sending + // it. + uint32_t response_body_size = + 2 * client_config_.GetInitialStreamFlowControlWindowToSend(); + std::string response_body(response_body_size, 'a'); + + StreamWithErrorFactory stream_factory(response_body); + SetSpdyStreamFactory(&stream_factory); + + ASSERT_TRUE(Initialize()); + + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + + // A POST that gets an early error response, after the headers are received + // and before the body is received, due to invalid content-length. + // Set an invalid content-length, so the request will receive an early 500 + // response. + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/garbage"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + headers["content-length"] = "-1"; + + // The body must be large enough that the FIN will be in a different packet + // than the end of the headers, but short enough to not require a flow control + // update. This allows headers processing to trigger the error response + // before the request FIN is processed but receive the request FIN before the + // response is sent completely. + const uint32_t kRequestBodySize = kMaxOutgoingPacketSize + 10; + std::string request_body(kRequestBodySize, 'a'); + + // Send the request. + client_->SendMessage(headers, request_body); + client_->WaitForResponse(); + CheckResponseHeaders("500"); + + // Pause the server so we can access the server's internals without races. + server_thread_->Pause(); + + QuicDispatcher* dispatcher = + QuicServerPeer::GetDispatcher(server_thread_->server()); + QuicSession* server_session = + QuicDispatcherPeer::GetFirstSessionIfAny(dispatcher); + EXPECT_TRUE(server_session != nullptr); + + // The stream is not waiting for the arrival of the peer's final offset. + EXPECT_EQ( + 0u, QuicSessionPeer::GetLocallyClosedStreamsHighestOffset(server_session) + .size()); + + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, Trailers) { + // Test sending and receiving HTTP/2 Trailers (trailing HEADERS frames). + ASSERT_TRUE(Initialize()); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + + // Set reordering to ensure that Trailers arriving before body is ok. + SetPacketSendDelay(QuicTime::Delta::FromMilliseconds(2)); + SetReorderPercentage(30); + + // Add a response with headers, body, and trailers. + const std::string kBody = "body content"; + + Http2HeaderBlock headers; + headers[":status"] = "200"; + headers["content-length"] = absl::StrCat(kBody.size()); + + Http2HeaderBlock trailers; + trailers["some-trailing-header"] = "trailing-header-value"; + + memory_cache_backend_.AddResponse(server_hostname_, "/trailer_url", + std::move(headers), kBody, + trailers.Clone()); + + SendSynchronousRequestAndCheckResponse("/trailer_url", kBody); + EXPECT_EQ(trailers, client_->response_trailers()); +} + +// TODO(fayang): this test seems to cause net_unittests timeouts :| +TEST_P(EndToEndTest, DISABLED_TestHugePostWithPacketLoss) { + // This test tests a huge post with introduced packet loss from client to + // server and body size greater than 4GB, making sure QUIC code does not break + // for 32-bit builds. + ServerStreamThatDropsBodyFactory stream_factory; + SetSpdyStreamFactory(&stream_factory); + ASSERT_TRUE(Initialize()); + + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + SetPacketLossPercentage(1); + // To avoid storing the whole request body in memory, use a loop to repeatedly + // send body size of kSizeBytes until the whole request body size is reached. + const int kSizeBytes = 128 * 1024; + // Request body size is 4G plus one more kSizeBytes. + int64_t request_body_size_bytes = pow(2, 32) + kSizeBytes; + ASSERT_LT(INT64_C(4294967296), request_body_size_bytes); + std::string body(kSizeBytes, 'a'); + + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + headers["content-length"] = absl::StrCat(request_body_size_bytes); + + client_->SendMessage(headers, "", /*fin=*/false); + + for (int i = 0; i < request_body_size_bytes / kSizeBytes; ++i) { + bool fin = (i == request_body_size_bytes - 1); + client_->SendData(std::string(body.data(), kSizeBytes), fin); + client_->client()->WaitForEvents(); + } + VerifyCleanConnection(true); +} + +// TODO(fayang): this test seems to cause net_unittests timeouts :| +TEST_P(EndToEndTest, DISABLED_TestHugeResponseWithPacketLoss) { + // This test tests a huge response with introduced loss from server to client + // and body size greater than 4GB, making sure QUIC code does not break for + // 32-bit builds. + const int kSizeBytes = 128 * 1024; + int64_t response_body_size_bytes = pow(2, 32) + kSizeBytes; + ASSERT_LT(4294967296, response_body_size_bytes); + ServerStreamThatSendsHugeResponseFactory stream_factory( + response_body_size_bytes); + SetSpdyStreamFactory(&stream_factory); + + StartServer(); + + // Use a quic client that drops received body. + QuicTestClient* client = + new QuicTestClient(server_address_, server_hostname_, client_config_, + client_supported_versions_); + client->client()->set_drop_response_body(true); + client->UseWriter(client_writer_); + client->Connect(); + client_.reset(client); + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + client_writer_->Initialize( + QuicConnectionPeer::GetHelper(client_connection), + QuicConnectionPeer::GetAlarmFactory(client_connection), + std::make_unique(client_->client())); + initialized_ = true; + ASSERT_TRUE(client_->client()->connected()); + + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + SetPacketLossPercentage(1); + client_->SendRequest("/huge_response"); + client_->WaitForResponse(); + VerifyCleanConnection(true); +} + +// Regression test for b/111515567 +TEST_P(EndToEndTest, AgreeOnStopWaiting) { + ASSERT_TRUE(Initialize()); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + if (server_connection != nullptr) { + // Verify client and server connections agree on the value of + // no_stop_waiting_frames. + EXPECT_EQ(QuicConnectionPeer::GetNoStopWaitingFrames(client_connection), + QuicConnectionPeer::GetNoStopWaitingFrames(server_connection)); + } else { + ADD_FAILURE() << "Missing server connection"; + } + server_thread_->Resume(); +} + +// Regression test for b/111515567 +TEST_P(EndToEndTest, AgreeOnStopWaitingWithNoStopWaitingOption) { + QuicTagVector options; + options.push_back(kNSTP); + client_config_.SetConnectionOptionsToSend(options); + ASSERT_TRUE(Initialize()); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + if (server_connection != nullptr) { + // Verify client and server connections agree on the value of + // no_stop_waiting_frames. + EXPECT_EQ(QuicConnectionPeer::GetNoStopWaitingFrames(client_connection), + QuicConnectionPeer::GetNoStopWaitingFrames(server_connection)); + } else { + ADD_FAILURE() << "Missing server connection"; + } + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, ReleaseHeadersStreamBufferWhenIdle) { + // Tests that when client side has no active request and no waiting + // PUSH_PROMISE, its headers stream's sequencer buffer should be released. + ASSERT_TRUE(Initialize()); + client_->SendSynchronousRequest("/foo"); + if (version_.UsesHttp3()) { + return; + } + QuicSpdyClientSession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + QuicHeadersStream* headers_stream = + QuicSpdySessionPeer::GetHeadersStream(client_session); + ASSERT_TRUE(headers_stream); + QuicStreamSequencer* sequencer = QuicStreamPeer::sequencer(headers_stream); + ASSERT_TRUE(sequencer); + EXPECT_FALSE(QuicStreamSequencerPeer::IsUnderlyingBufferAllocated(sequencer)); +} + +// A single large header value causes a different error than the total size of +// headers exceeding a smaller limit, tested at EndToEndTest.LargeHeaders. +TEST_P(EndToEndTest, WayTooLongRequestHeaders) { + ASSERT_TRUE(Initialize()); + + Http2HeaderBlock headers; + headers[":method"] = "GET"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + headers["key"] = std::string(2 * 1024 * 1024, 'a'); + + client_->SendMessage(headers, ""); + client_->WaitForResponse(); + if (version_.UsesHttp3()) { + EXPECT_THAT(client_->connection_error(), + IsError(QUIC_QPACK_DECOMPRESSION_FAILED)); + } else { + EXPECT_THAT(client_->connection_error(), + IsError(QUIC_HPACK_VALUE_TOO_LONG)); + } +} + +class WindowUpdateObserver : public QuicConnectionDebugVisitor { + public: + WindowUpdateObserver() : num_window_update_frames_(0), num_ping_frames_(0) {} + + size_t num_window_update_frames() const { return num_window_update_frames_; } + + size_t num_ping_frames() const { return num_ping_frames_; } + + void OnWindowUpdateFrame(const QuicWindowUpdateFrame& /*frame*/, + const QuicTime& /*receive_time*/) override { + ++num_window_update_frames_; + } + + void OnPingFrame(const QuicPingFrame& /*frame*/, + const QuicTime::Delta /*ping_received_delay*/) override { + ++num_ping_frames_; + } + + private: + size_t num_window_update_frames_; + size_t num_ping_frames_; +}; + +TEST_P(EndToEndTest, WindowUpdateInAck) { + ASSERT_TRUE(Initialize()); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + WindowUpdateObserver observer; + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + client_connection->set_debug_visitor(&observer); + // 100KB body. + std::string body(100 * 1024, 'a'); + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + + EXPECT_EQ(kFooResponseBody, + client_->SendCustomSynchronousRequest(headers, body)); + client_->Disconnect(); + EXPECT_LT(0u, observer.num_window_update_frames()); + EXPECT_EQ(0u, observer.num_ping_frames()); + client_connection->set_debug_visitor(nullptr); +} + +TEST_P(EndToEndTest, SendStatelessResetTokenInShlo) { + ASSERT_TRUE(Initialize()); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + QuicSpdyClientSession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + QuicConfig* config = client_session->config(); + ASSERT_TRUE(config); + EXPECT_TRUE(config->HasReceivedStatelessResetToken()); + QuicConnection* client_connection = client_session->connection(); + ASSERT_TRUE(client_connection); + EXPECT_EQ(QuicUtils::GenerateStatelessResetToken( + client_connection->connection_id()), + config->ReceivedStatelessResetToken()); + client_->Disconnect(); +} + +// Regression test for b/116200989. +TEST_P(EndToEndTest, + SendStatelessResetIfServerConnectionClosedLocallyDuringHandshake) { + connect_to_server_on_initialize_ = false; + ASSERT_TRUE(Initialize()); + + ASSERT_TRUE(server_thread_); + server_thread_->Pause(); + QuicDispatcher* dispatcher = + QuicServerPeer::GetDispatcher(server_thread_->server()); + if (dispatcher == nullptr) { + ADD_FAILURE() << "Missing dispatcher"; + server_thread_->Resume(); + return; + } + if (dispatcher->NumSessions() > 0) { + ADD_FAILURE() << "Dispatcher session map not empty"; + server_thread_->Resume(); + return; + } + // Note: this writer will only used by the server connection, not the time + // wait list. + QuicDispatcherPeer::UseWriter( + dispatcher, + // This cause the first server-sent packet, a.k.a REJ, to fail. + new BadPacketWriter(/*packet_causing_write_error=*/0, EPERM)); + server_thread_->Resume(); + + client_.reset(CreateQuicClient(client_writer_)); + EXPECT_EQ("", client_->SendSynchronousRequest("/foo")); + EXPECT_THAT(client_->connection_error(), IsError(QUIC_HANDSHAKE_FAILED)); +} + +// Regression test for b/116200989. +TEST_P(EndToEndTest, + SendStatelessResetIfServerConnectionClosedLocallyAfterHandshake) { + // Prevent the connection from expiring in the time wait list. + SetQuicFlag(quic_time_wait_list_seconds, 10000); + connect_to_server_on_initialize_ = false; + ASSERT_TRUE(Initialize()); + + // big_response_body is 64K, which is about 48 full-sized packets. + const size_t kBigResponseBodySize = 65536; + QuicData big_response_body(new char[kBigResponseBodySize](), + kBigResponseBodySize, /*owns_buffer=*/true); + AddToCache("/big_response", 200, big_response_body.AsStringPiece()); + + ASSERT_TRUE(server_thread_); + server_thread_->Pause(); + QuicDispatcher* dispatcher = + QuicServerPeer::GetDispatcher(server_thread_->server()); + if (dispatcher == nullptr) { + ADD_FAILURE() << "Missing dispatcher"; + server_thread_->Resume(); + return; + } + if (dispatcher->NumSessions() > 0) { + ADD_FAILURE() << "Dispatcher session map not empty"; + server_thread_->Resume(); + return; + } + QuicDispatcherPeer::UseWriter( + dispatcher, + // This will cause an server write error with EPERM, while sending the + // response for /big_response. + new BadPacketWriter(/*packet_causing_write_error=*/20, EPERM)); + server_thread_->Resume(); + + client_.reset(CreateQuicClient(client_writer_)); + + // First, a /foo request with small response should succeed. + SendSynchronousFooRequestAndCheckResponse(); + + // Second, a /big_response request with big response should fail. + EXPECT_LT(client_->SendSynchronousRequest("/big_response").length(), + kBigResponseBodySize); + EXPECT_THAT(client_->connection_error(), IsError(QUIC_PUBLIC_RESET)); +} + +// Regression test of b/70782529. +TEST_P(EndToEndTest, DoNotCrashOnPacketWriteError) { + ASSERT_TRUE(Initialize()); + BadPacketWriter* bad_writer = + new BadPacketWriter(/*packet_causing_write_error=*/5, + /*error_code=*/90); + std::unique_ptr client(CreateQuicClient(bad_writer)); + + // 1 MB body. + std::string body(1024 * 1024, 'a'); + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + + client->SendCustomSynchronousRequest(headers, body); +} + +// Regression test for b/71711996. This test sends a connectivity probing packet +// as its last sent packet, and makes sure the server's ACK of that packet does +// not cause the client to fail. +TEST_P(EndToEndTest, LastPacketSentIsConnectivityProbing) { + ASSERT_TRUE(Initialize()); + + SendSynchronousFooRequestAndCheckResponse(); + + // Wait for the client's ACK (of the response) to be received by the server. + client_->WaitForDelayedAcks(); + + // We are sending a connectivity probing packet from an unchanged client + // address, so the server will not respond to us with a connectivity probing + // packet, however the server should send an ack-only packet to us. + client_->SendConnectivityProbing(); + + // Wait for the server's last ACK to be received by the client. + client_->WaitForDelayedAcks(); +} + +TEST_P(EndToEndTest, PreSharedKey) { + client_config_.set_max_time_before_crypto_handshake( + QuicTime::Delta::FromSeconds(5)); + client_config_.set_max_idle_time_before_crypto_handshake( + QuicTime::Delta::FromSeconds(5)); + pre_shared_key_client_ = "foobar"; + pre_shared_key_server_ = "foobar"; + + if (version_.UsesTls()) { + // TODO(b/154162689) add PSK support to QUIC+TLS. + bool ok = true; + EXPECT_QUIC_BUG(ok = Initialize(), + "QUIC client pre-shared keys not yet supported with TLS"); + EXPECT_FALSE(ok); + return; + } + + ASSERT_TRUE(Initialize()); + + SendSynchronousFooRequestAndCheckResponse(); +} + +// TODO: reenable once we have a way to make this run faster. +TEST_P(EndToEndTest, QUIC_TEST_DISABLED_IN_CHROME(PreSharedKeyMismatch)) { + client_config_.set_max_time_before_crypto_handshake( + QuicTime::Delta::FromSeconds(1)); + client_config_.set_max_idle_time_before_crypto_handshake( + QuicTime::Delta::FromSeconds(1)); + pre_shared_key_client_ = "foo"; + pre_shared_key_server_ = "bar"; + + if (version_.UsesTls()) { + // TODO(b/154162689) add PSK support to QUIC+TLS. + bool ok = true; + EXPECT_QUIC_BUG(ok = Initialize(), + "QUIC client pre-shared keys not yet supported with TLS"); + EXPECT_FALSE(ok); + return; + } + + // One of two things happens when Initialize() returns: + // 1. Crypto handshake has completed, and it is unsuccessful. Initialize() + // returns false. + // 2. Crypto handshake has not completed, Initialize() returns true. The call + // to WaitForCryptoHandshakeConfirmed() will wait for the handshake and + // return whether it is successful. + ASSERT_FALSE(Initialize() && client_->client()->WaitForOneRttKeysAvailable()); + EXPECT_THAT(client_->connection_error(), IsError(QUIC_HANDSHAKE_TIMEOUT)); +} + +// TODO: reenable once we have a way to make this run faster. +TEST_P(EndToEndTest, QUIC_TEST_DISABLED_IN_CHROME(PreSharedKeyNoClient)) { + client_config_.set_max_time_before_crypto_handshake( + QuicTime::Delta::FromSeconds(1)); + client_config_.set_max_idle_time_before_crypto_handshake( + QuicTime::Delta::FromSeconds(1)); + pre_shared_key_server_ = "foobar"; + + if (version_.UsesTls()) { + // TODO(b/154162689) add PSK support to QUIC+TLS. + bool ok = true; + EXPECT_QUIC_BUG(ok = Initialize(), + "QUIC server pre-shared keys not yet supported with TLS"); + EXPECT_FALSE(ok); + return; + } + + ASSERT_FALSE(Initialize() && client_->client()->WaitForOneRttKeysAvailable()); + EXPECT_THAT(client_->connection_error(), IsError(QUIC_HANDSHAKE_TIMEOUT)); +} + +// TODO: reenable once we have a way to make this run faster. +TEST_P(EndToEndTest, QUIC_TEST_DISABLED_IN_CHROME(PreSharedKeyNoServer)) { + client_config_.set_max_time_before_crypto_handshake( + QuicTime::Delta::FromSeconds(1)); + client_config_.set_max_idle_time_before_crypto_handshake( + QuicTime::Delta::FromSeconds(1)); + pre_shared_key_client_ = "foobar"; + + if (version_.UsesTls()) { + // TODO(b/154162689) add PSK support to QUIC+TLS. + bool ok = true; + EXPECT_QUIC_BUG(ok = Initialize(), + "QUIC client pre-shared keys not yet supported with TLS"); + EXPECT_FALSE(ok); + return; + } + + ASSERT_FALSE(Initialize() && client_->client()->WaitForOneRttKeysAvailable()); + EXPECT_THAT(client_->connection_error(), IsError(QUIC_HANDSHAKE_TIMEOUT)); +} + +TEST_P(EndToEndTest, RequestAndStreamRstInOnePacket) { + // Regression test for b/80234898. + ASSERT_TRUE(Initialize()); + + // INCOMPLETE_RESPONSE will cause the server to not to send the trailer + // (and the FIN) after the response body. + std::string response_body(1305, 'a'); + Http2HeaderBlock response_headers; + response_headers[":status"] = absl::StrCat(200); + response_headers["content-length"] = absl::StrCat(response_body.length()); + memory_cache_backend_.AddSpecialResponse( + server_hostname_, "/test_url", std::move(response_headers), response_body, + QuicBackendResponse::INCOMPLETE_RESPONSE); + + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + client_->WaitForDelayedAcks(); + + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + const QuicPacketCount packets_sent_before = + client_connection->GetStats().packets_sent; + + client_->SendRequestAndRstTogether("/test_url"); + + // Expect exactly one packet is sent from the block above. + ASSERT_EQ(packets_sent_before + 1, + client_connection->GetStats().packets_sent); + + // Wait for the connection to become idle. + client_->WaitForDelayedAcks(); + + // The real expectation is the test does not crash or timeout. + EXPECT_THAT(client_->connection_error(), IsQuicNoError()); +} + +TEST_P(EndToEndTest, ResetStreamOnTtlExpires) { + ASSERT_TRUE(Initialize()); + EXPECT_TRUE(client_->client()->WaitForHandshakeConfirmed()); + SetPacketLossPercentage(30); + + QuicSpdyClientStream* stream = client_->GetOrCreateStream(); + // Set a TTL which expires immediately. + stream->MaybeSetTtl(QuicTime::Delta::FromMicroseconds(1)); + + WriteHeadersOnStream(stream); + // 1 MB body. + std::string body(1024 * 1024, 'a'); + stream->WriteOrBufferBody(body, true); + client_->WaitForResponse(); + EXPECT_THAT(client_->stream_error(), IsStreamError(QUIC_STREAM_TTL_EXPIRED)); +} + +TEST_P(EndToEndTest, SendMessages) { + if (!version_.SupportsMessageFrames()) { + Initialize(); + return; + } + ASSERT_TRUE(Initialize()); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + QuicSession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + QuicConnection* client_connection = client_session->connection(); + ASSERT_TRUE(client_connection); + + SetPacketLossPercentage(30); + ASSERT_GT(kMaxOutgoingPacketSize, + client_session->GetCurrentLargestMessagePayload()); + ASSERT_LT(0, client_session->GetCurrentLargestMessagePayload()); + + std::string message_string(kMaxOutgoingPacketSize, 'a'); + QuicRandom* random = + QuicConnectionPeer::GetHelper(client_connection)->GetRandomGenerator(); + { + QuicConnection::ScopedPacketFlusher flusher(client_session->connection()); + // Verify the largest message gets successfully sent. + EXPECT_EQ(MessageResult(MESSAGE_STATUS_SUCCESS, 1), + client_session->SendMessage(MemSliceFromString(absl::string_view( + message_string.data(), + client_session->GetCurrentLargestMessagePayload())))); + // Send more messages with size (0, largest_payload] until connection is + // write blocked. + const int kTestMaxNumberOfMessages = 100; + for (size_t i = 2; i <= kTestMaxNumberOfMessages; ++i) { + size_t message_length = + random->RandUint64() % + client_session->GetGuaranteedLargestMessagePayload() + + 1; + MessageResult result = client_session->SendMessage(MemSliceFromString( + absl::string_view(message_string.data(), message_length))); + if (result.status == MESSAGE_STATUS_BLOCKED) { + // Connection is write blocked. + break; + } + EXPECT_EQ(MessageResult(MESSAGE_STATUS_SUCCESS, i), result); + } + } + + client_->WaitForDelayedAcks(); + EXPECT_EQ(MESSAGE_STATUS_TOO_LARGE, + client_session + ->SendMessage(MemSliceFromString(absl::string_view( + message_string.data(), + client_session->GetCurrentLargestMessagePayload() + 1))) + .status); + EXPECT_THAT(client_->connection_error(), IsQuicNoError()); +} + +class EndToEndPacketReorderingTest : public EndToEndTest { + public: + void CreateClientWithWriter() override { + QUIC_LOG(ERROR) << "create client with reorder_writer_"; + reorder_writer_ = new PacketReorderingWriter(); + client_.reset(EndToEndTest::CreateQuicClient(reorder_writer_)); + } + + void SetUp() override { + // Don't initialize client writer in base class. + server_writer_ = new PacketDroppingTestWriter(); + } + + protected: + PacketReorderingWriter* reorder_writer_; +}; + +INSTANTIATE_TEST_SUITE_P(EndToEndPacketReorderingTests, + EndToEndPacketReorderingTest, + ::testing::ValuesIn(GetTestParams()), + ::testing::PrintToStringParamName()); + +TEST_P(EndToEndPacketReorderingTest, ReorderedConnectivityProbing) { + ASSERT_TRUE(Initialize()); + if (version_.HasIetfQuicFrames()) { + return; + } + + // Finish one request to make sure handshake established. + EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); + + // Wait for the connection to become idle, to make sure the packet gets + // delayed is the connectivity probing packet. + client_->WaitForDelayedAcks(); + + QuicSocketAddress old_addr = + client_->client()->network_helper()->GetLatestClientAddress(); + + // Migrate socket to the new IP address. + QuicIpAddress new_host = TestLoopback(2); + EXPECT_NE(old_addr.host(), new_host); + ASSERT_TRUE(client_->client()->MigrateSocket(new_host)); + + // Write a connectivity probing after the next /foo request. + reorder_writer_->SetDelay(1); + client_->SendConnectivityProbing(); + + ASSERT_TRUE(client_->MigrateSocketWithSpecifiedPort(old_addr.host(), + old_addr.port())); + + // The (delayed) connectivity probing will be sent after this request. + EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); + + // Send yet another request after the connectivity probing, when this request + // returns, the probing is guaranteed to have been received by the server, and + // the server's response to probing is guaranteed to have been received by the + // client. + EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); + + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + if (server_connection != nullptr) { + EXPECT_EQ(1u, + server_connection->GetStats().num_connectivity_probing_received); + } else { + ADD_FAILURE() << "Missing server connection"; + } + server_thread_->Resume(); + + // Server definitely responded to the connectivity probing. Sometime it also + // sends a padded ping that is not a connectivity probing, which is recognized + // as connectivity probing because client's self address is ANY. + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + EXPECT_LE(1u, + client_connection->GetStats().num_connectivity_probing_received); +} + +// A writer which holds the next packet to be sent till ReleasePacket() is +// called. +class PacketHoldingWriter : public QuicPacketWriterWrapper { + public: + WriteResult WritePacket(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + PerPacketOptions* options) override { + if (!hold_next_packet_) { + return QuicPacketWriterWrapper::WritePacket(buffer, buf_len, self_address, + peer_address, options); + } + QUIC_DLOG(INFO) << "Packet is held by the writer"; + packet_content_ = std::string(buffer, buf_len); + self_address_ = self_address; + peer_address_ = peer_address; + options_ = (options == nullptr ? nullptr : options->Clone()); + hold_next_packet_ = false; + return WriteResult(WRITE_STATUS_OK, buf_len); + } + + void HoldNextPacket() { + QUICHE_DCHECK(packet_content_.empty()) + << "There is already one packet on hold."; + hold_next_packet_ = true; + } + + void ReleasePacket() { + QUIC_DLOG(INFO) << "Release packet"; + ASSERT_EQ(WRITE_STATUS_OK, + QuicPacketWriterWrapper::WritePacket( + packet_content_.data(), packet_content_.length(), + self_address_, peer_address_, options_.release()) + .status); + packet_content_.clear(); + } + + private: + bool hold_next_packet_{false}; + std::string packet_content_; + QuicIpAddress self_address_; + QuicSocketAddress peer_address_; + std::unique_ptr options_; +}; + +TEST_P(EndToEndTest, ClientValidateNewNetwork) { + ASSERT_TRUE(Initialize()); + if (!version_.HasIetfQuicFrames() || + !GetClientConnection()->validate_client_address()) { + return; + } + client_.reset(EndToEndTest::CreateQuicClient(nullptr)); + SendSynchronousFooRequestAndCheckResponse(); + + // Store the client IP address which was used to send the first request. + QuicIpAddress old_host = + client_->client()->network_helper()->GetLatestClientAddress().host(); + + // Migrate socket to the new IP address. + QuicIpAddress new_host = TestLoopback(2); + EXPECT_NE(old_host, new_host); + + client_->client()->ValidateNewNetwork(new_host); + // Send a request using the old socket. + EXPECT_EQ(kBarResponseBody, client_->SendSynchronousRequest("/bar")); + // Client should have received a PATH_CHALLENGE. + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + EXPECT_EQ(1u, + client_connection->GetStats().num_connectivity_probing_received); + + // Send another request to make sure THE server will receive PATH_RESPONSE. + client_->SendSynchronousRequest("/eep"); + + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + if (server_connection != nullptr) { + EXPECT_EQ(1u, + server_connection->GetStats().num_connectivity_probing_received); + } else { + ADD_FAILURE() << "Missing server connection"; + } + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, ClientMultiPortConnection) { + client_config_.SetClientConnectionOptions(QuicTagVector{kMPQC}); + ASSERT_TRUE(Initialize()); + if (!GetClientConnection()->connection_migration_use_new_cid()) { + return; + } + client_.reset(EndToEndTest::CreateQuicClient(nullptr)); + QuicConnection* client_connection = GetClientConnection(); + QuicSpdyClientStream* stream = client_->GetOrCreateStream(); + ASSERT_TRUE(stream); + // Increase the probing frequency to speed up this test. + client_connection->SetMultiPortProbingInterval( + QuicTime::Delta::FromMilliseconds(100)); + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_TRUE(client_->WaitUntil(1000, [&]() { + return 1u == client_connection->GetStats().num_path_response_received; + })); + // Verify that the alternative path keeps sending probes periodically. + EXPECT_TRUE(client_->WaitUntil(1000, [&]() { + return 2u == client_connection->GetStats().num_path_response_received; + })); + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + // Verify that no migration has happened. + if (server_connection != nullptr) { + EXPECT_EQ(0u, server_connection->GetStats() + .num_peer_migration_to_proactively_validated_address); + } + server_thread_->Resume(); + + // This will cause the next periodic probing to fail. + server_writer_->set_fake_packet_loss_percentage(100); + EXPECT_TRUE(client_->WaitUntil( + 1000, [&]() { return client_->client()->HasPendingPathValidation(); })); + // Now wait for path validation to timeout. + EXPECT_TRUE(client_->WaitUntil( + 2000, [&]() { return !client_->client()->HasPendingPathValidation(); })); + server_writer_->set_fake_packet_loss_percentage(0); + EXPECT_TRUE(client_->WaitUntil(1000, [&]() { + return 3u == client_connection->GetStats().num_path_response_received; + })); + // Verify that the previous path was retired. + EXPECT_EQ(1u, client_connection->GetStats().num_retire_connection_id_sent); + stream->Reset(QuicRstStreamErrorCode::QUIC_STREAM_NO_ERROR); +} + +TEST_P(EndToEndTest, ClientMultiPortMigrationOnPathDegrading) { + client_config_.SetClientConnectionOptions(QuicTagVector{kMPQC}); + ASSERT_TRUE(Initialize()); + if (!GetClientConnection()->connection_migration_use_new_cid()) { + return; + } + client_.reset(EndToEndTest::CreateQuicClient(nullptr)); + QuicConnection* client_connection = GetClientConnection(); + QuicSpdyClientStream* stream = client_->GetOrCreateStream(); + ASSERT_TRUE(stream); + // Increase the probing frequency to speed up this test. + client_connection->SetMultiPortProbingInterval( + QuicTime::Delta::FromMilliseconds(100)); + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_TRUE(client_->WaitUntil(1000, [&]() { + return 1u == client_connection->GetStats().num_path_response_received; + })); + // Verify that the alternative path keeps sending probes periodically. + EXPECT_TRUE(client_->WaitUntil(1000, [&]() { + return 2u == client_connection->GetStats().num_path_response_received; + })); + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + // Verify that no migration has happened. + if (server_connection != nullptr) { + EXPECT_EQ(0u, server_connection->GetStats() + .num_peer_migration_to_proactively_validated_address); + } + server_thread_->Resume(); + + auto original_self_addr = client_connection->self_address(); + // Trigger client side path degrading + client_connection->OnPathDegradingDetected(); + EXPECT_NE(original_self_addr, client_connection->self_address()); + + // Send another request to trigger connection id retirement. + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_EQ(1u, client_connection->GetStats().num_retire_connection_id_sent); + auto new_alt_path = QuicConnectionPeer::GetAlternativePath(client_connection); + EXPECT_NE(client_connection->self_address(), new_alt_path->self_address); + + stream->Reset(QuicRstStreamErrorCode::QUIC_STREAM_NO_ERROR); +} + +TEST_P(EndToEndTest, SimpleServerPreferredAddressTest) { + use_preferred_address_ = true; + ASSERT_TRUE(Initialize()); + if (!GetClientConnection()->connection_migration_use_new_cid()) { + return; + } + client_.reset(CreateQuicClient(nullptr)); + QuicConnection* client_connection = GetClientConnection(); + EXPECT_TRUE(client_->client()->WaitForHandshakeConfirmed()); + EXPECT_EQ(server_address_, client_connection->effective_peer_address()); + EXPECT_EQ(server_address_, client_connection->peer_address()); + EXPECT_TRUE(client_->client()->HasPendingPathValidation()); + QuicConnectionId server_cid1 = client_connection->connection_id(); + + SendSynchronousFooRequestAndCheckResponse(); + while (client_->client()->HasPendingPathValidation()) { + client_->client()->WaitForEvents(); + } + EXPECT_EQ(server_preferred_address_, + client_connection->effective_peer_address()); + EXPECT_EQ(server_preferred_address_, client_connection->peer_address()); + EXPECT_NE(server_cid1, client_connection->connection_id()); + + const auto client_stats = GetClientConnection()->GetStats(); + EXPECT_TRUE(client_stats.server_preferred_address_validated); + EXPECT_FALSE(client_stats.failed_to_validate_server_preferred_address); +} + +TEST_P(EndToEndTest, OptimizedServerPreferredAddress) { + use_preferred_address_ = true; + ASSERT_TRUE(Initialize()); + if (!GetClientConnection()->connection_migration_use_new_cid()) { + return; + } + client_config_.SetClientConnectionOptions(QuicTagVector{kSPA2}); + client_.reset(CreateQuicClient(nullptr)); + QuicConnection* client_connection = GetClientConnection(); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + EXPECT_EQ(server_address_, client_connection->effective_peer_address()); + EXPECT_EQ(server_address_, client_connection->peer_address()); + EXPECT_TRUE(client_->client()->HasPendingPathValidation()); + SendSynchronousFooRequestAndCheckResponse(); + while (client_->client()->HasPendingPathValidation()) { + client_->client()->WaitForEvents(); + } + + const auto client_stats = GetClientConnection()->GetStats(); + EXPECT_TRUE(client_stats.server_preferred_address_validated); + EXPECT_FALSE(client_stats.failed_to_validate_server_preferred_address); +} + +TEST_P(EndToEndPacketReorderingTest, ReorderedPathChallenge) { + ASSERT_TRUE(Initialize()); + if (!version_.HasIetfQuicFrames()) { + return; + } + client_.reset(EndToEndTest::CreateQuicClient(nullptr)); + + // Finish one request to make sure handshake established. + EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); + + // Wait for the connection to become idle, to make sure the packet gets + // delayed is the connectivity probing packet. + client_->WaitForDelayedAcks(); + + QuicSocketAddress old_addr = + client_->client()->network_helper()->GetLatestClientAddress(); + + // Migrate socket to the new IP address. + QuicIpAddress new_host = TestLoopback(2); + EXPECT_NE(old_addr.host(), new_host); + + // Setup writer wrapper to hold the probing packet. + auto holding_writer = new PacketHoldingWriter(); + client_->UseWriter(holding_writer); + // Write a connectivity probing after the next /foo request. + holding_writer->HoldNextPacket(); + + // A packet with PATH_CHALLENGE will be held in the writer. + client_->client()->ValidateNewNetwork(new_host); + + // Send (on-hold) PATH_CHALLENGE after this request. + client_->SendRequest("/foo"); + holding_writer->ReleasePacket(); + + client_->WaitForResponse(); + + EXPECT_EQ(kFooResponseBody, client_->response_body()); + // Send yet another request after the PATH_CHALLENGE, when this request + // returns, the probing is guaranteed to have been received by the server, and + // the server's response to probing is guaranteed to have been received by the + // client. + EXPECT_EQ(kBarResponseBody, client_->SendSynchronousRequest("/bar")); + + // Client should have received a PATH_CHALLENGE. + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + EXPECT_EQ(client_connection->validate_client_address() ? 1u : 0, + client_connection->GetStats().num_connectivity_probing_received); + + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + if (server_connection != nullptr) { + EXPECT_EQ(1u, + server_connection->GetStats().num_connectivity_probing_received); + } else { + ADD_FAILURE() << "Missing server connection"; + } + server_thread_->Resume(); +} + +TEST_P(EndToEndPacketReorderingTest, PathValidationFailure) { + ASSERT_TRUE(Initialize()); + if (!version_.HasIetfQuicFrames()) { + return; + } + + client_.reset(CreateQuicClient(nullptr)); + // Finish one request to make sure handshake established. + EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); + + // Wait for the connection to become idle, to make sure the packet gets + // delayed is the connectivity probing packet. + client_->WaitForDelayedAcks(); + + QuicSocketAddress old_addr = client_->client()->session()->self_address(); + + // Migrate socket to the new IP address. + QuicIpAddress new_host = TestLoopback(2); + EXPECT_NE(old_addr.host(), new_host); + + // Drop PATH_RESPONSE packets to timeout the path validation. + server_writer_->set_fake_packet_loss_percentage(100); + ASSERT_TRUE(client_->client()->ValidateAndMigrateSocket(new_host)); + while (client_->client()->HasPendingPathValidation()) { + client_->client()->WaitForEvents(); + } + EXPECT_EQ(old_addr, client_->client()->session()->self_address()); + server_writer_->set_fake_packet_loss_percentage(0); + EXPECT_EQ(kBarResponseBody, client_->SendSynchronousRequest("/bar")); + + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + if (server_connection != nullptr) { + EXPECT_EQ(3u, + server_connection->GetStats().num_connectivity_probing_received); + } else { + ADD_FAILURE() << "Missing server connection"; + } + server_thread_->Resume(); +} + +TEST_P(EndToEndPacketReorderingTest, MigrateAgainAfterPathValidationFailure) { + ASSERT_TRUE(Initialize()); + if (!GetClientConnection()->connection_migration_use_new_cid()) { + return; + } + + client_.reset(CreateQuicClient(nullptr)); + // Finish one request to make sure handshake established. + EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); + + // Wait for the connection to become idle, to make sure the packet gets + // delayed is the connectivity probing packet. + client_->WaitForDelayedAcks(); + + QuicSocketAddress addr1 = client_->client()->session()->self_address(); + QuicConnection* client_connection = GetClientConnection(); + QuicConnectionId server_cid1 = client_connection->connection_id(); + + // Migrate socket to the new IP address. + QuicIpAddress host2 = TestLoopback(2); + EXPECT_NE(addr1.host(), host2); + + // Drop PATH_RESPONSE packets to timeout the path validation. + server_writer_->set_fake_packet_loss_percentage(100); + ASSERT_TRUE( + QuicConnectionPeer::HasUnusedPeerIssuedConnectionId(client_connection)); + + ASSERT_TRUE(client_->client()->ValidateAndMigrateSocket(host2)); + + QuicConnectionId server_cid2 = + QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection); + EXPECT_FALSE(server_cid2.IsEmpty()); + EXPECT_NE(server_cid2, server_cid1); + // Wait until path validation fails at the client. + while (client_->client()->HasPendingPathValidation()) { + EXPECT_EQ(server_cid2, + QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection)); + client_->client()->WaitForEvents(); + } + EXPECT_EQ(addr1, client_->client()->session()->self_address()); + EXPECT_EQ(server_cid1, GetClientConnection()->connection_id()); + + server_writer_->set_fake_packet_loss_percentage(0); + EXPECT_EQ(kBarResponseBody, client_->SendSynchronousRequest("/bar")); + + WaitForNewConnectionIds(); + EXPECT_EQ(1u, client_connection->GetStats().num_retire_connection_id_sent); + EXPECT_EQ(0u, client_connection->GetStats().num_new_connection_id_sent); + + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + // Server has received 3 path challenges. + EXPECT_EQ(3u, + server_connection->GetStats().num_connectivity_probing_received); + EXPECT_EQ(server_cid1, server_connection->connection_id()); + EXPECT_EQ(0u, server_connection->GetStats().num_retire_connection_id_sent); + EXPECT_EQ(2u, server_connection->GetStats().num_new_connection_id_sent); + server_thread_->Resume(); + + // Migrate socket to a new IP address again. + QuicIpAddress host3 = TestLoopback(3); + EXPECT_NE(addr1.host(), host3); + EXPECT_NE(host2, host3); + + WaitForNewConnectionIds(); + EXPECT_EQ(1u, client_connection->GetStats().num_retire_connection_id_sent); + EXPECT_EQ(0u, client_connection->GetStats().num_new_connection_id_sent); + + ASSERT_TRUE(client_->client()->ValidateAndMigrateSocket(host3)); + QuicConnectionId server_cid3 = + QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection); + EXPECT_FALSE(server_cid3.IsEmpty()); + EXPECT_NE(server_cid1, server_cid3); + EXPECT_NE(server_cid2, server_cid3); + while (client_->client()->HasPendingPathValidation()) { + client_->client()->WaitForEvents(); + } + EXPECT_EQ(host3, client_->client()->session()->self_address().host()); + EXPECT_EQ(server_cid3, GetClientConnection()->connection_id()); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + EXPECT_EQ(kBarResponseBody, client_->SendSynchronousRequest("/bar")); + + // Server should send a new connection ID to client. + WaitForNewConnectionIds(); + EXPECT_EQ(2u, client_connection->GetStats().num_retire_connection_id_sent); + EXPECT_EQ(0u, client_connection->GetStats().num_new_connection_id_sent); +} + +TEST_P(EndToEndPacketReorderingTest, + MigrateAgainAfterPathValidationFailureWithNonZeroClientCid) { + if (!version_.SupportsClientConnectionIds()) { + ASSERT_TRUE(Initialize()); + return; + } + SetQuicReloadableFlag(quic_retire_cid_on_reverse_path_validation_failure, + true); + override_client_connection_id_length_ = kQuicDefaultConnectionIdLength; + ASSERT_TRUE(Initialize()); + if (!GetClientConnection()->connection_migration_use_new_cid()) { + return; + } + + client_.reset(CreateQuicClient(nullptr)); + // Finish one request to make sure handshake established. + EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); + + // Wait for the connection to become idle, to make sure the packet gets + // delayed is the connectivity probing packet. + client_->WaitForDelayedAcks(); + + QuicSocketAddress addr1 = client_->client()->session()->self_address(); + QuicConnection* client_connection = GetClientConnection(); + QuicConnectionId server_cid1 = client_connection->connection_id(); + QuicConnectionId client_cid1 = client_connection->client_connection_id(); + + // Migrate socket to the new IP address. + QuicIpAddress host2 = TestLoopback(2); + EXPECT_NE(addr1.host(), host2); + + // Drop PATH_RESPONSE packets to timeout the path validation. + server_writer_->set_fake_packet_loss_percentage(100); + ASSERT_TRUE( + QuicConnectionPeer::HasUnusedPeerIssuedConnectionId(client_connection)); + ASSERT_TRUE(client_->client()->ValidateAndMigrateSocket(host2)); + QuicConnectionId server_cid2 = + QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection); + EXPECT_FALSE(server_cid2.IsEmpty()); + EXPECT_NE(server_cid2, server_cid1); + QuicConnectionId client_cid2 = + QuicConnectionPeer::GetClientConnectionIdOnAlternativePath( + client_connection); + EXPECT_FALSE(client_cid2.IsEmpty()); + EXPECT_NE(client_cid2, client_cid1); + while (client_->client()->HasPendingPathValidation()) { + EXPECT_EQ(server_cid2, + QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection)); + client_->client()->WaitForEvents(); + } + EXPECT_EQ(addr1, client_->client()->session()->self_address()); + EXPECT_EQ(server_cid1, GetClientConnection()->connection_id()); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + server_writer_->set_fake_packet_loss_percentage(0); + EXPECT_EQ(kBarResponseBody, client_->SendSynchronousRequest("/bar")); + WaitForNewConnectionIds(); + EXPECT_EQ(1u, client_connection->GetStats().num_retire_connection_id_sent); + EXPECT_EQ(2u, client_connection->GetStats().num_new_connection_id_sent); + + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + if (server_connection != nullptr) { + EXPECT_EQ(3u, + server_connection->GetStats().num_connectivity_probing_received); + EXPECT_EQ(server_cid1, server_connection->connection_id()); + } else { + ADD_FAILURE() << "Missing server connection"; + } + EXPECT_EQ(1u, server_connection->GetStats().num_retire_connection_id_sent); + EXPECT_EQ(2u, server_connection->GetStats().num_new_connection_id_sent); + server_thread_->Resume(); + + // Migrate socket to a new IP address again. + QuicIpAddress host3 = TestLoopback(3); + EXPECT_NE(addr1.host(), host3); + EXPECT_NE(host2, host3); + ASSERT_TRUE(client_->client()->ValidateAndMigrateSocket(host3)); + + QuicConnectionId server_cid3 = + QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection); + EXPECT_FALSE(server_cid3.IsEmpty()); + EXPECT_NE(server_cid1, server_cid3); + EXPECT_NE(server_cid2, server_cid3); + QuicConnectionId client_cid3 = + QuicConnectionPeer::GetClientConnectionIdOnAlternativePath( + client_connection); + EXPECT_NE(client_cid1, client_cid3); + EXPECT_NE(client_cid2, client_cid3); + while (client_->client()->HasPendingPathValidation()) { + client_->client()->WaitForEvents(); + } + EXPECT_EQ(host3, client_->client()->session()->self_address().host()); + EXPECT_EQ(server_cid3, GetClientConnection()->connection_id()); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + EXPECT_EQ(kBarResponseBody, client_->SendSynchronousRequest("/bar")); + + // Server should send new server connection ID to client and retires old + // client connection ID. + WaitForNewConnectionIds(); + EXPECT_EQ(2u, client_connection->GetStats().num_retire_connection_id_sent); + EXPECT_EQ(3u, client_connection->GetStats().num_new_connection_id_sent); +} + +TEST_P(EndToEndPacketReorderingTest, Buffer0RttRequest) { + ASSERT_TRUE(Initialize()); + // Finish one request to make sure handshake established. + client_->SendSynchronousRequest("/foo"); + // Disconnect for next 0-rtt request. + client_->Disconnect(); + + // Client has valid Session Ticket now. Do a 0-RTT request. + // Buffer a CHLO till the request is sent out. HTTP/3 sends two packets: a + // SETTINGS frame and a request. + reorder_writer_->SetDelay(version_.UsesHttp3() ? 2 : 1); + // Only send out a CHLO. + client_->client()->Initialize(); + + // Send a request before handshake finishes. + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/bar"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + + client_->SendMessage(headers, ""); + client_->WaitForResponse(); + EXPECT_EQ(kBarResponseBody, client_->response_body()); + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + QuicConnectionStats client_stats = client_connection->GetStats(); + EXPECT_EQ(0u, client_stats.packets_lost); + EXPECT_TRUE(client_->client()->EarlyDataAccepted()); +} + +TEST_P(EndToEndTest, SimpleStopSendingRstStreamTest) { + ASSERT_TRUE(Initialize()); + + // Send a request without a fin, to keep the stream open + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + client_->SendMessage(headers, "", /*fin=*/false); + // Stream should be open + ASSERT_NE(nullptr, client_->latest_created_stream()); + EXPECT_FALSE(client_->latest_created_stream()->write_side_closed()); + EXPECT_FALSE( + QuicStreamPeer::read_side_closed(client_->latest_created_stream())); + + // Send a RST_STREAM+STOP_SENDING on the stream + // Code is not important. + client_->latest_created_stream()->Reset(QUIC_BAD_APPLICATION_PAYLOAD); + client_->WaitForResponse(); + + // Stream should be gone. + ASSERT_EQ(nullptr, client_->latest_created_stream()); +} + +class BadShloPacketWriter : public QuicPacketWriterWrapper { + public: + BadShloPacketWriter(ParsedQuicVersion version) + : error_returned_(false), version_(version) {} + ~BadShloPacketWriter() override {} + + WriteResult WritePacket(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + quic::PerPacketOptions* options) override { + const WriteResult result = QuicPacketWriterWrapper::WritePacket( + buffer, buf_len, self_address, peer_address, options); + const uint8_t type_byte = buffer[0]; + if (!error_returned_ && (type_byte & FLAGS_LONG_HEADER) && + TypeByteIsServerHello(type_byte)) { + QUIC_DVLOG(1) << "Return write error for packet containing ServerHello"; + error_returned_ = true; + return WriteResult(WRITE_STATUS_ERROR, *MessageTooBigErrorCode()); + } + return result; + } + + bool TypeByteIsServerHello(uint8_t type_byte) { + if (version_.UsesV2PacketTypes()) { + return ((type_byte & 0x30) >> 4) == 3; + } + if (version_.UsesQuicCrypto()) { + // ENCRYPTION_ZERO_RTT packet. + return ((type_byte & 0x30) >> 4) == 1; + } + // ENCRYPTION_HANDSHAKE packet. + return ((type_byte & 0x30) >> 4) == 2; + } + + private: + bool error_returned_; + ParsedQuicVersion version_; +}; + +TEST_P(EndToEndTest, ConnectionCloseBeforeHandshakeComplete) { + if (!version_.HasIetfInvariantHeader()) { + // Only runs for IETF QUIC header. + Initialize(); + return; + } + // This test ensures ZERO_RTT_PROTECTED connection close could close a client + // which has switched to forward secure. + connect_to_server_on_initialize_ = false; + ASSERT_TRUE(Initialize()); + server_thread_->Pause(); + QuicDispatcher* dispatcher = + QuicServerPeer::GetDispatcher(server_thread_->server()); + if (dispatcher == nullptr) { + ADD_FAILURE() << "Missing dispatcher"; + server_thread_->Resume(); + return; + } + if (dispatcher->NumSessions() > 0) { + ADD_FAILURE() << "Dispatcher session map not empty"; + server_thread_->Resume(); + return; + } + // Note: this writer will only used by the server connection, not the time + // wait list. + QuicDispatcherPeer::UseWriter( + dispatcher, + // This causes the first server sent ZERO_RTT_PROTECTED packet (i.e., + // SHLO) to be sent, but WRITE_ERROR is returned. Such that a + // ZERO_RTT_PROTECTED connection close would be sent to a client with + // encryption level FORWARD_SECURE. + new BadShloPacketWriter(version_)); + server_thread_->Resume(); + + client_.reset(CreateQuicClient(client_writer_)); + EXPECT_EQ("", client_->SendSynchronousRequest("/foo")); + // Verify ZERO_RTT_PROTECTED connection close is successfully processed by + // client which switches to FORWARD_SECURE. + EXPECT_THAT(client_->connection_error(), IsError(QUIC_PACKET_WRITE_ERROR)); +} + +class BadShloPacketWriter2 : public QuicPacketWriterWrapper { + public: + BadShloPacketWriter2(ParsedQuicVersion version) + : error_returned_(false), version_(version) {} + ~BadShloPacketWriter2() override {} + + WriteResult WritePacket(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + quic::PerPacketOptions* options) override { + const uint8_t type_byte = buffer[0]; + + if (type_byte & FLAGS_LONG_HEADER) { + if (((type_byte & 0x30 >> 4) == (version_.UsesV2PacketTypes() ? 2 : 1)) || + ((type_byte & 0x7F) == 0x7C)) { + QUIC_DVLOG(1) << "Dropping ZERO_RTT_PACKET packet"; + return WriteResult(WRITE_STATUS_OK, buf_len); + } + } else if (!error_returned_) { + QUIC_DVLOG(1) << "Return write error for short header packet"; + error_returned_ = true; + return WriteResult(WRITE_STATUS_ERROR, *MessageTooBigErrorCode()); + } + return QuicPacketWriterWrapper::WritePacket(buffer, buf_len, self_address, + peer_address, options); + } + + private: + bool error_returned_; + ParsedQuicVersion version_; +}; + +TEST_P(EndToEndTest, ForwardSecureConnectionClose) { + // This test ensures ZERO_RTT_PROTECTED connection close is sent to a client + // which has ZERO_RTT_PROTECTED encryption level. + connect_to_server_on_initialize_ = !version_.HasIetfInvariantHeader(); + ASSERT_TRUE(Initialize()); + if (!version_.HasIetfInvariantHeader()) { + // Only runs for IETF QUIC header. + return; + } + server_thread_->Pause(); + QuicDispatcher* dispatcher = + QuicServerPeer::GetDispatcher(server_thread_->server()); + if (dispatcher == nullptr) { + ADD_FAILURE() << "Missing dispatcher"; + server_thread_->Resume(); + return; + } + if (dispatcher->NumSessions() > 0) { + ADD_FAILURE() << "Dispatcher session map not empty"; + server_thread_->Resume(); + return; + } + // Note: this writer will only used by the server connection, not the time + // wait list. + QuicDispatcherPeer::UseWriter( + dispatcher, + // This causes the all server sent ZERO_RTT_PROTECTED packets to be + // dropped, and first short header packet causes write error. + new BadShloPacketWriter2(version_)); + server_thread_->Resume(); + client_.reset(CreateQuicClient(client_writer_)); + EXPECT_EQ("", client_->SendSynchronousRequest("/foo")); + // Verify ZERO_RTT_PROTECTED connection close is successfully processed by + // client. + EXPECT_THAT(client_->connection_error(), IsError(QUIC_PACKET_WRITE_ERROR)); +} + +// Test that the stream id manager closes the connection if a stream +// in excess of the allowed maximum. +TEST_P(EndToEndTest, TooBigStreamIdClosesConnection) { + // Has to be before version test, see EndToEndTest::TearDown() + ASSERT_TRUE(Initialize()); + if (!version_.HasIetfQuicFrames()) { + // Only runs for IETF QUIC. + return; + } + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + + std::string body(kMaxOutgoingPacketSize, 'a'); + Http2HeaderBlock headers; + headers[":method"] = "POST"; + headers[":path"] = "/foo"; + headers[":scheme"] = "https"; + headers[":authority"] = server_hostname_; + + // Force the client to write with a stream ID that exceeds the limit. + QuicSpdySession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + QuicStreamIdManager* stream_id_manager = + QuicSessionPeer::ietf_bidirectional_stream_id_manager(client_session); + ASSERT_TRUE(stream_id_manager); + QuicStreamCount max_number_of_streams = + stream_id_manager->outgoing_max_streams(); + QuicSessionPeer::SetNextOutgoingBidirectionalStreamId( + client_session, + GetNthClientInitiatedBidirectionalId(max_number_of_streams + 1)); + client_->SendCustomSynchronousRequest(headers, body); + EXPECT_THAT(client_->stream_error(), + IsStreamError(QUIC_STREAM_CONNECTION_ERROR)); + EXPECT_THAT(client_session->error(), IsError(QUIC_INVALID_STREAM_ID)); + EXPECT_EQ(IETF_QUIC_TRANSPORT_CONNECTION_CLOSE, client_session->close_type()); + EXPECT_TRUE( + IS_IETF_STREAM_FRAME(client_session->transport_close_frame_type())); +} + +TEST_P(EndToEndTest, CustomTransportParameters) { + if (!version_.UsesTls()) { + // Custom transport parameters are only supported with TLS. + ASSERT_TRUE(Initialize()); + return; + } + constexpr auto kCustomParameter = + static_cast(0xff34); + client_config_.custom_transport_parameters_to_send()[kCustomParameter] = + "test"; + NiceMock visitor; + connection_debug_visitor_ = &visitor; + EXPECT_CALL(visitor, OnTransportParametersSent(_)) + .WillOnce(Invoke([kCustomParameter]( + const TransportParameters& transport_parameters) { + ASSERT_NE(transport_parameters.custom_parameters.find(kCustomParameter), + transport_parameters.custom_parameters.end()); + EXPECT_EQ(transport_parameters.custom_parameters.at(kCustomParameter), + "test"); + })); + EXPECT_CALL(visitor, OnTransportParametersReceived(_)).Times(1); + ASSERT_TRUE(Initialize()); + + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + + server_thread_->Pause(); + QuicSpdySession* server_session = GetServerSession(); + QuicConfig* server_config = nullptr; + if (server_session != nullptr) { + server_config = server_session->config(); + } else { + ADD_FAILURE() << "Missing server session"; + } + if (server_config != nullptr) { + if (server_config->received_custom_transport_parameters().find( + kCustomParameter) != + server_config->received_custom_transport_parameters().end()) { + EXPECT_EQ(server_config->received_custom_transport_parameters().at( + kCustomParameter), + "test"); + } else { + ADD_FAILURE() << "Did not find custom parameter"; + } + } else { + ADD_FAILURE() << "Missing server config"; + } + server_thread_->Resume(); +} + +// Testing packet writer that makes a copy of the first sent packets before +// sending them. Useful for tests that need access to sent packets. +class CopyingPacketWriter : public PacketDroppingTestWriter { + public: + explicit CopyingPacketWriter(int num_packets_to_copy) + : num_packets_to_copy_(num_packets_to_copy) {} + WriteResult WritePacket(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + PerPacketOptions* options) override { + if (num_packets_to_copy_ > 0) { + num_packets_to_copy_--; + packets_.push_back( + QuicEncryptedPacket(buffer, buf_len, /*owns_buffer=*/false).Clone()); + } + return PacketDroppingTestWriter::WritePacket(buffer, buf_len, self_address, + peer_address, options); + } + + std::vector>& packets() { + return packets_; + } + + private: + int num_packets_to_copy_; + std::vector> packets_; +}; + +TEST_P(EndToEndTest, KeyUpdateInitiatedByClient) { + if (!version_.UsesTls()) { + // Key Update is only supported in TLS handshake. + ASSERT_TRUE(Initialize()); + return; + } + + ASSERT_TRUE(Initialize()); + + SendSynchronousFooRequestAndCheckResponse(); + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + EXPECT_EQ(0u, client_connection->GetStats().key_update_count); + + EXPECT_TRUE( + client_connection->InitiateKeyUpdate(KeyUpdateReason::kLocalForTests)); + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_EQ(1u, client_connection->GetStats().key_update_count); + + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_EQ(1u, client_connection->GetStats().key_update_count); + + EXPECT_TRUE( + client_connection->InitiateKeyUpdate(KeyUpdateReason::kLocalForTests)); + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_EQ(2u, client_connection->GetStats().key_update_count); + + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + if (server_connection) { + QuicConnectionStats server_stats = server_connection->GetStats(); + EXPECT_EQ(2u, server_stats.key_update_count); + } else { + ADD_FAILURE() << "Missing server connection"; + } + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, KeyUpdateInitiatedByServer) { + if (!version_.UsesTls()) { + // Key Update is only supported in TLS handshake. + ASSERT_TRUE(Initialize()); + return; + } + + ASSERT_TRUE(Initialize()); + + SendSynchronousFooRequestAndCheckResponse(); + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + EXPECT_EQ(0u, client_connection->GetStats().key_update_count); + + // Use WaitUntil to ensure the server had executed the key update predicate + // before sending the Foo request, otherwise the test can be flaky if it + // receives the Foo request before executing the key update. + server_thread_->WaitUntil( + [this]() { + QuicConnection* server_connection = GetServerConnection(); + if (server_connection != nullptr) { + if (!server_connection->IsKeyUpdateAllowed()) { + // Server may not have received ack from client yet for the current + // key phase, wait a bit and try again. + return false; + } + EXPECT_TRUE(server_connection->InitiateKeyUpdate( + KeyUpdateReason::kLocalForTests)); + } else { + ADD_FAILURE() << "Missing server connection"; + } + return true; + }, + QuicTime::Delta::FromSeconds(5)); + + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_EQ(1u, client_connection->GetStats().key_update_count); + + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_EQ(1u, client_connection->GetStats().key_update_count); + + server_thread_->WaitUntil( + [this]() { + QuicConnection* server_connection = GetServerConnection(); + if (server_connection != nullptr) { + if (!server_connection->IsKeyUpdateAllowed()) { + return false; + } + EXPECT_TRUE(server_connection->InitiateKeyUpdate( + KeyUpdateReason::kLocalForTests)); + } else { + ADD_FAILURE() << "Missing server connection"; + } + return true; + }, + QuicTime::Delta::FromSeconds(5)); + + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_EQ(2u, client_connection->GetStats().key_update_count); + + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + if (server_connection) { + QuicConnectionStats server_stats = server_connection->GetStats(); + EXPECT_EQ(2u, server_stats.key_update_count); + } else { + ADD_FAILURE() << "Missing server connection"; + } + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, KeyUpdateInitiatedByBoth) { + if (!version_.UsesTls()) { + // Key Update is only supported in TLS handshake. + ASSERT_TRUE(Initialize()); + return; + } + + ASSERT_TRUE(Initialize()); + + SendSynchronousFooRequestAndCheckResponse(); + + // Use WaitUntil to ensure the server had executed the key update predicate + // before the client sends the Foo request, otherwise the Foo request from + // the client could trigger the server key update before the server can + // initiate the key update locally. That would mean the test is no longer + // hitting the intended test state of both sides locally initiating a key + // update before receiving a packet in the new key phase from the other side. + // Additionally the test would fail since InitiateKeyUpdate() would not allow + // to do another key update yet and return false. + server_thread_->WaitUntil( + [this]() { + QuicConnection* server_connection = GetServerConnection(); + if (server_connection != nullptr) { + if (!server_connection->IsKeyUpdateAllowed()) { + // Server may not have received ack from client yet for the current + // key phase, wait a bit and try again. + return false; + } + EXPECT_TRUE(server_connection->InitiateKeyUpdate( + KeyUpdateReason::kLocalForTests)); + } else { + ADD_FAILURE() << "Missing server connection"; + } + return true; + }, + QuicTime::Delta::FromSeconds(5)); + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + EXPECT_TRUE( + client_connection->InitiateKeyUpdate(KeyUpdateReason::kLocalForTests)); + + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_EQ(1u, client_connection->GetStats().key_update_count); + + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_EQ(1u, client_connection->GetStats().key_update_count); + + server_thread_->WaitUntil( + [this]() { + QuicConnection* server_connection = GetServerConnection(); + if (server_connection != nullptr) { + if (!server_connection->IsKeyUpdateAllowed()) { + return false; + } + EXPECT_TRUE(server_connection->InitiateKeyUpdate( + KeyUpdateReason::kLocalForTests)); + } else { + ADD_FAILURE() << "Missing server connection"; + } + return true; + }, + QuicTime::Delta::FromSeconds(5)); + EXPECT_TRUE( + client_connection->InitiateKeyUpdate(KeyUpdateReason::kLocalForTests)); + + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_EQ(2u, client_connection->GetStats().key_update_count); + + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + if (server_connection) { + QuicConnectionStats server_stats = server_connection->GetStats(); + EXPECT_EQ(2u, server_stats.key_update_count); + } else { + ADD_FAILURE() << "Missing server connection"; + } + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, KeyUpdateInitiatedByConfidentialityLimit) { + SetQuicFlag(quic_key_update_confidentiality_limit, 16U); + + if (!version_.UsesTls()) { + // Key Update is only supported in TLS handshake. + ASSERT_TRUE(Initialize()); + return; + } + + ASSERT_TRUE(Initialize()); + + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection); + EXPECT_EQ(0u, client_connection->GetStats().key_update_count); + + server_thread_->WaitUntil( + [this]() { + QuicConnection* server_connection = GetServerConnection(); + if (server_connection != nullptr) { + EXPECT_EQ(0u, server_connection->GetStats().key_update_count); + } else { + ADD_FAILURE() << "Missing server connection"; + } + return true; + }, + QuicTime::Delta::FromSeconds(5)); + + for (uint64_t i = 0; i < GetQuicFlag(quic_key_update_confidentiality_limit); + ++i) { + SendSynchronousFooRequestAndCheckResponse(); + } + + // Don't know exactly how many packets will be sent in each request/response, + // so just test that at least one key update occurred. + EXPECT_LE(1u, client_connection->GetStats().key_update_count); + + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + if (server_connection) { + QuicConnectionStats server_stats = server_connection->GetStats(); + EXPECT_LE(1u, server_stats.key_update_count); + } else { + ADD_FAILURE() << "Missing server connection"; + } + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, TlsResumptionEnabledOnTheFly) { + SetQuicFlag(quic_disable_server_tls_resumption, true); + ASSERT_TRUE(Initialize()); + + if (!version_.UsesTls()) { + // This test is TLS specific. + return; + } + + // Send the first request. Client should not have a resumption ticket. + SendSynchronousFooRequestAndCheckResponse(); + QuicSpdyClientSession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_EQ(client_session->GetCryptoStream()->EarlyDataReason(), + ssl_early_data_no_session_offered); + EXPECT_FALSE(client_session->EarlyDataAccepted()); + client_->Disconnect(); + + SetQuicFlag(quic_disable_server_tls_resumption, false); + + // Send the second request. Client should still have no resumption ticket, but + // it will receive one which can be used by the next request. + client_->Connect(); + SendSynchronousFooRequestAndCheckResponse(); + + client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_EQ(client_session->GetCryptoStream()->EarlyDataReason(), + ssl_early_data_no_session_offered); + EXPECT_FALSE(client_session->EarlyDataAccepted()); + client_->Disconnect(); + + // Send the third request in 0RTT. + client_->Connect(); + SendSynchronousFooRequestAndCheckResponse(); + + client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_TRUE(client_session->EarlyDataAccepted()); + client_->Disconnect(); +} + +TEST_P(EndToEndTest, TlsResumptionDisabledOnTheFly) { + SetQuicFlag(quic_disable_server_tls_resumption, false); + ASSERT_TRUE(Initialize()); + + if (!version_.UsesTls()) { + // This test is TLS specific. + return; + } + + // Send the first request and then disconnect. + SendSynchronousFooRequestAndCheckResponse(); + QuicSpdyClientSession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_FALSE(client_session->EarlyDataAccepted()); + client_->Disconnect(); + + // Send the second request in 0RTT. + client_->Connect(); + SendSynchronousFooRequestAndCheckResponse(); + + client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_TRUE(client_session->EarlyDataAccepted()); + client_->Disconnect(); + + SetQuicFlag(quic_disable_server_tls_resumption, true); + + // Send the third request. The client should try resumption but server should + // decline it. + client_->Connect(); + SendSynchronousFooRequestAndCheckResponse(); + + client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_FALSE(client_session->EarlyDataAccepted()); + EXPECT_EQ(client_session->GetCryptoStream()->EarlyDataReason(), + ssl_early_data_session_not_resumed); + client_->Disconnect(); + + // Keep sending until the client runs out of resumption tickets. + for (int i = 0; i < 10; ++i) { + client_->Connect(); + SendSynchronousFooRequestAndCheckResponse(); + + client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_FALSE(client_session->EarlyDataAccepted()); + const auto early_data_reason = + client_session->GetCryptoStream()->EarlyDataReason(); + client_->Disconnect(); + + if (early_data_reason != ssl_early_data_session_not_resumed) { + EXPECT_EQ(early_data_reason, ssl_early_data_unsupported_for_session); + return; + } + } + + ADD_FAILURE() << "Client should not have 10 resumption tickets."; +} + +TEST_P(EndToEndTest, WebTransportSessionSetup) { + enable_web_transport_ = true; + ASSERT_TRUE(Initialize()); + + if (!version_.UsesHttp3()) { + return; + } + + WebTransportHttp3* web_transport = + CreateWebTransportSession("/echo", /*wait_for_server_response=*/true); + ASSERT_NE(web_transport, nullptr); + + server_thread_->Pause(); + QuicSpdySession* server_session = GetServerSession(); + EXPECT_TRUE(server_session->GetWebTransportSession(web_transport->id()) != + nullptr); + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, WebTransportSessionSetupWithEchoWithSuffix) { + enable_web_transport_ = true; + ASSERT_TRUE(Initialize()); + + if (!version_.UsesHttp3()) { + return; + } + + // "/echoFoo" should be accepted as "echo" with "set-header" query. + WebTransportHttp3* web_transport = CreateWebTransportSession( + "/echoFoo?set-header=bar:baz", /*wait_for_server_response=*/true); + ASSERT_NE(web_transport, nullptr); + + server_thread_->Pause(); + QuicSpdySession* server_session = GetServerSession(); + EXPECT_TRUE(server_session->GetWebTransportSession(web_transport->id()) != + nullptr); + server_thread_->Resume(); + const spdy::Http2HeaderBlock* response_headers = client_->response_headers(); + auto it = response_headers->find("bar"); + EXPECT_NE(it, response_headers->end()); + EXPECT_EQ(it->second, "baz"); +} + +TEST_P(EndToEndTest, WebTransportSessionWithLoss) { + enable_web_transport_ = true; + // Enable loss to verify all permutations of receiving SETTINGS and + // request/response data. + SetPacketLossPercentage(30); + ASSERT_TRUE(Initialize()); + + if (!version_.UsesHttp3()) { + return; + } + + WebTransportHttp3* web_transport = + CreateWebTransportSession("/echo", /*wait_for_server_response=*/true); + ASSERT_NE(web_transport, nullptr); + + server_thread_->Pause(); + QuicSpdySession* server_session = GetServerSession(); + EXPECT_TRUE(server_session->GetWebTransportSession(web_transport->id()) != + nullptr); + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, WebTransportSessionUnidirectionalStream) { + enable_web_transport_ = true; + ASSERT_TRUE(Initialize()); + + if (!version_.UsesHttp3()) { + return; + } + + WebTransportHttp3* session = + CreateWebTransportSession("/echo", /*wait_for_server_response=*/true); + ASSERT_TRUE(session != nullptr); + NiceMock& visitor = + SetupWebTransportVisitor(session); + + WebTransportStream* outgoing_stream = + session->OpenOutgoingUnidirectionalStream(); + ASSERT_TRUE(outgoing_stream != nullptr); + EXPECT_EQ(outgoing_stream, + session->GetStreamById(outgoing_stream->GetStreamId())); + + auto stream_visitor = + std::make_unique>(); + bool data_acknowledged = false; + EXPECT_CALL(*stream_visitor, OnWriteSideInDataRecvdState()) + .WillOnce(Assign(&data_acknowledged, true)); + outgoing_stream->SetVisitor(std::move(stream_visitor)); + + QUICHE_EXPECT_OK(quiche::WriteIntoStream(*outgoing_stream, "test")); + EXPECT_TRUE(outgoing_stream->SendFin()); + + bool stream_received = false; + EXPECT_CALL(visitor, OnIncomingUnidirectionalStreamAvailable()) + .WillOnce(Assign(&stream_received, true)); + client_->WaitUntil(2000, [&stream_received]() { return stream_received; }); + EXPECT_TRUE(stream_received); + WebTransportStream* received_stream = + session->AcceptIncomingUnidirectionalStream(); + ASSERT_TRUE(received_stream != nullptr); + EXPECT_EQ(received_stream, + session->GetStreamById(received_stream->GetStreamId())); + std::string received_data; + WebTransportStream::ReadResult result = received_stream->Read(&received_data); + EXPECT_EQ(received_data, "test"); + EXPECT_TRUE(result.fin); + + client_->WaitUntil(2000, + [&data_acknowledged]() { return data_acknowledged; }); + EXPECT_TRUE(data_acknowledged); +} + +TEST_P(EndToEndTest, WebTransportSessionUnidirectionalStreamSentEarly) { + enable_web_transport_ = true; + SetPacketLossPercentage(30); + ASSERT_TRUE(Initialize()); + + if (!version_.UsesHttp3()) { + return; + } + + WebTransportHttp3* session = + CreateWebTransportSession("/echo", /*wait_for_server_response=*/false); + ASSERT_TRUE(session != nullptr); + NiceMock& visitor = + SetupWebTransportVisitor(session); + + WebTransportStream* outgoing_stream = + session->OpenOutgoingUnidirectionalStream(); + ASSERT_TRUE(outgoing_stream != nullptr); + QUICHE_EXPECT_OK(quiche::WriteIntoStream(*outgoing_stream, "test")); + EXPECT_TRUE(outgoing_stream->SendFin()); + + bool stream_received = false; + EXPECT_CALL(visitor, OnIncomingUnidirectionalStreamAvailable()) + .WillOnce(Assign(&stream_received, true)); + client_->WaitUntil(5000, [&stream_received]() { return stream_received; }); + EXPECT_TRUE(stream_received); + WebTransportStream* received_stream = + session->AcceptIncomingUnidirectionalStream(); + ASSERT_TRUE(received_stream != nullptr); + std::string received_data; + WebTransportStream::ReadResult result = received_stream->Read(&received_data); + EXPECT_EQ(received_data, "test"); + EXPECT_TRUE(result.fin); +} + +TEST_P(EndToEndTest, WebTransportSessionBidirectionalStream) { + enable_web_transport_ = true; + ASSERT_TRUE(Initialize()); + + if (!version_.UsesHttp3()) { + return; + } + + WebTransportHttp3* session = + CreateWebTransportSession("/echo", /*wait_for_server_response=*/true); + ASSERT_TRUE(session != nullptr); + + WebTransportStream* stream = session->OpenOutgoingBidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + EXPECT_EQ(stream, session->GetStreamById(stream->GetStreamId())); + + auto stream_visitor_owned = + std::make_unique>(); + MockWebTransportStreamVisitor* stream_visitor = stream_visitor_owned.get(); + bool data_acknowledged = false; + EXPECT_CALL(*stream_visitor, OnWriteSideInDataRecvdState()) + .WillOnce(Assign(&data_acknowledged, true)); + stream->SetVisitor(std::move(stream_visitor_owned)); + + QUICHE_EXPECT_OK(quiche::WriteIntoStream(*stream, "test")); + EXPECT_TRUE(stream->SendFin()); + + std::string received_data = + ReadDataFromWebTransportStreamUntilFin(stream, stream_visitor); + EXPECT_EQ(received_data, "test"); + + client_->WaitUntil(2000, + [&data_acknowledged]() { return data_acknowledged; }); + EXPECT_TRUE(data_acknowledged); +} + +TEST_P(EndToEndTest, WebTransportSessionBidirectionalStreamWithBuffering) { + enable_web_transport_ = true; + SetPacketLossPercentage(30); + ASSERT_TRUE(Initialize()); + + if (!version_.UsesHttp3()) { + return; + } + + WebTransportHttp3* session = + CreateWebTransportSession("/echo", /*wait_for_server_response=*/false); + ASSERT_TRUE(session != nullptr); + + WebTransportStream* stream = session->OpenOutgoingBidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + QUICHE_EXPECT_OK(quiche::WriteIntoStream(*stream, "test")); + EXPECT_TRUE(stream->SendFin()); + + std::string received_data = ReadDataFromWebTransportStreamUntilFin(stream); + EXPECT_EQ(received_data, "test"); +} + +TEST_P(EndToEndTest, WebTransportSessionServerBidirectionalStream) { + enable_web_transport_ = true; + ASSERT_TRUE(Initialize()); + + if (!version_.UsesHttp3()) { + return; + } + + WebTransportHttp3* session = + CreateWebTransportSession("/echo", /*wait_for_server_response=*/false); + ASSERT_TRUE(session != nullptr); + NiceMock& visitor = + SetupWebTransportVisitor(session); + + bool stream_received = false; + EXPECT_CALL(visitor, OnIncomingBidirectionalStreamAvailable()) + .WillOnce(Assign(&stream_received, true)); + client_->WaitUntil(5000, [&stream_received]() { return stream_received; }); + EXPECT_TRUE(stream_received); + + WebTransportStream* stream = session->AcceptIncomingBidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + // Test the full Writev() API. + const std::string kLongString = std::string(16 * 1024, 'a'); + std::vector write_vector = {"foo", "bar", "test", + kLongString}; + quiche::StreamWriteOptions options; + options.set_send_fin(true); + QUICHE_EXPECT_OK(stream->Writev(absl::MakeConstSpan(write_vector), options)); + + std::string received_data = ReadDataFromWebTransportStreamUntilFin(stream); + EXPECT_EQ(received_data, absl::StrCat("foobartest", kLongString)); +} + +TEST_P(EndToEndTest, WebTransportDatagrams) { + enable_web_transport_ = true; + ASSERT_TRUE(Initialize()); + + if (!version_.UsesHttp3()) { + return; + } + + WebTransportHttp3* session = + CreateWebTransportSession("/echo", /*wait_for_server_response=*/true); + ASSERT_TRUE(session != nullptr); + NiceMock& visitor = + SetupWebTransportVisitor(session); + + quiche::SimpleBufferAllocator allocator; + for (int i = 0; i < 10; i++) { + session->SendOrQueueDatagram("test"); + } + + int received = 0; + EXPECT_CALL(visitor, OnDatagramReceived(_)).WillRepeatedly([&received]() { + received++; + }); + client_->WaitUntil(5000, [&received]() { return received > 0; }); + EXPECT_GT(received, 0); +} + +TEST_P(EndToEndTest, WebTransportSessionClose) { + enable_web_transport_ = true; + ASSERT_TRUE(Initialize()); + + if (!version_.UsesHttp3()) { + return; + } + + WebTransportHttp3* session = + CreateWebTransportSession("/echo", /*wait_for_server_response=*/true); + ASSERT_TRUE(session != nullptr); + NiceMock& visitor = + SetupWebTransportVisitor(session); + + WebTransportStream* stream = session->OpenOutgoingBidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + QuicStreamId stream_id = stream->GetStreamId(); + QUICHE_EXPECT_OK(quiche::WriteIntoStream(*stream, "test")); + // Keep stream open. + + bool close_received = false; + EXPECT_CALL(visitor, OnSessionClosed(42, "test error")) + .WillOnce(Assign(&close_received, true)); + session->CloseSession(42, "test error"); + client_->WaitUntil(2000, [&]() { return close_received; }); + EXPECT_TRUE(close_received); + + QuicSpdyStream* spdy_stream = + GetClientSession()->GetOrCreateSpdyDataStream(stream_id); + EXPECT_TRUE(spdy_stream == nullptr); +} + +TEST_P(EndToEndTest, WebTransportSessionCloseWithoutCapsule) { + enable_web_transport_ = true; + ASSERT_TRUE(Initialize()); + + if (!version_.UsesHttp3()) { + return; + } + + WebTransportHttp3* session = + CreateWebTransportSession("/echo", /*wait_for_server_response=*/true); + ASSERT_TRUE(session != nullptr); + NiceMock& visitor = + SetupWebTransportVisitor(session); + + WebTransportStream* stream = session->OpenOutgoingBidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + QuicStreamId stream_id = stream->GetStreamId(); + QUICHE_EXPECT_OK(quiche::WriteIntoStream(*stream, "test")); + // Keep stream open. + + bool close_received = false; + EXPECT_CALL(visitor, OnSessionClosed(0, "")) + .WillOnce(Assign(&close_received, true)); + session->CloseSessionWithFinOnlyForTests(); + client_->WaitUntil(2000, [&]() { return close_received; }); + EXPECT_TRUE(close_received); + + QuicSpdyStream* spdy_stream = + GetClientSession()->GetOrCreateSpdyDataStream(stream_id); + EXPECT_TRUE(spdy_stream == nullptr); +} + +TEST_P(EndToEndTest, WebTransportSessionReceiveClose) { + enable_web_transport_ = true; + ASSERT_TRUE(Initialize()); + + if (!version_.UsesHttp3()) { + return; + } + + WebTransportHttp3* session = CreateWebTransportSession( + "/session-close", /*wait_for_server_response=*/true); + ASSERT_TRUE(session != nullptr); + NiceMock& visitor = + SetupWebTransportVisitor(session); + + WebTransportStream* stream = session->OpenOutgoingUnidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + QuicStreamId stream_id = stream->GetStreamId(); + QUICHE_EXPECT_OK(quiche::WriteIntoStream(*stream, "42 test error")); + EXPECT_TRUE(stream->SendFin()); + + // Have some other streams open pending, to ensure they are closed properly. + stream = session->OpenOutgoingUnidirectionalStream(); + stream = session->OpenOutgoingBidirectionalStream(); + + bool close_received = false; + EXPECT_CALL(visitor, OnSessionClosed(42, "test error")) + .WillOnce(Assign(&close_received, true)); + client_->WaitUntil(2000, [&]() { return close_received; }); + EXPECT_TRUE(close_received); + + QuicSpdyStream* spdy_stream = + GetClientSession()->GetOrCreateSpdyDataStream(stream_id); + EXPECT_TRUE(spdy_stream == nullptr); +} + +TEST_P(EndToEndTest, WebTransportSessionStreamTermination) { + enable_web_transport_ = true; + ASSERT_TRUE(Initialize()); + + if (!version_.UsesHttp3()) { + return; + } + + WebTransportHttp3* session = + CreateWebTransportSession("/resets", /*wait_for_server_response=*/true); + ASSERT_TRUE(session != nullptr); + + NiceMock& visitor = + SetupWebTransportVisitor(session); + EXPECT_CALL(visitor, OnIncomingUnidirectionalStreamAvailable()) + .WillRepeatedly([this, session]() { + ReadAllIncomingWebTransportUnidirectionalStreams(session); + }); + + WebTransportStream* stream = session->OpenOutgoingBidirectionalStream(); + QuicStreamId id1 = stream->GetStreamId(); + ASSERT_TRUE(stream != nullptr); + QUICHE_EXPECT_OK(quiche::WriteIntoStream(*stream, "test")); + stream->ResetWithUserCode(42); + + // This read fails if the stream is closed in both directions, since that + // results in stream object being deleted. + std::string received_data = ReadDataFromWebTransportStreamUntilFin(stream); + EXPECT_LE(received_data.size(), 4u); + + stream = session->OpenOutgoingBidirectionalStream(); + QuicStreamId id2 = stream->GetStreamId(); + ASSERT_TRUE(stream != nullptr); + QUICHE_EXPECT_OK(quiche::WriteIntoStream(*stream, "test")); + stream->SendStopSending(24); + + std::array expected_log = { + absl::StrCat("Received reset for stream ", id1, " with error code 42"), + absl::StrCat("Received stop sending for stream ", id2, + " with error code 24"), + }; + client_->WaitUntil(2000, [this, &expected_log]() { + return received_webtransport_unidirectional_streams_.size() >= + expected_log.size(); + }); + EXPECT_THAT(received_webtransport_unidirectional_streams_, + UnorderedElementsAreArray(expected_log)); + + // Since we closed the read side, cleanly closing the write side should result + // in the stream getting deleted. + ASSERT_TRUE(GetClientSession()->GetOrCreateSpdyDataStream(id2) != nullptr); + EXPECT_TRUE(stream->SendFin()); + EXPECT_TRUE(client_->WaitUntil(2000, [this, id2]() { + return GetClientSession()->GetOrCreateSpdyDataStream(id2) == nullptr; + })); +} + +// This test currently does not pass; we need support for +// https://datatracker.ietf.org/doc/draft-seemann-quic-reliable-stream-reset/ in +// order to make this work. +TEST_P(EndToEndTest, DISABLED_WebTransportSessionResetReliability) { + enable_web_transport_ = true; + ASSERT_TRUE(Initialize()); + + if (!version_.UsesHttp3()) { + return; + } + + SetPacketLossPercentage(30); + + WebTransportHttp3* session = + CreateWebTransportSession("/resets", /*wait_for_server_response=*/true); + ASSERT_TRUE(session != nullptr); + + NiceMock& visitor = + SetupWebTransportVisitor(session); + EXPECT_CALL(visitor, OnIncomingUnidirectionalStreamAvailable()) + .WillRepeatedly([this, session]() { + ReadAllIncomingWebTransportUnidirectionalStreams(session); + }); + + std::vector expected_log; + constexpr int kStreamsToCreate = 10; + for (int i = 0; i < kStreamsToCreate; i++) { + WebTransportStream* stream = session->OpenOutgoingBidirectionalStream(); + QuicStreamId id = stream->GetStreamId(); + ASSERT_TRUE(stream != nullptr); + stream->ResetWithUserCode(42); + + expected_log.push_back( + absl::StrCat("Received reset for stream ", id, " with error code 42")); + } + client_->WaitUntil(2000, [this, &expected_log]() { + return received_webtransport_unidirectional_streams_.size() >= + expected_log.size(); + }); + EXPECT_THAT(received_webtransport_unidirectional_streams_, + UnorderedElementsAreArray(expected_log)); +} + +TEST_P(EndToEndTest, WebTransportSession404) { + enable_web_transport_ = true; + ASSERT_TRUE(Initialize()); + + if (!version_.UsesHttp3()) { + return; + } + + WebTransportHttp3* session = CreateWebTransportSession( + "/does-not-exist", /*wait_for_server_response=*/false); + ASSERT_TRUE(session != nullptr); + QuicSpdyStream* connect_stream = client_->latest_created_stream(); + QuicStreamId connect_stream_id = connect_stream->id(); + + WebTransportStream* stream = session->OpenOutgoingBidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + QUICHE_EXPECT_OK(quiche::WriteIntoStream(*stream, "test")); + EXPECT_TRUE(stream->SendFin()); + + EXPECT_TRUE(client_->WaitUntil(-1, [this, connect_stream_id]() { + return GetClientSession()->GetOrCreateSpdyDataStream(connect_stream_id) == + nullptr; + })); +} + +TEST_P(EndToEndTest, InvalidExtendedConnect) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + ASSERT_TRUE(Initialize()); + + if (!version_.UsesHttp3()) { + return; + } + // Missing :path header. + spdy::Http2HeaderBlock headers; + headers[":scheme"] = "https"; + headers[":authority"] = "localhost"; + headers[":method"] = "CONNECT"; + headers[":protocol"] = "webtransport"; + + client_->SendMessage(headers, "", /*fin=*/false); + client_->WaitForResponse(); + // An early response should be received. + CheckResponseHeaders("400"); +} + +TEST_P(EndToEndTest, RejectExtendedConnect) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + // Disable extended CONNECT. + memory_cache_backend_.set_enable_extended_connect(false); + ASSERT_TRUE(Initialize()); + + if (!version_.UsesHttp3()) { + return; + } + // This extended CONNECT should be rejected. + spdy::Http2HeaderBlock headers; + headers[":scheme"] = "https"; + headers[":authority"] = "localhost"; + headers[":method"] = "CONNECT"; + headers[":path"] = "/echo"; + headers[":protocol"] = "webtransport"; + + client_->SendMessage(headers, "", /*fin=*/false); + client_->WaitForResponse(); + CheckResponseHeaders("400"); + + // Vanilla CONNECT should be sent to backend. + spdy::Http2HeaderBlock headers2; + headers2[":authority"] = "localhost"; + headers2[":method"] = "CONNECT"; + + // Backend not configured/implemented to fully handle CONNECT requests, so + // expect it to send a 405. + client_->SendMessage(headers2, "body", /*fin=*/true); + client_->WaitForResponse(); + CheckResponseHeaders("405"); +} + +TEST_P(EndToEndTest, RejectInvalidRequestHeader) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + ASSERT_TRUE(Initialize()); + + spdy::Http2HeaderBlock headers; + headers[":scheme"] = "https"; + headers[":authority"] = "localhost"; + headers[":method"] = "GET"; + headers[":path"] = "/echo"; + // transfer-encoding header is not allowed. + headers["transfer-encoding"] = "chunk"; + + client_->SendMessage(headers, "", /*fin=*/false); + client_->WaitForResponse(); + CheckResponseHeaders("400"); +} + +TEST_P(EndToEndTest, RejectTransferEncodingResponse) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + ASSERT_TRUE(Initialize()); + + // Add a response with transfer-encoding headers. + Http2HeaderBlock headers; + headers[":status"] = "200"; + headers["transfer-encoding"] = "gzip"; + + Http2HeaderBlock trailers; + trailers["some-trailing-header"] = "trailing-header-value"; + + memory_cache_backend_.AddResponse(server_hostname_, "/eep", + std::move(headers), "", trailers.Clone()); + + std::string received_response = client_->SendSynchronousRequest("/eep"); + EXPECT_THAT(client_->stream_error(), + IsStreamError(QUIC_BAD_APPLICATION_PAYLOAD)); +} + +TEST_P(EndToEndTest, RejectUpperCaseRequest) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + ASSERT_TRUE(Initialize()); + + spdy::Http2HeaderBlock headers; + headers[":scheme"] = "https"; + headers[":authority"] = "localhost"; + headers[":method"] = "GET"; + headers[":path"] = "/echo"; + headers["UpperCaseHeader"] = "foo"; + + client_->SendMessage(headers, "", /*fin=*/false); + client_->WaitForResponse(); + CheckResponseHeaders("400"); +} + +TEST_P(EndToEndTest, RejectRequestWithInvalidToken) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + ASSERT_TRUE(Initialize()); + + spdy::Http2HeaderBlock headers; + headers[":scheme"] = "https"; + headers[":authority"] = "localhost"; + headers[":method"] = "GET"; + headers[":path"] = "/echo"; + headers["invalid,header"] = "foo"; + + client_->SendMessage(headers, "", /*fin=*/false); + client_->WaitForResponse(); + CheckResponseHeaders("400"); +} + +TEST_P(EndToEndTest, OriginalConnectionIdClearedFromMap) { + connect_to_server_on_initialize_ = false; + ASSERT_TRUE(Initialize()); + if (override_client_connection_id_length_ != kLongConnectionIdLength) { + // There might not be an original connection ID. + CreateClientWithWriter(); + return; + } + + server_thread_->Pause(); + QuicDispatcher* dispatcher = + QuicServerPeer::GetDispatcher(server_thread_->server()); + EXPECT_EQ(QuicDispatcherPeer::GetFirstSessionIfAny(dispatcher), nullptr); + server_thread_->Resume(); + + CreateClientWithWriter(); // Also connects. + EXPECT_NE(client_, nullptr); + + server_thread_->Pause(); + EXPECT_NE(QuicDispatcherPeer::GetFirstSessionIfAny(dispatcher), nullptr); + EXPECT_EQ(dispatcher->NumSessions(), 1); + auto ids = GetServerConnection()->GetActiveServerConnectionIds(); + ASSERT_EQ(ids.size(), 2); + for (QuicConnectionId id : ids) { + EXPECT_NE(QuicDispatcherPeer::FindSession(dispatcher, id), nullptr); + } + QuicConnectionId original = ids[1]; + server_thread_->Resume(); + + client_->SendSynchronousRequest("/foo"); + client_->Disconnect(); + + server_thread_->Pause(); + EXPECT_EQ(QuicDispatcherPeer::GetFirstSessionIfAny(dispatcher), nullptr); + EXPECT_EQ(QuicDispatcherPeer::FindSession(dispatcher, original), nullptr); + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, ServerReportsEcn) { + // Client connects using not-ECT. + ASSERT_TRUE(Initialize()); + QuicConnection* client_connection = GetClientConnection(); + QuicConnectionPeer::DisableEcnCodepointValidation(client_connection); + QuicEcnCounts* ecn = QuicSentPacketManagerPeer::GetPeerEcnCounts( + QuicConnectionPeer::GetSentPacketManager(client_connection), + APPLICATION_DATA); + EXPECT_EQ(ecn->ect0, 0); + EXPECT_EQ(ecn->ect1, 0); + EXPECT_EQ(ecn->ce, 0); + QuicPacketCount ect0 = 0, ect1 = 0; + TestPerPacketOptions options; + client_connection->set_per_packet_options(&options); + for (QuicEcnCodepoint codepoint : {ECN_NOT_ECT, ECN_ECT0, ECN_ECT1, ECN_CE}) { + options.ecn_codepoint = codepoint; + client_->SendSynchronousRequest("/foo"); + if (!GetQuicRestartFlag(quic_receive_ecn) || + !GetQuicRestartFlag(quic_quiche_ecn_sockets) || + !VersionHasIetfQuicFrames(version_.transport_version) || + codepoint == ECN_NOT_ECT) { + EXPECT_EQ(ecn->ect0, 0); + EXPECT_EQ(ecn->ect1, 0); + EXPECT_EQ(ecn->ce, 0); + continue; + } + EXPECT_GT(ecn->ect0, 0); + if (codepoint == ECN_CE) { + EXPECT_EQ(ect0, ecn->ect0); // No more ECT(0) arriving + EXPECT_GE(ecn->ect1, ect1); // Late-arriving ECT(1) control packets + EXPECT_GT(ecn->ce, 0); + continue; + } + EXPECT_EQ(ecn->ce, 0); + if (codepoint == ECN_ECT1) { + EXPECT_GE(ecn->ect0, ect0); // Late-arriving ECT(0) control packets + ect0 = ecn->ect0; + ect1 = ecn->ect1; + EXPECT_GT(ect1, 0); + continue; + } + // codepoint == ECN_ECT0 + ect0 = ecn->ect0; + EXPECT_EQ(ecn->ect1, 0); + } + client_->Disconnect(); +} + +TEST_P(EndToEndTest, ClientReportsEcn) { + ASSERT_TRUE(Initialize()); + // Wait for handshake to complete, so that we can manipulate the server + // connection without race conditions. + server_thread_->WaitForCryptoHandshakeConfirmed(); + QuicConnection* server_connection = GetServerConnection(); + QuicConnectionPeer::DisableEcnCodepointValidation(server_connection); + QuicEcnCounts* ecn = QuicSentPacketManagerPeer::GetPeerEcnCounts( + QuicConnectionPeer::GetSentPacketManager(server_connection), + APPLICATION_DATA); + TestPerPacketOptions options; + options.ecn_codepoint = ECN_ECT1; + server_connection->set_per_packet_options(&options); + client_->SendSynchronousRequest("/foo"); + // A second request provides a packet for the client ACKs to go with. + client_->SendSynchronousRequest("/foo"); + server_thread_->Pause(); + EXPECT_EQ(ecn->ect0, 0); + EXPECT_EQ(ecn->ce, 0); + if (!GetQuicRestartFlag(quic_receive_ecn) || + !GetQuicRestartFlag(quic_quiche_ecn_sockets) || + !VersionHasIetfQuicFrames(version_.transport_version)) { + EXPECT_EQ(ecn->ect1, 0); + } else { + EXPECT_GT(ecn->ect1, 0); + } + server_connection->set_per_packet_options(nullptr); + server_thread_->Resume(); + client_->Disconnect(); +} + +TEST_P(EndToEndTest, ClientMigrationAfterHalfwayServerMigration) { + use_preferred_address_ = true; + ASSERT_TRUE(Initialize()); + if (!GetClientConnection()->connection_migration_use_new_cid()) { + return; + } + client_.reset(EndToEndTest::CreateQuicClient(nullptr)); + QuicConnection* client_connection = GetClientConnection(); + EXPECT_TRUE(client_->client()->WaitForHandshakeConfirmed()); + EXPECT_EQ(server_address_, client_connection->effective_peer_address()); + EXPECT_EQ(server_address_, client_connection->peer_address()); + EXPECT_TRUE(client_->client()->HasPendingPathValidation()); + QuicConnectionId server_cid1 = client_connection->connection_id(); + + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_TRUE(client_->WaitUntil( + 1000, [&]() { return !client_->client()->HasPendingPathValidation(); })); + EXPECT_EQ(server_preferred_address_, + client_connection->effective_peer_address()); + EXPECT_EQ(server_preferred_address_, client_connection->peer_address()); + EXPECT_NE(server_cid1, client_connection->connection_id()); + EXPECT_EQ(0u, + client_connection->GetStats().num_connectivity_probing_received); + const auto client_stats = GetClientConnection()->GetStats(); + EXPECT_TRUE(client_stats.server_preferred_address_validated); + EXPECT_FALSE(client_stats.failed_to_validate_server_preferred_address); + + WaitForNewConnectionIds(); + // Migrate socket to a new IP address. + QuicIpAddress host = TestLoopback(2); + ASSERT_NE( + client_->client()->network_helper()->GetLatestClientAddress().host(), + host); + ASSERT_TRUE(client_->client()->ValidateAndMigrateSocket(host)); + EXPECT_TRUE(client_->WaitUntil( + 1000, [&]() { return !client_->client()->HasPendingPathValidation(); })); + EXPECT_EQ(host, client_->client()->session()->self_address().host()); + + SendSynchronousBarRequestAndCheckResponse(); + + // Wait for the PATH_CHALLENGE. + EXPECT_TRUE(client_->WaitUntil(1000, [&]() { + return client_connection->GetStats().num_connectivity_probing_received >= 1; + })); + + // Send another request to ensure that the server will have time to finish the + // reverse path validation and send address token. + SendSynchronousBarRequestAndCheckResponse(); + // By the time the above request is completed, the PATH_RESPONSE must have + // been received by the server. Check server stats. + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + EXPECT_FALSE(server_connection->HasPendingPathValidation()); + EXPECT_EQ(2u, server_connection->GetStats().num_validated_peer_migration); + EXPECT_EQ(2u, server_connection->GetStats().num_new_connection_id_sent); + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, MultiPortCreationFollowingServerMigration) { + use_preferred_address_ = true; + ASSERT_TRUE(Initialize()); + if (!GetClientConnection()->connection_migration_use_new_cid()) { + return; + } + + client_config_.SetClientConnectionOptions(QuicTagVector{kMPQC}); + client_.reset(EndToEndTest::CreateQuicClient(nullptr)); + QuicConnection* client_connection = GetClientConnection(); + EXPECT_TRUE(client_->client()->WaitForHandshakeConfirmed()); + EXPECT_EQ(server_address_, client_connection->effective_peer_address()); + EXPECT_EQ(server_address_, client_connection->peer_address()); + QuicConnectionId server_cid1 = client_connection->connection_id(); + EXPECT_TRUE(client_connection->IsValidatingServerPreferredAddress()); + + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_TRUE(client_->WaitUntil(1000, [&]() { + return !client_connection->IsValidatingServerPreferredAddress(); + })); + EXPECT_EQ(server_preferred_address_, + client_connection->effective_peer_address()); + EXPECT_EQ(server_preferred_address_, client_connection->peer_address()); + const auto client_stats = GetClientConnection()->GetStats(); + EXPECT_TRUE(client_stats.server_preferred_address_validated); + EXPECT_FALSE(client_stats.failed_to_validate_server_preferred_address); + + QuicConnectionId server_cid2 = client_connection->connection_id(); + EXPECT_NE(server_cid1, server_cid2); + EXPECT_TRUE(client_->WaitUntil(1000, [&]() { + return client_connection->GetStats().num_path_response_received == 2; + })); + EXPECT_TRUE( + QuicConnectionPeer::IsAlternativePathValidated(client_connection)); + QuicConnectionId server_cid3 = + QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection); + EXPECT_NE(server_cid2, server_cid3); + EXPECT_NE(server_cid1, server_cid3); +} + +TEST_P(EndToEndTest, DoNotAdvertisePreferredAddressWithoutSPAD) { + if (!version_.HasIetfQuicFrames()) { + ASSERT_TRUE(Initialize()); + return; + } + server_config_.SetIPv4AlternateServerAddressToSend( + QuicSocketAddress(QuicIpAddress::Any4(), 12345)); + server_config_.SetIPv6AlternateServerAddressToSend( + QuicSocketAddress(QuicIpAddress::Any6(), 12345)); + NiceMock visitor; + connection_debug_visitor_ = &visitor; + EXPECT_CALL(visitor, OnTransportParametersReceived(_)) + .WillOnce(Invoke([](const TransportParameters& transport_parameters) { + EXPECT_EQ(nullptr, transport_parameters.preferred_address); + })); + ASSERT_TRUE(Initialize()); + EXPECT_TRUE(client_->client()->WaitForHandshakeConfirmed()); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/http/http_constants.cc b/quiche/quic/core/http/http_constants.cc new file mode 100644 index 000000000000..e5855350599e --- /dev/null +++ b/quiche/quic/core/http/http_constants.cc @@ -0,0 +1,33 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/http_constants.h" + +#include "absl/strings/str_cat.h" + +namespace quic { + +#define RETURN_STRING_LITERAL(x) \ + case x: \ + return #x; + +std::string H3SettingsToString(Http3AndQpackSettingsIdentifiers identifier) { + switch (identifier) { + RETURN_STRING_LITERAL(SETTINGS_QPACK_MAX_TABLE_CAPACITY); + RETURN_STRING_LITERAL(SETTINGS_MAX_FIELD_SECTION_SIZE); + RETURN_STRING_LITERAL(SETTINGS_QPACK_BLOCKED_STREAMS); + RETURN_STRING_LITERAL(SETTINGS_H3_DATAGRAM_DRAFT04); + RETURN_STRING_LITERAL(SETTINGS_H3_DATAGRAM); + RETURN_STRING_LITERAL(SETTINGS_WEBTRANS_DRAFT00); + RETURN_STRING_LITERAL(SETTINGS_ENABLE_CONNECT_PROTOCOL); + RETURN_STRING_LITERAL(SETTINGS_ENABLE_METADATA); + } + return absl::StrCat("UNSUPPORTED_SETTINGS_TYPE(", identifier, ")"); +} + +ABSL_CONST_INIT const absl::string_view kUserAgentHeaderName = "user-agent"; + +#undef RETURN_STRING_LITERAL // undef for jumbo builds + +} // namespace quic diff --git a/quiche/quic/core/http/http_constants.h b/quiche/quic/core/http/http_constants.h new file mode 100644 index 000000000000..feec52192386 --- /dev/null +++ b/quiche/quic/core/http/http_constants.h @@ -0,0 +1,77 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_HTTP_HTTP_CONSTANTS_H_ +#define QUICHE_QUIC_CORE_HTTP_HTTP_CONSTANTS_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// Unidirectional stream types. +enum : uint64_t { + // https://quicwg.org/base-drafts/draft-ietf-quic-http.html#unidirectional-streams + kControlStream = 0x00, + kServerPushStream = 0x01, + // https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#enc-dec-stream-def + kQpackEncoderStream = 0x02, + kQpackDecoderStream = 0x03, + // https://ietf-wg-webtrans.github.io/draft-ietf-webtrans-http3/draft-ietf-webtrans-http3.html#name-unidirectional-streams + kWebTransportUnidirectionalStream = 0x54, +}; + +// This includes control stream, QPACK encoder stream, and QPACK decoder stream. +enum : QuicStreamCount { kHttp3StaticUnidirectionalStreamCount = 3 }; + +// HTTP/3 and QPACK settings identifiers. +// https://quicwg.org/base-drafts/draft-ietf-quic-http.html#settings-parameters +// https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#configuration +enum Http3AndQpackSettingsIdentifiers : uint64_t { + // Same value as spdy::SETTINGS_HEADER_TABLE_SIZE. + SETTINGS_QPACK_MAX_TABLE_CAPACITY = 0x01, + // Same value as spdy::SETTINGS_MAX_HEADER_LIST_SIZE. + SETTINGS_MAX_FIELD_SECTION_SIZE = 0x06, + SETTINGS_QPACK_BLOCKED_STREAMS = 0x07, + // draft-ietf-masque-h3-datagram-04. + SETTINGS_H3_DATAGRAM_DRAFT04 = 0xffd277, + // RFC 9297. + SETTINGS_H3_DATAGRAM = 0x33, + // draft-ietf-webtrans-http3-00 + SETTINGS_WEBTRANS_DRAFT00 = 0x2b603742, + // draft-ietf-httpbis-h3-websockets + SETTINGS_ENABLE_CONNECT_PROTOCOL = 0x08, + SETTINGS_ENABLE_METADATA = 0x4d44, +}; + +// Returns HTTP/3 SETTINGS identifier as a string. +QUIC_EXPORT std::string H3SettingsToString( + Http3AndQpackSettingsIdentifiers identifier); + +// Default maximum dynamic table capacity, communicated via +// SETTINGS_QPACK_MAX_TABLE_CAPACITY. +enum : QuicByteCount { + kDefaultQpackMaxDynamicTableCapacity = 64 * 1024 // 64 KB +}; + +// Default limit on the size of uncompressed headers, +// communicated via SETTINGS_MAX_HEADER_LIST_SIZE. +enum : QuicByteCount { + kDefaultMaxUncompressedHeaderSize = 16 * 1024 // 16 KB +}; + +// Default limit on number of blocked streams, communicated via +// SETTINGS_QPACK_BLOCKED_STREAMS. +enum : uint64_t { kDefaultMaximumBlockedStreams = 100 }; + +ABSL_CONST_INIT QUIC_EXPORT_PRIVATE extern const absl::string_view + kUserAgentHeaderName; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_HTTP_HTTP_CONSTANTS_H_ diff --git a/quiche/quic/core/http/http_decoder.cc b/quiche/quic/core/http/http_decoder.cc new file mode 100644 index 000000000000..5facad8d4f30 --- /dev/null +++ b/quiche/quic/core/http/http_decoder.cc @@ -0,0 +1,683 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/http_decoder.h" + +#include + +#include "absl/base/attributes.h" +#include "absl/strings/string_view.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/quic/core/http/http_frames.h" +#include "quiche/quic/core/quic_data_reader.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +namespace { + +// Limit on the payload length for frames that are buffered by HttpDecoder. +// If a frame header indicating a payload length exceeding this limit is +// received, HttpDecoder closes the connection. Does not apply to frames that +// are not buffered here but each payload fragment is immediately passed to +// Visitor, like HEADERS, DATA, and unknown frames. +constexpr QuicByteCount kPayloadLengthLimit = 1024 * 1024; + +} // anonymous namespace + +HttpDecoder::HttpDecoder(Visitor* visitor) : HttpDecoder(visitor, Options()) {} +HttpDecoder::HttpDecoder(Visitor* visitor, Options options) + : visitor_(visitor), + allow_web_transport_stream_(options.allow_web_transport_stream), + state_(STATE_READING_FRAME_TYPE), + current_frame_type_(0), + current_length_field_length_(0), + remaining_length_field_length_(0), + current_frame_length_(0), + remaining_frame_length_(0), + current_type_field_length_(0), + remaining_type_field_length_(0), + error_(QUIC_NO_ERROR), + error_detail_("") { + QUICHE_DCHECK(visitor_); +} + +HttpDecoder::~HttpDecoder() {} + +// static +bool HttpDecoder::DecodeSettings(const char* data, QuicByteCount len, + SettingsFrame* frame) { + QuicDataReader reader(data, len); + uint64_t frame_type; + if (!reader.ReadVarInt62(&frame_type)) { + QUIC_DLOG(ERROR) << "Unable to read frame type."; + return false; + } + + if (frame_type != static_cast(HttpFrameType::SETTINGS)) { + QUIC_DLOG(ERROR) << "Invalid frame type " << frame_type; + return false; + } + + absl::string_view frame_contents; + if (!reader.ReadStringPieceVarInt62(&frame_contents)) { + QUIC_DLOG(ERROR) << "Failed to read SETTINGS frame contents"; + return false; + } + + QuicDataReader frame_reader(frame_contents); + + while (!frame_reader.IsDoneReading()) { + uint64_t id; + if (!frame_reader.ReadVarInt62(&id)) { + QUIC_DLOG(ERROR) << "Unable to read setting identifier."; + return false; + } + uint64_t content; + if (!frame_reader.ReadVarInt62(&content)) { + QUIC_DLOG(ERROR) << "Unable to read setting value."; + return false; + } + auto result = frame->values.insert({id, content}); + if (!result.second) { + QUIC_DLOG(ERROR) << "Duplicate setting identifier."; + return false; + } + } + return true; +} + +QuicByteCount HttpDecoder::ProcessInput(const char* data, QuicByteCount len) { + QUICHE_DCHECK_EQ(QUIC_NO_ERROR, error_); + QUICHE_DCHECK_NE(STATE_ERROR, state_); + + QuicDataReader reader(data, len); + bool continue_processing = true; + // BufferOrParsePayload() and FinishParsing() may need to be called even if + // there is no more data so that they can finish processing the current frame. + while (continue_processing && (reader.BytesRemaining() != 0 || + state_ == STATE_BUFFER_OR_PARSE_PAYLOAD || + state_ == STATE_FINISH_PARSING)) { + // |continue_processing| must have been set to false upon error. + QUICHE_DCHECK_EQ(QUIC_NO_ERROR, error_); + QUICHE_DCHECK_NE(STATE_ERROR, state_); + + switch (state_) { + case STATE_READING_FRAME_TYPE: + continue_processing = ReadFrameType(&reader); + break; + case STATE_READING_FRAME_LENGTH: + continue_processing = ReadFrameLength(&reader); + break; + case STATE_BUFFER_OR_PARSE_PAYLOAD: + continue_processing = BufferOrParsePayload(&reader); + break; + case STATE_READING_FRAME_PAYLOAD: + continue_processing = ReadFramePayload(&reader); + break; + case STATE_FINISH_PARSING: + continue_processing = FinishParsing(); + break; + case STATE_PARSING_NO_LONGER_POSSIBLE: + continue_processing = false; + QUIC_BUG(HttpDecoder PARSING_NO_LONGER_POSSIBLE) + << "HttpDecoder called after an indefinite-length frame has been " + "received"; + RaiseError(QUIC_INTERNAL_ERROR, + "HttpDecoder called after an indefinite-length frame has " + "been received"); + break; + case STATE_ERROR: + break; + default: + QUIC_BUG(quic_bug_10411_1) << "Invalid state: " << state_; + } + } + + return len - reader.BytesRemaining(); +} + +bool HttpDecoder::ReadFrameType(QuicDataReader* reader) { + QUICHE_DCHECK_NE(0u, reader->BytesRemaining()); + if (current_type_field_length_ == 0) { + // A new frame is coming. + current_type_field_length_ = reader->PeekVarInt62Length(); + QUICHE_DCHECK_NE(0u, current_type_field_length_); + if (current_type_field_length_ > reader->BytesRemaining()) { + // Buffer a new type field. + remaining_type_field_length_ = current_type_field_length_; + BufferFrameType(reader); + return true; + } + // The reader has all type data needed, so no need to buffer. + bool success = reader->ReadVarInt62(¤t_frame_type_); + QUICHE_DCHECK(success); + } else { + // Buffer the existing type field. + BufferFrameType(reader); + // The frame is still not buffered completely. + if (remaining_type_field_length_ != 0) { + return true; + } + QuicDataReader type_reader(type_buffer_.data(), current_type_field_length_); + bool success = type_reader.ReadVarInt62(¤t_frame_type_); + QUICHE_DCHECK(success); + } + + // https://tools.ietf.org/html/draft-ietf-quic-http-31#section-7.2.8 + // specifies that the following frames are treated as errors. + if (current_frame_type_ == + static_cast(http2::Http2FrameType::PRIORITY) || + current_frame_type_ == + static_cast(http2::Http2FrameType::PING) || + current_frame_type_ == + static_cast(http2::Http2FrameType::WINDOW_UPDATE) || + current_frame_type_ == + static_cast(http2::Http2FrameType::CONTINUATION)) { + RaiseError(QUIC_HTTP_RECEIVE_SPDY_FRAME, + absl::StrCat("HTTP/2 frame received in a HTTP/3 connection: ", + current_frame_type_)); + return false; + } + + if (current_frame_type_ == + static_cast(HttpFrameType::CANCEL_PUSH)) { + RaiseError(QUIC_HTTP_FRAME_ERROR, "CANCEL_PUSH frame received."); + return false; + } + if (current_frame_type_ == + static_cast(HttpFrameType::PUSH_PROMISE)) { + RaiseError(QUIC_HTTP_FRAME_ERROR, "PUSH_PROMISE frame received."); + return false; + } + + state_ = STATE_READING_FRAME_LENGTH; + return true; +} + +bool HttpDecoder::ReadFrameLength(QuicDataReader* reader) { + QUICHE_DCHECK_NE(0u, reader->BytesRemaining()); + if (current_length_field_length_ == 0) { + // A new frame is coming. + current_length_field_length_ = reader->PeekVarInt62Length(); + QUICHE_DCHECK_NE(0u, current_length_field_length_); + if (current_length_field_length_ > reader->BytesRemaining()) { + // Buffer a new length field. + remaining_length_field_length_ = current_length_field_length_; + BufferFrameLength(reader); + return true; + } + // The reader has all length data needed, so no need to buffer. + bool success = reader->ReadVarInt62(¤t_frame_length_); + QUICHE_DCHECK(success); + } else { + // Buffer the existing length field. + BufferFrameLength(reader); + // The frame is still not buffered completely. + if (remaining_length_field_length_ != 0) { + return true; + } + QuicDataReader length_reader(length_buffer_.data(), + current_length_field_length_); + bool success = length_reader.ReadVarInt62(¤t_frame_length_); + QUICHE_DCHECK(success); + } + + // WEBTRANSPORT_STREAM frames are indefinitely long, and thus require + // special handling; the number after the frame type is actually the + // WebTransport session ID, and not the length. + if (allow_web_transport_stream_ && + current_frame_type_ == + static_cast(HttpFrameType::WEBTRANSPORT_STREAM)) { + visitor_->OnWebTransportStreamFrameType( + current_length_field_length_ + current_type_field_length_, + current_frame_length_); + state_ = STATE_PARSING_NO_LONGER_POSSIBLE; + return false; + } + + if (IsFrameBuffered() && + current_frame_length_ > MaxFrameLength(current_frame_type_)) { + RaiseError(QUIC_HTTP_FRAME_TOO_LARGE, "Frame is too large."); + return false; + } + + // Calling the following visitor methods does not require parsing of any + // frame payload. + bool continue_processing = true; + const QuicByteCount header_length = + current_length_field_length_ + current_type_field_length_; + + switch (current_frame_type_) { + case static_cast(HttpFrameType::DATA): + continue_processing = + visitor_->OnDataFrameStart(header_length, current_frame_length_); + break; + case static_cast(HttpFrameType::HEADERS): + continue_processing = + visitor_->OnHeadersFrameStart(header_length, current_frame_length_); + break; + case static_cast(HttpFrameType::CANCEL_PUSH): + QUICHE_NOTREACHED(); + break; + case static_cast(HttpFrameType::SETTINGS): + continue_processing = visitor_->OnSettingsFrameStart(header_length); + break; + case static_cast(HttpFrameType::PUSH_PROMISE): + QUICHE_NOTREACHED(); + break; + case static_cast(HttpFrameType::GOAWAY): + break; + case static_cast(HttpFrameType::MAX_PUSH_ID): + break; + case static_cast(HttpFrameType::PRIORITY_UPDATE_REQUEST_STREAM): + continue_processing = visitor_->OnPriorityUpdateFrameStart(header_length); + break; + case static_cast(HttpFrameType::ACCEPT_CH): + continue_processing = visitor_->OnAcceptChFrameStart(header_length); + break; + default: + continue_processing = visitor_->OnUnknownFrameStart( + current_frame_type_, header_length, current_frame_length_); + break; + } + + remaining_frame_length_ = current_frame_length_; + + if (IsFrameBuffered()) { + state_ = STATE_BUFFER_OR_PARSE_PAYLOAD; + return continue_processing; + } + + state_ = (remaining_frame_length_ == 0) ? STATE_FINISH_PARSING + : STATE_READING_FRAME_PAYLOAD; + return continue_processing; +} + +bool HttpDecoder::IsFrameBuffered() { + switch (current_frame_type_) { + case static_cast(HttpFrameType::SETTINGS): + return true; + case static_cast(HttpFrameType::GOAWAY): + return true; + case static_cast(HttpFrameType::MAX_PUSH_ID): + return true; + case static_cast(HttpFrameType::PRIORITY_UPDATE_REQUEST_STREAM): + return true; + case static_cast(HttpFrameType::ACCEPT_CH): + return true; + } + + // Other defined frame types as well as unknown frames are not buffered. + return false; +} + +bool HttpDecoder::ReadFramePayload(QuicDataReader* reader) { + QUICHE_DCHECK(!IsFrameBuffered()); + QUICHE_DCHECK_NE(0u, reader->BytesRemaining()); + QUICHE_DCHECK_NE(0u, remaining_frame_length_); + + bool continue_processing = true; + + switch (current_frame_type_) { + case static_cast(HttpFrameType::DATA): { + QuicByteCount bytes_to_read = std::min( + remaining_frame_length_, reader->BytesRemaining()); + absl::string_view payload; + bool success = reader->ReadStringPiece(&payload, bytes_to_read); + QUICHE_DCHECK(success); + QUICHE_DCHECK(!payload.empty()); + continue_processing = visitor_->OnDataFramePayload(payload); + remaining_frame_length_ -= payload.length(); + break; + } + case static_cast(HttpFrameType::HEADERS): { + QuicByteCount bytes_to_read = std::min( + remaining_frame_length_, reader->BytesRemaining()); + absl::string_view payload; + bool success = reader->ReadStringPiece(&payload, bytes_to_read); + QUICHE_DCHECK(success); + QUICHE_DCHECK(!payload.empty()); + continue_processing = visitor_->OnHeadersFramePayload(payload); + remaining_frame_length_ -= payload.length(); + break; + } + case static_cast(HttpFrameType::CANCEL_PUSH): { + QUICHE_NOTREACHED(); + break; + } + case static_cast(HttpFrameType::SETTINGS): { + QUICHE_NOTREACHED(); + break; + } + case static_cast(HttpFrameType::PUSH_PROMISE): { + QUICHE_NOTREACHED(); + break; + } + case static_cast(HttpFrameType::GOAWAY): { + QUICHE_NOTREACHED(); + break; + } + case static_cast(HttpFrameType::MAX_PUSH_ID): { + QUICHE_NOTREACHED(); + break; + } + case static_cast(HttpFrameType::PRIORITY_UPDATE_REQUEST_STREAM): { + QUICHE_NOTREACHED(); + break; + } + case static_cast(HttpFrameType::ACCEPT_CH): { + QUICHE_NOTREACHED(); + break; + } + default: { + continue_processing = HandleUnknownFramePayload(reader); + break; + } + } + + if (remaining_frame_length_ == 0) { + state_ = STATE_FINISH_PARSING; + } + + return continue_processing; +} + +bool HttpDecoder::FinishParsing() { + QUICHE_DCHECK(!IsFrameBuffered()); + QUICHE_DCHECK_EQ(0u, remaining_frame_length_); + + bool continue_processing = true; + + switch (current_frame_type_) { + case static_cast(HttpFrameType::DATA): { + continue_processing = visitor_->OnDataFrameEnd(); + break; + } + case static_cast(HttpFrameType::HEADERS): { + continue_processing = visitor_->OnHeadersFrameEnd(); + break; + } + case static_cast(HttpFrameType::CANCEL_PUSH): { + QUICHE_NOTREACHED(); + break; + } + case static_cast(HttpFrameType::SETTINGS): { + QUICHE_NOTREACHED(); + break; + } + case static_cast(HttpFrameType::PUSH_PROMISE): { + QUICHE_NOTREACHED(); + break; + } + case static_cast(HttpFrameType::GOAWAY): { + QUICHE_NOTREACHED(); + break; + } + case static_cast(HttpFrameType::MAX_PUSH_ID): { + QUICHE_NOTREACHED(); + break; + } + case static_cast(HttpFrameType::PRIORITY_UPDATE_REQUEST_STREAM): { + QUICHE_NOTREACHED(); + break; + } + case static_cast(HttpFrameType::ACCEPT_CH): { + QUICHE_NOTREACHED(); + break; + } + default: + continue_processing = visitor_->OnUnknownFrameEnd(); + } + + ResetForNextFrame(); + return continue_processing; +} + +void HttpDecoder::ResetForNextFrame() { + current_length_field_length_ = 0; + current_type_field_length_ = 0; + state_ = STATE_READING_FRAME_TYPE; +} + +bool HttpDecoder::HandleUnknownFramePayload(QuicDataReader* reader) { + QuicByteCount bytes_to_read = std::min( + remaining_frame_length_, reader->BytesRemaining()); + absl::string_view payload; + bool success = reader->ReadStringPiece(&payload, bytes_to_read); + QUICHE_DCHECK(success); + QUICHE_DCHECK(!payload.empty()); + remaining_frame_length_ -= payload.length(); + return visitor_->OnUnknownFramePayload(payload); +} + +bool HttpDecoder::BufferOrParsePayload(QuicDataReader* reader) { + QUICHE_DCHECK(IsFrameBuffered()); + QUICHE_DCHECK_EQ(current_frame_length_, + buffer_.size() + remaining_frame_length_); + + if (buffer_.empty() && reader->BytesRemaining() >= current_frame_length_) { + // |*reader| contains entire payload, which might be empty. + remaining_frame_length_ = 0; + QuicDataReader current_payload_reader(reader->PeekRemainingPayload().data(), + current_frame_length_); + bool continue_processing = ParseEntirePayload(¤t_payload_reader); + + reader->Seek(current_frame_length_); + ResetForNextFrame(); + return continue_processing; + } + + // Buffer as much of the payload as |*reader| contains. + QuicByteCount bytes_to_read = std::min( + remaining_frame_length_, reader->BytesRemaining()); + absl::StrAppend(&buffer_, reader->PeekRemainingPayload().substr( + /* pos = */ 0, bytes_to_read)); + reader->Seek(bytes_to_read); + remaining_frame_length_ -= bytes_to_read; + + QUICHE_DCHECK_EQ(current_frame_length_, + buffer_.size() + remaining_frame_length_); + + if (remaining_frame_length_ > 0) { + QUICHE_DCHECK(reader->IsDoneReading()); + return false; + } + + QuicDataReader buffer_reader(buffer_); + bool continue_processing = ParseEntirePayload(&buffer_reader); + buffer_.clear(); + + ResetForNextFrame(); + return continue_processing; +} + +bool HttpDecoder::ParseEntirePayload(QuicDataReader* reader) { + QUICHE_DCHECK(IsFrameBuffered()); + QUICHE_DCHECK_EQ(current_frame_length_, reader->BytesRemaining()); + QUICHE_DCHECK_EQ(0u, remaining_frame_length_); + + switch (current_frame_type_) { + case static_cast(HttpFrameType::CANCEL_PUSH): { + QUICHE_NOTREACHED(); + return false; + } + case static_cast(HttpFrameType::SETTINGS): { + SettingsFrame frame; + if (!ParseSettingsFrame(reader, &frame)) { + return false; + } + return visitor_->OnSettingsFrame(frame); + } + case static_cast(HttpFrameType::GOAWAY): { + GoAwayFrame frame; + if (!reader->ReadVarInt62(&frame.id)) { + RaiseError(QUIC_HTTP_FRAME_ERROR, "Unable to read GOAWAY ID."); + return false; + } + if (!reader->IsDoneReading()) { + RaiseError(QUIC_HTTP_FRAME_ERROR, "Superfluous data in GOAWAY frame."); + return false; + } + return visitor_->OnGoAwayFrame(frame); + } + case static_cast(HttpFrameType::MAX_PUSH_ID): { + uint64_t unused; + if (!reader->ReadVarInt62(&unused)) { + RaiseError(QUIC_HTTP_FRAME_ERROR, + "Unable to read MAX_PUSH_ID push_id."); + return false; + } + if (!reader->IsDoneReading()) { + RaiseError(QUIC_HTTP_FRAME_ERROR, + "Superfluous data in MAX_PUSH_ID frame."); + return false; + } + return visitor_->OnMaxPushIdFrame(); + } + case static_cast(HttpFrameType::PRIORITY_UPDATE_REQUEST_STREAM): { + PriorityUpdateFrame frame; + if (!ParsePriorityUpdateFrame(reader, &frame)) { + return false; + } + return visitor_->OnPriorityUpdateFrame(frame); + } + case static_cast(HttpFrameType::ACCEPT_CH): { + AcceptChFrame frame; + if (!ParseAcceptChFrame(reader, &frame)) { + return false; + } + return visitor_->OnAcceptChFrame(frame); + } + default: + // Only above frame types are parsed by ParseEntirePayload(). + QUICHE_NOTREACHED(); + return false; + } +} + +void HttpDecoder::BufferFrameLength(QuicDataReader* reader) { + QuicByteCount bytes_to_read = std::min( + remaining_length_field_length_, reader->BytesRemaining()); + bool success = + reader->ReadBytes(length_buffer_.data() + current_length_field_length_ - + remaining_length_field_length_, + bytes_to_read); + QUICHE_DCHECK(success); + remaining_length_field_length_ -= bytes_to_read; +} + +void HttpDecoder::BufferFrameType(QuicDataReader* reader) { + QuicByteCount bytes_to_read = std::min( + remaining_type_field_length_, reader->BytesRemaining()); + bool success = + reader->ReadBytes(type_buffer_.data() + current_type_field_length_ - + remaining_type_field_length_, + bytes_to_read); + QUICHE_DCHECK(success); + remaining_type_field_length_ -= bytes_to_read; +} + +void HttpDecoder::RaiseError(QuicErrorCode error, std::string error_detail) { + state_ = STATE_ERROR; + error_ = error; + error_detail_ = std::move(error_detail); + visitor_->OnError(this); +} + +bool HttpDecoder::ParseSettingsFrame(QuicDataReader* reader, + SettingsFrame* frame) { + while (!reader->IsDoneReading()) { + uint64_t id; + if (!reader->ReadVarInt62(&id)) { + RaiseError(QUIC_HTTP_FRAME_ERROR, "Unable to read setting identifier."); + return false; + } + uint64_t content; + if (!reader->ReadVarInt62(&content)) { + RaiseError(QUIC_HTTP_FRAME_ERROR, "Unable to read setting value."); + return false; + } + auto result = frame->values.insert({id, content}); + if (!result.second) { + RaiseError(QUIC_HTTP_DUPLICATE_SETTING_IDENTIFIER, + "Duplicate setting identifier."); + return false; + } + } + return true; +} + +bool HttpDecoder::ParsePriorityUpdateFrame(QuicDataReader* reader, + PriorityUpdateFrame* frame) { + if (!reader->ReadVarInt62(&frame->prioritized_element_id)) { + RaiseError(QUIC_HTTP_FRAME_ERROR, "Unable to read prioritized element id."); + return false; + } + + absl::string_view priority_field_value = reader->ReadRemainingPayload(); + frame->priority_field_value = + std::string(priority_field_value.data(), priority_field_value.size()); + + return true; +} + +bool HttpDecoder::ParseAcceptChFrame(QuicDataReader* reader, + AcceptChFrame* frame) { + absl::string_view origin; + absl::string_view value; + while (!reader->IsDoneReading()) { + if (!reader->ReadStringPieceVarInt62(&origin)) { + RaiseError(QUIC_HTTP_FRAME_ERROR, "Unable to read ACCEPT_CH origin."); + return false; + } + if (!reader->ReadStringPieceVarInt62(&value)) { + RaiseError(QUIC_HTTP_FRAME_ERROR, "Unable to read ACCEPT_CH value."); + return false; + } + // Copy data. + frame->entries.push_back({std::string(origin.data(), origin.size()), + std::string(value.data(), value.size())}); + } + return true; +} + +QuicByteCount HttpDecoder::MaxFrameLength(uint64_t frame_type) { + QUICHE_DCHECK(IsFrameBuffered()); + + switch (frame_type) { + case static_cast(HttpFrameType::SETTINGS): + return kPayloadLengthLimit; + case static_cast(HttpFrameType::GOAWAY): + return quiche::VARIABLE_LENGTH_INTEGER_LENGTH_8; + case static_cast(HttpFrameType::MAX_PUSH_ID): + return quiche::VARIABLE_LENGTH_INTEGER_LENGTH_8; + case static_cast(HttpFrameType::PRIORITY_UPDATE_REQUEST_STREAM): + return kPayloadLengthLimit; + case static_cast(HttpFrameType::ACCEPT_CH): + return kPayloadLengthLimit; + default: + QUICHE_NOTREACHED(); + return 0; + } +} + +std::string HttpDecoder::DebugString() const { + return absl::StrCat( + "HttpDecoder:", "\n state: ", state_, "\n error: ", error_, + "\n current_frame_type: ", current_frame_type_, + "\n current_length_field_length: ", current_length_field_length_, + "\n remaining_length_field_length: ", remaining_length_field_length_, + "\n current_frame_length: ", current_frame_length_, + "\n remaining_frame_length: ", remaining_frame_length_, + "\n current_type_field_length: ", current_type_field_length_, + "\n remaining_type_field_length: ", remaining_type_field_length_); +} + +} // namespace quic diff --git a/quiche/quic/core/http/http_decoder.h b/quiche/quic/core/http/http_decoder.h new file mode 100644 index 000000000000..0e49c3f23817 --- /dev/null +++ b/quiche/quic/core/http/http_decoder.h @@ -0,0 +1,278 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_HTTP_HTTP_DECODER_H_ +#define QUICHE_QUIC_CORE_HTTP_HTTP_DECODER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/http/http_frames.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +namespace test { + +class HttpDecoderPeer; + +} // namespace test + +class QuicDataReader; + +// A class for decoding the HTTP frames that are exchanged in an HTTP over QUIC +// session. +class QUIC_EXPORT_PRIVATE HttpDecoder { + public: + struct QUIC_EXPORT_PRIVATE Options { + // Indicates that WEBTRANSPORT_STREAM should be parsed. + bool allow_web_transport_stream = false; + }; + + class QUIC_EXPORT_PRIVATE Visitor { + public: + virtual ~Visitor() {} + + // Called if an error is detected. + virtual void OnError(HttpDecoder* decoder) = 0; + + // All the following methods return true to continue decoding, + // and false to pause it. + // On*FrameStart() methods are called after the frame header is completely + // processed. At that point it is safe to consume |header_length| bytes. + + // Called when a MAX_PUSH_ID frame has been successfully parsed. + virtual bool OnMaxPushIdFrame() = 0; + + // Called when a GOAWAY frame has been successfully parsed. + virtual bool OnGoAwayFrame(const GoAwayFrame& frame) = 0; + + // Called when a SETTINGS frame has been received. + virtual bool OnSettingsFrameStart(QuicByteCount header_length) = 0; + + // Called when a SETTINGS frame has been successfully parsed. + virtual bool OnSettingsFrame(const SettingsFrame& frame) = 0; + + // Called when a DATA frame has been received. + // |header_length| and |payload_length| are the length of DATA frame header + // and payload, respectively. + virtual bool OnDataFrameStart(QuicByteCount header_length, + QuicByteCount payload_length) = 0; + // Called when part of the payload of a DATA frame has been read. May be + // called multiple times for a single frame. |payload| is guaranteed to be + // non-empty. + virtual bool OnDataFramePayload(absl::string_view payload) = 0; + // Called when a DATA frame has been completely processed. + virtual bool OnDataFrameEnd() = 0; + + // Called when a HEADERS frame has been received. + // |header_length| and |payload_length| are the length of HEADERS frame + // header and payload, respectively. + virtual bool OnHeadersFrameStart(QuicByteCount header_length, + QuicByteCount payload_length) = 0; + // Called when part of the payload of a HEADERS frame has been read. May be + // called multiple times for a single frame. |payload| is guaranteed to be + // non-empty. + virtual bool OnHeadersFramePayload(absl::string_view payload) = 0; + // Called when a HEADERS frame has been completely processed. + virtual bool OnHeadersFrameEnd() = 0; + + // Called when a PRIORITY_UPDATE frame has been received. + // |header_length| contains PRIORITY_UPDATE frame length and payload length. + virtual bool OnPriorityUpdateFrameStart(QuicByteCount header_length) = 0; + + // Called when a PRIORITY_UPDATE frame has been successfully parsed. + virtual bool OnPriorityUpdateFrame(const PriorityUpdateFrame& frame) = 0; + + // Called when an ACCEPT_CH frame has been received. + // |header_length| contains ACCEPT_CH frame length and payload length. + virtual bool OnAcceptChFrameStart(QuicByteCount header_length) = 0; + + // Called when an ACCEPT_CH frame has been successfully parsed. + virtual bool OnAcceptChFrame(const AcceptChFrame& frame) = 0; + + // Called when a WEBTRANSPORT_STREAM frame type and the session ID varint + // immediately following it has been received. Any further parsing should + // be done by the stream itself, and not the parser. Note that this does not + // return bool, because WEBTRANSPORT_STREAM always causes the parsing + // process to cease. + virtual void OnWebTransportStreamFrameType( + QuicByteCount header_length, WebTransportSessionId session_id) = 0; + + // Called when a frame of unknown type |frame_type| has been received. + // Frame type might be reserved, Visitor must make sure to ignore. + // |header_length| and |payload_length| are the length of the frame header + // and payload, respectively. + virtual bool OnUnknownFrameStart(uint64_t frame_type, + QuicByteCount header_length, + QuicByteCount payload_length) = 0; + // Called when part of the payload of the unknown frame has been read. May + // be called multiple times for a single frame. |payload| is guaranteed to + // be non-empty. + virtual bool OnUnknownFramePayload(absl::string_view payload) = 0; + // Called when the unknown frame has been completely processed. + virtual bool OnUnknownFrameEnd() = 0; + }; + + // |visitor| must be non-null, and must outlive HttpDecoder. + explicit HttpDecoder(Visitor* visitor); + explicit HttpDecoder(Visitor* visitor, Options options); + + ~HttpDecoder(); + + // Processes the input and invokes the appropriate visitor methods, until a + // visitor method returns false or an error occurs. Returns the number of + // bytes processed. Does not process any input if called after an error. + // Paused processing can be resumed by calling ProcessInput() again with the + // unprocessed portion of data. Must not be called after an error has + // occurred. + QuicByteCount ProcessInput(const char* data, QuicByteCount len); + + // Decode settings frame from |data|. + // Upon successful decoding, |frame| will be populated, and returns true. + // This method is not used for regular processing of incoming data. + static bool DecodeSettings(const char* data, QuicByteCount len, + SettingsFrame* frame); + + // Returns an error code other than QUIC_NO_ERROR if and only if + // Visitor::OnError() has been called. + QuicErrorCode error() const { return error_; } + + const std::string& error_detail() const { return error_detail_; } + + // Returns true if input data processed so far ends on a frame boundary. + bool AtFrameBoundary() const { return state_ == STATE_READING_FRAME_TYPE; } + + std::string DebugString() const; + + private: + friend test::HttpDecoderPeer; + + // Represents the current state of the parsing state machine. + enum HttpDecoderState { + STATE_READING_FRAME_LENGTH, + STATE_READING_FRAME_TYPE, + + // States used for buffered frame types + STATE_BUFFER_OR_PARSE_PAYLOAD, + + // States used for non-buffered frame types + STATE_READING_FRAME_PAYLOAD, + STATE_FINISH_PARSING, + + STATE_PARSING_NO_LONGER_POSSIBLE, + STATE_ERROR + }; + + // Reads the type of a frame from |reader|. Sets error_ and error_detail_ + // if there are any errors. Also calls OnDataFrameStart() or + // OnHeadersFrameStart() for appropriate frame types. Returns whether the + // processing should continue. + bool ReadFrameType(QuicDataReader* reader); + + // Reads the length of a frame from |reader|. Sets error_ and error_detail_ + // if there are any errors. Returns whether processing should continue. + bool ReadFrameLength(QuicDataReader* reader); + + // Returns whether the current frame is of a buffered type. + // The payload of buffered frames is buffered by HttpDecoder, and parsed by + // HttpDecoder after the entire frame has been received. (Copying to the + // buffer is skipped if the ProcessInput() call covers the entire payload.) + // Frames that are not buffered have every payload fragment synchronously + // passed to the Visitor without buffering. + bool IsFrameBuffered(); + + // For buffered frame types, calls BufferOrParsePayload(). For other frame + // types, reads the payload of the current frame from |reader| and calls + // visitor methods. Returns whether processing should continue. + bool ReadFramePayload(QuicDataReader* reader); + + // For buffered frame types, this method is only called if frame payload is + // empty, and it calls BufferOrParsePayload(). For other frame types, this + // method directly calls visitor methods to signal that frame had been + // received completely. Returns whether processing should continue. + bool FinishParsing(); + + // Reset internal fields to prepare for reading next frame. + void ResetForNextFrame(); + + // Read payload of unknown frame from |reader| and call + // Visitor::OnUnknownFramePayload(). Returns true decoding should continue, + // false if it should be paused. + bool HandleUnknownFramePayload(QuicDataReader* reader); + + // Buffers any remaining frame payload from |*reader| into |buffer_| if + // necessary. Parses the frame payload if complete. Parses out of |*reader| + // without unnecessary copy if |*reader| contains entire payload. + // Returns whether processing should continue. + // Must only be called when current frame type is buffered. + bool BufferOrParsePayload(QuicDataReader* reader); + + // Parses the entire payload of certain kinds of frames that are parsed in a + // single pass. |reader| must have at least |current_frame_length_| bytes. + // Returns whether processing should continue. + // Must only be called when current frame type is buffered. + bool ParseEntirePayload(QuicDataReader* reader); + + // Buffers any remaining frame length field from |reader| into + // |length_buffer_|. + void BufferFrameLength(QuicDataReader* reader); + + // Buffers any remaining frame type field from |reader| into |type_buffer_|. + void BufferFrameType(QuicDataReader* reader); + + // Sets |error_| and |error_detail_| accordingly. + void RaiseError(QuicErrorCode error, std::string error_detail); + + // Parses the payload of a SETTINGS frame from |reader| into |frame|. + bool ParseSettingsFrame(QuicDataReader* reader, SettingsFrame* frame); + + // Parses the payload of a PRIORITY_UPDATE frame (draft-02, type 0xf0700) + // from |reader| into |frame|. + bool ParsePriorityUpdateFrame(QuicDataReader* reader, + PriorityUpdateFrame* frame); + + // Parses the payload of an ACCEPT_CH frame from |reader| into |frame|. + bool ParseAcceptChFrame(QuicDataReader* reader, AcceptChFrame* frame); + + // Returns the max frame size of a given |frame_type|. + QuicByteCount MaxFrameLength(uint64_t frame_type); + + // Visitor to invoke when messages are parsed. + Visitor* const visitor_; // Unowned. + // Whether WEBTRANSPORT_STREAM should be parsed. + bool allow_web_transport_stream_; + // Current state of the parsing. + HttpDecoderState state_; + // Type of the frame currently being parsed. + uint64_t current_frame_type_; + // Size of the frame's length field. + QuicByteCount current_length_field_length_; + // Remaining length that's needed for the frame's length field. + QuicByteCount remaining_length_field_length_; + // Length of the payload of the frame currently being parsed. + QuicByteCount current_frame_length_; + // Remaining payload bytes to be parsed. + QuicByteCount remaining_frame_length_; + // Length of the frame's type field. + QuicByteCount current_type_field_length_; + // Remaining length that's needed for the frame's type field. + QuicByteCount remaining_type_field_length_; + // Last error. + QuicErrorCode error_; + // The issue which caused |error_| + std::string error_detail_; + // Remaining unparsed data. + std::string buffer_; + // Remaining unparsed length field data. + std::array length_buffer_; + // Remaining unparsed type field data. + std::array type_buffer_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_HTTP_HTTP_DECODER_H_ diff --git a/quiche/quic/core/http/http_decoder_test.cc b/quiche/quic/core/http/http_decoder_test.cc new file mode 100644 index 000000000000..79544b3cf196 --- /dev/null +++ b/quiche/quic/core/http/http_decoder_test.cc @@ -0,0 +1,1067 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/http_decoder.h" + +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/http/http_encoder.h" +#include "quiche/quic/core/http/http_frames.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +using ::testing::_; +using ::testing::AnyNumber; +using ::testing::Eq; +using ::testing::InSequence; +using ::testing::Return; + +namespace quic { +namespace test { + +class HttpDecoderPeer { + public: + static uint64_t current_frame_type(HttpDecoder* decoder) { + return decoder->current_frame_type_; + } +}; + +namespace { + +class HttpDecoderTest : public QuicTest { + public: + HttpDecoderTest() : decoder_(&visitor_) { + ON_CALL(visitor_, OnMaxPushIdFrame()).WillByDefault(Return(true)); + ON_CALL(visitor_, OnGoAwayFrame(_)).WillByDefault(Return(true)); + ON_CALL(visitor_, OnSettingsFrameStart(_)).WillByDefault(Return(true)); + ON_CALL(visitor_, OnSettingsFrame(_)).WillByDefault(Return(true)); + ON_CALL(visitor_, OnDataFrameStart(_, _)).WillByDefault(Return(true)); + ON_CALL(visitor_, OnDataFramePayload(_)).WillByDefault(Return(true)); + ON_CALL(visitor_, OnDataFrameEnd()).WillByDefault(Return(true)); + ON_CALL(visitor_, OnHeadersFrameStart(_, _)).WillByDefault(Return(true)); + ON_CALL(visitor_, OnHeadersFramePayload(_)).WillByDefault(Return(true)); + ON_CALL(visitor_, OnHeadersFrameEnd()).WillByDefault(Return(true)); + ON_CALL(visitor_, OnPriorityUpdateFrameStart(_)) + .WillByDefault(Return(true)); + ON_CALL(visitor_, OnPriorityUpdateFrame(_)).WillByDefault(Return(true)); + ON_CALL(visitor_, OnAcceptChFrameStart(_)).WillByDefault(Return(true)); + ON_CALL(visitor_, OnAcceptChFrame(_)).WillByDefault(Return(true)); + ON_CALL(visitor_, OnUnknownFrameStart(_, _, _)).WillByDefault(Return(true)); + ON_CALL(visitor_, OnUnknownFramePayload(_)).WillByDefault(Return(true)); + ON_CALL(visitor_, OnUnknownFrameEnd()).WillByDefault(Return(true)); + } + ~HttpDecoderTest() override = default; + + uint64_t current_frame_type() { + return HttpDecoderPeer::current_frame_type(&decoder_); + } + + // Process |input| in a single call to HttpDecoder::ProcessInput(). + QuicByteCount ProcessInput(absl::string_view input) { + return decoder_.ProcessInput(input.data(), input.size()); + } + + // Feed |input| to |decoder_| one character at a time, + // verifying that each character gets processed. + void ProcessInputCharByChar(absl::string_view input) { + for (char c : input) { + EXPECT_EQ(1u, decoder_.ProcessInput(&c, 1)); + } + } + + // Append garbage to |input|, then process it in a single call to + // HttpDecoder::ProcessInput(). Verify that garbage is not read. + QuicByteCount ProcessInputWithGarbageAppended(absl::string_view input) { + std::string input_with_garbage_appended = absl::StrCat(input, "blahblah"); + QuicByteCount processed_bytes = ProcessInput(input_with_garbage_appended); + + // Guaranteed by HttpDecoder::ProcessInput() contract. + QUICHE_DCHECK_LE(processed_bytes, input_with_garbage_appended.size()); + + // Caller should set up visitor to pause decoding + // before HttpDecoder would read garbage. + EXPECT_LE(processed_bytes, input.size()); + + return processed_bytes; + } + + testing::StrictMock visitor_; + HttpDecoder decoder_; +}; + +TEST_F(HttpDecoderTest, InitialState) { + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); +} + +TEST_F(HttpDecoderTest, UnknownFrame) { + std::unique_ptr input; + + const QuicByteCount payload_lengths[] = {0, 14, 100}; + const uint64_t frame_types[] = { + 0x21, 0x40, 0x5f, 0x7e, 0x9d, // some reserved frame types + 0x6f, 0x14 // some unknown, not reserved frame types + }; + + for (auto payload_length : payload_lengths) { + std::string data(payload_length, 'a'); + + for (auto frame_type : frame_types) { + const QuicByteCount total_length = + QuicDataWriter::GetVarInt62Len(frame_type) + + QuicDataWriter::GetVarInt62Len(payload_length) + payload_length; + input = std::make_unique(total_length); + + QuicDataWriter writer(total_length, input.get()); + writer.WriteVarInt62(frame_type); + writer.WriteVarInt62(payload_length); + const QuicByteCount header_length = writer.length(); + if (payload_length > 0) { + writer.WriteStringPiece(data); + } + + EXPECT_CALL(visitor_, OnUnknownFrameStart(frame_type, header_length, + payload_length)); + if (payload_length > 0) { + EXPECT_CALL(visitor_, OnUnknownFramePayload(Eq(data))); + } + EXPECT_CALL(visitor_, OnUnknownFrameEnd()); + + EXPECT_EQ(total_length, decoder_.ProcessInput(input.get(), total_length)); + + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + ASSERT_EQ("", decoder_.error_detail()); + EXPECT_EQ(frame_type, current_frame_type()); + } + } +} + +TEST_F(HttpDecoderTest, CancelPush) { + InSequence s; + std::string input = absl::HexStringToBytes( + "03" // type (CANCEL_PUSH) + "01" // length + "01"); // Push Id + + EXPECT_CALL(visitor_, OnError(&decoder_)); + EXPECT_EQ(1u, ProcessInput(input)); + EXPECT_THAT(decoder_.error(), IsError(QUIC_HTTP_FRAME_ERROR)); + EXPECT_EQ("CANCEL_PUSH frame received.", decoder_.error_detail()); +} + +TEST_F(HttpDecoderTest, PushPromiseFrame) { + InSequence s; + std::string input = + absl::StrCat(absl::HexStringToBytes("05" // type (PUSH PROMISE) + "08" // length + "1f"), // push id 31 + "Headers"); // headers + + EXPECT_CALL(visitor_, OnError(&decoder_)); + EXPECT_EQ(1u, ProcessInput(input)); + EXPECT_THAT(decoder_.error(), IsError(QUIC_HTTP_FRAME_ERROR)); + EXPECT_EQ("PUSH_PROMISE frame received.", decoder_.error_detail()); +} + +TEST_F(HttpDecoderTest, MaxPushId) { + InSequence s; + std::string input = absl::HexStringToBytes( + "0D" // type (MAX_PUSH_ID) + "01" // length + "01"); // Push Id + + // Visitor pauses processing. + EXPECT_CALL(visitor_, OnMaxPushIdFrame()).WillOnce(Return(false)); + EXPECT_EQ(input.size(), ProcessInputWithGarbageAppended(input)); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + // Process the full frame. + EXPECT_CALL(visitor_, OnMaxPushIdFrame()); + EXPECT_EQ(input.size(), ProcessInput(input)); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + // Process the frame incrementally. + EXPECT_CALL(visitor_, OnMaxPushIdFrame()); + ProcessInputCharByChar(input); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); +} + +TEST_F(HttpDecoderTest, SettingsFrame) { + InSequence s; + std::string input = absl::HexStringToBytes( + "04" // type (SETTINGS) + "07" // length + "01" // identifier (SETTINGS_QPACK_MAX_TABLE_CAPACITY) + "02" // content + "06" // identifier (SETTINGS_MAX_HEADER_LIST_SIZE) + "05" // content + "4100" // identifier, encoded on 2 bytes (0x40), value is 256 (0x100) + "04"); // content + + SettingsFrame frame; + frame.values[1] = 2; + frame.values[6] = 5; + frame.values[256] = 4; + + // Visitor pauses processing. + absl::string_view remaining_input(input); + EXPECT_CALL(visitor_, OnSettingsFrameStart(2)).WillOnce(Return(false)); + QuicByteCount processed_bytes = + ProcessInputWithGarbageAppended(remaining_input); + EXPECT_EQ(2u, processed_bytes); + remaining_input = remaining_input.substr(processed_bytes); + + EXPECT_CALL(visitor_, OnSettingsFrame(frame)).WillOnce(Return(false)); + processed_bytes = ProcessInputWithGarbageAppended(remaining_input); + EXPECT_EQ(remaining_input.size(), processed_bytes); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + // Process the full frame. + EXPECT_CALL(visitor_, OnSettingsFrameStart(2)); + EXPECT_CALL(visitor_, OnSettingsFrame(frame)); + EXPECT_EQ(input.size(), ProcessInput(input)); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + // Process the frame incrementally. + EXPECT_CALL(visitor_, OnSettingsFrameStart(2)); + EXPECT_CALL(visitor_, OnSettingsFrame(frame)); + ProcessInputCharByChar(input); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); +} + +TEST_F(HttpDecoderTest, CorruptSettingsFrame) { + const char* const kPayload = + "\x42\x11" // two-byte id + "\x80\x22\x33\x44" // four-byte value + "\x58\x39" // two-byte id + "\xf0\x22\x33\x44\x55\x66\x77\x88"; // eight-byte value + struct { + size_t payload_length; + const char* const error_message; + } kTestData[] = { + {1, "Unable to read setting identifier."}, + {5, "Unable to read setting value."}, + {7, "Unable to read setting identifier."}, + {12, "Unable to read setting value."}, + }; + + for (const auto& test_data : kTestData) { + std::string input; + input.push_back(4u); // type SETTINGS + input.push_back(test_data.payload_length); + const size_t header_length = input.size(); + input.append(kPayload, test_data.payload_length); + + HttpDecoder decoder(&visitor_); + EXPECT_CALL(visitor_, OnSettingsFrameStart(header_length)); + EXPECT_CALL(visitor_, OnError(&decoder)); + + QuicByteCount processed_bytes = + decoder.ProcessInput(input.data(), input.size()); + EXPECT_EQ(input.size(), processed_bytes); + EXPECT_THAT(decoder.error(), IsError(QUIC_HTTP_FRAME_ERROR)); + EXPECT_EQ(test_data.error_message, decoder.error_detail()); + } +} + +TEST_F(HttpDecoderTest, DuplicateSettingsIdentifier) { + std::string input = absl::HexStringToBytes( + "04" // type (SETTINGS) + "04" // length + "01" // identifier + "01" // content + "01" // identifier + "02"); // content + + EXPECT_CALL(visitor_, OnSettingsFrameStart(2)); + EXPECT_CALL(visitor_, OnError(&decoder_)); + + EXPECT_EQ(input.size(), ProcessInput(input)); + + EXPECT_THAT(decoder_.error(), + IsError(QUIC_HTTP_DUPLICATE_SETTING_IDENTIFIER)); + EXPECT_EQ("Duplicate setting identifier.", decoder_.error_detail()); +} + +TEST_F(HttpDecoderTest, DataFrame) { + InSequence s; + std::string input = absl::StrCat(absl::HexStringToBytes("00" // type (DATA) + "05"), // length + "Data!"); // data + + // Visitor pauses processing. + EXPECT_CALL(visitor_, OnDataFrameStart(2, 5)).WillOnce(Return(false)); + absl::string_view remaining_input(input); + QuicByteCount processed_bytes = + ProcessInputWithGarbageAppended(remaining_input); + EXPECT_EQ(2u, processed_bytes); + remaining_input = remaining_input.substr(processed_bytes); + + EXPECT_CALL(visitor_, OnDataFramePayload(absl::string_view("Data!"))) + .WillOnce(Return(false)); + processed_bytes = ProcessInputWithGarbageAppended(remaining_input); + EXPECT_EQ(remaining_input.size(), processed_bytes); + + EXPECT_CALL(visitor_, OnDataFrameEnd()).WillOnce(Return(false)); + EXPECT_EQ(0u, ProcessInputWithGarbageAppended("")); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + // Process the full frame. + EXPECT_CALL(visitor_, OnDataFrameStart(2, 5)); + EXPECT_CALL(visitor_, OnDataFramePayload(absl::string_view("Data!"))); + EXPECT_CALL(visitor_, OnDataFrameEnd()); + EXPECT_EQ(input.size(), ProcessInput(input)); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + // Process the frame incrementally. + EXPECT_CALL(visitor_, OnDataFrameStart(2, 5)); + EXPECT_CALL(visitor_, OnDataFramePayload(absl::string_view("D"))); + EXPECT_CALL(visitor_, OnDataFramePayload(absl::string_view("a"))); + EXPECT_CALL(visitor_, OnDataFramePayload(absl::string_view("t"))); + EXPECT_CALL(visitor_, OnDataFramePayload(absl::string_view("a"))); + EXPECT_CALL(visitor_, OnDataFramePayload(absl::string_view("!"))); + EXPECT_CALL(visitor_, OnDataFrameEnd()); + ProcessInputCharByChar(input); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); +} + +TEST_F(HttpDecoderTest, FrameHeaderPartialDelivery) { + InSequence s; + // A large input that will occupy more than 1 byte in the length field. + std::string input(2048, 'x'); + quiche::QuicheBuffer header = HttpEncoder::SerializeDataFrameHeader( + input.length(), quiche::SimpleBufferAllocator::Get()); + // Partially send only 1 byte of the header to process. + EXPECT_EQ(1u, decoder_.ProcessInput(header.data(), 1)); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + // Send the rest of the header. + EXPECT_CALL(visitor_, OnDataFrameStart(3, input.length())); + EXPECT_EQ(header.size() - 1, + decoder_.ProcessInput(header.data() + 1, header.size() - 1)); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + // Send data. + EXPECT_CALL(visitor_, OnDataFramePayload(absl::string_view(input))); + EXPECT_CALL(visitor_, OnDataFrameEnd()); + EXPECT_EQ(2048u, decoder_.ProcessInput(input.data(), 2048)); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); +} + +TEST_F(HttpDecoderTest, PartialDeliveryOfLargeFrameType) { + // Use a reserved type that takes four bytes as a varint. + const uint64_t frame_type = 0x1f * 0x222 + 0x21; + const QuicByteCount payload_length = 0; + const QuicByteCount header_length = + QuicDataWriter::GetVarInt62Len(frame_type) + + QuicDataWriter::GetVarInt62Len(payload_length); + + auto input = std::make_unique(header_length); + QuicDataWriter writer(header_length, input.get()); + writer.WriteVarInt62(frame_type); + writer.WriteVarInt62(payload_length); + + EXPECT_CALL(visitor_, + OnUnknownFrameStart(frame_type, header_length, payload_length)); + EXPECT_CALL(visitor_, OnUnknownFrameEnd()); + + auto raw_input = input.get(); + for (uint64_t i = 0; i < header_length; ++i) { + char c = raw_input[i]; + EXPECT_EQ(1u, decoder_.ProcessInput(&c, 1)); + } + + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + EXPECT_EQ(frame_type, current_frame_type()); +} + +TEST_F(HttpDecoderTest, GoAway) { + InSequence s; + std::string input = absl::HexStringToBytes( + "07" // type (GOAWAY) + "01" // length + "01"); // ID + + // Visitor pauses processing. + EXPECT_CALL(visitor_, OnGoAwayFrame(GoAwayFrame({1}))) + .WillOnce(Return(false)); + EXPECT_EQ(input.size(), ProcessInputWithGarbageAppended(input)); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + // Process the full frame. + EXPECT_CALL(visitor_, OnGoAwayFrame(GoAwayFrame({1}))); + EXPECT_EQ(input.size(), ProcessInput(input)); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + // Process the frame incrementally. + EXPECT_CALL(visitor_, OnGoAwayFrame(GoAwayFrame({1}))); + ProcessInputCharByChar(input); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); +} + +TEST_F(HttpDecoderTest, HeadersFrame) { + InSequence s; + std::string input = + absl::StrCat(absl::HexStringToBytes("01" // type (HEADERS) + "07"), // length + "Headers"); // headers + + // Visitor pauses processing. + EXPECT_CALL(visitor_, OnHeadersFrameStart(2, 7)).WillOnce(Return(false)); + absl::string_view remaining_input(input); + QuicByteCount processed_bytes = + ProcessInputWithGarbageAppended(remaining_input); + EXPECT_EQ(2u, processed_bytes); + remaining_input = remaining_input.substr(processed_bytes); + + EXPECT_CALL(visitor_, OnHeadersFramePayload(absl::string_view("Headers"))) + .WillOnce(Return(false)); + processed_bytes = ProcessInputWithGarbageAppended(remaining_input); + EXPECT_EQ(remaining_input.size(), processed_bytes); + + EXPECT_CALL(visitor_, OnHeadersFrameEnd()).WillOnce(Return(false)); + EXPECT_EQ(0u, ProcessInputWithGarbageAppended("")); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + // Process the full frame. + EXPECT_CALL(visitor_, OnHeadersFrameStart(2, 7)); + EXPECT_CALL(visitor_, OnHeadersFramePayload(absl::string_view("Headers"))); + EXPECT_CALL(visitor_, OnHeadersFrameEnd()); + EXPECT_EQ(input.size(), ProcessInput(input)); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + // Process the frame incrementally. + EXPECT_CALL(visitor_, OnHeadersFrameStart(2, 7)); + EXPECT_CALL(visitor_, OnHeadersFramePayload(absl::string_view("H"))); + EXPECT_CALL(visitor_, OnHeadersFramePayload(absl::string_view("e"))); + EXPECT_CALL(visitor_, OnHeadersFramePayload(absl::string_view("a"))); + EXPECT_CALL(visitor_, OnHeadersFramePayload(absl::string_view("d"))); + EXPECT_CALL(visitor_, OnHeadersFramePayload(absl::string_view("e"))); + EXPECT_CALL(visitor_, OnHeadersFramePayload(absl::string_view("r"))); + EXPECT_CALL(visitor_, OnHeadersFramePayload(absl::string_view("s"))); + EXPECT_CALL(visitor_, OnHeadersFrameEnd()); + ProcessInputCharByChar(input); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); +} + +TEST_F(HttpDecoderTest, EmptyDataFrame) { + InSequence s; + std::string input = absl::HexStringToBytes( + "00" // type (DATA) + "00"); // length + + // Visitor pauses processing. + EXPECT_CALL(visitor_, OnDataFrameStart(2, 0)).WillOnce(Return(false)); + EXPECT_EQ(input.size(), ProcessInputWithGarbageAppended(input)); + + EXPECT_CALL(visitor_, OnDataFrameEnd()).WillOnce(Return(false)); + EXPECT_EQ(0u, ProcessInputWithGarbageAppended("")); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + // Process the full frame. + EXPECT_CALL(visitor_, OnDataFrameStart(2, 0)); + EXPECT_CALL(visitor_, OnDataFrameEnd()); + EXPECT_EQ(input.size(), ProcessInput(input)); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + // Process the frame incrementally. + EXPECT_CALL(visitor_, OnDataFrameStart(2, 0)); + EXPECT_CALL(visitor_, OnDataFrameEnd()); + ProcessInputCharByChar(input); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); +} + +TEST_F(HttpDecoderTest, EmptyHeadersFrame) { + InSequence s; + std::string input = absl::HexStringToBytes( + "01" // type (HEADERS) + "00"); // length + + // Visitor pauses processing. + EXPECT_CALL(visitor_, OnHeadersFrameStart(2, 0)).WillOnce(Return(false)); + EXPECT_EQ(input.size(), ProcessInputWithGarbageAppended(input)); + + EXPECT_CALL(visitor_, OnHeadersFrameEnd()).WillOnce(Return(false)); + EXPECT_EQ(0u, ProcessInputWithGarbageAppended("")); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + // Process the full frame. + EXPECT_CALL(visitor_, OnHeadersFrameStart(2, 0)); + EXPECT_CALL(visitor_, OnHeadersFrameEnd()); + EXPECT_EQ(input.size(), ProcessInput(input)); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + // Process the frame incrementally. + EXPECT_CALL(visitor_, OnHeadersFrameStart(2, 0)); + EXPECT_CALL(visitor_, OnHeadersFrameEnd()); + ProcessInputCharByChar(input); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); +} + +TEST_F(HttpDecoderTest, GoawayWithOverlyLargePayload) { + std::string input = absl::HexStringToBytes( + "07" // type (GOAWAY) + "10"); // length exceeding the maximum possible length for GOAWAY frame + // Process all data at once. + EXPECT_CALL(visitor_, OnError(&decoder_)); + EXPECT_EQ(2u, ProcessInput(input)); + EXPECT_THAT(decoder_.error(), IsError(QUIC_HTTP_FRAME_TOO_LARGE)); + EXPECT_EQ("Frame is too large.", decoder_.error_detail()); +} + +TEST_F(HttpDecoderTest, MaxPushIdWithOverlyLargePayload) { + std::string input = absl::HexStringToBytes( + "0d" // type (MAX_PUSH_ID) + "10"); // length exceeding the maximum possible length for MAX_PUSH_ID + // frame + // Process all data at once. + EXPECT_CALL(visitor_, OnError(&decoder_)); + EXPECT_EQ(2u, ProcessInput(input)); + EXPECT_THAT(decoder_.error(), IsError(QUIC_HTTP_FRAME_TOO_LARGE)); + EXPECT_EQ("Frame is too large.", decoder_.error_detail()); +} + +TEST_F(HttpDecoderTest, FrameWithOverlyLargePayload) { + // Regression test for b/193919867: Ensure that reading frames with incredibly + // large payload lengths does not lead to allocating unbounded memory. + constexpr size_t max_input_length = + /*max frame type varint length*/ sizeof(uint64_t) + + /*max frame length varint length*/ sizeof(uint64_t) + + /*one byte of payload*/ sizeof(uint8_t); + char input[max_input_length]; + for (uint64_t frame_type = 0; frame_type < 1025; frame_type++) { + ::testing::NiceMock visitor; + HttpDecoder decoder(&visitor); + QuicDataWriter writer(max_input_length, input); + ASSERT_TRUE(writer.WriteVarInt62(frame_type)); // frame type. + ASSERT_TRUE( + writer.WriteVarInt62(quiche::kVarInt62MaxValue)); // frame length. + ASSERT_TRUE(writer.WriteUInt8(0x00)); // one byte of payload. + EXPECT_NE(decoder.ProcessInput(input, writer.length()), 0u) << frame_type; + } +} + +TEST_F(HttpDecoderTest, MalformedSettingsFrame) { + char input[30]; + QuicDataWriter writer(30, input); + // Write type SETTINGS. + writer.WriteUInt8(0x04); + // Write length. + writer.WriteVarInt62(2048 * 1024); + + writer.WriteStringPiece("Malformed payload"); + EXPECT_CALL(visitor_, OnError(&decoder_)); + EXPECT_EQ(5u, decoder_.ProcessInput(input, ABSL_ARRAYSIZE(input))); + EXPECT_THAT(decoder_.error(), IsError(QUIC_HTTP_FRAME_TOO_LARGE)); + EXPECT_EQ("Frame is too large.", decoder_.error_detail()); +} + +TEST_F(HttpDecoderTest, Http2Frame) { + std::string input = absl::HexStringToBytes( + "06" // PING in HTTP/2 but not supported in HTTP/3. + "05" // length + "15"); // random payload + + // Process the full frame. + EXPECT_CALL(visitor_, OnError(&decoder_)); + EXPECT_EQ(1u, ProcessInput(input)); + EXPECT_THAT(decoder_.error(), IsError(QUIC_HTTP_RECEIVE_SPDY_FRAME)); + EXPECT_EQ("HTTP/2 frame received in a HTTP/3 connection: 6", + decoder_.error_detail()); +} + +TEST_F(HttpDecoderTest, HeadersPausedThenData) { + InSequence s; + std::string input = + absl::StrCat(absl::HexStringToBytes("01" // type (HEADERS) + "07"), // length + "Headers", // headers + absl::HexStringToBytes("00" // type (DATA) + "05"), // length + "Data!"); // data + + // Visitor pauses processing, maybe because header decompression is blocked. + EXPECT_CALL(visitor_, OnHeadersFrameStart(2, 7)); + EXPECT_CALL(visitor_, OnHeadersFramePayload(absl::string_view("Headers"))); + EXPECT_CALL(visitor_, OnHeadersFrameEnd()).WillOnce(Return(false)); + absl::string_view remaining_input(input); + QuicByteCount processed_bytes = + ProcessInputWithGarbageAppended(remaining_input); + EXPECT_EQ(9u, processed_bytes); + remaining_input = remaining_input.substr(processed_bytes); + + // Process DATA frame. + EXPECT_CALL(visitor_, OnDataFrameStart(2, 5)); + EXPECT_CALL(visitor_, OnDataFramePayload(absl::string_view("Data!"))); + EXPECT_CALL(visitor_, OnDataFrameEnd()); + + processed_bytes = ProcessInput(remaining_input); + EXPECT_EQ(remaining_input.size(), processed_bytes); + + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); +} + +TEST_F(HttpDecoderTest, CorruptFrame) { + InSequence s; + + struct { + const char* const input; + const char* const error_message; + } kTestData[] = {{"\x0D" // type (MAX_PUSH_ID) + "\x01" // length + "\x40", // first byte of two-byte varint push id + "Unable to read MAX_PUSH_ID push_id."}, + {"\x0D" // type (MAX_PUSH_ID) + "\x04" // length + "\x05" // valid push id + "foo", // superfluous data + "Superfluous data in MAX_PUSH_ID frame."}, + {"\x07" // type (GOAWAY) + "\x01" // length + "\x40", // first byte of two-byte varint stream id + "Unable to read GOAWAY ID."}, + {"\x07" // type (GOAWAY) + "\x04" // length + "\x05" // valid stream id + "foo", // superfluous data + "Superfluous data in GOAWAY frame."}, + {"\x40\x89" // type (ACCEPT_CH) + "\x01" // length + "\x40", // first byte of two-byte varint origin length + "Unable to read ACCEPT_CH origin."}, + {"\x40\x89" // type (ACCEPT_CH) + "\x01" // length + "\x05", // valid origin length but no origin string + "Unable to read ACCEPT_CH origin."}, + {"\x40\x89" // type (ACCEPT_CH) + "\x04" // length + "\x05" // valid origin length + "foo", // payload ends before origin ends + "Unable to read ACCEPT_CH origin."}, + {"\x40\x89" // type (ACCEPT_CH) + "\x04" // length + "\x03" // valid origin length + "foo", // payload ends at end of origin: no value + "Unable to read ACCEPT_CH value."}, + {"\x40\x89" // type (ACCEPT_CH) + "\x05" // length + "\x03" // valid origin length + "foo" // payload ends at end of origin: no value + "\x40", // first byte of two-byte varint value length + "Unable to read ACCEPT_CH value."}, + {"\x40\x89" // type (ACCEPT_CH) + "\x08" // length + "\x03" // valid origin length + "foo" // origin + "\x05" // valid value length + "bar", // payload ends before value ends + "Unable to read ACCEPT_CH value."}}; + + for (const auto& test_data : kTestData) { + { + HttpDecoder decoder(&visitor_); + EXPECT_CALL(visitor_, OnAcceptChFrameStart(_)).Times(AnyNumber()); + EXPECT_CALL(visitor_, OnError(&decoder)); + + absl::string_view input(test_data.input); + decoder.ProcessInput(input.data(), input.size()); + EXPECT_THAT(decoder.error(), IsError(QUIC_HTTP_FRAME_ERROR)); + EXPECT_EQ(test_data.error_message, decoder.error_detail()); + } + { + HttpDecoder decoder(&visitor_); + EXPECT_CALL(visitor_, OnAcceptChFrameStart(_)).Times(AnyNumber()); + EXPECT_CALL(visitor_, OnError(&decoder)); + + absl::string_view input(test_data.input); + for (auto c : input) { + decoder.ProcessInput(&c, 1); + } + EXPECT_THAT(decoder.error(), IsError(QUIC_HTTP_FRAME_ERROR)); + EXPECT_EQ(test_data.error_message, decoder.error_detail()); + } + } +} + +TEST_F(HttpDecoderTest, EmptySettingsFrame) { + std::string input = absl::HexStringToBytes( + "04" // type (SETTINGS) + "00"); // frame length + + EXPECT_CALL(visitor_, OnSettingsFrameStart(2)); + + SettingsFrame empty_frame; + EXPECT_CALL(visitor_, OnSettingsFrame(empty_frame)); + + EXPECT_EQ(input.size(), ProcessInput(input)); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); +} + +TEST_F(HttpDecoderTest, EmptyGoAwayFrame) { + std::string input = absl::HexStringToBytes( + "07" // type (GOAWAY) + "00"); // frame length + + EXPECT_CALL(visitor_, OnError(&decoder_)); + EXPECT_EQ(input.size(), ProcessInput(input)); + EXPECT_THAT(decoder_.error(), IsError(QUIC_HTTP_FRAME_ERROR)); + EXPECT_EQ("Unable to read GOAWAY ID.", decoder_.error_detail()); +} + +TEST_F(HttpDecoderTest, EmptyMaxPushIdFrame) { + std::string input = absl::HexStringToBytes( + "0d" // type (MAX_PUSH_ID) + "00"); // frame length + + EXPECT_CALL(visitor_, OnError(&decoder_)); + EXPECT_EQ(input.size(), ProcessInput(input)); + EXPECT_THAT(decoder_.error(), IsError(QUIC_HTTP_FRAME_ERROR)); + EXPECT_EQ("Unable to read MAX_PUSH_ID push_id.", decoder_.error_detail()); +} + +TEST_F(HttpDecoderTest, LargeStreamIdInGoAway) { + GoAwayFrame frame; + frame.id = 1ull << 60; + std::string goaway = HttpEncoder::SerializeGoAwayFrame(frame); + EXPECT_CALL(visitor_, OnGoAwayFrame(frame)); + EXPECT_GT(goaway.length(), 0u); + EXPECT_EQ(goaway.length(), + decoder_.ProcessInput(goaway.data(), goaway.length())); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); +} + +// Old PRIORITY_UPDATE frame is parsed as unknown frame. +TEST_F(HttpDecoderTest, ObsoletePriorityUpdateFrame) { + const QuicByteCount header_length = 2; + const QuicByteCount payload_length = 3; + InSequence s; + std::string input = absl::HexStringToBytes( + "0f" // type (obsolete PRIORITY_UPDATE) + "03" // length + "666f6f"); // payload "foo" + + // Process frame as a whole. + EXPECT_CALL(visitor_, + OnUnknownFrameStart(0x0f, header_length, payload_length)); + EXPECT_CALL(visitor_, OnUnknownFramePayload(Eq("foo"))); + EXPECT_CALL(visitor_, OnUnknownFrameEnd()).WillOnce(Return(false)); + + EXPECT_EQ(header_length + payload_length, + ProcessInputWithGarbageAppended(input)); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + // Process frame byte by byte. + EXPECT_CALL(visitor_, + OnUnknownFrameStart(0x0f, header_length, payload_length)); + EXPECT_CALL(visitor_, OnUnknownFramePayload(Eq("f"))); + EXPECT_CALL(visitor_, OnUnknownFramePayload(Eq("o"))); + EXPECT_CALL(visitor_, OnUnknownFramePayload(Eq("o"))); + EXPECT_CALL(visitor_, OnUnknownFrameEnd()); + + ProcessInputCharByChar(input); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); +} + +TEST_F(HttpDecoderTest, PriorityUpdateFrame) { + InSequence s; + std::string input1 = absl::HexStringToBytes( + "800f0700" // type (PRIORITY_UPDATE) + "01" // length + "03"); // prioritized element id + + PriorityUpdateFrame priority_update1; + priority_update1.prioritized_element_id = 0x03; + + // Visitor pauses processing. + EXPECT_CALL(visitor_, OnPriorityUpdateFrameStart(5)).WillOnce(Return(false)); + absl::string_view remaining_input(input1); + QuicByteCount processed_bytes = + ProcessInputWithGarbageAppended(remaining_input); + EXPECT_EQ(5u, processed_bytes); + remaining_input = remaining_input.substr(processed_bytes); + + EXPECT_CALL(visitor_, OnPriorityUpdateFrame(priority_update1)) + .WillOnce(Return(false)); + processed_bytes = ProcessInputWithGarbageAppended(remaining_input); + EXPECT_EQ(remaining_input.size(), processed_bytes); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + // Process the full frame. + EXPECT_CALL(visitor_, OnPriorityUpdateFrameStart(5)); + EXPECT_CALL(visitor_, OnPriorityUpdateFrame(priority_update1)); + EXPECT_EQ(input1.size(), ProcessInput(input1)); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + // Process the frame incrementally. + EXPECT_CALL(visitor_, OnPriorityUpdateFrameStart(5)); + EXPECT_CALL(visitor_, OnPriorityUpdateFrame(priority_update1)); + ProcessInputCharByChar(input1); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + std::string input2 = absl::HexStringToBytes( + "800f0700" // type (PRIORITY_UPDATE) + "04" // length + "05" // prioritized element id + "666f6f"); // priority field value: "foo" + + PriorityUpdateFrame priority_update2; + priority_update2.prioritized_element_id = 0x05; + priority_update2.priority_field_value = "foo"; + + // Visitor pauses processing. + EXPECT_CALL(visitor_, OnPriorityUpdateFrameStart(5)).WillOnce(Return(false)); + remaining_input = input2; + processed_bytes = ProcessInputWithGarbageAppended(remaining_input); + EXPECT_EQ(5u, processed_bytes); + remaining_input = remaining_input.substr(processed_bytes); + + EXPECT_CALL(visitor_, OnPriorityUpdateFrame(priority_update2)) + .WillOnce(Return(false)); + processed_bytes = ProcessInputWithGarbageAppended(remaining_input); + EXPECT_EQ(remaining_input.size(), processed_bytes); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + // Process the full frame. + EXPECT_CALL(visitor_, OnPriorityUpdateFrameStart(5)); + EXPECT_CALL(visitor_, OnPriorityUpdateFrame(priority_update2)); + EXPECT_EQ(input2.size(), ProcessInput(input2)); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + // Process the frame incrementally. + EXPECT_CALL(visitor_, OnPriorityUpdateFrameStart(5)); + EXPECT_CALL(visitor_, OnPriorityUpdateFrame(priority_update2)); + ProcessInputCharByChar(input2); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); +} + +TEST_F(HttpDecoderTest, CorruptPriorityUpdateFrame) { + std::string payload = + absl::HexStringToBytes("4005"); // prioritized element id + struct { + size_t payload_length; + const char* const error_message; + } kTestData[] = { + {0, "Unable to read prioritized element id."}, + {1, "Unable to read prioritized element id."}, + }; + + for (const auto& test_data : kTestData) { + std::string input = + absl::HexStringToBytes("800f0700"); // type PRIORITY_UPDATE + input.push_back(test_data.payload_length); + size_t header_length = input.size(); + input.append(payload.data(), test_data.payload_length); + + HttpDecoder decoder(&visitor_); + EXPECT_CALL(visitor_, OnPriorityUpdateFrameStart(header_length)); + EXPECT_CALL(visitor_, OnError(&decoder)); + + QuicByteCount processed_bytes = + decoder.ProcessInput(input.data(), input.size()); + EXPECT_EQ(input.size(), processed_bytes); + EXPECT_THAT(decoder.error(), IsError(QUIC_HTTP_FRAME_ERROR)); + EXPECT_EQ(test_data.error_message, decoder.error_detail()); + } +} + +TEST_F(HttpDecoderTest, AcceptChFrame) { + InSequence s; + std::string input1 = absl::HexStringToBytes( + "4089" // type (ACCEPT_CH) + "00"); // length + + AcceptChFrame accept_ch1; + + // Visitor pauses processing. + EXPECT_CALL(visitor_, OnAcceptChFrameStart(3)).WillOnce(Return(false)); + absl::string_view remaining_input(input1); + QuicByteCount processed_bytes = + ProcessInputWithGarbageAppended(remaining_input); + EXPECT_EQ(3u, processed_bytes); + remaining_input = remaining_input.substr(processed_bytes); + + EXPECT_CALL(visitor_, OnAcceptChFrame(accept_ch1)).WillOnce(Return(false)); + processed_bytes = ProcessInputWithGarbageAppended(remaining_input); + EXPECT_EQ(remaining_input.size(), processed_bytes); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + // Process the full frame. + EXPECT_CALL(visitor_, OnAcceptChFrameStart(3)); + EXPECT_CALL(visitor_, OnAcceptChFrame(accept_ch1)); + EXPECT_EQ(input1.size(), ProcessInput(input1)); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + // Process the frame incrementally. + EXPECT_CALL(visitor_, OnAcceptChFrameStart(3)); + EXPECT_CALL(visitor_, OnAcceptChFrame(accept_ch1)); + ProcessInputCharByChar(input1); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + std::string input2 = absl::HexStringToBytes( + "4089" // type (ACCEPT_CH) + "08" // length + "03" // length of origin + "666f6f" // origin "foo" + "03" // length of value + "626172"); // value "bar" + + AcceptChFrame accept_ch2; + accept_ch2.entries.push_back({"foo", "bar"}); + + // Visitor pauses processing. + EXPECT_CALL(visitor_, OnAcceptChFrameStart(3)).WillOnce(Return(false)); + remaining_input = input2; + processed_bytes = ProcessInputWithGarbageAppended(remaining_input); + EXPECT_EQ(3u, processed_bytes); + remaining_input = remaining_input.substr(processed_bytes); + + EXPECT_CALL(visitor_, OnAcceptChFrame(accept_ch2)).WillOnce(Return(false)); + processed_bytes = ProcessInputWithGarbageAppended(remaining_input); + EXPECT_EQ(remaining_input.size(), processed_bytes); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + // Process the full frame. + EXPECT_CALL(visitor_, OnAcceptChFrameStart(3)); + EXPECT_CALL(visitor_, OnAcceptChFrame(accept_ch2)); + EXPECT_EQ(input2.size(), ProcessInput(input2)); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); + + // Process the frame incrementally. + EXPECT_CALL(visitor_, OnAcceptChFrameStart(3)); + EXPECT_CALL(visitor_, OnAcceptChFrame(accept_ch2)); + ProcessInputCharByChar(input2); + EXPECT_THAT(decoder_.error(), IsQuicNoError()); + EXPECT_EQ("", decoder_.error_detail()); +} + +TEST_F(HttpDecoderTest, WebTransportStreamDisabled) { + InSequence s; + + // Unknown frame of type 0x41 and length 0x104. + std::string input = absl::HexStringToBytes("40414104"); + EXPECT_CALL(visitor_, OnUnknownFrameStart(0x41, input.size(), 0x104)); + EXPECT_EQ(ProcessInput(input), input.size()); +} + +TEST(HttpDecoderTestNoFixture, WebTransportStream) { + HttpDecoder::Options options; + options.allow_web_transport_stream = true; + testing::StrictMock visitor; + HttpDecoder decoder(&visitor, options); + + // WebTransport stream for session ID 0x104, with four bytes of extra data. + std::string input = absl::HexStringToBytes("40414104ffffffff"); + EXPECT_CALL(visitor, OnWebTransportStreamFrameType(4, 0x104)); + QuicByteCount bytes = decoder.ProcessInput(input.data(), input.size()); + EXPECT_EQ(bytes, 4u); +} + +TEST(HttpDecoderTestNoFixture, WebTransportStreamError) { + HttpDecoder::Options options; + options.allow_web_transport_stream = true; + testing::StrictMock visitor; + HttpDecoder decoder(&visitor, options); + + std::string input = absl::HexStringToBytes("404100"); + EXPECT_CALL(visitor, OnWebTransportStreamFrameType(_, _)); + decoder.ProcessInput(input.data(), input.size()); + + EXPECT_QUIC_BUG( + { + EXPECT_CALL(visitor, OnError(_)); + decoder.ProcessInput(input.data(), input.size()); + }, + "HttpDecoder called after an indefinite-length frame"); +} + +TEST_F(HttpDecoderTest, DecodeSettings) { + std::string input = absl::HexStringToBytes( + "04" // type (SETTINGS) + "07" // length + "01" // identifier (SETTINGS_QPACK_MAX_TABLE_CAPACITY) + "02" // content + "06" // identifier (SETTINGS_MAX_HEADER_LIST_SIZE) + "05" // content + "4100" // identifier, encoded on 2 bytes (0x40), value is 256 (0x100) + "04"); // content + + SettingsFrame frame; + frame.values[1] = 2; + frame.values[6] = 5; + frame.values[256] = 4; + + SettingsFrame out; + EXPECT_TRUE(HttpDecoder::DecodeSettings(input.data(), input.size(), &out)); + EXPECT_EQ(frame, out); + + // non-settings frame. + input = absl::HexStringToBytes( + "0D" // type (MAX_PUSH_ID) + "01" // length + "01"); // Push Id + + EXPECT_FALSE(HttpDecoder::DecodeSettings(input.data(), input.size(), &out)); + + // Corrupt SETTINGS. + input = absl::HexStringToBytes( + "04" // type (SETTINGS) + "01" // length + "42"); // First byte of setting identifier, indicating a 2-byte varint62. + + EXPECT_FALSE(HttpDecoder::DecodeSettings(input.data(), input.size(), &out)); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/http/http_encoder.cc b/quiche/quic/core/http/http_encoder.cc new file mode 100644 index 000000000000..de40d4af486a --- /dev/null +++ b/quiche/quic/core/http/http_encoder.cc @@ -0,0 +1,283 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/http_encoder.h" + +#include +#include + +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +namespace { + +bool WriteFrameHeader(QuicByteCount length, HttpFrameType type, + QuicDataWriter* writer) { + return writer->WriteVarInt62(static_cast(type)) && + writer->WriteVarInt62(length); +} + +QuicByteCount GetTotalLength(QuicByteCount payload_length, HttpFrameType type) { + return QuicDataWriter::GetVarInt62Len(payload_length) + + QuicDataWriter::GetVarInt62Len(static_cast(type)) + + payload_length; +} + +} // namespace + +QuicByteCount HttpEncoder::GetDataFrameHeaderLength( + QuicByteCount payload_length) { + QUICHE_DCHECK_NE(0u, payload_length); + return QuicDataWriter::GetVarInt62Len(payload_length) + + QuicDataWriter::GetVarInt62Len( + static_cast(HttpFrameType::DATA)); +} + +quiche::QuicheBuffer HttpEncoder::SerializeDataFrameHeader( + QuicByteCount payload_length, quiche::QuicheBufferAllocator* allocator) { + QUICHE_DCHECK_NE(0u, payload_length); + QuicByteCount header_length = GetDataFrameHeaderLength(payload_length); + + quiche::QuicheBuffer header(allocator, header_length); + QuicDataWriter writer(header.size(), header.data()); + + if (WriteFrameHeader(payload_length, HttpFrameType::DATA, &writer)) { + return header; + } + QUIC_DLOG(ERROR) + << "Http encoder failed when attempting to serialize data frame header."; + return quiche::QuicheBuffer(); +} + +std::string HttpEncoder::SerializeHeadersFrameHeader( + QuicByteCount payload_length) { + QUICHE_DCHECK_NE(0u, payload_length); + QuicByteCount header_length = + QuicDataWriter::GetVarInt62Len(payload_length) + + QuicDataWriter::GetVarInt62Len( + static_cast(HttpFrameType::HEADERS)); + + std::string frame; + frame.resize(header_length); + QuicDataWriter writer(header_length, frame.data()); + + if (WriteFrameHeader(payload_length, HttpFrameType::HEADERS, &writer)) { + return frame; + } + QUIC_DLOG(ERROR) + << "Http encoder failed when attempting to serialize headers " + "frame header."; + return {}; +} + +std::string HttpEncoder::SerializeSettingsFrame(const SettingsFrame& settings) { + QuicByteCount payload_length = 0; + std::vector> ordered_settings{ + settings.values.begin(), settings.values.end()}; + std::sort(ordered_settings.begin(), ordered_settings.end()); + // Calculate the payload length. + for (const auto& p : ordered_settings) { + payload_length += QuicDataWriter::GetVarInt62Len(p.first); + payload_length += QuicDataWriter::GetVarInt62Len(p.second); + } + + QuicByteCount total_length = + GetTotalLength(payload_length, HttpFrameType::SETTINGS); + + std::string frame; + frame.resize(total_length); + QuicDataWriter writer(total_length, frame.data()); + + if (!WriteFrameHeader(payload_length, HttpFrameType::SETTINGS, &writer)) { + QUIC_DLOG(ERROR) << "Http encoder failed when attempting to serialize " + "settings frame header."; + return {}; + } + + for (const auto& p : ordered_settings) { + if (!writer.WriteVarInt62(p.first) || !writer.WriteVarInt62(p.second)) { + QUIC_DLOG(ERROR) << "Http encoder failed when attempting to serialize " + "settings frame payload."; + return {}; + } + } + + return frame; +} + +std::string HttpEncoder::SerializeGoAwayFrame(const GoAwayFrame& goaway) { + QuicByteCount payload_length = QuicDataWriter::GetVarInt62Len(goaway.id); + QuicByteCount total_length = + GetTotalLength(payload_length, HttpFrameType::GOAWAY); + + std::string frame; + frame.resize(total_length); + QuicDataWriter writer(total_length, frame.data()); + + if (WriteFrameHeader(payload_length, HttpFrameType::GOAWAY, &writer) && + writer.WriteVarInt62(goaway.id)) { + return frame; + } + QUIC_DLOG(ERROR) + << "Http encoder failed when attempting to serialize goaway frame."; + return {}; +} + +std::string HttpEncoder::SerializePriorityUpdateFrame( + const PriorityUpdateFrame& priority_update) { + QuicByteCount payload_length = + QuicDataWriter::GetVarInt62Len(priority_update.prioritized_element_id) + + priority_update.priority_field_value.size(); + QuicByteCount total_length = GetTotalLength( + payload_length, HttpFrameType::PRIORITY_UPDATE_REQUEST_STREAM); + + std::string frame; + frame.resize(total_length); + QuicDataWriter writer(total_length, frame.data()); + + if (WriteFrameHeader(payload_length, + HttpFrameType::PRIORITY_UPDATE_REQUEST_STREAM, + &writer) && + writer.WriteVarInt62(priority_update.prioritized_element_id) && + writer.WriteBytes(priority_update.priority_field_value.data(), + priority_update.priority_field_value.size())) { + return frame; + } + + QUIC_DLOG(ERROR) << "Http encoder failed when attempting to serialize " + "PRIORITY_UPDATE frame."; + return {}; +} + +std::string HttpEncoder::SerializeAcceptChFrame( + const AcceptChFrame& accept_ch) { + QuicByteCount payload_length = 0; + for (const auto& entry : accept_ch.entries) { + payload_length += QuicDataWriter::GetVarInt62Len(entry.origin.size()); + payload_length += entry.origin.size(); + payload_length += QuicDataWriter::GetVarInt62Len(entry.value.size()); + payload_length += entry.value.size(); + } + + QuicByteCount total_length = + GetTotalLength(payload_length, HttpFrameType::ACCEPT_CH); + + std::string frame; + frame.resize(total_length); + QuicDataWriter writer(total_length, frame.data()); + + if (!WriteFrameHeader(payload_length, HttpFrameType::ACCEPT_CH, &writer)) { + QUIC_DLOG(ERROR) + << "Http encoder failed to serialize ACCEPT_CH frame header."; + return {}; + } + + for (const auto& entry : accept_ch.entries) { + if (!writer.WriteStringPieceVarInt62(entry.origin) || + !writer.WriteStringPieceVarInt62(entry.value)) { + QUIC_DLOG(ERROR) + << "Http encoder failed to serialize ACCEPT_CH frame payload."; + return {}; + } + } + + return frame; +} + +std::string HttpEncoder::SerializeGreasingFrame() { + uint64_t frame_type; + QuicByteCount payload_length; + std::string payload; + if (!GetQuicFlag(quic_enable_http3_grease_randomness)) { + frame_type = 0x40; + payload_length = 1; + payload = "a"; + } else { + uint32_t result; + QuicRandom::GetInstance()->RandBytes(&result, sizeof(result)); + frame_type = 0x1fULL * static_cast(result) + 0x21ULL; + + // The payload length is random but within [0, 3]; + payload_length = result % 4; + + if (payload_length > 0) { + payload.resize(payload_length); + QuicRandom::GetInstance()->RandBytes(payload.data(), payload_length); + } + } + QuicByteCount total_length = QuicDataWriter::GetVarInt62Len(frame_type) + + QuicDataWriter::GetVarInt62Len(payload_length) + + payload_length; + + std::string frame; + frame.resize(total_length); + QuicDataWriter writer(total_length, frame.data()); + + bool success = + writer.WriteVarInt62(frame_type) && writer.WriteVarInt62(payload_length); + + if (payload_length > 0) { + success &= writer.WriteBytes(payload.data(), payload_length); + } + + if (success) { + return frame; + } + + QUIC_DLOG(ERROR) << "Http encoder failed when attempting to serialize " + "greasing frame."; + return {}; +} + +std::string HttpEncoder::SerializeWebTransportStreamFrameHeader( + WebTransportSessionId session_id) { + uint64_t stream_type = + static_cast(HttpFrameType::WEBTRANSPORT_STREAM); + QuicByteCount header_length = QuicDataWriter::GetVarInt62Len(stream_type) + + QuicDataWriter::GetVarInt62Len(session_id); + + std::string frame; + frame.resize(header_length); + QuicDataWriter writer(header_length, frame.data()); + + bool success = + writer.WriteVarInt62(stream_type) && writer.WriteVarInt62(session_id); + if (success && writer.remaining() == 0) { + return frame; + } + + QUIC_DLOG(ERROR) << "Http encoder failed when attempting to serialize " + "WEBTRANSPORT_STREAM frame header."; + return {}; +} + +std::string HttpEncoder::SerializeMetadataFrameHeader( + QuicByteCount payload_length) { + QUICHE_DCHECK_NE(0u, payload_length); + QuicByteCount header_length = + QuicDataWriter::GetVarInt62Len(payload_length) + + QuicDataWriter::GetVarInt62Len( + static_cast(HttpFrameType::METADATA)); + + std::string frame; + frame.resize(header_length); + QuicDataWriter writer(header_length, frame.data()); + + if (WriteFrameHeader(payload_length, HttpFrameType::METADATA, &writer)) { + return frame; + } + QUIC_DLOG(ERROR) + << "Http encoder failed when attempting to serialize METADATA " + "frame header."; + return {}; +} + +} // namespace quic diff --git a/quiche/quic/core/http/http_encoder.h b/quiche/quic/core/http/http_encoder.h new file mode 100644 index 000000000000..28a8a2cd100a --- /dev/null +++ b/quiche/quic/core/http/http_encoder.h @@ -0,0 +1,65 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_HTTP_HTTP_ENCODER_H_ +#define QUICHE_QUIC_CORE_HTTP_HTTP_ENCODER_H_ + +#include + +#include "quiche/quic/core/http/http_frames.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/common/quiche_buffer_allocator.h" + +namespace quic { + +class QuicDataWriter; + +// A class for encoding the HTTP frames that are exchanged in an HTTP over QUIC +// session. +class QUIC_EXPORT_PRIVATE HttpEncoder { + public: + HttpEncoder() = delete; + + // Returns the length of the header for a DATA frame. + static QuicByteCount GetDataFrameHeaderLength(QuicByteCount payload_length); + + // Serializes a DATA frame header into a QuicheBuffer; returns said + // QuicheBuffer on success, empty buffer otherwise. + static quiche::QuicheBuffer SerializeDataFrameHeader( + QuicByteCount payload_length, quiche::QuicheBufferAllocator* allocator); + + // Serializes a HEADERS frame header. + static std::string SerializeHeadersFrameHeader(QuicByteCount payload_length); + + // Serializes a SETTINGS frame. + static std::string SerializeSettingsFrame(const SettingsFrame& settings); + + // Serializes a GOAWAY frame. + static std::string SerializeGoAwayFrame(const GoAwayFrame& goaway); + + // Serializes a PRIORITY_UPDATE frame. + static std::string SerializePriorityUpdateFrame( + const PriorityUpdateFrame& priority_update); + + // Serializes an ACCEPT_CH frame. + static std::string SerializeAcceptChFrame(const AcceptChFrame& accept_ch); + + // Serializes a frame with reserved frame type specified in + // https://tools.ietf.org/html/draft-ietf-quic-http-25#section-7.2.9. + static std::string SerializeGreasingFrame(); + + // Serializes a WEBTRANSPORT_STREAM frame header as specified in + // https://www.ietf.org/archive/id/draft-ietf-webtrans-http3-00.html#name-client-initiated-bidirectio + static std::string SerializeWebTransportStreamFrameHeader( + WebTransportSessionId session_id); + + // Serializes a METADATA frame header. + static std::string SerializeMetadataFrameHeader(QuicByteCount payload_length); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_HTTP_HTTP_ENCODER_H_ diff --git a/quiche/quic/core/http/http_encoder_test.cc b/quiche/quic/core/http/http_encoder_test.cc new file mode 100644 index 000000000000..648799232655 --- /dev/null +++ b/quiche/quic/core/http/http_encoder_test.cc @@ -0,0 +1,139 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/http_encoder.h" + +#include "absl/base/macros.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/simple_buffer_allocator.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +namespace quic { +namespace test { + +TEST(HttpEncoderTest, SerializeDataFrameHeader) { + quiche::QuicheBuffer buffer = HttpEncoder::SerializeDataFrameHeader( + /* payload_length = */ 5, quiche::SimpleBufferAllocator::Get()); + char output[] = {0x00, // type (DATA) + 0x05}; // length + EXPECT_EQ(ABSL_ARRAYSIZE(output), buffer.size()); + quiche::test::CompareCharArraysWithHexError( + "DATA", buffer.data(), buffer.size(), output, ABSL_ARRAYSIZE(output)); +} + +TEST(HttpEncoderTest, SerializeHeadersFrameHeader) { + std::string header = + HttpEncoder::SerializeHeadersFrameHeader(/* payload_length = */ 7); + char output[] = {0x01, // type (HEADERS) + 0x07}; // length + quiche::test::CompareCharArraysWithHexError("HEADERS", header.data(), + header.length(), output, + ABSL_ARRAYSIZE(output)); +} + +TEST(HttpEncoderTest, SerializeSettingsFrame) { + SettingsFrame settings; + settings.values[1] = 2; + settings.values[6] = 5; + settings.values[256] = 4; + char output[] = {0x04, // type (SETTINGS) + 0x07, // length + 0x01, // identifier (SETTINGS_QPACK_MAX_TABLE_CAPACITY) + 0x02, // content + 0x06, // identifier (SETTINGS_MAX_HEADER_LIST_SIZE) + 0x05, // content + 0x41, 0x00, // identifier 0x100, varint encoded + 0x04}; // content + std::string frame = HttpEncoder::SerializeSettingsFrame(settings); + quiche::test::CompareCharArraysWithHexError( + "SETTINGS", frame.data(), frame.length(), output, ABSL_ARRAYSIZE(output)); +} + +TEST(HttpEncoderTest, SerializeGoAwayFrame) { + GoAwayFrame goaway; + goaway.id = 0x1; + char output[] = {0x07, // type (GOAWAY) + 0x1, // length + 0x01}; // ID + std::string frame = HttpEncoder::SerializeGoAwayFrame(goaway); + quiche::test::CompareCharArraysWithHexError( + "GOAWAY", frame.data(), frame.length(), output, ABSL_ARRAYSIZE(output)); +} + +TEST(HttpEncoderTest, SerializePriorityUpdateFrame) { + PriorityUpdateFrame priority_update1; + priority_update1.prioritized_element_id = 0x03; + uint8_t output1[] = {0x80, 0x0f, 0x07, 0x00, // type (PRIORITY_UPDATE) + 0x01, // length + 0x03}; // prioritized element id + + std::string frame1 = + HttpEncoder::SerializePriorityUpdateFrame(priority_update1); + quiche::test::CompareCharArraysWithHexError( + "PRIORITY_UPDATE", frame1.data(), frame1.length(), + reinterpret_cast(output1), ABSL_ARRAYSIZE(output1)); + + PriorityUpdateFrame priority_update2; + priority_update2.prioritized_element_id = 0x05; + priority_update2.priority_field_value = "foo"; + + uint8_t output2[] = {0x80, 0x0f, 0x07, 0x00, // type (PRIORITY_UPDATE) + 0x04, // length + 0x05, // prioritized element id + 0x66, 0x6f, 0x6f}; // priority field value: "foo" + + std::string frame2 = + HttpEncoder::SerializePriorityUpdateFrame(priority_update2); + quiche::test::CompareCharArraysWithHexError( + "PRIORITY_UPDATE", frame2.data(), frame2.length(), + reinterpret_cast(output2), ABSL_ARRAYSIZE(output2)); +} + +TEST(HttpEncoderTest, SerializeAcceptChFrame) { + AcceptChFrame accept_ch; + uint8_t output1[] = {0x40, 0x89, // type (ACCEPT_CH) + 0x00}; // length + + std::string frame1 = HttpEncoder::SerializeAcceptChFrame(accept_ch); + quiche::test::CompareCharArraysWithHexError( + "ACCEPT_CH", frame1.data(), frame1.length(), + reinterpret_cast(output1), ABSL_ARRAYSIZE(output1)); + + accept_ch.entries.push_back({"foo", "bar"}); + uint8_t output2[] = {0x40, 0x89, // type (ACCEPT_CH) + 0x08, // payload length + 0x03, 0x66, 0x6f, 0x6f, // length of "foo"; "foo" + 0x03, 0x62, 0x61, 0x72}; // length of "bar"; "bar" + + std::string frame2 = HttpEncoder::SerializeAcceptChFrame(accept_ch); + quiche::test::CompareCharArraysWithHexError( + "ACCEPT_CH", frame2.data(), frame2.length(), + reinterpret_cast(output2), ABSL_ARRAYSIZE(output2)); +} + +TEST(HttpEncoderTest, SerializeWebTransportStreamFrameHeader) { + WebTransportSessionId session_id = 0x17; + char output[] = {0x40, 0x41, // type (WEBTRANSPORT_STREAM) + 0x17}; // session ID + + std::string frame = + HttpEncoder::SerializeWebTransportStreamFrameHeader(session_id); + quiche::test::CompareCharArraysWithHexError("WEBTRANSPORT_STREAM", + frame.data(), frame.length(), + output, sizeof(output)); +} + +TEST(HttpEncoderTest, SerializeMetadataFrameHeader) { + std::string frame = HttpEncoder::SerializeMetadataFrameHeader( + /* payload_length = */ 7); + char output[] = {0x40, 0x4d, // type (METADATA, 0x4d, varint encoded) + 0x07}; // length + quiche::test::CompareCharArraysWithHexError( + "METADATA", frame.data(), frame.length(), output, ABSL_ARRAYSIZE(output)); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/http/http_frames.h b/quiche/quic/core/http/http_frames.h new file mode 100644 index 000000000000..56caca89fb9b --- /dev/null +++ b/quiche/quic/core/http/http_frames.h @@ -0,0 +1,163 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_HTTP_HTTP_FRAMES_H_ +#define QUICHE_QUIC_CORE_HTTP_HTTP_FRAMES_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/http/http_constants.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/spdy/core/spdy_protocol.h" + +namespace quic { + +enum class HttpFrameType { + DATA = 0x0, + HEADERS = 0x1, + CANCEL_PUSH = 0X3, + SETTINGS = 0x4, + PUSH_PROMISE = 0x5, + GOAWAY = 0x7, + MAX_PUSH_ID = 0xD, + // https://tools.ietf.org/html/draft-davidben-http-client-hint-reliability-02 + ACCEPT_CH = 0x89, + // https://tools.ietf.org/html/draft-ietf-httpbis-priority-03 + PRIORITY_UPDATE_REQUEST_STREAM = 0xF0700, + // https://www.ietf.org/archive/id/draft-ietf-webtrans-http3-00.html + WEBTRANSPORT_STREAM = 0x41, + METADATA = 0x4d, +}; + +// 7.2.1. DATA +// +// DATA frames (type=0x0) convey arbitrary, variable-length sequences of +// octets associated with an HTTP request or response payload. +struct QUIC_EXPORT_PRIVATE DataFrame { + absl::string_view data; +}; + +// 7.2.2. HEADERS +// +// The HEADERS frame (type=0x1) is used to carry a header block, +// compressed using QPACK. +struct QUIC_EXPORT_PRIVATE HeadersFrame { + absl::string_view headers; +}; + +// 7.2.4. SETTINGS +// +// The SETTINGS frame (type=0x4) conveys configuration parameters that +// affect how endpoints communicate, such as preferences and constraints +// on peer behavior + +using SettingsMap = absl::flat_hash_map; + +struct QUIC_EXPORT_PRIVATE SettingsFrame { + SettingsMap values; + + bool operator==(const SettingsFrame& rhs) const { + return values == rhs.values; + } + + std::string ToString() const { + std::string s; + for (auto it : values) { + std::string setting = absl::StrCat( + H3SettingsToString( + static_cast(it.first)), + " = ", it.second, "; "); + absl::StrAppend(&s, setting); + } + return s; + } + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<(std::ostream& os, + const SettingsFrame& s) { + os << s.ToString(); + return os; + } +}; + +// 7.2.6. GOAWAY +// +// The GOAWAY frame (type=0x7) is used to initiate shutdown of a connection by +// either endpoint. +struct QUIC_EXPORT_PRIVATE GoAwayFrame { + // When sent from server to client, |id| is a stream ID that should refer to + // a client-initiated bidirectional stream. + // When sent from client to server, |id| is a push ID. + uint64_t id; + + bool operator==(const GoAwayFrame& rhs) const { return id == rhs.id; } +}; + +// https://httpwg.org/http-extensions/draft-ietf-httpbis-priority.html +// +// The PRIORITY_UPDATE frame specifies the sender-advised priority of a stream. +// Frame type 0xf0700 (called PRIORITY_UPDATE_REQUEST_STREAM in the +// implementation) is used for for request streams. +// Frame type 0xf0701 would be used for push streams but it is not implemented; +// incoming 0xf0701 frames are treated as frames of unknown type. + +// Length of a priority frame's first byte. +const QuicByteCount kPriorityFirstByteLength = 1; + +struct QUIC_EXPORT_PRIVATE PriorityUpdateFrame { + uint64_t prioritized_element_id = 0; + std::string priority_field_value; + + bool operator==(const PriorityUpdateFrame& rhs) const { + return std::tie(prioritized_element_id, priority_field_value) == + std::tie(rhs.prioritized_element_id, rhs.priority_field_value); + } + std::string ToString() const { + return absl::StrCat( + "Priority Frame : {prioritized_element_id: ", prioritized_element_id, + ", priority_field_value: ", priority_field_value, "}"); + } + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const PriorityUpdateFrame& s) { + os << s.ToString(); + return os; + } +}; + +// ACCEPT_CH +// https://tools.ietf.org/html/draft-davidben-http-client-hint-reliability-02 +// +struct QUIC_EXPORT_PRIVATE AcceptChFrame { + std::vector entries; + + bool operator==(const AcceptChFrame& rhs) const { + return entries.size() == rhs.entries.size() && + std::equal(entries.begin(), entries.end(), rhs.entries.begin()); + } + + std::string ToString() const { + std::stringstream s; + s << *this; + return s.str(); + } + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const AcceptChFrame& frame) { + os << "ACCEPT_CH frame with " << frame.entries.size() << " entries: "; + for (auto& entry : frame.entries) { + os << "origin: " << entry.origin << "; value: " << entry.value; + } + return os; + } +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_HTTP_HTTP_FRAMES_H_ diff --git a/quiche/quic/core/http/http_frames_test.cc b/quiche/quic/core/http/http_frames_test.cc new file mode 100644 index 000000000000..ba633717df6a --- /dev/null +++ b/quiche/quic/core/http/http_frames_test.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/http_frames.h" + +#include + +#include "quiche/quic/core/http/http_constants.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { + +TEST(HttpFramesTest, SettingsFrame) { + SettingsFrame a; + EXPECT_TRUE(a == a); + EXPECT_EQ("", a.ToString()); + + SettingsFrame b; + b.values[SETTINGS_QPACK_MAX_TABLE_CAPACITY] = 1; + EXPECT_FALSE(a == b); + EXPECT_TRUE(b == b); + + a.values[SETTINGS_QPACK_MAX_TABLE_CAPACITY] = 2; + EXPECT_FALSE(a == b); + a.values[SETTINGS_QPACK_MAX_TABLE_CAPACITY] = 1; + EXPECT_TRUE(a == b); + + EXPECT_EQ("SETTINGS_QPACK_MAX_TABLE_CAPACITY = 1; ", b.ToString()); + std::stringstream s; + s << b; + EXPECT_EQ("SETTINGS_QPACK_MAX_TABLE_CAPACITY = 1; ", s.str()); +} + +TEST(HttpFramesTest, GoAwayFrame) { + GoAwayFrame a{1}; + EXPECT_TRUE(a == a); + + GoAwayFrame b{2}; + EXPECT_FALSE(a == b); + + b.id = 1; + EXPECT_TRUE(a == b); +} + +TEST(HttpFramesTest, PriorityUpdateFrame) { + PriorityUpdateFrame a{0, ""}; + EXPECT_TRUE(a == a); + PriorityUpdateFrame b{4, ""}; + EXPECT_FALSE(a == b); + a.prioritized_element_id = 4; + EXPECT_TRUE(a == b); + + a.priority_field_value = "foo"; + EXPECT_FALSE(a == b); + + EXPECT_EQ( + "Priority Frame : {prioritized_element_id: 4, priority_field_value: foo}", + a.ToString()); + std::stringstream s; + s << a; + EXPECT_EQ( + "Priority Frame : {prioritized_element_id: 4, priority_field_value: foo}", + s.str()); +} + +TEST(HttpFramesTest, AcceptChFrame) { + AcceptChFrame a; + EXPECT_TRUE(a == a); + EXPECT_EQ("ACCEPT_CH frame with 0 entries: ", a.ToString()); + + AcceptChFrame b{{{"foo", "bar"}}}; + EXPECT_FALSE(a == b); + + a.entries.push_back({"foo", "bar"}); + EXPECT_TRUE(a == b); + + EXPECT_EQ("ACCEPT_CH frame with 1 entries: origin: foo; value: bar", + a.ToString()); + std::stringstream s; + s << a; + EXPECT_EQ("ACCEPT_CH frame with 1 entries: origin: foo; value: bar", s.str()); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/http/quic_client_promised_info.cc b/quiche/quic/core/http/quic_client_promised_info.cc new file mode 100644 index 000000000000..db53585adbad --- /dev/null +++ b/quiche/quic/core/http/quic_client_promised_info.cc @@ -0,0 +1,146 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_client_promised_info.h" + +#include +#include + +#include "quiche/quic/core/http/spdy_server_push_utils.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/spdy/core/spdy_protocol.h" + +using spdy::Http2HeaderBlock; + +namespace quic { + +QuicClientPromisedInfo::QuicClientPromisedInfo( + QuicSpdyClientSessionBase* session, QuicStreamId id, std::string url) + : session_(session), + id_(id), + url_(std::move(url)), + client_request_delegate_(nullptr) {} + +QuicClientPromisedInfo::~QuicClientPromisedInfo() { + if (cleanup_alarm_ != nullptr) { + cleanup_alarm_->PermanentCancel(); + } +} + +void QuicClientPromisedInfo::CleanupAlarm::OnAlarm() { + QUIC_DVLOG(1) << "self GC alarm for stream " << promised_->id_; + promised_->session()->OnPushStreamTimedOut(promised_->id_); + promised_->Reset(QUIC_PUSH_STREAM_TIMED_OUT); +} + +void QuicClientPromisedInfo::Init() { + cleanup_alarm_.reset(session_->connection()->alarm_factory()->CreateAlarm( + new QuicClientPromisedInfo::CleanupAlarm(this))); + cleanup_alarm_->Set( + session_->connection()->helper()->GetClock()->ApproximateNow() + + QuicTime::Delta::FromSeconds(kPushPromiseTimeoutSecs)); +} + +bool QuicClientPromisedInfo::OnPromiseHeaders(const Http2HeaderBlock& headers) { + // RFC7540, Section 8.2, requests MUST be safe [RFC7231], Section + // 4.2.1. GET and HEAD are the methods that are safe and required. + Http2HeaderBlock::const_iterator it = headers.find(spdy::kHttp2MethodHeader); + if (it == headers.end()) { + QUIC_DVLOG(1) << "Promise for stream " << id_ << " has no method"; + Reset(QUIC_INVALID_PROMISE_METHOD); + return false; + } + if (!(it->second == "GET" || it->second == "HEAD")) { + QUIC_DVLOG(1) << "Promise for stream " << id_ << " has invalid method " + << it->second; + Reset(QUIC_INVALID_PROMISE_METHOD); + return false; + } + if (!SpdyServerPushUtils::PromisedUrlIsValid(headers)) { + QUIC_DVLOG(1) << "Promise for stream " << id_ << " has invalid URL " + << url_; + Reset(QUIC_INVALID_PROMISE_URL); + return false; + } + if (!session_->IsAuthorized( + SpdyServerPushUtils::GetPromisedHostNameFromHeaders(headers))) { + Reset(QUIC_UNAUTHORIZED_PROMISE_URL); + return false; + } + request_headers_ = headers.Clone(); + return true; +} + +void QuicClientPromisedInfo::OnResponseHeaders( + const Http2HeaderBlock& headers) { + response_headers_ = std::make_unique(headers.Clone()); + if (client_request_delegate_) { + // We already have a client request waiting. + FinalValidation(); + } +} + +void QuicClientPromisedInfo::Reset(QuicRstStreamErrorCode error_code) { + QuicClientPushPromiseIndex::Delegate* delegate = client_request_delegate_; + session_->ResetPromised(id_, error_code); + session_->DeletePromised(this); + if (delegate) { + delegate->OnRendezvousResult(nullptr); + } +} + +QuicAsyncStatus QuicClientPromisedInfo::FinalValidation() { + if (!client_request_delegate_->CheckVary( + client_request_headers_, request_headers_, *response_headers_)) { + Reset(QUIC_PROMISE_VARY_MISMATCH); + return QUIC_FAILURE; + } + QuicSpdyStream* stream = session_->GetPromisedStream(id_); + if (!stream) { + // This shouldn't be possible, as |ClientRequest| guards against + // closed stream for the synchronous case. And in the + // asynchronous case, a RST can only be caught by |OnAlarm()|. + QUIC_BUG(quic_bug_10378_1) << "missing promised stream" << id_; + } + QuicClientPushPromiseIndex::Delegate* delegate = client_request_delegate_; + session_->DeletePromised(this); + // Stream can start draining now + if (delegate) { + delegate->OnRendezvousResult(stream); + } + return QUIC_SUCCESS; +} + +QuicAsyncStatus QuicClientPromisedInfo::HandleClientRequest( + const Http2HeaderBlock& request_headers, + QuicClientPushPromiseIndex::Delegate* delegate) { + if (session_->IsClosedStream(id_)) { + // There was a RST on the response stream. + session_->DeletePromised(this); + return QUIC_FAILURE; + } + + if (is_validating()) { + // The push promise has already been matched to another request though + // pending for validation. Returns QUIC_FAILURE to the caller as it couldn't + // match a new request any more. This will not affect the validation of the + // other request. + return QUIC_FAILURE; + } + + client_request_delegate_ = delegate; + client_request_headers_ = request_headers.Clone(); + if (response_headers_ == nullptr) { + return QUIC_PENDING; + } + return FinalValidation(); +} + +void QuicClientPromisedInfo::Cancel() { + // Don't fire OnRendezvousResult() for client initiated cancel. + client_request_delegate_ = nullptr; + Reset(QUIC_STREAM_CANCELLED); +} + +} // namespace quic diff --git a/quiche/quic/core/http/quic_client_promised_info.h b/quiche/quic/core/http/quic_client_promised_info.h new file mode 100644 index 000000000000..7b4f460c2ff6 --- /dev/null +++ b/quiche/quic/core/http/quic_client_promised_info.h @@ -0,0 +1,115 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_HTTP_QUIC_CLIENT_PROMISED_INFO_H_ +#define QUICHE_QUIC_CORE_HTTP_QUIC_CLIENT_PROMISED_INFO_H_ + +#include +#include + +#include "quiche/quic/core/http/quic_client_push_promise_index.h" +#include "quiche/quic/core/http/quic_spdy_client_session_base.h" +#include "quiche/quic/core/http/quic_spdy_stream.h" +#include "quiche/quic/core/quic_alarm.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/spdy/core/http2_header_block.h" +#include "quiche/spdy/core/spdy_framer.h" + +namespace quic { + +namespace test { +class QuicClientPromisedInfoPeer; +} // namespace test + +// QuicClientPromisedInfo tracks the client state of a server push +// stream from the time a PUSH_PROMISE is received until rendezvous +// between the promised response and the corresponding client request +// is complete. +class QUIC_EXPORT_PRIVATE QuicClientPromisedInfo + : public QuicClientPushPromiseIndex::TryHandle { + public: + // Interface to QuicSpdyClientStream + QuicClientPromisedInfo(QuicSpdyClientSessionBase* session, QuicStreamId id, + std::string url); + QuicClientPromisedInfo(const QuicClientPromisedInfo&) = delete; + QuicClientPromisedInfo& operator=(const QuicClientPromisedInfo&) = delete; + virtual ~QuicClientPromisedInfo(); + + void Init(); + + // Validate promise headers etc. Returns true if headers are valid. + bool OnPromiseHeaders(const spdy::Http2HeaderBlock& headers); + + // Store response, possibly proceed with final validation. + void OnResponseHeaders(const spdy::Http2HeaderBlock& headers); + + // Rendezvous between this promised stream and a client request that + // has a matching URL. + virtual QuicAsyncStatus HandleClientRequest( + const spdy::Http2HeaderBlock& headers, + QuicClientPushPromiseIndex::Delegate* delegate); + + void Cancel() override; + + void Reset(QuicRstStreamErrorCode error_code); + + // Client requests are initially associated to promises by matching + // URL in the client request against the URL in the promise headers, + // uing the |promised_by_url| map. The push can be cross-origin, so + // the client should validate that the session is authoritative for + // the promised URL. If not, it should call |RejectUnauthorized|. + QuicSpdyClientSessionBase* session() { return session_; } + + // If the promised response contains Vary header, then the fields + // specified by Vary must match between the client request header + // and the promise headers (see https://crbug.com//554220). Vary + // validation requires the response headers (for the actual Vary + // field list), the promise headers (taking the role of the "cached" + // request), and the client request headers. + spdy::Http2HeaderBlock* request_headers() { return &request_headers_; } + + spdy::Http2HeaderBlock* response_headers() { return response_headers_.get(); } + + // After validation, client will use this to access the pushed stream. + + QuicStreamId id() const { return id_; } + + const std::string url() const { return url_; } + + // Return true if there's a request pending matching this push promise. + bool is_validating() const { return client_request_delegate_ != nullptr; } + + private: + friend class test::QuicClientPromisedInfoPeer; + + class QUIC_EXPORT_PRIVATE CleanupAlarm + : public QuicAlarm::DelegateWithoutContext { + public: + explicit CleanupAlarm(QuicClientPromisedInfo* promised) + : promised_(promised) {} + + void OnAlarm() override; + + QuicClientPromisedInfo* promised_; + }; + + QuicAsyncStatus FinalValidation(); + + QuicSpdyClientSessionBase* session_; + QuicStreamId id_; + std::string url_; + spdy::Http2HeaderBlock request_headers_; + std::unique_ptr response_headers_; + spdy::Http2HeaderBlock client_request_headers_; + QuicClientPushPromiseIndex::Delegate* client_request_delegate_; + + // The promise will commit suicide eventually if it is not claimed by a GET + // first. + std::unique_ptr cleanup_alarm_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_HTTP_QUIC_CLIENT_PROMISED_INFO_H_ diff --git a/quiche/quic/core/http/quic_client_promised_info_test.cc b/quiche/quic/core/http/quic_client_promised_info_test.cc new file mode 100644 index 000000000000..469c0c2384f6 --- /dev/null +++ b/quiche/quic/core/http/quic_client_promised_info_test.cc @@ -0,0 +1,350 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_client_promised_info.h" + +#include +#include +#include + +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/core/http/quic_spdy_client_session.h" +#include "quiche/quic/core/http/spdy_server_push_utils.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/quic_client_promised_info_peer.h" +#include "quiche/quic/test_tools/quic_spdy_session_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +using spdy::Http2HeaderBlock; +using testing::_; +using testing::StrictMock; + +namespace quic { +namespace test { +namespace { + +class MockQuicSpdyClientSession : public QuicSpdyClientSession { + public: + explicit MockQuicSpdyClientSession( + const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, + QuicClientPushPromiseIndex* push_promise_index) + : QuicSpdyClientSession(DefaultQuicConfig(), supported_versions, + connection, + QuicServerId("example.com", 443, false), + &crypto_config_, push_promise_index), + crypto_config_(crypto_test_utils::ProofVerifierForTesting()), + authorized_(true) {} + MockQuicSpdyClientSession(const MockQuicSpdyClientSession&) = delete; + MockQuicSpdyClientSession& operator=(const MockQuicSpdyClientSession&) = + delete; + ~MockQuicSpdyClientSession() override {} + + bool IsAuthorized(const std::string& /*authority*/) override { + return authorized_; + } + + void set_authorized(bool authorized) { authorized_ = authorized; } + + MOCK_METHOD(bool, WriteControlFrame, + (const QuicFrame& frame, TransmissionType type), (override)); + + private: + QuicCryptoClientConfig crypto_config_; + + bool authorized_; +}; + +class QuicClientPromisedInfoTest : public QuicTest { + public: + class StreamVisitor; + + QuicClientPromisedInfoTest() + : connection_(new StrictMock( + &helper_, &alarm_factory_, Perspective::IS_CLIENT)), + session_(connection_->supported_versions(), connection_, + &push_promise_index_), + body_("hello world"), + promise_id_( + QuicUtils::GetInvalidStreamId(connection_->transport_version())) { + connection_->AdvanceTime(QuicTime::Delta::FromSeconds(1)); + connection_->SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(connection_->perspective())); + session_.Initialize(); + + headers_[":status"] = "200"; + headers_["content-length"] = "11"; + + stream_ = std::make_unique( + GetNthClientInitiatedBidirectionalStreamId( + connection_->transport_version(), 0), + &session_, BIDIRECTIONAL); + stream_visitor_ = std::make_unique(); + stream_->set_visitor(stream_visitor_.get()); + + push_promise_[":path"] = "/bar"; + push_promise_[":authority"] = "www.google.com"; + push_promise_[":method"] = "GET"; + push_promise_[":scheme"] = "https"; + + promise_url_ = + SpdyServerPushUtils::GetPromisedUrlFromHeaders(push_promise_); + + client_request_ = push_promise_.Clone(); + promise_id_ = GetNthServerInitiatedUnidirectionalStreamId( + connection_->transport_version(), 0); + } + + class StreamVisitor : public QuicSpdyClientStream::Visitor { + void OnClose(QuicSpdyStream* stream) override { + QUIC_DVLOG(1) << "stream " << stream->id(); + } + }; + + void ReceivePromise(QuicStreamId id) { + auto headers = AsHeaderList(push_promise_); + stream_->OnPromiseHeaderList(id, headers.uncompressed_header_bytes(), + headers); + } + + MockQuicConnectionHelper helper_; + MockAlarmFactory alarm_factory_; + StrictMock* connection_; + QuicClientPushPromiseIndex push_promise_index_; + + MockQuicSpdyClientSession session_; + std::unique_ptr stream_; + std::unique_ptr stream_visitor_; + std::unique_ptr promised_stream_; + Http2HeaderBlock headers_; + std::string body_; + Http2HeaderBlock push_promise_; + QuicStreamId promise_id_; + std::string promise_url_; + Http2HeaderBlock client_request_; +}; + +TEST_F(QuicClientPromisedInfoTest, PushPromise) { + ReceivePromise(promise_id_); + + // Verify that the promise is in the unclaimed streams map. + EXPECT_NE(session_.GetPromisedById(promise_id_), nullptr); +} + +TEST_F(QuicClientPromisedInfoTest, PushPromiseCleanupAlarm) { + ReceivePromise(promise_id_); + + // Verify that the promise is in the unclaimed streams map. + QuicClientPromisedInfo* promised = session_.GetPromisedById(promise_id_); + ASSERT_NE(promised, nullptr); + + // Fire the alarm that will cancel the promised stream. + EXPECT_CALL(session_, WriteControlFrame(_, _)); + EXPECT_CALL(*connection_, + OnStreamReset(promise_id_, QUIC_PUSH_STREAM_TIMED_OUT)); + alarm_factory_.FireAlarm(QuicClientPromisedInfoPeer::GetAlarm(promised)); + + // Verify that the promise is gone after the alarm fires. + EXPECT_EQ(session_.GetPromisedById(promise_id_), nullptr); + EXPECT_EQ(session_.GetPromisedByUrl(promise_url_), nullptr); +} + +TEST_F(QuicClientPromisedInfoTest, PushPromiseInvalidMethod) { + // Promise with an unsafe method + push_promise_[":method"] = "PUT"; + + EXPECT_CALL(session_, WriteControlFrame(_, _)); + EXPECT_CALL(*connection_, + OnStreamReset(promise_id_, QUIC_INVALID_PROMISE_METHOD)); + ReceivePromise(promise_id_); + + // Verify that the promise headers were ignored + EXPECT_EQ(session_.GetPromisedById(promise_id_), nullptr); + EXPECT_EQ(session_.GetPromisedByUrl(promise_url_), nullptr); +} + +TEST_F(QuicClientPromisedInfoTest, PushPromiseMissingMethod) { + // Promise with a missing method + push_promise_.erase(":method"); + + EXPECT_CALL(session_, WriteControlFrame(_, _)); + EXPECT_CALL(*connection_, + OnStreamReset(promise_id_, QUIC_INVALID_PROMISE_METHOD)); + ReceivePromise(promise_id_); + + // Verify that the promise headers were ignored + EXPECT_EQ(session_.GetPromisedById(promise_id_), nullptr); + EXPECT_EQ(session_.GetPromisedByUrl(promise_url_), nullptr); +} + +TEST_F(QuicClientPromisedInfoTest, PushPromiseInvalidUrl) { + // Remove required header field to make URL invalid + push_promise_.erase(":authority"); + + EXPECT_CALL(session_, WriteControlFrame(_, _)); + EXPECT_CALL(*connection_, + OnStreamReset(promise_id_, QUIC_INVALID_PROMISE_URL)); + ReceivePromise(promise_id_); + + // Verify that the promise headers were ignored + EXPECT_EQ(session_.GetPromisedById(promise_id_), nullptr); + EXPECT_EQ(session_.GetPromisedByUrl(promise_url_), nullptr); +} + +TEST_F(QuicClientPromisedInfoTest, PushPromiseUnauthorizedUrl) { + session_.set_authorized(false); + + EXPECT_CALL(session_, WriteControlFrame(_, _)); + EXPECT_CALL(*connection_, + OnStreamReset(promise_id_, QUIC_UNAUTHORIZED_PROMISE_URL)); + + ReceivePromise(promise_id_); + + QuicClientPromisedInfo* promised = session_.GetPromisedById(promise_id_); + ASSERT_EQ(promised, nullptr); +} + +TEST_F(QuicClientPromisedInfoTest, PushPromiseMismatch) { + ReceivePromise(promise_id_); + + QuicClientPromisedInfo* promised = session_.GetPromisedById(promise_id_); + ASSERT_NE(promised, nullptr); + + // Need to send the promised response headers and initiate the + // rendezvous for secondary validation to proceed. + QuicSpdyClientStream* promise_stream = static_cast( + session_.GetOrCreateStream(promise_id_)); + auto headers = AsHeaderList(headers_); + promise_stream->OnStreamHeaderList(false, headers.uncompressed_header_bytes(), + headers); + + TestPushPromiseDelegate delegate(/*match=*/false); + EXPECT_CALL(session_, WriteControlFrame(_, _)); + EXPECT_CALL(*connection_, + OnStreamReset(promise_id_, QUIC_PROMISE_VARY_MISMATCH)); + + promised->HandleClientRequest(client_request_, &delegate); +} + +TEST_F(QuicClientPromisedInfoTest, PushPromiseVaryWaits) { + ReceivePromise(promise_id_); + + QuicClientPromisedInfo* promised = session_.GetPromisedById(promise_id_); + EXPECT_FALSE(promised->is_validating()); + ASSERT_NE(promised, nullptr); + + // Now initiate rendezvous. + TestPushPromiseDelegate delegate(/*match=*/true); + promised->HandleClientRequest(client_request_, &delegate); + EXPECT_TRUE(promised->is_validating()); + + // Promise is still there, waiting for response. + EXPECT_NE(session_.GetPromisedById(promise_id_), nullptr); + + // Send Response, should trigger promise validation and complete rendezvous + QuicSpdyClientStream* promise_stream = static_cast( + session_.GetOrCreateStream(promise_id_)); + ASSERT_NE(promise_stream, nullptr); + auto headers = AsHeaderList(headers_); + promise_stream->OnStreamHeaderList(false, headers.uncompressed_header_bytes(), + headers); + + // Promise is gone + EXPECT_EQ(session_.GetPromisedById(promise_id_), nullptr); +} + +TEST_F(QuicClientPromisedInfoTest, PushPromiseVaryNoWait) { + ReceivePromise(promise_id_); + + QuicClientPromisedInfo* promised = session_.GetPromisedById(promise_id_); + ASSERT_NE(promised, nullptr); + + QuicSpdyClientStream* promise_stream = static_cast( + session_.GetOrCreateStream(promise_id_)); + ASSERT_NE(promise_stream, nullptr); + + // Send Response, should trigger promise validation and complete rendezvous + auto headers = AsHeaderList(headers_); + promise_stream->OnStreamHeaderList(false, headers.uncompressed_header_bytes(), + headers); + + // Now initiate rendezvous. + TestPushPromiseDelegate delegate(/*match=*/true); + promised->HandleClientRequest(client_request_, &delegate); + + // Promise is gone + EXPECT_EQ(session_.GetPromisedById(promise_id_), nullptr); + // Have a push stream + EXPECT_TRUE(delegate.rendezvous_fired()); + + EXPECT_NE(delegate.rendezvous_stream(), nullptr); +} + +TEST_F(QuicClientPromisedInfoTest, PushPromiseWaitCancels) { + ReceivePromise(promise_id_); + + QuicClientPromisedInfo* promised = session_.GetPromisedById(promise_id_); + ASSERT_NE(promised, nullptr); + + // Now initiate rendezvous. + TestPushPromiseDelegate delegate(/*match=*/true); + promised->HandleClientRequest(client_request_, &delegate); + + // Promise is still there, waiting for response. + EXPECT_NE(session_.GetPromisedById(promise_id_), nullptr); + + // Create response stream, but no data yet. + session_.GetOrCreateStream(promise_id_); + + // Cancel the promised stream. + EXPECT_CALL(session_, WriteControlFrame(_, _)); + EXPECT_CALL(*connection_, OnStreamReset(promise_id_, QUIC_STREAM_CANCELLED)); + promised->Cancel(); + + // Promise is gone + EXPECT_EQ(session_.GetPromisedById(promise_id_), nullptr); +} + +TEST_F(QuicClientPromisedInfoTest, PushPromiseDataClosed) { + ReceivePromise(promise_id_); + + QuicClientPromisedInfo* promised = session_.GetPromisedById(promise_id_); + ASSERT_NE(promised, nullptr); + + QuicSpdyClientStream* promise_stream = static_cast( + session_.GetOrCreateStream(promise_id_)); + ASSERT_NE(promise_stream, nullptr); + + // Send response, rendezvous will be able to finish synchronously. + auto headers = AsHeaderList(headers_); + promise_stream->OnStreamHeaderList(false, headers.uncompressed_header_bytes(), + headers); + + EXPECT_CALL(session_, WriteControlFrame(_, _)); + EXPECT_CALL(*connection_, + OnStreamReset(promise_id_, QUIC_STREAM_PEER_GOING_AWAY)); + session_.ResetStream(promise_id_, QUIC_STREAM_PEER_GOING_AWAY); + + // Now initiate rendezvous. + TestPushPromiseDelegate delegate(/*match=*/true); + EXPECT_EQ(promised->HandleClientRequest(client_request_, &delegate), + QUIC_FAILURE); + + // Got an indication of the stream failure, client should retry + // request. + EXPECT_FALSE(delegate.rendezvous_fired()); + EXPECT_EQ(delegate.rendezvous_stream(), nullptr); + + // Promise is gone + EXPECT_EQ(session_.GetPromisedById(promise_id_), nullptr); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/http/quic_client_push_promise_index.cc b/quiche/quic/core/http/quic_client_push_promise_index.cc new file mode 100644 index 000000000000..00ac30a3f8b4 --- /dev/null +++ b/quiche/quic/core/http/quic_client_push_promise_index.cc @@ -0,0 +1,45 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_client_push_promise_index.h" + +#include + +#include "quiche/quic/core/http/quic_client_promised_info.h" +#include "quiche/quic/core/http/spdy_server_push_utils.h" + +namespace quic { + +QuicClientPushPromiseIndex::QuicClientPushPromiseIndex() {} + +QuicClientPushPromiseIndex::~QuicClientPushPromiseIndex() {} + +QuicClientPushPromiseIndex::TryHandle::~TryHandle() {} + +QuicClientPromisedInfo* QuicClientPushPromiseIndex::GetPromised( + const std::string& url) { + auto it = promised_by_url_.find(url); + if (it == promised_by_url_.end()) { + return nullptr; + } + return it->second; +} + +QuicAsyncStatus QuicClientPushPromiseIndex::Try( + const spdy::Http2HeaderBlock& request, + QuicClientPushPromiseIndex::Delegate* delegate, TryHandle** handle) { + std::string url(SpdyServerPushUtils::GetPromisedUrlFromHeaders(request)); + auto it = promised_by_url_.find(url); + if (it != promised_by_url_.end()) { + QuicClientPromisedInfo* promised = it->second; + QuicAsyncStatus rv = promised->HandleClientRequest(request, delegate); + if (rv == QUIC_PENDING) { + *handle = promised; + } + return rv; + } + return QUIC_FAILURE; +} + +} // namespace quic diff --git a/quiche/quic/core/http/quic_client_push_promise_index.h b/quiche/quic/core/http/quic_client_push_promise_index.h new file mode 100644 index 000000000000..c00a17eac540 --- /dev/null +++ b/quiche/quic/core/http/quic_client_push_promise_index.h @@ -0,0 +1,99 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_HTTP_QUIC_CLIENT_PUSH_PROMISE_INDEX_H_ +#define QUICHE_QUIC_CORE_HTTP_QUIC_CLIENT_PUSH_PROMISE_INDEX_H_ + +#include + +#include "quiche/quic/core/http/quic_spdy_client_session_base.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +// QuicClientPushPromiseIndex is the interface to support rendezvous +// between client requests and resources delivered via server push. +// The same index can be shared across multiple sessions (e.g. for the +// same browser users profile), since cross-origin pushes are allowed +// (subject to authority constraints). + +class QUIC_EXPORT_PRIVATE QuicClientPushPromiseIndex { + public: + // Delegate is used to complete the rendezvous that began with + // |Try()|. + class QUIC_EXPORT_PRIVATE Delegate { + public: + virtual ~Delegate() {} + + // The primary lookup matched request with push promise by URL. A + // secondary match is necessary to ensure Vary (RFC 2616, 14.14) + // is honored. If Vary is not present, return true. If Vary is + // present, return whether designated header fields of + // |promise_request| and |client_request| match. + virtual bool CheckVary(const spdy::Http2HeaderBlock& client_request, + const spdy::Http2HeaderBlock& promise_request, + const spdy::Http2HeaderBlock& promise_response) = 0; + + // On rendezvous success, provides the promised |stream|. Callee + // does not inherit ownership of |stream|. On rendezvous failure, + // |stream| is |nullptr| and the client should retry the request. + // Rendezvous can fail due to promise validation failure or RST on + // promised stream. |url| will have been removed from the index + // before |OnRendezvousResult()| is invoked, so a recursive call to + // |Try()| will return |QUIC_FAILURE|, which may be convenient for + // retry purposes. + virtual void OnRendezvousResult(QuicSpdyStream* stream) = 0; + }; + + class QUIC_EXPORT_PRIVATE TryHandle { + public: + // Cancel the request. + virtual void Cancel() = 0; + + protected: + TryHandle() {} + TryHandle(const TryHandle&) = delete; + TryHandle& operator=(const TryHandle&) = delete; + ~TryHandle(); + }; + + QuicClientPushPromiseIndex(); + QuicClientPushPromiseIndex(const QuicClientPushPromiseIndex&) = delete; + QuicClientPushPromiseIndex& operator=(const QuicClientPushPromiseIndex&) = + delete; + virtual ~QuicClientPushPromiseIndex(); + + // Called by client code, used to enforce affinity between requests + // for promised streams and the session the promise came from. + QuicClientPromisedInfo* GetPromised(const std::string& url); + + // Called by client code, to initiate rendezvous between a request + // and a server push stream. If |request|'s url is in the index, + // rendezvous will be attempted and may complete immediately or + // asynchronously. If the matching promise and response headers + // have already arrived, the delegate's methods will fire + // recursively from within |Try()|. Returns |QUIC_SUCCESS| if the + // rendezvous was a success. Returns |QUIC_FAILURE| if there was no + // matching promise, or if there was but the rendezvous has failed. + // Returns QUIC_PENDING if a matching promise was found, but the + // rendezvous needs to complete asynchronously because the promised + // response headers are not yet available. If result is + // QUIC_PENDING, then |*handle| will set so that the caller may + // cancel the request if need be. The caller does not inherit + // ownership of |*handle|, and it ceases to be valid if the caller + // invokes |handle->Cancel()| or if |delegate->OnReponse()| fires. + QuicAsyncStatus Try(const spdy::Http2HeaderBlock& request, Delegate* delegate, + TryHandle** handle); + + QuicPromisedByUrlMap* promised_by_url() { return &promised_by_url_; } + + private: + QuicPromisedByUrlMap promised_by_url_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_HTTP_QUIC_CLIENT_PUSH_PROMISE_INDEX_H_ diff --git a/quiche/quic/core/http/quic_client_push_promise_index_test.cc b/quiche/quic/core/http/quic_client_push_promise_index_test.cc new file mode 100644 index 000000000000..58104fd2bd8b --- /dev/null +++ b/quiche/quic/core/http/quic_client_push_promise_index_test.cc @@ -0,0 +1,109 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_client_push_promise_index.h" + +#include + +#include "quiche/quic/core/http/quic_spdy_client_session.h" +#include "quiche/quic/core/http/spdy_server_push_utils.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/mock_quic_client_promised_info.h" +#include "quiche/quic/test_tools/quic_spdy_session_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +using testing::_; +using testing::Return; +using testing::StrictMock; + +namespace quic { +namespace test { +namespace { + +class MockQuicSpdyClientSession : public QuicSpdyClientSession { + public: + explicit MockQuicSpdyClientSession( + const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, + QuicClientPushPromiseIndex* push_promise_index) + : QuicSpdyClientSession(DefaultQuicConfig(), supported_versions, + connection, + QuicServerId("example.com", 443, false), + &crypto_config_, push_promise_index), + crypto_config_(crypto_test_utils::ProofVerifierForTesting()) {} + MockQuicSpdyClientSession(const MockQuicSpdyClientSession&) = delete; + MockQuicSpdyClientSession& operator=(const MockQuicSpdyClientSession&) = + delete; + ~MockQuicSpdyClientSession() override {} + + private: + QuicCryptoClientConfig crypto_config_; +}; + +class QuicClientPushPromiseIndexTest : public QuicTest { + public: + QuicClientPushPromiseIndexTest() + : connection_(new StrictMock( + &helper_, &alarm_factory_, Perspective::IS_CLIENT)), + session_(connection_->supported_versions(), connection_, &index_), + promised_(&session_, + GetNthServerInitiatedUnidirectionalStreamId( + connection_->transport_version(), 0), + url_) { + request_[":path"] = "/bar"; + request_[":authority"] = "www.google.com"; + request_[":method"] = "GET"; + request_[":scheme"] = "https"; + url_ = SpdyServerPushUtils::GetPromisedUrlFromHeaders(request_); + } + + MockQuicConnectionHelper helper_; + MockAlarmFactory alarm_factory_; + StrictMock* connection_; + MockQuicSpdyClientSession session_; + QuicClientPushPromiseIndex index_; + spdy::Http2HeaderBlock request_; + std::string url_; + MockQuicClientPromisedInfo promised_; + QuicClientPushPromiseIndex::TryHandle* handle_; +}; + +TEST_F(QuicClientPushPromiseIndexTest, TryRequestSuccess) { + (*index_.promised_by_url())[url_] = &promised_; + EXPECT_CALL(promised_, HandleClientRequest(_, _)) + .WillOnce(Return(QUIC_SUCCESS)); + EXPECT_EQ(index_.Try(request_, nullptr, &handle_), QUIC_SUCCESS); +} + +TEST_F(QuicClientPushPromiseIndexTest, TryRequestPending) { + (*index_.promised_by_url())[url_] = &promised_; + EXPECT_CALL(promised_, HandleClientRequest(_, _)) + .WillOnce(Return(QUIC_PENDING)); + EXPECT_EQ(index_.Try(request_, nullptr, &handle_), QUIC_PENDING); +} + +TEST_F(QuicClientPushPromiseIndexTest, TryRequestFailure) { + (*index_.promised_by_url())[url_] = &promised_; + EXPECT_CALL(promised_, HandleClientRequest(_, _)) + .WillOnce(Return(QUIC_FAILURE)); + EXPECT_EQ(index_.Try(request_, nullptr, &handle_), QUIC_FAILURE); +} + +TEST_F(QuicClientPushPromiseIndexTest, TryNoPromise) { + EXPECT_EQ(index_.Try(request_, nullptr, &handle_), QUIC_FAILURE); +} + +TEST_F(QuicClientPushPromiseIndexTest, GetNoPromise) { + EXPECT_EQ(index_.GetPromised(url_), nullptr); +} + +TEST_F(QuicClientPushPromiseIndexTest, GetPromise) { + (*index_.promised_by_url())[url_] = &promised_; + EXPECT_EQ(index_.GetPromised(url_), &promised_); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/http/quic_header_list.cc b/quiche/quic/core/http/quic_header_list.cc new file mode 100644 index 000000000000..771e61752281 --- /dev/null +++ b/quiche/quic/core/http/quic_header_list.cc @@ -0,0 +1,75 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_header_list.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/qpack/qpack_header_table.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/platform/api/quic_flags.h" + +namespace quic { + +QuicHeaderList::QuicHeaderList() + : max_header_list_size_(std::numeric_limits::max()), + current_header_list_size_(0), + uncompressed_header_bytes_(0), + compressed_header_bytes_(0) {} + +QuicHeaderList::QuicHeaderList(QuicHeaderList&& other) = default; + +QuicHeaderList::QuicHeaderList(const QuicHeaderList& other) = default; + +QuicHeaderList& QuicHeaderList::operator=(const QuicHeaderList& other) = + default; + +QuicHeaderList& QuicHeaderList::operator=(QuicHeaderList&& other) = default; + +QuicHeaderList::~QuicHeaderList() {} + +void QuicHeaderList::OnHeaderBlockStart() { + QUIC_BUG_IF(quic_bug_12518_1, current_header_list_size_ != 0) + << "OnHeaderBlockStart called more than once!"; +} + +void QuicHeaderList::OnHeader(absl::string_view name, absl::string_view value) { + // Avoid infinite buffering of headers. No longer store headers + // once the current headers are over the limit. + if (current_header_list_size_ < max_header_list_size_) { + current_header_list_size_ += name.size(); + current_header_list_size_ += value.size(); + current_header_list_size_ += kQpackEntrySizeOverhead; + header_list_.emplace_back(std::string(name), std::string(value)); + } +} + +void QuicHeaderList::OnHeaderBlockEnd(size_t uncompressed_header_bytes, + size_t compressed_header_bytes) { + uncompressed_header_bytes_ = uncompressed_header_bytes; + compressed_header_bytes_ = compressed_header_bytes; + if (current_header_list_size_ > max_header_list_size_) { + Clear(); + } +} + +void QuicHeaderList::Clear() { + header_list_.clear(); + current_header_list_size_ = 0; + uncompressed_header_bytes_ = 0; + compressed_header_bytes_ = 0; +} + +std::string QuicHeaderList::DebugString() const { + std::string s = "{ "; + for (const auto& p : *this) { + s.append(p.first + "=" + p.second + ", "); + } + s.append("}"); + return s; +} + +} // namespace quic diff --git a/quiche/quic/core/http/quic_header_list.h b/quiche/quic/core/http/quic_header_list.h new file mode 100644 index 000000000000..ac25f89e008c --- /dev/null +++ b/quiche/quic/core/http/quic_header_list.h @@ -0,0 +1,88 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_HTTP_QUIC_HEADER_LIST_H_ +#define QUICHE_QUIC_CORE_HTTP_QUIC_HEADER_LIST_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/common/quiche_circular_deque.h" +#include "quiche/spdy/core/spdy_headers_handler_interface.h" + +namespace quic { + +// A simple class that accumulates header pairs +class QUIC_EXPORT_PRIVATE QuicHeaderList + : public spdy::SpdyHeadersHandlerInterface { + public: + using ListType = + quiche::QuicheCircularDeque>; + using value_type = ListType::value_type; + using const_iterator = ListType::const_iterator; + + QuicHeaderList(); + QuicHeaderList(QuicHeaderList&& other); + QuicHeaderList(const QuicHeaderList& other); + QuicHeaderList& operator=(QuicHeaderList&& other); + QuicHeaderList& operator=(const QuicHeaderList& other); + ~QuicHeaderList() override; + + // From SpdyHeadersHandlerInteface. + void OnHeaderBlockStart() override; + void OnHeader(absl::string_view name, absl::string_view value) override; + void OnHeaderBlockEnd(size_t uncompressed_header_bytes, + size_t compressed_header_bytes) override; + + void Clear(); + + const_iterator begin() const { return header_list_.begin(); } + const_iterator end() const { return header_list_.end(); } + + bool empty() const { return header_list_.empty(); } + size_t uncompressed_header_bytes() const { + return uncompressed_header_bytes_; + } + size_t compressed_header_bytes() const { return compressed_header_bytes_; } + + // Deprecated. TODO(b/145909215): remove. + void set_max_header_list_size(size_t max_header_list_size) { + max_header_list_size_ = max_header_list_size; + } + + std::string DebugString() const; + + private: + quiche::QuicheCircularDeque> header_list_; + + // The limit on the size of the header list (defined by spec as name + value + + // overhead for each header field). Headers over this limit will not be + // buffered, and the list will be cleared upon OnHeaderBlockEnd. + size_t max_header_list_size_; + + // Defined per the spec as the size of all header fields with an additional + // overhead for each field. + size_t current_header_list_size_; + + // TODO(dahollings) Are these fields necessary? + size_t uncompressed_header_bytes_; + size_t compressed_header_bytes_; +}; + +inline bool operator==(const QuicHeaderList& l1, const QuicHeaderList& l2) { + auto pred = [](const std::pair& p1, + const std::pair& p2) { + return p1.first == p2.first && p1.second == p2.second; + }; + return std::equal(l1.begin(), l1.end(), l2.begin(), pred); +} + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_HTTP_QUIC_HEADER_LIST_H_ diff --git a/quiche/quic/core/http/quic_header_list_test.cc b/quiche/quic/core/http/quic_header_list_test.cc new file mode 100644 index 000000000000..573aae5ee66d --- /dev/null +++ b/quiche/quic/core/http/quic_header_list_test.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_header_list.h" + +#include + +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_test.h" + +using ::testing::ElementsAre; +using ::testing::Pair; + +namespace quic::test { + +class QuicHeaderListTest : public QuicTest {}; + +// This test verifies that QuicHeaderList accumulates header pairs in order. +TEST_F(QuicHeaderListTest, OnHeader) { + QuicHeaderList headers; + headers.OnHeader("foo", "bar"); + headers.OnHeader("april", "fools"); + headers.OnHeader("beep", ""); + + EXPECT_THAT(headers, ElementsAre(Pair("foo", "bar"), Pair("april", "fools"), + Pair("beep", ""))); +} + +TEST_F(QuicHeaderListTest, DebugString) { + QuicHeaderList headers; + headers.OnHeader("foo", "bar"); + headers.OnHeader("april", "fools"); + headers.OnHeader("beep", ""); + + EXPECT_EQ("{ foo=bar, april=fools, beep=, }", headers.DebugString()); +} + +TEST_F(QuicHeaderListTest, TooLarge) { + const size_t kMaxHeaderListSize = 256; + + QuicHeaderList headers; + headers.set_max_header_list_size(kMaxHeaderListSize); + std::string key = "key"; + std::string value(kMaxHeaderListSize, '1'); + // Send a header that exceeds max_header_list_size. + headers.OnHeader(key, value); + // Send a second header exceeding max_header_list_size. + headers.OnHeader(key + "2", value); + // We should not allocate more memory after exceeding max_header_list_size. + EXPECT_LT(headers.DebugString().size(), 2 * value.size()); + size_t total_bytes = 2 * (key.size() + value.size()) + 1; + headers.OnHeaderBlockEnd(total_bytes, total_bytes); + + EXPECT_TRUE(headers.empty()); + EXPECT_EQ("{ }", headers.DebugString()); +} + +TEST_F(QuicHeaderListTest, NotTooLarge) { + QuicHeaderList headers; + headers.set_max_header_list_size(1 << 20); + std::string key = "key"; + std::string value(1 << 18, '1'); + headers.OnHeader(key, value); + size_t total_bytes = key.size() + value.size(); + headers.OnHeaderBlockEnd(total_bytes, total_bytes); + EXPECT_FALSE(headers.empty()); +} + +// This test verifies that QuicHeaderList is copyable and assignable. +TEST_F(QuicHeaderListTest, IsCopyableAndAssignable) { + QuicHeaderList headers; + headers.OnHeader("foo", "bar"); + headers.OnHeader("april", "fools"); + headers.OnHeader("beep", ""); + + QuicHeaderList headers2(headers); + QuicHeaderList headers3 = headers; + + EXPECT_THAT(headers2, ElementsAre(Pair("foo", "bar"), Pair("april", "fools"), + Pair("beep", ""))); + EXPECT_THAT(headers3, ElementsAre(Pair("foo", "bar"), Pair("april", "fools"), + Pair("beep", ""))); +} + +} // namespace quic::test diff --git a/quiche/quic/core/http/quic_headers_stream.cc b/quiche/quic/core/http/quic_headers_stream.cc new file mode 100644 index 000000000000..9a53e66ca088 --- /dev/null +++ b/quiche/quic/core/http/quic_headers_stream.cc @@ -0,0 +1,163 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_headers_stream.h" + +#include "absl/base/macros.h" +#include "quiche/quic/core/http/quic_spdy_session.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" + +namespace quic { + +QuicHeadersStream::CompressedHeaderInfo::CompressedHeaderInfo( + QuicStreamOffset headers_stream_offset, QuicStreamOffset full_length, + quiche::QuicheReferenceCountedPointer + ack_listener) + : headers_stream_offset(headers_stream_offset), + full_length(full_length), + unacked_length(full_length), + ack_listener(std::move(ack_listener)) {} + +QuicHeadersStream::CompressedHeaderInfo::CompressedHeaderInfo( + const CompressedHeaderInfo& other) = default; + +QuicHeadersStream::CompressedHeaderInfo::~CompressedHeaderInfo() {} + +QuicHeadersStream::QuicHeadersStream(QuicSpdySession* session) + : QuicStream(QuicUtils::GetHeadersStreamId(session->transport_version()), + session, + /*is_static=*/true, BIDIRECTIONAL), + spdy_session_(session) { + // The headers stream is exempt from connection level flow control. + DisableConnectionFlowControlForThisStream(); +} + +QuicHeadersStream::~QuicHeadersStream() {} + +void QuicHeadersStream::OnDataAvailable() { + struct iovec iov; + while (sequencer()->GetReadableRegion(&iov)) { + if (spdy_session_->ProcessHeaderData(iov) != iov.iov_len) { + // Error processing data. + return; + } + sequencer()->MarkConsumed(iov.iov_len); + MaybeReleaseSequencerBuffer(); + } +} + +void QuicHeadersStream::MaybeReleaseSequencerBuffer() { + if (spdy_session_->ShouldReleaseHeadersStreamSequencerBuffer()) { + sequencer()->ReleaseBufferIfEmpty(); + } +} + +bool QuicHeadersStream::OnStreamFrameAcked(QuicStreamOffset offset, + QuicByteCount data_length, + bool fin_acked, + QuicTime::Delta ack_delay_time, + QuicTime receive_timestamp, + QuicByteCount* newly_acked_length) { + QuicIntervalSet newly_acked(offset, offset + data_length); + newly_acked.Difference(bytes_acked()); + for (const auto& acked : newly_acked) { + QuicStreamOffset acked_offset = acked.min(); + QuicByteCount acked_length = acked.max() - acked.min(); + for (CompressedHeaderInfo& header : unacked_headers_) { + if (acked_offset < header.headers_stream_offset) { + // This header frame offset belongs to headers with smaller offset, stop + // processing. + break; + } + + if (acked_offset >= header.headers_stream_offset + header.full_length) { + // This header frame belongs to headers with larger offset. + continue; + } + + QuicByteCount header_offset = acked_offset - header.headers_stream_offset; + QuicByteCount header_length = + std::min(acked_length, header.full_length - header_offset); + + if (header.unacked_length < header_length) { + QUIC_BUG(quic_bug_10416_1) + << "Unsent stream data is acked. unacked_length: " + << header.unacked_length << " acked_length: " << header_length; + OnUnrecoverableError(QUIC_INTERNAL_ERROR, + "Unsent stream data is acked"); + return false; + } + if (header.ack_listener != nullptr && header_length > 0) { + header.ack_listener->OnPacketAcked(header_length, ack_delay_time); + } + header.unacked_length -= header_length; + acked_offset += header_length; + acked_length -= header_length; + } + } + // Remove headers which are fully acked. Please note, header frames can be + // acked out of order, but unacked_headers_ is cleaned up in order. + while (!unacked_headers_.empty() && + unacked_headers_.front().unacked_length == 0) { + unacked_headers_.pop_front(); + } + return QuicStream::OnStreamFrameAcked(offset, data_length, fin_acked, + ack_delay_time, receive_timestamp, + newly_acked_length); +} + +void QuicHeadersStream::OnStreamFrameRetransmitted(QuicStreamOffset offset, + QuicByteCount data_length, + bool /*fin_retransmitted*/) { + QuicStream::OnStreamFrameRetransmitted(offset, data_length, false); + for (CompressedHeaderInfo& header : unacked_headers_) { + if (offset < header.headers_stream_offset) { + // This header frame offset belongs to headers with smaller offset, stop + // processing. + break; + } + + if (offset >= header.headers_stream_offset + header.full_length) { + // This header frame belongs to headers with larger offset. + continue; + } + + QuicByteCount header_offset = offset - header.headers_stream_offset; + QuicByteCount retransmitted_length = + std::min(data_length, header.full_length - header_offset); + if (header.ack_listener != nullptr && retransmitted_length > 0) { + header.ack_listener->OnPacketRetransmitted(retransmitted_length); + } + offset += retransmitted_length; + data_length -= retransmitted_length; + } +} + +void QuicHeadersStream::OnDataBuffered( + QuicStreamOffset offset, QuicByteCount data_length, + const quiche::QuicheReferenceCountedPointer& + ack_listener) { + // Populate unacked_headers_. + if (!unacked_headers_.empty() && + (offset == unacked_headers_.back().headers_stream_offset + + unacked_headers_.back().full_length) && + ack_listener == unacked_headers_.back().ack_listener) { + // Try to combine with latest inserted entry if they belong to the same + // header (i.e., having contiguous offset and the same ack listener). + unacked_headers_.back().full_length += data_length; + unacked_headers_.back().unacked_length += data_length; + } else { + unacked_headers_.push_back( + CompressedHeaderInfo(offset, data_length, ack_listener)); + } +} + +void QuicHeadersStream::OnStreamReset(const QuicRstStreamFrame& /*frame*/) { + stream_delegate()->OnStreamError(QUIC_INVALID_STREAM_ID, + "Attempt to reset headers stream"); +} + +} // namespace quic diff --git a/quiche/quic/core/http/quic_headers_stream.h b/quiche/quic/core/http/quic_headers_stream.h new file mode 100644 index 000000000000..ba3a27b38507 --- /dev/null +++ b/quiche/quic/core/http/quic_headers_stream.h @@ -0,0 +1,96 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_HTTP_QUIC_HEADERS_STREAM_H_ +#define QUICHE_QUIC_CORE_HTTP_QUIC_HEADERS_STREAM_H_ + +#include +#include + +#include "quiche/quic/core/http/quic_header_list.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_stream.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/spdy/core/spdy_framer.h" + +namespace quic { + +class QuicSpdySession; + +namespace test { +class QuicHeadersStreamPeer; +} // namespace test + +// Headers in QUIC are sent as HTTP/2 HEADERS or PUSH_PROMISE frames over a +// reserved stream with the id 3. Each endpoint (client and server) will +// allocate an instance of QuicHeadersStream to send and receive headers. +class QUIC_EXPORT_PRIVATE QuicHeadersStream : public QuicStream { + public: + explicit QuicHeadersStream(QuicSpdySession* session); + QuicHeadersStream(const QuicHeadersStream&) = delete; + QuicHeadersStream& operator=(const QuicHeadersStream&) = delete; + ~QuicHeadersStream() override; + + // QuicStream implementation + void OnDataAvailable() override; + + // Release underlying buffer if allowed. + void MaybeReleaseSequencerBuffer(); + + bool OnStreamFrameAcked(QuicStreamOffset offset, QuicByteCount data_length, + bool fin_acked, QuicTime::Delta ack_delay_time, + QuicTime receive_timestamp, + QuicByteCount* newly_acked_length) override; + + void OnStreamFrameRetransmitted(QuicStreamOffset offset, + QuicByteCount data_length, + bool fin_retransmitted) override; + + void OnStreamReset(const QuicRstStreamFrame& frame) override; + + private: + friend class test::QuicHeadersStreamPeer; + + // CompressedHeaderInfo includes simple information of a header, including + // offset in headers stream, unacked length and ack listener of this header. + struct QUIC_EXPORT_PRIVATE CompressedHeaderInfo { + CompressedHeaderInfo( + QuicStreamOffset headers_stream_offset, QuicStreamOffset full_length, + quiche::QuicheReferenceCountedPointer + ack_listener); + CompressedHeaderInfo(const CompressedHeaderInfo& other); + ~CompressedHeaderInfo(); + + // Offset the header was sent on the headers stream. + QuicStreamOffset headers_stream_offset; + // The full length of the header. + QuicByteCount full_length; + // The remaining bytes to be acked. + QuicByteCount unacked_length; + // Ack listener of this header, and it is notified once any of the bytes has + // been acked or retransmitted. + quiche::QuicheReferenceCountedPointer + ack_listener; + }; + + // Returns true if the session is still connected. + bool IsConnected(); + + // Override to store mapping from offset, length to ack_listener. This + // ack_listener is notified once data within [offset, offset + length] is + // acked or retransmitted. + void OnDataBuffered( + QuicStreamOffset offset, QuicByteCount data_length, + const quiche::QuicheReferenceCountedPointer& + ack_listener) override; + + QuicSpdySession* spdy_session_; + + // Headers that have not been fully acked. + quiche::QuicheCircularDeque unacked_headers_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_HTTP_QUIC_HEADERS_STREAM_H_ diff --git a/quiche/quic/core/http/quic_headers_stream_test.cc b/quiche/quic/core/http/quic_headers_stream_test.cc new file mode 100644 index 000000000000..f4c48c0669ae --- /dev/null +++ b/quiche/quic/core/http/quic_headers_stream_test.cc @@ -0,0 +1,936 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_headers_stream.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/core/http/spdy_utils.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_spdy_session_peer.h" +#include "quiche/quic/test_tools/quic_stream_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/quiche_endian.h" +#include "quiche/spdy/core/http2_frame_decoder_adapter.h" +#include "quiche/spdy/core/http2_header_block.h" +#include "quiche/spdy/core/recording_headers_handler.h" +#include "quiche/spdy/core/spdy_alt_svc_wire_format.h" +#include "quiche/spdy/core/spdy_protocol.h" +#include "quiche/spdy/test_tools/spdy_test_utils.h" + +using spdy::ERROR_CODE_PROTOCOL_ERROR; +using spdy::Http2HeaderBlock; +using spdy::RecordingHeadersHandler; +using spdy::SETTINGS_ENABLE_PUSH; +using spdy::SETTINGS_HEADER_TABLE_SIZE; +using spdy::SETTINGS_INITIAL_WINDOW_SIZE; +using spdy::SETTINGS_MAX_CONCURRENT_STREAMS; +using spdy::SETTINGS_MAX_FRAME_SIZE; +using spdy::Spdy3PriorityToHttp2Weight; +using spdy::SpdyAltSvcWireFormat; +using spdy::SpdyDataIR; +using spdy::SpdyErrorCode; +using spdy::SpdyFramer; +using spdy::SpdyFramerVisitorInterface; +using spdy::SpdyGoAwayIR; +using spdy::SpdyHeadersHandlerInterface; +using spdy::SpdyHeadersIR; +using spdy::SpdyPingId; +using spdy::SpdyPingIR; +using spdy::SpdyPriority; +using spdy::SpdyPriorityIR; +using spdy::SpdyPushPromiseIR; +using spdy::SpdyRstStreamIR; +using spdy::SpdySerializedFrame; +using spdy::SpdySettingsId; +using spdy::SpdySettingsIR; +using spdy::SpdyStreamId; +using spdy::SpdyWindowUpdateIR; +using testing::_; +using testing::AnyNumber; +using testing::AtLeast; +using testing::InSequence; +using testing::Invoke; +using testing::Return; +using testing::StrictMock; +using testing::WithArgs; + +namespace quic { +namespace test { +namespace { + +class MockVisitor : public SpdyFramerVisitorInterface { + public: + MOCK_METHOD(void, OnError, + (http2::Http2DecoderAdapter::SpdyFramerError error, + std::string detailed_error), + (override)); + MOCK_METHOD(void, OnDataFrameHeader, + (SpdyStreamId stream_id, size_t length, bool fin), (override)); + MOCK_METHOD(void, OnStreamFrameData, + (SpdyStreamId stream_id, const char*, size_t len), (override)); + MOCK_METHOD(void, OnStreamEnd, (SpdyStreamId stream_id), (override)); + MOCK_METHOD(void, OnStreamPadding, (SpdyStreamId stream_id, size_t len), + (override)); + MOCK_METHOD(SpdyHeadersHandlerInterface*, OnHeaderFrameStart, + (SpdyStreamId stream_id), (override)); + MOCK_METHOD(void, OnHeaderFrameEnd, (SpdyStreamId stream_id), (override)); + MOCK_METHOD(void, OnRstStream, + (SpdyStreamId stream_id, SpdyErrorCode error_code), (override)); + MOCK_METHOD(void, OnSettings, (), (override)); + MOCK_METHOD(void, OnSetting, (SpdySettingsId id, uint32_t value), (override)); + MOCK_METHOD(void, OnSettingsAck, (), (override)); + MOCK_METHOD(void, OnSettingsEnd, (), (override)); + MOCK_METHOD(void, OnPing, (SpdyPingId unique_id, bool is_ack), (override)); + MOCK_METHOD(void, OnGoAway, + (SpdyStreamId last_accepted_stream_id, SpdyErrorCode error_code), + (override)); + MOCK_METHOD(void, OnHeaders, + (SpdyStreamId stream_id, size_t payload_length, bool has_priority, + int weight, SpdyStreamId parent_stream_id, bool exclusive, + bool fin, bool end), + (override)); + MOCK_METHOD(void, OnWindowUpdate, + (SpdyStreamId stream_id, int delta_window_size), (override)); + MOCK_METHOD(void, OnPushPromise, + (SpdyStreamId stream_id, SpdyStreamId promised_stream_id, + bool end), + (override)); + MOCK_METHOD(void, OnContinuation, + (SpdyStreamId stream_id, size_t payload_size, bool end), + (override)); + MOCK_METHOD( + void, OnAltSvc, + (SpdyStreamId stream_id, absl::string_view origin, + const SpdyAltSvcWireFormat::AlternativeServiceVector& altsvc_vector), + (override)); + MOCK_METHOD(void, OnPriority, + (SpdyStreamId stream_id, SpdyStreamId parent_stream_id, + int weight, bool exclusive), + (override)); + MOCK_METHOD(void, OnPriorityUpdate, + (SpdyStreamId prioritized_stream_id, + absl::string_view priority_field_value), + (override)); + MOCK_METHOD(bool, OnUnknownFrame, + (SpdyStreamId stream_id, uint8_t frame_type), (override)); + MOCK_METHOD(void, OnUnknownFrameStart, + (SpdyStreamId stream_id, size_t length, uint8_t type, + uint8_t flags), + (override)); + MOCK_METHOD(void, OnUnknownFramePayload, + (SpdyStreamId stream_id, absl::string_view payload), (override)); +}; + +struct TestParams { + TestParams(const ParsedQuicVersion& version, Perspective perspective) + : version(version), perspective(perspective) { + QUIC_LOG(INFO) << "TestParams: " << *this; + } + + TestParams(const TestParams& other) + : version(other.version), perspective(other.perspective) {} + + friend std::ostream& operator<<(std::ostream& os, const TestParams& tp) { + os << "{ version: " << ParsedQuicVersionToString(tp.version) + << ", perspective: " + << (tp.perspective == Perspective::IS_CLIENT ? "client" : "server") + << "}"; + return os; + } + + ParsedQuicVersion version; + Perspective perspective; +}; + +// Used by ::testing::PrintToStringParamName(). +std::string PrintToString(const TestParams& tp) { + return absl::StrCat( + ParsedQuicVersionToString(tp.version), "_", + (tp.perspective == Perspective::IS_CLIENT ? "client" : "server")); +} + +std::vector GetTestParams() { + std::vector params; + ParsedQuicVersionVector all_supported_versions = AllSupportedVersions(); + for (size_t i = 0; i < all_supported_versions.size(); ++i) { + if (VersionUsesHttp3(all_supported_versions[i].transport_version)) { + continue; + } + for (Perspective p : {Perspective::IS_SERVER, Perspective::IS_CLIENT}) { + params.emplace_back(all_supported_versions[i], p); + } + } + return params; +} + +class QuicHeadersStreamTest : public QuicTestWithParam { + public: + QuicHeadersStreamTest() + : connection_(new StrictMock( + &helper_, &alarm_factory_, perspective(), GetVersion())), + session_(connection_), + body_("hello world"), + stream_frame_( + QuicUtils::GetHeadersStreamId(connection_->transport_version()), + /*fin=*/false, + /*offset=*/0, ""), + next_promised_stream_id_(2) { + QuicSpdySessionPeer::SetMaxInboundHeaderListSize(&session_, 256 * 1024); + EXPECT_CALL(session_, OnCongestionWindowChange(_)).Times(AnyNumber()); + session_.Initialize(); + connection_->SetEncrypter( + quic::ENCRYPTION_FORWARD_SECURE, + std::make_unique(connection_->perspective())); + headers_stream_ = QuicSpdySessionPeer::GetHeadersStream(&session_); + headers_[":status"] = "200 Ok"; + headers_["content-length"] = "11"; + framer_ = std::unique_ptr( + new SpdyFramer(SpdyFramer::ENABLE_COMPRESSION)); + deframer_ = std::unique_ptr( + new http2::Http2DecoderAdapter()); + deframer_->set_visitor(&visitor_); + EXPECT_EQ(transport_version(), session_.transport_version()); + EXPECT_TRUE(headers_stream_ != nullptr); + connection_->AdvanceTime(QuicTime::Delta::FromMilliseconds(1)); + client_id_1_ = GetNthClientInitiatedBidirectionalStreamId( + connection_->transport_version(), 0); + client_id_2_ = GetNthClientInitiatedBidirectionalStreamId( + connection_->transport_version(), 1); + client_id_3_ = GetNthClientInitiatedBidirectionalStreamId( + connection_->transport_version(), 2); + next_stream_id_ = + QuicUtils::StreamIdDelta(connection_->transport_version()); + } + + QuicStreamId GetNthClientInitiatedId(int n) { + return GetNthClientInitiatedBidirectionalStreamId( + connection_->transport_version(), n); + } + + QuicConsumedData SaveIov(size_t write_length) { + char* buf = new char[write_length]; + QuicDataWriter writer(write_length, buf, quiche::NETWORK_BYTE_ORDER); + headers_stream_->WriteStreamData(headers_stream_->stream_bytes_written(), + write_length, &writer); + saved_data_.append(buf, write_length); + delete[] buf; + return QuicConsumedData(write_length, false); + } + + void SavePayload(const char* data, size_t len) { + saved_payloads_.append(data, len); + } + + bool SaveHeaderData(const char* data, int len) { + saved_header_data_.append(data, len); + return true; + } + + void SaveHeaderDataStringPiece(absl::string_view data) { + saved_header_data_.append(data.data(), data.length()); + } + + void SavePromiseHeaderList(QuicStreamId /* stream_id */, + QuicStreamId /* promised_stream_id */, size_t size, + const QuicHeaderList& header_list) { + SaveToHandler(size, header_list); + } + + void SaveHeaderList(QuicStreamId /* stream_id */, bool /* fin */, size_t size, + const QuicHeaderList& header_list) { + SaveToHandler(size, header_list); + } + + void SaveToHandler(size_t size, const QuicHeaderList& header_list) { + headers_handler_ = std::make_unique(); + headers_handler_->OnHeaderBlockStart(); + for (const auto& p : header_list) { + headers_handler_->OnHeader(p.first, p.second); + } + headers_handler_->OnHeaderBlockEnd(size, size); + } + + void WriteAndExpectRequestHeaders(QuicStreamId stream_id, bool fin, + SpdyPriority priority) { + WriteHeadersAndCheckData(stream_id, fin, priority, true /*is_request*/); + } + + void WriteAndExpectResponseHeaders(QuicStreamId stream_id, bool fin) { + WriteHeadersAndCheckData(stream_id, fin, 0, false /*is_request*/); + } + + void WriteHeadersAndCheckData(QuicStreamId stream_id, bool fin, + SpdyPriority priority, bool is_request) { + // Write the headers and capture the outgoing data + EXPECT_CALL(session_, WritevData(QuicUtils::GetHeadersStreamId( + connection_->transport_version()), + _, _, NO_FIN, _, _)) + .WillOnce(WithArgs<1>(Invoke(this, &QuicHeadersStreamTest::SaveIov))); + QuicSpdySessionPeer::WriteHeadersOnHeadersStream( + &session_, stream_id, headers_.Clone(), fin, + spdy::SpdyStreamPrecedence(priority), nullptr); + + // Parse the outgoing data and check that it matches was was written. + if (is_request) { + EXPECT_CALL( + visitor_, + OnHeaders(stream_id, saved_data_.length() - spdy::kFrameHeaderSize, + kHasPriority, Spdy3PriorityToHttp2Weight(priority), + /*parent_stream_id=*/0, + /*exclusive=*/false, fin, kFrameComplete)); + } else { + EXPECT_CALL( + visitor_, + OnHeaders(stream_id, saved_data_.length() - spdy::kFrameHeaderSize, + !kHasPriority, + /*weight=*/0, + /*parent_stream_id=*/0, + /*exclusive=*/false, fin, kFrameComplete)); + } + headers_handler_ = std::make_unique(); + EXPECT_CALL(visitor_, OnHeaderFrameStart(stream_id)) + .WillOnce(Return(headers_handler_.get())); + EXPECT_CALL(visitor_, OnHeaderFrameEnd(stream_id)).Times(1); + if (fin) { + EXPECT_CALL(visitor_, OnStreamEnd(stream_id)); + } + deframer_->ProcessInput(saved_data_.data(), saved_data_.length()); + EXPECT_FALSE(deframer_->HasError()) + << http2::Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); + + CheckHeaders(); + saved_data_.clear(); + } + + void CheckHeaders() { + ASSERT_TRUE(headers_handler_); + EXPECT_EQ(headers_, headers_handler_->decoded_block()); + headers_handler_.reset(); + } + + Perspective perspective() const { return GetParam().perspective; } + + QuicTransportVersion transport_version() const { + return GetParam().version.transport_version; + } + + ParsedQuicVersionVector GetVersion() { + ParsedQuicVersionVector versions; + versions.push_back(GetParam().version); + return versions; + } + + void TearDownLocalConnectionState() { + QuicConnectionPeer::TearDownLocalConnectionState(connection_); + } + + QuicStreamId NextPromisedStreamId() { + return next_promised_stream_id_ += next_stream_id_; + } + + static constexpr bool kFrameComplete = true; + static constexpr bool kHasPriority = true; + + MockQuicConnectionHelper helper_; + MockAlarmFactory alarm_factory_; + StrictMock* connection_; + StrictMock session_; + QuicHeadersStream* headers_stream_; + Http2HeaderBlock headers_; + std::unique_ptr headers_handler_; + std::string body_; + std::string saved_data_; + std::string saved_header_data_; + std::string saved_payloads_; + std::unique_ptr framer_; + std::unique_ptr deframer_; + StrictMock visitor_; + QuicStreamFrame stream_frame_; + QuicStreamId next_promised_stream_id_; + QuicStreamId client_id_1_; + QuicStreamId client_id_2_; + QuicStreamId client_id_3_; + QuicStreamId next_stream_id_; +}; + +// Run all tests with each version and perspective (client or server). +INSTANTIATE_TEST_SUITE_P(Tests, QuicHeadersStreamTest, + ::testing::ValuesIn(GetTestParams()), + ::testing::PrintToStringParamName()); + +TEST_P(QuicHeadersStreamTest, StreamId) { + EXPECT_EQ(QuicUtils::GetHeadersStreamId(connection_->transport_version()), + headers_stream_->id()); +} + +TEST_P(QuicHeadersStreamTest, WriteHeaders) { + for (QuicStreamId stream_id = client_id_1_; stream_id < client_id_3_; + stream_id += next_stream_id_) { + for (bool fin : {false, true}) { + if (perspective() == Perspective::IS_SERVER) { + WriteAndExpectResponseHeaders(stream_id, fin); + } else { + for (SpdyPriority priority = 0; priority < 7; ++priority) { + // TODO(rch): implement priorities correctly. + WriteAndExpectRequestHeaders(stream_id, fin, 0); + } + } + } + } +} + +TEST_P(QuicHeadersStreamTest, WritePushPromises) { + for (QuicStreamId stream_id = client_id_1_; stream_id < client_id_3_; + stream_id += next_stream_id_) { + QuicStreamId promised_stream_id = NextPromisedStreamId(); + if (perspective() == Perspective::IS_SERVER) { + // Write the headers and capture the outgoing data + EXPECT_CALL(session_, WritevData(QuicUtils::GetHeadersStreamId( + connection_->transport_version()), + _, _, NO_FIN, _, _)) + .WillOnce(WithArgs<1>(Invoke(this, &QuicHeadersStreamTest::SaveIov))); + session_.WritePushPromise(stream_id, promised_stream_id, + headers_.Clone()); + + // Parse the outgoing data and check that it matches was was written. + EXPECT_CALL(visitor_, + OnPushPromise(stream_id, promised_stream_id, kFrameComplete)); + headers_handler_ = std::make_unique(); + EXPECT_CALL(visitor_, OnHeaderFrameStart(stream_id)) + .WillOnce(Return(headers_handler_.get())); + EXPECT_CALL(visitor_, OnHeaderFrameEnd(stream_id)).Times(1); + deframer_->ProcessInput(saved_data_.data(), saved_data_.length()); + EXPECT_FALSE(deframer_->HasError()) + << http2::Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); + CheckHeaders(); + saved_data_.clear(); + } else { + EXPECT_QUIC_BUG(session_.WritePushPromise(stream_id, promised_stream_id, + headers_.Clone()), + "Client shouldn't send PUSH_PROMISE"); + } + } +} + +TEST_P(QuicHeadersStreamTest, ProcessRawData) { + for (QuicStreamId stream_id = client_id_1_; stream_id < client_id_3_; + stream_id += next_stream_id_) { + for (bool fin : {false, true}) { + for (SpdyPriority priority = 0; priority < 7; ++priority) { + // Replace with "WriteHeadersAndSaveData" + SpdySerializedFrame frame; + if (perspective() == Perspective::IS_SERVER) { + SpdyHeadersIR headers_frame(stream_id, headers_.Clone()); + headers_frame.set_fin(fin); + headers_frame.set_has_priority(true); + headers_frame.set_weight(Spdy3PriorityToHttp2Weight(0)); + frame = framer_->SerializeFrame(headers_frame); + EXPECT_CALL(session_, OnStreamHeadersPriority( + stream_id, spdy::SpdyStreamPrecedence(0))); + } else { + SpdyHeadersIR headers_frame(stream_id, headers_.Clone()); + headers_frame.set_fin(fin); + frame = framer_->SerializeFrame(headers_frame); + } + EXPECT_CALL(session_, + OnStreamHeaderList(stream_id, fin, frame.size(), _)) + .WillOnce(Invoke(this, &QuicHeadersStreamTest::SaveHeaderList)); + stream_frame_.data_buffer = frame.data(); + stream_frame_.data_length = frame.size(); + headers_stream_->OnStreamFrame(stream_frame_); + stream_frame_.offset += frame.size(); + CheckHeaders(); + } + } + } +} + +TEST_P(QuicHeadersStreamTest, ProcessPushPromise) { + if (perspective() == Perspective::IS_SERVER) { + return; + } + for (QuicStreamId stream_id = client_id_1_; stream_id < client_id_3_; + stream_id += next_stream_id_) { + QuicStreamId promised_stream_id = NextPromisedStreamId(); + SpdyPushPromiseIR push_promise(stream_id, promised_stream_id, + headers_.Clone()); + SpdySerializedFrame frame(framer_->SerializeFrame(push_promise)); + bool connection_closed = false; + if (perspective() == Perspective::IS_SERVER) { + EXPECT_CALL(*connection_, + CloseConnection(QUIC_INVALID_HEADERS_STREAM_DATA, + "PUSH_PROMISE not supported.", _)) + .WillRepeatedly(InvokeWithoutArgs( + this, &QuicHeadersStreamTest::TearDownLocalConnectionState)); + } else { + ON_CALL(*connection_, CloseConnection(_, _, _)) + .WillByDefault(testing::Assign(&connection_closed, true)); + EXPECT_CALL(session_, OnPromiseHeaderList(stream_id, promised_stream_id, + frame.size(), _)) + .WillOnce( + Invoke(this, &QuicHeadersStreamTest::SavePromiseHeaderList)); + } + stream_frame_.data_buffer = frame.data(); + stream_frame_.data_length = frame.size(); + headers_stream_->OnStreamFrame(stream_frame_); + if (perspective() == Perspective::IS_CLIENT) { + stream_frame_.offset += frame.size(); + // CheckHeaders crashes if the connection is closed so this ensures we + // fail the test instead of crashing. + ASSERT_FALSE(connection_closed); + CheckHeaders(); + } + } +} + +TEST_P(QuicHeadersStreamTest, ProcessPriorityFrame) { + QuicStreamId parent_stream_id = 0; + for (SpdyPriority priority = 0; priority < 7; ++priority) { + for (QuicStreamId stream_id = client_id_1_; stream_id < client_id_3_; + stream_id += next_stream_id_) { + int weight = Spdy3PriorityToHttp2Weight(priority); + SpdyPriorityIR priority_frame(stream_id, parent_stream_id, weight, true); + SpdySerializedFrame frame(framer_->SerializeFrame(priority_frame)); + parent_stream_id = stream_id; + if (perspective() == Perspective::IS_CLIENT) { + EXPECT_CALL(*connection_, + CloseConnection(QUIC_INVALID_HEADERS_STREAM_DATA, + "Server must not send PRIORITY frames.", _)) + .WillRepeatedly(InvokeWithoutArgs( + this, &QuicHeadersStreamTest::TearDownLocalConnectionState)); + } else { + EXPECT_CALL( + session_, + OnPriorityFrame(stream_id, spdy::SpdyStreamPrecedence(priority))) + .Times(1); + } + stream_frame_.data_buffer = frame.data(); + stream_frame_.data_length = frame.size(); + headers_stream_->OnStreamFrame(stream_frame_); + stream_frame_.offset += frame.size(); + } + } +} + +TEST_P(QuicHeadersStreamTest, ProcessPushPromiseDisabledSetting) { + if (perspective() != Perspective::IS_CLIENT) { + return; + } + + session_.OnConfigNegotiated(); + SpdySettingsIR data; + // Respect supported settings frames SETTINGS_ENABLE_PUSH. + data.AddSetting(SETTINGS_ENABLE_PUSH, 0); + SpdySerializedFrame frame(framer_->SerializeFrame(data)); + stream_frame_.data_buffer = frame.data(); + stream_frame_.data_length = frame.size(); + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_INVALID_HEADERS_STREAM_DATA, + "Unsupported field of HTTP/2 SETTINGS frame: 2", _)); + headers_stream_->OnStreamFrame(stream_frame_); +} + +TEST_P(QuicHeadersStreamTest, ProcessLargeRawData) { + // We want to create a frame that is more than the SPDY Framer's max control + // frame size, which is 16K, but less than the HPACK decoders max decode + // buffer size, which is 32K. + headers_["key0"] = std::string(1 << 13, '.'); + headers_["key1"] = std::string(1 << 13, '.'); + headers_["key2"] = std::string(1 << 13, '.'); + for (QuicStreamId stream_id = client_id_1_; stream_id < client_id_3_; + stream_id += next_stream_id_) { + for (bool fin : {false, true}) { + for (SpdyPriority priority = 0; priority < 7; ++priority) { + // Replace with "WriteHeadersAndSaveData" + SpdySerializedFrame frame; + if (perspective() == Perspective::IS_SERVER) { + SpdyHeadersIR headers_frame(stream_id, headers_.Clone()); + headers_frame.set_fin(fin); + headers_frame.set_has_priority(true); + headers_frame.set_weight(Spdy3PriorityToHttp2Weight(0)); + frame = framer_->SerializeFrame(headers_frame); + EXPECT_CALL(session_, OnStreamHeadersPriority( + stream_id, spdy::SpdyStreamPrecedence(0))); + } else { + SpdyHeadersIR headers_frame(stream_id, headers_.Clone()); + headers_frame.set_fin(fin); + frame = framer_->SerializeFrame(headers_frame); + } + EXPECT_CALL(session_, + OnStreamHeaderList(stream_id, fin, frame.size(), _)) + .WillOnce(Invoke(this, &QuicHeadersStreamTest::SaveHeaderList)); + stream_frame_.data_buffer = frame.data(); + stream_frame_.data_length = frame.size(); + headers_stream_->OnStreamFrame(stream_frame_); + stream_frame_.offset += frame.size(); + CheckHeaders(); + } + } + } +} + +TEST_P(QuicHeadersStreamTest, ProcessBadData) { + const char kBadData[] = "blah blah blah"; + EXPECT_CALL(*connection_, + CloseConnection(QUIC_INVALID_HEADERS_STREAM_DATA, _, _)) + .Times(::testing::AnyNumber()); + stream_frame_.data_buffer = kBadData; + stream_frame_.data_length = strlen(kBadData); + headers_stream_->OnStreamFrame(stream_frame_); +} + +TEST_P(QuicHeadersStreamTest, ProcessSpdyDataFrame) { + SpdyDataIR data(/* stream_id = */ 2, "ping"); + SpdySerializedFrame frame(framer_->SerializeFrame(data)); + + EXPECT_CALL(*connection_, CloseConnection(QUIC_INVALID_HEADERS_STREAM_DATA, + "SPDY DATA frame received.", _)) + .WillOnce(InvokeWithoutArgs( + this, &QuicHeadersStreamTest::TearDownLocalConnectionState)); + stream_frame_.data_buffer = frame.data(); + stream_frame_.data_length = frame.size(); + headers_stream_->OnStreamFrame(stream_frame_); +} + +TEST_P(QuicHeadersStreamTest, ProcessSpdyRstStreamFrame) { + SpdyRstStreamIR data(/* stream_id = */ 2, ERROR_CODE_PROTOCOL_ERROR); + SpdySerializedFrame frame(framer_->SerializeFrame(data)); + EXPECT_CALL(*connection_, + CloseConnection(QUIC_INVALID_HEADERS_STREAM_DATA, + "SPDY RST_STREAM frame received.", _)) + .WillOnce(InvokeWithoutArgs( + this, &QuicHeadersStreamTest::TearDownLocalConnectionState)); + stream_frame_.data_buffer = frame.data(); + stream_frame_.data_length = frame.size(); + headers_stream_->OnStreamFrame(stream_frame_); +} + +TEST_P(QuicHeadersStreamTest, RespectHttp2SettingsFrameSupportedFields) { + const uint32_t kTestHeaderTableSize = 1000; + SpdySettingsIR data; + // Respect supported settings frames SETTINGS_HEADER_TABLE_SIZE, + // SETTINGS_MAX_HEADER_LIST_SIZE. + data.AddSetting(SETTINGS_HEADER_TABLE_SIZE, kTestHeaderTableSize); + data.AddSetting(spdy::SETTINGS_MAX_HEADER_LIST_SIZE, 2000); + SpdySerializedFrame frame(framer_->SerializeFrame(data)); + stream_frame_.data_buffer = frame.data(); + stream_frame_.data_length = frame.size(); + headers_stream_->OnStreamFrame(stream_frame_); + EXPECT_EQ(kTestHeaderTableSize, QuicSpdySessionPeer::GetSpdyFramer(&session_) + ->header_encoder_table_size()); +} + +// Regression test for b/208997000. +TEST_P(QuicHeadersStreamTest, LimitEncoderDynamicTableSize) { + const uint32_t kVeryLargeTableSizeLimit = 1024 * 1024 * 1024; + SpdySettingsIR data; + data.AddSetting(SETTINGS_HEADER_TABLE_SIZE, kVeryLargeTableSizeLimit); + SpdySerializedFrame frame(framer_->SerializeFrame(data)); + stream_frame_.data_buffer = frame.data(); + stream_frame_.data_length = frame.size(); + headers_stream_->OnStreamFrame(stream_frame_); + EXPECT_EQ(16384u, QuicSpdySessionPeer::GetSpdyFramer(&session_) + ->header_encoder_table_size()); +} + +TEST_P(QuicHeadersStreamTest, RespectHttp2SettingsFrameUnsupportedFields) { + SpdySettingsIR data; + // Does not support SETTINGS_MAX_CONCURRENT_STREAMS, + // SETTINGS_INITIAL_WINDOW_SIZE, SETTINGS_ENABLE_PUSH and + // SETTINGS_MAX_FRAME_SIZE. + data.AddSetting(SETTINGS_MAX_CONCURRENT_STREAMS, 100); + data.AddSetting(SETTINGS_INITIAL_WINDOW_SIZE, 100); + data.AddSetting(SETTINGS_ENABLE_PUSH, 1); + data.AddSetting(SETTINGS_MAX_FRAME_SIZE, 1250); + SpdySerializedFrame frame(framer_->SerializeFrame(data)); + EXPECT_CALL(*connection_, + CloseConnection( + QUIC_INVALID_HEADERS_STREAM_DATA, + absl::StrCat("Unsupported field of HTTP/2 SETTINGS frame: ", + SETTINGS_MAX_CONCURRENT_STREAMS), + _)); + EXPECT_CALL(*connection_, + CloseConnection( + QUIC_INVALID_HEADERS_STREAM_DATA, + absl::StrCat("Unsupported field of HTTP/2 SETTINGS frame: ", + SETTINGS_INITIAL_WINDOW_SIZE), + _)); + if (session_.perspective() == Perspective::IS_CLIENT) { + EXPECT_CALL(*connection_, + CloseConnection( + QUIC_INVALID_HEADERS_STREAM_DATA, + absl::StrCat("Unsupported field of HTTP/2 SETTINGS frame: ", + SETTINGS_ENABLE_PUSH), + _)); + } + EXPECT_CALL(*connection_, + CloseConnection( + QUIC_INVALID_HEADERS_STREAM_DATA, + absl::StrCat("Unsupported field of HTTP/2 SETTINGS frame: ", + SETTINGS_MAX_FRAME_SIZE), + _)); + stream_frame_.data_buffer = frame.data(); + stream_frame_.data_length = frame.size(); + headers_stream_->OnStreamFrame(stream_frame_); +} + +TEST_P(QuicHeadersStreamTest, ProcessSpdyPingFrame) { + SpdyPingIR data(1); + SpdySerializedFrame frame(framer_->SerializeFrame(data)); + EXPECT_CALL(*connection_, CloseConnection(QUIC_INVALID_HEADERS_STREAM_DATA, + "SPDY PING frame received.", _)) + .WillOnce(InvokeWithoutArgs( + this, &QuicHeadersStreamTest::TearDownLocalConnectionState)); + stream_frame_.data_buffer = frame.data(); + stream_frame_.data_length = frame.size(); + headers_stream_->OnStreamFrame(stream_frame_); +} + +TEST_P(QuicHeadersStreamTest, ProcessSpdyGoAwayFrame) { + SpdyGoAwayIR data(/* last_good_stream_id = */ 1, ERROR_CODE_PROTOCOL_ERROR, + "go away"); + SpdySerializedFrame frame(framer_->SerializeFrame(data)); + EXPECT_CALL(*connection_, CloseConnection(QUIC_INVALID_HEADERS_STREAM_DATA, + "SPDY GOAWAY frame received.", _)) + .WillOnce(InvokeWithoutArgs( + this, &QuicHeadersStreamTest::TearDownLocalConnectionState)); + stream_frame_.data_buffer = frame.data(); + stream_frame_.data_length = frame.size(); + headers_stream_->OnStreamFrame(stream_frame_); +} + +TEST_P(QuicHeadersStreamTest, ProcessSpdyWindowUpdateFrame) { + SpdyWindowUpdateIR data(/* stream_id = */ 1, /* delta = */ 1); + SpdySerializedFrame frame(framer_->SerializeFrame(data)); + EXPECT_CALL(*connection_, + CloseConnection(QUIC_INVALID_HEADERS_STREAM_DATA, + "SPDY WINDOW_UPDATE frame received.", _)) + .WillOnce(InvokeWithoutArgs( + this, &QuicHeadersStreamTest::TearDownLocalConnectionState)); + stream_frame_.data_buffer = frame.data(); + stream_frame_.data_length = frame.size(); + headers_stream_->OnStreamFrame(stream_frame_); +} + +TEST_P(QuicHeadersStreamTest, NoConnectionLevelFlowControl) { + EXPECT_FALSE(QuicStreamPeer::StreamContributesToConnectionFlowControl( + headers_stream_)); +} + +TEST_P(QuicHeadersStreamTest, AckSentData) { + EXPECT_CALL(session_, WritevData(QuicUtils::GetHeadersStreamId( + connection_->transport_version()), + _, _, NO_FIN, _, _)) + .WillRepeatedly(Invoke(&session_, &MockQuicSpdySession::ConsumeData)); + InSequence s; + quiche::QuicheReferenceCountedPointer ack_listener1( + new MockAckListener()); + quiche::QuicheReferenceCountedPointer ack_listener2( + new MockAckListener()); + quiche::QuicheReferenceCountedPointer ack_listener3( + new MockAckListener()); + + // Packet 1. + headers_stream_->WriteOrBufferData("Header5", false, ack_listener1); + headers_stream_->WriteOrBufferData("Header5", false, ack_listener1); + headers_stream_->WriteOrBufferData("Header7", false, ack_listener2); + + // Packet 2. + headers_stream_->WriteOrBufferData("Header9", false, ack_listener3); + headers_stream_->WriteOrBufferData("Header7", false, ack_listener2); + + // Packet 3. + headers_stream_->WriteOrBufferData("Header9", false, ack_listener3); + + // Packet 2 gets retransmitted. + EXPECT_CALL(*ack_listener3, OnPacketRetransmitted(7)).Times(1); + EXPECT_CALL(*ack_listener2, OnPacketRetransmitted(7)).Times(1); + headers_stream_->OnStreamFrameRetransmitted(21, 7, false); + headers_stream_->OnStreamFrameRetransmitted(28, 7, false); + + // Packets are acked in order: 2, 3, 1. + QuicByteCount newly_acked_length = 0; + EXPECT_CALL(*ack_listener3, OnPacketAcked(7, _)); + EXPECT_CALL(*ack_listener2, OnPacketAcked(7, _)); + EXPECT_TRUE(headers_stream_->OnStreamFrameAcked( + 21, 7, false, QuicTime::Delta::Zero(), QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(7u, newly_acked_length); + EXPECT_TRUE(headers_stream_->OnStreamFrameAcked( + 28, 7, false, QuicTime::Delta::Zero(), QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(7u, newly_acked_length); + + EXPECT_CALL(*ack_listener3, OnPacketAcked(7, _)); + EXPECT_TRUE(headers_stream_->OnStreamFrameAcked( + 35, 7, false, QuicTime::Delta::Zero(), QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(7u, newly_acked_length); + + EXPECT_CALL(*ack_listener1, OnPacketAcked(7, _)); + EXPECT_CALL(*ack_listener1, OnPacketAcked(7, _)); + EXPECT_TRUE(headers_stream_->OnStreamFrameAcked( + 0, 7, false, QuicTime::Delta::Zero(), QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(7u, newly_acked_length); + EXPECT_TRUE(headers_stream_->OnStreamFrameAcked( + 7, 7, false, QuicTime::Delta::Zero(), QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(7u, newly_acked_length); + // Unsent data is acked. + EXPECT_CALL(*ack_listener2, OnPacketAcked(7, _)); + EXPECT_TRUE(headers_stream_->OnStreamFrameAcked( + 14, 10, false, QuicTime::Delta::Zero(), QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(7u, newly_acked_length); +} + +TEST_P(QuicHeadersStreamTest, FrameContainsMultipleHeaders) { + // In this test, a stream frame can contain multiple headers. + EXPECT_CALL(session_, WritevData(QuicUtils::GetHeadersStreamId( + connection_->transport_version()), + _, _, NO_FIN, _, _)) + .WillRepeatedly(Invoke(&session_, &MockQuicSpdySession::ConsumeData)); + InSequence s; + quiche::QuicheReferenceCountedPointer ack_listener1( + new MockAckListener()); + quiche::QuicheReferenceCountedPointer ack_listener2( + new MockAckListener()); + quiche::QuicheReferenceCountedPointer ack_listener3( + new MockAckListener()); + + headers_stream_->WriteOrBufferData("Header5", false, ack_listener1); + headers_stream_->WriteOrBufferData("Header5", false, ack_listener1); + headers_stream_->WriteOrBufferData("Header7", false, ack_listener2); + headers_stream_->WriteOrBufferData("Header9", false, ack_listener3); + headers_stream_->WriteOrBufferData("Header7", false, ack_listener2); + headers_stream_->WriteOrBufferData("Header9", false, ack_listener3); + + // Frame 1 is retransmitted. + EXPECT_CALL(*ack_listener1, OnPacketRetransmitted(14)); + EXPECT_CALL(*ack_listener2, OnPacketRetransmitted(3)); + headers_stream_->OnStreamFrameRetransmitted(0, 17, false); + + // Frames are acked in order: 2, 3, 1. + QuicByteCount newly_acked_length = 0; + EXPECT_CALL(*ack_listener2, OnPacketAcked(4, _)); + EXPECT_CALL(*ack_listener3, OnPacketAcked(7, _)); + EXPECT_CALL(*ack_listener2, OnPacketAcked(2, _)); + EXPECT_TRUE(headers_stream_->OnStreamFrameAcked( + 17, 13, false, QuicTime::Delta::Zero(), QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(13u, newly_acked_length); + + EXPECT_CALL(*ack_listener2, OnPacketAcked(5, _)); + EXPECT_CALL(*ack_listener3, OnPacketAcked(7, _)); + EXPECT_TRUE(headers_stream_->OnStreamFrameAcked( + 30, 12, false, QuicTime::Delta::Zero(), QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(12u, newly_acked_length); + + EXPECT_CALL(*ack_listener1, OnPacketAcked(14, _)); + EXPECT_CALL(*ack_listener2, OnPacketAcked(3, _)); + EXPECT_TRUE(headers_stream_->OnStreamFrameAcked( + 0, 17, false, QuicTime::Delta::Zero(), QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(17u, newly_acked_length); +} + +TEST_P(QuicHeadersStreamTest, HeadersGetAckedMultipleTimes) { + EXPECT_CALL(session_, WritevData(QuicUtils::GetHeadersStreamId( + connection_->transport_version()), + _, _, NO_FIN, _, _)) + .WillRepeatedly(Invoke(&session_, &MockQuicSpdySession::ConsumeData)); + InSequence s; + quiche::QuicheReferenceCountedPointer ack_listener1( + new MockAckListener()); + quiche::QuicheReferenceCountedPointer ack_listener2( + new MockAckListener()); + quiche::QuicheReferenceCountedPointer ack_listener3( + new MockAckListener()); + + // Send [0, 42). + headers_stream_->WriteOrBufferData("Header5", false, ack_listener1); + headers_stream_->WriteOrBufferData("Header5", false, ack_listener1); + headers_stream_->WriteOrBufferData("Header7", false, ack_listener2); + headers_stream_->WriteOrBufferData("Header9", false, ack_listener3); + headers_stream_->WriteOrBufferData("Header7", false, ack_listener2); + headers_stream_->WriteOrBufferData("Header9", false, ack_listener3); + + // Ack [15, 20), [5, 25), [10, 17), [0, 12) and [22, 42). + QuicByteCount newly_acked_length = 0; + EXPECT_CALL(*ack_listener2, OnPacketAcked(5, _)); + EXPECT_TRUE(headers_stream_->OnStreamFrameAcked( + 15, 5, false, QuicTime::Delta::Zero(), QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(5u, newly_acked_length); + + EXPECT_CALL(*ack_listener1, OnPacketAcked(9, _)); + EXPECT_CALL(*ack_listener2, OnPacketAcked(1, _)); + EXPECT_CALL(*ack_listener2, OnPacketAcked(1, _)); + EXPECT_CALL(*ack_listener3, OnPacketAcked(4, _)); + EXPECT_TRUE(headers_stream_->OnStreamFrameAcked( + 5, 20, false, QuicTime::Delta::Zero(), QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(15u, newly_acked_length); + + // Duplicate ack. + EXPECT_FALSE(headers_stream_->OnStreamFrameAcked( + 10, 7, false, QuicTime::Delta::Zero(), QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(0u, newly_acked_length); + + EXPECT_CALL(*ack_listener1, OnPacketAcked(5, _)); + EXPECT_TRUE(headers_stream_->OnStreamFrameAcked( + 0, 12, false, QuicTime::Delta::Zero(), QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(5u, newly_acked_length); + + EXPECT_CALL(*ack_listener3, OnPacketAcked(3, _)); + EXPECT_CALL(*ack_listener2, OnPacketAcked(7, _)); + EXPECT_CALL(*ack_listener3, OnPacketAcked(7, _)); + EXPECT_TRUE(headers_stream_->OnStreamFrameAcked( + 22, 20, false, QuicTime::Delta::Zero(), QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(17u, newly_acked_length); +} + +TEST_P(QuicHeadersStreamTest, CloseOnPushPromiseToServer) { + if (perspective() == Perspective::IS_CLIENT) { + return; + } + QuicStreamId promised_id = 1; + SpdyPushPromiseIR push_promise(client_id_1_, promised_id, headers_.Clone()); + SpdySerializedFrame frame = framer_->SerializeFrame(push_promise); + stream_frame_.data_buffer = frame.data(); + stream_frame_.data_length = frame.size(); + EXPECT_CALL(session_, OnStreamHeaderList(_, _, _, _)); + // TODO(lassey): Check for HTTP_WRONG_STREAM error code. + EXPECT_CALL(*connection_, CloseConnection(QUIC_INVALID_HEADERS_STREAM_DATA, + "PUSH_PROMISE not supported.", _)); + headers_stream_->OnStreamFrame(stream_frame_); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/http/quic_receive_control_stream.cc b/quiche/quic/core/http/quic_receive_control_stream.cc new file mode 100644 index 000000000000..4be1536b989d --- /dev/null +++ b/quiche/quic/core/http/quic_receive_control_stream.cc @@ -0,0 +1,234 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_receive_control_stream.h" + +#include + +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/quic/core/http/http_constants.h" +#include "quiche/quic/core/http/http_decoder.h" +#include "quiche/quic/core/http/quic_spdy_session.h" +#include "quiche/quic/core/quic_stream_priority.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/common/quiche_text_utils.h" + +namespace quic { + +QuicReceiveControlStream::QuicReceiveControlStream( + PendingStream* pending, QuicSpdySession* spdy_session) + : QuicStream(pending, spdy_session, + /*is_static=*/true), + settings_frame_received_(false), + decoder_(this), + spdy_session_(spdy_session) { + sequencer()->set_level_triggered(true); +} + +QuicReceiveControlStream::~QuicReceiveControlStream() {} + +void QuicReceiveControlStream::OnStreamReset( + const QuicRstStreamFrame& /*frame*/) { + stream_delegate()->OnStreamError( + QUIC_HTTP_CLOSED_CRITICAL_STREAM, + "RESET_STREAM received for receive control stream"); +} + +void QuicReceiveControlStream::OnDataAvailable() { + iovec iov; + while (!reading_stopped() && decoder_.error() == QUIC_NO_ERROR && + sequencer()->GetReadableRegion(&iov)) { + QUICHE_DCHECK(!sequencer()->IsClosed()); + + QuicByteCount processed_bytes = decoder_.ProcessInput( + reinterpret_cast(iov.iov_base), iov.iov_len); + sequencer()->MarkConsumed(processed_bytes); + + if (!session()->connection()->connected()) { + return; + } + + // The only reason QuicReceiveControlStream pauses HttpDecoder is an error, + // in which case the connection would have already been closed. + QUICHE_DCHECK_EQ(iov.iov_len, processed_bytes); + } +} + +void QuicReceiveControlStream::OnError(HttpDecoder* decoder) { + stream_delegate()->OnStreamError(decoder->error(), decoder->error_detail()); +} + +bool QuicReceiveControlStream::OnMaxPushIdFrame() { + return ValidateFrameType(HttpFrameType::MAX_PUSH_ID); +} + +bool QuicReceiveControlStream::OnGoAwayFrame(const GoAwayFrame& frame) { + if (spdy_session()->debug_visitor()) { + spdy_session()->debug_visitor()->OnGoAwayFrameReceived(frame); + } + + if (!ValidateFrameType(HttpFrameType::GOAWAY)) { + return false; + } + + spdy_session()->OnHttp3GoAway(frame.id); + return true; +} + +bool QuicReceiveControlStream::OnSettingsFrameStart( + QuicByteCount /*header_length*/) { + return ValidateFrameType(HttpFrameType::SETTINGS); +} + +bool QuicReceiveControlStream::OnSettingsFrame(const SettingsFrame& frame) { + QUIC_DVLOG(1) << "Control Stream " << id() + << " received settings frame: " << frame; + return spdy_session_->OnSettingsFrame(frame); +} + +bool QuicReceiveControlStream::OnDataFrameStart(QuicByteCount /*header_length*/, + QuicByteCount + /*payload_length*/) { + return ValidateFrameType(HttpFrameType::DATA); +} + +bool QuicReceiveControlStream::OnDataFramePayload( + absl::string_view /*payload*/) { + QUICHE_NOTREACHED(); + return false; +} + +bool QuicReceiveControlStream::OnDataFrameEnd() { + QUICHE_NOTREACHED(); + return false; +} + +bool QuicReceiveControlStream::OnHeadersFrameStart( + QuicByteCount /*header_length*/, QuicByteCount + /*payload_length*/) { + return ValidateFrameType(HttpFrameType::HEADERS); +} + +bool QuicReceiveControlStream::OnHeadersFramePayload( + absl::string_view /*payload*/) { + QUICHE_NOTREACHED(); + return false; +} + +bool QuicReceiveControlStream::OnHeadersFrameEnd() { + QUICHE_NOTREACHED(); + return false; +} + +bool QuicReceiveControlStream::OnPriorityUpdateFrameStart( + QuicByteCount /*header_length*/) { + return ValidateFrameType(HttpFrameType::PRIORITY_UPDATE_REQUEST_STREAM); +} + +bool QuicReceiveControlStream::OnPriorityUpdateFrame( + const PriorityUpdateFrame& frame) { + if (spdy_session()->debug_visitor()) { + spdy_session()->debug_visitor()->OnPriorityUpdateFrameReceived(frame); + } + + absl::optional priority = + ParsePriorityFieldValue(frame.priority_field_value); + + if (!priority.has_value()) { + stream_delegate()->OnStreamError(QUIC_INVALID_PRIORITY_UPDATE, + "Invalid PRIORITY_UPDATE frame payload."); + return false; + } + + const QuicStreamId stream_id = frame.prioritized_element_id; + return spdy_session_->OnPriorityUpdateForRequestStream(stream_id, *priority); +} + +bool QuicReceiveControlStream::OnAcceptChFrameStart( + QuicByteCount /* header_length */) { + return ValidateFrameType(HttpFrameType::ACCEPT_CH); +} + +bool QuicReceiveControlStream::OnAcceptChFrame(const AcceptChFrame& frame) { + QUICHE_DCHECK_EQ(Perspective::IS_CLIENT, spdy_session()->perspective()); + + if (spdy_session()->debug_visitor()) { + spdy_session()->debug_visitor()->OnAcceptChFrameReceived(frame); + } + + spdy_session()->OnAcceptChFrame(frame); + return true; +} + +void QuicReceiveControlStream::OnWebTransportStreamFrameType( + QuicByteCount /*header_length*/, WebTransportSessionId /*session_id*/) { + QUIC_BUG(WEBTRANSPORT_STREAM on Control Stream) + << "Parsed WEBTRANSPORT_STREAM on a control stream."; +} + +bool QuicReceiveControlStream::OnUnknownFrameStart( + uint64_t frame_type, QuicByteCount /*header_length*/, + QuicByteCount payload_length) { + if (spdy_session()->debug_visitor()) { + spdy_session()->debug_visitor()->OnUnknownFrameReceived(id(), frame_type, + payload_length); + } + + return ValidateFrameType(static_cast(frame_type)); +} + +bool QuicReceiveControlStream::OnUnknownFramePayload( + absl::string_view /*payload*/) { + // Ignore unknown frame types. + return true; +} + +bool QuicReceiveControlStream::OnUnknownFrameEnd() { + // Ignore unknown frame types. + return true; +} + +bool QuicReceiveControlStream::ValidateFrameType(HttpFrameType frame_type) { + // Certain frame types are forbidden. + if (frame_type == HttpFrameType::DATA || + frame_type == HttpFrameType::HEADERS || + (spdy_session()->perspective() == Perspective::IS_CLIENT && + frame_type == HttpFrameType::MAX_PUSH_ID) || + (spdy_session()->perspective() == Perspective::IS_SERVER && + frame_type == HttpFrameType::ACCEPT_CH)) { + stream_delegate()->OnStreamError( + QUIC_HTTP_FRAME_UNEXPECTED_ON_CONTROL_STREAM, + absl::StrCat("Invalid frame type ", static_cast(frame_type), + " received on control stream.")); + return false; + } + + if (settings_frame_received_) { + if (frame_type == HttpFrameType::SETTINGS) { + // SETTINGS frame may only be the first frame on the control stream. + stream_delegate()->OnStreamError( + QUIC_HTTP_INVALID_FRAME_SEQUENCE_ON_CONTROL_STREAM, + "SETTINGS frame can only be received once."); + return false; + } + return true; + } + + if (frame_type == HttpFrameType::SETTINGS) { + settings_frame_received_ = true; + return true; + } + stream_delegate()->OnStreamError( + QUIC_HTTP_MISSING_SETTINGS_FRAME, + absl::StrCat("First frame received on control stream is type ", + static_cast(frame_type), ", but it must be SETTINGS.")); + return false; +} + +} // namespace quic diff --git a/quiche/quic/core/http/quic_receive_control_stream.h b/quiche/quic/core/http/quic_receive_control_stream.h new file mode 100644 index 000000000000..71d0bf2fa62b --- /dev/null +++ b/quiche/quic/core/http/quic_receive_control_stream.h @@ -0,0 +1,78 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_HTTP_QUIC_RECEIVE_CONTROL_STREAM_H_ +#define QUICHE_QUIC_CORE_HTTP_QUIC_RECEIVE_CONTROL_STREAM_H_ + +#include "quiche/quic/core/http/http_decoder.h" +#include "quiche/quic/core/quic_stream.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class QuicSpdySession; + +// 3.2.1 Control Stream. +// The receive control stream is peer initiated and is read only. +class QUIC_EXPORT_PRIVATE QuicReceiveControlStream + : public QuicStream, + public HttpDecoder::Visitor { + public: + explicit QuicReceiveControlStream(PendingStream* pending, + QuicSpdySession* spdy_session); + QuicReceiveControlStream(const QuicReceiveControlStream&) = delete; + QuicReceiveControlStream& operator=(const QuicReceiveControlStream&) = delete; + ~QuicReceiveControlStream() override; + + // Overriding QuicStream::OnStreamReset to make sure control stream is never + // closed before connection. + void OnStreamReset(const QuicRstStreamFrame& frame) override; + + // Implementation of QuicStream. + void OnDataAvailable() override; + + // HttpDecoder::Visitor implementation. + void OnError(HttpDecoder* decoder) override; + bool OnMaxPushIdFrame() override; + bool OnGoAwayFrame(const GoAwayFrame& frame) override; + bool OnSettingsFrameStart(QuicByteCount header_length) override; + bool OnSettingsFrame(const SettingsFrame& frame) override; + bool OnDataFrameStart(QuicByteCount header_length, + QuicByteCount payload_length) override; + bool OnDataFramePayload(absl::string_view payload) override; + bool OnDataFrameEnd() override; + bool OnHeadersFrameStart(QuicByteCount header_length, + QuicByteCount payload_length) override; + bool OnHeadersFramePayload(absl::string_view payload) override; + bool OnHeadersFrameEnd() override; + bool OnPriorityUpdateFrameStart(QuicByteCount header_length) override; + bool OnPriorityUpdateFrame(const PriorityUpdateFrame& frame) override; + bool OnAcceptChFrameStart(QuicByteCount header_length) override; + bool OnAcceptChFrame(const AcceptChFrame& frame) override; + void OnWebTransportStreamFrameType(QuicByteCount header_length, + WebTransportSessionId session_id) override; + bool OnUnknownFrameStart(uint64_t frame_type, QuicByteCount header_length, + QuicByteCount payload_length) override; + bool OnUnknownFramePayload(absl::string_view payload) override; + bool OnUnknownFrameEnd() override; + + QuicSpdySession* spdy_session() { return spdy_session_; } + + private: + // Called when a frame of allowed type is received. Returns true if the frame + // is allowed in this position. Returns false and resets the stream + // otherwise. + bool ValidateFrameType(HttpFrameType frame_type); + + // False until a SETTINGS frame is received. + bool settings_frame_received_; + + HttpDecoder decoder_; + QuicSpdySession* const spdy_session_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_HTTP_QUIC_RECEIVE_CONTROL_STREAM_H_ diff --git a/quiche/quic/core/http/quic_receive_control_stream_test.cc b/quiche/quic/core/http/quic_receive_control_stream_test.cc new file mode 100644 index 000000000000..af7e5d43bc68 --- /dev/null +++ b/quiche/quic/core/http/quic_receive_control_stream_test.cc @@ -0,0 +1,461 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_receive_control_stream.h" + +#include "absl/memory/memory.h" +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/http/http_constants.h" +#include "quiche/quic/core/qpack/qpack_header_table.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/test_tools/qpack/qpack_encoder_peer.h" +#include "quiche/quic/test_tools/quic_spdy_session_peer.h" +#include "quiche/quic/test_tools/quic_stream_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/simple_buffer_allocator.h" + +namespace quic { + +class QpackEncoder; + +namespace test { + +namespace { +using ::testing::_; +using ::testing::AnyNumber; +using ::testing::StrictMock; + +struct TestParams { + TestParams(const ParsedQuicVersion& version, Perspective perspective) + : version(version), perspective(perspective) { + QUIC_LOG(INFO) << "TestParams: " << *this; + } + + TestParams(const TestParams& other) + : version(other.version), perspective(other.perspective) {} + + friend std::ostream& operator<<(std::ostream& os, const TestParams& tp) { + os << "{ version: " << ParsedQuicVersionToString(tp.version) + << ", perspective: " + << (tp.perspective == Perspective::IS_CLIENT ? "client" : "server") + << "}"; + return os; + } + + ParsedQuicVersion version; + Perspective perspective; +}; + +// Used by ::testing::PrintToStringParamName(). +std::string PrintToString(const TestParams& tp) { + return absl::StrCat( + ParsedQuicVersionToString(tp.version), "_", + (tp.perspective == Perspective::IS_CLIENT ? "client" : "server")); +} + +std::vector GetTestParams() { + std::vector params; + ParsedQuicVersionVector all_supported_versions = AllSupportedVersions(); + for (const auto& version : AllSupportedVersions()) { + if (!VersionUsesHttp3(version.transport_version)) { + continue; + } + for (Perspective p : {Perspective::IS_SERVER, Perspective::IS_CLIENT}) { + params.emplace_back(version, p); + } + } + return params; +} + +class TestStream : public QuicSpdyStream { + public: + TestStream(QuicStreamId id, QuicSpdySession* session) + : QuicSpdyStream(id, session, BIDIRECTIONAL) {} + ~TestStream() override = default; + + void OnBodyAvailable() override {} +}; + +class QuicReceiveControlStreamTest : public QuicTestWithParam { + public: + QuicReceiveControlStreamTest() + : connection_(new StrictMock( + &helper_, &alarm_factory_, perspective(), + SupportedVersions(GetParam().version))), + session_(connection_) { + EXPECT_CALL(session_, OnCongestionWindowChange(_)).Times(AnyNumber()); + session_.Initialize(); + QuicStreamId id = perspective() == Perspective::IS_SERVER + ? GetNthClientInitiatedUnidirectionalStreamId( + session_.transport_version(), 3) + : GetNthServerInitiatedUnidirectionalStreamId( + session_.transport_version(), 3); + char type[] = {kControlStream}; + + QuicStreamFrame data1(id, false, 0, absl::string_view(type, 1)); + session_.OnStreamFrame(data1); + + receive_control_stream_ = + QuicSpdySessionPeer::GetReceiveControlStream(&session_); + + stream_ = new TestStream(GetNthClientInitiatedBidirectionalStreamId( + GetParam().version.transport_version, 0), + &session_); + session_.ActivateStream(absl::WrapUnique(stream_)); + } + + Perspective perspective() const { return GetParam().perspective; } + + QuicStreamOffset NumBytesConsumed() { + return QuicStreamPeer::sequencer(receive_control_stream_) + ->NumBytesConsumed(); + } + + MockQuicConnectionHelper helper_; + MockAlarmFactory alarm_factory_; + StrictMock* connection_; + StrictMock session_; + QuicReceiveControlStream* receive_control_stream_; + TestStream* stream_; +}; + +INSTANTIATE_TEST_SUITE_P(Tests, QuicReceiveControlStreamTest, + ::testing::ValuesIn(GetTestParams()), + ::testing::PrintToStringParamName()); + +TEST_P(QuicReceiveControlStreamTest, ResetControlStream) { + EXPECT_TRUE(receive_control_stream_->is_static()); + QuicRstStreamFrame rst_frame(kInvalidControlFrameId, + receive_control_stream_->id(), + QUIC_STREAM_CANCELLED, 1234); + EXPECT_CALL(*connection_, + CloseConnection(QUIC_HTTP_CLOSED_CRITICAL_STREAM, _, _)); + receive_control_stream_->OnStreamReset(rst_frame); +} + +TEST_P(QuicReceiveControlStreamTest, ReceiveSettings) { + SettingsFrame settings; + settings.values[10] = 2; + settings.values[SETTINGS_MAX_FIELD_SECTION_SIZE] = 5; + settings.values[SETTINGS_QPACK_BLOCKED_STREAMS] = 12; + settings.values[SETTINGS_QPACK_MAX_TABLE_CAPACITY] = 37; + std::string data = HttpEncoder::SerializeSettingsFrame(settings); + QuicStreamFrame frame(receive_control_stream_->id(), false, 1, data); + + QpackEncoder* qpack_encoder = session_.qpack_encoder(); + QpackEncoderHeaderTable* header_table = + QpackEncoderPeer::header_table(qpack_encoder); + EXPECT_EQ(std::numeric_limits::max(), + session_.max_outbound_header_list_size()); + EXPECT_EQ(0u, QpackEncoderPeer::maximum_blocked_streams(qpack_encoder)); + EXPECT_EQ(0u, header_table->maximum_dynamic_table_capacity()); + + receive_control_stream_->OnStreamFrame(frame); + + EXPECT_EQ(5u, session_.max_outbound_header_list_size()); + EXPECT_EQ(12u, QpackEncoderPeer::maximum_blocked_streams(qpack_encoder)); + EXPECT_EQ(37u, header_table->maximum_dynamic_table_capacity()); +} + +// Regression test for https://crbug.com/982648. +// QuicReceiveControlStream::OnDataAvailable() must stop processing input as +// soon as OnSettingsFrameStart() is called by HttpDecoder for the second frame. +TEST_P(QuicReceiveControlStreamTest, ReceiveSettingsTwice) { + SettingsFrame settings; + // Reserved identifiers, must be ignored. + settings.values[0x21] = 100; + settings.values[0x40] = 200; + + std::string settings_frame = HttpEncoder::SerializeSettingsFrame(settings); + + QuicStreamOffset offset = 1; + EXPECT_EQ(offset, NumBytesConsumed()); + + // Receive first SETTINGS frame. + receive_control_stream_->OnStreamFrame( + QuicStreamFrame(receive_control_stream_->id(), /* fin = */ false, offset, + settings_frame)); + offset += settings_frame.length(); + + // First SETTINGS frame is consumed. + EXPECT_EQ(offset, NumBytesConsumed()); + + // Second SETTINGS frame causes the connection to be closed. + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_HTTP_INVALID_FRAME_SEQUENCE_ON_CONTROL_STREAM, + "SETTINGS frame can only be received once.", _)) + .WillOnce( + Invoke(connection_, &MockQuicConnection::ReallyCloseConnection)); + EXPECT_CALL(*connection_, SendConnectionClosePacket(_, _, _)); + EXPECT_CALL(session_, OnConnectionClosed(_, _)); + + // Receive second SETTINGS frame. + receive_control_stream_->OnStreamFrame( + QuicStreamFrame(receive_control_stream_->id(), /* fin = */ false, offset, + settings_frame)); + + // Frame header of second SETTINGS frame is consumed, but not frame payload. + QuicByteCount settings_frame_header_length = 2; + EXPECT_EQ(offset + settings_frame_header_length, NumBytesConsumed()); +} + +TEST_P(QuicReceiveControlStreamTest, ReceiveSettingsFragments) { + SettingsFrame settings; + settings.values[10] = 2; + settings.values[SETTINGS_MAX_FIELD_SECTION_SIZE] = 5; + std::string data = HttpEncoder::SerializeSettingsFrame(settings); + std::string data1 = data.substr(0, 1); + std::string data2 = data.substr(1, data.length() - 1); + + QuicStreamFrame frame(receive_control_stream_->id(), false, 1, data1); + QuicStreamFrame frame2(receive_control_stream_->id(), false, 2, data2); + EXPECT_NE(5u, session_.max_outbound_header_list_size()); + receive_control_stream_->OnStreamFrame(frame); + receive_control_stream_->OnStreamFrame(frame2); + EXPECT_EQ(5u, session_.max_outbound_header_list_size()); +} + +TEST_P(QuicReceiveControlStreamTest, ReceiveWrongFrame) { + // DATA frame header without payload. + quiche::QuicheBuffer data = HttpEncoder::SerializeDataFrameHeader( + /* payload_length = */ 2, quiche::SimpleBufferAllocator::Get()); + + QuicStreamFrame frame(receive_control_stream_->id(), false, 1, + data.AsStringView()); + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_HTTP_FRAME_UNEXPECTED_ON_CONTROL_STREAM, _, _)); + receive_control_stream_->OnStreamFrame(frame); +} + +TEST_P(QuicReceiveControlStreamTest, + ReceivePriorityUpdateFrameBeforeSettingsFrame) { + std::string serialized_frame = HttpEncoder::SerializePriorityUpdateFrame({}); + QuicStreamFrame data(receive_control_stream_->id(), /* fin = */ false, + /* offset = */ 1, serialized_frame); + + EXPECT_CALL(*connection_, + CloseConnection(QUIC_HTTP_MISSING_SETTINGS_FRAME, + "First frame received on control stream is type " + "984832, but it must be SETTINGS.", + _)) + .WillOnce( + Invoke(connection_, &MockQuicConnection::ReallyCloseConnection)); + EXPECT_CALL(*connection_, SendConnectionClosePacket(_, _, _)); + EXPECT_CALL(session_, OnConnectionClosed(_, _)); + + receive_control_stream_->OnStreamFrame(data); +} + +TEST_P(QuicReceiveControlStreamTest, ReceiveGoAwayFrame) { + StrictMock debug_visitor; + session_.set_debug_visitor(&debug_visitor); + + QuicStreamOffset offset = 1; + + // Receive SETTINGS frame. + SettingsFrame settings; + std::string settings_frame = HttpEncoder::SerializeSettingsFrame(settings); + EXPECT_CALL(debug_visitor, OnSettingsFrameReceived(settings)); + receive_control_stream_->OnStreamFrame( + QuicStreamFrame(receive_control_stream_->id(), /* fin = */ false, offset, + settings_frame)); + offset += settings_frame.length(); + + GoAwayFrame goaway{/* id = */ 0}; + std::string goaway_frame = HttpEncoder::SerializeGoAwayFrame(goaway); + QuicStreamFrame frame(receive_control_stream_->id(), false, offset, + goaway_frame); + + EXPECT_FALSE(session_.goaway_received()); + + EXPECT_CALL(debug_visitor, OnGoAwayFrameReceived(goaway)); + receive_control_stream_->OnStreamFrame(frame); + + EXPECT_TRUE(session_.goaway_received()); +} + +TEST_P(QuicReceiveControlStreamTest, PushPromiseOnControlStreamShouldClose) { + std::string push_promise_frame = absl::HexStringToBytes( + "05" // PUSH_PROMISE + "01" // length + "00"); // push ID + QuicStreamFrame frame(receive_control_stream_->id(), false, 1, + push_promise_frame); + EXPECT_CALL(*connection_, CloseConnection(QUIC_HTTP_FRAME_ERROR, _, _)) + .WillOnce( + Invoke(connection_, &MockQuicConnection::ReallyCloseConnection)); + EXPECT_CALL(*connection_, SendConnectionClosePacket(_, _, _)); + EXPECT_CALL(session_, OnConnectionClosed(_, _)); + receive_control_stream_->OnStreamFrame(frame); +} + +// Regression test for b/137554973: unknown frames should be consumed. +TEST_P(QuicReceiveControlStreamTest, ConsumeUnknownFrame) { + EXPECT_EQ(1u, NumBytesConsumed()); + + QuicStreamOffset offset = 1; + + // Receive SETTINGS frame. + std::string settings_frame = HttpEncoder::SerializeSettingsFrame({}); + receive_control_stream_->OnStreamFrame( + QuicStreamFrame(receive_control_stream_->id(), /* fin = */ false, offset, + settings_frame)); + offset += settings_frame.length(); + + // SETTINGS frame is consumed. + EXPECT_EQ(offset, NumBytesConsumed()); + + // Receive unknown frame. + std::string unknown_frame = absl::HexStringToBytes( + "21" // reserved frame type + "03" // payload length + "666f6f"); // payload "foo" + + receive_control_stream_->OnStreamFrame(QuicStreamFrame( + receive_control_stream_->id(), /* fin = */ false, offset, unknown_frame)); + offset += unknown_frame.size(); + + // Unknown frame is consumed. + EXPECT_EQ(offset, NumBytesConsumed()); +} + +TEST_P(QuicReceiveControlStreamTest, ReceiveUnknownFrame) { + StrictMock debug_visitor; + session_.set_debug_visitor(&debug_visitor); + + const QuicStreamId id = receive_control_stream_->id(); + QuicStreamOffset offset = 1; + + // Receive SETTINGS frame. + SettingsFrame settings; + std::string settings_frame = HttpEncoder::SerializeSettingsFrame(settings); + EXPECT_CALL(debug_visitor, OnSettingsFrameReceived(settings)); + receive_control_stream_->OnStreamFrame( + QuicStreamFrame(id, /* fin = */ false, offset, settings_frame)); + offset += settings_frame.length(); + + // Receive unknown frame. + std::string unknown_frame = absl::HexStringToBytes( + "21" // reserved frame type + "03" // payload length + "666f6f"); // payload "foo" + + EXPECT_CALL(debug_visitor, OnUnknownFrameReceived(id, /* frame_type = */ 0x21, + /* payload_length = */ 3)); + receive_control_stream_->OnStreamFrame( + QuicStreamFrame(id, /* fin = */ false, offset, unknown_frame)); +} + +TEST_P(QuicReceiveControlStreamTest, CancelPushFrameBeforeSettings) { + std::string cancel_push_frame = absl::HexStringToBytes( + "03" // type CANCEL_PUSH + "01" // payload length + "01"); // push ID + + EXPECT_CALL(*connection_, CloseConnection(QUIC_HTTP_FRAME_ERROR, + "CANCEL_PUSH frame received.", _)) + .WillOnce( + Invoke(connection_, &MockQuicConnection::ReallyCloseConnection)); + EXPECT_CALL(*connection_, SendConnectionClosePacket(_, _, _)); + EXPECT_CALL(session_, OnConnectionClosed(_, _)); + + receive_control_stream_->OnStreamFrame( + QuicStreamFrame(receive_control_stream_->id(), /* fin = */ false, + /* offset = */ 1, cancel_push_frame)); +} + +TEST_P(QuicReceiveControlStreamTest, AcceptChFrameBeforeSettings) { + std::string accept_ch_frame = absl::HexStringToBytes( + "4089" // type (ACCEPT_CH) + "00"); // length + + if (perspective() == Perspective::IS_SERVER) { + EXPECT_CALL(*connection_, + CloseConnection( + QUIC_HTTP_FRAME_UNEXPECTED_ON_CONTROL_STREAM, + "Invalid frame type 137 received on control stream.", _)) + .WillOnce( + Invoke(connection_, &MockQuicConnection::ReallyCloseConnection)); + } else { + EXPECT_CALL(*connection_, + CloseConnection(QUIC_HTTP_MISSING_SETTINGS_FRAME, + "First frame received on control stream is " + "type 137, but it must be SETTINGS.", + _)) + .WillOnce( + Invoke(connection_, &MockQuicConnection::ReallyCloseConnection)); + } + EXPECT_CALL(*connection_, SendConnectionClosePacket(_, _, _)); + EXPECT_CALL(session_, OnConnectionClosed(_, _)); + + receive_control_stream_->OnStreamFrame( + QuicStreamFrame(receive_control_stream_->id(), /* fin = */ false, + /* offset = */ 1, accept_ch_frame)); +} + +TEST_P(QuicReceiveControlStreamTest, ReceiveAcceptChFrame) { + StrictMock debug_visitor; + session_.set_debug_visitor(&debug_visitor); + + const QuicStreamId id = receive_control_stream_->id(); + QuicStreamOffset offset = 1; + + // Receive SETTINGS frame. + SettingsFrame settings; + std::string settings_frame = HttpEncoder::SerializeSettingsFrame(settings); + EXPECT_CALL(debug_visitor, OnSettingsFrameReceived(settings)); + receive_control_stream_->OnStreamFrame( + QuicStreamFrame(id, /* fin = */ false, offset, settings_frame)); + offset += settings_frame.length(); + + // Receive ACCEPT_CH frame. + std::string accept_ch_frame = absl::HexStringToBytes( + "4089" // type (ACCEPT_CH) + "00"); // length + + if (perspective() == Perspective::IS_CLIENT) { + EXPECT_CALL(debug_visitor, OnAcceptChFrameReceived(_)); + } else { + EXPECT_CALL(*connection_, + CloseConnection( + QUIC_HTTP_FRAME_UNEXPECTED_ON_CONTROL_STREAM, + "Invalid frame type 137 received on control stream.", _)) + .WillOnce( + Invoke(connection_, &MockQuicConnection::ReallyCloseConnection)); + EXPECT_CALL(*connection_, SendConnectionClosePacket(_, _, _)); + EXPECT_CALL(session_, OnConnectionClosed(_, _)); + } + + receive_control_stream_->OnStreamFrame( + QuicStreamFrame(id, /* fin = */ false, offset, accept_ch_frame)); +} + +TEST_P(QuicReceiveControlStreamTest, UnknownFrameBeforeSettings) { + std::string unknown_frame = absl::HexStringToBytes( + "21" // reserved frame type + "03" // payload length + "666f6f"); // payload "foo" + + EXPECT_CALL(*connection_, + CloseConnection(QUIC_HTTP_MISSING_SETTINGS_FRAME, + "First frame received on control stream is type " + "33, but it must be SETTINGS.", + _)) + .WillOnce( + Invoke(connection_, &MockQuicConnection::ReallyCloseConnection)); + EXPECT_CALL(*connection_, SendConnectionClosePacket(_, _, _)); + EXPECT_CALL(session_, OnConnectionClosed(_, _)); + + receive_control_stream_->OnStreamFrame( + QuicStreamFrame(receive_control_stream_->id(), /* fin = */ false, + /* offset = */ 1, unknown_frame)); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/http/quic_send_control_stream.cc b/quiche/quic/core/http/quic_send_control_stream.cc new file mode 100644 index 000000000000..e9b06edb78dd --- /dev/null +++ b/quiche/quic/core/http/quic_send_control_stream.cc @@ -0,0 +1,121 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_send_control_stream.h" + +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/http/http_constants.h" +#include "quiche/quic/core/http/quic_spdy_session.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { +namespace { + +} // anonymous namespace + +QuicSendControlStream::QuicSendControlStream(QuicStreamId id, + QuicSpdySession* spdy_session, + const SettingsFrame& settings) + : QuicStream(id, spdy_session, /*is_static = */ true, WRITE_UNIDIRECTIONAL), + settings_sent_(false), + settings_(settings), + spdy_session_(spdy_session) {} + +void QuicSendControlStream::OnStreamReset(const QuicRstStreamFrame& /*frame*/) { + QUIC_BUG(quic_bug_10382_1) + << "OnStreamReset() called for write unidirectional stream."; +} + +bool QuicSendControlStream::OnStopSending(QuicResetStreamError /* code */) { + stream_delegate()->OnStreamError( + QUIC_HTTP_CLOSED_CRITICAL_STREAM, + "STOP_SENDING received for send control stream"); + return false; +} + +void QuicSendControlStream::MaybeSendSettingsFrame() { + if (settings_sent_) { + return; + } + + QuicConnection::ScopedPacketFlusher flusher(session()->connection()); + // Send the stream type on so the peer knows about this stream. + char data[sizeof(kControlStream)]; + QuicDataWriter writer(ABSL_ARRAYSIZE(data), data); + writer.WriteVarInt62(kControlStream); + WriteOrBufferData(absl::string_view(writer.data(), writer.length()), false, + nullptr); + + SettingsFrame settings = settings_; + // https://tools.ietf.org/html/draft-ietf-quic-http-25#section-7.2.4.1 + // specifies that setting identifiers of 0x1f * N + 0x21 are reserved and + // greasing should be attempted. + if (!GetQuicFlag(quic_enable_http3_grease_randomness)) { + settings.values[0x40] = 20; + } else { + uint32_t result; + QuicRandom::GetInstance()->RandBytes(&result, sizeof(result)); + uint64_t setting_id = 0x1fULL * static_cast(result) + 0x21ULL; + QuicRandom::GetInstance()->RandBytes(&result, sizeof(result)); + settings.values[setting_id] = result; + } + + std::string settings_frame = HttpEncoder::SerializeSettingsFrame(settings); + QUIC_DVLOG(1) << "Control stream " << id() << " is writing settings frame " + << settings; + if (spdy_session_->debug_visitor()) { + spdy_session_->debug_visitor()->OnSettingsFrameSent(settings); + } + WriteOrBufferData(settings_frame, /*fin = */ false, nullptr); + settings_sent_ = true; + + // https://tools.ietf.org/html/draft-ietf-quic-http-25#section-7.2.9 + // specifies that a reserved frame type has no semantic meaning and should be + // discarded. A greasing frame is added here. + WriteOrBufferData(HttpEncoder::SerializeGreasingFrame(), /*fin = */ false, + nullptr); +} + +void QuicSendControlStream::WritePriorityUpdate(QuicStreamId stream_id, + HttpStreamPriority priority) { + QuicConnection::ScopedPacketFlusher flusher(session()->connection()); + MaybeSendSettingsFrame(); + + const std::string priority_field_value = + SerializePriorityFieldValue(priority); + PriorityUpdateFrame priority_update_frame{stream_id, priority_field_value}; + if (spdy_session_->debug_visitor()) { + spdy_session_->debug_visitor()->OnPriorityUpdateFrameSent( + priority_update_frame); + } + + std::string frame = + HttpEncoder::SerializePriorityUpdateFrame(priority_update_frame); + QUIC_DVLOG(1) << "Control Stream " << id() << " is writing " + << priority_update_frame; + WriteOrBufferData(frame, false, nullptr); +} + +void QuicSendControlStream::SendGoAway(QuicStreamId id) { + QuicConnection::ScopedPacketFlusher flusher(session()->connection()); + MaybeSendSettingsFrame(); + + GoAwayFrame frame; + frame.id = id; + if (spdy_session_->debug_visitor()) { + spdy_session_->debug_visitor()->OnGoAwayFrameSent(id); + } + + WriteOrBufferData(HttpEncoder::SerializeGoAwayFrame(frame), false, nullptr); +} + +} // namespace quic diff --git a/quiche/quic/core/http/quic_send_control_stream.h b/quiche/quic/core/http/quic_send_control_stream.h new file mode 100644 index 000000000000..fa8b96a94422 --- /dev/null +++ b/quiche/quic/core/http/quic_send_control_stream.h @@ -0,0 +1,65 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_HTTP_QUIC_SEND_CONTROL_STREAM_H_ +#define QUICHE_QUIC_CORE_HTTP_QUIC_SEND_CONTROL_STREAM_H_ + +#include "quiche/quic/core/http/http_encoder.h" +#include "quiche/quic/core/quic_stream.h" +#include "quiche/quic/core/quic_stream_priority.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace quic { + +class QuicSpdySession; + +// 6.2.1 Control Stream. +// The send control stream is self initiated and is write only. +class QUIC_EXPORT_PRIVATE QuicSendControlStream : public QuicStream { + public: + // |session| can't be nullptr, and the ownership is not passed. The stream can + // only be accessed through the session. + QuicSendControlStream(QuicStreamId id, QuicSpdySession* session, + const SettingsFrame& settings); + QuicSendControlStream(const QuicSendControlStream&) = delete; + QuicSendControlStream& operator=(const QuicSendControlStream&) = delete; + ~QuicSendControlStream() override = default; + + // Overriding QuicStream::OnStopSending() to make sure control stream is never + // closed before connection. + void OnStreamReset(const QuicRstStreamFrame& frame) override; + bool OnStopSending(QuicResetStreamError code) override; + + // Send SETTINGS frame if it hasn't been sent yet. Settings frame must be the + // first frame sent on this stream. + void MaybeSendSettingsFrame(); + + // Send a PRIORITY_UPDATE frame on this stream, and a SETTINGS frame + // beforehand if one has not been already sent. + void WritePriorityUpdate(QuicStreamId stream_id, HttpStreamPriority priority); + + // Send a GOAWAY frame on this stream, and a SETTINGS frame beforehand if one + // has not been already sent. + void SendGoAway(QuicStreamId id); + + // The send control stream is write unidirectional, so this method should + // never be called. + void OnDataAvailable() override { QUICHE_NOTREACHED(); } + + private: + // Track if a settings frame is already sent. + bool settings_sent_; + + // SETTINGS values to send. + const SettingsFrame settings_; + + QuicSpdySession* const spdy_session_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_HTTP_QUIC_SEND_CONTROL_STREAM_H_ diff --git a/quiche/quic/core/http/quic_send_control_stream_test.cc b/quiche/quic/core/http/quic_send_control_stream_test.cc new file mode 100644 index 000000000000..cd3a2745bb5f --- /dev/null +++ b/quiche/quic/core/http/quic_send_control_stream_test.cc @@ -0,0 +1,301 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_send_control_stream.h" + +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/test_tools/quic_config_peer.h" +#include "quiche/quic/test_tools/quic_spdy_session_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +namespace quic { +namespace test { + +namespace { + +using ::testing::_; +using ::testing::AnyNumber; +using ::testing::Invoke; +using ::testing::StrictMock; + +struct TestParams { + TestParams(const ParsedQuicVersion& version, Perspective perspective) + : version(version), perspective(perspective) { + QUIC_LOG(INFO) << "TestParams: " << *this; + } + + TestParams(const TestParams& other) + : version(other.version), perspective(other.perspective) {} + + friend std::ostream& operator<<(std::ostream& os, const TestParams& tp) { + os << "{ version: " << ParsedQuicVersionToString(tp.version) + << ", perspective: " + << (tp.perspective == Perspective::IS_CLIENT ? "client" : "server") + << "}"; + return os; + } + + ParsedQuicVersion version; + Perspective perspective; +}; + +// Used by ::testing::PrintToStringParamName(). +std::string PrintToString(const TestParams& tp) { + return absl::StrCat( + ParsedQuicVersionToString(tp.version), "_", + (tp.perspective == Perspective::IS_CLIENT ? "client" : "server")); +} + +std::vector GetTestParams() { + std::vector params; + ParsedQuicVersionVector all_supported_versions = AllSupportedVersions(); + for (const auto& version : AllSupportedVersions()) { + if (!VersionUsesHttp3(version.transport_version)) { + continue; + } + for (Perspective p : {Perspective::IS_SERVER, Perspective::IS_CLIENT}) { + params.emplace_back(version, p); + } + } + return params; +} + +class QuicSendControlStreamTest : public QuicTestWithParam { + public: + QuicSendControlStreamTest() + : connection_(new StrictMock( + &helper_, &alarm_factory_, perspective(), + SupportedVersions(GetParam().version))), + session_(connection_) { + ON_CALL(session_, WritevData(_, _, _, _, _, _)) + .WillByDefault(Invoke(&session_, &MockQuicSpdySession::ConsumeData)); + } + + void Initialize() { + EXPECT_CALL(session_, OnCongestionWindowChange(_)).Times(AnyNumber()); + session_.Initialize(); + connection_->SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(connection_->perspective())); + send_control_stream_ = QuicSpdySessionPeer::GetSendControlStream(&session_); + QuicConfigPeer::SetReceivedInitialSessionFlowControlWindow( + session_.config(), kMinimumFlowControlSendWindow); + QuicConfigPeer::SetReceivedInitialMaxStreamDataBytesUnidirectional( + session_.config(), kMinimumFlowControlSendWindow); + QuicConfigPeer::SetReceivedMaxUnidirectionalStreams(session_.config(), 3); + session_.OnConfigNegotiated(); + } + + Perspective perspective() const { return GetParam().perspective; } + + MockQuicConnectionHelper helper_; + MockAlarmFactory alarm_factory_; + StrictMock* connection_; + StrictMock session_; + QuicSendControlStream* send_control_stream_; +}; + +INSTANTIATE_TEST_SUITE_P(Tests, QuicSendControlStreamTest, + ::testing::ValuesIn(GetTestParams()), + ::testing::PrintToStringParamName()); + +TEST_P(QuicSendControlStreamTest, WriteSettings) { + SetQuicFlag(quic_enable_http3_grease_randomness, false); + session_.set_qpack_maximum_dynamic_table_capacity(255); + session_.set_qpack_maximum_blocked_streams(16); + session_.set_max_inbound_header_list_size(1024); + + Initialize(); + testing::InSequence s; + + std::string expected_write_data = absl::HexStringToBytes( + "00" // stream type: control stream + "04" // frame type: SETTINGS frame + "0b" // frame length + "01" // SETTINGS_QPACK_MAX_TABLE_CAPACITY + "40ff" // 255 + "06" // SETTINGS_MAX_HEADER_LIST_SIZE + "4400" // 1024 + "07" // SETTINGS_QPACK_BLOCKED_STREAMS + "10" // 16 + "4040" // 0x40 as the reserved settings id + "14" // 20 + "4040" // 0x40 as the reserved frame type + "01" // 1 byte frame length + "61"); // payload "a" + if ((!GetQuicReloadableFlag(quic_verify_request_headers_2) || + perspective() == Perspective::IS_CLIENT) && + QuicSpdySessionPeer::LocalHttpDatagramSupport(&session_) == + HttpDatagramSupport::kDraft04) { + expected_write_data = absl::HexStringToBytes( + "00" // stream type: control stream + "04" // frame type: SETTINGS frame + "0b" // frame length + "01" // SETTINGS_QPACK_MAX_TABLE_CAPACITY + "40ff" // 255 + "06" // SETTINGS_MAX_HEADER_LIST_SIZE + "4400" // 1024 + "07" // SETTINGS_QPACK_BLOCKED_STREAMS + "10" // 16 + "4040" // 0x40 as the reserved settings id + "14" // 20 + "800ffd277" // SETTINGS_H3_DATAGRAM_DRAFT04 + "01" // 1 + "4040" // 0x40 as the reserved frame type + "01" // 1 byte frame length + "61"); // payload "a" + } + if (GetQuicReloadableFlag(quic_verify_request_headers_2) && + perspective() == Perspective::IS_SERVER && + QuicSpdySessionPeer::LocalHttpDatagramSupport(&session_) == + HttpDatagramSupport::kNone) { + expected_write_data = absl::HexStringToBytes( + "00" // stream type: control stream + "04" // frame type: SETTINGS frame + "0d" // frame length + "01" // SETTINGS_QPACK_MAX_TABLE_CAPACITY + "40ff" // 255 + "06" // SETTINGS_MAX_HEADER_LIST_SIZE + "4400" // 1024 + "07" // SETTINGS_QPACK_BLOCKED_STREAMS + "10" // 16 + "08" // SETTINGS_ENABLE_CONNECT_PROTOCOL + "01" // 1 + "4040" // 0x40 as the reserved settings id + "14" // 20 + "4040" // 0x40 as the reserved frame type + "01" // 1 byte frame length + "61"); // payload "a" + } + if (GetQuicReloadableFlag(quic_verify_request_headers_2) && + perspective() == Perspective::IS_SERVER && + QuicSpdySessionPeer::LocalHttpDatagramSupport(&session_) != + HttpDatagramSupport::kNone) { + expected_write_data = absl::HexStringToBytes( + "00" // stream type: control stream + "04" // frame type: SETTINGS frame + "0e" // frame length + "01" // SETTINGS_QPACK_MAX_TABLE_CAPACITY + "40ff" // 255 + "06" // SETTINGS_MAX_HEADER_LIST_SIZE + "4400" // 1024 + "07" // SETTINGS_QPACK_BLOCKED_STREAMS + "10" // 16 + "08" // SETTINGS_ENABLE_CONNECT_PROTOCOL + "01" // 1 + "4040" // 0x40 as the reserved settings id + "14" // 20 + "800ffd277" // SETTINGS_H3_DATAGRAM_DRAFT04 + "01" // 1 + "4040" // 0x40 as the reserved frame type + "01" // 1 byte frame length + "61"); // payload "a" + } + + auto buffer = std::make_unique(expected_write_data.size()); + QuicDataWriter writer(expected_write_data.size(), buffer.get()); + + // A lambda to save and consume stream data when QuicSession::WritevData() is + // called. + auto save_write_data = + [&writer, this](QuicStreamId /*id*/, size_t write_length, + QuicStreamOffset offset, StreamSendingState /*state*/, + TransmissionType /*type*/, + absl::optional /*level*/) { + send_control_stream_->WriteStreamData(offset, write_length, &writer); + return QuicConsumedData(/* bytes_consumed = */ write_length, + /* fin_consumed = */ false); + }; + + EXPECT_CALL(session_, WritevData(send_control_stream_->id(), 1, _, _, _, _)) + .WillOnce(Invoke(save_write_data)); + EXPECT_CALL(session_, WritevData(send_control_stream_->id(), + expected_write_data.size() - 5, _, _, _, _)) + .WillOnce(Invoke(save_write_data)); + EXPECT_CALL(session_, WritevData(send_control_stream_->id(), 4, _, _, _, _)) + .WillOnce(Invoke(save_write_data)); + + send_control_stream_->MaybeSendSettingsFrame(); + quiche::test::CompareCharArraysWithHexError( + "settings", writer.data(), writer.length(), expected_write_data.data(), + expected_write_data.length()); +} + +TEST_P(QuicSendControlStreamTest, WriteSettingsOnlyOnce) { + Initialize(); + testing::InSequence s; + + EXPECT_CALL(session_, WritevData(send_control_stream_->id(), 1, _, _, _, _)); + EXPECT_CALL(session_, WritevData(send_control_stream_->id(), _, _, _, _, _)) + .Times(2); + send_control_stream_->MaybeSendSettingsFrame(); + + // No data should be written the second time MaybeSendSettingsFrame() is + // called. + send_control_stream_->MaybeSendSettingsFrame(); +} + +// Send stream type and SETTINGS frame if WritePriorityUpdate() is called first. +TEST_P(QuicSendControlStreamTest, WritePriorityBeforeSettings) { + Initialize(); + testing::InSequence s; + + // The first write will trigger the control stream to write stream type, a + // SETTINGS frame, and a greased frame before the PRIORITY_UPDATE frame. + EXPECT_CALL(session_, WritevData(send_control_stream_->id(), _, _, _, _, _)) + .Times(4); + send_control_stream_->WritePriorityUpdate( + /* stream_id = */ 0, + HttpStreamPriority{/* urgency = */ 3, /* incremental = */ false}); + + EXPECT_TRUE(testing::Mock::VerifyAndClearExpectations(&session_)); + + EXPECT_CALL(session_, WritevData(send_control_stream_->id(), _, _, _, _, _)); + send_control_stream_->WritePriorityUpdate( + /* stream_id = */ 0, + HttpStreamPriority{/* urgency = */ 3, /* incremental = */ false}); +} + +TEST_P(QuicSendControlStreamTest, CloseControlStream) { + Initialize(); + EXPECT_CALL(*connection_, + CloseConnection(QUIC_HTTP_CLOSED_CRITICAL_STREAM, _, _)); + send_control_stream_->OnStopSending( + QuicResetStreamError::FromInternal(QUIC_STREAM_CANCELLED)); +} + +TEST_P(QuicSendControlStreamTest, ReceiveDataOnSendControlStream) { + Initialize(); + QuicStreamFrame frame(send_control_stream_->id(), false, 0, "test"); + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_DATA_RECEIVED_ON_WRITE_UNIDIRECTIONAL_STREAM, _, _)); + send_control_stream_->OnStreamFrame(frame); +} + +TEST_P(QuicSendControlStreamTest, SendGoAway) { + Initialize(); + + StrictMock debug_visitor; + session_.set_debug_visitor(&debug_visitor); + + QuicStreamId stream_id = 4; + + EXPECT_CALL(session_, WritevData(send_control_stream_->id(), _, _, _, _, _)) + .Times(AnyNumber()); + EXPECT_CALL(debug_visitor, OnSettingsFrameSent(_)); + EXPECT_CALL(debug_visitor, OnGoAwayFrameSent(stream_id)); + + send_control_stream_->SendGoAway(stream_id); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/http/quic_server_initiated_spdy_stream.cc b/quiche/quic/core/http/quic_server_initiated_spdy_stream.cc new file mode 100644 index 000000000000..c036c2c236c4 --- /dev/null +++ b/quiche/quic/core/http/quic_server_initiated_spdy_stream.cc @@ -0,0 +1,42 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_server_initiated_spdy_stream.h" + +#include "quiche/quic/core/quic_error_codes.h" + +namespace quic { + +void QuicServerInitiatedSpdyStream::OnBodyAvailable() { + QUIC_BUG(Body received in QuicServerInitiatedSpdyStream) + << "Received body data in QuicServerInitiatedSpdyStream."; + OnUnrecoverableError( + QUIC_INTERNAL_ERROR, + "Received HTTP/3 body data in a server-initiated bidirectional stream"); +} + +size_t QuicServerInitiatedSpdyStream::WriteHeaders( + spdy::Http2HeaderBlock /*header_block*/, bool /*fin*/, + quiche::QuicheReferenceCountedPointer< + QuicAckListenerInterface> /*ack_listener*/) { + QUIC_BUG(Writing headers in QuicServerInitiatedSpdyStream) + << "Attempting to write headers in QuicServerInitiatedSpdyStream"; + OnUnrecoverableError(QUIC_INTERNAL_ERROR, + "Attempted to send HTTP/3 headers in a server-initiated " + "bidirectional stream"); + return 0; +} + +void QuicServerInitiatedSpdyStream::OnInitialHeadersComplete( + bool /*fin*/, size_t /*frame_len*/, const QuicHeaderList& /*header_list*/) { + QUIC_PEER_BUG(Reading headers in QuicServerInitiatedSpdyStream) + << "Attempting to receive headers in QuicServerInitiatedSpdyStream"; + + OnUnrecoverableError(IETF_QUIC_PROTOCOL_VIOLATION, + "Received HTTP/3 headers in a server-initiated " + "bidirectional stream without an extension setting " + "explicitly allowing those"); +} + +} // namespace quic diff --git a/quiche/quic/core/http/quic_server_initiated_spdy_stream.h b/quiche/quic/core/http/quic_server_initiated_spdy_stream.h new file mode 100644 index 000000000000..a47a712434cf --- /dev/null +++ b/quiche/quic/core/http/quic_server_initiated_spdy_stream.h @@ -0,0 +1,32 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_HTTP_QUIC_SERVER_INITIATED_SPDY_STREAM_H_ +#define QUICHE_QUIC_CORE_HTTP_QUIC_SERVER_INITIATED_SPDY_STREAM_H_ + +#include "quiche/quic/core/http/quic_spdy_stream.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +// QuicServerInitiatedSpdyStream is a subclass of QuicSpdyStream meant to handle +// WebTransport traffic on server-initiated bidirectional streams. Receiving or +// sending any other traffic on this stream will result in a CONNECTION_CLOSE. +class QUIC_EXPORT_PRIVATE QuicServerInitiatedSpdyStream + : public QuicSpdyStream { + public: + using QuicSpdyStream::QuicSpdyStream; + + void OnBodyAvailable() override; + size_t WriteHeaders( + spdy::Http2HeaderBlock header_block, bool fin, + quiche::QuicheReferenceCountedPointer + ack_listener) override; + void OnInitialHeadersComplete(bool fin, size_t frame_len, + const QuicHeaderList& header_list) override; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_HTTP_QUIC_SERVER_INITIATED_SPDY_STREAM_H_ diff --git a/quiche/quic/core/http/quic_server_session_base.cc b/quiche/quic/core/http/quic_server_session_base.cc new file mode 100644 index 000000000000..4fc2aa90025c --- /dev/null +++ b/quiche/quic/core/http/quic_server_session_base.cc @@ -0,0 +1,425 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_server_session_base.h" + +#include + +#include "quiche/quic/core/proto/cached_network_parameters_proto.h" +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/core/quic_stream.h" +#include "quiche/quic/core/quic_tag.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace quic { + +QuicServerSessionBase::QuicServerSessionBase( + const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, Visitor* visitor, + QuicCryptoServerStreamBase::Helper* helper, + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache) + : QuicSpdySession(connection, visitor, config, supported_versions), + crypto_config_(crypto_config), + compressed_certs_cache_(compressed_certs_cache), + helper_(helper), + bandwidth_resumption_enabled_(false), + bandwidth_estimate_sent_to_client_(QuicBandwidth::Zero()), + last_scup_time_(QuicTime::Zero()) {} + +QuicServerSessionBase::~QuicServerSessionBase() {} + +void QuicServerSessionBase::Initialize() { + crypto_stream_ = + CreateQuicCryptoServerStream(crypto_config_, compressed_certs_cache_); + QuicSpdySession::Initialize(); + SendSettingsToCryptoStream(); +} + +void QuicServerSessionBase::OnConfigNegotiated() { + QuicSpdySession::OnConfigNegotiated(); + + const CachedNetworkParameters* cached_network_params = + crypto_stream_->PreviousCachedNetworkParams(); + + // Set the initial rtt from cached_network_params.min_rtt_ms, which comes from + // a validated address token. This will override the initial rtt that may have + // been set by the transport parameters. + if (version().UsesTls() && cached_network_params != nullptr) { + if (cached_network_params->serving_region() == serving_region_) { + QUIC_CODE_COUNT(quic_server_received_network_params_at_same_region); + if (config()->HasReceivedConnectionOptions() && + ContainsQuicTag(config()->ReceivedConnectionOptions(), kTRTT)) { + QUIC_DLOG(INFO) + << "Server: Setting initial rtt to " + << cached_network_params->min_rtt_ms() + << "ms which is received from a validated address token"; + connection()->sent_packet_manager().SetInitialRtt( + QuicTime::Delta::FromMilliseconds( + cached_network_params->min_rtt_ms()), + /*trusted=*/true); + } + } else { + QUIC_CODE_COUNT(quic_server_received_network_params_at_different_region); + } + } + + if (!config()->HasReceivedConnectionOptions()) { + return; + } + + if (GetQuicReloadableFlag(quic_enable_disable_resumption) && + version().UsesTls() && + ContainsQuicTag(config()->ReceivedConnectionOptions(), kNRES) && + crypto_stream_->ResumptionAttempted()) { + QUIC_RELOADABLE_FLAG_COUNT(quic_enable_disable_resumption); + const bool disabled = crypto_stream_->DisableResumption(); + QUIC_BUG_IF(quic_failed_to_disable_resumption, !disabled) + << "Failed to disable resumption"; + } + + enable_sending_bandwidth_estimate_when_network_idle_ = + GetQuicRestartFlag( + quic_enable_sending_bandwidth_estimate_when_network_idle_v2) && + version().HasIetfQuicFrames() && + ContainsQuicTag(config()->ReceivedConnectionOptions(), kBWID); + + // Enable bandwidth resumption if peer sent correct connection options. + const bool last_bandwidth_resumption = + ContainsQuicTag(config()->ReceivedConnectionOptions(), kBWRE); + const bool max_bandwidth_resumption = + ContainsQuicTag(config()->ReceivedConnectionOptions(), kBWMX); + bandwidth_resumption_enabled_ = + last_bandwidth_resumption || max_bandwidth_resumption; + + // If the client has provided a bandwidth estimate from the same serving + // region as this server, then decide whether to use the data for bandwidth + // resumption. + if (cached_network_params != nullptr && + cached_network_params->serving_region() == serving_region_) { + if (!version().UsesTls()) { + // Log the received connection parameters, regardless of how they + // get used for bandwidth resumption. + connection()->OnReceiveConnectionState(*cached_network_params); + } + + if (bandwidth_resumption_enabled_) { + // Only do bandwidth resumption if estimate is recent enough. + const uint64_t seconds_since_estimate = + connection()->clock()->WallNow().ToUNIXSeconds() - + cached_network_params->timestamp(); + if (seconds_since_estimate <= kNumSecondsPerHour) { + connection()->ResumeConnectionState(*cached_network_params, + max_bandwidth_resumption); + } + } + } +} + +void QuicServerSessionBase::OnConnectionClosed( + const QuicConnectionCloseFrame& frame, ConnectionCloseSource source) { + QuicSession::OnConnectionClosed(frame, source); + // In the unlikely event we get a connection close while doing an asynchronous + // crypto event, make sure we cancel the callback. + if (crypto_stream_ != nullptr) { + crypto_stream_->CancelOutstandingCallbacks(); + } +} + +void QuicServerSessionBase::OnBandwidthUpdateTimeout() { + if (!enable_sending_bandwidth_estimate_when_network_idle_) { + return; + } + QUIC_DVLOG(1) << "Bandwidth update timed out."; + const SendAlgorithmInterface* send_algorithm = + connection()->sent_packet_manager().GetSendAlgorithm(); + if (send_algorithm != nullptr && + send_algorithm->HasGoodBandwidthEstimateForResumption()) { + const bool success = MaybeSendAddressToken(); + QUIC_BUG_IF(QUIC_BUG_25522, !success) << "Failed to send address token."; + QUIC_RESTART_FLAG_COUNT_N( + quic_enable_sending_bandwidth_estimate_when_network_idle_v2, 2, 3); + } +} + +void QuicServerSessionBase::OnCongestionWindowChange(QuicTime now) { + // Sending bandwidth is no longer conditioned on if session does bandwidth + // resumption. + if (GetQuicRestartFlag( + quic_enable_sending_bandwidth_estimate_when_network_idle_v2)) { + QUIC_RESTART_FLAG_COUNT_N( + quic_enable_sending_bandwidth_estimate_when_network_idle_v2, 3, 3); + return; + } + if (!bandwidth_resumption_enabled_) { + return; + } + // Only send updates when the application has no data to write. + if (HasDataToWrite()) { + return; + } + + // If not enough time has passed since the last time we sent an update to the + // client, or not enough packets have been sent, then return early. + const QuicSentPacketManager& sent_packet_manager = + connection()->sent_packet_manager(); + int64_t srtt_ms = + sent_packet_manager.GetRttStats()->smoothed_rtt().ToMilliseconds(); + int64_t now_ms = (now - last_scup_time_).ToMilliseconds(); + int64_t packets_since_last_scup = 0; + const QuicPacketNumber largest_sent_packet = + connection()->sent_packet_manager().GetLargestSentPacket(); + if (largest_sent_packet.IsInitialized()) { + packets_since_last_scup = + last_scup_packet_number_.IsInitialized() + ? largest_sent_packet - last_scup_packet_number_ + : largest_sent_packet.ToUint64(); + } + if (now_ms < (kMinIntervalBetweenServerConfigUpdatesRTTs * srtt_ms) || + now_ms < kMinIntervalBetweenServerConfigUpdatesMs || + packets_since_last_scup < kMinPacketsBetweenServerConfigUpdates) { + return; + } + + // If the bandwidth recorder does not have a valid estimate, return early. + const QuicSustainedBandwidthRecorder* bandwidth_recorder = + sent_packet_manager.SustainedBandwidthRecorder(); + if (bandwidth_recorder == nullptr || !bandwidth_recorder->HasEstimate()) { + return; + } + + // The bandwidth recorder has recorded at least one sustained bandwidth + // estimate. Check that it's substantially different from the last one that + // we sent to the client, and if so, send the new one. + QuicBandwidth new_bandwidth_estimate = + bandwidth_recorder->BandwidthEstimate(); + + int64_t bandwidth_delta = + std::abs(new_bandwidth_estimate.ToBitsPerSecond() - + bandwidth_estimate_sent_to_client_.ToBitsPerSecond()); + + // Define "substantial" difference as a 50% increase or decrease from the + // last estimate. + bool substantial_difference = + bandwidth_delta > + 0.5 * bandwidth_estimate_sent_to_client_.ToBitsPerSecond(); + if (!substantial_difference) { + return; + } + + if (version().UsesTls()) { + if (version().HasIetfQuicFrames() && MaybeSendAddressToken()) { + bandwidth_estimate_sent_to_client_ = new_bandwidth_estimate; + } + } else { + absl::optional cached_network_params = + GenerateCachedNetworkParameters(); + + if (cached_network_params.has_value()) { + bandwidth_estimate_sent_to_client_ = new_bandwidth_estimate; + QUIC_DVLOG(1) << "Server: sending new bandwidth estimate (KBytes/s): " + << bandwidth_estimate_sent_to_client_.ToKBytesPerSecond(); + + QUICHE_DCHECK_EQ( + BandwidthToCachedParameterBytesPerSecond( + bandwidth_estimate_sent_to_client_), + cached_network_params->bandwidth_estimate_bytes_per_second()); + + crypto_stream_->SendServerConfigUpdate(&cached_network_params.value()); + + connection()->OnSendConnectionState(*cached_network_params); + } + } + + last_scup_time_ = now; + last_scup_packet_number_ = + connection()->sent_packet_manager().GetLargestSentPacket(); +} + +bool QuicServerSessionBase::ShouldCreateIncomingStream(QuicStreamId id) { + if (!connection()->connected()) { + QUIC_BUG(quic_bug_10393_2) + << "ShouldCreateIncomingStream called when disconnected"; + return false; + } + + if (QuicUtils::IsServerInitiatedStreamId(transport_version(), id)) { + QUIC_BUG(quic_bug_10393_3) + << "ShouldCreateIncomingStream called with server initiated " + "stream ID."; + return false; + } + + if (QuicUtils::IsServerInitiatedStreamId(transport_version(), id)) { + QUIC_DLOG(INFO) << "Invalid incoming even stream_id:" << id; + connection()->CloseConnection( + QUIC_INVALID_STREAM_ID, "Client created even numbered stream", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return false; + } + return true; +} + +bool QuicServerSessionBase::ShouldCreateOutgoingBidirectionalStream() { + if (!connection()->connected()) { + QUIC_BUG(quic_bug_12513_2) + << "ShouldCreateOutgoingBidirectionalStream called when disconnected"; + return false; + } + if (!crypto_stream_->encryption_established()) { + QUIC_BUG(quic_bug_10393_4) + << "Encryption not established so no outgoing stream created."; + return false; + } + + return CanOpenNextOutgoingBidirectionalStream(); +} + +bool QuicServerSessionBase::ShouldCreateOutgoingUnidirectionalStream() { + if (!connection()->connected()) { + QUIC_BUG(quic_bug_12513_3) + << "ShouldCreateOutgoingUnidirectionalStream called when disconnected"; + return false; + } + if (!crypto_stream_->encryption_established()) { + QUIC_BUG(quic_bug_10393_5) + << "Encryption not established so no outgoing stream created."; + return false; + } + + return CanOpenNextOutgoingUnidirectionalStream(); +} + +QuicCryptoServerStreamBase* QuicServerSessionBase::GetMutableCryptoStream() { + return crypto_stream_.get(); +} + +const QuicCryptoServerStreamBase* QuicServerSessionBase::GetCryptoStream() + const { + return crypto_stream_.get(); +} + +int32_t QuicServerSessionBase::BandwidthToCachedParameterBytesPerSecond( + const QuicBandwidth& bandwidth) const { + return static_cast(std::min( + bandwidth.ToBytesPerSecond(), std::numeric_limits::max())); +} + +void QuicServerSessionBase::SendSettingsToCryptoStream() { + if (!version().UsesTls()) { + return; + } + std::string settings_frame = HttpEncoder::SerializeSettingsFrame(settings()); + + std::unique_ptr serialized_settings = + std::make_unique( + settings_frame.data(), + settings_frame.data() + settings_frame.length()); + GetMutableCryptoStream()->SetServerApplicationStateForResumption( + std::move(serialized_settings)); +} + +QuicSSLConfig QuicServerSessionBase::GetSSLConfig() const { + QUICHE_DCHECK(crypto_config_ && crypto_config_->proof_source()); + + QuicSSLConfig ssl_config = QuicSpdySession::GetSSLConfig(); + + ssl_config.disable_ticket_support = + GetQuicFlag(quic_disable_server_tls_resumption); + + if (!crypto_config_ || !crypto_config_->proof_source()) { + return ssl_config; + } + + absl::InlinedVector signature_algorithms = + crypto_config_->proof_source()->SupportedTlsSignatureAlgorithms(); + if (!signature_algorithms.empty()) { + ssl_config.signing_algorithm_prefs = std::move(signature_algorithms); + } + + return ssl_config; +} + +absl::optional +QuicServerSessionBase::GenerateCachedNetworkParameters() const { + const QuicSentPacketManager& sent_packet_manager = + connection()->sent_packet_manager(); + const QuicSustainedBandwidthRecorder* bandwidth_recorder = + sent_packet_manager.SustainedBandwidthRecorder(); + + CachedNetworkParameters cached_network_params; + cached_network_params.set_timestamp( + connection()->clock()->WallNow().ToUNIXSeconds()); + + if (!sent_packet_manager.GetRttStats()->min_rtt().IsZero()) { + cached_network_params.set_min_rtt_ms( + sent_packet_manager.GetRttStats()->min_rtt().ToMilliseconds()); + } + + if (enable_sending_bandwidth_estimate_when_network_idle_) { + const SendAlgorithmInterface* send_algorithm = + sent_packet_manager.GetSendAlgorithm(); + if (send_algorithm != nullptr && + send_algorithm->HasGoodBandwidthEstimateForResumption()) { + cached_network_params.set_bandwidth_estimate_bytes_per_second( + BandwidthToCachedParameterBytesPerSecond( + send_algorithm->BandwidthEstimate())); + QUIC_CODE_COUNT(quic_send_measured_bandwidth_in_token); + } else { + const quic::CachedNetworkParameters* previous_cached_network_params = + crypto_stream()->PreviousCachedNetworkParams(); + if (previous_cached_network_params != nullptr && + previous_cached_network_params + ->bandwidth_estimate_bytes_per_second() > 0) { + cached_network_params.set_bandwidth_estimate_bytes_per_second( + previous_cached_network_params + ->bandwidth_estimate_bytes_per_second()); + QUIC_CODE_COUNT(quic_send_previous_bandwidth_in_token); + } else { + QUIC_CODE_COUNT(quic_not_send_bandwidth_in_token); + } + } + } else { + // Populate bandwidth estimates if any. + if (bandwidth_recorder != nullptr && bandwidth_recorder->HasEstimate()) { + const int32_t bw_estimate_bytes_per_second = + BandwidthToCachedParameterBytesPerSecond( + bandwidth_recorder->BandwidthEstimate()); + const int32_t max_bw_estimate_bytes_per_second = + BandwidthToCachedParameterBytesPerSecond( + bandwidth_recorder->MaxBandwidthEstimate()); + QUIC_BUG_IF(quic_bug_12513_1, max_bw_estimate_bytes_per_second < 0) + << max_bw_estimate_bytes_per_second; + QUIC_BUG_IF(quic_bug_10393_1, bw_estimate_bytes_per_second < 0) + << bw_estimate_bytes_per_second; + + cached_network_params.set_bandwidth_estimate_bytes_per_second( + bw_estimate_bytes_per_second); + cached_network_params.set_max_bandwidth_estimate_bytes_per_second( + max_bw_estimate_bytes_per_second); + cached_network_params.set_max_bandwidth_timestamp_seconds( + bandwidth_recorder->MaxBandwidthTimestamp()); + + cached_network_params.set_previous_connection_state( + bandwidth_recorder->EstimateRecordedDuringSlowStart() + ? CachedNetworkParameters::SLOW_START + : CachedNetworkParameters::CONGESTION_AVOIDANCE); + } + } + + if (!serving_region_.empty()) { + cached_network_params.set_serving_region(serving_region_); + } + + return cached_network_params; +} + +} // namespace quic diff --git a/quiche/quic/core/http/quic_server_session_base.h b/quiche/quic/core/http/quic_server_session_base.h new file mode 100644 index 000000000000..fcfda0e89f5a --- /dev/null +++ b/quiche/quic/core/http/quic_server_session_base.h @@ -0,0 +1,161 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// A server specific QuicSession subclass. + +#ifndef QUICHE_QUIC_CORE_HTTP_QUIC_SERVER_SESSION_BASE_H_ +#define QUICHE_QUIC_CORE_HTTP_QUIC_SERVER_SESSION_BASE_H_ + +#include +#include +#include +#include +#include + +#include "quiche/quic/core/crypto/quic_compressed_certs_cache.h" +#include "quiche/quic/core/http/quic_spdy_session.h" +#include "quiche/quic/core/quic_crypto_server_stream_base.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class QuicConfig; +class QuicConnection; +class QuicCryptoServerConfig; + +namespace test { +class QuicServerSessionBasePeer; +class QuicSimpleServerSessionPeer; +} // namespace test + +class QUIC_EXPORT_PRIVATE QuicServerSessionBase : public QuicSpdySession { + public: + // Does not take ownership of |connection|. |crypto_config| must outlive the + // session. |helper| must outlive any created crypto streams. + QuicServerSessionBase(const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, + QuicSession::Visitor* visitor, + QuicCryptoServerStreamBase::Helper* helper, + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache); + QuicServerSessionBase(const QuicServerSessionBase&) = delete; + QuicServerSessionBase& operator=(const QuicServerSessionBase&) = delete; + + // Override the base class to cancel any ongoing asychronous crypto. + void OnConnectionClosed(const QuicConnectionCloseFrame& frame, + ConnectionCloseSource source) override; + + // Override to send bandwidth estimate. + void OnBandwidthUpdateTimeout() override; + + // Sends a server config update to the client, containing new bandwidth + // estimate. + void OnCongestionWindowChange(QuicTime now) override; + + ~QuicServerSessionBase() override; + + void Initialize() override; + + const QuicCryptoServerStreamBase* crypto_stream() const { + return crypto_stream_.get(); + } + + // Override base class to process bandwidth related config received from + // client. + void OnConfigNegotiated() override; + + void set_serving_region(const std::string& serving_region) { + serving_region_ = serving_region; + } + + const std::string& serving_region() const { return serving_region_; } + + QuicSSLConfig GetSSLConfig() const override; + + bool enable_sending_bandwidth_estimate_when_network_idle() const { + return enable_sending_bandwidth_estimate_when_network_idle_; + } + + protected: + // QuicSession methods(override them with return type of QuicSpdyStream*): + QuicCryptoServerStreamBase* GetMutableCryptoStream() override; + + const QuicCryptoServerStreamBase* GetCryptoStream() const override; + + absl::optional GenerateCachedNetworkParameters() + const override; + + // If an outgoing stream can be created, return true. + // Return false when connection is closed or forward secure encryption hasn't + // established yet or number of server initiated streams already reaches the + // upper limit. + bool ShouldCreateOutgoingBidirectionalStream() override; + bool ShouldCreateOutgoingUnidirectionalStream() override; + + // If we should create an incoming stream, returns true. Otherwise + // does error handling, including communicating the error to the client and + // possibly closing the connection, and returns false. + bool ShouldCreateIncomingStream(QuicStreamId id) override; + + virtual std::unique_ptr + CreateQuicCryptoServerStream( + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache) = 0; + + const QuicCryptoServerConfig* crypto_config() { return crypto_config_; } + + QuicCryptoServerStreamBase::Helper* stream_helper() { return helper_; } + + private: + friend class test::QuicServerSessionBasePeer; + friend class test::QuicSimpleServerSessionPeer; + + // Informs the QuicCryptoStream of the SETTINGS that will be used on this + // connection, so that the server crypto stream knows whether to accept 0-RTT + // data. + void SendSettingsToCryptoStream(); + + const QuicCryptoServerConfig* crypto_config_; + + // The cache which contains most recently compressed certs. + // Owned by QuicDispatcher. + QuicCompressedCertsCache* compressed_certs_cache_; + + std::unique_ptr crypto_stream_; + + // Pointer to the helper used to create crypto server streams. Must outlive + // streams created via CreateQuicCryptoServerStream. + QuicCryptoServerStreamBase::Helper* helper_; + + // Whether bandwidth resumption is enabled for this connection. + bool bandwidth_resumption_enabled_; + + // The most recent bandwidth estimate sent to the client. + QuicBandwidth bandwidth_estimate_sent_to_client_; + + // Text describing server location. Sent to the client as part of the + // bandwidth estimate in the source-address token. Optional, can be left + // empty. + std::string serving_region_; + + // Time at which we send the last SCUP to the client. + QuicTime last_scup_time_; + + // Number of packets sent to the peer, at the time we last sent a SCUP. + QuicPacketNumber last_scup_packet_number_; + + // Converts QuicBandwidth to an int32 bytes/second that can be + // stored in CachedNetworkParameters. TODO(jokulik): This function + // should go away once we fix http://b//27897982 + int32_t BandwidthToCachedParameterBytesPerSecond( + const QuicBandwidth& bandwidth) const; + + bool enable_sending_bandwidth_estimate_when_network_idle_ = false; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_HTTP_QUIC_SERVER_SESSION_BASE_H_ diff --git a/quiche/quic/core/http/quic_server_session_base_test.cc b/quiche/quic/core/http/quic_server_session_base_test.cc new file mode 100644 index 000000000000..058ae5ba287a --- /dev/null +++ b/quiche/quic/core/http/quic_server_session_base_test.cc @@ -0,0 +1,801 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_server_session_base.h" + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/core/crypto/quic_crypto_server_config.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/proto/cached_network_parameters_proto.h" +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/core/quic_crypto_server_stream.h" +#include "quiche/quic/core/quic_crypto_server_stream_base.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/tls_server_handshaker.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/fake_proof_source.h" +#include "quiche/quic/test_tools/mock_quic_session_visitor.h" +#include "quiche/quic/test_tools/quic_config_peer.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_crypto_server_config_peer.h" +#include "quiche/quic/test_tools/quic_sent_packet_manager_peer.h" +#include "quiche/quic/test_tools/quic_server_session_base_peer.h" +#include "quiche/quic/test_tools/quic_session_peer.h" +#include "quiche/quic/test_tools/quic_spdy_session_peer.h" +#include "quiche/quic/test_tools/quic_stream_id_manager_peer.h" +#include "quiche/quic/test_tools/quic_stream_peer.h" +#include "quiche/quic/test_tools/quic_sustained_bandwidth_recorder_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/tools/quic_memory_cache_backend.h" +#include "quiche/quic/tools/quic_simple_server_stream.h" + +using testing::_; +using testing::StrictMock; + +using testing::AtLeast; + +namespace quic { +namespace test { +namespace { + +// Data to be sent on a request stream. In Google QUIC, this is interpreted as +// DATA payload (there is no framing on request streams). In IETF QUIC, this is +// interpreted as HEADERS frame (type 0x1) with payload length 122 ('z'). Since +// no payload is included, QPACK decoder will not be invoked. +const char* const kStreamData = "\1z"; + +class TestServerSession : public QuicServerSessionBase { + public: + TestServerSession(const QuicConfig& config, QuicConnection* connection, + QuicSession::Visitor* visitor, + QuicCryptoServerStreamBase::Helper* helper, + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache, + QuicSimpleServerBackend* quic_simple_server_backend) + : QuicServerSessionBase(config, CurrentSupportedVersions(), connection, + visitor, helper, crypto_config, + compressed_certs_cache), + quic_simple_server_backend_(quic_simple_server_backend) {} + + ~TestServerSession() override { DeleteConnection(); } + + MOCK_METHOD(bool, WriteControlFrame, + (const QuicFrame& frame, TransmissionType type), (override)); + + protected: + QuicSpdyStream* CreateIncomingStream(QuicStreamId id) override { + if (!ShouldCreateIncomingStream(id)) { + return nullptr; + } + QuicSpdyStream* stream = new QuicSimpleServerStream( + id, this, BIDIRECTIONAL, quic_simple_server_backend_); + ActivateStream(absl::WrapUnique(stream)); + return stream; + } + + QuicSpdyStream* CreateIncomingStream(PendingStream* pending) override { + QuicSpdyStream* stream = + new QuicSimpleServerStream(pending, this, quic_simple_server_backend_); + ActivateStream(absl::WrapUnique(stream)); + return stream; + } + + QuicSpdyStream* CreateOutgoingBidirectionalStream() override { + QUICHE_DCHECK(false); + return nullptr; + } + + QuicSpdyStream* CreateOutgoingUnidirectionalStream() override { + if (!ShouldCreateOutgoingUnidirectionalStream()) { + return nullptr; + } + + QuicSpdyStream* stream = new QuicSimpleServerStream( + GetNextOutgoingUnidirectionalStreamId(), this, WRITE_UNIDIRECTIONAL, + quic_simple_server_backend_); + ActivateStream(absl::WrapUnique(stream)); + return stream; + } + + std::unique_ptr CreateQuicCryptoServerStream( + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache) override { + return CreateCryptoServerStream(crypto_config, compressed_certs_cache, this, + stream_helper()); + } + + private: + QuicSimpleServerBackend* + quic_simple_server_backend_; // Owned by QuicServerSessionBaseTest +}; + +const size_t kMaxStreamsForTest = 10; + +class QuicServerSessionBaseTest : public QuicTestWithParam { + protected: + QuicServerSessionBaseTest() + : QuicServerSessionBaseTest(crypto_test_utils::ProofSourceForTesting()) {} + + explicit QuicServerSessionBaseTest(std::unique_ptr proof_source) + : crypto_config_(QuicCryptoServerConfig::TESTING, + QuicRandom::GetInstance(), std::move(proof_source), + KeyExchangeSource::Default()), + compressed_certs_cache_( + QuicCompressedCertsCache::kQuicCompressedCertsCacheSize) { + config_.SetMaxBidirectionalStreamsToSend(kMaxStreamsForTest); + config_.SetMaxUnidirectionalStreamsToSend(kMaxStreamsForTest); + QuicConfigPeer::SetReceivedMaxBidirectionalStreams(&config_, + kMaxStreamsForTest); + QuicConfigPeer::SetReceivedMaxUnidirectionalStreams(&config_, + kMaxStreamsForTest); + config_.SetInitialStreamFlowControlWindowToSend( + kInitialStreamFlowControlWindowForTest); + config_.SetInitialSessionFlowControlWindowToSend( + kInitialSessionFlowControlWindowForTest); + + ParsedQuicVersionVector supported_versions = SupportedVersions(version()); + connection_ = new StrictMock( + &helper_, &alarm_factory_, Perspective::IS_SERVER, supported_versions); + connection_->AdvanceTime(QuicTime::Delta::FromSeconds(1)); + connection_->SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(connection_->perspective())); + session_ = std::make_unique( + config_, connection_, &owner_, &stream_helper_, &crypto_config_, + &compressed_certs_cache_, &memory_cache_backend_); + MockClock clock; + handshake_message_ = crypto_config_.AddDefaultConfig( + QuicRandom::GetInstance(), &clock, + QuicCryptoServerConfig::ConfigOptions()); + session_->Initialize(); + QuicConfigPeer::SetReceivedInitialSessionFlowControlWindow( + session_->config(), kMinimumFlowControlSendWindow); + session_->OnConfigNegotiated(); + if (version().SupportsAntiAmplificationLimit()) { + QuicConnectionPeer::SetAddressValidated(connection_); + } + } + + QuicStreamId GetNthClientInitiatedBidirectionalId(int n) { + return GetNthClientInitiatedBidirectionalStreamId(transport_version(), n); + } + + QuicStreamId GetNthServerInitiatedUnidirectionalId(int n) { + return quic::test::GetNthServerInitiatedUnidirectionalStreamId( + transport_version(), n); + } + + ParsedQuicVersion version() const { return GetParam(); } + + QuicTransportVersion transport_version() const { + return version().transport_version; + } + + // Create and inject a STOP_SENDING frame. In GOOGLE QUIC, receiving a + // RST_STREAM frame causes a two-way close. For IETF QUIC, RST_STREAM causes a + // one-way close. This method can be used to inject a STOP_SENDING, which + // would cause a close in the opposite direction. This allows tests to do the + // extra work to get a two-way (full) close where desired. Also sets up + // expects needed to ensure that the STOP_SENDING worked as expected. + void InjectStopSendingFrame(QuicStreamId stream_id) { + if (!VersionHasIetfQuicFrames(transport_version())) { + // Only needed for version 99/IETF QUIC. Noop otherwise. + return; + } + QuicStopSendingFrame stop_sending(kInvalidControlFrameId, stream_id, + QUIC_ERROR_PROCESSING_STREAM); + EXPECT_CALL(owner_, OnStopSendingReceived(_)).Times(1); + // Expect the RESET_STREAM that is generated in response to receiving a + // STOP_SENDING. + EXPECT_CALL(*session_, WriteControlFrame(_, _)); + EXPECT_CALL(*connection_, + OnStreamReset(stream_id, QUIC_ERROR_PROCESSING_STREAM)); + session_->OnStopSendingFrame(stop_sending); + } + + StrictMock owner_; + StrictMock stream_helper_; + MockQuicConnectionHelper helper_; + MockAlarmFactory alarm_factory_; + StrictMock* connection_; + QuicConfig config_; + QuicCryptoServerConfig crypto_config_; + QuicCompressedCertsCache compressed_certs_cache_; + QuicMemoryCacheBackend memory_cache_backend_; + std::unique_ptr session_; + std::unique_ptr handshake_message_; +}; + +// Compares CachedNetworkParameters. +MATCHER_P(EqualsProto, network_params, "") { + CachedNetworkParameters reference(network_params); + return (arg->bandwidth_estimate_bytes_per_second() == + reference.bandwidth_estimate_bytes_per_second() && + arg->bandwidth_estimate_bytes_per_second() == + reference.bandwidth_estimate_bytes_per_second() && + arg->max_bandwidth_estimate_bytes_per_second() == + reference.max_bandwidth_estimate_bytes_per_second() && + arg->max_bandwidth_timestamp_seconds() == + reference.max_bandwidth_timestamp_seconds() && + arg->min_rtt_ms() == reference.min_rtt_ms() && + arg->previous_connection_state() == + reference.previous_connection_state()); +} + +INSTANTIATE_TEST_SUITE_P(Tests, QuicServerSessionBaseTest, + ::testing::ValuesIn(AllSupportedVersions()), + ::testing::PrintToStringParamName()); + +TEST_P(QuicServerSessionBaseTest, CloseStreamDueToReset) { + // Send some data open a stream, then reset it. + QuicStreamFrame data1(GetNthClientInitiatedBidirectionalId(0), false, 0, + kStreamData); + session_->OnStreamFrame(data1); + EXPECT_EQ(1u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); + + // Send a reset (and expect the peer to send a RST in response). + QuicRstStreamFrame rst1(kInvalidControlFrameId, + GetNthClientInitiatedBidirectionalId(0), + QUIC_ERROR_PROCESSING_STREAM, 0); + EXPECT_CALL(owner_, OnRstStreamReceived(_)).Times(1); + if (!VersionHasIetfQuicFrames(transport_version())) { + // For non-version 99, the RESET_STREAM will do the full close. + // Set up expects accordingly. + EXPECT_CALL(*session_, WriteControlFrame(_, _)); + EXPECT_CALL(*connection_, + OnStreamReset(GetNthClientInitiatedBidirectionalId(0), + QUIC_RST_ACKNOWLEDGEMENT)); + } + session_->OnRstStream(rst1); + + // For version-99 will create and receive a stop-sending, completing + // the full-close expected by this test. + InjectStopSendingFrame(GetNthClientInitiatedBidirectionalId(0)); + + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); + // Send the same two bytes of payload in a new packet. + session_->OnStreamFrame(data1); + + // The stream should not be re-opened. + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); + EXPECT_TRUE(connection_->connected()); +} + +TEST_P(QuicServerSessionBaseTest, NeverOpenStreamDueToReset) { + // Send a reset (and expect the peer to send a RST in response). + QuicRstStreamFrame rst1(kInvalidControlFrameId, + GetNthClientInitiatedBidirectionalId(0), + QUIC_ERROR_PROCESSING_STREAM, 0); + EXPECT_CALL(owner_, OnRstStreamReceived(_)).Times(1); + if (!VersionHasIetfQuicFrames(transport_version())) { + // For non-version 99, the RESET_STREAM will do the full close. + // Set up expects accordingly. + EXPECT_CALL(*session_, WriteControlFrame(_, _)); + EXPECT_CALL(*connection_, + OnStreamReset(GetNthClientInitiatedBidirectionalId(0), + QUIC_RST_ACKNOWLEDGEMENT)); + } + session_->OnRstStream(rst1); + + // For version-99 will create and receive a stop-sending, completing + // the full-close expected by this test. + InjectStopSendingFrame(GetNthClientInitiatedBidirectionalId(0)); + + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); + + QuicStreamFrame data1(GetNthClientInitiatedBidirectionalId(0), false, 0, + kStreamData); + session_->OnStreamFrame(data1); + + // The stream should never be opened, now that the reset is received. + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); + EXPECT_TRUE(connection_->connected()); +} + +TEST_P(QuicServerSessionBaseTest, AcceptClosedStream) { + // Send some data to open two streams. + QuicStreamFrame frame1(GetNthClientInitiatedBidirectionalId(0), false, 0, + kStreamData); + QuicStreamFrame frame2(GetNthClientInitiatedBidirectionalId(1), false, 0, + kStreamData); + session_->OnStreamFrame(frame1); + session_->OnStreamFrame(frame2); + EXPECT_EQ(2u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); + + // Send a reset (and expect the peer to send a RST in response). + QuicRstStreamFrame rst(kInvalidControlFrameId, + GetNthClientInitiatedBidirectionalId(0), + QUIC_ERROR_PROCESSING_STREAM, 0); + EXPECT_CALL(owner_, OnRstStreamReceived(_)).Times(1); + if (!VersionHasIetfQuicFrames(transport_version())) { + // For non-version 99, the RESET_STREAM will do the full close. + // Set up expects accordingly. + EXPECT_CALL(*session_, WriteControlFrame(_, _)); + EXPECT_CALL(*connection_, + OnStreamReset(GetNthClientInitiatedBidirectionalId(0), + QUIC_RST_ACKNOWLEDGEMENT)); + } + session_->OnRstStream(rst); + + // For version-99 will create and receive a stop-sending, completing + // the full-close expected by this test. + InjectStopSendingFrame(GetNthClientInitiatedBidirectionalId(0)); + + // If we were tracking, we'd probably want to reject this because it's data + // past the reset point of stream 3. As it's a closed stream we just drop the + // data on the floor, but accept the packet because it has data for stream 5. + QuicStreamFrame frame3(GetNthClientInitiatedBidirectionalId(0), false, 2, + kStreamData); + QuicStreamFrame frame4(GetNthClientInitiatedBidirectionalId(1), false, 2, + kStreamData); + session_->OnStreamFrame(frame3); + session_->OnStreamFrame(frame4); + // The stream should never be opened, now that the reset is received. + EXPECT_EQ(1u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); + EXPECT_TRUE(connection_->connected()); +} + +TEST_P(QuicServerSessionBaseTest, MaxOpenStreams) { + // Test that the server refuses if a client attempts to open too many data + // streams. For versions other than version 99, the server accepts slightly + // more than the negotiated stream limit to deal with rare cases where a + // client FIN/RST is lost. + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + session_->OnConfigNegotiated(); + if (!VersionHasIetfQuicFrames(transport_version())) { + // The slightly increased stream limit is set during config negotiation. It + // is either an increase of 10 over negotiated limit, or a fixed percentage + // scaling, whichever is larger. Test both before continuing. + EXPECT_LT(kMaxStreamsMultiplier * kMaxStreamsForTest, + kMaxStreamsForTest + kMaxStreamsMinimumIncrement); + EXPECT_EQ(kMaxStreamsForTest + kMaxStreamsMinimumIncrement, + session_->max_open_incoming_bidirectional_streams()); + } + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); + QuicStreamId stream_id = GetNthClientInitiatedBidirectionalId(0); + // Open the max configured number of streams, should be no problem. + for (size_t i = 0; i < kMaxStreamsForTest; ++i) { + EXPECT_TRUE(QuicServerSessionBasePeer::GetOrCreateStream(session_.get(), + stream_id)); + stream_id += QuicUtils::StreamIdDelta(transport_version()); + } + + if (!VersionHasIetfQuicFrames(transport_version())) { + // Open more streams: server should accept slightly more than the limit. + // Excess streams are for non-version-99 only. + for (size_t i = 0; i < kMaxStreamsMinimumIncrement; ++i) { + EXPECT_TRUE(QuicServerSessionBasePeer::GetOrCreateStream(session_.get(), + stream_id)); + stream_id += QuicUtils::StreamIdDelta(transport_version()); + } + } + // Now violate the server's internal stream limit. + stream_id += QuicUtils::StreamIdDelta(transport_version()); + + if (!VersionHasIetfQuicFrames(transport_version())) { + // For non-version 99, QUIC responds to an attempt to exceed the stream + // limit by resetting the stream. + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + EXPECT_CALL(*session_, WriteControlFrame(_, _)); + EXPECT_CALL(*connection_, OnStreamReset(stream_id, QUIC_REFUSED_STREAM)); + } else { + // In version 99 QUIC responds to an attempt to exceed the stream limit by + // closing the connection. + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(1); + } + // Even if the connection remains open, the stream creation should fail. + EXPECT_FALSE( + QuicServerSessionBasePeer::GetOrCreateStream(session_.get(), stream_id)); +} + +TEST_P(QuicServerSessionBaseTest, MaxAvailableBidirectionalStreams) { + // Test that the server closes the connection if a client makes too many data + // streams available. The server accepts slightly more than the negotiated + // stream limit to deal with rare cases where a client FIN/RST is lost. + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + session_->OnConfigNegotiated(); + const size_t kAvailableStreamLimit = + session_->MaxAvailableBidirectionalStreams(); + + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); + EXPECT_TRUE(QuicServerSessionBasePeer::GetOrCreateStream( + session_.get(), GetNthClientInitiatedBidirectionalId(0))); + + // Establish available streams up to the server's limit. + QuicStreamId next_id = QuicUtils::StreamIdDelta(transport_version()); + const int kLimitingStreamId = + GetNthClientInitiatedBidirectionalId(kAvailableStreamLimit + 1); + if (!VersionHasIetfQuicFrames(transport_version())) { + // This exceeds the stream limit. In versions other than 99 + // this is allowed. Version 99 hews to the IETF spec and does + // not allow it. + EXPECT_TRUE(QuicServerSessionBasePeer::GetOrCreateStream( + session_.get(), kLimitingStreamId)); + // A further available stream will result in connection close. + EXPECT_CALL(*connection_, + CloseConnection(QUIC_TOO_MANY_AVAILABLE_STREAMS, _, _)); + } else { + // A further available stream will result in connection close. + EXPECT_CALL(*connection_, CloseConnection(QUIC_INVALID_STREAM_ID, _, _)); + } + + // This forces stream kLimitingStreamId + 2 to become available, which + // violates the quota. + EXPECT_FALSE(QuicServerSessionBasePeer::GetOrCreateStream( + session_.get(), kLimitingStreamId + 2 * next_id)); +} + +TEST_P(QuicServerSessionBaseTest, GetEvenIncomingError) { + // Incoming streams on the server session must be odd. + const QuicErrorCode expected_error = + VersionHasIetfQuicFrames(transport_version()) + ? QUIC_HTTP_STREAM_WRONG_DIRECTION + : QUIC_INVALID_STREAM_ID; + EXPECT_CALL(*connection_, CloseConnection(expected_error, _, _)); + EXPECT_EQ(nullptr, QuicServerSessionBasePeer::GetOrCreateStream( + session_.get(), + session_->next_outgoing_unidirectional_stream_id())); +} + +TEST_P(QuicServerSessionBaseTest, GetStreamDisconnected) { + // EXPECT_QUIC_BUG tests are expensive so only run one instance of them. + if (version() != AllSupportedVersions()[0]) { + return; + } + + // Don't create new streams if the connection is disconnected. + QuicConnectionPeer::TearDownLocalConnectionState(connection_); + EXPECT_QUIC_BUG(QuicServerSessionBasePeer::GetOrCreateStream( + session_.get(), GetNthClientInitiatedBidirectionalId(0)), + "ShouldCreateIncomingStream called when disconnected"); +} + +class MockQuicCryptoServerStream : public QuicCryptoServerStream { + public: + explicit MockQuicCryptoServerStream( + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache, + QuicServerSessionBase* session, + QuicCryptoServerStreamBase::Helper* helper) + : QuicCryptoServerStream(crypto_config, compressed_certs_cache, session, + helper) {} + MockQuicCryptoServerStream(const MockQuicCryptoServerStream&) = delete; + MockQuicCryptoServerStream& operator=(const MockQuicCryptoServerStream&) = + delete; + ~MockQuicCryptoServerStream() override {} + + MOCK_METHOD(void, SendServerConfigUpdate, (const CachedNetworkParameters*), + (override)); +}; + +class MockTlsServerHandshaker : public TlsServerHandshaker { + public: + explicit MockTlsServerHandshaker(QuicServerSessionBase* session, + const QuicCryptoServerConfig* crypto_config) + : TlsServerHandshaker(session, crypto_config) {} + MockTlsServerHandshaker(const MockTlsServerHandshaker&) = delete; + MockTlsServerHandshaker& operator=(const MockTlsServerHandshaker&) = delete; + ~MockTlsServerHandshaker() override {} + + MOCK_METHOD(void, SendServerConfigUpdate, (const CachedNetworkParameters*), + (override)); + + MOCK_METHOD(std::string, GetAddressToken, (const CachedNetworkParameters*), + (const, override)); +}; + +TEST_P(QuicServerSessionBaseTest, BandwidthEstimates) { + if (version().UsesTls() && !version().HasIetfQuicFrames()) { + // Skip the Txxx versions. + return; + } + + // Test that bandwidth estimate updates are sent to the client, only when + // bandwidth resumption is enabled, the bandwidth estimate has changed + // sufficiently, enough time has passed, + // and we don't have any other data to write. + + // Client has sent kBWRE connection option to trigger bandwidth resumption. + QuicTagVector copt; + copt.push_back(kBWRE); + copt.push_back(kBWID); + QuicConfigPeer::SetReceivedConnectionOptions(session_->config(), copt); + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + session_->OnConfigNegotiated(); + EXPECT_TRUE( + QuicServerSessionBasePeer::IsBandwidthResumptionEnabled(session_.get())); + + int32_t bandwidth_estimate_kbytes_per_second = 123; + int32_t max_bandwidth_estimate_kbytes_per_second = 134; + int32_t max_bandwidth_estimate_timestamp = 1122334455; + const std::string serving_region = "not a real region"; + session_->set_serving_region(serving_region); + + if (!VersionUsesHttp3(transport_version())) { + session_->UnregisterStreamPriority( + QuicUtils::GetHeadersStreamId(transport_version())); + } + QuicServerSessionBasePeer::SetCryptoStream(session_.get(), nullptr); + MockQuicCryptoServerStream* quic_crypto_stream = nullptr; + MockTlsServerHandshaker* tls_server_stream = nullptr; + if (version().handshake_protocol == PROTOCOL_QUIC_CRYPTO) { + quic_crypto_stream = new MockQuicCryptoServerStream( + &crypto_config_, &compressed_certs_cache_, session_.get(), + &stream_helper_); + QuicServerSessionBasePeer::SetCryptoStream(session_.get(), + quic_crypto_stream); + } else { + tls_server_stream = + new MockTlsServerHandshaker(session_.get(), &crypto_config_); + QuicServerSessionBasePeer::SetCryptoStream(session_.get(), + tls_server_stream); + } + if (!VersionUsesHttp3(transport_version())) { + session_->RegisterStreamPriority( + QuicUtils::GetHeadersStreamId(transport_version()), + /*is_static=*/true, + QuicStreamPriority::Default(session_->priority_type())); + } + + // Set some initial bandwidth values. + QuicSentPacketManager* sent_packet_manager = + QuicConnectionPeer::GetSentPacketManager(session_->connection()); + QuicSustainedBandwidthRecorder& bandwidth_recorder = + QuicSentPacketManagerPeer::GetBandwidthRecorder(sent_packet_manager); + // Seed an rtt measurement equal to the initial default rtt. + RttStats* rtt_stats = + const_cast(sent_packet_manager->GetRttStats()); + rtt_stats->UpdateRtt(rtt_stats->initial_rtt(), QuicTime::Delta::Zero(), + QuicTime::Zero()); + QuicSustainedBandwidthRecorderPeer::SetBandwidthEstimate( + &bandwidth_recorder, bandwidth_estimate_kbytes_per_second); + QuicSustainedBandwidthRecorderPeer::SetMaxBandwidthEstimate( + &bandwidth_recorder, max_bandwidth_estimate_kbytes_per_second, + max_bandwidth_estimate_timestamp); + // Queue up some pending data. + if (!VersionUsesHttp3(transport_version())) { + session_->MarkConnectionLevelWriteBlocked( + QuicUtils::GetHeadersStreamId(transport_version())); + } else { + session_->MarkConnectionLevelWriteBlocked( + QuicUtils::GetFirstUnidirectionalStreamId(transport_version(), + Perspective::IS_SERVER)); + } + EXPECT_TRUE(session_->HasDataToWrite()); + + // There will be no update sent yet - not enough time has passed. + QuicTime now = QuicTime::Zero(); + session_->OnCongestionWindowChange(now); + + // Bandwidth estimate has now changed sufficiently but not enough time has + // passed to send a Server Config Update. + bandwidth_estimate_kbytes_per_second = + bandwidth_estimate_kbytes_per_second * 1.6; + session_->OnCongestionWindowChange(now); + + // Bandwidth estimate has now changed sufficiently and enough time has passed, + // but not enough packets have been sent. + int64_t srtt_ms = + sent_packet_manager->GetRttStats()->smoothed_rtt().ToMilliseconds(); + now = now + QuicTime::Delta::FromMilliseconds( + kMinIntervalBetweenServerConfigUpdatesRTTs * srtt_ms); + session_->OnCongestionWindowChange(now); + + // The connection no longer has pending data to be written. + session_->OnCanWrite(); + EXPECT_FALSE(session_->HasDataToWrite()); + session_->OnCongestionWindowChange(now); + + // Bandwidth estimate has now changed sufficiently, enough time has passed, + // and enough packets have been sent. + SerializedPacket packet( + QuicPacketNumber(1) + kMinPacketsBetweenServerConfigUpdates, + PACKET_4BYTE_PACKET_NUMBER, nullptr, 1000, false, false); + sent_packet_manager->OnPacketSent(&packet, now, NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA, true, + ECN_NOT_ECT); + + if (GetQuicRestartFlag( + quic_enable_sending_bandwidth_estimate_when_network_idle_v2)) { + EXPECT_CALL(*connection_, OnSendConnectionState(_)).Times(0); + } else { + // Verify that the proto has exactly the values we expect. + CachedNetworkParameters expected_network_params; + expected_network_params.set_bandwidth_estimate_bytes_per_second( + bandwidth_recorder.BandwidthEstimate().ToBytesPerSecond()); + expected_network_params.set_max_bandwidth_estimate_bytes_per_second( + bandwidth_recorder.MaxBandwidthEstimate().ToBytesPerSecond()); + expected_network_params.set_max_bandwidth_timestamp_seconds( + bandwidth_recorder.MaxBandwidthTimestamp()); + expected_network_params.set_min_rtt_ms(session_->connection() + ->sent_packet_manager() + .GetRttStats() + ->min_rtt() + .ToMilliseconds()); + expected_network_params.set_previous_connection_state( + CachedNetworkParameters::CONGESTION_AVOIDANCE); + expected_network_params.set_timestamp( + session_->connection()->clock()->WallNow().ToUNIXSeconds()); + expected_network_params.set_serving_region(serving_region); + + if (quic_crypto_stream) { + EXPECT_CALL(*quic_crypto_stream, + SendServerConfigUpdate(EqualsProto(expected_network_params))) + .Times(1); + } else { + EXPECT_CALL(*tls_server_stream, + GetAddressToken(EqualsProto(expected_network_params))) + .WillOnce(testing::Return("Test address token")); + } + EXPECT_CALL(*connection_, OnSendConnectionState(_)).Times(1); + } + session_->OnCongestionWindowChange(now); +} + +TEST_P(QuicServerSessionBaseTest, BandwidthResumptionExperiment) { + if (version().UsesTls()) { + if (!version().HasIetfQuicFrames()) { + // Skip the Txxx versions. + return; + } + // Avoid a QUIC_BUG in QuicSession::OnConfigNegotiated. + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + } + + // Test that if a client provides a CachedNetworkParameters with the same + // serving region as the current server, and which was made within an hour of + // now, that this data is passed down to the send algorithm. + + // Client has sent kBWRE connection option to trigger bandwidth resumption. + QuicTagVector copt; + copt.push_back(kBWRE); + QuicConfigPeer::SetReceivedConnectionOptions(session_->config(), copt); + + const std::string kTestServingRegion = "a serving region"; + session_->set_serving_region(kTestServingRegion); + + // Set the time to be one hour + one second from the 0 baseline. + connection_->AdvanceTime( + QuicTime::Delta::FromSeconds(kNumSecondsPerHour + 1)); + + QuicCryptoServerStreamBase* crypto_stream = + static_cast( + QuicSessionPeer::GetMutableCryptoStream(session_.get())); + + // No effect if no CachedNetworkParameters provided. + EXPECT_CALL(*connection_, ResumeConnectionState(_, _)).Times(0); + session_->OnConfigNegotiated(); + + // No effect if CachedNetworkParameters provided, but different serving + // regions. + CachedNetworkParameters cached_network_params; + cached_network_params.set_bandwidth_estimate_bytes_per_second(1); + cached_network_params.set_serving_region("different serving region"); + crypto_stream->SetPreviousCachedNetworkParams(cached_network_params); + EXPECT_CALL(*connection_, ResumeConnectionState(_, _)).Times(0); + session_->OnConfigNegotiated(); + + // Same serving region, but timestamp is too old, should have no effect. + cached_network_params.set_serving_region(kTestServingRegion); + cached_network_params.set_timestamp(0); + crypto_stream->SetPreviousCachedNetworkParams(cached_network_params); + EXPECT_CALL(*connection_, ResumeConnectionState(_, _)).Times(0); + session_->OnConfigNegotiated(); + + // Same serving region, and timestamp is recent: estimate is stored. + cached_network_params.set_timestamp( + connection_->clock()->WallNow().ToUNIXSeconds()); + crypto_stream->SetPreviousCachedNetworkParams(cached_network_params); + EXPECT_CALL(*connection_, ResumeConnectionState(_, _)).Times(1); + session_->OnConfigNegotiated(); +} + +TEST_P(QuicServerSessionBaseTest, BandwidthMaxEnablesResumption) { + EXPECT_FALSE( + QuicServerSessionBasePeer::IsBandwidthResumptionEnabled(session_.get())); + + // Client has sent kBWMX connection option to trigger bandwidth resumption. + QuicTagVector copt; + copt.push_back(kBWMX); + QuicConfigPeer::SetReceivedConnectionOptions(session_->config(), copt); + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + session_->OnConfigNegotiated(); + EXPECT_TRUE( + QuicServerSessionBasePeer::IsBandwidthResumptionEnabled(session_.get())); +} + +TEST_P(QuicServerSessionBaseTest, NoBandwidthResumptionByDefault) { + EXPECT_FALSE( + QuicServerSessionBasePeer::IsBandwidthResumptionEnabled(session_.get())); + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + session_->OnConfigNegotiated(); + EXPECT_FALSE( + QuicServerSessionBasePeer::IsBandwidthResumptionEnabled(session_.get())); +} + +// Tests which check the lifetime management of data members of +// QuicCryptoServerStream objects when async GetProof is in use. +class StreamMemberLifetimeTest : public QuicServerSessionBaseTest { + public: + StreamMemberLifetimeTest() + : QuicServerSessionBaseTest( + std::unique_ptr(new FakeProofSource())), + crypto_config_peer_(&crypto_config_) { + GetFakeProofSource()->Activate(); + } + + FakeProofSource* GetFakeProofSource() const { + return static_cast(crypto_config_peer_.GetProofSource()); + } + + private: + QuicCryptoServerConfigPeer crypto_config_peer_; +}; + +INSTANTIATE_TEST_SUITE_P(StreamMemberLifetimeTests, StreamMemberLifetimeTest, + ::testing::ValuesIn(AllSupportedVersions()), + ::testing::PrintToStringParamName()); + +// Trigger an operation which causes an async invocation of +// ProofSource::GetProof. Delay the completion of the operation until after the +// stream has been destroyed, and verify that there are no memory bugs. +TEST_P(StreamMemberLifetimeTest, Basic) { + if (version().handshake_protocol == PROTOCOL_TLS1_3) { + // This test depends on the QUIC crypto protocol, so it is disabled for the + // TLS handshake. + // TODO(nharper): Fix this test so it doesn't rely on QUIC crypto. + return; + } + + const QuicClock* clock = helper_.GetClock(); + CryptoHandshakeMessage chlo = crypto_test_utils::GenerateDefaultInchoateCHLO( + clock, transport_version(), &crypto_config_); + chlo.SetVector(kCOPT, QuicTagVector{kREJ}); + std::vector packet_version_list = {version()}; + std::unique_ptr packet(ConstructEncryptedPacket( + TestConnectionId(1), EmptyQuicConnectionId(), true, false, 1, + std::string(chlo.GetSerialized().AsStringPiece()), CONNECTION_ID_PRESENT, + CONNECTION_ID_ABSENT, PACKET_4BYTE_PACKET_NUMBER, &packet_version_list)); + + EXPECT_CALL(stream_helper_, CanAcceptClientHello(_, _, _, _, _)) + .WillOnce(testing::Return(true)); + + // Set the current packet + QuicConnectionPeer::SetCurrentPacket(session_->connection(), + packet->AsStringPiece()); + + // Yes, this is horrible. But it's the easiest way to trigger the behavior we + // need to exercise. + QuicCryptoServerStreamBase* crypto_stream = + const_cast(session_->crypto_stream()); + + // Feed the CHLO into the crypto stream, which will trigger a call to + // ProofSource::GetProof + crypto_test_utils::SendHandshakeMessageToStream(crypto_stream, chlo, + Perspective::IS_CLIENT); + ASSERT_EQ(GetFakeProofSource()->NumPendingCallbacks(), 1); + + // Destroy the stream + session_.reset(); + + // Allow the async ProofSource::GetProof call to complete. Verify (under + // memory access checkers) that this does not result in accesses to any + // freed memory from the session or its subobjects. + GetFakeProofSource()->InvokePendingCallback(0); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/http/quic_spdy_client_session.cc b/quiche/quic/core/http/quic_spdy_client_session.cc new file mode 100644 index 000000000000..a30e10ce7fe4 --- /dev/null +++ b/quiche/quic/core/http/quic_spdy_client_session.cc @@ -0,0 +1,214 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_spdy_client_session.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/http/quic_server_initiated_spdy_stream.h" +#include "quiche/quic/core/http/quic_spdy_client_stream.h" +#include "quiche/quic/core/http/spdy_utils.h" +#include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +QuicSpdyClientSession::QuicSpdyClientSession( + const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, const QuicServerId& server_id, + QuicCryptoClientConfig* crypto_config, + QuicClientPushPromiseIndex* push_promise_index) + : QuicSpdyClientSession(config, supported_versions, connection, nullptr, + server_id, crypto_config, push_promise_index) {} + +QuicSpdyClientSession::QuicSpdyClientSession( + const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, QuicSession::Visitor* visitor, + const QuicServerId& server_id, QuicCryptoClientConfig* crypto_config, + QuicClientPushPromiseIndex* push_promise_index) + : QuicSpdyClientSessionBase(connection, visitor, push_promise_index, config, + supported_versions), + server_id_(server_id), + crypto_config_(crypto_config), + respect_goaway_(true) {} + +QuicSpdyClientSession::~QuicSpdyClientSession() = default; + +void QuicSpdyClientSession::Initialize() { + crypto_stream_ = CreateQuicCryptoStream(); + QuicSpdyClientSessionBase::Initialize(); +} + +void QuicSpdyClientSession::OnProofValid( + const QuicCryptoClientConfig::CachedState& /*cached*/) {} + +void QuicSpdyClientSession::OnProofVerifyDetailsAvailable( + const ProofVerifyDetails& /*verify_details*/) {} + +bool QuicSpdyClientSession::ShouldCreateOutgoingBidirectionalStream() { + if (!crypto_stream_->encryption_established()) { + QUIC_DLOG(INFO) << "Encryption not active so no outgoing stream created."; + QUIC_CODE_COUNT( + quic_client_fails_to_create_stream_encryption_not_established); + return false; + } + if (goaway_received() && respect_goaway_) { + QUIC_DLOG(INFO) << "Failed to create a new outgoing stream. " + << "Already received goaway."; + QUIC_CODE_COUNT(quic_client_fails_to_create_stream_goaway_received); + return false; + } + return CanOpenNextOutgoingBidirectionalStream(); +} + +bool QuicSpdyClientSession::ShouldCreateOutgoingUnidirectionalStream() { + QUIC_BUG(quic_bug_10396_1) + << "Try to create outgoing unidirectional client data streams"; + return false; +} + +QuicSpdyClientStream* +QuicSpdyClientSession::CreateOutgoingBidirectionalStream() { + if (!ShouldCreateOutgoingBidirectionalStream()) { + return nullptr; + } + std::unique_ptr stream = CreateClientStream(); + QuicSpdyClientStream* stream_ptr = stream.get(); + ActivateStream(std::move(stream)); + return stream_ptr; +} + +QuicSpdyClientStream* +QuicSpdyClientSession::CreateOutgoingUnidirectionalStream() { + QUIC_BUG(quic_bug_10396_2) + << "Try to create outgoing unidirectional client data streams"; + return nullptr; +} + +std::unique_ptr +QuicSpdyClientSession::CreateClientStream() { + return std::make_unique( + GetNextOutgoingBidirectionalStreamId(), this, BIDIRECTIONAL); +} + +QuicCryptoClientStreamBase* QuicSpdyClientSession::GetMutableCryptoStream() { + return crypto_stream_.get(); +} + +const QuicCryptoClientStreamBase* QuicSpdyClientSession::GetCryptoStream() + const { + return crypto_stream_.get(); +} + +void QuicSpdyClientSession::CryptoConnect() { + QUICHE_DCHECK(flow_controller()); + crypto_stream_->CryptoConnect(); +} + +int QuicSpdyClientSession::GetNumSentClientHellos() const { + return crypto_stream_->num_sent_client_hellos(); +} + +bool QuicSpdyClientSession::IsResumption() const { + return crypto_stream_->IsResumption(); +} + +bool QuicSpdyClientSession::EarlyDataAccepted() const { + return crypto_stream_->EarlyDataAccepted(); +} + +bool QuicSpdyClientSession::ReceivedInchoateReject() const { + return crypto_stream_->ReceivedInchoateReject(); +} + +int QuicSpdyClientSession::GetNumReceivedServerConfigUpdates() const { + return crypto_stream_->num_scup_messages_received(); +} + +bool QuicSpdyClientSession::ShouldCreateIncomingStream(QuicStreamId id) { + if (!connection()->connected()) { + QUIC_BUG(quic_bug_10396_3) + << "ShouldCreateIncomingStream called when disconnected"; + return false; + } + if (goaway_received() && respect_goaway_) { + QUIC_DLOG(INFO) << "Failed to create a new outgoing stream. " + << "Already received goaway."; + return false; + } + + if (QuicUtils::IsClientInitiatedStreamId(transport_version(), id)) { + QUIC_BUG(quic_bug_10396_4) + << "ShouldCreateIncomingStream called with client initiated " + "stream ID."; + return false; + } + + if (QuicUtils::IsClientInitiatedStreamId(transport_version(), id)) { + QUIC_LOG(WARNING) << "Received invalid push stream id " << id; + connection()->CloseConnection( + QUIC_INVALID_STREAM_ID, + "Server created non write unidirectional stream", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return false; + } + + if (VersionHasIetfQuicFrames(transport_version()) && + QuicUtils::IsBidirectionalStreamId(id, version()) && + !WillNegotiateWebTransport()) { + connection()->CloseConnection( + QUIC_HTTP_SERVER_INITIATED_BIDIRECTIONAL_STREAM, + "Server created bidirectional stream.", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return false; + } + + return true; +} + +QuicSpdyStream* QuicSpdyClientSession::CreateIncomingStream( + PendingStream* pending) { + QuicSpdyStream* stream = new QuicSpdyClientStream(pending, this); + ActivateStream(absl::WrapUnique(stream)); + return stream; +} + +QuicSpdyStream* QuicSpdyClientSession::CreateIncomingStream(QuicStreamId id) { + if (!ShouldCreateIncomingStream(id)) { + return nullptr; + } + QuicSpdyStream* stream; + if (version().UsesHttp3() && + QuicUtils::IsBidirectionalStreamId(id, version())) { + QUIC_BUG_IF(QuicServerInitiatedSpdyStream but no WebTransport support, + !WillNegotiateWebTransport()) + << "QuicServerInitiatedSpdyStream created but no WebTransport support"; + stream = new QuicServerInitiatedSpdyStream(id, this, BIDIRECTIONAL); + } else { + stream = new QuicSpdyClientStream(id, this, READ_UNIDIRECTIONAL); + } + ActivateStream(absl::WrapUnique(stream)); + return stream; +} + +std::unique_ptr +QuicSpdyClientSession::CreateQuicCryptoStream() { + return std::make_unique( + server_id_, this, + crypto_config_->proof_verifier()->CreateDefaultContext(), crypto_config_, + this, /*has_application_state = */ version().UsesHttp3()); +} + +bool QuicSpdyClientSession::IsAuthorized(const std::string& /*authority*/) { + return true; +} + +} // namespace quic diff --git a/quiche/quic/core/http/quic_spdy_client_session.h b/quiche/quic/core/http/quic_spdy_client_session.h new file mode 100644 index 000000000000..083baba962f5 --- /dev/null +++ b/quiche/quic/core/http/quic_spdy_client_session.h @@ -0,0 +1,131 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// A client specific QuicSession subclass. + +#ifndef QUICHE_QUIC_CORE_HTTP_QUIC_SPDY_CLIENT_SESSION_H_ +#define QUICHE_QUIC_CORE_HTTP_QUIC_SPDY_CLIENT_SESSION_H_ + +#include +#include + +#include "quiche/quic/core/http/quic_spdy_client_session_base.h" +#include "quiche/quic/core/http/quic_spdy_client_stream.h" +#include "quiche/quic/core/quic_crypto_client_stream.h" +#include "quiche/quic/core/quic_packets.h" + +namespace quic { + +class QuicConnection; +class QuicServerId; + +class QUIC_EXPORT_PRIVATE QuicSpdyClientSession + : public QuicSpdyClientSessionBase { + public: + // Takes ownership of |connection|. Caller retains ownership of + // |promised_by_url|. + QuicSpdyClientSession(const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, + const QuicServerId& server_id, + QuicCryptoClientConfig* crypto_config, + QuicClientPushPromiseIndex* push_promise_index); + + QuicSpdyClientSession(const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, + QuicSession::Visitor* visitor, + const QuicServerId& server_id, + QuicCryptoClientConfig* crypto_config, + QuicClientPushPromiseIndex* push_promise_index); + + QuicSpdyClientSession(const QuicSpdyClientSession&) = delete; + QuicSpdyClientSession& operator=(const QuicSpdyClientSession&) = delete; + ~QuicSpdyClientSession() override; + // Set up the QuicSpdyClientSession. Must be called prior to use. + void Initialize() override; + + // QuicSession methods: + QuicSpdyClientStream* CreateOutgoingBidirectionalStream() override; + QuicSpdyClientStream* CreateOutgoingUnidirectionalStream() override; + QuicCryptoClientStreamBase* GetMutableCryptoStream() override; + const QuicCryptoClientStreamBase* GetCryptoStream() const override; + + bool IsAuthorized(const std::string& authority) override; + + // QuicSpdyClientSessionBase methods: + void OnProofValid(const QuicCryptoClientConfig::CachedState& cached) override; + void OnProofVerifyDetailsAvailable( + const ProofVerifyDetails& verify_details) override; + + // Performs a crypto handshake with the server. + virtual void CryptoConnect(); + + // Returns the number of client hello messages that have been sent on the + // crypto stream. If the handshake has completed then this is one greater + // than the number of round-trips needed for the handshake. + int GetNumSentClientHellos() const; + + // Return true if the handshake performed is a TLS resumption. + // Always return false for QUIC Crypto. + bool IsResumption() const; + + // Returns true if early data (0-RTT data) was sent and the server accepted + // it. + bool EarlyDataAccepted() const; + + // Returns true if the handshake was delayed one round trip by the server + // because the server wanted proof the client controls its source address + // before progressing further. In Google QUIC, this would be due to an + // inchoate REJ in the QUIC Crypto handshake; in IETF QUIC this would be due + // to a Retry packet. + // TODO(nharper): Consider a better name for this method. + bool ReceivedInchoateReject() const; + + int GetNumReceivedServerConfigUpdates() const; + + using QuicSession::CanOpenNextOutgoingBidirectionalStream; + + void set_respect_goaway(bool respect_goaway) { + respect_goaway_ = respect_goaway; + } + + protected: + // QuicSession methods: + QuicSpdyStream* CreateIncomingStream(QuicStreamId id) override; + QuicSpdyStream* CreateIncomingStream(PendingStream* pending) override; + // If an outgoing stream can be created, return true. + bool ShouldCreateOutgoingBidirectionalStream() override; + bool ShouldCreateOutgoingUnidirectionalStream() override; + + // If an incoming stream can be created, return true. + // TODO(fayang): move this up to QuicSpdyClientSessionBase. + bool ShouldCreateIncomingStream(QuicStreamId id) override; + + // Create the crypto stream. Called by Initialize(). + virtual std::unique_ptr CreateQuicCryptoStream(); + + // Unlike CreateOutgoingBidirectionalStream, which applies a bunch of + // sanity checks, this simply returns a new QuicSpdyClientStream. This may be + // used by subclasses which want to use a subclass of QuicSpdyClientStream for + // streams but wish to use the sanity checks in + // CreateOutgoingBidirectionalStream. + virtual std::unique_ptr CreateClientStream(); + + const QuicServerId& server_id() const { return server_id_; } + QuicCryptoClientConfig* crypto_config() { return crypto_config_; } + + private: + std::unique_ptr crypto_stream_; + QuicServerId server_id_; + QuicCryptoClientConfig* crypto_config_; + + // If this is set to false, the client will ignore server GOAWAYs and allow + // the creation of streams regardless of the high chance they will fail. + bool respect_goaway_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_HTTP_QUIC_SPDY_CLIENT_SESSION_H_ diff --git a/quiche/quic/core/http/quic_spdy_client_session_base.cc b/quiche/quic/core/http/quic_spdy_client_session_base.cc new file mode 100644 index 000000000000..adc334828b6e --- /dev/null +++ b/quiche/quic/core/http/quic_spdy_client_session_base.cc @@ -0,0 +1,271 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_spdy_client_session_base.h" + +#include + +#include "quiche/quic/core/http/quic_client_promised_info.h" +#include "quiche/quic/core/http/spdy_server_push_utils.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" + +using spdy::Http2HeaderBlock; + +namespace quic { + +QuicSpdyClientSessionBase::QuicSpdyClientSessionBase( + QuicConnection* connection, QuicSession::Visitor* visitor, + QuicClientPushPromiseIndex* push_promise_index, const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions) + : QuicSpdySession(connection, visitor, config, supported_versions), + push_promise_index_(push_promise_index), + largest_promised_stream_id_( + QuicUtils::GetInvalidStreamId(connection->transport_version())) {} + +QuicSpdyClientSessionBase::~QuicSpdyClientSessionBase() { + // all promised streams for this session + for (auto& it : promised_by_id_) { + QUIC_DVLOG(1) << "erase stream " << it.first << " url " << it.second->url(); + push_promise_index_->promised_by_url()->erase(it.second->url()); + } + DeleteConnection(); +} + +void QuicSpdyClientSessionBase::OnConfigNegotiated() { + QuicSpdySession::OnConfigNegotiated(); +} + +void QuicSpdyClientSessionBase::OnInitialHeadersComplete( + QuicStreamId stream_id, const Http2HeaderBlock& response_headers) { + // Note that the strong ordering of the headers stream means that + // QuicSpdyClientStream::OnPromiseHeadersComplete must have already + // been called (on the associated stream) if this is a promised + // stream. However, this stream may not have existed at this time, + // hence the need to query the session. + QuicClientPromisedInfo* promised = GetPromisedById(stream_id); + if (!promised) return; + + promised->OnResponseHeaders(response_headers); +} + +void QuicSpdyClientSessionBase::OnPromiseHeaderList( + QuicStreamId stream_id, QuicStreamId promised_stream_id, size_t frame_len, + const QuicHeaderList& header_list) { + if (IsStaticStream(stream_id)) { + connection()->CloseConnection( + QUIC_INVALID_HEADERS_STREAM_DATA, "stream is static", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + // In HTTP3, push promises are received on individual streams, so they could + // be arrive out of order. + if (!VersionUsesHttp3(transport_version()) && + promised_stream_id != + QuicUtils::GetInvalidStreamId(transport_version()) && + largest_promised_stream_id_ != + QuicUtils::GetInvalidStreamId(transport_version()) && + promised_stream_id <= largest_promised_stream_id_) { + connection()->CloseConnection( + QUIC_INVALID_STREAM_ID, + "Received push stream id lesser or equal to the" + " last accepted before", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + if (!IsIncomingStream(promised_stream_id)) { + connection()->CloseConnection( + QUIC_INVALID_STREAM_ID, "Received push stream id for outgoing stream.", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + + if (VersionUsesHttp3(transport_version())) { + // Received push stream id is higher than MAX_PUSH_ID + // because no MAX_PUSH_ID frame is ever sent. + connection()->CloseConnection( + QUIC_INVALID_STREAM_ID, + "Received push stream id higher than MAX_PUSH_ID.", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + largest_promised_stream_id_ = promised_stream_id; + + QuicSpdyStream* stream = GetOrCreateSpdyDataStream(stream_id); + if (!stream) { + // It's quite possible to receive headers after a stream has been reset. + return; + } + stream->OnPromiseHeaderList(promised_stream_id, frame_len, header_list); +} + +bool QuicSpdyClientSessionBase::HandlePromised( + QuicStreamId /* associated_id */, QuicStreamId promised_id, + const Http2HeaderBlock& headers) { + // TODO(b/136295430): Do not treat |promised_id| as a stream ID when using + // IETF QUIC. + // Due to pathalogical packet re-ordering, it is possible that + // frames for the promised stream have already arrived, and the + // promised stream could be active or closed. + if (IsClosedStream(promised_id)) { + // There was a RST on the data stream already, perhaps + // QUIC_REFUSED_STREAM? + QUIC_DVLOG(1) << "Promise ignored for stream " << promised_id + << " that is already closed"; + return false; + } + + if (push_promise_index_->promised_by_url()->size() >= get_max_promises()) { + QUIC_DVLOG(1) << "Too many promises, rejecting promise for stream " + << promised_id; + ResetPromised(promised_id, QUIC_REFUSED_STREAM); + return false; + } + + const std::string url = + SpdyServerPushUtils::GetPromisedUrlFromHeaders(headers); + QuicClientPromisedInfo* old_promised = GetPromisedByUrl(url); + if (old_promised) { + QUIC_DVLOG(1) << "Promise for stream " << promised_id + << " is duplicate URL " << url + << " of previous promise for stream " << old_promised->id(); + ResetPromised(promised_id, QUIC_DUPLICATE_PROMISE_URL); + return false; + } + + if (GetPromisedById(promised_id)) { + // OnPromiseHeadersComplete() would have closed the connection if + // promised id is a duplicate. + QUIC_BUG(quic_bug_10412_1) << "Duplicate promise for id " << promised_id; + return false; + } + + QuicClientPromisedInfo* promised = + new QuicClientPromisedInfo(this, promised_id, url); + std::unique_ptr promised_owner(promised); + promised->Init(); + QUIC_DVLOG(1) << "stream " << promised_id << " emplace url " << url; + (*push_promise_index_->promised_by_url())[url] = promised; + promised_by_id_[promised_id] = std::move(promised_owner); + bool result = promised->OnPromiseHeaders(headers); + if (result) { + QUICHE_DCHECK(promised_by_id_.find(promised_id) != promised_by_id_.end()); + } + return result; +} + +QuicClientPromisedInfo* QuicSpdyClientSessionBase::GetPromisedByUrl( + const std::string& url) { + auto it = push_promise_index_->promised_by_url()->find(url); + if (it != push_promise_index_->promised_by_url()->end()) { + return it->second; + } + return nullptr; +} + +QuicClientPromisedInfo* QuicSpdyClientSessionBase::GetPromisedById( + const QuicStreamId id) { + auto it = promised_by_id_.find(id); + if (it != promised_by_id_.end()) { + return it->second.get(); + } + return nullptr; +} + +QuicSpdyStream* QuicSpdyClientSessionBase::GetPromisedStream( + const QuicStreamId id) { + QuicStream* stream = GetActiveStream(id); + if (stream != nullptr) { + return static_cast(stream); + } + return nullptr; +} + +void QuicSpdyClientSessionBase::DeletePromised( + QuicClientPromisedInfo* promised) { + push_promise_index_->promised_by_url()->erase(promised->url()); + // Since promised_by_id_ contains the unique_ptr, this will destroy + // promised. + // ToDo: Consider implementing logic to send a new MAX_PUSH_ID frame to allow + // another stream to be promised. + promised_by_id_.erase(promised->id()); + if (!VersionUsesHttp3(transport_version())) { + headers_stream()->MaybeReleaseSequencerBuffer(); + } +} + +void QuicSpdyClientSessionBase::OnPushStreamTimedOut( + QuicStreamId /*stream_id*/) {} + +void QuicSpdyClientSessionBase::ResetPromised( + QuicStreamId id, QuicRstStreamErrorCode error_code) { + QUICHE_DCHECK(QuicUtils::IsServerInitiatedStreamId(transport_version(), id)); + ResetStream(id, error_code); + if (!IsOpenStream(id) && !IsClosedStream(id)) { + MaybeIncreaseLargestPeerStreamId(id); + } +} + +void QuicSpdyClientSessionBase::OnStreamClosed(QuicStreamId stream_id) { + QuicSpdySession::OnStreamClosed(stream_id); + if (!VersionUsesHttp3(transport_version())) { + headers_stream()->MaybeReleaseSequencerBuffer(); + } +} + +bool QuicSpdyClientSessionBase::ShouldReleaseHeadersStreamSequencerBuffer() { + return !HasActiveRequestStreams() && promised_by_id_.empty(); +} + +bool QuicSpdyClientSessionBase::ShouldKeepConnectionAlive() const { + return QuicSpdySession::ShouldKeepConnectionAlive() || + num_outgoing_draining_streams() > 0; +} + +bool QuicSpdyClientSessionBase::OnSettingsFrame(const SettingsFrame& frame) { + if (!was_zero_rtt_rejected()) { + if (max_outbound_header_list_size() != std::numeric_limits::max() && + frame.values.find(SETTINGS_MAX_FIELD_SECTION_SIZE) == + frame.values.end()) { + CloseConnectionWithDetails( + QUIC_HTTP_ZERO_RTT_RESUMPTION_SETTINGS_MISMATCH, + "Server accepted 0-RTT but omitted non-default " + "SETTINGS_MAX_FIELD_SECTION_SIZE"); + return false; + } + + if (qpack_encoder()->maximum_blocked_streams() != 0 && + frame.values.find(SETTINGS_QPACK_BLOCKED_STREAMS) == + frame.values.end()) { + CloseConnectionWithDetails( + QUIC_HTTP_ZERO_RTT_RESUMPTION_SETTINGS_MISMATCH, + "Server accepted 0-RTT but omitted non-default " + "SETTINGS_QPACK_BLOCKED_STREAMS"); + return false; + } + + if (qpack_encoder()->MaximumDynamicTableCapacity() != 0 && + frame.values.find(SETTINGS_QPACK_MAX_TABLE_CAPACITY) == + frame.values.end()) { + CloseConnectionWithDetails( + QUIC_HTTP_ZERO_RTT_RESUMPTION_SETTINGS_MISMATCH, + "Server accepted 0-RTT but omitted non-default " + "SETTINGS_QPACK_MAX_TABLE_CAPACITY"); + return false; + } + } + + if (!QuicSpdySession::OnSettingsFrame(frame)) { + return false; + } + std::string settings_frame = HttpEncoder::SerializeSettingsFrame(frame); + auto serialized_data = std::make_unique( + settings_frame.data(), settings_frame.data() + settings_frame.length()); + GetMutableCryptoStream()->SetServerApplicationStateForResumption( + std::move(serialized_data)); + return true; +} + +} // namespace quic diff --git a/quiche/quic/core/http/quic_spdy_client_session_base.h b/quiche/quic/core/http/quic_spdy_client_session_base.h new file mode 100644 index 000000000000..39746b62230e --- /dev/null +++ b/quiche/quic/core/http/quic_spdy_client_session_base.h @@ -0,0 +1,146 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_HTTP_QUIC_SPDY_CLIENT_SESSION_BASE_H_ +#define QUICHE_QUIC_CORE_HTTP_QUIC_SPDY_CLIENT_SESSION_BASE_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "quiche/quic/core/http/quic_spdy_session.h" +#include "quiche/quic/core/quic_crypto_client_stream.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +class QuicClientPromisedInfo; +class QuicClientPushPromiseIndex; +class QuicSpdyClientStream; + +// For client/http layer code. Lookup promised streams based on +// matching promised request url. The same map can be shared across +// multiple sessions, since cross-origin pushes are allowed (subject +// to authority constraints). Clients should use this map to enforce +// session affinity for requests corresponding to cross-origin push +// promised streams. +using QuicPromisedByUrlMap = + absl::flat_hash_map; + +// The maximum time a promises stream can be reserved without being +// claimed by a client request. +const int64_t kPushPromiseTimeoutSecs = 60; + +// Base class for all client-specific QuicSession subclasses. +class QUIC_EXPORT_PRIVATE QuicSpdyClientSessionBase + : public QuicSpdySession, + public QuicCryptoClientStream::ProofHandler { + public: + // Takes ownership of |connection|. Caller retains ownership of + // |promised_by_url|. + QuicSpdyClientSessionBase(QuicConnection* connection, + QuicSession::Visitor* visitor, + QuicClientPushPromiseIndex* push_promise_index, + const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions); + QuicSpdyClientSessionBase(const QuicSpdyClientSessionBase&) = delete; + QuicSpdyClientSessionBase& operator=(const QuicSpdyClientSessionBase&) = + delete; + + ~QuicSpdyClientSessionBase() override; + + void OnConfigNegotiated() override; + + // Called by |headers_stream_| when push promise headers have been + // completely received. + void OnPromiseHeaderList(QuicStreamId stream_id, + QuicStreamId promised_stream_id, size_t frame_len, + const QuicHeaderList& header_list) override; + + // Called by |QuicSpdyClientStream| on receipt of response headers, + // needed to detect promised server push streams, as part of + // client-request to push-stream rendezvous. + void OnInitialHeadersComplete(QuicStreamId stream_id, + const spdy::Http2HeaderBlock& response_headers); + + // Called by |QuicSpdyClientStream| on receipt of PUSH_PROMISE, does + // some session level validation and creates the + // |QuicClientPromisedInfo| inserting into maps by (promised) id and + // url. Returns true if a new push promise is accepted. Resets the promised + // stream and returns false otherwise. + virtual bool HandlePromised(QuicStreamId associated_id, + QuicStreamId promised_id, + const spdy::Http2HeaderBlock& headers); + + // For cross-origin server push, this should verify the server is + // authoritative per [RFC2818], Section 3. Roughly, subjectAltName + // list in the certificate should contain a matching DNS name, or IP + // address. |hostname| is derived from the ":authority" header field of + // the PUSH_PROMISE frame, port if present there will be dropped. + virtual bool IsAuthorized(const std::string& hostname) = 0; + + // Session retains ownership. + QuicClientPromisedInfo* GetPromisedByUrl(const std::string& url); + // Session retains ownership. + QuicClientPromisedInfo* GetPromisedById(const QuicStreamId id); + + // + QuicSpdyStream* GetPromisedStream(const QuicStreamId id); + + // Removes |promised| from the maps by url. + void ErasePromisedByUrl(QuicClientPromisedInfo* promised); + + // Removes |promised| from the maps by url and id and destroys + // promised. + virtual void DeletePromised(QuicClientPromisedInfo* promised); + + virtual void OnPushStreamTimedOut(QuicStreamId stream_id); + + // Sends Rst for the stream, and makes sure that future calls to + // IsClosedStream(id) return true, which ensures that any subsequent + // frames related to this stream will be ignored (modulo flow + // control accounting). + void ResetPromised(QuicStreamId id, QuicRstStreamErrorCode error_code); + + // Release headers stream's sequencer buffer if it's empty. + void OnStreamClosed(QuicStreamId stream_id) override; + + // Returns true if there are no active requests and no promised streams. + bool ShouldReleaseHeadersStreamSequencerBuffer() override; + + // Override to wait for all received responses to be consumed by application. + bool ShouldKeepConnectionAlive() const override; + + size_t get_max_promises() const { + return max_open_incoming_unidirectional_streams() * + kMaxPromisedStreamsMultiplier; + } + + QuicClientPushPromiseIndex* push_promise_index() { + return push_promise_index_; + } + + // Override to serialize the settings and pass it down to the handshaker. + bool OnSettingsFrame(const SettingsFrame& frame) override; + + private: + // For QuicSpdyClientStream to detect that a response corresponds to a + // promise. + using QuicPromisedByIdMap = + absl::flat_hash_map>; + + // As per rfc7540, section 10.5: track promise streams in "reserved + // (remote)". The primary key is URL from the promise request + // headers. The promised stream id is a secondary key used to get + // promise info when the response headers of the promised stream + // arrive. + QuicClientPushPromiseIndex* push_promise_index_; + QuicPromisedByIdMap promised_by_id_; + QuicStreamId largest_promised_stream_id_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_HTTP_QUIC_SPDY_CLIENT_SESSION_BASE_H_ diff --git a/quiche/quic/core/http/quic_spdy_client_session_test.cc b/quiche/quic/core/http/quic_spdy_client_session_test.cc new file mode 100644 index 000000000000..2696f194a799 --- /dev/null +++ b/quiche/quic/core/http/quic_spdy_client_session_test.cc @@ -0,0 +1,1339 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_spdy_client_session.h" + +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/null_decrypter.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/core/http/http_constants.h" +#include "quiche/quic/core/http/http_frames.h" +#include "quiche/quic/core/http/quic_spdy_client_stream.h" +#include "quiche/quic/core/http/spdy_server_push_utils.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/core/tls_client_handshaker.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/mock_quic_spdy_client_stream.h" +#include "quiche/quic/test_tools/quic_config_peer.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_framer_peer.h" +#include "quiche/quic/test_tools/quic_packet_creator_peer.h" +#include "quiche/quic/test_tools/quic_session_peer.h" +#include "quiche/quic/test_tools/quic_spdy_session_peer.h" +#include "quiche/quic/test_tools/quic_stream_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/test_tools/simple_session_cache.h" +#include "quiche/spdy/core/http2_header_block.h" + +using spdy::Http2HeaderBlock; +using ::testing::_; +using ::testing::AnyNumber; +using ::testing::AtLeast; +using ::testing::AtMost; +using ::testing::Invoke; +using ::testing::StrictMock; +using ::testing::Truly; + +namespace quic { +namespace test { +namespace { + +const char kServerHostname[] = "test.example.com"; +const uint16_t kPort = 443; + +class TestQuicSpdyClientSession : public QuicSpdyClientSession { + public: + explicit TestQuicSpdyClientSession( + const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, const QuicServerId& server_id, + QuicCryptoClientConfig* crypto_config, + QuicClientPushPromiseIndex* push_promise_index) + : QuicSpdyClientSession(config, supported_versions, connection, server_id, + crypto_config, push_promise_index) {} + + std::unique_ptr CreateClientStream() override { + return std::make_unique( + GetNextOutgoingBidirectionalStreamId(), this, BIDIRECTIONAL); + } + + MockQuicSpdyClientStream* CreateIncomingStream(QuicStreamId id) override { + if (!ShouldCreateIncomingStream(id)) { + return nullptr; + } + MockQuicSpdyClientStream* stream = + new MockQuicSpdyClientStream(id, this, READ_UNIDIRECTIONAL); + ActivateStream(absl::WrapUnique(stream)); + return stream; + } +}; + +class QuicSpdyClientSessionTest : public QuicTestWithParam { + protected: + QuicSpdyClientSessionTest() + : promised_stream_id_( + QuicUtils::GetInvalidStreamId(GetParam().transport_version)), + associated_stream_id_( + QuicUtils::GetInvalidStreamId(GetParam().transport_version)) { + auto client_cache = std::make_unique(); + client_session_cache_ = client_cache.get(); + client_crypto_config_ = std::make_unique( + crypto_test_utils::ProofVerifierForTesting(), std::move(client_cache)); + server_crypto_config_ = crypto_test_utils::CryptoServerConfigForTesting(); + Initialize(); + // Advance the time, because timers do not like uninitialized times. + connection_->AdvanceTime(QuicTime::Delta::FromSeconds(1)); + } + + ~QuicSpdyClientSessionTest() override { + // Session must be destroyed before promised_by_url_ + session_.reset(nullptr); + } + + void Initialize() { + session_.reset(); + connection_ = new ::testing::NiceMock( + &helper_, &alarm_factory_, Perspective::IS_CLIENT, + SupportedVersions(GetParam())); + session_ = std::make_unique( + DefaultQuicConfig(), SupportedVersions(GetParam()), connection_, + QuicServerId(kServerHostname, kPort, false), + client_crypto_config_.get(), &push_promise_index_); + session_->Initialize(); + connection_->SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(connection_->perspective())); + crypto_stream_ = static_cast( + session_->GetMutableCryptoStream()); + push_promise_[":path"] = "/bar"; + push_promise_[":authority"] = "www.google.com"; + push_promise_[":method"] = "GET"; + push_promise_[":scheme"] = "https"; + promise_url_ = + SpdyServerPushUtils::GetPromisedUrlFromHeaders(push_promise_); + promised_stream_id_ = GetNthServerInitiatedUnidirectionalStreamId( + connection_->transport_version(), 0); + associated_stream_id_ = GetNthClientInitiatedBidirectionalStreamId( + connection_->transport_version(), 0); + } + + // The function ensures that A) the MAX_STREAMS frames get properly deleted + // (since the test uses a 'did we leak memory' check ... if we just lose the + // frame, the test fails) and B) returns true (instead of the default, false) + // which ensures that the rest of the system thinks that the frame actually + // was transmitted. + bool ClearMaxStreamsControlFrame(const QuicFrame& frame) { + if (frame.type == MAX_STREAMS_FRAME) { + DeleteFrame(&const_cast(frame)); + return true; + } + return false; + } + + public: + bool ClearStreamsBlockedControlFrame(const QuicFrame& frame) { + if (frame.type == STREAMS_BLOCKED_FRAME) { + DeleteFrame(&const_cast(frame)); + return true; + } + return false; + } + + protected: + void CompleteCryptoHandshake() { + CompleteCryptoHandshake(kDefaultMaxStreamsPerConnection); + } + + void CompleteCryptoHandshake(uint32_t server_max_incoming_streams) { + if (VersionHasIetfQuicFrames(connection_->transport_version())) { + EXPECT_CALL(*connection_, SendControlFrame(_)) + .Times(::testing::AnyNumber()) + .WillRepeatedly(Invoke( + this, &QuicSpdyClientSessionTest::ClearMaxStreamsControlFrame)); + } + session_->CryptoConnect(); + QuicConfig config = DefaultQuicConfig(); + if (VersionHasIetfQuicFrames(connection_->transport_version())) { + config.SetMaxUnidirectionalStreamsToSend(server_max_incoming_streams); + config.SetMaxBidirectionalStreamsToSend(server_max_incoming_streams); + } else { + config.SetMaxBidirectionalStreamsToSend(server_max_incoming_streams); + } + crypto_test_utils::HandshakeWithFakeServer( + &config, server_crypto_config_.get(), &helper_, &alarm_factory_, + connection_, crypto_stream_, AlpnForVersion(connection_->version())); + } + + void CreateConnection() { + connection_ = new ::testing::NiceMock( + &helper_, &alarm_factory_, Perspective::IS_CLIENT, + SupportedVersions(GetParam())); + // Advance the time, because timers do not like uninitialized times. + connection_->AdvanceTime(QuicTime::Delta::FromSeconds(1)); + session_ = std::make_unique( + DefaultQuicConfig(), SupportedVersions(GetParam()), connection_, + QuicServerId(kServerHostname, kPort, false), + client_crypto_config_.get(), &push_promise_index_); + session_->Initialize(); + crypto_stream_ = static_cast( + session_->GetMutableCryptoStream()); + } + + void CompleteFirstConnection() { + CompleteCryptoHandshake(); + EXPECT_FALSE(session_->GetCryptoStream()->IsResumption()); + if (session_->version().UsesHttp3()) { + SettingsFrame settings; + settings.values[SETTINGS_QPACK_MAX_TABLE_CAPACITY] = 2; + settings.values[SETTINGS_MAX_FIELD_SECTION_SIZE] = 5; + settings.values[256] = 4; // unknown setting + session_->OnSettingsFrame(settings); + } + } + + // Owned by |session_|. + QuicCryptoClientStream* crypto_stream_; + std::unique_ptr server_crypto_config_; + std::unique_ptr client_crypto_config_; + MockQuicConnectionHelper helper_; + MockAlarmFactory alarm_factory_; + ::testing::NiceMock* connection_; + std::unique_ptr session_; + QuicClientPushPromiseIndex push_promise_index_; + Http2HeaderBlock push_promise_; + std::string promise_url_; + QuicStreamId promised_stream_id_; + QuicStreamId associated_stream_id_; + test::SimpleSessionCache* client_session_cache_; +}; + +std::string ParamNameFormatter( + const testing::TestParamInfo& info) { + return ParsedQuicVersionToString(info.param); +} + +INSTANTIATE_TEST_SUITE_P(Tests, QuicSpdyClientSessionTest, + ::testing::ValuesIn(AllSupportedVersions()), + ParamNameFormatter); + +TEST_P(QuicSpdyClientSessionTest, CryptoConnect) { CompleteCryptoHandshake(); } + +TEST_P(QuicSpdyClientSessionTest, NoEncryptionAfterInitialEncryption) { + if (GetParam().handshake_protocol == PROTOCOL_TLS1_3) { + // This test relies on resumption and is QUIC crypto specific, so it is + // disabled for TLS. + return; + } + // Complete a handshake in order to prime the crypto config for 0-RTT. + CompleteCryptoHandshake(); + + // Now create a second session using the same crypto config. + Initialize(); + + // Starting the handshake should move immediately to encryption + // established and will allow streams to be created. + session_->CryptoConnect(); + EXPECT_TRUE(session_->IsEncryptionEstablished()); + QuicSpdyClientStream* stream = session_->CreateOutgoingBidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + EXPECT_FALSE(QuicUtils::IsCryptoStreamId(connection_->transport_version(), + stream->id())); + + // Process an "inchoate" REJ from the server which will cause + // an inchoate CHLO to be sent and will leave the encryption level + // at NONE. + CryptoHandshakeMessage rej; + crypto_test_utils::FillInDummyReject(&rej); + EXPECT_TRUE(session_->IsEncryptionEstablished()); + crypto_test_utils::SendHandshakeMessageToStream( + session_->GetMutableCryptoStream(), rej, Perspective::IS_CLIENT); + EXPECT_FALSE(session_->IsEncryptionEstablished()); + EXPECT_EQ(ENCRYPTION_INITIAL, + QuicPacketCreatorPeer::GetEncryptionLevel( + QuicConnectionPeer::GetPacketCreator(connection_))); + // Verify that no new streams may be created. + EXPECT_TRUE(session_->CreateOutgoingBidirectionalStream() == nullptr); + // Verify that no data may be send on existing streams. + char data[] = "hello world"; + QuicConsumedData consumed = + session_->WritevData(stream->id(), ABSL_ARRAYSIZE(data), 0, NO_FIN, + NOT_RETRANSMISSION, ENCRYPTION_INITIAL); + EXPECT_EQ(0u, consumed.bytes_consumed); + EXPECT_FALSE(consumed.fin_consumed); +} + +TEST_P(QuicSpdyClientSessionTest, MaxNumStreamsWithNoFinOrRst) { + uint32_t kServerMaxIncomingStreams = 1; + CompleteCryptoHandshake(kServerMaxIncomingStreams); + + QuicSpdyClientStream* stream = session_->CreateOutgoingBidirectionalStream(); + ASSERT_TRUE(stream); + EXPECT_FALSE(session_->CreateOutgoingBidirectionalStream()); + + // Close the stream, but without having received a FIN or a RST_STREAM + // or MAX_STREAMS (IETF QUIC) and check that a new one can not be created. + session_->ResetStream(stream->id(), QUIC_STREAM_CANCELLED); + EXPECT_EQ(1u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); + + stream = session_->CreateOutgoingBidirectionalStream(); + EXPECT_FALSE(stream); +} + +TEST_P(QuicSpdyClientSessionTest, MaxNumStreamsWithRst) { + uint32_t kServerMaxIncomingStreams = 1; + CompleteCryptoHandshake(kServerMaxIncomingStreams); + + QuicSpdyClientStream* stream = session_->CreateOutgoingBidirectionalStream(); + ASSERT_NE(nullptr, stream); + EXPECT_EQ(nullptr, session_->CreateOutgoingBidirectionalStream()); + + // Close the stream and receive an RST frame to remove the unfinished stream + session_->ResetStream(stream->id(), QUIC_STREAM_CANCELLED); + session_->OnRstStream(QuicRstStreamFrame(kInvalidControlFrameId, stream->id(), + QUIC_RST_ACKNOWLEDGEMENT, 0)); + // Check that a new one can be created. + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); + if (VersionHasIetfQuicFrames(GetParam().transport_version)) { + // In IETF QUIC the stream limit increases only if we get a MAX_STREAMS + // frame; pretend we got one. + + QuicMaxStreamsFrame frame(0, 2, + /*unidirectional=*/false); + session_->OnMaxStreamsFrame(frame); + } + stream = session_->CreateOutgoingBidirectionalStream(); + EXPECT_NE(nullptr, stream); + if (VersionHasIetfQuicFrames(GetParam().transport_version)) { + // Ensure that we have 2 total streams, 1 open and 1 closed. + QuicStreamCount expected_stream_count = 2; + EXPECT_EQ(expected_stream_count, + QuicSessionPeer::ietf_bidirectional_stream_id_manager(&*session_) + ->outgoing_stream_count()); + } +} + +TEST_P(QuicSpdyClientSessionTest, ResetAndTrailers) { + // Tests the situation in which the client sends a RST at the same time that + // the server sends trailing headers (trailers). Receipt of the trailers by + // the client should result in all outstanding stream state being tidied up + // (including flow control, and number of available outgoing streams). + uint32_t kServerMaxIncomingStreams = 1; + CompleteCryptoHandshake(kServerMaxIncomingStreams); + + QuicSpdyClientStream* stream = session_->CreateOutgoingBidirectionalStream(); + ASSERT_NE(nullptr, stream); + + if (VersionHasIetfQuicFrames(GetParam().transport_version)) { + // For IETF QUIC, trying to open a stream and failing due to lack + // of stream ids will result in a STREAMS_BLOCKED. Make + // sure we get one. Also clear out the frame because if it's + // left sitting, the later SendRstStream will not actually + // transmit the RST_STREAM because the connection will be in write-blocked + // state. This means that the SendControlFrame that is expected w.r.t. the + // RST_STREAM, below, will not be satisfied. + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillOnce(Invoke( + this, &QuicSpdyClientSessionTest::ClearStreamsBlockedControlFrame)); + } + + EXPECT_EQ(nullptr, session_->CreateOutgoingBidirectionalStream()); + + QuicStreamId stream_id = stream->id(); + + EXPECT_CALL(*connection_, SendControlFrame(_)) + .Times(AtLeast(1)) + .WillRepeatedly(Invoke(&ClearControlFrame)); + EXPECT_CALL(*connection_, OnStreamReset(_, _)).Times(1); + session_->ResetStream(stream_id, QUIC_STREAM_PEER_GOING_AWAY); + + // A new stream cannot be created as the reset stream still counts as an open + // outgoing stream until closed by the server. + EXPECT_EQ(1u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); + stream = session_->CreateOutgoingBidirectionalStream(); + EXPECT_EQ(nullptr, stream); + + // The stream receives trailers with final byte offset: this is one of three + // ways that a peer can signal the end of a stream (the others being RST, + // stream data + FIN). + QuicHeaderList trailers; + trailers.OnHeaderBlockStart(); + trailers.OnHeader(kFinalOffsetHeaderKey, "0"); + trailers.OnHeaderBlockEnd(0, 0); + session_->OnStreamHeaderList(stream_id, /*fin=*/false, 0, trailers); + + // The stream is now complete from the client's perspective, and it should + // be able to create a new outgoing stream. + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); + if (VersionHasIetfQuicFrames(GetParam().transport_version)) { + QuicMaxStreamsFrame frame(0, 2, + /*unidirectional=*/false); + + session_->OnMaxStreamsFrame(frame); + } + stream = session_->CreateOutgoingBidirectionalStream(); + EXPECT_NE(nullptr, stream); + if (VersionHasIetfQuicFrames(GetParam().transport_version)) { + // Ensure that we have 2 open streams. + QuicStreamCount expected_stream_count = 2; + EXPECT_EQ(expected_stream_count, + QuicSessionPeer::ietf_bidirectional_stream_id_manager(&*session_) + ->outgoing_stream_count()); + } +} + +TEST_P(QuicSpdyClientSessionTest, ReceivedMalformedTrailersAfterSendingRst) { + // Tests the situation where the client has sent a RST to the server, and has + // received trailing headers with a malformed final byte offset value. + CompleteCryptoHandshake(); + + QuicSpdyClientStream* stream = session_->CreateOutgoingBidirectionalStream(); + ASSERT_NE(nullptr, stream); + + // Send the RST, which results in the stream being closed locally (but some + // state remains while the client waits for a response from the server). + QuicStreamId stream_id = stream->id(); + EXPECT_CALL(*connection_, SendControlFrame(_)) + .Times(AtLeast(1)) + .WillRepeatedly(Invoke(&ClearControlFrame)); + EXPECT_CALL(*connection_, OnStreamReset(_, _)).Times(1); + session_->ResetStream(stream_id, QUIC_STREAM_PEER_GOING_AWAY); + + // The stream receives trailers with final byte offset, but the header value + // is non-numeric and should be treated as malformed. + QuicHeaderList trailers; + trailers.OnHeaderBlockStart(); + trailers.OnHeader(kFinalOffsetHeaderKey, "invalid non-numeric value"); + trailers.OnHeaderBlockEnd(0, 0); + + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(1); + session_->OnStreamHeaderList(stream_id, /*fin=*/false, 0, trailers); +} + +TEST_P(QuicSpdyClientSessionTest, OnStreamHeaderListWithStaticStream) { + // Test situation where OnStreamHeaderList is called by stream with static id. + CompleteCryptoHandshake(); + + QuicHeaderList trailers; + trailers.OnHeaderBlockStart(); + trailers.OnHeader(kFinalOffsetHeaderKey, "0"); + trailers.OnHeaderBlockEnd(0, 0); + + // Initialize H/3 control stream. + QuicStreamId id; + if (VersionUsesHttp3(connection_->transport_version())) { + id = GetNthServerInitiatedUnidirectionalStreamId( + connection_->transport_version(), 3); + char type[] = {0x00}; + + QuicStreamFrame data1(id, false, 0, absl::string_view(type, 1)); + session_->OnStreamFrame(data1); + } else { + id = QuicUtils::GetHeadersStreamId(connection_->transport_version()); + } + + EXPECT_CALL(*connection_, CloseConnection(QUIC_INVALID_HEADERS_STREAM_DATA, + "stream is static", _)) + .Times(1); + session_->OnStreamHeaderList(id, + /*fin=*/false, 0, trailers); +} + +TEST_P(QuicSpdyClientSessionTest, OnPromiseHeaderListWithStaticStream) { + // Test situation where OnPromiseHeaderList is called by stream with static + // id. + CompleteCryptoHandshake(); + + QuicHeaderList trailers; + trailers.OnHeaderBlockStart(); + trailers.OnHeader(kFinalOffsetHeaderKey, "0"); + trailers.OnHeaderBlockEnd(0, 0); + + // Initialize H/3 control stream. + QuicStreamId id; + if (VersionUsesHttp3(connection_->transport_version())) { + id = GetNthServerInitiatedUnidirectionalStreamId( + connection_->transport_version(), 3); + char type[] = {0x00}; + + QuicStreamFrame data1(id, false, 0, absl::string_view(type, 1)); + session_->OnStreamFrame(data1); + } else { + id = QuicUtils::GetHeadersStreamId(connection_->transport_version()); + } + EXPECT_CALL(*connection_, CloseConnection(QUIC_INVALID_HEADERS_STREAM_DATA, + "stream is static", _)) + .Times(1); + session_->OnPromiseHeaderList(id, promised_stream_id_, 0, trailers); +} + +TEST_P(QuicSpdyClientSessionTest, GoAwayReceived) { + if (VersionHasIetfQuicFrames(connection_->transport_version())) { + return; + } + CompleteCryptoHandshake(); + + // After receiving a GoAway, I should no longer be able to create outgoing + // streams. + session_->connection()->OnGoAwayFrame(QuicGoAwayFrame( + kInvalidControlFrameId, QUIC_PEER_GOING_AWAY, 1u, "Going away.")); + EXPECT_EQ(nullptr, session_->CreateOutgoingBidirectionalStream()); +} + +static bool CheckForDecryptionError(QuicFramer* framer) { + return framer->error() == QUIC_DECRYPTION_FAILURE; +} + +// Various sorts of invalid packets that should not cause a connection +// to be closed. +TEST_P(QuicSpdyClientSessionTest, InvalidPacketReceived) { + QuicSocketAddress server_address(TestPeerIPAddress(), kTestPort); + QuicSocketAddress client_address(TestPeerIPAddress(), kTestPort); + + EXPECT_CALL(*connection_, ProcessUdpPacket(server_address, client_address, _)) + .WillRepeatedly(Invoke(static_cast(connection_), + &MockQuicConnection::ReallyProcessUdpPacket)); + EXPECT_CALL(*connection_, OnCanWrite()).Times(AnyNumber()); + EXPECT_CALL(*connection_, OnError(_)).Times(1); + + // Verify that empty packets don't close the connection. + QuicReceivedPacket zero_length_packet(nullptr, 0, QuicTime::Zero(), false); + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + session_->ProcessUdpPacket(client_address, server_address, + zero_length_packet); + + // Verifiy that small, invalid packets don't close the connection. + char buf[2] = {0x00, 0x01}; + QuicConnectionId connection_id = session_->connection()->connection_id(); + QuicReceivedPacket valid_packet(buf, 2, QuicTime::Zero(), false); + // Close connection shouldn't be called. + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + EXPECT_CALL(*connection_, OnError(_)).Times(AtMost(1)); + session_->ProcessUdpPacket(client_address, server_address, valid_packet); + + // Verify that a non-decryptable packet doesn't close the connection. + QuicFramerPeer::SetLastSerializedServerConnectionId( + QuicConnectionPeer::GetFramer(connection_), connection_id); + ParsedQuicVersionVector versions = SupportedVersions(GetParam()); + QuicConnectionId destination_connection_id = EmptyQuicConnectionId(); + QuicConnectionId source_connection_id = connection_id; + std::unique_ptr packet(ConstructEncryptedPacket( + destination_connection_id, source_connection_id, false, false, 100, + "data", true, CONNECTION_ID_ABSENT, CONNECTION_ID_ABSENT, + PACKET_4BYTE_PACKET_NUMBER, &versions, Perspective::IS_SERVER)); + std::unique_ptr received( + ConstructReceivedPacket(*packet, QuicTime::Zero())); + // Change the last byte of the encrypted data. + *(const_cast(received->data() + received->length() - 1)) += 1; + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + EXPECT_CALL(*connection_, OnError(Truly(CheckForDecryptionError))).Times(1); + session_->ProcessUdpPacket(client_address, server_address, *received); +} + +// A packet with invalid framing should cause a connection to be closed. +TEST_P(QuicSpdyClientSessionTest, InvalidFramedPacketReceived) { + const ParsedQuicVersion version = GetParam(); + QuicSocketAddress server_address(TestPeerIPAddress(), kTestPort); + QuicSocketAddress client_address(TestPeerIPAddress(), kTestPort); + if (version.KnowsWhichDecrypterToUse()) { + connection_->InstallDecrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + } else { + connection_->SetAlternativeDecrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE), + false); + } + + EXPECT_CALL(*connection_, ProcessUdpPacket(server_address, client_address, _)) + .WillRepeatedly(Invoke(static_cast(connection_), + &MockQuicConnection::ReallyProcessUdpPacket)); + EXPECT_CALL(*connection_, OnError(_)).Times(1); + + // Verify that a decryptable packet with bad frames does close the connection. + QuicConnectionId destination_connection_id = + session_->connection()->connection_id(); + QuicConnectionId source_connection_id = EmptyQuicConnectionId(); + QuicFramerPeer::SetLastSerializedServerConnectionId( + QuicConnectionPeer::GetFramer(connection_), destination_connection_id); + bool version_flag = false; + QuicConnectionIdIncluded scid_included = CONNECTION_ID_ABSENT; + if (version.HasIetfInvariantHeader()) { + version_flag = true; + source_connection_id = destination_connection_id; + scid_included = CONNECTION_ID_PRESENT; + } + std::unique_ptr packet(ConstructMisFramedEncryptedPacket( + destination_connection_id, source_connection_id, version_flag, false, 100, + "data", CONNECTION_ID_ABSENT, scid_included, PACKET_4BYTE_PACKET_NUMBER, + version, Perspective::IS_SERVER)); + std::unique_ptr received( + ConstructReceivedPacket(*packet, QuicTime::Zero())); + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(1); + session_->ProcessUdpPacket(client_address, server_address, *received); +} + +TEST_P(QuicSpdyClientSessionTest, PushPromiseOnPromiseHeaders) { + if (VersionHasIetfQuicFrames(connection_->transport_version())) { + return; + } + + // Initialize crypto before the client session will create a stream. + CompleteCryptoHandshake(); + + MockQuicSpdyClientStream* stream = static_cast( + session_->CreateOutgoingBidirectionalStream()); + + EXPECT_CALL(*stream, OnPromiseHeaderList(_, _, _)); + session_->OnPromiseHeaderList(associated_stream_id_, promised_stream_id_, 0, + QuicHeaderList()); +} + +TEST_P(QuicSpdyClientSessionTest, PushPromiseStreamIdTooHigh) { + if (VersionHasIetfQuicFrames(connection_->transport_version())) { + return; + } + + // Initialize crypto before the client session will create a stream. + CompleteCryptoHandshake(); + QuicStreamId stream_id = + QuicSessionPeer::GetNextOutgoingBidirectionalStreamId(session_.get()); + QuicSessionPeer::ActivateStream( + session_.get(), std::make_unique( + stream_id, session_.get(), BIDIRECTIONAL)); + + QuicHeaderList headers; + headers.OnHeaderBlockStart(); + headers.OnHeader(":path", "/bar"); + headers.OnHeader(":authority", "www.google.com"); + headers.OnHeader(":method", "GET"); + headers.OnHeader(":scheme", "https"); + headers.OnHeaderBlockEnd(0, 0); + + const QuicStreamId promise_id = GetNthServerInitiatedUnidirectionalStreamId( + connection_->transport_version(), 11); + session_->OnPromiseHeaderList(stream_id, promise_id, 0, headers); +} + +TEST_P(QuicSpdyClientSessionTest, PushPromiseOnPromiseHeadersAlreadyClosed) { + // Initialize crypto before the client session will create a stream. + CompleteCryptoHandshake(); + + session_->CreateOutgoingBidirectionalStream(); + + EXPECT_CALL(*connection_, SendControlFrame(_)); + EXPECT_CALL(*connection_, + OnStreamReset(promised_stream_id_, QUIC_REFUSED_STREAM)); + session_->ResetPromised(promised_stream_id_, QUIC_REFUSED_STREAM); + + session_->OnPromiseHeaderList(associated_stream_id_, promised_stream_id_, 0, + QuicHeaderList()); +} + +TEST_P(QuicSpdyClientSessionTest, PushPromiseOutOfOrder) { + if (VersionHasIetfQuicFrames(connection_->transport_version())) { + return; + } + + // Initialize crypto before the client session will create a stream. + CompleteCryptoHandshake(); + + MockQuicSpdyClientStream* stream = static_cast( + session_->CreateOutgoingBidirectionalStream()); + + EXPECT_CALL(*stream, OnPromiseHeaderList(promised_stream_id_, _, _)); + session_->OnPromiseHeaderList(associated_stream_id_, promised_stream_id_, 0, + QuicHeaderList()); + associated_stream_id_ += + QuicUtils::StreamIdDelta(connection_->transport_version()); + if (!VersionUsesHttp3(session_->transport_version())) { + EXPECT_CALL(*connection_, + CloseConnection(QUIC_INVALID_STREAM_ID, + "Received push stream id lesser or equal to the" + " last accepted before", + _)); + } + session_->OnPromiseHeaderList(associated_stream_id_, promised_stream_id_, 0, + QuicHeaderList()); +} + +TEST_P(QuicSpdyClientSessionTest, PushPromiseOutgoingStreamId) { + // Initialize crypto before the client session will create a stream. + CompleteCryptoHandshake(); + + MockQuicSpdyClientStream* stream = static_cast( + session_->CreateOutgoingBidirectionalStream()); + + // Promise an illegal (outgoing) stream id. + promised_stream_id_ = GetNthClientInitiatedBidirectionalStreamId( + connection_->transport_version(), 0); + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_INVALID_STREAM_ID, + "Received push stream id for outgoing stream.", _)); + + session_->OnPromiseHeaderList(stream->id(), promised_stream_id_, 0, + QuicHeaderList()); +} + +TEST_P(QuicSpdyClientSessionTest, PushPromiseHandlePromise) { + // Initialize crypto before the client session will create a stream. + CompleteCryptoHandshake(); + + session_->CreateOutgoingBidirectionalStream(); + + EXPECT_TRUE(session_->HandlePromised(associated_stream_id_, + promised_stream_id_, push_promise_)); + + EXPECT_NE(session_->GetPromisedById(promised_stream_id_), nullptr); + EXPECT_NE(session_->GetPromisedByUrl(promise_url_), nullptr); +} + +TEST_P(QuicSpdyClientSessionTest, PushPromiseAlreadyClosed) { + // Initialize crypto before the client session will create a stream. + CompleteCryptoHandshake(); + + session_->CreateOutgoingBidirectionalStream(); + session_->GetOrCreateStream(promised_stream_id_); + + EXPECT_CALL(*connection_, SendControlFrame(_)); + EXPECT_CALL(*connection_, + OnStreamReset(promised_stream_id_, QUIC_REFUSED_STREAM)); + + session_->ResetPromised(promised_stream_id_, QUIC_REFUSED_STREAM); + Http2HeaderBlock promise_headers; + EXPECT_FALSE(session_->HandlePromised(associated_stream_id_, + promised_stream_id_, promise_headers)); + + // Verify that the promise was not created. + EXPECT_EQ(session_->GetPromisedById(promised_stream_id_), nullptr); + EXPECT_EQ(session_->GetPromisedByUrl(promise_url_), nullptr); +} + +TEST_P(QuicSpdyClientSessionTest, PushPromiseDuplicateUrl) { + // Initialize crypto before the client session will create a stream. + CompleteCryptoHandshake(); + + session_->CreateOutgoingBidirectionalStream(); + + EXPECT_TRUE(session_->HandlePromised(associated_stream_id_, + promised_stream_id_, push_promise_)); + + EXPECT_NE(session_->GetPromisedById(promised_stream_id_), nullptr); + EXPECT_NE(session_->GetPromisedByUrl(promise_url_), nullptr); + + promised_stream_id_ += + QuicUtils::StreamIdDelta(connection_->transport_version()); + EXPECT_CALL(*connection_, SendControlFrame(_)); + EXPECT_CALL(*connection_, + OnStreamReset(promised_stream_id_, QUIC_DUPLICATE_PROMISE_URL)); + + EXPECT_FALSE(session_->HandlePromised(associated_stream_id_, + promised_stream_id_, push_promise_)); + + // Verify that the promise was not created. + EXPECT_EQ(session_->GetPromisedById(promised_stream_id_), nullptr); +} + +TEST_P(QuicSpdyClientSessionTest, ReceivingPromiseEnhanceYourCalm) { + CompleteCryptoHandshake(); + for (size_t i = 0u; i < session_->get_max_promises(); i++) { + push_promise_[":path"] = absl::StrCat("/bar", i); + + QuicStreamId id = + promised_stream_id_ + + i * QuicUtils::StreamIdDelta(connection_->transport_version()); + + EXPECT_TRUE( + session_->HandlePromised(associated_stream_id_, id, push_promise_)); + + // Verify that the promise is in the unclaimed streams map. + std::string promise_url( + SpdyServerPushUtils::GetPromisedUrlFromHeaders(push_promise_)); + EXPECT_NE(session_->GetPromisedByUrl(promise_url), nullptr); + EXPECT_NE(session_->GetPromisedById(id), nullptr); + } + + // One more promise, this should be refused. + int i = session_->get_max_promises(); + push_promise_[":path"] = absl::StrCat("/bar", i); + + QuicStreamId id = + promised_stream_id_ + + i * QuicUtils::StreamIdDelta(connection_->transport_version()); + EXPECT_CALL(*connection_, SendControlFrame(_)); + EXPECT_CALL(*connection_, OnStreamReset(id, QUIC_REFUSED_STREAM)); + EXPECT_FALSE( + session_->HandlePromised(associated_stream_id_, id, push_promise_)); + + // Verify that the promise was not created. + std::string promise_url( + SpdyServerPushUtils::GetPromisedUrlFromHeaders(push_promise_)); + EXPECT_EQ(session_->GetPromisedById(id), nullptr); + EXPECT_EQ(session_->GetPromisedByUrl(promise_url), nullptr); +} + +TEST_P(QuicSpdyClientSessionTest, IsClosedTrueAfterResetPromisedAlreadyOpen) { + // Initialize crypto before the client session will create a stream. + CompleteCryptoHandshake(); + + session_->GetOrCreateStream(promised_stream_id_); + EXPECT_CALL(*connection_, SendControlFrame(_)); + EXPECT_CALL(*connection_, + OnStreamReset(promised_stream_id_, QUIC_REFUSED_STREAM)); + session_->ResetPromised(promised_stream_id_, QUIC_REFUSED_STREAM); + EXPECT_TRUE(session_->IsClosedStream(promised_stream_id_)); +} + +TEST_P(QuicSpdyClientSessionTest, IsClosedTrueAfterResetPromisedNonexistant) { + // Initialize crypto before the client session will create a stream. + CompleteCryptoHandshake(); + + EXPECT_CALL(*connection_, SendControlFrame(_)); + EXPECT_CALL(*connection_, + OnStreamReset(promised_stream_id_, QUIC_REFUSED_STREAM)); + session_->ResetPromised(promised_stream_id_, QUIC_REFUSED_STREAM); + EXPECT_TRUE(session_->IsClosedStream(promised_stream_id_)); +} + +TEST_P(QuicSpdyClientSessionTest, OnInitialHeadersCompleteIsPush) { + // Initialize crypto before the client session will create a stream. + CompleteCryptoHandshake(); + session_->GetOrCreateStream(promised_stream_id_); + EXPECT_TRUE(session_->HandlePromised(associated_stream_id_, + promised_stream_id_, push_promise_)); + EXPECT_NE(session_->GetPromisedById(promised_stream_id_), nullptr); + EXPECT_NE(session_->GetPromisedStream(promised_stream_id_), nullptr); + EXPECT_NE(session_->GetPromisedByUrl(promise_url_), nullptr); + + session_->OnInitialHeadersComplete(promised_stream_id_, Http2HeaderBlock()); +} + +TEST_P(QuicSpdyClientSessionTest, OnInitialHeadersCompleteIsNotPush) { + // Initialize crypto before the client session will create a stream. + CompleteCryptoHandshake(); + session_->CreateOutgoingBidirectionalStream(); + session_->OnInitialHeadersComplete(promised_stream_id_, Http2HeaderBlock()); +} + +TEST_P(QuicSpdyClientSessionTest, DeletePromised) { + // Initialize crypto before the client session will create a stream. + CompleteCryptoHandshake(); + session_->GetOrCreateStream(promised_stream_id_); + EXPECT_TRUE(session_->HandlePromised(associated_stream_id_, + promised_stream_id_, push_promise_)); + QuicClientPromisedInfo* promised = + session_->GetPromisedById(promised_stream_id_); + EXPECT_NE(promised, nullptr); + EXPECT_NE(session_->GetPromisedStream(promised_stream_id_), nullptr); + EXPECT_NE(session_->GetPromisedByUrl(promise_url_), nullptr); + + session_->DeletePromised(promised); + EXPECT_EQ(session_->GetPromisedById(promised_stream_id_), nullptr); + EXPECT_EQ(session_->GetPromisedByUrl(promise_url_), nullptr); +} + +TEST_P(QuicSpdyClientSessionTest, ResetPromised) { + // Initialize crypto before the client session will create a stream. + CompleteCryptoHandshake(); + session_->GetOrCreateStream(promised_stream_id_); + EXPECT_TRUE(session_->HandlePromised(associated_stream_id_, + promised_stream_id_, push_promise_)); + EXPECT_CALL(*connection_, SendControlFrame(_)); + EXPECT_CALL(*connection_, + OnStreamReset(promised_stream_id_, QUIC_STREAM_PEER_GOING_AWAY)); + session_->ResetStream(promised_stream_id_, QUIC_STREAM_PEER_GOING_AWAY); + QuicClientPromisedInfo* promised = + session_->GetPromisedById(promised_stream_id_); + EXPECT_NE(promised, nullptr); + EXPECT_NE(session_->GetPromisedByUrl(promise_url_), nullptr); + EXPECT_EQ(session_->GetPromisedStream(promised_stream_id_), nullptr); +} + +TEST_P(QuicSpdyClientSessionTest, PushPromiseInvalidMethod) { + // Initialize crypto before the client session will create a stream. + CompleteCryptoHandshake(); + + session_->CreateOutgoingBidirectionalStream(); + + EXPECT_CALL(*connection_, SendControlFrame(_)); + EXPECT_CALL(*connection_, + OnStreamReset(promised_stream_id_, QUIC_INVALID_PROMISE_METHOD)); + + push_promise_[":method"] = "POST"; + EXPECT_FALSE(session_->HandlePromised(associated_stream_id_, + promised_stream_id_, push_promise_)); + + EXPECT_EQ(session_->GetPromisedById(promised_stream_id_), nullptr); + EXPECT_EQ(session_->GetPromisedByUrl(promise_url_), nullptr); +} + +TEST_P(QuicSpdyClientSessionTest, PushPromiseInvalidHost) { + // Initialize crypto before the client session will create a stream. + CompleteCryptoHandshake(); + + session_->CreateOutgoingBidirectionalStream(); + + EXPECT_CALL(*connection_, SendControlFrame(_)); + EXPECT_CALL(*connection_, + OnStreamReset(promised_stream_id_, QUIC_INVALID_PROMISE_URL)); + + push_promise_[":authority"] = ""; + EXPECT_FALSE(session_->HandlePromised(associated_stream_id_, + promised_stream_id_, push_promise_)); + + EXPECT_EQ(session_->GetPromisedById(promised_stream_id_), nullptr); + EXPECT_EQ(session_->GetPromisedByUrl(promise_url_), nullptr); +} + +TEST_P(QuicSpdyClientSessionTest, + TryToCreateServerInitiatedBidirectionalStream) { + if (VersionHasIetfQuicFrames(connection_->transport_version())) { + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_HTTP_SERVER_INITIATED_BIDIRECTIONAL_STREAM, _, _)); + } else { + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + } + session_->GetOrCreateStream(GetNthServerInitiatedBidirectionalStreamId( + connection_->transport_version(), 0)); +} + +TEST_P(QuicSpdyClientSessionTest, TooManyPushPromises) { + if (VersionHasIetfQuicFrames(connection_->transport_version())) { + return; + } + + // Initialize crypto before the client session will create a stream. + CompleteCryptoHandshake(); + QuicStreamId stream_id = + QuicSessionPeer::GetNextOutgoingBidirectionalStreamId(session_.get()); + QuicSessionPeer::ActivateStream( + session_.get(), std::make_unique( + stream_id, session_.get(), BIDIRECTIONAL)); + + EXPECT_CALL(*connection_, OnStreamReset(_, QUIC_REFUSED_STREAM)); + + for (size_t promise_count = 0; promise_count <= session_->get_max_promises(); + promise_count++) { + auto promise_id = GetNthServerInitiatedUnidirectionalStreamId( + connection_->transport_version(), promise_count); + auto headers = QuicHeaderList(); + headers.OnHeaderBlockStart(); + headers.OnHeader(":path", absl::StrCat("/", promise_count)); + headers.OnHeader(":authority", "www.google.com"); + headers.OnHeader(":method", "GET"); + headers.OnHeader(":scheme", "https"); + headers.OnHeaderBlockEnd(0, 0); + session_->OnPromiseHeaderList(stream_id, promise_id, 0, headers); + } +} + +// Test that upon receiving HTTP/3 SETTINGS, the settings are serialized and +// stored into client session cache. +TEST_P(QuicSpdyClientSessionTest, OnSettingsFrame) { + // This feature is HTTP/3 only + if (!VersionUsesHttp3(session_->transport_version())) { + return; + } + CompleteCryptoHandshake(); + SettingsFrame settings; + settings.values[SETTINGS_QPACK_MAX_TABLE_CAPACITY] = 2; + settings.values[SETTINGS_MAX_FIELD_SECTION_SIZE] = 5; + settings.values[256] = 4; // unknown setting + char application_state[] = {// type (SETTINGS) + 0x04, + // length + 0x07, + // identifier (SETTINGS_QPACK_MAX_TABLE_CAPACITY) + 0x01, + // content + 0x02, + // identifier (SETTINGS_MAX_FIELD_SECTION_SIZE) + 0x06, + // content + 0x05, + // identifier (256 in variable length integer) + 0x40 + 0x01, 0x00, + // content + 0x04}; + ApplicationState expected(std::begin(application_state), + std::end(application_state)); + session_->OnSettingsFrame(settings); + EXPECT_EQ(expected, *client_session_cache_ + ->Lookup(QuicServerId(kServerHostname, kPort, false), + session_->GetClock()->WallNow(), nullptr) + ->application_state); +} + +TEST_P(QuicSpdyClientSessionTest, IetfZeroRttSetup) { + // This feature is TLS-only. + if (session_->version().UsesQuicCrypto()) { + return; + } + + CompleteFirstConnection(); + + CreateConnection(); + // Session configs should be in initial state. + if (session_->version().UsesHttp3()) { + EXPECT_EQ(0u, session_->flow_controller()->send_window_offset()); + EXPECT_EQ(std::numeric_limits::max(), + session_->max_outbound_header_list_size()); + } else { + EXPECT_EQ(kMinimumFlowControlSendWindow, + session_->flow_controller()->send_window_offset()); + } + session_->CryptoConnect(); + EXPECT_TRUE(session_->IsEncryptionEstablished()); + EXPECT_EQ(ENCRYPTION_ZERO_RTT, session_->connection()->encryption_level()); + + // The client session should have a basic setup ready before the handshake + // succeeds. + EXPECT_EQ(kInitialSessionFlowControlWindowForTest, + session_->flow_controller()->send_window_offset()); + if (session_->version().UsesHttp3()) { + auto* id_manager = QuicSessionPeer::ietf_streamid_manager(session_.get()); + EXPECT_EQ(kDefaultMaxStreamsPerConnection, + id_manager->max_outgoing_bidirectional_streams()); + EXPECT_EQ( + kDefaultMaxStreamsPerConnection + kHttp3StaticUnidirectionalStreamCount, + id_manager->max_outgoing_unidirectional_streams()); + auto* control_stream = + QuicSpdySessionPeer::GetSendControlStream(session_.get()); + EXPECT_EQ(kInitialStreamFlowControlWindowForTest, + QuicStreamPeer::SendWindowOffset(control_stream)); + EXPECT_EQ(5u, session_->max_outbound_header_list_size()); + } else { + auto* id_manager = QuicSessionPeer::GetStreamIdManager(session_.get()); + EXPECT_EQ(kDefaultMaxStreamsPerConnection, + id_manager->max_open_outgoing_streams()); + } + + // Complete the handshake with a different config. + QuicConfig config = DefaultQuicConfig(); + config.SetInitialMaxStreamDataBytesUnidirectionalToSend( + kInitialStreamFlowControlWindowForTest + 1); + config.SetInitialSessionFlowControlWindowToSend( + kInitialSessionFlowControlWindowForTest + 1); + config.SetMaxBidirectionalStreamsToSend(kDefaultMaxStreamsPerConnection + 1); + config.SetMaxUnidirectionalStreamsToSend(kDefaultMaxStreamsPerConnection + 1); + crypto_test_utils::HandshakeWithFakeServer( + &config, server_crypto_config_.get(), &helper_, &alarm_factory_, + connection_, crypto_stream_, AlpnForVersion(connection_->version())); + + EXPECT_TRUE(session_->GetCryptoStream()->IsResumption()); + EXPECT_EQ(kInitialSessionFlowControlWindowForTest + 1, + session_->flow_controller()->send_window_offset()); + if (session_->version().UsesHttp3()) { + auto* id_manager = QuicSessionPeer::ietf_streamid_manager(session_.get()); + auto* control_stream = + QuicSpdySessionPeer::GetSendControlStream(session_.get()); + EXPECT_EQ(kDefaultMaxStreamsPerConnection + 1, + id_manager->max_outgoing_bidirectional_streams()); + EXPECT_EQ(kDefaultMaxStreamsPerConnection + + kHttp3StaticUnidirectionalStreamCount + 1, + id_manager->max_outgoing_unidirectional_streams()); + EXPECT_EQ(kInitialStreamFlowControlWindowForTest + 1, + QuicStreamPeer::SendWindowOffset(control_stream)); + } else { + auto* id_manager = QuicSessionPeer::GetStreamIdManager(session_.get()); + EXPECT_EQ(kDefaultMaxStreamsPerConnection + 1, + id_manager->max_open_outgoing_streams()); + } + + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + // Let the session receive a new SETTINGS frame to complete the second + // connection. + if (session_->version().UsesHttp3()) { + SettingsFrame settings; + settings.values[SETTINGS_QPACK_MAX_TABLE_CAPACITY] = 2; + settings.values[SETTINGS_MAX_FIELD_SECTION_SIZE] = 5; + settings.values[256] = 4; // unknown setting + session_->OnSettingsFrame(settings); + } +} + +// Regression test for b/159168475 +TEST_P(QuicSpdyClientSessionTest, RetransmitDataOnZeroRttReject) { + // This feature is TLS-only. + if (session_->version().UsesQuicCrypto()) { + return; + } + + CompleteFirstConnection(); + + // Create a second connection, but disable 0-RTT on the server. + CreateConnection(); + ON_CALL(*connection_, OnCanWrite()) + .WillByDefault( + testing::Invoke(connection_, &MockQuicConnection::ReallyOnCanWrite)); + EXPECT_CALL(*connection_, OnCanWrite()).Times(0); + + QuicConfig config = DefaultQuicConfig(); + config.SetMaxUnidirectionalStreamsToSend(kDefaultMaxStreamsPerConnection); + config.SetMaxBidirectionalStreamsToSend(kDefaultMaxStreamsPerConnection); + SSL_CTX_set_early_data_enabled(server_crypto_config_->ssl_ctx(), false); + + // Packets will be written: CHLO, HTTP/3 SETTINGS (H/3 only), and request + // data. + EXPECT_CALL(*connection_, + OnPacketSent(ENCRYPTION_INITIAL, NOT_RETRANSMISSION)); + EXPECT_CALL(*connection_, + OnPacketSent(ENCRYPTION_ZERO_RTT, NOT_RETRANSMISSION)) + .Times(session_->version().UsesHttp3() ? 2 : 1); + session_->CryptoConnect(); + EXPECT_TRUE(session_->IsEncryptionEstablished()); + EXPECT_EQ(ENCRYPTION_ZERO_RTT, session_->connection()->encryption_level()); + QuicSpdyClientStream* stream = session_->CreateOutgoingBidirectionalStream(); + ASSERT_TRUE(stream); + stream->WriteOrBufferData("hello", true, nullptr); + + // When handshake is done, the client sends 2 packet: HANDSHAKE FINISHED, and + // coalesced retransmission of HTTP/3 SETTINGS and request data. + EXPECT_CALL(*connection_, + OnPacketSent(ENCRYPTION_HANDSHAKE, NOT_RETRANSMISSION)); + // TODO(b/158027651): change transmission type to ALL_ZERO_RTT_RETRANSMISSION. + EXPECT_CALL(*connection_, + OnPacketSent(ENCRYPTION_FORWARD_SECURE, LOSS_RETRANSMISSION)); + crypto_test_utils::HandshakeWithFakeServer( + &config, server_crypto_config_.get(), &helper_, &alarm_factory_, + connection_, crypto_stream_, AlpnForVersion(connection_->version())); + EXPECT_TRUE(session_->GetCryptoStream()->IsResumption()); +} + +// When IETF QUIC 0-RTT is rejected, a server-sent fresh transport params is +// available. If the new transport params reduces stream/flow control limit to +// lower than what the client has already used, connection will be closed. +TEST_P(QuicSpdyClientSessionTest, ZeroRttRejectReducesStreamLimitTooMuch) { + // This feature is TLS-only. + if (session_->version().UsesQuicCrypto()) { + return; + } + + CompleteFirstConnection(); + + // Create a second connection, but disable 0-RTT on the server. + CreateConnection(); + QuicConfig config = DefaultQuicConfig(); + // Server doesn't allow any bidirectional streams. + config.SetMaxBidirectionalStreamsToSend(0); + SSL_CTX_set_early_data_enabled(server_crypto_config_->ssl_ctx(), false); + session_->CryptoConnect(); + EXPECT_TRUE(session_->IsEncryptionEstablished()); + QuicSpdyClientStream* stream = session_->CreateOutgoingBidirectionalStream(); + ASSERT_TRUE(stream); + + if (session_->version().UsesHttp3()) { + EXPECT_CALL( + *connection_, + CloseConnection( + QUIC_ZERO_RTT_UNRETRANSMITTABLE, + "Server rejected 0-RTT, aborting because new bidirectional initial " + "stream limit 0 is less than current open streams: 1", + _)) + .WillOnce(testing::Invoke(connection_, + &MockQuicConnection::ReallyCloseConnection)); + } else { + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_INTERNAL_ERROR, + "Server rejected 0-RTT, aborting because new stream " + "limit 0 is less than current open streams: 1", + _)) + .WillOnce(testing::Invoke(connection_, + &MockQuicConnection::ReallyCloseConnection)); + } + EXPECT_CALL(*connection_, CloseConnection(QUIC_HANDSHAKE_FAILED, _, _)); + + crypto_test_utils::HandshakeWithFakeServer( + &config, server_crypto_config_.get(), &helper_, &alarm_factory_, + connection_, crypto_stream_, AlpnForVersion(connection_->version())); +} + +TEST_P(QuicSpdyClientSessionTest, + ZeroRttRejectReducesStreamFlowControlTooMuch) { + // This feature is TLS-only. + if (session_->version().UsesQuicCrypto()) { + return; + } + + CompleteFirstConnection(); + + // Create a second connection, but disable 0-RTT on the server. + CreateConnection(); + QuicConfig config = DefaultQuicConfig(); + // Server doesn't allow any outgoing streams. + config.SetInitialMaxStreamDataBytesIncomingBidirectionalToSend(2); + config.SetInitialMaxStreamDataBytesUnidirectionalToSend(1); + SSL_CTX_set_early_data_enabled(server_crypto_config_->ssl_ctx(), false); + session_->CryptoConnect(); + EXPECT_TRUE(session_->IsEncryptionEstablished()); + QuicSpdyClientStream* stream = session_->CreateOutgoingBidirectionalStream(); + ASSERT_TRUE(stream); + // Let the stream write more than 1 byte of data. + stream->WriteOrBufferData("hello", true, nullptr); + + if (session_->version().UsesHttp3()) { + // Both control stream and the request stream will report errors. + // Open question: should both streams be closed with the same error code? + EXPECT_CALL(*connection_, CloseConnection(_, _, _)) + .WillOnce(testing::Invoke(connection_, + &MockQuicConnection::ReallyCloseConnection)); + EXPECT_CALL(*connection_, + CloseConnection(QUIC_ZERO_RTT_UNRETRANSMITTABLE, _, _)) + .WillOnce(testing::Invoke(connection_, + &MockQuicConnection::ReallyCloseConnection)) + .RetiresOnSaturation(); + } else { + EXPECT_CALL(*connection_, + CloseConnection( + QUIC_ZERO_RTT_UNRETRANSMITTABLE, + "Server rejected 0-RTT, aborting because new stream max " + "data 2 for stream 3 is less than currently used: 5", + _)) + .Times(1) + .WillOnce(testing::Invoke(connection_, + &MockQuicConnection::ReallyCloseConnection)); + } + EXPECT_CALL(*connection_, CloseConnection(QUIC_HANDSHAKE_FAILED, _, _)); + + crypto_test_utils::HandshakeWithFakeServer( + &config, server_crypto_config_.get(), &helper_, &alarm_factory_, + connection_, crypto_stream_, AlpnForVersion(connection_->version())); +} + +TEST_P(QuicSpdyClientSessionTest, + ZeroRttRejectReducesSessionFlowControlTooMuch) { + // This feature is TLS-only. + if (session_->version().UsesQuicCrypto()) { + return; + } + + CompleteFirstConnection(); + + // Create a second connection, but disable 0-RTT on the server. + CreateConnection(); + QuicConfig config = DefaultQuicConfig(); + // Server doesn't allow minimum data in session. + config.SetInitialSessionFlowControlWindowToSend( + kMinimumFlowControlSendWindow); + SSL_CTX_set_early_data_enabled(server_crypto_config_->ssl_ctx(), false); + session_->CryptoConnect(); + EXPECT_TRUE(session_->IsEncryptionEstablished()); + QuicSpdyClientStream* stream = session_->CreateOutgoingBidirectionalStream(); + ASSERT_TRUE(stream); + std::string data_to_send(kMinimumFlowControlSendWindow + 1, 'x'); + // Let the stream write some data. + stream->WriteOrBufferData(data_to_send, true, nullptr); + + EXPECT_CALL(*connection_, + CloseConnection(QUIC_ZERO_RTT_UNRETRANSMITTABLE, _, _)) + .WillOnce(testing::Invoke(connection_, + &MockQuicConnection::ReallyCloseConnection)); + EXPECT_CALL(*connection_, CloseConnection(QUIC_HANDSHAKE_FAILED, _, _)); + + crypto_test_utils::HandshakeWithFakeServer( + &config, server_crypto_config_.get(), &helper_, &alarm_factory_, + connection_, crypto_stream_, AlpnForVersion(connection_->version())); +} + +TEST_P(QuicSpdyClientSessionTest, BadSettingsInZeroRttResumption) { + if (!session_->version().UsesHttp3()) { + return; + } + + CompleteFirstConnection(); + + CreateConnection(); + CompleteCryptoHandshake(); + EXPECT_TRUE(session_->GetCryptoStream()->EarlyDataAccepted()); + + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_HTTP_ZERO_RTT_RESUMPTION_SETTINGS_MISMATCH, _, _)) + .WillOnce(testing::Invoke(connection_, + &MockQuicConnection::ReallyCloseConnection)); + // Let the session receive a different SETTINGS frame. + SettingsFrame settings; + settings.values[SETTINGS_QPACK_MAX_TABLE_CAPACITY] = 1; + settings.values[SETTINGS_MAX_FIELD_SECTION_SIZE] = 5; + settings.values[256] = 4; // unknown setting + session_->OnSettingsFrame(settings); +} + +TEST_P(QuicSpdyClientSessionTest, BadSettingsInZeroRttRejection) { + if (!session_->version().UsesHttp3()) { + return; + } + + CompleteFirstConnection(); + + CreateConnection(); + SSL_CTX_set_early_data_enabled(server_crypto_config_->ssl_ctx(), false); + session_->CryptoConnect(); + EXPECT_TRUE(session_->IsEncryptionEstablished()); + QuicConfig config = DefaultQuicConfig(); + crypto_test_utils::HandshakeWithFakeServer( + &config, server_crypto_config_.get(), &helper_, &alarm_factory_, + connection_, crypto_stream_, AlpnForVersion(connection_->version())); + EXPECT_FALSE(session_->GetCryptoStream()->EarlyDataAccepted()); + + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_HTTP_ZERO_RTT_REJECTION_SETTINGS_MISMATCH, _, _)) + .WillOnce(testing::Invoke(connection_, + &MockQuicConnection::ReallyCloseConnection)); + // Let the session receive a different SETTINGS frame. + SettingsFrame settings; + settings.values[SETTINGS_QPACK_MAX_TABLE_CAPACITY] = 2; + // setting on SETTINGS_MAX_FIELD_SECTION_SIZE is reduced. + settings.values[SETTINGS_MAX_FIELD_SECTION_SIZE] = 4; + settings.values[256] = 4; // unknown setting + session_->OnSettingsFrame(settings); +} + +TEST_P(QuicSpdyClientSessionTest, ServerAcceptsZeroRttButOmitSetting) { + if (!session_->version().UsesHttp3()) { + return; + } + + CompleteFirstConnection(); + + CreateConnection(); + CompleteCryptoHandshake(); + EXPECT_TRUE(session_->GetMutableCryptoStream()->EarlyDataAccepted()); + + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_HTTP_ZERO_RTT_RESUMPTION_SETTINGS_MISMATCH, _, _)) + .WillOnce(testing::Invoke(connection_, + &MockQuicConnection::ReallyCloseConnection)); + // Let the session receive a different SETTINGS frame. + SettingsFrame settings; + settings.values[SETTINGS_QPACK_MAX_TABLE_CAPACITY] = 1; + // Intentionally omit SETTINGS_MAX_FIELD_SECTION_SIZE which was previously + // sent with a non-zero value. + settings.values[256] = 4; // unknown setting + session_->OnSettingsFrame(settings); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/http/quic_spdy_client_stream.cc b/quiche/quic/core/http/quic_spdy_client_stream.cc new file mode 100644 index 000000000000..1661c94bdfd5 --- /dev/null +++ b/quiche/quic/core/http/quic_spdy_client_stream.cc @@ -0,0 +1,227 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_spdy_client_stream.h" + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/http/quic_client_promised_info.h" +#include "quiche/quic/core/http/quic_spdy_client_session.h" +#include "quiche/quic/core/http/spdy_utils.h" +#include "quiche/quic/core/http/web_transport_http3.h" +#include "quiche/quic/core/quic_alarm.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/quiche_text_utils.h" +#include "quiche/spdy/core/spdy_protocol.h" + +using spdy::Http2HeaderBlock; + +namespace quic { + +QuicSpdyClientStream::QuicSpdyClientStream(QuicStreamId id, + QuicSpdyClientSession* session, + StreamType type) + : QuicSpdyStream(id, session, type), + content_length_(-1), + response_code_(0), + header_bytes_read_(0), + header_bytes_written_(0), + session_(session), + has_preliminary_headers_(false) {} + +QuicSpdyClientStream::QuicSpdyClientStream(PendingStream* pending, + QuicSpdyClientSession* session) + : QuicSpdyStream(pending, session), + content_length_(-1), + response_code_(0), + header_bytes_read_(0), + header_bytes_written_(0), + session_(session), + has_preliminary_headers_(false) {} + +QuicSpdyClientStream::~QuicSpdyClientStream() = default; + +bool QuicSpdyClientStream::CopyAndValidateHeaders( + const QuicHeaderList& header_list, int64_t& content_length, + spdy::Http2HeaderBlock& headers) { + return SpdyUtils::CopyAndValidateHeaders(header_list, &content_length, + &headers); +} + +bool QuicSpdyClientStream::ParseAndValidateStatusCode() { + if (!ParseHeaderStatusCode(response_headers_, &response_code_)) { + QUIC_DLOG(ERROR) << "Received invalid response code: " + << response_headers_[":status"].as_string() + << " on stream " << id(); + Reset(QUIC_BAD_APPLICATION_PAYLOAD); + return false; + } + + if (response_code_ == 101) { + // 101 "Switching Protocols" is forbidden in HTTP/3 as per the + // "HTTP Upgrade" section of draft-ietf-quic-http. + QUIC_DLOG(ERROR) << "Received forbidden 101 response code" + << " on stream " << id(); + Reset(QUIC_BAD_APPLICATION_PAYLOAD); + return false; + } + + if (response_code_ >= 100 && response_code_ < 200) { + // These are Informational 1xx headers, not the actual response headers. + QUIC_DLOG(INFO) << "Received informational response code: " + << response_headers_[":status"].as_string() << " on stream " + << id(); + set_headers_decompressed(false); + if (response_code_ == 100 && !has_preliminary_headers_) { + // This is 100 Continue, save it to enable "Expect: 100-continue". + has_preliminary_headers_ = true; + preliminary_headers_ = std::move(response_headers_); + } else { + response_headers_.clear(); + } + } + + return true; +} + +void QuicSpdyClientStream::OnInitialHeadersComplete( + bool fin, size_t frame_len, const QuicHeaderList& header_list) { + QuicSpdyStream::OnInitialHeadersComplete(fin, frame_len, header_list); + + QUICHE_DCHECK(headers_decompressed()); + header_bytes_read_ += frame_len; + if (rst_sent()) { + // QuicSpdyStream::OnInitialHeadersComplete already rejected invalid + // response header. + return; + } + + if (!CopyAndValidateHeaders(header_list, content_length_, + response_headers_)) { + QUIC_DLOG(ERROR) << "Failed to parse header list: " + << header_list.DebugString() << " on stream " << id(); + Reset(QUIC_BAD_APPLICATION_PAYLOAD); + return; + } + + if (web_transport() != nullptr) { + web_transport()->HeadersReceived(response_headers_); + if (!web_transport()->ready()) { + // The request was rejected by WebTransport, typically due to not having a + // 2xx status. The reason we're using Reset() here rather than closing + // cleanly is that even if the server attempts to send us any form of body + // with a 4xx request, we've already set up the capsule parser, and we + // don't have any way to process anything from the response body in + // question. + Reset(QUIC_STREAM_CANCELLED); + return; + } + } + + if (!ParseAndValidateStatusCode()) { + return; + } + + ConsumeHeaderList(); + QUIC_DVLOG(1) << "headers complete for stream " << id(); + + session_->OnInitialHeadersComplete(id(), response_headers_); +} + +void QuicSpdyClientStream::OnTrailingHeadersComplete( + bool fin, size_t frame_len, const QuicHeaderList& header_list) { + QuicSpdyStream::OnTrailingHeadersComplete(fin, frame_len, header_list); + MarkTrailersConsumed(); +} + +void QuicSpdyClientStream::OnPromiseHeaderList( + QuicStreamId promised_id, size_t frame_len, + const QuicHeaderList& header_list) { + header_bytes_read_ += frame_len; + int64_t content_length = -1; + Http2HeaderBlock promise_headers; + if (!SpdyUtils::CopyAndValidateHeaders(header_list, &content_length, + &promise_headers)) { + QUIC_DLOG(ERROR) << "Failed to parse promise headers: " + << header_list.DebugString(); + Reset(QUIC_BAD_APPLICATION_PAYLOAD); + return; + } + + session_->HandlePromised(id(), promised_id, promise_headers); + if (visitor() != nullptr) { + visitor()->OnPromiseHeadersComplete(promised_id, frame_len); + } +} + +void QuicSpdyClientStream::OnBodyAvailable() { + // For push streams, visitor will not be set until the rendezvous + // between server promise and client request is complete. + if (visitor() == nullptr) return; + + while (HasBytesToRead()) { + struct iovec iov; + if (GetReadableRegions(&iov, 1) == 0) { + // No more data to read. + break; + } + QUIC_DVLOG(1) << "Client processed " << iov.iov_len << " bytes for stream " + << id(); + data_.append(static_cast(iov.iov_base), iov.iov_len); + + if (content_length_ >= 0 && + data_.size() > static_cast(content_length_)) { + QUIC_DLOG(ERROR) << "Invalid content length (" << content_length_ + << ") with data of size " << data_.size(); + Reset(QUIC_BAD_APPLICATION_PAYLOAD); + return; + } + MarkConsumed(iov.iov_len); + } + if (sequencer()->IsClosed()) { + OnFinRead(); + } else { + sequencer()->SetUnblocked(); + } +} + +size_t QuicSpdyClientStream::SendRequest(Http2HeaderBlock headers, + absl::string_view body, bool fin) { + QuicConnection::ScopedPacketFlusher flusher(session_->connection()); + bool send_fin_with_headers = fin && body.empty(); + size_t bytes_sent = body.size(); + header_bytes_written_ = + WriteHeaders(std::move(headers), send_fin_with_headers, nullptr); + bytes_sent += header_bytes_written_; + + if (!body.empty()) { + WriteOrBufferBody(body, fin); + } + + return bytes_sent; +} + +bool QuicSpdyClientStream::AreHeadersValid( + const QuicHeaderList& header_list) const { + if (!GetQuicReloadableFlag(quic_verify_request_headers_2)) { + return true; + } + if (!QuicSpdyStream::AreHeadersValid(header_list)) { + return false; + } + // Verify the presence of :status header. + bool saw_status = false; + for (const std::pair& pair : header_list) { + if (pair.first == ":status") { + saw_status = true; + } else if (absl::StrContains(pair.first, ":")) { + QUIC_DLOG(ERROR) << "Unexpected ':' in header " << pair.first << "."; + return false; + } + } + return saw_status; +} + +} // namespace quic diff --git a/quiche/quic/core/http/quic_spdy_client_stream.h b/quiche/quic/core/http/quic_spdy_client_stream.h new file mode 100644 index 000000000000..9a952aa23e91 --- /dev/null +++ b/quiche/quic/core/http/quic_spdy_client_stream.h @@ -0,0 +1,108 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_HTTP_QUIC_SPDY_CLIENT_STREAM_H_ +#define QUICHE_QUIC_CORE_HTTP_QUIC_SPDY_CLIENT_STREAM_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/http/quic_spdy_stream.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/spdy/core/http2_header_block.h" +#include "quiche/spdy/core/spdy_framer.h" + +namespace quic { + +class QuicSpdyClientSession; + +// All this does right now is send an SPDY request, and aggregate the +// SPDY response. +class QUIC_EXPORT_PRIVATE QuicSpdyClientStream : public QuicSpdyStream { + public: + QuicSpdyClientStream(QuicStreamId id, QuicSpdyClientSession* session, + StreamType type); + QuicSpdyClientStream(PendingStream* pending, + QuicSpdyClientSession* spdy_session); + QuicSpdyClientStream(const QuicSpdyClientStream&) = delete; + QuicSpdyClientStream& operator=(const QuicSpdyClientStream&) = delete; + ~QuicSpdyClientStream() override; + + // Override the base class to parse and store headers. + void OnInitialHeadersComplete(bool fin, size_t frame_len, + const QuicHeaderList& header_list) override; + + // Override the base class to parse and store trailers. + void OnTrailingHeadersComplete(bool fin, size_t frame_len, + const QuicHeaderList& header_list) override; + + // Override the base class to handle creation of the push stream. + void OnPromiseHeaderList(QuicStreamId promised_id, size_t frame_len, + const QuicHeaderList& header_list) override; + + // QuicStream implementation called by the session when there's data for us. + void OnBodyAvailable() override; + + // Serializes the headers and body, sends it to the server, and + // returns the number of bytes sent. + size_t SendRequest(spdy::Http2HeaderBlock headers, absl::string_view body, + bool fin); + + // Returns the response data. + absl::string_view data() const { return data_; } + + // Returns whatever headers have been received for this stream. + const spdy::Http2HeaderBlock& response_headers() { return response_headers_; } + + const spdy::Http2HeaderBlock& preliminary_headers() { + return preliminary_headers_; + } + + size_t header_bytes_read() const { return header_bytes_read_; } + + size_t header_bytes_written() const { return header_bytes_written_; } + + int response_code() const { return response_code_; } + + // While the server's SetPriority shouldn't be called externally, the creator + // of client-side streams should be able to set the priority. + using QuicSpdyStream::SetPriority; + + protected: + bool AreHeadersValid(const QuicHeaderList& header_list) const override; + + // Called by OnInitialHeadersComplete to set response_header_. Returns false + // on error. + virtual bool CopyAndValidateHeaders(const QuicHeaderList& header_list, + int64_t& content_length, + spdy::Http2HeaderBlock& headers); + + // Called by OnInitialHeadersComplete to set response_code_ based on + // response_header_. Returns false on error. + virtual bool ParseAndValidateStatusCode(); + + private: + // The parsed headers received from the server. + spdy::Http2HeaderBlock response_headers_; + + // The parsed content-length, or -1 if none is specified. + int64_t content_length_; + int response_code_; + std::string data_; + size_t header_bytes_read_; + size_t header_bytes_written_; + + QuicSpdyClientSession* session_; + + // These preliminary headers are used for the 100 Continue headers + // that may arrive before the response headers when the request has + // Expect: 100-continue. + bool has_preliminary_headers_; + spdy::Http2HeaderBlock preliminary_headers_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_HTTP_QUIC_SPDY_CLIENT_STREAM_H_ diff --git a/quiche/quic/core/http/quic_spdy_client_stream_test.cc b/quiche/quic/core/http/quic_spdy_client_stream_test.cc new file mode 100644 index 000000000000..d5b41374ebef --- /dev/null +++ b/quiche/quic/core/http/quic_spdy_client_stream_test.cc @@ -0,0 +1,316 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_spdy_client_stream.h" + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/core/http/quic_spdy_client_session.h" +#include "quiche/quic/core/http/spdy_utils.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/quic_spdy_session_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/simple_buffer_allocator.h" + +using spdy::Http2HeaderBlock; +using testing::_; +using testing::StrictMock; + +namespace quic { +namespace test { + +namespace { + +class MockQuicSpdyClientSession : public QuicSpdyClientSession { + public: + explicit MockQuicSpdyClientSession( + const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, + QuicClientPushPromiseIndex* push_promise_index) + : QuicSpdyClientSession(DefaultQuicConfig(), supported_versions, + connection, + QuicServerId("example.com", 443, false), + &crypto_config_, push_promise_index), + crypto_config_(crypto_test_utils::ProofVerifierForTesting()) {} + MockQuicSpdyClientSession(const MockQuicSpdyClientSession&) = delete; + MockQuicSpdyClientSession& operator=(const MockQuicSpdyClientSession&) = + delete; + ~MockQuicSpdyClientSession() override = default; + + MOCK_METHOD(bool, WriteControlFrame, + (const QuicFrame& frame, TransmissionType type), (override)); + + using QuicSession::ActivateStream; + + private: + QuicCryptoClientConfig crypto_config_; +}; + +class QuicSpdyClientStreamTest : public QuicTestWithParam { + public: + class StreamVisitor; + + QuicSpdyClientStreamTest() + : connection_(new StrictMock( + &helper_, &alarm_factory_, Perspective::IS_CLIENT, + SupportedVersions(GetParam()))), + session_(connection_->supported_versions(), connection_, + &push_promise_index_), + body_("hello world") { + session_.Initialize(); + connection_->AdvanceTime(QuicTime::Delta::FromSeconds(1)); + connection_->SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(connection_->perspective())); + headers_[":status"] = "200"; + headers_["content-length"] = "11"; + + auto stream = std::make_unique( + GetNthClientInitiatedBidirectionalStreamId( + connection_->transport_version(), 0), + &session_, BIDIRECTIONAL); + stream_ = stream.get(); + session_.ActivateStream(std::move(stream)); + + stream_visitor_ = std::make_unique(); + stream_->set_visitor(stream_visitor_.get()); + } + + class StreamVisitor : public QuicSpdyClientStream::Visitor { + void OnClose(QuicSpdyStream* stream) override { + QUIC_DVLOG(1) << "stream " << stream->id(); + } + }; + + MockQuicConnectionHelper helper_; + MockAlarmFactory alarm_factory_; + StrictMock* connection_; + QuicClientPushPromiseIndex push_promise_index_; + + MockQuicSpdyClientSession session_; + QuicSpdyClientStream* stream_; + std::unique_ptr stream_visitor_; + Http2HeaderBlock headers_; + std::string body_; +}; + +INSTANTIATE_TEST_SUITE_P(Tests, QuicSpdyClientStreamTest, + ::testing::ValuesIn(AllSupportedVersions()), + ::testing::PrintToStringParamName()); + +TEST_P(QuicSpdyClientStreamTest, TestReceivingIllegalResponseStatusCode) { + headers_[":status"] = "200 ok"; + + EXPECT_CALL(session_, WriteControlFrame(_, _)); + EXPECT_CALL(*connection_, + OnStreamReset(stream_->id(), QUIC_BAD_APPLICATION_PAYLOAD)); + auto headers = AsHeaderList(headers_); + stream_->OnStreamHeaderList(false, headers.uncompressed_header_bytes(), + headers); + EXPECT_THAT(stream_->stream_error(), + IsStreamError(QUIC_BAD_APPLICATION_PAYLOAD)); +} + +TEST_P(QuicSpdyClientStreamTest, InvalidResponseHeader) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + auto headers = AsHeaderList(std::vector>{ + {":status", "200"}, {":path", "/foo"}}); + EXPECT_CALL(*connection_, + OnStreamReset(stream_->id(), QUIC_BAD_APPLICATION_PAYLOAD)); + stream_->OnStreamHeaderList(false, headers.uncompressed_header_bytes(), + headers); + EXPECT_THAT(stream_->stream_error(), + IsStreamError(QUIC_BAD_APPLICATION_PAYLOAD)); +} + +TEST_P(QuicSpdyClientStreamTest, MissingStatusCode) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + auto headers = AsHeaderList( + std::vector>{{"key", "value"}}); + EXPECT_CALL(*connection_, + OnStreamReset(stream_->id(), QUIC_BAD_APPLICATION_PAYLOAD)); + stream_->OnStreamHeaderList(false, headers.uncompressed_header_bytes(), + headers); + EXPECT_THAT(stream_->stream_error(), + IsStreamError(QUIC_BAD_APPLICATION_PAYLOAD)); +} + +TEST_P(QuicSpdyClientStreamTest, TestFraming) { + auto headers = AsHeaderList(headers_); + stream_->OnStreamHeaderList(false, headers.uncompressed_header_bytes(), + headers); + quiche::QuicheBuffer header = HttpEncoder::SerializeDataFrameHeader( + body_.length(), quiche::SimpleBufferAllocator::Get()); + std::string data = VersionUsesHttp3(connection_->transport_version()) + ? absl::StrCat(header.AsStringView(), body_) + : body_; + stream_->OnStreamFrame( + QuicStreamFrame(stream_->id(), /*fin=*/false, /*offset=*/0, data)); + EXPECT_EQ("200", stream_->response_headers().find(":status")->second); + EXPECT_EQ(200, stream_->response_code()); + EXPECT_EQ(body_, stream_->data()); +} + +TEST_P(QuicSpdyClientStreamTest, Test100ContinueBeforeSuccessful) { + // First send 100 Continue. + headers_[":status"] = "100"; + auto headers = AsHeaderList(headers_); + stream_->OnStreamHeaderList(false, headers.uncompressed_header_bytes(), + headers); + EXPECT_EQ("100", stream_->preliminary_headers().find(":status")->second); + EXPECT_EQ(0u, stream_->response_headers().size()); + EXPECT_EQ(100, stream_->response_code()); + EXPECT_EQ("", stream_->data()); + // Then send 200 OK. + headers_[":status"] = "200"; + headers = AsHeaderList(headers_); + stream_->OnStreamHeaderList(false, headers.uncompressed_header_bytes(), + headers); + quiche::QuicheBuffer header = HttpEncoder::SerializeDataFrameHeader( + body_.length(), quiche::SimpleBufferAllocator::Get()); + std::string data = VersionUsesHttp3(connection_->transport_version()) + ? absl::StrCat(header.AsStringView(), body_) + : body_; + stream_->OnStreamFrame( + QuicStreamFrame(stream_->id(), /*fin=*/false, /*offset=*/0, data)); + // Make sure the 200 response got parsed correctly. + EXPECT_EQ("200", stream_->response_headers().find(":status")->second); + EXPECT_EQ(200, stream_->response_code()); + EXPECT_EQ(body_, stream_->data()); + // Make sure the 100 response is still available. + EXPECT_EQ("100", stream_->preliminary_headers().find(":status")->second); +} + +TEST_P(QuicSpdyClientStreamTest, TestUnknownInformationalBeforeSuccessful) { + // First send 199, an unknown Informational (1XX). + headers_[":status"] = "199"; + auto headers = AsHeaderList(headers_); + stream_->OnStreamHeaderList(false, headers.uncompressed_header_bytes(), + headers); + EXPECT_EQ(0u, stream_->response_headers().size()); + EXPECT_EQ(199, stream_->response_code()); + EXPECT_EQ("", stream_->data()); + // Then send 200 OK. + headers_[":status"] = "200"; + headers = AsHeaderList(headers_); + stream_->OnStreamHeaderList(false, headers.uncompressed_header_bytes(), + headers); + quiche::QuicheBuffer header = HttpEncoder::SerializeDataFrameHeader( + body_.length(), quiche::SimpleBufferAllocator::Get()); + std::string data = VersionUsesHttp3(connection_->transport_version()) + ? absl::StrCat(header.AsStringView(), body_) + : body_; + stream_->OnStreamFrame( + QuicStreamFrame(stream_->id(), /*fin=*/false, /*offset=*/0, data)); + // Make sure the 200 response got parsed correctly. + EXPECT_EQ("200", stream_->response_headers().find(":status")->second); + EXPECT_EQ(200, stream_->response_code()); + EXPECT_EQ(body_, stream_->data()); +} + +TEST_P(QuicSpdyClientStreamTest, TestReceiving101) { + // 101 "Switching Protocols" is forbidden in HTTP/3 as per the + // "HTTP Upgrade" section of draft-ietf-quic-http. + headers_[":status"] = "101"; + EXPECT_CALL(session_, WriteControlFrame(_, _)); + EXPECT_CALL(*connection_, + OnStreamReset(stream_->id(), QUIC_BAD_APPLICATION_PAYLOAD)); + auto headers = AsHeaderList(headers_); + stream_->OnStreamHeaderList(false, headers.uncompressed_header_bytes(), + headers); + EXPECT_THAT(stream_->stream_error(), + IsStreamError(QUIC_BAD_APPLICATION_PAYLOAD)); +} + +TEST_P(QuicSpdyClientStreamTest, TestFramingOnePacket) { + auto headers = AsHeaderList(headers_); + stream_->OnStreamHeaderList(false, headers.uncompressed_header_bytes(), + headers); + quiche::QuicheBuffer header = HttpEncoder::SerializeDataFrameHeader( + body_.length(), quiche::SimpleBufferAllocator::Get()); + std::string data = VersionUsesHttp3(connection_->transport_version()) + ? absl::StrCat(header.AsStringView(), body_) + : body_; + stream_->OnStreamFrame( + QuicStreamFrame(stream_->id(), /*fin=*/false, /*offset=*/0, data)); + EXPECT_EQ("200", stream_->response_headers().find(":status")->second); + EXPECT_EQ(200, stream_->response_code()); + EXPECT_EQ(body_, stream_->data()); +} + +TEST_P(QuicSpdyClientStreamTest, + QUIC_TEST_DISABLED_IN_CHROME(TestFramingExtraData)) { + std::string large_body = "hello world!!!!!!"; + + auto headers = AsHeaderList(headers_); + stream_->OnStreamHeaderList(false, headers.uncompressed_header_bytes(), + headers); + // The headers should parse successfully. + EXPECT_THAT(stream_->stream_error(), IsQuicStreamNoError()); + EXPECT_EQ("200", stream_->response_headers().find(":status")->second); + EXPECT_EQ(200, stream_->response_code()); + quiche::QuicheBuffer header = HttpEncoder::SerializeDataFrameHeader( + large_body.length(), quiche::SimpleBufferAllocator::Get()); + std::string data = VersionUsesHttp3(connection_->transport_version()) + ? absl::StrCat(header.AsStringView(), large_body) + : large_body; + EXPECT_CALL(session_, WriteControlFrame(_, _)); + EXPECT_CALL(*connection_, + OnStreamReset(stream_->id(), QUIC_BAD_APPLICATION_PAYLOAD)); + + stream_->OnStreamFrame( + QuicStreamFrame(stream_->id(), /*fin=*/false, /*offset=*/0, data)); + + EXPECT_NE(QUIC_STREAM_NO_ERROR, stream_->stream_error()); +} + +// Test that receiving trailing headers (on the headers stream), containing a +// final offset, results in the stream being closed at that byte offset. +TEST_P(QuicSpdyClientStreamTest, ReceivingTrailers) { + // There is no kFinalOffsetHeaderKey if trailers are sent on the + // request/response stream. + if (VersionUsesHttp3(connection_->transport_version())) { + return; + } + + // Send headers as usual. + auto headers = AsHeaderList(headers_); + stream_->OnStreamHeaderList(false, headers.uncompressed_header_bytes(), + headers); + + // Send trailers before sending the body. Even though a FIN has been received + // the stream should not be closed, as it does not yet have all the data bytes + // promised by the final offset field. + Http2HeaderBlock trailer_block; + trailer_block["trailer key"] = "trailer value"; + trailer_block[kFinalOffsetHeaderKey] = absl::StrCat(body_.size()); + auto trailers = AsHeaderList(trailer_block); + stream_->OnStreamHeaderList(true, trailers.uncompressed_header_bytes(), + trailers); + + // Now send the body, which should close the stream as the FIN has been + // received, as well as all data. + quiche::QuicheBuffer header = HttpEncoder::SerializeDataFrameHeader( + body_.length(), quiche::SimpleBufferAllocator::Get()); + std::string data = VersionUsesHttp3(connection_->transport_version()) + ? absl::StrCat(header.AsStringView(), body_) + : body_; + stream_->OnStreamFrame( + QuicStreamFrame(stream_->id(), /*fin=*/false, /*offset=*/0, data)); + EXPECT_TRUE(stream_->reading_stopped()); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/http/quic_spdy_server_stream_base.cc b/quiche/quic/core/http/quic_spdy_server_stream_base.cc new file mode 100644 index 000000000000..9852302a21a2 --- /dev/null +++ b/quiche/quic/core/http/quic_spdy_server_stream_base.cc @@ -0,0 +1,134 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_spdy_server_stream_base.h" + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/http/quic_spdy_session.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/quiche_text_utils.h" + +namespace quic { + +QuicSpdyServerStreamBase::QuicSpdyServerStreamBase(QuicStreamId id, + QuicSpdySession* session, + StreamType type) + : QuicSpdyStream(id, session, type) {} + +QuicSpdyServerStreamBase::QuicSpdyServerStreamBase(PendingStream* pending, + QuicSpdySession* session) + : QuicSpdyStream(pending, session) {} + +void QuicSpdyServerStreamBase::CloseWriteSide() { + if (!fin_received() && !rst_received() && sequencer()->ignore_read_data() && + !rst_sent()) { + // Early cancel the stream if it has stopped reading before receiving FIN + // or RST. + QUICHE_DCHECK(fin_sent() || !session()->connection()->connected()); + // Tell the peer to stop sending further data. + QUIC_DVLOG(1) << " Server: Send QUIC_STREAM_NO_ERROR on stream " << id(); + MaybeSendStopSending(QUIC_STREAM_NO_ERROR); + } + + QuicSpdyStream::CloseWriteSide(); +} + +void QuicSpdyServerStreamBase::StopReading() { + if (!fin_received() && !rst_received() && write_side_closed() && + !rst_sent()) { + QUICHE_DCHECK(fin_sent()); + // Tell the peer to stop sending further data. + QUIC_DVLOG(1) << " Server: Send QUIC_STREAM_NO_ERROR on stream " << id(); + MaybeSendStopSending(QUIC_STREAM_NO_ERROR); + } + QuicSpdyStream::StopReading(); +} + +bool QuicSpdyServerStreamBase::AreHeadersValid( + const QuicHeaderList& header_list) const { + if (!GetQuicReloadableFlag(quic_verify_request_headers_2)) { + return true; + } + QUIC_RELOADABLE_FLAG_COUNT_N(quic_verify_request_headers_2, 2, 3); + if (!QuicSpdyStream::AreHeadersValid(header_list)) { + return false; + } + + bool saw_connect = false; + bool saw_protocol = false; + bool saw_path = false; + bool saw_scheme = false; + bool saw_method = false; + bool saw_authority = false; + bool is_extended_connect = false; + // Check if it is missing any required headers and if there is any disallowed + // ones. + for (const std::pair& pair : header_list) { + if (pair.first == ":method") { + saw_method = true; + if (pair.second == "CONNECT") { + saw_connect = true; + if (saw_protocol) { + is_extended_connect = true; + } + } + } else if (pair.first == ":protocol") { + saw_protocol = true; + if (saw_connect) { + is_extended_connect = true; + } + } else if (pair.first == ":scheme") { + saw_scheme = true; + } else if (pair.first == ":path") { + saw_path = true; + } else if (pair.first == ":authority") { + saw_authority = true; + } else if (absl::StrContains(pair.first, ":")) { + QUIC_DLOG(ERROR) << "Unexpected ':' in header " << pair.first << "."; + return false; + } + if (is_extended_connect) { + if (!spdy_session()->allow_extended_connect()) { + QUIC_DLOG(ERROR) + << "Received extended-CONNECT request while it is disabled."; + return false; + } + } else if (saw_method && !saw_connect) { + if (saw_protocol) { + QUIC_DLOG(ERROR) << "Receive non-CONNECT request with :protocol."; + return false; + } + } + } + + if (is_extended_connect) { + if (saw_scheme && saw_path && saw_authority) { + // Saw all the required pseudo headers. + return true; + } + QUIC_DLOG(ERROR) << "Missing required pseudo headers for extended-CONNECT."; + return false; + } + // This is a vanilla CONNECT or non-CONNECT request. + if (saw_connect) { + // Check vanilla CONNECT. + if (saw_path || saw_scheme) { + QUIC_DLOG(ERROR) + << "Received invalid CONNECT request with disallowed pseudo header."; + return false; + } + return true; + } + // Check non-CONNECT request. + if (saw_method && saw_authority && saw_path && saw_scheme) { + return true; + } + QUIC_LOG(ERROR) << "Missing required pseudo headers."; + return false; +} + +} // namespace quic diff --git a/quiche/quic/core/http/quic_spdy_server_stream_base.h b/quiche/quic/core/http/quic_spdy_server_stream_base.h new file mode 100644 index 000000000000..dd3423b7a044 --- /dev/null +++ b/quiche/quic/core/http/quic_spdy_server_stream_base.h @@ -0,0 +1,31 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_HTTP_QUIC_SPDY_SERVER_STREAM_BASE_H_ +#define QUICHE_QUIC_CORE_HTTP_QUIC_SPDY_SERVER_STREAM_BASE_H_ + +#include "quiche/quic/core/http/quic_spdy_stream.h" + +namespace quic { + +class QUIC_NO_EXPORT QuicSpdyServerStreamBase : public QuicSpdyStream { + public: + QuicSpdyServerStreamBase(QuicStreamId id, QuicSpdySession* session, + StreamType type); + QuicSpdyServerStreamBase(PendingStream* pending, QuicSpdySession* session); + QuicSpdyServerStreamBase(const QuicSpdyServerStreamBase&) = delete; + QuicSpdyServerStreamBase& operator=(const QuicSpdyServerStreamBase&) = delete; + + // Override the base class to send QUIC_STREAM_NO_ERROR to the peer + // when the stream has not received all the data. + void CloseWriteSide() override; + void StopReading() override; + + protected: + bool AreHeadersValid(const QuicHeaderList& header_list) const override; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_HTTP_QUIC_SPDY_SERVER_STREAM_BASE_H_ diff --git a/quiche/quic/core/http/quic_spdy_server_stream_base_test.cc b/quiche/quic/core/http/quic_spdy_server_stream_base_test.cc new file mode 100644 index 000000000000..8ebf712d62e8 --- /dev/null +++ b/quiche/quic/core/http/quic_spdy_server_stream_base_test.cc @@ -0,0 +1,336 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_spdy_server_stream_base.h" + +#include "absl/memory/memory.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/qpack/qpack_test_utils.h" +#include "quiche/quic/test_tools/quic_spdy_session_peer.h" +#include "quiche/quic/test_tools/quic_stream_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/spdy/core/http2_header_block.h" + +using testing::_; + +namespace quic { +namespace test { +namespace { + +class TestQuicSpdyServerStream : public QuicSpdyServerStreamBase { + public: + TestQuicSpdyServerStream(QuicStreamId id, QuicSpdySession* session, + StreamType type) + : QuicSpdyServerStreamBase(id, session, type) {} + + void OnBodyAvailable() override {} +}; + +class QuicSpdyServerStreamBaseTest : public QuicTest { + protected: + QuicSpdyServerStreamBaseTest() + : session_(new MockQuicConnection(&helper_, &alarm_factory_, + Perspective::IS_SERVER)) { + session_.Initialize(); + session_.connection()->SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(session_.perspective())); + stream_ = + new TestQuicSpdyServerStream(GetNthClientInitiatedBidirectionalStreamId( + session_.transport_version(), 0), + &session_, BIDIRECTIONAL); + session_.ActivateStream(absl::WrapUnique(stream_)); + helper_.AdvanceTime(QuicTime::Delta::FromSeconds(1)); + } + + QuicSpdyServerStreamBase* stream_ = nullptr; + MockQuicConnectionHelper helper_; + MockAlarmFactory alarm_factory_; + MockQuicSpdySession session_; +}; + +TEST_F(QuicSpdyServerStreamBaseTest, + SendQuicRstStreamNoErrorWithEarlyResponse) { + stream_->StopReading(); + + if (session_.version().UsesHttp3()) { + EXPECT_CALL(session_, + MaybeSendStopSendingFrame(_, QuicResetStreamError::FromInternal( + QUIC_STREAM_NO_ERROR))) + .Times(1); + } else { + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_STREAM_NO_ERROR), _)) + .Times(1); + } + QuicStreamPeer::SetFinSent(stream_); + stream_->CloseWriteSide(); +} + +TEST_F(QuicSpdyServerStreamBaseTest, + DoNotSendQuicRstStreamNoErrorWithRstReceived) { + EXPECT_FALSE(stream_->reading_stopped()); + + EXPECT_CALL(session_, + MaybeSendRstStreamFrame( + _, + QuicResetStreamError::FromInternal( + VersionHasIetfQuicFrames(session_.transport_version()) + ? QUIC_STREAM_CANCELLED + : QUIC_RST_ACKNOWLEDGEMENT), + _)) + .Times(1); + QuicRstStreamFrame rst_frame(kInvalidControlFrameId, stream_->id(), + QUIC_STREAM_CANCELLED, 1234); + stream_->OnStreamReset(rst_frame); + if (VersionHasIetfQuicFrames(session_.transport_version())) { + // Create and inject a STOP SENDING frame to complete the close + // of the stream. This is only needed for version 99/IETF QUIC. + QuicStopSendingFrame stop_sending(kInvalidControlFrameId, stream_->id(), + QUIC_STREAM_CANCELLED); + session_.OnStopSendingFrame(stop_sending); + } + + EXPECT_TRUE(stream_->reading_stopped()); + EXPECT_TRUE(stream_->write_side_closed()); +} + +TEST_F(QuicSpdyServerStreamBaseTest, AllowExtendedConnect) { + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":authority", "www.google.com:4433"); + header_list.OnHeader(":method", "CONNECT"); + header_list.OnHeader(":protocol", "webtransport"); + header_list.OnHeader(":path", "/path"); + header_list.OnHeader(":scheme", "http"); + header_list.OnHeaderBlockEnd(128, 128); + stream_->OnStreamHeaderList(/*fin=*/false, 0, header_list); + EXPECT_EQ(GetQuicReloadableFlag(quic_verify_request_headers_2) && + GetQuicReloadableFlag(quic_act_upon_invalid_header) && + !session_.allow_extended_connect(), + stream_->rst_sent()); +} + +TEST_F(QuicSpdyServerStreamBaseTest, AllowExtendedConnectProtocolFirst) { + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":protocol", "webtransport"); + header_list.OnHeader(":authority", "www.google.com:4433"); + header_list.OnHeader(":method", "CONNECT"); + header_list.OnHeader(":path", "/path"); + header_list.OnHeader(":scheme", "http"); + header_list.OnHeaderBlockEnd(128, 128); + stream_->OnStreamHeaderList(/*fin=*/false, 0, header_list); + EXPECT_EQ(GetQuicReloadableFlag(quic_verify_request_headers_2) && + GetQuicReloadableFlag(quic_act_upon_invalid_header) && + !session_.allow_extended_connect(), + stream_->rst_sent()); +} + +TEST_F(QuicSpdyServerStreamBaseTest, InvalidExtendedConnect) { + if (!session_.version().UsesHttp3()) { + return; + } + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":authority", "www.google.com:4433"); + header_list.OnHeader(":method", "CONNECT"); + header_list.OnHeader(":protocol", "webtransport"); + header_list.OnHeader(":scheme", "http"); + header_list.OnHeaderBlockEnd(128, 128); + + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_BAD_APPLICATION_PAYLOAD), + _)); + stream_->OnStreamHeaderList(/*fin=*/false, 0, header_list); + EXPECT_TRUE(stream_->rst_sent()); +} + +TEST_F(QuicSpdyServerStreamBaseTest, VanillaConnectAllowed) { + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":authority", "www.google.com:4433"); + header_list.OnHeader(":method", "CONNECT"); + header_list.OnHeaderBlockEnd(128, 128); + stream_->OnStreamHeaderList(/*fin=*/false, 0, header_list); + EXPECT_FALSE(stream_->rst_sent()); +} + +TEST_F(QuicSpdyServerStreamBaseTest, InvalidVanillaConnect) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":authority", "www.google.com:4433"); + header_list.OnHeader(":method", "CONNECT"); + header_list.OnHeader(":scheme", "http"); + header_list.OnHeaderBlockEnd(128, 128); + + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_BAD_APPLICATION_PAYLOAD), + _)); + stream_->OnStreamHeaderList(/*fin=*/false, 0, header_list); + EXPECT_TRUE(stream_->rst_sent()); +} + +TEST_F(QuicSpdyServerStreamBaseTest, InvalidNonConnectWithProtocol) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":authority", "www.google.com:4433"); + header_list.OnHeader(":method", "GET"); + header_list.OnHeader(":scheme", "http"); + header_list.OnHeader(":path", "/path"); + header_list.OnHeader(":protocol", "webtransport"); + header_list.OnHeaderBlockEnd(128, 128); + + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_BAD_APPLICATION_PAYLOAD), + _)); + stream_->OnStreamHeaderList(/*fin=*/false, 0, header_list); + EXPECT_TRUE(stream_->rst_sent()); +} + +TEST_F(QuicSpdyServerStreamBaseTest, InvalidRequestWithoutScheme) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + // A request without :scheme should be rejected. + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":authority", "www.google.com:4433"); + header_list.OnHeader(":method", "GET"); + header_list.OnHeader(":path", "/path"); + header_list.OnHeaderBlockEnd(128, 128); + + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_BAD_APPLICATION_PAYLOAD), + _)); + stream_->OnStreamHeaderList(/*fin=*/false, 0, header_list); + EXPECT_TRUE(stream_->rst_sent()); +} + +TEST_F(QuicSpdyServerStreamBaseTest, InvalidRequestWithoutAuthority) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + // A request without :authority should be rejected. + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":scheme", "http"); + header_list.OnHeader(":method", "GET"); + header_list.OnHeader(":path", "/path"); + header_list.OnHeaderBlockEnd(128, 128); + + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_BAD_APPLICATION_PAYLOAD), + _)); + stream_->OnStreamHeaderList(/*fin=*/false, 0, header_list); + EXPECT_TRUE(stream_->rst_sent()); +} + +TEST_F(QuicSpdyServerStreamBaseTest, InvalidRequestWithoutMethod) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + // A request without :method should be rejected. + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":authority", "www.google.com:4433"); + header_list.OnHeader(":scheme", "http"); + header_list.OnHeader(":path", "/path"); + header_list.OnHeaderBlockEnd(128, 128); + + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_BAD_APPLICATION_PAYLOAD), + _)); + stream_->OnStreamHeaderList(/*fin=*/false, 0, header_list); + EXPECT_TRUE(stream_->rst_sent()); +} + +TEST_F(QuicSpdyServerStreamBaseTest, InvalidRequestWithoutPath) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + // A request without :path should be rejected. + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":authority", "www.google.com:4433"); + header_list.OnHeader(":scheme", "http"); + header_list.OnHeader(":method", "POST"); + header_list.OnHeaderBlockEnd(128, 128); + + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_BAD_APPLICATION_PAYLOAD), + _)); + stream_->OnStreamHeaderList(/*fin=*/false, 0, header_list); + EXPECT_TRUE(stream_->rst_sent()); +} + +TEST_F(QuicSpdyServerStreamBaseTest, InvalidRequestHeader) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + // A request without :path should be rejected. + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":authority", "www.google.com:4433"); + header_list.OnHeader(":scheme", "http"); + header_list.OnHeader(":method", "POST"); + header_list.OnHeader("invalid:header", "value"); + header_list.OnHeaderBlockEnd(128, 128); + + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_BAD_APPLICATION_PAYLOAD), + _)); + stream_->OnStreamHeaderList(/*fin=*/false, 0, header_list); + EXPECT_TRUE(stream_->rst_sent()); +} + +TEST_F(QuicSpdyServerStreamBaseTest, EmptyHeaders) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + spdy::Http2HeaderBlock empty_header; + quic::test::NoopQpackStreamSenderDelegate encoder_stream_sender_delegate; + NoopDecoderStreamErrorDelegate decoder_stream_error_delegate; + auto qpack_encoder = + std::make_unique(&decoder_stream_error_delegate); + qpack_encoder->set_qpack_stream_sender_delegate( + &encoder_stream_sender_delegate); + std::string payload = + qpack_encoder->EncodeHeaderList(stream_->id(), empty_header, nullptr); + std::string headers_frame_header = + quic::HttpEncoder::SerializeHeadersFrameHeader(payload.length()); + + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_BAD_APPLICATION_PAYLOAD), + _)); + stream_->OnStreamFrame(QuicStreamFrame( + stream_->id(), true, 0, absl::StrCat(headers_frame_header, payload))); + EXPECT_TRUE(stream_->rst_sent()); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/http/quic_spdy_session.cc b/quiche/quic/core/http/quic_spdy_session.cc new file mode 100644 index 000000000000..e024213be78f --- /dev/null +++ b/quiche/quic/core/http/quic_spdy_session.cc @@ -0,0 +1,1862 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_spdy_session.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/http/http_constants.h" +#include "quiche/quic/core/http/http_decoder.h" +#include "quiche/quic/core/http/http_frames.h" +#include "quiche/quic/core/http/quic_headers_stream.h" +#include "quiche/quic/core/http/web_transport_http3.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_exported_stats.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_stack_trace.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/spdy/core/http2_frame_decoder_adapter.h" + +using http2::Http2DecoderAdapter; +using spdy::Http2HeaderBlock; +using spdy::Http2WeightToSpdy3Priority; +using spdy::Spdy3PriorityToHttp2Weight; +using spdy::SpdyErrorCode; +using spdy::SpdyFramer; +using spdy::SpdyFramerDebugVisitorInterface; +using spdy::SpdyFramerVisitorInterface; +using spdy::SpdyFrameType; +using spdy::SpdyHeadersHandlerInterface; +using spdy::SpdyHeadersIR; +using spdy::SpdyPingId; +using spdy::SpdyPriority; +using spdy::SpdyPriorityIR; +using spdy::SpdyPushPromiseIR; +using spdy::SpdySerializedFrame; +using spdy::SpdySettingsId; +using spdy::SpdyStreamId; + +namespace quic { + +ABSL_CONST_INIT const size_t kMaxUnassociatedWebTransportStreams = 24; + +namespace { + +// Limit on HPACK encoder dynamic table size. +// Only used for Google QUIC, not IETF QUIC. +constexpr uint64_t kHpackEncoderDynamicTableSizeLimit = 16384; + +#define ENDPOINT \ + (perspective() == Perspective::IS_SERVER ? "Server: " : "Client: ") + +// Class to forward ACCEPT_CH frame to QuicSpdySession, +// and ignore every other frame. +class AlpsFrameDecoder : public HttpDecoder::Visitor { + public: + explicit AlpsFrameDecoder(QuicSpdySession* session) : session_(session) {} + ~AlpsFrameDecoder() override = default; + + // HttpDecoder::Visitor implementation. + void OnError(HttpDecoder* /*decoder*/) override {} + bool OnMaxPushIdFrame() override { + error_detail_ = "MAX_PUSH_ID frame forbidden"; + return false; + } + bool OnGoAwayFrame(const GoAwayFrame& /*frame*/) override { + error_detail_ = "GOAWAY frame forbidden"; + return false; + } + bool OnSettingsFrameStart(QuicByteCount /*header_length*/) override { + return true; + } + bool OnSettingsFrame(const SettingsFrame& frame) override { + if (settings_frame_received_via_alps_) { + error_detail_ = "multiple SETTINGS frames"; + return false; + } + + settings_frame_received_via_alps_ = true; + + error_detail_ = session_->OnSettingsFrameViaAlps(frame); + return !error_detail_; + } + bool OnDataFrameStart(QuicByteCount /*header_length*/, QuicByteCount + /*payload_length*/) override { + error_detail_ = "DATA frame forbidden"; + return false; + } + bool OnDataFramePayload(absl::string_view /*payload*/) override { + QUICHE_NOTREACHED(); + return false; + } + bool OnDataFrameEnd() override { + QUICHE_NOTREACHED(); + return false; + } + bool OnHeadersFrameStart(QuicByteCount /*header_length*/, + QuicByteCount /*payload_length*/) override { + error_detail_ = "HEADERS frame forbidden"; + return false; + } + bool OnHeadersFramePayload(absl::string_view /*payload*/) override { + QUICHE_NOTREACHED(); + return false; + } + bool OnHeadersFrameEnd() override { + QUICHE_NOTREACHED(); + return false; + } + bool OnPriorityUpdateFrameStart(QuicByteCount /*header_length*/) override { + error_detail_ = "PRIORITY_UPDATE frame forbidden"; + return false; + } + bool OnPriorityUpdateFrame(const PriorityUpdateFrame& /*frame*/) override { + QUICHE_NOTREACHED(); + return false; + } + bool OnAcceptChFrameStart(QuicByteCount /*header_length*/) override { + return true; + } + bool OnAcceptChFrame(const AcceptChFrame& frame) override { + session_->OnAcceptChFrameReceivedViaAlps(frame); + return true; + } + void OnWebTransportStreamFrameType( + QuicByteCount /*header_length*/, + WebTransportSessionId /*session_id*/) override { + QUICHE_NOTREACHED(); + } + bool OnUnknownFrameStart(uint64_t /*frame_type*/, + QuicByteCount + /*header_length*/, + QuicByteCount /*payload_length*/) override { + return true; + } + bool OnUnknownFramePayload(absl::string_view /*payload*/) override { + return true; + } + bool OnUnknownFrameEnd() override { return true; } + + const absl::optional& error_detail() const { + return error_detail_; + } + + private: + QuicSpdySession* const session_; + absl::optional error_detail_; + + // True if SETTINGS frame has been received via ALPS. + bool settings_frame_received_via_alps_ = false; +}; + +} // namespace + +// A SpdyFramerVisitor that passes HEADERS frames to the QuicSpdyStream, and +// closes the connection if any unexpected frames are received. +class QuicSpdySession::SpdyFramerVisitor + : public SpdyFramerVisitorInterface, + public SpdyFramerDebugVisitorInterface { + public: + explicit SpdyFramerVisitor(QuicSpdySession* session) : session_(session) {} + SpdyFramerVisitor(const SpdyFramerVisitor&) = delete; + SpdyFramerVisitor& operator=(const SpdyFramerVisitor&) = delete; + + SpdyHeadersHandlerInterface* OnHeaderFrameStart( + SpdyStreamId /* stream_id */) override { + QUICHE_DCHECK(!VersionUsesHttp3(session_->transport_version())); + return &header_list_; + } + + void OnHeaderFrameEnd(SpdyStreamId /* stream_id */) override { + QUICHE_DCHECK(!VersionUsesHttp3(session_->transport_version())); + + LogHeaderCompressionRatioHistogram( + /* using_qpack = */ false, + /* is_sent = */ false, header_list_.compressed_header_bytes(), + header_list_.uncompressed_header_bytes()); + + if (session_->IsConnected()) { + session_->OnHeaderList(header_list_); + } + header_list_.Clear(); + } + + void OnStreamFrameData(SpdyStreamId /*stream_id*/, const char* /*data*/, + size_t /*len*/) override { + QUICHE_DCHECK(!VersionUsesHttp3(session_->transport_version())); + CloseConnection("SPDY DATA frame received.", + QUIC_INVALID_HEADERS_STREAM_DATA); + } + + void OnStreamEnd(SpdyStreamId /*stream_id*/) override { + // The framer invokes OnStreamEnd after processing a frame that had the fin + // bit set. + } + + void OnStreamPadding(SpdyStreamId /*stream_id*/, size_t /*len*/) override { + CloseConnection("SPDY frame padding received.", + QUIC_INVALID_HEADERS_STREAM_DATA); + } + + void OnError(Http2DecoderAdapter::SpdyFramerError error, + std::string detailed_error) override { + QuicErrorCode code; + switch (error) { + case Http2DecoderAdapter::SpdyFramerError::SPDY_HPACK_INDEX_VARINT_ERROR: + code = QUIC_HPACK_INDEX_VARINT_ERROR; + break; + case Http2DecoderAdapter::SpdyFramerError:: + SPDY_HPACK_NAME_LENGTH_VARINT_ERROR: + code = QUIC_HPACK_NAME_LENGTH_VARINT_ERROR; + break; + case Http2DecoderAdapter::SpdyFramerError:: + SPDY_HPACK_VALUE_LENGTH_VARINT_ERROR: + code = QUIC_HPACK_VALUE_LENGTH_VARINT_ERROR; + break; + case Http2DecoderAdapter::SpdyFramerError::SPDY_HPACK_NAME_TOO_LONG: + code = QUIC_HPACK_NAME_TOO_LONG; + break; + case Http2DecoderAdapter::SpdyFramerError::SPDY_HPACK_VALUE_TOO_LONG: + code = QUIC_HPACK_VALUE_TOO_LONG; + break; + case Http2DecoderAdapter::SpdyFramerError::SPDY_HPACK_NAME_HUFFMAN_ERROR: + code = QUIC_HPACK_NAME_HUFFMAN_ERROR; + break; + case Http2DecoderAdapter::SpdyFramerError::SPDY_HPACK_VALUE_HUFFMAN_ERROR: + code = QUIC_HPACK_VALUE_HUFFMAN_ERROR; + break; + case Http2DecoderAdapter::SpdyFramerError:: + SPDY_HPACK_MISSING_DYNAMIC_TABLE_SIZE_UPDATE: + code = QUIC_HPACK_MISSING_DYNAMIC_TABLE_SIZE_UPDATE; + break; + case Http2DecoderAdapter::SpdyFramerError::SPDY_HPACK_INVALID_INDEX: + code = QUIC_HPACK_INVALID_INDEX; + break; + case Http2DecoderAdapter::SpdyFramerError::SPDY_HPACK_INVALID_NAME_INDEX: + code = QUIC_HPACK_INVALID_NAME_INDEX; + break; + case Http2DecoderAdapter::SpdyFramerError:: + SPDY_HPACK_DYNAMIC_TABLE_SIZE_UPDATE_NOT_ALLOWED: + code = QUIC_HPACK_DYNAMIC_TABLE_SIZE_UPDATE_NOT_ALLOWED; + break; + case Http2DecoderAdapter::SpdyFramerError:: + SPDY_HPACK_INITIAL_DYNAMIC_TABLE_SIZE_UPDATE_IS_ABOVE_LOW_WATER_MARK: + code = QUIC_HPACK_INITIAL_TABLE_SIZE_UPDATE_IS_ABOVE_LOW_WATER_MARK; + break; + case Http2DecoderAdapter::SpdyFramerError:: + SPDY_HPACK_DYNAMIC_TABLE_SIZE_UPDATE_IS_ABOVE_ACKNOWLEDGED_SETTING: + code = QUIC_HPACK_TABLE_SIZE_UPDATE_IS_ABOVE_ACKNOWLEDGED_SETTING; + break; + case Http2DecoderAdapter::SpdyFramerError::SPDY_HPACK_TRUNCATED_BLOCK: + code = QUIC_HPACK_TRUNCATED_BLOCK; + break; + case Http2DecoderAdapter::SpdyFramerError::SPDY_HPACK_FRAGMENT_TOO_LONG: + code = QUIC_HPACK_FRAGMENT_TOO_LONG; + break; + case Http2DecoderAdapter::SpdyFramerError:: + SPDY_HPACK_COMPRESSED_HEADER_SIZE_EXCEEDS_LIMIT: + code = QUIC_HPACK_COMPRESSED_HEADER_SIZE_EXCEEDS_LIMIT; + break; + case Http2DecoderAdapter::SpdyFramerError::SPDY_DECOMPRESS_FAILURE: + code = QUIC_HEADERS_STREAM_DATA_DECOMPRESS_FAILURE; + break; + default: + code = QUIC_INVALID_HEADERS_STREAM_DATA; + } + CloseConnection( + absl::StrCat("SPDY framing error: ", detailed_error, + Http2DecoderAdapter::SpdyFramerErrorToString(error)), + code); + } + + void OnDataFrameHeader(SpdyStreamId /*stream_id*/, size_t /*length*/, + bool /*fin*/) override { + QUICHE_DCHECK(!VersionUsesHttp3(session_->transport_version())); + CloseConnection("SPDY DATA frame received.", + QUIC_INVALID_HEADERS_STREAM_DATA); + } + + void OnRstStream(SpdyStreamId /*stream_id*/, + SpdyErrorCode /*error_code*/) override { + CloseConnection("SPDY RST_STREAM frame received.", + QUIC_INVALID_HEADERS_STREAM_DATA); + } + + void OnSetting(SpdySettingsId id, uint32_t value) override { + QUICHE_DCHECK(!VersionUsesHttp3(session_->transport_version())); + session_->OnSetting(id, value); + } + + void OnSettingsEnd() override { + QUICHE_DCHECK(!VersionUsesHttp3(session_->transport_version())); + } + + void OnPing(SpdyPingId /*unique_id*/, bool /*is_ack*/) override { + CloseConnection("SPDY PING frame received.", + QUIC_INVALID_HEADERS_STREAM_DATA); + } + + void OnGoAway(SpdyStreamId /*last_accepted_stream_id*/, + SpdyErrorCode /*error_code*/) override { + CloseConnection("SPDY GOAWAY frame received.", + QUIC_INVALID_HEADERS_STREAM_DATA); + } + + void OnHeaders(SpdyStreamId stream_id, size_t /*payload_length*/, + bool has_priority, int weight, + SpdyStreamId /*parent_stream_id*/, bool /*exclusive*/, + bool fin, bool /*end*/) override { + if (!session_->IsConnected()) { + return; + } + + if (VersionUsesHttp3(session_->transport_version())) { + CloseConnection("HEADERS frame not allowed on headers stream.", + QUIC_INVALID_HEADERS_STREAM_DATA); + return; + } + + QUIC_BUG_IF(quic_bug_12477_1, + session_->destruction_indicator() != 123456789) + << "QuicSpdyStream use after free. " + << session_->destruction_indicator() << QuicStackTrace(); + + SpdyPriority priority = + has_priority ? Http2WeightToSpdy3Priority(weight) : 0; + session_->OnHeaders(stream_id, has_priority, + spdy::SpdyStreamPrecedence(priority), fin); + } + + void OnWindowUpdate(SpdyStreamId /*stream_id*/, + int /*delta_window_size*/) override { + CloseConnection("SPDY WINDOW_UPDATE frame received.", + QUIC_INVALID_HEADERS_STREAM_DATA); + } + + void OnPushPromise(SpdyStreamId stream_id, SpdyStreamId promised_stream_id, + bool /*end*/) override { + QUICHE_DCHECK(!VersionUsesHttp3(session_->transport_version())); + if (session_->perspective() != Perspective::IS_CLIENT) { + CloseConnection("PUSH_PROMISE not supported.", + QUIC_INVALID_HEADERS_STREAM_DATA); + return; + } + if (!session_->IsConnected()) { + return; + } + session_->OnPushPromise(stream_id, promised_stream_id); + } + + void OnContinuation(SpdyStreamId /*stream_id*/, size_t /*payload_size*/, + bool /*end*/) override {} + + void OnPriority(SpdyStreamId stream_id, SpdyStreamId /* parent_id */, + int weight, bool /* exclusive */) override { + QUICHE_DCHECK(!VersionUsesHttp3(session_->transport_version())); + if (!session_->IsConnected()) { + return; + } + SpdyPriority priority = Http2WeightToSpdy3Priority(weight); + session_->OnPriority(stream_id, spdy::SpdyStreamPrecedence(priority)); + } + + void OnPriorityUpdate(SpdyStreamId /*prioritized_stream_id*/, + absl::string_view /*priority_field_value*/) override {} + + bool OnUnknownFrame(SpdyStreamId /*stream_id*/, + uint8_t /*frame_type*/) override { + CloseConnection("Unknown frame type received.", + QUIC_INVALID_HEADERS_STREAM_DATA); + return false; + } + + void OnUnknownFrameStart(SpdyStreamId /*stream_id*/, size_t /*length*/, + uint8_t /*type*/, uint8_t /*flags*/) override {} + + void OnUnknownFramePayload(SpdyStreamId /*stream_id*/, + absl::string_view /*payload*/) override {} + + // SpdyFramerDebugVisitorInterface implementation + void OnSendCompressedFrame(SpdyStreamId /*stream_id*/, SpdyFrameType /*type*/, + size_t payload_len, size_t frame_len) override { + if (payload_len == 0) { + QUIC_BUG(quic_bug_10360_1) << "Zero payload length."; + return; + } + int compression_pct = 100 - (100 * frame_len) / payload_len; + QUIC_DVLOG(1) << "Net.QuicHpackCompressionPercentage: " << compression_pct; + } + + void OnReceiveCompressedFrame(SpdyStreamId /*stream_id*/, + SpdyFrameType /*type*/, + size_t frame_len) override { + if (session_->IsConnected()) { + session_->OnCompressedFrameSize(frame_len); + } + } + + void set_max_header_list_size(size_t max_header_list_size) { + header_list_.set_max_header_list_size(max_header_list_size); + } + + private: + void CloseConnection(const std::string& details, QuicErrorCode code) { + if (session_->IsConnected()) { + session_->CloseConnectionWithDetails(code, details); + } + } + + QuicSpdySession* session_; + QuicHeaderList header_list_; +}; + +Http3DebugVisitor::Http3DebugVisitor() {} + +Http3DebugVisitor::~Http3DebugVisitor() {} + +// Expected unidirectional static streams Requirement can be found at +// https://tools.ietf.org/html/draft-ietf-quic-http-22#section-6.2. +QuicSpdySession::QuicSpdySession( + QuicConnection* connection, QuicSession::Visitor* visitor, + const QuicConfig& config, const ParsedQuicVersionVector& supported_versions) + : QuicSession(connection, visitor, config, supported_versions, + /*num_expected_unidirectional_static_streams = */ + VersionUsesHttp3(connection->transport_version()) + ? static_cast( + kHttp3StaticUnidirectionalStreamCount) + : 0u, + std::make_unique(this)), + send_control_stream_(nullptr), + receive_control_stream_(nullptr), + qpack_encoder_receive_stream_(nullptr), + qpack_decoder_receive_stream_(nullptr), + qpack_encoder_send_stream_(nullptr), + qpack_decoder_send_stream_(nullptr), + qpack_maximum_dynamic_table_capacity_( + kDefaultQpackMaxDynamicTableCapacity), + qpack_maximum_blocked_streams_(kDefaultMaximumBlockedStreams), + max_inbound_header_list_size_(kDefaultMaxUncompressedHeaderSize), + max_outbound_header_list_size_(std::numeric_limits::max()), + stream_id_( + QuicUtils::GetInvalidStreamId(connection->transport_version())), + promised_stream_id_( + QuicUtils::GetInvalidStreamId(connection->transport_version())), + frame_len_(0), + fin_(false), + spdy_framer_(SpdyFramer::ENABLE_COMPRESSION), + spdy_framer_visitor_(new SpdyFramerVisitor(this)), + debug_visitor_(nullptr), + destruction_indicator_(123456789), + allow_extended_connect_( + GetQuicReloadableFlag(quic_verify_request_headers_2) && + perspective() == Perspective::IS_SERVER && + VersionUsesHttp3(transport_version())) { + h2_deframer_.set_visitor(spdy_framer_visitor_.get()); + h2_deframer_.set_debug_visitor(spdy_framer_visitor_.get()); + spdy_framer_.set_debug_visitor(spdy_framer_visitor_.get()); +} + +QuicSpdySession::~QuicSpdySession() { + QUIC_BUG_IF(quic_bug_12477_2, destruction_indicator_ != 123456789) + << "QuicSpdySession use after free. " << destruction_indicator_ + << QuicStackTrace(); + destruction_indicator_ = 987654321; +} + +void QuicSpdySession::Initialize() { + QuicSession::Initialize(); + + FillSettingsFrame(); + if (!VersionUsesHttp3(transport_version())) { + if (perspective() == Perspective::IS_SERVER) { + set_largest_peer_created_stream_id( + QuicUtils::GetHeadersStreamId(transport_version())); + } else { + QuicStreamId headers_stream_id = GetNextOutgoingBidirectionalStreamId(); + QUICHE_DCHECK_EQ(headers_stream_id, + QuicUtils::GetHeadersStreamId(transport_version())); + } + auto headers_stream = std::make_unique((this)); + QUICHE_DCHECK_EQ(QuicUtils::GetHeadersStreamId(transport_version()), + headers_stream->id()); + + headers_stream_ = headers_stream.get(); + ActivateStream(std::move(headers_stream)); + } else { + qpack_encoder_ = std::make_unique(this); + qpack_decoder_ = + std::make_unique(qpack_maximum_dynamic_table_capacity_, + qpack_maximum_blocked_streams_, this); + MaybeInitializeHttp3UnidirectionalStreams(); + } + + spdy_framer_visitor_->set_max_header_list_size(max_inbound_header_list_size_); + + // Limit HPACK buffering to 2x header list size limit. + h2_deframer_.GetHpackDecoder()->set_max_decode_buffer_size_bytes( + 2 * max_inbound_header_list_size_); +} + +void QuicSpdySession::FillSettingsFrame() { + settings_.values[SETTINGS_QPACK_MAX_TABLE_CAPACITY] = + qpack_maximum_dynamic_table_capacity_; + settings_.values[SETTINGS_QPACK_BLOCKED_STREAMS] = + qpack_maximum_blocked_streams_; + settings_.values[SETTINGS_MAX_FIELD_SECTION_SIZE] = + max_inbound_header_list_size_; + if (version().UsesHttp3()) { + switch (LocalHttpDatagramSupport()) { + case HttpDatagramSupport::kNone: + break; + case HttpDatagramSupport::kDraft04: + settings_.values[SETTINGS_H3_DATAGRAM_DRAFT04] = 1; + break; + case HttpDatagramSupport::kRfc: + settings_.values[SETTINGS_H3_DATAGRAM] = 1; + break; + case HttpDatagramSupport::kRfcAndDraft04: + settings_.values[SETTINGS_H3_DATAGRAM] = 1; + settings_.values[SETTINGS_H3_DATAGRAM_DRAFT04] = 1; + break; + } + } + if (WillNegotiateWebTransport()) { + settings_.values[SETTINGS_WEBTRANS_DRAFT00] = 1; + } + if (allow_extended_connect()) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_verify_request_headers_2, 1, 3); + settings_.values[SETTINGS_ENABLE_CONNECT_PROTOCOL] = 1; + } +} + +void QuicSpdySession::OnDecoderStreamError(QuicErrorCode error_code, + absl::string_view error_message) { + QUICHE_DCHECK(VersionUsesHttp3(transport_version())); + + CloseConnectionWithDetails( + error_code, absl::StrCat("Decoder stream error: ", error_message)); +} + +void QuicSpdySession::OnEncoderStreamError(QuicErrorCode error_code, + absl::string_view error_message) { + QUICHE_DCHECK(VersionUsesHttp3(transport_version())); + + CloseConnectionWithDetails( + error_code, absl::StrCat("Encoder stream error: ", error_message)); +} + +void QuicSpdySession::OnStreamHeadersPriority( + QuicStreamId stream_id, const spdy::SpdyStreamPrecedence& precedence) { + QuicSpdyStream* stream = GetOrCreateSpdyDataStream(stream_id); + if (!stream) { + // It's quite possible to receive headers after a stream has been reset. + return; + } + stream->OnStreamHeadersPriority(precedence); +} + +void QuicSpdySession::OnStreamHeaderList(QuicStreamId stream_id, bool fin, + size_t frame_len, + const QuicHeaderList& header_list) { + if (IsStaticStream(stream_id)) { + connection()->CloseConnection( + QUIC_INVALID_HEADERS_STREAM_DATA, "stream is static", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + QuicSpdyStream* stream = GetOrCreateSpdyDataStream(stream_id); + if (stream == nullptr) { + // The stream no longer exists, but trailing headers may contain the final + // byte offset necessary for flow control and open stream accounting. + size_t final_byte_offset = 0; + for (const auto& header : header_list) { + const std::string& header_key = header.first; + const std::string& header_value = header.second; + if (header_key == kFinalOffsetHeaderKey) { + if (!absl::SimpleAtoi(header_value, &final_byte_offset)) { + connection()->CloseConnection( + QUIC_INVALID_HEADERS_STREAM_DATA, + "Trailers are malformed (no final offset)", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + QUIC_DVLOG(1) << ENDPOINT + << "Received final byte offset in trailers for stream " + << stream_id << ", which no longer exists."; + OnFinalByteOffsetReceived(stream_id, final_byte_offset); + } + } + + // It's quite possible to receive headers after a stream has been reset. + return; + } + stream->OnStreamHeaderList(fin, frame_len, header_list); +} + +void QuicSpdySession::OnPriorityFrame( + QuicStreamId stream_id, const spdy::SpdyStreamPrecedence& precedence) { + QuicSpdyStream* stream = GetOrCreateSpdyDataStream(stream_id); + if (!stream) { + // It's quite possible to receive a PRIORITY frame after a stream has been + // reset. + return; + } + stream->OnPriorityFrame(precedence); +} + +bool QuicSpdySession::OnPriorityUpdateForRequestStream( + QuicStreamId stream_id, HttpStreamPriority priority) { + if (perspective() == Perspective::IS_CLIENT || + !QuicUtils::IsBidirectionalStreamId(stream_id, version()) || + !QuicUtils::IsClientInitiatedStreamId(transport_version(), stream_id)) { + return true; + } + + QuicStreamCount advertised_max_incoming_bidirectional_streams = + GetAdvertisedMaxIncomingBidirectionalStreams(); + if (advertised_max_incoming_bidirectional_streams == 0 || + stream_id > QuicUtils::GetFirstBidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT) + + QuicUtils::StreamIdDelta(transport_version()) * + (advertised_max_incoming_bidirectional_streams - 1)) { + connection()->CloseConnection( + QUIC_INVALID_STREAM_ID, + "PRIORITY_UPDATE frame received for invalid stream.", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return false; + } + + if (MaybeSetStreamPriority(stream_id, QuicStreamPriority(priority))) { + return true; + } + + if (IsClosedStream(stream_id)) { + return true; + } + + buffered_stream_priorities_[stream_id] = priority; + + if (buffered_stream_priorities_.size() > + 10 * max_open_incoming_bidirectional_streams()) { + // This should never happen, because |buffered_stream_priorities_| should + // only contain entries for streams that are allowed to be open by the peer + // but have not been opened yet. + std::string error_message = + absl::StrCat("Too many stream priority values buffered: ", + buffered_stream_priorities_.size(), + ", which should not exceed the incoming stream limit of ", + max_open_incoming_bidirectional_streams()); + QUIC_BUG(quic_bug_10360_2) << error_message; + connection()->CloseConnection( + QUIC_INTERNAL_ERROR, error_message, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return false; + } + + return true; +} + +size_t QuicSpdySession::ProcessHeaderData(const struct iovec& iov) { + QUIC_BUG_IF(quic_bug_12477_4, destruction_indicator_ != 123456789) + << "QuicSpdyStream use after free. " << destruction_indicator_ + << QuicStackTrace(); + return h2_deframer_.ProcessInput(static_cast(iov.iov_base), + iov.iov_len); +} + +size_t QuicSpdySession::WriteHeadersOnHeadersStream( + QuicStreamId id, Http2HeaderBlock headers, bool fin, + const spdy::SpdyStreamPrecedence& precedence, + quiche::QuicheReferenceCountedPointer + ack_listener) { + QUICHE_DCHECK(!VersionUsesHttp3(transport_version())); + + return WriteHeadersOnHeadersStreamImpl( + id, std::move(headers), fin, + /* parent_stream_id = */ 0, + Spdy3PriorityToHttp2Weight(precedence.spdy3_priority()), + /* exclusive = */ false, std::move(ack_listener)); +} + +size_t QuicSpdySession::WritePriority(QuicStreamId stream_id, + QuicStreamId parent_stream_id, int weight, + bool exclusive) { + QUICHE_DCHECK(!VersionUsesHttp3(transport_version())); + SpdyPriorityIR priority_frame(stream_id, parent_stream_id, weight, exclusive); + SpdySerializedFrame frame(spdy_framer_.SerializeFrame(priority_frame)); + headers_stream()->WriteOrBufferData( + absl::string_view(frame.data(), frame.size()), false, nullptr); + return frame.size(); +} + +void QuicSpdySession::WriteHttp3PriorityUpdate(QuicStreamId stream_id, + HttpStreamPriority priority) { + QUICHE_DCHECK(VersionUsesHttp3(transport_version())); + + send_control_stream_->WritePriorityUpdate(stream_id, priority); +} + +void QuicSpdySession::OnHttp3GoAway(uint64_t id) { + QUIC_BUG_IF(quic_bug_12477_5, !version().UsesHttp3()) + << "HTTP/3 GOAWAY received on version " << version(); + + if (last_received_http3_goaway_id_.has_value() && + id > last_received_http3_goaway_id_.value()) { + CloseConnectionWithDetails( + QUIC_HTTP_GOAWAY_ID_LARGER_THAN_PREVIOUS, + absl::StrCat("GOAWAY received with ID ", id, + " greater than previously received ID ", + last_received_http3_goaway_id_.value())); + return; + } + last_received_http3_goaway_id_ = id; + + if (perspective() == Perspective::IS_SERVER) { + // TODO(b/151749109): Cancel server pushes with push ID larger than |id|. + return; + } + + // QuicStreamId is uint32_t. Casting to this narrower type is well-defined + // and preserves the lower 32 bits. Both IsBidirectionalStreamId() and + // IsIncomingStream() give correct results, because their return value is + // determined by the least significant two bits. + QuicStreamId stream_id = static_cast(id); + if (!QuicUtils::IsBidirectionalStreamId(stream_id, version()) || + IsIncomingStream(stream_id)) { + CloseConnectionWithDetails(QUIC_HTTP_GOAWAY_INVALID_STREAM_ID, + "GOAWAY with invalid stream ID"); + return; + } + + // TODO(b/161252736): Cancel client requests with ID larger than |id|. + // If |id| is larger than numeric_limits::max(), then use + // max() instead of downcast value. +} + +bool QuicSpdySession::OnStreamsBlockedFrame( + const QuicStreamsBlockedFrame& frame) { + if (!QuicSession::OnStreamsBlockedFrame(frame)) { + return false; + } + + // The peer asked for stream space more than this implementation has. Send + // goaway. + if (perspective() == Perspective::IS_SERVER && + frame.stream_count >= QuicUtils::GetMaxStreamCount()) { + QUICHE_DCHECK_EQ(frame.stream_count, QuicUtils::GetMaxStreamCount()); + SendHttp3GoAway(QUIC_PEER_GOING_AWAY, "stream count too large"); + } + return true; +} + +void QuicSpdySession::SendHttp3GoAway(QuicErrorCode error_code, + const std::string& reason) { + QUICHE_DCHECK(VersionUsesHttp3(transport_version())); + if (!IsEncryptionEstablished()) { + QUIC_CODE_COUNT(quic_h3_goaway_before_encryption_established); + connection()->CloseConnection( + error_code, reason, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + QuicStreamId stream_id; + + stream_id = QuicUtils::GetMaxClientInitiatedBidirectionalStreamId( + transport_version()); + if (last_sent_http3_goaway_id_.has_value() && + last_sent_http3_goaway_id_.value() <= stream_id) { + // Do not send GOAWAY frame with a higher id, because it is forbidden. + // Do not send one with same stream id as before, since frames on the + // control stream are guaranteed to be processed in order. + return; + } + + send_control_stream_->SendGoAway(stream_id); + last_sent_http3_goaway_id_ = stream_id; +} + +void QuicSpdySession::WritePushPromise(QuicStreamId original_stream_id, + QuicStreamId promised_stream_id, + Http2HeaderBlock headers) { + if (perspective() == Perspective::IS_CLIENT) { + QUIC_BUG(quic_bug_10360_4) << "Client shouldn't send PUSH_PROMISE"; + return; + } + + if (VersionUsesHttp3(transport_version())) { + QUIC_BUG(quic_bug_12477_6) + << "Support for server push over HTTP/3 has been removed."; + return; + } + + SpdyPushPromiseIR push_promise(original_stream_id, promised_stream_id, + std::move(headers)); + // PUSH_PROMISE must not be the last frame sent out, at least followed by + // response headers. + push_promise.set_fin(false); + + SpdySerializedFrame frame(spdy_framer_.SerializeFrame(push_promise)); + headers_stream()->WriteOrBufferData( + absl::string_view(frame.data(), frame.size()), false, nullptr); +} + +void QuicSpdySession::SendInitialData() { + if (!VersionUsesHttp3(transport_version())) { + return; + } + QuicConnection::ScopedPacketFlusher flusher(connection()); + send_control_stream_->MaybeSendSettingsFrame(); +} + +QpackEncoder* QuicSpdySession::qpack_encoder() { + QUICHE_DCHECK(VersionUsesHttp3(transport_version())); + + return qpack_encoder_.get(); +} + +QpackDecoder* QuicSpdySession::qpack_decoder() { + QUICHE_DCHECK(VersionUsesHttp3(transport_version())); + + return qpack_decoder_.get(); +} + +void QuicSpdySession::OnStreamCreated(QuicSpdyStream* stream) { + auto it = buffered_stream_priorities_.find(stream->id()); + if (it == buffered_stream_priorities_.end()) { + return; + } + + stream->SetPriority(QuicStreamPriority(it->second)); + buffered_stream_priorities_.erase(it); +} + +QuicSpdyStream* QuicSpdySession::GetOrCreateSpdyDataStream( + const QuicStreamId stream_id) { + QuicStream* stream = GetOrCreateStream(stream_id); + if (stream && stream->is_static()) { + QUIC_BUG(quic_bug_10360_5) + << "GetOrCreateSpdyDataStream returns static stream " << stream_id + << " in version " << transport_version() << "\n" + << QuicStackTrace(); + connection()->CloseConnection( + QUIC_INVALID_STREAM_ID, + absl::StrCat("stream ", stream_id, " is static"), + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return nullptr; + } + return static_cast(stream); +} + +void QuicSpdySession::OnNewEncryptionKeyAvailable( + EncryptionLevel level, std::unique_ptr encrypter) { + QuicSession::OnNewEncryptionKeyAvailable(level, std::move(encrypter)); + if (IsEncryptionEstablished()) { + // Send H3 SETTINGs once encryption is established. + SendInitialData(); + } +} + +bool QuicSpdySession::ShouldNegotiateWebTransport() { return false; } + +bool QuicSpdySession::ShouldValidateWebTransportVersion() const { return true; } + +bool QuicSpdySession::WillNegotiateWebTransport() { + return LocalHttpDatagramSupport() != HttpDatagramSupport::kNone && + version().UsesHttp3() && ShouldNegotiateWebTransport(); +} + +// True if there are open HTTP requests. +bool QuicSpdySession::ShouldKeepConnectionAlive() const { + QUICHE_DCHECK(VersionUsesHttp3(transport_version()) || + 0u == pending_streams_size()); + return GetNumActiveStreams() + pending_streams_size() > 0; +} + +bool QuicSpdySession::UsesPendingStreamForFrame(QuicFrameType type, + QuicStreamId stream_id) const { + // Pending streams can only be used to handle unidirectional stream with + // STREAM & RESET_STREAM frames in IETF QUIC. + return VersionUsesHttp3(transport_version()) && + (type == STREAM_FRAME || type == RST_STREAM_FRAME) && + QuicUtils::GetStreamType(stream_id, perspective(), + IsIncomingStream(stream_id), + version()) == READ_UNIDIRECTIONAL; +} + +size_t QuicSpdySession::WriteHeadersOnHeadersStreamImpl( + QuicStreamId id, spdy::Http2HeaderBlock headers, bool fin, + QuicStreamId parent_stream_id, int weight, bool exclusive, + quiche::QuicheReferenceCountedPointer + ack_listener) { + QUICHE_DCHECK(!VersionUsesHttp3(transport_version())); + + const QuicByteCount uncompressed_size = headers.TotalBytesUsed(); + SpdyHeadersIR headers_frame(id, std::move(headers)); + headers_frame.set_fin(fin); + if (perspective() == Perspective::IS_CLIENT) { + headers_frame.set_has_priority(true); + headers_frame.set_parent_stream_id(parent_stream_id); + headers_frame.set_weight(weight); + headers_frame.set_exclusive(exclusive); + } + SpdySerializedFrame frame(spdy_framer_.SerializeFrame(headers_frame)); + headers_stream()->WriteOrBufferData( + absl::string_view(frame.data(), frame.size()), false, + std::move(ack_listener)); + + // Calculate compressed header block size without framing overhead. + QuicByteCount compressed_size = frame.size(); + compressed_size -= spdy::kFrameHeaderSize; + if (perspective() == Perspective::IS_CLIENT) { + // Exclusive bit and Stream Dependency are four bytes, weight is one more. + compressed_size -= 5; + } + + LogHeaderCompressionRatioHistogram( + /* using_qpack = */ false, + /* is_sent = */ true, compressed_size, uncompressed_size); + + return frame.size(); +} + +void QuicSpdySession::OnPromiseHeaderList( + QuicStreamId /*stream_id*/, QuicStreamId /*promised_stream_id*/, + size_t /*frame_len*/, const QuicHeaderList& /*header_list*/) { + std::string error = + "OnPromiseHeaderList should be overridden in client code."; + QUIC_BUG(quic_bug_10360_6) << error; + connection()->CloseConnection(QUIC_INTERNAL_ERROR, error, + ConnectionCloseBehavior::SILENT_CLOSE); +} + +bool QuicSpdySession::ResumeApplicationState(ApplicationState* cached_state) { + QUICHE_DCHECK_EQ(perspective(), Perspective::IS_CLIENT); + QUICHE_DCHECK(VersionUsesHttp3(transport_version())); + + SettingsFrame out; + if (!HttpDecoder::DecodeSettings( + reinterpret_cast(cached_state->data()), cached_state->size(), + &out)) { + return false; + } + + if (debug_visitor_ != nullptr) { + debug_visitor_->OnSettingsFrameResumed(out); + } + QUICHE_DCHECK(streams_waiting_for_settings_.empty()); + for (const auto& setting : out.values) { + OnSetting(setting.first, setting.second); + } + return true; +} + +absl::optional QuicSpdySession::OnAlpsData( + const uint8_t* alps_data, size_t alps_length) { + AlpsFrameDecoder alps_frame_decoder(this); + HttpDecoder decoder(&alps_frame_decoder); + decoder.ProcessInput(reinterpret_cast(alps_data), alps_length); + if (alps_frame_decoder.error_detail()) { + return alps_frame_decoder.error_detail(); + } + + if (decoder.error() != QUIC_NO_ERROR) { + return decoder.error_detail(); + } + + if (!decoder.AtFrameBoundary()) { + return "incomplete HTTP/3 frame"; + } + + return absl::nullopt; +} + +void QuicSpdySession::OnAcceptChFrameReceivedViaAlps( + const AcceptChFrame& frame) { + if (debug_visitor_) { + debug_visitor_->OnAcceptChFrameReceivedViaAlps(frame); + } +} + +bool QuicSpdySession::OnSettingsFrame(const SettingsFrame& frame) { + QUICHE_DCHECK(VersionUsesHttp3(transport_version())); + if (debug_visitor_ != nullptr) { + debug_visitor_->OnSettingsFrameReceived(frame); + } + for (const auto& setting : frame.values) { + if (!OnSetting(setting.first, setting.second)) { + return false; + } + } + for (QuicStreamId stream_id : streams_waiting_for_settings_) { + QUICHE_DCHECK(ShouldBufferRequestsUntilSettings()); + QuicSpdyStream* stream = GetOrCreateSpdyDataStream(stream_id); + if (stream == nullptr) { + // The stream may no longer exist, since it is possible for a stream to + // get reset while waiting for the SETTINGS frame. + continue; + } + stream->OnDataAvailable(); + } + streams_waiting_for_settings_.clear(); + return true; +} + +absl::optional QuicSpdySession::OnSettingsFrameViaAlps( + const SettingsFrame& frame) { + QUICHE_DCHECK(VersionUsesHttp3(transport_version())); + + if (debug_visitor_ != nullptr) { + debug_visitor_->OnSettingsFrameReceivedViaAlps(frame); + } + for (const auto& setting : frame.values) { + if (!OnSetting(setting.first, setting.second)) { + // Do not bother adding the setting identifier or value to the error + // message, because OnSetting() already closed the connection, therefore + // the error message will be ignored. + return "error parsing setting"; + } + } + return absl::nullopt; +} + +bool QuicSpdySession::VerifySettingIsZeroOrOne(uint64_t id, uint64_t value) { + if (value == 0 || value == 1) { + return true; + } + std::string error_details = absl::StrCat( + "Received ", + H3SettingsToString(static_cast(id)), + " with invalid value ", value); + QUIC_PEER_BUG(bad received setting) << ENDPOINT << error_details; + CloseConnectionWithDetails(QUIC_HTTP_INVALID_SETTING_VALUE, error_details); + return false; +} + +bool QuicSpdySession::OnSetting(uint64_t id, uint64_t value) { + any_settings_received_ = true; + + if (VersionUsesHttp3(transport_version())) { + // SETTINGS frame received on the control stream. + switch (id) { + case SETTINGS_QPACK_MAX_TABLE_CAPACITY: { + QUIC_DVLOG(1) + << ENDPOINT + << "SETTINGS_QPACK_MAX_TABLE_CAPACITY received with value " + << value; + // Communicate |value| to encoder, because it is used for encoding + // Required Insert Count. + if (!qpack_encoder_->SetMaximumDynamicTableCapacity(value)) { + CloseConnectionWithDetails( + was_zero_rtt_rejected() + ? QUIC_HTTP_ZERO_RTT_REJECTION_SETTINGS_MISMATCH + : QUIC_HTTP_ZERO_RTT_RESUMPTION_SETTINGS_MISMATCH, + absl::StrCat(was_zero_rtt_rejected() + ? "Server rejected 0-RTT, aborting because " + : "", + "Server sent an SETTINGS_QPACK_MAX_TABLE_CAPACITY: ", + value, " while current value is: ", + qpack_encoder_->MaximumDynamicTableCapacity())); + return false; + } + // However, limit the dynamic table capacity to + // |qpack_maximum_dynamic_table_capacity_|. + qpack_encoder_->SetDynamicTableCapacity( + std::min(value, qpack_maximum_dynamic_table_capacity_)); + break; + } + case SETTINGS_MAX_FIELD_SECTION_SIZE: + QUIC_DVLOG(1) << ENDPOINT + << "SETTINGS_MAX_FIELD_SECTION_SIZE received with value " + << value; + if (max_outbound_header_list_size_ != + std::numeric_limits::max() && + max_outbound_header_list_size_ > value) { + CloseConnectionWithDetails( + was_zero_rtt_rejected() + ? QUIC_HTTP_ZERO_RTT_REJECTION_SETTINGS_MISMATCH + : QUIC_HTTP_ZERO_RTT_RESUMPTION_SETTINGS_MISMATCH, + absl::StrCat(was_zero_rtt_rejected() + ? "Server rejected 0-RTT, aborting because " + : "", + "Server sent an SETTINGS_MAX_FIELD_SECTION_SIZE: ", + value, " which reduces current value: ", + max_outbound_header_list_size_)); + return false; + } + max_outbound_header_list_size_ = value; + break; + case SETTINGS_QPACK_BLOCKED_STREAMS: { + QUIC_DVLOG(1) << ENDPOINT + << "SETTINGS_QPACK_BLOCKED_STREAMS received with value " + << value; + if (!qpack_encoder_->SetMaximumBlockedStreams(value)) { + CloseConnectionWithDetails( + was_zero_rtt_rejected() + ? QUIC_HTTP_ZERO_RTT_REJECTION_SETTINGS_MISMATCH + : QUIC_HTTP_ZERO_RTT_RESUMPTION_SETTINGS_MISMATCH, + absl::StrCat(was_zero_rtt_rejected() + ? "Server rejected 0-RTT, aborting because " + : "", + "Server sent an SETTINGS_QPACK_BLOCKED_STREAMS: ", + value, " which reduces current value: ", + qpack_encoder_->maximum_blocked_streams())); + return false; + } + break; + } + case SETTINGS_ENABLE_CONNECT_PROTOCOL: { + QUIC_DVLOG(1) << ENDPOINT + << "SETTINGS_ENABLE_CONNECT_PROTOCOL received with value " + << value; + if (!VerifySettingIsZeroOrOne(id, value)) { + return false; + } + if (perspective() == Perspective::IS_CLIENT) { + allow_extended_connect_ = value != 0; + } + break; + } + case spdy::SETTINGS_ENABLE_PUSH: + ABSL_FALLTHROUGH_INTENDED; + case spdy::SETTINGS_MAX_CONCURRENT_STREAMS: + ABSL_FALLTHROUGH_INTENDED; + case spdy::SETTINGS_INITIAL_WINDOW_SIZE: + ABSL_FALLTHROUGH_INTENDED; + case spdy::SETTINGS_MAX_FRAME_SIZE: + CloseConnectionWithDetails( + QUIC_HTTP_RECEIVE_SPDY_SETTING, + absl::StrCat("received HTTP/2 specific setting in HTTP/3 session: ", + id)); + return false; + case SETTINGS_H3_DATAGRAM_DRAFT04: { + HttpDatagramSupport local_http_datagram_support = + LocalHttpDatagramSupport(); + if (local_http_datagram_support != HttpDatagramSupport::kDraft04 && + local_http_datagram_support != + HttpDatagramSupport::kRfcAndDraft04) { + break; + } + QUIC_DVLOG(1) << ENDPOINT + << "SETTINGS_H3_DATAGRAM_DRAFT04 received with value " + << value; + if (!version().UsesHttp3()) { + break; + } + if (!VerifySettingIsZeroOrOne(id, value)) { + return false; + } + if (value && http_datagram_support_ != HttpDatagramSupport::kRfc) { + // If both RFC 9297 and draft-04 are supported, we use the RFC. This + // is implemented by ignoring SETTINGS_H3_DATAGRAM_DRAFT04 when we've + // already parsed SETTINGS_H3_DATAGRAM. + http_datagram_support_ = HttpDatagramSupport::kDraft04; + } + break; + } + case SETTINGS_H3_DATAGRAM: { + HttpDatagramSupport local_http_datagram_support = + LocalHttpDatagramSupport(); + if (local_http_datagram_support != HttpDatagramSupport::kRfc && + local_http_datagram_support != + HttpDatagramSupport::kRfcAndDraft04) { + break; + } + QUIC_DVLOG(1) << ENDPOINT << "SETTINGS_H3_DATAGRAM received with value " + << value; + if (!version().UsesHttp3()) { + break; + } + if (!VerifySettingIsZeroOrOne(id, value)) { + return false; + } + if (value) { + http_datagram_support_ = HttpDatagramSupport::kRfc; + } + break; + } + case SETTINGS_WEBTRANS_DRAFT00: + if (!WillNegotiateWebTransport()) { + break; + } + QUIC_DVLOG(1) << ENDPOINT + << "SETTINGS_ENABLE_WEBTRANSPORT received with value " + << value; + if (!VerifySettingIsZeroOrOne(id, value)) { + return false; + } + peer_supports_webtransport_ = (value == 1); + if (perspective() == Perspective::IS_CLIENT && value == 1) { + allow_extended_connect_ = true; + } + break; + default: + QUIC_DVLOG(1) << ENDPOINT << "Unknown setting identifier " << id + << " received with value " << value; + // Ignore unknown settings. + break; + } + return true; + } + + // SETTINGS frame received on the headers stream. + switch (id) { + case spdy::SETTINGS_HEADER_TABLE_SIZE: + QUIC_DVLOG(1) << ENDPOINT + << "SETTINGS_HEADER_TABLE_SIZE received with value " + << value; + spdy_framer_.UpdateHeaderEncoderTableSize( + std::min(value, kHpackEncoderDynamicTableSizeLimit)); + break; + case spdy::SETTINGS_ENABLE_PUSH: + if (perspective() == Perspective::IS_SERVER) { + // See rfc7540, Section 6.5.2. + if (value > 1) { + QUIC_DLOG(ERROR) << ENDPOINT << "Invalid value " << value + << " received for SETTINGS_ENABLE_PUSH."; + if (IsConnected()) { + CloseConnectionWithDetails( + QUIC_INVALID_HEADERS_STREAM_DATA, + absl::StrCat("Invalid value for SETTINGS_ENABLE_PUSH: ", + value)); + } + return true; + } + QUIC_DVLOG(1) << ENDPOINT << "SETTINGS_ENABLE_PUSH received with value " + << value << ", ignoring."; + break; + } else { + QUIC_DLOG(ERROR) + << ENDPOINT + << "Invalid SETTINGS_ENABLE_PUSH received by client with value " + << value; + if (IsConnected()) { + CloseConnectionWithDetails( + QUIC_INVALID_HEADERS_STREAM_DATA, + absl::StrCat("Unsupported field of HTTP/2 SETTINGS frame: ", id)); + } + } + break; + case spdy::SETTINGS_MAX_HEADER_LIST_SIZE: + QUIC_DVLOG(1) << ENDPOINT + << "SETTINGS_MAX_HEADER_LIST_SIZE received with value " + << value; + max_outbound_header_list_size_ = value; + break; + default: + QUIC_DLOG(ERROR) << ENDPOINT << "Unknown setting identifier " << id + << " received with value " << value; + if (IsConnected()) { + CloseConnectionWithDetails( + QUIC_INVALID_HEADERS_STREAM_DATA, + absl::StrCat("Unsupported field of HTTP/2 SETTINGS frame: ", id)); + } + } + return true; +} + +bool QuicSpdySession::ShouldReleaseHeadersStreamSequencerBuffer() { + return false; +} + +void QuicSpdySession::OnHeaders(SpdyStreamId stream_id, bool has_priority, + const spdy::SpdyStreamPrecedence& precedence, + bool fin) { + if (has_priority) { + if (perspective() == Perspective::IS_CLIENT) { + CloseConnectionWithDetails(QUIC_INVALID_HEADERS_STREAM_DATA, + "Server must not send priorities."); + return; + } + OnStreamHeadersPriority(stream_id, precedence); + } else { + if (perspective() == Perspective::IS_SERVER) { + CloseConnectionWithDetails(QUIC_INVALID_HEADERS_STREAM_DATA, + "Client must send priorities."); + return; + } + } + QUICHE_DCHECK_EQ(QuicUtils::GetInvalidStreamId(transport_version()), + stream_id_); + QUICHE_DCHECK_EQ(QuicUtils::GetInvalidStreamId(transport_version()), + promised_stream_id_); + stream_id_ = stream_id; + fin_ = fin; +} + +void QuicSpdySession::OnPushPromise(SpdyStreamId stream_id, + SpdyStreamId promised_stream_id) { + QUICHE_DCHECK_EQ(QuicUtils::GetInvalidStreamId(transport_version()), + stream_id_); + QUICHE_DCHECK_EQ(QuicUtils::GetInvalidStreamId(transport_version()), + promised_stream_id_); + stream_id_ = stream_id; + promised_stream_id_ = promised_stream_id; +} + +// TODO (wangyix): Why is SpdyStreamId used instead of QuicStreamId? +// This occurs in many places in this file. +void QuicSpdySession::OnPriority(SpdyStreamId stream_id, + const spdy::SpdyStreamPrecedence& precedence) { + if (perspective() == Perspective::IS_CLIENT) { + CloseConnectionWithDetails(QUIC_INVALID_HEADERS_STREAM_DATA, + "Server must not send PRIORITY frames."); + return; + } + OnPriorityFrame(stream_id, precedence); +} + +void QuicSpdySession::OnHeaderList(const QuicHeaderList& header_list) { + QUIC_DVLOG(1) << ENDPOINT << "Received header list for stream " << stream_id_ + << ": " << header_list.DebugString(); + // This code path is only executed for push promise in IETF QUIC. + if (VersionUsesHttp3(transport_version())) { + QUICHE_DCHECK(promised_stream_id_ != + QuicUtils::GetInvalidStreamId(transport_version())); + } + if (promised_stream_id_ == + QuicUtils::GetInvalidStreamId(transport_version())) { + OnStreamHeaderList(stream_id_, fin_, frame_len_, header_list); + } else { + OnPromiseHeaderList(stream_id_, promised_stream_id_, frame_len_, + header_list); + } + // Reset state for the next frame. + promised_stream_id_ = QuicUtils::GetInvalidStreamId(transport_version()); + stream_id_ = QuicUtils::GetInvalidStreamId(transport_version()); + fin_ = false; + frame_len_ = 0; +} + +void QuicSpdySession::OnCompressedFrameSize(size_t frame_len) { + frame_len_ += frame_len; +} + +void QuicSpdySession::CloseConnectionWithDetails(QuicErrorCode error, + const std::string& details) { + connection()->CloseConnection( + error, details, ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); +} + +bool QuicSpdySession::HasActiveRequestStreams() const { + return GetNumActiveStreams() + num_draining_streams() > 0; +} + +QuicStream* QuicSpdySession::ProcessPendingStream(PendingStream* pending) { + QUICHE_DCHECK(VersionUsesHttp3(transport_version())); + QUICHE_DCHECK(connection()->connected()); + struct iovec iov; + if (!pending->sequencer()->GetReadableRegion(&iov)) { + // The first byte hasn't been received yet. + return nullptr; + } + + QuicDataReader reader(static_cast(iov.iov_base), iov.iov_len); + uint8_t stream_type_length = reader.PeekVarInt62Length(); + uint64_t stream_type = 0; + if (!reader.ReadVarInt62(&stream_type)) { + if (pending->sequencer()->NumBytesBuffered() == + pending->sequencer()->close_offset()) { + // Stream received FIN but there are not enough bytes for stream type. + // Mark all bytes consumed in order to close stream. + pending->MarkConsumed(pending->sequencer()->close_offset()); + } + return nullptr; + } + pending->MarkConsumed(stream_type_length); + + switch (stream_type) { + case kControlStream: { // HTTP/3 control stream. + if (receive_control_stream_) { + CloseConnectionOnDuplicateHttp3UnidirectionalStreams("Control"); + return nullptr; + } + auto receive_stream = + std::make_unique(pending, this); + receive_control_stream_ = receive_stream.get(); + ActivateStream(std::move(receive_stream)); + QUIC_DVLOG(1) << ENDPOINT << "Receive Control stream is created"; + if (debug_visitor_ != nullptr) { + debug_visitor_->OnPeerControlStreamCreated( + receive_control_stream_->id()); + } + return receive_control_stream_; + } + case kServerPushStream: { // Push Stream. + CloseConnectionWithDetails(QUIC_HTTP_RECEIVE_SERVER_PUSH, + "Received server push stream"); + return nullptr; + } + case kQpackEncoderStream: { // QPACK encoder stream. + if (qpack_encoder_receive_stream_) { + CloseConnectionOnDuplicateHttp3UnidirectionalStreams("QPACK encoder"); + return nullptr; + } + auto encoder_receive = std::make_unique( + pending, this, qpack_decoder_->encoder_stream_receiver()); + qpack_encoder_receive_stream_ = encoder_receive.get(); + ActivateStream(std::move(encoder_receive)); + QUIC_DVLOG(1) << ENDPOINT << "Receive QPACK Encoder stream is created"; + if (debug_visitor_ != nullptr) { + debug_visitor_->OnPeerQpackEncoderStreamCreated( + qpack_encoder_receive_stream_->id()); + } + return qpack_encoder_receive_stream_; + } + case kQpackDecoderStream: { // QPACK decoder stream. + if (qpack_decoder_receive_stream_) { + CloseConnectionOnDuplicateHttp3UnidirectionalStreams("QPACK decoder"); + return nullptr; + } + auto decoder_receive = std::make_unique( + pending, this, qpack_encoder_->decoder_stream_receiver()); + qpack_decoder_receive_stream_ = decoder_receive.get(); + ActivateStream(std::move(decoder_receive)); + QUIC_DVLOG(1) << ENDPOINT << "Receive QPACK Decoder stream is created"; + if (debug_visitor_ != nullptr) { + debug_visitor_->OnPeerQpackDecoderStreamCreated( + qpack_decoder_receive_stream_->id()); + } + return qpack_decoder_receive_stream_; + } + case kWebTransportUnidirectionalStream: { + // Note that this checks whether WebTransport is enabled on the receiver + // side, as we may receive WebTransport streams before peer's SETTINGS are + // received. + // TODO(b/184156476): consider whether this means we should drop buffered + // streams if we don't receive indication of WebTransport support. + if (!WillNegotiateWebTransport()) { + // Treat as unknown stream type. + break; + } + QUIC_DVLOG(1) << ENDPOINT << "Created an incoming WebTransport stream " + << pending->id(); + auto stream_owned = + std::make_unique(pending, + this); + WebTransportHttp3UnidirectionalStream* stream = stream_owned.get(); + ActivateStream(std::move(stream_owned)); + return stream; + } + default: + break; + } + MaybeSendStopSendingFrame( + pending->id(), + QuicResetStreamError::FromInternal(QUIC_STREAM_STREAM_CREATION_ERROR)); + pending->StopReading(); + return nullptr; +} + +void QuicSpdySession::MaybeInitializeHttp3UnidirectionalStreams() { + QUICHE_DCHECK(VersionUsesHttp3(transport_version())); + if (!send_control_stream_ && CanOpenNextOutgoingUnidirectionalStream()) { + auto send_control = std::make_unique( + GetNextOutgoingUnidirectionalStreamId(), this, settings_); + send_control_stream_ = send_control.get(); + ActivateStream(std::move(send_control)); + if (debug_visitor_) { + debug_visitor_->OnControlStreamCreated(send_control_stream_->id()); + } + } + + if (!qpack_decoder_send_stream_ && + CanOpenNextOutgoingUnidirectionalStream()) { + auto decoder_send = std::make_unique( + GetNextOutgoingUnidirectionalStreamId(), this, kQpackDecoderStream); + qpack_decoder_send_stream_ = decoder_send.get(); + ActivateStream(std::move(decoder_send)); + qpack_decoder_->set_qpack_stream_sender_delegate( + qpack_decoder_send_stream_); + if (debug_visitor_) { + debug_visitor_->OnQpackDecoderStreamCreated( + qpack_decoder_send_stream_->id()); + } + } + + if (!qpack_encoder_send_stream_ && + CanOpenNextOutgoingUnidirectionalStream()) { + auto encoder_send = std::make_unique( + GetNextOutgoingUnidirectionalStreamId(), this, kQpackEncoderStream); + qpack_encoder_send_stream_ = encoder_send.get(); + ActivateStream(std::move(encoder_send)); + qpack_encoder_->set_qpack_stream_sender_delegate( + qpack_encoder_send_stream_); + if (debug_visitor_) { + debug_visitor_->OnQpackEncoderStreamCreated( + qpack_encoder_send_stream_->id()); + } + } +} + +void QuicSpdySession::BeforeConnectionCloseSent() { + if (!VersionUsesHttp3(transport_version()) || !IsEncryptionEstablished()) { + return; + } + + QUICHE_DCHECK_EQ(perspective(), Perspective::IS_SERVER); + + QuicStreamId stream_id = + GetLargestPeerCreatedStreamId(/*unidirectional = */ false); + + if (stream_id == QuicUtils::GetInvalidStreamId(transport_version())) { + // No client-initiated bidirectional streams received yet. + // Send 0 to let client know that all requests can be retried. + stream_id = 0; + } else { + // Tell client that streams starting with the next after the largest + // received one can be retried. + stream_id += QuicUtils::StreamIdDelta(transport_version()); + } + if (last_sent_http3_goaway_id_.has_value() && + last_sent_http3_goaway_id_.value() <= stream_id) { + // Do not send GOAWAY frame with a higher id, because it is forbidden. + // Do not send one with same stream id as before, since frames on the + // control stream are guaranteed to be processed in order. + return; + } + + send_control_stream_->SendGoAway(stream_id); + last_sent_http3_goaway_id_ = stream_id; +} + +void QuicSpdySession::OnCanCreateNewOutgoingStream(bool unidirectional) { + if (unidirectional && VersionUsesHttp3(transport_version())) { + MaybeInitializeHttp3UnidirectionalStreams(); + } +} + +bool QuicSpdySession::goaway_received() const { + return VersionUsesHttp3(transport_version()) + ? last_received_http3_goaway_id_.has_value() + : transport_goaway_received(); +} + +bool QuicSpdySession::goaway_sent() const { + return VersionUsesHttp3(transport_version()) + ? last_sent_http3_goaway_id_.has_value() + : transport_goaway_sent(); +} + +void QuicSpdySession::CloseConnectionOnDuplicateHttp3UnidirectionalStreams( + absl::string_view type) { + QUIC_PEER_BUG(quic_peer_bug_10360_9) << absl::StrCat( + "Received a duplicate ", type, " stream: Closing connection."); + CloseConnectionWithDetails(QUIC_HTTP_DUPLICATE_UNIDIRECTIONAL_STREAM, + absl::StrCat(type, " stream is received twice.")); +} + +// static +void QuicSpdySession::LogHeaderCompressionRatioHistogram( + bool using_qpack, bool is_sent, QuicByteCount compressed, + QuicByteCount uncompressed) { + if (compressed <= 0 || uncompressed <= 0) { + return; + } + + int ratio = 100 * (compressed) / (uncompressed); + if (ratio < 1) { + ratio = 1; + } else if (ratio > 200) { + ratio = 200; + } + + // Note that when using histogram macros in Chromium, the histogram name must + // be the same across calls for any given call site. + if (using_qpack) { + if (is_sent) { + QUIC_HISTOGRAM_COUNTS("QuicSession.HeaderCompressionRatioQpackSent", + ratio, 1, 200, 200, + "Header compression ratio as percentage for sent " + "headers using QPACK."); + } else { + QUIC_HISTOGRAM_COUNTS("QuicSession.HeaderCompressionRatioQpackReceived", + ratio, 1, 200, 200, + "Header compression ratio as percentage for " + "received headers using QPACK."); + } + } else { + if (is_sent) { + QUIC_HISTOGRAM_COUNTS("QuicSession.HeaderCompressionRatioHpackSent", + ratio, 1, 200, 200, + "Header compression ratio as percentage for sent " + "headers using HPACK."); + } else { + QUIC_HISTOGRAM_COUNTS("QuicSession.HeaderCompressionRatioHpackReceived", + ratio, 1, 200, 200, + "Header compression ratio as percentage for " + "received headers using HPACK."); + } + } +} + +MessageStatus QuicSpdySession::SendHttp3Datagram(QuicStreamId stream_id, + absl::string_view payload) { + if (!SupportsH3Datagram()) { + QUIC_BUG(send http datagram too early) + << "Refusing to send HTTP Datagram before SETTINGS received"; + return MESSAGE_STATUS_INTERNAL_ERROR; + } + // Stream ID is sent divided by four as per the specification. + uint64_t stream_id_to_write = stream_id / kHttpDatagramStreamIdDivisor; + size_t slice_length = + QuicDataWriter::GetVarInt62Len(stream_id_to_write) + payload.length(); + quiche::QuicheBuffer buffer( + connection()->helper()->GetStreamSendBufferAllocator(), slice_length); + QuicDataWriter writer(slice_length, buffer.data()); + if (!writer.WriteVarInt62(stream_id_to_write)) { + QUIC_BUG(h3 datagram stream ID write fail) + << "Failed to write HTTP/3 datagram stream ID"; + return MESSAGE_STATUS_INTERNAL_ERROR; + } + if (!writer.WriteBytes(payload.data(), payload.length())) { + QUIC_BUG(h3 datagram payload write fail) + << "Failed to write HTTP/3 datagram payload"; + return MESSAGE_STATUS_INTERNAL_ERROR; + } + + quiche::QuicheMemSlice slice(std::move(buffer)); + return datagram_queue()->SendOrQueueDatagram(std::move(slice)); +} + +void QuicSpdySession::SetMaxDatagramTimeInQueueForStreamId( + QuicStreamId /*stream_id*/, QuicTime::Delta max_time_in_queue) { + // TODO(b/184598230): implement this in a way that works for multiple sessions + // on a same connection. + datagram_queue()->SetMaxTimeInQueue(max_time_in_queue); +} + +void QuicSpdySession::OnMessageReceived(absl::string_view message) { + QuicSession::OnMessageReceived(message); + if (!SupportsH3Datagram()) { + QUIC_DLOG(INFO) << "Ignoring unexpected received HTTP/3 datagram"; + return; + } + QuicDataReader reader(message); + uint64_t stream_id64; + if (!reader.ReadVarInt62(&stream_id64)) { + QUIC_DLOG(ERROR) << "Failed to parse stream ID in received HTTP/3 datagram"; + return; + } + // Stream ID is sent divided by four as per the specification. + if (stream_id64 > + std::numeric_limits::max() / kHttpDatagramStreamIdDivisor) { + CloseConnectionWithDetails( + QUIC_HTTP_FRAME_ERROR, + absl::StrCat("Received HTTP Datagram with invalid quarter stream ID ", + stream_id64)); + return; + } + stream_id64 *= kHttpDatagramStreamIdDivisor; + QuicStreamId stream_id = static_cast(stream_id64); + QuicSpdyStream* stream = + static_cast(GetActiveStream(stream_id)); + if (stream == nullptr) { + QUIC_DLOG(INFO) << "Received HTTP/3 datagram for unknown stream ID " + << stream_id; + // TODO(b/181256914) buffer HTTP/3 datagrams with unknown stream IDs for a + // short period of time in case they were reordered. + return; + } + stream->OnDatagramReceived(&reader); +} + +bool QuicSpdySession::SupportsWebTransport() { + return WillNegotiateWebTransport() && SupportsH3Datagram() && + peer_supports_webtransport_ && + (!GetQuicReloadableFlag(quic_verify_request_headers_2) || + allow_extended_connect_); +} + +bool QuicSpdySession::SupportsH3Datagram() const { + return http_datagram_support_ != HttpDatagramSupport::kNone; +} + +WebTransportHttp3* QuicSpdySession::GetWebTransportSession( + WebTransportSessionId id) { + if (!SupportsWebTransport()) { + return nullptr; + } + if (!IsValidWebTransportSessionId(id, version())) { + return nullptr; + } + QuicSpdyStream* connect_stream = GetOrCreateSpdyDataStream(id); + if (connect_stream == nullptr) { + return nullptr; + } + return connect_stream->web_transport(); +} + +bool QuicSpdySession::ShouldProcessIncomingRequests() { + if (!ShouldBufferRequestsUntilSettings()) { + return true; + } + + return any_settings_received_; +} + +void QuicSpdySession::OnStreamWaitingForClientSettings(QuicStreamId id) { + QUICHE_DCHECK(ShouldBufferRequestsUntilSettings()); + QUICHE_DCHECK(QuicUtils::IsBidirectionalStreamId(id, version())); + streams_waiting_for_settings_.insert(id); +} + +void QuicSpdySession::AssociateIncomingWebTransportStreamWithSession( + WebTransportSessionId session_id, QuicStreamId stream_id) { + if (QuicUtils::IsOutgoingStreamId(version(), stream_id, perspective())) { + QUIC_BUG(AssociateIncomingWebTransportStreamWithSession got outgoing stream) + << ENDPOINT + << "AssociateIncomingWebTransportStreamWithSession() got an outgoing " + "stream ID: " + << stream_id; + return; + } + WebTransportHttp3* session = GetWebTransportSession(session_id); + if (session != nullptr) { + QUIC_DVLOG(1) << ENDPOINT + << "Successfully associated incoming WebTransport stream " + << stream_id << " with session ID " << session_id; + + session->AssociateStream(stream_id); + return; + } + // Evict the oldest streams until we are under the limit. + while (buffered_streams_.size() >= kMaxUnassociatedWebTransportStreams) { + QUIC_DVLOG(1) << ENDPOINT << "Removing stream " + << buffered_streams_.front().stream_id + << " from buffered streams as the queue is full."; + ResetStream(buffered_streams_.front().stream_id, + QUIC_STREAM_WEBTRANSPORT_BUFFERED_STREAMS_LIMIT_EXCEEDED); + buffered_streams_.pop_front(); + } + QUIC_DVLOG(1) << ENDPOINT << "Received a WebTransport stream " << stream_id + << " for session ID " << session_id + << " but cannot associate it; buffering instead."; + buffered_streams_.push_back( + BufferedWebTransportStream{session_id, stream_id}); +} + +void QuicSpdySession::ProcessBufferedWebTransportStreamsForSession( + WebTransportHttp3* session) { + const WebTransportSessionId session_id = session->id(); + QUIC_DVLOG(1) << "Processing buffered WebTransport streams for " + << session_id; + auto it = buffered_streams_.begin(); + while (it != buffered_streams_.end()) { + if (it->session_id == session_id) { + QUIC_DVLOG(1) << "Unbuffered and associated WebTransport stream " + << it->stream_id << " with session " << it->session_id; + session->AssociateStream(it->stream_id); + it = buffered_streams_.erase(it); + } else { + it++; + } + } +} + +WebTransportHttp3UnidirectionalStream* +QuicSpdySession::CreateOutgoingUnidirectionalWebTransportStream( + WebTransportHttp3* session) { + if (!CanOpenNextOutgoingUnidirectionalStream()) { + return nullptr; + } + + QuicStreamId stream_id = GetNextOutgoingUnidirectionalStreamId(); + auto stream_owned = std::make_unique( + stream_id, this, session->id()); + WebTransportHttp3UnidirectionalStream* stream = stream_owned.get(); + ActivateStream(std::move(stream_owned)); + stream->WritePreamble(); + session->AssociateStream(stream_id); + return stream; +} + +QuicSpdyStream* QuicSpdySession::CreateOutgoingBidirectionalWebTransportStream( + WebTransportHttp3* session) { + QuicSpdyStream* stream = CreateOutgoingBidirectionalStream(); + if (stream == nullptr) { + return nullptr; + } + QuicStreamId stream_id = stream->id(); + stream->ConvertToWebTransportDataStream(session->id()); + if (stream->web_transport_stream() == nullptr) { + // An error in ConvertToWebTransportDataStream() would result in + // CONNECTION_CLOSE, thus we don't need to do anything here. + return nullptr; + } + session->AssociateStream(stream_id); + return stream; +} + +void QuicSpdySession::OnDatagramProcessed( + absl::optional /*status*/) { + // TODO(b/184598230): make this work with multiple datagram flows. +} + +void QuicSpdySession::DatagramObserver::OnDatagramProcessed( + absl::optional status) { + session_->OnDatagramProcessed(status); +} + +HttpDatagramSupport QuicSpdySession::LocalHttpDatagramSupport() { + return HttpDatagramSupport::kNone; +} + +std::string HttpDatagramSupportToString( + HttpDatagramSupport http_datagram_support) { + switch (http_datagram_support) { + case HttpDatagramSupport::kNone: + return "None"; + case HttpDatagramSupport::kDraft04: + return "Draft04"; + case HttpDatagramSupport::kRfc: + return "Rfc"; + case HttpDatagramSupport::kRfcAndDraft04: + return "RfcAndDraft04"; + } + return absl::StrCat("Unknown(", static_cast(http_datagram_support), ")"); +} + +std::ostream& operator<<(std::ostream& os, + const HttpDatagramSupport& http_datagram_support) { + os << HttpDatagramSupportToString(http_datagram_support); + return os; +} + +// Must not be called after Initialize(). +void QuicSpdySession::set_allow_extended_connect(bool allow_extended_connect) { + QUIC_BUG_IF(extended connect wrong version, + !GetQuicReloadableFlag(quic_verify_request_headers_2) || + !VersionUsesHttp3(transport_version())) + << "Try to enable/disable extended CONNECT in Google QUIC"; + QUIC_BUG_IF(extended connect on client, + !GetQuicReloadableFlag(quic_verify_request_headers_2) || + perspective() == Perspective::IS_CLIENT) + << "Enabling/disabling extended CONNECT on the client side has no effect"; + if (ShouldNegotiateWebTransport()) { + QUIC_BUG_IF(disable extended connect, !allow_extended_connect) + << "Disabling extended CONNECT with web transport enabled has no " + "effect."; + return; + } + allow_extended_connect_ = allow_extended_connect; +} + +#undef ENDPOINT // undef for jumbo builds + +} // namespace quic diff --git a/quiche/quic/core/http/quic_spdy_session.h b/quiche/quic/core/http/quic_spdy_session.h new file mode 100644 index 000000000000..0102787784a8 --- /dev/null +++ b/quiche/quic/core/http/quic_spdy_session.h @@ -0,0 +1,674 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_HTTP_QUIC_SPDY_SESSION_H_ +#define QUICHE_QUIC_CORE_HTTP_QUIC_SPDY_SESSION_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/quic/core/http/http_frames.h" +#include "quiche/quic/core/http/quic_header_list.h" +#include "quiche/quic/core/http/quic_headers_stream.h" +#include "quiche/quic/core/http/quic_receive_control_stream.h" +#include "quiche/quic/core/http/quic_send_control_stream.h" +#include "quiche/quic/core/http/quic_spdy_stream.h" +#include "quiche/quic/core/qpack/qpack_decoder.h" +#include "quiche/quic/core/qpack/qpack_encoder.h" +#include "quiche/quic/core/qpack/qpack_receive_stream.h" +#include "quiche/quic/core/qpack/qpack_send_stream.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_stream_priority.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/spdy/core/http2_frame_decoder_adapter.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +namespace test { +class QuicSpdySessionPeer; +} // namespace test + +class WebTransportHttp3UnidirectionalStream; + +QUIC_EXPORT_PRIVATE extern const size_t kMaxUnassociatedWebTransportStreams; + +class QUIC_EXPORT_PRIVATE Http3DebugVisitor { + public: + Http3DebugVisitor(); + Http3DebugVisitor(const Http3DebugVisitor&) = delete; + Http3DebugVisitor& operator=(const Http3DebugVisitor&) = delete; + + virtual ~Http3DebugVisitor(); + + // TODO(https://crbug.com/1062700): Remove default implementation of all + // methods after Chrome's QuicHttp3Logger has overrides. This is to make sure + // QUICHE merge is not blocked on having to add those overrides, they can + // happen asynchronously. + + // Creation of unidirectional streams. + + // Called when locally-initiated control stream is created. + virtual void OnControlStreamCreated(QuicStreamId /*stream_id*/) = 0; + // Called when locally-initiated QPACK encoder stream is created. + virtual void OnQpackEncoderStreamCreated(QuicStreamId /*stream_id*/) = 0; + // Called when locally-initiated QPACK decoder stream is created. + virtual void OnQpackDecoderStreamCreated(QuicStreamId /*stream_id*/) = 0; + // Called when peer's control stream type is received. + virtual void OnPeerControlStreamCreated(QuicStreamId /*stream_id*/) = 0; + // Called when peer's QPACK encoder stream type is received. + virtual void OnPeerQpackEncoderStreamCreated(QuicStreamId /*stream_id*/) = 0; + // Called when peer's QPACK decoder stream type is received. + virtual void OnPeerQpackDecoderStreamCreated(QuicStreamId /*stream_id*/) = 0; + + // Incoming HTTP/3 frames in ALPS TLS extension. + virtual void OnSettingsFrameReceivedViaAlps(const SettingsFrame& /*frame*/) {} + virtual void OnAcceptChFrameReceivedViaAlps(const AcceptChFrame& /*frame*/) {} + + // Incoming HTTP/3 frames on the control stream. + virtual void OnSettingsFrameReceived(const SettingsFrame& /*frame*/) = 0; + virtual void OnGoAwayFrameReceived(const GoAwayFrame& /*frame*/) = 0; + virtual void OnPriorityUpdateFrameReceived( + const PriorityUpdateFrame& /*frame*/) = 0; + virtual void OnAcceptChFrameReceived(const AcceptChFrame& /*frame*/) {} + + // Incoming HTTP/3 frames on request or push streams. + virtual void OnDataFrameReceived(QuicStreamId /*stream_id*/, + QuicByteCount /*payload_length*/) = 0; + virtual void OnHeadersFrameReceived( + QuicStreamId /*stream_id*/, + QuicByteCount /*compressed_headers_length*/) = 0; + virtual void OnHeadersDecoded(QuicStreamId /*stream_id*/, + QuicHeaderList /*headers*/) = 0; + + // Incoming HTTP/3 frames of unknown type on any stream. + virtual void OnUnknownFrameReceived(QuicStreamId /*stream_id*/, + uint64_t /*frame_type*/, + QuicByteCount /*payload_length*/) = 0; + + // Outgoing HTTP/3 frames on the control stream. + virtual void OnSettingsFrameSent(const SettingsFrame& /*frame*/) = 0; + virtual void OnGoAwayFrameSent(QuicStreamId /*stream_id*/) = 0; + virtual void OnPriorityUpdateFrameSent( + const PriorityUpdateFrame& /*frame*/) = 0; + + // Outgoing HTTP/3 frames on request or push streams. + virtual void OnDataFrameSent(QuicStreamId /*stream_id*/, + QuicByteCount /*payload_length*/) = 0; + virtual void OnHeadersFrameSent( + QuicStreamId /*stream_id*/, + const spdy::Http2HeaderBlock& /*header_block*/) = 0; + + // 0-RTT related events. + virtual void OnSettingsFrameResumed(const SettingsFrame& /*frame*/) = 0; +}; + +// Whether HTTP Datagrams are supported on this session and if so which version +// is currently in use. +enum class HttpDatagramSupport : uint8_t { + kNone, // HTTP Datagrams are not supported for this session. + kDraft04, + kRfc, + kRfcAndDraft04, // Only used locally for sending, we only negotiate one + // version. +}; + +QUIC_EXPORT_PRIVATE std::string HttpDatagramSupportToString( + HttpDatagramSupport http_datagram_support); +QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const HttpDatagramSupport& http_datagram_support); + +// A QUIC session for HTTP. +class QUIC_EXPORT_PRIVATE QuicSpdySession + : public QuicSession, + public QpackEncoder::DecoderStreamErrorDelegate, + public QpackDecoder::EncoderStreamErrorDelegate { + public: + // Does not take ownership of |connection| or |visitor|. + QuicSpdySession(QuicConnection* connection, QuicSession::Visitor* visitor, + const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions); + QuicSpdySession(const QuicSpdySession&) = delete; + QuicSpdySession& operator=(const QuicSpdySession&) = delete; + + ~QuicSpdySession() override; + + void Initialize() override; + + // QpackEncoder::DecoderStreamErrorDelegate implementation. + void OnDecoderStreamError(QuicErrorCode error_code, + absl::string_view error_message) override; + + // QpackDecoder::EncoderStreamErrorDelegate implementation. + void OnEncoderStreamError(QuicErrorCode error_code, + absl::string_view error_message) override; + + // Called by |headers_stream_| when headers with a priority have been + // received for a stream. This method will only be called for server streams. + virtual void OnStreamHeadersPriority( + QuicStreamId stream_id, const spdy::SpdyStreamPrecedence& precedence); + + // Called by |headers_stream_| when headers have been completely received + // for a stream. |fin| will be true if the fin flag was set in the headers + // frame. + virtual void OnStreamHeaderList(QuicStreamId stream_id, bool fin, + size_t frame_len, + const QuicHeaderList& header_list); + + // Called by |headers_stream_| when push promise headers have been + // completely received. |fin| will be true if the fin flag was set + // in the headers. + virtual void OnPromiseHeaderList(QuicStreamId stream_id, + QuicStreamId promised_stream_id, + size_t frame_len, + const QuicHeaderList& header_list); + + // Called by |headers_stream_| when a PRIORITY frame has been received for a + // stream. This method will only be called for server streams. + virtual void OnPriorityFrame(QuicStreamId stream_id, + const spdy::SpdyStreamPrecedence& precedence); + + // Called when an HTTP/3 PRIORITY_UPDATE frame has been received for a request + // stream. Returns false and closes connection if |stream_id| is invalid. + bool OnPriorityUpdateForRequestStream(QuicStreamId stream_id, + HttpStreamPriority priority); + + // Called when an HTTP/3 ACCEPT_CH frame has been received. + // This method will only be called for client sessions. + virtual void OnAcceptChFrame(const AcceptChFrame& /*frame*/) {} + + // Called when an HTTP/3 frame of unknown type has been received. + virtual void OnUnknownFrameStart(QuicStreamId /*stream_id*/, + uint64_t /*frame_type*/, + QuicByteCount /*header_length*/, + QuicByteCount /*payload_length*/) {} + virtual void OnUnknownFramePayload(QuicStreamId /*stream_id*/, + absl::string_view /*payload*/) {} + + // Sends contents of |iov| to h2_deframer_, returns number of bytes processed. + size_t ProcessHeaderData(const struct iovec& iov); + + // Writes |headers| for the stream |id| to the dedicated headers stream. + // If |fin| is true, then no more data will be sent for the stream |id|. + // If provided, |ack_notifier_delegate| will be registered to be notified when + // we have seen ACKs for all packets resulting from this call. + virtual size_t WriteHeadersOnHeadersStream( + QuicStreamId id, spdy::Http2HeaderBlock headers, bool fin, + const spdy::SpdyStreamPrecedence& precedence, + quiche::QuicheReferenceCountedPointer + ack_listener); + + // Writes an HTTP/2 PRIORITY frame the to peer. Returns the size in bytes of + // the resulting PRIORITY frame. + size_t WritePriority(QuicStreamId stream_id, QuicStreamId parent_stream_id, + int weight, bool exclusive); + + // Writes an HTTP/3 PRIORITY_UPDATE frame to the peer. + void WriteHttp3PriorityUpdate(QuicStreamId stream_id, + HttpStreamPriority priority); + + // Process received HTTP/3 GOAWAY frame. When sent from server to client, + // |id| is a stream ID. When sent from client to server, |id| is a push ID. + virtual void OnHttp3GoAway(uint64_t id); + + // Send GOAWAY if the peer is blocked on the implementation max. + bool OnStreamsBlockedFrame(const QuicStreamsBlockedFrame& frame) override; + + // Write GOAWAY frame with maximum stream ID on the control stream. Called to + // initite graceful connection shutdown. Do not use smaller stream ID, in + // case client does not implement retry on GOAWAY. Do not send GOAWAY if one + // has already been sent. Send connection close with |error_code| and |reason| + // before encryption gets established. + void SendHttp3GoAway(QuicErrorCode error_code, const std::string& reason); + + // Write |headers| for |promised_stream_id| on |original_stream_id| in a + // PUSH_PROMISE frame to peer. + virtual void WritePushPromise(QuicStreamId original_stream_id, + QuicStreamId promised_stream_id, + spdy::Http2HeaderBlock headers); + + QpackEncoder* qpack_encoder(); + QpackDecoder* qpack_decoder(); + QuicHeadersStream* headers_stream() { return headers_stream_; } + + const QuicHeadersStream* headers_stream() const { return headers_stream_; } + + // Called when the control stream receives HTTP/3 SETTINGS. + // Returns false in case of 0-RTT if received settings are incompatible with + // cached values, true otherwise. + virtual bool OnSettingsFrame(const SettingsFrame& frame); + + // Called when an HTTP/3 SETTINGS frame is received via ALPS. + // Returns an error message if an error has occurred, or nullopt otherwise. + // May or may not close the connection on error. + absl::optional OnSettingsFrameViaAlps( + const SettingsFrame& frame); + + // Called when a setting is parsed from a SETTINGS frame received on the + // control stream or from cached application state. + // Returns true on success. + // Returns false if received setting is incompatible with cached value (in + // case of 0-RTT) or with previously received value (in case of ALPS). + // Also closes the connection on error. + bool OnSetting(uint64_t id, uint64_t value); + + // Return true if this session wants to release headers stream's buffer + // aggressively. + virtual bool ShouldReleaseHeadersStreamSequencerBuffer(); + + void CloseConnectionWithDetails(QuicErrorCode error, + const std::string& details); + + // Must not be called after Initialize(). + // TODO(bnc): Move to constructor argument. + void set_qpack_maximum_dynamic_table_capacity( + uint64_t qpack_maximum_dynamic_table_capacity) { + qpack_maximum_dynamic_table_capacity_ = + qpack_maximum_dynamic_table_capacity; + } + + // Must not be called after Initialize(). + // TODO(bnc): Move to constructor argument. + void set_qpack_maximum_blocked_streams( + uint64_t qpack_maximum_blocked_streams) { + qpack_maximum_blocked_streams_ = qpack_maximum_blocked_streams; + } + + // Should only be used by IETF QUIC server side. + // Must not be called after Initialize(). + // TODO(bnc): Move to constructor argument. + void set_max_inbound_header_list_size(size_t max_inbound_header_list_size) { + max_inbound_header_list_size_ = max_inbound_header_list_size; + } + + // Must not be called after Initialize(). + void set_allow_extended_connect(bool allow_extended_connect); + + size_t max_outbound_header_list_size() const { + return max_outbound_header_list_size_; + } + + size_t max_inbound_header_list_size() const { + return max_inbound_header_list_size_; + } + + bool allow_extended_connect() const { return allow_extended_connect_; } + + // Returns true if the session has active request streams. + bool HasActiveRequestStreams() const; + + // Called when the size of the compressed frame payload is available. + void OnCompressedFrameSize(size_t frame_len); + + // Called when a PUSH_PROMISE frame has been received. + // TODO(b/171463363): Remove. + void OnPushPromise(spdy::SpdyStreamId stream_id, + spdy::SpdyStreamId promised_stream_id); + + // Called when the complete list of headers is available. + void OnHeaderList(const QuicHeaderList& header_list); + + QuicStreamId promised_stream_id() const { return promised_stream_id_; } + + // Initialze HTTP/3 unidirectional streams if |unidirectional| is true and + // those streams are not initialized yet. + void OnCanCreateNewOutgoingStream(bool unidirectional) override; + + int32_t destruction_indicator() const { return destruction_indicator_; } + + void set_debug_visitor(Http3DebugVisitor* debug_visitor) { + debug_visitor_ = debug_visitor; + } + + Http3DebugVisitor* debug_visitor() { return debug_visitor_; } + + // When using Google QUIC, return whether a transport layer GOAWAY frame has + // been received or sent. + // When using IETF QUIC, return whether an HTTP/3 GOAWAY frame has been + // received or sent. + bool goaway_received() const; + bool goaway_sent() const; + + // Log header compression ratio histogram. + // |using_qpack| is true for QPACK, false for HPACK. + // |is_sent| is true for sent headers, false for received ones. + // Ratio is recorded as percentage. Smaller value means more efficient + // compression. Compressed size might be larger than uncompressed size, but + // recorded ratio is trunckated at 200%. + // Uncompressed size can be zero for an empty header list, and compressed size + // can be zero for an empty header list when using HPACK. (QPACK always emits + // a header block prefix of at least two bytes.) This method records nothing + // if either |compressed| or |uncompressed| is not positive. + // In order for measurements for different protocol to be comparable, the + // caller must ensure that uncompressed size is the total length of header + // names and values without any overhead. + static void LogHeaderCompressionRatioHistogram(bool using_qpack, bool is_sent, + QuicByteCount compressed, + QuicByteCount uncompressed); + + // True if any dynamic table entries have been referenced from either a sent + // or received header block. Used for stats. + bool dynamic_table_entry_referenced() const { + return (qpack_encoder_ && + qpack_encoder_->dynamic_table_entry_referenced()) || + (qpack_decoder_ && qpack_decoder_->dynamic_table_entry_referenced()); + } + + void OnStreamCreated(QuicSpdyStream* stream); + + // Decode SETTINGS from |cached_state| and apply it to the session. + bool ResumeApplicationState(ApplicationState* cached_state) override; + + absl::optional OnAlpsData(const uint8_t* alps_data, + size_t alps_length) override; + + // Called when ACCEPT_CH frame is parsed out of data received in TLS ALPS + // extension. + virtual void OnAcceptChFrameReceivedViaAlps(const AcceptChFrame& /*frame*/); + + // Whether HTTP datagrams are supported on this session and which draft is in + // use, based on received SETTINGS. + HttpDatagramSupport http_datagram_support() const { + return http_datagram_support_; + } + + // This must not be used except by QuicSpdyStream::SendHttp3Datagram. + MessageStatus SendHttp3Datagram(QuicStreamId stream_id, + absl::string_view payload); + // This must not be used except by QuicSpdyStream::SetMaxDatagramTimeInQueue. + void SetMaxDatagramTimeInQueueForStreamId(QuicStreamId stream_id, + QuicTime::Delta max_time_in_queue); + + // Override from QuicSession to support HTTP/3 datagrams. + void OnMessageReceived(absl::string_view message) override; + + // Indicates whether the HTTP/3 session supports WebTransport. + bool SupportsWebTransport(); + + // Indicates whether both the peer and us support HTTP/3 Datagrams. + bool SupportsH3Datagram() const; + + // Indicates whether the HTTP/3 session will indicate WebTransport support to + // the peer. + bool WillNegotiateWebTransport(); + + // Returns a WebTransport session by its session ID. Returns nullptr if no + // session is associated with the given ID. + WebTransportHttp3* GetWebTransportSession(WebTransportSessionId id); + + // If true, no data on bidirectional streams will be processed by the server + // until the SETTINGS are received. Only works for HTTP/3. This is currently + // required either (1) for WebTransport because WebTransport needs settings to + // correctly parse requests or (2) when multiple versions of HTTP Datagrams + // are supported to ensure we know which one is used. The HTTP Datagram check + // will be removed once we drop support for draft04. + bool ShouldBufferRequestsUntilSettings() { + return version().UsesHttp3() && perspective() == Perspective::IS_SERVER && + (ShouldNegotiateWebTransport() || + LocalHttpDatagramSupport() == HttpDatagramSupport::kRfcAndDraft04); + } + + // Returns if the incoming bidirectional streams should process data. This is + // usually true, but in certain cases we would want to wait until the settings + // are received. + bool ShouldProcessIncomingRequests(); + + void OnStreamWaitingForClientSettings(QuicStreamId id); + + // Links the specified stream with a WebTransport session. If the session is + // not present, it is buffered until a corresponding stream is found. + void AssociateIncomingWebTransportStreamWithSession( + WebTransportSessionId session_id, QuicStreamId stream_id); + + void ProcessBufferedWebTransportStreamsForSession(WebTransportHttp3* session); + + bool CanOpenOutgoingUnidirectionalWebTransportStream( + WebTransportSessionId /*id*/) { + return CanOpenNextOutgoingUnidirectionalStream(); + } + bool CanOpenOutgoingBidirectionalWebTransportStream( + WebTransportSessionId /*id*/) { + return CanOpenNextOutgoingBidirectionalStream(); + } + + // Creates an outgoing unidirectional WebTransport stream. Returns nullptr if + // the stream cannot be created due to flow control or some other reason. + WebTransportHttp3UnidirectionalStream* + CreateOutgoingUnidirectionalWebTransportStream(WebTransportHttp3* session); + + // Creates an outgoing bidirectional WebTransport stream. Returns nullptr if + // the stream cannot be created due to flow control or some other reason. + QuicSpdyStream* CreateOutgoingBidirectionalWebTransportStream( + WebTransportHttp3* session); + + QuicSpdyStream* GetOrCreateSpdyDataStream(const QuicStreamId stream_id); + + // Indicates whether the client should check that the + // `Sec-Webtransport-Http3-Draft` header is valid. + // TODO(vasilvv): remove this once this is enabled in Chromium. + virtual bool ShouldValidateWebTransportVersion() const; + + protected: + // Override CreateIncomingStream(), CreateOutgoingBidirectionalStream() and + // CreateOutgoingUnidirectionalStream() with QuicSpdyStream return type to + // make sure that all data streams are QuicSpdyStreams. + QuicSpdyStream* CreateIncomingStream(QuicStreamId id) override = 0; + QuicSpdyStream* CreateIncomingStream(PendingStream* pending) override = 0; + virtual QuicSpdyStream* CreateOutgoingBidirectionalStream() = 0; + virtual QuicSpdyStream* CreateOutgoingUnidirectionalStream() = 0; + + // If an incoming stream can be created, return true. + virtual bool ShouldCreateIncomingStream(QuicStreamId id) = 0; + + // If an outgoing bidirectional/unidirectional stream can be created, return + // true. + virtual bool ShouldCreateOutgoingBidirectionalStream() = 0; + virtual bool ShouldCreateOutgoingUnidirectionalStream() = 0; + + // Indicates whether the underlying backend can accept and process + // WebTransport sessions over HTTP/3. + virtual bool ShouldNegotiateWebTransport(); + + // Returns true if there are open HTTP requests. + bool ShouldKeepConnectionAlive() const override; + + // Overridden to buffer incoming unidirectional streams for version 99. + bool UsesPendingStreamForFrame(QuicFrameType type, + QuicStreamId stream_id) const override; + + // Processes incoming unidirectional streams; parses the stream type, and + // creates a new stream of the corresponding type. Returns the pointer to the + // newly created stream, or nullptr if the stream type is not yet available. + QuicStream* ProcessPendingStream(PendingStream* pending) override; + + size_t WriteHeadersOnHeadersStreamImpl( + QuicStreamId id, spdy::Http2HeaderBlock headers, bool fin, + QuicStreamId parent_stream_id, int weight, bool exclusive, + quiche::QuicheReferenceCountedPointer + ack_listener); + + void OnNewEncryptionKeyAvailable( + EncryptionLevel level, std::unique_ptr encrypter) override; + + // Sets the maximum size of the header compression table spdy_framer_ is + // willing to use to encode header blocks. + void UpdateHeaderEncoderTableSize(uint32_t value); + + bool IsConnected() { return connection()->connected(); } + + const QuicReceiveControlStream* receive_control_stream() const { + return receive_control_stream_; + } + + const SettingsFrame& settings() const { return settings_; } + + // Initializes HTTP/3 unidirectional streams if not yet initialzed. + virtual void MaybeInitializeHttp3UnidirectionalStreams(); + + // QuicConnectionVisitorInterface method. + void BeforeConnectionCloseSent() override; + + // Called whenever a datagram is dequeued or dropped from datagram_queue(). + virtual void OnDatagramProcessed(absl::optional status); + + // Returns which version of the HTTP/3 datagram extension we should advertise + // in settings and accept remote settings for. + virtual HttpDatagramSupport LocalHttpDatagramSupport(); + + // Sends any data which should be sent at the start of a connection, including + // the initial SETTINGS frame. When using 0-RTT, this method is called twice: + // once when encryption is established, and again when 1-RTT keys are + // available. + void SendInitialData(); + + private: + friend class test::QuicSpdySessionPeer; + + class SpdyFramerVisitor; + + // Proxies OnDatagramProcessed() calls to the session. + class QUIC_EXPORT_PRIVATE DatagramObserver + : public QuicDatagramQueue::Observer { + public: + explicit DatagramObserver(QuicSpdySession* session) : session_(session) {} + void OnDatagramProcessed(absl::optional status) override; + + private: + QuicSpdySession* session_; // not owned + }; + + struct QUIC_EXPORT_PRIVATE BufferedWebTransportStream { + WebTransportSessionId session_id; + QuicStreamId stream_id; + }; + + // The following methods are called by the SimpleVisitor. + + // Called when a HEADERS frame has been received. + void OnHeaders(spdy::SpdyStreamId stream_id, bool has_priority, + const spdy::SpdyStreamPrecedence& precedence, bool fin); + + // Called when a PRIORITY frame has been received. + void OnPriority(spdy::SpdyStreamId stream_id, + const spdy::SpdyStreamPrecedence& precedence); + + void CloseConnectionOnDuplicateHttp3UnidirectionalStreams( + absl::string_view type); + + void FillSettingsFrame(); + + bool VerifySettingIsZeroOrOne(uint64_t id, uint64_t value); + + std::unique_ptr qpack_encoder_; + std::unique_ptr qpack_decoder_; + + // Pointer to the header stream in stream_map_. + QuicHeadersStream* headers_stream_; + + // HTTP/3 control streams. They are owned by QuicSession inside + // stream map, and can be accessed by those unowned pointers below. + QuicSendControlStream* send_control_stream_; + QuicReceiveControlStream* receive_control_stream_; + + // Pointers to HTTP/3 QPACK streams in stream map. + QpackReceiveStream* qpack_encoder_receive_stream_; + QpackReceiveStream* qpack_decoder_receive_stream_; + QpackSendStream* qpack_encoder_send_stream_; + QpackSendStream* qpack_decoder_send_stream_; + + SettingsFrame settings_; + + // Maximum dynamic table capacity as defined at + // https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#maximum-dynamic-table-capacity + // for the decoding context. Value will be sent via + // SETTINGS_QPACK_MAX_TABLE_CAPACITY. + // |qpack_maximum_dynamic_table_capacity_| also serves as an upper bound for + // the dynamic table capacity of the encoding context, to limit memory usage + // if a larger SETTINGS_QPACK_MAX_TABLE_CAPACITY value is received. + uint64_t qpack_maximum_dynamic_table_capacity_; + + // Maximum number of blocked streams as defined at + // https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#blocked-streams + // for the decoding context. Value will be sent via + // SETTINGS_QPACK_BLOCKED_STREAMS. + uint64_t qpack_maximum_blocked_streams_; + + // The maximum size of a header block that will be accepted from the peer, + // defined per spec as key + value + overhead per field (uncompressed). + // Value will be sent via SETTINGS_MAX_HEADER_LIST_SIZE. + size_t max_inbound_header_list_size_; + + // The maximum size of a header block that can be sent to the peer. This field + // is informed and set by the peer via SETTINGS frame. + // TODO(b/148616439): Honor this field when sending headers. + size_t max_outbound_header_list_size_; + + // Data about the stream whose headers are being processed. + QuicStreamId stream_id_; + QuicStreamId promised_stream_id_; + size_t frame_len_; + bool fin_; + + spdy::SpdyFramer spdy_framer_; + http2::Http2DecoderAdapter h2_deframer_; + std::unique_ptr spdy_framer_visitor_; + + // Not owned by the session. + Http3DebugVisitor* debug_visitor_; + + // Priority values received in PRIORITY_UPDATE frames for streams that are not + // open yet. + absl::flat_hash_map + buffered_stream_priorities_; + + // An integer used for live check. The indicator is assigned a value in + // constructor. As long as it is not the assigned value, that would indicate + // an use-after-free. + int32_t destruction_indicator_; + + // The identifier in the most recently received GOAWAY frame. Unset if no + // GOAWAY frame has been received yet. + absl::optional last_received_http3_goaway_id_; + // The identifier in the most recently sent GOAWAY frame. Unset if no GOAWAY + // frame has been sent yet. + absl::optional last_sent_http3_goaway_id_; + + // Whether both this endpoint and our peer support HTTP datagrams and which + // draft is in use for this session. + HttpDatagramSupport http_datagram_support_ = HttpDatagramSupport::kNone; + + // Whether the peer has indicated WebTransport support. + bool peer_supports_webtransport_ = false; + + // Whether any settings have been received, either from the peer or from a + // session ticket. + bool any_settings_received_ = false; + + // If ShouldBufferRequestsUntilSettings() is true, all streams that are + // blocked by that are tracked here. + absl::flat_hash_set streams_waiting_for_settings_; + + // WebTransport streams that do not have a session associated with them. + // Limited to kMaxUnassociatedWebTransportStreams; when the list is full, + // oldest streams are evicated first. + std::list buffered_streams_; + + // On the server side, if true, advertise and accept extended CONNECT method. + // On the client side, true if the peer advertised extended CONNECT. + bool allow_extended_connect_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_HTTP_QUIC_SPDY_SESSION_H_ diff --git a/quiche/quic/core/http/quic_spdy_session_test.cc b/quiche/quic/core/http/quic_spdy_session_test.cc new file mode 100644 index 000000000000..966b23191451 --- /dev/null +++ b/quiche/quic/core/http/quic_spdy_session_test.cc @@ -0,0 +1,3785 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_spdy_session.h" + +#include +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/memory/memory.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/frames/quic_stream_frame.h" +#include "quiche/quic/core/frames/quic_streams_blocked_frame.h" +#include "quiche/quic/core/http/http_constants.h" +#include "quiche/quic/core/http/http_encoder.h" +#include "quiche/quic/core/http/quic_header_list.h" +#include "quiche/quic/core/http/web_transport_http3.h" +#include "quiche/quic/core/qpack/qpack_header_table.h" +#include "quiche/quic/core/quic_config.h" +#include "quiche/quic/core/quic_crypto_stream.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_stream.h" +#include "quiche/quic/core/quic_stream_priority.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/qpack/qpack_encoder_peer.h" +#include "quiche/quic/test_tools/qpack/qpack_test_utils.h" +#include "quiche/quic/test_tools/quic_config_peer.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_flow_controller_peer.h" +#include "quiche/quic/test_tools/quic_session_peer.h" +#include "quiche/quic/test_tools/quic_spdy_session_peer.h" +#include "quiche/quic/test_tools/quic_stream_peer.h" +#include "quiche/quic/test_tools/quic_stream_send_buffer_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/common/quiche_endian.h" +#include "quiche/common/test_tools/quiche_test_utils.h" +#include "quiche/spdy/core/spdy_framer.h" + +using spdy::Http2HeaderBlock; +using spdy::kV3HighestPriority; +using spdy::Spdy3PriorityToHttp2Weight; +using spdy::SpdyFramer; +using spdy::SpdyPriority; +using spdy::SpdyPriorityIR; +using spdy::SpdySerializedFrame; +using ::testing::_; +using ::testing::AnyNumber; +using ::testing::AtLeast; +using ::testing::ElementsAre; +using ::testing::InSequence; +using ::testing::Invoke; +using ::testing::Return; +using ::testing::StrictMock; + +namespace quic { +namespace test { +namespace { + +bool VerifyAndClearStopSendingFrame(const QuicFrame& frame) { + EXPECT_EQ(STOP_SENDING_FRAME, frame.type); + return ClearControlFrame(frame); +} + +class TestCryptoStream : public QuicCryptoStream, public QuicCryptoHandshaker { + public: + explicit TestCryptoStream(QuicSession* session) + : QuicCryptoStream(session), + QuicCryptoHandshaker(this, session), + encryption_established_(false), + one_rtt_keys_available_(false), + params_(new QuicCryptoNegotiatedParameters) { + // Simulate a negotiated cipher_suite with a fake value. + params_->cipher_suite = 1; + } + + void EstablishZeroRttEncryption() { + encryption_established_ = true; + session()->connection()->SetEncrypter( + ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + } + + void OnHandshakeMessage(const CryptoHandshakeMessage& /*message*/) override { + encryption_established_ = true; + one_rtt_keys_available_ = true; + QuicErrorCode error; + std::string error_details; + session()->config()->SetInitialStreamFlowControlWindowToSend( + kInitialStreamFlowControlWindowForTest); + session()->config()->SetInitialSessionFlowControlWindowToSend( + kInitialSessionFlowControlWindowForTest); + if (session()->version().UsesTls()) { + if (session()->perspective() == Perspective::IS_CLIENT) { + session()->config()->SetOriginalConnectionIdToSend( + session()->connection()->connection_id()); + session()->config()->SetInitialSourceConnectionIdToSend( + session()->connection()->connection_id()); + } else { + session()->config()->SetInitialSourceConnectionIdToSend( + session()->connection()->client_connection_id()); + } + TransportParameters transport_parameters; + EXPECT_TRUE( + session()->config()->FillTransportParameters(&transport_parameters)); + error = session()->config()->ProcessTransportParameters( + transport_parameters, /* is_resumption = */ false, &error_details); + } else { + CryptoHandshakeMessage msg; + session()->config()->ToHandshakeMessage(&msg, transport_version()); + error = + session()->config()->ProcessPeerHello(msg, CLIENT, &error_details); + } + EXPECT_THAT(error, IsQuicNoError()); + session()->OnNewEncryptionKeyAvailable( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + session()->OnConfigNegotiated(); + if (session()->connection()->version().handshake_protocol == + PROTOCOL_TLS1_3) { + session()->OnTlsHandshakeComplete(); + } else { + session()->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + } + session()->DiscardOldEncryptionKey(ENCRYPTION_INITIAL); + } + + // QuicCryptoStream implementation + ssl_early_data_reason_t EarlyDataReason() const override { + return ssl_early_data_unknown; + } + bool encryption_established() const override { + return encryption_established_; + } + bool one_rtt_keys_available() const override { + return one_rtt_keys_available_; + } + HandshakeState GetHandshakeState() const override { + return one_rtt_keys_available() ? HANDSHAKE_COMPLETE : HANDSHAKE_START; + } + void SetServerApplicationStateForResumption( + std::unique_ptr /*application_state*/) override {} + std::unique_ptr AdvanceKeysAndCreateCurrentOneRttDecrypter() + override { + return nullptr; + } + std::unique_ptr CreateCurrentOneRttEncrypter() override { + return nullptr; + } + const QuicCryptoNegotiatedParameters& crypto_negotiated_params() + const override { + return *params_; + } + CryptoMessageParser* crypto_message_parser() override { + return QuicCryptoHandshaker::crypto_message_parser(); + } + void OnPacketDecrypted(EncryptionLevel /*level*/) override {} + void OnOneRttPacketAcknowledged() override {} + void OnHandshakePacketSent() override {} + void OnHandshakeDoneReceived() override {} + void OnNewTokenReceived(absl::string_view /*token*/) override {} + std::string GetAddressToken( + const CachedNetworkParameters* /*cached_network_params*/) const override { + return ""; + } + bool ValidateAddressToken(absl::string_view /*token*/) const override { + return true; + } + const CachedNetworkParameters* PreviousCachedNetworkParams() const override { + return nullptr; + } + void SetPreviousCachedNetworkParams( + CachedNetworkParameters /*cached_network_params*/) override {} + + MOCK_METHOD(void, OnCanWrite, (), (override)); + + bool HasPendingCryptoRetransmission() const override { return false; } + + MOCK_METHOD(bool, HasPendingRetransmission, (), (const, override)); + + void OnConnectionClosed(QuicErrorCode /*error*/, + ConnectionCloseSource /*source*/) override {} + SSL* GetSsl() const override { return nullptr; } + bool IsCryptoFrameExpectedForEncryptionLevel( + EncryptionLevel level) const override { + return level != ENCRYPTION_ZERO_RTT; + } + EncryptionLevel GetEncryptionLevelToSendCryptoDataOfSpace( + PacketNumberSpace space) const override { + switch (space) { + case INITIAL_DATA: + return ENCRYPTION_INITIAL; + case HANDSHAKE_DATA: + return ENCRYPTION_HANDSHAKE; + case APPLICATION_DATA: + return ENCRYPTION_FORWARD_SECURE; + default: + QUICHE_DCHECK(false); + return NUM_ENCRYPTION_LEVELS; + } + } + + bool ExportKeyingMaterial(absl::string_view /*label*/, + absl::string_view /*context*/, + size_t /*result_len*/, std::string* + /*result*/) override { + return false; + } + + private: + using QuicCryptoStream::session; + + bool encryption_established_; + bool one_rtt_keys_available_; + quiche::QuicheReferenceCountedPointer params_; +}; + +class TestHeadersStream : public QuicHeadersStream { + public: + explicit TestHeadersStream(QuicSpdySession* session) + : QuicHeadersStream(session) {} + + MOCK_METHOD(void, OnCanWrite, (), (override)); +}; + +class TestStream : public QuicSpdyStream { + public: + TestStream(QuicStreamId id, QuicSpdySession* session, StreamType type) + : QuicSpdyStream(id, session, type) {} + + TestStream(PendingStream* pending, QuicSpdySession* session) + : QuicSpdyStream(pending, session) {} + + using QuicStream::CloseWriteSide; + + void OnBodyAvailable() override {} + + MOCK_METHOD(void, OnCanWrite, (), (override)); + MOCK_METHOD(bool, RetransmitStreamData, + (QuicStreamOffset, QuicByteCount, bool, TransmissionType), + (override)); + + MOCK_METHOD(bool, HasPendingRetransmission, (), (const, override)); + + protected: + bool AreHeadersValid(const QuicHeaderList& /*header_list*/) const override { + return true; + } +}; + +class TestSession : public QuicSpdySession { + public: + explicit TestSession(QuicConnection* connection) + : QuicSpdySession(connection, nullptr, DefaultQuicConfig(), + CurrentSupportedVersions()), + crypto_stream_(this), + writev_consumes_all_data_(false) { + this->connection()->SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + if (this->connection()->version().SupportsAntiAmplificationLimit()) { + QuicConnectionPeer::SetAddressValidated(this->connection()); + } + } + + ~TestSession() override { DeleteConnection(); } + + TestCryptoStream* GetMutableCryptoStream() override { + return &crypto_stream_; + } + + const TestCryptoStream* GetCryptoStream() const override { + return &crypto_stream_; + } + + TestStream* CreateOutgoingBidirectionalStream() override { + TestStream* stream = new TestStream(GetNextOutgoingBidirectionalStreamId(), + this, BIDIRECTIONAL); + ActivateStream(absl::WrapUnique(stream)); + return stream; + } + + TestStream* CreateOutgoingUnidirectionalStream() override { + TestStream* stream = new TestStream(GetNextOutgoingUnidirectionalStreamId(), + this, WRITE_UNIDIRECTIONAL); + ActivateStream(absl::WrapUnique(stream)); + return stream; + } + + TestStream* CreateIncomingStream(QuicStreamId id) override { + // Enforce the limit on the number of open streams. + if (!VersionHasIetfQuicFrames(connection()->transport_version()) && + stream_id_manager().num_open_incoming_streams() + 1 > + max_open_incoming_bidirectional_streams()) { + connection()->CloseConnection( + QUIC_TOO_MANY_OPEN_STREAMS, "Too many streams!", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return nullptr; + } else { + TestStream* stream = new TestStream( + id, this, + DetermineStreamType(id, connection()->version(), perspective(), + /*is_incoming=*/true, BIDIRECTIONAL)); + ActivateStream(absl::WrapUnique(stream)); + return stream; + } + } + + TestStream* CreateIncomingStream(PendingStream* pending) override { + TestStream* stream = new TestStream(pending, this); + ActivateStream(absl::WrapUnique(stream)); + return stream; + } + + bool ShouldCreateIncomingStream(QuicStreamId /*id*/) override { return true; } + + bool ShouldCreateOutgoingBidirectionalStream() override { return true; } + bool ShouldCreateOutgoingUnidirectionalStream() override { return true; } + + bool IsClosedStream(QuicStreamId id) { + return QuicSession::IsClosedStream(id); + } + + QuicStream* GetOrCreateStream(QuicStreamId stream_id) { + return QuicSpdySession::GetOrCreateStream(stream_id); + } + + QuicConsumedData WritevData(QuicStreamId id, size_t write_length, + QuicStreamOffset offset, StreamSendingState state, + TransmissionType type, + EncryptionLevel level) override { + bool fin = state != NO_FIN; + QuicConsumedData consumed(write_length, fin); + if (!writev_consumes_all_data_) { + consumed = + QuicSession::WritevData(id, write_length, offset, state, type, level); + } + QuicSessionPeer::GetWriteBlockedStreams(this)->UpdateBytesForStream( + id, consumed.bytes_consumed); + return consumed; + } + + void set_writev_consumes_all_data(bool val) { + writev_consumes_all_data_ = val; + } + + QuicConsumedData SendStreamData(QuicStream* stream) { + if (!QuicUtils::IsCryptoStreamId(connection()->transport_version(), + stream->id()) && + connection()->encryption_level() != ENCRYPTION_FORWARD_SECURE) { + this->connection()->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + } + QuicStreamPeer::SendBuffer(stream).SaveStreamData("not empty"); + QuicConsumedData consumed = + WritevData(stream->id(), 9, 0, FIN, NOT_RETRANSMISSION, + GetEncryptionLevelToSendApplicationData()); + QuicStreamPeer::SendBuffer(stream).OnStreamDataConsumed( + consumed.bytes_consumed); + return consumed; + } + + QuicConsumedData SendLargeFakeData(QuicStream* stream, int bytes) { + QUICHE_DCHECK(writev_consumes_all_data_); + return WritevData(stream->id(), bytes, 0, FIN, NOT_RETRANSMISSION, + GetEncryptionLevelToSendApplicationData()); + } + + bool ShouldNegotiateWebTransport() override { return supports_webtransport_; } + void set_supports_webtransport(bool value) { supports_webtransport_ = value; } + + HttpDatagramSupport LocalHttpDatagramSupport() override { + return local_http_datagram_support_; + } + void set_local_http_datagram_support(HttpDatagramSupport value) { + local_http_datagram_support_ = value; + } + + MOCK_METHOD(void, OnAcceptChFrame, (const AcceptChFrame&), (override)); + + using QuicSession::closed_streams; + using QuicSession::ShouldKeepConnectionAlive; + using QuicSpdySession::ProcessPendingStream; + using QuicSpdySession::UsesPendingStreamForFrame; + + private: + StrictMock crypto_stream_; + + bool writev_consumes_all_data_; + bool supports_webtransport_ = false; + HttpDatagramSupport local_http_datagram_support_ = HttpDatagramSupport::kNone; +}; + +class QuicSpdySessionTestBase : public QuicTestWithParam { + public: + bool ClearMaxStreamsControlFrame(const QuicFrame& frame) { + if (frame.type == MAX_STREAMS_FRAME) { + DeleteFrame(&const_cast(frame)); + return true; + } + return false; + } + + protected: + explicit QuicSpdySessionTestBase(Perspective perspective, + bool allow_extended_connect) + : connection_(new StrictMock( + &helper_, &alarm_factory_, perspective, + SupportedVersions(GetParam()))), + session_(connection_) { + if (perspective == Perspective::IS_SERVER && + VersionUsesHttp3(transport_version()) && + GetQuicReloadableFlag(quic_verify_request_headers_2)) { + session_.set_allow_extended_connect(allow_extended_connect); + } + session_.Initialize(); + session_.config()->SetInitialStreamFlowControlWindowToSend( + kInitialStreamFlowControlWindowForTest); + session_.config()->SetInitialSessionFlowControlWindowToSend( + kInitialSessionFlowControlWindowForTest); + if (VersionUsesHttp3(transport_version())) { + QuicConfigPeer::SetReceivedMaxUnidirectionalStreams( + session_.config(), kHttp3StaticUnidirectionalStreamCount); + } + QuicConfigPeer::SetReceivedInitialSessionFlowControlWindow( + session_.config(), kMinimumFlowControlSendWindow); + QuicConfigPeer::SetReceivedInitialMaxStreamDataBytesUnidirectional( + session_.config(), kMinimumFlowControlSendWindow); + QuicConfigPeer::SetReceivedInitialMaxStreamDataBytesIncomingBidirectional( + session_.config(), kMinimumFlowControlSendWindow); + QuicConfigPeer::SetReceivedInitialMaxStreamDataBytesOutgoingBidirectional( + session_.config(), kMinimumFlowControlSendWindow); + session_.OnConfigNegotiated(); + connection_->AdvanceTime(QuicTime::Delta::FromSeconds(1)); + TestCryptoStream* crypto_stream = session_.GetMutableCryptoStream(); + EXPECT_CALL(*crypto_stream, HasPendingRetransmission()) + .Times(testing::AnyNumber()); + writer_ = static_cast( + QuicConnectionPeer::GetWriter(session_.connection())); + } + + void CheckClosedStreams() { + QuicStreamId first_stream_id = QuicUtils::GetFirstBidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT); + if (!QuicVersionUsesCryptoFrames(transport_version())) { + first_stream_id = QuicUtils::GetCryptoStreamId(transport_version()); + } + for (QuicStreamId i = first_stream_id; i < 100; i++) { + if (closed_streams_.find(i) == closed_streams_.end()) { + EXPECT_FALSE(session_.IsClosedStream(i)) << " stream id: " << i; + } else { + EXPECT_TRUE(session_.IsClosedStream(i)) << " stream id: " << i; + } + } + } + + void CloseStream(QuicStreamId id) { + if (!VersionHasIetfQuicFrames(transport_version())) { + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillOnce(Invoke(&ClearControlFrame)); + } else { + // IETF QUIC has two frames, RST_STREAM and STOP_SENDING + EXPECT_CALL(*connection_, SendControlFrame(_)) + .Times(2) + .WillRepeatedly(Invoke(&ClearControlFrame)); + } + EXPECT_CALL(*connection_, OnStreamReset(id, _)); + + // QPACK streams might write data upon stream reset. Let the test session + // handle the data. + session_.set_writev_consumes_all_data(true); + + session_.ResetStream(id, QUIC_STREAM_CANCELLED); + closed_streams_.insert(id); + } + + ParsedQuicVersion version() const { return connection_->version(); } + + QuicTransportVersion transport_version() const { + return connection_->transport_version(); + } + + QuicStreamId GetNthClientInitiatedBidirectionalId(int n) { + return GetNthClientInitiatedBidirectionalStreamId(transport_version(), n); + } + + QuicStreamId GetNthServerInitiatedBidirectionalId(int n) { + return GetNthServerInitiatedBidirectionalStreamId(transport_version(), n); + } + + QuicStreamId IdDelta() { + return QuicUtils::StreamIdDelta(transport_version()); + } + + QuicStreamId StreamCountToId(QuicStreamCount stream_count, + Perspective perspective, bool bidirectional) { + // Calculate and build up stream ID rather than use + // GetFirst... because the test that relies on this method + // needs to do the stream count where #1 is 0/1/2/3, and not + // take into account that stream 0 is special. + QuicStreamId id = + ((stream_count - 1) * QuicUtils::StreamIdDelta(transport_version())); + if (!bidirectional) { + id |= 0x2; + } + if (perspective == Perspective::IS_SERVER) { + id |= 0x1; + } + return id; + } + + void CompleteHandshake() { + if (VersionHasIetfQuicFrames(transport_version())) { + EXPECT_CALL(*writer_, WritePacket(_, _, _, _, _)) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 0))); + } + if (connection_->version().UsesTls() && + connection_->perspective() == Perspective::IS_SERVER) { + // HANDSHAKE_DONE frame. + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillOnce(Invoke(&ClearControlFrame)); + } + + CryptoHandshakeMessage message; + session_.GetMutableCryptoStream()->OnHandshakeMessage(message); + testing::Mock::VerifyAndClearExpectations(writer_); + testing::Mock::VerifyAndClearExpectations(connection_); + } + + void ReceiveWebTransportSettings() { + SettingsFrame settings; + settings.values[SETTINGS_H3_DATAGRAM_DRAFT04] = 1; + settings.values[SETTINGS_WEBTRANS_DRAFT00] = 1; + settings.values[SETTINGS_ENABLE_CONNECT_PROTOCOL] = 1; + std::string data = std::string(1, kControlStream) + + HttpEncoder::SerializeSettingsFrame(settings); + QuicStreamId control_stream_id = + session_.perspective() == Perspective::IS_SERVER + ? GetNthClientInitiatedUnidirectionalStreamId(transport_version(), + 3) + : GetNthServerInitiatedUnidirectionalStreamId(transport_version(), + 3); + QuicStreamFrame frame(control_stream_id, /*fin=*/false, /*offset=*/0, data); + session_.OnStreamFrame(frame); + } + + void ReceiveWebTransportSession(WebTransportSessionId session_id) { + QuicStreamFrame frame(session_id, /*fin=*/false, /*offset=*/0, + absl::string_view()); + session_.OnStreamFrame(frame); + QuicSpdyStream* stream = + static_cast(session_.GetOrCreateStream(session_id)); + QuicHeaderList headers; + headers.OnHeaderBlockStart(); + headers.OnHeader(":method", "CONNECT"); + headers.OnHeader(":protocol", "webtransport"); + headers.OnHeader("sec-webtransport-http3-draft02", "1"); + stream->OnStreamHeaderList(/*fin=*/true, 0, headers); + WebTransportHttp3* web_transport = + session_.GetWebTransportSession(session_id); + ASSERT_TRUE(web_transport != nullptr); + spdy::Http2HeaderBlock header_block; + web_transport->HeadersReceived(header_block); + } + + void ReceiveWebTransportUnidirectionalStream(WebTransportSessionId session_id, + QuicStreamId stream_id) { + char buffer[256]; + QuicDataWriter data_writer(sizeof(buffer), buffer); + ASSERT_TRUE(data_writer.WriteVarInt62(kWebTransportUnidirectionalStream)); + ASSERT_TRUE(data_writer.WriteVarInt62(session_id)); + ASSERT_TRUE(data_writer.WriteStringPiece("test data")); + std::string data(buffer, data_writer.length()); + QuicStreamFrame frame(stream_id, /*fin=*/false, /*offset=*/0, data); + session_.OnStreamFrame(frame); + } + + void TestHttpDatagramSetting(HttpDatagramSupport local_support, + HttpDatagramSupport remote_support, + HttpDatagramSupport expected_support, + bool expected_datagram_supported); + + MockQuicConnectionHelper helper_; + MockAlarmFactory alarm_factory_; + StrictMock* connection_; + TestSession session_; + std::set closed_streams_; + MockPacketWriter* writer_; +}; + +class QuicSpdySessionTestServer : public QuicSpdySessionTestBase { + protected: + QuicSpdySessionTestServer() + : QuicSpdySessionTestBase(Perspective::IS_SERVER, true) {} +}; + +INSTANTIATE_TEST_SUITE_P(Tests, QuicSpdySessionTestServer, + ::testing::ValuesIn(AllSupportedVersions()), + ::testing::PrintToStringParamName()); + +TEST_P(QuicSpdySessionTestServer, UsesPendingStreamsForFrame) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + EXPECT_TRUE(session_.UsesPendingStreamForFrame( + STREAM_FRAME, QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT))); + EXPECT_TRUE(session_.UsesPendingStreamForFrame( + RST_STREAM_FRAME, QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT))); + EXPECT_FALSE(session_.UsesPendingStreamForFrame( + RST_STREAM_FRAME, QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_SERVER))); + EXPECT_FALSE(session_.UsesPendingStreamForFrame( + STOP_SENDING_FRAME, QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT))); + EXPECT_FALSE(session_.UsesPendingStreamForFrame( + RST_STREAM_FRAME, QuicUtils::GetFirstBidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT))); +} + +TEST_P(QuicSpdySessionTestServer, PeerAddress) { + EXPECT_EQ(QuicSocketAddress(QuicIpAddress::Loopback4(), kTestPort), + session_.peer_address()); +} + +TEST_P(QuicSpdySessionTestServer, SelfAddress) { + EXPECT_TRUE(session_.self_address().IsInitialized()); +} + +TEST_P(QuicSpdySessionTestServer, OneRttKeysAvailable) { + EXPECT_FALSE(session_.OneRttKeysAvailable()); + CompleteHandshake(); + EXPECT_TRUE(session_.OneRttKeysAvailable()); +} + +TEST_P(QuicSpdySessionTestServer, IsClosedStreamDefault) { + // Ensure that no streams are initially closed. + QuicStreamId first_stream_id = QuicUtils::GetFirstBidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT); + if (!QuicVersionUsesCryptoFrames(transport_version())) { + first_stream_id = QuicUtils::GetCryptoStreamId(transport_version()); + } + for (QuicStreamId i = first_stream_id; i < 100; i++) { + EXPECT_FALSE(session_.IsClosedStream(i)) << "stream id: " << i; + } +} + +TEST_P(QuicSpdySessionTestServer, AvailableStreams) { + ASSERT_TRUE(session_.GetOrCreateStream( + GetNthClientInitiatedBidirectionalId(2)) != nullptr); + // Both client initiated streams with smaller stream IDs are available. + EXPECT_TRUE(QuicSessionPeer::IsStreamAvailable( + &session_, GetNthClientInitiatedBidirectionalId(0))); + EXPECT_TRUE(QuicSessionPeer::IsStreamAvailable( + &session_, GetNthClientInitiatedBidirectionalId(1))); + ASSERT_TRUE(session_.GetOrCreateStream( + GetNthClientInitiatedBidirectionalId(1)) != nullptr); + ASSERT_TRUE(session_.GetOrCreateStream( + GetNthClientInitiatedBidirectionalId(0)) != nullptr); +} + +TEST_P(QuicSpdySessionTestServer, IsClosedStreamLocallyCreated) { + CompleteHandshake(); + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + EXPECT_EQ(GetNthServerInitiatedBidirectionalId(0), stream2->id()); + QuicSpdyStream* stream4 = session_.CreateOutgoingBidirectionalStream(); + EXPECT_EQ(GetNthServerInitiatedBidirectionalId(1), stream4->id()); + + CheckClosedStreams(); + CloseStream(GetNthServerInitiatedBidirectionalId(0)); + CheckClosedStreams(); + CloseStream(GetNthServerInitiatedBidirectionalId(1)); + CheckClosedStreams(); +} + +TEST_P(QuicSpdySessionTestServer, IsClosedStreamPeerCreated) { + CompleteHandshake(); + QuicStreamId stream_id1 = GetNthClientInitiatedBidirectionalId(0); + QuicStreamId stream_id2 = GetNthClientInitiatedBidirectionalId(1); + session_.GetOrCreateStream(stream_id1); + session_.GetOrCreateStream(stream_id2); + + CheckClosedStreams(); + CloseStream(stream_id1); + CheckClosedStreams(); + CloseStream(stream_id2); + // Create a stream, and make another available. + QuicStream* stream3 = session_.GetOrCreateStream(stream_id2 + 4); + CheckClosedStreams(); + // Close one, but make sure the other is still not closed + CloseStream(stream3->id()); + CheckClosedStreams(); +} + +TEST_P(QuicSpdySessionTestServer, MaximumAvailableOpenedStreams) { + if (VersionHasIetfQuicFrames(transport_version())) { + // For IETF QUIC, we should be able to obtain the max allowed + // stream ID, the next ID should fail. Since the actual limit + // is not the number of open streams, we allocate the max and the max+2. + // Get the max allowed stream ID, this should succeed. + QuicStreamId stream_id = StreamCountToId( + QuicSessionPeer::ietf_streamid_manager(&session_) + ->max_incoming_bidirectional_streams(), + Perspective::IS_CLIENT, // Client initates stream, allocs stream id. + /*bidirectional=*/true); + EXPECT_NE(nullptr, session_.GetOrCreateStream(stream_id)); + stream_id = + StreamCountToId(QuicSessionPeer::ietf_streamid_manager(&session_) + ->max_incoming_unidirectional_streams(), + Perspective::IS_CLIENT, + /*bidirectional=*/false); + EXPECT_NE(nullptr, session_.GetOrCreateStream(stream_id)); + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(2); + // Get the (max allowed stream ID)++. These should all fail. + stream_id = + StreamCountToId(QuicSessionPeer::ietf_streamid_manager(&session_) + ->max_incoming_bidirectional_streams() + + 1, + Perspective::IS_CLIENT, + /*bidirectional=*/true); + EXPECT_EQ(nullptr, session_.GetOrCreateStream(stream_id)); + + stream_id = + StreamCountToId(QuicSessionPeer::ietf_streamid_manager(&session_) + ->max_incoming_unidirectional_streams() + + 1, + Perspective::IS_CLIENT, + /*bidirectional=*/false); + EXPECT_EQ(nullptr, session_.GetOrCreateStream(stream_id)); + } else { + QuicStreamId stream_id = GetNthClientInitiatedBidirectionalId(0); + session_.GetOrCreateStream(stream_id); + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + EXPECT_NE( + nullptr, + session_.GetOrCreateStream( + stream_id + + IdDelta() * + (session_.max_open_incoming_bidirectional_streams() - 1))); + } +} + +TEST_P(QuicSpdySessionTestServer, TooManyAvailableStreams) { + QuicStreamId stream_id1 = GetNthClientInitiatedBidirectionalId(0); + QuicStreamId stream_id2; + EXPECT_NE(nullptr, session_.GetOrCreateStream(stream_id1)); + // A stream ID which is too large to create. + stream_id2 = GetNthClientInitiatedBidirectionalId( + 2 * session_.MaxAvailableBidirectionalStreams() + 4); + if (VersionHasIetfQuicFrames(transport_version())) { + EXPECT_CALL(*connection_, CloseConnection(QUIC_INVALID_STREAM_ID, _, _)); + } else { + EXPECT_CALL(*connection_, + CloseConnection(QUIC_TOO_MANY_AVAILABLE_STREAMS, _, _)); + } + EXPECT_EQ(nullptr, session_.GetOrCreateStream(stream_id2)); +} + +TEST_P(QuicSpdySessionTestServer, ManyAvailableStreams) { + // When max_open_streams_ is 200, should be able to create 200 streams + // out-of-order, that is, creating the one with the largest stream ID first. + if (VersionHasIetfQuicFrames(transport_version())) { + QuicSessionPeer::SetMaxOpenIncomingBidirectionalStreams(&session_, 200); + } else { + QuicSessionPeer::SetMaxOpenIncomingStreams(&session_, 200); + } + QuicStreamId stream_id = GetNthClientInitiatedBidirectionalId(0); + // Create one stream. + session_.GetOrCreateStream(stream_id); + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + // Stream count is 200, GetNth... starts counting at 0, so the 200'th stream + // is 199. BUT actually we need to do 198 because the crypto stream (Stream + // ID 0) has not been registered, but GetNth... assumes that it has. + EXPECT_NE(nullptr, session_.GetOrCreateStream( + GetNthClientInitiatedBidirectionalId(198))); +} + +TEST_P(QuicSpdySessionTestServer, + DebugDFatalIfMarkingClosedStreamWriteBlocked) { + CompleteHandshake(); + EXPECT_CALL(*writer_, WritePacket(_, _, _, _, _)) + .WillRepeatedly(Return(WriteResult(WRITE_STATUS_OK, 0))); + + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + QuicStreamId closed_stream_id = stream2->id(); + // Close the stream. + EXPECT_CALL(*connection_, SendControlFrame(_)); + EXPECT_CALL(*connection_, OnStreamReset(closed_stream_id, _)); + stream2->Reset(QUIC_BAD_APPLICATION_PAYLOAD); + std::string msg = + absl::StrCat("Marking unknown stream ", closed_stream_id, " blocked."); + EXPECT_QUIC_BUG(session_.MarkConnectionLevelWriteBlocked(closed_stream_id), + msg); +} + +TEST_P(QuicSpdySessionTestServer, TooLargeStreamBlocked) { + // STREAMS_BLOCKED frame is IETF QUIC only. + if (!VersionUsesHttp3(transport_version())) { + return; + } + + CompleteHandshake(); + StrictMock debug_visitor; + session_.set_debug_visitor(&debug_visitor); + + // Simualte the situation where the incoming stream count is at its limit and + // the peer is blocked. + QuicSessionPeer::SetMaxOpenIncomingBidirectionalStreams( + static_cast(&session_), QuicUtils::GetMaxStreamCount()); + QuicStreamsBlockedFrame frame; + frame.stream_count = QuicUtils::GetMaxStreamCount(); + EXPECT_CALL(*writer_, WritePacket(_, _, _, _, _)) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 0))); + EXPECT_CALL(debug_visitor, OnGoAwayFrameSent(_)); + session_.OnStreamsBlockedFrame(frame); +} + +TEST_P(QuicSpdySessionTestServer, OnCanWriteBundlesStreams) { + // Encryption needs to be established before data can be sent. + CompleteHandshake(); + + // Drive congestion control manually. + MockSendAlgorithm* send_algorithm = new StrictMock; + QuicConnectionPeer::SetSendAlgorithm(session_.connection(), send_algorithm); + + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + TestStream* stream4 = session_.CreateOutgoingBidirectionalStream(); + TestStream* stream6 = session_.CreateOutgoingBidirectionalStream(); + + session_.MarkConnectionLevelWriteBlocked(stream2->id()); + session_.MarkConnectionLevelWriteBlocked(stream6->id()); + session_.MarkConnectionLevelWriteBlocked(stream4->id()); + + EXPECT_CALL(*send_algorithm, CanSend(_)).WillRepeatedly(Return(true)); + EXPECT_CALL(*send_algorithm, GetCongestionWindow()) + .WillRepeatedly(Return(kMaxOutgoingPacketSize * 10)); + EXPECT_CALL(*send_algorithm, InRecovery()).WillRepeatedly(Return(false)); + EXPECT_CALL(*stream2, OnCanWrite()).WillOnce(Invoke([this, stream2]() { + session_.SendStreamData(stream2); + })); + EXPECT_CALL(*stream4, OnCanWrite()).WillOnce(Invoke([this, stream4]() { + session_.SendStreamData(stream4); + })); + EXPECT_CALL(*stream6, OnCanWrite()).WillOnce(Invoke([this, stream6]() { + session_.SendStreamData(stream6); + })); + + // Expect that we only send one packet, the writes from different streams + // should be bundled together. + EXPECT_CALL(*writer_, WritePacket(_, _, _, _, _)) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 0))); + EXPECT_CALL(*send_algorithm, OnPacketSent(_, _, _, _, _)); + EXPECT_CALL(*send_algorithm, OnApplicationLimited(_)); + session_.OnCanWrite(); + EXPECT_FALSE(session_.WillingAndAbleToWrite()); +} + +TEST_P(QuicSpdySessionTestServer, OnCanWriteCongestionControlBlocks) { + CompleteHandshake(); + session_.set_writev_consumes_all_data(true); + InSequence s; + + // Drive congestion control manually. + MockSendAlgorithm* send_algorithm = new StrictMock; + QuicConnectionPeer::SetSendAlgorithm(session_.connection(), send_algorithm); + + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + TestStream* stream4 = session_.CreateOutgoingBidirectionalStream(); + TestStream* stream6 = session_.CreateOutgoingBidirectionalStream(); + + session_.MarkConnectionLevelWriteBlocked(stream2->id()); + session_.MarkConnectionLevelWriteBlocked(stream6->id()); + session_.MarkConnectionLevelWriteBlocked(stream4->id()); + + EXPECT_CALL(*send_algorithm, CanSend(_)).WillOnce(Return(true)); + EXPECT_CALL(*stream2, OnCanWrite()).WillOnce(Invoke([this, stream2]() { + session_.SendStreamData(stream2); + })); + EXPECT_CALL(*send_algorithm, GetCongestionWindow()).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm, CanSend(_)).WillOnce(Return(true)); + EXPECT_CALL(*stream6, OnCanWrite()).WillOnce(Invoke([this, stream6]() { + session_.SendStreamData(stream6); + })); + EXPECT_CALL(*send_algorithm, CanSend(_)).WillOnce(Return(false)); + // stream4->OnCanWrite is not called. + + session_.OnCanWrite(); + EXPECT_TRUE(session_.WillingAndAbleToWrite()); + + // Still congestion-control blocked. + EXPECT_CALL(*send_algorithm, CanSend(_)).WillOnce(Return(false)); + session_.OnCanWrite(); + EXPECT_TRUE(session_.WillingAndAbleToWrite()); + + // stream4->OnCanWrite is called once the connection stops being + // congestion-control blocked. + EXPECT_CALL(*send_algorithm, CanSend(_)).WillOnce(Return(true)); + EXPECT_CALL(*stream4, OnCanWrite()).WillOnce(Invoke([this, stream4]() { + session_.SendStreamData(stream4); + })); + EXPECT_CALL(*send_algorithm, OnApplicationLimited(_)); + session_.OnCanWrite(); + EXPECT_FALSE(session_.WillingAndAbleToWrite()); +} + +TEST_P(QuicSpdySessionTestServer, OnCanWriteWriterBlocks) { + CompleteHandshake(); + // Drive congestion control manually in order to ensure that + // application-limited signaling is handled correctly. + MockSendAlgorithm* send_algorithm = new StrictMock; + QuicConnectionPeer::SetSendAlgorithm(session_.connection(), send_algorithm); + EXPECT_CALL(*send_algorithm, CanSend(_)).WillRepeatedly(Return(true)); + + // Drive packet writer manually. + EXPECT_CALL(*writer_, IsWriteBlocked()).WillRepeatedly(Return(true)); + EXPECT_CALL(*writer_, WritePacket(_, _, _, _, _)).Times(0); + + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + + session_.MarkConnectionLevelWriteBlocked(stream2->id()); + + EXPECT_CALL(*stream2, OnCanWrite()).Times(0); + EXPECT_CALL(*send_algorithm, OnApplicationLimited(_)).Times(0); + + session_.OnCanWrite(); + EXPECT_TRUE(session_.WillingAndAbleToWrite()); +} + +TEST_P(QuicSpdySessionTestServer, BufferedHandshake) { + // This tests prioritization of the crypto stream when flow control limits are + // reached. When CRYPTO frames are in use, there is no flow control for the + // crypto handshake, so this test is irrelevant. + if (QuicVersionUsesCryptoFrames(transport_version())) { + return; + } + session_.set_writev_consumes_all_data(true); + EXPECT_FALSE(session_.HasPendingHandshake()); // Default value. + + // Test that blocking other streams does not change our status. + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + session_.MarkConnectionLevelWriteBlocked(stream2->id()); + EXPECT_FALSE(session_.HasPendingHandshake()); + + TestStream* stream3 = session_.CreateOutgoingBidirectionalStream(); + session_.MarkConnectionLevelWriteBlocked(stream3->id()); + EXPECT_FALSE(session_.HasPendingHandshake()); + + // Blocking (due to buffering of) the Crypto stream is detected. + session_.MarkConnectionLevelWriteBlocked( + QuicUtils::GetCryptoStreamId(transport_version())); + EXPECT_TRUE(session_.HasPendingHandshake()); + + TestStream* stream4 = session_.CreateOutgoingBidirectionalStream(); + session_.MarkConnectionLevelWriteBlocked(stream4->id()); + EXPECT_TRUE(session_.HasPendingHandshake()); + + InSequence s; + // Force most streams to re-register, which is common scenario when we block + // the Crypto stream, and only the crypto stream can "really" write. + + // Due to prioritization, we *should* be asked to write the crypto stream + // first. + // Don't re-register the crypto stream (which signals complete writing). + TestCryptoStream* crypto_stream = session_.GetMutableCryptoStream(); + EXPECT_CALL(*crypto_stream, OnCanWrite()); + + EXPECT_CALL(*stream2, OnCanWrite()).WillOnce(Invoke([this, stream2]() { + session_.SendStreamData(stream2); + })); + EXPECT_CALL(*stream3, OnCanWrite()).WillOnce(Invoke([this, stream3]() { + session_.SendStreamData(stream3); + })); + EXPECT_CALL(*stream4, OnCanWrite()).WillOnce(Invoke([this, stream4]() { + session_.SendStreamData(stream4); + session_.MarkConnectionLevelWriteBlocked(stream4->id()); + })); + + session_.OnCanWrite(); + EXPECT_TRUE(session_.WillingAndAbleToWrite()); + EXPECT_FALSE(session_.HasPendingHandshake()); // Crypto stream wrote. +} + +TEST_P(QuicSpdySessionTestServer, OnCanWriteWithClosedStream) { + CompleteHandshake(); + session_.set_writev_consumes_all_data(true); + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + TestStream* stream4 = session_.CreateOutgoingBidirectionalStream(); + TestStream* stream6 = session_.CreateOutgoingBidirectionalStream(); + + session_.MarkConnectionLevelWriteBlocked(stream2->id()); + session_.MarkConnectionLevelWriteBlocked(stream6->id()); + session_.MarkConnectionLevelWriteBlocked(stream4->id()); + CloseStream(stream6->id()); + + InSequence s; + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillRepeatedly(Invoke(&ClearControlFrame)); + EXPECT_CALL(*stream2, OnCanWrite()).WillOnce(Invoke([this, stream2]() { + session_.SendStreamData(stream2); + })); + EXPECT_CALL(*stream4, OnCanWrite()).WillOnce(Invoke([this, stream4]() { + session_.SendStreamData(stream4); + })); + session_.OnCanWrite(); + EXPECT_FALSE(session_.WillingAndAbleToWrite()); +} + +TEST_P(QuicSpdySessionTestServer, + OnCanWriteLimitsNumWritesIfFlowControlBlocked) { + CompleteHandshake(); + // Drive congestion control manually in order to ensure that + // application-limited signaling is handled correctly. + MockSendAlgorithm* send_algorithm = new StrictMock; + QuicConnectionPeer::SetSendAlgorithm(session_.connection(), send_algorithm); + EXPECT_CALL(*send_algorithm, CanSend(_)).WillRepeatedly(Return(true)); + + // Ensure connection level flow control blockage. + QuicFlowControllerPeer::SetSendWindowOffset(session_.flow_controller(), 0); + EXPECT_TRUE(session_.flow_controller()->IsBlocked()); + EXPECT_TRUE(session_.IsConnectionFlowControlBlocked()); + EXPECT_FALSE(session_.IsStreamFlowControlBlocked()); + + // Mark the crypto and headers streams as write blocked, we expect them to be + // allowed to write later. + if (!QuicVersionUsesCryptoFrames(transport_version())) { + session_.MarkConnectionLevelWriteBlocked( + QuicUtils::GetCryptoStreamId(transport_version())); + } + + // Create a data stream, and although it is write blocked we never expect it + // to be allowed to write as we are connection level flow control blocked. + TestStream* stream = session_.CreateOutgoingBidirectionalStream(); + session_.MarkConnectionLevelWriteBlocked(stream->id()); + EXPECT_CALL(*stream, OnCanWrite()).Times(0); + + // The crypto and headers streams should be called even though we are + // connection flow control blocked. + if (!QuicVersionUsesCryptoFrames(transport_version())) { + TestCryptoStream* crypto_stream = session_.GetMutableCryptoStream(); + EXPECT_CALL(*crypto_stream, OnCanWrite()); + } + + if (!VersionUsesHttp3(transport_version())) { + TestHeadersStream* headers_stream; + QuicSpdySessionPeer::SetHeadersStream(&session_, nullptr); + headers_stream = new TestHeadersStream(&session_); + QuicSpdySessionPeer::SetHeadersStream(&session_, headers_stream); + session_.MarkConnectionLevelWriteBlocked( + QuicUtils::GetHeadersStreamId(transport_version())); + EXPECT_CALL(*headers_stream, OnCanWrite()); + } + + // After the crypto and header streams perform a write, the connection will be + // blocked by the flow control, hence it should become application-limited. + EXPECT_CALL(*send_algorithm, OnApplicationLimited(_)); + + session_.OnCanWrite(); + EXPECT_FALSE(session_.WillingAndAbleToWrite()); +} + +TEST_P(QuicSpdySessionTestServer, SendGoAway) { + CompleteHandshake(); + if (VersionHasIetfQuicFrames(transport_version())) { + // HTTP/3 GOAWAY has different semantic and thus has its own test. + return; + } + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + EXPECT_CALL(*writer_, WritePacket(_, _, _, _, _)) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 0))); + + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillOnce( + Invoke(connection_, &MockQuicConnection::ReallySendControlFrame)); + session_.SendGoAway(QUIC_PEER_GOING_AWAY, "Going Away."); + EXPECT_TRUE(session_.goaway_sent()); + + const QuicStreamId kTestStreamId = 5u; + EXPECT_CALL(*connection_, SendControlFrame(_)).Times(0); + EXPECT_CALL(*connection_, + OnStreamReset(kTestStreamId, QUIC_STREAM_PEER_GOING_AWAY)) + .Times(0); + EXPECT_TRUE(session_.GetOrCreateStream(kTestStreamId)); +} + +TEST_P(QuicSpdySessionTestServer, SendGoAwayWithoutEncryption) { + if (VersionHasIetfQuicFrames(transport_version())) { + // HTTP/3 GOAWAY has different semantic and thus has its own test. + return; + } + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_PEER_GOING_AWAY, "Going Away.", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET)); + EXPECT_CALL(*connection_, SendControlFrame(_)).Times(0); + session_.SendGoAway(QUIC_PEER_GOING_AWAY, "Going Away."); + EXPECT_FALSE(session_.goaway_sent()); +} + +TEST_P(QuicSpdySessionTestServer, SendHttp3GoAway) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + CompleteHandshake(); + StrictMock debug_visitor; + session_.set_debug_visitor(&debug_visitor); + + EXPECT_CALL(*writer_, WritePacket(_, _, _, _, _)) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 0))); + // Send max stream id (currently 32 bits). + EXPECT_CALL(debug_visitor, OnGoAwayFrameSent(/* stream_id = */ 0xfffffffc)); + session_.SendHttp3GoAway(QUIC_PEER_GOING_AWAY, "Goaway"); + EXPECT_TRUE(session_.goaway_sent()); + + // New incoming stream is not reset. + const QuicStreamId kTestStreamId = + GetNthClientInitiatedBidirectionalStreamId(transport_version(), 0); + EXPECT_CALL(*connection_, OnStreamReset(kTestStreamId, _)).Times(0); + EXPECT_TRUE(session_.GetOrCreateStream(kTestStreamId)); + + // No more GOAWAY frames are sent because they could not convey new + // information to the client. + session_.SendHttp3GoAway(QUIC_PEER_GOING_AWAY, "Goaway"); +} + +TEST_P(QuicSpdySessionTestServer, SendHttp3GoAwayWithoutEncryption) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_PEER_GOING_AWAY, "Goaway", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET)); + session_.SendHttp3GoAway(QUIC_PEER_GOING_AWAY, "Goaway"); + EXPECT_FALSE(session_.goaway_sent()); +} + +TEST_P(QuicSpdySessionTestServer, SendHttp3GoAwayAfterStreamIsCreated) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + CompleteHandshake(); + StrictMock debug_visitor; + session_.set_debug_visitor(&debug_visitor); + + const QuicStreamId kTestStreamId = + GetNthClientInitiatedBidirectionalStreamId(transport_version(), 0); + EXPECT_TRUE(session_.GetOrCreateStream(kTestStreamId)); + + EXPECT_CALL(*writer_, WritePacket(_, _, _, _, _)) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 0))); + // Send max stream id (currently 32 bits). + EXPECT_CALL(debug_visitor, OnGoAwayFrameSent(/* stream_id = */ 0xfffffffc)); + session_.SendHttp3GoAway(QUIC_PEER_GOING_AWAY, "Goaway"); + EXPECT_TRUE(session_.goaway_sent()); + + // No more GOAWAY frames are sent because they could not convey new + // information to the client. + session_.SendHttp3GoAway(QUIC_PEER_GOING_AWAY, "Goaway"); +} + +TEST_P(QuicSpdySessionTestServer, DoNotSendGoAwayTwice) { + CompleteHandshake(); + if (VersionHasIetfQuicFrames(transport_version())) { + // HTTP/3 GOAWAY doesn't have such restriction. + return; + } + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillOnce(Invoke(&ClearControlFrame)); + session_.SendGoAway(QUIC_PEER_GOING_AWAY, "Going Away."); + EXPECT_TRUE(session_.goaway_sent()); + session_.SendGoAway(QUIC_PEER_GOING_AWAY, "Going Away."); +} + +TEST_P(QuicSpdySessionTestServer, InvalidGoAway) { + if (VersionHasIetfQuicFrames(transport_version())) { + // HTTP/3 GOAWAY has different semantics and thus has its own test. + return; + } + QuicGoAwayFrame go_away(kInvalidControlFrameId, QUIC_PEER_GOING_AWAY, + session_.next_outgoing_bidirectional_stream_id(), ""); + session_.OnGoAway(go_away); +} + +TEST_P(QuicSpdySessionTestServer, Http3GoAwayLargerIdThanBefore) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + EXPECT_FALSE(session_.goaway_received()); + session_.OnHttp3GoAway(/* id = */ 0); + EXPECT_TRUE(session_.goaway_received()); + + EXPECT_CALL( + *connection_, + CloseConnection( + QUIC_HTTP_GOAWAY_ID_LARGER_THAN_PREVIOUS, + "GOAWAY received with ID 1 greater than previously received ID 0", + _)); + session_.OnHttp3GoAway(/* id = */ 1); +} + +// Test that server session will send a connectivity probe in response to a +// connectivity probe on the same path. +TEST_P(QuicSpdySessionTestServer, ServerReplyToConnecitivityProbe) { + if (VersionHasIetfQuicFrames(transport_version())) { + return; + } + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + QuicSocketAddress old_peer_address = + QuicSocketAddress(QuicIpAddress::Loopback4(), kTestPort); + EXPECT_EQ(old_peer_address, session_.peer_address()); + + QuicSocketAddress new_peer_address = + QuicSocketAddress(QuicIpAddress::Loopback4(), kTestPort + 1); + + EXPECT_CALL(*connection_, + SendConnectivityProbingPacket(nullptr, new_peer_address)); + + if (VersionHasIetfQuicFrames(transport_version())) { + // Need to explicitly do this to emulate the reception of a PathChallenge, + // which stores its payload for use in generating the response. + connection_->OnPathChallengeFrame( + QuicPathChallengeFrame(0, {{0, 1, 2, 3, 4, 5, 6, 7}})); + } + session_.OnPacketReceived(session_.self_address(), new_peer_address, + /*is_connectivity_probe=*/true); + EXPECT_EQ(old_peer_address, session_.peer_address()); +} + +TEST_P(QuicSpdySessionTestServer, IncreasedTimeoutAfterCryptoHandshake) { + EXPECT_EQ(kInitialIdleTimeoutSecs + 3, + QuicConnectionPeer::GetNetworkTimeout(connection_).ToSeconds()); + CompleteHandshake(); + EXPECT_EQ(kMaximumIdleTimeoutSecs + 3, + QuicConnectionPeer::GetNetworkTimeout(connection_).ToSeconds()); +} + +TEST_P(QuicSpdySessionTestServer, RstStreamBeforeHeadersDecompressed) { + CompleteHandshake(); + // Send two bytes of payload. + QuicStreamFrame data1(GetNthClientInitiatedBidirectionalId(0), false, 0, + absl::string_view("HT")); + session_.OnStreamFrame(data1); + EXPECT_EQ(1u, QuicSessionPeer::GetNumOpenDynamicStreams(&session_)); + + if (!VersionHasIetfQuicFrames(transport_version())) { + // For version99, OnStreamReset gets called because of the STOP_SENDING, + // below. EXPECT the call there. + EXPECT_CALL(*connection_, + OnStreamReset(GetNthClientInitiatedBidirectionalId(0), _)); + } + + // In HTTP/3, Qpack stream will send data on stream reset and cause packet to + // be flushed. + if (VersionUsesHttp3(transport_version())) { + EXPECT_CALL(*writer_, WritePacket(_, _, _, _, _)) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 0))); + } + EXPECT_CALL(*connection_, SendControlFrame(_)); + QuicRstStreamFrame rst1(kInvalidControlFrameId, + GetNthClientInitiatedBidirectionalId(0), + QUIC_ERROR_PROCESSING_STREAM, 0); + session_.OnRstStream(rst1); + + // Create and inject a STOP_SENDING frame. In GOOGLE QUIC, receiving a + // RST_STREAM frame causes a two-way close. For IETF QUIC, RST_STREAM causes a + // one-way close. + if (VersionHasIetfQuicFrames(transport_version())) { + // Only needed for version 99/IETF QUIC. + QuicStopSendingFrame stop_sending(kInvalidControlFrameId, + GetNthClientInitiatedBidirectionalId(0), + QUIC_ERROR_PROCESSING_STREAM); + // Expect the RESET_STREAM that is generated in response to receiving a + // STOP_SENDING. + EXPECT_CALL(*connection_, + OnStreamReset(GetNthClientInitiatedBidirectionalId(0), + QUIC_ERROR_PROCESSING_STREAM)); + session_.OnStopSendingFrame(stop_sending); + } + + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(&session_)); + // Connection should remain alive. + EXPECT_TRUE(connection_->connected()); +} + +TEST_P(QuicSpdySessionTestServer, OnStreamFrameFinStaticStreamId) { + QuicStreamId id; + // Initialize HTTP/3 control stream. + if (VersionUsesHttp3(transport_version())) { + id = GetNthClientInitiatedUnidirectionalStreamId(transport_version(), 3); + char type[] = {kControlStream}; + + QuicStreamFrame data1(id, false, 0, absl::string_view(type, 1)); + session_.OnStreamFrame(data1); + } else { + id = QuicUtils::GetHeadersStreamId(transport_version()); + } + + // Send two bytes of payload. + QuicStreamFrame data1(id, true, 0, absl::string_view("HT")); + EXPECT_CALL(*connection_, + CloseConnection( + QUIC_INVALID_STREAM_ID, "Attempt to close a static stream", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET)); + session_.OnStreamFrame(data1); +} + +TEST_P(QuicSpdySessionTestServer, OnRstStreamStaticStreamId) { + QuicStreamId id; + QuicErrorCode expected_error; + std::string error_message; + // Initialize HTTP/3 control stream. + if (VersionUsesHttp3(transport_version())) { + id = GetNthClientInitiatedUnidirectionalStreamId(transport_version(), 3); + char type[] = {kControlStream}; + + QuicStreamFrame data1(id, false, 0, absl::string_view(type, 1)); + session_.OnStreamFrame(data1); + expected_error = QUIC_HTTP_CLOSED_CRITICAL_STREAM; + error_message = "RESET_STREAM received for receive control stream"; + } else { + id = QuicUtils::GetHeadersStreamId(transport_version()); + expected_error = QUIC_INVALID_STREAM_ID; + error_message = "Attempt to reset headers stream"; + } + + // Send two bytes of payload. + QuicRstStreamFrame rst1(kInvalidControlFrameId, id, + QUIC_ERROR_PROCESSING_STREAM, 0); + EXPECT_CALL( + *connection_, + CloseConnection(expected_error, error_message, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET)); + session_.OnRstStream(rst1); +} + +TEST_P(QuicSpdySessionTestServer, OnStreamFrameInvalidStreamId) { + // Send two bytes of payload. + QuicStreamFrame data1(QuicUtils::GetInvalidStreamId(transport_version()), + true, 0, absl::string_view("HT")); + EXPECT_CALL(*connection_, + CloseConnection( + QUIC_INVALID_STREAM_ID, "Received data for an invalid stream", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET)); + session_.OnStreamFrame(data1); +} + +TEST_P(QuicSpdySessionTestServer, OnRstStreamInvalidStreamId) { + // Send two bytes of payload. + QuicRstStreamFrame rst1(kInvalidControlFrameId, + QuicUtils::GetInvalidStreamId(transport_version()), + QUIC_ERROR_PROCESSING_STREAM, 0); + EXPECT_CALL(*connection_, + CloseConnection( + QUIC_INVALID_STREAM_ID, "Received data for an invalid stream", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET)); + session_.OnRstStream(rst1); +} + +TEST_P(QuicSpdySessionTestServer, HandshakeUnblocksFlowControlBlockedStream) { + if (connection_->version().handshake_protocol == PROTOCOL_TLS1_3) { + // This test requires Google QUIC crypto because it assumes streams start + // off unblocked. + return; + } + // Test that if a stream is flow control blocked, then on receipt of the SHLO + // containing a suitable send window offset, the stream becomes unblocked. + + // Ensure that Writev consumes all the data it is given (simulate no socket + // blocking). + session_.GetMutableCryptoStream()->EstablishZeroRttEncryption(); + session_.set_writev_consumes_all_data(true); + + // Create a stream, and send enough data to make it flow control blocked. + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + std::string body(kMinimumFlowControlSendWindow, '.'); + EXPECT_FALSE(stream2->IsFlowControlBlocked()); + EXPECT_FALSE(session_.IsConnectionFlowControlBlocked()); + EXPECT_FALSE(session_.IsStreamFlowControlBlocked()); + EXPECT_CALL(*connection_, SendControlFrame(_)).Times(AtLeast(1)); + stream2->WriteOrBufferBody(body, false); + EXPECT_TRUE(stream2->IsFlowControlBlocked()); + EXPECT_TRUE(session_.IsConnectionFlowControlBlocked()); + EXPECT_TRUE(session_.IsStreamFlowControlBlocked()); + + // Now complete the crypto handshake, resulting in an increased flow control + // send window. + CompleteHandshake(); + EXPECT_TRUE(QuicSessionPeer::IsStreamWriteBlocked(&session_, stream2->id())); + // Stream is now unblocked. + EXPECT_FALSE(stream2->IsFlowControlBlocked()); + EXPECT_FALSE(session_.IsConnectionFlowControlBlocked()); + EXPECT_FALSE(session_.IsStreamFlowControlBlocked()); +} + +#if !defined(OS_IOS) +// This test is failing flakily for iOS bots. +// http://crbug.com/425050 +// NOTE: It's not possible to use the standard MAYBE_ convention to disable +// this test on iOS because when this test gets instantiated it ends up with +// various names that are dependent on the parameters passed. +TEST_P(QuicSpdySessionTestServer, + HandshakeUnblocksFlowControlBlockedHeadersStream) { + // This test depends on stream-level flow control for the crypto stream, which + // doesn't exist when CRYPTO frames are used. + if (QuicVersionUsesCryptoFrames(transport_version())) { + return; + } + + // This test depends on the headers stream, which does not exist when QPACK is + // used. + if (VersionUsesHttp3(transport_version())) { + return; + } + + // Test that if the header stream is flow control blocked, then if the SHLO + // contains a larger send window offset, the stream becomes unblocked. + session_.GetMutableCryptoStream()->EstablishZeroRttEncryption(); + session_.set_writev_consumes_all_data(true); + TestCryptoStream* crypto_stream = session_.GetMutableCryptoStream(); + EXPECT_FALSE(crypto_stream->IsFlowControlBlocked()); + EXPECT_FALSE(session_.IsConnectionFlowControlBlocked()); + EXPECT_FALSE(session_.IsStreamFlowControlBlocked()); + QuicHeadersStream* headers_stream = + QuicSpdySessionPeer::GetHeadersStream(&session_); + EXPECT_FALSE(headers_stream->IsFlowControlBlocked()); + EXPECT_FALSE(session_.IsConnectionFlowControlBlocked()); + EXPECT_FALSE(session_.IsStreamFlowControlBlocked()); + QuicStreamId stream_id = 5; + // Write until the header stream is flow control blocked. + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillOnce(Invoke(&ClearControlFrame)); + Http2HeaderBlock headers; + SimpleRandom random; + while (!headers_stream->IsFlowControlBlocked() && stream_id < 2000) { + EXPECT_FALSE(session_.IsConnectionFlowControlBlocked()); + EXPECT_FALSE(session_.IsStreamFlowControlBlocked()); + headers["header"] = absl::StrCat(random.RandUint64(), random.RandUint64(), + random.RandUint64()); + session_.WriteHeadersOnHeadersStream(stream_id, headers.Clone(), true, + spdy::SpdyStreamPrecedence(0), + nullptr); + stream_id += IdDelta(); + } + // Write once more to ensure that the headers stream has buffered data. The + // random headers may have exactly filled the flow control window. + session_.WriteHeadersOnHeadersStream(stream_id, std::move(headers), true, + spdy::SpdyStreamPrecedence(0), nullptr); + EXPECT_TRUE(headers_stream->HasBufferedData()); + + EXPECT_TRUE(headers_stream->IsFlowControlBlocked()); + EXPECT_FALSE(crypto_stream->IsFlowControlBlocked()); + EXPECT_FALSE(session_.IsConnectionFlowControlBlocked()); + EXPECT_TRUE(session_.IsStreamFlowControlBlocked()); + EXPECT_FALSE(session_.HasDataToWrite()); + + // Now complete the crypto handshake, resulting in an increased flow control + // send window. + CompleteHandshake(); + + // Stream is now unblocked and will no longer have buffered data. + EXPECT_FALSE(headers_stream->IsFlowControlBlocked()); + EXPECT_FALSE(session_.IsConnectionFlowControlBlocked()); + EXPECT_FALSE(session_.IsStreamFlowControlBlocked()); + EXPECT_TRUE(headers_stream->HasBufferedData()); + EXPECT_TRUE(QuicSessionPeer::IsStreamWriteBlocked( + &session_, QuicUtils::GetHeadersStreamId(transport_version()))); +} +#endif // !defined(OS_IOS) + +TEST_P(QuicSpdySessionTestServer, + ConnectionFlowControlAccountingRstOutOfOrder) { + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillRepeatedly(Invoke(&ClearControlFrame)); + CompleteHandshake(); + // Test that when we receive an out of order stream RST we correctly adjust + // our connection level flow control receive window. + // On close, the stream should mark as consumed all bytes between the highest + // byte consumed so far and the final byte offset from the RST frame. + TestStream* stream = session_.CreateOutgoingBidirectionalStream(); + + const QuicStreamOffset kByteOffset = + 1 + kInitialSessionFlowControlWindowForTest / 2; + + if (!VersionHasIetfQuicFrames(transport_version())) { + // For version99 the call to OnStreamReset happens as a result of receiving + // the STOP_SENDING, so set up the EXPECT there. + EXPECT_CALL(*connection_, OnStreamReset(stream->id(), _)); + EXPECT_CALL(*connection_, SendControlFrame(_)); + } else { + EXPECT_CALL(*writer_, WritePacket(_, _, _, _, _)) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 0))); + } + QuicRstStreamFrame rst_frame(kInvalidControlFrameId, stream->id(), + QUIC_STREAM_CANCELLED, kByteOffset); + session_.OnRstStream(rst_frame); + // Create and inject a STOP_SENDING frame. In GOOGLE QUIC, receiving a + // RST_STREAM frame causes a two-way close. For IETF QUIC, RST_STREAM causes a + // one-way close. + if (VersionHasIetfQuicFrames(transport_version())) { + // Only needed for version 99/IETF QUIC. + QuicStopSendingFrame stop_sending(kInvalidControlFrameId, stream->id(), + QUIC_STREAM_CANCELLED); + // Expect the RESET_STREAM that is generated in response to receiving a + // STOP_SENDING. + EXPECT_CALL(*connection_, + OnStreamReset(stream->id(), QUIC_STREAM_CANCELLED)); + EXPECT_CALL(*connection_, SendControlFrame(_)); + session_.OnStopSendingFrame(stop_sending); + } + + EXPECT_EQ(kByteOffset, session_.flow_controller()->bytes_consumed()); +} + +TEST_P(QuicSpdySessionTestServer, InvalidStreamFlowControlWindowInHandshake) { + if (GetParam().handshake_protocol == PROTOCOL_TLS1_3) { + // IETF Quic doesn't require a minimum flow control window. + return; + } + // Test that receipt of an invalid (< default) stream flow control window from + // the peer results in the connection being torn down. + const uint32_t kInvalidWindow = kMinimumFlowControlSendWindow - 1; + QuicConfigPeer::SetReceivedInitialStreamFlowControlWindow(session_.config(), + kInvalidWindow); + + EXPECT_CALL(*connection_, + CloseConnection(QUIC_FLOW_CONTROL_INVALID_WINDOW, _, _)); + session_.OnConfigNegotiated(); +} + +TEST_P(QuicSpdySessionTestServer, TooLowUnidirectionalStreamLimitHttp3) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + session_.GetMutableCryptoStream()->EstablishZeroRttEncryption(); + QuicConfigPeer::SetReceivedMaxUnidirectionalStreams(session_.config(), 2u); + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + + EXPECT_CALL( + *connection_, + CloseConnection( + _, "new unidirectional limit 2 decreases the current limit: 3", _)); + session_.OnConfigNegotiated(); +} + +// Test negotiation of custom server initial flow control window. +TEST_P(QuicSpdySessionTestServer, CustomFlowControlWindow) { + QuicTagVector copt; + copt.push_back(kIFW7); + QuicConfigPeer::SetReceivedConnectionOptions(session_.config(), copt); + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + session_.OnConfigNegotiated(); + EXPECT_EQ(192 * 1024u, QuicFlowControllerPeer::ReceiveWindowSize( + session_.flow_controller())); +} + +TEST_P(QuicSpdySessionTestServer, WindowUpdateUnblocksHeadersStream) { + if (VersionUsesHttp3(transport_version())) { + // The test relies on headers stream, which no longer exists in IETF QUIC. + return; + } + + // Test that a flow control blocked headers stream gets unblocked on recipt of + // a WINDOW_UPDATE frame. + + // Set the headers stream to be flow control blocked. + QuicHeadersStream* headers_stream = + QuicSpdySessionPeer::GetHeadersStream(&session_); + QuicStreamPeer::SetSendWindowOffset(headers_stream, 0); + EXPECT_TRUE(headers_stream->IsFlowControlBlocked()); + EXPECT_FALSE(session_.IsConnectionFlowControlBlocked()); + EXPECT_TRUE(session_.IsStreamFlowControlBlocked()); + + // Unblock the headers stream by supplying a WINDOW_UPDATE. + QuicWindowUpdateFrame window_update_frame(kInvalidControlFrameId, + headers_stream->id(), + 2 * kMinimumFlowControlSendWindow); + session_.OnWindowUpdateFrame(window_update_frame); + EXPECT_FALSE(headers_stream->IsFlowControlBlocked()); + EXPECT_FALSE(session_.IsConnectionFlowControlBlocked()); + EXPECT_FALSE(session_.IsStreamFlowControlBlocked()); +} + +TEST_P(QuicSpdySessionTestServer, + TooManyUnfinishedStreamsCauseServerRejectStream) { + // If a buggy/malicious peer creates too many streams that are not ended + // with a FIN or RST then we send an RST to refuse streams for versions other + // than version 99. In version 99 the connection gets closed. + CompleteHandshake(); + const QuicStreamId kMaxStreams = 5; + if (VersionHasIetfQuicFrames(transport_version())) { + QuicSessionPeer::SetMaxOpenIncomingBidirectionalStreams(&session_, + kMaxStreams); + } else { + QuicSessionPeer::SetMaxOpenIncomingStreams(&session_, kMaxStreams); + } + // GetNth assumes that both the crypto and header streams have been + // open, but the stream id manager, using GetFirstBidirectional... only + // assumes that the crypto stream is open. This means that GetNth...(0) + // Will return stream ID == 8 (with id ==0 for crypto and id==4 for headers). + // It also means that GetNth(kMax..=5) returns 28 (streams 0/1/2/3/4 are ids + // 8, 12, 16, 20, 24, respectively, so stream#5 is stream id 28). + // However, the stream ID manager does not assume stream 4 is for headers. + // The ID manager would assume that stream#5 is streamid 24. + // In order to make this all work out properly, kFinalStreamId will + // be set to GetNth...(kMaxStreams-1)... but only for IETF QUIC + const QuicStreamId kFirstStreamId = GetNthClientInitiatedBidirectionalId(0); + const QuicStreamId kFinalStreamId = + GetNthClientInitiatedBidirectionalId(kMaxStreams); + // Create kMaxStreams data streams, and close them all without receiving a + // FIN or a RST_STREAM from the client. + const QuicStreamId kNextId = QuicUtils::StreamIdDelta(transport_version()); + for (QuicStreamId i = kFirstStreamId; i < kFinalStreamId; i += kNextId) { + QuicStreamFrame data1(i, false, 0, absl::string_view("HT")); + session_.OnStreamFrame(data1); + CloseStream(i); + } + // Try and open a stream that exceeds the limit. + if (!VersionHasIetfQuicFrames(transport_version())) { + // On versions other than 99, opening such a stream results in a + // RST_STREAM. + EXPECT_CALL(*connection_, SendControlFrame(_)).Times(1); + EXPECT_CALL(*connection_, + OnStreamReset(kFinalStreamId, QUIC_REFUSED_STREAM)) + .Times(1); + } else { + // On version 99 opening such a stream results in a connection close. + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_INVALID_STREAM_ID, + testing::MatchesRegex( + "Stream id \\d+ would exceed stream count limit 5"), + _)); + } + // Create one more data streams to exceed limit of open stream. + QuicStreamFrame data1(kFinalStreamId, false, 0, absl::string_view("HT")); + session_.OnStreamFrame(data1); +} + +TEST_P(QuicSpdySessionTestServer, DrainingStreamsDoNotCountAsOpened) { + // Verify that a draining stream (which has received a FIN but not consumed + // it) does not count against the open quota (because it is closed from the + // protocol point of view). + CompleteHandshake(); + if (VersionHasIetfQuicFrames(transport_version())) { + // Simulate receiving a config. so that MAX_STREAMS/etc frames may + // be transmitted + QuicSessionPeer::set_is_configured(&session_, true); + // Version 99 will result in a MAX_STREAMS frame as streams are consumed + // (via the OnStreamFrame call) and then released (via + // StreamDraining). Eventually this node will believe that the peer is + // running low on available stream ids and then send a MAX_STREAMS frame, + // caught by this EXPECT_CALL. + EXPECT_CALL(*connection_, SendControlFrame(_)).Times(1); + } else { + EXPECT_CALL(*connection_, SendControlFrame(_)).Times(0); + } + EXPECT_CALL(*connection_, OnStreamReset(_, QUIC_REFUSED_STREAM)).Times(0); + const QuicStreamId kMaxStreams = 5; + if (VersionHasIetfQuicFrames(transport_version())) { + QuicSessionPeer::SetMaxOpenIncomingBidirectionalStreams(&session_, + kMaxStreams); + } else { + QuicSessionPeer::SetMaxOpenIncomingStreams(&session_, kMaxStreams); + } + + // Create kMaxStreams + 1 data streams, and mark them draining. + const QuicStreamId kFirstStreamId = GetNthClientInitiatedBidirectionalId(0); + const QuicStreamId kFinalStreamId = + GetNthClientInitiatedBidirectionalId(kMaxStreams + 1); + for (QuicStreamId i = kFirstStreamId; i < kFinalStreamId; i += IdDelta()) { + QuicStreamFrame data1(i, true, 0, absl::string_view("HT")); + session_.OnStreamFrame(data1); + EXPECT_EQ(1u, QuicSessionPeer::GetNumOpenDynamicStreams(&session_)); + session_.StreamDraining(i, /*unidirectional=*/false); + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(&session_)); + } +} + +class QuicSpdySessionTestClient : public QuicSpdySessionTestBase { + protected: + QuicSpdySessionTestClient() + : QuicSpdySessionTestBase(Perspective::IS_CLIENT, false) {} +}; + +INSTANTIATE_TEST_SUITE_P(Tests, QuicSpdySessionTestClient, + ::testing::ValuesIn(AllSupportedVersions()), + ::testing::PrintToStringParamName()); + +TEST_P(QuicSpdySessionTestClient, UsesPendingStreamsForFrame) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + EXPECT_TRUE(session_.UsesPendingStreamForFrame( + STREAM_FRAME, QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_SERVER))); + EXPECT_TRUE(session_.UsesPendingStreamForFrame( + RST_STREAM_FRAME, QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_SERVER))); + EXPECT_FALSE(session_.UsesPendingStreamForFrame( + RST_STREAM_FRAME, QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT))); + EXPECT_FALSE(session_.UsesPendingStreamForFrame( + STOP_SENDING_FRAME, QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_SERVER))); + EXPECT_FALSE(session_.UsesPendingStreamForFrame( + RST_STREAM_FRAME, QuicUtils::GetFirstBidirectionalStreamId( + transport_version(), Perspective::IS_SERVER))); +} + +// Regression test for crbug.com/977581. +TEST_P(QuicSpdySessionTestClient, BadStreamFramePendingStream) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(&session_)); + QuicStreamId stream_id1 = + GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 0); + // A bad stream frame with no data and no fin. + QuicStreamFrame data1(stream_id1, false, 0, 0); + session_.OnStreamFrame(data1); +} + +TEST_P(QuicSpdySessionTestClient, PendingStreamKeepsConnectionAlive) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + QuicStreamId stream_id = QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_SERVER); + + QuicStreamFrame frame(stream_id, false, 1, "test"); + EXPECT_FALSE(session_.ShouldKeepConnectionAlive()); + session_.OnStreamFrame(frame); + EXPECT_TRUE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); + EXPECT_TRUE(session_.ShouldKeepConnectionAlive()); +} + +TEST_P(QuicSpdySessionTestClient, AvailableStreamsClient) { + ASSERT_TRUE(session_.GetOrCreateStream( + GetNthServerInitiatedBidirectionalId(2)) != nullptr); + // Both server initiated streams with smaller stream IDs should be available. + EXPECT_TRUE(QuicSessionPeer::IsStreamAvailable( + &session_, GetNthServerInitiatedBidirectionalId(0))); + EXPECT_TRUE(QuicSessionPeer::IsStreamAvailable( + &session_, GetNthServerInitiatedBidirectionalId(1))); + ASSERT_TRUE(session_.GetOrCreateStream( + GetNthServerInitiatedBidirectionalId(0)) != nullptr); + ASSERT_TRUE(session_.GetOrCreateStream( + GetNthServerInitiatedBidirectionalId(1)) != nullptr); + // And client initiated stream ID should be not available. + EXPECT_FALSE(QuicSessionPeer::IsStreamAvailable( + &session_, GetNthClientInitiatedBidirectionalId(0))); +} + +// Regression test for b/130740258 and https://crbug.com/971779. +// If headers that are too large or empty are received (these cases are handled +// the same way, as QuicHeaderList clears itself when headers exceed the limit), +// then the stream is reset. No more frames must be sent in this case. +TEST_P(QuicSpdySessionTestClient, TooLargeHeadersMustNotCauseWriteAfterReset) { + // In IETF QUIC, HEADERS do not carry FIN flag, and OnStreamHeaderList() is + // never called after an error, including too large headers. + if (VersionUsesHttp3(transport_version())) { + return; + } + CompleteHandshake(); + TestStream* stream = session_.CreateOutgoingBidirectionalStream(); + + EXPECT_CALL(*writer_, WritePacket(_, _, _, _, _)) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 0))); + // Write headers with FIN set to close write side of stream. + // Header block does not matter. + stream->WriteHeaders(Http2HeaderBlock(), /* fin = */ true, nullptr); + + // Receive headers that are too large or empty, with FIN set. + // This causes the stream to be reset. No frames must be written after this. + QuicHeaderList headers; + EXPECT_CALL(*connection_, SendControlFrame(_)); + EXPECT_CALL(*connection_, + OnStreamReset(stream->id(), QUIC_HEADERS_TOO_LARGE)); + stream->OnStreamHeaderList(/* fin = */ true, + headers.uncompressed_header_bytes(), headers); +} + +TEST_P(QuicSpdySessionTestClient, RecordFinAfterReadSideClosed) { + // Verify that an incoming FIN is recorded in a stream object even if the read + // side has been closed. This prevents an entry from being made in + // locally_closed_streams_highest_offset_ (which will never be deleted). + CompleteHandshake(); + TestStream* stream = session_.CreateOutgoingBidirectionalStream(); + QuicStreamId stream_id = stream->id(); + + // Close the read side manually. + QuicStreamPeer::CloseReadSide(stream); + + // Receive a stream data frame with FIN. + QuicStreamFrame frame(stream_id, true, 0, absl::string_view()); + session_.OnStreamFrame(frame); + EXPECT_TRUE(stream->fin_received()); + + // Reset stream locally. + EXPECT_CALL(*connection_, SendControlFrame(_)); + EXPECT_CALL(*connection_, OnStreamReset(stream->id(), _)); + stream->Reset(QUIC_STREAM_CANCELLED); + EXPECT_TRUE(QuicStreamPeer::read_side_closed(stream)); + + EXPECT_TRUE(connection_->connected()); + EXPECT_TRUE(QuicSessionPeer::IsStreamClosed(&session_, stream_id)); + EXPECT_FALSE(QuicSessionPeer::IsStreamCreated(&session_, stream_id)); + + // The stream is not waiting for the arrival of the peer's final offset as it + // was received with the FIN earlier. + EXPECT_EQ( + 0u, + QuicSessionPeer::GetLocallyClosedStreamsHighestOffset(&session_).size()); +} + +TEST_P(QuicSpdySessionTestClient, WritePriority) { + if (VersionUsesHttp3(transport_version())) { + // IETF QUIC currently doesn't support PRIORITY. + return; + } + CompleteHandshake(); + + TestHeadersStream* headers_stream; + QuicSpdySessionPeer::SetHeadersStream(&session_, nullptr); + headers_stream = new TestHeadersStream(&session_); + QuicSpdySessionPeer::SetHeadersStream(&session_, headers_stream); + + // Make packet writer blocked so |headers_stream| will buffer its write data. + EXPECT_CALL(*writer_, IsWriteBlocked()).WillRepeatedly(Return(true)); + + const QuicStreamId id = 4; + const QuicStreamId parent_stream_id = 9; + const SpdyPriority priority = kV3HighestPriority; + const bool exclusive = true; + session_.WritePriority(id, parent_stream_id, + Spdy3PriorityToHttp2Weight(priority), exclusive); + + QuicStreamSendBuffer& send_buffer = + QuicStreamPeer::SendBuffer(headers_stream); + ASSERT_EQ(1u, send_buffer.size()); + + SpdyPriorityIR priority_frame( + id, parent_stream_id, Spdy3PriorityToHttp2Weight(priority), exclusive); + SpdyFramer spdy_framer(SpdyFramer::ENABLE_COMPRESSION); + SpdySerializedFrame frame = spdy_framer.SerializeFrame(priority_frame); + + const quiche::QuicheMemSlice& slice = + QuicStreamSendBufferPeer::CurrentWriteSlice(&send_buffer)->slice; + EXPECT_EQ(absl::string_view(frame.data(), frame.size()), + absl::string_view(slice.data(), slice.length())); +} + +TEST_P(QuicSpdySessionTestClient, Http3ServerPush) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(&session_)); + + // Push unidirectional stream is type 0x01. + std::string frame_type1 = absl::HexStringToBytes("01"); + QuicStreamId stream_id1 = + GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 0); + EXPECT_CALL(*connection_, + CloseConnection(QUIC_HTTP_RECEIVE_SERVER_PUSH, _, _)) + .Times(1); + session_.OnStreamFrame(QuicStreamFrame(stream_id1, /* fin = */ false, + /* offset = */ 0, frame_type1)); +} + +TEST_P(QuicSpdySessionTestClient, Http3ServerPushOutofOrderFrame) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(&session_)); + + // Push unidirectional stream is type 0x01. + std::string frame_type = absl::HexStringToBytes("01"); + // The first field of a push stream is the Push ID. + std::string push_id = absl::HexStringToBytes("4000"); + + QuicStreamId stream_id = + GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 0); + + QuicStreamFrame data1(stream_id, + /* fin = */ false, /* offset = */ 0, frame_type); + QuicStreamFrame data2(stream_id, + /* fin = */ false, /* offset = */ frame_type.size(), + push_id); + + // Receiving some stream data without stream type does not open the stream. + session_.OnStreamFrame(data2); + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(&session_)); + EXPECT_CALL(*connection_, + CloseConnection(QUIC_HTTP_RECEIVE_SERVER_PUSH, _, _)) + .Times(1); + session_.OnStreamFrame(data1); +} + +TEST_P(QuicSpdySessionTestServer, OnStreamFrameLost) { + CompleteHandshake(); + InSequence s; + + // Drive congestion control manually. + MockSendAlgorithm* send_algorithm = new StrictMock; + QuicConnectionPeer::SetSendAlgorithm(session_.connection(), send_algorithm); + + TestCryptoStream* crypto_stream = session_.GetMutableCryptoStream(); + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + TestStream* stream4 = session_.CreateOutgoingBidirectionalStream(); + + QuicStreamFrame frame2(stream2->id(), false, 0, 9); + QuicStreamFrame frame3(stream4->id(), false, 0, 9); + + // Lost data on cryption stream, streams 2 and 4. + EXPECT_CALL(*stream4, HasPendingRetransmission()).WillOnce(Return(true)); + if (!QuicVersionUsesCryptoFrames(transport_version())) { + EXPECT_CALL(*crypto_stream, HasPendingRetransmission()) + .WillOnce(Return(true)); + } + EXPECT_CALL(*stream2, HasPendingRetransmission()).WillOnce(Return(true)); + session_.OnFrameLost(QuicFrame(frame3)); + if (!QuicVersionUsesCryptoFrames(transport_version())) { + QuicStreamFrame frame1(QuicUtils::GetCryptoStreamId(transport_version()), + false, 0, 1300); + session_.OnFrameLost(QuicFrame(frame1)); + } else { + QuicCryptoFrame crypto_frame(ENCRYPTION_INITIAL, 0, 1300); + session_.OnFrameLost(QuicFrame(&crypto_frame)); + } + session_.OnFrameLost(QuicFrame(frame2)); + EXPECT_TRUE(session_.WillingAndAbleToWrite()); + + // Mark streams 2 and 4 write blocked. + session_.MarkConnectionLevelWriteBlocked(stream2->id()); + session_.MarkConnectionLevelWriteBlocked(stream4->id()); + + // Lost data is retransmitted before new data, and retransmissions for crypto + // stream go first. + // Do not check congestion window when crypto stream has lost data. + EXPECT_CALL(*send_algorithm, CanSend(_)).Times(0); + if (!QuicVersionUsesCryptoFrames(transport_version())) { + EXPECT_CALL(*crypto_stream, OnCanWrite()); + EXPECT_CALL(*crypto_stream, HasPendingRetransmission()) + .WillOnce(Return(false)); + } + // Check congestion window for non crypto streams. + EXPECT_CALL(*send_algorithm, CanSend(_)).WillOnce(Return(true)); + EXPECT_CALL(*stream4, OnCanWrite()); + EXPECT_CALL(*stream4, HasPendingRetransmission()).WillOnce(Return(false)); + // Connection is blocked. + EXPECT_CALL(*send_algorithm, CanSend(_)).WillRepeatedly(Return(false)); + + session_.OnCanWrite(); + EXPECT_TRUE(session_.WillingAndAbleToWrite()); + + // Unblock connection. + // Stream 2 retransmits lost data. + EXPECT_CALL(*send_algorithm, CanSend(_)).WillOnce(Return(true)); + EXPECT_CALL(*stream2, OnCanWrite()); + EXPECT_CALL(*stream2, HasPendingRetransmission()).WillOnce(Return(false)); + EXPECT_CALL(*send_algorithm, CanSend(_)).WillOnce(Return(true)); + // Stream 2 sends new data. + EXPECT_CALL(*stream2, OnCanWrite()); + EXPECT_CALL(*send_algorithm, CanSend(_)).WillOnce(Return(true)); + EXPECT_CALL(*stream4, OnCanWrite()); + EXPECT_CALL(*send_algorithm, OnApplicationLimited(_)); + + session_.OnCanWrite(); + EXPECT_FALSE(session_.WillingAndAbleToWrite()); +} + +TEST_P(QuicSpdySessionTestServer, DonotRetransmitDataOfClosedStreams) { + // Resetting a stream will send a QPACK Stream Cancellation instruction on the + // decoder stream. For simplicity, ignore writes on this stream. + CompleteHandshake(); + NoopQpackStreamSenderDelegate qpack_stream_sender_delegate; + if (VersionUsesHttp3(transport_version())) { + session_.qpack_decoder()->set_qpack_stream_sender_delegate( + &qpack_stream_sender_delegate); + } + + InSequence s; + + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + TestStream* stream4 = session_.CreateOutgoingBidirectionalStream(); + TestStream* stream6 = session_.CreateOutgoingBidirectionalStream(); + + QuicStreamFrame frame1(stream2->id(), false, 0, 9); + QuicStreamFrame frame2(stream4->id(), false, 0, 9); + QuicStreamFrame frame3(stream6->id(), false, 0, 9); + + EXPECT_CALL(*stream6, HasPendingRetransmission()).WillOnce(Return(true)); + EXPECT_CALL(*stream4, HasPendingRetransmission()).WillOnce(Return(true)); + EXPECT_CALL(*stream2, HasPendingRetransmission()).WillOnce(Return(true)); + session_.OnFrameLost(QuicFrame(frame3)); + session_.OnFrameLost(QuicFrame(frame2)); + session_.OnFrameLost(QuicFrame(frame1)); + + session_.MarkConnectionLevelWriteBlocked(stream2->id()); + session_.MarkConnectionLevelWriteBlocked(stream4->id()); + session_.MarkConnectionLevelWriteBlocked(stream6->id()); + + // Reset stream 4 locally. + EXPECT_CALL(*connection_, SendControlFrame(_)); + EXPECT_CALL(*connection_, OnStreamReset(stream4->id(), _)); + stream4->Reset(QUIC_STREAM_CANCELLED); + + // Verify stream 4 is removed from streams with lost data list. + EXPECT_CALL(*stream6, OnCanWrite()); + EXPECT_CALL(*stream6, HasPendingRetransmission()).WillOnce(Return(false)); + EXPECT_CALL(*stream2, OnCanWrite()); + EXPECT_CALL(*stream2, HasPendingRetransmission()).WillOnce(Return(false)); + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillRepeatedly(Invoke(&ClearControlFrame)); + EXPECT_CALL(*stream2, OnCanWrite()); + EXPECT_CALL(*stream6, OnCanWrite()); + session_.OnCanWrite(); +} + +TEST_P(QuicSpdySessionTestServer, RetransmitFrames) { + CompleteHandshake(); + MockSendAlgorithm* send_algorithm = new StrictMock; + QuicConnectionPeer::SetSendAlgorithm(session_.connection(), send_algorithm); + InSequence s; + + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + TestStream* stream4 = session_.CreateOutgoingBidirectionalStream(); + TestStream* stream6 = session_.CreateOutgoingBidirectionalStream(); + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillOnce(Invoke(&ClearControlFrame)); + session_.SendWindowUpdate(stream2->id(), 9); + + QuicStreamFrame frame1(stream2->id(), false, 0, 9); + QuicStreamFrame frame2(stream4->id(), false, 0, 9); + QuicStreamFrame frame3(stream6->id(), false, 0, 9); + QuicWindowUpdateFrame window_update(1, stream2->id(), 9); + QuicFrames frames; + frames.push_back(QuicFrame(frame1)); + frames.push_back(QuicFrame(window_update)); + frames.push_back(QuicFrame(frame2)); + frames.push_back(QuicFrame(frame3)); + EXPECT_FALSE(session_.WillingAndAbleToWrite()); + + EXPECT_CALL(*stream2, RetransmitStreamData(_, _, _, _)) + .WillOnce(Return(true)); + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillOnce(Invoke(&ClearControlFrame)); + EXPECT_CALL(*stream4, RetransmitStreamData(_, _, _, _)) + .WillOnce(Return(true)); + EXPECT_CALL(*stream6, RetransmitStreamData(_, _, _, _)) + .WillOnce(Return(true)); + EXPECT_CALL(*send_algorithm, OnApplicationLimited(_)); + session_.RetransmitFrames(frames, PTO_RETRANSMISSION); +} + +TEST_P(QuicSpdySessionTestServer, OnPriorityFrame) { + QuicStreamId stream_id = GetNthClientInitiatedBidirectionalId(0); + TestStream* stream = session_.CreateIncomingStream(stream_id); + session_.OnPriorityFrame(stream_id, + spdy::SpdyStreamPrecedence(kV3HighestPriority)); + + EXPECT_EQ((QuicStreamPriority(HttpStreamPriority{ + kV3HighestPriority, HttpStreamPriority::kDefaultIncremental})), + stream->priority()); +} + +TEST_P(QuicSpdySessionTestServer, OnPriorityUpdateFrame) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + StrictMock debug_visitor; + session_.set_debug_visitor(&debug_visitor); + + // Create control stream. + QuicStreamId receive_control_stream_id = + GetNthClientInitiatedUnidirectionalStreamId(transport_version(), 3); + char type[] = {kControlStream}; + absl::string_view stream_type(type, 1); + QuicStreamOffset offset = 0; + QuicStreamFrame data1(receive_control_stream_id, false, offset, stream_type); + offset += stream_type.length(); + EXPECT_CALL(debug_visitor, + OnPeerControlStreamCreated(receive_control_stream_id)); + session_.OnStreamFrame(data1); + EXPECT_EQ(receive_control_stream_id, + QuicSpdySessionPeer::GetReceiveControlStream(&session_)->id()); + + // Send SETTINGS frame. + std::string serialized_settings = HttpEncoder::SerializeSettingsFrame({}); + QuicStreamFrame data2(receive_control_stream_id, false, offset, + serialized_settings); + offset += serialized_settings.length(); + EXPECT_CALL(debug_visitor, OnSettingsFrameReceived(_)); + session_.OnStreamFrame(data2); + + // PRIORITY_UPDATE frame for first request stream. + const QuicStreamId stream_id1 = GetNthClientInitiatedBidirectionalId(0); + PriorityUpdateFrame priority_update1{stream_id1, "u=2"}; + std::string serialized_priority_update1 = + HttpEncoder::SerializePriorityUpdateFrame(priority_update1); + QuicStreamFrame data3(receive_control_stream_id, + /* fin = */ false, offset, serialized_priority_update1); + offset += serialized_priority_update1.size(); + + // PRIORITY_UPDATE frame arrives after stream creation. + TestStream* stream1 = session_.CreateIncomingStream(stream_id1); + EXPECT_EQ(QuicStreamPriority( + HttpStreamPriority{HttpStreamPriority::kDefaultUrgency, + HttpStreamPriority::kDefaultIncremental}), + stream1->priority()); + EXPECT_CALL(debug_visitor, OnPriorityUpdateFrameReceived(priority_update1)); + session_.OnStreamFrame(data3); + EXPECT_EQ(QuicStreamPriority(HttpStreamPriority{ + 2u, HttpStreamPriority::kDefaultIncremental}), + stream1->priority()); + + // PRIORITY_UPDATE frame for second request stream. + const QuicStreamId stream_id2 = GetNthClientInitiatedBidirectionalId(1); + PriorityUpdateFrame priority_update2{stream_id2, "u=5, i"}; + std::string serialized_priority_update2 = + HttpEncoder::SerializePriorityUpdateFrame(priority_update2); + QuicStreamFrame stream_frame3(receive_control_stream_id, + /* fin = */ false, offset, + serialized_priority_update2); + + // PRIORITY_UPDATE frame arrives before stream creation, + // priority value is buffered. + EXPECT_CALL(debug_visitor, OnPriorityUpdateFrameReceived(priority_update2)); + session_.OnStreamFrame(stream_frame3); + // Priority is applied upon stream construction. + TestStream* stream2 = session_.CreateIncomingStream(stream_id2); + EXPECT_EQ(QuicStreamPriority(HttpStreamPriority{5u, true}), + stream2->priority()); +} + +TEST_P(QuicSpdySessionTestServer, OnInvalidPriorityUpdateFrame) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + StrictMock debug_visitor; + session_.set_debug_visitor(&debug_visitor); + + // Create control stream. + QuicStreamId receive_control_stream_id = + GetNthClientInitiatedUnidirectionalStreamId(transport_version(), 3); + char type[] = {kControlStream}; + absl::string_view stream_type(type, 1); + QuicStreamOffset offset = 0; + QuicStreamFrame data1(receive_control_stream_id, false, offset, stream_type); + offset += stream_type.length(); + EXPECT_CALL(debug_visitor, + OnPeerControlStreamCreated(receive_control_stream_id)); + session_.OnStreamFrame(data1); + EXPECT_EQ(receive_control_stream_id, + QuicSpdySessionPeer::GetReceiveControlStream(&session_)->id()); + + // Send SETTINGS frame. + std::string serialized_settings = HttpEncoder::SerializeSettingsFrame({}); + QuicStreamFrame data2(receive_control_stream_id, false, offset, + serialized_settings); + offset += serialized_settings.length(); + EXPECT_CALL(debug_visitor, OnSettingsFrameReceived(_)); + session_.OnStreamFrame(data2); + + // PRIORITY_UPDATE frame with Priority Field Value that is not valid + // Structured Headers. + const QuicStreamId stream_id = GetNthClientInitiatedBidirectionalId(0); + PriorityUpdateFrame priority_update{stream_id, "00"}; + + EXPECT_CALL(debug_visitor, OnPriorityUpdateFrameReceived(priority_update)); + EXPECT_CALL(*connection_, + CloseConnection(QUIC_INVALID_PRIORITY_UPDATE, + "Invalid PRIORITY_UPDATE frame payload.", _)); + + std::string serialized_priority_update = + HttpEncoder::SerializePriorityUpdateFrame(priority_update); + QuicStreamFrame data3(receive_control_stream_id, + /* fin = */ false, offset, serialized_priority_update); + session_.OnStreamFrame(data3); +} + +TEST_P(QuicSpdySessionTestServer, OnPriorityUpdateFrameOutOfBoundsUrgency) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + StrictMock debug_visitor; + session_.set_debug_visitor(&debug_visitor); + + // Create control stream. + QuicStreamId receive_control_stream_id = + GetNthClientInitiatedUnidirectionalStreamId(transport_version(), 3); + char type[] = {kControlStream}; + absl::string_view stream_type(type, 1); + QuicStreamOffset offset = 0; + QuicStreamFrame data1(receive_control_stream_id, false, offset, stream_type); + offset += stream_type.length(); + EXPECT_CALL(debug_visitor, + OnPeerControlStreamCreated(receive_control_stream_id)); + session_.OnStreamFrame(data1); + EXPECT_EQ(receive_control_stream_id, + QuicSpdySessionPeer::GetReceiveControlStream(&session_)->id()); + + // Send SETTINGS frame. + std::string serialized_settings = HttpEncoder::SerializeSettingsFrame({}); + QuicStreamFrame data2(receive_control_stream_id, false, offset, + serialized_settings); + offset += serialized_settings.length(); + EXPECT_CALL(debug_visitor, OnSettingsFrameReceived(_)); + session_.OnStreamFrame(data2); + + // PRIORITY_UPDATE frame with urgency not in [0,7]. + const QuicStreamId stream_id = GetNthClientInitiatedBidirectionalId(0); + PriorityUpdateFrame priority_update{stream_id, "u=9"}; + + EXPECT_CALL(debug_visitor, OnPriorityUpdateFrameReceived(priority_update)); + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + + std::string serialized_priority_update = + HttpEncoder::SerializePriorityUpdateFrame(priority_update); + QuicStreamFrame data3(receive_control_stream_id, + /* fin = */ false, offset, serialized_priority_update); + session_.OnStreamFrame(data3); +} + +TEST_P(QuicSpdySessionTestServer, SimplePendingStreamType) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + CompleteHandshake(); + char input[] = {0x04, // type + 'a', 'b', 'c'}; // data + absl::string_view payload(input, ABSL_ARRAYSIZE(input)); + + // This is a server test with a client-initiated unidirectional stream. + QuicStreamId stream_id = QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT); + + for (bool fin : {true, false}) { + QuicStreamFrame frame(stream_id, fin, /* offset = */ 0, payload); + + // A STOP_SENDING frame is sent in response to the unknown stream type. + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillOnce(Invoke([stream_id](const QuicFrame& frame) { + EXPECT_EQ(STOP_SENDING_FRAME, frame.type); + + const QuicStopSendingFrame& stop_sending = frame.stop_sending_frame; + EXPECT_EQ(stream_id, stop_sending.stream_id); + EXPECT_EQ(QUIC_STREAM_STREAM_CREATION_ERROR, stop_sending.error_code); + EXPECT_EQ( + static_cast(QuicHttp3ErrorCode::STREAM_CREATION_ERROR), + stop_sending.ietf_error_code); + + return ClearControlFrame(frame); + })); + session_.OnStreamFrame(frame); + + PendingStream* pending = + QuicSessionPeer::GetPendingStream(&session_, stream_id); + if (fin) { + // Stream is closed if FIN is received. + EXPECT_FALSE(pending); + } else { + ASSERT_TRUE(pending); + // The pending stream must ignore read data. + EXPECT_TRUE(pending->sequencer()->ignore_read_data()); + } + + stream_id += QuicUtils::StreamIdDelta(transport_version()); + } +} + +TEST_P(QuicSpdySessionTestServer, SimplePendingStreamTypeOutOfOrderDelivery) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + CompleteHandshake(); + char input[] = {0x04, // type + 'a', 'b', 'c'}; // data + absl::string_view payload(input, ABSL_ARRAYSIZE(input)); + + // This is a server test with a client-initiated unidirectional stream. + QuicStreamId stream_id = QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT); + + for (bool fin : {true, false}) { + QuicStreamFrame frame1(stream_id, /* fin = */ false, /* offset = */ 0, + payload.substr(0, 1)); + QuicStreamFrame frame2(stream_id, fin, /* offset = */ 1, payload.substr(1)); + + // Deliver frames out of order. + session_.OnStreamFrame(frame2); + // A STOP_SENDING frame is sent in response to the unknown stream type. + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillOnce(Invoke(&VerifyAndClearStopSendingFrame)); + session_.OnStreamFrame(frame1); + + PendingStream* pending = + QuicSessionPeer::GetPendingStream(&session_, stream_id); + if (fin) { + // Stream is closed if FIN is received. + EXPECT_FALSE(pending); + } else { + ASSERT_TRUE(pending); + // The pending stream must ignore read data. + EXPECT_TRUE(pending->sequencer()->ignore_read_data()); + } + + stream_id += QuicUtils::StreamIdDelta(transport_version()); + } +} + +TEST_P(QuicSpdySessionTestServer, + MultipleBytesPendingStreamTypeOutOfOrderDelivery) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + CompleteHandshake(); + char input[] = {0x41, 0x00, // type (256) + 'a', 'b', 'c'}; // data + absl::string_view payload(input, ABSL_ARRAYSIZE(input)); + + // This is a server test with a client-initiated unidirectional stream. + QuicStreamId stream_id = QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT); + + for (bool fin : {true, false}) { + QuicStreamFrame frame1(stream_id, /* fin = */ false, /* offset = */ 0, + payload.substr(0, 1)); + QuicStreamFrame frame2(stream_id, /* fin = */ false, /* offset = */ 1, + payload.substr(1, 1)); + QuicStreamFrame frame3(stream_id, fin, /* offset = */ 2, payload.substr(2)); + + // Deliver frames out of order. + session_.OnStreamFrame(frame3); + // The first byte does not contain the entire type varint. + session_.OnStreamFrame(frame1); + // A STOP_SENDING frame is sent in response to the unknown stream type. + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillOnce(Invoke(&VerifyAndClearStopSendingFrame)); + session_.OnStreamFrame(frame2); + + PendingStream* pending = + QuicSessionPeer::GetPendingStream(&session_, stream_id); + if (fin) { + // Stream is closed if FIN is received. + EXPECT_FALSE(pending); + } else { + ASSERT_TRUE(pending); + // The pending stream must ignore read data. + EXPECT_TRUE(pending->sequencer()->ignore_read_data()); + } + + stream_id += QuicUtils::StreamIdDelta(transport_version()); + } +} + +TEST_P(QuicSpdySessionTestServer, ReceiveControlStream) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + CompleteHandshake(); + StrictMock debug_visitor; + session_.set_debug_visitor(&debug_visitor); + + // Use an arbitrary stream id. + QuicStreamId stream_id = + GetNthClientInitiatedUnidirectionalStreamId(transport_version(), 3); + char type[] = {kControlStream}; + + QuicStreamFrame data1(stream_id, false, 0, absl::string_view(type, 1)); + EXPECT_CALL(debug_visitor, OnPeerControlStreamCreated(stream_id)); + session_.OnStreamFrame(data1); + EXPECT_EQ(stream_id, + QuicSpdySessionPeer::GetReceiveControlStream(&session_)->id()); + + SettingsFrame settings; + settings.values[SETTINGS_QPACK_MAX_TABLE_CAPACITY] = 512; + settings.values[SETTINGS_MAX_FIELD_SECTION_SIZE] = 5; + settings.values[SETTINGS_QPACK_BLOCKED_STREAMS] = 42; + std::string data = HttpEncoder::SerializeSettingsFrame(settings); + QuicStreamFrame frame(stream_id, false, 1, data); + + QpackEncoder* qpack_encoder = session_.qpack_encoder(); + QpackEncoderHeaderTable* header_table = + QpackEncoderPeer::header_table(qpack_encoder); + + EXPECT_NE(512u, header_table->maximum_dynamic_table_capacity()); + EXPECT_NE(5u, session_.max_outbound_header_list_size()); + EXPECT_NE(42u, QpackEncoderPeer::maximum_blocked_streams(qpack_encoder)); + + EXPECT_CALL(debug_visitor, OnSettingsFrameReceived(settings)); + session_.OnStreamFrame(frame); + + EXPECT_EQ(512u, header_table->maximum_dynamic_table_capacity()); + EXPECT_EQ(5u, session_.max_outbound_header_list_size()); + EXPECT_EQ(42u, QpackEncoderPeer::maximum_blocked_streams(qpack_encoder)); +} + +TEST_P(QuicSpdySessionTestServer, ReceiveControlStreamOutOfOrderDelivery) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + // Use an arbitrary stream id. + QuicStreamId stream_id = + GetNthClientInitiatedUnidirectionalStreamId(transport_version(), 3); + char type[] = {kControlStream}; + SettingsFrame settings; + settings.values[10] = 2; + settings.values[SETTINGS_MAX_FIELD_SECTION_SIZE] = 5; + std::string data = HttpEncoder::SerializeSettingsFrame(settings); + + QuicStreamFrame data1(stream_id, false, 1, data); + QuicStreamFrame data2(stream_id, false, 0, absl::string_view(type, 1)); + + session_.OnStreamFrame(data1); + EXPECT_NE(5u, session_.max_outbound_header_list_size()); + session_.OnStreamFrame(data2); + EXPECT_EQ(5u, session_.max_outbound_header_list_size()); +} + +// Regression test for https://crbug.com/1009551. +TEST_P(QuicSpdySessionTestServer, StreamClosedWhileHeaderDecodingBlocked) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + CompleteHandshake(); + session_.qpack_decoder()->OnSetDynamicTableCapacity(1024); + + QuicStreamId stream_id = GetNthClientInitiatedBidirectionalId(0); + TestStream* stream = session_.CreateIncomingStream(stream_id); + + // HEADERS frame referencing first dynamic table entry. + std::string headers_frame_payload = absl::HexStringToBytes("020080"); + std::string headers_frame_header = + HttpEncoder::SerializeHeadersFrameHeader(headers_frame_payload.length()); + std::string headers_frame = + absl::StrCat(headers_frame_header, headers_frame_payload); + stream->OnStreamFrame(QuicStreamFrame(stream_id, false, 0, headers_frame)); + + // Decoding is blocked because dynamic table entry has not been received yet. + EXPECT_FALSE(stream->headers_decompressed()); + + // Stream is closed and destroyed. + CloseStream(stream_id); + session_.CleanUpClosedStreams(); + + // Dynamic table entry arrived on the decoder stream. + // The destroyed stream object must not be referenced. + session_.qpack_decoder()->OnInsertWithoutNameReference("foo", "bar"); +} + +// Regression test for https://crbug.com/1011294. +TEST_P(QuicSpdySessionTestServer, SessionDestroyedWhileHeaderDecodingBlocked) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + session_.qpack_decoder()->OnSetDynamicTableCapacity(1024); + + QuicStreamId stream_id = GetNthClientInitiatedBidirectionalId(0); + TestStream* stream = session_.CreateIncomingStream(stream_id); + + // HEADERS frame referencing first dynamic table entry. + std::string headers_frame_payload = absl::HexStringToBytes("020080"); + std::string headers_frame_header = + HttpEncoder::SerializeHeadersFrameHeader(headers_frame_payload.length()); + std::string headers_frame = + absl::StrCat(headers_frame_header, headers_frame_payload); + stream->OnStreamFrame(QuicStreamFrame(stream_id, false, 0, headers_frame)); + + // Decoding is blocked because dynamic table entry has not been received yet. + EXPECT_FALSE(stream->headers_decompressed()); + + // |session_| gets destoyed. That destroys QpackDecoder, a member of + // QuicSpdySession (derived class), which destroys QpackDecoderHeaderTable. + // Then |*stream|, owned by QuicSession (base class) get destroyed, which + // destroys QpackProgessiveDecoder, a registered Observer of + // QpackDecoderHeaderTable. This must not cause a crash. +} + +TEST_P(QuicSpdySessionTestClient, ResetAfterInvalidIncomingStreamType) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + CompleteHandshake(); + + const QuicStreamId stream_id = + GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 0); + ASSERT_TRUE(session_.UsesPendingStreamForFrame(STREAM_FRAME, stream_id)); + + // Payload consists of two bytes. The first byte is an unknown unidirectional + // stream type. The second one would be the type of a push stream, but it + // must not be interpreted as stream type. + std::string payload = absl::HexStringToBytes("3f01"); + QuicStreamFrame frame(stream_id, /* fin = */ false, /* offset = */ 0, + payload); + + // A STOP_SENDING frame is sent in response to the unknown stream type. + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillOnce(Invoke(&VerifyAndClearStopSendingFrame)); + session_.OnStreamFrame(frame); + + // There are no active streams. + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(&session_)); + + // The pending stream is still around, because it did not receive a FIN. + PendingStream* pending = + QuicSessionPeer::GetPendingStream(&session_, stream_id); + ASSERT_TRUE(pending); + + // The pending stream must ignore read data. + EXPECT_TRUE(pending->sequencer()->ignore_read_data()); + + // If the stream frame is received again, it should be ignored. + session_.OnStreamFrame(frame); + + // Receive RESET_STREAM. + QuicRstStreamFrame rst_frame(kInvalidControlFrameId, stream_id, + QUIC_STREAM_CANCELLED, + /* bytes_written = */ payload.size()); + + session_.OnRstStream(rst_frame); + + // The stream is closed. + EXPECT_FALSE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); +} + +TEST_P(QuicSpdySessionTestClient, FinAfterInvalidIncomingStreamType) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + CompleteHandshake(); + + const QuicStreamId stream_id = + GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 0); + ASSERT_TRUE(session_.UsesPendingStreamForFrame(STREAM_FRAME, stream_id)); + + // Payload consists of two bytes. The first byte is an unknown unidirectional + // stream type. The second one would be the type of a push stream, but it + // must not be interpreted as stream type. + std::string payload = absl::HexStringToBytes("3f01"); + QuicStreamFrame frame(stream_id, /* fin = */ false, /* offset = */ 0, + payload); + + // A STOP_SENDING frame is sent in response to the unknown stream type. + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillOnce(Invoke(&VerifyAndClearStopSendingFrame)); + session_.OnStreamFrame(frame); + + // The pending stream is still around, because it did not receive a FIN. + PendingStream* pending = + QuicSessionPeer::GetPendingStream(&session_, stream_id); + EXPECT_TRUE(pending); + + // The pending stream must ignore read data. + EXPECT_TRUE(pending->sequencer()->ignore_read_data()); + + // If the stream frame is received again, it should be ignored. + session_.OnStreamFrame(frame); + + // Receive FIN. + session_.OnStreamFrame(QuicStreamFrame(stream_id, /* fin = */ true, + /* offset = */ payload.size(), "")); + + EXPECT_FALSE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); +} + +TEST_P(QuicSpdySessionTestClient, ResetInMiddleOfStreamType) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + const QuicStreamId stream_id = + GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 0); + ASSERT_TRUE(session_.UsesPendingStreamForFrame(STREAM_FRAME, stream_id)); + + // Payload is the first byte of a two byte varint encoding. + std::string payload = absl::HexStringToBytes("40"); + QuicStreamFrame frame(stream_id, /* fin = */ false, /* offset = */ 0, + payload); + + session_.OnStreamFrame(frame); + EXPECT_TRUE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); + + // Receive RESET_STREAM. + QuicRstStreamFrame rst_frame(kInvalidControlFrameId, stream_id, + QUIC_STREAM_CANCELLED, + /* bytes_written = */ payload.size()); + + session_.OnRstStream(rst_frame); + + // The stream is closed. + EXPECT_FALSE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); +} + +TEST_P(QuicSpdySessionTestClient, FinInMiddleOfStreamType) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + const QuicStreamId stream_id = + GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 0); + ASSERT_TRUE(session_.UsesPendingStreamForFrame(STREAM_FRAME, stream_id)); + + // Payload is the first byte of a two byte varint encoding with a FIN. + std::string payload = absl::HexStringToBytes("40"); + QuicStreamFrame frame(stream_id, /* fin = */ true, /* offset = */ 0, payload); + + session_.OnStreamFrame(frame); + EXPECT_FALSE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); +} + +TEST_P(QuicSpdySessionTestClient, DuplicateHttp3UnidirectionalStreams) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + StrictMock debug_visitor; + session_.set_debug_visitor(&debug_visitor); + + QuicStreamId id1 = + GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 0); + char type1[] = {kControlStream}; + + QuicStreamFrame data1(id1, false, 0, absl::string_view(type1, 1)); + EXPECT_CALL(debug_visitor, OnPeerControlStreamCreated(id1)); + session_.OnStreamFrame(data1); + QuicStreamId id2 = + GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 1); + QuicStreamFrame data2(id2, false, 0, absl::string_view(type1, 1)); + EXPECT_CALL(debug_visitor, OnPeerControlStreamCreated(id2)).Times(0); + EXPECT_QUIC_PEER_BUG( + { + EXPECT_CALL(*connection_, + CloseConnection(QUIC_HTTP_DUPLICATE_UNIDIRECTIONAL_STREAM, + "Control stream is received twice.", _)); + session_.OnStreamFrame(data2); + }, + "Received a duplicate Control stream: Closing connection."); + + QuicStreamId id3 = + GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 2); + char type2[]{kQpackEncoderStream}; + + QuicStreamFrame data3(id3, false, 0, absl::string_view(type2, 1)); + EXPECT_CALL(debug_visitor, OnPeerQpackEncoderStreamCreated(id3)); + session_.OnStreamFrame(data3); + + QuicStreamId id4 = + GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 3); + QuicStreamFrame data4(id4, false, 0, absl::string_view(type2, 1)); + EXPECT_CALL(debug_visitor, OnPeerQpackEncoderStreamCreated(id4)).Times(0); + EXPECT_QUIC_PEER_BUG( + { + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_HTTP_DUPLICATE_UNIDIRECTIONAL_STREAM, + "QPACK encoder stream is received twice.", _)); + session_.OnStreamFrame(data4); + }, + "Received a duplicate QPACK encoder stream: Closing connection."); + + QuicStreamId id5 = + GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 4); + char type3[]{kQpackDecoderStream}; + + QuicStreamFrame data5(id5, false, 0, absl::string_view(type3, 1)); + EXPECT_CALL(debug_visitor, OnPeerQpackDecoderStreamCreated(id5)); + session_.OnStreamFrame(data5); + + QuicStreamId id6 = + GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 5); + QuicStreamFrame data6(id6, false, 0, absl::string_view(type3, 1)); + EXPECT_CALL(debug_visitor, OnPeerQpackDecoderStreamCreated(id6)).Times(0); + EXPECT_QUIC_PEER_BUG( + { + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_HTTP_DUPLICATE_UNIDIRECTIONAL_STREAM, + "QPACK decoder stream is received twice.", _)); + session_.OnStreamFrame(data6); + }, + "Received a duplicate QPACK decoder stream: Closing connection."); +} + +TEST_P(QuicSpdySessionTestClient, EncoderStreamError) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + std::string data = absl::HexStringToBytes( + "02" // Encoder stream. + "00"); // Duplicate entry 0, but no entries exist. + + QuicStreamId stream_id = + GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 0); + + QuicStreamFrame frame(stream_id, /* fin = */ false, /* offset = */ 0, data); + + EXPECT_CALL(*connection_, + CloseConnection( + QUIC_QPACK_ENCODER_STREAM_DUPLICATE_INVALID_RELATIVE_INDEX, + "Encoder stream error: Invalid relative index.", _)); + session_.OnStreamFrame(frame); +} + +TEST_P(QuicSpdySessionTestClient, DecoderStreamError) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + std::string data = absl::HexStringToBytes( + "03" // Decoder stream. + "00"); // Insert Count Increment with forbidden increment value of zero. + + QuicStreamId stream_id = + GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 0); + + QuicStreamFrame frame(stream_id, /* fin = */ false, /* offset = */ 0, data); + + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_QPACK_DECODER_STREAM_INVALID_ZERO_INCREMENT, + "Decoder stream error: Invalid increment value 0.", _)); + session_.OnStreamFrame(frame); +} + +TEST_P(QuicSpdySessionTestClient, InvalidHttp3GoAway) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + EXPECT_CALL(*connection_, + CloseConnection(QUIC_HTTP_GOAWAY_INVALID_STREAM_ID, + "GOAWAY with invalid stream ID", _)); + QuicStreamId stream_id = + GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 0); + session_.OnHttp3GoAway(stream_id); +} + +TEST_P(QuicSpdySessionTestClient, Http3GoAwayLargerIdThanBefore) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + EXPECT_FALSE(session_.goaway_received()); + QuicStreamId stream_id1 = + GetNthClientInitiatedBidirectionalStreamId(transport_version(), 0); + session_.OnHttp3GoAway(stream_id1); + EXPECT_TRUE(session_.goaway_received()); + + EXPECT_CALL( + *connection_, + CloseConnection( + QUIC_HTTP_GOAWAY_ID_LARGER_THAN_PREVIOUS, + "GOAWAY received with ID 4 greater than previously received ID 0", + _)); + QuicStreamId stream_id2 = + GetNthClientInitiatedBidirectionalStreamId(transport_version(), 1); + session_.OnHttp3GoAway(stream_id2); +} + +TEST_P(QuicSpdySessionTestClient, CloseConnectionOnCancelPush) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + StrictMock debug_visitor; + session_.set_debug_visitor(&debug_visitor); + + // Create control stream. + QuicStreamId receive_control_stream_id = + GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 3); + char type[] = {kControlStream}; + absl::string_view stream_type(type, 1); + QuicStreamOffset offset = 0; + QuicStreamFrame data1(receive_control_stream_id, /* fin = */ false, offset, + stream_type); + offset += stream_type.length(); + EXPECT_CALL(debug_visitor, + OnPeerControlStreamCreated(receive_control_stream_id)); + session_.OnStreamFrame(data1); + EXPECT_EQ(receive_control_stream_id, + QuicSpdySessionPeer::GetReceiveControlStream(&session_)->id()); + + // First frame has to be SETTINGS. + std::string serialized_settings = HttpEncoder::SerializeSettingsFrame({}); + QuicStreamFrame data2(receive_control_stream_id, /* fin = */ false, offset, + serialized_settings); + offset += serialized_settings.length(); + EXPECT_CALL(debug_visitor, OnSettingsFrameReceived(_)); + session_.OnStreamFrame(data2); + + std::string cancel_push_frame = absl::HexStringToBytes( + "03" // CANCEL_PUSH + "01" // length + "00"); // push ID + QuicStreamFrame data3(receive_control_stream_id, /* fin = */ false, offset, + cancel_push_frame); + EXPECT_CALL(*connection_, CloseConnection(QUIC_HTTP_FRAME_ERROR, + "CANCEL_PUSH frame received.", _)) + .WillOnce( + Invoke(connection_, &MockQuicConnection::ReallyCloseConnection)); + EXPECT_CALL(*connection_, + SendConnectionClosePacket(QUIC_HTTP_FRAME_ERROR, _, + "CANCEL_PUSH frame received.")); + session_.OnStreamFrame(data3); +} + +TEST_P(QuicSpdySessionTestServer, OnSetting) { + CompleteHandshake(); + if (VersionUsesHttp3(transport_version())) { + EXPECT_EQ(std::numeric_limits::max(), + session_.max_outbound_header_list_size()); + session_.OnSetting(SETTINGS_MAX_FIELD_SECTION_SIZE, 5); + EXPECT_EQ(5u, session_.max_outbound_header_list_size()); + + EXPECT_CALL(*writer_, WritePacket(_, _, _, _, _)) + .WillRepeatedly(Return(WriteResult(WRITE_STATUS_OK, 0))); + QpackEncoder* qpack_encoder = session_.qpack_encoder(); + EXPECT_EQ(0u, QpackEncoderPeer::maximum_blocked_streams(qpack_encoder)); + session_.OnSetting(SETTINGS_QPACK_BLOCKED_STREAMS, 12); + EXPECT_EQ(12u, QpackEncoderPeer::maximum_blocked_streams(qpack_encoder)); + + QpackEncoderHeaderTable* header_table = + QpackEncoderPeer::header_table(qpack_encoder); + EXPECT_EQ(0u, header_table->maximum_dynamic_table_capacity()); + session_.OnSetting(SETTINGS_QPACK_MAX_TABLE_CAPACITY, 37); + EXPECT_EQ(37u, header_table->maximum_dynamic_table_capacity()); + + return; + } + + EXPECT_EQ(std::numeric_limits::max(), + session_.max_outbound_header_list_size()); + session_.OnSetting(SETTINGS_MAX_FIELD_SECTION_SIZE, 5); + EXPECT_EQ(5u, session_.max_outbound_header_list_size()); + + spdy::HpackEncoder* hpack_encoder = + QuicSpdySessionPeer::GetSpdyFramer(&session_)->GetHpackEncoder(); + EXPECT_EQ(4096u, hpack_encoder->CurrentHeaderTableSizeSetting()); + session_.OnSetting(spdy::SETTINGS_HEADER_TABLE_SIZE, 59); + EXPECT_EQ(59u, hpack_encoder->CurrentHeaderTableSizeSetting()); +} + +TEST_P(QuicSpdySessionTestServer, FineGrainedHpackErrorCodes) { + if (VersionUsesHttp3(transport_version())) { + // HPACK is not used in HTTP/3. + return; + } + + QuicStreamId request_stream_id = 5; + session_.CreateIncomingStream(request_stream_id); + + // Index 126 does not exist (static table has 61 entries and dynamic table is + // empty). + std::string headers_frame = absl::HexStringToBytes( + "000006" // length + "01" // type + "24" // flags: PRIORITY | END_HEADERS + "00000005" // stream_id + "00000000" // stream dependency + "10" // weight + "fe"); // payload: reference to index 126. + QuicStreamId headers_stream_id = + QuicUtils::GetHeadersStreamId(transport_version()); + QuicStreamFrame data(headers_stream_id, false, 0, headers_frame); + + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_HPACK_INVALID_INDEX, + "SPDY framing error: HPACK_INVALID_INDEX", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET)); + session_.OnStreamFrame(data); +} + +TEST_P(QuicSpdySessionTestServer, PeerClosesCriticalReceiveStream) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + struct { + char type; + const char* error_details; + } kTestData[] = { + {kControlStream, "RESET_STREAM received for receive control stream"}, + {kQpackEncoderStream, "RESET_STREAM received for QPACK receive stream"}, + {kQpackDecoderStream, "RESET_STREAM received for QPACK receive stream"}, + }; + for (size_t i = 0; i < ABSL_ARRAYSIZE(kTestData); ++i) { + QuicStreamId stream_id = + GetNthClientInitiatedUnidirectionalStreamId(transport_version(), i + 1); + const QuicByteCount data_length = 1; + QuicStreamFrame data(stream_id, false, 0, + absl::string_view(&kTestData[i].type, data_length)); + session_.OnStreamFrame(data); + + EXPECT_CALL(*connection_, CloseConnection(QUIC_HTTP_CLOSED_CRITICAL_STREAM, + kTestData[i].error_details, _)); + + QuicRstStreamFrame rst(kInvalidControlFrameId, stream_id, + QUIC_STREAM_CANCELLED, data_length); + session_.OnRstStream(rst); + } +} + +TEST_P(QuicSpdySessionTestServer, + H3ControlStreamsLimitedByConnectionFlowControl) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + // Ensure connection level flow control blockage. + QuicFlowControllerPeer::SetSendWindowOffset(session_.flow_controller(), 0); + EXPECT_TRUE(session_.IsConnectionFlowControlBlocked()); + + QuicSendControlStream* send_control_stream = + QuicSpdySessionPeer::GetSendControlStream(&session_); + // Mark send_control stream write blocked. + session_.MarkConnectionLevelWriteBlocked(send_control_stream->id()); + EXPECT_FALSE(session_.WillingAndAbleToWrite()); +} + +TEST_P(QuicSpdySessionTestServer, PeerClosesCriticalSendStream) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + QuicSendControlStream* control_stream = + QuicSpdySessionPeer::GetSendControlStream(&session_); + ASSERT_TRUE(control_stream); + + QuicStopSendingFrame stop_sending_control_stream( + kInvalidControlFrameId, control_stream->id(), QUIC_STREAM_CANCELLED); + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_HTTP_CLOSED_CRITICAL_STREAM, + "STOP_SENDING received for send control stream", _)); + session_.OnStopSendingFrame(stop_sending_control_stream); + + QpackSendStream* decoder_stream = + QuicSpdySessionPeer::GetQpackDecoderSendStream(&session_); + ASSERT_TRUE(decoder_stream); + + QuicStopSendingFrame stop_sending_decoder_stream( + kInvalidControlFrameId, decoder_stream->id(), QUIC_STREAM_CANCELLED); + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_HTTP_CLOSED_CRITICAL_STREAM, + "STOP_SENDING received for QPACK send stream", _)); + session_.OnStopSendingFrame(stop_sending_decoder_stream); + + QpackSendStream* encoder_stream = + QuicSpdySessionPeer::GetQpackEncoderSendStream(&session_); + ASSERT_TRUE(encoder_stream); + + QuicStopSendingFrame stop_sending_encoder_stream( + kInvalidControlFrameId, encoder_stream->id(), QUIC_STREAM_CANCELLED); + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_HTTP_CLOSED_CRITICAL_STREAM, + "STOP_SENDING received for QPACK send stream", _)); + session_.OnStopSendingFrame(stop_sending_encoder_stream); +} + +TEST_P(QuicSpdySessionTestServer, CloseConnectionOnCancelPush) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + StrictMock debug_visitor; + session_.set_debug_visitor(&debug_visitor); + + // Create control stream. + QuicStreamId receive_control_stream_id = + GetNthClientInitiatedUnidirectionalStreamId(transport_version(), 3); + char type[] = {kControlStream}; + absl::string_view stream_type(type, 1); + QuicStreamOffset offset = 0; + QuicStreamFrame data1(receive_control_stream_id, /* fin = */ false, offset, + stream_type); + offset += stream_type.length(); + EXPECT_CALL(debug_visitor, + OnPeerControlStreamCreated(receive_control_stream_id)); + session_.OnStreamFrame(data1); + EXPECT_EQ(receive_control_stream_id, + QuicSpdySessionPeer::GetReceiveControlStream(&session_)->id()); + + // First frame has to be SETTINGS. + std::string serialized_settings = HttpEncoder::SerializeSettingsFrame({}); + QuicStreamFrame data2(receive_control_stream_id, /* fin = */ false, offset, + serialized_settings); + offset += serialized_settings.length(); + EXPECT_CALL(debug_visitor, OnSettingsFrameReceived(_)); + session_.OnStreamFrame(data2); + + std::string cancel_push_frame = absl::HexStringToBytes( + "03" // CANCEL_PUSH + "01" // length + "00"); // push ID + QuicStreamFrame data3(receive_control_stream_id, /* fin = */ false, offset, + cancel_push_frame); + EXPECT_CALL(*connection_, CloseConnection(QUIC_HTTP_FRAME_ERROR, + "CANCEL_PUSH frame received.", _)) + .WillOnce( + Invoke(connection_, &MockQuicConnection::ReallyCloseConnection)); + EXPECT_CALL(*connection_, + SendConnectionClosePacket(QUIC_HTTP_FRAME_ERROR, _, + "CANCEL_PUSH frame received.")); + session_.OnStreamFrame(data3); +} + +TEST_P(QuicSpdySessionTestServer, Http3GoAwayWhenClosingConnection) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + StrictMock debug_visitor; + session_.set_debug_visitor(&debug_visitor); + + EXPECT_CALL(debug_visitor, OnSettingsFrameSent(_)); + CompleteHandshake(); + + QuicStreamId stream_id = GetNthClientInitiatedBidirectionalId(0); + + // Create stream by receiving some data (CreateIncomingStream() would not + // update the session's largest peer created stream ID). + const QuicByteCount headers_payload_length = 10; + std::string headers_frame_header = + HttpEncoder::SerializeHeadersFrameHeader(headers_payload_length); + EXPECT_CALL(debug_visitor, + OnHeadersFrameReceived(stream_id, headers_payload_length)); + session_.OnStreamFrame( + QuicStreamFrame(stream_id, false, 0, headers_frame_header)); + + EXPECT_EQ(stream_id, QuicSessionPeer::GetLargestPeerCreatedStreamId( + &session_, /*unidirectional = */ false)); + + // Stream with stream_id is already received and potentially processed, + // therefore a GOAWAY frame is sent with the next stream ID. + EXPECT_CALL(debug_visitor, + OnGoAwayFrameSent(stream_id + + QuicUtils::StreamIdDelta(transport_version()))); + + // Close connection. + EXPECT_CALL(*writer_, WritePacket(_, _, _, _, _)) + .WillRepeatedly(Return(WriteResult(WRITE_STATUS_OK, 0))); + EXPECT_CALL(*connection_, CloseConnection(QUIC_NO_ERROR, _, _)) + .WillOnce( + Invoke(connection_, &MockQuicConnection::ReallyCloseConnection)); + EXPECT_CALL(*connection_, SendConnectionClosePacket(QUIC_NO_ERROR, _, _)) + .WillOnce(Invoke(connection_, + &MockQuicConnection::ReallySendConnectionClosePacket)); + connection_->CloseConnection( + QUIC_NO_ERROR, "closing connection", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); +} + +TEST_P(QuicSpdySessionTestClient, DoNotSendInitialMaxPushIdIfNotSet) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + StrictMock debug_visitor; + session_.set_debug_visitor(&debug_visitor); + + InSequence s; + EXPECT_CALL(debug_visitor, OnSettingsFrameSent(_)); + + CompleteHandshake(); +} + +TEST_P(QuicSpdySessionTestClient, ReceiveSpdySettingInHttp3) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + SettingsFrame frame; + frame.values[SETTINGS_MAX_FIELD_SECTION_SIZE] = 5; + // https://datatracker.ietf.org/doc/html/draft-ietf-quic-http-30#section-7.2.4.1 + // specifies the presence of HTTP/2 setting as error. + frame.values[spdy::SETTINGS_INITIAL_WINDOW_SIZE] = 100; + + CompleteHandshake(); + + EXPECT_CALL(*connection_, + CloseConnection(QUIC_HTTP_RECEIVE_SPDY_SETTING, _, _)); + session_.OnSettingsFrame(frame); +} + +TEST_P(QuicSpdySessionTestClient, ReceiveAcceptChFrame) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + StrictMock debug_visitor; + session_.set_debug_visitor(&debug_visitor); + + // Create control stream. + QuicStreamId receive_control_stream_id = + GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 3); + char type[] = {kControlStream}; + absl::string_view stream_type(type, 1); + QuicStreamOffset offset = 0; + QuicStreamFrame data1(receive_control_stream_id, /* fin = */ false, offset, + stream_type); + offset += stream_type.length(); + EXPECT_CALL(debug_visitor, + OnPeerControlStreamCreated(receive_control_stream_id)); + + session_.OnStreamFrame(data1); + EXPECT_EQ(receive_control_stream_id, + QuicSpdySessionPeer::GetReceiveControlStream(&session_)->id()); + + // First frame has to be SETTINGS. + std::string serialized_settings = HttpEncoder::SerializeSettingsFrame({}); + QuicStreamFrame data2(receive_control_stream_id, /* fin = */ false, offset, + serialized_settings); + offset += serialized_settings.length(); + EXPECT_CALL(debug_visitor, OnSettingsFrameReceived(_)); + + session_.OnStreamFrame(data2); + + // Receive ACCEPT_CH frame. + AcceptChFrame accept_ch; + accept_ch.entries.push_back({"foo", "bar"}); + std::string accept_ch_frame = HttpEncoder::SerializeAcceptChFrame(accept_ch); + QuicStreamFrame data3(receive_control_stream_id, /* fin = */ false, offset, + accept_ch_frame); + + EXPECT_CALL(debug_visitor, OnAcceptChFrameReceived(accept_ch)); + EXPECT_CALL(session_, OnAcceptChFrame(accept_ch)); + + session_.OnStreamFrame(data3); +} + +TEST_P(QuicSpdySessionTestClient, AcceptChViaAlps) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + StrictMock debug_visitor; + session_.set_debug_visitor(&debug_visitor); + + std::string serialized_accept_ch_frame = absl::HexStringToBytes( + "4089" // type (ACCEPT_CH) + "08" // length + "03" // length of origin + "666f6f" // origin "foo" + "03" // length of value + "626172"); // value "bar" + + AcceptChFrame expected_accept_ch_frame{{{"foo", "bar"}}}; + EXPECT_CALL(debug_visitor, + OnAcceptChFrameReceivedViaAlps(expected_accept_ch_frame)); + + auto error = session_.OnAlpsData( + reinterpret_cast(serialized_accept_ch_frame.data()), + serialized_accept_ch_frame.size()); + EXPECT_FALSE(error); +} + +TEST_P(QuicSpdySessionTestClient, AlpsForbiddenFrame) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + std::string forbidden_frame = absl::HexStringToBytes( + "00" // type (DATA) + "03" // length + "66666f"); // "foo" + + auto error = session_.OnAlpsData( + reinterpret_cast(forbidden_frame.data()), + forbidden_frame.size()); + ASSERT_TRUE(error); + EXPECT_EQ("DATA frame forbidden", error.value()); +} + +TEST_P(QuicSpdySessionTestClient, AlpsIncompleteFrame) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + std::string incomplete_frame = absl::HexStringToBytes( + "04" // type (SETTINGS) + "03"); // non-zero length but empty payload + + auto error = session_.OnAlpsData( + reinterpret_cast(incomplete_frame.data()), + incomplete_frame.size()); + ASSERT_TRUE(error); + EXPECT_EQ("incomplete HTTP/3 frame", error.value()); +} + +// After receiving a SETTINGS frame via ALPS, +// another SETTINGS frame is still allowed on control frame. +TEST_P(QuicSpdySessionTestClient, SettingsViaAlpsThenOnControlStream) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + QpackEncoder* qpack_encoder = session_.qpack_encoder(); + EXPECT_EQ(0u, qpack_encoder->MaximumDynamicTableCapacity()); + EXPECT_EQ(0u, qpack_encoder->maximum_blocked_streams()); + + StrictMock debug_visitor; + session_.set_debug_visitor(&debug_visitor); + + std::string serialized_settings_frame1 = absl::HexStringToBytes( + "04" // type (SETTINGS) + "05" // length + "01" // SETTINGS_QPACK_MAX_TABLE_CAPACITY + "4400" // 0x0400 = 1024 + "07" // SETTINGS_QPACK_BLOCKED_STREAMS + "20"); // 0x20 = 32 + + SettingsFrame expected_settings_frame1{ + {{SETTINGS_QPACK_MAX_TABLE_CAPACITY, 1024}, + {SETTINGS_QPACK_BLOCKED_STREAMS, 32}}}; + EXPECT_CALL(debug_visitor, + OnSettingsFrameReceivedViaAlps(expected_settings_frame1)); + + auto error = session_.OnAlpsData( + reinterpret_cast(serialized_settings_frame1.data()), + serialized_settings_frame1.size()); + EXPECT_FALSE(error); + + EXPECT_EQ(1024u, qpack_encoder->MaximumDynamicTableCapacity()); + EXPECT_EQ(32u, qpack_encoder->maximum_blocked_streams()); + + const QuicStreamId control_stream_id = + GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 0); + EXPECT_CALL(debug_visitor, OnPeerControlStreamCreated(control_stream_id)); + + std::string stream_type = absl::HexStringToBytes("00"); + session_.OnStreamFrame(QuicStreamFrame(control_stream_id, /* fin = */ false, + /* offset = */ 0, stream_type)); + + // SETTINGS_QPACK_MAX_TABLE_CAPACITY, if advertised again, MUST have identical + // value. + // SETTINGS_QPACK_BLOCKED_STREAMS is a limit. Limits MUST NOT be reduced, but + // increasing is okay. + SettingsFrame expected_settings_frame2{ + {{SETTINGS_QPACK_MAX_TABLE_CAPACITY, 1024}, + {SETTINGS_QPACK_BLOCKED_STREAMS, 48}}}; + EXPECT_CALL(debug_visitor, OnSettingsFrameReceived(expected_settings_frame2)); + std::string serialized_settings_frame2 = absl::HexStringToBytes( + "04" // type (SETTINGS) + "05" // length + "01" // SETTINGS_QPACK_MAX_TABLE_CAPACITY + "4400" // 0x0400 = 1024 + "07" // SETTINGS_QPACK_BLOCKED_STREAMS + "30"); // 0x30 = 48 + session_.OnStreamFrame(QuicStreamFrame(control_stream_id, /* fin = */ false, + /* offset = */ stream_type.length(), + serialized_settings_frame2)); + + EXPECT_EQ(1024u, qpack_encoder->MaximumDynamicTableCapacity()); + EXPECT_EQ(48u, qpack_encoder->maximum_blocked_streams()); +} + +// A SETTINGS frame received via ALPS and another one on the control stream +// cannot have conflicting values. +TEST_P(QuicSpdySessionTestClient, + SettingsViaAlpsConflictsSettingsViaControlStream) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + QpackEncoder* qpack_encoder = session_.qpack_encoder(); + EXPECT_EQ(0u, qpack_encoder->MaximumDynamicTableCapacity()); + + std::string serialized_settings_frame1 = absl::HexStringToBytes( + "04" // type (SETTINGS) + "03" // length + "01" // SETTINGS_QPACK_MAX_TABLE_CAPACITY + "4400"); // 0x0400 = 1024 + + auto error = session_.OnAlpsData( + reinterpret_cast(serialized_settings_frame1.data()), + serialized_settings_frame1.size()); + EXPECT_FALSE(error); + + EXPECT_EQ(1024u, qpack_encoder->MaximumDynamicTableCapacity()); + + const QuicStreamId control_stream_id = + GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 0); + + std::string stream_type = absl::HexStringToBytes("00"); + session_.OnStreamFrame(QuicStreamFrame(control_stream_id, /* fin = */ false, + /* offset = */ 0, stream_type)); + + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_HTTP_ZERO_RTT_RESUMPTION_SETTINGS_MISMATCH, + "Server sent an SETTINGS_QPACK_MAX_TABLE_CAPACITY: " + "32 while current value is: 1024", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET)); + std::string serialized_settings_frame2 = absl::HexStringToBytes( + "04" // type (SETTINGS) + "02" // length + "01" // SETTINGS_QPACK_MAX_TABLE_CAPACITY + "20"); // 0x20 = 32 + session_.OnStreamFrame(QuicStreamFrame(control_stream_id, /* fin = */ false, + /* offset = */ stream_type.length(), + serialized_settings_frame2)); +} + +TEST_P(QuicSpdySessionTestClient, AlpsTwoSettingsFrame) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + std::string banned_frame = absl::HexStringToBytes( + "04" // type (SETTINGS) + "00" // length + "04" // type (SETTINGS) + "00"); // length + + auto error = + session_.OnAlpsData(reinterpret_cast(banned_frame.data()), + banned_frame.size()); + ASSERT_TRUE(error); + EXPECT_EQ("multiple SETTINGS frames", error.value()); +} + +void QuicSpdySessionTestBase::TestHttpDatagramSetting( + HttpDatagramSupport local_support, HttpDatagramSupport remote_support, + HttpDatagramSupport expected_support, bool expected_datagram_supported) { + if (!version().UsesHttp3()) { + return; + } + session_.set_local_http_datagram_support(local_support); + // HTTP/3 datagrams aren't supported before SETTINGS are received. + EXPECT_FALSE(session_.SupportsH3Datagram()); + EXPECT_EQ(session_.http_datagram_support(), HttpDatagramSupport::kNone); + // Receive SETTINGS. + SettingsFrame settings; + switch (remote_support) { + case HttpDatagramSupport::kNone: + break; + case HttpDatagramSupport::kDraft04: + settings.values[SETTINGS_H3_DATAGRAM_DRAFT04] = 1; + break; + case HttpDatagramSupport::kRfc: + settings.values[SETTINGS_H3_DATAGRAM] = 1; + break; + case HttpDatagramSupport::kRfcAndDraft04: + settings.values[SETTINGS_H3_DATAGRAM] = 1; + settings.values[SETTINGS_H3_DATAGRAM_DRAFT04] = 1; + break; + } + std::string data = std::string(1, kControlStream) + + HttpEncoder::SerializeSettingsFrame(settings); + QuicStreamId stream_id = + GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 3); + QuicStreamFrame frame(stream_id, /*fin=*/false, /*offset=*/0, data); + StrictMock debug_visitor; + session_.set_debug_visitor(&debug_visitor); + EXPECT_CALL(debug_visitor, OnPeerControlStreamCreated(stream_id)); + EXPECT_CALL(debug_visitor, OnSettingsFrameReceived(settings)); + session_.OnStreamFrame(frame); + EXPECT_EQ(session_.http_datagram_support(), expected_support); + EXPECT_EQ(session_.SupportsH3Datagram(), expected_datagram_supported); +} + +TEST_P(QuicSpdySessionTestClient, HttpDatagramSettingLocal04Remote04) { + TestHttpDatagramSetting( + /*local_support=*/HttpDatagramSupport::kDraft04, + /*remote_support=*/HttpDatagramSupport::kDraft04, + /*expected_support=*/HttpDatagramSupport::kDraft04, + /*expected_datagram_supported=*/true); +} + +TEST_P(QuicSpdySessionTestClient, HttpDatagramSettingLocal04Remote09) { + TestHttpDatagramSetting( + /*local_support=*/HttpDatagramSupport::kDraft04, + /*remote_support=*/HttpDatagramSupport::kRfc, + /*expected_support=*/HttpDatagramSupport::kNone, + /*expected_datagram_supported=*/false); +} + +TEST_P(QuicSpdySessionTestClient, HttpDatagramSettingLocal04Remote04And09) { + TestHttpDatagramSetting( + /*local_support=*/HttpDatagramSupport::kDraft04, + /*remote_support=*/HttpDatagramSupport::kRfcAndDraft04, + /*expected_support=*/HttpDatagramSupport::kDraft04, + /*expected_datagram_supported=*/true); +} + +TEST_P(QuicSpdySessionTestClient, HttpDatagramSettingLocal09Remote04) { + TestHttpDatagramSetting( + /*local_support=*/HttpDatagramSupport::kRfc, + /*remote_support=*/HttpDatagramSupport::kDraft04, + /*expected_support=*/HttpDatagramSupport::kNone, + /*expected_datagram_supported=*/false); +} + +TEST_P(QuicSpdySessionTestClient, HttpDatagramSettingLocal09Remote09) { + TestHttpDatagramSetting( + /*local_support=*/HttpDatagramSupport::kRfc, + /*remote_support=*/HttpDatagramSupport::kRfc, + /*expected_support=*/HttpDatagramSupport::kRfc, + /*expected_datagram_supported=*/true); +} + +TEST_P(QuicSpdySessionTestClient, HttpDatagramSettingLocal09Remote04And09) { + TestHttpDatagramSetting( + /*local_support=*/HttpDatagramSupport::kRfc, + /*remote_support=*/HttpDatagramSupport::kRfcAndDraft04, + /*expected_support=*/HttpDatagramSupport::kRfc, + /*expected_datagram_supported=*/true); +} + +TEST_P(QuicSpdySessionTestClient, HttpDatagramSettingLocal04And09Remote04) { + TestHttpDatagramSetting( + /*local_support=*/HttpDatagramSupport::kRfcAndDraft04, + /*remote_support=*/HttpDatagramSupport::kDraft04, + /*expected_support=*/HttpDatagramSupport::kDraft04, + /*expected_datagram_supported=*/true); +} + +TEST_P(QuicSpdySessionTestClient, HttpDatagramSettingLocal04And09Remote09) { + TestHttpDatagramSetting( + /*local_support=*/HttpDatagramSupport::kRfcAndDraft04, + /*remote_support=*/HttpDatagramSupport::kRfc, + /*expected_support=*/HttpDatagramSupport::kRfc, + /*expected_datagram_supported=*/true); +} + +TEST_P(QuicSpdySessionTestClient, + HttpDatagramSettingLocal04And09Remote04And09) { + TestHttpDatagramSetting( + /*local_support=*/HttpDatagramSupport::kRfcAndDraft04, + /*remote_support=*/HttpDatagramSupport::kRfcAndDraft04, + /*expected_support=*/HttpDatagramSupport::kRfc, + /*expected_datagram_supported=*/true); +} +TEST_P(QuicSpdySessionTestClient, WebTransportSetting) { + if (!version().UsesHttp3()) { + return; + } + session_.set_local_http_datagram_support(HttpDatagramSupport::kDraft04); + session_.set_supports_webtransport(true); + + EXPECT_FALSE(session_.SupportsWebTransport()); + + StrictMock debug_visitor; + // Note that this does not actually fill out correct settings because the + // settings are filled in at the construction time. + EXPECT_CALL(debug_visitor, OnSettingsFrameSent(_)); + session_.set_debug_visitor(&debug_visitor); + CompleteHandshake(); + + EXPECT_CALL(debug_visitor, OnPeerControlStreamCreated(_)); + EXPECT_CALL(debug_visitor, OnSettingsFrameReceived(_)); + ReceiveWebTransportSettings(); + EXPECT_TRUE(session_.ShouldProcessIncomingRequests()); + EXPECT_TRUE(session_.SupportsWebTransport()); +} + +TEST_P(QuicSpdySessionTestClient, WebTransportSettingSetToZero) { + if (!version().UsesHttp3()) { + return; + } + session_.set_local_http_datagram_support(HttpDatagramSupport::kDraft04); + session_.set_supports_webtransport(true); + + EXPECT_FALSE(session_.SupportsWebTransport()); + + StrictMock debug_visitor; + // Note that this does not actually fill out correct settings because the + // settings are filled in at the construction time. + EXPECT_CALL(debug_visitor, OnSettingsFrameSent(_)); + session_.set_debug_visitor(&debug_visitor); + CompleteHandshake(); + + SettingsFrame server_settings; + server_settings.values[SETTINGS_H3_DATAGRAM_DRAFT04] = 1; + server_settings.values[SETTINGS_WEBTRANS_DRAFT00] = 0; + std::string data = std::string(1, kControlStream) + + HttpEncoder::SerializeSettingsFrame(server_settings); + QuicStreamId stream_id = + GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 3); + QuicStreamFrame frame(stream_id, /*fin=*/false, /*offset=*/0, data); + EXPECT_CALL(debug_visitor, OnPeerControlStreamCreated(stream_id)); + EXPECT_CALL(debug_visitor, OnSettingsFrameReceived(server_settings)); + session_.OnStreamFrame(frame); + EXPECT_FALSE(session_.SupportsWebTransport()); +} + +TEST_P(QuicSpdySessionTestServer, WebTransportSetting) { + if (!version().UsesHttp3()) { + return; + } + session_.set_local_http_datagram_support(HttpDatagramSupport::kDraft04); + session_.set_supports_webtransport(true); + + EXPECT_FALSE(session_.SupportsWebTransport()); + EXPECT_FALSE(session_.ShouldProcessIncomingRequests()); + + CompleteHandshake(); + + ReceiveWebTransportSettings(); + EXPECT_TRUE(session_.SupportsWebTransport()); + EXPECT_TRUE(session_.ShouldProcessIncomingRequests()); +} + +TEST_P(QuicSpdySessionTestServer, BufferingIncomingStreams) { + if (!version().UsesHttp3()) { + return; + } + session_.set_local_http_datagram_support(HttpDatagramSupport::kDraft04); + session_.set_supports_webtransport(true); + + CompleteHandshake(); + QuicStreamId session_id = + GetNthClientInitiatedBidirectionalStreamId(transport_version(), 1); + + QuicStreamId data_stream_id = + GetNthClientInitiatedUnidirectionalStreamId(transport_version(), 4); + ReceiveWebTransportUnidirectionalStream(session_id, data_stream_id); + + ReceiveWebTransportSettings(); + + ReceiveWebTransportSession(session_id); + WebTransportHttp3* web_transport = + session_.GetWebTransportSession(session_id); + ASSERT_TRUE(web_transport != nullptr); + + EXPECT_EQ(web_transport->NumberOfAssociatedStreams(), 1u); + + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillRepeatedly(Invoke(&ClearControlFrame)); + EXPECT_CALL(*connection_, OnStreamReset(session_id, _)); + EXPECT_CALL( + *connection_, + OnStreamReset(data_stream_id, QUIC_STREAM_WEBTRANSPORT_SESSION_GONE)); + session_.ResetStream(session_id, QUIC_STREAM_INTERNAL_ERROR); +} + +TEST_P(QuicSpdySessionTestServer, BufferingIncomingStreamsLimit) { + if (!version().UsesHttp3()) { + return; + } + session_.set_local_http_datagram_support(HttpDatagramSupport::kDraft04); + session_.set_supports_webtransport(true); + + CompleteHandshake(); + QuicStreamId session_id = + GetNthClientInitiatedBidirectionalStreamId(transport_version(), 1); + + const int streams_to_send = kMaxUnassociatedWebTransportStreams + 4; + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillRepeatedly(Invoke(&ClearControlFrame)); + EXPECT_CALL(*connection_, + OnStreamReset( + _, QUIC_STREAM_WEBTRANSPORT_BUFFERED_STREAMS_LIMIT_EXCEEDED)) + .Times(4); + for (int i = 0; i < streams_to_send; i++) { + QuicStreamId data_stream_id = + GetNthClientInitiatedUnidirectionalStreamId(transport_version(), 4 + i); + ReceiveWebTransportUnidirectionalStream(session_id, data_stream_id); + } + + ReceiveWebTransportSettings(); + + ReceiveWebTransportSession(session_id); + WebTransportHttp3* web_transport = + session_.GetWebTransportSession(session_id); + ASSERT_TRUE(web_transport != nullptr); + + EXPECT_EQ(web_transport->NumberOfAssociatedStreams(), + kMaxUnassociatedWebTransportStreams); + + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillRepeatedly(Invoke(&ClearControlFrame)); + EXPECT_CALL(*connection_, OnStreamReset(_, _)) + .Times(kMaxUnassociatedWebTransportStreams + 1); + session_.ResetStream(session_id, QUIC_STREAM_INTERNAL_ERROR); +} + +TEST_P(QuicSpdySessionTestServer, ResetOutgoingWebTransportStreams) { + if (!version().UsesHttp3()) { + return; + } + session_.set_local_http_datagram_support(HttpDatagramSupport::kDraft04); + session_.set_supports_webtransport(true); + + CompleteHandshake(); + QuicStreamId session_id = + GetNthClientInitiatedBidirectionalStreamId(transport_version(), 1); + + ReceiveWebTransportSettings(); + ReceiveWebTransportSession(session_id); + WebTransportHttp3* web_transport = + session_.GetWebTransportSession(session_id); + ASSERT_TRUE(web_transport != nullptr); + + session_.set_writev_consumes_all_data(true); + EXPECT_TRUE(web_transport->CanOpenNextOutgoingUnidirectionalStream()); + EXPECT_EQ(web_transport->NumberOfAssociatedStreams(), 0u); + WebTransportStream* stream = + web_transport->OpenOutgoingUnidirectionalStream(); + EXPECT_EQ(web_transport->NumberOfAssociatedStreams(), 1u); + ASSERT_TRUE(stream != nullptr); + QuicStreamId stream_id = stream->GetStreamId(); + + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillRepeatedly(Invoke(&ClearControlFrame)); + EXPECT_CALL(*connection_, OnStreamReset(session_id, _)); + EXPECT_CALL(*connection_, + OnStreamReset(stream_id, QUIC_STREAM_WEBTRANSPORT_SESSION_GONE)); + session_.ResetStream(session_id, QUIC_STREAM_INTERNAL_ERROR); + EXPECT_EQ(web_transport->NumberOfAssociatedStreams(), 0u); +} + +TEST_P(QuicSpdySessionTestClient, WebTransportWithoutExtendedConnect) { + if (!version().UsesHttp3()) { + return; + } + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + session_.set_local_http_datagram_support(HttpDatagramSupport::kDraft04); + session_.set_supports_webtransport(true); + + EXPECT_FALSE(session_.SupportsWebTransport()); + CompleteHandshake(); + + SettingsFrame settings; + settings.values[SETTINGS_H3_DATAGRAM_DRAFT04] = 1; + settings.values[SETTINGS_WEBTRANS_DRAFT00] = 1; + // No SETTINGS_ENABLE_CONNECT_PROTOCOL here. + std::string data = std::string(1, kControlStream) + + HttpEncoder::SerializeSettingsFrame(settings); + QuicStreamId control_stream_id = + session_.perspective() == Perspective::IS_SERVER + ? GetNthClientInitiatedUnidirectionalStreamId(transport_version(), 3) + : GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 3); + QuicStreamFrame frame(control_stream_id, /*fin=*/false, /*offset=*/0, data); + session_.OnStreamFrame(frame); + + EXPECT_TRUE(session_.SupportsWebTransport()); +} + +// Regression test for b/208997000. +TEST_P(QuicSpdySessionTestClient, LimitEncoderDynamicTableSize) { + if (version().UsesHttp3()) { + return; + } + CompleteHandshake(); + + QuicSpdySessionPeer::SetHeadersStream(&session_, nullptr); + TestHeadersStream* headers_stream = + new StrictMock(&session_); + QuicSpdySessionPeer::SetHeadersStream(&session_, headers_stream); + session_.MarkConnectionLevelWriteBlocked(headers_stream->id()); + + // Peer sends very large value. + session_.OnSetting(spdy::SETTINGS_HEADER_TABLE_SIZE, 1024 * 1024 * 1024); + + TestStream* stream = session_.CreateOutgoingBidirectionalStream(); + EXPECT_CALL(*writer_, IsWriteBlocked()).WillRepeatedly(Return(true)); + Http2HeaderBlock headers; + headers[":method"] = "GET"; // entry with index 2 in HPACK static table + stream->WriteHeaders(std::move(headers), /* fin = */ true, nullptr); + + EXPECT_TRUE(headers_stream->HasBufferedData()); + QuicStreamSendBuffer& send_buffer = + QuicStreamPeer::SendBuffer(headers_stream); + ASSERT_EQ(1u, send_buffer.size()); + + const quiche::QuicheMemSlice& slice = + QuicStreamSendBufferPeer::CurrentWriteSlice(&send_buffer)->slice; + absl::string_view stream_data(slice.data(), slice.length()); + + EXPECT_EQ(absl::HexStringToBytes( + "000009" // frame length + "01" // frame type HEADERS + "25"), // flags END_STREAM | END_HEADERS | PRIORITY + stream_data.substr(0, 5)); + stream_data.remove_prefix(5); + + // Ignore stream ID as it might differ between QUIC versions. + stream_data.remove_prefix(4); + + EXPECT_EQ(absl::HexStringToBytes("00000000" // stream dependency + "92"), // stream weight + stream_data.substr(0, 5)); + stream_data.remove_prefix(5); + + EXPECT_EQ(absl::HexStringToBytes( + "3fe17f" // Dynamic Table Size Update to 16384 + "82"), // Indexed Header Field Representation with index 2 + stream_data); +} + +class QuicSpdySessionTestServerNoExtendedConnect + : public QuicSpdySessionTestBase { + public: + QuicSpdySessionTestServerNoExtendedConnect() + : QuicSpdySessionTestBase(Perspective::IS_SERVER, false) {} +}; + +INSTANTIATE_TEST_SUITE_P(Tests, QuicSpdySessionTestServerNoExtendedConnect, + ::testing::ValuesIn(AllSupportedVersions()), + ::testing::PrintToStringParamName()); + +// Tests that receiving SETTINGS_ENABLE_CONNECT_PROTOCOL = 1 doesn't enable +// server session to support extended CONNECT. +TEST_P(QuicSpdySessionTestServerNoExtendedConnect, + WebTransportSettingNoEffect) { + if (!version().UsesHttp3()) { + return; + } + + EXPECT_FALSE(session_.SupportsWebTransport()); + EXPECT_TRUE(session_.ShouldProcessIncomingRequests()); + + CompleteHandshake(); + + ReceiveWebTransportSettings(); + EXPECT_FALSE(session_.allow_extended_connect()); + EXPECT_FALSE(session_.SupportsWebTransport()); + EXPECT_TRUE(session_.ShouldProcessIncomingRequests()); +} + +TEST_P(QuicSpdySessionTestServerNoExtendedConnect, BadExtendedConnectSetting) { + if (!version().UsesHttp3()) { + return; + } + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + + EXPECT_FALSE(session_.SupportsWebTransport()); + EXPECT_TRUE(session_.ShouldProcessIncomingRequests()); + + CompleteHandshake(); + + // ENABLE_CONNECT_PROTOCOL setting value has to be 1 or 0; + SettingsFrame settings; + settings.values[SETTINGS_ENABLE_CONNECT_PROTOCOL] = 2; + std::string data = std::string(1, kControlStream) + + HttpEncoder::SerializeSettingsFrame(settings); + QuicStreamId control_stream_id = + session_.perspective() == Perspective::IS_SERVER + ? GetNthClientInitiatedUnidirectionalStreamId(transport_version(), 3) + : GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 3); + QuicStreamFrame frame(control_stream_id, /*fin=*/false, /*offset=*/0, data); + EXPECT_QUIC_PEER_BUG( + { + EXPECT_CALL(*connection_, + CloseConnection(QUIC_HTTP_INVALID_SETTING_VALUE, _, _)); + session_.OnStreamFrame(frame); + }, + "Received SETTINGS_ENABLE_CONNECT_PROTOCOL with invalid value"); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/http/quic_spdy_stream.cc b/quiche/quic/core/http/quic_spdy_stream.cc new file mode 100644 index 000000000000..84f57e31744f --- /dev/null +++ b/quiche/quic/core/http/quic_spdy_stream.cc @@ -0,0 +1,1673 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_spdy_stream.h" + +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/quic/core/http/http_constants.h" +#include "quiche/quic/core/http/http_decoder.h" +#include "quiche/quic/core/http/http_frames.h" +#include "quiche/quic/core/http/quic_spdy_session.h" +#include "quiche/quic/core/http/spdy_utils.h" +#include "quiche/quic/core/http/web_transport_http3.h" +#include "quiche/quic/core/qpack/qpack_decoder.h" +#include "quiche/quic/core/qpack/qpack_encoder.h" +#include "quiche/quic/core/quic_stream_priority.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/core/quic_write_blocked_list.h" +#include "quiche/quic/core/web_transport_interface.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/capsule.h" +#include "quiche/common/quiche_mem_slice_storage.h" +#include "quiche/common/quiche_text_utils.h" +#include "quiche/spdy/core/spdy_protocol.h" + +using ::quiche::Capsule; +using ::quiche::CapsuleType; +using ::spdy::Http2HeaderBlock; + +namespace quic { + +// Visitor of HttpDecoder that passes data frame to QuicSpdyStream and closes +// the connection on unexpected frames. +class QuicSpdyStream::HttpDecoderVisitor : public HttpDecoder::Visitor { + public: + explicit HttpDecoderVisitor(QuicSpdyStream* stream) : stream_(stream) {} + HttpDecoderVisitor(const HttpDecoderVisitor&) = delete; + HttpDecoderVisitor& operator=(const HttpDecoderVisitor&) = delete; + + void OnError(HttpDecoder* decoder) override { + stream_->OnUnrecoverableError(decoder->error(), decoder->error_detail()); + } + + bool OnMaxPushIdFrame() override { + CloseConnectionOnWrongFrame("Max Push Id"); + return false; + } + + bool OnGoAwayFrame(const GoAwayFrame& /*frame*/) override { + CloseConnectionOnWrongFrame("Goaway"); + return false; + } + + bool OnSettingsFrameStart(QuicByteCount /*header_length*/) override { + CloseConnectionOnWrongFrame("Settings"); + return false; + } + + bool OnSettingsFrame(const SettingsFrame& /*frame*/) override { + CloseConnectionOnWrongFrame("Settings"); + return false; + } + + bool OnDataFrameStart(QuicByteCount header_length, + QuicByteCount payload_length) override { + return stream_->OnDataFrameStart(header_length, payload_length); + } + + bool OnDataFramePayload(absl::string_view payload) override { + QUICHE_DCHECK(!payload.empty()); + return stream_->OnDataFramePayload(payload); + } + + bool OnDataFrameEnd() override { return stream_->OnDataFrameEnd(); } + + bool OnHeadersFrameStart(QuicByteCount header_length, + QuicByteCount payload_length) override { + if (!VersionUsesHttp3(stream_->transport_version())) { + CloseConnectionOnWrongFrame("Headers"); + return false; + } + return stream_->OnHeadersFrameStart(header_length, payload_length); + } + + bool OnHeadersFramePayload(absl::string_view payload) override { + QUICHE_DCHECK(!payload.empty()); + if (!VersionUsesHttp3(stream_->transport_version())) { + CloseConnectionOnWrongFrame("Headers"); + return false; + } + return stream_->OnHeadersFramePayload(payload); + } + + bool OnHeadersFrameEnd() override { + if (!VersionUsesHttp3(stream_->transport_version())) { + CloseConnectionOnWrongFrame("Headers"); + return false; + } + return stream_->OnHeadersFrameEnd(); + } + + bool OnPriorityUpdateFrameStart(QuicByteCount /*header_length*/) override { + CloseConnectionOnWrongFrame("Priority update"); + return false; + } + + bool OnPriorityUpdateFrame(const PriorityUpdateFrame& /*frame*/) override { + CloseConnectionOnWrongFrame("Priority update"); + return false; + } + + bool OnAcceptChFrameStart(QuicByteCount /*header_length*/) override { + CloseConnectionOnWrongFrame("ACCEPT_CH"); + return false; + } + + bool OnAcceptChFrame(const AcceptChFrame& /*frame*/) override { + CloseConnectionOnWrongFrame("ACCEPT_CH"); + return false; + } + + void OnWebTransportStreamFrameType( + QuicByteCount header_length, WebTransportSessionId session_id) override { + stream_->OnWebTransportStreamFrameType(header_length, session_id); + } + + bool OnUnknownFrameStart(uint64_t frame_type, QuicByteCount header_length, + QuicByteCount payload_length) override { + return stream_->OnUnknownFrameStart(frame_type, header_length, + payload_length); + } + + bool OnUnknownFramePayload(absl::string_view payload) override { + return stream_->OnUnknownFramePayload(payload); + } + + bool OnUnknownFrameEnd() override { return stream_->OnUnknownFrameEnd(); } + + private: + void CloseConnectionOnWrongFrame(absl::string_view frame_type) { + stream_->OnUnrecoverableError( + QUIC_HTTP_FRAME_UNEXPECTED_ON_SPDY_STREAM, + absl::StrCat(frame_type, " frame received on data stream")); + } + + QuicSpdyStream* stream_; +}; + +#define ENDPOINT \ + (session()->perspective() == Perspective::IS_SERVER ? "Server: " \ + : "Client:" \ + " ") + +namespace { +HttpDecoder::Options HttpDecoderOptionsForBidiStream( + QuicSpdySession* spdy_session) { + HttpDecoder::Options options; + options.allow_web_transport_stream = + spdy_session->WillNegotiateWebTransport(); + return options; +} +} // namespace + +QuicSpdyStream::QuicSpdyStream(QuicStreamId id, QuicSpdySession* spdy_session, + StreamType type) + : QuicStream(id, spdy_session, /*is_static=*/false, type), + spdy_session_(spdy_session), + on_body_available_called_because_sequencer_is_closed_(false), + visitor_(nullptr), + blocked_on_decoding_headers_(false), + headers_decompressed_(false), + header_list_size_limit_exceeded_(false), + headers_payload_length_(0), + trailers_decompressed_(false), + trailers_consumed_(false), + http_decoder_visitor_(std::make_unique(this)), + decoder_(http_decoder_visitor_.get(), + HttpDecoderOptionsForBidiStream(spdy_session)), + sequencer_offset_(0), + is_decoder_processing_input_(false), + ack_listener_(nullptr), + last_sent_priority_( + QuicStreamPriority::Default(spdy_session->priority_type())) { + QUICHE_DCHECK_EQ(session()->connection(), spdy_session->connection()); + QUICHE_DCHECK_EQ(transport_version(), spdy_session->transport_version()); + QUICHE_DCHECK(!QuicUtils::IsCryptoStreamId(transport_version(), id)); + QUICHE_DCHECK_EQ(0u, sequencer()->NumBytesConsumed()); + // If headers are sent on the headers stream, then do not receive any + // callbacks from the sequencer until headers are complete. + if (!VersionUsesHttp3(transport_version())) { + sequencer()->SetBlockedUntilFlush(); + } + + if (VersionUsesHttp3(transport_version())) { + sequencer()->set_level_triggered(true); + } + + spdy_session_->OnStreamCreated(this); +} + +QuicSpdyStream::QuicSpdyStream(PendingStream* pending, + QuicSpdySession* spdy_session) + : QuicStream(pending, spdy_session, /*is_static=*/false), + spdy_session_(spdy_session), + on_body_available_called_because_sequencer_is_closed_(false), + visitor_(nullptr), + blocked_on_decoding_headers_(false), + headers_decompressed_(false), + header_list_size_limit_exceeded_(false), + headers_payload_length_(0), + trailers_decompressed_(false), + trailers_consumed_(false), + http_decoder_visitor_(std::make_unique(this)), + decoder_(http_decoder_visitor_.get()), + sequencer_offset_(sequencer()->NumBytesConsumed()), + is_decoder_processing_input_(false), + ack_listener_(nullptr), + last_sent_priority_( + QuicStreamPriority::Default(spdy_session->priority_type())) { + QUICHE_DCHECK_EQ(session()->connection(), spdy_session->connection()); + QUICHE_DCHECK_EQ(transport_version(), spdy_session->transport_version()); + QUICHE_DCHECK(!QuicUtils::IsCryptoStreamId(transport_version(), id())); + // If headers are sent on the headers stream, then do not receive any + // callbacks from the sequencer until headers are complete. + if (!VersionUsesHttp3(transport_version())) { + sequencer()->SetBlockedUntilFlush(); + } + + if (VersionUsesHttp3(transport_version())) { + sequencer()->set_level_triggered(true); + } + + spdy_session_->OnStreamCreated(this); +} + +QuicSpdyStream::~QuicSpdyStream() {} + +size_t QuicSpdyStream::WriteHeaders( + Http2HeaderBlock header_block, bool fin, + quiche::QuicheReferenceCountedPointer + ack_listener) { + if (!AssertNotWebTransportDataStream("writing headers")) { + return 0; + } + + QuicConnection::ScopedPacketFlusher flusher(spdy_session_->connection()); + + MaybeProcessSentWebTransportHeaders(header_block); + + if (web_transport_ != nullptr && + spdy_session_->perspective() == Perspective::IS_SERVER) { + header_block["sec-webtransport-http3-draft"] = "draft02"; + } + + size_t bytes_written = + WriteHeadersImpl(std::move(header_block), fin, std::move(ack_listener)); + if (!VersionUsesHttp3(transport_version()) && fin) { + // If HEADERS are sent on the headers stream, then |fin_sent_| needs to be + // set and write side needs to be closed without actually sending a FIN on + // this stream. + // TODO(rch): Add test to ensure fin_sent_ is set whenever a fin is sent. + SetFinSent(); + CloseWriteSide(); + } + + if (web_transport_ != nullptr && + session()->perspective() == Perspective::IS_CLIENT) { + WriteGreaseCapsule(); + if (spdy_session_->http_datagram_support() == + HttpDatagramSupport::kDraft04) { + // Send a REGISTER_DATAGRAM_NO_CONTEXT capsule to support servers that + // are running draft-ietf-masque-h3-datagram-04 or -05. + uint64_t capsule_type = 0xff37a2; // REGISTER_DATAGRAM_NO_CONTEXT + constexpr unsigned char capsule_data[4] = { + 0x80, 0xff, 0x7c, 0x00, // WEBTRANSPORT datagram format type + }; + WriteCapsule(Capsule::Unknown( + capsule_type, + absl::string_view(reinterpret_cast(capsule_data), + sizeof(capsule_data)))); + WriteGreaseCapsule(); + } + } + + if (connect_ip_visitor_ != nullptr) { + connect_ip_visitor_->OnHeadersWritten(); + } + + return bytes_written; +} + +void QuicSpdyStream::WriteOrBufferBody(absl::string_view data, bool fin) { + if (!AssertNotWebTransportDataStream("writing body data")) { + return; + } + if (!VersionUsesHttp3(transport_version()) || data.length() == 0) { + WriteOrBufferData(data, fin, nullptr); + return; + } + QuicConnection::ScopedPacketFlusher flusher(spdy_session_->connection()); + + if (spdy_session_->debug_visitor()) { + spdy_session_->debug_visitor()->OnDataFrameSent(id(), data.length()); + } + + const bool success = + WriteDataFrameHeader(data.length(), /*force_write=*/true); + QUICHE_DCHECK(success); + + // Write body. + QUIC_DLOG(INFO) << ENDPOINT << "Stream " << id() + << " is writing DATA frame payload of length " + << data.length() << " with fin " << fin; + WriteOrBufferData(data, fin, nullptr); +} + +size_t QuicSpdyStream::WriteTrailers( + Http2HeaderBlock trailer_block, + quiche::QuicheReferenceCountedPointer + ack_listener) { + if (fin_sent()) { + QUIC_BUG(quic_bug_10410_1) + << "Trailers cannot be sent after a FIN, on stream " << id(); + return 0; + } + + if (!VersionUsesHttp3(transport_version())) { + // The header block must contain the final offset for this stream, as the + // trailers may be processed out of order at the peer. + const QuicStreamOffset final_offset = + stream_bytes_written() + BufferedDataBytes(); + QUIC_DLOG(INFO) << ENDPOINT << "Inserting trailer: (" + << kFinalOffsetHeaderKey << ", " << final_offset << ")"; + trailer_block.insert( + std::make_pair(kFinalOffsetHeaderKey, absl::StrCat(final_offset))); + } + + // Write the trailing headers with a FIN, and close stream for writing: + // trailers are the last thing to be sent on a stream. + const bool kFin = true; + size_t bytes_written = + WriteHeadersImpl(std::move(trailer_block), kFin, std::move(ack_listener)); + + // If trailers are sent on the headers stream, then |fin_sent_| needs to be + // set without actually sending a FIN on this stream. + if (!VersionUsesHttp3(transport_version())) { + SetFinSent(); + + // Also, write side of this stream needs to be closed. However, only do + // this if there is no more buffered data, otherwise it will never be sent. + if (BufferedDataBytes() == 0) { + CloseWriteSide(); + } + } + + return bytes_written; +} + +QuicConsumedData QuicSpdyStream::WritevBody(const struct iovec* iov, int count, + bool fin) { + quiche::QuicheMemSliceStorage storage( + iov, count, + session()->connection()->helper()->GetStreamSendBufferAllocator(), + GetQuicFlag(quic_send_buffer_max_data_slice_size)); + return WriteBodySlices(storage.ToSpan(), fin); +} + +bool QuicSpdyStream::WriteDataFrameHeader(QuicByteCount data_length, + bool force_write) { + QUICHE_DCHECK(VersionUsesHttp3(transport_version())); + QUICHE_DCHECK_GT(data_length, 0u); + quiche::QuicheBuffer header = HttpEncoder::SerializeDataFrameHeader( + data_length, + spdy_session_->connection()->helper()->GetStreamSendBufferAllocator()); + const bool can_write = CanWriteNewDataAfterData(header.size()); + if (!can_write && !force_write) { + return false; + } + + if (spdy_session_->debug_visitor()) { + spdy_session_->debug_visitor()->OnDataFrameSent(id(), data_length); + } + + unacked_frame_headers_offsets_.Add( + send_buffer().stream_offset(), + send_buffer().stream_offset() + header.size()); + QUIC_DLOG(INFO) << ENDPOINT << "Stream " << id() + << " is writing DATA frame header of length " + << header.size(); + if (can_write) { + // Save one copy and allocation if send buffer can accomodate the header. + quiche::QuicheMemSlice header_slice(std::move(header)); + WriteMemSlices(absl::MakeSpan(&header_slice, 1), false); + } else { + QUICHE_DCHECK(force_write); + WriteOrBufferData(header.AsStringView(), false, nullptr); + } + return true; +} + +QuicConsumedData QuicSpdyStream::WriteBodySlices( + absl::Span slices, bool fin) { + if (!VersionUsesHttp3(transport_version()) || slices.empty()) { + return WriteMemSlices(slices, fin); + } + + QuicConnection::ScopedPacketFlusher flusher(spdy_session_->connection()); + const QuicByteCount data_size = MemSliceSpanTotalSize(slices); + if (!WriteDataFrameHeader(data_size, /*force_write=*/false)) { + return {0, false}; + } + + QUIC_DLOG(INFO) << ENDPOINT << "Stream " << id() + << " is writing DATA frame payload of length " << data_size; + return WriteMemSlices(slices, fin); +} + +size_t QuicSpdyStream::Readv(const struct iovec* iov, size_t iov_len) { + QUICHE_DCHECK(FinishedReadingHeaders()); + if (!VersionUsesHttp3(transport_version())) { + return sequencer()->Readv(iov, iov_len); + } + size_t bytes_read = 0; + sequencer()->MarkConsumed(body_manager_.ReadBody(iov, iov_len, &bytes_read)); + + return bytes_read; +} + +int QuicSpdyStream::GetReadableRegions(iovec* iov, size_t iov_len) const { + QUICHE_DCHECK(FinishedReadingHeaders()); + if (!VersionUsesHttp3(transport_version())) { + return sequencer()->GetReadableRegions(iov, iov_len); + } + return body_manager_.PeekBody(iov, iov_len); +} + +void QuicSpdyStream::MarkConsumed(size_t num_bytes) { + QUICHE_DCHECK(FinishedReadingHeaders()); + if (!VersionUsesHttp3(transport_version())) { + sequencer()->MarkConsumed(num_bytes); + return; + } + + sequencer()->MarkConsumed(body_manager_.OnBodyConsumed(num_bytes)); +} + +bool QuicSpdyStream::IsDoneReading() const { + bool done_reading_headers = FinishedReadingHeaders(); + bool done_reading_body = sequencer()->IsClosed(); + bool done_reading_trailers = FinishedReadingTrailers(); + return done_reading_headers && done_reading_body && done_reading_trailers; +} + +bool QuicSpdyStream::HasBytesToRead() const { + if (!VersionUsesHttp3(transport_version())) { + return sequencer()->HasBytesToRead(); + } + return body_manager_.HasBytesToRead(); +} + +void QuicSpdyStream::MarkTrailersConsumed() { trailers_consumed_ = true; } + +uint64_t QuicSpdyStream::total_body_bytes_read() const { + if (VersionUsesHttp3(transport_version())) { + return body_manager_.total_body_bytes_received(); + } + return sequencer()->NumBytesConsumed(); +} + +void QuicSpdyStream::ConsumeHeaderList() { + header_list_.Clear(); + + if (!FinishedReadingHeaders()) { + return; + } + + if (!VersionUsesHttp3(transport_version())) { + sequencer()->SetUnblocked(); + return; + } + + if (body_manager_.HasBytesToRead()) { + HandleBodyAvailable(); + return; + } + + if (sequencer()->IsClosed() && + !on_body_available_called_because_sequencer_is_closed_) { + on_body_available_called_because_sequencer_is_closed_ = true; + HandleBodyAvailable(); + } +} + +void QuicSpdyStream::OnStreamHeadersPriority( + const spdy::SpdyStreamPrecedence& precedence) { + QUICHE_DCHECK_EQ(Perspective::IS_SERVER, + session()->connection()->perspective()); + if (session()->priority_type() != QuicPriorityType::kHttp) { + return; + } + SetPriority(QuicStreamPriority(HttpStreamPriority{ + precedence.spdy3_priority(), HttpStreamPriority::kDefaultIncremental})); +} + +void QuicSpdyStream::OnStreamHeaderList(bool fin, size_t frame_len, + const QuicHeaderList& header_list) { + if (!spdy_session()->user_agent_id().has_value()) { + std::string uaid; + for (const auto& kv : header_list) { + if (quiche::QuicheTextUtils::ToLower(kv.first) == kUserAgentHeaderName) { + uaid = kv.second; + break; + } + } + spdy_session()->SetUserAgentId(std::move(uaid)); + } + + // TODO(b/134706391): remove |fin| argument. + // When using Google QUIC, an empty header list indicates that the size limit + // has been exceeded. + // When using IETF QUIC, there is an explicit signal from + // QpackDecodedHeadersAccumulator. + if ((VersionUsesHttp3(transport_version()) && + header_list_size_limit_exceeded_) || + (!VersionUsesHttp3(transport_version()) && header_list.empty())) { + OnHeadersTooLarge(); + if (IsDoneReading()) { + return; + } + } + if (!headers_decompressed_) { + OnInitialHeadersComplete(fin, frame_len, header_list); + } else { + OnTrailingHeadersComplete(fin, frame_len, header_list); + } +} + +void QuicSpdyStream::OnHeadersDecoded(QuicHeaderList headers, + bool header_list_size_limit_exceeded) { + header_list_size_limit_exceeded_ = header_list_size_limit_exceeded; + qpack_decoded_headers_accumulator_.reset(); + + QuicSpdySession::LogHeaderCompressionRatioHistogram( + /* using_qpack = */ true, + /* is_sent = */ false, headers.compressed_header_bytes(), + headers.uncompressed_header_bytes()); + + const QuicStreamId promised_stream_id = spdy_session()->promised_stream_id(); + Http3DebugVisitor* const debug_visitor = spdy_session()->debug_visitor(); + if (promised_stream_id == + QuicUtils::GetInvalidStreamId(transport_version())) { + if (debug_visitor) { + debug_visitor->OnHeadersDecoded(id(), headers); + } + + OnStreamHeaderList(/* fin = */ false, headers_payload_length_, headers); + } else { + spdy_session_->OnHeaderList(headers); + } + + if (blocked_on_decoding_headers_) { + blocked_on_decoding_headers_ = false; + // Continue decoding HTTP/3 frames. + OnDataAvailable(); + } +} + +void QuicSpdyStream::OnHeaderDecodingError(QuicErrorCode error_code, + absl::string_view error_message) { + qpack_decoded_headers_accumulator_.reset(); + + std::string connection_close_error_message = absl::StrCat( + "Error decoding ", headers_decompressed_ ? "trailers" : "headers", + " on stream ", id(), ": ", error_message); + OnUnrecoverableError(error_code, connection_close_error_message); +} + +void QuicSpdyStream::MaybeSendPriorityUpdateFrame() { + if (!VersionUsesHttp3(transport_version()) || + session()->perspective() != Perspective::IS_CLIENT) { + return; + } + if (spdy_session_->priority_type() != QuicPriorityType::kHttp) { + return; + } + + if (last_sent_priority_ == priority()) { + return; + } + last_sent_priority_ = priority(); + + spdy_session_->WriteHttp3PriorityUpdate(id(), priority().http()); +} + +void QuicSpdyStream::OnHeadersTooLarge() { Reset(QUIC_HEADERS_TOO_LARGE); } + +void QuicSpdyStream::OnInitialHeadersComplete( + bool fin, size_t /*frame_len*/, const QuicHeaderList& header_list) { + // TODO(b/134706391): remove |fin| argument. + headers_decompressed_ = true; + header_list_ = header_list; + bool header_too_large = VersionUsesHttp3(transport_version()) + ? header_list_size_limit_exceeded_ + : header_list.empty(); + if (!AreHeaderFieldValuesValid(header_list)) { + OnInvalidHeaders(); + return; + } + // Validate request headers if it did not exceed size limit. If it did, + // OnHeadersTooLarge() should have already handled it previously. + if (!header_too_large && !AreHeadersValid(header_list)) { + QUIC_CODE_COUNT_N(quic_validate_request_header, 1, 2); + if (GetQuicReloadableFlag(quic_act_upon_invalid_header)) { + QUIC_RELOADABLE_FLAG_COUNT(quic_act_upon_invalid_header); + OnInvalidHeaders(); + return; + } + } + QUIC_CODE_COUNT_N(quic_validate_request_header, 2, 2); + + if (!GetQuicReloadableFlag(quic_verify_request_headers_2) || + !header_too_large) { + MaybeProcessReceivedWebTransportHeaders(); + } + + if (VersionUsesHttp3(transport_version())) { + if (fin) { + OnStreamFrame(QuicStreamFrame(id(), /* fin = */ true, + highest_received_byte_offset(), + absl::string_view())); + } + return; + } + + if (fin && !rst_sent()) { + OnStreamFrame( + QuicStreamFrame(id(), fin, /* offset = */ 0, absl::string_view())); + } + if (FinishedReadingHeaders()) { + sequencer()->SetUnblocked(); + } +} + +void QuicSpdyStream::OnPromiseHeaderList( + QuicStreamId /* promised_id */, size_t /* frame_len */, + const QuicHeaderList& /*header_list */) { + // To be overridden in QuicSpdyClientStream. Not supported on + // server side. + stream_delegate()->OnStreamError(QUIC_INVALID_HEADERS_STREAM_DATA, + "Promise headers received by server"); +} + +bool QuicSpdyStream::CopyAndValidateTrailers(const QuicHeaderList& header_list, + bool expect_final_byte_offset, + size_t* final_byte_offset, + spdy::Http2HeaderBlock* trailers) { + return SpdyUtils::CopyAndValidateTrailers( + header_list, expect_final_byte_offset, final_byte_offset, trailers); +} + +void QuicSpdyStream::OnTrailingHeadersComplete( + bool fin, size_t /*frame_len*/, const QuicHeaderList& header_list) { + // TODO(b/134706391): remove |fin| argument. + QUICHE_DCHECK(!trailers_decompressed_); + if (!VersionUsesHttp3(transport_version()) && fin_received()) { + QUIC_DLOG(INFO) << ENDPOINT + << "Received Trailers after FIN, on stream: " << id(); + stream_delegate()->OnStreamError(QUIC_INVALID_HEADERS_STREAM_DATA, + "Trailers after fin"); + return; + } + + if (!VersionUsesHttp3(transport_version()) && !fin) { + QUIC_DLOG(INFO) << ENDPOINT + << "Trailers must have FIN set, on stream: " << id(); + stream_delegate()->OnStreamError(QUIC_INVALID_HEADERS_STREAM_DATA, + "Fin missing from trailers"); + return; + } + + size_t final_byte_offset = 0; + const bool expect_final_byte_offset = !VersionUsesHttp3(transport_version()); + if (!CopyAndValidateTrailers(header_list, expect_final_byte_offset, + &final_byte_offset, &received_trailers_)) { + QUIC_DLOG(ERROR) << ENDPOINT << "Trailers for stream " << id() + << " are malformed."; + stream_delegate()->OnStreamError(QUIC_INVALID_HEADERS_STREAM_DATA, + "Trailers are malformed"); + return; + } + trailers_decompressed_ = true; + if (fin) { + const QuicStreamOffset offset = VersionUsesHttp3(transport_version()) + ? highest_received_byte_offset() + : final_byte_offset; + OnStreamFrame(QuicStreamFrame(id(), fin, offset, absl::string_view())); + } +} + +void QuicSpdyStream::OnPriorityFrame( + const spdy::SpdyStreamPrecedence& precedence) { + QUICHE_DCHECK_EQ(Perspective::IS_SERVER, + session()->connection()->perspective()); + if (session()->priority_type() != QuicPriorityType::kHttp) { + return; + } + SetPriority(QuicStreamPriority(HttpStreamPriority{ + precedence.spdy3_priority(), HttpStreamPriority::kDefaultIncremental})); +} + +void QuicSpdyStream::OnStreamReset(const QuicRstStreamFrame& frame) { + if (web_transport_data_ != nullptr) { + WebTransportStreamVisitor* webtransport_visitor = + web_transport_data_->adapter.visitor(); + if (webtransport_visitor != nullptr) { + webtransport_visitor->OnResetStreamReceived( + Http3ErrorToWebTransportOrDefault(frame.ietf_error_code)); + } + QuicStream::OnStreamReset(frame); + return; + } + + if (VersionUsesHttp3(transport_version()) && !fin_received() && + spdy_session_->qpack_decoder()) { + spdy_session_->qpack_decoder()->OnStreamReset(id()); + qpack_decoded_headers_accumulator_.reset(); + } + + if (VersionUsesHttp3(transport_version()) || + frame.error_code != QUIC_STREAM_NO_ERROR) { + QuicStream::OnStreamReset(frame); + return; + } + + QUIC_DVLOG(1) << ENDPOINT + << "Received QUIC_STREAM_NO_ERROR, not discarding response"; + set_rst_received(true); + MaybeIncreaseHighestReceivedOffset(frame.byte_offset); + set_stream_error(frame.error()); + CloseWriteSide(); +} + +void QuicSpdyStream::ResetWithError(QuicResetStreamError error) { + if (VersionUsesHttp3(transport_version()) && !fin_received() && + spdy_session_->qpack_decoder() && web_transport_data_ == nullptr) { + spdy_session_->qpack_decoder()->OnStreamReset(id()); + qpack_decoded_headers_accumulator_.reset(); + } + + QuicStream::ResetWithError(error); +} + +bool QuicSpdyStream::OnStopSending(QuicResetStreamError error) { + if (web_transport_data_ != nullptr) { + WebTransportStreamVisitor* visitor = web_transport_data_->adapter.visitor(); + if (visitor != nullptr) { + visitor->OnStopSendingReceived( + Http3ErrorToWebTransportOrDefault(error.ietf_application_code())); + } + } + + return QuicStream::OnStopSending(error); +} + +void QuicSpdyStream::OnWriteSideInDataRecvdState() { + if (web_transport_data_ != nullptr) { + WebTransportStreamVisitor* visitor = web_transport_data_->adapter.visitor(); + if (visitor != nullptr) { + visitor->OnWriteSideInDataRecvdState(); + } + } + + QuicStream::OnWriteSideInDataRecvdState(); +} + +void QuicSpdyStream::OnDataAvailable() { + if (!VersionUsesHttp3(transport_version())) { + // Sequencer must be blocked until headers are consumed. + QUICHE_DCHECK(FinishedReadingHeaders()); + } + + if (!VersionUsesHttp3(transport_version())) { + HandleBodyAvailable(); + return; + } + + if (web_transport_data_ != nullptr) { + web_transport_data_->adapter.OnDataAvailable(); + return; + } + + if (!spdy_session()->ShouldProcessIncomingRequests()) { + spdy_session()->OnStreamWaitingForClientSettings(id()); + return; + } + + if (is_decoder_processing_input_) { + // Let the outermost nested OnDataAvailable() call do the work. + return; + } + + if (blocked_on_decoding_headers_) { + return; + } + + iovec iov; + while (session()->connection()->connected() && !reading_stopped() && + decoder_.error() == QUIC_NO_ERROR) { + QUICHE_DCHECK_GE(sequencer_offset_, sequencer()->NumBytesConsumed()); + if (!sequencer()->PeekRegion(sequencer_offset_, &iov)) { + break; + } + + QUICHE_DCHECK(!sequencer()->IsClosed()); + is_decoder_processing_input_ = true; + QuicByteCount processed_bytes = decoder_.ProcessInput( + reinterpret_cast(iov.iov_base), iov.iov_len); + is_decoder_processing_input_ = false; + if (!session()->connection()->connected()) { + return; + } + sequencer_offset_ += processed_bytes; + if (blocked_on_decoding_headers_) { + return; + } + if (web_transport_data_ != nullptr) { + return; + } + } + + // Do not call HandleBodyAvailable() until headers are consumed. + if (!FinishedReadingHeaders()) { + return; + } + + if (body_manager_.HasBytesToRead()) { + HandleBodyAvailable(); + return; + } + + if (sequencer()->IsClosed() && + !on_body_available_called_because_sequencer_is_closed_) { + on_body_available_called_because_sequencer_is_closed_ = true; + HandleBodyAvailable(); + } +} + +void QuicSpdyStream::OnClose() { + QuicStream::OnClose(); + + qpack_decoded_headers_accumulator_.reset(); + + if (visitor_) { + Visitor* visitor = visitor_; + // Calling Visitor::OnClose() may result the destruction of the visitor, + // so we need to ensure we don't call it again. + visitor_ = nullptr; + visitor->OnClose(this); + } + + if (web_transport_ != nullptr) { + web_transport_->OnConnectStreamClosing(); + } + if (web_transport_data_ != nullptr) { + WebTransportHttp3* web_transport = + spdy_session_->GetWebTransportSession(web_transport_data_->session_id); + if (web_transport == nullptr) { + // Since there is no guaranteed destruction order for streams, the session + // could be already removed from the stream map by the time we reach here. + QUIC_DLOG(WARNING) << ENDPOINT << "WebTransport stream " << id() + << " attempted to notify parent session " + << web_transport_data_->session_id + << ", but the session could not be found."; + return; + } + web_transport->OnStreamClosed(id()); + } +} + +void QuicSpdyStream::OnCanWrite() { + QuicStream::OnCanWrite(); + + // Trailers (and hence a FIN) may have been sent ahead of queued body bytes. + if (!HasBufferedData() && fin_sent()) { + CloseWriteSide(); + } +} + +bool QuicSpdyStream::FinishedReadingHeaders() const { + return headers_decompressed_ && header_list_.empty(); +} + +bool QuicSpdyStream::ParseHeaderStatusCode(const Http2HeaderBlock& header, + int* status_code) { + Http2HeaderBlock::const_iterator it = header.find(spdy::kHttp2StatusHeader); + if (it == header.end()) { + return false; + } + const absl::string_view status(it->second); + return ParseHeaderStatusCode(status, status_code); +} + +bool QuicSpdyStream::ParseHeaderStatusCode(absl::string_view status, + int* status_code) { + if (status.size() != 3) { + return false; + } + // First character must be an integer in range [1,5]. + if (status[0] < '1' || status[0] > '5') { + return false; + } + // The remaining two characters must be integers. + if (!isdigit(status[1]) || !isdigit(status[2])) { + return false; + } + return absl::SimpleAtoi(status, status_code); +} + +bool QuicSpdyStream::FinishedReadingTrailers() const { + // If no further trailing headers are expected, and the decompressed trailers + // (if any) have been consumed, then reading of trailers is finished. + if (!fin_received()) { + return false; + } else if (!trailers_decompressed_) { + return true; + } else { + return trailers_consumed_; + } +} + +bool QuicSpdyStream::OnDataFrameStart(QuicByteCount header_length, + QuicByteCount payload_length) { + QUICHE_DCHECK(VersionUsesHttp3(transport_version())); + + if (spdy_session_->debug_visitor()) { + spdy_session_->debug_visitor()->OnDataFrameReceived(id(), payload_length); + } + + if (!headers_decompressed_ || trailers_decompressed_) { + stream_delegate()->OnStreamError( + QUIC_HTTP_INVALID_FRAME_SEQUENCE_ON_SPDY_STREAM, + "Unexpected DATA frame received."); + return false; + } + + sequencer()->MarkConsumed(body_manager_.OnNonBody(header_length)); + + return true; +} + +bool QuicSpdyStream::OnDataFramePayload(absl::string_view payload) { + QUICHE_DCHECK(VersionUsesHttp3(transport_version())); + + body_manager_.OnBody(payload); + + return true; +} + +bool QuicSpdyStream::OnDataFrameEnd() { + QUICHE_DCHECK(VersionUsesHttp3(transport_version())); + + QUIC_DVLOG(1) << ENDPOINT + << "Reaches the end of a data frame. Total bytes received are " + << body_manager_.total_body_bytes_received(); + return true; +} + +bool QuicSpdyStream::OnStreamFrameAcked(QuicStreamOffset offset, + QuicByteCount data_length, + bool fin_acked, + QuicTime::Delta ack_delay_time, + QuicTime receive_timestamp, + QuicByteCount* newly_acked_length) { + const bool new_data_acked = QuicStream::OnStreamFrameAcked( + offset, data_length, fin_acked, ack_delay_time, receive_timestamp, + newly_acked_length); + + const QuicByteCount newly_acked_header_length = + GetNumFrameHeadersInInterval(offset, data_length); + QUICHE_DCHECK_LE(newly_acked_header_length, *newly_acked_length); + unacked_frame_headers_offsets_.Difference(offset, offset + data_length); + if (ack_listener_ != nullptr && new_data_acked) { + ack_listener_->OnPacketAcked( + *newly_acked_length - newly_acked_header_length, ack_delay_time); + } + return new_data_acked; +} + +void QuicSpdyStream::OnStreamFrameRetransmitted(QuicStreamOffset offset, + QuicByteCount data_length, + bool fin_retransmitted) { + QuicStream::OnStreamFrameRetransmitted(offset, data_length, + fin_retransmitted); + + const QuicByteCount retransmitted_header_length = + GetNumFrameHeadersInInterval(offset, data_length); + QUICHE_DCHECK_LE(retransmitted_header_length, data_length); + + if (ack_listener_ != nullptr) { + ack_listener_->OnPacketRetransmitted(data_length - + retransmitted_header_length); + } +} + +QuicByteCount QuicSpdyStream::GetNumFrameHeadersInInterval( + QuicStreamOffset offset, QuicByteCount data_length) const { + QuicByteCount header_acked_length = 0; + QuicIntervalSet newly_acked(offset, offset + data_length); + newly_acked.Intersection(unacked_frame_headers_offsets_); + for (const auto& interval : newly_acked) { + header_acked_length += interval.Length(); + } + return header_acked_length; +} + +bool QuicSpdyStream::OnHeadersFrameStart(QuicByteCount header_length, + QuicByteCount payload_length) { + QUICHE_DCHECK(VersionUsesHttp3(transport_version())); + QUICHE_DCHECK(!qpack_decoded_headers_accumulator_); + + if (spdy_session_->debug_visitor()) { + spdy_session_->debug_visitor()->OnHeadersFrameReceived(id(), + payload_length); + } + + headers_payload_length_ = payload_length; + + if (trailers_decompressed_) { + stream_delegate()->OnStreamError( + QUIC_HTTP_INVALID_FRAME_SEQUENCE_ON_SPDY_STREAM, + "HEADERS frame received after trailing HEADERS."); + return false; + } + + sequencer()->MarkConsumed(body_manager_.OnNonBody(header_length)); + + qpack_decoded_headers_accumulator_ = + std::make_unique( + id(), spdy_session_->qpack_decoder(), this, + spdy_session_->max_inbound_header_list_size()); + + return true; +} + +bool QuicSpdyStream::OnHeadersFramePayload(absl::string_view payload) { + QUICHE_DCHECK(VersionUsesHttp3(transport_version())); + + if (!qpack_decoded_headers_accumulator_) { + QUIC_BUG(b215142466_OnHeadersFramePayload); + OnHeaderDecodingError(QUIC_INTERNAL_ERROR, + "qpack_decoded_headers_accumulator_ is nullptr"); + return false; + } + + qpack_decoded_headers_accumulator_->Decode(payload); + + // |qpack_decoded_headers_accumulator_| is reset if an error is detected. + if (!qpack_decoded_headers_accumulator_) { + return false; + } + + sequencer()->MarkConsumed(body_manager_.OnNonBody(payload.size())); + return true; +} + +bool QuicSpdyStream::OnHeadersFrameEnd() { + QUICHE_DCHECK(VersionUsesHttp3(transport_version())); + + if (!qpack_decoded_headers_accumulator_) { + QUIC_BUG(b215142466_OnHeadersFrameEnd); + OnHeaderDecodingError(QUIC_INTERNAL_ERROR, + "qpack_decoded_headers_accumulator_ is nullptr"); + return false; + } + + qpack_decoded_headers_accumulator_->EndHeaderBlock(); + + // If decoding is complete or an error is detected, then + // |qpack_decoded_headers_accumulator_| is already reset. + if (qpack_decoded_headers_accumulator_) { + blocked_on_decoding_headers_ = true; + return false; + } + + return !sequencer()->IsClosed() && !reading_stopped(); +} + +void QuicSpdyStream::OnWebTransportStreamFrameType( + QuicByteCount header_length, WebTransportSessionId session_id) { + QUIC_DVLOG(1) << ENDPOINT << " Received WEBTRANSPORT_STREAM on stream " + << id() << " for session " << session_id; + sequencer()->MarkConsumed(header_length); + + if (headers_payload_length_ > 0 || headers_decompressed_) { + QUIC_PEER_BUG(WEBTRANSPORT_STREAM received on HTTP request) + << ENDPOINT << "Stream " << id() + << " tried to convert to WebTransport, but it already " + "has HTTP data on it"; + Reset(QUIC_STREAM_FRAME_UNEXPECTED); + } + if (QuicUtils::IsOutgoingStreamId(spdy_session_->version(), id(), + spdy_session_->perspective())) { + QUIC_PEER_BUG(WEBTRANSPORT_STREAM received on outgoing request) + << ENDPOINT << "Stream " << id() + << " tried to convert to WebTransport, but only the " + "initiator of the stream can do it."; + Reset(QUIC_STREAM_FRAME_UNEXPECTED); + } + + QUICHE_DCHECK(web_transport_ == nullptr); + web_transport_data_ = + std::make_unique(this, session_id); + spdy_session_->AssociateIncomingWebTransportStreamWithSession(session_id, + id()); +} + +bool QuicSpdyStream::OnUnknownFrameStart(uint64_t frame_type, + QuicByteCount header_length, + QuicByteCount payload_length) { + if (spdy_session_->debug_visitor()) { + spdy_session_->debug_visitor()->OnUnknownFrameReceived(id(), frame_type, + payload_length); + } + spdy_session_->OnUnknownFrameStart(id(), frame_type, header_length, + payload_length); + + // Consume the frame header. + QUIC_DVLOG(1) << ENDPOINT << "Consuming " << header_length + << " byte long frame header of frame of unknown type " + << frame_type << "."; + sequencer()->MarkConsumed(body_manager_.OnNonBody(header_length)); + return true; +} + +bool QuicSpdyStream::OnUnknownFramePayload(absl::string_view payload) { + spdy_session_->OnUnknownFramePayload(id(), payload); + + // Consume the frame payload. + QUIC_DVLOG(1) << ENDPOINT << "Consuming " << payload.size() + << " bytes of payload of frame of unknown type."; + sequencer()->MarkConsumed(body_manager_.OnNonBody(payload.size())); + return true; +} + +bool QuicSpdyStream::OnUnknownFrameEnd() { return true; } + +size_t QuicSpdyStream::WriteHeadersImpl( + spdy::Http2HeaderBlock header_block, bool fin, + quiche::QuicheReferenceCountedPointer + ack_listener) { + if (!VersionUsesHttp3(transport_version())) { + return spdy_session_->WriteHeadersOnHeadersStream( + id(), std::move(header_block), fin, + spdy::SpdyStreamPrecedence(priority().http().urgency), + std::move(ack_listener)); + } + + // Encode header list. + QuicByteCount encoder_stream_sent_byte_count; + std::string encoded_headers = + spdy_session_->qpack_encoder()->EncodeHeaderList( + id(), header_block, &encoder_stream_sent_byte_count); + + if (spdy_session_->debug_visitor()) { + spdy_session_->debug_visitor()->OnHeadersFrameSent(id(), header_block); + } + + // Write HEADERS frame. + std::string headers_frame_header = + HttpEncoder::SerializeHeadersFrameHeader(encoded_headers.size()); + unacked_frame_headers_offsets_.Add( + send_buffer().stream_offset(), + send_buffer().stream_offset() + headers_frame_header.length()); + + QUIC_DLOG(INFO) << ENDPOINT << "Stream " << id() + << " is writing HEADERS frame header of length " + << headers_frame_header.length() << ", and payload of length " + << encoded_headers.length() << " with fin " << fin; + WriteOrBufferData(absl::StrCat(headers_frame_header, encoded_headers), fin, + /*ack_listener=*/nullptr); + + QuicSpdySession::LogHeaderCompressionRatioHistogram( + /* using_qpack = */ true, + /* is_sent = */ true, + encoded_headers.size() + encoder_stream_sent_byte_count, + header_block.TotalBytesUsed()); + + return encoded_headers.size(); +} + +bool QuicSpdyStream::CanWriteNewBodyData(QuicByteCount write_size) const { + QUICHE_DCHECK_NE(0u, write_size); + if (!VersionUsesHttp3(transport_version())) { + return CanWriteNewData(); + } + + return CanWriteNewDataAfterData( + HttpEncoder::GetDataFrameHeaderLength(write_size)); +} + +void QuicSpdyStream::MaybeProcessReceivedWebTransportHeaders() { + if (!spdy_session_->SupportsWebTransport()) { + return; + } + if (session()->perspective() != Perspective::IS_SERVER) { + return; + } + QUICHE_DCHECK(IsValidWebTransportSessionId(id(), version())); + + std::string method; + std::string protocol; + for (const auto& [header_name, header_value] : header_list_) { + if (header_name == ":method") { + if (!method.empty() || header_value.empty()) { + return; + } + method = header_value; + } + if (header_name == ":protocol") { + if (!protocol.empty() || header_value.empty()) { + return; + } + protocol = header_value; + } + if (header_name == "datagram-flow-id") { + QUIC_DLOG(ERROR) << ENDPOINT + << "Rejecting WebTransport due to unexpected " + "Datagram-Flow-Id header"; + return; + } + if (header_name == "sec-webtransport-http3-draft02") { + if (header_value != "1") { + QUIC_DLOG(ERROR) << ENDPOINT + << "Rejecting WebTransport due to invalid value of " + "Sec-Webtransport-Http3-Draft02 header"; + return; + } + } + } + + if (method != "CONNECT" || protocol != "webtransport") { + return; + } + + web_transport_ = + std::make_unique(spdy_session_, this, id()); +} + +void QuicSpdyStream::MaybeProcessSentWebTransportHeaders( + spdy::Http2HeaderBlock& headers) { + if (!spdy_session_->SupportsWebTransport()) { + return; + } + if (session()->perspective() != Perspective::IS_CLIENT) { + return; + } + QUICHE_DCHECK(IsValidWebTransportSessionId(id(), version())); + + const auto method_it = headers.find(":method"); + const auto protocol_it = headers.find(":protocol"); + if (method_it == headers.end() || protocol_it == headers.end()) { + return; + } + if (method_it->second != "CONNECT" && protocol_it->second != "webtransport") { + return; + } + + headers["sec-webtransport-http3-draft02"] = "1"; + + web_transport_ = + std::make_unique(spdy_session_, this, id()); +} + +void QuicSpdyStream::OnCanWriteNewData() { + if (web_transport_data_ != nullptr) { + web_transport_data_->adapter.OnCanWriteNewData(); + } +} + +bool QuicSpdyStream::AssertNotWebTransportDataStream( + absl::string_view operation) { + if (web_transport_data_ != nullptr) { + QUIC_BUG(Invalid operation on WebTransport stream) + << "Attempted to " << operation << " on WebTransport data stream " + << id() << " associated with session " + << web_transport_data_->session_id; + OnUnrecoverableError(QUIC_INTERNAL_ERROR, + absl::StrCat("Attempted to ", operation, + " on WebTransport data stream")); + return false; + } + return true; +} + +void QuicSpdyStream::ConvertToWebTransportDataStream( + WebTransportSessionId session_id) { + if (send_buffer().stream_offset() != 0) { + QUIC_BUG(Sending WEBTRANSPORT_STREAM when data already sent) + << "Attempted to send a WEBTRANSPORT_STREAM frame when other data has " + "already been sent on the stream."; + OnUnrecoverableError(QUIC_INTERNAL_ERROR, + "Attempted to send a WEBTRANSPORT_STREAM frame when " + "other data has already been sent on the stream."); + return; + } + + std::string header = + HttpEncoder::SerializeWebTransportStreamFrameHeader(session_id); + if (header.empty()) { + QUIC_BUG(Failed to serialize WEBTRANSPORT_STREAM) + << "Failed to serialize a WEBTRANSPORT_STREAM frame."; + OnUnrecoverableError(QUIC_INTERNAL_ERROR, + "Failed to serialize a WEBTRANSPORT_STREAM frame."); + return; + } + + WriteOrBufferData(header, /*fin=*/false, nullptr); + web_transport_data_ = + std::make_unique(this, session_id); + QUIC_DVLOG(1) << ENDPOINT << "Successfully opened WebTransport data stream " + << id() << " for session " << session_id; +} + +QuicSpdyStream::WebTransportDataStream::WebTransportDataStream( + QuicSpdyStream* stream, WebTransportSessionId session_id) + : session_id(session_id), + adapter(stream->spdy_session_, stream, stream->sequencer()) {} + +void QuicSpdyStream::HandleReceivedDatagram(absl::string_view payload) { + if (datagram_visitor_ == nullptr) { + QUIC_DLOG(ERROR) << ENDPOINT << "Received datagram without any visitor"; + return; + } + datagram_visitor_->OnHttp3Datagram(id(), payload); +} + +bool QuicSpdyStream::OnCapsule(const Capsule& capsule) { + QUIC_DLOG(INFO) << ENDPOINT << "Stream " << id() << " received capsule " + << capsule; + if (!headers_decompressed_) { + QUIC_PEER_BUG(capsule before headers) + << ENDPOINT << "Stream " << id() << " received capsule " << capsule + << " before headers"; + return false; + } + if (web_transport_ != nullptr && web_transport_->close_received()) { + QUIC_PEER_BUG(capsule after close) + << ENDPOINT << "Stream " << id() << " received capsule " << capsule + << " after CLOSE_WEBTRANSPORT_SESSION."; + return false; + } + switch (capsule.capsule_type()) { + case CapsuleType::DATAGRAM: + HandleReceivedDatagram(capsule.datagram_capsule().http_datagram_payload); + return true; + case CapsuleType::LEGACY_DATAGRAM: + HandleReceivedDatagram( + capsule.legacy_datagram_capsule().http_datagram_payload); + return true; + case CapsuleType::LEGACY_DATAGRAM_WITHOUT_CONTEXT: + HandleReceivedDatagram(capsule.legacy_datagram_without_context_capsule() + .http_datagram_payload); + return true; + case CapsuleType::CLOSE_WEBTRANSPORT_SESSION: + if (web_transport_ == nullptr) { + QUIC_DLOG(ERROR) << ENDPOINT << "Received capsule " << capsule + << " for a non-WebTransport stream."; + return false; + } + web_transport_->OnCloseReceived( + capsule.close_web_transport_session_capsule().error_code, + capsule.close_web_transport_session_capsule().error_message); + return true; + case CapsuleType::ADDRESS_ASSIGN: + if (connect_ip_visitor_ == nullptr) { + return true; + } + return connect_ip_visitor_->OnAddressAssignCapsule( + capsule.address_assign_capsule()); + case CapsuleType::ADDRESS_REQUEST: + if (connect_ip_visitor_ == nullptr) { + return true; + } + return connect_ip_visitor_->OnAddressRequestCapsule( + capsule.address_request_capsule()); + case CapsuleType::ROUTE_ADVERTISEMENT: + if (connect_ip_visitor_ == nullptr) { + return true; + } + return connect_ip_visitor_->OnRouteAdvertisementCapsule( + capsule.route_advertisement_capsule()); + + // Ignore WebTransport over HTTP/2 capsules. + case CapsuleType::WT_RESET_STREAM: + case CapsuleType::WT_STOP_SENDING: + case CapsuleType::WT_STREAM: + case CapsuleType::WT_STREAM_WITH_FIN: + case CapsuleType::WT_MAX_STREAM_DATA: + case CapsuleType::WT_MAX_STREAMS_BIDI: + case CapsuleType::WT_MAX_STREAMS_UNIDI: + return true; + } + if (datagram_visitor_) { + datagram_visitor_->OnUnknownCapsule(id(), capsule.unknown_capsule()); + } + return true; +} + +void QuicSpdyStream::OnCapsuleParseFailure(absl::string_view error_message) { + QUIC_DLOG(ERROR) << ENDPOINT << "Capsule parse failure: " << error_message; + Reset(QUIC_BAD_APPLICATION_PAYLOAD); +} + +void QuicSpdyStream::WriteCapsule(const Capsule& capsule, bool fin) { + QUIC_DLOG(INFO) << ENDPOINT << "Stream " << id() << " sending capsule " + << capsule; + quiche::QuicheBuffer serialized_capsule = SerializeCapsule( + capsule, + spdy_session_->connection()->helper()->GetStreamSendBufferAllocator()); + QUICHE_DCHECK_GT(serialized_capsule.size(), 0u); + WriteOrBufferBody(serialized_capsule.AsStringView(), /*fin=*/fin); +} + +void QuicSpdyStream::WriteGreaseCapsule() { + // GREASE capsulde IDs have a form of 41 * N + 23. + QuicRandom* random = spdy_session_->connection()->random_generator(); + uint64_t type = random->InsecureRandUint64() >> 4; + type = (type / 41) * 41 + 23; + QUICHE_DCHECK_EQ((type - 23) % 41, 0u); + + constexpr size_t kMaxLength = 64; + size_t length = random->InsecureRandUint64() % kMaxLength; + std::string bytes(length, '\0'); + random->InsecureRandBytes(&bytes[0], bytes.size()); + Capsule capsule = Capsule::Unknown(type, bytes); + WriteCapsule(capsule, /*fin=*/false); +} + +MessageStatus QuicSpdyStream::SendHttp3Datagram(absl::string_view payload) { + return spdy_session_->SendHttp3Datagram(id(), payload); +} + +void QuicSpdyStream::RegisterHttp3DatagramVisitor( + Http3DatagramVisitor* visitor) { + if (visitor == nullptr) { + QUIC_BUG(null datagram visitor) + << ENDPOINT << "Null datagram visitor for stream ID " << id(); + return; + } + QUIC_DLOG(INFO) << ENDPOINT << "Registering datagram visitor with stream ID " + << id(); + + if (datagram_visitor_ != nullptr) { + QUIC_BUG(h3 datagram double registration) + << ENDPOINT + << "Attempted to doubly register HTTP/3 datagram with stream ID " + << id(); + return; + } + datagram_visitor_ = visitor; + QUICHE_DCHECK(!capsule_parser_); + capsule_parser_ = std::make_unique(this); +} + +void QuicSpdyStream::UnregisterHttp3DatagramVisitor() { + if (datagram_visitor_ == nullptr) { + QUIC_BUG(datagram visitor empty during unregistration) + << ENDPOINT << "Cannot unregister datagram visitor for stream ID " + << id(); + return; + } + QUIC_DLOG(INFO) << ENDPOINT << "Unregistering datagram visitor for stream ID " + << id(); + datagram_visitor_ = nullptr; +} + +void QuicSpdyStream::ReplaceHttp3DatagramVisitor( + Http3DatagramVisitor* visitor) { + QUIC_BUG_IF(h3 datagram unknown move, datagram_visitor_ == nullptr) + << "Attempted to move missing datagram visitor on HTTP/3 stream ID " + << id(); + datagram_visitor_ = visitor; +} + +void QuicSpdyStream::RegisterConnectIpVisitor(ConnectIpVisitor* visitor) { + if (visitor == nullptr) { + QUIC_BUG(null connect - ip visitor) + << ENDPOINT << "Null connect-ip visitor for stream ID " << id(); + return; + } + QUIC_DLOG(INFO) << ENDPOINT + << "Registering CONNECT-IP visitor with stream ID " << id(); + + if (connect_ip_visitor_ != nullptr) { + QUIC_BUG(connect - ip double registration) + << ENDPOINT << "Attempted to doubly register CONNECT-IP with stream ID " + << id(); + return; + } + connect_ip_visitor_ = visitor; +} + +void QuicSpdyStream::UnregisterConnectIpVisitor() { + if (connect_ip_visitor_ == nullptr) { + QUIC_BUG(connect - ip visitor empty during unregistration) + << ENDPOINT << "Cannot unregister CONNECT-IP visitor for stream ID " + << id(); + return; + } + QUIC_DLOG(INFO) << ENDPOINT + << "Unregistering CONNECT-IP visitor for stream ID " << id(); + connect_ip_visitor_ = nullptr; +} + +void QuicSpdyStream::ReplaceConnectIpVisitor(ConnectIpVisitor* visitor) { + QUIC_BUG_IF(connect - ip unknown move, connect_ip_visitor_ == nullptr) + << "Attempted to move missing CONNECT-IP visitor on HTTP/3 stream ID " + << id(); + connect_ip_visitor_ = visitor; +} + +void QuicSpdyStream::SetMaxDatagramTimeInQueue( + QuicTime::Delta max_time_in_queue) { + spdy_session_->SetMaxDatagramTimeInQueueForStreamId(id(), max_time_in_queue); +} + +void QuicSpdyStream::OnDatagramReceived(QuicDataReader* reader) { + if (!headers_decompressed_) { + QUIC_DLOG(INFO) << "Dropping datagram received before headers on stream ID " + << id(); + return; + } + HandleReceivedDatagram(reader->ReadRemainingPayload()); +} + +QuicByteCount QuicSpdyStream::GetMaxDatagramSize() const { + QuicByteCount prefix_size = 0; + switch (spdy_session_->http_datagram_support()) { + case HttpDatagramSupport::kDraft04: + case HttpDatagramSupport::kRfc: + prefix_size = + QuicDataWriter::GetVarInt62Len(id() / kHttpDatagramStreamIdDivisor); + break; + case HttpDatagramSupport::kNone: + case HttpDatagramSupport::kRfcAndDraft04: + QUIC_BUG(GetMaxDatagramSize called with no datagram support) + << "GetMaxDatagramSize() called when no HTTP/3 datagram support has " + "been negotiated. Support value: " + << spdy_session_->http_datagram_support(); + break; + } + // If the logic above fails, use the largest possible value as the safe one. + if (prefix_size == 0) { + prefix_size = 8; + } + + QuicByteCount max_datagram_size = + session()->GetGuaranteedLargestMessagePayload(); + if (max_datagram_size < prefix_size) { + QUIC_BUG(max_datagram_size smaller than prefix_size) + << "GetGuaranteedLargestMessagePayload() returned a datagram size that " + "is not sufficient to fit stream ID into it."; + return 0; + } + return max_datagram_size - prefix_size; +} + +void QuicSpdyStream::HandleBodyAvailable() { + if (!capsule_parser_) { + OnBodyAvailable(); + return; + } + while (body_manager_.HasBytesToRead()) { + iovec iov; + int num_iov = GetReadableRegions(&iov, /*iov_len=*/1); + if (num_iov == 0) { + break; + } + if (!capsule_parser_->IngestCapsuleFragment(absl::string_view( + reinterpret_cast(iov.iov_base), iov.iov_len))) { + break; + } + MarkConsumed(iov.iov_len); + } + // If we received a FIN, make sure that there isn't a partial capsule buffered + // in the capsule parser. + if (sequencer()->IsClosed()) { + capsule_parser_->ErrorIfThereIsRemainingBufferedData(); + if (web_transport_ != nullptr) { + web_transport_->OnConnectStreamFinReceived(); + } + OnFinRead(); + } +} + +namespace { +// Return true if |c| is not allowed in an HTTP/3 wire-encoded header and +// pseudo-header names according to +// https://datatracker.ietf.org/doc/html/draft-ietf-quic-http#section-4.1.1 and +// https://datatracker.ietf.org/doc/html/draft-ietf-httpbis-semantics-19#section-5.6.2 +constexpr bool isInvalidHeaderNameCharacter(unsigned char c) { + if (c == '!' || c == '|' || c == '~' || c == '*' || c == '+' || c == '-' || + c == '.' || + // #, $, %, &, ' + (c >= '#' && c <= '\'') || + // [0-9], : + (c >= '0' && c <= ':') || + // ^, _, `, [a-z] + (c >= '^' && c <= 'z')) { + return false; + } + return true; +} +} // namespace + +bool QuicSpdyStream::AreHeadersValid(const QuicHeaderList& header_list) const { + QUICHE_DCHECK(GetQuicReloadableFlag(quic_verify_request_headers_2)); + for (const std::pair& pair : header_list) { + const std::string& name = pair.first; + if (std::any_of(name.begin(), name.end(), isInvalidHeaderNameCharacter)) { + QUIC_DLOG(ERROR) << "Invalid request header " << name; + return false; + } + if (http2::GetInvalidHttp2HeaderSet().contains(name)) { + QUIC_DLOG(ERROR) << name << " header is not allowed"; + return false; + } + } + return true; +} + +bool QuicSpdyStream::AreHeaderFieldValuesValid( + const QuicHeaderList& header_list) const { + if (!VersionUsesHttp3(transport_version())) { + return true; + } + // According to https://www.rfc-editor.org/rfc/rfc9114.html#section-10.3 + // "[...] HTTP/3 can transport field values that are not valid. While most + // values that can be encoded will not alter field parsing, carriage return + // (ASCII 0x0d), line feed (ASCII 0x0a), and the null character (ASCII 0x00) + // might be exploited by an attacker if they are translated verbatim. Any + // request or response that contains a character not permitted in a field + // value MUST be treated as malformed. + // [...]" + for (const std::pair& pair : header_list) { + const std::string& value = pair.second; + for (const auto c : value) { + if (c == '\0' || c == '\n' || c == '\r') { + return false; + } + } + } + return true; +} + +void QuicSpdyStream::OnInvalidHeaders() { Reset(QUIC_BAD_APPLICATION_PAYLOAD); } + +#undef ENDPOINT // undef for jumbo builds +} // namespace quic diff --git a/quiche/quic/core/http/quic_spdy_stream.h b/quiche/quic/core/http/quic_spdy_stream.h new file mode 100644 index 000000000000..301df2457fdc --- /dev/null +++ b/quiche/quic/core/http/quic_spdy_stream.h @@ -0,0 +1,500 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// The base class for streams which deliver data to/from an application. +// In each direction, the data on such a stream first contains compressed +// headers then body data. + +#ifndef QUICHE_QUIC_CORE_HTTP_QUIC_SPDY_STREAM_H_ +#define QUICHE_QUIC_CORE_HTTP_QUIC_SPDY_STREAM_H_ + +#include + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "quiche/quic/core/http/http_decoder.h" +#include "quiche/quic/core/http/http_encoder.h" +#include "quiche/quic/core/http/quic_header_list.h" +#include "quiche/quic/core/http/quic_spdy_stream_body_manager.h" +#include "quiche/quic/core/http/web_transport_stream_adapter.h" +#include "quiche/quic/core/qpack/qpack_decoded_headers_accumulator.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_stream.h" +#include "quiche/quic/core/quic_stream_priority.h" +#include "quiche/quic/core/quic_stream_sequencer.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/web_transport_interface.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/common/capsule.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/spdy/core/http2_header_block.h" +#include "quiche/spdy/core/spdy_framer.h" + +namespace quic { + +namespace test { +class QuicSpdyStreamPeer; +class QuicStreamPeer; +} // namespace test + +class QuicSpdySession; +class WebTransportHttp3; + +// A QUIC stream that can send and receive HTTP2 (SPDY) headers. +class QUIC_EXPORT_PRIVATE QuicSpdyStream + : public QuicStream, + public quiche::CapsuleParser::Visitor, + public QpackDecodedHeadersAccumulator::Visitor { + public: + // Visitor receives callbacks from the stream. + class QUIC_EXPORT_PRIVATE Visitor { + public: + Visitor() {} + Visitor(const Visitor&) = delete; + Visitor& operator=(const Visitor&) = delete; + + // Called when the stream is closed. + virtual void OnClose(QuicSpdyStream* stream) = 0; + + // Allows subclasses to override and do work. + virtual void OnPromiseHeadersComplete(QuicStreamId /*promised_id*/, + size_t /*frame_len*/) {} + + protected: + virtual ~Visitor() {} + }; + + QuicSpdyStream(QuicStreamId id, QuicSpdySession* spdy_session, + StreamType type); + QuicSpdyStream(PendingStream* pending, QuicSpdySession* spdy_session); + QuicSpdyStream(const QuicSpdyStream&) = delete; + QuicSpdyStream& operator=(const QuicSpdyStream&) = delete; + ~QuicSpdyStream() override; + + // QuicStream implementation + void OnClose() override; + + // Override to maybe close the write side after writing. + void OnCanWrite() override; + + // Called by the session when headers with a priority have been received + // for this stream. This method will only be called for server streams. + virtual void OnStreamHeadersPriority( + const spdy::SpdyStreamPrecedence& precedence); + + // Called by the session when decompressed headers have been completely + // delivered to this stream. If |fin| is true, then this stream + // should be closed; no more data will be sent by the peer. + virtual void OnStreamHeaderList(bool fin, size_t frame_len, + const QuicHeaderList& header_list); + + // Called by the session when decompressed push promise headers have + // been completely delivered to this stream. + virtual void OnPromiseHeaderList(QuicStreamId promised_id, size_t frame_len, + const QuicHeaderList& header_list); + + // Called by the session when a PRIORITY frame has been been received for this + // stream. This method will only be called for server streams. + void OnPriorityFrame(const spdy::SpdyStreamPrecedence& precedence); + + // Override the base class to not discard response when receiving + // QUIC_STREAM_NO_ERROR. + void OnStreamReset(const QuicRstStreamFrame& frame) override; + void ResetWithError(QuicResetStreamError error) override; + bool OnStopSending(QuicResetStreamError error) override; + + // Called by the sequencer when new data is available. Decodes the data and + // calls OnBodyAvailable() to pass to the upper layer. + void OnDataAvailable() override; + + // Called in OnDataAvailable() after it finishes the decoding job. + virtual void OnBodyAvailable() = 0; + + // Writes the headers contained in |header_block| on the dedicated headers + // stream or on this stream, depending on VersionUsesHttp3(). Returns the + // number of bytes sent, including data sent on the encoder stream when using + // QPACK. + virtual size_t WriteHeaders( + spdy::Http2HeaderBlock header_block, bool fin, + quiche::QuicheReferenceCountedPointer + ack_listener); + + // Sends |data| to the peer, or buffers if it can't be sent immediately. + virtual void WriteOrBufferBody(absl::string_view data, bool fin); + + // Writes the trailers contained in |trailer_block| on the dedicated headers + // stream or on this stream, depending on VersionUsesHttp3(). Trailers will + // always have the FIN flag set. Returns the number of bytes sent, including + // data sent on the encoder stream when using QPACK. + virtual size_t WriteTrailers( + spdy::Http2HeaderBlock trailer_block, + quiche::QuicheReferenceCountedPointer + ack_listener); + + // Override to report newly acked bytes via ack_listener_. + bool OnStreamFrameAcked(QuicStreamOffset offset, QuicByteCount data_length, + bool fin_acked, QuicTime::Delta ack_delay_time, + QuicTime receive_timestamp, + QuicByteCount* newly_acked_length) override; + + // Override to report bytes retransmitted via ack_listener_. + void OnStreamFrameRetransmitted(QuicStreamOffset offset, + QuicByteCount data_length, + bool fin_retransmitted) override; + + // Does the same thing as WriteOrBufferBody except this method takes iovec + // as the data input. Right now it only calls WritevData. + QuicConsumedData WritevBody(const struct iovec* iov, int count, bool fin); + + // Does the same thing as WriteOrBufferBody except this method takes + // memslicespan as the data input. Right now it only calls WriteMemSlices. + QuicConsumedData WriteBodySlices(absl::Span slices, + bool fin); + + // Marks the trailers as consumed. This applies to the case where this object + // receives headers and trailers as QuicHeaderLists via calls to + // OnStreamHeaderList(). Trailer data will be consumed from the sequencer only + // once all body data has been consumed. + void MarkTrailersConsumed(); + + // Clears |header_list_|. + void ConsumeHeaderList(); + + // This block of functions wraps the sequencer's functions of the same + // name. These methods return uncompressed data until that has + // been fully processed. Then they simply delegate to the sequencer. + virtual size_t Readv(const struct iovec* iov, size_t iov_len); + virtual int GetReadableRegions(iovec* iov, size_t iov_len) const; + void MarkConsumed(size_t num_bytes); + + // Returns true if header contains a valid 3-digit status and parse the status + // code to |status_code|. + static bool ParseHeaderStatusCode(const spdy::Http2HeaderBlock& header, + int* status_code); + // Returns true if status_value (associated with :status) contains a valid + // 3-digit status and parse the status code to |status_code|. + static bool ParseHeaderStatusCode(absl::string_view status_value, + int* status_code); + + // Returns true when all data from the peer has been read and consumed, + // including the fin. + bool IsDoneReading() const; + bool HasBytesToRead() const; + + void set_visitor(Visitor* visitor) { visitor_ = visitor; } + + bool headers_decompressed() const { return headers_decompressed_; } + + // Returns total amount of body bytes that have been read. + uint64_t total_body_bytes_read() const; + + const QuicHeaderList& header_list() const { return header_list_; } + + bool trailers_decompressed() const { return trailers_decompressed_; } + + // Returns whatever trailers have been received for this stream. + const spdy::Http2HeaderBlock& received_trailers() const { + return received_trailers_; + } + + // Returns true if headers have been fully read and consumed. + bool FinishedReadingHeaders() const; + + // Returns true if FIN has been received and either trailers have been fully + // read and consumed or there are no trailers. + bool FinishedReadingTrailers() const; + + // Returns true if the sequencer has delivered the FIN, and no more body bytes + // will be available. + bool IsSequencerClosed() { return sequencer()->IsClosed(); } + + // QpackDecodedHeadersAccumulator::Visitor implementation. + void OnHeadersDecoded(QuicHeaderList headers, + bool header_list_size_limit_exceeded) override; + void OnHeaderDecodingError(QuicErrorCode error_code, + absl::string_view error_message) override; + + QuicSpdySession* spdy_session() const { return spdy_session_; } + + // Send PRIORITY_UPDATE frame and update |last_sent_priority_| if + // |last_sent_priority_| is different from current priority. + void MaybeSendPriorityUpdateFrame() override; + + // Returns the WebTransport session owned by this stream, if one exists. + WebTransportHttp3* web_transport() { return web_transport_.get(); } + + // Returns the WebTransport data stream associated with this QUIC stream, or + // null if this is not a WebTransport data stream. + WebTransportStream* web_transport_stream() { + if (web_transport_data_ == nullptr) { + return nullptr; + } + return &web_transport_data_->adapter; + } + + // Sends a WEBTRANSPORT_STREAM frame and sets up the appropriate metadata. + void ConvertToWebTransportDataStream(WebTransportSessionId session_id); + + void OnCanWriteNewData() override; + + // If this stream is a WebTransport data stream, closes the connection with an + // error, and returns false. + bool AssertNotWebTransportDataStream(absl::string_view operation); + + // Indicates whether a call to WriteBodySlices will be successful and not + // rejected due to buffer being full. |write_size| must be non-zero. + bool CanWriteNewBodyData(QuicByteCount write_size) const; + + // From CapsuleParser::Visitor. + bool OnCapsule(const quiche::Capsule& capsule) override; + void OnCapsuleParseFailure(absl::string_view error_message) override; + + // Sends an HTTP/3 datagram. The stream ID is not part of |payload|. Virtual + // to allow mocking in tests. + virtual MessageStatus SendHttp3Datagram(absl::string_view payload); + + class QUIC_EXPORT_PRIVATE Http3DatagramVisitor { + public: + virtual ~Http3DatagramVisitor() {} + + // Called when an HTTP/3 datagram is received. |payload| does not contain + // the stream ID. + virtual void OnHttp3Datagram(QuicStreamId stream_id, + absl::string_view payload) = 0; + + // Called when a Capsule with an unknown type is received. + virtual void OnUnknownCapsule(QuicStreamId stream_id, + const quiche::UnknownCapsule& capsule) = 0; + }; + + // Registers |visitor| to receive HTTP/3 datagrams and enables Capsule + // Protocol by registering a CapsuleParser. |visitor| must be valid until a + // corresponding call to UnregisterHttp3DatagramVisitor. + void RegisterHttp3DatagramVisitor(Http3DatagramVisitor* visitor); + + // Unregisters an HTTP/3 datagram visitor. Must only be called after a call to + // RegisterHttp3DatagramVisitor. + void UnregisterHttp3DatagramVisitor(); + + // Replaces the current HTTP/3 datagram visitor with a different visitor. + // Mainly meant to be used by the visitors' move operators. + void ReplaceHttp3DatagramVisitor(Http3DatagramVisitor* visitor); + + class QUIC_EXPORT_PRIVATE ConnectIpVisitor { + public: + virtual ~ConnectIpVisitor() {} + + virtual bool OnAddressAssignCapsule( + const quiche::AddressAssignCapsule& capsule) = 0; + virtual bool OnAddressRequestCapsule( + const quiche::AddressRequestCapsule& capsule) = 0; + virtual bool OnRouteAdvertisementCapsule( + const quiche::RouteAdvertisementCapsule& capsule) = 0; + virtual void OnHeadersWritten() = 0; + }; + + // Registers |visitor| to receive CONNECT-IP capsules. |visitor| must be + // valid until a corresponding call to UnregisterConnectIpVisitor. + void RegisterConnectIpVisitor(ConnectIpVisitor* visitor); + + // Unregisters a CONNECT-IP visitor. Must only be called after a call to + // RegisterConnectIpVisitor. + void UnregisterConnectIpVisitor(); + + // Replaces the current CONNECT-IP visitor with a different visitor. + // Mainly meant to be used by the visitors' move operators. + void ReplaceConnectIpVisitor(ConnectIpVisitor* visitor); + + // Sets max datagram time in queue. + void SetMaxDatagramTimeInQueue(QuicTime::Delta max_time_in_queue); + + void OnDatagramReceived(QuicDataReader* reader); + + QuicByteCount GetMaxDatagramSize() const; + + // Writes |capsule| onto the DATA stream. + void WriteCapsule(const quiche::Capsule& capsule, bool fin = false); + + void WriteGreaseCapsule(); + + protected: + // Called when the received headers are too large. By default this will + // reset the stream. + virtual void OnHeadersTooLarge(); + + virtual void OnInitialHeadersComplete(bool fin, size_t frame_len, + const QuicHeaderList& header_list); + virtual void OnTrailingHeadersComplete(bool fin, size_t frame_len, + const QuicHeaderList& header_list); + virtual size_t WriteHeadersImpl( + spdy::Http2HeaderBlock header_block, bool fin, + quiche::QuicheReferenceCountedPointer + ack_listener); + + virtual bool CopyAndValidateTrailers(const QuicHeaderList& header_list, + bool expect_final_byte_offset, + size_t* final_byte_offset, + spdy::Http2HeaderBlock* trailers); + + Visitor* visitor() { return visitor_; } + + void set_headers_decompressed(bool val) { headers_decompressed_ = val; } + + void set_ack_listener( + quiche::QuicheReferenceCountedPointer + ack_listener) { + ack_listener_ = std::move(ack_listener); + } + + void OnWriteSideInDataRecvdState() override; + + virtual bool AreHeadersValid(const QuicHeaderList& header_list) const; + // TODO(b/202433856) Merge AreHeaderFieldValueValid into AreHeadersValid once + // all flags guarding the behavior of AreHeadersValid has been rolled out. + virtual bool AreHeaderFieldValuesValid( + const QuicHeaderList& header_list) const; + + // Reset stream upon invalid request headers. + virtual void OnInvalidHeaders(); + + private: + friend class test::QuicSpdyStreamPeer; + friend class test::QuicStreamPeer; + friend class QuicStreamUtils; + class HttpDecoderVisitor; + + struct QUIC_EXPORT_PRIVATE WebTransportDataStream { + WebTransportDataStream(QuicSpdyStream* stream, + WebTransportSessionId session_id); + + WebTransportSessionId session_id; + WebTransportStreamAdapter adapter; + }; + + // Called by HttpDecoderVisitor. + bool OnDataFrameStart(QuicByteCount header_length, + QuicByteCount payload_length); + bool OnDataFramePayload(absl::string_view payload); + bool OnDataFrameEnd(); + bool OnHeadersFrameStart(QuicByteCount header_length, + QuicByteCount payload_length); + bool OnHeadersFramePayload(absl::string_view payload); + bool OnHeadersFrameEnd(); + void OnWebTransportStreamFrameType(QuicByteCount header_length, + WebTransportSessionId session_id); + bool OnUnknownFrameStart(uint64_t frame_type, QuicByteCount header_length, + QuicByteCount payload_length); + bool OnUnknownFramePayload(absl::string_view payload); + bool OnUnknownFrameEnd(); + + // Given the interval marked by [|offset|, |offset| + |data_length|), return + // the number of frame header bytes contained in it. + QuicByteCount GetNumFrameHeadersInInterval(QuicStreamOffset offset, + QuicByteCount data_length) const; + + void MaybeProcessSentWebTransportHeaders(spdy::Http2HeaderBlock& headers); + void MaybeProcessReceivedWebTransportHeaders(); + + // Writes HTTP/3 DATA frame header. If |force_write| is true, use + // WriteOrBufferData if send buffer cannot accomodate the header + data. + ABSL_MUST_USE_RESULT bool WriteDataFrameHeader(QuicByteCount data_length, + bool force_write); + + // Simply calls OnBodyAvailable() unless capsules are in use, in which case + // pass the capsule fragments to the capsule manager. + void HandleBodyAvailable(); + + // Called when a datagram frame or capsule is received. + void HandleReceivedDatagram(absl::string_view payload); + + QuicSpdySession* spdy_session_; + + bool on_body_available_called_because_sequencer_is_closed_; + + Visitor* visitor_; + + // True if read side processing is blocked while waiting for callback from + // QPACK decoder. + bool blocked_on_decoding_headers_; + // True if the headers have been completely decompressed. + bool headers_decompressed_; + // True if uncompressed headers or trailers exceed maximum allowed size + // advertised to peer via SETTINGS_MAX_HEADER_LIST_SIZE. + bool header_list_size_limit_exceeded_; + // Contains a copy of the decompressed header (name, value) pairs until they + // are consumed via Readv. + QuicHeaderList header_list_; + // Length of most recently received HEADERS frame payload. + QuicByteCount headers_payload_length_; + + // True if the trailers have been completely decompressed. + bool trailers_decompressed_; + // True if the trailers have been consumed. + bool trailers_consumed_; + + // The parsed trailers received from the peer. + spdy::Http2HeaderBlock received_trailers_; + + // Headers accumulator for decoding HEADERS frame payload. + std::unique_ptr + qpack_decoded_headers_accumulator_; + // Visitor of the HttpDecoder. + std::unique_ptr http_decoder_visitor_; + // HttpDecoder for processing raw incoming stream frames. + HttpDecoder decoder_; + // Object that manages references to DATA frame payload fragments buffered by + // the sequencer and calculates how much data should be marked consumed with + // the sequencer each time new stream data is processed. + QuicSpdyStreamBodyManager body_manager_; + + std::unique_ptr capsule_parser_; + + // Sequencer offset keeping track of how much data HttpDecoder has processed. + // Initial value is zero for fresh streams, or sequencer()->NumBytesConsumed() + // at time of construction if a PendingStream is converted to account for the + // length of the unidirectional stream type at the beginning of the stream. + QuicStreamOffset sequencer_offset_; + + // True when inside an HttpDecoder::ProcessInput() call. + // Used for detecting reentrancy. + bool is_decoder_processing_input_; + + // Ack listener of this stream, and it is notified when any of written bytes + // are acked or retransmitted. + quiche::QuicheReferenceCountedPointer ack_listener_; + + // Offset of unacked frame headers. + QuicIntervalSet unacked_frame_headers_offsets_; + + // Priority parameters sent in the last PRIORITY_UPDATE frame, or default + // values defined by RFC9218 if no PRIORITY_UPDATE frame has been sent. + QuicStreamPriority last_sent_priority_; + + // If this stream is a WebTransport extended CONNECT stream, contains the + // WebTransport session associated with this stream. + std::unique_ptr web_transport_; + + // If this stream is a WebTransport data stream, |web_transport_data_| + // contains all of the associated metadata. + std::unique_ptr web_transport_data_; + + // HTTP/3 Datagram support. + Http3DatagramVisitor* datagram_visitor_ = nullptr; + // CONNECT-IP support. + ConnectIpVisitor* connect_ip_visitor_ = nullptr; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_HTTP_QUIC_SPDY_STREAM_H_ diff --git a/quiche/quic/core/http/quic_spdy_stream_body_manager.cc b/quiche/quic/core/http/quic_spdy_stream_body_manager.cc new file mode 100644 index 000000000000..efb32c374949 --- /dev/null +++ b/quiche/quic/core/http/quic_spdy_stream_body_manager.cc @@ -0,0 +1,146 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_spdy_stream_body_manager.h" + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +QuicSpdyStreamBodyManager::QuicSpdyStreamBodyManager() + : total_body_bytes_received_(0) {} + +size_t QuicSpdyStreamBodyManager::OnNonBody(QuicByteCount length) { + QUICHE_DCHECK_NE(0u, length); + + if (fragments_.empty()) { + // Non-body bytes can be consumed immediately, because all previously + // received body bytes have been read. + return length; + } + + // Non-body bytes will be consumed after last body fragment is read. + fragments_.back().trailing_non_body_byte_count += length; + return 0; +} + +void QuicSpdyStreamBodyManager::OnBody(absl::string_view body) { + QUICHE_DCHECK(!body.empty()); + + fragments_.push_back({body, 0}); + total_body_bytes_received_ += body.length(); +} + +size_t QuicSpdyStreamBodyManager::OnBodyConsumed(size_t num_bytes) { + QuicByteCount bytes_to_consume = 0; + size_t remaining_bytes = num_bytes; + + while (remaining_bytes > 0) { + if (fragments_.empty()) { + QUIC_BUG(quic_bug_10394_1) << "Not enough available body to consume."; + return 0; + } + + Fragment& fragment = fragments_.front(); + const absl::string_view body = fragment.body; + + if (body.length() > remaining_bytes) { + // Consume leading |remaining_bytes| bytes of body. + bytes_to_consume += remaining_bytes; + fragment.body = body.substr(remaining_bytes); + return bytes_to_consume; + } + + // Consume entire fragment and the following + // |trailing_non_body_byte_count| bytes. + remaining_bytes -= body.length(); + bytes_to_consume += body.length() + fragment.trailing_non_body_byte_count; + fragments_.pop_front(); + } + + return bytes_to_consume; +} + +int QuicSpdyStreamBodyManager::PeekBody(iovec* iov, size_t iov_len) const { + QUICHE_DCHECK(iov); + QUICHE_DCHECK_GT(iov_len, 0u); + + // TODO(bnc): Is this really necessary? + if (fragments_.empty()) { + iov[0].iov_base = nullptr; + iov[0].iov_len = 0; + return 0; + } + + size_t iov_filled = 0; + while (iov_filled < fragments_.size() && iov_filled < iov_len) { + absl::string_view body = fragments_[iov_filled].body; + iov[iov_filled].iov_base = const_cast(body.data()); + iov[iov_filled].iov_len = body.size(); + iov_filled++; + } + + return iov_filled; +} + +size_t QuicSpdyStreamBodyManager::ReadBody(const struct iovec* iov, + size_t iov_len, + size_t* total_bytes_read) { + *total_bytes_read = 0; + QuicByteCount bytes_to_consume = 0; + + // The index of iovec to write to. + size_t index = 0; + // Address to write to within current iovec. + char* dest = reinterpret_cast(iov[index].iov_base); + // Remaining space in current iovec. + size_t dest_remaining = iov[index].iov_len; + + while (!fragments_.empty()) { + Fragment& fragment = fragments_.front(); + const absl::string_view body = fragment.body; + + const size_t bytes_to_copy = + std::min(body.length(), dest_remaining); + + // According to Section 7.1.4 of the C11 standard (ISO/IEC 9899:2011), null + // pointers should not be passed to standard library functions. + if (bytes_to_copy > 0) { + memcpy(dest, body.data(), bytes_to_copy); + } + + bytes_to_consume += bytes_to_copy; + *total_bytes_read += bytes_to_copy; + + if (bytes_to_copy == body.length()) { + // Entire fragment read. + bytes_to_consume += fragment.trailing_non_body_byte_count; + fragments_.pop_front(); + } else { + // Consume leading |bytes_to_copy| bytes of body. + fragment.body = body.substr(bytes_to_copy); + } + + if (bytes_to_copy == dest_remaining) { + // Current iovec full. + ++index; + if (index == iov_len) { + break; + } + dest = reinterpret_cast(iov[index].iov_base); + dest_remaining = iov[index].iov_len; + } else { + // Advance destination parameters within this iovec. + dest += bytes_to_copy; + dest_remaining -= bytes_to_copy; + } + } + + return bytes_to_consume; +} + +} // namespace quic diff --git a/quiche/quic/core/http/quic_spdy_stream_body_manager.h b/quiche/quic/core/http/quic_spdy_stream_body_manager.h new file mode 100644 index 000000000000..34bf4461243b --- /dev/null +++ b/quiche/quic/core/http/quic_spdy_stream_body_manager.h @@ -0,0 +1,93 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_HTTP_QUIC_SPDY_STREAM_BODY_MANAGER_H_ +#define QUICHE_QUIC_CORE_HTTP_QUIC_SPDY_STREAM_BODY_MANAGER_H_ + +#include "absl/base/attributes.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/common/platform/api/quiche_iovec.h" +#include "quiche/common/quiche_circular_deque.h" + +namespace quic { + +// All data that a request stream receives falls into one of two categories: +// * "body", that is, DATA frame payload, which the QuicStreamSequencer must +// buffer until it is read; +// * everything else, which QuicSpdyStream immediately processes and thus could +// be marked as consumed with QuicStreamSequencer, unless there is some piece +// of body received prior that still needs to be buffered. +// QuicSpdyStreamBodyManager does two things: it keeps references to body +// fragments (owned by QuicStreamSequencer) and offers methods to read them; and +// it calculates the total number of bytes (including non-body bytes) the caller +// needs to mark consumed (with QuicStreamSequencer) when non-body bytes are +// received or when body is consumed. +class QUIC_EXPORT_PRIVATE QuicSpdyStreamBodyManager { + public: + QuicSpdyStreamBodyManager(); + ~QuicSpdyStreamBodyManager() = default; + + // One of the following two methods must be called every time data is received + // on the request stream. + + // Called when data that could immediately be marked consumed with the + // sequencer (provided that all previous body fragments are consumed) is + // received. |length| must be positive. Returns number of bytes the caller + // must mark consumed, which might be zero. + ABSL_MUST_USE_RESULT size_t OnNonBody(QuicByteCount length); + + // Called when body is received. |body| is added to |fragments_|. The data + // pointed to by |body| must be kept alive until an OnBodyConsumed() or + // ReadBody() call consumes it. |body| must not be empty. + void OnBody(absl::string_view body); + + // Internally marks |num_bytes| of body consumed. |num_bytes| might be zero. + // Returns the number of bytes that the caller should mark consumed with the + // sequencer, which is the sum of |num_bytes| for body, and the number of any + // interleaving or immediately trailing non-body bytes. + ABSL_MUST_USE_RESULT size_t OnBodyConsumed(size_t num_bytes); + + // Set up to |iov_len| elements of iov[] to point to available bodies: each + // iov[i].iov_base will point to a body fragment, and iov[i].iov_len will be + // set to its length. No data is copied, no data is consumed. Returns the + // number of iov set. + int PeekBody(iovec* iov, size_t iov_len) const; + + // Copies data from available bodies into at most |iov_len| elements of iov[]. + // Internally consumes copied body bytes as well as all interleaving and + // immediately trailing non-body bytes. |iov.iov_base| and |iov.iov_len| are + // preassigned and will not be changed. Returns the total number of bytes the + // caller shall mark consumed. Sets |*total_bytes_read| to the total number + // of body bytes read. + ABSL_MUST_USE_RESULT size_t ReadBody(const struct iovec* iov, size_t iov_len, + size_t* total_bytes_read); + + bool HasBytesToRead() const { return !fragments_.empty(); } + + uint64_t total_body_bytes_received() const { + return total_body_bytes_received_; + } + + private: + // A Fragment instance represents a body fragment with a count of bytes + // received afterwards but before the next body fragment that can be marked + // consumed as soon as all of the body fragment is read. + struct QUIC_EXPORT_PRIVATE Fragment { + // |body| must not be empty. + absl::string_view body; + // Might be zero. + QuicByteCount trailing_non_body_byte_count; + }; + // Queue of body fragments and trailing non-body byte counts. + quiche::QuicheCircularDeque fragments_; + // Total body bytes received. + QuicByteCount total_body_bytes_received_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_HTTP_QUIC_SPDY_STREAM_BODY_MANAGER_H_ diff --git a/quiche/quic/core/http/quic_spdy_stream_body_manager_test.cc b/quiche/quic/core/http/quic_spdy_stream_body_manager_test.cc new file mode 100644 index 000000000000..7eca7314920f --- /dev/null +++ b/quiche/quic/core/http/quic_spdy_stream_body_manager_test.cc @@ -0,0 +1,286 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_spdy_stream_body_manager.h" + +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { + +namespace test { + +namespace { + +class QuicSpdyStreamBodyManagerTest : public QuicTest { + protected: + QuicSpdyStreamBodyManager body_manager_; +}; + +TEST_F(QuicSpdyStreamBodyManagerTest, HasBytesToRead) { + EXPECT_FALSE(body_manager_.HasBytesToRead()); + EXPECT_EQ(0u, body_manager_.total_body_bytes_received()); + + const QuicByteCount header_length = 3; + EXPECT_EQ(header_length, body_manager_.OnNonBody(header_length)); + + EXPECT_FALSE(body_manager_.HasBytesToRead()); + EXPECT_EQ(0u, body_manager_.total_body_bytes_received()); + + std::string body(1024, 'a'); + body_manager_.OnBody(body); + + EXPECT_TRUE(body_manager_.HasBytesToRead()); + EXPECT_EQ(1024u, body_manager_.total_body_bytes_received()); +} + +TEST_F(QuicSpdyStreamBodyManagerTest, ConsumeMoreThanAvailable) { + std::string body(1024, 'a'); + body_manager_.OnBody(body); + size_t bytes_to_consume = 0; + EXPECT_QUIC_BUG(bytes_to_consume = body_manager_.OnBodyConsumed(2048), + "Not enough available body to consume."); + EXPECT_EQ(0u, bytes_to_consume); +} + +TEST_F(QuicSpdyStreamBodyManagerTest, OnBodyConsumed) { + struct { + std::vector frame_header_lengths; + std::vector frame_payloads; + std::vector body_bytes_to_read; + std::vector expected_return_values; + } const kOnBodyConsumedTestData[] = { + // One frame consumed in one call. + {{2}, {"foobar"}, {6}, {6}}, + // Two frames consumed in one call. + {{3, 5}, {"foobar", "baz"}, {9}, {14}}, + // One frame consumed in two calls. + {{2}, {"foobar"}, {4, 2}, {4, 2}}, + // Two frames consumed in two calls matching frame boundaries. + {{3, 5}, {"foobar", "baz"}, {6, 3}, {11, 3}}, + // Two frames consumed in two calls, + // the first call only consuming part of the first frame. + {{3, 5}, {"foobar", "baz"}, {5, 4}, {5, 9}}, + // Two frames consumed in two calls, + // the first call consuming the entire first frame and part of the second. + {{3, 5}, {"foobar", "baz"}, {7, 2}, {12, 2}}, + }; + + for (size_t test_case_index = 0; + test_case_index < ABSL_ARRAYSIZE(kOnBodyConsumedTestData); + ++test_case_index) { + const std::vector& frame_header_lengths = + kOnBodyConsumedTestData[test_case_index].frame_header_lengths; + const std::vector& frame_payloads = + kOnBodyConsumedTestData[test_case_index].frame_payloads; + const std::vector& body_bytes_to_read = + kOnBodyConsumedTestData[test_case_index].body_bytes_to_read; + const std::vector& expected_return_values = + kOnBodyConsumedTestData[test_case_index].expected_return_values; + + for (size_t frame_index = 0; frame_index < frame_header_lengths.size(); + ++frame_index) { + // Frame header of first frame can immediately be consumed, but not the + // other frames. Each test case start with an empty + // QuicSpdyStreamBodyManager. + EXPECT_EQ(frame_index == 0 ? frame_header_lengths[frame_index] : 0u, + body_manager_.OnNonBody(frame_header_lengths[frame_index])); + body_manager_.OnBody(frame_payloads[frame_index]); + } + + for (size_t call_index = 0; call_index < body_bytes_to_read.size(); + ++call_index) { + EXPECT_EQ(expected_return_values[call_index], + body_manager_.OnBodyConsumed(body_bytes_to_read[call_index])); + } + + EXPECT_FALSE(body_manager_.HasBytesToRead()); + } +} + +TEST_F(QuicSpdyStreamBodyManagerTest, PeekBody) { + struct { + std::vector frame_header_lengths; + std::vector frame_payloads; + size_t iov_len; + } const kPeekBodyTestData[] = { + // No frames, more iovecs than frames. + {{}, {}, 1}, + // One frame, same number of iovecs. + {{3}, {"foobar"}, 1}, + // One frame, more iovecs than frames. + {{3}, {"foobar"}, 2}, + // Two frames, fewer iovecs than frames. + {{3, 5}, {"foobar", "baz"}, 1}, + // Two frames, same number of iovecs. + {{3, 5}, {"foobar", "baz"}, 2}, + // Two frames, more iovecs than frames. + {{3, 5}, {"foobar", "baz"}, 3}, + }; + + for (size_t test_case_index = 0; + test_case_index < ABSL_ARRAYSIZE(kPeekBodyTestData); ++test_case_index) { + const std::vector& frame_header_lengths = + kPeekBodyTestData[test_case_index].frame_header_lengths; + const std::vector& frame_payloads = + kPeekBodyTestData[test_case_index].frame_payloads; + size_t iov_len = kPeekBodyTestData[test_case_index].iov_len; + + QuicSpdyStreamBodyManager body_manager; + + for (size_t frame_index = 0; frame_index < frame_header_lengths.size(); + ++frame_index) { + // Frame header of first frame can immediately be consumed, but not the + // other frames. Each test case uses a new QuicSpdyStreamBodyManager + // instance. + EXPECT_EQ(frame_index == 0 ? frame_header_lengths[frame_index] : 0u, + body_manager.OnNonBody(frame_header_lengths[frame_index])); + body_manager.OnBody(frame_payloads[frame_index]); + } + + std::vector iovecs; + iovecs.resize(iov_len); + size_t iovs_filled = std::min(frame_payloads.size(), iov_len); + ASSERT_EQ(iovs_filled, + static_cast(body_manager.PeekBody(&iovecs[0], iov_len))); + for (size_t iovec_index = 0; iovec_index < iovs_filled; ++iovec_index) { + EXPECT_EQ(frame_payloads[iovec_index], + absl::string_view( + static_cast(iovecs[iovec_index].iov_base), + iovecs[iovec_index].iov_len)); + } + } +} + +TEST_F(QuicSpdyStreamBodyManagerTest, ReadBody) { + struct { + std::vector frame_header_lengths; + std::vector frame_payloads; + std::vector> iov_lengths; + std::vector expected_total_bytes_read; + std::vector expected_return_values; + } const kReadBodyTestData[] = { + // One frame, one read with smaller iovec. + {{4}, {"foo"}, {{2}}, {2}, {2}}, + // One frame, one read with same size iovec. + {{4}, {"foo"}, {{3}}, {3}, {3}}, + // One frame, one read with larger iovec. + {{4}, {"foo"}, {{5}}, {3}, {3}}, + // One frame, one read with two iovecs, smaller total size. + {{4}, {"foobar"}, {{2, 3}}, {5}, {5}}, + // One frame, one read with two iovecs, same total size. + {{4}, {"foobar"}, {{2, 4}}, {6}, {6}}, + // One frame, one read with two iovecs, larger total size in last iovec. + {{4}, {"foobar"}, {{2, 6}}, {6}, {6}}, + // One frame, one read with extra iovecs, body ends at iovec boundary. + {{4}, {"foobar"}, {{2, 4, 4, 3}}, {6}, {6}}, + // One frame, one read with extra iovecs, body ends not at iovec boundary. + {{4}, {"foobar"}, {{2, 7, 4, 3}}, {6}, {6}}, + // One frame, two reads with two iovecs each, smaller total size. + {{4}, {"foobarbaz"}, {{2, 1}, {3, 2}}, {3, 5}, {3, 5}}, + // One frame, two reads with two iovecs each, same total size. + {{4}, {"foobarbaz"}, {{2, 1}, {4, 2}}, {3, 6}, {3, 6}}, + // One frame, two reads with two iovecs each, larger total size. + {{4}, {"foobarbaz"}, {{2, 1}, {4, 10}}, {3, 6}, {3, 6}}, + // Two frames, one read with smaller iovec. + {{4, 3}, {"foobar", "baz"}, {{8}}, {8}, {11}}, + // Two frames, one read with same size iovec. + {{4, 3}, {"foobar", "baz"}, {{9}}, {9}, {12}}, + // Two frames, one read with larger iovec. + {{4, 3}, {"foobar", "baz"}, {{10}}, {9}, {12}}, + // Two frames, one read with two iovecs, smaller total size. + {{4, 3}, {"foobar", "baz"}, {{4, 3}}, {7}, {10}}, + // Two frames, one read with two iovecs, same total size. + {{4, 3}, {"foobar", "baz"}, {{4, 5}}, {9}, {12}}, + // Two frames, one read with two iovecs, larger total size in last iovec. + {{4, 3}, {"foobar", "baz"}, {{4, 6}}, {9}, {12}}, + // Two frames, one read with extra iovecs, body ends at iovec boundary. + {{4, 3}, {"foobar", "baz"}, {{4, 6, 4, 3}}, {9}, {12}}, + // Two frames, one read with extra iovecs, body ends not at iovec + // boundary. + {{4, 3}, {"foobar", "baz"}, {{4, 7, 4, 3}}, {9}, {12}}, + // Two frames, two reads with two iovecs each, reads end on frame + // boundary. + {{4, 3}, {"foobar", "baz"}, {{2, 4}, {2, 1}}, {6, 3}, {9, 3}}, + // Three frames, three reads, extra iovecs, no iovec ends on frame + // boundary. + {{4, 3, 6}, + {"foobar", "bazquux", "qux"}, + {{4, 3}, {2, 3}, {5, 3}}, + {7, 5, 4}, + {10, 5, 10}}, + }; + + for (size_t test_case_index = 0; + test_case_index < ABSL_ARRAYSIZE(kReadBodyTestData); ++test_case_index) { + const std::vector& frame_header_lengths = + kReadBodyTestData[test_case_index].frame_header_lengths; + const std::vector& frame_payloads = + kReadBodyTestData[test_case_index].frame_payloads; + const std::vector>& iov_lengths = + kReadBodyTestData[test_case_index].iov_lengths; + const std::vector& expected_total_bytes_read = + kReadBodyTestData[test_case_index].expected_total_bytes_read; + const std::vector& expected_return_values = + kReadBodyTestData[test_case_index].expected_return_values; + + QuicSpdyStreamBodyManager body_manager; + + std::string received_body; + + for (size_t frame_index = 0; frame_index < frame_header_lengths.size(); + ++frame_index) { + // Frame header of first frame can immediately be consumed, but not the + // other frames. Each test case uses a new QuicSpdyStreamBodyManager + // instance. + EXPECT_EQ(frame_index == 0 ? frame_header_lengths[frame_index] : 0u, + body_manager.OnNonBody(frame_header_lengths[frame_index])); + body_manager.OnBody(frame_payloads[frame_index]); + received_body.append(frame_payloads[frame_index]); + } + + std::string read_body; + + for (size_t call_index = 0; call_index < iov_lengths.size(); ++call_index) { + // Allocate single buffer for iovecs. + size_t total_iov_length = std::accumulate(iov_lengths[call_index].begin(), + iov_lengths[call_index].end(), + static_cast(0)); + std::string buffer(total_iov_length, 'z'); + + // Construct iovecs pointing to contiguous areas in the buffer. + std::vector iovecs; + size_t offset = 0; + for (size_t iov_length : iov_lengths[call_index]) { + QUICHE_CHECK(offset + iov_length <= buffer.size()); + iovecs.push_back({&buffer[offset], iov_length}); + offset += iov_length; + } + + // Make sure |total_bytes_read| differs from |expected_total_bytes_read|. + size_t total_bytes_read = expected_total_bytes_read[call_index] + 12; + EXPECT_EQ( + expected_return_values[call_index], + body_manager.ReadBody(&iovecs[0], iovecs.size(), &total_bytes_read)); + read_body.append(buffer.substr(0, total_bytes_read)); + } + + EXPECT_EQ(received_body.substr(0, read_body.size()), read_body); + EXPECT_EQ(read_body.size() < received_body.size(), + body_manager.HasBytesToRead()); + } +} + +} // anonymous namespace + +} // namespace test + +} // namespace quic diff --git a/quiche/quic/core/http/quic_spdy_stream_test.cc b/quiche/quic/core/http/quic_spdy_stream_test.cc new file mode 100644 index 000000000000..2b966b27095b --- /dev/null +++ b/quiche/quic/core/http/quic_spdy_stream_test.cc @@ -0,0 +1,3275 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/quic_spdy_stream.h" + +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/memory/memory.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/core/http/http_encoder.h" +#include "quiche/quic/core/http/quic_spdy_session.h" +#include "quiche/quic/core/http/spdy_utils.h" +#include "quiche/quic/core/http/web_transport_http3.h" +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/core/quic_stream_sequencer_buffer.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/core/quic_write_blocked_list.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/qpack/qpack_test_utils.h" +#include "quiche/quic/test_tools/quic_config_peer.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_flow_controller_peer.h" +#include "quiche/quic/test_tools/quic_session_peer.h" +#include "quiche/quic/test_tools/quic_spdy_session_peer.h" +#include "quiche/quic/test_tools/quic_spdy_stream_peer.h" +#include "quiche/quic/test_tools/quic_stream_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/capsule.h" +#include "quiche/common/quiche_ip_address.h" +#include "quiche/common/quiche_mem_slice_storage.h" +#include "quiche/common/simple_buffer_allocator.h" + +using quiche::Capsule; +using quiche::IpAddressRange; +using spdy::Http2HeaderBlock; +using spdy::kV3HighestPriority; +using spdy::kV3LowestPriority; +using testing::_; +using testing::AnyNumber; +using testing::AtLeast; +using testing::DoAll; +using testing::ElementsAre; +using testing::Invoke; +using testing::InvokeWithoutArgs; +using testing::MatchesRegex; +using testing::Pair; +using testing::Return; +using testing::SaveArg; +using testing::StrictMock; + +namespace quic { +namespace test { +namespace { + +const bool kShouldProcessData = true; +const char kDataFramePayload[] = "some data"; + +class TestCryptoStream : public QuicCryptoStream, public QuicCryptoHandshaker { + public: + explicit TestCryptoStream(QuicSession* session) + : QuicCryptoStream(session), + QuicCryptoHandshaker(this, session), + encryption_established_(false), + one_rtt_keys_available_(false), + params_(new QuicCryptoNegotiatedParameters) { + // Simulate a negotiated cipher_suite with a fake value. + params_->cipher_suite = 1; + } + + void OnHandshakeMessage(const CryptoHandshakeMessage& /*message*/) override { + encryption_established_ = true; + one_rtt_keys_available_ = true; + QuicErrorCode error; + std::string error_details; + session()->config()->SetInitialStreamFlowControlWindowToSend( + kInitialStreamFlowControlWindowForTest); + session()->config()->SetInitialSessionFlowControlWindowToSend( + kInitialSessionFlowControlWindowForTest); + if (session()->version().UsesTls()) { + if (session()->perspective() == Perspective::IS_CLIENT) { + session()->config()->SetOriginalConnectionIdToSend( + session()->connection()->connection_id()); + session()->config()->SetInitialSourceConnectionIdToSend( + session()->connection()->connection_id()); + } else { + session()->config()->SetInitialSourceConnectionIdToSend( + session()->connection()->client_connection_id()); + } + TransportParameters transport_parameters; + EXPECT_TRUE( + session()->config()->FillTransportParameters(&transport_parameters)); + error = session()->config()->ProcessTransportParameters( + transport_parameters, /* is_resumption = */ false, &error_details); + } else { + CryptoHandshakeMessage msg; + session()->config()->ToHandshakeMessage(&msg, transport_version()); + error = + session()->config()->ProcessPeerHello(msg, CLIENT, &error_details); + } + EXPECT_THAT(error, IsQuicNoError()); + session()->OnNewEncryptionKeyAvailable( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(session()->perspective())); + session()->OnConfigNegotiated(); + if (session()->version().UsesTls()) { + session()->OnTlsHandshakeComplete(); + } else { + session()->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + } + if (session()->version().UsesTls()) { + // HANDSHAKE_DONE frame. + EXPECT_CALL(*this, HasPendingRetransmission()); + } + session()->DiscardOldEncryptionKey(ENCRYPTION_INITIAL); + } + + // QuicCryptoStream implementation + ssl_early_data_reason_t EarlyDataReason() const override { + return ssl_early_data_unknown; + } + bool encryption_established() const override { + return encryption_established_; + } + bool one_rtt_keys_available() const override { + return one_rtt_keys_available_; + } + HandshakeState GetHandshakeState() const override { + return one_rtt_keys_available() ? HANDSHAKE_COMPLETE : HANDSHAKE_START; + } + void SetServerApplicationStateForResumption( + std::unique_ptr /*application_state*/) override {} + std::unique_ptr AdvanceKeysAndCreateCurrentOneRttDecrypter() + override { + return nullptr; + } + std::unique_ptr CreateCurrentOneRttEncrypter() override { + return nullptr; + } + const QuicCryptoNegotiatedParameters& crypto_negotiated_params() + const override { + return *params_; + } + CryptoMessageParser* crypto_message_parser() override { + return QuicCryptoHandshaker::crypto_message_parser(); + } + void OnPacketDecrypted(EncryptionLevel /*level*/) override {} + void OnOneRttPacketAcknowledged() override {} + void OnHandshakePacketSent() override {} + void OnConnectionClosed(QuicErrorCode /*error*/, + ConnectionCloseSource /*source*/) override {} + void OnHandshakeDoneReceived() override {} + void OnNewTokenReceived(absl::string_view /*token*/) override {} + std::string GetAddressToken( + const CachedNetworkParameters* /*cached_network_parameters*/) + const override { + return ""; + } + bool ValidateAddressToken(absl::string_view /*token*/) const override { + return true; + } + const CachedNetworkParameters* PreviousCachedNetworkParams() const override { + return nullptr; + } + void SetPreviousCachedNetworkParams( + CachedNetworkParameters /*cached_network_params*/) override {} + + MOCK_METHOD(void, OnCanWrite, (), (override)); + + bool HasPendingCryptoRetransmission() const override { return false; } + + MOCK_METHOD(bool, HasPendingRetransmission, (), (const, override)); + + bool ExportKeyingMaterial(absl::string_view /*label*/, + absl::string_view /*context*/, + size_t /*result_len*/, + std::string* /*result*/) override { + return false; + } + + SSL* GetSsl() const override { return nullptr; } + + bool IsCryptoFrameExpectedForEncryptionLevel( + EncryptionLevel level) const override { + return level != ENCRYPTION_ZERO_RTT; + } + + EncryptionLevel GetEncryptionLevelToSendCryptoDataOfSpace( + PacketNumberSpace space) const override { + switch (space) { + case INITIAL_DATA: + return ENCRYPTION_INITIAL; + case HANDSHAKE_DATA: + return ENCRYPTION_HANDSHAKE; + case APPLICATION_DATA: + return ENCRYPTION_FORWARD_SECURE; + default: + QUICHE_DCHECK(false); + return NUM_ENCRYPTION_LEVELS; + } + } + + private: + using QuicCryptoStream::session; + + bool encryption_established_; + bool one_rtt_keys_available_; + quiche::QuicheReferenceCountedPointer params_; +}; + +class TestStream : public QuicSpdyStream { + public: + TestStream(QuicStreamId id, QuicSpdySession* session, + bool should_process_data) + : QuicSpdyStream(id, session, BIDIRECTIONAL), + should_process_data_(should_process_data), + headers_payload_length_(0) {} + ~TestStream() override = default; + + using QuicSpdyStream::set_ack_listener; + using QuicStream::CloseWriteSide; + using QuicStream::WriteOrBufferData; + + void OnBodyAvailable() override { + if (!should_process_data_) { + return; + } + char buffer[2048]; + struct iovec vec; + vec.iov_base = buffer; + vec.iov_len = ABSL_ARRAYSIZE(buffer); + size_t bytes_read = Readv(&vec, 1); + data_ += std::string(buffer, bytes_read); + } + + MOCK_METHOD(void, WriteHeadersMock, (bool fin), ()); + + size_t WriteHeadersImpl( + spdy::Http2HeaderBlock header_block, bool fin, + quiche::QuicheReferenceCountedPointer + /*ack_listener*/) override { + saved_headers_ = std::move(header_block); + WriteHeadersMock(fin); + if (VersionUsesHttp3(transport_version())) { + // In this case, call QuicSpdyStream::WriteHeadersImpl() that does the + // actual work of closing the stream. + return QuicSpdyStream::WriteHeadersImpl(saved_headers_.Clone(), fin, + nullptr); + } + return 0; + } + + const std::string& data() const { return data_; } + const spdy::Http2HeaderBlock& saved_headers() const { return saved_headers_; } + + // Expose protected accessor. + const QuicStreamSequencer* sequencer() const { + return QuicStream::sequencer(); + } + + void OnStreamHeaderList(bool fin, size_t frame_len, + const QuicHeaderList& header_list) override { + headers_payload_length_ = frame_len; + QuicSpdyStream::OnStreamHeaderList(fin, frame_len, header_list); + } + + size_t headers_payload_length() const { return headers_payload_length_; } + + bool AreHeadersValid(const QuicHeaderList& header_list) const override { + return !GetQuicReloadableFlag(quic_verify_request_headers_2) || + QuicSpdyStream::AreHeadersValid(header_list); + } + + private: + bool should_process_data_; + spdy::Http2HeaderBlock saved_headers_; + std::string data_; + size_t headers_payload_length_; +}; + +class TestSession : public MockQuicSpdySession { + public: + explicit TestSession(QuicConnection* connection) + : MockQuicSpdySession(connection, /*create_mock_crypto_stream=*/false), + crypto_stream_(this) {} + + TestCryptoStream* GetMutableCryptoStream() override { + return &crypto_stream_; + } + + const TestCryptoStream* GetCryptoStream() const override { + return &crypto_stream_; + } + + bool ShouldNegotiateWebTransport() override { return enable_webtransport_; } + void EnableWebTransport() { enable_webtransport_ = true; } + + HttpDatagramSupport LocalHttpDatagramSupport() override { + return local_http_datagram_support_; + } + void set_local_http_datagram_support(HttpDatagramSupport value) { + local_http_datagram_support_ = value; + } + + private: + bool enable_webtransport_ = false; + HttpDatagramSupport local_http_datagram_support_ = HttpDatagramSupport::kNone; + StrictMock crypto_stream_; +}; + +class TestMockUpdateStreamSession : public MockQuicSpdySession { + public: + explicit TestMockUpdateStreamSession(QuicConnection* connection) + : MockQuicSpdySession(connection) {} + + void UpdateStreamPriority(QuicStreamId id, + const QuicStreamPriority& new_priority) override { + EXPECT_EQ(id, expected_stream_->id()); + EXPECT_EQ(expected_priority_, new_priority.http()); + EXPECT_EQ(QuicStreamPriority(expected_priority_), + expected_stream_->priority()); + } + + void SetExpectedStream(QuicSpdyStream* stream) { expected_stream_ = stream; } + void SetExpectedPriority(const HttpStreamPriority& priority) { + expected_priority_ = priority; + } + + private: + QuicSpdyStream* expected_stream_; + HttpStreamPriority expected_priority_; +}; + +class QuicSpdyStreamTest : public QuicTestWithParam { + protected: + QuicSpdyStreamTest() { + headers_[":host"] = "www.google.com"; + headers_[":path"] = "/index.hml"; + headers_[":scheme"] = "https"; + headers_["cookie"] = + "__utma=208381060.1228362404.1372200928.1372200928.1372200928.1; " + "__utmc=160408618; " + "GX=DQAAAOEAAACWJYdewdE9rIrW6qw3PtVi2-d729qaa-74KqOsM1NVQblK4VhX" + "hoALMsy6HOdDad2Sz0flUByv7etmo3mLMidGrBoljqO9hSVA40SLqpG_iuKKSHX" + "RW3Np4bq0F0SDGDNsW0DSmTS9ufMRrlpARJDS7qAI6M3bghqJp4eABKZiRqebHT" + "pMU-RXvTI5D5oCF1vYxYofH_l1Kviuiy3oQ1kS1enqWgbhJ2t61_SNdv-1XJIS0" + "O3YeHLmVCs62O6zp89QwakfAWK9d3IDQvVSJzCQsvxvNIvaZFa567MawWlXg0Rh" + "1zFMi5vzcns38-8_Sns; " + "GA=v*2%2Fmem*57968640*47239936%2Fmem*57968640*47114716%2Fno-nm-" + "yj*15%2Fno-cc-yj*5%2Fpc-ch*133685%2Fpc-s-cr*133947%2Fpc-s-t*1339" + "47%2Fno-nm-yj*4%2Fno-cc-yj*1%2Fceft-as*1%2Fceft-nqas*0%2Fad-ra-c" + "v_p%2Fad-nr-cv_p-f*1%2Fad-v-cv_p*859%2Fad-ns-cv_p-f*1%2Ffn-v-ad%" + "2Fpc-t*250%2Fpc-cm*461%2Fpc-s-cr*722%2Fpc-s-t*722%2Fau_p*4" + "SICAID=AJKiYcHdKgxum7KMXG0ei2t1-W4OD1uW-ecNsCqC0wDuAXiDGIcT_HA2o1" + "3Rs1UKCuBAF9g8rWNOFbxt8PSNSHFuIhOo2t6bJAVpCsMU5Laa6lewuTMYI8MzdQP" + "ARHKyW-koxuhMZHUnGBJAM1gJODe0cATO_KGoX4pbbFxxJ5IicRxOrWK_5rU3cdy6" + "edlR9FsEdH6iujMcHkbE5l18ehJDwTWmBKBzVD87naobhMMrF6VvnDGxQVGp9Ir_b" + "Rgj3RWUoPumQVCxtSOBdX0GlJOEcDTNCzQIm9BSfetog_eP_TfYubKudt5eMsXmN6" + "QnyXHeGeK2UINUzJ-D30AFcpqYgH9_1BvYSpi7fc7_ydBU8TaD8ZRxvtnzXqj0RfG" + "tuHghmv3aD-uzSYJ75XDdzKdizZ86IG6Fbn1XFhYZM-fbHhm3mVEXnyRW4ZuNOLFk" + "Fas6LMcVC6Q8QLlHYbXBpdNFuGbuZGUnav5C-2I_-46lL0NGg3GewxGKGHvHEfoyn" + "EFFlEYHsBQ98rXImL8ySDycdLEFvBPdtctPmWCfTxwmoSMLHU2SCVDhbqMWU5b0yr" + "JBCScs_ejbKaqBDoB7ZGxTvqlrB__2ZmnHHjCr8RgMRtKNtIeuZAo "; + } + + ~QuicSpdyStreamTest() override = default; + + // Return QPACK-encoded header block without using the dynamic table. + std::string EncodeQpackHeaders( + std::vector> headers) { + Http2HeaderBlock header_block; + for (const auto& header_field : headers) { + header_block.AppendValueOrAddHeader(header_field.first, + header_field.second); + } + + return EncodeQpackHeaders(header_block); + } + + // Return QPACK-encoded header block without using the dynamic table. + std::string EncodeQpackHeaders(const Http2HeaderBlock& header) { + NoopQpackStreamSenderDelegate encoder_stream_sender_delegate; + auto qpack_encoder = std::make_unique(session_.get()); + qpack_encoder->set_qpack_stream_sender_delegate( + &encoder_stream_sender_delegate); + // QpackEncoder does not use the dynamic table by default, + // therefore the value of |stream_id| does not matter. + return qpack_encoder->EncodeHeaderList(/* stream_id = */ 0, header, + nullptr); + } + + void Initialize(bool stream_should_process_data) { + InitializeWithPerspective(stream_should_process_data, + Perspective::IS_SERVER); + } + + void InitializeWithPerspective(bool stream_should_process_data, + Perspective perspective) { + connection_ = new StrictMock( + &helper_, &alarm_factory_, perspective, SupportedVersions(GetParam())); + session_ = std::make_unique>(connection_); + EXPECT_CALL(*session_, OnCongestionWindowChange(_)).Times(AnyNumber()); + session_->Initialize(); + if (connection_->version().SupportsAntiAmplificationLimit()) { + QuicConnectionPeer::SetAddressValidated(connection_); + } + connection_->AdvanceTime(QuicTime::Delta::FromSeconds(1)); + ON_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillByDefault( + Invoke(session_.get(), &MockQuicSpdySession::ConsumeData)); + + stream_ = + new StrictMock(GetNthClientInitiatedBidirectionalId(0), + session_.get(), stream_should_process_data); + session_->ActivateStream(absl::WrapUnique(stream_)); + stream2_ = + new StrictMock(GetNthClientInitiatedBidirectionalId(1), + session_.get(), stream_should_process_data); + session_->ActivateStream(absl::WrapUnique(stream2_)); + QuicConfigPeer::SetReceivedInitialSessionFlowControlWindow( + session_->config(), kMinimumFlowControlSendWindow); + QuicConfigPeer::SetReceivedInitialMaxStreamDataBytesUnidirectional( + session_->config(), kMinimumFlowControlSendWindow); + QuicConfigPeer::SetReceivedInitialMaxStreamDataBytesIncomingBidirectional( + session_->config(), kMinimumFlowControlSendWindow); + QuicConfigPeer::SetReceivedInitialMaxStreamDataBytesOutgoingBidirectional( + session_->config(), kMinimumFlowControlSendWindow); + QuicConfigPeer::SetReceivedMaxUnidirectionalStreams(session_->config(), 10); + session_->OnConfigNegotiated(); + if (UsesHttp3()) { + // The control stream will write the stream type, a greased frame, and + // SETTINGS frame. + int num_control_stream_writes = 3; + auto send_control_stream = + QuicSpdySessionPeer::GetSendControlStream(session_.get()); + EXPECT_CALL(*session_, + WritevData(send_control_stream->id(), _, _, _, _, _)) + .Times(num_control_stream_writes); + } + TestCryptoStream* crypto_stream = session_->GetMutableCryptoStream(); + EXPECT_CALL(*crypto_stream, HasPendingRetransmission()).Times(AnyNumber()); + + if (connection_->version().UsesTls() && + session_->perspective() == Perspective::IS_SERVER) { + // HANDSHAKE_DONE frame. + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillOnce(Invoke(&ClearControlFrame)); + } + CryptoHandshakeMessage message; + session_->GetMutableCryptoStream()->OnHandshakeMessage(message); + } + + QuicHeaderList ProcessHeaders(bool fin, const Http2HeaderBlock& headers) { + QuicHeaderList h = AsHeaderList(headers); + stream_->OnStreamHeaderList(fin, h.uncompressed_header_bytes(), h); + return h; + } + + QuicStreamId GetNthClientInitiatedBidirectionalId(int n) { + return GetNthClientInitiatedBidirectionalStreamId( + connection_->transport_version(), n); + } + + bool UsesHttp3() const { + return VersionUsesHttp3(GetParam().transport_version); + } + + // Construct HEADERS frame with QPACK-encoded |headers| without using the + // dynamic table. + std::string HeadersFrame( + std::vector> headers) { + return HeadersFrame(EncodeQpackHeaders(headers)); + } + + // Construct HEADERS frame with QPACK-encoded |headers| without using the + // dynamic table. + std::string HeadersFrame(const Http2HeaderBlock& headers) { + return HeadersFrame(EncodeQpackHeaders(headers)); + } + + // Construct HEADERS frame with given payload. + std::string HeadersFrame(absl::string_view payload) { + std::string headers_frame_header = + HttpEncoder::SerializeHeadersFrameHeader(payload.length()); + return absl::StrCat(headers_frame_header, payload); + } + + std::string DataFrame(absl::string_view payload) { + quiche::QuicheBuffer header = HttpEncoder::SerializeDataFrameHeader( + payload.length(), quiche::SimpleBufferAllocator::Get()); + return absl::StrCat(header.AsStringView(), payload); + } + + std::string UnknownFrame(uint64_t frame_type, absl::string_view payload) { + std::string frame; + const size_t length = QuicDataWriter::GetVarInt62Len(frame_type) + + QuicDataWriter::GetVarInt62Len(payload.size()) + + payload.size(); + frame.resize(length); + + QuicDataWriter writer(length, const_cast(frame.data())); + writer.WriteVarInt62(frame_type); + writer.WriteStringPieceVarInt62(payload); + // Even though integers can be encoded with different lengths, + // QuicDataWriter is expected to produce an encoding in Write*() of length + // promised in GetVarInt62Len(). + QUICHE_DCHECK_EQ(length, writer.length()); + + return frame; + } + + MockQuicConnectionHelper helper_; + MockAlarmFactory alarm_factory_; + MockQuicConnection* connection_; + std::unique_ptr session_; + + // Owned by the |session_|. + TestStream* stream_; + TestStream* stream2_; + + Http2HeaderBlock headers_; +}; + +INSTANTIATE_TEST_SUITE_P(Tests, QuicSpdyStreamTest, + ::testing::ValuesIn(AllSupportedVersions()), + ::testing::PrintToStringParamName()); + +TEST_P(QuicSpdyStreamTest, ProcessHeaderList) { + Initialize(kShouldProcessData); + + stream_->OnStreamHeadersPriority( + spdy::SpdyStreamPrecedence(kV3HighestPriority)); + ProcessHeaders(false, headers_); + EXPECT_EQ("", stream_->data()); + EXPECT_FALSE(stream_->header_list().empty()); + EXPECT_FALSE(stream_->IsDoneReading()); +} + +TEST_P(QuicSpdyStreamTest, ProcessTooLargeHeaderList) { + Initialize(kShouldProcessData); + + if (!UsesHttp3()) { + QuicHeaderList headers; + stream_->OnStreamHeadersPriority( + spdy::SpdyStreamPrecedence(kV3HighestPriority)); + + EXPECT_CALL( + *session_, + MaybeSendRstStreamFrame( + stream_->id(), + QuicResetStreamError::FromInternal(QUIC_HEADERS_TOO_LARGE), 0)); + stream_->OnStreamHeaderList(false, 1 << 20, headers); + + EXPECT_THAT(stream_->stream_error(), IsStreamError(QUIC_HEADERS_TOO_LARGE)); + + return; + } + + // Header list size includes 32 bytes for overhead per header field. + session_->set_max_inbound_header_list_size(40); + std::string headers = + HeadersFrame({std::make_pair("foo", "too long headers")}); + + QuicStreamFrame frame(stream_->id(), false, 0, headers); + + EXPECT_CALL(*session_, MaybeSendStopSendingFrame( + stream_->id(), QuicResetStreamError::FromInternal( + QUIC_HEADERS_TOO_LARGE))); + EXPECT_CALL( + *session_, + MaybeSendRstStreamFrame( + stream_->id(), + QuicResetStreamError::FromInternal(QUIC_HEADERS_TOO_LARGE), 0)); + + auto qpack_decoder_stream = + QuicSpdySessionPeer::GetQpackDecoderSendStream(session_.get()); + // Stream type and stream cancellation. + EXPECT_CALL(*session_, + WritevData(qpack_decoder_stream->id(), _, _, NO_FIN, _, _)) + .Times(2); + + stream_->OnStreamFrame(frame); + EXPECT_THAT(stream_->stream_error(), IsStreamError(QUIC_HEADERS_TOO_LARGE)); +} + +TEST_P(QuicSpdyStreamTest, QpackProcessLargeHeaderListDiscountOverhead) { + if (!UsesHttp3()) { + return; + } + // Setting this flag to false causes no per-entry overhead to be included + // in the header size. + SetQuicFlag(quic_header_size_limit_includes_overhead, false); + Initialize(kShouldProcessData); + session_->set_max_inbound_header_list_size(40); + std::string headers = + HeadersFrame({std::make_pair("foo", "too long headers")}); + + QuicStreamFrame frame(stream_->id(), false, 0, headers); + stream_->OnStreamFrame(frame); + EXPECT_THAT(stream_->stream_error(), IsStreamError(QUIC_STREAM_NO_ERROR)); +} + +TEST_P(QuicSpdyStreamTest, ProcessHeaderListWithFin) { + Initialize(kShouldProcessData); + + size_t total_bytes = 0; + QuicHeaderList headers; + for (auto p : headers_) { + headers.OnHeader(p.first, p.second); + total_bytes += p.first.size() + p.second.size(); + } + stream_->OnStreamHeadersPriority( + spdy::SpdyStreamPrecedence(kV3HighestPriority)); + stream_->OnStreamHeaderList(true, total_bytes, headers); + EXPECT_EQ("", stream_->data()); + EXPECT_FALSE(stream_->header_list().empty()); + EXPECT_FALSE(stream_->IsDoneReading()); + EXPECT_TRUE(stream_->HasReceivedFinalOffset()); +} + +// A valid status code should be 3-digit integer. The first digit should be in +// the range of [1, 5]. All the others are invalid. +TEST_P(QuicSpdyStreamTest, ParseHeaderStatusCode) { + Initialize(kShouldProcessData); + int status_code = 0; + + // Valid status codes. + headers_[":status"] = "404"; + EXPECT_TRUE(stream_->ParseHeaderStatusCode(headers_, &status_code)); + EXPECT_EQ(404, status_code); + + headers_[":status"] = "100"; + EXPECT_TRUE(stream_->ParseHeaderStatusCode(headers_, &status_code)); + EXPECT_EQ(100, status_code); + + headers_[":status"] = "599"; + EXPECT_TRUE(stream_->ParseHeaderStatusCode(headers_, &status_code)); + EXPECT_EQ(599, status_code); + + // Invalid status codes. + headers_[":status"] = "010"; + EXPECT_FALSE(stream_->ParseHeaderStatusCode(headers_, &status_code)); + + headers_[":status"] = "600"; + EXPECT_FALSE(stream_->ParseHeaderStatusCode(headers_, &status_code)); + + headers_[":status"] = "200 ok"; + EXPECT_FALSE(stream_->ParseHeaderStatusCode(headers_, &status_code)); + + headers_[":status"] = "2000"; + EXPECT_FALSE(stream_->ParseHeaderStatusCode(headers_, &status_code)); + + headers_[":status"] = "+200"; + EXPECT_FALSE(stream_->ParseHeaderStatusCode(headers_, &status_code)); + + headers_[":status"] = "+20"; + EXPECT_FALSE(stream_->ParseHeaderStatusCode(headers_, &status_code)); + + headers_[":status"] = "-10"; + EXPECT_FALSE(stream_->ParseHeaderStatusCode(headers_, &status_code)); + + headers_[":status"] = "-100"; + EXPECT_FALSE(stream_->ParseHeaderStatusCode(headers_, &status_code)); + + // Leading or trailing spaces are also invalid. + headers_[":status"] = " 200"; + EXPECT_FALSE(stream_->ParseHeaderStatusCode(headers_, &status_code)); + + headers_[":status"] = "200 "; + EXPECT_FALSE(stream_->ParseHeaderStatusCode(headers_, &status_code)); + + headers_[":status"] = " 200 "; + EXPECT_FALSE(stream_->ParseHeaderStatusCode(headers_, &status_code)); + + headers_[":status"] = " "; + EXPECT_FALSE(stream_->ParseHeaderStatusCode(headers_, &status_code)); +} + +TEST_P(QuicSpdyStreamTest, MarkHeadersConsumed) { + Initialize(kShouldProcessData); + + std::string body = "this is the body"; + QuicHeaderList headers = ProcessHeaders(false, headers_); + EXPECT_EQ(headers, stream_->header_list()); + + stream_->ConsumeHeaderList(); + EXPECT_EQ(QuicHeaderList(), stream_->header_list()); +} + +TEST_P(QuicSpdyStreamTest, ProcessWrongFramesOnSpdyStream) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + testing::InSequence s; + connection_->AdvanceTime(QuicTime::Delta::FromSeconds(1)); + GoAwayFrame goaway; + goaway.id = 0x1; + std::string goaway_frame = HttpEncoder::SerializeGoAwayFrame(goaway); + + EXPECT_EQ("", stream_->data()); + QuicHeaderList headers = ProcessHeaders(false, headers_); + EXPECT_EQ(headers, stream_->header_list()); + stream_->ConsumeHeaderList(); + QuicStreamFrame frame(GetNthClientInitiatedBidirectionalId(0), false, 0, + goaway_frame); + + EXPECT_CALL(*connection_, + CloseConnection(QUIC_HTTP_FRAME_UNEXPECTED_ON_SPDY_STREAM, _, _)) + .WillOnce( + (Invoke([this](QuicErrorCode error, const std::string& error_details, + ConnectionCloseBehavior connection_close_behavior) { + connection_->ReallyCloseConnection(error, error_details, + connection_close_behavior); + }))); + EXPECT_CALL(*connection_, SendConnectionClosePacket(_, _, _)); + EXPECT_CALL(*session_, OnConnectionClosed(_, _)) + .WillOnce(Invoke([this](const QuicConnectionCloseFrame& frame, + ConnectionCloseSource source) { + session_->ReallyOnConnectionClosed(frame, source); + })); + EXPECT_CALL(*session_, MaybeSendRstStreamFrame(_, _, _)).Times(2); + + stream_->OnStreamFrame(frame); +} + +TEST_P(QuicSpdyStreamTest, Http3FrameError) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + + // PUSH_PROMISE frame with empty payload is considered invalid. + std::string invalid_http3_frame = absl::HexStringToBytes("0500"); + QuicStreamFrame stream_frame(stream_->id(), /* fin = */ false, + /* offset = */ 0, invalid_http3_frame); + + EXPECT_CALL(*connection_, CloseConnection(QUIC_HTTP_FRAME_ERROR, _, _)); + stream_->OnStreamFrame(stream_frame); +} + +TEST_P(QuicSpdyStreamTest, UnexpectedHttp3Frame) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + + // SETTINGS frame with empty payload. + std::string settings = absl::HexStringToBytes("0400"); + QuicStreamFrame stream_frame(stream_->id(), /* fin = */ false, + /* offset = */ 0, settings); + + EXPECT_CALL(*connection_, + CloseConnection(QUIC_HTTP_FRAME_UNEXPECTED_ON_SPDY_STREAM, _, _)); + stream_->OnStreamFrame(stream_frame); +} + +TEST_P(QuicSpdyStreamTest, ProcessHeadersAndBody) { + Initialize(kShouldProcessData); + + std::string body = "this is the body"; + std::string data = UsesHttp3() ? DataFrame(body) : body; + + EXPECT_EQ("", stream_->data()); + QuicHeaderList headers = ProcessHeaders(false, headers_); + EXPECT_EQ(headers, stream_->header_list()); + stream_->ConsumeHeaderList(); + QuicStreamFrame frame(GetNthClientInitiatedBidirectionalId(0), false, 0, + absl::string_view(data)); + stream_->OnStreamFrame(frame); + EXPECT_EQ(QuicHeaderList(), stream_->header_list()); + EXPECT_EQ(body, stream_->data()); +} + +TEST_P(QuicSpdyStreamTest, ProcessHeadersAndBodyFragments) { + std::string body = "this is the body"; + std::string data = UsesHttp3() ? DataFrame(body) : body; + + for (size_t fragment_size = 1; fragment_size < data.size(); ++fragment_size) { + Initialize(kShouldProcessData); + QuicHeaderList headers = ProcessHeaders(false, headers_); + ASSERT_EQ(headers, stream_->header_list()); + stream_->ConsumeHeaderList(); + for (size_t offset = 0; offset < data.size(); offset += fragment_size) { + size_t remaining_data = data.size() - offset; + absl::string_view fragment(data.data() + offset, + std::min(fragment_size, remaining_data)); + QuicStreamFrame frame(GetNthClientInitiatedBidirectionalId(0), false, + offset, absl::string_view(fragment)); + stream_->OnStreamFrame(frame); + } + ASSERT_EQ(body, stream_->data()) << "fragment_size: " << fragment_size; + } +} + +TEST_P(QuicSpdyStreamTest, ProcessHeadersAndBodyFragmentsSplit) { + std::string body = "this is the body"; + std::string data = UsesHttp3() ? DataFrame(body) : body; + + for (size_t split_point = 1; split_point < data.size() - 1; ++split_point) { + Initialize(kShouldProcessData); + QuicHeaderList headers = ProcessHeaders(false, headers_); + ASSERT_EQ(headers, stream_->header_list()); + stream_->ConsumeHeaderList(); + + absl::string_view fragment1(data.data(), split_point); + QuicStreamFrame frame1(GetNthClientInitiatedBidirectionalId(0), false, 0, + absl::string_view(fragment1)); + stream_->OnStreamFrame(frame1); + + absl::string_view fragment2(data.data() + split_point, + data.size() - split_point); + QuicStreamFrame frame2(GetNthClientInitiatedBidirectionalId(0), false, + split_point, absl::string_view(fragment2)); + stream_->OnStreamFrame(frame2); + + ASSERT_EQ(body, stream_->data()) << "split_point: " << split_point; + } +} + +TEST_P(QuicSpdyStreamTest, ProcessHeadersAndBodyReadv) { + Initialize(!kShouldProcessData); + + std::string body = "this is the body"; + std::string data = UsesHttp3() ? DataFrame(body) : body; + + ProcessHeaders(false, headers_); + QuicStreamFrame frame(GetNthClientInitiatedBidirectionalId(0), false, 0, + absl::string_view(data)); + stream_->OnStreamFrame(frame); + stream_->ConsumeHeaderList(); + + char buffer[2048]; + ASSERT_LT(data.length(), ABSL_ARRAYSIZE(buffer)); + struct iovec vec; + vec.iov_base = buffer; + vec.iov_len = ABSL_ARRAYSIZE(buffer); + + size_t bytes_read = stream_->Readv(&vec, 1); + QuicStreamPeer::CloseReadSide(stream_); + EXPECT_EQ(body.length(), bytes_read); + EXPECT_EQ(body, std::string(buffer, bytes_read)); +} + +TEST_P(QuicSpdyStreamTest, ProcessHeadersAndLargeBodySmallReadv) { + Initialize(kShouldProcessData); + std::string body(12 * 1024, 'a'); + std::string data = UsesHttp3() ? DataFrame(body) : body; + + ProcessHeaders(false, headers_); + QuicStreamFrame frame(GetNthClientInitiatedBidirectionalId(0), false, 0, + absl::string_view(data)); + stream_->OnStreamFrame(frame); + stream_->ConsumeHeaderList(); + char buffer[2048]; + char buffer2[2048]; + struct iovec vec[2]; + vec[0].iov_base = buffer; + vec[0].iov_len = ABSL_ARRAYSIZE(buffer); + vec[1].iov_base = buffer2; + vec[1].iov_len = ABSL_ARRAYSIZE(buffer2); + size_t bytes_read = stream_->Readv(vec, 2); + EXPECT_EQ(2048u * 2, bytes_read); + EXPECT_EQ(body.substr(0, 2048), std::string(buffer, 2048)); + EXPECT_EQ(body.substr(2048, 2048), std::string(buffer2, 2048)); +} + +TEST_P(QuicSpdyStreamTest, ProcessHeadersAndBodyMarkConsumed) { + Initialize(!kShouldProcessData); + + std::string body = "this is the body"; + std::string data = UsesHttp3() ? DataFrame(body) : body; + + ProcessHeaders(false, headers_); + QuicStreamFrame frame(GetNthClientInitiatedBidirectionalId(0), false, 0, + absl::string_view(data)); + stream_->OnStreamFrame(frame); + stream_->ConsumeHeaderList(); + + struct iovec vec; + + EXPECT_EQ(1, stream_->GetReadableRegions(&vec, 1)); + EXPECT_EQ(body.length(), vec.iov_len); + EXPECT_EQ(body, std::string(static_cast(vec.iov_base), vec.iov_len)); + + stream_->MarkConsumed(body.length()); + EXPECT_EQ(data.length(), QuicStreamPeer::bytes_consumed(stream_)); +} + +TEST_P(QuicSpdyStreamTest, ProcessHeadersAndConsumeMultipleBody) { + Initialize(!kShouldProcessData); + std::string body1 = "this is body 1"; + std::string data1 = UsesHttp3() ? DataFrame(body1) : body1; + std::string body2 = "body 2"; + std::string data2 = UsesHttp3() ? DataFrame(body2) : body2; + + ProcessHeaders(false, headers_); + QuicStreamFrame frame1(GetNthClientInitiatedBidirectionalId(0), false, 0, + absl::string_view(data1)); + QuicStreamFrame frame2(GetNthClientInitiatedBidirectionalId(0), false, + data1.length(), absl::string_view(data2)); + stream_->OnStreamFrame(frame1); + stream_->OnStreamFrame(frame2); + stream_->ConsumeHeaderList(); + + stream_->MarkConsumed(body1.length() + body2.length()); + EXPECT_EQ(data1.length() + data2.length(), + QuicStreamPeer::bytes_consumed(stream_)); +} + +TEST_P(QuicSpdyStreamTest, ProcessHeadersAndBodyIncrementalReadv) { + Initialize(!kShouldProcessData); + + std::string body = "this is the body"; + std::string data = UsesHttp3() ? DataFrame(body) : body; + + ProcessHeaders(false, headers_); + QuicStreamFrame frame(GetNthClientInitiatedBidirectionalId(0), false, 0, + absl::string_view(data)); + stream_->OnStreamFrame(frame); + stream_->ConsumeHeaderList(); + + char buffer[1]; + struct iovec vec; + vec.iov_base = buffer; + vec.iov_len = ABSL_ARRAYSIZE(buffer); + + for (size_t i = 0; i < body.length(); ++i) { + size_t bytes_read = stream_->Readv(&vec, 1); + ASSERT_EQ(1u, bytes_read); + EXPECT_EQ(body.data()[i], buffer[0]); + } +} + +TEST_P(QuicSpdyStreamTest, ProcessHeadersUsingReadvWithMultipleIovecs) { + Initialize(!kShouldProcessData); + + std::string body = "this is the body"; + std::string data = UsesHttp3() ? DataFrame(body) : body; + + ProcessHeaders(false, headers_); + QuicStreamFrame frame(GetNthClientInitiatedBidirectionalId(0), false, 0, + absl::string_view(data)); + stream_->OnStreamFrame(frame); + stream_->ConsumeHeaderList(); + + char buffer1[1]; + char buffer2[1]; + struct iovec vec[2]; + vec[0].iov_base = buffer1; + vec[0].iov_len = ABSL_ARRAYSIZE(buffer1); + vec[1].iov_base = buffer2; + vec[1].iov_len = ABSL_ARRAYSIZE(buffer2); + + for (size_t i = 0; i < body.length(); i += 2) { + size_t bytes_read = stream_->Readv(vec, 2); + ASSERT_EQ(2u, bytes_read) << i; + ASSERT_EQ(body.data()[i], buffer1[0]) << i; + ASSERT_EQ(body.data()[i + 1], buffer2[0]) << i; + } +} + +// Tests that we send a BLOCKED frame to the peer when we attempt to write, but +// are flow control blocked. +TEST_P(QuicSpdyStreamTest, StreamFlowControlBlocked) { + Initialize(kShouldProcessData); + testing::InSequence seq; + + // Set a small flow control limit. + const uint64_t kWindow = 36; + QuicStreamPeer::SetSendWindowOffset(stream_, kWindow); + EXPECT_EQ(kWindow, QuicStreamPeer::SendWindowOffset(stream_)); + + // Try to send more data than the flow control limit allows. + const uint64_t kOverflow = 15; + std::string body(kWindow + kOverflow, 'a'); + + const uint64_t kHeaderLength = UsesHttp3() ? 2 : 0; + if (UsesHttp3()) { + EXPECT_CALL(*session_, WritevData(_, kHeaderLength, _, NO_FIN, _, _)); + } + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillOnce(Return(QuicConsumedData(kWindow - kHeaderLength, true))); + EXPECT_CALL(*session_, SendBlocked(_, _)); + EXPECT_CALL(*connection_, SendControlFrame(_)); + stream_->WriteOrBufferBody(body, false); + + // Should have sent as much as possible, resulting in no send window left. + EXPECT_EQ(0u, QuicStreamPeer::SendWindowSize(stream_)); + + // And we should have queued the overflowed data. + EXPECT_EQ(kOverflow + kHeaderLength, stream_->BufferedDataBytes()); +} + +// The flow control receive window decreases whenever we add new bytes to the +// sequencer, whether they are consumed immediately or buffered. However we only +// send WINDOW_UPDATE frames based on increasing number of bytes consumed. +TEST_P(QuicSpdyStreamTest, StreamFlowControlNoWindowUpdateIfNotConsumed) { + // Don't process data - it will be buffered instead. + Initialize(!kShouldProcessData); + + // Expect no WINDOW_UPDATE frames to be sent. + EXPECT_CALL(*session_, SendWindowUpdate(_, _)).Times(0); + + // Set a small flow control receive window. + const uint64_t kWindow = 36; + QuicStreamPeer::SetReceiveWindowOffset(stream_, kWindow); + QuicStreamPeer::SetMaxReceiveWindow(stream_, kWindow); + + // Stream receives enough data to fill a fraction of the receive window. + std::string body(kWindow / 3, 'a'); + QuicByteCount header_length = 0; + std::string data; + + if (UsesHttp3()) { + quiche::QuicheBuffer header = HttpEncoder::SerializeDataFrameHeader( + body.length(), quiche::SimpleBufferAllocator::Get()); + data = absl::StrCat(header.AsStringView(), body); + header_length = header.size(); + } else { + data = body; + } + + ProcessHeaders(false, headers_); + + QuicStreamFrame frame1(GetNthClientInitiatedBidirectionalId(0), false, 0, + absl::string_view(data)); + stream_->OnStreamFrame(frame1); + EXPECT_EQ(kWindow - (kWindow / 3) - header_length, + QuicStreamPeer::ReceiveWindowSize(stream_)); + + // Now receive another frame which results in the receive window being over + // half full. This should all be buffered, decreasing the receive window but + // not sending WINDOW_UPDATE. + QuicStreamFrame frame2(GetNthClientInitiatedBidirectionalId(0), false, + kWindow / 3 + header_length, absl::string_view(data)); + stream_->OnStreamFrame(frame2); + EXPECT_EQ(kWindow - (2 * kWindow / 3) - 2 * header_length, + QuicStreamPeer::ReceiveWindowSize(stream_)); +} + +// Tests that on receipt of data, the stream updates its receive window offset +// appropriately, and sends WINDOW_UPDATE frames when its receive window drops +// too low. +TEST_P(QuicSpdyStreamTest, StreamFlowControlWindowUpdate) { + Initialize(kShouldProcessData); + + // Set a small flow control limit. + const uint64_t kWindow = 36; + QuicStreamPeer::SetReceiveWindowOffset(stream_, kWindow); + QuicStreamPeer::SetMaxReceiveWindow(stream_, kWindow); + + // Stream receives enough data to fill a fraction of the receive window. + std::string body(kWindow / 3, 'a'); + QuicByteCount header_length = 0; + std::string data; + + if (UsesHttp3()) { + quiche::QuicheBuffer header = HttpEncoder::SerializeDataFrameHeader( + body.length(), quiche::SimpleBufferAllocator::Get()); + data = absl::StrCat(header.AsStringView(), body); + header_length = header.size(); + } else { + data = body; + } + + ProcessHeaders(false, headers_); + stream_->ConsumeHeaderList(); + + QuicStreamFrame frame1(GetNthClientInitiatedBidirectionalId(0), false, 0, + absl::string_view(data)); + stream_->OnStreamFrame(frame1); + EXPECT_EQ(kWindow - (kWindow / 3) - header_length, + QuicStreamPeer::ReceiveWindowSize(stream_)); + + // Now receive another frame which results in the receive window being over + // half full. This will trigger the stream to increase its receive window + // offset and send a WINDOW_UPDATE. The result will be again an available + // window of kWindow bytes. + QuicStreamFrame frame2(GetNthClientInitiatedBidirectionalId(0), false, + kWindow / 3 + header_length, absl::string_view(data)); + EXPECT_CALL(*session_, SendWindowUpdate(_, _)); + EXPECT_CALL(*connection_, SendControlFrame(_)); + stream_->OnStreamFrame(frame2); + EXPECT_EQ(kWindow, QuicStreamPeer::ReceiveWindowSize(stream_)); +} + +// Tests that on receipt of data, the connection updates its receive window +// offset appropriately, and sends WINDOW_UPDATE frames when its receive window +// drops too low. +TEST_P(QuicSpdyStreamTest, ConnectionFlowControlWindowUpdate) { + Initialize(kShouldProcessData); + + // Set a small flow control limit for streams and connection. + const uint64_t kWindow = 36; + QuicStreamPeer::SetReceiveWindowOffset(stream_, kWindow); + QuicStreamPeer::SetMaxReceiveWindow(stream_, kWindow); + QuicStreamPeer::SetReceiveWindowOffset(stream2_, kWindow); + QuicStreamPeer::SetMaxReceiveWindow(stream2_, kWindow); + QuicFlowControllerPeer::SetReceiveWindowOffset(session_->flow_controller(), + kWindow); + QuicFlowControllerPeer::SetMaxReceiveWindow(session_->flow_controller(), + kWindow); + + // Supply headers to both streams so that they are happy to receive data. + auto headers = AsHeaderList(headers_); + stream_->OnStreamHeaderList(false, headers.uncompressed_header_bytes(), + headers); + stream_->ConsumeHeaderList(); + stream2_->OnStreamHeaderList(false, headers.uncompressed_header_bytes(), + headers); + stream2_->ConsumeHeaderList(); + + // Each stream gets a quarter window of data. This should not trigger a + // WINDOW_UPDATE for either stream, nor for the connection. + QuicByteCount header_length = 0; + std::string body; + std::string data; + std::string data2; + std::string body2(1, 'a'); + + if (UsesHttp3()) { + body = std::string(kWindow / 4 - 2, 'a'); + quiche::QuicheBuffer header = HttpEncoder::SerializeDataFrameHeader( + body.length(), quiche::SimpleBufferAllocator::Get()); + data = absl::StrCat(header.AsStringView(), body); + header_length = header.size(); + quiche::QuicheBuffer header2 = HttpEncoder::SerializeDataFrameHeader( + body.length(), quiche::SimpleBufferAllocator::Get()); + data2 = absl::StrCat(header2.AsStringView(), body2); + } else { + body = std::string(kWindow / 4, 'a'); + data = body; + data2 = body2; + } + + QuicStreamFrame frame1(GetNthClientInitiatedBidirectionalId(0), false, 0, + absl::string_view(data)); + stream_->OnStreamFrame(frame1); + QuicStreamFrame frame2(GetNthClientInitiatedBidirectionalId(1), false, 0, + absl::string_view(data)); + stream2_->OnStreamFrame(frame2); + + // Now receive a further single byte on one stream - again this does not + // trigger a stream WINDOW_UPDATE, but now the connection flow control window + // is over half full and thus a connection WINDOW_UPDATE is sent. + EXPECT_CALL(*session_, SendWindowUpdate(_, _)); + EXPECT_CALL(*connection_, SendControlFrame(_)); + QuicStreamFrame frame3(GetNthClientInitiatedBidirectionalId(0), false, + body.length() + header_length, + absl::string_view(data2)); + stream_->OnStreamFrame(frame3); +} + +// Tests that on if the peer sends too much data (i.e. violates the flow control +// protocol), then we terminate the connection. +TEST_P(QuicSpdyStreamTest, StreamFlowControlViolation) { + // Stream should not process data, so that data gets buffered in the + // sequencer, triggering flow control limits. + Initialize(!kShouldProcessData); + + // Set a small flow control limit. + const uint64_t kWindow = 50; + QuicStreamPeer::SetReceiveWindowOffset(stream_, kWindow); + + ProcessHeaders(false, headers_); + + // Receive data to overflow the window, violating flow control. + std::string body(kWindow + 1, 'a'); + std::string data = UsesHttp3() ? DataFrame(body) : body; + QuicStreamFrame frame(GetNthClientInitiatedBidirectionalId(0), false, 0, + absl::string_view(data)); + EXPECT_CALL(*connection_, + CloseConnection(QUIC_FLOW_CONTROL_RECEIVED_TOO_MUCH_DATA, _, _)); + stream_->OnStreamFrame(frame); +} + +TEST_P(QuicSpdyStreamTest, TestHandlingQuicRstStreamNoError) { + Initialize(kShouldProcessData); + ProcessHeaders(false, headers_); + + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)).Times(AnyNumber()); + + stream_->OnStreamReset(QuicRstStreamFrame( + kInvalidControlFrameId, stream_->id(), QUIC_STREAM_NO_ERROR, 0)); + + if (UsesHttp3()) { + // RESET_STREAM should close the read side but not the write side. + EXPECT_TRUE(stream_->read_side_closed()); + EXPECT_FALSE(stream_->write_side_closed()); + } else { + EXPECT_TRUE(stream_->write_side_closed()); + EXPECT_FALSE(stream_->reading_stopped()); + } +} + +// Tests that on if the peer sends too much data (i.e. violates the flow control +// protocol), at the connection level (rather than the stream level) then we +// terminate the connection. +TEST_P(QuicSpdyStreamTest, ConnectionFlowControlViolation) { + // Stream should not process data, so that data gets buffered in the + // sequencer, triggering flow control limits. + Initialize(!kShouldProcessData); + + // Set a small flow control window on streams, and connection. + const uint64_t kStreamWindow = 50; + const uint64_t kConnectionWindow = 10; + QuicStreamPeer::SetReceiveWindowOffset(stream_, kStreamWindow); + QuicFlowControllerPeer::SetReceiveWindowOffset(session_->flow_controller(), + kConnectionWindow); + + ProcessHeaders(false, headers_); + + // Send enough data to overflow the connection level flow control window. + std::string body(kConnectionWindow + 1, 'a'); + std::string data = UsesHttp3() ? DataFrame(body) : body; + + EXPECT_LT(data.size(), kStreamWindow); + QuicStreamFrame frame(GetNthClientInitiatedBidirectionalId(0), false, 0, + absl::string_view(data)); + + EXPECT_CALL(*connection_, + CloseConnection(QUIC_FLOW_CONTROL_RECEIVED_TOO_MUCH_DATA, _, _)); + stream_->OnStreamFrame(frame); +} + +// An attempt to write a FIN with no data should not be flow control blocked, +// even if the send window is 0. +TEST_P(QuicSpdyStreamTest, StreamFlowControlFinNotBlocked) { + Initialize(kShouldProcessData); + + // Set a flow control limit of zero. + QuicStreamPeer::SetReceiveWindowOffset(stream_, 0); + + // Send a frame with a FIN but no data. This should not be blocked. + std::string body = ""; + bool fin = true; + + EXPECT_CALL(*session_, + SendBlocked(GetNthClientInitiatedBidirectionalId(0), _)) + .Times(0); + EXPECT_CALL(*session_, WritevData(_, 0, _, FIN, _, _)); + + stream_->WriteOrBufferBody(body, fin); +} + +// Test that receiving trailing headers from the peer via OnStreamHeaderList() +// works, and can be read from the stream and consumed. +TEST_P(QuicSpdyStreamTest, ReceivingTrailersViaHeaderList) { + Initialize(kShouldProcessData); + + // Receive initial headers. + size_t total_bytes = 0; + QuicHeaderList headers; + for (const auto& p : headers_) { + headers.OnHeader(p.first, p.second); + total_bytes += p.first.size() + p.second.size(); + } + + stream_->OnStreamHeadersPriority( + spdy::SpdyStreamPrecedence(kV3HighestPriority)); + stream_->OnStreamHeaderList(/*fin=*/false, total_bytes, headers); + stream_->ConsumeHeaderList(); + + // Receive trailing headers. + Http2HeaderBlock trailers_block; + trailers_block["key1"] = "value1"; + trailers_block["key2"] = "value2"; + trailers_block["key3"] = "value3"; + Http2HeaderBlock trailers_block_with_final_offset = trailers_block.Clone(); + if (!UsesHttp3()) { + // :final-offset pseudo-header is only added if trailers are sent + // on the headers stream. + trailers_block_with_final_offset[kFinalOffsetHeaderKey] = "0"; + } + total_bytes = 0; + QuicHeaderList trailers; + for (const auto& p : trailers_block_with_final_offset) { + trailers.OnHeader(p.first, p.second); + total_bytes += p.first.size() + p.second.size(); + } + stream_->OnStreamHeaderList(/*fin=*/true, total_bytes, trailers); + + // The trailers should be decompressed, and readable from the stream. + EXPECT_TRUE(stream_->trailers_decompressed()); + EXPECT_EQ(trailers_block, stream_->received_trailers()); + + // IsDoneReading() returns false until trailers marked consumed. + EXPECT_FALSE(stream_->IsDoneReading()); + stream_->MarkTrailersConsumed(); + EXPECT_TRUE(stream_->IsDoneReading()); +} + +// Test that when receiving trailing headers with an offset before response +// body, stream is closed at the right offset. +TEST_P(QuicSpdyStreamTest, ReceivingTrailersWithOffset) { + // kFinalOffsetHeaderKey is not used when HEADERS are sent on the + // request/response stream. + if (UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + + // Receive initial headers. + QuicHeaderList headers = ProcessHeaders(false, headers_); + stream_->ConsumeHeaderList(); + + const std::string body = "this is the body"; + std::string data = UsesHttp3() ? DataFrame(body) : body; + + // Receive trailing headers. + Http2HeaderBlock trailers_block; + trailers_block["key1"] = "value1"; + trailers_block["key2"] = "value2"; + trailers_block["key3"] = "value3"; + trailers_block[kFinalOffsetHeaderKey] = absl::StrCat(data.size()); + + QuicHeaderList trailers = ProcessHeaders(true, trailers_block); + + // The trailers should be decompressed, and readable from the stream. + EXPECT_TRUE(stream_->trailers_decompressed()); + + // The final offset trailer will be consumed by QUIC. + trailers_block.erase(kFinalOffsetHeaderKey); + EXPECT_EQ(trailers_block, stream_->received_trailers()); + + // Consuming the trailers erases them from the stream. + stream_->MarkTrailersConsumed(); + EXPECT_TRUE(stream_->FinishedReadingTrailers()); + + EXPECT_FALSE(stream_->IsDoneReading()); + // Receive and consume body. + QuicStreamFrame frame(GetNthClientInitiatedBidirectionalId(0), /*fin=*/false, + 0, data); + stream_->OnStreamFrame(frame); + EXPECT_EQ(body, stream_->data()); + EXPECT_TRUE(stream_->IsDoneReading()); +} + +// Test that receiving trailers without a final offset field is an error. +TEST_P(QuicSpdyStreamTest, ReceivingTrailersWithoutOffset) { + // kFinalOffsetHeaderKey is not used when HEADERS are sent on the + // request/response stream. + if (UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + + // Receive initial headers. + ProcessHeaders(false, headers_); + stream_->ConsumeHeaderList(); + + // Receive trailing headers, without kFinalOffsetHeaderKey. + Http2HeaderBlock trailers_block; + trailers_block["key1"] = "value1"; + trailers_block["key2"] = "value2"; + trailers_block["key3"] = "value3"; + auto trailers = AsHeaderList(trailers_block); + + // Verify that the trailers block didn't contain a final offset. + EXPECT_EQ("", trailers_block[kFinalOffsetHeaderKey].as_string()); + + // Receipt of the malformed trailers will close the connection. + EXPECT_CALL(*connection_, + CloseConnection(QUIC_INVALID_HEADERS_STREAM_DATA, _, _)) + .Times(1); + stream_->OnStreamHeaderList(/*fin=*/true, + trailers.uncompressed_header_bytes(), trailers); +} + +// Test that received Trailers must always have the FIN set. +TEST_P(QuicSpdyStreamTest, ReceivingTrailersWithoutFin) { + // In IETF QUIC, there is no such thing as FIN flag on HTTP/3 frames like the + // HEADERS frame. + if (UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + + // Receive initial headers. + auto headers = AsHeaderList(headers_); + stream_->OnStreamHeaderList(/*fin=*/false, + headers.uncompressed_header_bytes(), headers); + stream_->ConsumeHeaderList(); + + // Receive trailing headers with FIN deliberately set to false. + Http2HeaderBlock trailers_block; + trailers_block["foo"] = "bar"; + auto trailers = AsHeaderList(trailers_block); + + EXPECT_CALL(*connection_, + CloseConnection(QUIC_INVALID_HEADERS_STREAM_DATA, _, _)) + .Times(1); + stream_->OnStreamHeaderList(/*fin=*/false, + trailers.uncompressed_header_bytes(), trailers); +} + +TEST_P(QuicSpdyStreamTest, ReceivingTrailersAfterHeadersWithFin) { + // If headers are received with a FIN, no trailers should then arrive. + Initialize(kShouldProcessData); + + // If HEADERS frames are sent on the request/response stream, then the + // sequencer will signal an error if any stream data arrives after a FIN, + // so QuicSpdyStream does not need to. + if (UsesHttp3()) { + return; + } + + // Receive initial headers with FIN set. + ProcessHeaders(true, headers_); + stream_->ConsumeHeaderList(); + + // Receive trailing headers after FIN already received. + Http2HeaderBlock trailers_block; + trailers_block["foo"] = "bar"; + EXPECT_CALL(*connection_, + CloseConnection(QUIC_INVALID_HEADERS_STREAM_DATA, _, _)) + .Times(1); + ProcessHeaders(true, trailers_block); +} + +// If body data are received with a FIN, no trailers should then arrive. +TEST_P(QuicSpdyStreamTest, ReceivingTrailersAfterBodyWithFin) { + // If HEADERS frames are sent on the request/response stream, + // then the sequencer will block them from reaching QuicSpdyStream + // after the stream is closed. + if (UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + + // Receive initial headers without FIN set. + ProcessHeaders(false, headers_); + stream_->ConsumeHeaderList(); + + // Receive body data, with FIN. + QuicStreamFrame frame(GetNthClientInitiatedBidirectionalId(0), /*fin=*/true, + 0, "body"); + stream_->OnStreamFrame(frame); + + // Receive trailing headers after FIN already received. + Http2HeaderBlock trailers_block; + trailers_block["foo"] = "bar"; + EXPECT_CALL(*connection_, + CloseConnection(QUIC_INVALID_HEADERS_STREAM_DATA, _, _)) + .Times(1); + ProcessHeaders(true, trailers_block); +} + +TEST_P(QuicSpdyStreamTest, ClosingStreamWithNoTrailers) { + // Verify that a stream receiving headers, body, and no trailers is correctly + // marked as done reading on consumption of headers and body. + Initialize(kShouldProcessData); + + // Receive and consume initial headers with FIN not set. + auto h = AsHeaderList(headers_); + stream_->OnStreamHeaderList(/*fin=*/false, h.uncompressed_header_bytes(), h); + stream_->ConsumeHeaderList(); + + // Receive and consume body with FIN set, and no trailers. + std::string body(1024, 'x'); + std::string data = UsesHttp3() ? DataFrame(body) : body; + + QuicStreamFrame frame(GetNthClientInitiatedBidirectionalId(0), /*fin=*/true, + 0, data); + stream_->OnStreamFrame(frame); + + EXPECT_TRUE(stream_->IsDoneReading()); +} + +// Test that writing trailers will send a FIN, as Trailers are the last thing to +// be sent on a stream. +TEST_P(QuicSpdyStreamTest, WritingTrailersSendsAFin) { + Initialize(kShouldProcessData); + + if (UsesHttp3()) { + // In this case, TestStream::WriteHeadersImpl() does not prevent writes. + // Four writes on the request stream: HEADERS frame header and payload both + // for headers and trailers. + EXPECT_CALL(*session_, WritevData(stream_->id(), _, _, _, _, _)).Times(2); + } + + // Write the initial headers, without a FIN. + EXPECT_CALL(*stream_, WriteHeadersMock(false)); + stream_->WriteHeaders(Http2HeaderBlock(), /*fin=*/false, nullptr); + + // Writing trailers implicitly sends a FIN. + Http2HeaderBlock trailers; + trailers["trailer key"] = "trailer value"; + EXPECT_CALL(*stream_, WriteHeadersMock(true)); + stream_->WriteTrailers(std::move(trailers), nullptr); + EXPECT_TRUE(stream_->fin_sent()); +} + +TEST_P(QuicSpdyStreamTest, DoNotSendPriorityUpdateWithDefaultUrgency) { + if (!UsesHttp3()) { + return; + } + + InitializeWithPerspective(kShouldProcessData, Perspective::IS_CLIENT); + StrictMock debug_visitor; + session_->set_debug_visitor(&debug_visitor); + + // Four writes on the request stream: HEADERS frame header and payload both + // for headers and trailers. + EXPECT_CALL(*session_, WritevData(stream_->id(), _, _, _, _, _)).Times(2); + + // No PRIORITY_UPDATE frames on the control stream, + // because the stream has default priority. + auto send_control_stream = + QuicSpdySessionPeer::GetSendControlStream(session_.get()); + EXPECT_CALL(*session_, WritevData(send_control_stream->id(), _, _, _, _, _)) + .Times(0); + + // Write the initial headers, without a FIN. + EXPECT_CALL(*stream_, WriteHeadersMock(false)); + EXPECT_CALL(debug_visitor, OnHeadersFrameSent(stream_->id(), _)); + stream_->WriteHeaders(Http2HeaderBlock(), /*fin=*/false, nullptr); + + // Writing trailers implicitly sends a FIN. + Http2HeaderBlock trailers; + trailers["trailer key"] = "trailer value"; + EXPECT_CALL(*stream_, WriteHeadersMock(true)); + EXPECT_CALL(debug_visitor, OnHeadersFrameSent(stream_->id(), _)); + stream_->WriteTrailers(std::move(trailers), nullptr); + EXPECT_TRUE(stream_->fin_sent()); +} + +TEST_P(QuicSpdyStreamTest, ChangePriority) { + if (!UsesHttp3()) { + return; + } + + InitializeWithPerspective(kShouldProcessData, Perspective::IS_CLIENT); + StrictMock debug_visitor; + session_->set_debug_visitor(&debug_visitor); + + EXPECT_CALL(*session_, WritevData(stream_->id(), _, _, _, _, _)).Times(1); + EXPECT_CALL(*stream_, WriteHeadersMock(false)); + EXPECT_CALL(debug_visitor, OnHeadersFrameSent(stream_->id(), _)); + stream_->WriteHeaders(Http2HeaderBlock(), /*fin=*/false, nullptr); + testing::Mock::VerifyAndClearExpectations(&debug_visitor); + + // PRIORITY_UPDATE frame on the control stream. + auto send_control_stream = + QuicSpdySessionPeer::GetSendControlStream(session_.get()); + EXPECT_CALL(*session_, WritevData(send_control_stream->id(), _, _, _, _, _)); + PriorityUpdateFrame priority_update1{stream_->id(), "u=0"}; + EXPECT_CALL(debug_visitor, OnPriorityUpdateFrameSent(priority_update1)); + const HttpStreamPriority priority1{kV3HighestPriority, + HttpStreamPriority::kDefaultIncremental}; + stream_->SetPriority(QuicStreamPriority(priority1)); + testing::Mock::VerifyAndClearExpectations(&debug_visitor); + + // Send another PRIORITY_UPDATE frame with incremental flag set to true. + EXPECT_CALL(*session_, WritevData(send_control_stream->id(), _, _, _, _, _)); + PriorityUpdateFrame priority_update2{stream_->id(), "u=2, i"}; + EXPECT_CALL(debug_visitor, OnPriorityUpdateFrameSent(priority_update2)); + const HttpStreamPriority priority2{2, true}; + stream_->SetPriority(QuicStreamPriority(priority2)); + testing::Mock::VerifyAndClearExpectations(&debug_visitor); + + // Calling SetPriority() with the same priority does not trigger sending + // another PRIORITY_UPDATE frame. + stream_->SetPriority(QuicStreamPriority(priority2)); +} + +TEST_P(QuicSpdyStreamTest, ChangePriorityBeforeWritingHeaders) { + if (!UsesHttp3()) { + return; + } + + InitializeWithPerspective(kShouldProcessData, Perspective::IS_CLIENT); + + // PRIORITY_UPDATE frame sent on the control stream as soon as SetPriority() + // is called, before HEADERS frame is sent. + auto send_control_stream = + QuicSpdySessionPeer::GetSendControlStream(session_.get()); + EXPECT_CALL(*session_, WritevData(send_control_stream->id(), _, _, _, _, _)); + + stream_->SetPriority(QuicStreamPriority(HttpStreamPriority{ + kV3HighestPriority, HttpStreamPriority::kDefaultIncremental})); + testing::Mock::VerifyAndClearExpectations(session_.get()); + + // Two writes on the request stream: HEADERS frame header and payload. + // PRIORITY_UPDATE frame is not sent this time, because one is already sent. + EXPECT_CALL(*session_, WritevData(stream_->id(), _, _, _, _, _)).Times(1); + EXPECT_CALL(*stream_, WriteHeadersMock(true)); + stream_->WriteHeaders(Http2HeaderBlock(), /*fin=*/true, nullptr); +} + +// Test that when writing trailers, the trailers that are actually sent to the +// peer contain the final offset field indicating last byte of data. +TEST_P(QuicSpdyStreamTest, WritingTrailersFinalOffset) { + Initialize(kShouldProcessData); + + if (UsesHttp3()) { + // In this case, TestStream::WriteHeadersImpl() does not prevent writes. + // HEADERS frame header and payload on the request stream. + EXPECT_CALL(*session_, WritevData(stream_->id(), _, _, _, _, _)).Times(1); + } + + // Write the initial headers. + EXPECT_CALL(*stream_, WriteHeadersMock(false)); + stream_->WriteHeaders(Http2HeaderBlock(), /*fin=*/false, nullptr); + + // Write non-zero body data to force a non-zero final offset. + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)).Times(AtLeast(1)); + std::string body(1024, 'x'); // 1 kB + QuicByteCount header_length = 0; + if (UsesHttp3()) { + header_length = HttpEncoder::SerializeDataFrameHeader( + body.length(), quiche::SimpleBufferAllocator::Get()) + .size(); + } + + stream_->WriteOrBufferBody(body, false); + + // The final offset field in the trailing headers is populated with the + // number of body bytes written (including queued bytes). + Http2HeaderBlock trailers; + trailers["trailer key"] = "trailer value"; + + Http2HeaderBlock expected_trailers(trailers.Clone()); + // :final-offset pseudo-header is only added if trailers are sent + // on the headers stream. + if (!UsesHttp3()) { + expected_trailers[kFinalOffsetHeaderKey] = + absl::StrCat(body.length() + header_length); + } + + EXPECT_CALL(*stream_, WriteHeadersMock(true)); + stream_->WriteTrailers(std::move(trailers), nullptr); + EXPECT_EQ(expected_trailers, stream_->saved_headers()); +} + +// Test that if trailers are written after all other data has been written +// (headers and body), that this closes the stream for writing. +TEST_P(QuicSpdyStreamTest, WritingTrailersClosesWriteSide) { + Initialize(kShouldProcessData); + + // Expect data being written on the stream. In addition to that, headers are + // also written on the stream in case of IETF QUIC. + EXPECT_CALL(*session_, WritevData(stream_->id(), _, _, _, _, _)) + .Times(AtLeast(1)); + + // Write the initial headers. + EXPECT_CALL(*stream_, WriteHeadersMock(false)); + stream_->WriteHeaders(Http2HeaderBlock(), /*fin=*/false, nullptr); + + // Write non-zero body data. + const int kBodySize = 1 * 1024; // 1 kB + stream_->WriteOrBufferBody(std::string(kBodySize, 'x'), false); + EXPECT_EQ(0u, stream_->BufferedDataBytes()); + + // Headers and body have been fully written, there is no queued data. Writing + // trailers marks the end of this stream, and thus the write side is closed. + EXPECT_CALL(*stream_, WriteHeadersMock(true)); + stream_->WriteTrailers(Http2HeaderBlock(), nullptr); + EXPECT_TRUE(stream_->write_side_closed()); +} + +// Test that the stream is not closed for writing when trailers are sent while +// there are still body bytes queued. +TEST_P(QuicSpdyStreamTest, WritingTrailersWithQueuedBytes) { + // This test exercises sending trailers on the headers stream while data is + // still queued on the response/request stream. In IETF QUIC, data and + // trailers are sent on the same stream, so this test does not apply. + if (UsesHttp3()) { + return; + } + + testing::InSequence seq; + Initialize(kShouldProcessData); + + // Write the initial headers. + EXPECT_CALL(*stream_, WriteHeadersMock(false)); + stream_->WriteHeaders(Http2HeaderBlock(), /*fin=*/false, nullptr); + + // Write non-zero body data, but only consume partially, ensuring queueing. + const int kBodySize = 1 * 1024; // 1 kB + if (UsesHttp3()) { + EXPECT_CALL(*session_, WritevData(_, 3, _, NO_FIN, _, _)); + } + EXPECT_CALL(*session_, WritevData(_, kBodySize, _, NO_FIN, _, _)) + .WillOnce(Return(QuicConsumedData(kBodySize - 1, false))); + stream_->WriteOrBufferBody(std::string(kBodySize, 'x'), false); + EXPECT_EQ(1u, stream_->BufferedDataBytes()); + + // Writing trailers will send a FIN, but not close the write side of the + // stream as there are queued bytes. + EXPECT_CALL(*stream_, WriteHeadersMock(true)); + stream_->WriteTrailers(Http2HeaderBlock(), nullptr); + EXPECT_TRUE(stream_->fin_sent()); + EXPECT_FALSE(stream_->write_side_closed()); + + // Writing the queued bytes will close the write side of the stream. + EXPECT_CALL(*session_, WritevData(_, 1, _, NO_FIN, _, _)); + stream_->OnCanWrite(); + EXPECT_TRUE(stream_->write_side_closed()); +} + +// Test that it is not possible to write Trailers after a FIN has been sent. +TEST_P(QuicSpdyStreamTest, WritingTrailersAfterFIN) { + // In IETF QUIC, there is no such thing as FIN flag on HTTP/3 frames like the + // HEADERS frame. + if (UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + + // Write the initial headers, with a FIN. + EXPECT_CALL(*stream_, WriteHeadersMock(true)); + stream_->WriteHeaders(Http2HeaderBlock(), /*fin=*/true, nullptr); + EXPECT_TRUE(stream_->fin_sent()); + + // Writing Trailers should fail, as the FIN has already been sent. + // populated with the number of body bytes written. + EXPECT_QUIC_BUG(stream_->WriteTrailers(Http2HeaderBlock(), nullptr), + "Trailers cannot be sent after a FIN"); +} + +TEST_P(QuicSpdyStreamTest, HeaderStreamNotiferCorrespondingSpdyStream) { + // There is no headers stream if QPACK is used. + if (UsesHttp3()) { + return; + } + + const char kHeader1[] = "Header1"; + const char kHeader2[] = "Header2"; + const char kBody1[] = "Test1"; + const char kBody2[] = "Test2"; + + Initialize(kShouldProcessData); + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)).Times(AtLeast(1)); + testing::InSequence s; + quiche::QuicheReferenceCountedPointer ack_listener1( + new MockAckListener()); + quiche::QuicheReferenceCountedPointer ack_listener2( + new MockAckListener()); + stream_->set_ack_listener(ack_listener1); + stream2_->set_ack_listener(ack_listener2); + + session_->headers_stream()->WriteOrBufferData(kHeader1, false, ack_listener1); + stream_->WriteOrBufferBody(kBody1, true); + + session_->headers_stream()->WriteOrBufferData(kHeader2, false, ack_listener2); + stream2_->WriteOrBufferBody(kBody2, false); + + QuicStreamFrame frame1( + QuicUtils::GetHeadersStreamId(connection_->transport_version()), false, 0, + kHeader1); + + std::string data1 = UsesHttp3() ? DataFrame(kBody1) : kBody1; + QuicStreamFrame frame2(stream_->id(), true, 0, data1); + QuicStreamFrame frame3( + QuicUtils::GetHeadersStreamId(connection_->transport_version()), false, 7, + kHeader2); + std::string data2 = UsesHttp3() ? DataFrame(kBody2) : kBody2; + QuicStreamFrame frame4(stream2_->id(), false, 0, data2); + + EXPECT_CALL(*ack_listener1, OnPacketRetransmitted(7)); + session_->OnStreamFrameRetransmitted(frame1); + + EXPECT_CALL(*ack_listener1, OnPacketAcked(7, _)); + EXPECT_TRUE(session_->OnFrameAcked(QuicFrame(frame1), QuicTime::Delta::Zero(), + QuicTime::Zero())); + EXPECT_CALL(*ack_listener1, OnPacketAcked(5, _)); + EXPECT_TRUE(session_->OnFrameAcked(QuicFrame(frame2), QuicTime::Delta::Zero(), + QuicTime::Zero())); + EXPECT_CALL(*ack_listener2, OnPacketAcked(7, _)); + EXPECT_TRUE(session_->OnFrameAcked(QuicFrame(frame3), QuicTime::Delta::Zero(), + QuicTime::Zero())); + EXPECT_CALL(*ack_listener2, OnPacketAcked(5, _)); + EXPECT_TRUE(session_->OnFrameAcked(QuicFrame(frame4), QuicTime::Delta::Zero(), + QuicTime::Zero())); +} + +TEST_P(QuicSpdyStreamTest, OnPriorityFrame) { + Initialize(kShouldProcessData); + stream_->OnPriorityFrame(spdy::SpdyStreamPrecedence(kV3HighestPriority)); + EXPECT_EQ(QuicStreamPriority(HttpStreamPriority{ + kV3HighestPriority, HttpStreamPriority::kDefaultIncremental}), + stream_->priority()); +} + +TEST_P(QuicSpdyStreamTest, OnPriorityFrameAfterSendingData) { + Initialize(kShouldProcessData); + testing::InSequence seq; + + if (UsesHttp3()) { + EXPECT_CALL(*session_, WritevData(_, 2, _, NO_FIN, _, _)); + } + EXPECT_CALL(*session_, WritevData(_, 4, _, FIN, _, _)); + stream_->WriteOrBufferBody("data", true); + stream_->OnPriorityFrame(spdy::SpdyStreamPrecedence(kV3HighestPriority)); + EXPECT_EQ(QuicStreamPriority(HttpStreamPriority{ + kV3HighestPriority, HttpStreamPriority::kDefaultIncremental}), + stream_->priority()); +} + +TEST_P(QuicSpdyStreamTest, SetPriorityBeforeUpdateStreamPriority) { + MockQuicConnection* connection = new StrictMock( + &helper_, &alarm_factory_, Perspective::IS_SERVER, + SupportedVersions(GetParam())); + std::unique_ptr session( + new StrictMock(connection)); + auto stream = + new StrictMock(GetNthClientInitiatedBidirectionalStreamId( + session->transport_version(), 0), + session.get(), + /*should_process_data=*/true); + session->ActivateStream(absl::WrapUnique(stream)); + + // QuicSpdyStream::SetPriority() should eventually call UpdateStreamPriority() + // on the session. Make sure stream->priority() returns the updated priority + // if called within UpdateStreamPriority(). This expectation is enforced in + // TestMockUpdateStreamSession::UpdateStreamPriority(). + session->SetExpectedStream(stream); + session->SetExpectedPriority(HttpStreamPriority{kV3HighestPriority}); + stream->SetPriority( + QuicStreamPriority(HttpStreamPriority{kV3HighestPriority})); + + session->SetExpectedPriority(HttpStreamPriority{kV3LowestPriority}); + stream->SetPriority( + QuicStreamPriority(HttpStreamPriority{kV3LowestPriority})); +} + +TEST_P(QuicSpdyStreamTest, StreamWaitsForAcks) { + Initialize(kShouldProcessData); + quiche::QuicheReferenceCountedPointer mock_ack_listener( + new StrictMock); + stream_->set_ack_listener(mock_ack_listener); + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)).Times(AtLeast(1)); + // Stream is not waiting for acks initially. + EXPECT_FALSE(stream_->IsWaitingForAcks()); + EXPECT_EQ(0u, QuicStreamPeer::SendBuffer(stream_).size()); + + // Send kData1. + stream_->WriteOrBufferData("FooAndBar", false, nullptr); + EXPECT_EQ(1u, QuicStreamPeer::SendBuffer(stream_).size()); + EXPECT_TRUE(stream_->IsWaitingForAcks()); + EXPECT_CALL(*mock_ack_listener, OnPacketAcked(9, _)); + QuicByteCount newly_acked_length = 0; + EXPECT_TRUE(stream_->OnStreamFrameAcked(0, 9, false, QuicTime::Delta::Zero(), + QuicTime::Zero(), + &newly_acked_length)); + // Stream is not waiting for acks as all sent data is acked. + EXPECT_FALSE(stream_->IsWaitingForAcks()); + EXPECT_EQ(0u, QuicStreamPeer::SendBuffer(stream_).size()); + + // Send kData2. + stream_->WriteOrBufferData("FooAndBar", false, nullptr); + EXPECT_TRUE(stream_->IsWaitingForAcks()); + EXPECT_EQ(1u, QuicStreamPeer::SendBuffer(stream_).size()); + // Send FIN. + stream_->WriteOrBufferData("", true, nullptr); + // Fin only frame is not stored in send buffer. + EXPECT_EQ(1u, QuicStreamPeer::SendBuffer(stream_).size()); + + // kData2 is retransmitted. + EXPECT_CALL(*mock_ack_listener, OnPacketRetransmitted(9)); + stream_->OnStreamFrameRetransmitted(9, 9, false); + + // kData2 is acked. + EXPECT_CALL(*mock_ack_listener, OnPacketAcked(9, _)); + EXPECT_TRUE(stream_->OnStreamFrameAcked(9, 9, false, QuicTime::Delta::Zero(), + QuicTime::Zero(), + &newly_acked_length)); + // Stream is waiting for acks as FIN is not acked. + EXPECT_TRUE(stream_->IsWaitingForAcks()); + EXPECT_EQ(0u, QuicStreamPeer::SendBuffer(stream_).size()); + + // FIN is acked. + EXPECT_CALL(*mock_ack_listener, OnPacketAcked(0, _)); + EXPECT_TRUE(stream_->OnStreamFrameAcked(18, 0, true, QuicTime::Delta::Zero(), + QuicTime::Zero(), + &newly_acked_length)); + EXPECT_FALSE(stream_->IsWaitingForAcks()); + EXPECT_EQ(0u, QuicStreamPeer::SendBuffer(stream_).size()); +} + +TEST_P(QuicSpdyStreamTest, StreamDataGetAckedMultipleTimes) { + Initialize(kShouldProcessData); + quiche::QuicheReferenceCountedPointer mock_ack_listener( + new StrictMock); + stream_->set_ack_listener(mock_ack_listener); + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)).Times(AtLeast(1)); + // Send [0, 27) and fin. + stream_->WriteOrBufferData("FooAndBar", false, nullptr); + stream_->WriteOrBufferData("FooAndBar", false, nullptr); + stream_->WriteOrBufferData("FooAndBar", true, nullptr); + + // Ack [0, 9), [5, 22) and [18, 26) + // Verify [0, 9) 9 bytes are acked. + QuicByteCount newly_acked_length = 0; + EXPECT_CALL(*mock_ack_listener, OnPacketAcked(9, _)); + EXPECT_TRUE(stream_->OnStreamFrameAcked(0, 9, false, QuicTime::Delta::Zero(), + QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(2u, QuicStreamPeer::SendBuffer(stream_).size()); + // Verify [9, 22) 13 bytes are acked. + EXPECT_CALL(*mock_ack_listener, OnPacketAcked(13, _)); + EXPECT_TRUE(stream_->OnStreamFrameAcked(5, 17, false, QuicTime::Delta::Zero(), + QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(1u, QuicStreamPeer::SendBuffer(stream_).size()); + // Verify [22, 26) 4 bytes are acked. + EXPECT_CALL(*mock_ack_listener, OnPacketAcked(4, _)); + EXPECT_TRUE(stream_->OnStreamFrameAcked(18, 8, false, QuicTime::Delta::Zero(), + QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(1u, QuicStreamPeer::SendBuffer(stream_).size()); + EXPECT_TRUE(stream_->IsWaitingForAcks()); + + // Ack [0, 27). + // Verify [26, 27) 1 byte is acked. + EXPECT_CALL(*mock_ack_listener, OnPacketAcked(1, _)); + EXPECT_TRUE(stream_->OnStreamFrameAcked(26, 1, false, QuicTime::Delta::Zero(), + QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(0u, QuicStreamPeer::SendBuffer(stream_).size()); + EXPECT_TRUE(stream_->IsWaitingForAcks()); + + // Ack Fin. Verify OnPacketAcked is called. + EXPECT_CALL(*mock_ack_listener, OnPacketAcked(0, _)); + EXPECT_TRUE(stream_->OnStreamFrameAcked(27, 0, true, QuicTime::Delta::Zero(), + QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(0u, QuicStreamPeer::SendBuffer(stream_).size()); + EXPECT_FALSE(stream_->IsWaitingForAcks()); + + // Ack [10, 27) and fin. + // No new data is acked, verify OnPacketAcked is not called. + EXPECT_CALL(*mock_ack_listener, OnPacketAcked(_, _)).Times(0); + EXPECT_FALSE( + stream_->OnStreamFrameAcked(10, 17, true, QuicTime::Delta::Zero(), + QuicTime::Zero(), &newly_acked_length)); + EXPECT_EQ(0u, QuicStreamPeer::SendBuffer(stream_).size()); + EXPECT_FALSE(stream_->IsWaitingForAcks()); +} + +// HTTP/3 only. +TEST_P(QuicSpdyStreamTest, HeadersAckNotReportedWriteOrBufferBody) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + quiche::QuicheReferenceCountedPointer mock_ack_listener( + new StrictMock); + stream_->set_ack_listener(mock_ack_listener); + std::string body = "Test1"; + std::string body2(100, 'x'); + + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)).Times(AtLeast(1)); + stream_->WriteOrBufferBody(body, false); + stream_->WriteOrBufferBody(body2, true); + + quiche::QuicheBuffer header = HttpEncoder::SerializeDataFrameHeader( + body.length(), quiche::SimpleBufferAllocator::Get()); + quiche::QuicheBuffer header2 = HttpEncoder::SerializeDataFrameHeader( + body2.length(), quiche::SimpleBufferAllocator::Get()); + + EXPECT_CALL(*mock_ack_listener, OnPacketAcked(body.length(), _)); + QuicStreamFrame frame(stream_->id(), false, 0, + absl::StrCat(header.AsStringView(), body)); + EXPECT_TRUE(session_->OnFrameAcked(QuicFrame(frame), QuicTime::Delta::Zero(), + QuicTime::Zero())); + + EXPECT_CALL(*mock_ack_listener, OnPacketAcked(0, _)); + QuicStreamFrame frame2(stream_->id(), false, header.size() + body.length(), + header2.AsStringView()); + EXPECT_TRUE(session_->OnFrameAcked(QuicFrame(frame2), QuicTime::Delta::Zero(), + QuicTime::Zero())); + + EXPECT_CALL(*mock_ack_listener, OnPacketAcked(body2.length(), _)); + QuicStreamFrame frame3(stream_->id(), true, + header.size() + body.length() + header2.size(), body2); + EXPECT_TRUE(session_->OnFrameAcked(QuicFrame(frame3), QuicTime::Delta::Zero(), + QuicTime::Zero())); + + EXPECT_TRUE( + QuicSpdyStreamPeer::unacked_frame_headers_offsets(stream_).Empty()); +} + +// HTTP/3 only. +TEST_P(QuicSpdyStreamTest, HeadersAckNotReportedWriteBodySlices) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + quiche::QuicheReferenceCountedPointer mock_ack_listener( + new StrictMock); + stream_->set_ack_listener(mock_ack_listener); + std::string body1 = "Test1"; + std::string body2(100, 'x'); + struct iovec body1_iov = {const_cast(body1.data()), body1.length()}; + struct iovec body2_iov = {const_cast(body2.data()), body2.length()}; + quiche::QuicheMemSliceStorage storage( + &body1_iov, 1, helper_.GetStreamSendBufferAllocator(), 1024); + quiche::QuicheMemSliceStorage storage2( + &body2_iov, 1, helper_.GetStreamSendBufferAllocator(), 1024); + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)).Times(AtLeast(1)); + stream_->WriteBodySlices(storage.ToSpan(), false); + stream_->WriteBodySlices(storage2.ToSpan(), true); + + std::string data1 = DataFrame(body1); + std::string data2 = DataFrame(body2); + + EXPECT_CALL(*mock_ack_listener, + OnPacketAcked(body1.length() + body2.length(), _)); + QuicStreamFrame frame(stream_->id(), true, 0, data1 + data2); + EXPECT_TRUE(session_->OnFrameAcked(QuicFrame(frame), QuicTime::Delta::Zero(), + QuicTime::Zero())); + + EXPECT_TRUE( + QuicSpdyStreamPeer::unacked_frame_headers_offsets(stream_).Empty()); +} + +// HTTP/3 only. +TEST_P(QuicSpdyStreamTest, HeaderBytesNotReportedOnRetransmission) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + quiche::QuicheReferenceCountedPointer mock_ack_listener( + new StrictMock); + stream_->set_ack_listener(mock_ack_listener); + std::string body1 = "Test1"; + std::string body2(100, 'x'); + + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)).Times(AtLeast(1)); + stream_->WriteOrBufferBody(body1, false); + stream_->WriteOrBufferBody(body2, true); + + std::string data1 = DataFrame(body1); + std::string data2 = DataFrame(body2); + + EXPECT_CALL(*mock_ack_listener, OnPacketRetransmitted(body1.length())); + QuicStreamFrame frame(stream_->id(), false, 0, data1); + session_->OnStreamFrameRetransmitted(frame); + + EXPECT_CALL(*mock_ack_listener, OnPacketRetransmitted(body2.length())); + QuicStreamFrame frame2(stream_->id(), true, data1.length(), data2); + session_->OnStreamFrameRetransmitted(frame2); + + EXPECT_FALSE( + QuicSpdyStreamPeer::unacked_frame_headers_offsets(stream_).Empty()); +} + +TEST_P(QuicSpdyStreamTest, HeadersFrameOnRequestStream) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + + std::string headers = HeadersFrame({std::make_pair("foo", "bar")}); + std::string data = DataFrame(kDataFramePayload); + std::string trailers = + HeadersFrame({std::make_pair("custom-key", "custom-value")}); + + std::string stream_frame_payload = absl::StrCat(headers, data, trailers); + QuicStreamFrame frame(stream_->id(), false, 0, stream_frame_payload); + stream_->OnStreamFrame(frame); + + EXPECT_THAT(stream_->header_list(), ElementsAre(Pair("foo", "bar"))); + + // QuicSpdyStream only calls OnBodyAvailable() + // after the header list has been consumed. + EXPECT_EQ("", stream_->data()); + stream_->ConsumeHeaderList(); + EXPECT_EQ(kDataFramePayload, stream_->data()); + + EXPECT_THAT(stream_->received_trailers(), + ElementsAre(Pair("custom-key", "custom-value"))); +} + +TEST_P(QuicSpdyStreamTest, ProcessBodyAfterTrailers) { + if (!UsesHttp3()) { + return; + } + + Initialize(!kShouldProcessData); + + std::string headers = HeadersFrame({std::make_pair("foo", "bar")}); + std::string data = DataFrame(kDataFramePayload); + + // A header block that will take more than one block of sequencer buffer. + // This ensures that when the trailers are consumed, some buffer buckets will + // be freed. + Http2HeaderBlock trailers_block; + trailers_block["key1"] = std::string(10000, 'x'); + std::string trailers = HeadersFrame(trailers_block); + + // Feed all three HTTP/3 frames in a single stream frame. + std::string stream_frame_payload = absl::StrCat(headers, data, trailers); + QuicStreamFrame frame(stream_->id(), false, 0, stream_frame_payload); + stream_->OnStreamFrame(frame); + + stream_->ConsumeHeaderList(); + stream_->MarkTrailersConsumed(); + + EXPECT_TRUE(stream_->trailers_decompressed()); + EXPECT_EQ(trailers_block, stream_->received_trailers()); + + EXPECT_TRUE(stream_->HasBytesToRead()); + + // Consume data. + char buffer[2048]; + struct iovec vec; + vec.iov_base = buffer; + vec.iov_len = ABSL_ARRAYSIZE(buffer); + size_t bytes_read = stream_->Readv(&vec, 1); + EXPECT_EQ(kDataFramePayload, absl::string_view(buffer, bytes_read)); + + EXPECT_FALSE(stream_->HasBytesToRead()); +} + +// The test stream will receive a stream frame containing malformed headers and +// normal body. Make sure the http decoder stops processing body after the +// connection shuts down. +TEST_P(QuicSpdyStreamTest, MalformedHeadersStopHttpDecoder) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + testing::InSequence s; + connection_->AdvanceTime(QuicTime::Delta::FromSeconds(1)); + + // Random bad headers. + std::string headers = + HeadersFrame(absl::HexStringToBytes("00002a94e7036261")); + std::string data = DataFrame(kDataFramePayload); + + std::string stream_frame_payload = absl::StrCat(headers, data); + QuicStreamFrame frame(stream_->id(), false, 0, stream_frame_payload); + + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_QPACK_DECOMPRESSION_FAILED, + MatchesRegex("Error decoding headers on stream \\d+: " + "Incomplete header block."), + _)) + .WillOnce( + (Invoke([this](QuicErrorCode error, const std::string& error_details, + ConnectionCloseBehavior connection_close_behavior) { + connection_->ReallyCloseConnection(error, error_details, + connection_close_behavior); + }))); + EXPECT_CALL(*connection_, SendConnectionClosePacket(_, _, _)); + EXPECT_CALL(*session_, OnConnectionClosed(_, _)) + .WillOnce(Invoke([this](const QuicConnectionCloseFrame& frame, + ConnectionCloseSource source) { + session_->ReallyOnConnectionClosed(frame, source); + })); + EXPECT_CALL(*session_, MaybeSendRstStreamFrame(_, _, _)).Times(2); + stream_->OnStreamFrame(frame); +} + +// Regression test for https://crbug.com/1027895: a HEADERS frame triggers an +// error in QuicSpdyStream::OnHeadersFramePayload(). This closes the +// connection, freeing the buffer of QuicStreamSequencer. Therefore +// QuicStreamSequencer::MarkConsumed() must not be called from +// QuicSpdyStream::OnHeadersFramePayload(). +TEST_P(QuicSpdyStreamTest, DoNotMarkConsumedAfterQpackDecodingError) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + connection_->AdvanceTime(QuicTime::Delta::FromSeconds(1)); + + { + testing::InSequence s; + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_QPACK_DECOMPRESSION_FAILED, + MatchesRegex("Error decoding headers on stream \\d+: " + "Invalid relative index."), + _)) + .WillOnce(( + Invoke([this](QuicErrorCode error, const std::string& error_details, + ConnectionCloseBehavior connection_close_behavior) { + connection_->ReallyCloseConnection(error, error_details, + connection_close_behavior); + }))); + EXPECT_CALL(*connection_, SendConnectionClosePacket(_, _, _)); + EXPECT_CALL(*session_, OnConnectionClosed(_, _)) + .WillOnce(Invoke([this](const QuicConnectionCloseFrame& frame, + ConnectionCloseSource source) { + session_->ReallyOnConnectionClosed(frame, source); + })); + } + EXPECT_CALL(*session_, MaybeSendRstStreamFrame(stream_->id(), _, _)); + EXPECT_CALL(*session_, MaybeSendRstStreamFrame(stream2_->id(), _, _)); + + // Invalid headers: Required Insert Count is zero, but the header block + // contains a dynamic table reference. + std::string headers = HeadersFrame(absl::HexStringToBytes("000080")); + QuicStreamFrame frame(stream_->id(), false, 0, headers); + stream_->OnStreamFrame(frame); +} + +TEST_P(QuicSpdyStreamTest, ImmediateHeaderDecodingWithDynamicTableEntries) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + testing::InSequence s; + session_->qpack_decoder()->OnSetDynamicTableCapacity(1024); + StrictMock debug_visitor; + session_->set_debug_visitor(&debug_visitor); + + auto decoder_send_stream = + QuicSpdySessionPeer::GetQpackDecoderSendStream(session_.get()); + + // Deliver dynamic table entry to decoder. + session_->qpack_decoder()->OnInsertWithoutNameReference("foo", "bar"); + + // HEADERS frame referencing first dynamic table entry. + std::string encoded_headers = absl::HexStringToBytes("020080"); + std::string headers = HeadersFrame(encoded_headers); + EXPECT_CALL(debug_visitor, + OnHeadersFrameReceived(stream_->id(), encoded_headers.length())); + // Decoder stream type. + EXPECT_CALL(*session_, + WritevData(decoder_send_stream->id(), /* write_length = */ 1, + /* offset = */ 0, _, _, _)); + // Header acknowledgement. + EXPECT_CALL(*session_, + WritevData(decoder_send_stream->id(), /* write_length = */ 1, + /* offset = */ 1, _, _, _)); + EXPECT_CALL(debug_visitor, OnHeadersDecoded(stream_->id(), _)); + stream_->OnStreamFrame(QuicStreamFrame(stream_->id(), false, 0, headers)); + + // Headers can be decoded immediately. + EXPECT_TRUE(stream_->headers_decompressed()); + + // Verify headers. + EXPECT_THAT(stream_->header_list(), ElementsAre(Pair("foo", "bar"))); + stream_->ConsumeHeaderList(); + + // DATA frame. + std::string data = DataFrame(kDataFramePayload); + EXPECT_CALL(debug_visitor, + OnDataFrameReceived(stream_->id(), strlen(kDataFramePayload))); + stream_->OnStreamFrame(QuicStreamFrame(stream_->id(), false, /* offset = */ + headers.length(), data)); + EXPECT_EQ(kDataFramePayload, stream_->data()); + + // Deliver second dynamic table entry to decoder. + session_->qpack_decoder()->OnInsertWithoutNameReference("trailing", "foobar"); + + // Trailing HEADERS frame referencing second dynamic table entry. + std::string encoded_trailers = absl::HexStringToBytes("030080"); + std::string trailers = HeadersFrame(encoded_trailers); + EXPECT_CALL(debug_visitor, + OnHeadersFrameReceived(stream_->id(), encoded_trailers.length())); + // Header acknowledgement. + EXPECT_CALL(*session_, WritevData(decoder_send_stream->id(), _, _, _, _, _)); + EXPECT_CALL(debug_visitor, OnHeadersDecoded(stream_->id(), _)); + stream_->OnStreamFrame(QuicStreamFrame(stream_->id(), true, /* offset = */ + headers.length() + data.length(), + trailers)); + + // Trailers can be decoded immediately. + EXPECT_TRUE(stream_->trailers_decompressed()); + + // Verify trailers. + EXPECT_THAT(stream_->received_trailers(), + ElementsAre(Pair("trailing", "foobar"))); + stream_->MarkTrailersConsumed(); +} + +TEST_P(QuicSpdyStreamTest, BlockedHeaderDecoding) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + testing::InSequence s; + session_->qpack_decoder()->OnSetDynamicTableCapacity(1024); + StrictMock debug_visitor; + session_->set_debug_visitor(&debug_visitor); + + // HEADERS frame referencing first dynamic table entry. + std::string encoded_headers = absl::HexStringToBytes("020080"); + std::string headers = HeadersFrame(encoded_headers); + EXPECT_CALL(debug_visitor, + OnHeadersFrameReceived(stream_->id(), encoded_headers.length())); + stream_->OnStreamFrame(QuicStreamFrame(stream_->id(), false, 0, headers)); + + // Decoding is blocked because dynamic table entry has not been received yet. + EXPECT_FALSE(stream_->headers_decompressed()); + + auto decoder_send_stream = + QuicSpdySessionPeer::GetQpackDecoderSendStream(session_.get()); + + // Decoder stream type. + EXPECT_CALL(*session_, + WritevData(decoder_send_stream->id(), /* write_length = */ 1, + /* offset = */ 0, _, _, _)); + // Header acknowledgement. + EXPECT_CALL(*session_, + WritevData(decoder_send_stream->id(), /* write_length = */ 1, + /* offset = */ 1, _, _, _)); + EXPECT_CALL(debug_visitor, OnHeadersDecoded(stream_->id(), _)); + // Deliver dynamic table entry to decoder. + session_->qpack_decoder()->OnInsertWithoutNameReference("foo", "bar"); + EXPECT_TRUE(stream_->headers_decompressed()); + + // Verify headers. + EXPECT_THAT(stream_->header_list(), ElementsAre(Pair("foo", "bar"))); + stream_->ConsumeHeaderList(); + + // DATA frame. + std::string data = DataFrame(kDataFramePayload); + EXPECT_CALL(debug_visitor, + OnDataFrameReceived(stream_->id(), strlen(kDataFramePayload))); + stream_->OnStreamFrame(QuicStreamFrame(stream_->id(), false, /* offset = */ + headers.length(), data)); + EXPECT_EQ(kDataFramePayload, stream_->data()); + + // Trailing HEADERS frame referencing second dynamic table entry. + std::string encoded_trailers = absl::HexStringToBytes("030080"); + std::string trailers = HeadersFrame(encoded_trailers); + EXPECT_CALL(debug_visitor, + OnHeadersFrameReceived(stream_->id(), encoded_trailers.length())); + stream_->OnStreamFrame(QuicStreamFrame(stream_->id(), true, /* offset = */ + headers.length() + data.length(), + trailers)); + + // Decoding is blocked because dynamic table entry has not been received yet. + EXPECT_FALSE(stream_->trailers_decompressed()); + + // Header acknowledgement. + EXPECT_CALL(*session_, WritevData(decoder_send_stream->id(), _, _, _, _, _)); + EXPECT_CALL(debug_visitor, OnHeadersDecoded(stream_->id(), _)); + // Deliver second dynamic table entry to decoder. + session_->qpack_decoder()->OnInsertWithoutNameReference("trailing", "foobar"); + EXPECT_TRUE(stream_->trailers_decompressed()); + + // Verify trailers. + EXPECT_THAT(stream_->received_trailers(), + ElementsAre(Pair("trailing", "foobar"))); + stream_->MarkTrailersConsumed(); +} + +TEST_P(QuicSpdyStreamTest, AsyncErrorDecodingHeaders) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + session_->qpack_decoder()->OnSetDynamicTableCapacity(1024); + + // HEADERS frame only referencing entry with absolute index 0 but with + // Required Insert Count = 2, which is incorrect. + std::string headers = HeadersFrame(absl::HexStringToBytes("030081")); + stream_->OnStreamFrame(QuicStreamFrame(stream_->id(), false, 0, headers)); + + // Even though entire header block is received and every referenced entry is + // available, decoding is blocked until insert count reaches the Required + // Insert Count value advertised in the header block prefix. + EXPECT_FALSE(stream_->headers_decompressed()); + + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_QPACK_DECOMPRESSION_FAILED, + MatchesRegex("Error decoding headers on stream \\d+: " + "Required Insert Count too large."), + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET)); + + // Deliver two dynamic table entries to decoder + // to trigger decoding of header block. + session_->qpack_decoder()->OnInsertWithoutNameReference("foo", "bar"); + session_->qpack_decoder()->OnInsertWithoutNameReference("foo", "bar"); +} + +// Regression test for https://crbug.com/1024263 and for +// https://crbug.com/1025209#c11. +TEST_P(QuicSpdyStreamTest, BlockedHeaderDecodingUnblockedWithBufferedError) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + session_->qpack_decoder()->OnSetDynamicTableCapacity(1024); + + // Relative index 2 is invalid because it is larger than or equal to the Base. + std::string headers = HeadersFrame(absl::HexStringToBytes("020082")); + stream_->OnStreamFrame(QuicStreamFrame(stream_->id(), false, 0, headers)); + + // Decoding is blocked. + EXPECT_FALSE(stream_->headers_decompressed()); + + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_QPACK_DECOMPRESSION_FAILED, + MatchesRegex("Error decoding headers on stream \\d+: " + "Invalid relative index."), + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET)); + + // Deliver one dynamic table entry to decoder + // to trigger decoding of header block. + session_->qpack_decoder()->OnInsertWithoutNameReference("foo", "bar"); +} + +TEST_P(QuicSpdyStreamTest, AsyncErrorDecodingTrailers) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + testing::InSequence s; + session_->qpack_decoder()->OnSetDynamicTableCapacity(1024); + + // HEADERS frame referencing first dynamic table entry. + std::string headers = HeadersFrame(absl::HexStringToBytes("020080")); + stream_->OnStreamFrame(QuicStreamFrame(stream_->id(), false, 0, headers)); + + // Decoding is blocked because dynamic table entry has not been received yet. + EXPECT_FALSE(stream_->headers_decompressed()); + + auto decoder_send_stream = + QuicSpdySessionPeer::GetQpackDecoderSendStream(session_.get()); + + // Decoder stream type. + EXPECT_CALL(*session_, + WritevData(decoder_send_stream->id(), /* write_length = */ 1, + /* offset = */ 0, _, _, _)); + // Header acknowledgement. + EXPECT_CALL(*session_, + WritevData(decoder_send_stream->id(), /* write_length = */ 1, + /* offset = */ 1, _, _, _)); + // Deliver dynamic table entry to decoder. + session_->qpack_decoder()->OnInsertWithoutNameReference("foo", "bar"); + EXPECT_TRUE(stream_->headers_decompressed()); + + // Verify headers. + EXPECT_THAT(stream_->header_list(), ElementsAre(Pair("foo", "bar"))); + stream_->ConsumeHeaderList(); + + // DATA frame. + std::string data = DataFrame(kDataFramePayload); + stream_->OnStreamFrame(QuicStreamFrame(stream_->id(), false, /* offset = */ + headers.length(), data)); + EXPECT_EQ(kDataFramePayload, stream_->data()); + + // Trailing HEADERS frame only referencing entry with absolute index 0 but + // with Required Insert Count = 2, which is incorrect. + std::string trailers = HeadersFrame(absl::HexStringToBytes("030081")); + stream_->OnStreamFrame(QuicStreamFrame(stream_->id(), true, /* offset = */ + headers.length() + data.length(), + trailers)); + + // Even though entire header block is received and every referenced entry is + // available, decoding is blocked until insert count reaches the Required + // Insert Count value advertised in the header block prefix. + EXPECT_FALSE(stream_->trailers_decompressed()); + + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_QPACK_DECOMPRESSION_FAILED, + MatchesRegex("Error decoding trailers on stream \\d+: " + "Required Insert Count too large."), + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET)); + + // Deliver second dynamic table entry to decoder + // to trigger decoding of trailing header block. + session_->qpack_decoder()->OnInsertWithoutNameReference("trailing", "foobar"); +} + +// Regression test for b/132603592: QPACK decoding unblocked after stream is +// closed. +TEST_P(QuicSpdyStreamTest, HeaderDecodingUnblockedAfterStreamClosed) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + testing::InSequence s; + session_->qpack_decoder()->OnSetDynamicTableCapacity(1024); + StrictMock debug_visitor; + session_->set_debug_visitor(&debug_visitor); + + // HEADERS frame referencing first dynamic table entry. + std::string encoded_headers = absl::HexStringToBytes("020080"); + std::string headers = HeadersFrame(encoded_headers); + EXPECT_CALL(debug_visitor, + OnHeadersFrameReceived(stream_->id(), encoded_headers.length())); + stream_->OnStreamFrame(QuicStreamFrame(stream_->id(), false, 0, headers)); + + // Decoding is blocked because dynamic table entry has not been received yet. + EXPECT_FALSE(stream_->headers_decompressed()); + + // Decoder stream type and stream cancellation instruction. + auto decoder_send_stream = + QuicSpdySessionPeer::GetQpackDecoderSendStream(session_.get()); + EXPECT_CALL(*session_, + WritevData(decoder_send_stream->id(), /* write_length = */ 1, + /* offset = */ 0, _, _, _)); + EXPECT_CALL(*session_, + WritevData(decoder_send_stream->id(), /* write_length = */ 1, + /* offset = */ 1, _, _, _)); + + // Reset stream by this endpoint, for example, due to stream cancellation. + EXPECT_CALL(*session_, MaybeSendStopSendingFrame( + stream_->id(), QuicResetStreamError::FromInternal( + QUIC_STREAM_CANCELLED))); + EXPECT_CALL( + *session_, + MaybeSendRstStreamFrame( + stream_->id(), + QuicResetStreamError::FromInternal(QUIC_STREAM_CANCELLED), _)); + stream_->Reset(QUIC_STREAM_CANCELLED); + + // Deliver dynamic table entry to decoder. + session_->qpack_decoder()->OnInsertWithoutNameReference("foo", "bar"); + + EXPECT_FALSE(stream_->headers_decompressed()); +} + +TEST_P(QuicSpdyStreamTest, HeaderDecodingUnblockedAfterResetReceived) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + testing::InSequence s; + session_->qpack_decoder()->OnSetDynamicTableCapacity(1024); + StrictMock debug_visitor; + session_->set_debug_visitor(&debug_visitor); + + // HEADERS frame referencing first dynamic table entry. + std::string encoded_headers = absl::HexStringToBytes("020080"); + std::string headers = HeadersFrame(encoded_headers); + EXPECT_CALL(debug_visitor, + OnHeadersFrameReceived(stream_->id(), encoded_headers.length())); + stream_->OnStreamFrame(QuicStreamFrame(stream_->id(), false, 0, headers)); + + // Decoding is blocked because dynamic table entry has not been received yet. + EXPECT_FALSE(stream_->headers_decompressed()); + + // Decoder stream type and stream cancellation instruction. + auto decoder_send_stream = + QuicSpdySessionPeer::GetQpackDecoderSendStream(session_.get()); + EXPECT_CALL(*session_, + WritevData(decoder_send_stream->id(), /* write_length = */ 1, + /* offset = */ 0, _, _, _)); + EXPECT_CALL(*session_, + WritevData(decoder_send_stream->id(), /* write_length = */ 1, + /* offset = */ 1, _, _, _)); + + // OnStreamReset() is called when RESET_STREAM frame is received from peer. + // This aborts header decompression. + stream_->OnStreamReset(QuicRstStreamFrame( + kInvalidControlFrameId, stream_->id(), QUIC_STREAM_CANCELLED, 0)); + + // Deliver dynamic table entry to decoder. + session_->qpack_decoder()->OnInsertWithoutNameReference("foo", "bar"); + EXPECT_FALSE(stream_->headers_decompressed()); +} + +class QuicSpdyStreamIncrementalConsumptionTest : public QuicSpdyStreamTest { + protected: + QuicSpdyStreamIncrementalConsumptionTest() : offset_(0), consumed_bytes_(0) {} + ~QuicSpdyStreamIncrementalConsumptionTest() override = default; + + // Create QuicStreamFrame with |payload| + // and pass it to stream_->OnStreamFrame(). + void OnStreamFrame(absl::string_view payload) { + QuicStreamFrame frame(stream_->id(), /* fin = */ false, offset_, payload); + stream_->OnStreamFrame(frame); + offset_ += payload.size(); + } + + // Return number of bytes marked consumed with sequencer + // since last NewlyConsumedBytes() call. + QuicStreamOffset NewlyConsumedBytes() { + QuicStreamOffset previously_consumed_bytes = consumed_bytes_; + consumed_bytes_ = stream_->sequencer()->NumBytesConsumed(); + return consumed_bytes_ - previously_consumed_bytes; + } + + // Read |size| bytes from the stream. + std::string ReadFromStream(QuicByteCount size) { + std::string buffer; + buffer.resize(size); + + struct iovec vec; + vec.iov_base = const_cast(buffer.data()); + vec.iov_len = size; + + size_t bytes_read = stream_->Readv(&vec, 1); + EXPECT_EQ(bytes_read, size); + + return buffer; + } + + private: + QuicStreamOffset offset_; + QuicStreamOffset consumed_bytes_; +}; + +INSTANTIATE_TEST_SUITE_P(Tests, QuicSpdyStreamIncrementalConsumptionTest, + ::testing::ValuesIn(AllSupportedVersions()), + ::testing::PrintToStringParamName()); + +// Test that stream bytes are consumed (by calling +// sequencer()->MarkConsumed()) incrementally, as soon as possible. +TEST_P(QuicSpdyStreamIncrementalConsumptionTest, OnlyKnownFrames) { + if (!UsesHttp3()) { + return; + } + + Initialize(!kShouldProcessData); + + std::string headers = HeadersFrame({std::make_pair("foo", "bar")}); + + // All HEADERS frame bytes are consumed even if the frame is not received + // completely. + OnStreamFrame(absl::string_view(headers).substr(0, headers.size() - 1)); + EXPECT_EQ(headers.size() - 1, NewlyConsumedBytes()); + + // The rest of the HEADERS frame is also consumed immediately. + OnStreamFrame(absl::string_view(headers).substr(headers.size() - 1)); + EXPECT_EQ(1u, NewlyConsumedBytes()); + + // Verify headers. + EXPECT_THAT(stream_->header_list(), ElementsAre(Pair("foo", "bar"))); + stream_->ConsumeHeaderList(); + + // DATA frame. + absl::string_view data_payload(kDataFramePayload); + std::string data_frame = DataFrame(data_payload); + QuicByteCount data_frame_header_length = + data_frame.size() - data_payload.size(); + + // DATA frame header is consumed. + // DATA frame payload is not consumed because payload has to be buffered. + OnStreamFrame(data_frame); + EXPECT_EQ(data_frame_header_length, NewlyConsumedBytes()); + + // Consume all but last byte of data. + EXPECT_EQ(data_payload.substr(0, data_payload.size() - 1), + ReadFromStream(data_payload.size() - 1)); + EXPECT_EQ(data_payload.size() - 1, NewlyConsumedBytes()); + + std::string trailers = + HeadersFrame({std::make_pair("custom-key", "custom-value")}); + + // No bytes are consumed, because last byte of DATA payload is still buffered. + OnStreamFrame(absl::string_view(trailers).substr(0, trailers.size() - 1)); + EXPECT_EQ(0u, NewlyConsumedBytes()); + + // Reading last byte of DATA payload triggers consumption of all data received + // so far, even though last HEADERS frame has not been received completely. + EXPECT_EQ(data_payload.substr(data_payload.size() - 1), ReadFromStream(1)); + EXPECT_EQ(1 + trailers.size() - 1, NewlyConsumedBytes()); + + // Last byte of trailers is immediately consumed. + OnStreamFrame(absl::string_view(trailers).substr(trailers.size() - 1)); + EXPECT_EQ(1u, NewlyConsumedBytes()); + + // Verify trailers. + EXPECT_THAT(stream_->received_trailers(), + ElementsAre(Pair("custom-key", "custom-value"))); +} + +TEST_P(QuicSpdyStreamIncrementalConsumptionTest, ReceiveUnknownFrame) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + StrictMock debug_visitor; + session_->set_debug_visitor(&debug_visitor); + + EXPECT_CALL(debug_visitor, + OnUnknownFrameReceived(stream_->id(), /* frame_type = */ 0x21, + /* payload_length = */ 3)); + std::string unknown_frame = UnknownFrame(0x21, "foo"); + OnStreamFrame(unknown_frame); +} + +TEST_P(QuicSpdyStreamIncrementalConsumptionTest, UnknownFramesInterleaved) { + if (!UsesHttp3()) { + return; + } + + Initialize(!kShouldProcessData); + + // Unknown frame of reserved type before HEADERS is consumed immediately. + std::string unknown_frame1 = UnknownFrame(0x21, "foo"); + OnStreamFrame(unknown_frame1); + EXPECT_EQ(unknown_frame1.size(), NewlyConsumedBytes()); + + std::string headers = HeadersFrame({std::make_pair("foo", "bar")}); + + // All HEADERS frame bytes are consumed even if the frame is not received + // completely. + OnStreamFrame(absl::string_view(headers).substr(0, headers.size() - 1)); + EXPECT_EQ(headers.size() - 1, NewlyConsumedBytes()); + + // The rest of the HEADERS frame is also consumed immediately. + OnStreamFrame(absl::string_view(headers).substr(headers.size() - 1)); + EXPECT_EQ(1u, NewlyConsumedBytes()); + + // Verify headers. + EXPECT_THAT(stream_->header_list(), ElementsAre(Pair("foo", "bar"))); + stream_->ConsumeHeaderList(); + + // Frame of unknown, not reserved type between HEADERS and DATA is consumed + // immediately. + std::string unknown_frame2 = UnknownFrame(0x3a, ""); + OnStreamFrame(unknown_frame2); + EXPECT_EQ(unknown_frame2.size(), NewlyConsumedBytes()); + + // DATA frame. + absl::string_view data_payload(kDataFramePayload); + std::string data_frame = DataFrame(data_payload); + QuicByteCount data_frame_header_length = + data_frame.size() - data_payload.size(); + + // DATA frame header is consumed. + // DATA frame payload is not consumed because payload has to be buffered. + OnStreamFrame(data_frame); + EXPECT_EQ(data_frame_header_length, NewlyConsumedBytes()); + + // Frame of unknown, not reserved type is not consumed because DATA payload is + // still buffered. + std::string unknown_frame3 = UnknownFrame(0x39, "bar"); + OnStreamFrame(unknown_frame3); + EXPECT_EQ(0u, NewlyConsumedBytes()); + + // Consume all but last byte of data. + EXPECT_EQ(data_payload.substr(0, data_payload.size() - 1), + ReadFromStream(data_payload.size() - 1)); + EXPECT_EQ(data_payload.size() - 1, NewlyConsumedBytes()); + + std::string trailers = + HeadersFrame({std::make_pair("custom-key", "custom-value")}); + + // No bytes are consumed, because last byte of DATA payload is still buffered. + OnStreamFrame(absl::string_view(trailers).substr(0, trailers.size() - 1)); + EXPECT_EQ(0u, NewlyConsumedBytes()); + + // Reading last byte of DATA payload triggers consumption of all data received + // so far, even though last HEADERS frame has not been received completely. + EXPECT_EQ(data_payload.substr(data_payload.size() - 1), ReadFromStream(1)); + EXPECT_EQ(1 + unknown_frame3.size() + trailers.size() - 1, + NewlyConsumedBytes()); + + // Last byte of trailers is immediately consumed. + OnStreamFrame(absl::string_view(trailers).substr(trailers.size() - 1)); + EXPECT_EQ(1u, NewlyConsumedBytes()); + + // Verify trailers. + EXPECT_THAT(stream_->received_trailers(), + ElementsAre(Pair("custom-key", "custom-value"))); + + // Unknown frame of reserved type after trailers is consumed immediately. + std::string unknown_frame4 = UnknownFrame(0x40, ""); + OnStreamFrame(unknown_frame4); + EXPECT_EQ(unknown_frame4.size(), NewlyConsumedBytes()); +} + +// Close connection if a DATA frame is received before a HEADERS frame. +TEST_P(QuicSpdyStreamTest, DataBeforeHeaders) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + + // Closing the connection is mocked out in tests. Instead, simply stop + // reading data at the stream level to prevent QuicSpdyStream from blowing up. + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_HTTP_INVALID_FRAME_SEQUENCE_ON_SPDY_STREAM, + "Unexpected DATA frame received.", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET)) + .WillOnce(InvokeWithoutArgs([this]() { stream_->StopReading(); })); + + std::string data = DataFrame(kDataFramePayload); + stream_->OnStreamFrame(QuicStreamFrame(stream_->id(), false, 0, data)); +} + +// Close connection if a HEADERS frame is received after the trailing HEADERS. +TEST_P(QuicSpdyStreamTest, TrailersAfterTrailers) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + + // Receive and consume headers. + std::string headers = HeadersFrame({std::make_pair("foo", "bar")}); + QuicStreamOffset offset = 0; + stream_->OnStreamFrame( + QuicStreamFrame(stream_->id(), false, offset, headers)); + offset += headers.size(); + + EXPECT_THAT(stream_->header_list(), ElementsAre(Pair("foo", "bar"))); + stream_->ConsumeHeaderList(); + + // Receive data. It is consumed by TestStream. + std::string data = DataFrame(kDataFramePayload); + stream_->OnStreamFrame(QuicStreamFrame(stream_->id(), false, offset, data)); + offset += data.size(); + + EXPECT_EQ(kDataFramePayload, stream_->data()); + + // Receive and consume trailers. + std::string trailers1 = + HeadersFrame({std::make_pair("custom-key", "custom-value")}); + stream_->OnStreamFrame( + QuicStreamFrame(stream_->id(), false, offset, trailers1)); + offset += trailers1.size(); + + EXPECT_TRUE(stream_->trailers_decompressed()); + EXPECT_THAT(stream_->received_trailers(), + ElementsAre(Pair("custom-key", "custom-value"))); + + // Closing the connection is mocked out in tests. Instead, simply stop + // reading data at the stream level to prevent QuicSpdyStream from blowing up. + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_HTTP_INVALID_FRAME_SEQUENCE_ON_SPDY_STREAM, + "HEADERS frame received after trailing HEADERS.", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET)) + .WillOnce(InvokeWithoutArgs([this]() { stream_->StopReading(); })); + + // Receive another HEADERS frame, with no header fields. + std::string trailers2 = HeadersFrame(Http2HeaderBlock()); + stream_->OnStreamFrame( + QuicStreamFrame(stream_->id(), false, offset, trailers2)); +} + +// Regression test for https://crbug.com/978733. +// Close connection if a DATA frame is received after the trailing HEADERS. +TEST_P(QuicSpdyStreamTest, DataAfterTrailers) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + + // Receive and consume headers. + std::string headers = HeadersFrame({std::make_pair("foo", "bar")}); + QuicStreamOffset offset = 0; + stream_->OnStreamFrame( + QuicStreamFrame(stream_->id(), false, offset, headers)); + offset += headers.size(); + + EXPECT_THAT(stream_->header_list(), ElementsAre(Pair("foo", "bar"))); + stream_->ConsumeHeaderList(); + + // Receive data. It is consumed by TestStream. + std::string data1 = DataFrame(kDataFramePayload); + stream_->OnStreamFrame(QuicStreamFrame(stream_->id(), false, offset, data1)); + offset += data1.size(); + EXPECT_EQ(kDataFramePayload, stream_->data()); + + // Receive trailers, with single header field "custom-key: custom-value". + std::string trailers = + HeadersFrame({std::make_pair("custom-key", "custom-value")}); + stream_->OnStreamFrame( + QuicStreamFrame(stream_->id(), false, offset, trailers)); + offset += trailers.size(); + + EXPECT_THAT(stream_->received_trailers(), + ElementsAre(Pair("custom-key", "custom-value"))); + + // Closing the connection is mocked out in tests. Instead, simply stop + // reading data at the stream level to prevent QuicSpdyStream from blowing up. + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_HTTP_INVALID_FRAME_SEQUENCE_ON_SPDY_STREAM, + "Unexpected DATA frame received.", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET)) + .WillOnce(InvokeWithoutArgs([this]() { stream_->StopReading(); })); + + // Receive more data. + std::string data2 = DataFrame("This payload should not be proccessed."); + stream_->OnStreamFrame(QuicStreamFrame(stream_->id(), false, offset, data2)); +} + +// SETTINGS frames are invalid on bidirectional streams. If one is received, +// the connection is closed. No more data should be processed. +TEST_P(QuicSpdyStreamTest, StopProcessingIfConnectionClosed) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + + // SETTINGS frame with empty payload. + std::string settings = absl::HexStringToBytes("0400"); + + // HEADERS frame. + // Since it arrives after a SETTINGS frame, it should never be read. + std::string headers = HeadersFrame({std::make_pair("foo", "bar")}); + + // Combine the two frames to make sure they are processed in a single + // QuicSpdyStream::OnDataAvailable() call. + std::string frames = absl::StrCat(settings, headers); + + EXPECT_EQ(0u, stream_->sequencer()->NumBytesConsumed()); + + EXPECT_CALL(*connection_, + CloseConnection(QUIC_HTTP_FRAME_UNEXPECTED_ON_SPDY_STREAM, _, _)) + .WillOnce( + Invoke(connection_, &MockQuicConnection::ReallyCloseConnection)); + EXPECT_CALL(*connection_, SendConnectionClosePacket(_, _, _)); + EXPECT_CALL(*session_, OnConnectionClosed(_, _)); + + stream_->OnStreamFrame(QuicStreamFrame(stream_->id(), /* fin = */ false, + /* offset = */ 0, frames)); + + EXPECT_EQ(0u, stream_->sequencer()->NumBytesConsumed()); +} + +// Stream Cancellation instruction is sent on QPACK decoder stream +// when stream is reset. +TEST_P(QuicSpdyStreamTest, StreamCancellationWhenStreamReset) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + + auto qpack_decoder_stream = + QuicSpdySessionPeer::GetQpackDecoderSendStream(session_.get()); + // Stream type. + EXPECT_CALL(*session_, + WritevData(qpack_decoder_stream->id(), /* write_length = */ 1, + /* offset = */ 0, _, _, _)); + // Stream cancellation. + EXPECT_CALL(*session_, + WritevData(qpack_decoder_stream->id(), /* write_length = */ 1, + /* offset = */ 1, _, _, _)); + EXPECT_CALL(*session_, MaybeSendStopSendingFrame( + stream_->id(), QuicResetStreamError::FromInternal( + QUIC_STREAM_CANCELLED))); + EXPECT_CALL( + *session_, + MaybeSendRstStreamFrame( + stream_->id(), + QuicResetStreamError::FromInternal(QUIC_STREAM_CANCELLED), _)); + + stream_->Reset(QUIC_STREAM_CANCELLED); +} + +// Stream Cancellation instruction is sent on QPACK decoder stream +// when RESET_STREAM frame is received. +TEST_P(QuicSpdyStreamTest, StreamCancellationOnResetReceived) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + + auto qpack_decoder_stream = + QuicSpdySessionPeer::GetQpackDecoderSendStream(session_.get()); + // Stream type. + EXPECT_CALL(*session_, + WritevData(qpack_decoder_stream->id(), /* write_length = */ 1, + /* offset = */ 0, _, _, _)); + // Stream cancellation. + EXPECT_CALL(*session_, + WritevData(qpack_decoder_stream->id(), /* write_length = */ 1, + /* offset = */ 1, _, _, _)); + + stream_->OnStreamReset(QuicRstStreamFrame( + kInvalidControlFrameId, stream_->id(), QUIC_STREAM_CANCELLED, 0)); +} + +TEST_P(QuicSpdyStreamTest, WriteHeadersReturnValue) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + testing::InSequence s; + + // Enable QPACK dynamic table. + session_->OnSetting(SETTINGS_QPACK_MAX_TABLE_CAPACITY, 1024); + session_->OnSetting(SETTINGS_QPACK_BLOCKED_STREAMS, 1); + + EXPECT_CALL(*stream_, WriteHeadersMock(true)); + + QpackSendStream* encoder_stream = + QuicSpdySessionPeer::GetQpackEncoderSendStream(session_.get()); + EXPECT_CALL(*session_, WritevData(encoder_stream->id(), _, _, _, _, _)) + .Times(AnyNumber()); + + size_t bytes_written = 0; + EXPECT_CALL(*session_, + WritevData(stream_->id(), _, /* offset = */ 0, _, _, _)) + .WillOnce( + DoAll(SaveArg<1>(&bytes_written), + Invoke(session_.get(), &MockQuicSpdySession::ConsumeData))); + + Http2HeaderBlock request_headers; + request_headers["foo"] = "bar"; + size_t write_headers_return_value = + stream_->WriteHeaders(std::move(request_headers), /*fin=*/true, nullptr); + EXPECT_TRUE(stream_->fin_sent()); + // bytes_written includes HEADERS frame header. + EXPECT_GT(bytes_written, write_headers_return_value); +} + +// Regression test for https://crbug.com/1177662. +// RESET_STREAM with QUIC_STREAM_NO_ERROR should not be treated in a special +// way: it should close the read side but not the write side. +TEST_P(QuicSpdyStreamTest, TwoResetStreamFrames) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)).Times(AnyNumber()); + + QuicRstStreamFrame rst_frame1(kInvalidControlFrameId, stream_->id(), + QUIC_STREAM_CANCELLED, /* bytes_written = */ 0); + stream_->OnStreamReset(rst_frame1); + EXPECT_TRUE(stream_->read_side_closed()); + EXPECT_FALSE(stream_->write_side_closed()); + + QuicRstStreamFrame rst_frame2(kInvalidControlFrameId, stream_->id(), + QUIC_STREAM_NO_ERROR, /* bytes_written = */ 0); + stream_->OnStreamReset(rst_frame2); + EXPECT_TRUE(stream_->read_side_closed()); + EXPECT_FALSE(stream_->write_side_closed()); +} + +TEST_P(QuicSpdyStreamTest, ProcessOutgoingWebTransportHeaders) { + if (!UsesHttp3()) { + return; + } + + InitializeWithPerspective(kShouldProcessData, Perspective::IS_CLIENT); + session_->set_local_http_datagram_support(HttpDatagramSupport::kRfc); + session_->EnableWebTransport(); + session_->OnSetting(SETTINGS_ENABLE_CONNECT_PROTOCOL, 1); + QuicSpdySessionPeer::EnableWebTransport(session_.get()); + QuicSpdySessionPeer::SetHttpDatagramSupport(session_.get(), + HttpDatagramSupport::kRfc); + + EXPECT_CALL(*stream_, WriteHeadersMock(false)); + EXPECT_CALL(*session_, WritevData(stream_->id(), _, _, _, _, _)) + .Times(AnyNumber()); + + spdy::Http2HeaderBlock headers; + headers[":method"] = "CONNECT"; + headers[":protocol"] = "webtransport"; + stream_->WriteHeaders(std::move(headers), /*fin=*/false, nullptr); + ASSERT_TRUE(stream_->web_transport() != nullptr); + EXPECT_EQ(stream_->id(), stream_->web_transport()->id()); +} + +TEST_P(QuicSpdyStreamTest, ProcessIncomingWebTransportHeaders) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + session_->set_local_http_datagram_support(HttpDatagramSupport::kRfc); + session_->EnableWebTransport(); + QuicSpdySessionPeer::EnableWebTransport(session_.get()); + QuicSpdySessionPeer::SetHttpDatagramSupport(session_.get(), + HttpDatagramSupport::kRfc); + + headers_[":method"] = "CONNECT"; + headers_[":protocol"] = "webtransport"; + headers_["sec-webtransport-http3-draft02"] = "1"; + + stream_->OnStreamHeadersPriority( + spdy::SpdyStreamPrecedence(kV3HighestPriority)); + ProcessHeaders(false, headers_); + EXPECT_EQ("", stream_->data()); + EXPECT_FALSE(stream_->header_list().empty()); + EXPECT_FALSE(stream_->IsDoneReading()); + ASSERT_TRUE(stream_->web_transport() != nullptr); + EXPECT_EQ(stream_->id(), stream_->web_transport()->id()); +} + +TEST_P(QuicSpdyStreamTest, ReceiveHttpDatagram) { + if (!UsesHttp3()) { + return; + } + InitializeWithPerspective(kShouldProcessData, Perspective::IS_CLIENT); + session_->set_local_http_datagram_support(HttpDatagramSupport::kRfc); + QuicSpdySessionPeer::SetHttpDatagramSupport(session_.get(), + HttpDatagramSupport::kRfc); + headers_[":method"] = "CONNECT"; + headers_[":protocol"] = "webtransport"; + ProcessHeaders(false, headers_); + SavingHttp3DatagramVisitor h3_datagram_visitor; + ASSERT_EQ(QuicDataWriter::GetVarInt62Len(stream_->id()), 1); + std::array datagram; + datagram[0] = stream_->id(); + for (size_t i = 1; i < datagram.size(); i++) { + datagram[i] = i; + } + + stream_->RegisterHttp3DatagramVisitor(&h3_datagram_visitor); + session_->OnMessageReceived( + absl::string_view(datagram.data(), datagram.size())); + EXPECT_THAT( + h3_datagram_visitor.received_h3_datagrams(), + ElementsAre(SavingHttp3DatagramVisitor::SavedHttp3Datagram{ + stream_->id(), std::string(&datagram[1], datagram.size() - 1)})); + // Test move. + SavingHttp3DatagramVisitor h3_datagram_visitor2; + stream_->ReplaceHttp3DatagramVisitor(&h3_datagram_visitor2); + EXPECT_TRUE(h3_datagram_visitor2.received_h3_datagrams().empty()); + session_->OnMessageReceived( + absl::string_view(datagram.data(), datagram.size())); + EXPECT_THAT( + h3_datagram_visitor2.received_h3_datagrams(), + ElementsAre(SavingHttp3DatagramVisitor::SavedHttp3Datagram{ + stream_->id(), std::string(&datagram[1], datagram.size() - 1)})); + // Cleanup. + stream_->UnregisterHttp3DatagramVisitor(); +} + +TEST_P(QuicSpdyStreamTest, SendHttpDatagram) { + if (!UsesHttp3()) { + return; + } + Initialize(kShouldProcessData); + session_->set_local_http_datagram_support(HttpDatagramSupport::kRfc); + QuicSpdySessionPeer::SetHttpDatagramSupport(session_.get(), + HttpDatagramSupport::kRfc); + std::string http_datagram_payload = {1, 2, 3, 4, 5, 6}; + EXPECT_CALL(*connection_, SendMessage(1, _, false)) + .WillOnce(Return(MESSAGE_STATUS_SUCCESS)); + EXPECT_EQ(stream_->SendHttp3Datagram(http_datagram_payload), + MESSAGE_STATUS_SUCCESS); +} + +TEST_P(QuicSpdyStreamTest, GetMaxDatagramSize) { + if (!UsesHttp3()) { + return; + } + Initialize(kShouldProcessData); + session_->set_local_http_datagram_support(HttpDatagramSupport::kRfc); + QuicSpdySessionPeer::SetHttpDatagramSupport(session_.get(), + HttpDatagramSupport::kRfc); + EXPECT_GT(stream_->GetMaxDatagramSize(), 512u); +} + +TEST_P(QuicSpdyStreamTest, Capsules) { + if (!UsesHttp3()) { + return; + } + Initialize(kShouldProcessData); + session_->set_local_http_datagram_support(HttpDatagramSupport::kRfc); + QuicSpdySessionPeer::SetHttpDatagramSupport(session_.get(), + HttpDatagramSupport::kRfc); + SavingHttp3DatagramVisitor h3_datagram_visitor; + stream_->RegisterHttp3DatagramVisitor(&h3_datagram_visitor); + SavingConnectIpVisitor connect_ip_visitor; + stream_->RegisterConnectIpVisitor(&connect_ip_visitor); + headers_[":method"] = "CONNECT"; + headers_[":protocol"] = "fake-capsule-protocol"; + ProcessHeaders(/*fin=*/false, headers_); + // Datagram capsule. + std::string http_datagram_payload = {1, 2, 3, 4, 5, 6}; + stream_->OnCapsule(Capsule::Datagram(http_datagram_payload)); + EXPECT_THAT(h3_datagram_visitor.received_h3_datagrams(), + ElementsAre(SavingHttp3DatagramVisitor::SavedHttp3Datagram{ + stream_->id(), http_datagram_payload})); + // Address assign capsule. + quiche::PrefixWithId ip_prefix_with_id; + ip_prefix_with_id.request_id = 1; + quiche::QuicheIpAddress ip_address; + ip_address.FromString("::"); + ip_prefix_with_id.ip_prefix = + quiche::QuicheIpPrefix(ip_address, /*prefix_length=*/96); + Capsule address_assign_capsule = Capsule::AddressAssign(); + address_assign_capsule.address_assign_capsule().assigned_addresses.push_back( + ip_prefix_with_id); + stream_->OnCapsule(address_assign_capsule); + EXPECT_THAT(connect_ip_visitor.received_address_assign_capsules(), + ElementsAre(address_assign_capsule.address_assign_capsule())); + // Address request capsule. + Capsule address_request_capsule = Capsule::AddressRequest(); + address_request_capsule.address_request_capsule() + .requested_addresses.push_back(ip_prefix_with_id); + stream_->OnCapsule(address_request_capsule); + EXPECT_THAT(connect_ip_visitor.received_address_request_capsules(), + ElementsAre(address_request_capsule.address_request_capsule())); + // Route advertisement capsule. + Capsule route_advertisement_capsule = Capsule::RouteAdvertisement(); + IpAddressRange ip_address_range; + ip_address_range.start_ip_address.FromString("192.0.2.24"); + ip_address_range.end_ip_address.FromString("192.0.2.42"); + ip_address_range.ip_protocol = 0; + route_advertisement_capsule.route_advertisement_capsule() + .ip_address_ranges.push_back(ip_address_range); + stream_->OnCapsule(route_advertisement_capsule); + EXPECT_THAT( + connect_ip_visitor.received_route_advertisement_capsules(), + ElementsAre(route_advertisement_capsule.route_advertisement_capsule())); + // Unknown capsule. + uint64_t capsule_type = 0x17u; + std::string capsule_payload = {1, 2, 3, 4}; + Capsule unknown_capsule = Capsule::Unknown(capsule_type, capsule_payload); + stream_->OnCapsule(unknown_capsule); + EXPECT_THAT(h3_datagram_visitor.received_unknown_capsules(), + ElementsAre(SavingHttp3DatagramVisitor::SavedUnknownCapsule{ + stream_->id(), capsule_type, capsule_payload})); + // Cleanup. + stream_->UnregisterHttp3DatagramVisitor(); + stream_->UnregisterConnectIpVisitor(); +} + +TEST_P(QuicSpdyStreamTest, + QUIC_TEST_DISABLED_IN_CHROME(HeadersAccumulatorNullptr)) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + + // Creates QpackDecodedHeadersAccumulator in + // `qpack_decoded_headers_accumulator_`. + std::string headers = HeadersFrame({std::make_pair("foo", "bar")}); + stream_->OnStreamFrame(QuicStreamFrame(stream_->id(), false, 0, headers)); + + // Resets `qpack_decoded_headers_accumulator_`. + stream_->OnHeadersDecoded({}, false); + + EXPECT_QUIC_BUG( + { + EXPECT_CALL(*connection_, CloseConnection(_, _, _)); + // This private method should never be called when + // `qpack_decoded_headers_accumulator_` is nullptr. + EXPECT_FALSE(QuicSpdyStreamPeer::OnHeadersFrameEnd(stream_)); + }, + "b215142466_OnHeadersFrameEnd"); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/http/spdy_server_push_utils.cc b/quiche/quic/core/http/spdy_server_push_utils.cc new file mode 100644 index 000000000000..22dc13fa9f80 --- /dev/null +++ b/quiche/quic/core/http/spdy_server_push_utils.cc @@ -0,0 +1,215 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/spdy_server_push_utils.h" + +#include "absl/strings/string_view.h" +#include "url/gurl.h" + +using spdy::Http2HeaderBlock; + +namespace quic { + +// static +std::string SpdyServerPushUtils::GetPromisedUrlFromHeaders( + const Http2HeaderBlock& headers) { + // RFC 7540, Section 8.1.2.3: All HTTP/2 requests MUST include exactly + // one valid value for the ":method", ":scheme", and ":path" pseudo-header + // fields, unless it is a CONNECT request. + + // RFC 7540, Section 8.2.1: The header fields in PUSH_PROMISE and any + // subsequent CONTINUATION frames MUST be a valid and complete set of request + // header fields (Section 8.1.2.3). The server MUST include a method in the + // ":method" pseudo-header field that is safe and cacheable. + // + // RFC 7231, Section 4.2.1: Of the request methods defined by this + // specification, the GET, HEAD, OPTIONS, and TRACE methods are defined to be + // safe. + // + // RFC 7231, Section 4.2.1: ... this specification defines GET, HEAD, and + // POST as cacheable, ... + // + // So the only methods allowed in a PUSH_PROMISE are GET and HEAD. + Http2HeaderBlock::const_iterator it = headers.find(":method"); + if (it == headers.end() || (it->second != "GET" && it->second != "HEAD")) { + return std::string(); + } + + it = headers.find(":scheme"); + if (it == headers.end() || it->second.empty()) { + return std::string(); + } + absl::string_view scheme = it->second; + + // RFC 7540, Section 8.2: The server MUST include a value in the + // ":authority" pseudo-header field for which the server is authoritative + // (see Section 10.1). + it = headers.find(":authority"); + if (it == headers.end() || it->second.empty()) { + return std::string(); + } + absl::string_view authority = it->second; + + // RFC 7540, Section 8.1.2.3 requires that the ":path" pseudo-header MUST + // NOT be empty for "http" or "https" URIs; + // + // However, to ensure the scheme is consistently canonicalized, that check + // is deferred to implementations in QuicUrlUtils::GetPushPromiseUrl(). + it = headers.find(":path"); + if (it == headers.end()) { + return std::string(); + } + absl::string_view path = it->second; + + return GetPushPromiseUrl(scheme, authority, path); +} + +// static +std::string SpdyServerPushUtils::GetPromisedHostNameFromHeaders( + const Http2HeaderBlock& headers) { + // TODO(fayang): Consider just checking out the value of the ":authority" key + // in headers. + return GURL(GetPromisedUrlFromHeaders(headers)).host(); +} + +// static +bool SpdyServerPushUtils::PromisedUrlIsValid(const Http2HeaderBlock& headers) { + std::string url(GetPromisedUrlFromHeaders(headers)); + return !url.empty() && GURL(url).is_valid(); +} + +// static +std::string SpdyServerPushUtils::GetPushPromiseUrl(absl::string_view scheme, + absl::string_view authority, + absl::string_view path) { + // RFC 7540, Section 8.1.2.3: The ":path" pseudo-header field includes the + // path and query parts of the target URI (the "path-absolute" production + // and optionally a '?' character followed by the "query" production (see + // Sections 3.3 and 3.4 of RFC3986). A request in asterisk form includes the + // value '*' for the ":path" pseudo-header field. + // + // This pseudo-header field MUST NOT be empty for "http" or "https" URIs; + // "http" or "https" URIs that do not contain a path MUST include a value of + // '/'. The exception to this rule is an OPTIONS request for an "http" or + // "https" URI that does not include a path component; these MUST include a + // ":path" pseudo-header with a value of '*' (see RFC7230, Section 5.3.4). + // + // In addition to the above restriction from RFC 7540, note that RFC3986 + // defines the "path-absolute" construction as starting with "/" but not "//". + // + // RFC 7540, Section 8.2.1: The header fields in PUSH_PROMISE and any + // subsequent CONTINUATION frames MUST be a valid and complete set of request + // header fields (Section 8.1.2.3). The server MUST include a method in the + // ":method" pseudo-header field that is safe and cacheable. + // + // RFC 7231, Section 4.2.1: + // ... this specification defines GET, HEAD, and POST as cacheable, ... + // + // Since the OPTIONS method is not cacheable, it cannot be the method of a + // PUSH_PROMISE. Therefore, the exception mentioned in RFC 7540, Section + // 8.1.2.3 about OPTIONS requests does not apply here (i.e. ":path" cannot be + // "*"). + if (path.empty() || path[0] != '/' || (path.size() >= 2 && path[1] == '/')) { + return std::string(); + } + + // Validate the scheme; this is to ensure a scheme of "foo://bar" is not + // parsed as a URL of "foo://bar://baz" when combined with a host of "baz". + std::string canonical_scheme; + url::StdStringCanonOutput canon_scheme_output(&canonical_scheme); + url::Component canon_component; + url::Component scheme_component(0, scheme.size()); + + if (!url::CanonicalizeScheme(scheme.data(), scheme_component, + &canon_scheme_output, &canon_component) || + !canon_component.is_nonempty() || canon_component.begin != 0) { + return std::string(); + } + canonical_scheme.resize(canon_component.len + 1); + + // Validate the authority; this is to ensure an authority such as + // "host/path" is not accepted, as when combined with a scheme like + // "http://", could result in a URL of "http://host/path". + url::Component auth_component(0, authority.size()); + url::Component username_component; + url::Component password_component; + url::Component host_component; + url::Component port_component; + + url::ParseAuthority(authority.data(), auth_component, &username_component, + &password_component, &host_component, &port_component); + + // RFC 7540, Section 8.1.2.3: The authority MUST NOT include the deprecated + // "userinfo" subcomponent for "http" or "https" schemed URIs. + // + // Note: Although |canonical_scheme| has not yet been checked for that, as + // it is performed later in processing, only "http" and "https" schemed + // URIs are supported for PUSH. + if (username_component.is_valid() || password_component.is_valid()) { + return std::string(); + } + + // Failed parsing or no host present. ParseAuthority() will ensure that + // host_component + port_component cover the entire string, if + // username_component and password_component are not present. + if (!host_component.is_nonempty()) { + return std::string(); + } + + // Validate the port (if present; it's optional). + int parsed_port_number = url::PORT_INVALID; + if (port_component.is_nonempty()) { + parsed_port_number = url::ParsePort(authority.data(), port_component); + if (parsed_port_number < 0 && parsed_port_number != url::PORT_UNSPECIFIED) { + return std::string(); + } + } + + // Validate the host by attempting to canonicalize it. Invalid characters + // will result in a canonicalization failure (e.g. '/') + std::string canon_host; + url::StdStringCanonOutput canon_host_output(&canon_host); + canon_component.reset(); + if (!url::CanonicalizeHost(authority.data(), host_component, + &canon_host_output, &canon_component) || + !canon_component.is_nonempty() || canon_component.begin != 0) { + return std::string(); + } + + // At this point, "authority" has been validated to either be of the form + // 'host:port' or 'host', with 'host' being a valid domain or IP address, + // and 'port' (if present), being a valid port. Attempt to construct a + // URL of just the (scheme, host, port), which should be safe and will not + // result in ambiguous parsing. + // + // This also enforces that all PUSHed URLs are either HTTP or HTTPS-schemed + // URIs, consistent with the other restrictions enforced above. + // + // Note: url::CanonicalizeScheme() will have added the ':' to + // |canonical_scheme|. + GURL origin_url(canonical_scheme + "//" + std::string(authority)); + if (!origin_url.is_valid() || !origin_url.SchemeIsHTTPOrHTTPS() || + // The following checks are merely defense in depth. + origin_url.has_username() || origin_url.has_password() || + (origin_url.has_path() && origin_url.path_piece() != "/") || + origin_url.has_query() || origin_url.has_ref()) { + return std::string(); + } + + // Attempt to parse the path. + std::string spec = origin_url.GetWithEmptyPath().spec(); + spec.pop_back(); // Remove the '/', as ":path" must contain it. + spec.append(std::string(path)); + + // Attempt to parse the full URL, with the path as well. Ensure there is no + // fragment to the query. + GURL full_url(spec); + if (!full_url.is_valid() || full_url.has_ref()) { + return std::string(); + } + + return full_url.spec(); +} + +} // namespace quic diff --git a/quiche/quic/core/http/spdy_server_push_utils.h b/quiche/quic/core/http/spdy_server_push_utils.h new file mode 100644 index 000000000000..a924158bf34f --- /dev/null +++ b/quiche/quic/core/http/spdy_server_push_utils.h @@ -0,0 +1,43 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_HTTP_SPDY_SERVER_PUSH_UTILS_H_ +#define QUICHE_QUIC_CORE_HTTP_SPDY_SERVER_PUSH_UTILS_H_ + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +class QUIC_EXPORT_PRIVATE SpdyServerPushUtils { + public: + SpdyServerPushUtils() = delete; + + // Returns a canonicalized URL composed from the :scheme, :authority, and + // :path headers of a PUSH_PROMISE. Returns empty string if the headers do not + // conform to HTTP/2 spec or if the ":method" header contains a forbidden + // method for PUSH_PROMISE. + static std::string GetPromisedUrlFromHeaders( + const spdy::Http2HeaderBlock& headers); + + // Returns hostname, or empty string if missing. + static std::string GetPromisedHostNameFromHeaders( + const spdy::Http2HeaderBlock& headers); + + // Returns true if result of |GetPromisedUrlFromHeaders()| is non-empty + // and is a well-formed URL. + static bool PromisedUrlIsValid(const spdy::Http2HeaderBlock& headers); + + // Returns a canonical, valid URL for a PUSH_PROMISE with the specified + // ":scheme", ":authority", and ":path" header fields, or an empty + // string if the resulting URL is not valid or supported. + static std::string GetPushPromiseUrl(absl::string_view scheme, + absl::string_view authority, + absl::string_view path); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_HTTP_SPDY_SERVER_PUSH_UTILS_H_ diff --git a/quiche/quic/core/http/spdy_server_push_utils_test.cc b/quiche/quic/core/http/spdy_server_push_utils_test.cc new file mode 100644 index 000000000000..f2d1855107bc --- /dev/null +++ b/quiche/quic/core/http/spdy_server_push_utils_test.cc @@ -0,0 +1,221 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/spdy_server_push_utils.h" + +#include +#include +#include + +#include "absl/base/macros.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +using spdy::Http2HeaderBlock; + +namespace quic { +namespace test { + +using GetPromisedUrlFromHeaders = QuicTest; + +TEST_F(GetPromisedUrlFromHeaders, Basic) { + Http2HeaderBlock headers; + headers[":method"] = "GET"; + EXPECT_EQ(SpdyServerPushUtils::GetPromisedUrlFromHeaders(headers), ""); + headers[":scheme"] = "https"; + EXPECT_EQ(SpdyServerPushUtils::GetPromisedUrlFromHeaders(headers), ""); + headers[":authority"] = "www.google.com"; + EXPECT_EQ(SpdyServerPushUtils::GetPromisedUrlFromHeaders(headers), ""); + headers[":path"] = "/index.html"; + EXPECT_EQ(SpdyServerPushUtils::GetPromisedUrlFromHeaders(headers), + "https://www.google.com/index.html"); + headers["key1"] = "value1"; + headers["key2"] = "value2"; + EXPECT_EQ(SpdyServerPushUtils::GetPromisedUrlFromHeaders(headers), + "https://www.google.com/index.html"); +} + +TEST_F(GetPromisedUrlFromHeaders, Connect) { + Http2HeaderBlock headers; + headers[":method"] = "CONNECT"; + EXPECT_EQ(SpdyServerPushUtils::GetPromisedUrlFromHeaders(headers), ""); + headers[":authority"] = "www.google.com"; + EXPECT_EQ(SpdyServerPushUtils::GetPromisedUrlFromHeaders(headers), ""); + headers[":scheme"] = "https"; + EXPECT_EQ(SpdyServerPushUtils::GetPromisedUrlFromHeaders(headers), ""); + headers[":path"] = "https"; + EXPECT_EQ(SpdyServerPushUtils::GetPromisedUrlFromHeaders(headers), ""); +} + +TEST_F(GetPromisedUrlFromHeaders, InvalidUserinfo) { + Http2HeaderBlock headers; + headers[":method"] = "GET"; + headers[":authority"] = "user@www.google.com"; + headers[":scheme"] = "https"; + headers[":path"] = "/"; + EXPECT_EQ(SpdyServerPushUtils::GetPromisedUrlFromHeaders(headers), ""); +} + +TEST_F(GetPromisedUrlFromHeaders, InvalidPath) { + Http2HeaderBlock headers; + headers[":method"] = "GET"; + headers[":authority"] = "www.google.com"; + headers[":scheme"] = "https"; + headers[":path"] = ""; + EXPECT_EQ(SpdyServerPushUtils::GetPromisedUrlFromHeaders(headers), ""); +} + +using GetPromisedHostNameFromHeaders = QuicTest; + +TEST_F(GetPromisedHostNameFromHeaders, NormalUsage) { + Http2HeaderBlock headers; + headers[":method"] = "GET"; + EXPECT_EQ(SpdyServerPushUtils::GetPromisedHostNameFromHeaders(headers), ""); + headers[":scheme"] = "https"; + EXPECT_EQ(SpdyServerPushUtils::GetPromisedHostNameFromHeaders(headers), ""); + headers[":authority"] = "www.google.com"; + EXPECT_EQ(SpdyServerPushUtils::GetPromisedHostNameFromHeaders(headers), ""); + headers[":path"] = "/index.html"; + EXPECT_EQ(SpdyServerPushUtils::GetPromisedHostNameFromHeaders(headers), + "www.google.com"); + headers["key1"] = "value1"; + headers["key2"] = "value2"; + EXPECT_EQ(SpdyServerPushUtils::GetPromisedHostNameFromHeaders(headers), + "www.google.com"); + headers[":authority"] = "www.google.com:6666"; + EXPECT_EQ(SpdyServerPushUtils::GetPromisedHostNameFromHeaders(headers), + "www.google.com"); + headers[":authority"] = "192.168.1.1"; + EXPECT_EQ(SpdyServerPushUtils::GetPromisedHostNameFromHeaders(headers), + "192.168.1.1"); + headers[":authority"] = "192.168.1.1:6666"; + EXPECT_EQ(SpdyServerPushUtils::GetPromisedHostNameFromHeaders(headers), + "192.168.1.1"); +} + +using PushPromiseUrlTest = QuicTest; + +TEST_F(PushPromiseUrlTest, GetPushPromiseUrl) { + // Test rejection of various inputs. + EXPECT_EQ("", SpdyServerPushUtils::GetPushPromiseUrl("file", "localhost", + "/etc/password")); + EXPECT_EQ("", SpdyServerPushUtils::GetPushPromiseUrl( + "file", "", "/C:/Windows/System32/Config/")); + EXPECT_EQ("", SpdyServerPushUtils::GetPushPromiseUrl( + "", "https://www.google.com", "/")); + + EXPECT_EQ("", SpdyServerPushUtils::GetPushPromiseUrl("https://www.google.com", + "www.google.com", "/")); + EXPECT_EQ("", SpdyServerPushUtils::GetPushPromiseUrl("https://", + "www.google.com", "/")); + EXPECT_EQ("", SpdyServerPushUtils::GetPushPromiseUrl("https", "", "/")); + EXPECT_EQ("", SpdyServerPushUtils::GetPushPromiseUrl("https", "", + "www.google.com/")); + EXPECT_EQ("", SpdyServerPushUtils::GetPushPromiseUrl("https", + "www.google.com/", "/")); + EXPECT_EQ("", SpdyServerPushUtils::GetPushPromiseUrl("https", + "www.google.com", "")); + EXPECT_EQ("", SpdyServerPushUtils::GetPushPromiseUrl("https", "www.google", + ".com/")); + + // Test acception/rejection of various input combinations. + // |input_headers| is an array of pairs. The first value of each pair is a + // string that will be used as one of the inputs of GetPushPromiseUrl(). The + // second value of each pair is a bitfield where the lowest 3 bits indicate + // for which headers that string is valid (in a PUSH_PROMISE). For example, + // the string "http" would be valid for both the ":scheme" and ":authority" + // headers, so the bitfield paired with it is set to SCHEME | AUTH. + const unsigned char SCHEME = (1u << 0); + const unsigned char AUTH = (1u << 1); + const unsigned char PATH = (1u << 2); + std::vector> input_headers = { + {"http", SCHEME | AUTH}, + {"https", SCHEME | AUTH}, + {"hTtP", SCHEME | AUTH}, + {"HTTPS", SCHEME | AUTH}, + {"www.google.com", AUTH}, + {"90af90e0", AUTH}, + {"12foo%20-bar:00001233", AUTH}, + {"192.168.0.5", AUTH}, + {"[::ffff:192.168.0.1.]", AUTH}, + {"http:", AUTH}, + {"bife l", AUTH}, + {"/", PATH}, + {"/foo/bar/baz", PATH}, + {"/%20-2DVdkj.cie/foe_.iif/", PATH}, + {"http://", 0}, + {":443", 0}, + {":80/eddd", 0}, + {"google.com:-0", 0}, + {"google.com:65536", 0}, + {"http://google.com", 0}, + {"http://google.com:39", 0}, + {"//google.com/foo", 0}, + {".com/", 0}, + {"http://www.google.com/", 0}, + {"http://foo:439", 0}, + {"[::ffff:192.168", 0}, + {"]/", 0}, + {"//", 0}}; + if (quiche::test::GoogleUrlSupportsIdnaForTest()) { + input_headers.push_back({"GOO\u200b\u2060\ufeffgoo", AUTH}); + } + for (size_t i = 0; i < input_headers.size(); ++i) { + bool should_accept = (input_headers[i].second & SCHEME); + for (size_t j = 0; j < input_headers.size(); ++j) { + bool should_accept_2 = should_accept && (input_headers[j].second & AUTH); + for (size_t k = 0; k < input_headers.size(); ++k) { + // |should_accept_3| indicates whether or not GetPushPromiseUrl() is + // expected to accept this input combination. + bool should_accept_3 = + should_accept_2 && (input_headers[k].second & PATH); + + std::string url = SpdyServerPushUtils::GetPushPromiseUrl( + input_headers[i].first, input_headers[j].first, + input_headers[k].first); + + ::testing::AssertionResult result = ::testing::AssertionSuccess(); + if (url.empty() == should_accept_3) { + result = ::testing::AssertionFailure() + << "GetPushPromiseUrl() accepted/rejected the inputs when " + "it shouldn't have." + << std::endl + << " scheme: " << input_headers[i].first << std::endl + << " authority: " << input_headers[j].first << std::endl + << " path: " << input_headers[k].first << std::endl + << "Output: " << url << std::endl; + } + ASSERT_TRUE(result); + } + } + } + + // Test canonicalization of various valid inputs. + EXPECT_EQ("http://www.google.com/", SpdyServerPushUtils::GetPushPromiseUrl( + "http", "www.google.com", "/")); + EXPECT_EQ("https://www.goo-gle.com/fOOo/baRR", + SpdyServerPushUtils::GetPushPromiseUrl("hTtPs", "wWw.gOo-gLE.cOm", + "/fOOo/baRR")); + EXPECT_EQ("https://www.goo-gle.com:3278/pAth/To/reSOurce", + SpdyServerPushUtils::GetPushPromiseUrl( + "hTtPs", "Www.gOo-Gle.Com:000003278", "/pAth/To/reSOurce")); + EXPECT_EQ("https://foo%20bar/foo/bar/baz", + SpdyServerPushUtils::GetPushPromiseUrl("https", "foo bar", + "/foo/bar/baz")); + EXPECT_EQ("http://foo.com:70/e/", SpdyServerPushUtils::GetPushPromiseUrl( + "http", "foo.com:0000070", "/e/")); + EXPECT_EQ("http://192.168.0.1:70/e/", + SpdyServerPushUtils::GetPushPromiseUrl( + "http", "0300.0250.00.01:0070", "/e/")); + EXPECT_EQ("http://192.168.0.1/e/", SpdyServerPushUtils::GetPushPromiseUrl( + "http", "0xC0a80001", "/e/")); + EXPECT_EQ("http://[::c0a8:1]/", SpdyServerPushUtils::GetPushPromiseUrl( + "http", "[::192.168.0.1]", "/")); + EXPECT_EQ("https://[::ffff:c0a8:1]/", + SpdyServerPushUtils::GetPushPromiseUrl( + "https", "[::ffff:0xC0.0Xa8.0x0.0x1]", "/")); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/http/spdy_utils.cc b/quiche/quic/core/http/spdy_utils.cc new file mode 100644 index 000000000000..873d39c9ff6c --- /dev/null +++ b/quiche/quic/core/http/spdy_utils.cc @@ -0,0 +1,176 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/spdy_utils.h" + +#include +#include +#include + +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/quiche_text_utils.h" +#include "quiche/spdy/core/spdy_protocol.h" + +using spdy::Http2HeaderBlock; + +namespace quic { + +// static +bool SpdyUtils::ExtractContentLengthFromHeaders(int64_t* content_length, + Http2HeaderBlock* headers) { + auto it = headers->find("content-length"); + if (it == headers->end()) { + return false; + } else { + // Check whether multiple values are consistent. + absl::string_view content_length_header = it->second; + std::vector values = + absl::StrSplit(content_length_header, '\0'); + for (const absl::string_view& value : values) { + uint64_t new_value; + if (!absl::SimpleAtoi(value, &new_value) || + !quiche::QuicheTextUtils::IsAllDigits(value)) { + QUIC_DLOG(ERROR) + << "Content length was either unparseable or negative."; + return false; + } + if (*content_length < 0) { + *content_length = new_value; + continue; + } + if (new_value != static_cast(*content_length)) { + QUIC_DLOG(ERROR) + << "Parsed content length " << new_value << " is " + << "inconsistent with previously detected content length " + << *content_length; + return false; + } + } + return true; + } +} + +bool SpdyUtils::CopyAndValidateHeaders(const QuicHeaderList& header_list, + int64_t* content_length, + Http2HeaderBlock* headers) { + for (const auto& p : header_list) { + const std::string& name = p.first; + if (name.empty()) { + QUIC_DLOG(ERROR) << "Header name must not be empty."; + return false; + } + + if (quiche::QuicheTextUtils::ContainsUpperCase(name)) { + QUIC_DLOG(ERROR) << "Malformed header: Header name " << name + << " contains upper-case characters."; + return false; + } + + headers->AppendValueOrAddHeader(name, p.second); + } + + if (headers->contains("content-length") && + !ExtractContentLengthFromHeaders(content_length, headers)) { + return false; + } + + QUIC_DVLOG(1) << "Successfully parsed headers: " << headers->DebugString(); + return true; +} + +bool SpdyUtils::CopyAndValidateTrailers(const QuicHeaderList& header_list, + bool expect_final_byte_offset, + size_t* final_byte_offset, + Http2HeaderBlock* trailers) { + bool found_final_byte_offset = false; + for (const auto& p : header_list) { + const std::string& name = p.first; + + // Pull out the final offset pseudo header which indicates the number of + // response body bytes expected. + if (expect_final_byte_offset && !found_final_byte_offset && + name == kFinalOffsetHeaderKey && + absl::SimpleAtoi(p.second, final_byte_offset)) { + found_final_byte_offset = true; + continue; + } + + if (name.empty() || name[0] == ':') { + QUIC_DLOG(ERROR) + << "Trailers must not be empty, and must not contain pseudo-" + << "headers. Found: '" << name << "'"; + return false; + } + + if (quiche::QuicheTextUtils::ContainsUpperCase(name)) { + QUIC_DLOG(ERROR) << "Malformed header: Header name " << name + << " contains upper-case characters."; + return false; + } + + trailers->AppendValueOrAddHeader(name, p.second); + } + + if (expect_final_byte_offset && !found_final_byte_offset) { + QUIC_DLOG(ERROR) << "Required key '" << kFinalOffsetHeaderKey + << "' not present"; + return false; + } + + // TODO(rjshade): Check for other forbidden keys, following the HTTP/2 spec. + + QUIC_DVLOG(1) << "Successfully parsed Trailers: " << trailers->DebugString(); + return true; +} + +// static +// TODO(danzh): Move it to quic/tools/ and switch to use GURL. +bool SpdyUtils::PopulateHeaderBlockFromUrl(const std::string url, + Http2HeaderBlock* headers) { + (*headers)[":method"] = "GET"; + size_t pos = url.find("://"); + if (pos == std::string::npos) { + return false; + } + (*headers)[":scheme"] = url.substr(0, pos); + size_t start = pos + 3; + pos = url.find('/', start); + if (pos == std::string::npos) { + (*headers)[":authority"] = url.substr(start); + (*headers)[":path"] = "/"; + return true; + } + (*headers)[":authority"] = url.substr(start, pos - start); + (*headers)[":path"] = url.substr(pos); + return true; +} + +// static +ParsedQuicVersion SpdyUtils::ExtractQuicVersionFromAltSvcEntry( + const spdy::SpdyAltSvcWireFormat::AlternativeService& + alternative_service_entry, + const ParsedQuicVersionVector& supported_versions) { + for (const ParsedQuicVersion& version : supported_versions) { + if (version.AlpnDeferToRFCv1()) { + // Versions with share an ALPN with v1 are currently unable to be + // advertised with Alt-Svc. + continue; + } + if (AlpnForVersion(version) == alternative_service_entry.protocol_id) { + return version; + } + } + + return ParsedQuicVersion::Unsupported(); +} + +} // namespace quic diff --git a/quiche/quic/core/http/spdy_utils.h b/quiche/quic/core/http/spdy_utils.h new file mode 100644 index 000000000000..05a237d3ec3f --- /dev/null +++ b/quiche/quic/core/http/spdy_utils.h @@ -0,0 +1,68 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_HTTP_SPDY_UTILS_H_ +#define QUICHE_QUIC_CORE_HTTP_SPDY_UTILS_H_ + +#include +#include +#include + +#include "absl/types/optional.h" +#include "quiche/quic/core/http/http_constants.h" +#include "quiche/quic/core/http/quic_header_list.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/spdy/core/http2_header_block.h" +#include "quiche/spdy/core/spdy_alt_svc_wire_format.h" + +namespace quic { + +class QUIC_EXPORT_PRIVATE SpdyUtils { + public: + SpdyUtils() = delete; + + // Populate |content length| with the value of the content-length header. + // Returns true on success, false if parsing fails or content-length header is + // missing. + static bool ExtractContentLengthFromHeaders(int64_t* content_length, + spdy::Http2HeaderBlock* headers); + + // Copies a list of headers to a Http2HeaderBlock. + static bool CopyAndValidateHeaders(const QuicHeaderList& header_list, + int64_t* content_length, + spdy::Http2HeaderBlock* headers); + + // Copies a list of headers to a Http2HeaderBlock. + // If |expect_final_byte_offset| is true, requires exactly one header field + // with key kFinalOffsetHeaderKey and an integer value. + // If |expect_final_byte_offset| is false, no kFinalOffsetHeaderKey may be + // present. + // Returns true if parsing is successful. Returns false if the presence of + // kFinalOffsetHeaderKey does not match the value of + // |expect_final_byte_offset|, the kFinalOffsetHeaderKey value cannot be + // parsed, any other pseudo-header is present, an empty header key is present, + // or a header key contains an uppercase character. + static bool CopyAndValidateTrailers(const QuicHeaderList& header_list, + bool expect_final_byte_offset, + size_t* final_byte_offset, + spdy::Http2HeaderBlock* trailers); + + // Populates the fields of |headers| to make a GET request of |url|, + // which must be fully-qualified. + static bool PopulateHeaderBlockFromUrl(const std::string url, + spdy::Http2HeaderBlock* headers); + + // Returns the advertised QUIC version from the specified alternative service + // advertisement, or ParsedQuicVersion::Unsupported() if no supported version + // is advertised. + static ParsedQuicVersion ExtractQuicVersionFromAltSvcEntry( + const spdy::SpdyAltSvcWireFormat::AlternativeService& + alternative_service_entry, + const ParsedQuicVersionVector& supported_versions); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_HTTP_SPDY_UTILS_H_ diff --git a/quiche/quic/core/http/spdy_utils_test.cc b/quiche/quic/core/http/spdy_utils_test.cc new file mode 100644 index 000000000000..43177ba05f63 --- /dev/null +++ b/quiche/quic/core/http/spdy_utils_test.cc @@ -0,0 +1,410 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/spdy_utils.h" + +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_test.h" + +using spdy::Http2HeaderBlock; +using testing::Pair; +using testing::UnorderedElementsAre; + +namespace quic { +namespace test { +namespace { + +const bool kExpectFinalByteOffset = true; +const bool kDoNotExpectFinalByteOffset = false; + +static std::unique_ptr FromList( + const QuicHeaderList::ListType& src) { + std::unique_ptr headers(new QuicHeaderList); + headers->OnHeaderBlockStart(); + for (const auto& p : src) { + headers->OnHeader(p.first, p.second); + } + headers->OnHeaderBlockEnd(0, 0); + return headers; +} + +} // anonymous namespace + +using CopyAndValidateHeaders = QuicTest; + +TEST_F(CopyAndValidateHeaders, NormalUsage) { + auto headers = FromList({// All cookie crumbs are joined. + {"cookie", " part 1"}, + {"cookie", "part 2 "}, + {"cookie", "part3"}, + + // Already-delimited headers are passed through. + {"passed-through", std::string("foo\0baz", 7)}, + + // Other headers are joined on \0. + {"joined", "value 1"}, + {"joined", "value 2"}, + + // Empty headers remain empty. + {"empty", ""}, + + // Joined empty headers work as expected. + {"empty-joined", ""}, + {"empty-joined", "foo"}, + {"empty-joined", ""}, + {"empty-joined", ""}, + + // Non-continguous cookie crumb. + {"cookie", " fin!"}}); + + int64_t content_length = -1; + Http2HeaderBlock block; + ASSERT_TRUE( + SpdyUtils::CopyAndValidateHeaders(*headers, &content_length, &block)); + EXPECT_THAT(block, + UnorderedElementsAre( + Pair("cookie", " part 1; part 2 ; part3; fin!"), + Pair("passed-through", absl::string_view("foo\0baz", 7)), + Pair("joined", absl::string_view("value 1\0value 2", 15)), + Pair("empty", ""), + Pair("empty-joined", absl::string_view("\0foo\0\0", 6)))); + EXPECT_EQ(-1, content_length); +} + +TEST_F(CopyAndValidateHeaders, EmptyName) { + auto headers = FromList({{"foo", "foovalue"}, {"", "barvalue"}, {"baz", ""}}); + int64_t content_length = -1; + Http2HeaderBlock block; + ASSERT_FALSE( + SpdyUtils::CopyAndValidateHeaders(*headers, &content_length, &block)); +} + +TEST_F(CopyAndValidateHeaders, UpperCaseName) { + auto headers = + FromList({{"foo", "foovalue"}, {"bar", "barvalue"}, {"bAz", ""}}); + int64_t content_length = -1; + Http2HeaderBlock block; + ASSERT_FALSE( + SpdyUtils::CopyAndValidateHeaders(*headers, &content_length, &block)); +} + +TEST_F(CopyAndValidateHeaders, MultipleContentLengths) { + auto headers = FromList({{"content-length", "9"}, + {"foo", "foovalue"}, + {"content-length", "9"}, + {"bar", "barvalue"}, + {"baz", ""}}); + int64_t content_length = -1; + Http2HeaderBlock block; + ASSERT_TRUE( + SpdyUtils::CopyAndValidateHeaders(*headers, &content_length, &block)); + EXPECT_THAT(block, UnorderedElementsAre( + Pair("foo", "foovalue"), Pair("bar", "barvalue"), + Pair("content-length", absl::string_view("9\09", 3)), + Pair("baz", ""))); + EXPECT_EQ(9, content_length); +} + +TEST_F(CopyAndValidateHeaders, InconsistentContentLengths) { + auto headers = FromList({{"content-length", "9"}, + {"foo", "foovalue"}, + {"content-length", "8"}, + {"bar", "barvalue"}, + {"baz", ""}}); + int64_t content_length = -1; + Http2HeaderBlock block; + ASSERT_FALSE( + SpdyUtils::CopyAndValidateHeaders(*headers, &content_length, &block)); +} + +TEST_F(CopyAndValidateHeaders, LargeContentLength) { + auto headers = FromList({{"content-length", "9000000000"}, + {"foo", "foovalue"}, + {"bar", "barvalue"}, + {"baz", ""}}); + int64_t content_length = -1; + Http2HeaderBlock block; + ASSERT_TRUE( + SpdyUtils::CopyAndValidateHeaders(*headers, &content_length, &block)); + EXPECT_THAT(block, + UnorderedElementsAre( + Pair("foo", "foovalue"), Pair("bar", "barvalue"), + Pair("content-length", absl::string_view("9000000000")), + Pair("baz", ""))); + EXPECT_EQ(9000000000, content_length); +} + +TEST_F(CopyAndValidateHeaders, NonDigitContentLength) { + // Section 3.3.2 of RFC 7230 defines content-length as being only digits. + // Number parsers might accept symbols like a leading plus; test that this + // fails to parse. + auto headers = FromList({{"content-length", "+123"}, + {"foo", "foovalue"}, + {"bar", "barvalue"}, + {"baz", ""}}); + int64_t content_length = -1; + Http2HeaderBlock block; + EXPECT_FALSE( + SpdyUtils::CopyAndValidateHeaders(*headers, &content_length, &block)); +} + +TEST_F(CopyAndValidateHeaders, MultipleValues) { + auto headers = FromList({{"foo", "foovalue"}, + {"bar", "barvalue"}, + {"baz", ""}, + {"foo", "boo"}, + {"baz", "buzz"}}); + int64_t content_length = -1; + Http2HeaderBlock block; + ASSERT_TRUE( + SpdyUtils::CopyAndValidateHeaders(*headers, &content_length, &block)); + EXPECT_THAT(block, UnorderedElementsAre( + Pair("foo", absl::string_view("foovalue\0boo", 12)), + Pair("bar", "barvalue"), + Pair("baz", absl::string_view("\0buzz", 5)))); + EXPECT_EQ(-1, content_length); +} + +TEST_F(CopyAndValidateHeaders, MoreThanTwoValues) { + auto headers = FromList({{"set-cookie", "value1"}, + {"set-cookie", "value2"}, + {"set-cookie", "value3"}}); + int64_t content_length = -1; + Http2HeaderBlock block; + ASSERT_TRUE( + SpdyUtils::CopyAndValidateHeaders(*headers, &content_length, &block)); + EXPECT_THAT(block, UnorderedElementsAre(Pair( + "set-cookie", + absl::string_view("value1\0value2\0value3", 20)))); + EXPECT_EQ(-1, content_length); +} + +TEST_F(CopyAndValidateHeaders, Cookie) { + auto headers = FromList({{"foo", "foovalue"}, + {"bar", "barvalue"}, + {"cookie", "value1"}, + {"baz", ""}}); + int64_t content_length = -1; + Http2HeaderBlock block; + ASSERT_TRUE( + SpdyUtils::CopyAndValidateHeaders(*headers, &content_length, &block)); + EXPECT_THAT(block, UnorderedElementsAre( + Pair("foo", "foovalue"), Pair("bar", "barvalue"), + Pair("cookie", "value1"), Pair("baz", ""))); + EXPECT_EQ(-1, content_length); +} + +TEST_F(CopyAndValidateHeaders, MultipleCookies) { + auto headers = FromList({{"foo", "foovalue"}, + {"bar", "barvalue"}, + {"cookie", "value1"}, + {"baz", ""}, + {"cookie", "value2"}}); + int64_t content_length = -1; + Http2HeaderBlock block; + ASSERT_TRUE( + SpdyUtils::CopyAndValidateHeaders(*headers, &content_length, &block)); + EXPECT_THAT(block, UnorderedElementsAre( + Pair("foo", "foovalue"), Pair("bar", "barvalue"), + Pair("cookie", "value1; value2"), Pair("baz", ""))); + EXPECT_EQ(-1, content_length); +} + +using CopyAndValidateTrailers = QuicTest; + +TEST_F(CopyAndValidateTrailers, SimplestValidList) { + // Verify that the simplest trailers are valid: just a final byte offset that + // gets parsed successfully. + auto trailers = FromList({{kFinalOffsetHeaderKey, "1234"}}); + size_t final_byte_offset = 0; + Http2HeaderBlock block; + EXPECT_TRUE(SpdyUtils::CopyAndValidateTrailers( + *trailers, kExpectFinalByteOffset, &final_byte_offset, &block)); + EXPECT_EQ(1234u, final_byte_offset); +} + +TEST_F(CopyAndValidateTrailers, EmptyTrailerListWithFinalByteOffsetExpected) { + // An empty trailer list will fail as expected key kFinalOffsetHeaderKey is + // not present. + QuicHeaderList trailers; + size_t final_byte_offset = 0; + Http2HeaderBlock block; + EXPECT_FALSE(SpdyUtils::CopyAndValidateTrailers( + trailers, kExpectFinalByteOffset, &final_byte_offset, &block)); +} + +TEST_F(CopyAndValidateTrailers, + EmptyTrailerListWithFinalByteOffsetNotExpected) { + // An empty trailer list will pass successfully if kFinalOffsetHeaderKey is + // not expected. + QuicHeaderList trailers; + size_t final_byte_offset = 0; + Http2HeaderBlock block; + EXPECT_TRUE(SpdyUtils::CopyAndValidateTrailers( + trailers, kDoNotExpectFinalByteOffset, &final_byte_offset, &block)); + EXPECT_TRUE(block.empty()); +} + +TEST_F(CopyAndValidateTrailers, FinalByteOffsetExpectedButNotPresent) { + // Validation fails if expected kFinalOffsetHeaderKey is not present, even if + // the rest of the header block is valid. + auto trailers = FromList({{"key", "value"}}); + size_t final_byte_offset = 0; + Http2HeaderBlock block; + EXPECT_FALSE(SpdyUtils::CopyAndValidateTrailers( + *trailers, kExpectFinalByteOffset, &final_byte_offset, &block)); +} + +TEST_F(CopyAndValidateTrailers, FinalByteOffsetNotExpectedButPresent) { + // Validation fails if kFinalOffsetHeaderKey is present but should not be, + // even if the rest of the header block is valid. + auto trailers = FromList({{"key", "value"}, {kFinalOffsetHeaderKey, "1234"}}); + size_t final_byte_offset = 0; + Http2HeaderBlock block; + EXPECT_FALSE(SpdyUtils::CopyAndValidateTrailers( + *trailers, kDoNotExpectFinalByteOffset, &final_byte_offset, &block)); +} + +TEST_F(CopyAndValidateTrailers, FinalByteOffsetNotExpectedAndNotPresent) { + // Validation succeeds if kFinalOffsetHeaderKey is not expected and not + // present. + auto trailers = FromList({{"key", "value"}}); + size_t final_byte_offset = 0; + Http2HeaderBlock block; + EXPECT_TRUE(SpdyUtils::CopyAndValidateTrailers( + *trailers, kDoNotExpectFinalByteOffset, &final_byte_offset, &block)); + EXPECT_THAT(block, UnorderedElementsAre(Pair("key", "value"))); +} + +TEST_F(CopyAndValidateTrailers, EmptyName) { + // Trailer validation will fail with an empty header key, in an otherwise + // valid block of trailers. + auto trailers = FromList({{"", "value"}, {kFinalOffsetHeaderKey, "1234"}}); + size_t final_byte_offset = 0; + Http2HeaderBlock block; + EXPECT_FALSE(SpdyUtils::CopyAndValidateTrailers( + *trailers, kExpectFinalByteOffset, &final_byte_offset, &block)); +} + +TEST_F(CopyAndValidateTrailers, PseudoHeaderInTrailers) { + // Pseudo headers are illegal in trailers. + auto trailers = + FromList({{":pseudo_key", "value"}, {kFinalOffsetHeaderKey, "1234"}}); + size_t final_byte_offset = 0; + Http2HeaderBlock block; + EXPECT_FALSE(SpdyUtils::CopyAndValidateTrailers( + *trailers, kExpectFinalByteOffset, &final_byte_offset, &block)); +} + +TEST_F(CopyAndValidateTrailers, DuplicateTrailers) { + // Duplicate trailers are allowed, and their values are concatenated into a + // single string delimted with '\0'. Some of the duplicate headers + // deliberately have an empty value. + auto trailers = FromList({{"key", "value0"}, + {"key", "value1"}, + {"key", ""}, + {"key", ""}, + {"key", "value2"}, + {"key", ""}, + {kFinalOffsetHeaderKey, "1234"}, + {"other_key", "value"}, + {"key", "non_contiguous_duplicate"}}); + size_t final_byte_offset = 0; + Http2HeaderBlock block; + EXPECT_TRUE(SpdyUtils::CopyAndValidateTrailers( + *trailers, kExpectFinalByteOffset, &final_byte_offset, &block)); + EXPECT_THAT( + block, + UnorderedElementsAre( + Pair("key", + absl::string_view( + "value0\0value1\0\0\0value2\0\0non_contiguous_duplicate", + 48)), + Pair("other_key", "value"))); +} + +TEST_F(CopyAndValidateTrailers, DuplicateCookies) { + // Duplicate cookie headers in trailers should be concatenated into a single + // "; " delimted string. + auto headers = FromList({{"cookie", " part 1"}, + {"cookie", "part 2 "}, + {"cookie", "part3"}, + {"key", "value"}, + {kFinalOffsetHeaderKey, "1234"}, + {"cookie", " non_contiguous_cookie!"}}); + + size_t final_byte_offset = 0; + Http2HeaderBlock block; + EXPECT_TRUE(SpdyUtils::CopyAndValidateTrailers( + *headers, kExpectFinalByteOffset, &final_byte_offset, &block)); + EXPECT_THAT( + block, + UnorderedElementsAre( + Pair("cookie", " part 1; part 2 ; part3; non_contiguous_cookie!"), + Pair("key", "value"))); +} + +using PopulateHeaderBlockFromUrl = QuicTest; + +TEST_F(PopulateHeaderBlockFromUrl, NormalUsage) { + std::string url = "https://www.google.com/index.html"; + Http2HeaderBlock headers; + EXPECT_TRUE(SpdyUtils::PopulateHeaderBlockFromUrl(url, &headers)); + EXPECT_EQ("https", headers[":scheme"].as_string()); + EXPECT_EQ("www.google.com", headers[":authority"].as_string()); + EXPECT_EQ("/index.html", headers[":path"].as_string()); +} + +TEST_F(PopulateHeaderBlockFromUrl, UrlWithNoPath) { + std::string url = "https://www.google.com"; + Http2HeaderBlock headers; + EXPECT_TRUE(SpdyUtils::PopulateHeaderBlockFromUrl(url, &headers)); + EXPECT_EQ("https", headers[":scheme"].as_string()); + EXPECT_EQ("www.google.com", headers[":authority"].as_string()); + EXPECT_EQ("/", headers[":path"].as_string()); +} + +TEST_F(PopulateHeaderBlockFromUrl, Failure) { + Http2HeaderBlock headers; + EXPECT_FALSE(SpdyUtils::PopulateHeaderBlockFromUrl("/", &headers)); + EXPECT_FALSE(SpdyUtils::PopulateHeaderBlockFromUrl("/index.html", &headers)); + EXPECT_FALSE( + SpdyUtils::PopulateHeaderBlockFromUrl("www.google.com/", &headers)); +} + +using ExtractQuicVersionFromAltSvcEntry = QuicTest; + +TEST_F(ExtractQuicVersionFromAltSvcEntry, SupportedVersion) { + ParsedQuicVersionVector supported_versions = AllSupportedVersions(); + spdy::SpdyAltSvcWireFormat::AlternativeService entry; + for (const ParsedQuicVersion& version : supported_versions) { + entry.protocol_id = AlpnForVersion(version); + ParsedQuicVersion expected_version = version; + // Versions with share an ALPN with v1 are currently unable to be + // advertised with Alt-Svc. + if (entry.protocol_id == AlpnForVersion(ParsedQuicVersion::RFCv1()) && + version != ParsedQuicVersion::RFCv1()) { + expected_version = ParsedQuicVersion::RFCv1(); + } + EXPECT_EQ(expected_version, SpdyUtils::ExtractQuicVersionFromAltSvcEntry( + entry, supported_versions)) + << "version: " << version; + } +} + +TEST_F(ExtractQuicVersionFromAltSvcEntry, UnsupportedVersion) { + spdy::SpdyAltSvcWireFormat::AlternativeService entry; + entry.protocol_id = "quic"; + EXPECT_EQ(ParsedQuicVersion::Unsupported(), + SpdyUtils::ExtractQuicVersionFromAltSvcEntry( + entry, AllSupportedVersions())); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/http/web_transport_http3.cc b/quiche/quic/core/http/web_transport_http3.cc new file mode 100644 index 000000000000..7a5fd07f1051 --- /dev/null +++ b/quiche/quic/core/http/web_transport_http3.cc @@ -0,0 +1,474 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/web_transport_http3.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/quic/core/http/quic_spdy_session.h" +#include "quiche/quic/core/http/quic_spdy_stream.h" +#include "quiche/quic/core/quic_data_reader.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_stream.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/common/capsule.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/web_transport/web_transport.h" + +#define ENDPOINT \ + (session_->perspective() == Perspective::IS_SERVER ? "Server: " : "Client: ") + +namespace quic { + +namespace { +class QUIC_NO_EXPORT NoopWebTransportVisitor : public WebTransportVisitor { + void OnSessionReady(const spdy::Http2HeaderBlock&) override {} + void OnSessionClosed(WebTransportSessionError /*error_code*/, + const std::string& /*error_message*/) override {} + void OnIncomingBidirectionalStreamAvailable() override {} + void OnIncomingUnidirectionalStreamAvailable() override {} + void OnDatagramReceived(absl::string_view /*datagram*/) override {} + void OnCanCreateNewOutgoingBidirectionalStream() override {} + void OnCanCreateNewOutgoingUnidirectionalStream() override {} +}; +} // namespace + +WebTransportHttp3::WebTransportHttp3(QuicSpdySession* session, + QuicSpdyStream* connect_stream, + WebTransportSessionId id) + : session_(session), + connect_stream_(connect_stream), + id_(id), + visitor_(std::make_unique()) { + QUICHE_DCHECK(session_->SupportsWebTransport()); + QUICHE_DCHECK(IsValidWebTransportSessionId(id, session_->version())); + QUICHE_DCHECK_EQ(connect_stream_->id(), id); + connect_stream_->RegisterHttp3DatagramVisitor(this); +} + +void WebTransportHttp3::AssociateStream(QuicStreamId stream_id) { + streams_.insert(stream_id); + + ParsedQuicVersion version = session_->version(); + if (QuicUtils::IsOutgoingStreamId(version, stream_id, + session_->perspective())) { + return; + } + if (QuicUtils::IsBidirectionalStreamId(stream_id, version)) { + incoming_bidirectional_streams_.push_back(stream_id); + visitor_->OnIncomingBidirectionalStreamAvailable(); + } else { + incoming_unidirectional_streams_.push_back(stream_id); + visitor_->OnIncomingUnidirectionalStreamAvailable(); + } +} + +void WebTransportHttp3::OnConnectStreamClosing() { + // Copy the stream list before iterating over it, as calls to ResetStream() + // can potentially mutate the |session_| list. + std::vector streams(streams_.begin(), streams_.end()); + streams_.clear(); + for (QuicStreamId id : streams) { + session_->ResetStream(id, QUIC_STREAM_WEBTRANSPORT_SESSION_GONE); + } + connect_stream_->UnregisterHttp3DatagramVisitor(); + + MaybeNotifyClose(); +} + +void WebTransportHttp3::CloseSession(WebTransportSessionError error_code, + absl::string_view error_message) { + if (close_sent_) { + QUIC_BUG(WebTransportHttp3 close sent twice) + << "Calling WebTransportHttp3::CloseSession() more than once is not " + "allowed."; + return; + } + close_sent_ = true; + + // There can be a race between us trying to send our close and peer sending + // one. If we received a close, however, we cannot send ours since we already + // closed the stream in response. + if (close_received_) { + QUIC_DLOG(INFO) << "Not sending CLOSE_WEBTRANSPORT_SESSION as we've " + "already sent one from peer."; + return; + } + + error_code_ = error_code; + error_message_ = std::string(error_message); + QuicConnection::ScopedPacketFlusher flusher( + connect_stream_->spdy_session()->connection()); + connect_stream_->WriteCapsule( + quiche::Capsule::CloseWebTransportSession(error_code, error_message), + /*fin=*/true); +} + +void WebTransportHttp3::OnCloseReceived(WebTransportSessionError error_code, + absl::string_view error_message) { + if (close_received_) { + QUIC_BUG(WebTransportHttp3 notified of close received twice) + << "WebTransportHttp3::OnCloseReceived() may be only called once."; + } + close_received_ = true; + + // If the peer has sent a close after we sent our own, keep the local error. + if (close_sent_) { + QUIC_DLOG(INFO) << "Ignoring received CLOSE_WEBTRANSPORT_SESSION as we've " + "already sent our own."; + return; + } + + error_code_ = error_code; + error_message_ = std::string(error_message); + connect_stream_->WriteOrBufferBody("", /*fin=*/true); + MaybeNotifyClose(); +} + +void WebTransportHttp3::OnConnectStreamFinReceived() { + // If we already received a CLOSE_WEBTRANSPORT_SESSION capsule, we don't need + // to do anything about receiving a FIN, since we already sent one in + // response. + if (close_received_) { + return; + } + close_received_ = true; + if (close_sent_) { + QUIC_DLOG(INFO) << "Ignoring received FIN as we've already sent our close."; + return; + } + + connect_stream_->WriteOrBufferBody("", /*fin=*/true); + MaybeNotifyClose(); +} + +void WebTransportHttp3::CloseSessionWithFinOnlyForTests() { + QUICHE_DCHECK(!close_sent_); + close_sent_ = true; + if (close_received_) { + return; + } + + connect_stream_->WriteOrBufferBody("", /*fin=*/true); +} + +void WebTransportHttp3::HeadersReceived(const spdy::Http2HeaderBlock& headers) { + if (session_->perspective() == Perspective::IS_CLIENT) { + int status_code; + if (!QuicSpdyStream::ParseHeaderStatusCode(headers, &status_code)) { + QUIC_DVLOG(1) << ENDPOINT + << "Received WebTransport headers from server without " + "a valid status code, rejecting."; + rejection_reason_ = WebTransportHttp3RejectionReason::kNoStatusCode; + return; + } + bool valid_status = status_code >= 200 && status_code <= 299; + if (!valid_status) { + QUIC_DVLOG(1) << ENDPOINT + << "Received WebTransport headers from server with " + "status code " + << status_code << ", rejecting."; + rejection_reason_ = WebTransportHttp3RejectionReason::kWrongStatusCode; + return; + } + bool should_validate_version = + session_->ShouldValidateWebTransportVersion(); + if (should_validate_version) { + auto draft_version_it = headers.find("sec-webtransport-http3-draft"); + if (draft_version_it == headers.end()) { + QUIC_DVLOG(1) << ENDPOINT + << "Received WebTransport headers from server without " + "a draft version, rejecting."; + rejection_reason_ = + WebTransportHttp3RejectionReason::kMissingDraftVersion; + return; + } + if (draft_version_it->second != "draft02") { + QUIC_DVLOG(1) << ENDPOINT + << "Received WebTransport headers from server with " + "an unknown draft version (" + << draft_version_it->second << "), rejecting."; + rejection_reason_ = + WebTransportHttp3RejectionReason::kUnsupportedDraftVersion; + return; + } + } + } + + QUIC_DVLOG(1) << ENDPOINT << "WebTransport session " << id_ << " ready."; + ready_ = true; + visitor_->OnSessionReady(headers); + session_->ProcessBufferedWebTransportStreamsForSession(this); +} + +WebTransportStream* WebTransportHttp3::AcceptIncomingBidirectionalStream() { + while (!incoming_bidirectional_streams_.empty()) { + QuicStreamId id = incoming_bidirectional_streams_.front(); + incoming_bidirectional_streams_.pop_front(); + QuicSpdyStream* stream = session_->GetOrCreateSpdyDataStream(id); + if (stream == nullptr) { + // Skip the streams that were reset in between the time they were + // receieved and the time the client has polled for them. + continue; + } + return stream->web_transport_stream(); + } + return nullptr; +} + +WebTransportStream* WebTransportHttp3::AcceptIncomingUnidirectionalStream() { + while (!incoming_unidirectional_streams_.empty()) { + QuicStreamId id = incoming_unidirectional_streams_.front(); + incoming_unidirectional_streams_.pop_front(); + QuicStream* stream = session_->GetOrCreateStream(id); + if (stream == nullptr) { + // Skip the streams that were reset in between the time they were + // receieved and the time the client has polled for them. + continue; + } + return static_cast(stream) + ->interface(); + } + return nullptr; +} + +bool WebTransportHttp3::CanOpenNextOutgoingBidirectionalStream() { + return session_->CanOpenOutgoingBidirectionalWebTransportStream(id_); +} +bool WebTransportHttp3::CanOpenNextOutgoingUnidirectionalStream() { + return session_->CanOpenOutgoingUnidirectionalWebTransportStream(id_); +} +WebTransportStream* WebTransportHttp3::OpenOutgoingBidirectionalStream() { + QuicSpdyStream* stream = + session_->CreateOutgoingBidirectionalWebTransportStream(this); + if (stream == nullptr) { + // If stream cannot be created due to flow control or other errors, return + // nullptr. + return nullptr; + } + return stream->web_transport_stream(); +} + +WebTransportStream* WebTransportHttp3::OpenOutgoingUnidirectionalStream() { + WebTransportHttp3UnidirectionalStream* stream = + session_->CreateOutgoingUnidirectionalWebTransportStream(this); + if (stream == nullptr) { + // If stream cannot be created due to flow control, return nullptr. + return nullptr; + } + return stream->interface(); +} + +webtransport::Stream* WebTransportHttp3::GetStreamById( + webtransport::StreamId id) { + if (!streams_.contains(id)) { + return nullptr; + } + QuicStream* stream = session_->GetActiveStream(id); + const bool bidi = QuicUtils::IsBidirectionalStreamId( + id, ParsedQuicVersion::RFCv1()); // Assume IETF QUIC for WebTransport + if (bidi) { + return static_cast(stream)->web_transport_stream(); + } else { + return static_cast(stream) + ->interface(); + } +} + +webtransport::DatagramStatus WebTransportHttp3::SendOrQueueDatagram( + absl::string_view datagram) { + return MessageStatusToWebTransportStatus( + connect_stream_->SendHttp3Datagram(datagram)); +} + +QuicByteCount WebTransportHttp3::GetMaxDatagramSize() const { + return connect_stream_->GetMaxDatagramSize(); +} + +void WebTransportHttp3::SetDatagramMaxTimeInQueue( + absl::Duration max_time_in_queue) { + connect_stream_->SetMaxDatagramTimeInQueue(QuicTimeDelta(max_time_in_queue)); +} + +void WebTransportHttp3::OnHttp3Datagram(QuicStreamId stream_id, + absl::string_view payload) { + QUICHE_DCHECK_EQ(stream_id, connect_stream_->id()); + visitor_->OnDatagramReceived(payload); +} + +void WebTransportHttp3::MaybeNotifyClose() { + if (close_notified_) { + return; + } + close_notified_ = true; + visitor_->OnSessionClosed(error_code_, error_message_); +} + +WebTransportHttp3UnidirectionalStream::WebTransportHttp3UnidirectionalStream( + PendingStream* pending, QuicSpdySession* session) + : QuicStream(pending, session, /*is_static=*/false), + session_(session), + adapter_(session, this, sequencer()), + needs_to_send_preamble_(false) { + sequencer()->set_level_triggered(true); +} + +WebTransportHttp3UnidirectionalStream::WebTransportHttp3UnidirectionalStream( + QuicStreamId id, QuicSpdySession* session, WebTransportSessionId session_id) + : QuicStream(id, session, /*is_static=*/false, WRITE_UNIDIRECTIONAL), + session_(session), + adapter_(session, this, sequencer()), + session_id_(session_id), + needs_to_send_preamble_(true) {} + +void WebTransportHttp3UnidirectionalStream::WritePreamble() { + if (!needs_to_send_preamble_ || !session_id_.has_value()) { + QUIC_BUG(WebTransportHttp3UnidirectionalStream duplicate preamble) + << ENDPOINT << "Sending preamble on stream ID " << id() + << " at the wrong time."; + OnUnrecoverableError(QUIC_INTERNAL_ERROR, + "Attempting to send a WebTransport unidirectional " + "stream preamble at the wrong time."); + return; + } + + QuicConnection::ScopedPacketFlusher flusher(session_->connection()); + char buffer[sizeof(uint64_t) * 2]; // varint62, varint62 + QuicDataWriter writer(sizeof(buffer), buffer); + bool success = true; + success = success && writer.WriteVarInt62(kWebTransportUnidirectionalStream); + success = success && writer.WriteVarInt62(*session_id_); + QUICHE_DCHECK(success); + WriteOrBufferData(absl::string_view(buffer, writer.length()), /*fin=*/false, + /*ack_listener=*/nullptr); + QUIC_DVLOG(1) << ENDPOINT << "Sent stream type and session ID (" + << *session_id_ << ") on WebTransport stream " << id(); + needs_to_send_preamble_ = false; +} + +bool WebTransportHttp3UnidirectionalStream::ReadSessionId() { + iovec iov; + if (!sequencer()->GetReadableRegion(&iov)) { + return false; + } + QuicDataReader reader(static_cast(iov.iov_base), iov.iov_len); + WebTransportSessionId session_id; + uint8_t session_id_length = reader.PeekVarInt62Length(); + if (!reader.ReadVarInt62(&session_id)) { + // If all of the data has been received, and we still cannot associate the + // stream with a session, consume all of the data so that the stream can + // be closed. + if (sequencer()->IsAllDataAvailable()) { + QUIC_DLOG(WARNING) + << ENDPOINT << "Failed to associate WebTransport stream " << id() + << " with a session because the stream ended prematurely."; + sequencer()->MarkConsumed(sequencer()->NumBytesBuffered()); + } + return false; + } + sequencer()->MarkConsumed(session_id_length); + session_id_ = session_id; + session_->AssociateIncomingWebTransportStreamWithSession(session_id, id()); + return true; +} + +void WebTransportHttp3UnidirectionalStream::OnDataAvailable() { + if (!session_id_.has_value()) { + if (!ReadSessionId()) { + return; + } + } + + adapter_.OnDataAvailable(); +} + +void WebTransportHttp3UnidirectionalStream::OnCanWriteNewData() { + adapter_.OnCanWriteNewData(); +} + +void WebTransportHttp3UnidirectionalStream::OnClose() { + QuicStream::OnClose(); + + if (!session_id_.has_value()) { + return; + } + WebTransportHttp3* session = session_->GetWebTransportSession(*session_id_); + if (session == nullptr) { + QUIC_DLOG(WARNING) << ENDPOINT << "WebTransport stream " << id() + << " attempted to notify parent session " << *session_id_ + << ", but the session could not be found."; + return; + } + session->OnStreamClosed(id()); +} + +void WebTransportHttp3UnidirectionalStream::OnStreamReset( + const QuicRstStreamFrame& frame) { + if (adapter_.visitor() != nullptr) { + adapter_.visitor()->OnResetStreamReceived( + Http3ErrorToWebTransportOrDefault(frame.ietf_error_code)); + } + QuicStream::OnStreamReset(frame); +} +bool WebTransportHttp3UnidirectionalStream::OnStopSending( + QuicResetStreamError error) { + if (adapter_.visitor() != nullptr) { + adapter_.visitor()->OnStopSendingReceived( + Http3ErrorToWebTransportOrDefault(error.ietf_application_code())); + } + return QuicStream::OnStopSending(error); +} +void WebTransportHttp3UnidirectionalStream::OnWriteSideInDataRecvdState() { + if (adapter_.visitor() != nullptr) { + adapter_.visitor()->OnWriteSideInDataRecvdState(); + } + + QuicStream::OnWriteSideInDataRecvdState(); +} + +namespace { +constexpr uint64_t kWebTransportMappedErrorCodeFirst = 0x52e4a40fa8db; +constexpr uint64_t kWebTransportMappedErrorCodeLast = 0x52e4a40fa9e2; +constexpr WebTransportStreamError kDefaultWebTransportError = 0; +} // namespace + +absl::optional Http3ErrorToWebTransport( + uint64_t http3_error_code) { + // Ensure the code is within the valid range. + if (http3_error_code < kWebTransportMappedErrorCodeFirst || + http3_error_code > kWebTransportMappedErrorCodeLast) { + return absl::nullopt; + } + // Exclude GREASE codepoints. + if ((http3_error_code - 0x21) % 0x1f == 0) { + return absl::nullopt; + } + + uint64_t shifted = http3_error_code - kWebTransportMappedErrorCodeFirst; + uint64_t result = shifted - shifted / 0x1f; + QUICHE_DCHECK_LE(result, std::numeric_limits::max()); + return static_cast(result); +} + +WebTransportStreamError Http3ErrorToWebTransportOrDefault( + uint64_t http3_error_code) { + absl::optional result = + Http3ErrorToWebTransport(http3_error_code); + return result.has_value() ? *result : kDefaultWebTransportError; +} + +uint64_t WebTransportErrorToHttp3( + WebTransportStreamError webtransport_error_code) { + return kWebTransportMappedErrorCodeFirst + webtransport_error_code + + webtransport_error_code / 0x1e; +} + +} // namespace quic diff --git a/quiche/quic/core/http/web_transport_http3.h b/quiche/quic/core/http/web_transport_http3.h new file mode 100644 index 000000000000..8a6f35a3e863 --- /dev/null +++ b/quiche/quic/core/http/web_transport_http3.h @@ -0,0 +1,182 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_HTTP_WEB_TRANSPORT_HTTP3_H_ +#define QUICHE_QUIC_CORE_HTTP_WEB_TRANSPORT_HTTP3_H_ + +#include + +#include "absl/base/attributes.h" +#include "absl/container/flat_hash_set.h" +#include "absl/time/time.h" +#include "absl/types/optional.h" +#include "quiche/quic/core/http/quic_spdy_session.h" +#include "quiche/quic/core/http/web_transport_stream_adapter.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_stream.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/web_transport_interface.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/web_transport/web_transport.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +class QuicSpdySession; +class QuicSpdyStream; + +enum class WebTransportHttp3RejectionReason { + kNone, + kNoStatusCode, + kWrongStatusCode, + kMissingDraftVersion, + kUnsupportedDraftVersion, +}; + +// A session of WebTransport over HTTP/3. The session is owned by +// QuicSpdyStream object for the CONNECT stream that established it. +// +// WebTransport over HTTP/3 specification: +// +class QUIC_EXPORT_PRIVATE WebTransportHttp3 + : public WebTransportSession, + public QuicSpdyStream::Http3DatagramVisitor { + public: + WebTransportHttp3(QuicSpdySession* session, QuicSpdyStream* connect_stream, + WebTransportSessionId id); + + void HeadersReceived(const spdy::Http2HeaderBlock& headers); + void SetVisitor(std::unique_ptr visitor) { + visitor_ = std::move(visitor); + } + + WebTransportSessionId id() { return id_; } + bool ready() { return ready_; } + + void AssociateStream(QuicStreamId stream_id); + void OnStreamClosed(QuicStreamId stream_id) { streams_.erase(stream_id); } + void OnConnectStreamClosing(); + + size_t NumberOfAssociatedStreams() { return streams_.size(); } + + void CloseSession(WebTransportSessionError error_code, + absl::string_view error_message) override; + void OnCloseReceived(WebTransportSessionError error_code, + absl::string_view error_message); + void OnConnectStreamFinReceived(); + + // It is legal for WebTransport to be closed without a + // CLOSE_WEBTRANSPORT_SESSION capsule. We always send a capsule, but we still + // need to ensure we handle this case correctly. + void CloseSessionWithFinOnlyForTests(); + + // Return the earliest incoming stream that has been received by the session + // but has not been accepted. Returns nullptr if there are no incoming + // streams. + WebTransportStream* AcceptIncomingBidirectionalStream() override; + WebTransportStream* AcceptIncomingUnidirectionalStream() override; + + bool CanOpenNextOutgoingBidirectionalStream() override; + bool CanOpenNextOutgoingUnidirectionalStream() override; + WebTransportStream* OpenOutgoingBidirectionalStream() override; + WebTransportStream* OpenOutgoingUnidirectionalStream() override; + + webtransport::Stream* GetStreamById(webtransport::StreamId id) override; + + webtransport::DatagramStatus SendOrQueueDatagram( + absl::string_view datagram) override; + QuicByteCount GetMaxDatagramSize() const override; + void SetDatagramMaxTimeInQueue(absl::Duration max_time_in_queue) override; + + // From QuicSpdyStream::Http3DatagramVisitor. + void OnHttp3Datagram(QuicStreamId stream_id, + absl::string_view payload) override; + void OnUnknownCapsule(QuicStreamId /*stream_id*/, + const quiche::UnknownCapsule& /*capsule*/) override {} + + bool close_received() const { return close_received_; } + WebTransportHttp3RejectionReason rejection_reason() const { + return rejection_reason_; + } + + private: + // Notifies the visitor that the connection has been closed. Ensures that the + // visitor is only ever called once. + void MaybeNotifyClose(); + + QuicSpdySession* const session_; // Unowned. + QuicSpdyStream* const connect_stream_; // Unowned. + const WebTransportSessionId id_; + // |ready_| is set to true when the peer has seen both sets of headers. + bool ready_ = false; + std::unique_ptr visitor_; + absl::flat_hash_set streams_; + quiche::QuicheCircularDeque incoming_bidirectional_streams_; + quiche::QuicheCircularDeque incoming_unidirectional_streams_; + + bool close_sent_ = false; + bool close_received_ = false; + bool close_notified_ = false; + + WebTransportHttp3RejectionReason rejection_reason_ = + WebTransportHttp3RejectionReason::kNone; + // Those are set to default values, which are used if the session is not + // closed cleanly using an appropriate capsule. + WebTransportSessionError error_code_ = 0; + std::string error_message_ = ""; +}; + +class QUIC_EXPORT_PRIVATE WebTransportHttp3UnidirectionalStream + : public QuicStream { + public: + // Incoming stream. + WebTransportHttp3UnidirectionalStream(PendingStream* pending, + QuicSpdySession* session); + // Outgoing stream. + WebTransportHttp3UnidirectionalStream(QuicStreamId id, + QuicSpdySession* session, + WebTransportSessionId session_id); + + // Sends the stream type and the session ID on the stream. + void WritePreamble(); + + // Implementation of QuicStream. + void OnDataAvailable() override; + void OnCanWriteNewData() override; + void OnClose() override; + void OnStreamReset(const QuicRstStreamFrame& frame) override; + bool OnStopSending(QuicResetStreamError error) override; + void OnWriteSideInDataRecvdState() override; + + WebTransportStream* interface() { return &adapter_; } + void SetUnblocked() { sequencer()->SetUnblocked(); } + + private: + QuicSpdySession* session_; + WebTransportStreamAdapter adapter_; + absl::optional session_id_; + bool needs_to_send_preamble_; + + bool ReadSessionId(); + // Closes the stream if all of the data has been received. + void MaybeCloseIncompleteStream(); +}; + +// Remaps HTTP/3 error code into a WebTransport error code. Returns nullopt if +// the provided code is outside of valid range. +QUIC_EXPORT_PRIVATE absl::optional +Http3ErrorToWebTransport(uint64_t http3_error_code); + +// Same as above, but returns default error value (zero) when none could be +// mapped. +QUIC_EXPORT_PRIVATE WebTransportStreamError +Http3ErrorToWebTransportOrDefault(uint64_t http3_error_code); + +// Remaps WebTransport error code into an HTTP/3 error code. +QUIC_EXPORT_PRIVATE uint64_t +WebTransportErrorToHttp3(WebTransportStreamError webtransport_error_code); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_HTTP_WEB_TRANSPORT_HTTP3_H_ diff --git a/quiche/quic/core/http/web_transport_http3_test.cc b/quiche/quic/core/http/web_transport_http3_test.cc new file mode 100644 index 000000000000..87cd0d379d8b --- /dev/null +++ b/quiche/quic/core/http/web_transport_http3_test.cc @@ -0,0 +1,52 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/web_transport_http3.h" + +#include +#include + +#include "absl/types/optional.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace { + +using ::testing::Optional; + +TEST(WebTransportHttp3Test, ErrorCodesToHttp3) { + EXPECT_EQ(0x52e4a40fa8dbu, WebTransportErrorToHttp3(0x00)); + EXPECT_EQ(0x52e4a40fa9e2u, WebTransportErrorToHttp3(0xff)); + + EXPECT_EQ(0x52e4a40fa8f7u, WebTransportErrorToHttp3(0x1c)); + EXPECT_EQ(0x52e4a40fa8f8u, WebTransportErrorToHttp3(0x1d)); + // 0x52e4a40fa8f9 is a GREASE codepoint + EXPECT_EQ(0x52e4a40fa8fau, WebTransportErrorToHttp3(0x1e)); +} + +TEST(WebTransportHttp3Test, ErrorCodesToWebTransport) { + EXPECT_THAT(Http3ErrorToWebTransport(0x52e4a40fa8db), Optional(0x00)); + EXPECT_THAT(Http3ErrorToWebTransport(0x52e4a40fa9e2), Optional(0xff)); + + EXPECT_THAT(Http3ErrorToWebTransport(0x52e4a40fa8f7), Optional(0x1cu)); + EXPECT_THAT(Http3ErrorToWebTransport(0x52e4a40fa8f8), Optional(0x1du)); + EXPECT_THAT(Http3ErrorToWebTransport(0x52e4a40fa8f9), absl::nullopt); + EXPECT_THAT(Http3ErrorToWebTransport(0x52e4a40fa8fa), Optional(0x1eu)); + + EXPECT_EQ(Http3ErrorToWebTransport(0), absl::nullopt); + EXPECT_EQ(Http3ErrorToWebTransport(std::numeric_limits::max()), + absl::nullopt); +} + +TEST(WebTransportHttp3Test, ErrorCodeRoundTrip) { + for (int error = 0; error < 256; error++) { + uint64_t http_error = WebTransportErrorToHttp3(error); + absl::optional mapped_back = + quic::Http3ErrorToWebTransport(http_error); + EXPECT_THAT(mapped_back, Optional(error)); + } +} + +} // namespace +} // namespace quic diff --git a/quiche/quic/core/http/web_transport_stream_adapter.cc b/quiche/quic/core/http/web_transport_stream_adapter.cc new file mode 100644 index 000000000000..f484249e12c8 --- /dev/null +++ b/quiche/quic/core/http/web_transport_stream_adapter.cc @@ -0,0 +1,156 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/http/web_transport_stream_adapter.h" + +#include "absl/status/status.h" +#include "quiche/quic/core/http/web_transport_http3.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/common/quiche_buffer_allocator.h" +#include "quiche/common/quiche_mem_slice_storage.h" +#include "quiche/web_transport/web_transport.h" + +namespace quic { + +WebTransportStreamAdapter::WebTransportStreamAdapter( + QuicSession* session, QuicStream* stream, QuicStreamSequencer* sequencer) + : session_(session), stream_(stream), sequencer_(sequencer) {} + +WebTransportStream::ReadResult WebTransportStreamAdapter::Read( + absl::Span buffer) { + iovec iov; + iov.iov_base = buffer.data(); + iov.iov_len = buffer.size(); + const size_t result = sequencer_->Readv(&iov, 1); + if (!fin_read_ && sequencer_->IsClosed()) { + fin_read_ = true; + stream_->OnFinRead(); + } + return ReadResult{result, sequencer_->IsClosed()}; +} + +WebTransportStream::ReadResult WebTransportStreamAdapter::Read( + std::string* output) { + const size_t old_size = output->size(); + const size_t bytes_to_read = ReadableBytes(); + output->resize(old_size + bytes_to_read); + ReadResult result = + Read(absl::Span(&(*output)[old_size], bytes_to_read)); + QUICHE_DCHECK_EQ(bytes_to_read, result.bytes_read); + output->resize(old_size + result.bytes_read); + return result; +} + +absl::Status WebTransportStreamAdapter::Writev( + absl::Span data, + const quiche::StreamWriteOptions& options) { + if (data.empty() && !options.send_fin()) { + return absl::InvalidArgumentError( + "Writev() called without any data or a FIN"); + } + const absl::Status initial_check_status = CheckBeforeStreamWrite(); + if (!initial_check_status.ok()) { + return initial_check_status; + } + + std::vector iovecs; + size_t total_size = 0; + iovecs.resize(data.size()); + for (size_t i = 0; i < data.size(); i++) { + // QuicheMemSliceStorage only reads iovec, thus this is safe. + iovecs[i].iov_base = const_cast(data[i].data()); + iovecs[i].iov_len = data[i].size(); + total_size += data[i].size(); + } + quiche::QuicheMemSliceStorage storage( + iovecs.data(), iovecs.size(), + session_->connection()->helper()->GetStreamSendBufferAllocator(), + GetQuicFlag(quic_send_buffer_max_data_slice_size)); + QuicConsumedData consumed = + stream_->WriteMemSlices(storage.ToSpan(), /*fin=*/options.send_fin()); + + if (consumed.bytes_consumed == total_size) { + return absl::OkStatus(); + } + if (consumed.bytes_consumed == 0) { + return absl::UnavailableError("Stream write-blocked"); + } + // WebTransportStream::Write() is an all-or-nothing write API. To achieve + // that property, it relies on WriteMemSlices() being an all-or-nothing API. + // If WriteMemSlices() fails to provide that guarantee, we have no way to + // communicate a partial write to the caller, and thus it's safer to just + // close the connection. + constexpr absl::string_view kErrorMessage = + "WriteMemSlices() unexpectedly partially consumed the input data"; + QUIC_BUG(WebTransportStreamAdapter partial write) + << kErrorMessage << ", provided: " << total_size + << ", written: " << consumed.bytes_consumed; + stream_->OnUnrecoverableError(QUIC_INTERNAL_ERROR, + std::string(kErrorMessage)); + return absl::InternalError(kErrorMessage); +} + +absl::Status WebTransportStreamAdapter::CheckBeforeStreamWrite() const { + if (stream_->write_side_closed() || stream_->fin_buffered()) { + return absl::FailedPreconditionError("Stream write side is closed"); + } + if (!stream_->CanWriteNewData()) { + return absl::UnavailableError("Stream write-blocked"); + } + return absl::OkStatus(); +} + +bool WebTransportStreamAdapter::CanWrite() const { + return CheckBeforeStreamWrite().ok(); +} + +void WebTransportStreamAdapter::AbruptlyTerminate(absl::Status error) { + QUIC_DLOG(WARNING) << (session_->perspective() == Perspective::IS_CLIENT + ? "Client: " + : "Server: ") + << "Abruptly terminating stream " << stream_->id() + << " due to the following error: " << error; + ResetDueToInternalError(); +} + +size_t WebTransportStreamAdapter::ReadableBytes() const { + return sequencer_->ReadableBytes(); +} + +void WebTransportStreamAdapter::OnDataAvailable() { + if (visitor_ == nullptr) { + return; + } + const bool fin_readable = sequencer_->IsClosed() && !fin_read_; + if (ReadableBytes() == 0 && !fin_readable) { + return; + } + visitor_->OnCanRead(); +} + +void WebTransportStreamAdapter::OnCanWriteNewData() { + // Ensure the origin check has been completed, as the stream can be notified + // about being writable before that. + if (!CanWrite()) { + return; + } + if (visitor_ != nullptr) { + visitor_->OnCanWrite(); + } +} + +void WebTransportStreamAdapter::ResetWithUserCode( + WebTransportStreamError error) { + stream_->ResetWriteSide(QuicResetStreamError( + QUIC_STREAM_CANCELLED, WebTransportErrorToHttp3(error))); +} + +void WebTransportStreamAdapter::SendStopSending(WebTransportStreamError error) { + stream_->SendStopSending(QuicResetStreamError( + QUIC_STREAM_CANCELLED, WebTransportErrorToHttp3(error))); +} + +} // namespace quic diff --git a/quiche/quic/core/http/web_transport_stream_adapter.h b/quiche/quic/core/http/web_transport_stream_adapter.h new file mode 100644 index 000000000000..c664347a52c8 --- /dev/null +++ b/quiche/quic/core/http/web_transport_stream_adapter.h @@ -0,0 +1,68 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_HTTP_WEB_TRANSPORT_STREAM_ADAPTER_H_ +#define QUICHE_QUIC_CORE_HTTP_WEB_TRANSPORT_STREAM_ADAPTER_H_ + +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_stream.h" +#include "quiche/quic/core/quic_stream_sequencer.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/web_transport_interface.h" +#include "quiche/web_transport/web_transport.h" + +namespace quic { + +// Converts WebTransportStream API calls into QuicStream API calls. The users +// of this class can either subclass it, or wrap around it. +class QUIC_EXPORT_PRIVATE WebTransportStreamAdapter + : public WebTransportStream { + public: + WebTransportStreamAdapter(QuicSession* session, QuicStream* stream, + QuicStreamSequencer* sequencer); + + // WebTransportStream implementation. + ABSL_MUST_USE_RESULT ReadResult Read(absl::Span output) override; + ABSL_MUST_USE_RESULT ReadResult Read(std::string* output) override; + absl::Status Writev(absl::Span data, + const quiche::StreamWriteOptions& options) override; + bool CanWrite() const override; + void AbruptlyTerminate(absl::Status error) override; + size_t ReadableBytes() const override; + void SetVisitor(std::unique_ptr visitor) override { + visitor_ = std::move(visitor); + } + QuicStreamId GetStreamId() const override { return stream_->id(); } + + void ResetWithUserCode(WebTransportStreamError error) override; + void ResetDueToInternalError() override { + stream_->Reset(QUIC_STREAM_INTERNAL_ERROR); + } + void SendStopSending(WebTransportStreamError error) override; + void MaybeResetDueToStreamObjectGone() override { + if (stream_->write_side_closed() && stream_->read_side_closed()) { + return; + } + stream_->Reset(QUIC_STREAM_CANCELLED); + } + + WebTransportStreamVisitor* visitor() override { return visitor_.get(); } + + // Calls that need to be passed from the corresponding QuicStream methods. + void OnDataAvailable(); + void OnCanWriteNewData(); + + private: + absl::Status CheckBeforeStreamWrite() const; + + QuicSession* session_; // Unowned. + QuicStream* stream_; // Unowned. + QuicStreamSequencer* sequencer_; // Unowned. + std::unique_ptr visitor_; + bool fin_read_ = false; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_HTTP_WEB_TRANSPORT_STREAM_ADAPTER_H_ diff --git a/quiche/quic/core/io/event_loop_connecting_client_socket.cc b/quiche/quic/core/io/event_loop_connecting_client_socket.cc new file mode 100644 index 000000000000..aefa353de570 --- /dev/null +++ b/quiche/quic/core/io/event_loop_connecting_client_socket.cc @@ -0,0 +1,621 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/io/event_loop_connecting_client_socket.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "absl/types/variant.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/io/socket.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" + +namespace quic { + +EventLoopConnectingClientSocket::EventLoopConnectingClientSocket( + socket_api::SocketProtocol protocol, + const quic::QuicSocketAddress& peer_address, + QuicByteCount receive_buffer_size, QuicByteCount send_buffer_size, + QuicEventLoop* event_loop, quiche::QuicheBufferAllocator* buffer_allocator, + AsyncVisitor* async_visitor) + : protocol_(protocol), + peer_address_(peer_address), + receive_buffer_size_(receive_buffer_size), + send_buffer_size_(send_buffer_size), + event_loop_(event_loop), + buffer_allocator_(buffer_allocator), + async_visitor_(async_visitor) { + QUICHE_DCHECK(event_loop_); + QUICHE_DCHECK(buffer_allocator_); +} + +EventLoopConnectingClientSocket::~EventLoopConnectingClientSocket() { + // Connected socket must be closed via Disconnect() before destruction. Cannot + // safely recover if state indicates caller may be expecting async callbacks. + QUICHE_DCHECK(connect_status_ != ConnectStatus::kConnecting); + QUICHE_DCHECK(!receive_max_size_.has_value()); + QUICHE_DCHECK(absl::holds_alternative(send_data_)); + if (descriptor_ != kInvalidSocketFd) { + QUICHE_BUG(quic_event_loop_connecting_socket_invalid_destruction) + << "Must call Disconnect() on connected socket before destruction."; + Close(); + } + + QUICHE_DCHECK(connect_status_ == ConnectStatus::kNotConnected); + QUICHE_DCHECK(send_remaining_.empty()); +} + +absl::Status EventLoopConnectingClientSocket::ConnectBlocking() { + QUICHE_DCHECK_EQ(descriptor_, kInvalidSocketFd); + QUICHE_DCHECK(connect_status_ == ConnectStatus::kNotConnected); + QUICHE_DCHECK(!receive_max_size_.has_value()); + QUICHE_DCHECK(absl::holds_alternative(send_data_)); + + absl::Status status = Open(); + if (!status.ok()) { + return status; + } + + status = socket_api::SetSocketBlocking(descriptor_, /*blocking=*/true); + if (!status.ok()) { + QUICHE_LOG_FIRST_N(WARNING, 100) + << "Failed to set socket to address: " << peer_address_.ToString() + << " as blocking for connect with error: " << status; + Close(); + return status; + } + + status = DoInitialConnect(); + + if (absl::IsUnavailable(status)) { + QUICHE_LOG_FIRST_N(ERROR, 100) + << "Non-blocking connect to should-be blocking socket to address:" + << peer_address_.ToString() << "."; + Close(); + connect_status_ = ConnectStatus::kNotConnected; + return status; + } else if (!status.ok()) { + // DoInitialConnect() closes the socket on failures. + QUICHE_DCHECK_EQ(descriptor_, kInvalidSocketFd); + QUICHE_DCHECK(connect_status_ == ConnectStatus::kNotConnected); + return status; + } + + status = socket_api::SetSocketBlocking(descriptor_, /*blocking=*/false); + if (!status.ok()) { + QUICHE_LOG_FIRST_N(WARNING, 100) + << "Failed to return socket to address: " << peer_address_.ToString() + << " to non-blocking after connect with error: " << status; + Close(); + connect_status_ = ConnectStatus::kNotConnected; + } + + QUICHE_DCHECK(connect_status_ != ConnectStatus::kConnecting); + return status; +} + +void EventLoopConnectingClientSocket::ConnectAsync() { + QUICHE_DCHECK(async_visitor_); + QUICHE_DCHECK_EQ(descriptor_, kInvalidSocketFd); + QUICHE_DCHECK(connect_status_ == ConnectStatus::kNotConnected); + QUICHE_DCHECK(!receive_max_size_.has_value()); + QUICHE_DCHECK(absl::holds_alternative(send_data_)); + + absl::Status status = Open(); + if (!status.ok()) { + async_visitor_->ConnectComplete(status); + return; + } + + FinishOrRearmAsyncConnect(DoInitialConnect()); +} + +void EventLoopConnectingClientSocket::Disconnect() { + QUICHE_DCHECK_NE(descriptor_, kInvalidSocketFd); + QUICHE_DCHECK(connect_status_ != ConnectStatus::kNotConnected); + + Close(); + QUICHE_DCHECK_EQ(descriptor_, kInvalidSocketFd); + + // Reset all state before invoking any callbacks. + bool require_connect_callback = connect_status_ == ConnectStatus::kConnecting; + connect_status_ = ConnectStatus::kNotConnected; + bool require_receive_callback = receive_max_size_.has_value(); + receive_max_size_.reset(); + bool require_send_callback = + !absl::holds_alternative(send_data_); + send_data_ = absl::monostate(); + send_remaining_ = ""; + + if (require_connect_callback) { + QUICHE_DCHECK(async_visitor_); + async_visitor_->ConnectComplete(absl::CancelledError()); + } + if (require_receive_callback) { + QUICHE_DCHECK(async_visitor_); + async_visitor_->ReceiveComplete(absl::CancelledError()); + } + if (require_send_callback) { + QUICHE_DCHECK(async_visitor_); + async_visitor_->SendComplete(absl::CancelledError()); + } +} + +absl::StatusOr +EventLoopConnectingClientSocket::GetLocalAddress() { + QUICHE_DCHECK_NE(descriptor_, kInvalidSocketFd); + QUICHE_DCHECK(connect_status_ == ConnectStatus::kConnected); + + return socket_api::GetSocketAddress(descriptor_); +} + +absl::StatusOr +EventLoopConnectingClientSocket::ReceiveBlocking(QuicByteCount max_size) { + QUICHE_DCHECK_GT(max_size, 0u); + QUICHE_DCHECK_NE(descriptor_, kInvalidSocketFd); + QUICHE_DCHECK(connect_status_ == ConnectStatus::kConnected); + QUICHE_DCHECK(!receive_max_size_.has_value()); + + absl::Status status = + socket_api::SetSocketBlocking(descriptor_, /*blocking=*/true); + if (!status.ok()) { + QUICHE_LOG_FIRST_N(WARNING, 100) + << "Failed to set socket to address: " << peer_address_.ToString() + << " as blocking for receive with error: " << status; + return status; + } + + receive_max_size_ = max_size; + absl::StatusOr buffer = ReceiveInternal(); + + if (!buffer.ok() && absl::IsUnavailable(buffer.status())) { + QUICHE_LOG_FIRST_N(ERROR, 100) + << "Non-blocking receive from should-be blocking socket to address:" + << peer_address_.ToString() << "."; + receive_max_size_.reset(); + } else { + QUICHE_DCHECK(!receive_max_size_.has_value()); + } + + absl::Status set_non_blocking_status = + socket_api::SetSocketBlocking(descriptor_, /*blocking=*/false); + if (!set_non_blocking_status.ok()) { + QUICHE_LOG_FIRST_N(WARNING, 100) + << "Failed to return socket to address: " << peer_address_.ToString() + << " to non-blocking after receive with error: " + << set_non_blocking_status; + return set_non_blocking_status; + } + + return buffer; +} + +void EventLoopConnectingClientSocket::ReceiveAsync(QuicByteCount max_size) { + QUICHE_DCHECK(async_visitor_); + QUICHE_DCHECK_GT(max_size, 0u); + QUICHE_DCHECK_NE(descriptor_, kInvalidSocketFd); + QUICHE_DCHECK(connect_status_ == ConnectStatus::kConnected); + QUICHE_DCHECK(!receive_max_size_.has_value()); + + receive_max_size_ = max_size; + + FinishOrRearmAsyncReceive(ReceiveInternal()); +} + +absl::Status EventLoopConnectingClientSocket::SendBlocking(std::string data) { + QUICHE_DCHECK(!data.empty()); + QUICHE_DCHECK(absl::holds_alternative(send_data_)); + + send_data_ = std::move(data); + return SendBlockingInternal(); +} + +absl::Status EventLoopConnectingClientSocket::SendBlocking( + quiche::QuicheMemSlice data) { + QUICHE_DCHECK(!data.empty()); + QUICHE_DCHECK(absl::holds_alternative(send_data_)); + + send_data_ = std::move(data); + return SendBlockingInternal(); +} + +void EventLoopConnectingClientSocket::SendAsync(std::string data) { + QUICHE_DCHECK(!data.empty()); + QUICHE_DCHECK(absl::holds_alternative(send_data_)); + + send_data_ = std::move(data); + send_remaining_ = absl::get(send_data_); + + FinishOrRearmAsyncSend(SendInternal()); +} + +void EventLoopConnectingClientSocket::SendAsync(quiche::QuicheMemSlice data) { + QUICHE_DCHECK(!data.empty()); + QUICHE_DCHECK(absl::holds_alternative(send_data_)); + + send_data_ = std::move(data); + send_remaining_ = + absl::get(send_data_).AsStringView(); + + FinishOrRearmAsyncSend(SendInternal()); +} + +void EventLoopConnectingClientSocket::OnSocketEvent( + QuicEventLoop* event_loop, SocketFd fd, QuicSocketEventMask events) { + QUICHE_DCHECK_EQ(event_loop, event_loop_); + QUICHE_DCHECK_EQ(fd, descriptor_); + + if (connect_status_ == ConnectStatus::kConnecting && + (events & (kSocketEventWritable | kSocketEventError))) { + FinishOrRearmAsyncConnect(GetConnectResult()); + return; + } + + if (receive_max_size_.has_value() && + (events & (kSocketEventReadable | kSocketEventError))) { + FinishOrRearmAsyncReceive(ReceiveInternal()); + } + if (!send_remaining_.empty() && + (events & (kSocketEventWritable | kSocketEventError))) { + FinishOrRearmAsyncSend(SendInternal()); + } +} + +absl::Status EventLoopConnectingClientSocket::Open() { + QUICHE_DCHECK_EQ(descriptor_, kInvalidSocketFd); + QUICHE_DCHECK(connect_status_ == ConnectStatus::kNotConnected); + QUICHE_DCHECK(!receive_max_size_.has_value()); + QUICHE_DCHECK(absl::holds_alternative(send_data_)); + QUICHE_DCHECK(send_remaining_.empty()); + + absl::StatusOr descriptor = + socket_api::CreateSocket(peer_address_.host().address_family(), protocol_, + /*blocking=*/false); + if (!descriptor.ok()) { + QUICHE_DVLOG(1) << "Failed to open socket for connection to address: " + << peer_address_.ToString() + << " with error: " << descriptor.status(); + return descriptor.status(); + } + QUICHE_DCHECK_NE(descriptor.value(), kInvalidSocketFd); + + descriptor_ = descriptor.value(); + + if (async_visitor_) { + bool registered; + if (event_loop_->SupportsEdgeTriggered()) { + registered = event_loop_->RegisterSocket( + descriptor_, + kSocketEventReadable | kSocketEventWritable | kSocketEventError, + this); + } else { + // Just register the socket without any armed events for now. Will rearm + // with specific events as needed. Registering now before events are + // needed makes it easier to ensure the socket is registered only once + // and can always be unregistered on socket close. + registered = event_loop_->RegisterSocket(descriptor_, /*events=*/0, this); + } + QUICHE_DCHECK(registered); + } + + if (receive_buffer_size_ != 0) { + absl::Status status = + socket_api::SetReceiveBufferSize(descriptor_, receive_buffer_size_); + if (!status.ok()) { + QUICHE_LOG_FIRST_N(WARNING, 100) + << "Failed to set receive buffer size to: " << receive_buffer_size_ + << " for socket to address: " << peer_address_.ToString() + << " with error: " << status; + Close(); + return status; + } + } + + if (send_buffer_size_ != 0) { + absl::Status status = + socket_api::SetSendBufferSize(descriptor_, send_buffer_size_); + if (!status.ok()) { + QUICHE_LOG_FIRST_N(WARNING, 100) + << "Failed to set send buffer size to: " << send_buffer_size_ + << " for socket to address: " << peer_address_.ToString() + << " with error: " << status; + Close(); + return status; + } + } + + return absl::OkStatus(); +} + +void EventLoopConnectingClientSocket::Close() { + QUICHE_DCHECK_NE(descriptor_, kInvalidSocketFd); + + bool unregistered = event_loop_->UnregisterSocket(descriptor_); + QUICHE_DCHECK_EQ(unregistered, !!async_visitor_); + + absl::Status status = socket_api::Close(descriptor_); + if (!status.ok()) { + QUICHE_LOG_FIRST_N(WARNING, 100) + << "Could not close socket to address: " << peer_address_.ToString() + << " with error: " << status; + } + + descriptor_ = kInvalidSocketFd; +} + +absl::Status EventLoopConnectingClientSocket::DoInitialConnect() { + QUICHE_DCHECK_NE(descriptor_, kInvalidSocketFd); + QUICHE_DCHECK(connect_status_ == ConnectStatus::kNotConnected); + QUICHE_DCHECK(!receive_max_size_.has_value()); + QUICHE_DCHECK(absl::holds_alternative(send_data_)); + + absl::Status connect_result = socket_api::Connect(descriptor_, peer_address_); + + if (connect_result.ok()) { + connect_status_ = ConnectStatus::kConnected; + } else if (absl::IsUnavailable(connect_result)) { + connect_status_ = ConnectStatus::kConnecting; + } else { + QUICHE_DVLOG(1) << "Synchronously failed to connect socket to address: " + << peer_address_.ToString() + << " with error: " << connect_result; + Close(); + connect_status_ = ConnectStatus::kNotConnected; + } + + return connect_result; +} + +absl::Status EventLoopConnectingClientSocket::GetConnectResult() { + QUICHE_DCHECK_NE(descriptor_, kInvalidSocketFd); + QUICHE_DCHECK(connect_status_ == ConnectStatus::kConnecting); + QUICHE_DCHECK(!receive_max_size_.has_value()); + QUICHE_DCHECK(absl::holds_alternative(send_data_)); + + absl::Status error = socket_api::GetSocketError(descriptor_); + + if (!error.ok()) { + QUICHE_DVLOG(1) << "Asynchronously failed to connect socket to address: " + << peer_address_.ToString() << " with error: " << error; + Close(); + connect_status_ = ConnectStatus::kNotConnected; + return error; + } + + // Peek at one byte to confirm the connection is actually alive. Motivation: + // 1) Plausibly could have a lot of cases where the connection operation + // itself technically succeeds but the socket then quickly fails. Don't + // want to claim connection success here if, by the time this code is + // running after event triggers and such, the socket has already failed. + // Lot of undefined room around whether or not such errors would be saved + // into SO_ERROR and returned by socket_api::GetSocketError(). + // 2) With the various platforms and event systems involved, less than 100% + // trust that it's impossible to end up in this method before the async + // connect has completed/errored. Given that Connect() and GetSocketError() + // does not difinitevely differentiate between success and + // still-in-progress, and given that there's a very simple and performant + // way to positively confirm the socket is connected (peek), do that here. + // (Could consider making the not-connected case a QUIC_BUG if a way is + // found to differentiate it from (1).) + absl::StatusOr peek_data = OneBytePeek(); + if (peek_data.ok() || absl::IsUnavailable(peek_data.status())) { + connect_status_ = ConnectStatus::kConnected; + } else { + error = peek_data.status(); + QUICHE_LOG_FIRST_N(WARNING, 100) + << "Socket to address: " << peer_address_.ToString() + << " signalled writable after connect and no connect error found, " + "but socket does not appear connected with error: " + << error; + Close(); + connect_status_ = ConnectStatus::kNotConnected; + } + + return error; +} + +void EventLoopConnectingClientSocket::FinishOrRearmAsyncConnect( + absl::Status status) { + if (absl::IsUnavailable(status)) { + if (!event_loop_->SupportsEdgeTriggered()) { + bool result = event_loop_->RearmSocket( + descriptor_, kSocketEventWritable | kSocketEventError); + QUICHE_DCHECK(result); + } + QUICHE_DCHECK(connect_status_ == ConnectStatus::kConnecting); + } else { + QUICHE_DCHECK(connect_status_ != ConnectStatus::kConnecting); + async_visitor_->ConnectComplete(status); + } +} + +absl::StatusOr +EventLoopConnectingClientSocket::ReceiveInternal() { + QUICHE_DCHECK_NE(descriptor_, kInvalidSocketFd); + QUICHE_DCHECK(connect_status_ == ConnectStatus::kConnected); + QUICHE_CHECK(receive_max_size_.has_value()); + QUICHE_DCHECK_GE(receive_max_size_.value(), 1u); + QUICHE_DCHECK_LE(receive_max_size_.value(), + std::numeric_limits::max()); + + // Before allocating a buffer, do a 1-byte peek to determine if needed. + if (receive_max_size_.value() > 1) { + absl::StatusOr peek_data = OneBytePeek(); + if (!peek_data.ok()) { + if (!absl::IsUnavailable(peek_data.status())) { + receive_max_size_.reset(); + } + return peek_data.status(); + } else if (!peek_data.value()) { + receive_max_size_.reset(); + return quiche::QuicheMemSlice(); + } + } + + quiche::QuicheBuffer buffer(buffer_allocator_, receive_max_size_.value()); + absl::StatusOr> received = socket_api::Receive( + descriptor_, absl::MakeSpan(buffer.data(), buffer.size())); + + if (received.ok()) { + QUICHE_DCHECK_LE(received.value().size(), buffer.size()); + QUICHE_DCHECK_EQ(received.value().data(), buffer.data()); + + receive_max_size_.reset(); + return quiche::QuicheMemSlice( + quiche::QuicheBuffer(buffer.Release(), received.value().size())); + } else { + if (!absl::IsUnavailable(received.status())) { + QUICHE_DVLOG(1) << "Failed to receive from socket to address: " + << peer_address_.ToString() + << " with error: " << received.status(); + receive_max_size_.reset(); + } + return received.status(); + } +} + +void EventLoopConnectingClientSocket::FinishOrRearmAsyncReceive( + absl::StatusOr buffer) { + QUICHE_DCHECK(async_visitor_); + QUICHE_DCHECK(connect_status_ == ConnectStatus::kConnected); + + if (!buffer.ok() && absl::IsUnavailable(buffer.status())) { + if (!event_loop_->SupportsEdgeTriggered()) { + bool result = event_loop_->RearmSocket( + descriptor_, kSocketEventReadable | kSocketEventError); + QUICHE_DCHECK(result); + } + QUICHE_DCHECK(receive_max_size_.has_value()); + } else { + QUICHE_DCHECK(!receive_max_size_.has_value()); + async_visitor_->ReceiveComplete(std::move(buffer)); + } +} + +absl::StatusOr EventLoopConnectingClientSocket::OneBytePeek() { + QUICHE_DCHECK_NE(descriptor_, kInvalidSocketFd); + + char peek_buffer; + absl::StatusOr> peek_received = socket_api::Receive( + descriptor_, absl::MakeSpan(&peek_buffer, /*size=*/1), /*peek=*/true); + if (!peek_received.ok()) { + return peek_received.status(); + } else { + return !peek_received.value().empty(); + } +} + +absl::Status EventLoopConnectingClientSocket::SendBlockingInternal() { + QUICHE_DCHECK_NE(descriptor_, kInvalidSocketFd); + QUICHE_DCHECK(connect_status_ == ConnectStatus::kConnected); + QUICHE_DCHECK(!absl::holds_alternative(send_data_)); + QUICHE_DCHECK(send_remaining_.empty()); + + absl::Status status = + socket_api::SetSocketBlocking(descriptor_, /*blocking=*/true); + if (!status.ok()) { + QUICHE_LOG_FIRST_N(WARNING, 100) + << "Failed to set socket to address: " << peer_address_.ToString() + << " as blocking for send with error: " << status; + send_data_ = absl::monostate(); + return status; + } + + if (absl::holds_alternative(send_data_)) { + send_remaining_ = absl::get(send_data_); + } else { + send_remaining_ = + absl::get(send_data_).AsStringView(); + } + + status = SendInternal(); + if (absl::IsUnavailable(status)) { + QUICHE_LOG_FIRST_N(ERROR, 100) + << "Non-blocking send for should-be blocking socket to address:" + << peer_address_.ToString(); + send_data_ = absl::monostate(); + send_remaining_ = ""; + } else { + QUICHE_DCHECK(absl::holds_alternative(send_data_)); + QUICHE_DCHECK(send_remaining_.empty()); + } + + absl::Status set_non_blocking_status = + socket_api::SetSocketBlocking(descriptor_, /*blocking=*/false); + if (!set_non_blocking_status.ok()) { + QUICHE_LOG_FIRST_N(WARNING, 100) + << "Failed to return socket to address: " << peer_address_.ToString() + << " to non-blocking after send with error: " + << set_non_blocking_status; + return set_non_blocking_status; + } + + return status; +} + +absl::Status EventLoopConnectingClientSocket::SendInternal() { + QUICHE_DCHECK_NE(descriptor_, kInvalidSocketFd); + QUICHE_DCHECK(connect_status_ == ConnectStatus::kConnected); + QUICHE_DCHECK(!absl::holds_alternative(send_data_)); + QUICHE_DCHECK(!send_remaining_.empty()); + + // Repeat send until all data sent, unavailable, or error. + while (!send_remaining_.empty()) { + absl::StatusOr remainder = + socket_api::Send(descriptor_, send_remaining_); + + if (remainder.ok()) { + QUICHE_DCHECK(remainder.value().empty() || + (remainder.value().data() >= send_remaining_.data() && + remainder.value().data() < + send_remaining_.data() + send_remaining_.size())); + QUICHE_DCHECK(remainder.value().empty() || + (remainder.value().data() + remainder.value().size() == + send_remaining_.data() + send_remaining_.size())); + send_remaining_ = remainder.value(); + } else { + if (!absl::IsUnavailable(remainder.status())) { + QUICHE_DVLOG(1) << "Failed to send to socket to address: " + << peer_address_.ToString() + << " with error: " << remainder.status(); + send_data_ = absl::monostate(); + send_remaining_ = ""; + } + return remainder.status(); + } + } + + send_data_ = absl::monostate(); + return absl::OkStatus(); +} + +void EventLoopConnectingClientSocket::FinishOrRearmAsyncSend( + absl::Status status) { + QUICHE_DCHECK(async_visitor_); + QUICHE_DCHECK(connect_status_ == ConnectStatus::kConnected); + + if (absl::IsUnavailable(status)) { + if (!event_loop_->SupportsEdgeTriggered()) { + bool result = event_loop_->RearmSocket( + descriptor_, kSocketEventWritable | kSocketEventError); + QUICHE_DCHECK(result); + } + QUICHE_DCHECK(!absl::holds_alternative(send_data_)); + QUICHE_DCHECK(!send_remaining_.empty()); + } else { + QUICHE_DCHECK(absl::holds_alternative(send_data_)); + QUICHE_DCHECK(send_remaining_.empty()); + async_visitor_->SendComplete(status); + } +} + +} // namespace quic diff --git a/quiche/quic/core/io/event_loop_connecting_client_socket.h b/quiche/quic/core/io/event_loop_connecting_client_socket.h new file mode 100644 index 000000000000..a7fb32c8dce1 --- /dev/null +++ b/quiche/quic/core/io/event_loop_connecting_client_socket.h @@ -0,0 +1,106 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_IO_EVENT_LOOP_CONNECTING_CLIENT_SOCKET_H_ +#define QUICHE_QUIC_CORE_IO_EVENT_LOOP_CONNECTING_CLIENT_SOCKET_H_ + +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "quiche/quic/core/connecting_client_socket.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/io/socket.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/quiche_buffer_allocator.h" + +namespace quic { + +// A connection-based client socket implemented using an underlying +// QuicEventLoop. +class QUICHE_EXPORT EventLoopConnectingClientSocket + : public ConnectingClientSocket, + public QuicSocketEventListener { + public: + // Will use platform default buffer size if `receive_buffer_size` or + // `send_buffer_size` is zero. `async_visitor` may be null if no async + // operations will be requested. `event_loop`, `buffer_allocator`, and + // `async_visitor` (if non-null) must outlive the created socket. + EventLoopConnectingClientSocket( + socket_api::SocketProtocol protocol, + const quic::QuicSocketAddress& peer_address, + QuicByteCount receive_buffer_size, QuicByteCount send_buffer_size, + QuicEventLoop* event_loop, + quiche::QuicheBufferAllocator* buffer_allocator, + AsyncVisitor* async_visitor); + + ~EventLoopConnectingClientSocket() override; + + // ConnectingClientSocket: + absl::Status ConnectBlocking() override; + void ConnectAsync() override; + void Disconnect() override; + absl::StatusOr GetLocalAddress() override; + absl::StatusOr ReceiveBlocking( + QuicByteCount max_size) override; + void ReceiveAsync(QuicByteCount max_size) override; + absl::Status SendBlocking(std::string data) override; + absl::Status SendBlocking(quiche::QuicheMemSlice data) override; + void SendAsync(std::string data) override; + void SendAsync(quiche::QuicheMemSlice data) override; + + // QuicSocketEventListener: + void OnSocketEvent(QuicEventLoop* event_loop, SocketFd fd, + QuicSocketEventMask events) override; + + private: + enum class ConnectStatus { + kNotConnected, + kConnecting, + kConnected, + }; + + absl::Status Open(); + void Close(); + absl::Status DoInitialConnect(); + absl::Status GetConnectResult(); + void FinishOrRearmAsyncConnect(absl::Status status); + absl::StatusOr ReceiveInternal(); + void FinishOrRearmAsyncReceive(absl::StatusOr buffer); + // Returns `true` if a byte received, or `false` if successfully received + // empty data. + absl::StatusOr OneBytePeek(); + absl::Status SendBlockingInternal(); + absl::Status SendInternal(); + void FinishOrRearmAsyncSend(absl::Status status); + + const socket_api::SocketProtocol protocol_; + const QuicSocketAddress peer_address_; + const QuicByteCount receive_buffer_size_; + const QuicByteCount send_buffer_size_; + QuicEventLoop* const event_loop_; // unowned + quiche::QuicheBufferAllocator* buffer_allocator_; // unowned + AsyncVisitor* const async_visitor_; // unowned, potentially null + + SocketFd descriptor_ = kInvalidSocketFd; + ConnectStatus connect_status_ = ConnectStatus::kNotConnected; + + // Only set while receive in progress or pending, otherwise nullopt. + absl::optional receive_max_size_; + + // Only contains data while send in progress or pending, otherwise monostate. + absl::variant + send_data_; + // Points to the unsent portion of `send_data_` while send in progress or + // pending, otherwise empty. + absl::string_view send_remaining_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_IO_EVENT_LOOP_CONNECTING_CLIENT_SOCKET_H_ diff --git a/quiche/quic/core/io/event_loop_connecting_client_socket_test.cc b/quiche/quic/core/io/event_loop_connecting_client_socket_test.cc new file mode 100644 index 000000000000..37cb607fdaf2 --- /dev/null +++ b/quiche/quic/core/io/event_loop_connecting_client_socket_test.cc @@ -0,0 +1,700 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/io/event_loop_connecting_client_socket.h" + +#include +#include +#include +#include +#include + +#include "absl/functional/bind_front.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "quiche/quic/core/connecting_client_socket.h" +#include "quiche/quic/core/io/event_loop_socket_factory.h" +#include "quiche/quic/core/io/quic_default_event_loop.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/io/socket.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_ip_address_family.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/test_tools/mock_clock.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/common/platform/api/quiche_mutex.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/platform/api/quiche_test_loopback.h" +#include "quiche/common/platform/api/quiche_thread.h" +#include "quiche/common/simple_buffer_allocator.h" + +namespace quic::test { +namespace { + +using ::testing::Combine; +using ::testing::Values; +using ::testing::ValuesIn; + +class TestServerSocketRunner : public quiche::QuicheThread { + public: + using SocketBehavior = std::function; + + TestServerSocketRunner(SocketFd server_socket_descriptor, + SocketBehavior behavior) + : QuicheThread("TestServerSocketRunner"), + server_socket_descriptor_(server_socket_descriptor), + behavior_(std::move(behavior)) {} + ~TestServerSocketRunner() override { WaitForCompletion(); } + + void WaitForCompletion() { completion_notification_.WaitForNotification(); } + + protected: + SocketFd server_socket_descriptor() const { + return server_socket_descriptor_; + } + + const SocketBehavior& behavior() const { return behavior_; } + + quiche::QuicheNotification& completion_notification() { + return completion_notification_; + } + + private: + const SocketFd server_socket_descriptor_; + const SocketBehavior behavior_; + + quiche::QuicheNotification completion_notification_; +}; + +class TestTcpServerSocketRunner : public TestServerSocketRunner { + public: + // On construction, spins a separate thread to accept a connection from + // `server_socket_descriptor`, runs `behavior` with that connection, and then + // closes the accepted connection socket. + TestTcpServerSocketRunner(SocketFd server_socket_descriptor, + SocketBehavior behavior) + : TestServerSocketRunner(server_socket_descriptor, behavior) { + Start(); + } + + ~TestTcpServerSocketRunner() override { Join(); } + + protected: + void Run() override { + AcceptSocket(); + behavior()(connection_socket_descriptor_, socket_api::SocketProtocol::kTcp); + CloseSocket(); + + completion_notification().Notify(); + } + + private: + void AcceptSocket() { + absl::StatusOr connection_socket = + socket_api::Accept(server_socket_descriptor(), /*blocking=*/true); + QUICHE_CHECK(connection_socket.ok()); + connection_socket_descriptor_ = connection_socket.value().fd; + } + + void CloseSocket() { + QUICHE_CHECK(socket_api::Close(connection_socket_descriptor_).ok()); + QUICHE_CHECK(socket_api::Close(server_socket_descriptor()).ok()); + } + + SocketFd connection_socket_descriptor_ = kInvalidSocketFd; +}; + +class TestUdpServerSocketRunner : public TestServerSocketRunner { + public: + // On construction, spins a separate thread to connect + // `server_socket_descriptor` to `client_socket_address`, runs `behavior` with + // that connection, and then disconnects the socket. + TestUdpServerSocketRunner(SocketFd server_socket_descriptor, + SocketBehavior behavior, + QuicSocketAddress client_socket_address) + : TestServerSocketRunner(server_socket_descriptor, behavior), + client_socket_address_(std::move(client_socket_address)) { + Start(); + } + + ~TestUdpServerSocketRunner() override { Join(); } + + protected: + void Run() override { + ConnectSocket(); + behavior()(server_socket_descriptor(), socket_api::SocketProtocol::kUdp); + DisconnectSocket(); + + completion_notification().Notify(); + } + + private: + void ConnectSocket() { + QUICHE_CHECK( + socket_api::Connect(server_socket_descriptor(), client_socket_address_) + .ok()); + } + + void DisconnectSocket() { + QUICHE_CHECK(socket_api::Close(server_socket_descriptor()).ok()); + } + + QuicSocketAddress client_socket_address_; +}; + +class EventLoopConnectingClientSocketTest + : public quiche::test::QuicheTestWithParam< + std::tuple>, + public ConnectingClientSocket::AsyncVisitor { + public: + void SetUp() override { + QuicEventLoopFactory* event_loop_factory; + std::tie(protocol_, event_loop_factory) = GetParam(); + + event_loop_ = event_loop_factory->Create(&clock_); + socket_factory_ = std::make_unique( + event_loop_.get(), quiche::SimpleBufferAllocator::Get()); + + QUICHE_CHECK(CreateListeningServerSocket()); + } + + void TearDown() override { + if (server_socket_descriptor_ != kInvalidSocketFd) { + QUICHE_CHECK(socket_api::Close(server_socket_descriptor_).ok()); + } + } + + void ConnectComplete(absl::Status status) override { + QUICHE_CHECK(!connect_result_.has_value()); + connect_result_ = std::move(status); + } + + void ReceiveComplete(absl::StatusOr data) override { + QUICHE_CHECK(!receive_result_.has_value()); + receive_result_ = std::move(data); + } + + void SendComplete(absl::Status status) override { + QUICHE_CHECK(!send_result_.has_value()); + send_result_ = std::move(status); + } + + protected: + std::unique_ptr CreateSocket( + const quic::QuicSocketAddress& peer_address, + ConnectingClientSocket::AsyncVisitor* async_visitor) { + switch (protocol_) { + case socket_api::SocketProtocol::kUdp: + return socket_factory_->CreateConnectingUdpClientSocket( + peer_address, /*receive_buffer_size=*/0, /*send_buffer_size=*/0, + async_visitor); + case socket_api::SocketProtocol::kTcp: + return socket_factory_->CreateTcpClientSocket( + peer_address, /*receive_buffer_size=*/0, /*send_buffer_size=*/0, + async_visitor); + } + } + + std::unique_ptr CreateSocketToEncourageDelayedSend( + const quic::QuicSocketAddress& peer_address, + ConnectingClientSocket::AsyncVisitor* async_visitor) { + switch (protocol_) { + case socket_api::SocketProtocol::kUdp: + // Nothing special for UDP since UDP does not gaurantee packets will be + // sent once send buffers are full. + return socket_factory_->CreateConnectingUdpClientSocket( + peer_address, /*receive_buffer_size=*/0, /*send_buffer_size=*/0, + async_visitor); + case socket_api::SocketProtocol::kTcp: + // For TCP, set a very small send buffer to encourage sends to be + // delayed. + return socket_factory_->CreateTcpClientSocket( + peer_address, /*receive_buffer_size=*/0, /*send_buffer_size=*/4, + async_visitor); + } + } + + bool CreateListeningServerSocket() { + absl::StatusOr socket = socket_api::CreateSocket( + quiche::TestLoopback().address_family(), protocol_, + /*blocking=*/true); + QUICHE_CHECK(socket.ok()); + + // For TCP, set an extremely small receive buffer size to increase the odds + // of buffers filling up when testing asynchronous writes. + if (protocol_ == socket_api::SocketProtocol::kTcp) { + static const QuicByteCount kReceiveBufferSize = 2; + absl::Status result = + socket_api::SetReceiveBufferSize(socket.value(), kReceiveBufferSize); + QUICHE_CHECK(result.ok()); + } + + QuicSocketAddress bind_address(quiche::TestLoopback(), /*port=*/0); + absl::Status result = socket_api::Bind(socket.value(), bind_address); + QUICHE_CHECK(result.ok()); + + absl::StatusOr socket_address = + socket_api::GetSocketAddress(socket.value()); + QUICHE_CHECK(socket_address.ok()); + + // TCP sockets need to listen for connections. UDP sockets are ready to + // receive. + if (protocol_ == socket_api::SocketProtocol::kTcp) { + result = socket_api::Listen(socket.value(), /*backlog=*/1); + QUICHE_CHECK(result.ok()); + } + + server_socket_descriptor_ = socket.value(); + server_socket_address_ = std::move(socket_address).value(); + return true; + } + + std::unique_ptr CreateServerSocketRunner( + TestServerSocketRunner::SocketBehavior behavior, + ConnectingClientSocket* client_socket) { + std::unique_ptr runner; + switch (protocol_) { + case socket_api::SocketProtocol::kUdp: { + absl::StatusOr client_socket_address = + client_socket->GetLocalAddress(); + QUICHE_CHECK(client_socket_address.ok()); + runner = std::make_unique( + server_socket_descriptor_, std::move(behavior), + std::move(client_socket_address).value()); + break; + } + case socket_api::SocketProtocol::kTcp: + runner = std::make_unique( + server_socket_descriptor_, std::move(behavior)); + break; + } + + // Runner takes responsibility for closing server socket. + server_socket_descriptor_ = kInvalidSocketFd; + + return runner; + } + + socket_api::SocketProtocol protocol_; + + SocketFd server_socket_descriptor_ = kInvalidSocketFd; + QuicSocketAddress server_socket_address_; + + MockClock clock_; + std::unique_ptr event_loop_; + std::unique_ptr socket_factory_; + + absl::optional connect_result_; + absl::optional> receive_result_; + absl::optional send_result_; +}; + +std::string GetTestParamName( + ::testing::TestParamInfo< + std::tuple> + info) { + auto [protocol, event_loop_factory] = info.param; + + return EscapeTestParamName(absl::StrCat(socket_api::GetProtocolName(protocol), + "_", event_loop_factory->GetName())); +} + +INSTANTIATE_TEST_SUITE_P(EventLoopConnectingClientSocketTests, + EventLoopConnectingClientSocketTest, + Combine(Values(socket_api::SocketProtocol::kUdp, + socket_api::SocketProtocol::kTcp), + ValuesIn(GetAllSupportedEventLoops())), + &GetTestParamName); + +TEST_P(EventLoopConnectingClientSocketTest, ConnectBlocking) { + std::unique_ptr socket = + CreateSocket(server_socket_address_, + /*async_visitor=*/nullptr); + + // No socket runner to accept the connection for the server, but that is not + // expected to be necessary for the connection to complete from the client for + // TCP or UDP. + EXPECT_TRUE(socket->ConnectBlocking().ok()); + + socket->Disconnect(); +} + +TEST_P(EventLoopConnectingClientSocketTest, ConnectAsync) { + std::unique_ptr socket = + CreateSocket(server_socket_address_, + /*async_visitor=*/this); + + socket->ConnectAsync(); + + // TCP connection typically completes asynchronously and UDP connection + // typically completes before ConnectAsync returns, but there is no simple way + // to ensure either behaves one way or the other. If connecting is + // asynchronous, expect completion once signalled by the event loop. + if (!connect_result_.has_value()) { + event_loop_->RunEventLoopOnce(QuicTime::Delta::FromSeconds(1)); + ASSERT_TRUE(connect_result_.has_value()); + } + EXPECT_TRUE(connect_result_.value().ok()); + + connect_result_.reset(); + socket->Disconnect(); + EXPECT_FALSE(connect_result_.has_value()); +} + +TEST_P(EventLoopConnectingClientSocketTest, ErrorBeforeConnectAsync) { + std::unique_ptr socket = + CreateSocket(server_socket_address_, + /*async_visitor=*/this); + + // Close the server socket. + EXPECT_TRUE(socket_api::Close(server_socket_descriptor_).ok()); + server_socket_descriptor_ = kInvalidSocketFd; + + socket->ConnectAsync(); + if (!connect_result_.has_value()) { + event_loop_->RunEventLoopOnce(QuicTime::Delta::FromSeconds(1)); + ASSERT_TRUE(connect_result_.has_value()); + } + + switch (protocol_) { + case socket_api::SocketProtocol::kTcp: + // Expect an error because server socket was closed before connection. + EXPECT_FALSE(connect_result_.value().ok()); + break; + case socket_api::SocketProtocol::kUdp: + // No error for UDP because UDP connection success does not rely on the + // server. + EXPECT_TRUE(connect_result_.value().ok()); + socket->Disconnect(); + break; + } +} + +TEST_P(EventLoopConnectingClientSocketTest, ErrorDuringConnectAsync) { + std::unique_ptr socket = + CreateSocket(server_socket_address_, + /*async_visitor=*/this); + + socket->ConnectAsync(); + + if (connect_result_.has_value()) { + // UDP typically completes connection immediately before this test has a + // chance to actually attempt the error. TCP typically completes + // asynchronously, but no simple way to ensure that always happens. + EXPECT_TRUE(connect_result_.value().ok()); + socket->Disconnect(); + return; + } + + // Close the server socket. + EXPECT_TRUE(socket_api::Close(server_socket_descriptor_).ok()); + server_socket_descriptor_ = kInvalidSocketFd; + + EXPECT_FALSE(connect_result_.has_value()); + event_loop_->RunEventLoopOnce(QuicTime::Delta::FromSeconds(1)); + ASSERT_TRUE(connect_result_.has_value()); + + switch (protocol_) { + case socket_api::SocketProtocol::kTcp: + EXPECT_FALSE(connect_result_.value().ok()); + break; + case socket_api::SocketProtocol::kUdp: + // No error for UDP because UDP connection success does not rely on the + // server. + EXPECT_TRUE(connect_result_.value().ok()); + break; + } +} + +TEST_P(EventLoopConnectingClientSocketTest, Disconnect) { + std::unique_ptr socket = + CreateSocket(server_socket_address_, + /*async_visitor=*/nullptr); + + ASSERT_TRUE(socket->ConnectBlocking().ok()); + socket->Disconnect(); +} + +TEST_P(EventLoopConnectingClientSocketTest, DisconnectCancelsConnectAsync) { + std::unique_ptr socket = + CreateSocket(server_socket_address_, + /*async_visitor=*/this); + + socket->ConnectAsync(); + + bool expect_canceled = true; + if (connect_result_.has_value()) { + // UDP typically completes connection immediately before this test has a + // chance to actually attempt the disconnect. TCP typically completes + // asynchronously, but no simple way to ensure that always happens. + EXPECT_TRUE(connect_result_.value().ok()); + expect_canceled = false; + } + + socket->Disconnect(); + + if (expect_canceled) { + // Expect immediate cancelled error. + ASSERT_TRUE(connect_result_.has_value()); + EXPECT_TRUE(absl::IsCancelled(connect_result_.value())); + } +} + +TEST_P(EventLoopConnectingClientSocketTest, ConnectAndReconnect) { + std::unique_ptr socket = + CreateSocket(server_socket_address_, + /*async_visitor=*/nullptr); + + ASSERT_TRUE(socket->ConnectBlocking().ok()); + socket->Disconnect(); + + // Expect `socket` can reconnect now that it has been disconnected. + EXPECT_TRUE(socket->ConnectBlocking().ok()); + socket->Disconnect(); +} + +TEST_P(EventLoopConnectingClientSocketTest, GetLocalAddress) { + std::unique_ptr socket = + CreateSocket(server_socket_address_, + /*async_visitor=*/nullptr); + ASSERT_TRUE(socket->ConnectBlocking().ok()); + + absl::StatusOr address = socket->GetLocalAddress(); + ASSERT_TRUE(address.ok()); + EXPECT_TRUE(address.value().IsInitialized()); + + socket->Disconnect(); +} + +void SendDataOnSocket(absl::string_view data, SocketFd connected_socket, + socket_api::SocketProtocol protocol) { + QUICHE_CHECK(!data.empty()); + + // May attempt to send in pieces for TCP. For UDP, expect failure if `data` + // cannot be sent in a single packet. + do { + absl::StatusOr remainder = + socket_api::Send(connected_socket, data); + if (!remainder.ok()) { + return; + } + data = remainder.value(); + } while (protocol == socket_api::SocketProtocol::kTcp && !data.empty()); + + QUICHE_CHECK(data.empty()); +} + +TEST_P(EventLoopConnectingClientSocketTest, ReceiveBlocking) { + std::unique_ptr socket = + CreateSocket(server_socket_address_, + /*async_visitor=*/nullptr); + ASSERT_TRUE(socket->ConnectBlocking().ok()); + + std::string expected = {1, 2, 3, 4, 5, 6, 7, 8}; + std::unique_ptr runner = CreateServerSocketRunner( + absl::bind_front(&SendDataOnSocket, expected), socket.get()); + + std::string received; + absl::StatusOr data; + + // Expect exactly one packet for UDP, and at least two receives (data + FIN) + // for TCP. + do { + data = socket->ReceiveBlocking(100); + ASSERT_TRUE(data.ok()); + received.append(data.value().data(), data.value().length()); + } while (protocol_ == socket_api::SocketProtocol::kTcp && + !data.value().empty()); + + EXPECT_EQ(received, expected); + + socket->Disconnect(); +} + +TEST_P(EventLoopConnectingClientSocketTest, ReceiveAsync) { + std::unique_ptr socket = + CreateSocket(server_socket_address_, + /*async_visitor=*/this); + ASSERT_TRUE(socket->ConnectBlocking().ok()); + + // Start an async receive. Expect no immediate results because runner not + // yet setup to send. + socket->ReceiveAsync(100); + EXPECT_FALSE(receive_result_.has_value()); + + // Send data from server. + std::string expected = {1, 2, 3, 4, 5, 6, 7, 8}; + std::unique_ptr runner = CreateServerSocketRunner( + absl::bind_front(&SendDataOnSocket, expected), socket.get()); + + EXPECT_FALSE(receive_result_.has_value()); + for (int i = 0; i < 5 && !receive_result_.has_value(); ++i) { + event_loop_->RunEventLoopOnce(QuicTime::Delta::FromSeconds(1)); + } + + // Expect to receive at least some of the sent data. + ASSERT_TRUE(receive_result_.has_value()); + ASSERT_TRUE(receive_result_.value().ok()); + EXPECT_FALSE(receive_result_.value().value().empty()); + std::string received(receive_result_.value().value().data(), + receive_result_.value().value().length()); + + // For TCP, expect at least one more receive for the FIN. + if (protocol_ == socket_api::SocketProtocol::kTcp) { + absl::StatusOr data; + do { + data = socket->ReceiveBlocking(100); + ASSERT_TRUE(data.ok()); + received.append(data.value().data(), data.value().length()); + } while (!data.value().empty()); + } + + EXPECT_EQ(received, expected); + + receive_result_.reset(); + socket->Disconnect(); + EXPECT_FALSE(receive_result_.has_value()); +} + +TEST_P(EventLoopConnectingClientSocketTest, DisconnectCancelsReceiveAsync) { + std::unique_ptr socket = + CreateSocket(server_socket_address_, + /*async_visitor=*/this); + + ASSERT_TRUE(socket->ConnectBlocking().ok()); + + // Start an asynchronous read, expecting no completion because server never + // sends any data. + socket->ReceiveAsync(100); + EXPECT_FALSE(receive_result_.has_value()); + + // Disconnect and expect an immediate cancelled error. + socket->Disconnect(); + ASSERT_TRUE(receive_result_.has_value()); + ASSERT_FALSE(receive_result_.value().ok()); + EXPECT_TRUE(absl::IsCancelled(receive_result_.value().status())); +} + +// Receive from `connected_socket` until connection is closed, writing +// received data to `out_received`. +void ReceiveDataFromSocket(std::string* out_received, SocketFd connected_socket, + socket_api::SocketProtocol protocol) { + out_received->clear(); + + std::string buffer(100, 0); + absl::StatusOr> received; + + // Expect exactly one packet for UDP, and at least two receives (data + FIN) + // for TCP. + do { + received = socket_api::Receive(connected_socket, absl::MakeSpan(buffer)); + QUICHE_CHECK(received.ok()); + out_received->insert(out_received->end(), received.value().begin(), + received.value().end()); + } while (protocol == socket_api::SocketProtocol::kTcp && + !received.value().empty()); + QUICHE_CHECK(!out_received->empty()); +} + +TEST_P(EventLoopConnectingClientSocketTest, SendBlocking) { + std::unique_ptr socket = + CreateSocket(server_socket_address_, + /*async_visitor=*/nullptr); + ASSERT_TRUE(socket->ConnectBlocking().ok()); + + std::string sent; + std::unique_ptr runner = CreateServerSocketRunner( + absl::bind_front(&ReceiveDataFromSocket, &sent), socket.get()); + + std::string expected = {1, 2, 3, 4, 5, 6, 7, 8}; + EXPECT_TRUE(socket->SendBlocking(expected).ok()); + socket->Disconnect(); + + runner->WaitForCompletion(); + EXPECT_EQ(sent, expected); +} + +TEST_P(EventLoopConnectingClientSocketTest, SendAsync) { + std::unique_ptr socket = + CreateSocketToEncourageDelayedSend(server_socket_address_, + /*async_visitor=*/this); + ASSERT_TRUE(socket->ConnectBlocking().ok()); + + std::string data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + std::string expected; + + std::unique_ptr runner; + std::string sent; + switch (protocol_) { + case socket_api::SocketProtocol::kTcp: + // Repeatedly write to socket until it does not complete synchronously. + do { + expected.insert(expected.end(), data.begin(), data.end()); + send_result_.reset(); + socket->SendAsync(data); + ASSERT_TRUE(!send_result_.has_value() || send_result_.value().ok()); + } while (send_result_.has_value()); + + // Begin receiving from server and expect more data to send. + runner = CreateServerSocketRunner( + absl::bind_front(&ReceiveDataFromSocket, &sent), socket.get()); + EXPECT_FALSE(send_result_.has_value()); + for (int i = 0; i < 5 && !send_result_.has_value(); ++i) { + event_loop_->RunEventLoopOnce(QuicTime::Delta::FromSeconds(1)); + } + break; + + case socket_api::SocketProtocol::kUdp: + // Expect UDP send to always send immediately. + runner = CreateServerSocketRunner( + absl::bind_front(&ReceiveDataFromSocket, &sent), socket.get()); + socket->SendAsync(data); + expected = data; + break; + } + ASSERT_TRUE(send_result_.has_value()); + EXPECT_TRUE(send_result_.value().ok()); + + send_result_.reset(); + socket->Disconnect(); + EXPECT_FALSE(send_result_.has_value()); + + runner->WaitForCompletion(); + EXPECT_EQ(sent, expected); +} + +TEST_P(EventLoopConnectingClientSocketTest, DisconnectCancelsSendAsync) { + if (protocol_ == socket_api::SocketProtocol::kUdp) { + // UDP sends are always immediate, so cannot disconect mid-send. + return; + } + + std::unique_ptr socket = + CreateSocketToEncourageDelayedSend(server_socket_address_, + /*async_visitor=*/this); + ASSERT_TRUE(socket->ConnectBlocking().ok()); + + std::string data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + + // Repeatedly write to socket until it does not complete synchronously. + do { + send_result_.reset(); + socket->SendAsync(data); + ASSERT_TRUE(!send_result_.has_value() || send_result_.value().ok()); + } while (send_result_.has_value()); + + // Disconnect and expect immediate cancelled error. + socket->Disconnect(); + ASSERT_TRUE(send_result_.has_value()); + EXPECT_TRUE(absl::IsCancelled(send_result_.value())); +} + +} // namespace +} // namespace quic::test diff --git a/quiche/quic/core/io/event_loop_socket_factory.cc b/quiche/quic/core/io/event_loop_socket_factory.cc new file mode 100644 index 000000000000..b1aaec7866cb --- /dev/null +++ b/quiche/quic/core/io/event_loop_socket_factory.cc @@ -0,0 +1,47 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/io/event_loop_socket_factory.h" + +#include + +#include "quiche/quic/core/connecting_client_socket.h" +#include "quiche/quic/core/io/event_loop_connecting_client_socket.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/io/socket.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_buffer_allocator.h" + +namespace quic { + +EventLoopSocketFactory::EventLoopSocketFactory( + QuicEventLoop* event_loop, quiche::QuicheBufferAllocator* buffer_allocator) + : event_loop_(event_loop), buffer_allocator_(buffer_allocator) { + QUICHE_DCHECK(event_loop_); + QUICHE_DCHECK(buffer_allocator_); +} + +std::unique_ptr +EventLoopSocketFactory::CreateTcpClientSocket( + const quic::QuicSocketAddress& peer_address, + QuicByteCount receive_buffer_size, QuicByteCount send_buffer_size, + ConnectingClientSocket::AsyncVisitor* async_visitor) { + return std::make_unique( + socket_api::SocketProtocol::kTcp, peer_address, receive_buffer_size, + send_buffer_size, event_loop_, buffer_allocator_, async_visitor); +} + +std::unique_ptr +EventLoopSocketFactory::CreateConnectingUdpClientSocket( + const quic::QuicSocketAddress& peer_address, + QuicByteCount receive_buffer_size, QuicByteCount send_buffer_size, + ConnectingClientSocket::AsyncVisitor* async_visitor) { + return std::make_unique( + socket_api::SocketProtocol::kUdp, peer_address, receive_buffer_size, + send_buffer_size, event_loop_, buffer_allocator_, async_visitor); +} + +} // namespace quic diff --git a/quiche/quic/core/io/event_loop_socket_factory.h b/quiche/quic/core/io/event_loop_socket_factory.h new file mode 100644 index 000000000000..8edf020bd169 --- /dev/null +++ b/quiche/quic/core/io/event_loop_socket_factory.h @@ -0,0 +1,45 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_IO_EVENT_LOOP_SOCKET_FACTORY_H_ +#define QUICHE_QUIC_CORE_IO_EVENT_LOOP_SOCKET_FACTORY_H_ + +#include + +#include "quiche/quic/core/connecting_client_socket.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/socket_factory.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/quiche_buffer_allocator.h" + +namespace quic { + +// A socket factory that creates sockets implemented using an underlying +// QuicEventLoop. +class QUICHE_EXPORT EventLoopSocketFactory : public SocketFactory { + public: + // `event_loop` and `buffer_allocator` must outlive the created factory. + EventLoopSocketFactory(QuicEventLoop* event_loop, + quiche::QuicheBufferAllocator* buffer_allocator); + + // SocketFactory: + std::unique_ptr CreateTcpClientSocket( + const quic::QuicSocketAddress& peer_address, + QuicByteCount receive_buffer_size, QuicByteCount send_buffer_size, + ConnectingClientSocket::AsyncVisitor* async_visitor) override; + std::unique_ptr CreateConnectingUdpClientSocket( + const quic::QuicSocketAddress& peer_address, + QuicByteCount receive_buffer_size, QuicByteCount send_buffer_size, + ConnectingClientSocket::AsyncVisitor* async_visitor) override; + + private: + QuicEventLoop* const event_loop_; // unowned + quiche::QuicheBufferAllocator* buffer_allocator_; // unowned +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_IO_EVENT_LOOP_SOCKET_FACTORY_H_ diff --git a/quiche/quic/core/io/quic_all_event_loops_test.cc b/quiche/quic/core/io/quic_all_event_loops_test.cc new file mode 100644 index 000000000000..ed6282476265 --- /dev/null +++ b/quiche/quic/core/io/quic_all_event_loops_test.cc @@ -0,0 +1,440 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// A universal test for all event loops supported by the build of QUICHE in +// question. +// +// This test is very similar to QuicPollEventLoopTest, however, there are some +// notable differences: +// (1) This test uses the real clock, since the event loop implementation may +// not support accepting a mock clock. +// (2) This test covers both level-triggered and edge-triggered event loops. + +#include +#include + +#include "absl/cleanup/cleanup.h" +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/io/quic_default_event_loop.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/quic_alarm.h" +#include "quiche/quic/core/quic_alarm_factory.h" +#include "quiche/quic/core/quic_default_clock.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic::test { +namespace { + +using testing::_; +using testing::AtMost; + +MATCHER_P(HasFlagSet, value, "Checks a flag in a bit mask") { + return (arg & value) != 0; +} + +constexpr QuicSocketEventMask kAllEvents = + kSocketEventReadable | kSocketEventWritable | kSocketEventError; + +class MockQuicSocketEventListener : public QuicSocketEventListener { + public: + MOCK_METHOD(void, OnSocketEvent, + (QuicEventLoop* /*event_loop*/, QuicUdpSocketFd /*fd*/, + QuicSocketEventMask /*events*/), + (override)); +}; + +class MockDelegate : public QuicAlarm::Delegate { + public: + QuicConnectionContext* GetConnectionContext() override { return nullptr; } + MOCK_METHOD(void, OnAlarm, (), (override)); +}; + +void SetNonBlocking(int fd) { + QUICHE_CHECK(::fcntl(fd, F_SETFL, ::fcntl(fd, F_GETFL) | O_NONBLOCK) == 0) + << "Failed to mark FD non-blocking, errno: " << errno; +} + +class QuicEventLoopFactoryTest + : public QuicTestWithParam { + public: + QuicEventLoopFactoryTest() + : loop_(GetParam()->Create(&clock_)), + factory_(loop_->CreateAlarmFactory()) { + int fds[2]; + int result = ::pipe(fds); + QUICHE_CHECK(result >= 0) << "Failed to create a pipe, errno: " << errno; + read_fd_ = fds[0]; + write_fd_ = fds[1]; + + SetNonBlocking(read_fd_); + SetNonBlocking(write_fd_); + } + + ~QuicEventLoopFactoryTest() { + close(read_fd_); + close(write_fd_); + } + + std::pair, MockDelegate*> CreateAlarm() { + auto delegate = std::make_unique>(); + MockDelegate* delegate_unowned = delegate.get(); + auto alarm = absl::WrapUnique(factory_->CreateAlarm(delegate.release())); + return std::make_pair(std::move(alarm), delegate_unowned); + } + + template + void RunEventLoopUntil(Condition condition, QuicTime::Delta timeout) { + const QuicTime end = clock_.Now() + timeout; + while (!condition() && clock_.Now() < end) { + loop_->RunEventLoopOnce(end - clock_.Now()); + } + } + + protected: + QuicDefaultClock clock_; + std::unique_ptr loop_; + std::unique_ptr factory_; + int read_fd_; + int write_fd_; +}; + +std::string GetTestParamName( + ::testing::TestParamInfo info) { + return EscapeTestParamName(info.param->GetName()); +} + +INSTANTIATE_TEST_SUITE_P(QuicEventLoopFactoryTests, QuicEventLoopFactoryTest, + ::testing::ValuesIn(GetAllSupportedEventLoops()), + GetTestParamName); + +TEST_P(QuicEventLoopFactoryTest, NothingHappens) { + testing::StrictMock listener; + ASSERT_TRUE(loop_->RegisterSocket(read_fd_, kAllEvents, &listener)); + ASSERT_TRUE(loop_->RegisterSocket(write_fd_, kAllEvents, &listener)); + + // Attempt double-registration. + EXPECT_FALSE(loop_->RegisterSocket(write_fd_, kAllEvents, &listener)); + + EXPECT_CALL(listener, OnSocketEvent(_, write_fd_, kSocketEventWritable)); + loop_->RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(4)); + // Expect no further calls. + loop_->RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(5)); +} + +TEST_P(QuicEventLoopFactoryTest, RearmWriter) { + testing::StrictMock listener; + ASSERT_TRUE(loop_->RegisterSocket(write_fd_, kAllEvents, &listener)); + + if (loop_->SupportsEdgeTriggered()) { + EXPECT_CALL(listener, OnSocketEvent(_, write_fd_, kSocketEventWritable)) + .Times(1); + loop_->RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); + loop_->RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); + } else { + EXPECT_CALL(listener, OnSocketEvent(_, write_fd_, kSocketEventWritable)) + .Times(2); + loop_->RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); + ASSERT_TRUE(loop_->RearmSocket(write_fd_, kSocketEventWritable)); + loop_->RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); + } +} + +TEST_P(QuicEventLoopFactoryTest, Readable) { + testing::StrictMock listener; + ASSERT_TRUE(loop_->RegisterSocket(read_fd_, kAllEvents, &listener)); + + ASSERT_EQ(4, write(write_fd_, "test", 4)); + EXPECT_CALL(listener, OnSocketEvent(_, read_fd_, kSocketEventReadable)); + loop_->RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); + // Expect no further calls. + loop_->RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); +} + +// A common pattern: read a limited amount of data from an FD, and expect to +// read the remainder on the next operation. +TEST_P(QuicEventLoopFactoryTest, ArtificialNotifyFromCallback) { + testing::StrictMock listener; + ASSERT_TRUE(loop_->RegisterSocket(read_fd_, kSocketEventReadable, &listener)); + + constexpr absl::string_view kData = "test test test test test test test "; + constexpr size_t kTimes = kData.size() / 5; + ASSERT_EQ(kData.size(), write(write_fd_, kData.data(), kData.size())); + EXPECT_CALL(listener, OnSocketEvent(_, read_fd_, kSocketEventReadable)) + .Times(loop_->SupportsEdgeTriggered() ? (kTimes + 1) : kTimes) + .WillRepeatedly([&]() { + char buf[5]; + int read_result = read(read_fd_, buf, sizeof(buf)); + if (read_result > 0) { + ASSERT_EQ(read_result, 5); + if (loop_->SupportsEdgeTriggered()) { + EXPECT_TRUE( + loop_->ArtificiallyNotifyEvent(read_fd_, kSocketEventReadable)); + } else { + EXPECT_TRUE(loop_->RearmSocket(read_fd_, kSocketEventReadable)); + } + } else { + EXPECT_EQ(errno, EAGAIN); + } + }); + for (size_t i = 0; i < kTimes + 2; i++) { + loop_->RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); + } +} + +TEST_P(QuicEventLoopFactoryTest, WriterUnblocked) { + testing::StrictMock listener; + ASSERT_TRUE(loop_->RegisterSocket(write_fd_, kAllEvents, &listener)); + + EXPECT_CALL(listener, OnSocketEvent(_, write_fd_, kSocketEventWritable)); + loop_->RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); + loop_->RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); + + int io_result; + std::string data(2048, 'a'); + do { + io_result = write(write_fd_, data.data(), data.size()); + } while (io_result > 0); + ASSERT_EQ(errno, EAGAIN); + + // Rearm if necessary and expect no immediate calls. + if (!loop_->SupportsEdgeTriggered()) { + ASSERT_TRUE(loop_->RearmSocket(write_fd_, kSocketEventWritable)); + } + loop_->RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); + + EXPECT_CALL(listener, OnSocketEvent(_, write_fd_, kSocketEventWritable)); + do { + io_result = read(read_fd_, data.data(), data.size()); + } while (io_result > 0); + ASSERT_EQ(errno, EAGAIN); + loop_->RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); +} + +TEST_P(QuicEventLoopFactoryTest, ArtificialEvent) { + testing::StrictMock listener; + ASSERT_TRUE(loop_->RegisterSocket(read_fd_, kAllEvents, &listener)); + ASSERT_TRUE(loop_->RegisterSocket(write_fd_, kAllEvents, &listener)); + + ASSERT_TRUE(loop_->ArtificiallyNotifyEvent(read_fd_, kSocketEventReadable)); + + EXPECT_CALL(listener, OnSocketEvent(_, read_fd_, kSocketEventReadable)); + EXPECT_CALL(listener, OnSocketEvent(_, write_fd_, kSocketEventWritable)); + loop_->RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); +} + +TEST_P(QuicEventLoopFactoryTest, Unregister) { + testing::StrictMock listener; + ASSERT_TRUE(loop_->RegisterSocket(write_fd_, kAllEvents, &listener)); + ASSERT_TRUE(loop_->UnregisterSocket(write_fd_)); + + // Expect nothing to happen. + loop_->RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); + + EXPECT_FALSE(loop_->UnregisterSocket(write_fd_)); + if (!loop_->SupportsEdgeTriggered()) { + EXPECT_FALSE(loop_->RearmSocket(write_fd_, kSocketEventWritable)); + } + EXPECT_FALSE(loop_->ArtificiallyNotifyEvent(write_fd_, kSocketEventWritable)); +} + +TEST_P(QuicEventLoopFactoryTest, UnregisterInsideEventHandler) { + testing::StrictMock listener; + ASSERT_TRUE(loop_->RegisterSocket(read_fd_, kAllEvents, &listener)); + ASSERT_TRUE(loop_->RegisterSocket(write_fd_, kAllEvents, &listener)); + + // We are not guaranteed the order in which those events will happen, so we + // try to accommodate both possibilities. + int total_called = 0; + EXPECT_CALL(listener, OnSocketEvent(_, read_fd_, kSocketEventReadable)) + .Times(AtMost(1)) + .WillOnce([&]() { + ++total_called; + ASSERT_TRUE(loop_->UnregisterSocket(write_fd_)); + }); + EXPECT_CALL(listener, OnSocketEvent(_, write_fd_, kSocketEventWritable)) + .Times(AtMost(1)) + .WillOnce([&]() { + ++total_called; + ASSERT_TRUE(loop_->UnregisterSocket(read_fd_)); + }); + ASSERT_TRUE(loop_->ArtificiallyNotifyEvent(read_fd_, kSocketEventReadable)); + loop_->RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); + EXPECT_EQ(total_called, 1); +} + +TEST_P(QuicEventLoopFactoryTest, UnregisterSelfInsideEventHandler) { + testing::StrictMock listener; + ASSERT_TRUE(loop_->RegisterSocket(write_fd_, kAllEvents, &listener)); + + EXPECT_CALL(listener, OnSocketEvent(_, write_fd_, kSocketEventWritable)) + .WillOnce([&]() { ASSERT_TRUE(loop_->UnregisterSocket(write_fd_)); }); + loop_->RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); +} + +// Creates a bidirectional socket and tests its behavior when it's both readable +// and writable. +TEST_P(QuicEventLoopFactoryTest, ReadWriteSocket) { + int sockets[2]; + ASSERT_EQ(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets), 0); + auto close_sockets = absl::MakeCleanup([&]() { + close(sockets[0]); + close(sockets[1]); + }); + SetNonBlocking(sockets[0]); + SetNonBlocking(sockets[1]); + + testing::StrictMock listener; + ASSERT_TRUE(loop_->RegisterSocket(sockets[0], kAllEvents, &listener)); + EXPECT_CALL(listener, OnSocketEvent(_, sockets[0], kSocketEventWritable)); + loop_->RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(4)); + + int io_result; + std::string data(2048, 'a'); + do { + io_result = write(sockets[0], data.data(), data.size()); + } while (io_result > 0); + ASSERT_EQ(errno, EAGAIN); + + if (!loop_->SupportsEdgeTriggered()) { + ASSERT_TRUE(loop_->RearmSocket(sockets[0], kSocketEventWritable)); + } + // We are not write-blocked, so this should not notify. + loop_->RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(4)); + + EXPECT_GT(write(sockets[1], data.data(), data.size()), 0); + EXPECT_CALL(listener, OnSocketEvent(_, sockets[0], kSocketEventReadable)); + loop_->RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(4)); + + do { + char buffer[2048]; + io_result = read(sockets[1], buffer, sizeof(buffer)); + } while (io_result > 0); + ASSERT_EQ(errno, EAGAIN); + // Here, we can receive either "writable" or "readable and writable" + // notification depending on the backend in question. + EXPECT_CALL(listener, + OnSocketEvent(_, sockets[0], HasFlagSet(kSocketEventWritable))); + loop_->RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(4)); +} + +TEST_P(QuicEventLoopFactoryTest, AlarmInFuture) { + constexpr auto kAlarmTimeout = QuicTime::Delta::FromMilliseconds(5); + auto [alarm, delegate] = CreateAlarm(); + + alarm->Set(clock_.Now() + kAlarmTimeout); + + bool alarm_called = false; + EXPECT_CALL(*delegate, OnAlarm()).WillOnce([&]() { alarm_called = true; }); + RunEventLoopUntil([&]() { return alarm_called; }, + QuicTime::Delta::FromMilliseconds(100)); +} + +TEST_P(QuicEventLoopFactoryTest, AlarmsInPast) { + constexpr auto kAlarmTimeout = QuicTime::Delta::FromMilliseconds(5); + auto [alarm1, delegate1] = CreateAlarm(); + auto [alarm2, delegate2] = CreateAlarm(); + + alarm1->Set(clock_.Now() - 2 * kAlarmTimeout); + alarm2->Set(clock_.Now() - kAlarmTimeout); + + { + testing::InSequence s; + EXPECT_CALL(*delegate1, OnAlarm()); + EXPECT_CALL(*delegate2, OnAlarm()); + } + loop_->RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(100)); +} + +TEST_P(QuicEventLoopFactoryTest, AlarmCancelled) { + constexpr auto kAlarmTimeout = QuicTime::Delta::FromMilliseconds(5); + auto [alarm, delegate] = CreateAlarm(); + + alarm->Set(clock_.Now() + kAlarmTimeout); + alarm->Cancel(); + + loop_->RunEventLoopOnce(kAlarmTimeout * 2); +} + +TEST_P(QuicEventLoopFactoryTest, AlarmCancelledAndSetAgain) { + constexpr auto kAlarmTimeout = QuicTime::Delta::FromMilliseconds(5); + auto [alarm, delegate] = CreateAlarm(); + + alarm->Set(clock_.Now() + kAlarmTimeout); + alarm->Cancel(); + alarm->Set(clock_.Now() + 2 * kAlarmTimeout); + + bool alarm_called = false; + EXPECT_CALL(*delegate, OnAlarm()).WillOnce([&]() { alarm_called = true; }); + RunEventLoopUntil([&]() { return alarm_called; }, + QuicTime::Delta::FromMilliseconds(100)); +} + +TEST_P(QuicEventLoopFactoryTest, AlarmCancelsAnotherAlarm) { + constexpr auto kAlarmTimeout = QuicTime::Delta::FromMilliseconds(5); + auto [alarm1_ptr, delegate1] = CreateAlarm(); + auto [alarm2_ptr, delegate2] = CreateAlarm(); + + QuicAlarm& alarm1 = *alarm1_ptr; + QuicAlarm& alarm2 = *alarm2_ptr; + alarm1.Set(clock_.Now() - kAlarmTimeout); + alarm2.Set(clock_.Now() - kAlarmTimeout); + + int alarms_called = 0; + // Since the order in which alarms are cancelled is not well-determined, make + // each one cancel another. + EXPECT_CALL(*delegate1, OnAlarm()).Times(AtMost(1)).WillOnce([&]() { + alarm2.Cancel(); + ++alarms_called; + }); + EXPECT_CALL(*delegate2, OnAlarm()).Times(AtMost(1)).WillOnce([&]() { + alarm1.Cancel(); + ++alarms_called; + }); + // Run event loop twice to ensure the second alarm is not called after two + // iterations. + loop_->RunEventLoopOnce(kAlarmTimeout * 2); + loop_->RunEventLoopOnce(kAlarmTimeout * 2); + EXPECT_EQ(alarms_called, 1); +} + +TEST_P(QuicEventLoopFactoryTest, DestructorWithPendingAlarm) { + constexpr auto kAlarmTimeout = QuicTime::Delta::FromMilliseconds(5); + auto [alarm1_ptr, delegate1] = CreateAlarm(); + + alarm1_ptr->Set(clock_.Now() + kAlarmTimeout); + // Expect destructor to cleanly unregister itself before the event loop is + // gone. +} + +TEST_P(QuicEventLoopFactoryTest, NegativeTimeout) { + constexpr auto kAlarmTimeout = QuicTime::Delta::FromSeconds(300); + auto [alarm1_ptr, delegate1] = CreateAlarm(); + + alarm1_ptr->Set(clock_.Now() + kAlarmTimeout); + + loop_->RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(-1)); +} + +TEST_P(QuicEventLoopFactoryTest, ScheduleAlarmInPastFromInsideAlarm) { + constexpr auto kAlarmTimeout = QuicTime::Delta::FromMilliseconds(20); + auto [alarm1_ptr, delegate1] = CreateAlarm(); + auto [alarm2_ptr, delegate2] = CreateAlarm(); + + alarm1_ptr->Set(clock_.Now() - kAlarmTimeout); + EXPECT_CALL(*delegate1, OnAlarm()) + .WillOnce([&, alarm2_unowned = alarm2_ptr.get()]() { + alarm2_unowned->Set(clock_.Now() - 2 * kAlarmTimeout); + }); + bool fired = false; + EXPECT_CALL(*delegate2, OnAlarm()).WillOnce([&]() { fired = true; }); + + RunEventLoopUntil([&]() { return fired; }, + QuicTime::Delta::FromMilliseconds(100)); +} + +} // namespace +} // namespace quic::test diff --git a/quiche/quic/core/io/quic_default_event_loop.cc b/quiche/quic/core/io/quic_default_event_loop.cc new file mode 100644 index 000000000000..8c1877c4fd23 --- /dev/null +++ b/quiche/quic/core/io/quic_default_event_loop.cc @@ -0,0 +1,43 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/io/quic_default_event_loop.h" + +#include + +#include "quiche/quic/core/io/quic_poll_event_loop.h" +#include "quiche/common/platform/api/quiche_event_loop.h" + +#ifdef QUICHE_ENABLE_LIBEVENT +#include "quiche/quic/bindings/quic_libevent.h" +#endif + +namespace quic { + +QuicEventLoopFactory* GetDefaultEventLoop() { + if (QuicEventLoopFactory* factory = + quiche::GetOverrideForDefaultEventLoop()) { + return factory; + } +#ifdef QUICHE_ENABLE_LIBEVENT + return QuicLibeventEventLoopFactory::Get(); +#else + return QuicPollEventLoopFactory::Get(); +#endif +} + +std::vector GetAllSupportedEventLoops() { + std::vector loops = { +#ifdef QUICHE_ENABLE_LIBEVENT + QuicLibeventEventLoopFactory::Get(), + QuicLibeventEventLoopFactory::GetLevelTriggeredBackendForTests(), +#endif + QuicPollEventLoopFactory::Get()}; + std::vector extra = + quiche::GetExtraEventLoopImplementations(); + loops.insert(loops.end(), extra.begin(), extra.end()); + return loops; +} + +} // namespace quic diff --git a/quiche/quic/core/io/quic_default_event_loop.h b/quiche/quic/core/io/quic_default_event_loop.h new file mode 100644 index 000000000000..6073a6e9a38e --- /dev/null +++ b/quiche/quic/core/io/quic_default_event_loop.h @@ -0,0 +1,26 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_IO_QUIC_DEFAULT_EVENT_LOOP_H_ +#define QUICHE_QUIC_CORE_IO_QUIC_DEFAULT_EVENT_LOOP_H_ + +#include + +#include "quiche/quic/core/io/quic_event_loop.h" + +namespace quic { + +// Returns the default implementation of QuicheEventLoop. The embedders can +// override this using the platform API. The factory pointer returned is an +// unowned static variable. +QUICHE_NO_EXPORT QuicEventLoopFactory* GetDefaultEventLoop(); + +// Returns the factory objects for all event loops. This is particularly useful +// for the unit tests. The factory pointers returned are unowned static +// variables. +QUICHE_NO_EXPORT std::vector GetAllSupportedEventLoops(); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_IO_QUIC_DEFAULT_EVENT_LOOP_H_ diff --git a/quiche/quic/core/io/quic_event_loop.h b/quiche/quic/core/io/quic_event_loop.h new file mode 100644 index 000000000000..e02a0f0ca951 --- /dev/null +++ b/quiche/quic/core/io/quic_event_loop.h @@ -0,0 +1,101 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_IO_QUIC_EVENT_LOOP_H_ +#define QUICHE_QUIC_IO_QUIC_EVENT_LOOP_H_ + +#include +#include + +#include "absl/base/attributes.h" +#include "quiche/quic/core/quic_alarm_factory.h" +#include "quiche/quic/core/quic_clock.h" +#include "quiche/quic/core/quic_udp_socket.h" + +namespace quic { + +// A bitmask indicating a set of I/O events. +using QuicSocketEventMask = uint8_t; +inline constexpr QuicSocketEventMask kSocketEventReadable = 0x01; +inline constexpr QuicSocketEventMask kSocketEventWritable = 0x02; +inline constexpr QuicSocketEventMask kSocketEventError = 0x04; + +class QuicEventLoop; + +// A listener associated with a file descriptor. +class QUICHE_NO_EXPORT QuicSocketEventListener { + public: + virtual ~QuicSocketEventListener() = default; + + virtual void OnSocketEvent(QuicEventLoop* event_loop, QuicUdpSocketFd fd, + QuicSocketEventMask events) = 0; +}; + +// An abstraction for an event loop that can handle alarms and notify the +// listener about I/O events occuring to the registered UDP sockets. +// +// Note on error handling: while most of the methods below return a boolean to +// indicate whether the operation has succeeded or not, some will QUIC_BUG +// instead. +class QUICHE_NO_EXPORT QuicEventLoop { + public: + virtual ~QuicEventLoop() = default; + + // Indicates whether the event loop implementation supports edge-triggered + // notifications. If true, all of the events are permanent and are notified + // as long as they are registered. If false, whenever an event is triggered, + // the event registration is unset and has to be re-armed using RearmSocket(). + virtual bool SupportsEdgeTriggered() const = 0; + + // Watches for all of the requested |events| that occur on the |fd| and + // notifies the |listener| about them. |fd| must not be already registered; + // if it is, the function returns false. The |listener| must be alive for as + // long as it is registered. + virtual ABSL_MUST_USE_RESULT bool RegisterSocket( + QuicUdpSocketFd fd, QuicSocketEventMask events, + QuicSocketEventListener* listener) = 0; + // Removes the listener associated with |fd|. Returns false if the listener + // is not found. + virtual ABSL_MUST_USE_RESULT bool UnregisterSocket(QuicUdpSocketFd fd) = 0; + // Adds |events| to the list of the listened events for |fd|, given that |fd| + // is already registered. Must be only called if SupportsEdgeTriggered() is + // false. + virtual ABSL_MUST_USE_RESULT bool RearmSocket(QuicUdpSocketFd fd, + QuicSocketEventMask events) = 0; + // Causes the |fd| to be notified of |events| on the next event loop iteration + // even if none of the specified events has happened. + virtual ABSL_MUST_USE_RESULT bool ArtificiallyNotifyEvent( + QuicUdpSocketFd fd, QuicSocketEventMask events) = 0; + + // Runs a single iteration of the event loop. The iteration will run for at + // most |default_timeout|. + virtual void RunEventLoopOnce(QuicTime::Delta default_timeout) = 0; + + // Returns an alarm factory that allows alarms to be scheduled on this event + // loop. + virtual std::unique_ptr CreateAlarmFactory() = 0; + + // Returns the clock that is used by the alarm factory that the event loop + // provides. + virtual const QuicClock* GetClock() = 0; +}; + +// A factory object for the event loop. Every implementation is expected to have +// a static singleton instance. +class QUICHE_NO_EXPORT QuicEventLoopFactory { + public: + virtual ~QuicEventLoopFactory() {} + + // Creates an event loop. Note that |clock| may be ignored if the event loop + // implementation uses its own clock internally. + virtual std::unique_ptr Create(QuicClock* clock) = 0; + + // A human-readable name of the event loop implementation used in diagnostics + // output. + virtual std::string GetName() const = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_IO_QUIC_EVENT_LOOP_H_ diff --git a/quiche/quic/core/io/quic_poll_event_loop.cc b/quiche/quic/core/io/quic_poll_event_loop.cc new file mode 100644 index 000000000000..f56635f3fcf1 --- /dev/null +++ b/quiche/quic/core/io/quic_poll_event_loop.cc @@ -0,0 +1,263 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/io/quic_poll_event_loop.h" + +#include + +#include +#include +#include + +#include "absl/types/span.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/quic_alarm.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" + +namespace quic { + +namespace { + +using PollMask = decltype(::pollfd().events); + +PollMask GetPollMask(QuicSocketEventMask event_mask) { + return ((event_mask & kSocketEventReadable) ? POLLIN : 0) | + ((event_mask & kSocketEventWritable) ? POLLOUT : 0) | + ((event_mask & kSocketEventError) ? POLLERR : 0); +} + +QuicSocketEventMask GetEventMask(PollMask poll_mask) { + return ((poll_mask & POLLIN) ? kSocketEventReadable : 0) | + ((poll_mask & POLLOUT) ? kSocketEventWritable : 0) | + ((poll_mask & POLLERR) ? kSocketEventError : 0); +} + +} // namespace + +QuicPollEventLoop::QuicPollEventLoop(QuicClock* clock) : clock_(clock) {} + +bool QuicPollEventLoop::RegisterSocket(QuicUdpSocketFd fd, + QuicSocketEventMask events, + QuicSocketEventListener* listener) { + auto [it, success] = + registrations_.insert({fd, std::make_shared()}); + if (!success) { + return false; + } + Registration& registration = *it->second; + registration.events = events; + registration.listener = listener; + return true; +} + +bool QuicPollEventLoop::UnregisterSocket(QuicUdpSocketFd fd) { + return registrations_.erase(fd); +} + +bool QuicPollEventLoop::RearmSocket(QuicUdpSocketFd fd, + QuicSocketEventMask events) { + auto it = registrations_.find(fd); + if (it == registrations_.end()) { + return false; + } + it->second->events |= events; + return true; +} + +bool QuicPollEventLoop::ArtificiallyNotifyEvent(QuicUdpSocketFd fd, + QuicSocketEventMask events) { + auto it = registrations_.find(fd); + if (it == registrations_.end()) { + return false; + } + has_artificial_events_pending_ = true; + it->second->artificially_notify_at_next_iteration |= events; + return true; +} + +void QuicPollEventLoop::RunEventLoopOnce(QuicTime::Delta default_timeout) { + const QuicTime start_time = clock_->Now(); + ProcessAlarmsUpTo(start_time); + + QuicTime::Delta timeout = ComputePollTimeout(start_time, default_timeout); + ProcessIoEvents(start_time, timeout); + + const QuicTime end_time = clock_->Now(); + ProcessAlarmsUpTo(end_time); +} + +QuicTime::Delta QuicPollEventLoop::ComputePollTimeout( + QuicTime now, QuicTime::Delta default_timeout) const { + default_timeout = std::max(default_timeout, QuicTime::Delta::Zero()); + if (has_artificial_events_pending_) { + return QuicTime::Delta::Zero(); + } + if (alarms_.empty()) { + return default_timeout; + } + QuicTime end_time = std::min(now + default_timeout, alarms_.begin()->first); + if (end_time < now) { + // We only run a single pass of processing alarm callbacks per + // RunEventLoopOnce() call. If an alarm schedules another alarm in the past + // while in the callback, this will happen. + return QuicTime::Delta::Zero(); + } + return end_time - now; +} + +int QuicPollEventLoop::PollWithRetries(absl::Span fds, + QuicTime start_time, + QuicTime::Delta timeout) { + const QuicTime timeout_at = start_time + timeout; + int poll_result; + for (;;) { + float timeout_ms = std::ceil(timeout.ToMicroseconds() / 1000.f); + poll_result = + PollSyscall(fds.data(), fds.size(), static_cast(timeout_ms)); + + // Retry if EINTR happens. + bool is_eintr = poll_result < 0 && errno == EINTR; + if (!is_eintr) { + break; + } + QuicTime now = clock_->Now(); + if (now >= timeout_at) { + break; + } + timeout = timeout_at - now; + } + return poll_result; +} + +void QuicPollEventLoop::ProcessIoEvents(QuicTime start_time, + QuicTime::Delta timeout) { + // Set up the pollfd[] array. + const size_t registration_count = registrations_.size(); + auto pollfds = std::make_unique(registration_count); + size_t i = 0; + for (auto& [fd, registration] : registrations_) { + QUICHE_CHECK_LT( + i, registration_count); // Crash instead of out-of-bounds access. + pollfds[i].fd = fd; + pollfds[i].events = GetPollMask(registration->events); + pollfds[i].revents = 0; + ++i; + } + + // Actually run poll(2). + int poll_result = + PollWithRetries(absl::Span(pollfds.get(), registration_count), + start_time, timeout); + if (poll_result == 0 && !has_artificial_events_pending_) { + return; + } + + // Prepare the list of all callbacks to be called, while resetting all events, + // since we're operating in the level-triggered mode. + std::vector ready_list; + ready_list.reserve(registration_count); + for (i = 0; i < registration_count; i++) { + DispatchIoEvent(ready_list, pollfds[i].fd, pollfds[i].revents); + } + has_artificial_events_pending_ = false; + + // Actually call all of the callbacks. + RunReadyCallbacks(ready_list); +} + +void QuicPollEventLoop::DispatchIoEvent(std::vector& ready_list, + QuicUdpSocketFd fd, PollMask mask) { + auto it = registrations_.find(fd); + if (it == registrations_.end()) { + QUIC_BUG(poll returned an unregistered fd) << fd; + return; + } + Registration& registration = *it->second; + + mask |= GetPollMask(registration.artificially_notify_at_next_iteration); + registration.artificially_notify_at_next_iteration = QuicSocketEventMask(); + + // poll() always returns certain classes of events even if not requested. + mask &= GetPollMask(registration.events); + if (!mask) { + return; + } + + ready_list.push_back(ReadyListEntry{fd, it->second, GetEventMask(mask)}); + registration.events &= ~GetEventMask(mask); +} + +void QuicPollEventLoop::RunReadyCallbacks( + std::vector& ready_list) { + for (ReadyListEntry& entry : ready_list) { + std::shared_ptr registration = entry.registration.lock(); + if (!registration) { + // The socket has been unregistered from within one of the callbacks. + continue; + } + registration->listener->OnSocketEvent(this, entry.fd, entry.events); + } + ready_list.clear(); +} + +void QuicPollEventLoop::ProcessAlarmsUpTo(QuicTime time) { + // Determine which alarm callbacks needs to be run. + std::vector> alarms_to_call; + while (!alarms_.empty() && alarms_.begin()->first <= time) { + auto& [deadline, schedule_handle_weak] = *alarms_.begin(); + alarms_to_call.push_back(std::move(schedule_handle_weak)); + alarms_.erase(alarms_.begin()); + } + // Actually run those callbacks. + for (std::weak_ptr& schedule_handle_weak : alarms_to_call) { + std::shared_ptr schedule_handle = schedule_handle_weak.lock(); + if (!schedule_handle) { + // The alarm has been cancelled and might not even exist anymore. + continue; + } + (*schedule_handle)->DoFire(); + } + // Clean up all of the alarms in the front that have been cancelled. + while (!alarms_.empty()) { + if (alarms_.begin()->second.expired()) { + alarms_.erase(alarms_.begin()); + } else { + break; + } + } +} + +QuicAlarm* QuicPollEventLoop::AlarmFactory::CreateAlarm( + QuicAlarm::Delegate* delegate) { + return new Alarm(loop_, QuicArenaScopedPtr(delegate)); +} + +QuicArenaScopedPtr QuicPollEventLoop::AlarmFactory::CreateAlarm( + QuicArenaScopedPtr delegate, + QuicConnectionArena* arena) { + if (arena != nullptr) { + return arena->New(loop_, std::move(delegate)); + } + return QuicArenaScopedPtr(new Alarm(loop_, std::move(delegate))); +} + +QuicPollEventLoop::Alarm::Alarm( + QuicPollEventLoop* loop, QuicArenaScopedPtr delegate) + : QuicAlarm(std::move(delegate)), loop_(loop) {} + +void QuicPollEventLoop::Alarm::SetImpl() { + current_schedule_handle_ = std::make_shared(this); + loop_->alarms_.insert({deadline(), current_schedule_handle_}); +} + +void QuicPollEventLoop::Alarm::CancelImpl() { + current_schedule_handle_.reset(); +} + +std::unique_ptr QuicPollEventLoop::CreateAlarmFactory() { + return std::make_unique(this); +} + +} // namespace quic diff --git a/quiche/quic/core/io/quic_poll_event_loop.h b/quiche/quic/core/io/quic_poll_event_loop.h new file mode 100644 index 000000000000..6b2ec4911074 --- /dev/null +++ b/quiche/quic/core/io/quic_poll_event_loop.h @@ -0,0 +1,166 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_IO_QUIC_POLL_EVENT_LOOP_H_ +#define QUICHE_QUIC_CORE_IO_QUIC_POLL_EVENT_LOOP_H_ + +#include + +#include + +#include "absl/container/btree_map.h" +#include "absl/types/span.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/quic_alarm.h" +#include "quiche/quic/core/quic_alarm_factory.h" +#include "quiche/quic/core/quic_clock.h" +#include "quiche/quic/core/quic_udp_socket.h" +#include "quiche/common/quiche_linked_hash_map.h" + +namespace quic { + +// A simple and portable implementation of QuicEventLoop using poll(2). Works +// on all POSIX platforms (and can be potentially made to support Windows using +// WSAPoll). +// +// For most operations, this implementation has a typical runtime of +// O(N + log M), where N is the number of file descriptors, and M is the number +// of pending alarms. +// +// This API has to deal with the situations where callbacks are modified from +// the callbacks themselves. To address this, we use the following two +// approaches: +// 1. The code does not execute any callbacks until the very end of the +// processing, when all of the state for the event loop is consistent. +// 2. The callbacks are stored as weak pointers, since other callbacks can +// cause them to be unregistered. +class QUICHE_NO_EXPORT QuicPollEventLoop : public QuicEventLoop { + public: + QuicPollEventLoop(QuicClock* clock); + + // QuicEventLoop implementation. + bool SupportsEdgeTriggered() const override { return false; } + ABSL_MUST_USE_RESULT bool RegisterSocket( + QuicUdpSocketFd fd, QuicSocketEventMask events, + QuicSocketEventListener* listener) override; + ABSL_MUST_USE_RESULT bool UnregisterSocket(QuicUdpSocketFd fd) override; + ABSL_MUST_USE_RESULT bool RearmSocket(QuicUdpSocketFd fd, + QuicSocketEventMask events) override; + ABSL_MUST_USE_RESULT bool ArtificiallyNotifyEvent( + QuicUdpSocketFd fd, QuicSocketEventMask events) override; + void RunEventLoopOnce(QuicTime::Delta default_timeout) override; + std::unique_ptr CreateAlarmFactory() override; + const QuicClock* GetClock() override { return clock_; } + + protected: + // Allows poll(2) calls to be mocked out in unit tests. + virtual int PollSyscall(pollfd* fds, nfds_t nfds, int timeout) { + return ::poll(fds, nfds, timeout); + } + + private: + friend class QuicPollEventLoopPeer; + + struct Registration { + QuicSocketEventMask events = 0; + QuicSocketEventListener* listener; + + QuicSocketEventMask artificially_notify_at_next_iteration = 0; + }; + + class Alarm : public QuicAlarm { + public: + Alarm(QuicPollEventLoop* loop, + QuicArenaScopedPtr delegate); + + void SetImpl() override; + void CancelImpl() override; + + void DoFire() { + current_schedule_handle_.reset(); + Fire(); + } + + private: + QuicPollEventLoop* loop_; + // Deleted when the alarm is cancelled, causing the corresponding weak_ptr + // in the alarm list to not be executed. + std::shared_ptr current_schedule_handle_; + }; + + class AlarmFactory : public QuicAlarmFactory { + public: + AlarmFactory(QuicPollEventLoop* loop) : loop_(loop) {} + + // QuicAlarmFactory implementation. + QuicAlarm* CreateAlarm(QuicAlarm::Delegate* delegate) override; + QuicArenaScopedPtr CreateAlarm( + QuicArenaScopedPtr delegate, + QuicConnectionArena* arena) override; + + private: + QuicPollEventLoop* loop_; + }; + + // Used for deferred execution of I/O callbacks. + struct ReadyListEntry { + QuicUdpSocketFd fd; + std::weak_ptr registration; + QuicSocketEventMask events; + }; + + // We're using a linked hash map here to ensure the events are called in the + // registration order. This isn't strictly speaking necessary, but makes + // testing things easier. + using RegistrationMap = + quiche::QuicheLinkedHashMap>; + // Alarms are stored as weak pointers, since the alarm can be cancelled and + // disappear while in the queue. + using AlarmList = absl::btree_multimap>; + + // Returns the timeout for the next poll(2) call. It is typically the time at + // which the next alarm is supposed to activate. + QuicTime::Delta ComputePollTimeout(QuicTime now, + QuicTime::Delta default_timeout) const; + // Calls poll(2) with the provided timeout and dispatches the callbacks + // accordingly. + void ProcessIoEvents(QuicTime start_time, QuicTime::Delta timeout); + // Calls all of the alarm callbacks that are scheduled before or at |time|. + void ProcessAlarmsUpTo(QuicTime time); + + // Adds the I/O callbacks for |fd| to the |ready_lits| as appopriate. + void DispatchIoEvent(std::vector& ready_list, + QuicUdpSocketFd fd, short mask); // NOLINT(runtime/int) + // Runs all of the callbacks on the ready list. + void RunReadyCallbacks(std::vector& ready_list); + + // Calls poll() while handling EINTR. Returns the return value of poll(2) + // system call. + int PollWithRetries(absl::Span fds, QuicTime start_time, + QuicTime::Delta timeout); + + const QuicClock* clock_; + RegistrationMap registrations_; + AlarmList alarms_; + bool has_artificial_events_pending_ = false; +}; + +class QUICHE_NO_EXPORT QuicPollEventLoopFactory : public QuicEventLoopFactory { + public: + static QuicPollEventLoopFactory* Get() { + static QuicPollEventLoopFactory* factory = new QuicPollEventLoopFactory(); + return factory; + } + + std::unique_ptr Create(QuicClock* clock) override { + return std::make_unique(clock); + } + + std::string GetName() const override { return "poll(2)"; } +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_IO_QUIC_POLL_EVENT_LOOP_H_ diff --git a/quiche/quic/core/io/quic_poll_event_loop_test.cc b/quiche/quic/core/io/quic_poll_event_loop_test.cc new file mode 100644 index 000000000000..f0e95577e1e1 --- /dev/null +++ b/quiche/quic/core/io/quic_poll_event_loop_test.cc @@ -0,0 +1,342 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/io/quic_poll_event_loop.h" + +#include +#include + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/quic_alarm.h" +#include "quiche/quic/core/quic_alarm_factory.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/mock_clock.h" + +namespace quic { + +class QuicPollEventLoopPeer { + public: + static QuicTime::Delta ComputePollTimeout(const QuicPollEventLoop& loop, + QuicTime now, + QuicTime::Delta default_timeout) { + return loop.ComputePollTimeout(now, default_timeout); + } +}; + +} // namespace quic + +namespace quic::test { +namespace { + +using testing::_; +using testing::AtMost; +using testing::ElementsAre; + +constexpr QuicSocketEventMask kAllEvents = + kSocketEventReadable | kSocketEventWritable | kSocketEventError; +constexpr QuicTime::Delta kDefaultTimeout = QuicTime::Delta::FromSeconds(100); + +class MockQuicSocketEventListener : public QuicSocketEventListener { + public: + MOCK_METHOD(void, OnSocketEvent, + (QuicEventLoop* /*event_loop*/, QuicUdpSocketFd /*fd*/, + QuicSocketEventMask /*events*/), + (override)); +}; + +class MockDelegate : public QuicAlarm::Delegate { + public: + QuicConnectionContext* GetConnectionContext() override { return nullptr; } + MOCK_METHOD(void, OnAlarm, (), (override)); +}; + +class QuicPollEventLoopForTest : public QuicPollEventLoop { + public: + QuicPollEventLoopForTest(MockClock* clock) + : QuicPollEventLoop(clock), clock_(clock) {} + + int PollSyscall(pollfd* fds, nfds_t nfds, int timeout) override { + timeouts_.push_back(timeout); + if (eintr_after_ != QuicTime::Delta::Infinite()) { + errno = EINTR; + clock_->AdvanceTime(eintr_after_); + eintr_after_ = QuicTime::Delta::Infinite(); + return -1; + } + clock_->AdvanceTime(QuicTime::Delta::FromMilliseconds(timeout)); + return QuicPollEventLoop::PollSyscall(fds, nfds, timeout); + } + + void TriggerEintrAfter(QuicTime::Delta time) { eintr_after_ = time; } + + const std::vector& timeouts() const { return timeouts_; } + + private: + MockClock* clock_; + QuicTime::Delta eintr_after_ = QuicTime::Delta::Infinite(); + std::vector timeouts_; +}; + +class QuicPollEventLoopTest : public QuicTest { + public: + QuicPollEventLoopTest() + : loop_(&clock_), factory_(loop_.CreateAlarmFactory()) { + int fds[2]; + int result = ::pipe(fds); + QUICHE_CHECK(result >= 0) << "Failed to create a pipe, errno: " << errno; + read_fd_ = fds[0]; + write_fd_ = fds[1]; + + QUICHE_CHECK(::fcntl(read_fd_, F_SETFL, + ::fcntl(read_fd_, F_GETFL) | O_NONBLOCK) == 0) + << "Failed to mark pipe FD non-blocking, errno: " << errno; + QUICHE_CHECK(::fcntl(write_fd_, F_SETFL, + ::fcntl(write_fd_, F_GETFL) | O_NONBLOCK) == 0) + << "Failed to mark pipe FD non-blocking, errno: " << errno; + + clock_.AdvanceTime(10 * kDefaultTimeout); + } + + ~QuicPollEventLoopTest() { + close(read_fd_); + close(write_fd_); + } + + QuicTime::Delta ComputePollTimeout() { + return QuicPollEventLoopPeer::ComputePollTimeout(loop_, clock_.Now(), + kDefaultTimeout); + } + + std::pair, MockDelegate*> CreateAlarm() { + auto delegate = std::make_unique>(); + MockDelegate* delegate_unowned = delegate.get(); + auto alarm = absl::WrapUnique(factory_->CreateAlarm(delegate.release())); + return std::make_pair(std::move(alarm), delegate_unowned); + } + + protected: + MockClock clock_; + QuicPollEventLoopForTest loop_; + std::unique_ptr factory_; + int read_fd_; + int write_fd_; +}; + +TEST_F(QuicPollEventLoopTest, NothingHappens) { + testing::StrictMock listener; + ASSERT_TRUE(loop_.RegisterSocket(read_fd_, kAllEvents, &listener)); + ASSERT_TRUE(loop_.RegisterSocket(write_fd_, kAllEvents, &listener)); + + // Attempt double-registration. + EXPECT_FALSE(loop_.RegisterSocket(write_fd_, kAllEvents, &listener)); + + EXPECT_EQ(ComputePollTimeout(), kDefaultTimeout); + + EXPECT_CALL(listener, OnSocketEvent(_, write_fd_, kSocketEventWritable)); + loop_.RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(4)); + // Expect no further calls. + loop_.RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(5)); + EXPECT_THAT(loop_.timeouts(), ElementsAre(4, 5)); +} + +TEST_F(QuicPollEventLoopTest, RearmWriter) { + testing::StrictMock listener; + ASSERT_TRUE(loop_.RegisterSocket(write_fd_, kAllEvents, &listener)); + + EXPECT_CALL(listener, OnSocketEvent(_, write_fd_, kSocketEventWritable)) + .Times(2); + loop_.RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); + ASSERT_TRUE(loop_.RearmSocket(write_fd_, kSocketEventWritable)); + loop_.RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); +} + +TEST_F(QuicPollEventLoopTest, Readable) { + testing::StrictMock listener; + ASSERT_TRUE(loop_.RegisterSocket(read_fd_, kAllEvents, &listener)); + + ASSERT_EQ(4, write(write_fd_, "test", 4)); + EXPECT_CALL(listener, OnSocketEvent(_, read_fd_, kSocketEventReadable)); + loop_.RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); + // Expect no further calls. + loop_.RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); +} + +TEST_F(QuicPollEventLoopTest, RearmReader) { + testing::StrictMock listener; + ASSERT_TRUE(loop_.RegisterSocket(read_fd_, kAllEvents, &listener)); + + ASSERT_EQ(4, write(write_fd_, "test", 4)); + EXPECT_CALL(listener, OnSocketEvent(_, read_fd_, kSocketEventReadable)); + loop_.RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); + // Expect no further calls. + loop_.RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); +} + +TEST_F(QuicPollEventLoopTest, WriterUnblocked) { + testing::StrictMock listener; + ASSERT_TRUE(loop_.RegisterSocket(write_fd_, kAllEvents, &listener)); + + EXPECT_CALL(listener, OnSocketEvent(_, write_fd_, kSocketEventWritable)); + loop_.RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); + loop_.RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); + + int io_result; + std::string data(2048, 'a'); + do { + io_result = write(write_fd_, data.data(), data.size()); + } while (io_result > 0); + ASSERT_EQ(errno, EAGAIN); + + // Rearm and expect no immediate calls. + ASSERT_TRUE(loop_.RearmSocket(write_fd_, kSocketEventWritable)); + loop_.RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); + + EXPECT_CALL(listener, OnSocketEvent(_, write_fd_, kSocketEventWritable)); + do { + io_result = read(read_fd_, data.data(), data.size()); + } while (io_result > 0); + ASSERT_EQ(errno, EAGAIN); + loop_.RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); +} + +TEST_F(QuicPollEventLoopTest, ArtificialEvent) { + testing::StrictMock listener; + ASSERT_TRUE(loop_.RegisterSocket(read_fd_, kAllEvents, &listener)); + ASSERT_TRUE(loop_.RegisterSocket(write_fd_, kAllEvents, &listener)); + + EXPECT_EQ(ComputePollTimeout(), kDefaultTimeout); + ASSERT_TRUE(loop_.ArtificiallyNotifyEvent(read_fd_, kSocketEventReadable)); + EXPECT_EQ(ComputePollTimeout(), QuicTime::Delta::Zero()); + + { + testing::InSequence s; + EXPECT_CALL(listener, OnSocketEvent(_, read_fd_, kSocketEventReadable)); + EXPECT_CALL(listener, OnSocketEvent(_, write_fd_, kSocketEventWritable)); + } + loop_.RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); + EXPECT_EQ(ComputePollTimeout(), kDefaultTimeout); +} + +TEST_F(QuicPollEventLoopTest, Unregister) { + testing::StrictMock listener; + ASSERT_TRUE(loop_.RegisterSocket(write_fd_, kAllEvents, &listener)); + ASSERT_TRUE(loop_.UnregisterSocket(write_fd_)); + + // Expect nothing to happen. + loop_.RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); + + EXPECT_FALSE(loop_.UnregisterSocket(write_fd_)); + EXPECT_FALSE(loop_.RearmSocket(write_fd_, kSocketEventWritable)); + EXPECT_FALSE(loop_.ArtificiallyNotifyEvent(write_fd_, kSocketEventWritable)); +} + +TEST_F(QuicPollEventLoopTest, UnregisterInsideEventHandler) { + testing::StrictMock listener; + ASSERT_TRUE(loop_.RegisterSocket(read_fd_, kAllEvents, &listener)); + ASSERT_TRUE(loop_.RegisterSocket(write_fd_, kAllEvents, &listener)); + + EXPECT_CALL(listener, OnSocketEvent(_, read_fd_, kSocketEventReadable)) + .WillOnce([this]() { ASSERT_TRUE(loop_.UnregisterSocket(write_fd_)); }); + EXPECT_CALL(listener, OnSocketEvent(_, write_fd_, kSocketEventWritable)) + .Times(0); + ASSERT_TRUE(loop_.ArtificiallyNotifyEvent(read_fd_, kSocketEventReadable)); + loop_.RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(1)); +} + +TEST_F(QuicPollEventLoopTest, EintrHandler) { + testing::StrictMock listener; + ASSERT_TRUE(loop_.RegisterSocket(read_fd_, kAllEvents, &listener)); + + loop_.TriggerEintrAfter(QuicTime::Delta::FromMilliseconds(25)); + loop_.RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(100)); + EXPECT_THAT(loop_.timeouts(), ElementsAre(100, 75)); +} + +TEST_F(QuicPollEventLoopTest, AlarmInFuture) { + EXPECT_EQ(ComputePollTimeout(), kDefaultTimeout); + + constexpr auto kAlarmTimeout = QuicTime::Delta::FromMilliseconds(5); + auto [alarm, delegate] = CreateAlarm(); + EXPECT_EQ(ComputePollTimeout(), kDefaultTimeout); + + alarm->Set(clock_.Now() + kAlarmTimeout); + EXPECT_EQ(ComputePollTimeout(), kAlarmTimeout); + + EXPECT_CALL(*delegate, OnAlarm()); + loop_.RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(100)); + EXPECT_EQ(ComputePollTimeout(), kDefaultTimeout); +} + +TEST_F(QuicPollEventLoopTest, AlarmsInPast) { + EXPECT_EQ(ComputePollTimeout(), kDefaultTimeout); + + constexpr auto kAlarmTimeout = QuicTime::Delta::FromMilliseconds(5); + auto [alarm1, delegate1] = CreateAlarm(); + auto [alarm2, delegate2] = CreateAlarm(); + + alarm1->Set(clock_.Now() - 2 * kAlarmTimeout); + alarm2->Set(clock_.Now() - kAlarmTimeout); + + { + testing::InSequence s; + EXPECT_CALL(*delegate1, OnAlarm()); + EXPECT_CALL(*delegate2, OnAlarm()); + } + loop_.RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(100)); +} + +TEST_F(QuicPollEventLoopTest, AlarmCancelled) { + EXPECT_EQ(ComputePollTimeout(), kDefaultTimeout); + + constexpr auto kAlarmTimeout = QuicTime::Delta::FromMilliseconds(5); + auto [alarm, delegate] = CreateAlarm(); + EXPECT_EQ(ComputePollTimeout(), kDefaultTimeout); + + alarm->Set(clock_.Now() + kAlarmTimeout); + alarm->Cancel(); + alarm->Set(clock_.Now() + 2 * kAlarmTimeout); + EXPECT_EQ(ComputePollTimeout(), kAlarmTimeout); + + EXPECT_CALL(*delegate, OnAlarm()); + loop_.RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(100)); + EXPECT_THAT(loop_.timeouts(), ElementsAre(10)); + EXPECT_EQ(ComputePollTimeout(), kDefaultTimeout); +} + +TEST_F(QuicPollEventLoopTest, AlarmCancelsAnotherAlarm) { + EXPECT_EQ(ComputePollTimeout(), kDefaultTimeout); + + constexpr auto kAlarmTimeout = QuicTime::Delta::FromMilliseconds(5); + auto [alarm1_ptr, delegate1] = CreateAlarm(); + auto [alarm2_ptr, delegate2] = CreateAlarm(); + + QuicAlarm& alarm1 = *alarm1_ptr; + QuicAlarm& alarm2 = *alarm2_ptr; + alarm1.Set(clock_.Now() - kAlarmTimeout); + alarm2.Set(clock_.Now() - kAlarmTimeout); + + int alarms_called = 0; + // Since the order in which alarms are cancelled is not well-determined, make + // each one cancel another. + EXPECT_CALL(*delegate1, OnAlarm()).Times(AtMost(1)).WillOnce([&]() { + alarm2.Cancel(); + ++alarms_called; + }); + EXPECT_CALL(*delegate2, OnAlarm()).Times(AtMost(1)).WillOnce([&]() { + alarm1.Cancel(); + ++alarms_called; + }); + loop_.RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(100)); + EXPECT_EQ(alarms_called, 1); + EXPECT_EQ(ComputePollTimeout(), kDefaultTimeout); +} + +} // namespace +} // namespace quic::test diff --git a/quiche/quic/core/io/socket.h b/quiche/quic/core/io/socket.h new file mode 100644 index 000000000000..7298f7e827df --- /dev/null +++ b/quiche/quic/core/io/socket.h @@ -0,0 +1,131 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_IO_SOCKET_H_ +#define QUICHE_QUIC_CORE_IO_SOCKET_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_ip_address_family.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/common/platform/api/quiche_export.h" + +#if defined(_WIN32) +#include +#endif // defined(_WIN32) + +namespace quic { + +#if defined(_WIN32) +using SocketFd = SOCKET; +inline constexpr SocketFd kInvalidSocketFd = INVALID_SOCKET; +#else +using SocketFd = int; +inline constexpr SocketFd kInvalidSocketFd = -1; +#endif + +// Low-level platform-agnostic socket operations. Closely follows the behavior +// of basic POSIX socket APIs, diverging mostly only to convert to/from cleaner +// and platform-agnostic types. +namespace socket_api { +enum class SocketProtocol { + kUdp, + kTcp, +}; + +inline absl::string_view GetProtocolName(SocketProtocol protocol) { + switch (protocol) { + case SocketProtocol::kUdp: + return "UDP"; + case SocketProtocol::kTcp: + return "TCP"; + } + + return "unknown"; +} + +struct QUICHE_EXPORT AcceptResult { + // Socket for interacting with the accepted connection. + SocketFd fd; + + // Address of the connected peer. + QuicSocketAddress peer_address; +}; + +// Creates a socket with blocking or non-blocking behavior. +absl::StatusOr CreateSocket(IpAddressFamily address_family, + SocketProtocol protocol, + bool blocking = false); + +// Sets socket `fd` to blocking (if `blocking` true) or non-blocking (if +// `blocking` false). Must be a change from previous state. +absl::Status SetSocketBlocking(SocketFd fd, bool blocking); + +// Sets buffer sizes for socket `fd` to `size` bytes. +absl::Status SetReceiveBufferSize(SocketFd fd, QuicByteCount size); +absl::Status SetSendBufferSize(SocketFd fd, QuicByteCount size); + +// Connects socket `fd` to `peer_address`. Returns a status with +// `absl::StatusCode::kUnavailable` iff the socket is non-blocking and the +// connection could not be immediately completed. The socket will then complete +// connecting asynchronously, and on becoming writable, the result can be +// checked using GetSocketError(). +absl::Status Connect(SocketFd fd, const QuicSocketAddress& peer_address); + +// Gets and clears socket error information for socket `fd`. Note that returned +// error could be either the found socket error, or unusually, an error from the +// attempt to retrieve error information. Typically used to determine connection +// result after asynchronous completion of a Connect() call. +absl::Status GetSocketError(SocketFd fd); + +// Assign `address` to socket `fd`. +absl::Status Bind(SocketFd fd, const QuicSocketAddress& address); + +// Gets the address assigned to socket `fd`. +absl::StatusOr GetSocketAddress(SocketFd fd); + +// Marks socket `fd` as a passive socket listening for connection requests. +// `backlog` is the maximum number of queued connection requests. Typically +// expected to return a status with `absl::StatusCode::InvalidArgumentError` +// if `fd` is not a TCP socket. +absl::Status Listen(SocketFd fd, int backlog); + +// Accepts an incoming connection to the listening socket `fd`. The returned +// connection socket will be set as non-blocking iff `blocking` is false. +// Typically expected to return a status with +// `absl::StatusCode::InvalidArgumentError` if `fd` is not a TCP socket or not +// listening for connections. Returns a status with +// `absl::StatusCode::kUnavailable` iff the socket is non-blocking and no +// incoming connection could be immediately accepted. +absl::StatusOr Accept(SocketFd fd, bool blocking = false); + +// Receives data from socket `fd`. Will fill `buffer.data()` with up to +// `buffer.size()` bytes. On success, returns a span pointing to the buffer +// but resized to the actual number of bytes received. Returns a status with +// `absl::StatusCode::kUnavailable` iff the socket is non-blocking and the +// receive operation could not be immediately completed. If `peek` is true, +// received data is not removed from the underlying socket data queue. +absl::StatusOr> Receive(SocketFd fd, absl::Span buffer, + bool peek = false); + +// Sends some or all of the data in `buffer` to socket `fd`. On success, +// returns a string_view pointing to the unsent remainder of the buffer (or an +// empty string_view if all of `buffer` was successfully sent). Returns a status +// with `absl::StatusCode::kUnavailable` iff the socket is non-blocking and the +// send operation could not be immediately completed. +absl::StatusOr Send(SocketFd fd, absl::string_view buffer); + +// Closes socket `fd`. +absl::Status Close(SocketFd fd); +} // namespace socket_api + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_IO_SOCKET_H_ diff --git a/quiche/quic/core/io/socket_posix.cc b/quiche/quic/core/io/socket_posix.cc new file mode 100644 index 000000000000..f72b088c79e9 --- /dev/null +++ b/quiche/quic/core/io/socket_posix.cc @@ -0,0 +1,521 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +#include +#include + +#include + +#include "absl/base/attributes.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/io/socket.h" +#include "quiche/quic/platform/api/quic_ip_address_family.h" +#include "quiche/common/platform/api/quiche_logging.h" + +// accept4() is a Linux-specific extension that is available in glibc 2.10+. +#if defined(__linux__) && defined(_GNU_SOURCE) && defined(__GLIBC_PREREQ) +#if __GLIBC_PREREQ(2, 10) +#define HAS_ACCEPT4 +#endif +#endif + +namespace quic::socket_api { + +namespace { + +int ToPlatformSocketType(SocketProtocol protocol) { + switch (protocol) { + case SocketProtocol::kUdp: + return SOCK_DGRAM; + case SocketProtocol::kTcp: + return SOCK_STREAM; + } + + QUICHE_NOTREACHED(); + return -1; +} + +int ToPlatformProtocol(SocketProtocol protocol) { + switch (protocol) { + case SocketProtocol::kUdp: + return IPPROTO_UDP; + case SocketProtocol::kTcp: + return IPPROTO_TCP; + } + + QUICHE_NOTREACHED(); + return -1; +} + +// Wrapper of absl::ErrnoToStatus that ensures the `unavailable_error_numbers` +// and only those numbers result in `absl::StatusCode::kUnavailable`, converting +// any other would-be-unavailable Statuses to `absl::StatusCode::kNotFound`. +absl::Status ToStatus(int error_number, absl::string_view method_name, + absl::flat_hash_set unavailable_error_numbers = { + EAGAIN, EWOULDBLOCK}) { + QUICHE_DCHECK_NE(error_number, 0); + QUICHE_DCHECK_NE(error_number, EINTR); + + absl::Status status = absl::ErrnoToStatus(error_number, method_name); + QUICHE_DCHECK(!status.ok()); + + if (!absl::IsUnavailable(status) && + unavailable_error_numbers.contains(error_number)) { + status = absl::UnavailableError(status.message()); + } else if (absl::IsUnavailable(status) && + !unavailable_error_numbers.contains(error_number)) { + status = absl::NotFoundError(status.message()); + } + + return status; +} + +absl::Status SetSocketFlags(SocketFd fd, int to_add, int to_remove) { + QUICHE_DCHECK_GE(fd, 0); + QUICHE_DCHECK(to_add || to_remove); + QUICHE_DCHECK(!(to_add & to_remove)); + + int flags; + do { + flags = ::fcntl(fd, F_GETFL); + } while (flags < 0 && errno == EINTR); + if (flags < 0) { + absl::Status status = ToStatus(errno, "::fcntl()"); + QUICHE_LOG_FIRST_N(ERROR, 100) + << "Could not get flags for socket " << fd << " with error: " << status; + return status; + } + + QUICHE_DCHECK(!(flags & to_add) || (flags & to_remove)); + + int fcntl_result; + do { + fcntl_result = ::fcntl(fd, F_SETFL, (flags | to_add) & ~to_remove); + } while (fcntl_result < 0 && errno == EINTR); + if (fcntl_result < 0) { + absl::Status status = ToStatus(errno, "::fcntl()"); + QUICHE_LOG_FIRST_N(ERROR, 100) + << "Could not set flags for socket " << fd << " with error: " << status; + return status; + } + + return absl::OkStatus(); +} + +absl::StatusOr ValidateAndConvertAddress( + const sockaddr_storage& addr, socklen_t addr_len) { + if (addr.ss_family != AF_INET && addr.ss_family != AF_INET6) { + QUICHE_DVLOG(1) << "Socket did not have recognized address family: " + << addr.ss_family; + return absl::UnimplementedError("Unrecognized address family."); + } + + if ((addr.ss_family == AF_INET && addr_len != sizeof(sockaddr_in)) || + (addr.ss_family == AF_INET6 && addr_len != sizeof(sockaddr_in6))) { + QUICHE_DVLOG(1) << "Socket did not have expected address size (" + << (addr.ss_family == AF_INET ? sizeof(sockaddr_in) + : sizeof(sockaddr_in6)) + << "), had: " << addr_len; + return absl::UnimplementedError("Unhandled address size."); + } + + return QuicSocketAddress(addr); +} + +absl::StatusOr CreateSocketWithFlags(IpAddressFamily address_family, + SocketProtocol protocol, + int flags) { + int address_family_int = quiche::ToPlatformAddressFamily(address_family); + + int type_int = ToPlatformSocketType(protocol); + type_int |= flags; + + int protocol_int = ToPlatformProtocol(protocol); + + SocketFd fd; + do { + fd = ::socket(address_family_int, type_int, protocol_int); + } while (fd < 0 && errno == EINTR); + + if (fd >= 0) { + return fd; + } else { + absl::Status status = ToStatus(errno, "::socket()"); + QUICHE_LOG_FIRST_N(ERROR, 100) + << "Failed to create socket with error: " << status; + return status; + } +} + +absl::StatusOr AcceptInternal(SocketFd fd) { + QUICHE_DCHECK_GE(fd, 0); + + sockaddr_storage peer_addr; + socklen_t peer_addr_len = sizeof(peer_addr); + SocketFd connection_socket; + do { + connection_socket = ::accept( + fd, reinterpret_cast(&peer_addr), &peer_addr_len); + } while (connection_socket < 0 && errno == EINTR); + + if (connection_socket < 0) { + absl::Status status = ToStatus(errno, "::accept()"); + QUICHE_DVLOG(1) << "Failed to accept connection from socket " << fd + << " with error: " << status; + return status; + } + + absl::StatusOr peer_address = + ValidateAndConvertAddress(peer_addr, peer_addr_len); + + if (peer_address.ok()) { + return AcceptResult{connection_socket, peer_address.value()}; + } else { + return peer_address.status(); + } +} + +#if defined(HAS_ACCEPT4) +absl::StatusOr AcceptWithFlags(SocketFd fd, int flags) { + QUICHE_DCHECK_GE(fd, 0); + + sockaddr_storage peer_addr; + socklen_t peer_addr_len = sizeof(peer_addr); + SocketFd connection_socket; + do { + connection_socket = + ::accept4(fd, reinterpret_cast(&peer_addr), + &peer_addr_len, flags); + } while (connection_socket < 0 && errno == EINTR); + + if (connection_socket < 0) { + absl::Status status = ToStatus(errno, "::accept4()"); + QUICHE_DVLOG(1) << "Failed to accept connection from socket " << fd + << " with error: " << status; + return status; + } + + absl::StatusOr peer_address = + ValidateAndConvertAddress(peer_addr, peer_addr_len); + + if (peer_address.ok()) { + return AcceptResult{connection_socket, peer_address.value()}; + } else { + return peer_address.status(); + } +} +#endif // defined(HAS_ACCEPT4) + +socklen_t GetAddrlen(IpAddressFamily family) { + switch (family) { + case IpAddressFamily::IP_V4: + return sizeof(sockaddr_in); + case IpAddressFamily::IP_V6: + return sizeof(sockaddr_in6); + default: + QUICHE_NOTREACHED(); + return 0; + } +} + +absl::Status SetSockOptInt(SocketFd fd, int option, int value) { + QUICHE_DCHECK_GE(fd, 0); + + int result; + do { + result = ::setsockopt(fd, SOL_SOCKET, option, &value, sizeof(value)); + } while (result < 0 && errno == EINTR); + + if (result >= 0) { + return absl::OkStatus(); + } else { + absl::Status status = ToStatus(errno, "::setsockopt()"); + QUICHE_DVLOG(1) << "Failed to set socket " << fd << " option " << option + << " to " << value << " with error: " << status; + return status; + } +} + +} // namespace + +absl::StatusOr CreateSocket(IpAddressFamily address_family, + SocketProtocol protocol, bool blocking) { + int flags = 0; +#if defined(__linux__) && defined(SOCK_NONBLOCK) + if (!blocking) { + flags = SOCK_NONBLOCK; + } +#endif + + absl::StatusOr socket = + CreateSocketWithFlags(address_family, protocol, flags); + if (!socket.ok() || blocking) { + return socket; + } + +#if !defined(__linux__) || !defined(SOCK_NONBLOCK) + // If non-blocking could not be set directly on socket creation, need to do + // it now. + absl::Status set_non_blocking_result = + SetSocketBlocking(socket.value(), /*blocking=*/false); + if (!set_non_blocking_result.ok()) { + QUICHE_LOG_FIRST_N(ERROR, 100) << "Failed to set socket " << socket.value() + << " as non-blocking on creation."; + if (!Close(socket.value()).ok()) { + QUICHE_LOG_FIRST_N(ERROR, 100) + << "Failed to close socket " << socket.value() + << " after set-non-blocking error on creation."; + } + return set_non_blocking_result; + } +#endif + + return socket; +} + +absl::Status SetSocketBlocking(SocketFd fd, bool blocking) { + if (blocking) { + return SetSocketFlags(fd, /*to_add=*/0, /*to_remove=*/O_NONBLOCK); + } else { + return SetSocketFlags(fd, /*to_add=*/O_NONBLOCK, /*to_remove=*/0); + } +} + +absl::Status SetReceiveBufferSize(SocketFd fd, QuicByteCount size) { + QUICHE_DCHECK_GE(fd, 0); + QUICHE_DCHECK_LE(size, QuicByteCount{INT_MAX}); + + return SetSockOptInt(fd, SO_RCVBUF, static_cast(size)); +} + +absl::Status SetSendBufferSize(SocketFd fd, QuicByteCount size) { + QUICHE_DCHECK_GE(fd, 0); + QUICHE_DCHECK_LE(size, QuicByteCount{INT_MAX}); + + return SetSockOptInt(fd, SO_SNDBUF, static_cast(size)); +} + +absl::Status Connect(SocketFd fd, const QuicSocketAddress& peer_address) { + QUICHE_DCHECK_GE(fd, 0); + QUICHE_DCHECK(peer_address.IsInitialized()); + + sockaddr_storage addr = peer_address.generic_address(); + socklen_t addrlen = GetAddrlen(peer_address.host().address_family()); + + int connect_result; + do { + connect_result = ::connect(fd, reinterpret_cast(&addr), addrlen); + } while (connect_result < 0 && errno == EINTR); + + if (connect_result >= 0) { + return absl::OkStatus(); + } else { + // For ::connect(), only `EINPROGRESS` indicates unavailable. + absl::Status status = + ToStatus(errno, "::connect()", /*unavailable_error_numbers=*/ + {EINPROGRESS}); + QUICHE_DVLOG(1) << "Failed to connect socket " << fd + << " to address: " << peer_address.ToString() + << " with error: " << status; + return status; + } +} + +absl::Status GetSocketError(SocketFd fd) { + QUICHE_DCHECK_GE(fd, 0); + + int socket_error = 0; + socklen_t len = sizeof(socket_error); + int sockopt_result; + do { + sockopt_result = + ::getsockopt(fd, SOL_SOCKET, SO_ERROR, &socket_error, &len); + } while (sockopt_result < 0 && errno == EINTR); + + if (sockopt_result >= 0) { + if (socket_error == 0) { + return absl::OkStatus(); + } else { + return ToStatus(socket_error, "SO_ERROR"); + } + } else { + absl::Status status = ToStatus(errno, "::getsockopt()"); + QUICHE_LOG_FIRST_N(ERROR, 100) + << "Failed to get socket error information from socket " << fd + << " with error: " << status; + return status; + } +} + +absl::Status Bind(SocketFd fd, const QuicSocketAddress& address) { + QUICHE_DCHECK_GE(fd, 0); + QUICHE_DCHECK(address.IsInitialized()); + + sockaddr_storage addr = address.generic_address(); + socklen_t addr_len = GetAddrlen(address.host().address_family()); + + int result; + do { + result = ::bind(fd, reinterpret_cast(&addr), addr_len); + } while (result < 0 && errno == EINTR); + + if (result >= 0) { + return absl::OkStatus(); + } else { + absl::Status status = ToStatus(errno, "::bind()"); + QUICHE_DVLOG(1) << "Failed to bind socket " << fd + << " to address: " << address.ToString() + << " with error: " << status; + return status; + } +} + +absl::StatusOr GetSocketAddress(SocketFd fd) { + QUICHE_DCHECK_GE(fd, 0); + + sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + + int result; + do { + result = ::getsockname(fd, reinterpret_cast(&addr), &addr_len); + } while (result < 0 && errno == EINTR); + + if (result >= 0) { + return ValidateAndConvertAddress(addr, addr_len); + } else { + absl::Status status = ToStatus(errno, "::getsockname()"); + QUICHE_DVLOG(1) << "Failed to get socket " << fd + << " name with error: " << status; + return status; + } +} + +absl::Status Listen(SocketFd fd, int backlog) { + QUICHE_DCHECK_GE(fd, 0); + QUICHE_DCHECK_GT(backlog, 0); + + int result; + do { + result = ::listen(fd, backlog); + } while (result < 0 && errno == EINTR); + + if (result >= 0) { + return absl::OkStatus(); + } else { + absl::Status status = ToStatus(errno, "::listen()"); + QUICHE_DVLOG(1) << "Failed to mark socket: " << fd + << " to listen with error :" << status; + return status; + } +} + +absl::StatusOr Accept(SocketFd fd, bool blocking) { + QUICHE_DCHECK_GE(fd, 0); + +#if defined(HAS_ACCEPT4) + if (!blocking) { + return AcceptWithFlags(fd, SOCK_NONBLOCK); + } +#endif + + absl::StatusOr accept_result = AcceptInternal(fd); + if (!accept_result.ok() || blocking) { + return accept_result; + } + +#if !defined(__linux__) || !defined(SOCK_NONBLOCK) + // If non-blocking could not be set directly on socket acceptance, need to + // do it now. + absl::Status set_non_blocking_result = + SetSocketBlocking(accept_result.value().fd, /*blocking=*/false); + if (!set_non_blocking_result.ok()) { + QUICHE_LOG_FIRST_N(ERROR, 100) + << "Failed to set socket " << fd << " as non-blocking on acceptance."; + if (!Close(accept_result.value().fd).ok()) { + QUICHE_LOG_FIRST_N(ERROR, 100) + << "Failed to close socket " << accept_result.value().fd + << " after error setting non-blocking on acceptance."; + } + return set_non_blocking_result; + } +#endif + + return accept_result; +} + +absl::StatusOr> Receive(SocketFd fd, absl::Span buffer, + bool peek) { + QUICHE_DCHECK_GE(fd, 0); + QUICHE_DCHECK(!buffer.empty()); + + ssize_t num_read; + do { + num_read = + ::recv(fd, buffer.data(), buffer.size(), /*flags=*/peek ? MSG_PEEK : 0); + } while (num_read < 0 && errno == EINTR); + + if (num_read > 0 && static_cast(num_read) > buffer.size()) { + QUICHE_LOG_FIRST_N(WARNING, 100) + << "Received more bytes (" << num_read << ") from socket " << fd + << " than buffer size (" << buffer.size() << ")."; + return absl::OutOfRangeError( + "::recv(): Received more bytes than buffer size."); + } else if (num_read >= 0) { + return buffer.subspan(0, num_read); + } else { + absl::Status status = ToStatus(errno, "::recv()"); + QUICHE_DVLOG(1) << "Failed to receive from socket: " << fd + << " with error: " << status; + return status; + } +} + +absl::StatusOr Send(SocketFd fd, absl::string_view buffer) { + QUICHE_DCHECK_GE(fd, 0); + QUICHE_DCHECK(!buffer.empty()); + + ssize_t num_sent; + do { + num_sent = ::send(fd, buffer.data(), buffer.size(), /*flags=*/0); + } while (num_sent < 0 && errno == EINTR); + + if (num_sent > 0 && static_cast(num_sent) > buffer.size()) { + QUICHE_LOG_FIRST_N(WARNING, 100) + << "Sent more bytes (" << num_sent << ") to socket " << fd + << " than buffer size (" << buffer.size() << ")."; + return absl::OutOfRangeError("::send(): Sent more bytes than buffer size."); + } else if (num_sent >= 0) { + return buffer.substr(num_sent); + } else { + absl::Status status = ToStatus(errno, "::send()"); + QUICHE_DVLOG(1) << "Failed to send to socket: " << fd + << " with error: " << status; + return status; + } +} + +absl::Status Close(SocketFd fd) { + QUICHE_DCHECK_GE(fd, 0); + + int close_result = ::close(fd); + + if (close_result >= 0) { + return absl::OkStatus(); + } else if (errno == EINTR) { + // Ignore EINTR on close because the socket is left in an undefined state + // and can't be acted on again. + QUICHE_DVLOG(1) << "Socket " << fd << " close unspecified due to EINTR."; + return absl::OkStatus(); + } else { + absl::Status status = ToStatus(errno, "::close()"); + QUICHE_DVLOG(1) << "Failed to close socket: " << fd + << " with error: " << status; + return status; + } +} + +} // namespace quic::socket_api diff --git a/quiche/quic/core/io/socket_test.cc b/quiche/quic/core/io/socket_test.cc new file mode 100644 index 000000000000..9afe0f695a18 --- /dev/null +++ b/quiche/quic/core/io/socket_test.cc @@ -0,0 +1,197 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/io/socket.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "quiche/quic/platform/api/quic_ip_address_family.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/platform/api/quiche_test_loopback.h" + +namespace quic { +namespace { + +using quiche::test::QuicheTest; +using testing::Lt; +using testing::SizeIs; + +SocketFd CreateTestSocket(socket_api::SocketProtocol protocol, + bool blocking = true) { + absl::StatusOr socket = socket_api::CreateSocket( + quiche::TestLoopback().address_family(), protocol, blocking); + + if (socket.ok()) { + return socket.value(); + } else { + QUICHE_CHECK(false); + return kInvalidSocketFd; + } +} + +TEST(SocketTest, CreateAndCloseSocket) { + QuicIpAddress localhost_address = quiche::TestLoopback(); + absl::StatusOr created_socket = socket_api::CreateSocket( + localhost_address.address_family(), socket_api::SocketProtocol::kUdp); + + EXPECT_TRUE(created_socket.ok()); + + EXPECT_TRUE(socket_api::Close(created_socket.value()).ok()); +} + +TEST(SocketTest, SetSocketBlocking) { + SocketFd socket = CreateTestSocket(socket_api::SocketProtocol::kUdp, + /*blocking=*/true); + + EXPECT_TRUE(socket_api::SetSocketBlocking(socket, /*blocking=*/false).ok()); + + EXPECT_TRUE(socket_api::Close(socket).ok()); +} + +TEST(SocketTest, SetReceiveBufferSize) { + SocketFd socket = CreateTestSocket(socket_api::SocketProtocol::kUdp, + /*blocking=*/true); + + EXPECT_TRUE(socket_api::SetReceiveBufferSize(socket, /*size=*/100).ok()); + + EXPECT_TRUE(socket_api::Close(socket).ok()); +} + +TEST(SocketTest, SetSendBufferSize) { + SocketFd socket = CreateTestSocket(socket_api::SocketProtocol::kUdp, + /*blocking=*/true); + + EXPECT_TRUE(socket_api::SetSendBufferSize(socket, /*size=*/100).ok()); + + EXPECT_TRUE(socket_api::Close(socket).ok()); +} + +TEST(SocketTest, Connect) { + SocketFd socket = CreateTestSocket(socket_api::SocketProtocol::kUdp); + + // UDP, so "connecting" should succeed without any listening sockets. + EXPECT_TRUE(socket_api::Connect( + socket, QuicSocketAddress(quiche::TestLoopback(), /*port=*/0)) + .ok()); + + EXPECT_TRUE(socket_api::Close(socket).ok()); +} + +TEST(SocketTest, GetSocketError) { + SocketFd socket = CreateTestSocket(socket_api::SocketProtocol::kUdp, + /*blocking=*/true); + + absl::Status error = socket_api::GetSocketError(socket); + EXPECT_TRUE(error.ok()); + + EXPECT_TRUE(socket_api::Close(socket).ok()); +} + +TEST(SocketTest, Bind) { + SocketFd socket = CreateTestSocket(socket_api::SocketProtocol::kUdp); + + EXPECT_TRUE(socket_api::Bind( + socket, QuicSocketAddress(quiche::TestLoopback(), /*port=*/0)) + .ok()); + + EXPECT_TRUE(socket_api::Close(socket).ok()); +} + +TEST(SocketTest, GetSocketAddress) { + SocketFd socket = CreateTestSocket(socket_api::SocketProtocol::kUdp); + ASSERT_TRUE(socket_api::Bind( + socket, QuicSocketAddress(quiche::TestLoopback(), /*port=*/0)) + .ok()); + + absl::StatusOr address = + socket_api::GetSocketAddress(socket); + EXPECT_TRUE(address.ok()); + EXPECT_TRUE(address.value().IsInitialized()); + EXPECT_EQ(address.value().host(), quiche::TestLoopback()); + + EXPECT_TRUE(socket_api::Close(socket).ok()); +} + +TEST(SocketTest, Listen) { + SocketFd socket = CreateTestSocket(socket_api::SocketProtocol::kTcp); + ASSERT_TRUE(socket_api::Bind( + socket, QuicSocketAddress(quiche::TestLoopback(), /*port=*/0)) + .ok()); + + EXPECT_TRUE(socket_api::Listen(socket, /*backlog=*/5).ok()); + + EXPECT_TRUE(socket_api::Close(socket).ok()); +} + +TEST(SocketTest, Accept) { + // Need a non-blocking socket to avoid waiting when no connection comes. + SocketFd socket = + CreateTestSocket(socket_api::SocketProtocol::kTcp, /*blocking=*/false); + ASSERT_TRUE(socket_api::Bind( + socket, QuicSocketAddress(quiche::TestLoopback(), /*port=*/0)) + .ok()); + ASSERT_TRUE(socket_api::Listen(socket, /*backlog=*/5).ok()); + + // Nothing set up to connect, so expect kUnavailable. + absl::StatusOr result = socket_api::Accept(socket); + ASSERT_FALSE(result.ok()); + EXPECT_TRUE(absl::IsUnavailable(result.status())); + + EXPECT_TRUE(socket_api::Close(socket).ok()); +} + +TEST(SocketTest, Receive) { + // Non-blocking to avoid waiting when no data to receive. + SocketFd socket = CreateTestSocket(socket_api::SocketProtocol::kUdp, + /*blocking=*/false); + + std::string buffer(100, 0); + absl::StatusOr> result = + socket_api::Receive(socket, absl::MakeSpan(buffer)); + ASSERT_FALSE(result.ok()); + EXPECT_TRUE(absl::IsUnavailable(result.status())); + + EXPECT_TRUE(socket_api::Close(socket).ok()); +} + +TEST(SocketTest, Peek) { + // Non-blocking to avoid waiting when no data to receive. + SocketFd socket = CreateTestSocket(socket_api::SocketProtocol::kUdp, + /*blocking=*/false); + + std::string buffer(100, 0); + absl::StatusOr> result = + socket_api::Receive(socket, absl::MakeSpan(buffer), /*peek=*/true); + ASSERT_FALSE(result.ok()); + EXPECT_TRUE(absl::IsUnavailable(result.status())); + + EXPECT_TRUE(socket_api::Close(socket).ok()); +} + +TEST(SocketTest, Send) { + SocketFd socket = CreateTestSocket(socket_api::SocketProtocol::kUdp); + // UDP, so "connecting" should succeed without any listening sockets. + ASSERT_TRUE(socket_api::Connect( + socket, QuicSocketAddress(quiche::TestLoopback(), /*port=*/0)) + .ok()); + + char buffer[] = {12, 34, 56, 78}; + // Expect at least some data to be sent successfully. + absl::StatusOr result = + socket_api::Send(socket, absl::string_view(buffer, sizeof(buffer))); + ASSERT_TRUE(result.ok()); + EXPECT_THAT(result.value(), SizeIs(Lt(4))); + + EXPECT_TRUE(socket_api::Close(socket).ok()); +} + +} // namespace +} // namespace quic diff --git a/quiche/quic/core/legacy_quic_stream_id_manager.cc b/quiche/quic/core/legacy_quic_stream_id_manager.cc new file mode 100644 index 000000000000..0e422ef12cdc --- /dev/null +++ b/quiche/quic/core/legacy_quic_stream_id_manager.cc @@ -0,0 +1,139 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +#include "quiche/quic/core/legacy_quic_stream_id_manager.h" + +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" + +namespace quic { + +LegacyQuicStreamIdManager::LegacyQuicStreamIdManager( + Perspective perspective, QuicTransportVersion transport_version, + size_t max_open_outgoing_streams, size_t max_open_incoming_streams) + : perspective_(perspective), + transport_version_(transport_version), + max_open_outgoing_streams_(max_open_outgoing_streams), + max_open_incoming_streams_(max_open_incoming_streams), + next_outgoing_stream_id_(QuicUtils::GetFirstBidirectionalStreamId( + transport_version_, perspective_)), + largest_peer_created_stream_id_( + perspective_ == Perspective::IS_SERVER + ? (QuicVersionUsesCryptoFrames(transport_version_) + ? QuicUtils::GetInvalidStreamId(transport_version_) + : QuicUtils::GetCryptoStreamId(transport_version_)) + : QuicUtils::GetInvalidStreamId(transport_version_)), + num_open_incoming_streams_(0), + num_open_outgoing_streams_(0) {} + +LegacyQuicStreamIdManager::~LegacyQuicStreamIdManager() {} + +bool LegacyQuicStreamIdManager::CanOpenNextOutgoingStream() const { + QUICHE_DCHECK_LE(num_open_outgoing_streams_, max_open_outgoing_streams_); + QUIC_DLOG_IF(INFO, num_open_outgoing_streams_ == max_open_outgoing_streams_) + << "Failed to create a new outgoing stream. " + << "Already " << num_open_outgoing_streams_ << " open."; + return num_open_outgoing_streams_ < max_open_outgoing_streams_; +} + +bool LegacyQuicStreamIdManager::CanOpenIncomingStream() const { + return num_open_incoming_streams_ < max_open_incoming_streams_; +} + +bool LegacyQuicStreamIdManager::MaybeIncreaseLargestPeerStreamId( + const QuicStreamId stream_id) { + available_streams_.erase(stream_id); + + if (largest_peer_created_stream_id_ != + QuicUtils::GetInvalidStreamId(transport_version_) && + stream_id <= largest_peer_created_stream_id_) { + return true; + } + + // Check if the new number of available streams would cause the number of + // available streams to exceed the limit. Note that the peer can create + // only alternately-numbered streams. + size_t additional_available_streams = + (stream_id - largest_peer_created_stream_id_) / 2 - 1; + if (largest_peer_created_stream_id_ == + QuicUtils::GetInvalidStreamId(transport_version_)) { + additional_available_streams = (stream_id + 1) / 2 - 1; + } + size_t new_num_available_streams = + GetNumAvailableStreams() + additional_available_streams; + if (new_num_available_streams > MaxAvailableStreams()) { + QUIC_DLOG(INFO) << perspective_ + << "Failed to create a new incoming stream with id:" + << stream_id << ". There are already " + << GetNumAvailableStreams() + << " streams available, which would become " + << new_num_available_streams << ", which exceeds the limit " + << MaxAvailableStreams() << "."; + return false; + } + QuicStreamId first_available_stream = largest_peer_created_stream_id_ + 2; + if (largest_peer_created_stream_id_ == + QuicUtils::GetInvalidStreamId(transport_version_)) { + first_available_stream = QuicUtils::GetFirstBidirectionalStreamId( + transport_version_, QuicUtils::InvertPerspective(perspective_)); + } + for (QuicStreamId id = first_available_stream; id < stream_id; id += 2) { + available_streams_.insert(id); + } + largest_peer_created_stream_id_ = stream_id; + + return true; +} + +QuicStreamId LegacyQuicStreamIdManager::GetNextOutgoingStreamId() { + QuicStreamId id = next_outgoing_stream_id_; + next_outgoing_stream_id_ += 2; + return id; +} + +void LegacyQuicStreamIdManager::ActivateStream(bool is_incoming) { + if (is_incoming) { + ++num_open_incoming_streams_; + return; + } + ++num_open_outgoing_streams_; +} + +void LegacyQuicStreamIdManager::OnStreamClosed(bool is_incoming) { + if (is_incoming) { + QUIC_BUG_IF(quic_bug_12720_1, num_open_incoming_streams_ == 0); + --num_open_incoming_streams_; + return; + } + QUIC_BUG_IF(quic_bug_12720_2, num_open_outgoing_streams_ == 0); + --num_open_outgoing_streams_; +} + +bool LegacyQuicStreamIdManager::IsAvailableStream(QuicStreamId id) const { + if (!IsIncomingStream(id)) { + // Stream IDs under next_ougoing_stream_id_ are either open or previously + // open but now closed. + return id >= next_outgoing_stream_id_; + } + // For peer created streams, we also need to consider available streams. + return largest_peer_created_stream_id_ == + QuicUtils::GetInvalidStreamId(transport_version_) || + id > largest_peer_created_stream_id_ || + available_streams_.contains(id); +} + +bool LegacyQuicStreamIdManager::IsIncomingStream(QuicStreamId id) const { + return id % 2 != next_outgoing_stream_id_ % 2; +} + +size_t LegacyQuicStreamIdManager::GetNumAvailableStreams() const { + return available_streams_.size(); +} + +size_t LegacyQuicStreamIdManager::MaxAvailableStreams() const { + return max_open_incoming_streams_ * kMaxAvailableStreamsMultiplier; +} + +} // namespace quic diff --git a/quiche/quic/core/legacy_quic_stream_id_manager.h b/quiche/quic/core/legacy_quic_stream_id_manager.h new file mode 100644 index 000000000000..3c67028e734e --- /dev/null +++ b/quiche/quic/core/legacy_quic_stream_id_manager.h @@ -0,0 +1,128 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +#ifndef QUICHE_QUIC_CORE_LEGACY_QUIC_STREAM_ID_MANAGER_H_ +#define QUICHE_QUIC_CORE_LEGACY_QUIC_STREAM_ID_MANAGER_H_ + +#include "absl/container/flat_hash_set.h" +#include "quiche/quic/core/quic_stream_id_manager.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_versions.h" + +namespace quic { + +namespace test { +class QuicSessionPeer; +} // namespace test + +class QuicSession; + +// Manages Google QUIC stream IDs. This manager is responsible for two +// questions: 1) can next outgoing stream ID be allocated (if yes, what is the +// next outgoing stream ID) and 2) can a new incoming stream be opened. +class QUIC_EXPORT_PRIVATE LegacyQuicStreamIdManager { + public: + LegacyQuicStreamIdManager(Perspective perspective, + QuicTransportVersion transport_version, + size_t max_open_outgoing_streams, + size_t max_open_incoming_streams); + + ~LegacyQuicStreamIdManager(); + + // Returns true if the next outgoing stream ID can be allocated. + bool CanOpenNextOutgoingStream() const; + + // Returns true if a new incoming stream can be opened. + bool CanOpenIncomingStream() const; + + // Returns false when increasing the largest created stream id to |id| would + // violate the limit, so the connection should be closed. + bool MaybeIncreaseLargestPeerStreamId(const QuicStreamId id); + + // Returns true if |id| is still available. + bool IsAvailableStream(QuicStreamId id) const; + + // Returns the stream ID for a new outgoing stream, and increments the + // underlying counter. + QuicStreamId GetNextOutgoingStreamId(); + + // Called when a new stream is open. + void ActivateStream(bool is_incoming); + + // Called when a stream ID is closed. + void OnStreamClosed(bool is_incoming); + + // Return true if |id| is peer initiated. + bool IsIncomingStream(QuicStreamId id) const; + + size_t MaxAvailableStreams() const; + + void set_max_open_incoming_streams(size_t max_open_incoming_streams) { + max_open_incoming_streams_ = max_open_incoming_streams; + } + + void set_max_open_outgoing_streams(size_t max_open_outgoing_streams) { + max_open_outgoing_streams_ = max_open_outgoing_streams; + } + + void set_largest_peer_created_stream_id( + QuicStreamId largest_peer_created_stream_id) { + largest_peer_created_stream_id_ = largest_peer_created_stream_id; + } + + size_t max_open_incoming_streams() const { + return max_open_incoming_streams_; + } + + size_t max_open_outgoing_streams() const { + return max_open_outgoing_streams_; + } + + QuicStreamId next_outgoing_stream_id() const { + return next_outgoing_stream_id_; + } + + QuicStreamId largest_peer_created_stream_id() const { + return largest_peer_created_stream_id_; + } + + size_t GetNumAvailableStreams() const; + + size_t num_open_incoming_streams() const { + return num_open_incoming_streams_; + } + size_t num_open_outgoing_streams() const { + return num_open_outgoing_streams_; + } + + private: + friend class test::QuicSessionPeer; + + const Perspective perspective_; + const QuicTransportVersion transport_version_; + + // The maximum number of outgoing streams this connection can open. + size_t max_open_outgoing_streams_; + + // The maximum number of incoming streams this connection will allow. + size_t max_open_incoming_streams_; + + // The ID to use for the next outgoing stream. + QuicStreamId next_outgoing_stream_id_; + + // Set of stream ids that are less than the largest stream id that has been + // received, but are nonetheless available to be created. + absl::flat_hash_set available_streams_; + + QuicStreamId largest_peer_created_stream_id_; + + // A counter for peer initiated open streams. + size_t num_open_incoming_streams_; + + // A counter for self initiated open streams. + size_t num_open_outgoing_streams_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_LEGACY_QUIC_STREAM_ID_MANAGER_H_ diff --git a/quiche/quic/core/legacy_quic_stream_id_manager_test.cc b/quiche/quic/core/legacy_quic_stream_id_manager_test.cc new file mode 100644 index 000000000000..1dcc0574f2fa --- /dev/null +++ b/quiche/quic/core/legacy_quic_stream_id_manager_test.cc @@ -0,0 +1,178 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/legacy_quic_stream_id_manager.h" + +#include + +#include "absl/strings/str_cat.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_session_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { +namespace test { +namespace { + +using testing::_; +using testing::StrictMock; + +struct TestParams { + TestParams(ParsedQuicVersion version, Perspective perspective) + : version(version), perspective(perspective) {} + + ParsedQuicVersion version; + Perspective perspective; +}; + +// Used by ::testing::PrintToStringParamName(). +std::string PrintToString(const TestParams& p) { + return absl::StrCat( + ParsedQuicVersionToString(p.version), + (p.perspective == Perspective::IS_CLIENT ? "Client" : "Server")); +} + +std::vector GetTestParams() { + std::vector params; + for (ParsedQuicVersion version : AllSupportedVersions()) { + for (auto perspective : {Perspective::IS_CLIENT, Perspective::IS_SERVER}) { + // LegacyQuicStreamIdManager is only used when IETF QUIC frames are not + // presented. + if (!VersionHasIetfQuicFrames(version.transport_version)) { + params.push_back(TestParams(version, perspective)); + } + } + } + return params; +} + +class LegacyQuicStreamIdManagerTest : public QuicTestWithParam { + public: + LegacyQuicStreamIdManagerTest() + : manager_(GetParam().perspective, GetParam().version.transport_version, + kDefaultMaxStreamsPerConnection, + kDefaultMaxStreamsPerConnection) {} + + protected: + QuicStreamId GetNthPeerInitiatedId(int n) { + if (GetParam().perspective == Perspective::IS_SERVER) { + return QuicUtils::GetFirstBidirectionalStreamId( + GetParam().version.transport_version, Perspective::IS_CLIENT) + + 2 * n; + } else { + return 2 + 2 * n; + } + } + + LegacyQuicStreamIdManager manager_; +}; + +INSTANTIATE_TEST_SUITE_P(Tests, LegacyQuicStreamIdManagerTest, + ::testing::ValuesIn(GetTestParams()), + ::testing::PrintToStringParamName()); + +TEST_P(LegacyQuicStreamIdManagerTest, CanOpenNextOutgoingStream) { + for (size_t i = 0; i < manager_.max_open_outgoing_streams() - 1; ++i) { + manager_.ActivateStream(/*is_incoming=*/false); + } + EXPECT_TRUE(manager_.CanOpenNextOutgoingStream()); + manager_.ActivateStream(/*is_incoming=*/false); + EXPECT_FALSE(manager_.CanOpenNextOutgoingStream()); +} + +TEST_P(LegacyQuicStreamIdManagerTest, CanOpenIncomingStream) { + for (size_t i = 0; i < manager_.max_open_incoming_streams() - 1; ++i) { + manager_.ActivateStream(/*is_incoming=*/true); + } + EXPECT_TRUE(manager_.CanOpenIncomingStream()); + manager_.ActivateStream(/*is_incoming=*/true); + EXPECT_FALSE(manager_.CanOpenIncomingStream()); +} + +TEST_P(LegacyQuicStreamIdManagerTest, AvailableStreams) { + ASSERT_TRUE( + manager_.MaybeIncreaseLargestPeerStreamId(GetNthPeerInitiatedId(3))); + EXPECT_TRUE(manager_.IsAvailableStream(GetNthPeerInitiatedId(1))); + EXPECT_TRUE(manager_.IsAvailableStream(GetNthPeerInitiatedId(2))); + ASSERT_TRUE( + manager_.MaybeIncreaseLargestPeerStreamId(GetNthPeerInitiatedId(2))); + ASSERT_TRUE( + manager_.MaybeIncreaseLargestPeerStreamId(GetNthPeerInitiatedId(1))); +} + +TEST_P(LegacyQuicStreamIdManagerTest, MaxAvailableStreams) { + // Test that the server closes the connection if a client makes too many data + // streams available. The server accepts slightly more than the negotiated + // stream limit to deal with rare cases where a client FIN/RST is lost. + const size_t kMaxStreamsForTest = 10; + const size_t kAvailableStreamLimit = manager_.MaxAvailableStreams(); + EXPECT_EQ( + manager_.max_open_incoming_streams() * kMaxAvailableStreamsMultiplier, + manager_.MaxAvailableStreams()); + // The protocol specification requires that there can be at least 10 times + // as many available streams as the connection's maximum open streams. + EXPECT_LE(10 * kMaxStreamsForTest, kAvailableStreamLimit); + + EXPECT_TRUE( + manager_.MaybeIncreaseLargestPeerStreamId(GetNthPeerInitiatedId(0))); + + // Establish available streams up to the server's limit. + const int kLimitingStreamId = + GetNthPeerInitiatedId(kAvailableStreamLimit + 1); + // This exceeds the stream limit. In versions other than 99 + // this is allowed. Version 99 hews to the IETF spec and does + // not allow it. + EXPECT_TRUE(manager_.MaybeIncreaseLargestPeerStreamId(kLimitingStreamId)); + + // This forces stream kLimitingStreamId + 2 to become available, which + // violates the quota. + EXPECT_FALSE( + manager_.MaybeIncreaseLargestPeerStreamId(kLimitingStreamId + 2 * 2)); +} + +TEST_P(LegacyQuicStreamIdManagerTest, MaximumAvailableOpenedStreams) { + QuicStreamId stream_id = GetNthPeerInitiatedId(0); + EXPECT_TRUE(manager_.MaybeIncreaseLargestPeerStreamId(stream_id)); + + EXPECT_TRUE(manager_.MaybeIncreaseLargestPeerStreamId( + stream_id + 2 * (manager_.max_open_incoming_streams() - 1))); +} + +TEST_P(LegacyQuicStreamIdManagerTest, TooManyAvailableStreams) { + QuicStreamId stream_id = GetNthPeerInitiatedId(0); + EXPECT_TRUE(manager_.MaybeIncreaseLargestPeerStreamId(stream_id)); + + // A stream ID which is too large to create. + QuicStreamId stream_id2 = + GetNthPeerInitiatedId(2 * manager_.MaxAvailableStreams() + 4); + EXPECT_FALSE(manager_.MaybeIncreaseLargestPeerStreamId(stream_id2)); +} + +TEST_P(LegacyQuicStreamIdManagerTest, ManyAvailableStreams) { + // When max_open_streams_ is 200, should be able to create 200 streams + // out-of-order, that is, creating the one with the largest stream ID first. + manager_.set_max_open_incoming_streams(200); + QuicStreamId stream_id = GetNthPeerInitiatedId(0); + EXPECT_TRUE(manager_.MaybeIncreaseLargestPeerStreamId(stream_id)); + + // Create the largest stream ID of a threatened total of 200 streams. + // GetNth... starts at 0, so for 200 streams, get the 199th. + EXPECT_TRUE( + manager_.MaybeIncreaseLargestPeerStreamId(GetNthPeerInitiatedId(199))); +} + +TEST_P(LegacyQuicStreamIdManagerTest, + TestMaxIncomingAndOutgoingStreamsAllowed) { + EXPECT_EQ(manager_.max_open_incoming_streams(), + kDefaultMaxStreamsPerConnection); + EXPECT_EQ(manager_.max_open_outgoing_streams(), + kDefaultMaxStreamsPerConnection); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/packet_number_indexed_queue.h b/quiche/quic/core/packet_number_indexed_queue.h new file mode 100644 index 000000000000..ef48bd1bfabc --- /dev/null +++ b/quiche/quic/core/packet_number_indexed_queue.h @@ -0,0 +1,252 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_PACKET_NUMBER_INDEXED_QUEUE_H_ +#define QUICHE_QUIC_CORE_PACKET_NUMBER_INDEXED_QUEUE_H_ + +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_packet_number.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/common/quiche_circular_deque.h" + +namespace quic { + +// PacketNumberIndexedQueue is a queue of mostly continuous numbered entries +// which supports the following operations: +// - adding elements to the end of the queue, or at some point past the end +// - removing elements in any order +// - retrieving elements +// If all elements are inserted in order, all of the operations above are +// amortized O(1) time. +// +// Internally, the data structure is a deque where each element is marked as +// present or not. The deque starts at the lowest present index. Whenever an +// element is removed, it's marked as not present, and the front of the deque is +// cleared of elements that are not present. +// +// The tail of the queue is not cleared due to the assumption of entries being +// inserted in order, though removing all elements of the queue will return it +// to its initial state. +// +// Note that this data structure is inherently hazardous, since an addition of +// just two entries will cause it to consume all of the memory available. +// Because of that, it is not a general-purpose container and should not be used +// as one. +// TODO(wub): Update the comments when deprecating +// --quic_bw_sampler_remove_packets_once_per_congestion_event. +template +class QUIC_NO_EXPORT PacketNumberIndexedQueue { + public: + PacketNumberIndexedQueue() : number_of_present_entries_(0) {} + + // Retrieve the entry associated with the packet number. Returns the pointer + // to the entry in case of success, or nullptr if the entry does not exist. + T* GetEntry(QuicPacketNumber packet_number); + const T* GetEntry(QuicPacketNumber packet_number) const; + + // Inserts data associated |packet_number| into (or past) the end of the + // queue, filling up the missing intermediate entries as necessary. Returns + // true if the element has been inserted successfully, false if it was already + // in the queue or inserted out of order. + template + bool Emplace(QuicPacketNumber packet_number, Args&&... args); + + // Removes data associated with |packet_number| and frees the slots in the + // queue as necessary. + bool Remove(QuicPacketNumber packet_number); + + // Same as above, but if an entry is present in the queue, also call f(entry) + // before removing it. + template + bool Remove(QuicPacketNumber packet_number, Function f); + + // Remove up to, but not including |packet_number|. + // Unused slots in the front are also removed, which means when the function + // returns, |first_packet()| can be larger than |packet_number|. + void RemoveUpTo(QuicPacketNumber packet_number); + + bool IsEmpty() const { return number_of_present_entries_ == 0; } + + // Returns the number of entries in the queue. + size_t number_of_present_entries() const { + return number_of_present_entries_; + } + + // Returns the number of entries allocated in the underlying deque. This is + // proportional to the memory usage of the queue. + size_t entry_slots_used() const { return entries_.size(); } + + // Packet number of the first entry in the queue. + QuicPacketNumber first_packet() const { return first_packet_; } + + // Packet number of the last entry ever inserted in the queue. Note that the + // entry in question may have already been removed. Zero if the queue is + // empty. + QuicPacketNumber last_packet() const { + if (IsEmpty()) { + return QuicPacketNumber(); + } + return first_packet_ + entries_.size() - 1; + } + + private: + // Wrapper around T used to mark whether the entry is actually in the map. + struct QUIC_NO_EXPORT EntryWrapper : T { + // NOTE(wub): When quic_bw_sampler_remove_packets_once_per_congestion_event + // is enabled, |present| is false if and only if this is a placeholder entry + // for holes in the parent's |entries|. + bool present; + + EntryWrapper() : present(false) {} + + template + explicit EntryWrapper(Args&&... args) + : T(std::forward(args)...), present(true) {} + }; + + // Cleans up unused slots in the front after removing an element. + void Cleanup(); + + const EntryWrapper* GetEntryWrapper(QuicPacketNumber offset) const; + EntryWrapper* GetEntryWrapper(QuicPacketNumber offset) { + const auto* const_this = this; + return const_cast(const_this->GetEntryWrapper(offset)); + } + + quiche::QuicheCircularDeque entries_; + // NOTE(wub): When --quic_bw_sampler_remove_packets_once_per_congestion_event + // is enabled, |number_of_present_entries_| only represents number of holes, + // which does not include number of acked or lost packets. + size_t number_of_present_entries_; + QuicPacketNumber first_packet_; +}; + +template +T* PacketNumberIndexedQueue::GetEntry(QuicPacketNumber packet_number) { + EntryWrapper* entry = GetEntryWrapper(packet_number); + if (entry == nullptr) { + return nullptr; + } + return entry; +} + +template +const T* PacketNumberIndexedQueue::GetEntry( + QuicPacketNumber packet_number) const { + const EntryWrapper* entry = GetEntryWrapper(packet_number); + if (entry == nullptr) { + return nullptr; + } + return entry; +} + +template +template +bool PacketNumberIndexedQueue::Emplace(QuicPacketNumber packet_number, + Args&&... args) { + if (!packet_number.IsInitialized()) { + QUIC_BUG(quic_bug_10359_1) + << "Try to insert an uninitialized packet number"; + return false; + } + + if (IsEmpty()) { + QUICHE_DCHECK(entries_.empty()); + QUICHE_DCHECK(!first_packet_.IsInitialized()); + + entries_.emplace_back(std::forward(args)...); + number_of_present_entries_ = 1; + first_packet_ = packet_number; + return true; + } + + // Do not allow insertion out-of-order. + if (packet_number <= last_packet()) { + return false; + } + + // Handle potentially missing elements. + size_t offset = packet_number - first_packet_; + if (offset > entries_.size()) { + entries_.resize(offset); + } + + number_of_present_entries_++; + entries_.emplace_back(std::forward(args)...); + QUICHE_DCHECK_EQ(packet_number, last_packet()); + return true; +} + +template +bool PacketNumberIndexedQueue::Remove(QuicPacketNumber packet_number) { + return Remove(packet_number, [](const T&) {}); +} + +template +template +bool PacketNumberIndexedQueue::Remove(QuicPacketNumber packet_number, + Function f) { + EntryWrapper* entry = GetEntryWrapper(packet_number); + if (entry == nullptr) { + return false; + } + f(*static_cast(entry)); + entry->present = false; + number_of_present_entries_--; + + if (packet_number == first_packet()) { + Cleanup(); + } + return true; +} + +template +void PacketNumberIndexedQueue::RemoveUpTo(QuicPacketNumber packet_number) { + while (!entries_.empty() && first_packet_.IsInitialized() && + first_packet_ < packet_number) { + if (entries_.front().present) { + number_of_present_entries_--; + } + entries_.pop_front(); + first_packet_++; + } + Cleanup(); +} + +template +void PacketNumberIndexedQueue::Cleanup() { + while (!entries_.empty() && !entries_.front().present) { + entries_.pop_front(); + first_packet_++; + } + if (entries_.empty()) { + first_packet_.Clear(); + } +} + +template +auto PacketNumberIndexedQueue::GetEntryWrapper( + QuicPacketNumber packet_number) const -> const EntryWrapper* { + if (!packet_number.IsInitialized() || IsEmpty() || + packet_number < first_packet_) { + return nullptr; + } + + uint64_t offset = packet_number - first_packet_; + if (offset >= entries_.size()) { + return nullptr; + } + + const EntryWrapper* entry = &entries_[offset]; + if (!entry->present) { + return nullptr; + } + + return entry; +} + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_PACKET_NUMBER_INDEXED_QUEUE_H_ diff --git a/quiche/quic/core/packet_number_indexed_queue_test.cc b/quiche/quic/core/packet_number_indexed_queue_test.cc new file mode 100644 index 000000000000..f05309beca23 --- /dev/null +++ b/quiche/quic/core/packet_number_indexed_queue_test.cc @@ -0,0 +1,205 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/packet_number_indexed_queue.h" + +#include +#include +#include + +#include "quiche/quic/core/quic_packet_number.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic::test { +namespace { + +class PacketNumberIndexedQueueTest : public QuicTest { + public: + PacketNumberIndexedQueueTest() {} + + protected: + PacketNumberIndexedQueue queue_; +}; + +TEST_F(PacketNumberIndexedQueueTest, InitialState) { + EXPECT_TRUE(queue_.IsEmpty()); + EXPECT_FALSE(queue_.first_packet().IsInitialized()); + EXPECT_FALSE(queue_.last_packet().IsInitialized()); + EXPECT_EQ(0u, queue_.number_of_present_entries()); + EXPECT_EQ(0u, queue_.entry_slots_used()); +} + +TEST_F(PacketNumberIndexedQueueTest, InsertingContinuousElements) { + ASSERT_TRUE(queue_.Emplace(QuicPacketNumber(1001), "one")); + EXPECT_EQ("one", *queue_.GetEntry(QuicPacketNumber(1001))); + + ASSERT_TRUE(queue_.Emplace(QuicPacketNumber(1002), "two")); + EXPECT_EQ("two", *queue_.GetEntry(QuicPacketNumber(1002))); + + EXPECT_FALSE(queue_.IsEmpty()); + EXPECT_EQ(QuicPacketNumber(1001u), queue_.first_packet()); + EXPECT_EQ(QuicPacketNumber(1002u), queue_.last_packet()); + EXPECT_EQ(2u, queue_.number_of_present_entries()); + EXPECT_EQ(2u, queue_.entry_slots_used()); +} + +TEST_F(PacketNumberIndexedQueueTest, InsertingOutOfOrder) { + queue_.Emplace(QuicPacketNumber(1001), "one"); + + ASSERT_TRUE(queue_.Emplace(QuicPacketNumber(1003), "three")); + EXPECT_EQ(nullptr, queue_.GetEntry(QuicPacketNumber(1002))); + EXPECT_EQ("three", *queue_.GetEntry(QuicPacketNumber(1003))); + + EXPECT_EQ(QuicPacketNumber(1001u), queue_.first_packet()); + EXPECT_EQ(QuicPacketNumber(1003u), queue_.last_packet()); + EXPECT_EQ(2u, queue_.number_of_present_entries()); + EXPECT_EQ(3u, queue_.entry_slots_used()); + + ASSERT_FALSE(queue_.Emplace(QuicPacketNumber(1002), "two")); +} + +TEST_F(PacketNumberIndexedQueueTest, InsertingIntoPast) { + queue_.Emplace(QuicPacketNumber(1001), "one"); + EXPECT_FALSE(queue_.Emplace(QuicPacketNumber(1000), "zero")); +} + +TEST_F(PacketNumberIndexedQueueTest, InsertingDuplicate) { + queue_.Emplace(QuicPacketNumber(1001), "one"); + EXPECT_FALSE(queue_.Emplace(QuicPacketNumber(1001), "one")); +} + +TEST_F(PacketNumberIndexedQueueTest, RemoveInTheMiddle) { + queue_.Emplace(QuicPacketNumber(1001), "one"); + queue_.Emplace(QuicPacketNumber(1002), "two"); + queue_.Emplace(QuicPacketNumber(1003), "three"); + + ASSERT_TRUE(queue_.Remove(QuicPacketNumber(1002))); + EXPECT_EQ(nullptr, queue_.GetEntry(QuicPacketNumber(1002))); + + EXPECT_EQ(QuicPacketNumber(1001u), queue_.first_packet()); + EXPECT_EQ(QuicPacketNumber(1003u), queue_.last_packet()); + EXPECT_EQ(2u, queue_.number_of_present_entries()); + EXPECT_EQ(3u, queue_.entry_slots_used()); + + EXPECT_FALSE(queue_.Emplace(QuicPacketNumber(1002), "two")); + EXPECT_TRUE(queue_.Emplace(QuicPacketNumber(1004), "four")); +} + +TEST_F(PacketNumberIndexedQueueTest, RemoveAtImmediateEdges) { + queue_.Emplace(QuicPacketNumber(1001), "one"); + queue_.Emplace(QuicPacketNumber(1002), "two"); + queue_.Emplace(QuicPacketNumber(1003), "three"); + ASSERT_TRUE(queue_.Remove(QuicPacketNumber(1001))); + EXPECT_EQ(nullptr, queue_.GetEntry(QuicPacketNumber(1001))); + ASSERT_TRUE(queue_.Remove(QuicPacketNumber(1003))); + EXPECT_EQ(nullptr, queue_.GetEntry(QuicPacketNumber(1003))); + + EXPECT_EQ(QuicPacketNumber(1002u), queue_.first_packet()); + EXPECT_EQ(QuicPacketNumber(1003u), queue_.last_packet()); + EXPECT_EQ(1u, queue_.number_of_present_entries()); + EXPECT_EQ(2u, queue_.entry_slots_used()); + + EXPECT_TRUE(queue_.Emplace(QuicPacketNumber(1004), "four")); +} + +TEST_F(PacketNumberIndexedQueueTest, RemoveAtDistantFront) { + queue_.Emplace(QuicPacketNumber(1001), "one"); + queue_.Emplace(QuicPacketNumber(1002), "one (kinda)"); + queue_.Emplace(QuicPacketNumber(2001), "two"); + + EXPECT_EQ(QuicPacketNumber(1001u), queue_.first_packet()); + EXPECT_EQ(QuicPacketNumber(2001u), queue_.last_packet()); + EXPECT_EQ(3u, queue_.number_of_present_entries()); + EXPECT_EQ(1001u, queue_.entry_slots_used()); + + ASSERT_TRUE(queue_.Remove(QuicPacketNumber(1002))); + EXPECT_EQ(QuicPacketNumber(1001u), queue_.first_packet()); + EXPECT_EQ(QuicPacketNumber(2001u), queue_.last_packet()); + EXPECT_EQ(2u, queue_.number_of_present_entries()); + EXPECT_EQ(1001u, queue_.entry_slots_used()); + + ASSERT_TRUE(queue_.Remove(QuicPacketNumber(1001))); + EXPECT_EQ(QuicPacketNumber(2001u), queue_.first_packet()); + EXPECT_EQ(QuicPacketNumber(2001u), queue_.last_packet()); + EXPECT_EQ(1u, queue_.number_of_present_entries()); + EXPECT_EQ(1u, queue_.entry_slots_used()); +} + +TEST_F(PacketNumberIndexedQueueTest, RemoveAtDistantBack) { + queue_.Emplace(QuicPacketNumber(1001), "one"); + queue_.Emplace(QuicPacketNumber(2001), "two"); + + EXPECT_EQ(QuicPacketNumber(1001u), queue_.first_packet()); + EXPECT_EQ(QuicPacketNumber(2001u), queue_.last_packet()); + + ASSERT_TRUE(queue_.Remove(QuicPacketNumber(2001))); + EXPECT_EQ(QuicPacketNumber(1001u), queue_.first_packet()); + EXPECT_EQ(QuicPacketNumber(2001u), queue_.last_packet()); +} + +TEST_F(PacketNumberIndexedQueueTest, ClearAndRepopulate) { + queue_.Emplace(QuicPacketNumber(1001), "one"); + queue_.Emplace(QuicPacketNumber(2001), "two"); + + ASSERT_TRUE(queue_.Remove(QuicPacketNumber(1001))); + ASSERT_TRUE(queue_.Remove(QuicPacketNumber(2001))); + EXPECT_TRUE(queue_.IsEmpty()); + EXPECT_FALSE(queue_.first_packet().IsInitialized()); + EXPECT_FALSE(queue_.last_packet().IsInitialized()); + + EXPECT_TRUE(queue_.Emplace(QuicPacketNumber(101), "one")); + EXPECT_TRUE(queue_.Emplace(QuicPacketNumber(201), "two")); + EXPECT_EQ(QuicPacketNumber(101u), queue_.first_packet()); + EXPECT_EQ(QuicPacketNumber(201u), queue_.last_packet()); +} + +TEST_F(PacketNumberIndexedQueueTest, FailToRemoveElementsThatNeverExisted) { + ASSERT_FALSE(queue_.Remove(QuicPacketNumber(1000))); + queue_.Emplace(QuicPacketNumber(1001), "one"); + ASSERT_FALSE(queue_.Remove(QuicPacketNumber(1000))); + ASSERT_FALSE(queue_.Remove(QuicPacketNumber(1002))); +} + +TEST_F(PacketNumberIndexedQueueTest, FailToRemoveElementsTwice) { + queue_.Emplace(QuicPacketNumber(1001), "one"); + ASSERT_TRUE(queue_.Remove(QuicPacketNumber(1001))); + ASSERT_FALSE(queue_.Remove(QuicPacketNumber(1001))); + ASSERT_FALSE(queue_.Remove(QuicPacketNumber(1001))); +} + +TEST_F(PacketNumberIndexedQueueTest, RemoveUpTo) { + queue_.Emplace(QuicPacketNumber(1001), "one"); + queue_.Emplace(QuicPacketNumber(2001), "two"); + EXPECT_EQ(QuicPacketNumber(1001u), queue_.first_packet()); + EXPECT_EQ(2u, queue_.number_of_present_entries()); + + queue_.RemoveUpTo(QuicPacketNumber(1001)); + EXPECT_EQ(QuicPacketNumber(1001u), queue_.first_packet()); + EXPECT_EQ(2u, queue_.number_of_present_entries()); + + // Remove up to 1100, since [1100, 2001) are !present, they should be cleaned + // up from the front. + queue_.RemoveUpTo(QuicPacketNumber(1100)); + EXPECT_EQ(QuicPacketNumber(2001u), queue_.first_packet()); + EXPECT_EQ(1u, queue_.number_of_present_entries()); + + queue_.RemoveUpTo(QuicPacketNumber(2001)); + EXPECT_EQ(QuicPacketNumber(2001u), queue_.first_packet()); + EXPECT_EQ(1u, queue_.number_of_present_entries()); + + queue_.RemoveUpTo(QuicPacketNumber(2002)); + EXPECT_FALSE(queue_.first_packet().IsInitialized()); + EXPECT_EQ(0u, queue_.number_of_present_entries()); +} + +TEST_F(PacketNumberIndexedQueueTest, ConstGetter) { + queue_.Emplace(QuicPacketNumber(1001), "one"); + const auto& const_queue = queue_; + + EXPECT_EQ("one", *const_queue.GetEntry(QuicPacketNumber(1001))); + EXPECT_EQ(nullptr, const_queue.GetEntry(QuicPacketNumber(1002))); +} + +} // namespace +} // namespace quic::test diff --git a/quiche/quic/core/proto/cached_network_parameters.proto b/quiche/quic/core/proto/cached_network_parameters.proto new file mode 100644 index 000000000000..d609be9b0d14 --- /dev/null +++ b/quiche/quic/core/proto/cached_network_parameters.proto @@ -0,0 +1,43 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +syntax = "proto2"; + +option optimize_for = LITE_RUNTIME; + +package quic; + +// CachedNetworkParameters contains data that can be used to choose appropriate +// connection parameters (initial RTT, initial CWND, etc.) in new connections. +// Next id: 8 +message CachedNetworkParameters { + // Describes the state of the connection during which the supplied network + // parameters were calculated. + enum PreviousConnectionState { + SLOW_START = 0; + CONGESTION_AVOIDANCE = 1; + } + + // serving_region is used to decide whether or not the bandwidth estimate and + // min RTT are reasonable and if they should be used. + // For example a group of geographically close servers may share the same + // serving_region string if they are expected to have similar network + // performance. + optional string serving_region = 1; + // The server can supply a bandwidth estimate (in bytes/s) which it may re-use + // on receipt of a source-address token with this field set. + optional int32 bandwidth_estimate_bytes_per_second = 2; + // The maximum bandwidth seen to the client, not necessarily the latest. + optional int32 max_bandwidth_estimate_bytes_per_second = 5; + // Timestamp (seconds since UNIX epoch) that indicates when the max bandwidth + // was seen by the server. + optional int64 max_bandwidth_timestamp_seconds = 6; + // The min RTT seen on a previous connection can be used by the server to + // inform initial connection parameters for new connections. + optional int32 min_rtt_ms = 3; + // Encodes the PreviousConnectionState enum. + optional int32 previous_connection_state = 4; + // UNIX timestamp when this bandwidth estimate was created. + optional int64 timestamp = 7; +}; diff --git a/quiche/quic/core/proto/cached_network_parameters_proto.h b/quiche/quic/core/proto/cached_network_parameters_proto.h new file mode 100644 index 000000000000..2b4388525801 --- /dev/null +++ b/quiche/quic/core/proto/cached_network_parameters_proto.h @@ -0,0 +1,10 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_PROTO_CACHED_NETWORK_PARAMETERS_PROTO_H_ +#define QUICHE_QUIC_CORE_PROTO_CACHED_NETWORK_PARAMETERS_PROTO_H_ + +#include "quiche/quic/core/proto/cached_network_parameters.pb.h" + +#endif // QUICHE_QUIC_CORE_PROTO_CACHED_NETWORK_PARAMETERS_PROTO_H_ diff --git a/quiche/quic/core/proto/crypto_server_config.proto b/quiche/quic/core/proto/crypto_server_config.proto new file mode 100644 index 000000000000..bb447dcc4b68 --- /dev/null +++ b/quiche/quic/core/proto/crypto_server_config.proto @@ -0,0 +1,34 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +syntax = "proto2"; + +option optimize_for = LITE_RUNTIME; + +package quic; + +// QuicServerConfigProtobuf contains QUIC server config block and the private +// keys needed to prove ownership. +message QuicServerConfigProtobuf { + // config is a serialised config in QUIC wire format. + required bytes config = 1; + + // PrivateKey contains a QUIC tag of a key exchange algorithm and a + // serialised private key for that algorithm. The format of the serialised + // private key is specific to the algorithm in question. + message PrivateKey { + required uint32 tag = 1; + required bytes private_key = 2; + } + repeated PrivateKey key = 2; + + // primary_time contains a UNIX epoch seconds value that indicates when this + // config should become primary. + optional int64 primary_time = 3; + + // Relative priority of this config vs other configs with the same + // primary time. For use as a secondary sort key when selecting the + // primary config. + optional uint64 priority = 4; +}; diff --git a/quiche/quic/core/proto/crypto_server_config_proto.h b/quiche/quic/core/proto/crypto_server_config_proto.h new file mode 100644 index 000000000000..feda6ee3b61b --- /dev/null +++ b/quiche/quic/core/proto/crypto_server_config_proto.h @@ -0,0 +1,10 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_PROTO_CRYPTO_SERVER_CONFIG_PROTO_H_ +#define QUICHE_QUIC_CORE_PROTO_CRYPTO_SERVER_CONFIG_PROTO_H_ + +#include "quiche/quic/core/proto/crypto_server_config.pb.h" + +#endif // QUICHE_QUIC_CORE_PROTO_CRYPTO_SERVER_CONFIG_PROTO_H_ diff --git a/quiche/quic/core/proto/source_address_token.proto b/quiche/quic/core/proto/source_address_token.proto new file mode 100644 index 000000000000..d261d46ecdbd --- /dev/null +++ b/quiche/quic/core/proto/source_address_token.proto @@ -0,0 +1,32 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +syntax = "proto2"; + +option optimize_for = LITE_RUNTIME; + +import "quiche/quic/core/proto/cached_network_parameters.proto"; + +package quic; + +// A SourceAddressToken is serialised, encrypted and sent to clients so that +// they can prove ownership of an IP address. +message SourceAddressToken { + // ip contains either 4 (IPv4) or 16 (IPv6) bytes of IP address in network + // byte order. + required bytes ip = 1; + // timestamp contains a UNIX timestamp value of the time when the token was + // created. + required int64 timestamp = 2; + // The server can provide estimated network parameters to be used for + // initial parameter selection in future connections. + optional CachedNetworkParameters cached_network_parameters = 3; +}; + +// SourceAddressTokens are simply lists of SourceAddressToken messages. +message SourceAddressTokens { + // This field has id 4 to avoid ambiguity between the serialized form of + // SourceAddressToken vs SourceAddressTokens. + repeated SourceAddressToken tokens = 4; +}; diff --git a/quiche/quic/core/proto/source_address_token_proto.h b/quiche/quic/core/proto/source_address_token_proto.h new file mode 100644 index 000000000000..55c2c001f7b0 --- /dev/null +++ b/quiche/quic/core/proto/source_address_token_proto.h @@ -0,0 +1,10 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_PROTO_SOURCE_ADDRESS_TOKEN_PROTO_H_ +#define QUICHE_QUIC_CORE_PROTO_SOURCE_ADDRESS_TOKEN_PROTO_H_ + +#include "quiche/quic/core/proto/source_address_token.pb.h" + +#endif // QUICHE_QUIC_CORE_PROTO_SOURCE_ADDRESS_TOKEN_PROTO_H_ diff --git a/quiche/quic/core/qpack/fuzzer/qpack_decoder_fuzzer.cc b/quiche/quic/core/qpack/fuzzer/qpack_decoder_fuzzer.cc new file mode 100644 index 000000000000..8927604ce86b --- /dev/null +++ b/quiche/quic/core/qpack/fuzzer/qpack_decoder_fuzzer.cc @@ -0,0 +1,193 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/qpack/qpack_decoder.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/test_tools/qpack/qpack_decoder_test_utils.h" +#include "quiche/quic/test_tools/qpack/qpack_test_utils.h" + +namespace quic { +namespace test { + +struct DecoderAndHandler { + std::unique_ptr decoder; + std::unique_ptr handler; +}; + +using DecoderAndHandlerMap = std::map; + +// Class that sets externally owned |error_detected| to true +// on encoder stream error. +class ErrorDelegate : public QpackDecoder::EncoderStreamErrorDelegate { + public: + ErrorDelegate(bool* error_detected) : error_detected_(error_detected) {} + ~ErrorDelegate() override = default; + + void OnEncoderStreamError(QuicErrorCode /*error_code*/, + absl::string_view /*error_message*/) override { + *error_detected_ = true; + } + + private: + bool* const error_detected_; +}; + +// Class that destroys DecoderAndHandler when decoding completes, and sets +// externally owned |error_detected| to true on encoder stream error. +class HeadersHandler : public QpackProgressiveDecoder::HeadersHandlerInterface { + public: + HeadersHandler(QuicStreamId stream_id, + DecoderAndHandlerMap* processing_decoders, + bool* error_detected) + : stream_id_(stream_id), + processing_decoders_(processing_decoders), + error_detected_(error_detected) {} + ~HeadersHandler() override = default; + + void OnHeaderDecoded(absl::string_view /*name*/, + absl::string_view /*value*/) override {} + + // Remove DecoderAndHandler from |*processing_decoders|. + void OnDecodingCompleted() override { + // Will delete |this|. + size_t result = processing_decoders_->erase(stream_id_); + QUICHE_CHECK_EQ(1u, result); + } + + void OnDecodingErrorDetected(QuicErrorCode /*error_code*/, + absl::string_view /*error_message*/) override { + *error_detected_ = true; + } + + private: + const QuicStreamId stream_id_; + DecoderAndHandlerMap* const processing_decoders_; + bool* const error_detected_; +}; + +// This fuzzer exercises QpackDecoder. It should be able to cover all possible +// code paths. There is no point in encoding QpackDecoder's output to turn this +// into a roundtrip test, because the same header list can be encoded in many +// different ways, so the output could not be expected to match the original +// input. +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + FuzzedDataProvider provider(data, size); + + // Maximum 256 byte dynamic table. Such a small size helps test draining + // entries and eviction. + const uint64_t maximum_dynamic_table_capacity = + provider.ConsumeIntegral(); + // Maximum 256 blocked streams. + const uint64_t maximum_blocked_streams = provider.ConsumeIntegral(); + + // |error_detected| will be set to true if an error is encountered either in a + // header block or on the encoder stream. + bool error_detected = false; + + ErrorDelegate encoder_stream_error_delegate(&error_detected); + QpackDecoder decoder(maximum_dynamic_table_capacity, maximum_blocked_streams, + &encoder_stream_error_delegate); + + NoopQpackStreamSenderDelegate decoder_stream_sender_delegate; + decoder.set_qpack_stream_sender_delegate(&decoder_stream_sender_delegate); + + // Decoders still reading the header block, with corresponding handlers. + DecoderAndHandlerMap reading_decoders; + + // Decoders still processing the completely read header block, + // with corresponding handlers. + DecoderAndHandlerMap processing_decoders; + + // Maximum 256 data fragments to limit runtime and memory usage. + auto fragment_count = provider.ConsumeIntegral(); + while (fragment_count > 0 && !error_detected && + provider.remaining_bytes() > 0) { + --fragment_count; + switch (provider.ConsumeIntegralInRange(0, 3)) { + // Feed encoder stream data to QpackDecoder. + case 0: { + size_t fragment_size = provider.ConsumeIntegral(); + std::string encoded_data = + provider.ConsumeRandomLengthString(fragment_size); + decoder.encoder_stream_receiver()->Decode(encoded_data); + + continue; + } + + // Create new progressive decoder. + case 1: { + QuicStreamId stream_id = provider.ConsumeIntegral(); + if (reading_decoders.find(stream_id) != reading_decoders.end() || + processing_decoders.find(stream_id) != processing_decoders.end()) { + continue; + } + + DecoderAndHandler decoder_and_handler; + decoder_and_handler.handler = std::make_unique( + stream_id, &processing_decoders, &error_detected); + decoder_and_handler.decoder = decoder.CreateProgressiveDecoder( + stream_id, decoder_and_handler.handler.get()); + reading_decoders.insert({stream_id, std::move(decoder_and_handler)}); + + continue; + } + + // Feed header block data to existing decoder. + case 2: { + if (reading_decoders.empty()) { + continue; + } + + auto it = reading_decoders.begin(); + auto distance = provider.ConsumeIntegralInRange( + 0, reading_decoders.size() - 1); + std::advance(it, distance); + + size_t fragment_size = provider.ConsumeIntegral(); + std::string encoded_data = + provider.ConsumeRandomLengthString(fragment_size); + it->second.decoder->Decode(encoded_data); + + continue; + } + + // End header block. + case 3: { + if (reading_decoders.empty()) { + continue; + } + + auto it = reading_decoders.begin(); + auto distance = provider.ConsumeIntegralInRange( + 0, reading_decoders.size() - 1); + std::advance(it, distance); + + QpackProgressiveDecoder* reading_decoder = it->second.decoder.get(); + + // Move DecoderAndHandler to |processing_decoders| first, because + // EndHeaderBlock() might synchronously call OnDecodingCompleted(). + QuicStreamId stream_id = it->first; + processing_decoders.insert({stream_id, std::move(it->second)}); + reading_decoders.erase(it); + + reading_decoder->EndHeaderBlock(); + + continue; + } + } + } + + return 0; +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/qpack/fuzzer/qpack_decoder_stream_receiver_fuzzer.cc b/quiche/quic/core/qpack/fuzzer/qpack_decoder_stream_receiver_fuzzer.cc new file mode 100644 index 000000000000..7d8542eab3b5 --- /dev/null +++ b/quiche/quic/core/qpack/fuzzer/qpack_decoder_stream_receiver_fuzzer.cc @@ -0,0 +1,62 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/qpack/qpack_decoder_stream_receiver.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_stream.h" + +namespace quic { +namespace test { +namespace { + +// A QpackDecoderStreamReceiver::Delegate implementation that ignores all +// decoded instructions but keeps track of whether an error has been detected. +class NoOpDelegate : public QpackDecoderStreamReceiver::Delegate { + public: + NoOpDelegate() : error_detected_(false) {} + ~NoOpDelegate() override = default; + + void OnInsertCountIncrement(uint64_t /*increment*/) override {} + void OnHeaderAcknowledgement(QuicStreamId /*stream_id*/) override {} + void OnStreamCancellation(QuicStreamId /*stream_id*/) override {} + void OnErrorDetected(QuicErrorCode /*error_code*/, + absl::string_view /*error_message*/) override { + error_detected_ = true; + } + + bool error_detected() const { return error_detected_; } + + private: + bool error_detected_; +}; + +} // namespace + +// This fuzzer exercises QpackDecoderStreamReceiver. +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + NoOpDelegate delegate; + QpackDecoderStreamReceiver receiver(&delegate); + + FuzzedDataProvider provider(data, size); + + while (!delegate.error_detected() && provider.remaining_bytes() != 0) { + // Process up to 64 kB fragments at a time. Too small upper bound might not + // provide enough coverage, too large might make fuzzing too inefficient. + size_t fragment_size = provider.ConsumeIntegralInRange( + 0, std::numeric_limits::max()); + receiver.Decode(provider.ConsumeRandomLengthString(fragment_size)); + } + + return 0; +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/qpack/fuzzer/qpack_decoder_stream_sender_fuzzer.cc b/quiche/quic/core/qpack/fuzzer/qpack_decoder_stream_sender_fuzzer.cc new file mode 100644 index 000000000000..57fe7183ee42 --- /dev/null +++ b/quiche/quic/core/qpack/fuzzer/qpack_decoder_stream_sender_fuzzer.cc @@ -0,0 +1,54 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include + +#include +#include + +#include "quiche/quic/core/qpack/qpack_decoder_stream_sender.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/test_tools/qpack/qpack_test_utils.h" + +namespace quic { +namespace test { + +// This fuzzer exercises QpackDecoderStreamSender. +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + NoopQpackStreamSenderDelegate delegate; + QpackDecoderStreamSender sender; + sender.set_qpack_stream_sender_delegate(&delegate); + + FuzzedDataProvider provider(data, size); + + while (provider.remaining_bytes() != 0) { + switch (provider.ConsumeIntegral() % 4) { + case 0: { + uint64_t increment = provider.ConsumeIntegral(); + sender.SendInsertCountIncrement(increment); + break; + } + case 1: { + QuicStreamId stream_id = provider.ConsumeIntegral(); + sender.SendHeaderAcknowledgement(stream_id); + break; + } + case 2: { + QuicStreamId stream_id = provider.ConsumeIntegral(); + sender.SendStreamCancellation(stream_id); + break; + } + case 3: { + sender.Flush(); + break; + } + } + } + + sender.Flush(); + return 0; +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/qpack/fuzzer/qpack_encoder_stream_receiver_fuzzer.cc b/quiche/quic/core/qpack/fuzzer/qpack_encoder_stream_receiver_fuzzer.cc new file mode 100644 index 000000000000..53aaa63adef2 --- /dev/null +++ b/quiche/quic/core/qpack/fuzzer/qpack_encoder_stream_receiver_fuzzer.cc @@ -0,0 +1,67 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/qpack/qpack_encoder_stream_receiver.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { +namespace test { +namespace { + +// A QpackEncoderStreamReceiver::Delegate implementation that ignores all +// decoded instructions but keeps track of whether an error has been detected. +class NoOpDelegate : public QpackEncoderStreamReceiver::Delegate { + public: + NoOpDelegate() : error_detected_(false) {} + ~NoOpDelegate() override = default; + + void OnInsertWithNameReference(bool /*is_static*/, uint64_t /*name_index*/, + absl::string_view /*value*/) override {} + void OnInsertWithoutNameReference(absl::string_view /*name*/, + absl::string_view /*value*/) override {} + void OnDuplicate(uint64_t /*index*/) override {} + void OnSetDynamicTableCapacity(uint64_t /*capacity*/) override {} + void OnErrorDetected(QuicErrorCode /*error_code*/, + absl::string_view /*error_message*/) override { + error_detected_ = true; + } + + bool error_detected() const { return error_detected_; } + + private: + bool error_detected_; +}; + +} // namespace + +// This fuzzer exercises QpackEncoderStreamReceiver. +// Note that since string literals may be encoded with or without Huffman +// encoding, one could not expect identical encoded data if the decoded +// instructions were fed into QpackEncoderStreamSender. Therefore there is no +// point in extending this fuzzer into a round-trip test. +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + NoOpDelegate delegate; + QpackEncoderStreamReceiver receiver(&delegate); + + FuzzedDataProvider provider(data, size); + + while (!delegate.error_detected() && provider.remaining_bytes() != 0) { + // Process up to 64 kB fragments at a time. Too small upper bound might not + // provide enough coverage, too large might make fuzzing too inefficient. + size_t fragment_size = provider.ConsumeIntegralInRange( + 0, std::numeric_limits::max()); + receiver.Decode(provider.ConsumeRandomLengthString(fragment_size)); + } + + return 0; +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/qpack/fuzzer/qpack_encoder_stream_sender_fuzzer.cc b/quiche/quic/core/qpack/fuzzer/qpack_encoder_stream_sender_fuzzer.cc new file mode 100644 index 000000000000..5109f61c6a41 --- /dev/null +++ b/quiche/quic/core/qpack/fuzzer/qpack_encoder_stream_sender_fuzzer.cc @@ -0,0 +1,72 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include + +#include +#include +#include +#include + +#include "quiche/quic/core/qpack/qpack_encoder_stream_sender.h" +#include "quiche/quic/test_tools/qpack/qpack_test_utils.h" + +namespace quic { +namespace test { + +// This fuzzer exercises QpackEncoderStreamSender. +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + NoopQpackStreamSenderDelegate delegate; + QpackEncoderStreamSender sender; + sender.set_qpack_stream_sender_delegate(&delegate); + + FuzzedDataProvider provider(data, size); + // Limit string literal length to 2 kB for efficiency. + const uint16_t kMaxStringLength = 2048; + + while (provider.remaining_bytes() != 0) { + switch (provider.ConsumeIntegral() % 5) { + case 0: { + bool is_static = provider.ConsumeBool(); + uint64_t name_index = provider.ConsumeIntegral(); + uint16_t value_length = + provider.ConsumeIntegralInRange(0, kMaxStringLength); + std::string value = provider.ConsumeRandomLengthString(value_length); + + sender.SendInsertWithNameReference(is_static, name_index, value); + break; + } + case 1: { + uint16_t name_length = + provider.ConsumeIntegralInRange(0, kMaxStringLength); + std::string name = provider.ConsumeRandomLengthString(name_length); + uint16_t value_length = + provider.ConsumeIntegralInRange(0, kMaxStringLength); + std::string value = provider.ConsumeRandomLengthString(value_length); + sender.SendInsertWithoutNameReference(name, value); + break; + } + case 2: { + uint64_t index = provider.ConsumeIntegral(); + sender.SendDuplicate(index); + break; + } + case 3: { + uint64_t capacity = provider.ConsumeIntegral(); + sender.SendSetDynamicTableCapacity(capacity); + break; + } + case 4: { + sender.Flush(); + break; + } + } + } + + sender.Flush(); + return 0; +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/qpack/fuzzer/qpack_round_trip_fuzzer.cc b/quiche/quic/core/qpack/fuzzer/qpack_round_trip_fuzzer.cc new file mode 100644 index 000000000000..b046565d840e --- /dev/null +++ b/quiche/quic/core/qpack/fuzzer/qpack_round_trip_fuzzer.cc @@ -0,0 +1,661 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include + +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/http/quic_header_list.h" +#include "quiche/quic/core/qpack/qpack_decoded_headers_accumulator.h" +#include "quiche/quic/core/qpack/qpack_decoder.h" +#include "quiche/quic/core/qpack/qpack_encoder.h" +#include "quiche/quic/core/qpack/qpack_stream_sender_delegate.h" +#include "quiche/quic/core/qpack/value_splitting_header_list.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/test_tools/qpack/qpack_decoder_test_utils.h" +#include "quiche/quic/test_tools/qpack/qpack_encoder_peer.h" +#include "quiche/common/quiche_circular_deque.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { +namespace test { +namespace { + +// Find the first occurrence of invalid characters NUL, LF, CR in |*value| and +// remove that and the remaining of the string. +void TruncateValueOnInvalidChars(std::string* value) { + for (auto it = value->begin(); it != value->end(); ++it) { + if (*it == '\0' || *it == '\n' || *it == '\r') { + value->erase(it, value->end()); + return; + } + } +} + +} // anonymous namespace + +// Class to hold QpackEncoder and its DecoderStreamErrorDelegate. +class EncodingEndpoint { + public: + EncodingEndpoint(uint64_t maximum_dynamic_table_capacity, + uint64_t maximum_blocked_streams) + : encoder_(&decoder_stream_error_delegate) { + encoder_.SetMaximumDynamicTableCapacity(maximum_dynamic_table_capacity); + encoder_.SetMaximumBlockedStreams(maximum_blocked_streams); + } + + ~EncodingEndpoint() { + // Every reference should be acknowledged. + QUICHE_CHECK_EQ(std::numeric_limits::max(), + QpackEncoderPeer::smallest_blocking_index(&encoder_)); + } + + void set_qpack_stream_sender_delegate(QpackStreamSenderDelegate* delegate) { + encoder_.set_qpack_stream_sender_delegate(delegate); + } + + void SetDynamicTableCapacity(uint64_t maximum_dynamic_table_capacity) { + encoder_.SetDynamicTableCapacity(maximum_dynamic_table_capacity); + } + + QpackStreamReceiver* decoder_stream_receiver() { + return encoder_.decoder_stream_receiver(); + } + + std::string EncodeHeaderList(QuicStreamId stream_id, + const spdy::Http2HeaderBlock& header_list) { + return encoder_.EncodeHeaderList(stream_id, header_list, nullptr); + } + + private: + // DecoderStreamErrorDelegate implementation that crashes on error. + class CrashingDecoderStreamErrorDelegate + : public QpackEncoder::DecoderStreamErrorDelegate { + public: + ~CrashingDecoderStreamErrorDelegate() override = default; + + void OnDecoderStreamError(QuicErrorCode error_code, + absl::string_view error_message) override { + QUICHE_CHECK(false) << QuicErrorCodeToString(error_code) << " " + << error_message; + } + }; + + CrashingDecoderStreamErrorDelegate decoder_stream_error_delegate; + QpackEncoder encoder_; +}; + +// Class that receives all header blocks from the encoding endpoint and passes +// them to the decoding endpoint, with delay determined by fuzzer data, +// preserving order within each stream but not among streams. +class DelayedHeaderBlockTransmitter { + public: + class Visitor { + public: + virtual ~Visitor() = default; + + // If decoding of the previous header block is still in progress, then + // DelayedHeaderBlockTransmitter will not start transmitting the next header + // block. + virtual bool IsDecodingInProgressOnStream(QuicStreamId stream_id) = 0; + + // Called when a header block starts. + virtual void OnHeaderBlockStart(QuicStreamId stream_id) = 0; + // Called when part or all of a header block is transmitted. + virtual void OnHeaderBlockFragment(QuicStreamId stream_id, + absl::string_view data) = 0; + // Called when transmission of a header block is complete. + virtual void OnHeaderBlockEnd(QuicStreamId stream_id) = 0; + }; + + DelayedHeaderBlockTransmitter(Visitor* visitor, FuzzedDataProvider* provider) + : visitor_(visitor), provider_(provider) {} + + ~DelayedHeaderBlockTransmitter() { QUICHE_CHECK(header_blocks_.empty()); } + + // Enqueues |encoded_header_block| for delayed transmission. + void SendEncodedHeaderBlock(QuicStreamId stream_id, + std::string encoded_header_block) { + auto it = header_blocks_.lower_bound(stream_id); + if (it == header_blocks_.end() || it->first != stream_id) { + it = header_blocks_.insert(it, {stream_id, {}}); + } + QUICHE_CHECK_EQ(stream_id, it->first); + it->second.push(HeaderBlock(std::move(encoded_header_block))); + } + + // Release some (possibly none) header block data. + void MaybeTransmitSomeData() { + if (header_blocks_.empty()) { + return; + } + + auto index = + provider_->ConsumeIntegralInRange(0, header_blocks_.size() - 1); + auto it = header_blocks_.begin(); + std::advance(it, index); + const QuicStreamId stream_id = it->first; + + // Do not start new header block if processing of previous header block is + // blocked. + if (visitor_->IsDecodingInProgressOnStream(stream_id)) { + return; + } + + auto& header_block_queue = it->second; + HeaderBlock& header_block = header_block_queue.front(); + + if (header_block.ConsumedLength() == 0) { + visitor_->OnHeaderBlockStart(stream_id); + } + + QUICHE_DCHECK_NE(0u, header_block.RemainingLength()); + + size_t length = provider_->ConsumeIntegralInRange( + 1, header_block.RemainingLength()); + visitor_->OnHeaderBlockFragment(stream_id, header_block.Consume(length)); + + QUICHE_DCHECK_NE(0u, header_block.ConsumedLength()); + + if (header_block.RemainingLength() == 0) { + visitor_->OnHeaderBlockEnd(stream_id); + + header_block_queue.pop(); + if (header_block_queue.empty()) { + header_blocks_.erase(it); + } + } + } + + // Release all header block data. Must be called before destruction. All + // encoder stream data must have been released before calling Flush() so that + // all header blocks can be decoded synchronously. + void Flush() { + while (!header_blocks_.empty()) { + auto it = header_blocks_.begin(); + const QuicStreamId stream_id = it->first; + + auto& header_block_queue = it->second; + HeaderBlock& header_block = header_block_queue.front(); + + if (header_block.ConsumedLength() == 0) { + QUICHE_CHECK(!visitor_->IsDecodingInProgressOnStream(stream_id)); + visitor_->OnHeaderBlockStart(stream_id); + } + + QUICHE_DCHECK_NE(0u, header_block.RemainingLength()); + + visitor_->OnHeaderBlockFragment(stream_id, + header_block.ConsumeRemaining()); + + QUICHE_DCHECK_NE(0u, header_block.ConsumedLength()); + QUICHE_DCHECK_EQ(0u, header_block.RemainingLength()); + + visitor_->OnHeaderBlockEnd(stream_id); + QUICHE_CHECK(!visitor_->IsDecodingInProgressOnStream(stream_id)); + + header_block_queue.pop(); + if (header_block_queue.empty()) { + header_blocks_.erase(it); + } + } + } + + private: + // Helper class that allows the header block to be consumed in parts. + class HeaderBlock { + public: + explicit HeaderBlock(std::string data) + : data_(std::move(data)), offset_(0) { + // Valid QPACK header block cannot be empty. + QUICHE_DCHECK(!data_.empty()); + } + + size_t ConsumedLength() const { return offset_; } + + size_t RemainingLength() const { return data_.length() - offset_; } + + absl::string_view Consume(size_t length) { + QUICHE_DCHECK_NE(0u, length); + QUICHE_DCHECK_LE(length, RemainingLength()); + + absl::string_view consumed = absl::string_view(&data_[offset_], length); + offset_ += length; + return consumed; + } + + absl::string_view ConsumeRemaining() { return Consume(RemainingLength()); } + + private: + // Complete header block. + const std::string data_; + + // Offset of the part not consumed yet. Same as number of consumed bytes. + size_t offset_; + }; + + Visitor* const visitor_; + FuzzedDataProvider* const provider_; + + std::map> header_blocks_; +}; + +// Class to decode and verify a header block, and in case of blocked decoding, +// keep necessary decoding context while waiting for decoding to complete. +class VerifyingDecoder : public QpackDecodedHeadersAccumulator::Visitor { + public: + class Visitor { + public: + virtual ~Visitor() = default; + + // Called when header block is decoded, either synchronously or + // asynchronously. Might destroy VerifyingDecoder. + virtual void OnHeaderBlockDecoded(QuicStreamId stream_id) = 0; + }; + + VerifyingDecoder(QuicStreamId stream_id, Visitor* visitor, + QpackDecoder* qpack_decoder, + QuicHeaderList expected_header_list) + : stream_id_(stream_id), + visitor_(visitor), + accumulator_( + stream_id, qpack_decoder, this, + /* max_header_list_size = */ std::numeric_limits::max()), + expected_header_list_(std::move(expected_header_list)) {} + + VerifyingDecoder(const VerifyingDecoder&) = delete; + VerifyingDecoder& operator=(const VerifyingDecoder&) = delete; + // VerifyingDecoder must not be moved because it passes |this| to + // |accumulator_| upon construction. + VerifyingDecoder(VerifyingDecoder&&) = delete; + VerifyingDecoder& operator=(VerifyingDecoder&&) = delete; + + virtual ~VerifyingDecoder() = default; + + // QpackDecodedHeadersAccumulator::Visitor implementation. + void OnHeadersDecoded(QuicHeaderList headers, + bool header_list_size_limit_exceeded) override { + // Verify headers. + QUICHE_CHECK(!header_list_size_limit_exceeded); + QUICHE_CHECK(expected_header_list_ == headers); + + // Might destroy |this|. + visitor_->OnHeaderBlockDecoded(stream_id_); + } + + void OnHeaderDecodingError(QuicErrorCode error_code, + absl::string_view error_message) override { + QUICHE_CHECK(false) << QuicErrorCodeToString(error_code) << " " + << error_message; + } + + void Decode(absl::string_view data) { accumulator_.Decode(data); } + + void EndHeaderBlock() { accumulator_.EndHeaderBlock(); } + + private: + QuicStreamId stream_id_; + Visitor* const visitor_; + QpackDecodedHeadersAccumulator accumulator_; + QuicHeaderList expected_header_list_; +}; + +// Class that holds QpackDecoder and its EncoderStreamErrorDelegate, and creates +// and keeps VerifyingDecoders for each received header block until decoding is +// complete. +class DecodingEndpoint : public DelayedHeaderBlockTransmitter::Visitor, + public VerifyingDecoder::Visitor { + public: + DecodingEndpoint(uint64_t maximum_dynamic_table_capacity, + uint64_t maximum_blocked_streams) + : decoder_(maximum_dynamic_table_capacity, maximum_blocked_streams, + &encoder_stream_error_delegate_) {} + + ~DecodingEndpoint() override { + // All decoding must have been completed. + QUICHE_CHECK(expected_header_lists_.empty()); + QUICHE_CHECK(verifying_decoders_.empty()); + } + + void set_qpack_stream_sender_delegate(QpackStreamSenderDelegate* delegate) { + decoder_.set_qpack_stream_sender_delegate(delegate); + } + + QpackStreamReceiver* encoder_stream_receiver() { + return decoder_.encoder_stream_receiver(); + } + + void AddExpectedHeaderList(QuicStreamId stream_id, + QuicHeaderList expected_header_list) { + auto it = expected_header_lists_.lower_bound(stream_id); + if (it == expected_header_lists_.end() || it->first != stream_id) { + it = expected_header_lists_.insert(it, {stream_id, {}}); + } + QUICHE_CHECK_EQ(stream_id, it->first); + it->second.push(std::move(expected_header_list)); + } + + // VerifyingDecoder::Visitor implementation. + void OnHeaderBlockDecoded(QuicStreamId stream_id) override { + auto result = verifying_decoders_.erase(stream_id); + QUICHE_CHECK_EQ(1u, result); + } + + // DelayedHeaderBlockTransmitter::Visitor implementation. + bool IsDecodingInProgressOnStream(QuicStreamId stream_id) override { + return verifying_decoders_.find(stream_id) != verifying_decoders_.end(); + } + + void OnHeaderBlockStart(QuicStreamId stream_id) override { + QUICHE_CHECK(!IsDecodingInProgressOnStream(stream_id)); + auto it = expected_header_lists_.find(stream_id); + QUICHE_CHECK(it != expected_header_lists_.end()); + + auto& header_list_queue = it->second; + QuicHeaderList expected_header_list = std::move(header_list_queue.front()); + + header_list_queue.pop(); + if (header_list_queue.empty()) { + expected_header_lists_.erase(it); + } + + auto verifying_decoder = std::make_unique( + stream_id, this, &decoder_, std::move(expected_header_list)); + auto result = + verifying_decoders_.insert({stream_id, std::move(verifying_decoder)}); + QUICHE_CHECK(result.second); + } + + void OnHeaderBlockFragment(QuicStreamId stream_id, + absl::string_view data) override { + auto it = verifying_decoders_.find(stream_id); + QUICHE_CHECK(it != verifying_decoders_.end()); + it->second->Decode(data); + } + + void OnHeaderBlockEnd(QuicStreamId stream_id) override { + auto it = verifying_decoders_.find(stream_id); + QUICHE_CHECK(it != verifying_decoders_.end()); + it->second->EndHeaderBlock(); + } + + private: + // EncoderStreamErrorDelegate implementation that crashes on error. + class CrashingEncoderStreamErrorDelegate + : public QpackDecoder::EncoderStreamErrorDelegate { + public: + ~CrashingEncoderStreamErrorDelegate() override = default; + + void OnEncoderStreamError(QuicErrorCode error_code, + absl::string_view error_message) override { + QUICHE_CHECK(false) << QuicErrorCodeToString(error_code) << " " + << error_message; + } + }; + + CrashingEncoderStreamErrorDelegate encoder_stream_error_delegate_; + QpackDecoder decoder_; + + // Expected header lists in order for each stream. + std::map> expected_header_lists_; + + // A VerifyingDecoder object keeps context necessary for asynchronously + // decoding blocked header blocks. It is destroyed as soon as it signals that + // decoding is completed, which might happen synchronously within an + // EndHeaderBlock() call. + std::map> verifying_decoders_; +}; + +// Class that receives encoder stream data from the encoder and passes it to the +// decoder, or receives decoder stream data from the decoder and passes it to +// the encoder, with delay determined by fuzzer data. +class DelayedStreamDataTransmitter : public QpackStreamSenderDelegate { + public: + DelayedStreamDataTransmitter(QpackStreamReceiver* receiver, + FuzzedDataProvider* provider) + : receiver_(receiver), provider_(provider) {} + + ~DelayedStreamDataTransmitter() { QUICHE_CHECK(stream_data.empty()); } + + // QpackStreamSenderDelegate implementation. + void WriteStreamData(absl::string_view data) override { + stream_data.push_back(std::string(data.data(), data.size())); + } + uint64_t NumBytesBuffered() const override { return 0; } + + // Release some (possibly none) delayed stream data. + void MaybeTransmitSomeData() { + auto count = provider_->ConsumeIntegral(); + while (!stream_data.empty() && count > 0) { + receiver_->Decode(stream_data.front()); + stream_data.pop_front(); + --count; + } + } + + // Release all delayed stream data. Must be called before destruction. + void Flush() { + while (!stream_data.empty()) { + receiver_->Decode(stream_data.front()); + stream_data.pop_front(); + } + } + + private: + QpackStreamReceiver* const receiver_; + FuzzedDataProvider* const provider_; + quiche::QuicheCircularDeque stream_data; +}; + +// Generate header list using fuzzer data. +spdy::Http2HeaderBlock GenerateHeaderList(FuzzedDataProvider* provider) { + spdy::Http2HeaderBlock header_list; + uint8_t header_count = provider->ConsumeIntegral(); + for (uint8_t header_index = 0; header_index < header_count; ++header_index) { + if (provider->remaining_bytes() == 0) { + // Do not add more headers if there is no more fuzzer data. + break; + } + + std::string name; + std::string value; + switch (provider->ConsumeIntegral()) { + case 0: + // Static table entry with no header value. + name = ":authority"; + break; + case 1: + // Static table entry with no header value, using non-empty header + // value. + name = ":authority"; + value = "www.example.org"; + break; + case 2: + // Static table entry with header value, using that header value. + name = ":accept-encoding"; + value = "gzip, deflate"; + break; + case 3: + // Static table entry with header value, using empty header value. + name = ":accept-encoding"; + break; + case 4: + // Static table entry with header value, using different, non-empty + // header value. + name = ":accept-encoding"; + value = "brotli"; + break; + case 5: + // Header name that has multiple entries in the static table, + // using header value from one of them. + name = ":method"; + value = "GET"; + break; + case 6: + // Header name that has multiple entries in the static table, + // using empty header value. + name = ":method"; + break; + case 7: + // Header name that has multiple entries in the static table, + // using different, non-empty header value. + name = ":method"; + value = "CONNECT"; + break; + case 8: + // Header name not in the static table, empty header value. + name = "foo"; + value = ""; + break; + case 9: + // Header name not in the static table, non-empty fixed header value. + name = "foo"; + value = "bar"; + break; + case 10: + // Header name not in the static table, fuzzed header value. + name = "foo"; + value = provider->ConsumeRandomLengthString(128); + TruncateValueOnInvalidChars(&value); + break; + case 11: + // Another header name not in the static table, empty header value. + name = "bar"; + value = ""; + break; + case 12: + // Another header name not in the static table, non-empty fixed header + // value. + name = "bar"; + value = "baz"; + break; + case 13: + // Another header name not in the static table, fuzzed header value. + name = "bar"; + value = provider->ConsumeRandomLengthString(128); + TruncateValueOnInvalidChars(&value); + break; + default: + // Fuzzed header name and header value. + name = provider->ConsumeRandomLengthString(128); + value = provider->ConsumeRandomLengthString(128); + TruncateValueOnInvalidChars(&value); + } + + header_list.AppendValueOrAddHeader(name, value); + } + + return header_list; +} + +// Splits |*header_list| header values along '\0' or ';' separators. +QuicHeaderList SplitHeaderList(const spdy::Http2HeaderBlock& header_list) { + QuicHeaderList split_header_list; + split_header_list.OnHeaderBlockStart(); + + size_t total_size = 0; + ValueSplittingHeaderList splitting_header_list(&header_list); + for (const auto& header : splitting_header_list) { + split_header_list.OnHeader(header.first, header.second); + total_size += header.first.size() + header.second.size(); + } + + split_header_list.OnHeaderBlockEnd(total_size, total_size); + + return split_header_list; +} + +// This fuzzer exercises QpackEncoder and QpackDecoder. It should be able to +// cover all possible code paths of QpackEncoder. However, since the resulting +// header block is always valid and is encoded in a particular way, this fuzzer +// is not expected to cover all code paths of QpackDecoder. On the other hand, +// encoding then decoding is expected to result in the original header list, and +// this fuzzer checks for that. +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + FuzzedDataProvider provider(data, size); + + // Maximum 256 byte dynamic table. Such a small size helps test draining + // entries and eviction. + const uint64_t maximum_dynamic_table_capacity = + provider.ConsumeIntegral(); + // Maximum 256 blocked streams. + const uint64_t maximum_blocked_streams = provider.ConsumeIntegral(); + + // Set up encoder. + EncodingEndpoint encoder(maximum_dynamic_table_capacity, + maximum_blocked_streams); + + // Set up decoder. + DecodingEndpoint decoder(maximum_dynamic_table_capacity, + maximum_blocked_streams); + + // Transmit encoder stream data from encoder to decoder. + DelayedStreamDataTransmitter encoder_stream_transmitter( + decoder.encoder_stream_receiver(), &provider); + encoder.set_qpack_stream_sender_delegate(&encoder_stream_transmitter); + + // Use a dynamic table as large as the peer allows. This sends data on the + // encoder stream, so it can only be done after delegate is set. + encoder.SetDynamicTableCapacity(maximum_dynamic_table_capacity); + + // Transmit decoder stream data from encoder to decoder. + DelayedStreamDataTransmitter decoder_stream_transmitter( + encoder.decoder_stream_receiver(), &provider); + decoder.set_qpack_stream_sender_delegate(&decoder_stream_transmitter); + + // Transmit header blocks from encoder to decoder. + DelayedHeaderBlockTransmitter header_block_transmitter(&decoder, &provider); + + // Maximum 256 header lists to limit runtime and memory usage. + auto header_list_count = provider.ConsumeIntegral(); + while (header_list_count > 0 && provider.remaining_bytes() > 0) { + const QuicStreamId stream_id = provider.ConsumeIntegral(); + + // Generate header list. + spdy::Http2HeaderBlock header_list = GenerateHeaderList(&provider); + + // Encode header list. + std::string encoded_header_block = + encoder.EncodeHeaderList(stream_id, header_list); + + // TODO(bnc): Randomly cancel the stream. + + // Encoder splits |header_list| header values along '\0' or ';' separators. + // Do the same here so that we get matching results. + QuicHeaderList expected_header_list = SplitHeaderList(header_list); + decoder.AddExpectedHeaderList(stream_id, std::move(expected_header_list)); + + header_block_transmitter.SendEncodedHeaderBlock( + stream_id, std::move(encoded_header_block)); + + // Transmit some encoder stream data, decoder stream data, or header blocks + // on the request stream, repeating a few times. + for (auto transmit_data_count = provider.ConsumeIntegralInRange(1, 5); + transmit_data_count > 0; --transmit_data_count) { + encoder_stream_transmitter.MaybeTransmitSomeData(); + decoder_stream_transmitter.MaybeTransmitSomeData(); + header_block_transmitter.MaybeTransmitSomeData(); + } + + --header_list_count; + } + + // Release all delayed encoder stream data so that remaining header blocks can + // be decoded synchronously. + encoder_stream_transmitter.Flush(); + // Release all delayed header blocks. + header_block_transmitter.Flush(); + // Release all delayed decoder stream data. + decoder_stream_transmitter.Flush(); + + return 0; +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_blocking_manager.cc b/quiche/quic/core/qpack/qpack_blocking_manager.cc new file mode 100644 index 000000000000..17b6e9fda2fc --- /dev/null +++ b/quiche/quic/core/qpack/qpack_blocking_manager.cc @@ -0,0 +1,158 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_blocking_manager.h" + +#include +#include + +namespace quic { + +QpackBlockingManager::QpackBlockingManager() : known_received_count_(0) {} + +bool QpackBlockingManager::OnHeaderAcknowledgement(QuicStreamId stream_id) { + auto it = header_blocks_.find(stream_id); + if (it == header_blocks_.end()) { + return false; + } + + QUICHE_DCHECK(!it->second.empty()); + + const IndexSet& indices = it->second.front(); + QUICHE_DCHECK(!indices.empty()); + + const uint64_t required_index_count = RequiredInsertCount(indices); + if (known_received_count_ < required_index_count) { + known_received_count_ = required_index_count; + } + + DecreaseReferenceCounts(indices); + + it->second.pop_front(); + if (it->second.empty()) { + header_blocks_.erase(it); + } + + return true; +} + +void QpackBlockingManager::OnStreamCancellation(QuicStreamId stream_id) { + auto it = header_blocks_.find(stream_id); + if (it == header_blocks_.end()) { + return; + } + + for (const IndexSet& indices : it->second) { + DecreaseReferenceCounts(indices); + } + + header_blocks_.erase(it); +} + +bool QpackBlockingManager::OnInsertCountIncrement(uint64_t increment) { + if (increment > + std::numeric_limits::max() - known_received_count_) { + return false; + } + + known_received_count_ += increment; + return true; +} + +void QpackBlockingManager::OnHeaderBlockSent(QuicStreamId stream_id, + IndexSet indices) { + QUICHE_DCHECK(!indices.empty()); + + IncreaseReferenceCounts(indices); + header_blocks_[stream_id].push_back(std::move(indices)); +} + +bool QpackBlockingManager::blocking_allowed_on_stream( + QuicStreamId stream_id, uint64_t maximum_blocked_streams) const { + // This should be the most common case: the limit is larger than the number of + // streams that have unacknowledged header blocks (regardless of whether they + // are blocked or not) plus one for stream |stream_id|. + if (header_blocks_.size() + 1 <= maximum_blocked_streams) { + return true; + } + + // This should be another common case: no blocked stream allowed. + if (maximum_blocked_streams == 0) { + return false; + } + + uint64_t blocked_stream_count = 0; + for (const auto& header_blocks_for_stream : header_blocks_) { + for (const IndexSet& indices : header_blocks_for_stream.second) { + if (RequiredInsertCount(indices) > known_received_count_) { + if (header_blocks_for_stream.first == stream_id) { + // Sending blocking references is allowed if stream |stream_id| is + // already blocked. + return true; + } + ++blocked_stream_count; + // If stream |stream_id| is already blocked, then it is not counted yet, + // therefore the number of blocked streams is at least + // |blocked_stream_count + 1|, which cannot be more than + // |maximum_blocked_streams| by API contract. + // If stream |stream_id| is not blocked, then blocking will increase the + // blocked stream count to at least |blocked_stream_count + 1|. If that + // is larger than |maximum_blocked_streams|, then blocking is not + // allowed on stream |stream_id|. + if (blocked_stream_count + 1 > maximum_blocked_streams) { + return false; + } + break; + } + } + } + + // Stream |stream_id| is not blocked. + // If there are no blocked streams, then + // |blocked_stream_count + 1 <= maximum_blocked_streams| because + // |maximum_blocked_streams| is larger than zero. + // If there are are blocked streams, then + // |blocked_stream_count + 1 <= maximum_blocked_streams| otherwise the method + // would have returned false when |blocked_stream_count| was incremented. + // Therefore blocking on |stream_id| is allowed. + return true; +} + +uint64_t QpackBlockingManager::smallest_blocking_index() const { + return entry_reference_counts_.empty() + ? std::numeric_limits::max() + : entry_reference_counts_.begin()->first; +} + +// static +uint64_t QpackBlockingManager::RequiredInsertCount(const IndexSet& indices) { + return *indices.rbegin() + 1; +} + +void QpackBlockingManager::IncreaseReferenceCounts(const IndexSet& indices) { + for (const uint64_t index : indices) { + auto it = entry_reference_counts_.lower_bound(index); + if (it != entry_reference_counts_.end() && it->first == index) { + ++it->second; + } else { + entry_reference_counts_.insert(it, {index, 1}); + } + } +} + +void QpackBlockingManager::DecreaseReferenceCounts(const IndexSet& indices) { + for (const uint64_t index : indices) { + auto it = entry_reference_counts_.find(index); + QUICHE_DCHECK(it != entry_reference_counts_.end()); + QUICHE_DCHECK_NE(0u, it->second); + + if (it->second == 1) { + entry_reference_counts_.erase(it); + } else { + --it->second; + } + } +} + +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_blocking_manager.h b/quiche/quic/core/qpack/qpack_blocking_manager.h new file mode 100644 index 000000000000..46e75b803e72 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_blocking_manager.h @@ -0,0 +1,98 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QPACK_QPACK_BLOCKING_MANAGER_H_ +#define QUICHE_QUIC_CORE_QPACK_QPACK_BLOCKING_MANAGER_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +namespace test { + +class QpackBlockingManagerPeer; + +} // namespace test + +// Class to keep track of blocked streams and blocking dynamic table entries: +// https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#blocked-decoding +// https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#blocked-insertion +class QUIC_EXPORT_PRIVATE QpackBlockingManager { + public: + using IndexSet = std::multiset; + + QpackBlockingManager(); + + // Called when a Header Acknowledgement instruction is received on the decoder + // stream. Returns false if there are no outstanding header blocks to be + // acknowledged on |stream_id|. + bool OnHeaderAcknowledgement(QuicStreamId stream_id); + + // Called when a Stream Cancellation instruction is received on the decoder + // stream. + void OnStreamCancellation(QuicStreamId stream_id); + + // Called when an Insert Count Increment instruction is received on the + // decoder stream. Returns true if Known Received Count is successfully + // updated. Returns false on overflow. + bool OnInsertCountIncrement(uint64_t increment); + + // Called when sending a header block containing references to dynamic table + // entries with |indices|. |indices| must not be empty. + void OnHeaderBlockSent(QuicStreamId stream_id, IndexSet indices); + + // Returns true if sending blocking references on stream |stream_id| would not + // increase the total number of blocked streams above + // |maximum_blocked_streams|. Note that if |stream_id| is already blocked + // then it is always allowed to send more blocking references on it. + // Behavior is undefined if |maximum_blocked_streams| is smaller than number + // of currently blocked streams. + bool blocking_allowed_on_stream(QuicStreamId stream_id, + uint64_t maximum_blocked_streams) const; + + // Returns the index of the blocking entry with the smallest index, + // or std::numeric_limits::max() if there are no blocking entries. + uint64_t smallest_blocking_index() const; + + // Returns the Known Received Count as defined at + // https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#known-received-count. + uint64_t known_received_count() const { return known_received_count_; } + + // Required Insert Count for set of indices. + static uint64_t RequiredInsertCount(const IndexSet& indices); + + private: + friend test::QpackBlockingManagerPeer; + + // A stream typically has only one header block, except for the rare cases of + // 1xx responses, trailers, or push promises. Even if there are multiple + // header blocks sent on a single stream, they might not be blocked at the + // same time. Use std::list instead of quiche::QuicheCircularDeque because it + // has lower memory footprint when holding few elements. + using HeaderBlocksForStream = std::list; + using HeaderBlocks = absl::flat_hash_map; + + // Increase or decrease the reference count for each index in |indices|. + void IncreaseReferenceCounts(const IndexSet& indices); + void DecreaseReferenceCounts(const IndexSet& indices); + + // Multiset of indices in each header block for each stream. + // Must not contain a stream id with an empty queue. + HeaderBlocks header_blocks_; + + // Number of references in |header_blocks_| for each entry index. + std::map entry_reference_counts_; + + uint64_t known_received_count_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QPACK_QPACK_BLOCKING_MANAGER_H_ diff --git a/quiche/quic/core/qpack/qpack_blocking_manager_test.cc b/quiche/quic/core/qpack/qpack_blocking_manager_test.cc new file mode 100644 index 000000000000..670264ee4c12 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_blocking_manager_test.cc @@ -0,0 +1,319 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_blocking_manager.h" + +#include + +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { + +class QpackBlockingManagerPeer { + public: + static bool stream_is_blocked(const QpackBlockingManager* manager, + QuicStreamId stream_id) { + for (const auto& header_blocks_for_stream : manager->header_blocks_) { + if (header_blocks_for_stream.first != stream_id) { + continue; + } + for (const auto& indices : header_blocks_for_stream.second) { + if (QpackBlockingManager::RequiredInsertCount(indices) > + manager->known_received_count_) { + return true; + } + } + } + + return false; + } +}; + +namespace { + +class QpackBlockingManagerTest : public QuicTest { + protected: + QpackBlockingManagerTest() = default; + ~QpackBlockingManagerTest() override = default; + + bool stream_is_blocked(QuicStreamId stream_id) const { + return QpackBlockingManagerPeer::stream_is_blocked(&manager_, stream_id); + } + + QpackBlockingManager manager_; +}; + +TEST_F(QpackBlockingManagerTest, Empty) { + EXPECT_EQ(0u, manager_.known_received_count()); + EXPECT_EQ(std::numeric_limits::max(), + manager_.smallest_blocking_index()); + + EXPECT_FALSE(manager_.OnHeaderAcknowledgement(0)); + EXPECT_FALSE(manager_.OnHeaderAcknowledgement(1)); +} + +TEST_F(QpackBlockingManagerTest, NotBlockedByInsertCountIncrement) { + EXPECT_TRUE(manager_.OnInsertCountIncrement(2)); + + // Stream 0 is not blocked, because it only references entries that are + // already acknowledged by an Insert Count Increment instruction. + manager_.OnHeaderBlockSent(0, {1, 0}); + EXPECT_FALSE(stream_is_blocked(0)); +} + +TEST_F(QpackBlockingManagerTest, UnblockedByInsertCountIncrement) { + manager_.OnHeaderBlockSent(0, {1, 0}); + EXPECT_TRUE(stream_is_blocked(0)); + + EXPECT_TRUE(manager_.OnInsertCountIncrement(2)); + EXPECT_FALSE(stream_is_blocked(0)); +} + +TEST_F(QpackBlockingManagerTest, NotBlockedByHeaderAcknowledgement) { + manager_.OnHeaderBlockSent(0, {2, 1, 1}); + EXPECT_TRUE(stream_is_blocked(0)); + + EXPECT_TRUE(manager_.OnHeaderAcknowledgement(0)); + EXPECT_FALSE(stream_is_blocked(0)); + + // Stream 1 is not blocked, because it only references entries that are + // already acknowledged by a Header Acknowledgement instruction. + manager_.OnHeaderBlockSent(1, {2, 2}); + EXPECT_FALSE(stream_is_blocked(1)); +} + +TEST_F(QpackBlockingManagerTest, UnblockedByHeaderAcknowledgement) { + manager_.OnHeaderBlockSent(0, {2, 1, 1}); + manager_.OnHeaderBlockSent(1, {2, 2}); + EXPECT_TRUE(stream_is_blocked(0)); + EXPECT_TRUE(stream_is_blocked(1)); + + EXPECT_TRUE(manager_.OnHeaderAcknowledgement(0)); + EXPECT_FALSE(stream_is_blocked(0)); + EXPECT_FALSE(stream_is_blocked(1)); +} + +TEST_F(QpackBlockingManagerTest, KnownReceivedCount) { + EXPECT_EQ(0u, manager_.known_received_count()); + + // Sending a header block does not change Known Received Count. + manager_.OnHeaderBlockSent(0, {0}); + EXPECT_EQ(0u, manager_.known_received_count()); + + manager_.OnHeaderBlockSent(1, {1}); + EXPECT_EQ(0u, manager_.known_received_count()); + + // Header Acknowledgement might increase Known Received Count. + EXPECT_TRUE(manager_.OnHeaderAcknowledgement(0)); + EXPECT_EQ(1u, manager_.known_received_count()); + + manager_.OnHeaderBlockSent(2, {5}); + EXPECT_EQ(1u, manager_.known_received_count()); + + EXPECT_TRUE(manager_.OnHeaderAcknowledgement(1)); + EXPECT_EQ(2u, manager_.known_received_count()); + + // Insert Count Increment increases Known Received Count. + EXPECT_TRUE(manager_.OnInsertCountIncrement(2)); + EXPECT_EQ(4u, manager_.known_received_count()); + + EXPECT_TRUE(manager_.OnHeaderAcknowledgement(2)); + EXPECT_EQ(6u, manager_.known_received_count()); + + // Stream Cancellation does not change Known Received Count. + manager_.OnStreamCancellation(0); + EXPECT_EQ(6u, manager_.known_received_count()); + + // Header Acknowledgement of a block with smaller Required Insert Count does + // not increase Known Received Count. + manager_.OnHeaderBlockSent(0, {3}); + EXPECT_EQ(6u, manager_.known_received_count()); + + EXPECT_TRUE(manager_.OnHeaderAcknowledgement(0)); + EXPECT_EQ(6u, manager_.known_received_count()); + + // Header Acknowledgement of a block with equal Required Insert Count does not + // increase Known Received Count. + manager_.OnHeaderBlockSent(1, {5}); + EXPECT_EQ(6u, manager_.known_received_count()); + + EXPECT_TRUE(manager_.OnHeaderAcknowledgement(1)); + EXPECT_EQ(6u, manager_.known_received_count()); +} + +TEST_F(QpackBlockingManagerTest, SmallestBlockingIndex) { + EXPECT_EQ(std::numeric_limits::max(), + manager_.smallest_blocking_index()); + + manager_.OnHeaderBlockSent(0, {0}); + EXPECT_EQ(0u, manager_.smallest_blocking_index()); + + manager_.OnHeaderBlockSent(1, {2}); + EXPECT_EQ(0u, manager_.smallest_blocking_index()); + + EXPECT_TRUE(manager_.OnHeaderAcknowledgement(0)); + EXPECT_EQ(2u, manager_.smallest_blocking_index()); + + manager_.OnHeaderBlockSent(1, {1}); + EXPECT_EQ(1u, manager_.smallest_blocking_index()); + + EXPECT_TRUE(manager_.OnHeaderAcknowledgement(1)); + EXPECT_EQ(1u, manager_.smallest_blocking_index()); + + // Insert Count Increment does not change smallest blocking index. + EXPECT_TRUE(manager_.OnInsertCountIncrement(2)); + EXPECT_EQ(1u, manager_.smallest_blocking_index()); + + manager_.OnStreamCancellation(1); + EXPECT_EQ(std::numeric_limits::max(), + manager_.smallest_blocking_index()); +} + +TEST_F(QpackBlockingManagerTest, HeaderAcknowledgementsOnSingleStream) { + EXPECT_EQ(0u, manager_.known_received_count()); + EXPECT_EQ(std::numeric_limits::max(), + manager_.smallest_blocking_index()); + + manager_.OnHeaderBlockSent(0, {2, 1, 1}); + EXPECT_EQ(0u, manager_.known_received_count()); + EXPECT_TRUE(stream_is_blocked(0)); + EXPECT_EQ(1u, manager_.smallest_blocking_index()); + + manager_.OnHeaderBlockSent(0, {1, 0}); + EXPECT_EQ(0u, manager_.known_received_count()); + EXPECT_TRUE(stream_is_blocked(0)); + EXPECT_EQ(0u, manager_.smallest_blocking_index()); + + EXPECT_TRUE(manager_.OnHeaderAcknowledgement(0)); + EXPECT_EQ(3u, manager_.known_received_count()); + EXPECT_FALSE(stream_is_blocked(0)); + EXPECT_EQ(0u, manager_.smallest_blocking_index()); + + manager_.OnHeaderBlockSent(0, {3}); + EXPECT_EQ(3u, manager_.known_received_count()); + EXPECT_TRUE(stream_is_blocked(0)); + EXPECT_EQ(0u, manager_.smallest_blocking_index()); + + EXPECT_TRUE(manager_.OnHeaderAcknowledgement(0)); + EXPECT_EQ(3u, manager_.known_received_count()); + EXPECT_TRUE(stream_is_blocked(0)); + EXPECT_EQ(3u, manager_.smallest_blocking_index()); + + EXPECT_TRUE(manager_.OnHeaderAcknowledgement(0)); + EXPECT_EQ(4u, manager_.known_received_count()); + EXPECT_FALSE(stream_is_blocked(0)); + EXPECT_EQ(std::numeric_limits::max(), + manager_.smallest_blocking_index()); + + EXPECT_FALSE(manager_.OnHeaderAcknowledgement(0)); +} + +TEST_F(QpackBlockingManagerTest, CancelStream) { + manager_.OnHeaderBlockSent(0, {3}); + EXPECT_TRUE(stream_is_blocked(0)); + EXPECT_EQ(3u, manager_.smallest_blocking_index()); + + manager_.OnHeaderBlockSent(0, {2}); + EXPECT_TRUE(stream_is_blocked(0)); + EXPECT_EQ(2u, manager_.smallest_blocking_index()); + + manager_.OnHeaderBlockSent(1, {4}); + EXPECT_TRUE(stream_is_blocked(0)); + EXPECT_TRUE(stream_is_blocked(1)); + EXPECT_EQ(2u, manager_.smallest_blocking_index()); + + manager_.OnStreamCancellation(0); + EXPECT_FALSE(stream_is_blocked(0)); + EXPECT_TRUE(stream_is_blocked(1)); + EXPECT_EQ(4u, manager_.smallest_blocking_index()); + + manager_.OnStreamCancellation(1); + EXPECT_FALSE(stream_is_blocked(0)); + EXPECT_FALSE(stream_is_blocked(1)); + EXPECT_EQ(std::numeric_limits::max(), + manager_.smallest_blocking_index()); +} + +TEST_F(QpackBlockingManagerTest, BlockingAllowedOnStream) { + const QuicStreamId kStreamId1 = 1; + const QuicStreamId kStreamId2 = 2; + const QuicStreamId kStreamId3 = 3; + + // No stream can block if limit is 0. + EXPECT_FALSE(manager_.blocking_allowed_on_stream(kStreamId1, 0)); + EXPECT_FALSE(manager_.blocking_allowed_on_stream(kStreamId2, 0)); + + // Either stream can block if limit is larger. + EXPECT_TRUE(manager_.blocking_allowed_on_stream(kStreamId1, 1)); + EXPECT_TRUE(manager_.blocking_allowed_on_stream(kStreamId2, 1)); + + // Doubly block first stream. + manager_.OnHeaderBlockSent(kStreamId1, {0}); + manager_.OnHeaderBlockSent(kStreamId1, {1}); + + // First stream is already blocked so it can carry more blocking references. + EXPECT_TRUE(manager_.blocking_allowed_on_stream(kStreamId1, 1)); + // Second stream is not allowed to block if limit is already reached. + EXPECT_FALSE(manager_.blocking_allowed_on_stream(kStreamId2, 1)); + + // Either stream can block if limit is larger than number of blocked streams. + EXPECT_TRUE(manager_.blocking_allowed_on_stream(kStreamId1, 2)); + EXPECT_TRUE(manager_.blocking_allowed_on_stream(kStreamId2, 2)); + + // Block second stream. + manager_.OnHeaderBlockSent(kStreamId2, {2}); + + // Streams are already blocked so either can carry more blocking references. + EXPECT_TRUE(manager_.blocking_allowed_on_stream(kStreamId1, 2)); + EXPECT_TRUE(manager_.blocking_allowed_on_stream(kStreamId2, 2)); + + // Third, unblocked stream is not allowed to block unless limit is strictly + // larger than number of blocked streams. + EXPECT_FALSE(manager_.blocking_allowed_on_stream(kStreamId3, 2)); + EXPECT_TRUE(manager_.blocking_allowed_on_stream(kStreamId3, 3)); + + // Acknowledge decoding of first header block on first stream. + // Stream is still blocked on its second header block. + manager_.OnHeaderAcknowledgement(kStreamId1); + + EXPECT_TRUE(manager_.blocking_allowed_on_stream(kStreamId1, 2)); + EXPECT_TRUE(manager_.blocking_allowed_on_stream(kStreamId2, 2)); + + // Acknowledge decoding of second header block on first stream. + // This unblocks the stream. + manager_.OnHeaderAcknowledgement(kStreamId1); + + // First stream is not allowed to block if limit is already reached. + EXPECT_FALSE(manager_.blocking_allowed_on_stream(kStreamId1, 1)); + // Second stream is already blocked so it can carry more blocking references. + EXPECT_TRUE(manager_.blocking_allowed_on_stream(kStreamId2, 1)); + + // Either stream can block if limit is larger than number of blocked streams. + EXPECT_TRUE(manager_.blocking_allowed_on_stream(kStreamId1, 2)); + EXPECT_TRUE(manager_.blocking_allowed_on_stream(kStreamId2, 2)); + + // Unblock second stream. + manager_.OnHeaderAcknowledgement(kStreamId2); + + // No stream can block if limit is 0. + EXPECT_FALSE(manager_.blocking_allowed_on_stream(kStreamId1, 0)); + EXPECT_FALSE(manager_.blocking_allowed_on_stream(kStreamId2, 0)); + + // Either stream can block if limit is larger. + EXPECT_TRUE(manager_.blocking_allowed_on_stream(kStreamId1, 1)); + EXPECT_TRUE(manager_.blocking_allowed_on_stream(kStreamId2, 1)); +} + +TEST_F(QpackBlockingManagerTest, InsertCountIncrementOverflow) { + EXPECT_TRUE(manager_.OnInsertCountIncrement(10)); + EXPECT_EQ(10u, manager_.known_received_count()); + + EXPECT_FALSE(manager_.OnInsertCountIncrement( + std::numeric_limits::max() - 5)); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_decoded_headers_accumulator.cc b/quiche/quic/core/qpack/qpack_decoded_headers_accumulator.cc new file mode 100644 index 000000000000..41d64e7bf68e --- /dev/null +++ b/quiche/quic/core/qpack/qpack_decoded_headers_accumulator.cc @@ -0,0 +1,100 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_decoded_headers_accumulator.h" + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/qpack/qpack_decoder.h" +#include "quiche/quic/core/qpack/qpack_header_table.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flags.h" + +namespace quic { + +QpackDecodedHeadersAccumulator::QpackDecodedHeadersAccumulator( + QuicStreamId id, QpackDecoder* qpack_decoder, Visitor* visitor, + size_t max_header_list_size) + : decoder_(qpack_decoder->CreateProgressiveDecoder(id, this)), + visitor_(visitor), + max_header_list_size_(max_header_list_size), + uncompressed_header_bytes_including_overhead_(0), + uncompressed_header_bytes_without_overhead_(0), + compressed_header_bytes_(0), + header_list_size_limit_exceeded_(false), + headers_decoded_(false), + error_detected_(false) { + quic_header_list_.OnHeaderBlockStart(); +} + +void QpackDecodedHeadersAccumulator::OnHeaderDecoded(absl::string_view name, + absl::string_view value) { + QUICHE_DCHECK(!error_detected_); + + uncompressed_header_bytes_without_overhead_ += name.size() + value.size(); + + if (header_list_size_limit_exceeded_) { + return; + } + + uncompressed_header_bytes_including_overhead_ += + name.size() + value.size() + kQpackEntrySizeOverhead; + + const size_t uncompressed_header_bytes = + GetQuicFlag(quic_header_size_limit_includes_overhead) + ? uncompressed_header_bytes_including_overhead_ + : uncompressed_header_bytes_without_overhead_; + if (uncompressed_header_bytes > max_header_list_size_) { + header_list_size_limit_exceeded_ = true; + quic_header_list_.Clear(); + } else { + quic_header_list_.OnHeader(name, value); + } +} + +void QpackDecodedHeadersAccumulator::OnDecodingCompleted() { + QUICHE_DCHECK(!headers_decoded_); + QUICHE_DCHECK(!error_detected_); + + headers_decoded_ = true; + + quic_header_list_.OnHeaderBlockEnd( + uncompressed_header_bytes_without_overhead_, compressed_header_bytes_); + + // Might destroy |this|. + visitor_->OnHeadersDecoded(std::move(quic_header_list_), + header_list_size_limit_exceeded_); +} + +void QpackDecodedHeadersAccumulator::OnDecodingErrorDetected( + QuicErrorCode error_code, absl::string_view error_message) { + QUICHE_DCHECK(!error_detected_); + QUICHE_DCHECK(!headers_decoded_); + + error_detected_ = true; + // Might destroy |this|. + visitor_->OnHeaderDecodingError(error_code, error_message); +} + +void QpackDecodedHeadersAccumulator::Decode(absl::string_view data) { + QUICHE_DCHECK(!error_detected_); + + compressed_header_bytes_ += data.size(); + // Might destroy |this|. + decoder_->Decode(data); +} + +void QpackDecodedHeadersAccumulator::EndHeaderBlock() { + QUICHE_DCHECK(!error_detected_); + QUICHE_DCHECK(!headers_decoded_); + + if (!decoder_) { + QUIC_BUG(b215142466_EndHeaderBlock); + return; + } + + // Might destroy |this|. + decoder_->EndHeaderBlock(); +} + +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_decoded_headers_accumulator.h b/quiche/quic/core/qpack/qpack_decoded_headers_accumulator.h new file mode 100644 index 000000000000..0e6e31b88a9d --- /dev/null +++ b/quiche/quic/core/qpack/qpack_decoded_headers_accumulator.h @@ -0,0 +1,104 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QPACK_QPACK_DECODED_HEADERS_ACCUMULATOR_H_ +#define QUICHE_QUIC_CORE_QPACK_QPACK_DECODED_HEADERS_ACCUMULATOR_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/http/quic_header_list.h" +#include "quiche/quic/core/qpack/qpack_progressive_decoder.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class QpackDecoder; + +// A class that creates and owns a QpackProgressiveDecoder instance, accumulates +// decoded headers in a QuicHeaderList, and keeps track of uncompressed and +// compressed size so that it can be passed to +// QuicHeaderList::OnHeaderBlockEnd(). +class QUIC_EXPORT_PRIVATE QpackDecodedHeadersAccumulator + : public QpackProgressiveDecoder::HeadersHandlerInterface { + public: + // Visitor interface to signal success or error. + // Exactly one method will be called. + // Methods may be called synchronously from Decode() and EndHeaderBlock(), + // or asynchronously. + // Method implementations are allowed to destroy |this|. + class QUIC_EXPORT_PRIVATE Visitor { + public: + virtual ~Visitor() = default; + + // Called when headers are successfully decoded. If the uncompressed header + // list size including an overhead for each header field exceeds the limit + // specified via |max_header_list_size| in QpackDecodedHeadersAccumulator + // constructor, then |header_list_size_limit_exceeded| will be true, and + // |headers| will be empty but will still have the correct compressed and + // uncompressed size + // information. + virtual void OnHeadersDecoded(QuicHeaderList headers, + bool header_list_size_limit_exceeded) = 0; + + // Called when an error has occurred. + virtual void OnHeaderDecodingError(QuicErrorCode error_code, + absl::string_view error_message) = 0; + }; + + QpackDecodedHeadersAccumulator(QuicStreamId id, QpackDecoder* qpack_decoder, + Visitor* visitor, size_t max_header_list_size); + virtual ~QpackDecodedHeadersAccumulator() = default; + + // QpackProgressiveDecoder::HeadersHandlerInterface implementation. + // These methods should only be called by |decoder_|. + void OnHeaderDecoded(absl::string_view name, + absl::string_view value) override; + void OnDecodingCompleted() override; + void OnDecodingErrorDetected(QuicErrorCode error_code, + absl::string_view error_message) override; + + // Decode payload data. + // Must not be called if an error has been detected. + // Must not be called after EndHeaderBlock(). + void Decode(absl::string_view data); + + // Signal end of HEADERS frame. + // Must not be called if an error has been detected. + // Must not be called more that once. + void EndHeaderBlock(); + + private: + std::unique_ptr decoder_; + Visitor* visitor_; + // Maximum header list size including overhead. + size_t max_header_list_size_; + // Uncompressed header list size including overhead, for enforcing the limit. + size_t uncompressed_header_bytes_including_overhead_; + QuicHeaderList quic_header_list_; + // Uncompressed header list size with overhead, + // for passing in to QuicHeaderList::OnHeaderBlockEnd(). + size_t uncompressed_header_bytes_without_overhead_; + // Compressed header list size + // for passing in to QuicHeaderList::OnHeaderBlockEnd(). + size_t compressed_header_bytes_; + + // True if the header size limit has been exceeded. + // Input data is still fed to QpackProgressiveDecoder. + bool header_list_size_limit_exceeded_; + + // The following two members are only used for QUICHE_DCHECKs. + + // True if headers have been completedly and successfully decoded. + bool headers_decoded_; + // True if an error has been detected during decoding. + bool error_detected_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QPACK_QPACK_DECODED_HEADERS_ACCUMULATOR_H_ diff --git a/quiche/quic/core/qpack/qpack_decoded_headers_accumulator_test.cc b/quiche/quic/core/qpack/qpack_decoded_headers_accumulator_test.cc new file mode 100644 index 000000000000..e4847c114916 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_decoded_headers_accumulator_test.cc @@ -0,0 +1,248 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_decoded_headers_accumulator.h" + +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/qpack/qpack_decoder.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/qpack/qpack_test_utils.h" + +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::Pair; +using ::testing::SaveArg; +using ::testing::StrictMock; + +namespace quic { +namespace test { +namespace { + +// Arbitrary stream ID used for testing. +QuicStreamId kTestStreamId = 1; + +// Limit on header list size. +const size_t kMaxHeaderListSize = 100; + +// Maximum dynamic table capacity. +const size_t kMaxDynamicTableCapacity = 100; + +// Maximum number of blocked streams. +const uint64_t kMaximumBlockedStreams = 1; + +// Header Acknowledgement decoder stream instruction with stream_id = 1. +const char* const kHeaderAcknowledgement = "\x81"; + +class MockVisitor : public QpackDecodedHeadersAccumulator::Visitor { + public: + ~MockVisitor() override = default; + MOCK_METHOD(void, OnHeadersDecoded, + (QuicHeaderList headers, bool header_list_size_limit_exceeded), + (override)); + MOCK_METHOD(void, OnHeaderDecodingError, + (QuicErrorCode error_code, absl::string_view error_message), + (override)); +}; + +} // anonymous namespace + +class QpackDecodedHeadersAccumulatorTest : public QuicTest { + protected: + QpackDecodedHeadersAccumulatorTest() + : qpack_decoder_(kMaxDynamicTableCapacity, kMaximumBlockedStreams, + &encoder_stream_error_delegate_), + accumulator_(kTestStreamId, &qpack_decoder_, &visitor_, + kMaxHeaderListSize) { + qpack_decoder_.set_qpack_stream_sender_delegate( + &decoder_stream_sender_delegate_); + } + + NoopEncoderStreamErrorDelegate encoder_stream_error_delegate_; + StrictMock decoder_stream_sender_delegate_; + QpackDecoder qpack_decoder_; + StrictMock visitor_; + QpackDecodedHeadersAccumulator accumulator_; +}; + +// HEADERS frame payload must have a complete Header Block Prefix. +TEST_F(QpackDecodedHeadersAccumulatorTest, EmptyPayload) { + EXPECT_CALL(visitor_, + OnHeaderDecodingError(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Incomplete header data prefix."))); + accumulator_.EndHeaderBlock(); +} + +// HEADERS frame payload must have a complete Header Block Prefix. +TEST_F(QpackDecodedHeadersAccumulatorTest, TruncatedHeaderBlockPrefix) { + accumulator_.Decode(absl::HexStringToBytes("00")); + + EXPECT_CALL(visitor_, + OnHeaderDecodingError(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Incomplete header data prefix."))); + accumulator_.EndHeaderBlock(); +} + +TEST_F(QpackDecodedHeadersAccumulatorTest, EmptyHeaderList) { + std::string encoded_data(absl::HexStringToBytes("0000")); + accumulator_.Decode(encoded_data); + + QuicHeaderList header_list; + EXPECT_CALL(visitor_, OnHeadersDecoded(_, false)) + .WillOnce(SaveArg<0>(&header_list)); + accumulator_.EndHeaderBlock(); + + EXPECT_EQ(0u, header_list.uncompressed_header_bytes()); + EXPECT_EQ(encoded_data.size(), header_list.compressed_header_bytes()); + EXPECT_TRUE(header_list.empty()); +} + +// This payload is the prefix of a valid payload, but EndHeaderBlock() is called +// before it can be completely decoded. +TEST_F(QpackDecodedHeadersAccumulatorTest, TruncatedPayload) { + accumulator_.Decode(absl::HexStringToBytes("00002366")); + + EXPECT_CALL(visitor_, OnHeaderDecodingError(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Incomplete header block."))); + accumulator_.EndHeaderBlock(); +} + +// This payload is invalid because it refers to a non-existing static entry. +TEST_F(QpackDecodedHeadersAccumulatorTest, InvalidPayload) { + EXPECT_CALL(visitor_, + OnHeaderDecodingError(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Static table entry not found."))); + accumulator_.Decode(absl::HexStringToBytes("0000ff23ff24")); +} + +TEST_F(QpackDecodedHeadersAccumulatorTest, Success) { + std::string encoded_data(absl::HexStringToBytes("000023666f6f03626172")); + accumulator_.Decode(encoded_data); + + QuicHeaderList header_list; + EXPECT_CALL(visitor_, OnHeadersDecoded(_, false)) + .WillOnce(SaveArg<0>(&header_list)); + accumulator_.EndHeaderBlock(); + + EXPECT_THAT(header_list, ElementsAre(Pair("foo", "bar"))); + EXPECT_EQ(strlen("foo") + strlen("bar"), + header_list.uncompressed_header_bytes()); + EXPECT_EQ(encoded_data.size(), header_list.compressed_header_bytes()); +} + +// Test that Decode() calls are not ignored after header list limit is exceeded, +// otherwise decoding could fail with "incomplete header block" error. +TEST_F(QpackDecodedHeadersAccumulatorTest, ExceedLimitThenSplitInstruction) { + // Total length of header list exceeds kMaxHeaderListSize. + accumulator_.Decode(absl::HexStringToBytes( + "0000" // header block prefix + "26666f6f626172" // header key: "foobar" + "7d61616161616161616161616161616161616161" // header value: 'a' 125 times + "616161616161616161616161616161616161616161616161616161616161616161616161" + "616161616161616161616161616161616161616161616161616161616161616161616161" + "61616161616161616161616161616161616161616161616161616161616161616161" + "ff")); // first byte of a two-byte long Indexed Header Field instruction + accumulator_.Decode(absl::HexStringToBytes( + "0f" // second byte of a two-byte long Indexed Header Field instruction + )); + + EXPECT_CALL(visitor_, OnHeadersDecoded(_, true)); + accumulator_.EndHeaderBlock(); +} + +// Test that header list limit enforcement works with blocked encoding. +TEST_F(QpackDecodedHeadersAccumulatorTest, ExceedLimitBlocked) { + // Total length of header list exceeds kMaxHeaderListSize. + accumulator_.Decode(absl::HexStringToBytes( + "0200" // header block prefix + "80" // reference to dynamic table entry not yet received + "26666f6f626172" // header key: "foobar" + "7d61616161616161616161616161616161616161" // header value: 'a' 125 times + "616161616161616161616161616161616161616161616161616161616161616161616161" + "616161616161616161616161616161616161616161616161616161616161616161616161" + "61616161616161616161616161616161616161616161616161616161616161616161")); + accumulator_.EndHeaderBlock(); + + // Set dynamic table capacity. + qpack_decoder_.OnSetDynamicTableCapacity(kMaxDynamicTableCapacity); + // Adding dynamic table entry unblocks decoding. + EXPECT_CALL(decoder_stream_sender_delegate_, + WriteStreamData(Eq(kHeaderAcknowledgement))); + + EXPECT_CALL(visitor_, OnHeadersDecoded(_, true)); + qpack_decoder_.OnInsertWithoutNameReference("foo", "bar"); +} + +TEST_F(QpackDecodedHeadersAccumulatorTest, BlockedDecoding) { + // Reference to dynamic table entry not yet received. + std::string encoded_data(absl::HexStringToBytes("020080")); + accumulator_.Decode(encoded_data); + accumulator_.EndHeaderBlock(); + + // Set dynamic table capacity. + qpack_decoder_.OnSetDynamicTableCapacity(kMaxDynamicTableCapacity); + // Adding dynamic table entry unblocks decoding. + EXPECT_CALL(decoder_stream_sender_delegate_, + WriteStreamData(Eq(kHeaderAcknowledgement))); + + QuicHeaderList header_list; + EXPECT_CALL(visitor_, OnHeadersDecoded(_, false)) + .WillOnce(SaveArg<0>(&header_list)); + qpack_decoder_.OnInsertWithoutNameReference("foo", "bar"); + + EXPECT_THAT(header_list, ElementsAre(Pair("foo", "bar"))); + EXPECT_EQ(strlen("foo") + strlen("bar"), + header_list.uncompressed_header_bytes()); + EXPECT_EQ(encoded_data.size(), header_list.compressed_header_bytes()); +} + +TEST_F(QpackDecodedHeadersAccumulatorTest, + BlockedDecodingUnblockedBeforeEndOfHeaderBlock) { + // Reference to dynamic table entry not yet received. + accumulator_.Decode(absl::HexStringToBytes("020080")); + + // Set dynamic table capacity. + qpack_decoder_.OnSetDynamicTableCapacity(kMaxDynamicTableCapacity); + // Adding dynamic table entry unblocks decoding. + qpack_decoder_.OnInsertWithoutNameReference("foo", "bar"); + + // Rest of header block: same entry again. + EXPECT_CALL(decoder_stream_sender_delegate_, + WriteStreamData(Eq(kHeaderAcknowledgement))); + accumulator_.Decode(absl::HexStringToBytes("80")); + + QuicHeaderList header_list; + EXPECT_CALL(visitor_, OnHeadersDecoded(_, false)) + .WillOnce(SaveArg<0>(&header_list)); + accumulator_.EndHeaderBlock(); + + EXPECT_THAT(header_list, ElementsAre(Pair("foo", "bar"), Pair("foo", "bar"))); +} + +// Regression test for https://crbug.com/1024263. +TEST_F(QpackDecodedHeadersAccumulatorTest, + BlockedDecodingUnblockedAndErrorBeforeEndOfHeaderBlock) { + // Required Insert Count higher than number of entries causes decoding to be + // blocked. + accumulator_.Decode(absl::HexStringToBytes("0200")); + // Indexed Header Field instruction addressing dynamic table entry with + // relative index 0, absolute index 0. + accumulator_.Decode(absl::HexStringToBytes("80")); + // Relative index larger than or equal to Base is invalid. + accumulator_.Decode(absl::HexStringToBytes("81")); + + // Set dynamic table capacity. + qpack_decoder_.OnSetDynamicTableCapacity(kMaxDynamicTableCapacity); + + // Adding dynamic table entry unblocks decoding. Error is detected. + EXPECT_CALL(visitor_, OnHeaderDecodingError(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Invalid relative index."))); + qpack_decoder_.OnInsertWithoutNameReference("foo", "bar"); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_decoder.cc b/quiche/quic/core/qpack/qpack_decoder.cc new file mode 100644 index 000000000000..05698502a81c --- /dev/null +++ b/quiche/quic/core/qpack/qpack_decoder.cc @@ -0,0 +1,170 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_decoder.h" + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/qpack/qpack_index_conversions.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +QpackDecoder::QpackDecoder( + uint64_t maximum_dynamic_table_capacity, uint64_t maximum_blocked_streams, + EncoderStreamErrorDelegate* encoder_stream_error_delegate) + : encoder_stream_error_delegate_(encoder_stream_error_delegate), + encoder_stream_receiver_(this), + maximum_blocked_streams_(maximum_blocked_streams), + known_received_count_(0) { + QUICHE_DCHECK(encoder_stream_error_delegate_); + + header_table_.SetMaximumDynamicTableCapacity(maximum_dynamic_table_capacity); +} + +QpackDecoder::~QpackDecoder() {} + +void QpackDecoder::OnStreamReset(QuicStreamId stream_id) { + if (header_table_.maximum_dynamic_table_capacity() > 0) { + decoder_stream_sender_.SendStreamCancellation(stream_id); + decoder_stream_sender_.Flush(); + } +} + +bool QpackDecoder::OnStreamBlocked(QuicStreamId stream_id) { + auto result = blocked_streams_.insert(stream_id); + QUICHE_DCHECK(result.second); + return blocked_streams_.size() <= maximum_blocked_streams_; +} + +void QpackDecoder::OnStreamUnblocked(QuicStreamId stream_id) { + size_t result = blocked_streams_.erase(stream_id); + QUICHE_DCHECK_EQ(1u, result); +} + +void QpackDecoder::OnDecodingCompleted(QuicStreamId stream_id, + uint64_t required_insert_count) { + if (required_insert_count > 0) { + decoder_stream_sender_.SendHeaderAcknowledgement(stream_id); + + if (known_received_count_ < required_insert_count) { + known_received_count_ = required_insert_count; + } + } + + // Send an Insert Count Increment instruction if not all dynamic table entries + // have been acknowledged yet. This is necessary for efficient compression in + // case the encoder chooses not to reference unacknowledged dynamic table + // entries, otherwise inserted entries would never be acknowledged. + if (known_received_count_ < header_table_.inserted_entry_count()) { + decoder_stream_sender_.SendInsertCountIncrement( + header_table_.inserted_entry_count() - known_received_count_); + known_received_count_ = header_table_.inserted_entry_count(); + } + + decoder_stream_sender_.Flush(); +} + +void QpackDecoder::OnInsertWithNameReference(bool is_static, + uint64_t name_index, + absl::string_view value) { + if (is_static) { + auto entry = header_table_.LookupEntry(/* is_static = */ true, name_index); + if (!entry) { + OnErrorDetected(QUIC_QPACK_ENCODER_STREAM_INVALID_STATIC_ENTRY, + "Invalid static table entry."); + return; + } + + if (!header_table_.EntryFitsDynamicTableCapacity(entry->name(), value)) { + OnErrorDetected(QUIC_QPACK_ENCODER_STREAM_ERROR_INSERTING_STATIC, + "Error inserting entry with name reference."); + return; + } + header_table_.InsertEntry(entry->name(), value); + return; + } + + uint64_t absolute_index; + if (!QpackEncoderStreamRelativeIndexToAbsoluteIndex( + name_index, header_table_.inserted_entry_count(), &absolute_index)) { + OnErrorDetected(QUIC_QPACK_ENCODER_STREAM_INSERTION_INVALID_RELATIVE_INDEX, + "Invalid relative index."); + return; + } + + const QpackEntry* entry = + header_table_.LookupEntry(/* is_static = */ false, absolute_index); + if (!entry) { + OnErrorDetected(QUIC_QPACK_ENCODER_STREAM_INSERTION_DYNAMIC_ENTRY_NOT_FOUND, + "Dynamic table entry not found."); + return; + } + if (!header_table_.EntryFitsDynamicTableCapacity(entry->name(), value)) { + OnErrorDetected(QUIC_QPACK_ENCODER_STREAM_ERROR_INSERTING_DYNAMIC, + "Error inserting entry with name reference."); + return; + } + header_table_.InsertEntry(entry->name(), value); +} + +void QpackDecoder::OnInsertWithoutNameReference(absl::string_view name, + absl::string_view value) { + if (!header_table_.EntryFitsDynamicTableCapacity(name, value)) { + OnErrorDetected(QUIC_QPACK_ENCODER_STREAM_ERROR_INSERTING_LITERAL, + "Error inserting literal entry."); + return; + } + header_table_.InsertEntry(name, value); +} + +void QpackDecoder::OnDuplicate(uint64_t index) { + uint64_t absolute_index; + if (!QpackEncoderStreamRelativeIndexToAbsoluteIndex( + index, header_table_.inserted_entry_count(), &absolute_index)) { + OnErrorDetected(QUIC_QPACK_ENCODER_STREAM_DUPLICATE_INVALID_RELATIVE_INDEX, + "Invalid relative index."); + return; + } + + const QpackEntry* entry = + header_table_.LookupEntry(/* is_static = */ false, absolute_index); + if (!entry) { + OnErrorDetected(QUIC_QPACK_ENCODER_STREAM_DUPLICATE_DYNAMIC_ENTRY_NOT_FOUND, + "Dynamic table entry not found."); + return; + } + if (!header_table_.EntryFitsDynamicTableCapacity(entry->name(), + entry->value())) { + // This is impossible since entry was retrieved from the dynamic table. + OnErrorDetected(QUIC_INTERNAL_ERROR, "Error inserting duplicate entry."); + return; + } + header_table_.InsertEntry(entry->name(), entry->value()); +} + +void QpackDecoder::OnSetDynamicTableCapacity(uint64_t capacity) { + if (!header_table_.SetDynamicTableCapacity(capacity)) { + OnErrorDetected(QUIC_QPACK_ENCODER_STREAM_SET_DYNAMIC_TABLE_CAPACITY, + "Error updating dynamic table capacity."); + } +} + +void QpackDecoder::OnErrorDetected(QuicErrorCode error_code, + absl::string_view error_message) { + encoder_stream_error_delegate_->OnEncoderStreamError(error_code, + error_message); +} + +std::unique_ptr QpackDecoder::CreateProgressiveDecoder( + QuicStreamId stream_id, + QpackProgressiveDecoder::HeadersHandlerInterface* handler) { + return std::make_unique(stream_id, this, this, + &header_table_, handler); +} + +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_decoder.h b/quiche/quic/core/qpack/qpack_decoder.h new file mode 100644 index 000000000000..9474378631c4 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_decoder.h @@ -0,0 +1,137 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QPACK_QPACK_DECODER_H_ +#define QUICHE_QUIC_CORE_QPACK_QPACK_DECODER_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/qpack/qpack_decoder_stream_sender.h" +#include "quiche/quic/core/qpack/qpack_encoder_stream_receiver.h" +#include "quiche/quic/core/qpack/qpack_header_table.h" +#include "quiche/quic/core/qpack/qpack_progressive_decoder.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// QPACK decoder class. Exactly one instance should exist per QUIC connection. +// This class vends a new QpackProgressiveDecoder instance for each new header +// list to be encoded. +// QpackProgressiveDecoder detects and signals errors with header blocks, which +// are stream errors. +// The only input of QpackDecoder is the encoder stream. Any error QpackDecoder +// signals is an encoder stream error, which is fatal to the connection. +class QUIC_EXPORT_PRIVATE QpackDecoder + : public QpackEncoderStreamReceiver::Delegate, + public QpackProgressiveDecoder::BlockedStreamLimitEnforcer, + public QpackProgressiveDecoder::DecodingCompletedVisitor { + public: + // Interface for receiving notification that an error has occurred on the + // encoder stream. This MUST be treated as a connection error of type + // HTTP_QPACK_ENCODER_STREAM_ERROR. + class QUIC_EXPORT_PRIVATE EncoderStreamErrorDelegate { + public: + virtual ~EncoderStreamErrorDelegate() {} + + virtual void OnEncoderStreamError(QuicErrorCode error_code, + absl::string_view error_message) = 0; + }; + + QpackDecoder(uint64_t maximum_dynamic_table_capacity, + uint64_t maximum_blocked_streams, + EncoderStreamErrorDelegate* encoder_stream_error_delegate); + ~QpackDecoder() override; + + // Signal to the peer's encoder that a stream is reset. This lets the peer's + // encoder know that no more header blocks will be processed on this stream, + // therefore references to dynamic table entries shall not prevent their + // eviction. + // This method should be called regardless of whether a header block is being + // decoded on that stream, because a header block might be in flight from the + // peer. + // This method should be called every time a request or push stream is reset + // for any reason: for example, client cancels request, or a decoding error + // occurs and HeadersHandlerInterface::OnDecodingErrorDetected() is called. + // This method should also be called if the stream is reset by the peer, + // because the peer's encoder can only evict entries referenced by header + // blocks once it receives acknowledgement from this endpoint that the stream + // is reset. + // However, this method should not be called if the stream is closed normally + // using the FIN bit. + void OnStreamReset(QuicStreamId stream_id); + + // QpackProgressiveDecoder::BlockedStreamLimitEnforcer implementation. + bool OnStreamBlocked(QuicStreamId stream_id) override; + void OnStreamUnblocked(QuicStreamId stream_id) override; + + // QpackProgressiveDecoder::DecodingCompletedVisitor implementation. + void OnDecodingCompleted(QuicStreamId stream_id, + uint64_t required_insert_count) override; + + // Factory method to create a QpackProgressiveDecoder for decoding a header + // block. |handler| must remain valid until the returned + // QpackProgressiveDecoder instance is destroyed or the decoder calls + // |handler->OnHeaderBlockEnd()|. + std::unique_ptr CreateProgressiveDecoder( + QuicStreamId stream_id, + QpackProgressiveDecoder::HeadersHandlerInterface* handler); + + // QpackEncoderStreamReceiver::Delegate implementation + void OnInsertWithNameReference(bool is_static, uint64_t name_index, + absl::string_view value) override; + void OnInsertWithoutNameReference(absl::string_view name, + absl::string_view value) override; + void OnDuplicate(uint64_t index) override; + void OnSetDynamicTableCapacity(uint64_t capacity) override; + void OnErrorDetected(QuicErrorCode error_code, + absl::string_view error_message) override; + + // delegate must be set if dynamic table capacity is not zero. + void set_qpack_stream_sender_delegate(QpackStreamSenderDelegate* delegate) { + decoder_stream_sender_.set_qpack_stream_sender_delegate(delegate); + } + + QpackStreamReceiver* encoder_stream_receiver() { + return &encoder_stream_receiver_; + } + + // True if any dynamic table entries have been referenced from a header block. + bool dynamic_table_entry_referenced() const { + return header_table_.dynamic_table_entry_referenced(); + } + + private: + EncoderStreamErrorDelegate* const encoder_stream_error_delegate_; + QpackEncoderStreamReceiver encoder_stream_receiver_; + QpackDecoderStreamSender decoder_stream_sender_; + QpackDecoderHeaderTable header_table_; + std::set blocked_streams_; + const uint64_t maximum_blocked_streams_; + + // Known Received Count is the number of insertions the encoder has received + // acknowledgement for (through Header Acknowledgement and Insert Count + // Increment instructions). The encoder must keep track of it in order to be + // able to send Insert Count Increment instructions. See + // https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#known-received-count. + uint64_t known_received_count_; +}; + +// QpackDecoder::EncoderStreamErrorDelegate implementation that does nothing. +class QUIC_EXPORT_PRIVATE NoopEncoderStreamErrorDelegate + : public QpackDecoder::EncoderStreamErrorDelegate { + public: + ~NoopEncoderStreamErrorDelegate() override = default; + + void OnEncoderStreamError(QuicErrorCode /*error_code*/, + absl::string_view /*error_message*/) override {} +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QPACK_QPACK_DECODER_H_ diff --git a/quiche/quic/core/qpack/qpack_decoder_stream_receiver.cc b/quiche/quic/core/qpack/qpack_decoder_stream_receiver.cc new file mode 100644 index 000000000000..699eb4f7cb66 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_decoder_stream_receiver.cc @@ -0,0 +1,62 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_decoder_stream_receiver.h" + +#include "absl/strings/string_view.h" +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/quic/core/qpack/qpack_instructions.h" + +namespace quic { + +QpackDecoderStreamReceiver::QpackDecoderStreamReceiver(Delegate* delegate) + : instruction_decoder_(QpackDecoderStreamLanguage(), this), + delegate_(delegate), + error_detected_(false) { + QUICHE_DCHECK(delegate_); +} + +void QpackDecoderStreamReceiver::Decode(absl::string_view data) { + if (data.empty() || error_detected_) { + return; + } + + instruction_decoder_.Decode(data); +} + +bool QpackDecoderStreamReceiver::OnInstructionDecoded( + const QpackInstruction* instruction) { + if (instruction == InsertCountIncrementInstruction()) { + delegate_->OnInsertCountIncrement(instruction_decoder_.varint()); + return true; + } + + if (instruction == HeaderAcknowledgementInstruction()) { + delegate_->OnHeaderAcknowledgement(instruction_decoder_.varint()); + return true; + } + + QUICHE_DCHECK_EQ(instruction, StreamCancellationInstruction()); + delegate_->OnStreamCancellation(instruction_decoder_.varint()); + return true; +} + +void QpackDecoderStreamReceiver::OnInstructionDecodingError( + QpackInstructionDecoder::ErrorCode error_code, + absl::string_view error_message) { + QUICHE_DCHECK(!error_detected_); + + error_detected_ = true; + + // There is no string literals on the decoder stream, + // the only possible error is INTEGER_TOO_LARGE. + QuicErrorCode quic_error_code = + (error_code == QpackInstructionDecoder::ErrorCode::INTEGER_TOO_LARGE) + ? QUIC_QPACK_DECODER_STREAM_INTEGER_TOO_LARGE + : QUIC_INTERNAL_ERROR; + delegate_->OnErrorDetected(quic_error_code, error_message); +} + +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_decoder_stream_receiver.h b/quiche/quic/core/qpack/qpack_decoder_stream_receiver.h new file mode 100644 index 000000000000..f78fca94f164 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_decoder_stream_receiver.h @@ -0,0 +1,69 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QPACK_QPACK_DECODER_STREAM_RECEIVER_H_ +#define QUICHE_QUIC_CORE_QPACK_QPACK_DECODER_STREAM_RECEIVER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/qpack/qpack_instruction_decoder.h" +#include "quiche/quic/core/qpack/qpack_stream_receiver.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// This class decodes data received on the decoder stream, +// and passes it along to its Delegate. +class QUIC_EXPORT_PRIVATE QpackDecoderStreamReceiver + : public QpackInstructionDecoder::Delegate, + public QpackStreamReceiver { + public: + // An interface for handling instructions decoded from the decoder stream, see + // https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#rfc.section.5.3 + class QUIC_EXPORT_PRIVATE Delegate { + public: + virtual ~Delegate() = default; + + // 5.3.1 Insert Count Increment + virtual void OnInsertCountIncrement(uint64_t increment) = 0; + // 5.3.2 Header Acknowledgement + virtual void OnHeaderAcknowledgement(QuicStreamId stream_id) = 0; + // 5.3.3 Stream Cancellation + virtual void OnStreamCancellation(QuicStreamId stream_id) = 0; + // Decoding error + virtual void OnErrorDetected(QuicErrorCode error_code, + absl::string_view error_message) = 0; + }; + + explicit QpackDecoderStreamReceiver(Delegate* delegate); + QpackDecoderStreamReceiver() = delete; + QpackDecoderStreamReceiver(const QpackDecoderStreamReceiver&) = delete; + QpackDecoderStreamReceiver& operator=(const QpackDecoderStreamReceiver&) = + delete; + + // Implements QpackStreamReceiver::Decode(). + // Decode data and call appropriate Delegate method after each decoded + // instruction. Once an error occurs, Delegate::OnErrorDetected() is called, + // and all further data is ignored. + void Decode(absl::string_view data) override; + + // QpackInstructionDecoder::Delegate implementation. + bool OnInstructionDecoded(const QpackInstruction* instruction) override; + void OnInstructionDecodingError(QpackInstructionDecoder::ErrorCode error_code, + absl::string_view error_message) override; + + private: + QpackInstructionDecoder instruction_decoder_; + Delegate* const delegate_; + + // True if a decoding error has been detected. + bool error_detected_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QPACK_QPACK_DECODER_STREAM_RECEIVER_H_ diff --git a/quiche/quic/core/qpack/qpack_decoder_stream_receiver_test.cc b/quiche/quic/core/qpack/qpack_decoder_stream_receiver_test.cc new file mode 100644 index 000000000000..6cee903b4b56 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_decoder_stream_receiver_test.cc @@ -0,0 +1,99 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_decoder_stream_receiver.h" + +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_test.h" + +using testing::Eq; +using testing::StrictMock; + +namespace quic { +namespace test { +namespace { + +class MockDelegate : public QpackDecoderStreamReceiver::Delegate { + public: + ~MockDelegate() override = default; + + MOCK_METHOD(void, OnInsertCountIncrement, (uint64_t increment), (override)); + MOCK_METHOD(void, OnHeaderAcknowledgement, (QuicStreamId stream_id), + (override)); + MOCK_METHOD(void, OnStreamCancellation, (QuicStreamId stream_id), (override)); + MOCK_METHOD(void, OnErrorDetected, + (QuicErrorCode error_code, absl::string_view error_message), + (override)); +}; + +class QpackDecoderStreamReceiverTest : public QuicTest { + protected: + QpackDecoderStreamReceiverTest() : stream_(&delegate_) {} + ~QpackDecoderStreamReceiverTest() override = default; + + QpackDecoderStreamReceiver stream_; + StrictMock delegate_; +}; + +TEST_F(QpackDecoderStreamReceiverTest, InsertCountIncrement) { + EXPECT_CALL(delegate_, OnInsertCountIncrement(0)); + stream_.Decode(absl::HexStringToBytes("00")); + + EXPECT_CALL(delegate_, OnInsertCountIncrement(10)); + stream_.Decode(absl::HexStringToBytes("0a")); + + EXPECT_CALL(delegate_, OnInsertCountIncrement(63)); + stream_.Decode(absl::HexStringToBytes("3f00")); + + EXPECT_CALL(delegate_, OnInsertCountIncrement(200)); + stream_.Decode(absl::HexStringToBytes("3f8901")); + + EXPECT_CALL(delegate_, + OnErrorDetected(QUIC_QPACK_DECODER_STREAM_INTEGER_TOO_LARGE, + Eq("Encoded integer too large."))); + stream_.Decode(absl::HexStringToBytes("3fffffffffffffffffffff")); +} + +TEST_F(QpackDecoderStreamReceiverTest, HeaderAcknowledgement) { + EXPECT_CALL(delegate_, OnHeaderAcknowledgement(0)); + stream_.Decode(absl::HexStringToBytes("80")); + + EXPECT_CALL(delegate_, OnHeaderAcknowledgement(37)); + stream_.Decode(absl::HexStringToBytes("a5")); + + EXPECT_CALL(delegate_, OnHeaderAcknowledgement(127)); + stream_.Decode(absl::HexStringToBytes("ff00")); + + EXPECT_CALL(delegate_, OnHeaderAcknowledgement(503)); + stream_.Decode(absl::HexStringToBytes("fff802")); + + EXPECT_CALL(delegate_, + OnErrorDetected(QUIC_QPACK_DECODER_STREAM_INTEGER_TOO_LARGE, + Eq("Encoded integer too large."))); + stream_.Decode(absl::HexStringToBytes("ffffffffffffffffffffff")); +} + +TEST_F(QpackDecoderStreamReceiverTest, StreamCancellation) { + EXPECT_CALL(delegate_, OnStreamCancellation(0)); + stream_.Decode(absl::HexStringToBytes("40")); + + EXPECT_CALL(delegate_, OnStreamCancellation(19)); + stream_.Decode(absl::HexStringToBytes("53")); + + EXPECT_CALL(delegate_, OnStreamCancellation(63)); + stream_.Decode(absl::HexStringToBytes("7f00")); + + EXPECT_CALL(delegate_, OnStreamCancellation(110)); + stream_.Decode(absl::HexStringToBytes("7f2f")); + + EXPECT_CALL(delegate_, + OnErrorDetected(QUIC_QPACK_DECODER_STREAM_INTEGER_TOO_LARGE, + Eq("Encoded integer too large."))); + stream_.Decode(absl::HexStringToBytes("7fffffffffffffffffffff")); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_decoder_stream_sender.cc b/quiche/quic/core/qpack/qpack_decoder_stream_sender.cc new file mode 100644 index 000000000000..cc4858768a54 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_decoder_stream_sender.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_decoder_stream_sender.h" + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/qpack/qpack_instructions.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +QpackDecoderStreamSender::QpackDecoderStreamSender() : delegate_(nullptr) {} + +void QpackDecoderStreamSender::SendInsertCountIncrement(uint64_t increment) { + instruction_encoder_.Encode( + QpackInstructionWithValues::InsertCountIncrement(increment), &buffer_); +} + +void QpackDecoderStreamSender::SendHeaderAcknowledgement( + QuicStreamId stream_id) { + instruction_encoder_.Encode( + QpackInstructionWithValues::HeaderAcknowledgement(stream_id), &buffer_); +} + +void QpackDecoderStreamSender::SendStreamCancellation(QuicStreamId stream_id) { + instruction_encoder_.Encode( + QpackInstructionWithValues::StreamCancellation(stream_id), &buffer_); +} + +void QpackDecoderStreamSender::Flush() { + if (buffer_.empty()) { + return; + } + + delegate_->WriteStreamData(buffer_); + buffer_.clear(); +} + +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_decoder_stream_sender.h b/quiche/quic/core/qpack/qpack_decoder_stream_sender.h new file mode 100644 index 000000000000..443e850f4941 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_decoder_stream_sender.h @@ -0,0 +1,51 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QPACK_QPACK_DECODER_STREAM_SENDER_H_ +#define QUICHE_QUIC_CORE_QPACK_QPACK_DECODER_STREAM_SENDER_H_ + +#include + +#include "quiche/quic/core/qpack/qpack_instruction_encoder.h" +#include "quiche/quic/core/qpack/qpack_stream_sender_delegate.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// This class serializes instructions for transmission on the decoder stream. +// Serialized instructions are buffered until Flush() is called. +class QUIC_EXPORT_PRIVATE QpackDecoderStreamSender { + public: + QpackDecoderStreamSender(); + QpackDecoderStreamSender(const QpackDecoderStreamSender&) = delete; + QpackDecoderStreamSender& operator=(const QpackDecoderStreamSender&) = delete; + + // Methods for serializing and buffering instructions, see + // https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#rfc.section.5.3 + + // 5.3.1 Insert Count Increment + void SendInsertCountIncrement(uint64_t increment); + // 5.3.2 Header Acknowledgement + void SendHeaderAcknowledgement(QuicStreamId stream_id); + // 5.3.3 Stream Cancellation + void SendStreamCancellation(QuicStreamId stream_id); + + // Writes all buffered instructions on the decoder stream. + void Flush(); + + // delegate must be set if dynamic table capacity is not zero. + void set_qpack_stream_sender_delegate(QpackStreamSenderDelegate* delegate) { + delegate_ = delegate; + } + + private: + QpackStreamSenderDelegate* delegate_; + QpackInstructionEncoder instruction_encoder_; + std::string buffer_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QPACK_QPACK_DECODER_STREAM_SENDER_H_ diff --git a/quiche/quic/core/qpack/qpack_decoder_stream_sender_test.cc b/quiche/quic/core/qpack/qpack_decoder_stream_sender_test.cc new file mode 100644 index 000000000000..8c43af53e1e1 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_decoder_stream_sender_test.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_decoder_stream_sender.h" + +#include "absl/strings/escaping.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/qpack/qpack_test_utils.h" + +using ::testing::Eq; +using ::testing::StrictMock; + +namespace quic { +namespace test { +namespace { + +class QpackDecoderStreamSenderTest : public QuicTest { + protected: + QpackDecoderStreamSenderTest() { + stream_.set_qpack_stream_sender_delegate(&delegate_); + } + ~QpackDecoderStreamSenderTest() override = default; + + StrictMock delegate_; + QpackDecoderStreamSender stream_; +}; + +TEST_F(QpackDecoderStreamSenderTest, InsertCountIncrement) { + EXPECT_CALL(delegate_, WriteStreamData(Eq(absl::HexStringToBytes("00")))); + stream_.SendInsertCountIncrement(0); + stream_.Flush(); + + EXPECT_CALL(delegate_, WriteStreamData(Eq(absl::HexStringToBytes("0a")))); + stream_.SendInsertCountIncrement(10); + stream_.Flush(); + + EXPECT_CALL(delegate_, WriteStreamData(Eq(absl::HexStringToBytes("3f00")))); + stream_.SendInsertCountIncrement(63); + stream_.Flush(); + + EXPECT_CALL(delegate_, WriteStreamData(Eq(absl::HexStringToBytes("3f8901")))); + stream_.SendInsertCountIncrement(200); + stream_.Flush(); +} + +TEST_F(QpackDecoderStreamSenderTest, HeaderAcknowledgement) { + EXPECT_CALL(delegate_, WriteStreamData(Eq(absl::HexStringToBytes("80")))); + stream_.SendHeaderAcknowledgement(0); + stream_.Flush(); + + EXPECT_CALL(delegate_, WriteStreamData(Eq(absl::HexStringToBytes("a5")))); + stream_.SendHeaderAcknowledgement(37); + stream_.Flush(); + + EXPECT_CALL(delegate_, WriteStreamData(Eq(absl::HexStringToBytes("ff00")))); + stream_.SendHeaderAcknowledgement(127); + stream_.Flush(); + + EXPECT_CALL(delegate_, WriteStreamData(Eq(absl::HexStringToBytes("fff802")))); + stream_.SendHeaderAcknowledgement(503); + stream_.Flush(); +} + +TEST_F(QpackDecoderStreamSenderTest, StreamCancellation) { + EXPECT_CALL(delegate_, WriteStreamData(Eq(absl::HexStringToBytes("40")))); + stream_.SendStreamCancellation(0); + stream_.Flush(); + + EXPECT_CALL(delegate_, WriteStreamData(Eq(absl::HexStringToBytes("53")))); + stream_.SendStreamCancellation(19); + stream_.Flush(); + + EXPECT_CALL(delegate_, WriteStreamData(Eq(absl::HexStringToBytes("7f00")))); + stream_.SendStreamCancellation(63); + stream_.Flush(); + + EXPECT_CALL(delegate_, WriteStreamData(Eq(absl::HexStringToBytes("7f2f")))); + stream_.SendStreamCancellation(110); + stream_.Flush(); +} + +TEST_F(QpackDecoderStreamSenderTest, Coalesce) { + stream_.SendInsertCountIncrement(10); + stream_.SendHeaderAcknowledgement(37); + stream_.SendStreamCancellation(0); + + EXPECT_CALL(delegate_, WriteStreamData(Eq(absl::HexStringToBytes("0aa540")))); + stream_.Flush(); + + stream_.SendInsertCountIncrement(63); + stream_.SendStreamCancellation(110); + + EXPECT_CALL(delegate_, + WriteStreamData(Eq(absl::HexStringToBytes("3f007f2f")))); + stream_.Flush(); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_decoder_test.cc b/quiche/quic/core/qpack/qpack_decoder_test.cc new file mode 100644 index 000000000000..5cfd9602279f --- /dev/null +++ b/quiche/quic/core/qpack/qpack_decoder_test.cc @@ -0,0 +1,979 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_decoder.h" + +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/qpack/qpack_decoder_test_utils.h" +#include "quiche/quic/test_tools/qpack/qpack_test_utils.h" + +using ::testing::_; +using ::testing::Eq; +using ::testing::Invoke; +using ::testing::Mock; +using ::testing::Sequence; +using ::testing::StrictMock; +using ::testing::Values; + +namespace quic { +namespace test { +namespace { + +// Header Acknowledgement decoder stream instruction with stream_id = 1. +const char* const kHeaderAcknowledgement = "\x81"; + +const uint64_t kMaximumDynamicTableCapacity = 1024; +const uint64_t kMaximumBlockedStreams = 1; + +class QpackDecoderTest : public QuicTestWithParam { + protected: + QpackDecoderTest() + : qpack_decoder_(kMaximumDynamicTableCapacity, kMaximumBlockedStreams, + &encoder_stream_error_delegate_), + fragment_mode_(GetParam()) { + qpack_decoder_.set_qpack_stream_sender_delegate( + &decoder_stream_sender_delegate_); + } + + ~QpackDecoderTest() override = default; + + void SetUp() override { + // Destroy QpackProgressiveDecoder on error to test that it does not crash. + // See https://crbug.com/1025209. + ON_CALL(handler_, OnDecodingErrorDetected(_, _)) + .WillByDefault(Invoke([this](QuicErrorCode /* error_code */, + absl::string_view /* error_message */) { + progressive_decoder_.reset(); + })); + } + + void DecodeEncoderStreamData(absl::string_view data) { + qpack_decoder_.encoder_stream_receiver()->Decode(data); + } + + std::unique_ptr CreateProgressiveDecoder( + QuicStreamId stream_id) { + return qpack_decoder_.CreateProgressiveDecoder(stream_id, &handler_); + } + + // Set up |progressive_decoder_|. + void StartDecoding() { + progressive_decoder_ = CreateProgressiveDecoder(/* stream_id = */ 1); + } + + // Pass header block data to QpackProgressiveDecoder::Decode() + // in fragments dictated by |fragment_mode_|. + void DecodeData(absl::string_view data) { + auto fragment_size_generator = + FragmentModeToFragmentSizeGenerator(fragment_mode_); + while (progressive_decoder_ && !data.empty()) { + size_t fragment_size = std::min(fragment_size_generator(), data.size()); + progressive_decoder_->Decode(data.substr(0, fragment_size)); + data = data.substr(fragment_size); + } + } + + // Signal end of header block to QpackProgressiveDecoder. + void EndDecoding() { + if (progressive_decoder_) { + progressive_decoder_->EndHeaderBlock(); + } + // If no error was detected, |*progressive_decoder_| is kept alive so that + // it can handle callbacks later in case of blocked decoding. + } + + // Decode an entire header block. + void DecodeHeaderBlock(absl::string_view data) { + StartDecoding(); + DecodeData(data); + EndDecoding(); + } + + StrictMock encoder_stream_error_delegate_; + StrictMock decoder_stream_sender_delegate_; + StrictMock handler_; + + private: + QpackDecoder qpack_decoder_; + const FragmentMode fragment_mode_; + std::unique_ptr progressive_decoder_; +}; + +INSTANTIATE_TEST_SUITE_P(All, QpackDecoderTest, + Values(FragmentMode::kSingleChunk, + FragmentMode::kOctetByOctet)); + +TEST_P(QpackDecoderTest, NoPrefix) { + EXPECT_CALL(handler_, + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Incomplete header data prefix."))); + + // Header Data Prefix is at least two bytes long. + DecodeHeaderBlock(absl::HexStringToBytes("00")); +} + +// Regression test for https://1025209: QpackProgressiveDecoder must not crash +// in Decode() if it is destroyed by handler_.OnDecodingErrorDetected(). +TEST_P(QpackDecoderTest, InvalidPrefix) { + StartDecoding(); + + EXPECT_CALL(handler_, + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Encoded integer too large."))); + + // Encoded Required Insert Count in Header Data Prefix is too large. + DecodeData(absl::HexStringToBytes("ffffffffffffffffffffffffffff")); +} + +TEST_P(QpackDecoderTest, EmptyHeaderBlock) { + EXPECT_CALL(handler_, OnDecodingCompleted()); + + DecodeHeaderBlock(absl::HexStringToBytes("0000")); +} + +TEST_P(QpackDecoderTest, LiteralEntryEmptyName) { + EXPECT_CALL(handler_, OnHeaderDecoded(Eq(""), Eq("foo"))); + EXPECT_CALL(handler_, OnDecodingCompleted()); + + DecodeHeaderBlock(absl::HexStringToBytes("00002003666f6f")); +} + +TEST_P(QpackDecoderTest, LiteralEntryEmptyValue) { + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq(""))); + EXPECT_CALL(handler_, OnDecodingCompleted()); + + DecodeHeaderBlock(absl::HexStringToBytes("000023666f6f00")); +} + +TEST_P(QpackDecoderTest, LiteralEntryEmptyNameAndValue) { + EXPECT_CALL(handler_, OnHeaderDecoded(Eq(""), Eq(""))); + EXPECT_CALL(handler_, OnDecodingCompleted()); + + DecodeHeaderBlock(absl::HexStringToBytes("00002000")); +} + +TEST_P(QpackDecoderTest, SimpleLiteralEntry) { + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq("bar"))); + EXPECT_CALL(handler_, OnDecodingCompleted()); + + DecodeHeaderBlock(absl::HexStringToBytes("000023666f6f03626172")); +} + +TEST_P(QpackDecoderTest, MultipleLiteralEntries) { + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq("bar"))); + std::string str(127, 'a'); + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foobaar"), absl::string_view(str))); + EXPECT_CALL(handler_, OnDecodingCompleted()); + + DecodeHeaderBlock(absl::HexStringToBytes( + "0000" // prefix + "23666f6f03626172" // foo: bar + "2700666f6f62616172" // 7 octet long header name, the smallest number + // that does not fit on a 3-bit prefix. + "7f0061616161616161" // 127 octet long header value, the smallest number + "616161616161616161" // that does not fit on a 7-bit prefix. + "6161616161616161616161616161616161616161616161616161616161616161616161" + "6161616161616161616161616161616161616161616161616161616161616161616161" + "6161616161616161616161616161616161616161616161616161616161616161616161" + "616161616161")); +} + +// Name Length value is too large for varint decoder to decode. +TEST_P(QpackDecoderTest, NameLenTooLargeForVarintDecoder) { + EXPECT_CALL(handler_, + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Encoded integer too large."))); + + DecodeHeaderBlock(absl::HexStringToBytes("000027ffffffffffffffffffff")); +} + +// Name Length value can be decoded by varint decoder but exceeds 1 MB limit. +TEST_P(QpackDecoderTest, NameLenExceedsLimit) { + EXPECT_CALL(handler_, + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("String literal too long."))); + + DecodeHeaderBlock(absl::HexStringToBytes("000027ffff7f")); +} + +// Value Length value is too large for varint decoder to decode. +TEST_P(QpackDecoderTest, ValueLenTooLargeForVarintDecoder) { + EXPECT_CALL(handler_, + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Encoded integer too large."))); + + DecodeHeaderBlock( + absl::HexStringToBytes("000023666f6f7fffffffffffffffffffff")); +} + +// Value Length value can be decoded by varint decoder but exceeds 1 MB limit. +TEST_P(QpackDecoderTest, ValueLenExceedsLimit) { + EXPECT_CALL(handler_, + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("String literal too long."))); + + DecodeHeaderBlock(absl::HexStringToBytes("000023666f6f7fffff7f")); +} + +TEST_P(QpackDecoderTest, LineFeedInValue) { + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq("ba\nr"))); + EXPECT_CALL(handler_, OnDecodingCompleted()); + DecodeHeaderBlock(absl::HexStringToBytes("000023666f6f0462610a72")); +} + +TEST_P(QpackDecoderTest, IncompleteHeaderBlock) { + EXPECT_CALL(handler_, + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Incomplete header block."))); + + DecodeHeaderBlock(absl::HexStringToBytes("00002366")); +} + +TEST_P(QpackDecoderTest, HuffmanSimple) { + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("custom-key"), Eq("custom-value"))); + EXPECT_CALL(handler_, OnDecodingCompleted()); + + DecodeHeaderBlock( + absl::HexStringToBytes("00002f0125a849e95ba97d7f8925a849e95bb8e8b4bf")); +} + +TEST_P(QpackDecoderTest, AlternatingHuffmanNonHuffman) { + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("custom-key"), Eq("custom-value"))) + .Times(4); + EXPECT_CALL(handler_, OnDecodingCompleted()); + + DecodeHeaderBlock(absl::HexStringToBytes( + "0000" // Prefix. + "2f0125a849e95ba97d7f" // Huffman-encoded name. + "8925a849e95bb8e8b4bf" // Huffman-encoded value. + "2703637573746f6d2d6b6579" // Non-Huffman encoded name. + "0c637573746f6d2d76616c7565" // Non-Huffman encoded value. + "2f0125a849e95ba97d7f" // Huffman-encoded name. + "0c637573746f6d2d76616c7565" // Non-Huffman encoded value. + "2703637573746f6d2d6b6579" // Non-Huffman encoded name. + "8925a849e95bb8e8b4bf")); // Huffman-encoded value. +} + +TEST_P(QpackDecoderTest, HuffmanNameDoesNotHaveEOSPrefix) { + EXPECT_CALL(handler_, + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Error in Huffman-encoded string."))); + + // 'y' ends in 0b0 on the most significant bit of the last byte. + // The remaining 7 bits must be a prefix of EOS, which is all 1s. + DecodeHeaderBlock( + absl::HexStringToBytes("00002f0125a849e95ba97d7e8925a849e95bb8e8b4bf")); +} + +TEST_P(QpackDecoderTest, HuffmanValueDoesNotHaveEOSPrefix) { + EXPECT_CALL(handler_, + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Error in Huffman-encoded string."))); + + // 'e' ends in 0b101, taking up the 3 most significant bits of the last byte. + // The remaining 5 bits must be a prefix of EOS, which is all 1s. + DecodeHeaderBlock( + absl::HexStringToBytes("00002f0125a849e95ba97d7f8925a849e95bb8e8b4be")); +} + +TEST_P(QpackDecoderTest, HuffmanNameEOSPrefixTooLong) { + EXPECT_CALL(handler_, + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Error in Huffman-encoded string."))); + + // The trailing EOS prefix must be at most 7 bits long. Appending one octet + // with value 0xff is invalid, even though 0b111111111111111 (15 bits) is a + // prefix of EOS. + DecodeHeaderBlock( + absl::HexStringToBytes("00002f0225a849e95ba97d7fff8925a849e95bb8e8b4bf")); +} + +TEST_P(QpackDecoderTest, HuffmanValueEOSPrefixTooLong) { + EXPECT_CALL(handler_, + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Error in Huffman-encoded string."))); + + // The trailing EOS prefix must be at most 7 bits long. Appending one octet + // with value 0xff is invalid, even though 0b1111111111111 (13 bits) is a + // prefix of EOS. + DecodeHeaderBlock( + absl::HexStringToBytes("00002f0125a849e95ba97d7f8a25a849e95bb8e8b4bfff")); +} + +TEST_P(QpackDecoderTest, StaticTable) { + // A header name that has multiple entries with different values. + EXPECT_CALL(handler_, OnHeaderDecoded(Eq(":method"), Eq("GET"))); + EXPECT_CALL(handler_, OnHeaderDecoded(Eq(":method"), Eq("POST"))); + EXPECT_CALL(handler_, OnHeaderDecoded(Eq(":method"), Eq("TRACE"))); + + // A header name that has a single entry with non-empty value. + EXPECT_CALL(handler_, + OnHeaderDecoded(Eq("accept-encoding"), Eq("gzip, deflate, br"))); + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("accept-encoding"), Eq("compress"))); + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("accept-encoding"), Eq(""))); + + // A header name that has a single entry with empty value. + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("location"), Eq(""))); + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("location"), Eq("foo"))); + + EXPECT_CALL(handler_, OnDecodingCompleted()); + + DecodeHeaderBlock(absl::HexStringToBytes( + "0000d1dfccd45f108621e9aec2a11f5c8294e75f000554524143455f1000")); +} + +TEST_P(QpackDecoderTest, TooHighStaticTableIndex) { + // This is the last entry in the static table with index 98. + EXPECT_CALL(handler_, + OnHeaderDecoded(Eq("x-frame-options"), Eq("sameorigin"))); + + // Addressing entry 99 should trigger an error. + EXPECT_CALL(handler_, + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Static table entry not found."))); + + DecodeHeaderBlock(absl::HexStringToBytes("0000ff23ff24")); +} + +TEST_P(QpackDecoderTest, DynamicTable) { + DecodeEncoderStreamData(absl::HexStringToBytes( + "3fe107" // Set dynamic table capacity to 1024. + "6294e703626172" // Add literal entry with name "foo" and value "bar". + "80035a5a5a" // Add entry with name of dynamic table entry index 0 + // (relative index) and value "ZZZ". + "cf8294e7" // Add entry with name of static table entry index 15 + // and value "foo". + "01")); // Duplicate entry with relative index 1. + + // Now there are four entries in the dynamic table. + // Entry 0: "foo", "bar" + // Entry 1: "foo", "ZZZ" + // Entry 2: ":method", "foo" + // Entry 3: "foo", "ZZZ" + + // Use a Sequence to test that mock methods are called in order. + Sequence s; + + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq("bar"))).InSequence(s); + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq("ZZZ"))).InSequence(s); + EXPECT_CALL(handler_, OnHeaderDecoded(Eq(":method"), Eq("foo"))) + .InSequence(s); + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq("ZZZ"))).InSequence(s); + EXPECT_CALL(handler_, OnHeaderDecoded(Eq(":method"), Eq("ZZ"))).InSequence(s); + EXPECT_CALL(decoder_stream_sender_delegate_, + WriteStreamData(Eq(kHeaderAcknowledgement))) + .InSequence(s); + EXPECT_CALL(handler_, OnDecodingCompleted()).InSequence(s); + + DecodeHeaderBlock(absl::HexStringToBytes( + "0500" // Required Insert Count 4 and Delta Base 0. + // Base is 4 + 0 = 4. + "83" // Dynamic table entry with relative index 3, absolute index 0. + "82" // Dynamic table entry with relative index 2, absolute index 1. + "81" // Dynamic table entry with relative index 1, absolute index 2. + "80" // Dynamic table entry with relative index 0, absolute index 3. + "41025a5a")); // Name of entry 1 (relative index) from dynamic table, + // with value "ZZ". + + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq("bar"))).InSequence(s); + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq("ZZZ"))).InSequence(s); + EXPECT_CALL(handler_, OnHeaderDecoded(Eq(":method"), Eq("foo"))) + .InSequence(s); + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq("ZZZ"))).InSequence(s); + EXPECT_CALL(handler_, OnHeaderDecoded(Eq(":method"), Eq("ZZ"))).InSequence(s); + EXPECT_CALL(decoder_stream_sender_delegate_, + WriteStreamData(Eq(kHeaderAcknowledgement))) + .InSequence(s); + EXPECT_CALL(handler_, OnDecodingCompleted()).InSequence(s); + + DecodeHeaderBlock(absl::HexStringToBytes( + "0502" // Required Insert Count 4 and Delta Base 2. + // Base is 4 + 2 = 6. + "85" // Dynamic table entry with relative index 5, absolute index 0. + "84" // Dynamic table entry with relative index 4, absolute index 1. + "83" // Dynamic table entry with relative index 3, absolute index 2. + "82" // Dynamic table entry with relative index 2, absolute index 3. + "43025a5a")); // Name of entry 3 (relative index) from dynamic table, + // with value "ZZ". + + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq("bar"))).InSequence(s); + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq("ZZZ"))).InSequence(s); + EXPECT_CALL(handler_, OnHeaderDecoded(Eq(":method"), Eq("foo"))) + .InSequence(s); + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq("ZZZ"))).InSequence(s); + EXPECT_CALL(handler_, OnHeaderDecoded(Eq(":method"), Eq("ZZ"))).InSequence(s); + EXPECT_CALL(decoder_stream_sender_delegate_, + WriteStreamData(Eq(kHeaderAcknowledgement))) + .InSequence(s); + EXPECT_CALL(handler_, OnDecodingCompleted()).InSequence(s); + + DecodeHeaderBlock(absl::HexStringToBytes( + "0582" // Required Insert Count 4 and Delta Base 2 with sign bit set. + // Base is 4 - 2 - 1 = 1. + "80" // Dynamic table entry with relative index 0, absolute index 0. + "10" // Dynamic table entry with post-base index 0, absolute index 1. + "11" // Dynamic table entry with post-base index 1, absolute index 2. + "12" // Dynamic table entry with post-base index 2, absolute index 3. + "01025a5a")); // Name of entry 1 (post-base index) from dynamic table, + // with value "ZZ". +} + +TEST_P(QpackDecoderTest, DecreasingDynamicTableCapacityEvictsEntries) { + // Set dynamic table capacity to 1024. + DecodeEncoderStreamData(absl::HexStringToBytes("3fe107")); + // Add literal entry with name "foo" and value "bar". + DecodeEncoderStreamData(absl::HexStringToBytes("6294e703626172")); + + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq("bar"))); + EXPECT_CALL(handler_, OnDecodingCompleted()); + EXPECT_CALL(decoder_stream_sender_delegate_, + WriteStreamData(Eq(kHeaderAcknowledgement))); + + DecodeHeaderBlock(absl::HexStringToBytes( + "0200" // Required Insert Count 1 and Delta Base 0. + // Base is 1 + 0 = 1. + "80")); // Dynamic table entry with relative index 0, absolute index 0. + + // Change dynamic table capacity to 32 bytes, smaller than the entry. + // This must cause the entry to be evicted. + DecodeEncoderStreamData(absl::HexStringToBytes("3f01")); + + EXPECT_CALL(handler_, OnDecodingErrorDetected( + QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Dynamic table entry already evicted."))); + + DecodeHeaderBlock(absl::HexStringToBytes( + "0200" // Required Insert Count 1 and Delta Base 0. + // Base is 1 + 0 = 1. + "80")); // Dynamic table entry with relative index 0, absolute index 0. +} + +TEST_P(QpackDecoderTest, EncoderStreamErrorEntryTooLarge) { + EXPECT_CALL( + encoder_stream_error_delegate_, + OnEncoderStreamError(QUIC_QPACK_ENCODER_STREAM_ERROR_INSERTING_LITERAL, + Eq("Error inserting literal entry."))); + + // Set dynamic table capacity to 34. + DecodeEncoderStreamData(absl::HexStringToBytes("3f03")); + // Add literal entry with name "foo" and value "bar", size is 32 + 3 + 3 = 38. + DecodeEncoderStreamData(absl::HexStringToBytes("6294e703626172")); +} + +TEST_P(QpackDecoderTest, EncoderStreamErrorInvalidStaticTableEntry) { + EXPECT_CALL( + encoder_stream_error_delegate_, + OnEncoderStreamError(QUIC_QPACK_ENCODER_STREAM_INVALID_STATIC_ENTRY, + Eq("Invalid static table entry."))); + + // Address invalid static table entry index 99. + DecodeEncoderStreamData(absl::HexStringToBytes("ff2400")); +} + +TEST_P(QpackDecoderTest, EncoderStreamErrorInvalidDynamicTableEntry) { + EXPECT_CALL(encoder_stream_error_delegate_, + OnEncoderStreamError( + QUIC_QPACK_ENCODER_STREAM_INSERTION_INVALID_RELATIVE_INDEX, + Eq("Invalid relative index."))); + + DecodeEncoderStreamData(absl::HexStringToBytes( + "3fe107" // Set dynamic table capacity to 1024. + "6294e703626172" // Add literal entry with name "foo" and value "bar". + "8100")); // Address dynamic table entry with relative index 1. Such + // entry does not exist. The most recently added and only + // dynamic table entry has relative index 0. +} + +TEST_P(QpackDecoderTest, EncoderStreamErrorDuplicateInvalidEntry) { + EXPECT_CALL(encoder_stream_error_delegate_, + OnEncoderStreamError( + QUIC_QPACK_ENCODER_STREAM_DUPLICATE_INVALID_RELATIVE_INDEX, + Eq("Invalid relative index."))); + + DecodeEncoderStreamData(absl::HexStringToBytes( + "3fe107" // Set dynamic table capacity to 1024. + "6294e703626172" // Add literal entry with name "foo" and value "bar". + "01")); // Duplicate dynamic table entry with relative index 1. Such + // entry does not exist. The most recently added and only + // dynamic table entry has relative index 0. +} + +TEST_P(QpackDecoderTest, EncoderStreamErrorTooLargeInteger) { + EXPECT_CALL(encoder_stream_error_delegate_, + OnEncoderStreamError(QUIC_QPACK_ENCODER_STREAM_INTEGER_TOO_LARGE, + Eq("Encoded integer too large."))); + + DecodeEncoderStreamData(absl::HexStringToBytes("3fffffffffffffffffffff")); +} + +TEST_P(QpackDecoderTest, InvalidDynamicEntryWhenBaseIsZero) { + EXPECT_CALL(handler_, OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Invalid relative index."))); + + // Set dynamic table capacity to 1024. + DecodeEncoderStreamData(absl::HexStringToBytes("3fe107")); + // Add literal entry with name "foo" and value "bar". + DecodeEncoderStreamData(absl::HexStringToBytes("6294e703626172")); + + DecodeHeaderBlock(absl::HexStringToBytes( + "0280" // Required Insert Count is 1. Base 1 - 1 - 0 = 0 is explicitly + // permitted by the spec. + "80")); // However, addressing entry with relative index 0 would point to + // absolute index -1, which is invalid. +} + +TEST_P(QpackDecoderTest, InvalidNegativeBase) { + EXPECT_CALL(handler_, OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Error calculating Base."))); + + // Required Insert Count 1, Delta Base 1 with sign bit set, Base would + // be 1 - 1 - 1 = -1, but it is not allowed to be negative. + DecodeHeaderBlock(absl::HexStringToBytes("0281")); +} + +TEST_P(QpackDecoderTest, InvalidDynamicEntryByRelativeIndex) { + // Set dynamic table capacity to 1024. + DecodeEncoderStreamData(absl::HexStringToBytes("3fe107")); + // Add literal entry with name "foo" and value "bar". + DecodeEncoderStreamData(absl::HexStringToBytes("6294e703626172")); + + EXPECT_CALL(handler_, OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Invalid relative index."))); + + DecodeHeaderBlock(absl::HexStringToBytes( + "0200" // Required Insert Count 1 and Delta Base 0. + // Base is 1 + 0 = 1. + "81")); // Indexed Header Field instruction addressing relative index 1. + // This is absolute index -1, which is invalid. + + EXPECT_CALL(handler_, OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Invalid relative index."))); + + DecodeHeaderBlock(absl::HexStringToBytes( + "0200" // Required Insert Count 1 and Delta Base 0. + // Base is 1 + 0 = 1. + "4100")); // Literal Header Field with Name Reference instruction + // addressing relative index 1. This is absolute index -1, + // which is invalid. +} + +TEST_P(QpackDecoderTest, EvictedDynamicTableEntry) { + // Update dynamic table capacity to 128. + DecodeEncoderStreamData(absl::HexStringToBytes("3f61")); + + // Add literal entry with name "foo" and value "bar", size 32 + 3 + 3 = 38. + // This fits in the table three times. + DecodeEncoderStreamData(absl::HexStringToBytes("6294e703626172")); + // Duplicate entry four times. This evicts the first two instances. + DecodeEncoderStreamData(absl::HexStringToBytes("00000000")); + + EXPECT_CALL(handler_, OnDecodingErrorDetected( + QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Dynamic table entry already evicted."))); + + DecodeHeaderBlock(absl::HexStringToBytes( + "0500" // Required Insert Count 4 and Delta Base 0. + // Base is 4 + 0 = 4. + "82")); // Indexed Header Field instruction addressing relative index 2. + // This is absolute index 1. Such entry does not exist. + + EXPECT_CALL(handler_, OnDecodingErrorDetected( + QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Dynamic table entry already evicted."))); + + DecodeHeaderBlock(absl::HexStringToBytes( + "0500" // Required Insert Count 4 and Delta Base 0. + // Base is 4 + 0 = 4. + "4200")); // Literal Header Field with Name Reference instruction + // addressing relative index 2. This is absolute index 1. Such + // entry does not exist. + + EXPECT_CALL(handler_, OnDecodingErrorDetected( + QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Dynamic table entry already evicted."))); + + DecodeHeaderBlock(absl::HexStringToBytes( + "0380" // Required Insert Count 2 and Delta Base 0 with sign bit set. + // Base is 2 - 0 - 1 = 1 + "10")); // Indexed Header Field instruction addressing dynamic table + // entry with post-base index 0, absolute index 1. Such entry + // does not exist. + + EXPECT_CALL(handler_, OnDecodingErrorDetected( + QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Dynamic table entry already evicted."))); + + DecodeHeaderBlock(absl::HexStringToBytes( + "0380" // Required Insert Count 2 and Delta Base 0 with sign bit set. + // Base is 2 - 0 - 1 = 1 + "0000")); // Literal Header Field With Name Reference instruction + // addressing dynamic table entry with post-base index 0, + // absolute index 1. Such entry does not exist. +} + +TEST_P(QpackDecoderTest, TableCapacityMustNotExceedMaximum) { + EXPECT_CALL( + encoder_stream_error_delegate_, + OnEncoderStreamError(QUIC_QPACK_ENCODER_STREAM_SET_DYNAMIC_TABLE_CAPACITY, + Eq("Error updating dynamic table capacity."))); + + // Try to update dynamic table capacity to 2048, which exceeds the maximum. + DecodeEncoderStreamData(absl::HexStringToBytes("3fe10f")); +} + +TEST_P(QpackDecoderTest, SetDynamicTableCapacity) { + // Update dynamic table capacity to 128, which does not exceed the maximum. + DecodeEncoderStreamData(absl::HexStringToBytes("3f61")); +} + +TEST_P(QpackDecoderTest, InvalidEncodedRequiredInsertCount) { + // Maximum dynamic table capacity is 1024. + // MaxEntries is 1024 / 32 = 32. + // Required Insert Count is decoded modulo 2 * MaxEntries, that is, modulo 64. + // A value of 1 cannot be encoded as 65 even though it has the same remainder. + EXPECT_CALL(handler_, OnDecodingErrorDetected( + QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Error decoding Required Insert Count."))); + DecodeHeaderBlock(absl::HexStringToBytes("4100")); +} + +// Regression test for https://crbug.com/970218: Decoder must stop processing +// after a Header Block Prefix with an invalid Encoded Required Insert Count. +TEST_P(QpackDecoderTest, DataAfterInvalidEncodedRequiredInsertCount) { + EXPECT_CALL(handler_, OnDecodingErrorDetected( + QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Error decoding Required Insert Count."))); + // Header Block Prefix followed by some extra data. + DecodeHeaderBlock(absl::HexStringToBytes("410000")); +} + +TEST_P(QpackDecoderTest, WrappedRequiredInsertCount) { + // Maximum dynamic table capacity is 1024. + // MaxEntries is 1024 / 32 = 32. + + // Set dynamic table capacity to 1024. + DecodeEncoderStreamData(absl::HexStringToBytes("3fe107")); + // Add literal entry with name "foo" and a 600 byte long value. This will fit + // in the dynamic table once but not twice. + DecodeEncoderStreamData( + absl::HexStringToBytes("6294e7" // Name "foo". + "7fd903")); // Value length 600. + std::string header_value(600, 'Z'); + DecodeEncoderStreamData(header_value); + + // Duplicate most recent entry 200 times. + DecodeEncoderStreamData(std::string(200, '\x00')); + + // Now there is only one entry in the dynamic table, with absolute index 200. + + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq(header_value))); + EXPECT_CALL(handler_, OnDecodingCompleted()); + EXPECT_CALL(decoder_stream_sender_delegate_, + WriteStreamData(Eq(kHeaderAcknowledgement))); + + // Send header block with Required Insert Count = 201. + DecodeHeaderBlock(absl::HexStringToBytes( + "0a00" // Encoded Required Insert Count 10, Required Insert Count 201, + // Delta Base 0, Base 201. + "80")); // Emit dynamic table entry with relative index 0. +} + +TEST_P(QpackDecoderTest, NonZeroRequiredInsertCountButNoDynamicEntries) { + // Set dynamic table capacity to 1024. + DecodeEncoderStreamData(absl::HexStringToBytes("3fe107")); + // Add literal entry with name "foo" and value "bar". + DecodeEncoderStreamData(absl::HexStringToBytes("6294e703626172")); + + EXPECT_CALL(handler_, OnHeaderDecoded(Eq(":method"), Eq("GET"))); + EXPECT_CALL(handler_, + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Required Insert Count too large."))); + + DecodeHeaderBlock(absl::HexStringToBytes( + "0200" // Required Insert Count is 1. + "d1")); // But the only instruction references the static table. +} + +TEST_P(QpackDecoderTest, AddressEntryNotAllowedByRequiredInsertCount) { + // Set dynamic table capacity to 1024. + DecodeEncoderStreamData(absl::HexStringToBytes("3fe107")); + // Add literal entry with name "foo" and value "bar". + DecodeEncoderStreamData(absl::HexStringToBytes("6294e703626172")); + + EXPECT_CALL( + handler_, + OnDecodingErrorDetected( + QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Absolute Index must be smaller than Required Insert Count."))); + + DecodeHeaderBlock(absl::HexStringToBytes( + "0201" // Required Insert Count 1 and Delta Base 1. + // Base is 1 + 1 = 2. + "80")); // Indexed Header Field instruction addressing dynamic table + // entry with relative index 0, absolute index 1. This is not + // allowed by Required Insert Count. + + EXPECT_CALL( + handler_, + OnDecodingErrorDetected( + QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Absolute Index must be smaller than Required Insert Count."))); + + DecodeHeaderBlock(absl::HexStringToBytes( + "0201" // Required Insert Count 1 and Delta Base 1. + // Base is 1 + 1 = 2. + "4000")); // Literal Header Field with Name Reference instruction + // addressing dynamic table entry with relative index 0, + // absolute index 1. This is not allowed by Required Index + // Count. + + EXPECT_CALL( + handler_, + OnDecodingErrorDetected( + QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Absolute Index must be smaller than Required Insert Count."))); + + DecodeHeaderBlock(absl::HexStringToBytes( + "0200" // Required Insert Count 1 and Delta Base 0. + // Base is 1 + 0 = 1. + "10")); // Indexed Header Field with Post-Base Index instruction + // addressing dynamic table entry with post-base index 0, + // absolute index 1. This is not allowed by Required Insert + // Count. + + EXPECT_CALL( + handler_, + OnDecodingErrorDetected( + QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Absolute Index must be smaller than Required Insert Count."))); + + DecodeHeaderBlock(absl::HexStringToBytes( + "0200" // Required Insert Count 1 and Delta Base 0. + // Base is 1 + 0 = 1. + "0000")); // Literal Header Field with Post-Base Name Reference + // instruction addressing dynamic table entry with post-base + // index 0, absolute index 1. This is not allowed by Required + // Index Count. +} + +TEST_P(QpackDecoderTest, PromisedRequiredInsertCountLargerThanActual) { + // Set dynamic table capacity to 1024. + DecodeEncoderStreamData(absl::HexStringToBytes("3fe107")); + // Add literal entry with name "foo" and value "bar". + DecodeEncoderStreamData(absl::HexStringToBytes("6294e703626172")); + // Duplicate entry twice so that decoding of header blocks with Required + // Insert Count not exceeding 3 is not blocked. + DecodeEncoderStreamData(absl::HexStringToBytes("00")); + DecodeEncoderStreamData(absl::HexStringToBytes("00")); + + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq("bar"))); + EXPECT_CALL(handler_, + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Required Insert Count too large."))); + + DecodeHeaderBlock(absl::HexStringToBytes( + "0300" // Required Insert Count 2 and Delta Base 0. + // Base is 2 + 0 = 2. + "81")); // Indexed Header Field instruction addressing dynamic table + // entry with relative index 1, absolute index 0. Header block + // requires insert count of 1, even though Required Insert Count + // is 2. + + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq(""))); + EXPECT_CALL(handler_, + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Required Insert Count too large."))); + + DecodeHeaderBlock(absl::HexStringToBytes( + "0300" // Required Insert Count 2 and Delta Base 0. + // Base is 2 + 0 = 2. + "4100")); // Literal Header Field with Name Reference instruction + // addressing dynamic table entry with relative index 1, + // absolute index 0. Header block requires insert count of 1, + // even though Required Insert Count is 2. + + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq("bar"))); + EXPECT_CALL(handler_, + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Required Insert Count too large."))); + + DecodeHeaderBlock(absl::HexStringToBytes( + "0481" // Required Insert Count 3 and Delta Base 1 with sign bit set. + // Base is 3 - 1 - 1 = 1. + "10")); // Indexed Header Field with Post-Base Index instruction + // addressing dynamic table entry with post-base index 0, + // absolute index 1. Header block requires insert count of 2, + // even though Required Insert Count is 3. + + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq(""))); + EXPECT_CALL(handler_, + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Required Insert Count too large."))); + + DecodeHeaderBlock(absl::HexStringToBytes( + "0481" // Required Insert Count 3 and Delta Base 1 with sign bit set. + // Base is 3 - 1 - 1 = 1. + "0000")); // Literal Header Field with Post-Base Name Reference + // instruction addressing dynamic table entry with post-base + // index 0, absolute index 1. Header block requires insert + // count of 2, even though Required Insert Count is 3. +} + +TEST_P(QpackDecoderTest, BlockedDecoding) { + DecodeHeaderBlock(absl::HexStringToBytes( + "0200" // Required Insert Count 1 and Delta Base 0. + // Base is 1 + 0 = 1. + "80")); // Indexed Header Field instruction addressing dynamic table + // entry with relative index 0, absolute index 0. + + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq("bar"))); + EXPECT_CALL(handler_, OnDecodingCompleted()); + EXPECT_CALL(decoder_stream_sender_delegate_, + WriteStreamData(Eq(kHeaderAcknowledgement))); + + // Set dynamic table capacity to 1024. + DecodeEncoderStreamData(absl::HexStringToBytes("3fe107")); + // Add literal entry with name "foo" and value "bar". + DecodeEncoderStreamData(absl::HexStringToBytes("6294e703626172")); +} + +TEST_P(QpackDecoderTest, BlockedDecodingUnblockedBeforeEndOfHeaderBlock) { + StartDecoding(); + DecodeData(absl::HexStringToBytes( + "0200" // Required Insert Count 1 and Delta Base 0. + // Base is 1 + 0 = 1. + "80" // Indexed Header Field instruction addressing dynamic table + // entry with relative index 0, absolute index 0. + "d1")); // Static table entry with index 17. + + // Set dynamic table capacity to 1024. + DecodeEncoderStreamData(absl::HexStringToBytes("3fe107")); + + // Add literal entry with name "foo" and value "bar". Decoding is now + // unblocked because dynamic table Insert Count reached the Required Insert + // Count of the header block. |handler_| methods are called immediately for + // the already consumed part of the header block. + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq("bar"))); + EXPECT_CALL(handler_, OnHeaderDecoded(Eq(":method"), Eq("GET"))); + DecodeEncoderStreamData(absl::HexStringToBytes("6294e703626172")); + Mock::VerifyAndClearExpectations(&handler_); + + // Rest of header block is processed by QpackProgressiveDecoder + // in the unblocked state. + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq("bar"))); + EXPECT_CALL(handler_, OnHeaderDecoded(Eq(":scheme"), Eq("https"))); + DecodeData(absl::HexStringToBytes( + "80" // Indexed Header Field instruction addressing dynamic table + // entry with relative index 0, absolute index 0. + "d7")); // Static table entry with index 23. + Mock::VerifyAndClearExpectations(&handler_); + + EXPECT_CALL(handler_, OnDecodingCompleted()); + EXPECT_CALL(decoder_stream_sender_delegate_, + WriteStreamData(Eq(kHeaderAcknowledgement))); + EndDecoding(); +} + +// Regression test for https://crbug.com/1024263. +TEST_P(QpackDecoderTest, + BlockedDecodingUnblockedAndErrorBeforeEndOfHeaderBlock) { + StartDecoding(); + DecodeData(absl::HexStringToBytes( + "0200" // Required Insert Count 1 and Delta Base 0. + // Base is 1 + 0 = 1. + "80" // Indexed Header Field instruction addressing dynamic table + // entry with relative index 0, absolute index 0. + "81")); // Relative index 1 is equal to Base, therefore invalid. + + // Set dynamic table capacity to 1024. + DecodeEncoderStreamData(absl::HexStringToBytes("3fe107")); + + // Add literal entry with name "foo" and value "bar". Decoding is now + // unblocked because dynamic table Insert Count reached the Required Insert + // Count of the header block. |handler_| methods are called immediately for + // the already consumed part of the header block. + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq("bar"))); + EXPECT_CALL(handler_, OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Invalid relative index."))); + DecodeEncoderStreamData(absl::HexStringToBytes("6294e703626172")); +} + +// Make sure that Required Insert Count is compared to Insert Count, +// not size of dynamic table. +TEST_P(QpackDecoderTest, BlockedDecodingAndEvictedEntries) { + // Update dynamic table capacity to 128. + // At most three non-empty entries fit in the dynamic table. + DecodeEncoderStreamData(absl::HexStringToBytes("3f61")); + + DecodeHeaderBlock(absl::HexStringToBytes( + "0700" // Required Insert Count 6 and Delta Base 0. + // Base is 6 + 0 = 6. + "80")); // Indexed Header Field instruction addressing dynamic table + // entry with relative index 0, absolute index 5. + + // Add literal entry with name "foo" and value "bar". + DecodeEncoderStreamData(absl::HexStringToBytes("6294e703626172")); + + // Duplicate entry four times. This evicts the first two instances. + DecodeEncoderStreamData(absl::HexStringToBytes("00000000")); + + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq("baz"))); + EXPECT_CALL(handler_, OnDecodingCompleted()); + EXPECT_CALL(decoder_stream_sender_delegate_, + WriteStreamData(Eq(kHeaderAcknowledgement))); + + // Add literal entry with name "foo" and value "bar". + // Insert Count is now 6, reaching Required Insert Count of the header block. + DecodeEncoderStreamData(absl::HexStringToBytes("6294e70362617a")); +} + +TEST_P(QpackDecoderTest, TooManyBlockedStreams) { + // Required Insert Count 1 and Delta Base 0. + // Without any dynamic table entries received, decoding is blocked. + std::string data = absl::HexStringToBytes("0200"); + + auto progressive_decoder1 = CreateProgressiveDecoder(/* stream_id = */ 1); + progressive_decoder1->Decode(data); + + EXPECT_CALL(handler_, + OnDecodingErrorDetected( + QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Limit on number of blocked streams exceeded."))); + + auto progressive_decoder2 = CreateProgressiveDecoder(/* stream_id = */ 2); + progressive_decoder2->Decode(data); +} + +TEST_P(QpackDecoderTest, InsertCountIncrement) { + DecodeEncoderStreamData(absl::HexStringToBytes( + "3fe107" // Set dynamic table capacity to 1024. + "6294e703626172" // Add literal entry with name "foo" and value "bar". + "00")); // Duplicate entry. + + EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq("bar"))); + EXPECT_CALL(handler_, OnDecodingCompleted()); + + // Decoder received two insertions, but Header Acknowledgement only increases + // Known Insert Count to one. Decoder should send an Insert Count Increment + // instruction with increment of one to update Known Insert Count to two. + EXPECT_CALL(decoder_stream_sender_delegate_, + WriteStreamData(Eq(absl::HexStringToBytes( + "81" // Header Acknowledgement on stream 1 + "01")))); // Insert Count Increment with increment of one + + DecodeHeaderBlock(absl::HexStringToBytes( + "0200" // Required Insert Count 1 and Delta Base 0. + // Base is 1 + 0 = 1. + "80")); // Dynamic table entry with relative index 0, absolute index 0. +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_encoder.cc b/quiche/quic/core/qpack/qpack_encoder.cc new file mode 100644 index 000000000000..f70fda43061a --- /dev/null +++ b/quiche/quic/core/qpack/qpack_encoder.cc @@ -0,0 +1,455 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_encoder.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/qpack/qpack_index_conversions.h" +#include "quiche/quic/core/qpack/qpack_instruction_encoder.h" +#include "quiche/quic/core/qpack/qpack_required_insert_count.h" +#include "quiche/quic/core/qpack/value_splitting_header_list.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +namespace { + +// Fraction to calculate draining index. The oldest |kDrainingFraction| entries +// will not be referenced in header blocks. A new entry (duplicate or literal +// with name reference) will be added to the dynamic table instead. This allows +// the number of references to the draining entry to go to zero faster, so that +// it can be evicted. See +// https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#avoiding-blocked-insertions. +// TODO(bnc): Fine tune. +const float kDrainingFraction = 0.25; + +} // anonymous namespace + +QpackEncoder::QpackEncoder( + DecoderStreamErrorDelegate* decoder_stream_error_delegate) + : decoder_stream_error_delegate_(decoder_stream_error_delegate), + decoder_stream_receiver_(this), + maximum_blocked_streams_(0), + header_list_count_(0) { + QUICHE_DCHECK(decoder_stream_error_delegate_); +} + +QpackEncoder::~QpackEncoder() {} + +// static +QpackEncoder::Representation QpackEncoder::EncodeIndexedHeaderField( + bool is_static, uint64_t index, + QpackBlockingManager::IndexSet* referred_indices) { + // Add |index| to |*referred_indices| only if entry is in the dynamic table. + if (!is_static) { + referred_indices->insert(index); + } + return Representation::IndexedHeaderField(is_static, index); +} + +// static +QpackEncoder::Representation +QpackEncoder::EncodeLiteralHeaderFieldWithNameReference( + bool is_static, uint64_t index, absl::string_view value, + QpackBlockingManager::IndexSet* referred_indices) { + // Add |index| to |*referred_indices| only if entry is in the dynamic table. + if (!is_static) { + referred_indices->insert(index); + } + return Representation::LiteralHeaderFieldNameReference(is_static, index, + value); +} + +// static +QpackEncoder::Representation QpackEncoder::EncodeLiteralHeaderField( + absl::string_view name, absl::string_view value) { + return Representation::LiteralHeaderField(name, value); +} + +QpackEncoder::Representations QpackEncoder::FirstPassEncode( + QuicStreamId stream_id, const spdy::Http2HeaderBlock& header_list, + QpackBlockingManager::IndexSet* referred_indices, + QuicByteCount* encoder_stream_sent_byte_count) { + // If previous instructions are buffered in |encoder_stream_sender_|, + // do not count them towards the current header block. + const QuicByteCount initial_encoder_stream_buffered_byte_count = + encoder_stream_sender_.BufferedByteCount(); + + const bool can_write_to_encoder_stream = encoder_stream_sender_.CanWrite(); + + Representations representations; + representations.reserve(header_list.size()); + + // The index of the oldest entry that must not be evicted. + uint64_t smallest_blocking_index = + blocking_manager_.smallest_blocking_index(); + // Entries with index larger than or equal to |known_received_count| are + // blocking. + const uint64_t known_received_count = + blocking_manager_.known_received_count(); + // Only entries with index greater than or equal to |draining_index| are + // allowed to be referenced. + const uint64_t draining_index = + header_table_.draining_index(kDrainingFraction); + // Blocking references are allowed if the number of blocked streams is less + // than the limit. + const bool blocking_allowed = blocking_manager_.blocking_allowed_on_stream( + stream_id, maximum_blocked_streams_); + + // Track events for histograms. + bool dynamic_table_insertion_blocked = false; + bool blocked_stream_limit_exhausted = false; + + for (const auto& header : ValueSplittingHeaderList(&header_list)) { + // These strings are owned by |header_list|. + absl::string_view name = header.first; + absl::string_view value = header.second; + + bool is_static; + uint64_t index; + + auto match_type = + header_table_.FindHeaderField(name, value, &is_static, &index); + + switch (match_type) { + case QpackEncoderHeaderTable::MatchType::kNameAndValue: + if (is_static) { + // Refer to entry directly. + representations.push_back( + EncodeIndexedHeaderField(is_static, index, referred_indices)); + + break; + } + + if (index >= draining_index) { + // If allowed, refer to entry directly. + if (!blocking_allowed && index >= known_received_count) { + blocked_stream_limit_exhausted = true; + } else { + representations.push_back( + EncodeIndexedHeaderField(is_static, index, referred_indices)); + smallest_blocking_index = std::min(smallest_blocking_index, index); + header_table_.set_dynamic_table_entry_referenced(); + + break; + } + } else { + // Entry is draining, needs to be duplicated. + if (!blocking_allowed) { + blocked_stream_limit_exhausted = true; + } else if (QpackEntry::Size(name, value) > + header_table_.MaxInsertSizeWithoutEvictingGivenEntry( + std::min(smallest_blocking_index, index))) { + dynamic_table_insertion_blocked = true; + } else { + if (can_write_to_encoder_stream) { + // If allowed, duplicate entry and refer to it. + encoder_stream_sender_.SendDuplicate( + QpackAbsoluteIndexToEncoderStreamRelativeIndex( + index, header_table_.inserted_entry_count())); + uint64_t new_index = header_table_.InsertEntry(name, value); + representations.push_back(EncodeIndexedHeaderField( + is_static, new_index, referred_indices)); + smallest_blocking_index = + std::min(smallest_blocking_index, index); + header_table_.set_dynamic_table_entry_referenced(); + + break; + } + } + } + + // Encode entry as string literals. + // TODO(b/112770235): Use already acknowledged entry with lower index if + // exists. + // TODO(b/112770235): Use static entry name with literal value if + // dynamic entry exists but cannot be used. + representations.push_back(EncodeLiteralHeaderField(name, value)); + + break; + + case QpackEncoderHeaderTable::MatchType::kName: + if (is_static) { + if (blocking_allowed && + QpackEntry::Size(name, value) <= + header_table_.MaxInsertSizeWithoutEvictingGivenEntry( + smallest_blocking_index)) { + // If allowed, insert entry into dynamic table and refer to it. + if (can_write_to_encoder_stream) { + encoder_stream_sender_.SendInsertWithNameReference(is_static, + index, value); + uint64_t new_index = header_table_.InsertEntry(name, value); + representations.push_back(EncodeIndexedHeaderField( + /* is_static = */ false, new_index, referred_indices)); + smallest_blocking_index = + std::min(smallest_blocking_index, new_index); + + break; + } + } + + // Emit literal field with name reference. + representations.push_back(EncodeLiteralHeaderFieldWithNameReference( + is_static, index, value, referred_indices)); + + break; + } + + if (!blocking_allowed) { + blocked_stream_limit_exhausted = true; + } else if (QpackEntry::Size(name, value) > + header_table_.MaxInsertSizeWithoutEvictingGivenEntry( + std::min(smallest_blocking_index, index))) { + dynamic_table_insertion_blocked = true; + } else { + // If allowed, insert entry with name reference and refer to it. + if (can_write_to_encoder_stream) { + encoder_stream_sender_.SendInsertWithNameReference( + is_static, + QpackAbsoluteIndexToEncoderStreamRelativeIndex( + index, header_table_.inserted_entry_count()), + value); + uint64_t new_index = header_table_.InsertEntry(name, value); + representations.push_back(EncodeIndexedHeaderField( + is_static, new_index, referred_indices)); + smallest_blocking_index = std::min(smallest_blocking_index, index); + header_table_.set_dynamic_table_entry_referenced(); + + break; + } + } + + if ((blocking_allowed || index < known_received_count) && + index >= draining_index) { + // If allowed, refer to entry name directly, with literal value. + representations.push_back(EncodeLiteralHeaderFieldWithNameReference( + is_static, index, value, referred_indices)); + smallest_blocking_index = std::min(smallest_blocking_index, index); + header_table_.set_dynamic_table_entry_referenced(); + + break; + } + + // Encode entry as string literals. + // TODO(b/112770235): Use already acknowledged entry with lower index if + // exists. + // TODO(b/112770235): Use static entry name with literal value if + // dynamic entry exists but cannot be used. + representations.push_back(EncodeLiteralHeaderField(name, value)); + + break; + + case QpackEncoderHeaderTable::MatchType::kNoMatch: + // If allowed, insert entry and refer to it. + if (!blocking_allowed) { + blocked_stream_limit_exhausted = true; + } else if (QpackEntry::Size(name, value) > + header_table_.MaxInsertSizeWithoutEvictingGivenEntry( + smallest_blocking_index)) { + dynamic_table_insertion_blocked = true; + } else { + if (can_write_to_encoder_stream) { + encoder_stream_sender_.SendInsertWithoutNameReference(name, value); + uint64_t new_index = header_table_.InsertEntry(name, value); + representations.push_back(EncodeIndexedHeaderField( + /* is_static = */ false, new_index, referred_indices)); + smallest_blocking_index = + std::min(smallest_blocking_index, new_index); + + break; + } + } + + // Encode entry as string literals. + // TODO(b/112770235): Consider also adding to dynamic table to improve + // compression ratio for subsequent header blocks with peers that do not + // allow any blocked streams. + representations.push_back(EncodeLiteralHeaderField(name, value)); + + break; + } + } + + const QuicByteCount encoder_stream_buffered_byte_count = + encoder_stream_sender_.BufferedByteCount(); + QUICHE_DCHECK_GE(encoder_stream_buffered_byte_count, + initial_encoder_stream_buffered_byte_count); + + if (encoder_stream_sent_byte_count) { + *encoder_stream_sent_byte_count = + encoder_stream_buffered_byte_count - + initial_encoder_stream_buffered_byte_count; + } + if (can_write_to_encoder_stream) { + encoder_stream_sender_.Flush(); + } else { + QUICHE_DCHECK_EQ(encoder_stream_buffered_byte_count, + initial_encoder_stream_buffered_byte_count); + } + + ++header_list_count_; + + if (dynamic_table_insertion_blocked) { + QUIC_HISTOGRAM_COUNTS( + "QuicSession.Qpack.HeaderListCountWhenInsertionBlocked", + header_list_count_, /* min = */ 1, /* max = */ 1000, + /* bucket_count = */ 50, + "The ordinality of a header list within a connection during the " + "encoding of which at least one dynamic table insertion was " + "blocked."); + } else { + QUIC_HISTOGRAM_COUNTS( + "QuicSession.Qpack.HeaderListCountWhenInsertionNotBlocked", + header_list_count_, /* min = */ 1, /* max = */ 1000, + /* bucket_count = */ 50, + "The ordinality of a header list within a connection during the " + "encoding of which no dynamic table insertion was blocked."); + } + + if (blocked_stream_limit_exhausted) { + QUIC_HISTOGRAM_COUNTS( + "QuicSession.Qpack.HeaderListCountWhenBlockedStreamLimited", + header_list_count_, /* min = */ 1, /* max = */ 1000, + /* bucket_count = */ 50, + "The ordinality of a header list within a connection during the " + "encoding of which unacknowledged dynamic table entries could not be " + "referenced due to the limit on the number of blocked streams."); + } else { + QUIC_HISTOGRAM_COUNTS( + "QuicSession.Qpack.HeaderListCountWhenNotBlockedStreamLimited", + header_list_count_, /* min = */ 1, /* max = */ 1000, + /* bucket_count = */ 50, + "The ordinality of a header list within a connection during the " + "encoding of which the limit on the number of blocked streams did " + "not " + "prevent referencing unacknowledged dynamic table entries."); + } + + return representations; +} + +std::string QpackEncoder::SecondPassEncode( + QpackEncoder::Representations representations, + uint64_t required_insert_count) const { + QpackInstructionEncoder instruction_encoder; + std::string encoded_headers; + + // Header block prefix. + instruction_encoder.Encode( + Representation::Prefix(QpackEncodeRequiredInsertCount( + required_insert_count, header_table_.max_entries())), + &encoded_headers); + + const uint64_t base = required_insert_count; + + for (auto& representation : representations) { + // Dynamic table references must be transformed from absolute to relative + // indices. + if ((representation.instruction() == QpackIndexedHeaderFieldInstruction() || + representation.instruction() == + QpackLiteralHeaderFieldNameReferenceInstruction()) && + !representation.s_bit()) { + representation.set_varint(QpackAbsoluteIndexToRequestStreamRelativeIndex( + representation.varint(), base)); + } + instruction_encoder.Encode(representation, &encoded_headers); + } + + return encoded_headers; +} + +std::string QpackEncoder::EncodeHeaderList( + QuicStreamId stream_id, const spdy::Http2HeaderBlock& header_list, + QuicByteCount* encoder_stream_sent_byte_count) { + // Keep track of all dynamic table indices that this header block refers to so + // that it can be passed to QpackBlockingManager. + QpackBlockingManager::IndexSet referred_indices; + + // First pass: encode into |representations|. + Representations representations = + FirstPassEncode(stream_id, header_list, &referred_indices, + encoder_stream_sent_byte_count); + + const uint64_t required_insert_count = + referred_indices.empty() + ? 0 + : QpackBlockingManager::RequiredInsertCount(referred_indices); + if (!referred_indices.empty()) { + blocking_manager_.OnHeaderBlockSent(stream_id, std::move(referred_indices)); + } + + // Second pass. + return SecondPassEncode(std::move(representations), required_insert_count); +} + +bool QpackEncoder::SetMaximumDynamicTableCapacity( + uint64_t maximum_dynamic_table_capacity) { + return header_table_.SetMaximumDynamicTableCapacity( + maximum_dynamic_table_capacity); +} + +void QpackEncoder::SetDynamicTableCapacity(uint64_t dynamic_table_capacity) { + encoder_stream_sender_.SendSetDynamicTableCapacity(dynamic_table_capacity); + // Do not flush encoder stream. This write can safely be delayed until more + // instructions are written. + + bool success = header_table_.SetDynamicTableCapacity(dynamic_table_capacity); + QUICHE_DCHECK(success); +} + +bool QpackEncoder::SetMaximumBlockedStreams(uint64_t maximum_blocked_streams) { + if (maximum_blocked_streams < maximum_blocked_streams_) { + return false; + } + maximum_blocked_streams_ = maximum_blocked_streams; + return true; +} + +void QpackEncoder::OnInsertCountIncrement(uint64_t increment) { + if (increment == 0) { + OnErrorDetected(QUIC_QPACK_DECODER_STREAM_INVALID_ZERO_INCREMENT, + "Invalid increment value 0."); + return; + } + + if (!blocking_manager_.OnInsertCountIncrement(increment)) { + OnErrorDetected(QUIC_QPACK_DECODER_STREAM_INCREMENT_OVERFLOW, + "Insert Count Increment instruction causes overflow."); + } + + if (blocking_manager_.known_received_count() > + header_table_.inserted_entry_count()) { + OnErrorDetected(QUIC_QPACK_DECODER_STREAM_IMPOSSIBLE_INSERT_COUNT, + absl::StrCat("Increment value ", increment, + " raises known received count to ", + blocking_manager_.known_received_count(), + " exceeding inserted entry count ", + header_table_.inserted_entry_count())); + } +} + +void QpackEncoder::OnHeaderAcknowledgement(QuicStreamId stream_id) { + if (!blocking_manager_.OnHeaderAcknowledgement(stream_id)) { + OnErrorDetected( + QUIC_QPACK_DECODER_STREAM_INCORRECT_ACKNOWLEDGEMENT, + absl::StrCat("Header Acknowledgement received for stream ", stream_id, + " with no outstanding header blocks.")); + } +} + +void QpackEncoder::OnStreamCancellation(QuicStreamId stream_id) { + blocking_manager_.OnStreamCancellation(stream_id); +} + +void QpackEncoder::OnErrorDetected(QuicErrorCode error_code, + absl::string_view error_message) { + decoder_stream_error_delegate_->OnDecoderStreamError(error_code, + error_message); +} + +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_encoder.h b/quiche/quic/core/qpack/qpack_encoder.h new file mode 100644 index 000000000000..89d08b1a9f6e --- /dev/null +++ b/quiche/quic/core/qpack/qpack_encoder.h @@ -0,0 +1,170 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QPACK_QPACK_ENCODER_H_ +#define QUICHE_QUIC_CORE_QPACK_QPACK_ENCODER_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/qpack/qpack_blocking_manager.h" +#include "quiche/quic/core/qpack/qpack_decoder_stream_receiver.h" +#include "quiche/quic/core/qpack/qpack_encoder_stream_sender.h" +#include "quiche/quic/core/qpack/qpack_header_table.h" +#include "quiche/quic/core/qpack/qpack_instructions.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_exported_stats.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +namespace test { + +class QpackEncoderPeer; + +} // namespace test + +// QPACK encoder class. Exactly one instance should exist per QUIC connection. +class QUIC_EXPORT_PRIVATE QpackEncoder + : public QpackDecoderStreamReceiver::Delegate { + public: + // Interface for receiving notification that an error has occurred on the + // decoder stream. This MUST be treated as a connection error of type + // HTTP_QPACK_DECODER_STREAM_ERROR. + class QUIC_EXPORT_PRIVATE DecoderStreamErrorDelegate { + public: + virtual ~DecoderStreamErrorDelegate() {} + + virtual void OnDecoderStreamError(QuicErrorCode error_code, + absl::string_view error_message) = 0; + }; + + QpackEncoder(DecoderStreamErrorDelegate* decoder_stream_error_delegate); + ~QpackEncoder() override; + + // Encode a header list. If |encoder_stream_sent_byte_count| is not null, + // |*encoder_stream_sent_byte_count| will be set to the number of bytes sent + // on the encoder stream to insert dynamic table entries. + std::string EncodeHeaderList(QuicStreamId stream_id, + const spdy::Http2HeaderBlock& header_list, + QuicByteCount* encoder_stream_sent_byte_count); + + // Set maximum dynamic table capacity to |maximum_dynamic_table_capacity|, + // measured in bytes. Called when SETTINGS_QPACK_MAX_TABLE_CAPACITY is + // received. Encoder needs to know this value so that it can calculate + // MaxEntries, used as a modulus to encode Required Insert Count. + // Returns true if |maximum_dynamic_table_capacity| is set for the first time + // or if it doesn't change current value. The setting is not changed when + // returning false. + bool SetMaximumDynamicTableCapacity(uint64_t maximum_dynamic_table_capacity); + + // Set dynamic table capacity to |dynamic_table_capacity|. + // |dynamic_table_capacity| must not exceed maximum dynamic table capacity. + // Also sends Set Dynamic Table Capacity instruction on encoder stream. + void SetDynamicTableCapacity(uint64_t dynamic_table_capacity); + + // Set maximum number of blocked streams. + // Called when SETTINGS_QPACK_BLOCKED_STREAMS is received. + // Returns true if |maximum_blocked_streams| doesn't decrease current value. + // The setting is not changed when returning false. + bool SetMaximumBlockedStreams(uint64_t maximum_blocked_streams); + + // QpackDecoderStreamReceiver::Delegate implementation + void OnInsertCountIncrement(uint64_t increment) override; + void OnHeaderAcknowledgement(QuicStreamId stream_id) override; + void OnStreamCancellation(QuicStreamId stream_id) override; + void OnErrorDetected(QuicErrorCode error_code, + absl::string_view error_message) override; + + // delegate must be set if dynamic table capacity is not zero. + void set_qpack_stream_sender_delegate(QpackStreamSenderDelegate* delegate) { + encoder_stream_sender_.set_qpack_stream_sender_delegate(delegate); + } + + QpackStreamReceiver* decoder_stream_receiver() { + return &decoder_stream_receiver_; + } + + // True if any dynamic table entries have been referenced from a header block. + bool dynamic_table_entry_referenced() const { + return header_table_.dynamic_table_entry_referenced(); + } + + uint64_t maximum_blocked_streams() const { return maximum_blocked_streams_; } + + uint64_t MaximumDynamicTableCapacity() const { + return header_table_.maximum_dynamic_table_capacity(); + } + + private: + friend class test::QpackEncoderPeer; + + using Representation = QpackInstructionWithValues; + using Representations = std::vector; + + // Generate indexed header field representation + // and optionally update |*referred_indices|. + static Representation EncodeIndexedHeaderField( + bool is_static, uint64_t index, + QpackBlockingManager::IndexSet* referred_indices); + + // Generate literal header field with name reference representation + // and optionally update |*referred_indices|. + static Representation EncodeLiteralHeaderFieldWithNameReference( + bool is_static, uint64_t index, absl::string_view value, + QpackBlockingManager::IndexSet* referred_indices); + + // Generate literal header field representation. + static Representation EncodeLiteralHeaderField(absl::string_view name, + absl::string_view value); + + // Performs first pass of two-pass encoding: represent each header field in + // |*header_list| as a reference to an existing entry, the name of an existing + // entry with a literal value, or a literal name and value pair. Sends + // necessary instructions on the encoder stream coalesced in a single write. + // Records absolute indices of referred dynamic table entries in + // |*referred_indices|. If |encoder_stream_sent_byte_count| is not null, then + // sets |*encoder_stream_sent_byte_count| to the number of bytes sent on the + // encoder stream to insert dynamic table entries. Returns list of header + // field representations, with all dynamic table entries referred to with + // absolute indices. Returned representation objects may have + // absl::string_views pointing to strings owned by |*header_list|. + Representations FirstPassEncode( + QuicStreamId stream_id, const spdy::Http2HeaderBlock& header_list, + QpackBlockingManager::IndexSet* referred_indices, + QuicByteCount* encoder_stream_sent_byte_count); + + // Performs second pass of two-pass encoding: serializes representations + // generated in first pass, transforming absolute indices of dynamic table + // entries to relative indices. + std::string SecondPassEncode(Representations representations, + uint64_t required_insert_count) const; + + DecoderStreamErrorDelegate* const decoder_stream_error_delegate_; + QpackDecoderStreamReceiver decoder_stream_receiver_; + QpackEncoderStreamSender encoder_stream_sender_; + QpackEncoderHeaderTable header_table_; + uint64_t maximum_blocked_streams_; + QpackBlockingManager blocking_manager_; + int header_list_count_; +}; + +// QpackEncoder::DecoderStreamErrorDelegate implementation that does nothing. +class QUIC_EXPORT_PRIVATE NoopDecoderStreamErrorDelegate + : public QpackEncoder::DecoderStreamErrorDelegate { + public: + ~NoopDecoderStreamErrorDelegate() override = default; + + void OnDecoderStreamError(QuicErrorCode /*error_code*/, absl::string_view + /*error_message*/) override {} +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QPACK_QPACK_ENCODER_H_ diff --git a/quiche/quic/core/qpack/qpack_encoder_stream_receiver.cc b/quiche/quic/core/qpack/qpack_encoder_stream_receiver.cc new file mode 100644 index 000000000000..579efc212905 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_encoder_stream_receiver.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_encoder_stream_receiver.h" + +#include "absl/strings/string_view.h" +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/quic/core/qpack/qpack_instructions.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +QpackEncoderStreamReceiver::QpackEncoderStreamReceiver(Delegate* delegate) + : instruction_decoder_(QpackEncoderStreamLanguage(), this), + delegate_(delegate), + error_detected_(false) { + QUICHE_DCHECK(delegate_); +} + +void QpackEncoderStreamReceiver::Decode(absl::string_view data) { + if (data.empty() || error_detected_) { + return; + } + + instruction_decoder_.Decode(data); +} + +bool QpackEncoderStreamReceiver::OnInstructionDecoded( + const QpackInstruction* instruction) { + if (instruction == InsertWithNameReferenceInstruction()) { + delegate_->OnInsertWithNameReference(instruction_decoder_.s_bit(), + instruction_decoder_.varint(), + instruction_decoder_.value()); + return true; + } + + if (instruction == InsertWithoutNameReferenceInstruction()) { + delegate_->OnInsertWithoutNameReference(instruction_decoder_.name(), + instruction_decoder_.value()); + return true; + } + + if (instruction == DuplicateInstruction()) { + delegate_->OnDuplicate(instruction_decoder_.varint()); + return true; + } + + QUICHE_DCHECK_EQ(instruction, SetDynamicTableCapacityInstruction()); + delegate_->OnSetDynamicTableCapacity(instruction_decoder_.varint()); + return true; +} + +void QpackEncoderStreamReceiver::OnInstructionDecodingError( + QpackInstructionDecoder::ErrorCode error_code, + absl::string_view error_message) { + QUICHE_DCHECK(!error_detected_); + + error_detected_ = true; + + QuicErrorCode quic_error_code; + switch (error_code) { + case QpackInstructionDecoder::ErrorCode::INTEGER_TOO_LARGE: + quic_error_code = QUIC_QPACK_ENCODER_STREAM_INTEGER_TOO_LARGE; + break; + case QpackInstructionDecoder::ErrorCode::STRING_LITERAL_TOO_LONG: + quic_error_code = QUIC_QPACK_ENCODER_STREAM_STRING_LITERAL_TOO_LONG; + break; + case QpackInstructionDecoder::ErrorCode::HUFFMAN_ENCODING_ERROR: + quic_error_code = QUIC_QPACK_ENCODER_STREAM_HUFFMAN_ENCODING_ERROR; + break; + default: + quic_error_code = QUIC_INTERNAL_ERROR; + } + + delegate_->OnErrorDetected(quic_error_code, error_message); +} + +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_encoder_stream_receiver.h b/quiche/quic/core/qpack/qpack_encoder_stream_receiver.h new file mode 100644 index 000000000000..6db91e19d4db --- /dev/null +++ b/quiche/quic/core/qpack/qpack_encoder_stream_receiver.h @@ -0,0 +1,73 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QPACK_QPACK_ENCODER_STREAM_RECEIVER_H_ +#define QUICHE_QUIC_CORE_QPACK_QPACK_ENCODER_STREAM_RECEIVER_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/qpack/qpack_instruction_decoder.h" +#include "quiche/quic/core/qpack/qpack_stream_receiver.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// This class decodes data received on the encoder stream. +class QUIC_EXPORT_PRIVATE QpackEncoderStreamReceiver + : public QpackInstructionDecoder::Delegate, + public QpackStreamReceiver { + public: + // An interface for handling instructions decoded from the encoder stream, see + // https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#rfc.section.5.2 + class QUIC_EXPORT_PRIVATE Delegate { + public: + virtual ~Delegate() = default; + + // 5.2.1. Insert With Name Reference + virtual void OnInsertWithNameReference(bool is_static, uint64_t name_index, + absl::string_view value) = 0; + // 5.2.2. Insert Without Name Reference + virtual void OnInsertWithoutNameReference(absl::string_view name, + absl::string_view value) = 0; + // 5.2.3. Duplicate + virtual void OnDuplicate(uint64_t index) = 0; + // 5.2.4. Set Dynamic Table Capacity + virtual void OnSetDynamicTableCapacity(uint64_t capacity) = 0; + // Decoding error + virtual void OnErrorDetected(QuicErrorCode error_code, + absl::string_view error_message) = 0; + }; + + explicit QpackEncoderStreamReceiver(Delegate* delegate); + QpackEncoderStreamReceiver() = delete; + QpackEncoderStreamReceiver(const QpackEncoderStreamReceiver&) = delete; + QpackEncoderStreamReceiver& operator=(const QpackEncoderStreamReceiver&) = + delete; + ~QpackEncoderStreamReceiver() override = default; + + // Implements QpackStreamReceiver::Decode(). + // Decode data and call appropriate Delegate method after each decoded + // instruction. Once an error occurs, Delegate::OnErrorDetected() is called, + // and all further data is ignored. + void Decode(absl::string_view data) override; + + // QpackInstructionDecoder::Delegate implementation. + bool OnInstructionDecoded(const QpackInstruction* instruction) override; + void OnInstructionDecodingError(QpackInstructionDecoder::ErrorCode error_code, + absl::string_view error_message) override; + + private: + QpackInstructionDecoder instruction_decoder_; + Delegate* const delegate_; + + // True if a decoding error has been detected. + bool error_detected_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QPACK_QPACK_ENCODER_STREAM_RECEIVER_H_ diff --git a/quiche/quic/core/qpack/qpack_encoder_stream_receiver_test.cc b/quiche/quic/core/qpack/qpack_encoder_stream_receiver_test.cc new file mode 100644 index 000000000000..5ab70427e163 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_encoder_stream_receiver_test.cc @@ -0,0 +1,193 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_encoder_stream_receiver.h" + +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_test.h" + +using testing::Eq; +using testing::StrictMock; + +namespace quic { +namespace test { +namespace { + +class MockDelegate : public QpackEncoderStreamReceiver::Delegate { + public: + ~MockDelegate() override = default; + + MOCK_METHOD(void, OnInsertWithNameReference, + (bool is_static, uint64_t name_index, absl::string_view value), + (override)); + MOCK_METHOD(void, OnInsertWithoutNameReference, + (absl::string_view name, absl::string_view value), (override)); + MOCK_METHOD(void, OnDuplicate, (uint64_t index), (override)); + MOCK_METHOD(void, OnSetDynamicTableCapacity, (uint64_t capacity), (override)); + MOCK_METHOD(void, OnErrorDetected, + (QuicErrorCode error_code, absl::string_view error_message), + (override)); +}; + +class QpackEncoderStreamReceiverTest : public QuicTest { + protected: + QpackEncoderStreamReceiverTest() : stream_(&delegate_) {} + ~QpackEncoderStreamReceiverTest() override = default; + + void Decode(absl::string_view data) { stream_.Decode(data); } + StrictMock* delegate() { return &delegate_; } + + private: + QpackEncoderStreamReceiver stream_; + StrictMock delegate_; +}; + +TEST_F(QpackEncoderStreamReceiverTest, InsertWithNameReference) { + // Static, index fits in prefix, empty value. + EXPECT_CALL(*delegate(), OnInsertWithNameReference(true, 5, Eq(""))); + // Static, index fits in prefix, Huffman encoded value. + EXPECT_CALL(*delegate(), OnInsertWithNameReference(true, 2, Eq("foo"))); + // Not static, index does not fit in prefix, not Huffman encoded value. + EXPECT_CALL(*delegate(), OnInsertWithNameReference(false, 137, Eq("bar"))); + // Value length does not fit in prefix. + // 'Z' would be Huffman encoded to 8 bits, so no Huffman encoding is used. + EXPECT_CALL(*delegate(), + OnInsertWithNameReference(false, 42, Eq(std::string(127, 'Z')))); + + Decode(absl::HexStringToBytes( + "c500" + "c28294e7" + "bf4a03626172" + "aa7f005a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a" + "5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a" + "5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a" + "5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a")); +} + +TEST_F(QpackEncoderStreamReceiverTest, InsertWithNameReferenceIndexTooLarge) { + EXPECT_CALL(*delegate(), + OnErrorDetected(QUIC_QPACK_ENCODER_STREAM_INTEGER_TOO_LARGE, + Eq("Encoded integer too large."))); + + Decode(absl::HexStringToBytes("bfffffffffffffffffffffff")); +} + +TEST_F(QpackEncoderStreamReceiverTest, InsertWithNameReferenceValueTooLong) { + EXPECT_CALL(*delegate(), + OnErrorDetected(QUIC_QPACK_ENCODER_STREAM_INTEGER_TOO_LARGE, + Eq("Encoded integer too large."))); + + Decode(absl::HexStringToBytes("c57fffffffffffffffffffff")); +} + +TEST_F(QpackEncoderStreamReceiverTest, InsertWithoutNameReference) { + // Empty name and value. + EXPECT_CALL(*delegate(), OnInsertWithoutNameReference(Eq(""), Eq(""))); + // Huffman encoded short strings. + EXPECT_CALL(*delegate(), OnInsertWithoutNameReference(Eq("bar"), Eq("bar"))); + // Not Huffman encoded short strings. + EXPECT_CALL(*delegate(), OnInsertWithoutNameReference(Eq("foo"), Eq("foo"))); + // Not Huffman encoded long strings; length does not fit on prefix. + // 'Z' would be Huffman encoded to 8 bits, so no Huffman encoding is used. + EXPECT_CALL(*delegate(), + OnInsertWithoutNameReference(Eq(std::string(31, 'Z')), + Eq(std::string(127, 'Z')))); + + Decode(absl::HexStringToBytes( + "4000" + "4362617203626172" + "6294e78294e7" + "5f005a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a7f005a" + "5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a" + "5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a" + "5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a" + "5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a")); +} + +// Name Length value is too large for varint decoder to decode. +TEST_F(QpackEncoderStreamReceiverTest, + InsertWithoutNameReferenceNameTooLongForVarintDecoder) { + EXPECT_CALL(*delegate(), + OnErrorDetected(QUIC_QPACK_ENCODER_STREAM_INTEGER_TOO_LARGE, + Eq("Encoded integer too large."))); + + Decode(absl::HexStringToBytes("5fffffffffffffffffffff")); +} + +// Name Length value can be decoded by varint decoder but exceeds 1 MB limit. +TEST_F(QpackEncoderStreamReceiverTest, + InsertWithoutNameReferenceNameExceedsLimit) { + EXPECT_CALL(*delegate(), + OnErrorDetected(QUIC_QPACK_ENCODER_STREAM_STRING_LITERAL_TOO_LONG, + Eq("String literal too long."))); + + Decode(absl::HexStringToBytes("5fffff7f")); +} + +// Value Length value is too large for varint decoder to decode. +TEST_F(QpackEncoderStreamReceiverTest, + InsertWithoutNameReferenceValueTooLongForVarintDecoder) { + EXPECT_CALL(*delegate(), + OnErrorDetected(QUIC_QPACK_ENCODER_STREAM_INTEGER_TOO_LARGE, + Eq("Encoded integer too large."))); + + Decode(absl::HexStringToBytes("436261727fffffffffffffffffffff")); +} + +// Value Length value can be decoded by varint decoder but exceeds 1 MB limit. +TEST_F(QpackEncoderStreamReceiverTest, + InsertWithoutNameReferenceValueExceedsLimit) { + EXPECT_CALL(*delegate(), + OnErrorDetected(QUIC_QPACK_ENCODER_STREAM_STRING_LITERAL_TOO_LONG, + Eq("String literal too long."))); + + Decode(absl::HexStringToBytes("436261727fffff7f")); +} + +TEST_F(QpackEncoderStreamReceiverTest, Duplicate) { + // Small index fits in prefix. + EXPECT_CALL(*delegate(), OnDuplicate(17)); + // Large index requires two extension bytes. + EXPECT_CALL(*delegate(), OnDuplicate(500)); + + Decode(absl::HexStringToBytes("111fd503")); +} + +TEST_F(QpackEncoderStreamReceiverTest, DuplicateIndexTooLarge) { + EXPECT_CALL(*delegate(), + OnErrorDetected(QUIC_QPACK_ENCODER_STREAM_INTEGER_TOO_LARGE, + Eq("Encoded integer too large."))); + + Decode(absl::HexStringToBytes("1fffffffffffffffffffff")); +} + +TEST_F(QpackEncoderStreamReceiverTest, SetDynamicTableCapacity) { + // Small capacity fits in prefix. + EXPECT_CALL(*delegate(), OnSetDynamicTableCapacity(17)); + // Large capacity requires two extension bytes. + EXPECT_CALL(*delegate(), OnSetDynamicTableCapacity(500)); + + Decode(absl::HexStringToBytes("313fd503")); +} + +TEST_F(QpackEncoderStreamReceiverTest, SetDynamicTableCapacityTooLarge) { + EXPECT_CALL(*delegate(), + OnErrorDetected(QUIC_QPACK_ENCODER_STREAM_INTEGER_TOO_LARGE, + Eq("Encoded integer too large."))); + + Decode(absl::HexStringToBytes("3fffffffffffffffffffff")); +} + +TEST_F(QpackEncoderStreamReceiverTest, InvalidHuffmanEncoding) { + EXPECT_CALL(*delegate(), + OnErrorDetected(QUIC_QPACK_ENCODER_STREAM_HUFFMAN_ENCODING_ERROR, + Eq("Error in Huffman-encoded string."))); + + Decode(absl::HexStringToBytes("c281ff")); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_encoder_stream_sender.cc b/quiche/quic/core/qpack/qpack_encoder_stream_sender.cc new file mode 100644 index 000000000000..574b3bb9393d --- /dev/null +++ b/quiche/quic/core/qpack/qpack_encoder_stream_sender.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_encoder_stream_sender.h" + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/qpack/qpack_instructions.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +namespace { + +// If QUIC stream bufferes more that this number of bytes, +// CanWrite() will return false. +constexpr uint64_t kMaxBytesBufferedByStream = 64 * 1024; + +} // anonymous namespace + +QpackEncoderStreamSender::QpackEncoderStreamSender() : delegate_(nullptr) {} + +void QpackEncoderStreamSender::SendInsertWithNameReference( + bool is_static, uint64_t name_index, absl::string_view value) { + instruction_encoder_.Encode( + QpackInstructionWithValues::InsertWithNameReference(is_static, name_index, + value), + &buffer_); +} + +void QpackEncoderStreamSender::SendInsertWithoutNameReference( + absl::string_view name, absl::string_view value) { + instruction_encoder_.Encode( + QpackInstructionWithValues::InsertWithoutNameReference(name, value), + &buffer_); +} + +void QpackEncoderStreamSender::SendDuplicate(uint64_t index) { + instruction_encoder_.Encode(QpackInstructionWithValues::Duplicate(index), + &buffer_); +} + +void QpackEncoderStreamSender::SendSetDynamicTableCapacity(uint64_t capacity) { + instruction_encoder_.Encode( + QpackInstructionWithValues::SetDynamicTableCapacity(capacity), &buffer_); +} + +bool QpackEncoderStreamSender::CanWrite() const { + return delegate_ && delegate_->NumBytesBuffered() + buffer_.size() <= + kMaxBytesBufferedByStream; +} + +void QpackEncoderStreamSender::Flush() { + if (buffer_.empty()) { + return; + } + + delegate_->WriteStreamData(buffer_); + buffer_.clear(); +} + +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_encoder_stream_sender.h b/quiche/quic/core/qpack/qpack_encoder_stream_sender.h new file mode 100644 index 000000000000..ab53a1dc318d --- /dev/null +++ b/quiche/quic/core/qpack/qpack_encoder_stream_sender.h @@ -0,0 +1,68 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QPACK_QPACK_ENCODER_STREAM_SENDER_H_ +#define QUICHE_QUIC_CORE_QPACK_QPACK_ENCODER_STREAM_SENDER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/qpack/qpack_instruction_encoder.h" +#include "quiche/quic/core/qpack/qpack_stream_sender_delegate.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// This class serializes instructions for transmission on the encoder stream. +// Serialized instructions are buffered until Flush() is called. +class QUIC_EXPORT_PRIVATE QpackEncoderStreamSender { + public: + QpackEncoderStreamSender(); + QpackEncoderStreamSender(const QpackEncoderStreamSender&) = delete; + QpackEncoderStreamSender& operator=(const QpackEncoderStreamSender&) = delete; + + // Methods for serializing and buffering instructions, see + // https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#rfc.section.5.2 + + // 5.2.1. Insert With Name Reference + void SendInsertWithNameReference(bool is_static, uint64_t name_index, + absl::string_view value); + // 5.2.2. Insert Without Name Reference + void SendInsertWithoutNameReference(absl::string_view name, + absl::string_view value); + // 5.2.3. Duplicate + void SendDuplicate(uint64_t index); + // 5.2.4. Set Dynamic Table Capacity + void SendSetDynamicTableCapacity(uint64_t capacity); + + // Returns number of bytes buffered by this object. + // There is no limit on how much data this object is willing to buffer. + QuicByteCount BufferedByteCount() const { return buffer_.size(); } + + // Returns whether writing to the encoder stream is allowed. Writing is + // disallowed if the amount of data buffered by the underlying stream exceeds + // a hardcoded limit, in order to limit memory consumption in case the encoder + // stream is blocked. CanWrite() returning true does not mean that the + // encoder stream is not blocked, it just means the blocked data does not + // exceed the threshold. + bool CanWrite() const; + + // Writes all buffered instructions on the encoder stream. + void Flush(); + + // delegate must be set if dynamic table capacity is not zero. + void set_qpack_stream_sender_delegate(QpackStreamSenderDelegate* delegate) { + delegate_ = delegate; + } + + private: + QpackStreamSenderDelegate* delegate_; + QpackInstructionEncoder instruction_encoder_; + std::string buffer_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QPACK_QPACK_ENCODER_STREAM_SENDER_H_ diff --git a/quiche/quic/core/qpack/qpack_encoder_stream_sender_test.cc b/quiche/quic/core/qpack/qpack_encoder_stream_sender_test.cc new file mode 100644 index 000000000000..5f90e740c958 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_encoder_stream_sender_test.cc @@ -0,0 +1,179 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_encoder_stream_sender.h" + +#include "absl/strings/escaping.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/qpack/qpack_test_utils.h" + +using ::testing::Eq; +using ::testing::StrictMock; + +namespace quic { +namespace test { +namespace { + +class QpackEncoderStreamSenderTest : public QuicTest { + protected: + QpackEncoderStreamSenderTest() { + stream_.set_qpack_stream_sender_delegate(&delegate_); + } + ~QpackEncoderStreamSenderTest() override = default; + + StrictMock delegate_; + QpackEncoderStreamSender stream_; +}; + +TEST_F(QpackEncoderStreamSenderTest, InsertWithNameReference) { + EXPECT_EQ(0u, stream_.BufferedByteCount()); + + // Static, index fits in prefix, empty value. + std::string expected_encoded_data = absl::HexStringToBytes("c500"); + EXPECT_CALL(delegate_, WriteStreamData(Eq(expected_encoded_data))); + stream_.SendInsertWithNameReference(true, 5, ""); + EXPECT_EQ(expected_encoded_data.size(), stream_.BufferedByteCount()); + stream_.Flush(); + + // Static, index fits in prefix, Huffman encoded value. + expected_encoded_data = absl::HexStringToBytes("c28294e7"); + EXPECT_CALL(delegate_, WriteStreamData(Eq(expected_encoded_data))); + stream_.SendInsertWithNameReference(true, 2, "foo"); + EXPECT_EQ(expected_encoded_data.size(), stream_.BufferedByteCount()); + stream_.Flush(); + + // Not static, index does not fit in prefix, not Huffman encoded value. + expected_encoded_data = absl::HexStringToBytes("bf4a03626172"); + EXPECT_CALL(delegate_, WriteStreamData(Eq(expected_encoded_data))); + stream_.SendInsertWithNameReference(false, 137, "bar"); + EXPECT_EQ(expected_encoded_data.size(), stream_.BufferedByteCount()); + stream_.Flush(); + + // Value length does not fit in prefix. + // 'Z' would be Huffman encoded to 8 bits, so no Huffman encoding is used. + expected_encoded_data = absl::HexStringToBytes( + "aa7f005a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a" + "5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a" + "5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a" + "5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a"); + EXPECT_CALL(delegate_, WriteStreamData(Eq(expected_encoded_data))); + stream_.SendInsertWithNameReference(false, 42, std::string(127, 'Z')); + EXPECT_EQ(expected_encoded_data.size(), stream_.BufferedByteCount()); + stream_.Flush(); +} + +TEST_F(QpackEncoderStreamSenderTest, InsertWithoutNameReference) { + EXPECT_EQ(0u, stream_.BufferedByteCount()); + + // Empty name and value. + std::string expected_encoded_data = absl::HexStringToBytes("4000"); + EXPECT_CALL(delegate_, WriteStreamData(Eq(expected_encoded_data))); + stream_.SendInsertWithoutNameReference("", ""); + EXPECT_EQ(expected_encoded_data.size(), stream_.BufferedByteCount()); + stream_.Flush(); + + // Huffman encoded short strings. + expected_encoded_data = absl::HexStringToBytes("6294e78294e7"); + EXPECT_CALL(delegate_, WriteStreamData(Eq(expected_encoded_data))); + stream_.SendInsertWithoutNameReference("foo", "foo"); + EXPECT_EQ(expected_encoded_data.size(), stream_.BufferedByteCount()); + stream_.Flush(); + + // Not Huffman encoded short strings. + expected_encoded_data = absl::HexStringToBytes("4362617203626172"); + EXPECT_CALL(delegate_, WriteStreamData(Eq(expected_encoded_data))); + stream_.SendInsertWithoutNameReference("bar", "bar"); + EXPECT_EQ(expected_encoded_data.size(), stream_.BufferedByteCount()); + stream_.Flush(); + + // Not Huffman encoded long strings; length does not fit on prefix. + // 'Z' would be Huffman encoded to 8 bits, so no Huffman encoding is used. + expected_encoded_data = absl::HexStringToBytes( + "5f005a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a7f" + "005a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a" + "5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a" + "5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a" + "5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a"); + EXPECT_CALL(delegate_, WriteStreamData(Eq(expected_encoded_data))); + stream_.SendInsertWithoutNameReference(std::string(31, 'Z'), + std::string(127, 'Z')); + EXPECT_EQ(expected_encoded_data.size(), stream_.BufferedByteCount()); + stream_.Flush(); +} + +TEST_F(QpackEncoderStreamSenderTest, Duplicate) { + EXPECT_EQ(0u, stream_.BufferedByteCount()); + + // Small index fits in prefix. + std::string expected_encoded_data = absl::HexStringToBytes("11"); + EXPECT_CALL(delegate_, WriteStreamData(Eq(expected_encoded_data))); + stream_.SendDuplicate(17); + EXPECT_EQ(expected_encoded_data.size(), stream_.BufferedByteCount()); + stream_.Flush(); + + // Large index requires two extension bytes. + expected_encoded_data = absl::HexStringToBytes("1fd503"); + EXPECT_CALL(delegate_, WriteStreamData(Eq(expected_encoded_data))); + stream_.SendDuplicate(500); + EXPECT_EQ(expected_encoded_data.size(), stream_.BufferedByteCount()); + stream_.Flush(); +} + +TEST_F(QpackEncoderStreamSenderTest, SetDynamicTableCapacity) { + EXPECT_EQ(0u, stream_.BufferedByteCount()); + + // Small capacity fits in prefix. + std::string expected_encoded_data = absl::HexStringToBytes("31"); + EXPECT_CALL(delegate_, WriteStreamData(Eq(expected_encoded_data))); + stream_.SendSetDynamicTableCapacity(17); + EXPECT_EQ(expected_encoded_data.size(), stream_.BufferedByteCount()); + stream_.Flush(); + EXPECT_EQ(0u, stream_.BufferedByteCount()); + + // Large capacity requires two extension bytes. + expected_encoded_data = absl::HexStringToBytes("3fd503"); + EXPECT_CALL(delegate_, WriteStreamData(Eq(expected_encoded_data))); + stream_.SendSetDynamicTableCapacity(500); + EXPECT_EQ(expected_encoded_data.size(), stream_.BufferedByteCount()); + stream_.Flush(); + EXPECT_EQ(0u, stream_.BufferedByteCount()); +} + +// No writes should happen until Flush is called. +TEST_F(QpackEncoderStreamSenderTest, Coalesce) { + // Insert entry with static name reference, empty value. + stream_.SendInsertWithNameReference(true, 5, ""); + + // Insert entry with static name reference, Huffman encoded value. + stream_.SendInsertWithNameReference(true, 2, "foo"); + + // Insert literal entry, Huffman encoded short strings. + stream_.SendInsertWithoutNameReference("foo", "foo"); + + // Duplicate entry. + stream_.SendDuplicate(17); + + std::string expected_encoded_data = absl::HexStringToBytes( + "c500" // Insert entry with static name reference. + "c28294e7" // Insert entry with static name reference. + "6294e78294e7" // Insert literal entry. + "11"); // Duplicate entry. + + EXPECT_CALL(delegate_, WriteStreamData(Eq(expected_encoded_data))); + EXPECT_EQ(expected_encoded_data.size(), stream_.BufferedByteCount()); + stream_.Flush(); + EXPECT_EQ(0u, stream_.BufferedByteCount()); +} + +// No writes should happen if QpackEncoderStreamSender::Flush() is called +// when the buffer is empty. +TEST_F(QpackEncoderStreamSenderTest, FlushEmpty) { + EXPECT_EQ(0u, stream_.BufferedByteCount()); + stream_.Flush(); + EXPECT_EQ(0u, stream_.BufferedByteCount()); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_encoder_test.cc b/quiche/quic/core/qpack/qpack_encoder_test.cc new file mode 100644 index 000000000000..aa40a8ae563a --- /dev/null +++ b/quiche/quic/core/qpack/qpack_encoder_test.cc @@ -0,0 +1,633 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_encoder.h" + +#include +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/qpack/qpack_encoder_peer.h" +#include "quiche/quic/test_tools/qpack/qpack_test_utils.h" + +using ::testing::_; +using ::testing::Eq; +using ::testing::Return; +using ::testing::StrictMock; + +namespace quic { +namespace test { +namespace { + +// A number larger than kMaxBytesBufferedByStream in +// qpack_encoder_stream_sender.cc. Returning this value from NumBytesBuffered() +// will instruct QpackEncoder not to generate any instructions for the encoder +// stream. +constexpr uint64_t kTooManyBytesBuffered = 1024 * 1024; + +// Mock QpackEncoder::DecoderStreamErrorDelegate implementation. +class MockDecoderStreamErrorDelegate + : public QpackEncoder::DecoderStreamErrorDelegate { + public: + ~MockDecoderStreamErrorDelegate() override = default; + + MOCK_METHOD(void, OnDecoderStreamError, + (QuicErrorCode error_code, absl::string_view error_message), + (override)); +}; + +class QpackEncoderTest : public QuicTest { + protected: + QpackEncoderTest() + : encoder_(&decoder_stream_error_delegate_), + encoder_stream_sent_byte_count_(0) { + encoder_.set_qpack_stream_sender_delegate(&encoder_stream_sender_delegate_); + encoder_.SetMaximumBlockedStreams(1); + } + + ~QpackEncoderTest() override = default; + + std::string Encode(const spdy::Http2HeaderBlock& header_list) { + return encoder_.EncodeHeaderList(/* stream_id = */ 1, header_list, + &encoder_stream_sent_byte_count_); + } + + StrictMock decoder_stream_error_delegate_; + StrictMock encoder_stream_sender_delegate_; + QpackEncoder encoder_; + QuicByteCount encoder_stream_sent_byte_count_; +}; + +TEST_F(QpackEncoderTest, Empty) { + EXPECT_CALL(encoder_stream_sender_delegate_, NumBytesBuffered()) + .WillRepeatedly(Return(0)); + spdy::Http2HeaderBlock header_list; + std::string output = Encode(header_list); + + EXPECT_EQ(absl::HexStringToBytes("0000"), output); +} + +TEST_F(QpackEncoderTest, EmptyName) { + EXPECT_CALL(encoder_stream_sender_delegate_, NumBytesBuffered()) + .WillRepeatedly(Return(0)); + spdy::Http2HeaderBlock header_list; + header_list[""] = "foo"; + std::string output = Encode(header_list); + + EXPECT_EQ(absl::HexStringToBytes("0000208294e7"), output); +} + +TEST_F(QpackEncoderTest, EmptyValue) { + EXPECT_CALL(encoder_stream_sender_delegate_, NumBytesBuffered()) + .WillRepeatedly(Return(0)); + spdy::Http2HeaderBlock header_list; + header_list["foo"] = ""; + std::string output = Encode(header_list); + + EXPECT_EQ(absl::HexStringToBytes("00002a94e700"), output); +} + +TEST_F(QpackEncoderTest, EmptyNameAndValue) { + EXPECT_CALL(encoder_stream_sender_delegate_, NumBytesBuffered()) + .WillRepeatedly(Return(0)); + spdy::Http2HeaderBlock header_list; + header_list[""] = ""; + std::string output = Encode(header_list); + + EXPECT_EQ(absl::HexStringToBytes("00002000"), output); +} + +TEST_F(QpackEncoderTest, Simple) { + EXPECT_CALL(encoder_stream_sender_delegate_, NumBytesBuffered()) + .WillRepeatedly(Return(0)); + spdy::Http2HeaderBlock header_list; + header_list["foo"] = "bar"; + std::string output = Encode(header_list); + + EXPECT_EQ(absl::HexStringToBytes("00002a94e703626172"), output); +} + +TEST_F(QpackEncoderTest, Multiple) { + EXPECT_CALL(encoder_stream_sender_delegate_, NumBytesBuffered()) + .WillRepeatedly(Return(0)); + spdy::Http2HeaderBlock header_list; + header_list["foo"] = "bar"; + // 'Z' would be Huffman encoded to 8 bits, so no Huffman encoding is used. + header_list["ZZZZZZZ"] = std::string(127, 'Z'); + std::string output = Encode(header_list); + + EXPECT_EQ( + absl::HexStringToBytes( + "0000" // prefix + "2a94e703626172" // foo: bar + "27005a5a5a5a5a5a5a" // 7 octet long header name, the smallest number + // that does not fit on a 3-bit prefix. + "7f005a5a5a5a5a5a5a" // 127 octet long header value, the smallest + "5a5a5a5a5a5a5a5a5a" // number that does not fit on a 7-bit prefix. + "5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a" + "5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a" + "5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a" + "5a5a5a5a5a5a5a5a5a"), + output); +} + +TEST_F(QpackEncoderTest, StaticTable) { + EXPECT_CALL(encoder_stream_sender_delegate_, NumBytesBuffered()) + .WillRepeatedly(Return(0)); + { + spdy::Http2HeaderBlock header_list; + header_list[":method"] = "GET"; + header_list["accept-encoding"] = "gzip, deflate, br"; + header_list["location"] = ""; + + std::string output = Encode(header_list); + EXPECT_EQ(absl::HexStringToBytes("0000d1dfcc"), output); + } + { + spdy::Http2HeaderBlock header_list; + header_list[":method"] = "POST"; + header_list["accept-encoding"] = "compress"; + header_list["location"] = "foo"; + + std::string output = Encode(header_list); + EXPECT_EQ(absl::HexStringToBytes("0000d45f108621e9aec2a11f5c8294e7"), + output); + } + { + spdy::Http2HeaderBlock header_list; + header_list[":method"] = "TRACE"; + header_list["accept-encoding"] = ""; + + std::string output = Encode(header_list); + EXPECT_EQ(absl::HexStringToBytes("00005f000554524143455f1000"), output); + } +} + +TEST_F(QpackEncoderTest, DecoderStreamError) { + EXPECT_CALL(decoder_stream_error_delegate_, + OnDecoderStreamError(QUIC_QPACK_DECODER_STREAM_INTEGER_TOO_LARGE, + Eq("Encoded integer too large."))); + + QpackEncoder encoder(&decoder_stream_error_delegate_); + encoder.set_qpack_stream_sender_delegate(&encoder_stream_sender_delegate_); + encoder.decoder_stream_receiver()->Decode( + absl::HexStringToBytes("ffffffffffffffffffffff")); +} + +TEST_F(QpackEncoderTest, SplitAlongNullCharacter) { + EXPECT_CALL(encoder_stream_sender_delegate_, NumBytesBuffered()) + .WillRepeatedly(Return(0)); + spdy::Http2HeaderBlock header_list; + header_list["foo"] = absl::string_view("bar\0bar\0baz", 11); + std::string output = Encode(header_list); + + EXPECT_EQ(absl::HexStringToBytes("0000" // prefix + "2a94e703626172" // foo: bar + "2a94e703626172" // foo: bar + "2a94e70362617a" // foo: baz + ), + output); +} + +TEST_F(QpackEncoderTest, ZeroInsertCountIncrement) { + // Encoder receives insert count increment with forbidden value 0. + EXPECT_CALL( + decoder_stream_error_delegate_, + OnDecoderStreamError(QUIC_QPACK_DECODER_STREAM_INVALID_ZERO_INCREMENT, + Eq("Invalid increment value 0."))); + encoder_.OnInsertCountIncrement(0); +} + +TEST_F(QpackEncoderTest, TooLargeInsertCountIncrement) { + // Encoder receives insert count increment with value that increases Known + // Received Count to a value (one) which is larger than the number of dynamic + // table insertions sent (zero). + EXPECT_CALL( + decoder_stream_error_delegate_, + OnDecoderStreamError(QUIC_QPACK_DECODER_STREAM_IMPOSSIBLE_INSERT_COUNT, + Eq("Increment value 1 raises known received count " + "to 1 exceeding inserted entry count 0"))); + encoder_.OnInsertCountIncrement(1); +} + +// Regression test for https://crbug.com/1014372. +TEST_F(QpackEncoderTest, InsertCountIncrementOverflow) { + QpackEncoderHeaderTable* header_table = + QpackEncoderPeer::header_table(&encoder_); + + // Set dynamic table capacity large enough to hold one entry. + header_table->SetMaximumDynamicTableCapacity(4096); + header_table->SetDynamicTableCapacity(4096); + // Insert one entry into the header table. + header_table->InsertEntry("foo", "bar"); + + // Receive Insert Count Increment instruction with increment value 1. + encoder_.OnInsertCountIncrement(1); + + // Receive Insert Count Increment instruction that overflows the known + // received count. This must result in an error instead of a crash. + EXPECT_CALL(decoder_stream_error_delegate_, + OnDecoderStreamError( + QUIC_QPACK_DECODER_STREAM_INCREMENT_OVERFLOW, + Eq("Insert Count Increment instruction causes overflow."))); + encoder_.OnInsertCountIncrement(std::numeric_limits::max()); +} + +TEST_F(QpackEncoderTest, InvalidHeaderAcknowledgement) { + // Encoder receives header acknowledgement for a stream on which no header + // block with dynamic table entries was ever sent. + EXPECT_CALL( + decoder_stream_error_delegate_, + OnDecoderStreamError(QUIC_QPACK_DECODER_STREAM_INCORRECT_ACKNOWLEDGEMENT, + Eq("Header Acknowledgement received for stream 0 " + "with no outstanding header blocks."))); + encoder_.OnHeaderAcknowledgement(/* stream_id = */ 0); +} + +TEST_F(QpackEncoderTest, DynamicTable) { + EXPECT_CALL(encoder_stream_sender_delegate_, NumBytesBuffered()) + .WillRepeatedly(Return(0)); + encoder_.SetMaximumBlockedStreams(1); + encoder_.SetMaximumDynamicTableCapacity(4096); + encoder_.SetDynamicTableCapacity(4096); + + spdy::Http2HeaderBlock header_list; + header_list["foo"] = "bar"; + header_list.AppendValueOrAddHeader("foo", + "baz"); // name matches dynamic entry + header_list["cookie"] = "baz"; // name matches static entry + + // Set Dynamic Table Capacity instruction. + std::string set_dyanamic_table_capacity = absl::HexStringToBytes("3fe11f"); + // Insert three entries into the dynamic table. + std::string insert_entries = absl::HexStringToBytes( + "62" // insert without name reference + "94e7" // Huffman-encoded name "foo" + "03626172" // value "bar" + "80" // insert with name reference, dynamic index 0 + "0362617a" // value "baz" + "c5" // insert with name reference, static index 5 + "0362617a"); // value "baz" + EXPECT_CALL(encoder_stream_sender_delegate_, + WriteStreamData(Eq( + absl::StrCat(set_dyanamic_table_capacity, insert_entries)))); + + EXPECT_EQ(absl::HexStringToBytes( + "0400" // prefix + "828180"), // dynamic entries with relative index 0, 1, and 2 + Encode(header_list)); + + EXPECT_EQ(insert_entries.size(), encoder_stream_sent_byte_count_); +} + +// There is no room in the dynamic table after inserting the first entry. +TEST_F(QpackEncoderTest, SmallDynamicTable) { + EXPECT_CALL(encoder_stream_sender_delegate_, NumBytesBuffered()) + .WillRepeatedly(Return(0)); + encoder_.SetMaximumBlockedStreams(1); + encoder_.SetMaximumDynamicTableCapacity(QpackEntry::Size("foo", "bar")); + encoder_.SetDynamicTableCapacity(QpackEntry::Size("foo", "bar")); + + spdy::Http2HeaderBlock header_list; + header_list["foo"] = "bar"; + header_list.AppendValueOrAddHeader("foo", + "baz"); // name matches dynamic entry + header_list["cookie"] = "baz"; // name matches static entry + header_list["bar"] = "baz"; // no match + + // Set Dynamic Table Capacity instruction. + std::string set_dyanamic_table_capacity = absl::HexStringToBytes("3f07"); + // Insert one entry into the dynamic table. + std::string insert_entry = absl::HexStringToBytes( + "62" // insert without name reference + "94e7" // Huffman-encoded name "foo" + "03626172"); // value "bar" + EXPECT_CALL(encoder_stream_sender_delegate_, + WriteStreamData( + Eq(absl::StrCat(set_dyanamic_table_capacity, insert_entry)))); + + EXPECT_EQ(absl::HexStringToBytes("0200" // prefix + "80" // dynamic entry 0 + "40" // reference to dynamic entry 0 name + "0362617a" // with literal value "baz" + "55" // reference to static entry 5 name + "0362617a" // with literal value "baz" + "23626172" // literal name "bar" + "0362617a"), // with literal value "baz" + Encode(header_list)); + + EXPECT_EQ(insert_entry.size(), encoder_stream_sent_byte_count_); +} + +TEST_F(QpackEncoderTest, BlockedStream) { + EXPECT_CALL(encoder_stream_sender_delegate_, NumBytesBuffered()) + .WillRepeatedly(Return(0)); + encoder_.SetMaximumBlockedStreams(1); + encoder_.SetMaximumDynamicTableCapacity(4096); + encoder_.SetDynamicTableCapacity(4096); + + spdy::Http2HeaderBlock header_list1; + header_list1["foo"] = "bar"; + + // Set Dynamic Table Capacity instruction. + std::string set_dyanamic_table_capacity = absl::HexStringToBytes("3fe11f"); + // Insert one entry into the dynamic table. + std::string insert_entry1 = absl::HexStringToBytes( + "62" // insert without name reference + "94e7" // Huffman-encoded name "foo" + "03626172"); // value "bar" + EXPECT_CALL(encoder_stream_sender_delegate_, + WriteStreamData(Eq( + absl::StrCat(set_dyanamic_table_capacity, insert_entry1)))); + + EXPECT_EQ(absl::HexStringToBytes("0200" // prefix + "80"), // dynamic entry 0 + encoder_.EncodeHeaderList(/* stream_id = */ 1, header_list1, + &encoder_stream_sent_byte_count_)); + EXPECT_EQ(insert_entry1.size(), encoder_stream_sent_byte_count_); + + // Stream 1 is blocked. Stream 2 is not allowed to block. + spdy::Http2HeaderBlock header_list2; + header_list2["foo"] = "bar"; // name and value match dynamic entry + header_list2.AppendValueOrAddHeader("foo", + "baz"); // name matches dynamic entry + header_list2["cookie"] = "baz"; // name matches static entry + header_list2["bar"] = "baz"; // no match + + EXPECT_EQ(absl::HexStringToBytes("0000" // prefix + "2a94e7" // literal name "foo" + "03626172" // with literal value "bar" + "2a94e7" // literal name "foo" + "0362617a" // with literal value "baz" + "55" // name of static entry 5 + "0362617a" // with literal value "baz" + "23626172" // literal name "bar" + "0362617a"), // with literal value "baz" + encoder_.EncodeHeaderList(/* stream_id = */ 2, header_list2, + &encoder_stream_sent_byte_count_)); + EXPECT_EQ(0u, encoder_stream_sent_byte_count_); + + // Peer acknowledges receipt of one dynamic table entry. + // Stream 1 is no longer blocked. + encoder_.OnInsertCountIncrement(1); + + // Insert three entries into the dynamic table. + std::string insert_entries = absl::HexStringToBytes( + "80" // insert with name reference, dynamic index 0 + "0362617a" // value "baz" + "c5" // insert with name reference, static index 5 + "0362617a" // value "baz" + "43" // insert without name reference + "626172" // name "bar" + "0362617a"); // value "baz" + EXPECT_CALL(encoder_stream_sender_delegate_, + WriteStreamData(Eq(insert_entries))); + + EXPECT_EQ(absl::HexStringToBytes("0500" // prefix + "83828180"), // dynamic entries + encoder_.EncodeHeaderList(/* stream_id = */ 3, header_list2, + &encoder_stream_sent_byte_count_)); + EXPECT_EQ(insert_entries.size(), encoder_stream_sent_byte_count_); + + // Stream 3 is blocked. Stream 4 is not allowed to block, but it can + // reference already acknowledged dynamic entry 0. + EXPECT_EQ(absl::HexStringToBytes("0200" // prefix + "80" // dynamic entry 0 + "2a94e7" // literal name "foo" + "0362617a" // with literal value "baz" + "2c21cfd4c5" // literal name "cookie" + "0362617a" // with literal value "baz" + "23626172" // literal name "bar" + "0362617a"), // with literal value "baz" + encoder_.EncodeHeaderList(/* stream_id = */ 4, header_list2, + &encoder_stream_sent_byte_count_)); + EXPECT_EQ(0u, encoder_stream_sent_byte_count_); + + // Peer acknowledges receipt of two more dynamic table entries. + // Stream 3 is still blocked. + encoder_.OnInsertCountIncrement(2); + + // Stream 5 is not allowed to block, but it can reference already acknowledged + // dynamic entries 0, 1, and 2. + EXPECT_EQ(absl::HexStringToBytes("0400" // prefix + "828180" // dynamic entries + "23626172" // literal name "bar" + "0362617a"), // with literal value "baz" + encoder_.EncodeHeaderList(/* stream_id = */ 5, header_list2, + &encoder_stream_sent_byte_count_)); + EXPECT_EQ(0u, encoder_stream_sent_byte_count_); + + // Peer acknowledges decoding header block on stream 3. + // Stream 3 is not blocked any longer. + encoder_.OnHeaderAcknowledgement(3); + + EXPECT_EQ(absl::HexStringToBytes("0500" // prefix + "83828180"), // dynamic entries + encoder_.EncodeHeaderList(/* stream_id = */ 6, header_list2, + &encoder_stream_sent_byte_count_)); + EXPECT_EQ(0u, encoder_stream_sent_byte_count_); +} + +TEST_F(QpackEncoderTest, Draining) { + EXPECT_CALL(encoder_stream_sender_delegate_, NumBytesBuffered()) + .WillRepeatedly(Return(0)); + spdy::Http2HeaderBlock header_list1; + header_list1["one"] = "foo"; + header_list1["two"] = "foo"; + header_list1["three"] = "foo"; + header_list1["four"] = "foo"; + header_list1["five"] = "foo"; + header_list1["six"] = "foo"; + header_list1["seven"] = "foo"; + header_list1["eight"] = "foo"; + header_list1["nine"] = "foo"; + header_list1["ten"] = "foo"; + + // Make just enough room in the dynamic table for the header list plus the + // first entry duplicated. This will ensure that the oldest entries are + // draining. + uint64_t maximum_dynamic_table_capacity = 0; + for (const auto& header_field : header_list1) { + maximum_dynamic_table_capacity += + QpackEntry::Size(header_field.first, header_field.second); + } + maximum_dynamic_table_capacity += QpackEntry::Size("one", "foo"); + encoder_.SetMaximumDynamicTableCapacity(maximum_dynamic_table_capacity); + encoder_.SetDynamicTableCapacity(maximum_dynamic_table_capacity); + + // Set Dynamic Table Capacity instruction and insert ten entries into the + // dynamic table. + EXPECT_CALL(encoder_stream_sender_delegate_, WriteStreamData(_)); + + EXPECT_EQ(absl::HexStringToBytes("0b00" // prefix + "89888786858483828180"), // dynamic entries + Encode(header_list1)); + + // Entry is identical to oldest one, which is draining. It will be + // duplicated and referenced. + spdy::Http2HeaderBlock header_list2; + header_list2["one"] = "foo"; + + // Duplicate oldest entry. + EXPECT_CALL(encoder_stream_sender_delegate_, + WriteStreamData(Eq(absl::HexStringToBytes("09")))); + + EXPECT_EQ(absl::HexStringToBytes("0c00" // prefix + "80"), // most recent dynamic table entry + Encode(header_list2)); + + spdy::Http2HeaderBlock header_list3; + // Entry is identical to second oldest one, which is draining. There is no + // room to duplicate, it will be encoded with string literals. + header_list3.AppendValueOrAddHeader("two", "foo"); + // Entry has name identical to second oldest one, which is draining. There is + // no room to insert new entry, it will be encoded with string literals. + header_list3.AppendValueOrAddHeader("two", "bar"); + + EXPECT_EQ(absl::HexStringToBytes("0000" // prefix + "2374776f" // literal name "two" + "8294e7" // literal value "foo" + "2374776f" // literal name "two" + "03626172"), // literal value "bar" + Encode(header_list3)); +} + +TEST_F(QpackEncoderTest, DynamicTableCapacityLessThanMaximum) { + encoder_.SetMaximumDynamicTableCapacity(1024); + encoder_.SetDynamicTableCapacity(30); + + QpackEncoderHeaderTable* header_table = + QpackEncoderPeer::header_table(&encoder_); + + EXPECT_EQ(1024u, header_table->maximum_dynamic_table_capacity()); + EXPECT_EQ(30u, header_table->dynamic_table_capacity()); +} + +TEST_F(QpackEncoderTest, EncoderStreamWritesDisallowedThenAllowed) { + EXPECT_CALL(encoder_stream_sender_delegate_, NumBytesBuffered()) + .WillRepeatedly(Return(kTooManyBytesBuffered)); + encoder_.SetMaximumBlockedStreams(1); + encoder_.SetMaximumDynamicTableCapacity(4096); + encoder_.SetDynamicTableCapacity(4096); + + spdy::Http2HeaderBlock header_list1; + header_list1["foo"] = "bar"; + header_list1.AppendValueOrAddHeader("foo", "baz"); + header_list1["cookie"] = "baz"; // name matches static entry + + // Encoder is not allowed to write on the encoder stream. + // No Set Dynamic Table Capacity or Insert instructions are sent. + // Headers are encoded as string literals. + EXPECT_EQ(absl::HexStringToBytes("0000" // prefix + "2a94e7" // literal name "foo" + "03626172" // with literal value "bar" + "2a94e7" // literal name "foo" + "0362617a" // with literal value "baz" + "55" // name of static entry 5 + "0362617a"), // with literal value "baz" + Encode(header_list1)); + + EXPECT_EQ(0u, encoder_stream_sent_byte_count_); + + // If number of bytes buffered by encoder stream goes under the threshold, + // then QpackEncoder will resume emitting encoder stream instructions. + ::testing::Mock::VerifyAndClearExpectations(&encoder_stream_sender_delegate_); + EXPECT_CALL(encoder_stream_sender_delegate_, NumBytesBuffered()) + .WillRepeatedly(Return(0)); + + spdy::Http2HeaderBlock header_list2; + header_list2["foo"] = "bar"; + header_list2.AppendValueOrAddHeader("foo", + "baz"); // name matches dynamic entry + header_list2["cookie"] = "baz"; // name matches static entry + + // Set Dynamic Table Capacity instruction. + std::string set_dyanamic_table_capacity = absl::HexStringToBytes("3fe11f"); + // Insert three entries into the dynamic table. + std::string insert_entries = absl::HexStringToBytes( + "62" // insert without name reference + "94e7" // Huffman-encoded name "foo" + "03626172" // value "bar" + "80" // insert with name reference, dynamic index 0 + "0362617a" // value "baz" + "c5" // insert with name reference, static index 5 + "0362617a"); // value "baz" + EXPECT_CALL(encoder_stream_sender_delegate_, + WriteStreamData(Eq( + absl::StrCat(set_dyanamic_table_capacity, insert_entries)))); + + EXPECT_EQ(absl::HexStringToBytes( + "0400" // prefix + "828180"), // dynamic entries with relative index 0, 1, and 2 + Encode(header_list2)); + + EXPECT_EQ(insert_entries.size(), encoder_stream_sent_byte_count_); +} + +TEST_F(QpackEncoderTest, EncoderStreamWritesAllowedThenDisallowed) { + EXPECT_CALL(encoder_stream_sender_delegate_, NumBytesBuffered()) + .WillRepeatedly(Return(0)); + encoder_.SetMaximumBlockedStreams(1); + encoder_.SetMaximumDynamicTableCapacity(4096); + encoder_.SetDynamicTableCapacity(4096); + + spdy::Http2HeaderBlock header_list1; + header_list1["foo"] = "bar"; + header_list1.AppendValueOrAddHeader("foo", + "baz"); // name matches dynamic entry + header_list1["cookie"] = "baz"; // name matches static entry + + // Set Dynamic Table Capacity instruction. + std::string set_dyanamic_table_capacity = absl::HexStringToBytes("3fe11f"); + // Insert three entries into the dynamic table. + std::string insert_entries = absl::HexStringToBytes( + "62" // insert without name reference + "94e7" // Huffman-encoded name "foo" + "03626172" // value "bar" + "80" // insert with name reference, dynamic index 0 + "0362617a" // value "baz" + "c5" // insert with name reference, static index 5 + "0362617a"); // value "baz" + EXPECT_CALL(encoder_stream_sender_delegate_, + WriteStreamData(Eq( + absl::StrCat(set_dyanamic_table_capacity, insert_entries)))); + + EXPECT_EQ(absl::HexStringToBytes( + "0400" // prefix + "828180"), // dynamic entries with relative index 0, 1, and 2 + Encode(header_list1)); + + EXPECT_EQ(insert_entries.size(), encoder_stream_sent_byte_count_); + + // If number of bytes buffered by encoder stream goes over the threshold, + // then QpackEncoder will stop emitting encoder stream instructions. + ::testing::Mock::VerifyAndClearExpectations(&encoder_stream_sender_delegate_); + EXPECT_CALL(encoder_stream_sender_delegate_, NumBytesBuffered()) + .WillRepeatedly(Return(kTooManyBytesBuffered)); + + spdy::Http2HeaderBlock header_list2; + header_list2["foo"] = "bar"; // matches previously inserted dynamic entry + header_list2["bar"] = "baz"; + header_list2["cookie"] = "baz"; // name matches static entry + + // Encoder is not allowed to write on the encoder stream. + // No Set Dynamic Table Capacity or Insert instructions are sent. + // Headers are encoded as string literals. + EXPECT_EQ( + absl::HexStringToBytes("0400" // prefix + "82" // dynamic entry with relative index 0 + "23626172" // literal name "bar" + "0362617a" // with literal value "baz" + "80"), // dynamic entry with relative index 2 + Encode(header_list2)); + + EXPECT_EQ(0u, encoder_stream_sent_byte_count_); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_header_table.cc b/quiche/quic/core/qpack/qpack_header_table.cc new file mode 100644 index 000000000000..d5b834df5f3a --- /dev/null +++ b/quiche/quic/core/qpack/qpack_header_table.cc @@ -0,0 +1,239 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_header_table.h" + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/qpack/qpack_static_table.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace quic { + +QpackEncoderHeaderTable::QpackEncoderHeaderTable() + : static_index_(ObtainQpackStaticTable().GetStaticIndex()), + static_name_index_(ObtainQpackStaticTable().GetStaticNameIndex()) {} + +uint64_t QpackEncoderHeaderTable::InsertEntry(absl::string_view name, + absl::string_view value) { + const uint64_t index = + QpackHeaderTableBase::InsertEntry(name, value); + + // Make name and value point to the new entry. + name = dynamic_entries().back()->name(); + value = dynamic_entries().back()->value(); + + auto index_result = dynamic_index_.insert( + std::make_pair(QpackLookupEntry{name, value}, index)); + if (!index_result.second) { + // An entry with the same name and value already exists. It needs to be + // replaced, because |dynamic_index_| tracks the most recent entry for a + // given name and value. + QUICHE_DCHECK_GT(index, index_result.first->second); + dynamic_index_.erase(index_result.first); + auto result = dynamic_index_.insert( + std::make_pair(QpackLookupEntry{name, value}, index)); + QUICHE_CHECK(result.second); + } + + auto name_result = dynamic_name_index_.insert({name, index}); + if (!name_result.second) { + // An entry with the same name already exists. It needs to be replaced, + // because |dynamic_name_index_| tracks the most recent entry for a given + // name. + QUICHE_DCHECK_GT(index, name_result.first->second); + dynamic_name_index_.erase(name_result.first); + auto result = dynamic_name_index_.insert({name, index}); + QUICHE_CHECK(result.second); + } + + return index; +} + +QpackEncoderHeaderTable::MatchType QpackEncoderHeaderTable::FindHeaderField( + absl::string_view name, absl::string_view value, bool* is_static, + uint64_t* index) const { + QpackLookupEntry query{name, value}; + + // Look for exact match in static table. + auto index_it = static_index_.find(query); + if (index_it != static_index_.end()) { + *index = index_it->second; + *is_static = true; + return MatchType::kNameAndValue; + } + + // Look for exact match in dynamic table. + index_it = dynamic_index_.find(query); + if (index_it != dynamic_index_.end()) { + *index = index_it->second; + *is_static = false; + return MatchType::kNameAndValue; + } + + // Look for name match in static table. + auto name_index_it = static_name_index_.find(name); + if (name_index_it != static_name_index_.end()) { + *index = name_index_it->second; + *is_static = true; + return MatchType::kName; + } + + // Look for name match in dynamic table. + name_index_it = dynamic_name_index_.find(name); + if (name_index_it != dynamic_name_index_.end()) { + *index = name_index_it->second; + *is_static = false; + return MatchType::kName; + } + + return MatchType::kNoMatch; +} + +uint64_t QpackEncoderHeaderTable::MaxInsertSizeWithoutEvictingGivenEntry( + uint64_t index) const { + QUICHE_DCHECK_LE(dropped_entry_count(), index); + + if (index > inserted_entry_count()) { + // All entries are allowed to be evicted. + return dynamic_table_capacity(); + } + + // Initialize to current available capacity. + uint64_t max_insert_size = dynamic_table_capacity() - dynamic_table_size(); + + uint64_t entry_index = dropped_entry_count(); + for (const auto& entry : dynamic_entries()) { + if (entry_index >= index) { + break; + } + ++entry_index; + max_insert_size += entry->Size(); + } + + return max_insert_size; +} + +uint64_t QpackEncoderHeaderTable::draining_index( + float draining_fraction) const { + QUICHE_DCHECK_LE(0.0, draining_fraction); + QUICHE_DCHECK_LE(draining_fraction, 1.0); + + const uint64_t required_space = draining_fraction * dynamic_table_capacity(); + uint64_t space_above_draining_index = + dynamic_table_capacity() - dynamic_table_size(); + + if (dynamic_entries().empty() || + space_above_draining_index >= required_space) { + return dropped_entry_count(); + } + + auto it = dynamic_entries().begin(); + uint64_t entry_index = dropped_entry_count(); + while (space_above_draining_index < required_space) { + space_above_draining_index += (*it)->Size(); + ++it; + ++entry_index; + if (it == dynamic_entries().end()) { + return inserted_entry_count(); + } + } + + return entry_index; +} + +void QpackEncoderHeaderTable::RemoveEntryFromEnd() { + const QpackEntry* const entry = dynamic_entries().front().get(); + const uint64_t index = dropped_entry_count(); + + auto index_it = dynamic_index_.find({entry->name(), entry->value()}); + // Remove |dynamic_index_| entry only if it points to the same + // QpackEntry in dynamic_entries(). + if (index_it != dynamic_index_.end() && index_it->second == index) { + dynamic_index_.erase(index_it); + } + + auto name_it = dynamic_name_index_.find(entry->name()); + // Remove |dynamic_name_index_| entry only if it points to the same + // QpackEntry in dynamic_entries(). + if (name_it != dynamic_name_index_.end() && name_it->second == index) { + dynamic_name_index_.erase(name_it); + } + + QpackHeaderTableBase::RemoveEntryFromEnd(); +} + +QpackDecoderHeaderTable::QpackDecoderHeaderTable() + : static_entries_(ObtainQpackStaticTable().GetStaticEntries()) {} + +QpackDecoderHeaderTable::~QpackDecoderHeaderTable() { + for (auto& entry : observers_) { + entry.second->Cancel(); + } +} + +uint64_t QpackDecoderHeaderTable::InsertEntry(absl::string_view name, + absl::string_view value) { + const uint64_t index = + QpackHeaderTableBase::InsertEntry(name, value); + + // Notify and deregister observers whose threshold is met, if any. + while (!observers_.empty()) { + auto it = observers_.begin(); + if (it->first > inserted_entry_count()) { + break; + } + Observer* observer = it->second; + observers_.erase(it); + observer->OnInsertCountReachedThreshold(); + } + + return index; +} + +const QpackEntry* QpackDecoderHeaderTable::LookupEntry(bool is_static, + uint64_t index) const { + if (is_static) { + if (index >= static_entries_.size()) { + return nullptr; + } + + return &static_entries_[index]; + } + + if (index < dropped_entry_count()) { + return nullptr; + } + + index -= dropped_entry_count(); + + if (index >= dynamic_entries().size()) { + return nullptr; + } + + return &dynamic_entries()[index]; +} + +void QpackDecoderHeaderTable::RegisterObserver(uint64_t required_insert_count, + Observer* observer) { + QUICHE_DCHECK_GT(required_insert_count, 0u); + observers_.insert({required_insert_count, observer}); +} + +void QpackDecoderHeaderTable::UnregisterObserver(uint64_t required_insert_count, + Observer* observer) { + auto it = observers_.lower_bound(required_insert_count); + while (it != observers_.end() && it->first == required_insert_count) { + if (it->second == observer) { + observers_.erase(it); + return; + } + ++it; + } + + // |observer| must have been registered. + QUICHE_NOTREACHED(); +} + +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_header_table.h b/quiche/quic/core/qpack/qpack_header_table.h new file mode 100644 index 000000000000..f882e751cb7a --- /dev/null +++ b/quiche/quic/core/qpack/qpack_header_table.h @@ -0,0 +1,364 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QPACK_QPACK_HEADER_TABLE_H_ +#define QUICHE_QUIC_CORE_QPACK_QPACK_HEADER_TABLE_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/common/quiche_circular_deque.h" +#include "quiche/spdy/core/hpack/hpack_entry.h" +#include "quiche/spdy/core/hpack/hpack_header_table.h" + +namespace quic { + +using QpackEntry = spdy::HpackEntry; +using QpackLookupEntry = spdy::HpackLookupEntry; +constexpr size_t kQpackEntrySizeOverhead = spdy::kHpackEntrySizeOverhead; + +// Encoder needs pointer stability for |dynamic_index_| and +// |dynamic_name_index_|. However, it does not need random access. +using QpackEncoderDynamicTable = + quiche::QuicheCircularDeque>; + +// Decoder needs random access for LookupEntry(). +// However, it does not need pointer stability. +using QpackDecoderDynamicTable = quiche::QuicheCircularDeque; + +// This is a base class for encoder and decoder classes that manage the QPACK +// static and dynamic tables. For dynamic entries, it only has a concept of +// absolute indices. The caller needs to perform the necessary transformations +// to and from relative indices and post-base indices. +template +class QUIC_EXPORT_PRIVATE QpackHeaderTableBase { + public: + QpackHeaderTableBase(); + QpackHeaderTableBase(const QpackHeaderTableBase&) = delete; + QpackHeaderTableBase& operator=(const QpackHeaderTableBase&) = delete; + + virtual ~QpackHeaderTableBase() = default; + + // Returns whether an entry with |name| and |value| has a size (including + // overhead) that is smaller than or equal to the capacity of the dynamic + // table. + bool EntryFitsDynamicTableCapacity(absl::string_view name, + absl::string_view value) const; + + // Inserts (name, value) into the dynamic table. Entry must not be larger + // than the capacity of the dynamic table. May evict entries. |name| and + // |value| are copied first, therefore it is safe for them to point to an + // entry in the dynamic table, even if it is about to be evicted, or even if + // the underlying container might move entries around when resizing for + // insertion. + // Returns the absolute index of the inserted dynamic table entry. + virtual uint64_t InsertEntry(absl::string_view name, absl::string_view value); + + // Change dynamic table capacity to |capacity|. Returns true on success. + // Returns false is |capacity| exceeds maximum dynamic table capacity. + bool SetDynamicTableCapacity(uint64_t capacity); + + // Set |maximum_dynamic_table_capacity_|. The initial value is zero. The + // final value is determined by the decoder and is sent to the encoder as + // SETTINGS_HEADER_TABLE_SIZE. Therefore in the decoding context the final + // value can be set upon connection establishment, whereas in the encoding + // context it can be set when the SETTINGS frame is received. + // This method must only be called at most once. + // Returns true if |maximum_dynamic_table_capacity| is set for the first time + // or if it doesn't change current value. The setting is not changed when + // returning false. + bool SetMaximumDynamicTableCapacity(uint64_t maximum_dynamic_table_capacity); + + uint64_t dynamic_table_size() const { return dynamic_table_size_; } + uint64_t dynamic_table_capacity() const { return dynamic_table_capacity_; } + uint64_t maximum_dynamic_table_capacity() const { + return maximum_dynamic_table_capacity_; + } + uint64_t max_entries() const { return max_entries_; } + + // The number of entries inserted to the dynamic table (including ones that + // were dropped since). Used for relative indexing on the encoder stream. + uint64_t inserted_entry_count() const { + return dynamic_entries_.size() + dropped_entry_count_; + } + + // The number of entries dropped from the dynamic table. + uint64_t dropped_entry_count() const { return dropped_entry_count_; } + + void set_dynamic_table_entry_referenced() { + dynamic_table_entry_referenced_ = true; + } + bool dynamic_table_entry_referenced() const { + return dynamic_table_entry_referenced_; + } + + protected: + // Removes a single entry from the end of the dynamic table, updates + // |dynamic_table_size_| and |dropped_entry_count_|. + virtual void RemoveEntryFromEnd(); + + const DynamicEntryTable& dynamic_entries() const { return dynamic_entries_; } + + private: + // Evict entries from the dynamic table until table size is less than or equal + // to |capacity|. + void EvictDownToCapacity(uint64_t capacity); + + // Dynamic Table entries. + DynamicEntryTable dynamic_entries_; + + // Size of the dynamic table. This is the sum of the size of its entries. + uint64_t dynamic_table_size_; + + // Dynamic Table Capacity is the maximum allowed value of + // |dynamic_table_size_|. Entries are evicted if necessary before inserting a + // new entry to ensure that dynamic table size never exceeds capacity. + // Initial value is |maximum_dynamic_table_capacity_|. Capacity can be + // changed by the encoder, as long as it does not exceed + // |maximum_dynamic_table_capacity_|. + uint64_t dynamic_table_capacity_; + + // Maximum allowed value of |dynamic_table_capacity|. The initial value is + // zero. Can be changed by SetMaximumDynamicTableCapacity(). + uint64_t maximum_dynamic_table_capacity_; + + // MaxEntries, see Section 3.2.2. Calculated based on + // |maximum_dynamic_table_capacity_|. Used on request streams to encode and + // decode Required Insert Count. + uint64_t max_entries_; + + // The number of entries dropped from the dynamic table. + uint64_t dropped_entry_count_; + + // True if any dynamic table entries have been referenced from a header block. + // Set directly by the encoder or decoder. Used for stats. + bool dynamic_table_entry_referenced_; +}; + +template +QpackHeaderTableBase::QpackHeaderTableBase() + : dynamic_table_size_(0), + dynamic_table_capacity_(0), + maximum_dynamic_table_capacity_(0), + max_entries_(0), + dropped_entry_count_(0), + dynamic_table_entry_referenced_(false) {} + +template +bool QpackHeaderTableBase::EntryFitsDynamicTableCapacity( + absl::string_view name, absl::string_view value) const { + return QpackEntry::Size(name, value) <= dynamic_table_capacity_; +} + +namespace internal { + +QUIC_NO_EXPORT inline size_t GetSize(const QpackEntry& entry) { + return entry.Size(); +} + +QUIC_NO_EXPORT inline size_t GetSize(const std::unique_ptr& entry) { + return entry->Size(); +} + +QUIC_NO_EXPORT inline std::unique_ptr NewEntry( + std::string name, std::string value, QpackEncoderDynamicTable& /*t*/) { + return std::make_unique(std::move(name), std::move(value)); +} + +QUIC_NO_EXPORT inline QpackEntry NewEntry(std::string name, std::string value, + QpackDecoderDynamicTable& /*t*/) { + return QpackEntry{std::move(name), std::move(value)}; +} + +} // namespace internal + +template +uint64_t QpackHeaderTableBase::InsertEntry( + absl::string_view name, absl::string_view value) { + QUICHE_DCHECK(EntryFitsDynamicTableCapacity(name, value)); + + const uint64_t index = dropped_entry_count_ + dynamic_entries_.size(); + + // Copy name and value before modifying the container, because evicting + // entries or even inserting a new one might invalidate |name| or |value| if + // they point to an entry. + auto new_entry = internal::NewEntry(std::string(name), std::string(value), + dynamic_entries_); + const size_t entry_size = internal::GetSize(new_entry); + EvictDownToCapacity(dynamic_table_capacity_ - entry_size); + + dynamic_table_size_ += entry_size; + dynamic_entries_.push_back(std::move(new_entry)); + + return index; +} + +template +bool QpackHeaderTableBase::SetDynamicTableCapacity( + uint64_t capacity) { + if (capacity > maximum_dynamic_table_capacity_) { + return false; + } + + dynamic_table_capacity_ = capacity; + EvictDownToCapacity(capacity); + + QUICHE_DCHECK_LE(dynamic_table_size_, dynamic_table_capacity_); + + return true; +} + +template +bool QpackHeaderTableBase::SetMaximumDynamicTableCapacity( + uint64_t maximum_dynamic_table_capacity) { + if (maximum_dynamic_table_capacity_ == 0) { + maximum_dynamic_table_capacity_ = maximum_dynamic_table_capacity; + max_entries_ = maximum_dynamic_table_capacity / 32; + return true; + } + // If the value is already set, it should not be changed. + return maximum_dynamic_table_capacity == maximum_dynamic_table_capacity_; +} + +template +void QpackHeaderTableBase::RemoveEntryFromEnd() { + const uint64_t entry_size = internal::GetSize(dynamic_entries_.front()); + QUICHE_DCHECK_GE(dynamic_table_size_, entry_size); + dynamic_table_size_ -= entry_size; + + dynamic_entries_.pop_front(); + ++dropped_entry_count_; +} + +template +void QpackHeaderTableBase::EvictDownToCapacity( + uint64_t capacity) { + while (dynamic_table_size_ > capacity) { + QUICHE_DCHECK(!dynamic_entries_.empty()); + RemoveEntryFromEnd(); + } +} + +class QUIC_EXPORT_PRIVATE QpackEncoderHeaderTable + : public QpackHeaderTableBase { + public: + // Result of header table lookup. + enum class MatchType { kNameAndValue, kName, kNoMatch }; + + QpackEncoderHeaderTable(); + ~QpackEncoderHeaderTable() override = default; + + uint64_t InsertEntry(absl::string_view name, + absl::string_view value) override; + + // Returns the absolute index of an entry with matching name and value if such + // exists, otherwise one with matching name is such exists. |index| is zero + // based for both the static and the dynamic table. + MatchType FindHeaderField(absl::string_view name, absl::string_view value, + bool* is_static, uint64_t* index) const; + + // Returns the size of the largest entry that could be inserted into the + // dynamic table without evicting entry |index|. |index| might be larger than + // inserted_entry_count(), in which case the capacity of the table is + // returned. |index| must not be smaller than dropped_entry_count(). + uint64_t MaxInsertSizeWithoutEvictingGivenEntry(uint64_t index) const; + + // Returns the draining index described at + // https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#avoiding-blocked-insertions. + // Entries with an index larger than or equal to the draining index take up + // approximately |1.0 - draining_fraction| of dynamic table capacity. The + // remaining capacity is taken up by draining entries and unused space. + // The returned index might not be the index of a valid entry. + uint64_t draining_index(float draining_fraction) const; + + protected: + void RemoveEntryFromEnd() override; + + private: + using NameValueToEntryMap = spdy::HpackHeaderTable::NameValueToEntryMap; + using NameToEntryMap = spdy::HpackHeaderTable::NameToEntryMap; + + // Static Table + + // |static_index_| and |static_name_index_| are owned by QpackStaticTable + // singleton. + + // Tracks the unique static entry for a given header name and value. + const NameValueToEntryMap& static_index_; + + // Tracks the first static entry for a given header name. + const NameToEntryMap& static_name_index_; + + // Dynamic Table + + // An unordered set of QpackEntry pointers with a comparison operator that + // only cares about name and value. This allows fast lookup of the most + // recently inserted dynamic entry for a given header name and value pair. + // Entries point to entries owned by |QpackHeaderTableBase::dynamic_entries_|. + NameValueToEntryMap dynamic_index_; + + // An unordered map of QpackEntry pointers keyed off header name. This allows + // fast lookup of the most recently inserted dynamic entry for a given header + // name. Entries point to entries owned by + // |QpackHeaderTableBase::dynamic_entries_|. + NameToEntryMap dynamic_name_index_; +}; + +class QUIC_EXPORT_PRIVATE QpackDecoderHeaderTable + : public QpackHeaderTableBase { + public: + // Observer interface for dynamic table insertion. + class QUIC_EXPORT_PRIVATE Observer { + public: + virtual ~Observer() = default; + + // Called when inserted_entry_count() reaches the threshold the Observer was + // registered with. After this call the Observer automatically gets + // deregistered. + virtual void OnInsertCountReachedThreshold() = 0; + + // Called when QpackDecoderHeaderTable is destroyed to let the Observer know + // that it must not call UnregisterObserver(). + virtual void Cancel() = 0; + }; + + QpackDecoderHeaderTable(); + ~QpackDecoderHeaderTable() override; + + uint64_t InsertEntry(absl::string_view name, + absl::string_view value) override; + + // Returns the entry at absolute index |index| from the static or dynamic + // table according to |is_static|. |index| is zero based for both the static + // and the dynamic table. The returned pointer is valid until the entry is + // evicted, even if other entries are inserted into the dynamic table. + // Returns nullptr if entry does not exist. + const QpackEntry* LookupEntry(bool is_static, uint64_t index) const; + + // Register an observer to be notified when inserted_entry_count() reaches + // |required_insert_count|. After the notification, |observer| automatically + // gets unregistered. Each observer must only be registered at most once. + void RegisterObserver(uint64_t required_insert_count, Observer* observer); + + // Unregister previously registered observer. Must be called with the same + // |required_insert_count| value that |observer| was registered with. Must be + // called before an observer still waiting for notification is destroyed, + // unless QpackDecoderHeaderTable already called Observer::Cancel(), in which + // case this method must not be called. + void UnregisterObserver(uint64_t required_insert_count, Observer* observer); + + private: + // Static Table entries. Owned by QpackStaticTable singleton. + using StaticEntryTable = spdy::HpackHeaderTable::StaticEntryTable; + const StaticEntryTable& static_entries_; + + // Observers waiting to be notified, sorted by required insert count. + std::multimap observers_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QPACK_QPACK_HEADER_TABLE_H_ diff --git a/quiche/quic/core/qpack/qpack_header_table_test.cc b/quiche/quic/core/qpack/qpack_header_table_test.cc new file mode 100644 index 000000000000..3450e667e87d --- /dev/null +++ b/quiche/quic/core/qpack/qpack_header_table_test.cc @@ -0,0 +1,652 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_header_table.h" + +#include + +#include "absl/base/macros.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/qpack/qpack_static_table.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/spdy/core/hpack/hpack_entry.h" + +using ::testing::Mock; +using ::testing::StrictMock; + +namespace quic { +namespace test { +namespace { + +const uint64_t kMaximumDynamicTableCapacityForTesting = 1024 * 1024; + +template +class QpackHeaderTableTest : public QuicTest { + protected: + ~QpackHeaderTableTest() override = default; + + void SetUp() override { + ASSERT_TRUE(table_.SetMaximumDynamicTableCapacity( + kMaximumDynamicTableCapacityForTesting)); + ASSERT_TRUE( + table_.SetDynamicTableCapacity(kMaximumDynamicTableCapacityForTesting)); + } + + bool EntryFitsDynamicTableCapacity(absl::string_view name, + absl::string_view value) const { + return table_.EntryFitsDynamicTableCapacity(name, value); + } + + void InsertEntry(absl::string_view name, absl::string_view value) { + table_.InsertEntry(name, value); + } + + bool SetDynamicTableCapacity(uint64_t capacity) { + return table_.SetDynamicTableCapacity(capacity); + } + + uint64_t max_entries() const { return table_.max_entries(); } + uint64_t inserted_entry_count() const { + return table_.inserted_entry_count(); + } + uint64_t dropped_entry_count() const { return table_.dropped_entry_count(); } + + T table_; +}; + +using MyTypes = + ::testing::Types; +TYPED_TEST_SUITE(QpackHeaderTableTest, MyTypes); + +// MaxEntries is determined by maximum dynamic table capacity, +// which is set at construction time. +TYPED_TEST(QpackHeaderTableTest, MaxEntries) { + TypeParam table1; + table1.SetMaximumDynamicTableCapacity(1024); + EXPECT_EQ(32u, table1.max_entries()); + + TypeParam table2; + table2.SetMaximumDynamicTableCapacity(500); + EXPECT_EQ(15u, table2.max_entries()); +} + +TYPED_TEST(QpackHeaderTableTest, SetDynamicTableCapacity) { + // Dynamic table capacity does not affect MaxEntries. + EXPECT_TRUE(this->SetDynamicTableCapacity(1024)); + EXPECT_EQ(32u * 1024, this->max_entries()); + + EXPECT_TRUE(this->SetDynamicTableCapacity(500)); + EXPECT_EQ(32u * 1024, this->max_entries()); + + // Dynamic table capacity cannot exceed maximum dynamic table capacity. + EXPECT_FALSE(this->SetDynamicTableCapacity( + 2 * kMaximumDynamicTableCapacityForTesting)); +} + +TYPED_TEST(QpackHeaderTableTest, EntryFitsDynamicTableCapacity) { + EXPECT_TRUE(this->SetDynamicTableCapacity(39)); + + EXPECT_TRUE(this->EntryFitsDynamicTableCapacity("foo", "bar")); + EXPECT_TRUE(this->EntryFitsDynamicTableCapacity("foo", "bar2")); + EXPECT_FALSE(this->EntryFitsDynamicTableCapacity("foo", "bar12")); +} + +class QpackEncoderHeaderTableTest + : public QpackHeaderTableTest { + protected: + ~QpackEncoderHeaderTableTest() override = default; + + void ExpectMatch(absl::string_view name, absl::string_view value, + QpackEncoderHeaderTable::MatchType expected_match_type, + bool expected_is_static, uint64_t expected_index) const { + // Initialize outparams to a value different from the expected to ensure + // that FindHeaderField() sets them. + bool is_static = !expected_is_static; + uint64_t index = expected_index + 1; + + QpackEncoderHeaderTable::MatchType matchtype = + table_.FindHeaderField(name, value, &is_static, &index); + + EXPECT_EQ(expected_match_type, matchtype) << name << ": " << value; + EXPECT_EQ(expected_is_static, is_static) << name << ": " << value; + EXPECT_EQ(expected_index, index) << name << ": " << value; + } + + void ExpectNoMatch(absl::string_view name, absl::string_view value) const { + bool is_static = false; + uint64_t index = 0; + + QpackEncoderHeaderTable::MatchType matchtype = + table_.FindHeaderField(name, value, &is_static, &index); + + EXPECT_EQ(QpackEncoderHeaderTable::MatchType::kNoMatch, matchtype) + << name << ": " << value; + } + + uint64_t MaxInsertSizeWithoutEvictingGivenEntry(uint64_t index) const { + return table_.MaxInsertSizeWithoutEvictingGivenEntry(index); + } + + uint64_t draining_index(float draining_fraction) const { + return table_.draining_index(draining_fraction); + } +}; + +TEST_F(QpackEncoderHeaderTableTest, FindStaticHeaderField) { + // A header name that has multiple entries with different values. + ExpectMatch(":method", "GET", + QpackEncoderHeaderTable::MatchType::kNameAndValue, true, 17u); + + ExpectMatch(":method", "POST", + QpackEncoderHeaderTable::MatchType::kNameAndValue, true, 20u); + + ExpectMatch(":method", "TRACE", QpackEncoderHeaderTable::MatchType::kName, + true, 15u); + + // A header name that has a single entry with non-empty value. + ExpectMatch("accept-encoding", "gzip, deflate, br", + QpackEncoderHeaderTable::MatchType::kNameAndValue, true, 31u); + + ExpectMatch("accept-encoding", "compress", + QpackEncoderHeaderTable::MatchType::kName, true, 31u); + + ExpectMatch("accept-encoding", "", QpackEncoderHeaderTable::MatchType::kName, + true, 31u); + + // A header name that has a single entry with empty value. + ExpectMatch("location", "", QpackEncoderHeaderTable::MatchType::kNameAndValue, + true, 12u); + + ExpectMatch("location", "foo", QpackEncoderHeaderTable::MatchType::kName, + true, 12u); + + // No matching header name. + ExpectNoMatch("foo", ""); + ExpectNoMatch("foo", "bar"); +} + +TEST_F(QpackEncoderHeaderTableTest, FindDynamicHeaderField) { + // Dynamic table is initially entry. + ExpectNoMatch("foo", "bar"); + ExpectNoMatch("foo", "baz"); + + // Insert one entry. + InsertEntry("foo", "bar"); + + // Match name and value. + ExpectMatch("foo", "bar", QpackEncoderHeaderTable::MatchType::kNameAndValue, + false, 0u); + + // Match name only. + ExpectMatch("foo", "baz", QpackEncoderHeaderTable::MatchType::kName, false, + 0u); + + // Insert an identical entry. FindHeaderField() should return the index of + // the most recently inserted matching entry. + InsertEntry("foo", "bar"); + + // Match name and value. + ExpectMatch("foo", "bar", QpackEncoderHeaderTable::MatchType::kNameAndValue, + false, 1u); + + // Match name only. + ExpectMatch("foo", "baz", QpackEncoderHeaderTable::MatchType::kName, false, + 1u); +} + +TEST_F(QpackEncoderHeaderTableTest, FindHeaderFieldPrefersStaticTable) { + // Insert an entry to the dynamic table that exists in the static table. + InsertEntry(":method", "GET"); + + // FindHeaderField() prefers static table if both have name-and-value match. + ExpectMatch(":method", "GET", + QpackEncoderHeaderTable::MatchType::kNameAndValue, true, 17u); + + // FindHeaderField() prefers static table if both have name match but no value + // match, and prefers the first entry with matching name. + ExpectMatch(":method", "TRACE", QpackEncoderHeaderTable::MatchType::kName, + true, 15u); + + // Add new entry to the dynamic table. + InsertEntry(":method", "TRACE"); + + // FindHeaderField prefers name-and-value match in dynamic table over name + // only match in static table. + ExpectMatch(":method", "TRACE", + QpackEncoderHeaderTable::MatchType::kNameAndValue, false, 1u); +} + +TEST_F(QpackEncoderHeaderTableTest, EvictByInsertion) { + EXPECT_TRUE(SetDynamicTableCapacity(40)); + + // Entry size is 3 + 3 + 32 = 38. + InsertEntry("foo", "bar"); + EXPECT_EQ(1u, inserted_entry_count()); + EXPECT_EQ(0u, dropped_entry_count()); + + ExpectMatch("foo", "bar", QpackEncoderHeaderTable::MatchType::kNameAndValue, + /* expected_is_static = */ false, 0u); + + // Inserting second entry evicts the first one. + InsertEntry("baz", "qux"); + EXPECT_EQ(2u, inserted_entry_count()); + EXPECT_EQ(1u, dropped_entry_count()); + + ExpectNoMatch("foo", "bar"); + ExpectMatch("baz", "qux", QpackEncoderHeaderTable::MatchType::kNameAndValue, + /* expected_is_static = */ false, 1u); +} + +TEST_F(QpackEncoderHeaderTableTest, EvictByUpdateTableSize) { + // Entry size is 3 + 3 + 32 = 38. + InsertEntry("foo", "bar"); + InsertEntry("baz", "qux"); + EXPECT_EQ(2u, inserted_entry_count()); + EXPECT_EQ(0u, dropped_entry_count()); + + ExpectMatch("foo", "bar", QpackEncoderHeaderTable::MatchType::kNameAndValue, + /* expected_is_static = */ false, 0u); + ExpectMatch("baz", "qux", QpackEncoderHeaderTable::MatchType::kNameAndValue, + /* expected_is_static = */ false, 1u); + + EXPECT_TRUE(SetDynamicTableCapacity(40)); + EXPECT_EQ(2u, inserted_entry_count()); + EXPECT_EQ(1u, dropped_entry_count()); + + ExpectNoMatch("foo", "bar"); + ExpectMatch("baz", "qux", QpackEncoderHeaderTable::MatchType::kNameAndValue, + /* expected_is_static = */ false, 1u); + + EXPECT_TRUE(SetDynamicTableCapacity(20)); + EXPECT_EQ(2u, inserted_entry_count()); + EXPECT_EQ(2u, dropped_entry_count()); + + ExpectNoMatch("foo", "bar"); + ExpectNoMatch("baz", "qux"); +} + +TEST_F(QpackEncoderHeaderTableTest, EvictOldestOfIdentical) { + EXPECT_TRUE(SetDynamicTableCapacity(80)); + + // Entry size is 3 + 3 + 32 = 38. + // Insert same entry twice. + InsertEntry("foo", "bar"); + InsertEntry("foo", "bar"); + EXPECT_EQ(2u, inserted_entry_count()); + EXPECT_EQ(0u, dropped_entry_count()); + + // Find most recently inserted entry. + ExpectMatch("foo", "bar", QpackEncoderHeaderTable::MatchType::kNameAndValue, + /* expected_is_static = */ false, 1u); + + // Inserting third entry evicts the first one, not the second. + InsertEntry("baz", "qux"); + EXPECT_EQ(3u, inserted_entry_count()); + EXPECT_EQ(1u, dropped_entry_count()); + + ExpectMatch("foo", "bar", QpackEncoderHeaderTable::MatchType::kNameAndValue, + /* expected_is_static = */ false, 1u); + ExpectMatch("baz", "qux", QpackEncoderHeaderTable::MatchType::kNameAndValue, + /* expected_is_static = */ false, 2u); +} + +TEST_F(QpackEncoderHeaderTableTest, EvictOldestOfSameName) { + EXPECT_TRUE(SetDynamicTableCapacity(80)); + + // Entry size is 3 + 3 + 32 = 38. + // Insert two entries with same name but different values. + InsertEntry("foo", "bar"); + InsertEntry("foo", "baz"); + EXPECT_EQ(2u, inserted_entry_count()); + EXPECT_EQ(0u, dropped_entry_count()); + + // Find most recently inserted entry with matching name. + ExpectMatch("foo", "foo", QpackEncoderHeaderTable::MatchType::kName, + /* expected_is_static = */ false, 1u); + + // Inserting third entry evicts the first one, not the second. + InsertEntry("baz", "qux"); + EXPECT_EQ(3u, inserted_entry_count()); + EXPECT_EQ(1u, dropped_entry_count()); + + ExpectMatch("foo", "foo", QpackEncoderHeaderTable::MatchType::kName, + /* expected_is_static = */ false, 1u); + ExpectMatch("baz", "qux", QpackEncoderHeaderTable::MatchType::kNameAndValue, + /* expected_is_static = */ false, 2u); +} + +// Returns the size of the largest entry that could be inserted into the +// dynamic table without evicting entry |index|. +TEST_F(QpackEncoderHeaderTableTest, MaxInsertSizeWithoutEvictingGivenEntry) { + const uint64_t dynamic_table_capacity = 100; + EXPECT_TRUE(SetDynamicTableCapacity(dynamic_table_capacity)); + + // Empty table can take an entry up to its capacity. + EXPECT_EQ(dynamic_table_capacity, MaxInsertSizeWithoutEvictingGivenEntry(0)); + + const uint64_t entry_size1 = QpackEntry::Size("foo", "bar"); + InsertEntry("foo", "bar"); + EXPECT_EQ(dynamic_table_capacity - entry_size1, + MaxInsertSizeWithoutEvictingGivenEntry(0)); + // Table can take an entry up to its capacity if all entries are allowed to be + // evicted. + EXPECT_EQ(dynamic_table_capacity, MaxInsertSizeWithoutEvictingGivenEntry(1)); + + const uint64_t entry_size2 = QpackEntry::Size("baz", "foobar"); + InsertEntry("baz", "foobar"); + // Table can take an entry up to its capacity if all entries are allowed to be + // evicted. + EXPECT_EQ(dynamic_table_capacity, MaxInsertSizeWithoutEvictingGivenEntry(2)); + // Second entry must stay. + EXPECT_EQ(dynamic_table_capacity - entry_size2, + MaxInsertSizeWithoutEvictingGivenEntry(1)); + // First and second entry must stay. + EXPECT_EQ(dynamic_table_capacity - entry_size2 - entry_size1, + MaxInsertSizeWithoutEvictingGivenEntry(0)); + + // Third entry evicts first one. + const uint64_t entry_size3 = QpackEntry::Size("last", "entry"); + InsertEntry("last", "entry"); + EXPECT_EQ(1u, dropped_entry_count()); + // Table can take an entry up to its capacity if all entries are allowed to be + // evicted. + EXPECT_EQ(dynamic_table_capacity, MaxInsertSizeWithoutEvictingGivenEntry(3)); + // Third entry must stay. + EXPECT_EQ(dynamic_table_capacity - entry_size3, + MaxInsertSizeWithoutEvictingGivenEntry(2)); + // Second and third entry must stay. + EXPECT_EQ(dynamic_table_capacity - entry_size3 - entry_size2, + MaxInsertSizeWithoutEvictingGivenEntry(1)); +} + +TEST_F(QpackEncoderHeaderTableTest, DrainingIndex) { + EXPECT_TRUE(SetDynamicTableCapacity(4 * QpackEntry::Size("foo", "bar"))); + + // Empty table: no draining entry. + EXPECT_EQ(0u, draining_index(0.0)); + EXPECT_EQ(0u, draining_index(1.0)); + + // Table with one entry. + InsertEntry("foo", "bar"); + // Any entry can be referenced if none of the table is draining. + EXPECT_EQ(0u, draining_index(0.0)); + // No entry can be referenced if all of the table is draining. + EXPECT_EQ(1u, draining_index(1.0)); + + // Table with two entries is at half capacity. + InsertEntry("foo", "bar"); + // Any entry can be referenced if at most half of the table is draining, + // because current entries only take up half of total capacity. + EXPECT_EQ(0u, draining_index(0.0)); + EXPECT_EQ(0u, draining_index(0.5)); + // No entry can be referenced if all of the table is draining. + EXPECT_EQ(2u, draining_index(1.0)); + + // Table with four entries is full. + InsertEntry("foo", "bar"); + InsertEntry("foo", "bar"); + // Any entry can be referenced if none of the table is draining. + EXPECT_EQ(0u, draining_index(0.0)); + // In a full table with identically sized entries, |draining_fraction| of all + // entries are draining. + EXPECT_EQ(2u, draining_index(0.5)); + // No entry can be referenced if all of the table is draining. + EXPECT_EQ(4u, draining_index(1.0)); +} + +class MockObserver : public QpackDecoderHeaderTable::Observer { + public: + ~MockObserver() override = default; + + MOCK_METHOD(void, OnInsertCountReachedThreshold, (), (override)); + MOCK_METHOD(void, Cancel, (), (override)); +}; + +class QpackDecoderHeaderTableTest + : public QpackHeaderTableTest { + protected: + ~QpackDecoderHeaderTableTest() override = default; + + void ExpectEntryAtIndex(bool is_static, uint64_t index, + absl::string_view expected_name, + absl::string_view expected_value) const { + const auto* entry = table_.LookupEntry(is_static, index); + ASSERT_TRUE(entry); + EXPECT_EQ(expected_name, entry->name()); + EXPECT_EQ(expected_value, entry->value()); + } + + void ExpectNoEntryAtIndex(bool is_static, uint64_t index) const { + EXPECT_FALSE(table_.LookupEntry(is_static, index)); + } + + void RegisterObserver(uint64_t required_insert_count, + QpackDecoderHeaderTable::Observer* observer) { + table_.RegisterObserver(required_insert_count, observer); + } + + void UnregisterObserver(uint64_t required_insert_count, + QpackDecoderHeaderTable::Observer* observer) { + table_.UnregisterObserver(required_insert_count, observer); + } +}; + +TEST_F(QpackDecoderHeaderTableTest, LookupStaticEntry) { + ExpectEntryAtIndex(/* is_static = */ true, 0, ":authority", ""); + + ExpectEntryAtIndex(/* is_static = */ true, 1, ":path", "/"); + + // 98 is the last entry. + ExpectEntryAtIndex(/* is_static = */ true, 98, "x-frame-options", + "sameorigin"); + + ASSERT_EQ(99u, QpackStaticTableVector().size()); + ExpectNoEntryAtIndex(/* is_static = */ true, 99); +} + +TEST_F(QpackDecoderHeaderTableTest, InsertAndLookupDynamicEntry) { + // Dynamic table is initially entry. + ExpectNoEntryAtIndex(/* is_static = */ false, 0); + ExpectNoEntryAtIndex(/* is_static = */ false, 1); + ExpectNoEntryAtIndex(/* is_static = */ false, 2); + ExpectNoEntryAtIndex(/* is_static = */ false, 3); + + // Insert one entry. + InsertEntry("foo", "bar"); + + ExpectEntryAtIndex(/* is_static = */ false, 0, "foo", "bar"); + + ExpectNoEntryAtIndex(/* is_static = */ false, 1); + ExpectNoEntryAtIndex(/* is_static = */ false, 2); + ExpectNoEntryAtIndex(/* is_static = */ false, 3); + + // Insert a different entry. + InsertEntry("baz", "bing"); + + ExpectEntryAtIndex(/* is_static = */ false, 0, "foo", "bar"); + + ExpectEntryAtIndex(/* is_static = */ false, 1, "baz", "bing"); + + ExpectNoEntryAtIndex(/* is_static = */ false, 2); + ExpectNoEntryAtIndex(/* is_static = */ false, 3); + + // Insert an entry identical to the most recently inserted one. + InsertEntry("baz", "bing"); + + ExpectEntryAtIndex(/* is_static = */ false, 0, "foo", "bar"); + + ExpectEntryAtIndex(/* is_static = */ false, 1, "baz", "bing"); + + ExpectEntryAtIndex(/* is_static = */ false, 2, "baz", "bing"); + + ExpectNoEntryAtIndex(/* is_static = */ false, 3); +} + +TEST_F(QpackDecoderHeaderTableTest, EvictByInsertion) { + EXPECT_TRUE(SetDynamicTableCapacity(40)); + + // Entry size is 3 + 3 + 32 = 38. + InsertEntry("foo", "bar"); + EXPECT_EQ(1u, inserted_entry_count()); + EXPECT_EQ(0u, dropped_entry_count()); + + ExpectEntryAtIndex(/* is_static = */ false, 0u, "foo", "bar"); + + // Inserting second entry evicts the first one. + InsertEntry("baz", "qux"); + EXPECT_EQ(2u, inserted_entry_count()); + EXPECT_EQ(1u, dropped_entry_count()); + + ExpectNoEntryAtIndex(/* is_static = */ false, 0u); + ExpectEntryAtIndex(/* is_static = */ false, 1u, "baz", "qux"); +} + +TEST_F(QpackDecoderHeaderTableTest, EvictByUpdateTableSize) { + ExpectNoEntryAtIndex(/* is_static = */ false, 0u); + ExpectNoEntryAtIndex(/* is_static = */ false, 1u); + + // Entry size is 3 + 3 + 32 = 38. + InsertEntry("foo", "bar"); + InsertEntry("baz", "qux"); + EXPECT_EQ(2u, inserted_entry_count()); + EXPECT_EQ(0u, dropped_entry_count()); + + ExpectEntryAtIndex(/* is_static = */ false, 0u, "foo", "bar"); + ExpectEntryAtIndex(/* is_static = */ false, 1u, "baz", "qux"); + + EXPECT_TRUE(SetDynamicTableCapacity(40)); + EXPECT_EQ(2u, inserted_entry_count()); + EXPECT_EQ(1u, dropped_entry_count()); + + ExpectNoEntryAtIndex(/* is_static = */ false, 0u); + ExpectEntryAtIndex(/* is_static = */ false, 1u, "baz", "qux"); + + EXPECT_TRUE(SetDynamicTableCapacity(20)); + EXPECT_EQ(2u, inserted_entry_count()); + EXPECT_EQ(2u, dropped_entry_count()); + + ExpectNoEntryAtIndex(/* is_static = */ false, 0u); + ExpectNoEntryAtIndex(/* is_static = */ false, 1u); +} + +TEST_F(QpackDecoderHeaderTableTest, EvictOldestOfIdentical) { + EXPECT_TRUE(SetDynamicTableCapacity(80)); + + // Entry size is 3 + 3 + 32 = 38. + // Insert same entry twice. + InsertEntry("foo", "bar"); + InsertEntry("foo", "bar"); + EXPECT_EQ(2u, inserted_entry_count()); + EXPECT_EQ(0u, dropped_entry_count()); + + ExpectEntryAtIndex(/* is_static = */ false, 0u, "foo", "bar"); + ExpectEntryAtIndex(/* is_static = */ false, 1u, "foo", "bar"); + ExpectNoEntryAtIndex(/* is_static = */ false, 2u); + + // Inserting third entry evicts the first one, not the second. + InsertEntry("baz", "qux"); + EXPECT_EQ(3u, inserted_entry_count()); + EXPECT_EQ(1u, dropped_entry_count()); + + ExpectNoEntryAtIndex(/* is_static = */ false, 0u); + ExpectEntryAtIndex(/* is_static = */ false, 1u, "foo", "bar"); + ExpectEntryAtIndex(/* is_static = */ false, 2u, "baz", "qux"); +} + +TEST_F(QpackDecoderHeaderTableTest, EvictOldestOfSameName) { + EXPECT_TRUE(SetDynamicTableCapacity(80)); + + // Entry size is 3 + 3 + 32 = 38. + // Insert two entries with same name but different values. + InsertEntry("foo", "bar"); + InsertEntry("foo", "baz"); + EXPECT_EQ(2u, inserted_entry_count()); + EXPECT_EQ(0u, dropped_entry_count()); + + ExpectEntryAtIndex(/* is_static = */ false, 0u, "foo", "bar"); + ExpectEntryAtIndex(/* is_static = */ false, 1u, "foo", "baz"); + ExpectNoEntryAtIndex(/* is_static = */ false, 2u); + + // Inserting third entry evicts the first one, not the second. + InsertEntry("baz", "qux"); + EXPECT_EQ(3u, inserted_entry_count()); + EXPECT_EQ(1u, dropped_entry_count()); + + ExpectNoEntryAtIndex(/* is_static = */ false, 0u); + ExpectEntryAtIndex(/* is_static = */ false, 1u, "foo", "baz"); + ExpectEntryAtIndex(/* is_static = */ false, 2u, "baz", "qux"); +} + +TEST_F(QpackDecoderHeaderTableTest, RegisterObserver) { + StrictMock observer1; + RegisterObserver(1, &observer1); + EXPECT_CALL(observer1, OnInsertCountReachedThreshold); + InsertEntry("foo", "bar"); + EXPECT_EQ(1u, inserted_entry_count()); + Mock::VerifyAndClearExpectations(&observer1); + + // Registration order does not matter. + StrictMock observer2; + StrictMock observer3; + RegisterObserver(3, &observer3); + RegisterObserver(2, &observer2); + + EXPECT_CALL(observer2, OnInsertCountReachedThreshold); + InsertEntry("foo", "bar"); + EXPECT_EQ(2u, inserted_entry_count()); + Mock::VerifyAndClearExpectations(&observer3); + + EXPECT_CALL(observer3, OnInsertCountReachedThreshold); + InsertEntry("foo", "bar"); + EXPECT_EQ(3u, inserted_entry_count()); + Mock::VerifyAndClearExpectations(&observer2); + + // Multiple observers with identical |required_insert_count| should all be + // notified. + StrictMock observer4; + StrictMock observer5; + RegisterObserver(4, &observer4); + RegisterObserver(4, &observer5); + + EXPECT_CALL(observer4, OnInsertCountReachedThreshold); + EXPECT_CALL(observer5, OnInsertCountReachedThreshold); + InsertEntry("foo", "bar"); + EXPECT_EQ(4u, inserted_entry_count()); + Mock::VerifyAndClearExpectations(&observer4); + Mock::VerifyAndClearExpectations(&observer5); +} + +TEST_F(QpackDecoderHeaderTableTest, UnregisterObserver) { + StrictMock observer1; + StrictMock observer2; + StrictMock observer3; + StrictMock observer4; + RegisterObserver(1, &observer1); + RegisterObserver(2, &observer2); + RegisterObserver(2, &observer3); + RegisterObserver(3, &observer4); + + UnregisterObserver(2, &observer3); + + EXPECT_CALL(observer1, OnInsertCountReachedThreshold); + EXPECT_CALL(observer2, OnInsertCountReachedThreshold); + EXPECT_CALL(observer4, OnInsertCountReachedThreshold); + InsertEntry("foo", "bar"); + InsertEntry("foo", "bar"); + InsertEntry("foo", "bar"); + EXPECT_EQ(3u, inserted_entry_count()); +} + +TEST_F(QpackDecoderHeaderTableTest, Cancel) { + StrictMock observer; + auto table = std::make_unique(); + table->RegisterObserver(1, &observer); + + EXPECT_CALL(observer, Cancel); + table.reset(); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_index_conversions.cc b/quiche/quic/core/qpack/qpack_index_conversions.cc new file mode 100644 index 000000000000..a54bfcd1086d --- /dev/null +++ b/quiche/quic/core/qpack/qpack_index_conversions.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_index_conversions.h" + +#include + +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +uint64_t QpackAbsoluteIndexToEncoderStreamRelativeIndex( + uint64_t absolute_index, uint64_t inserted_entry_count) { + QUICHE_DCHECK_LT(absolute_index, inserted_entry_count); + + return inserted_entry_count - absolute_index - 1; +} + +uint64_t QpackAbsoluteIndexToRequestStreamRelativeIndex(uint64_t absolute_index, + uint64_t base) { + QUICHE_DCHECK_LT(absolute_index, base); + + return base - absolute_index - 1; +} + +bool QpackEncoderStreamRelativeIndexToAbsoluteIndex( + uint64_t relative_index, uint64_t inserted_entry_count, + uint64_t* absolute_index) { + if (relative_index >= inserted_entry_count) { + return false; + } + + *absolute_index = inserted_entry_count - relative_index - 1; + return true; +} + +bool QpackRequestStreamRelativeIndexToAbsoluteIndex(uint64_t relative_index, + uint64_t base, + uint64_t* absolute_index) { + if (relative_index >= base) { + return false; + } + + *absolute_index = base - relative_index - 1; + return true; +} + +bool QpackPostBaseIndexToAbsoluteIndex(uint64_t post_base_index, uint64_t base, + uint64_t* absolute_index) { + if (post_base_index >= std::numeric_limits::max() - base) { + return false; + } + + *absolute_index = base + post_base_index; + return true; +} + +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_index_conversions.h b/quiche/quic/core/qpack/qpack_index_conversions.h new file mode 100644 index 000000000000..ddd51cf05057 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_index_conversions.h @@ -0,0 +1,52 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Utility methods to convert between absolute indexing (used in the dynamic +// table), relative indexing used on the encoder stream, and relative indexing +// and post-base indexing used on request streams (in header blocks). See: +// https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#indexing +// https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#relative-indexing +// https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#post-base + +#ifndef QUICHE_QUIC_CORE_QPACK_QPACK_INDEX_CONVERSIONS_H_ +#define QUICHE_QUIC_CORE_QPACK_QPACK_INDEX_CONVERSIONS_H_ + +#include + +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// Conversion functions used in the encoder do not check for overflow/underflow. +// Since the maximum index is limited by maximum dynamic table capacity +// (represented on uint64_t) divided by minimum header field size (defined to be +// 32 bytes), overflow is not possible. The caller is responsible for providing +// input that does not underflow. + +QUIC_EXPORT_PRIVATE uint64_t QpackAbsoluteIndexToEncoderStreamRelativeIndex( + uint64_t absolute_index, uint64_t inserted_entry_count); + +QUIC_EXPORT_PRIVATE uint64_t QpackAbsoluteIndexToRequestStreamRelativeIndex( + uint64_t absolute_index, uint64_t base); + +// Conversion functions used in the decoder operate on input received from the +// network. These functions return false on overflow or underflow. + +QUIC_EXPORT_PRIVATE bool QpackEncoderStreamRelativeIndexToAbsoluteIndex( + uint64_t relative_index, uint64_t inserted_entry_count, + uint64_t* absolute_index); + +// On success, |*absolute_index| is guaranteed to be strictly less than +// std::numeric_limits::max(). +QUIC_EXPORT_PRIVATE bool QpackRequestStreamRelativeIndexToAbsoluteIndex( + uint64_t relative_index, uint64_t base, uint64_t* absolute_index); + +// On success, |*absolute_index| is guaranteed to be strictly less than +// std::numeric_limits::max(). +QUIC_EXPORT_PRIVATE bool QpackPostBaseIndexToAbsoluteIndex( + uint64_t post_base_index, uint64_t base, uint64_t* absolute_index); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QPACK_QPACK_INDEX_CONVERSIONS_H_ diff --git a/quiche/quic/core/qpack/qpack_index_conversions_test.cc b/quiche/quic/core/qpack/qpack_index_conversions_test.cc new file mode 100644 index 000000000000..80162df358fa --- /dev/null +++ b/quiche/quic/core/qpack/qpack_index_conversions_test.cc @@ -0,0 +1,99 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_index_conversions.h" + +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { +namespace { + +struct { + uint64_t relative_index; + uint64_t inserted_entry_count; + uint64_t expected_absolute_index; +} kEncoderStreamRelativeIndexTestData[] = {{0, 1, 0}, {0, 2, 1}, {1, 2, 0}, + {0, 10, 9}, {5, 10, 4}, {9, 10, 0}}; + +TEST(QpackIndexConversions, EncoderStreamRelativeIndex) { + for (const auto& test_data : kEncoderStreamRelativeIndexTestData) { + uint64_t absolute_index = 42; + EXPECT_TRUE(QpackEncoderStreamRelativeIndexToAbsoluteIndex( + test_data.relative_index, test_data.inserted_entry_count, + &absolute_index)); + EXPECT_EQ(test_data.expected_absolute_index, absolute_index); + + EXPECT_EQ(test_data.relative_index, + QpackAbsoluteIndexToEncoderStreamRelativeIndex( + absolute_index, test_data.inserted_entry_count)); + } +} + +struct { + uint64_t relative_index; + uint64_t base; + uint64_t expected_absolute_index; +} kRequestStreamRelativeIndexTestData[] = {{0, 1, 0}, {0, 2, 1}, {1, 2, 0}, + {0, 10, 9}, {5, 10, 4}, {9, 10, 0}}; + +TEST(QpackIndexConversions, RequestStreamRelativeIndex) { + for (const auto& test_data : kRequestStreamRelativeIndexTestData) { + uint64_t absolute_index = 42; + EXPECT_TRUE(QpackRequestStreamRelativeIndexToAbsoluteIndex( + test_data.relative_index, test_data.base, &absolute_index)); + EXPECT_EQ(test_data.expected_absolute_index, absolute_index); + + EXPECT_EQ(test_data.relative_index, + QpackAbsoluteIndexToRequestStreamRelativeIndex(absolute_index, + test_data.base)); + } +} + +struct { + uint64_t post_base_index; + uint64_t base; + uint64_t expected_absolute_index; +} kPostBaseIndexTestData[] = {{0, 1, 1}, {1, 0, 1}, {2, 0, 2}, + {1, 1, 2}, {0, 2, 2}, {1, 2, 3}}; + +TEST(QpackIndexConversions, PostBaseIndex) { + for (const auto& test_data : kPostBaseIndexTestData) { + uint64_t absolute_index = 42; + EXPECT_TRUE(QpackPostBaseIndexToAbsoluteIndex( + test_data.post_base_index, test_data.base, &absolute_index)); + EXPECT_EQ(test_data.expected_absolute_index, absolute_index); + } +} + +TEST(QpackIndexConversions, EncoderStreamRelativeIndexUnderflow) { + uint64_t absolute_index; + EXPECT_FALSE(QpackEncoderStreamRelativeIndexToAbsoluteIndex( + /* relative_index = */ 10, + /* inserted_entry_count = */ 10, &absolute_index)); + EXPECT_FALSE(QpackEncoderStreamRelativeIndexToAbsoluteIndex( + /* relative_index = */ 12, + /* inserted_entry_count = */ 10, &absolute_index)); +} + +TEST(QpackIndexConversions, RequestStreamRelativeIndexUnderflow) { + uint64_t absolute_index; + EXPECT_FALSE(QpackRequestStreamRelativeIndexToAbsoluteIndex( + /* relative_index = */ 10, + /* base = */ 10, &absolute_index)); + EXPECT_FALSE(QpackRequestStreamRelativeIndexToAbsoluteIndex( + /* relative_index = */ 12, + /* base = */ 10, &absolute_index)); +} + +TEST(QpackIndexConversions, QpackPostBaseIndexToAbsoluteIndexOverflow) { + uint64_t absolute_index; + EXPECT_FALSE(QpackPostBaseIndexToAbsoluteIndex( + /* post_base_index = */ 20, + /* base = */ std::numeric_limits::max() - 10, &absolute_index)); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_instruction_decoder.cc b/quiche/quic/core/qpack/qpack_instruction_decoder.cc new file mode 100644 index 000000000000..bc22db00ef4a --- /dev/null +++ b/quiche/quic/core/qpack/qpack_instruction_decoder.cc @@ -0,0 +1,332 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_instruction_decoder.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +namespace { + +// Maximum length of header name and header value. This limits the amount of +// memory the peer can make the decoder allocate when sending string literals. +const size_t kStringLiteralLengthLimit = 1024 * 1024; + +} // namespace + +QpackInstructionDecoder::QpackInstructionDecoder(const QpackLanguage* language, + Delegate* delegate) + : language_(language), + delegate_(delegate), + s_bit_(false), + varint_(0), + varint2_(0), + is_huffman_encoded_(false), + string_length_(0), + error_detected_(false), + state_(State::kStartInstruction) {} + +bool QpackInstructionDecoder::Decode(absl::string_view data) { + QUICHE_DCHECK(!data.empty()); + QUICHE_DCHECK(!error_detected_); + + while (true) { + bool success = true; + size_t bytes_consumed = 0; + + switch (state_) { + case State::kStartInstruction: + success = DoStartInstruction(data); + break; + case State::kStartField: + success = DoStartField(); + break; + case State::kReadBit: + success = DoReadBit(data); + break; + case State::kVarintStart: + success = DoVarintStart(data, &bytes_consumed); + break; + case State::kVarintResume: + success = DoVarintResume(data, &bytes_consumed); + break; + case State::kVarintDone: + success = DoVarintDone(); + break; + case State::kReadString: + success = DoReadString(data, &bytes_consumed); + break; + case State::kReadStringDone: + success = DoReadStringDone(); + break; + } + + if (!success) { + return false; + } + + // |success| must be false if an error is detected. + QUICHE_DCHECK(!error_detected_); + + QUICHE_DCHECK_LE(bytes_consumed, data.size()); + + data = absl::string_view(data.data() + bytes_consumed, + data.size() - bytes_consumed); + + // Stop processing if no more data but next state would require it. + if (data.empty() && (state_ != State::kStartField) && + (state_ != State::kVarintDone) && (state_ != State::kReadStringDone)) { + return true; + } + } +} + +bool QpackInstructionDecoder::AtInstructionBoundary() const { + return state_ == State::kStartInstruction; +} + +bool QpackInstructionDecoder::DoStartInstruction(absl::string_view data) { + QUICHE_DCHECK(!data.empty()); + + instruction_ = LookupOpcode(data[0]); + field_ = instruction_->fields.begin(); + + state_ = State::kStartField; + return true; +} + +bool QpackInstructionDecoder::DoStartField() { + if (field_ == instruction_->fields.end()) { + // Completed decoding this instruction. + + if (!delegate_->OnInstructionDecoded(instruction_)) { + return false; + } + + state_ = State::kStartInstruction; + return true; + } + + switch (field_->type) { + case QpackInstructionFieldType::kSbit: + case QpackInstructionFieldType::kName: + case QpackInstructionFieldType::kValue: + state_ = State::kReadBit; + return true; + case QpackInstructionFieldType::kVarint: + case QpackInstructionFieldType::kVarint2: + state_ = State::kVarintStart; + return true; + default: + QUIC_BUG(quic_bug_10767_1) << "Invalid field type."; + return false; + } +} + +bool QpackInstructionDecoder::DoReadBit(absl::string_view data) { + QUICHE_DCHECK(!data.empty()); + + switch (field_->type) { + case QpackInstructionFieldType::kSbit: { + const uint8_t bitmask = field_->param; + s_bit_ = (data[0] & bitmask) == bitmask; + + ++field_; + state_ = State::kStartField; + + return true; + } + case QpackInstructionFieldType::kName: + case QpackInstructionFieldType::kValue: { + const uint8_t prefix_length = field_->param; + QUICHE_DCHECK_GE(7, prefix_length); + const uint8_t bitmask = 1 << prefix_length; + is_huffman_encoded_ = (data[0] & bitmask) == bitmask; + + state_ = State::kVarintStart; + + return true; + } + default: + QUIC_BUG(quic_bug_10767_2) << "Invalid field type."; + return false; + } +} + +bool QpackInstructionDecoder::DoVarintStart(absl::string_view data, + size_t* bytes_consumed) { + QUICHE_DCHECK(!data.empty()); + QUICHE_DCHECK(field_->type == QpackInstructionFieldType::kVarint || + field_->type == QpackInstructionFieldType::kVarint2 || + field_->type == QpackInstructionFieldType::kName || + field_->type == QpackInstructionFieldType::kValue); + + http2::DecodeBuffer buffer(data.data() + 1, data.size() - 1); + http2::DecodeStatus status = + varint_decoder_.Start(data[0], field_->param, &buffer); + + *bytes_consumed = 1 + buffer.Offset(); + switch (status) { + case http2::DecodeStatus::kDecodeDone: + state_ = State::kVarintDone; + return true; + case http2::DecodeStatus::kDecodeInProgress: + state_ = State::kVarintResume; + return true; + case http2::DecodeStatus::kDecodeError: + OnError(ErrorCode::INTEGER_TOO_LARGE, "Encoded integer too large."); + return false; + default: + QUIC_BUG(quic_bug_10767_3) << "Unknown decode status " << status; + return false; + } +} + +bool QpackInstructionDecoder::DoVarintResume(absl::string_view data, + size_t* bytes_consumed) { + QUICHE_DCHECK(!data.empty()); + QUICHE_DCHECK(field_->type == QpackInstructionFieldType::kVarint || + field_->type == QpackInstructionFieldType::kVarint2 || + field_->type == QpackInstructionFieldType::kName || + field_->type == QpackInstructionFieldType::kValue); + + http2::DecodeBuffer buffer(data); + http2::DecodeStatus status = varint_decoder_.Resume(&buffer); + + *bytes_consumed = buffer.Offset(); + switch (status) { + case http2::DecodeStatus::kDecodeDone: + state_ = State::kVarintDone; + return true; + case http2::DecodeStatus::kDecodeInProgress: + QUICHE_DCHECK_EQ(*bytes_consumed, data.size()); + QUICHE_DCHECK(buffer.Empty()); + return true; + case http2::DecodeStatus::kDecodeError: + OnError(ErrorCode::INTEGER_TOO_LARGE, "Encoded integer too large."); + return false; + default: + QUIC_BUG(quic_bug_10767_4) << "Unknown decode status " << status; + return false; + } +} + +bool QpackInstructionDecoder::DoVarintDone() { + QUICHE_DCHECK(field_->type == QpackInstructionFieldType::kVarint || + field_->type == QpackInstructionFieldType::kVarint2 || + field_->type == QpackInstructionFieldType::kName || + field_->type == QpackInstructionFieldType::kValue); + + if (field_->type == QpackInstructionFieldType::kVarint) { + varint_ = varint_decoder_.value(); + + ++field_; + state_ = State::kStartField; + return true; + } + + if (field_->type == QpackInstructionFieldType::kVarint2) { + varint2_ = varint_decoder_.value(); + + ++field_; + state_ = State::kStartField; + return true; + } + + string_length_ = varint_decoder_.value(); + if (string_length_ > kStringLiteralLengthLimit) { + OnError(ErrorCode::STRING_LITERAL_TOO_LONG, "String literal too long."); + return false; + } + + std::string* const string = + (field_->type == QpackInstructionFieldType::kName) ? &name_ : &value_; + string->clear(); + + if (string_length_ == 0) { + ++field_; + state_ = State::kStartField; + return true; + } + + string->reserve(string_length_); + + state_ = State::kReadString; + return true; +} + +bool QpackInstructionDecoder::DoReadString(absl::string_view data, + size_t* bytes_consumed) { + QUICHE_DCHECK(!data.empty()); + QUICHE_DCHECK(field_->type == QpackInstructionFieldType::kName || + field_->type == QpackInstructionFieldType::kValue); + + std::string* const string = + (field_->type == QpackInstructionFieldType::kName) ? &name_ : &value_; + QUICHE_DCHECK_LT(string->size(), string_length_); + + *bytes_consumed = std::min(string_length_ - string->size(), data.size()); + string->append(data.data(), *bytes_consumed); + + QUICHE_DCHECK_LE(string->size(), string_length_); + if (string->size() == string_length_) { + state_ = State::kReadStringDone; + } + return true; +} + +bool QpackInstructionDecoder::DoReadStringDone() { + QUICHE_DCHECK(field_->type == QpackInstructionFieldType::kName || + field_->type == QpackInstructionFieldType::kValue); + + std::string* const string = + (field_->type == QpackInstructionFieldType::kName) ? &name_ : &value_; + QUICHE_DCHECK_EQ(string->size(), string_length_); + + if (is_huffman_encoded_) { + huffman_decoder_.Reset(); + // HpackHuffmanDecoder::Decode() cannot perform in-place decoding. + std::string decoded_value; + huffman_decoder_.Decode(*string, &decoded_value); + if (!huffman_decoder_.InputProperlyTerminated()) { + OnError(ErrorCode::HUFFMAN_ENCODING_ERROR, + "Error in Huffman-encoded string."); + return false; + } + *string = std::move(decoded_value); + } + + ++field_; + state_ = State::kStartField; + return true; +} + +const QpackInstruction* QpackInstructionDecoder::LookupOpcode( + uint8_t byte) const { + for (const auto* instruction : *language_) { + if ((byte & instruction->opcode.mask) == instruction->opcode.value) { + return instruction; + } + } + // |language_| should be defined such that instruction opcodes cover every + // possible input. + QUICHE_DCHECK(false); + return nullptr; +} + +void QpackInstructionDecoder::OnError(ErrorCode error_code, + absl::string_view error_message) { + QUICHE_DCHECK(!error_detected_); + + error_detected_ = true; + delegate_->OnInstructionDecodingError(error_code, error_message); +} + +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_instruction_decoder.h b/quiche/quic/core/qpack/qpack_instruction_decoder.h new file mode 100644 index 000000000000..c848b3dc4e44 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_instruction_decoder.h @@ -0,0 +1,160 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QPACK_QPACK_INSTRUCTION_DECODER_H_ +#define QUICHE_QUIC_CORE_QPACK_QPACK_INSTRUCTION_DECODER_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/http2/hpack/huffman/hpack_huffman_decoder.h" +#include "quiche/http2/hpack/varint/hpack_varint_decoder.h" +#include "quiche/quic/core/qpack/qpack_instructions.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// Generic instruction decoder class. Takes a QpackLanguage that describes a +// language, that is, a set of instruction opcodes together with a list of +// fields that follow each instruction. +class QUIC_EXPORT_PRIVATE QpackInstructionDecoder { + public: + enum class ErrorCode { + INTEGER_TOO_LARGE, + STRING_LITERAL_TOO_LONG, + HUFFMAN_ENCODING_ERROR, + }; + + // Delegate is notified each time an instruction is decoded or when an error + // occurs. + class QUIC_EXPORT_PRIVATE Delegate { + public: + virtual ~Delegate() = default; + + // Called when an instruction (including all its fields) is decoded. + // |instruction| points to an entry in |language|. + // Returns true if decoded fields are valid. + // Returns false otherwise, in which case QpackInstructionDecoder stops + // decoding: Delegate methods will not be called, and Decode() must not be + // called. Implementations are allowed to destroy the + // QpackInstructionDecoder instance synchronously if OnInstructionDecoded() + // returns false. + virtual bool OnInstructionDecoded(const QpackInstruction* instruction) = 0; + + // Called by QpackInstructionDecoder if an error has occurred. + // No more data is processed afterwards. + // Implementations are allowed to destroy the QpackInstructionDecoder + // instance synchronously. + virtual void OnInstructionDecodingError( + ErrorCode error_code, absl::string_view error_message) = 0; + }; + + // Both |*language| and |*delegate| must outlive this object. + QpackInstructionDecoder(const QpackLanguage* language, Delegate* delegate); + QpackInstructionDecoder() = delete; + QpackInstructionDecoder(const QpackInstructionDecoder&) = delete; + QpackInstructionDecoder& operator=(const QpackInstructionDecoder&) = delete; + + // Provide a data fragment to decode. Must not be called after an error has + // occurred. Must not be called with empty |data|. Return true on success, + // false on error (in which case Delegate::OnInstructionDecodingError() is + // called synchronously). + bool Decode(absl::string_view data); + + // Returns true if no decoding has taken place yet or if the last instruction + // has been entirely parsed. + bool AtInstructionBoundary() const; + + // Accessors for decoded values. Should only be called for fields that are + // part of the most recently decoded instruction, and only after |this| calls + // Delegate::OnInstructionDecoded() but before Decode() is called again. + bool s_bit() const { return s_bit_; } + uint64_t varint() const { return varint_; } + uint64_t varint2() const { return varint2_; } + const std::string& name() const { return name_; } + const std::string& value() const { return value_; } + + private: + enum class State { + // Identify instruction. + kStartInstruction, + // Start decoding next field. + kStartField, + // Read a single bit. + kReadBit, + // Start reading integer. + kVarintStart, + // Resume reading integer. + kVarintResume, + // Done reading integer. + kVarintDone, + // Read string. + kReadString, + // Done reading string. + kReadStringDone + }; + + // One method for each state. They each return true on success, false on + // error (in which case |this| might already be destroyed). Some take input + // data and set |*bytes_consumed| to the number of octets processed. Some + // take input data but do not consume any bytes. Some do not take any + // arguments because they only change internal state. + bool DoStartInstruction(absl::string_view data); + bool DoStartField(); + bool DoReadBit(absl::string_view data); + bool DoVarintStart(absl::string_view data, size_t* bytes_consumed); + bool DoVarintResume(absl::string_view data, size_t* bytes_consumed); + bool DoVarintDone(); + bool DoReadString(absl::string_view data, size_t* bytes_consumed); + bool DoReadStringDone(); + + // Identify instruction based on opcode encoded in |byte|. + // Returns a pointer to an element of |*language_|. + const QpackInstruction* LookupOpcode(uint8_t byte) const; + + // Stops decoding and calls Delegate::OnInstructionDecodingError(). + void OnError(ErrorCode error_code, absl::string_view error_message); + + // Describes the language used for decoding. + const QpackLanguage* const language_; + + // The Delegate to notify of decoded instructions and errors. + Delegate* const delegate_; + + // Storage for decoded field values. + bool s_bit_; + uint64_t varint_; + uint64_t varint2_; + std::string name_; + std::string value_; + // Whether the currently decoded header name or value is Huffman encoded. + bool is_huffman_encoded_; + // Length of string being read into |name_| or |value_|. + size_t string_length_; + + // Decoder instance for decoding integers. + http2::HpackVarintDecoder varint_decoder_; + + // Decoder instance for decoding Huffman encoded strings. + http2::HpackHuffmanDecoder huffman_decoder_; + + // True if a decoding error has been detected by QpackInstructionDecoder. + // Only used in QUICHE_DCHECKs. + bool error_detected_; + + // Decoding state. + State state_; + + // Instruction currently being decoded. + const QpackInstruction* instruction_; + + // Field currently being decoded. + QpackInstructionFields::const_iterator field_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QPACK_QPACK_INSTRUCTION_DECODER_H_ diff --git a/quiche/quic/core/qpack/qpack_instruction_decoder_test.cc b/quiche/quic/core/qpack/qpack_instruction_decoder_test.cc new file mode 100644 index 000000000000..1a2aa2cb2aad --- /dev/null +++ b/quiche/quic/core/qpack/qpack_instruction_decoder_test.cc @@ -0,0 +1,222 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_instruction_decoder.h" + +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/qpack/qpack_instructions.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/qpack/qpack_test_utils.h" + +using ::testing::_; +using ::testing::Eq; +using ::testing::Expectation; +using ::testing::InvokeWithoutArgs; +using ::testing::Return; +using ::testing::StrictMock; +using ::testing::Values; + +namespace quic { +namespace test { +namespace { + +// This instruction has three fields: an S bit and two varints. +const QpackInstruction* TestInstruction1() { + static const QpackInstruction* const instruction = + new QpackInstruction{QpackInstructionOpcode{0x00, 0x80}, + {{QpackInstructionFieldType::kSbit, 0x40}, + {QpackInstructionFieldType::kVarint, 6}, + {QpackInstructionFieldType::kVarint2, 8}}}; + return instruction; +} + +// This instruction has two fields: a header name with a 6-bit prefix, and a +// header value with a 7-bit prefix, both preceded by a Huffman bit. +const QpackInstruction* TestInstruction2() { + static const QpackInstruction* const instruction = + new QpackInstruction{QpackInstructionOpcode{0x80, 0x80}, + {{QpackInstructionFieldType::kName, 6}, + {QpackInstructionFieldType::kValue, 7}}}; + return instruction; +} + +const QpackLanguage* TestLanguage() { + static const QpackLanguage* const language = + new QpackLanguage{TestInstruction1(), TestInstruction2()}; + return language; +} + +class MockDelegate : public QpackInstructionDecoder::Delegate { + public: + MockDelegate() { + ON_CALL(*this, OnInstructionDecoded(_)).WillByDefault(Return(true)); + } + + MockDelegate(const MockDelegate&) = delete; + MockDelegate& operator=(const MockDelegate&) = delete; + ~MockDelegate() override = default; + + MOCK_METHOD(bool, OnInstructionDecoded, (const QpackInstruction*), + (override)); + MOCK_METHOD(void, OnInstructionDecodingError, + (QpackInstructionDecoder::ErrorCode error_code, + absl::string_view error_message), + (override)); +}; + +class QpackInstructionDecoderTest : public QuicTestWithParam { + protected: + QpackInstructionDecoderTest() + : decoder_(std::make_unique(TestLanguage(), + &delegate_)), + fragment_mode_(GetParam()) {} + ~QpackInstructionDecoderTest() override = default; + + void SetUp() override { + // Destroy QpackInstructionDecoder on error to test that it does not crash. + // See https://crbug.com/1025209. + ON_CALL(delegate_, OnInstructionDecodingError(_, _)) + .WillByDefault(InvokeWithoutArgs([this]() { decoder_.reset(); })); + } + + // Decode one full instruction with fragment sizes dictated by + // |fragment_mode_|. + // Assumes that |data| is a single complete instruction, and accordingly + // verifies that AtInstructionBoundary() returns true before and after the + // instruction, and returns false while decoding is in progress. + // Assumes that delegate methods destroy |decoder_| if they return false. + void DecodeInstruction(absl::string_view data) { + EXPECT_TRUE(decoder_->AtInstructionBoundary()); + + FragmentSizeGenerator fragment_size_generator = + FragmentModeToFragmentSizeGenerator(fragment_mode_); + + while (!data.empty()) { + size_t fragment_size = std::min(fragment_size_generator(), data.size()); + bool success = decoder_->Decode(data.substr(0, fragment_size)); + if (!decoder_) { + EXPECT_FALSE(success); + return; + } + EXPECT_TRUE(success); + data = data.substr(fragment_size); + if (!data.empty()) { + EXPECT_FALSE(decoder_->AtInstructionBoundary()); + } + } + + EXPECT_TRUE(decoder_->AtInstructionBoundary()); + } + + StrictMock delegate_; + std::unique_ptr decoder_; + + private: + const FragmentMode fragment_mode_; +}; + +INSTANTIATE_TEST_SUITE_P(All, QpackInstructionDecoderTest, + Values(FragmentMode::kSingleChunk, + FragmentMode::kOctetByOctet)); + +TEST_P(QpackInstructionDecoderTest, SBitAndVarint2) { + EXPECT_CALL(delegate_, OnInstructionDecoded(TestInstruction1())); + DecodeInstruction(absl::HexStringToBytes("7f01ff65")); + + EXPECT_TRUE(decoder_->s_bit()); + EXPECT_EQ(64u, decoder_->varint()); + EXPECT_EQ(356u, decoder_->varint2()); + + EXPECT_CALL(delegate_, OnInstructionDecoded(TestInstruction1())); + DecodeInstruction(absl::HexStringToBytes("05c8")); + + EXPECT_FALSE(decoder_->s_bit()); + EXPECT_EQ(5u, decoder_->varint()); + EXPECT_EQ(200u, decoder_->varint2()); +} + +TEST_P(QpackInstructionDecoderTest, NameAndValue) { + EXPECT_CALL(delegate_, OnInstructionDecoded(TestInstruction2())); + DecodeInstruction(absl::HexStringToBytes("83666f6f03626172")); + + EXPECT_EQ("foo", decoder_->name()); + EXPECT_EQ("bar", decoder_->value()); + + EXPECT_CALL(delegate_, OnInstructionDecoded(TestInstruction2())); + DecodeInstruction(absl::HexStringToBytes("8000")); + + EXPECT_EQ("", decoder_->name()); + EXPECT_EQ("", decoder_->value()); + + EXPECT_CALL(delegate_, OnInstructionDecoded(TestInstruction2())); + DecodeInstruction(absl::HexStringToBytes("c294e7838c767f")); + + EXPECT_EQ("foo", decoder_->name()); + EXPECT_EQ("bar", decoder_->value()); +} + +TEST_P(QpackInstructionDecoderTest, InvalidHuffmanEncoding) { + EXPECT_CALL(delegate_, + OnInstructionDecodingError( + QpackInstructionDecoder::ErrorCode::HUFFMAN_ENCODING_ERROR, + Eq("Error in Huffman-encoded string."))); + DecodeInstruction(absl::HexStringToBytes("c1ff")); +} + +TEST_P(QpackInstructionDecoderTest, InvalidVarintEncoding) { + EXPECT_CALL(delegate_, + OnInstructionDecodingError( + QpackInstructionDecoder::ErrorCode::INTEGER_TOO_LARGE, + Eq("Encoded integer too large."))); + DecodeInstruction(absl::HexStringToBytes("ffffffffffffffffffffff")); +} + +TEST_P(QpackInstructionDecoderTest, StringLiteralTooLong) { + EXPECT_CALL(delegate_, + OnInstructionDecodingError( + QpackInstructionDecoder::ErrorCode::STRING_LITERAL_TOO_LONG, + Eq("String literal too long."))); + DecodeInstruction(absl::HexStringToBytes("bfffff7f")); +} + +TEST_P(QpackInstructionDecoderTest, DelegateSignalsError) { + // First instruction is valid. + Expectation first_call = + EXPECT_CALL(delegate_, OnInstructionDecoded(TestInstruction1())) + .WillOnce(InvokeWithoutArgs([this]() -> bool { + EXPECT_EQ(1u, decoder_->varint()); + return true; + })); + + // Second instruction is invalid. Decoding must halt. + EXPECT_CALL(delegate_, OnInstructionDecoded(TestInstruction1())) + .After(first_call) + .WillOnce(InvokeWithoutArgs([this]() -> bool { + EXPECT_EQ(2u, decoder_->varint()); + return false; + })); + + EXPECT_FALSE( + decoder_->Decode(absl::HexStringToBytes("01000200030004000500"))); +} + +// QpackInstructionDecoder must not crash if it is destroyed from a +// Delegate::OnInstructionDecoded() call as long as it returns false. +TEST_P(QpackInstructionDecoderTest, DelegateSignalsErrorAndDestroysDecoder) { + EXPECT_CALL(delegate_, OnInstructionDecoded(TestInstruction1())) + .WillOnce(InvokeWithoutArgs([this]() -> bool { + EXPECT_EQ(1u, decoder_->varint()); + decoder_.reset(); + return false; + })); + DecodeInstruction(absl::HexStringToBytes("0100")); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_instruction_encoder.cc b/quiche/quic/core/qpack/qpack_instruction_encoder.cc new file mode 100644 index 000000000000..21f549ccf97c --- /dev/null +++ b/quiche/quic/core/qpack/qpack_instruction_encoder.cc @@ -0,0 +1,176 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_instruction_encoder.h" + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/http2/hpack/huffman/hpack_huffman_encoder.h" +#include "quiche/http2/hpack/varint/hpack_varint_encoder.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +QpackInstructionEncoder::QpackInstructionEncoder() + : use_huffman_(false), + string_length_(0), + byte_(0), + state_(State::kOpcode), + instruction_(nullptr) {} + +void QpackInstructionEncoder::Encode( + const QpackInstructionWithValues& instruction_with_values, + std::string* output) { + QUICHE_DCHECK(instruction_with_values.instruction()); + + state_ = State::kOpcode; + instruction_ = instruction_with_values.instruction(); + field_ = instruction_->fields.begin(); + + // Field list must not be empty. + QUICHE_DCHECK(field_ != instruction_->fields.end()); + + do { + switch (state_) { + case State::kOpcode: + DoOpcode(); + break; + case State::kStartField: + DoStartField(); + break; + case State::kSbit: + DoSBit(instruction_with_values.s_bit()); + break; + case State::kVarintEncode: + DoVarintEncode(instruction_with_values.varint(), + instruction_with_values.varint2(), output); + break; + case State::kStartString: + DoStartString(instruction_with_values.name(), + instruction_with_values.value()); + break; + case State::kWriteString: + DoWriteString(instruction_with_values.name(), + instruction_with_values.value(), output); + break; + } + } while (field_ != instruction_->fields.end()); + + QUICHE_DCHECK(state_ == State::kStartField); +} + +void QpackInstructionEncoder::DoOpcode() { + QUICHE_DCHECK_EQ(0u, byte_); + + byte_ = instruction_->opcode.value; + + state_ = State::kStartField; +} + +void QpackInstructionEncoder::DoStartField() { + switch (field_->type) { + case QpackInstructionFieldType::kSbit: + state_ = State::kSbit; + return; + case QpackInstructionFieldType::kVarint: + case QpackInstructionFieldType::kVarint2: + state_ = State::kVarintEncode; + return; + case QpackInstructionFieldType::kName: + case QpackInstructionFieldType::kValue: + state_ = State::kStartString; + return; + } +} + +void QpackInstructionEncoder::DoSBit(bool s_bit) { + QUICHE_DCHECK(field_->type == QpackInstructionFieldType::kSbit); + + if (s_bit) { + QUICHE_DCHECK_EQ(0, byte_ & field_->param); + + byte_ |= field_->param; + } + + ++field_; + state_ = State::kStartField; +} + +void QpackInstructionEncoder::DoVarintEncode(uint64_t varint, uint64_t varint2, + std::string* output) { + QUICHE_DCHECK(field_->type == QpackInstructionFieldType::kVarint || + field_->type == QpackInstructionFieldType::kVarint2 || + field_->type == QpackInstructionFieldType::kName || + field_->type == QpackInstructionFieldType::kValue); + uint64_t integer_to_encode; + switch (field_->type) { + case QpackInstructionFieldType::kVarint: + integer_to_encode = varint; + break; + case QpackInstructionFieldType::kVarint2: + integer_to_encode = varint2; + break; + default: + integer_to_encode = string_length_; + break; + } + + http2::HpackVarintEncoder::Encode(byte_, field_->param, integer_to_encode, + output); + byte_ = 0; + + if (field_->type == QpackInstructionFieldType::kVarint || + field_->type == QpackInstructionFieldType::kVarint2) { + ++field_; + state_ = State::kStartField; + return; + } + + state_ = State::kWriteString; +} + +void QpackInstructionEncoder::DoStartString(absl::string_view name, + absl::string_view value) { + QUICHE_DCHECK(field_->type == QpackInstructionFieldType::kName || + field_->type == QpackInstructionFieldType::kValue); + + absl::string_view string_to_write = + (field_->type == QpackInstructionFieldType::kName) ? name : value; + string_length_ = string_to_write.size(); + + size_t encoded_size = http2::HuffmanSize(string_to_write); + use_huffman_ = encoded_size < string_length_; + + if (use_huffman_) { + QUICHE_DCHECK_EQ(0, byte_ & (1 << field_->param)); + byte_ |= (1 << field_->param); + + string_length_ = encoded_size; + } + + state_ = State::kVarintEncode; +} + +void QpackInstructionEncoder::DoWriteString(absl::string_view name, + absl::string_view value, + std::string* output) { + QUICHE_DCHECK(field_->type == QpackInstructionFieldType::kName || + field_->type == QpackInstructionFieldType::kValue); + + absl::string_view string_to_write = + (field_->type == QpackInstructionFieldType::kName) ? name : value; + if (use_huffman_) { + http2::HuffmanEncodeFast(string_to_write, string_length_, output); + } else { + absl::StrAppend(output, string_to_write); + } + + ++field_; + state_ = State::kStartField; +} + +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_instruction_encoder.h b/quiche/quic/core/qpack/qpack_instruction_encoder.h new file mode 100644 index 000000000000..9b7551d33620 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_instruction_encoder.h @@ -0,0 +1,83 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QPACK_QPACK_INSTRUCTION_ENCODER_H_ +#define QUICHE_QUIC_CORE_QPACK_QPACK_INSTRUCTION_ENCODER_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/qpack/qpack_instructions.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// Generic instruction encoder class. Takes a QpackLanguage that describes a +// language, that is, a set of instruction opcodes together with a list of +// fields that follow each instruction. +class QUIC_EXPORT_PRIVATE QpackInstructionEncoder { + public: + QpackInstructionEncoder(); + QpackInstructionEncoder(const QpackInstructionEncoder&) = delete; + QpackInstructionEncoder& operator=(const QpackInstructionEncoder&) = delete; + + // Append encoded instruction to |output|. + void Encode(const QpackInstructionWithValues& instruction_with_values, + std::string* output); + + private: + enum class State { + // Write instruction opcode to |byte_|. + kOpcode, + // Select state based on type of current field. + kStartField, + // Write static bit to |byte_|. + kSbit, + // Encode an integer (|varint_| or |varint2_| or string length) with a + // prefix, using |byte_| for the high bits. + kVarintEncode, + // Determine if Huffman encoding should be used for the header name or + // value, set |use_huffman_| and |string_length_| appropriately, write the + // Huffman bit to |byte_|. + kStartString, + // Write header name or value, performing Huffman encoding if |use_huffman_| + // is true. + kWriteString + }; + + // One method for each state. Some append encoded bytes to |output|. + // Some only change internal state. + void DoOpcode(); + void DoStartField(); + void DoSBit(bool s_bit); + void DoVarintEncode(uint64_t varint, uint64_t varint2, std::string* output); + void DoStartString(absl::string_view name, absl::string_view value); + void DoWriteString(absl::string_view name, absl::string_view value, + std::string* output); + + // True if name or value should be Huffman encoded. + bool use_huffman_; + + // Length of name or value string to be written. + // If |use_huffman_| is true, length is after Huffman encoding. + size_t string_length_; + + // Storage for a single byte that contains multiple fields, that is, multiple + // states are writing it. + uint8_t byte_; + + // Encoding state. + State state_; + + // Instruction currently being decoded. + const QpackInstruction* instruction_; + + // Field currently being decoded. + QpackInstructionFields::const_iterator field_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QPACK_QPACK_INSTRUCTION_ENCODER_H_ diff --git a/quiche/quic/core/qpack/qpack_instruction_encoder_test.cc b/quiche/quic/core/qpack/qpack_instruction_encoder_test.cc new file mode 100644 index 000000000000..dfdec9e2f04f --- /dev/null +++ b/quiche/quic/core/qpack/qpack_instruction_encoder_test.cc @@ -0,0 +1,204 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_instruction_encoder.h" + +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { + +class QpackInstructionWithValuesPeer { + public: + static QpackInstructionWithValues CreateQpackInstructionWithValues( + const QpackInstruction* instruction) { + QpackInstructionWithValues instruction_with_values; + instruction_with_values.instruction_ = instruction; + return instruction_with_values; + } + + static void set_s_bit(QpackInstructionWithValues* instruction_with_values, + bool s_bit) { + instruction_with_values->s_bit_ = s_bit; + } + + static void set_varint(QpackInstructionWithValues* instruction_with_values, + uint64_t varint) { + instruction_with_values->varint_ = varint; + } + + static void set_varint2(QpackInstructionWithValues* instruction_with_values, + uint64_t varint2) { + instruction_with_values->varint2_ = varint2; + } + + static void set_name(QpackInstructionWithValues* instruction_with_values, + absl::string_view name) { + instruction_with_values->name_ = name; + } + + static void set_value(QpackInstructionWithValues* instruction_with_values, + absl::string_view value) { + instruction_with_values->value_ = value; + } +}; + +namespace { + +class QpackInstructionEncoderTest : public QuicTest { + protected: + QpackInstructionEncoderTest() : verified_position_(0) {} + ~QpackInstructionEncoderTest() override = default; + + // Append encoded |instruction| to |output_|. + void EncodeInstruction( + const QpackInstructionWithValues& instruction_with_values) { + encoder_.Encode(instruction_with_values, &output_); + } + + // Compare substring appended to |output_| since last EncodedSegmentMatches() + // call against hex-encoded argument. + bool EncodedSegmentMatches(absl::string_view hex_encoded_expected_substring) { + auto recently_encoded = + absl::string_view(output_).substr(verified_position_); + auto expected = absl::HexStringToBytes(hex_encoded_expected_substring); + verified_position_ = output_.size(); + return recently_encoded == expected; + } + + private: + QpackInstructionEncoder encoder_; + std::string output_; + std::string::size_type verified_position_; +}; + +TEST_F(QpackInstructionEncoderTest, Varint) { + const QpackInstruction instruction{QpackInstructionOpcode{0x00, 0x80}, + {{QpackInstructionFieldType::kVarint, 7}}}; + + auto instruction_with_values = + QpackInstructionWithValuesPeer::CreateQpackInstructionWithValues( + &instruction); + QpackInstructionWithValuesPeer::set_varint(&instruction_with_values, 5); + EncodeInstruction(instruction_with_values); + EXPECT_TRUE(EncodedSegmentMatches("05")); + + QpackInstructionWithValuesPeer::set_varint(&instruction_with_values, 127); + EncodeInstruction(instruction_with_values); + EXPECT_TRUE(EncodedSegmentMatches("7f00")); +} + +TEST_F(QpackInstructionEncoderTest, SBitAndTwoVarint2) { + const QpackInstruction instruction{ + QpackInstructionOpcode{0x80, 0xc0}, + {{QpackInstructionFieldType::kSbit, 0x20}, + {QpackInstructionFieldType::kVarint, 5}, + {QpackInstructionFieldType::kVarint2, 8}}}; + + auto instruction_with_values = + QpackInstructionWithValuesPeer::CreateQpackInstructionWithValues( + &instruction); + QpackInstructionWithValuesPeer::set_s_bit(&instruction_with_values, true); + QpackInstructionWithValuesPeer::set_varint(&instruction_with_values, 5); + QpackInstructionWithValuesPeer::set_varint2(&instruction_with_values, 200); + EncodeInstruction(instruction_with_values); + EXPECT_TRUE(EncodedSegmentMatches("a5c8")); + + QpackInstructionWithValuesPeer::set_s_bit(&instruction_with_values, false); + QpackInstructionWithValuesPeer::set_varint(&instruction_with_values, 31); + QpackInstructionWithValuesPeer::set_varint2(&instruction_with_values, 356); + EncodeInstruction(instruction_with_values); + EXPECT_TRUE(EncodedSegmentMatches("9f00ff65")); +} + +TEST_F(QpackInstructionEncoderTest, SBitAndVarintAndValue) { + const QpackInstruction instruction{QpackInstructionOpcode{0xc0, 0xc0}, + {{QpackInstructionFieldType::kSbit, 0x20}, + {QpackInstructionFieldType::kVarint, 5}, + {QpackInstructionFieldType::kValue, 7}}}; + + auto instruction_with_values = + QpackInstructionWithValuesPeer::CreateQpackInstructionWithValues( + &instruction); + QpackInstructionWithValuesPeer::set_s_bit(&instruction_with_values, true); + QpackInstructionWithValuesPeer::set_varint(&instruction_with_values, 100); + QpackInstructionWithValuesPeer::set_value(&instruction_with_values, "foo"); + EncodeInstruction(instruction_with_values); + EXPECT_TRUE(EncodedSegmentMatches("ff458294e7")); + + QpackInstructionWithValuesPeer::set_s_bit(&instruction_with_values, false); + QpackInstructionWithValuesPeer::set_varint(&instruction_with_values, 3); + QpackInstructionWithValuesPeer::set_value(&instruction_with_values, "bar"); + EncodeInstruction(instruction_with_values); + EXPECT_TRUE(EncodedSegmentMatches("c303626172")); +} + +TEST_F(QpackInstructionEncoderTest, Name) { + const QpackInstruction instruction{QpackInstructionOpcode{0xe0, 0xe0}, + {{QpackInstructionFieldType::kName, 4}}}; + + auto instruction_with_values = + QpackInstructionWithValuesPeer::CreateQpackInstructionWithValues( + &instruction); + QpackInstructionWithValuesPeer::set_name(&instruction_with_values, ""); + EncodeInstruction(instruction_with_values); + EXPECT_TRUE(EncodedSegmentMatches("e0")); + + QpackInstructionWithValuesPeer::set_name(&instruction_with_values, "foo"); + EncodeInstruction(instruction_with_values); + EXPECT_TRUE(EncodedSegmentMatches("f294e7")); + + QpackInstructionWithValuesPeer::set_name(&instruction_with_values, "bar"); + EncodeInstruction(instruction_with_values); + EXPECT_TRUE(EncodedSegmentMatches("e3626172")); +} + +TEST_F(QpackInstructionEncoderTest, Value) { + const QpackInstruction instruction{QpackInstructionOpcode{0xf0, 0xf0}, + {{QpackInstructionFieldType::kValue, 3}}}; + + auto instruction_with_values = + QpackInstructionWithValuesPeer::CreateQpackInstructionWithValues( + &instruction); + QpackInstructionWithValuesPeer::set_value(&instruction_with_values, ""); + EncodeInstruction(instruction_with_values); + EXPECT_TRUE(EncodedSegmentMatches("f0")); + + QpackInstructionWithValuesPeer::set_value(&instruction_with_values, "foo"); + EncodeInstruction(instruction_with_values); + EXPECT_TRUE(EncodedSegmentMatches("fa94e7")); + + QpackInstructionWithValuesPeer::set_value(&instruction_with_values, "bar"); + EncodeInstruction(instruction_with_values); + EXPECT_TRUE(EncodedSegmentMatches("f3626172")); +} + +TEST_F(QpackInstructionEncoderTest, SBitAndNameAndValue) { + const QpackInstruction instruction{QpackInstructionOpcode{0xf0, 0xf0}, + {{QpackInstructionFieldType::kSbit, 0x08}, + {QpackInstructionFieldType::kName, 2}, + {QpackInstructionFieldType::kValue, 7}}}; + + auto instruction_with_values = + QpackInstructionWithValuesPeer::CreateQpackInstructionWithValues( + &instruction); + QpackInstructionWithValuesPeer::set_s_bit(&instruction_with_values, false); + QpackInstructionWithValuesPeer::set_name(&instruction_with_values, ""); + QpackInstructionWithValuesPeer::set_value(&instruction_with_values, ""); + EncodeInstruction(instruction_with_values); + EXPECT_TRUE(EncodedSegmentMatches("f000")); + + QpackInstructionWithValuesPeer::set_s_bit(&instruction_with_values, true); + QpackInstructionWithValuesPeer::set_name(&instruction_with_values, "foo"); + QpackInstructionWithValuesPeer::set_value(&instruction_with_values, "bar"); + EncodeInstruction(instruction_with_values); + EXPECT_TRUE(EncodedSegmentMatches("fe94e703626172")); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_instructions.cc b/quiche/quic/core/qpack/qpack_instructions.cc new file mode 100644 index 000000000000..9de0dcdea5a8 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_instructions.cc @@ -0,0 +1,326 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_instructions.h" + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +namespace { + +// Validate that +// * in each instruction, the bits of |value| that are zero in |mask| are zero; +// * every byte matches exactly one opcode. +void ValidateLangague(const QpackLanguage* language) { +#ifndef NDEBUG + for (const auto* instruction : *language) { + QUICHE_DCHECK_EQ(0, instruction->opcode.value & ~instruction->opcode.mask); + } + + for (uint8_t byte = 0; byte < std::numeric_limits::max(); ++byte) { + size_t match_count = 0; + for (const auto* instruction : *language) { + if ((byte & instruction->opcode.mask) == instruction->opcode.value) { + ++match_count; + } + } + QUICHE_DCHECK_EQ(1u, match_count) << static_cast(byte); + } +#else + (void)language; +#endif +} + +} // namespace + +bool operator==(const QpackInstructionOpcode& a, + const QpackInstructionOpcode& b) { + return std::tie(a.value, a.mask) == std::tie(b.value, b.mask); +} + +const QpackInstruction* InsertWithNameReferenceInstruction() { + static const QpackInstructionOpcode* const opcode = + new QpackInstructionOpcode{0b10000000, 0b10000000}; + static const QpackInstruction* const instruction = + new QpackInstruction{*opcode, + {{QpackInstructionFieldType::kSbit, 0b01000000}, + {QpackInstructionFieldType::kVarint, 6}, + {QpackInstructionFieldType::kValue, 7}}}; + return instruction; +} + +const QpackInstruction* InsertWithoutNameReferenceInstruction() { + static const QpackInstructionOpcode* const opcode = + new QpackInstructionOpcode{0b01000000, 0b11000000}; + static const QpackInstruction* const instruction = + new QpackInstruction{*opcode, + {{QpackInstructionFieldType::kName, 5}, + {QpackInstructionFieldType::kValue, 7}}}; + return instruction; +} + +const QpackInstruction* DuplicateInstruction() { + static const QpackInstructionOpcode* const opcode = + new QpackInstructionOpcode{0b00000000, 0b11100000}; + static const QpackInstruction* const instruction = + new QpackInstruction{*opcode, {{QpackInstructionFieldType::kVarint, 5}}}; + return instruction; +} + +const QpackInstruction* SetDynamicTableCapacityInstruction() { + static const QpackInstructionOpcode* const opcode = + new QpackInstructionOpcode{0b00100000, 0b11100000}; + static const QpackInstruction* const instruction = + new QpackInstruction{*opcode, {{QpackInstructionFieldType::kVarint, 5}}}; + return instruction; +} + +const QpackLanguage* QpackEncoderStreamLanguage() { + static const QpackLanguage* const language = new QpackLanguage{ + InsertWithNameReferenceInstruction(), + InsertWithoutNameReferenceInstruction(), DuplicateInstruction(), + SetDynamicTableCapacityInstruction()}; + ValidateLangague(language); + return language; +} + +const QpackInstruction* InsertCountIncrementInstruction() { + static const QpackInstructionOpcode* const opcode = + new QpackInstructionOpcode{0b00000000, 0b11000000}; + static const QpackInstruction* const instruction = + new QpackInstruction{*opcode, {{QpackInstructionFieldType::kVarint, 6}}}; + return instruction; +} + +const QpackInstruction* HeaderAcknowledgementInstruction() { + static const QpackInstructionOpcode* const opcode = + new QpackInstructionOpcode{0b10000000, 0b10000000}; + static const QpackInstruction* const instruction = + new QpackInstruction{*opcode, {{QpackInstructionFieldType::kVarint, 7}}}; + return instruction; +} + +const QpackInstruction* StreamCancellationInstruction() { + static const QpackInstructionOpcode* const opcode = + new QpackInstructionOpcode{0b01000000, 0b11000000}; + static const QpackInstruction* const instruction = + new QpackInstruction{*opcode, {{QpackInstructionFieldType::kVarint, 6}}}; + return instruction; +} + +const QpackLanguage* QpackDecoderStreamLanguage() { + static const QpackLanguage* const language = new QpackLanguage{ + InsertCountIncrementInstruction(), HeaderAcknowledgementInstruction(), + StreamCancellationInstruction()}; + ValidateLangague(language); + return language; +} + +const QpackInstruction* QpackPrefixInstruction() { + // This opcode matches every input. + static const QpackInstructionOpcode* const opcode = + new QpackInstructionOpcode{0b00000000, 0b00000000}; + static const QpackInstruction* const instruction = + new QpackInstruction{*opcode, + {{QpackInstructionFieldType::kVarint, 8}, + {QpackInstructionFieldType::kSbit, 0b10000000}, + {QpackInstructionFieldType::kVarint2, 7}}}; + return instruction; +} + +const QpackLanguage* QpackPrefixLanguage() { + static const QpackLanguage* const language = + new QpackLanguage{QpackPrefixInstruction()}; + ValidateLangague(language); + return language; +} + +const QpackInstruction* QpackIndexedHeaderFieldInstruction() { + static const QpackInstructionOpcode* const opcode = + new QpackInstructionOpcode{0b10000000, 0b10000000}; + static const QpackInstruction* const instruction = + new QpackInstruction{*opcode, + {{QpackInstructionFieldType::kSbit, 0b01000000}, + {QpackInstructionFieldType::kVarint, 6}}}; + return instruction; +} + +const QpackInstruction* QpackIndexedHeaderFieldPostBaseInstruction() { + static const QpackInstructionOpcode* const opcode = + new QpackInstructionOpcode{0b00010000, 0b11110000}; + static const QpackInstruction* const instruction = + new QpackInstruction{*opcode, {{QpackInstructionFieldType::kVarint, 4}}}; + return instruction; +} + +const QpackInstruction* QpackLiteralHeaderFieldNameReferenceInstruction() { + static const QpackInstructionOpcode* const opcode = + new QpackInstructionOpcode{0b01000000, 0b11000000}; + static const QpackInstruction* const instruction = + new QpackInstruction{*opcode, + {{QpackInstructionFieldType::kSbit, 0b00010000}, + {QpackInstructionFieldType::kVarint, 4}, + {QpackInstructionFieldType::kValue, 7}}}; + return instruction; +} + +const QpackInstruction* QpackLiteralHeaderFieldPostBaseInstruction() { + static const QpackInstructionOpcode* const opcode = + new QpackInstructionOpcode{0b00000000, 0b11110000}; + static const QpackInstruction* const instruction = + new QpackInstruction{*opcode, + {{QpackInstructionFieldType::kVarint, 3}, + {QpackInstructionFieldType::kValue, 7}}}; + return instruction; +} + +const QpackInstruction* QpackLiteralHeaderFieldInstruction() { + static const QpackInstructionOpcode* const opcode = + new QpackInstructionOpcode{0b00100000, 0b11100000}; + static const QpackInstruction* const instruction = + new QpackInstruction{*opcode, + {{QpackInstructionFieldType::kName, 3}, + {QpackInstructionFieldType::kValue, 7}}}; + return instruction; +} + +const QpackLanguage* QpackRequestStreamLanguage() { + static const QpackLanguage* const language = + new QpackLanguage{QpackIndexedHeaderFieldInstruction(), + QpackIndexedHeaderFieldPostBaseInstruction(), + QpackLiteralHeaderFieldNameReferenceInstruction(), + QpackLiteralHeaderFieldPostBaseInstruction(), + QpackLiteralHeaderFieldInstruction()}; + ValidateLangague(language); + return language; +} + +// static +QpackInstructionWithValues QpackInstructionWithValues::InsertWithNameReference( + bool is_static, uint64_t name_index, absl::string_view value) { + QpackInstructionWithValues instruction_with_values; + instruction_with_values.instruction_ = InsertWithNameReferenceInstruction(); + instruction_with_values.s_bit_ = is_static; + instruction_with_values.varint_ = name_index; + instruction_with_values.value_ = value; + + return instruction_with_values; +} + +// static +QpackInstructionWithValues +QpackInstructionWithValues::InsertWithoutNameReference( + absl::string_view name, absl::string_view value) { + QpackInstructionWithValues instruction_with_values; + instruction_with_values.instruction_ = + InsertWithoutNameReferenceInstruction(); + instruction_with_values.name_ = name; + instruction_with_values.value_ = value; + + return instruction_with_values; +} + +// static +QpackInstructionWithValues QpackInstructionWithValues::Duplicate( + uint64_t index) { + QpackInstructionWithValues instruction_with_values; + instruction_with_values.instruction_ = DuplicateInstruction(); + instruction_with_values.varint_ = index; + + return instruction_with_values; +} + +// static +QpackInstructionWithValues QpackInstructionWithValues::SetDynamicTableCapacity( + uint64_t capacity) { + QpackInstructionWithValues instruction_with_values; + instruction_with_values.instruction_ = SetDynamicTableCapacityInstruction(); + instruction_with_values.varint_ = capacity; + + return instruction_with_values; +} + +// static +QpackInstructionWithValues QpackInstructionWithValues::InsertCountIncrement( + uint64_t increment) { + QpackInstructionWithValues instruction_with_values; + instruction_with_values.instruction_ = InsertCountIncrementInstruction(); + instruction_with_values.varint_ = increment; + + return instruction_with_values; +} + +// static +QpackInstructionWithValues QpackInstructionWithValues::HeaderAcknowledgement( + uint64_t stream_id) { + QpackInstructionWithValues instruction_with_values; + instruction_with_values.instruction_ = HeaderAcknowledgementInstruction(); + instruction_with_values.varint_ = stream_id; + + return instruction_with_values; +} + +// static +QpackInstructionWithValues QpackInstructionWithValues::StreamCancellation( + uint64_t stream_id) { + QpackInstructionWithValues instruction_with_values; + instruction_with_values.instruction_ = StreamCancellationInstruction(); + instruction_with_values.varint_ = stream_id; + + return instruction_with_values; +} + +// static +QpackInstructionWithValues QpackInstructionWithValues::Prefix( + uint64_t required_insert_count) { + QpackInstructionWithValues instruction_with_values; + instruction_with_values.instruction_ = QpackPrefixInstruction(); + instruction_with_values.varint_ = required_insert_count; + instruction_with_values.varint2_ = 0; // Delta Base. + instruction_with_values.s_bit_ = false; // Delta Base sign. + + return instruction_with_values; +} + +// static +QpackInstructionWithValues QpackInstructionWithValues::IndexedHeaderField( + bool is_static, uint64_t index) { + QpackInstructionWithValues instruction_with_values; + instruction_with_values.instruction_ = QpackIndexedHeaderFieldInstruction(); + instruction_with_values.s_bit_ = is_static; + instruction_with_values.varint_ = index; + + return instruction_with_values; +} + +// static +QpackInstructionWithValues +QpackInstructionWithValues::LiteralHeaderFieldNameReference( + bool is_static, uint64_t index, absl::string_view value) { + QpackInstructionWithValues instruction_with_values; + instruction_with_values.instruction_ = + QpackLiteralHeaderFieldNameReferenceInstruction(); + instruction_with_values.s_bit_ = is_static; + instruction_with_values.varint_ = index; + instruction_with_values.value_ = value; + + return instruction_with_values; +} + +// static +QpackInstructionWithValues QpackInstructionWithValues::LiteralHeaderField( + absl::string_view name, absl::string_view value) { + QpackInstructionWithValues instruction_with_values; + instruction_with_values.instruction_ = QpackLiteralHeaderFieldInstruction(); + instruction_with_values.name_ = name; + instruction_with_values.value_ = value; + + return instruction_with_values; +} + +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_instructions.h b/quiche/quic/core/qpack/qpack_instructions.h new file mode 100644 index 000000000000..a96424278c06 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_instructions.h @@ -0,0 +1,205 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QPACK_QPACK_INSTRUCTIONS_H_ +#define QUICHE_QUIC_CORE_QPACK_QPACK_INSTRUCTIONS_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +namespace test { +class QpackInstructionWithValuesPeer; +} // namespace test + +// Each instruction is identified with an opcode in the first byte. +// |mask| determines which bits are part of the opcode. +// |value| is the value of these bits. (Other bits in value must be zero.) +struct QUIC_EXPORT_PRIVATE QpackInstructionOpcode { + uint8_t value; + uint8_t mask; +}; + +bool operator==(const QpackInstructionOpcode& a, + const QpackInstructionOpcode& b); + +// Possible types of an instruction field. Decoding a static bit does not +// consume the current byte. Decoding an integer or a length-prefixed string +// literal consumes all bytes containing the field value. +enum class QpackInstructionFieldType { + // A single bit indicating whether the index refers to the static table, or + // indicating the sign of Delta Base. Called "S" bit because both "static" + // and "sign" start with the letter "S". + kSbit, + // An integer encoded with variable length encoding. This could be an index, + // stream ID, maximum size, or Encoded Required Insert Count. + kVarint, + // A second integer encoded with variable length encoding. This could be + // Delta Base. + kVarint2, + // A header name or header value encoded as: + // a bit indicating whether it is Huffman encoded; + // the encoded length of the string; + // the header name or value optionally Huffman encoded. + kName, + kValue +}; + +// Each instruction field has a type and a parameter. +// The meaning of the parameter depends on the field type. +struct QUIC_EXPORT_PRIVATE QpackInstructionField { + QpackInstructionFieldType type; + // For a kSbit field, |param| is a mask with exactly one bit set. + // For kVarint fields, |param| is the prefix length of the integer encoding. + // For kName and kValue fields, |param| is the prefix length of the length of + // the string, and the bit immediately preceding the prefix is interpreted as + // the Huffman bit. + uint8_t param; +}; + +using QpackInstructionFields = std::vector; + +// A QPACK instruction consists of an opcode identifying the instruction, +// followed by a non-empty list of fields. The last field must be integer or +// string literal type to guarantee that all bytes of the instruction are +// consumed. +struct QUIC_EXPORT_PRIVATE QpackInstruction { + QpackInstruction(QpackInstructionOpcode opcode, QpackInstructionFields fields) + : opcode(std::move(opcode)), fields(std::move(fields)) {} + + QpackInstruction(const QpackInstruction&) = delete; + const QpackInstruction& operator=(const QpackInstruction&) = delete; + + QpackInstructionOpcode opcode; + QpackInstructionFields fields; +}; + +// A language is a collection of instructions. The order does not matter. +// Every possible input must match exactly one instruction. +using QpackLanguage = std::vector; + +// Wire format defined in +// https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#rfc.section.5 + +// 5.2 Encoder stream instructions + +// 5.2.1 Insert With Name Reference +const QpackInstruction* InsertWithNameReferenceInstruction(); + +// 5.2.2 Insert Without Name Reference +const QpackInstruction* InsertWithoutNameReferenceInstruction(); + +// 5.2.3 Duplicate +const QpackInstruction* DuplicateInstruction(); + +// 5.2.4 Dynamic Table Size Update +const QpackInstruction* SetDynamicTableCapacityInstruction(); + +// Encoder stream language. +const QpackLanguage* QpackEncoderStreamLanguage(); + +// 5.3 Decoder stream instructions + +// 5.3.1 Insert Count Increment +const QpackInstruction* InsertCountIncrementInstruction(); + +// 5.3.2 Header Acknowledgement +const QpackInstruction* HeaderAcknowledgementInstruction(); + +// 5.3.3 Stream Cancellation +const QpackInstruction* StreamCancellationInstruction(); + +// Decoder stream language. +const QpackLanguage* QpackDecoderStreamLanguage(); + +// 5.4.1. Header data prefix instructions + +const QpackInstruction* QpackPrefixInstruction(); + +const QpackLanguage* QpackPrefixLanguage(); + +// 5.4.2. Request and push stream instructions + +// 5.4.2.1. Indexed Header Field +const QpackInstruction* QpackIndexedHeaderFieldInstruction(); + +// 5.4.2.2. Indexed Header Field With Post-Base Index +const QpackInstruction* QpackIndexedHeaderFieldPostBaseInstruction(); + +// 5.4.2.3. Literal Header Field With Name Reference +const QpackInstruction* QpackLiteralHeaderFieldNameReferenceInstruction(); + +// 5.4.2.4. Literal Header Field With Post-Base Name Reference +const QpackInstruction* QpackLiteralHeaderFieldPostBaseInstruction(); + +// 5.4.2.5. Literal Header Field Without Name Reference +const QpackInstruction* QpackLiteralHeaderFieldInstruction(); + +// Request and push stream language. +const QpackLanguage* QpackRequestStreamLanguage(); + +// Storage for instruction and field values to be encoded. +// This class can only be instantiated using factory methods that take exactly +// the arguments that the corresponding instruction needs. +class QUIC_EXPORT_PRIVATE QpackInstructionWithValues { + public: + // 5.2 Encoder stream instructions + static QpackInstructionWithValues InsertWithNameReference( + bool is_static, uint64_t name_index, absl::string_view value); + static QpackInstructionWithValues InsertWithoutNameReference( + absl::string_view name, absl::string_view value); + static QpackInstructionWithValues Duplicate(uint64_t index); + static QpackInstructionWithValues SetDynamicTableCapacity(uint64_t capacity); + + // 5.3 Decoder stream instructions + static QpackInstructionWithValues InsertCountIncrement(uint64_t increment); + static QpackInstructionWithValues HeaderAcknowledgement(uint64_t stream_id); + static QpackInstructionWithValues StreamCancellation(uint64_t stream_id); + + // 5.4.1. Header data prefix. Delta Base is hardcoded to be zero. + static QpackInstructionWithValues Prefix(uint64_t required_insert_count); + + // 5.4.2. Request and push stream instructions + static QpackInstructionWithValues IndexedHeaderField(bool is_static, + uint64_t index); + static QpackInstructionWithValues LiteralHeaderFieldNameReference( + bool is_static, uint64_t index, absl::string_view value); + static QpackInstructionWithValues LiteralHeaderField(absl::string_view name, + absl::string_view value); + + const QpackInstruction* instruction() const { return instruction_; } + bool s_bit() const { return s_bit_; } + uint64_t varint() const { return varint_; } + uint64_t varint2() const { return varint2_; } + absl::string_view name() const { return name_; } + absl::string_view value() const { return value_; } + + // Used by QpackEncoder, because in the first pass it stores absolute indices, + // which are converted into relative indices in the second pass after base is + // determined. + void set_varint(uint64_t varint) { varint_ = varint; } + + private: + friend test::QpackInstructionWithValuesPeer; + + QpackInstructionWithValues() = default; + + // |*instruction| is not owned. + const QpackInstruction* instruction_ = nullptr; + bool s_bit_ = false; + uint64_t varint_ = 0; + uint64_t varint2_ = 0; + absl::string_view name_; + absl::string_view value_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QPACK_QPACK_INSTRUCTIONS_H_ diff --git a/quiche/quic/core/qpack/qpack_progressive_decoder.cc b/quiche/quic/core/qpack/qpack_progressive_decoder.cc new file mode 100644 index 000000000000..5dc5a68cc209 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_progressive_decoder.cc @@ -0,0 +1,406 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_progressive_decoder.h" + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/qpack/qpack_index_conversions.h" +#include "quiche/quic/core/qpack/qpack_instructions.h" +#include "quiche/quic/core/qpack/qpack_required_insert_count.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +namespace { + +// The value argument passed to OnHeaderDecoded() is from an entry in the static +// table. +constexpr bool kValueFromStaticTable = true; + +} // anonymous namespace + +QpackProgressiveDecoder::QpackProgressiveDecoder( + QuicStreamId stream_id, BlockedStreamLimitEnforcer* enforcer, + DecodingCompletedVisitor* visitor, QpackDecoderHeaderTable* header_table, + HeadersHandlerInterface* handler) + : stream_id_(stream_id), + prefix_decoder_(std::make_unique( + QpackPrefixLanguage(), this)), + instruction_decoder_(QpackRequestStreamLanguage(), this), + enforcer_(enforcer), + visitor_(visitor), + header_table_(header_table), + handler_(handler), + required_insert_count_(0), + base_(0), + required_insert_count_so_far_(0), + prefix_decoded_(false), + blocked_(false), + decoding_(true), + error_detected_(false), + cancelled_(false) {} + +QpackProgressiveDecoder::~QpackProgressiveDecoder() { + if (blocked_ && !cancelled_) { + header_table_->UnregisterObserver(required_insert_count_, this); + } +} + +void QpackProgressiveDecoder::Decode(absl::string_view data) { + QUICHE_DCHECK(decoding_); + + if (data.empty() || error_detected_) { + return; + } + + // Decode prefix byte by byte until the first (and only) instruction is + // decoded. + while (!prefix_decoded_) { + QUICHE_DCHECK(!blocked_); + + if (!prefix_decoder_->Decode(data.substr(0, 1))) { + return; + } + + // |prefix_decoder_->Decode()| must return false if an error is detected. + QUICHE_DCHECK(!error_detected_); + + data = data.substr(1); + if (data.empty()) { + return; + } + } + + if (blocked_) { + buffer_.append(data.data(), data.size()); + } else { + QUICHE_DCHECK(buffer_.empty()); + + instruction_decoder_.Decode(data); + } +} + +void QpackProgressiveDecoder::EndHeaderBlock() { + QUICHE_DCHECK(decoding_); + decoding_ = false; + + if (!blocked_) { + FinishDecoding(); + } +} + +bool QpackProgressiveDecoder::OnInstructionDecoded( + const QpackInstruction* instruction) { + if (instruction == QpackPrefixInstruction()) { + return DoPrefixInstruction(); + } + + QUICHE_DCHECK(prefix_decoded_); + QUICHE_DCHECK_LE(required_insert_count_, + header_table_->inserted_entry_count()); + + if (instruction == QpackIndexedHeaderFieldInstruction()) { + return DoIndexedHeaderFieldInstruction(); + } + if (instruction == QpackIndexedHeaderFieldPostBaseInstruction()) { + return DoIndexedHeaderFieldPostBaseInstruction(); + } + if (instruction == QpackLiteralHeaderFieldNameReferenceInstruction()) { + return DoLiteralHeaderFieldNameReferenceInstruction(); + } + if (instruction == QpackLiteralHeaderFieldPostBaseInstruction()) { + return DoLiteralHeaderFieldPostBaseInstruction(); + } + QUICHE_DCHECK_EQ(instruction, QpackLiteralHeaderFieldInstruction()); + return DoLiteralHeaderFieldInstruction(); +} + +void QpackProgressiveDecoder::OnInstructionDecodingError( + QpackInstructionDecoder::ErrorCode /* error_code */, + absl::string_view error_message) { + // Ignore |error_code| and always use QUIC_QPACK_DECOMPRESSION_FAILED to avoid + // having to define a new QuicErrorCode for every instruction decoder error. + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, error_message); +} + +void QpackProgressiveDecoder::OnInsertCountReachedThreshold() { + QUICHE_DCHECK(blocked_); + + // Clear |blocked_| before calling instruction_decoder_.Decode() below, + // because that might destroy |this| and ~QpackProgressiveDecoder() needs to + // know not to call UnregisterObserver(). + blocked_ = false; + enforcer_->OnStreamUnblocked(stream_id_); + + if (!buffer_.empty()) { + std::string buffer(std::move(buffer_)); + buffer_.clear(); + if (!instruction_decoder_.Decode(buffer)) { + // |this| might be destroyed. + return; + } + } + + if (!decoding_) { + FinishDecoding(); + } +} + +void QpackProgressiveDecoder::Cancel() { cancelled_ = true; } + +bool QpackProgressiveDecoder::DoIndexedHeaderFieldInstruction() { + if (!instruction_decoder_.s_bit()) { + uint64_t absolute_index; + if (!QpackRequestStreamRelativeIndexToAbsoluteIndex( + instruction_decoder_.varint(), base_, &absolute_index)) { + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, "Invalid relative index."); + return false; + } + + if (absolute_index >= required_insert_count_) { + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, + "Absolute Index must be smaller than Required Insert Count."); + return false; + } + + QUICHE_DCHECK_LT(absolute_index, std::numeric_limits::max()); + required_insert_count_so_far_ = + std::max(required_insert_count_so_far_, absolute_index + 1); + + auto entry = + header_table_->LookupEntry(/* is_static = */ false, absolute_index); + if (!entry) { + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, + "Dynamic table entry already evicted."); + return false; + } + + header_table_->set_dynamic_table_entry_referenced(); + return OnHeaderDecoded(!kValueFromStaticTable, entry->name(), + entry->value()); + } + + auto entry = header_table_->LookupEntry(/* is_static = */ true, + instruction_decoder_.varint()); + if (!entry) { + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, "Static table entry not found."); + return false; + } + + return OnHeaderDecoded(kValueFromStaticTable, entry->name(), entry->value()); +} + +bool QpackProgressiveDecoder::DoIndexedHeaderFieldPostBaseInstruction() { + uint64_t absolute_index; + if (!QpackPostBaseIndexToAbsoluteIndex(instruction_decoder_.varint(), base_, + &absolute_index)) { + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, "Invalid post-base index."); + return false; + } + + if (absolute_index >= required_insert_count_) { + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, + "Absolute Index must be smaller than Required Insert Count."); + return false; + } + + QUICHE_DCHECK_LT(absolute_index, std::numeric_limits::max()); + required_insert_count_so_far_ = + std::max(required_insert_count_so_far_, absolute_index + 1); + + auto entry = + header_table_->LookupEntry(/* is_static = */ false, absolute_index); + if (!entry) { + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, + "Dynamic table entry already evicted."); + return false; + } + + header_table_->set_dynamic_table_entry_referenced(); + return OnHeaderDecoded(!kValueFromStaticTable, entry->name(), entry->value()); +} + +bool QpackProgressiveDecoder::DoLiteralHeaderFieldNameReferenceInstruction() { + if (!instruction_decoder_.s_bit()) { + uint64_t absolute_index; + if (!QpackRequestStreamRelativeIndexToAbsoluteIndex( + instruction_decoder_.varint(), base_, &absolute_index)) { + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, "Invalid relative index."); + return false; + } + + if (absolute_index >= required_insert_count_) { + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, + "Absolute Index must be smaller than Required Insert Count."); + return false; + } + + QUICHE_DCHECK_LT(absolute_index, std::numeric_limits::max()); + required_insert_count_so_far_ = + std::max(required_insert_count_so_far_, absolute_index + 1); + + auto entry = + header_table_->LookupEntry(/* is_static = */ false, absolute_index); + if (!entry) { + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, + "Dynamic table entry already evicted."); + return false; + } + + header_table_->set_dynamic_table_entry_referenced(); + return OnHeaderDecoded(!kValueFromStaticTable, entry->name(), + instruction_decoder_.value()); + } + + auto entry = header_table_->LookupEntry(/* is_static = */ true, + instruction_decoder_.varint()); + if (!entry) { + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, "Static table entry not found."); + return false; + } + + return OnHeaderDecoded(kValueFromStaticTable, entry->name(), + instruction_decoder_.value()); +} + +bool QpackProgressiveDecoder::DoLiteralHeaderFieldPostBaseInstruction() { + uint64_t absolute_index; + if (!QpackPostBaseIndexToAbsoluteIndex(instruction_decoder_.varint(), base_, + &absolute_index)) { + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, "Invalid post-base index."); + return false; + } + + if (absolute_index >= required_insert_count_) { + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, + "Absolute Index must be smaller than Required Insert Count."); + return false; + } + + QUICHE_DCHECK_LT(absolute_index, std::numeric_limits::max()); + required_insert_count_so_far_ = + std::max(required_insert_count_so_far_, absolute_index + 1); + + auto entry = + header_table_->LookupEntry(/* is_static = */ false, absolute_index); + if (!entry) { + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, + "Dynamic table entry already evicted."); + return false; + } + + header_table_->set_dynamic_table_entry_referenced(); + return OnHeaderDecoded(!kValueFromStaticTable, entry->name(), + instruction_decoder_.value()); +} + +bool QpackProgressiveDecoder::DoLiteralHeaderFieldInstruction() { + return OnHeaderDecoded(!kValueFromStaticTable, instruction_decoder_.name(), + instruction_decoder_.value()); +} + +bool QpackProgressiveDecoder::DoPrefixInstruction() { + QUICHE_DCHECK(!prefix_decoded_); + + if (!QpackDecodeRequiredInsertCount( + prefix_decoder_->varint(), header_table_->max_entries(), + header_table_->inserted_entry_count(), &required_insert_count_)) { + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, + "Error decoding Required Insert Count."); + return false; + } + + const bool sign = prefix_decoder_->s_bit(); + const uint64_t delta_base = prefix_decoder_->varint2(); + if (!DeltaBaseToBase(sign, delta_base, &base_)) { + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, "Error calculating Base."); + return false; + } + + prefix_decoded_ = true; + + if (required_insert_count_ > header_table_->inserted_entry_count()) { + if (!enforcer_->OnStreamBlocked(stream_id_)) { + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, + "Limit on number of blocked streams exceeded."); + return false; + } + blocked_ = true; + header_table_->RegisterObserver(required_insert_count_, this); + } + + return true; +} + +bool QpackProgressiveDecoder::OnHeaderDecoded(bool /*value_from_static_table*/, + absl::string_view name, + absl::string_view value) { + handler_->OnHeaderDecoded(name, value); + return true; +} + +void QpackProgressiveDecoder::FinishDecoding() { + QUICHE_DCHECK(buffer_.empty()); + QUICHE_DCHECK(!blocked_); + QUICHE_DCHECK(!decoding_); + + if (error_detected_) { + return; + } + + if (!instruction_decoder_.AtInstructionBoundary()) { + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, "Incomplete header block."); + return; + } + + if (!prefix_decoded_) { + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, "Incomplete header data prefix."); + return; + } + + if (required_insert_count_ != required_insert_count_so_far_) { + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, + "Required Insert Count too large."); + return; + } + + visitor_->OnDecodingCompleted(stream_id_, required_insert_count_); + handler_->OnDecodingCompleted(); +} + +void QpackProgressiveDecoder::OnError(QuicErrorCode error_code, + absl::string_view error_message) { + QUICHE_DCHECK(!error_detected_); + + error_detected_ = true; + // Might destroy |this|. + handler_->OnDecodingErrorDetected(error_code, error_message); +} + +bool QpackProgressiveDecoder::DeltaBaseToBase(bool sign, uint64_t delta_base, + uint64_t* base) { + if (sign) { + if (delta_base == std::numeric_limits::max() || + required_insert_count_ < delta_base + 1) { + return false; + } + *base = required_insert_count_ - delta_base - 1; + return true; + } + + if (delta_base > + std::numeric_limits::max() - required_insert_count_) { + return false; + } + *base = required_insert_count_ + delta_base; + return true; +} + +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_progressive_decoder.h b/quiche/quic/core/qpack/qpack_progressive_decoder.h new file mode 100644 index 000000000000..7616376574e5 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_progressive_decoder.h @@ -0,0 +1,183 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QPACK_QPACK_PROGRESSIVE_DECODER_H_ +#define QUICHE_QUIC_CORE_QPACK_QPACK_PROGRESSIVE_DECODER_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/qpack/qpack_encoder_stream_receiver.h" +#include "quiche/quic/core/qpack/qpack_header_table.h" +#include "quiche/quic/core/qpack/qpack_instruction_decoder.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class QpackDecoderHeaderTable; + +// Class to decode a single header block. +class QUIC_EXPORT_PRIVATE QpackProgressiveDecoder + : public QpackInstructionDecoder::Delegate, + public QpackDecoderHeaderTable::Observer { + public: + // Interface for receiving decoded header block from the decoder. + class QUIC_EXPORT_PRIVATE HeadersHandlerInterface { + public: + virtual ~HeadersHandlerInterface() {} + + // Called when a new header name-value pair is decoded. Multiple values for + // a given name will be emitted as multiple calls to OnHeader. + virtual void OnHeaderDecoded(absl::string_view name, + absl::string_view value) = 0; + + // Called when the header block is completely decoded. + // Indicates the total number of bytes in this block. + // The decoder will not access the handler after this call. + // Note that this method might not be called synchronously when the header + // block is received on the wire, in case decoding is blocked on receiving + // entries on the encoder stream. + virtual void OnDecodingCompleted() = 0; + + // Called when a decoding error has occurred. No other methods will be + // called afterwards. Implementations are allowed to destroy + // the QpackProgressiveDecoder instance synchronously. + virtual void OnDecodingErrorDetected(QuicErrorCode error_code, + absl::string_view error_message) = 0; + }; + + // Interface for keeping track of blocked streams for the purpose of enforcing + // the limit communicated to peer via QPACK_BLOCKED_STREAMS settings. + class QUIC_EXPORT_PRIVATE BlockedStreamLimitEnforcer { + public: + virtual ~BlockedStreamLimitEnforcer() {} + + // Called when the stream becomes blocked. Returns true if allowed. Returns + // false if limit is violated, in which case QpackProgressiveDecoder signals + // an error. + // Stream must not be already blocked. + virtual bool OnStreamBlocked(QuicStreamId stream_id) = 0; + + // Called when the stream becomes unblocked. + // Stream must be blocked. + virtual void OnStreamUnblocked(QuicStreamId stream_id) = 0; + }; + + // Visitor to be notified when decoding is completed. + class QUIC_EXPORT_PRIVATE DecodingCompletedVisitor { + public: + virtual ~DecodingCompletedVisitor() = default; + + // Called when decoding is completed, with Required Insert Count of the + // decoded header block. Required Insert Count is defined at + // https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#blocked-streams. + virtual void OnDecodingCompleted(QuicStreamId stream_id, + uint64_t required_insert_count) = 0; + }; + + QpackProgressiveDecoder() = delete; + QpackProgressiveDecoder(QuicStreamId stream_id, + BlockedStreamLimitEnforcer* enforcer, + DecodingCompletedVisitor* visitor, + QpackDecoderHeaderTable* header_table, + HeadersHandlerInterface* handler); + QpackProgressiveDecoder(const QpackProgressiveDecoder&) = delete; + QpackProgressiveDecoder& operator=(const QpackProgressiveDecoder&) = delete; + ~QpackProgressiveDecoder() override; + + // Provide a data fragment to decode. + void Decode(absl::string_view data); + + // Signal that the entire header block has been received and passed in + // through Decode(). No methods must be called afterwards. + void EndHeaderBlock(); + + // QpackInstructionDecoder::Delegate implementation. + bool OnInstructionDecoded(const QpackInstruction* instruction) override; + void OnInstructionDecodingError(QpackInstructionDecoder::ErrorCode error_code, + absl::string_view error_message) override; + + // QpackDecoderHeaderTable::Observer implementation. + void OnInsertCountReachedThreshold() override; + void Cancel() override; + + private: + bool DoIndexedHeaderFieldInstruction(); + bool DoIndexedHeaderFieldPostBaseInstruction(); + bool DoLiteralHeaderFieldNameReferenceInstruction(); + bool DoLiteralHeaderFieldPostBaseInstruction(); + bool DoLiteralHeaderFieldInstruction(); + bool DoPrefixInstruction(); + + // Called when an entry is decoded. Performs validation and calls + // HeadersHandlerInterface::OnHeaderDecoded() or OnError() as needed. Returns + // true if header value is valid, false otherwise. Skips validation if + // |value_from_static_table| is true, because static table entries are always + // valid. + bool OnHeaderDecoded(bool value_from_static_table, absl::string_view name, + absl::string_view value); + + // Called as soon as EndHeaderBlock() is called and decoding is not blocked. + void FinishDecoding(); + + // Called on error. + void OnError(QuicErrorCode error_code, absl::string_view error_message); + + // Calculates Base from |required_insert_count_|, which must be set before + // calling this method, and sign bit and Delta Base in the Header Data Prefix, + // which are passed in as arguments. Returns true on success, false on + // failure due to overflow/underflow. + bool DeltaBaseToBase(bool sign, uint64_t delta_base, uint64_t* base); + + const QuicStreamId stream_id_; + + // |prefix_decoder_| only decodes a handful of bytes then it can be + // destroyed to conserve memory. |instruction_decoder_|, on the other hand, + // is used until the entire header block is decoded. + std::unique_ptr prefix_decoder_; + QpackInstructionDecoder instruction_decoder_; + + BlockedStreamLimitEnforcer* const enforcer_; + DecodingCompletedVisitor* const visitor_; + QpackDecoderHeaderTable* const header_table_; + HeadersHandlerInterface* const handler_; + + // Required Insert Count and Base are decoded from the Header Data Prefix. + uint64_t required_insert_count_; + uint64_t base_; + + // Required Insert Count is one larger than the largest absolute index of all + // referenced dynamic table entries, or zero if no dynamic table entries are + // referenced. |required_insert_count_so_far_| starts out as zero and keeps + // track of the Required Insert Count based on entries decoded so far. + // After decoding is completed, it is compared to |required_insert_count_|. + uint64_t required_insert_count_so_far_; + + // False until prefix is fully read and decoded. + bool prefix_decoded_; + + // True if waiting for dynamic table entries to arrive. + bool blocked_; + + // Buffer the entire header block after the prefix while decoding is blocked. + std::string buffer_; + + // True until EndHeaderBlock() is called. + bool decoding_; + + // True if a decoding error has been detected. + bool error_detected_; + + // True if QpackDecoderHeaderTable has been destroyed + // while decoding is still blocked. + bool cancelled_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QPACK_QPACK_PROGRESSIVE_DECODER_H_ diff --git a/quiche/quic/core/qpack/qpack_receive_stream.cc b/quiche/quic/core/qpack/qpack_receive_stream.cc new file mode 100644 index 000000000000..b00a06f4c370 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_receive_stream.cc @@ -0,0 +1,33 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_receive_stream.h" + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_session.h" + +namespace quic { +QpackReceiveStream::QpackReceiveStream(PendingStream* pending, + QuicSession* session, + QpackStreamReceiver* receiver) + : QuicStream(pending, session, /*is_static=*/true), receiver_(receiver) {} + +void QpackReceiveStream::OnStreamReset(const QuicRstStreamFrame& /*frame*/) { + stream_delegate()->OnStreamError( + QUIC_HTTP_CLOSED_CRITICAL_STREAM, + "RESET_STREAM received for QPACK receive stream"); +} + +void QpackReceiveStream::OnDataAvailable() { + iovec iov; + while (!reading_stopped() && sequencer()->GetReadableRegion(&iov)) { + QUICHE_DCHECK(!sequencer()->IsClosed()); + + receiver_->Decode(absl::string_view( + reinterpret_cast(iov.iov_base), iov.iov_len)); + sequencer()->MarkConsumed(iov.iov_len); + } +} + +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_receive_stream.h b/quiche/quic/core/qpack/qpack_receive_stream.h new file mode 100644 index 000000000000..20ad878fd153 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_receive_stream.h @@ -0,0 +1,41 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QPACK_QPACK_RECEIVE_STREAM_H_ +#define QUICHE_QUIC_CORE_QPACK_QPACK_RECEIVE_STREAM_H_ + +#include "quiche/quic/core/qpack/qpack_stream_receiver.h" +#include "quiche/quic/core/quic_stream.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class QuicSession; + +// QPACK 4.2.1 Encoder and Decoder Streams. +// The QPACK receive stream is peer initiated and is read only. +class QUIC_EXPORT_PRIVATE QpackReceiveStream : public QuicStream { + public: + // Construct receive stream from pending stream, the |pending| object needs + // to be deleted after the construction. + QpackReceiveStream(PendingStream* pending, QuicSession* session, + QpackStreamReceiver* receiver); + QpackReceiveStream(const QpackReceiveStream&) = delete; + QpackReceiveStream& operator=(const QpackReceiveStream&) = delete; + ~QpackReceiveStream() override = default; + + // Overriding QuicStream::OnStreamReset to make sure QPACK stream is never + // closed before connection. + void OnStreamReset(const QuicRstStreamFrame& frame) override; + + // Implementation of QuicStream. + void OnDataAvailable() override; + + private: + QpackStreamReceiver* receiver_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QPACK_QPACK_RECEIVE_STREAM_H_ diff --git a/quiche/quic/core/qpack/qpack_receive_stream_test.cc b/quiche/quic/core/qpack/qpack_receive_stream_test.cc new file mode 100644 index 000000000000..4f94802e1387 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_receive_stream_test.cc @@ -0,0 +1,95 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_receive_stream.h" + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_spdy_session_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { +namespace test { + +namespace { +using ::testing::_; +using ::testing::AnyNumber; +using ::testing::StrictMock; + +struct TestParams { + TestParams(const ParsedQuicVersion& version, Perspective perspective) + : version(version), perspective(perspective) { + QUIC_LOG(INFO) << "TestParams: version: " + << ParsedQuicVersionToString(version) + << ", perspective: " << perspective; + } + + TestParams(const TestParams& other) + : version(other.version), perspective(other.perspective) {} + + ParsedQuicVersion version; + Perspective perspective; +}; + +std::vector GetTestParams() { + std::vector params; + ParsedQuicVersionVector all_supported_versions = AllSupportedVersions(); + for (const auto& version : AllSupportedVersions()) { + if (!VersionUsesHttp3(version.transport_version)) { + continue; + } + for (Perspective p : {Perspective::IS_SERVER, Perspective::IS_CLIENT}) { + params.emplace_back(version, p); + } + } + return params; +} + +class QpackReceiveStreamTest : public QuicTestWithParam { + public: + QpackReceiveStreamTest() + : connection_(new StrictMock( + &helper_, &alarm_factory_, perspective(), + SupportedVersions(GetParam().version))), + session_(connection_) { + EXPECT_CALL(session_, OnCongestionWindowChange(_)).Times(AnyNumber()); + session_.Initialize(); + QuicStreamId id = perspective() == Perspective::IS_SERVER + ? GetNthClientInitiatedUnidirectionalStreamId( + session_.transport_version(), 3) + : GetNthServerInitiatedUnidirectionalStreamId( + session_.transport_version(), 3); + char type[] = {0x03}; + QuicStreamFrame data1(id, false, 0, absl::string_view(type, 1)); + session_.OnStreamFrame(data1); + qpack_receive_stream_ = + QuicSpdySessionPeer::GetQpackDecoderReceiveStream(&session_); + } + + Perspective perspective() const { return GetParam().perspective; } + + MockQuicConnectionHelper helper_; + MockAlarmFactory alarm_factory_; + StrictMock* connection_; + StrictMock session_; + QpackReceiveStream* qpack_receive_stream_; +}; + +INSTANTIATE_TEST_SUITE_P(Tests, QpackReceiveStreamTest, + ::testing::ValuesIn(GetTestParams())); + +TEST_P(QpackReceiveStreamTest, ResetQpackReceiveStream) { + EXPECT_TRUE(qpack_receive_stream_->is_static()); + QuicRstStreamFrame rst_frame(kInvalidControlFrameId, + qpack_receive_stream_->id(), + QUIC_STREAM_CANCELLED, 1234); + EXPECT_CALL(*connection_, + CloseConnection(QUIC_HTTP_CLOSED_CRITICAL_STREAM, _, _)); + qpack_receive_stream_->OnStreamReset(rst_frame); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_required_insert_count.cc b/quiche/quic/core/qpack/qpack_required_insert_count.cc new file mode 100644 index 000000000000..544dffefe1fd --- /dev/null +++ b/quiche/quic/core/qpack/qpack_required_insert_count.cc @@ -0,0 +1,71 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_required_insert_count.h" + +#include + +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +uint64_t QpackEncodeRequiredInsertCount(uint64_t required_insert_count, + uint64_t max_entries) { + if (required_insert_count == 0) { + return 0; + } + + return required_insert_count % (2 * max_entries) + 1; +} + +bool QpackDecodeRequiredInsertCount(uint64_t encoded_required_insert_count, + uint64_t max_entries, + uint64_t total_number_of_inserts, + uint64_t* required_insert_count) { + if (encoded_required_insert_count == 0) { + *required_insert_count = 0; + return true; + } + + // |max_entries| is calculated by dividing an unsigned 64-bit integer by 32, + // precluding all calculations in this method from overflowing. + QUICHE_DCHECK_LE(max_entries, std::numeric_limits::max() / 32); + + if (encoded_required_insert_count > 2 * max_entries) { + return false; + } + + *required_insert_count = encoded_required_insert_count - 1; + QUICHE_DCHECK_LT(*required_insert_count, + std::numeric_limits::max() / 16); + + uint64_t current_wrapped = total_number_of_inserts % (2 * max_entries); + QUICHE_DCHECK_LT(current_wrapped, std::numeric_limits::max() / 16); + + if (current_wrapped >= *required_insert_count + max_entries) { + // Required Insert Count wrapped around 1 extra time. + *required_insert_count += 2 * max_entries; + } else if (current_wrapped + max_entries < *required_insert_count) { + // Decoder wrapped around 1 extra time. + current_wrapped += 2 * max_entries; + } + + if (*required_insert_count > + std::numeric_limits::max() - total_number_of_inserts) { + return false; + } + + *required_insert_count += total_number_of_inserts; + + // Prevent underflow, also disallow invalid value 0 for Required Insert Count. + if (current_wrapped >= *required_insert_count) { + return false; + } + + *required_insert_count -= current_wrapped; + + return true; +} + +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_required_insert_count.h b/quiche/quic/core/qpack/qpack_required_insert_count.h new file mode 100644 index 000000000000..762dfb1e5fcb --- /dev/null +++ b/quiche/quic/core/qpack/qpack_required_insert_count.h @@ -0,0 +1,30 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QPACK_QPACK_REQUIRED_INSERT_COUNT_H_ +#define QUICHE_QUIC_CORE_QPACK_QPACK_REQUIRED_INSERT_COUNT_H_ + +#include + +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// Calculate Encoded Required Insert Count from Required Insert Count and +// MaxEntries according to +// https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#ric. +QUIC_EXPORT_PRIVATE uint64_t QpackEncodeRequiredInsertCount( + uint64_t required_insert_count, uint64_t max_entries); + +// Calculate Required Insert Count from Encoded Required Insert Count, +// MaxEntries, and total number of dynamic table insertions according to +// https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#ric. Returns true +// on success, false on invalid input or overflow/underflow. +QUIC_EXPORT_PRIVATE bool QpackDecodeRequiredInsertCount( + uint64_t encoded_required_insert_count, uint64_t max_entries, + uint64_t total_number_of_inserts, uint64_t* required_insert_count); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QPACK_QPACK_REQUIRED_INSERT_COUNT_H_ diff --git a/quiche/quic/core/qpack/qpack_required_insert_count_test.cc b/quiche/quic/core/qpack/qpack_required_insert_count_test.cc new file mode 100644 index 000000000000..95456b9d2e72 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_required_insert_count_test.cc @@ -0,0 +1,125 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_required_insert_count.h" + +#include "absl/base/macros.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { +namespace { + +TEST(QpackRequiredInsertCountTest, QpackEncodeRequiredInsertCount) { + EXPECT_EQ(0u, QpackEncodeRequiredInsertCount(0, 0)); + EXPECT_EQ(0u, QpackEncodeRequiredInsertCount(0, 8)); + EXPECT_EQ(0u, QpackEncodeRequiredInsertCount(0, 1024)); + + EXPECT_EQ(2u, QpackEncodeRequiredInsertCount(1, 8)); + EXPECT_EQ(5u, QpackEncodeRequiredInsertCount(20, 8)); + EXPECT_EQ(7u, QpackEncodeRequiredInsertCount(106, 10)); +} + +// For testing valid decodings, the Encoded Required Insert Count is calculated +// from Required Insert Count, so that there is an expected value to compare +// the decoded value against, and so that intricate inequalities can be +// documented. +struct { + uint64_t required_insert_count; + uint64_t max_entries; + uint64_t total_number_of_inserts; +} kTestData[] = { + // Maximum dynamic table capacity is zero. + {0, 0, 0}, + // No dynamic entries in header. + {0, 100, 0}, + {0, 100, 500}, + // Required Insert Count has not wrapped around yet, no entries evicted. + {15, 100, 25}, + {20, 100, 10}, + // Required Insert Count has not wrapped around yet, some entries evicted. + {90, 100, 110}, + // Required Insert Count has wrapped around. + {234, 100, 180}, + // Required Insert Count has wrapped around many times. + {5678, 100, 5701}, + // Lowest and highest possible Required Insert Count values + // for given MaxEntries and total number of insertions. + {401, 100, 500}, + {600, 100, 500}}; + +TEST(QpackRequiredInsertCountTest, QpackDecodeRequiredInsertCount) { + for (size_t i = 0; i < ABSL_ARRAYSIZE(kTestData); ++i) { + const uint64_t required_insert_count = kTestData[i].required_insert_count; + const uint64_t max_entries = kTestData[i].max_entries; + const uint64_t total_number_of_inserts = + kTestData[i].total_number_of_inserts; + + if (required_insert_count != 0) { + // Dynamic entries cannot be referenced if dynamic table capacity is zero. + ASSERT_LT(0u, max_entries) << i; + // Entry |total_number_of_inserts - 1 - max_entries| and earlier entries + // are evicted. Entry |required_insert_count - 1| is referenced. No + // evicted entry can be referenced. + ASSERT_LT(total_number_of_inserts, required_insert_count + max_entries) + << i; + // Entry |required_insert_count - 1 - max_entries| and earlier entries are + // evicted, entry |total_number_of_inserts - 1| is the last acknowledged + // entry. Every evicted entry must be acknowledged. + ASSERT_LE(required_insert_count, total_number_of_inserts + max_entries) + << i; + } + + uint64_t encoded_required_insert_count = + QpackEncodeRequiredInsertCount(required_insert_count, max_entries); + + // Initialize to a value different from the expected output to confirm that + // QpackDecodeRequiredInsertCount() modifies the value of + // |decoded_required_insert_count|. + uint64_t decoded_required_insert_count = required_insert_count + 1; + EXPECT_TRUE(QpackDecodeRequiredInsertCount( + encoded_required_insert_count, max_entries, total_number_of_inserts, + &decoded_required_insert_count)) + << i; + + EXPECT_EQ(decoded_required_insert_count, required_insert_count) << i; + } +} + +// Failures are tested with hardcoded values for encoded required insert count, +// to provide test coverage for values that would never be produced by a well +// behaved encoding function. +struct { + uint64_t encoded_required_insert_count; + uint64_t max_entries; + uint64_t total_number_of_inserts; +} kInvalidTestData[] = { + // Maximum dynamic table capacity is zero, yet header block + // claims to have a reference to a dynamic table entry. + {1, 0, 0}, + {9, 0, 0}, + // Examples from + // https://github.com/quicwg/base-drafts/issues/2112#issue-389626872. + {1, 10, 2}, + {18, 10, 2}, + // Encoded Required Insert Count value too small or too large + // for given MaxEntries and total number of insertions. + {400, 100, 500}, + {601, 100, 500}}; + +TEST(QpackRequiredInsertCountTest, DecodeRequiredInsertCountError) { + for (size_t i = 0; i < ABSL_ARRAYSIZE(kInvalidTestData); ++i) { + uint64_t decoded_required_insert_count = 0; + EXPECT_FALSE(QpackDecodeRequiredInsertCount( + kInvalidTestData[i].encoded_required_insert_count, + kInvalidTestData[i].max_entries, + kInvalidTestData[i].total_number_of_inserts, + &decoded_required_insert_count)) + << i; + } +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_round_trip_test.cc b/quiche/quic/core/qpack/qpack_round_trip_test.cc new file mode 100644 index 000000000000..147c627e2842 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_round_trip_test.cc @@ -0,0 +1,137 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/qpack/qpack_decoder.h" +#include "quiche/quic/core/qpack/qpack_encoder.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/qpack/qpack_decoder_test_utils.h" +#include "quiche/quic/test_tools/qpack/qpack_test_utils.h" +#include "quiche/spdy/core/http2_header_block.h" + +using ::testing::Values; + +namespace quic { +namespace test { +namespace { + +class QpackRoundTripTest : public QuicTestWithParam { + public: + QpackRoundTripTest() = default; + ~QpackRoundTripTest() override = default; + + spdy::Http2HeaderBlock EncodeThenDecode( + const spdy::Http2HeaderBlock& header_list) { + NoopDecoderStreamErrorDelegate decoder_stream_error_delegate; + NoopQpackStreamSenderDelegate encoder_stream_sender_delegate; + QpackEncoder encoder(&decoder_stream_error_delegate); + encoder.set_qpack_stream_sender_delegate(&encoder_stream_sender_delegate); + std::string encoded_header_block = + encoder.EncodeHeaderList(/* stream_id = */ 1, header_list, nullptr); + + TestHeadersHandler handler; + NoopEncoderStreamErrorDelegate encoder_stream_error_delegate; + NoopQpackStreamSenderDelegate decoder_stream_sender_delegate; + // TODO(b/112770235): Test dynamic table and blocked streams. + QpackDecode( + /* maximum_dynamic_table_capacity = */ 0, + /* maximum_blocked_streams = */ 0, &encoder_stream_error_delegate, + &decoder_stream_sender_delegate, &handler, + FragmentModeToFragmentSizeGenerator(GetParam()), encoded_header_block); + + EXPECT_TRUE(handler.decoding_completed()); + EXPECT_FALSE(handler.decoding_error_detected()); + + return handler.ReleaseHeaderList(); + } +}; + +INSTANTIATE_TEST_SUITE_P(All, QpackRoundTripTest, + Values(FragmentMode::kSingleChunk, + FragmentMode::kOctetByOctet)); + +TEST_P(QpackRoundTripTest, Empty) { + spdy::Http2HeaderBlock header_list; + spdy::Http2HeaderBlock output = EncodeThenDecode(header_list); + EXPECT_EQ(header_list, output); +} + +TEST_P(QpackRoundTripTest, EmptyName) { + spdy::Http2HeaderBlock header_list; + header_list["foo"] = "bar"; + header_list[""] = "bar"; + + spdy::Http2HeaderBlock output = EncodeThenDecode(header_list); + EXPECT_EQ(header_list, output); +} + +TEST_P(QpackRoundTripTest, EmptyValue) { + spdy::Http2HeaderBlock header_list; + header_list["foo"] = ""; + header_list[""] = ""; + + spdy::Http2HeaderBlock output = EncodeThenDecode(header_list); + EXPECT_EQ(header_list, output); +} + +TEST_P(QpackRoundTripTest, MultipleWithLongEntries) { + spdy::Http2HeaderBlock header_list; + header_list["foo"] = "bar"; + header_list[":path"] = "/"; + header_list["foobaar"] = std::string(127, 'Z'); + header_list[std::string(1000, 'b')] = std::string(1000, 'c'); + + spdy::Http2HeaderBlock output = EncodeThenDecode(header_list); + EXPECT_EQ(header_list, output); +} + +TEST_P(QpackRoundTripTest, StaticTable) { + { + spdy::Http2HeaderBlock header_list; + header_list[":method"] = "GET"; + header_list["accept-encoding"] = "gzip, deflate"; + header_list["cache-control"] = ""; + header_list["foo"] = "bar"; + header_list[":path"] = "/"; + + spdy::Http2HeaderBlock output = EncodeThenDecode(header_list); + EXPECT_EQ(header_list, output); + } + { + spdy::Http2HeaderBlock header_list; + header_list[":method"] = "POST"; + header_list["accept-encoding"] = "brotli"; + header_list["cache-control"] = "foo"; + header_list["foo"] = "bar"; + header_list[":path"] = "/"; + + spdy::Http2HeaderBlock output = EncodeThenDecode(header_list); + EXPECT_EQ(header_list, output); + } + { + spdy::Http2HeaderBlock header_list; + header_list[":method"] = "CONNECT"; + header_list["accept-encoding"] = ""; + header_list["foo"] = "bar"; + header_list[":path"] = "/"; + + spdy::Http2HeaderBlock output = EncodeThenDecode(header_list); + EXPECT_EQ(header_list, output); + } +} + +TEST_P(QpackRoundTripTest, ValueHasNullCharacter) { + spdy::Http2HeaderBlock header_list; + header_list["foo"] = absl::string_view("bar\0bar\0baz", 11); + + spdy::Http2HeaderBlock output = EncodeThenDecode(header_list); + EXPECT_EQ(header_list, output); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_send_stream.cc b/quiche/quic/core/qpack/qpack_send_stream.cc new file mode 100644 index 000000000000..9616c1a4a200 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_send_stream.cc @@ -0,0 +1,51 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_send_stream.h" + +#include "absl/base/macros.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_session.h" + +namespace quic { +QpackSendStream::QpackSendStream(QuicStreamId id, QuicSession* session, + uint64_t http3_stream_type) + : QuicStream(id, session, /*is_static = */ true, WRITE_UNIDIRECTIONAL), + http3_stream_type_(http3_stream_type), + stream_type_sent_(false) {} + +void QpackSendStream::OnStreamReset(const QuicRstStreamFrame& /*frame*/) { + QUIC_BUG(quic_bug_10805_1) + << "OnStreamReset() called for write unidirectional stream."; +} + +bool QpackSendStream::OnStopSending(QuicResetStreamError /* code */) { + stream_delegate()->OnStreamError( + QUIC_HTTP_CLOSED_CRITICAL_STREAM, + "STOP_SENDING received for QPACK send stream"); + return false; +} + +void QpackSendStream::WriteStreamData(absl::string_view data) { + QuicConnection::ScopedPacketFlusher flusher(session()->connection()); + MaybeSendStreamType(); + WriteOrBufferData(data, false, nullptr); +} + +uint64_t QpackSendStream::NumBytesBuffered() const { + return QuicStream::BufferedDataBytes(); +} + +void QpackSendStream::MaybeSendStreamType() { + if (!stream_type_sent_) { + char type[sizeof(http3_stream_type_)]; + QuicDataWriter writer(ABSL_ARRAYSIZE(type), type); + writer.WriteVarInt62(http3_stream_type_); + WriteOrBufferData(absl::string_view(writer.data(), writer.length()), false, + nullptr); + stream_type_sent_ = true; + } +} + +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_send_stream.h b/quiche/quic/core/qpack/qpack_send_stream.h new file mode 100644 index 000000000000..8c9a764345f9 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_send_stream.h @@ -0,0 +1,60 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QPACK_QPACK_SEND_STREAM_H_ +#define QUICHE_QUIC_CORE_QPACK_QPACK_SEND_STREAM_H_ + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/qpack/qpack_stream_sender_delegate.h" +#include "quiche/quic/core/quic_stream.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace quic { + +class QuicSession; + +// QPACK 4.2.1 Encoder and Decoder Streams. +// The QPACK send stream is self initiated and is write only. +class QUIC_EXPORT_PRIVATE QpackSendStream : public QuicStream, + public QpackStreamSenderDelegate { + public: + // |session| can't be nullptr, and the ownership is not passed. |session| owns + // this stream. + QpackSendStream(QuicStreamId id, QuicSession* session, + uint64_t http3_stream_type); + QpackSendStream(const QpackSendStream&) = delete; + QpackSendStream& operator=(const QpackSendStream&) = delete; + ~QpackSendStream() override = default; + + // Overriding QuicStream::OnStopSending() to make sure QPACK stream is never + // closed before connection. + void OnStreamReset(const QuicRstStreamFrame& frame) override; + bool OnStopSending(QuicResetStreamError code) override; + + // The send QPACK stream is write unidirectional, so this method + // should never be called. + void OnDataAvailable() override { QUICHE_NOTREACHED(); } + + // Writes the instructions to peer. The stream type will be sent + // before the first instruction so that the peer can open an qpack stream. + void WriteStreamData(absl::string_view data) override; + + // Return the number of bytes buffered due to underlying stream being blocked. + uint64_t NumBytesBuffered() const override; + + // TODO(b/112770235): Remove this method once QuicStreamIdManager supports + // creating HTTP/3 unidirectional streams dynamically. + void MaybeSendStreamType(); + + private: + const uint64_t http3_stream_type_; + bool stream_type_sent_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QPACK_QPACK_SEND_STREAM_H_ diff --git a/quiche/quic/core/qpack/qpack_send_stream_test.cc b/quiche/quic/core/qpack/qpack_send_stream_test.cc new file mode 100644 index 000000000000..4f0cc2a6cbd0 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_send_stream_test.cc @@ -0,0 +1,133 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_send_stream.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/core/http/http_constants.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_config_peer.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_spdy_session_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { +namespace test { + +namespace { +using ::testing::_; +using ::testing::AnyNumber; +using ::testing::Invoke; +using ::testing::StrictMock; + +struct TestParams { + TestParams(const ParsedQuicVersion& version, Perspective perspective) + : version(version), perspective(perspective) { + QUIC_LOG(INFO) << "TestParams: version: " + << ParsedQuicVersionToString(version) + << ", perspective: " << perspective; + } + + TestParams(const TestParams& other) + : version(other.version), perspective(other.perspective) {} + + ParsedQuicVersion version; + Perspective perspective; +}; + +// Used by ::testing::PrintToStringParamName(). +std::string PrintToString(const TestParams& tp) { + return absl::StrCat( + ParsedQuicVersionToString(tp.version), "_", + (tp.perspective == Perspective::IS_CLIENT ? "client" : "server")); +} + +std::vector GetTestParams() { + std::vector params; + ParsedQuicVersionVector all_supported_versions = AllSupportedVersions(); + for (const auto& version : AllSupportedVersions()) { + if (!VersionUsesHttp3(version.transport_version)) { + continue; + } + for (Perspective p : {Perspective::IS_SERVER, Perspective::IS_CLIENT}) { + params.emplace_back(version, p); + } + } + return params; +} + +class QpackSendStreamTest : public QuicTestWithParam { + public: + QpackSendStreamTest() + : connection_(new StrictMock( + &helper_, &alarm_factory_, perspective(), + SupportedVersions(GetParam().version))), + session_(connection_) { + EXPECT_CALL(session_, OnCongestionWindowChange(_)).Times(AnyNumber()); + session_.Initialize(); + connection_->SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(connection_->perspective())); + if (connection_->version().SupportsAntiAmplificationLimit()) { + QuicConnectionPeer::SetAddressValidated(connection_); + } + QuicConfigPeer::SetReceivedInitialSessionFlowControlWindow( + session_.config(), kMinimumFlowControlSendWindow); + QuicConfigPeer::SetReceivedInitialMaxStreamDataBytesUnidirectional( + session_.config(), kMinimumFlowControlSendWindow); + QuicConfigPeer::SetReceivedMaxUnidirectionalStreams(session_.config(), 3); + session_.OnConfigNegotiated(); + + qpack_send_stream_ = + QuicSpdySessionPeer::GetQpackDecoderSendStream(&session_); + + ON_CALL(session_, WritevData(_, _, _, _, _, _)) + .WillByDefault(Invoke(&session_, &MockQuicSpdySession::ConsumeData)); + } + + Perspective perspective() const { return GetParam().perspective; } + + MockQuicConnectionHelper helper_; + MockAlarmFactory alarm_factory_; + StrictMock* connection_; + StrictMock session_; + QpackSendStream* qpack_send_stream_; +}; + +INSTANTIATE_TEST_SUITE_P(Tests, QpackSendStreamTest, + ::testing::ValuesIn(GetTestParams()), + ::testing::PrintToStringParamName()); + +TEST_P(QpackSendStreamTest, WriteStreamTypeOnlyFirstTime) { + std::string data = "data"; + EXPECT_CALL(session_, WritevData(_, 1, _, _, _, _)); + EXPECT_CALL(session_, WritevData(_, data.length(), _, _, _, _)); + qpack_send_stream_->WriteStreamData(absl::string_view(data)); + + EXPECT_CALL(session_, WritevData(_, data.length(), _, _, _, _)); + qpack_send_stream_->WriteStreamData(absl::string_view(data)); + EXPECT_CALL(session_, WritevData(_, _, _, _, _, _)).Times(0); + qpack_send_stream_->MaybeSendStreamType(); +} + +TEST_P(QpackSendStreamTest, StopSendingQpackStream) { + EXPECT_CALL(*connection_, + CloseConnection(QUIC_HTTP_CLOSED_CRITICAL_STREAM, _, _)); + qpack_send_stream_->OnStopSending( + QuicResetStreamError::FromInternal(QUIC_STREAM_CANCELLED)); +} + +TEST_P(QpackSendStreamTest, ReceiveDataOnSendStream) { + QuicStreamFrame frame(qpack_send_stream_->id(), false, 0, "test"); + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_DATA_RECEIVED_ON_WRITE_UNIDIRECTIONAL_STREAM, _, _)); + qpack_send_stream_->OnStreamFrame(frame); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_static_table.cc b/quiche/quic/core/qpack/qpack_static_table.cc new file mode 100644 index 000000000000..fc8566716955 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_static_table.cc @@ -0,0 +1,139 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_static_table.h" + +#include "absl/base/macros.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +// The "constructor" for a QpackStaticEntry that computes the lengths at +// compile time. +#define STATIC_ENTRY(name, value) \ + { name, ABSL_ARRAYSIZE(name) - 1, value, ABSL_ARRAYSIZE(value) - 1 } + +const std::vector& QpackStaticTableVector() { + static const auto* kQpackStaticTable = new std::vector{ + STATIC_ENTRY(":authority", ""), // 0 + STATIC_ENTRY(":path", "/"), // 1 + STATIC_ENTRY("age", "0"), // 2 + STATIC_ENTRY("content-disposition", ""), // 3 + STATIC_ENTRY("content-length", "0"), // 4 + STATIC_ENTRY("cookie", ""), // 5 + STATIC_ENTRY("date", ""), // 6 + STATIC_ENTRY("etag", ""), // 7 + STATIC_ENTRY("if-modified-since", ""), // 8 + STATIC_ENTRY("if-none-match", ""), // 9 + STATIC_ENTRY("last-modified", ""), // 10 + STATIC_ENTRY("link", ""), // 11 + STATIC_ENTRY("location", ""), // 12 + STATIC_ENTRY("referer", ""), // 13 + STATIC_ENTRY("set-cookie", ""), // 14 + STATIC_ENTRY(":method", "CONNECT"), // 15 + STATIC_ENTRY(":method", "DELETE"), // 16 + STATIC_ENTRY(":method", "GET"), // 17 + STATIC_ENTRY(":method", "HEAD"), // 18 + STATIC_ENTRY(":method", "OPTIONS"), // 19 + STATIC_ENTRY(":method", "POST"), // 20 + STATIC_ENTRY(":method", "PUT"), // 21 + STATIC_ENTRY(":scheme", "http"), // 22 + STATIC_ENTRY(":scheme", "https"), // 23 + STATIC_ENTRY(":status", "103"), // 24 + STATIC_ENTRY(":status", "200"), // 25 + STATIC_ENTRY(":status", "304"), // 26 + STATIC_ENTRY(":status", "404"), // 27 + STATIC_ENTRY(":status", "503"), // 28 + STATIC_ENTRY("accept", "*/*"), // 29 + STATIC_ENTRY("accept", "application/dns-message"), // 30 + STATIC_ENTRY("accept-encoding", "gzip, deflate, br"), // 31 + STATIC_ENTRY("accept-ranges", "bytes"), // 32 + STATIC_ENTRY("access-control-allow-headers", "cache-control"), // 33 + STATIC_ENTRY("access-control-allow-headers", "content-type"), // 35 + STATIC_ENTRY("access-control-allow-origin", "*"), // 35 + STATIC_ENTRY("cache-control", "max-age=0"), // 36 + STATIC_ENTRY("cache-control", "max-age=2592000"), // 37 + STATIC_ENTRY("cache-control", "max-age=604800"), // 38 + STATIC_ENTRY("cache-control", "no-cache"), // 39 + STATIC_ENTRY("cache-control", "no-store"), // 40 + STATIC_ENTRY("cache-control", "public, max-age=31536000"), // 41 + STATIC_ENTRY("content-encoding", "br"), // 42 + STATIC_ENTRY("content-encoding", "gzip"), // 43 + STATIC_ENTRY("content-type", "application/dns-message"), // 44 + STATIC_ENTRY("content-type", "application/javascript"), // 45 + STATIC_ENTRY("content-type", "application/json"), // 46 + STATIC_ENTRY("content-type", "application/x-www-form-urlencoded"), // 47 + STATIC_ENTRY("content-type", "image/gif"), // 48 + STATIC_ENTRY("content-type", "image/jpeg"), // 49 + STATIC_ENTRY("content-type", "image/png"), // 50 + STATIC_ENTRY("content-type", "text/css"), // 51 + STATIC_ENTRY("content-type", "text/html; charset=utf-8"), // 52 + STATIC_ENTRY("content-type", "text/plain"), // 53 + STATIC_ENTRY("content-type", "text/plain;charset=utf-8"), // 54 + STATIC_ENTRY("range", "bytes=0-"), // 55 + STATIC_ENTRY("strict-transport-security", "max-age=31536000"), // 56 + STATIC_ENTRY("strict-transport-security", + "max-age=31536000; includesubdomains"), // 57 + STATIC_ENTRY("strict-transport-security", + "max-age=31536000; includesubdomains; preload"), // 58 + STATIC_ENTRY("vary", "accept-encoding"), // 59 + STATIC_ENTRY("vary", "origin"), // 60 + STATIC_ENTRY("x-content-type-options", "nosniff"), // 61 + STATIC_ENTRY("x-xss-protection", "1; mode=block"), // 62 + STATIC_ENTRY(":status", "100"), // 63 + STATIC_ENTRY(":status", "204"), // 64 + STATIC_ENTRY(":status", "206"), // 65 + STATIC_ENTRY(":status", "302"), // 66 + STATIC_ENTRY(":status", "400"), // 67 + STATIC_ENTRY(":status", "403"), // 68 + STATIC_ENTRY(":status", "421"), // 69 + STATIC_ENTRY(":status", "425"), // 70 + STATIC_ENTRY(":status", "500"), // 71 + STATIC_ENTRY("accept-language", ""), // 72 + STATIC_ENTRY("access-control-allow-credentials", "FALSE"), // 73 + STATIC_ENTRY("access-control-allow-credentials", "TRUE"), // 74 + STATIC_ENTRY("access-control-allow-headers", "*"), // 75 + STATIC_ENTRY("access-control-allow-methods", "get"), // 76 + STATIC_ENTRY("access-control-allow-methods", "get, post, options"), // 77 + STATIC_ENTRY("access-control-allow-methods", "options"), // 78 + STATIC_ENTRY("access-control-expose-headers", "content-length"), // 79 + STATIC_ENTRY("access-control-request-headers", "content-type"), // 80 + STATIC_ENTRY("access-control-request-method", "get"), // 81 + STATIC_ENTRY("access-control-request-method", "post"), // 82 + STATIC_ENTRY("alt-svc", "clear"), // 83 + STATIC_ENTRY("authorization", ""), // 84 + STATIC_ENTRY( + "content-security-policy", + "script-src 'none'; object-src 'none'; base-uri 'none'"), // 85 + STATIC_ENTRY("early-data", "1"), // 86 + STATIC_ENTRY("expect-ct", ""), // 87 + STATIC_ENTRY("forwarded", ""), // 88 + STATIC_ENTRY("if-range", ""), // 89 + STATIC_ENTRY("origin", ""), // 90 + STATIC_ENTRY("purpose", "prefetch"), // 91 + STATIC_ENTRY("server", ""), // 92 + STATIC_ENTRY("timing-allow-origin", "*"), // 93 + STATIC_ENTRY("upgrade-insecure-requests", "1"), // 94 + STATIC_ENTRY("user-agent", ""), // 95 + STATIC_ENTRY("x-forwarded-for", ""), // 96 + STATIC_ENTRY("x-frame-options", "deny"), // 97 + STATIC_ENTRY("x-frame-options", "sameorigin"), // 98 + }; + return *kQpackStaticTable; +} + +#undef STATIC_ENTRY + +const QpackStaticTable& ObtainQpackStaticTable() { + static const QpackStaticTable* const shared_static_table = []() { + auto* table = new QpackStaticTable(); + table->Initialize(QpackStaticTableVector().data(), + QpackStaticTableVector().size()); + QUICHE_CHECK(table->IsInitialized()); + return table; + }(); + return *shared_static_table; +} + +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_static_table.h b/quiche/quic/core/qpack/qpack_static_table.h new file mode 100644 index 000000000000..43fd297a0a22 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_static_table.h @@ -0,0 +1,31 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QPACK_QPACK_STATIC_TABLE_H_ +#define QUICHE_QUIC_CORE_QPACK_QPACK_STATIC_TABLE_H_ + +#include + +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/spdy/core/hpack/hpack_constants.h" +#include "quiche/spdy/core/hpack/hpack_static_table.h" + +namespace quic { + +using QpackStaticEntry = spdy::HpackStaticEntry; +using QpackStaticTable = spdy::HpackStaticTable; + +// QPACK static table defined at +// https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#static-table. +QUIC_EXPORT_PRIVATE const std::vector& +QpackStaticTableVector(); + +// Returns a QpackStaticTable instance initialized with kQpackStaticTable. +// The instance is read-only, has static lifetime, and is safe to share amoung +// threads. This function is thread-safe. +QUIC_EXPORT_PRIVATE const QpackStaticTable& ObtainQpackStaticTable(); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QPACK_QPACK_STATIC_TABLE_H_ diff --git a/quiche/quic/core/qpack/qpack_static_table_test.cc b/quiche/quic/core/qpack/qpack_static_table_test.cc new file mode 100644 index 000000000000..67cd9c0fdb8b --- /dev/null +++ b/quiche/quic/core/qpack/qpack_static_table_test.cc @@ -0,0 +1,54 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/qpack_static_table.h" + +#include + +#include "absl/base/macros.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { + +namespace test { + +namespace { + +// Check that an initialized instance has the right number of entries. +TEST(QpackStaticTableTest, Initialize) { + QpackStaticTable table; + EXPECT_FALSE(table.IsInitialized()); + + table.Initialize(QpackStaticTableVector().data(), + QpackStaticTableVector().size()); + EXPECT_TRUE(table.IsInitialized()); + + const auto& static_entries = table.GetStaticEntries(); + EXPECT_EQ(QpackStaticTableVector().size(), static_entries.size()); + + const auto& static_index = table.GetStaticIndex(); + EXPECT_EQ(QpackStaticTableVector().size(), static_index.size()); + + const auto& static_name_index = table.GetStaticNameIndex(); + // Count distinct names in static table. + std::set names; + for (const auto& entry : static_entries) { + names.insert(entry.name()); + } + EXPECT_EQ(names.size(), static_name_index.size()); +} + +// Test that ObtainQpackStaticTable returns the same instance every time. +TEST(QpackStaticTableTest, IsSingleton) { + const QpackStaticTable* static_table_one = &ObtainQpackStaticTable(); + const QpackStaticTable* static_table_two = &ObtainQpackStaticTable(); + EXPECT_EQ(static_table_one, static_table_two); +} + +} // namespace + +} // namespace test + +} // namespace quic diff --git a/quiche/quic/core/qpack/qpack_stream_receiver.h b/quiche/quic/core/qpack/qpack_stream_receiver.h new file mode 100644 index 000000000000..ad95cce46de9 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_stream_receiver.h @@ -0,0 +1,24 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QPACK_QPACK_STREAM_RECEIVER_H_ +#define QUICHE_QUIC_CORE_QPACK_QPACK_STREAM_RECEIVER_H_ + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// This interface decodes QPACK data that are received on a QpackReceiveStream. +class QUIC_EXPORT_PRIVATE QpackStreamReceiver { + public: + virtual ~QpackStreamReceiver() = default; + + // Decode data. + virtual void Decode(absl::string_view data) = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QPACK_QPACK_STREAM_RECEIVER_H_ diff --git a/quiche/quic/core/qpack/qpack_stream_sender_delegate.h b/quiche/quic/core/qpack/qpack_stream_sender_delegate.h new file mode 100644 index 000000000000..8e267db829e3 --- /dev/null +++ b/quiche/quic/core/qpack/qpack_stream_sender_delegate.h @@ -0,0 +1,27 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QPACK_QPACK_STREAM_SENDER_DELEGATE_H_ +#define QUICHE_QUIC_CORE_QPACK_QPACK_STREAM_SENDER_DELEGATE_H_ + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// This interface writes encoder/decoder data to peer. +class QUIC_EXPORT_PRIVATE QpackStreamSenderDelegate { + public: + virtual ~QpackStreamSenderDelegate() = default; + + // Write data on the unidirectional stream. + virtual void WriteStreamData(absl::string_view data) = 0; + + // Return the number of bytes buffered due to underlying stream being blocked. + virtual uint64_t NumBytesBuffered() const = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QPACK_QPACK_STREAM_SENDER_DELEGATE_H_ diff --git a/quiche/quic/core/qpack/value_splitting_header_list.cc b/quiche/quic/core/qpack/value_splitting_header_list.cc new file mode 100644 index 000000000000..faeccdf4ff3b --- /dev/null +++ b/quiche/quic/core/qpack/value_splitting_header_list.cc @@ -0,0 +1,108 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/value_splitting_header_list.h" + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { +namespace { + +const char kCookieKey[] = "cookie"; +const char kCookieSeparator = ';'; +const char kOptionalSpaceAfterCookieSeparator = ' '; +const char kNonCookieSeparator = '\0'; + +} // namespace + +ValueSplittingHeaderList::const_iterator::const_iterator( + const spdy::Http2HeaderBlock* header_list, + spdy::Http2HeaderBlock::const_iterator header_list_iterator) + : header_list_(header_list), + header_list_iterator_(header_list_iterator), + value_start_(0) { + UpdateHeaderField(); +} + +bool ValueSplittingHeaderList::const_iterator::operator==( + const const_iterator& other) const { + return header_list_iterator_ == other.header_list_iterator_ && + value_start_ == other.value_start_; +} + +bool ValueSplittingHeaderList::const_iterator::operator!=( + const const_iterator& other) const { + return !(*this == other); +} + +const ValueSplittingHeaderList::const_iterator& +ValueSplittingHeaderList::const_iterator::operator++() { + if (value_end_ == absl::string_view::npos) { + // This was the last frament within |*header_list_iterator_|, + // move on to the next header element of |header_list_|. + ++header_list_iterator_; + value_start_ = 0; + } else { + // Find the next fragment within |*header_list_iterator_|. + value_start_ = value_end_ + 1; + } + UpdateHeaderField(); + + return *this; +} + +const ValueSplittingHeaderList::value_type& +ValueSplittingHeaderList::const_iterator::operator*() const { + return header_field_; +} +const ValueSplittingHeaderList::value_type* +ValueSplittingHeaderList::const_iterator::operator->() const { + return &header_field_; +} + +void ValueSplittingHeaderList::const_iterator::UpdateHeaderField() { + QUICHE_DCHECK(value_start_ != absl::string_view::npos); + + if (header_list_iterator_ == header_list_->end()) { + return; + } + + const absl::string_view name = header_list_iterator_->first; + const absl::string_view original_value = header_list_iterator_->second; + + if (name == kCookieKey) { + value_end_ = original_value.find(kCookieSeparator, value_start_); + } else { + value_end_ = original_value.find(kNonCookieSeparator, value_start_); + } + + const absl::string_view value = + original_value.substr(value_start_, value_end_ - value_start_); + header_field_ = std::make_pair(name, value); + + // Skip character after ';' separator if it is a space. + if (name == kCookieKey && value_end_ != absl::string_view::npos && + value_end_ + 1 < original_value.size() && + original_value[value_end_ + 1] == kOptionalSpaceAfterCookieSeparator) { + ++value_end_; + } +} + +ValueSplittingHeaderList::ValueSplittingHeaderList( + const spdy::Http2HeaderBlock* header_list) + : header_list_(header_list) { + QUICHE_DCHECK(header_list_); +} + +ValueSplittingHeaderList::const_iterator ValueSplittingHeaderList::begin() + const { + return const_iterator(header_list_, header_list_->begin()); +} + +ValueSplittingHeaderList::const_iterator ValueSplittingHeaderList::end() const { + return const_iterator(header_list_, header_list_->end()); +} + +} // namespace quic diff --git a/quiche/quic/core/qpack/value_splitting_header_list.h b/quiche/quic/core/qpack/value_splitting_header_list.h new file mode 100644 index 000000000000..e2750332a165 --- /dev/null +++ b/quiche/quic/core/qpack/value_splitting_header_list.h @@ -0,0 +1,62 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QPACK_VALUE_SPLITTING_HEADER_LIST_H_ +#define QUICHE_QUIC_CORE_QPACK_VALUE_SPLITTING_HEADER_LIST_H_ + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +// A wrapper class around Http2HeaderBlock that splits header values along ';' +// separators (while also removing optional space following separator) for +// cookies and along '\0' separators for other header fields. +class QUIC_EXPORT_PRIVATE ValueSplittingHeaderList { + public: + using value_type = spdy::Http2HeaderBlock::value_type; + + class QUIC_EXPORT_PRIVATE const_iterator { + public: + // |header_list| must outlive this object. + const_iterator(const spdy::Http2HeaderBlock* header_list, + spdy::Http2HeaderBlock::const_iterator header_list_iterator); + const_iterator(const const_iterator&) = default; + const_iterator& operator=(const const_iterator&) = delete; + + bool operator==(const const_iterator& other) const; + bool operator!=(const const_iterator& other) const; + + const const_iterator& operator++(); + + const value_type& operator*() const; + const value_type* operator->() const; + + private: + // Find next separator; update |value_end_| and |header_field_|. + void UpdateHeaderField(); + + const spdy::Http2HeaderBlock* const header_list_; + spdy::Http2HeaderBlock::const_iterator header_list_iterator_; + absl::string_view::size_type value_start_; + absl::string_view::size_type value_end_; + value_type header_field_; + }; + + // |header_list| must outlive this object. + explicit ValueSplittingHeaderList(const spdy::Http2HeaderBlock* header_list); + ValueSplittingHeaderList(const ValueSplittingHeaderList&) = delete; + ValueSplittingHeaderList& operator=(const ValueSplittingHeaderList&) = delete; + + const_iterator begin() const; + const_iterator end() const; + + private: + const spdy::Http2HeaderBlock* const header_list_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QPACK_VALUE_SPLITTING_HEADER_LIST_H_ diff --git a/quiche/quic/core/qpack/value_splitting_header_list_test.cc b/quiche/quic/core/qpack/value_splitting_header_list_test.cc new file mode 100644 index 000000000000..a3aae0af6fd3 --- /dev/null +++ b/quiche/quic/core/qpack/value_splitting_header_list_test.cc @@ -0,0 +1,158 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/qpack/value_splitting_header_list.h" + +#include "absl/base/macros.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { +namespace { + +using ::testing::ElementsAre; +using ::testing::Pair; + +TEST(ValueSplittingHeaderListTest, Comparison) { + spdy::Http2HeaderBlock block; + block["foo"] = absl::string_view("bar\0baz", 7); + block["baz"] = "qux"; + block["cookie"] = "foo; bar"; + + ValueSplittingHeaderList headers(&block); + ValueSplittingHeaderList::const_iterator it1 = headers.begin(); + const int kEnd = 6; + for (int i = 0; i < kEnd; ++i) { + // Compare to begin(). + if (i == 0) { + EXPECT_TRUE(it1 == headers.begin()); + EXPECT_TRUE(headers.begin() == it1); + EXPECT_FALSE(it1 != headers.begin()); + EXPECT_FALSE(headers.begin() != it1); + } else { + EXPECT_FALSE(it1 == headers.begin()); + EXPECT_FALSE(headers.begin() == it1); + EXPECT_TRUE(it1 != headers.begin()); + EXPECT_TRUE(headers.begin() != it1); + } + + // Compare to end(). + if (i == kEnd - 1) { + EXPECT_TRUE(it1 == headers.end()); + EXPECT_TRUE(headers.end() == it1); + EXPECT_FALSE(it1 != headers.end()); + EXPECT_FALSE(headers.end() != it1); + } else { + EXPECT_FALSE(it1 == headers.end()); + EXPECT_FALSE(headers.end() == it1); + EXPECT_TRUE(it1 != headers.end()); + EXPECT_TRUE(headers.end() != it1); + } + + // Compare to another iterator walking through the container. + ValueSplittingHeaderList::const_iterator it2 = headers.begin(); + for (int j = 0; j < kEnd; ++j) { + if (i == j) { + EXPECT_TRUE(it1 == it2); + EXPECT_FALSE(it1 != it2); + } else { + EXPECT_FALSE(it1 == it2); + EXPECT_TRUE(it1 != it2); + } + if (j < kEnd - 1) { + ASSERT_NE(it2, headers.end()); + ++it2; + } + } + + if (i < kEnd - 1) { + ASSERT_NE(it1, headers.end()); + ++it1; + } + } +} + +TEST(ValueSplittingHeaderListTest, Empty) { + spdy::Http2HeaderBlock block; + + ValueSplittingHeaderList headers(&block); + EXPECT_THAT(headers, ElementsAre()); + EXPECT_EQ(headers.begin(), headers.end()); +} + +TEST(ValueSplittingHeaderListTest, Split) { + struct { + const char* name; + absl::string_view value; + std::vector expected_values; + } kTestData[]{ + // Empty value. + {"foo", "", {""}}, + // Trivial case. + {"foo", "bar", {"bar"}}, + // Simple split. + {"foo", {"bar\0baz", 7}, {"bar", "baz"}}, + {"cookie", "foo;bar", {"foo", "bar"}}, + {"cookie", "foo; bar", {"foo", "bar"}}, + // Empty fragments with \0 separator. + {"foo", {"\0", 1}, {"", ""}}, + {"bar", {"foo\0", 4}, {"foo", ""}}, + {"baz", {"\0bar", 4}, {"", "bar"}}, + {"qux", {"\0foobar\0", 8}, {"", "foobar", ""}}, + // Empty fragments with ";" separator. + {"cookie", ";", {"", ""}}, + {"cookie", "foo;", {"foo", ""}}, + {"cookie", ";bar", {"", "bar"}}, + {"cookie", ";foobar;", {"", "foobar", ""}}, + // Empty fragments with "; " separator. + {"cookie", "; ", {"", ""}}, + {"cookie", "foo; ", {"foo", ""}}, + {"cookie", "; bar", {"", "bar"}}, + {"cookie", "; foobar; ", {"", "foobar", ""}}, + }; + + for (size_t i = 0; i < ABSL_ARRAYSIZE(kTestData); ++i) { + spdy::Http2HeaderBlock block; + block[kTestData[i].name] = kTestData[i].value; + + ValueSplittingHeaderList headers(&block); + auto it = headers.begin(); + for (const char* expected_value : kTestData[i].expected_values) { + ASSERT_NE(it, headers.end()); + EXPECT_EQ(it->first, kTestData[i].name); + EXPECT_EQ(it->second, expected_value); + ++it; + } + EXPECT_EQ(it, headers.end()); + } +} + +TEST(ValueSplittingHeaderListTest, MultipleFields) { + spdy::Http2HeaderBlock block; + block["foo"] = absl::string_view("bar\0baz\0", 8); + block["cookie"] = "foo; bar"; + block["bar"] = absl::string_view("qux\0foo", 7); + + ValueSplittingHeaderList headers(&block); + EXPECT_THAT(headers, ElementsAre(Pair("foo", "bar"), Pair("foo", "baz"), + Pair("foo", ""), Pair("cookie", "foo"), + Pair("cookie", "bar"), Pair("bar", "qux"), + Pair("bar", "foo"))); +} + +TEST(ValueSplittingHeaderListTest, CookieStartsWithSpace) { + spdy::Http2HeaderBlock block; + block["foo"] = "bar"; + block["cookie"] = " foo"; + block["bar"] = "baz"; + + ValueSplittingHeaderList headers(&block); + EXPECT_THAT(headers, ElementsAre(Pair("foo", "bar"), Pair("cookie", " foo"), + Pair("bar", "baz"))); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_ack_listener_interface.cc b/quiche/quic/core/quic_ack_listener_interface.cc new file mode 100644 index 000000000000..f3c9d141246e --- /dev/null +++ b/quiche/quic/core/quic_ack_listener_interface.cc @@ -0,0 +1,11 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_ack_listener_interface.h" + +namespace quic { + +QuicAckListenerInterface::~QuicAckListenerInterface() {} + +} // namespace quic diff --git a/quiche/quic/core/quic_ack_listener_interface.h b/quiche/quic/core/quic_ack_listener_interface.h new file mode 100644 index 000000000000..edf4fd3111cd --- /dev/null +++ b/quiche/quic/core/quic_ack_listener_interface.h @@ -0,0 +1,37 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_ACK_LISTENER_INTERFACE_H_ +#define QUICHE_QUIC_CORE_QUIC_ACK_LISTENER_INTERFACE_H_ + +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/common/platform/api/quiche_reference_counted.h" + +namespace quic { + +// Pure virtual class to listen for packet acknowledgements. +class QUIC_EXPORT_PRIVATE QuicAckListenerInterface + : public quiche::QuicheReferenceCounted { + public: + QuicAckListenerInterface() {} + + // Called when a packet is acked. Called once per packet. + // |acked_bytes| is the number of data bytes acked. + virtual void OnPacketAcked(int acked_bytes, + QuicTime::Delta ack_delay_time) = 0; + + // Called when a packet is retransmitted. Called once per packet. + // |retransmitted_bytes| is the number of data bytes retransmitted. + virtual void OnPacketRetransmitted(int retransmitted_bytes) = 0; + + protected: + // Delegates are ref counted. + ~QuicAckListenerInterface() override; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_ACK_LISTENER_INTERFACE_H_ diff --git a/quiche/quic/core/quic_alarm.cc b/quiche/quic/core/quic_alarm.cc new file mode 100644 index 000000000000..029eefd1773c --- /dev/null +++ b/quiche/quic/core/quic_alarm.cc @@ -0,0 +1,105 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_alarm.h" + +#include + +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_stack_trace.h" + +namespace quic { + +QuicAlarm::QuicAlarm(QuicArenaScopedPtr delegate) + : delegate_(std::move(delegate)), deadline_(QuicTime::Zero()) {} + +QuicAlarm::~QuicAlarm() { + if (IsSet()) { + QUIC_CODE_COUNT(quic_alarm_not_cancelled_in_dtor); + } +} + +void QuicAlarm::Set(QuicTime new_deadline) { + QUICHE_DCHECK(!IsSet()); + QUICHE_DCHECK(new_deadline.IsInitialized()); + + if (IsPermanentlyCancelled()) { + QUIC_BUG(quic_alarm_illegal_set) + << "Set called after alarm is permanently cancelled. new_deadline:" + << new_deadline; + return; + } + + deadline_ = new_deadline; + SetImpl(); +} + +void QuicAlarm::CancelInternal(bool permanent) { + if (IsSet()) { + deadline_ = QuicTime::Zero(); + CancelImpl(); + } + + if (permanent) { + delegate_.reset(); + } +} + +bool QuicAlarm::IsPermanentlyCancelled() const { return delegate_ == nullptr; } + +void QuicAlarm::Update(QuicTime new_deadline, QuicTime::Delta granularity) { + if (IsPermanentlyCancelled()) { + QUIC_BUG(quic_alarm_illegal_update) + << "Update called after alarm is permanently cancelled. new_deadline:" + << new_deadline << ", granularity:" << granularity; + return; + } + + if (!new_deadline.IsInitialized()) { + Cancel(); + return; + } + if (std::abs((new_deadline - deadline_).ToMicroseconds()) < + granularity.ToMicroseconds()) { + return; + } + const bool was_set = IsSet(); + deadline_ = new_deadline; + if (was_set) { + UpdateImpl(); + } else { + SetImpl(); + } +} + +bool QuicAlarm::IsSet() const { return deadline_.IsInitialized(); } + +void QuicAlarm::Fire() { + if (!IsSet()) { + return; + } + + deadline_ = QuicTime::Zero(); + if (!IsPermanentlyCancelled()) { + QuicConnectionContextSwitcher context_switcher( + delegate_->GetConnectionContext()); + delegate_->OnAlarm(); + } +} + +void QuicAlarm::UpdateImpl() { + // CancelImpl and SetImpl take the new deadline by way of the deadline_ + // member, so save and restore deadline_ before canceling. + const QuicTime new_deadline = deadline_; + + deadline_ = QuicTime::Zero(); + CancelImpl(); + + deadline_ = new_deadline; + SetImpl(); +} + +} // namespace quic diff --git a/quiche/quic/core/quic_alarm.h b/quiche/quic/core/quic_alarm.h new file mode 100644 index 000000000000..4352a441454b --- /dev/null +++ b/quiche/quic/core/quic_alarm.h @@ -0,0 +1,125 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_ALARM_H_ +#define QUICHE_QUIC_CORE_QUIC_ALARM_H_ + +#include "quiche/quic/core/quic_arena_scoped_ptr.h" +#include "quiche/quic/core/quic_connection_context.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// Abstract class which represents an alarm which will go off at a +// scheduled time, and execute the |OnAlarm| method of the delegate. +// An alarm may be cancelled, in which case it may or may not be +// removed from the underlying scheduling system, but in either case +// the task will not be executed. +class QUIC_EXPORT_PRIVATE QuicAlarm { + public: + class QUIC_EXPORT_PRIVATE Delegate { + public: + virtual ~Delegate() {} + + // If the alarm belongs to a single QuicConnection, return the corresponding + // QuicConnection.context_. Note the context_ is the first member of + // QuicConnection, so it should outlive the delegate. + // Otherwise return nullptr. + // The OnAlarm function will be called under the connection context, if any. + virtual QuicConnectionContext* GetConnectionContext() = 0; + + // Invoked when the alarm fires. + virtual void OnAlarm() = 0; + }; + + // DelegateWithContext is a Delegate with a QuicConnectionContext* stored as a + // member variable. + class QUIC_EXPORT_PRIVATE DelegateWithContext : public Delegate { + public: + explicit DelegateWithContext(QuicConnectionContext* context) + : context_(context) {} + ~DelegateWithContext() override {} + QuicConnectionContext* GetConnectionContext() override { return context_; } + + private: + QuicConnectionContext* context_; + }; + + // DelegateWithoutContext is a Delegate that does not have a corresponding + // context. Typically this means one object of the child class deals with many + // connections. + class QUIC_EXPORT_PRIVATE DelegateWithoutContext : public Delegate { + public: + ~DelegateWithoutContext() override {} + QuicConnectionContext* GetConnectionContext() override { return nullptr; } + }; + + explicit QuicAlarm(QuicArenaScopedPtr delegate); + QuicAlarm(const QuicAlarm&) = delete; + QuicAlarm& operator=(const QuicAlarm&) = delete; + virtual ~QuicAlarm(); + + // Sets the alarm to fire at |deadline|. Must not be called while + // the alarm is set. To reschedule an alarm, call Cancel() first, + // then Set(). + void Set(QuicTime new_deadline); + + // Both PermanentCancel() and Cancel() can cancel the alarm. If permanent, + // future calls to Set() and Update() will become no-op except emitting an + // error log. + // + // Both may be called repeatedly. Does not guarantee that the underlying + // scheduling system will remove the alarm's associated task, but guarantees + // that the delegates OnAlarm method will not be called. + void PermanentCancel() { CancelInternal(true); } + void Cancel() { CancelInternal(false); } + + // Return true if PermanentCancel() has been called. + bool IsPermanentlyCancelled() const; + + // Cancels and sets the alarm if the |deadline| is farther from the current + // deadline than |granularity|, and otherwise does nothing. If |deadline| is + // not initialized, the alarm is cancelled. + void Update(QuicTime new_deadline, QuicTime::Delta granularity); + + // Returns true if |deadline_| has been set to a non-zero time. + bool IsSet() const; + + QuicTime deadline() const { return deadline_; } + + protected: + // Subclasses implement this method to perform the platform-specific + // scheduling of the alarm. Is called from Set() or Fire(), after the + // deadline has been updated. + virtual void SetImpl() = 0; + + // Subclasses implement this method to perform the platform-specific + // cancelation of the alarm. + virtual void CancelImpl() = 0; + + // Subclasses implement this method to perform the platform-specific update of + // the alarm if there exists a more optimal implementation than calling + // CancelImpl() and SetImpl(). + virtual void UpdateImpl(); + + // Called by subclasses when the alarm fires. Invokes the + // delegates |OnAlarm| if a delegate is set, and if the deadline + // has been exceeded. Implementations which do not remove the + // alarm from the underlying scheduler on Cancel() may need to handle + // the situation where the task executes before the deadline has been + // reached, in which case they need to reschedule the task and must not + // call invoke this method. + void Fire(); + + private: + void CancelInternal(bool permanent); + + QuicArenaScopedPtr delegate_; + QuicTime deadline_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_ALARM_H_ diff --git a/quiche/quic/core/quic_alarm_factory.h b/quiche/quic/core/quic_alarm_factory.h new file mode 100644 index 000000000000..b1ce54c77edd --- /dev/null +++ b/quiche/quic/core/quic_alarm_factory.h @@ -0,0 +1,36 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_ALARM_FACTORY_H_ +#define QUICHE_QUIC_CORE_QUIC_ALARM_FACTORY_H_ + +#include "quiche/quic/core/quic_alarm.h" +#include "quiche/quic/core/quic_one_block_arena.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// Creates platform-specific alarms used throughout QUIC. +class QUIC_EXPORT_PRIVATE QuicAlarmFactory { + public: + virtual ~QuicAlarmFactory() {} + + // Creates a new platform-specific alarm which will be configured to notify + // |delegate| when the alarm fires. Returns an alarm allocated on the heap. + // Caller takes ownership of the new alarm, which will not yet be "set" to + // fire. + virtual QuicAlarm* CreateAlarm(QuicAlarm::Delegate* delegate) = 0; + + // Creates a new platform-specific alarm which will be configured to notify + // |delegate| when the alarm fires. Caller takes ownership of the new alarm, + // which will not yet be "set" to fire. If |arena| is null, then the alarm + // will be created on the heap. Otherwise, it will be created in |arena|. + virtual QuicArenaScopedPtr CreateAlarm( + QuicArenaScopedPtr delegate, + QuicConnectionArena* arena) = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_ALARM_FACTORY_H_ diff --git a/quiche/quic/core/quic_alarm_test.cc b/quiche/quic/core/quic_alarm_test.cc new file mode 100644 index 000000000000..5feef7872d45 --- /dev/null +++ b/quiche/quic/core/quic_alarm_test.cc @@ -0,0 +1,259 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_alarm.h" + +#include "quiche/quic/core/quic_connection_context.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_test.h" + +using testing::ElementsAre; +using testing::Invoke; +using testing::Return; + +namespace quic { +namespace test { +namespace { + +class TraceCollector : public QuicConnectionTracer { + public: + ~TraceCollector() override = default; + + void PrintLiteral(const char* literal) override { trace_.push_back(literal); } + + void PrintString(absl::string_view s) override { + trace_.push_back(std::string(s)); + } + + const std::vector& trace() const { return trace_; } + + private: + std::vector trace_; +}; + +class MockDelegate : public QuicAlarm::Delegate { + public: + MOCK_METHOD(QuicConnectionContext*, GetConnectionContext, (), (override)); + MOCK_METHOD(void, OnAlarm, (), (override)); +}; + +class DestructiveDelegate : public QuicAlarm::DelegateWithoutContext { + public: + DestructiveDelegate() : alarm_(nullptr) {} + + void set_alarm(QuicAlarm* alarm) { alarm_ = alarm; } + + void OnAlarm() override { + QUICHE_DCHECK(alarm_); + delete alarm_; + } + + private: + QuicAlarm* alarm_; +}; + +class TestAlarm : public QuicAlarm { + public: + explicit TestAlarm(QuicAlarm::Delegate* delegate) + : QuicAlarm(QuicArenaScopedPtr(delegate)) {} + + bool scheduled() const { return scheduled_; } + + void FireAlarm() { + scheduled_ = false; + Fire(); + } + + protected: + void SetImpl() override { + QUICHE_DCHECK(deadline().IsInitialized()); + scheduled_ = true; + } + + void CancelImpl() override { + QUICHE_DCHECK(!deadline().IsInitialized()); + scheduled_ = false; + } + + private: + bool scheduled_; +}; + +class DestructiveAlarm : public QuicAlarm { + public: + explicit DestructiveAlarm(DestructiveDelegate* delegate) + : QuicAlarm(QuicArenaScopedPtr(delegate)) {} + + void FireAlarm() { Fire(); } + + protected: + void SetImpl() override {} + + void CancelImpl() override {} +}; + +class QuicAlarmTest : public QuicTest { + public: + QuicAlarmTest() + : delegate_(new MockDelegate()), + alarm_(delegate_), + deadline_(QuicTime::Zero() + QuicTime::Delta::FromSeconds(7)), + deadline2_(QuicTime::Zero() + QuicTime::Delta::FromSeconds(14)), + new_deadline_(QuicTime::Zero()) {} + + void ResetAlarm() { alarm_.Set(new_deadline_); } + + MockDelegate* delegate_; // not owned + TestAlarm alarm_; + QuicTime deadline_; + QuicTime deadline2_; + QuicTime new_deadline_; +}; + +TEST_F(QuicAlarmTest, IsSet) { EXPECT_FALSE(alarm_.IsSet()); } + +TEST_F(QuicAlarmTest, Set) { + QuicTime deadline = QuicTime::Zero() + QuicTime::Delta::FromSeconds(7); + alarm_.Set(deadline); + EXPECT_TRUE(alarm_.IsSet()); + EXPECT_TRUE(alarm_.scheduled()); + EXPECT_EQ(deadline, alarm_.deadline()); +} + +TEST_F(QuicAlarmTest, Cancel) { + QuicTime deadline = QuicTime::Zero() + QuicTime::Delta::FromSeconds(7); + alarm_.Set(deadline); + alarm_.Cancel(); + EXPECT_FALSE(alarm_.IsSet()); + EXPECT_FALSE(alarm_.scheduled()); + EXPECT_EQ(QuicTime::Zero(), alarm_.deadline()); +} + +TEST_F(QuicAlarmTest, PermanentCancel) { + QuicTime deadline = QuicTime::Zero() + QuicTime::Delta::FromSeconds(7); + alarm_.Set(deadline); + alarm_.PermanentCancel(); + EXPECT_FALSE(alarm_.IsSet()); + EXPECT_FALSE(alarm_.scheduled()); + EXPECT_EQ(QuicTime::Zero(), alarm_.deadline()); + + EXPECT_QUIC_BUG(alarm_.Set(deadline), + "Set called after alarm is permanently cancelled"); + EXPECT_TRUE(alarm_.IsPermanentlyCancelled()); + EXPECT_FALSE(alarm_.IsSet()); + EXPECT_FALSE(alarm_.scheduled()); + EXPECT_EQ(QuicTime::Zero(), alarm_.deadline()); + + EXPECT_QUIC_BUG(alarm_.Update(deadline, QuicTime::Delta::Zero()), + "Update called after alarm is permanently cancelled"); + EXPECT_TRUE(alarm_.IsPermanentlyCancelled()); + EXPECT_FALSE(alarm_.IsSet()); + EXPECT_FALSE(alarm_.scheduled()); + EXPECT_EQ(QuicTime::Zero(), alarm_.deadline()); +} + +TEST_F(QuicAlarmTest, Update) { + QuicTime deadline = QuicTime::Zero() + QuicTime::Delta::FromSeconds(7); + alarm_.Set(deadline); + QuicTime new_deadline = QuicTime::Zero() + QuicTime::Delta::FromSeconds(8); + alarm_.Update(new_deadline, QuicTime::Delta::Zero()); + EXPECT_TRUE(alarm_.IsSet()); + EXPECT_TRUE(alarm_.scheduled()); + EXPECT_EQ(new_deadline, alarm_.deadline()); +} + +TEST_F(QuicAlarmTest, UpdateWithZero) { + QuicTime deadline = QuicTime::Zero() + QuicTime::Delta::FromSeconds(7); + alarm_.Set(deadline); + alarm_.Update(QuicTime::Zero(), QuicTime::Delta::Zero()); + EXPECT_FALSE(alarm_.IsSet()); + EXPECT_FALSE(alarm_.scheduled()); + EXPECT_EQ(QuicTime::Zero(), alarm_.deadline()); +} + +TEST_F(QuicAlarmTest, Fire) { + QuicTime deadline = QuicTime::Zero() + QuicTime::Delta::FromSeconds(7); + alarm_.Set(deadline); + EXPECT_CALL(*delegate_, OnAlarm()); + alarm_.FireAlarm(); + EXPECT_FALSE(alarm_.IsSet()); + EXPECT_FALSE(alarm_.scheduled()); + EXPECT_EQ(QuicTime::Zero(), alarm_.deadline()); +} + +TEST_F(QuicAlarmTest, FireAndResetViaSet) { + alarm_.Set(deadline_); + new_deadline_ = deadline2_; + EXPECT_CALL(*delegate_, OnAlarm()) + .WillOnce(Invoke(this, &QuicAlarmTest::ResetAlarm)); + alarm_.FireAlarm(); + EXPECT_TRUE(alarm_.IsSet()); + EXPECT_TRUE(alarm_.scheduled()); + EXPECT_EQ(deadline2_, alarm_.deadline()); +} + +TEST_F(QuicAlarmTest, FireDestroysAlarm) { + DestructiveDelegate* delegate(new DestructiveDelegate); + DestructiveAlarm* alarm = new DestructiveAlarm(delegate); + delegate->set_alarm(alarm); + QuicTime deadline = QuicTime::Zero() + QuicTime::Delta::FromSeconds(7); + alarm->Set(deadline); + // This should not crash, even though it will destroy alarm. + alarm->FireAlarm(); +} + +TEST_F(QuicAlarmTest, NullAlarmContext) { + QuicTime deadline = QuicTime::Zero() + QuicTime::Delta::FromSeconds(7); + alarm_.Set(deadline); + + EXPECT_CALL(*delegate_, GetConnectionContext()).WillOnce(Return(nullptr)); + + EXPECT_CALL(*delegate_, OnAlarm()).WillOnce(Invoke([] { + QUIC_TRACELITERAL("Alarm fired."); + })); + alarm_.FireAlarm(); +} + +TEST_F(QuicAlarmTest, AlarmContextWithNullTracer) { + QuicConnectionContext context; + ASSERT_EQ(context.tracer, nullptr); + + QuicTime deadline = QuicTime::Zero() + QuicTime::Delta::FromSeconds(7); + alarm_.Set(deadline); + + EXPECT_CALL(*delegate_, GetConnectionContext()).WillOnce(Return(&context)); + + EXPECT_CALL(*delegate_, OnAlarm()).WillOnce(Invoke([] { + QUIC_TRACELITERAL("Alarm fired."); + })); + alarm_.FireAlarm(); +} + +TEST_F(QuicAlarmTest, AlarmContextWithTracer) { + QuicConnectionContext context; + std::unique_ptr tracer = std::make_unique(); + const TraceCollector& tracer_ref = *tracer; + context.tracer = std::move(tracer); + + QuicTime deadline = QuicTime::Zero() + QuicTime::Delta::FromSeconds(7); + alarm_.Set(deadline); + + EXPECT_CALL(*delegate_, GetConnectionContext()).WillOnce(Return(&context)); + + EXPECT_CALL(*delegate_, OnAlarm()).WillOnce(Invoke([] { + QUIC_TRACELITERAL("Alarm fired."); + })); + + // Since |context| is not installed in the current thread, the messages before + // and after FireAlarm() should not be collected by |tracer|. + QUIC_TRACELITERAL("Should not be collected before alarm."); + alarm_.FireAlarm(); + QUIC_TRACELITERAL("Should not be collected after alarm."); + + EXPECT_THAT(tracer_ref.trace(), ElementsAre("Alarm fired.")); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_arena_scoped_ptr.h b/quiche/quic/core/quic_arena_scoped_ptr.h new file mode 100644 index 000000000000..a0c4ed81e311 --- /dev/null +++ b/quiche/quic/core/quic_arena_scoped_ptr.h @@ -0,0 +1,208 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// unique_ptr-style pointer that stores values that may be from an arena. Takes +// up the same storage as the platform's native pointer type. Takes ownership +// of the value it's constructed with; if holding a value in an arena, and the +// type has a non-trivial destructor, the arena must outlive the +// QuicArenaScopedPtr. Does not support array overloads. + +#ifndef QUICHE_QUIC_CORE_QUIC_ARENA_SCOPED_PTR_H_ +#define QUICHE_QUIC_CORE_QUIC_ARENA_SCOPED_PTR_H_ + +#include // for uintptr_t + +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +template +class QUIC_NO_EXPORT QuicArenaScopedPtr { + static_assert(alignof(T*) > 1, + "QuicArenaScopedPtr can only store objects that are aligned to " + "greater than 1 byte."); + + public: + // Constructs an empty QuicArenaScopedPtr. + QuicArenaScopedPtr(); + + // Constructs a QuicArenaScopedPtr referencing the heap-allocated memory + // provided. + explicit QuicArenaScopedPtr(T* value); + + template + QuicArenaScopedPtr(QuicArenaScopedPtr&& other); // NOLINT + template + QuicArenaScopedPtr& operator=(QuicArenaScopedPtr&& other); + ~QuicArenaScopedPtr(); + + // Returns a pointer to the value. + T* get() const; + + // Returns a reference to the value. + T& operator*() const; + + // Returns a pointer to the value. + T* operator->() const; + + // Swaps the value of this pointer with |other|. + void swap(QuicArenaScopedPtr& other); + + // Resets the held value to |value|. + void reset(T* value = nullptr); + + // Returns true if |this| came from an arena. Primarily exposed for testing + // and assertions. + bool is_from_arena(); + + private: + // Friends with other derived types of QuicArenaScopedPtr, to support the + // derived-types case. + template + friend class QuicArenaScopedPtr; + // Also befriend all known arenas, only to prevent misuse. + template + friend class QuicOneBlockArena; + + // Tag to denote that a QuicArenaScopedPtr is being explicitly created by an + // arena. + enum class ConstructFrom { kHeap, kArena }; + + // Constructs a QuicArenaScopedPtr with the given representation. + QuicArenaScopedPtr(void* value, ConstructFrom from); + QuicArenaScopedPtr(const QuicArenaScopedPtr&) = delete; + QuicArenaScopedPtr& operator=(const QuicArenaScopedPtr&) = delete; + + // Low-order bits of value_ that determine if the pointer came from an arena. + static const uintptr_t kFromArenaMask = 0x1; + + // Every platform we care about has at least 4B aligned integers, so store the + // is_from_arena bit in the least significant bit. + void* value_; +}; + +template +bool operator==(const QuicArenaScopedPtr& left, + const QuicArenaScopedPtr& right) { + return left.get() == right.get(); +} + +template +bool operator!=(const QuicArenaScopedPtr& left, + const QuicArenaScopedPtr& right) { + return left.get() != right.get(); +} + +template +bool operator==(std::nullptr_t, const QuicArenaScopedPtr& right) { + return nullptr == right.get(); +} + +template +bool operator!=(std::nullptr_t, const QuicArenaScopedPtr& right) { + return nullptr != right.get(); +} + +template +bool operator==(const QuicArenaScopedPtr& left, std::nullptr_t) { + return left.get() == nullptr; +} + +template +bool operator!=(const QuicArenaScopedPtr& left, std::nullptr_t) { + return left.get() != nullptr; +} + +template +QuicArenaScopedPtr::QuicArenaScopedPtr() : value_(nullptr) {} + +template +QuicArenaScopedPtr::QuicArenaScopedPtr(T* value) + : QuicArenaScopedPtr(value, ConstructFrom::kHeap) {} + +template +template +QuicArenaScopedPtr::QuicArenaScopedPtr(QuicArenaScopedPtr&& other) + : value_(other.value_) { + static_assert( + std::is_base_of::value || std::is_same::value, + "Cannot construct QuicArenaScopedPtr; type is not derived or same."); + other.value_ = nullptr; +} + +template +template +QuicArenaScopedPtr& QuicArenaScopedPtr::operator=( + QuicArenaScopedPtr&& other) { + static_assert( + std::is_base_of::value || std::is_same::value, + "Cannot assign QuicArenaScopedPtr; type is not derived or same."); + swap(other); + return *this; +} + +template +QuicArenaScopedPtr::~QuicArenaScopedPtr() { + reset(); +} + +template +T* QuicArenaScopedPtr::get() const { + return reinterpret_cast(reinterpret_cast(value_) & + ~kFromArenaMask); +} + +template +T& QuicArenaScopedPtr::operator*() const { + return *get(); +} + +template +T* QuicArenaScopedPtr::operator->() const { + return get(); +} + +template +void QuicArenaScopedPtr::swap(QuicArenaScopedPtr& other) { + using std::swap; + swap(value_, other.value_); +} + +template +bool QuicArenaScopedPtr::is_from_arena() { + return (reinterpret_cast(value_) & kFromArenaMask) != 0; +} + +template +void QuicArenaScopedPtr::reset(T* value) { + if (value_ != nullptr) { + if (is_from_arena()) { + // Manually invoke the destructor. + get()->~T(); + } else { + delete get(); + } + } + QUICHE_DCHECK_EQ(0u, reinterpret_cast(value) & kFromArenaMask); + value_ = value; +} + +template +QuicArenaScopedPtr::QuicArenaScopedPtr(void* value, ConstructFrom from_arena) + : value_(value) { + QUICHE_DCHECK_EQ(0u, reinterpret_cast(value_) & kFromArenaMask); + switch (from_arena) { + case ConstructFrom::kHeap: + break; + case ConstructFrom::kArena: + value_ = reinterpret_cast(reinterpret_cast(value_) | + QuicArenaScopedPtr::kFromArenaMask); + break; + } +} + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_ARENA_SCOPED_PTR_H_ diff --git a/quiche/quic/core/quic_arena_scoped_ptr_test.cc b/quiche/quic/core/quic_arena_scoped_ptr_test.cc new file mode 100644 index 000000000000..fd6dd640f7ce --- /dev/null +++ b/quiche/quic/core/quic_arena_scoped_ptr_test.cc @@ -0,0 +1,115 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_arena_scoped_ptr.h" + +#include "quiche/quic/core/quic_one_block_arena.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic::test { +namespace { + +enum class TestParam { kFromHeap, kFromArena }; + +struct TestObject { + explicit TestObject(uintptr_t value) : value(value) { buffer.resize(1024); } + uintptr_t value; + + // Ensure that we have a non-trivial destructor that will leak memory if it's + // not called. + std::vector buffer; +}; + +// Used by ::testing::PrintToStringParamName(). +std::string PrintToString(const TestParam& p) { + switch (p) { + case TestParam::kFromHeap: + return "heap"; + case TestParam::kFromArena: + return "arena"; + } + QUICHE_DCHECK(false); + return "?"; +} + +class QuicArenaScopedPtrParamTest : public QuicTestWithParam { + protected: + QuicArenaScopedPtr CreateObject(uintptr_t value) { + QuicArenaScopedPtr ptr; + switch (GetParam()) { + case TestParam::kFromHeap: + ptr = QuicArenaScopedPtr(new TestObject(value)); + QUICHE_CHECK(!ptr.is_from_arena()); + break; + case TestParam::kFromArena: + ptr = arena_.New(value); + QUICHE_CHECK(ptr.is_from_arena()); + break; + } + return ptr; + } + + private: + QuicOneBlockArena<1024> arena_; +}; + +INSTANTIATE_TEST_SUITE_P(QuicArenaScopedPtrParamTest, + QuicArenaScopedPtrParamTest, + testing::Values(TestParam::kFromHeap, + TestParam::kFromArena), + ::testing::PrintToStringParamName()); + +TEST_P(QuicArenaScopedPtrParamTest, NullObjects) { + QuicArenaScopedPtr def; + QuicArenaScopedPtr null(nullptr); + EXPECT_EQ(def, null); + EXPECT_EQ(def, nullptr); + EXPECT_EQ(null, nullptr); +} + +TEST_P(QuicArenaScopedPtrParamTest, FromArena) { + QuicOneBlockArena<1024> arena_; + EXPECT_TRUE(arena_.New(0).is_from_arena()); + EXPECT_FALSE( + QuicArenaScopedPtr(new TestObject(0)).is_from_arena()); +} + +TEST_P(QuicArenaScopedPtrParamTest, Assign) { + QuicArenaScopedPtr ptr = CreateObject(12345); + ptr = CreateObject(54321); + EXPECT_EQ(54321u, ptr->value); +} + +TEST_P(QuicArenaScopedPtrParamTest, MoveConstruct) { + QuicArenaScopedPtr ptr1 = CreateObject(12345); + QuicArenaScopedPtr ptr2(std::move(ptr1)); + EXPECT_EQ(nullptr, ptr1); + EXPECT_EQ(12345u, ptr2->value); +} + +TEST_P(QuicArenaScopedPtrParamTest, Accessors) { + QuicArenaScopedPtr ptr = CreateObject(12345); + EXPECT_EQ(12345u, (*ptr).value); + EXPECT_EQ(12345u, ptr->value); + // We explicitly want to test that get() returns a valid pointer to the data, + // but the call looks redundant. + EXPECT_EQ(12345u, ptr.get()->value); // NOLINT +} + +TEST_P(QuicArenaScopedPtrParamTest, Reset) { + QuicArenaScopedPtr ptr = CreateObject(12345); + ptr.reset(new TestObject(54321)); + EXPECT_EQ(54321u, ptr->value); +} + +TEST_P(QuicArenaScopedPtrParamTest, Swap) { + QuicArenaScopedPtr ptr1 = CreateObject(12345); + QuicArenaScopedPtr ptr2 = CreateObject(54321); + ptr1.swap(ptr2); + EXPECT_EQ(12345u, ptr2->value); + EXPECT_EQ(54321u, ptr1->value); +} + +} // namespace +} // namespace quic::test diff --git a/quiche/quic/core/quic_bandwidth.cc b/quiche/quic/core/quic_bandwidth.cc new file mode 100644 index 000000000000..4b25432da33e --- /dev/null +++ b/quiche/quic/core/quic_bandwidth.cc @@ -0,0 +1,41 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_bandwidth.h" + +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" + +namespace quic { + +std::string QuicBandwidth::ToDebuggingValue() const { + if (bits_per_second_ < 80000) { + return absl::StrFormat("%d bits/s (%d bytes/s)", bits_per_second_, + bits_per_second_ / 8); + } + + double divisor; + char unit; + if (bits_per_second_ < 8 * 1000 * 1000) { + divisor = 1e3; + unit = 'k'; + } else if (bits_per_second_ < INT64_C(8) * 1000 * 1000 * 1000) { + divisor = 1e6; + unit = 'M'; + } else { + divisor = 1e9; + unit = 'G'; + } + + double bits_per_second_with_unit = bits_per_second_ / divisor; + double bytes_per_second_with_unit = bits_per_second_with_unit / 8; + return absl::StrFormat("%.2f %cbits/s (%.2f %cbytes/s)", + bits_per_second_with_unit, unit, + bytes_per_second_with_unit, unit); +} + +} // namespace quic diff --git a/quiche/quic/core/quic_bandwidth.h b/quiche/quic/core/quic_bandwidth.h new file mode 100644 index 000000000000..33356abc35e3 --- /dev/null +++ b/quiche/quic/core/quic_bandwidth.h @@ -0,0 +1,168 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// QuicBandwidth represents a bandwidth, stored in bits per second resolution. + +#ifndef QUICHE_QUIC_CORE_QUIC_BANDWIDTH_H_ +#define QUICHE_QUIC_CORE_QUIC_BANDWIDTH_H_ + +#include +#include +#include +#include +#include + +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" + +namespace quic { + +class QUIC_EXPORT_PRIVATE QuicBandwidth { + public: + // Creates a new QuicBandwidth with an internal value of 0. + static constexpr QuicBandwidth Zero() { return QuicBandwidth(0); } + + // Creates a new QuicBandwidth with an internal value of INT64_MAX. + static constexpr QuicBandwidth Infinite() { + return QuicBandwidth(std::numeric_limits::max()); + } + + // Create a new QuicBandwidth holding the bits per second. + static constexpr QuicBandwidth FromBitsPerSecond(int64_t bits_per_second) { + return QuicBandwidth(bits_per_second); + } + + // Create a new QuicBandwidth holding the kilo bits per second. + static constexpr QuicBandwidth FromKBitsPerSecond(int64_t k_bits_per_second) { + return QuicBandwidth(k_bits_per_second * 1000); + } + + // Create a new QuicBandwidth holding the bytes per second. + static constexpr QuicBandwidth FromBytesPerSecond(int64_t bytes_per_second) { + return QuicBandwidth(bytes_per_second * 8); + } + + // Create a new QuicBandwidth holding the kilo bytes per second. + static constexpr QuicBandwidth FromKBytesPerSecond( + int64_t k_bytes_per_second) { + return QuicBandwidth(k_bytes_per_second * 8000); + } + + // Create a new QuicBandwidth based on the bytes per the elapsed delta. + static QuicBandwidth FromBytesAndTimeDelta(QuicByteCount bytes, + QuicTime::Delta delta) { + if (bytes == 0) { + return QuicBandwidth(0); + } + + // 1 bit is 1000000 micro bits. + int64_t num_micro_bits = 8 * bytes * kNumMicrosPerSecond; + if (num_micro_bits < delta.ToMicroseconds()) { + return QuicBandwidth(1); + } + + return QuicBandwidth(num_micro_bits / delta.ToMicroseconds()); + } + + int64_t ToBitsPerSecond() const { return bits_per_second_; } + + int64_t ToKBitsPerSecond() const { return bits_per_second_ / 1000; } + + int64_t ToBytesPerSecond() const { return bits_per_second_ / 8; } + + int64_t ToKBytesPerSecond() const { return bits_per_second_ / 8000; } + + constexpr QuicByteCount ToBytesPerPeriod(QuicTime::Delta time_period) const { + return bits_per_second_ * time_period.ToMicroseconds() / 8 / + kNumMicrosPerSecond; + } + + int64_t ToKBytesPerPeriod(QuicTime::Delta time_period) const { + return bits_per_second_ * time_period.ToMicroseconds() / 8000 / + kNumMicrosPerSecond; + } + + bool IsZero() const { return bits_per_second_ == 0; } + bool IsInfinite() const { + return bits_per_second_ == Infinite().ToBitsPerSecond(); + } + + constexpr QuicTime::Delta TransferTime(QuicByteCount bytes) const { + if (bits_per_second_ == 0) { + return QuicTime::Delta::Zero(); + } + return QuicTime::Delta::FromMicroseconds(bytes * 8 * kNumMicrosPerSecond / + bits_per_second_); + } + + std::string ToDebuggingValue() const; + + private: + explicit constexpr QuicBandwidth(int64_t bits_per_second) + : bits_per_second_(bits_per_second >= 0 ? bits_per_second : 0) {} + + int64_t bits_per_second_; + + friend constexpr QuicBandwidth operator+(QuicBandwidth lhs, + QuicBandwidth rhs); + friend constexpr QuicBandwidth operator-(QuicBandwidth lhs, + QuicBandwidth rhs); + friend QuicBandwidth operator*(QuicBandwidth lhs, float rhs); +}; + +// Non-member relational operators for QuicBandwidth. +inline bool operator==(QuicBandwidth lhs, QuicBandwidth rhs) { + return lhs.ToBitsPerSecond() == rhs.ToBitsPerSecond(); +} +inline bool operator!=(QuicBandwidth lhs, QuicBandwidth rhs) { + return !(lhs == rhs); +} +inline bool operator<(QuicBandwidth lhs, QuicBandwidth rhs) { + return lhs.ToBitsPerSecond() < rhs.ToBitsPerSecond(); +} +inline bool operator>(QuicBandwidth lhs, QuicBandwidth rhs) { + return rhs < lhs; +} +inline bool operator<=(QuicBandwidth lhs, QuicBandwidth rhs) { + return !(rhs < lhs); +} +inline bool operator>=(QuicBandwidth lhs, QuicBandwidth rhs) { + return !(lhs < rhs); +} + +// Non-member arithmetic operators for QuicBandwidth. +inline constexpr QuicBandwidth operator+(QuicBandwidth lhs, QuicBandwidth rhs) { + return QuicBandwidth(lhs.bits_per_second_ + rhs.bits_per_second_); +} +inline constexpr QuicBandwidth operator-(QuicBandwidth lhs, QuicBandwidth rhs) { + return QuicBandwidth(lhs.bits_per_second_ - rhs.bits_per_second_); +} +inline QuicBandwidth operator*(QuicBandwidth lhs, float rhs) { + return QuicBandwidth( + static_cast(std::llround(lhs.bits_per_second_ * rhs))); +} +inline QuicBandwidth operator*(float lhs, QuicBandwidth rhs) { + return rhs * lhs; +} +inline constexpr QuicByteCount operator*(QuicBandwidth lhs, + QuicTime::Delta rhs) { + return lhs.ToBytesPerPeriod(rhs); +} +inline constexpr QuicByteCount operator*(QuicTime::Delta lhs, + QuicBandwidth rhs) { + return rhs * lhs; +} + +// Override stream output operator for gtest. +inline std::ostream& operator<<(std::ostream& output, + const QuicBandwidth bandwidth) { + output << bandwidth.ToDebuggingValue(); + return output; +} + +} // namespace quic +#endif // QUICHE_QUIC_CORE_QUIC_BANDWIDTH_H_ diff --git a/quiche/quic/core/quic_bandwidth_test.cc b/quiche/quic/core/quic_bandwidth_test.cc new file mode 100644 index 000000000000..2d2f99471fb6 --- /dev/null +++ b/quiche/quic/core/quic_bandwidth_test.cc @@ -0,0 +1,151 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_bandwidth.h" + +#include + +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { + +class QuicBandwidthTest : public QuicTest {}; + +TEST_F(QuicBandwidthTest, FromTo) { + EXPECT_EQ(QuicBandwidth::FromKBitsPerSecond(1), + QuicBandwidth::FromBitsPerSecond(1000)); + EXPECT_EQ(QuicBandwidth::FromKBytesPerSecond(1), + QuicBandwidth::FromBytesPerSecond(1000)); + EXPECT_EQ(QuicBandwidth::FromBitsPerSecond(8000), + QuicBandwidth::FromBytesPerSecond(1000)); + EXPECT_EQ(QuicBandwidth::FromKBitsPerSecond(8), + QuicBandwidth::FromKBytesPerSecond(1)); + + EXPECT_EQ(0, QuicBandwidth::Zero().ToBitsPerSecond()); + EXPECT_EQ(0, QuicBandwidth::Zero().ToKBitsPerSecond()); + EXPECT_EQ(0, QuicBandwidth::Zero().ToBytesPerSecond()); + EXPECT_EQ(0, QuicBandwidth::Zero().ToKBytesPerSecond()); + + EXPECT_EQ(1, QuicBandwidth::FromBitsPerSecond(1000).ToKBitsPerSecond()); + EXPECT_EQ(1000, QuicBandwidth::FromKBitsPerSecond(1).ToBitsPerSecond()); + EXPECT_EQ(1, QuicBandwidth::FromBytesPerSecond(1000).ToKBytesPerSecond()); + EXPECT_EQ(1000, QuicBandwidth::FromKBytesPerSecond(1).ToBytesPerSecond()); +} + +TEST_F(QuicBandwidthTest, Add) { + QuicBandwidth bandwidht_1 = QuicBandwidth::FromKBitsPerSecond(1); + QuicBandwidth bandwidht_2 = QuicBandwidth::FromKBytesPerSecond(1); + + EXPECT_EQ(9000, (bandwidht_1 + bandwidht_2).ToBitsPerSecond()); + EXPECT_EQ(9000, (bandwidht_2 + bandwidht_1).ToBitsPerSecond()); +} + +TEST_F(QuicBandwidthTest, Subtract) { + QuicBandwidth bandwidht_1 = QuicBandwidth::FromKBitsPerSecond(1); + QuicBandwidth bandwidht_2 = QuicBandwidth::FromKBytesPerSecond(1); + + EXPECT_EQ(7000, (bandwidht_2 - bandwidht_1).ToBitsPerSecond()); +} + +TEST_F(QuicBandwidthTest, TimeDelta) { + EXPECT_EQ(QuicBandwidth::FromKBytesPerSecond(1000), + QuicBandwidth::FromBytesAndTimeDelta( + 1000, QuicTime::Delta::FromMilliseconds(1))); + + EXPECT_EQ(QuicBandwidth::FromKBytesPerSecond(10), + QuicBandwidth::FromBytesAndTimeDelta( + 1000, QuicTime::Delta::FromMilliseconds(100))); + + EXPECT_EQ(QuicBandwidth::Zero(), QuicBandwidth::FromBytesAndTimeDelta( + 0, QuicTime::Delta::FromSeconds(9))); + + EXPECT_EQ( + QuicBandwidth::FromBitsPerSecond(1), + QuicBandwidth::FromBytesAndTimeDelta(1, QuicTime::Delta::FromSeconds(9))); +} + +TEST_F(QuicBandwidthTest, Scale) { + EXPECT_EQ(QuicBandwidth::FromKBytesPerSecond(500), + QuicBandwidth::FromKBytesPerSecond(1000) * 0.5f); + EXPECT_EQ(QuicBandwidth::FromKBytesPerSecond(750), + 0.75f * QuicBandwidth::FromKBytesPerSecond(1000)); + EXPECT_EQ(QuicBandwidth::FromKBytesPerSecond(1250), + QuicBandwidth::FromKBytesPerSecond(1000) * 1.25f); + + // Ensure we are rounding correctly within a 1bps level of precision. + EXPECT_EQ(QuicBandwidth::FromBitsPerSecond(5), + QuicBandwidth::FromBitsPerSecond(9) * 0.5f); + EXPECT_EQ(QuicBandwidth::FromBitsPerSecond(2), + QuicBandwidth::FromBitsPerSecond(12) * 0.2f); +} + +TEST_F(QuicBandwidthTest, BytesPerPeriod) { + EXPECT_EQ(2000, QuicBandwidth::FromKBytesPerSecond(2000).ToBytesPerPeriod( + QuicTime::Delta::FromMilliseconds(1))); + EXPECT_EQ(2, QuicBandwidth::FromKBytesPerSecond(2000).ToKBytesPerPeriod( + QuicTime::Delta::FromMilliseconds(1))); + EXPECT_EQ(200000, QuicBandwidth::FromKBytesPerSecond(2000).ToBytesPerPeriod( + QuicTime::Delta::FromMilliseconds(100))); + EXPECT_EQ(200, QuicBandwidth::FromKBytesPerSecond(2000).ToKBytesPerPeriod( + QuicTime::Delta::FromMilliseconds(100))); + + // 1599 * 1001 = 1600599 bits/ms = 200.074875 bytes/s. + EXPECT_EQ(200, QuicBandwidth::FromBitsPerSecond(1599).ToBytesPerPeriod( + QuicTime::Delta::FromMilliseconds(1001))); + + EXPECT_EQ(200, QuicBandwidth::FromBitsPerSecond(1599).ToKBytesPerPeriod( + QuicTime::Delta::FromSeconds(1001))); +} + +TEST_F(QuicBandwidthTest, TransferTime) { + EXPECT_EQ(QuicTime::Delta::FromSeconds(1), + QuicBandwidth::FromKBytesPerSecond(1).TransferTime(1000)); + EXPECT_EQ(QuicTime::Delta::Zero(), QuicBandwidth::Zero().TransferTime(1000)); +} + +TEST_F(QuicBandwidthTest, RelOps) { + const QuicBandwidth b1 = QuicBandwidth::FromKBitsPerSecond(1); + const QuicBandwidth b2 = QuicBandwidth::FromKBytesPerSecond(2); + EXPECT_EQ(b1, b1); + EXPECT_NE(b1, b2); + EXPECT_LT(b1, b2); + EXPECT_GT(b2, b1); + EXPECT_LE(b1, b1); + EXPECT_LE(b1, b2); + EXPECT_GE(b1, b1); + EXPECT_GE(b2, b1); +} + +TEST_F(QuicBandwidthTest, DebuggingValue) { + EXPECT_EQ("128 bits/s (16 bytes/s)", + QuicBandwidth::FromBytesPerSecond(16).ToDebuggingValue()); + EXPECT_EQ("4096 bits/s (512 bytes/s)", + QuicBandwidth::FromBytesPerSecond(512).ToDebuggingValue()); + + QuicBandwidth bandwidth = QuicBandwidth::FromBytesPerSecond(1000 * 50); + EXPECT_EQ("400.00 kbits/s (50.00 kbytes/s)", bandwidth.ToDebuggingValue()); + + bandwidth = bandwidth * 1000; + EXPECT_EQ("400.00 Mbits/s (50.00 Mbytes/s)", bandwidth.ToDebuggingValue()); + + bandwidth = bandwidth * 1000; + EXPECT_EQ("400.00 Gbits/s (50.00 Gbytes/s)", bandwidth.ToDebuggingValue()); +} + +TEST_F(QuicBandwidthTest, SpecialValues) { + EXPECT_EQ(0, QuicBandwidth::Zero().ToBitsPerSecond()); + EXPECT_EQ(std::numeric_limits::max(), + QuicBandwidth::Infinite().ToBitsPerSecond()); + + EXPECT_TRUE(QuicBandwidth::Zero().IsZero()); + EXPECT_FALSE(QuicBandwidth::Zero().IsInfinite()); + + EXPECT_TRUE(QuicBandwidth::Infinite().IsInfinite()); + EXPECT_FALSE(QuicBandwidth::Infinite().IsZero()); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_blocked_writer_interface.h b/quiche/quic/core/quic_blocked_writer_interface.h new file mode 100644 index 000000000000..062b5ccec6ad --- /dev/null +++ b/quiche/quic/core/quic_blocked_writer_interface.h @@ -0,0 +1,29 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This is an interface for all objects that want to be notified that +// the underlying UDP socket is available for writing (not write blocked +// anymore). + +#ifndef QUICHE_QUIC_CORE_QUIC_BLOCKED_WRITER_INTERFACE_H_ +#define QUICHE_QUIC_CORE_QUIC_BLOCKED_WRITER_INTERFACE_H_ + +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class QUIC_EXPORT_PRIVATE QuicBlockedWriterInterface { + public: + virtual ~QuicBlockedWriterInterface() {} + + // Called by the PacketWriter when the underlying socket becomes writable + // so that the BlockedWriter can go ahead and try writing. + virtual void OnBlockedWriterCanWrite() = 0; + + virtual bool IsWriterBlocked() const = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_BLOCKED_WRITER_INTERFACE_H_ diff --git a/quiche/quic/core/quic_buffered_packet_store.cc b/quiche/quic/core/quic_buffered_packet_store.cc new file mode 100644 index 000000000000..df028f01ad09 --- /dev/null +++ b/quiche/quic/core/quic_buffered_packet_store.cc @@ -0,0 +1,321 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_buffered_packet_store.h" + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flags.h" + +namespace quic { + +using BufferedPacket = QuicBufferedPacketStore::BufferedPacket; +using BufferedPacketList = QuicBufferedPacketStore::BufferedPacketList; +using EnqueuePacketResult = QuicBufferedPacketStore::EnqueuePacketResult; + +// Max number of connections this store can keep track. +static const size_t kDefaultMaxConnectionsInStore = 100; +// Up to half of the capacity can be used for storing non-CHLO packets. +static const size_t kMaxConnectionsWithoutCHLO = + kDefaultMaxConnectionsInStore / 2; + +namespace { + +// This alarm removes expired entries in map each time this alarm fires. +class ConnectionExpireAlarm : public QuicAlarm::DelegateWithoutContext { + public: + explicit ConnectionExpireAlarm(QuicBufferedPacketStore* store) + : connection_store_(store) {} + + void OnAlarm() override { connection_store_->OnExpirationTimeout(); } + + ConnectionExpireAlarm(const ConnectionExpireAlarm&) = delete; + ConnectionExpireAlarm& operator=(const ConnectionExpireAlarm&) = delete; + + private: + QuicBufferedPacketStore* connection_store_; +}; + +} // namespace + +BufferedPacket::BufferedPacket(std::unique_ptr packet, + QuicSocketAddress self_address, + QuicSocketAddress peer_address) + : packet(std::move(packet)), + self_address(self_address), + peer_address(peer_address) {} + +BufferedPacket::BufferedPacket(BufferedPacket&& other) = default; + +BufferedPacket& BufferedPacket::operator=(BufferedPacket&& other) = default; + +BufferedPacket::~BufferedPacket() {} + +BufferedPacketList::BufferedPacketList() + : creation_time(QuicTime::Zero()), + ietf_quic(false), + version(ParsedQuicVersion::Unsupported()) {} + +BufferedPacketList::BufferedPacketList(BufferedPacketList&& other) = default; + +BufferedPacketList& BufferedPacketList::operator=(BufferedPacketList&& other) = + default; + +BufferedPacketList::~BufferedPacketList() {} + +QuicBufferedPacketStore::QuicBufferedPacketStore( + VisitorInterface* visitor, const QuicClock* clock, + QuicAlarmFactory* alarm_factory) + : connection_life_span_( + QuicTime::Delta::FromSeconds(kInitialIdleTimeoutSecs)), + visitor_(visitor), + clock_(clock), + expiration_alarm_( + alarm_factory->CreateAlarm(new ConnectionExpireAlarm(this))) {} + +QuicBufferedPacketStore::~QuicBufferedPacketStore() { + if (expiration_alarm_ != nullptr) { + expiration_alarm_->PermanentCancel(); + } +} + +EnqueuePacketResult QuicBufferedPacketStore::EnqueuePacket( + QuicConnectionId connection_id, bool ietf_quic, + const QuicReceivedPacket& packet, QuicSocketAddress self_address, + QuicSocketAddress peer_address, const ParsedQuicVersion& version, + absl::optional parsed_chlo) { + const bool is_chlo = parsed_chlo.has_value(); + QUIC_BUG_IF(quic_bug_12410_1, !GetQuicFlag(quic_allow_chlo_buffering)) + << "Shouldn't buffer packets if disabled via flag."; + QUIC_BUG_IF(quic_bug_12410_2, + is_chlo && connections_with_chlo_.contains(connection_id)) + << "Shouldn't buffer duplicated CHLO on connection " << connection_id; + QUIC_BUG_IF(quic_bug_12410_4, is_chlo && !version.IsKnown()) + << "Should have version for CHLO packet."; + + const bool is_first_packet = !undecryptable_packets_.contains(connection_id); + if (is_first_packet) { + if (ShouldNotBufferPacket(is_chlo)) { + // Drop the packet if the upper limit of undecryptable packets has been + // reached or the whole capacity of the store has been reached. + return TOO_MANY_CONNECTIONS; + } + undecryptable_packets_.emplace( + std::make_pair(connection_id, BufferedPacketList())); + undecryptable_packets_.back().second.ietf_quic = ietf_quic; + undecryptable_packets_.back().second.version = version; + } + QUICHE_CHECK(undecryptable_packets_.contains(connection_id)); + BufferedPacketList& queue = + undecryptable_packets_.find(connection_id)->second; + + if (!is_chlo) { + // If current packet is not CHLO, it might not be buffered because store + // only buffers certain number of undecryptable packets per connection. + size_t num_non_chlo_packets = connections_with_chlo_.contains(connection_id) + ? (queue.buffered_packets.size() - 1) + : queue.buffered_packets.size(); + if (num_non_chlo_packets >= kDefaultMaxUndecryptablePackets) { + // If there are kMaxBufferedPacketsPerConnection packets buffered up for + // this connection, drop the current packet. + return TOO_MANY_PACKETS; + } + } + + if (queue.buffered_packets.empty()) { + // If this is the first packet arrived on a new connection, initialize the + // creation time. + queue.creation_time = clock_->ApproximateNow(); + } + + BufferedPacket new_entry(std::unique_ptr(packet.Clone()), + self_address, peer_address); + if (is_chlo) { + // Add CHLO to the beginning of buffered packets so that it can be delivered + // first later. + queue.buffered_packets.push_front(std::move(new_entry)); + queue.parsed_chlo = std::move(parsed_chlo); + connections_with_chlo_[connection_id] = false; // Dummy value. + // Set the version of buffered packets of this connection on CHLO. + queue.version = version; + } else { + // Buffer non-CHLO packets in arrival order. + queue.buffered_packets.push_back(std::move(new_entry)); + + // Attempt to parse multi-packet TLS CHLOs. + if (is_first_packet) { + queue.tls_chlo_extractor.IngestPacket(version, packet); + // Since this is the first packet and it's not a CHLO, the + // TlsChloExtractor should not have the entire CHLO. + QUIC_BUG_IF(quic_bug_12410_5, + queue.tls_chlo_extractor.HasParsedFullChlo()) + << "First packet in list should not contain full CHLO"; + } + // TODO(b/154857081) Reorder CHLO packets ahead of other ones. + } + + MaybeSetExpirationAlarm(); + return SUCCESS; +} + +bool QuicBufferedPacketStore::HasBufferedPackets( + QuicConnectionId connection_id) const { + return undecryptable_packets_.contains(connection_id); +} + +bool QuicBufferedPacketStore::HasChlosBuffered() const { + return !connections_with_chlo_.empty(); +} + +BufferedPacketList QuicBufferedPacketStore::DeliverPackets( + QuicConnectionId connection_id) { + BufferedPacketList packets_to_deliver; + auto it = undecryptable_packets_.find(connection_id); + if (it != undecryptable_packets_.end()) { + packets_to_deliver = std::move(it->second); + undecryptable_packets_.erase(connection_id); + std::list initial_packets; + std::list other_packets; + for (auto& packet : packets_to_deliver.buffered_packets) { + QuicLongHeaderType long_packet_type = INVALID_PACKET_TYPE; + PacketHeaderFormat unused_format; + bool unused_version_flag; + bool unused_use_length_prefix; + QuicVersionLabel unused_version_label; + ParsedQuicVersion unused_parsed_version = UnsupportedQuicVersion(); + QuicConnectionId unused_destination_connection_id; + QuicConnectionId unused_source_connection_id; + absl::optional unused_retry_token; + std::string unused_detailed_error; + + // We don't need to pass |generator| because we already got the correct + // connection ID length when we buffered the packet and indexed by + // connection ID. + QuicErrorCode error_code = QuicFramer::ParsePublicHeaderDispatcher( + *packet.packet, connection_id.length(), &unused_format, + &long_packet_type, &unused_version_flag, &unused_use_length_prefix, + &unused_version_label, &unused_parsed_version, + &unused_destination_connection_id, &unused_source_connection_id, + &unused_retry_token, &unused_detailed_error); + + if (error_code == QUIC_NO_ERROR && long_packet_type == INITIAL) { + initial_packets.push_back(std::move(packet)); + } else { + other_packets.push_back(std::move(packet)); + } + } + + initial_packets.splice(initial_packets.end(), other_packets); + packets_to_deliver.buffered_packets = std::move(initial_packets); + } + return packets_to_deliver; +} + +void QuicBufferedPacketStore::DiscardPackets(QuicConnectionId connection_id) { + undecryptable_packets_.erase(connection_id); + connections_with_chlo_.erase(connection_id); +} + +void QuicBufferedPacketStore::DiscardAllPackets() { + undecryptable_packets_.clear(); + connections_with_chlo_.clear(); + expiration_alarm_->Cancel(); +} + +void QuicBufferedPacketStore::OnExpirationTimeout() { + QuicTime expiration_time = clock_->ApproximateNow() - connection_life_span_; + while (!undecryptable_packets_.empty()) { + auto& entry = undecryptable_packets_.front(); + if (entry.second.creation_time > expiration_time) { + break; + } + QuicConnectionId connection_id = entry.first; + visitor_->OnExpiredPackets(connection_id, std::move(entry.second)); + undecryptable_packets_.pop_front(); + connections_with_chlo_.erase(connection_id); + } + if (!undecryptable_packets_.empty()) { + MaybeSetExpirationAlarm(); + } +} + +void QuicBufferedPacketStore::MaybeSetExpirationAlarm() { + if (!expiration_alarm_->IsSet()) { + expiration_alarm_->Set(clock_->ApproximateNow() + connection_life_span_); + } +} + +bool QuicBufferedPacketStore::ShouldNotBufferPacket(bool is_chlo) { + bool is_store_full = + undecryptable_packets_.size() >= kDefaultMaxConnectionsInStore; + + if (is_chlo) { + return is_store_full; + } + + size_t num_connections_without_chlo = + undecryptable_packets_.size() - connections_with_chlo_.size(); + bool reach_non_chlo_limit = + num_connections_without_chlo >= kMaxConnectionsWithoutCHLO; + + return is_store_full || reach_non_chlo_limit; +} + +BufferedPacketList QuicBufferedPacketStore::DeliverPacketsForNextConnection( + QuicConnectionId* connection_id) { + if (connections_with_chlo_.empty()) { + // Returns empty list if no CHLO has been buffered. + return BufferedPacketList(); + } + *connection_id = connections_with_chlo_.front().first; + connections_with_chlo_.pop_front(); + + BufferedPacketList packets = DeliverPackets(*connection_id); + QUICHE_DCHECK(!packets.buffered_packets.empty() && + packets.parsed_chlo.has_value()) + << "Try to deliver connectons without CHLO. # packets:" + << packets.buffered_packets.size() + << ", has_parsed_chlo:" << packets.parsed_chlo.has_value(); + return packets; +} + +bool QuicBufferedPacketStore::HasChloForConnection( + QuicConnectionId connection_id) { + return connections_with_chlo_.contains(connection_id); +} + +bool QuicBufferedPacketStore::IngestPacketForTlsChloExtraction( + const QuicConnectionId& connection_id, const ParsedQuicVersion& version, + const QuicReceivedPacket& packet, std::vector* out_alpns, + std::string* out_sni, bool* out_resumption_attempted, + bool* out_early_data_attempted, absl::optional* tls_alert) { + QUICHE_DCHECK_NE(out_alpns, nullptr); + QUICHE_DCHECK_NE(out_sni, nullptr); + QUICHE_DCHECK_NE(tls_alert, nullptr); + QUICHE_DCHECK_EQ(version.handshake_protocol, PROTOCOL_TLS1_3); + auto it = undecryptable_packets_.find(connection_id); + if (it == undecryptable_packets_.end()) { + QUIC_BUG(quic_bug_10838_1) + << "Cannot ingest packet for unknown connection ID " << connection_id; + return false; + } + it->second.tls_chlo_extractor.IngestPacket(version, packet); + if (!it->second.tls_chlo_extractor.HasParsedFullChlo()) { + *tls_alert = it->second.tls_chlo_extractor.tls_alert(); + return false; + } + const TlsChloExtractor& tls_chlo_extractor = it->second.tls_chlo_extractor; + *out_alpns = tls_chlo_extractor.alpns(); + *out_sni = tls_chlo_extractor.server_name(); + *out_resumption_attempted = tls_chlo_extractor.resumption_attempted(); + *out_early_data_attempted = tls_chlo_extractor.early_data_attempted(); + return true; +} + +} // namespace quic diff --git a/quiche/quic/core/quic_buffered_packet_store.h b/quiche/quic/core/quic_buffered_packet_store.h new file mode 100644 index 000000000000..95bb737e0ffe --- /dev/null +++ b/quiche/quic/core/quic_buffered_packet_store.h @@ -0,0 +1,194 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_BUFFERED_PACKET_STORE_H_ +#define QUICHE_QUIC_CORE_QUIC_BUFFERED_PACKET_STORE_H_ + +#include +#include + +#include "quiche/quic/core/quic_alarm.h" +#include "quiche/quic/core/quic_alarm_factory.h" +#include "quiche/quic/core/quic_clock.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/tls_chlo_extractor.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/common/quiche_linked_hash_map.h" + +namespace quic { + +namespace test { +class QuicBufferedPacketStorePeer; +} // namespace test + +// This class buffers packets for each connection until either +// 1) They are requested to be delivered via +// DeliverPacket()/DeliverPacketsForNextConnection(), or +// 2) They expire after exceeding their lifetime in the store. +// +// It can only buffer packets on certain number of connections. It has two pools +// of connections: connections with CHLO buffered and those without CHLO. The +// latter has its own upper limit along with the max number of connections this +// store can hold. The former pool can grow till this store is full. +class QUIC_NO_EXPORT QuicBufferedPacketStore { + public: + enum EnqueuePacketResult { + SUCCESS = 0, + TOO_MANY_PACKETS, // Too many packets stored up for a certain connection. + TOO_MANY_CONNECTIONS // Too many connections stored up in the store. + }; + + struct QUIC_NO_EXPORT BufferedPacket { + BufferedPacket(std::unique_ptr packet, + QuicSocketAddress self_address, + QuicSocketAddress peer_address); + BufferedPacket(BufferedPacket&& other); + + BufferedPacket& operator=(BufferedPacket&& other); + + ~BufferedPacket(); + + std::unique_ptr packet; + QuicSocketAddress self_address; + QuicSocketAddress peer_address; + }; + + // A queue of BufferedPackets for a connection. + struct QUIC_NO_EXPORT BufferedPacketList { + BufferedPacketList(); + BufferedPacketList(BufferedPacketList&& other); + + BufferedPacketList& operator=(BufferedPacketList&& other); + + ~BufferedPacketList(); + + std::list buffered_packets; + QuicTime creation_time; + // |parsed_chlo| is set iff the entire CHLO has been received. + absl::optional parsed_chlo; + // Indicating whether this is an IETF QUIC connection. + bool ietf_quic; + // If buffered_packets contains the CHLO, it is the version of the CHLO. + // Otherwise, it is the version of the first packet in |buffered_packets|. + ParsedQuicVersion version; + TlsChloExtractor tls_chlo_extractor; + }; + + using BufferedPacketMap = + quiche::QuicheLinkedHashMap; + + class QUIC_NO_EXPORT VisitorInterface { + public: + virtual ~VisitorInterface() {} + + // Called for each expired connection when alarm fires. + virtual void OnExpiredPackets(QuicConnectionId connection_id, + BufferedPacketList early_arrived_packets) = 0; + }; + + QuicBufferedPacketStore(VisitorInterface* visitor, const QuicClock* clock, + QuicAlarmFactory* alarm_factory); + + QuicBufferedPacketStore(const QuicBufferedPacketStore&) = delete; + + ~QuicBufferedPacketStore(); + + QuicBufferedPacketStore& operator=(const QuicBufferedPacketStore&) = delete; + + // Adds a copy of packet into the packet queue for given connection. If the + // packet is the last one of the CHLO, |parsed_chlo| will contain a parsed + // version of the CHLO. + EnqueuePacketResult EnqueuePacket( + QuicConnectionId connection_id, bool ietf_quic, + const QuicReceivedPacket& packet, QuicSocketAddress self_address, + QuicSocketAddress peer_address, const ParsedQuicVersion& version, + absl::optional parsed_chlo); + + // Returns true if there are any packets buffered for |connection_id|. + bool HasBufferedPackets(QuicConnectionId connection_id) const; + + // Ingests this packet into the corresponding TlsChloExtractor. This should + // only be called when HasBufferedPackets(connection_id) is true. + // Returns whether we've now parsed a full multi-packet TLS CHLO. + // When this returns true, |out_alpns| is populated with the list of ALPNs + // extracted from the CHLO. |out_sni| is populated with the SNI tag in CHLO. + // |out_resumption_attempted| is populated if the CHLO has the + // 'pre_shared_key' TLS extension. |out_early_data_attempted| is populated if + // the CHLO has the 'early_data' TLS extension. + // When this returns false, and an unrecoverable error happened due to a TLS + // alert, |*tls_alert| will be set to the alert value. + bool IngestPacketForTlsChloExtraction( + const QuicConnectionId& connection_id, const ParsedQuicVersion& version, + const QuicReceivedPacket& packet, std::vector* out_alpns, + std::string* out_sni, bool* out_resumption_attempted, + bool* out_early_data_attempted, absl::optional* tls_alert); + + // Returns the list of buffered packets for |connection_id| and removes them + // from the store. Returns an empty list if no early arrived packets for this + // connection are present. + BufferedPacketList DeliverPackets(QuicConnectionId connection_id); + + // Discards packets buffered for |connection_id|, if any. + void DiscardPackets(QuicConnectionId connection_id); + + // Discards all the packets. + void DiscardAllPackets(); + + // Examines how long packets have been buffered in the store for each + // connection. If they stay too long, removes them for new coming packets and + // calls |visitor_|'s OnPotentialConnectionExpire(). + // Resets the alarm at the end. + void OnExpirationTimeout(); + + // Delivers buffered packets for next connection with CHLO to open. + // Return connection id for next connection in |connection_id| + // and all buffered packets including CHLO. + // The returned list should at least has one packet(CHLO) if + // store does have any connection to open. If no connection in the store has + // received CHLO yet, empty list will be returned. + BufferedPacketList DeliverPacketsForNextConnection( + QuicConnectionId* connection_id); + + // Is given connection already buffered in the store? + bool HasChloForConnection(QuicConnectionId connection_id); + + // Is there any CHLO buffered in the store? + bool HasChlosBuffered() const; + + private: + friend class test::QuicBufferedPacketStorePeer; + + // Set expiration alarm if it hasn't been set. + void MaybeSetExpirationAlarm(); + + // Return true if add an extra packet will go beyond allowed max connection + // limit. The limit for non-CHLO packet and CHLO packet is different. + bool ShouldNotBufferPacket(bool is_chlo); + + // A map to store packet queues with creation time for each connection. + BufferedPacketMap undecryptable_packets_; + + // The max time the packets of a connection can be buffer in the store. + const QuicTime::Delta connection_life_span_; + + VisitorInterface* visitor_; // Unowned. + + const QuicClock* clock_; // Unowned. + + // This alarm fires every |connection_life_span_| to clean up + // packets staying in the store for too long. + std::unique_ptr expiration_alarm_; + + // Keeps track of connection with CHLO buffered up already and the order they + // arrive. + quiche::QuicheLinkedHashMap + connections_with_chlo_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_BUFFERED_PACKET_STORE_H_ diff --git a/quiche/quic/core/quic_buffered_packet_store_test.cc b/quiche/quic/core/quic_buffered_packet_store_test.cc new file mode 100644 index 000000000000..3b3230077aa7 --- /dev/null +++ b/quiche/quic/core/quic_buffered_packet_store_test.cc @@ -0,0 +1,600 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_buffered_packet_store.h" + +#include +#include +#include + +#include "quiche/quic/core/crypto/transport_parameters.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/first_flight.h" +#include "quiche/quic/test_tools/mock_clock.h" +#include "quiche/quic/test_tools/quic_buffered_packet_store_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { +static const size_t kDefaultMaxConnectionsInStore = 100; +static const size_t kMaxConnectionsWithoutCHLO = + kDefaultMaxConnectionsInStore / 2; + +namespace test { +namespace { + +const absl::optional kNoParsedChlo; +const absl::optional kDefaultParsedChlo = + absl::make_optional(); + +using BufferedPacket = QuicBufferedPacketStore::BufferedPacket; +using BufferedPacketList = QuicBufferedPacketStore::BufferedPacketList; +using EnqueuePacketResult = QuicBufferedPacketStore::EnqueuePacketResult; +using ::testing::A; +using ::testing::Conditional; +using ::testing::Each; +using ::testing::ElementsAre; +using ::testing::Ne; +using ::testing::SizeIs; +using ::testing::Truly; + +class QuicBufferedPacketStoreVisitor + : public QuicBufferedPacketStore::VisitorInterface { + public: + QuicBufferedPacketStoreVisitor() {} + + ~QuicBufferedPacketStoreVisitor() override {} + + void OnExpiredPackets(QuicConnectionId /*connection_id*/, + BufferedPacketList early_arrived_packets) override { + last_expired_packet_queue_ = std::move(early_arrived_packets); + } + + // The packets queue for most recently expirect connection. + BufferedPacketList last_expired_packet_queue_; +}; + +class QuicBufferedPacketStoreTest : public QuicTest { + public: + QuicBufferedPacketStoreTest() + : store_(&visitor_, &clock_, &alarm_factory_), + self_address_(QuicIpAddress::Any6(), 65535), + peer_address_(QuicIpAddress::Any6(), 65535), + packet_content_("some encrypted content"), + packet_time_(QuicTime::Zero() + QuicTime::Delta::FromMicroseconds(42)), + packet_(packet_content_.data(), packet_content_.size(), packet_time_), + invalid_version_(UnsupportedQuicVersion()), + valid_version_(CurrentSupportedVersions().front()) {} + + protected: + QuicBufferedPacketStoreVisitor visitor_; + MockClock clock_; + MockAlarmFactory alarm_factory_; + QuicBufferedPacketStore store_; + QuicSocketAddress self_address_; + QuicSocketAddress peer_address_; + std::string packet_content_; + QuicTime packet_time_; + QuicReceivedPacket packet_; + const ParsedQuicVersion invalid_version_; + const ParsedQuicVersion valid_version_; +}; + +TEST_F(QuicBufferedPacketStoreTest, SimpleEnqueueAndDeliverPacket) { + QuicConnectionId connection_id = TestConnectionId(1); + store_.EnqueuePacket(connection_id, false, packet_, self_address_, + peer_address_, invalid_version_, kNoParsedChlo); + EXPECT_TRUE(store_.HasBufferedPackets(connection_id)); + auto packets = store_.DeliverPackets(connection_id); + const std::list& queue = packets.buffered_packets; + ASSERT_EQ(1u, queue.size()); + ASSERT_FALSE(packets.parsed_chlo.has_value()); + // There is no valid version because CHLO has not arrived. + EXPECT_EQ(invalid_version_, packets.version); + // Check content of the only packet in the queue. + EXPECT_EQ(packet_content_, queue.front().packet->AsStringPiece()); + EXPECT_EQ(packet_time_, queue.front().packet->receipt_time()); + EXPECT_EQ(peer_address_, queue.front().peer_address); + EXPECT_EQ(self_address_, queue.front().self_address); + // No more packets on connection 1 should remain in the store. + EXPECT_TRUE(store_.DeliverPackets(connection_id).buffered_packets.empty()); + EXPECT_FALSE(store_.HasBufferedPackets(connection_id)); +} + +TEST_F(QuicBufferedPacketStoreTest, DifferentPacketAddressOnOneConnection) { + QuicSocketAddress addr_with_new_port(QuicIpAddress::Any4(), 256); + QuicConnectionId connection_id = TestConnectionId(1); + store_.EnqueuePacket(connection_id, false, packet_, self_address_, + peer_address_, invalid_version_, kNoParsedChlo); + store_.EnqueuePacket(connection_id, false, packet_, self_address_, + addr_with_new_port, invalid_version_, kNoParsedChlo); + std::list queue = + store_.DeliverPackets(connection_id).buffered_packets; + ASSERT_EQ(2u, queue.size()); + // The address migration path should be preserved. + EXPECT_EQ(peer_address_, queue.front().peer_address); + EXPECT_EQ(addr_with_new_port, queue.back().peer_address); +} + +TEST_F(QuicBufferedPacketStoreTest, + EnqueueAndDeliverMultiplePacketsOnMultipleConnections) { + size_t num_connections = 10; + for (uint64_t conn_id = 1; conn_id <= num_connections; ++conn_id) { + QuicConnectionId connection_id = TestConnectionId(conn_id); + store_.EnqueuePacket(connection_id, false, packet_, self_address_, + peer_address_, invalid_version_, kNoParsedChlo); + store_.EnqueuePacket(connection_id, false, packet_, self_address_, + peer_address_, invalid_version_, kNoParsedChlo); + } + + // Deliver packets in reversed order. + for (uint64_t conn_id = num_connections; conn_id > 0; --conn_id) { + QuicConnectionId connection_id = TestConnectionId(conn_id); + std::list queue = + store_.DeliverPackets(connection_id).buffered_packets; + ASSERT_EQ(2u, queue.size()); + } +} + +TEST_F(QuicBufferedPacketStoreTest, + FailToBufferTooManyPacketsOnExistingConnection) { + // Tests that for one connection, only limited number of packets can be + // buffered. + size_t num_packets = kDefaultMaxUndecryptablePackets + 1; + QuicConnectionId connection_id = TestConnectionId(1); + // Arrived CHLO packet shouldn't affect how many non-CHLO pacekts store can + // keep. + EXPECT_EQ( + QuicBufferedPacketStore::SUCCESS, + store_.EnqueuePacket(connection_id, false, packet_, self_address_, + peer_address_, valid_version_, kDefaultParsedChlo)); + for (size_t i = 1; i <= num_packets; ++i) { + // Only first |kDefaultMaxUndecryptablePackets packets| will be buffered. + EnqueuePacketResult result = + store_.EnqueuePacket(connection_id, false, packet_, self_address_, + peer_address_, invalid_version_, kNoParsedChlo); + if (i <= kDefaultMaxUndecryptablePackets) { + EXPECT_EQ(EnqueuePacketResult::SUCCESS, result); + } else { + EXPECT_EQ(EnqueuePacketResult::TOO_MANY_PACKETS, result); + } + } + + // Only first |kDefaultMaxUndecryptablePackets| non-CHLO packets and CHLO are + // buffered. + EXPECT_EQ(kDefaultMaxUndecryptablePackets + 1, + store_.DeliverPackets(connection_id).buffered_packets.size()); +} + +TEST_F(QuicBufferedPacketStoreTest, ReachNonChloConnectionUpperLimit) { + // Tests that store can only keep early arrived packets for limited number of + // connections. + const size_t kNumConnections = kMaxConnectionsWithoutCHLO + 1; + for (uint64_t conn_id = 1; conn_id <= kNumConnections; ++conn_id) { + QuicConnectionId connection_id = TestConnectionId(conn_id); + EnqueuePacketResult result = + store_.EnqueuePacket(connection_id, false, packet_, self_address_, + peer_address_, invalid_version_, kNoParsedChlo); + if (conn_id <= kMaxConnectionsWithoutCHLO) { + EXPECT_EQ(EnqueuePacketResult::SUCCESS, result); + } else { + EXPECT_EQ(EnqueuePacketResult::TOO_MANY_CONNECTIONS, result); + } + } + // Store only keeps early arrived packets upto |kNumConnections| connections. + for (uint64_t conn_id = 1; conn_id <= kNumConnections; ++conn_id) { + QuicConnectionId connection_id = TestConnectionId(conn_id); + std::list queue = + store_.DeliverPackets(connection_id).buffered_packets; + if (conn_id <= kMaxConnectionsWithoutCHLO) { + EXPECT_EQ(1u, queue.size()); + } else { + EXPECT_EQ(0u, queue.size()); + } + } +} + +TEST_F(QuicBufferedPacketStoreTest, + FullStoreFailToBufferDataPacketOnNewConnection) { + // Send enough CHLOs so that store gets full before number of connections + // without CHLO reaches its upper limit. + size_t num_chlos = + kDefaultMaxConnectionsInStore - kMaxConnectionsWithoutCHLO + 1; + for (uint64_t conn_id = 1; conn_id <= num_chlos; ++conn_id) { + EXPECT_EQ(EnqueuePacketResult::SUCCESS, + store_.EnqueuePacket(TestConnectionId(conn_id), false, packet_, + self_address_, peer_address_, valid_version_, + kDefaultParsedChlo)); + } + + // Send data packets on another |kMaxConnectionsWithoutCHLO| connections. + // Store should only be able to buffer till it's full. + for (uint64_t conn_id = num_chlos + 1; + conn_id <= (kDefaultMaxConnectionsInStore + 1); ++conn_id) { + QuicConnectionId connection_id = TestConnectionId(conn_id); + EnqueuePacketResult result = + store_.EnqueuePacket(connection_id, false, packet_, self_address_, + peer_address_, valid_version_, kDefaultParsedChlo); + if (conn_id <= kDefaultMaxConnectionsInStore) { + EXPECT_EQ(EnqueuePacketResult::SUCCESS, result); + } else { + EXPECT_EQ(EnqueuePacketResult::TOO_MANY_CONNECTIONS, result); + } + } +} + +TEST_F(QuicBufferedPacketStoreTest, EnqueueChloOnTooManyDifferentConnections) { + // Buffer data packets on different connections upto limit. + for (uint64_t conn_id = 1; conn_id <= kMaxConnectionsWithoutCHLO; ++conn_id) { + QuicConnectionId connection_id = TestConnectionId(conn_id); + EXPECT_EQ( + EnqueuePacketResult::SUCCESS, + store_.EnqueuePacket(connection_id, false, packet_, self_address_, + peer_address_, invalid_version_, kNoParsedChlo)); + } + + // Buffer CHLOs on other connections till store is full. + for (size_t i = kMaxConnectionsWithoutCHLO + 1; + i <= kDefaultMaxConnectionsInStore + 1; ++i) { + QuicConnectionId connection_id = TestConnectionId(i); + EnqueuePacketResult rs = + store_.EnqueuePacket(connection_id, false, packet_, self_address_, + peer_address_, valid_version_, kDefaultParsedChlo); + if (i <= kDefaultMaxConnectionsInStore) { + EXPECT_EQ(EnqueuePacketResult::SUCCESS, rs); + EXPECT_TRUE(store_.HasChloForConnection(connection_id)); + } else { + // Last CHLO can't be buffered because store is full. + EXPECT_EQ(EnqueuePacketResult::TOO_MANY_CONNECTIONS, rs); + EXPECT_FALSE(store_.HasChloForConnection(connection_id)); + } + } + + // But buffering a CHLO belonging to a connection already has data packet + // buffered in the store should success. This is the connection should be + // delivered at last. + EXPECT_EQ( + EnqueuePacketResult::SUCCESS, + store_.EnqueuePacket( + /*connection_id=*/TestConnectionId(1), false, packet_, self_address_, + peer_address_, valid_version_, kDefaultParsedChlo)); + EXPECT_TRUE(store_.HasChloForConnection( + /*connection_id=*/TestConnectionId(1))); + + QuicConnectionId delivered_conn_id; + for (size_t i = 0; + i < kDefaultMaxConnectionsInStore - kMaxConnectionsWithoutCHLO + 1; + ++i) { + if (i < kDefaultMaxConnectionsInStore - kMaxConnectionsWithoutCHLO) { + // Only CHLO is buffered. + EXPECT_EQ(1u, store_.DeliverPacketsForNextConnection(&delivered_conn_id) + .buffered_packets.size()); + EXPECT_EQ(TestConnectionId(i + kMaxConnectionsWithoutCHLO + 1), + delivered_conn_id); + } else { + EXPECT_EQ(2u, store_.DeliverPacketsForNextConnection(&delivered_conn_id) + .buffered_packets.size()); + EXPECT_EQ(TestConnectionId(1u), delivered_conn_id); + } + } + EXPECT_FALSE(store_.HasChlosBuffered()); +} + +// Tests that store expires long-staying connections appropriately for +// connections both with and without CHLOs. +TEST_F(QuicBufferedPacketStoreTest, PacketQueueExpiredBeforeDelivery) { + QuicConnectionId connection_id = TestConnectionId(1); + store_.EnqueuePacket(connection_id, false, packet_, self_address_, + peer_address_, invalid_version_, kNoParsedChlo); + EXPECT_EQ( + EnqueuePacketResult::SUCCESS, + store_.EnqueuePacket(connection_id, false, packet_, self_address_, + peer_address_, valid_version_, kDefaultParsedChlo)); + QuicConnectionId connection_id2 = TestConnectionId(2); + EXPECT_EQ( + EnqueuePacketResult::SUCCESS, + store_.EnqueuePacket(connection_id2, false, packet_, self_address_, + peer_address_, invalid_version_, kNoParsedChlo)); + + // CHLO on connection 3 arrives 1ms later. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(1)); + QuicConnectionId connection_id3 = TestConnectionId(3); + // Use different client address to differetiate packets from different + // connections. + QuicSocketAddress another_client_address(QuicIpAddress::Any4(), 255); + store_.EnqueuePacket(connection_id3, false, packet_, self_address_, + another_client_address, valid_version_, + kDefaultParsedChlo); + + // Advance clock to the time when connection 1 and 2 expires. + clock_.AdvanceTime( + QuicBufferedPacketStorePeer::expiration_alarm(&store_)->deadline() - + clock_.ApproximateNow()); + ASSERT_GE(clock_.ApproximateNow(), + QuicBufferedPacketStorePeer::expiration_alarm(&store_)->deadline()); + // Fire alarm to remove long-staying connection 1 and 2 packets. + alarm_factory_.FireAlarm( + QuicBufferedPacketStorePeer::expiration_alarm(&store_)); + EXPECT_EQ(1u, visitor_.last_expired_packet_queue_.buffered_packets.size()); + EXPECT_FALSE(store_.HasBufferedPackets(connection_id)); + EXPECT_FALSE(store_.HasBufferedPackets(connection_id2)); + + // Try to deliver packets, but packet queue has been removed so no + // packets can be returned. + ASSERT_EQ(0u, store_.DeliverPackets(connection_id).buffered_packets.size()); + ASSERT_EQ(0u, store_.DeliverPackets(connection_id2).buffered_packets.size()); + QuicConnectionId delivered_conn_id; + auto queue = store_.DeliverPacketsForNextConnection(&delivered_conn_id) + .buffered_packets; + // Connection 3 is the next to be delivered as connection 1 already expired. + EXPECT_EQ(connection_id3, delivered_conn_id); + ASSERT_EQ(1u, queue.size()); + // Packets in connection 3 should use another peer address. + EXPECT_EQ(another_client_address, queue.front().peer_address); + + // Test the alarm is reset by enqueueing 2 packets for 4th connection and wait + // for them to expire. + QuicConnectionId connection_id4 = TestConnectionId(4); + store_.EnqueuePacket(connection_id4, false, packet_, self_address_, + peer_address_, invalid_version_, kNoParsedChlo); + store_.EnqueuePacket(connection_id4, false, packet_, self_address_, + peer_address_, invalid_version_, kNoParsedChlo); + clock_.AdvanceTime( + QuicBufferedPacketStorePeer::expiration_alarm(&store_)->deadline() - + clock_.ApproximateNow()); + alarm_factory_.FireAlarm( + QuicBufferedPacketStorePeer::expiration_alarm(&store_)); + // |last_expired_packet_queue_| should be updated. + EXPECT_EQ(2u, visitor_.last_expired_packet_queue_.buffered_packets.size()); +} + +TEST_F(QuicBufferedPacketStoreTest, SimpleDiscardPackets) { + QuicConnectionId connection_id = TestConnectionId(1); + + // Enqueue some packets + store_.EnqueuePacket(connection_id, false, packet_, self_address_, + peer_address_, invalid_version_, kNoParsedChlo); + store_.EnqueuePacket(connection_id, false, packet_, self_address_, + peer_address_, invalid_version_, kNoParsedChlo); + EXPECT_TRUE(store_.HasBufferedPackets(connection_id)); + EXPECT_FALSE(store_.HasChlosBuffered()); + + // Dicard the packets + store_.DiscardPackets(connection_id); + + // No packets on connection 1 should remain in the store + EXPECT_TRUE(store_.DeliverPackets(connection_id).buffered_packets.empty()); + EXPECT_FALSE(store_.HasBufferedPackets(connection_id)); + EXPECT_FALSE(store_.HasChlosBuffered()); + + // Check idempotency + store_.DiscardPackets(connection_id); + EXPECT_TRUE(store_.DeliverPackets(connection_id).buffered_packets.empty()); + EXPECT_FALSE(store_.HasBufferedPackets(connection_id)); + EXPECT_FALSE(store_.HasChlosBuffered()); +} + +TEST_F(QuicBufferedPacketStoreTest, DiscardWithCHLOs) { + QuicConnectionId connection_id = TestConnectionId(1); + + // Enqueue some packets, which include a CHLO + store_.EnqueuePacket(connection_id, false, packet_, self_address_, + peer_address_, invalid_version_, kNoParsedChlo); + store_.EnqueuePacket(connection_id, false, packet_, self_address_, + peer_address_, valid_version_, kDefaultParsedChlo); + store_.EnqueuePacket(connection_id, false, packet_, self_address_, + peer_address_, invalid_version_, kNoParsedChlo); + EXPECT_TRUE(store_.HasBufferedPackets(connection_id)); + EXPECT_TRUE(store_.HasChlosBuffered()); + + // Dicard the packets + store_.DiscardPackets(connection_id); + + // No packets on connection 1 should remain in the store + EXPECT_TRUE(store_.DeliverPackets(connection_id).buffered_packets.empty()); + EXPECT_FALSE(store_.HasBufferedPackets(connection_id)); + EXPECT_FALSE(store_.HasChlosBuffered()); + + // Check idempotency + store_.DiscardPackets(connection_id); + EXPECT_TRUE(store_.DeliverPackets(connection_id).buffered_packets.empty()); + EXPECT_FALSE(store_.HasBufferedPackets(connection_id)); + EXPECT_FALSE(store_.HasChlosBuffered()); +} + +TEST_F(QuicBufferedPacketStoreTest, MultipleDiscardPackets) { + QuicConnectionId connection_id_1 = TestConnectionId(1); + QuicConnectionId connection_id_2 = TestConnectionId(2); + + // Enqueue some packets for two connection IDs + store_.EnqueuePacket(connection_id_1, false, packet_, self_address_, + peer_address_, invalid_version_, kNoParsedChlo); + store_.EnqueuePacket(connection_id_1, false, packet_, self_address_, + peer_address_, invalid_version_, kNoParsedChlo); + + ParsedClientHello parsed_chlo; + parsed_chlo.alpns.push_back("h3"); + parsed_chlo.sni = TestHostname(); + store_.EnqueuePacket(connection_id_2, false, packet_, self_address_, + peer_address_, valid_version_, parsed_chlo); + EXPECT_TRUE(store_.HasBufferedPackets(connection_id_1)); + EXPECT_TRUE(store_.HasBufferedPackets(connection_id_2)); + EXPECT_TRUE(store_.HasChlosBuffered()); + + // Discard the packets for connection 1 + store_.DiscardPackets(connection_id_1); + + // No packets on connection 1 should remain in the store + EXPECT_TRUE(store_.DeliverPackets(connection_id_1).buffered_packets.empty()); + EXPECT_FALSE(store_.HasBufferedPackets(connection_id_1)); + EXPECT_TRUE(store_.HasChlosBuffered()); + + // Packets on connection 2 should remain + EXPECT_TRUE(store_.HasBufferedPackets(connection_id_2)); + auto packets = store_.DeliverPackets(connection_id_2); + EXPECT_EQ(1u, packets.buffered_packets.size()); + ASSERT_EQ(1u, packets.parsed_chlo->alpns.size()); + EXPECT_EQ("h3", packets.parsed_chlo->alpns[0]); + EXPECT_EQ(TestHostname(), packets.parsed_chlo->sni); + // Since connection_id_2's chlo arrives, verify version is set. + EXPECT_EQ(valid_version_, packets.version); + EXPECT_TRUE(store_.HasChlosBuffered()); + + // Discard the packets for connection 2 + store_.DiscardPackets(connection_id_2); + EXPECT_FALSE(store_.HasChlosBuffered()); +} + +TEST_F(QuicBufferedPacketStoreTest, DiscardPacketsEmpty) { + // Check that DiscardPackets on an unknown connection ID is safe and does + // nothing. + QuicConnectionId connection_id = TestConnectionId(11235); + EXPECT_FALSE(store_.HasBufferedPackets(connection_id)); + EXPECT_FALSE(store_.HasChlosBuffered()); + store_.DiscardPackets(connection_id); + EXPECT_FALSE(store_.HasBufferedPackets(connection_id)); + EXPECT_FALSE(store_.HasChlosBuffered()); +} + +TEST_F(QuicBufferedPacketStoreTest, IngestPacketForTlsChloExtraction) { + QuicConnectionId connection_id = TestConnectionId(1); + std::vector alpns; + std::string sni; + bool resumption_attempted = false; + bool early_data_attempted = false; + QuicConfig config; + absl::optional tls_alert; + + EXPECT_FALSE(store_.HasBufferedPackets(connection_id)); + store_.EnqueuePacket(connection_id, false, packet_, self_address_, + peer_address_, valid_version_, kNoParsedChlo); + EXPECT_TRUE(store_.HasBufferedPackets(connection_id)); + + // The packet in 'packet_' is not a TLS CHLO packet. + EXPECT_FALSE(store_.IngestPacketForTlsChloExtraction( + connection_id, valid_version_, packet_, &alpns, &sni, + &resumption_attempted, &early_data_attempted, &tls_alert)); + + store_.DiscardPackets(connection_id); + + // Force the TLS CHLO to span multiple packets. + constexpr auto kCustomParameterId = + static_cast(0xff33); + std::string kCustomParameterValue(2000, '-'); + config.custom_transport_parameters_to_send()[kCustomParameterId] = + kCustomParameterValue; + auto packets = GetFirstFlightOfPackets(valid_version_, config); + ASSERT_EQ(packets.size(), 2u); + + store_.EnqueuePacket(connection_id, false, *packets[0], self_address_, + peer_address_, valid_version_, kNoParsedChlo); + store_.EnqueuePacket(connection_id, false, *packets[1], self_address_, + peer_address_, valid_version_, kNoParsedChlo); + + EXPECT_TRUE(store_.HasBufferedPackets(connection_id)); + EXPECT_FALSE(store_.IngestPacketForTlsChloExtraction( + connection_id, valid_version_, *packets[0], &alpns, &sni, + &resumption_attempted, &early_data_attempted, &tls_alert)); + EXPECT_TRUE(store_.IngestPacketForTlsChloExtraction( + connection_id, valid_version_, *packets[1], &alpns, &sni, + &resumption_attempted, &early_data_attempted, &tls_alert)); + + EXPECT_THAT(alpns, ElementsAre(AlpnForVersion(valid_version_))); + EXPECT_EQ(sni, TestHostname()); + + EXPECT_FALSE(resumption_attempted); + EXPECT_FALSE(early_data_attempted); +} + +TEST_F(QuicBufferedPacketStoreTest, DeliverInitialPacketsFirst) { + QuicConfig config; + QuicConnectionId connection_id = TestConnectionId(1); + + // Force the TLS CHLO to span multiple packets. + constexpr auto kCustomParameterId = + static_cast(0xff33); + std::string custom_parameter_value(2000, '-'); + config.custom_transport_parameters_to_send()[kCustomParameterId] = + custom_parameter_value; + auto initial_packets = GetFirstFlightOfPackets(valid_version_, config); + ASSERT_THAT(initial_packets, SizeIs(2)); + + // Verify that the packets generated are INITIAL packets. + EXPECT_THAT( + initial_packets, + Each(Truly([](const std::unique_ptr& packet) { + QuicLongHeaderType long_packet_type = INVALID_PACKET_TYPE; + PacketHeaderFormat unused_format; + bool unused_version_flag; + bool unused_use_length_prefix; + QuicVersionLabel unused_version_label; + ParsedQuicVersion unused_parsed_version = UnsupportedQuicVersion(); + QuicConnectionId unused_destination_connection_id; + QuicConnectionId unused_source_connection_id; + absl::optional unused_retry_token; + std::string unused_detailed_error; + QuicErrorCode error_code = QuicFramer::ParsePublicHeaderDispatcher( + *packet, kQuicDefaultConnectionIdLength, &unused_format, + &long_packet_type, &unused_version_flag, &unused_use_length_prefix, + &unused_version_label, &unused_parsed_version, + &unused_destination_connection_id, &unused_source_connection_id, + &unused_retry_token, &unused_detailed_error); + return error_code == QUIC_NO_ERROR && long_packet_type == INITIAL; + }))); + + QuicLongHeaderType long_packet_type = INVALID_PACKET_TYPE; + PacketHeaderFormat unused_format; + bool unused_version_flag; + bool unused_use_length_prefix; + QuicVersionLabel unused_version_label; + ParsedQuicVersion unused_parsed_version = UnsupportedQuicVersion(); + QuicConnectionId unused_destination_connection_id; + QuicConnectionId unused_source_connection_id; + absl::optional unused_retry_token; + std::string unused_detailed_error; + QuicErrorCode error_code = QUIC_NO_ERROR; + + // Verify that packet_ is not an INITIAL packet. + error_code = QuicFramer::ParsePublicHeaderDispatcher( + packet_, kQuicDefaultConnectionIdLength, &unused_format, + &long_packet_type, &unused_version_flag, &unused_use_length_prefix, + &unused_version_label, &unused_parsed_version, + &unused_destination_connection_id, &unused_source_connection_id, + &unused_retry_token, &unused_detailed_error); + EXPECT_THAT(error_code, IsQuicNoError()); + EXPECT_NE(long_packet_type, INITIAL); + + store_.EnqueuePacket(connection_id, false, packet_, self_address_, + peer_address_, valid_version_, kNoParsedChlo); + store_.EnqueuePacket(connection_id, false, *initial_packets[0], self_address_, + peer_address_, valid_version_, kNoParsedChlo); + store_.EnqueuePacket(connection_id, false, *initial_packets[1], self_address_, + peer_address_, valid_version_, kNoParsedChlo); + + BufferedPacketList delivered_packets = store_.DeliverPackets(connection_id); + EXPECT_THAT(delivered_packets.buffered_packets, SizeIs(3)); + + QuicLongHeaderType previous_packet_type = INITIAL; + for (const auto& packet : delivered_packets.buffered_packets) { + error_code = QuicFramer::ParsePublicHeaderDispatcher( + *packet.packet, kQuicDefaultConnectionIdLength, &unused_format, + &long_packet_type, &unused_version_flag, &unused_use_length_prefix, + &unused_version_label, &unused_parsed_version, + &unused_destination_connection_id, &unused_source_connection_id, + &unused_retry_token, &unused_detailed_error); + EXPECT_THAT(error_code, IsQuicNoError()); + + // INITIAL packets should not follow a non-INITIAL packet. + EXPECT_THAT(long_packet_type, + Conditional(previous_packet_type == INITIAL, + A(), Ne(INITIAL))); + previous_packet_type = long_packet_type; + } +} +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_chaos_protector.cc b/quiche/quic/core/quic_chaos_protector.cc new file mode 100644 index 000000000000..8daaeb7db00c --- /dev/null +++ b/quiche/quic/core/quic_chaos_protector.cc @@ -0,0 +1,225 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_chaos_protector.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/frames/quic_crypto_frame.h" +#include "quiche/quic/core/frames/quic_frame.h" +#include "quiche/quic/core/frames/quic_padding_frame.h" +#include "quiche/quic/core/frames/quic_ping_frame.h" +#include "quiche/quic/core/quic_data_reader.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/core/quic_framer.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_stream_frame_data_producer.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace quic { + +QuicChaosProtector::QuicChaosProtector(const QuicCryptoFrame& crypto_frame, + int num_padding_bytes, + size_t packet_size, QuicFramer* framer, + QuicRandom* random) + : packet_size_(packet_size), + crypto_data_length_(crypto_frame.data_length), + crypto_buffer_offset_(crypto_frame.offset), + level_(crypto_frame.level), + remaining_padding_bytes_(num_padding_bytes), + framer_(framer), + random_(random) { + QUICHE_DCHECK_NE(framer_, nullptr); + QUICHE_DCHECK_NE(framer_->data_producer(), nullptr); + QUICHE_DCHECK_NE(random_, nullptr); +} + +QuicChaosProtector::~QuicChaosProtector() { DeleteFrames(&frames_); } + +absl::optional QuicChaosProtector::BuildDataPacket( + const QuicPacketHeader& header, char* buffer) { + if (!CopyCryptoDataToLocalBuffer()) { + return absl::nullopt; + } + SplitCryptoFrame(); + AddPingFrames(); + SpreadPadding(); + ReorderFrames(); + return BuildPacket(header, buffer); +} + +WriteStreamDataResult QuicChaosProtector::WriteStreamData( + QuicStreamId id, QuicStreamOffset offset, QuicByteCount data_length, + QuicDataWriter* /*writer*/) { + QUIC_BUG(chaos stream) << "This should never be called; id " << id + << " offset " << offset << " data_length " + << data_length; + return STREAM_MISSING; +} + +bool QuicChaosProtector::WriteCryptoData(EncryptionLevel level, + QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* writer) { + if (level != level_) { + QUIC_BUG(chaos bad level) << "Unexpected " << level << " != " << level_; + return false; + } + // This is `offset + data_length > buffer_offset_ + buffer_length_` + // but with integer overflow protection. + if (offset < crypto_buffer_offset_ || data_length > crypto_data_length_ || + offset - crypto_buffer_offset_ > crypto_data_length_ - data_length) { + QUIC_BUG(chaos bad lengths) + << "Unexpected buffer_offset_ " << crypto_buffer_offset_ << " offset " + << offset << " buffer_length_ " << crypto_data_length_ + << " data_length " << data_length; + return false; + } + writer->WriteBytes(&crypto_data_buffer_[offset - crypto_buffer_offset_], + data_length); + return true; +} + +bool QuicChaosProtector::CopyCryptoDataToLocalBuffer() { + crypto_frame_buffer_ = std::make_unique(packet_size_); + frames_.push_back(QuicFrame( + new QuicCryptoFrame(level_, crypto_buffer_offset_, crypto_data_length_))); + // We use |framer_| to serialize the CRYPTO frame in order to extract its + // data from the crypto data producer. This ensures that we reuse the + // usual serialization code path, but has the downside that we then need to + // parse the offset and length in order to skip over those fields. + QuicDataWriter writer(packet_size_, crypto_frame_buffer_.get()); + if (!framer_->AppendCryptoFrame(*frames_.front().crypto_frame, &writer)) { + QUIC_BUG(chaos write crypto data); + return false; + } + QuicDataReader reader(crypto_frame_buffer_.get(), writer.length()); + uint64_t parsed_offset, parsed_length; + if (!reader.ReadVarInt62(&parsed_offset) || + !reader.ReadVarInt62(&parsed_length)) { + QUIC_BUG(chaos parse crypto frame); + return false; + } + + absl::string_view crypto_data = reader.ReadRemainingPayload(); + crypto_data_buffer_ = crypto_data.data(); + + QUICHE_DCHECK_EQ(parsed_offset, crypto_buffer_offset_); + QUICHE_DCHECK_EQ(parsed_length, crypto_data_length_); + QUICHE_DCHECK_EQ(parsed_length, crypto_data.length()); + + return true; +} + +void QuicChaosProtector::SplitCryptoFrame() { + const int max_overhead_of_adding_a_crypto_frame = + static_cast(QuicFramer::GetMinCryptoFrameSize( + crypto_buffer_offset_ + crypto_data_length_, crypto_data_length_)); + // Pick a random number of CRYPTO frames to add. + constexpr uint64_t kMaxAddedCryptoFrames = 10; + const uint64_t num_added_crypto_frames = + random_->InsecureRandUint64() % (kMaxAddedCryptoFrames + 1); + for (uint64_t i = 0; i < num_added_crypto_frames; i++) { + if (remaining_padding_bytes_ < max_overhead_of_adding_a_crypto_frame) { + break; + } + // Pick a random frame and split it by shrinking the picked frame and + // moving the second half of its data to a new frame that is then appended + // to |frames|. + size_t frame_to_split_index = + random_->InsecureRandUint64() % frames_.size(); + QuicCryptoFrame* frame_to_split = + frames_[frame_to_split_index].crypto_frame; + if (frame_to_split->data_length <= 1) { + continue; + } + const int frame_to_split_old_overhead = + static_cast(QuicFramer::GetMinCryptoFrameSize( + frame_to_split->offset, frame_to_split->data_length)); + const QuicPacketLength frame_to_split_new_data_length = + 1 + (random_->InsecureRandUint64() % (frame_to_split->data_length - 1)); + const QuicPacketLength new_frame_data_length = + frame_to_split->data_length - frame_to_split_new_data_length; + const QuicStreamOffset new_frame_offset = + frame_to_split->offset + frame_to_split_new_data_length; + frame_to_split->data_length -= new_frame_data_length; + frames_.push_back(QuicFrame( + new QuicCryptoFrame(level_, new_frame_offset, new_frame_data_length))); + const int frame_to_split_new_overhead = + static_cast(QuicFramer::GetMinCryptoFrameSize( + frame_to_split->offset, frame_to_split->data_length)); + const int new_frame_overhead = + static_cast(QuicFramer::GetMinCryptoFrameSize( + new_frame_offset, new_frame_data_length)); + QUICHE_DCHECK_LE(frame_to_split_new_overhead, frame_to_split_old_overhead); + // Readjust padding based on increased overhead. + remaining_padding_bytes_ -= new_frame_overhead; + remaining_padding_bytes_ -= frame_to_split_new_overhead; + remaining_padding_bytes_ += frame_to_split_old_overhead; + } +} + +void QuicChaosProtector::AddPingFrames() { + if (remaining_padding_bytes_ == 0) { + return; + } + constexpr uint64_t kMaxAddedPingFrames = 10; + const uint64_t num_ping_frames = + random_->InsecureRandUint64() % + std::min(kMaxAddedPingFrames, remaining_padding_bytes_); + for (uint64_t i = 0; i < num_ping_frames; i++) { + frames_.push_back(QuicFrame(QuicPingFrame())); + } + remaining_padding_bytes_ -= static_cast(num_ping_frames); +} + +void QuicChaosProtector::ReorderFrames() { + // Walk the array backwards and swap each frame with a random earlier one. + for (size_t i = frames_.size() - 1; i > 0; i--) { + std::swap(frames_[i], frames_[random_->InsecureRandUint64() % (i + 1)]); + } +} + +void QuicChaosProtector::SpreadPadding() { + for (auto it = frames_.begin(); it != frames_.end(); ++it) { + const int padding_bytes_in_this_frame = + random_->InsecureRandUint64() % (remaining_padding_bytes_ + 1); + if (padding_bytes_in_this_frame <= 0) { + continue; + } + it = frames_.insert( + it, QuicFrame(QuicPaddingFrame(padding_bytes_in_this_frame))); + ++it; // Skip over the padding frame we just added. + remaining_padding_bytes_ -= padding_bytes_in_this_frame; + } + if (remaining_padding_bytes_ > 0) { + frames_.push_back(QuicFrame(QuicPaddingFrame(remaining_padding_bytes_))); + } +} + +absl::optional QuicChaosProtector::BuildPacket( + const QuicPacketHeader& header, char* buffer) { + QuicStreamFrameDataProducer* original_data_producer = + framer_->data_producer(); + framer_->set_data_producer(this); + + size_t length = + framer_->BuildDataPacket(header, frames_, buffer, packet_size_, level_); + + framer_->set_data_producer(original_data_producer); + if (length == 0) { + return absl::nullopt; + } + return length; +} + +} // namespace quic diff --git a/quiche/quic/core/quic_chaos_protector.h b/quiche/quic/core/quic_chaos_protector.h new file mode 100644 index 000000000000..6bcb3352217f --- /dev/null +++ b/quiche/quic/core/quic_chaos_protector.h @@ -0,0 +1,96 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_CHAOS_PROTECTOR_H_ +#define QUICHE_QUIC_CORE_QUIC_CHAOS_PROTECTOR_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/frames/quic_crypto_frame.h" +#include "quiche/quic/core/frames/quic_frame.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/core/quic_framer.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_stream_frame_data_producer.h" +#include "quiche/quic/core/quic_types.h" + +namespace quic { + +namespace test { +class QuicChaosProtectorTest; +} + +// QuicChaosProtector will take a crypto frame and an amount of padding and +// build a data packet that will parse to something equivalent. +class QUIC_EXPORT_PRIVATE QuicChaosProtector + : public QuicStreamFrameDataProducer { + public: + // |framer| and |random| must be valid for the lifetime of QuicChaosProtector. + explicit QuicChaosProtector(const QuicCryptoFrame& crypto_frame, + int num_padding_bytes, size_t packet_size, + QuicFramer* framer, QuicRandom* random); + + ~QuicChaosProtector() override; + + QuicChaosProtector(const QuicChaosProtector&) = delete; + QuicChaosProtector(QuicChaosProtector&&) = delete; + QuicChaosProtector& operator=(const QuicChaosProtector&) = delete; + QuicChaosProtector& operator=(QuicChaosProtector&&) = delete; + + // Attempts to build a data packet with chaos protection. If an error occurs, + // then absl::nullopt is returned. Otherwise returns the serialized length. + absl::optional BuildDataPacket(const QuicPacketHeader& header, + char* buffer); + + // From QuicStreamFrameDataProducer. + WriteStreamDataResult WriteStreamData(QuicStreamId id, + QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* /*writer*/) override; + bool WriteCryptoData(EncryptionLevel level, QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* writer) override; + + private: + friend class test::QuicChaosProtectorTest; + + // Allocate the crypto data buffer, create the CRYPTO frame and write the + // crypto data to our buffer. + bool CopyCryptoDataToLocalBuffer(); + + // Split the CRYPTO frame in |frames_| into one or more CRYPTO frames that + // collectively represent the same data. Adjusts padding to compensate. + void SplitCryptoFrame(); + + // Add a random number of PING frames to |frames_| and adjust padding. + void AddPingFrames(); + + // Randomly reorder |frames_|. + void ReorderFrames(); + + // Add PADDING frames randomly between all other frames. + void SpreadPadding(); + + // Serialize |frames_| using |framer_|. + absl::optional BuildPacket(const QuicPacketHeader& header, + char* buffer); + + size_t packet_size_; + std::unique_ptr crypto_frame_buffer_; + const char* crypto_data_buffer_ = nullptr; + QuicByteCount crypto_data_length_; + QuicStreamOffset crypto_buffer_offset_; + EncryptionLevel level_; + int remaining_padding_bytes_; + QuicFrames frames_; // Inner frames owned, will be deleted by destructor. + QuicFramer* framer_; // Unowned. + QuicRandom* random_; // Unowned. +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_CHAOS_PROTECTOR_H_ diff --git a/quiche/quic/core/quic_chaos_protector_test.cc b/quiche/quic/core/quic_chaos_protector_test.cc new file mode 100644 index 000000000000..92d3af9a20f6 --- /dev/null +++ b/quiche/quic/core/quic_chaos_protector_test.cc @@ -0,0 +1,229 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_chaos_protector.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/frames/quic_crypto_frame.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_framer.h" +#include "quiche/quic/core/quic_packet_number.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_stream_frame_data_producer.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/mock_random.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/test_tools/simple_quic_framer.h" + +namespace quic { +namespace test { + +class QuicChaosProtectorTest : public QuicTestWithParam, + public QuicStreamFrameDataProducer { + public: + QuicChaosProtectorTest() + : version_(GetParam()), + framer_({version_}, QuicTime::Zero(), Perspective::IS_CLIENT, + kQuicDefaultConnectionIdLength), + validation_framer_({version_}), + random_(/*base=*/3), + level_(ENCRYPTION_INITIAL), + crypto_offset_(0), + crypto_data_length_(100), + crypto_frame_(level_, crypto_offset_, crypto_data_length_), + num_padding_bytes_(50), + packet_size_(1000), + packet_buffer_(std::make_unique(packet_size_)) { + ReCreateChaosProtector(); + } + + void ReCreateChaosProtector() { + chaos_protector_ = std::make_unique( + crypto_frame_, num_padding_bytes_, packet_size_, + SetupHeaderAndFramers(), &random_); + } + + // From QuicStreamFrameDataProducer. + WriteStreamDataResult WriteStreamData(QuicStreamId /*id*/, + QuicStreamOffset /*offset*/, + QuicByteCount /*data_length*/, + QuicDataWriter* /*writer*/) override { + ADD_FAILURE() << "This should never be called"; + return STREAM_MISSING; + } + + // From QuicStreamFrameDataProducer. + bool WriteCryptoData(EncryptionLevel level, QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* writer) override { + EXPECT_EQ(level, level); + EXPECT_EQ(offset, crypto_offset_); + EXPECT_EQ(data_length, crypto_data_length_); + for (QuicByteCount i = 0; i < data_length; i++) { + EXPECT_TRUE(writer->WriteUInt8(static_cast(i & 0xFF))); + } + return true; + } + + protected: + QuicFramer* SetupHeaderAndFramers() { + // Setup header. + header_.destination_connection_id = TestConnectionId(); + header_.destination_connection_id_included = CONNECTION_ID_PRESENT; + header_.source_connection_id = EmptyQuicConnectionId(); + header_.source_connection_id_included = CONNECTION_ID_PRESENT; + header_.reset_flag = false; + header_.version_flag = true; + header_.has_possible_stateless_reset_token = false; + header_.packet_number_length = PACKET_4BYTE_PACKET_NUMBER; + header_.version = version_; + header_.packet_number = QuicPacketNumber(1); + header_.form = IETF_QUIC_LONG_HEADER_PACKET; + header_.long_packet_type = INITIAL; + header_.retry_token_length_length = + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1; + header_.length_length = quiche::kQuicheDefaultLongHeaderLengthLength; + // Setup validation framer. + validation_framer_.framer()->SetInitialObfuscators( + header_.destination_connection_id); + // Setup framer. + framer_.SetInitialObfuscators(header_.destination_connection_id); + framer_.set_data_producer(this); + return &framer_; + } + + void BuildEncryptAndParse() { + absl::optional length = + chaos_protector_->BuildDataPacket(header_, packet_buffer_.get()); + ASSERT_TRUE(length.has_value()); + ASSERT_GT(length.value(), 0u); + size_t encrypted_length = framer_.EncryptInPlace( + level_, header_.packet_number, + GetStartOfEncryptedData(framer_.transport_version(), header_), + length.value(), packet_size_, packet_buffer_.get()); + ASSERT_GT(encrypted_length, 0u); + ASSERT_TRUE(validation_framer_.ProcessPacket(QuicEncryptedPacket( + absl::string_view(packet_buffer_.get(), encrypted_length)))); + } + + void ResetOffset(QuicStreamOffset offset) { + crypto_offset_ = offset; + crypto_frame_.offset = offset; + ReCreateChaosProtector(); + } + + void ResetLength(QuicByteCount length) { + crypto_data_length_ = length; + crypto_frame_.data_length = length; + ReCreateChaosProtector(); + } + + ParsedQuicVersion version_; + QuicPacketHeader header_; + QuicFramer framer_; + SimpleQuicFramer validation_framer_; + MockRandom random_; + EncryptionLevel level_; + QuicStreamOffset crypto_offset_; + QuicByteCount crypto_data_length_; + QuicCryptoFrame crypto_frame_; + int num_padding_bytes_; + size_t packet_size_; + std::unique_ptr packet_buffer_; + std::unique_ptr chaos_protector_; +}; + +namespace { + +ParsedQuicVersionVector TestVersions() { + ParsedQuicVersionVector versions; + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + if (version.UsesCryptoFrames()) { + versions.push_back(version); + } + } + return versions; +} + +INSTANTIATE_TEST_SUITE_P(QuicChaosProtectorTests, QuicChaosProtectorTest, + ::testing::ValuesIn(TestVersions()), + ::testing::PrintToStringParamName()); + +TEST_P(QuicChaosProtectorTest, Main) { + BuildEncryptAndParse(); + ASSERT_EQ(validation_framer_.crypto_frames().size(), 4u); + EXPECT_EQ(validation_framer_.crypto_frames()[0]->offset, 0u); + EXPECT_EQ(validation_framer_.crypto_frames()[0]->data_length, 1u); + ASSERT_EQ(validation_framer_.ping_frames().size(), 3u); + ASSERT_EQ(validation_framer_.padding_frames().size(), 7u); + EXPECT_EQ(validation_framer_.padding_frames()[0].num_padding_bytes, 3); +} + +TEST_P(QuicChaosProtectorTest, DifferentRandom) { + random_.ResetBase(4); + BuildEncryptAndParse(); + ASSERT_EQ(validation_framer_.crypto_frames().size(), 4u); + ASSERT_EQ(validation_framer_.ping_frames().size(), 4u); + ASSERT_EQ(validation_framer_.padding_frames().size(), 8u); +} + +TEST_P(QuicChaosProtectorTest, RandomnessZero) { + random_.ResetBase(0); + BuildEncryptAndParse(); + ASSERT_EQ(validation_framer_.crypto_frames().size(), 1u); + EXPECT_EQ(validation_framer_.crypto_frames()[0]->offset, crypto_offset_); + EXPECT_EQ(validation_framer_.crypto_frames()[0]->data_length, + crypto_data_length_); + ASSERT_EQ(validation_framer_.ping_frames().size(), 0u); + ASSERT_EQ(validation_framer_.padding_frames().size(), 1u); +} + +TEST_P(QuicChaosProtectorTest, Offset) { + ResetOffset(123); + BuildEncryptAndParse(); + ASSERT_EQ(validation_framer_.crypto_frames().size(), 4u); + EXPECT_EQ(validation_framer_.crypto_frames()[0]->offset, crypto_offset_); + EXPECT_EQ(validation_framer_.crypto_frames()[0]->data_length, 1u); + ASSERT_EQ(validation_framer_.ping_frames().size(), 3u); + ASSERT_EQ(validation_framer_.padding_frames().size(), 7u); + EXPECT_EQ(validation_framer_.padding_frames()[0].num_padding_bytes, 3); +} + +TEST_P(QuicChaosProtectorTest, OffsetAndRandomnessZero) { + ResetOffset(123); + random_.ResetBase(0); + BuildEncryptAndParse(); + ASSERT_EQ(validation_framer_.crypto_frames().size(), 1u); + EXPECT_EQ(validation_framer_.crypto_frames()[0]->offset, crypto_offset_); + EXPECT_EQ(validation_framer_.crypto_frames()[0]->data_length, + crypto_data_length_); + ASSERT_EQ(validation_framer_.ping_frames().size(), 0u); + ASSERT_EQ(validation_framer_.padding_frames().size(), 1u); +} + +TEST_P(QuicChaosProtectorTest, ZeroRemainingBytesAfterSplit) { + QuicPacketLength new_length = 63; + num_padding_bytes_ = QuicFramer::GetMinCryptoFrameSize( + crypto_frame_.offset + new_length, new_length); + ResetLength(new_length); + BuildEncryptAndParse(); + + ASSERT_EQ(validation_framer_.crypto_frames().size(), 2u); + EXPECT_EQ(validation_framer_.crypto_frames()[0]->offset, crypto_offset_); + EXPECT_EQ(validation_framer_.crypto_frames()[0]->data_length, 4); + EXPECT_EQ(validation_framer_.crypto_frames()[1]->offset, crypto_offset_ + 4); + EXPECT_EQ(validation_framer_.crypto_frames()[1]->data_length, + crypto_data_length_ - 4); + ASSERT_EQ(validation_framer_.ping_frames().size(), 0u); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_clock.h b/quiche/quic/core/quic_clock.h new file mode 100644 index 000000000000..5e5fd5baa5ee --- /dev/null +++ b/quiche/quic/core/quic_clock.h @@ -0,0 +1,47 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_CLOCK_H_ +#define QUICHE_QUIC_CORE_QUIC_CLOCK_H_ + +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_export.h" + +/* API_DESCRIPTION + QuicClock is used by QUIC core to get current time. Its instance is created by + applications and passed into QuicDispatcher and QuicConnectionHelperInterface. + API-DESCRIPTION */ + +namespace quic { + +// Interface for retrieving the current time. +class QUIC_EXPORT_PRIVATE QuicClock { + public: + QuicClock() = default; + virtual ~QuicClock() = default; + + QuicClock(const QuicClock&) = delete; + QuicClock& operator=(const QuicClock&) = delete; + + // Returns the approximate current time as a QuicTime object. + virtual QuicTime ApproximateNow() const = 0; + + // Returns the current time as a QuicTime object. + // Note: this use significant resources please use only if needed. + virtual QuicTime Now() const = 0; + + // WallNow returns the current wall-time - a time that is consistent across + // different clocks. + virtual QuicWallTime WallNow() const = 0; + + protected: + // Creates a new QuicTime using |time_us| as the internal value. + QuicTime CreateTimeFromMicroseconds(uint64_t time_us) const { + return QuicTime(time_us); + } +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_CLOCK_H_ diff --git a/quiche/quic/core/quic_coalesced_packet.cc b/quiche/quic/core/quic_coalesced_packet.cc new file mode 100644 index 000000000000..802fac49b463 --- /dev/null +++ b/quiche/quic/core/quic_coalesced_packet.cc @@ -0,0 +1,194 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_coalesced_packet.h" + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" + +namespace quic { + +QuicCoalescedPacket::QuicCoalescedPacket() + : length_(0), max_packet_length_(0) {} + +QuicCoalescedPacket::~QuicCoalescedPacket() { Clear(); } + +bool QuicCoalescedPacket::MaybeCoalescePacket( + const SerializedPacket& packet, const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + quiche::QuicheBufferAllocator* allocator, + QuicPacketLength current_max_packet_length) { + if (packet.encrypted_length == 0) { + QUIC_BUG(quic_bug_10611_1) << "Trying to coalesce an empty packet"; + return true; + } + if (length_ == 0) { +#ifndef NDEBUG + for (const auto& buffer : encrypted_buffers_) { + QUICHE_DCHECK(buffer.empty()); + } +#endif + QUICHE_DCHECK(initial_packet_ == nullptr); + // This is the first packet, set max_packet_length and self/peer + // addresses. + max_packet_length_ = current_max_packet_length; + self_address_ = self_address; + peer_address_ = peer_address; + } else { + if (self_address_ != self_address || peer_address_ != peer_address) { + // Do not coalesce packet with different self/peer addresses. + QUIC_DLOG(INFO) + << "Cannot coalesce packet because self/peer address changed"; + return false; + } + if (max_packet_length_ != current_max_packet_length) { + QUIC_BUG(quic_bug_10611_2) + << "Max packet length changes in the middle of the write path"; + return false; + } + if (ContainsPacketOfEncryptionLevel(packet.encryption_level)) { + // Do not coalesce packets of the same encryption level. + return false; + } + } + + if (length_ + packet.encrypted_length > max_packet_length_) { + // Packet does not fit. + return false; + } + QUIC_DVLOG(1) << "Successfully coalesced packet: encryption_level: " + << packet.encryption_level + << ", encrypted_length: " << packet.encrypted_length + << ", current length: " << length_ + << ", max_packet_length: " << max_packet_length_; + if (length_ > 0) { + QUIC_CODE_COUNT(QUIC_SUCCESSFULLY_COALESCED_MULTIPLE_PACKETS); + } + length_ += packet.encrypted_length; + transmission_types_[packet.encryption_level] = packet.transmission_type; + if (packet.encryption_level == ENCRYPTION_INITIAL) { + // Save a copy of ENCRYPTION_INITIAL packet (excluding encrypted buffer, as + // the packet will be re-serialized later). + initial_packet_ = absl::WrapUnique( + CopySerializedPacket(packet, allocator, /*copy_buffer=*/false)); + return true; + } + // Copy encrypted buffer of packets with other encryption levels. + encrypted_buffers_[packet.encryption_level] = + std::string(packet.encrypted_buffer, packet.encrypted_length); + return true; +} + +void QuicCoalescedPacket::Clear() { + self_address_ = QuicSocketAddress(); + peer_address_ = QuicSocketAddress(); + length_ = 0; + max_packet_length_ = 0; + for (auto& packet : encrypted_buffers_) { + packet.clear(); + } + for (size_t i = ENCRYPTION_INITIAL; i < NUM_ENCRYPTION_LEVELS; ++i) { + transmission_types_[i] = NOT_RETRANSMISSION; + } + initial_packet_ = nullptr; +} + +void QuicCoalescedPacket::NeuterInitialPacket() { + if (initial_packet_ == nullptr) { + return; + } + if (length_ < initial_packet_->encrypted_length) { + QUIC_BUG(quic_bug_10611_3) + << "length_: " << length_ << ", is less than initial packet length: " + << initial_packet_->encrypted_length; + Clear(); + return; + } + length_ -= initial_packet_->encrypted_length; + if (length_ == 0) { + Clear(); + return; + } + transmission_types_[ENCRYPTION_INITIAL] = NOT_RETRANSMISSION; + initial_packet_ = nullptr; +} + +bool QuicCoalescedPacket::CopyEncryptedBuffers(char* buffer, size_t buffer_len, + size_t* length_copied) const { + *length_copied = 0; + for (const auto& packet : encrypted_buffers_) { + if (packet.empty()) { + continue; + } + if (packet.length() > buffer_len) { + return false; + } + memcpy(buffer, packet.data(), packet.length()); + buffer += packet.length(); + buffer_len -= packet.length(); + *length_copied += packet.length(); + } + return true; +} + +bool QuicCoalescedPacket::ContainsPacketOfEncryptionLevel( + EncryptionLevel level) const { + return !encrypted_buffers_[level].empty() || + (level == ENCRYPTION_INITIAL && initial_packet_ != nullptr); +} + +TransmissionType QuicCoalescedPacket::TransmissionTypeOfPacket( + EncryptionLevel level) const { + if (!ContainsPacketOfEncryptionLevel(level)) { + QUIC_BUG(quic_bug_10611_4) + << "Coalesced packet does not contain packet of encryption level: " + << EncryptionLevelToString(level); + return NOT_RETRANSMISSION; + } + return transmission_types_[level]; +} + +size_t QuicCoalescedPacket::NumberOfPackets() const { + size_t num_of_packets = 0; + for (int8_t i = ENCRYPTION_INITIAL; i < NUM_ENCRYPTION_LEVELS; ++i) { + if (ContainsPacketOfEncryptionLevel(static_cast(i))) { + ++num_of_packets; + } + } + return num_of_packets; +} + +std::string QuicCoalescedPacket::ToString(size_t serialized_length) const { + // Total length and padding size. + std::string info = absl::StrCat( + "total_length: ", serialized_length, + " padding_size: ", serialized_length - length_, " packets: {"); + // Packets' encryption levels. + bool first_packet = true; + for (int8_t i = ENCRYPTION_INITIAL; i < NUM_ENCRYPTION_LEVELS; ++i) { + if (ContainsPacketOfEncryptionLevel(static_cast(i))) { + absl::StrAppend(&info, first_packet ? "" : ", ", + EncryptionLevelToString(static_cast(i))); + first_packet = false; + } + } + absl::StrAppend(&info, "}"); + return info; +} + +std::vector QuicCoalescedPacket::packet_lengths() const { + std::vector lengths; + for (const auto& packet : encrypted_buffers_) { + if (lengths.empty()) { + lengths.push_back( + initial_packet_ == nullptr ? 0 : initial_packet_->encrypted_length); + } else { + lengths.push_back(packet.length()); + } + } + return lengths; +} + +} // namespace quic diff --git a/quiche/quic/core/quic_coalesced_packet.h b/quiche/quic/core/quic_coalesced_packet.h new file mode 100644 index 000000000000..21a6e1ed20e2 --- /dev/null +++ b/quiche/quic/core/quic_coalesced_packet.h @@ -0,0 +1,96 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_COALESCED_PACKET_H_ +#define QUICHE_QUIC_CORE_QUIC_COALESCED_PACKET_H_ + +#include "quiche/quic/core/quic_packets.h" + +namespace quic { + +namespace test { +class QuicCoalescedPacketPeer; +} + +// QuicCoalescedPacket is used to buffer multiple packets which can be coalesced +// into the same UDP datagram. +class QUIC_EXPORT_PRIVATE QuicCoalescedPacket { + public: + QuicCoalescedPacket(); + ~QuicCoalescedPacket(); + + // Returns true if |packet| is successfully coalesced with existing packets. + // Returns false otherwise. + bool MaybeCoalescePacket(const SerializedPacket& packet, + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + quiche::QuicheBufferAllocator* allocator, + QuicPacketLength current_max_packet_length); + + // Clears this coalesced packet. + void Clear(); + + // Clears all state associated with initial_packet_. + void NeuterInitialPacket(); + + // Copies encrypted_buffers_ to |buffer| and sets |length_copied| to the + // copied amount. Returns false if copy fails (i.e., |buffer_len| is not + // enough). + bool CopyEncryptedBuffers(char* buffer, size_t buffer_len, + size_t* length_copied) const; + + std::string ToString(size_t serialized_length) const; + + // Returns true if this coalesced packet contains packet of |level|. + bool ContainsPacketOfEncryptionLevel(EncryptionLevel level) const; + + // Returns transmission type of packet of |level|. This should only be called + // when this coalesced packet contains packet of |level|. + TransmissionType TransmissionTypeOfPacket(EncryptionLevel level) const; + + // Returns number of packets contained in this coalesced packet. + size_t NumberOfPackets() const; + + const SerializedPacket* initial_packet() const { + return initial_packet_.get(); + } + + const QuicSocketAddress& self_address() const { return self_address_; } + + const QuicSocketAddress& peer_address() const { return peer_address_; } + + QuicPacketLength length() const { return length_; } + + QuicPacketLength max_packet_length() const { return max_packet_length_; } + + std::vector packet_lengths() const; + + private: + friend class test::QuicCoalescedPacketPeer; + + // self/peer addresses are set when trying to coalesce the first packet. + // Packets with different self/peer addresses cannot be coalesced. + QuicSocketAddress self_address_; + QuicSocketAddress peer_address_; + // Length of this coalesced packet. + QuicPacketLength length_; + // Max packet length. Do not try to coalesce packet when max packet length + // changes (e.g., with MTU discovery). + QuicPacketLength max_packet_length_; + // Copies of packets' encrypted buffers according to different encryption + // levels. + std::string encrypted_buffers_[NUM_ENCRYPTION_LEVELS]; + // Recorded transmission type according to different encryption levels. + TransmissionType transmission_types_[NUM_ENCRYPTION_LEVELS]; + + // A copy of ENCRYPTION_INITIAL packet if this coalesced packet contains one. + // Null otherwise. Please note, the encrypted_buffer field is not copied. The + // frames are copied to allow it be re-serialized when this coalesced packet + // gets sent. + std::unique_ptr initial_packet_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_COALESCED_PACKET_H_ diff --git a/quiche/quic/core/quic_coalesced_packet_test.cc b/quiche/quic/core/quic_coalesced_packet_test.cc new file mode 100644 index 000000000000..eb69372844f0 --- /dev/null +++ b/quiche/quic/core/quic_coalesced_packet_test.cc @@ -0,0 +1,213 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_coalesced_packet.h" + +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +namespace quic { +namespace test { +namespace { + +TEST(QuicCoalescedPacketTest, MaybeCoalescePacket) { + QuicCoalescedPacket coalesced; + EXPECT_EQ("total_length: 0 padding_size: 0 packets: {}", + coalesced.ToString(0)); + quiche::SimpleBufferAllocator allocator; + EXPECT_EQ(0u, coalesced.length()); + EXPECT_EQ(0u, coalesced.NumberOfPackets()); + char buffer[1000]; + QuicSocketAddress self_address(QuicIpAddress::Loopback4(), 1); + QuicSocketAddress peer_address(QuicIpAddress::Loopback4(), 2); + SerializedPacket packet1(QuicPacketNumber(1), PACKET_4BYTE_PACKET_NUMBER, + buffer, 500, false, false); + packet1.transmission_type = PTO_RETRANSMISSION; + QuicAckFrame ack_frame(InitAckFrame(1)); + packet1.nonretransmittable_frames.push_back(QuicFrame(&ack_frame)); + packet1.retransmittable_frames.push_back( + QuicFrame(QuicStreamFrame(1, true, 0, 100))); + ASSERT_TRUE(coalesced.MaybeCoalescePacket(packet1, self_address, peer_address, + &allocator, 1500)); + EXPECT_EQ(PTO_RETRANSMISSION, + coalesced.TransmissionTypeOfPacket(ENCRYPTION_INITIAL)); + EXPECT_EQ(1500u, coalesced.max_packet_length()); + EXPECT_EQ(500u, coalesced.length()); + EXPECT_EQ(1u, coalesced.NumberOfPackets()); + EXPECT_EQ( + "total_length: 1500 padding_size: 1000 packets: {ENCRYPTION_INITIAL}", + coalesced.ToString(1500)); + + // Cannot coalesce packet of the same encryption level. + SerializedPacket packet2(QuicPacketNumber(2), PACKET_4BYTE_PACKET_NUMBER, + buffer, 500, false, false); + EXPECT_FALSE(coalesced.MaybeCoalescePacket(packet2, self_address, + peer_address, &allocator, 1500)); + + SerializedPacket packet3(QuicPacketNumber(3), PACKET_4BYTE_PACKET_NUMBER, + buffer, 500, false, false); + packet3.nonretransmittable_frames.push_back(QuicFrame(QuicPaddingFrame(100))); + packet3.encryption_level = ENCRYPTION_ZERO_RTT; + packet3.transmission_type = LOSS_RETRANSMISSION; + ASSERT_TRUE(coalesced.MaybeCoalescePacket(packet3, self_address, peer_address, + &allocator, 1500)); + EXPECT_EQ(1500u, coalesced.max_packet_length()); + EXPECT_EQ(1000u, coalesced.length()); + EXPECT_EQ(2u, coalesced.NumberOfPackets()); + EXPECT_EQ(LOSS_RETRANSMISSION, + coalesced.TransmissionTypeOfPacket(ENCRYPTION_ZERO_RTT)); + EXPECT_EQ( + "total_length: 1500 padding_size: 500 packets: {ENCRYPTION_INITIAL, " + "ENCRYPTION_ZERO_RTT}", + coalesced.ToString(1500)); + + SerializedPacket packet4(QuicPacketNumber(4), PACKET_4BYTE_PACKET_NUMBER, + buffer, 500, false, false); + packet4.encryption_level = ENCRYPTION_FORWARD_SECURE; + // Cannot coalesce packet of changed self/peer address. + EXPECT_FALSE(coalesced.MaybeCoalescePacket( + packet4, QuicSocketAddress(QuicIpAddress::Loopback4(), 3), peer_address, + &allocator, 1500)); + + // Packet does not fit. + SerializedPacket packet5(QuicPacketNumber(5), PACKET_4BYTE_PACKET_NUMBER, + buffer, 501, false, false); + packet5.encryption_level = ENCRYPTION_FORWARD_SECURE; + EXPECT_FALSE(coalesced.MaybeCoalescePacket(packet5, self_address, + peer_address, &allocator, 1500)); + EXPECT_EQ(1500u, coalesced.max_packet_length()); + EXPECT_EQ(1000u, coalesced.length()); + EXPECT_EQ(2u, coalesced.NumberOfPackets()); + + // Max packet number length changed. + SerializedPacket packet6(QuicPacketNumber(6), PACKET_4BYTE_PACKET_NUMBER, + buffer, 100, false, false); + packet6.encryption_level = ENCRYPTION_FORWARD_SECURE; + EXPECT_QUIC_BUG(coalesced.MaybeCoalescePacket(packet6, self_address, + peer_address, &allocator, 1000), + "Max packet length changes in the middle of the write path"); + EXPECT_EQ(1500u, coalesced.max_packet_length()); + EXPECT_EQ(1000u, coalesced.length()); + EXPECT_EQ(2u, coalesced.NumberOfPackets()); +} + +TEST(QuicCoalescedPacketTest, CopyEncryptedBuffers) { + QuicCoalescedPacket coalesced; + quiche::SimpleBufferAllocator allocator; + QuicSocketAddress self_address(QuicIpAddress::Loopback4(), 1); + QuicSocketAddress peer_address(QuicIpAddress::Loopback4(), 2); + std::string buffer(500, 'a'); + std::string buffer2(500, 'b'); + SerializedPacket packet1(QuicPacketNumber(1), PACKET_4BYTE_PACKET_NUMBER, + buffer.data(), 500, + /*has_ack=*/false, /*has_stop_waiting=*/false); + packet1.encryption_level = ENCRYPTION_ZERO_RTT; + SerializedPacket packet2(QuicPacketNumber(2), PACKET_4BYTE_PACKET_NUMBER, + buffer2.data(), 500, + /*has_ack=*/false, /*has_stop_waiting=*/false); + packet2.encryption_level = ENCRYPTION_FORWARD_SECURE; + + ASSERT_TRUE(coalesced.MaybeCoalescePacket(packet1, self_address, peer_address, + &allocator, 1500)); + ASSERT_TRUE(coalesced.MaybeCoalescePacket(packet2, self_address, peer_address, + &allocator, 1500)); + EXPECT_EQ(1000u, coalesced.length()); + + char copy_buffer[1000]; + size_t length_copied = 0; + EXPECT_FALSE( + coalesced.CopyEncryptedBuffers(copy_buffer, 900, &length_copied)); + ASSERT_TRUE( + coalesced.CopyEncryptedBuffers(copy_buffer, 1000, &length_copied)); + EXPECT_EQ(1000u, length_copied); + char expected[1000]; + memset(expected, 'a', 500); + memset(expected + 500, 'b', 500); + quiche::test::CompareCharArraysWithHexError("copied buffers", copy_buffer, + length_copied, expected, 1000); +} + +TEST(QuicCoalescedPacketTest, NeuterInitialPacket) { + QuicCoalescedPacket coalesced; + EXPECT_EQ("total_length: 0 padding_size: 0 packets: {}", + coalesced.ToString(0)); + // Noop when neutering initial packet on a empty coalescer. + coalesced.NeuterInitialPacket(); + EXPECT_EQ("total_length: 0 padding_size: 0 packets: {}", + coalesced.ToString(0)); + + quiche::SimpleBufferAllocator allocator; + EXPECT_EQ(0u, coalesced.length()); + char buffer[1000]; + QuicSocketAddress self_address(QuicIpAddress::Loopback4(), 1); + QuicSocketAddress peer_address(QuicIpAddress::Loopback4(), 2); + SerializedPacket packet1(QuicPacketNumber(1), PACKET_4BYTE_PACKET_NUMBER, + buffer, 500, false, false); + packet1.transmission_type = PTO_RETRANSMISSION; + QuicAckFrame ack_frame(InitAckFrame(1)); + packet1.nonretransmittable_frames.push_back(QuicFrame(&ack_frame)); + packet1.retransmittable_frames.push_back( + QuicFrame(QuicStreamFrame(1, true, 0, 100))); + ASSERT_TRUE(coalesced.MaybeCoalescePacket(packet1, self_address, peer_address, + &allocator, 1500)); + EXPECT_EQ(PTO_RETRANSMISSION, + coalesced.TransmissionTypeOfPacket(ENCRYPTION_INITIAL)); + EXPECT_EQ(1500u, coalesced.max_packet_length()); + EXPECT_EQ(500u, coalesced.length()); + EXPECT_EQ( + "total_length: 1500 padding_size: 1000 packets: {ENCRYPTION_INITIAL}", + coalesced.ToString(1500)); + // Neuter initial packet. + coalesced.NeuterInitialPacket(); + EXPECT_EQ(0u, coalesced.max_packet_length()); + EXPECT_EQ(0u, coalesced.length()); + EXPECT_EQ("total_length: 0 padding_size: 0 packets: {}", + coalesced.ToString(0)); + + // Coalesce initial packet again. + ASSERT_TRUE(coalesced.MaybeCoalescePacket(packet1, self_address, peer_address, + &allocator, 1500)); + + SerializedPacket packet2(QuicPacketNumber(3), PACKET_4BYTE_PACKET_NUMBER, + buffer, 500, false, false); + packet2.nonretransmittable_frames.push_back(QuicFrame(QuicPaddingFrame(100))); + packet2.encryption_level = ENCRYPTION_ZERO_RTT; + packet2.transmission_type = LOSS_RETRANSMISSION; + ASSERT_TRUE(coalesced.MaybeCoalescePacket(packet2, self_address, peer_address, + &allocator, 1500)); + EXPECT_EQ(1500u, coalesced.max_packet_length()); + EXPECT_EQ(1000u, coalesced.length()); + EXPECT_EQ(LOSS_RETRANSMISSION, + coalesced.TransmissionTypeOfPacket(ENCRYPTION_ZERO_RTT)); + EXPECT_EQ( + "total_length: 1500 padding_size: 500 packets: {ENCRYPTION_INITIAL, " + "ENCRYPTION_ZERO_RTT}", + coalesced.ToString(1500)); + + // Neuter initial packet. + coalesced.NeuterInitialPacket(); + EXPECT_EQ(1500u, coalesced.max_packet_length()); + EXPECT_EQ(500u, coalesced.length()); + EXPECT_EQ( + "total_length: 1500 padding_size: 1000 packets: {ENCRYPTION_ZERO_RTT}", + coalesced.ToString(1500)); + + SerializedPacket packet3(QuicPacketNumber(5), PACKET_4BYTE_PACKET_NUMBER, + buffer, 501, false, false); + packet3.encryption_level = ENCRYPTION_FORWARD_SECURE; + EXPECT_TRUE(coalesced.MaybeCoalescePacket(packet3, self_address, peer_address, + &allocator, 1500)); + EXPECT_EQ(1500u, coalesced.max_packet_length()); + EXPECT_EQ(1001u, coalesced.length()); + // Neuter initial packet. + coalesced.NeuterInitialPacket(); + EXPECT_EQ(1500u, coalesced.max_packet_length()); + EXPECT_EQ(1001u, coalesced.length()); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_config.cc b/quiche/quic/core/quic_config.cc new file mode 100644 index 000000000000..43f4bbd96994 --- /dev/null +++ b/quiche/quic/core/quic_config.cc @@ -0,0 +1,1434 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_config.h" + +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/crypto_handshake_message.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_socket_address_coder.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_socket_address.h" + +namespace quic { + +// Reads the value corresponding to |name_| from |msg| into |out|. If the +// |name_| is absent in |msg| and |presence| is set to OPTIONAL |out| is set +// to |default_value|. +QuicErrorCode ReadUint32(const CryptoHandshakeMessage& msg, QuicTag tag, + QuicConfigPresence presence, uint32_t default_value, + uint32_t* out, std::string* error_details) { + QUICHE_DCHECK(error_details != nullptr); + QuicErrorCode error = msg.GetUint32(tag, out); + switch (error) { + case QUIC_CRYPTO_MESSAGE_PARAMETER_NOT_FOUND: + if (presence == PRESENCE_REQUIRED) { + *error_details = "Missing " + QuicTagToString(tag); + break; + } + error = QUIC_NO_ERROR; + *out = default_value; + break; + case QUIC_NO_ERROR: + break; + default: + *error_details = "Bad " + QuicTagToString(tag); + break; + } + return error; +} + +QuicConfigValue::QuicConfigValue(QuicTag tag, QuicConfigPresence presence) + : tag_(tag), presence_(presence) {} +QuicConfigValue::~QuicConfigValue() {} + +QuicFixedUint32::QuicFixedUint32(QuicTag tag, QuicConfigPresence presence) + : QuicConfigValue(tag, presence), + has_send_value_(false), + has_receive_value_(false) {} +QuicFixedUint32::~QuicFixedUint32() {} + +bool QuicFixedUint32::HasSendValue() const { return has_send_value_; } + +uint32_t QuicFixedUint32::GetSendValue() const { + QUIC_BUG_IF(quic_bug_12743_1, !has_send_value_) + << "No send value to get for tag:" << QuicTagToString(tag_); + return send_value_; +} + +void QuicFixedUint32::SetSendValue(uint32_t value) { + has_send_value_ = true; + send_value_ = value; +} + +bool QuicFixedUint32::HasReceivedValue() const { return has_receive_value_; } + +uint32_t QuicFixedUint32::GetReceivedValue() const { + QUIC_BUG_IF(quic_bug_12743_2, !has_receive_value_) + << "No receive value to get for tag:" << QuicTagToString(tag_); + return receive_value_; +} + +void QuicFixedUint32::SetReceivedValue(uint32_t value) { + has_receive_value_ = true; + receive_value_ = value; +} + +void QuicFixedUint32::ToHandshakeMessage(CryptoHandshakeMessage* out) const { + if (tag_ == 0) { + QUIC_BUG(quic_bug_12743_3) + << "This parameter does not support writing to CryptoHandshakeMessage"; + return; + } + if (has_send_value_) { + out->SetValue(tag_, send_value_); + } +} + +QuicErrorCode QuicFixedUint32::ProcessPeerHello( + const CryptoHandshakeMessage& peer_hello, HelloType /*hello_type*/, + std::string* error_details) { + QUICHE_DCHECK(error_details != nullptr); + if (tag_ == 0) { + *error_details = + "This parameter does not support reading from CryptoHandshakeMessage"; + QUIC_BUG(quic_bug_10575_1) << *error_details; + return QUIC_CRYPTO_MESSAGE_PARAMETER_NOT_FOUND; + } + QuicErrorCode error = peer_hello.GetUint32(tag_, &receive_value_); + switch (error) { + case QUIC_CRYPTO_MESSAGE_PARAMETER_NOT_FOUND: + if (presence_ == PRESENCE_OPTIONAL) { + return QUIC_NO_ERROR; + } + *error_details = "Missing " + QuicTagToString(tag_); + break; + case QUIC_NO_ERROR: + has_receive_value_ = true; + break; + default: + *error_details = "Bad " + QuicTagToString(tag_); + break; + } + return error; +} + +QuicFixedUint62::QuicFixedUint62(QuicTag name, QuicConfigPresence presence) + : QuicConfigValue(name, presence), + has_send_value_(false), + has_receive_value_(false) {} + +QuicFixedUint62::~QuicFixedUint62() {} + +bool QuicFixedUint62::HasSendValue() const { return has_send_value_; } + +uint64_t QuicFixedUint62::GetSendValue() const { + if (!has_send_value_) { + QUIC_BUG(quic_bug_10575_2) + << "No send value to get for tag:" << QuicTagToString(tag_); + return 0; + } + return send_value_; +} + +void QuicFixedUint62::SetSendValue(uint64_t value) { + if (value > quiche::kVarInt62MaxValue) { + QUIC_BUG(quic_bug_10575_3) << "QuicFixedUint62 invalid value " << value; + value = quiche::kVarInt62MaxValue; + } + has_send_value_ = true; + send_value_ = value; +} + +bool QuicFixedUint62::HasReceivedValue() const { return has_receive_value_; } + +uint64_t QuicFixedUint62::GetReceivedValue() const { + if (!has_receive_value_) { + QUIC_BUG(quic_bug_10575_4) + << "No receive value to get for tag:" << QuicTagToString(tag_); + return 0; + } + return receive_value_; +} + +void QuicFixedUint62::SetReceivedValue(uint64_t value) { + has_receive_value_ = true; + receive_value_ = value; +} + +void QuicFixedUint62::ToHandshakeMessage(CryptoHandshakeMessage* out) const { + if (!has_send_value_) { + return; + } + uint32_t send_value32; + if (send_value_ > std::numeric_limits::max()) { + QUIC_BUG(quic_bug_10575_5) << "Attempting to send " << send_value_ + << " for tag:" << QuicTagToString(tag_); + send_value32 = std::numeric_limits::max(); + } else { + send_value32 = static_cast(send_value_); + } + out->SetValue(tag_, send_value32); +} + +QuicErrorCode QuicFixedUint62::ProcessPeerHello( + const CryptoHandshakeMessage& peer_hello, HelloType /*hello_type*/, + std::string* error_details) { + QUICHE_DCHECK(error_details != nullptr); + uint32_t receive_value32; + QuicErrorCode error = peer_hello.GetUint32(tag_, &receive_value32); + // GetUint32 is guaranteed to always initialize receive_value32. + receive_value_ = receive_value32; + switch (error) { + case QUIC_CRYPTO_MESSAGE_PARAMETER_NOT_FOUND: + if (presence_ == PRESENCE_OPTIONAL) { + return QUIC_NO_ERROR; + } + *error_details = "Missing " + QuicTagToString(tag_); + break; + case QUIC_NO_ERROR: + has_receive_value_ = true; + break; + default: + *error_details = "Bad " + QuicTagToString(tag_); + break; + } + return error; +} + +QuicFixedStatelessResetToken::QuicFixedStatelessResetToken( + QuicTag tag, QuicConfigPresence presence) + : QuicConfigValue(tag, presence), + has_send_value_(false), + has_receive_value_(false) {} +QuicFixedStatelessResetToken::~QuicFixedStatelessResetToken() {} + +bool QuicFixedStatelessResetToken::HasSendValue() const { + return has_send_value_; +} + +const StatelessResetToken& QuicFixedStatelessResetToken::GetSendValue() const { + QUIC_BUG_IF(quic_bug_12743_4, !has_send_value_) + << "No send value to get for tag:" << QuicTagToString(tag_); + return send_value_; +} + +void QuicFixedStatelessResetToken::SetSendValue( + const StatelessResetToken& value) { + has_send_value_ = true; + send_value_ = value; +} + +bool QuicFixedStatelessResetToken::HasReceivedValue() const { + return has_receive_value_; +} + +const StatelessResetToken& QuicFixedStatelessResetToken::GetReceivedValue() + const { + QUIC_BUG_IF(quic_bug_12743_5, !has_receive_value_) + << "No receive value to get for tag:" << QuicTagToString(tag_); + return receive_value_; +} + +void QuicFixedStatelessResetToken::SetReceivedValue( + const StatelessResetToken& value) { + has_receive_value_ = true; + receive_value_ = value; +} + +void QuicFixedStatelessResetToken::ToHandshakeMessage( + CryptoHandshakeMessage* out) const { + if (has_send_value_) { + out->SetValue(tag_, send_value_); + } +} + +QuicErrorCode QuicFixedStatelessResetToken::ProcessPeerHello( + const CryptoHandshakeMessage& peer_hello, HelloType /*hello_type*/, + std::string* error_details) { + QUICHE_DCHECK(error_details != nullptr); + QuicErrorCode error = + peer_hello.GetStatelessResetToken(tag_, &receive_value_); + switch (error) { + case QUIC_CRYPTO_MESSAGE_PARAMETER_NOT_FOUND: + if (presence_ == PRESENCE_OPTIONAL) { + return QUIC_NO_ERROR; + } + *error_details = "Missing " + QuicTagToString(tag_); + break; + case QUIC_NO_ERROR: + has_receive_value_ = true; + break; + default: + *error_details = "Bad " + QuicTagToString(tag_); + break; + } + return error; +} + +QuicFixedTagVector::QuicFixedTagVector(QuicTag name, + QuicConfigPresence presence) + : QuicConfigValue(name, presence), + has_send_values_(false), + has_receive_values_(false) {} + +QuicFixedTagVector::QuicFixedTagVector(const QuicFixedTagVector& other) = + default; + +QuicFixedTagVector::~QuicFixedTagVector() {} + +bool QuicFixedTagVector::HasSendValues() const { return has_send_values_; } + +const QuicTagVector& QuicFixedTagVector::GetSendValues() const { + QUIC_BUG_IF(quic_bug_12743_6, !has_send_values_) + << "No send values to get for tag:" << QuicTagToString(tag_); + return send_values_; +} + +void QuicFixedTagVector::SetSendValues(const QuicTagVector& values) { + has_send_values_ = true; + send_values_ = values; +} + +bool QuicFixedTagVector::HasReceivedValues() const { + return has_receive_values_; +} + +const QuicTagVector& QuicFixedTagVector::GetReceivedValues() const { + QUIC_BUG_IF(quic_bug_12743_7, !has_receive_values_) + << "No receive value to get for tag:" << QuicTagToString(tag_); + return receive_values_; +} + +void QuicFixedTagVector::SetReceivedValues(const QuicTagVector& values) { + has_receive_values_ = true; + receive_values_ = values; +} + +void QuicFixedTagVector::ToHandshakeMessage(CryptoHandshakeMessage* out) const { + if (has_send_values_) { + out->SetVector(tag_, send_values_); + } +} + +QuicErrorCode QuicFixedTagVector::ProcessPeerHello( + const CryptoHandshakeMessage& peer_hello, HelloType /*hello_type*/, + std::string* error_details) { + QUICHE_DCHECK(error_details != nullptr); + QuicTagVector values; + QuicErrorCode error = peer_hello.GetTaglist(tag_, &values); + switch (error) { + case QUIC_CRYPTO_MESSAGE_PARAMETER_NOT_FOUND: + if (presence_ == PRESENCE_OPTIONAL) { + return QUIC_NO_ERROR; + } + *error_details = "Missing " + QuicTagToString(tag_); + break; + case QUIC_NO_ERROR: + QUIC_DVLOG(1) << "Received Connection Option tags from receiver."; + has_receive_values_ = true; + receive_values_.insert(receive_values_.end(), values.begin(), + values.end()); + break; + default: + *error_details = "Bad " + QuicTagToString(tag_); + break; + } + return error; +} + +QuicFixedSocketAddress::QuicFixedSocketAddress(QuicTag tag, + QuicConfigPresence presence) + : QuicConfigValue(tag, presence), + has_send_value_(false), + has_receive_value_(false) {} + +QuicFixedSocketAddress::~QuicFixedSocketAddress() {} + +bool QuicFixedSocketAddress::HasSendValue() const { return has_send_value_; } + +const QuicSocketAddress& QuicFixedSocketAddress::GetSendValue() const { + QUIC_BUG_IF(quic_bug_12743_8, !has_send_value_) + << "No send value to get for tag:" << QuicTagToString(tag_); + return send_value_; +} + +void QuicFixedSocketAddress::SetSendValue(const QuicSocketAddress& value) { + has_send_value_ = true; + send_value_ = value; +} + +void QuicFixedSocketAddress::ClearSendValue() { + has_send_value_ = false; + send_value_ = QuicSocketAddress(); +} + +bool QuicFixedSocketAddress::HasReceivedValue() const { + return has_receive_value_; +} + +const QuicSocketAddress& QuicFixedSocketAddress::GetReceivedValue() const { + QUIC_BUG_IF(quic_bug_12743_9, !has_receive_value_) + << "No receive value to get for tag:" << QuicTagToString(tag_); + return receive_value_; +} + +void QuicFixedSocketAddress::SetReceivedValue(const QuicSocketAddress& value) { + has_receive_value_ = true; + receive_value_ = value; +} + +void QuicFixedSocketAddress::ToHandshakeMessage( + CryptoHandshakeMessage* out) const { + if (has_send_value_) { + QuicSocketAddressCoder address_coder(send_value_); + out->SetStringPiece(tag_, address_coder.Encode()); + } +} + +QuicErrorCode QuicFixedSocketAddress::ProcessPeerHello( + const CryptoHandshakeMessage& peer_hello, HelloType /*hello_type*/, + std::string* error_details) { + absl::string_view address; + if (!peer_hello.GetStringPiece(tag_, &address)) { + if (presence_ == PRESENCE_REQUIRED) { + *error_details = "Missing " + QuicTagToString(tag_); + return QUIC_CRYPTO_MESSAGE_PARAMETER_NOT_FOUND; + } + } else { + QuicSocketAddressCoder address_coder; + if (address_coder.Decode(address.data(), address.length())) { + SetReceivedValue( + QuicSocketAddress(address_coder.ip(), address_coder.port())); + } + } + return QUIC_NO_ERROR; +} + +QuicConfig::QuicConfig() + : negotiated_(false), + max_time_before_crypto_handshake_(QuicTime::Delta::Zero()), + max_idle_time_before_crypto_handshake_(QuicTime::Delta::Zero()), + max_undecryptable_packets_(0), + connection_options_(kCOPT, PRESENCE_OPTIONAL), + client_connection_options_(kCLOP, PRESENCE_OPTIONAL), + max_idle_timeout_to_send_(QuicTime::Delta::Infinite()), + max_bidirectional_streams_(kMIBS, PRESENCE_REQUIRED), + max_unidirectional_streams_(kMIUS, PRESENCE_OPTIONAL), + bytes_for_connection_id_(kTCID, PRESENCE_OPTIONAL), + initial_round_trip_time_us_(kIRTT, PRESENCE_OPTIONAL), + initial_max_stream_data_bytes_incoming_bidirectional_(0, + PRESENCE_OPTIONAL), + initial_max_stream_data_bytes_outgoing_bidirectional_(0, + PRESENCE_OPTIONAL), + initial_max_stream_data_bytes_unidirectional_(0, PRESENCE_OPTIONAL), + initial_stream_flow_control_window_bytes_(kSFCW, PRESENCE_OPTIONAL), + initial_session_flow_control_window_bytes_(kCFCW, PRESENCE_OPTIONAL), + connection_migration_disabled_(kNCMR, PRESENCE_OPTIONAL), + alternate_server_address_ipv6_(kASAD, PRESENCE_OPTIONAL), + alternate_server_address_ipv4_(kASAD, PRESENCE_OPTIONAL), + stateless_reset_token_(kSRST, PRESENCE_OPTIONAL), + max_ack_delay_ms_(kMAD, PRESENCE_OPTIONAL), + min_ack_delay_ms_(0, PRESENCE_OPTIONAL), + ack_delay_exponent_(kADE, PRESENCE_OPTIONAL), + max_udp_payload_size_(0, PRESENCE_OPTIONAL), + max_datagram_frame_size_(0, PRESENCE_OPTIONAL), + active_connection_id_limit_(0, PRESENCE_OPTIONAL) { + SetDefaults(); +} + +QuicConfig::QuicConfig(const QuicConfig& other) = default; + +QuicConfig::~QuicConfig() {} + +bool QuicConfig::SetInitialReceivedConnectionOptions( + const QuicTagVector& tags) { + if (HasReceivedConnectionOptions()) { + // If we have already received connection options (via handshake or due to + // a previous call), don't re-initialize. + return false; + } + connection_options_.SetReceivedValues(tags); + return true; +} + +void QuicConfig::SetConnectionOptionsToSend( + const QuicTagVector& connection_options) { + connection_options_.SetSendValues(connection_options); +} + +void QuicConfig::SetGoogleHandshakeMessageToSend(std::string message) { + google_handshake_message_to_send_ = std::move(message); +} + +const absl::optional& +QuicConfig::GetReceivedGoogleHandshakeMessage() const { + return received_google_handshake_message_; +} + +bool QuicConfig::HasReceivedConnectionOptions() const { + return connection_options_.HasReceivedValues(); +} + +const QuicTagVector& QuicConfig::ReceivedConnectionOptions() const { + return connection_options_.GetReceivedValues(); +} + +bool QuicConfig::HasSendConnectionOptions() const { + return connection_options_.HasSendValues(); +} + +const QuicTagVector& QuicConfig::SendConnectionOptions() const { + return connection_options_.GetSendValues(); +} + +bool QuicConfig::HasClientSentConnectionOption(QuicTag tag, + Perspective perspective) const { + if (perspective == Perspective::IS_SERVER) { + if (HasReceivedConnectionOptions() && + ContainsQuicTag(ReceivedConnectionOptions(), tag)) { + return true; + } + } else if (HasSendConnectionOptions() && + ContainsQuicTag(SendConnectionOptions(), tag)) { + return true; + } + return false; +} + +void QuicConfig::SetClientConnectionOptions( + const QuicTagVector& client_connection_options) { + client_connection_options_.SetSendValues(client_connection_options); +} + +bool QuicConfig::HasClientRequestedIndependentOption( + QuicTag tag, Perspective perspective) const { + if (perspective == Perspective::IS_SERVER) { + return (HasReceivedConnectionOptions() && + ContainsQuicTag(ReceivedConnectionOptions(), tag)); + } + + return (client_connection_options_.HasSendValues() && + ContainsQuicTag(client_connection_options_.GetSendValues(), tag)); +} + +const QuicTagVector& QuicConfig::ClientRequestedIndependentOptions( + Perspective perspective) const { + static const QuicTagVector* no_options = new QuicTagVector; + if (perspective == Perspective::IS_SERVER) { + return HasReceivedConnectionOptions() ? ReceivedConnectionOptions() + : *no_options; + } + + return client_connection_options_.HasSendValues() + ? client_connection_options_.GetSendValues() + : *no_options; +} + +void QuicConfig::SetIdleNetworkTimeout(QuicTime::Delta idle_network_timeout) { + if (idle_network_timeout.ToMicroseconds() <= 0) { + QUIC_BUG(quic_bug_10575_6) + << "Invalid idle network timeout " << idle_network_timeout; + return; + } + max_idle_timeout_to_send_ = idle_network_timeout; +} + +QuicTime::Delta QuicConfig::IdleNetworkTimeout() const { + // TODO(b/152032210) add a QUIC_BUG to ensure that is not called before we've + // received the peer's values. This is true in production code but not in all + // of our tests that use a fake QuicConfig. + if (!received_max_idle_timeout_.has_value()) { + return max_idle_timeout_to_send_; + } + return received_max_idle_timeout_.value(); +} + +void QuicConfig::SetMaxBidirectionalStreamsToSend(uint32_t max_streams) { + max_bidirectional_streams_.SetSendValue(max_streams); +} + +uint32_t QuicConfig::GetMaxBidirectionalStreamsToSend() const { + return max_bidirectional_streams_.GetSendValue(); +} + +bool QuicConfig::HasReceivedMaxBidirectionalStreams() const { + return max_bidirectional_streams_.HasReceivedValue(); +} + +uint32_t QuicConfig::ReceivedMaxBidirectionalStreams() const { + return max_bidirectional_streams_.GetReceivedValue(); +} + +void QuicConfig::SetMaxUnidirectionalStreamsToSend(uint32_t max_streams) { + max_unidirectional_streams_.SetSendValue(max_streams); +} + +uint32_t QuicConfig::GetMaxUnidirectionalStreamsToSend() const { + return max_unidirectional_streams_.GetSendValue(); +} + +bool QuicConfig::HasReceivedMaxUnidirectionalStreams() const { + return max_unidirectional_streams_.HasReceivedValue(); +} + +uint32_t QuicConfig::ReceivedMaxUnidirectionalStreams() const { + return max_unidirectional_streams_.GetReceivedValue(); +} + +void QuicConfig::SetMaxAckDelayToSendMs(uint32_t max_ack_delay_ms) { + max_ack_delay_ms_.SetSendValue(max_ack_delay_ms); +} + +uint32_t QuicConfig::GetMaxAckDelayToSendMs() const { + return max_ack_delay_ms_.GetSendValue(); +} + +bool QuicConfig::HasReceivedMaxAckDelayMs() const { + return max_ack_delay_ms_.HasReceivedValue(); +} + +uint32_t QuicConfig::ReceivedMaxAckDelayMs() const { + return max_ack_delay_ms_.GetReceivedValue(); +} + +void QuicConfig::SetMinAckDelayMs(uint32_t min_ack_delay_ms) { + min_ack_delay_ms_.SetSendValue(min_ack_delay_ms); +} + +uint32_t QuicConfig::GetMinAckDelayToSendMs() const { + return min_ack_delay_ms_.GetSendValue(); +} + +bool QuicConfig::HasReceivedMinAckDelayMs() const { + return min_ack_delay_ms_.HasReceivedValue(); +} + +uint32_t QuicConfig::ReceivedMinAckDelayMs() const { + return min_ack_delay_ms_.GetReceivedValue(); +} + +void QuicConfig::SetAckDelayExponentToSend(uint32_t exponent) { + ack_delay_exponent_.SetSendValue(exponent); +} + +uint32_t QuicConfig::GetAckDelayExponentToSend() const { + return ack_delay_exponent_.GetSendValue(); +} + +bool QuicConfig::HasReceivedAckDelayExponent() const { + return ack_delay_exponent_.HasReceivedValue(); +} + +uint32_t QuicConfig::ReceivedAckDelayExponent() const { + return ack_delay_exponent_.GetReceivedValue(); +} + +void QuicConfig::SetMaxPacketSizeToSend(uint64_t max_udp_payload_size) { + max_udp_payload_size_.SetSendValue(max_udp_payload_size); +} + +uint64_t QuicConfig::GetMaxPacketSizeToSend() const { + return max_udp_payload_size_.GetSendValue(); +} + +bool QuicConfig::HasReceivedMaxPacketSize() const { + return max_udp_payload_size_.HasReceivedValue(); +} + +uint64_t QuicConfig::ReceivedMaxPacketSize() const { + return max_udp_payload_size_.GetReceivedValue(); +} + +void QuicConfig::SetMaxDatagramFrameSizeToSend( + uint64_t max_datagram_frame_size) { + max_datagram_frame_size_.SetSendValue(max_datagram_frame_size); +} + +uint64_t QuicConfig::GetMaxDatagramFrameSizeToSend() const { + return max_datagram_frame_size_.GetSendValue(); +} + +bool QuicConfig::HasReceivedMaxDatagramFrameSize() const { + return max_datagram_frame_size_.HasReceivedValue(); +} + +uint64_t QuicConfig::ReceivedMaxDatagramFrameSize() const { + return max_datagram_frame_size_.GetReceivedValue(); +} + +void QuicConfig::SetActiveConnectionIdLimitToSend( + uint64_t active_connection_id_limit) { + active_connection_id_limit_.SetSendValue(active_connection_id_limit); +} + +uint64_t QuicConfig::GetActiveConnectionIdLimitToSend() const { + return active_connection_id_limit_.GetSendValue(); +} + +bool QuicConfig::HasReceivedActiveConnectionIdLimit() const { + return active_connection_id_limit_.HasReceivedValue(); +} + +uint64_t QuicConfig::ReceivedActiveConnectionIdLimit() const { + return active_connection_id_limit_.GetReceivedValue(); +} + +bool QuicConfig::HasSetBytesForConnectionIdToSend() const { + return bytes_for_connection_id_.HasSendValue(); +} + +void QuicConfig::SetBytesForConnectionIdToSend(uint32_t bytes) { + bytes_for_connection_id_.SetSendValue(bytes); +} + +bool QuicConfig::HasReceivedBytesForConnectionId() const { + return bytes_for_connection_id_.HasReceivedValue(); +} + +uint32_t QuicConfig::ReceivedBytesForConnectionId() const { + return bytes_for_connection_id_.GetReceivedValue(); +} + +void QuicConfig::SetInitialRoundTripTimeUsToSend(uint64_t rtt) { + initial_round_trip_time_us_.SetSendValue(rtt); +} + +bool QuicConfig::HasReceivedInitialRoundTripTimeUs() const { + return initial_round_trip_time_us_.HasReceivedValue(); +} + +uint64_t QuicConfig::ReceivedInitialRoundTripTimeUs() const { + return initial_round_trip_time_us_.GetReceivedValue(); +} + +bool QuicConfig::HasInitialRoundTripTimeUsToSend() const { + return initial_round_trip_time_us_.HasSendValue(); +} + +uint64_t QuicConfig::GetInitialRoundTripTimeUsToSend() const { + return initial_round_trip_time_us_.GetSendValue(); +} + +void QuicConfig::SetInitialStreamFlowControlWindowToSend( + uint64_t window_bytes) { + if (window_bytes < kMinimumFlowControlSendWindow) { + QUIC_BUG(quic_bug_10575_7) + << "Initial stream flow control receive window (" << window_bytes + << ") cannot be set lower than minimum (" + << kMinimumFlowControlSendWindow << ")."; + window_bytes = kMinimumFlowControlSendWindow; + } + initial_stream_flow_control_window_bytes_.SetSendValue(window_bytes); +} + +uint64_t QuicConfig::GetInitialStreamFlowControlWindowToSend() const { + return initial_stream_flow_control_window_bytes_.GetSendValue(); +} + +bool QuicConfig::HasReceivedInitialStreamFlowControlWindowBytes() const { + return initial_stream_flow_control_window_bytes_.HasReceivedValue(); +} + +uint64_t QuicConfig::ReceivedInitialStreamFlowControlWindowBytes() const { + return initial_stream_flow_control_window_bytes_.GetReceivedValue(); +} + +void QuicConfig::SetInitialMaxStreamDataBytesIncomingBidirectionalToSend( + uint64_t window_bytes) { + initial_max_stream_data_bytes_incoming_bidirectional_.SetSendValue( + window_bytes); +} + +uint64_t QuicConfig::GetInitialMaxStreamDataBytesIncomingBidirectionalToSend() + const { + if (initial_max_stream_data_bytes_incoming_bidirectional_.HasSendValue()) { + return initial_max_stream_data_bytes_incoming_bidirectional_.GetSendValue(); + } + return initial_stream_flow_control_window_bytes_.GetSendValue(); +} + +bool QuicConfig::HasReceivedInitialMaxStreamDataBytesIncomingBidirectional() + const { + return initial_max_stream_data_bytes_incoming_bidirectional_ + .HasReceivedValue(); +} + +uint64_t QuicConfig::ReceivedInitialMaxStreamDataBytesIncomingBidirectional() + const { + return initial_max_stream_data_bytes_incoming_bidirectional_ + .GetReceivedValue(); +} + +void QuicConfig::SetInitialMaxStreamDataBytesOutgoingBidirectionalToSend( + uint64_t window_bytes) { + initial_max_stream_data_bytes_outgoing_bidirectional_.SetSendValue( + window_bytes); +} + +uint64_t QuicConfig::GetInitialMaxStreamDataBytesOutgoingBidirectionalToSend() + const { + if (initial_max_stream_data_bytes_outgoing_bidirectional_.HasSendValue()) { + return initial_max_stream_data_bytes_outgoing_bidirectional_.GetSendValue(); + } + return initial_stream_flow_control_window_bytes_.GetSendValue(); +} + +bool QuicConfig::HasReceivedInitialMaxStreamDataBytesOutgoingBidirectional() + const { + return initial_max_stream_data_bytes_outgoing_bidirectional_ + .HasReceivedValue(); +} + +uint64_t QuicConfig::ReceivedInitialMaxStreamDataBytesOutgoingBidirectional() + const { + return initial_max_stream_data_bytes_outgoing_bidirectional_ + .GetReceivedValue(); +} + +void QuicConfig::SetInitialMaxStreamDataBytesUnidirectionalToSend( + uint64_t window_bytes) { + initial_max_stream_data_bytes_unidirectional_.SetSendValue(window_bytes); +} + +uint64_t QuicConfig::GetInitialMaxStreamDataBytesUnidirectionalToSend() const { + if (initial_max_stream_data_bytes_unidirectional_.HasSendValue()) { + return initial_max_stream_data_bytes_unidirectional_.GetSendValue(); + } + return initial_stream_flow_control_window_bytes_.GetSendValue(); +} + +bool QuicConfig::HasReceivedInitialMaxStreamDataBytesUnidirectional() const { + return initial_max_stream_data_bytes_unidirectional_.HasReceivedValue(); +} + +uint64_t QuicConfig::ReceivedInitialMaxStreamDataBytesUnidirectional() const { + return initial_max_stream_data_bytes_unidirectional_.GetReceivedValue(); +} + +void QuicConfig::SetInitialSessionFlowControlWindowToSend( + uint64_t window_bytes) { + if (window_bytes < kMinimumFlowControlSendWindow) { + QUIC_BUG(quic_bug_10575_8) + << "Initial session flow control receive window (" << window_bytes + << ") cannot be set lower than default (" + << kMinimumFlowControlSendWindow << ")."; + window_bytes = kMinimumFlowControlSendWindow; + } + initial_session_flow_control_window_bytes_.SetSendValue(window_bytes); +} + +uint64_t QuicConfig::GetInitialSessionFlowControlWindowToSend() const { + return initial_session_flow_control_window_bytes_.GetSendValue(); +} + +bool QuicConfig::HasReceivedInitialSessionFlowControlWindowBytes() const { + return initial_session_flow_control_window_bytes_.HasReceivedValue(); +} + +uint64_t QuicConfig::ReceivedInitialSessionFlowControlWindowBytes() const { + return initial_session_flow_control_window_bytes_.GetReceivedValue(); +} + +void QuicConfig::SetDisableConnectionMigration() { + connection_migration_disabled_.SetSendValue(1); +} + +bool QuicConfig::DisableConnectionMigration() const { + return connection_migration_disabled_.HasReceivedValue(); +} + +void QuicConfig::SetIPv6AlternateServerAddressToSend( + const QuicSocketAddress& alternate_server_address_ipv6) { + if (!alternate_server_address_ipv6.Normalized().host().IsIPv6()) { + QUIC_BUG(quic_bug_10575_9) + << "Cannot use SetIPv6AlternateServerAddressToSend with " + << alternate_server_address_ipv6; + return; + } + alternate_server_address_ipv6_.SetSendValue(alternate_server_address_ipv6); +} + +bool QuicConfig::HasReceivedIPv6AlternateServerAddress() const { + return alternate_server_address_ipv6_.HasReceivedValue(); +} + +const QuicSocketAddress& QuicConfig::ReceivedIPv6AlternateServerAddress() + const { + return alternate_server_address_ipv6_.GetReceivedValue(); +} + +void QuicConfig::SetIPv4AlternateServerAddressToSend( + const QuicSocketAddress& alternate_server_address_ipv4) { + if (!alternate_server_address_ipv4.host().IsIPv4()) { + QUIC_BUG(quic_bug_10575_11) + << "Cannot use SetIPv4AlternateServerAddressToSend with " + << alternate_server_address_ipv4; + return; + } + alternate_server_address_ipv4_.SetSendValue(alternate_server_address_ipv4); +} + +bool QuicConfig::HasReceivedIPv4AlternateServerAddress() const { + return alternate_server_address_ipv4_.HasReceivedValue(); +} + +const QuicSocketAddress& QuicConfig::ReceivedIPv4AlternateServerAddress() + const { + return alternate_server_address_ipv4_.GetReceivedValue(); +} + +void QuicConfig::SetPreferredAddressConnectionIdAndTokenToSend( + const QuicConnectionId& connection_id, + const StatelessResetToken& stateless_reset_token) { + if ((!alternate_server_address_ipv4_.HasSendValue() && + !alternate_server_address_ipv6_.HasSendValue()) || + preferred_address_connection_id_and_token_.has_value()) { + QUIC_BUG(quic_bug_10575_17) + << "Can not send connection ID and token for preferred address"; + return; + } + preferred_address_connection_id_and_token_ = + std::make_pair(connection_id, stateless_reset_token); +} + +bool QuicConfig::HasReceivedPreferredAddressConnectionIdAndToken() const { + return (HasReceivedIPv6AlternateServerAddress() || + HasReceivedIPv4AlternateServerAddress()) && + preferred_address_connection_id_and_token_.has_value(); +} + +const std::pair& +QuicConfig::ReceivedPreferredAddressConnectionIdAndToken() const { + QUICHE_DCHECK(HasReceivedPreferredAddressConnectionIdAndToken()); + return *preferred_address_connection_id_and_token_; +} + +void QuicConfig::SetOriginalConnectionIdToSend( + const QuicConnectionId& original_destination_connection_id) { + original_destination_connection_id_to_send_ = + original_destination_connection_id; +} + +bool QuicConfig::HasReceivedOriginalConnectionId() const { + return received_original_destination_connection_id_.has_value(); +} + +QuicConnectionId QuicConfig::ReceivedOriginalConnectionId() const { + if (!HasReceivedOriginalConnectionId()) { + QUIC_BUG(quic_bug_10575_13) << "No received original connection ID"; + return EmptyQuicConnectionId(); + } + return received_original_destination_connection_id_.value(); +} + +void QuicConfig::SetInitialSourceConnectionIdToSend( + const QuicConnectionId& initial_source_connection_id) { + initial_source_connection_id_to_send_ = initial_source_connection_id; +} + +bool QuicConfig::HasReceivedInitialSourceConnectionId() const { + return received_initial_source_connection_id_.has_value(); +} + +QuicConnectionId QuicConfig::ReceivedInitialSourceConnectionId() const { + if (!HasReceivedInitialSourceConnectionId()) { + QUIC_BUG(quic_bug_10575_14) << "No received initial source connection ID"; + return EmptyQuicConnectionId(); + } + return received_initial_source_connection_id_.value(); +} + +void QuicConfig::SetRetrySourceConnectionIdToSend( + const QuicConnectionId& retry_source_connection_id) { + retry_source_connection_id_to_send_ = retry_source_connection_id; +} + +bool QuicConfig::HasReceivedRetrySourceConnectionId() const { + return received_retry_source_connection_id_.has_value(); +} + +QuicConnectionId QuicConfig::ReceivedRetrySourceConnectionId() const { + if (!HasReceivedRetrySourceConnectionId()) { + QUIC_BUG(quic_bug_10575_15) << "No received retry source connection ID"; + return EmptyQuicConnectionId(); + } + return received_retry_source_connection_id_.value(); +} + +void QuicConfig::SetStatelessResetTokenToSend( + const StatelessResetToken& stateless_reset_token) { + stateless_reset_token_.SetSendValue(stateless_reset_token); +} + +bool QuicConfig::HasStatelessResetTokenToSend() const { + return stateless_reset_token_.HasSendValue(); +} + +bool QuicConfig::HasReceivedStatelessResetToken() const { + return stateless_reset_token_.HasReceivedValue(); +} + +const StatelessResetToken& QuicConfig::ReceivedStatelessResetToken() const { + return stateless_reset_token_.GetReceivedValue(); +} + +bool QuicConfig::negotiated() const { return negotiated_; } + +void QuicConfig::SetCreateSessionTagIndicators(QuicTagVector tags) { + create_session_tag_indicators_ = std::move(tags); +} + +const QuicTagVector& QuicConfig::create_session_tag_indicators() const { + return create_session_tag_indicators_; +} + +void QuicConfig::SetDefaults() { + SetIdleNetworkTimeout(QuicTime::Delta::FromSeconds(kMaximumIdleTimeoutSecs)); + SetMaxBidirectionalStreamsToSend(kDefaultMaxStreamsPerConnection); + SetMaxUnidirectionalStreamsToSend(kDefaultMaxStreamsPerConnection); + max_time_before_crypto_handshake_ = + QuicTime::Delta::FromSeconds(kMaxTimeForCryptoHandshakeSecs); + max_idle_time_before_crypto_handshake_ = + QuicTime::Delta::FromSeconds(kInitialIdleTimeoutSecs); + max_undecryptable_packets_ = kDefaultMaxUndecryptablePackets; + + SetInitialStreamFlowControlWindowToSend(kMinimumFlowControlSendWindow); + SetInitialSessionFlowControlWindowToSend(kMinimumFlowControlSendWindow); + SetMaxAckDelayToSendMs(kDefaultDelayedAckTimeMs); + SetAckDelayExponentToSend(kDefaultAckDelayExponent); + SetMaxPacketSizeToSend(kMaxIncomingPacketSize); + SetMaxDatagramFrameSizeToSend(kMaxAcceptedDatagramFrameSize); +} + +void QuicConfig::ToHandshakeMessage( + CryptoHandshakeMessage* out, QuicTransportVersion transport_version) const { + // Idle timeout has custom rules that are different from other values. + // We configure ourselves with the minumum value between the one sent and + // the one received. Additionally, when QUIC_CRYPTO is used, the server + // MUST send an idle timeout no greater than the idle timeout it received + // from the client. We therefore send the received value if it is lower. + QuicFixedUint32 max_idle_timeout_seconds(kICSL, PRESENCE_REQUIRED); + uint32_t max_idle_timeout_to_send_seconds = + max_idle_timeout_to_send_.ToSeconds(); + if (received_max_idle_timeout_.has_value() && + received_max_idle_timeout_->ToSeconds() < + max_idle_timeout_to_send_seconds) { + max_idle_timeout_to_send_seconds = received_max_idle_timeout_->ToSeconds(); + } + max_idle_timeout_seconds.SetSendValue(max_idle_timeout_to_send_seconds); + max_idle_timeout_seconds.ToHandshakeMessage(out); + + // Do not need a version check here, max...bi... will encode + // as "MIDS" -- the max initial dynamic streams tag -- if + // doing some version other than IETF QUIC. + max_bidirectional_streams_.ToHandshakeMessage(out); + if (VersionHasIetfQuicFrames(transport_version)) { + max_unidirectional_streams_.ToHandshakeMessage(out); + ack_delay_exponent_.ToHandshakeMessage(out); + } + if (max_ack_delay_ms_.GetSendValue() != kDefaultDelayedAckTimeMs) { + // Only send max ack delay if it is using a non-default value, because + // the default value is used by QuicSentPacketManager if it is not + // sent during the handshake, and we want to save bytes. + max_ack_delay_ms_.ToHandshakeMessage(out); + } + bytes_for_connection_id_.ToHandshakeMessage(out); + initial_round_trip_time_us_.ToHandshakeMessage(out); + initial_stream_flow_control_window_bytes_.ToHandshakeMessage(out); + initial_session_flow_control_window_bytes_.ToHandshakeMessage(out); + connection_migration_disabled_.ToHandshakeMessage(out); + connection_options_.ToHandshakeMessage(out); + if (alternate_server_address_ipv6_.HasSendValue()) { + alternate_server_address_ipv6_.ToHandshakeMessage(out); + } else { + alternate_server_address_ipv4_.ToHandshakeMessage(out); + } + stateless_reset_token_.ToHandshakeMessage(out); +} + +QuicErrorCode QuicConfig::ProcessPeerHello( + const CryptoHandshakeMessage& peer_hello, HelloType hello_type, + std::string* error_details) { + QUICHE_DCHECK(error_details != nullptr); + + QuicErrorCode error = QUIC_NO_ERROR; + if (error == QUIC_NO_ERROR) { + // Idle timeout has custom rules that are different from other values. + // We configure ourselves with the minumum value between the one sent and + // the one received. Additionally, when QUIC_CRYPTO is used, the server + // MUST send an idle timeout no greater than the idle timeout it received + // from the client. + QuicFixedUint32 max_idle_timeout_seconds(kICSL, PRESENCE_REQUIRED); + error = max_idle_timeout_seconds.ProcessPeerHello(peer_hello, hello_type, + error_details); + if (error == QUIC_NO_ERROR) { + if (max_idle_timeout_seconds.GetReceivedValue() > + max_idle_timeout_to_send_.ToSeconds()) { + // The received value is higher than ours, ignore it if from the client + // and raise an error if from the server. + if (hello_type == SERVER) { + error = QUIC_INVALID_NEGOTIATED_VALUE; + *error_details = + "Invalid value received for " + QuicTagToString(kICSL); + } + } else { + received_max_idle_timeout_ = QuicTime::Delta::FromSeconds( + max_idle_timeout_seconds.GetReceivedValue()); + } + } + } + if (error == QUIC_NO_ERROR) { + error = max_bidirectional_streams_.ProcessPeerHello(peer_hello, hello_type, + error_details); + } + if (error == QUIC_NO_ERROR) { + error = max_unidirectional_streams_.ProcessPeerHello(peer_hello, hello_type, + error_details); + } + if (error == QUIC_NO_ERROR) { + error = bytes_for_connection_id_.ProcessPeerHello(peer_hello, hello_type, + error_details); + } + if (error == QUIC_NO_ERROR) { + error = initial_round_trip_time_us_.ProcessPeerHello(peer_hello, hello_type, + error_details); + } + if (error == QUIC_NO_ERROR) { + error = initial_stream_flow_control_window_bytes_.ProcessPeerHello( + peer_hello, hello_type, error_details); + } + if (error == QUIC_NO_ERROR) { + error = initial_session_flow_control_window_bytes_.ProcessPeerHello( + peer_hello, hello_type, error_details); + } + if (error == QUIC_NO_ERROR) { + error = connection_migration_disabled_.ProcessPeerHello( + peer_hello, hello_type, error_details); + } + if (error == QUIC_NO_ERROR) { + error = connection_options_.ProcessPeerHello(peer_hello, hello_type, + error_details); + } + if (error == QUIC_NO_ERROR) { + QuicFixedSocketAddress alternate_server_address(kASAD, PRESENCE_OPTIONAL); + error = alternate_server_address.ProcessPeerHello(peer_hello, hello_type, + error_details); + if (error == QUIC_NO_ERROR && alternate_server_address.HasReceivedValue()) { + const QuicSocketAddress& received_address = + alternate_server_address.GetReceivedValue(); + if (received_address.host().IsIPv6()) { + alternate_server_address_ipv6_.SetReceivedValue(received_address); + } else if (received_address.host().IsIPv4()) { + alternate_server_address_ipv4_.SetReceivedValue(received_address); + } + } + } + if (error == QUIC_NO_ERROR) { + error = stateless_reset_token_.ProcessPeerHello(peer_hello, hello_type, + error_details); + } + + if (error == QUIC_NO_ERROR) { + error = max_ack_delay_ms_.ProcessPeerHello(peer_hello, hello_type, + error_details); + } + if (error == QUIC_NO_ERROR) { + error = ack_delay_exponent_.ProcessPeerHello(peer_hello, hello_type, + error_details); + } + if (error == QUIC_NO_ERROR) { + negotiated_ = true; + } + return error; +} + +bool QuicConfig::FillTransportParameters(TransportParameters* params) const { + if (original_destination_connection_id_to_send_.has_value()) { + params->original_destination_connection_id = + original_destination_connection_id_to_send_.value(); + } + + params->max_idle_timeout_ms.set_value( + max_idle_timeout_to_send_.ToMilliseconds()); + + if (stateless_reset_token_.HasSendValue()) { + StatelessResetToken stateless_reset_token = + stateless_reset_token_.GetSendValue(); + params->stateless_reset_token.assign( + reinterpret_cast(&stateless_reset_token), + reinterpret_cast(&stateless_reset_token) + + sizeof(stateless_reset_token)); + } + + params->max_udp_payload_size.set_value(GetMaxPacketSizeToSend()); + params->max_datagram_frame_size.set_value(GetMaxDatagramFrameSizeToSend()); + params->initial_max_data.set_value( + GetInitialSessionFlowControlWindowToSend()); + // The max stream data bidirectional transport parameters can be either local + // or remote. A stream is local iff it is initiated by the endpoint that sent + // the transport parameter (see the Transport Parameter Definitions section of + // draft-ietf-quic-transport). In this function we are sending transport + // parameters, so a local stream is one we initiated, which means an outgoing + // stream. + params->initial_max_stream_data_bidi_local.set_value( + GetInitialMaxStreamDataBytesOutgoingBidirectionalToSend()); + params->initial_max_stream_data_bidi_remote.set_value( + GetInitialMaxStreamDataBytesIncomingBidirectionalToSend()); + params->initial_max_stream_data_uni.set_value( + GetInitialMaxStreamDataBytesUnidirectionalToSend()); + params->initial_max_streams_bidi.set_value( + GetMaxBidirectionalStreamsToSend()); + params->initial_max_streams_uni.set_value( + GetMaxUnidirectionalStreamsToSend()); + params->max_ack_delay.set_value(GetMaxAckDelayToSendMs()); + if (min_ack_delay_ms_.HasSendValue()) { + params->min_ack_delay_us.set_value(min_ack_delay_ms_.GetSendValue() * + kNumMicrosPerMilli); + } + params->ack_delay_exponent.set_value(GetAckDelayExponentToSend()); + params->disable_active_migration = + connection_migration_disabled_.HasSendValue() && + connection_migration_disabled_.GetSendValue() != 0; + + if (alternate_server_address_ipv6_.HasSendValue() || + alternate_server_address_ipv4_.HasSendValue()) { + TransportParameters::PreferredAddress preferred_address; + if (alternate_server_address_ipv6_.HasSendValue()) { + preferred_address.ipv6_socket_address = + alternate_server_address_ipv6_.GetSendValue(); + } + if (alternate_server_address_ipv4_.HasSendValue()) { + preferred_address.ipv4_socket_address = + alternate_server_address_ipv4_.GetSendValue(); + } + if (preferred_address_connection_id_and_token_) { + preferred_address.connection_id = + preferred_address_connection_id_and_token_->first; + auto* begin = reinterpret_cast( + &preferred_address_connection_id_and_token_->second); + auto* end = + begin + sizeof(preferred_address_connection_id_and_token_->second); + preferred_address.stateless_reset_token.assign(begin, end); + } + params->preferred_address = + std::make_unique( + preferred_address); + } + + if (active_connection_id_limit_.HasSendValue()) { + params->active_connection_id_limit.set_value( + active_connection_id_limit_.GetSendValue()); + } + + if (initial_source_connection_id_to_send_.has_value()) { + params->initial_source_connection_id = + initial_source_connection_id_to_send_.value(); + } + + if (retry_source_connection_id_to_send_.has_value()) { + params->retry_source_connection_id = + retry_source_connection_id_to_send_.value(); + } + + if (initial_round_trip_time_us_.HasSendValue()) { + params->initial_round_trip_time_us.set_value( + initial_round_trip_time_us_.GetSendValue()); + } + if (connection_options_.HasSendValues() && + !connection_options_.GetSendValues().empty()) { + params->google_connection_options = connection_options_.GetSendValues(); + } + + if (google_handshake_message_to_send_.has_value()) { + params->google_handshake_message = google_handshake_message_to_send_; + } + + params->custom_parameters = custom_transport_parameters_to_send_; + + return true; +} + +QuicErrorCode QuicConfig::ProcessTransportParameters( + const TransportParameters& params, bool is_resumption, + std::string* error_details) { + if (!is_resumption && params.original_destination_connection_id.has_value()) { + received_original_destination_connection_id_ = + params.original_destination_connection_id.value(); + } + + if (params.max_idle_timeout_ms.value() > 0 && + params.max_idle_timeout_ms.value() < + static_cast(max_idle_timeout_to_send_.ToMilliseconds())) { + // An idle timeout of zero indicates it is disabled. + // We also ignore values higher than ours which will cause us to use the + // smallest value between ours and our peer's. + received_max_idle_timeout_ = + QuicTime::Delta::FromMilliseconds(params.max_idle_timeout_ms.value()); + } + + if (!is_resumption && !params.stateless_reset_token.empty()) { + StatelessResetToken stateless_reset_token; + if (params.stateless_reset_token.size() != sizeof(stateless_reset_token)) { + QUIC_BUG(quic_bug_10575_16) << "Bad stateless reset token length " + << params.stateless_reset_token.size(); + *error_details = "Bad stateless reset token length"; + return QUIC_INTERNAL_ERROR; + } + memcpy(&stateless_reset_token, params.stateless_reset_token.data(), + params.stateless_reset_token.size()); + stateless_reset_token_.SetReceivedValue(stateless_reset_token); + } + + if (params.max_udp_payload_size.IsValid()) { + max_udp_payload_size_.SetReceivedValue(params.max_udp_payload_size.value()); + } + + if (params.max_datagram_frame_size.IsValid()) { + max_datagram_frame_size_.SetReceivedValue( + params.max_datagram_frame_size.value()); + } + + initial_session_flow_control_window_bytes_.SetReceivedValue( + params.initial_max_data.value()); + + // IETF QUIC specifies stream IDs and stream counts as 62-bit integers but + // our implementation uses uint32_t to represent them to save memory. + max_bidirectional_streams_.SetReceivedValue( + std::min(params.initial_max_streams_bidi.value(), + std::numeric_limits::max())); + max_unidirectional_streams_.SetReceivedValue( + std::min(params.initial_max_streams_uni.value(), + std::numeric_limits::max())); + + // The max stream data bidirectional transport parameters can be either local + // or remote. A stream is local iff it is initiated by the endpoint that sent + // the transport parameter (see the Transport Parameter Definitions section of + // draft-ietf-quic-transport). However in this function we are processing + // received transport parameters, so a local stream is one initiated by our + // peer, which means an incoming stream. + initial_max_stream_data_bytes_incoming_bidirectional_.SetReceivedValue( + params.initial_max_stream_data_bidi_local.value()); + initial_max_stream_data_bytes_outgoing_bidirectional_.SetReceivedValue( + params.initial_max_stream_data_bidi_remote.value()); + initial_max_stream_data_bytes_unidirectional_.SetReceivedValue( + params.initial_max_stream_data_uni.value()); + + if (!is_resumption) { + max_ack_delay_ms_.SetReceivedValue(params.max_ack_delay.value()); + if (params.ack_delay_exponent.IsValid()) { + ack_delay_exponent_.SetReceivedValue(params.ack_delay_exponent.value()); + } + if (params.preferred_address != nullptr) { + if (params.preferred_address->ipv6_socket_address.port() != 0) { + alternate_server_address_ipv6_.SetReceivedValue( + params.preferred_address->ipv6_socket_address); + } + if (params.preferred_address->ipv4_socket_address.port() != 0) { + alternate_server_address_ipv4_.SetReceivedValue( + params.preferred_address->ipv4_socket_address); + } + // TODO(haoyuewang) Treat 0 length connection ID sent in preferred_address + // as a connection error of type TRANSPORT_PARAMETER_ERROR when server + // fully supports it. + if (!params.preferred_address->connection_id.IsEmpty()) { + preferred_address_connection_id_and_token_ = std::make_pair( + params.preferred_address->connection_id, + *reinterpret_cast( + ¶ms.preferred_address->stateless_reset_token.front())); + } + } + if (params.min_ack_delay_us.value() != 0) { + if (params.min_ack_delay_us.value() > + params.max_ack_delay.value() * kNumMicrosPerMilli) { + *error_details = "MinAckDelay is greater than MaxAckDelay."; + return IETF_QUIC_PROTOCOL_VIOLATION; + } + min_ack_delay_ms_.SetReceivedValue(params.min_ack_delay_us.value() / + kNumMicrosPerMilli); + } + } + + if (params.disable_active_migration) { + connection_migration_disabled_.SetReceivedValue(1u); + } + + active_connection_id_limit_.SetReceivedValue( + params.active_connection_id_limit.value()); + + if (!is_resumption) { + if (params.initial_source_connection_id.has_value()) { + received_initial_source_connection_id_ = + params.initial_source_connection_id.value(); + } + if (params.retry_source_connection_id.has_value()) { + received_retry_source_connection_id_ = + params.retry_source_connection_id.value(); + } + } + + if (params.initial_round_trip_time_us.value() > 0) { + initial_round_trip_time_us_.SetReceivedValue( + params.initial_round_trip_time_us.value()); + } + if (params.google_connection_options.has_value()) { + connection_options_.SetReceivedValues( + params.google_connection_options.value()); + } + if (params.google_handshake_message.has_value()) { + received_google_handshake_message_ = params.google_handshake_message; + } + + received_custom_transport_parameters_ = params.custom_parameters; + + if (!is_resumption) { + negotiated_ = true; + } + *error_details = ""; + return QUIC_NO_ERROR; +} + +void QuicConfig::ClearGoogleHandshakeMessage() { + google_handshake_message_to_send_.reset(); + received_google_handshake_message_.reset(); +} + +absl::optional QuicConfig::GetPreferredAddressToSend( + quiche::IpAddressFamily address_family) const { + if (alternate_server_address_ipv6_.HasSendValue() && + address_family == quiche::IpAddressFamily::IP_V6) { + return alternate_server_address_ipv6_.GetSendValue(); + } + + if (alternate_server_address_ipv4_.HasSendValue() && + address_family == quiche::IpAddressFamily::IP_V4) { + return alternate_server_address_ipv4_.GetSendValue(); + } + return absl::nullopt; +} + +void QuicConfig::ClearAlternateServerAddressToSend( + quiche::IpAddressFamily address_family) { + if (address_family == quiche::IpAddressFamily::IP_V4) { + alternate_server_address_ipv4_.ClearSendValue(); + } else if (address_family == quiche::IpAddressFamily::IP_V6) { + alternate_server_address_ipv6_.ClearSendValue(); + } +} + +} // namespace quic diff --git a/quiche/quic/core/quic_config.h b/quiche/quic/core/quic_config.h new file mode 100644 index 000000000000..f60b4818d66d --- /dev/null +++ b/quiche/quic/core/quic_config.h @@ -0,0 +1,683 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_CONFIG_H_ +#define QUICHE_QUIC_CORE_QUIC_CONFIG_H_ + +#include +#include +#include + +#include "absl/types/optional.h" +#include "quiche/quic/core/crypto/transport_parameters.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +namespace test { +class QuicConfigPeer; +} // namespace test + +class CryptoHandshakeMessage; + +// Describes whether or not a given QuicTag is required or optional in the +// handshake message. +enum QuicConfigPresence : uint8_t { + // This negotiable value can be absent from the handshake message. Default + // value is selected as the negotiated value in such a case. + PRESENCE_OPTIONAL, + // This negotiable value is required in the handshake message otherwise the + // Process*Hello function returns an error. + PRESENCE_REQUIRED, +}; + +// Whether the CryptoHandshakeMessage is from the client or server. +enum HelloType { + CLIENT, + SERVER, +}; + +// An abstract base class that stores a value that can be sent in CHLO/SHLO +// message. These values can be OPTIONAL or REQUIRED, depending on |presence_|. +class QUIC_EXPORT_PRIVATE QuicConfigValue { + public: + QuicConfigValue(QuicTag tag, QuicConfigPresence presence); + virtual ~QuicConfigValue(); + + // Serialises tag name and value(s) to |out|. + virtual void ToHandshakeMessage(CryptoHandshakeMessage* out) const = 0; + + // Selects a mutually acceptable value from those offered in |peer_hello| + // and those defined in the subclass. + virtual QuicErrorCode ProcessPeerHello( + const CryptoHandshakeMessage& peer_hello, HelloType hello_type, + std::string* error_details) = 0; + + protected: + const QuicTag tag_; + const QuicConfigPresence presence_; +}; + +// Stores uint32_t from CHLO or SHLO messages that are not negotiated. +class QUIC_EXPORT_PRIVATE QuicFixedUint32 : public QuicConfigValue { + public: + QuicFixedUint32(QuicTag tag, QuicConfigPresence presence); + ~QuicFixedUint32() override; + + bool HasSendValue() const; + + uint32_t GetSendValue() const; + + void SetSendValue(uint32_t value); + + bool HasReceivedValue() const; + + uint32_t GetReceivedValue() const; + + void SetReceivedValue(uint32_t value); + + // If has_send_value is true, serialises |tag_| and |send_value_| to |out|. + void ToHandshakeMessage(CryptoHandshakeMessage* out) const override; + + // Sets |value_| to the corresponding value from |peer_hello_| if it exists. + QuicErrorCode ProcessPeerHello(const CryptoHandshakeMessage& peer_hello, + HelloType hello_type, + std::string* error_details) override; + + private: + bool has_send_value_; + bool has_receive_value_; + uint32_t send_value_; + uint32_t receive_value_; +}; + +// Stores 62bit numbers from handshake messages that unilaterally shared by each +// endpoint. IMPORTANT: these are serialized as 32-bit unsigned integers when +// using QUIC_CRYPTO versions and CryptoHandshakeMessage. +class QUIC_EXPORT_PRIVATE QuicFixedUint62 : public QuicConfigValue { + public: + QuicFixedUint62(QuicTag name, QuicConfigPresence presence); + ~QuicFixedUint62() override; + + bool HasSendValue() const; + + uint64_t GetSendValue() const; + + void SetSendValue(uint64_t value); + + bool HasReceivedValue() const; + + uint64_t GetReceivedValue() const; + + void SetReceivedValue(uint64_t value); + + // If has_send_value is true, serialises |tag_| and |send_value_| to |out|. + // IMPORTANT: this method serializes |send_value_| as an unsigned 32bit + // integer. + void ToHandshakeMessage(CryptoHandshakeMessage* out) const override; + + // Sets |value_| to the corresponding value from |peer_hello_| if it exists. + QuicErrorCode ProcessPeerHello(const CryptoHandshakeMessage& peer_hello, + HelloType hello_type, + std::string* error_details) override; + + private: + bool has_send_value_; + bool has_receive_value_; + uint64_t send_value_; + uint64_t receive_value_; +}; + +// Stores StatelessResetToken from CHLO or SHLO messages that are not +// negotiated. +class QUIC_EXPORT_PRIVATE QuicFixedStatelessResetToken + : public QuicConfigValue { + public: + QuicFixedStatelessResetToken(QuicTag tag, QuicConfigPresence presence); + ~QuicFixedStatelessResetToken() override; + + bool HasSendValue() const; + + const StatelessResetToken& GetSendValue() const; + + void SetSendValue(const StatelessResetToken& value); + + bool HasReceivedValue() const; + + const StatelessResetToken& GetReceivedValue() const; + + void SetReceivedValue(const StatelessResetToken& value); + + // If has_send_value is true, serialises |tag_| and |send_value_| to |out|. + void ToHandshakeMessage(CryptoHandshakeMessage* out) const override; + + // Sets |value_| to the corresponding value from |peer_hello_| if it exists. + QuicErrorCode ProcessPeerHello(const CryptoHandshakeMessage& peer_hello, + HelloType hello_type, + std::string* error_details) override; + + private: + bool has_send_value_; + bool has_receive_value_; + StatelessResetToken send_value_; + StatelessResetToken receive_value_; +}; + +// Stores tag from CHLO or SHLO messages that are not negotiated. +class QUIC_EXPORT_PRIVATE QuicFixedTagVector : public QuicConfigValue { + public: + QuicFixedTagVector(QuicTag name, QuicConfigPresence presence); + QuicFixedTagVector(const QuicFixedTagVector& other); + ~QuicFixedTagVector() override; + + bool HasSendValues() const; + + const QuicTagVector& GetSendValues() const; + + void SetSendValues(const QuicTagVector& values); + + bool HasReceivedValues() const; + + const QuicTagVector& GetReceivedValues() const; + + void SetReceivedValues(const QuicTagVector& values); + + // If has_send_value is true, serialises |tag_vector_| and |send_value_| to + // |out|. + void ToHandshakeMessage(CryptoHandshakeMessage* out) const override; + + // Sets |receive_values_| to the corresponding value from |client_hello_| if + // it exists. + QuicErrorCode ProcessPeerHello(const CryptoHandshakeMessage& peer_hello, + HelloType hello_type, + std::string* error_details) override; + + private: + bool has_send_values_; + bool has_receive_values_; + QuicTagVector send_values_; + QuicTagVector receive_values_; +}; + +// Stores QuicSocketAddress from CHLO or SHLO messages that are not negotiated. +class QUIC_EXPORT_PRIVATE QuicFixedSocketAddress : public QuicConfigValue { + public: + QuicFixedSocketAddress(QuicTag tag, QuicConfigPresence presence); + ~QuicFixedSocketAddress() override; + + bool HasSendValue() const; + + const QuicSocketAddress& GetSendValue() const; + + void SetSendValue(const QuicSocketAddress& value); + + void ClearSendValue(); + + bool HasReceivedValue() const; + + const QuicSocketAddress& GetReceivedValue() const; + + void SetReceivedValue(const QuicSocketAddress& value); + + void ToHandshakeMessage(CryptoHandshakeMessage* out) const override; + + QuicErrorCode ProcessPeerHello(const CryptoHandshakeMessage& peer_hello, + HelloType hello_type, + std::string* error_details) override; + + private: + bool has_send_value_; + bool has_receive_value_; + QuicSocketAddress send_value_; + QuicSocketAddress receive_value_; +}; + +// QuicConfig contains non-crypto configuration options that are negotiated in +// the crypto handshake. +class QUIC_EXPORT_PRIVATE QuicConfig { + public: + QuicConfig(); + QuicConfig(const QuicConfig& other); + ~QuicConfig(); + + void SetConnectionOptionsToSend(const QuicTagVector& connection_options); + + bool HasReceivedConnectionOptions() const; + + void SetGoogleHandshakeMessageToSend(std::string message); + + const absl::optional& GetReceivedGoogleHandshakeMessage() const; + + // Sets initial received connection options. All received connection options + // will be initialized with these fields. Initial received options may only be + // set once per config, prior to the setting of any other options. If options + // have already been set (either by previous calls or via handshake), this + // function does nothing and returns false. + bool SetInitialReceivedConnectionOptions(const QuicTagVector& tags); + + const QuicTagVector& ReceivedConnectionOptions() const; + + bool HasSendConnectionOptions() const; + + const QuicTagVector& SendConnectionOptions() const; + + // Returns true if the client is sending or the server has received a + // connection option. + // TODO(ianswett): Rename to HasClientRequestedSharedOption + bool HasClientSentConnectionOption(QuicTag tag, + Perspective perspective) const; + + void SetClientConnectionOptions( + const QuicTagVector& client_connection_options); + + // Returns true if the client has requested the specified connection option. + // Checks the client connection options if the |perspective| is client and + // connection options if the |perspective| is the server. + bool HasClientRequestedIndependentOption(QuicTag tag, + Perspective perspective) const; + + const QuicTagVector& ClientRequestedIndependentOptions( + Perspective perspective) const; + + void SetIdleNetworkTimeout(QuicTime::Delta idle_network_timeout); + + QuicTime::Delta IdleNetworkTimeout() const; + + // Sets the max bidirectional stream count that this endpoint supports. + void SetMaxBidirectionalStreamsToSend(uint32_t max_streams); + uint32_t GetMaxBidirectionalStreamsToSend() const; + + bool HasReceivedMaxBidirectionalStreams() const; + // Gets the max bidirectional stream limit imposed by the peer. + uint32_t ReceivedMaxBidirectionalStreams() const; + + // Sets the max unidirectional stream count that this endpoint supports. + void SetMaxUnidirectionalStreamsToSend(uint32_t max_streams); + uint32_t GetMaxUnidirectionalStreamsToSend() const; + + bool HasReceivedMaxUnidirectionalStreams() const; + // Gets the max unidirectional stream limit imposed by the peer. + uint32_t ReceivedMaxUnidirectionalStreams() const; + + void set_max_time_before_crypto_handshake( + QuicTime::Delta max_time_before_crypto_handshake) { + max_time_before_crypto_handshake_ = max_time_before_crypto_handshake; + } + + QuicTime::Delta max_time_before_crypto_handshake() const { + return max_time_before_crypto_handshake_; + } + + void set_max_idle_time_before_crypto_handshake( + QuicTime::Delta max_idle_time_before_crypto_handshake) { + max_idle_time_before_crypto_handshake_ = + max_idle_time_before_crypto_handshake; + } + + QuicTime::Delta max_idle_time_before_crypto_handshake() const { + return max_idle_time_before_crypto_handshake_; + } + + void set_max_undecryptable_packets(size_t max_undecryptable_packets) { + max_undecryptable_packets_ = max_undecryptable_packets; + } + + size_t max_undecryptable_packets() const { + return max_undecryptable_packets_; + } + + // Peer's connection id length, in bytes. Only used in Q043 and Q046. + bool HasSetBytesForConnectionIdToSend() const; + void SetBytesForConnectionIdToSend(uint32_t bytes); + bool HasReceivedBytesForConnectionId() const; + uint32_t ReceivedBytesForConnectionId() const; + + // Estimated initial round trip time in us. + void SetInitialRoundTripTimeUsToSend(uint64_t rtt_us); + bool HasReceivedInitialRoundTripTimeUs() const; + uint64_t ReceivedInitialRoundTripTimeUs() const; + bool HasInitialRoundTripTimeUsToSend() const; + uint64_t GetInitialRoundTripTimeUsToSend() const; + + // Sets an initial stream flow control window size to transmit to the peer. + void SetInitialStreamFlowControlWindowToSend(uint64_t window_bytes); + uint64_t GetInitialStreamFlowControlWindowToSend() const; + bool HasReceivedInitialStreamFlowControlWindowBytes() const; + uint64_t ReceivedInitialStreamFlowControlWindowBytes() const; + + // Specifies the initial flow control window (max stream data) for + // incoming bidirectional streams. Incoming means streams initiated by our + // peer. If not set, GetInitialMaxStreamDataBytesIncomingBidirectionalToSend + // returns the value passed to SetInitialStreamFlowControlWindowToSend. + void SetInitialMaxStreamDataBytesIncomingBidirectionalToSend( + uint64_t window_bytes); + uint64_t GetInitialMaxStreamDataBytesIncomingBidirectionalToSend() const; + bool HasReceivedInitialMaxStreamDataBytesIncomingBidirectional() const; + uint64_t ReceivedInitialMaxStreamDataBytesIncomingBidirectional() const; + + // Specifies the initial flow control window (max stream data) for + // outgoing bidirectional streams. Outgoing means streams initiated by us. + // If not set, GetInitialMaxStreamDataBytesOutgoingBidirectionalToSend + // returns the value passed to SetInitialStreamFlowControlWindowToSend. + void SetInitialMaxStreamDataBytesOutgoingBidirectionalToSend( + uint64_t window_bytes); + uint64_t GetInitialMaxStreamDataBytesOutgoingBidirectionalToSend() const; + bool HasReceivedInitialMaxStreamDataBytesOutgoingBidirectional() const; + uint64_t ReceivedInitialMaxStreamDataBytesOutgoingBidirectional() const; + + // Specifies the initial flow control window (max stream data) for + // unidirectional streams. If not set, + // GetInitialMaxStreamDataBytesUnidirectionalToSend returns the value passed + // to SetInitialStreamFlowControlWindowToSend. + void SetInitialMaxStreamDataBytesUnidirectionalToSend(uint64_t window_bytes); + uint64_t GetInitialMaxStreamDataBytesUnidirectionalToSend() const; + bool HasReceivedInitialMaxStreamDataBytesUnidirectional() const; + uint64_t ReceivedInitialMaxStreamDataBytesUnidirectional() const; + + // Sets an initial session flow control window size to transmit to the peer. + void SetInitialSessionFlowControlWindowToSend(uint64_t window_bytes); + uint64_t GetInitialSessionFlowControlWindowToSend() const; + bool HasReceivedInitialSessionFlowControlWindowBytes() const; + uint64_t ReceivedInitialSessionFlowControlWindowBytes() const; + + // Disable connection migration. + void SetDisableConnectionMigration(); + bool DisableConnectionMigration() const; + + // IPv6 alternate server address. + void SetIPv6AlternateServerAddressToSend( + const QuicSocketAddress& alternate_server_address_ipv6); + bool HasReceivedIPv6AlternateServerAddress() const; + const QuicSocketAddress& ReceivedIPv6AlternateServerAddress() const; + + // IPv4 alternate server address. + void SetIPv4AlternateServerAddressToSend( + const QuicSocketAddress& alternate_server_address_ipv4); + bool HasReceivedIPv4AlternateServerAddress() const; + const QuicSocketAddress& ReceivedIPv4AlternateServerAddress() const; + + // Called to set |connection_id| and |stateless_reset_token| if server + // preferred address has been set via SetIPv(4|6)AlternateServerAddressToSend. + // Please note, this is different from SetStatelessResetTokenToSend(const + // StatelessResetToken&) which is used to send the token corresponding to the + // existing server_connection_id. + void SetPreferredAddressConnectionIdAndTokenToSend( + const QuicConnectionId& connection_id, + const StatelessResetToken& stateless_reset_token); + + // Preferred Address Connection ID and Token. + bool HasReceivedPreferredAddressConnectionIdAndToken() const; + const std::pair& + ReceivedPreferredAddressConnectionIdAndToken() const; + absl::optional GetPreferredAddressToSend( + quiche::IpAddressFamily address_family) const; + void ClearAlternateServerAddressToSend( + quiche::IpAddressFamily address_family); + + // Original destination connection ID. + void SetOriginalConnectionIdToSend( + const QuicConnectionId& original_destination_connection_id); + bool HasReceivedOriginalConnectionId() const; + QuicConnectionId ReceivedOriginalConnectionId() const; + + // Stateless reset token. + void SetStatelessResetTokenToSend( + const StatelessResetToken& stateless_reset_token); + bool HasStatelessResetTokenToSend() const; + bool HasReceivedStatelessResetToken() const; + const StatelessResetToken& ReceivedStatelessResetToken() const; + + // Manage the IETF QUIC Max ACK Delay transport parameter. + // The sent value is the delay that this node uses + // (QuicSentPacketManager::local_max_ack_delay_). + // The received delay is the value received from + // the peer (QuicSentPacketManager::peer_max_ack_delay_). + void SetMaxAckDelayToSendMs(uint32_t max_ack_delay_ms); + uint32_t GetMaxAckDelayToSendMs() const; + bool HasReceivedMaxAckDelayMs() const; + uint32_t ReceivedMaxAckDelayMs() const; + + // Manage the IETF QUIC extension Min Ack Delay transport parameter. + // An endpoint uses min_ack_delay to advsertise its support for + // AckFrequencyFrame sent by peer. + void SetMinAckDelayMs(uint32_t min_ack_delay_ms); + uint32_t GetMinAckDelayToSendMs() const; + bool HasReceivedMinAckDelayMs() const; + uint32_t ReceivedMinAckDelayMs() const; + + void SetAckDelayExponentToSend(uint32_t exponent); + uint32_t GetAckDelayExponentToSend() const; + bool HasReceivedAckDelayExponent() const; + uint32_t ReceivedAckDelayExponent() const; + + // IETF QUIC max_udp_payload_size transport parameter. + void SetMaxPacketSizeToSend(uint64_t max_udp_payload_size); + uint64_t GetMaxPacketSizeToSend() const; + bool HasReceivedMaxPacketSize() const; + uint64_t ReceivedMaxPacketSize() const; + + // IETF QUIC max_datagram_frame_size transport parameter. + void SetMaxDatagramFrameSizeToSend(uint64_t max_datagram_frame_size); + uint64_t GetMaxDatagramFrameSizeToSend() const; + bool HasReceivedMaxDatagramFrameSize() const; + uint64_t ReceivedMaxDatagramFrameSize() const; + + // IETF QUIC active_connection_id_limit transport parameter. + void SetActiveConnectionIdLimitToSend(uint64_t active_connection_id_limit); + uint64_t GetActiveConnectionIdLimitToSend() const; + bool HasReceivedActiveConnectionIdLimit() const; + uint64_t ReceivedActiveConnectionIdLimit() const; + + // Initial source connection ID. + void SetInitialSourceConnectionIdToSend( + const QuicConnectionId& initial_source_connection_id); + bool HasReceivedInitialSourceConnectionId() const; + QuicConnectionId ReceivedInitialSourceConnectionId() const; + + // Retry source connection ID. + void SetRetrySourceConnectionIdToSend( + const QuicConnectionId& retry_source_connection_id); + bool HasReceivedRetrySourceConnectionId() const; + QuicConnectionId ReceivedRetrySourceConnectionId() const; + + bool negotiated() const; + + void SetCreateSessionTagIndicators(QuicTagVector tags); + + const QuicTagVector& create_session_tag_indicators() const; + + // ToHandshakeMessage serialises the settings in this object as a series of + // tags /value pairs and adds them to |out|. + void ToHandshakeMessage(CryptoHandshakeMessage* out, + QuicTransportVersion transport_version) const; + + // Calls ProcessPeerHello on each negotiable parameter. On failure returns + // the corresponding QuicErrorCode and sets detailed error in |error_details|. + QuicErrorCode ProcessPeerHello(const CryptoHandshakeMessage& peer_hello, + HelloType hello_type, + std::string* error_details); + + // FillTransportParameters writes the values to send for ICSL, MIDS, CFCW, and + // SFCW to |*params|, returning true if the values could be written and false + // if something prevents them from being written (e.g. a value is too large). + bool FillTransportParameters(TransportParameters* params) const; + + // ProcessTransportParameters reads from |params| which were received from a + // peer. If |is_resumption|, some configs will not be processed. + // On failure, it returns a QuicErrorCode and puts a detailed error in + // |*error_details|. + QuicErrorCode ProcessTransportParameters(const TransportParameters& params, + bool is_resumption, + std::string* error_details); + + TransportParameters::ParameterMap& custom_transport_parameters_to_send() { + return custom_transport_parameters_to_send_; + } + const TransportParameters::ParameterMap& + received_custom_transport_parameters() const { + return received_custom_transport_parameters_; + } + + // Called to clear google_handshake_message to send or received. + void ClearGoogleHandshakeMessage(); + + private: + friend class test::QuicConfigPeer; + + // SetDefaults sets the members to sensible, default values. + void SetDefaults(); + + // Whether we've received the peer's config. + bool negotiated_; + + // Configurations options that are not negotiated. + // Maximum time the session can be alive before crypto handshake is finished. + QuicTime::Delta max_time_before_crypto_handshake_; + // Maximum idle time before the crypto handshake has completed. + QuicTime::Delta max_idle_time_before_crypto_handshake_; + // Maximum number of undecryptable packets stored before CHLO/SHLO. + size_t max_undecryptable_packets_; + + // Connection options which affect the server side. May also affect the + // client side in cases when identical behavior is desirable. + QuicFixedTagVector connection_options_; + // Connection options which only affect the client side. + QuicFixedTagVector client_connection_options_; + // Maximum idle network timeout. + // Uses the max_idle_timeout transport parameter in IETF QUIC. + // Note that received_max_idle_timeout_ is only populated if we receive the + // peer's value, which isn't guaranteed in IETF QUIC as sending is optional. + QuicTime::Delta max_idle_timeout_to_send_; + absl::optional received_max_idle_timeout_; + // Maximum number of dynamic streams that a Google QUIC connection + // can support or the maximum number of bidirectional streams that + // an IETF QUIC connection can support. + // The SendValue is the limit on peer-created streams that this endpoint is + // advertising. + // The ReceivedValue is the limit on locally-created streams that + // the peer advertised. + // Uses the initial_max_streams_bidi transport parameter in IETF QUIC. + QuicFixedUint32 max_bidirectional_streams_; + // Maximum number of unidirectional streams that the connection can + // support. + // The SendValue is the limit on peer-created streams that this endpoint is + // advertising. + // The ReceivedValue is the limit on locally-created streams that the peer + // advertised. + // Uses the initial_max_streams_uni transport parameter in IETF QUIC. + QuicFixedUint32 max_unidirectional_streams_; + // The number of bytes required for the connection ID. This is only used in + // the legacy header format used only by Q043 at this point. + QuicFixedUint32 bytes_for_connection_id_; + // Initial round trip time estimate in microseconds. + QuicFixedUint62 initial_round_trip_time_us_; + + // Initial IETF QUIC stream flow control receive windows in bytes. + // Incoming bidirectional streams. + // Uses the initial_max_stream_data_bidi_{local,remote} transport parameter + // in IETF QUIC, depending on whether we're sending or receiving. + QuicFixedUint62 initial_max_stream_data_bytes_incoming_bidirectional_; + // Outgoing bidirectional streams. + // Uses the initial_max_stream_data_bidi_{local,remote} transport parameter + // in IETF QUIC, depending on whether we're sending or receiving. + QuicFixedUint62 initial_max_stream_data_bytes_outgoing_bidirectional_; + // Unidirectional streams. + // Uses the initial_max_stream_data_uni transport parameter in IETF QUIC. + QuicFixedUint62 initial_max_stream_data_bytes_unidirectional_; + + // Initial Google QUIC stream flow control receive window in bytes. + QuicFixedUint62 initial_stream_flow_control_window_bytes_; + + // Initial session flow control receive window in bytes. + // Uses the initial_max_data transport parameter in IETF QUIC. + QuicFixedUint62 initial_session_flow_control_window_bytes_; + + // Whether active connection migration is allowed. + // Uses the disable_active_migration transport parameter in IETF QUIC. + QuicFixedUint32 connection_migration_disabled_; + + // Alternate server addresses the client could connect to. + // Uses the preferred_address transport parameter in IETF QUIC. + // Note that when QUIC_CRYPTO is in use, only one of the addresses is sent. + QuicFixedSocketAddress alternate_server_address_ipv6_; + QuicFixedSocketAddress alternate_server_address_ipv4_; + // Connection Id data to send from the server or receive at the client as part + // of the preferred address transport parameter. + absl::optional> + preferred_address_connection_id_and_token_; + + // Stateless reset token used in IETF public reset packet. + // Uses the stateless_reset_token transport parameter in IETF QUIC. + QuicFixedStatelessResetToken stateless_reset_token_; + + // List of QuicTags whose presence immediately causes the session to + // be created. This allows for CHLOs that are larger than a single + // packet to be processed. + QuicTagVector create_session_tag_indicators_; + + // Maximum ack delay. The sent value is the value used on this node. + // The received value is the value received from the peer and used by + // the peer. + // Uses the max_ack_delay transport parameter in IETF QUIC. + QuicFixedUint32 max_ack_delay_ms_; + + // Minimum ack delay. Used to enable sender control of max_ack_delay. + // Uses the min_ack_delay transport parameter in IETF QUIC extension. + QuicFixedUint32 min_ack_delay_ms_; + + // The sent exponent is the exponent that this node uses when serializing an + // ACK frame (and the peer should use when deserializing the frame); + // the received exponent is the value the peer uses to serialize frames and + // this node uses to deserialize them. + // Uses the ack_delay_exponent transport parameter in IETF QUIC. + QuicFixedUint32 ack_delay_exponent_; + + // Maximum packet size in bytes. + // Uses the max_udp_payload_size transport parameter in IETF QUIC. + QuicFixedUint62 max_udp_payload_size_; + + // Maximum DATAGRAM/MESSAGE frame size in bytes. + // Uses the max_datagram_frame_size transport parameter in IETF QUIC. + QuicFixedUint62 max_datagram_frame_size_; + + // Maximum number of connection IDs from the peer. + // Uses the active_connection_id_limit transport parameter in IETF QUIC. + QuicFixedUint62 active_connection_id_limit_; + + // The value of the Destination Connection ID field from the first + // Initial packet sent by the client. + // Uses the original_destination_connection_id transport parameter in + // IETF QUIC. + absl::optional original_destination_connection_id_to_send_; + absl::optional received_original_destination_connection_id_; + + // The value that the endpoint included in the Source Connection ID field of + // the first Initial packet it sent. + // Uses the initial_source_connection_id transport parameter in IETF QUIC. + absl::optional initial_source_connection_id_to_send_; + absl::optional received_initial_source_connection_id_; + + // The value that the server included in the Source Connection ID field of a + // Retry packet it sent. + // Uses the retry_source_connection_id transport parameter in IETF QUIC. + absl::optional retry_source_connection_id_to_send_; + absl::optional received_retry_source_connection_id_; + + // Custom transport parameters that can be sent and received in the TLS + // handshake. + TransportParameters::ParameterMap custom_transport_parameters_to_send_; + TransportParameters::ParameterMap received_custom_transport_parameters_; + + // Google internal handshake message. + absl::optional google_handshake_message_to_send_; + absl::optional received_google_handshake_message_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_CONFIG_H_ diff --git a/quiche/quic/core/quic_config_test.cc b/quiche/quic/core/quic_config_test.cc new file mode 100644 index 000000000000..86e5a996d26b --- /dev/null +++ b/quiche/quic/core/quic_config_test.cc @@ -0,0 +1,776 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_config.h" + +#include +#include + +#include "quiche/quic/core/crypto/crypto_handshake_message.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/crypto/transport_parameters.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_config_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { +namespace test { +namespace { + +class QuicConfigTest : public QuicTestWithParam { + public: + QuicConfigTest() : version_(GetParam()) {} + + protected: + ParsedQuicVersion version_; + QuicConfig config_; +}; + +// Run all tests with all versions of QUIC. +INSTANTIATE_TEST_SUITE_P(QuicConfigTests, QuicConfigTest, + ::testing::ValuesIn(AllSupportedVersions()), + ::testing::PrintToStringParamName()); + +TEST_P(QuicConfigTest, SetDefaults) { + EXPECT_EQ(kMinimumFlowControlSendWindow, + config_.GetInitialStreamFlowControlWindowToSend()); + EXPECT_EQ(kMinimumFlowControlSendWindow, + config_.GetInitialMaxStreamDataBytesIncomingBidirectionalToSend()); + EXPECT_EQ(kMinimumFlowControlSendWindow, + config_.GetInitialMaxStreamDataBytesOutgoingBidirectionalToSend()); + EXPECT_EQ(kMinimumFlowControlSendWindow, + config_.GetInitialMaxStreamDataBytesUnidirectionalToSend()); + EXPECT_FALSE(config_.HasReceivedInitialStreamFlowControlWindowBytes()); + EXPECT_FALSE( + config_.HasReceivedInitialMaxStreamDataBytesIncomingBidirectional()); + EXPECT_FALSE( + config_.HasReceivedInitialMaxStreamDataBytesOutgoingBidirectional()); + EXPECT_FALSE(config_.HasReceivedInitialMaxStreamDataBytesUnidirectional()); + EXPECT_EQ(kMaxIncomingPacketSize, config_.GetMaxPacketSizeToSend()); + EXPECT_FALSE(config_.HasReceivedMaxPacketSize()); +} + +TEST_P(QuicConfigTest, AutoSetIetfFlowControl) { + EXPECT_EQ(kMinimumFlowControlSendWindow, + config_.GetInitialStreamFlowControlWindowToSend()); + EXPECT_EQ(kMinimumFlowControlSendWindow, + config_.GetInitialMaxStreamDataBytesIncomingBidirectionalToSend()); + EXPECT_EQ(kMinimumFlowControlSendWindow, + config_.GetInitialMaxStreamDataBytesOutgoingBidirectionalToSend()); + EXPECT_EQ(kMinimumFlowControlSendWindow, + config_.GetInitialMaxStreamDataBytesUnidirectionalToSend()); + static const uint32_t kTestWindowSize = 1234567; + config_.SetInitialStreamFlowControlWindowToSend(kTestWindowSize); + EXPECT_EQ(kTestWindowSize, config_.GetInitialStreamFlowControlWindowToSend()); + EXPECT_EQ(kTestWindowSize, + config_.GetInitialMaxStreamDataBytesIncomingBidirectionalToSend()); + EXPECT_EQ(kTestWindowSize, + config_.GetInitialMaxStreamDataBytesOutgoingBidirectionalToSend()); + EXPECT_EQ(kTestWindowSize, + config_.GetInitialMaxStreamDataBytesUnidirectionalToSend()); + static const uint32_t kTestWindowSizeTwo = 2345678; + config_.SetInitialMaxStreamDataBytesIncomingBidirectionalToSend( + kTestWindowSizeTwo); + EXPECT_EQ(kTestWindowSize, config_.GetInitialStreamFlowControlWindowToSend()); + EXPECT_EQ(kTestWindowSizeTwo, + config_.GetInitialMaxStreamDataBytesIncomingBidirectionalToSend()); + EXPECT_EQ(kTestWindowSize, + config_.GetInitialMaxStreamDataBytesOutgoingBidirectionalToSend()); + EXPECT_EQ(kTestWindowSize, + config_.GetInitialMaxStreamDataBytesUnidirectionalToSend()); +} + +TEST_P(QuicConfigTest, ToHandshakeMessage) { + if (version_.UsesTls()) { + // CryptoHandshakeMessage is only used for QUIC_CRYPTO. + return; + } + config_.SetInitialStreamFlowControlWindowToSend( + kInitialStreamFlowControlWindowForTest); + config_.SetInitialSessionFlowControlWindowToSend( + kInitialSessionFlowControlWindowForTest); + config_.SetIdleNetworkTimeout(QuicTime::Delta::FromSeconds(5)); + CryptoHandshakeMessage msg; + config_.ToHandshakeMessage(&msg, version_.transport_version); + + uint32_t value; + QuicErrorCode error = msg.GetUint32(kICSL, &value); + EXPECT_THAT(error, IsQuicNoError()); + EXPECT_EQ(5u, value); + + error = msg.GetUint32(kSFCW, &value); + EXPECT_THAT(error, IsQuicNoError()); + EXPECT_EQ(kInitialStreamFlowControlWindowForTest, value); + + error = msg.GetUint32(kCFCW, &value); + EXPECT_THAT(error, IsQuicNoError()); + EXPECT_EQ(kInitialSessionFlowControlWindowForTest, value); +} + +TEST_P(QuicConfigTest, ProcessClientHello) { + if (version_.UsesTls()) { + // CryptoHandshakeMessage is only used for QUIC_CRYPTO. + return; + } + const uint32_t kTestMaxAckDelayMs = + static_cast(kDefaultDelayedAckTimeMs + 1); + QuicConfig client_config; + QuicTagVector cgst; + cgst.push_back(kQBIC); + client_config.SetIdleNetworkTimeout( + QuicTime::Delta::FromSeconds(2 * kMaximumIdleTimeoutSecs)); + client_config.SetInitialRoundTripTimeUsToSend(10 * kNumMicrosPerMilli); + client_config.SetInitialStreamFlowControlWindowToSend( + 2 * kInitialStreamFlowControlWindowForTest); + client_config.SetInitialSessionFlowControlWindowToSend( + 2 * kInitialSessionFlowControlWindowForTest); + QuicTagVector copt; + copt.push_back(kTBBR); + client_config.SetConnectionOptionsToSend(copt); + client_config.SetMaxAckDelayToSendMs(kTestMaxAckDelayMs); + CryptoHandshakeMessage msg; + client_config.ToHandshakeMessage(&msg, version_.transport_version); + + std::string error_details; + QuicTagVector initial_received_options; + initial_received_options.push_back(kIW50); + EXPECT_TRUE( + config_.SetInitialReceivedConnectionOptions(initial_received_options)); + EXPECT_FALSE( + config_.SetInitialReceivedConnectionOptions(initial_received_options)) + << "You can only set initial options once."; + const QuicErrorCode error = + config_.ProcessPeerHello(msg, CLIENT, &error_details); + EXPECT_FALSE( + config_.SetInitialReceivedConnectionOptions(initial_received_options)) + << "You cannot set initial options after the hello."; + EXPECT_THAT(error, IsQuicNoError()); + EXPECT_TRUE(config_.negotiated()); + EXPECT_EQ(QuicTime::Delta::FromSeconds(kMaximumIdleTimeoutSecs), + config_.IdleNetworkTimeout()); + EXPECT_EQ(10 * kNumMicrosPerMilli, config_.ReceivedInitialRoundTripTimeUs()); + EXPECT_TRUE(config_.HasReceivedConnectionOptions()); + EXPECT_EQ(2u, config_.ReceivedConnectionOptions().size()); + EXPECT_EQ(config_.ReceivedConnectionOptions()[0], kIW50); + EXPECT_EQ(config_.ReceivedConnectionOptions()[1], kTBBR); + EXPECT_EQ(config_.ReceivedInitialStreamFlowControlWindowBytes(), + 2 * kInitialStreamFlowControlWindowForTest); + EXPECT_EQ(config_.ReceivedInitialSessionFlowControlWindowBytes(), + 2 * kInitialSessionFlowControlWindowForTest); + EXPECT_TRUE(config_.HasReceivedMaxAckDelayMs()); + EXPECT_EQ(kTestMaxAckDelayMs, config_.ReceivedMaxAckDelayMs()); + + // IETF QUIC stream limits should not be received in QUIC crypto messages. + EXPECT_FALSE( + config_.HasReceivedInitialMaxStreamDataBytesIncomingBidirectional()); + EXPECT_FALSE( + config_.HasReceivedInitialMaxStreamDataBytesOutgoingBidirectional()); + EXPECT_FALSE(config_.HasReceivedInitialMaxStreamDataBytesUnidirectional()); +} + +TEST_P(QuicConfigTest, ProcessServerHello) { + if (version_.UsesTls()) { + // CryptoHandshakeMessage is only used for QUIC_CRYPTO. + return; + } + QuicIpAddress host; + host.FromString("127.0.3.1"); + const QuicSocketAddress kTestServerAddress = QuicSocketAddress(host, 1234); + const StatelessResetToken kTestStatelessResetToken{ + 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, + 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f}; + const uint32_t kTestMaxAckDelayMs = + static_cast(kDefaultDelayedAckTimeMs + 1); + QuicConfig server_config; + QuicTagVector cgst; + cgst.push_back(kQBIC); + server_config.SetIdleNetworkTimeout( + QuicTime::Delta::FromSeconds(kMaximumIdleTimeoutSecs / 2)); + server_config.SetInitialRoundTripTimeUsToSend(10 * kNumMicrosPerMilli); + server_config.SetInitialStreamFlowControlWindowToSend( + 2 * kInitialStreamFlowControlWindowForTest); + server_config.SetInitialSessionFlowControlWindowToSend( + 2 * kInitialSessionFlowControlWindowForTest); + server_config.SetIPv4AlternateServerAddressToSend(kTestServerAddress); + server_config.SetStatelessResetTokenToSend(kTestStatelessResetToken); + server_config.SetMaxAckDelayToSendMs(kTestMaxAckDelayMs); + CryptoHandshakeMessage msg; + server_config.ToHandshakeMessage(&msg, version_.transport_version); + std::string error_details; + const QuicErrorCode error = + config_.ProcessPeerHello(msg, SERVER, &error_details); + EXPECT_THAT(error, IsQuicNoError()); + EXPECT_TRUE(config_.negotiated()); + EXPECT_EQ(QuicTime::Delta::FromSeconds(kMaximumIdleTimeoutSecs / 2), + config_.IdleNetworkTimeout()); + EXPECT_EQ(10 * kNumMicrosPerMilli, config_.ReceivedInitialRoundTripTimeUs()); + EXPECT_EQ(config_.ReceivedInitialStreamFlowControlWindowBytes(), + 2 * kInitialStreamFlowControlWindowForTest); + EXPECT_EQ(config_.ReceivedInitialSessionFlowControlWindowBytes(), + 2 * kInitialSessionFlowControlWindowForTest); + EXPECT_TRUE(config_.HasReceivedIPv4AlternateServerAddress()); + EXPECT_EQ(kTestServerAddress, config_.ReceivedIPv4AlternateServerAddress()); + EXPECT_FALSE(config_.HasReceivedIPv6AlternateServerAddress()); + EXPECT_TRUE(config_.HasReceivedStatelessResetToken()); + EXPECT_EQ(kTestStatelessResetToken, config_.ReceivedStatelessResetToken()); + EXPECT_TRUE(config_.HasReceivedMaxAckDelayMs()); + EXPECT_EQ(kTestMaxAckDelayMs, config_.ReceivedMaxAckDelayMs()); + + // IETF QUIC stream limits should not be received in QUIC crypto messages. + EXPECT_FALSE( + config_.HasReceivedInitialMaxStreamDataBytesIncomingBidirectional()); + EXPECT_FALSE( + config_.HasReceivedInitialMaxStreamDataBytesOutgoingBidirectional()); + EXPECT_FALSE(config_.HasReceivedInitialMaxStreamDataBytesUnidirectional()); +} + +TEST_P(QuicConfigTest, MissingOptionalValuesInCHLO) { + if (version_.UsesTls()) { + // CryptoHandshakeMessage is only used for QUIC_CRYPTO. + return; + } + CryptoHandshakeMessage msg; + msg.SetValue(kICSL, 1); + + // Set all REQUIRED tags. + msg.SetValue(kICSL, 1); + msg.SetValue(kMIBS, 1); + + // No error, as rest are optional. + std::string error_details; + const QuicErrorCode error = + config_.ProcessPeerHello(msg, CLIENT, &error_details); + EXPECT_THAT(error, IsQuicNoError()); + EXPECT_TRUE(config_.negotiated()); +} + +TEST_P(QuicConfigTest, MissingOptionalValuesInSHLO) { + if (version_.UsesTls()) { + // CryptoHandshakeMessage is only used for QUIC_CRYPTO. + return; + } + CryptoHandshakeMessage msg; + + // Set all REQUIRED tags. + msg.SetValue(kICSL, 1); + msg.SetValue(kMIBS, 1); + + // No error, as rest are optional. + std::string error_details; + const QuicErrorCode error = + config_.ProcessPeerHello(msg, SERVER, &error_details); + EXPECT_THAT(error, IsQuicNoError()); + EXPECT_TRUE(config_.negotiated()); +} + +TEST_P(QuicConfigTest, MissingValueInCHLO) { + if (version_.UsesTls()) { + // CryptoHandshakeMessage is only used for QUIC_CRYPTO. + return; + } + // Server receives CHLO with missing kICSL. + CryptoHandshakeMessage msg; + std::string error_details; + const QuicErrorCode error = + config_.ProcessPeerHello(msg, CLIENT, &error_details); + EXPECT_THAT(error, IsError(QUIC_CRYPTO_MESSAGE_PARAMETER_NOT_FOUND)); +} + +TEST_P(QuicConfigTest, MissingValueInSHLO) { + if (version_.UsesTls()) { + // CryptoHandshakeMessage is only used for QUIC_CRYPTO. + return; + } + // Client receives SHLO with missing kICSL. + CryptoHandshakeMessage msg; + std::string error_details; + const QuicErrorCode error = + config_.ProcessPeerHello(msg, SERVER, &error_details); + EXPECT_THAT(error, IsError(QUIC_CRYPTO_MESSAGE_PARAMETER_NOT_FOUND)); +} + +TEST_P(QuicConfigTest, OutOfBoundSHLO) { + if (version_.UsesTls()) { + // CryptoHandshakeMessage is only used for QUIC_CRYPTO. + return; + } + QuicConfig server_config; + server_config.SetIdleNetworkTimeout( + QuicTime::Delta::FromSeconds(2 * kMaximumIdleTimeoutSecs)); + + CryptoHandshakeMessage msg; + server_config.ToHandshakeMessage(&msg, version_.transport_version); + std::string error_details; + const QuicErrorCode error = + config_.ProcessPeerHello(msg, SERVER, &error_details); + EXPECT_THAT(error, IsError(QUIC_INVALID_NEGOTIATED_VALUE)); +} + +TEST_P(QuicConfigTest, InvalidFlowControlWindow) { + // QuicConfig should not accept an invalid flow control window to send to the + // peer: the receive window must be at least the default of 16 Kb. + QuicConfig config; + const uint64_t kInvalidWindow = kMinimumFlowControlSendWindow - 1; + EXPECT_QUIC_BUG( + config.SetInitialStreamFlowControlWindowToSend(kInvalidWindow), + "Initial stream flow control receive window"); + + EXPECT_EQ(kMinimumFlowControlSendWindow, + config.GetInitialStreamFlowControlWindowToSend()); +} + +TEST_P(QuicConfigTest, HasClientSentConnectionOption) { + if (version_.UsesTls()) { + // CryptoHandshakeMessage is only used for QUIC_CRYPTO. + return; + } + QuicConfig client_config; + QuicTagVector copt; + copt.push_back(kTBBR); + client_config.SetConnectionOptionsToSend(copt); + EXPECT_TRUE(client_config.HasClientSentConnectionOption( + kTBBR, Perspective::IS_CLIENT)); + + CryptoHandshakeMessage msg; + client_config.ToHandshakeMessage(&msg, version_.transport_version); + + std::string error_details; + const QuicErrorCode error = + config_.ProcessPeerHello(msg, CLIENT, &error_details); + EXPECT_THAT(error, IsQuicNoError()); + EXPECT_TRUE(config_.negotiated()); + + EXPECT_TRUE(config_.HasReceivedConnectionOptions()); + EXPECT_EQ(1u, config_.ReceivedConnectionOptions().size()); + EXPECT_TRUE( + config_.HasClientSentConnectionOption(kTBBR, Perspective::IS_SERVER)); +} + +TEST_P(QuicConfigTest, DontSendClientConnectionOptions) { + if (version_.UsesTls()) { + // CryptoHandshakeMessage is only used for QUIC_CRYPTO. + return; + } + QuicConfig client_config; + QuicTagVector copt; + copt.push_back(kTBBR); + client_config.SetClientConnectionOptions(copt); + + CryptoHandshakeMessage msg; + client_config.ToHandshakeMessage(&msg, version_.transport_version); + + std::string error_details; + const QuicErrorCode error = + config_.ProcessPeerHello(msg, CLIENT, &error_details); + EXPECT_THAT(error, IsQuicNoError()); + EXPECT_TRUE(config_.negotiated()); + + EXPECT_FALSE(config_.HasReceivedConnectionOptions()); +} + +TEST_P(QuicConfigTest, HasClientRequestedIndependentOption) { + if (version_.UsesTls()) { + // CryptoHandshakeMessage is only used for QUIC_CRYPTO. + return; + } + QuicConfig client_config; + QuicTagVector client_opt; + client_opt.push_back(kRENO); + QuicTagVector copt; + copt.push_back(kTBBR); + client_config.SetClientConnectionOptions(client_opt); + client_config.SetConnectionOptionsToSend(copt); + EXPECT_TRUE(client_config.HasClientSentConnectionOption( + kTBBR, Perspective::IS_CLIENT)); + EXPECT_TRUE(client_config.HasClientRequestedIndependentOption( + kRENO, Perspective::IS_CLIENT)); + EXPECT_FALSE(client_config.HasClientRequestedIndependentOption( + kTBBR, Perspective::IS_CLIENT)); + + CryptoHandshakeMessage msg; + client_config.ToHandshakeMessage(&msg, version_.transport_version); + + std::string error_details; + const QuicErrorCode error = + config_.ProcessPeerHello(msg, CLIENT, &error_details); + EXPECT_THAT(error, IsQuicNoError()); + EXPECT_TRUE(config_.negotiated()); + + EXPECT_TRUE(config_.HasReceivedConnectionOptions()); + EXPECT_EQ(1u, config_.ReceivedConnectionOptions().size()); + EXPECT_FALSE(config_.HasClientRequestedIndependentOption( + kRENO, Perspective::IS_SERVER)); + EXPECT_TRUE(config_.HasClientRequestedIndependentOption( + kTBBR, Perspective::IS_SERVER)); +} + +TEST_P(QuicConfigTest, IncomingLargeIdleTimeoutTransportParameter) { + if (!version_.UsesTls()) { + // TransportParameters are only used for QUIC+TLS. + return; + } + // Configure our idle timeout to 60s, then receive 120s from peer. + // Since the received value is above ours, we should then use ours. + config_.SetIdleNetworkTimeout(quic::QuicTime::Delta::FromSeconds(60)); + TransportParameters params; + params.max_idle_timeout_ms.set_value(120000); + + std::string error_details = "foobar"; + EXPECT_THAT(config_.ProcessTransportParameters( + params, /* is_resumption = */ false, &error_details), + IsQuicNoError()); + EXPECT_EQ("", error_details); + EXPECT_EQ(quic::QuicTime::Delta::FromSeconds(60), + config_.IdleNetworkTimeout()); +} + +TEST_P(QuicConfigTest, ReceivedInvalidMinAckDelayInTransportParameter) { + if (!version_.UsesTls()) { + // TransportParameters are only used for QUIC+TLS. + return; + } + TransportParameters params; + + params.max_ack_delay.set_value(25 /*ms*/); + params.min_ack_delay_us.set_value(25 * kNumMicrosPerMilli + 1); + std::string error_details = "foobar"; + EXPECT_THAT(config_.ProcessTransportParameters( + params, /* is_resumption = */ false, &error_details), + IsError(IETF_QUIC_PROTOCOL_VIOLATION)); + EXPECT_EQ("MinAckDelay is greater than MaxAckDelay.", error_details); + + params.max_ack_delay.set_value(25 /*ms*/); + params.min_ack_delay_us.set_value(25 * kNumMicrosPerMilli); + EXPECT_THAT(config_.ProcessTransportParameters( + params, /* is_resumption = */ false, &error_details), + IsQuicNoError()); + EXPECT_TRUE(error_details.empty()); +} + +TEST_P(QuicConfigTest, FillTransportParams) { + if (!version_.UsesTls()) { + // TransportParameters are only used for QUIC+TLS. + return; + } + const std::string kFakeGoogleHandshakeMessage = "Fake handshake message"; + config_.SetInitialMaxStreamDataBytesIncomingBidirectionalToSend( + 2 * kMinimumFlowControlSendWindow); + config_.SetInitialMaxStreamDataBytesOutgoingBidirectionalToSend( + 3 * kMinimumFlowControlSendWindow); + config_.SetInitialMaxStreamDataBytesUnidirectionalToSend( + 4 * kMinimumFlowControlSendWindow); + config_.SetMaxPacketSizeToSend(kMaxPacketSizeForTest); + config_.SetMaxDatagramFrameSizeToSend(kMaxDatagramFrameSizeForTest); + config_.SetActiveConnectionIdLimitToSend(kActiveConnectionIdLimitForTest); + + config_.SetOriginalConnectionIdToSend(TestConnectionId(0x1111)); + config_.SetInitialSourceConnectionIdToSend(TestConnectionId(0x2222)); + config_.SetRetrySourceConnectionIdToSend(TestConnectionId(0x3333)); + config_.SetMinAckDelayMs(kDefaultMinAckDelayTimeMs); + config_.SetGoogleHandshakeMessageToSend(kFakeGoogleHandshakeMessage); + + QuicIpAddress host; + host.FromString("127.0.3.1"); + QuicSocketAddress kTestServerAddress = QuicSocketAddress(host, 1234); + QuicConnectionId new_connection_id = TestConnectionId(5); + StatelessResetToken new_stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(new_connection_id); + config_.SetIPv4AlternateServerAddressToSend(kTestServerAddress); + QuicSocketAddress kTestServerAddressV6 = + QuicSocketAddress(QuicIpAddress::Any6(), 1234); + config_.SetIPv6AlternateServerAddressToSend(kTestServerAddressV6); + config_.SetPreferredAddressConnectionIdAndTokenToSend( + new_connection_id, new_stateless_reset_token); + config_.ClearAlternateServerAddressToSend(quiche::IpAddressFamily::IP_V6); + EXPECT_TRUE(config_.GetPreferredAddressToSend(quiche::IpAddressFamily::IP_V4) + .has_value()); + EXPECT_FALSE(config_.GetPreferredAddressToSend(quiche::IpAddressFamily::IP_V6) + .has_value()); + + TransportParameters params; + config_.FillTransportParameters(¶ms); + + EXPECT_EQ(2 * kMinimumFlowControlSendWindow, + params.initial_max_stream_data_bidi_remote.value()); + EXPECT_EQ(3 * kMinimumFlowControlSendWindow, + params.initial_max_stream_data_bidi_local.value()); + EXPECT_EQ(4 * kMinimumFlowControlSendWindow, + params.initial_max_stream_data_uni.value()); + + EXPECT_EQ(static_cast(kMaximumIdleTimeoutSecs * 1000), + params.max_idle_timeout_ms.value()); + + EXPECT_EQ(kMaxPacketSizeForTest, params.max_udp_payload_size.value()); + EXPECT_EQ(kMaxDatagramFrameSizeForTest, + params.max_datagram_frame_size.value()); + EXPECT_EQ(kActiveConnectionIdLimitForTest, + params.active_connection_id_limit.value()); + + ASSERT_TRUE(params.original_destination_connection_id.has_value()); + EXPECT_EQ(TestConnectionId(0x1111), + params.original_destination_connection_id.value()); + ASSERT_TRUE(params.initial_source_connection_id.has_value()); + EXPECT_EQ(TestConnectionId(0x2222), + params.initial_source_connection_id.value()); + ASSERT_TRUE(params.retry_source_connection_id.has_value()); + EXPECT_EQ(TestConnectionId(0x3333), + params.retry_source_connection_id.value()); + + EXPECT_EQ( + static_cast(kDefaultMinAckDelayTimeMs) * kNumMicrosPerMilli, + params.min_ack_delay_us.value()); + + EXPECT_EQ(params.preferred_address->ipv4_socket_address, kTestServerAddress); + EXPECT_EQ(params.preferred_address->ipv6_socket_address, + QuicSocketAddress(QuicIpAddress::Any6(), 0)); + + EXPECT_EQ(*reinterpret_cast( + ¶ms.preferred_address->stateless_reset_token.front()), + new_stateless_reset_token); + EXPECT_EQ(kFakeGoogleHandshakeMessage, params.google_handshake_message); +} + +TEST_P(QuicConfigTest, FillTransportParamsNoV4PreferredAddress) { + if (!version_.UsesTls()) { + // TransportParameters are only used for QUIC+TLS. + return; + } + + QuicIpAddress host; + host.FromString("127.0.3.1"); + QuicSocketAddress kTestServerAddress = QuicSocketAddress(host, 1234); + QuicConnectionId new_connection_id = TestConnectionId(5); + StatelessResetToken new_stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(new_connection_id); + config_.SetIPv4AlternateServerAddressToSend(kTestServerAddress); + QuicSocketAddress kTestServerAddressV6 = + QuicSocketAddress(QuicIpAddress::Any6(), 1234); + config_.SetIPv6AlternateServerAddressToSend(kTestServerAddressV6); + config_.SetPreferredAddressConnectionIdAndTokenToSend( + new_connection_id, new_stateless_reset_token); + config_.ClearAlternateServerAddressToSend(quiche::IpAddressFamily::IP_V4); + EXPECT_FALSE(config_.GetPreferredAddressToSend(quiche::IpAddressFamily::IP_V4) + .has_value()); + config_.ClearAlternateServerAddressToSend(quiche::IpAddressFamily::IP_V4); + + TransportParameters params; + config_.FillTransportParameters(¶ms); + EXPECT_EQ(params.preferred_address->ipv4_socket_address, + QuicSocketAddress(QuicIpAddress::Any4(), 0)); + EXPECT_EQ(params.preferred_address->ipv6_socket_address, + kTestServerAddressV6); +} + +TEST_P(QuicConfigTest, ProcessTransportParametersServer) { + if (!version_.UsesTls()) { + // TransportParameters are only used for QUIC+TLS. + return; + } + const std::string kFakeGoogleHandshakeMessage = "Fake handshake message"; + TransportParameters params; + + params.initial_max_stream_data_bidi_local.set_value( + 2 * kMinimumFlowControlSendWindow); + params.initial_max_stream_data_bidi_remote.set_value( + 3 * kMinimumFlowControlSendWindow); + params.initial_max_stream_data_uni.set_value(4 * + kMinimumFlowControlSendWindow); + params.max_udp_payload_size.set_value(kMaxPacketSizeForTest); + params.max_datagram_frame_size.set_value(kMaxDatagramFrameSizeForTest); + params.initial_max_streams_bidi.set_value(kDefaultMaxStreamsPerConnection); + params.stateless_reset_token = CreateStatelessResetTokenForTest(); + params.max_ack_delay.set_value(kMaxAckDelayForTest); + params.min_ack_delay_us.set_value(kMinAckDelayUsForTest); + params.ack_delay_exponent.set_value(kAckDelayExponentForTest); + params.active_connection_id_limit.set_value(kActiveConnectionIdLimitForTest); + params.original_destination_connection_id = TestConnectionId(0x1111); + params.initial_source_connection_id = TestConnectionId(0x2222); + params.retry_source_connection_id = TestConnectionId(0x3333); + params.google_handshake_message = kFakeGoogleHandshakeMessage; + + std::string error_details; + EXPECT_THAT(config_.ProcessTransportParameters( + params, /* is_resumption = */ true, &error_details), + IsQuicNoError()) + << error_details; + + EXPECT_FALSE(config_.negotiated()); + + ASSERT_TRUE( + config_.HasReceivedInitialMaxStreamDataBytesIncomingBidirectional()); + EXPECT_EQ(2 * kMinimumFlowControlSendWindow, + config_.ReceivedInitialMaxStreamDataBytesIncomingBidirectional()); + + ASSERT_TRUE( + config_.HasReceivedInitialMaxStreamDataBytesOutgoingBidirectional()); + EXPECT_EQ(3 * kMinimumFlowControlSendWindow, + config_.ReceivedInitialMaxStreamDataBytesOutgoingBidirectional()); + + ASSERT_TRUE(config_.HasReceivedInitialMaxStreamDataBytesUnidirectional()); + EXPECT_EQ(4 * kMinimumFlowControlSendWindow, + config_.ReceivedInitialMaxStreamDataBytesUnidirectional()); + + ASSERT_TRUE(config_.HasReceivedMaxPacketSize()); + EXPECT_EQ(kMaxPacketSizeForTest, config_.ReceivedMaxPacketSize()); + + ASSERT_TRUE(config_.HasReceivedMaxDatagramFrameSize()); + EXPECT_EQ(kMaxDatagramFrameSizeForTest, + config_.ReceivedMaxDatagramFrameSize()); + + ASSERT_TRUE(config_.HasReceivedMaxBidirectionalStreams()); + EXPECT_EQ(kDefaultMaxStreamsPerConnection, + config_.ReceivedMaxBidirectionalStreams()); + + EXPECT_FALSE(config_.DisableConnectionMigration()); + + // The following config shouldn't be processed because of resumption. + EXPECT_FALSE(config_.HasReceivedStatelessResetToken()); + EXPECT_FALSE(config_.HasReceivedMaxAckDelayMs()); + EXPECT_FALSE(config_.HasReceivedAckDelayExponent()); + EXPECT_FALSE(config_.HasReceivedMinAckDelayMs()); + EXPECT_FALSE(config_.HasReceivedOriginalConnectionId()); + EXPECT_FALSE(config_.HasReceivedInitialSourceConnectionId()); + EXPECT_FALSE(config_.HasReceivedRetrySourceConnectionId()); + + // Let the config process another slightly tweaked transport paramters. + // Note that the values for flow control and stream limit cannot be smaller + // than before. This rule is enforced in QuicSession::OnConfigNegotiated(). + params.initial_max_stream_data_bidi_local.set_value( + 2 * kMinimumFlowControlSendWindow + 1); + params.initial_max_stream_data_bidi_remote.set_value( + 4 * kMinimumFlowControlSendWindow); + params.initial_max_stream_data_uni.set_value(5 * + kMinimumFlowControlSendWindow); + params.max_udp_payload_size.set_value(2 * kMaxPacketSizeForTest); + params.max_datagram_frame_size.set_value(2 * kMaxDatagramFrameSizeForTest); + params.initial_max_streams_bidi.set_value(2 * + kDefaultMaxStreamsPerConnection); + params.disable_active_migration = true; + + EXPECT_THAT(config_.ProcessTransportParameters( + params, /* is_resumption = */ false, &error_details), + IsQuicNoError()) + << error_details; + + EXPECT_TRUE(config_.negotiated()); + + ASSERT_TRUE( + config_.HasReceivedInitialMaxStreamDataBytesIncomingBidirectional()); + EXPECT_EQ(2 * kMinimumFlowControlSendWindow + 1, + config_.ReceivedInitialMaxStreamDataBytesIncomingBidirectional()); + + ASSERT_TRUE( + config_.HasReceivedInitialMaxStreamDataBytesOutgoingBidirectional()); + EXPECT_EQ(4 * kMinimumFlowControlSendWindow, + config_.ReceivedInitialMaxStreamDataBytesOutgoingBidirectional()); + + ASSERT_TRUE(config_.HasReceivedInitialMaxStreamDataBytesUnidirectional()); + EXPECT_EQ(5 * kMinimumFlowControlSendWindow, + config_.ReceivedInitialMaxStreamDataBytesUnidirectional()); + + ASSERT_TRUE(config_.HasReceivedMaxPacketSize()); + EXPECT_EQ(2 * kMaxPacketSizeForTest, config_.ReceivedMaxPacketSize()); + + ASSERT_TRUE(config_.HasReceivedMaxDatagramFrameSize()); + EXPECT_EQ(2 * kMaxDatagramFrameSizeForTest, + config_.ReceivedMaxDatagramFrameSize()); + + ASSERT_TRUE(config_.HasReceivedMaxBidirectionalStreams()); + EXPECT_EQ(2 * kDefaultMaxStreamsPerConnection, + config_.ReceivedMaxBidirectionalStreams()); + + EXPECT_TRUE(config_.DisableConnectionMigration()); + + ASSERT_TRUE(config_.HasReceivedStatelessResetToken()); + + ASSERT_TRUE(config_.HasReceivedMaxAckDelayMs()); + EXPECT_EQ(config_.ReceivedMaxAckDelayMs(), kMaxAckDelayForTest); + + ASSERT_TRUE(config_.HasReceivedMinAckDelayMs()); + EXPECT_EQ(config_.ReceivedMinAckDelayMs(), + kMinAckDelayUsForTest / kNumMicrosPerMilli); + + ASSERT_TRUE(config_.HasReceivedAckDelayExponent()); + EXPECT_EQ(config_.ReceivedAckDelayExponent(), kAckDelayExponentForTest); + + ASSERT_TRUE(config_.HasReceivedActiveConnectionIdLimit()); + EXPECT_EQ(config_.ReceivedActiveConnectionIdLimit(), + kActiveConnectionIdLimitForTest); + + ASSERT_TRUE(config_.HasReceivedOriginalConnectionId()); + EXPECT_EQ(config_.ReceivedOriginalConnectionId(), TestConnectionId(0x1111)); + ASSERT_TRUE(config_.HasReceivedInitialSourceConnectionId()); + EXPECT_EQ(config_.ReceivedInitialSourceConnectionId(), + TestConnectionId(0x2222)); + ASSERT_TRUE(config_.HasReceivedRetrySourceConnectionId()); + EXPECT_EQ(config_.ReceivedRetrySourceConnectionId(), + TestConnectionId(0x3333)); + EXPECT_EQ(kFakeGoogleHandshakeMessage, + config_.GetReceivedGoogleHandshakeMessage()); +} + +TEST_P(QuicConfigTest, DisableMigrationTransportParameter) { + if (!version_.UsesTls()) { + // TransportParameters are only used for QUIC+TLS. + return; + } + TransportParameters params; + params.disable_active_migration = true; + std::string error_details; + EXPECT_THAT(config_.ProcessTransportParameters( + params, /* is_resumption = */ false, &error_details), + IsQuicNoError()); + EXPECT_TRUE(config_.DisableConnectionMigration()); +} + +TEST_P(QuicConfigTest, SendPreferredIPv4Address) { + if (!version_.UsesTls()) { + // TransportParameters are only used for QUIC+TLS. + return; + } + + EXPECT_FALSE(config_.HasReceivedPreferredAddressConnectionIdAndToken()); + + TransportParameters params; + QuicIpAddress host; + host.FromString("::ffff:192.0.2.128"); + QuicSocketAddress kTestServerAddress = QuicSocketAddress(host, 1234); + QuicConnectionId new_connection_id = TestConnectionId(5); + StatelessResetToken new_stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(new_connection_id); + auto preferred_address = + std::make_unique(); + preferred_address->ipv6_socket_address = kTestServerAddress; + preferred_address->connection_id = new_connection_id; + preferred_address->stateless_reset_token.assign( + reinterpret_cast(&new_stateless_reset_token), + reinterpret_cast(&new_stateless_reset_token) + + sizeof(new_stateless_reset_token)); + params.preferred_address = std::move(preferred_address); + + std::string error_details; + EXPECT_THAT(config_.ProcessTransportParameters( + params, /* is_resumption = */ false, &error_details), + IsQuicNoError()); + + EXPECT_TRUE(config_.HasReceivedIPv6AlternateServerAddress()); + EXPECT_EQ(config_.ReceivedIPv6AlternateServerAddress(), kTestServerAddress); + EXPECT_TRUE(config_.HasReceivedPreferredAddressConnectionIdAndToken()); + const std::pair& + preferred_address_connection_id_and_token = + config_.ReceivedPreferredAddressConnectionIdAndToken(); + EXPECT_EQ(preferred_address_connection_id_and_token.first, new_connection_id); + EXPECT_EQ(preferred_address_connection_id_and_token.second, + new_stateless_reset_token); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_connection.cc b/quiche/quic/core/quic_connection.cc new file mode 100644 index 000000000000..01e728aa7d88 --- /dev/null +++ b/quiche/quic/core/quic_connection.cc @@ -0,0 +1,7409 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_connection.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/congestion_control/rtt_stats.h" +#include "quiche/quic/core/congestion_control/send_algorithm_interface.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/crypto/crypto_utils.h" +#include "quiche/quic/core/crypto/quic_decrypter.h" +#include "quiche/quic/core/crypto/quic_encrypter.h" +#include "quiche/quic/core/quic_bandwidth.h" +#include "quiche/quic/core/quic_config.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_packet_creator.h" +#include "quiche/quic/core/quic_packet_writer.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_path_validator.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_client_stats.h" +#include "quiche/quic/platform/api/quic_exported_stats.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/common/platform/api/quiche_flag_utils.h" +#include "quiche/common/quiche_text_utils.h" + +namespace quic { + +class QuicDecrypter; +class QuicEncrypter; + +namespace { + +// Maximum number of consecutive sent nonretransmittable packets. +const QuicPacketCount kMaxConsecutiveNonRetransmittablePackets = 19; + +// The minimum release time into future in ms. +const int kMinReleaseTimeIntoFutureMs = 1; + +// The maximum number of recorded client addresses. +const size_t kMaxReceivedClientAddressSize = 20; + +// Base class of all alarms owned by a QuicConnection. +class QuicConnectionAlarmDelegate : public QuicAlarm::Delegate { + public: + explicit QuicConnectionAlarmDelegate(QuicConnection* connection) + : connection_(connection) {} + QuicConnectionAlarmDelegate(const QuicConnectionAlarmDelegate&) = delete; + QuicConnectionAlarmDelegate& operator=(const QuicConnectionAlarmDelegate&) = + delete; + + QuicConnectionContext* GetConnectionContext() override { + return (connection_ == nullptr) ? nullptr : connection_->context(); + } + + protected: + QuicConnection* connection_; +}; + +// An alarm that is scheduled to send an ack if a timeout occurs. +class AckAlarmDelegate : public QuicConnectionAlarmDelegate { + public: + using QuicConnectionAlarmDelegate::QuicConnectionAlarmDelegate; + + void OnAlarm() override { + QUICHE_DCHECK(connection_->ack_frame_updated()); + QUICHE_DCHECK(connection_->connected()); + QuicConnection::ScopedPacketFlusher flusher(connection_); + if (connection_->SupportsMultiplePacketNumberSpaces()) { + connection_->SendAllPendingAcks(); + } else { + connection_->SendAck(); + } + } +}; + +// This alarm will be scheduled any time a data-bearing packet is sent out. +// When the alarm goes off, the connection checks to see if the oldest packets +// have been acked, and retransmit them if they have not. +class RetransmissionAlarmDelegate : public QuicConnectionAlarmDelegate { + public: + using QuicConnectionAlarmDelegate::QuicConnectionAlarmDelegate; + + void OnAlarm() override { + QUICHE_DCHECK(connection_->connected()); + connection_->OnRetransmissionTimeout(); + } +}; + +// An alarm that is scheduled when the SentPacketManager requires a delay +// before sending packets and fires when the packet may be sent. +class SendAlarmDelegate : public QuicConnectionAlarmDelegate { + public: + using QuicConnectionAlarmDelegate::QuicConnectionAlarmDelegate; + + void OnAlarm() override { + QUICHE_DCHECK(connection_->connected()); + connection_->WriteIfNotBlocked(); + } +}; + +class MtuDiscoveryAlarmDelegate : public QuicConnectionAlarmDelegate { + public: + using QuicConnectionAlarmDelegate::QuicConnectionAlarmDelegate; + + void OnAlarm() override { + QUICHE_DCHECK(connection_->connected()); + connection_->DiscoverMtu(); + } +}; + +class ProcessUndecryptablePacketsAlarmDelegate + : public QuicConnectionAlarmDelegate { + public: + using QuicConnectionAlarmDelegate::QuicConnectionAlarmDelegate; + + void OnAlarm() override { + QUICHE_DCHECK(connection_->connected()); + QuicConnection::ScopedPacketFlusher flusher(connection_); + connection_->MaybeProcessUndecryptablePackets(); + } +}; + +class DiscardPreviousOneRttKeysAlarmDelegate + : public QuicConnectionAlarmDelegate { + public: + using QuicConnectionAlarmDelegate::QuicConnectionAlarmDelegate; + + void OnAlarm() override { + QUICHE_DCHECK(connection_->connected()); + connection_->DiscardPreviousOneRttKeys(); + } +}; + +class DiscardZeroRttDecryptionKeysAlarmDelegate + : public QuicConnectionAlarmDelegate { + public: + using QuicConnectionAlarmDelegate::QuicConnectionAlarmDelegate; + + void OnAlarm() override { + QUICHE_DCHECK(connection_->connected()); + QUIC_DLOG(INFO) << "0-RTT discard alarm fired"; + connection_->RemoveDecrypter(ENCRYPTION_ZERO_RTT); + connection_->RetireOriginalDestinationConnectionId(); + } +}; + +class MultiPortProbingAlarmDelegate : public QuicConnectionAlarmDelegate { + public: + using QuicConnectionAlarmDelegate::QuicConnectionAlarmDelegate; + + void OnAlarm() override { + QUICHE_DCHECK(connection_->connected()); + QUIC_DLOG(INFO) << "Alternative path probing alarm fired"; + connection_->MaybeProbeMultiPortPath(); + } +}; + +// When the clearer goes out of scope, the coalesced packet gets cleared. +class ScopedCoalescedPacketClearer { + public: + explicit ScopedCoalescedPacketClearer(QuicCoalescedPacket* coalesced) + : coalesced_(coalesced) {} + ~ScopedCoalescedPacketClearer() { coalesced_->Clear(); } + + private: + QuicCoalescedPacket* coalesced_; // Unowned. +}; + +// Whether this incoming packet is allowed to replace our connection ID. +bool PacketCanReplaceServerConnectionId(const QuicPacketHeader& header, + Perspective perspective) { + return perspective == Perspective::IS_CLIENT && + header.form == IETF_QUIC_LONG_HEADER_PACKET && + header.version.IsKnown() && + header.version.AllowsVariableLengthConnectionIds() && + (header.long_packet_type == INITIAL || + header.long_packet_type == RETRY); +} + +// Due to a lost Initial packet, a Handshake packet might use a new connection +// ID we haven't seen before. We shouldn't update the connection ID based on +// this, but should buffer the packet in case it works out. +bool NewServerConnectionIdMightBeValid(const QuicPacketHeader& header, + Perspective perspective, + bool connection_id_already_replaced) { + return perspective == Perspective::IS_CLIENT && + header.form == IETF_QUIC_LONG_HEADER_PACKET && + header.version.IsKnown() && + header.version.AllowsVariableLengthConnectionIds() && + header.long_packet_type == HANDSHAKE && + !connection_id_already_replaced; +} + +CongestionControlType GetDefaultCongestionControlType() { + if (GetQuicReloadableFlag(quic_default_to_bbr_v2)) { + return kBBRv2; + } + + if (GetQuicReloadableFlag(quic_default_to_bbr)) { + return kBBR; + } + + return kCubicBytes; +} + +bool ContainsNonProbingFrame(const SerializedPacket& packet) { + for (const QuicFrame& frame : packet.nonretransmittable_frames) { + if (!QuicUtils::IsProbingFrame(frame.type)) { + return true; + } + } + for (const QuicFrame& frame : packet.retransmittable_frames) { + if (!QuicUtils::IsProbingFrame(frame.type)) { + return true; + } + } + return false; +} + +} // namespace + +#define ENDPOINT \ + (perspective_ == Perspective::IS_SERVER ? "Server: " : "Client: ") + +QuicConnection::QuicConnection( + QuicConnectionId server_connection_id, + QuicSocketAddress initial_self_address, + QuicSocketAddress initial_peer_address, + QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory, + QuicPacketWriter* writer, bool owns_writer, Perspective perspective, + const ParsedQuicVersionVector& supported_versions, + ConnectionIdGeneratorInterface& generator) + : framer_(supported_versions, helper->GetClock()->ApproximateNow(), + perspective, server_connection_id.length()), + current_packet_content_(NO_FRAMES_RECEIVED), + is_current_packet_connectivity_probing_(false), + has_path_challenge_in_current_packet_(false), + current_effective_peer_migration_type_(NO_CHANGE), + helper_(helper), + alarm_factory_(alarm_factory), + per_packet_options_(nullptr), + writer_(writer), + owns_writer_(owns_writer), + encryption_level_(ENCRYPTION_INITIAL), + clock_(helper->GetClock()), + random_generator_(helper->GetRandomGenerator()), + client_connection_id_is_set_(false), + direct_peer_address_(initial_peer_address), + default_path_(initial_self_address, QuicSocketAddress(), + /*client_connection_id=*/EmptyQuicConnectionId(), + server_connection_id, + /*stateless_reset_token=*/absl::nullopt), + active_effective_peer_migration_type_(NO_CHANGE), + support_key_update_for_connection_(false), + current_packet_data_(nullptr), + should_last_packet_instigate_acks_(false), + max_undecryptable_packets_(0), + max_tracked_packets_(GetQuicFlag(quic_max_tracked_packet_count)), + idle_timeout_connection_close_behavior_( + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET), + num_rtos_for_blackhole_detection_(0), + uber_received_packet_manager_(&stats_), + stop_waiting_count_(0), + pending_retransmission_alarm_(false), + defer_send_in_response_to_packets_(false), + arena_(), + ack_alarm_(alarm_factory_->CreateAlarm(arena_.New(this), + &arena_)), + retransmission_alarm_(alarm_factory_->CreateAlarm( + arena_.New(this), &arena_)), + send_alarm_(alarm_factory_->CreateAlarm( + arena_.New(this), &arena_)), + mtu_discovery_alarm_(alarm_factory_->CreateAlarm( + arena_.New(this), &arena_)), + process_undecryptable_packets_alarm_(alarm_factory_->CreateAlarm( + arena_.New(this), &arena_)), + discard_previous_one_rtt_keys_alarm_(alarm_factory_->CreateAlarm( + arena_.New(this), &arena_)), + discard_zero_rtt_decryption_keys_alarm_(alarm_factory_->CreateAlarm( + arena_.New(this), + &arena_)), + multi_port_probing_alarm_(alarm_factory_->CreateAlarm( + arena_.New(this), &arena_)), + visitor_(nullptr), + debug_visitor_(nullptr), + packet_creator_(server_connection_id, &framer_, random_generator_, this), + last_received_packet_info_(clock_->ApproximateNow()), + sent_packet_manager_(perspective, clock_, random_generator_, &stats_, + GetDefaultCongestionControlType()), + version_negotiated_(false), + perspective_(perspective), + connected_(true), + can_truncate_connection_ids_(perspective == Perspective::IS_SERVER), + mtu_probe_count_(0), + previous_validated_mtu_(0), + peer_max_packet_size_(kDefaultMaxPacketSizeTransportParam), + largest_received_packet_size_(0), + write_error_occurred_(false), + no_stop_waiting_frames_(version().HasIetfInvariantHeader()), + consecutive_num_packets_with_no_retransmittable_frames_(0), + max_consecutive_num_packets_with_no_retransmittable_frames_( + kMaxConsecutiveNonRetransmittablePackets), + bundle_retransmittable_with_pto_ack_(false), + last_control_frame_id_(kInvalidControlFrameId), + is_path_degrading_(false), + processing_ack_frame_(false), + supports_release_time_(false), + release_time_into_future_(QuicTime::Delta::Zero()), + blackhole_detector_(this, &arena_, alarm_factory_, &context_), + idle_network_detector_(this, clock_->ApproximateNow(), &arena_, + alarm_factory_, &context_), + path_validator_(alarm_factory_, &arena_, this, random_generator_, clock_, + &context_), + ping_manager_(perspective, this, &arena_, alarm_factory_, &context_), + multi_port_probing_interval_(kDefaultMultiPortProbingInterval), + connection_id_generator_(generator), + received_client_addresses_cache_(kMaxReceivedClientAddressSize) { + QUICHE_DCHECK(perspective_ == Perspective::IS_CLIENT || + default_path_.self_address.IsInitialized()); + + QUIC_DLOG(INFO) << ENDPOINT << "Created connection with server connection ID " + << server_connection_id + << " and version: " << ParsedQuicVersionToString(version()); + + QUIC_BUG_IF(quic_bug_12714_2, !QuicUtils::IsConnectionIdValidForVersion( + server_connection_id, transport_version())) + << "QuicConnection: attempted to use server connection ID " + << server_connection_id << " which is invalid with version " << version(); + framer_.set_visitor(this); + stats_.connection_creation_time = clock_->ApproximateNow(); + // TODO(ianswett): Supply the NetworkChangeVisitor as a constructor argument + // and make it required non-null, because it's always used. + sent_packet_manager_.SetNetworkChangeVisitor(this); + if (GetQuicRestartFlag(quic_offload_pacing_to_usps2)) { + sent_packet_manager_.SetPacingAlarmGranularity(QuicTime::Delta::Zero()); + release_time_into_future_ = + QuicTime::Delta::FromMilliseconds(kMinReleaseTimeIntoFutureMs); + } + // Allow the packet writer to potentially reduce the packet size to a value + // even smaller than kDefaultMaxPacketSize. + SetMaxPacketLength(perspective_ == Perspective::IS_SERVER + ? kDefaultServerMaxPacketSize + : kDefaultMaxPacketSize); + uber_received_packet_manager_.set_max_ack_ranges(255); + MaybeEnableMultiplePacketNumberSpacesSupport(); + QUICHE_DCHECK(perspective_ == Perspective::IS_CLIENT || + supported_versions.size() == 1); + InstallInitialCrypters(default_path_.server_connection_id); + + // On the server side, version negotiation has been done by the dispatcher, + // and the server connection is created with the right version. + if (perspective_ == Perspective::IS_SERVER) { + SetVersionNegotiated(); + } + if (default_enable_5rto_blackhole_detection_) { + num_rtos_for_blackhole_detection_ = 5; + if (GetQuicReloadableFlag(quic_disable_server_blackhole_detection) && + perspective_ == Perspective::IS_SERVER) { + QUIC_RELOADABLE_FLAG_COUNT(quic_disable_server_blackhole_detection); + blackhole_detection_disabled_ = true; + } + } + if (perspective_ == Perspective::IS_CLIENT) { + AddKnownServerAddress(initial_peer_address); + } + packet_creator_.SetDefaultPeerAddress(initial_peer_address); +} + +void QuicConnection::InstallInitialCrypters(QuicConnectionId connection_id) { + CrypterPair crypters; + CryptoUtils::CreateInitialObfuscators(perspective_, version(), connection_id, + &crypters); + SetEncrypter(ENCRYPTION_INITIAL, std::move(crypters.encrypter)); + if (version().KnowsWhichDecrypterToUse()) { + InstallDecrypter(ENCRYPTION_INITIAL, std::move(crypters.decrypter)); + } else { + SetDecrypter(ENCRYPTION_INITIAL, std::move(crypters.decrypter)); + } +} + +QuicConnection::~QuicConnection() { + QUICHE_DCHECK_GE(stats_.max_egress_mtu, long_term_mtu_); + if (owns_writer_) { + delete writer_; + } + ClearQueuedPackets(); + if (stats_ + .num_tls_server_zero_rtt_packets_received_after_discarding_decrypter > + 0) { + QUIC_CODE_COUNT_N( + quic_server_received_tls_zero_rtt_packet_after_discarding_decrypter, 2, + 3); + } else { + QUIC_CODE_COUNT_N( + quic_server_received_tls_zero_rtt_packet_after_discarding_decrypter, 3, + 3); + } +} + +void QuicConnection::ClearQueuedPackets() { buffered_packets_.clear(); } + +bool QuicConnection::ValidateConfigConnectionIds(const QuicConfig& config) { + QUICHE_DCHECK(config.negotiated()); + if (!version().UsesTls()) { + // QUIC+TLS is required to transmit connection ID transport parameters. + return true; + } + // This function validates connection IDs as defined in IETF draft-28 and + // later. + + // Validate initial_source_connection_id. + QuicConnectionId expected_initial_source_connection_id; + if (perspective_ == Perspective::IS_CLIENT) { + expected_initial_source_connection_id = default_path_.server_connection_id; + } else { + expected_initial_source_connection_id = default_path_.client_connection_id; + } + if (!config.HasReceivedInitialSourceConnectionId() || + config.ReceivedInitialSourceConnectionId() != + expected_initial_source_connection_id) { + std::string received_value; + if (config.HasReceivedInitialSourceConnectionId()) { + received_value = config.ReceivedInitialSourceConnectionId().ToString(); + } else { + received_value = "none"; + } + std::string error_details = + absl::StrCat("Bad initial_source_connection_id: expected ", + expected_initial_source_connection_id.ToString(), + ", received ", received_value); + CloseConnection(IETF_QUIC_PROTOCOL_VIOLATION, error_details, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return false; + } + if (perspective_ == Perspective::IS_CLIENT) { + // Validate original_destination_connection_id. + if (!config.HasReceivedOriginalConnectionId() || + config.ReceivedOriginalConnectionId() != + GetOriginalDestinationConnectionId()) { + std::string received_value; + if (config.HasReceivedOriginalConnectionId()) { + received_value = config.ReceivedOriginalConnectionId().ToString(); + } else { + received_value = "none"; + } + std::string error_details = + absl::StrCat("Bad original_destination_connection_id: expected ", + GetOriginalDestinationConnectionId().ToString(), + ", received ", received_value); + CloseConnection(IETF_QUIC_PROTOCOL_VIOLATION, error_details, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return false; + } + // Validate retry_source_connection_id. + if (retry_source_connection_id_.has_value()) { + // We received a RETRY packet, validate that the retry source + // connection ID from the config matches the one from the RETRY. + if (!config.HasReceivedRetrySourceConnectionId() || + config.ReceivedRetrySourceConnectionId() != + retry_source_connection_id_.value()) { + std::string received_value; + if (config.HasReceivedRetrySourceConnectionId()) { + received_value = config.ReceivedRetrySourceConnectionId().ToString(); + } else { + received_value = "none"; + } + std::string error_details = + absl::StrCat("Bad retry_source_connection_id: expected ", + retry_source_connection_id_.value().ToString(), + ", received ", received_value); + CloseConnection(IETF_QUIC_PROTOCOL_VIOLATION, error_details, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return false; + } + } else { + // We did not receive a RETRY packet, make sure we did not receive the + // retry_source_connection_id transport parameter. + if (config.HasReceivedRetrySourceConnectionId()) { + std::string error_details = absl::StrCat( + "Bad retry_source_connection_id: did not receive RETRY but " + "received ", + config.ReceivedRetrySourceConnectionId().ToString()); + CloseConnection(IETF_QUIC_PROTOCOL_VIOLATION, error_details, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return false; + } + } + } + return true; +} + +void QuicConnection::SetFromConfig(const QuicConfig& config) { + if (config.negotiated()) { + // Handshake complete, set handshake timeout to Infinite. + SetNetworkTimeouts(QuicTime::Delta::Infinite(), + config.IdleNetworkTimeout()); + idle_timeout_connection_close_behavior_ = + ConnectionCloseBehavior::SILENT_CLOSE; + if (perspective_ == Perspective::IS_SERVER) { + idle_timeout_connection_close_behavior_ = ConnectionCloseBehavior:: + SILENT_CLOSE_WITH_CONNECTION_CLOSE_PACKET_SERIALIZED; + } + if (config.HasClientRequestedIndependentOption(kNSLC, perspective_)) { + idle_timeout_connection_close_behavior_ = + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET; + } + if (!ValidateConfigConnectionIds(config)) { + return; + } + support_key_update_for_connection_ = version().UsesTls(); + framer_.SetKeyUpdateSupportForConnection( + support_key_update_for_connection_); + } else { + SetNetworkTimeouts(config.max_time_before_crypto_handshake(), + config.max_idle_time_before_crypto_handshake()); + } + + if (version().HasIetfQuicFrames() && + config.HasReceivedPreferredAddressConnectionIdAndToken()) { + QuicNewConnectionIdFrame frame; + std::tie(frame.connection_id, frame.stateless_reset_token) = + config.ReceivedPreferredAddressConnectionIdAndToken(); + frame.sequence_number = 1u; + frame.retire_prior_to = 0u; + OnNewConnectionIdFrameInner(frame); + } + + sent_packet_manager_.SetFromConfig(config); + if (perspective_ == Perspective::IS_SERVER && + config.HasClientSentConnectionOption(kAFF2, perspective_)) { + send_ack_frequency_on_handshake_completion_ = true; + } + if (config.HasReceivedBytesForConnectionId() && + can_truncate_connection_ids_) { + packet_creator_.SetServerConnectionIdLength( + config.ReceivedBytesForConnectionId()); + } + max_undecryptable_packets_ = config.max_undecryptable_packets(); + + if (!GetQuicReloadableFlag(quic_enable_mtu_discovery_at_server)) { + if (config.HasClientRequestedIndependentOption(kMTUH, perspective_)) { + SetMtuDiscoveryTarget(kMtuDiscoveryTargetPacketSizeHigh); + } + } + if (config.HasClientRequestedIndependentOption(kMTUL, perspective_)) { + SetMtuDiscoveryTarget(kMtuDiscoveryTargetPacketSizeLow); + } + if (default_enable_5rto_blackhole_detection_) { + if (config.HasClientRequestedIndependentOption(kCBHD, perspective_)) { + QUIC_CODE_COUNT(quic_client_only_blackhole_detection); + blackhole_detection_disabled_ = true; + } + if (config.HasClientSentConnectionOption(kNBHD, perspective_)) { + blackhole_detection_disabled_ = true; + } + } + + if (config.HasClientRequestedIndependentOption(kFIDT, perspective_)) { + idle_network_detector_.enable_shorter_idle_timeout_on_sent_packet(); + } + if (perspective_ == Perspective::IS_CLIENT && version().HasIetfQuicFrames()) { + // Only conduct those experiments in IETF QUIC because random packets may + // elicit reset and gQUIC PUBLIC_RESET will cause connection close. + if (config.HasClientRequestedIndependentOption(kROWF, perspective_)) { + retransmittable_on_wire_behavior_ = SEND_FIRST_FORWARD_SECURE_PACKET; + } + if (config.HasClientRequestedIndependentOption(kROWR, perspective_)) { + retransmittable_on_wire_behavior_ = SEND_RANDOM_BYTES; + } + } + if (config.HasClientRequestedIndependentOption(k3AFF, perspective_)) { + anti_amplification_factor_ = 3; + } + if (config.HasClientRequestedIndependentOption(k10AF, perspective_)) { + anti_amplification_factor_ = 10; + } + + if (GetQuicReloadableFlag(quic_enable_server_on_wire_ping) && + perspective_ == Perspective::IS_SERVER && + config.HasClientSentConnectionOption(kSRWP, perspective_)) { + QUIC_RELOADABLE_FLAG_COUNT(quic_enable_server_on_wire_ping); + set_initial_retransmittable_on_wire_timeout( + QuicTime::Delta::FromMilliseconds(200)); + } + + if (debug_visitor_ != nullptr) { + debug_visitor_->OnSetFromConfig(config); + } + uber_received_packet_manager_.SetFromConfig(config, perspective_); + if (config.HasClientSentConnectionOption(k5RTO, perspective_)) { + num_rtos_for_blackhole_detection_ = 5; + } + if (config.HasClientSentConnectionOption(k6PTO, perspective_) || + config.HasClientSentConnectionOption(k7PTO, perspective_) || + config.HasClientSentConnectionOption(k8PTO, perspective_)) { + num_rtos_for_blackhole_detection_ = 5; + } + if (config.HasClientSentConnectionOption(kNSTP, perspective_)) { + no_stop_waiting_frames_ = true; + } + if (config.HasReceivedStatelessResetToken()) { + default_path_.stateless_reset_token = config.ReceivedStatelessResetToken(); + } + if (config.HasReceivedAckDelayExponent()) { + framer_.set_peer_ack_delay_exponent(config.ReceivedAckDelayExponent()); + } + if (config.HasClientSentConnectionOption(kEACK, perspective_)) { + bundle_retransmittable_with_pto_ack_ = true; + } + if (config.HasClientSentConnectionOption(kDFER, perspective_)) { + defer_send_in_response_to_packets_ = false; + } + + if (config.HasClientRequestedIndependentOption(kINVC, perspective_)) { + send_connection_close_for_invalid_version_ = true; + } + const bool remove_connection_migration_connection_option = + GetQuicReloadableFlag( + quic_remove_connection_migration_connection_option_v2); + if (remove_connection_migration_connection_option) { + QUIC_RELOADABLE_FLAG_COUNT( + quic_remove_connection_migration_connection_option_v2); + } + if (framer_.version().HasIetfQuicFrames() && + (remove_connection_migration_connection_option || + config.HasClientSentConnectionOption(kRVCM, perspective_))) { + validate_client_addresses_ = true; + } + // Having connection_migration_use_new_cid_ depends on the same set of flags + // and connection option on both client and server sides has the advantage of: + // 1) Less chance of skew in using new connection ID or not between client + // and server in unit tests with random flag combinations. + // 2) Client side's rollout can be protected by the same connection option. + connection_migration_use_new_cid_ = + validate_client_addresses_ && + GetQuicReloadableFlag(quic_connection_migration_use_new_cid_v2); + + if (connection_migration_use_new_cid_ && + config.HasReceivedPreferredAddressConnectionIdAndToken() && + config.HasClientSentConnectionOption(kSPAD, perspective_)) { + if (self_address().host().IsIPv4() && + config.HasReceivedIPv4AlternateServerAddress()) { + received_server_preferred_address_ = + config.ReceivedIPv4AlternateServerAddress(); + } else if (self_address().host().IsIPv6() && + config.HasReceivedIPv6AlternateServerAddress()) { + received_server_preferred_address_ = + config.ReceivedIPv6AlternateServerAddress(); + } + if (received_server_preferred_address_.IsInitialized()) { + QUICHE_DLOG(INFO) << ENDPOINT << "Received server preferred address: " + << received_server_preferred_address_; + if (config.HasClientRequestedIndependentOption(kSPA2, perspective_)) { + accelerated_server_preferred_address_ = true; + visitor_->OnServerPreferredAddressAvailable( + received_server_preferred_address_); + } + } + } + + if (config.HasReceivedMaxPacketSize()) { + peer_max_packet_size_ = config.ReceivedMaxPacketSize(); + packet_creator_.SetMaxPacketLength( + GetLimitedMaxPacketSize(packet_creator_.max_packet_length())); + } + if (config.HasReceivedMaxDatagramFrameSize()) { + packet_creator_.SetMaxDatagramFrameSize( + config.ReceivedMaxDatagramFrameSize()); + } + + supports_release_time_ = + writer_ != nullptr && writer_->SupportsReleaseTime() && + !config.HasClientSentConnectionOption(kNPCO, perspective_); + + if (supports_release_time_) { + UpdateReleaseTimeIntoFuture(); + } + + if (perspective_ == Perspective::IS_CLIENT && + connection_migration_use_new_cid_ && + config.HasClientRequestedIndependentOption(kMPQC, perspective_)) { + multi_port_stats_ = std::make_unique(); + } +} + +bool QuicConnection::MaybeTestLiveness() { + QUICHE_DCHECK_EQ(perspective_, Perspective::IS_CLIENT); + if (liveness_testing_disabled_ || + encryption_level_ != ENCRYPTION_FORWARD_SECURE) { + return false; + } + const QuicTime idle_network_deadline = + idle_network_detector_.GetIdleNetworkDeadline(); + if (!idle_network_deadline.IsInitialized()) { + return false; + } + const QuicTime now = clock_->ApproximateNow(); + if (now > idle_network_deadline) { + QUIC_DLOG(WARNING) << "Idle network deadline has passed"; + return false; + } + const QuicTime::Delta timeout = idle_network_deadline - now; + if (2 * timeout > idle_network_detector_.idle_network_timeout()) { + // Do not test liveness if timeout is > half timeout. This is used to + // prevent an infinite loop for short idle timeout. + return false; + } + if (!sent_packet_manager_.IsLessThanThreePTOs(timeout)) { + return false; + } + QUIC_LOG_EVERY_N_SEC(INFO, 60) + << "Testing liveness, idle_network_timeout: " + << idle_network_detector_.idle_network_timeout() + << ", timeout: " << timeout + << ", Pto delay: " << sent_packet_manager_.GetPtoDelay() + << ", smoothed_rtt: " + << sent_packet_manager_.GetRttStats()->smoothed_rtt() + << ", mean deviation: " + << sent_packet_manager_.GetRttStats()->mean_deviation(); + SendConnectivityProbingPacket(writer_, peer_address()); + return true; +} + +void QuicConnection::ApplyConnectionOptions( + const QuicTagVector& connection_options) { + sent_packet_manager_.ApplyConnectionOptions(connection_options); +} + +void QuicConnection::OnSendConnectionState( + const CachedNetworkParameters& cached_network_params) { + if (debug_visitor_ != nullptr) { + debug_visitor_->OnSendConnectionState(cached_network_params); + } +} + +void QuicConnection::OnReceiveConnectionState( + const CachedNetworkParameters& cached_network_params) { + if (debug_visitor_ != nullptr) { + debug_visitor_->OnReceiveConnectionState(cached_network_params); + } +} + +void QuicConnection::ResumeConnectionState( + const CachedNetworkParameters& cached_network_params, + bool max_bandwidth_resumption) { + sent_packet_manager_.ResumeConnectionState(cached_network_params, + max_bandwidth_resumption); +} + +void QuicConnection::SetMaxPacingRate(QuicBandwidth max_pacing_rate) { + sent_packet_manager_.SetMaxPacingRate(max_pacing_rate); +} + +void QuicConnection::AdjustNetworkParameters( + const SendAlgorithmInterface::NetworkParams& params) { + sent_packet_manager_.AdjustNetworkParameters(params); +} + +void QuicConnection::SetLossDetectionTuner( + std::unique_ptr tuner) { + sent_packet_manager_.SetLossDetectionTuner(std::move(tuner)); +} + +void QuicConnection::OnConfigNegotiated() { + sent_packet_manager_.OnConfigNegotiated(); + + if (GetQuicReloadableFlag(quic_enable_mtu_discovery_at_server) && + perspective_ == Perspective::IS_SERVER) { + QUIC_RELOADABLE_FLAG_COUNT(quic_enable_mtu_discovery_at_server); + SetMtuDiscoveryTarget(kMtuDiscoveryTargetPacketSizeHigh); + } +} + +QuicBandwidth QuicConnection::MaxPacingRate() const { + return sent_packet_manager_.MaxPacingRate(); +} + +bool QuicConnection::SelectMutualVersion( + const ParsedQuicVersionVector& available_versions) { + // Try to find the highest mutual version by iterating over supported + // versions, starting with the highest, and breaking out of the loop once we + // find a matching version in the provided available_versions vector. + const ParsedQuicVersionVector& supported_versions = + framer_.supported_versions(); + for (size_t i = 0; i < supported_versions.size(); ++i) { + const ParsedQuicVersion& version = supported_versions[i]; + if (std::find(available_versions.begin(), available_versions.end(), + version) != available_versions.end()) { + framer_.set_version(version); + return true; + } + } + + return false; +} + +void QuicConnection::OnError(QuicFramer* framer) { + // Packets that we can not or have not decrypted are dropped. + // TODO(rch): add stats to measure this. + if (!connected_ || !last_received_packet_info_.decrypted) { + return; + } + CloseConnection(framer->error(), framer->detailed_error(), + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); +} + +void QuicConnection::OnPacket() { + last_received_packet_info_.decrypted = false; +} + +void QuicConnection::OnPublicResetPacket(const QuicPublicResetPacket& packet) { + // Check that any public reset packet with a different connection ID that was + // routed to this QuicConnection has been redirected before control reaches + // here. (Check for a bug regression.) + QUICHE_DCHECK_EQ(default_path_.server_connection_id, packet.connection_id); + QUICHE_DCHECK_EQ(perspective_, Perspective::IS_CLIENT); + QUICHE_DCHECK(!version().HasIetfInvariantHeader()); + if (debug_visitor_ != nullptr) { + debug_visitor_->OnPublicResetPacket(packet); + } + std::string error_details = "Received public reset."; + if (perspective_ == Perspective::IS_CLIENT && !packet.endpoint_id.empty()) { + absl::StrAppend(&error_details, " From ", packet.endpoint_id, "."); + } + QUIC_DLOG(INFO) << ENDPOINT << error_details; + QUIC_CODE_COUNT(quic_tear_down_local_connection_on_public_reset); + TearDownLocalConnectionState(QUIC_PUBLIC_RESET, NO_IETF_QUIC_ERROR, + error_details, ConnectionCloseSource::FROM_PEER); +} + +bool QuicConnection::OnProtocolVersionMismatch( + ParsedQuicVersion received_version) { + QUIC_DLOG(INFO) << ENDPOINT << "Received packet with mismatched version " + << ParsedQuicVersionToString(received_version); + if (perspective_ == Perspective::IS_CLIENT) { + const std::string error_details = "Protocol version mismatch."; + QUIC_BUG(quic_bug_10511_3) << ENDPOINT << error_details; + CloseConnection(QUIC_INTERNAL_ERROR, error_details, + ConnectionCloseBehavior::SILENT_CLOSE); + } + + // Server drops old packets that were sent by the client before the version + // was negotiated. + return false; +} + +// Handles version negotiation for client connection. +void QuicConnection::OnVersionNegotiationPacket( + const QuicVersionNegotiationPacket& packet) { + // Check that any public reset packet with a different connection ID that was + // routed to this QuicConnection has been redirected before control reaches + // here. (Check for a bug regression.) + QUICHE_DCHECK_EQ(default_path_.server_connection_id, packet.connection_id); + if (perspective_ == Perspective::IS_SERVER) { + const std::string error_details = + "Server received version negotiation packet."; + QUIC_BUG(quic_bug_10511_4) << error_details; + QUIC_CODE_COUNT(quic_tear_down_local_connection_on_version_negotiation); + CloseConnection(QUIC_INTERNAL_ERROR, error_details, + ConnectionCloseBehavior::SILENT_CLOSE); + return; + } + if (debug_visitor_ != nullptr) { + debug_visitor_->OnVersionNegotiationPacket(packet); + } + + if (version_negotiated_) { + // Possibly a duplicate version negotiation packet. + return; + } + + if (std::find(packet.versions.begin(), packet.versions.end(), version()) != + packet.versions.end()) { + const std::string error_details = absl::StrCat( + "Server already supports client's version ", + ParsedQuicVersionToString(version()), + " and should have accepted the connection instead of sending {", + ParsedQuicVersionVectorToString(packet.versions), "}."); + QUIC_DLOG(WARNING) << error_details; + CloseConnection(QUIC_INVALID_VERSION_NEGOTIATION_PACKET, error_details, + ConnectionCloseBehavior::SILENT_CLOSE); + return; + } + + server_supported_versions_ = packet.versions; + CloseConnection( + QUIC_INVALID_VERSION, + absl::StrCat( + "Client may support one of the versions in the server's list, but " + "it's going to close the connection anyway. Supported versions: {", + ParsedQuicVersionVectorToString(framer_.supported_versions()), + "}, peer supported versions: {", + ParsedQuicVersionVectorToString(packet.versions), "}"), + send_connection_close_for_invalid_version_ + ? ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET + : ConnectionCloseBehavior::SILENT_CLOSE); +} + +// Handles retry for client connection. +void QuicConnection::OnRetryPacket(QuicConnectionId original_connection_id, + QuicConnectionId new_connection_id, + absl::string_view retry_token, + absl::string_view retry_integrity_tag, + absl::string_view retry_without_tag) { + QUICHE_DCHECK_EQ(Perspective::IS_CLIENT, perspective_); + if (version().UsesTls()) { + if (!CryptoUtils::ValidateRetryIntegrityTag( + version(), default_path_.server_connection_id, retry_without_tag, + retry_integrity_tag)) { + QUIC_DLOG(ERROR) << "Ignoring RETRY with invalid integrity tag"; + return; + } + } else { + if (original_connection_id != default_path_.server_connection_id) { + QUIC_DLOG(ERROR) << "Ignoring RETRY with original connection ID " + << original_connection_id << " not matching expected " + << default_path_.server_connection_id << " token " + << absl::BytesToHexString(retry_token); + return; + } + } + framer_.set_drop_incoming_retry_packets(true); + stats_.retry_packet_processed = true; + QUIC_DLOG(INFO) << "Received RETRY, replacing connection ID " + << default_path_.server_connection_id << " with " + << new_connection_id << ", received token " + << absl::BytesToHexString(retry_token); + if (!original_destination_connection_id_.has_value()) { + original_destination_connection_id_ = default_path_.server_connection_id; + } + QUICHE_DCHECK(!retry_source_connection_id_.has_value()) + << retry_source_connection_id_.value(); + retry_source_connection_id_ = new_connection_id; + ReplaceInitialServerConnectionId(new_connection_id); + packet_creator_.SetRetryToken(retry_token); + + // Reinstall initial crypters because the connection ID changed. + InstallInitialCrypters(default_path_.server_connection_id); + + sent_packet_manager_.MarkInitialPacketsForRetransmission(); +} + +void QuicConnection::SetOriginalDestinationConnectionId( + const QuicConnectionId& original_destination_connection_id) { + QUIC_DLOG(INFO) << "Setting original_destination_connection_id to " + << original_destination_connection_id + << " on connection with server_connection_id " + << default_path_.server_connection_id; + QUICHE_DCHECK_NE(original_destination_connection_id, + default_path_.server_connection_id); + InstallInitialCrypters(original_destination_connection_id); + QUICHE_DCHECK(!original_destination_connection_id_.has_value()) + << original_destination_connection_id_.value(); + original_destination_connection_id_ = original_destination_connection_id; + original_destination_connection_id_replacement_ = + default_path_.server_connection_id; +} + +QuicConnectionId QuicConnection::GetOriginalDestinationConnectionId() const { + if (original_destination_connection_id_.has_value()) { + return original_destination_connection_id_.value(); + } + return default_path_.server_connection_id; +} + +void QuicConnection::RetireOriginalDestinationConnectionId() { + if (original_destination_connection_id_.has_value()) { + visitor_->OnServerConnectionIdRetired(*original_destination_connection_id_); + original_destination_connection_id_.reset(); + } +} + +bool QuicConnection::ValidateServerConnectionId( + const QuicPacketHeader& header) const { + if (perspective_ == Perspective::IS_CLIENT && + header.form == IETF_QUIC_SHORT_HEADER_PACKET) { + return true; + } + + QuicConnectionId server_connection_id = + GetServerConnectionIdAsRecipient(header, perspective_); + + if (server_connection_id == default_path_.server_connection_id || + server_connection_id == original_destination_connection_id_) { + return true; + } + + if (PacketCanReplaceServerConnectionId(header, perspective_)) { + QUIC_DLOG(INFO) << ENDPOINT << "Accepting packet with new connection ID " + << server_connection_id << " instead of " + << default_path_.server_connection_id; + return true; + } + + if (connection_migration_use_new_cid_ && + perspective_ == Perspective::IS_SERVER && + self_issued_cid_manager_ != nullptr && + self_issued_cid_manager_->IsConnectionIdInUse(server_connection_id)) { + return true; + } + + if (NewServerConnectionIdMightBeValid( + header, perspective_, server_connection_id_replaced_by_initial_)) { + return true; + } + + return false; +} + +bool QuicConnection::OnUnauthenticatedPublicHeader( + const QuicPacketHeader& header) { + last_received_packet_info_.destination_connection_id = + header.destination_connection_id; + // If last packet destination connection ID is the original server + // connection ID chosen by client, replaces it with the connection ID chosen + // by server. + if (perspective_ == Perspective::IS_SERVER && + original_destination_connection_id_.has_value() && + last_received_packet_info_.destination_connection_id == + *original_destination_connection_id_) { + last_received_packet_info_.destination_connection_id = + original_destination_connection_id_replacement_; + } + + // As soon as we receive an initial we start ignoring subsequent retries. + if (header.version_flag && header.long_packet_type == INITIAL) { + framer_.set_drop_incoming_retry_packets(true); + } + + if (!ValidateServerConnectionId(header)) { + ++stats_.packets_dropped; + QuicConnectionId server_connection_id = + GetServerConnectionIdAsRecipient(header, perspective_); + QUIC_DLOG(INFO) << ENDPOINT + << "Ignoring packet from unexpected server connection ID " + << server_connection_id << " instead of " + << default_path_.server_connection_id; + if (debug_visitor_ != nullptr) { + debug_visitor_->OnIncorrectConnectionId(server_connection_id); + } + QUICHE_DCHECK_NE(Perspective::IS_SERVER, perspective_); + return false; + } + + if (!version().SupportsClientConnectionIds()) { + return true; + } + + if (perspective_ == Perspective::IS_SERVER && + header.form == IETF_QUIC_SHORT_HEADER_PACKET) { + return true; + } + + QuicConnectionId client_connection_id = + GetClientConnectionIdAsRecipient(header, perspective_); + + if (client_connection_id == default_path_.client_connection_id) { + return true; + } + + if (!client_connection_id_is_set_ && perspective_ == Perspective::IS_SERVER) { + QUIC_DLOG(INFO) << ENDPOINT + << "Setting client connection ID from first packet to " + << client_connection_id; + set_client_connection_id(client_connection_id); + return true; + } + + if (connection_migration_use_new_cid_ && + perspective_ == Perspective::IS_CLIENT && + self_issued_cid_manager_ != nullptr && + self_issued_cid_manager_->IsConnectionIdInUse(client_connection_id)) { + return true; + } + + ++stats_.packets_dropped; + QUIC_DLOG(INFO) << ENDPOINT + << "Ignoring packet from unexpected client connection ID " + << client_connection_id << " instead of " + << default_path_.client_connection_id; + return false; +} + +bool QuicConnection::OnUnauthenticatedHeader(const QuicPacketHeader& header) { + if (debug_visitor_ != nullptr) { + debug_visitor_->OnUnauthenticatedHeader(header); + } + + // Sanity check on the server connection ID in header. + QUICHE_DCHECK(ValidateServerConnectionId(header)); + + if (packet_creator_.HasPendingFrames()) { + // Incoming packets may change a queued ACK frame. + const std::string error_details = + "Pending frames must be serialized before incoming packets are " + "processed."; + QUIC_BUG(quic_pending_frames_not_serialized) + << error_details << ", received header: " << header; + CloseConnection(QUIC_INTERNAL_ERROR, error_details, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return false; + } + + return true; +} + +void QuicConnection::OnSuccessfulVersionNegotiation() { + visitor_->OnSuccessfulVersionNegotiation(version()); + if (debug_visitor_ != nullptr) { + debug_visitor_->OnSuccessfulVersionNegotiation(version()); + } +} + +void QuicConnection::OnSuccessfulMigration(bool is_port_change) { + QUICHE_DCHECK_EQ(perspective_, Perspective::IS_CLIENT); + if (IsPathDegrading() && !multi_port_stats_) { + // If path was previously degrading, and migration is successful after + // probing, restart the path degrading and blackhole detection. + // In the case of multi-port, since the alt-path state is inferred from + // historical data, we can't trust it until we receive data on the new path. + OnForwardProgressMade(); + } + if (IsAlternativePath(default_path_.self_address, + default_path_.peer_address)) { + // Reset alternative path state even if it is still under validation. + alternative_path_.Clear(); + } + // TODO(b/159074035): notify SentPacketManger with RTT sample from probing. + if (version().HasIetfQuicFrames() && !is_port_change) { + sent_packet_manager_.OnConnectionMigration(/*reset_send_algorithm=*/true); + } +} + +void QuicConnection::OnTransportParametersSent( + const TransportParameters& transport_parameters) const { + if (debug_visitor_ != nullptr) { + debug_visitor_->OnTransportParametersSent(transport_parameters); + } +} + +void QuicConnection::OnTransportParametersReceived( + const TransportParameters& transport_parameters) const { + if (debug_visitor_ != nullptr) { + debug_visitor_->OnTransportParametersReceived(transport_parameters); + } +} + +void QuicConnection::OnTransportParametersResumed( + const TransportParameters& transport_parameters) const { + if (debug_visitor_ != nullptr) { + debug_visitor_->OnTransportParametersResumed(transport_parameters); + } +} + +bool QuicConnection::HasPendingAcks() const { return ack_alarm_->IsSet(); } + +void QuicConnection::OnUserAgentIdKnown(const std::string& /*user_agent_id*/) { + sent_packet_manager_.OnUserAgentIdKnown(); +} + +void QuicConnection::OnDecryptedPacket(size_t /*length*/, + EncryptionLevel level) { + last_received_packet_info_.decrypted_level = level; + last_received_packet_info_.decrypted = true; + if (level == ENCRYPTION_FORWARD_SECURE && + !have_decrypted_first_one_rtt_packet_) { + have_decrypted_first_one_rtt_packet_ = true; + if (version().UsesTls() && perspective_ == Perspective::IS_SERVER) { + // Servers MAY temporarily retain 0-RTT keys to allow decrypting reordered + // packets without requiring their contents to be retransmitted with 1-RTT + // keys. After receiving a 1-RTT packet, servers MUST discard 0-RTT keys + // within a short time; the RECOMMENDED time period is three times the + // Probe Timeout. + // https://quicwg.org/base-drafts/draft-ietf-quic-tls.html#name-discarding-0-rtt-keys + discard_zero_rtt_decryption_keys_alarm_->Set( + clock_->ApproximateNow() + sent_packet_manager_.GetPtoDelay() * 3); + } + } + if (EnforceAntiAmplificationLimit() && !IsHandshakeConfirmed() && + (level == ENCRYPTION_HANDSHAKE || level == ENCRYPTION_FORWARD_SECURE)) { + // Address is validated by successfully processing a HANDSHAKE or 1-RTT + // packet. + default_path_.validated = true; + stats_.address_validated_via_decrypting_packet = true; + } + idle_network_detector_.OnPacketReceived( + last_received_packet_info_.receipt_time); + + visitor_->OnPacketDecrypted(level); +} + +QuicSocketAddress QuicConnection::GetEffectivePeerAddressFromCurrentPacket() + const { + // By default, the connection is not proxied, and the effective peer address + // is the packet's source address, i.e. the direct peer address. + return last_received_packet_info_.source_address; +} + +bool QuicConnection::OnPacketHeader(const QuicPacketHeader& header) { + if (debug_visitor_ != nullptr) { + debug_visitor_->OnPacketHeader(header, clock_->ApproximateNow(), + last_received_packet_info_.decrypted_level); + } + + // Will be decremented below if we fall through to return true. + ++stats_.packets_dropped; + + if (!ProcessValidatedPacket(header)) { + return false; + } + + // Initialize the current packet content state. + current_packet_content_ = NO_FRAMES_RECEIVED; + is_current_packet_connectivity_probing_ = false; + has_path_challenge_in_current_packet_ = false; + current_effective_peer_migration_type_ = NO_CHANGE; + + if (perspective_ == Perspective::IS_CLIENT) { + if (!GetLargestReceivedPacket().IsInitialized() || + header.packet_number > GetLargestReceivedPacket()) { + if (version().HasIetfQuicFrames()) { + // Client processes packets from any known server address, but only + // updates peer address on initialization and/or to validated server + // preferred address. + } else { + // Update direct_peer_address_ and default path peer_address immediately + // for client connections. + // TODO(fayang): only change peer addresses in application data packet + // number space. + UpdatePeerAddress(last_received_packet_info_.source_address); + default_path_.peer_address = GetEffectivePeerAddressFromCurrentPacket(); + } + } + } else { + // At server, remember the address change type of effective_peer_address + // in current_effective_peer_migration_type_. But this variable alone + // doesn't necessarily starts a migration. A migration will be started + // later, once the current packet is confirmed to meet the following + // conditions: + // 1) current_effective_peer_migration_type_ is not NO_CHANGE. + // 2) The current packet is not a connectivity probing. + // 3) The current packet is not reordered, i.e. its packet number is the + // largest of this connection so far. + // Once the above conditions are confirmed, a new migration will start + // even if there is an active migration underway. + current_effective_peer_migration_type_ = + QuicUtils::DetermineAddressChangeType( + default_path_.peer_address, + GetEffectivePeerAddressFromCurrentPacket()); + + if (connection_migration_use_new_cid_) { + auto effective_peer_address = GetEffectivePeerAddressFromCurrentPacket(); + // Since server does not send new connection ID to client before handshake + // completion and source connection ID is omitted in short header packet, + // the server_connection_id on PathState on the server side does not + // affect the packets server writes after handshake completion. On the + // other hand, it is still desirable to have the "correct" server + // connection ID set on path. + // 1) If client uses 1 unique server connection ID per path and the packet + // is received from an existing path, then + // last_received_packet_info_.destination_connection_id will always be the + // same as the server connection ID on path. Server side will maintain the + // 1-to-1 mapping from server connection ID to path. 2) If client uses + // multiple server connection IDs on the same path, compared to the + // server_connection_id on path, + // last_received_packet_info_.destination_connection_id has the advantage + // that it is still present in the session map since the packet can be + // routed here regardless of packet reordering. + if (IsDefaultPath(last_received_packet_info_.destination_address, + effective_peer_address)) { + default_path_.server_connection_id = + last_received_packet_info_.destination_connection_id; + } else if (IsAlternativePath( + last_received_packet_info_.destination_address, + effective_peer_address)) { + alternative_path_.server_connection_id = + last_received_packet_info_.destination_connection_id; + } + } + + if (last_received_packet_info_.destination_connection_id != + default_path_.server_connection_id && + (!original_destination_connection_id_.has_value() || + last_received_packet_info_.destination_connection_id != + *original_destination_connection_id_)) { + QUIC_CODE_COUNT(quic_connection_id_change); + } + + QUIC_DLOG_IF(INFO, current_effective_peer_migration_type_ != NO_CHANGE) + << ENDPOINT << "Effective peer's ip:port changed from " + << default_path_.peer_address.ToString() << " to " + << GetEffectivePeerAddressFromCurrentPacket().ToString() + << ", active_effective_peer_migration_type is " + << active_effective_peer_migration_type_; + } + + --stats_.packets_dropped; + QUIC_DVLOG(1) << ENDPOINT << "Received packet header: " << header; + last_received_packet_info_.header = header; + if (!stats_.first_decrypted_packet.IsInitialized()) { + stats_.first_decrypted_packet = + last_received_packet_info_.header.packet_number; + } + + switch (last_received_packet_info_.ecn_codepoint) { + case ECN_NOT_ECT: + break; + case ECN_ECT0: + stats_.num_ecn_marks_received.ect0++; + break; + case ECN_ECT1: + stats_.num_ecn_marks_received.ect1++; + break; + case ECN_CE: + stats_.num_ecn_marks_received.ce++; + break; + } + + // Record packet receipt to populate ack info before processing stream + // frames, since the processing may result in sending a bundled ack. + QuicTime receipt_time = idle_network_detector_.time_of_last_received_packet(); + if (SupportsMultiplePacketNumberSpaces()) { + receipt_time = last_received_packet_info_.receipt_time; + } + uber_received_packet_manager_.RecordPacketReceived( + last_received_packet_info_.decrypted_level, + last_received_packet_info_.header, receipt_time, + last_received_packet_info_.ecn_codepoint); + if (EnforceAntiAmplificationLimit() && !IsHandshakeConfirmed() && + !header.retry_token.empty() && + visitor_->ValidateToken(header.retry_token)) { + QUIC_DLOG(INFO) << ENDPOINT << "Address validated via token."; + QUIC_CODE_COUNT(quic_address_validated_via_token); + default_path_.validated = true; + stats_.address_validated_via_token = true; + } + QUICHE_DCHECK(connected_); + return true; +} + +bool QuicConnection::OnStreamFrame(const QuicStreamFrame& frame) { + QUIC_BUG_IF(quic_bug_12714_3, !connected_) + << "Processing STREAM frame when connection is closed. Received packet " + "info: " + << last_received_packet_info_; + + // Since a stream frame was received, this is not a connectivity probe. + // A probe only contains a PING and full padding. + if (!UpdatePacketContent(STREAM_FRAME)) { + return false; + } + + if (debug_visitor_ != nullptr) { + debug_visitor_->OnStreamFrame(frame); + } + if (!QuicUtils::IsCryptoStreamId(transport_version(), frame.stream_id) && + last_received_packet_info_.decrypted_level == ENCRYPTION_INITIAL) { + if (MaybeConsiderAsMemoryCorruption(frame)) { + CloseConnection(QUIC_MAYBE_CORRUPTED_MEMORY, + "Received crypto frame on non crypto stream.", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return false; + } + + QUIC_PEER_BUG(quic_peer_bug_10511_6) + << ENDPOINT << "Received an unencrypted data frame: closing connection" + << " packet_number:" << last_received_packet_info_.header.packet_number + << " stream_id:" << frame.stream_id + << " received_packets:" << ack_frame(); + CloseConnection(QUIC_UNENCRYPTED_STREAM_DATA, + "Unencrypted stream data seen.", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return false; + } + // TODO(fayang): Consider moving UpdatePacketContent and + // MaybeUpdateAckTimeout to a stand-alone function instead of calling them for + // all frames. + MaybeUpdateAckTimeout(); + visitor_->OnStreamFrame(frame); + stats_.stream_bytes_received += frame.data_length; + ping_manager_.reset_consecutive_retransmittable_on_wire_count(); + return connected_; +} + +bool QuicConnection::OnCryptoFrame(const QuicCryptoFrame& frame) { + QUIC_BUG_IF(quic_bug_12714_4, !connected_) + << "Processing CRYPTO frame when connection is closed. Received packet " + "info: " + << last_received_packet_info_; + + // Since a CRYPTO frame was received, this is not a connectivity probe. + // A probe only contains a PING and full padding. + if (!UpdatePacketContent(CRYPTO_FRAME)) { + return false; + } + + if (debug_visitor_ != nullptr) { + debug_visitor_->OnCryptoFrame(frame); + } + MaybeUpdateAckTimeout(); + visitor_->OnCryptoFrame(frame); + return connected_; +} + +bool QuicConnection::OnAckFrameStart(QuicPacketNumber largest_acked, + QuicTime::Delta ack_delay_time) { + QUIC_BUG_IF(quic_bug_12714_5, !connected_) + << "Processing ACK frame start when connection is closed. Received " + "packet info: " + << last_received_packet_info_; + + if (processing_ack_frame_) { + CloseConnection(QUIC_INVALID_ACK_DATA, + "Received a new ack while processing an ack frame.", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return false; + } + + // Since an ack frame was received, this is not a connectivity probe. + // A probe only contains a PING and full padding. + if (!UpdatePacketContent(ACK_FRAME)) { + return false; + } + + QUIC_DVLOG(1) << ENDPOINT + << "OnAckFrameStart, largest_acked: " << largest_acked; + + if (GetLargestReceivedPacketWithAck().IsInitialized() && + last_received_packet_info_.header.packet_number <= + GetLargestReceivedPacketWithAck()) { + QUIC_DLOG(INFO) << ENDPOINT << "Received an old ack frame: ignoring"; + return true; + } + + if (!sent_packet_manager_.GetLargestSentPacket().IsInitialized() || + largest_acked > sent_packet_manager_.GetLargestSentPacket()) { + QUIC_DLOG(WARNING) << ENDPOINT + << "Peer's observed unsent packet:" << largest_acked + << " vs " << sent_packet_manager_.GetLargestSentPacket() + << ". SupportsMultiplePacketNumberSpaces():" + << SupportsMultiplePacketNumberSpaces() + << ", last_received_packet_info_.decrypted_level:" + << last_received_packet_info_.decrypted_level; + // We got an ack for data we have not sent. + CloseConnection(QUIC_INVALID_ACK_DATA, "Largest observed too high.", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return false; + } + processing_ack_frame_ = true; + sent_packet_manager_.OnAckFrameStart( + largest_acked, ack_delay_time, + idle_network_detector_.time_of_last_received_packet()); + return true; +} + +bool QuicConnection::OnAckRange(QuicPacketNumber start, QuicPacketNumber end) { + QUIC_BUG_IF(quic_bug_12714_6, !connected_) + << "Processing ACK frame range when connection is closed. Received " + "packet info: " + << last_received_packet_info_; + QUIC_DVLOG(1) << ENDPOINT << "OnAckRange: [" << start << ", " << end << ")"; + + if (GetLargestReceivedPacketWithAck().IsInitialized() && + last_received_packet_info_.header.packet_number <= + GetLargestReceivedPacketWithAck()) { + QUIC_DLOG(INFO) << ENDPOINT << "Received an old ack frame: ignoring"; + return true; + } + + sent_packet_manager_.OnAckRange(start, end); + return true; +} + +bool QuicConnection::OnAckTimestamp(QuicPacketNumber packet_number, + QuicTime timestamp) { + QUIC_BUG_IF(quic_bug_10511_7, !connected_) + << "Processing ACK frame time stamp when connection is closed. Received " + "packet info: " + << last_received_packet_info_; + QUIC_DVLOG(1) << ENDPOINT << "OnAckTimestamp: [" << packet_number << ", " + << timestamp.ToDebuggingValue() << ")"; + + if (GetLargestReceivedPacketWithAck().IsInitialized() && + last_received_packet_info_.header.packet_number <= + GetLargestReceivedPacketWithAck()) { + QUIC_DLOG(INFO) << ENDPOINT << "Received an old ack frame: ignoring"; + return true; + } + + sent_packet_manager_.OnAckTimestamp(packet_number, timestamp); + return true; +} + +bool QuicConnection::OnAckFrameEnd( + QuicPacketNumber start, const absl::optional& ecn_counts) { + QUIC_BUG_IF(quic_bug_12714_7, !connected_) + << "Processing ACK frame end when connection is closed. Received packet " + "info: " + << last_received_packet_info_; + QUIC_DVLOG(1) << ENDPOINT << "OnAckFrameEnd, start: " << start; + + if (GetLargestReceivedPacketWithAck().IsInitialized() && + last_received_packet_info_.header.packet_number <= + GetLargestReceivedPacketWithAck()) { + QUIC_DLOG(INFO) << ENDPOINT << "Received an old ack frame: ignoring"; + return true; + } + const bool one_rtt_packet_was_acked = + sent_packet_manager_.one_rtt_packet_acked(); + const bool zero_rtt_packet_was_acked = + sent_packet_manager_.zero_rtt_packet_acked(); + const AckResult ack_result = sent_packet_manager_.OnAckFrameEnd( + idle_network_detector_.time_of_last_received_packet(), + last_received_packet_info_.header.packet_number, + last_received_packet_info_.decrypted_level, ecn_counts); + if (ack_result != PACKETS_NEWLY_ACKED && + ack_result != NO_PACKETS_NEWLY_ACKED) { + // Error occurred (e.g., this ACK tries to ack packets in wrong packet + // number space), and this would cause the connection to be closed. + QUIC_DLOG(ERROR) << ENDPOINT + << "Error occurred when processing an ACK frame: " + << QuicUtils::AckResultToString(ack_result); + return false; + } + if (SupportsMultiplePacketNumberSpaces() && !one_rtt_packet_was_acked && + sent_packet_manager_.one_rtt_packet_acked()) { + visitor_->OnOneRttPacketAcknowledged(); + } + if (debug_visitor_ != nullptr && version().UsesTls() && + !zero_rtt_packet_was_acked && + sent_packet_manager_.zero_rtt_packet_acked()) { + debug_visitor_->OnZeroRttPacketAcked(); + } + // Cancel the send alarm because new packets likely have been acked, which + // may change the congestion window and/or pacing rate. Canceling the alarm + // causes CanWrite to recalculate the next send time. + if (send_alarm_->IsSet()) { + send_alarm_->Cancel(); + } + if (supports_release_time_) { + // Update pace time into future because smoothed RTT is likely updated. + UpdateReleaseTimeIntoFuture(); + } + SetLargestReceivedPacketWithAck( + last_received_packet_info_.header.packet_number); + // If the incoming ack's packets set expresses missing packets: peer is still + // waiting for a packet lower than a packet that we are no longer planning to + // send. + // If the incoming ack's packets set expresses received packets: peer is still + // acking packets which we never care about. + // Send an ack to raise the high water mark. + const bool send_stop_waiting = + no_stop_waiting_frames_ ? false : GetLeastUnacked() > start; + PostProcessAfterAckFrame(send_stop_waiting, + ack_result == PACKETS_NEWLY_ACKED); + processing_ack_frame_ = false; + return connected_; +} + +bool QuicConnection::OnStopWaitingFrame(const QuicStopWaitingFrame& frame) { + QUIC_BUG_IF(quic_bug_12714_8, !connected_) + << "Processing STOP_WAITING frame when connection is closed. Received " + "packet info: " + << last_received_packet_info_; + + // Since a stop waiting frame was received, this is not a connectivity probe. + // A probe only contains a PING and full padding. + if (!UpdatePacketContent(STOP_WAITING_FRAME)) { + return false; + } + + if (no_stop_waiting_frames_) { + return true; + } + if (largest_seen_packet_with_stop_waiting_.IsInitialized() && + last_received_packet_info_.header.packet_number <= + largest_seen_packet_with_stop_waiting_) { + QUIC_DLOG(INFO) << ENDPOINT + << "Received an old stop waiting frame: ignoring"; + return true; + } + + const char* error = ValidateStopWaitingFrame(frame); + if (error != nullptr) { + CloseConnection(QUIC_INVALID_STOP_WAITING_DATA, error, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return false; + } + + if (debug_visitor_ != nullptr) { + debug_visitor_->OnStopWaitingFrame(frame); + } + + largest_seen_packet_with_stop_waiting_ = + last_received_packet_info_.header.packet_number; + uber_received_packet_manager_.DontWaitForPacketsBefore( + last_received_packet_info_.decrypted_level, frame.least_unacked); + return connected_; +} + +bool QuicConnection::OnPaddingFrame(const QuicPaddingFrame& frame) { + QUIC_BUG_IF(quic_bug_12714_9, !connected_) + << "Processing PADDING frame when connection is closed. Received packet " + "info: " + << last_received_packet_info_; + if (!UpdatePacketContent(PADDING_FRAME)) { + return false; + } + + if (debug_visitor_ != nullptr) { + debug_visitor_->OnPaddingFrame(frame); + } + return true; +} + +bool QuicConnection::OnPingFrame(const QuicPingFrame& frame) { + QUIC_BUG_IF(quic_bug_12714_10, !connected_) + << "Processing PING frame when connection is closed. Received packet " + "info: " + << last_received_packet_info_; + if (!UpdatePacketContent(PING_FRAME)) { + return false; + } + + if (debug_visitor_ != nullptr) { + QuicTime::Delta ping_received_delay = QuicTime::Delta::Zero(); + const QuicTime now = clock_->ApproximateNow(); + if (now > stats_.connection_creation_time) { + ping_received_delay = now - stats_.connection_creation_time; + } + debug_visitor_->OnPingFrame(frame, ping_received_delay); + } + MaybeUpdateAckTimeout(); + return true; +} + +const char* QuicConnection::ValidateStopWaitingFrame( + const QuicStopWaitingFrame& stop_waiting) { + const QuicPacketNumber peer_least_packet_awaiting_ack = + uber_received_packet_manager_.peer_least_packet_awaiting_ack(); + if (peer_least_packet_awaiting_ack.IsInitialized() && + stop_waiting.least_unacked < peer_least_packet_awaiting_ack) { + QUIC_DLOG(ERROR) << ENDPOINT << "Peer's sent low least_unacked: " + << stop_waiting.least_unacked << " vs " + << peer_least_packet_awaiting_ack; + // We never process old ack frames, so this number should only increase. + return "Least unacked too small."; + } + + if (stop_waiting.least_unacked > + last_received_packet_info_.header.packet_number) { + QUIC_DLOG(ERROR) << ENDPOINT + << "Peer sent least_unacked:" << stop_waiting.least_unacked + << " greater than the enclosing packet number:" + << last_received_packet_info_.header.packet_number; + return "Least unacked too large."; + } + + return nullptr; +} + +bool QuicConnection::OnRstStreamFrame(const QuicRstStreamFrame& frame) { + QUIC_BUG_IF(quic_bug_12714_11, !connected_) + << "Processing RST_STREAM frame when connection is closed. Received " + "packet info: " + << last_received_packet_info_; + + // Since a reset stream frame was received, this is not a connectivity probe. + // A probe only contains a PING and full padding. + if (!UpdatePacketContent(RST_STREAM_FRAME)) { + return false; + } + + if (debug_visitor_ != nullptr) { + debug_visitor_->OnRstStreamFrame(frame); + } + QUIC_DLOG(INFO) << ENDPOINT + << "RST_STREAM_FRAME received for stream: " << frame.stream_id + << " with error: " + << QuicRstStreamErrorCodeToString(frame.error_code); + MaybeUpdateAckTimeout(); + visitor_->OnRstStream(frame); + return connected_; +} + +bool QuicConnection::OnStopSendingFrame(const QuicStopSendingFrame& frame) { + QUIC_BUG_IF(quic_bug_12714_12, !connected_) + << "Processing STOP_SENDING frame when connection is closed. Received " + "packet info: " + << last_received_packet_info_; + + // Since a reset stream frame was received, this is not a connectivity probe. + // A probe only contains a PING and full padding. + if (!UpdatePacketContent(STOP_SENDING_FRAME)) { + return false; + } + + if (debug_visitor_ != nullptr) { + debug_visitor_->OnStopSendingFrame(frame); + } + + QUIC_DLOG(INFO) << ENDPOINT << "STOP_SENDING frame received for stream: " + << frame.stream_id + << " with error: " << frame.ietf_error_code; + MaybeUpdateAckTimeout(); + visitor_->OnStopSendingFrame(frame); + return connected_; +} + +class ReversePathValidationContext : public QuicPathValidationContext { + public: + ReversePathValidationContext(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + const QuicSocketAddress& effective_peer_address, + QuicConnection* connection) + : QuicPathValidationContext(self_address, peer_address, + effective_peer_address), + connection_(connection) {} + + QuicPacketWriter* WriterToUse() override { return connection_->writer(); } + + private: + QuicConnection* connection_; +}; + +bool QuicConnection::OnPathChallengeFrame(const QuicPathChallengeFrame& frame) { + QUIC_BUG_IF(quic_bug_10511_8, !connected_) + << "Processing PATH_CHALLENGE frame when connection is closed. Received " + "packet info: " + << last_received_packet_info_; + if (has_path_challenge_in_current_packet_) { + // Only respond to the 1st PATH_CHALLENGE in the packet. + return true; + } + if (!validate_client_addresses_) { + return OnPathChallengeFrameInternal(frame); + } + { + // TODO(danzh) inline OnPathChallengeFrameInternal() once + // validate_client_addresses_ is deprecated. + if (!OnPathChallengeFrameInternal(frame)) { + return false; + } + } + return connected_; +} + +bool QuicConnection::OnPathChallengeFrameInternal( + const QuicPathChallengeFrame& frame) { + should_proactively_validate_peer_address_on_path_challenge_ = false; + // UpdatePacketContent() may start reverse path validation. + if (!UpdatePacketContent(PATH_CHALLENGE_FRAME)) { + return false; + } + if (debug_visitor_ != nullptr) { + debug_visitor_->OnPathChallengeFrame(frame); + } + // On the server side, send response to the source address of the current + // incoming packet according to RFC9000. + // On the client side, send response to the default peer address which should + // be on an existing path with a pre-assigned a destination CID. + const QuicSocketAddress effective_peer_address_to_respond = + perspective_ == Perspective::IS_CLIENT + ? effective_peer_address() + : GetEffectivePeerAddressFromCurrentPacket(); + const QuicSocketAddress direct_peer_address_to_respond = + perspective_ == Perspective::IS_CLIENT + ? direct_peer_address_ + : last_received_packet_info_.source_address; + QuicConnectionId client_cid, server_cid; + FindOnPathConnectionIds(last_received_packet_info_.destination_address, + effective_peer_address_to_respond, &client_cid, + &server_cid); + QuicPacketCreator::ScopedPeerAddressContext context( + &packet_creator_, direct_peer_address_to_respond, client_cid, server_cid, + connection_migration_use_new_cid_); + if (should_proactively_validate_peer_address_on_path_challenge_) { + // Conditions to proactively validate peer address: + // The perspective is server + // The PATH_CHALLENGE is received on an unvalidated alternative path. + // The connection isn't validating migrated peer address, which is of + // higher prority. + QUIC_DVLOG(1) << "Proactively validate the effective peer address " + << effective_peer_address_to_respond; + QUIC_CODE_COUNT_N(quic_kick_off_client_address_validation, 2, 6); + ValidatePath(std::make_unique( + default_path_.self_address, direct_peer_address_to_respond, + effective_peer_address_to_respond, this), + std::make_unique( + this, peer_address()), + PathValidationReason::kReversePathValidation); + } + has_path_challenge_in_current_packet_ = true; + MaybeUpdateAckTimeout(); + // Queue or send PATH_RESPONSE. + if (!SendPathResponse(frame.data_buffer, direct_peer_address_to_respond, + effective_peer_address_to_respond)) { + QUIC_CODE_COUNT(quic_failed_to_send_path_response); + } + // TODO(b/150095588): change the stats to + // num_valid_path_challenge_received. + ++stats_.num_connectivity_probing_received; + + // SendPathResponse() might cause connection to be closed. + return connected_; +} + +bool QuicConnection::OnPathResponseFrame(const QuicPathResponseFrame& frame) { + QUIC_BUG_IF(quic_bug_10511_9, !connected_) + << "Processing PATH_RESPONSE frame when connection is closed. Received " + "packet info: " + << last_received_packet_info_; + ++stats_.num_path_response_received; + if (!UpdatePacketContent(PATH_RESPONSE_FRAME)) { + return false; + } + if (debug_visitor_ != nullptr) { + debug_visitor_->OnPathResponseFrame(frame); + } + MaybeUpdateAckTimeout(); + path_validator_.OnPathResponse( + frame.data_buffer, last_received_packet_info_.destination_address); + return connected_; +} + +bool QuicConnection::OnConnectionCloseFrame( + const QuicConnectionCloseFrame& frame) { + QUIC_BUG_IF(quic_bug_10511_10, !connected_) + << "Processing CONNECTION_CLOSE frame when connection is closed. " + "Received packet info: " + << last_received_packet_info_; + + // Since a connection close frame was received, this is not a connectivity + // probe. A probe only contains a PING and full padding. + if (!UpdatePacketContent(CONNECTION_CLOSE_FRAME)) { + return false; + } + + if (debug_visitor_ != nullptr) { + debug_visitor_->OnConnectionCloseFrame(frame); + } + switch (frame.close_type) { + case GOOGLE_QUIC_CONNECTION_CLOSE: + QUIC_DLOG(INFO) << ENDPOINT << "Received ConnectionClose for connection: " + << connection_id() << ", with error: " + << QuicErrorCodeToString(frame.quic_error_code) << " (" + << frame.error_details << ")"; + break; + case IETF_QUIC_TRANSPORT_CONNECTION_CLOSE: + QUIC_DLOG(INFO) << ENDPOINT + << "Received Transport ConnectionClose for connection: " + << connection_id() << ", with error: " + << QuicErrorCodeToString(frame.quic_error_code) << " (" + << frame.error_details << ")" + << ", transport error code: " + << QuicIetfTransportErrorCodeString( + static_cast( + frame.wire_error_code)) + << ", error frame type: " + << frame.transport_close_frame_type; + break; + case IETF_QUIC_APPLICATION_CONNECTION_CLOSE: + QUIC_DLOG(INFO) << ENDPOINT + << "Received Application ConnectionClose for connection: " + << connection_id() << ", with error: " + << QuicErrorCodeToString(frame.quic_error_code) << " (" + << frame.error_details << ")" + << ", application error code: " << frame.wire_error_code; + break; + } + + if (frame.quic_error_code == QUIC_BAD_MULTIPATH_FLAG) { + QUIC_LOG_FIRST_N(ERROR, 10) + << "Unexpected QUIC_BAD_MULTIPATH_FLAG error." + << " last_received_header: " << last_received_packet_info_.header + << " encryption_level: " << encryption_level_; + } + TearDownLocalConnectionState(frame, ConnectionCloseSource::FROM_PEER); + return connected_; +} + +bool QuicConnection::OnMaxStreamsFrame(const QuicMaxStreamsFrame& frame) { + QUIC_BUG_IF(quic_bug_12714_13, !connected_) + << "Processing MAX_STREAMS frame when connection is closed. Received " + "packet info: " + << last_received_packet_info_; + if (!UpdatePacketContent(MAX_STREAMS_FRAME)) { + return false; + } + + if (debug_visitor_ != nullptr) { + debug_visitor_->OnMaxStreamsFrame(frame); + } + MaybeUpdateAckTimeout(); + return visitor_->OnMaxStreamsFrame(frame) && connected_; +} + +bool QuicConnection::OnStreamsBlockedFrame( + const QuicStreamsBlockedFrame& frame) { + QUIC_BUG_IF(quic_bug_10511_11, !connected_) + << "Processing STREAMS_BLOCKED frame when connection is closed. Received " + "packet info: " + << last_received_packet_info_; + if (!UpdatePacketContent(STREAMS_BLOCKED_FRAME)) { + return false; + } + + if (debug_visitor_ != nullptr) { + debug_visitor_->OnStreamsBlockedFrame(frame); + } + MaybeUpdateAckTimeout(); + return visitor_->OnStreamsBlockedFrame(frame) && connected_; +} + +bool QuicConnection::OnGoAwayFrame(const QuicGoAwayFrame& frame) { + QUIC_BUG_IF(quic_bug_12714_14, !connected_) + << "Processing GOAWAY frame when connection is closed. Received packet " + "info: " + << last_received_packet_info_; + + // Since a go away frame was received, this is not a connectivity probe. + // A probe only contains a PING and full padding. + if (!UpdatePacketContent(GOAWAY_FRAME)) { + return false; + } + + if (debug_visitor_ != nullptr) { + debug_visitor_->OnGoAwayFrame(frame); + } + QUIC_DLOG(INFO) << ENDPOINT << "GOAWAY_FRAME received with last good stream: " + << frame.last_good_stream_id + << " and error: " << QuicErrorCodeToString(frame.error_code) + << " and reason: " << frame.reason_phrase; + MaybeUpdateAckTimeout(); + visitor_->OnGoAway(frame); + return connected_; +} + +bool QuicConnection::OnWindowUpdateFrame(const QuicWindowUpdateFrame& frame) { + QUIC_BUG_IF(quic_bug_10511_12, !connected_) + << "Processing WINDOW_UPDATE frame when connection is closed. Received " + "packet info: " + << last_received_packet_info_; + + // Since a window update frame was received, this is not a connectivity probe. + // A probe only contains a PING and full padding. + if (!UpdatePacketContent(WINDOW_UPDATE_FRAME)) { + return false; + } + + if (debug_visitor_ != nullptr) { + debug_visitor_->OnWindowUpdateFrame( + frame, idle_network_detector_.time_of_last_received_packet()); + } + QUIC_DVLOG(1) << ENDPOINT << "WINDOW_UPDATE_FRAME received " << frame; + MaybeUpdateAckTimeout(); + visitor_->OnWindowUpdateFrame(frame); + return connected_; +} + +void QuicConnection::OnClientConnectionIdAvailable() { + QUICHE_DCHECK(perspective_ == Perspective::IS_SERVER); + if (!peer_issued_cid_manager_->HasUnusedConnectionId()) { + return; + } + if (default_path_.client_connection_id.IsEmpty()) { + // Count client connection ID patched onto the default path. + QUIC_RELOADABLE_FLAG_COUNT_N(quic_connection_migration_use_new_cid_v2, 3, + 6); + const QuicConnectionIdData* unused_cid_data = + peer_issued_cid_manager_->ConsumeOneUnusedConnectionId(); + QUIC_DVLOG(1) << ENDPOINT << "Patch connection ID " + << unused_cid_data->connection_id << " to default path"; + default_path_.client_connection_id = unused_cid_data->connection_id; + default_path_.stateless_reset_token = + unused_cid_data->stateless_reset_token; + QUICHE_DCHECK(!packet_creator_.HasPendingFrames()); + QUICHE_DCHECK(packet_creator_.GetDestinationConnectionId().IsEmpty()); + packet_creator_.SetClientConnectionId(default_path_.client_connection_id); + return; + } + if (alternative_path_.peer_address.IsInitialized() && + alternative_path_.client_connection_id.IsEmpty()) { + // Count client connection ID patched onto the alternative path. + QUIC_RELOADABLE_FLAG_COUNT_N(quic_connection_migration_use_new_cid_v2, 4, + 6); + const QuicConnectionIdData* unused_cid_data = + peer_issued_cid_manager_->ConsumeOneUnusedConnectionId(); + QUIC_DVLOG(1) << ENDPOINT << "Patch connection ID " + << unused_cid_data->connection_id << " to alternative path"; + alternative_path_.client_connection_id = unused_cid_data->connection_id; + alternative_path_.stateless_reset_token = + unused_cid_data->stateless_reset_token; + } +} + +bool QuicConnection::OnNewConnectionIdFrameInner( + const QuicNewConnectionIdFrame& frame) { + if (peer_issued_cid_manager_ == nullptr) { + CloseConnection( + IETF_QUIC_PROTOCOL_VIOLATION, + "Receives NEW_CONNECTION_ID while peer uses zero length connection ID", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return false; + } + std::string error_detail; + QuicErrorCode error = + peer_issued_cid_manager_->OnNewConnectionIdFrame(frame, &error_detail); + if (error != QUIC_NO_ERROR) { + CloseConnection(error, error_detail, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return false; + } + if (perspective_ == Perspective::IS_SERVER) { + OnClientConnectionIdAvailable(); + } + MaybeUpdateAckTimeout(); + return true; +} + +bool QuicConnection::OnNewConnectionIdFrame( + const QuicNewConnectionIdFrame& frame) { + QUICHE_DCHECK(version().HasIetfQuicFrames()); + QUIC_BUG_IF(quic_bug_10511_13, !connected_) + << "Processing NEW_CONNECTION_ID frame when connection is closed. " + "Received packet info: " + << last_received_packet_info_; + if (!UpdatePacketContent(NEW_CONNECTION_ID_FRAME)) { + return false; + } + + if (debug_visitor_ != nullptr) { + debug_visitor_->OnNewConnectionIdFrame(frame); + } + + if (!OnNewConnectionIdFrameInner(frame)) { + return false; + } + if (multi_port_stats_ != nullptr) { + MaybeCreateMultiPortPath(); + } + return true; +} + +bool QuicConnection::OnRetireConnectionIdFrame( + const QuicRetireConnectionIdFrame& frame) { + QUICHE_DCHECK(version().HasIetfQuicFrames()); + QUIC_BUG_IF(quic_bug_10511_14, !connected_) + << "Processing RETIRE_CONNECTION_ID frame when connection is closed. " + "Received packet info: " + << last_received_packet_info_; + if (!UpdatePacketContent(RETIRE_CONNECTION_ID_FRAME)) { + return false; + } + + if (debug_visitor_ != nullptr) { + debug_visitor_->OnRetireConnectionIdFrame(frame); + } + if (!connection_migration_use_new_cid_) { + // Do not respond to RetireConnectionId frame. + return true; + } + if (self_issued_cid_manager_ == nullptr) { + CloseConnection( + IETF_QUIC_PROTOCOL_VIOLATION, + "Receives RETIRE_CONNECTION_ID while new connection ID is never issued", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return false; + } + std::string error_detail; + QuicErrorCode error = self_issued_cid_manager_->OnRetireConnectionIdFrame( + frame, sent_packet_manager_.GetPtoDelay(), &error_detail); + if (error != QUIC_NO_ERROR) { + CloseConnection(error, error_detail, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return false; + } + // Count successfully received RETIRE_CONNECTION_ID frames. + QUIC_RELOADABLE_FLAG_COUNT_N(quic_connection_migration_use_new_cid_v2, 5, 6); + MaybeUpdateAckTimeout(); + return true; +} + +bool QuicConnection::OnNewTokenFrame(const QuicNewTokenFrame& frame) { + QUIC_BUG_IF(quic_bug_12714_15, !connected_) + << "Processing NEW_TOKEN frame when connection is closed. Received " + "packet info: " + << last_received_packet_info_; + if (!UpdatePacketContent(NEW_TOKEN_FRAME)) { + return false; + } + + if (debug_visitor_ != nullptr) { + debug_visitor_->OnNewTokenFrame(frame); + } + if (perspective_ == Perspective::IS_SERVER) { + CloseConnection(QUIC_INVALID_NEW_TOKEN, "Server received new token frame.", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return false; + } + // NEW_TOKEN frame should insitgate ACKs. + MaybeUpdateAckTimeout(); + visitor_->OnNewTokenReceived(frame.token); + return true; +} + +bool QuicConnection::OnMessageFrame(const QuicMessageFrame& frame) { + QUIC_BUG_IF(quic_bug_12714_16, !connected_) + << "Processing MESSAGE frame when connection is closed. Received packet " + "info: " + << last_received_packet_info_; + + // Since a message frame was received, this is not a connectivity probe. + // A probe only contains a PING and full padding. + if (!UpdatePacketContent(MESSAGE_FRAME)) { + return false; + } + + if (debug_visitor_ != nullptr) { + debug_visitor_->OnMessageFrame(frame); + } + MaybeUpdateAckTimeout(); + visitor_->OnMessageReceived( + absl::string_view(frame.data, frame.message_length)); + return connected_; +} + +bool QuicConnection::OnHandshakeDoneFrame(const QuicHandshakeDoneFrame& frame) { + QUIC_BUG_IF(quic_bug_10511_15, !connected_) + << "Processing HANDSHAKE_DONE frame when connection " + "is closed. Received packet " + "info: " + << last_received_packet_info_; + if (!version().UsesTls()) { + CloseConnection(IETF_QUIC_PROTOCOL_VIOLATION, + "Handshake done frame is unsupported", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return false; + } + + if (perspective_ == Perspective::IS_SERVER) { + CloseConnection(IETF_QUIC_PROTOCOL_VIOLATION, + "Server received handshake done frame.", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return false; + } + + // Since a handshake done frame was received, this is not a connectivity + // probe. A probe only contains a PING and full padding. + if (!UpdatePacketContent(HANDSHAKE_DONE_FRAME)) { + return false; + } + + if (debug_visitor_ != nullptr) { + debug_visitor_->OnHandshakeDoneFrame(frame); + } + MaybeUpdateAckTimeout(); + visitor_->OnHandshakeDoneReceived(); + return connected_; +} + +bool QuicConnection::OnAckFrequencyFrame(const QuicAckFrequencyFrame& frame) { + QUIC_BUG_IF(quic_bug_10511_16, !connected_) + << "Processing ACK_FREQUENCY frame when connection " + "is closed. Received packet " + "info: " + << last_received_packet_info_; + if (debug_visitor_ != nullptr) { + debug_visitor_->OnAckFrequencyFrame(frame); + } + if (!UpdatePacketContent(ACK_FREQUENCY_FRAME)) { + return false; + } + + if (!can_receive_ack_frequency_frame_) { + QUIC_LOG_EVERY_N_SEC(ERROR, 120) << "Get unexpected AckFrequencyFrame."; + return false; + } + if (auto packet_number_space = + QuicUtils::GetPacketNumberSpace( + last_received_packet_info_.decrypted_level) == APPLICATION_DATA) { + uber_received_packet_manager_.OnAckFrequencyFrame(frame); + } else { + QUIC_LOG_EVERY_N_SEC(ERROR, 120) + << "Get AckFrequencyFrame in packet number space " + << packet_number_space; + } + MaybeUpdateAckTimeout(); + return true; +} + +bool QuicConnection::OnBlockedFrame(const QuicBlockedFrame& frame) { + QUIC_BUG_IF(quic_bug_12714_17, !connected_) + << "Processing BLOCKED frame when connection is closed. Received packet " + "info: " + << last_received_packet_info_; + + // Since a blocked frame was received, this is not a connectivity probe. + // A probe only contains a PING and full padding. + if (!UpdatePacketContent(BLOCKED_FRAME)) { + return false; + } + + if (debug_visitor_ != nullptr) { + debug_visitor_->OnBlockedFrame(frame); + } + QUIC_DLOG(INFO) << ENDPOINT + << "BLOCKED_FRAME received for stream: " << frame.stream_id; + MaybeUpdateAckTimeout(); + visitor_->OnBlockedFrame(frame); + stats_.blocked_frames_received++; + return connected_; +} + +void QuicConnection::OnPacketComplete() { + // Don't do anything if this packet closed the connection. + if (!connected_) { + ClearLastFrames(); + return; + } + + if (IsCurrentPacketConnectivityProbing()) { + QUICHE_DCHECK(!version().HasIetfQuicFrames()); + ++stats_.num_connectivity_probing_received; + } + + QUIC_DVLOG(1) << ENDPOINT << "Got" + << (SupportsMultiplePacketNumberSpaces() + ? (" " + + EncryptionLevelToString( + last_received_packet_info_.decrypted_level)) + : "") + << " packet " << last_received_packet_info_.header.packet_number + << " for " + << GetServerConnectionIdAsRecipient( + last_received_packet_info_.header, perspective_); + + QUIC_DLOG_IF(INFO, current_packet_content_ == SECOND_FRAME_IS_PADDING) + << ENDPOINT << "Received a padded PING packet. is_probing: " + << IsCurrentPacketConnectivityProbing(); + + if (!version().HasIetfQuicFrames()) { + MaybeRespondToConnectivityProbingOrMigration(); + } + + current_effective_peer_migration_type_ = NO_CHANGE; + + // For IETF QUIC, it is guaranteed that TLS will give connection the + // corresponding write key before read key. In other words, connection should + // never process a packet while an ACK for it cannot be encrypted. + if (!should_last_packet_instigate_acks_) { + uber_received_packet_manager_.MaybeUpdateAckTimeout( + should_last_packet_instigate_acks_, + last_received_packet_info_.decrypted_level, + last_received_packet_info_.header.packet_number, + last_received_packet_info_.receipt_time, clock_->ApproximateNow(), + sent_packet_manager_.GetRttStats()); + } + + ClearLastFrames(); + CloseIfTooManyOutstandingSentPackets(); +} + +void QuicConnection::MaybeRespondToConnectivityProbingOrMigration() { + QUICHE_DCHECK(!version().HasIetfQuicFrames()); + if (IsCurrentPacketConnectivityProbing()) { + visitor_->OnPacketReceived(last_received_packet_info_.destination_address, + last_received_packet_info_.source_address, + /*is_connectivity_probe=*/true); + return; + } + if (perspective_ == Perspective::IS_CLIENT) { + // This node is a client, notify that a speculative connectivity probing + // packet has been received anyway. + QUIC_DVLOG(1) << ENDPOINT + << "Received a speculative connectivity probing packet for " + << GetServerConnectionIdAsRecipient( + last_received_packet_info_.header, perspective_) + << " from ip:port: " + << last_received_packet_info_.source_address.ToString() + << " to ip:port: " + << last_received_packet_info_.destination_address.ToString(); + visitor_->OnPacketReceived(last_received_packet_info_.destination_address, + last_received_packet_info_.source_address, + /*is_connectivity_probe=*/false); + return; + } +} + +bool QuicConnection::IsValidStatelessResetToken( + const StatelessResetToken& token) const { + QUICHE_DCHECK_EQ(perspective_, Perspective::IS_CLIENT); + return default_path_.stateless_reset_token.has_value() && + QuicUtils::AreStatelessResetTokensEqual( + token, *default_path_.stateless_reset_token); +} + +void QuicConnection::OnAuthenticatedIetfStatelessResetPacket( + const QuicIetfStatelessResetPacket& /*packet*/) { + // TODO(fayang): Add OnAuthenticatedIetfStatelessResetPacket to + // debug_visitor_. + QUICHE_DCHECK(version().HasIetfInvariantHeader()); + QUICHE_DCHECK_EQ(perspective_, Perspective::IS_CLIENT); + + if (!IsDefaultPath(last_received_packet_info_.destination_address, + last_received_packet_info_.source_address)) { + // This packet is received on a probing path. Do not close connection. + if (IsAlternativePath(last_received_packet_info_.destination_address, + GetEffectivePeerAddressFromCurrentPacket())) { + QUIC_BUG_IF(quic_bug_12714_18, alternative_path_.validated) + << "STATELESS_RESET received on alternate path after it's " + "validated."; + path_validator_.CancelPathValidation(); + } else { + QUIC_BUG(quic_bug_10511_17) + << "Received Stateless Reset on unknown socket."; + } + return; + } + + const std::string error_details = "Received stateless reset."; + QUIC_CODE_COUNT(quic_tear_down_local_connection_on_stateless_reset); + TearDownLocalConnectionState(QUIC_PUBLIC_RESET, NO_IETF_QUIC_ERROR, + error_details, ConnectionCloseSource::FROM_PEER); +} + +void QuicConnection::OnKeyUpdate(KeyUpdateReason reason) { + QUICHE_DCHECK(support_key_update_for_connection_); + QUIC_DLOG(INFO) << ENDPOINT << "Key phase updated for " << reason; + + lowest_packet_sent_in_current_key_phase_.Clear(); + stats_.key_update_count++; + + // If another key update triggers while the previous + // discard_previous_one_rtt_keys_alarm_ hasn't fired yet, cancel it since the + // old keys would already be discarded. + discard_previous_one_rtt_keys_alarm_->Cancel(); + + visitor_->OnKeyUpdate(reason); +} + +void QuicConnection::OnDecryptedFirstPacketInKeyPhase() { + QUIC_DLOG(INFO) << ENDPOINT << "OnDecryptedFirstPacketInKeyPhase"; + // An endpoint SHOULD retain old read keys for no more than three times the + // PTO after having received a packet protected using the new keys. After this + // period, old read keys and their corresponding secrets SHOULD be discarded. + // + // Note that this will cause an unnecessary + // discard_previous_one_rtt_keys_alarm_ on the first packet in the 1RTT + // encryption level, but this is harmless. + discard_previous_one_rtt_keys_alarm_->Set( + clock_->ApproximateNow() + sent_packet_manager_.GetPtoDelay() * 3); +} + +std::unique_ptr +QuicConnection::AdvanceKeysAndCreateCurrentOneRttDecrypter() { + QUIC_DLOG(INFO) << ENDPOINT << "AdvanceKeysAndCreateCurrentOneRttDecrypter"; + return visitor_->AdvanceKeysAndCreateCurrentOneRttDecrypter(); +} + +std::unique_ptr QuicConnection::CreateCurrentOneRttEncrypter() { + QUIC_DLOG(INFO) << ENDPOINT << "CreateCurrentOneRttEncrypter"; + return visitor_->CreateCurrentOneRttEncrypter(); +} + +void QuicConnection::ClearLastFrames() { + should_last_packet_instigate_acks_ = false; +} + +void QuicConnection::CloseIfTooManyOutstandingSentPackets() { + // This occurs if we don't discard old packets we've seen fast enough. It's + // possible largest observed is less than leaset unacked. + const bool should_close = + sent_packet_manager_.GetLargestSentPacket().IsInitialized() && + sent_packet_manager_.GetLargestSentPacket() > + sent_packet_manager_.GetLeastUnacked() + max_tracked_packets_; + + if (should_close) { + CloseConnection( + QUIC_TOO_MANY_OUTSTANDING_SENT_PACKETS, + absl::StrCat("More than ", max_tracked_packets_, + " outstanding, least_unacked: ", + sent_packet_manager_.GetLeastUnacked().ToUint64(), + ", packets_processed: ", stats_.packets_processed, + ", last_decrypted_packet_level: ", + EncryptionLevelToString( + last_received_packet_info_.decrypted_level)), + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + } +} + +const QuicFrame QuicConnection::GetUpdatedAckFrame() { + QUICHE_DCHECK(!uber_received_packet_manager_.IsAckFrameEmpty( + QuicUtils::GetPacketNumberSpace(encryption_level_))) + << "Try to retrieve an empty ACK frame"; + return uber_received_packet_manager_.GetUpdatedAckFrame( + QuicUtils::GetPacketNumberSpace(encryption_level_), + clock_->ApproximateNow()); +} + +void QuicConnection::PopulateStopWaitingFrame( + QuicStopWaitingFrame* stop_waiting) { + stop_waiting->least_unacked = GetLeastUnacked(); +} + +QuicPacketNumber QuicConnection::GetLeastUnacked() const { + return sent_packet_manager_.GetLeastUnacked(); +} + +bool QuicConnection::HandleWriteBlocked() { + if (!writer_->IsWriteBlocked()) { + return false; + } + + visitor_->OnWriteBlocked(); + return true; +} + +void QuicConnection::MaybeSendInResponseToPacket() { + if (!connected_) { + return; + } + + // If the writer is blocked, don't attempt to send packets now or in the send + // alarm. When the writer unblocks, OnCanWrite() will be called for this + // connection to send. + if (HandleWriteBlocked()) { + return; + } + + // Now that we have received an ack, we might be able to send packets which + // are queued locally, or drain streams which are blocked. + if (defer_send_in_response_to_packets_) { + send_alarm_->Update(clock_->ApproximateNow(), QuicTime::Delta::Zero()); + } else { + WriteIfNotBlocked(); + } +} + +size_t QuicConnection::SendCryptoData(EncryptionLevel level, + size_t write_length, + QuicStreamOffset offset) { + if (write_length == 0) { + QUIC_BUG(quic_bug_10511_18) << "Attempt to send empty crypto frame"; + return 0; + } + ScopedPacketFlusher flusher(this); + return packet_creator_.ConsumeCryptoData(level, write_length, offset); +} + +QuicConsumedData QuicConnection::SendStreamData(QuicStreamId id, + size_t write_length, + QuicStreamOffset offset, + StreamSendingState state) { + if (state == NO_FIN && write_length == 0) { + QUIC_BUG(quic_bug_10511_19) << "Attempt to send empty stream frame"; + return QuicConsumedData(0, false); + } + + if (perspective_ == Perspective::IS_SERVER && + version().CanSendCoalescedPackets() && !IsHandshakeConfirmed()) { + if (in_probe_time_out_ && coalesced_packet_.NumberOfPackets() == 0u) { + // PTO fires while handshake is not confirmed. Do not preempt handshake + // data with stream data. + QUIC_CODE_COUNT(quic_try_to_send_half_rtt_data_when_pto_fires); + return QuicConsumedData(0, false); + } + if (coalesced_packet_.ContainsPacketOfEncryptionLevel(ENCRYPTION_INITIAL) && + coalesced_packet_.NumberOfPackets() == 1u) { + // Handshake is not confirmed yet, if there is only an initial packet in + // the coalescer, try to bundle an ENCRYPTION_HANDSHAKE packet before + // sending stream data. + sent_packet_manager_.RetransmitDataOfSpaceIfAny(HANDSHAKE_DATA); + } + } + // Opportunistically bundle an ack with every outgoing packet. + // Particularly, we want to bundle with handshake packets since we don't + // know which decrypter will be used on an ack packet following a handshake + // packet (a handshake packet from client to server could result in a REJ or + // a SHLO from the server, leading to two different decrypters at the + // server.) + ScopedPacketFlusher flusher(this); + return packet_creator_.ConsumeData(id, write_length, offset, state); +} + +bool QuicConnection::SendControlFrame(const QuicFrame& frame) { + if (SupportsMultiplePacketNumberSpaces() && + (encryption_level_ == ENCRYPTION_INITIAL || + encryption_level_ == ENCRYPTION_HANDSHAKE) && + frame.type != PING_FRAME) { + // Allow PING frame to be sent without APPLICATION key. For example, when + // anti-amplification limit is used, client needs to send something to avoid + // handshake deadlock. + QUIC_DVLOG(1) << ENDPOINT << "Failed to send control frame: " << frame + << " at encryption level: " << encryption_level_; + return false; + } + ScopedPacketFlusher flusher(this); + const bool consumed = + packet_creator_.ConsumeRetransmittableControlFrame(frame); + if (!consumed) { + QUIC_DVLOG(1) << ENDPOINT << "Failed to send control frame: " << frame; + return false; + } + if (frame.type == PING_FRAME) { + // Flush PING frame immediately. + packet_creator_.FlushCurrentPacket(); + stats_.ping_frames_sent++; + if (debug_visitor_ != nullptr) { + debug_visitor_->OnPingSent(); + } + } + if (frame.type == BLOCKED_FRAME) { + stats_.blocked_frames_sent++; + } + return true; +} + +void QuicConnection::OnStreamReset(QuicStreamId id, + QuicRstStreamErrorCode error) { + if (error == QUIC_STREAM_NO_ERROR) { + // All data for streams which are reset with QUIC_STREAM_NO_ERROR must + // be received by the peer. + return; + } + // Flush stream frames of reset stream. + if (packet_creator_.HasPendingStreamFramesOfStream(id)) { + ScopedPacketFlusher flusher(this); + packet_creator_.FlushCurrentPacket(); + } + // TODO(ianswett): Consider checking for 3 RTOs when the last stream is + // cancelled as well. +} + +const QuicConnectionStats& QuicConnection::GetStats() { + const RttStats* rtt_stats = sent_packet_manager_.GetRttStats(); + + // Update rtt and estimated bandwidth. + QuicTime::Delta min_rtt = rtt_stats->min_rtt(); + if (min_rtt.IsZero()) { + // If min RTT has not been set, use initial RTT instead. + min_rtt = rtt_stats->initial_rtt(); + } + stats_.min_rtt_us = min_rtt.ToMicroseconds(); + + QuicTime::Delta srtt = rtt_stats->SmoothedOrInitialRtt(); + stats_.srtt_us = srtt.ToMicroseconds(); + + stats_.estimated_bandwidth = sent_packet_manager_.BandwidthEstimate(); + sent_packet_manager_.GetSendAlgorithm()->PopulateConnectionStats(&stats_); + stats_.egress_mtu = long_term_mtu_; + stats_.ingress_mtu = largest_received_packet_size_; + return stats_; +} + +void QuicConnection::OnCoalescedPacket(const QuicEncryptedPacket& packet) { + QueueCoalescedPacket(packet); +} + +void QuicConnection::OnUndecryptablePacket(const QuicEncryptedPacket& packet, + EncryptionLevel decryption_level, + bool has_decryption_key) { + QUIC_DVLOG(1) << ENDPOINT << "Received undecryptable packet of length " + << packet.length() << " with" + << (has_decryption_key ? "" : "out") << " key at level " + << decryption_level + << " while connection is at encryption level " + << encryption_level_; + QUICHE_DCHECK(EncryptionLevelIsValid(decryption_level)); + if (encryption_level_ != ENCRYPTION_FORWARD_SECURE) { + ++stats_.undecryptable_packets_received_before_handshake_complete; + } + + const bool should_enqueue = + ShouldEnqueueUnDecryptablePacket(decryption_level, has_decryption_key); + if (should_enqueue) { + QueueUndecryptablePacket(packet, decryption_level); + } + + if (debug_visitor_ != nullptr) { + debug_visitor_->OnUndecryptablePacket(decryption_level, + /*dropped=*/!should_enqueue); + } + + if (has_decryption_key) { + stats_.num_failed_authentication_packets_received++; + if (version().UsesTls()) { + // Should always be non-null if has_decryption_key is true. + QUICHE_DCHECK(framer_.GetDecrypter(decryption_level)); + const QuicPacketCount integrity_limit = + framer_.GetDecrypter(decryption_level)->GetIntegrityLimit(); + QUIC_DVLOG(2) << ENDPOINT << "Checking AEAD integrity limits:" + << " num_failed_authentication_packets_received=" + << stats_.num_failed_authentication_packets_received + << " integrity_limit=" << integrity_limit; + if (stats_.num_failed_authentication_packets_received >= + integrity_limit) { + const std::string error_details = absl::StrCat( + "decrypter integrity limit reached:" + " num_failed_authentication_packets_received=", + stats_.num_failed_authentication_packets_received, + " integrity_limit=", integrity_limit); + CloseConnection(QUIC_AEAD_LIMIT_REACHED, error_details, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + } + } + } + + if (version().UsesTls() && perspective_ == Perspective::IS_SERVER && + decryption_level == ENCRYPTION_ZERO_RTT && !has_decryption_key && + had_zero_rtt_decrypter_) { + QUIC_CODE_COUNT_N( + quic_server_received_tls_zero_rtt_packet_after_discarding_decrypter, 1, + 3); + stats_ + .num_tls_server_zero_rtt_packets_received_after_discarding_decrypter++; + } +} + +bool QuicConnection::ShouldEnqueueUnDecryptablePacket( + EncryptionLevel decryption_level, bool has_decryption_key) const { + if (has_decryption_key) { + // We already have the key for this decryption level, therefore no + // future keys will allow it be decrypted. + return false; + } + if (IsHandshakeComplete()) { + // We do not expect to install any further keys. + return false; + } + if (undecryptable_packets_.size() >= max_undecryptable_packets_) { + // We do not queue more than max_undecryptable_packets_ packets. + return false; + } + if (version().KnowsWhichDecrypterToUse() && + decryption_level == ENCRYPTION_INITIAL) { + // When the corresponding decryption key is not available, all + // non-Initial packets should be buffered until the handshake is complete. + return false; + } + if (perspective_ == Perspective::IS_CLIENT && version().UsesTls() && + decryption_level == ENCRYPTION_ZERO_RTT) { + // Only clients send Zero RTT packets in IETF QUIC. + QUIC_PEER_BUG(quic_peer_bug_client_received_zero_rtt) + << "Client received a Zero RTT packet, not buffering."; + return false; + } + return true; +} + +std::string QuicConnection::UndecryptablePacketsInfo() const { + std::string info = absl::StrCat( + "num_undecryptable_packets: ", undecryptable_packets_.size(), " {"); + for (const auto& packet : undecryptable_packets_) { + absl::StrAppend(&info, "[", + EncryptionLevelToString(packet.encryption_level), ", ", + packet.packet->length(), "]"); + } + absl::StrAppend(&info, "}"); + return info; +} + +void QuicConnection::ProcessUdpPacket(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + const QuicReceivedPacket& packet) { + if (!connected_) { + return; + } + QUIC_DVLOG(2) << ENDPOINT << "Received encrypted " << packet.length() + << " bytes:" << std::endl + << quiche::QuicheTextUtils::HexDump( + absl::string_view(packet.data(), packet.length())); + QUIC_BUG_IF(quic_bug_12714_21, current_packet_data_ != nullptr) + << "ProcessUdpPacket must not be called while processing a packet."; + if (debug_visitor_ != nullptr) { + debug_visitor_->OnPacketReceived(self_address, peer_address, packet); + } + last_received_packet_info_ = + ReceivedPacketInfo(self_address, peer_address, packet.receipt_time(), + packet.length(), packet.ecn_codepoint()); + current_packet_data_ = packet.data(); + + if (!default_path_.self_address.IsInitialized()) { + default_path_.self_address = last_received_packet_info_.destination_address; + } else if (default_path_.self_address != self_address && + sent_server_preferred_address_.IsInitialized() && + self_address.Normalized() == + sent_server_preferred_address_.Normalized()) { + // If the packet is received at the preferred address, treat it as if it is + // received on the original server address. + last_received_packet_info_.destination_address = default_path_.self_address; + last_received_packet_info_.actual_destination_address = self_address; + } + + if (!direct_peer_address_.IsInitialized()) { + if (perspective_ == Perspective::IS_CLIENT) { + AddKnownServerAddress(last_received_packet_info_.source_address); + } + UpdatePeerAddress(last_received_packet_info_.source_address); + } + + if (!default_path_.peer_address.IsInitialized()) { + const QuicSocketAddress effective_peer_addr = + GetEffectivePeerAddressFromCurrentPacket(); + + // The default path peer_address must be initialized at the beginning of the + // first packet processed(here). If effective_peer_addr is uninitialized, + // just set effective_peer_address_ to the direct peer address. + default_path_.peer_address = effective_peer_addr.IsInitialized() + ? effective_peer_addr + : direct_peer_address_; + } + + stats_.bytes_received += packet.length(); + ++stats_.packets_received; + if (IsDefaultPath(last_received_packet_info_.destination_address, + last_received_packet_info_.source_address) && + EnforceAntiAmplificationLimit()) { + last_received_packet_info_.received_bytes_counted = true; + default_path_.bytes_received_before_address_validation += + last_received_packet_info_.length; + } + + // Ensure the time coming from the packet reader is within 2 minutes of now. + if (std::abs((packet.receipt_time() - clock_->ApproximateNow()).ToSeconds()) > + 2 * 60) { + QUIC_BUG(quic_bug_10511_21) + << "Packet receipt time:" << packet.receipt_time().ToDebuggingValue() + << " too far from current time:" + << clock_->ApproximateNow().ToDebuggingValue(); + } + QUIC_DVLOG(1) << ENDPOINT << "time of last received packet: " + << packet.receipt_time().ToDebuggingValue() << " from peer " + << last_received_packet_info_.source_address << ", to " + << last_received_packet_info_.destination_address; + + ScopedPacketFlusher flusher(this); + if (!framer_.ProcessPacket(packet)) { + // If we are unable to decrypt this packet, it might be + // because the CHLO or SHLO packet was lost. + QUIC_DVLOG(1) << ENDPOINT + << "Unable to process packet. Last packet processed: " + << last_received_packet_info_.header.packet_number; + current_packet_data_ = nullptr; + is_current_packet_connectivity_probing_ = false; + + MaybeProcessCoalescedPackets(); + return; + } + + ++stats_.packets_processed; + + QUIC_DLOG_IF(INFO, active_effective_peer_migration_type_ != NO_CHANGE) + << "sent_packet_manager_.GetLargestObserved() = " + << sent_packet_manager_.GetLargestObserved() + << ", highest_packet_sent_before_effective_peer_migration_ = " + << highest_packet_sent_before_effective_peer_migration_; + if (!validate_client_addresses_ && + active_effective_peer_migration_type_ != NO_CHANGE && + sent_packet_manager_.GetLargestObserved().IsInitialized() && + (!highest_packet_sent_before_effective_peer_migration_.IsInitialized() || + sent_packet_manager_.GetLargestObserved() > + highest_packet_sent_before_effective_peer_migration_)) { + if (perspective_ == Perspective::IS_SERVER) { + OnEffectivePeerMigrationValidated(/*is_migration_linkable=*/true); + } + } + + if (!MaybeProcessCoalescedPackets()) { + MaybeProcessUndecryptablePackets(); + MaybeSendInResponseToPacket(); + } + SetPingAlarm(); + RetirePeerIssuedConnectionIdsNoLongerOnPath(); + current_packet_data_ = nullptr; + is_current_packet_connectivity_probing_ = false; +} + +void QuicConnection::OnBlockedWriterCanWrite() { + writer_->SetWritable(); + OnCanWrite(); +} + +void QuicConnection::OnCanWrite() { + if (!connected_) { + return; + } + if (writer_->IsWriteBlocked()) { + const std::string error_details = + "Writer is blocked while calling OnCanWrite."; + QUIC_BUG(quic_bug_10511_22) << ENDPOINT << error_details; + CloseConnection(QUIC_INTERNAL_ERROR, error_details, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + + ScopedPacketFlusher flusher(this); + + WriteQueuedPackets(); + const QuicTime ack_timeout = + uber_received_packet_manager_.GetEarliestAckTimeout(); + if (ack_timeout.IsInitialized() && ack_timeout <= clock_->ApproximateNow()) { + // Send an ACK now because either 1) we were write blocked when we last + // tried to send an ACK, or 2) both ack alarm and send alarm were set to + // go off together. + if (SupportsMultiplePacketNumberSpaces()) { + SendAllPendingAcks(); + } else { + SendAck(); + } + } + + // Sending queued packets may have caused the socket to become write blocked, + // or the congestion manager to prohibit sending. + if (!CanWrite(HAS_RETRANSMITTABLE_DATA)) { + return; + } + + // Tell the session it can write. + visitor_->OnCanWrite(); + + // After the visitor writes, it may have caused the socket to become write + // blocked or the congestion manager to prohibit sending, so check again. + if (visitor_->WillingAndAbleToWrite() && !send_alarm_->IsSet() && + CanWrite(HAS_RETRANSMITTABLE_DATA)) { + // We're not write blocked, but some data wasn't written. Register for + // 'immediate' resumption so we'll keep writing after other connections. + send_alarm_->Set(clock_->ApproximateNow()); + } +} + +void QuicConnection::WriteIfNotBlocked() { + if (framer().is_processing_packet()) { + QUIC_BUG(connection_write_mid_packet_processing) + << ENDPOINT << "Tried to write in mid of packet processing"; + return; + } + if (!HandleWriteBlocked()) { + OnCanWrite(); + } +} + +void QuicConnection::MaybeClearQueuedPacketsOnPathChange() { + if (connection_migration_use_new_cid_ && + peer_issued_cid_manager_ != nullptr && HasQueuedPackets()) { + // Discard packets serialized with the connection ID on the old code path. + // It is possible to clear queued packets only if connection ID changes. + // However, the case where connection ID is unchanged and queued packets are + // non-empty is quite rare. + ClearQueuedPackets(); + } +} + +void QuicConnection::ReplaceInitialServerConnectionId( + const QuicConnectionId& new_server_connection_id) { + QUICHE_DCHECK(perspective_ == Perspective::IS_CLIENT); + if (version().HasIetfQuicFrames()) { + if (new_server_connection_id.IsEmpty()) { + peer_issued_cid_manager_ = nullptr; + } else { + if (peer_issued_cid_manager_ != nullptr) { + QUIC_BUG_IF(quic_bug_12714_22, + !peer_issued_cid_manager_->IsConnectionIdActive( + default_path_.server_connection_id)) + << "Connection ID replaced header is no longer active. old id: " + << default_path_.server_connection_id + << " new_id: " << new_server_connection_id; + peer_issued_cid_manager_->ReplaceConnectionId( + default_path_.server_connection_id, new_server_connection_id); + } else { + peer_issued_cid_manager_ = + std::make_unique( + kMinNumOfActiveConnectionIds, new_server_connection_id, clock_, + alarm_factory_, this, context()); + } + } + } + default_path_.server_connection_id = new_server_connection_id; + packet_creator_.SetServerConnectionId(default_path_.server_connection_id); +} + +void QuicConnection::FindMatchingOrNewClientConnectionIdOrToken( + const PathState& default_path, const PathState& alternative_path, + const QuicConnectionId& server_connection_id, + QuicConnectionId* client_connection_id, + absl::optional* stateless_reset_token) { + QUICHE_DCHECK(perspective_ == Perspective::IS_SERVER); + if (peer_issued_cid_manager_ == nullptr || + server_connection_id == default_path.server_connection_id) { + *client_connection_id = default_path.client_connection_id; + *stateless_reset_token = default_path.stateless_reset_token; + return; + } + if (server_connection_id == alternative_path_.server_connection_id) { + *client_connection_id = alternative_path.client_connection_id; + *stateless_reset_token = alternative_path.stateless_reset_token; + return; + } + if (!connection_migration_use_new_cid_) { + QUIC_BUG(quic_bug_46004) << "Cannot find matching connection ID."; + return; + } + auto* connection_id_data = + peer_issued_cid_manager_->ConsumeOneUnusedConnectionId(); + if (connection_id_data == nullptr) { + return; + } + *client_connection_id = connection_id_data->connection_id; + *stateless_reset_token = connection_id_data->stateless_reset_token; +} + +bool QuicConnection::FindOnPathConnectionIds( + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + QuicConnectionId* client_connection_id, + QuicConnectionId* server_connection_id) const { + if (IsDefaultPath(self_address, peer_address)) { + *client_connection_id = default_path_.client_connection_id, + *server_connection_id = default_path_.server_connection_id; + return true; + } + if (IsAlternativePath(self_address, peer_address)) { + *client_connection_id = alternative_path_.client_connection_id, + *server_connection_id = alternative_path_.server_connection_id; + return true; + } + // Client should only send packets on either default or alternative path, so + // it shouldn't fail here. If the server fail to find CID to use, no packet + // will be generated on this path. + // TODO(danzh) fix SendPathResponse() to respond to probes from a different + // client port with non-Zero client CID. + QUIC_BUG_IF(failed to find on path connection ids, + perspective_ == Perspective::IS_CLIENT) + << "Fails to find on path connection IDs"; + return false; +} + +void QuicConnection::SetDefaultPathState(PathState new_path_state) { + default_path_ = std::move(new_path_state); + if (connection_migration_use_new_cid_) { + packet_creator_.SetClientConnectionId(default_path_.client_connection_id); + packet_creator_.SetServerConnectionId(default_path_.server_connection_id); + } +} + +bool QuicConnection::ProcessValidatedPacket(const QuicPacketHeader& header) { + if (perspective_ == Perspective::IS_CLIENT && version().HasIetfQuicFrames() && + direct_peer_address_.IsInitialized() && + last_received_packet_info_.source_address.IsInitialized() && + direct_peer_address_ != last_received_packet_info_.source_address && + !IsKnownServerAddress(last_received_packet_info_.source_address)) { + // Discard packets received from unseen server addresses. + return false; + } + + if (perspective_ == Perspective::IS_SERVER && + default_path_.self_address.IsInitialized() && + last_received_packet_info_.destination_address.IsInitialized() && + default_path_.self_address != + last_received_packet_info_.destination_address) { + // Allow change between pure IPv4 and equivalent mapped IPv4 address. + if (default_path_.self_address.port() != + last_received_packet_info_.destination_address.port() || + default_path_.self_address.host().Normalized() != + last_received_packet_info_.destination_address.host() + .Normalized()) { + if (!visitor_->AllowSelfAddressChange()) { + const std::string error_details = absl::StrCat( + "Self address migration is not supported at the server, current " + "address: ", + default_path_.self_address.ToString(), + ", server preferred address: ", + sent_server_preferred_address_.ToString(), + ", received packet address: ", + last_received_packet_info_.destination_address.ToString(), + ", size: ", last_received_packet_info_.length, + ", packet number: ", header.packet_number.ToString(), + ", encryption level: ", + EncryptionLevelToString( + last_received_packet_info_.decrypted_level)); + QUIC_LOG_EVERY_N_SEC(INFO, 100) << error_details; + QUIC_CODE_COUNT(quic_dropped_packets_with_changed_server_address); + return false; + } + } + default_path_.self_address = last_received_packet_info_.destination_address; + } + + if (GetQuicReloadableFlag(quic_use_received_client_addresses_cache) && + perspective_ == Perspective::IS_SERVER && + !last_received_packet_info_.actual_destination_address.IsInitialized() && + last_received_packet_info_.source_address.IsInitialized()) { + QUIC_RELOADABLE_FLAG_COUNT(quic_use_received_client_addresses_cache); + // Record client address of packets received on server original address. + received_client_addresses_cache_.Insert( + last_received_packet_info_.source_address, + std::make_unique(true)); + } + + if (perspective_ == Perspective::IS_SERVER && + last_received_packet_info_.actual_destination_address.IsInitialized() && + !IsHandshakeConfirmed() && + GetEffectivePeerAddressFromCurrentPacket() != + default_path_.peer_address) { + // Our client implementation has an optimization to spray packets from + // different sockets to the server's preferred address before handshake + // gets confirmed. In this case, do not kick off client address migration + // detection. + QUICHE_DCHECK(sent_server_preferred_address_.IsInitialized()); + last_received_packet_info_.source_address = direct_peer_address_; + } + + if (PacketCanReplaceServerConnectionId(header, perspective_) && + default_path_.server_connection_id != header.source_connection_id) { + QUICHE_DCHECK_EQ(header.long_packet_type, INITIAL); + if (server_connection_id_replaced_by_initial_) { + QUIC_DLOG(ERROR) << ENDPOINT << "Refusing to replace connection ID " + << default_path_.server_connection_id << " with " + << header.source_connection_id; + return false; + } + server_connection_id_replaced_by_initial_ = true; + QUIC_DLOG(INFO) << ENDPOINT << "Replacing connection ID " + << default_path_.server_connection_id << " with " + << header.source_connection_id; + if (!original_destination_connection_id_.has_value()) { + original_destination_connection_id_ = default_path_.server_connection_id; + } + ReplaceInitialServerConnectionId(header.source_connection_id); + } + + if (!ValidateReceivedPacketNumber(header.packet_number)) { + return false; + } + + if (!version_negotiated_) { + if (perspective_ == Perspective::IS_CLIENT) { + QUICHE_DCHECK(!header.version_flag || header.form != GOOGLE_QUIC_PACKET); + if (!version().HasIetfInvariantHeader()) { + // If the client gets a packet without the version flag from the server + // it should stop sending version since the version negotiation is done. + // IETF QUIC stops sending version once encryption level switches to + // forward secure. + packet_creator_.StopSendingVersion(); + } + version_negotiated_ = true; + OnSuccessfulVersionNegotiation(); + } + } + + if (last_received_packet_info_.length > largest_received_packet_size_) { + largest_received_packet_size_ = last_received_packet_info_.length; + } + + if (perspective_ == Perspective::IS_SERVER && + encryption_level_ == ENCRYPTION_INITIAL && + last_received_packet_info_.length > packet_creator_.max_packet_length()) { + if (GetQuicFlag(quic_use_lower_server_response_mtu_for_test)) { + SetMaxPacketLength( + std::min(last_received_packet_info_.length, QuicByteCount(1250))); + } else { + SetMaxPacketLength(last_received_packet_info_.length); + } + } + return true; +} + +bool QuicConnection::ValidateReceivedPacketNumber( + QuicPacketNumber packet_number) { + // If this packet has already been seen, or the sender has told us that it + // will not be retransmitted, then stop processing the packet. + if (!uber_received_packet_manager_.IsAwaitingPacket( + last_received_packet_info_.decrypted_level, packet_number)) { + QUIC_DLOG(INFO) << ENDPOINT << "Packet " << packet_number + << " no longer being waited for at level " + << static_cast( + last_received_packet_info_.decrypted_level) + << ". Discarding."; + if (debug_visitor_ != nullptr) { + debug_visitor_->OnDuplicatePacket(packet_number); + } + return false; + } + + return true; +} + +void QuicConnection::WriteQueuedPackets() { + QUICHE_DCHECK(!writer_->IsWriteBlocked()); + QUIC_CLIENT_HISTOGRAM_COUNTS("QuicSession.NumQueuedPacketsBeforeWrite", + buffered_packets_.size(), 1, 1000, 50, ""); + + while (!buffered_packets_.empty()) { + if (HandleWriteBlocked()) { + break; + } + const BufferedPacket& packet = buffered_packets_.front(); + WriteResult result = SendPacketToWriter( + packet.data.get(), packet.length, packet.self_address.host(), + packet.peer_address, per_packet_options_); + QUIC_DVLOG(1) << ENDPOINT << "Sending buffered packet, result: " << result; + if (IsMsgTooBig(writer_, result) && packet.length > long_term_mtu_) { + // When MSG_TOO_BIG is returned, the system typically knows what the + // actual MTU is, so there is no need to probe further. + // TODO(wub): Reduce max packet size to a safe default, or the actual MTU. + mtu_discoverer_.Disable(); + mtu_discovery_alarm_->Cancel(); + buffered_packets_.pop_front(); + continue; + } + if (IsWriteError(result.status)) { + OnWriteError(result.error_code); + break; + } + if (result.status == WRITE_STATUS_OK || + result.status == WRITE_STATUS_BLOCKED_DATA_BUFFERED) { + buffered_packets_.pop_front(); + } + if (IsWriteBlockedStatus(result.status)) { + visitor_->OnWriteBlocked(); + break; + } + } +} + +void QuicConnection::MarkZeroRttPacketsForRetransmission(int reject_reason) { + sent_packet_manager_.MarkZeroRttPacketsForRetransmission(); + if (debug_visitor_ != nullptr && version().UsesTls()) { + debug_visitor_->OnZeroRttRejected(reject_reason); + } +} + +void QuicConnection::NeuterUnencryptedPackets() { + sent_packet_manager_.NeuterUnencryptedPackets(); + // This may have changed the retransmission timer, so re-arm it. + SetRetransmissionAlarm(); + if (default_enable_5rto_blackhole_detection_) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_default_enable_5rto_blackhole_detection2, + 1, 3); + // Consider this as forward progress since this is called when initial key + // gets discarded (or previous unencrypted data is not needed anymore). + OnForwardProgressMade(); + } + if (SupportsMultiplePacketNumberSpaces()) { + // Stop sending ack of initial packet number space. + uber_received_packet_manager_.ResetAckStates(ENCRYPTION_INITIAL); + // Re-arm ack alarm. + ack_alarm_->Update(uber_received_packet_manager_.GetEarliestAckTimeout(), + kAlarmGranularity); + } +} + +bool QuicConnection::ShouldGeneratePacket( + HasRetransmittableData retransmittable, IsHandshake handshake) { + QUICHE_DCHECK(handshake != IS_HANDSHAKE || + QuicVersionUsesCryptoFrames(transport_version())) + << ENDPOINT + << "Handshake in STREAM frames should not check ShouldGeneratePacket"; + if (peer_issued_cid_manager_ != nullptr && + packet_creator_.GetDestinationConnectionId().IsEmpty()) { + QUICHE_DCHECK(version().HasIetfQuicFrames()); + QUIC_CODE_COUNT(quic_generate_packet_blocked_by_no_connection_id); + QUIC_BUG_IF(quic_bug_90265_1, perspective_ == Perspective::IS_CLIENT); + QUIC_DLOG(INFO) << ENDPOINT + << "There is no destination connection ID available to " + "generate packet."; + return false; + } + if (IsDefaultPath(default_path_.self_address, + packet_creator_.peer_address())) { + return CanWrite(retransmittable); + } + // This is checking on the alternative path with a different peer address. The + // self address and the writer used are the same as the default path. In the + // case of different self address and writer, writing packet would use a + // differnt code path without checking the states of the default writer. + return connected_ && !HandleWriteBlocked(); +} + +const QuicFrames QuicConnection::MaybeBundleAckOpportunistically() { + if (!ack_frequency_sent_ && sent_packet_manager_.CanSendAckFrequency()) { + if (packet_creator_.NextSendingPacketNumber() >= + FirstSendingPacketNumber() + kMinReceivedBeforeAckDecimation) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_can_send_ack_frequency, 3, 3); + ack_frequency_sent_ = true; + auto frame = sent_packet_manager_.GetUpdatedAckFrequencyFrame(); + visitor_->SendAckFrequency(frame); + } + } + + QuicFrames frames; + const bool has_pending_ack = + uber_received_packet_manager_ + .GetAckTimeout(QuicUtils::GetPacketNumberSpace(encryption_level_)) + .IsInitialized(); + if (!has_pending_ack && stop_waiting_count_ <= 1) { + // No need to send an ACK. + return frames; + } + ResetAckStates(); + + QUIC_DVLOG(1) << ENDPOINT << "Bundle an ACK opportunistically"; + QuicFrame updated_ack_frame = GetUpdatedAckFrame(); + QUIC_BUG_IF(quic_bug_12714_23, updated_ack_frame.ack_frame->packets.Empty()) + << ENDPOINT << "Attempted to opportunistically bundle an empty " + << encryption_level_ << " ACK, " << (has_pending_ack ? "" : "!") + << "has_pending_ack, stop_waiting_count_ " << stop_waiting_count_; + frames.push_back(updated_ack_frame); + + if (!no_stop_waiting_frames_) { + QuicStopWaitingFrame stop_waiting; + PopulateStopWaitingFrame(&stop_waiting); + frames.push_back(QuicFrame(stop_waiting)); + } + return frames; +} + +bool QuicConnection::CanWrite(HasRetransmittableData retransmittable) { + if (!connected_) { + return false; + } + + if (version().CanSendCoalescedPackets() && + framer_.HasEncrypterOfEncryptionLevel(ENCRYPTION_INITIAL) && + framer_.is_processing_packet()) { + // While we still have initial keys, suppress sending in mid of packet + // processing. + // TODO(fayang): always suppress sending while in the mid of packet + // processing. + QUIC_DVLOG(1) << ENDPOINT + << "Suppress sending in the mid of packet processing"; + return false; + } + + if (fill_coalesced_packet_) { + // Try to coalesce packet, only allow to write when creator is on soft max + // packet length. Given the next created packet is going to fill current + // coalesced packet, do not check amplification factor. + return packet_creator_.HasSoftMaxPacketLength(); + } + + if (sent_packet_manager_.pending_timer_transmission_count() > 0) { + // Allow sending if there are pending tokens, which occurs when: + // 1) firing PTO, + // 2) bundling CRYPTO data with ACKs, + // 3) coalescing CRYPTO data of higher space. + return true; + } + + if (LimitedByAmplificationFactor(packet_creator_.max_packet_length())) { + // Server is constrained by the amplification restriction. + QUIC_CODE_COUNT(quic_throttled_by_amplification_limit); + QUIC_DVLOG(1) << ENDPOINT + << "Constrained by amplification restriction to peer address " + << default_path_.peer_address << " bytes received " + << default_path_.bytes_received_before_address_validation + << ", bytes sent" + << default_path_.bytes_sent_before_address_validation; + ++stats_.num_amplification_throttling; + return false; + } + + if (HandleWriteBlocked()) { + return false; + } + + // Allow acks and probing frames to be sent immediately. + if (retransmittable == NO_RETRANSMITTABLE_DATA) { + return true; + } + // If the send alarm is set, wait for it to fire. + if (send_alarm_->IsSet()) { + return false; + } + + QuicTime now = clock_->Now(); + QuicTime::Delta delay = sent_packet_manager_.TimeUntilSend(now); + if (delay.IsInfinite()) { + send_alarm_->Cancel(); + return false; + } + + // Scheduler requires a delay. + if (!delay.IsZero()) { + if (delay <= release_time_into_future_) { + // Required delay is within pace time into future, send now. + return true; + } + // Cannot send packet now because delay is too far in the future. + send_alarm_->Update(now + delay, kAlarmGranularity); + QUIC_DVLOG(1) << ENDPOINT << "Delaying sending " << delay.ToMilliseconds() + << "ms"; + return false; + } + return true; +} + +QuicTime QuicConnection::CalculatePacketSentTime() { + const QuicTime now = clock_->Now(); + if (!supports_release_time_ || per_packet_options_ == nullptr) { + // Don't change the release delay. + return now; + } + + auto next_release_time_result = sent_packet_manager_.GetNextReleaseTime(); + + // Release before |now| is impossible. + QuicTime next_release_time = + std::max(now, next_release_time_result.release_time); + per_packet_options_->release_time_delay = next_release_time - now; + per_packet_options_->allow_burst = next_release_time_result.allow_burst; + return next_release_time; +} + +bool QuicConnection::WritePacket(SerializedPacket* packet) { + if (sent_packet_manager_.GetLargestSentPacket().IsInitialized() && + packet->packet_number < sent_packet_manager_.GetLargestSentPacket()) { + QUIC_BUG(quic_bug_10511_23) + << "Attempt to write packet:" << packet->packet_number + << " after:" << sent_packet_manager_.GetLargestSentPacket(); + CloseConnection(QUIC_INTERNAL_ERROR, "Packet written out of order.", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return true; + } + const bool is_mtu_discovery = QuicUtils::ContainsFrameType( + packet->nonretransmittable_frames, MTU_DISCOVERY_FRAME); + const SerializedPacketFate fate = packet->fate; + // Termination packets are encrypted and saved, so don't exit early. + QuicErrorCode error_code = QUIC_NO_ERROR; + const bool is_termination_packet = IsTerminationPacket(*packet, &error_code); + QuicPacketNumber packet_number = packet->packet_number; + QuicPacketLength encrypted_length = packet->encrypted_length; + // Termination packets are eventually owned by TimeWaitListManager. + // Others are deleted at the end of this call. + if (is_termination_packet) { + if (termination_packets_ == nullptr) { + termination_packets_.reset( + new std::vector>); + } + // Copy the buffer so it's owned in the future. + char* buffer_copy = CopyBuffer(*packet); + termination_packets_->emplace_back( + new QuicEncryptedPacket(buffer_copy, encrypted_length, true)); + if (error_code == QUIC_SILENT_IDLE_TIMEOUT) { + QUICHE_DCHECK_EQ(Perspective::IS_SERVER, perspective_); + // TODO(fayang): populate histogram indicating the time elapsed from this + // connection gets closed to following client packets get received. + QUIC_DVLOG(1) << ENDPOINT + << "Added silent connection close to termination packets, " + "num of termination packets: " + << termination_packets_->size(); + return true; + } + } + + QUICHE_DCHECK_LE(encrypted_length, kMaxOutgoingPacketSize); + QUICHE_DCHECK(is_mtu_discovery || + encrypted_length <= packet_creator_.max_packet_length()) + << " encrypted_length=" << encrypted_length + << " > packet_creator max_packet_length=" + << packet_creator_.max_packet_length(); + QUIC_DVLOG(1) << ENDPOINT << "Sending packet " << packet_number << " : " + << (IsRetransmittable(*packet) == HAS_RETRANSMITTABLE_DATA + ? "data bearing " + : " ack or probing only ") + << ", encryption level: " << packet->encryption_level + << ", encrypted length:" << encrypted_length + << ", fate: " << fate << " to peer " << packet->peer_address; + QUIC_DVLOG(2) << ENDPOINT << packet->encryption_level << " packet number " + << packet_number << " of length " << encrypted_length << ": " + << std::endl + << quiche::QuicheTextUtils::HexDump(absl::string_view( + packet->encrypted_buffer, encrypted_length)); + + // Measure the RTT from before the write begins to avoid underestimating the + // min_rtt_, especially in cases where the thread blocks or gets swapped out + // during the WritePacket below. + QuicTime packet_send_time = CalculatePacketSentTime(); + WriteResult result(WRITE_STATUS_OK, encrypted_length); + QuicSocketAddress send_to_address = packet->peer_address; + QuicSocketAddress send_from_address = self_address(); + if (perspective_ == Perspective::IS_SERVER && + sent_server_preferred_address_.IsInitialized() && + received_client_addresses_cache_.Lookup(send_to_address) == + received_client_addresses_cache_.end()) { + // Given server has not received packets from send_to_address to + // self_address(), most NATs do not allow packets from self_address() to + // send_to_address to go through. Override packet's self address to + // sent_server_preferred_address_. + // TODO(b/262386897): server should validate reverse path before changing + // self address of packets to send. + send_from_address = sent_server_preferred_address_; + } + // Self address is always the default self address on this code path. + const bool send_on_current_path = send_to_address == peer_address(); + if (!send_on_current_path) { + QUIC_BUG_IF(quic_send_non_probing_frames_on_alternative_path, + ContainsNonProbingFrame(*packet)) + << "Packet " << packet->packet_number + << " with non-probing frames was sent on alternative path: " + "nonretransmittable_frames: " + << QuicFramesToString(packet->nonretransmittable_frames) + << " retransmittable_frames: " + << QuicFramesToString(packet->retransmittable_frames); + } + switch (fate) { + case DISCARD: + ++stats_.packets_discarded; + if (debug_visitor_ != nullptr) { + debug_visitor_->OnPacketDiscarded(*packet); + } + return true; + case COALESCE: + QUIC_BUG_IF(quic_bug_12714_24, + !version().CanSendCoalescedPackets() || coalescing_done_); + if (!coalesced_packet_.MaybeCoalescePacket( + *packet, send_from_address, send_to_address, + helper_->GetStreamSendBufferAllocator(), + packet_creator_.max_packet_length())) { + // Failed to coalesce packet, flush current coalesced packet. + if (!FlushCoalescedPacket()) { + QUIC_BUG_IF(quic_connection_connected_after_flush_coalesced_failure, + connected_) + << "QUIC connection is still connected after failing to flush " + "coalesced packet."; + // Failed to flush coalesced packet, write error has been handled. + return false; + } + if (!coalesced_packet_.MaybeCoalescePacket( + *packet, send_from_address, send_to_address, + helper_->GetStreamSendBufferAllocator(), + packet_creator_.max_packet_length())) { + // Failed to coalesce packet even it is the only packet, raise a write + // error. + QUIC_DLOG(ERROR) << ENDPOINT << "Failed to coalesce packet"; + result.error_code = WRITE_STATUS_FAILED_TO_COALESCE_PACKET; + break; + } + } + if (coalesced_packet_.length() < coalesced_packet_.max_packet_length()) { + QUIC_DVLOG(1) << ENDPOINT << "Trying to set soft max packet length to " + << coalesced_packet_.max_packet_length() - + coalesced_packet_.length(); + packet_creator_.SetSoftMaxPacketLength( + coalesced_packet_.max_packet_length() - coalesced_packet_.length()); + } + break; + case BUFFER: + QUIC_DVLOG(1) << ENDPOINT << "Adding packet: " << packet->packet_number + << " to buffered packets"; + buffered_packets_.emplace_back(*packet, send_from_address, + send_to_address); + break; + case SEND_TO_WRITER: + // Stop using coalescer from now on. + coalescing_done_ = true; + // At this point, packet->release_encrypted_buffer is either nullptr, + // meaning |packet->encrypted_buffer| is a stack buffer, or not-nullptr, + /// meaning it's a writer-allocated buffer. Note that connectivity probing + // packets do not use this function, so setting release_encrypted_buffer + // to nullptr will not cause probing packets to be leaked. + // + // writer_->WritePacket transfers buffer ownership back to the writer. + packet->release_encrypted_buffer = nullptr; + result = SendPacketToWriter(packet->encrypted_buffer, encrypted_length, + send_from_address.host(), send_to_address, + per_packet_options_); + // This is a work around for an issue with linux UDP GSO batch writers. + // When sending a GSO packet with 2 segments, if the first segment is + // larger than the path MTU, instead of EMSGSIZE, the linux kernel returns + // EINVAL, which translates to WRITE_STATUS_ERROR and causes conneciton to + // be closed. By manually flush the writer here, the MTU probe is sent in + // a normal(non-GSO) packet, so the kernel can return EMSGSIZE and we will + // not close the connection. + if (is_mtu_discovery && writer_->IsBatchMode()) { + result = writer_->Flush(); + } + break; + default: + QUICHE_DCHECK(false); + break; + } + + QUIC_HISTOGRAM_ENUM( + "QuicConnection.WritePacketStatus", result.status, + WRITE_STATUS_NUM_VALUES, + "Status code returned by writer_->WritePacket() in QuicConnection."); + + if (IsWriteBlockedStatus(result.status)) { + // Ensure the writer is still write blocked, otherwise QUIC may continue + // trying to write when it will not be able to. + QUICHE_DCHECK(writer_->IsWriteBlocked()); + visitor_->OnWriteBlocked(); + // If the socket buffers the data, then the packet should not + // be queued and sent again, which would result in an unnecessary + // duplicate packet being sent. The helper must call OnCanWrite + // when the write completes, and OnWriteError if an error occurs. + if (result.status != WRITE_STATUS_BLOCKED_DATA_BUFFERED) { + QUIC_DVLOG(1) << ENDPOINT << "Adding packet: " << packet->packet_number + << " to buffered packets"; + buffered_packets_.emplace_back(*packet, send_from_address, + send_to_address); + } + } + + // In some cases, an MTU probe can cause EMSGSIZE. This indicates that the + // MTU discovery is permanently unsuccessful. + if (IsMsgTooBig(writer_, result)) { + if (is_mtu_discovery) { + // When MSG_TOO_BIG is returned, the system typically knows what the + // actual MTU is, so there is no need to probe further. + // TODO(wub): Reduce max packet size to a safe default, or the actual MTU. + QUIC_DVLOG(1) << ENDPOINT + << " MTU probe packet too big, size:" << encrypted_length + << ", long_term_mtu_:" << long_term_mtu_; + mtu_discoverer_.Disable(); + mtu_discovery_alarm_->Cancel(); + // The write failed, but the writer is not blocked, so return true. + return true; + } + if (!send_on_current_path) { + // Only handle MSG_TOO_BIG as error on current path. + return true; + } + } + + if (IsWriteError(result.status)) { + QUIC_LOG_FIRST_N(ERROR, 10) + << ENDPOINT << "Failed writing packet " << packet_number << " of " + << encrypted_length << " bytes from " << send_from_address.host() + << " to " << send_to_address << ", with error code " + << result.error_code << ". long_term_mtu_:" << long_term_mtu_ + << ", previous_validated_mtu_:" << previous_validated_mtu_ + << ", max_packet_length():" << max_packet_length() + << ", is_mtu_discovery:" << is_mtu_discovery; + if (MaybeRevertToPreviousMtu()) { + return true; + } + + OnWriteError(result.error_code); + return false; + } + + if (result.status == WRITE_STATUS_OK) { + // packet_send_time is the ideal send time, if allow_burst is true, writer + // may have sent it earlier than that. + packet_send_time = packet_send_time + result.send_time_offset; + } + + if (IsRetransmittable(*packet) == HAS_RETRANSMITTABLE_DATA && + !is_termination_packet) { + // Start blackhole/path degrading detections if the sent packet is not + // termination packet and contains retransmittable data. + // Do not restart detection if detection is in progress indicating no + // forward progress has been made since last event (i.e., packet was sent + // or new packets were acknowledged). + if (!blackhole_detector_.IsDetectionInProgress()) { + // Try to start detections if no detection in progress. This could + // because either both detections are inactive when sending last packet + // or this connection just gets out of quiescence. + blackhole_detector_.RestartDetection(GetPathDegradingDeadline(), + GetNetworkBlackholeDeadline(), + GetPathMtuReductionDeadline()); + } + idle_network_detector_.OnPacketSent(packet_send_time, + sent_packet_manager_.GetPtoDelay()); + } + + MaybeSetMtuAlarm(packet_number); + QUIC_DVLOG(1) << ENDPOINT << "time we began writing last sent packet: " + << packet_send_time.ToDebuggingValue(); + + if (IsDefaultPath(default_path_.self_address, send_to_address)) { + if (EnforceAntiAmplificationLimit()) { + // Include bytes sent even if they are not in flight. + default_path_.bytes_sent_before_address_validation += encrypted_length; + } + } else { + MaybeUpdateBytesSentToAlternativeAddress(send_to_address, encrypted_length); + } + + // Do not measure rtt of this packet if it's not sent on current path. + QUIC_DLOG_IF(INFO, !send_on_current_path) + << ENDPOINT << " Sent packet " << packet->packet_number + << " on a different path with remote address " << send_to_address + << " while current path has peer address " << peer_address(); + const bool in_flight = sent_packet_manager_.OnPacketSent( + packet, packet_send_time, packet->transmission_type, + IsRetransmittable(*packet), /*measure_rtt=*/send_on_current_path, + ECN_NOT_ECT); + QUIC_BUG_IF(quic_bug_12714_25, + perspective_ == Perspective::IS_SERVER && + default_enable_5rto_blackhole_detection_ && + blackhole_detector_.IsDetectionInProgress() && + !sent_packet_manager_.HasInFlightPackets()) + << ENDPOINT + << "Trying to start blackhole detection without no bytes in flight"; + + if (debug_visitor_ != nullptr) { + if (sent_packet_manager_.unacked_packets().empty()) { + QUIC_BUG(quic_bug_10511_25) + << "Unacked map is empty right after packet is sent"; + } else { + debug_visitor_->OnPacketSent( + packet->packet_number, packet->encrypted_length, + packet->has_crypto_handshake, packet->transmission_type, + packet->encryption_level, + sent_packet_manager_.unacked_packets() + .rbegin() + ->retransmittable_frames, + packet->nonretransmittable_frames, packet_send_time); + } + } + if (packet->encryption_level == ENCRYPTION_HANDSHAKE) { + handshake_packet_sent_ = true; + } + + if (packet->encryption_level == ENCRYPTION_FORWARD_SECURE) { + if (!lowest_packet_sent_in_current_key_phase_.IsInitialized()) { + QUIC_DLOG(INFO) << ENDPOINT + << "lowest_packet_sent_in_current_key_phase_ = " + << packet_number; + lowest_packet_sent_in_current_key_phase_ = packet_number; + } + if (!is_termination_packet && + MaybeHandleAeadConfidentialityLimits(*packet)) { + return true; + } + } + if (in_flight || !retransmission_alarm_->IsSet()) { + SetRetransmissionAlarm(); + } + SetPingAlarm(); + RetirePeerIssuedConnectionIdsNoLongerOnPath(); + + // The packet number length must be updated after OnPacketSent, because it + // may change the packet number length in packet. + packet_creator_.UpdatePacketNumberLength( + sent_packet_manager_.GetLeastPacketAwaitedByPeer(encryption_level_), + sent_packet_manager_.EstimateMaxPacketsInFlight(max_packet_length())); + + stats_.bytes_sent += encrypted_length; + ++stats_.packets_sent; + if (packet->has_ack_ecn) { + stats_.num_ack_frames_sent_with_ecn++; + } + + QuicByteCount bytes_not_retransmitted = + packet->bytes_not_retransmitted.value_or(0); + if (packet->transmission_type != NOT_RETRANSMISSION) { + if (static_cast(encrypted_length) < bytes_not_retransmitted) { + QUIC_BUG(quic_packet_bytes_written_lt_bytes_not_retransmitted) + << "Total bytes written to the packet should be larger than the " + "bytes in not-retransmitted frames. Bytes written: " + << encrypted_length + << ", bytes not retransmitted: " << bytes_not_retransmitted; + } else { + // bytes_retransmitted includes packet's headers and encryption + // overhead. + stats_.bytes_retransmitted += + (encrypted_length - bytes_not_retransmitted); + } + ++stats_.packets_retransmitted; + } + + return true; +} + +bool QuicConnection::MaybeHandleAeadConfidentialityLimits( + const SerializedPacket& packet) { + if (!version().UsesTls()) { + return false; + } + + if (packet.encryption_level != ENCRYPTION_FORWARD_SECURE) { + QUIC_BUG(quic_bug_12714_26) + << "MaybeHandleAeadConfidentialityLimits called on non 1-RTT packet"; + return false; + } + if (!lowest_packet_sent_in_current_key_phase_.IsInitialized()) { + QUIC_BUG(quic_bug_10511_26) + << "lowest_packet_sent_in_current_key_phase_ must be initialized " + "before calling MaybeHandleAeadConfidentialityLimits"; + return false; + } + + // Calculate the number of packets encrypted from the packet number, which is + // simpler than keeping another counter. The packet number space may be + // sparse, so this might overcount, but doing a key update earlier than + // necessary would only improve security and has negligible cost. + if (packet.packet_number < lowest_packet_sent_in_current_key_phase_) { + const std::string error_details = + absl::StrCat("packet_number(", packet.packet_number.ToString(), + ") < lowest_packet_sent_in_current_key_phase_ (", + lowest_packet_sent_in_current_key_phase_.ToString(), ")"); + QUIC_BUG(quic_bug_10511_27) << error_details; + CloseConnection(QUIC_INTERNAL_ERROR, error_details, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return true; + } + const QuicPacketCount num_packets_encrypted_in_current_key_phase = + packet.packet_number - lowest_packet_sent_in_current_key_phase_ + 1; + + const QuicPacketCount confidentiality_limit = + framer_.GetOneRttEncrypterConfidentialityLimit(); + + // Attempt to initiate a key update before reaching the AEAD + // confidentiality limit when the number of packets sent in the current + // key phase gets within |kKeyUpdateConfidentialityLimitOffset| packets of + // the limit, unless overridden by + // FLAGS_quic_key_update_confidentiality_limit. + constexpr QuicPacketCount kKeyUpdateConfidentialityLimitOffset = 1000; + QuicPacketCount key_update_limit = 0; + if (confidentiality_limit > kKeyUpdateConfidentialityLimitOffset) { + key_update_limit = + confidentiality_limit - kKeyUpdateConfidentialityLimitOffset; + } + const QuicPacketCount key_update_limit_override = + GetQuicFlag(quic_key_update_confidentiality_limit); + if (key_update_limit_override) { + key_update_limit = key_update_limit_override; + } + + QUIC_DVLOG(2) << ENDPOINT << "Checking AEAD confidentiality limits: " + << "num_packets_encrypted_in_current_key_phase=" + << num_packets_encrypted_in_current_key_phase + << " key_update_limit=" << key_update_limit + << " confidentiality_limit=" << confidentiality_limit + << " IsKeyUpdateAllowed()=" << IsKeyUpdateAllowed(); + + if (num_packets_encrypted_in_current_key_phase >= confidentiality_limit) { + // Reached the confidentiality limit without initiating a key update, + // must close the connection. + const std::string error_details = absl::StrCat( + "encrypter confidentiality limit reached: " + "num_packets_encrypted_in_current_key_phase=", + num_packets_encrypted_in_current_key_phase, + " key_update_limit=", key_update_limit, + " confidentiality_limit=", confidentiality_limit, + " IsKeyUpdateAllowed()=", IsKeyUpdateAllowed()); + CloseConnection(QUIC_AEAD_LIMIT_REACHED, error_details, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return true; + } + + if (IsKeyUpdateAllowed() && + num_packets_encrypted_in_current_key_phase >= key_update_limit) { + // Approaching the confidentiality limit, initiate key update so that + // the next set of keys will be ready for the next packet before the + // limit is reached. + KeyUpdateReason reason = KeyUpdateReason::kLocalAeadConfidentialityLimit; + if (key_update_limit_override) { + QUIC_DLOG(INFO) << ENDPOINT + << "reached FLAGS_quic_key_update_confidentiality_limit, " + "initiating key update: " + << "num_packets_encrypted_in_current_key_phase=" + << num_packets_encrypted_in_current_key_phase + << " key_update_limit=" << key_update_limit + << " confidentiality_limit=" << confidentiality_limit; + reason = KeyUpdateReason::kLocalKeyUpdateLimitOverride; + } else { + QUIC_DLOG(INFO) << ENDPOINT + << "approaching AEAD confidentiality limit, " + "initiating key update: " + << "num_packets_encrypted_in_current_key_phase=" + << num_packets_encrypted_in_current_key_phase + << " key_update_limit=" << key_update_limit + << " confidentiality_limit=" << confidentiality_limit; + } + InitiateKeyUpdate(reason); + } + + return false; +} + +void QuicConnection::FlushPackets() { + if (!connected_) { + return; + } + + if (!writer_->IsBatchMode()) { + return; + } + + if (HandleWriteBlocked()) { + QUIC_DLOG(INFO) << ENDPOINT << "FlushPackets called while blocked."; + return; + } + + WriteResult result = writer_->Flush(); + + QUIC_HISTOGRAM_ENUM("QuicConnection.FlushPacketStatus", result.status, + WRITE_STATUS_NUM_VALUES, + "Status code returned by writer_->Flush() in " + "QuicConnection::FlushPackets."); + + if (HandleWriteBlocked()) { + QUICHE_DCHECK_EQ(WRITE_STATUS_BLOCKED, result.status) + << "Unexpected flush result:" << result; + QUIC_DLOG(INFO) << ENDPOINT << "Write blocked in FlushPackets."; + return; + } + + if (IsWriteError(result.status) && !MaybeRevertToPreviousMtu()) { + OnWriteError(result.error_code); + } +} + +bool QuicConnection::IsMsgTooBig(const QuicPacketWriter* writer, + const WriteResult& result) { + absl::optional writer_error_code = writer->MessageTooBigErrorCode(); + return (result.status == WRITE_STATUS_MSG_TOO_BIG) || + (writer_error_code.has_value() && IsWriteError(result.status) && + result.error_code == *writer_error_code); +} + +bool QuicConnection::ShouldDiscardPacket(EncryptionLevel encryption_level) { + if (!connected_) { + QUIC_DLOG(INFO) << ENDPOINT + << "Not sending packet as connection is disconnected."; + return true; + } + + if (encryption_level_ == ENCRYPTION_FORWARD_SECURE && + encryption_level == ENCRYPTION_INITIAL) { + // Drop packets that are NULL encrypted since the peer won't accept them + // anymore. + QUIC_DLOG(INFO) << ENDPOINT + << "Dropping NULL encrypted packet since the connection is " + "forward secure."; + return true; + } + + return false; +} + +QuicTime QuicConnection::GetPathMtuReductionDeadline() const { + if (previous_validated_mtu_ == 0) { + return QuicTime::Zero(); + } + QuicTime::Delta delay = sent_packet_manager_.GetMtuReductionDelay( + num_rtos_for_blackhole_detection_); + if (delay.IsZero()) { + return QuicTime::Zero(); + } + return clock_->ApproximateNow() + delay; +} + +bool QuicConnection::MaybeRevertToPreviousMtu() { + if (previous_validated_mtu_ == 0) { + return false; + } + + SetMaxPacketLength(previous_validated_mtu_); + mtu_discoverer_.Disable(); + mtu_discovery_alarm_->Cancel(); + previous_validated_mtu_ = 0; + return true; +} + +void QuicConnection::OnWriteError(int error_code) { + if (write_error_occurred_) { + // A write error already occurred. The connection is being closed. + return; + } + write_error_occurred_ = true; + + const std::string error_details = absl::StrCat( + "Write failed with error: ", error_code, " (", strerror(error_code), ")"); + QUIC_LOG_FIRST_N(ERROR, 2) << ENDPOINT << error_details; + absl::optional writer_error_code = writer_->MessageTooBigErrorCode(); + if (writer_error_code.has_value() && error_code == *writer_error_code) { + CloseConnection(QUIC_PACKET_WRITE_ERROR, error_details, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + // We can't send an error as the socket is presumably borked. + if (version().HasIetfInvariantHeader()) { + QUIC_CODE_COUNT(quic_tear_down_local_connection_on_write_error_ietf); + } else { + QUIC_CODE_COUNT(quic_tear_down_local_connection_on_write_error_non_ietf); + } + CloseConnection(QUIC_PACKET_WRITE_ERROR, error_details, + ConnectionCloseBehavior::SILENT_CLOSE); +} + +QuicPacketBuffer QuicConnection::GetPacketBuffer() { + if (version().CanSendCoalescedPackets() && !coalescing_done_) { + // Do not use writer's packet buffer for coalesced packets which may + // contain multiple QUIC packets. + return {nullptr, nullptr}; + } + return writer_->GetNextWriteLocation(self_address().host(), peer_address()); +} + +void QuicConnection::OnSerializedPacket(SerializedPacket serialized_packet) { + if (serialized_packet.encrypted_buffer == nullptr) { + // We failed to serialize the packet, so close the connection. + // Specify that the close is silent, that no packet be sent, so no infinite + // loop here. + // TODO(ianswett): This is actually an internal error, not an + // encryption failure. + if (version().HasIetfInvariantHeader()) { + QUIC_CODE_COUNT( + quic_tear_down_local_connection_on_serialized_packet_ietf); + } else { + QUIC_CODE_COUNT( + quic_tear_down_local_connection_on_serialized_packet_non_ietf); + } + CloseConnection(QUIC_ENCRYPTION_FAILURE, + "Serialized packet does not have an encrypted buffer.", + ConnectionCloseBehavior::SILENT_CLOSE); + return; + } + + if (serialized_packet.retransmittable_frames.empty()) { + // Increment consecutive_num_packets_with_no_retransmittable_frames_ if + // this packet is a new transmission with no retransmittable frames. + ++consecutive_num_packets_with_no_retransmittable_frames_; + } else { + consecutive_num_packets_with_no_retransmittable_frames_ = 0; + } + if (retransmittable_on_wire_behavior_ == SEND_FIRST_FORWARD_SECURE_PACKET && + first_serialized_one_rtt_packet_ == nullptr && + serialized_packet.encryption_level == ENCRYPTION_FORWARD_SECURE) { + first_serialized_one_rtt_packet_ = std::make_unique( + serialized_packet, self_address(), peer_address()); + } + SendOrQueuePacket(std::move(serialized_packet)); +} + +void QuicConnection::OnUnrecoverableError(QuicErrorCode error, + const std::string& error_details) { + // The packet creator or generator encountered an unrecoverable error: tear + // down local connection state immediately. + if (version().HasIetfInvariantHeader()) { + QUIC_CODE_COUNT( + quic_tear_down_local_connection_on_unrecoverable_error_ietf); + } else { + QUIC_CODE_COUNT( + quic_tear_down_local_connection_on_unrecoverable_error_non_ietf); + } + CloseConnection(error, error_details, ConnectionCloseBehavior::SILENT_CLOSE); +} + +void QuicConnection::OnCongestionChange() { + visitor_->OnCongestionWindowChange(clock_->ApproximateNow()); + + // Uses the connection's smoothed RTT. If zero, uses initial_rtt. + QuicTime::Delta rtt = sent_packet_manager_.GetRttStats()->smoothed_rtt(); + if (rtt.IsZero()) { + rtt = sent_packet_manager_.GetRttStats()->initial_rtt(); + } + + if (debug_visitor_ != nullptr) { + debug_visitor_->OnRttChanged(rtt); + } +} + +void QuicConnection::OnPathMtuIncreased(QuicPacketLength packet_size) { + if (packet_size > max_packet_length()) { + previous_validated_mtu_ = max_packet_length(); + SetMaxPacketLength(packet_size); + mtu_discoverer_.OnMaxPacketLengthUpdated(previous_validated_mtu_, + max_packet_length()); + } +} + +std::unique_ptr +QuicConnection::MakeSelfIssuedConnectionIdManager() { + QUICHE_DCHECK((perspective_ == Perspective::IS_CLIENT && + !default_path_.client_connection_id.IsEmpty()) || + (perspective_ == Perspective::IS_SERVER && + !default_path_.server_connection_id.IsEmpty())); + return std::make_unique( + kMinNumOfActiveConnectionIds, + perspective_ == Perspective::IS_CLIENT + ? default_path_.client_connection_id + : default_path_.server_connection_id, + clock_, alarm_factory_, this, context(), connection_id_generator_); +} + +void QuicConnection::MaybeSendConnectionIdToClient() { + if (perspective_ == Perspective::IS_CLIENT) { + return; + } + QUICHE_DCHECK(self_issued_cid_manager_ != nullptr); + self_issued_cid_manager_->MaybeSendNewConnectionIds(); +} + +void QuicConnection::OnHandshakeComplete() { + sent_packet_manager_.SetHandshakeConfirmed(); + if (connection_migration_use_new_cid_ && + perspective_ == Perspective::IS_SERVER && + self_issued_cid_manager_ != nullptr) { + self_issued_cid_manager_->MaybeSendNewConnectionIds(); + } + if (send_ack_frequency_on_handshake_completion_ && + sent_packet_manager_.CanSendAckFrequency()) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_can_send_ack_frequency, 2, 3); + auto ack_frequency_frame = + sent_packet_manager_.GetUpdatedAckFrequencyFrame(); + // This AckFrequencyFrame is meant to only update the max_ack_delay. Set + // packet tolerance to the default value for now. + ack_frequency_frame.packet_tolerance = + kDefaultRetransmittablePacketsBeforeAck; + visitor_->SendAckFrequency(ack_frequency_frame); + if (!connected_) { + return; + } + } + // This may have changed the retransmission timer, so re-arm it. + SetRetransmissionAlarm(); + if (default_enable_5rto_blackhole_detection_) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_default_enable_5rto_blackhole_detection2, + 2, 3); + OnForwardProgressMade(); + } + if (!SupportsMultiplePacketNumberSpaces()) { + // The client should immediately ack the SHLO to confirm the handshake is + // complete with the server. + if (perspective_ == Perspective::IS_CLIENT && ack_frame_updated()) { + ack_alarm_->Update(clock_->ApproximateNow(), QuicTime::Delta::Zero()); + } + return; + } + // Stop sending ack of handshake packet number space. + uber_received_packet_manager_.ResetAckStates(ENCRYPTION_HANDSHAKE); + // Re-arm ack alarm. + ack_alarm_->Update(uber_received_packet_manager_.GetEarliestAckTimeout(), + kAlarmGranularity); + if (!accelerated_server_preferred_address_ && + received_server_preferred_address_.IsInitialized()) { + QUICHE_DCHECK_EQ(Perspective::IS_CLIENT, perspective_); + visitor_->OnServerPreferredAddressAvailable( + received_server_preferred_address_); + } +} + +void QuicConnection::MaybeCreateMultiPortPath() { + QUICHE_DCHECK_EQ(Perspective::IS_CLIENT, perspective_); + if (path_validator_.HasPendingPathValidation()) { + QUIC_CLIENT_HISTOGRAM_ENUM("QuicConnection.MultiPortPathCreationCancelled", + path_validator_.GetPathValidationReason(), + PathValidationReason::kMaxValue, + "Reason for cancelled multi port path creation"); + return; + } + if (multi_port_stats_->num_multi_port_paths_created >= + kMaxNumMultiPortPaths) { + return; + } + auto path_context = visitor_->CreateContextForMultiPortPath(); + if (!path_context) { + return; + } + auto multi_port_validation_result_delegate = + std::make_unique(this); + multi_port_probing_alarm_->Cancel(); + multi_port_path_context_ = nullptr; + multi_port_stats_->num_multi_port_paths_created++; + ValidatePath(std::move(path_context), + std::move(multi_port_validation_result_delegate), + PathValidationReason::kMultiPort); +} + +void QuicConnection::SendOrQueuePacket(SerializedPacket packet) { + // The caller of this function is responsible for checking CanWrite(). + WritePacket(&packet); +} + +void QuicConnection::SendAck() { + QUICHE_DCHECK(!SupportsMultiplePacketNumberSpaces()); + QUIC_DVLOG(1) << ENDPOINT << "Sending an ACK proactively"; + QuicFrames frames; + frames.push_back(GetUpdatedAckFrame()); + if (!no_stop_waiting_frames_) { + QuicStopWaitingFrame stop_waiting; + PopulateStopWaitingFrame(&stop_waiting); + frames.push_back(QuicFrame(stop_waiting)); + } + if (!packet_creator_.FlushAckFrame(frames)) { + return; + } + ResetAckStates(); + if (!ShouldBundleRetransmittableFrameWithAck()) { + return; + } + consecutive_num_packets_with_no_retransmittable_frames_ = 0; + if (packet_creator_.HasPendingRetransmittableFrames() || + visitor_->WillingAndAbleToWrite()) { + // There are pending retransmittable frames. + return; + } + + visitor_->OnAckNeedsRetransmittableFrame(); +} + +EncryptionLevel QuicConnection::GetEncryptionLevelToSendPingForSpace( + PacketNumberSpace space) const { + switch (space) { + case INITIAL_DATA: + return ENCRYPTION_INITIAL; + case HANDSHAKE_DATA: + return ENCRYPTION_HANDSHAKE; + case APPLICATION_DATA: + return framer_.GetEncryptionLevelToSendApplicationData(); + default: + QUICHE_DCHECK(false); + return NUM_ENCRYPTION_LEVELS; + } +} + +bool QuicConnection::IsKnownServerAddress( + const QuicSocketAddress& address) const { + QUICHE_DCHECK(address.IsInitialized()); + return std::find(known_server_addresses_.cbegin(), + known_server_addresses_.cend(), + address) != known_server_addresses_.cend(); +} + +void QuicConnection::ClearEcnCodepoint() { + if (per_packet_options_ != nullptr) { + per_packet_options_->ecn_codepoint = ECN_NOT_ECT; + } +} + +WriteResult QuicConnection::SendPacketToWriter( + const char* buffer, size_t buf_len, const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, PerPacketOptions* options) { + if (!disable_ecn_codepoint_validation_) { + switch (GetNextEcnCodepoint()) { + case ECN_NOT_ECT: + break; + case ECN_ECT0: + if (!sent_packet_manager_.GetSendAlgorithm()->SupportsECT0()) { + ClearEcnCodepoint(); + } + break; + case ECN_ECT1: + if (!sent_packet_manager_.GetSendAlgorithm()->SupportsECT1()) { + ClearEcnCodepoint(); + } + break; + case ECN_CE: + ClearEcnCodepoint(); + break; + } + } + return writer_->WritePacket(buffer, buf_len, self_address, peer_address, + options); +} + +void QuicConnection::OnRetransmissionTimeout() { + ScopedRetransmissionTimeoutIndicator indicator(this); +#ifndef NDEBUG + if (sent_packet_manager_.unacked_packets().empty()) { + QUICHE_DCHECK(sent_packet_manager_.handshake_mode_disabled()); + QUICHE_DCHECK(!IsHandshakeComplete()); + } +#endif + if (!connected_) { + return; + } + + QuicPacketNumber previous_created_packet_number = + packet_creator_.packet_number(); + const auto retransmission_mode = + sent_packet_manager_.OnRetransmissionTimeout(); + if (retransmission_mode == QuicSentPacketManager::PTO_MODE) { + // Skip a packet number when PTO fires to elicit an immediate ACK. + const QuicPacketCount num_packet_numbers_to_skip = 1; + packet_creator_.SkipNPacketNumbers( + num_packet_numbers_to_skip, + sent_packet_manager_.GetLeastPacketAwaitedByPeer(encryption_level_), + sent_packet_manager_.EstimateMaxPacketsInFlight(max_packet_length())); + previous_created_packet_number += num_packet_numbers_to_skip; + if (debug_visitor_ != nullptr) { + debug_visitor_->OnNPacketNumbersSkipped(num_packet_numbers_to_skip, + clock_->Now()); + } + } + if (default_enable_5rto_blackhole_detection_ && + !sent_packet_manager_.HasInFlightPackets() && + blackhole_detector_.IsDetectionInProgress()) { + // Stop detection in quiescence. + QUICHE_DCHECK_EQ(QuicSentPacketManager::LOSS_MODE, retransmission_mode); + blackhole_detector_.StopDetection(/*permanent=*/false); + } + WriteIfNotBlocked(); + + // A write failure can result in the connection being closed, don't attempt to + // write further packets, or to set alarms. + if (!connected_) { + return; + } + // When PTO fires, the SentPacketManager gives the connection the opportunity + // to send new data before retransmitting. + sent_packet_manager_.MaybeSendProbePacket(); + + if (packet_creator_.packet_number() == previous_created_packet_number && + retransmission_mode == QuicSentPacketManager::PTO_MODE && + !visitor_->WillingAndAbleToWrite()) { + // Send PING if timer fires in PTO mode but there is no data to send. + QUIC_DLOG(INFO) << ENDPOINT + << "No packet gets sent when timer fires in mode " + << retransmission_mode << ", send PING"; + QUICHE_DCHECK_LT(0u, + sent_packet_manager_.pending_timer_transmission_count()); + if (SupportsMultiplePacketNumberSpaces()) { + // Based on https://datatracker.ietf.org/doc/html/rfc9002#appendix-A.9 + PacketNumberSpace packet_number_space; + if (sent_packet_manager_ + .GetEarliestPacketSentTimeForPto(&packet_number_space) + .IsInitialized()) { + SendPingAtLevel( + GetEncryptionLevelToSendPingForSpace(packet_number_space)); + } else { + // The client must PTO when there is nothing in flight if the server + // could be blocked from sending by the amplification limit + QUICHE_DCHECK_EQ(Perspective::IS_CLIENT, perspective_); + if (framer_.HasEncrypterOfEncryptionLevel(ENCRYPTION_HANDSHAKE)) { + SendPingAtLevel(ENCRYPTION_HANDSHAKE); + } else if (framer_.HasEncrypterOfEncryptionLevel(ENCRYPTION_INITIAL)) { + SendPingAtLevel(ENCRYPTION_INITIAL); + } else { + QUIC_BUG(quic_bug_no_pto) << "PTO fired but nothing was sent."; + } + } + } else { + SendPingAtLevel(encryption_level_); + } + } + if (retransmission_mode == QuicSentPacketManager::PTO_MODE) { + // When timer fires in PTO mode, ensure 1) at least one packet is created, + // or there is data to send and available credit (such that packets will be + // sent eventually). + QUIC_BUG_IF( + quic_bug_12714_27, + packet_creator_.packet_number() == previous_created_packet_number && + (!visitor_->WillingAndAbleToWrite() || + sent_packet_manager_.pending_timer_transmission_count() == 0u)) + << "retransmission_mode: " << retransmission_mode + << ", packet_number: " << packet_creator_.packet_number() + << ", session has data to write: " << visitor_->WillingAndAbleToWrite() + << ", writer is blocked: " << writer_->IsWriteBlocked() + << ", pending_timer_transmission_count: " + << sent_packet_manager_.pending_timer_transmission_count(); + } + + // Ensure the retransmission alarm is always set if there are unacked packets + // and nothing waiting to be sent. + // This happens if the loss algorithm invokes a timer based loss, but the + // packet doesn't need to be retransmitted. + if (!HasQueuedData() && !retransmission_alarm_->IsSet()) { + SetRetransmissionAlarm(); + } +} + +void QuicConnection::SetEncrypter(EncryptionLevel level, + std::unique_ptr encrypter) { + packet_creator_.SetEncrypter(level, std::move(encrypter)); +} + +void QuicConnection::RemoveEncrypter(EncryptionLevel level) { + framer_.RemoveEncrypter(level); +} + +void QuicConnection::SetDiversificationNonce( + const DiversificationNonce& nonce) { + QUICHE_DCHECK_EQ(Perspective::IS_SERVER, perspective_); + packet_creator_.SetDiversificationNonce(nonce); +} + +void QuicConnection::SetDefaultEncryptionLevel(EncryptionLevel level) { + QUIC_DVLOG(1) << ENDPOINT << "Setting default encryption level from " + << encryption_level_ << " to " << level; + const bool changing_level = level != encryption_level_; + if (changing_level && packet_creator_.HasPendingFrames()) { + // Flush all queued frames when encryption level changes. + ScopedPacketFlusher flusher(this); + packet_creator_.FlushCurrentPacket(); + } + encryption_level_ = level; + packet_creator_.set_encryption_level(level); + QUIC_BUG_IF(quic_bug_12714_28, !framer_.HasEncrypterOfEncryptionLevel(level)) + << ENDPOINT << "Trying to set encryption level to " + << EncryptionLevelToString(level) << " while the key is missing"; + + if (!changing_level) { + return; + } + // The least packet awaited by the peer depends on the encryption level so + // we recalculate it here. + packet_creator_.UpdatePacketNumberLength( + sent_packet_manager_.GetLeastPacketAwaitedByPeer(encryption_level_), + sent_packet_manager_.EstimateMaxPacketsInFlight(max_packet_length())); +} + +void QuicConnection::SetDecrypter(EncryptionLevel level, + std::unique_ptr decrypter) { + framer_.SetDecrypter(level, std::move(decrypter)); + + if (!undecryptable_packets_.empty() && + !process_undecryptable_packets_alarm_->IsSet()) { + process_undecryptable_packets_alarm_->Set(clock_->ApproximateNow()); + } +} + +void QuicConnection::SetAlternativeDecrypter( + EncryptionLevel level, std::unique_ptr decrypter, + bool latch_once_used) { + framer_.SetAlternativeDecrypter(level, std::move(decrypter), latch_once_used); + + if (!undecryptable_packets_.empty() && + !process_undecryptable_packets_alarm_->IsSet()) { + process_undecryptable_packets_alarm_->Set(clock_->ApproximateNow()); + } +} + +void QuicConnection::InstallDecrypter( + EncryptionLevel level, std::unique_ptr decrypter) { + if (level == ENCRYPTION_ZERO_RTT) { + had_zero_rtt_decrypter_ = true; + } + framer_.InstallDecrypter(level, std::move(decrypter)); + if (!undecryptable_packets_.empty() && + !process_undecryptable_packets_alarm_->IsSet()) { + process_undecryptable_packets_alarm_->Set(clock_->ApproximateNow()); + } +} + +void QuicConnection::RemoveDecrypter(EncryptionLevel level) { + framer_.RemoveDecrypter(level); +} + +void QuicConnection::DiscardPreviousOneRttKeys() { + framer_.DiscardPreviousOneRttKeys(); +} + +bool QuicConnection::IsKeyUpdateAllowed() const { + return support_key_update_for_connection_ && + GetLargestAckedPacket().IsInitialized() && + lowest_packet_sent_in_current_key_phase_.IsInitialized() && + GetLargestAckedPacket() >= lowest_packet_sent_in_current_key_phase_; +} + +bool QuicConnection::HaveSentPacketsInCurrentKeyPhaseButNoneAcked() const { + return lowest_packet_sent_in_current_key_phase_.IsInitialized() && + (!GetLargestAckedPacket().IsInitialized() || + GetLargestAckedPacket() < lowest_packet_sent_in_current_key_phase_); +} + +QuicPacketCount QuicConnection::PotentialPeerKeyUpdateAttemptCount() const { + return framer_.PotentialPeerKeyUpdateAttemptCount(); +} + +bool QuicConnection::InitiateKeyUpdate(KeyUpdateReason reason) { + QUIC_DLOG(INFO) << ENDPOINT << "InitiateKeyUpdate"; + if (!IsKeyUpdateAllowed()) { + QUIC_BUG(quic_bug_10511_28) << "key update not allowed"; + return false; + } + return framer_.DoKeyUpdate(reason); +} + +const QuicDecrypter* QuicConnection::decrypter() const { + return framer_.decrypter(); +} + +const QuicDecrypter* QuicConnection::alternative_decrypter() const { + return framer_.alternative_decrypter(); +} + +void QuicConnection::QueueUndecryptablePacket( + const QuicEncryptedPacket& packet, EncryptionLevel decryption_level) { + for (const auto& saved_packet : undecryptable_packets_) { + if (packet.data() == saved_packet.packet->data() && + packet.length() == saved_packet.packet->length()) { + QUIC_DVLOG(1) << ENDPOINT << "Not queueing known undecryptable packet"; + return; + } + } + QUIC_DVLOG(1) << ENDPOINT << "Queueing undecryptable packet."; + undecryptable_packets_.emplace_back(packet, decryption_level, + last_received_packet_info_); + if (perspective_ == Perspective::IS_CLIENT) { + SetRetransmissionAlarm(); + } +} + +void QuicConnection::MaybeProcessUndecryptablePackets() { + process_undecryptable_packets_alarm_->Cancel(); + + if (undecryptable_packets_.empty() || + encryption_level_ == ENCRYPTION_INITIAL) { + return; + } + + auto iter = undecryptable_packets_.begin(); + while (connected_ && iter != undecryptable_packets_.end()) { + // Making sure there is no pending frames when processing next undecrypted + // packet because the queued ack frame may change. + packet_creator_.FlushCurrentPacket(); + if (!connected_) { + return; + } + UndecryptablePacket* undecryptable_packet = &*iter; + QUIC_DVLOG(1) << ENDPOINT << "Attempting to process undecryptable packet"; + if (debug_visitor_ != nullptr) { + debug_visitor_->OnAttemptingToProcessUndecryptablePacket( + undecryptable_packet->encryption_level); + } + last_received_packet_info_ = undecryptable_packet->packet_info; + current_packet_data_ = undecryptable_packet->packet->data(); + const bool processed = framer_.ProcessPacket(*undecryptable_packet->packet); + current_packet_data_ = nullptr; + + if (processed) { + QUIC_DVLOG(1) << ENDPOINT << "Processed undecryptable packet!"; + iter = undecryptable_packets_.erase(iter); + ++stats_.packets_processed; + continue; + } + const bool has_decryption_key = version().KnowsWhichDecrypterToUse() && + framer_.HasDecrypterOfEncryptionLevel( + undecryptable_packet->encryption_level); + if (framer_.error() == QUIC_DECRYPTION_FAILURE && + ShouldEnqueueUnDecryptablePacket(undecryptable_packet->encryption_level, + has_decryption_key)) { + QUIC_DVLOG(1) + << ENDPOINT + << "Need to attempt to process this undecryptable packet later"; + ++iter; + continue; + } + iter = undecryptable_packets_.erase(iter); + } + + // Once handshake is complete, there will be no new keys installed and hence + // any undecryptable packets will never be able to be decrypted. + if (IsHandshakeComplete()) { + if (debug_visitor_ != nullptr) { + for (const auto& undecryptable_packet : undecryptable_packets_) { + debug_visitor_->OnUndecryptablePacket( + undecryptable_packet.encryption_level, /*dropped=*/true); + } + } + undecryptable_packets_.clear(); + } + if (perspective_ == Perspective::IS_CLIENT) { + SetRetransmissionAlarm(); + } +} + +void QuicConnection::QueueCoalescedPacket(const QuicEncryptedPacket& packet) { + QUIC_DVLOG(1) << ENDPOINT << "Queueing coalesced packet."; + received_coalesced_packets_.push_back(packet.Clone()); + ++stats_.num_coalesced_packets_received; +} + +bool QuicConnection::MaybeProcessCoalescedPackets() { + bool processed = false; + while (connected_ && !received_coalesced_packets_.empty()) { + // Making sure there are no pending frames when processing the next + // coalesced packet because the queued ack frame may change. + packet_creator_.FlushCurrentPacket(); + if (!connected_) { + return processed; + } + + std::unique_ptr packet = + std::move(received_coalesced_packets_.front()); + received_coalesced_packets_.pop_front(); + + QUIC_DVLOG(1) << ENDPOINT << "Processing coalesced packet"; + if (framer_.ProcessPacket(*packet)) { + processed = true; + ++stats_.num_coalesced_packets_processed; + } else { + // If we are unable to decrypt this packet, it might be + // because the CHLO or SHLO packet was lost. + } + } + if (processed) { + MaybeProcessUndecryptablePackets(); + MaybeSendInResponseToPacket(); + } + return processed; +} + +void QuicConnection::CloseConnection( + QuicErrorCode error, const std::string& details, + ConnectionCloseBehavior connection_close_behavior) { + CloseConnection(error, NO_IETF_QUIC_ERROR, details, + connection_close_behavior); +} + +void QuicConnection::CloseConnection( + QuicErrorCode error, QuicIetfTransportErrorCodes ietf_error, + const std::string& error_details, + ConnectionCloseBehavior connection_close_behavior) { + QUICHE_DCHECK(!error_details.empty()); + if (!connected_) { + QUIC_DLOG(INFO) << "Connection is already closed."; + return; + } + + if (ietf_error != NO_IETF_QUIC_ERROR) { + QUIC_DLOG(INFO) << ENDPOINT << "Closing connection: " << connection_id() + << ", with wire error: " << ietf_error + << ", error: " << QuicErrorCodeToString(error) + << ", and details: " << error_details; + } else { + QUIC_DLOG(INFO) << ENDPOINT << "Closing connection: " << connection_id() + << ", with error: " << QuicErrorCodeToString(error) << " (" + << error << "), and details: " << error_details; + } + + if (connection_close_behavior != ConnectionCloseBehavior::SILENT_CLOSE) { + SendConnectionClosePacket(error, ietf_error, error_details); + } + + TearDownLocalConnectionState(error, ietf_error, error_details, + ConnectionCloseSource::FROM_SELF); +} + +void QuicConnection::SendConnectionClosePacket( + QuicErrorCode error, QuicIetfTransportErrorCodes ietf_error, + const std::string& details) { + // Always use the current path to send CONNECTION_CLOSE. + QuicPacketCreator::ScopedPeerAddressContext context( + &packet_creator_, peer_address(), default_path_.client_connection_id, + default_path_.server_connection_id, connection_migration_use_new_cid_); + if (!SupportsMultiplePacketNumberSpaces()) { + QUIC_DLOG(INFO) << ENDPOINT << "Sending connection close packet."; + ScopedEncryptionLevelContext context(this, + GetConnectionCloseEncryptionLevel()); + if (version().CanSendCoalescedPackets()) { + coalesced_packet_.Clear(); + } + ClearQueuedPackets(); + // If there was a packet write error, write the smallest close possible. + ScopedPacketFlusher flusher(this); + // Always bundle an ACK with connection close for debugging purpose. + if (error != QUIC_PACKET_WRITE_ERROR && + !uber_received_packet_manager_.IsAckFrameEmpty( + QuicUtils::GetPacketNumberSpace(encryption_level_)) && + !packet_creator_.has_ack()) { + SendAck(); + } + QuicConnectionCloseFrame* frame; + + frame = new QuicConnectionCloseFrame(transport_version(), error, ietf_error, + details, + framer_.current_received_frame_type()); + packet_creator_.ConsumeRetransmittableControlFrame(QuicFrame(frame)); + packet_creator_.FlushCurrentPacket(); + if (version().CanSendCoalescedPackets()) { + FlushCoalescedPacket(); + } + ClearQueuedPackets(); + return; + } + ScopedPacketFlusher flusher(this); + + // Now that the connection is being closed, discard any unsent packets + // so the only packets to be sent will be connection close packets. + if (version().CanSendCoalescedPackets()) { + coalesced_packet_.Clear(); + } + ClearQueuedPackets(); + + for (EncryptionLevel level : + {ENCRYPTION_INITIAL, ENCRYPTION_HANDSHAKE, ENCRYPTION_ZERO_RTT, + ENCRYPTION_FORWARD_SECURE}) { + if (!framer_.HasEncrypterOfEncryptionLevel(level)) { + continue; + } + QUIC_DLOG(INFO) << ENDPOINT + << "Sending connection close packet at level: " << level; + ScopedEncryptionLevelContext context(this, level); + // Bundle an ACK of the corresponding packet number space for debugging + // purpose. + if (error != QUIC_PACKET_WRITE_ERROR && + !uber_received_packet_manager_.IsAckFrameEmpty( + QuicUtils::GetPacketNumberSpace(encryption_level_)) && + !packet_creator_.has_ack()) { + QuicFrames frames; + frames.push_back(GetUpdatedAckFrame()); + packet_creator_.FlushAckFrame(frames); + } + + if (level == ENCRYPTION_FORWARD_SECURE && + perspective_ == Perspective::IS_SERVER) { + visitor_->BeforeConnectionCloseSent(); + } + + auto* frame = new QuicConnectionCloseFrame( + transport_version(), error, ietf_error, details, + framer_.current_received_frame_type()); + packet_creator_.ConsumeRetransmittableControlFrame(QuicFrame(frame)); + packet_creator_.FlushCurrentPacket(); + } + if (version().CanSendCoalescedPackets()) { + FlushCoalescedPacket(); + } + // Since the connection is closing, if the connection close packets were not + // sent, then they should be discarded. + ClearQueuedPackets(); +} + +void QuicConnection::TearDownLocalConnectionState( + QuicErrorCode error, QuicIetfTransportErrorCodes ietf_error, + const std::string& error_details, ConnectionCloseSource source) { + QuicConnectionCloseFrame frame(transport_version(), error, ietf_error, + error_details, + framer_.current_received_frame_type()); + return TearDownLocalConnectionState(frame, source); +} + +void QuicConnection::TearDownLocalConnectionState( + const QuicConnectionCloseFrame& frame, ConnectionCloseSource source) { + if (!connected_) { + QUIC_DLOG(INFO) << "Connection is already closed."; + return; + } + + // If we are using a batch writer, flush packets queued in it, if any. + FlushPackets(); + connected_ = false; + QUICHE_DCHECK(visitor_ != nullptr); + visitor_->OnConnectionClosed(frame, source); + // LossDetectionTunerInterface::Finish() may be called from + // sent_packet_manager_.OnConnectionClosed. Which may require the session to + // finish its business first. + sent_packet_manager_.OnConnectionClosed(); + if (debug_visitor_ != nullptr) { + debug_visitor_->OnConnectionClosed(frame, source); + } + // Cancel the alarms so they don't trigger any action now that the + // connection is closed. + CancelAllAlarms(); + CancelPathValidation(); + + peer_issued_cid_manager_.reset(); + self_issued_cid_manager_.reset(); +} + +void QuicConnection::CancelAllAlarms() { + QUIC_DVLOG(1) << "Cancelling all QuicConnection alarms."; + + ack_alarm_->PermanentCancel(); + ping_manager_.Stop(); + retransmission_alarm_->PermanentCancel(); + send_alarm_->PermanentCancel(); + mtu_discovery_alarm_->PermanentCancel(); + process_undecryptable_packets_alarm_->PermanentCancel(); + discard_previous_one_rtt_keys_alarm_->PermanentCancel(); + discard_zero_rtt_decryption_keys_alarm_->PermanentCancel(); + multi_port_probing_alarm_->PermanentCancel(); + blackhole_detector_.StopDetection(/*permanent=*/true); + idle_network_detector_.StopDetection(); +} + +QuicByteCount QuicConnection::max_packet_length() const { + return packet_creator_.max_packet_length(); +} + +void QuicConnection::SetMaxPacketLength(QuicByteCount length) { + long_term_mtu_ = length; + stats_.max_egress_mtu = std::max(stats_.max_egress_mtu, long_term_mtu_); + packet_creator_.SetMaxPacketLength(GetLimitedMaxPacketSize(length)); +} + +bool QuicConnection::HasQueuedData() const { + return packet_creator_.HasPendingFrames() || !buffered_packets_.empty(); +} + +void QuicConnection::SetNetworkTimeouts(QuicTime::Delta handshake_timeout, + QuicTime::Delta idle_timeout) { + QUIC_BUG_IF(quic_bug_12714_29, idle_timeout > handshake_timeout) + << "idle_timeout:" << idle_timeout.ToMilliseconds() + << " handshake_timeout:" << handshake_timeout.ToMilliseconds(); + // Adjust the idle timeout on client and server to prevent clients from + // sending requests to servers which have already closed the connection. + if (perspective_ == Perspective::IS_SERVER) { + idle_timeout = idle_timeout + QuicTime::Delta::FromSeconds(3); + } else if (idle_timeout > QuicTime::Delta::FromSeconds(1)) { + idle_timeout = idle_timeout - QuicTime::Delta::FromSeconds(1); + } + idle_network_detector_.SetTimeouts(handshake_timeout, idle_timeout); +} + +void QuicConnection::SetPingAlarm() { + if (!connected_) { + return; + } + ping_manager_.SetAlarm(clock_->ApproximateNow(), + visitor_->ShouldKeepConnectionAlive(), + sent_packet_manager_.HasInFlightPackets()); +} + +void QuicConnection::SetRetransmissionAlarm() { + if (!connected_) { + if (retransmission_alarm_->IsSet()) { + QUIC_BUG(quic_bug_10511_29) + << ENDPOINT << "Retransmission alarm is set while disconnected"; + retransmission_alarm_->Cancel(); + } + return; + } + if (packet_creator_.PacketFlusherAttached()) { + pending_retransmission_alarm_ = true; + return; + } + if (LimitedByAmplificationFactor(packet_creator_.max_packet_length())) { + // Do not set retransmission timer if connection is anti-amplification limit + // throttled. Otherwise, nothing can be sent when timer fires. + retransmission_alarm_->Cancel(); + return; + } + PacketNumberSpace packet_number_space; + if (SupportsMultiplePacketNumberSpaces() && !IsHandshakeConfirmed() && + !sent_packet_manager_ + .GetEarliestPacketSentTimeForPto(&packet_number_space) + .IsInitialized()) { + // Before handshake gets confirmed, GetEarliestPacketSentTimeForPto + // returning 0 indicates no packets are in flight or only application data + // is in flight. + if (perspective_ == Perspective::IS_SERVER) { + // No need to arm PTO on server side. + retransmission_alarm_->Cancel(); + return; + } + if (retransmission_alarm_->IsSet() && + GetRetransmissionDeadline() > retransmission_alarm_->deadline()) { + // Do not postpone armed PTO on the client side. + return; + } + } + + retransmission_alarm_->Update(GetRetransmissionDeadline(), kAlarmGranularity); +} + +void QuicConnection::MaybeSetMtuAlarm(QuicPacketNumber sent_packet_number) { + if (mtu_discovery_alarm_->IsSet() || + !mtu_discoverer_.ShouldProbeMtu(sent_packet_number)) { + return; + } + mtu_discovery_alarm_->Set(clock_->ApproximateNow()); +} + +QuicConnection::ScopedPacketFlusher::ScopedPacketFlusher( + QuicConnection* connection) + : connection_(connection), + flush_and_set_pending_retransmission_alarm_on_delete_(false), + handshake_packet_sent_(connection != nullptr && + connection->handshake_packet_sent_) { + if (connection_ == nullptr) { + return; + } + + if (!connection_->packet_creator_.PacketFlusherAttached()) { + flush_and_set_pending_retransmission_alarm_on_delete_ = true; + connection->packet_creator_.AttachPacketFlusher(); + } +} + +QuicConnection::ScopedPacketFlusher::~ScopedPacketFlusher() { + if (connection_ == nullptr || !connection_->connected()) { + return; + } + + if (flush_and_set_pending_retransmission_alarm_on_delete_) { + const QuicTime ack_timeout = + connection_->uber_received_packet_manager_.GetEarliestAckTimeout(); + if (ack_timeout.IsInitialized()) { + if (ack_timeout <= connection_->clock_->ApproximateNow() && + !connection_->CanWrite(NO_RETRANSMITTABLE_DATA)) { + // Cancel ACK alarm if connection is write blocked, and ACK will be + // sent when connection gets unblocked. + connection_->ack_alarm_->Cancel(); + } else if (!connection_->ack_alarm_->IsSet() || + connection_->ack_alarm_->deadline() > ack_timeout) { + connection_->ack_alarm_->Update(ack_timeout, QuicTime::Delta::Zero()); + } + } + if (connection_->ack_alarm_->IsSet() && + connection_->ack_alarm_->deadline() <= + connection_->clock_->ApproximateNow()) { + // An ACK needs to be sent right now. This ACK did not get bundled + // because either there was no data to write or packets were marked as + // received after frames were queued in the generator. + if (connection_->send_alarm_->IsSet() && + connection_->send_alarm_->deadline() <= + connection_->clock_->ApproximateNow()) { + // If send alarm will go off soon, let send alarm send the ACK. + connection_->ack_alarm_->Cancel(); + } else if (connection_->SupportsMultiplePacketNumberSpaces()) { + connection_->SendAllPendingAcks(); + } else { + connection_->SendAck(); + } + } + + // INITIAL or HANDSHAKE retransmission could cause peer to derive new + // keys, such that the buffered undecryptable packets may be processed. + // This endpoint would derive an inflated RTT sample when receiving ACKs + // of those undecryptable packets. To mitigate this, tries to coalesce as + // many higher space packets as possible (via for loop inside + // MaybeCoalescePacketOfHigherSpace) to fill the remaining space in the + // coalescer. + if (connection_->version().CanSendCoalescedPackets()) { + connection_->MaybeCoalescePacketOfHigherSpace(); + } + connection_->packet_creator_.Flush(); + if (connection_->version().CanSendCoalescedPackets()) { + connection_->FlushCoalescedPacket(); + } + connection_->FlushPackets(); + + if (!connection_->connected()) { + return; + } + + if (!handshake_packet_sent_ && connection_->handshake_packet_sent_) { + // This would cause INITIAL key to be dropped. Drop keys here to avoid + // missing the write keys in the middle of writing. + connection_->visitor_->OnHandshakePacketSent(); + } + // Reset transmission type. + connection_->SetTransmissionType(NOT_RETRANSMISSION); + + // Once all transmissions are done, check if there is any outstanding data + // to send and notify the congestion controller if not. + // + // Note that this means that the application limited check will happen as + // soon as the last flusher gets destroyed, which is typically after a + // single stream write is finished. This means that if all the data from a + // single write goes through the connection, the application-limited signal + // will fire even if the caller does a write operation immediately after. + // There are two important approaches to remedy this situation: + // (1) Instantiate ScopedPacketFlusher before performing multiple subsequent + // writes, thus deferring this check until all writes are done. + // (2) Write data in chunks sufficiently large so that they cause the + // connection to be limited by the congestion control. Typically, this + // would mean writing chunks larger than the product of the current + // pacing rate and the pacer granularity. So, for instance, if the + // pacing rate of the connection is 1 Gbps, and the pacer granularity is + // 1 ms, the caller should send at least 125k bytes in order to not + // be marked as application-limited. + connection_->CheckIfApplicationLimited(); + + if (connection_->pending_retransmission_alarm_) { + connection_->SetRetransmissionAlarm(); + connection_->pending_retransmission_alarm_ = false; + } + } + QUICHE_DCHECK_EQ(flush_and_set_pending_retransmission_alarm_on_delete_, + !connection_->packet_creator_.PacketFlusherAttached()); +} + +QuicConnection::ScopedEncryptionLevelContext::ScopedEncryptionLevelContext( + QuicConnection* connection, EncryptionLevel encryption_level) + : connection_(connection), latched_encryption_level_(ENCRYPTION_INITIAL) { + if (connection_ == nullptr) { + return; + } + latched_encryption_level_ = connection_->encryption_level_; + connection_->SetDefaultEncryptionLevel(encryption_level); +} + +QuicConnection::ScopedEncryptionLevelContext::~ScopedEncryptionLevelContext() { + if (connection_ == nullptr || !connection_->connected_) { + return; + } + connection_->SetDefaultEncryptionLevel(latched_encryption_level_); +} + +QuicConnection::BufferedPacket::BufferedPacket( + const SerializedPacket& packet, const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address) + : BufferedPacket(packet.encrypted_buffer, packet.encrypted_length, + self_address, peer_address) {} + +QuicConnection::BufferedPacket::BufferedPacket( + const char* encrypted_buffer, QuicPacketLength encrypted_length, + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address) + : length(encrypted_length), + self_address(self_address), + peer_address(peer_address) { + data = std::make_unique(encrypted_length); + memcpy(data.get(), encrypted_buffer, encrypted_length); +} + +QuicConnection::BufferedPacket::BufferedPacket( + QuicRandom& random, QuicPacketLength encrypted_length, + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address) + : length(encrypted_length), + self_address(self_address), + peer_address(peer_address) { + data = std::make_unique(encrypted_length); + random.RandBytes(data.get(), encrypted_length); +} + +QuicConnection::ReceivedPacketInfo::ReceivedPacketInfo(QuicTime receipt_time) + : receipt_time(receipt_time) {} +QuicConnection::ReceivedPacketInfo::ReceivedPacketInfo( + const QuicSocketAddress& destination_address, + const QuicSocketAddress& source_address, QuicTime receipt_time, + QuicByteCount length, QuicEcnCodepoint ecn_codepoint) + : destination_address(destination_address), + source_address(source_address), + receipt_time(receipt_time), + length(length), + ecn_codepoint(ecn_codepoint) {} + +std::ostream& operator<<(std::ostream& os, + const QuicConnection::ReceivedPacketInfo& info) { + os << " { destination_address: " << info.destination_address.ToString() + << ", source_address: " << info.source_address.ToString() + << ", received_bytes_counted: " << info.received_bytes_counted + << ", length: " << info.length + << ", destination_connection_id: " << info.destination_connection_id; + if (!info.decrypted) { + os << " }\n"; + return os; + } + os << ", decrypted: " << info.decrypted + << ", decrypted_level: " << EncryptionLevelToString(info.decrypted_level) + << ", header: " << info.header << ", frames: "; + for (const auto frame : info.frames) { + os << frame; + } + os << " }\n"; + return os; +} + +HasRetransmittableData QuicConnection::IsRetransmittable( + const SerializedPacket& packet) { + // Retransmitted packets retransmittable frames are owned by the unacked + // packet map, but are not present in the serialized packet. + if (packet.transmission_type != NOT_RETRANSMISSION || + !packet.retransmittable_frames.empty()) { + return HAS_RETRANSMITTABLE_DATA; + } else { + return NO_RETRANSMITTABLE_DATA; + } +} + +bool QuicConnection::IsTerminationPacket(const SerializedPacket& packet, + QuicErrorCode* error_code) { + if (packet.retransmittable_frames.empty()) { + return false; + } + for (const QuicFrame& frame : packet.retransmittable_frames) { + if (frame.type == CONNECTION_CLOSE_FRAME) { + *error_code = frame.connection_close_frame->quic_error_code; + return true; + } + } + return false; +} + +void QuicConnection::SetMtuDiscoveryTarget(QuicByteCount target) { + QUIC_DVLOG(2) << ENDPOINT << "SetMtuDiscoveryTarget: " << target; + mtu_discoverer_.Disable(); + mtu_discoverer_.Enable(max_packet_length(), GetLimitedMaxPacketSize(target)); +} + +QuicByteCount QuicConnection::GetLimitedMaxPacketSize( + QuicByteCount suggested_max_packet_size) { + if (!peer_address().IsInitialized()) { + QUIC_BUG(quic_bug_10511_30) + << "Attempted to use a connection without a valid peer address"; + return suggested_max_packet_size; + } + + const QuicByteCount writer_limit = writer_->GetMaxPacketSize(peer_address()); + + QuicByteCount max_packet_size = suggested_max_packet_size; + if (max_packet_size > writer_limit) { + max_packet_size = writer_limit; + } + if (max_packet_size > peer_max_packet_size_) { + max_packet_size = peer_max_packet_size_; + } + if (max_packet_size > kMaxOutgoingPacketSize) { + max_packet_size = kMaxOutgoingPacketSize; + } + return max_packet_size; +} + +void QuicConnection::SendMtuDiscoveryPacket(QuicByteCount target_mtu) { + // Currently, this limit is ensured by the caller. + QUICHE_DCHECK_EQ(target_mtu, GetLimitedMaxPacketSize(target_mtu)); + + // Send the probe. + packet_creator_.GenerateMtuDiscoveryPacket(target_mtu); +} + +// TODO(zhongyi): change this method to generate a connectivity probing packet +// and let the caller to call writer to write the packet and handle write +// status. +bool QuicConnection::SendConnectivityProbingPacket( + QuicPacketWriter* probing_writer, const QuicSocketAddress& peer_address) { + QUICHE_DCHECK(peer_address.IsInitialized()); + if (!connected_) { + QUIC_BUG(quic_bug_10511_31) + << "Not sending connectivity probing packet as connection is " + << "disconnected."; + return false; + } + if (perspective_ == Perspective::IS_SERVER && probing_writer == nullptr) { + // Server can use default packet writer to write packet. + probing_writer = writer_; + } + QUICHE_DCHECK(probing_writer); + + if (probing_writer->IsWriteBlocked()) { + QUIC_DLOG(INFO) + << ENDPOINT + << "Writer blocked when sending connectivity probing packet."; + if (probing_writer == writer_) { + // Visitor should not be write blocked if the probing writer is not the + // default packet writer. + visitor_->OnWriteBlocked(); + } + return true; + } + + QUIC_DLOG(INFO) << ENDPOINT + << "Sending path probe packet for connection_id = " + << default_path_.server_connection_id; + + std::unique_ptr probing_packet; + if (!version().HasIetfQuicFrames()) { + // Non-IETF QUIC, generate a padded ping regardless of whether this is a + // request or a response. + probing_packet = packet_creator_.SerializeConnectivityProbingPacket(); + } else { + // IETF QUIC path challenge. + // Send a path probe request using IETF QUIC PATH_CHALLENGE frame. + QuicPathFrameBuffer transmitted_connectivity_probe_payload; + random_generator_->RandBytes(&transmitted_connectivity_probe_payload, + sizeof(QuicPathFrameBuffer)); + probing_packet = + packet_creator_.SerializePathChallengeConnectivityProbingPacket( + transmitted_connectivity_probe_payload); + } + QUICHE_DCHECK_EQ(IsRetransmittable(*probing_packet), NO_RETRANSMITTABLE_DATA); + return WritePacketUsingWriter(std::move(probing_packet), probing_writer, + self_address(), peer_address, + /*measure_rtt=*/true); +} + +bool QuicConnection::WritePacketUsingWriter( + std::unique_ptr packet, QuicPacketWriter* writer, + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, bool measure_rtt) { + const QuicTime packet_send_time = clock_->Now(); + QUIC_DVLOG(2) << ENDPOINT + << "Sending path probe packet for server connection ID " + << default_path_.server_connection_id << std::endl + << quiche::QuicheTextUtils::HexDump(absl::string_view( + packet->encrypted_buffer, packet->encrypted_length)); + WriteResult result = writer->WritePacket( + packet->encrypted_buffer, packet->encrypted_length, self_address.host(), + peer_address, per_packet_options_); + + // If using a batch writer and the probing packet is buffered, flush it. + if (writer->IsBatchMode() && result.status == WRITE_STATUS_OK && + result.bytes_written == 0) { + result = writer->Flush(); + } + + if (IsWriteError(result.status)) { + // Write error for any connectivity probe should not affect the connection + // as it is sent on a different path. + QUIC_DLOG(INFO) << ENDPOINT << "Write probing packet failed with error = " + << result.error_code; + return false; + } + + // Send in currrent path. Call OnPacketSent regardless of the write result. + sent_packet_manager_.OnPacketSent( + packet.get(), packet_send_time, packet->transmission_type, + NO_RETRANSMITTABLE_DATA, measure_rtt, ECN_NOT_ECT); + + if (debug_visitor_ != nullptr) { + if (sent_packet_manager_.unacked_packets().empty()) { + QUIC_BUG(quic_bug_10511_32) + << "Unacked map is empty right after packet is sent"; + } else { + debug_visitor_->OnPacketSent( + packet->packet_number, packet->encrypted_length, + packet->has_crypto_handshake, packet->transmission_type, + packet->encryption_level, + sent_packet_manager_.unacked_packets() + .rbegin() + ->retransmittable_frames, + packet->nonretransmittable_frames, packet_send_time); + } + } + + if (IsWriteBlockedStatus(result.status)) { + if (writer == writer_) { + // Visitor should not be write blocked if the probing writer is not the + // default packet writer. + visitor_->OnWriteBlocked(); + } + if (result.status == WRITE_STATUS_BLOCKED_DATA_BUFFERED) { + QUIC_DLOG(INFO) << ENDPOINT << "Write probing packet blocked"; + } + } + + return true; +} + +void QuicConnection::DisableMtuDiscovery() { + mtu_discoverer_.Disable(); + mtu_discovery_alarm_->Cancel(); +} + +void QuicConnection::DiscoverMtu() { + QUICHE_DCHECK(!mtu_discovery_alarm_->IsSet()); + + const QuicPacketNumber largest_sent_packet = + sent_packet_manager_.GetLargestSentPacket(); + if (mtu_discoverer_.ShouldProbeMtu(largest_sent_packet)) { + ++mtu_probe_count_; + SendMtuDiscoveryPacket( + mtu_discoverer_.GetUpdatedMtuProbeSize(largest_sent_packet)); + } + QUICHE_DCHECK(!mtu_discovery_alarm_->IsSet()); +} + +void QuicConnection::OnEffectivePeerMigrationValidated( + bool /*is_migration_linkable*/) { + if (active_effective_peer_migration_type_ == NO_CHANGE) { + QUIC_BUG(quic_bug_10511_33) << "No migration underway."; + return; + } + highest_packet_sent_before_effective_peer_migration_.Clear(); + const bool send_address_token = + active_effective_peer_migration_type_ != PORT_CHANGE; + active_effective_peer_migration_type_ = NO_CHANGE; + ++stats_.num_validated_peer_migration; + if (!validate_client_addresses_) { + return; + } + if (debug_visitor_ != nullptr) { + const QuicTime now = clock_->ApproximateNow(); + if (now >= stats_.handshake_completion_time) { + debug_visitor_->OnPeerMigrationValidated( + now - stats_.handshake_completion_time); + } else { + QUIC_BUG(quic_bug_10511_34) + << "Handshake completion time is larger than current time."; + } + } + + // Lift anti-amplification limit. + default_path_.validated = true; + alternative_path_.Clear(); + if (send_address_token) { + visitor_->MaybeSendAddressToken(); + } +} + +void QuicConnection::StartEffectivePeerMigration(AddressChangeType type) { + // TODO(fayang): Currently, all peer address change type are allowed. Need to + // add a method ShouldAllowPeerAddressChange(PeerAddressChangeType type) to + // determine whether |type| is allowed. + if (!validate_client_addresses_) { + if (type == NO_CHANGE) { + QUIC_BUG(quic_bug_10511_35) + << "EffectivePeerMigration started without address change."; + return; + } + QUIC_DLOG(INFO) + << ENDPOINT << "Effective peer's ip:port changed from " + << default_path_.peer_address.ToString() << " to " + << GetEffectivePeerAddressFromCurrentPacket().ToString() + << ", address change type is " << type + << ", migrating connection without validating new client address."; + + highest_packet_sent_before_effective_peer_migration_ = + sent_packet_manager_.GetLargestSentPacket(); + default_path_.peer_address = GetEffectivePeerAddressFromCurrentPacket(); + active_effective_peer_migration_type_ = type; + + OnConnectionMigration(); + return; + } + + if (type == NO_CHANGE) { + UpdatePeerAddress(last_received_packet_info_.source_address); + QUIC_BUG(quic_bug_10511_36) + << "EffectivePeerMigration started without address change."; + return; + } + if (GetQuicReloadableFlag( + quic_flush_pending_frames_and_padding_bytes_on_migration)) { + QUIC_RELOADABLE_FLAG_COUNT( + quic_flush_pending_frames_and_padding_bytes_on_migration); + // There could be pending NEW_TOKEN_FRAME triggered by non-probing + // PATH_RESPONSE_FRAME in the same packet or pending padding bytes in the + // packet creator. + packet_creator_.FlushCurrentPacket(); + packet_creator_.SendRemainingPendingPadding(); + if (!connected_) { + return; + } + } else { + if (packet_creator_.HasPendingFrames()) { + packet_creator_.FlushCurrentPacket(); + if (!connected_) { + return; + } + } + } + + // Action items: + // 1. Switch congestion controller; + // 2. Update default_path_ (addresses, validation and bytes accounting); + // 3. Save previous default path if needed; + // 4. Kick off reverse path validation if needed. + // Items 1 and 2 are must-to-do. Items 3 and 4 depends on if the new address + // is validated or not and which path the incoming packet is on. + + const QuicSocketAddress current_effective_peer_address = + GetEffectivePeerAddressFromCurrentPacket(); + QUIC_DLOG(INFO) << ENDPOINT << "Effective peer's ip:port changed from " + << default_path_.peer_address.ToString() << " to " + << current_effective_peer_address.ToString() + << ", address change type is " << type + << ", migrating connection."; + + const QuicSocketAddress previous_direct_peer_address = direct_peer_address_; + PathState previous_default_path = std::move(default_path_); + active_effective_peer_migration_type_ = type; + MaybeClearQueuedPacketsOnPathChange(); + OnConnectionMigration(); + + // Update congestion controller if the address change type is not PORT_CHANGE. + if (type == PORT_CHANGE) { + QUICHE_DCHECK(previous_default_path.validated || + (alternative_path_.validated && + alternative_path_.send_algorithm != nullptr)); + // No need to store previous congestion controller because either the new + // default path is validated or the alternative path is validated and + // already has associated congestion controller. + } else { + previous_default_path.rtt_stats.emplace(); + previous_default_path.rtt_stats->CloneFrom( + *sent_packet_manager_.GetRttStats()); + // If the new peer address share the same IP with the alternative path, the + // connection should switch to the congestion controller of the alternative + // path. Otherwise, the connection should use a brand new one. + // In order to re-use existing code in sent_packet_manager_, reset + // congestion controller to initial state first and then change to the one + // on alternative path. + // TODO(danzh) combine these two steps into one after deprecating gQUIC. + previous_default_path.send_algorithm = OnPeerIpAddressChanged(); + + if (alternative_path_.peer_address.host() == + current_effective_peer_address.host() && + alternative_path_.send_algorithm != nullptr) { + // Update the default path with the congestion controller of the + // alternative path. + sent_packet_manager_.SetSendAlgorithm( + alternative_path_.send_algorithm.release()); + sent_packet_manager_.SetRttStats( + std::move(alternative_path_.rtt_stats).value()); + } + } + // Update to the new peer address. + UpdatePeerAddress(last_received_packet_info_.source_address); + // Update the default path. + if (IsAlternativePath(last_received_packet_info_.destination_address, + current_effective_peer_address)) { + SetDefaultPathState(std::move(alternative_path_)); + } else { + QuicConnectionId client_connection_id; + absl::optional stateless_reset_token; + FindMatchingOrNewClientConnectionIdOrToken( + previous_default_path, alternative_path_, + last_received_packet_info_.destination_connection_id, + &client_connection_id, &stateless_reset_token); + SetDefaultPathState( + PathState(last_received_packet_info_.destination_address, + current_effective_peer_address, client_connection_id, + last_received_packet_info_.destination_connection_id, + stateless_reset_token)); + // The path is considered validated if its peer IP address matches any + // validated path's peer IP address. + default_path_.validated = + (alternative_path_.peer_address.host() == + current_effective_peer_address.host() && + alternative_path_.validated) || + (previous_default_path.validated && type == PORT_CHANGE); + } + if (!last_received_packet_info_.received_bytes_counted) { + // Increment bytes counting on the new default path. + default_path_.bytes_received_before_address_validation += + last_received_packet_info_.length; + last_received_packet_info_.received_bytes_counted = true; + } + + if (!previous_default_path.validated) { + // If the old address is under validation, cancel and fail it. Failing to + // validate the old path shouldn't take any effect. + QUIC_DVLOG(1) << "Cancel validation of previous peer address change to " + << previous_default_path.peer_address + << " upon peer migration to " << default_path_.peer_address; + path_validator_.CancelPathValidation(); + ++stats_.num_peer_migration_while_validating_default_path; + } + + // Clear alternative path if the new default path shares the same IP as the + // alternative path. + if (alternative_path_.peer_address.host() == + default_path_.peer_address.host()) { + alternative_path_.Clear(); + } + + if (default_path_.validated) { + QUIC_DVLOG(1) << "Peer migrated to a validated address."; + // No need to save previous default path, validate new peer address or + // update bytes sent/received. + if (!(previous_default_path.validated && type == PORT_CHANGE)) { + // The alternative path was validated because of proactive reverse path + // validation. + ++stats_.num_peer_migration_to_proactively_validated_address; + } + OnEffectivePeerMigrationValidated( + default_path_.server_connection_id == + previous_default_path.server_connection_id); + return; + } + + // The new default address is not validated yet. Anti-amplification limit is + // enforced. + QUICHE_DCHECK(EnforceAntiAmplificationLimit()); + QUIC_DVLOG(1) << "Apply anti-amplification limit to effective peer address " + << default_path_.peer_address << " with " + << default_path_.bytes_sent_before_address_validation + << " bytes sent and " + << default_path_.bytes_received_before_address_validation + << " bytes received."; + + QUICHE_DCHECK(!alternative_path_.peer_address.IsInitialized() || + alternative_path_.peer_address.host() != + default_path_.peer_address.host()); + + // Save previous default path to the altenative path. + if (previous_default_path.validated) { + // The old path is a validated path which the connection might revert back + // to later. Store it as the alternative path. + alternative_path_ = std::move(previous_default_path); + QUICHE_DCHECK(alternative_path_.send_algorithm != nullptr); + } + + // If the new address is not validated and the connection is not already + // validating that address, a new reverse path validation is needed. + if (!path_validator_.IsValidatingPeerAddress( + current_effective_peer_address)) { + ++stats_.num_reverse_path_validtion_upon_migration; + ValidatePath(std::make_unique( + default_path_.self_address, peer_address(), + default_path_.peer_address, this), + std::make_unique( + this, previous_direct_peer_address), + PathValidationReason::kReversePathValidation); + } else { + QUIC_DVLOG(1) << "Peer address " << default_path_.peer_address + << " is already under validation, wait for result."; + ++stats_.num_peer_migration_to_proactively_validated_address; + } +} + +void QuicConnection::OnConnectionMigration() { + if (debug_visitor_ != nullptr) { + const QuicTime now = clock_->ApproximateNow(); + if (now >= stats_.handshake_completion_time) { + debug_visitor_->OnPeerAddressChange( + active_effective_peer_migration_type_, + now - stats_.handshake_completion_time); + } + } + visitor_->OnConnectionMigration(active_effective_peer_migration_type_); + if (active_effective_peer_migration_type_ != PORT_CHANGE && + active_effective_peer_migration_type_ != IPV4_SUBNET_CHANGE && + !validate_client_addresses_) { + sent_packet_manager_.OnConnectionMigration(/*reset_send_algorithm=*/false); + } +} + +bool QuicConnection::IsCurrentPacketConnectivityProbing() const { + return is_current_packet_connectivity_probing_; +} + +bool QuicConnection::ack_frame_updated() const { + return uber_received_packet_manager_.IsAckFrameUpdated(); +} + +absl::string_view QuicConnection::GetCurrentPacket() { + if (current_packet_data_ == nullptr) { + return absl::string_view(); + } + return absl::string_view(current_packet_data_, + last_received_packet_info_.length); +} + +bool QuicConnection::MaybeConsiderAsMemoryCorruption( + const QuicStreamFrame& frame) { + if (QuicUtils::IsCryptoStreamId(transport_version(), frame.stream_id) || + last_received_packet_info_.decrypted_level != ENCRYPTION_INITIAL) { + return false; + } + + if (perspective_ == Perspective::IS_SERVER && + frame.data_length >= sizeof(kCHLO) && + strncmp(frame.data_buffer, reinterpret_cast(&kCHLO), + sizeof(kCHLO)) == 0) { + return true; + } + + if (perspective_ == Perspective::IS_CLIENT && + frame.data_length >= sizeof(kREJ) && + strncmp(frame.data_buffer, reinterpret_cast(&kREJ), + sizeof(kREJ)) == 0) { + return true; + } + + return false; +} + +void QuicConnection::CheckIfApplicationLimited() { + if (!connected_) { + return; + } + + bool application_limited = + buffered_packets_.empty() && !visitor_->WillingAndAbleToWrite(); + + if (!application_limited) { + return; + } + + sent_packet_manager_.OnApplicationLimited(); +} + +bool QuicConnection::UpdatePacketContent(QuicFrameType type) { + last_received_packet_info_.frames.push_back(type); + if (version().HasIetfQuicFrames()) { + if (perspective_ == Perspective::IS_CLIENT) { + return connected_; + } + if (!QuicUtils::IsProbingFrame(type)) { + MaybeStartIetfPeerMigration(); + return connected_; + } + QuicSocketAddress current_effective_peer_address = + GetEffectivePeerAddressFromCurrentPacket(); + if (IsDefaultPath(last_received_packet_info_.destination_address, + last_received_packet_info_.source_address)) { + return connected_; + } + if (type == PATH_CHALLENGE_FRAME && + !IsAlternativePath(last_received_packet_info_.destination_address, + current_effective_peer_address)) { + QUIC_DVLOG(1) + << "The peer is probing a new path with effective peer address " + << current_effective_peer_address << ", self address " + << last_received_packet_info_.destination_address; + if (!validate_client_addresses_) { + QuicConnectionId client_cid; + absl::optional stateless_reset_token; + FindMatchingOrNewClientConnectionIdOrToken( + default_path_, alternative_path_, + last_received_packet_info_.destination_connection_id, &client_cid, + &stateless_reset_token); + alternative_path_ = + PathState(last_received_packet_info_.destination_address, + current_effective_peer_address, client_cid, + last_received_packet_info_.destination_connection_id, + stateless_reset_token); + } else if (!default_path_.validated) { + // Skip reverse path validation because either handshake hasn't + // completed or the connection is validating the default path. Using + // PATH_CHALLENGE to validate alternative client address before + // handshake gets comfirmed is meaningless because anyone can respond to + // it. If the connection is validating the default path, this + // alternative path is currently the only validated path which shouldn't + // be overridden. + QUIC_DVLOG(1) << "The connection hasn't finished handshake or is " + "validating a recent peer address change."; + QUIC_BUG_IF(quic_bug_12714_30, + IsHandshakeConfirmed() && !alternative_path_.validated) + << "No validated peer address to send after handshake comfirmed."; + } else if (!IsReceivedPeerAddressValidated()) { + QuicConnectionId client_connection_id; + absl::optional stateless_reset_token; + FindMatchingOrNewClientConnectionIdOrToken( + default_path_, alternative_path_, + last_received_packet_info_.destination_connection_id, + &client_connection_id, &stateless_reset_token); + // Only override alternative path state upon receiving a PATH_CHALLENGE + // from an unvalidated peer address, and the connection isn't validating + // a recent peer migration. + alternative_path_ = + PathState(last_received_packet_info_.destination_address, + current_effective_peer_address, client_connection_id, + last_received_packet_info_.destination_connection_id, + stateless_reset_token); + should_proactively_validate_peer_address_on_path_challenge_ = true; + } + } + MaybeUpdateBytesReceivedFromAlternativeAddress( + last_received_packet_info_.length); + return connected_; + } + // Packet content is tracked to identify connectivity probe in non-IETF + // version, where a connectivity probe is defined as + // - a padded PING packet with peer address change received by server, + // - a padded PING packet on new path received by client. + + if (current_packet_content_ == NOT_PADDED_PING) { + // We have already learned the current packet is not a connectivity + // probing packet. Peer migration should have already been started earlier + // if needed. + return connected_; + } + + if (type == PING_FRAME) { + if (current_packet_content_ == NO_FRAMES_RECEIVED) { + current_packet_content_ = FIRST_FRAME_IS_PING; + return connected_; + } + } + + // In Google QUIC, we look for a packet with just a PING and PADDING. + // If the condition is met, mark things as connectivity-probing, causing + // later processing to generate the correct response. + if (type == PADDING_FRAME && current_packet_content_ == FIRST_FRAME_IS_PING) { + current_packet_content_ = SECOND_FRAME_IS_PADDING; + if (perspective_ == Perspective::IS_SERVER) { + is_current_packet_connectivity_probing_ = + current_effective_peer_migration_type_ != NO_CHANGE; + QUIC_DLOG_IF(INFO, is_current_packet_connectivity_probing_) + << ENDPOINT + << "Detected connectivity probing packet. " + "current_effective_peer_migration_type_:" + << current_effective_peer_migration_type_; + } else { + is_current_packet_connectivity_probing_ = + (last_received_packet_info_.source_address != peer_address()) || + (last_received_packet_info_.destination_address != + default_path_.self_address); + QUIC_DLOG_IF(INFO, is_current_packet_connectivity_probing_) + << ENDPOINT + << "Detected connectivity probing packet. " + "last_packet_source_address:" + << last_received_packet_info_.source_address + << ", peer_address_:" << peer_address() + << ", last_packet_destination_address:" + << last_received_packet_info_.destination_address + << ", default path self_address :" << default_path_.self_address; + } + return connected_; + } + + current_packet_content_ = NOT_PADDED_PING; + if (GetLargestReceivedPacket().IsInitialized() && + last_received_packet_info_.header.packet_number == + GetLargestReceivedPacket()) { + UpdatePeerAddress(last_received_packet_info_.source_address); + if (current_effective_peer_migration_type_ != NO_CHANGE) { + // Start effective peer migration immediately when the current packet is + // confirmed not a connectivity probing packet. + StartEffectivePeerMigration(current_effective_peer_migration_type_); + } + } + current_effective_peer_migration_type_ = NO_CHANGE; + return connected_; +} + +void QuicConnection::MaybeStartIetfPeerMigration() { + QUICHE_DCHECK(version().HasIetfQuicFrames()); + if (current_effective_peer_migration_type_ != NO_CHANGE && + !IsHandshakeConfirmed()) { + QUIC_LOG_EVERY_N_SEC(INFO, 60) + << ENDPOINT << "Effective peer's ip:port changed from " + << default_path_.peer_address.ToString() << " to " + << GetEffectivePeerAddressFromCurrentPacket().ToString() + << " before handshake confirmed, " + "current_effective_peer_migration_type_: " + << current_effective_peer_migration_type_; + // Peer migrated before handshake gets confirmed. + CloseConnection((current_effective_peer_migration_type_ == PORT_CHANGE + ? QUIC_PEER_PORT_CHANGE_HANDSHAKE_UNCONFIRMED + : QUIC_CONNECTION_MIGRATION_HANDSHAKE_UNCONFIRMED), + "Peer address changed before handshake is confirmed.", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + + if (GetLargestReceivedPacket().IsInitialized() && + last_received_packet_info_.header.packet_number == + GetLargestReceivedPacket()) { + if (current_effective_peer_migration_type_ != NO_CHANGE) { + // Start effective peer migration when the current packet contains a + // non-probing frame. + // TODO(fayang): When multiple packet number spaces is supported, only + // start peer migration for the application data. + if (!validate_client_addresses_) { + UpdatePeerAddress(last_received_packet_info_.source_address); + } + StartEffectivePeerMigration(current_effective_peer_migration_type_); + } else { + UpdatePeerAddress(last_received_packet_info_.source_address); + } + } + current_effective_peer_migration_type_ = NO_CHANGE; +} + +void QuicConnection::PostProcessAfterAckFrame(bool send_stop_waiting, + bool acked_new_packet) { + if (no_stop_waiting_frames_ && !packet_creator_.has_ack()) { + uber_received_packet_manager_.DontWaitForPacketsBefore( + last_received_packet_info_.decrypted_level, + SupportsMultiplePacketNumberSpaces() + ? sent_packet_manager_.GetLargestPacketPeerKnowsIsAcked( + last_received_packet_info_.decrypted_level) + : sent_packet_manager_.largest_packet_peer_knows_is_acked()); + } + // Always reset the retransmission alarm when an ack comes in, since we now + // have a better estimate of the current rtt than when it was set. + SetRetransmissionAlarm(); + if (acked_new_packet) { + OnForwardProgressMade(); + } else if (default_enable_5rto_blackhole_detection_ && + !sent_packet_manager_.HasInFlightPackets() && + blackhole_detector_.IsDetectionInProgress()) { + // In case no new packets get acknowledged, it is possible packets are + // detected lost because of time based loss detection. Cancel blackhole + // detection if there is no packets in flight. + blackhole_detector_.StopDetection(/*permanent=*/false); + } + + if (send_stop_waiting) { + ++stop_waiting_count_; + } else { + stop_waiting_count_ = 0; + } +} + +void QuicConnection::SetSessionNotifier( + SessionNotifierInterface* session_notifier) { + sent_packet_manager_.SetSessionNotifier(session_notifier); +} + +void QuicConnection::SetDataProducer( + QuicStreamFrameDataProducer* data_producer) { + framer_.set_data_producer(data_producer); +} + +void QuicConnection::SetTransmissionType(TransmissionType type) { + packet_creator_.SetTransmissionType(type); +} + +void QuicConnection::UpdateReleaseTimeIntoFuture() { + QUICHE_DCHECK(supports_release_time_); + + const QuicTime::Delta prior_max_release_time = release_time_into_future_; + release_time_into_future_ = std::max( + QuicTime::Delta::FromMilliseconds(kMinReleaseTimeIntoFutureMs), + std::min(QuicTime::Delta::FromMilliseconds( + GetQuicFlag(quic_max_pace_time_into_future_ms)), + sent_packet_manager_.GetRttStats()->SmoothedOrInitialRtt() * + GetQuicFlag(quic_pace_time_into_future_srtt_fraction))); + QUIC_DVLOG(3) << "Updated max release time delay from " + << prior_max_release_time << " to " + << release_time_into_future_; +} + +void QuicConnection::ResetAckStates() { + ack_alarm_->Cancel(); + stop_waiting_count_ = 0; + uber_received_packet_manager_.ResetAckStates(encryption_level_); +} + +MessageStatus QuicConnection::SendMessage( + QuicMessageId message_id, absl::Span message, + bool flush) { + if (!VersionSupportsMessageFrames(transport_version())) { + QUIC_BUG(quic_bug_10511_38) + << "MESSAGE frame is not supported for version " << transport_version(); + return MESSAGE_STATUS_UNSUPPORTED; + } + if (MemSliceSpanTotalSize(message) > GetCurrentLargestMessagePayload()) { + return MESSAGE_STATUS_TOO_LARGE; + } + if (!connected_ || (!flush && !CanWrite(HAS_RETRANSMITTABLE_DATA))) { + return MESSAGE_STATUS_BLOCKED; + } + ScopedPacketFlusher flusher(this); + return packet_creator_.AddMessageFrame(message_id, message); +} + +QuicPacketLength QuicConnection::GetCurrentLargestMessagePayload() const { + return packet_creator_.GetCurrentLargestMessagePayload(); +} + +QuicPacketLength QuicConnection::GetGuaranteedLargestMessagePayload() const { + return packet_creator_.GetGuaranteedLargestMessagePayload(); +} + +uint32_t QuicConnection::cipher_id() const { + if (version().KnowsWhichDecrypterToUse()) { + return framer_.GetDecrypter(last_received_packet_info_.decrypted_level) + ->cipher_id(); + } + return framer_.decrypter()->cipher_id(); +} + +EncryptionLevel QuicConnection::GetConnectionCloseEncryptionLevel() const { + if (perspective_ == Perspective::IS_CLIENT) { + return encryption_level_; + } + if (IsHandshakeComplete()) { + // A forward secure packet has been received. + QUIC_BUG_IF(quic_bug_12714_31, + encryption_level_ != ENCRYPTION_FORWARD_SECURE) + << ENDPOINT << "Unexpected connection close encryption level " + << encryption_level_; + return ENCRYPTION_FORWARD_SECURE; + } + if (framer_.HasEncrypterOfEncryptionLevel(ENCRYPTION_ZERO_RTT)) { + if (encryption_level_ != ENCRYPTION_ZERO_RTT) { + if (version().HasIetfInvariantHeader()) { + QUIC_CODE_COUNT(quic_wrong_encryption_level_connection_close_ietf); + } else { + QUIC_CODE_COUNT(quic_wrong_encryption_level_connection_close); + } + } + return ENCRYPTION_ZERO_RTT; + } + return ENCRYPTION_INITIAL; +} + +void QuicConnection::MaybeBundleCryptoDataWithAcks() { + QUICHE_DCHECK(SupportsMultiplePacketNumberSpaces()); + if (IsHandshakeConfirmed()) { + return; + } + PacketNumberSpace space = HANDSHAKE_DATA; + if (perspective() == Perspective::IS_SERVER && + framer_.HasEncrypterOfEncryptionLevel(ENCRYPTION_INITIAL)) { + // On the server side, sends INITIAL data with INITIAL ACK if initial key is + // available. + space = INITIAL_DATA; + } + const QuicTime ack_timeout = + uber_received_packet_manager_.GetAckTimeout(space); + if (!ack_timeout.IsInitialized() || + (ack_timeout > clock_->ApproximateNow() && + ack_timeout > uber_received_packet_manager_.GetEarliestAckTimeout())) { + // No pending ACK of space. + return; + } + if (coalesced_packet_.length() > 0) { + // Do not bundle CRYPTO data if the ACK could be coalesced with other + // packets. + return; + } + + if (!framer_.HasAnEncrypterForSpace(space)) { + QUIC_BUG(quic_bug_10511_39) + << ENDPOINT + << "Try to bundle crypto with ACK with missing key of space " + << PacketNumberSpaceToString(space); + return; + } + + sent_packet_manager_.RetransmitDataOfSpaceIfAny(space); +} + +void QuicConnection::SendAllPendingAcks() { + QUICHE_DCHECK(SupportsMultiplePacketNumberSpaces()); + QUIC_DVLOG(1) << ENDPOINT << "Trying to send all pending ACKs"; + ack_alarm_->Cancel(); + QuicTime earliest_ack_timeout = + uber_received_packet_manager_.GetEarliestAckTimeout(); + QUIC_BUG_IF(quic_bug_12714_32, !earliest_ack_timeout.IsInitialized()); + MaybeBundleCryptoDataWithAcks(); + earliest_ack_timeout = uber_received_packet_manager_.GetEarliestAckTimeout(); + if (!earliest_ack_timeout.IsInitialized()) { + return; + } + for (int8_t i = INITIAL_DATA; i <= APPLICATION_DATA; ++i) { + const QuicTime ack_timeout = uber_received_packet_manager_.GetAckTimeout( + static_cast(i)); + if (!ack_timeout.IsInitialized()) { + continue; + } + if (!framer_.HasAnEncrypterForSpace(static_cast(i))) { + // The key has been dropped. + continue; + } + if (ack_timeout > clock_->ApproximateNow() && + ack_timeout > earliest_ack_timeout) { + // Always send the earliest ACK to make forward progress in case alarm + // fires early. + continue; + } + QUIC_DVLOG(1) << ENDPOINT << "Sending ACK of packet number space " + << PacketNumberSpaceToString( + static_cast(i)); + ScopedEncryptionLevelContext context( + this, QuicUtils::GetEncryptionLevelToSendAckofSpace( + static_cast(i))); + QuicFrames frames; + frames.push_back(uber_received_packet_manager_.GetUpdatedAckFrame( + static_cast(i), clock_->ApproximateNow())); + const bool flushed = packet_creator_.FlushAckFrame(frames); + if (!flushed) { + // Connection is write blocked. + QUIC_BUG_IF(quic_bug_12714_33, + !writer_->IsWriteBlocked() && + !LimitedByAmplificationFactor( + packet_creator_.max_packet_length())) + << "Writer not blocked and not throttled by amplification factor, " + "but ACK not flushed for packet space:" + << PacketNumberSpaceToString(static_cast(i)) + << ", connected: " << connected_ + << ", fill_coalesced_packet: " << fill_coalesced_packet_ + << ", has_soft_max_packet_length: " + << packet_creator_.HasSoftMaxPacketLength() + << ", max_packet_length: " << packet_creator_.max_packet_length() + << ", pending frames: " << packet_creator_.GetPendingFramesInfo(); + break; + } + ResetAckStates(); + } + + const QuicTime timeout = + uber_received_packet_manager_.GetEarliestAckTimeout(); + if (timeout.IsInitialized()) { + // If there are ACKs pending, re-arm ack alarm. + ack_alarm_->Update(timeout, kAlarmGranularity); + } + // Only try to bundle retransmittable data with ACK frame if default + // encryption level is forward secure. + if (encryption_level_ != ENCRYPTION_FORWARD_SECURE || + !ShouldBundleRetransmittableFrameWithAck()) { + return; + } + consecutive_num_packets_with_no_retransmittable_frames_ = 0; + if (packet_creator_.HasPendingRetransmittableFrames() || + visitor_->WillingAndAbleToWrite()) { + // There are pending retransmittable frames. + return; + } + + visitor_->OnAckNeedsRetransmittableFrame(); +} + +bool QuicConnection::ShouldBundleRetransmittableFrameWithAck() const { + if (consecutive_num_packets_with_no_retransmittable_frames_ >= + max_consecutive_num_packets_with_no_retransmittable_frames_) { + return true; + } + if (bundle_retransmittable_with_pto_ack_ && + sent_packet_manager_.GetConsecutivePtoCount() > 0) { + // Bundle a retransmittable frame with an ACK if PTO has fired in order to + // recover more quickly in cases of temporary network outage. + return true; + } + return false; +} + +void QuicConnection::MaybeCoalescePacketOfHigherSpace() { + if (!connected() || !packet_creator_.HasSoftMaxPacketLength()) { + return; + } + if (fill_coalesced_packet_) { + // Make sure MaybeCoalescePacketOfHigherSpace is not re-entrant. + QUIC_BUG(quic_coalesce_packet_reentrant); + return; + } + for (EncryptionLevel retransmission_level : + {ENCRYPTION_INITIAL, ENCRYPTION_HANDSHAKE}) { + // Coalesce HANDSHAKE with INITIAL retransmission, and coalesce 1-RTT with + // HANDSHAKE retransmission. + const EncryptionLevel coalesced_level = + retransmission_level == ENCRYPTION_INITIAL ? ENCRYPTION_HANDSHAKE + : ENCRYPTION_FORWARD_SECURE; + if (coalesced_packet_.ContainsPacketOfEncryptionLevel( + retransmission_level) && + coalesced_packet_.TransmissionTypeOfPacket(retransmission_level) != + NOT_RETRANSMISSION && + framer_.HasEncrypterOfEncryptionLevel(coalesced_level) && + !coalesced_packet_.ContainsPacketOfEncryptionLevel(coalesced_level)) { + QUIC_DVLOG(1) << ENDPOINT + << "Trying to coalesce packet of encryption level: " + << EncryptionLevelToString(coalesced_level); + fill_coalesced_packet_ = true; + sent_packet_manager_.RetransmitDataOfSpaceIfAny( + QuicUtils::GetPacketNumberSpace(coalesced_level)); + fill_coalesced_packet_ = false; + } + } +} + +bool QuicConnection::FlushCoalescedPacket() { + ScopedCoalescedPacketClearer clearer(&coalesced_packet_); + if (!connected_) { + return false; + } + if (!version().CanSendCoalescedPackets()) { + QUIC_BUG_IF(quic_bug_12714_34, coalesced_packet_.length() > 0); + return true; + } + if (coalesced_packet_.ContainsPacketOfEncryptionLevel(ENCRYPTION_INITIAL) && + !framer_.HasEncrypterOfEncryptionLevel(ENCRYPTION_INITIAL)) { + // Initial packet will be re-serialized. Neuter it in case initial key has + // been dropped. + QUIC_BUG(quic_bug_10511_40) + << ENDPOINT + << "Coalescer contains initial packet while initial key has " + "been dropped."; + coalesced_packet_.NeuterInitialPacket(); + } + if (coalesced_packet_.length() == 0) { + return true; + } + + char buffer[kMaxOutgoingPacketSize]; + const size_t length = packet_creator_.SerializeCoalescedPacket( + coalesced_packet_, buffer, coalesced_packet_.max_packet_length()); + if (length == 0) { + if (connected_) { + CloseConnection(QUIC_FAILED_TO_SERIALIZE_PACKET, + "Failed to serialize coalesced packet.", + ConnectionCloseBehavior::SILENT_CLOSE); + } + return false; + } + if (debug_visitor_ != nullptr) { + debug_visitor_->OnCoalescedPacketSent(coalesced_packet_, length); + } + QUIC_DVLOG(1) << ENDPOINT << "Sending coalesced packet " + << coalesced_packet_.ToString(length); + const size_t padding_size = + length - std::min(length, coalesced_packet_.length()); + // Buffer coalesced packet if padding + bytes_sent exceeds amplifcation limit. + if (!buffered_packets_.empty() || HandleWriteBlocked() || + (enforce_strict_amplification_factor_ && + LimitedByAmplificationFactor(padding_size))) { + QUIC_DVLOG(1) << ENDPOINT + << "Buffering coalesced packet of len: " << length; + buffered_packets_.emplace_back( + buffer, static_cast(length), + coalesced_packet_.self_address(), coalesced_packet_.peer_address()); + } else { + WriteResult result = SendPacketToWriter( + buffer, length, coalesced_packet_.self_address().host(), + coalesced_packet_.peer_address(), per_packet_options_); + if (IsWriteError(result.status)) { + OnWriteError(result.error_code); + return false; + } + if (IsWriteBlockedStatus(result.status)) { + visitor_->OnWriteBlocked(); + if (result.status != WRITE_STATUS_BLOCKED_DATA_BUFFERED) { + QUIC_DVLOG(1) << ENDPOINT + << "Buffering coalesced packet of len: " << length; + buffered_packets_.emplace_back( + buffer, static_cast(length), + coalesced_packet_.self_address(), coalesced_packet_.peer_address()); + } + } + } + if (accelerated_server_preferred_address_ && + stats_.num_duplicated_packets_sent_to_server_preferred_address < + kMaxDuplicatedPacketsSentToServerPreferredAddress) { + // Send coalesced packets to both addresses while the server preferred + // address validation is pending. + QUICHE_DCHECK(received_server_preferred_address_.IsInitialized()); + path_validator_.MaybeWritePacketToAddress( + buffer, length, received_server_preferred_address_); + ++stats_.num_duplicated_packets_sent_to_server_preferred_address; + } + // Account for added padding. + if (length > coalesced_packet_.length()) { + if (IsDefaultPath(coalesced_packet_.self_address(), + coalesced_packet_.peer_address())) { + if (EnforceAntiAmplificationLimit()) { + // Include bytes sent even if they are not in flight. + default_path_.bytes_sent_before_address_validation += padding_size; + } + } else { + MaybeUpdateBytesSentToAlternativeAddress(coalesced_packet_.peer_address(), + padding_size); + } + stats_.bytes_sent += padding_size; + if (coalesced_packet_.initial_packet() != nullptr && + coalesced_packet_.initial_packet()->transmission_type != + NOT_RETRANSMISSION) { + stats_.bytes_retransmitted += padding_size; + } + } + return true; +} + +void QuicConnection::MaybeEnableMultiplePacketNumberSpacesSupport() { + if (version().handshake_protocol != PROTOCOL_TLS1_3) { + return; + } + QUIC_DVLOG(1) << ENDPOINT << "connection " << connection_id() + << " supports multiple packet number spaces"; + framer_.EnableMultiplePacketNumberSpacesSupport(); + sent_packet_manager_.EnableMultiplePacketNumberSpacesSupport(); + uber_received_packet_manager_.EnableMultiplePacketNumberSpacesSupport( + perspective_); +} + +bool QuicConnection::SupportsMultiplePacketNumberSpaces() const { + return sent_packet_manager_.supports_multiple_packet_number_spaces(); +} + +void QuicConnection::SetLargestReceivedPacketWithAck( + QuicPacketNumber new_value) { + if (SupportsMultiplePacketNumberSpaces()) { + largest_seen_packets_with_ack_[QuicUtils::GetPacketNumberSpace( + last_received_packet_info_.decrypted_level)] = new_value; + } else { + largest_seen_packet_with_ack_ = new_value; + } +} + +void QuicConnection::OnForwardProgressMade() { + if (!connected_) { + return; + } + if (is_path_degrading_) { + visitor_->OnForwardProgressMadeAfterPathDegrading(); + is_path_degrading_ = false; + } + if (sent_packet_manager_.HasInFlightPackets()) { + // Restart detections if forward progress has been made. + blackhole_detector_.RestartDetection(GetPathDegradingDeadline(), + GetNetworkBlackholeDeadline(), + GetPathMtuReductionDeadline()); + } else { + // Stop detections in quiecense. + blackhole_detector_.StopDetection(/*permanent=*/false); + } + QUIC_BUG_IF(quic_bug_12714_35, + perspective_ == Perspective::IS_SERVER && + default_enable_5rto_blackhole_detection_ && + blackhole_detector_.IsDetectionInProgress() && + !sent_packet_manager_.HasInFlightPackets()) + << ENDPOINT + << "Trying to start blackhole detection without no bytes in flight"; +} + +QuicPacketNumber QuicConnection::GetLargestReceivedPacketWithAck() const { + if (SupportsMultiplePacketNumberSpaces()) { + return largest_seen_packets_with_ack_[QuicUtils::GetPacketNumberSpace( + last_received_packet_info_.decrypted_level)]; + } + return largest_seen_packet_with_ack_; +} + +QuicPacketNumber QuicConnection::GetLargestAckedPacket() const { + if (SupportsMultiplePacketNumberSpaces()) { + return sent_packet_manager_.GetLargestAckedPacket( + last_received_packet_info_.decrypted_level); + } + return sent_packet_manager_.GetLargestObserved(); +} + +QuicPacketNumber QuicConnection::GetLargestReceivedPacket() const { + return uber_received_packet_manager_.GetLargestObserved( + last_received_packet_info_.decrypted_level); +} + +bool QuicConnection::EnforceAntiAmplificationLimit() const { + return version().SupportsAntiAmplificationLimit() && + perspective_ == Perspective::IS_SERVER && !default_path_.validated; +} + +// TODO(danzh) Pass in path object or its reference of some sort to use this +// method to check anti-amplification limit on non-default path. +bool QuicConnection::LimitedByAmplificationFactor(QuicByteCount bytes) const { + return EnforceAntiAmplificationLimit() && + (default_path_.bytes_sent_before_address_validation + + (enforce_strict_amplification_factor_ ? bytes : 0)) >= + anti_amplification_factor_ * + default_path_.bytes_received_before_address_validation; +} + +SerializedPacketFate QuicConnection::GetSerializedPacketFate( + bool is_mtu_discovery, EncryptionLevel encryption_level) { + if (ShouldDiscardPacket(encryption_level)) { + return DISCARD; + } + if (version().CanSendCoalescedPackets() && !coalescing_done_ && + !is_mtu_discovery) { + if (!IsHandshakeConfirmed()) { + // Before receiving ACK for any 1-RTT packets, always try to coalesce + // packet (except MTU discovery packet). + return COALESCE; + } + if (coalesced_packet_.length() > 0) { + // If the coalescer is not empty, let this packet go through coalescer + // to avoid potential out of order sending. + return COALESCE; + } + } + if (!buffered_packets_.empty() || HandleWriteBlocked()) { + return BUFFER; + } + return SEND_TO_WRITER; +} + +bool QuicConnection::IsHandshakeComplete() const { + return visitor_->GetHandshakeState() >= HANDSHAKE_COMPLETE; +} + +bool QuicConnection::IsHandshakeConfirmed() const { + QUICHE_DCHECK_EQ(PROTOCOL_TLS1_3, version().handshake_protocol); + return visitor_->GetHandshakeState() == HANDSHAKE_CONFIRMED; +} + +size_t QuicConnection::min_received_before_ack_decimation() const { + return uber_received_packet_manager_.min_received_before_ack_decimation(); +} + +void QuicConnection::set_min_received_before_ack_decimation(size_t new_value) { + uber_received_packet_manager_.set_min_received_before_ack_decimation( + new_value); +} + +const QuicAckFrame& QuicConnection::ack_frame() const { + if (SupportsMultiplePacketNumberSpaces()) { + return uber_received_packet_manager_.GetAckFrame( + QuicUtils::GetPacketNumberSpace( + last_received_packet_info_.decrypted_level)); + } + return uber_received_packet_manager_.ack_frame(); +} + +void QuicConnection::set_client_connection_id( + QuicConnectionId client_connection_id) { + if (!version().SupportsClientConnectionIds()) { + QUIC_BUG_IF(quic_bug_12714_36, !client_connection_id.IsEmpty()) + << ENDPOINT << "Attempted to use client connection ID " + << client_connection_id << " with unsupported version " << version(); + return; + } + default_path_.client_connection_id = client_connection_id; + + client_connection_id_is_set_ = true; + if (version().HasIetfQuicFrames() && !client_connection_id.IsEmpty()) { + if (perspective_ == Perspective::IS_SERVER) { + QUICHE_DCHECK(peer_issued_cid_manager_ == nullptr); + peer_issued_cid_manager_ = + std::make_unique( + kMinNumOfActiveConnectionIds, client_connection_id, clock_, + alarm_factory_, this, context()); + } else { + // Note in Chromium client, set_client_connection_id is not called and + // thus self_issued_cid_manager_ should be null. + self_issued_cid_manager_ = MakeSelfIssuedConnectionIdManager(); + } + } + QUIC_DLOG(INFO) << ENDPOINT << "setting client connection ID to " + << default_path_.client_connection_id + << " for connection with server connection ID " + << default_path_.server_connection_id; + packet_creator_.SetClientConnectionId(default_path_.client_connection_id); + framer_.SetExpectedClientConnectionIdLength( + default_path_.client_connection_id.length()); +} + +void QuicConnection::OnPathDegradingDetected() { + is_path_degrading_ = true; + visitor_->OnPathDegrading(); + if (multi_port_stats_) { + multi_port_stats_->num_path_degrading++; + MaybeMigrateToMultiPortPath(); + } +} + +void QuicConnection::MaybeMigrateToMultiPortPath() { + if (!alternative_path_.validated) { + QUIC_CLIENT_HISTOGRAM_ENUM( + "QuicConnection.MultiPortPathStatusWhenMigrating", + MultiPortStatusOnMigration::kNotValidated, + MultiPortStatusOnMigration::kMaxValue, + "Status of the multi port path upon migration"); + return; + } + std::unique_ptr context; + const bool has_pending_validation = + path_validator_.HasPendingPathValidation(); + if (!has_pending_validation) { + // The multi-port path should have just finished the recent probe and + // waiting for the next one. + context = std::move(multi_port_path_context_); + multi_port_probing_alarm_->Cancel(); + QUIC_CLIENT_HISTOGRAM_ENUM( + "QuicConnection.MultiPortPathStatusWhenMigrating", + MultiPortStatusOnMigration::kWaitingForRefreshValidation, + MultiPortStatusOnMigration::kMaxValue, + "Status of the multi port path upon migration"); + } else { + // The multi-port path is currently under probing. + context = path_validator_.ReleaseContext(); + QUIC_CLIENT_HISTOGRAM_ENUM( + "QuicConnection.MultiPortPathStatusWhenMigrating", + MultiPortStatusOnMigration::kPendingRefreshValidation, + MultiPortStatusOnMigration::kMaxValue, + "Status of the multi port path upon migration"); + } + if (context == nullptr) { + QUICHE_BUG(quic_bug_12714_90) << "No multi-port context to migrate to"; + return; + } + visitor_->MigrateToMultiPortPath(std::move(context)); +} + +void QuicConnection::OnBlackholeDetected() { + if (default_enable_5rto_blackhole_detection_ && + !sent_packet_manager_.HasInFlightPackets()) { + QUIC_BUG(quic_bug_10511_41) + << ENDPOINT + << "Blackhole detected, but there is no bytes in flight, version: " + << version(); + // Do not close connection if there is no bytes in flight. + return; + } + CloseConnection(QUIC_TOO_MANY_RTOS, "Network blackhole detected", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); +} + +void QuicConnection::OnPathMtuReductionDetected() { + MaybeRevertToPreviousMtu(); +} + +void QuicConnection::OnHandshakeTimeout() { + const QuicTime::Delta duration = + clock_->ApproximateNow() - stats_.connection_creation_time; + std::string error_details = absl::StrCat( + "Handshake timeout expired after ", duration.ToDebuggingValue(), + ". Timeout:", + idle_network_detector_.handshake_timeout().ToDebuggingValue()); + if (perspective() == Perspective::IS_CLIENT && version().UsesTls()) { + absl::StrAppend(&error_details, UndecryptablePacketsInfo()); + } + QUIC_DVLOG(1) << ENDPOINT << error_details; + CloseConnection(QUIC_HANDSHAKE_TIMEOUT, error_details, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); +} + +void QuicConnection::OnIdleNetworkDetected() { + const QuicTime::Delta duration = + clock_->ApproximateNow() - + idle_network_detector_.last_network_activity_time(); + std::string error_details = absl::StrCat( + "No recent network activity after ", duration.ToDebuggingValue(), + ". Timeout:", + idle_network_detector_.idle_network_timeout().ToDebuggingValue()); + if (perspective() == Perspective::IS_CLIENT && version().UsesTls() && + !IsHandshakeComplete()) { + absl::StrAppend(&error_details, UndecryptablePacketsInfo()); + } + QUIC_DVLOG(1) << ENDPOINT << error_details; + const bool has_consecutive_pto = + sent_packet_manager_.GetConsecutivePtoCount() > 0; + if (has_consecutive_pto || visitor_->ShouldKeepConnectionAlive()) { + if (GetQuicReloadableFlag(quic_add_stream_info_to_idle_close_detail) && + !has_consecutive_pto) { + // Include stream information in error detail if there are open streams. + QUIC_RELOADABLE_FLAG_COUNT(quic_add_stream_info_to_idle_close_detail); + absl::StrAppend(&error_details, ", ", + visitor_->GetStreamsInfoForLogging()); + } + CloseConnection(QUIC_NETWORK_IDLE_TIMEOUT, error_details, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + QuicErrorCode error_code = QUIC_NETWORK_IDLE_TIMEOUT; + if (idle_timeout_connection_close_behavior_ == + ConnectionCloseBehavior:: + SILENT_CLOSE_WITH_CONNECTION_CLOSE_PACKET_SERIALIZED) { + error_code = QUIC_SILENT_IDLE_TIMEOUT; + } + CloseConnection(error_code, error_details, + idle_timeout_connection_close_behavior_); +} + +void QuicConnection::OnBandwidthUpdateTimeout() { + visitor_->OnBandwidthUpdateTimeout(); +} + +void QuicConnection::OnKeepAliveTimeout() { + if (retransmission_alarm_->IsSet() || + !visitor_->ShouldKeepConnectionAlive()) { + return; + } + SendPingAtLevel(framer().GetEncryptionLevelToSendApplicationData()); +} + +void QuicConnection::OnRetransmittableOnWireTimeout() { + if (retransmission_alarm_->IsSet() || + !visitor_->ShouldKeepConnectionAlive()) { + return; + } + bool packet_buffered = false; + switch (retransmittable_on_wire_behavior_) { + case DEFAULT: + break; + case SEND_FIRST_FORWARD_SECURE_PACKET: + if (first_serialized_one_rtt_packet_ != nullptr) { + buffered_packets_.emplace_back( + first_serialized_one_rtt_packet_->data.get(), + first_serialized_one_rtt_packet_->length, self_address(), + peer_address()); + packet_buffered = true; + } + break; + case SEND_RANDOM_BYTES: + const QuicPacketLength random_bytes_length = std::max( + QuicFramer::GetMinStatelessResetPacketLength() + 1, + random_generator_->RandUint64() % + packet_creator_.max_packet_length()); + buffered_packets_.emplace_back(*random_generator_, random_bytes_length, + self_address(), peer_address()); + packet_buffered = true; + break; + } + if (packet_buffered) { + if (!writer_->IsWriteBlocked()) { + WriteQueuedPackets(); + } + if (connected_) { + // Always reset PING alarm with has_in_flight_packets=true. This is used + // to avoid re-arming the alarm in retransmittable-on-wire mode. + ping_manager_.SetAlarm(clock_->ApproximateNow(), + visitor_->ShouldKeepConnectionAlive(), + /*has_in_flight_packets=*/true); + } + return; + } + SendPingAtLevel(framer().GetEncryptionLevelToSendApplicationData()); +} + +void QuicConnection::OnPeerIssuedConnectionIdRetired() { + QUICHE_DCHECK(peer_issued_cid_manager_ != nullptr); + QuicConnectionId* default_path_cid = + perspective_ == Perspective::IS_CLIENT + ? &default_path_.server_connection_id + : &default_path_.client_connection_id; + QuicConnectionId* alternative_path_cid = + perspective_ == Perspective::IS_CLIENT + ? &alternative_path_.server_connection_id + : &alternative_path_.client_connection_id; + bool default_path_and_alternative_path_use_the_same_peer_connection_id = + *default_path_cid == *alternative_path_cid; + if (!default_path_cid->IsEmpty() && + !peer_issued_cid_manager_->IsConnectionIdActive(*default_path_cid)) { + *default_path_cid = QuicConnectionId(); + } + // TODO(haoyuewang) Handle the change for default_path_ & alternatvie_path_ + // via the same helper function. + if (default_path_cid->IsEmpty()) { + // Try setting a new connection ID now such that subsequent + // RetireConnectionId frames can be sent on the default path. + const QuicConnectionIdData* unused_connection_id_data = + peer_issued_cid_manager_->ConsumeOneUnusedConnectionId(); + if (unused_connection_id_data != nullptr) { + *default_path_cid = unused_connection_id_data->connection_id; + default_path_.stateless_reset_token = + unused_connection_id_data->stateless_reset_token; + if (perspective_ == Perspective::IS_CLIENT) { + packet_creator_.SetServerConnectionId( + unused_connection_id_data->connection_id); + } else { + packet_creator_.SetClientConnectionId( + unused_connection_id_data->connection_id); + } + } + } + if (default_path_and_alternative_path_use_the_same_peer_connection_id) { + *alternative_path_cid = *default_path_cid; + alternative_path_.stateless_reset_token = + default_path_.stateless_reset_token; + } else if (!alternative_path_cid->IsEmpty() && + !peer_issued_cid_manager_->IsConnectionIdActive( + *alternative_path_cid)) { + *alternative_path_cid = EmptyQuicConnectionId(); + const QuicConnectionIdData* unused_connection_id_data = + peer_issued_cid_manager_->ConsumeOneUnusedConnectionId(); + if (unused_connection_id_data != nullptr) { + *alternative_path_cid = unused_connection_id_data->connection_id; + alternative_path_.stateless_reset_token = + unused_connection_id_data->stateless_reset_token; + } + } + + std::vector retired_cid_sequence_numbers = + peer_issued_cid_manager_->ConsumeToBeRetiredConnectionIdSequenceNumbers(); + QUICHE_DCHECK(!retired_cid_sequence_numbers.empty()); + for (const auto& sequence_number : retired_cid_sequence_numbers) { + ++stats_.num_retire_connection_id_sent; + visitor_->SendRetireConnectionId(sequence_number); + } +} + +bool QuicConnection::SendNewConnectionId( + const QuicNewConnectionIdFrame& frame) { + visitor_->SendNewConnectionId(frame); + ++stats_.num_new_connection_id_sent; + return connected_; +} + +bool QuicConnection::MaybeReserveConnectionId( + const QuicConnectionId& connection_id) { + if (perspective_ == Perspective::IS_SERVER) { + return visitor_->MaybeReserveConnectionId(connection_id); + } + return true; +} + +void QuicConnection::OnSelfIssuedConnectionIdRetired( + const QuicConnectionId& connection_id) { + if (perspective_ == Perspective::IS_SERVER) { + visitor_->OnServerConnectionIdRetired(connection_id); + } +} + +void QuicConnection::MaybeUpdateAckTimeout() { + if (should_last_packet_instigate_acks_) { + return; + } + should_last_packet_instigate_acks_ = true; + uber_received_packet_manager_.MaybeUpdateAckTimeout( + /*should_last_packet_instigate_acks=*/true, + last_received_packet_info_.decrypted_level, + last_received_packet_info_.header.packet_number, + last_received_packet_info_.receipt_time, clock_->ApproximateNow(), + sent_packet_manager_.GetRttStats()); +} + +QuicTime QuicConnection::GetPathDegradingDeadline() const { + if (!ShouldDetectPathDegrading()) { + return QuicTime::Zero(); + } + return clock_->ApproximateNow() + + sent_packet_manager_.GetPathDegradingDelay(); +} + +bool QuicConnection::ShouldDetectPathDegrading() const { + if (!connected_) { + return false; + } + if (GetQuicReloadableFlag( + quic_no_path_degrading_before_handshake_confirmed) && + SupportsMultiplePacketNumberSpaces()) { + QUIC_RELOADABLE_FLAG_COUNT_N( + quic_no_path_degrading_before_handshake_confirmed, 1, 2); + // No path degrading detection before handshake confirmed. + return perspective_ == Perspective::IS_CLIENT && IsHandshakeConfirmed() && + !is_path_degrading_; + } + // No path degrading detection before handshake completes. + if (!idle_network_detector_.handshake_timeout().IsInfinite()) { + return false; + } + return perspective_ == Perspective::IS_CLIENT && !is_path_degrading_; +} + +QuicTime QuicConnection::GetNetworkBlackholeDeadline() const { + if (!ShouldDetectBlackhole()) { + return QuicTime::Zero(); + } + QUICHE_DCHECK_LT(0u, num_rtos_for_blackhole_detection_); + + const QuicTime::Delta blackhole_delay = + sent_packet_manager_.GetNetworkBlackholeDelay( + num_rtos_for_blackhole_detection_); + if (!ShouldDetectPathDegrading()) { + return clock_->ApproximateNow() + blackhole_delay; + } + return clock_->ApproximateNow() + + CalculateNetworkBlackholeDelay( + blackhole_delay, sent_packet_manager_.GetPathDegradingDelay(), + sent_packet_manager_.GetPtoDelay()); +} + +// static +QuicTime::Delta QuicConnection::CalculateNetworkBlackholeDelay( + QuicTime::Delta blackhole_delay, QuicTime::Delta path_degrading_delay, + QuicTime::Delta pto_delay) { + const QuicTime::Delta min_delay = path_degrading_delay + pto_delay * 2; + if (blackhole_delay < min_delay) { + QUIC_CODE_COUNT(quic_extending_short_blackhole_delay); + } + return std::max(min_delay, blackhole_delay); +} + +void QuicConnection::AddKnownServerAddress(const QuicSocketAddress& address) { + QUICHE_DCHECK(perspective_ == Perspective::IS_CLIENT); + if (!address.IsInitialized() || IsKnownServerAddress(address)) { + return; + } + known_server_addresses_.push_back(address); +} + +absl::optional +QuicConnection::MaybeIssueNewConnectionIdForPreferredAddress() { + if (self_issued_cid_manager_ == nullptr) { + return absl::nullopt; + } + return self_issued_cid_manager_ + ->MaybeIssueNewConnectionIdForPreferredAddress(); +} + +bool QuicConnection::ShouldDetectBlackhole() const { + if (!connected_ || blackhole_detection_disabled_) { + return false; + } + if (GetQuicReloadableFlag( + quic_no_path_degrading_before_handshake_confirmed) && + SupportsMultiplePacketNumberSpaces() && !IsHandshakeConfirmed()) { + QUIC_RELOADABLE_FLAG_COUNT_N( + quic_no_path_degrading_before_handshake_confirmed, 2, 2); + return false; + } + // No blackhole detection before handshake completes. + if (default_enable_5rto_blackhole_detection_) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_default_enable_5rto_blackhole_detection2, + 3, 3); + return IsHandshakeComplete(); + } + + if (!idle_network_detector_.handshake_timeout().IsInfinite()) { + return false; + } + return num_rtos_for_blackhole_detection_ > 0; +} + +QuicTime QuicConnection::GetRetransmissionDeadline() const { + if (perspective_ == Perspective::IS_CLIENT && + SupportsMultiplePacketNumberSpaces() && !IsHandshakeConfirmed() && + stats_.pto_count == 0 && + !framer_.HasDecrypterOfEncryptionLevel(ENCRYPTION_HANDSHAKE) && + !undecryptable_packets_.empty()) { + // Retransmits ClientHello quickly when a Handshake or 1-RTT packet is + // received prior to having Handshake keys. Adding kAlarmGranulary will + // avoid spurious retransmissions in the case of small-scale reordering. + return clock_->ApproximateNow() + kAlarmGranularity; + } + return sent_packet_manager_.GetRetransmissionTime(); +} + +bool QuicConnection::SendPathChallenge( + const QuicPathFrameBuffer& data_buffer, + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + const QuicSocketAddress& effective_peer_address, QuicPacketWriter* writer) { + if (!framer_.HasEncrypterOfEncryptionLevel(ENCRYPTION_FORWARD_SECURE)) { + return connected_; + } + if (connection_migration_use_new_cid_) { + QuicConnectionId client_cid, server_cid; + FindOnPathConnectionIds(self_address, effective_peer_address, &client_cid, + &server_cid); + if (writer == writer_) { + ScopedPacketFlusher flusher(this); + { + QuicPacketCreator::ScopedPeerAddressContext context( + &packet_creator_, peer_address, client_cid, server_cid, + connection_migration_use_new_cid_); + // It's using the default writer, add the PATH_CHALLENGE the same way as + // other frames. This may cause connection to be closed. + packet_creator_.AddPathChallengeFrame(data_buffer); + } + } else { + // Switch to the right CID and source/peer addresses. + QuicPacketCreator::ScopedPeerAddressContext context( + &packet_creator_, peer_address, client_cid, server_cid, + connection_migration_use_new_cid_); + std::unique_ptr probing_packet = + packet_creator_.SerializePathChallengeConnectivityProbingPacket( + data_buffer); + QUICHE_DCHECK_EQ(IsRetransmittable(*probing_packet), + NO_RETRANSMITTABLE_DATA); + QUICHE_DCHECK_EQ(self_address, alternative_path_.self_address); + WritePacketUsingWriter(std::move(probing_packet), writer, self_address, + peer_address, /*measure_rtt=*/false); + } + return connected_; + } + if (writer == writer_) { + ScopedPacketFlusher flusher(this); + { + // It's on current path, add the PATH_CHALLENGE the same way as other + // frames. + QuicPacketCreator::ScopedPeerAddressContext context( + &packet_creator_, peer_address, /*update_connection_id=*/false); + // This may cause connection to be closed. + packet_creator_.AddPathChallengeFrame(data_buffer); + } + // Return outside of the scope so that the flush result can be reflected. + return connected_; + } + std::unique_ptr probing_packet = + packet_creator_.SerializePathChallengeConnectivityProbingPacket( + data_buffer); + QUICHE_DCHECK_EQ(IsRetransmittable(*probing_packet), NO_RETRANSMITTABLE_DATA); + QUICHE_DCHECK_EQ(self_address, alternative_path_.self_address); + WritePacketUsingWriter(std::move(probing_packet), writer, self_address, + peer_address, /*measure_rtt=*/false); + return true; +} + +QuicTime QuicConnection::GetRetryTimeout( + const QuicSocketAddress& peer_address_to_use, + QuicPacketWriter* writer_to_use) const { + if (writer_to_use == writer_ && peer_address_to_use == peer_address()) { + return clock_->ApproximateNow() + sent_packet_manager_.GetPtoDelay(); + } + return clock_->ApproximateNow() + + QuicTime::Delta::FromMilliseconds(3 * kInitialRttMs); +} + +void QuicConnection::ValidatePath( + std::unique_ptr context, + std::unique_ptr result_delegate, + PathValidationReason reason) { + if (!connection_migration_use_new_cid_ && + perspective_ == Perspective::IS_CLIENT && + !IsDefaultPath(context->self_address(), context->peer_address())) { + alternative_path_ = PathState( + context->self_address(), context->peer_address(), + default_path_.client_connection_id, default_path_.server_connection_id, + default_path_.stateless_reset_token); + } + if (path_validator_.HasPendingPathValidation()) { + if (perspective_ == Perspective::IS_CLIENT && + IsValidatingServerPreferredAddress()) { + QUIC_CLIENT_HISTOGRAM_BOOL( + "QuicSession.ServerPreferredAddressValidationCancelled", true, + "How often the caller kicked off another validation while there is " + "an on-going server preferred address validation."); + } + // Cancel and fail any earlier validation. + path_validator_.CancelPathValidation(); + } + if (connection_migration_use_new_cid_ && + perspective_ == Perspective::IS_CLIENT && + !IsDefaultPath(context->self_address(), context->peer_address())) { + if (self_issued_cid_manager_ != nullptr) { + self_issued_cid_manager_->MaybeSendNewConnectionIds(); + if (!connected_) { + return; + } + } + if ((self_issued_cid_manager_ != nullptr && + !self_issued_cid_manager_->HasConnectionIdToConsume()) || + (peer_issued_cid_manager_ != nullptr && + !peer_issued_cid_manager_->HasUnusedConnectionId())) { + QUIC_DVLOG(1) << "Client cannot start new path validation as there is no " + "requried connection ID is available."; + result_delegate->OnPathValidationFailure(std::move(context)); + return; + } + QuicConnectionId client_connection_id, server_connection_id; + absl::optional stateless_reset_token; + if (self_issued_cid_manager_ != nullptr) { + client_connection_id = + *self_issued_cid_manager_->ConsumeOneConnectionId(); + } + if (peer_issued_cid_manager_ != nullptr) { + const auto* connection_id_data = + peer_issued_cid_manager_->ConsumeOneUnusedConnectionId(); + server_connection_id = connection_id_data->connection_id; + stateless_reset_token = connection_id_data->stateless_reset_token; + } + alternative_path_ = PathState(context->self_address(), + context->peer_address(), client_connection_id, + server_connection_id, stateless_reset_token); + } + path_validator_.StartPathValidation(std::move(context), + std::move(result_delegate), reason); + if (perspective_ == Perspective::IS_CLIENT && + IsValidatingServerPreferredAddress()) { + AddKnownServerAddress(received_server_preferred_address_); + } +} + +bool QuicConnection::SendPathResponse( + const QuicPathFrameBuffer& data_buffer, + const QuicSocketAddress& peer_address_to_send, + const QuicSocketAddress& effective_peer_address) { + if (!framer_.HasEncrypterOfEncryptionLevel(ENCRYPTION_FORWARD_SECURE)) { + return false; + } + QuicConnectionId client_cid, server_cid; + if (connection_migration_use_new_cid_) { + FindOnPathConnectionIds(last_received_packet_info_.destination_address, + effective_peer_address, &client_cid, &server_cid); + } + // Send PATH_RESPONSE using the provided peer address. If the creator has been + // using a different peer address, it will flush before and after serializing + // the current PATH_RESPONSE. + QuicPacketCreator::ScopedPeerAddressContext context( + &packet_creator_, peer_address_to_send, client_cid, server_cid, + connection_migration_use_new_cid_); + QUIC_DVLOG(1) << ENDPOINT << "Send PATH_RESPONSE to " << peer_address_to_send; + if (default_path_.self_address == + last_received_packet_info_.destination_address) { + // The PATH_CHALLENGE is received on the default socket. Respond on the same + // socket. + return packet_creator_.AddPathResponseFrame(data_buffer); + } + + QUICHE_DCHECK_EQ(Perspective::IS_CLIENT, perspective_); + // This PATH_CHALLENGE is received on an alternative socket which should be + // used to send PATH_RESPONSE. + if (!path_validator_.HasPendingPathValidation() || + path_validator_.GetContext()->self_address() != + last_received_packet_info_.destination_address) { + // Ignore this PATH_CHALLENGE if it's received from an uninteresting + // socket. + return true; + } + QuicPacketWriter* writer = path_validator_.GetContext()->WriterToUse(); + + std::unique_ptr probing_packet = + packet_creator_.SerializePathResponseConnectivityProbingPacket( + {data_buffer}, /*is_padded=*/true); + QUICHE_DCHECK_EQ(IsRetransmittable(*probing_packet), NO_RETRANSMITTABLE_DATA); + QUIC_DVLOG(1) << ENDPOINT + << "Send PATH_RESPONSE from alternative socket with address " + << last_received_packet_info_.destination_address; + // Ignore the return value to treat write error on the alternative writer as + // part of network error. If the writer becomes blocked, wait for the peer to + // send another PATH_CHALLENGE. + WritePacketUsingWriter(std::move(probing_packet), writer, + last_received_packet_info_.destination_address, + peer_address_to_send, + /*measure_rtt=*/false); + return true; +} + +void QuicConnection::UpdatePeerAddress(QuicSocketAddress peer_address) { + direct_peer_address_ = peer_address; + packet_creator_.SetDefaultPeerAddress(peer_address); +} + +void QuicConnection::SendPingAtLevel(EncryptionLevel level) { + ScopedEncryptionLevelContext context(this, level); + SendControlFrame(QuicFrame(QuicPingFrame())); +} + +bool QuicConnection::HasPendingPathValidation() const { + return path_validator_.HasPendingPathValidation(); +} + +QuicPathValidationContext* QuicConnection::GetPathValidationContext() const { + return path_validator_.GetContext(); +} + +void QuicConnection::CancelPathValidation() { + path_validator_.CancelPathValidation(); +} + +bool QuicConnection::UpdateConnectionIdsOnMigration( + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address) { + QUICHE_DCHECK(perspective_ == Perspective::IS_CLIENT); + if (IsAlternativePath(self_address, peer_address)) { + // Client migration is after path validation. + default_path_.client_connection_id = alternative_path_.client_connection_id; + default_path_.server_connection_id = alternative_path_.server_connection_id; + default_path_.stateless_reset_token = + alternative_path_.stateless_reset_token; + return true; + } + // Client migration is without path validation. + if (self_issued_cid_manager_ != nullptr) { + self_issued_cid_manager_->MaybeSendNewConnectionIds(); + if (!connected_) { + return false; + } + } + if ((self_issued_cid_manager_ != nullptr && + !self_issued_cid_manager_->HasConnectionIdToConsume()) || + (peer_issued_cid_manager_ != nullptr && + !peer_issued_cid_manager_->HasUnusedConnectionId())) { + return false; + } + if (self_issued_cid_manager_ != nullptr) { + default_path_.client_connection_id = + *self_issued_cid_manager_->ConsumeOneConnectionId(); + } + if (peer_issued_cid_manager_ != nullptr) { + const auto* connection_id_data = + peer_issued_cid_manager_->ConsumeOneUnusedConnectionId(); + default_path_.server_connection_id = connection_id_data->connection_id; + default_path_.stateless_reset_token = + connection_id_data->stateless_reset_token; + } + return true; +} + +void QuicConnection::RetirePeerIssuedConnectionIdsNoLongerOnPath() { + if (!connection_migration_use_new_cid_ || + peer_issued_cid_manager_ == nullptr) { + return; + } + if (perspective_ == Perspective::IS_CLIENT) { + peer_issued_cid_manager_->MaybeRetireUnusedConnectionIds( + {default_path_.server_connection_id, + alternative_path_.server_connection_id}); + } else { + peer_issued_cid_manager_->MaybeRetireUnusedConnectionIds( + {default_path_.client_connection_id, + alternative_path_.client_connection_id}); + } +} + +void QuicConnection::RetirePeerIssuedConnectionIdsOnPathValidationFailure() { + // The alarm to retire connection IDs no longer on paths is scheduled at the + // end of writing and reading packet. On path validation failure, there could + // be no packet to write or read. Hence the retirement alarm for the + // connection ID associated with the failed path needs to be proactively + // scheduled here. + if (GetQuicReloadableFlag( + quic_retire_cid_on_reverse_path_validation_failure) || + perspective_ == Perspective::IS_CLIENT) { + QUIC_RELOADABLE_FLAG_COUNT( + quic_retire_cid_on_reverse_path_validation_failure); + RetirePeerIssuedConnectionIdsNoLongerOnPath(); + } +} + +bool QuicConnection::MigratePath(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + QuicPacketWriter* writer, bool owns_writer) { + QUICHE_DCHECK(perspective_ == Perspective::IS_CLIENT); + if (!connected_) { + if (owns_writer) { + delete writer; + } + return false; + } + QUICHE_DCHECK(!version().UsesHttp3() || IsHandshakeConfirmed() || + accelerated_server_preferred_address_); + + if (connection_migration_use_new_cid_) { + if (!UpdateConnectionIdsOnMigration(self_address, peer_address)) { + if (owns_writer) { + delete writer; + } + return false; + } + if (packet_creator_.GetServerConnectionId().length() != + default_path_.server_connection_id.length()) { + packet_creator_.FlushCurrentPacket(); + } + packet_creator_.SetClientConnectionId(default_path_.client_connection_id); + packet_creator_.SetServerConnectionId(default_path_.server_connection_id); + } + + const auto self_address_change_type = QuicUtils::DetermineAddressChangeType( + default_path_.self_address, self_address); + const auto peer_address_change_type = QuicUtils::DetermineAddressChangeType( + default_path_.peer_address, peer_address); + QUICHE_DCHECK(self_address_change_type != NO_CHANGE || + peer_address_change_type != NO_CHANGE); + const bool is_port_change = (self_address_change_type == PORT_CHANGE || + self_address_change_type == NO_CHANGE) && + (peer_address_change_type == PORT_CHANGE || + peer_address_change_type == NO_CHANGE); + SetSelfAddress(self_address); + UpdatePeerAddress(peer_address); + default_path_.peer_address = peer_address; + if (writer_ != writer) { + SetQuicPacketWriter(writer, owns_writer); + } + MaybeClearQueuedPacketsOnPathChange(); + OnSuccessfulMigration(is_port_change); + return true; +} + +void QuicConnection::OnPathValidationFailureAtClient( + bool is_multi_port, const QuicPathValidationContext& context) { + if (connection_migration_use_new_cid_) { + QUICHE_DCHECK(perspective_ == Perspective::IS_CLIENT); + alternative_path_.Clear(); + } + + if (is_multi_port && multi_port_stats_ != nullptr) { + if (is_path_degrading_) { + multi_port_stats_->num_multi_port_probe_failures_when_path_degrading++; + } else { + multi_port_stats_ + ->num_multi_port_probe_failures_when_path_not_degrading++; + } + } + + if (context.peer_address() == received_server_preferred_address_ && + received_server_preferred_address_ != default_path_.peer_address) { + QUIC_DLOG(INFO) << "Failed to validate server preferred address : " + << received_server_preferred_address_; + mutable_stats().failed_to_validate_server_preferred_address = true; + } + + RetirePeerIssuedConnectionIdsOnPathValidationFailure(); +} + +QuicConnectionId QuicConnection::GetOneActiveServerConnectionId() const { + if (perspective_ == Perspective::IS_CLIENT || + self_issued_cid_manager_ == nullptr) { + return connection_id(); + } + auto active_connection_ids = GetActiveServerConnectionIds(); + QUIC_BUG_IF(quic_bug_6944, active_connection_ids.empty()); + if (active_connection_ids.empty() || + std::find(active_connection_ids.begin(), active_connection_ids.end(), + connection_id()) != active_connection_ids.end()) { + return connection_id(); + } + QUICHE_CODE_COUNT(connection_id_on_default_path_has_been_retired); + auto active_connection_id = + self_issued_cid_manager_->GetOneActiveConnectionId(); + return active_connection_id; +} + +std::vector QuicConnection::GetActiveServerConnectionIds() + const { + QUICHE_DCHECK_EQ(Perspective::IS_SERVER, perspective_); + std::vector result; + if (self_issued_cid_manager_ == nullptr) { + result.push_back(default_path_.server_connection_id); + } else { + QUICHE_DCHECK(version().HasIetfQuicFrames()); + result = self_issued_cid_manager_->GetUnretiredConnectionIds(); + } + if (!original_destination_connection_id_.has_value()) { + return result; + } + // Add the original connection ID + if (std::find(result.begin(), result.end(), + original_destination_connection_id_.value()) != result.end()) { + QUIC_BUG(quic_unexpected_original_destination_connection_id) + << "original_destination_connection_id: " + << original_destination_connection_id_.value() + << " is unexpectedly in active list"; + } else { + result.insert(result.end(), original_destination_connection_id_.value()); + } + return result; +} + +void QuicConnection::CreateConnectionIdManager() { + if (!version().HasIetfQuicFrames()) { + return; + } + + if (perspective_ == Perspective::IS_CLIENT) { + if (!default_path_.server_connection_id.IsEmpty()) { + peer_issued_cid_manager_ = + std::make_unique( + kMinNumOfActiveConnectionIds, default_path_.server_connection_id, + clock_, alarm_factory_, this, context()); + } + } else { + if (!default_path_.server_connection_id.IsEmpty()) { + self_issued_cid_manager_ = MakeSelfIssuedConnectionIdManager(); + } + } +} + +void QuicConnection::QuicBugIfHasPendingFrames(QuicStreamId id) const { + QUIC_BUG_IF(quic_has_pending_frames_unexpectedly, + connected_ && packet_creator_.HasPendingStreamFramesOfStream(id)) + << "Stream " << id + << " has pending frames unexpectedly. Received packet info: " + << last_received_packet_info_; +} + +void QuicConnection::SetUnackedMapInitialCapacity() { + sent_packet_manager_.ReserveUnackedPacketsInitialCapacity( + GetUnackedMapInitialCapacity()); +} + +void QuicConnection::SetSourceAddressTokenToSend(absl::string_view token) { + QUICHE_DCHECK_EQ(perspective_, Perspective::IS_CLIENT); + if (!packet_creator_.HasRetryToken()) { + // Ignore received tokens (via NEW_TOKEN frame) from previous connections + // when a RETRY token has been received. + packet_creator_.SetRetryToken(std::string(token.data(), token.length())); + } +} + +void QuicConnection::MaybeUpdateBytesSentToAlternativeAddress( + const QuicSocketAddress& peer_address, QuicByteCount sent_packet_size) { + if (!version().SupportsAntiAmplificationLimit() || + perspective_ != Perspective::IS_SERVER) { + return; + } + QUICHE_DCHECK(!IsDefaultPath(default_path_.self_address, peer_address)); + if (!IsAlternativePath(default_path_.self_address, peer_address)) { + QUIC_DLOG(INFO) << "Wrote to uninteresting peer address: " << peer_address + << " default direct_peer_address_ " << direct_peer_address_ + << " alternative path peer address " + << alternative_path_.peer_address; + return; + } + if (alternative_path_.validated) { + return; + } + if (alternative_path_.bytes_sent_before_address_validation >= + anti_amplification_factor_ * + alternative_path_.bytes_received_before_address_validation) { + QUIC_LOG_FIRST_N(WARNING, 100) + << "Server sent more data than allowed to unverified alternative " + "peer address " + << peer_address << " bytes sent " + << alternative_path_.bytes_sent_before_address_validation + << ", bytes received " + << alternative_path_.bytes_received_before_address_validation; + } + alternative_path_.bytes_sent_before_address_validation += sent_packet_size; +} + +void QuicConnection::MaybeUpdateBytesReceivedFromAlternativeAddress( + QuicByteCount received_packet_size) { + if (!version().SupportsAntiAmplificationLimit() || + perspective_ != Perspective::IS_SERVER || + !IsAlternativePath(last_received_packet_info_.destination_address, + GetEffectivePeerAddressFromCurrentPacket()) || + last_received_packet_info_.received_bytes_counted) { + return; + } + // Only update bytes received if this probing frame is received on the most + // recent alternative path. + QUICHE_DCHECK(!IsDefaultPath(last_received_packet_info_.destination_address, + GetEffectivePeerAddressFromCurrentPacket())); + if (!alternative_path_.validated) { + alternative_path_.bytes_received_before_address_validation += + received_packet_size; + } + last_received_packet_info_.received_bytes_counted = true; +} + +bool QuicConnection::IsDefaultPath( + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address) const { + return direct_peer_address_ == peer_address && + default_path_.self_address == self_address; +} + +bool QuicConnection::IsAlternativePath( + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address) const { + return alternative_path_.peer_address == peer_address && + alternative_path_.self_address == self_address; +} + +void QuicConnection::PathState::Clear() { + self_address = QuicSocketAddress(); + peer_address = QuicSocketAddress(); + client_connection_id = {}; + server_connection_id = {}; + validated = false; + bytes_received_before_address_validation = 0; + bytes_sent_before_address_validation = 0; + send_algorithm = nullptr; + rtt_stats = absl::nullopt; + stateless_reset_token.reset(); +} + +QuicConnection::PathState::PathState(PathState&& other) { + *this = std::move(other); +} + +QuicConnection::PathState& QuicConnection::PathState::operator=( + QuicConnection::PathState&& other) { + if (this != &other) { + self_address = other.self_address; + peer_address = other.peer_address; + client_connection_id = other.client_connection_id; + server_connection_id = other.server_connection_id; + stateless_reset_token = other.stateless_reset_token; + validated = other.validated; + bytes_received_before_address_validation = + other.bytes_received_before_address_validation; + bytes_sent_before_address_validation = + other.bytes_sent_before_address_validation; + send_algorithm = std::move(other.send_algorithm); + if (other.rtt_stats.has_value()) { + rtt_stats.emplace(); + rtt_stats->CloneFrom(other.rtt_stats.value()); + } else { + rtt_stats.reset(); + } + other.Clear(); + } + return *this; +} + +bool QuicConnection::IsReceivedPeerAddressValidated() const { + QuicSocketAddress current_effective_peer_address = + GetEffectivePeerAddressFromCurrentPacket(); + QUICHE_DCHECK(current_effective_peer_address.IsInitialized()); + return (alternative_path_.peer_address.host() == + current_effective_peer_address.host() && + alternative_path_.validated) || + (default_path_.validated && default_path_.peer_address.host() == + current_effective_peer_address.host()); +} + +void QuicConnection::OnMultiPortPathProbingSuccess( + std::unique_ptr context, QuicTime start_time) { + QUICHE_DCHECK_EQ(Perspective::IS_CLIENT, perspective()); + alternative_path_.validated = true; + multi_port_path_context_ = std::move(context); + multi_port_probing_alarm_->Set(clock_->ApproximateNow() + + multi_port_probing_interval_); + if (multi_port_stats_ != nullptr) { + auto now = clock_->Now(); + auto time_delta = now - start_time; + multi_port_stats_->rtt_stats.UpdateRtt(time_delta, QuicTime::Delta::Zero(), + now); + if (is_path_degrading_) { + multi_port_stats_->rtt_stats_when_default_path_degrading.UpdateRtt( + time_delta, QuicTime::Delta::Zero(), now); + } + } +} + +void QuicConnection::MaybeProbeMultiPortPath() { + if (!connected_ || path_validator_.HasPendingPathValidation() || + !multi_port_path_context_ || + alternative_path_.self_address != + multi_port_path_context_->self_address() || + alternative_path_.peer_address != + multi_port_path_context_->peer_address() || + !visitor_->ShouldKeepConnectionAlive() || + multi_port_probing_alarm_->IsSet()) { + return; + } + auto multi_port_validation_result_delegate = + std::make_unique(this); + path_validator_.StartPathValidation( + std::move(multi_port_path_context_), + std::move(multi_port_validation_result_delegate), + PathValidationReason::kMultiPort); +} + +QuicConnection::MultiPortPathValidationResultDelegate:: + MultiPortPathValidationResultDelegate(QuicConnection* connection) + : connection_(connection) { + QUICHE_DCHECK_EQ(Perspective::IS_CLIENT, connection->perspective()); +} + +void QuicConnection::MultiPortPathValidationResultDelegate:: + OnPathValidationSuccess(std::unique_ptr context, + QuicTime start_time) { + connection_->OnMultiPortPathProbingSuccess(std::move(context), start_time); +} + +void QuicConnection::MultiPortPathValidationResultDelegate:: + OnPathValidationFailure( + std::unique_ptr context) { + connection_->OnPathValidationFailureAtClient(/*is_multi_port=*/true, + *context); +} + +QuicConnection::ReversePathValidationResultDelegate:: + ReversePathValidationResultDelegate( + QuicConnection* connection, + const QuicSocketAddress& direct_peer_address) + : QuicPathValidator::ResultDelegate(), + connection_(connection), + original_direct_peer_address_(direct_peer_address), + peer_address_default_path_(connection->direct_peer_address_), + peer_address_alternative_path_( + connection_->alternative_path_.peer_address), + active_effective_peer_migration_type_( + connection_->active_effective_peer_migration_type_) { + if (connection_->count_reverse_path_validation_stats()) { + QUIC_CODE_COUNT_N(quic_reverse_path_validation, 1, 4); + } +} + +void QuicConnection::ReversePathValidationResultDelegate:: + OnPathValidationSuccess(std::unique_ptr context, + QuicTime start_time) { + if (connection_->count_reverse_path_validation_stats()) { + QUIC_CODE_COUNT_N(quic_reverse_path_validation, 2, 4); + } + QUIC_DLOG(INFO) << "Successfully validated new path " << *context + << ", validation started at " << start_time; + if (connection_->IsDefaultPath(context->self_address(), + context->peer_address())) { + QUIC_CODE_COUNT_N(quic_kick_off_client_address_validation, 3, 6); + if (connection_->active_effective_peer_migration_type_ == NO_CHANGE) { + std::string error_detail = absl::StrCat( + "Reverse path validation on default path from ", + context->self_address().ToString(), " to ", + context->peer_address().ToString(), + " completed without active peer address change: current " + "peer address on default path ", + connection_->direct_peer_address_.ToString(), + ", peer address on default path when the reverse path " + "validation was kicked off ", + peer_address_default_path_.ToString(), + ", peer address on alternative path when the reverse " + "path validation was kicked off ", + peer_address_alternative_path_.ToString(), + ", with active_effective_peer_migration_type_ = ", + AddressChangeTypeToString(active_effective_peer_migration_type_), + ". The last received packet number ", + connection_->last_received_packet_info_.header.packet_number + .ToString(), + " Connection is connected: ", connection_->connected_); + QUIC_BUG(quic_bug_10511_43) << error_detail; + } + connection_->OnEffectivePeerMigrationValidated( + connection_->alternative_path_.server_connection_id == + connection_->default_path_.server_connection_id); + } else { + QUICHE_DCHECK(connection_->IsAlternativePath( + context->self_address(), context->effective_peer_address())); + QUIC_CODE_COUNT_N(quic_kick_off_client_address_validation, 4, 6); + QUIC_DVLOG(1) << "Mark alternative peer address " + << context->effective_peer_address() << " validated."; + connection_->alternative_path_.validated = true; + } +} + +void QuicConnection::ReversePathValidationResultDelegate:: + OnPathValidationFailure( + std::unique_ptr context) { + if (connection_->count_reverse_path_validation_stats()) { + QUIC_CODE_COUNT_N(quic_reverse_path_validation, 3, 4); + } + if (!connection_->connected()) { + return; + } + QUIC_DLOG(INFO) << "Fail to validate new path " << *context; + if (connection_->IsDefaultPath(context->self_address(), + context->peer_address())) { + // Only act upon validation failure on the default path. + QUIC_CODE_COUNT_N(quic_kick_off_client_address_validation, 5, 6); + connection_->RestoreToLastValidatedPath(original_direct_peer_address_); + } else if (connection_->IsAlternativePath( + context->self_address(), context->effective_peer_address())) { + QUIC_CODE_COUNT_N(quic_kick_off_client_address_validation, 6, 6); + connection_->alternative_path_.Clear(); + } + connection_->RetirePeerIssuedConnectionIdsOnPathValidationFailure(); +} + +QuicConnection::ScopedRetransmissionTimeoutIndicator:: + ScopedRetransmissionTimeoutIndicator(QuicConnection* connection) + : connection_(connection) { + QUICHE_DCHECK(!connection_->in_probe_time_out_) + << "ScopedRetransmissionTimeoutIndicator is not supposed to be nested"; + connection_->in_probe_time_out_ = true; +} + +QuicConnection::ScopedRetransmissionTimeoutIndicator:: + ~ScopedRetransmissionTimeoutIndicator() { + QUICHE_DCHECK(connection_->in_probe_time_out_); + connection_->in_probe_time_out_ = false; +} + +void QuicConnection::RestoreToLastValidatedPath( + QuicSocketAddress original_direct_peer_address) { + QUIC_DLOG(INFO) << "Switch back to use the old peer address " + << alternative_path_.peer_address; + if (!alternative_path_.validated) { + // If not validated by now, close connection silently so that the following + // packets received will be rejected. + CloseConnection(QUIC_INTERNAL_ERROR, + "No validated peer address to use after reverse path " + "validation failure.", + ConnectionCloseBehavior::SILENT_CLOSE); + return; + } + MaybeClearQueuedPacketsOnPathChange(); + + // Revert congestion control context to old state. + OnPeerIpAddressChanged(); + + if (alternative_path_.send_algorithm != nullptr) { + sent_packet_manager_.SetSendAlgorithm( + alternative_path_.send_algorithm.release()); + sent_packet_manager_.SetRttStats(alternative_path_.rtt_stats.value()); + } else { + QUIC_BUG(quic_bug_10511_42) + << "Fail to store congestion controller before migration."; + } + + UpdatePeerAddress(original_direct_peer_address); + SetDefaultPathState(std::move(alternative_path_)); + + active_effective_peer_migration_type_ = NO_CHANGE; + ++stats_.num_invalid_peer_migration; + // The reverse path validation failed because of alarm firing, flush all the + // pending writes previously throttled by anti-amplification limit. + WriteIfNotBlocked(); +} + +std::unique_ptr +QuicConnection::OnPeerIpAddressChanged() { + QUICHE_DCHECK(validate_client_addresses_); + std::unique_ptr old_send_algorithm = + sent_packet_manager_.OnConnectionMigration( + /*reset_send_algorithm=*/true); + // OnConnectionMigration() should have marked in-flight packets to be + // retransmitted if there is any. + QUICHE_DCHECK(!sent_packet_manager_.HasInFlightPackets()); + // OnConnectionMigration() may have changed the retransmission timer, so + // re-arm it. + SetRetransmissionAlarm(); + // Stop detections in quiecense. + blackhole_detector_.StopDetection(/*permanent=*/false); + return old_send_algorithm; +} + +void QuicConnection::set_keep_alive_ping_timeout( + QuicTime::Delta keep_alive_ping_timeout) { + ping_manager_.set_keep_alive_timeout(keep_alive_ping_timeout); +} + +void QuicConnection::set_initial_retransmittable_on_wire_timeout( + QuicTime::Delta retransmittable_on_wire_timeout) { + ping_manager_.set_initial_retransmittable_on_wire_timeout( + retransmittable_on_wire_timeout); +} + +bool QuicConnection::IsValidatingServerPreferredAddress() const { + QUICHE_DCHECK_EQ(perspective_, Perspective::IS_CLIENT); + return received_server_preferred_address_.IsInitialized() && + received_server_preferred_address_ != default_path_.peer_address && + path_validator_.HasPendingPathValidation() && + path_validator_.GetContext()->peer_address() == + received_server_preferred_address_; +} + +void QuicConnection::OnServerPreferredAddressValidated( + QuicPathValidationContext& context, bool owns_writer) { + QUIC_DLOG(INFO) << "Server preferred address: " << context.peer_address() + << " validated. Migrating path, self_address: " + << context.self_address() + << ", peer_address: " << context.peer_address(); + mutable_stats().server_preferred_address_validated = true; + const bool success = + MigratePath(context.self_address(), context.peer_address(), + context.WriterToUse(), owns_writer); + QUIC_BUG_IF(failed to migrate to server preferred address, !success) + << "Failed to migrate to server preferred address: " + << context.peer_address() << " after successful validation"; +} + +#undef ENDPOINT // undef for jumbo builds +} // namespace quic diff --git a/quiche/quic/core/quic_connection.h b/quiche/quic/core/quic_connection.h new file mode 100644 index 000000000000..55714ad5019a --- /dev/null +++ b/quiche/quic/core/quic_connection.h @@ -0,0 +1,2387 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// The entity that handles framing writes for a Quic client or server. +// Each QuicSession will have a connection associated with it. +// +// On the server side, the Dispatcher handles the raw reads, and hands off +// packets via ProcessUdpPacket for framing and processing. +// +// On the client side, the Connection handles the raw reads, as well as the +// processing. +// +// Note: this class is not thread-safe. + +#ifndef QUICHE_QUIC_CORE_QUIC_CONNECTION_H_ +#define QUICHE_QUIC_CORE_QUIC_CONNECTION_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/quic/core/congestion_control/rtt_stats.h" +#include "quiche/quic/core/crypto/quic_decrypter.h" +#include "quiche/quic/core/crypto/quic_encrypter.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/crypto/transport_parameters.h" +#include "quiche/quic/core/frames/quic_ack_frequency_frame.h" +#include "quiche/quic/core/frames/quic_max_streams_frame.h" +#include "quiche/quic/core/frames/quic_new_connection_id_frame.h" +#include "quiche/quic/core/quic_alarm.h" +#include "quiche/quic/core/quic_alarm_factory.h" +#include "quiche/quic/core/quic_blocked_writer_interface.h" +#include "quiche/quic/core/quic_connection_context.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_connection_id_manager.h" +#include "quiche/quic/core/quic_connection_stats.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_framer.h" +#include "quiche/quic/core/quic_idle_network_detector.h" +#include "quiche/quic/core/quic_lru_cache.h" +#include "quiche/quic/core/quic_mtu_discovery.h" +#include "quiche/quic/core/quic_network_blackhole_detector.h" +#include "quiche/quic/core/quic_one_block_arena.h" +#include "quiche/quic/core/quic_packet_creator.h" +#include "quiche/quic/core/quic_packet_writer.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_path_validator.h" +#include "quiche/quic/core/quic_ping_manager.h" +#include "quiche/quic/core/quic_sent_packet_manager.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/uber_received_packet_manager.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/common/quiche_circular_deque.h" + +namespace quic { + +class QuicClock; +class QuicConfig; +class QuicConnection; + +namespace test { +class QuicConnectionPeer; +} // namespace test + +// Class that receives callbacks from the connection when frames are received +// and when other interesting events happen. +class QUIC_EXPORT_PRIVATE QuicConnectionVisitorInterface { + public: + virtual ~QuicConnectionVisitorInterface() {} + + // A simple visitor interface for dealing with a data frame. + virtual void OnStreamFrame(const QuicStreamFrame& frame) = 0; + + // Called when a CRYPTO frame containing handshake data is received. + virtual void OnCryptoFrame(const QuicCryptoFrame& frame) = 0; + + // The session should process the WINDOW_UPDATE frame, adjusting both stream + // and connection level flow control windows. + virtual void OnWindowUpdateFrame(const QuicWindowUpdateFrame& frame) = 0; + + // A BLOCKED frame indicates the peer is flow control blocked + // on a specified stream. + virtual void OnBlockedFrame(const QuicBlockedFrame& frame) = 0; + + // Called when the stream is reset by the peer. + virtual void OnRstStream(const QuicRstStreamFrame& frame) = 0; + + // Called when the connection is going away according to the peer. + virtual void OnGoAway(const QuicGoAwayFrame& frame) = 0; + + // Called when |message| has been received. + virtual void OnMessageReceived(absl::string_view message) = 0; + + // Called when a HANDSHAKE_DONE frame has been received. + virtual void OnHandshakeDoneReceived() = 0; + + // Called when a NEW_TOKEN frame has been received. + virtual void OnNewTokenReceived(absl::string_view token) = 0; + + // Called when a MAX_STREAMS frame has been received from the peer. + virtual bool OnMaxStreamsFrame(const QuicMaxStreamsFrame& frame) = 0; + + // Called when a STREAMS_BLOCKED frame has been received from the peer. + virtual bool OnStreamsBlockedFrame(const QuicStreamsBlockedFrame& frame) = 0; + + // Called when the connection is closed either locally by the framer, or + // remotely by the peer. + virtual void OnConnectionClosed(const QuicConnectionCloseFrame& frame, + ConnectionCloseSource source) = 0; + + // Called when the connection failed to write because the socket was blocked. + virtual void OnWriteBlocked() = 0; + + // Called once a specific QUIC version is agreed by both endpoints. + virtual void OnSuccessfulVersionNegotiation( + const ParsedQuicVersion& version) = 0; + + // Called when a packet has been received by the connection, after being + // validated and parsed. Only called when the client receives a valid packet + // or the server receives a connectivity probing packet. + // |is_connectivity_probe| is true if the received packet is a connectivity + // probe. + virtual void OnPacketReceived(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + bool is_connectivity_probe) = 0; + + // Called when a blocked socket becomes writable. + virtual void OnCanWrite() = 0; + + // Called when the connection experiences a change in congestion window. + virtual void OnCongestionWindowChange(QuicTime now) = 0; + + // Called when the connection receives a packet from a migrated client. + virtual void OnConnectionMigration(AddressChangeType type) = 0; + + // Called when the peer seems unreachable over the current path. + virtual void OnPathDegrading() = 0; + + // Called when forward progress made after path degrading. + virtual void OnForwardProgressMadeAfterPathDegrading() = 0; + + // Called when the connection sends ack after + // max_consecutive_num_packets_with_no_retransmittable_frames_ consecutive not + // retransmittable packets sent. To instigate an ack from peer, a + // retransmittable frame needs to be added. + virtual void OnAckNeedsRetransmittableFrame() = 0; + + // Called when an AckFrequency frame need to be sent. + virtual void SendAckFrequency(const QuicAckFrequencyFrame& frame) = 0; + + // Called to send a NEW_CONNECTION_ID frame. + virtual void SendNewConnectionId(const QuicNewConnectionIdFrame& frame) = 0; + + // Called to send a RETIRE_CONNECTION_ID frame. + virtual void SendRetireConnectionId(uint64_t sequence_number) = 0; + + // Called when server starts to use a server issued connection ID. Returns + // true if this connection ID hasn't been used by another connection. + virtual bool MaybeReserveConnectionId( + const QuicConnectionId& server_connection_id) = 0; + + // Called when server stops to use a server issued connection ID. + virtual void OnServerConnectionIdRetired( + const QuicConnectionId& server_connection_id) = 0; + + // Called to ask if the visitor wants to schedule write resumption as it both + // has pending data to write, and is able to write (e.g. based on flow control + // limits). + // Writes may be pending because they were write-blocked, congestion-throttled + // or yielded to other connections. + virtual bool WillingAndAbleToWrite() const = 0; + + // Called to ask if the connection should be kept alive and prevented + // from timing out, for example if there are outstanding application + // transactions expecting a response. + virtual bool ShouldKeepConnectionAlive() const = 0; + + // Called to retrieve streams information for logging purpose. + virtual std::string GetStreamsInfoForLogging() const = 0; + + // Called when a self address change is observed. Returns true if self address + // change is allowed. + virtual bool AllowSelfAddressChange() const = 0; + + // Called to get current handshake state. + virtual HandshakeState GetHandshakeState() const = 0; + + // Called when a STOP_SENDING frame has been received. + virtual void OnStopSendingFrame(const QuicStopSendingFrame& frame) = 0; + + // Called when a packet of encryption |level| has been successfully decrypted. + virtual void OnPacketDecrypted(EncryptionLevel level) = 0; + + // Called when a 1RTT packet has been acknowledged. + virtual void OnOneRttPacketAcknowledged() = 0; + + // Called when a packet of ENCRYPTION_HANDSHAKE gets sent. + virtual void OnHandshakePacketSent() = 0; + + // Called when a key update has occurred. + virtual void OnKeyUpdate(KeyUpdateReason reason) = 0; + + // Called to generate a decrypter for the next key phase. Each call should + // generate the key for phase n+1. + virtual std::unique_ptr + AdvanceKeysAndCreateCurrentOneRttDecrypter() = 0; + + // Called to generate an encrypter for the same key phase of the last + // decrypter returned by AdvanceKeysAndCreateCurrentOneRttDecrypter(). + virtual std::unique_ptr CreateCurrentOneRttEncrypter() = 0; + + // Called when connection is being closed right before a CONNECTION_CLOSE + // frame is serialized, but only on the server and only if forward secure + // encryption has already been established. + virtual void BeforeConnectionCloseSent() = 0; + + // Called by the server to validate |token| in received INITIAL packets. + // Consider the client address gets validated (and therefore remove + // amplification factor) once the |token| gets successfully validated. + virtual bool ValidateToken(absl::string_view token) = 0; + + // Called by the server to send another token. + // Return false if the crypto stream fail to generate one. + virtual bool MaybeSendAddressToken() = 0; + + // When bandwidth update alarms. + virtual void OnBandwidthUpdateTimeout() = 0; + + // Returns context needed for the connection to probe on the alternative path. + virtual std::unique_ptr + CreateContextForMultiPortPath() = 0; + + // Migrate to the multi-port path which is identified by |context|. + virtual void MigrateToMultiPortPath( + std::unique_ptr context) = 0; + + // Called when the client receives a preferred address from its peer. + virtual void OnServerPreferredAddressAvailable( + const QuicSocketAddress& server_preferred_address) = 0; +}; + +// Interface which gets callbacks from the QuicConnection at interesting +// points. Implementations must not mutate the state of the connection +// as a result of these callbacks. +class QUIC_EXPORT_PRIVATE QuicConnectionDebugVisitor + : public QuicSentPacketManager::DebugDelegate { + public: + ~QuicConnectionDebugVisitor() override {} + + // Called when a packet has been sent. + virtual void OnPacketSent(QuicPacketNumber /*packet_number*/, + QuicPacketLength /*packet_length*/, + bool /*has_crypto_handshake*/, + TransmissionType /*transmission_type*/, + EncryptionLevel /*encryption_level*/, + const QuicFrames& /*retransmittable_frames*/, + const QuicFrames& /*nonretransmittable_frames*/, + QuicTime /*sent_time*/) {} + + // Called when a coalesced packet is successfully serialized. + virtual void OnCoalescedPacketSent( + const QuicCoalescedPacket& /*coalesced_packet*/, size_t /*length*/) {} + + // Called when a PING frame has been sent. + virtual void OnPingSent() {} + + // Called when a packet has been received, but before it is + // validated or parsed. + virtual void OnPacketReceived(const QuicSocketAddress& /*self_address*/, + const QuicSocketAddress& /*peer_address*/, + const QuicEncryptedPacket& /*packet*/) {} + + // Called when the unauthenticated portion of the header has been parsed. + virtual void OnUnauthenticatedHeader(const QuicPacketHeader& /*header*/) {} + + // Called when a packet is received with a connection id that does not + // match the ID of this connection. + virtual void OnIncorrectConnectionId(QuicConnectionId /*connection_id*/) {} + + // Called when an undecryptable packet has been received. If |dropped| is + // true, the packet has been dropped. Otherwise, the packet will be queued and + // connection will attempt to process it later. + virtual void OnUndecryptablePacket(EncryptionLevel /*decryption_level*/, + bool /*dropped*/) {} + + // Called when attempting to process a previously undecryptable packet. + virtual void OnAttemptingToProcessUndecryptablePacket( + EncryptionLevel /*decryption_level*/) {} + + // Called when a duplicate packet has been received. + virtual void OnDuplicatePacket(QuicPacketNumber /*packet_number*/) {} + + // Called when the protocol version on the received packet doensn't match + // current protocol version of the connection. + virtual void OnProtocolVersionMismatch(ParsedQuicVersion /*version*/) {} + + // Called when the complete header of a packet has been parsed. + virtual void OnPacketHeader(const QuicPacketHeader& /*header*/, + QuicTime /*receive_time*/, + EncryptionLevel /*level*/) {} + + // Called when a StreamFrame has been parsed. + virtual void OnStreamFrame(const QuicStreamFrame& /*frame*/) {} + + // Called when a CRYPTO frame containing handshake data is received. + virtual void OnCryptoFrame(const QuicCryptoFrame& /*frame*/) {} + + // Called when a StopWaitingFrame has been parsed. + virtual void OnStopWaitingFrame(const QuicStopWaitingFrame& /*frame*/) {} + + // Called when a QuicPaddingFrame has been parsed. + virtual void OnPaddingFrame(const QuicPaddingFrame& /*frame*/) {} + + // Called when a Ping has been parsed. + virtual void OnPingFrame(const QuicPingFrame& /*frame*/, + QuicTime::Delta /*ping_received_delay*/) {} + + // Called when a GoAway has been parsed. + virtual void OnGoAwayFrame(const QuicGoAwayFrame& /*frame*/) {} + + // Called when a RstStreamFrame has been parsed. + virtual void OnRstStreamFrame(const QuicRstStreamFrame& /*frame*/) {} + + // Called when a ConnectionCloseFrame has been parsed. All forms + // of CONNECTION CLOSE are handled, Google QUIC, IETF QUIC + // CONNECTION CLOSE/Transport and IETF QUIC CONNECTION CLOSE/Application + virtual void OnConnectionCloseFrame( + const QuicConnectionCloseFrame& /*frame*/) {} + + // Called when a WindowUpdate has been parsed. + virtual void OnWindowUpdateFrame(const QuicWindowUpdateFrame& /*frame*/, + const QuicTime& /*receive_time*/) {} + + // Called when a BlockedFrame has been parsed. + virtual void OnBlockedFrame(const QuicBlockedFrame& /*frame*/) {} + + // Called when a NewConnectionIdFrame has been parsed. + virtual void OnNewConnectionIdFrame( + const QuicNewConnectionIdFrame& /*frame*/) {} + + // Called when a RetireConnectionIdFrame has been parsed. + virtual void OnRetireConnectionIdFrame( + const QuicRetireConnectionIdFrame& /*frame*/) {} + + // Called when a NewTokenFrame has been parsed. + virtual void OnNewTokenFrame(const QuicNewTokenFrame& /*frame*/) {} + + // Called when a MessageFrame has been parsed. + virtual void OnMessageFrame(const QuicMessageFrame& /*frame*/) {} + + // Called when a HandshakeDoneFrame has been parsed. + virtual void OnHandshakeDoneFrame(const QuicHandshakeDoneFrame& /*frame*/) {} + + // Called when a public reset packet has been received. + virtual void OnPublicResetPacket(const QuicPublicResetPacket& /*packet*/) {} + + // Called when a version negotiation packet has been received. + virtual void OnVersionNegotiationPacket( + const QuicVersionNegotiationPacket& /*packet*/) {} + + // Called when the connection is closed. + virtual void OnConnectionClosed(const QuicConnectionCloseFrame& /*frame*/, + ConnectionCloseSource /*source*/) {} + + // Called when the version negotiation is successful. + virtual void OnSuccessfulVersionNegotiation( + const ParsedQuicVersion& /*version*/) {} + + // Called when a CachedNetworkParameters is sent to the client. + virtual void OnSendConnectionState( + const CachedNetworkParameters& /*cached_network_params*/) {} + + // Called when a CachedNetworkParameters are received from the client. + virtual void OnReceiveConnectionState( + const CachedNetworkParameters& /*cached_network_params*/) {} + + // Called when the connection parameters are set from the supplied + // |config|. + virtual void OnSetFromConfig(const QuicConfig& /*config*/) {} + + // Called when RTT may have changed, including when an RTT is read from + // the config. + virtual void OnRttChanged(QuicTime::Delta /*rtt*/) const {} + + // Called when a StopSendingFrame has been parsed. + virtual void OnStopSendingFrame(const QuicStopSendingFrame& /*frame*/) {} + + // Called when a PathChallengeFrame has been parsed. + virtual void OnPathChallengeFrame(const QuicPathChallengeFrame& /*frame*/) {} + + // Called when a PathResponseFrame has been parsed. + virtual void OnPathResponseFrame(const QuicPathResponseFrame& /*frame*/) {} + + // Called when a StreamsBlockedFrame has been parsed. + virtual void OnStreamsBlockedFrame(const QuicStreamsBlockedFrame& /*frame*/) { + } + + // Called when a MaxStreamsFrame has been parsed. + virtual void OnMaxStreamsFrame(const QuicMaxStreamsFrame& /*frame*/) {} + + // Called when an AckFrequencyFrame has been parsed. + virtual void OnAckFrequencyFrame(const QuicAckFrequencyFrame& /*frame*/) {} + + // Called when |count| packet numbers have been skipped. + virtual void OnNPacketNumbersSkipped(QuicPacketCount /*count*/, + QuicTime /*now*/) {} + + // Called when a packet is serialized but discarded (i.e. not sent). + virtual void OnPacketDiscarded(const SerializedPacket& /*packet*/) {} + + // Called for QUIC+TLS versions when we send transport parameters. + virtual void OnTransportParametersSent( + const TransportParameters& /*transport_parameters*/) {} + + // Called for QUIC+TLS versions when we receive transport parameters. + virtual void OnTransportParametersReceived( + const TransportParameters& /*transport_parameters*/) {} + + // Called for QUIC+TLS versions when we resume cached transport parameters for + // 0-RTT. + virtual void OnTransportParametersResumed( + const TransportParameters& /*transport_parameters*/) {} + + // Called for QUIC+TLS versions when 0-RTT is rejected. + virtual void OnZeroRttRejected(int /*reject_reason*/) {} + + // Called for QUIC+TLS versions when 0-RTT packet gets acked. + virtual void OnZeroRttPacketAcked() {} + + // Called on peer address change. + virtual void OnPeerAddressChange(AddressChangeType /*type*/, + QuicTime::Delta /*connection_time*/) {} + + // Called after peer migration is validated. + virtual void OnPeerMigrationValidated(QuicTime::Delta /*connection_time*/) {} +}; + +class QUIC_EXPORT_PRIVATE QuicConnectionHelperInterface { + public: + virtual ~QuicConnectionHelperInterface() {} + + // Returns a QuicClock to be used for all time related functions. + virtual const QuicClock* GetClock() const = 0; + + // Returns a QuicRandom to be used for all random number related functions. + virtual QuicRandom* GetRandomGenerator() = 0; + + // Returns a QuicheBufferAllocator to be used for stream send buffers. + virtual quiche::QuicheBufferAllocator* GetStreamSendBufferAllocator() = 0; +}; + +class QUIC_EXPORT_PRIVATE QuicConnection + : public QuicFramerVisitorInterface, + public QuicBlockedWriterInterface, + public QuicPacketCreator::DelegateInterface, + public QuicSentPacketManager::NetworkChangeVisitor, + public QuicNetworkBlackholeDetector::Delegate, + public QuicIdleNetworkDetector::Delegate, + public QuicPathValidator::SendDelegate, + public QuicConnectionIdManagerVisitorInterface, + public QuicPingManager::Delegate { + public: + // Constructs a new QuicConnection for |connection_id| and + // |initial_peer_address| using |writer| to write packets. |owns_writer| + // specifies whether the connection takes ownership of |writer|. |helper| must + // outlive this connection. + QuicConnection(QuicConnectionId server_connection_id, + QuicSocketAddress initial_self_address, + QuicSocketAddress initial_peer_address, + QuicConnectionHelperInterface* helper, + QuicAlarmFactory* alarm_factory, QuicPacketWriter* writer, + bool owns_writer, Perspective perspective, + const ParsedQuicVersionVector& supported_versions, + ConnectionIdGeneratorInterface& generator); + QuicConnection(const QuicConnection&) = delete; + QuicConnection& operator=(const QuicConnection&) = delete; + ~QuicConnection() override; + + struct MultiPortStats { + // general rtt stats of the multi-port path. + RttStats rtt_stats; + // rtt stats for the multi-port path when the default path is degrading. + RttStats rtt_stats_when_default_path_degrading; + // number of path degrading triggered when multi-port is enabled. + size_t num_path_degrading = 0; + // number of multi-port probe failures when path is not degrading + size_t num_multi_port_probe_failures_when_path_not_degrading = 0; + // number of multi-port probe failure when path is degrading + size_t num_multi_port_probe_failures_when_path_degrading = 0; + // number of total multi-port path creations in a connection + size_t num_multi_port_paths_created = 0; + }; + + // Sets connection parameters from the supplied |config|. + void SetFromConfig(const QuicConfig& config); + + // Apply |connection_options| for this connection. Unlike SetFromConfig, this + // can happen at anytime in the life of a connection. + // Note there is no guarantee that all options can be applied. Components will + // only apply cherrypicked options that make sense at the time of the call. + void ApplyConnectionOptions(const QuicTagVector& connection_options); + + // Called by the session when sending connection state to the client. + virtual void OnSendConnectionState( + const CachedNetworkParameters& cached_network_params); + + // Called by the session when receiving connection state from the client. + virtual void OnReceiveConnectionState( + const CachedNetworkParameters& cached_network_params); + + // Called by the Session when the client has provided CachedNetworkParameters. + virtual void ResumeConnectionState( + const CachedNetworkParameters& cached_network_params, + bool max_bandwidth_resumption); + + // Called by the Session when a max pacing rate for the connection is needed. + virtual void SetMaxPacingRate(QuicBandwidth max_pacing_rate); + + // Allows the client to adjust network parameters based on external + // information. + void AdjustNetworkParameters( + const SendAlgorithmInterface::NetworkParams& params); + void AdjustNetworkParameters(QuicBandwidth bandwidth, QuicTime::Delta rtt, + bool allow_cwnd_to_decrease); + + // Install a loss detection tuner. Must be called before OnConfigNegotiated. + void SetLossDetectionTuner( + std::unique_ptr tuner); + // Called by the session when session->is_configured() becomes true. + void OnConfigNegotiated(); + + // Returns the max pacing rate for the connection. + virtual QuicBandwidth MaxPacingRate() const; + + // Sends crypto handshake messages of length |write_length| to the peer in as + // few packets as possible. Returns the number of bytes consumed from the + // data. + virtual size_t SendCryptoData(EncryptionLevel level, size_t write_length, + QuicStreamOffset offset); + + // Send the data of length |write_length| to the peer in as few packets as + // possible. Returns the number of bytes consumed from data, and a boolean + // indicating if the fin bit was consumed. This does not indicate the data + // has been sent on the wire: it may have been turned into a packet and queued + // if the socket was unexpectedly blocked. + virtual QuicConsumedData SendStreamData(QuicStreamId id, size_t write_length, + QuicStreamOffset offset, + StreamSendingState state); + + // Send |frame| to the peer. Returns true if frame is consumed, false + // otherwise. + virtual bool SendControlFrame(const QuicFrame& frame); + + // Called when stream |id| is reset because of |error|. + virtual void OnStreamReset(QuicStreamId id, QuicRstStreamErrorCode error); + + // Closes the connection. + // |connection_close_behavior| determines whether or not a connection close + // packet is sent to the peer. + virtual void CloseConnection( + QuicErrorCode error, const std::string& details, + ConnectionCloseBehavior connection_close_behavior); + // Closes the connection, specifying the wire error code |ietf_error| + // explicitly. + virtual void CloseConnection( + QuicErrorCode error, QuicIetfTransportErrorCodes ietf_error, + const std::string& details, + ConnectionCloseBehavior connection_close_behavior); + + QuicConnectionStats& mutable_stats() { return stats_; } + + // Returns statistics tracked for this connection. + const QuicConnectionStats& GetStats(); + + // Processes an incoming UDP packet (consisting of a QuicEncryptedPacket) from + // the peer. + // In a client, the packet may be "stray" and have a different connection ID + // than that of this connection. + virtual void ProcessUdpPacket(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + const QuicReceivedPacket& packet); + + // QuicBlockedWriterInterface + // Called when the underlying connection becomes writable to allow queued + // writes to happen. + void OnBlockedWriterCanWrite() override; + + bool IsWriterBlocked() const override { + return writer_ != nullptr && writer_->IsWriteBlocked(); + } + + // Called when the caller thinks it's worth a try to write. + // TODO(fayang): consider unifying this with QuicSession::OnCanWrite. + virtual void OnCanWrite(); + + // Called when an error occurs while attempting to write a packet to the + // network. + void OnWriteError(int error_code); + + // Whether |result| represents a MSG TOO BIG write error. + bool IsMsgTooBig(const QuicPacketWriter* writer, const WriteResult& result); + + // If the socket is not blocked, writes queued packets. + void WriteIfNotBlocked(); + + // Set the packet writer. + void SetQuicPacketWriter(QuicPacketWriter* writer, bool owns_writer) { + QUICHE_DCHECK(writer != nullptr); + if (writer_ != nullptr && owns_writer_) { + delete writer_; + } + writer_ = writer; + owns_writer_ = owns_writer; + } + + // Set self address. + void SetSelfAddress(QuicSocketAddress address) { + default_path_.self_address = address; + } + + // The version of the protocol this connection is using. + QuicTransportVersion transport_version() const { + return framer_.transport_version(); + } + + ParsedQuicVersion version() const { return framer_.version(); } + + // The versions of the protocol that this connection supports. + const ParsedQuicVersionVector& supported_versions() const { + return framer_.supported_versions(); + } + + // Mark version negotiated for this connection. Once called, the connection + // will ignore received version negotiation packets. + void SetVersionNegotiated() { + version_negotiated_ = true; + if (perspective_ == Perspective::IS_SERVER) { + framer_.InferPacketHeaderTypeFromVersion(); + } + } + + // From QuicFramerVisitorInterface + void OnError(QuicFramer* framer) override; + bool OnProtocolVersionMismatch(ParsedQuicVersion received_version) override; + void OnPacket() override; + void OnPublicResetPacket(const QuicPublicResetPacket& packet) override; + void OnVersionNegotiationPacket( + const QuicVersionNegotiationPacket& packet) override; + void OnRetryPacket(QuicConnectionId original_connection_id, + QuicConnectionId new_connection_id, + absl::string_view retry_token, + absl::string_view retry_integrity_tag, + absl::string_view retry_without_tag) override; + bool OnUnauthenticatedPublicHeader(const QuicPacketHeader& header) override; + bool OnUnauthenticatedHeader(const QuicPacketHeader& header) override; + void OnDecryptedPacket(size_t length, EncryptionLevel level) override; + bool OnPacketHeader(const QuicPacketHeader& header) override; + void OnCoalescedPacket(const QuicEncryptedPacket& packet) override; + void OnUndecryptablePacket(const QuicEncryptedPacket& packet, + EncryptionLevel decryption_level, + bool has_decryption_key) override; + bool OnStreamFrame(const QuicStreamFrame& frame) override; + bool OnCryptoFrame(const QuicCryptoFrame& frame) override; + bool OnAckFrameStart(QuicPacketNumber largest_acked, + QuicTime::Delta ack_delay_time) override; + bool OnAckRange(QuicPacketNumber start, QuicPacketNumber end) override; + bool OnAckTimestamp(QuicPacketNumber packet_number, + QuicTime timestamp) override; + bool OnAckFrameEnd(QuicPacketNumber start, + const absl::optional& ecn_counts) override; + bool OnStopWaitingFrame(const QuicStopWaitingFrame& frame) override; + bool OnPaddingFrame(const QuicPaddingFrame& frame) override; + bool OnPingFrame(const QuicPingFrame& frame) override; + bool OnRstStreamFrame(const QuicRstStreamFrame& frame) override; + bool OnConnectionCloseFrame(const QuicConnectionCloseFrame& frame) override; + bool OnStopSendingFrame(const QuicStopSendingFrame& frame) override; + bool OnPathChallengeFrame(const QuicPathChallengeFrame& frame) override; + bool OnPathResponseFrame(const QuicPathResponseFrame& frame) override; + bool OnGoAwayFrame(const QuicGoAwayFrame& frame) override; + bool OnMaxStreamsFrame(const QuicMaxStreamsFrame& frame) override; + bool OnStreamsBlockedFrame(const QuicStreamsBlockedFrame& frame) override; + bool OnWindowUpdateFrame(const QuicWindowUpdateFrame& frame) override; + bool OnBlockedFrame(const QuicBlockedFrame& frame) override; + bool OnNewConnectionIdFrame(const QuicNewConnectionIdFrame& frame) override; + bool OnRetireConnectionIdFrame( + const QuicRetireConnectionIdFrame& frame) override; + bool OnNewTokenFrame(const QuicNewTokenFrame& frame) override; + bool OnMessageFrame(const QuicMessageFrame& frame) override; + bool OnHandshakeDoneFrame(const QuicHandshakeDoneFrame& frame) override; + bool OnAckFrequencyFrame(const QuicAckFrequencyFrame& frame) override; + void OnPacketComplete() override; + bool IsValidStatelessResetToken( + const StatelessResetToken& token) const override; + void OnAuthenticatedIetfStatelessResetPacket( + const QuicIetfStatelessResetPacket& packet) override; + void OnKeyUpdate(KeyUpdateReason reason) override; + void OnDecryptedFirstPacketInKeyPhase() override; + std::unique_ptr AdvanceKeysAndCreateCurrentOneRttDecrypter() + override; + std::unique_ptr CreateCurrentOneRttEncrypter() override; + + // QuicPacketCreator::DelegateInterface + bool ShouldGeneratePacket(HasRetransmittableData retransmittable, + IsHandshake handshake) override; + const QuicFrames MaybeBundleAckOpportunistically() override; + QuicPacketBuffer GetPacketBuffer() override; + void OnSerializedPacket(SerializedPacket packet) override; + void OnUnrecoverableError(QuicErrorCode error, + const std::string& error_details) override; + SerializedPacketFate GetSerializedPacketFate( + bool is_mtu_discovery, EncryptionLevel encryption_level) override; + + // QuicSentPacketManager::NetworkChangeVisitor + void OnCongestionChange() override; + void OnPathMtuIncreased(QuicPacketLength packet_size) override; + + // QuicNetworkBlackholeDetector::Delegate + void OnPathDegradingDetected() override; + void OnBlackholeDetected() override; + void OnPathMtuReductionDetected() override; + + // QuicIdleNetworkDetector::Delegate + void OnHandshakeTimeout() override; + void OnIdleNetworkDetected() override; + void OnBandwidthUpdateTimeout() override; + + // QuicPingManager::Delegate + void OnKeepAliveTimeout() override; + void OnRetransmittableOnWireTimeout() override; + + // QuicConnectionIdManagerVisitorInterface + void OnPeerIssuedConnectionIdRetired() override; + bool SendNewConnectionId(const QuicNewConnectionIdFrame& frame) override; + bool MaybeReserveConnectionId(const QuicConnectionId& connection_id) override; + void OnSelfIssuedConnectionIdRetired( + const QuicConnectionId& connection_id) override; + + // Please note, this is not a const function. For logging purpose, please use + // ack_frame(). + const QuicFrame GetUpdatedAckFrame(); + + // Called to send a new connection ID to client if the # of connection ID has + // not exceeded the active connection ID limits. + void MaybeSendConnectionIdToClient(); + + // Called when the handshake completes. On the client side, handshake + // completes on receipt of SHLO. On the server side, handshake completes when + // SHLO gets ACKed (or a forward secure packet gets decrypted successfully). + // TODO(fayang): Add a guard that this only gets called once. + void OnHandshakeComplete(); + + // Creates and probes an multi-port path if none exists. + void MaybeCreateMultiPortPath(); + + // Called in multi-port QUIC when the alternative path validation succeeds. + // Stores the path validation context and prepares for the next validation. + void OnMultiPortPathProbingSuccess( + std::unique_ptr context, QuicTime start_time); + + // Probe the existing alternative path. Does not create a new alternative + // path. This method is the callback for |multi_port_probing_alarm_|. + virtual void MaybeProbeMultiPortPath(); + + // Accessors + void set_visitor(QuicConnectionVisitorInterface* visitor) { + visitor_ = visitor; + } + void set_debug_visitor(QuicConnectionDebugVisitor* debug_visitor) { + debug_visitor_ = debug_visitor; + sent_packet_manager_.SetDebugDelegate(debug_visitor); + } + // Used in Chromium, but not internally. + // Must only be called before ping_alarm_ is set. + void set_keep_alive_ping_timeout(QuicTime::Delta keep_alive_ping_timeout); + // Sets an initial timeout for the ping alarm when there is no retransmittable + // data in flight, allowing for a more aggressive ping alarm in that case. + void set_initial_retransmittable_on_wire_timeout( + QuicTime::Delta retransmittable_on_wire_timeout); + // Used in Chromium, but not internally. + void set_creator_debug_delegate(QuicPacketCreator::DebugDelegate* visitor) { + packet_creator_.set_debug_delegate(visitor); + } + const QuicSocketAddress& self_address() const { + return default_path_.self_address; + } + const QuicSocketAddress& peer_address() const { return direct_peer_address_; } + const QuicSocketAddress& effective_peer_address() const { + return default_path_.peer_address; + } + + // Returns the server connection ID used on the default path. + const QuicConnectionId& connection_id() const { + return default_path_.server_connection_id; + } + + const QuicConnectionId& client_connection_id() const { + return default_path_.client_connection_id; + } + void set_client_connection_id(QuicConnectionId client_connection_id); + const QuicClock* clock() const { return clock_; } + QuicRandom* random_generator() const { return random_generator_; } + QuicByteCount max_packet_length() const; + void SetMaxPacketLength(QuicByteCount length); + + size_t mtu_probe_count() const { return mtu_probe_count_; } + + bool connected() const { return connected_; } + + // Must only be called on client connections. + const ParsedQuicVersionVector& server_supported_versions() const { + QUICHE_DCHECK_EQ(Perspective::IS_CLIENT, perspective_); + return server_supported_versions_; + } + + bool HasQueuedPackets() const { return !buffered_packets_.empty(); } + // Testing only. TODO(ianswett): Use a peer instead. + size_t NumQueuedPackets() const { return buffered_packets_.size(); } + + // Returns true if the connection has queued packets or frames. + bool HasQueuedData() const; + + // Sets the handshake and idle state connection timeouts. + void SetNetworkTimeouts(QuicTime::Delta handshake_timeout, + QuicTime::Delta idle_timeout); + + void SetMultiPortProbingInterval(QuicTime::Delta probing_interval) { + multi_port_probing_interval_ = probing_interval; + } + + const MultiPortStats* multi_port_stats() const { + return multi_port_stats_.get(); + } + + // Sets up a packet with an QuicAckFrame and sends it out. + void SendAck(); + + // Called when an RTO fires. Resets the retransmission alarm if there are + // remaining unacked packets. + void OnRetransmissionTimeout(); + + // Mark all sent 0-RTT encrypted packets for retransmission. Called when new + // 0-RTT or 1-RTT key is available in gQUIC, or when 0-RTT is rejected in IETF + // QUIC. |reject_reason| is used in TLS-QUIC to log why 0-RTT was rejected. + void MarkZeroRttPacketsForRetransmission(int reject_reason); + + // Calls |sent_packet_manager_|'s NeuterUnencryptedPackets. Used when the + // connection becomes forward secure and hasn't received acks for all packets. + void NeuterUnencryptedPackets(); + + // Changes the encrypter used for level |level| to |encrypter|. + void SetEncrypter(EncryptionLevel level, + std::unique_ptr encrypter); + + // Called to remove encrypter of encryption |level|. + void RemoveEncrypter(EncryptionLevel level); + + // SetNonceForPublicHeader sets the nonce that will be transmitted in the + // header of each packet encrypted at the initial encryption level decrypted. + // This should only be called on the server side. + void SetDiversificationNonce(const DiversificationNonce& nonce); + + // SetDefaultEncryptionLevel sets the encryption level that will be applied + // to new packets. + void SetDefaultEncryptionLevel(EncryptionLevel level); + + // SetDecrypter sets the primary decrypter, replacing any that already exists. + // If an alternative decrypter is in place then the function QUICHE_DCHECKs. + // This is intended for cases where one knows that future packets will be + // using the new decrypter and the previous decrypter is now obsolete. |level| + // indicates the encryption level of the new decrypter. + void SetDecrypter(EncryptionLevel level, + std::unique_ptr decrypter); + + // SetAlternativeDecrypter sets a decrypter that may be used to decrypt + // future packets. |level| indicates the encryption level of the decrypter. If + // |latch_once_used| is true, then the first time that the decrypter is + // successful it will replace the primary decrypter. Otherwise both + // decrypters will remain active and the primary decrypter will be the one + // last used. + void SetAlternativeDecrypter(EncryptionLevel level, + std::unique_ptr decrypter, + bool latch_once_used); + + void InstallDecrypter(EncryptionLevel level, + std::unique_ptr decrypter); + void RemoveDecrypter(EncryptionLevel level); + + // Discard keys for the previous key phase. + void DiscardPreviousOneRttKeys(); + + // Returns true if it is currently allowed to initiate a key update. + bool IsKeyUpdateAllowed() const; + + // Returns true if packets have been sent in the current 1-RTT key phase but + // none of these packets have been acked. + bool HaveSentPacketsInCurrentKeyPhaseButNoneAcked() const; + + // Returns the count of packets received that appeared to attempt a key + // update but failed decryption that have been received since the last + // successfully decrypted packet. + QuicPacketCount PotentialPeerKeyUpdateAttemptCount() const; + + // Increment the key phase. It is a bug to call this when IsKeyUpdateAllowed() + // is false. Returns false on error. + bool InitiateKeyUpdate(KeyUpdateReason reason); + + const QuicDecrypter* decrypter() const; + const QuicDecrypter* alternative_decrypter() const; + + Perspective perspective() const { return perspective_; } + + // Allow easy overriding of truncated connection IDs. + void set_can_truncate_connection_ids(bool can) { + can_truncate_connection_ids_ = can; + } + + // Returns the underlying sent packet manager. + const QuicSentPacketManager& sent_packet_manager() const { + return sent_packet_manager_; + } + + // Returns the underlying sent packet manager. + QuicSentPacketManager& sent_packet_manager() { return sent_packet_manager_; } + + UberReceivedPacketManager& received_packet_manager() { + return uber_received_packet_manager_; + } + + bool CanWrite(HasRetransmittableData retransmittable); + + // When the flusher is out of scope, only the outermost flusher will cause a + // flush of the connection and set the retransmission alarm if there is one + // pending. In addition, this flusher can be configured to ensure that an ACK + // frame is included in the first packet created, if there's new ack + // information to be sent. + class QUIC_EXPORT_PRIVATE ScopedPacketFlusher { + public: + explicit ScopedPacketFlusher(QuicConnection* connection); + ~ScopedPacketFlusher(); + + private: + QuicConnection* connection_; + // If true, when this flusher goes out of scope, flush connection and set + // retransmission alarm if there is one pending. + bool flush_and_set_pending_retransmission_alarm_on_delete_; + // Latched connection's handshake_packet_sent_ on creation of this flusher. + const bool handshake_packet_sent_; + }; + + class QUIC_EXPORT_PRIVATE ScopedEncryptionLevelContext { + public: + ScopedEncryptionLevelContext(QuicConnection* connection, + EncryptionLevel level); + ~ScopedEncryptionLevelContext(); + + private: + QuicConnection* connection_; + // Latched current write encryption level on creation of this context. + EncryptionLevel latched_encryption_level_; + }; + + QuicPacketWriter* writer() { return writer_; } + const QuicPacketWriter* writer() const { return writer_; } + + // Sends an MTU discovery packet of size |target_mtu|. If the packet is + // acknowledged by the peer, the maximum packet size will be increased to + // |target_mtu|. + void SendMtuDiscoveryPacket(QuicByteCount target_mtu); + + // Sends a connectivity probing packet to |peer_address| with + // |probing_writer|. If |probing_writer| is nullptr, will use default + // packet writer to write the packet. Returns true if subsequent packets can + // be written to the probing writer. If connection is V99, a padded IETF QUIC + // PATH_CHALLENGE packet is transmitted; if not V99, a Google QUIC padded PING + // packet is transmitted. + virtual bool SendConnectivityProbingPacket( + QuicPacketWriter* probing_writer, const QuicSocketAddress& peer_address); + + // Disable MTU discovery on this connection. + void DisableMtuDiscovery(); + + // Sends an MTU discovery packet and updates the MTU discovery alarm. + void DiscoverMtu(); + + // Sets the session notifier on the SentPacketManager. + void SetSessionNotifier(SessionNotifierInterface* session_notifier); + + // Set data producer in framer. + void SetDataProducer(QuicStreamFrameDataProducer* data_producer); + + // Set transmission type of next sending packets. + void SetTransmissionType(TransmissionType type); + + // Tries to send |message| and returns the message status. + // If |flush| is false, this will return a MESSAGE_STATUS_BLOCKED + // when the connection is deemed unwritable. + virtual MessageStatus SendMessage(QuicMessageId message_id, + absl::Span message, + bool flush); + + // Returns the largest payload that will fit into a single MESSAGE frame. + // Because overhead can vary during a connection, this method should be + // checked for every message. + QuicPacketLength GetCurrentLargestMessagePayload() const; + // Returns the largest payload that will fit into a single MESSAGE frame at + // any point during the connection. This assumes the version and + // connection ID lengths do not change. + QuicPacketLength GetGuaranteedLargestMessagePayload() const; + + void SetUnackedMapInitialCapacity(); + + virtual int GetUnackedMapInitialCapacity() const { + return kDefaultUnackedPacketsInitialCapacity; + } + + // Returns the id of the cipher last used for decrypting packets. + uint32_t cipher_id() const; + + std::vector>* termination_packets() { + return termination_packets_.get(); + } + + bool ack_frame_updated() const; + + QuicConnectionHelperInterface* helper() { return helper_; } + const QuicConnectionHelperInterface* helper() const { return helper_; } + QuicAlarmFactory* alarm_factory() { return alarm_factory_; } + + absl::string_view GetCurrentPacket(); + + const QuicFramer& framer() const { return framer_; } + + const QuicPacketCreator& packet_creator() const { return packet_creator_; } + + EncryptionLevel encryption_level() const { return encryption_level_; } + EncryptionLevel last_decrypted_level() const { + return last_received_packet_info_.decrypted_level; + } + + const QuicSocketAddress& last_packet_source_address() const { + return last_received_packet_info_.source_address; + } + + // This setting may be changed during the crypto handshake in order to + // enable/disable padding of different packets in the crypto handshake. + // + // This setting should never be set to false in public facing endpoints. It + // can only be set to false if there is some other mechanism of preventing + // amplification attacks, such as ICE (plus its a non-standard quic). + void set_fully_pad_crypto_handshake_packets(bool new_value) { + packet_creator_.set_fully_pad_crypto_handshake_packets(new_value); + } + + bool fully_pad_during_crypto_handshake() const { + return packet_creator_.fully_pad_crypto_handshake_packets(); + } + + size_t min_received_before_ack_decimation() const; + void set_min_received_before_ack_decimation(size_t new_value); + + // If |defer| is true, configures the connection to defer sending packets in + // response to an ACK to the SendAlarm. If |defer| is false, packets may be + // sent immediately after receiving an ACK. + void set_defer_send_in_response_to_packets(bool defer) { + defer_send_in_response_to_packets_ = defer; + } + + // Sets the current per-packet options for the connection. The QuicConnection + // does not take ownership of |options|; |options| must live for as long as + // the QuicConnection is in use. + void set_per_packet_options(PerPacketOptions* options) { + per_packet_options_ = options; + } + + bool IsPathDegrading() const { return is_path_degrading_; } + + // Attempts to process any queued undecryptable packets. + void MaybeProcessUndecryptablePackets(); + + // Queue a coalesced packet. + void QueueCoalescedPacket(const QuicEncryptedPacket& packet); + + // Process previously queued coalesced packets. Returns true if any coalesced + // packets have been successfully processed. + bool MaybeProcessCoalescedPackets(); + + enum PacketContent : uint8_t { + NO_FRAMES_RECEIVED, + // TODO(fkastenholz): Change name when we get rid of padded ping/ + // pre-version-99. + // Also PATH CHALLENGE and PATH RESPONSE. + FIRST_FRAME_IS_PING, + SECOND_FRAME_IS_PADDING, + NOT_PADDED_PING, // Set if the packet is not {PING, PADDING}. + }; + + // Whether the handshake completes from this connection's perspective. + bool IsHandshakeComplete() const; + + // Whether peer completes handshake. Only used with TLS handshake. + bool IsHandshakeConfirmed() const; + + // Returns the largest received packet number sent by peer. + QuicPacketNumber GetLargestReceivedPacket() const; + + // Sets the original destination connection ID on the connection. + // This is called by QuicDispatcher when it has replaced the connection ID. + void SetOriginalDestinationConnectionId( + const QuicConnectionId& original_destination_connection_id); + + // Returns the original destination connection ID used for this connection. + QuicConnectionId GetOriginalDestinationConnectionId() const; + + // Tells the visitor the serverside connection is no longer expecting packets + // with the client-generated destination connection ID. + void RetireOriginalDestinationConnectionId(); + + // Called when ACK alarm goes off. Sends ACKs of those packet number spaces + // which have expired ACK timeout. Only used when this connection supports + // multiple packet number spaces. + void SendAllPendingAcks(); + + // Returns true if this connection supports multiple packet number spaces. + bool SupportsMultiplePacketNumberSpaces() const; + + // For logging purpose. + const QuicAckFrame& ack_frame() const; + + // Install encrypter and decrypter for ENCRYPTION_INITIAL using + // |connection_id| as the first client-sent destination connection ID, + // or the one sent after an IETF Retry. + void InstallInitialCrypters(QuicConnectionId connection_id); + + // Called when version is considered negotiated. + void OnSuccessfulVersionNegotiation(); + + // Called when self migration succeeds after probing. + void OnSuccessfulMigration(bool is_port_change); + + // Called for QUIC+TLS versions when we send transport parameters. + void OnTransportParametersSent( + const TransportParameters& transport_parameters) const; + + // Called for QUIC+TLS versions when we receive transport parameters. + void OnTransportParametersReceived( + const TransportParameters& transport_parameters) const; + + // Called for QUIC+TLS versions when we resume cached transport parameters for + // 0-RTT. + void OnTransportParametersResumed( + const TransportParameters& transport_parameters) const; + + // Returns true if ack_alarm_ is set. + bool HasPendingAcks() const; + + virtual void OnUserAgentIdKnown(const std::string& user_agent_id); + + // If now is close to idle timeout, returns true and sends a connectivity + // probing packet to test the connection for liveness. Otherwise, returns + // false. + bool MaybeTestLiveness(); + + // QuicPathValidator::SendDelegate + // Send PATH_CHALLENGE using the given path information. If |writer| is the + // default writer, PATH_CHALLENGE can be bundled with other frames, and the + // containing packet can be buffered if the writer is blocked. Otherwise, + // PATH_CHALLENGE will be written in an individual packet and it will be + // dropped if write fails. |data_buffer| will be populated with the payload + // for future validation. + // Return false if the connection is closed thus the caller will not continue + // the validation, otherwise return true. + bool SendPathChallenge(const QuicPathFrameBuffer& data_buffer, + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + const QuicSocketAddress& effective_peer_address, + QuicPacketWriter* writer) override; + // If |writer| is the default writer and |peer_address| is the same as + // peer_address(), return the PTO of this connection. Otherwise, return 3 * + // kInitialRtt. + QuicTime GetRetryTimeout(const QuicSocketAddress& peer_address_to_use, + QuicPacketWriter* writer_to_use) const override; + + // Start vaildating the path defined by |context| asynchronously and call the + // |result_delegate| after validation finishes. If the connection is + // validating another path, cancel and fail that validation before starting + // this one. + void ValidatePath( + std::unique_ptr context, + std::unique_ptr result_delegate, + PathValidationReason reason); + + bool can_receive_ack_frequency_frame() const { + return can_receive_ack_frequency_frame_; + } + + void set_can_receive_ack_frequency_frame() { + can_receive_ack_frequency_frame_ = true; + } + + bool is_processing_packet() const { return framer_.is_processing_packet(); } + + bool HasPendingPathValidation() const; + + QuicPathValidationContext* GetPathValidationContext() const; + + void CancelPathValidation(); + + // Returns true if the migration succeeds, otherwise returns false (e.g., no + // available CIDs, connection disconnected, etc). + bool MigratePath(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + QuicPacketWriter* writer, bool owns_writer); + + // Called to clear the alternative_path_ when path validation failed on the + // client side. + void OnPathValidationFailureAtClient( + bool is_multi_port, const QuicPathValidationContext& context); + + void SetSourceAddressTokenToSend(absl::string_view token); + + void SendPing() { + SendPingAtLevel(framer().GetEncryptionLevelToSendApplicationData()); + } + + // Returns one server connection ID that associates the current session in the + // session map. + virtual QuicConnectionId GetOneActiveServerConnectionId() const; + + // Returns all server connection IDs that have not been removed from the + // session map. + virtual std::vector GetActiveServerConnectionIds() const; + + bool validate_client_address() const { return validate_client_addresses_; } + + bool connection_migration_use_new_cid() const { + return connection_migration_use_new_cid_; + } + + // Instantiates connection ID manager. + void CreateConnectionIdManager(); + + // Log QUIC_BUG if there is pending frames for the stream with |id|. + void QuicBugIfHasPendingFrames(QuicStreamId id) const; + + QuicConnectionContext* context() { return &context_; } + const QuicConnectionContext* context() const { return &context_; } + + void set_tracer(std::unique_ptr tracer) { + context_.tracer.swap(tracer); + } + + void set_bug_listener(std::unique_ptr bug_listener) { + context_.bug_listener.swap(bug_listener); + } + + bool in_probe_time_out() const { return in_probe_time_out_; } + + // Ensures the network blackhole delay is longer than path degrading delay. + static QuicTime::Delta CalculateNetworkBlackholeDelay( + QuicTime::Delta blackhole_delay, QuicTime::Delta path_degrading_delay, + QuicTime::Delta pto_delay); + + void DisableLivenessTesting() { liveness_testing_disabled_ = true; } + + void AddKnownServerAddress(const QuicSocketAddress& address); + + absl::optional + MaybeIssueNewConnectionIdForPreferredAddress(); + + // Kicks off validation of received server preferred address. + void ValidateServerPreferredAddress(); + + // Returns true if the client is validating the server preferred address which + // hasn't been used before. + bool IsValidatingServerPreferredAddress() const; + + // Called by client to start sending packets to the preferred address. + // If |owns_writer| is true, the ownership of the writer in the |context| is + // also passed in. + void OnServerPreferredAddressValidated(QuicPathValidationContext& context, + bool owns_writer); + + void set_sent_server_preferred_address( + const QuicSocketAddress& sent_server_preferred_address) { + sent_server_preferred_address_ = sent_server_preferred_address; + } + + const QuicSocketAddress& sent_server_preferred_address() const { + return sent_server_preferred_address_; + } + + protected: + // Calls cancel() on all the alarms owned by this connection. + void CancelAllAlarms(); + + // Send a packet to the peer, and takes ownership of the packet if the packet + // cannot be written immediately. + virtual void SendOrQueuePacket(SerializedPacket packet); + + // Called after a packet is received from a new effective peer address and is + // decrypted. Starts validation of effective peer's address change. Calls + // OnConnectionMigration as soon as the address changed. + void StartEffectivePeerMigration(AddressChangeType type); + + // Called when a effective peer address migration is validated. + virtual void OnEffectivePeerMigrationValidated(bool is_migration_linkable); + + // Get the effective peer address from the packet being processed. For proxied + // connections, effective peer address is the address of the endpoint behind + // the proxy. For non-proxied connections, effective peer address is the same + // as peer address. + // + // Notes for implementations in subclasses: + // - If the connection is not proxied, the overridden method should use the + // base implementation: + // + // return QuicConnection::GetEffectivePeerAddressFromCurrentPacket(); + // + // - If the connection is proxied, the overridden method may return either of + // the following: + // a) The address of the endpoint behind the proxy. The address is used to + // drive effective peer migration. + // b) An uninitialized address, meaning the effective peer address does not + // change. + virtual QuicSocketAddress GetEffectivePeerAddressFromCurrentPacket() const; + + // Selects and updates the version of the protocol being used by selecting a + // version from |available_versions| which is also supported. Returns true if + // such a version exists, false otherwise. + bool SelectMutualVersion(const ParsedQuicVersionVector& available_versions); + + // Returns the current per-packet options for the connection. + PerPacketOptions* per_packet_options() { return per_packet_options_; } + + AddressChangeType active_effective_peer_migration_type() const { + return active_effective_peer_migration_type_; + } + + // Sends a connection close packet to the peer and includes an ACK if the ACK + // is not empty, the |error| is not PACKET_WRITE_ERROR, and it fits. + // |ietf_error| may optionally be be used to directly specify the wire + // error code. Otherwise if |ietf_error| is NO_IETF_QUIC_ERROR, the + // QuicErrorCodeToTransportErrorCode mapping of |error| will be used. + // Caller may choose to call SendConnectionClosePacket() directly instead of + // CloseConnection() to notify peer that the connection is going to be closed, + // for example, when the server is tearing down. Given + // SendConnectionClosePacket() does not close connection, multiple connection + // close packets could be sent to the peer. + virtual void SendConnectionClosePacket(QuicErrorCode error, + QuicIetfTransportErrorCodes ietf_error, + const std::string& details); + + // Returns true if the packet should be discarded and not sent. + virtual bool ShouldDiscardPacket(EncryptionLevel encryption_level); + + // Notify various components(Session etc.) that this connection has been + // migrated. + virtual void OnConnectionMigration(); + + // Return whether the packet being processed is a connectivity probing. + // A packet is a connectivity probing if it is a padded ping packet with self + // and/or peer address changes. + bool IsCurrentPacketConnectivityProbing() const; + + // Return true iff the writer is blocked, if blocked, call + // visitor_->OnWriteBlocked() to add the connection into the write blocked + // list. + bool HandleWriteBlocked(); + + // Whether connection enforces anti-amplification limit. + bool EnforceAntiAmplificationLimit() const; + + void AddBytesReceivedBeforeAddressValidation(size_t length) { + default_path_.bytes_received_before_address_validation += length; + } + + void set_validate_client_addresses(bool value) { + validate_client_addresses_ = value; + } + + bool defer_send_in_response_to_packets() const { + return defer_send_in_response_to_packets_; + } + + ConnectionIdGeneratorInterface& connection_id_generator() const { + return connection_id_generator_; + } + + bool count_reverse_path_validation_stats() const { + return count_reverse_path_validation_stats_; + } + void set_count_reverse_path_validation_stats(bool value) { + count_reverse_path_validation_stats_ = value; + } + + private: + friend class test::QuicConnectionPeer; + + enum RetransmittableOnWireBehavior { + DEFAULT, // Send packet containing a PING frame. + SEND_FIRST_FORWARD_SECURE_PACKET, // Send 1st 1-RTT packet. + SEND_RANDOM_BYTES // Send random bytes which is an unprocessable packet. + }; + + enum class MultiPortStatusOnMigration { + kNotValidated, + kPendingRefreshValidation, + kWaitingForRefreshValidation, + kMaxValue, + }; + + struct QUIC_EXPORT_PRIVATE PendingPathChallenge { + QuicPathFrameBuffer received_path_challenge; + QuicSocketAddress peer_address; + }; + + struct QUIC_EXPORT_PRIVATE PathState { + PathState() = default; + + PathState(const QuicSocketAddress& alternative_self_address, + const QuicSocketAddress& alternative_peer_address, + const QuicConnectionId& client_connection_id, + const QuicConnectionId& server_connection_id, + absl::optional stateless_reset_token) + : self_address(alternative_self_address), + peer_address(alternative_peer_address), + client_connection_id(client_connection_id), + server_connection_id(server_connection_id), + stateless_reset_token(stateless_reset_token) {} + + PathState(PathState&& other); + + PathState& operator=(PathState&& other); + + // Reset all the members. + void Clear(); + + QuicSocketAddress self_address; + // The actual peer address behind the proxy if there is any. + QuicSocketAddress peer_address; + QuicConnectionId client_connection_id; + QuicConnectionId server_connection_id; + absl::optional stateless_reset_token; + // True if the peer address has been validated. Address is considered + // validated when 1) an address token of the peer address is received and + // validated, or 2) a HANDSHAKE packet has been successfully processed on + // this path, or 3) a path validation on this path has succeeded. + bool validated = false; + // Used by the sever to apply anti-amplification limit after this path + // becomes the default path if |peer_address| hasn't been validated. + QuicByteCount bytes_received_before_address_validation = 0; + QuicByteCount bytes_sent_before_address_validation = 0; + // Points to the send algorithm on the old default path while connection is + // validating migrated peer address. Nullptr otherwise. + std::unique_ptr send_algorithm; + absl::optional rtt_stats; + }; + + using QueuedPacketList = std::list; + + // BufferedPacket stores necessary information (encrypted buffer and self/peer + // addresses) of those packets which are serialized but failed to send because + // socket is blocked. From unacked packet map and send algorithm's + // perspective, buffered packets are treated as sent. + struct QUIC_EXPORT_PRIVATE BufferedPacket { + BufferedPacket(const SerializedPacket& packet, + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address); + BufferedPacket(const char* encrypted_buffer, + QuicPacketLength encrypted_length, + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address); + // Please note, this buffered packet contains random bytes (and is not + // *actually* a QUIC packet). + BufferedPacket(QuicRandom& random, QuicPacketLength encrypted_length, + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address); + BufferedPacket(const BufferedPacket& other) = delete; + BufferedPacket(const BufferedPacket&& other) = delete; + + ~BufferedPacket() = default; + + std::unique_ptr data; + const QuicPacketLength length; + // Self and peer addresses when the packet is serialized. + const QuicSocketAddress self_address; + const QuicSocketAddress peer_address; + }; + + // ReceivedPacketInfo comprises the received packet information. + // TODO(fayang): move more fields to ReceivedPacketInfo. + struct QUIC_EXPORT_PRIVATE ReceivedPacketInfo { + explicit ReceivedPacketInfo(QuicTime receipt_time); + ReceivedPacketInfo(const QuicSocketAddress& destination_address, + const QuicSocketAddress& source_address, + QuicTime receipt_time, QuicByteCount length, + QuicEcnCodepoint ecn_codepoint); + + QuicSocketAddress destination_address; + QuicSocketAddress source_address; + QuicTime receipt_time = QuicTime::Zero(); + bool received_bytes_counted = false; + QuicByteCount length = 0; + QuicConnectionId destination_connection_id; + // Fields below are only populated if packet gets decrypted successfully. + // TODO(fayang): consider using absl::optional for following fields. + bool decrypted = false; + EncryptionLevel decrypted_level = ENCRYPTION_INITIAL; + QuicPacketHeader header; + absl::InlinedVector frames; + QuicEcnCodepoint ecn_codepoint = ECN_NOT_ECT; + // Stores the actual address this packet is received on when it is received + // on the preferred address. In this case, |destination_address| will + // be overridden to the current default self address. + QuicSocketAddress actual_destination_address; + }; + + QUIC_EXPORT_PRIVATE friend std::ostream& operator<<( + std::ostream& os, const QuicConnection::ReceivedPacketInfo& info); + + // UndecrytablePacket comprises a undecryptable packet and related + // information. + struct QUIC_EXPORT_PRIVATE UndecryptablePacket { + UndecryptablePacket(const QuicEncryptedPacket& packet, + EncryptionLevel encryption_level, + const ReceivedPacketInfo& packet_info) + : packet(packet.Clone()), + encryption_level(encryption_level), + packet_info(packet_info) {} + + std::unique_ptr packet; + EncryptionLevel encryption_level; + ReceivedPacketInfo packet_info; + }; + + // Handles the reverse path validation result depending on connection state: + // whether the connection is validating a migrated peer address or is + // validating an alternative path. + class ReversePathValidationResultDelegate + : public QuicPathValidator::ResultDelegate { + public: + ReversePathValidationResultDelegate( + QuicConnection* connection, + const QuicSocketAddress& direct_peer_address); + + void OnPathValidationSuccess( + std::unique_ptr context, + QuicTime start_time) override; + + void OnPathValidationFailure( + std::unique_ptr context) override; + + private: + QuicConnection* connection_; + QuicSocketAddress original_direct_peer_address_; + // TODO(b/205023946) Debug-only fields, to be deprecated after the bug is + // fixed. + QuicSocketAddress peer_address_default_path_; + QuicSocketAddress peer_address_alternative_path_; + AddressChangeType active_effective_peer_migration_type_; + }; + + // Keeps an ongoing alternative path. The connection will not migrate upon + // validation success. + class MultiPortPathValidationResultDelegate + : public QuicPathValidator::ResultDelegate { + public: + MultiPortPathValidationResultDelegate(QuicConnection* connection); + + void OnPathValidationSuccess( + std::unique_ptr context, + QuicTime start_time) override; + + void OnPathValidationFailure( + std::unique_ptr context) override; + + private: + QuicConnection* connection_; + }; + + // A class which sets and clears in_probe_time_out_ when entering + // and exiting OnRetransmissionTimeout, respectively. + class QUIC_EXPORT_PRIVATE ScopedRetransmissionTimeoutIndicator { + public: + // |connection| must outlive this indicator. + explicit ScopedRetransmissionTimeoutIndicator(QuicConnection* connection); + + ~ScopedRetransmissionTimeoutIndicator(); + + private: + QuicConnection* connection_; // Not owned. + }; + + // If peer uses non-empty connection ID, discards any buffered packets on path + // change in IETF QUIC. + void MaybeClearQueuedPacketsOnPathChange(); + + // Notifies the visitor of the close and marks the connection as disconnected. + // Does not send a connection close frame to the peer. It should only be + // called by CloseConnection or OnConnectionCloseFrame, OnPublicResetPacket, + // and OnAuthenticatedIetfStatelessResetPacket. + // |ietf_error| may optionally be be used to directly specify the wire + // error code. Otherwise if |ietf_error| is NO_IETF_QUIC_ERROR, the + // QuicErrorCodeToTransportErrorCode mapping of |error| will be used. + void TearDownLocalConnectionState(QuicErrorCode error, + QuicIetfTransportErrorCodes ietf_error, + const std::string& details, + ConnectionCloseSource source); + void TearDownLocalConnectionState(const QuicConnectionCloseFrame& frame, + ConnectionCloseSource source); + + // Replace server connection ID on the client side from retry packet or + // initial packets with a different source connection ID. + void ReplaceInitialServerConnectionId( + const QuicConnectionId& new_server_connection_id); + + // Given the server_connection_id find if there is already a corresponding + // client connection ID used on default/alternative path. If not, find if + // there is an unused connection ID. + void FindMatchingOrNewClientConnectionIdOrToken( + const PathState& default_path, const PathState& alternative_path, + const QuicConnectionId& server_connection_id, + QuicConnectionId* client_connection_id, + absl::optional* stateless_reset_token); + + // Returns true and sets connection IDs if (self_address, peer_address) + // corresponds to either the default path or alternative path. Returns false + // otherwise. + bool FindOnPathConnectionIds(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + QuicConnectionId* client_connection_id, + QuicConnectionId* server_connection_id) const; + + // Set default_path_ to the new_path_state and update the connection IDs in + // packet creator accordingly. + void SetDefaultPathState(PathState new_path_state); + + // Returns true if header contains valid server connection ID. + bool ValidateServerConnectionId(const QuicPacketHeader& header) const; + + // Update the connection IDs when client migrates its own address + // (with/without validation) or switches to server preferred address. + // Returns false if required connection ID is not available. + bool UpdateConnectionIdsOnMigration(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address); + + // Retire active peer issued connection IDs after they are no longer used on + // any path. + void RetirePeerIssuedConnectionIdsNoLongerOnPath(); + + // When path validation fails, proactively retire peer issued connection IDs + // no longer used on any path. + void RetirePeerIssuedConnectionIdsOnPathValidationFailure(); + + // Writes the given packet to socket, encrypted with packet's + // encryption_level. Returns true on successful write, and false if the writer + // was blocked and the write needs to be tried again. Notifies the + // SentPacketManager when the write is successful and sets + // retransmittable frames to nullptr. + // Saves the connection close packet for later transmission, even if the + // writer is write blocked. + bool WritePacket(SerializedPacket* packet); + + // Enforce AEAD Confidentiality limits by iniating key update or closing + // connection if too many packets have been encrypted with the current key. + // Returns true if the connection was closed. Should not be called for + // termination packets. + bool MaybeHandleAeadConfidentialityLimits(const SerializedPacket& packet); + + // Flush packets buffered in the writer, if any. + void FlushPackets(); + + // Make sure a stop waiting we got from our peer is sane. + // Returns nullptr if the frame is valid or an error string if it was invalid. + const char* ValidateStopWaitingFrame( + const QuicStopWaitingFrame& stop_waiting); + + // Clears any accumulated frames from the last received packet. + void ClearLastFrames(); + + // Deletes and clears any queued packets. + void ClearQueuedPackets(); + + // Closes the connection if the sent packet manager is tracking too many + // outstanding packets. + void CloseIfTooManyOutstandingSentPackets(); + + // Writes as many queued packets as possible. The connection must not be + // blocked when this is called. + void WriteQueuedPackets(); + + // Queues |packet| in the hopes that it can be decrypted in the + // future, when a new key is installed. + void QueueUndecryptablePacket(const QuicEncryptedPacket& packet, + EncryptionLevel decryption_level); + + // Sends any packets which are a response to the last packet, including both + // acks and pending writes if an ack opened the congestion window. + void MaybeSendInResponseToPacket(); + + // Gets the least unacked packet number, which is the next packet number to be + // sent if there are no outstanding packets. + QuicPacketNumber GetLeastUnacked() const; + + // Sets the ping alarm to the appropriate value, if any. + void SetPingAlarm(); + + // Sets the retransmission alarm based on SentPacketManager. + void SetRetransmissionAlarm(); + + // Sets the MTU discovery alarm if necessary. + // |sent_packet_number| is the recently sent packet number. + void MaybeSetMtuAlarm(QuicPacketNumber sent_packet_number); + + HasRetransmittableData IsRetransmittable(const SerializedPacket& packet); + bool IsTerminationPacket(const SerializedPacket& packet, + QuicErrorCode* error_code); + + // Set the size of the packet we are targeting while doing path MTU discovery. + void SetMtuDiscoveryTarget(QuicByteCount target); + + // Returns |suggested_max_packet_size| clamped to any limits set by the + // underlying writer, connection, or protocol. + QuicByteCount GetLimitedMaxPacketSize( + QuicByteCount suggested_max_packet_size); + + // Do any work which logically would be done in OnPacket but can not be + // safely done until the packet is validated. Returns true if packet can be + // handled, false otherwise. + bool ProcessValidatedPacket(const QuicPacketHeader& header); + + // Returns true if received |packet_number| can be processed. Please note, + // this is called after packet got decrypted successfully. + bool ValidateReceivedPacketNumber(QuicPacketNumber packet_number); + + // Consider receiving crypto frame on non crypto stream as memory corruption. + bool MaybeConsiderAsMemoryCorruption(const QuicStreamFrame& frame); + + // Check if the connection has no outstanding data to send and notify + // congestion controller if it is the case. + void CheckIfApplicationLimited(); + + // Sets |current_packet_content_| to |type| if applicable. And + // starts effective peer migration if current packet is confirmed not a + // connectivity probe and |current_effective_peer_migration_type_| indicates + // effective peer address change. + // Returns true if connection is still alive. + ABSL_MUST_USE_RESULT bool UpdatePacketContent(QuicFrameType type); + + // Called when last received ack frame has been processed. + // |send_stop_waiting| indicates whether a stop waiting needs to be sent. + // |acked_new_packet| is true if a previously-unacked packet was acked. + void PostProcessAfterAckFrame(bool send_stop_waiting, bool acked_new_packet); + + // Updates the release time into the future. + void UpdateReleaseTimeIntoFuture(); + + // Sends generic path probe packet to the peer. If we are not IETF QUIC, will + // always send a padded ping, regardless of whether this is a request or not. + bool SendGenericPathProbePacket(QuicPacketWriter* probing_writer, + const QuicSocketAddress& peer_address); + + // Called when an ACK is about to send. Resets ACK related internal states, + // e.g., cancels ack_alarm_, resets + // num_retransmittable_packets_received_since_last_ack_sent_ etc. + void ResetAckStates(); + + // Returns true if the ACK frame should be bundled with ACK-eliciting frame. + bool ShouldBundleRetransmittableFrameWithAck() const; + + void PopulateStopWaitingFrame(QuicStopWaitingFrame* stop_waiting); + + // Enables multiple packet number spaces support based on handshake protocol + // and flags. + void MaybeEnableMultiplePacketNumberSpacesSupport(); + + // Called to update ACK timeout when an retransmittable frame has been parsed. + void MaybeUpdateAckTimeout(); + + // Tries to fill coalesced packet with data of higher packet space. + void MaybeCoalescePacketOfHigherSpace(); + + // Serialize and send coalesced_packet. Returns false if serialization fails + // or the write causes errors, otherwise, returns true. + bool FlushCoalescedPacket(); + + // Returns the encryption level the connection close packet should be sent at, + // which is the highest encryption level that peer can guarantee to process. + EncryptionLevel GetConnectionCloseEncryptionLevel() const; + + // Called after an ACK frame is successfully processed to update largest + // received packet number which contains an ACK frame. + void SetLargestReceivedPacketWithAck(QuicPacketNumber new_value); + + // Called when new packets have been acknowledged or old keys have been + // discarded. + void OnForwardProgressMade(); + + // Returns largest received packet number which contains an ACK frame. + QuicPacketNumber GetLargestReceivedPacketWithAck() const; + + // Returns the largest packet number that has been sent. + QuicPacketNumber GetLargestSentPacket() const; + + // Returns the largest sent packet number that has been ACKed by peer. + QuicPacketNumber GetLargestAckedPacket() const; + + // Whether connection is limited by amplification factor. + // If enforce_strict_amplification_factor_ is true, this will return true if + // connection is amplification limited after sending |bytes|. + bool LimitedByAmplificationFactor(QuicByteCount bytes) const; + + // Called before sending a packet to get packet send time and to set the + // release time delay in |per_packet_options_|. Return the time when the + // packet is scheduled to be released(a.k.a send time), which is NOW + delay. + // Returns Now() and does not update release time delay if + // |supports_release_time_| is false. + QuicTime CalculatePacketSentTime(); + + // If we have a previously validate MTU value, e.g. due to a write error, + // revert to it and disable MTU discovery. + // Return true iff we reverted to a previously validate MTU. + bool MaybeRevertToPreviousMtu(); + + QuicTime GetPathMtuReductionDeadline() const; + + // Returns path degrading deadline. QuicTime::Zero() means no path degrading + // detection is needed. + QuicTime GetPathDegradingDeadline() const; + + // Returns true if path degrading should be detected. + bool ShouldDetectPathDegrading() const; + + // Returns network blackhole deadline. QuicTime::Zero() means no blackhole + // detection is needed. + QuicTime GetNetworkBlackholeDeadline() const; + + // Returns true if network blackhole should be detected. + bool ShouldDetectBlackhole() const; + + // Returns retransmission deadline. + QuicTime GetRetransmissionDeadline() const; + + // Validate connection IDs used during the handshake. Closes the connection + // on validation failure. + bool ValidateConfigConnectionIds(const QuicConfig& config); + + // Called when ACK alarm goes off. Try to bundle crypto data with ACKs. + void MaybeBundleCryptoDataWithAcks(); + + // Returns true if an undecryptable packet of |decryption_level| should be + // buffered (such that connection can try to decrypt it later). + bool ShouldEnqueueUnDecryptablePacket(EncryptionLevel decryption_level, + bool has_decryption_key) const; + + // Returns string which contains undecryptable packets information. + std::string UndecryptablePacketsInfo() const; + + // For Google Quic, if the current packet is connectivity probing packet, call + // session OnPacketReceived() which eventually sends connectivity probing + // response on server side. And no-op on client side. And for both Google Quic + // and IETF Quic, start migration if the current packet is a non-probing + // packet. + // TODO(danzh) rename to MaybeRespondToPeerMigration() when Google Quic is + // deprecated. + void MaybeRespondToConnectivityProbingOrMigration(); + + // Called in IETF QUIC. Start peer migration if a non-probing frame is + // received and the current packet number is largest received so far. + void MaybeStartIetfPeerMigration(); + + // Send PATH_RESPONSE to the given peer address. + bool SendPathResponse(const QuicPathFrameBuffer& data_buffer, + const QuicSocketAddress& peer_address_to_send, + const QuicSocketAddress& effective_peer_address); + + // Update both connection's and packet creator's peer address. + void UpdatePeerAddress(QuicSocketAddress peer_address); + + // Send PING at encryption level. + void SendPingAtLevel(EncryptionLevel level); + + // Write the given packet with |self_address| and |peer_address| using + // |writer|. + bool WritePacketUsingWriter(std::unique_ptr packet, + QuicPacketWriter* writer, + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + bool measure_rtt); + + // Increment bytes sent/received on the alternative path if the current packet + // is sent/received on that path. + void MaybeUpdateBytesSentToAlternativeAddress( + const QuicSocketAddress& peer_address, QuicByteCount sent_packet_size); + void MaybeUpdateBytesReceivedFromAlternativeAddress( + QuicByteCount received_packet_size); + + // TODO(danzh) pass in PathState of the incoming packet or the packet sent + // once PathState is used in packet creator. Return true if the given self + // address and peer address is the same as the self address and peer address + // of the default path. + bool IsDefaultPath(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address) const; + + // Return true if the |self_address| and |peer_address| is the same as the + // self address and peer address of the alternative path. + bool IsAlternativePath(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address) const; + + // Restore connection default path and congestion control state to the last + // validated path and its state. Called after fail to validate peer address + // upon detecting a peer migration. + void RestoreToLastValidatedPath( + QuicSocketAddress original_direct_peer_address); + + // Return true if the current incoming packet is from a peer address that is + // validated. + bool IsReceivedPeerAddressValidated() const; + + // Called after receiving PATH_CHALLENGE. Update packet content and + // alternative path state if the current packet is from a non-default path. + // Return true if framer should continue processing the packet. + bool OnPathChallengeFrameInternal(const QuicPathChallengeFrame& frame); + + // Check the state of the multi-port alternative path and initiate path + // migration. + void MaybeMigrateToMultiPortPath(); + + std::unique_ptr + MakeSelfIssuedConnectionIdManager(); + + // Called on peer IP change or restoring to previous address to reset + // congestion window, RTT stats, retransmission timer, etc. Only used in IETF + // QUIC. + std::unique_ptr OnPeerIpAddressChanged(); + + // Process NewConnectionIdFrame either sent from peer or synsthesized from + // preferred_address transport parameter. + bool OnNewConnectionIdFrameInner(const QuicNewConnectionIdFrame& frame); + + // Called to patch missing client connection ID on default/alternative paths + // when a new client connection ID is received. + void OnClientConnectionIdAvailable(); + + // Determines encryption level to send ping in `packet_number_space`. + EncryptionLevel GetEncryptionLevelToSendPingForSpace( + PacketNumberSpace space) const; + + // Returns true if |address| is known server address. + bool IsKnownServerAddress(const QuicSocketAddress& address) const; + + // Retrieves the ECN codepoint to be sent on the next packet. + QuicEcnCodepoint GetNextEcnCodepoint() const { + return (per_packet_options_ != nullptr) ? per_packet_options_->ecn_codepoint + : ECN_NOT_ECT; + } + + // Sets the ECN codepoint to Not-ECT. + void ClearEcnCodepoint(); + + // Writes the packet to the writer and clears the ECN codepoint in |options| + // if it is invalid. + WriteResult SendPacketToWriter(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + PerPacketOptions* options); + + QuicConnectionContext context_; + + QuicFramer framer_; + + // Contents received in the current packet, especially used to identify + // whether the current packet is a padded PING packet. + PacketContent current_packet_content_; + // Set to true as soon as the packet currently being processed has been + // detected as a connectivity probing. + // Always false outside the context of ProcessUdpPacket(). + bool is_current_packet_connectivity_probing_; + + bool has_path_challenge_in_current_packet_; + + // Caches the current effective peer migration type if a effective peer + // migration might be initiated. As soon as the current packet is confirmed + // not a connectivity probe, effective peer migration will start. + AddressChangeType current_effective_peer_migration_type_; + QuicConnectionHelperInterface* helper_; // Not owned. + QuicAlarmFactory* alarm_factory_; // Not owned. + PerPacketOptions* per_packet_options_; // Not owned. + QuicPacketWriter* writer_; // Owned or not depending on |owns_writer_|. + bool owns_writer_; + // Encryption level for new packets. Should only be changed via + // SetDefaultEncryptionLevel(). + EncryptionLevel encryption_level_; + const QuicClock* clock_; + QuicRandom* random_generator_; + + // On the server, the connection ID is set when receiving the first packet. + // This variable ensures we only set it this way once. + bool client_connection_id_is_set_; + + // Whether we've already replaced our server connection ID due to receiving an + // INITIAL packet with a different source connection ID. Only used on client. + bool server_connection_id_replaced_by_initial_ = false; + // Address on the last successfully processed packet received from the + // direct peer. + + // Other than initialization, do not modify it directly, use + // UpdatePeerAddress() instead. + QuicSocketAddress direct_peer_address_; + // The default path on which the endpoint sends non-probing packets. + // The send algorithm and RTT stats of this path are stored in + // |sent_packet_manager_| instead of in this object. + PathState default_path_; + + // Records change type when the effective peer initiates migration to a new + // address. Reset to NO_CHANGE after effective peer migration is validated. + AddressChangeType active_effective_peer_migration_type_; + + // Records highest sent packet number when effective peer migration is + // started. + QuicPacketNumber highest_packet_sent_before_effective_peer_migration_; + + // True if Key Update is supported on this connection. + bool support_key_update_for_connection_; + + // Tracks the lowest packet sent in the current key phase. Will be + // uninitialized before the first one-RTT packet has been sent or after a + // key update but before the first packet has been sent. + QuicPacketNumber lowest_packet_sent_in_current_key_phase_; + + // TODO(rch): remove this when b/27221014 is fixed. + const char* current_packet_data_; // UDP payload of packet currently being + // parsed or nullptr. + bool should_last_packet_instigate_acks_; + + // Track some peer state so we can do less bookkeeping + // Largest sequence sent by the peer which had an ack frame (latest ack info). + // Do not read or write directly, use GetLargestReceivedPacketWithAck() and + // SetLargestReceivedPacketWithAck() instead. + QuicPacketNumber largest_seen_packet_with_ack_; + // Largest packet number sent by the peer which had an ACK frame per packet + // number space. Only used when this connection supports multiple packet + // number spaces. + QuicPacketNumber largest_seen_packets_with_ack_[NUM_PACKET_NUMBER_SPACES]; + + // Largest packet number sent by the peer which had a stop waiting frame. + QuicPacketNumber largest_seen_packet_with_stop_waiting_; + + // Collection of packets which were received before encryption was + // established, but which could not be decrypted. We buffer these on + // the assumption that they could not be processed because they were + // sent with the INITIAL encryption and the CHLO message was lost. + std::deque undecryptable_packets_; + + // Collection of coalesced packets which were received while processing + // the current packet. + quiche::QuicheCircularDeque> + received_coalesced_packets_; + + // Maximum number of undecryptable packets the connection will store. + size_t max_undecryptable_packets_; + + // Maximum number of tracked packets. + QuicPacketCount max_tracked_packets_; + + // Contains the connection close packets if the connection has been closed. + std::unique_ptr>> + termination_packets_; + + // Determines whether or not a connection close packet is sent to the peer + // after idle timeout due to lack of network activity. During the handshake, + // a connection close packet is sent, but not after. + ConnectionCloseBehavior idle_timeout_connection_close_behavior_; + + // When > 0, close the QUIC connection after this number of RTOs. + size_t num_rtos_for_blackhole_detection_; + + // Statistics for this session. + QuicConnectionStats stats_; + + UberReceivedPacketManager uber_received_packet_manager_; + + // Indicates how many consecutive times an ack has arrived which indicates + // the peer needs to stop waiting for some packets. + // TODO(fayang): remove this when deprecating Q043. + int stop_waiting_count_; + + // Indicates the retransmission alarm needs to be set. + bool pending_retransmission_alarm_; + + // If true, defer sending data in response to received packets to the + // SendAlarm. + bool defer_send_in_response_to_packets_; + + // Arena to store class implementations within the QuicConnection. + QuicConnectionArena arena_; + + // An alarm that fires when an ACK should be sent to the peer. + QuicArenaScopedPtr ack_alarm_; + // An alarm that fires when a packet needs to be retransmitted. + QuicArenaScopedPtr retransmission_alarm_; + // An alarm that is scheduled when the SentPacketManager requires a delay + // before sending packets and fires when the packet may be sent. + QuicArenaScopedPtr send_alarm_; + // An alarm that fires when an MTU probe should be sent. + QuicArenaScopedPtr mtu_discovery_alarm_; + // An alarm that fires to process undecryptable packets when new decyrption + // keys are available. + QuicArenaScopedPtr process_undecryptable_packets_alarm_; + // An alarm that fires to discard keys for the previous key phase some time + // after a key update has completed. + QuicArenaScopedPtr discard_previous_one_rtt_keys_alarm_; + // An alarm that fires to discard 0-RTT decryption keys some time after the + // first 1-RTT packet has been decrypted. Only used on server connections with + // TLS handshaker. + QuicArenaScopedPtr discard_zero_rtt_decryption_keys_alarm_; + // An alarm that fires to keep probing the multi-port path. + QuicArenaScopedPtr multi_port_probing_alarm_; + // Neither visitor is owned by this class. + QuicConnectionVisitorInterface* visitor_; + QuicConnectionDebugVisitor* debug_visitor_; + + QuicPacketCreator packet_creator_; + + // Information about the last received QUIC packet, which may not have been + // successfully decrypted and processed. + ReceivedPacketInfo last_received_packet_info_; + + // Sent packet manager which tracks the status of packets sent by this + // connection and contains the send and receive algorithms to determine when + // to send packets. + QuicSentPacketManager sent_packet_manager_; + + // Indicates whether connection version has been negotiated. + // Always true for server connections. + bool version_negotiated_; + + // Tracks if the connection was created by the server or the client. + Perspective perspective_; + + // True by default. False if we've received or sent an explicit connection + // close. + bool connected_; + + // Set to false if the connection should not send truncated connection IDs to + // the peer, even if the peer supports it. + bool can_truncate_connection_ids_; + + // If non-empty this contains the set of versions received in a + // version negotiation packet. + ParsedQuicVersionVector server_supported_versions_; + + // The number of MTU probes already sent. + size_t mtu_probe_count_; + + // The value of |long_term_mtu_| prior to the last successful MTU increase. + // 0 means either + // - MTU discovery has never been enabled, or + // - MTU discovery has been enabled, but the connection got a packet write + // error with a new (successfully probed) MTU, so it reverted + // |long_term_mtu_| to the value before the last increase. + QuicPacketLength previous_validated_mtu_; + // The value of the MTU regularly used by the connection. This is different + // from the value returned by max_packet_size(), as max_packet_size() returns + // the value of the MTU as currently used by the serializer, so if + // serialization of an MTU probe is in progress, those two values will be + // different. + QuicByteCount long_term_mtu_; + + // The maximum UDP payload size that our peer has advertised support for. + // Defaults to kDefaultMaxPacketSizeTransportParam until received from peer. + QuicByteCount peer_max_packet_size_; + + // The size of the largest packet received from peer. + QuicByteCount largest_received_packet_size_; + + // Indicates whether a write error is encountered currently. This is used to + // avoid infinite write errors. + bool write_error_occurred_; + + // Indicates not to send or process stop waiting frames. + bool no_stop_waiting_frames_; + + // Consecutive number of sent packets which have no retransmittable frames. + size_t consecutive_num_packets_with_no_retransmittable_frames_; + + // After this many packets sent without retransmittable frames, an artificial + // retransmittable frame(a WINDOW_UPDATE) will be created to solicit an ack + // from the peer. Default to kMaxConsecutiveNonRetransmittablePackets. + size_t max_consecutive_num_packets_with_no_retransmittable_frames_; + + // If true, bundle an ack-eliciting frame with an ACK if the PTO alarm have + // previously fired. + bool bundle_retransmittable_with_pto_ack_; + + // Id of latest sent control frame. 0 if no control frame has been sent. + QuicControlFrameId last_control_frame_id_; + + // True if the peer is unreachable on the current path. + bool is_path_degrading_; + + // True if an ack frame is being processed. + bool processing_ack_frame_; + + // True if the writer supports release timestamp. + bool supports_release_time_; + + std::unique_ptr peer_issued_cid_manager_; + std::unique_ptr self_issued_cid_manager_; + + // Time this connection can release packets into the future. + QuicTime::Delta release_time_into_future_; + + // Payloads that were received in the most recent probe. This needs to be a + // Deque because the peer might no be using this implementation, and others + // might send a packet with more than one PATH_CHALLENGE, so all need to be + // saved and responded to. + // TODO(danzh) deprecate this field when deprecating + // --quic_send_path_response. + quiche::QuicheCircularDeque + received_path_challenge_payloads_; + + // When we receive a RETRY packet or some INITIAL packets, we replace + // |server_connection_id_| with the value from that packet and save off the + // original value of |server_connection_id_| into + // |original_destination_connection_id_| for validation. + absl::optional original_destination_connection_id_; + + // The connection ID that replaces original_destination_connection_id_. + QuicConnectionId original_destination_connection_id_replacement_; + + // After we receive a RETRY packet, |retry_source_connection_id_| contains + // the source connection ID from that packet. + absl::optional retry_source_connection_id_; + + // Used to store content of packets which cannot be sent because of write + // blocked. Packets' encrypted buffers are copied and owned by + // buffered_packets_. From unacked_packet_map (and congestion control)'s + // perspective, those packets are considered sent. + std::list buffered_packets_; + + // Used to coalesce packets of different encryption level into the same UDP + // datagram. Connection stops trying to coalesce packets if a forward secure + // packet gets acknowledged. + QuicCoalescedPacket coalesced_packet_; + + QuicConnectionMtuDiscoverer mtu_discoverer_; + + QuicNetworkBlackholeDetector blackhole_detector_; + + QuicIdleNetworkDetector idle_network_detector_; + + bool blackhole_detection_disabled_ = false; + + const bool default_enable_5rto_blackhole_detection_ = + GetQuicReloadableFlag(quic_default_enable_5rto_blackhole_detection2); + + // True if next packet is intended to consume remaining space in the + // coalescer. + bool fill_coalesced_packet_ = false; + + size_t anti_amplification_factor_ = + GetQuicFlag(quic_anti_amplification_factor); + + // True if AckFrequencyFrame is supported. + bool can_receive_ack_frequency_frame_ = false; + + // Indicate whether coalescing is done. + bool coalescing_done_ = false; + + // Indicate whether any ENCRYPTION_HANDSHAKE packet has been sent. + bool handshake_packet_sent_ = false; + + // Indicate whether to send an AckFrequencyFrame upon handshake completion. + // The AckFrequencyFrame sent will updates client's max_ack_delay, which if + // chosen properly can reduce the CPU and bandwidth usage for ACK frames. + bool send_ack_frequency_on_handshake_completion_ = false; + + // Indicate whether AckFrequency frame has been sent. + bool ack_frequency_sent_ = false; + + // True if a 0-RTT decrypter was or is installed at some point in the + // connection's lifetime. + bool had_zero_rtt_decrypter_ = false; + + // True after the first 1-RTT packet has successfully decrypted. + bool have_decrypted_first_one_rtt_packet_ = false; + + // True if we are currently processing OnRetransmissionTimeout. + bool in_probe_time_out_ = false; + + QuicPathValidator path_validator_; + + // Stores information of a path which maybe used as default path in the + // future. On the client side, it gets created when the client starts + // validating a new path and gets cleared once it becomes the default path or + // the path validation fails or replaced by a newer path of interest. On the + // server side, alternative_path gets created when server: 1) receives + // PATH_CHALLENGE on non-default path, or 2) switches to a not yet validated + // default path such that it needs to store the previous validated default + // path. + // Note that if alternative_path_ stores a validated path information (case + // 2), do not override it on receiving PATH_CHALLENGE (case 1). + PathState alternative_path_; + + // If true, upon seeing a new client address, validate the client address. + bool validate_client_addresses_ = false; + + // Indicates whether we should proactively validate peer address on a + // PATH_CHALLENGE received. + bool should_proactively_validate_peer_address_on_path_challenge_ = false; + + // Enable this via reloadable flag once this feature is complete. + bool connection_migration_use_new_cid_ = false; + + // If true, send connection close packet on INVALID_VERSION. + bool send_connection_close_for_invalid_version_ = false; + + // If true, disable liveness testing. + bool liveness_testing_disabled_ = false; + + QuicPingManager ping_manager_; + + // Records first serialized 1-RTT packet. + std::unique_ptr first_serialized_one_rtt_packet_; + + std::unique_ptr multi_port_path_context_; + + QuicTime::Delta multi_port_probing_interval_; + + std::unique_ptr multi_port_stats_; + + RetransmittableOnWireBehavior retransmittable_on_wire_behavior_ = DEFAULT; + + // Server addresses that are known to the client. + std::vector known_server_addresses_; + + // Stores received server preferred address in transport param. Client side + // only. + QuicSocketAddress received_server_preferred_address_; + + // Stores sent server preferred address in transport param. Server side only. + QuicSocketAddress sent_server_preferred_address_; + + // If true, kicks off validation of server_preferred_address_ once it is + // received. Also, send all coalesced packets on both paths until handshake is + // confirmed. + bool accelerated_server_preferred_address_ = false; + + // TODO(b/223634460) Remove this. + bool count_reverse_path_validation_stats_ = false; + + // If true, throttle sending if next created packet will exceed amplification + // limit. + const bool enforce_strict_amplification_factor_ = + GetQuicFlag(quic_enforce_strict_amplification_factor); + + ConnectionIdGeneratorInterface& connection_id_generator_; + + // This LRU cache records source addresses of packets received on server's + // original address. + QuicLRUCache + received_client_addresses_cache_; + + // Endpoints should never mark packets with Congestion Experienced (CE), as + // this is only done by routers. Endpoints cannot send ECT(0) or ECT(1) if + // their congestion control cannot respond to these signals in accordance with + // the spec, or if the QUIC implementation doesn't validate ECN feedback. When + // true, the connection will not verify that the requested codepoint adheres + // to these policies. This is only accessible through QuicConnectionPeer. + bool disable_ecn_codepoint_validation_ = false; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_CONNECTION_H_ diff --git a/quiche/quic/core/quic_connection_context.cc b/quiche/quic/core/quic_connection_context.cc new file mode 100644 index 000000000000..28e9d8edeeec --- /dev/null +++ b/quiche/quic/core/quic_connection_context.cc @@ -0,0 +1,48 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_connection_context.h" + +#include "absl/base/attributes.h" +#include "quiche/common/quiche_text_utils.h" + +namespace quic { +namespace { +ABSL_CONST_INIT thread_local QuicConnectionContext* current_context = nullptr; +} // namespace + +std::string QuicConnectionProcessPacketContext::DebugString() const { + if (decrypted_payload.empty()) { + return "Not processing packet"; + } + + return absl::StrCat("current_frame_offset: ", current_frame_offset, + ", payload size: ", decrypted_payload.size(), + ", payload hexdump: ", + quiche::QuicheTextUtils::HexDump(decrypted_payload)); +} + +// static +QuicConnectionContext* QuicConnectionContext::Current() { + return current_context; +} + +QuicConnectionContextSwitcher::QuicConnectionContextSwitcher( + QuicConnectionContext* new_context) + : old_context_(QuicConnectionContext::Current()) { + current_context = new_context; + if (new_context && new_context->tracer) { + new_context->tracer->Activate(); + } +} + +QuicConnectionContextSwitcher::~QuicConnectionContextSwitcher() { + QuicConnectionContext* current = QuicConnectionContext::Current(); + if (current && current->tracer) { + current->tracer->Deactivate(); + } + current_context = old_context_; +} + +} // namespace quic diff --git a/quiche/quic/core/quic_connection_context.h b/quiche/quic/core/quic_connection_context.h new file mode 100644 index 000000000000..72d6b66c86ea --- /dev/null +++ b/quiche/quic/core/quic_connection_context.h @@ -0,0 +1,153 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_CONNECTION_CONTEXT_H_ +#define QUICHE_QUIC_CORE_QUIC_CONNECTION_CONTEXT_H_ + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace quic { + +// QuicConnectionTracer is responsible for emit trace messages for a single +// QuicConnection. +// QuicConnectionTracer is part of the QuicConnectionContext. +class QUIC_EXPORT_PRIVATE QuicConnectionTracer { + public: + virtual ~QuicConnectionTracer() = default; + + // Emit a trace message from a string literal. The trace may simply remember + // the address of the literal in this function and read it at a later time. + virtual void PrintLiteral(const char* literal) = 0; + + // Emit a trace message from a string_view. Unlike PrintLiteral, this function + // will not read |s| after it returns. + virtual void PrintString(absl::string_view s) = 0; + + // Emit a trace message from printf-style arguments. + template + void Printf(const absl::FormatSpec& format, const Args&... args) { + std::string s = absl::StrFormat(format, args...); + PrintString(s); + } + + private: + friend class QuicConnectionContextSwitcher; + + // Called by QuicConnectionContextSwitcher, when |this| becomes the current + // thread's QUIC connection tracer. + // + // Activate/Deactivate are only called by QuicConnectionContextSwitcher's + // constructor/destructor, they always come in pairs. + virtual void Activate() {} + + // Called by QuicConnectionContextSwitcher, when |this| stops from being the + // current thread's QUIC connection tracer. + // + // Activate/Deactivate are only called by QuicConnectionContextSwitcher's + // constructor/destructor, they always come in pairs. + virtual void Deactivate() {} +}; + +// QuicBugListener is a helper class for implementing QUIC_BUG. The QUIC_BUG +// implementation can send the bug information into quic::CurrentBugListener(). +class QUIC_EXPORT_PRIVATE QuicBugListener { + public: + virtual ~QuicBugListener() = default; + virtual void OnQuicBug(const char* bug_id, const char* file, int line, + absl::string_view bug_message) = 0; +}; + +// QuicConnectionProcessPacketContext is a member of QuicConnectionContext that +// contains information of the packet currently being processed by the owning +// QuicConnection. +struct QUIC_EXPORT_PRIVATE QuicConnectionProcessPacketContext final { + // If !empty(), the decrypted payload of the packet currently being processed. + absl::string_view decrypted_payload; + + // The offset within |decrypted_payload|, if it's non-empty, that marks the + // start of the frame currently being processed. + // Should not be used when |decrypted_payload| is empty. + size_t current_frame_offset = 0; + + // NOTE: This can be very expansive. If used in logs, make sure it is rate + // limited via QUIC_BUG etc. + std::string DebugString() const; +}; + +// QuicConnectionContext is a per-QuicConnection context that includes +// facilities useable by any part of a QuicConnection. A QuicConnectionContext +// is owned by a QuicConnection. +// +// The 'top-level' QuicConnection functions are responsible for maintaining the +// thread-local QuicConnectionContext pointer, such that any function called by +// them(directly or indirectly) can access the context. +// +// Like QuicConnection, all facilities in QuicConnectionContext are assumed to +// be called from a single thread at a time, they are NOT thread-safe. +struct QUIC_EXPORT_PRIVATE QuicConnectionContext final { + // Get the context on the current executing thread. nullptr if the current + // function is not called from a 'top-level' QuicConnection function. + static QuicConnectionContext* Current(); + + std::unique_ptr tracer; + std::unique_ptr bug_listener; + + // Information about the packet currently being processed. + QuicConnectionProcessPacketContext process_packet_context; +}; + +// QuicConnectionContextSwitcher is a RAII object used for maintaining the +// thread-local QuicConnectionContext pointer. +class QUIC_EXPORT_PRIVATE QuicConnectionContextSwitcher final { + public: + // The constructor switches from QuicConnectionContext::Current() to + // |new_context|. + explicit QuicConnectionContextSwitcher(QuicConnectionContext* new_context); + + // The destructor switches from QuicConnectionContext::Current() back to the + // old context. + ~QuicConnectionContextSwitcher(); + + private: + QuicConnectionContext* old_context_; +}; + +// Emit a trace message from a string literal to the current tracer(if any). +inline void QUIC_TRACELITERAL(const char* literal) { + QuicConnectionContext* current = QuicConnectionContext::Current(); + if (current && current->tracer) { + current->tracer->PrintLiteral(literal); + } +} + +// Emit a trace message from a string_view to the current tracer(if any). +inline void QUIC_TRACESTRING(absl::string_view s) { + QuicConnectionContext* current = QuicConnectionContext::Current(); + if (current && current->tracer) { + current->tracer->PrintString(s); + } +} + +// Emit a trace message from printf-style arguments to the current tracer(if +// any). +template +void QUIC_TRACEPRINTF(const absl::FormatSpec& format, + const Args&... args) { + QuicConnectionContext* current = QuicConnectionContext::Current(); + if (current && current->tracer) { + current->tracer->Printf(format, args...); + } +} + +inline QuicBugListener* CurrentBugListener() { + QuicConnectionContext* current = QuicConnectionContext::Current(); + return (current != nullptr) ? current->bug_listener.get() : nullptr; +} + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_CONNECTION_CONTEXT_H_ diff --git a/quiche/quic/core/quic_connection_context_test.cc b/quiche/quic/core/quic_connection_context_test.cc new file mode 100644 index 000000000000..1f68ae931722 --- /dev/null +++ b/quiche/quic/core/quic_connection_context_test.cc @@ -0,0 +1,173 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_connection_context.h" + +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/platform/api/quic_thread.h" + +using testing::ElementsAre; + +namespace quic::test { +namespace { + +class TraceCollector : public QuicConnectionTracer { + public: + ~TraceCollector() override = default; + + void PrintLiteral(const char* literal) override { trace_.push_back(literal); } + + void PrintString(absl::string_view s) override { + trace_.push_back(std::string(s)); + } + + const std::vector& trace() const { return trace_; } + + private: + std::vector trace_; +}; + +struct FakeConnection { + FakeConnection() { context.tracer = std::make_unique(); } + + const std::vector& trace() const { + return static_cast(context.tracer.get())->trace(); + } + + QuicConnectionContext context; +}; + +void SimpleSwitch() { + FakeConnection connection; + + // These should be ignored since current context is nullptr. + EXPECT_EQ(QuicConnectionContext::Current(), nullptr); + QUIC_TRACELITERAL("before switch: literal"); + QUIC_TRACESTRING(std::string("before switch: string")); + QUIC_TRACEPRINTF("%s: %s", "before switch", "printf"); + + { + QuicConnectionContextSwitcher switcher(&connection.context); + QUIC_TRACELITERAL("literal"); + QUIC_TRACESTRING(std::string("string")); + QUIC_TRACEPRINTF("%s", "printf"); + } + + EXPECT_EQ(QuicConnectionContext::Current(), nullptr); + QUIC_TRACELITERAL("after switch: literal"); + QUIC_TRACESTRING(std::string("after switch: string")); + QUIC_TRACEPRINTF("%s: %s", "after switch", "printf"); + + EXPECT_THAT(connection.trace(), ElementsAre("literal", "string", "printf")); +} + +void NestedSwitch() { + FakeConnection outer, inner; + + { + QuicConnectionContextSwitcher switcher(&outer.context); + QUIC_TRACELITERAL("outer literal 0"); + QUIC_TRACESTRING(std::string("outer string 0")); + QUIC_TRACEPRINTF("%s %s %d", "outer", "printf", 0); + + { + QuicConnectionContextSwitcher switcher(&inner.context); + QUIC_TRACELITERAL("inner literal"); + QUIC_TRACESTRING(std::string("inner string")); + QUIC_TRACEPRINTF("%s %s", "inner", "printf"); + } + + QUIC_TRACELITERAL("outer literal 1"); + QUIC_TRACESTRING(std::string("outer string 1")); + QUIC_TRACEPRINTF("%s %s %d", "outer", "printf", 1); + } + + EXPECT_THAT(outer.trace(), ElementsAre("outer literal 0", "outer string 0", + "outer printf 0", "outer literal 1", + "outer string 1", "outer printf 1")); + + EXPECT_THAT(inner.trace(), + ElementsAre("inner literal", "inner string", "inner printf")); +} + +void AlternatingSwitch() { + FakeConnection zero, one, two; + for (int i = 0; i < 15; ++i) { + FakeConnection* connection = + ((i % 3) == 0) ? &zero : (((i % 3) == 1) ? &one : &two); + + QuicConnectionContextSwitcher switcher(&connection->context); + QUIC_TRACEPRINTF("%d", i); + } + + EXPECT_THAT(zero.trace(), ElementsAre("0", "3", "6", "9", "12")); + EXPECT_THAT(one.trace(), ElementsAre("1", "4", "7", "10", "13")); + EXPECT_THAT(two.trace(), ElementsAre("2", "5", "8", "11", "14")); +} + +typedef void (*ThreadFunction)(); + +template +class TestThread : public QuicThread { + public: + TestThread() : QuicThread("TestThread") {} + ~TestThread() override = default; + + protected: + void Run() override { func(); } +}; + +template +void RunInThreads(size_t n_threads) { + using ThreadType = TestThread; + std::vector threads(n_threads); + + for (ThreadType& t : threads) { + t.Start(); + } + + for (ThreadType& t : threads) { + t.Join(); + } +} + +class QuicConnectionContextTest : public QuicTest { + protected: +}; + +TEST_F(QuicConnectionContextTest, NullTracerOK) { + FakeConnection connection; + std::unique_ptr tracer; + + { + QuicConnectionContextSwitcher switcher(&connection.context); + QUIC_TRACELITERAL("msg 1 recorded"); + } + + connection.context.tracer.swap(tracer); + + { + QuicConnectionContextSwitcher switcher(&connection.context); + // Should be a no-op since connection.context.tracer is nullptr. + QUIC_TRACELITERAL("msg 2 ignored"); + } + + EXPECT_THAT(static_cast(tracer.get())->trace(), + ElementsAre("msg 1 recorded")); +} + +TEST_F(QuicConnectionContextTest, TestSimpleSwitch) { + RunInThreads(10); +} + +TEST_F(QuicConnectionContextTest, TestNestedSwitch) { + RunInThreads(10); +} + +TEST_F(QuicConnectionContextTest, TestAlternatingSwitch) { + RunInThreads(10); +} + +} // namespace +} // namespace quic::test diff --git a/quiche/quic/core/quic_connection_id.cc b/quiche/quic/core/quic_connection_id.cc new file mode 100644 index 000000000000..839097dafdc8 --- /dev/null +++ b/quiche/quic/core/quic_connection_id.cc @@ -0,0 +1,180 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_connection_id.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/escaping.h" +#include "openssl/siphash.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/quiche_endian.h" + +namespace quic { + +namespace { + +// QuicConnectionIdHasher can be used to generate a stable connection ID hash +// function that will return the same value for two equal connection IDs for +// the duration of process lifetime. It is meant to be used as input to data +// structures that do not outlast process lifetime. A new key is generated once +// per process to prevent attackers from crafting connection IDs in such a way +// that they always land in the same hash bucket. +class QuicConnectionIdHasher { + public: + inline QuicConnectionIdHasher() + : QuicConnectionIdHasher(QuicRandom::GetInstance()) {} + + explicit inline QuicConnectionIdHasher(QuicRandom* random) { + random->RandBytes(&sip_hash_key_, sizeof(sip_hash_key_)); + } + + inline size_t Hash(const char* input, size_t input_len) const { + return static_cast(SIPHASH_24( + sip_hash_key_, reinterpret_cast(input), input_len)); + } + + private: + uint64_t sip_hash_key_[2]; +}; + +} // namespace + +QuicConnectionId::QuicConnectionId() : QuicConnectionId(nullptr, 0) { + static_assert(offsetof(QuicConnectionId, padding_) == + offsetof(QuicConnectionId, length_), + "bad offset"); + static_assert(sizeof(QuicConnectionId) <= 16, "bad size"); +} + +QuicConnectionId::QuicConnectionId(const char* data, uint8_t length) { + length_ = length; + if (length_ == 0) { + return; + } + if (length_ <= sizeof(data_short_)) { + memcpy(data_short_, data, length_); + return; + } + data_long_ = reinterpret_cast(malloc(length_)); + QUICHE_CHECK_NE(nullptr, data_long_); + memcpy(data_long_, data, length_); +} + +QuicConnectionId::QuicConnectionId(const absl::Span data) + : QuicConnectionId(reinterpret_cast(data.data()), + data.length()) {} + +QuicConnectionId::~QuicConnectionId() { + if (length_ > sizeof(data_short_)) { + free(data_long_); + data_long_ = nullptr; + } +} + +QuicConnectionId::QuicConnectionId(const QuicConnectionId& other) + : QuicConnectionId(other.data(), other.length()) {} + +QuicConnectionId& QuicConnectionId::operator=(const QuicConnectionId& other) { + set_length(other.length()); + memcpy(mutable_data(), other.data(), length_); + return *this; +} + +const char* QuicConnectionId::data() const { + if (length_ <= sizeof(data_short_)) { + return data_short_; + } + return data_long_; +} + +char* QuicConnectionId::mutable_data() { + if (length_ <= sizeof(data_short_)) { + return data_short_; + } + return data_long_; +} + +uint8_t QuicConnectionId::length() const { return length_; } + +void QuicConnectionId::set_length(uint8_t length) { + char temporary_data[sizeof(data_short_)]; + if (length > sizeof(data_short_)) { + if (length_ <= sizeof(data_short_)) { + // Copy data from data_short_ to data_long_. + memcpy(temporary_data, data_short_, length_); + data_long_ = reinterpret_cast(malloc(length)); + QUICHE_CHECK_NE(nullptr, data_long_); + memcpy(data_long_, temporary_data, length_); + } else { + // Resize data_long_. + char* realloc_result = + reinterpret_cast(realloc(data_long_, length)); + QUICHE_CHECK_NE(nullptr, realloc_result); + data_long_ = realloc_result; + } + } else if (length_ > sizeof(data_short_)) { + // Copy data from data_long_ to data_short_. + memcpy(temporary_data, data_long_, length); + free(data_long_); + data_long_ = nullptr; + memcpy(data_short_, temporary_data, length); + } + length_ = length; +} + +bool QuicConnectionId::IsEmpty() const { return length_ == 0; } + +size_t QuicConnectionId::Hash() const { + static const QuicConnectionIdHasher hasher = QuicConnectionIdHasher(); + return hasher.Hash(data(), length_); +} + +std::string QuicConnectionId::ToString() const { + if (IsEmpty()) { + return std::string("0"); + } + return absl::BytesToHexString(absl::string_view(data(), length_)); +} + +std::ostream& operator<<(std::ostream& os, const QuicConnectionId& v) { + os << v.ToString(); + return os; +} + +bool QuicConnectionId::operator==(const QuicConnectionId& v) const { + return length_ == v.length_ && memcmp(data(), v.data(), length_) == 0; +} + +bool QuicConnectionId::operator!=(const QuicConnectionId& v) const { + return !(v == *this); +} + +bool QuicConnectionId::operator<(const QuicConnectionId& v) const { + if (length_ < v.length_) { + return true; + } + if (length_ > v.length_) { + return false; + } + return memcmp(data(), v.data(), length_) < 0; +} + +QuicConnectionId EmptyQuicConnectionId() { return QuicConnectionId(); } + +static_assert(kQuicDefaultConnectionIdLength == sizeof(uint64_t), + "kQuicDefaultConnectionIdLength changed"); +static_assert(kQuicDefaultConnectionIdLength == 8, + "kQuicDefaultConnectionIdLength changed"); + +} // namespace quic diff --git a/quiche/quic/core/quic_connection_id.h b/quiche/quic/core/quic_connection_id.h new file mode 100644 index 000000000000..b4e25fec0dbc --- /dev/null +++ b/quiche/quic/core/quic_connection_id.h @@ -0,0 +1,138 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_CONNECTION_ID_H_ +#define QUICHE_QUIC_CORE_QUIC_CONNECTION_ID_H_ + +#include +#include +#include + +#include "absl/types/span.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// This is a property of QUIC headers, it indicates whether the connection ID +// should actually be sent over the wire (or was sent on received packets). +enum QuicConnectionIdIncluded : uint8_t { + CONNECTION_ID_PRESENT = 1, + CONNECTION_ID_ABSENT = 2, +}; + +// Maximum connection ID length supported by versions that use the encoding from +// draft-ietf-quic-invariants-06. +const uint8_t kQuicMaxConnectionIdWithLengthPrefixLength = 20; + +// Maximum connection ID length supported by versions that use the encoding from +// draft-ietf-quic-invariants-05. +const uint8_t kQuicMaxConnectionId4BitLength = 18; + +// kQuicDefaultConnectionIdLength is the only supported length for QUIC +// versions < v99, and is the default picked for all versions. +const uint8_t kQuicDefaultConnectionIdLength = 8; + +// According to the IETF spec, the initial server connection ID generated by +// the client must be at least this long. +const uint8_t kQuicMinimumInitialConnectionIdLength = 8; + +class QUIC_EXPORT_PRIVATE QuicConnectionId { + public: + // Creates a connection ID of length zero. + QuicConnectionId(); + + // Creates a connection ID from network order bytes. + QuicConnectionId(const char* data, uint8_t length); + QuicConnectionId(const absl::Span data); + + // Creates a connection ID from another connection ID. + QuicConnectionId(const QuicConnectionId& other); + + // Assignment operator. + QuicConnectionId& operator=(const QuicConnectionId& other); + + ~QuicConnectionId(); + + // Returns the length of the connection ID, in bytes. + uint8_t length() const; + + // Sets the length of the connection ID, in bytes. + // WARNING: Calling set_length() can change the in-memory location of the + // connection ID. Callers must therefore ensure they call data() or + // mutable_data() after they call set_length(). + void set_length(uint8_t length); + + // Returns a pointer to the connection ID bytes, in network byte order. + const char* data() const; + + // Returns a mutable pointer to the connection ID bytes, + // in network byte order. + char* mutable_data(); + + // Returns whether the connection ID has length zero. + bool IsEmpty() const; + + // Hash() is required to use connection IDs as keys in hash tables. + // During the lifetime of a process, the output of Hash() is guaranteed to be + // the same for connection IDs that are equal to one another. Note however + // that this property is not guaranteed across process lifetimes. This makes + // Hash() suitable for data structures such as hash tables but not for sending + // a hash over the network. + size_t Hash() const; + + // Generates an ASCII string that represents + // the contents of the connection ID, or "0" if it is empty. + std::string ToString() const; + + // operator<< allows easily logging connection IDs. + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const QuicConnectionId& v); + + bool operator==(const QuicConnectionId& v) const; + bool operator!=(const QuicConnectionId& v) const; + // operator< is required to use connection IDs as keys in hash tables. + bool operator<(const QuicConnectionId& v) const; + + private: + // The connection ID is represented in network byte order. + union { + // If the connection ID fits in |data_short_|, it is stored in the + // first |length_| bytes of |data_short_|. + // Otherwise it is stored in |data_long_| which is guaranteed to have a size + // equal to |length_|. + // A value of 11 was chosen because our commonly used connection ID length + // is 8 and with the length, the class is padded to at least 12 bytes + // anyway. + struct { + uint8_t padding_; // Match length_ field of the other union member. + char data_short_[11]; + }; + struct { + uint8_t length_; // length of the connection ID, in bytes. + char* data_long_; + }; + }; +}; + +// Creates a connection ID of length zero, unless the restart flag +// quic_connection_ids_network_byte_order is false in which case +// it returns an 8-byte all-zeroes connection ID. +QUIC_EXPORT_PRIVATE QuicConnectionId EmptyQuicConnectionId(); + +// QuicConnectionIdHash can be passed as hash argument to hash tables. +// During the lifetime of a process, the output of QuicConnectionIdHash is +// guaranteed to be the same for connection IDs that are equal to one another. +// Note however that this property is not guaranteed across process lifetimes. +// This makes QuicConnectionIdHash suitable for data structures such as hash +// tables but not for sending a hash over the network. +class QUIC_EXPORT_PRIVATE QuicConnectionIdHash { + public: + size_t operator()(QuicConnectionId const& connection_id) const noexcept { + return connection_id.Hash(); + } +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_CONNECTION_ID_H_ diff --git a/quiche/quic/core/quic_connection_id_manager.cc b/quiche/quic/core/quic_connection_id_manager.cc new file mode 100644 index 000000000000..4776545ce2ba --- /dev/null +++ b/quiche/quic/core/quic_connection_id_manager.cc @@ -0,0 +1,487 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_connection_id_manager.h" + +#include + +#include "quiche/quic/core/quic_clock.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace quic { + +QuicConnectionIdData::QuicConnectionIdData( + const QuicConnectionId& connection_id, uint64_t sequence_number, + const StatelessResetToken& stateless_reset_token) + : connection_id(connection_id), + sequence_number(sequence_number), + stateless_reset_token(stateless_reset_token) {} + +namespace { + +class RetirePeerIssuedConnectionIdAlarm + : public QuicAlarm::DelegateWithContext { + public: + explicit RetirePeerIssuedConnectionIdAlarm( + QuicConnectionIdManagerVisitorInterface* visitor, + QuicConnectionContext* context) + : QuicAlarm::DelegateWithContext(context), visitor_(visitor) {} + RetirePeerIssuedConnectionIdAlarm(const RetirePeerIssuedConnectionIdAlarm&) = + delete; + RetirePeerIssuedConnectionIdAlarm& operator=( + const RetirePeerIssuedConnectionIdAlarm&) = delete; + + void OnAlarm() override { visitor_->OnPeerIssuedConnectionIdRetired(); } + + private: + QuicConnectionIdManagerVisitorInterface* visitor_; +}; + +std::vector::const_iterator FindConnectionIdData( + const std::vector& cid_data_vector, + const QuicConnectionId& cid) { + return std::find_if(cid_data_vector.begin(), cid_data_vector.end(), + [&cid](const QuicConnectionIdData& cid_data) { + return cid == cid_data.connection_id; + }); +} + +std::vector::iterator FindConnectionIdData( + std::vector* cid_data_vector, + const QuicConnectionId& cid) { + return std::find_if(cid_data_vector->begin(), cid_data_vector->end(), + [&cid](const QuicConnectionIdData& cid_data) { + return cid == cid_data.connection_id; + }); +} + +} // namespace + +QuicPeerIssuedConnectionIdManager::QuicPeerIssuedConnectionIdManager( + size_t active_connection_id_limit, + const QuicConnectionId& initial_peer_issued_connection_id, + const QuicClock* clock, QuicAlarmFactory* alarm_factory, + QuicConnectionIdManagerVisitorInterface* visitor, + QuicConnectionContext* context) + : active_connection_id_limit_(active_connection_id_limit), + clock_(clock), + retire_connection_id_alarm_(alarm_factory->CreateAlarm( + new RetirePeerIssuedConnectionIdAlarm(visitor, context))) { + QUICHE_DCHECK_GE(active_connection_id_limit_, 2u); + QUICHE_DCHECK(!initial_peer_issued_connection_id.IsEmpty()); + active_connection_id_data_.emplace_back( + initial_peer_issued_connection_id, + /*sequence_number=*/0u, {}); + recent_new_connection_id_sequence_numbers_.Add(0u, 1u); +} + +QuicPeerIssuedConnectionIdManager::~QuicPeerIssuedConnectionIdManager() { + retire_connection_id_alarm_->Cancel(); +} + +bool QuicPeerIssuedConnectionIdManager::IsConnectionIdNew( + const QuicNewConnectionIdFrame& frame) { + auto is_old_connection_id = [&frame](const QuicConnectionIdData& cid_data) { + return cid_data.connection_id == frame.connection_id; + }; + if (std::any_of(active_connection_id_data_.begin(), + active_connection_id_data_.end(), is_old_connection_id)) { + return false; + } + if (std::any_of(unused_connection_id_data_.begin(), + unused_connection_id_data_.end(), is_old_connection_id)) { + return false; + } + if (std::any_of(to_be_retired_connection_id_data_.begin(), + to_be_retired_connection_id_data_.end(), + is_old_connection_id)) { + return false; + } + return true; +} + +void QuicPeerIssuedConnectionIdManager::PrepareToRetireConnectionIdPriorTo( + uint64_t retire_prior_to, + std::vector* cid_data_vector) { + auto it2 = cid_data_vector->begin(); + for (auto it = cid_data_vector->begin(); it != cid_data_vector->end(); ++it) { + if (it->sequence_number >= retire_prior_to) { + *it2++ = *it; + } else { + to_be_retired_connection_id_data_.push_back(*it); + if (!retire_connection_id_alarm_->IsSet()) { + retire_connection_id_alarm_->Set(clock_->ApproximateNow()); + } + } + } + cid_data_vector->erase(it2, cid_data_vector->end()); +} + +QuicErrorCode QuicPeerIssuedConnectionIdManager::OnNewConnectionIdFrame( + const QuicNewConnectionIdFrame& frame, std::string* error_detail) { + if (recent_new_connection_id_sequence_numbers_.Contains( + frame.sequence_number)) { + // This frame has a recently seen sequence number. Ignore. + return QUIC_NO_ERROR; + } + if (!IsConnectionIdNew(frame)) { + *error_detail = + "Received a NEW_CONNECTION_ID frame that reuses a previously seen Id."; + return IETF_QUIC_PROTOCOL_VIOLATION; + } + + recent_new_connection_id_sequence_numbers_.AddOptimizedForAppend( + frame.sequence_number, frame.sequence_number + 1); + + if (recent_new_connection_id_sequence_numbers_.Size() > + kMaxNumConnectionIdSequenceNumberIntervals) { + *error_detail = + "Too many disjoint connection Id sequence number intervals."; + return IETF_QUIC_PROTOCOL_VIOLATION; + } + + // QuicFramer::ProcessNewConnectionIdFrame guarantees that + // frame.sequence_number >= frame.retire_prior_to, and hence there is no need + // to check that. + if (frame.sequence_number < max_new_connection_id_frame_retire_prior_to_) { + // Later frames have asked for retirement of the current frame. + to_be_retired_connection_id_data_.emplace_back(frame.connection_id, + frame.sequence_number, + frame.stateless_reset_token); + if (!retire_connection_id_alarm_->IsSet()) { + retire_connection_id_alarm_->Set(clock_->ApproximateNow()); + } + return QUIC_NO_ERROR; + } + if (frame.retire_prior_to > max_new_connection_id_frame_retire_prior_to_) { + max_new_connection_id_frame_retire_prior_to_ = frame.retire_prior_to; + PrepareToRetireConnectionIdPriorTo(frame.retire_prior_to, + &active_connection_id_data_); + PrepareToRetireConnectionIdPriorTo(frame.retire_prior_to, + &unused_connection_id_data_); + } + + if (active_connection_id_data_.size() + unused_connection_id_data_.size() >= + active_connection_id_limit_) { + *error_detail = "Peer provides more connection IDs than the limit."; + return QUIC_CONNECTION_ID_LIMIT_ERROR; + } + + unused_connection_id_data_.emplace_back( + frame.connection_id, frame.sequence_number, frame.stateless_reset_token); + return QUIC_NO_ERROR; +} + +const QuicConnectionIdData* +QuicPeerIssuedConnectionIdManager::ConsumeOneUnusedConnectionId() { + if (unused_connection_id_data_.empty()) { + return nullptr; + } + active_connection_id_data_.push_back(unused_connection_id_data_.back()); + unused_connection_id_data_.pop_back(); + return &active_connection_id_data_.back(); +} + +void QuicPeerIssuedConnectionIdManager::PrepareToRetireActiveConnectionId( + const QuicConnectionId& cid) { + auto it = FindConnectionIdData(active_connection_id_data_, cid); + if (it == active_connection_id_data_.end()) { + // The cid has already been retired. + return; + } + to_be_retired_connection_id_data_.push_back(*it); + active_connection_id_data_.erase(it); + if (!retire_connection_id_alarm_->IsSet()) { + retire_connection_id_alarm_->Set(clock_->ApproximateNow()); + } +} + +void QuicPeerIssuedConnectionIdManager::MaybeRetireUnusedConnectionIds( + const std::vector& active_connection_ids_on_path) { + std::vector cids_to_retire; + for (const auto& cid_data : active_connection_id_data_) { + if (std::find(active_connection_ids_on_path.begin(), + active_connection_ids_on_path.end(), + cid_data.connection_id) == + active_connection_ids_on_path.end()) { + cids_to_retire.push_back(cid_data.connection_id); + } + } + for (const auto& cid : cids_to_retire) { + PrepareToRetireActiveConnectionId(cid); + } +} + +bool QuicPeerIssuedConnectionIdManager::IsConnectionIdActive( + const QuicConnectionId& cid) const { + return FindConnectionIdData(active_connection_id_data_, cid) != + active_connection_id_data_.end(); +} + +std::vector QuicPeerIssuedConnectionIdManager:: + ConsumeToBeRetiredConnectionIdSequenceNumbers() { + std::vector result; + for (auto const& cid_data : to_be_retired_connection_id_data_) { + result.push_back(cid_data.sequence_number); + } + to_be_retired_connection_id_data_.clear(); + return result; +} + +void QuicPeerIssuedConnectionIdManager::ReplaceConnectionId( + const QuicConnectionId& old_connection_id, + const QuicConnectionId& new_connection_id) { + auto it1 = + FindConnectionIdData(&active_connection_id_data_, old_connection_id); + if (it1 != active_connection_id_data_.end()) { + it1->connection_id = new_connection_id; + return; + } + auto it2 = FindConnectionIdData(&to_be_retired_connection_id_data_, + old_connection_id); + if (it2 != to_be_retired_connection_id_data_.end()) { + it2->connection_id = new_connection_id; + } +} + +namespace { + +class RetireSelfIssuedConnectionIdAlarmDelegate + : public QuicAlarm::DelegateWithContext { + public: + explicit RetireSelfIssuedConnectionIdAlarmDelegate( + QuicSelfIssuedConnectionIdManager* connection_id_manager, + QuicConnectionContext* context) + : QuicAlarm::DelegateWithContext(context), + connection_id_manager_(connection_id_manager) {} + RetireSelfIssuedConnectionIdAlarmDelegate( + const RetireSelfIssuedConnectionIdAlarmDelegate&) = delete; + RetireSelfIssuedConnectionIdAlarmDelegate& operator=( + const RetireSelfIssuedConnectionIdAlarmDelegate&) = delete; + + void OnAlarm() override { connection_id_manager_->RetireConnectionId(); } + + private: + QuicSelfIssuedConnectionIdManager* connection_id_manager_; +}; + +} // namespace + +QuicSelfIssuedConnectionIdManager::QuicSelfIssuedConnectionIdManager( + size_t active_connection_id_limit, + const QuicConnectionId& initial_connection_id, const QuicClock* clock, + QuicAlarmFactory* alarm_factory, + QuicConnectionIdManagerVisitorInterface* visitor, + QuicConnectionContext* context, ConnectionIdGeneratorInterface& generator) + : active_connection_id_limit_(active_connection_id_limit), + clock_(clock), + visitor_(visitor), + retire_connection_id_alarm_(alarm_factory->CreateAlarm( + new RetireSelfIssuedConnectionIdAlarmDelegate(this, context))), + last_connection_id_(initial_connection_id), + next_connection_id_sequence_number_(1u), + last_connection_id_consumed_by_self_sequence_number_(0u), + connection_id_generator_(generator) { + active_connection_ids_.emplace_back(initial_connection_id, 0u); +} + +QuicSelfIssuedConnectionIdManager::~QuicSelfIssuedConnectionIdManager() { + retire_connection_id_alarm_->Cancel(); +} + +absl::optional +QuicSelfIssuedConnectionIdManager::MaybeIssueNewConnectionId() { + const bool check_cid_collision_when_issue_new_cid = + GetQuicReloadableFlag(quic_check_cid_collision_when_issue_new_cid); + absl::optional new_cid = + connection_id_generator_.GenerateNextConnectionId(last_connection_id_); + if (!new_cid.has_value()) { + return {}; + } + if (check_cid_collision_when_issue_new_cid) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_check_cid_collision_when_issue_new_cid, 1, + 2); + if (!visitor_->MaybeReserveConnectionId(*new_cid)) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_check_cid_collision_when_issue_new_cid, + 2, 2); + return {}; + } + } + QuicNewConnectionIdFrame frame; + frame.connection_id = *new_cid; + frame.sequence_number = next_connection_id_sequence_number_++; + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + if (!check_cid_collision_when_issue_new_cid) { + visitor_->MaybeReserveConnectionId(frame.connection_id); + } + active_connection_ids_.emplace_back(frame.connection_id, + frame.sequence_number); + frame.retire_prior_to = active_connection_ids_.front().second; + last_connection_id_ = frame.connection_id; + return frame; +} + +absl::optional QuicSelfIssuedConnectionIdManager:: + MaybeIssueNewConnectionIdForPreferredAddress() { + absl::optional frame = MaybeIssueNewConnectionId(); + QUICHE_DCHECK(!frame.has_value() || (frame->sequence_number == 1u)); + return frame; +} + +QuicErrorCode QuicSelfIssuedConnectionIdManager::OnRetireConnectionIdFrame( + const QuicRetireConnectionIdFrame& frame, QuicTime::Delta pto_delay, + std::string* error_detail) { + QUICHE_DCHECK(!active_connection_ids_.empty()); + if (GetQuicReloadableFlag( + quic_check_retire_cid_with_next_cid_sequence_number)) { + QUIC_RELOADABLE_FLAG_COUNT( + quic_check_retire_cid_with_next_cid_sequence_number); + if (frame.sequence_number >= next_connection_id_sequence_number_) { + *error_detail = "To be retired connecton ID is never issued."; + return IETF_QUIC_PROTOCOL_VIOLATION; + } + } else { + if (frame.sequence_number > active_connection_ids_.back().second) { + *error_detail = "To be retired connecton ID is never issued."; + return IETF_QUIC_PROTOCOL_VIOLATION; + } + } + + auto it = + std::find_if(active_connection_ids_.begin(), active_connection_ids_.end(), + [&frame](const std::pair& p) { + return p.second == frame.sequence_number; + }); + // The corresponding connection ID has been retired. Ignore. + if (it == active_connection_ids_.end()) { + return QUIC_NO_ERROR; + } + + if (to_be_retired_connection_ids_.size() + active_connection_ids_.size() >= + kMaxNumConnectonIdsInUse) { + // Close connection if the number of connection IDs in use will exeed the + // limit, i.e., peer retires connection ID too fast. + *error_detail = "There are too many connection IDs in use."; + return QUIC_TOO_MANY_CONNECTION_ID_WAITING_TO_RETIRE; + } + + QuicTime retirement_time = clock_->ApproximateNow() + 3 * pto_delay; + if (!to_be_retired_connection_ids_.empty()) { + retirement_time = + std::max(retirement_time, to_be_retired_connection_ids_.back().second); + } + + to_be_retired_connection_ids_.emplace_back(it->first, retirement_time); + if (!retire_connection_id_alarm_->IsSet()) { + retire_connection_id_alarm_->Set(retirement_time); + } + + active_connection_ids_.erase(it); + MaybeSendNewConnectionIds(); + + return QUIC_NO_ERROR; +} + +std::vector +QuicSelfIssuedConnectionIdManager::GetUnretiredConnectionIds() const { + std::vector unretired_ids; + for (const auto& cid_pair : to_be_retired_connection_ids_) { + unretired_ids.push_back(cid_pair.first); + } + for (const auto& cid_pair : active_connection_ids_) { + unretired_ids.push_back(cid_pair.first); + } + return unretired_ids; +} + +QuicConnectionId QuicSelfIssuedConnectionIdManager::GetOneActiveConnectionId() + const { + QUICHE_DCHECK(!active_connection_ids_.empty()); + return active_connection_ids_.front().first; +} + +void QuicSelfIssuedConnectionIdManager::RetireConnectionId() { + if (to_be_retired_connection_ids_.empty()) { + QUIC_BUG(quic_bug_12420_1) + << "retire_connection_id_alarm fired but there is no connection ID " + "to be retired."; + return; + } + QuicTime now = clock_->ApproximateNow(); + auto it = to_be_retired_connection_ids_.begin(); + do { + visitor_->OnSelfIssuedConnectionIdRetired(it->first); + ++it; + } while (it != to_be_retired_connection_ids_.end() && it->second <= now); + to_be_retired_connection_ids_.erase(to_be_retired_connection_ids_.begin(), + it); + // Set the alarm again if there is another connection ID to be removed. + if (!to_be_retired_connection_ids_.empty()) { + retire_connection_id_alarm_->Set( + to_be_retired_connection_ids_.front().second); + } +} + +void QuicSelfIssuedConnectionIdManager::MaybeSendNewConnectionIds() { + while (active_connection_ids_.size() < active_connection_id_limit_) { + absl::optional frame = + MaybeIssueNewConnectionId(); + if (!frame.has_value()) { + break; + } + if (!visitor_->SendNewConnectionId(*frame)) { + break; + } + } +} + +bool QuicSelfIssuedConnectionIdManager::HasConnectionIdToConsume() const { + for (const auto& active_cid_data : active_connection_ids_) { + if (active_cid_data.second > + last_connection_id_consumed_by_self_sequence_number_) { + return true; + } + } + return false; +} + +absl::optional +QuicSelfIssuedConnectionIdManager::ConsumeOneConnectionId() { + for (const auto& active_cid_data : active_connection_ids_) { + if (active_cid_data.second > + last_connection_id_consumed_by_self_sequence_number_) { + // Since connection IDs in active_connection_ids_ has monotonically + // increasing sequence numbers, the returned connection ID has the + // smallest sequence number among all unconsumed active connection IDs. + last_connection_id_consumed_by_self_sequence_number_ = + active_cid_data.second; + return active_cid_data.first; + } + } + return absl::nullopt; +} + +bool QuicSelfIssuedConnectionIdManager::IsConnectionIdInUse( + const QuicConnectionId& cid) const { + for (const auto& active_cid_data : active_connection_ids_) { + if (active_cid_data.first == cid) { + return true; + } + } + for (const auto& to_be_retired_cid_data : to_be_retired_connection_ids_) { + if (to_be_retired_cid_data.first == cid) { + return true; + } + } + return false; +} + +} // namespace quic diff --git a/quiche/quic/core/quic_connection_id_manager.h b/quiche/quic/core/quic_connection_id_manager.h new file mode 100644 index 000000000000..b8454af09eed --- /dev/null +++ b/quiche/quic/core/quic_connection_id_manager.h @@ -0,0 +1,197 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// QuicPeerIssuedConnectionIdManager handles the states associated with receving +// and retiring peer issued connection Ids. +// QuicSelfIssuedConnectionIdManager handles the states associated with +// connection Ids issued by the current end point. + +#ifndef QUICHE_QUIC_CORE_QUIC_CONNECTION_ID_MANAGER_H_ +#define QUICHE_QUIC_CORE_QUIC_CONNECTION_ID_MANAGER_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "quiche/quic/core/connection_id_generator.h" +#include "quiche/quic/core/frames/quic_new_connection_id_frame.h" +#include "quiche/quic/core/frames/quic_retire_connection_id_frame.h" +#include "quiche/quic/core/quic_alarm.h" +#include "quiche/quic/core/quic_alarm_factory.h" +#include "quiche/quic/core/quic_clock.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_interval_set.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +namespace test { +class QuicConnectionIdManagerPeer; +} // namespace test + +struct QUIC_EXPORT_PRIVATE QuicConnectionIdData { + QuicConnectionIdData(const QuicConnectionId& connection_id, + uint64_t sequence_number, + const StatelessResetToken& stateless_reset_token); + + QuicConnectionId connection_id; + uint64_t sequence_number; + StatelessResetToken stateless_reset_token; +}; + +// Used by QuicSelfIssuedConnectionIdManager +// and QuicPeerIssuedConnectionIdManager. +class QUIC_EXPORT_PRIVATE QuicConnectionIdManagerVisitorInterface { + public: + virtual ~QuicConnectionIdManagerVisitorInterface() = default; + virtual void OnPeerIssuedConnectionIdRetired() = 0; + virtual bool SendNewConnectionId(const QuicNewConnectionIdFrame& frame) = 0; + virtual bool MaybeReserveConnectionId( + const QuicConnectionId& connection_id) = 0; + virtual void OnSelfIssuedConnectionIdRetired( + const QuicConnectionId& connection_id) = 0; +}; + +class QUIC_EXPORT_PRIVATE QuicPeerIssuedConnectionIdManager { + public: + // QuicPeerIssuedConnectionIdManager should be instantiated only when a peer + // issued-non empty connection ID is received. + QuicPeerIssuedConnectionIdManager( + size_t active_connection_id_limit, + const QuicConnectionId& initial_peer_issued_connection_id, + const QuicClock* clock, QuicAlarmFactory* alarm_factory, + QuicConnectionIdManagerVisitorInterface* visitor, + QuicConnectionContext* context); + + ~QuicPeerIssuedConnectionIdManager(); + + QuicErrorCode OnNewConnectionIdFrame(const QuicNewConnectionIdFrame& frame, + std::string* error_detail); + + bool HasUnusedConnectionId() const { + return !unused_connection_id_data_.empty(); + } + + // Returns the data associated with an unused connection Id. After the call, + // the Id is marked as used. Returns nullptr if there is no unused connection + // Id. + const QuicConnectionIdData* ConsumeOneUnusedConnectionId(); + + // Add each active connection Id that is no longer on path to the pending + // retirement connection Id list. + void MaybeRetireUnusedConnectionIds( + const std::vector& active_connection_ids_on_path); + + bool IsConnectionIdActive(const QuicConnectionId& cid) const; + + // Get the sequence numbers of all the connection Ids pending retirement when + // it is safe to retires these Ids. + std::vector ConsumeToBeRetiredConnectionIdSequenceNumbers(); + + // If old_connection_id is still tracked by QuicPeerIssuedConnectionIdManager, + // replace it with new_connection_id. Otherwise, this is a no-op. + void ReplaceConnectionId(const QuicConnectionId& old_connection_id, + const QuicConnectionId& new_connection_id); + + private: + friend class test::QuicConnectionIdManagerPeer; + + // Add the connection Id to the pending retirement connection Id list and + // schedule an alarm if needed. + void PrepareToRetireActiveConnectionId(const QuicConnectionId& cid); + + bool IsConnectionIdNew(const QuicNewConnectionIdFrame& frame); + + void PrepareToRetireConnectionIdPriorTo( + uint64_t retire_prior_to, + std::vector* cid_data_vector); + + size_t active_connection_id_limit_; + const QuicClock* clock_; + std::unique_ptr retire_connection_id_alarm_; + std::vector active_connection_id_data_; + std::vector unused_connection_id_data_; + std::vector to_be_retired_connection_id_data_; + // Track sequence numbers of recent NEW_CONNECTION_ID frames received from + // the peer. + QuicIntervalSet recent_new_connection_id_sequence_numbers_; + uint64_t max_new_connection_id_frame_retire_prior_to_ = 0u; +}; + +class QUIC_EXPORT_PRIVATE QuicSelfIssuedConnectionIdManager { + public: + QuicSelfIssuedConnectionIdManager( + size_t active_connection_id_limit, + const QuicConnectionId& initial_connection_id, const QuicClock* clock, + QuicAlarmFactory* alarm_factory, + QuicConnectionIdManagerVisitorInterface* visitor, + QuicConnectionContext* context, + ConnectionIdGeneratorInterface& generator); + + virtual ~QuicSelfIssuedConnectionIdManager(); + + absl::optional + MaybeIssueNewConnectionIdForPreferredAddress(); + + QuicErrorCode OnRetireConnectionIdFrame( + const QuicRetireConnectionIdFrame& frame, QuicTime::Delta pto_delay, + std::string* error_detail); + + std::vector GetUnretiredConnectionIds() const; + + QuicConnectionId GetOneActiveConnectionId() const; + + // Called when the retire_connection_id alarm_ fires. Removes the to be + // retired connection ID locally. + void RetireConnectionId(); + + // Sends new connection IDs if more can be sent. + void MaybeSendNewConnectionIds(); + + // The two functions are called on the client side to associate a client + // connection ID with a new probing/migration path when client uses + // non-empty connection ID. + bool HasConnectionIdToConsume() const; + absl::optional ConsumeOneConnectionId(); + + // Returns true if the given connection ID is issued by the + // QuicSelfIssuedConnectionIdManager and not retired locally yet. Called to + // tell if a received packet has a valid connection ID. + bool IsConnectionIdInUse(const QuicConnectionId& cid) const; + + private: + friend class test::QuicConnectionIdManagerPeer; + + // Issue a new connection ID. Can return nullopt. + absl::optional MaybeIssueNewConnectionId(); + + // This should be set to the min of: + // (1) # of active connection IDs that peer can maintain. + // (2) maximum # of active connection IDs self plans to issue. + size_t active_connection_id_limit_; + const QuicClock* clock_; + QuicConnectionIdManagerVisitorInterface* visitor_; + // This tracks connection IDs issued to the peer but not retired by the peer. + // Each pair is a connection ID and its sequence number. + std::vector> active_connection_ids_; + // This tracks connection IDs retired by the peer but has not been retired + // locally. Each pair is a connection ID and the time by which it should be + // retired. + std::vector> + to_be_retired_connection_ids_; + // An alarm that fires when a connection ID should be retired. + std::unique_ptr retire_connection_id_alarm_; + // State of the last issued connection Id. + QuicConnectionId last_connection_id_; + uint64_t next_connection_id_sequence_number_; + // The sequence number of last connection ID consumed. + uint64_t last_connection_id_consumed_by_self_sequence_number_; + + ConnectionIdGeneratorInterface& connection_id_generator_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_CONNECTION_ID_MANAGER_H_ diff --git a/quiche/quic/core/quic_connection_id_manager_test.cc b/quiche/quic/core/quic_connection_id_manager_test.cc new file mode 100644 index 000000000000..2dd35bfbe64e --- /dev/null +++ b/quiche/quic/core/quic_connection_id_manager_test.cc @@ -0,0 +1,1074 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_connection_id_manager.h" + +#include + +#include "quiche/quic/core/frames/quic_retire_connection_id_frame.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/mock_clock.h" +#include "quiche/quic/test_tools/mock_connection_id_generator.h" +#include "quiche/quic/test_tools/quic_connection_id_manager_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic::test { +namespace { + +using ::quic::test::IsError; +using ::quic::test::IsQuicNoError; +using ::quic::test::QuicConnectionIdManagerPeer; +using ::quic::test::TestConnectionId; +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::IsNull; +using ::testing::Return; +using ::testing::StrictMock; + +class TestPeerIssuedConnectionIdManagerVisitor + : public QuicConnectionIdManagerVisitorInterface { + public: + void SetPeerIssuedConnectionIdManager( + QuicPeerIssuedConnectionIdManager* peer_issued_connection_id_manager) { + peer_issued_connection_id_manager_ = peer_issued_connection_id_manager; + } + + void OnPeerIssuedConnectionIdRetired() override { + // Replace current connection Id if it has been retired. + if (!peer_issued_connection_id_manager_->IsConnectionIdActive( + current_peer_issued_connection_id_)) { + current_peer_issued_connection_id_ = + peer_issued_connection_id_manager_->ConsumeOneUnusedConnectionId() + ->connection_id; + } + // Retire all the to-be-retired connection Ids. + most_recent_retired_connection_id_sequence_numbers_ = + peer_issued_connection_id_manager_ + ->ConsumeToBeRetiredConnectionIdSequenceNumbers(); + } + + const std::vector& + most_recent_retired_connection_id_sequence_numbers() { + return most_recent_retired_connection_id_sequence_numbers_; + } + + void SetCurrentPeerConnectionId(QuicConnectionId cid) { + current_peer_issued_connection_id_ = cid; + } + + const QuicConnectionId& GetCurrentPeerConnectionId() { + return current_peer_issued_connection_id_; + } + + bool SendNewConnectionId(const QuicNewConnectionIdFrame& /*frame*/) override { + return false; + } + bool MaybeReserveConnectionId(const QuicConnectionId&) override { + return false; + } + + void OnSelfIssuedConnectionIdRetired( + const QuicConnectionId& /*connection_id*/) override {} + + private: + QuicPeerIssuedConnectionIdManager* peer_issued_connection_id_manager_ = + nullptr; + QuicConnectionId current_peer_issued_connection_id_; + std::vector most_recent_retired_connection_id_sequence_numbers_; +}; + +class QuicPeerIssuedConnectionIdManagerTest : public QuicTest { + public: + QuicPeerIssuedConnectionIdManagerTest() + : peer_issued_cid_manager_( + /*active_connection_id_limit=*/2, initial_connection_id_, &clock_, + &alarm_factory_, &cid_manager_visitor_, /*context=*/nullptr) { + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + cid_manager_visitor_.SetPeerIssuedConnectionIdManager( + &peer_issued_cid_manager_); + cid_manager_visitor_.SetCurrentPeerConnectionId(initial_connection_id_); + retire_peer_issued_cid_alarm_ = + QuicConnectionIdManagerPeer::GetRetirePeerIssuedConnectionIdAlarm( + &peer_issued_cid_manager_); + } + + protected: + MockClock clock_; + test::MockAlarmFactory alarm_factory_; + TestPeerIssuedConnectionIdManagerVisitor cid_manager_visitor_; + QuicConnectionId initial_connection_id_ = TestConnectionId(0); + QuicPeerIssuedConnectionIdManager peer_issued_cid_manager_; + QuicAlarm* retire_peer_issued_cid_alarm_ = nullptr; + std::string error_details_; +}; + +TEST_F(QuicPeerIssuedConnectionIdManagerTest, + ConnectionIdSequenceWhenMigrationSucceed) { + { + // Receives CID #1 from peer. + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(1); + frame.sequence_number = 1u; + frame.retire_prior_to = 0u; + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + ASSERT_THAT( + peer_issued_cid_manager_.OnNewConnectionIdFrame(frame, &error_details_), + IsQuicNoError()); + + // Start to use CID #1 for alternative path. + const QuicConnectionIdData* aternative_connection_id_data = + peer_issued_cid_manager_.ConsumeOneUnusedConnectionId(); + ASSERT_THAT(aternative_connection_id_data, testing::NotNull()); + EXPECT_EQ(aternative_connection_id_data->connection_id, + TestConnectionId(1)); + EXPECT_EQ(aternative_connection_id_data->stateless_reset_token, + frame.stateless_reset_token); + + // Connection migration succeed. Prepares to retire CID #0. + peer_issued_cid_manager_.MaybeRetireUnusedConnectionIds( + {TestConnectionId(1)}); + cid_manager_visitor_.SetCurrentPeerConnectionId(TestConnectionId(1)); + ASSERT_TRUE(retire_peer_issued_cid_alarm_->IsSet()); + alarm_factory_.FireAlarm(retire_peer_issued_cid_alarm_); + EXPECT_THAT(cid_manager_visitor_ + .most_recent_retired_connection_id_sequence_numbers(), + ElementsAre(0u)); + } + + { + // Receives CID #2 from peer since CID #0 is retired. + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(2); + frame.sequence_number = 2u; + frame.retire_prior_to = 1u; + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + ASSERT_THAT( + peer_issued_cid_manager_.OnNewConnectionIdFrame(frame, &error_details_), + IsQuicNoError()); + // Start to use CID #2 for alternative path. + peer_issued_cid_manager_.ConsumeOneUnusedConnectionId(); + // Connection migration succeed. Prepares to retire CID #1. + peer_issued_cid_manager_.MaybeRetireUnusedConnectionIds( + {TestConnectionId(2)}); + cid_manager_visitor_.SetCurrentPeerConnectionId(TestConnectionId(2)); + ASSERT_TRUE(retire_peer_issued_cid_alarm_->IsSet()); + alarm_factory_.FireAlarm(retire_peer_issued_cid_alarm_); + EXPECT_THAT(cid_manager_visitor_ + .most_recent_retired_connection_id_sequence_numbers(), + ElementsAre(1u)); + } + + { + // Receives CID #3 from peer since CID #1 is retired. + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(3); + frame.sequence_number = 3u; + frame.retire_prior_to = 2u; + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + ASSERT_THAT( + peer_issued_cid_manager_.OnNewConnectionIdFrame(frame, &error_details_), + IsQuicNoError()); + // Start to use CID #3 for alternative path. + peer_issued_cid_manager_.ConsumeOneUnusedConnectionId(); + // Connection migration succeed. Prepares to retire CID #2. + peer_issued_cid_manager_.MaybeRetireUnusedConnectionIds( + {TestConnectionId(3)}); + cid_manager_visitor_.SetCurrentPeerConnectionId(TestConnectionId(3)); + ASSERT_TRUE(retire_peer_issued_cid_alarm_->IsSet()); + alarm_factory_.FireAlarm(retire_peer_issued_cid_alarm_); + EXPECT_THAT(cid_manager_visitor_ + .most_recent_retired_connection_id_sequence_numbers(), + ElementsAre(2u)); + } + + { + // Receives CID #4 from peer since CID #2 is retired. + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(4); + frame.sequence_number = 4u; + frame.retire_prior_to = 3u; + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + ASSERT_THAT( + peer_issued_cid_manager_.OnNewConnectionIdFrame(frame, &error_details_), + IsQuicNoError()); + } +} + +TEST_F(QuicPeerIssuedConnectionIdManagerTest, + ConnectionIdSequenceWhenMigrationFail) { + { + // Receives CID #1 from peer. + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(1); + frame.sequence_number = 1u; + frame.retire_prior_to = 0u; + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + ASSERT_THAT( + peer_issued_cid_manager_.OnNewConnectionIdFrame(frame, &error_details_), + IsQuicNoError()); + // Start to use CID #1 for alternative path. + peer_issued_cid_manager_.ConsumeOneUnusedConnectionId(); + // Connection migration fails. Prepares to retire CID #1. + peer_issued_cid_manager_.MaybeRetireUnusedConnectionIds( + {initial_connection_id_}); + // Actually retires CID #1. + ASSERT_TRUE(retire_peer_issued_cid_alarm_->IsSet()); + alarm_factory_.FireAlarm(retire_peer_issued_cid_alarm_); + EXPECT_THAT(cid_manager_visitor_ + .most_recent_retired_connection_id_sequence_numbers(), + ElementsAre(1u)); + } + + { + // Receives CID #2 from peer since CID #1 is retired. + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(2); + frame.sequence_number = 2u; + frame.retire_prior_to = 0u; + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + ASSERT_THAT( + peer_issued_cid_manager_.OnNewConnectionIdFrame(frame, &error_details_), + IsQuicNoError()); + // Start to use CID #2 for alternative path. + peer_issued_cid_manager_.ConsumeOneUnusedConnectionId(); + // Connection migration fails again. Prepares to retire CID #2. + peer_issued_cid_manager_.MaybeRetireUnusedConnectionIds( + {initial_connection_id_}); + // Actually retires CID #2. + ASSERT_TRUE(retire_peer_issued_cid_alarm_->IsSet()); + alarm_factory_.FireAlarm(retire_peer_issued_cid_alarm_); + EXPECT_THAT(cid_manager_visitor_ + .most_recent_retired_connection_id_sequence_numbers(), + ElementsAre(2u)); + } + + { + // Receives CID #3 from peer since CID #2 is retired. + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(3); + frame.sequence_number = 3u; + frame.retire_prior_to = 0u; + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + ASSERT_THAT( + peer_issued_cid_manager_.OnNewConnectionIdFrame(frame, &error_details_), + IsQuicNoError()); + // Start to use CID #3 for alternative path. + peer_issued_cid_manager_.ConsumeOneUnusedConnectionId(); + // Connection migration succeed. Prepares to retire CID #0. + peer_issued_cid_manager_.MaybeRetireUnusedConnectionIds( + {TestConnectionId(3)}); + // After CID #3 is default (i.e., when there is no pending frame to write + // associated with CID #0), #0 can actually be retired. + cid_manager_visitor_.SetCurrentPeerConnectionId(TestConnectionId(3)); + ASSERT_TRUE(retire_peer_issued_cid_alarm_->IsSet()); + alarm_factory_.FireAlarm(retire_peer_issued_cid_alarm_); + EXPECT_THAT(cid_manager_visitor_ + .most_recent_retired_connection_id_sequence_numbers(), + ElementsAre(0u)); + } + + { + // Receives CID #4 from peer since CID #0 is retired. + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(4); + frame.sequence_number = 4u; + frame.retire_prior_to = 3u; + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + EXPECT_THAT( + peer_issued_cid_manager_.OnNewConnectionIdFrame(frame, &error_details_), + IsQuicNoError()); + EXPECT_FALSE(retire_peer_issued_cid_alarm_->IsSet()); + } +} + +TEST_F(QuicPeerIssuedConnectionIdManagerTest, + ReceivesNewConnectionIdOutOfOrder) { + { + // Receives new CID #1 that retires prior to #0. + // Outcome: (active: #0 unused: #1) + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(1); + frame.sequence_number = 1u; + frame.retire_prior_to = 0u; + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + ASSERT_THAT( + peer_issued_cid_manager_.OnNewConnectionIdFrame(frame, &error_details_), + IsQuicNoError()); + // Start to use CID #1 for alternative path. + // Outcome: (active: #0 #1 unused: None) + peer_issued_cid_manager_.ConsumeOneUnusedConnectionId(); + } + + { + // Receives new CID #3 that retires prior to #2. + // Outcome: (active: None unused: #3) + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(3); + frame.sequence_number = 3u; + frame.retire_prior_to = 2u; + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + ASSERT_THAT( + peer_issued_cid_manager_.OnNewConnectionIdFrame(frame, &error_details_), + IsQuicNoError()); + } + + { + // Receives new CID #2 that retires prior to #1. + // Outcome: (active: None unused: #3, #2) + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(2); + frame.sequence_number = 2u; + frame.retire_prior_to = 1u; + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + ASSERT_THAT( + peer_issued_cid_manager_.OnNewConnectionIdFrame(frame, &error_details_), + IsQuicNoError()); + } + + { + EXPECT_FALSE( + peer_issued_cid_manager_.IsConnectionIdActive(TestConnectionId(0))); + EXPECT_FALSE( + peer_issued_cid_manager_.IsConnectionIdActive(TestConnectionId(1))); + // When there is no frame associated with #0 and #1 to write, replace the + // in-use CID with an unused CID (#2) and retires #0 & #1. + ASSERT_TRUE(retire_peer_issued_cid_alarm_->IsSet()); + alarm_factory_.FireAlarm(retire_peer_issued_cid_alarm_); + EXPECT_THAT(cid_manager_visitor_ + .most_recent_retired_connection_id_sequence_numbers(), + ElementsAre(0u, 1u)); + EXPECT_EQ(cid_manager_visitor_.GetCurrentPeerConnectionId(), + TestConnectionId(2)); + // Get another unused CID for path validation. + EXPECT_EQ( + peer_issued_cid_manager_.ConsumeOneUnusedConnectionId()->connection_id, + TestConnectionId(3)); + } +} + +TEST_F(QuicPeerIssuedConnectionIdManagerTest, + VisitedNewConnectionIdFrameIsIgnored) { + // Receives new CID #1 that retires prior to #0. + // Outcome: (active: #0 unused: #1) + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(1); + frame.sequence_number = 1u; + frame.retire_prior_to = 0u; + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + ASSERT_THAT( + peer_issued_cid_manager_.OnNewConnectionIdFrame(frame, &error_details_), + IsQuicNoError()); + // Start to use CID #1 for alternative path. + // Outcome: (active: #0 #1 unused: None) + peer_issued_cid_manager_.ConsumeOneUnusedConnectionId(); + // Prepare to retire CID #1 as path validation fails. + peer_issued_cid_manager_.MaybeRetireUnusedConnectionIds( + {initial_connection_id_}); + // Actually retires CID #1. + ASSERT_TRUE(retire_peer_issued_cid_alarm_->IsSet()); + alarm_factory_.FireAlarm(retire_peer_issued_cid_alarm_); + EXPECT_THAT( + cid_manager_visitor_.most_recent_retired_connection_id_sequence_numbers(), + ElementsAre(1u)); + // Receives the same frame again. Should be a no-op. + ASSERT_THAT( + peer_issued_cid_manager_.OnNewConnectionIdFrame(frame, &error_details_), + IsQuicNoError()); + EXPECT_THAT(peer_issued_cid_manager_.ConsumeOneUnusedConnectionId(), + testing::IsNull()); +} + +TEST_F(QuicPeerIssuedConnectionIdManagerTest, + ErrorWhenActiveConnectionIdLimitExceeded) { + { + // Receives new CID #1 that retires prior to #0. + // Outcome: (active: #0 unused: #1) + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(1); + frame.sequence_number = 1u; + frame.retire_prior_to = 0u; + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + ASSERT_THAT( + peer_issued_cid_manager_.OnNewConnectionIdFrame(frame, &error_details_), + IsQuicNoError()); + } + + { + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(2); + frame.sequence_number = 2u; + frame.retire_prior_to = 0u; + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + ASSERT_THAT( + peer_issued_cid_manager_.OnNewConnectionIdFrame(frame, &error_details_), + IsError(QUIC_CONNECTION_ID_LIMIT_ERROR)); + } +} + +TEST_F(QuicPeerIssuedConnectionIdManagerTest, + ErrorWhenTheSameConnectionIdIsSeenWithDifferentSequenceNumbers) { + { + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(1); + frame.sequence_number = 1u; + frame.retire_prior_to = 0u; + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + ASSERT_THAT( + peer_issued_cid_manager_.OnNewConnectionIdFrame(frame, &error_details_), + IsQuicNoError()); + } + + { + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(1); + frame.sequence_number = 2u; + frame.retire_prior_to = 1u; + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(TestConnectionId(2)); + ASSERT_THAT( + peer_issued_cid_manager_.OnNewConnectionIdFrame(frame, &error_details_), + IsError(IETF_QUIC_PROTOCOL_VIOLATION)); + } +} + +TEST_F(QuicPeerIssuedConnectionIdManagerTest, + NewConnectionIdFrameWithTheSameSequenceNumberIsIgnored) { + { + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(1); + frame.sequence_number = 1u; + frame.retire_prior_to = 0u; + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + ASSERT_THAT( + peer_issued_cid_manager_.OnNewConnectionIdFrame(frame, &error_details_), + IsQuicNoError()); + } + + { + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(2); + frame.sequence_number = 1u; + frame.retire_prior_to = 0u; + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(TestConnectionId(2)); + EXPECT_THAT( + peer_issued_cid_manager_.OnNewConnectionIdFrame(frame, &error_details_), + IsQuicNoError()); + EXPECT_EQ( + peer_issued_cid_manager_.ConsumeOneUnusedConnectionId()->connection_id, + TestConnectionId(1)); + EXPECT_THAT(peer_issued_cid_manager_.ConsumeOneUnusedConnectionId(), + IsNull()); + } +} + +TEST_F(QuicPeerIssuedConnectionIdManagerTest, + ErrorWhenThereAreTooManyGapsInIssuedConnectionIdSequenceNumbers) { + // Add 20 intervals: [0, 1), [2, 3), ..., [38,39) + for (int i = 2; i <= 38; i += 2) { + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(i); + frame.sequence_number = i; + frame.retire_prior_to = i; + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + ASSERT_THAT( + peer_issued_cid_manager_.OnNewConnectionIdFrame(frame, &error_details_), + IsQuicNoError()); + } + + // Interval [40, 41) goes over the limit. + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(40); + frame.sequence_number = 40u; + frame.retire_prior_to = 40u; + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + ASSERT_THAT( + peer_issued_cid_manager_.OnNewConnectionIdFrame(frame, &error_details_), + IsError(IETF_QUIC_PROTOCOL_VIOLATION)); +} + +TEST_F(QuicPeerIssuedConnectionIdManagerTest, ReplaceConnectionId) { + ASSERT_TRUE( + peer_issued_cid_manager_.IsConnectionIdActive(initial_connection_id_)); + peer_issued_cid_manager_.ReplaceConnectionId(initial_connection_id_, + TestConnectionId(1)); + EXPECT_FALSE( + peer_issued_cid_manager_.IsConnectionIdActive(initial_connection_id_)); + EXPECT_TRUE( + peer_issued_cid_manager_.IsConnectionIdActive(TestConnectionId(1))); +} + +class TestSelfIssuedConnectionIdManagerVisitor + : public QuicConnectionIdManagerVisitorInterface { + public: + void OnPeerIssuedConnectionIdRetired() override {} + + MOCK_METHOD(bool, SendNewConnectionId, + (const QuicNewConnectionIdFrame& frame), (override)); + MOCK_METHOD(bool, MaybeReserveConnectionId, + (const QuicConnectionId& connection_id), (override)); + MOCK_METHOD(void, OnSelfIssuedConnectionIdRetired, + (const QuicConnectionId& connection_id), (override)); +}; + +class QuicSelfIssuedConnectionIdManagerTest : public QuicTest { + public: + QuicSelfIssuedConnectionIdManagerTest() + : cid_manager_(/*active_connection_id_limit*/ 2, initial_connection_id_, + &clock_, &alarm_factory_, &cid_manager_visitor_, + /*context=*/nullptr, connection_id_generator_) { + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + retire_self_issued_cid_alarm_ = + QuicConnectionIdManagerPeer::GetRetireSelfIssuedConnectionIdAlarm( + &cid_manager_); + } + + protected: + // Verify that a call to GenerateNewConnectionId() does the right thing. + QuicConnectionId CheckGenerate(QuicConnectionId old_cid) { + QuicConnectionId new_cid = old_cid; + (*new_cid.mutable_data())++; + // Ready for the actual call. + EXPECT_CALL(connection_id_generator_, GenerateNextConnectionId(old_cid)) + .WillOnce(Return(new_cid)); + return new_cid; + } + + MockClock clock_; + test::MockAlarmFactory alarm_factory_; + TestSelfIssuedConnectionIdManagerVisitor cid_manager_visitor_; + QuicConnectionId initial_connection_id_ = TestConnectionId(0); + StrictMock cid_manager_; + QuicAlarm* retire_self_issued_cid_alarm_ = nullptr; + std::string error_details_; + QuicTime::Delta pto_delay_ = QuicTime::Delta::FromMilliseconds(10); + MockConnectionIdGenerator connection_id_generator_; +}; + +MATCHER_P3(ExpectedNewConnectionIdFrame, connection_id, sequence_number, + retire_prior_to, "") { + return (arg.connection_id == connection_id) && + (arg.sequence_number == sequence_number) && + (arg.retire_prior_to == retire_prior_to); +} + +TEST_F(QuicSelfIssuedConnectionIdManagerTest, + RetireSelfIssuedConnectionIdInOrder) { + QuicConnectionId cid0 = initial_connection_id_; + QuicConnectionId cid1 = CheckGenerate(cid0); + QuicConnectionId cid2 = CheckGenerate(cid1); + QuicConnectionId cid3 = CheckGenerate(cid2); + QuicConnectionId cid4 = CheckGenerate(cid3); + QuicConnectionId cid5 = CheckGenerate(cid4); + + // Sends CID #1 to peer. + EXPECT_CALL(cid_manager_visitor_, MaybeReserveConnectionId(cid1)) + .WillOnce(Return(true)); + EXPECT_CALL(cid_manager_visitor_, + SendNewConnectionId(ExpectedNewConnectionIdFrame(cid1, 1u, 0u))) + .WillOnce(Return(true)); + cid_manager_.MaybeSendNewConnectionIds(); + + { + // Peer retires CID #0; + // Sends CID #2 and asks peer to retire CIDs prior to #1. + // Outcome: (#1, #2) are active. + EXPECT_CALL(cid_manager_visitor_, MaybeReserveConnectionId(cid2)) + .WillOnce(Return(true)); + EXPECT_CALL(cid_manager_visitor_, + SendNewConnectionId(ExpectedNewConnectionIdFrame(cid2, 2u, 1u))) + .WillOnce(Return(true)); + QuicRetireConnectionIdFrame retire_cid_frame; + retire_cid_frame.sequence_number = 0u; + ASSERT_THAT(cid_manager_.OnRetireConnectionIdFrame( + retire_cid_frame, pto_delay_, &error_details_), + IsQuicNoError()); + } + + { + // Peer retires CID #1; + // Sends CID #3 and asks peer to retire CIDs prior to #2. + // Outcome: (#2, #3) are active. + EXPECT_CALL(cid_manager_visitor_, MaybeReserveConnectionId(cid3)) + .WillOnce(Return(true)); + EXPECT_CALL(cid_manager_visitor_, + SendNewConnectionId(ExpectedNewConnectionIdFrame(cid3, 3u, 2u))) + .WillOnce(Return(true)); + QuicRetireConnectionIdFrame retire_cid_frame; + retire_cid_frame.sequence_number = 1u; + ASSERT_THAT(cid_manager_.OnRetireConnectionIdFrame( + retire_cid_frame, pto_delay_, &error_details_), + IsQuicNoError()); + } + + { + // Peer retires CID #2; + // Sends CID #4 and asks peer to retire CIDs prior to #3. + // Outcome: (#3, #4) are active. + EXPECT_CALL(cid_manager_visitor_, MaybeReserveConnectionId(cid4)) + .WillOnce(Return(true)); + EXPECT_CALL(cid_manager_visitor_, + SendNewConnectionId(ExpectedNewConnectionIdFrame(cid4, 4u, 3u))) + .WillOnce(Return(true)); + QuicRetireConnectionIdFrame retire_cid_frame; + retire_cid_frame.sequence_number = 2u; + ASSERT_THAT(cid_manager_.OnRetireConnectionIdFrame( + retire_cid_frame, pto_delay_, &error_details_), + IsQuicNoError()); + } + + { + // Peer retires CID #3; + // Sends CID #5 and asks peer to retire CIDs prior to #4. + // Outcome: (#4, #5) are active. + EXPECT_CALL(cid_manager_visitor_, MaybeReserveConnectionId(cid5)) + .WillOnce(Return(true)); + EXPECT_CALL(cid_manager_visitor_, + SendNewConnectionId(ExpectedNewConnectionIdFrame(cid5, 5u, 4u))) + .WillOnce(Return(true)); + QuicRetireConnectionIdFrame retire_cid_frame; + retire_cid_frame.sequence_number = 3u; + ASSERT_THAT(cid_manager_.OnRetireConnectionIdFrame( + retire_cid_frame, pto_delay_, &error_details_), + IsQuicNoError()); + } +} + +TEST_F(QuicSelfIssuedConnectionIdManagerTest, + RetireSelfIssuedConnectionIdOutOfOrder) { + QuicConnectionId cid0 = initial_connection_id_; + QuicConnectionId cid1 = CheckGenerate(cid0); + QuicConnectionId cid2 = CheckGenerate(cid1); + QuicConnectionId cid3 = CheckGenerate(cid2); + QuicConnectionId cid4 = CheckGenerate(cid3); + + // Sends CID #1 to peer. + EXPECT_CALL(cid_manager_visitor_, MaybeReserveConnectionId(cid1)) + .WillOnce(Return(true)); + EXPECT_CALL(cid_manager_visitor_, + SendNewConnectionId(ExpectedNewConnectionIdFrame(cid1, 1u, 0u))) + .WillOnce(Return(true)); + cid_manager_.MaybeSendNewConnectionIds(); + + { + // Peer retires CID #1; + // Sends CID #2 and asks peer to retire CIDs prior to #0. + // Outcome: (#0, #2) are active. + EXPECT_CALL(cid_manager_visitor_, MaybeReserveConnectionId(cid2)) + .WillOnce(Return(true)); + EXPECT_CALL(cid_manager_visitor_, + SendNewConnectionId(ExpectedNewConnectionIdFrame(cid2, 2u, 0u))) + .WillOnce(Return(true)); + QuicRetireConnectionIdFrame retire_cid_frame; + retire_cid_frame.sequence_number = 1u; + ASSERT_THAT(cid_manager_.OnRetireConnectionIdFrame( + retire_cid_frame, pto_delay_, &error_details_), + IsQuicNoError()); + } + + { + // Peer retires CID #1 again. This is a no-op. + QuicRetireConnectionIdFrame retire_cid_frame; + retire_cid_frame.sequence_number = 1u; + ASSERT_THAT(cid_manager_.OnRetireConnectionIdFrame( + retire_cid_frame, pto_delay_, &error_details_), + IsQuicNoError()); + } + + { + // Peer retires CID #0; + // Sends CID #3 and asks peer to retire CIDs prior to #2. + // Outcome: (#2, #3) are active. + EXPECT_CALL(cid_manager_visitor_, MaybeReserveConnectionId(cid3)) + .WillOnce(Return(true)); + EXPECT_CALL(cid_manager_visitor_, + SendNewConnectionId(ExpectedNewConnectionIdFrame(cid3, 3u, 2u))) + .WillOnce(Return(true)); + QuicRetireConnectionIdFrame retire_cid_frame; + retire_cid_frame.sequence_number = 0u; + ASSERT_THAT(cid_manager_.OnRetireConnectionIdFrame( + retire_cid_frame, pto_delay_, &error_details_), + IsQuicNoError()); + } + + { + // Peer retires CID #3; + // Sends CID #4 and asks peer to retire CIDs prior to #2. + // Outcome: (#2, #4) are active. + EXPECT_CALL(cid_manager_visitor_, MaybeReserveConnectionId(cid4)) + .WillOnce(Return(true)); + EXPECT_CALL(cid_manager_visitor_, + SendNewConnectionId(ExpectedNewConnectionIdFrame(cid4, 4u, 2u))) + .WillOnce(Return(true)); + QuicRetireConnectionIdFrame retire_cid_frame; + retire_cid_frame.sequence_number = 3u; + ASSERT_THAT(cid_manager_.OnRetireConnectionIdFrame( + retire_cid_frame, pto_delay_, &error_details_), + IsQuicNoError()); + } + + { + // Peer retires CID #0 again. This is a no-op. + QuicRetireConnectionIdFrame retire_cid_frame; + retire_cid_frame.sequence_number = 0u; + ASSERT_THAT(cid_manager_.OnRetireConnectionIdFrame( + retire_cid_frame, pto_delay_, &error_details_), + IsQuicNoError()); + } +} + +TEST_F(QuicSelfIssuedConnectionIdManagerTest, + ScheduleConnectionIdRetirementOneAtATime) { + QuicConnectionId cid0 = initial_connection_id_; + QuicConnectionId cid1 = CheckGenerate(cid0); + QuicConnectionId cid2 = CheckGenerate(cid1); + QuicConnectionId cid3 = CheckGenerate(cid2); + EXPECT_CALL(cid_manager_visitor_, MaybeReserveConnectionId(_)) + .Times(3) + .WillRepeatedly(Return(true)); + EXPECT_CALL(cid_manager_visitor_, SendNewConnectionId(_)) + .Times(3) + .WillRepeatedly(Return(true)); + QuicTime::Delta connection_id_expire_timeout = 3 * pto_delay_; + QuicRetireConnectionIdFrame retire_cid_frame; + + // CID #1 is sent to peer. + cid_manager_.MaybeSendNewConnectionIds(); + + // CID #0's retirement is scheduled and CID #2 is sent to peer. + retire_cid_frame.sequence_number = 0u; + ASSERT_THAT(cid_manager_.OnRetireConnectionIdFrame( + retire_cid_frame, pto_delay_, &error_details_), + IsQuicNoError()); + // While CID #0's retirement is scheduled, it is not retired yet. + EXPECT_THAT(cid_manager_.GetUnretiredConnectionIds(), + ElementsAre(cid0, cid1, cid2)); + EXPECT_TRUE(retire_self_issued_cid_alarm_->IsSet()); + EXPECT_EQ(retire_self_issued_cid_alarm_->deadline(), + clock_.ApproximateNow() + connection_id_expire_timeout); + + // CID #0 is actually retired. + EXPECT_CALL(cid_manager_visitor_, OnSelfIssuedConnectionIdRetired(cid0)); + clock_.AdvanceTime(connection_id_expire_timeout); + alarm_factory_.FireAlarm(retire_self_issued_cid_alarm_); + EXPECT_THAT(cid_manager_.GetUnretiredConnectionIds(), + ElementsAre(cid1, cid2)); + EXPECT_FALSE(retire_self_issued_cid_alarm_->IsSet()); + + // CID #1's retirement is scheduled and CID #3 is sent to peer. + retire_cid_frame.sequence_number = 1u; + ASSERT_THAT(cid_manager_.OnRetireConnectionIdFrame( + retire_cid_frame, pto_delay_, &error_details_), + IsQuicNoError()); + // While CID #1's retirement is scheduled, it is not retired yet. + EXPECT_THAT(cid_manager_.GetUnretiredConnectionIds(), + ElementsAre(cid1, cid2, cid3)); + EXPECT_TRUE(retire_self_issued_cid_alarm_->IsSet()); + EXPECT_EQ(retire_self_issued_cid_alarm_->deadline(), + clock_.ApproximateNow() + connection_id_expire_timeout); + + // CID #1 is actually retired. + EXPECT_CALL(cid_manager_visitor_, OnSelfIssuedConnectionIdRetired(cid1)); + clock_.AdvanceTime(connection_id_expire_timeout); + alarm_factory_.FireAlarm(retire_self_issued_cid_alarm_); + EXPECT_THAT(cid_manager_.GetUnretiredConnectionIds(), + ElementsAre(cid2, cid3)); + EXPECT_FALSE(retire_self_issued_cid_alarm_->IsSet()); +} + +TEST_F(QuicSelfIssuedConnectionIdManagerTest, + ScheduleMultipleConnectionIdRetirement) { + QuicConnectionId cid0 = initial_connection_id_; + QuicConnectionId cid1 = CheckGenerate(cid0); + QuicConnectionId cid2 = CheckGenerate(cid1); + QuicConnectionId cid3 = CheckGenerate(cid2); + EXPECT_CALL(cid_manager_visitor_, MaybeReserveConnectionId(_)) + .Times(3) + .WillRepeatedly(Return(true)); + EXPECT_CALL(cid_manager_visitor_, SendNewConnectionId(_)) + .Times(3) + .WillRepeatedly(Return(true)); + QuicTime::Delta connection_id_expire_timeout = 3 * pto_delay_; + QuicRetireConnectionIdFrame retire_cid_frame; + + // CID #1 is sent to peer. + cid_manager_.MaybeSendNewConnectionIds(); + + // CID #0's retirement is scheduled and CID #2 is sent to peer. + retire_cid_frame.sequence_number = 0u; + ASSERT_THAT(cid_manager_.OnRetireConnectionIdFrame( + retire_cid_frame, pto_delay_, &error_details_), + IsQuicNoError()); + + clock_.AdvanceTime(connection_id_expire_timeout * 0.25); + + // CID #1's retirement is scheduled and CID #3 is sent to peer. + retire_cid_frame.sequence_number = 1u; + ASSERT_THAT(cid_manager_.OnRetireConnectionIdFrame( + retire_cid_frame, pto_delay_, &error_details_), + IsQuicNoError()); + + // While CID #0, #1s retirement is scheduled, they are not retired yet. + EXPECT_THAT(cid_manager_.GetUnretiredConnectionIds(), + ElementsAre(cid0, cid1, cid2, cid3)); + EXPECT_TRUE(retire_self_issued_cid_alarm_->IsSet()); + EXPECT_EQ(retire_self_issued_cid_alarm_->deadline(), + clock_.ApproximateNow() + connection_id_expire_timeout * 0.75); + + // CID #0 is actually retired. + EXPECT_CALL(cid_manager_visitor_, OnSelfIssuedConnectionIdRetired(cid0)); + clock_.AdvanceTime(connection_id_expire_timeout * 0.75); + alarm_factory_.FireAlarm(retire_self_issued_cid_alarm_); + EXPECT_THAT(cid_manager_.GetUnretiredConnectionIds(), + ElementsAre(cid1, cid2, cid3)); + EXPECT_TRUE(retire_self_issued_cid_alarm_->IsSet()); + EXPECT_EQ(retire_self_issued_cid_alarm_->deadline(), + clock_.ApproximateNow() + connection_id_expire_timeout * 0.25); + + // CID #1 is actually retired. + EXPECT_CALL(cid_manager_visitor_, OnSelfIssuedConnectionIdRetired(cid1)); + clock_.AdvanceTime(connection_id_expire_timeout * 0.25); + alarm_factory_.FireAlarm(retire_self_issued_cid_alarm_); + EXPECT_THAT(cid_manager_.GetUnretiredConnectionIds(), + ElementsAre(cid2, cid3)); + EXPECT_FALSE(retire_self_issued_cid_alarm_->IsSet()); +} + +TEST_F(QuicSelfIssuedConnectionIdManagerTest, + AllExpiredConnectionIdsAreRetiredInOneBatch) { + QuicConnectionId cid0 = initial_connection_id_; + QuicConnectionId cid1 = CheckGenerate(cid0); + QuicConnectionId cid2 = CheckGenerate(cid1); + QuicConnectionId cid3 = CheckGenerate(cid2); + QuicConnectionId cid; + EXPECT_CALL(cid_manager_visitor_, MaybeReserveConnectionId(_)) + .Times(3) + .WillRepeatedly(Return(true)); + EXPECT_CALL(cid_manager_visitor_, SendNewConnectionId(_)) + .Times(3) + .WillRepeatedly(Return(true)); + QuicTime::Delta connection_id_expire_timeout = 3 * pto_delay_; + QuicRetireConnectionIdFrame retire_cid_frame; + EXPECT_TRUE(cid_manager_.IsConnectionIdInUse(cid0)); + EXPECT_FALSE(cid_manager_.HasConnectionIdToConsume()); + EXPECT_FALSE(cid_manager_.ConsumeOneConnectionId().has_value()); + + // CID #1 is sent to peer. + cid_manager_.MaybeSendNewConnectionIds(); + EXPECT_TRUE(cid_manager_.IsConnectionIdInUse(cid1)); + EXPECT_TRUE(cid_manager_.HasConnectionIdToConsume()); + cid = *cid_manager_.ConsumeOneConnectionId(); + EXPECT_EQ(cid1, cid); + EXPECT_FALSE(cid_manager_.HasConnectionIdToConsume()); + + // CID #0's retirement is scheduled and CID #2 is sent to peer. + retire_cid_frame.sequence_number = 0u; + cid_manager_.OnRetireConnectionIdFrame(retire_cid_frame, pto_delay_, + &error_details_); + EXPECT_TRUE(cid_manager_.IsConnectionIdInUse(cid0)); + EXPECT_TRUE(cid_manager_.IsConnectionIdInUse(cid1)); + EXPECT_TRUE(cid_manager_.IsConnectionIdInUse(cid2)); + EXPECT_TRUE(cid_manager_.HasConnectionIdToConsume()); + cid = *cid_manager_.ConsumeOneConnectionId(); + EXPECT_EQ(cid2, cid); + EXPECT_FALSE(cid_manager_.HasConnectionIdToConsume()); + + clock_.AdvanceTime(connection_id_expire_timeout * 0.1); + + // CID #1's retirement is scheduled and CID #3 is sent to peer. + retire_cid_frame.sequence_number = 1u; + cid_manager_.OnRetireConnectionIdFrame(retire_cid_frame, pto_delay_, + &error_details_); + + { + // CID #0 & #1 are retired in a single alarm fire. + clock_.AdvanceTime(connection_id_expire_timeout); + testing::InSequence s; + EXPECT_CALL(cid_manager_visitor_, OnSelfIssuedConnectionIdRetired(cid0)); + EXPECT_CALL(cid_manager_visitor_, OnSelfIssuedConnectionIdRetired(cid1)); + alarm_factory_.FireAlarm(retire_self_issued_cid_alarm_); + EXPECT_FALSE(cid_manager_.IsConnectionIdInUse(cid0)); + EXPECT_FALSE(cid_manager_.IsConnectionIdInUse(cid1)); + EXPECT_TRUE(cid_manager_.IsConnectionIdInUse(cid2)); + EXPECT_THAT(cid_manager_.GetUnretiredConnectionIds(), + ElementsAre(cid2, cid3)); + EXPECT_FALSE(retire_self_issued_cid_alarm_->IsSet()); + } +} + +TEST_F(QuicSelfIssuedConnectionIdManagerTest, + ErrorWhenRetireConnectionIdNeverIssued) { + QuicConnectionId cid0 = initial_connection_id_; + QuicConnectionId cid1 = CheckGenerate(cid0); + + // CID #1 is sent to peer. + EXPECT_CALL(cid_manager_visitor_, MaybeReserveConnectionId(_)) + .WillOnce(Return(true)); + EXPECT_CALL(cid_manager_visitor_, SendNewConnectionId(_)) + .WillOnce(Return(true)); + cid_manager_.MaybeSendNewConnectionIds(); + + // CID #2 is never issued. + QuicRetireConnectionIdFrame retire_cid_frame; + retire_cid_frame.sequence_number = 2u; + ASSERT_THAT(cid_manager_.OnRetireConnectionIdFrame( + retire_cid_frame, pto_delay_, &error_details_), + IsError(IETF_QUIC_PROTOCOL_VIOLATION)); +} + +TEST_F(QuicSelfIssuedConnectionIdManagerTest, + ErrorWhenTooManyConnectionIdWaitingToBeRetired) { + // CID #0 & #1 are issued. + QuicConnectionId last_connection_id = CheckGenerate(initial_connection_id_); + EXPECT_CALL(cid_manager_visitor_, + MaybeReserveConnectionId(last_connection_id)) + .WillOnce(Return(true)); + EXPECT_CALL(cid_manager_visitor_, SendNewConnectionId(_)) + .WillOnce(Return(true)); + cid_manager_.MaybeSendNewConnectionIds(); + + // Add 8 connection IDs to the to-be-retired list. + + for (int i = 0; i < 8; ++i) { + last_connection_id = CheckGenerate(last_connection_id); + EXPECT_CALL(cid_manager_visitor_, + MaybeReserveConnectionId(last_connection_id)) + .WillOnce(Return(true)); + EXPECT_CALL(cid_manager_visitor_, SendNewConnectionId(_)); + QuicRetireConnectionIdFrame retire_cid_frame; + retire_cid_frame.sequence_number = i; + ASSERT_THAT(cid_manager_.OnRetireConnectionIdFrame( + retire_cid_frame, pto_delay_, &error_details_), + IsQuicNoError()); + } + QuicRetireConnectionIdFrame retire_cid_frame; + retire_cid_frame.sequence_number = 8u; + // This would have push the number of to-be-retired connection IDs over its + // limit. + ASSERT_THAT(cid_manager_.OnRetireConnectionIdFrame( + retire_cid_frame, pto_delay_, &error_details_), + IsError(QUIC_TOO_MANY_CONNECTION_ID_WAITING_TO_RETIRE)); +} + +TEST_F(QuicSelfIssuedConnectionIdManagerTest, CannotIssueNewCidDueToVisitor) { + QuicConnectionId cid0 = initial_connection_id_; + QuicConnectionId cid1 = CheckGenerate(cid0); + EXPECT_CALL(cid_manager_visitor_, MaybeReserveConnectionId(cid1)) + .WillOnce(Return(false)); + if (GetQuicReloadableFlag(quic_check_cid_collision_when_issue_new_cid)) { + EXPECT_CALL(cid_manager_visitor_, SendNewConnectionId(_)).Times(0); + } else { + EXPECT_CALL(cid_manager_visitor_, SendNewConnectionId(_)).Times(1); + } + cid_manager_.MaybeSendNewConnectionIds(); +} + +TEST_F(QuicSelfIssuedConnectionIdManagerTest, + CannotIssueNewCidUponRetireConnectionIdDueToVisitor) { + QuicConnectionId cid0 = initial_connection_id_; + QuicConnectionId cid1 = CheckGenerate(cid0); + QuicConnectionId cid2 = CheckGenerate(cid1); + // CID #0 & #1 are issued. + EXPECT_CALL(cid_manager_visitor_, MaybeReserveConnectionId(cid1)) + .WillOnce(Return(true)); + EXPECT_CALL(cid_manager_visitor_, SendNewConnectionId(_)) + .WillOnce(Return(true)); + cid_manager_.MaybeSendNewConnectionIds(); + + // CID #2 is not issued. + EXPECT_CALL(cid_manager_visitor_, MaybeReserveConnectionId(cid2)) + .WillOnce(Return(false)); + if (GetQuicReloadableFlag(quic_check_cid_collision_when_issue_new_cid)) { + EXPECT_CALL(cid_manager_visitor_, SendNewConnectionId(_)).Times(0); + } else { + EXPECT_CALL(cid_manager_visitor_, SendNewConnectionId(_)).Times(1); + } + QuicRetireConnectionIdFrame retire_cid_frame; + retire_cid_frame.sequence_number = 1; + ASSERT_THAT(cid_manager_.OnRetireConnectionIdFrame( + retire_cid_frame, pto_delay_, &error_details_), + IsQuicNoError()); +} + +TEST_F(QuicSelfIssuedConnectionIdManagerTest, + DoNotIssueConnectionIdVoluntarilyIfOneHasIssuedForPerferredAddress) { + QuicConnectionId cid0 = initial_connection_id_; + QuicConnectionId cid1 = CheckGenerate(cid0); + EXPECT_CALL(cid_manager_visitor_, MaybeReserveConnectionId(cid1)) + .WillOnce(Return(true)); + absl::optional new_cid_frame = + cid_manager_.MaybeIssueNewConnectionIdForPreferredAddress(); + ASSERT_TRUE(new_cid_frame.has_value()); + ASSERT_THAT(*new_cid_frame, ExpectedNewConnectionIdFrame(cid1, 1u, 0u)); + EXPECT_THAT(cid_manager_.GetUnretiredConnectionIds(), + ElementsAre(cid0, cid1)); + + EXPECT_CALL(cid_manager_visitor_, MaybeReserveConnectionId(_)).Times(0); + EXPECT_CALL(cid_manager_visitor_, SendNewConnectionId(_)).Times(0); + cid_manager_.MaybeSendNewConnectionIds(); +} + +// Regression test for b/258450534 +TEST_F(QuicSelfIssuedConnectionIdManagerTest, + RetireConnectionIdAfterConnectionIdCollisionIsFine) { + SetQuicReloadableFlag(quic_check_cid_collision_when_issue_new_cid, true); + QuicConnectionId cid0 = initial_connection_id_; + QuicConnectionId cid1 = CheckGenerate(cid0); + EXPECT_CALL(cid_manager_visitor_, MaybeReserveConnectionId(cid1)) + .WillOnce(Return(true)); + EXPECT_CALL(cid_manager_visitor_, SendNewConnectionId(_)) + .WillOnce(Return(true)); + cid_manager_.MaybeSendNewConnectionIds(); + + QuicRetireConnectionIdFrame retire_cid_frame(/*control_frame_id=*/0, + /*sequence_number=*/1); + QuicConnectionId cid2 = CheckGenerate(cid1); + // This happens when cid2 is aleady present in the dispatcher map. + EXPECT_CALL(cid_manager_visitor_, MaybeReserveConnectionId(cid2)) + .WillOnce(Return(false)); + std::string error_details; + EXPECT_EQ( + cid_manager_.OnRetireConnectionIdFrame( + retire_cid_frame, QuicTime::Delta::FromSeconds(1), &error_details), + QUIC_NO_ERROR) + << error_details; + + if (GetQuicReloadableFlag( + quic_check_retire_cid_with_next_cid_sequence_number)) { + EXPECT_EQ( + cid_manager_.OnRetireConnectionIdFrame( + retire_cid_frame, QuicTime::Delta::FromSeconds(1), &error_details), + QUIC_NO_ERROR) + << error_details; + } else { + EXPECT_EQ( + cid_manager_.OnRetireConnectionIdFrame( + retire_cid_frame, QuicTime::Delta::FromSeconds(1), &error_details), + IETF_QUIC_PROTOCOL_VIOLATION); + } +} + +} // namespace +} // namespace quic::test diff --git a/quiche/quic/core/quic_connection_id_test.cc b/quiche/quic/core/quic_connection_id_test.cc new file mode 100644 index 000000000000..fdd4c6f96c20 --- /dev/null +++ b/quiche/quic/core/quic_connection_id_test.cc @@ -0,0 +1,181 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_connection_id.h" + +#include +#include +#include + +#include "absl/base/macros.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic::test { + +namespace { + +class QuicConnectionIdTest : public QuicTest {}; + +TEST_F(QuicConnectionIdTest, Empty) { + QuicConnectionId connection_id_empty = EmptyQuicConnectionId(); + EXPECT_TRUE(connection_id_empty.IsEmpty()); +} + +TEST_F(QuicConnectionIdTest, DefaultIsEmpty) { + QuicConnectionId connection_id_empty = QuicConnectionId(); + EXPECT_TRUE(connection_id_empty.IsEmpty()); +} + +TEST_F(QuicConnectionIdTest, NotEmpty) { + QuicConnectionId connection_id = test::TestConnectionId(1); + EXPECT_FALSE(connection_id.IsEmpty()); +} + +TEST_F(QuicConnectionIdTest, ZeroIsNotEmpty) { + QuicConnectionId connection_id = test::TestConnectionId(0); + EXPECT_FALSE(connection_id.IsEmpty()); +} + +TEST_F(QuicConnectionIdTest, Data) { + char connection_id_data[kQuicDefaultConnectionIdLength]; + memset(connection_id_data, 0x42, sizeof(connection_id_data)); + QuicConnectionId connection_id1 = + QuicConnectionId(connection_id_data, sizeof(connection_id_data)); + QuicConnectionId connection_id2 = + QuicConnectionId(connection_id_data, sizeof(connection_id_data)); + EXPECT_EQ(connection_id1, connection_id2); + EXPECT_EQ(connection_id1.length(), kQuicDefaultConnectionIdLength); + EXPECT_EQ(connection_id1.data(), connection_id1.mutable_data()); + EXPECT_EQ(0, memcmp(connection_id1.data(), connection_id2.data(), + sizeof(connection_id_data))); + EXPECT_EQ(0, memcmp(connection_id1.data(), connection_id_data, + sizeof(connection_id_data))); + connection_id2.mutable_data()[0] = 0x33; + EXPECT_NE(connection_id1, connection_id2); + static const uint8_t kNewLength = 4; + connection_id2.set_length(kNewLength); + EXPECT_EQ(kNewLength, connection_id2.length()); +} + +TEST_F(QuicConnectionIdTest, SpanData) { + QuicConnectionId connection_id = QuicConnectionId({0x01, 0x02, 0x03}); + EXPECT_EQ(connection_id.length(), 3); + QuicConnectionId empty_connection_id = + QuicConnectionId(absl::Span()); + EXPECT_EQ(empty_connection_id.length(), 0); + QuicConnectionId connection_id2 = QuicConnectionId({ + 0x01, + 0x02, + 0x03, + 0x04, + 0x05, + 0x06, + 0x07, + 0x08, + 0x09, + 0x0a, + 0x0b, + 0x0c, + 0x0d, + 0x0e, + 0x0f, + 0x10, + }); + EXPECT_EQ(connection_id2.length(), 16); +} + +TEST_F(QuicConnectionIdTest, DoubleConvert) { + QuicConnectionId connection_id64_1 = test::TestConnectionId(1); + QuicConnectionId connection_id64_2 = test::TestConnectionId(42); + QuicConnectionId connection_id64_3 = + test::TestConnectionId(UINT64_C(0xfedcba9876543210)); + EXPECT_EQ(connection_id64_1, + test::TestConnectionId( + test::TestConnectionIdToUInt64(connection_id64_1))); + EXPECT_EQ(connection_id64_2, + test::TestConnectionId( + test::TestConnectionIdToUInt64(connection_id64_2))); + EXPECT_EQ(connection_id64_3, + test::TestConnectionId( + test::TestConnectionIdToUInt64(connection_id64_3))); + EXPECT_NE(connection_id64_1, connection_id64_2); + EXPECT_NE(connection_id64_1, connection_id64_3); + EXPECT_NE(connection_id64_2, connection_id64_3); +} + +TEST_F(QuicConnectionIdTest, Hash) { + QuicConnectionId connection_id64_1 = test::TestConnectionId(1); + QuicConnectionId connection_id64_1b = test::TestConnectionId(1); + QuicConnectionId connection_id64_2 = test::TestConnectionId(42); + QuicConnectionId connection_id64_3 = + test::TestConnectionId(UINT64_C(0xfedcba9876543210)); + EXPECT_EQ(connection_id64_1.Hash(), connection_id64_1b.Hash()); + EXPECT_NE(connection_id64_1.Hash(), connection_id64_2.Hash()); + EXPECT_NE(connection_id64_1.Hash(), connection_id64_3.Hash()); + EXPECT_NE(connection_id64_2.Hash(), connection_id64_3.Hash()); + + // Verify that any two all-zero connection IDs of different lengths never + // have the same hash. + const char connection_id_bytes[255] = {}; + for (uint8_t i = 0; i < sizeof(connection_id_bytes) - 1; ++i) { + QuicConnectionId connection_id_i(connection_id_bytes, i); + for (uint8_t j = i + 1; j < sizeof(connection_id_bytes); ++j) { + QuicConnectionId connection_id_j(connection_id_bytes, j); + EXPECT_NE(connection_id_i.Hash(), connection_id_j.Hash()); + } + } +} + +TEST_F(QuicConnectionIdTest, AssignAndCopy) { + QuicConnectionId connection_id = test::TestConnectionId(1); + QuicConnectionId connection_id2 = test::TestConnectionId(2); + connection_id = connection_id2; + EXPECT_EQ(connection_id, test::TestConnectionId(2)); + EXPECT_NE(connection_id, test::TestConnectionId(1)); + connection_id = QuicConnectionId(test::TestConnectionId(1)); + EXPECT_EQ(connection_id, test::TestConnectionId(1)); + EXPECT_NE(connection_id, test::TestConnectionId(2)); +} + +TEST_F(QuicConnectionIdTest, ChangeLength) { + QuicConnectionId connection_id64_1 = test::TestConnectionId(1); + QuicConnectionId connection_id64_2 = test::TestConnectionId(2); + QuicConnectionId connection_id136_2 = test::TestConnectionId(2); + connection_id136_2.set_length(17); + memset(connection_id136_2.mutable_data() + 8, 0, 9); + char connection_id136_2_bytes[17] = {0, 0, 0, 0, 0, 0, 0, 2, 0, + 0, 0, 0, 0, 0, 0, 0, 0}; + QuicConnectionId connection_id136_2b(connection_id136_2_bytes, + sizeof(connection_id136_2_bytes)); + EXPECT_EQ(connection_id136_2, connection_id136_2b); + QuicConnectionId connection_id = connection_id64_1; + connection_id.set_length(17); + EXPECT_NE(connection_id64_1, connection_id); + // Check resizing big to small. + connection_id.set_length(8); + EXPECT_EQ(connection_id64_1, connection_id); + // Check resizing small to big. + connection_id.set_length(17); + memset(connection_id.mutable_data(), 0, connection_id.length()); + memcpy(connection_id.mutable_data(), connection_id64_2.data(), + connection_id64_2.length()); + EXPECT_EQ(connection_id136_2, connection_id); + EXPECT_EQ(connection_id136_2b, connection_id); + QuicConnectionId connection_id120(connection_id136_2_bytes, 15); + connection_id.set_length(15); + EXPECT_EQ(connection_id120, connection_id); + // Check resizing big to big. + QuicConnectionId connection_id2 = connection_id120; + connection_id2.set_length(17); + connection_id2.mutable_data()[15] = 0; + connection_id2.mutable_data()[16] = 0; + EXPECT_EQ(connection_id136_2, connection_id2); + EXPECT_EQ(connection_id136_2b, connection_id2); +} + +} // namespace + +} // namespace quic::test diff --git a/quiche/quic/core/quic_connection_stats.cc b/quiche/quic/core/quic_connection_stats.cc new file mode 100644 index 000000000000..fd7428560230 --- /dev/null +++ b/quiche/quic/core/quic_connection_stats.cc @@ -0,0 +1,77 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_connection_stats.h" + +namespace quic { + +std::ostream& operator<<(std::ostream& os, const QuicConnectionStats& s) { + os << "{ bytes_sent: " << s.bytes_sent; + os << " packets_sent: " << s.packets_sent; + os << " stream_bytes_sent: " << s.stream_bytes_sent; + os << " packets_discarded: " << s.packets_discarded; + os << " bytes_received: " << s.bytes_received; + os << " packets_received: " << s.packets_received; + os << " packets_processed: " << s.packets_processed; + os << " stream_bytes_received: " << s.stream_bytes_received; + os << " bytes_retransmitted: " << s.bytes_retransmitted; + os << " packets_retransmitted: " << s.packets_retransmitted; + os << " bytes_spuriously_retransmitted: " << s.bytes_spuriously_retransmitted; + os << " packets_spuriously_retransmitted: " + << s.packets_spuriously_retransmitted; + os << " packets_lost: " << s.packets_lost; + os << " slowstart_packets_sent: " << s.slowstart_packets_sent; + os << " slowstart_packets_lost: " << s.slowstart_packets_lost; + os << " slowstart_bytes_lost: " << s.slowstart_bytes_lost; + os << " packets_dropped: " << s.packets_dropped; + os << " undecryptable_packets_received_before_handshake_complete: " + << s.undecryptable_packets_received_before_handshake_complete; + os << " crypto_retransmit_count: " << s.crypto_retransmit_count; + os << " loss_timeout_count: " << s.loss_timeout_count; + os << " tlp_count: " << s.tlp_count; + os << " rto_count: " << s.rto_count; + os << " pto_count: " << s.pto_count; + os << " min_rtt_us: " << s.min_rtt_us; + os << " srtt_us: " << s.srtt_us; + os << " egress_mtu: " << s.egress_mtu; + os << " max_egress_mtu: " << s.max_egress_mtu; + os << " ingress_mtu: " << s.ingress_mtu; + os << " estimated_bandwidth: " << s.estimated_bandwidth; + os << " packets_reordered: " << s.packets_reordered; + os << " max_sequence_reordering: " << s.max_sequence_reordering; + os << " max_time_reordering_us: " << s.max_time_reordering_us; + os << " tcp_loss_events: " << s.tcp_loss_events; + os << " connection_creation_time: " + << s.connection_creation_time.ToDebuggingValue(); + os << " blocked_frames_received: " << s.blocked_frames_received; + os << " blocked_frames_sent: " << s.blocked_frames_sent; + os << " num_connectivity_probing_received: " + << s.num_connectivity_probing_received; + os << " num_path_response_received: " << s.num_path_response_received; + os << " retry_packet_processed: " + << (s.retry_packet_processed ? "yes" : "no"); + os << " num_coalesced_packets_received: " << s.num_coalesced_packets_received; + os << " num_coalesced_packets_processed: " + << s.num_coalesced_packets_processed; + os << " num_ack_aggregation_epochs: " << s.num_ack_aggregation_epochs; + os << " key_update_count: " << s.key_update_count; + os << " num_failed_authentication_packets_received: " + << s.num_failed_authentication_packets_received; + os << " num_tls_server_zero_rtt_packets_received_after_discarding_decrypter: " + << s.num_tls_server_zero_rtt_packets_received_after_discarding_decrypter; + os << " address_validated_via_decrypting_packet: " + << s.address_validated_via_decrypting_packet; + os << " address_validated_via_token: " << s.address_validated_via_token; + os << " server_preferred_address_validated: " + << s.server_preferred_address_validated; + os << " failed_to_validate_server_preferred_address: " + << s.failed_to_validate_server_preferred_address; + os << " num_duplicated_packets_sent_to_server_preferred_address: " + << s.num_duplicated_packets_sent_to_server_preferred_address; + os << " }"; + + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/quic_connection_stats.h b/quiche/quic/core/quic_connection_stats.h new file mode 100644 index 000000000000..336435ea7bfb --- /dev/null +++ b/quiche/quic/core/quic_connection_stats.h @@ -0,0 +1,253 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_CONNECTION_STATS_H_ +#define QUICHE_QUIC_CORE_QUIC_CONNECTION_STATS_H_ + +#include +#include + +#include "quiche/quic/core/quic_bandwidth.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_time_accumulator.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// Structure to hold stats for a QuicConnection. +struct QUIC_EXPORT_PRIVATE QuicConnectionStats { + QUIC_EXPORT_PRIVATE friend std::ostream& operator<<( + std::ostream& os, const QuicConnectionStats& s); + + QuicByteCount bytes_sent = 0; // Includes retransmissions. + QuicPacketCount packets_sent = 0; + // Non-retransmitted bytes sent in a stream frame. + QuicByteCount stream_bytes_sent = 0; + // Packets serialized and discarded before sending. + QuicPacketCount packets_discarded = 0; + + // These include version negotiation and public reset packets, which do not + // have packet numbers or frame data. + QuicByteCount bytes_received = 0; // Includes duplicate data for a stream. + // Includes packets which were not processable. + QuicPacketCount packets_received = 0; + // Excludes packets which were not processable. + QuicPacketCount packets_processed = 0; + QuicByteCount stream_bytes_received = 0; // Bytes received in a stream frame. + + QuicByteCount bytes_retransmitted = 0; + QuicPacketCount packets_retransmitted = 0; + + QuicByteCount bytes_spuriously_retransmitted = 0; + QuicPacketCount packets_spuriously_retransmitted = 0; + // Number of packets abandoned as lost by the loss detection algorithm. + QuicPacketCount packets_lost = 0; + QuicPacketCount packet_spuriously_detected_lost = 0; + + // The sum of loss detection response times of all lost packets, in number of + // round trips. + // Given a packet detected as lost: + // T(S) T(1Rtt) T(D) + // |_________________________________|_______| + // Where + // T(S) is the time when the packet is sent. + // T(1Rtt) is one rtt after T(S), using the rtt at the time of detection. + // T(D) is the time of detection, i.e. when the packet is declared as lost. + // The loss detection response time is defined as + // (T(D) - T(S)) / (T(1Rtt) - T(S)) + // + // The average loss detection response time is this number divided by + // |packets_lost|. Smaller result means detection is faster. + float total_loss_detection_response_time = 0.0; + + // Number of times this connection went through the slow start phase. + uint32_t slowstart_count = 0; + // Number of round trips spent in slow start. + uint32_t slowstart_num_rtts = 0; + // Number of packets sent in slow start. + QuicPacketCount slowstart_packets_sent = 0; + // Number of bytes sent in slow start. + QuicByteCount slowstart_bytes_sent = 0; + // Number of packets lost exiting slow start. + QuicPacketCount slowstart_packets_lost = 0; + // Number of bytes lost exiting slow start. + QuicByteCount slowstart_bytes_lost = 0; + // Time spent in slow start. Populated for BBRv1 and BBRv2. + QuicTimeAccumulator slowstart_duration; + + // Number of PROBE_BW cycles. Populated for BBRv1 and BBRv2. + uint32_t bbr_num_cycles = 0; + // Number of PROBE_BW cycles shortened for reno coexistence. BBRv2 only. + uint32_t bbr_num_short_cycles_for_reno_coexistence = 0; + // Whether BBR exited STARTUP due to excessive loss. Populated for BBRv1 and + // BBRv2. + bool bbr_exit_startup_due_to_loss = false; + + QuicPacketCount packets_dropped = 0; // Duplicate or less than least unacked. + + // Packets that failed to decrypt when they were first received, + // before the handshake was complete. + QuicPacketCount undecryptable_packets_received_before_handshake_complete = 0; + + size_t crypto_retransmit_count = 0; + // Count of times the loss detection alarm fired. At least one packet should + // be lost when the alarm fires. + size_t loss_timeout_count = 0; + size_t tlp_count = 0; + size_t rto_count = 0; // Count of times the rto timer fired. + size_t pto_count = 0; + + int64_t min_rtt_us = 0; // Minimum RTT in microseconds. + int64_t srtt_us = 0; // Smoothed RTT in microseconds. + int64_t cwnd_bootstrapping_rtt_us = 0; // RTT used in cwnd_bootstrapping. + // The connection's |long_term_mtu_| used for sending packets, populated by + // QuicConnection::GetStats(). + QuicByteCount egress_mtu = 0; + // The maximum |long_term_mtu_| the connection ever used. + QuicByteCount max_egress_mtu = 0; + // Size of the largest packet received from the peer, populated by + // QuicConnection::GetStats(). + QuicByteCount ingress_mtu = 0; + QuicBandwidth estimated_bandwidth = QuicBandwidth::Zero(); + + // Reordering stats for received packets. + // Number of packets received out of packet number order. + QuicPacketCount packets_reordered = 0; + // Maximum reordering observed in packet number space. + QuicPacketCount max_sequence_reordering = 0; + // Maximum reordering observed in microseconds + int64_t max_time_reordering_us = 0; + + // Maximum sequence reordering observed from acked packets. + QuicPacketCount sent_packets_max_sequence_reordering = 0; + // Number of times that a packet is not detected as lost per reordering_shift, + // but would have been if the reordering_shift increases by one. + QuicPacketCount sent_packets_num_borderline_time_reorderings = 0; + + // The following stats are used only in TcpCubicSender. + // The number of loss events from TCP's perspective. Each loss event includes + // one or more lost packets. + uint32_t tcp_loss_events = 0; + + // Creation time, as reported by the QuicClock. + QuicTime connection_creation_time = QuicTime::Zero(); + + // Handshake completion time. + QuicTime handshake_completion_time = QuicTime::Zero(); + + uint64_t blocked_frames_received = 0; + uint64_t blocked_frames_sent = 0; + + // Number of connectivity probing packets received by this connection. + uint64_t num_connectivity_probing_received = 0; + + // Number of PATH_RESPONSE frame received by this connection. + uint64_t num_path_response_received = 0; + + // Whether a RETRY packet was successfully processed. + bool retry_packet_processed = false; + + // Number of received coalesced packets. + uint64_t num_coalesced_packets_received = 0; + // Number of successfully processed coalesced packets. + uint64_t num_coalesced_packets_processed = 0; + // Number of ack aggregation epochs. For the same number of bytes acked, the + // smaller this value, the more ack aggregation is going on. + uint64_t num_ack_aggregation_epochs = 0; + + // Whether overshooting is detected (and pacing rate decreases) during start + // up with network parameters adjusted. + bool overshooting_detected_with_network_parameters_adjusted = false; + + // Whether there is any non app-limited bandwidth sample. + bool has_non_app_limited_sample = false; + + // Packet number of first decrypted packet. + QuicPacketNumber first_decrypted_packet; + + // Max consecutive retransmission timeout before making forward progress. + size_t max_consecutive_rto_with_forward_progress = 0; + + // Number of times when the connection tries to send data but gets throttled + // by amplification factor. + size_t num_amplification_throttling = 0; + + // Number of key phase updates that have occurred. In the case of a locally + // initiated key update, this is incremented when the keys are updated, before + // the peer has acknowledged the key update. + uint32_t key_update_count = 0; + + // Counts the number of undecryptable packets received across all keys. Does + // not include packets where a decryption key for that level was absent. + QuicPacketCount num_failed_authentication_packets_received = 0; + + // Counts the number of QUIC+TLS 0-RTT packets received after 0-RTT decrypter + // was discarded, only on server connections. + QuicPacketCount + num_tls_server_zero_rtt_packets_received_after_discarding_decrypter = 0; + + // Counts the number of packets received with each Explicit Congestion + // Notification (ECN) codepoint, except Not-ECT. There is one counter across + // all packet number spaces. + QuicEcnCounts num_ecn_marks_received; + + // Counts the number of ACK frames sent with ECN counts. + QuicPacketCount num_ack_frames_sent_with_ecn = 0; + + // True if address is validated via decrypting HANDSHAKE or 1-RTT packet. + bool address_validated_via_decrypting_packet = false; + + // True if address is validated via validating token received in INITIAL + // packet. + bool address_validated_via_token = false; + + size_t ping_frames_sent = 0; + + // Number of detected peer address changes which changes to a peer address + // validated by earlier path validation. + size_t num_peer_migration_to_proactively_validated_address = 0; + // Number of detected peer address changes which triggers reverse path + // validation. + size_t num_reverse_path_validtion_upon_migration = 0; + // Number of detected peer migrations which either succeed reverse path + // validation or no need to be validated. + size_t num_validated_peer_migration = 0; + // Number of detected peer migrations which triggered reverse path validation + // and failed and fell back to the old path. + size_t num_invalid_peer_migration = 0; + // Number of detected peer migrations which triggered reverse path validation + // which was canceled because the peer migrated again. Such migration is also + // counted as invalid peer migration. + size_t num_peer_migration_while_validating_default_path = 0; + // Number of NEW_CONNECTION_ID frames sent. + size_t num_new_connection_id_sent = 0; + // Number of RETIRE_CONNECTION_ID frames sent. + size_t num_retire_connection_id_sent = 0; + + bool server_preferred_address_validated = false; + bool failed_to_validate_server_preferred_address = false; + // Number of duplicated packets that have been sent to server preferred + // address while the validation is pending. + size_t num_duplicated_packets_sent_to_server_preferred_address = 0; + + struct QUIC_NO_EXPORT TlsServerOperationStats { + bool success = false; + // If the operation is performed asynchronously, how long did it take. + // Zero() for synchronous operations. + QuicTime::Delta async_latency = QuicTime::Delta::Zero(); + }; + + // The TLS server op stats only have values when the corresponding operation + // is performed by TlsServerHandshaker. If an operation is done within + // BoringSSL, e.g. ticket decrypted without using + // TlsServerHandshaker::SessionTicketOpen, it will not be recorded here. + absl::optional tls_server_select_cert_stats; + absl::optional tls_server_compute_signature_stats; + absl::optional tls_server_decrypt_ticket_stats; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_CONNECTION_STATS_H_ diff --git a/quiche/quic/core/quic_connection_test.cc b/quiche/quic/core/quic_connection_test.cc new file mode 100644 index 000000000000..7853992a9749 --- /dev/null +++ b/quiche/quic/core/quic_connection_test.cc @@ -0,0 +1,17517 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_connection.h" + +#include + +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/congestion_control/loss_detection_interface.h" +#include "quiche/quic/core/congestion_control/send_algorithm_interface.h" +#include "quiche/quic/core/crypto/null_decrypter.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/core/crypto/quic_decrypter.h" +#include "quiche/quic/core/frames/quic_connection_close_frame.h" +#include "quiche/quic/core/frames/quic_new_connection_id_frame.h" +#include "quiche/quic/core/frames/quic_path_response_frame.h" +#include "quiche/quic/core/frames/quic_rst_stream_frame.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_packet_creator.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_path_validator.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_ip_address_family.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/mock_clock.h" +#include "quiche/quic/test_tools/mock_connection_id_generator.h" +#include "quiche/quic/test_tools/mock_random.h" +#include "quiche/quic/test_tools/quic_coalesced_packet_peer.h" +#include "quiche/quic/test_tools/quic_config_peer.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_framer_peer.h" +#include "quiche/quic/test_tools/quic_packet_creator_peer.h" +#include "quiche/quic/test_tools/quic_path_validator_peer.h" +#include "quiche/quic/test_tools/quic_sent_packet_manager_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/test_tools/simple_data_producer.h" +#include "quiche/quic/test_tools/simple_session_notifier.h" +#include "quiche/common/simple_buffer_allocator.h" + +using testing::_; +using testing::AnyNumber; +using testing::AtLeast; +using testing::DoAll; +using testing::ElementsAre; +using testing::Ge; +using testing::IgnoreResult; +using testing::InSequence; +using testing::Invoke; +using testing::InvokeWithoutArgs; +using testing::Lt; +using testing::Ref; +using testing::Return; +using testing::SaveArg; +using testing::SetArgPointee; +using testing::StrictMock; + +namespace quic { +namespace test { +namespace { + +const char data1[] = "foo data"; +const char data2[] = "bar data"; + +const bool kHasStopWaiting = true; + +const int kDefaultRetransmissionTimeMs = 500; + +DiversificationNonce kTestDiversificationNonce = { + 'a', 'b', 'a', 'b', 'a', 'b', 'a', 'b', 'a', 'b', 'a', + 'b', 'a', 'b', 'a', 'b', 'a', 'b', 'a', 'b', 'a', 'b', + 'a', 'b', 'a', 'b', 'a', 'b', 'a', 'b', 'a', 'b', +}; + +const StatelessResetToken kTestStatelessResetToken{ + 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, + 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f}; + +const QuicSocketAddress kPeerAddress = + QuicSocketAddress(QuicIpAddress::Loopback6(), + /*port=*/12345); +const QuicSocketAddress kSelfAddress = + QuicSocketAddress(QuicIpAddress::Loopback6(), + /*port=*/443); +const QuicSocketAddress kServerPreferredAddress = QuicSocketAddress( + []() { + QuicIpAddress address; + address.FromString("2604:31c0::"); + return address; + }(), + /*port=*/443); + +QuicStreamId GetNthClientInitiatedStreamId(int n, + QuicTransportVersion version) { + return QuicUtils::GetFirstBidirectionalStreamId(version, + Perspective::IS_CLIENT) + + n * 2; +} + +QuicLongHeaderType EncryptionlevelToLongHeaderType(EncryptionLevel level) { + switch (level) { + case ENCRYPTION_INITIAL: + return INITIAL; + case ENCRYPTION_HANDSHAKE: + return HANDSHAKE; + case ENCRYPTION_ZERO_RTT: + return ZERO_RTT_PROTECTED; + case ENCRYPTION_FORWARD_SECURE: + QUICHE_DCHECK(false); + return INVALID_PACKET_TYPE; + default: + QUICHE_DCHECK(false); + return INVALID_PACKET_TYPE; + } +} + +// A TaggingEncrypterWithConfidentialityLimit is a TaggingEncrypter that allows +// specifying the confidentiality limit on the maximum number of packets that +// may be encrypted per key phase in TLS+QUIC. +class TaggingEncrypterWithConfidentialityLimit : public TaggingEncrypter { + public: + TaggingEncrypterWithConfidentialityLimit( + uint8_t tag, QuicPacketCount confidentiality_limit) + : TaggingEncrypter(tag), confidentiality_limit_(confidentiality_limit) {} + + QuicPacketCount GetConfidentialityLimit() const override { + return confidentiality_limit_; + } + + private: + QuicPacketCount confidentiality_limit_; +}; + +class StrictTaggingDecrypterWithIntegrityLimit : public StrictTaggingDecrypter { + public: + StrictTaggingDecrypterWithIntegrityLimit(uint8_t tag, + QuicPacketCount integrity_limit) + : StrictTaggingDecrypter(tag), integrity_limit_(integrity_limit) {} + + QuicPacketCount GetIntegrityLimit() const override { + return integrity_limit_; + } + + private: + QuicPacketCount integrity_limit_; +}; + +class TestConnectionHelper : public QuicConnectionHelperInterface { + public: + TestConnectionHelper(MockClock* clock, MockRandom* random_generator) + : clock_(clock), random_generator_(random_generator) { + clock_->AdvanceTime(QuicTime::Delta::FromSeconds(1)); + } + TestConnectionHelper(const TestConnectionHelper&) = delete; + TestConnectionHelper& operator=(const TestConnectionHelper&) = delete; + + // QuicConnectionHelperInterface + const QuicClock* GetClock() const override { return clock_; } + + QuicRandom* GetRandomGenerator() override { return random_generator_; } + + quiche::QuicheBufferAllocator* GetStreamSendBufferAllocator() override { + return &buffer_allocator_; + } + + private: + MockClock* clock_; + MockRandom* random_generator_; + quiche::SimpleBufferAllocator buffer_allocator_; +}; + +class TestConnection : public QuicConnection { + public: + TestConnection(QuicConnectionId connection_id, + QuicSocketAddress initial_self_address, + QuicSocketAddress initial_peer_address, + TestConnectionHelper* helper, TestAlarmFactory* alarm_factory, + TestPacketWriter* writer, Perspective perspective, + ParsedQuicVersion version, + ConnectionIdGeneratorInterface& generator) + : QuicConnection(connection_id, initial_self_address, + initial_peer_address, helper, alarm_factory, writer, + /* owns_writer= */ false, perspective, + SupportedVersions(version), generator), + notifier_(nullptr) { + writer->set_perspective(perspective); + SetEncrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + SetDataProducer(&producer_); + ON_CALL(*this, OnSerializedPacket(_)) + .WillByDefault([this](SerializedPacket packet) { + QuicConnection::OnSerializedPacket(std::move(packet)); + }); + } + TestConnection(const TestConnection&) = delete; + TestConnection& operator=(const TestConnection&) = delete; + + MOCK_METHOD(void, OnSerializedPacket, (SerializedPacket packet), (override)); + + void OnEffectivePeerMigrationValidated(bool is_migration_linkable) override { + QuicConnection::OnEffectivePeerMigrationValidated(is_migration_linkable); + if (is_migration_linkable) { + num_linkable_client_migration_++; + } else { + num_unlinkable_client_migration_++; + } + } + + uint32_t num_unlinkable_client_migration() const { + return num_unlinkable_client_migration_; + } + + uint32_t num_linkable_client_migration() const { + return num_linkable_client_migration_; + } + + void SetSendAlgorithm(SendAlgorithmInterface* send_algorithm) { + QuicConnectionPeer::SetSendAlgorithm(this, send_algorithm); + } + + void SetLossAlgorithm(LossDetectionInterface* loss_algorithm) { + QuicConnectionPeer::SetLossAlgorithm(this, loss_algorithm); + } + + void SendPacket(EncryptionLevel /*level*/, uint64_t packet_number, + std::unique_ptr packet, + HasRetransmittableData retransmittable, bool has_ack, + bool has_pending_frames) { + ScopedPacketFlusher flusher(this); + char buffer[kMaxOutgoingPacketSize]; + size_t encrypted_length = + QuicConnectionPeer::GetFramer(this)->EncryptPayload( + ENCRYPTION_INITIAL, QuicPacketNumber(packet_number), *packet, + buffer, kMaxOutgoingPacketSize); + SerializedPacket serialized_packet( + QuicPacketNumber(packet_number), PACKET_4BYTE_PACKET_NUMBER, buffer, + encrypted_length, has_ack, has_pending_frames); + serialized_packet.peer_address = kPeerAddress; + if (retransmittable == HAS_RETRANSMITTABLE_DATA) { + serialized_packet.retransmittable_frames.push_back( + QuicFrame(QuicPingFrame())); + } + OnSerializedPacket(std::move(serialized_packet)); + } + + QuicConsumedData SaveAndSendStreamData(QuicStreamId id, + absl::string_view data, + QuicStreamOffset offset, + StreamSendingState state) { + return SaveAndSendStreamData(id, data, offset, state, NOT_RETRANSMISSION); + } + + QuicConsumedData SaveAndSendStreamData(QuicStreamId id, + absl::string_view data, + QuicStreamOffset offset, + StreamSendingState state, + TransmissionType transmission_type) { + ScopedPacketFlusher flusher(this); + producer_.SaveStreamData(id, data); + if (notifier_ != nullptr) { + return notifier_->WriteOrBufferData(id, data.length(), state, + transmission_type); + } + return QuicConnection::SendStreamData(id, data.length(), offset, state); + } + + QuicConsumedData SendStreamDataWithString(QuicStreamId id, + absl::string_view data, + QuicStreamOffset offset, + StreamSendingState state) { + ScopedPacketFlusher flusher(this); + if (!QuicUtils::IsCryptoStreamId(transport_version(), id) && + this->encryption_level() == ENCRYPTION_INITIAL) { + this->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + if (perspective() == Perspective::IS_CLIENT && !IsHandshakeComplete()) { + OnHandshakeComplete(); + } + if (version().SupportsAntiAmplificationLimit()) { + QuicConnectionPeer::SetAddressValidated(this); + } + } + return SaveAndSendStreamData(id, data, offset, state); + } + + QuicConsumedData SendApplicationDataAtLevel(EncryptionLevel encryption_level, + QuicStreamId id, + absl::string_view data, + QuicStreamOffset offset, + StreamSendingState state) { + ScopedPacketFlusher flusher(this); + QUICHE_DCHECK(encryption_level >= ENCRYPTION_ZERO_RTT); + SetEncrypter(encryption_level, + std::make_unique(encryption_level)); + SetDefaultEncryptionLevel(encryption_level); + return SaveAndSendStreamData(id, data, offset, state); + } + + QuicConsumedData SendStreamData3() { + return SendStreamDataWithString( + GetNthClientInitiatedStreamId(1, transport_version()), "food", 0, + NO_FIN); + } + + QuicConsumedData SendStreamData5() { + return SendStreamDataWithString( + GetNthClientInitiatedStreamId(2, transport_version()), "food2", 0, + NO_FIN); + } + + // Ensures the connection can write stream data before writing. + QuicConsumedData EnsureWritableAndSendStreamData5() { + EXPECT_TRUE(CanWrite(HAS_RETRANSMITTABLE_DATA)); + return SendStreamData5(); + } + + // The crypto stream has special semantics so that it is not blocked by a + // congestion window limitation, and also so that it gets put into a separate + // packet (so that it is easier to reason about a crypto frame not being + // split needlessly across packet boundaries). As a result, we have separate + // tests for some cases for this stream. + QuicConsumedData SendCryptoStreamData() { + QuicStreamOffset offset = 0; + absl::string_view data("chlo"); + if (!QuicVersionUsesCryptoFrames(transport_version())) { + return SendCryptoDataWithString(data, offset); + } + producer_.SaveCryptoData(ENCRYPTION_INITIAL, offset, data); + size_t bytes_written; + if (notifier_) { + bytes_written = + notifier_->WriteCryptoData(ENCRYPTION_INITIAL, data.length(), offset); + } else { + bytes_written = QuicConnection::SendCryptoData(ENCRYPTION_INITIAL, + data.length(), offset); + } + return QuicConsumedData(bytes_written, /*fin_consumed*/ false); + } + + QuicConsumedData SendCryptoDataWithString(absl::string_view data, + QuicStreamOffset offset) { + return SendCryptoDataWithString(data, offset, ENCRYPTION_INITIAL); + } + + QuicConsumedData SendCryptoDataWithString(absl::string_view data, + QuicStreamOffset offset, + EncryptionLevel encryption_level) { + if (!QuicVersionUsesCryptoFrames(transport_version())) { + return SendStreamDataWithString( + QuicUtils::GetCryptoStreamId(transport_version()), data, offset, + NO_FIN); + } + producer_.SaveCryptoData(encryption_level, offset, data); + size_t bytes_written; + if (notifier_) { + bytes_written = + notifier_->WriteCryptoData(encryption_level, data.length(), offset); + } else { + bytes_written = QuicConnection::SendCryptoData(encryption_level, + data.length(), offset); + } + return QuicConsumedData(bytes_written, /*fin_consumed*/ false); + } + + void set_version(ParsedQuicVersion version) { + QuicConnectionPeer::GetFramer(this)->set_version(version); + } + + void SetSupportedVersions(const ParsedQuicVersionVector& versions) { + QuicConnectionPeer::GetFramer(this)->SetSupportedVersions(versions); + writer()->SetSupportedVersions(versions); + } + + // This should be called before setting customized encrypters/decrypters for + // connection and peer creator. + void set_perspective(Perspective perspective) { + writer()->set_perspective(perspective); + QuicConnectionPeer::ResetPeerIssuedConnectionIdManager(this); + QuicConnectionPeer::SetPerspective(this, perspective); + QuicSentPacketManagerPeer::SetPerspective( + QuicConnectionPeer::GetSentPacketManager(this), perspective); + QuicConnectionPeer::GetFramer(this)->SetInitialObfuscators( + TestConnectionId()); + } + + // Enable path MTU discovery. Assumes that the test is performed from the + // server perspective and the higher value of MTU target is used. + void EnablePathMtuDiscovery(MockSendAlgorithm* send_algorithm) { + ASSERT_EQ(Perspective::IS_SERVER, perspective()); + + if (GetQuicReloadableFlag(quic_enable_mtu_discovery_at_server)) { + OnConfigNegotiated(); + } else { + QuicConfig config; + QuicTagVector connection_options; + connection_options.push_back(kMTUH); + config.SetInitialReceivedConnectionOptions(connection_options); + EXPECT_CALL(*send_algorithm, SetFromConfig(_, _)); + SetFromConfig(config); + } + + // Normally, the pacing would be disabled in the test, but calling + // SetFromConfig enables it. Set nearly-infinite bandwidth to make the + // pacing algorithm work. + EXPECT_CALL(*send_algorithm, PacingRate(_)) + .WillRepeatedly(Return(QuicBandwidth::Infinite())); + } + + TestAlarmFactory::TestAlarm* GetAckAlarm() { + return reinterpret_cast( + QuicConnectionPeer::GetAckAlarm(this)); + } + + TestAlarmFactory::TestAlarm* GetPingAlarm() { + return reinterpret_cast( + QuicConnectionPeer::GetPingAlarm(this)); + } + + TestAlarmFactory::TestAlarm* GetRetransmissionAlarm() { + return reinterpret_cast( + QuicConnectionPeer::GetRetransmissionAlarm(this)); + } + + TestAlarmFactory::TestAlarm* GetSendAlarm() { + return reinterpret_cast( + QuicConnectionPeer::GetSendAlarm(this)); + } + + TestAlarmFactory::TestAlarm* GetTimeoutAlarm() { + return reinterpret_cast( + QuicConnectionPeer::GetIdleNetworkDetectorAlarm(this)); + } + + TestAlarmFactory::TestAlarm* GetMtuDiscoveryAlarm() { + return reinterpret_cast( + QuicConnectionPeer::GetMtuDiscoveryAlarm(this)); + } + + TestAlarmFactory::TestAlarm* GetProcessUndecryptablePacketsAlarm() { + return reinterpret_cast( + QuicConnectionPeer::GetProcessUndecryptablePacketsAlarm(this)); + } + + TestAlarmFactory::TestAlarm* GetDiscardPreviousOneRttKeysAlarm() { + return reinterpret_cast( + QuicConnectionPeer::GetDiscardPreviousOneRttKeysAlarm(this)); + } + + TestAlarmFactory::TestAlarm* GetDiscardZeroRttDecryptionKeysAlarm() { + return reinterpret_cast( + QuicConnectionPeer::GetDiscardZeroRttDecryptionKeysAlarm(this)); + } + + TestAlarmFactory::TestAlarm* GetBlackholeDetectorAlarm() { + return reinterpret_cast( + QuicConnectionPeer::GetBlackholeDetectorAlarm(this)); + } + + TestAlarmFactory::TestAlarm* GetRetirePeerIssuedConnectionIdAlarm() { + return reinterpret_cast( + QuicConnectionPeer::GetRetirePeerIssuedConnectionIdAlarm(this)); + } + + TestAlarmFactory::TestAlarm* GetRetireSelfIssuedConnectionIdAlarm() { + return reinterpret_cast( + QuicConnectionPeer::GetRetireSelfIssuedConnectionIdAlarm(this)); + } + + TestAlarmFactory::TestAlarm* GetMultiPortProbingAlarm() { + return reinterpret_cast( + QuicConnectionPeer::GetMultiPortProbingAlarm(this)); + } + + void PathDegradingTimeout() { + QUICHE_DCHECK(PathDegradingDetectionInProgress()); + GetBlackholeDetectorAlarm()->Fire(); + } + + bool PathDegradingDetectionInProgress() { + return QuicConnectionPeer::GetPathDegradingDeadline(this).IsInitialized(); + } + + bool BlackholeDetectionInProgress() { + return QuicConnectionPeer::GetBlackholeDetectionDeadline(this) + .IsInitialized(); + } + + bool PathMtuReductionDetectionInProgress() { + return QuicConnectionPeer::GetPathMtuReductionDetectionDeadline(this) + .IsInitialized(); + } + + QuicByteCount GetBytesInFlight() { + return QuicConnectionPeer::GetSentPacketManager(this)->GetBytesInFlight(); + } + + void set_notifier(SimpleSessionNotifier* notifier) { notifier_ = notifier; } + + void ReturnEffectivePeerAddressForNextPacket(const QuicSocketAddress& addr) { + next_effective_peer_addr_ = std::make_unique(addr); + } + + void SendOrQueuePacket(SerializedPacket packet) override { + QuicConnection::SendOrQueuePacket(std::move(packet)); + self_address_on_default_path_while_sending_packet_ = self_address(); + } + + QuicSocketAddress self_address_on_default_path_while_sending_packet() { + return self_address_on_default_path_while_sending_packet_; + } + + SimpleDataProducer* producer() { return &producer_; } + + using QuicConnection::active_effective_peer_migration_type; + using QuicConnection::IsCurrentPacketConnectivityProbing; + using QuicConnection::SelectMutualVersion; + using QuicConnection::set_defer_send_in_response_to_packets; + + protected: + QuicSocketAddress GetEffectivePeerAddressFromCurrentPacket() const override { + if (next_effective_peer_addr_) { + return *std::move(next_effective_peer_addr_); + } + return QuicConnection::GetEffectivePeerAddressFromCurrentPacket(); + } + + private: + TestPacketWriter* writer() { + return static_cast(QuicConnection::writer()); + } + + SimpleDataProducer producer_; + + SimpleSessionNotifier* notifier_; + + std::unique_ptr next_effective_peer_addr_; + + QuicSocketAddress self_address_on_default_path_while_sending_packet_; + + uint32_t num_unlinkable_client_migration_ = 0; + + uint32_t num_linkable_client_migration_ = 0; +}; + +enum class AckResponse { kDefer, kImmediate }; + +// Run tests with combinations of {ParsedQuicVersion, AckResponse}. +struct TestParams { + TestParams(ParsedQuicVersion version, AckResponse ack_response, + bool no_stop_waiting) + : version(version), + ack_response(ack_response), + no_stop_waiting(no_stop_waiting) {} + + ParsedQuicVersion version; + AckResponse ack_response; + bool no_stop_waiting; +}; + +// Used by ::testing::PrintToStringParamName(). +std::string PrintToString(const TestParams& p) { + return absl::StrCat( + ParsedQuicVersionToString(p.version), "_", + (p.ack_response == AckResponse::kDefer ? "defer" : "immediate"), "_", + (p.no_stop_waiting ? "No" : ""), "StopWaiting"); +} + +// Constructs various test permutations. +std::vector GetTestParams() { + QuicFlagSaver flags; + std::vector params; + ParsedQuicVersionVector all_supported_versions = AllSupportedVersions(); + for (size_t i = 0; i < all_supported_versions.size(); ++i) { + for (AckResponse ack_response : + {AckResponse::kDefer, AckResponse::kImmediate}) { + params.push_back( + TestParams(all_supported_versions[i], ack_response, true)); + if (!all_supported_versions[i].HasIetfInvariantHeader()) { + params.push_back( + TestParams(all_supported_versions[i], ack_response, false)); + } + } + } + return params; +} + +class QuicConnectionTest : public QuicTestWithParam { + public: + // For tests that do silent connection closes, no such packet is generated. In + // order to verify the contents of the OnConnectionClosed upcall, EXPECTs + // should invoke this method, saving the frame, and then the test can verify + // the contents. + void SaveConnectionCloseFrame(const QuicConnectionCloseFrame& frame, + ConnectionCloseSource /*source*/) { + saved_connection_close_frame_ = frame; + connection_close_frame_count_++; + } + + protected: + QuicConnectionTest() + : connection_id_(TestConnectionId()), + framer_(SupportedVersions(version()), QuicTime::Zero(), + Perspective::IS_CLIENT, connection_id_.length()), + send_algorithm_(new StrictMock), + loss_algorithm_(new MockLossAlgorithm()), + helper_(new TestConnectionHelper(&clock_, &random_generator_)), + alarm_factory_(new TestAlarmFactory()), + peer_framer_(SupportedVersions(version()), QuicTime::Zero(), + Perspective::IS_SERVER, connection_id_.length()), + peer_creator_(connection_id_, &peer_framer_, + /*delegate=*/nullptr), + writer_( + new TestPacketWriter(version(), &clock_, Perspective::IS_CLIENT)), + connection_(connection_id_, kSelfAddress, kPeerAddress, helper_.get(), + alarm_factory_.get(), writer_.get(), Perspective::IS_CLIENT, + version(), connection_id_generator_), + creator_(QuicConnectionPeer::GetPacketCreator(&connection_)), + manager_(QuicConnectionPeer::GetSentPacketManager(&connection_)), + frame1_(0, false, 0, absl::string_view(data1)), + frame2_(0, false, 3, absl::string_view(data2)), + crypto_frame_(ENCRYPTION_INITIAL, 0, absl::string_view(data1)), + packet_number_length_(PACKET_4BYTE_PACKET_NUMBER), + connection_id_included_(CONNECTION_ID_PRESENT), + notifier_(&connection_), + connection_close_frame_count_(0) { + QUIC_DVLOG(2) << "QuicConnectionTest(" << PrintToString(GetParam()) << ")"; + connection_.set_defer_send_in_response_to_packets(GetParam().ack_response == + AckResponse::kDefer); + framer_.SetInitialObfuscators(TestConnectionId()); + connection_.InstallInitialCrypters(TestConnectionId()); + CrypterPair crypters; + CryptoUtils::CreateInitialObfuscators(Perspective::IS_SERVER, version(), + TestConnectionId(), &crypters); + peer_creator_.SetEncrypter(ENCRYPTION_INITIAL, + std::move(crypters.encrypter)); + if (version().KnowsWhichDecrypterToUse()) { + peer_framer_.InstallDecrypter(ENCRYPTION_INITIAL, + std::move(crypters.decrypter)); + } else { + peer_framer_.SetDecrypter(ENCRYPTION_INITIAL, + std::move(crypters.decrypter)); + } + for (EncryptionLevel level : + {ENCRYPTION_ZERO_RTT, ENCRYPTION_FORWARD_SECURE}) { + peer_creator_.SetEncrypter(level, + std::make_unique(level)); + } + QuicFramerPeer::SetLastSerializedServerConnectionId( + QuicConnectionPeer::GetFramer(&connection_), connection_id_); + QuicFramerPeer::SetLastWrittenPacketNumberLength( + QuicConnectionPeer::GetFramer(&connection_), packet_number_length_); + if (version().HasIetfInvariantHeader()) { + EXPECT_TRUE(QuicConnectionPeer::GetNoStopWaitingFrames(&connection_)); + } else { + QuicConnectionPeer::SetNoStopWaitingFrames(&connection_, + GetParam().no_stop_waiting); + } + QuicStreamId stream_id; + if (QuicVersionUsesCryptoFrames(version().transport_version)) { + stream_id = QuicUtils::GetFirstBidirectionalStreamId( + version().transport_version, Perspective::IS_CLIENT); + } else { + stream_id = QuicUtils::GetCryptoStreamId(version().transport_version); + } + frame1_.stream_id = stream_id; + frame2_.stream_id = stream_id; + connection_.set_visitor(&visitor_); + connection_.SetSessionNotifier(¬ifier_); + connection_.set_notifier(¬ifier_); + connection_.SetSendAlgorithm(send_algorithm_); + connection_.SetLossAlgorithm(loss_algorithm_.get()); + EXPECT_CALL(*send_algorithm_, CanSend(_)).WillRepeatedly(Return(true)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, OnPacketNeutered(_)).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .WillRepeatedly(Return(kDefaultTCPMSS)); + EXPECT_CALL(*send_algorithm_, PacingRate(_)) + .WillRepeatedly(Return(QuicBandwidth::Zero())); + EXPECT_CALL(*send_algorithm_, BandwidthEstimate()) + .Times(AnyNumber()) + .WillRepeatedly(Return(QuicBandwidth::Zero())); + EXPECT_CALL(*send_algorithm_, PopulateConnectionStats(_)) + .Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, InSlowStart()).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, InRecovery()).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, GetCongestionControlType()) + .Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, OnApplicationLimited(_)).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, GetCongestionControlType()) + .Times(AnyNumber()); + EXPECT_CALL(visitor_, WillingAndAbleToWrite()).Times(AnyNumber()); + EXPECT_CALL(visitor_, OnPacketDecrypted(_)).Times(AnyNumber()); + EXPECT_CALL(visitor_, OnCanWrite()) + .WillRepeatedly(Invoke(¬ifier_, &SimpleSessionNotifier::OnCanWrite)); + EXPECT_CALL(visitor_, ShouldKeepConnectionAlive()) + .WillRepeatedly(Return(false)); + EXPECT_CALL(visitor_, OnCongestionWindowChange(_)).Times(AnyNumber()); + EXPECT_CALL(visitor_, OnPacketReceived(_, _, _)).Times(AnyNumber()); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)).Times(AnyNumber()); + EXPECT_CALL(visitor_, OnOneRttPacketAcknowledged()) + .Times(testing::AtMost(1)); + EXPECT_CALL(*loss_algorithm_, GetLossTimeout()) + .WillRepeatedly(Return(QuicTime::Zero())); + EXPECT_CALL(*loss_algorithm_, DetectLosses(_, _, _, _, _, _)) + .Times(AnyNumber()); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_START)); + if (connection_.version().KnowsWhichDecrypterToUse()) { + connection_.InstallDecrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + } else { + connection_.SetAlternativeDecrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE), + false); + } + peer_creator_.SetDefaultPeerAddress(kSelfAddress); + } + + QuicConnectionTest(const QuicConnectionTest&) = delete; + QuicConnectionTest& operator=(const QuicConnectionTest&) = delete; + + ParsedQuicVersion version() { return GetParam().version; } + + QuicStopWaitingFrame* stop_waiting() { + QuicConnectionPeer::PopulateStopWaitingFrame(&connection_, &stop_waiting_); + return &stop_waiting_; + } + + QuicPacketNumber least_unacked() { + if (writer_->stop_waiting_frames().empty()) { + return QuicPacketNumber(); + } + return writer_->stop_waiting_frames()[0].least_unacked; + } + + void SetClientConnectionId(const QuicConnectionId& client_connection_id) { + connection_.set_client_connection_id(client_connection_id); + writer_->framer()->framer()->SetExpectedClientConnectionIdLength( + client_connection_id.length()); + } + + void SetDecrypter(EncryptionLevel level, + std::unique_ptr decrypter) { + if (connection_.version().KnowsWhichDecrypterToUse()) { + connection_.InstallDecrypter(level, std::move(decrypter)); + } else { + connection_.SetAlternativeDecrypter(level, std::move(decrypter), false); + } + } + + void ProcessPacket(uint64_t number) { + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacket(number); + if (connection_.GetSendAlarm()->IsSet()) { + connection_.GetSendAlarm()->Fire(); + } + } + + void ProcessReceivedPacket(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + const QuicReceivedPacket& packet) { + connection_.ProcessUdpPacket(self_address, peer_address, packet); + if (connection_.GetSendAlarm()->IsSet()) { + connection_.GetSendAlarm()->Fire(); + } + } + + QuicFrame MakeCryptoFrame() const { + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + return QuicFrame(new QuicCryptoFrame(crypto_frame_)); + } + return QuicFrame(QuicStreamFrame( + QuicUtils::GetCryptoStreamId(connection_.transport_version()), false, + 0u, absl::string_view())); + } + + void ProcessFramePacket(QuicFrame frame) { + ProcessFramePacketWithAddresses(frame, kSelfAddress, kPeerAddress, + ENCRYPTION_FORWARD_SECURE); + } + + void ProcessFramePacketWithAddresses(QuicFrame frame, + QuicSocketAddress self_address, + QuicSocketAddress peer_address, + EncryptionLevel level) { + QuicFrames frames; + frames.push_back(QuicFrame(frame)); + return ProcessFramesPacketWithAddresses(frames, self_address, peer_address, + level); + } + + std::unique_ptr ConstructPacket(QuicFrames frames, + EncryptionLevel level, + char* buffer, + size_t buffer_len) { + QUICHE_DCHECK(peer_framer_.HasEncrypterOfEncryptionLevel(level)); + peer_creator_.set_encryption_level(level); + QuicPacketCreatorPeer::SetSendVersionInPacket( + &peer_creator_, + level < ENCRYPTION_FORWARD_SECURE && + connection_.perspective() == Perspective::IS_SERVER); + + SerializedPacket serialized_packet = + QuicPacketCreatorPeer::SerializeAllFrames(&peer_creator_, frames, + buffer, buffer_len); + return std::make_unique( + serialized_packet.encrypted_buffer, serialized_packet.encrypted_length, + clock_.Now()); + } + + void ProcessFramesPacketWithAddresses(QuicFrames frames, + QuicSocketAddress self_address, + QuicSocketAddress peer_address, + EncryptionLevel level) { + char buffer[kMaxOutgoingPacketSize]; + connection_.ProcessUdpPacket( + self_address, peer_address, + *ConstructPacket(std::move(frames), level, buffer, + kMaxOutgoingPacketSize)); + if (connection_.GetSendAlarm()->IsSet()) { + connection_.GetSendAlarm()->Fire(); + } + } + + // Bypassing the packet creator is unrealistic, but allows us to process + // packets the QuicPacketCreator won't allow us to create. + void ForceProcessFramePacket(QuicFrame frame) { + QuicFrames frames; + frames.push_back(QuicFrame(frame)); + bool send_version = connection_.perspective() == Perspective::IS_SERVER; + if (connection_.version().KnowsWhichDecrypterToUse()) { + send_version = true; + } + QuicPacketCreatorPeer::SetSendVersionInPacket(&peer_creator_, send_version); + QuicPacketHeader header; + QuicPacketCreatorPeer::FillPacketHeader(&peer_creator_, &header); + char encrypted_buffer[kMaxOutgoingPacketSize]; + size_t length = peer_framer_.BuildDataPacket( + header, frames, encrypted_buffer, kMaxOutgoingPacketSize, + ENCRYPTION_INITIAL); + QUICHE_DCHECK_GT(length, 0u); + + const size_t encrypted_length = peer_framer_.EncryptInPlace( + ENCRYPTION_INITIAL, header.packet_number, + GetStartOfEncryptedData(peer_framer_.version().transport_version, + header), + length, kMaxOutgoingPacketSize, encrypted_buffer); + QUICHE_DCHECK_GT(encrypted_length, 0u); + + connection_.ProcessUdpPacket( + kSelfAddress, kPeerAddress, + QuicReceivedPacket(encrypted_buffer, encrypted_length, clock_.Now())); + } + + size_t ProcessFramePacketAtLevel(uint64_t number, QuicFrame frame, + EncryptionLevel level) { + return ProcessFramePacketAtLevelWithEcn(number, frame, level, ECN_NOT_ECT); + } + + size_t ProcessFramePacketAtLevelWithEcn(uint64_t number, QuicFrame frame, + EncryptionLevel level, + QuicEcnCodepoint ecn_codepoint) { + QuicFrames frames; + frames.push_back(frame); + return ProcessFramesPacketAtLevelWithEcn(number, frames, level, + ecn_codepoint); + } + + size_t ProcessFramesPacketAtLevel(uint64_t number, QuicFrames frames, + EncryptionLevel level) { + return ProcessFramesPacketAtLevelWithEcn(number, frames, level, + ECN_NOT_ECT); + } + + size_t ProcessFramesPacketAtLevelWithEcn(uint64_t number, + const QuicFrames& frames, + EncryptionLevel level, + QuicEcnCodepoint ecn_codepoint) { + QuicPacketHeader header = ConstructPacketHeader(number, level); + // Set the correct encryption level and encrypter on peer_creator and + // peer_framer, respectively. + peer_creator_.set_encryption_level(level); + if (level > ENCRYPTION_INITIAL) { + peer_framer_.SetEncrypter(level, + std::make_unique(level)); + // Set the corresponding decrypter. + if (connection_.version().KnowsWhichDecrypterToUse()) { + connection_.InstallDecrypter( + level, std::make_unique(level)); + } else { + connection_.SetAlternativeDecrypter( + level, std::make_unique(level), false); + } + } + std::unique_ptr packet(ConstructPacket(header, frames)); + + char buffer[kMaxOutgoingPacketSize]; + size_t encrypted_length = + peer_framer_.EncryptPayload(level, QuicPacketNumber(number), *packet, + buffer, kMaxOutgoingPacketSize); + connection_.ProcessUdpPacket( + kSelfAddress, kPeerAddress, + QuicReceivedPacket(buffer, encrypted_length, clock_.Now(), false, 0, + true, nullptr, 0, false, ecn_codepoint)); + if (connection_.GetSendAlarm()->IsSet()) { + connection_.GetSendAlarm()->Fire(); + } + return encrypted_length; + } + + struct PacketInfo { + PacketInfo(uint64_t packet_number, QuicFrames frames, EncryptionLevel level) + : packet_number(packet_number), frames(frames), level(level) {} + + uint64_t packet_number; + QuicFrames frames; + EncryptionLevel level; + }; + + size_t ProcessCoalescedPacket(std::vector packets) { + return ProcessCoalescedPacket(packets, ECN_NOT_ECT); + } + + size_t ProcessCoalescedPacket(std::vector packets, + QuicEcnCodepoint ecn_codepoint) { + char coalesced_buffer[kMaxOutgoingPacketSize]; + size_t coalesced_size = 0; + bool contains_initial = false; + for (const auto& packet : packets) { + QuicPacketHeader header = + ConstructPacketHeader(packet.packet_number, packet.level); + // Set the correct encryption level and encrypter on peer_creator and + // peer_framer, respectively. + peer_creator_.set_encryption_level(packet.level); + if (packet.level == ENCRYPTION_INITIAL) { + contains_initial = true; + } + EncryptionLevel level = + QuicPacketCreatorPeer::GetEncryptionLevel(&peer_creator_); + if (level > ENCRYPTION_INITIAL) { + peer_framer_.SetEncrypter(level, + std::make_unique(level)); + // Set the corresponding decrypter. + if (connection_.version().KnowsWhichDecrypterToUse()) { + connection_.InstallDecrypter( + level, std::make_unique(level)); + } else { + connection_.SetDecrypter( + level, std::make_unique(level)); + } + } + std::unique_ptr constructed_packet( + ConstructPacket(header, packet.frames)); + + char buffer[kMaxOutgoingPacketSize]; + size_t encrypted_length = peer_framer_.EncryptPayload( + packet.level, QuicPacketNumber(packet.packet_number), + *constructed_packet, buffer, kMaxOutgoingPacketSize); + QUICHE_DCHECK_LE(coalesced_size + encrypted_length, + kMaxOutgoingPacketSize); + memcpy(coalesced_buffer + coalesced_size, buffer, encrypted_length); + coalesced_size += encrypted_length; + } + if (contains_initial) { + // Padded coalesced packet to full if it contains initial packet. + memset(coalesced_buffer + coalesced_size, '0', + kMaxOutgoingPacketSize - coalesced_size); + } + connection_.ProcessUdpPacket( + kSelfAddress, kPeerAddress, + QuicReceivedPacket(coalesced_buffer, coalesced_size, clock_.Now(), + false, 0, true, nullptr, 0, false, ecn_codepoint)); + if (connection_.GetSendAlarm()->IsSet()) { + connection_.GetSendAlarm()->Fire(); + } + return coalesced_size; + } + + size_t ProcessDataPacket(uint64_t number) { + return ProcessDataPacketAtLevel(number, false, ENCRYPTION_FORWARD_SECURE); + } + + size_t ProcessDataPacket(QuicPacketNumber packet_number) { + return ProcessDataPacketAtLevel(packet_number, false, + ENCRYPTION_FORWARD_SECURE); + } + + size_t ProcessDataPacketAtLevel(QuicPacketNumber packet_number, + bool has_stop_waiting, + EncryptionLevel level) { + return ProcessDataPacketAtLevel(packet_number.ToUint64(), has_stop_waiting, + level); + } + + size_t ProcessCryptoPacketAtLevel(uint64_t number, EncryptionLevel level) { + QuicPacketHeader header = ConstructPacketHeader(number, level); + QuicFrames frames; + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + frames.push_back(QuicFrame(&crypto_frame_)); + } else { + frames.push_back(QuicFrame(frame1_)); + } + if (level == ENCRYPTION_INITIAL) { + frames.push_back(QuicFrame(QuicPaddingFrame(-1))); + } + std::unique_ptr packet = ConstructPacket(header, frames); + char buffer[kMaxOutgoingPacketSize]; + peer_creator_.set_encryption_level(level); + size_t encrypted_length = + peer_framer_.EncryptPayload(level, QuicPacketNumber(number), *packet, + buffer, kMaxOutgoingPacketSize); + connection_.ProcessUdpPacket( + kSelfAddress, kPeerAddress, + QuicReceivedPacket(buffer, encrypted_length, clock_.Now(), false)); + if (connection_.GetSendAlarm()->IsSet()) { + connection_.GetSendAlarm()->Fire(); + } + return encrypted_length; + } + + size_t ProcessDataPacketAtLevel(uint64_t number, bool has_stop_waiting, + EncryptionLevel level) { + std::unique_ptr packet( + ConstructDataPacket(number, has_stop_waiting, level)); + char buffer[kMaxOutgoingPacketSize]; + peer_creator_.set_encryption_level(level); + size_t encrypted_length = + peer_framer_.EncryptPayload(level, QuicPacketNumber(number), *packet, + buffer, kMaxOutgoingPacketSize); + connection_.ProcessUdpPacket( + kSelfAddress, kPeerAddress, + QuicReceivedPacket(buffer, encrypted_length, clock_.Now(), false)); + if (connection_.GetSendAlarm()->IsSet()) { + connection_.GetSendAlarm()->Fire(); + } + return encrypted_length; + } + + void ProcessClosePacket(uint64_t number) { + std::unique_ptr packet(ConstructClosePacket(number)); + char buffer[kMaxOutgoingPacketSize]; + size_t encrypted_length = peer_framer_.EncryptPayload( + ENCRYPTION_FORWARD_SECURE, QuicPacketNumber(number), *packet, buffer, + kMaxOutgoingPacketSize); + connection_.ProcessUdpPacket( + kSelfAddress, kPeerAddress, + QuicReceivedPacket(buffer, encrypted_length, QuicTime::Zero(), false)); + } + + QuicByteCount SendStreamDataToPeer(QuicStreamId id, absl::string_view data, + QuicStreamOffset offset, + StreamSendingState state, + QuicPacketNumber* last_packet) { + QuicByteCount packet_size = 0; + // Save the last packet's size. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(SaveArg<3>(&packet_size)); + connection_.SendStreamDataWithString(id, data, offset, state); + if (last_packet != nullptr) { + *last_packet = creator_->packet_number(); + } + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(AnyNumber()); + return packet_size; + } + + void SendAckPacketToPeer() { + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + { + QuicConnection::ScopedPacketFlusher flusher(&connection_); + connection_.SendAck(); + } + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(AnyNumber()); + } + + void SendRstStream(QuicStreamId id, QuicRstStreamErrorCode error, + QuicStreamOffset bytes_written) { + notifier_.WriteOrBufferRstStream(id, error, bytes_written); + connection_.OnStreamReset(id, error); + } + + void SendPing() { notifier_.WriteOrBufferPing(); } + + MessageStatus SendMessage(absl::string_view message) { + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + quiche::QuicheMemSlice slice(quiche::QuicheBuffer::Copy( + connection_.helper()->GetStreamSendBufferAllocator(), message)); + return connection_.SendMessage(1, absl::MakeSpan(&slice, 1), false); + } + + void ProcessAckPacket(uint64_t packet_number, QuicAckFrame* frame) { + if (packet_number > 1) { + QuicPacketCreatorPeer::SetPacketNumber(&peer_creator_, packet_number - 1); + } else { + QuicPacketCreatorPeer::ClearPacketNumber(&peer_creator_); + } + ProcessFramePacket(QuicFrame(frame)); + } + + void ProcessAckPacket(QuicAckFrame* frame) { + ProcessFramePacket(QuicFrame(frame)); + } + + void ProcessStopWaitingPacket(QuicStopWaitingFrame frame) { + ProcessFramePacket(QuicFrame(frame)); + } + + size_t ProcessStopWaitingPacketAtLevel(uint64_t number, + QuicStopWaitingFrame frame, + EncryptionLevel /*level*/) { + return ProcessFramePacketAtLevel(number, QuicFrame(frame), + ENCRYPTION_ZERO_RTT); + } + + void ProcessGoAwayPacket(QuicGoAwayFrame* frame) { + ProcessFramePacket(QuicFrame(frame)); + } + + bool IsMissing(uint64_t number) { + return IsAwaitingPacket(connection_.ack_frame(), QuicPacketNumber(number), + QuicPacketNumber()); + } + + std::unique_ptr ConstructPacket(const QuicPacketHeader& header, + const QuicFrames& frames) { + auto packet = BuildUnsizedDataPacket(&peer_framer_, header, frames); + EXPECT_NE(nullptr, packet.get()); + return packet; + } + + QuicPacketHeader ConstructPacketHeader(uint64_t number, + EncryptionLevel level) { + QuicPacketHeader header; + if (peer_framer_.version().HasIetfInvariantHeader() && + level < ENCRYPTION_FORWARD_SECURE) { + // Set long header type accordingly. + header.version_flag = true; + header.form = IETF_QUIC_LONG_HEADER_PACKET; + header.long_packet_type = EncryptionlevelToLongHeaderType(level); + if (QuicVersionHasLongHeaderLengths( + peer_framer_.version().transport_version)) { + header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2; + if (header.long_packet_type == INITIAL) { + header.retry_token_length_length = + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1; + } + } + } + // Set connection_id to peer's in memory representation as this data packet + // is created by peer_framer. + if (peer_framer_.perspective() == Perspective::IS_SERVER) { + header.source_connection_id = connection_id_; + header.source_connection_id_included = connection_id_included_; + header.destination_connection_id_included = CONNECTION_ID_ABSENT; + } else { + header.destination_connection_id = connection_id_; + header.destination_connection_id_included = connection_id_included_; + } + if (peer_framer_.version().HasIetfInvariantHeader() && + peer_framer_.perspective() == Perspective::IS_SERVER) { + if (!connection_.client_connection_id().IsEmpty()) { + header.destination_connection_id = connection_.client_connection_id(); + header.destination_connection_id_included = CONNECTION_ID_PRESENT; + } else { + header.destination_connection_id_included = CONNECTION_ID_ABSENT; + } + if (header.version_flag) { + header.source_connection_id = connection_id_; + header.source_connection_id_included = CONNECTION_ID_PRESENT; + if (GetParam().version.handshake_protocol == PROTOCOL_QUIC_CRYPTO && + header.long_packet_type == ZERO_RTT_PROTECTED) { + header.nonce = &kTestDiversificationNonce; + } + } + } + if (!peer_framer_.version().HasIetfInvariantHeader() && + peer_framer_.perspective() == Perspective::IS_SERVER && + GetParam().version.handshake_protocol == PROTOCOL_QUIC_CRYPTO && + level == ENCRYPTION_ZERO_RTT) { + header.nonce = &kTestDiversificationNonce; + } + header.packet_number_length = packet_number_length_; + header.packet_number = QuicPacketNumber(number); + return header; + } + + std::unique_ptr ConstructDataPacket(uint64_t number, + bool has_stop_waiting, + EncryptionLevel level) { + QuicPacketHeader header = ConstructPacketHeader(number, level); + QuicFrames frames; + if (VersionHasIetfQuicFrames(version().transport_version) && + (level == ENCRYPTION_INITIAL || level == ENCRYPTION_HANDSHAKE)) { + frames.push_back(QuicFrame(QuicPingFrame())); + frames.push_back(QuicFrame(QuicPaddingFrame(100))); + } else { + frames.push_back(QuicFrame(frame1_)); + if (has_stop_waiting) { + frames.push_back(QuicFrame(stop_waiting_)); + } + } + return ConstructPacket(header, frames); + } + + std::unique_ptr ConstructProbingPacket() { + peer_creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + if (VersionHasIetfQuicFrames(version().transport_version)) { + QuicPathFrameBuffer payload = { + {0xde, 0xad, 0xbe, 0xef, 0xba, 0xdc, 0x0f, 0xfe}}; + return QuicPacketCreatorPeer:: + SerializePathChallengeConnectivityProbingPacket(&peer_creator_, + payload); + } + return QuicPacketCreatorPeer::SerializeConnectivityProbingPacket( + &peer_creator_); + } + + std::unique_ptr ConstructClosePacket(uint64_t number) { + peer_creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + QuicPacketHeader header; + // Set connection_id to peer's in memory representation as this connection + // close packet is created by peer_framer. + if (peer_framer_.perspective() == Perspective::IS_SERVER) { + header.source_connection_id = connection_id_; + header.destination_connection_id_included = CONNECTION_ID_ABSENT; + if (!peer_framer_.version().HasIetfInvariantHeader()) { + header.source_connection_id_included = CONNECTION_ID_PRESENT; + } + } else { + header.destination_connection_id = connection_id_; + if (peer_framer_.version().HasIetfInvariantHeader()) { + header.destination_connection_id_included = CONNECTION_ID_ABSENT; + } + } + + header.packet_number = QuicPacketNumber(number); + + QuicErrorCode kQuicErrorCode = QUIC_PEER_GOING_AWAY; + QuicConnectionCloseFrame qccf(peer_framer_.transport_version(), + kQuicErrorCode, NO_IETF_QUIC_ERROR, "", + /*transport_close_frame_type=*/0); + QuicFrames frames; + frames.push_back(QuicFrame(&qccf)); + return ConstructPacket(header, frames); + } + + QuicTime::Delta DefaultRetransmissionTime() { + return QuicTime::Delta::FromMilliseconds(kDefaultRetransmissionTimeMs); + } + + QuicTime::Delta DefaultDelayedAckTime() { + return QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs); + } + + const QuicStopWaitingFrame InitStopWaitingFrame(uint64_t least_unacked) { + QuicStopWaitingFrame frame; + frame.least_unacked = QuicPacketNumber(least_unacked); + return frame; + } + + // Construct a ack_frame that acks all packet numbers between 1 and + // |largest_acked|, except |missing|. + // REQUIRES: 1 <= |missing| < |largest_acked| + QuicAckFrame ConstructAckFrame(uint64_t largest_acked, uint64_t missing) { + return ConstructAckFrame(QuicPacketNumber(largest_acked), + QuicPacketNumber(missing)); + } + + QuicAckFrame ConstructAckFrame(QuicPacketNumber largest_acked, + QuicPacketNumber missing) { + if (missing == QuicPacketNumber(1)) { + return InitAckFrame({{missing + 1, largest_acked + 1}}); + } + return InitAckFrame( + {{QuicPacketNumber(1), missing}, {missing + 1, largest_acked + 1}}); + } + + // Undo nacking a packet within the frame. + void AckPacket(QuicPacketNumber arrived, QuicAckFrame* frame) { + EXPECT_FALSE(frame->packets.Contains(arrived)); + frame->packets.Add(arrived); + } + + void TriggerConnectionClose() { + // Send an erroneous packet to close the connection. + EXPECT_CALL(visitor_, + OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .WillOnce(Invoke(this, &QuicConnectionTest::SaveConnectionCloseFrame)); + + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + // Triggers a connection by receiving ACK of unsent packet. + QuicAckFrame frame = InitAckFrame(10000); + ProcessAckPacket(1, &frame); + EXPECT_FALSE(QuicConnectionPeer::GetConnectionClosePacket(&connection_) == + nullptr); + EXPECT_EQ(1, connection_close_frame_count_); + EXPECT_THAT(saved_connection_close_frame_.quic_error_code, + IsError(QUIC_INVALID_ACK_DATA)); + } + + void BlockOnNextWrite() { + writer_->BlockOnNextWrite(); + EXPECT_CALL(visitor_, OnWriteBlocked()).Times(AtLeast(1)); + } + + void SimulateNextPacketTooLarge() { writer_->SimulateNextPacketTooLarge(); } + + void ExpectNextPacketUnprocessable() { + writer_->ExpectNextPacketUnprocessable(); + } + + void AlwaysGetPacketTooLarge() { writer_->AlwaysGetPacketTooLarge(); } + + void SetWritePauseTimeDelta(QuicTime::Delta delta) { + writer_->SetWritePauseTimeDelta(delta); + } + + void CongestionBlockWrites() { + EXPECT_CALL(*send_algorithm_, CanSend(_)) + .WillRepeatedly(testing::Return(false)); + } + + void CongestionUnblockWrites() { + EXPECT_CALL(*send_algorithm_, CanSend(_)) + .WillRepeatedly(testing::Return(true)); + } + + void set_perspective(Perspective perspective) { + connection_.set_perspective(perspective); + if (perspective == Perspective::IS_SERVER) { + QuicConfig config; + if (!GetQuicReloadableFlag( + quic_remove_connection_migration_connection_option_v2)) { + QuicTagVector connection_options; + connection_options.push_back(kRVCM); + config.SetInitialReceivedConnectionOptions(connection_options); + } + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + + connection_.set_can_truncate_connection_ids(true); + QuicConnectionPeer::SetNegotiatedVersion(&connection_); + connection_.OnSuccessfulVersionNegotiation(); + } + QuicFramerPeer::SetPerspective(&peer_framer_, + QuicUtils::InvertPerspective(perspective)); + peer_framer_.SetInitialObfuscators(TestConnectionId()); + for (EncryptionLevel level : {ENCRYPTION_ZERO_RTT, ENCRYPTION_HANDSHAKE, + ENCRYPTION_FORWARD_SECURE}) { + if (peer_framer_.HasEncrypterOfEncryptionLevel(level)) { + peer_creator_.SetEncrypter(level, + std::make_unique(level)); + } + } + } + + void set_packets_between_probes_base( + const QuicPacketCount packets_between_probes_base) { + QuicConnectionPeer::ReInitializeMtuDiscoverer( + &connection_, packets_between_probes_base, + QuicPacketNumber(packets_between_probes_base)); + } + + bool IsDefaultTestConfiguration() { + TestParams p = GetParam(); + return p.ack_response == AckResponse::kImmediate && + p.version == AllSupportedVersions()[0] && p.no_stop_waiting; + } + + void TestConnectionCloseQuicErrorCode(QuicErrorCode expected_code) { + // Not strictly needed for this test, but is commonly done. + EXPECT_FALSE(QuicConnectionPeer::GetConnectionClosePacket(&connection_) == + nullptr); + const std::vector& connection_close_frames = + writer_->connection_close_frames(); + ASSERT_EQ(1u, connection_close_frames.size()); + + EXPECT_THAT(connection_close_frames[0].quic_error_code, + IsError(expected_code)); + + if (!VersionHasIetfQuicFrames(version().transport_version)) { + EXPECT_THAT(connection_close_frames[0].wire_error_code, + IsError(expected_code)); + EXPECT_EQ(GOOGLE_QUIC_CONNECTION_CLOSE, + connection_close_frames[0].close_type); + return; + } + + QuicErrorCodeToIetfMapping mapping = + QuicErrorCodeToTransportErrorCode(expected_code); + + if (mapping.is_transport_close) { + // This Google QUIC Error Code maps to a transport close, + EXPECT_EQ(IETF_QUIC_TRANSPORT_CONNECTION_CLOSE, + connection_close_frames[0].close_type); + } else { + // This maps to an application close. + EXPECT_EQ(IETF_QUIC_APPLICATION_CONNECTION_CLOSE, + connection_close_frames[0].close_type); + } + EXPECT_EQ(mapping.error_code, connection_close_frames[0].wire_error_code); + } + + void MtuDiscoveryTestInit() { + set_perspective(Perspective::IS_SERVER); + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + if (version().SupportsAntiAmplificationLimit()) { + QuicConnectionPeer::SetAddressValidated(&connection_); + } + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + peer_creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + // Prevent packets from being coalesced. + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + EXPECT_TRUE(connection_.connected()); + } + + void PathProbeTestInit(Perspective perspective, + bool receive_new_server_connection_id = true) { + set_perspective(perspective); + connection_.CreateConnectionIdManager(); + EXPECT_EQ(connection_.perspective(), perspective); + if (perspective == Perspective::IS_SERVER) { + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + } + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + peer_creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + // Discard INITIAL key. + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.NeuterUnencryptedPackets(); + // Prevent packets from being coalesced. + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + if (version().SupportsAntiAmplificationLimit() && + perspective == Perspective::IS_SERVER) { + QuicConnectionPeer::SetAddressValidated(&connection_); + } + // Clear direct_peer_address. + QuicConnectionPeer::SetDirectPeerAddress(&connection_, QuicSocketAddress()); + // Clear effective_peer_address, it is the same as direct_peer_address for + // this test. + QuicConnectionPeer::SetEffectivePeerAddress(&connection_, + QuicSocketAddress()); + EXPECT_FALSE(connection_.effective_peer_address().IsInitialized()); + + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + } else { + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(AnyNumber()); + } + QuicPacketCreatorPeer::SetPacketNumber(&peer_creator_, 2); + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kSelfAddress, + kPeerAddress, ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + if (perspective == Perspective::IS_CLIENT && + receive_new_server_connection_id && version().HasIetfQuicFrames()) { + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(1234); + ASSERT_NE(frame.connection_id, connection_.connection_id()); + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + frame.retire_prior_to = 0u; + frame.sequence_number = 1u; + connection_.OnNewConnectionIdFrame(frame); + } + } + + void ServerHandlePreferredAddressInit() { + ASSERT_TRUE(GetParam().version.HasIetfQuicFrames()); + set_perspective(Perspective::IS_SERVER); + connection_.CreateConnectionIdManager(); + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + SetQuicReloadableFlag(quic_connection_migration_use_new_cid_v2, true); + SetQuicReloadableFlag(quic_use_received_client_addresses_cache, true); + EXPECT_CALL(visitor_, AllowSelfAddressChange()) + .WillRepeatedly(Return(true)); + + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + peer_creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + // Discard INITIAL key. + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.NeuterUnencryptedPackets(); + // Prevent packets from being coalesced. + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + if (version().SupportsAntiAmplificationLimit()) { + QuicConnectionPeer::SetAddressValidated(&connection_); + } + // Clear direct_peer_address. + QuicConnectionPeer::SetDirectPeerAddress(&connection_, QuicSocketAddress()); + // Clear effective_peer_address, it is the same as direct_peer_address for + // this test. + QuicConnectionPeer::SetEffectivePeerAddress(&connection_, + QuicSocketAddress()); + EXPECT_FALSE(connection_.effective_peer_address().IsInitialized()); + + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + } else { + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(AnyNumber()); + } + QuicPacketCreatorPeer::SetPacketNumber(&peer_creator_, 2); + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kSelfAddress, + kPeerAddress, ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + QuicConfig config; + config.SetInitialReceivedConnectionOptions(QuicTagVector{kRVCM}); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + connection_.set_sent_server_preferred_address(kServerPreferredAddress); + } + + // Receive server preferred address. + void ServerPreferredAddressInit(QuicConfig& config) { + ASSERT_EQ(Perspective::IS_CLIENT, connection_.perspective()); + ASSERT_TRUE(version().HasIetfQuicFrames()); + ASSERT_TRUE(connection_.self_address().host().IsIPv6()); + SetQuicReloadableFlag(quic_connection_migration_use_new_cid_v2, true); + const QuicConnectionId connection_id = TestConnectionId(17); + const StatelessResetToken reset_token = + QuicUtils::GenerateStatelessResetToken(connection_id); + + connection_.CreateConnectionIdManager(); + + connection_.SendCryptoStreamData(); + EXPECT_CALL(*loss_algorithm_, DetectLosses(_, _, _, _, _, _)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + QuicAckFrame frame = InitAckFrame(1); + // Received ACK for packet 1. + ProcessFramePacketAtLevel(1, QuicFrame(&frame), ENCRYPTION_INITIAL); + // Discard INITIAL key. + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + + config.SetConnectionOptionsToSend(QuicTagVector{kRVCM, kSPAD}); + QuicConfigPeer::SetReceivedStatelessResetToken(&config, + kTestStatelessResetToken); + QuicConfigPeer::SetReceivedAlternateServerAddress(&config, + kServerPreferredAddress); + QuicConfigPeer::SetPreferredAddressConnectionIdAndToken( + &config, connection_id, reset_token); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + + ASSERT_TRUE( + QuicConnectionPeer::GetReceivedServerPreferredAddress(&connection_) + .IsInitialized()); + EXPECT_EQ( + kServerPreferredAddress, + QuicConnectionPeer::GetReceivedServerPreferredAddress(&connection_)); + } + + void TestClientRetryHandling(bool invalid_retry_tag, + bool missing_original_id_in_config, + bool wrong_original_id_in_config, + bool missing_retry_id_in_config, + bool wrong_retry_id_in_config); + + void TestReplaceConnectionIdFromInitial(); + + QuicConnectionId connection_id_; + QuicFramer framer_; + + MockSendAlgorithm* send_algorithm_; + std::unique_ptr loss_algorithm_; + MockClock clock_; + MockRandom random_generator_; + quiche::SimpleBufferAllocator buffer_allocator_; + std::unique_ptr helper_; + std::unique_ptr alarm_factory_; + QuicFramer peer_framer_; + QuicPacketCreator peer_creator_; + std::unique_ptr writer_; + TestConnection connection_; + QuicPacketCreator* creator_; + QuicSentPacketManager* manager_; + StrictMock visitor_; + + QuicStreamFrame frame1_; + QuicStreamFrame frame2_; + QuicCryptoFrame crypto_frame_; + QuicAckFrame ack_; + QuicStopWaitingFrame stop_waiting_; + QuicPacketNumberLength packet_number_length_; + QuicConnectionIdIncluded connection_id_included_; + + SimpleSessionNotifier notifier_; + + QuicConnectionCloseFrame saved_connection_close_frame_; + int connection_close_frame_count_; + MockConnectionIdGenerator connection_id_generator_; +}; + +// Run all end to end tests with all supported versions. +INSTANTIATE_TEST_SUITE_P(QuicConnectionTests, QuicConnectionTest, + ::testing::ValuesIn(GetTestParams()), + ::testing::PrintToStringParamName()); + +// These two tests ensure that the QuicErrorCode mapping works correctly. +// Both tests expect to see a Google QUIC close if not running IETF QUIC. +// If running IETF QUIC, the first will generate a transport connection +// close, the second an application connection close. +// The connection close codes for the two tests are manually chosen; +// they are expected to always map to transport- and application- +// closes, respectively. If that changes, new codes should be chosen. +TEST_P(QuicConnectionTest, CloseErrorCodeTestTransport) { + EXPECT_TRUE(connection_.connected()); + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)); + connection_.CloseConnection( + IETF_QUIC_PROTOCOL_VIOLATION, "Should be transport close", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + EXPECT_FALSE(connection_.connected()); + TestConnectionCloseQuicErrorCode(IETF_QUIC_PROTOCOL_VIOLATION); +} + +// Test that the IETF QUIC Error code mapping function works +// properly for application connection close codes. +TEST_P(QuicConnectionTest, CloseErrorCodeTestApplication) { + EXPECT_TRUE(connection_.connected()); + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)); + connection_.CloseConnection( + QUIC_HEADERS_STREAM_DATA_DECOMPRESS_FAILURE, + "Should be application close", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + EXPECT_FALSE(connection_.connected()); + TestConnectionCloseQuicErrorCode(QUIC_HEADERS_STREAM_DATA_DECOMPRESS_FAILURE); +} + +TEST_P(QuicConnectionTest, SelfAddressChangeAtClient) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + EXPECT_EQ(Perspective::IS_CLIENT, connection_.perspective()); + EXPECT_TRUE(connection_.connected()); + + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)); + } else { + EXPECT_CALL(visitor_, OnStreamFrame(_)); + } + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kSelfAddress, kPeerAddress, + ENCRYPTION_INITIAL); + // Cause change in self_address. + QuicIpAddress host; + host.FromString("1.1.1.1"); + QuicSocketAddress self_address(host, 123); + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)); + } else { + EXPECT_CALL(visitor_, OnStreamFrame(_)); + } + ProcessFramePacketWithAddresses(MakeCryptoFrame(), self_address, kPeerAddress, + ENCRYPTION_INITIAL); + EXPECT_TRUE(connection_.connected()); +} + +TEST_P(QuicConnectionTest, SelfAddressChangeAtServer) { + set_perspective(Perspective::IS_SERVER); + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + + EXPECT_EQ(Perspective::IS_SERVER, connection_.perspective()); + EXPECT_TRUE(connection_.connected()); + + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)); + } else { + EXPECT_CALL(visitor_, OnStreamFrame(_)); + } + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kSelfAddress, kPeerAddress, + ENCRYPTION_INITIAL); + // Cause change in self_address. + QuicIpAddress host; + host.FromString("1.1.1.1"); + QuicSocketAddress self_address(host, 123); + EXPECT_EQ(0u, connection_.GetStats().packets_dropped); + EXPECT_CALL(visitor_, AllowSelfAddressChange()).WillOnce(Return(false)); + ProcessFramePacketWithAddresses(MakeCryptoFrame(), self_address, kPeerAddress, + ENCRYPTION_INITIAL); + EXPECT_TRUE(connection_.connected()); + EXPECT_EQ(1u, connection_.GetStats().packets_dropped); +} + +TEST_P(QuicConnectionTest, AllowSelfAddressChangeToMappedIpv4AddressAtServer) { + set_perspective(Perspective::IS_SERVER); + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + + EXPECT_EQ(Perspective::IS_SERVER, connection_.perspective()); + EXPECT_TRUE(connection_.connected()); + + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(3); + } else { + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(3); + } + QuicIpAddress host; + host.FromString("1.1.1.1"); + QuicSocketAddress self_address1(host, 443); + connection_.SetSelfAddress(self_address1); + ProcessFramePacketWithAddresses(MakeCryptoFrame(), self_address1, + kPeerAddress, ENCRYPTION_INITIAL); + // Cause self_address change to mapped Ipv4 address. + QuicIpAddress host2; + host2.FromString( + absl::StrCat("::ffff:", connection_.self_address().host().ToString())); + QuicSocketAddress self_address2(host2, connection_.self_address().port()); + ProcessFramePacketWithAddresses(MakeCryptoFrame(), self_address2, + kPeerAddress, ENCRYPTION_INITIAL); + EXPECT_TRUE(connection_.connected()); + // self_address change back to Ipv4 address. + ProcessFramePacketWithAddresses(MakeCryptoFrame(), self_address1, + kPeerAddress, ENCRYPTION_INITIAL); + EXPECT_TRUE(connection_.connected()); +} + +TEST_P(QuicConnectionTest, ClientAddressChangeAndPacketReordered) { + set_perspective(Perspective::IS_SERVER); + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + + // Clear direct_peer_address. + QuicConnectionPeer::SetDirectPeerAddress(&connection_, QuicSocketAddress()); + // Clear effective_peer_address, it is the same as direct_peer_address for + // this test. + QuicConnectionPeer::SetEffectivePeerAddress(&connection_, + QuicSocketAddress()); + + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + } else { + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(AnyNumber()); + } + QuicPacketCreatorPeer::SetPacketNumber(&peer_creator_, 5); + const QuicSocketAddress kNewPeerAddress = + QuicSocketAddress(QuicIpAddress::Loopback6(), + /*port=*/23456); + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kSelfAddress, + kNewPeerAddress, ENCRYPTION_INITIAL); + EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); + EXPECT_EQ(kNewPeerAddress, connection_.effective_peer_address()); + + // Decrease packet number to simulate out-of-order packets. + QuicPacketCreatorPeer::SetPacketNumber(&peer_creator_, 4); + // This is an old packet, do not migrate. + EXPECT_CALL(visitor_, OnConnectionMigration(PORT_CHANGE)).Times(0); + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kSelfAddress, kPeerAddress, + ENCRYPTION_INITIAL); + EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); + EXPECT_EQ(kNewPeerAddress, connection_.effective_peer_address()); +} + +TEST_P(QuicConnectionTest, PeerPortChangeAtServer) { + set_perspective(Perspective::IS_SERVER); + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + EXPECT_EQ(Perspective::IS_SERVER, connection_.perspective()); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + // Prevent packets from being coalesced. + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + if (version().SupportsAntiAmplificationLimit()) { + QuicConnectionPeer::SetAddressValidated(&connection_); + } + + // Clear direct_peer_address. + QuicConnectionPeer::SetDirectPeerAddress(&connection_, QuicSocketAddress()); + // Clear effective_peer_address, it is the same as direct_peer_address for + // this test. + QuicConnectionPeer::SetEffectivePeerAddress(&connection_, + QuicSocketAddress()); + EXPECT_FALSE(connection_.effective_peer_address().IsInitialized()); + + RttStats* rtt_stats = const_cast(manager_->GetRttStats()); + QuicTime::Delta default_init_rtt = rtt_stats->initial_rtt(); + rtt_stats->set_initial_rtt(default_init_rtt * 2); + EXPECT_EQ(2 * default_init_rtt, rtt_stats->initial_rtt()); + + QuicSentPacketManagerPeer::SetConsecutivePtoCount(manager_, 1); + EXPECT_EQ(1u, manager_->GetConsecutivePtoCount()); + + const QuicSocketAddress kNewPeerAddress = + QuicSocketAddress(QuicIpAddress::Loopback6(), /*port=*/23456); + EXPECT_CALL(visitor_, OnStreamFrame(_)) + .WillOnce(Invoke( + [=]() { EXPECT_EQ(kPeerAddress, connection_.peer_address()); })) + .WillOnce(Invoke( + [=]() { EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); })); + QuicFrames frames; + frames.push_back(QuicFrame(frame1_)); + ProcessFramesPacketWithAddresses(frames, kSelfAddress, kPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + + // Process another packet with a different peer address on server side will + // start connection migration. + EXPECT_CALL(visitor_, OnConnectionMigration(PORT_CHANGE)).Times(1); + QuicFrames frames2; + frames2.push_back(QuicFrame(frame2_)); + ProcessFramesPacketWithAddresses(frames2, kSelfAddress, kNewPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); + EXPECT_EQ(kNewPeerAddress, connection_.effective_peer_address()); + // PORT_CHANGE shouldn't state change in sent packet manager. + EXPECT_EQ(2 * default_init_rtt, rtt_stats->initial_rtt()); + EXPECT_EQ(1u, manager_->GetConsecutivePtoCount()); + EXPECT_EQ(manager_->GetSendAlgorithm(), send_algorithm_); + if (connection_.validate_client_address()) { + EXPECT_EQ(NO_CHANGE, connection_.active_effective_peer_migration_type()); + EXPECT_EQ(1u, connection_.GetStats().num_validated_peer_migration); + EXPECT_EQ(1u, connection_.num_linkable_client_migration()); + } +} + +TEST_P(QuicConnectionTest, PeerIpAddressChangeAtServer) { + set_perspective(Perspective::IS_SERVER); + if (!connection_.validate_client_address() || + GetQuicFlag(quic_enforce_strict_amplification_factor)) { + return; + } + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + EXPECT_EQ(Perspective::IS_SERVER, connection_.perspective()); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + // Discard INITIAL key. + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.NeuterUnencryptedPackets(); + // Prevent packets from being coalesced. + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + QuicConnectionPeer::SetAddressValidated(&connection_); + connection_.OnHandshakeComplete(); + + // Enable 5 RTO + QuicConfig config; + QuicTagVector connection_options; + connection_options.push_back(k5RTO); + config.SetInitialReceivedConnectionOptions(connection_options); + QuicConfigPeer::SetNegotiated(&config, true); + QuicConfigPeer::SetReceivedOriginalConnectionId(&config, + connection_.connection_id()); + QuicConfigPeer::SetReceivedInitialSourceConnectionId(&config, + QuicConnectionId()); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + + // Clear direct_peer_address. + QuicConnectionPeer::SetDirectPeerAddress(&connection_, QuicSocketAddress()); + // Clear effective_peer_address, it is the same as direct_peer_address for + // this test. + QuicConnectionPeer::SetEffectivePeerAddress(&connection_, + QuicSocketAddress()); + EXPECT_FALSE(connection_.effective_peer_address().IsInitialized()); + + const QuicSocketAddress kNewPeerAddress = + QuicSocketAddress(QuicIpAddress::Loopback4(), /*port=*/23456); + EXPECT_CALL(visitor_, OnStreamFrame(_)) + .WillOnce(Invoke( + [=]() { EXPECT_EQ(kPeerAddress, connection_.peer_address()); })) + .WillOnce(Invoke( + [=]() { EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); })); + QuicFrames frames; + frames.push_back(QuicFrame(frame1_)); + ProcessFramesPacketWithAddresses(frames, kSelfAddress, kPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + + // Send some data to make connection has packets in flight. + connection_.SendStreamData3(); + EXPECT_EQ(1u, writer_->packets_write_attempts()); + EXPECT_TRUE(connection_.BlackholeDetectionInProgress()); + EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + + // Process another packet with a different peer address on server side will + // start connection migration. + EXPECT_CALL(visitor_, OnConnectionMigration(IPV6_TO_IPV4_CHANGE)).Times(1); + // IETF QUIC send algorithm should be changed to a different object, so no + // OnPacketSent() called on the old send algorithm. + EXPECT_CALL(*send_algorithm_, + OnPacketSent(_, _, _, _, NO_RETRANSMITTABLE_DATA)) + .Times(0); + // Do not propagate OnCanWrite() to session notifier. + EXPECT_CALL(visitor_, OnCanWrite()).Times(AtLeast(1u)); + + QuicFrames frames2; + frames2.push_back(QuicFrame(frame2_)); + ProcessFramesPacketWithAddresses(frames2, kSelfAddress, kNewPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); + EXPECT_EQ(kNewPeerAddress, connection_.effective_peer_address()); + EXPECT_EQ(IPV6_TO_IPV4_CHANGE, + connection_.active_effective_peer_migration_type()); + EXPECT_FALSE(connection_.BlackholeDetectionInProgress()); + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + + EXPECT_EQ(2u, writer_->packets_write_attempts()); + EXPECT_FALSE(writer_->path_challenge_frames().empty()); + QuicPathFrameBuffer payload = + writer_->path_challenge_frames().front().data_buffer; + EXPECT_NE(connection_.sent_packet_manager().GetSendAlgorithm(), + send_algorithm_); + // Switch to use the mock send algorithm. + send_algorithm_ = new StrictMock(); + EXPECT_CALL(*send_algorithm_, CanSend(_)).WillRepeatedly(Return(true)); + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .WillRepeatedly(Return(kDefaultTCPMSS)); + EXPECT_CALL(*send_algorithm_, OnApplicationLimited(_)).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, BandwidthEstimate()) + .Times(AnyNumber()) + .WillRepeatedly(Return(QuicBandwidth::Zero())); + EXPECT_CALL(*send_algorithm_, InSlowStart()).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, InRecovery()).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, PopulateConnectionStats(_)).Times(AnyNumber()); + connection_.SetSendAlgorithm(send_algorithm_); + + // PATH_CHALLENGE is expanded upto the max packet size which may exceeds the + // anti-amplification limit. + EXPECT_EQ(kNewPeerAddress, writer_->last_write_peer_address()); + EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); + EXPECT_EQ(kNewPeerAddress, connection_.effective_peer_address()); + EXPECT_EQ(1u, + connection_.GetStats().num_reverse_path_validtion_upon_migration); + + // Verify server is throttled by anti-amplification limit. + connection_.SendCryptoDataWithString("foo", 0); + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + + // Receiving an ACK to the packet sent after changing peer address doesn't + // finish migration validation. + QuicAckFrame ack_frame = InitAckFrame(2); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _, _, _)); + ProcessFramePacketWithAddresses(QuicFrame(&ack_frame), kSelfAddress, + kNewPeerAddress, ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); + EXPECT_EQ(kNewPeerAddress, connection_.effective_peer_address()); + EXPECT_EQ(IPV6_TO_IPV4_CHANGE, + connection_.active_effective_peer_migration_type()); + + // Receiving PATH_RESPONSE should lift the anti-amplification limit. + QuicFrames frames3; + frames3.push_back(QuicFrame(QuicPathResponseFrame(99, payload))); + EXPECT_CALL(visitor_, MaybeSendAddressToken()); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(testing::AtLeast(1u)); + ProcessFramesPacketWithAddresses(frames3, kSelfAddress, kNewPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(NO_CHANGE, connection_.active_effective_peer_migration_type()); + + // Verify the anti-amplification limit is lifted by sending a packet larger + // than the anti-amplification limit. + connection_.SendCryptoDataWithString(std::string(1200, 'a'), 0); + EXPECT_EQ(1u, connection_.GetStats().num_validated_peer_migration); + EXPECT_EQ(1u, connection_.num_linkable_client_migration()); +} + +TEST_P(QuicConnectionTest, PeerIpAddressChangeAtServerWithMissingConnectionId) { + set_perspective(Perspective::IS_SERVER); + if (!connection_.connection_migration_use_new_cid()) { + return; + } + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + EXPECT_EQ(Perspective::IS_SERVER, connection_.perspective()); + + QuicConnectionId client_cid0 = TestConnectionId(1); + QuicConnectionId client_cid1 = TestConnectionId(3); + QuicConnectionId server_cid1; + SetClientConnectionId(client_cid0); + connection_.CreateConnectionIdManager(); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + // Prevent packets from being coalesced. + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + QuicConnectionPeer::SetAddressValidated(&connection_); + + // Sends new server CID to client. + if (!connection_.connection_id().IsEmpty()) { + EXPECT_CALL(connection_id_generator_, GenerateNextConnectionId(_)) + .WillOnce(Return(TestConnectionId(456))); + } + EXPECT_CALL(visitor_, MaybeReserveConnectionId(_)) + .WillOnce(Invoke([&](const QuicConnectionId& cid) { + server_cid1 = cid; + return true; + })); + EXPECT_CALL(visitor_, SendNewConnectionId(_)); + connection_.OnHandshakeComplete(); + + // Clear direct_peer_address. + QuicConnectionPeer::SetDirectPeerAddress(&connection_, QuicSocketAddress()); + // Clear effective_peer_address, it is the same as direct_peer_address for + // this test. + QuicConnectionPeer::SetEffectivePeerAddress(&connection_, + QuicSocketAddress()); + EXPECT_FALSE(connection_.effective_peer_address().IsInitialized()); + + const QuicSocketAddress kNewPeerAddress = + QuicSocketAddress(QuicIpAddress::Loopback4(), /*port=*/23456); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(2); + QuicFrames frames; + frames.push_back(QuicFrame(frame1_)); + ProcessFramesPacketWithAddresses(frames, kSelfAddress, kPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + + // Send some data to make connection has packets in flight. + connection_.SendStreamData3(); + EXPECT_EQ(1u, writer_->packets_write_attempts()); + + // Process another packet with a different peer address on server side will + // start connection migration. + peer_creator_.SetServerConnectionId(server_cid1); + EXPECT_CALL(visitor_, OnConnectionMigration(IPV6_TO_IPV4_CHANGE)).Times(1); + // Do not propagate OnCanWrite() to session notifier. + EXPECT_CALL(visitor_, OnCanWrite()).Times(AtLeast(1u)); + + QuicFrames frames2; + frames2.push_back(QuicFrame(frame2_)); + if (GetQuicFlag(quic_enforce_strict_amplification_factor)) { + frames2.push_back(QuicFrame(QuicPaddingFrame(-1))); + } + ProcessFramesPacketWithAddresses(frames2, kSelfAddress, kNewPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); + EXPECT_EQ(kNewPeerAddress, connection_.effective_peer_address()); + + // Writing path response & reverse path challenge is blocked due to missing + // client connection ID, i.e., packets_write_attempts is unchanged. + EXPECT_EQ(1u, writer_->packets_write_attempts()); + + // Receives new client CID from client would unblock write. + QuicNewConnectionIdFrame new_cid_frame; + new_cid_frame.connection_id = client_cid1; + new_cid_frame.sequence_number = 1u; + new_cid_frame.retire_prior_to = 0u; + connection_.OnNewConnectionIdFrame(new_cid_frame); + connection_.SendStreamData3(); + + EXPECT_EQ(2u, writer_->packets_write_attempts()); +} + +TEST_P(QuicConnectionTest, EffectivePeerAddressChangeAtServer) { + if (GetQuicFlag(quic_enforce_strict_amplification_factor)) { + return; + } + set_perspective(Perspective::IS_SERVER); + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + EXPECT_EQ(Perspective::IS_SERVER, connection_.perspective()); + if (version().SupportsAntiAmplificationLimit()) { + QuicConnectionPeer::SetAddressValidated(&connection_); + } + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + // Discard INITIAL key. + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.NeuterUnencryptedPackets(); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + + // Clear direct_peer_address. + QuicConnectionPeer::SetDirectPeerAddress(&connection_, QuicSocketAddress()); + // Clear effective_peer_address, it is different from direct_peer_address for + // this test. + QuicConnectionPeer::SetEffectivePeerAddress(&connection_, + QuicSocketAddress()); + const QuicSocketAddress kEffectivePeerAddress = + QuicSocketAddress(QuicIpAddress::Loopback6(), /*port=*/43210); + connection_.ReturnEffectivePeerAddressForNextPacket(kEffectivePeerAddress); + + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + } else { + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(AnyNumber()); + } + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kSelfAddress, kPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kEffectivePeerAddress, connection_.effective_peer_address()); + + // Process another packet with the same direct peer address and different + // effective peer address on server side will start connection migration. + const QuicSocketAddress kNewEffectivePeerAddress = + QuicSocketAddress(QuicIpAddress::Loopback6(), /*port=*/54321); + connection_.ReturnEffectivePeerAddressForNextPacket(kNewEffectivePeerAddress); + EXPECT_CALL(visitor_, OnConnectionMigration(PORT_CHANGE)).Times(1); + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kSelfAddress, kPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kNewEffectivePeerAddress, connection_.effective_peer_address()); + EXPECT_EQ(kPeerAddress, writer_->last_write_peer_address()); + if (connection_.validate_client_address()) { + EXPECT_EQ(NO_CHANGE, connection_.active_effective_peer_migration_type()); + EXPECT_EQ(1u, connection_.GetStats().num_validated_peer_migration); + EXPECT_EQ(1u, connection_.num_linkable_client_migration()); + } + + // Process another packet with a different direct peer address and the same + // effective peer address on server side will not start connection migration. + const QuicSocketAddress kNewPeerAddress = + QuicSocketAddress(QuicIpAddress::Loopback6(), /*port=*/23456); + connection_.ReturnEffectivePeerAddressForNextPacket(kNewEffectivePeerAddress); + EXPECT_CALL(visitor_, OnConnectionMigration(PORT_CHANGE)).Times(0); + + if (!connection_.validate_client_address()) { + // ack_frame is used to complete the migration started by the last packet, + // we need to make sure a new migration does not start after the previous + // one is completed. + QuicAckFrame ack_frame = InitAckFrame(1); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _, _, _)); + ProcessFramePacketWithAddresses(QuicFrame(&ack_frame), kSelfAddress, + kNewPeerAddress, ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); + EXPECT_EQ(kNewEffectivePeerAddress, connection_.effective_peer_address()); + EXPECT_EQ(NO_CHANGE, connection_.active_effective_peer_migration_type()); + } + + // Process another packet with different direct peer address and different + // effective peer address on server side will start connection migration. + const QuicSocketAddress kNewerEffectivePeerAddress = + QuicSocketAddress(QuicIpAddress::Loopback6(), /*port=*/65432); + const QuicSocketAddress kFinalPeerAddress = + QuicSocketAddress(QuicIpAddress::Loopback6(), /*port=*/34567); + connection_.ReturnEffectivePeerAddressForNextPacket( + kNewerEffectivePeerAddress); + EXPECT_CALL(visitor_, OnConnectionMigration(PORT_CHANGE)).Times(1); + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kSelfAddress, + kFinalPeerAddress, ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kFinalPeerAddress, connection_.peer_address()); + EXPECT_EQ(kNewerEffectivePeerAddress, connection_.effective_peer_address()); + if (connection_.validate_client_address()) { + EXPECT_EQ(NO_CHANGE, connection_.active_effective_peer_migration_type()); + EXPECT_EQ(send_algorithm_, + connection_.sent_packet_manager().GetSendAlgorithm()); + EXPECT_EQ(2u, connection_.GetStats().num_validated_peer_migration); + } + + // While the previous migration is ongoing, process another packet with the + // same direct peer address and different effective peer address on server + // side will start a new connection migration. + const QuicSocketAddress kNewestEffectivePeerAddress = + QuicSocketAddress(QuicIpAddress::Loopback4(), /*port=*/65430); + connection_.ReturnEffectivePeerAddressForNextPacket( + kNewestEffectivePeerAddress); + EXPECT_CALL(visitor_, OnConnectionMigration(IPV6_TO_IPV4_CHANGE)).Times(1); + if (!connection_.validate_client_address()) { + EXPECT_CALL(*send_algorithm_, OnConnectionMigration()).Times(1); + } + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kSelfAddress, + kFinalPeerAddress, ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kFinalPeerAddress, connection_.peer_address()); + EXPECT_EQ(kNewestEffectivePeerAddress, connection_.effective_peer_address()); + EXPECT_EQ(IPV6_TO_IPV4_CHANGE, + connection_.active_effective_peer_migration_type()); + if (connection_.validate_client_address()) { + EXPECT_NE(send_algorithm_, + connection_.sent_packet_manager().GetSendAlgorithm()); + EXPECT_EQ(kFinalPeerAddress, writer_->last_write_peer_address()); + EXPECT_FALSE(writer_->path_challenge_frames().empty()); + EXPECT_EQ(0u, connection_.GetStats() + .num_peer_migration_while_validating_default_path); + EXPECT_TRUE(connection_.HasPendingPathValidation()); + } +} + +// Regression test for b/200020764. +TEST_P(QuicConnectionTest, ConnectionMigrationWithPendingPaddingBytes) { + // TODO(haoyuewang) Move these test setup code to a common member function. + set_perspective(Perspective::IS_SERVER); + if (!connection_.connection_migration_use_new_cid()) { + return; + } + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + connection_.CreateConnectionIdManager(); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + QuicConnectionPeer::SetPeerAddress(&connection_, kPeerAddress); + QuicConnectionPeer::SetEffectivePeerAddress(&connection_, kPeerAddress); + QuicConnectionPeer::SetAddressValidated(&connection_); + + // Sends new server CID to client. + QuicConnectionId new_cid; + if (!connection_.connection_id().IsEmpty()) { + EXPECT_CALL(connection_id_generator_, GenerateNextConnectionId(_)) + .WillOnce(Return(TestConnectionId(456))); + } + EXPECT_CALL(visitor_, MaybeReserveConnectionId(_)) + .WillOnce(Invoke([&](const QuicConnectionId& cid) { + new_cid = cid; + return true; + })); + EXPECT_CALL(visitor_, SendNewConnectionId(_)); + // Discard INITIAL key. + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.NeuterUnencryptedPackets(); + connection_.OnHandshakeComplete(); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + + auto* packet_creator = QuicConnectionPeer::GetPacketCreator(&connection_); + packet_creator->FlushCurrentPacket(); + packet_creator->AddPendingPadding(50u); + const QuicSocketAddress kPeerAddress3 = + QuicSocketAddress(QuicIpAddress::Loopback6(), /*port=*/56789); + auto ack_frame = InitAckFrame(1); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _, _, _)); + EXPECT_CALL(visitor_, OnConnectionMigration(PORT_CHANGE)).Times(1); + ProcessFramesPacketWithAddresses({QuicFrame(&ack_frame)}, kSelfAddress, + kPeerAddress3, ENCRYPTION_FORWARD_SECURE); + if (GetQuicReloadableFlag( + quic_flush_pending_frames_and_padding_bytes_on_migration)) { + // Any pending frames/padding should be flushed before default_path_ is + // temporarily reset. + ASSERT_EQ(connection_.self_address_on_default_path_while_sending_packet() + .host() + .address_family(), + IpAddressFamily::IP_V6); + } else { + ASSERT_EQ(connection_.self_address_on_default_path_while_sending_packet() + .host() + .address_family(), + IpAddressFamily::IP_UNSPEC); + } +} + +// Regression test for b/196208556. +TEST_P(QuicConnectionTest, + ReversePathValidationResponseReceivedFromUnexpectedPeerAddress) { + set_perspective(Perspective::IS_SERVER); + if (!connection_.connection_migration_use_new_cid() || + GetQuicFlag(quic_enforce_strict_amplification_factor)) { + return; + } + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + connection_.CreateConnectionIdManager(); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + QuicConnectionPeer::SetPeerAddress(&connection_, kPeerAddress); + QuicConnectionPeer::SetEffectivePeerAddress(&connection_, kPeerAddress); + QuicConnectionPeer::SetAddressValidated(&connection_); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + + // Sends new server CID to client. + QuicConnectionId new_cid; + if (!connection_.connection_id().IsEmpty()) { + EXPECT_CALL(connection_id_generator_, GenerateNextConnectionId(_)) + .WillOnce(Return(TestConnectionId(456))); + } + EXPECT_CALL(visitor_, MaybeReserveConnectionId(_)) + .WillOnce(Invoke([&](const QuicConnectionId& cid) { + new_cid = cid; + return true; + })); + EXPECT_CALL(visitor_, SendNewConnectionId(_)); + // Discard INITIAL key. + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.NeuterUnencryptedPackets(); + connection_.OnHandshakeComplete(); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + + // Process a non-probing packet to migrate to path 2 and kick off reverse path + // validation. + EXPECT_CALL(visitor_, OnConnectionMigration(IPV6_TO_IPV4_CHANGE)).Times(1); + const QuicSocketAddress kPeerAddress2 = + QuicSocketAddress(QuicIpAddress::Loopback4(), /*port=*/23456); + peer_creator_.SetServerConnectionId(new_cid); + ProcessFramesPacketWithAddresses({QuicFrame(QuicPingFrame())}, kSelfAddress, + kPeerAddress2, ENCRYPTION_FORWARD_SECURE); + EXPECT_FALSE(writer_->path_challenge_frames().empty()); + QuicPathFrameBuffer reverse_path_challenge_payload = + writer_->path_challenge_frames().front().data_buffer; + + // Receiveds a packet from path 3 with PATH_RESPONSE frame intended to + // validate path 2 and a non-probing frame. + { + QuicConnection::ScopedPacketFlusher flusher(&connection_); + const QuicSocketAddress kPeerAddress3 = + QuicSocketAddress(QuicIpAddress::Loopback6(), /*port=*/56789); + auto ack_frame = InitAckFrame(1); + EXPECT_CALL(visitor_, OnConnectionMigration(IPV4_TO_IPV6_CHANGE)).Times(1); + EXPECT_CALL(visitor_, MaybeSendAddressToken()).WillOnce(Invoke([this]() { + connection_.SendControlFrame( + QuicFrame(new QuicNewTokenFrame(1, "new_token"))); + return true; + })); + ProcessFramesPacketWithAddresses( + {QuicFrame(QuicPathResponseFrame(0, reverse_path_challenge_payload)), + QuicFrame(&ack_frame)}, + kSelfAddress, kPeerAddress3, ENCRYPTION_FORWARD_SECURE); + } +} + +TEST_P(QuicConnectionTest, ReversePathValidationFailureAtServer) { + set_perspective(Perspective::IS_SERVER); + if (!connection_.connection_migration_use_new_cid()) { + return; + } + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + EXPECT_EQ(Perspective::IS_SERVER, connection_.perspective()); + SetClientConnectionId(TestConnectionId(1)); + connection_.CreateConnectionIdManager(); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + // Discard INITIAL key. + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.NeuterUnencryptedPackets(); + // Prevent packets from being coalesced. + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + QuicConnectionPeer::SetAddressValidated(&connection_); + + QuicConnectionId client_cid0 = connection_.client_connection_id(); + QuicConnectionId client_cid1 = TestConnectionId(2); + QuicConnectionId server_cid0 = connection_.connection_id(); + QuicConnectionId server_cid1; + // Sends new server CID to client. + if (!connection_.connection_id().IsEmpty()) { + EXPECT_CALL(connection_id_generator_, GenerateNextConnectionId(_)) + .WillOnce(Return(TestConnectionId(456))); + } + EXPECT_CALL(visitor_, MaybeReserveConnectionId(_)) + .WillOnce(Invoke([&](const QuicConnectionId& cid) { + server_cid1 = cid; + return true; + })); + EXPECT_CALL(visitor_, SendNewConnectionId(_)); + connection_.OnHandshakeComplete(); + // Receives new client CID from client. + QuicNewConnectionIdFrame new_cid_frame; + new_cid_frame.connection_id = client_cid1; + new_cid_frame.sequence_number = 1u; + new_cid_frame.retire_prior_to = 0u; + connection_.OnNewConnectionIdFrame(new_cid_frame); + auto* packet_creator = QuicConnectionPeer::GetPacketCreator(&connection_); + ASSERT_EQ(packet_creator->GetDestinationConnectionId(), client_cid0); + ASSERT_EQ(packet_creator->GetSourceConnectionId(), server_cid0); + + // Clear direct_peer_address. + QuicConnectionPeer::SetDirectPeerAddress(&connection_, QuicSocketAddress()); + // Clear effective_peer_address, it is the same as direct_peer_address for + // this test. + QuicConnectionPeer::SetEffectivePeerAddress(&connection_, + QuicSocketAddress()); + EXPECT_FALSE(connection_.effective_peer_address().IsInitialized()); + + const QuicSocketAddress kNewPeerAddress = + QuicSocketAddress(QuicIpAddress::Loopback4(), /*port=*/23456); + EXPECT_CALL(visitor_, OnStreamFrame(_)) + .WillOnce(Invoke( + [=]() { EXPECT_EQ(kPeerAddress, connection_.peer_address()); })) + .WillOnce(Invoke( + [=]() { EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); })); + QuicFrames frames; + frames.push_back(QuicFrame(frame1_)); + ProcessFramesPacketWithAddresses(frames, kSelfAddress, kPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + + // Process another packet with a different peer address on server side will + // start connection migration. + EXPECT_CALL(visitor_, OnConnectionMigration(IPV6_TO_IPV4_CHANGE)).Times(1); + // IETF QUIC send algorithm should be changed to a different object, so no + // OnPacketSent() called on the old send algorithm. + EXPECT_CALL(*send_algorithm_, OnConnectionMigration()).Times(0); + + QuicFrames frames2; + frames2.push_back(QuicFrame(frame2_)); + QuicPaddingFrame padding; + frames2.push_back(QuicFrame(padding)); + peer_creator_.SetServerConnectionId(server_cid1); + ProcessFramesPacketWithAddresses(frames2, kSelfAddress, kNewPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); + EXPECT_EQ(kNewPeerAddress, connection_.effective_peer_address()); + EXPECT_EQ(IPV6_TO_IPV4_CHANGE, + connection_.active_effective_peer_migration_type()); + EXPECT_LT(0u, writer_->packets_write_attempts()); + EXPECT_TRUE(connection_.HasPendingPathValidation()); + EXPECT_NE(connection_.sent_packet_manager().GetSendAlgorithm(), + send_algorithm_); + EXPECT_EQ(kNewPeerAddress, writer_->last_write_peer_address()); + EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); + EXPECT_EQ(kNewPeerAddress, connection_.effective_peer_address()); + const auto* default_path = QuicConnectionPeer::GetDefaultPath(&connection_); + const auto* alternative_path = + QuicConnectionPeer::GetAlternativePath(&connection_); + EXPECT_EQ(default_path->client_connection_id, client_cid1); + EXPECT_EQ(default_path->server_connection_id, server_cid1); + EXPECT_EQ(alternative_path->client_connection_id, client_cid0); + EXPECT_EQ(alternative_path->server_connection_id, server_cid0); + EXPECT_EQ(packet_creator->GetDestinationConnectionId(), client_cid1); + EXPECT_EQ(packet_creator->GetSourceConnectionId(), server_cid1); + + for (size_t i = 0; i < QuicPathValidator::kMaxRetryTimes; ++i) { + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(3 * kInitialRttMs)); + static_cast( + QuicPathValidatorPeer::retry_timer( + QuicConnectionPeer::path_validator(&connection_))) + ->Fire(); + } + EXPECT_EQ(IPV6_TO_IPV4_CHANGE, + connection_.active_effective_peer_migration_type()); + + // Make sure anti-amplification limit is not reached. + ProcessFramesPacketWithAddresses( + {QuicFrame(QuicPingFrame()), QuicFrame(QuicPaddingFrame())}, kSelfAddress, + kNewPeerAddress, ENCRYPTION_FORWARD_SECURE); + SendStreamDataToPeer(1, "foo", 0, NO_FIN, nullptr); + EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + + // Advance the time so that the reverse path validation times out. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(3 * kInitialRttMs)); + static_cast( + QuicPathValidatorPeer::retry_timer( + QuicConnectionPeer::path_validator(&connection_))) + ->Fire(); + EXPECT_EQ(NO_CHANGE, connection_.active_effective_peer_migration_type()); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + EXPECT_EQ(connection_.sent_packet_manager().GetSendAlgorithm(), + send_algorithm_); + EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + + // Verify that default_path_ is reverted and alternative_path_ is cleared. + EXPECT_EQ(default_path->client_connection_id, client_cid0); + EXPECT_EQ(default_path->server_connection_id, server_cid0); + EXPECT_TRUE(alternative_path->server_connection_id.IsEmpty()); + EXPECT_FALSE(alternative_path->stateless_reset_token.has_value()); + auto* retire_peer_issued_cid_alarm = + connection_.GetRetirePeerIssuedConnectionIdAlarm(); + ASSERT_TRUE(retire_peer_issued_cid_alarm->IsSet()); + EXPECT_CALL(visitor_, SendRetireConnectionId(/*sequence_number=*/1u)); + retire_peer_issued_cid_alarm->Fire(); + EXPECT_EQ(packet_creator->GetDestinationConnectionId(), client_cid0); + EXPECT_EQ(packet_creator->GetSourceConnectionId(), server_cid0); +} + +TEST_P(QuicConnectionTest, ReceivePathProbeWithNoAddressChangeAtServer) { + PathProbeTestInit(Perspective::IS_SERVER); + + EXPECT_CALL(visitor_, OnConnectionMigration(PORT_CHANGE)).Times(0); + EXPECT_CALL(visitor_, OnPacketReceived(_, _, false)).Times(0); + + // Process a padded PING packet with no peer address change on server side + // will be ignored. But a PATH CHALLENGE packet with no peer address change + // will be considered as path probing. + std::unique_ptr probing_packet = ConstructProbingPacket(); + + std::unique_ptr received(ConstructReceivedPacket( + QuicEncryptedPacket(probing_packet->encrypted_buffer, + probing_packet->encrypted_length), + clock_.Now())); + + uint64_t num_probing_received = + connection_.GetStats().num_connectivity_probing_received; + ProcessReceivedPacket(kSelfAddress, kPeerAddress, *received); + + EXPECT_EQ( + num_probing_received + (GetParam().version.HasIetfQuicFrames() ? 1u : 0u), + connection_.GetStats().num_connectivity_probing_received); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); +} + +// Regression test for b/150161358. +TEST_P(QuicConnectionTest, BufferedMtuPacketTooBig) { + EXPECT_CALL(visitor_, OnWriteBlocked()).Times(1); + writer_->SetWriteBlocked(); + + // Send a MTU packet while blocked. It should be buffered. + connection_.SendMtuDiscoveryPacket(kMaxOutgoingPacketSize); + EXPECT_EQ(1u, connection_.NumQueuedPackets()); + EXPECT_TRUE(writer_->IsWriteBlocked()); + + writer_->AlwaysGetPacketTooLarge(); + writer_->SetWritable(); + connection_.OnCanWrite(); +} + +TEST_P(QuicConnectionTest, WriteOutOfOrderQueuedPackets) { + // EXPECT_QUIC_BUG tests are expensive so only run one instance of them. + if (!IsDefaultTestConfiguration()) { + return; + } + + set_perspective(Perspective::IS_CLIENT); + + BlockOnNextWrite(); + + QuicStreamId stream_id = 2; + connection_.SendStreamDataWithString(stream_id, "foo", 0, NO_FIN); + + EXPECT_EQ(1u, connection_.NumQueuedPackets()); + + writer_->SetWritable(); + connection_.SendConnectivityProbingPacket(writer_.get(), + connection_.peer_address()); + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)).Times(0); + connection_.OnCanWrite(); +} + +TEST_P(QuicConnectionTest, DiscardQueuedPacketsAfterConnectionClose) { + // Regression test for b/74073386. + { + InSequence seq; + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(AtLeast(1)); + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)).Times(AtLeast(1)); + } + + set_perspective(Perspective::IS_CLIENT); + + writer_->SimulateNextPacketTooLarge(); + + // This packet write should fail, which should cause the connection to close + // after sending a connection close packet, then the failed packet should be + // queued. + connection_.SendStreamDataWithString(/*id=*/2, "foo", 0, NO_FIN); + + EXPECT_FALSE(connection_.connected()); + // No need to buffer packets. + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + + EXPECT_EQ(0u, connection_.GetStats().packets_discarded); + connection_.OnCanWrite(); + EXPECT_EQ(0u, connection_.GetStats().packets_discarded); +} + +class TestQuicPathValidationContext : public QuicPathValidationContext { + public: + TestQuicPathValidationContext(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + + QuicPacketWriter* writer) + : QuicPathValidationContext(self_address, peer_address), + writer_(writer) {} + + QuicPacketWriter* WriterToUse() override { return writer_; } + + private: + QuicPacketWriter* writer_; +}; + +class TestValidationResultDelegate : public QuicPathValidator::ResultDelegate { + public: + TestValidationResultDelegate(QuicConnection* connection, + const QuicSocketAddress& expected_self_address, + const QuicSocketAddress& expected_peer_address, + bool* success) + : QuicPathValidator::ResultDelegate(), + connection_(connection), + expected_self_address_(expected_self_address), + expected_peer_address_(expected_peer_address), + success_(success) {} + void OnPathValidationSuccess( + std::unique_ptr context, + QuicTime /*start_time*/) override { + EXPECT_EQ(expected_self_address_, context->self_address()); + EXPECT_EQ(expected_peer_address_, context->peer_address()); + *success_ = true; + } + + void OnPathValidationFailure( + std::unique_ptr context) override { + EXPECT_EQ(expected_self_address_, context->self_address()); + EXPECT_EQ(expected_peer_address_, context->peer_address()); + if (connection_->perspective() == Perspective::IS_CLIENT) { + connection_->OnPathValidationFailureAtClient(/*is_multi_port=*/false, + *context); + } + *success_ = false; + } + + private: + QuicConnection* connection_; + QuicSocketAddress expected_self_address_; + QuicSocketAddress expected_peer_address_; + bool* success_; +}; + +// A test implementation which migrates to server preferred address +// on path validation suceeds. Otherwise, client cleans up alternative path. +class ServerPreferredAddressTestResultDelegate + : public QuicPathValidator::ResultDelegate { + public: + explicit ServerPreferredAddressTestResultDelegate(QuicConnection* connection) + : connection_(connection) {} + void OnPathValidationSuccess( + std::unique_ptr context, + QuicTime /*start_time*/) override { + connection_->OnServerPreferredAddressValidated(*context, false); + } + + void OnPathValidationFailure( + std::unique_ptr context) override { + connection_->OnPathValidationFailureAtClient(/*is_multi_port=*/false, + *context); + } + + protected: + QuicConnection* connection() { return connection_; } + + private: + QuicConnection* connection_; +}; + +// Receive a path probe request at the server side, i.e., +// in non-IETF version: receive a padded PING packet with a peer addess change; +// in IETF version: receive a packet contains PATH CHALLENGE with peer address +// change. +TEST_P(QuicConnectionTest, ReceivePathProbingFromNewPeerAddressAtServer) { + PathProbeTestInit(Perspective::IS_SERVER); + + EXPECT_CALL(visitor_, OnConnectionMigration(PORT_CHANGE)).Times(0); + QuicPathFrameBuffer payload; + if (!GetParam().version.HasIetfQuicFrames()) { + EXPECT_CALL(visitor_, + OnPacketReceived(_, _, /*is_connectivity_probe=*/true)) + .Times(1); + } else { + EXPECT_CALL(visitor_, OnPacketReceived(_, _, _)).Times(0); + if (connection_.validate_client_address()) { + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(AtLeast(1u)) + .WillOnce(Invoke([&]() { + EXPECT_EQ(1u, writer_->path_challenge_frames().size()); + EXPECT_EQ(1u, writer_->path_response_frames().size()); + payload = writer_->path_challenge_frames().front().data_buffer; + })); + } + } + // Process a probing packet from a new peer address on server side + // is effectively receiving a connectivity probing. + const QuicSocketAddress kNewPeerAddress(QuicIpAddress::Loopback4(), + /*port=*/23456); + + std::unique_ptr probing_packet = ConstructProbingPacket(); + std::unique_ptr received(ConstructReceivedPacket( + QuicEncryptedPacket(probing_packet->encrypted_buffer, + probing_packet->encrypted_length), + clock_.Now())); + uint64_t num_probing_received = + connection_.GetStats().num_connectivity_probing_received; + ProcessReceivedPacket(kSelfAddress, kNewPeerAddress, *received); + + EXPECT_EQ(num_probing_received + 1, + connection_.GetStats().num_connectivity_probing_received); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + if (GetParam().version.HasIetfQuicFrames()) { + QuicByteCount bytes_sent = + QuicConnectionPeer::BytesSentOnAlternativePath(&connection_); + EXPECT_LT(0u, bytes_sent); + EXPECT_EQ(received->length(), + QuicConnectionPeer::BytesReceivedOnAlternativePath(&connection_)); + + // Receiving one more probing packet should update the bytes count. + probing_packet = ConstructProbingPacket(); + received.reset(ConstructReceivedPacket( + QuicEncryptedPacket(probing_packet->encrypted_buffer, + probing_packet->encrypted_length), + clock_.Now())); + ProcessReceivedPacket(kSelfAddress, kNewPeerAddress, *received); + + EXPECT_EQ(num_probing_received + 2, + connection_.GetStats().num_connectivity_probing_received); + EXPECT_EQ(2 * bytes_sent, + QuicConnectionPeer::BytesSentOnAlternativePath(&connection_)); + EXPECT_EQ(2 * received->length(), + QuicConnectionPeer::BytesReceivedOnAlternativePath(&connection_)); + + bool success = false; + if (!connection_.validate_client_address()) { + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(AtLeast(1u)) + .WillOnce(Invoke([&]() { + EXPECT_EQ(1u, writer_->path_challenge_frames().size()); + payload = writer_->path_challenge_frames().front().data_buffer; + })); + + connection_.ValidatePath( + std::make_unique( + connection_.self_address(), kNewPeerAddress, writer_.get()), + std::make_unique( + &connection_, connection_.self_address(), kNewPeerAddress, + &success), + PathValidationReason::kReasonUnknown); + } + EXPECT_EQ((connection_.validate_client_address() ? 2 : 3) * bytes_sent, + QuicConnectionPeer::BytesSentOnAlternativePath(&connection_)); + QuicFrames frames; + frames.push_back(QuicFrame(QuicPathResponseFrame(99, payload))); + ProcessFramesPacketWithAddresses(frames, connection_.self_address(), + kNewPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_LT(2 * received->length(), + QuicConnectionPeer::BytesReceivedOnAlternativePath(&connection_)); + if (connection_.validate_client_address()) { + EXPECT_TRUE(QuicConnectionPeer::IsAlternativePathValidated(&connection_)); + } + // Receiving another probing packet from a newer address with a different + // port shouldn't trigger another reverse path validation. + QuicSocketAddress kNewerPeerAddress(QuicIpAddress::Loopback4(), + /*port=*/34567); + probing_packet = ConstructProbingPacket(); + received.reset(ConstructReceivedPacket( + QuicEncryptedPacket(probing_packet->encrypted_buffer, + probing_packet->encrypted_length), + clock_.Now())); + ProcessReceivedPacket(kSelfAddress, kNewerPeerAddress, *received); + EXPECT_FALSE(connection_.HasPendingPathValidation()); + EXPECT_EQ(connection_.validate_client_address(), + QuicConnectionPeer::IsAlternativePathValidated(&connection_)); + } + + // Process another packet with the old peer address on server side will not + // start peer migration. + EXPECT_CALL(visitor_, OnConnectionMigration(PORT_CHANGE)).Times(0); + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kSelfAddress, kPeerAddress, + ENCRYPTION_INITIAL); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); +} + +// Receive a packet contains PATH CHALLENGE with self address change. +TEST_P(QuicConnectionTest, ReceivePathProbingToPreferredAddressAtServer) { + if (!GetParam().version.HasIetfQuicFrames()) { + return; + } + ServerHandlePreferredAddressInit(); + + EXPECT_CALL(visitor_, OnConnectionMigration(PORT_CHANGE)).Times(0); + EXPECT_CALL(visitor_, OnPacketReceived(_, _, _)).Times(0); + + // Process a probing packet to the server preferred address. + std::unique_ptr probing_packet = ConstructProbingPacket(); + std::unique_ptr received(ConstructReceivedPacket( + QuicEncryptedPacket(probing_packet->encrypted_buffer, + probing_packet->encrypted_length), + clock_.Now())); + uint64_t num_probing_received = + connection_.GetStats().num_connectivity_probing_received; + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(AtLeast(1u)) + .WillOnce(Invoke([&]() { + EXPECT_EQ(1u, writer_->path_response_frames().size()); + // Verify that the PATH_RESPONSE is sent from the original self address. + EXPECT_EQ(kSelfAddress.host(), writer_->last_write_source_address()); + EXPECT_EQ(kPeerAddress, writer_->last_write_peer_address()); + })); + ProcessReceivedPacket(kServerPreferredAddress, kPeerAddress, *received); + + EXPECT_EQ(num_probing_received + 1, + connection_.GetStats().num_connectivity_probing_received); + EXPECT_FALSE(QuicConnectionPeer::IsAlternativePath( + &connection_, kServerPreferredAddress, kPeerAddress)); + EXPECT_NE(kServerPreferredAddress, connection_.self_address()); + + // Receiving another probing packet from a new client address. + const QuicSocketAddress kNewPeerAddress(QuicIpAddress::Loopback4(), + /*port=*/34567); + probing_packet = ConstructProbingPacket(); + received.reset(ConstructReceivedPacket( + QuicEncryptedPacket(probing_packet->encrypted_buffer, + probing_packet->encrypted_length), + clock_.Now())); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(AtLeast(1u)) + .WillOnce(Invoke([&]() { + EXPECT_EQ(1u, writer_->path_response_frames().size()); + EXPECT_EQ(1u, writer_->path_challenge_frames().size()); + EXPECT_EQ(kServerPreferredAddress.host(), + writer_->last_write_source_address()); + // The responses should be sent from preferred address given server + // has not received packet on original address from the new client + // address. + EXPECT_EQ(kNewPeerAddress, writer_->last_write_peer_address()); + })); + ProcessReceivedPacket(kServerPreferredAddress, kNewPeerAddress, *received); + + EXPECT_EQ(num_probing_received + 2, + connection_.GetStats().num_connectivity_probing_received); + EXPECT_TRUE(QuicConnectionPeer::IsAlternativePath(&connection_, kSelfAddress, + kNewPeerAddress)); + EXPECT_LT(0u, QuicConnectionPeer::BytesSentOnAlternativePath(&connection_)); + EXPECT_EQ(received->length(), + QuicConnectionPeer::BytesReceivedOnAlternativePath(&connection_)); +} + +// Receive a padded PING packet with a port change on server side. +TEST_P(QuicConnectionTest, ReceivePaddedPingWithPortChangeAtServer) { + set_perspective(Perspective::IS_SERVER); + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + EXPECT_EQ(Perspective::IS_SERVER, connection_.perspective()); + if (version().SupportsAntiAmplificationLimit()) { + QuicConnectionPeer::SetAddressValidated(&connection_); + } + + // Clear direct_peer_address. + QuicConnectionPeer::SetDirectPeerAddress(&connection_, QuicSocketAddress()); + // Clear effective_peer_address, it is the same as direct_peer_address for + // this test. + QuicConnectionPeer::SetEffectivePeerAddress(&connection_, + QuicSocketAddress()); + EXPECT_FALSE(connection_.effective_peer_address().IsInitialized()); + + if (GetParam().version.UsesCryptoFrames()) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + } else { + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(AnyNumber()); + } + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kSelfAddress, kPeerAddress, + ENCRYPTION_INITIAL); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + + if (GetParam().version.HasIetfQuicFrames()) { + // In IETF version, a padded PING packet with port change is not taken as + // connectivity probe. + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + EXPECT_CALL(visitor_, OnConnectionMigration(PORT_CHANGE)).Times(1); + EXPECT_CALL(visitor_, OnPacketReceived(_, _, _)).Times(0); + } else { + // In non-IETF version, process a padded PING packet from a new peer + // address on server side is effectively receiving a connectivity probing. + EXPECT_CALL(visitor_, OnConnectionMigration(PORT_CHANGE)).Times(0); + EXPECT_CALL(visitor_, + OnPacketReceived(_, _, /*is_connectivity_probe=*/true)) + .Times(1); + } + const QuicSocketAddress kNewPeerAddress = + QuicSocketAddress(QuicIpAddress::Loopback6(), /*port=*/23456); + + QuicFrames frames; + // Write a PING frame, which has no data payload. + QuicPingFrame ping_frame; + frames.push_back(QuicFrame(ping_frame)); + + // Add padding to the rest of the packet. + QuicPaddingFrame padding_frame; + frames.push_back(QuicFrame(padding_frame)); + + uint64_t num_probing_received = + connection_.GetStats().num_connectivity_probing_received; + + ProcessFramesPacketWithAddresses(frames, kSelfAddress, kNewPeerAddress, + ENCRYPTION_INITIAL); + + if (GetParam().version.HasIetfQuicFrames()) { + // Padded PING with port changen is not considered as connectivity probe but + // a PORT CHANGE. + EXPECT_EQ(num_probing_received, + connection_.GetStats().num_connectivity_probing_received); + EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); + EXPECT_EQ(kNewPeerAddress, connection_.effective_peer_address()); + } else { + EXPECT_EQ(num_probing_received + 1, + connection_.GetStats().num_connectivity_probing_received); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + } + + if (GetParam().version.HasIetfQuicFrames()) { + EXPECT_CALL(visitor_, OnConnectionMigration(PORT_CHANGE)).Times(1); + } + // Process another packet with the old peer address on server side. gQUIC + // shouldn't regard this as a peer migration. + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kSelfAddress, kPeerAddress, + ENCRYPTION_INITIAL); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); +} + +TEST_P(QuicConnectionTest, ReceiveReorderedPathProbingAtServer) { + PathProbeTestInit(Perspective::IS_SERVER); + + // Decrease packet number to simulate out-of-order packets. + QuicPacketCreatorPeer::SetPacketNumber(&peer_creator_, 4); + + EXPECT_CALL(visitor_, OnConnectionMigration(PORT_CHANGE)).Times(0); + if (!GetParam().version.HasIetfQuicFrames()) { + EXPECT_CALL(visitor_, + OnPacketReceived(_, _, /*is_connectivity_probe=*/true)) + .Times(1); + } else { + EXPECT_CALL(visitor_, OnPacketReceived(_, _, _)).Times(0); + } + + // Process a padded PING packet from a new peer address on server side + // is effectively receiving a connectivity probing, even if a newer packet has + // been received before this one. + const QuicSocketAddress kNewPeerAddress = + QuicSocketAddress(QuicIpAddress::Loopback6(), /*port=*/23456); + + std::unique_ptr probing_packet = ConstructProbingPacket(); + std::unique_ptr received(ConstructReceivedPacket( + QuicEncryptedPacket(probing_packet->encrypted_buffer, + probing_packet->encrypted_length), + clock_.Now())); + + uint64_t num_probing_received = + connection_.GetStats().num_connectivity_probing_received; + ProcessReceivedPacket(kSelfAddress, kNewPeerAddress, *received); + + EXPECT_EQ(num_probing_received + 1, + connection_.GetStats().num_connectivity_probing_received); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); +} + +TEST_P(QuicConnectionTest, MigrateAfterProbingAtServer) { + PathProbeTestInit(Perspective::IS_SERVER); + + EXPECT_CALL(visitor_, OnConnectionMigration(PORT_CHANGE)).Times(0); + if (!GetParam().version.HasIetfQuicFrames()) { + EXPECT_CALL(visitor_, + OnPacketReceived(_, _, /*is_connectivity_probe=*/true)) + .Times(1); + } else { + EXPECT_CALL(visitor_, OnPacketReceived(_, _, _)).Times(0); + } + + // Process a padded PING packet from a new peer address on server side + // is effectively receiving a connectivity probing. + const QuicSocketAddress kNewPeerAddress = + QuicSocketAddress(QuicIpAddress::Loopback6(), /*port=*/23456); + + std::unique_ptr probing_packet = ConstructProbingPacket(); + std::unique_ptr received(ConstructReceivedPacket( + QuicEncryptedPacket(probing_packet->encrypted_buffer, + probing_packet->encrypted_length), + clock_.Now())); + ProcessReceivedPacket(kSelfAddress, kNewPeerAddress, *received); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + + // Process another non-probing packet with the new peer address on server + // side will start peer migration. + EXPECT_CALL(visitor_, OnConnectionMigration(PORT_CHANGE)).Times(1); + + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kSelfAddress, + kNewPeerAddress, ENCRYPTION_INITIAL); + EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); + EXPECT_EQ(kNewPeerAddress, connection_.effective_peer_address()); +} + +TEST_P(QuicConnectionTest, ReceiveConnectivityProbingPacketAtClient) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + PathProbeTestInit(Perspective::IS_CLIENT); + + // Client takes all padded PING packet as speculative connectivity + // probing packet, and reports to visitor. + EXPECT_CALL(visitor_, OnConnectionMigration(PORT_CHANGE)).Times(0); + + std::unique_ptr probing_packet = ConstructProbingPacket(); + std::unique_ptr received(ConstructReceivedPacket( + QuicEncryptedPacket(probing_packet->encrypted_buffer, + probing_packet->encrypted_length), + clock_.Now())); + uint64_t num_probing_received = + connection_.GetStats().num_connectivity_probing_received; + ProcessReceivedPacket(kSelfAddress, kPeerAddress, *received); + + EXPECT_EQ( + num_probing_received + (GetParam().version.HasIetfQuicFrames() ? 1u : 0u), + connection_.GetStats().num_connectivity_probing_received); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); +} + +TEST_P(QuicConnectionTest, ReceiveConnectivityProbingResponseAtClient) { + // TODO(b/150095484): add test coverage for IETF to verify that client takes + // PATH RESPONSE with peer address change as correct validation on the new + // path. + if (GetParam().version.HasIetfQuicFrames()) { + return; + } + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + PathProbeTestInit(Perspective::IS_CLIENT); + + // Process a padded PING packet with a different self address on client side + // is effectively receiving a connectivity probing. + EXPECT_CALL(visitor_, OnConnectionMigration(PORT_CHANGE)).Times(0); + if (!GetParam().version.HasIetfQuicFrames()) { + EXPECT_CALL(visitor_, + OnPacketReceived(_, _, /*is_connectivity_probe=*/true)) + .Times(1); + } else { + EXPECT_CALL(visitor_, OnPacketReceived(_, _, _)).Times(0); + } + + const QuicSocketAddress kNewSelfAddress = + QuicSocketAddress(QuicIpAddress::Loopback6(), /*port=*/23456); + + std::unique_ptr probing_packet = ConstructProbingPacket(); + std::unique_ptr received(ConstructReceivedPacket( + QuicEncryptedPacket(probing_packet->encrypted_buffer, + probing_packet->encrypted_length), + clock_.Now())); + uint64_t num_probing_received = + connection_.GetStats().num_connectivity_probing_received; + ProcessReceivedPacket(kNewSelfAddress, kPeerAddress, *received); + + EXPECT_EQ(num_probing_received + 1, + connection_.GetStats().num_connectivity_probing_received); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); +} + +TEST_P(QuicConnectionTest, PeerAddressChangeAtClient) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + set_perspective(Perspective::IS_CLIENT); + EXPECT_EQ(Perspective::IS_CLIENT, connection_.perspective()); + + // Clear direct_peer_address. + QuicConnectionPeer::SetDirectPeerAddress(&connection_, QuicSocketAddress()); + // Clear effective_peer_address, it is the same as direct_peer_address for + // this test. + QuicConnectionPeer::SetEffectivePeerAddress(&connection_, + QuicSocketAddress()); + EXPECT_FALSE(connection_.effective_peer_address().IsInitialized()); + + if (connection_.version().HasIetfQuicFrames()) { + // Verify the 2nd packet from unknown server address gets dropped. + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(1); + } else if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(2); + } else { + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(2); + } + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kSelfAddress, kPeerAddress, + ENCRYPTION_INITIAL); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + const QuicSocketAddress kNewPeerAddress = + QuicSocketAddress(QuicIpAddress::Loopback6(), /*port=*/23456); + EXPECT_CALL(visitor_, OnConnectionMigration(PORT_CHANGE)).Times(0); + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kSelfAddress, + kNewPeerAddress, ENCRYPTION_INITIAL); + if (connection_.version().HasIetfQuicFrames()) { + // IETF QUIC disallows server initiated address change. + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + } else { + EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); + EXPECT_EQ(kNewPeerAddress, connection_.effective_peer_address()); + } +} + +TEST_P(QuicConnectionTest, ServerAddressChangesToKnownAddress) { + if (!connection_.version().HasIetfQuicFrames()) { + return; + } + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + set_perspective(Perspective::IS_CLIENT); + EXPECT_EQ(Perspective::IS_CLIENT, connection_.perspective()); + + // Clear direct_peer_address. + QuicConnectionPeer::SetDirectPeerAddress(&connection_, QuicSocketAddress()); + // Clear effective_peer_address, it is the same as direct_peer_address for + // this test. + QuicConnectionPeer::SetEffectivePeerAddress(&connection_, + QuicSocketAddress()); + EXPECT_FALSE(connection_.effective_peer_address().IsInitialized()); + + // Verify all 3 packets get processed. + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(3); + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kSelfAddress, kPeerAddress, + ENCRYPTION_INITIAL); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + + // Process another packet with a different but known server address. + const QuicSocketAddress kNewPeerAddress = + QuicSocketAddress(QuicIpAddress::Loopback6(), /*port=*/23456); + connection_.AddKnownServerAddress(kNewPeerAddress); + EXPECT_CALL(visitor_, OnConnectionMigration(_)).Times(0); + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kSelfAddress, + kNewPeerAddress, ENCRYPTION_INITIAL); + // Verify peer address does not change. + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + + // Process 3rd packet from previous server address. + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kSelfAddress, kPeerAddress, + ENCRYPTION_INITIAL); + // Verify peer address does not change. + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); +} + +TEST_P(QuicConnectionTest, + PeerAddressChangesToPreferredAddressBeforeClientInitiates) { + if (!version().HasIetfQuicFrames()) { + return; + } + ASSERT_EQ(Perspective::IS_CLIENT, connection_.perspective()); + ASSERT_TRUE(connection_.self_address().host().IsIPv6()); + SetQuicReloadableFlag(quic_connection_migration_use_new_cid_v2, true); + const QuicConnectionId connection_id = TestConnectionId(17); + const StatelessResetToken reset_token = + QuicUtils::GenerateStatelessResetToken(connection_id); + + connection_.CreateConnectionIdManager(); + + connection_.SendCryptoStreamData(); + EXPECT_CALL(*loss_algorithm_, DetectLosses(_, _, _, _, _, _)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + QuicAckFrame frame = InitAckFrame(1); + // Received ACK for packet 1. + ProcessFramePacketAtLevel(1, QuicFrame(&frame), ENCRYPTION_INITIAL); + // Discard INITIAL key. + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + + QuicConfig config; + config.SetConnectionOptionsToSend(QuicTagVector{kRVCM, kSPAD}); + QuicConfigPeer::SetReceivedStatelessResetToken(&config, + kTestStatelessResetToken); + QuicConfigPeer::SetReceivedAlternateServerAddress(&config, + kServerPreferredAddress); + QuicConfigPeer::SetPreferredAddressConnectionIdAndToken( + &config, connection_id, reset_token); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + ASSERT_TRUE( + QuicConnectionPeer::GetReceivedServerPreferredAddress(&connection_) + .IsInitialized()); + EXPECT_EQ( + kServerPreferredAddress, + QuicConnectionPeer::GetReceivedServerPreferredAddress(&connection_)); + + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(0); + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kSelfAddress, + kServerPreferredAddress, ENCRYPTION_INITIAL); +} + +TEST_P(QuicConnectionTest, MaxPacketSize) { + EXPECT_EQ(Perspective::IS_CLIENT, connection_.perspective()); + EXPECT_EQ(1250u, connection_.max_packet_length()); +} + +TEST_P(QuicConnectionTest, PeerLowersMaxPacketSize) { + EXPECT_EQ(Perspective::IS_CLIENT, connection_.perspective()); + + // SetFromConfig is always called after construction from InitializeSession. + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + constexpr uint32_t kTestMaxPacketSize = 1233u; + QuicConfig config; + QuicConfigPeer::SetReceivedMaxPacketSize(&config, kTestMaxPacketSize); + connection_.SetFromConfig(config); + + EXPECT_EQ(kTestMaxPacketSize, connection_.max_packet_length()); +} + +TEST_P(QuicConnectionTest, PeerCannotRaiseMaxPacketSize) { + EXPECT_EQ(Perspective::IS_CLIENT, connection_.perspective()); + + // SetFromConfig is always called after construction from InitializeSession. + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + constexpr uint32_t kTestMaxPacketSize = 1450u; + QuicConfig config; + QuicConfigPeer::SetReceivedMaxPacketSize(&config, kTestMaxPacketSize); + connection_.SetFromConfig(config); + + EXPECT_EQ(kDefaultMaxPacketSize, connection_.max_packet_length()); +} + +TEST_P(QuicConnectionTest, SmallerServerMaxPacketSize) { + TestConnection connection(TestConnectionId(), kSelfAddress, kPeerAddress, + helper_.get(), alarm_factory_.get(), writer_.get(), + Perspective::IS_SERVER, version(), + connection_id_generator_); + EXPECT_EQ(Perspective::IS_SERVER, connection.perspective()); + EXPECT_EQ(1000u, connection.max_packet_length()); +} + +TEST_P(QuicConnectionTest, LowerServerResponseMtuTest) { + set_perspective(Perspective::IS_SERVER); + connection_.SetMaxPacketLength(1000); + EXPECT_EQ(1000u, connection_.max_packet_length()); + + SetQuicFlag(quic_use_lower_server_response_mtu_for_test, true); + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(::testing::AtMost(1)); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(::testing::AtMost(1)); + ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL); + EXPECT_EQ(1250u, connection_.max_packet_length()); +} + +TEST_P(QuicConnectionTest, IncreaseServerMaxPacketSize) { + set_perspective(Perspective::IS_SERVER); + connection_.SetMaxPacketLength(1000); + + QuicPacketHeader header; + header.destination_connection_id = connection_id_; + header.version_flag = true; + header.packet_number = QuicPacketNumber(12); + + if (QuicVersionHasLongHeaderLengths( + peer_framer_.version().transport_version)) { + header.long_packet_type = INITIAL; + header.retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1; + header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2; + } + + QuicFrames frames; + QuicPaddingFrame padding; + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + frames.push_back(QuicFrame(&crypto_frame_)); + } else { + frames.push_back(QuicFrame(frame1_)); + } + frames.push_back(QuicFrame(padding)); + std::unique_ptr packet(ConstructPacket(header, frames)); + char buffer[kMaxOutgoingPacketSize]; + size_t encrypted_length = + peer_framer_.EncryptPayload(ENCRYPTION_INITIAL, QuicPacketNumber(12), + *packet, buffer, kMaxOutgoingPacketSize); + EXPECT_EQ(kMaxOutgoingPacketSize, + encrypted_length + + (connection_.version().KnowsWhichDecrypterToUse() ? 0 : 4)); + + framer_.set_version(version()); + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(1); + } else { + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + } + connection_.ProcessUdpPacket( + kSelfAddress, kPeerAddress, + QuicReceivedPacket(buffer, encrypted_length, clock_.ApproximateNow(), + false)); + + EXPECT_EQ(kMaxOutgoingPacketSize, + connection_.max_packet_length() + + (connection_.version().KnowsWhichDecrypterToUse() ? 0 : 4)); +} + +TEST_P(QuicConnectionTest, IncreaseServerMaxPacketSizeWhileWriterLimited) { + const QuicByteCount lower_max_packet_size = 1240; + writer_->set_max_packet_size(lower_max_packet_size); + set_perspective(Perspective::IS_SERVER); + connection_.SetMaxPacketLength(1000); + EXPECT_EQ(1000u, connection_.max_packet_length()); + + QuicPacketHeader header; + header.destination_connection_id = connection_id_; + header.version_flag = true; + header.packet_number = QuicPacketNumber(12); + + if (QuicVersionHasLongHeaderLengths( + peer_framer_.version().transport_version)) { + header.long_packet_type = INITIAL; + header.retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1; + header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2; + } + + QuicFrames frames; + QuicPaddingFrame padding; + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + frames.push_back(QuicFrame(&crypto_frame_)); + } else { + frames.push_back(QuicFrame(frame1_)); + } + frames.push_back(QuicFrame(padding)); + std::unique_ptr packet(ConstructPacket(header, frames)); + char buffer[kMaxOutgoingPacketSize]; + size_t encrypted_length = + peer_framer_.EncryptPayload(ENCRYPTION_INITIAL, QuicPacketNumber(12), + *packet, buffer, kMaxOutgoingPacketSize); + EXPECT_EQ(kMaxOutgoingPacketSize, + encrypted_length + + (connection_.version().KnowsWhichDecrypterToUse() ? 0 : 4)); + + framer_.set_version(version()); + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(1); + } else { + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + } + connection_.ProcessUdpPacket( + kSelfAddress, kPeerAddress, + QuicReceivedPacket(buffer, encrypted_length, clock_.ApproximateNow(), + false)); + + // Here, the limit imposed by the writer is lower than the size of the packet + // received, so the writer max packet size is used. + EXPECT_EQ(lower_max_packet_size, connection_.max_packet_length()); +} + +TEST_P(QuicConnectionTest, LimitMaxPacketSizeByWriter) { + const QuicByteCount lower_max_packet_size = 1240; + writer_->set_max_packet_size(lower_max_packet_size); + + static_assert(lower_max_packet_size < kDefaultMaxPacketSize, + "Default maximum packet size is too low"); + connection_.SetMaxPacketLength(kDefaultMaxPacketSize); + + EXPECT_EQ(lower_max_packet_size, connection_.max_packet_length()); +} + +TEST_P(QuicConnectionTest, LimitMaxPacketSizeByWriterForNewConnection) { + const QuicConnectionId connection_id = TestConnectionId(17); + const QuicByteCount lower_max_packet_size = 1240; + writer_->set_max_packet_size(lower_max_packet_size); + TestConnection connection(connection_id, kSelfAddress, kPeerAddress, + helper_.get(), alarm_factory_.get(), writer_.get(), + Perspective::IS_CLIENT, version(), + connection_id_generator_); + EXPECT_EQ(Perspective::IS_CLIENT, connection.perspective()); + EXPECT_EQ(lower_max_packet_size, connection.max_packet_length()); +} + +TEST_P(QuicConnectionTest, PacketsInOrder) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + ProcessPacket(1); + EXPECT_EQ(QuicPacketNumber(1u), LargestAcked(connection_.ack_frame())); + EXPECT_EQ(1u, connection_.ack_frame().packets.NumIntervals()); + + ProcessPacket(2); + EXPECT_EQ(QuicPacketNumber(2u), LargestAcked(connection_.ack_frame())); + EXPECT_EQ(1u, connection_.ack_frame().packets.NumIntervals()); + + ProcessPacket(3); + EXPECT_EQ(QuicPacketNumber(3u), LargestAcked(connection_.ack_frame())); + EXPECT_EQ(1u, connection_.ack_frame().packets.NumIntervals()); +} + +TEST_P(QuicConnectionTest, PacketsOutOfOrder) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + ProcessPacket(3); + EXPECT_EQ(QuicPacketNumber(3u), LargestAcked(connection_.ack_frame())); + EXPECT_TRUE(IsMissing(2)); + EXPECT_TRUE(IsMissing(1)); + + ProcessPacket(2); + EXPECT_EQ(QuicPacketNumber(3u), LargestAcked(connection_.ack_frame())); + EXPECT_FALSE(IsMissing(2)); + EXPECT_TRUE(IsMissing(1)); + + ProcessPacket(1); + EXPECT_EQ(QuicPacketNumber(3u), LargestAcked(connection_.ack_frame())); + EXPECT_FALSE(IsMissing(2)); + EXPECT_FALSE(IsMissing(1)); +} + +TEST_P(QuicConnectionTest, DuplicatePacket) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + ProcessPacket(3); + EXPECT_EQ(QuicPacketNumber(3u), LargestAcked(connection_.ack_frame())); + EXPECT_TRUE(IsMissing(2)); + EXPECT_TRUE(IsMissing(1)); + + // Send packet 3 again, but do not set the expectation that + // the visitor OnStreamFrame() will be called. + ProcessDataPacket(3); + EXPECT_EQ(QuicPacketNumber(3u), LargestAcked(connection_.ack_frame())); + EXPECT_TRUE(IsMissing(2)); + EXPECT_TRUE(IsMissing(1)); +} + +TEST_P(QuicConnectionTest, PacketsOutOfOrderWithAdditionsAndLeastAwaiting) { + if (connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + ProcessPacket(3); + EXPECT_EQ(QuicPacketNumber(3u), LargestAcked(connection_.ack_frame())); + EXPECT_TRUE(IsMissing(2)); + EXPECT_TRUE(IsMissing(1)); + + ProcessPacket(2); + EXPECT_EQ(QuicPacketNumber(3u), LargestAcked(connection_.ack_frame())); + EXPECT_TRUE(IsMissing(1)); + + ProcessPacket(5); + EXPECT_EQ(QuicPacketNumber(5u), LargestAcked(connection_.ack_frame())); + EXPECT_TRUE(IsMissing(1)); + EXPECT_TRUE(IsMissing(4)); + + // Pretend at this point the client has gotten acks for 2 and 3 and 1 is a + // packet the peer will not retransmit. It indicates this by sending 'least + // awaiting' is 4. The connection should then realize 1 will not be + // retransmitted, and will remove it from the missing list. + QuicAckFrame frame = InitAckFrame(1); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _, _, _)); + ProcessAckPacket(6, &frame); + + // Force an ack to be sent. + SendAckPacketToPeer(); + EXPECT_TRUE(IsMissing(4)); +} + +TEST_P(QuicConnectionTest, RejectUnencryptedStreamData) { + // EXPECT_QUIC_BUG tests are expensive so only run one instance of them. + if (!IsDefaultTestConfiguration() || + VersionHasIetfQuicFrames(version().transport_version)) { + return; + } + + // Process an unencrypted packet from the non-crypto stream. + frame1_.stream_id = 3; + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(visitor_, + OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)); + EXPECT_QUIC_PEER_BUG(ProcessDataPacketAtLevel(1, false, ENCRYPTION_INITIAL), + ""); + TestConnectionCloseQuicErrorCode(QUIC_UNENCRYPTED_STREAM_DATA); +} + +TEST_P(QuicConnectionTest, OutOfOrderReceiptCausesAckSend) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + ProcessPacket(3); + // Should not cause an ack. + EXPECT_EQ(0u, writer_->packets_write_attempts()); + + ProcessPacket(2); + // Should ack immediately, since this fills the last hole. + EXPECT_EQ(1u, writer_->packets_write_attempts()); + + ProcessPacket(1); + // Should ack immediately, since this fills the last hole. + EXPECT_EQ(2u, writer_->packets_write_attempts()); + + ProcessPacket(4); + // Should not cause an ack. + EXPECT_EQ(2u, writer_->packets_write_attempts()); +} + +TEST_P(QuicConnectionTest, OutOfOrderAckReceiptCausesNoAck) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + SendStreamDataToPeer(1, "foo", 0, NO_FIN, nullptr); + SendStreamDataToPeer(1, "bar", 3, NO_FIN, nullptr); + EXPECT_EQ(2u, writer_->packets_write_attempts()); + + QuicAckFrame ack1 = InitAckFrame(1); + QuicAckFrame ack2 = InitAckFrame(2); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + if (connection_.SupportsMultiplePacketNumberSpaces()) { + EXPECT_CALL(visitor_, OnOneRttPacketAcknowledged()).Times(1); + } + ProcessAckPacket(2, &ack2); + // Should ack immediately since we have missing packets. + EXPECT_EQ(2u, writer_->packets_write_attempts()); + + if (connection_.SupportsMultiplePacketNumberSpaces()) { + EXPECT_CALL(visitor_, OnOneRttPacketAcknowledged()).Times(0); + } + ProcessAckPacket(1, &ack1); + // Should not ack an ack filling a missing packet. + EXPECT_EQ(2u, writer_->packets_write_attempts()); +} + +TEST_P(QuicConnectionTest, AckReceiptCausesAckSend) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + QuicPacketNumber original, second; + + QuicByteCount packet_size = + SendStreamDataToPeer(3, "foo", 0, NO_FIN, &original); // 1st packet. + SendStreamDataToPeer(3, "bar", 3, NO_FIN, &second); // 2nd packet. + + QuicAckFrame frame = InitAckFrame({{second, second + 1}}); + // First nack triggers early retransmit. + LostPacketVector lost_packets; + lost_packets.push_back(LostPacket(original, kMaxOutgoingPacketSize)); + EXPECT_CALL(*loss_algorithm_, DetectLosses(_, _, _, _, _, _)) + .WillOnce(DoAll(SetArgPointee<5>(lost_packets), + Return(LossDetectionInterface::DetectionStats()))); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + QuicPacketNumber retransmission; + // Packet 1 is short header for IETF QUIC because the encryption level + // switched to ENCRYPTION_FORWARD_SECURE in SendStreamDataToPeer. + EXPECT_CALL(*send_algorithm_, + OnPacketSent(_, _, _, + GetParam().version.HasIetfInvariantHeader() + ? packet_size + : packet_size - kQuicVersionSize, + _)) + .WillOnce(SaveArg<2>(&retransmission)); + + ProcessAckPacket(&frame); + + QuicAckFrame frame2 = ConstructAckFrame(retransmission, original); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + EXPECT_CALL(*loss_algorithm_, DetectLosses(_, _, _, _, _, _)); + ProcessAckPacket(&frame2); + + // Now if the peer sends an ack which still reports the retransmitted packet + // as missing, that will bundle an ack with data after two acks in a row + // indicate the high water mark needs to be raised. + EXPECT_CALL(*send_algorithm_, + OnPacketSent(_, _, _, _, HAS_RETRANSMITTABLE_DATA)); + connection_.SendStreamDataWithString(3, "foo", 6, NO_FIN); + // No ack sent. + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); + EXPECT_EQ(1u, writer_->stream_frames().size()); + + // No more packet loss for the rest of the test. + EXPECT_CALL(*loss_algorithm_, DetectLosses(_, _, _, _, _, _)) + .Times(AnyNumber()); + ProcessAckPacket(&frame2); + EXPECT_CALL(*send_algorithm_, + OnPacketSent(_, _, _, _, HAS_RETRANSMITTABLE_DATA)); + connection_.SendStreamDataWithString(3, "foofoofoo", 9, NO_FIN); + // Ack bundled. + if (GetParam().no_stop_waiting) { + // Do not ACK acks. + EXPECT_EQ(1u, writer_->frame_count()); + } else { + EXPECT_EQ(3u, writer_->frame_count()); + } + EXPECT_EQ(1u, writer_->stream_frames().size()); + if (GetParam().no_stop_waiting) { + EXPECT_TRUE(writer_->ack_frames().empty()); + } else { + EXPECT_FALSE(writer_->ack_frames().empty()); + } + + // But an ack with no missing packets will not send an ack. + AckPacket(original, &frame2); + ProcessAckPacket(&frame2); + ProcessAckPacket(&frame2); +} + +TEST_P(QuicConnectionTest, AckFrequencyUpdatedFromAckFrequencyFrame) { + if (!GetParam().version.HasIetfQuicFrames()) { + return; + } + connection_.set_can_receive_ack_frequency_frame(); + + // Expect 13 acks, every 3rd packet including the first packet with + // AckFrequencyFrame. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(13); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + QuicAckFrequencyFrame ack_frequency_frame; + ack_frequency_frame.packet_tolerance = 3; + ProcessFramePacketAtLevel(1, QuicFrame(&ack_frequency_frame), + ENCRYPTION_FORWARD_SECURE); + + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(38); + // Receives packets 2 - 39. + for (size_t i = 2; i <= 39; ++i) { + ProcessDataPacket(i); + } +} + +TEST_P(QuicConnectionTest, AckDecimationReducesAcks) { + const size_t kMinRttMs = 40; + RttStats* rtt_stats = const_cast(manager_->GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(kMinRttMs), + QuicTime::Delta::Zero(), QuicTime::Zero()); + EXPECT_CALL(visitor_, OnAckNeedsRetransmittableFrame()).Times(AnyNumber()); + + // Start ack decimation from 10th packet. + connection_.set_min_received_before_ack_decimation(10); + + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(30); + + // Expect 6 acks: 5 acks between packets 1-10, and ack at 20. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(6); + // Receives packets 1 - 29. + for (size_t i = 1; i <= 29; ++i) { + ProcessDataPacket(i); + } + + // We now receive the 30th packet, and so we send an ack. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + ProcessDataPacket(30); +} + +TEST_P(QuicConnectionTest, AckNeedsRetransmittableFrames) { + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(99); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(19); + // Receives packets 1 - 39. + for (size_t i = 1; i <= 39; ++i) { + ProcessDataPacket(i); + } + // Receiving Packet 40 causes 20th ack to send. Session is informed and adds + // WINDOW_UPDATE. + EXPECT_CALL(visitor_, OnAckNeedsRetransmittableFrame()) + .WillOnce(Invoke([this]() { + connection_.SendControlFrame(QuicFrame(QuicWindowUpdateFrame(1, 0, 0))); + })); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + EXPECT_EQ(0u, writer_->window_update_frames().size()); + ProcessDataPacket(40); + EXPECT_EQ(1u, writer_->window_update_frames().size()); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(9); + // Receives packets 41 - 59. + for (size_t i = 41; i <= 59; ++i) { + ProcessDataPacket(i); + } + // Send a packet containing stream frame. + SendStreamDataToPeer( + QuicUtils::GetFirstBidirectionalStreamId( + connection_.version().transport_version, Perspective::IS_CLIENT), + "bar", 0, NO_FIN, nullptr); + + // Session will not be informed until receiving another 20 packets. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(19); + for (size_t i = 60; i <= 98; ++i) { + ProcessDataPacket(i); + EXPECT_EQ(0u, writer_->window_update_frames().size()); + } + // Session does not add a retransmittable frame. + EXPECT_CALL(visitor_, OnAckNeedsRetransmittableFrame()) + .WillOnce(Invoke([this]() { + connection_.SendControlFrame(QuicFrame(QuicPingFrame(1))); + })); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + EXPECT_EQ(0u, writer_->ping_frames().size()); + ProcessDataPacket(99); + EXPECT_EQ(0u, writer_->window_update_frames().size()); + // A ping frame will be added. + EXPECT_EQ(1u, writer_->ping_frames().size()); +} + +TEST_P(QuicConnectionTest, AckNeedsRetransmittableFramesAfterPto) { + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + QuicTagVector connection_options; + connection_options.push_back(kEACK); + config.SetConnectionOptionsToSend(connection_options); + connection_.SetFromConfig(config); + + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + connection_.OnHandshakeComplete(); + + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(10); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(4); + // Receive packets 1 - 9. + for (size_t i = 1; i <= 9; ++i) { + ProcessDataPacket(i); + } + + // Send a ping and fire the retransmission alarm. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(2); + SendPing(); + QuicTime retransmission_time = + connection_.GetRetransmissionAlarm()->deadline(); + clock_.AdvanceTime(retransmission_time - clock_.Now()); + connection_.GetRetransmissionAlarm()->Fire(); + ASSERT_LT(0u, manager_->GetConsecutivePtoCount()); + + // Process a packet, which requests a retransmittable frame be bundled + // with the ACK. + EXPECT_CALL(visitor_, OnAckNeedsRetransmittableFrame()) + .WillOnce(Invoke([this]() { + connection_.SendControlFrame(QuicFrame(QuicWindowUpdateFrame(1, 0, 0))); + })); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + ProcessDataPacket(11); + EXPECT_EQ(1u, writer_->window_update_frames().size()); +} + +TEST_P(QuicConnectionTest, LeastUnackedLower) { + if (GetParam().version.HasIetfInvariantHeader()) { + return; + } + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + SendStreamDataToPeer(1, "foo", 0, NO_FIN, nullptr); + SendStreamDataToPeer(1, "bar", 3, NO_FIN, nullptr); + SendStreamDataToPeer(1, "eep", 6, NO_FIN, nullptr); + + // Start out saying the least unacked is 2. + QuicPacketCreatorPeer::SetPacketNumber(&peer_creator_, 5); + ProcessStopWaitingPacket(InitStopWaitingFrame(2)); + + // Change it to 1, but lower the packet number to fake out-of-order packets. + // This should be fine. + QuicPacketCreatorPeer::SetPacketNumber(&peer_creator_, 1); + // The scheduler will not process out of order acks, but all packet processing + // causes the connection to try to write. + if (!GetParam().no_stop_waiting) { + EXPECT_CALL(visitor_, OnCanWrite()); + } + ProcessStopWaitingPacket(InitStopWaitingFrame(1)); + + // Now claim it's one, but set the ordering so it was sent "after" the first + // one. This should cause a connection error. + QuicPacketCreatorPeer::SetPacketNumber(&peer_creator_, 7); + if (!GetParam().no_stop_waiting) { + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(AtLeast(1)); + EXPECT_CALL(visitor_, + OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .Times(AtLeast(1)); + } + ProcessStopWaitingPacket(InitStopWaitingFrame(1)); + if (!GetParam().no_stop_waiting) { + TestConnectionCloseQuicErrorCode(QUIC_INVALID_STOP_WAITING_DATA); + } +} + +TEST_P(QuicConnectionTest, TooManySentPackets) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + QuicPacketCount max_tracked_packets = 50; + QuicConnectionPeer::SetMaxTrackedPackets(&connection_, max_tracked_packets); + + const int num_packets = max_tracked_packets + 5; + + for (int i = 0; i < num_packets; ++i) { + SendStreamDataToPeer(1, "foo", 3 * i, NO_FIN, nullptr); + } + + EXPECT_CALL(visitor_, + OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)); + + ProcessFramePacket(QuicFrame(QuicPingFrame())); + + TestConnectionCloseQuicErrorCode(QUIC_TOO_MANY_OUTSTANDING_SENT_PACKETS); +} + +TEST_P(QuicConnectionTest, LargestObservedLower) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + SendStreamDataToPeer(1, "foo", 0, NO_FIN, nullptr); + SendStreamDataToPeer(1, "bar", 3, NO_FIN, nullptr); + SendStreamDataToPeer(1, "eep", 6, NO_FIN, nullptr); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + + // Start out saying the largest observed is 2. + QuicAckFrame frame1 = InitAckFrame(1); + QuicAckFrame frame2 = InitAckFrame(2); + ProcessAckPacket(&frame2); + + EXPECT_CALL(visitor_, OnCanWrite()); + ProcessAckPacket(&frame1); +} + +TEST_P(QuicConnectionTest, AckUnsentData) { + // Ack a packet which has not been sent. + EXPECT_CALL(visitor_, + OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(AtLeast(1)); + QuicAckFrame frame = InitAckFrame(1); + EXPECT_CALL(visitor_, OnCanWrite()).Times(0); + ProcessAckPacket(&frame); + TestConnectionCloseQuicErrorCode(QUIC_INVALID_ACK_DATA); +} + +TEST_P(QuicConnectionTest, BasicSending) { + if (connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + const QuicConnectionStats& stats = connection_.GetStats(); + EXPECT_FALSE(stats.first_decrypted_packet.IsInitialized()); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacket(1); + EXPECT_EQ(QuicPacketNumber(1), stats.first_decrypted_packet); + QuicPacketCreatorPeer::SetPacketNumber(&peer_creator_, 2); + QuicPacketNumber last_packet; + SendStreamDataToPeer(1, "foo", 0, NO_FIN, &last_packet); // Packet 1 + EXPECT_EQ(QuicPacketNumber(1u), last_packet); + SendAckPacketToPeer(); // Packet 2 + + if (GetParam().no_stop_waiting) { + // Expect no stop waiting frame is sent. + EXPECT_FALSE(least_unacked().IsInitialized()); + } else { + EXPECT_EQ(QuicPacketNumber(1u), least_unacked()); + } + + SendAckPacketToPeer(); // Packet 3 + if (GetParam().no_stop_waiting) { + // Expect no stop waiting frame is sent. + EXPECT_FALSE(least_unacked().IsInitialized()); + } else { + EXPECT_EQ(QuicPacketNumber(1u), least_unacked()); + } + + SendStreamDataToPeer(1, "bar", 3, NO_FIN, &last_packet); // Packet 4 + EXPECT_EQ(QuicPacketNumber(4u), last_packet); + SendAckPacketToPeer(); // Packet 5 + if (GetParam().no_stop_waiting) { + // Expect no stop waiting frame is sent. + EXPECT_FALSE(least_unacked().IsInitialized()); + } else { + EXPECT_EQ(QuicPacketNumber(1u), least_unacked()); + } + + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + + // Peer acks up to packet 3. + QuicAckFrame frame = InitAckFrame(3); + ProcessAckPacket(&frame); + SendAckPacketToPeer(); // Packet 6 + + // As soon as we've acked one, we skip ack packets 2 and 3 and note lack of + // ack for 4. + if (GetParam().no_stop_waiting) { + // Expect no stop waiting frame is sent. + EXPECT_FALSE(least_unacked().IsInitialized()); + } else { + EXPECT_EQ(QuicPacketNumber(4u), least_unacked()); + } + + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + + // Peer acks up to packet 4, the last packet. + QuicAckFrame frame2 = InitAckFrame(6); + ProcessAckPacket(&frame2); // Acks don't instigate acks. + + // Verify that we did not send an ack. + EXPECT_EQ(QuicPacketNumber(6u), writer_->header().packet_number); + + // So the last ack has not changed. + if (GetParam().no_stop_waiting) { + // Expect no stop waiting frame is sent. + EXPECT_FALSE(least_unacked().IsInitialized()); + } else { + EXPECT_EQ(QuicPacketNumber(4u), least_unacked()); + } + + // If we force an ack, we shouldn't change our retransmit state. + SendAckPacketToPeer(); // Packet 7 + if (GetParam().no_stop_waiting) { + // Expect no stop waiting frame is sent. + EXPECT_FALSE(least_unacked().IsInitialized()); + } else { + EXPECT_EQ(QuicPacketNumber(7u), least_unacked()); + } + + // But if we send more data it should. + SendStreamDataToPeer(1, "eep", 6, NO_FIN, &last_packet); // Packet 8 + EXPECT_EQ(QuicPacketNumber(8u), last_packet); + SendAckPacketToPeer(); // Packet 9 + if (GetParam().no_stop_waiting) { + // Expect no stop waiting frame is sent. + EXPECT_FALSE(least_unacked().IsInitialized()); + } else { + EXPECT_EQ(QuicPacketNumber(7u), least_unacked()); + } + EXPECT_EQ(QuicPacketNumber(1), stats.first_decrypted_packet); +} + +// QuicConnection should record the packet sent-time prior to sending the +// packet. +TEST_P(QuicConnectionTest, RecordSentTimeBeforePacketSent) { + // We're using a MockClock for the tests, so we have complete control over the + // time. + // Our recorded timestamp for the last packet sent time will be passed in to + // the send_algorithm. Make sure that it is set to the correct value. + QuicTime actual_recorded_send_time = QuicTime::Zero(); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .WillOnce(SaveArg<0>(&actual_recorded_send_time)); + + // First send without any pause and check the result. + QuicTime expected_recorded_send_time = clock_.Now(); + connection_.SendStreamDataWithString(1, "foo", 0, NO_FIN); + EXPECT_EQ(expected_recorded_send_time, actual_recorded_send_time) + << "Expected time = " << expected_recorded_send_time.ToDebuggingValue() + << ". Actual time = " << actual_recorded_send_time.ToDebuggingValue(); + + // Now pause during the write, and check the results. + actual_recorded_send_time = QuicTime::Zero(); + const QuicTime::Delta write_pause_time_delta = + QuicTime::Delta::FromMilliseconds(5000); + SetWritePauseTimeDelta(write_pause_time_delta); + expected_recorded_send_time = clock_.Now(); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .WillOnce(SaveArg<0>(&actual_recorded_send_time)); + connection_.SendStreamDataWithString(2, "baz", 0, NO_FIN); + EXPECT_EQ(expected_recorded_send_time, actual_recorded_send_time) + << "Expected time = " << expected_recorded_send_time.ToDebuggingValue() + << ". Actual time = " << actual_recorded_send_time.ToDebuggingValue(); +} + +TEST_P(QuicConnectionTest, ConnectionStatsRetransmission_WithRetransmissions) { + // Send two stream frames in 1 packet by queueing them. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + + { + QuicConnection::ScopedPacketFlusher flusher(&connection_); + connection_.SaveAndSendStreamData( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), + "helloworld", 0, NO_FIN, PTO_RETRANSMISSION); + connection_.SaveAndSendStreamData( + GetNthClientInitiatedStreamId(2, connection_.transport_version()), + "helloworld", 0, NO_FIN, LOSS_RETRANSMISSION); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + } + + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + EXPECT_FALSE(connection_.HasQueuedData()); + + EXPECT_EQ(2u, writer_->frame_count()); + for (auto& frame : writer_->stream_frames()) { + EXPECT_EQ(frame->data_length, 10u); + } + + ASSERT_EQ(connection_.GetStats().packets_retransmitted, 1u); + ASSERT_GE(connection_.GetStats().bytes_retransmitted, 20u); +} + +TEST_P(QuicConnectionTest, ConnectionStatsRetransmission_WithMixedFrames) { + // Send two stream frames in 1 packet by queueing them. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + + { + QuicConnection::ScopedPacketFlusher flusher(&connection_); + // First frame is retransmission. Second is NOT_RETRANSMISSION but the + // packet retains the PTO_RETRANSMISSION type. + connection_.SaveAndSendStreamData( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), + "helloworld", 0, NO_FIN, PTO_RETRANSMISSION); + connection_.SaveAndSendStreamData( + GetNthClientInitiatedStreamId(2, connection_.transport_version()), + "helloworld", 0, NO_FIN, NOT_RETRANSMISSION); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + } + + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + EXPECT_FALSE(connection_.HasQueuedData()); + + EXPECT_EQ(2u, writer_->frame_count()); + for (auto& frame : writer_->stream_frames()) { + EXPECT_EQ(frame->data_length, 10u); + } + + ASSERT_EQ(connection_.GetStats().packets_retransmitted, 1u); + ASSERT_GE(connection_.GetStats().bytes_retransmitted, 10u); +} + +TEST_P(QuicConnectionTest, ConnectionStatsRetransmission_NoRetransmission) { + // Send two stream frames in 1 packet by queueing them. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + + { + QuicConnection::ScopedPacketFlusher flusher(&connection_); + // Both frames are NOT_RETRANSMISSION + connection_.SaveAndSendStreamData( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), + "helloworld", 0, NO_FIN, NOT_RETRANSMISSION); + connection_.SaveAndSendStreamData( + GetNthClientInitiatedStreamId(2, connection_.transport_version()), + "helloworld", 0, NO_FIN, NOT_RETRANSMISSION); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + } + + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + EXPECT_FALSE(connection_.HasQueuedData()); + + EXPECT_EQ(2u, writer_->frame_count()); + ASSERT_EQ(connection_.GetStats().packets_retransmitted, 0u); + ASSERT_EQ(connection_.GetStats().bytes_retransmitted, 0u); +} + +TEST_P(QuicConnectionTest, FramePacking) { + // Send two stream frames in 1 packet by queueing them. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + { + QuicConnection::ScopedPacketFlusher flusher(&connection_); + connection_.SendStreamData3(); + connection_.SendStreamData5(); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + } + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + EXPECT_FALSE(connection_.HasQueuedData()); + + // Parse the last packet and ensure it's an ack and two stream frames from + // two different streams. + if (GetParam().no_stop_waiting) { + EXPECT_EQ(2u, writer_->frame_count()); + EXPECT_TRUE(writer_->stop_waiting_frames().empty()); + } else { + EXPECT_EQ(2u, writer_->frame_count()); + EXPECT_TRUE(writer_->stop_waiting_frames().empty()); + } + + EXPECT_TRUE(writer_->ack_frames().empty()); + + ASSERT_EQ(2u, writer_->stream_frames().size()); + EXPECT_EQ(GetNthClientInitiatedStreamId(1, connection_.transport_version()), + writer_->stream_frames()[0]->stream_id); + EXPECT_EQ(GetNthClientInitiatedStreamId(2, connection_.transport_version()), + writer_->stream_frames()[1]->stream_id); +} + +TEST_P(QuicConnectionTest, FramePackingNonCryptoThenCrypto) { + // Send two stream frames (one non-crypto, then one crypto) in 2 packets by + // queueing them. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + { + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(2); + QuicConnection::ScopedPacketFlusher flusher(&connection_); + connection_.SendStreamData3(); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + // Set the crypters for INITIAL packets in the TestPacketWriter. + if (!connection_.version().KnowsWhichDecrypterToUse()) { + writer_->framer()->framer()->SetAlternativeDecrypter( + ENCRYPTION_INITIAL, + std::make_unique(Perspective::IS_SERVER), false); + } + connection_.SendCryptoStreamData(); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + } + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + EXPECT_FALSE(connection_.HasQueuedData()); + + // Parse the last packet and ensure it contains a crypto stream frame. + EXPECT_LE(2u, writer_->frame_count()); + ASSERT_LE(1u, writer_->padding_frames().size()); + if (!QuicVersionUsesCryptoFrames(connection_.transport_version())) { + ASSERT_EQ(1u, writer_->stream_frames().size()); + EXPECT_EQ(QuicUtils::GetCryptoStreamId(connection_.transport_version()), + writer_->stream_frames()[0]->stream_id); + } else { + EXPECT_LE(1u, writer_->crypto_frames().size()); + } +} + +TEST_P(QuicConnectionTest, FramePackingCryptoThenNonCrypto) { + // Send two stream frames (one crypto, then one non-crypto) in 2 packets by + // queueing them. + { + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(2); + QuicConnection::ScopedPacketFlusher flusher(&connection_); + connection_.SendCryptoStreamData(); + connection_.SendStreamData3(); + } + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + EXPECT_FALSE(connection_.HasQueuedData()); + + // Parse the last packet and ensure it's the stream frame from stream 3. + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); + ASSERT_EQ(1u, writer_->stream_frames().size()); + EXPECT_EQ(GetNthClientInitiatedStreamId(1, connection_.transport_version()), + writer_->stream_frames()[0]->stream_id); +} + +TEST_P(QuicConnectionTest, FramePackingAckResponse) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + // Process a data packet to queue up a pending ack. + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(1); + } else { + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + } + ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL); + + QuicPacketNumber last_packet; + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + connection_.SendCryptoDataWithString("foo", 0); + } else { + SendStreamDataToPeer( + QuicUtils::GetCryptoStreamId(connection_.transport_version()), "foo", 0, + NO_FIN, &last_packet); + } + // Verify ack is bundled with outging packet. + EXPECT_FALSE(writer_->ack_frames().empty()); + + EXPECT_CALL(visitor_, OnCanWrite()) + .WillOnce(DoAll(IgnoreResult(InvokeWithoutArgs( + &connection_, &TestConnection::SendStreamData3)), + IgnoreResult(InvokeWithoutArgs( + &connection_, &TestConnection::SendStreamData5)))); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + + // Process a data packet to cause the visitor's OnCanWrite to be invoked. + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + peer_framer_.SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + SetDecrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + ProcessDataPacket(2); + + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + EXPECT_FALSE(connection_.HasQueuedData()); + + // Parse the last packet and ensure it's an ack and two stream frames from + // two different streams. + if (GetParam().no_stop_waiting) { + EXPECT_EQ(3u, writer_->frame_count()); + EXPECT_TRUE(writer_->stop_waiting_frames().empty()); + } else { + EXPECT_EQ(4u, writer_->frame_count()); + EXPECT_FALSE(writer_->stop_waiting_frames().empty()); + } + EXPECT_FALSE(writer_->ack_frames().empty()); + ASSERT_EQ(2u, writer_->stream_frames().size()); + EXPECT_EQ(GetNthClientInitiatedStreamId(1, connection_.transport_version()), + writer_->stream_frames()[0]->stream_id); + EXPECT_EQ(GetNthClientInitiatedStreamId(2, connection_.transport_version()), + writer_->stream_frames()[1]->stream_id); +} + +TEST_P(QuicConnectionTest, FramePackingSendv) { + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)); + + QuicStreamId stream_id = QuicUtils::GetFirstBidirectionalStreamId( + connection_.transport_version(), Perspective::IS_CLIENT); + connection_.SaveAndSendStreamData(stream_id, "ABCDEF", 0, NO_FIN); + + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + EXPECT_FALSE(connection_.HasQueuedData()); + + // Parse the last packet and ensure multiple iovector blocks have + // been packed into a single stream frame from one stream. + EXPECT_EQ(1u, writer_->frame_count()); + EXPECT_EQ(1u, writer_->stream_frames().size()); + EXPECT_EQ(0u, writer_->padding_frames().size()); + QuicStreamFrame* frame = writer_->stream_frames()[0].get(); + EXPECT_EQ(stream_id, frame->stream_id); + EXPECT_EQ("ABCDEF", + absl::string_view(frame->data_buffer, frame->data_length)); +} + +TEST_P(QuicConnectionTest, FramePackingSendvQueued) { + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)); + + BlockOnNextWrite(); + QuicStreamId stream_id = QuicUtils::GetFirstBidirectionalStreamId( + connection_.transport_version(), Perspective::IS_CLIENT); + connection_.SaveAndSendStreamData(stream_id, "ABCDEF", 0, NO_FIN); + + EXPECT_EQ(1u, connection_.NumQueuedPackets()); + EXPECT_TRUE(connection_.HasQueuedData()); + + // Unblock the writes and actually send. + writer_->SetWritable(); + connection_.OnCanWrite(); + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + + // Parse the last packet and ensure it's one stream frame from one stream. + EXPECT_EQ(1u, writer_->frame_count()); + EXPECT_EQ(1u, writer_->stream_frames().size()); + EXPECT_EQ(0u, writer_->padding_frames().size()); + QuicStreamFrame* frame = writer_->stream_frames()[0].get(); + EXPECT_EQ(stream_id, frame->stream_id); + EXPECT_EQ("ABCDEF", + absl::string_view(frame->data_buffer, frame->data_length)); +} + +TEST_P(QuicConnectionTest, SendingZeroBytes) { + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + // Send a zero byte write with a fin using writev. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)); + QuicStreamId stream_id = QuicUtils::GetFirstBidirectionalStreamId( + connection_.transport_version(), Perspective::IS_CLIENT); + connection_.SaveAndSendStreamData(stream_id, {}, 0, FIN); + + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + EXPECT_FALSE(connection_.HasQueuedData()); + + // Padding frames are added by v99 to ensure a minimum packet size. + size_t extra_padding_frames = 0; + if (GetParam().version.HasHeaderProtection()) { + extra_padding_frames = 1; + } + + // Parse the last packet and ensure it's one stream frame from one stream. + EXPECT_EQ(1u + extra_padding_frames, writer_->frame_count()); + EXPECT_EQ(extra_padding_frames, writer_->padding_frames().size()); + ASSERT_EQ(1u, writer_->stream_frames().size()); + EXPECT_EQ(stream_id, writer_->stream_frames()[0]->stream_id); + EXPECT_TRUE(writer_->stream_frames()[0]->fin); +} + +TEST_P(QuicConnectionTest, LargeSendWithPendingAck) { + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + // Set the ack alarm by processing a ping frame. + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + // Processs a PING frame. + ProcessFramePacket(QuicFrame(QuicPingFrame())); + // Ensure that this has caused the ACK alarm to be set. + EXPECT_TRUE(connection_.HasPendingAcks()); + + // Send data and ensure the ack is bundled. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(9); + const std::string data(10000, '?'); + QuicConsumedData consumed = connection_.SaveAndSendStreamData( + GetNthClientInitiatedStreamId(0, connection_.transport_version()), data, + 0, FIN); + EXPECT_EQ(data.length(), consumed.bytes_consumed); + EXPECT_TRUE(consumed.fin_consumed); + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + EXPECT_FALSE(connection_.HasQueuedData()); + + // Parse the last packet and ensure it's one stream frame with a fin. + EXPECT_EQ(1u, writer_->frame_count()); + ASSERT_EQ(1u, writer_->stream_frames().size()); + EXPECT_EQ(GetNthClientInitiatedStreamId(0, connection_.transport_version()), + writer_->stream_frames()[0]->stream_id); + EXPECT_TRUE(writer_->stream_frames()[0]->fin); + // Ensure the ack alarm was cancelled when the ack was sent. + EXPECT_FALSE(connection_.HasPendingAcks()); +} + +TEST_P(QuicConnectionTest, OnCanWrite) { + // Visitor's OnCanWrite will send data, but will have more pending writes. + EXPECT_CALL(visitor_, OnCanWrite()) + .WillOnce(DoAll(IgnoreResult(InvokeWithoutArgs( + &connection_, &TestConnection::SendStreamData3)), + IgnoreResult(InvokeWithoutArgs( + &connection_, &TestConnection::SendStreamData5)))); + { + InSequence seq; + EXPECT_CALL(visitor_, WillingAndAbleToWrite()).WillOnce(Return(true)); + EXPECT_CALL(visitor_, WillingAndAbleToWrite()) + .WillRepeatedly(Return(false)); + } + + EXPECT_CALL(*send_algorithm_, CanSend(_)) + .WillRepeatedly(testing::Return(true)); + + connection_.OnCanWrite(); + + // Parse the last packet and ensure it's the two stream frames from + // two different streams. + EXPECT_EQ(2u, writer_->frame_count()); + EXPECT_EQ(2u, writer_->stream_frames().size()); + EXPECT_EQ(GetNthClientInitiatedStreamId(1, connection_.transport_version()), + writer_->stream_frames()[0]->stream_id); + EXPECT_EQ(GetNthClientInitiatedStreamId(2, connection_.transport_version()), + writer_->stream_frames()[1]->stream_id); +} + +TEST_P(QuicConnectionTest, RetransmitOnNack) { + QuicPacketNumber last_packet; + SendStreamDataToPeer(3, "foo", 0, NO_FIN, &last_packet); + SendStreamDataToPeer(3, "foos", 3, NO_FIN, &last_packet); + SendStreamDataToPeer(3, "fooos", 7, NO_FIN, &last_packet); + + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + // Don't lose a packet on an ack, and nothing is retransmitted. + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + QuicAckFrame ack_one = InitAckFrame(1); + ProcessAckPacket(&ack_one); + + // Lose a packet and ensure it triggers retransmission. + QuicAckFrame nack_two = ConstructAckFrame(3, 2); + LostPacketVector lost_packets; + lost_packets.push_back( + LostPacket(QuicPacketNumber(2), kMaxOutgoingPacketSize)); + EXPECT_CALL(*loss_algorithm_, DetectLosses(_, _, _, _, _, _)) + .WillOnce(DoAll(SetArgPointee<5>(lost_packets), + Return(LossDetectionInterface::DetectionStats()))); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + EXPECT_FALSE(QuicPacketCreatorPeer::SendVersionInPacket(creator_)); + ProcessAckPacket(&nack_two); +} + +TEST_P(QuicConnectionTest, DoNotSendQueuedPacketForResetStream) { + // Block the connection to queue the packet. + BlockOnNextWrite(); + + QuicStreamId stream_id = 2; + connection_.SendStreamDataWithString(stream_id, "foo", 0, NO_FIN); + + // Now that there is a queued packet, reset the stream. + SendRstStream(stream_id, QUIC_ERROR_PROCESSING_STREAM, 3); + + // Unblock the connection and verify that only the RST_STREAM is sent. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + writer_->SetWritable(); + connection_.OnCanWrite(); + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); + EXPECT_EQ(1u, writer_->rst_stream_frames().size()); +} + +TEST_P(QuicConnectionTest, SendQueuedPacketForQuicRstStreamNoError) { + // Block the connection to queue the packet. + BlockOnNextWrite(); + + QuicStreamId stream_id = 2; + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.SendStreamDataWithString(stream_id, "foo", 0, NO_FIN); + + // Now that there is a queued packet, reset the stream. + SendRstStream(stream_id, QUIC_STREAM_NO_ERROR, 3); + + // Unblock the connection and verify that the RST_STREAM is sent and the data + // packet is sent. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(AtLeast(1)); + writer_->SetWritable(); + connection_.OnCanWrite(); + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); + EXPECT_EQ(1u, writer_->rst_stream_frames().size()); +} + +TEST_P(QuicConnectionTest, DoNotRetransmitForResetStreamOnNack) { + QuicStreamId stream_id = 2; + QuicPacketNumber last_packet; + SendStreamDataToPeer(stream_id, "foo", 0, NO_FIN, &last_packet); + SendStreamDataToPeer(stream_id, "foos", 3, NO_FIN, &last_packet); + SendStreamDataToPeer(stream_id, "fooos", 7, NO_FIN, &last_packet); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + SendRstStream(stream_id, QUIC_ERROR_PROCESSING_STREAM, 12); + + // Lose a packet and ensure it does not trigger retransmission. + QuicAckFrame nack_two = ConstructAckFrame(last_packet, last_packet - 1); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(*loss_algorithm_, DetectLosses(_, _, _, _, _, _)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + ProcessAckPacket(&nack_two); +} + +TEST_P(QuicConnectionTest, RetransmitForQuicRstStreamNoErrorOnNack) { + QuicStreamId stream_id = 2; + QuicPacketNumber last_packet; + SendStreamDataToPeer(stream_id, "foo", 0, NO_FIN, &last_packet); + SendStreamDataToPeer(stream_id, "foos", 3, NO_FIN, &last_packet); + SendStreamDataToPeer(stream_id, "fooos", 7, NO_FIN, &last_packet); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + SendRstStream(stream_id, QUIC_STREAM_NO_ERROR, 12); + + // Lose a packet, ensure it triggers retransmission. + QuicAckFrame nack_two = ConstructAckFrame(last_packet, last_packet - 1); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + LostPacketVector lost_packets; + lost_packets.push_back(LostPacket(last_packet - 1, kMaxOutgoingPacketSize)); + EXPECT_CALL(*loss_algorithm_, DetectLosses(_, _, _, _, _, _)) + .WillOnce(DoAll(SetArgPointee<5>(lost_packets), + Return(LossDetectionInterface::DetectionStats()))); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(AtLeast(1)); + ProcessAckPacket(&nack_two); +} + +TEST_P(QuicConnectionTest, DoNotRetransmitForResetStreamOnRTO) { + QuicStreamId stream_id = 2; + QuicPacketNumber last_packet; + SendStreamDataToPeer(stream_id, "foo", 0, NO_FIN, &last_packet); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + SendRstStream(stream_id, QUIC_ERROR_PROCESSING_STREAM, 3); + + // Fire the RTO and verify that the RST_STREAM is resent, not stream data. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + clock_.AdvanceTime(DefaultRetransmissionTime()); + connection_.GetRetransmissionAlarm()->Fire(); + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); + EXPECT_EQ(1u, writer_->rst_stream_frames().size()); + EXPECT_EQ(stream_id, writer_->rst_stream_frames().front().stream_id); +} + +// Ensure that if the only data in flight is non-retransmittable, the +// retransmission alarm is not set. +TEST_P(QuicConnectionTest, CancelRetransmissionAlarmAfterResetStream) { + QuicStreamId stream_id = 2; + QuicPacketNumber last_data_packet; + SendStreamDataToPeer(stream_id, "foo", 0, NO_FIN, &last_data_packet); + + // Cancel the stream. + const QuicPacketNumber rst_packet = last_data_packet + 1; + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, rst_packet, _, _)).Times(1); + SendRstStream(stream_id, QUIC_ERROR_PROCESSING_STREAM, 3); + + // Ack the RST_STREAM frame (since it's retransmittable), but not the data + // packet, which is no longer retransmittable since the stream was cancelled. + QuicAckFrame nack_stream_data = + ConstructAckFrame(rst_packet, last_data_packet); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + ProcessAckPacket(&nack_stream_data); + + // Ensure that the data is still in flight, but the retransmission alarm is no + // longer set. + EXPECT_GT(manager_->GetBytesInFlight(), 0u); + EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); +} + +TEST_P(QuicConnectionTest, RetransmitForQuicRstStreamNoErrorOnPTO) { + QuicStreamId stream_id = 2; + QuicPacketNumber last_packet; + SendStreamDataToPeer(stream_id, "foo", 0, NO_FIN, &last_packet); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + SendRstStream(stream_id, QUIC_STREAM_NO_ERROR, 3); + + // Fire the RTO and verify that the RST_STREAM is resent, the stream data + // is sent. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(AtLeast(1)); + clock_.AdvanceTime(DefaultRetransmissionTime()); + connection_.GetRetransmissionAlarm()->Fire(); + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); +} + +TEST_P(QuicConnectionTest, DoNotSendPendingRetransmissionForResetStream) { + QuicStreamId stream_id = 2; + QuicPacketNumber last_packet; + SendStreamDataToPeer(stream_id, "foo", 0, NO_FIN, &last_packet); + SendStreamDataToPeer(stream_id, "foos", 3, NO_FIN, &last_packet); + BlockOnNextWrite(); + connection_.SendStreamDataWithString(stream_id, "fooos", 7, NO_FIN); + + // Lose a packet which will trigger a pending retransmission. + QuicAckFrame ack = ConstructAckFrame(last_packet, last_packet - 1); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(*loss_algorithm_, DetectLosses(_, _, _, _, _, _)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + ProcessAckPacket(&ack); + + SendRstStream(stream_id, QUIC_ERROR_PROCESSING_STREAM, 12); + + // Unblock the connection and verify that the RST_STREAM is sent but not the + // second data packet nor a retransmit. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + writer_->SetWritable(); + connection_.OnCanWrite(); + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); + ASSERT_EQ(1u, writer_->rst_stream_frames().size()); + EXPECT_EQ(stream_id, writer_->rst_stream_frames().front().stream_id); +} + +TEST_P(QuicConnectionTest, SendPendingRetransmissionForQuicRstStreamNoError) { + QuicStreamId stream_id = 2; + QuicPacketNumber last_packet; + SendStreamDataToPeer(stream_id, "foo", 0, NO_FIN, &last_packet); + SendStreamDataToPeer(stream_id, "foos", 3, NO_FIN, &last_packet); + BlockOnNextWrite(); + connection_.SendStreamDataWithString(stream_id, "fooos", 7, NO_FIN); + + // Lose a packet which will trigger a pending retransmission. + QuicAckFrame ack = ConstructAckFrame(last_packet, last_packet - 1); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + LostPacketVector lost_packets; + lost_packets.push_back(LostPacket(last_packet - 1, kMaxOutgoingPacketSize)); + EXPECT_CALL(*loss_algorithm_, DetectLosses(_, _, _, _, _, _)) + .WillOnce(DoAll(SetArgPointee<5>(lost_packets), + Return(LossDetectionInterface::DetectionStats()))); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + ProcessAckPacket(&ack); + + SendRstStream(stream_id, QUIC_STREAM_NO_ERROR, 12); + + // Unblock the connection and verify that the RST_STREAM is sent and the + // second data packet or a retransmit is sent. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(AtLeast(2)); + writer_->SetWritable(); + connection_.OnCanWrite(); + // The RST_STREAM_FRAME is sent after queued packets and pending + // retransmission. + connection_.SendControlFrame(QuicFrame( + new QuicRstStreamFrame(1, stream_id, QUIC_STREAM_NO_ERROR, 14))); + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); + EXPECT_EQ(1u, writer_->rst_stream_frames().size()); +} + +TEST_P(QuicConnectionTest, RetransmitAckedPacket) { + QuicPacketNumber last_packet; + SendStreamDataToPeer(1, "foo", 0, NO_FIN, &last_packet); // Packet 1 + SendStreamDataToPeer(1, "foos", 3, NO_FIN, &last_packet); // Packet 2 + SendStreamDataToPeer(1, "fooos", 7, NO_FIN, &last_packet); // Packet 3 + + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + // Instigate a loss with an ack. + QuicAckFrame nack_two = ConstructAckFrame(3, 2); + // The first nack should trigger a fast retransmission, but we'll be + // write blocked, so the packet will be queued. + BlockOnNextWrite(); + + LostPacketVector lost_packets; + lost_packets.push_back( + LostPacket(QuicPacketNumber(2), kMaxOutgoingPacketSize)); + EXPECT_CALL(*loss_algorithm_, DetectLosses(_, _, _, _, _, _)) + .WillOnce(DoAll(SetArgPointee<5>(lost_packets), + Return(LossDetectionInterface::DetectionStats()))); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, QuicPacketNumber(4), _, _)) + .Times(1); + ProcessAckPacket(&nack_two); + EXPECT_EQ(1u, connection_.NumQueuedPackets()); + + // Now, ack the previous transmission. + EXPECT_CALL(*loss_algorithm_, DetectLosses(_, _, _, _, _, _)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(false, _, _, _, _, _, _)); + QuicAckFrame ack_all = InitAckFrame(3); + ProcessAckPacket(&ack_all); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, QuicPacketNumber(4), _, _)) + .Times(0); + + writer_->SetWritable(); + connection_.OnCanWrite(); + + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + // We do not store retransmittable frames of this retransmission. + EXPECT_FALSE(QuicConnectionPeer::HasRetransmittableFrames(&connection_, 4)); +} + +TEST_P(QuicConnectionTest, RetransmitNackedLargestObserved) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + QuicPacketNumber original, second; + + QuicByteCount packet_size = + SendStreamDataToPeer(3, "foo", 0, NO_FIN, &original); // 1st packet. + SendStreamDataToPeer(3, "bar", 3, NO_FIN, &second); // 2nd packet. + + QuicAckFrame frame = InitAckFrame({{second, second + 1}}); + // The first nack should retransmit the largest observed packet. + LostPacketVector lost_packets; + lost_packets.push_back(LostPacket(original, kMaxOutgoingPacketSize)); + EXPECT_CALL(*loss_algorithm_, DetectLosses(_, _, _, _, _, _)) + .WillOnce(DoAll(SetArgPointee<5>(lost_packets), + Return(LossDetectionInterface::DetectionStats()))); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + // Packet 1 is short header for IETF QUIC because the encryption level + // switched to ENCRYPTION_FORWARD_SECURE in SendStreamDataToPeer. + EXPECT_CALL(*send_algorithm_, + OnPacketSent(_, _, _, + GetParam().version.HasIetfInvariantHeader() + ? packet_size + : packet_size - kQuicVersionSize, + _)); + ProcessAckPacket(&frame); +} + +TEST_P(QuicConnectionTest, WriteBlockedBufferedThenSent) { + BlockOnNextWrite(); + writer_->set_is_write_blocked_data_buffered(true); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.SendStreamDataWithString(1, "foo", 0, NO_FIN); + EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + + writer_->SetWritable(); + connection_.OnCanWrite(); + EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); +} + +TEST_P(QuicConnectionTest, WriteBlockedThenSent) { + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + BlockOnNextWrite(); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.SendStreamDataWithString(1, "foo", 0, NO_FIN); + EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + EXPECT_EQ(1u, connection_.NumQueuedPackets()); + + // The second packet should also be queued, in order to ensure packets are + // never sent out of order. + writer_->SetWritable(); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.SendStreamDataWithString(1, "foo", 0, NO_FIN); + EXPECT_EQ(2u, connection_.NumQueuedPackets()); + + // Now both are sent in order when we unblock. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + connection_.OnCanWrite(); + EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + EXPECT_EQ(0u, connection_.NumQueuedPackets()); +} + +TEST_P(QuicConnectionTest, RetransmitWriteBlockedAckedOriginalThenSent) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + connection_.SendStreamDataWithString(3, "foo", 0, NO_FIN); + EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + + BlockOnNextWrite(); + writer_->set_is_write_blocked_data_buffered(true); + // Simulate the retransmission alarm firing. + clock_.AdvanceTime(DefaultRetransmissionTime()); + connection_.GetRetransmissionAlarm()->Fire(); + + // Ack the sent packet before the callback returns, which happens in + // rare circumstances with write blocked sockets. + QuicAckFrame ack = InitAckFrame(1); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + ProcessAckPacket(&ack); + + writer_->SetWritable(); + connection_.OnCanWrite(); + EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + EXPECT_FALSE(QuicConnectionPeer::HasRetransmittableFrames(&connection_, 3)); +} + +TEST_P(QuicConnectionTest, AlarmsWhenWriteBlocked) { + // Block the connection. + BlockOnNextWrite(); + connection_.SendStreamDataWithString(3, "foo", 0, NO_FIN); + EXPECT_EQ(1u, writer_->packets_write_attempts()); + EXPECT_TRUE(writer_->IsWriteBlocked()); + + // Set the send alarm. Fire the alarm and ensure it doesn't attempt to write. + connection_.GetSendAlarm()->Set(clock_.ApproximateNow()); + connection_.GetSendAlarm()->Fire(); + EXPECT_TRUE(writer_->IsWriteBlocked()); + EXPECT_EQ(1u, writer_->packets_write_attempts()); +} + +TEST_P(QuicConnectionTest, NoSendAlarmAfterProcessPacketWhenWriteBlocked) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + // Block the connection. + BlockOnNextWrite(); + connection_.SendStreamDataWithString(3, "foo", 0, NO_FIN); + EXPECT_TRUE(writer_->IsWriteBlocked()); + EXPECT_EQ(1u, connection_.NumQueuedPackets()); + EXPECT_FALSE(connection_.GetSendAlarm()->IsSet()); + + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + // Process packet number 1. Can not call ProcessPacket or ProcessDataPacket + // here, because they will fire the alarm after QuicConnection::ProcessPacket + // is returned. + const uint64_t received_packet_num = 1; + const bool has_stop_waiting = false; + const EncryptionLevel level = ENCRYPTION_FORWARD_SECURE; + std::unique_ptr packet( + ConstructDataPacket(received_packet_num, has_stop_waiting, level)); + char buffer[kMaxOutgoingPacketSize]; + size_t encrypted_length = + peer_framer_.EncryptPayload(level, QuicPacketNumber(received_packet_num), + *packet, buffer, kMaxOutgoingPacketSize); + connection_.ProcessUdpPacket( + kSelfAddress, kPeerAddress, + QuicReceivedPacket(buffer, encrypted_length, clock_.Now(), false)); + + EXPECT_TRUE(writer_->IsWriteBlocked()); + EXPECT_FALSE(connection_.GetSendAlarm()->IsSet()); +} + +TEST_P(QuicConnectionTest, AddToWriteBlockedListIfWriterBlockedWhenProcessing) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + SendStreamDataToPeer(1, "foo", 0, NO_FIN, nullptr); + + // Simulate the case where a shared writer gets blocked by another connection. + writer_->SetWriteBlocked(); + + // Process an ACK, make sure the connection calls visitor_->OnWriteBlocked(). + QuicAckFrame ack1 = InitAckFrame(1); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _, _, _)); + EXPECT_CALL(visitor_, OnWriteBlocked()).Times(1); + ProcessAckPacket(1, &ack1); +} + +TEST_P(QuicConnectionTest, DoNotAddToWriteBlockedListAfterDisconnect) { + writer_->SetBatchMode(true); + EXPECT_TRUE(connection_.connected()); + // Have to explicitly grab the OnConnectionClosed frame and check + // its parameters because this is a silent connection close and the + // frame is not also transmitted to the peer. + EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .WillOnce(Invoke(this, &QuicConnectionTest::SaveConnectionCloseFrame)); + + EXPECT_CALL(visitor_, OnWriteBlocked()).Times(0); + + { + QuicConnection::ScopedPacketFlusher flusher(&connection_); + connection_.CloseConnection(QUIC_PEER_GOING_AWAY, "no reason", + ConnectionCloseBehavior::SILENT_CLOSE); + + EXPECT_FALSE(connection_.connected()); + writer_->SetWriteBlocked(); + } + EXPECT_EQ(1, connection_close_frame_count_); + EXPECT_THAT(saved_connection_close_frame_.quic_error_code, + IsError(QUIC_PEER_GOING_AWAY)); +} + +TEST_P(QuicConnectionTest, AddToWriteBlockedListIfBlockedOnFlushPackets) { + writer_->SetBatchMode(true); + writer_->BlockOnNextFlush(); + + EXPECT_CALL(visitor_, OnWriteBlocked()).Times(1); + { + QuicConnection::ScopedPacketFlusher flusher(&connection_); + // flusher's destructor will call connection_.FlushPackets, which should add + // the connection to the write blocked list. + } +} + +TEST_P(QuicConnectionTest, NoLimitPacketsPerNack) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + int offset = 0; + // Send packets 1 to 15. + for (int i = 0; i < 15; ++i) { + SendStreamDataToPeer(1, "foo", offset, NO_FIN, nullptr); + offset += 3; + } + + // Ack 15, nack 1-14. + + QuicAckFrame nack = + InitAckFrame({{QuicPacketNumber(15), QuicPacketNumber(16)}}); + + // 14 packets have been NACK'd and lost. + LostPacketVector lost_packets; + for (int i = 1; i < 15; ++i) { + lost_packets.push_back( + LostPacket(QuicPacketNumber(i), kMaxOutgoingPacketSize)); + } + EXPECT_CALL(*loss_algorithm_, DetectLosses(_, _, _, _, _, _)) + .WillOnce(DoAll(SetArgPointee<5>(lost_packets), + Return(LossDetectionInterface::DetectionStats()))); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + ProcessAckPacket(&nack); +} + +// Test sending multiple acks from the connection to the session. +TEST_P(QuicConnectionTest, MultipleAcks) { + if (connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacket(1); + QuicPacketCreatorPeer::SetPacketNumber(&peer_creator_, 2); + QuicPacketNumber last_packet; + SendStreamDataToPeer(1, "foo", 0, NO_FIN, &last_packet); // Packet 1 + EXPECT_EQ(QuicPacketNumber(1u), last_packet); + SendStreamDataToPeer(3, "foo", 0, NO_FIN, &last_packet); // Packet 2 + EXPECT_EQ(QuicPacketNumber(2u), last_packet); + SendAckPacketToPeer(); // Packet 3 + SendStreamDataToPeer(5, "foo", 0, NO_FIN, &last_packet); // Packet 4 + EXPECT_EQ(QuicPacketNumber(4u), last_packet); + SendStreamDataToPeer(1, "foo", 3, NO_FIN, &last_packet); // Packet 5 + EXPECT_EQ(QuicPacketNumber(5u), last_packet); + SendStreamDataToPeer(3, "foo", 3, NO_FIN, &last_packet); // Packet 6 + EXPECT_EQ(QuicPacketNumber(6u), last_packet); + + // Client will ack packets 1, 2, [!3], 4, 5. + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + QuicAckFrame frame1 = ConstructAckFrame(5, 3); + ProcessAckPacket(&frame1); + + // Now the client implicitly acks 3, and explicitly acks 6. + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + QuicAckFrame frame2 = InitAckFrame(6); + ProcessAckPacket(&frame2); +} + +TEST_P(QuicConnectionTest, DontLatchUnackedPacket) { + if (connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacket(1); + QuicPacketCreatorPeer::SetPacketNumber(&peer_creator_, 2); + SendStreamDataToPeer(1, "foo", 0, NO_FIN, nullptr); // Packet 1; + // From now on, we send acks, so the send algorithm won't mark them pending. + SendAckPacketToPeer(); // Packet 2 + + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + QuicAckFrame frame = InitAckFrame(1); + ProcessAckPacket(&frame); + + // Verify that our internal state has least-unacked as 2, because we're still + // waiting for a potential ack for 2. + + EXPECT_EQ(QuicPacketNumber(2u), stop_waiting()->least_unacked); + + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + frame = InitAckFrame(2); + ProcessAckPacket(&frame); + EXPECT_EQ(QuicPacketNumber(3u), stop_waiting()->least_unacked); + + // When we send an ack, we make sure our least-unacked makes sense. In this + // case since we're not waiting on an ack for 2 and all packets are acked, we + // set it to 3. + SendAckPacketToPeer(); // Packet 3 + // Least_unacked remains at 3 until another ack is received. + EXPECT_EQ(QuicPacketNumber(3u), stop_waiting()->least_unacked); + if (GetParam().no_stop_waiting) { + // Expect no stop waiting frame is sent. + EXPECT_FALSE(least_unacked().IsInitialized()); + } else { + // Check that the outgoing ack had its packet number as least_unacked. + EXPECT_EQ(QuicPacketNumber(3u), least_unacked()); + } + + // Ack the ack, which updates the rtt and raises the least unacked. + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + frame = InitAckFrame(3); + ProcessAckPacket(&frame); + + SendStreamDataToPeer(1, "bar", 3, NO_FIN, nullptr); // Packet 4 + EXPECT_EQ(QuicPacketNumber(4u), stop_waiting()->least_unacked); + SendAckPacketToPeer(); // Packet 5 + if (GetParam().no_stop_waiting) { + // Expect no stop waiting frame is sent. + EXPECT_FALSE(least_unacked().IsInitialized()); + } else { + EXPECT_EQ(QuicPacketNumber(4u), least_unacked()); + } + + // Send two data packets at the end, and ensure if the last one is acked, + // the least unacked is raised above the ack packets. + SendStreamDataToPeer(1, "bar", 6, NO_FIN, nullptr); // Packet 6 + SendStreamDataToPeer(1, "bar", 9, NO_FIN, nullptr); // Packet 7 + + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + frame = InitAckFrame({{QuicPacketNumber(1), QuicPacketNumber(5)}, + {QuicPacketNumber(7), QuicPacketNumber(8)}}); + ProcessAckPacket(&frame); + + EXPECT_EQ(QuicPacketNumber(6u), stop_waiting()->least_unacked); +} + +TEST_P(QuicConnectionTest, SendHandshakeMessages) { + // Attempt to send a handshake message and have the socket block. + EXPECT_CALL(*send_algorithm_, CanSend(_)).WillRepeatedly(Return(true)); + BlockOnNextWrite(); + connection_.SendCryptoDataWithString("foo", 0); + // The packet should be serialized, but not queued. + EXPECT_EQ(1u, connection_.NumQueuedPackets()); + + // Switch to the new encrypter. + connection_.SetEncrypter( + ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + + // Now become writeable and flush the packets. + writer_->SetWritable(); + EXPECT_CALL(visitor_, OnCanWrite()); + connection_.OnCanWrite(); + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + + // Verify that the handshake packet went out with Initial encryption. + EXPECT_NE(0x02020202u, writer_->final_bytes_of_last_packet()); +} + +TEST_P(QuicConnectionTest, DropRetransmitsForInitialPacketAfterForwardSecure) { + connection_.SendCryptoStreamData(); + // Simulate the retransmission alarm firing and the socket blocking. + BlockOnNextWrite(); + clock_.AdvanceTime(DefaultRetransmissionTime()); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.GetRetransmissionAlarm()->Fire(); + EXPECT_EQ(1u, connection_.NumQueuedPackets()); + + // Go forward secure. + connection_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(0x02)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + notifier_.NeuterUnencryptedData(); + connection_.NeuterUnencryptedPackets(); + connection_.OnHandshakeComplete(); + + EXPECT_EQ(QuicTime::Zero(), connection_.GetRetransmissionAlarm()->deadline()); + // Unblock the socket and ensure that no packets are sent. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + writer_->SetWritable(); + connection_.OnCanWrite(); +} + +TEST_P(QuicConnectionTest, RetransmitPacketsWithInitialEncryption) { + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + + connection_.SendCryptoDataWithString("foo", 0); + + connection_.SetEncrypter( + ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + if (!connection_.version().KnowsWhichDecrypterToUse()) { + writer_->framer()->framer()->SetAlternativeDecrypter( + ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT), false); + } + + SendStreamDataToPeer(2, "bar", 0, NO_FIN, nullptr); + EXPECT_FALSE(notifier_.HasLostStreamData()); + connection_.MarkZeroRttPacketsForRetransmission(0); + EXPECT_TRUE(notifier_.HasLostStreamData()); +} + +TEST_P(QuicConnectionTest, BufferNonDecryptablePackets) { + if (connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + // SetFromConfig is always called after construction from InitializeSession. + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + connection_.SetFromConfig(config); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + peer_framer_.SetEncrypter( + ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + if (!connection_.version().KnowsWhichDecrypterToUse()) { + writer_->framer()->framer()->SetDecrypter( + ENCRYPTION_ZERO_RTT, std::make_unique()); + } + + // Process an encrypted packet which can not yet be decrypted which should + // result in the packet being buffered. + ProcessDataPacketAtLevel(1, !kHasStopWaiting, ENCRYPTION_ZERO_RTT); + + // Transition to the new encryption state and process another encrypted packet + // which should result in the original packet being processed. + SetDecrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + connection_.SetEncrypter( + ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(2); + ProcessDataPacketAtLevel(2, !kHasStopWaiting, ENCRYPTION_ZERO_RTT); + + // Finally, process a third packet and note that we do not reprocess the + // buffered packet. + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacketAtLevel(3, !kHasStopWaiting, ENCRYPTION_ZERO_RTT); +} + +TEST_P(QuicConnectionTest, Buffer100NonDecryptablePacketsThenKeyChange) { + if (connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + // SetFromConfig is always called after construction from InitializeSession. + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + config.set_max_undecryptable_packets(100); + connection_.SetFromConfig(config); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + peer_framer_.SetEncrypter( + ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + + // Process an encrypted packet which can not yet be decrypted which should + // result in the packet being buffered. + for (uint64_t i = 1; i <= 100; ++i) { + ProcessDataPacketAtLevel(i, !kHasStopWaiting, ENCRYPTION_ZERO_RTT); + } + + // Transition to the new encryption state and process another encrypted packet + // which should result in the original packets being processed. + EXPECT_FALSE(connection_.GetProcessUndecryptablePacketsAlarm()->IsSet()); + SetDecrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + EXPECT_TRUE(connection_.GetProcessUndecryptablePacketsAlarm()->IsSet()); + connection_.SetEncrypter( + ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(100); + if (!connection_.version().KnowsWhichDecrypterToUse()) { + writer_->framer()->framer()->SetDecrypter( + ENCRYPTION_ZERO_RTT, std::make_unique()); + } + connection_.GetProcessUndecryptablePacketsAlarm()->Fire(); + + // Finally, process a third packet and note that we do not reprocess the + // buffered packet. + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacketAtLevel(102, !kHasStopWaiting, ENCRYPTION_ZERO_RTT); +} + +TEST_P(QuicConnectionTest, SetRTOAfterWritingToSocket) { + BlockOnNextWrite(); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.SendStreamDataWithString(1, "foo", 0, NO_FIN); + EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + + // Test that RTO is started once we write to the socket. + writer_->SetWritable(); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + connection_.OnCanWrite(); + EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); +} + +TEST_P(QuicConnectionTest, TestQueued) { + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + BlockOnNextWrite(); + connection_.SendStreamDataWithString(1, "foo", 0, NO_FIN); + EXPECT_EQ(1u, connection_.NumQueuedPackets()); + + // Unblock the writes and actually send. + writer_->SetWritable(); + connection_.OnCanWrite(); + EXPECT_EQ(0u, connection_.NumQueuedPackets()); +} + +TEST_P(QuicConnectionTest, InitialTimeout) { + EXPECT_TRUE(connection_.connected()); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(AnyNumber()); + EXPECT_FALSE(connection_.GetTimeoutAlarm()->IsSet()); + + // SetFromConfig sets the initial timeouts before negotiation. + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + connection_.SetFromConfig(config); + // Subtract a second from the idle timeout on the client side. + QuicTime default_timeout = + clock_.ApproximateNow() + + QuicTime::Delta::FromSeconds(kInitialIdleTimeoutSecs - 1); + EXPECT_EQ(default_timeout, connection_.GetTimeoutAlarm()->deadline()); + + EXPECT_CALL(visitor_, + OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)); + // Simulate the timeout alarm firing. + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(kInitialIdleTimeoutSecs - 1)); + connection_.GetTimeoutAlarm()->Fire(); + + EXPECT_FALSE(connection_.GetTimeoutAlarm()->IsSet()); + EXPECT_FALSE(connection_.connected()); + + EXPECT_FALSE(connection_.HasPendingAcks()); + EXPECT_FALSE(connection_.GetPingAlarm()->IsSet()); + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + EXPECT_FALSE(connection_.GetSendAlarm()->IsSet()); + EXPECT_FALSE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + EXPECT_FALSE(connection_.GetProcessUndecryptablePacketsAlarm()->IsSet()); + TestConnectionCloseQuicErrorCode(QUIC_NETWORK_IDLE_TIMEOUT); +} + +TEST_P(QuicConnectionTest, IdleTimeoutAfterFirstSentPacket) { + EXPECT_TRUE(connection_.connected()); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(AnyNumber()); + EXPECT_FALSE(connection_.GetTimeoutAlarm()->IsSet()); + + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + connection_.SetFromConfig(config); + EXPECT_TRUE(connection_.GetTimeoutAlarm()->IsSet()); + QuicTime initial_ddl = + clock_.ApproximateNow() + + QuicTime::Delta::FromSeconds(kInitialIdleTimeoutSecs - 1); + EXPECT_EQ(initial_ddl, connection_.GetTimeoutAlarm()->deadline()); + EXPECT_TRUE(connection_.connected()); + + // Advance the time and send the first packet to the peer. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(20)); + QuicPacketNumber last_packet; + SendStreamDataToPeer(1, "foo", 0, NO_FIN, &last_packet); + EXPECT_EQ(QuicPacketNumber(1u), last_packet); + // This will be the updated deadline for the connection to idle time out. + QuicTime new_ddl = clock_.ApproximateNow() + + QuicTime::Delta::FromSeconds(kInitialIdleTimeoutSecs - 1); + + // Simulate the timeout alarm firing, the connection should not be closed as + // a new packet has been sent. + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)).Times(0); + QuicTime::Delta delay = initial_ddl - clock_.ApproximateNow(); + clock_.AdvanceTime(delay); + // Verify the timeout alarm deadline is updated. + EXPECT_TRUE(connection_.connected()); + EXPECT_TRUE(connection_.GetTimeoutAlarm()->IsSet()); + EXPECT_EQ(new_ddl, connection_.GetTimeoutAlarm()->deadline()); + + // Simulate the timeout alarm firing again, the connection now should be + // closed. + EXPECT_CALL(visitor_, + OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)); + clock_.AdvanceTime(new_ddl - clock_.ApproximateNow()); + connection_.GetTimeoutAlarm()->Fire(); + EXPECT_FALSE(connection_.GetTimeoutAlarm()->IsSet()); + EXPECT_FALSE(connection_.connected()); + + EXPECT_FALSE(connection_.HasPendingAcks()); + EXPECT_FALSE(connection_.GetPingAlarm()->IsSet()); + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + EXPECT_FALSE(connection_.GetSendAlarm()->IsSet()); + EXPECT_FALSE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + TestConnectionCloseQuicErrorCode(QUIC_NETWORK_IDLE_TIMEOUT); +} + +TEST_P(QuicConnectionTest, IdleTimeoutAfterSendTwoPackets) { + EXPECT_TRUE(connection_.connected()); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(AnyNumber()); + EXPECT_FALSE(connection_.GetTimeoutAlarm()->IsSet()); + + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + connection_.SetFromConfig(config); + EXPECT_TRUE(connection_.GetTimeoutAlarm()->IsSet()); + QuicTime initial_ddl = + clock_.ApproximateNow() + + QuicTime::Delta::FromSeconds(kInitialIdleTimeoutSecs - 1); + EXPECT_EQ(initial_ddl, connection_.GetTimeoutAlarm()->deadline()); + EXPECT_TRUE(connection_.connected()); + + // Immediately send the first packet, this is a rare case but test code will + // hit this issue often as MockClock used for tests doesn't move with code + // execution until manually adjusted. + QuicPacketNumber last_packet; + SendStreamDataToPeer(1, "foo", 0, NO_FIN, &last_packet); + EXPECT_EQ(QuicPacketNumber(1u), last_packet); + + // Advance the time and send the second packet to the peer. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(20)); + SendStreamDataToPeer(1, "foo", 0, NO_FIN, &last_packet); + EXPECT_EQ(QuicPacketNumber(2u), last_packet); + + // Simulate the timeout alarm firing, the connection will be closed. + EXPECT_CALL(visitor_, + OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)); + clock_.AdvanceTime(initial_ddl - clock_.ApproximateNow()); + connection_.GetTimeoutAlarm()->Fire(); + + EXPECT_FALSE(connection_.GetTimeoutAlarm()->IsSet()); + EXPECT_FALSE(connection_.connected()); + + EXPECT_FALSE(connection_.HasPendingAcks()); + EXPECT_FALSE(connection_.GetPingAlarm()->IsSet()); + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + EXPECT_FALSE(connection_.GetSendAlarm()->IsSet()); + EXPECT_FALSE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + TestConnectionCloseQuicErrorCode(QUIC_NETWORK_IDLE_TIMEOUT); +} + +TEST_P(QuicConnectionTest, HandshakeTimeout) { + // Use a shorter handshake timeout than idle timeout for this test. + const QuicTime::Delta timeout = QuicTime::Delta::FromSeconds(5); + connection_.SetNetworkTimeouts(timeout, timeout); + EXPECT_TRUE(connection_.connected()); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(AnyNumber()); + + QuicTime handshake_timeout = + clock_.ApproximateNow() + timeout - QuicTime::Delta::FromSeconds(1); + EXPECT_EQ(handshake_timeout, connection_.GetTimeoutAlarm()->deadline()); + EXPECT_TRUE(connection_.connected()); + + // Send and ack new data 3 seconds later to lengthen the idle timeout. + SendStreamDataToPeer( + GetNthClientInitiatedStreamId(0, connection_.transport_version()), + "GET /", 0, FIN, nullptr); + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(3)); + QuicAckFrame frame = InitAckFrame(1); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + ProcessAckPacket(&frame); + + EXPECT_TRUE(connection_.GetTimeoutAlarm()->IsSet()); + EXPECT_TRUE(connection_.connected()); + + clock_.AdvanceTime(timeout - QuicTime::Delta::FromSeconds(2)); + + EXPECT_CALL(visitor_, + OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)); + // Simulate the timeout alarm firing. + connection_.GetTimeoutAlarm()->Fire(); + + EXPECT_FALSE(connection_.GetTimeoutAlarm()->IsSet()); + EXPECT_FALSE(connection_.connected()); + + EXPECT_FALSE(connection_.HasPendingAcks()); + EXPECT_FALSE(connection_.GetPingAlarm()->IsSet()); + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + EXPECT_FALSE(connection_.GetSendAlarm()->IsSet()); + TestConnectionCloseQuicErrorCode(QUIC_HANDSHAKE_TIMEOUT); +} + +TEST_P(QuicConnectionTest, PingAfterSend) { + if (connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + EXPECT_TRUE(connection_.connected()); + EXPECT_CALL(visitor_, ShouldKeepConnectionAlive()) + .WillRepeatedly(Return(true)); + EXPECT_FALSE(connection_.GetPingAlarm()->IsSet()); + + // Advance to 5ms, and send a packet to the peer, which will set + // the ping alarm. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + SendStreamDataToPeer( + GetNthClientInitiatedStreamId(0, connection_.transport_version()), + "GET /", 0, FIN, nullptr); + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(QuicTime::Delta::FromSeconds(15), + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + + // Now recevie an ACK of the previous packet, which will move the + // ping alarm forward. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + QuicAckFrame frame = InitAckFrame(1); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + ProcessAckPacket(&frame); + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + // The ping timer is set slightly less than 15 seconds in the future, because + // of the 1s ping timer alarm granularity. + EXPECT_EQ( + QuicTime::Delta::FromSeconds(15) - QuicTime::Delta::FromMilliseconds(5), + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + + writer_->Reset(); + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(15)); + connection_.GetPingAlarm()->Fire(); + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); + ASSERT_EQ(1u, writer_->ping_frames().size()); + writer_->Reset(); + + EXPECT_CALL(visitor_, ShouldKeepConnectionAlive()) + .WillRepeatedly(Return(false)); + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + SendAckPacketToPeer(); + + EXPECT_FALSE(connection_.GetPingAlarm()->IsSet()); +} + +TEST_P(QuicConnectionTest, ReducedPingTimeout) { + if (connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + EXPECT_TRUE(connection_.connected()); + EXPECT_CALL(visitor_, ShouldKeepConnectionAlive()) + .WillRepeatedly(Return(true)); + EXPECT_FALSE(connection_.GetPingAlarm()->IsSet()); + + // Use a reduced ping timeout for this connection. + connection_.set_keep_alive_ping_timeout(QuicTime::Delta::FromSeconds(10)); + + // Advance to 5ms, and send a packet to the peer, which will set + // the ping alarm. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + SendStreamDataToPeer( + GetNthClientInitiatedStreamId(0, connection_.transport_version()), + "GET /", 0, FIN, nullptr); + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(QuicTime::Delta::FromSeconds(10), + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + + // Now recevie an ACK of the previous packet, which will move the + // ping alarm forward. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + QuicAckFrame frame = InitAckFrame(1); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + ProcessAckPacket(&frame); + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + // The ping timer is set slightly less than 10 seconds in the future, because + // of the 1s ping timer alarm granularity. + EXPECT_EQ( + QuicTime::Delta::FromSeconds(10) - QuicTime::Delta::FromMilliseconds(5), + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + + writer_->Reset(); + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(10)); + connection_.GetPingAlarm()->Fire(); + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); + ASSERT_EQ(1u, writer_->ping_frames().size()); + writer_->Reset(); + + EXPECT_CALL(visitor_, ShouldKeepConnectionAlive()) + .WillRepeatedly(Return(false)); + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + SendAckPacketToPeer(); + + EXPECT_FALSE(connection_.GetPingAlarm()->IsSet()); +} + +// Tests whether sending an MTU discovery packet to peer successfully causes the +// maximum packet size to increase. +TEST_P(QuicConnectionTest, SendMtuDiscoveryPacket) { + MtuDiscoveryTestInit(); + + // Send an MTU probe. + const size_t new_mtu = kDefaultMaxPacketSize + 100; + QuicByteCount mtu_probe_size; + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .WillOnce(SaveArg<3>(&mtu_probe_size)); + connection_.SendMtuDiscoveryPacket(new_mtu); + EXPECT_EQ(new_mtu, mtu_probe_size); + EXPECT_EQ(QuicPacketNumber(1u), creator_->packet_number()); + + // Send more than MTU worth of data. No acknowledgement was received so far, + // so the MTU should be at its old value. + const std::string data(kDefaultMaxPacketSize + 1, '.'); + QuicByteCount size_before_mtu_change; + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(2) + .WillOnce(SaveArg<3>(&size_before_mtu_change)) + .WillOnce(Return()); + connection_.SendStreamDataWithString(3, data, 0, FIN); + EXPECT_EQ(QuicPacketNumber(3u), creator_->packet_number()); + EXPECT_EQ(kDefaultMaxPacketSize, size_before_mtu_change); + + // Acknowledge all packets so far. + QuicAckFrame probe_ack = InitAckFrame(3); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + ProcessAckPacket(&probe_ack); + EXPECT_EQ(new_mtu, connection_.max_packet_length()); + + // Send the same data again. Check that it fits into a single packet now. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.SendStreamDataWithString(3, data, 0, FIN); + EXPECT_EQ(QuicPacketNumber(4u), creator_->packet_number()); +} + +// Verifies that when a MTU probe packet is sent and buffered in a batch writer, +// the writer is flushed immediately. +TEST_P(QuicConnectionTest, BatchWriterFlushedAfterMtuDiscoveryPacket) { + writer_->SetBatchMode(true); + MtuDiscoveryTestInit(); + + // Send an MTU probe. + const size_t target_mtu = kDefaultMaxPacketSize + 100; + QuicByteCount mtu_probe_size; + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .WillOnce(SaveArg<3>(&mtu_probe_size)); + const uint32_t prior_flush_attempts = writer_->flush_attempts(); + connection_.SendMtuDiscoveryPacket(target_mtu); + EXPECT_EQ(target_mtu, mtu_probe_size); + EXPECT_EQ(writer_->flush_attempts(), prior_flush_attempts + 1); +} + +// Tests whether MTU discovery does not happen when it is not explicitly enabled +// by the connection options. +TEST_P(QuicConnectionTest, MtuDiscoveryDisabled) { + MtuDiscoveryTestInit(); + + const QuicPacketCount packets_between_probes_base = 10; + set_packets_between_probes_base(packets_between_probes_base); + + const QuicPacketCount number_of_packets = packets_between_probes_base * 2; + for (QuicPacketCount i = 0; i < number_of_packets; i++) { + SendStreamDataToPeer(3, ".", i, NO_FIN, nullptr); + EXPECT_FALSE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + EXPECT_EQ(0u, connection_.mtu_probe_count()); + } +} + +// Tests whether MTU discovery works when all probes are acknowledged on the +// first try. +TEST_P(QuicConnectionTest, MtuDiscoveryEnabled) { + MtuDiscoveryTestInit(); + + const QuicPacketCount packets_between_probes_base = 5; + set_packets_between_probes_base(packets_between_probes_base); + + connection_.EnablePathMtuDiscovery(send_algorithm_); + + // Send enough packets so that the next one triggers path MTU discovery. + for (QuicPacketCount i = 0; i < packets_between_probes_base - 1; i++) { + SendStreamDataToPeer(3, ".", i, NO_FIN, nullptr); + ASSERT_FALSE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + } + + // Trigger the probe. + SendStreamDataToPeer(3, "!", packets_between_probes_base - 1, NO_FIN, + nullptr); + ASSERT_TRUE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + QuicByteCount probe_size; + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .WillOnce(SaveArg<3>(&probe_size)); + connection_.GetMtuDiscoveryAlarm()->Fire(); + + EXPECT_THAT(probe_size, InRange(connection_.max_packet_length(), + kMtuDiscoveryTargetPacketSizeHigh)); + + const QuicPacketNumber probe_packet_number = + FirstSendingPacketNumber() + packets_between_probes_base; + ASSERT_EQ(probe_packet_number, creator_->packet_number()); + + // Acknowledge all packets sent so far. + QuicAckFrame probe_ack = InitAckFrame(probe_packet_number); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)) + .Times(AnyNumber()); + ProcessAckPacket(&probe_ack); + EXPECT_EQ(probe_size, connection_.max_packet_length()); + EXPECT_EQ(0u, connection_.GetBytesInFlight()); + + EXPECT_EQ(1u, connection_.mtu_probe_count()); + + QuicStreamOffset stream_offset = packets_between_probes_base; + QuicByteCount last_probe_size = 0; + for (size_t num_probes = 1; num_probes < kMtuDiscoveryAttempts; + ++num_probes) { + // Send just enough packets without triggering the next probe. + for (QuicPacketCount i = 0; + i < (packets_between_probes_base << num_probes) - 1; ++i) { + SendStreamDataToPeer(3, ".", stream_offset++, NO_FIN, nullptr); + ASSERT_FALSE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + } + + // Trigger the next probe. + SendStreamDataToPeer(3, "!", stream_offset++, NO_FIN, nullptr); + ASSERT_TRUE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + QuicByteCount new_probe_size; + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .WillOnce(SaveArg<3>(&new_probe_size)); + connection_.GetMtuDiscoveryAlarm()->Fire(); + EXPECT_THAT(new_probe_size, + InRange(probe_size, kMtuDiscoveryTargetPacketSizeHigh)); + EXPECT_EQ(num_probes + 1, connection_.mtu_probe_count()); + + // Acknowledge all packets sent so far. + QuicAckFrame probe_ack = InitAckFrame(creator_->packet_number()); + ProcessAckPacket(&probe_ack); + EXPECT_EQ(new_probe_size, connection_.max_packet_length()); + EXPECT_EQ(0u, connection_.GetBytesInFlight()); + + last_probe_size = probe_size; + probe_size = new_probe_size; + } + + // The last probe size should be equal to the target. + EXPECT_EQ(probe_size, kMtuDiscoveryTargetPacketSizeHigh); + + writer_->SetShouldWriteFail(); + + // Ignore PACKET_WRITE_ERROR once. + SendStreamDataToPeer(3, "(", stream_offset++, NO_FIN, nullptr); + EXPECT_EQ(last_probe_size, connection_.max_packet_length()); + EXPECT_TRUE(connection_.connected()); + + // Close connection on another PACKET_WRITE_ERROR. + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)) + .WillOnce(Invoke(this, &QuicConnectionTest::SaveConnectionCloseFrame)); + SendStreamDataToPeer(3, ")", stream_offset++, NO_FIN, nullptr); + EXPECT_EQ(last_probe_size, connection_.max_packet_length()); + EXPECT_FALSE(connection_.connected()); + EXPECT_THAT(saved_connection_close_frame_.quic_error_code, + IsError(QUIC_PACKET_WRITE_ERROR)); +} + +// After a successful MTU probe, one and only one write error should be ignored +// if it happened in QuicConnection::FlushPacket. +TEST_P(QuicConnectionTest, + MtuDiscoveryIgnoreOneWriteErrorInFlushAfterSuccessfulProbes) { + MtuDiscoveryTestInit(); + writer_->SetBatchMode(true); + + const QuicPacketCount packets_between_probes_base = 5; + set_packets_between_probes_base(packets_between_probes_base); + + connection_.EnablePathMtuDiscovery(send_algorithm_); + + const QuicByteCount original_max_packet_length = + connection_.max_packet_length(); + // Send enough packets so that the next one triggers path MTU discovery. + for (QuicPacketCount i = 0; i < packets_between_probes_base - 1; i++) { + SendStreamDataToPeer(3, ".", i, NO_FIN, nullptr); + ASSERT_FALSE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + } + + // Trigger the probe. + SendStreamDataToPeer(3, "!", packets_between_probes_base - 1, NO_FIN, + nullptr); + ASSERT_TRUE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + QuicByteCount probe_size; + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .WillOnce(SaveArg<3>(&probe_size)); + connection_.GetMtuDiscoveryAlarm()->Fire(); + + EXPECT_THAT(probe_size, InRange(connection_.max_packet_length(), + kMtuDiscoveryTargetPacketSizeHigh)); + + const QuicPacketNumber probe_packet_number = + FirstSendingPacketNumber() + packets_between_probes_base; + ASSERT_EQ(probe_packet_number, creator_->packet_number()); + + // Acknowledge all packets sent so far. + QuicAckFrame probe_ack = InitAckFrame(probe_packet_number); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)) + .Times(AnyNumber()); + ProcessAckPacket(&probe_ack); + EXPECT_EQ(probe_size, connection_.max_packet_length()); + EXPECT_EQ(0u, connection_.GetBytesInFlight()); + + EXPECT_EQ(1u, connection_.mtu_probe_count()); + + writer_->SetShouldWriteFail(); + + // Ignore PACKET_WRITE_ERROR once. + { + QuicConnection::ScopedPacketFlusher flusher(&connection_); + // flusher's destructor will call connection_.FlushPackets, which should + // get a WRITE_STATUS_ERROR from the writer and ignore it. + } + EXPECT_EQ(original_max_packet_length, connection_.max_packet_length()); + EXPECT_TRUE(connection_.connected()); + + // Close connection on another PACKET_WRITE_ERROR. + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)) + .WillOnce(Invoke(this, &QuicConnectionTest::SaveConnectionCloseFrame)); + { + QuicConnection::ScopedPacketFlusher flusher(&connection_); + // flusher's destructor will call connection_.FlushPackets, which should + // get a WRITE_STATUS_ERROR from the writer and ignore it. + } + EXPECT_EQ(original_max_packet_length, connection_.max_packet_length()); + EXPECT_FALSE(connection_.connected()); + EXPECT_THAT(saved_connection_close_frame_.quic_error_code, + IsError(QUIC_PACKET_WRITE_ERROR)); +} + +// Simulate the case where the first attempt to send a probe is write blocked, +// and after unblock, the second attempt returns a MSG_TOO_BIG error. +TEST_P(QuicConnectionTest, MtuDiscoveryWriteBlocked) { + MtuDiscoveryTestInit(); + + const QuicPacketCount packets_between_probes_base = 5; + set_packets_between_probes_base(packets_between_probes_base); + + connection_.EnablePathMtuDiscovery(send_algorithm_); + + // Send enough packets so that the next one triggers path MTU discovery. + for (QuicPacketCount i = 0; i < packets_between_probes_base - 1; i++) { + SendStreamDataToPeer(3, ".", i, NO_FIN, nullptr); + ASSERT_FALSE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + } + + QuicByteCount original_max_packet_length = connection_.max_packet_length(); + + // Trigger the probe. + SendStreamDataToPeer(3, "!", packets_between_probes_base - 1, NO_FIN, + nullptr); + ASSERT_TRUE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)); + BlockOnNextWrite(); + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + connection_.GetMtuDiscoveryAlarm()->Fire(); + EXPECT_EQ(1u, connection_.mtu_probe_count()); + EXPECT_EQ(1u, connection_.NumQueuedPackets()); + ASSERT_TRUE(connection_.connected()); + + writer_->SetWritable(); + SimulateNextPacketTooLarge(); + connection_.OnCanWrite(); + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + EXPECT_EQ(original_max_packet_length, connection_.max_packet_length()); + EXPECT_TRUE(connection_.connected()); +} + +// Tests whether MTU discovery works correctly when the probes never get +// acknowledged. +TEST_P(QuicConnectionTest, MtuDiscoveryFailed) { + MtuDiscoveryTestInit(); + + // Lower the number of probes between packets in order to make the test go + // much faster. + const QuicPacketCount packets_between_probes_base = 5; + set_packets_between_probes_base(packets_between_probes_base); + + connection_.EnablePathMtuDiscovery(send_algorithm_); + + const QuicTime::Delta rtt = QuicTime::Delta::FromMilliseconds(100); + + EXPECT_EQ(packets_between_probes_base, + QuicConnectionPeer::GetPacketsBetweenMtuProbes(&connection_)); + + // This tests sends more packets than strictly necessary to make sure that if + // the connection was to send more discovery packets than needed, those would + // get caught as well. + const QuicPacketCount number_of_packets = + packets_between_probes_base * (1 << (kMtuDiscoveryAttempts + 1)); + std::vector mtu_discovery_packets; + // Called on many acks. + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)) + .Times(AnyNumber()); + for (QuicPacketCount i = 0; i < number_of_packets; i++) { + SendStreamDataToPeer(3, "!", i, NO_FIN, nullptr); + clock_.AdvanceTime(rtt); + + // Receive an ACK, which marks all data packets as received, and all MTU + // discovery packets as missing. + + QuicAckFrame ack; + + if (!mtu_discovery_packets.empty()) { + QuicPacketNumber min_packet = *min_element(mtu_discovery_packets.begin(), + mtu_discovery_packets.end()); + QuicPacketNumber max_packet = *max_element(mtu_discovery_packets.begin(), + mtu_discovery_packets.end()); + ack.packets.AddRange(QuicPacketNumber(1), min_packet); + ack.packets.AddRange(QuicPacketNumber(max_packet + 1), + creator_->packet_number() + 1); + ack.largest_acked = creator_->packet_number(); + + } else { + ack.packets.AddRange(QuicPacketNumber(1), creator_->packet_number() + 1); + ack.largest_acked = creator_->packet_number(); + } + + ProcessAckPacket(&ack); + + // Trigger MTU probe if it would be scheduled now. + if (!connection_.GetMtuDiscoveryAlarm()->IsSet()) { + continue; + } + + // Fire the alarm. The alarm should cause a packet to be sent. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)); + connection_.GetMtuDiscoveryAlarm()->Fire(); + // Record the packet number of the MTU discovery packet in order to + // mark it as NACK'd. + mtu_discovery_packets.push_back(creator_->packet_number()); + } + + // Ensure the number of packets between probes grows exponentially by checking + // it against the closed-form expression for the packet number. + ASSERT_EQ(kMtuDiscoveryAttempts, mtu_discovery_packets.size()); + for (uint64_t i = 0; i < kMtuDiscoveryAttempts; i++) { + // 2^0 + 2^1 + 2^2 + ... + 2^n = 2^(n + 1) - 1 + const QuicPacketCount packets_between_probes = + packets_between_probes_base * ((1 << (i + 1)) - 1); + EXPECT_EQ(QuicPacketNumber(packets_between_probes + (i + 1)), + mtu_discovery_packets[i]); + } + + EXPECT_FALSE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + EXPECT_EQ(kDefaultMaxPacketSize, connection_.max_packet_length()); + EXPECT_EQ(kMtuDiscoveryAttempts, connection_.mtu_probe_count()); +} + +// Probe 3 times, the first one succeeds, then fails, then succeeds again. +TEST_P(QuicConnectionTest, MtuDiscoverySecondProbeFailed) { + MtuDiscoveryTestInit(); + + const QuicPacketCount packets_between_probes_base = 5; + set_packets_between_probes_base(packets_between_probes_base); + + connection_.EnablePathMtuDiscovery(send_algorithm_); + + // Send enough packets so that the next one triggers path MTU discovery. + QuicStreamOffset stream_offset = 0; + for (QuicPacketCount i = 0; i < packets_between_probes_base - 1; i++) { + SendStreamDataToPeer(3, ".", stream_offset++, NO_FIN, nullptr); + ASSERT_FALSE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + } + + // Trigger the probe. + SendStreamDataToPeer(3, "!", packets_between_probes_base - 1, NO_FIN, + nullptr); + ASSERT_TRUE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + QuicByteCount probe_size; + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .WillOnce(SaveArg<3>(&probe_size)); + connection_.GetMtuDiscoveryAlarm()->Fire(); + EXPECT_THAT(probe_size, InRange(connection_.max_packet_length(), + kMtuDiscoveryTargetPacketSizeHigh)); + + const QuicPacketNumber probe_packet_number = + FirstSendingPacketNumber() + packets_between_probes_base; + ASSERT_EQ(probe_packet_number, creator_->packet_number()); + + // Acknowledge all packets sent so far. + QuicAckFrame first_ack = InitAckFrame(probe_packet_number); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)) + .Times(AnyNumber()); + ProcessAckPacket(&first_ack); + EXPECT_EQ(probe_size, connection_.max_packet_length()); + EXPECT_EQ(0u, connection_.GetBytesInFlight()); + + EXPECT_EQ(1u, connection_.mtu_probe_count()); + + // Send just enough packets without triggering the second probe. + for (QuicPacketCount i = 0; i < (packets_between_probes_base << 1) - 1; ++i) { + SendStreamDataToPeer(3, ".", stream_offset++, NO_FIN, nullptr); + ASSERT_FALSE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + } + + // Trigger the second probe. + SendStreamDataToPeer(3, "!", stream_offset++, NO_FIN, nullptr); + ASSERT_TRUE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + QuicByteCount second_probe_size; + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .WillOnce(SaveArg<3>(&second_probe_size)); + connection_.GetMtuDiscoveryAlarm()->Fire(); + EXPECT_THAT(second_probe_size, + InRange(probe_size, kMtuDiscoveryTargetPacketSizeHigh)); + EXPECT_EQ(2u, connection_.mtu_probe_count()); + + // Acknowledge all packets sent so far, except the second probe. + QuicPacketNumber second_probe_packet_number = creator_->packet_number(); + QuicAckFrame second_ack = InitAckFrame(second_probe_packet_number - 1); + ProcessAckPacket(&first_ack); + EXPECT_EQ(probe_size, connection_.max_packet_length()); + + // Send just enough packets without triggering the third probe. + for (QuicPacketCount i = 0; i < (packets_between_probes_base << 2) - 1; ++i) { + SendStreamDataToPeer(3, "@", stream_offset++, NO_FIN, nullptr); + ASSERT_FALSE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + } + + // Trigger the third probe. + SendStreamDataToPeer(3, "#", stream_offset++, NO_FIN, nullptr); + ASSERT_TRUE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + QuicByteCount third_probe_size; + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .WillOnce(SaveArg<3>(&third_probe_size)); + connection_.GetMtuDiscoveryAlarm()->Fire(); + EXPECT_THAT(third_probe_size, InRange(probe_size, second_probe_size)); + EXPECT_EQ(3u, connection_.mtu_probe_count()); + + // Acknowledge all packets sent so far, except the second probe. + QuicAckFrame third_ack = + ConstructAckFrame(creator_->packet_number(), second_probe_packet_number); + ProcessAckPacket(&third_ack); + EXPECT_EQ(third_probe_size, connection_.max_packet_length()); + + SendStreamDataToPeer(3, "$", stream_offset++, NO_FIN, nullptr); + EXPECT_TRUE(connection_.PathMtuReductionDetectionInProgress()); + + if (connection_.PathDegradingDetectionInProgress() && + QuicConnectionPeer::GetPathDegradingDeadline(&connection_) < + QuicConnectionPeer::GetPathMtuReductionDetectionDeadline( + &connection_)) { + // Fire path degrading alarm first. + connection_.PathDegradingTimeout(); + } + + // Verify the max packet size has not reduced. + EXPECT_EQ(third_probe_size, connection_.max_packet_length()); + + // Fire alarm to get path mtu reduction callback called. + EXPECT_TRUE(connection_.PathMtuReductionDetectionInProgress()); + connection_.GetBlackholeDetectorAlarm()->Fire(); + + // Verify the max packet size has reduced to the previous value. + EXPECT_EQ(probe_size, connection_.max_packet_length()); +} + +// Tests whether MTU discovery works when the writer has a limit on how large a +// packet can be. +TEST_P(QuicConnectionTest, MtuDiscoveryWriterLimited) { + MtuDiscoveryTestInit(); + + const QuicByteCount mtu_limit = kMtuDiscoveryTargetPacketSizeHigh - 1; + writer_->set_max_packet_size(mtu_limit); + + const QuicPacketCount packets_between_probes_base = 5; + set_packets_between_probes_base(packets_between_probes_base); + + connection_.EnablePathMtuDiscovery(send_algorithm_); + + // Send enough packets so that the next one triggers path MTU discovery. + for (QuicPacketCount i = 0; i < packets_between_probes_base - 1; i++) { + SendStreamDataToPeer(3, ".", i, NO_FIN, nullptr); + ASSERT_FALSE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + } + + // Trigger the probe. + SendStreamDataToPeer(3, "!", packets_between_probes_base - 1, NO_FIN, + nullptr); + ASSERT_TRUE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + QuicByteCount probe_size; + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .WillOnce(SaveArg<3>(&probe_size)); + connection_.GetMtuDiscoveryAlarm()->Fire(); + + EXPECT_THAT(probe_size, InRange(connection_.max_packet_length(), mtu_limit)); + + const QuicPacketNumber probe_sequence_number = + FirstSendingPacketNumber() + packets_between_probes_base; + ASSERT_EQ(probe_sequence_number, creator_->packet_number()); + + // Acknowledge all packets sent so far. + QuicAckFrame probe_ack = InitAckFrame(probe_sequence_number); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)) + .Times(AnyNumber()); + ProcessAckPacket(&probe_ack); + EXPECT_EQ(probe_size, connection_.max_packet_length()); + EXPECT_EQ(0u, connection_.GetBytesInFlight()); + + EXPECT_EQ(1u, connection_.mtu_probe_count()); + + QuicStreamOffset stream_offset = packets_between_probes_base; + for (size_t num_probes = 1; num_probes < kMtuDiscoveryAttempts; + ++num_probes) { + // Send just enough packets without triggering the next probe. + for (QuicPacketCount i = 0; + i < (packets_between_probes_base << num_probes) - 1; ++i) { + SendStreamDataToPeer(3, ".", stream_offset++, NO_FIN, nullptr); + ASSERT_FALSE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + } + + // Trigger the next probe. + SendStreamDataToPeer(3, "!", stream_offset++, NO_FIN, nullptr); + ASSERT_TRUE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + QuicByteCount new_probe_size; + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .WillOnce(SaveArg<3>(&new_probe_size)); + connection_.GetMtuDiscoveryAlarm()->Fire(); + EXPECT_THAT(new_probe_size, InRange(probe_size, mtu_limit)); + EXPECT_EQ(num_probes + 1, connection_.mtu_probe_count()); + + // Acknowledge all packets sent so far. + QuicAckFrame probe_ack = InitAckFrame(creator_->packet_number()); + ProcessAckPacket(&probe_ack); + EXPECT_EQ(new_probe_size, connection_.max_packet_length()); + EXPECT_EQ(0u, connection_.GetBytesInFlight()); + + probe_size = new_probe_size; + } + + // The last probe size should be equal to the target. + EXPECT_EQ(probe_size, mtu_limit); +} + +// Tests whether MTU discovery works when the writer returns an error despite +// advertising higher packet length. +TEST_P(QuicConnectionTest, MtuDiscoveryWriterFailed) { + MtuDiscoveryTestInit(); + + const QuicByteCount mtu_limit = kMtuDiscoveryTargetPacketSizeHigh - 1; + const QuicByteCount initial_mtu = connection_.max_packet_length(); + EXPECT_LT(initial_mtu, mtu_limit); + writer_->set_max_packet_size(mtu_limit); + + const QuicPacketCount packets_between_probes_base = 5; + set_packets_between_probes_base(packets_between_probes_base); + + connection_.EnablePathMtuDiscovery(send_algorithm_); + + // Send enough packets so that the next one triggers path MTU discovery. + for (QuicPacketCount i = 0; i < packets_between_probes_base - 1; i++) { + SendStreamDataToPeer(3, ".", i, NO_FIN, nullptr); + ASSERT_FALSE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + } + + // Trigger the probe. + SendStreamDataToPeer(3, "!", packets_between_probes_base - 1, NO_FIN, + nullptr); + ASSERT_TRUE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + writer_->SimulateNextPacketTooLarge(); + connection_.GetMtuDiscoveryAlarm()->Fire(); + ASSERT_TRUE(connection_.connected()); + + // Send more data. + QuicPacketNumber probe_number = creator_->packet_number(); + QuicPacketCount extra_packets = packets_between_probes_base * 3; + for (QuicPacketCount i = 0; i < extra_packets; i++) { + connection_.EnsureWritableAndSendStreamData5(); + ASSERT_FALSE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + } + + // Acknowledge all packets sent so far, except for the lost probe. + QuicAckFrame probe_ack = + ConstructAckFrame(creator_->packet_number(), probe_number); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + ProcessAckPacket(&probe_ack); + EXPECT_EQ(initial_mtu, connection_.max_packet_length()); + + // Send more packets, and ensure that none of them sets the alarm. + for (QuicPacketCount i = 0; i < 4 * packets_between_probes_base; i++) { + connection_.EnsureWritableAndSendStreamData5(); + ASSERT_FALSE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + } + + EXPECT_EQ(initial_mtu, connection_.max_packet_length()); + EXPECT_EQ(1u, connection_.mtu_probe_count()); +} + +TEST_P(QuicConnectionTest, NoMtuDiscoveryAfterConnectionClosed) { + MtuDiscoveryTestInit(); + + const QuicPacketCount packets_between_probes_base = 10; + set_packets_between_probes_base(packets_between_probes_base); + + connection_.EnablePathMtuDiscovery(send_algorithm_); + + // Send enough packets so that the next one triggers path MTU discovery. + for (QuicPacketCount i = 0; i < packets_between_probes_base - 1; i++) { + SendStreamDataToPeer(3, ".", i, NO_FIN, nullptr); + ASSERT_FALSE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + } + + SendStreamDataToPeer(3, "!", packets_between_probes_base - 1, NO_FIN, + nullptr); + EXPECT_TRUE(connection_.GetMtuDiscoveryAlarm()->IsSet()); + + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)); + connection_.CloseConnection(QUIC_PEER_GOING_AWAY, "no reason", + ConnectionCloseBehavior::SILENT_CLOSE); + EXPECT_FALSE(connection_.GetMtuDiscoveryAlarm()->IsSet()); +} + +TEST_P(QuicConnectionTest, TimeoutAfterSendDuringHandshake) { + EXPECT_TRUE(connection_.connected()); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + connection_.SetFromConfig(config); + + const QuicTime::Delta initial_idle_timeout = + QuicTime::Delta::FromSeconds(kInitialIdleTimeoutSecs - 1); + const QuicTime::Delta five_ms = QuicTime::Delta::FromMilliseconds(5); + QuicTime default_timeout = clock_.ApproximateNow() + initial_idle_timeout; + + // When we send a packet, the timeout will change to 5ms + + // kInitialIdleTimeoutSecs. + clock_.AdvanceTime(five_ms); + SendStreamDataToPeer( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), "foo", + 0, FIN, nullptr); + EXPECT_EQ(default_timeout + five_ms, + connection_.GetTimeoutAlarm()->deadline()); + + // Now send more data. This will not move the timeout because + // no data has been received since the previous write. + clock_.AdvanceTime(five_ms); + SendStreamDataToPeer( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), "foo", + 3, FIN, nullptr); + EXPECT_EQ(default_timeout + five_ms, + connection_.GetTimeoutAlarm()->deadline()); + + // The original alarm will fire. We should not time out because we had a + // network event at t=5ms. The alarm will reregister. + clock_.AdvanceTime(initial_idle_timeout - five_ms - five_ms); + EXPECT_EQ(default_timeout, clock_.ApproximateNow()); + EXPECT_TRUE(connection_.GetTimeoutAlarm()->IsSet()); + EXPECT_TRUE(connection_.connected()); + EXPECT_EQ(default_timeout + five_ms, + connection_.GetTimeoutAlarm()->deadline()); + + // This time, we should time out. + EXPECT_CALL(visitor_, + OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(AtLeast(1)); + clock_.AdvanceTime(five_ms); + EXPECT_EQ(default_timeout + five_ms, clock_.ApproximateNow()); + connection_.GetTimeoutAlarm()->Fire(); + EXPECT_FALSE(connection_.GetTimeoutAlarm()->IsSet()); + EXPECT_FALSE(connection_.connected()); + TestConnectionCloseQuicErrorCode(QUIC_NETWORK_IDLE_TIMEOUT); +} + +TEST_P(QuicConnectionTest, TimeoutAfterSendAfterHandshake) { + // When the idle timeout fires, verify that by default we do not send any + // connection close packets. + EXPECT_TRUE(connection_.connected()); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + + // Create a handshake message that also enables silent close. + CryptoHandshakeMessage msg; + std::string error_details; + QuicConfig client_config; + client_config.SetInitialStreamFlowControlWindowToSend( + kInitialStreamFlowControlWindowForTest); + client_config.SetInitialSessionFlowControlWindowToSend( + kInitialSessionFlowControlWindowForTest); + client_config.SetIdleNetworkTimeout( + QuicTime::Delta::FromSeconds(kMaximumIdleTimeoutSecs)); + client_config.ToHandshakeMessage(&msg, connection_.transport_version()); + const QuicErrorCode error = + config.ProcessPeerHello(msg, CLIENT, &error_details); + EXPECT_THAT(error, IsQuicNoError()); + + if (connection_.version().UsesTls()) { + QuicConfigPeer::SetReceivedOriginalConnectionId( + &config, connection_.connection_id()); + QuicConfigPeer::SetReceivedInitialSourceConnectionId( + &config, connection_.connection_id()); + } + connection_.SetFromConfig(config); + QuicConnectionPeer::DisableBandwidthUpdate(&connection_); + + const QuicTime::Delta default_idle_timeout = + QuicTime::Delta::FromSeconds(kMaximumIdleTimeoutSecs - 1); + const QuicTime::Delta five_ms = QuicTime::Delta::FromMilliseconds(5); + QuicTime default_timeout = clock_.ApproximateNow() + default_idle_timeout; + + // When we send a packet, the timeout will change to 5ms + + // kInitialIdleTimeoutSecs. + clock_.AdvanceTime(five_ms); + SendStreamDataToPeer( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), "foo", + 0, FIN, nullptr); + EXPECT_EQ(default_timeout + five_ms, + connection_.GetTimeoutAlarm()->deadline()); + + // Now send more data. This will not move the timeout because + // no data has been received since the previous write. + clock_.AdvanceTime(five_ms); + SendStreamDataToPeer( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), "foo", + 3, FIN, nullptr); + EXPECT_EQ(default_timeout + five_ms, + connection_.GetTimeoutAlarm()->deadline()); + + // The original alarm will fire. We should not time out because we had a + // network event at t=5ms. The alarm will reregister. + clock_.AdvanceTime(default_idle_timeout - five_ms - five_ms); + EXPECT_EQ(default_timeout, clock_.ApproximateNow()); + EXPECT_TRUE(connection_.GetTimeoutAlarm()->IsSet()); + EXPECT_TRUE(connection_.connected()); + EXPECT_EQ(default_timeout + five_ms, + connection_.GetTimeoutAlarm()->deadline()); + + // This time, we should time out. + // This results in a SILENT_CLOSE, so the writer will not be invoked + // and will not save the frame. Grab the frame from OnConnectionClosed + // directly. + EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .WillOnce(Invoke(this, &QuicConnectionTest::SaveConnectionCloseFrame)); + + clock_.AdvanceTime(five_ms); + EXPECT_EQ(default_timeout + five_ms, clock_.ApproximateNow()); + connection_.GetTimeoutAlarm()->Fire(); + EXPECT_FALSE(connection_.GetTimeoutAlarm()->IsSet()); + EXPECT_FALSE(connection_.connected()); + EXPECT_EQ(1, connection_close_frame_count_); + EXPECT_THAT(saved_connection_close_frame_.quic_error_code, + IsError(QUIC_NETWORK_IDLE_TIMEOUT)); +} + +TEST_P(QuicConnectionTest, TimeoutAfterSendSilentCloseWithOpenStreams) { + // Same test as above, but having open streams causes a connection close + // to be sent. + EXPECT_TRUE(connection_.connected()); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + + // Create a handshake message that also enables silent close. + CryptoHandshakeMessage msg; + std::string error_details; + QuicConfig client_config; + client_config.SetInitialStreamFlowControlWindowToSend( + kInitialStreamFlowControlWindowForTest); + client_config.SetInitialSessionFlowControlWindowToSend( + kInitialSessionFlowControlWindowForTest); + client_config.SetIdleNetworkTimeout( + QuicTime::Delta::FromSeconds(kMaximumIdleTimeoutSecs)); + client_config.ToHandshakeMessage(&msg, connection_.transport_version()); + const QuicErrorCode error = + config.ProcessPeerHello(msg, CLIENT, &error_details); + EXPECT_THAT(error, IsQuicNoError()); + + if (connection_.version().UsesTls()) { + QuicConfigPeer::SetReceivedOriginalConnectionId( + &config, connection_.connection_id()); + QuicConfigPeer::SetReceivedInitialSourceConnectionId( + &config, connection_.connection_id()); + } + connection_.SetFromConfig(config); + QuicConnectionPeer::DisableBandwidthUpdate(&connection_); + + const QuicTime::Delta default_idle_timeout = + QuicTime::Delta::FromSeconds(kMaximumIdleTimeoutSecs - 1); + const QuicTime::Delta five_ms = QuicTime::Delta::FromMilliseconds(5); + QuicTime default_timeout = clock_.ApproximateNow() + default_idle_timeout; + + // When we send a packet, the timeout will change to 5ms + + // kInitialIdleTimeoutSecs. + clock_.AdvanceTime(five_ms); + SendStreamDataToPeer( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), "foo", + 0, FIN, nullptr); + EXPECT_EQ(default_timeout + five_ms, + connection_.GetTimeoutAlarm()->deadline()); + + // Indicate streams are still open. + EXPECT_CALL(visitor_, ShouldKeepConnectionAlive()) + .WillRepeatedly(Return(true)); + if (GetQuicReloadableFlag(quic_add_stream_info_to_idle_close_detail)) { + EXPECT_CALL(visitor_, GetStreamsInfoForLogging()).WillOnce(Return("")); + } + + // This time, we should time out and send a connection close due to the TLP. + EXPECT_CALL(visitor_, + OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(AtLeast(1)); + clock_.AdvanceTime(connection_.GetTimeoutAlarm()->deadline() - + clock_.ApproximateNow() + five_ms); + connection_.GetTimeoutAlarm()->Fire(); + EXPECT_FALSE(connection_.GetTimeoutAlarm()->IsSet()); + EXPECT_FALSE(connection_.connected()); + TestConnectionCloseQuicErrorCode(QUIC_NETWORK_IDLE_TIMEOUT); +} + +TEST_P(QuicConnectionTest, TimeoutAfterReceive) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_TRUE(connection_.connected()); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + connection_.SetFromConfig(config); + + const QuicTime::Delta initial_idle_timeout = + QuicTime::Delta::FromSeconds(kInitialIdleTimeoutSecs - 1); + const QuicTime::Delta five_ms = QuicTime::Delta::FromMilliseconds(5); + QuicTime default_timeout = clock_.ApproximateNow() + initial_idle_timeout; + + connection_.SendStreamDataWithString( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), "foo", + 0, NO_FIN); + connection_.SendStreamDataWithString( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), "foo", + 3, NO_FIN); + + EXPECT_EQ(default_timeout, connection_.GetTimeoutAlarm()->deadline()); + clock_.AdvanceTime(five_ms); + + // When we receive a packet, the timeout will change to 5ms + + // kInitialIdleTimeoutSecs. + QuicAckFrame ack = InitAckFrame(2); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + ProcessAckPacket(&ack); + + // The original alarm will fire. We should not time out because we had a + // network event at t=5ms. The alarm will reregister. + clock_.AdvanceTime(initial_idle_timeout - five_ms); + EXPECT_EQ(default_timeout, clock_.ApproximateNow()); + EXPECT_TRUE(connection_.connected()); + EXPECT_TRUE(connection_.GetTimeoutAlarm()->IsSet()); + EXPECT_EQ(default_timeout + five_ms, + connection_.GetTimeoutAlarm()->deadline()); + + // This time, we should time out. + EXPECT_CALL(visitor_, + OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(AtLeast(1)); + clock_.AdvanceTime(five_ms); + EXPECT_EQ(default_timeout + five_ms, clock_.ApproximateNow()); + connection_.GetTimeoutAlarm()->Fire(); + EXPECT_FALSE(connection_.GetTimeoutAlarm()->IsSet()); + EXPECT_FALSE(connection_.connected()); + TestConnectionCloseQuicErrorCode(QUIC_NETWORK_IDLE_TIMEOUT); +} + +TEST_P(QuicConnectionTest, TimeoutAfterReceiveNotSendWhenUnacked) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_TRUE(connection_.connected()); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + connection_.SetFromConfig(config); + + const QuicTime::Delta initial_idle_timeout = + QuicTime::Delta::FromSeconds(kInitialIdleTimeoutSecs - 1); + connection_.SetNetworkTimeouts( + QuicTime::Delta::Infinite(), + initial_idle_timeout + QuicTime::Delta::FromSeconds(1)); + QuicConnectionPeer::DisableBandwidthUpdate(&connection_); + const QuicTime::Delta five_ms = QuicTime::Delta::FromMilliseconds(5); + QuicTime default_timeout = clock_.ApproximateNow() + initial_idle_timeout; + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)); + connection_.SendStreamDataWithString( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), "foo", + 0, NO_FIN); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)); + connection_.SendStreamDataWithString( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), "foo", + 3, NO_FIN); + + EXPECT_EQ(default_timeout, connection_.GetTimeoutAlarm()->deadline()); + + clock_.AdvanceTime(five_ms); + + // When we receive a packet, the timeout will change to 5ms + + // kInitialIdleTimeoutSecs. + QuicAckFrame ack = InitAckFrame(2); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + ProcessAckPacket(&ack); + + // The original alarm will fire. We should not time out because we had a + // network event at t=5ms. The alarm will reregister. + clock_.AdvanceTime(initial_idle_timeout - five_ms); + EXPECT_EQ(default_timeout, clock_.ApproximateNow()); + EXPECT_TRUE(connection_.connected()); + EXPECT_TRUE(connection_.GetTimeoutAlarm()->IsSet()); + EXPECT_EQ(default_timeout + five_ms, + connection_.GetTimeoutAlarm()->deadline()); + + // Now, send packets while advancing the time and verify that the connection + // eventually times out. + EXPECT_CALL(visitor_, + OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(AnyNumber()); + for (int i = 0; i < 100 && connection_.connected(); ++i) { + QUIC_LOG(INFO) << "sending data packet"; + connection_.SendStreamDataWithString( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), + "foo", 0, NO_FIN); + connection_.GetTimeoutAlarm()->Fire(); + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(1)); + } + EXPECT_FALSE(connection_.connected()); + EXPECT_FALSE(connection_.GetTimeoutAlarm()->IsSet()); + TestConnectionCloseQuicErrorCode(QUIC_NETWORK_IDLE_TIMEOUT); +} + +TEST_P(QuicConnectionTest, SendScheduler) { + // Test that if we send a packet without delay, it is not queued. + QuicFramerPeer::SetPerspective(&peer_framer_, Perspective::IS_CLIENT); + std::unique_ptr packet = + ConstructDataPacket(1, !kHasStopWaiting, ENCRYPTION_INITIAL); + QuicPacketCreatorPeer::SetPacketNumber(creator_, 1); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)); + connection_.SendPacket(ENCRYPTION_INITIAL, 1, std::move(packet), + HAS_RETRANSMITTABLE_DATA, false, false); + EXPECT_EQ(0u, connection_.NumQueuedPackets()); +} + +TEST_P(QuicConnectionTest, FailToSendFirstPacket) { + // Test that the connection does not crash when it fails to send the first + // packet at which point self_address_ might be uninitialized. + QuicFramerPeer::SetPerspective(&peer_framer_, Perspective::IS_CLIENT); + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)).Times(1); + std::unique_ptr packet = + ConstructDataPacket(1, !kHasStopWaiting, ENCRYPTION_INITIAL); + QuicPacketCreatorPeer::SetPacketNumber(creator_, 1); + writer_->SetShouldWriteFail(); + connection_.SendPacket(ENCRYPTION_INITIAL, 1, std::move(packet), + HAS_RETRANSMITTABLE_DATA, false, false); +} + +TEST_P(QuicConnectionTest, SendSchedulerEAGAIN) { + QuicFramerPeer::SetPerspective(&peer_framer_, Perspective::IS_CLIENT); + std::unique_ptr packet = + ConstructDataPacket(1, !kHasStopWaiting, ENCRYPTION_INITIAL); + QuicPacketCreatorPeer::SetPacketNumber(creator_, 1); + BlockOnNextWrite(); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, QuicPacketNumber(2u), _, _)) + .Times(0); + connection_.SendPacket(ENCRYPTION_INITIAL, 1, std::move(packet), + HAS_RETRANSMITTABLE_DATA, false, false); + EXPECT_EQ(1u, connection_.NumQueuedPackets()); +} + +TEST_P(QuicConnectionTest, TestQueueLimitsOnSendStreamData) { + // Queue the first packet. + size_t payload_length = connection_.max_packet_length(); + EXPECT_CALL(*send_algorithm_, CanSend(_)).WillOnce(testing::Return(false)); + const std::string payload(payload_length, 'a'); + QuicStreamId first_bidi_stream_id(QuicUtils::GetFirstBidirectionalStreamId( + connection_.version().transport_version, Perspective::IS_CLIENT)); + EXPECT_EQ(0u, connection_ + .SendStreamDataWithString(first_bidi_stream_id, payload, 0, + NO_FIN) + .bytes_consumed); + EXPECT_EQ(0u, connection_.NumQueuedPackets()); +} + +TEST_P(QuicConnectionTest, SendingThreePackets) { + // Make the payload twice the size of the packet, so 3 packets are written. + size_t total_payload_length = 2 * connection_.max_packet_length(); + const std::string payload(total_payload_length, 'a'); + QuicStreamId first_bidi_stream_id(QuicUtils::GetFirstBidirectionalStreamId( + connection_.version().transport_version, Perspective::IS_CLIENT)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(3); + EXPECT_EQ(payload.size(), connection_ + .SendStreamDataWithString(first_bidi_stream_id, + payload, 0, NO_FIN) + .bytes_consumed); +} + +TEST_P(QuicConnectionTest, LoopThroughSendingPacketsWithTruncation) { + set_perspective(Perspective::IS_SERVER); + if (!GetParam().version.HasIetfInvariantHeader()) { + // For IETF QUIC, encryption level will be switched to FORWARD_SECURE in + // SendStreamDataWithString. + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + } + // Set up a larger payload than will fit in one packet. + const std::string payload(connection_.max_packet_length(), 'a'); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)).Times(AnyNumber()); + + // Now send some packets with no truncation. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(2); + EXPECT_EQ(payload.size(), + connection_.SendStreamDataWithString(3, payload, 0, NO_FIN) + .bytes_consumed); + // Track the size of the second packet here. The overhead will be the largest + // we see in this test, due to the non-truncated connection id. + size_t non_truncated_packet_size = writer_->last_packet_size(); + + // Change to a 0 byte connection id. + QuicConfig config; + QuicConfigPeer::SetReceivedBytesForConnectionId(&config, 0); + connection_.SetFromConfig(config); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(2); + EXPECT_EQ(payload.size(), + connection_.SendStreamDataWithString(3, payload, 1350, NO_FIN) + .bytes_consumed); + if (connection_.version().HasIetfInvariantHeader()) { + // Short header packets sent from server omit connection ID already, and + // stream offset size increases from 0 to 2. + EXPECT_EQ(non_truncated_packet_size, writer_->last_packet_size() - 2); + } else { + // Just like above, we save 8 bytes on payload, and 8 on truncation. -2 + // because stream offset size is 2 instead of 0. + EXPECT_EQ(non_truncated_packet_size, + writer_->last_packet_size() + 8 * 2 - 2); + } +} + +TEST_P(QuicConnectionTest, SendDelayedAck) { + QuicTime ack_time = clock_.ApproximateNow() + DefaultDelayedAckTime(); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_FALSE(connection_.HasPendingAcks()); + SetDecrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + peer_framer_.SetEncrypter( + ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + // Process a packet from the non-crypto stream. + frame1_.stream_id = 3; + + // The same as ProcessPacket(1) except that ENCRYPTION_ZERO_RTT is used + // instead of ENCRYPTION_INITIAL. + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacketAtLevel(1, !kHasStopWaiting, ENCRYPTION_ZERO_RTT); + + // Check if delayed ack timer is running for the expected interval. + EXPECT_TRUE(connection_.HasPendingAcks()); + EXPECT_EQ(ack_time, connection_.GetAckAlarm()->deadline()); + // Simulate delayed ack alarm firing. + clock_.AdvanceTime(DefaultDelayedAckTime()); + connection_.GetAckAlarm()->Fire(); + // Check that ack is sent and that delayed ack alarm is reset. + size_t padding_frame_count = writer_->padding_frames().size(); + if (GetParam().no_stop_waiting) { + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); + EXPECT_TRUE(writer_->stop_waiting_frames().empty()); + } else { + EXPECT_EQ(padding_frame_count + 2u, writer_->frame_count()); + EXPECT_FALSE(writer_->stop_waiting_frames().empty()); + } + EXPECT_FALSE(writer_->ack_frames().empty()); + EXPECT_FALSE(connection_.HasPendingAcks()); +} + +TEST_P(QuicConnectionTest, SendDelayedAckDecimation) { + EXPECT_CALL(visitor_, OnAckNeedsRetransmittableFrame()).Times(AnyNumber()); + + const size_t kMinRttMs = 40; + RttStats* rtt_stats = const_cast(manager_->GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(kMinRttMs), + QuicTime::Delta::Zero(), QuicTime::Zero()); + // The ack time should be based on min_rtt/4, since it's less than the + // default delayed ack time. + QuicTime ack_time = clock_.ApproximateNow() + + QuicTime::Delta::FromMilliseconds(kMinRttMs / 4); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_FALSE(connection_.HasPendingAcks()); + SetDecrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + peer_framer_.SetEncrypter( + ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + // Process a packet from the non-crypto stream. + frame1_.stream_id = 3; + + // Process all the initial packets in order so there aren't missing packets. + uint64_t kFirstDecimatedPacket = 101; + for (unsigned int i = 0; i < kFirstDecimatedPacket - 1; ++i) { + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacketAtLevel(1 + i, !kHasStopWaiting, ENCRYPTION_ZERO_RTT); + } + EXPECT_FALSE(connection_.HasPendingAcks()); + // The same as ProcessPacket(1) except that ENCRYPTION_ZERO_RTT is used + // instead of ENCRYPTION_INITIAL. + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacketAtLevel(kFirstDecimatedPacket, !kHasStopWaiting, + ENCRYPTION_ZERO_RTT); + + // Check if delayed ack timer is running for the expected interval. + EXPECT_TRUE(connection_.HasPendingAcks()); + EXPECT_EQ(ack_time, connection_.GetAckAlarm()->deadline()); + + // The 10th received packet causes an ack to be sent. + for (int i = 0; i < 9; ++i) { + EXPECT_TRUE(connection_.HasPendingAcks()); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacketAtLevel(kFirstDecimatedPacket + 1 + i, !kHasStopWaiting, + ENCRYPTION_ZERO_RTT); + } + // Check that ack is sent and that delayed ack alarm is reset. + size_t padding_frame_count = writer_->padding_frames().size(); + if (GetParam().no_stop_waiting) { + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); + EXPECT_TRUE(writer_->stop_waiting_frames().empty()); + } else { + EXPECT_EQ(padding_frame_count + 2u, writer_->frame_count()); + EXPECT_FALSE(writer_->stop_waiting_frames().empty()); + } + EXPECT_FALSE(writer_->ack_frames().empty()); + EXPECT_FALSE(connection_.HasPendingAcks()); +} + +TEST_P(QuicConnectionTest, SendDelayedAckDecimationUnlimitedAggregation) { + EXPECT_CALL(visitor_, OnAckNeedsRetransmittableFrame()).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + QuicTagVector connection_options; + // No limit on the number of packets received before sending an ack. + connection_options.push_back(kAKDU); + config.SetConnectionOptionsToSend(connection_options); + connection_.SetFromConfig(config); + + const size_t kMinRttMs = 40; + RttStats* rtt_stats = const_cast(manager_->GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(kMinRttMs), + QuicTime::Delta::Zero(), QuicTime::Zero()); + // The ack time should be based on min_rtt/4, since it's less than the + // default delayed ack time. + QuicTime ack_time = clock_.ApproximateNow() + + QuicTime::Delta::FromMilliseconds(kMinRttMs / 4); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_FALSE(connection_.HasPendingAcks()); + SetDecrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + peer_framer_.SetEncrypter( + ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + // Process a packet from the non-crypto stream. + frame1_.stream_id = 3; + + // Process all the initial packets in order so there aren't missing packets. + uint64_t kFirstDecimatedPacket = 101; + for (unsigned int i = 0; i < kFirstDecimatedPacket - 1; ++i) { + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacketAtLevel(1 + i, !kHasStopWaiting, ENCRYPTION_ZERO_RTT); + } + EXPECT_FALSE(connection_.HasPendingAcks()); + // The same as ProcessPacket(1) except that ENCRYPTION_ZERO_RTT is used + // instead of ENCRYPTION_INITIAL. + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacketAtLevel(kFirstDecimatedPacket, !kHasStopWaiting, + ENCRYPTION_ZERO_RTT); + + // Check if delayed ack timer is running for the expected interval. + EXPECT_TRUE(connection_.HasPendingAcks()); + EXPECT_EQ(ack_time, connection_.GetAckAlarm()->deadline()); + + // 18 packets will not cause an ack to be sent. 19 will because when + // stop waiting frames are in use, we ack every 20 packets no matter what. + for (int i = 0; i < 18; ++i) { + EXPECT_TRUE(connection_.HasPendingAcks()); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacketAtLevel(kFirstDecimatedPacket + 1 + i, !kHasStopWaiting, + ENCRYPTION_ZERO_RTT); + } + // The delayed ack timer should still be set to the expected deadline. + EXPECT_TRUE(connection_.HasPendingAcks()); + EXPECT_EQ(ack_time, connection_.GetAckAlarm()->deadline()); +} + +TEST_P(QuicConnectionTest, SendDelayedAckDecimationEighthRtt) { + EXPECT_CALL(visitor_, OnAckNeedsRetransmittableFrame()).Times(AnyNumber()); + QuicConnectionPeer::SetAckDecimationDelay(&connection_, 0.125); + + const size_t kMinRttMs = 40; + RttStats* rtt_stats = const_cast(manager_->GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(kMinRttMs), + QuicTime::Delta::Zero(), QuicTime::Zero()); + // The ack time should be based on min_rtt/8, since it's less than the + // default delayed ack time. + QuicTime ack_time = clock_.ApproximateNow() + + QuicTime::Delta::FromMilliseconds(kMinRttMs / 8); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_FALSE(connection_.HasPendingAcks()); + SetDecrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + peer_framer_.SetEncrypter( + ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + // Process a packet from the non-crypto stream. + frame1_.stream_id = 3; + + // Process all the initial packets in order so there aren't missing packets. + uint64_t kFirstDecimatedPacket = 101; + for (unsigned int i = 0; i < kFirstDecimatedPacket - 1; ++i) { + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacketAtLevel(1 + i, !kHasStopWaiting, ENCRYPTION_ZERO_RTT); + } + EXPECT_FALSE(connection_.HasPendingAcks()); + // The same as ProcessPacket(1) except that ENCRYPTION_ZERO_RTT is used + // instead of ENCRYPTION_INITIAL. + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacketAtLevel(kFirstDecimatedPacket, !kHasStopWaiting, + ENCRYPTION_ZERO_RTT); + + // Check if delayed ack timer is running for the expected interval. + EXPECT_TRUE(connection_.HasPendingAcks()); + EXPECT_EQ(ack_time, connection_.GetAckAlarm()->deadline()); + + // The 10th received packet causes an ack to be sent. + for (int i = 0; i < 9; ++i) { + EXPECT_TRUE(connection_.HasPendingAcks()); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacketAtLevel(kFirstDecimatedPacket + 1 + i, !kHasStopWaiting, + ENCRYPTION_ZERO_RTT); + } + // Check that ack is sent and that delayed ack alarm is reset. + size_t padding_frame_count = writer_->padding_frames().size(); + if (GetParam().no_stop_waiting) { + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); + EXPECT_TRUE(writer_->stop_waiting_frames().empty()); + } else { + EXPECT_EQ(padding_frame_count + 2u, writer_->frame_count()); + EXPECT_FALSE(writer_->stop_waiting_frames().empty()); + } + EXPECT_FALSE(writer_->ack_frames().empty()); + EXPECT_FALSE(connection_.HasPendingAcks()); +} + +TEST_P(QuicConnectionTest, SendDelayedAckOnHandshakeConfirmed) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + ProcessPacket(1); + // Check that ack is sent and that delayed ack alarm is set. + EXPECT_TRUE(connection_.HasPendingAcks()); + QuicTime ack_time = clock_.ApproximateNow() + DefaultDelayedAckTime(); + EXPECT_EQ(ack_time, connection_.GetAckAlarm()->deadline()); + + // Completing the handshake as the server does nothing. + QuicConnectionPeer::SetPerspective(&connection_, Perspective::IS_SERVER); + connection_.OnHandshakeComplete(); + EXPECT_TRUE(connection_.HasPendingAcks()); + EXPECT_EQ(ack_time, connection_.GetAckAlarm()->deadline()); + + // Complete the handshake as the client decreases the delayed ack time to 0ms. + QuicConnectionPeer::SetPerspective(&connection_, Perspective::IS_CLIENT); + connection_.OnHandshakeComplete(); + EXPECT_TRUE(connection_.HasPendingAcks()); + if (connection_.SupportsMultiplePacketNumberSpaces()) { + EXPECT_EQ(clock_.ApproximateNow() + DefaultDelayedAckTime(), + connection_.GetAckAlarm()->deadline()); + } else { + EXPECT_EQ(clock_.ApproximateNow(), connection_.GetAckAlarm()->deadline()); + } +} + +TEST_P(QuicConnectionTest, SendDelayedAckOnSecondPacket) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + ProcessPacket(1); + ProcessPacket(2); + // Check that ack is sent and that delayed ack alarm is reset. + size_t padding_frame_count = writer_->padding_frames().size(); + if (GetParam().no_stop_waiting) { + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); + EXPECT_TRUE(writer_->stop_waiting_frames().empty()); + } else { + EXPECT_EQ(padding_frame_count + 2u, writer_->frame_count()); + EXPECT_FALSE(writer_->stop_waiting_frames().empty()); + } + EXPECT_FALSE(writer_->ack_frames().empty()); + EXPECT_FALSE(connection_.HasPendingAcks()); +} + +TEST_P(QuicConnectionTest, NoAckOnOldNacks) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + ProcessPacket(2); + size_t frames_per_ack = GetParam().no_stop_waiting ? 1 : 2; + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + ProcessPacket(3); + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + frames_per_ack, writer_->frame_count()); + EXPECT_FALSE(writer_->ack_frames().empty()); + writer_->Reset(); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + ProcessPacket(4); + EXPECT_EQ(0u, writer_->frame_count()); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + ProcessPacket(5); + padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + frames_per_ack, writer_->frame_count()); + EXPECT_FALSE(writer_->ack_frames().empty()); + writer_->Reset(); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + // Now only set the timer on the 6th packet, instead of sending another ack. + ProcessPacket(6); + padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count, writer_->frame_count()); + EXPECT_TRUE(connection_.HasPendingAcks()); +} + +TEST_P(QuicConnectionTest, SendDelayedAckOnOutgoingPacket) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(visitor_, OnStreamFrame(_)); + peer_framer_.SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + SetDecrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + ProcessDataPacket(1); + connection_.SendStreamDataWithString( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), "foo", + 0, NO_FIN); + // Check that ack is bundled with outgoing data and that delayed ack + // alarm is reset. + if (GetParam().no_stop_waiting) { + EXPECT_EQ(2u, writer_->frame_count()); + EXPECT_TRUE(writer_->stop_waiting_frames().empty()); + } else { + EXPECT_EQ(3u, writer_->frame_count()); + EXPECT_FALSE(writer_->stop_waiting_frames().empty()); + } + EXPECT_FALSE(writer_->ack_frames().empty()); + EXPECT_FALSE(connection_.HasPendingAcks()); +} + +TEST_P(QuicConnectionTest, SendDelayedAckOnOutgoingCryptoPacket) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(1); + } else { + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + } + ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL); + connection_.SendCryptoDataWithString("foo", 0); + // Check that ack is bundled with outgoing crypto data. + if (GetParam().no_stop_waiting) { + EXPECT_EQ(3u, writer_->frame_count()); + EXPECT_TRUE(writer_->stop_waiting_frames().empty()); + } else { + EXPECT_EQ(4u, writer_->frame_count()); + EXPECT_FALSE(writer_->stop_waiting_frames().empty()); + } + EXPECT_FALSE(connection_.HasPendingAcks()); +} + +TEST_P(QuicConnectionTest, BlockAndBufferOnFirstCHLOPacketOfTwo) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + ProcessPacket(1); + BlockOnNextWrite(); + writer_->set_is_write_blocked_data_buffered(true); + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + } else { + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(2); + } + connection_.SendCryptoDataWithString("foo", 0); + EXPECT_TRUE(writer_->IsWriteBlocked()); + EXPECT_FALSE(connection_.HasQueuedData()); + connection_.SendCryptoDataWithString("bar", 3); + EXPECT_TRUE(writer_->IsWriteBlocked()); + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + // CRYPTO frames are not flushed when writer is blocked. + EXPECT_FALSE(connection_.HasQueuedData()); + } else { + EXPECT_TRUE(connection_.HasQueuedData()); + } +} + +TEST_P(QuicConnectionTest, BundleAckForSecondCHLO) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_FALSE(connection_.HasPendingAcks()); + EXPECT_CALL(visitor_, OnCanWrite()) + .WillOnce(IgnoreResult(InvokeWithoutArgs( + &connection_, &TestConnection::SendCryptoStreamData))); + // Process a packet from the crypto stream, which is frame1_'s default. + // Receiving the CHLO as packet 2 first will cause the connection to + // immediately send an ack, due to the packet gap. + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(1); + } else { + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + } + ProcessCryptoPacketAtLevel(2, ENCRYPTION_INITIAL); + // Check that ack is sent and that delayed ack alarm is reset. + if (GetParam().no_stop_waiting) { + EXPECT_EQ(3u, writer_->frame_count()); + EXPECT_TRUE(writer_->stop_waiting_frames().empty()); + } else { + EXPECT_EQ(4u, writer_->frame_count()); + EXPECT_FALSE(writer_->stop_waiting_frames().empty()); + } + if (!QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_EQ(1u, writer_->stream_frames().size()); + } else { + EXPECT_EQ(1u, writer_->crypto_frames().size()); + } + EXPECT_EQ(1u, writer_->padding_frames().size()); + ASSERT_FALSE(writer_->ack_frames().empty()); + EXPECT_EQ(QuicPacketNumber(2u), LargestAcked(writer_->ack_frames().front())); + EXPECT_FALSE(connection_.HasPendingAcks()); +} + +TEST_P(QuicConnectionTest, BundleAckForSecondCHLOTwoPacketReject) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_FALSE(connection_.HasPendingAcks()); + + // Process two packets from the crypto stream, which is frame1_'s default, + // simulating a 2 packet reject. + { + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(1); + } else { + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + } + ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL); + // Send the new CHLO when the REJ is processed. + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)) + .WillOnce(IgnoreResult(InvokeWithoutArgs( + &connection_, &TestConnection::SendCryptoStreamData))); + } else { + EXPECT_CALL(visitor_, OnStreamFrame(_)) + .WillOnce(IgnoreResult(InvokeWithoutArgs( + &connection_, &TestConnection::SendCryptoStreamData))); + } + ProcessCryptoPacketAtLevel(2, ENCRYPTION_INITIAL); + } + // Check that ack is sent and that delayed ack alarm is reset. + if (GetParam().no_stop_waiting) { + EXPECT_EQ(3u, writer_->frame_count()); + EXPECT_TRUE(writer_->stop_waiting_frames().empty()); + } else { + EXPECT_EQ(4u, writer_->frame_count()); + EXPECT_FALSE(writer_->stop_waiting_frames().empty()); + } + if (!QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_EQ(1u, writer_->stream_frames().size()); + } else { + EXPECT_EQ(1u, writer_->crypto_frames().size()); + } + EXPECT_EQ(1u, writer_->padding_frames().size()); + ASSERT_FALSE(writer_->ack_frames().empty()); + EXPECT_EQ(QuicPacketNumber(2u), LargestAcked(writer_->ack_frames().front())); + EXPECT_FALSE(connection_.HasPendingAcks()); +} + +TEST_P(QuicConnectionTest, BundleAckWithDataOnIncomingAck) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + connection_.SendStreamDataWithString( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), "foo", + 0, NO_FIN); + connection_.SendStreamDataWithString( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), "foo", + 3, NO_FIN); + // Ack the second packet, which will retransmit the first packet. + QuicAckFrame ack = ConstructAckFrame(2, 1); + LostPacketVector lost_packets; + lost_packets.push_back( + LostPacket(QuicPacketNumber(1), kMaxOutgoingPacketSize)); + EXPECT_CALL(*loss_algorithm_, DetectLosses(_, _, _, _, _, _)) + .WillOnce(DoAll(SetArgPointee<5>(lost_packets), + Return(LossDetectionInterface::DetectionStats()))); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + ProcessAckPacket(&ack); + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); + EXPECT_EQ(1u, writer_->stream_frames().size()); + writer_->Reset(); + + // Now ack the retransmission, which will both raise the high water mark + // and see if there is more data to send. + ack = ConstructAckFrame(3, 1); + EXPECT_CALL(*loss_algorithm_, DetectLosses(_, _, _, _, _, _)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + ProcessAckPacket(&ack); + + // Check that no packet is sent and the ack alarm isn't set. + EXPECT_EQ(0u, writer_->frame_count()); + EXPECT_FALSE(connection_.HasPendingAcks()); + writer_->Reset(); + + // Send the same ack, but send both data and an ack together. + ack = ConstructAckFrame(3, 1); + EXPECT_CALL(*loss_algorithm_, DetectLosses(_, _, _, _, _, _)); + EXPECT_CALL(visitor_, OnCanWrite()) + .WillOnce(IgnoreResult(InvokeWithoutArgs( + &connection_, &TestConnection::EnsureWritableAndSendStreamData5))); + ProcessAckPacket(&ack); + + // Check that ack is bundled with outgoing data and the delayed ack + // alarm is reset. + if (GetParam().no_stop_waiting) { + // Do not ACK acks. + EXPECT_EQ(1u, writer_->frame_count()); + } else { + EXPECT_EQ(3u, writer_->frame_count()); + EXPECT_FALSE(writer_->stop_waiting_frames().empty()); + } + if (GetParam().no_stop_waiting) { + EXPECT_TRUE(writer_->ack_frames().empty()); + } else { + EXPECT_FALSE(writer_->ack_frames().empty()); + EXPECT_EQ(QuicPacketNumber(3u), + LargestAcked(writer_->ack_frames().front())); + } + EXPECT_EQ(1u, writer_->stream_frames().size()); + EXPECT_FALSE(connection_.HasPendingAcks()); +} + +TEST_P(QuicConnectionTest, NoAckSentForClose) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + ProcessPacket(1); + EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_PEER)) + .WillOnce(Invoke(this, &QuicConnectionTest::SaveConnectionCloseFrame)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + ProcessClosePacket(2); + EXPECT_EQ(1, connection_close_frame_count_); + EXPECT_THAT(saved_connection_close_frame_.quic_error_code, + IsError(QUIC_PEER_GOING_AWAY)); +} + +TEST_P(QuicConnectionTest, SendWhenDisconnected) { + EXPECT_TRUE(connection_.connected()); + EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .WillOnce(Invoke(this, &QuicConnectionTest::SaveConnectionCloseFrame)); + connection_.CloseConnection(QUIC_PEER_GOING_AWAY, "no reason", + ConnectionCloseBehavior::SILENT_CLOSE); + EXPECT_FALSE(connection_.connected()); + EXPECT_FALSE(connection_.CanWrite(HAS_RETRANSMITTABLE_DATA)); + EXPECT_EQ(DISCARD, connection_.GetSerializedPacketFate( + /*is_mtu_discovery=*/false, ENCRYPTION_INITIAL)); +} + +TEST_P(QuicConnectionTest, SendConnectivityProbingWhenDisconnected) { + // EXPECT_QUIC_BUG tests are expensive so only run one instance of them. + if (!IsDefaultTestConfiguration()) { + return; + } + + EXPECT_TRUE(connection_.connected()); + EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .WillOnce(Invoke(this, &QuicConnectionTest::SaveConnectionCloseFrame)); + connection_.CloseConnection(QUIC_PEER_GOING_AWAY, "no reason", + ConnectionCloseBehavior::SILENT_CLOSE); + EXPECT_FALSE(connection_.connected()); + EXPECT_FALSE(connection_.CanWrite(HAS_RETRANSMITTABLE_DATA)); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, QuicPacketNumber(1), _, _)) + .Times(0); + + EXPECT_QUIC_BUG(connection_.SendConnectivityProbingPacket( + writer_.get(), connection_.peer_address()), + "Not sending connectivity probing packet as connection is " + "disconnected."); + EXPECT_EQ(1, connection_close_frame_count_); + EXPECT_THAT(saved_connection_close_frame_.quic_error_code, + IsError(QUIC_PEER_GOING_AWAY)); +} + +TEST_P(QuicConnectionTest, WriteBlockedAfterClientSendsConnectivityProbe) { + PathProbeTestInit(Perspective::IS_CLIENT); + TestPacketWriter probing_writer(version(), &clock_, Perspective::IS_CLIENT); + // Block next write so that sending connectivity probe will encounter a + // blocked write when send a connectivity probe to the peer. + probing_writer.BlockOnNextWrite(); + // Connection will not be marked as write blocked as connectivity probe only + // affects the probing_writer which is not the default. + EXPECT_CALL(visitor_, OnWriteBlocked()).Times(0); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, QuicPacketNumber(1), _, _)) + .Times(1); + connection_.SendConnectivityProbingPacket(&probing_writer, + connection_.peer_address()); +} + +TEST_P(QuicConnectionTest, WriterBlockedAfterServerSendsConnectivityProbe) { + PathProbeTestInit(Perspective::IS_SERVER); + if (version().SupportsAntiAmplificationLimit()) { + QuicConnectionPeer::SetAddressValidated(&connection_); + } + + // Block next write so that sending connectivity probe will encounter a + // blocked write when send a connectivity probe to the peer. + writer_->BlockOnNextWrite(); + // Connection will be marked as write blocked as server uses the default + // writer to send connectivity probes. + EXPECT_CALL(visitor_, OnWriteBlocked()).Times(1); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, QuicPacketNumber(1), _, _)) + .Times(1); + if (VersionHasIetfQuicFrames(GetParam().version.transport_version)) { + QuicPathFrameBuffer payload{ + {0xde, 0xad, 0xbe, 0xef, 0xba, 0xdc, 0x0f, 0xfe}}; + QuicConnection::ScopedPacketFlusher flusher(&connection_); + connection_.SendPathChallenge( + payload, connection_.self_address(), connection_.peer_address(), + connection_.effective_peer_address(), writer_.get()); + } else { + connection_.SendConnectivityProbingPacket(writer_.get(), + connection_.peer_address()); + } +} + +TEST_P(QuicConnectionTest, WriterErrorWhenClientSendsConnectivityProbe) { + PathProbeTestInit(Perspective::IS_CLIENT); + TestPacketWriter probing_writer(version(), &clock_, Perspective::IS_CLIENT); + probing_writer.SetShouldWriteFail(); + + // Connection should not be closed if a connectivity probe is failed to be + // sent. + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)).Times(0); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, QuicPacketNumber(1), _, _)) + .Times(0); + connection_.SendConnectivityProbingPacket(&probing_writer, + connection_.peer_address()); +} + +TEST_P(QuicConnectionTest, WriterErrorWhenServerSendsConnectivityProbe) { + PathProbeTestInit(Perspective::IS_SERVER); + + writer_->SetShouldWriteFail(); + // Connection should not be closed if a connectivity probe is failed to be + // sent. + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)).Times(0); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, QuicPacketNumber(1), _, _)) + .Times(0); + connection_.SendConnectivityProbingPacket(writer_.get(), + connection_.peer_address()); +} + +TEST_P(QuicConnectionTest, PublicReset) { + if (GetParam().version.HasIetfInvariantHeader()) { + return; + } + QuicPublicResetPacket header; + // Public reset packet in only built by server. + header.connection_id = connection_id_; + std::unique_ptr packet( + framer_.BuildPublicResetPacket(header)); + std::unique_ptr received( + ConstructReceivedPacket(*packet, QuicTime::Zero())); + EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_PEER)) + .WillOnce(Invoke(this, &QuicConnectionTest::SaveConnectionCloseFrame)); + connection_.ProcessUdpPacket(kSelfAddress, kPeerAddress, *received); + EXPECT_EQ(1, connection_close_frame_count_); + EXPECT_THAT(saved_connection_close_frame_.quic_error_code, + IsError(QUIC_PUBLIC_RESET)); +} + +TEST_P(QuicConnectionTest, IetfStatelessReset) { + if (!GetParam().version.HasIetfInvariantHeader()) { + return; + } + QuicConfig config; + QuicConfigPeer::SetReceivedStatelessResetToken(&config, + kTestStatelessResetToken); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + std::unique_ptr packet( + QuicFramer::BuildIetfStatelessResetPacket(connection_id_, + /*received_packet_length=*/100, + kTestStatelessResetToken)); + std::unique_ptr received( + ConstructReceivedPacket(*packet, QuicTime::Zero())); + EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_PEER)) + .WillOnce(Invoke(this, &QuicConnectionTest::SaveConnectionCloseFrame)); + connection_.ProcessUdpPacket(kSelfAddress, kPeerAddress, *received); + EXPECT_EQ(1, connection_close_frame_count_); + EXPECT_THAT(saved_connection_close_frame_.quic_error_code, + IsError(QUIC_PUBLIC_RESET)); +} + +TEST_P(QuicConnectionTest, GoAway) { + if (VersionHasIetfQuicFrames(GetParam().version.transport_version)) { + // GoAway is not available in version 99. + return; + } + + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + QuicGoAwayFrame* goaway = new QuicGoAwayFrame(); + goaway->last_good_stream_id = 1; + goaway->error_code = QUIC_PEER_GOING_AWAY; + goaway->reason_phrase = "Going away."; + EXPECT_CALL(visitor_, OnGoAway(_)); + ProcessGoAwayPacket(goaway); +} + +TEST_P(QuicConnectionTest, WindowUpdate) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + QuicWindowUpdateFrame window_update; + window_update.stream_id = 3; + window_update.max_data = 1234; + EXPECT_CALL(visitor_, OnWindowUpdateFrame(_)); + ProcessFramePacket(QuicFrame(window_update)); +} + +TEST_P(QuicConnectionTest, Blocked) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + QuicBlockedFrame blocked; + blocked.stream_id = 3; + EXPECT_CALL(visitor_, OnBlockedFrame(_)); + ProcessFramePacket(QuicFrame(blocked)); + EXPECT_EQ(1u, connection_.GetStats().blocked_frames_received); + EXPECT_EQ(0u, connection_.GetStats().blocked_frames_sent); +} + +TEST_P(QuicConnectionTest, ZeroBytePacket) { + // Don't close the connection for zero byte packets. + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)).Times(0); + QuicReceivedPacket encrypted(nullptr, 0, QuicTime::Zero()); + connection_.ProcessUdpPacket(kSelfAddress, kPeerAddress, encrypted); +} + +TEST_P(QuicConnectionTest, MissingPacketsBeforeLeastUnacked) { + if (GetParam().version.HasIetfInvariantHeader()) { + return; + } + // Set the packet number of the ack packet to be least unacked (4). + QuicPacketCreatorPeer::SetPacketNumber(&peer_creator_, 3); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + ProcessStopWaitingPacket(InitStopWaitingFrame(4)); + EXPECT_FALSE(connection_.ack_frame().packets.Empty()); +} + +TEST_P(QuicConnectionTest, ClientHandlesVersionNegotiation) { + // All supported versions except the one the connection supports. + ParsedQuicVersionVector versions; + for (auto version : AllSupportedVersions()) { + if (version != connection_.version()) { + versions.push_back(version); + } + } + + // Send a version negotiation packet. + std::unique_ptr encrypted( + QuicFramer::BuildVersionNegotiationPacket( + connection_id_, EmptyQuicConnectionId(), + connection_.version().HasIetfInvariantHeader(), + connection_.version().HasLengthPrefixedConnectionIds(), versions)); + std::unique_ptr received( + ConstructReceivedPacket(*encrypted, QuicTime::Zero())); + EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .WillOnce(Invoke(this, &QuicConnectionTest::SaveConnectionCloseFrame)); + // Verify no connection close packet gets sent. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + connection_.ProcessUdpPacket(kSelfAddress, kPeerAddress, *received); + EXPECT_FALSE(connection_.connected()); + EXPECT_EQ(1, connection_close_frame_count_); + EXPECT_THAT(saved_connection_close_frame_.quic_error_code, + IsError(QUIC_INVALID_VERSION)); +} + +TEST_P(QuicConnectionTest, ClientHandlesVersionNegotiationWithConnectionClose) { + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + QuicTagVector connection_options; + connection_options.push_back(kINVC); + config.SetClientConnectionOptions(connection_options); + connection_.SetFromConfig(config); + + // All supported versions except the one the connection supports. + ParsedQuicVersionVector versions; + for (auto version : AllSupportedVersions()) { + if (version != connection_.version()) { + versions.push_back(version); + } + } + + // Send a version negotiation packet. + std::unique_ptr encrypted( + QuicFramer::BuildVersionNegotiationPacket( + connection_id_, EmptyQuicConnectionId(), + connection_.version().HasIetfInvariantHeader(), + connection_.version().HasLengthPrefixedConnectionIds(), versions)); + std::unique_ptr received( + ConstructReceivedPacket(*encrypted, QuicTime::Zero())); + EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .WillOnce(Invoke(this, &QuicConnectionTest::SaveConnectionCloseFrame)); + // Verify connection close packet gets sent. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(AtLeast(1u)); + connection_.ProcessUdpPacket(kSelfAddress, kPeerAddress, *received); + EXPECT_FALSE(connection_.connected()); + EXPECT_EQ(1, connection_close_frame_count_); + EXPECT_THAT(saved_connection_close_frame_.quic_error_code, + IsError(QUIC_INVALID_VERSION)); +} + +TEST_P(QuicConnectionTest, BadVersionNegotiation) { + // Send a version negotiation packet with the version the client started with. + // It should be rejected. + EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .WillOnce(Invoke(this, &QuicConnectionTest::SaveConnectionCloseFrame)); + std::unique_ptr encrypted( + QuicFramer::BuildVersionNegotiationPacket( + connection_id_, EmptyQuicConnectionId(), + connection_.version().HasIetfInvariantHeader(), + connection_.version().HasLengthPrefixedConnectionIds(), + AllSupportedVersions())); + std::unique_ptr received( + ConstructReceivedPacket(*encrypted, QuicTime::Zero())); + connection_.ProcessUdpPacket(kSelfAddress, kPeerAddress, *received); + EXPECT_EQ(1, connection_close_frame_count_); + EXPECT_THAT(saved_connection_close_frame_.quic_error_code, + IsError(QUIC_INVALID_VERSION_NEGOTIATION_PACKET)); +} + +TEST_P(QuicConnectionTest, ProcessFramesIfPacketClosedConnection) { + // Construct a packet with stream frame and connection close frame. + QuicPacketHeader header; + if (peer_framer_.perspective() == Perspective::IS_SERVER) { + header.source_connection_id = connection_id_; + header.destination_connection_id_included = CONNECTION_ID_ABSENT; + if (!peer_framer_.version().HasIetfInvariantHeader()) { + header.source_connection_id_included = CONNECTION_ID_PRESENT; + } + } else { + header.destination_connection_id = connection_id_; + if (peer_framer_.version().HasIetfInvariantHeader()) { + header.destination_connection_id_included = CONNECTION_ID_ABSENT; + } + } + header.packet_number = QuicPacketNumber(1); + header.version_flag = false; + + QuicErrorCode kQuicErrorCode = QUIC_PEER_GOING_AWAY; + // This QuicConnectionCloseFrame will default to being for a Google QUIC + // close. If doing IETF QUIC then set fields appropriately for CC/T or CC/A, + // depending on the mapping. + QuicConnectionCloseFrame qccf(peer_framer_.transport_version(), + kQuicErrorCode, NO_IETF_QUIC_ERROR, "", + /*transport_close_frame_type=*/0); + QuicFrames frames; + frames.push_back(QuicFrame(frame1_)); + frames.push_back(QuicFrame(&qccf)); + std::unique_ptr packet(ConstructPacket(header, frames)); + EXPECT_TRUE(nullptr != packet); + char buffer[kMaxOutgoingPacketSize]; + size_t encrypted_length = peer_framer_.EncryptPayload( + ENCRYPTION_FORWARD_SECURE, QuicPacketNumber(1), *packet, buffer, + kMaxOutgoingPacketSize); + + EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_PEER)) + .WillOnce(Invoke(this, &QuicConnectionTest::SaveConnectionCloseFrame)); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + connection_.ProcessUdpPacket( + kSelfAddress, kPeerAddress, + QuicReceivedPacket(buffer, encrypted_length, QuicTime::Zero(), false)); + EXPECT_EQ(1, connection_close_frame_count_); + EXPECT_THAT(saved_connection_close_frame_.quic_error_code, + IsError(QUIC_PEER_GOING_AWAY)); +} + +TEST_P(QuicConnectionTest, SelectMutualVersion) { + connection_.SetSupportedVersions(AllSupportedVersions()); + // Set the connection to speak the lowest quic version. + connection_.set_version(QuicVersionMin()); + EXPECT_EQ(QuicVersionMin(), connection_.version()); + + // Pass in available versions which includes a higher mutually supported + // version. The higher mutually supported version should be selected. + ParsedQuicVersionVector supported_versions = AllSupportedVersions(); + EXPECT_TRUE(connection_.SelectMutualVersion(supported_versions)); + EXPECT_EQ(QuicVersionMax(), connection_.version()); + + // Expect that the lowest version is selected. + // Ensure the lowest supported version is less than the max, unless they're + // the same. + ParsedQuicVersionVector lowest_version_vector; + lowest_version_vector.push_back(QuicVersionMin()); + EXPECT_TRUE(connection_.SelectMutualVersion(lowest_version_vector)); + EXPECT_EQ(QuicVersionMin(), connection_.version()); + + // Shouldn't be able to find a mutually supported version. + ParsedQuicVersionVector unsupported_version; + unsupported_version.push_back(UnsupportedQuicVersion()); + EXPECT_FALSE(connection_.SelectMutualVersion(unsupported_version)); +} + +TEST_P(QuicConnectionTest, ConnectionCloseWhenWritable) { + EXPECT_FALSE(writer_->IsWriteBlocked()); + + // Send a packet. + connection_.SendStreamDataWithString(1, "foo", 0, NO_FIN); + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + EXPECT_EQ(1u, writer_->packets_write_attempts()); + + TriggerConnectionClose(); + EXPECT_LE(2u, writer_->packets_write_attempts()); +} + +TEST_P(QuicConnectionTest, ConnectionCloseGettingWriteBlocked) { + BlockOnNextWrite(); + TriggerConnectionClose(); + EXPECT_EQ(1u, writer_->packets_write_attempts()); + EXPECT_TRUE(writer_->IsWriteBlocked()); +} + +TEST_P(QuicConnectionTest, ConnectionCloseWhenWriteBlocked) { + BlockOnNextWrite(); + connection_.SendStreamDataWithString(1, "foo", 0, NO_FIN); + EXPECT_EQ(1u, connection_.NumQueuedPackets()); + EXPECT_EQ(1u, writer_->packets_write_attempts()); + EXPECT_TRUE(writer_->IsWriteBlocked()); + TriggerConnectionClose(); + EXPECT_EQ(1u, writer_->packets_write_attempts()); +} + +TEST_P(QuicConnectionTest, OnPacketSentDebugVisitor) { + PathProbeTestInit(Perspective::IS_CLIENT); + MockQuicConnectionDebugVisitor debug_visitor; + connection_.set_debug_visitor(&debug_visitor); + + EXPECT_CALL(debug_visitor, OnPacketSent(_, _, _, _, _, _, _, _)).Times(1); + connection_.SendStreamDataWithString(1, "foo", 0, NO_FIN); + + EXPECT_CALL(debug_visitor, OnPacketSent(_, _, _, _, _, _, _, _)).Times(1); + connection_.SendConnectivityProbingPacket(writer_.get(), + connection_.peer_address()); +} + +TEST_P(QuicConnectionTest, OnPacketHeaderDebugVisitor) { + QuicPacketHeader header; + header.packet_number = QuicPacketNumber(1); + if (GetParam().version.HasIetfInvariantHeader()) { + header.form = IETF_QUIC_LONG_HEADER_PACKET; + } + + MockQuicConnectionDebugVisitor debug_visitor; + connection_.set_debug_visitor(&debug_visitor); + EXPECT_CALL(debug_visitor, OnPacketHeader(Ref(header), _, _)).Times(1); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)).Times(1); + EXPECT_CALL(debug_visitor, OnSuccessfulVersionNegotiation(_)).Times(1); + connection_.OnPacketHeader(header); +} + +TEST_P(QuicConnectionTest, Pacing) { + TestConnection server(connection_id_, kPeerAddress, kSelfAddress, + helper_.get(), alarm_factory_.get(), writer_.get(), + Perspective::IS_SERVER, version(), + connection_id_generator_); + TestConnection client(connection_id_, kSelfAddress, kPeerAddress, + helper_.get(), alarm_factory_.get(), writer_.get(), + Perspective::IS_CLIENT, version(), + connection_id_generator_); + EXPECT_FALSE(QuicSentPacketManagerPeer::UsingPacing( + static_cast( + &client.sent_packet_manager()))); + EXPECT_FALSE(QuicSentPacketManagerPeer::UsingPacing( + static_cast( + &server.sent_packet_manager()))); +} + +TEST_P(QuicConnectionTest, WindowUpdateInstigateAcks) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + // Send a WINDOW_UPDATE frame. + QuicWindowUpdateFrame window_update; + window_update.stream_id = 3; + window_update.max_data = 1234; + EXPECT_CALL(visitor_, OnWindowUpdateFrame(_)); + ProcessFramePacket(QuicFrame(window_update)); + + // Ensure that this has caused the ACK alarm to be set. + EXPECT_TRUE(connection_.HasPendingAcks()); +} + +TEST_P(QuicConnectionTest, BlockedFrameInstigateAcks) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + // Send a BLOCKED frame. + QuicBlockedFrame blocked; + blocked.stream_id = 3; + EXPECT_CALL(visitor_, OnBlockedFrame(_)); + ProcessFramePacket(QuicFrame(blocked)); + + // Ensure that this has caused the ACK alarm to be set. + EXPECT_TRUE(connection_.HasPendingAcks()); +} + +TEST_P(QuicConnectionTest, ReevaluateTimeUntilSendOnAck) { + // Enable pacing. + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + connection_.SetFromConfig(config); + + // Send two packets. One packet is not sufficient because if it gets acked, + // there will be no packets in flight after that and the pacer will always + // allow the next packet in that situation. + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(*send_algorithm_, CanSend(_)).WillRepeatedly(Return(true)); + connection_.SendStreamDataWithString( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), "foo", + 0, NO_FIN); + connection_.SendStreamDataWithString( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), "bar", + 3, NO_FIN); + connection_.OnCanWrite(); + + // Schedule the next packet for a few milliseconds in future. + QuicSentPacketManagerPeer::DisablePacerBursts(manager_); + QuicTime scheduled_pacing_time = + clock_.Now() + QuicTime::Delta::FromMilliseconds(5); + QuicSentPacketManagerPeer::SetNextPacedPacketTime(manager_, + scheduled_pacing_time); + + // Send a packet and have it be blocked by congestion control. + EXPECT_CALL(*send_algorithm_, CanSend(_)).WillRepeatedly(Return(false)); + connection_.SendStreamDataWithString( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), "baz", + 6, NO_FIN); + EXPECT_FALSE(connection_.GetSendAlarm()->IsSet()); + + // Process an ack and the send alarm will be set to the new 5ms delay. + QuicAckFrame ack = InitAckFrame(1); + EXPECT_CALL(*loss_algorithm_, DetectLosses(_, _, _, _, _, _)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + EXPECT_CALL(*send_algorithm_, CanSend(_)).WillRepeatedly(Return(true)); + ProcessAckPacket(&ack); + size_t padding_frame_count = writer_->padding_frames().size(); + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); + EXPECT_EQ(1u, writer_->stream_frames().size()); + EXPECT_TRUE(connection_.GetSendAlarm()->IsSet()); + EXPECT_EQ(scheduled_pacing_time, connection_.GetSendAlarm()->deadline()); + writer_->Reset(); +} + +TEST_P(QuicConnectionTest, SendAcksImmediately) { + if (connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacket(1); + CongestionBlockWrites(); + SendAckPacketToPeer(); +} + +TEST_P(QuicConnectionTest, SendPingImmediately) { + MockQuicConnectionDebugVisitor debug_visitor; + connection_.set_debug_visitor(&debug_visitor); + + CongestionBlockWrites(); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + EXPECT_CALL(debug_visitor, OnPacketSent(_, _, _, _, _, _, _, _)).Times(1); + EXPECT_CALL(debug_visitor, OnPingSent()).Times(1); + connection_.SendControlFrame(QuicFrame(QuicPingFrame(1))); + EXPECT_FALSE(connection_.HasQueuedData()); +} + +TEST_P(QuicConnectionTest, SendBlockedImmediately) { + MockQuicConnectionDebugVisitor debug_visitor; + connection_.set_debug_visitor(&debug_visitor); + + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + EXPECT_CALL(debug_visitor, OnPacketSent(_, _, _, _, _, _, _, _)).Times(1); + EXPECT_EQ(0u, connection_.GetStats().blocked_frames_sent); + connection_.SendControlFrame(QuicFrame(QuicBlockedFrame(1, 3, 0))); + EXPECT_EQ(1u, connection_.GetStats().blocked_frames_sent); + EXPECT_FALSE(connection_.HasQueuedData()); +} + +TEST_P(QuicConnectionTest, FailedToSendBlockedFrames) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + MockQuicConnectionDebugVisitor debug_visitor; + connection_.set_debug_visitor(&debug_visitor); + QuicBlockedFrame blocked(1, 3, 0); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + EXPECT_CALL(debug_visitor, OnPacketSent(_, _, _, _, _, _, _, _)).Times(0); + EXPECT_EQ(0u, connection_.GetStats().blocked_frames_sent); + connection_.SendControlFrame(QuicFrame(blocked)); + EXPECT_EQ(0u, connection_.GetStats().blocked_frames_sent); + EXPECT_FALSE(connection_.HasQueuedData()); +} + +TEST_P(QuicConnectionTest, SendingUnencryptedStreamDataFails) { + // EXPECT_QUIC_BUG tests are expensive so only run one instance of them. + if (!IsDefaultTestConfiguration()) { + return; + } + + EXPECT_QUIC_BUG( + { + EXPECT_CALL(visitor_, + OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .WillOnce( + Invoke(this, &QuicConnectionTest::SaveConnectionCloseFrame)); + connection_.SaveAndSendStreamData(3, {}, 0, FIN); + EXPECT_FALSE(connection_.connected()); + EXPECT_EQ(1, connection_close_frame_count_); + EXPECT_THAT(saved_connection_close_frame_.quic_error_code, + IsError(QUIC_ATTEMPT_TO_SEND_UNENCRYPTED_STREAM_DATA)); + }, + "Cannot send stream data with level: ENCRYPTION_INITIAL"); +} + +TEST_P(QuicConnectionTest, SetRetransmissionAlarmForCryptoPacket) { + EXPECT_TRUE(connection_.connected()); + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.SendCryptoStreamData(); + + // Verify retransmission timer is correctly set after crypto packet has been + // sent. + EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + QuicTime retransmission_time = + QuicConnectionPeer::GetSentPacketManager(&connection_) + ->GetRetransmissionTime(); + EXPECT_NE(retransmission_time, clock_.ApproximateNow()); + EXPECT_EQ(retransmission_time, + connection_.GetRetransmissionAlarm()->deadline()); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.GetRetransmissionAlarm()->Fire(); +} + +// Includes regression test for b/69979024. +TEST_P(QuicConnectionTest, PathDegradingDetectionForNonCryptoPackets) { + EXPECT_TRUE(connection_.connected()); + EXPECT_FALSE(connection_.PathDegradingDetectionInProgress()); + EXPECT_FALSE(connection_.IsPathDegrading()); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + connection_.OnHandshakeComplete(); + + const char data[] = "data"; + size_t data_size = strlen(data); + QuicStreamOffset offset = 0; + + for (int i = 0; i < 2; ++i) { + // Send a packet. Now there's a retransmittable packet on the wire, so the + // path degrading detection should be set. + connection_.SendStreamDataWithString( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), data, + offset, NO_FIN); + offset += data_size; + EXPECT_TRUE(connection_.PathDegradingDetectionInProgress()); + // Check the deadline of the path degrading detection. + QuicTime::Delta delay = + QuicConnectionPeer::GetSentPacketManager(&connection_) + ->GetPathDegradingDelay(); + EXPECT_EQ(delay, connection_.GetBlackholeDetectorAlarm()->deadline() - + clock_.ApproximateNow()); + + // Send a second packet. The path degrading detection's deadline should + // remain the same. + // Regression test for b/69979024. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + QuicTime prev_deadline = + connection_.GetBlackholeDetectorAlarm()->deadline(); + connection_.SendStreamDataWithString( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), data, + offset, NO_FIN); + offset += data_size; + EXPECT_TRUE(connection_.PathDegradingDetectionInProgress()); + EXPECT_EQ(prev_deadline, + connection_.GetBlackholeDetectorAlarm()->deadline()); + + // Now receive an ACK of the first packet. This should advance the path + // degrading detection's deadline since forward progress has been made. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + if (i == 0) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + } + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + QuicAckFrame frame = InitAckFrame( + {{QuicPacketNumber(1u + 2u * i), QuicPacketNumber(2u + 2u * i)}}); + ProcessAckPacket(&frame); + EXPECT_TRUE(connection_.PathDegradingDetectionInProgress()); + // Check the deadline of the path degrading detection. + delay = QuicConnectionPeer::GetSentPacketManager(&connection_) + ->GetPathDegradingDelay(); + EXPECT_EQ(delay, connection_.GetBlackholeDetectorAlarm()->deadline() - + clock_.ApproximateNow()); + + if (i == 0) { + // Now receive an ACK of the second packet. Since there are no more + // retransmittable packets on the wire, this should cancel the path + // degrading detection. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + frame = InitAckFrame({{QuicPacketNumber(2), QuicPacketNumber(3)}}); + ProcessAckPacket(&frame); + EXPECT_FALSE(connection_.PathDegradingDetectionInProgress()); + } else { + // Advance time to the path degrading alarm's deadline and simulate + // firing the alarm. + clock_.AdvanceTime(delay); + EXPECT_CALL(visitor_, OnPathDegrading()); + connection_.PathDegradingTimeout(); + EXPECT_FALSE(connection_.PathDegradingDetectionInProgress()); + } + } + EXPECT_TRUE(connection_.IsPathDegrading()); +} + +TEST_P(QuicConnectionTest, RetransmittableOnWireSetsPingAlarm) { + const QuicTime::Delta retransmittable_on_wire_timeout = + QuicTime::Delta::FromMilliseconds(50); + connection_.set_initial_retransmittable_on_wire_timeout( + retransmittable_on_wire_timeout); + + EXPECT_TRUE(connection_.connected()); + EXPECT_CALL(visitor_, ShouldKeepConnectionAlive()) + .WillRepeatedly(Return(true)); + + EXPECT_FALSE(connection_.PathDegradingDetectionInProgress()); + EXPECT_FALSE(connection_.IsPathDegrading()); + EXPECT_FALSE(connection_.GetPingAlarm()->IsSet()); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + connection_.OnHandshakeComplete(); + + const char data[] = "data"; + size_t data_size = strlen(data); + QuicStreamOffset offset = 0; + + // Send a packet. + connection_.SendStreamDataWithString(1, data, offset, NO_FIN); + offset += data_size; + // Now there's a retransmittable packet on the wire, so the path degrading + // alarm should be set. + // The retransmittable-on-wire alarm should not be set. + EXPECT_TRUE(connection_.PathDegradingDetectionInProgress()); + QuicTime::Delta delay = QuicConnectionPeer::GetSentPacketManager(&connection_) + ->GetPathDegradingDelay(); + EXPECT_EQ(delay, connection_.GetBlackholeDetectorAlarm()->deadline() - + clock_.ApproximateNow()); + ASSERT_TRUE(connection_.sent_packet_manager().HasInFlightPackets()); + // The ping alarm is set for the ping timeout, not the shorter + // retransmittable_on_wire_timeout. + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + QuicTime::Delta ping_delay = QuicTime::Delta::FromSeconds(kPingTimeoutSecs); + EXPECT_EQ(ping_delay, + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + + // Now receive an ACK of the packet. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + QuicAckFrame frame = + InitAckFrame({{QuicPacketNumber(1), QuicPacketNumber(2)}}); + ProcessAckPacket(&frame); + // No more retransmittable packets on the wire, so the path degrading alarm + // should be cancelled, and the ping alarm should be set to the + // retransmittable_on_wire_timeout. + EXPECT_FALSE(connection_.PathDegradingDetectionInProgress()); + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(retransmittable_on_wire_timeout, + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + + // Simulate firing the ping alarm and sending a PING. + clock_.AdvanceTime(retransmittable_on_wire_timeout); + connection_.GetPingAlarm()->Fire(); + + // Now there's a retransmittable packet (PING) on the wire, so the path + // degrading alarm should be set. + ASSERT_TRUE(connection_.PathDegradingDetectionInProgress()); + delay = QuicConnectionPeer::GetSentPacketManager(&connection_) + ->GetPathDegradingDelay(); + EXPECT_EQ(delay, connection_.GetBlackholeDetectorAlarm()->deadline() - + clock_.ApproximateNow()); +} + +TEST_P(QuicConnectionTest, ServerRetransmittableOnWire) { + set_perspective(Perspective::IS_SERVER); + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + SetQuicReloadableFlag(quic_enable_server_on_wire_ping, true); + + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + QuicTagVector connection_options; + connection_options.push_back(kSRWP); + config.SetInitialReceivedConnectionOptions(connection_options); + connection_.SetFromConfig(config); + + EXPECT_CALL(visitor_, ShouldKeepConnectionAlive()) + .WillRepeatedly(Return(true)); + + ProcessPacket(1); + + ASSERT_TRUE(connection_.GetPingAlarm()->IsSet()); + QuicTime::Delta ping_delay = QuicTime::Delta::FromMilliseconds(200); + EXPECT_EQ(ping_delay, + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + connection_.SendStreamDataWithString(2, "foo", 0, NO_FIN); + // Verify PING alarm gets cancelled. + EXPECT_FALSE(connection_.GetPingAlarm()->IsSet()); + + // Now receive an ACK of the packet. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(100)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + QuicAckFrame frame = + InitAckFrame({{QuicPacketNumber(1), QuicPacketNumber(2)}}); + ProcessAckPacket(2, &frame); + // Verify PING alarm gets scheduled. + ASSERT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(ping_delay, + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); +} + +TEST_P(QuicConnectionTest, RetransmittableOnWireSendFirstPacket) { + if (!VersionHasIetfQuicFrames(connection_.version().transport_version)) { + return; + } + EXPECT_CALL(visitor_, ShouldKeepConnectionAlive()) + .WillRepeatedly(Return(true)); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + const QuicTime::Delta kRetransmittableOnWireTimeout = + QuicTime::Delta::FromMilliseconds(200); + const QuicTime::Delta kTestRtt = QuicTime::Delta::FromMilliseconds(100); + + connection_.set_initial_retransmittable_on_wire_timeout( + kRetransmittableOnWireTimeout); + + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + QuicTagVector connection_options; + connection_options.push_back(kROWF); + config.SetClientConnectionOptions(connection_options); + connection_.SetFromConfig(config); + + // Send a request. + connection_.SendStreamDataWithString(3, "foo", 0, NO_FIN); + // Receive an ACK after 1-RTT. + clock_.AdvanceTime(kTestRtt); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + QuicAckFrame frame = + InitAckFrame({{QuicPacketNumber(1), QuicPacketNumber(2)}}); + ProcessAckPacket(&frame); + ASSERT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(kRetransmittableOnWireTimeout, + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + EXPECT_EQ(1u, writer_->packets_write_attempts()); + + // Fire retransmittable-on-wire alarm. + clock_.AdvanceTime(kRetransmittableOnWireTimeout); + connection_.GetPingAlarm()->Fire(); + EXPECT_EQ(2u, writer_->packets_write_attempts()); + // Verify alarm is set in keep-alive mode. + ASSERT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(QuicTime::Delta::FromSeconds(kPingTimeoutSecs), + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); +} + +TEST_P(QuicConnectionTest, RetransmittableOnWireSendRandomBytes) { + if (!VersionHasIetfQuicFrames(connection_.version().transport_version)) { + return; + } + EXPECT_CALL(visitor_, ShouldKeepConnectionAlive()) + .WillRepeatedly(Return(true)); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + const QuicTime::Delta kRetransmittableOnWireTimeout = + QuicTime::Delta::FromMilliseconds(200); + const QuicTime::Delta kTestRtt = QuicTime::Delta::FromMilliseconds(100); + + connection_.set_initial_retransmittable_on_wire_timeout( + kRetransmittableOnWireTimeout); + + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + QuicTagVector connection_options; + connection_options.push_back(kROWR); + config.SetClientConnectionOptions(connection_options); + connection_.SetFromConfig(config); + + // Send a request. + connection_.SendStreamDataWithString(3, "foo", 0, NO_FIN); + // Receive an ACK after 1-RTT. + clock_.AdvanceTime(kTestRtt); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + QuicAckFrame frame = + InitAckFrame({{QuicPacketNumber(1), QuicPacketNumber(2)}}); + ProcessAckPacket(&frame); + ASSERT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(kRetransmittableOnWireTimeout, + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + EXPECT_EQ(1u, writer_->packets_write_attempts()); + + // Fire retransmittable-on-wire alarm. + clock_.AdvanceTime(kRetransmittableOnWireTimeout); + // Next packet is not processable by the framer in the test writer. + ExpectNextPacketUnprocessable(); + connection_.GetPingAlarm()->Fire(); + EXPECT_EQ(2u, writer_->packets_write_attempts()); + // Verify alarm is set in keep-alive mode. + ASSERT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(QuicTime::Delta::FromSeconds(kPingTimeoutSecs), + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); +} + +TEST_P(QuicConnectionTest, + RetransmittableOnWireSendRandomBytesWithWriterBlocked) { + if (!VersionHasIetfQuicFrames(connection_.version().transport_version)) { + return; + } + EXPECT_CALL(visitor_, ShouldKeepConnectionAlive()) + .WillRepeatedly(Return(true)); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + + const QuicTime::Delta kRetransmittableOnWireTimeout = + QuicTime::Delta::FromMilliseconds(200); + const QuicTime::Delta kTestRtt = QuicTime::Delta::FromMilliseconds(100); + + connection_.set_initial_retransmittable_on_wire_timeout( + kRetransmittableOnWireTimeout); + + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + QuicTagVector connection_options; + connection_options.push_back(kROWR); + config.SetClientConnectionOptions(connection_options); + connection_.SetFromConfig(config); + + // Send a request. + connection_.SendStreamDataWithString(3, "foo", 0, NO_FIN); + // Receive an ACK after 1-RTT. + clock_.AdvanceTime(kTestRtt); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + QuicAckFrame frame = + InitAckFrame({{QuicPacketNumber(1), QuicPacketNumber(2)}}); + ProcessAckPacket(&frame); + ASSERT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(kRetransmittableOnWireTimeout, + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + EXPECT_EQ(1u, writer_->packets_write_attempts()); + // Receive an out of order data packet and block the ACK packet. + BlockOnNextWrite(); + ProcessDataPacket(3); + EXPECT_EQ(2u, writer_->packets_write_attempts()); + EXPECT_EQ(1u, connection_.NumQueuedPackets()); + + // Fire retransmittable-on-wire alarm. + clock_.AdvanceTime(kRetransmittableOnWireTimeout); + connection_.GetPingAlarm()->Fire(); + // Verify the random bytes packet gets queued. + EXPECT_EQ(2u, connection_.NumQueuedPackets()); +} + +// This test verifies that the connection marks path as degrading and does not +// spin timer to detect path degrading when a new packet is sent on the +// degraded path. +TEST_P(QuicConnectionTest, NoPathDegradingDetectionIfPathIsDegrading) { + EXPECT_TRUE(connection_.connected()); + EXPECT_FALSE(connection_.PathDegradingDetectionInProgress()); + EXPECT_FALSE(connection_.IsPathDegrading()); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + connection_.OnHandshakeComplete(); + + const char data[] = "data"; + size_t data_size = strlen(data); + QuicStreamOffset offset = 0; + + // Send the first packet. Now there's a retransmittable packet on the wire, so + // the path degrading alarm should be set. + connection_.SendStreamDataWithString(1, data, offset, NO_FIN); + offset += data_size; + EXPECT_TRUE(connection_.PathDegradingDetectionInProgress()); + // Check the deadline of the path degrading detection. + QuicTime::Delta delay = QuicConnectionPeer::GetSentPacketManager(&connection_) + ->GetPathDegradingDelay(); + EXPECT_EQ(delay, connection_.GetBlackholeDetectorAlarm()->deadline() - + clock_.ApproximateNow()); + + // Send a second packet. The path degrading detection's deadline should remain + // the same. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + QuicTime prev_deadline = connection_.GetBlackholeDetectorAlarm()->deadline(); + connection_.SendStreamDataWithString(1, data, offset, NO_FIN); + offset += data_size; + EXPECT_TRUE(connection_.PathDegradingDetectionInProgress()); + EXPECT_EQ(prev_deadline, connection_.GetBlackholeDetectorAlarm()->deadline()); + + // Now receive an ACK of the first packet. This should advance the path + // degrading detection's deadline since forward progress has been made. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + QuicAckFrame frame = + InitAckFrame({{QuicPacketNumber(1u), QuicPacketNumber(2u)}}); + ProcessAckPacket(&frame); + EXPECT_TRUE(connection_.PathDegradingDetectionInProgress()); + // Check the deadline of the path degrading alarm. + delay = QuicConnectionPeer::GetSentPacketManager(&connection_) + ->GetPathDegradingDelay(); + EXPECT_EQ(delay, connection_.GetBlackholeDetectorAlarm()->deadline() - + clock_.ApproximateNow()); + + // Advance time to the path degrading detection's deadline and simulate + // firing the path degrading detection. This path will be considered as + // degrading. + clock_.AdvanceTime(delay); + EXPECT_CALL(visitor_, OnPathDegrading()).Times(1); + connection_.PathDegradingTimeout(); + EXPECT_FALSE(connection_.PathDegradingDetectionInProgress()); + EXPECT_TRUE(connection_.IsPathDegrading()); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + EXPECT_FALSE(connection_.PathDegradingDetectionInProgress()); + // Send a third packet. The path degrading detection is no longer set but path + // should still be marked as degrading. + connection_.SendStreamDataWithString(1, data, offset, NO_FIN); + offset += data_size; + EXPECT_FALSE(connection_.PathDegradingDetectionInProgress()); + EXPECT_TRUE(connection_.IsPathDegrading()); +} + +TEST_P(QuicConnectionTest, NoPathDegradingDetectionBeforeHandshakeConfirmed) { + EXPECT_TRUE(connection_.connected()); + EXPECT_FALSE(connection_.PathDegradingDetectionInProgress()); + EXPECT_FALSE(connection_.IsPathDegrading()); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_COMPLETE)); + + connection_.SendStreamDataWithString(1, "data", 0, NO_FIN); + if (GetQuicReloadableFlag( + quic_no_path_degrading_before_handshake_confirmed) && + connection_.SupportsMultiplePacketNumberSpaces()) { + EXPECT_FALSE(connection_.PathDegradingDetectionInProgress()); + } else { + EXPECT_TRUE(connection_.PathDegradingDetectionInProgress()); + } +} + +// This test verifies that the connection unmarks path as degrarding and spins +// the timer to detect future path degrading when forward progress is made +// after path has been marked degrading. +TEST_P(QuicConnectionTest, UnmarkPathDegradingOnForwardProgress) { + EXPECT_TRUE(connection_.connected()); + EXPECT_FALSE(connection_.PathDegradingDetectionInProgress()); + EXPECT_FALSE(connection_.IsPathDegrading()); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + connection_.OnHandshakeComplete(); + + const char data[] = "data"; + size_t data_size = strlen(data); + QuicStreamOffset offset = 0; + + // Send the first packet. Now there's a retransmittable packet on the wire, so + // the path degrading alarm should be set. + connection_.SendStreamDataWithString(1, data, offset, NO_FIN); + offset += data_size; + EXPECT_TRUE(connection_.PathDegradingDetectionInProgress()); + // Check the deadline of the path degrading alarm. + QuicTime::Delta delay = QuicConnectionPeer::GetSentPacketManager(&connection_) + ->GetPathDegradingDelay(); + EXPECT_EQ(delay, connection_.GetBlackholeDetectorAlarm()->deadline() - + clock_.ApproximateNow()); + + // Send a second packet. The path degrading alarm's deadline should remain + // the same. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + QuicTime prev_deadline = connection_.GetBlackholeDetectorAlarm()->deadline(); + connection_.SendStreamDataWithString(1, data, offset, NO_FIN); + offset += data_size; + EXPECT_TRUE(connection_.PathDegradingDetectionInProgress()); + EXPECT_EQ(prev_deadline, connection_.GetBlackholeDetectorAlarm()->deadline()); + + // Now receive an ACK of the first packet. This should advance the path + // degrading alarm's deadline since forward progress has been made. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + QuicAckFrame frame = + InitAckFrame({{QuicPacketNumber(1u), QuicPacketNumber(2u)}}); + ProcessAckPacket(&frame); + EXPECT_TRUE(connection_.PathDegradingDetectionInProgress()); + // Check the deadline of the path degrading alarm. + delay = QuicConnectionPeer::GetSentPacketManager(&connection_) + ->GetPathDegradingDelay(); + EXPECT_EQ(delay, connection_.GetBlackholeDetectorAlarm()->deadline() - + clock_.ApproximateNow()); + + // Advance time to the path degrading alarm's deadline and simulate + // firing the alarm. + clock_.AdvanceTime(delay); + EXPECT_CALL(visitor_, OnPathDegrading()).Times(1); + connection_.PathDegradingTimeout(); + EXPECT_FALSE(connection_.PathDegradingDetectionInProgress()); + EXPECT_TRUE(connection_.IsPathDegrading()); + + // Send a third packet. The path degrading alarm is no longer set but path + // should still be marked as degrading. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + EXPECT_FALSE(connection_.PathDegradingDetectionInProgress()); + connection_.SendStreamDataWithString(1, data, offset, NO_FIN); + offset += data_size; + EXPECT_FALSE(connection_.PathDegradingDetectionInProgress()); + EXPECT_TRUE(connection_.IsPathDegrading()); + + // Now receive an ACK of the second packet. This should unmark the path as + // degrading. And will set a timer to detect new path degrading. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + EXPECT_CALL(visitor_, OnForwardProgressMadeAfterPathDegrading()).Times(1); + frame = InitAckFrame({{QuicPacketNumber(2), QuicPacketNumber(3)}}); + ProcessAckPacket(&frame); + EXPECT_FALSE(connection_.IsPathDegrading()); + EXPECT_TRUE(connection_.PathDegradingDetectionInProgress()); +} + +TEST_P(QuicConnectionTest, NoPathDegradingOnServer) { + if (connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + set_perspective(Perspective::IS_SERVER); + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + + EXPECT_FALSE(connection_.IsPathDegrading()); + EXPECT_FALSE(connection_.PathDegradingDetectionInProgress()); + + // Send data. + const char data[] = "data"; + connection_.SendStreamDataWithString(1, data, 0, NO_FIN); + EXPECT_FALSE(connection_.IsPathDegrading()); + EXPECT_FALSE(connection_.PathDegradingDetectionInProgress()); + + // Ack data. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + QuicAckFrame frame = + InitAckFrame({{QuicPacketNumber(1u), QuicPacketNumber(2u)}}); + ProcessAckPacket(&frame); + EXPECT_FALSE(connection_.IsPathDegrading()); + EXPECT_FALSE(connection_.PathDegradingDetectionInProgress()); +} + +TEST_P(QuicConnectionTest, NoPathDegradingAfterSendingAck) { + if (connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacket(1); + SendAckPacketToPeer(); + EXPECT_FALSE(connection_.sent_packet_manager().unacked_packets().empty()); + EXPECT_FALSE(connection_.sent_packet_manager().HasInFlightPackets()); + EXPECT_FALSE(connection_.IsPathDegrading()); + EXPECT_FALSE(connection_.PathDegradingDetectionInProgress()); +} + +TEST_P(QuicConnectionTest, MultipleCallsToCloseConnection) { + // Verifies that multiple calls to CloseConnection do not + // result in multiple attempts to close the connection - it will be marked as + // disconnected after the first call. + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)).Times(1); + connection_.CloseConnection(QUIC_NO_ERROR, "no reason", + ConnectionCloseBehavior::SILENT_CLOSE); + connection_.CloseConnection(QUIC_NO_ERROR, "no reason", + ConnectionCloseBehavior::SILENT_CLOSE); +} + +TEST_P(QuicConnectionTest, ServerReceivesChloOnNonCryptoStream) { + set_perspective(Perspective::IS_SERVER); + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + QuicConnectionPeer::SetAddressValidated(&connection_); + + CryptoHandshakeMessage message; + CryptoFramer framer; + message.set_tag(kCHLO); + std::unique_ptr data = framer.ConstructHandshakeMessage(message); + frame1_.stream_id = 10; + frame1_.data_buffer = data->data(); + frame1_.data_length = data->length(); + + if (version().handshake_protocol == PROTOCOL_TLS1_3) { + EXPECT_CALL(visitor_, BeforeConnectionCloseSent()); + } + EXPECT_CALL(visitor_, + OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)); + ForceProcessFramePacket(QuicFrame(frame1_)); + if (VersionHasIetfQuicFrames(version().transport_version)) { + // INITIAL packet should not contain STREAM frame. + TestConnectionCloseQuicErrorCode(IETF_QUIC_PROTOCOL_VIOLATION); + } else { + TestConnectionCloseQuicErrorCode(QUIC_MAYBE_CORRUPTED_MEMORY); + } +} + +TEST_P(QuicConnectionTest, ClientReceivesRejOnNonCryptoStream) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + CryptoHandshakeMessage message; + CryptoFramer framer; + message.set_tag(kREJ); + std::unique_ptr data = framer.ConstructHandshakeMessage(message); + frame1_.stream_id = 10; + frame1_.data_buffer = data->data(); + frame1_.data_length = data->length(); + + EXPECT_CALL(visitor_, + OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)); + ForceProcessFramePacket(QuicFrame(frame1_)); + if (VersionHasIetfQuicFrames(version().transport_version)) { + // INITIAL packet should not contain STREAM frame. + TestConnectionCloseQuicErrorCode(IETF_QUIC_PROTOCOL_VIOLATION); + } else { + TestConnectionCloseQuicErrorCode(QUIC_MAYBE_CORRUPTED_MEMORY); + } +} + +TEST_P(QuicConnectionTest, CloseConnectionOnPacketTooLarge) { + SimulateNextPacketTooLarge(); + // A connection close packet is sent + EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .Times(1); + connection_.SendStreamDataWithString(3, "foo", 0, NO_FIN); + TestConnectionCloseQuicErrorCode(QUIC_PACKET_WRITE_ERROR); +} + +TEST_P(QuicConnectionTest, AlwaysGetPacketTooLarge) { + // Test even we always get packet too large, we do not infinitely try to send + // close packet. + AlwaysGetPacketTooLarge(); + EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .Times(1); + connection_.SendStreamDataWithString(3, "foo", 0, NO_FIN); + TestConnectionCloseQuicErrorCode(QUIC_PACKET_WRITE_ERROR); +} + +TEST_P(QuicConnectionTest, CloseConnectionOnQueuedWriteError) { + // Regression test for crbug.com/979507. + // + // If we get a write error when writing queued packets, we should attempt to + // send a connection close packet, but if sending that fails, it shouldn't get + // queued. + + // Queue a packet to write. + BlockOnNextWrite(); + connection_.SendStreamDataWithString(3, "foo", 0, NO_FIN); + EXPECT_EQ(1u, connection_.NumQueuedPackets()); + + // Configure writer to always fail. + AlwaysGetPacketTooLarge(); + + // Expect that we attempt to close the connection exactly once. + EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .Times(1); + + // Unblock the writes and actually send. + writer_->SetWritable(); + connection_.OnCanWrite(); + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + + TestConnectionCloseQuicErrorCode(QUIC_PACKET_WRITE_ERROR); +} + +// Verify that if connection has no outstanding data, it notifies the send +// algorithm after the write. +TEST_P(QuicConnectionTest, SendDataAndBecomeApplicationLimited) { + EXPECT_CALL(*send_algorithm_, OnApplicationLimited(_)).Times(1); + { + InSequence seq; + EXPECT_CALL(visitor_, WillingAndAbleToWrite()).WillRepeatedly(Return(true)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)); + EXPECT_CALL(visitor_, WillingAndAbleToWrite()) + .WillRepeatedly(Return(false)); + } + + connection_.SendStreamData3(); +} + +// Verify that the connection does not become app-limited if there is +// outstanding data to send after the write. +TEST_P(QuicConnectionTest, NotBecomeApplicationLimitedIfMoreDataAvailable) { + EXPECT_CALL(*send_algorithm_, OnApplicationLimited(_)).Times(0); + { + InSequence seq; + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)); + EXPECT_CALL(visitor_, WillingAndAbleToWrite()).WillRepeatedly(Return(true)); + } + + connection_.SendStreamData3(); +} + +// Verify that the connection does not become app-limited after blocked write +// even if there is outstanding data to send after the write. +TEST_P(QuicConnectionTest, NotBecomeApplicationLimitedDueToWriteBlock) { + EXPECT_CALL(*send_algorithm_, OnApplicationLimited(_)).Times(0); + EXPECT_CALL(visitor_, WillingAndAbleToWrite()).WillRepeatedly(Return(true)); + BlockOnNextWrite(); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.SendStreamData3(); + + // Now unblock the writer, become congestion control blocked, + // and ensure we become app-limited after writing. + writer_->SetWritable(); + CongestionBlockWrites(); + EXPECT_CALL(visitor_, WillingAndAbleToWrite()).WillRepeatedly(Return(false)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + EXPECT_CALL(*send_algorithm_, OnApplicationLimited(_)).Times(1); + connection_.OnCanWrite(); +} + +TEST_P(QuicConnectionTest, DoNotForceSendingAckOnPacketTooLarge) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + // Send an ack by simulating delayed ack alarm firing. + ProcessPacket(1); + EXPECT_TRUE(connection_.HasPendingAcks()); + connection_.GetAckAlarm()->Fire(); + // Simulate data packet causes write error. + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)); + SimulateNextPacketTooLarge(); + connection_.SendStreamDataWithString(3, "foo", 0, NO_FIN); + EXPECT_EQ(1u, writer_->connection_close_frames().size()); + // Ack frame is not bundled in connection close packet. + EXPECT_TRUE(writer_->ack_frames().empty()); + if (writer_->padding_frames().empty()) { + EXPECT_EQ(1u, writer_->frame_count()); + } else { + EXPECT_EQ(2u, writer_->frame_count()); + } + + TestConnectionCloseQuicErrorCode(QUIC_PACKET_WRITE_ERROR); +} + +TEST_P(QuicConnectionTest, CloseConnectionAllLevels) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)); + const QuicErrorCode kQuicErrorCode = QUIC_INTERNAL_ERROR; + connection_.CloseConnection( + kQuicErrorCode, "Some random error message", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + + EXPECT_EQ(2u, QuicConnectionPeer::GetNumEncryptionLevels(&connection_)); + + TestConnectionCloseQuicErrorCode(kQuicErrorCode); + EXPECT_EQ(1u, writer_->connection_close_frames().size()); + + if (!connection_.version().CanSendCoalescedPackets()) { + // Each connection close packet should be sent in distinct UDP packets. + EXPECT_EQ(QuicConnectionPeer::GetNumEncryptionLevels(&connection_), + writer_->connection_close_packets()); + EXPECT_EQ(QuicConnectionPeer::GetNumEncryptionLevels(&connection_), + writer_->packets_write_attempts()); + return; + } + + // A single UDP packet should be sent with multiple connection close packets + // coalesced together. + EXPECT_EQ(1u, writer_->packets_write_attempts()); + + // Only the first packet has been processed yet. + EXPECT_EQ(1u, writer_->connection_close_packets()); + + // ProcessPacket resets the visitor and frees the coalesced packet. + ASSERT_TRUE(writer_->coalesced_packet() != nullptr); + auto packet = writer_->coalesced_packet()->Clone(); + writer_->framer()->ProcessPacket(*packet); + EXPECT_EQ(1u, writer_->connection_close_packets()); + ASSERT_TRUE(writer_->coalesced_packet() == nullptr); +} + +TEST_P(QuicConnectionTest, CloseConnectionOneLevel) { + if (connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)); + const QuicErrorCode kQuicErrorCode = QUIC_INTERNAL_ERROR; + connection_.CloseConnection( + kQuicErrorCode, "Some random error message", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + + EXPECT_EQ(2u, QuicConnectionPeer::GetNumEncryptionLevels(&connection_)); + + TestConnectionCloseQuicErrorCode(kQuicErrorCode); + EXPECT_EQ(1u, writer_->connection_close_frames().size()); + EXPECT_EQ(1u, writer_->connection_close_packets()); + EXPECT_EQ(1u, writer_->packets_write_attempts()); + ASSERT_TRUE(writer_->coalesced_packet() == nullptr); +} + +TEST_P(QuicConnectionTest, DoNotPadServerInitialConnectionClose) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + set_perspective(Perspective::IS_SERVER); + // Receives packet 1000 in initial data. + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(1); + ProcessCryptoPacketAtLevel(1000, ENCRYPTION_INITIAL); + + if (version().handshake_protocol == PROTOCOL_TLS1_3) { + EXPECT_CALL(visitor_, BeforeConnectionCloseSent()); + } + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)); + const QuicErrorCode kQuicErrorCode = QUIC_INTERNAL_ERROR; + connection_.CloseConnection( + kQuicErrorCode, "Some random error message", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + + EXPECT_EQ(2u, QuicConnectionPeer::GetNumEncryptionLevels(&connection_)); + + TestConnectionCloseQuicErrorCode(kQuicErrorCode); + EXPECT_EQ(1u, writer_->connection_close_frames().size()); + EXPECT_TRUE(writer_->padding_frames().empty()); + EXPECT_EQ(ENCRYPTION_INITIAL, writer_->framer()->last_decrypted_level()); +} + +// Regression test for b/63620844. +TEST_P(QuicConnectionTest, FailedToWriteHandshakePacket) { + SimulateNextPacketTooLarge(); + EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .Times(1); + + connection_.SendCryptoStreamData(); + TestConnectionCloseQuicErrorCode(QUIC_PACKET_WRITE_ERROR); +} + +TEST_P(QuicConnectionTest, MaxPacingRate) { + EXPECT_EQ(0, connection_.MaxPacingRate().ToBytesPerSecond()); + connection_.SetMaxPacingRate(QuicBandwidth::FromBytesPerSecond(100)); + EXPECT_EQ(100, connection_.MaxPacingRate().ToBytesPerSecond()); +} + +TEST_P(QuicConnectionTest, ClientAlwaysSendConnectionId) { + EXPECT_EQ(Perspective::IS_CLIENT, connection_.perspective()); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.SendStreamDataWithString(3, "foo", 0, NO_FIN); + EXPECT_EQ(CONNECTION_ID_PRESENT, + writer_->last_packet_header().destination_connection_id_included); + + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + QuicConfigPeer::SetReceivedBytesForConnectionId(&config, 0); + connection_.SetFromConfig(config); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.SendStreamDataWithString(3, "bar", 3, NO_FIN); + // Verify connection id is still sent in the packet. + EXPECT_EQ(CONNECTION_ID_PRESENT, + writer_->last_packet_header().destination_connection_id_included); +} + +TEST_P(QuicConnectionTest, PingAfterLastRetransmittablePacketAcked) { + const QuicTime::Delta retransmittable_on_wire_timeout = + QuicTime::Delta::FromMilliseconds(50); + connection_.set_initial_retransmittable_on_wire_timeout( + retransmittable_on_wire_timeout); + + EXPECT_TRUE(connection_.connected()); + EXPECT_CALL(visitor_, ShouldKeepConnectionAlive()) + .WillRepeatedly(Return(true)); + + const char data[] = "data"; + size_t data_size = strlen(data); + QuicStreamOffset offset = 0; + + // Advance 5ms, send a retransmittable packet to the peer. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + EXPECT_FALSE(connection_.GetPingAlarm()->IsSet()); + connection_.SendStreamDataWithString(1, data, offset, NO_FIN); + offset += data_size; + EXPECT_TRUE(connection_.sent_packet_manager().HasInFlightPackets()); + // The ping alarm is set for the ping timeout, not the shorter + // retransmittable_on_wire_timeout. + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + QuicTime::Delta ping_delay = QuicTime::Delta::FromSeconds(kPingTimeoutSecs); + EXPECT_EQ(ping_delay, + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + + // Advance 5ms, send a second retransmittable packet to the peer. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + connection_.SendStreamDataWithString(1, data, offset, NO_FIN); + offset += data_size; + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + + // Now receive an ACK of the first packet. This should not set the + // retransmittable-on-wire alarm since packet 2 is still on the wire. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + QuicAckFrame frame = + InitAckFrame({{QuicPacketNumber(1), QuicPacketNumber(2)}}); + ProcessAckPacket(&frame); + EXPECT_TRUE(connection_.sent_packet_manager().HasInFlightPackets()); + // The ping alarm is set for the ping timeout, not the shorter + // retransmittable_on_wire_timeout. + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + // The ping alarm has a 1 second granularity, and the clock has been advanced + // 10ms since it was originally set. + EXPECT_EQ(ping_delay - QuicTime::Delta::FromMilliseconds(10), + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + + // Now receive an ACK of the second packet. This should set the + // retransmittable-on-wire alarm now that no retransmittable packets are on + // the wire. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + frame = InitAckFrame({{QuicPacketNumber(2), QuicPacketNumber(3)}}); + ProcessAckPacket(&frame); + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(retransmittable_on_wire_timeout, + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + + // Now receive a duplicate ACK of the second packet. This should not update + // the ping alarm. + QuicTime prev_deadline = connection_.GetPingAlarm()->deadline(); + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + frame = InitAckFrame({{QuicPacketNumber(2), QuicPacketNumber(3)}}); + ProcessAckPacket(&frame); + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(prev_deadline, connection_.GetPingAlarm()->deadline()); + + // Now receive a non-ACK packet. This should not update the ping alarm. + prev_deadline = connection_.GetPingAlarm()->deadline(); + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + ProcessPacket(4); + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(prev_deadline, connection_.GetPingAlarm()->deadline()); + + // Simulate the alarm firing and check that a PING is sent. + connection_.GetPingAlarm()->Fire(); + size_t padding_frame_count = writer_->padding_frames().size(); + if (GetParam().no_stop_waiting) { + EXPECT_EQ(padding_frame_count + 2u, writer_->frame_count()); + } else { + EXPECT_EQ(padding_frame_count + 3u, writer_->frame_count()); + } + ASSERT_EQ(1u, writer_->ping_frames().size()); +} + +TEST_P(QuicConnectionTest, NoPingIfRetransmittablePacketSent) { + const QuicTime::Delta retransmittable_on_wire_timeout = + QuicTime::Delta::FromMilliseconds(50); + connection_.set_initial_retransmittable_on_wire_timeout( + retransmittable_on_wire_timeout); + + EXPECT_TRUE(connection_.connected()); + EXPECT_CALL(visitor_, ShouldKeepConnectionAlive()) + .WillRepeatedly(Return(true)); + + const char data[] = "data"; + size_t data_size = strlen(data); + QuicStreamOffset offset = 0; + + // Advance 5ms, send a retransmittable packet to the peer. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + EXPECT_FALSE(connection_.GetPingAlarm()->IsSet()); + connection_.SendStreamDataWithString(1, data, offset, NO_FIN); + offset += data_size; + EXPECT_TRUE(connection_.sent_packet_manager().HasInFlightPackets()); + // The ping alarm is set for the ping timeout, not the shorter + // retransmittable_on_wire_timeout. + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + QuicTime::Delta ping_delay = QuicTime::Delta::FromSeconds(kPingTimeoutSecs); + EXPECT_EQ(ping_delay, + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + + // Now receive an ACK of the first packet. This should set the + // retransmittable-on-wire alarm now that no retransmittable packets are on + // the wire. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + QuicAckFrame frame = + InitAckFrame({{QuicPacketNumber(1), QuicPacketNumber(2)}}); + ProcessAckPacket(&frame); + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(retransmittable_on_wire_timeout, + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + + // Before the alarm fires, send another retransmittable packet. This should + // cancel the retransmittable-on-wire alarm since now there's a + // retransmittable packet on the wire. + connection_.SendStreamDataWithString(1, data, offset, NO_FIN); + offset += data_size; + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + + // Now receive an ACK of the second packet. This should set the + // retransmittable-on-wire alarm now that no retransmittable packets are on + // the wire. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + frame = InitAckFrame({{QuicPacketNumber(2), QuicPacketNumber(3)}}); + ProcessAckPacket(&frame); + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(retransmittable_on_wire_timeout, + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + + // Simulate the alarm firing and check that a PING is sent. + writer_->Reset(); + connection_.GetPingAlarm()->Fire(); + size_t padding_frame_count = writer_->padding_frames().size(); + if (GetParam().no_stop_waiting) { + // Do not ACK acks. + EXPECT_EQ(padding_frame_count + 1u, writer_->frame_count()); + } else { + EXPECT_EQ(padding_frame_count + 3u, writer_->frame_count()); + } + ASSERT_EQ(1u, writer_->ping_frames().size()); +} + +// When there is no stream data received but are open streams, send the +// first few consecutive pings with aggressive retransmittable-on-wire +// timeout. Exponentially back off the retransmittable-on-wire ping timeout +// afterwards until it exceeds the default ping timeout. +TEST_P(QuicConnectionTest, BackOffRetransmittableOnWireTimeout) { + int max_aggressive_retransmittable_on_wire_ping_count = 5; + SetQuicFlag(quic_max_aggressive_retransmittable_on_wire_ping_count, + max_aggressive_retransmittable_on_wire_ping_count); + const QuicTime::Delta initial_retransmittable_on_wire_timeout = + QuicTime::Delta::FromMilliseconds(200); + connection_.set_initial_retransmittable_on_wire_timeout( + initial_retransmittable_on_wire_timeout); + + EXPECT_TRUE(connection_.connected()); + EXPECT_CALL(visitor_, ShouldKeepConnectionAlive()) + .WillRepeatedly(Return(true)); + + const char data[] = "data"; + // Advance 5ms, send a retransmittable data packet to the peer. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + EXPECT_FALSE(connection_.GetPingAlarm()->IsSet()); + connection_.SendStreamDataWithString(1, data, 0, NO_FIN); + EXPECT_TRUE(connection_.sent_packet_manager().HasInFlightPackets()); + // The ping alarm is set for the ping timeout, not the shorter + // retransmittable_on_wire_timeout. + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(QuicTime::Delta::FromSeconds(kPingTimeoutSecs), + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)) + .Times(AnyNumber()); + + // Verify that the first few consecutive retransmittable on wire pings are + // sent with aggressive timeout. + for (int i = 0; i <= max_aggressive_retransmittable_on_wire_ping_count; i++) { + // Receive an ACK of the previous packet. This should set the ping alarm + // with the initial retransmittable-on-wire timeout. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + QuicPacketNumber ack_num = creator_->packet_number(); + QuicAckFrame frame = InitAckFrame( + {{QuicPacketNumber(ack_num), QuicPacketNumber(ack_num + 1)}}); + ProcessAckPacket(&frame); + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(initial_retransmittable_on_wire_timeout, + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + // Simulate the alarm firing and check that a PING is sent. + writer_->Reset(); + clock_.AdvanceTime(initial_retransmittable_on_wire_timeout); + connection_.GetPingAlarm()->Fire(); + } + + QuicTime::Delta retransmittable_on_wire_timeout = + initial_retransmittable_on_wire_timeout; + + // Verify subsequent pings are sent with timeout that is exponentially backed + // off. + while (retransmittable_on_wire_timeout * 2 < + QuicTime::Delta::FromSeconds(kPingTimeoutSecs)) { + // Receive an ACK for the previous PING. This should set the + // ping alarm with backed off retransmittable-on-wire timeout. + retransmittable_on_wire_timeout = retransmittable_on_wire_timeout * 2; + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + QuicPacketNumber ack_num = creator_->packet_number(); + QuicAckFrame frame = InitAckFrame( + {{QuicPacketNumber(ack_num), QuicPacketNumber(ack_num + 1)}}); + ProcessAckPacket(&frame); + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(retransmittable_on_wire_timeout, + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + + // Simulate the alarm firing and check that a PING is sent. + writer_->Reset(); + clock_.AdvanceTime(retransmittable_on_wire_timeout); + connection_.GetPingAlarm()->Fire(); + } + + // The ping alarm is set with default ping timeout. + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(QuicTime::Delta::FromSeconds(kPingTimeoutSecs), + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + + // Receive an ACK for the previous PING. The ping alarm is set with an + // earlier deadline. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + QuicPacketNumber ack_num = creator_->packet_number(); + QuicAckFrame frame = InitAckFrame( + {{QuicPacketNumber(ack_num), QuicPacketNumber(ack_num + 1)}}); + ProcessAckPacket(&frame); + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(QuicTime::Delta::FromSeconds(kPingTimeoutSecs) - + QuicTime::Delta::FromMilliseconds(5), + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); +} + +// This test verify that the count of consecutive aggressive pings is reset +// when new data is received. And it also verifies the connection resets +// the exponential back-off of the retransmittable-on-wire ping timeout +// after receiving new stream data. +TEST_P(QuicConnectionTest, ResetBackOffRetransmitableOnWireTimeout) { + int max_aggressive_retransmittable_on_wire_ping_count = 3; + SetQuicFlag(quic_max_aggressive_retransmittable_on_wire_ping_count, 3); + const QuicTime::Delta initial_retransmittable_on_wire_timeout = + QuicTime::Delta::FromMilliseconds(200); + connection_.set_initial_retransmittable_on_wire_timeout( + initial_retransmittable_on_wire_timeout); + + EXPECT_TRUE(connection_.connected()); + EXPECT_CALL(visitor_, ShouldKeepConnectionAlive()) + .WillRepeatedly(Return(true)); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)) + .Times(AnyNumber()); + + const char data[] = "data"; + // Advance 5ms, send a retransmittable data packet to the peer. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + EXPECT_FALSE(connection_.GetPingAlarm()->IsSet()); + connection_.SendStreamDataWithString(1, data, 0, NO_FIN); + EXPECT_TRUE(connection_.sent_packet_manager().HasInFlightPackets()); + // The ping alarm is set for the ping timeout, not the shorter + // retransmittable_on_wire_timeout. + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(QuicTime::Delta::FromSeconds(kPingTimeoutSecs), + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + + // Receive an ACK of the first packet. This should set the ping alarm with + // initial retransmittable-on-wire timeout since there is no retransmittable + // packet on the wire. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + QuicAckFrame frame = + InitAckFrame({{QuicPacketNumber(1), QuicPacketNumber(2)}}); + ProcessAckPacket(&frame); + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(initial_retransmittable_on_wire_timeout, + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + + // Simulate the alarm firing and check that a PING is sent. + writer_->Reset(); + clock_.AdvanceTime(initial_retransmittable_on_wire_timeout); + connection_.GetPingAlarm()->Fire(); + + // Receive an ACK for the previous PING. Ping alarm will be set with + // aggressive timeout. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + QuicPacketNumber ack_num = creator_->packet_number(); + frame = InitAckFrame( + {{QuicPacketNumber(ack_num), QuicPacketNumber(ack_num + 1)}}); + ProcessAckPacket(&frame); + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(initial_retransmittable_on_wire_timeout, + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + + // Process a data packet. + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacket(peer_creator_.packet_number() + 1); + QuicPacketCreatorPeer::SetPacketNumber(&peer_creator_, + peer_creator_.packet_number() + 1); + EXPECT_EQ(initial_retransmittable_on_wire_timeout, + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + clock_.AdvanceTime(initial_retransmittable_on_wire_timeout); + connection_.GetPingAlarm()->Fire(); + + // Verify the count of consecutive aggressive pings is reset. + for (int i = 0; i < max_aggressive_retransmittable_on_wire_ping_count; i++) { + // Receive an ACK of the previous packet. This should set the ping alarm + // with the initial retransmittable-on-wire timeout. + QuicPacketNumber ack_num = creator_->packet_number(); + QuicAckFrame frame = InitAckFrame( + {{QuicPacketNumber(ack_num), QuicPacketNumber(ack_num + 1)}}); + ProcessAckPacket(&frame); + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(initial_retransmittable_on_wire_timeout, + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + // Simulate the alarm firing and check that a PING is sent. + writer_->Reset(); + clock_.AdvanceTime(initial_retransmittable_on_wire_timeout); + connection_.GetPingAlarm()->Fire(); + // Advance 5ms to receive next packet. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + } + + // Receive another ACK for the previous PING. This should set the + // ping alarm with backed off retransmittable-on-wire timeout. + ack_num = creator_->packet_number(); + frame = InitAckFrame( + {{QuicPacketNumber(ack_num), QuicPacketNumber(ack_num + 1)}}); + ProcessAckPacket(&frame); + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(initial_retransmittable_on_wire_timeout * 2, + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + + writer_->Reset(); + clock_.AdvanceTime(2 * initial_retransmittable_on_wire_timeout); + connection_.GetPingAlarm()->Fire(); + + // Process another data packet and a new ACK packet. The ping alarm is set + // with aggressive ping timeout again. + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + ProcessDataPacket(peer_creator_.packet_number() + 1); + QuicPacketCreatorPeer::SetPacketNumber(&peer_creator_, + peer_creator_.packet_number() + 1); + ack_num = creator_->packet_number(); + frame = InitAckFrame( + {{QuicPacketNumber(ack_num), QuicPacketNumber(ack_num + 1)}}); + ProcessAckPacket(&frame); + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(initial_retransmittable_on_wire_timeout, + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); +} + +// Make sure that we never send more retransmissible on the wire pings than +// the limit in FLAGS_quic_max_retransmittable_on_wire_ping_count. +TEST_P(QuicConnectionTest, RetransmittableOnWirePingLimit) { + static constexpr int kMaxRetransmittableOnWirePingCount = 3; + SetQuicFlag(quic_max_retransmittable_on_wire_ping_count, + kMaxRetransmittableOnWirePingCount); + static constexpr QuicTime::Delta initial_retransmittable_on_wire_timeout = + QuicTime::Delta::FromMilliseconds(200); + static constexpr QuicTime::Delta short_delay = + QuicTime::Delta::FromMilliseconds(5); + ASSERT_LT(short_delay * 10, initial_retransmittable_on_wire_timeout); + connection_.set_initial_retransmittable_on_wire_timeout( + initial_retransmittable_on_wire_timeout); + + EXPECT_TRUE(connection_.connected()); + EXPECT_CALL(visitor_, ShouldKeepConnectionAlive()) + .WillRepeatedly(Return(true)); + + const char data[] = "data"; + // Advance 5ms, send a retransmittable data packet to the peer. + clock_.AdvanceTime(short_delay); + EXPECT_FALSE(connection_.GetPingAlarm()->IsSet()); + connection_.SendStreamDataWithString(1, data, 0, NO_FIN); + EXPECT_TRUE(connection_.sent_packet_manager().HasInFlightPackets()); + // The ping alarm is set for the ping timeout, not the shorter + // retransmittable_on_wire_timeout. + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(QuicTime::Delta::FromSeconds(kPingTimeoutSecs), + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)) + .Times(AnyNumber()); + + // Verify that the first few consecutive retransmittable on wire pings are + // sent with aggressive timeout. + for (int i = 0; i <= kMaxRetransmittableOnWirePingCount; i++) { + // Receive an ACK of the previous packet. This should set the ping alarm + // with the initial retransmittable-on-wire timeout. + clock_.AdvanceTime(short_delay); + QuicPacketNumber ack_num = creator_->packet_number(); + QuicAckFrame frame = InitAckFrame( + {{QuicPacketNumber(ack_num), QuicPacketNumber(ack_num + 1)}}); + ProcessAckPacket(&frame); + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(initial_retransmittable_on_wire_timeout, + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + // Simulate the alarm firing and check that a PING is sent. + writer_->Reset(); + clock_.AdvanceTime(initial_retransmittable_on_wire_timeout); + connection_.GetPingAlarm()->Fire(); + } + + // Receive an ACK of the previous packet. This should set the ping alarm + // but this time with the default ping timeout. + QuicPacketNumber ack_num = creator_->packet_number(); + QuicAckFrame frame = InitAckFrame( + {{QuicPacketNumber(ack_num), QuicPacketNumber(ack_num + 1)}}); + ProcessAckPacket(&frame); + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_EQ(QuicTime::Delta::FromSeconds(kPingTimeoutSecs), + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); +} + +TEST_P(QuicConnectionTest, ValidStatelessResetToken) { + const StatelessResetToken kTestToken{0, 1, 0, 1, 0, 1, 0, 1, + 0, 1, 0, 1, 0, 1, 0, 1}; + const StatelessResetToken kWrongTestToken{0, 1, 0, 1, 0, 1, 0, 1, + 0, 1, 0, 1, 0, 1, 0, 2}; + QuicConfig config; + // No token has been received. + EXPECT_FALSE(connection_.IsValidStatelessResetToken(kTestToken)); + + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)).Times(2); + // Token is different from received token. + QuicConfigPeer::SetReceivedStatelessResetToken(&config, kTestToken); + connection_.SetFromConfig(config); + EXPECT_FALSE(connection_.IsValidStatelessResetToken(kWrongTestToken)); + + QuicConfigPeer::SetReceivedStatelessResetToken(&config, kTestToken); + connection_.SetFromConfig(config); + EXPECT_TRUE(connection_.IsValidStatelessResetToken(kTestToken)); +} + +TEST_P(QuicConnectionTest, WriteBlockedWithInvalidAck) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)).Times(0); + BlockOnNextWrite(); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.SendStreamDataWithString(5, "foo", 0, FIN); + // This causes connection to be closed because packet 1 has not been sent yet. + QuicAckFrame frame = InitAckFrame(1); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _, _, _)); + ProcessAckPacket(1, &frame); + EXPECT_EQ(0, connection_close_frame_count_); +} + +TEST_P(QuicConnectionTest, SendMessage) { + if (!VersionSupportsMessageFrames(connection_.transport_version())) { + return; + } + if (connection_.version().UsesTls()) { + QuicConfig config; + QuicConfigPeer::SetReceivedMaxDatagramFrameSize( + &config, kMaxAcceptedDatagramFrameSize); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + } + std::string message(connection_.GetCurrentLargestMessagePayload() * 2, 'a'); + quiche::QuicheMemSlice slice; + { + QuicConnection::ScopedPacketFlusher flusher(&connection_); + connection_.SendStreamData3(); + // Send a message which cannot fit into current open packet, and 2 packets + // get sent, one contains stream frame, and the other only contains the + // message frame. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(2); + slice = MemSliceFromString(absl::string_view( + message.data(), connection_.GetCurrentLargestMessagePayload())); + EXPECT_EQ(MESSAGE_STATUS_SUCCESS, + connection_.SendMessage(1, absl::MakeSpan(&slice, 1), false)); + } + // Fail to send a message if connection is congestion control blocked. + EXPECT_CALL(*send_algorithm_, CanSend(_)).WillOnce(Return(false)); + slice = MemSliceFromString("message"); + EXPECT_EQ(MESSAGE_STATUS_BLOCKED, + connection_.SendMessage(2, absl::MakeSpan(&slice, 1), false)); + + // Always fail to send a message which cannot fit into one packet. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + slice = MemSliceFromString(absl::string_view( + message.data(), connection_.GetCurrentLargestMessagePayload() + 1)); + EXPECT_EQ(MESSAGE_STATUS_TOO_LARGE, + connection_.SendMessage(3, absl::MakeSpan(&slice, 1), false)); +} + +TEST_P(QuicConnectionTest, GetCurrentLargestMessagePayload) { + if (!connection_.version().SupportsMessageFrames()) { + return; + } + QuicPacketLength expected_largest_payload = 1215; + if (connection_.version().SendsVariableLengthPacketNumberInLongHeader()) { + expected_largest_payload += 3; + } + if (connection_.version().HasLongHeaderLengths()) { + expected_largest_payload -= 2; + } + if (connection_.version().HasLengthPrefixedConnectionIds()) { + expected_largest_payload -= 1; + } + if (connection_.version().UsesTls()) { + // QUIC+TLS disallows DATAGRAM/MESSAGE frames before the handshake. + EXPECT_EQ(connection_.GetCurrentLargestMessagePayload(), 0); + QuicConfig config; + QuicConfigPeer::SetReceivedMaxDatagramFrameSize( + &config, kMaxAcceptedDatagramFrameSize); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + // Verify the value post-handshake. + EXPECT_EQ(connection_.GetCurrentLargestMessagePayload(), + expected_largest_payload); + } else { + EXPECT_EQ(connection_.GetCurrentLargestMessagePayload(), + expected_largest_payload); + } +} + +TEST_P(QuicConnectionTest, GetGuaranteedLargestMessagePayload) { + if (!connection_.version().SupportsMessageFrames()) { + return; + } + QuicPacketLength expected_largest_payload = 1215; + if (connection_.version().HasLongHeaderLengths()) { + expected_largest_payload -= 2; + } + if (connection_.version().HasLengthPrefixedConnectionIds()) { + expected_largest_payload -= 1; + } + if (connection_.version().UsesTls()) { + // QUIC+TLS disallows DATAGRAM/MESSAGE frames before the handshake. + EXPECT_EQ(connection_.GetGuaranteedLargestMessagePayload(), 0); + QuicConfig config; + QuicConfigPeer::SetReceivedMaxDatagramFrameSize( + &config, kMaxAcceptedDatagramFrameSize); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + // Verify the value post-handshake. + EXPECT_EQ(connection_.GetGuaranteedLargestMessagePayload(), + expected_largest_payload); + } else { + EXPECT_EQ(connection_.GetGuaranteedLargestMessagePayload(), + expected_largest_payload); + } +} + +TEST_P(QuicConnectionTest, LimitedLargestMessagePayload) { + if (!connection_.version().SupportsMessageFrames() || + !connection_.version().UsesTls()) { + return; + } + constexpr QuicPacketLength kFrameSizeLimit = 1000; + constexpr QuicPacketLength kPayloadSizeLimit = + kFrameSizeLimit - kQuicFrameTypeSize; + // QUIC+TLS disallows DATAGRAM/MESSAGE frames before the handshake. + EXPECT_EQ(connection_.GetCurrentLargestMessagePayload(), 0); + EXPECT_EQ(connection_.GetGuaranteedLargestMessagePayload(), 0); + QuicConfig config; + QuicConfigPeer::SetReceivedMaxDatagramFrameSize(&config, kFrameSizeLimit); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + // Verify the value post-handshake. + EXPECT_EQ(connection_.GetCurrentLargestMessagePayload(), kPayloadSizeLimit); + EXPECT_EQ(connection_.GetGuaranteedLargestMessagePayload(), + kPayloadSizeLimit); +} + +// Test to check that the path challenge/path response logic works +// correctly. This test is only for version-99 +TEST_P(QuicConnectionTest, ServerResponseToPathChallenge) { + if (!VersionHasIetfQuicFrames(connection_.version().transport_version)) { + return; + } + PathProbeTestInit(Perspective::IS_SERVER); + QuicConnectionPeer::SetAddressValidated(&connection_); + // First check if the server can send probing packet. + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + + // Create and send the probe request (PATH_CHALLENGE frame). + // SendConnectivityProbingPacket ends up calling + // TestPacketWriter::WritePacket() which in turns receives and parses the + // packet by calling framer_.ProcessPacket() -- which in turn calls + // SimpleQuicFramer::OnPathChallengeFrame(). SimpleQuicFramer saves + // the packet in writer_->path_challenge_frames() + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.SendConnectivityProbingPacket(writer_.get(), + connection_.peer_address()); + // Save the random contents of the challenge for later comparison to the + // response. + ASSERT_GE(writer_->path_challenge_frames().size(), 1u); + QuicPathFrameBuffer challenge_data = + writer_->path_challenge_frames().front().data_buffer; + + // Normally, QuicConnection::OnPathChallengeFrame and OnPaddingFrame would be + // called and it will perform actions to ensure that the rest of the protocol + // is performed. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + EXPECT_TRUE(connection_.OnPathChallengeFrame( + writer_->path_challenge_frames().front())); + EXPECT_TRUE(connection_.OnPaddingFrame(writer_->padding_frames().front())); + creator_->FlushCurrentPacket(); + + // The final check is to ensure that the random data in the response matches + // the random data from the challenge. + EXPECT_EQ(1u, writer_->path_response_frames().size()); + EXPECT_EQ(0, memcmp(&challenge_data, + &(writer_->path_response_frames().front().data_buffer), + sizeof(challenge_data))); +} + +TEST_P(QuicConnectionTest, ClientResponseToPathChallengeOnDefaulSocket) { + if (!VersionHasIetfQuicFrames(connection_.version().transport_version)) { + return; + } + PathProbeTestInit(Perspective::IS_CLIENT); + // First check if the client can send probing packet. + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + + // Create and send the probe request (PATH_CHALLENGE frame). + // SendConnectivityProbingPacket ends up calling + // TestPacketWriter::WritePacket() which in turns receives and parses the + // packet by calling framer_.ProcessPacket() -- which in turn calls + // SimpleQuicFramer::OnPathChallengeFrame(). SimpleQuicFramer saves + // the packet in writer_->path_challenge_frames() + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.SendConnectivityProbingPacket(writer_.get(), + connection_.peer_address()); + // Save the random contents of the challenge for later validation against the + // response. + ASSERT_GE(writer_->path_challenge_frames().size(), 1u); + QuicPathFrameBuffer challenge_data = + writer_->path_challenge_frames().front().data_buffer; + + // Normally, QuicConnection::OnPathChallengeFrame would be + // called and it will perform actions to ensure that the rest of the protocol + // is performed. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + EXPECT_TRUE(connection_.OnPathChallengeFrame( + writer_->path_challenge_frames().front())); + EXPECT_TRUE(connection_.OnPaddingFrame(writer_->padding_frames().front())); + creator_->FlushCurrentPacket(); + + // The final check is to ensure that the random data in the response matches + // the random data from the challenge. + EXPECT_EQ(1u, writer_->path_response_frames().size()); + EXPECT_EQ(0, memcmp(&challenge_data, + &(writer_->path_response_frames().front().data_buffer), + sizeof(challenge_data))); +} + +TEST_P(QuicConnectionTest, ClientResponseToPathChallengeOnAlternativeSocket) { + if (!VersionHasIetfQuicFrames(connection_.version().transport_version)) { + return; + } + PathProbeTestInit(Perspective::IS_CLIENT); + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + + QuicSocketAddress kNewSelfAddress(QuicIpAddress::Loopback6(), /*port=*/23456); + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(AtLeast(1u)) + .WillOnce(Invoke([&]() { + EXPECT_EQ(1u, new_writer.packets_write_attempts()); + EXPECT_EQ(1u, new_writer.path_challenge_frames().size()); + EXPECT_EQ(1u, new_writer.padding_frames().size()); + EXPECT_EQ(kNewSelfAddress.host(), + new_writer.last_write_source_address()); + })); + bool success = false; + connection_.ValidatePath( + std::make_unique( + kNewSelfAddress, connection_.peer_address(), &new_writer), + std::make_unique( + &connection_, kNewSelfAddress, connection_.peer_address(), &success), + PathValidationReason::kReasonUnknown); + + // Receiving a PATH_CHALLENGE on the alternative path. Response to this + // PATH_CHALLENGE should be sent via the alternative writer. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(AtLeast(1u)) + .WillOnce(Invoke([&]() { + EXPECT_EQ(2u, new_writer.packets_write_attempts()); + EXPECT_EQ(1u, new_writer.path_response_frames().size()); + EXPECT_EQ(1u, new_writer.padding_frames().size()); + EXPECT_EQ(kNewSelfAddress.host(), + new_writer.last_write_source_address()); + })); + std::unique_ptr probing_packet = ConstructProbingPacket(); + std::unique_ptr received(ConstructReceivedPacket( + QuicEncryptedPacket(probing_packet->encrypted_buffer, + probing_packet->encrypted_length), + clock_.Now())); + ProcessReceivedPacket(kNewSelfAddress, kPeerAddress, *received); + + QuicSocketAddress kNewerSelfAddress(QuicIpAddress::Loopback6(), + /*port=*/34567); + // Receiving a PATH_CHALLENGE on an unknown socket should be ignored. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0u); + ProcessReceivedPacket(kNewerSelfAddress, kPeerAddress, *received); +} + +TEST_P(QuicConnectionTest, + RestartPathDegradingDetectionAfterMigrationWithProbe) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + PathProbeTestInit(Perspective::IS_CLIENT); + + // Send data and verify the path degrading detection is set. + const char data[] = "data"; + size_t data_size = strlen(data); + QuicStreamOffset offset = 0; + connection_.SendStreamDataWithString(1, data, offset, NO_FIN); + offset += data_size; + + // Verify the path degrading detection is in progress. + EXPECT_TRUE(connection_.PathDegradingDetectionInProgress()); + EXPECT_FALSE(connection_.IsPathDegrading()); + QuicTime ddl = connection_.GetBlackholeDetectorAlarm()->deadline(); + + // Simulate the firing of path degrading. + clock_.AdvanceTime(ddl - clock_.ApproximateNow()); + EXPECT_CALL(visitor_, OnPathDegrading()).Times(1); + connection_.PathDegradingTimeout(); + EXPECT_TRUE(connection_.IsPathDegrading()); + EXPECT_FALSE(connection_.PathDegradingDetectionInProgress()); + + if (!GetParam().version.HasIetfQuicFrames()) { + // Simulate path degrading handling by sending a probe on an alternet path. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + TestPacketWriter probing_writer(version(), &clock_, Perspective::IS_CLIENT); + connection_.SendConnectivityProbingPacket(&probing_writer, + connection_.peer_address()); + // Verify that path degrading detection is not reset. + EXPECT_FALSE(connection_.PathDegradingDetectionInProgress()); + + // Simulate successful path degrading handling by receiving probe response. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(20)); + + EXPECT_CALL(visitor_, + OnPacketReceived(_, _, /*is_connectivity_probe=*/true)) + .Times(1); + const QuicSocketAddress kNewSelfAddress = + QuicSocketAddress(QuicIpAddress::Loopback6(), /*port=*/23456); + + std::unique_ptr probing_packet = ConstructProbingPacket(); + std::unique_ptr received(ConstructReceivedPacket( + QuicEncryptedPacket(probing_packet->encrypted_buffer, + probing_packet->encrypted_length), + clock_.Now())); + uint64_t num_probing_received = + connection_.GetStats().num_connectivity_probing_received; + ProcessReceivedPacket(kNewSelfAddress, kPeerAddress, *received); + + EXPECT_EQ(num_probing_received + 1, + connection_.GetStats().num_connectivity_probing_received); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + EXPECT_TRUE(connection_.IsPathDegrading()); + } + + // Verify new path degrading detection is activated. + EXPECT_CALL(visitor_, OnForwardProgressMadeAfterPathDegrading()).Times(1); + connection_.OnSuccessfulMigration(/*is_port_change*/ true); + EXPECT_FALSE(connection_.IsPathDegrading()); + EXPECT_TRUE(connection_.PathDegradingDetectionInProgress()); +} + +TEST_P(QuicConnectionTest, ClientsResetCwndAfterConnectionMigration) { + if (!GetParam().version.HasIetfQuicFrames()) { + return; + } + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + PathProbeTestInit(Perspective::IS_CLIENT); + EXPECT_EQ(kSelfAddress, connection_.self_address()); + + RttStats* rtt_stats = const_cast(manager_->GetRttStats()); + QuicTime::Delta default_init_rtt = rtt_stats->initial_rtt(); + rtt_stats->set_initial_rtt(default_init_rtt * 2); + EXPECT_EQ(2 * default_init_rtt, rtt_stats->initial_rtt()); + + QuicSentPacketManagerPeer::SetConsecutivePtoCount(manager_, 1); + EXPECT_EQ(1u, manager_->GetConsecutivePtoCount()); + const SendAlgorithmInterface* send_algorithm = manager_->GetSendAlgorithm(); + + // Migrate to a new address with different IP. + const QuicSocketAddress kNewSelfAddress = + QuicSocketAddress(QuicIpAddress::Loopback4(), /*port=*/23456); + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); + connection_.MigratePath(kNewSelfAddress, connection_.peer_address(), + &new_writer, false); + EXPECT_EQ(default_init_rtt, manager_->GetRttStats()->initial_rtt()); + EXPECT_EQ(0u, manager_->GetConsecutivePtoCount()); + EXPECT_NE(send_algorithm, manager_->GetSendAlgorithm()); +} + +// Regression test for b/110259444 +TEST_P(QuicConnectionTest, DoNotScheduleSpuriousAckAlarm) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(visitor_, OnWriteBlocked()).Times(AtLeast(1)); + writer_->SetWriteBlocked(); + + ProcessPacket(1); + // Verify ack alarm is set. + EXPECT_TRUE(connection_.HasPendingAcks()); + // Fire the ack alarm, verify no packet is sent because the writer is blocked. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + connection_.GetAckAlarm()->Fire(); + + writer_->SetWritable(); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + ProcessPacket(2); + // Verify ack alarm is not set. + EXPECT_FALSE(connection_.HasPendingAcks()); +} + +TEST_P(QuicConnectionTest, DisablePacingOffloadConnectionOptions) { + EXPECT_FALSE(QuicConnectionPeer::SupportsReleaseTime(&connection_)); + writer_->set_supports_release_time(true); + QuicConfig config; + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + EXPECT_TRUE(QuicConnectionPeer::SupportsReleaseTime(&connection_)); + + QuicTagVector connection_options; + connection_options.push_back(kNPCO); + config.SetConnectionOptionsToSend(connection_options); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + // Verify pacing offload is disabled. + EXPECT_FALSE(QuicConnectionPeer::SupportsReleaseTime(&connection_)); +} + +// Regression test for b/110259444 +// Get a path response without having issued a path challenge... +TEST_P(QuicConnectionTest, OrphanPathResponse) { + QuicPathFrameBuffer data = {{0, 1, 2, 3, 4, 5, 6, 7}}; + + QuicPathResponseFrame frame(99, data); + EXPECT_TRUE(connection_.OnPathResponseFrame(frame)); + // If PATH_RESPONSE was accepted (payload matches the payload saved + // in QuicConnection::transmitted_connectivity_probe_payload_) then + // current_packet_content_ would be set to FIRST_FRAME_IS_PING. + // Since this PATH_RESPONSE does not match, current_packet_content_ + // must not be FIRST_FRAME_IS_PING. + EXPECT_NE(QuicConnection::FIRST_FRAME_IS_PING, + QuicConnectionPeer::GetCurrentPacketContent(&connection_)); +} + +// Regression test for b/120791670 +TEST_P(QuicConnectionTest, StopProcessingGQuicPacketInIetfQuicConnection) { + // This test mimics a problematic scenario where a QUIC connection using a + // modern version received a Q043 packet and processed it incorrectly. + // We can remove this test once Q043 is deprecated. + if (!version().HasIetfInvariantHeader()) { + return; + } + set_perspective(Perspective::IS_SERVER); + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(1); + } else { + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + } + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kSelfAddress, kPeerAddress, + ENCRYPTION_INITIAL); + + // Let connection process a Google QUIC packet. + peer_framer_.set_version_for_tests(ParsedQuicVersion::Q043()); + std::unique_ptr packet( + ConstructDataPacket(2, !kHasStopWaiting, ENCRYPTION_INITIAL)); + char buffer[kMaxOutgoingPacketSize]; + size_t encrypted_length = + peer_framer_.EncryptPayload(ENCRYPTION_INITIAL, QuicPacketNumber(2), + *packet, buffer, kMaxOutgoingPacketSize); + // Make sure no stream frame is processed. + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(0); + connection_.ProcessUdpPacket( + kSelfAddress, kPeerAddress, + QuicReceivedPacket(buffer, encrypted_length, clock_.Now(), false)); + + EXPECT_EQ(2u, connection_.GetStats().packets_received); + EXPECT_EQ(1u, connection_.GetStats().packets_processed); +} + +TEST_P(QuicConnectionTest, AcceptPacketNumberZero) { + if (!VersionHasIetfQuicFrames(version().transport_version)) { + return; + } + // Set first_sending_packet_number to be 0 to allow successfully processing + // acks which ack packet number 0. + QuicFramerPeer::SetFirstSendingPacketNumber(writer_->framer()->framer(), 0); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + ProcessPacket(0); + EXPECT_EQ(QuicPacketNumber(0), LargestAcked(connection_.ack_frame())); + EXPECT_EQ(1u, connection_.ack_frame().packets.NumIntervals()); + + ProcessPacket(1); + EXPECT_EQ(QuicPacketNumber(1), LargestAcked(connection_.ack_frame())); + EXPECT_EQ(1u, connection_.ack_frame().packets.NumIntervals()); + + ProcessPacket(2); + EXPECT_EQ(QuicPacketNumber(2), LargestAcked(connection_.ack_frame())); + EXPECT_EQ(1u, connection_.ack_frame().packets.NumIntervals()); +} + +TEST_P(QuicConnectionTest, MultiplePacketNumberSpacesBasicSending) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + connection_.SendCryptoStreamData(); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(*loss_algorithm_, DetectLosses(_, _, _, _, _, _)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + QuicAckFrame frame1 = InitAckFrame(1); + // Received ACK for packet 1. + ProcessFramePacketAtLevel(30, QuicFrame(&frame1), ENCRYPTION_INITIAL); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(4); + connection_.SendApplicationDataAtLevel(ENCRYPTION_ZERO_RTT, 5, "data", 0, + NO_FIN); + connection_.SendApplicationDataAtLevel(ENCRYPTION_ZERO_RTT, 5, "data", 4, + NO_FIN); + connection_.SendApplicationDataAtLevel(ENCRYPTION_FORWARD_SECURE, 5, "data", + 8, NO_FIN); + connection_.SendApplicationDataAtLevel(ENCRYPTION_FORWARD_SECURE, 5, "data", + 12, FIN); + // Received ACK for packets 2, 4, 5. + EXPECT_CALL(*loss_algorithm_, DetectLosses(_, _, _, _, _, _)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + QuicAckFrame frame2 = + InitAckFrame({{QuicPacketNumber(2), QuicPacketNumber(3)}, + {QuicPacketNumber(4), QuicPacketNumber(6)}}); + // Make sure although the same packet number is used, but they are in + // different packet number spaces. + ProcessFramePacketAtLevel(30, QuicFrame(&frame2), ENCRYPTION_FORWARD_SECURE); +} + +TEST_P(QuicConnectionTest, PeerAcksPacketsInWrongPacketNumberSpace) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + connection_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(0x01)); + + connection_.SendCryptoStreamData(); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(*loss_algorithm_, DetectLosses(_, _, _, _, _, _)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + QuicAckFrame frame1 = InitAckFrame(1); + // Received ACK for packet 1. + ProcessFramePacketAtLevel(30, QuicFrame(&frame1), ENCRYPTION_INITIAL); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(2); + connection_.SendApplicationDataAtLevel(ENCRYPTION_ZERO_RTT, 5, "data", 0, + NO_FIN); + connection_.SendApplicationDataAtLevel(ENCRYPTION_ZERO_RTT, 5, "data", 4, + NO_FIN); + + // Received ACK for packets 2 and 3 in wrong packet number space. + QuicAckFrame invalid_ack = + InitAckFrame({{QuicPacketNumber(2), QuicPacketNumber(4)}}); + EXPECT_CALL(visitor_, + OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(AtLeast(1)); + ProcessFramePacketAtLevel(300, QuicFrame(&invalid_ack), ENCRYPTION_INITIAL); + TestConnectionCloseQuicErrorCode(QUIC_INVALID_ACK_DATA); +} + +TEST_P(QuicConnectionTest, MultiplePacketNumberSpacesBasicReceiving) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + } + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(AnyNumber()); + // Receives packet 1000 in initial data. + ProcessCryptoPacketAtLevel(1000, ENCRYPTION_INITIAL); + EXPECT_TRUE(connection_.HasPendingAcks()); + peer_framer_.SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + SetDecrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + // Receives packet 1000 in application data. + ProcessDataPacketAtLevel(1000, false, ENCRYPTION_FORWARD_SECURE); + EXPECT_TRUE(connection_.HasPendingAcks()); + connection_.SendApplicationDataAtLevel(ENCRYPTION_FORWARD_SECURE, 5, "data", + 0, NO_FIN); + // Verify application data ACK gets bundled with outgoing data. + EXPECT_EQ(2u, writer_->frame_count()); + // Make sure ACK alarm is still set because initial data is not ACKed. + EXPECT_TRUE(connection_.HasPendingAcks()); + // Receive packet 1001 in application data. + ProcessDataPacketAtLevel(1001, false, ENCRYPTION_FORWARD_SECURE); + clock_.AdvanceTime(DefaultRetransmissionTime()); + // Simulates ACK alarm fires and verify two ACKs are flushed. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(2); + connection_.SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + connection_.GetAckAlarm()->Fire(); + EXPECT_FALSE(connection_.HasPendingAcks()); + // Receives more packets in application data. + ProcessDataPacketAtLevel(1002, false, ENCRYPTION_FORWARD_SECURE); + EXPECT_TRUE(connection_.HasPendingAcks()); + + // Verify zero rtt and forward secure packets get acked in the same packet. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + ProcessDataPacket(1003); + EXPECT_FALSE(connection_.HasPendingAcks()); +} + +TEST_P(QuicConnectionTest, CancelAckAlarmOnWriteBlocked) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + } + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(AnyNumber()); + // Receives packet 1000 in initial data. + ProcessCryptoPacketAtLevel(1000, ENCRYPTION_INITIAL); + EXPECT_TRUE(connection_.HasPendingAcks()); + peer_framer_.SetEncrypter( + ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + SetDecrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + // Receives packet 1000 in application data. + ProcessDataPacketAtLevel(1000, false, ENCRYPTION_ZERO_RTT); + EXPECT_TRUE(connection_.HasPendingAcks()); + + writer_->SetWriteBlocked(); + EXPECT_CALL(visitor_, OnWriteBlocked()).Times(AnyNumber()); + // Simulates ACK alarm fires and verify no ACK is flushed because of write + // blocked. + clock_.AdvanceTime(DefaultDelayedAckTime()); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + connection_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(0x02)); + connection_.GetAckAlarm()->Fire(); + // Verify ACK alarm is not set. + EXPECT_FALSE(connection_.HasPendingAcks()); + + writer_->SetWritable(); + // Verify 2 ACKs are sent when connection gets unblocked. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(2); + connection_.OnCanWrite(); + EXPECT_FALSE(connection_.HasPendingAcks()); +} + +// Make sure a packet received with the right client connection ID is processed. +TEST_P(QuicConnectionTest, ValidClientConnectionId) { + if (!framer_.version().SupportsClientConnectionIds()) { + return; + } + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + SetClientConnectionId(TestConnectionId(0x33)); + QuicPacketHeader header = ConstructPacketHeader(1, ENCRYPTION_FORWARD_SECURE); + header.destination_connection_id = TestConnectionId(0x33); + header.destination_connection_id_included = CONNECTION_ID_PRESENT; + header.source_connection_id_included = CONNECTION_ID_ABSENT; + QuicFrames frames; + QuicPingFrame ping_frame; + QuicPaddingFrame padding_frame; + frames.push_back(QuicFrame(ping_frame)); + frames.push_back(QuicFrame(padding_frame)); + std::unique_ptr packet = + BuildUnsizedDataPacket(&peer_framer_, header, frames); + char buffer[kMaxOutgoingPacketSize]; + size_t encrypted_length = peer_framer_.EncryptPayload( + ENCRYPTION_FORWARD_SECURE, QuicPacketNumber(1), *packet, buffer, + kMaxOutgoingPacketSize); + QuicReceivedPacket received_packet(buffer, encrypted_length, clock_.Now(), + false); + EXPECT_EQ(0u, connection_.GetStats().packets_dropped); + ProcessReceivedPacket(kSelfAddress, kPeerAddress, received_packet); + EXPECT_EQ(0u, connection_.GetStats().packets_dropped); +} + +// Make sure a packet received with a different client connection ID is dropped. +TEST_P(QuicConnectionTest, InvalidClientConnectionId) { + if (!framer_.version().SupportsClientConnectionIds()) { + return; + } + SetClientConnectionId(TestConnectionId(0x33)); + QuicPacketHeader header = ConstructPacketHeader(1, ENCRYPTION_FORWARD_SECURE); + header.destination_connection_id = TestConnectionId(0xbad); + header.destination_connection_id_included = CONNECTION_ID_PRESENT; + header.source_connection_id_included = CONNECTION_ID_ABSENT; + QuicFrames frames; + QuicPingFrame ping_frame; + QuicPaddingFrame padding_frame; + frames.push_back(QuicFrame(ping_frame)); + frames.push_back(QuicFrame(padding_frame)); + std::unique_ptr packet = + BuildUnsizedDataPacket(&peer_framer_, header, frames); + char buffer[kMaxOutgoingPacketSize]; + size_t encrypted_length = peer_framer_.EncryptPayload( + ENCRYPTION_FORWARD_SECURE, QuicPacketNumber(1), *packet, buffer, + kMaxOutgoingPacketSize); + QuicReceivedPacket received_packet(buffer, encrypted_length, clock_.Now(), + false); + EXPECT_EQ(0u, connection_.GetStats().packets_dropped); + ProcessReceivedPacket(kSelfAddress, kPeerAddress, received_packet); + EXPECT_EQ(1u, connection_.GetStats().packets_dropped); +} + +// Make sure the first packet received with a different client connection ID on +// the server is processed and it changes the client connection ID. +TEST_P(QuicConnectionTest, UpdateClientConnectionIdFromFirstPacket) { + if (!framer_.version().SupportsClientConnectionIds()) { + return; + } + set_perspective(Perspective::IS_SERVER); + QuicPacketHeader header = ConstructPacketHeader(1, ENCRYPTION_INITIAL); + header.source_connection_id = TestConnectionId(0x33); + header.source_connection_id_included = CONNECTION_ID_PRESENT; + QuicFrames frames; + QuicPingFrame ping_frame; + QuicPaddingFrame padding_frame; + frames.push_back(QuicFrame(ping_frame)); + frames.push_back(QuicFrame(padding_frame)); + std::unique_ptr packet = + BuildUnsizedDataPacket(&peer_framer_, header, frames); + char buffer[kMaxOutgoingPacketSize]; + size_t encrypted_length = + peer_framer_.EncryptPayload(ENCRYPTION_INITIAL, QuicPacketNumber(1), + *packet, buffer, kMaxOutgoingPacketSize); + QuicReceivedPacket received_packet(buffer, encrypted_length, clock_.Now(), + false); + EXPECT_EQ(0u, connection_.GetStats().packets_dropped); + ProcessReceivedPacket(kSelfAddress, kPeerAddress, received_packet); + EXPECT_EQ(0u, connection_.GetStats().packets_dropped); + EXPECT_EQ(TestConnectionId(0x33), connection_.client_connection_id()); +} +void QuicConnectionTest::TestReplaceConnectionIdFromInitial() { + if (!framer_.version().AllowsVariableLengthConnectionIds()) { + return; + } + // We start with a known connection ID. + EXPECT_TRUE(connection_.connected()); + EXPECT_EQ(0u, connection_.GetStats().packets_dropped); + EXPECT_NE(TestConnectionId(0x33), connection_.connection_id()); + // Receiving an initial can replace the connection ID once. + { + QuicPacketHeader header = ConstructPacketHeader(1, ENCRYPTION_INITIAL); + header.source_connection_id = TestConnectionId(0x33); + header.source_connection_id_included = CONNECTION_ID_PRESENT; + QuicFrames frames; + QuicPingFrame ping_frame; + QuicPaddingFrame padding_frame; + frames.push_back(QuicFrame(ping_frame)); + frames.push_back(QuicFrame(padding_frame)); + std::unique_ptr packet = + BuildUnsizedDataPacket(&peer_framer_, header, frames); + char buffer[kMaxOutgoingPacketSize]; + size_t encrypted_length = + peer_framer_.EncryptPayload(ENCRYPTION_INITIAL, QuicPacketNumber(1), + *packet, buffer, kMaxOutgoingPacketSize); + QuicReceivedPacket received_packet(buffer, encrypted_length, clock_.Now(), + false); + ProcessReceivedPacket(kSelfAddress, kPeerAddress, received_packet); + } + EXPECT_TRUE(connection_.connected()); + EXPECT_EQ(0u, connection_.GetStats().packets_dropped); + EXPECT_EQ(TestConnectionId(0x33), connection_.connection_id()); + // Trying to replace the connection ID a second time drops the packet. + { + QuicPacketHeader header = ConstructPacketHeader(2, ENCRYPTION_INITIAL); + header.source_connection_id = TestConnectionId(0x66); + header.source_connection_id_included = CONNECTION_ID_PRESENT; + QuicFrames frames; + QuicPingFrame ping_frame; + QuicPaddingFrame padding_frame; + frames.push_back(QuicFrame(ping_frame)); + frames.push_back(QuicFrame(padding_frame)); + std::unique_ptr packet = + BuildUnsizedDataPacket(&peer_framer_, header, frames); + char buffer[kMaxOutgoingPacketSize]; + size_t encrypted_length = + peer_framer_.EncryptPayload(ENCRYPTION_INITIAL, QuicPacketNumber(2), + *packet, buffer, kMaxOutgoingPacketSize); + QuicReceivedPacket received_packet(buffer, encrypted_length, clock_.Now(), + false); + ProcessReceivedPacket(kSelfAddress, kPeerAddress, received_packet); + } + EXPECT_TRUE(connection_.connected()); + EXPECT_EQ(1u, connection_.GetStats().packets_dropped); + EXPECT_EQ(TestConnectionId(0x33), connection_.connection_id()); +} + +TEST_P(QuicConnectionTest, ReplaceServerConnectionIdFromInitial) { + TestReplaceConnectionIdFromInitial(); +} + +TEST_P(QuicConnectionTest, ReplaceServerConnectionIdFromRetryAndInitial) { + // First make the connection process a RETRY and replace the server connection + // ID a first time. + TestClientRetryHandling(/*invalid_retry_tag=*/false, + /*missing_original_id_in_config=*/false, + /*wrong_original_id_in_config=*/false, + /*missing_retry_id_in_config=*/false, + /*wrong_retry_id_in_config=*/false); + // Reset the test framer to use the right connection ID. + peer_framer_.SetInitialObfuscators(connection_.connection_id()); + // Now process an INITIAL and replace the server connection ID a second time. + TestReplaceConnectionIdFromInitial(); +} + +// Regression test for b/134416344. +TEST_P(QuicConnectionTest, CheckConnectedBeforeFlush) { + // This test mimics a scenario where a connection processes 2 packets and the + // 2nd packet contains connection close frame. When the 2nd flusher goes out + // of scope, a delayed ACK is pending, and ACK alarm should not be scheduled + // because connection is disconnected. + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)); + EXPECT_EQ(Perspective::IS_CLIENT, connection_.perspective()); + const QuicErrorCode kErrorCode = QUIC_INTERNAL_ERROR; + std::unique_ptr connection_close_frame( + new QuicConnectionCloseFrame(connection_.transport_version(), kErrorCode, + NO_IETF_QUIC_ERROR, "", + /*transport_close_frame_type=*/0)); + + // Received 2 packets. + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + } else { + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(AnyNumber()); + } + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kSelfAddress, kPeerAddress, + ENCRYPTION_INITIAL); + EXPECT_TRUE(connection_.HasPendingAcks()); + ProcessFramePacketWithAddresses(QuicFrame(connection_close_frame.release()), + kSelfAddress, kPeerAddress, + ENCRYPTION_INITIAL); + // Verify ack alarm is not set. + EXPECT_FALSE(connection_.HasPendingAcks()); +} + +// Verify that a packet containing three coalesced packets is parsed correctly. +TEST_P(QuicConnectionTest, CoalescedPacket) { + if (!QuicVersionHasLongHeaderLengths(connection_.transport_version())) { + // Coalesced packets can only be encoded using long header lengths. + return; + } + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_TRUE(connection_.connected()); + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(3); + } else { + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(3); + } + + uint64_t packet_numbers[3] = {1, 2, 3}; + EncryptionLevel encryption_levels[3] = { + ENCRYPTION_INITIAL, ENCRYPTION_INITIAL, ENCRYPTION_FORWARD_SECURE}; + char buffer[kMaxOutgoingPacketSize] = {}; + size_t total_encrypted_length = 0; + for (int i = 0; i < 3; i++) { + QuicPacketHeader header = + ConstructPacketHeader(packet_numbers[i], encryption_levels[i]); + QuicFrames frames; + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + frames.push_back(QuicFrame(&crypto_frame_)); + } else { + frames.push_back(QuicFrame(frame1_)); + } + std::unique_ptr packet = ConstructPacket(header, frames); + peer_creator_.set_encryption_level(encryption_levels[i]); + size_t encrypted_length = peer_framer_.EncryptPayload( + encryption_levels[i], QuicPacketNumber(packet_numbers[i]), *packet, + buffer + total_encrypted_length, + sizeof(buffer) - total_encrypted_length); + EXPECT_GT(encrypted_length, 0u); + total_encrypted_length += encrypted_length; + } + connection_.ProcessUdpPacket( + kSelfAddress, kPeerAddress, + QuicReceivedPacket(buffer, total_encrypted_length, clock_.Now(), false)); + if (connection_.GetSendAlarm()->IsSet()) { + connection_.GetSendAlarm()->Fire(); + } + + EXPECT_TRUE(connection_.connected()); +} + +// Regression test for crbug.com/992831. +TEST_P(QuicConnectionTest, CoalescedPacketThatSavesFrames) { + if (!QuicVersionHasLongHeaderLengths(connection_.transport_version())) { + // Coalesced packets can only be encoded using long header lengths. + return; + } + if (connection_.SupportsMultiplePacketNumberSpaces()) { + // TODO(b/129151114) Enable this test with multiple packet number spaces. + return; + } + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_TRUE(connection_.connected()); + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)) + .Times(3) + .WillRepeatedly([this](const QuicCryptoFrame& /*frame*/) { + // QuicFrame takes ownership of the QuicBlockedFrame. + connection_.SendControlFrame(QuicFrame(QuicBlockedFrame(1, 3, 0))); + }); + } else { + EXPECT_CALL(visitor_, OnStreamFrame(_)) + .Times(3) + .WillRepeatedly([this](const QuicStreamFrame& /*frame*/) { + // QuicFrame takes ownership of the QuicBlockedFrame. + connection_.SendControlFrame(QuicFrame(QuicBlockedFrame(1, 3, 0))); + }); + } + + uint64_t packet_numbers[3] = {1, 2, 3}; + EncryptionLevel encryption_levels[3] = { + ENCRYPTION_INITIAL, ENCRYPTION_INITIAL, ENCRYPTION_FORWARD_SECURE}; + char buffer[kMaxOutgoingPacketSize] = {}; + size_t total_encrypted_length = 0; + for (int i = 0; i < 3; i++) { + QuicPacketHeader header = + ConstructPacketHeader(packet_numbers[i], encryption_levels[i]); + QuicFrames frames; + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + frames.push_back(QuicFrame(&crypto_frame_)); + } else { + frames.push_back(QuicFrame(frame1_)); + } + std::unique_ptr packet = ConstructPacket(header, frames); + peer_creator_.set_encryption_level(encryption_levels[i]); + size_t encrypted_length = peer_framer_.EncryptPayload( + encryption_levels[i], QuicPacketNumber(packet_numbers[i]), *packet, + buffer + total_encrypted_length, + sizeof(buffer) - total_encrypted_length); + EXPECT_GT(encrypted_length, 0u); + total_encrypted_length += encrypted_length; + } + connection_.ProcessUdpPacket( + kSelfAddress, kPeerAddress, + QuicReceivedPacket(buffer, total_encrypted_length, clock_.Now(), false)); + if (connection_.GetSendAlarm()->IsSet()) { + connection_.GetSendAlarm()->Fire(); + } + + EXPECT_TRUE(connection_.connected()); + + SendAckPacketToPeer(); +} + +// Regresstion test for b/138962304. +TEST_P(QuicConnectionTest, RtoAndWriteBlocked) { + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + + QuicStreamId stream_id = 2; + QuicPacketNumber last_data_packet; + SendStreamDataToPeer(stream_id, "foo", 0, NO_FIN, &last_data_packet); + EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + + // Writer gets blocked. + writer_->SetWriteBlocked(); + + // Cancel the stream. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + EXPECT_CALL(visitor_, OnWriteBlocked()).Times(AtLeast(1)); + EXPECT_CALL(visitor_, WillingAndAbleToWrite()) + .WillRepeatedly( + Invoke(¬ifier_, &SimpleSessionNotifier::WillingToWrite)); + SendRstStream(stream_id, QUIC_ERROR_PROCESSING_STREAM, 3); + + // Retransmission timer fires in RTO mode. + connection_.GetRetransmissionAlarm()->Fire(); + // Verify no packets get flushed when writer is blocked. + EXPECT_EQ(0u, connection_.NumQueuedPackets()); +} + +// Regresstion test for b/138962304. +TEST_P(QuicConnectionTest, PtoAndWriteBlocked) { + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + + QuicStreamId stream_id = 2; + QuicPacketNumber last_data_packet; + SendStreamDataToPeer(stream_id, "foo", 0, NO_FIN, &last_data_packet); + SendStreamDataToPeer(4, "foo", 0, NO_FIN, &last_data_packet); + EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + + // Writer gets blocked. + writer_->SetWriteBlocked(); + + // Cancel stream 2. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + EXPECT_CALL(visitor_, OnWriteBlocked()).Times(AtLeast(1)); + SendRstStream(stream_id, QUIC_ERROR_PROCESSING_STREAM, 3); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + // Retransmission timer fires in TLP mode. + connection_.GetRetransmissionAlarm()->Fire(); + // Verify one packets is forced flushed when writer is blocked. + EXPECT_EQ(1u, connection_.NumQueuedPackets()); +} + +TEST_P(QuicConnectionTest, ProbeTimeout) { + QuicConfig config; + QuicTagVector connection_options; + connection_options.push_back(k2PTO); + config.SetConnectionOptionsToSend(connection_options); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + + QuicStreamId stream_id = 2; + QuicPacketNumber last_packet; + SendStreamDataToPeer(stream_id, "foooooo", 0, NO_FIN, &last_packet); + SendStreamDataToPeer(stream_id, "foooooo", 7, NO_FIN, &last_packet); + EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + + // Reset stream. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + SendRstStream(stream_id, QUIC_ERROR_PROCESSING_STREAM, 3); + + // Fire the PTO and verify only the RST_STREAM is resent, not stream data. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.GetRetransmissionAlarm()->Fire(); + EXPECT_EQ(0u, writer_->stream_frames().size()); + EXPECT_EQ(1u, writer_->rst_stream_frames().size()); + EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); +} + +TEST_P(QuicConnectionTest, CloseConnectionAfter6ClientPTOs) { + QuicConfig config; + QuicTagVector connection_options; + connection_options.push_back(k1PTO); + connection_options.push_back(k6PTO); + config.SetConnectionOptionsToSend(connection_options); + QuicConfigPeer::SetNegotiated(&config, true); + if (connection_.version().UsesTls()) { + QuicConfigPeer::SetReceivedOriginalConnectionId( + &config, connection_.connection_id()); + QuicConfigPeer::SetReceivedInitialSourceConnectionId( + &config, connection_.connection_id()); + } + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + if (GetQuicReloadableFlag(quic_default_enable_5rto_blackhole_detection2) || + GetQuicReloadableFlag( + quic_no_path_degrading_before_handshake_confirmed)) { + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + } + connection_.OnHandshakeComplete(); + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + + // Send stream data. + SendStreamDataToPeer( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), "foo", + 0, FIN, nullptr); + + // Fire the retransmission alarm 5 times. + for (int i = 0; i < 5; ++i) { + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.GetRetransmissionAlarm()->Fire(); + EXPECT_TRUE(connection_.GetTimeoutAlarm()->IsSet()); + EXPECT_TRUE(connection_.connected()); + } + EXPECT_CALL(visitor_, OnPathDegrading()); + connection_.PathDegradingTimeout(); + + EXPECT_EQ(5u, connection_.sent_packet_manager().GetConsecutivePtoCount()); + // Closes connection on 6th PTO. + // May send multiple connecction close packets with multiple PN spaces. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(AtLeast(1)); + EXPECT_CALL(visitor_, + OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)); + ASSERT_TRUE(connection_.BlackholeDetectionInProgress()); + connection_.GetBlackholeDetectorAlarm()->Fire(); + EXPECT_FALSE(connection_.GetTimeoutAlarm()->IsSet()); + EXPECT_FALSE(connection_.connected()); + TestConnectionCloseQuicErrorCode(QUIC_TOO_MANY_RTOS); +} + +TEST_P(QuicConnectionTest, CloseConnectionAfter7ClientPTOs) { + QuicConfig config; + QuicTagVector connection_options; + connection_options.push_back(k2PTO); + connection_options.push_back(k7PTO); + config.SetConnectionOptionsToSend(connection_options); + QuicConfigPeer::SetNegotiated(&config, true); + if (connection_.version().UsesTls()) { + QuicConfigPeer::SetReceivedOriginalConnectionId( + &config, connection_.connection_id()); + QuicConfigPeer::SetReceivedInitialSourceConnectionId( + &config, connection_.connection_id()); + } + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + if (GetQuicReloadableFlag(quic_default_enable_5rto_blackhole_detection2) || + GetQuicReloadableFlag( + quic_no_path_degrading_before_handshake_confirmed)) { + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + } + connection_.OnHandshakeComplete(); + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + + // Send stream data. + SendStreamDataToPeer( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), "foo", + 0, FIN, nullptr); + + // Fire the retransmission alarm 6 times. + for (int i = 0; i < 6; ++i) { + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)); + connection_.GetRetransmissionAlarm()->Fire(); + EXPECT_TRUE(connection_.GetTimeoutAlarm()->IsSet()); + EXPECT_TRUE(connection_.connected()); + } + EXPECT_CALL(visitor_, OnPathDegrading()); + connection_.PathDegradingTimeout(); + + EXPECT_EQ(6u, connection_.sent_packet_manager().GetConsecutivePtoCount()); + // Closes connection on 7th PTO. + EXPECT_CALL(visitor_, + OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(AtLeast(1)); + ASSERT_TRUE(connection_.BlackholeDetectionInProgress()); + connection_.GetBlackholeDetectorAlarm()->Fire(); + EXPECT_FALSE(connection_.GetTimeoutAlarm()->IsSet()); + EXPECT_FALSE(connection_.connected()); + TestConnectionCloseQuicErrorCode(QUIC_TOO_MANY_RTOS); +} + +TEST_P(QuicConnectionTest, CloseConnectionAfter8ClientPTOs) { + QuicConfig config; + QuicTagVector connection_options; + connection_options.push_back(k2PTO); + connection_options.push_back(k8PTO); + QuicConfigPeer::SetNegotiated(&config, true); + if (connection_.version().UsesTls()) { + QuicConfigPeer::SetReceivedOriginalConnectionId( + &config, connection_.connection_id()); + QuicConfigPeer::SetReceivedInitialSourceConnectionId( + &config, connection_.connection_id()); + } + config.SetConnectionOptionsToSend(connection_options); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + if (GetQuicReloadableFlag(quic_default_enable_5rto_blackhole_detection2) || + GetQuicReloadableFlag( + quic_no_path_degrading_before_handshake_confirmed)) { + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + } + connection_.OnHandshakeComplete(); + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + + // Send stream data. + SendStreamDataToPeer( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), "foo", + 0, FIN, nullptr); + + // Fire the retransmission alarm 7 times. + for (int i = 0; i < 7; ++i) { + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)); + connection_.GetRetransmissionAlarm()->Fire(); + EXPECT_TRUE(connection_.GetTimeoutAlarm()->IsSet()); + EXPECT_TRUE(connection_.connected()); + } + EXPECT_CALL(visitor_, OnPathDegrading()); + connection_.PathDegradingTimeout(); + + EXPECT_EQ(7u, connection_.sent_packet_manager().GetConsecutivePtoCount()); + // Closes connection on 8th PTO. + EXPECT_CALL(visitor_, + OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(AtLeast(1)); + ASSERT_TRUE(connection_.BlackholeDetectionInProgress()); + connection_.GetBlackholeDetectorAlarm()->Fire(); + EXPECT_FALSE(connection_.GetTimeoutAlarm()->IsSet()); + EXPECT_FALSE(connection_.connected()); + TestConnectionCloseQuicErrorCode(QUIC_TOO_MANY_RTOS); +} + +TEST_P(QuicConnectionTest, DeprecateHandshakeMode) { + if (!connection_.version().SupportsAntiAmplificationLimit()) { + return; + } + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + + // Send CHLO. + connection_.SendCryptoStreamData(); + EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + + EXPECT_CALL(*loss_algorithm_, DetectLosses(_, _, _, _, _, _)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + QuicAckFrame frame1 = InitAckFrame(1); + // Received ACK for packet 1. + ProcessFramePacketAtLevel(1, QuicFrame(&frame1), ENCRYPTION_INITIAL); + + // Verify retransmission alarm is still set because handshake is not + // confirmed although there is nothing in flight. + EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + EXPECT_EQ(0u, connection_.GetStats().pto_count); + EXPECT_EQ(0u, connection_.GetStats().crypto_retransmit_count); + + // PTO fires, verify a PING packet gets sent because there is no data to send. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, QuicPacketNumber(3), _, _)); + connection_.GetRetransmissionAlarm()->Fire(); + EXPECT_EQ(1u, connection_.GetStats().pto_count); + EXPECT_EQ(1u, connection_.GetStats().crypto_retransmit_count); + EXPECT_EQ(1u, writer_->ping_frames().size()); +} + +TEST_P(QuicConnectionTest, AntiAmplificationLimit) { + if (!connection_.version().SupportsAntiAmplificationLimit() || + GetQuicFlag(quic_enforce_strict_amplification_factor)) { + return; + } + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + + set_perspective(Perspective::IS_SERVER); + // Verify no data can be sent at the beginning because bytes received is 0. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + connection_.SendCryptoDataWithString("foo", 0); + EXPECT_FALSE(connection_.CanWrite(HAS_RETRANSMITTABLE_DATA)); + EXPECT_FALSE(connection_.CanWrite(NO_RETRANSMITTABLE_DATA)); + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + + // Receives packet 1. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL); + + const size_t anti_amplification_factor = + GetQuicFlag(quic_anti_amplification_factor); + // Verify now packets can be sent. + for (size_t i = 1; i < anti_amplification_factor; ++i) { + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.SendCryptoDataWithString("foo", i * 3); + // Verify retransmission alarm is not set if throttled by anti-amplification + // limit. + EXPECT_EQ(i != anti_amplification_factor - 1, + connection_.GetRetransmissionAlarm()->IsSet()); + } + // Verify server is throttled by anti-amplification limit. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + connection_.SendCryptoDataWithString("foo", anti_amplification_factor * 3); + + // Receives packet 2. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + ProcessCryptoPacketAtLevel(2, ENCRYPTION_INITIAL); + // Verify more packets can be sent. + for (size_t i = anti_amplification_factor + 1; + i < anti_amplification_factor * 2; ++i) { + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.SendCryptoDataWithString("foo", i * 3); + } + // Verify server is throttled by anti-amplification limit. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + connection_.SendCryptoDataWithString("foo", + 2 * anti_amplification_factor * 3); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + ProcessPacket(3); + // Verify anti-amplification limit is gone after address validation. + for (size_t i = 0; i < 100; ++i) { + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.SendStreamDataWithString(3, "first", i * 0, NO_FIN); + } +} + +TEST_P(QuicConnectionTest, 3AntiAmplificationLimit) { + if (!connection_.version().SupportsAntiAmplificationLimit() || + GetQuicFlag(quic_enforce_strict_amplification_factor)) { + return; + } + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + + set_perspective(Perspective::IS_SERVER); + QuicConfig config; + QuicTagVector connection_options; + connection_options.push_back(k3AFF); + config.SetInitialReceivedConnectionOptions(connection_options); + if (connection_.version().UsesTls()) { + QuicConfigPeer::SetReceivedOriginalConnectionId( + &config, connection_.connection_id()); + QuicConfigPeer::SetReceivedInitialSourceConnectionId(&config, + QuicConnectionId()); + } + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + + // Verify no data can be sent at the beginning because bytes received is 0. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + connection_.SendCryptoDataWithString("foo", 0); + EXPECT_FALSE(connection_.CanWrite(HAS_RETRANSMITTABLE_DATA)); + EXPECT_FALSE(connection_.CanWrite(NO_RETRANSMITTABLE_DATA)); + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + + // Receives packet 1. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL); + + const size_t anti_amplification_factor = 3; + // Verify now packets can be sent. + for (size_t i = 1; i < anti_amplification_factor; ++i) { + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.SendCryptoDataWithString("foo", i * 3); + // Verify retransmission alarm is not set if throttled by anti-amplification + // limit. + EXPECT_EQ(i != anti_amplification_factor - 1, + connection_.GetRetransmissionAlarm()->IsSet()); + } + // Verify server is throttled by anti-amplification limit. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + connection_.SendCryptoDataWithString("foo", anti_amplification_factor * 3); + + // Receives packet 2. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + ProcessCryptoPacketAtLevel(2, ENCRYPTION_INITIAL); + // Verify more packets can be sent. + for (size_t i = anti_amplification_factor + 1; + i < anti_amplification_factor * 2; ++i) { + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.SendCryptoDataWithString("foo", i * 3); + } + // Verify server is throttled by anti-amplification limit. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + connection_.SendCryptoDataWithString("foo", + 2 * anti_amplification_factor * 3); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + ProcessPacket(3); + // Verify anti-amplification limit is gone after address validation. + for (size_t i = 0; i < 100; ++i) { + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.SendStreamDataWithString(3, "first", i * 0, NO_FIN); + } +} + +TEST_P(QuicConnectionTest, 10AntiAmplificationLimit) { + if (!connection_.version().SupportsAntiAmplificationLimit() || + GetQuicFlag(quic_enforce_strict_amplification_factor)) { + return; + } + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + + set_perspective(Perspective::IS_SERVER); + QuicConfig config; + QuicTagVector connection_options; + connection_options.push_back(k10AF); + config.SetInitialReceivedConnectionOptions(connection_options); + if (connection_.version().UsesTls()) { + QuicConfigPeer::SetReceivedOriginalConnectionId( + &config, connection_.connection_id()); + QuicConfigPeer::SetReceivedInitialSourceConnectionId(&config, + QuicConnectionId()); + } + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + + // Verify no data can be sent at the beginning because bytes received is 0. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + connection_.SendCryptoDataWithString("foo", 0); + EXPECT_FALSE(connection_.CanWrite(HAS_RETRANSMITTABLE_DATA)); + EXPECT_FALSE(connection_.CanWrite(NO_RETRANSMITTABLE_DATA)); + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + + // Receives packet 1. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL); + + const size_t anti_amplification_factor = 10; + // Verify now packets can be sent. + for (size_t i = 1; i < anti_amplification_factor; ++i) { + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.SendCryptoDataWithString("foo", i * 3); + // Verify retransmission alarm is not set if throttled by anti-amplification + // limit. + EXPECT_EQ(i != anti_amplification_factor - 1, + connection_.GetRetransmissionAlarm()->IsSet()); + } + // Verify server is throttled by anti-amplification limit. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + connection_.SendCryptoDataWithString("foo", anti_amplification_factor * 3); + + // Receives packet 2. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + ProcessCryptoPacketAtLevel(2, ENCRYPTION_INITIAL); + // Verify more packets can be sent. + for (size_t i = anti_amplification_factor + 1; + i < anti_amplification_factor * 2; ++i) { + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.SendCryptoDataWithString("foo", i * 3); + } + // Verify server is throttled by anti-amplification limit. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + connection_.SendCryptoDataWithString("foo", + 2 * anti_amplification_factor * 3); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + ProcessPacket(3); + // Verify anti-amplification limit is gone after address validation. + for (size_t i = 0; i < 100; ++i) { + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.SendStreamDataWithString(3, "first", i * 0, NO_FIN); + } +} + +TEST_P(QuicConnectionTest, AckPendingWithAmplificationLimited) { + if (!connection_.version().SupportsAntiAmplificationLimit()) { + return; + } + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(AnyNumber()); + set_perspective(Perspective::IS_SERVER); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + // Receives packet 1. + ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL); + connection_.SetEncrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + EXPECT_TRUE(connection_.HasPendingAcks()); + // Send response in different encryption level and cause amplification factor + // throttled. + size_t i = 0; + while (connection_.CanWrite(HAS_RETRANSMITTABLE_DATA)) { + connection_.SendCryptoDataWithString(std::string(1024, 'a'), i * 1024, + ENCRYPTION_HANDSHAKE); + ++i; + } + // Verify ACK is still pending. + EXPECT_TRUE(connection_.HasPendingAcks()); + + // Fire ACK alarm and verify ACK cannot be sent due to amplification factor. + clock_.AdvanceTime(connection_.GetAckAlarm()->deadline() - clock_.Now()); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + connection_.GetAckAlarm()->Fire(); + // Verify ACK alarm is cancelled. + EXPECT_FALSE(connection_.HasPendingAcks()); + + // Receives packet 2 and verify ACK gets flushed. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + ProcessCryptoPacketAtLevel(2, ENCRYPTION_INITIAL); + EXPECT_FALSE(writer_->ack_frames().empty()); +} + +TEST_P(QuicConnectionTest, ConnectionCloseFrameType) { + if (!VersionHasIetfQuicFrames(version().transport_version)) { + // Test relevent only for IETF QUIC. + return; + } + const QuicErrorCode kQuicErrorCode = IETF_QUIC_PROTOCOL_VIOLATION; + // Use the (unknown) frame type of 9999 to avoid triggering any logic + // which might be associated with the processing of a known frame type. + const uint64_t kTransportCloseFrameType = 9999u; + QuicFramerPeer::set_current_received_frame_type( + QuicConnectionPeer::GetFramer(&connection_), kTransportCloseFrameType); + // Do a transport connection close + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)); + connection_.CloseConnection( + kQuicErrorCode, "Some random error message", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + const std::vector& connection_close_frames = + writer_->connection_close_frames(); + ASSERT_EQ(1u, connection_close_frames.size()); + EXPECT_EQ(IETF_QUIC_TRANSPORT_CONNECTION_CLOSE, + connection_close_frames[0].close_type); + EXPECT_EQ(kQuicErrorCode, connection_close_frames[0].quic_error_code); + EXPECT_EQ(kTransportCloseFrameType, + connection_close_frames[0].transport_close_frame_type); +} + +TEST_P(QuicConnectionTest, PtoSkipsPacketNumber) { + QuicConfig config; + QuicTagVector connection_options; + connection_options.push_back(k1PTO); + connection_options.push_back(kPTOS); + config.SetConnectionOptionsToSend(connection_options); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + + QuicStreamId stream_id = 2; + QuicPacketNumber last_packet; + SendStreamDataToPeer(stream_id, "foooooo", 0, NO_FIN, &last_packet); + SendStreamDataToPeer(stream_id, "foooooo", 7, NO_FIN, &last_packet); + EXPECT_EQ(QuicPacketNumber(2), last_packet); + EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + + // Fire PTO and verify the PTO retransmission skips one packet number. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.GetRetransmissionAlarm()->Fire(); + EXPECT_EQ(1u, writer_->stream_frames().size()); + EXPECT_EQ(QuicPacketNumber(4), writer_->last_packet_header().packet_number); + EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); +} + +TEST_P(QuicConnectionTest, SendCoalescedPackets) { + if (!connection_.version().CanSendCoalescedPackets()) { + return; + } + MockQuicConnectionDebugVisitor debug_visitor; + connection_.set_debug_visitor(&debug_visitor); + EXPECT_CALL(debug_visitor, OnPacketSent(_, _, _, _, _, _, _, _)).Times(3); + EXPECT_CALL(debug_visitor, OnCoalescedPacketSent(_, _)).Times(1); + EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1); + { + QuicConnection::ScopedPacketFlusher flusher(&connection_); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + connection_.SendCryptoDataWithString("foo", 0); + // Verify this packet is on hold. + EXPECT_EQ(0u, writer_->packets_write_attempts()); + + connection_.SetEncrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(0x02)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + connection_.SendCryptoDataWithString("bar", 3); + EXPECT_EQ(0u, writer_->packets_write_attempts()); + + connection_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(0x03)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + SendStreamDataToPeer(2, "baz", 3, NO_FIN, nullptr); + } + // Verify all 3 packets are coalesced in the same UDP datagram. + EXPECT_EQ(1u, writer_->packets_write_attempts()); + EXPECT_EQ(0x03030303u, writer_->final_bytes_of_last_packet()); + // Verify the packet is padded to full. + EXPECT_EQ(connection_.max_packet_length(), writer_->last_packet_size()); + + // Verify packet process. + EXPECT_EQ(1u, writer_->crypto_frames().size()); + EXPECT_EQ(0u, writer_->stream_frames().size()); + // Verify there is coalesced packet. + EXPECT_NE(nullptr, writer_->coalesced_packet()); +} + +TEST_P(QuicConnectionTest, FailToCoalescePacket) { + // EXPECT_QUIC_BUG tests are expensive so only run one instance of them. + if (!IsDefaultTestConfiguration() || + !connection_.version().CanSendCoalescedPackets() || + GetQuicFlag(quic_enforce_strict_amplification_factor)) { + return; + } + + set_perspective(Perspective::IS_SERVER); + + auto test_body = [&] { + EXPECT_CALL(visitor_, + OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .WillOnce(Invoke(this, &QuicConnectionTest::SaveConnectionCloseFrame)); + + ProcessDataPacketAtLevel(1, !kHasStopWaiting, ENCRYPTION_INITIAL); + + { + QuicConnection::ScopedPacketFlusher flusher(&connection_); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + connection_.SendCryptoDataWithString("foo", 0); + // Verify this packet is on hold. + EXPECT_EQ(0u, writer_->packets_write_attempts()); + + connection_.SetEncrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(0x02)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + connection_.SendCryptoDataWithString("bar", 3); + EXPECT_EQ(0u, writer_->packets_write_attempts()); + + connection_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(0x03)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + SendStreamDataToPeer(2, "baz", 3, NO_FIN, nullptr); + + creator_->Flush(); + + auto& coalesced_packet = + QuicConnectionPeer::GetCoalescedPacket(&connection_); + QuicPacketLength coalesced_packet_max_length = + coalesced_packet.max_packet_length(); + QuicCoalescedPacketPeer::SetMaxPacketLength(coalesced_packet, + coalesced_packet.length()); + + // Make the coalescer's FORWARD_SECURE packet longer. + *QuicCoalescedPacketPeer::GetMutableEncryptedBuffer( + coalesced_packet, ENCRYPTION_FORWARD_SECURE) += "!!! TEST !!!"; + + QUIC_LOG(INFO) << "Reduced coalesced_packet_max_length from " + << coalesced_packet_max_length << " to " + << coalesced_packet.max_packet_length() + << ", coalesced_packet.length:" + << coalesced_packet.length() + << ", coalesced_packet.packet_lengths:" + << absl::StrJoin(coalesced_packet.packet_lengths(), ":"); + } + + EXPECT_FALSE(connection_.connected()); + EXPECT_THAT(saved_connection_close_frame_.quic_error_code, + IsError(QUIC_FAILED_TO_SERIALIZE_PACKET)); + EXPECT_EQ(saved_connection_close_frame_.error_details, + "Failed to serialize coalesced packet."); + }; + + EXPECT_QUIC_BUG(test_body(), "SerializeCoalescedPacket failed."); +} + +TEST_P(QuicConnectionTest, ClientReceivedHandshakeDone) { + if (!connection_.version().UsesTls()) { + return; + } + EXPECT_CALL(visitor_, OnHandshakeDoneReceived()); + QuicFrames frames; + frames.push_back(QuicFrame(QuicHandshakeDoneFrame())); + frames.push_back(QuicFrame(QuicPaddingFrame(-1))); + ProcessFramesPacketAtLevel(1, frames, ENCRYPTION_FORWARD_SECURE); +} + +TEST_P(QuicConnectionTest, ServerReceivedHandshakeDone) { + if (!connection_.version().UsesTls()) { + return; + } + set_perspective(Perspective::IS_SERVER); + EXPECT_CALL(visitor_, OnHandshakeDoneReceived()).Times(0); + if (version().handshake_protocol == PROTOCOL_TLS1_3) { + EXPECT_CALL(visitor_, BeforeConnectionCloseSent()); + } + EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .WillOnce(Invoke(this, &QuicConnectionTest::SaveConnectionCloseFrame)); + QuicFrames frames; + frames.push_back(QuicFrame(QuicHandshakeDoneFrame())); + frames.push_back(QuicFrame(QuicPaddingFrame(-1))); + ProcessFramesPacketAtLevel(1, frames, ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(1, connection_close_frame_count_); + EXPECT_THAT(saved_connection_close_frame_.quic_error_code, + IsError(IETF_QUIC_PROTOCOL_VIOLATION)); +} + +TEST_P(QuicConnectionTest, MultiplePacketNumberSpacePto) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + // Send handshake packet. + connection_.SetEncrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1); + connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_HANDSHAKE); + EXPECT_EQ(0x01010101u, writer_->final_bytes_of_last_packet()); + + // Send application data. + connection_.SendApplicationDataAtLevel(ENCRYPTION_FORWARD_SECURE, 5, "data", + 0, NO_FIN); + EXPECT_EQ(0x03030303u, writer_->final_bytes_of_last_packet()); + QuicTime retransmission_time = + connection_.GetRetransmissionAlarm()->deadline(); + EXPECT_NE(QuicTime::Zero(), retransmission_time); + + // Retransmit handshake data. + clock_.AdvanceTime(retransmission_time - clock_.Now()); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, QuicPacketNumber(4), _, _)); + connection_.GetRetransmissionAlarm()->Fire(); + // Verify 1-RTT packet gets coalesced with handshake retransmission. + EXPECT_EQ(0x03030303u, writer_->final_bytes_of_last_packet()); + + // Send application data. + connection_.SendApplicationDataAtLevel(ENCRYPTION_FORWARD_SECURE, 5, "data", + 4, NO_FIN); + EXPECT_EQ(0x03030303u, writer_->final_bytes_of_last_packet()); + retransmission_time = connection_.GetRetransmissionAlarm()->deadline(); + EXPECT_NE(QuicTime::Zero(), retransmission_time); + + // Retransmit handshake data again. + clock_.AdvanceTime(retransmission_time - clock_.Now()); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, QuicPacketNumber(9), _, _)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, QuicPacketNumber(8), _, _)); + connection_.GetRetransmissionAlarm()->Fire(); + // Verify 1-RTT packet gets coalesced with handshake retransmission. + EXPECT_EQ(0x03030303u, writer_->final_bytes_of_last_packet()); + + // Discard handshake key. + connection_.OnHandshakeComplete(); + retransmission_time = connection_.GetRetransmissionAlarm()->deadline(); + EXPECT_NE(QuicTime::Zero(), retransmission_time); + + // Retransmit application data. + clock_.AdvanceTime(retransmission_time - clock_.Now()); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, QuicPacketNumber(11), _, _)); + connection_.GetRetransmissionAlarm()->Fire(); + EXPECT_EQ(0x03030303u, writer_->final_bytes_of_last_packet()); +} + +void QuicConnectionTest::TestClientRetryHandling( + bool invalid_retry_tag, bool missing_original_id_in_config, + bool wrong_original_id_in_config, bool missing_retry_id_in_config, + bool wrong_retry_id_in_config) { + if (invalid_retry_tag) { + ASSERT_FALSE(missing_original_id_in_config); + ASSERT_FALSE(wrong_original_id_in_config); + ASSERT_FALSE(missing_retry_id_in_config); + ASSERT_FALSE(wrong_retry_id_in_config); + } else { + ASSERT_FALSE(missing_original_id_in_config && wrong_original_id_in_config); + ASSERT_FALSE(missing_retry_id_in_config && wrong_retry_id_in_config); + } + if (!version().UsesTls()) { + return; + } + + // These values come from draft-ietf-quic-v2 Appendix A.4. + uint8_t retry_packet_rfcv2[] = { + 0xcf, 0x6b, 0x33, 0x43, 0xcf, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, + 0x42, 0x62, 0xb5, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0xc8, 0x64, 0x6c, 0xe8, + 0xbf, 0xe3, 0x39, 0x52, 0xd9, 0x55, 0x54, 0x36, 0x65, 0xdc, 0xc7, 0xb6}; + // These values come from RFC9001 Appendix A.4. + uint8_t retry_packet_rfcv1[] = { + 0xff, 0x00, 0x00, 0x00, 0x01, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, + 0x42, 0x62, 0xb5, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x04, 0xa2, 0x65, 0xba, + 0x2e, 0xff, 0x4d, 0x82, 0x90, 0x58, 0xfb, 0x3f, 0x0f, 0x24, 0x96, 0xba}; + uint8_t retry_packet29[] = { + 0xff, 0xff, 0x00, 0x00, 0x1d, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, + 0x42, 0x62, 0xb5, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0xd1, 0x69, 0x26, 0xd8, + 0x1f, 0x6f, 0x9c, 0xa2, 0x95, 0x3a, 0x8a, 0xa4, 0x57, 0x5e, 0x1e, 0x49}; + + uint8_t* retry_packet; + size_t retry_packet_length; + if (version() == ParsedQuicVersion::V2Draft08()) { + retry_packet = retry_packet_rfcv2; + retry_packet_length = ABSL_ARRAYSIZE(retry_packet_rfcv2); + } else if (version() == ParsedQuicVersion::RFCv1()) { + retry_packet = retry_packet_rfcv1; + retry_packet_length = ABSL_ARRAYSIZE(retry_packet_rfcv1); + } else if (version() == ParsedQuicVersion::Draft29()) { + retry_packet = retry_packet29; + retry_packet_length = ABSL_ARRAYSIZE(retry_packet29); + } else { + // TODO(dschinazi) generate retry packets for all versions once we have + // server-side support for generating these programmatically. + return; + } + + uint8_t original_connection_id_bytes[] = {0x83, 0x94, 0xc8, 0xf0, + 0x3e, 0x51, 0x57, 0x08}; + uint8_t new_connection_id_bytes[] = {0xf0, 0x67, 0xa5, 0x50, + 0x2a, 0x42, 0x62, 0xb5}; + uint8_t retry_token_bytes[] = {0x74, 0x6f, 0x6b, 0x65, 0x6e}; + + QuicConnectionId original_connection_id( + reinterpret_cast(original_connection_id_bytes), + ABSL_ARRAYSIZE(original_connection_id_bytes)); + QuicConnectionId new_connection_id( + reinterpret_cast(new_connection_id_bytes), + ABSL_ARRAYSIZE(new_connection_id_bytes)); + + std::string retry_token(reinterpret_cast(retry_token_bytes), + ABSL_ARRAYSIZE(retry_token_bytes)); + + if (invalid_retry_tag) { + // Flip the last bit of the retry packet to prevent the integrity tag + // from validating correctly. + retry_packet[retry_packet_length - 1] ^= 1; + } + + QuicConnectionId config_original_connection_id = original_connection_id; + if (wrong_original_id_in_config) { + // Flip the first bit of the connection ID. + ASSERT_FALSE(config_original_connection_id.IsEmpty()); + config_original_connection_id.mutable_data()[0] ^= 0x80; + } + QuicConnectionId config_retry_source_connection_id = new_connection_id; + if (wrong_retry_id_in_config) { + // Flip the first bit of the connection ID. + ASSERT_FALSE(config_retry_source_connection_id.IsEmpty()); + config_retry_source_connection_id.mutable_data()[0] ^= 0x80; + } + + // Make sure the connection uses the connection ID from the test vectors, + QuicConnectionPeer::SetServerConnectionId(&connection_, + original_connection_id); + // Make sure our fake framer has the new post-retry INITIAL keys so that any + // retransmission triggered by retry can be decrypted. + writer_->framer()->framer()->SetInitialObfuscators(new_connection_id); + + // Process the RETRY packet. + connection_.ProcessUdpPacket( + kSelfAddress, kPeerAddress, + QuicReceivedPacket(reinterpret_cast(retry_packet), + retry_packet_length, clock_.Now())); + + if (invalid_retry_tag) { + // Make sure we refuse to process a RETRY with invalid tag. + EXPECT_FALSE(connection_.GetStats().retry_packet_processed); + EXPECT_EQ(connection_.connection_id(), original_connection_id); + EXPECT_TRUE(QuicPacketCreatorPeer::GetRetryToken( + QuicConnectionPeer::GetPacketCreator(&connection_)) + .empty()); + return; + } + + // Make sure we correctly parsed the RETRY. + EXPECT_TRUE(connection_.GetStats().retry_packet_processed); + EXPECT_EQ(connection_.connection_id(), new_connection_id); + EXPECT_EQ(QuicPacketCreatorPeer::GetRetryToken( + QuicConnectionPeer::GetPacketCreator(&connection_)), + retry_token); + + // Test validating the original_connection_id from the config. + QuicConfig received_config; + QuicConfigPeer::SetNegotiated(&received_config, true); + if (connection_.version().UsesTls()) { + QuicConfigPeer::SetReceivedInitialSourceConnectionId( + &received_config, connection_.connection_id()); + if (!missing_retry_id_in_config) { + QuicConfigPeer::SetReceivedRetrySourceConnectionId( + &received_config, config_retry_source_connection_id); + } + } + if (!missing_original_id_in_config) { + QuicConfigPeer::SetReceivedOriginalConnectionId( + &received_config, config_original_connection_id); + } + + if (missing_original_id_in_config || wrong_original_id_in_config || + missing_retry_id_in_config || wrong_retry_id_in_config) { + EXPECT_CALL(visitor_, + OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .Times(1); + } else { + EXPECT_CALL(visitor_, + OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .Times(0); + } + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)).Times(AnyNumber()); + connection_.SetFromConfig(received_config); + if (missing_original_id_in_config || wrong_original_id_in_config || + missing_retry_id_in_config || wrong_retry_id_in_config) { + ASSERT_FALSE(connection_.connected()); + TestConnectionCloseQuicErrorCode(IETF_QUIC_PROTOCOL_VIOLATION); + } else { + EXPECT_TRUE(connection_.connected()); + } +} + +TEST_P(QuicConnectionTest, ClientParsesRetry) { + TestClientRetryHandling(/*invalid_retry_tag=*/false, + /*missing_original_id_in_config=*/false, + /*wrong_original_id_in_config=*/false, + /*missing_retry_id_in_config=*/false, + /*wrong_retry_id_in_config=*/false); +} + +TEST_P(QuicConnectionTest, ClientParsesRetryInvalidTag) { + TestClientRetryHandling(/*invalid_retry_tag=*/true, + /*missing_original_id_in_config=*/false, + /*wrong_original_id_in_config=*/false, + /*missing_retry_id_in_config=*/false, + /*wrong_retry_id_in_config=*/false); +} + +TEST_P(QuicConnectionTest, ClientParsesRetryMissingOriginalId) { + TestClientRetryHandling(/*invalid_retry_tag=*/false, + /*missing_original_id_in_config=*/true, + /*wrong_original_id_in_config=*/false, + /*missing_retry_id_in_config=*/false, + /*wrong_retry_id_in_config=*/false); +} + +TEST_P(QuicConnectionTest, ClientParsesRetryWrongOriginalId) { + TestClientRetryHandling(/*invalid_retry_tag=*/false, + /*missing_original_id_in_config=*/false, + /*wrong_original_id_in_config=*/true, + /*missing_retry_id_in_config=*/false, + /*wrong_retry_id_in_config=*/false); +} + +TEST_P(QuicConnectionTest, ClientParsesRetryMissingRetryId) { + if (!connection_.version().UsesTls()) { + // Versions that do not authenticate connection IDs never send the + // retry_source_connection_id transport parameter. + return; + } + TestClientRetryHandling(/*invalid_retry_tag=*/false, + /*missing_original_id_in_config=*/false, + /*wrong_original_id_in_config=*/false, + /*missing_retry_id_in_config=*/true, + /*wrong_retry_id_in_config=*/false); +} + +TEST_P(QuicConnectionTest, ClientParsesRetryWrongRetryId) { + if (!connection_.version().UsesTls()) { + // Versions that do not authenticate connection IDs never send the + // retry_source_connection_id transport parameter. + return; + } + TestClientRetryHandling(/*invalid_retry_tag=*/false, + /*missing_original_id_in_config=*/false, + /*wrong_original_id_in_config=*/false, + /*missing_retry_id_in_config=*/false, + /*wrong_retry_id_in_config=*/true); +} + +TEST_P(QuicConnectionTest, ClientRetransmitsInitialPacketsOnRetry) { + if (!connection_.version().HasIetfQuicFrames()) { + // TestClientRetryHandling() currently only supports IETF draft versions. + return; + } + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + + connection_.SendCryptoStreamData(); + + EXPECT_EQ(1u, writer_->packets_write_attempts()); + TestClientRetryHandling(/*invalid_retry_tag=*/false, + /*missing_original_id_in_config=*/false, + /*wrong_original_id_in_config=*/false, + /*missing_retry_id_in_config=*/false, + /*wrong_retry_id_in_config=*/false); + + // Verify that initial data is retransmitted immediately after receiving + // RETRY. + if (GetParam().ack_response == AckResponse::kImmediate) { + EXPECT_EQ(2u, writer_->packets_write_attempts()); + EXPECT_EQ(1u, writer_->framer()->crypto_frames().size()); + } +} + +TEST_P(QuicConnectionTest, NoInitialPacketsRetransmissionOnInvalidRetry) { + if (!connection_.version().HasIetfQuicFrames()) { + return; + } + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + + connection_.SendCryptoStreamData(); + + EXPECT_EQ(1u, writer_->packets_write_attempts()); + TestClientRetryHandling(/*invalid_retry_tag=*/true, + /*missing_original_id_in_config=*/false, + /*wrong_original_id_in_config=*/false, + /*missing_retry_id_in_config=*/false, + /*wrong_retry_id_in_config=*/false); + + EXPECT_EQ(1u, writer_->packets_write_attempts()); +} + +TEST_P(QuicConnectionTest, ClientReceivesOriginalConnectionIdWithoutRetry) { + if (!connection_.version().UsesTls()) { + // QUIC+TLS is required to transmit connection ID transport parameters. + return; + } + if (connection_.version().UsesTls()) { + // Versions that authenticate connection IDs always send the + // original_destination_connection_id transport parameter. + return; + } + // Make sure that receiving the original_destination_connection_id transport + // parameter fails the handshake when no RETRY packet was received before it. + QuicConfig received_config; + QuicConfigPeer::SetNegotiated(&received_config, true); + QuicConfigPeer::SetReceivedOriginalConnectionId(&received_config, + TestConnectionId(0x12345)); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)).Times(AnyNumber()); + EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .Times(1); + connection_.SetFromConfig(received_config); + EXPECT_FALSE(connection_.connected()); + TestConnectionCloseQuicErrorCode(IETF_QUIC_PROTOCOL_VIOLATION); +} + +TEST_P(QuicConnectionTest, ClientReceivesRetrySourceConnectionIdWithoutRetry) { + if (!connection_.version().UsesTls()) { + // Versions that do not authenticate connection IDs never send the + // retry_source_connection_id transport parameter. + return; + } + // Make sure that receiving the retry_source_connection_id transport parameter + // fails the handshake when no RETRY packet was received before it. + QuicConfig received_config; + QuicConfigPeer::SetNegotiated(&received_config, true); + QuicConfigPeer::SetReceivedRetrySourceConnectionId(&received_config, + TestConnectionId(0x12345)); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)).Times(AnyNumber()); + EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .Times(1); + connection_.SetFromConfig(received_config); + EXPECT_FALSE(connection_.connected()); + TestConnectionCloseQuicErrorCode(IETF_QUIC_PROTOCOL_VIOLATION); +} + +// Regression test for http://crbug/1047977 +TEST_P(QuicConnectionTest, MaxStreamsFrameCausesConnectionClose) { + if (!VersionHasIetfQuicFrames(connection_.transport_version())) { + return; + } + // Received frame causes connection close. + EXPECT_CALL(visitor_, OnMaxStreamsFrame(_)) + .WillOnce(InvokeWithoutArgs([this]() { + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)); + connection_.CloseConnection( + QUIC_TOO_MANY_BUFFERED_CONTROL_FRAMES, "error", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return true; + })); + QuicFrames frames; + frames.push_back(QuicFrame(QuicMaxStreamsFrame())); + frames.push_back(QuicFrame(QuicPaddingFrame(-1))); + ProcessFramesPacketAtLevel(1, frames, ENCRYPTION_FORWARD_SECURE); +} + +TEST_P(QuicConnectionTest, StreamsBlockedFrameCausesConnectionClose) { + if (!VersionHasIetfQuicFrames(connection_.transport_version())) { + return; + } + // Received frame causes connection close. + EXPECT_CALL(visitor_, OnStreamsBlockedFrame(_)) + .WillOnce(InvokeWithoutArgs([this]() { + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)); + connection_.CloseConnection( + QUIC_TOO_MANY_BUFFERED_CONTROL_FRAMES, "error", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return true; + })); + QuicFrames frames; + frames.push_back( + QuicFrame(QuicStreamsBlockedFrame(kInvalidControlFrameId, 10, false))); + frames.push_back(QuicFrame(QuicPaddingFrame(-1))); + ProcessFramesPacketAtLevel(1, frames, ENCRYPTION_FORWARD_SECURE); +} + +TEST_P(QuicConnectionTest, + BundleAckWithConnectionCloseMultiplePacketNumberSpace) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(AnyNumber()); + // Receives packet 1000 in initial data. + ProcessCryptoPacketAtLevel(1000, ENCRYPTION_INITIAL); + // Receives packet 2000 in application data. + ProcessDataPacketAtLevel(2000, false, ENCRYPTION_FORWARD_SECURE); + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)); + const QuicErrorCode kQuicErrorCode = QUIC_INTERNAL_ERROR; + connection_.CloseConnection( + kQuicErrorCode, "Some random error message", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + + EXPECT_EQ(2u, QuicConnectionPeer::GetNumEncryptionLevels(&connection_)); + + TestConnectionCloseQuicErrorCode(kQuicErrorCode); + EXPECT_EQ(1u, writer_->connection_close_frames().size()); + // Verify ack is bundled. + EXPECT_EQ(1u, writer_->ack_frames().size()); + + if (!connection_.version().CanSendCoalescedPackets()) { + // Each connection close packet should be sent in distinct UDP packets. + EXPECT_EQ(QuicConnectionPeer::GetNumEncryptionLevels(&connection_), + writer_->connection_close_packets()); + EXPECT_EQ(QuicConnectionPeer::GetNumEncryptionLevels(&connection_), + writer_->packets_write_attempts()); + return; + } + + // A single UDP packet should be sent with multiple connection close packets + // coalesced together. + EXPECT_EQ(1u, writer_->packets_write_attempts()); + + // Only the first packet has been processed yet. + EXPECT_EQ(1u, writer_->connection_close_packets()); + + // ProcessPacket resets the visitor and frees the coalesced packet. + ASSERT_TRUE(writer_->coalesced_packet() != nullptr); + auto packet = writer_->coalesced_packet()->Clone(); + writer_->framer()->ProcessPacket(*packet); + EXPECT_EQ(1u, writer_->connection_close_packets()); + EXPECT_EQ(1u, writer_->connection_close_frames().size()); + // Verify ack is bundled. + EXPECT_EQ(1u, writer_->ack_frames().size()); + ASSERT_TRUE(writer_->coalesced_packet() == nullptr); +} + +// Regression test for b/151220135. +TEST_P(QuicConnectionTest, SendPingWhenSkipPacketNumberForPto) { + if (!VersionSupportsMessageFrames(connection_.transport_version())) { + return; + } + QuicConfig config; + QuicTagVector connection_options; + connection_options.push_back(kPTOS); + connection_options.push_back(k1PTO); + config.SetConnectionOptionsToSend(connection_options); + if (connection_.version().UsesTls()) { + QuicConfigPeer::SetReceivedMaxDatagramFrameSize( + &config, kMaxAcceptedDatagramFrameSize); + } + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + connection_.OnHandshakeComplete(); + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + + EXPECT_EQ(MESSAGE_STATUS_SUCCESS, SendMessage("message")); + EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + + // PTO fires, verify a PING packet gets sent because there is no data to + // send. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, QuicPacketNumber(3), _, _)); + connection_.GetRetransmissionAlarm()->Fire(); + EXPECT_EQ(1u, connection_.GetStats().pto_count); + EXPECT_EQ(0u, connection_.GetStats().crypto_retransmit_count); + EXPECT_EQ(1u, writer_->ping_frames().size()); +} + +// Regression test for b/155757133 +TEST_P(QuicConnectionTest, DonotChangeQueuedAcks) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + const size_t kMinRttMs = 40; + RttStats* rtt_stats = const_cast(manager_->GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(kMinRttMs), + QuicTime::Delta::Zero(), QuicTime::Zero()); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _, _, _)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + // Discard INITIAL key. + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.NeuterUnencryptedPackets(); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_COMPLETE)); + + ProcessPacket(2); + ProcessPacket(3); + ProcessPacket(4); + // Process a packet containing stream frame followed by ACK of packets 1. + QuicFrames frames; + frames.push_back(QuicFrame(QuicStreamFrame( + QuicUtils::GetFirstBidirectionalStreamId( + connection_.version().transport_version, Perspective::IS_CLIENT), + false, 0u, absl::string_view()))); + QuicAckFrame ack_frame = InitAckFrame(1); + frames.push_back(QuicFrame(&ack_frame)); + // Receiving stream frame causes something to send. + EXPECT_CALL(visitor_, OnStreamFrame(_)).WillOnce(Invoke([this]() { + connection_.SendControlFrame(QuicFrame(QuicWindowUpdateFrame(1, 0, 0))); + // Verify now the queued ACK contains packet number 2. + EXPECT_TRUE(QuicPacketCreatorPeer::QueuedFrames( + QuicConnectionPeer::GetPacketCreator(&connection_))[0] + .ack_frame->packets.Contains(QuicPacketNumber(2))); + })); + ProcessFramesPacketAtLevel(9, frames, ENCRYPTION_FORWARD_SECURE); + EXPECT_TRUE(writer_->ack_frames()[0].packets.Contains(QuicPacketNumber(2))); +} + +TEST_P(QuicConnectionTest, DoNotExtendIdleTimeOnUndecryptablePackets) { + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + connection_.SetFromConfig(config); + // Subtract a second from the idle timeout on the client side. + QuicTime initial_deadline = + clock_.ApproximateNow() + + QuicTime::Delta::FromSeconds(kInitialIdleTimeoutSecs - 1); + EXPECT_EQ(initial_deadline, connection_.GetTimeoutAlarm()->deadline()); + + // Received an undecryptable packet. + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(1)); + peer_framer_.SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(Perspective::IS_CLIENT)); + ProcessDataPacketAtLevel(1, !kHasStopWaiting, ENCRYPTION_FORWARD_SECURE); + // Verify deadline does not get extended. + EXPECT_EQ(initial_deadline, connection_.GetTimeoutAlarm()->deadline()); + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)).Times(1); + QuicTime::Delta delay = initial_deadline - clock_.ApproximateNow(); + clock_.AdvanceTime(delay); + connection_.GetTimeoutAlarm()->Fire(); + // Verify connection gets closed. + EXPECT_FALSE(connection_.connected()); +} + +TEST_P(QuicConnectionTest, BundleAckWithImmediateResponse) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + + EXPECT_CALL(visitor_, OnStreamFrame(_)).WillOnce(Invoke([this]() { + notifier_.WriteOrBufferWindowUpate(0, 0); + })); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + ProcessDataPacket(1); + // Verify ACK is bundled with WINDOW_UPDATE. + EXPECT_FALSE(writer_->ack_frames().empty()); + EXPECT_FALSE(connection_.HasPendingAcks()); +} + +TEST_P(QuicConnectionTest, AckAlarmFiresEarly) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + } + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(AnyNumber()); + // Receives packet 1000 in initial data. + ProcessCryptoPacketAtLevel(1000, ENCRYPTION_INITIAL); + EXPECT_TRUE(connection_.HasPendingAcks()); + + peer_framer_.SetEncrypter( + ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + SetDecrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + // Receives packet 1000 in application data. + ProcessDataPacketAtLevel(1000, false, ENCRYPTION_ZERO_RTT); + EXPECT_TRUE(connection_.HasPendingAcks()); + // Verify ACK deadline does not change. + EXPECT_EQ(clock_.ApproximateNow() + kAlarmGranularity, + connection_.GetAckAlarm()->deadline()); + + // Ack alarm fires early. + // Verify the earliest ACK is flushed. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.GetAckAlarm()->Fire(); + EXPECT_TRUE(connection_.HasPendingAcks()); + EXPECT_EQ(clock_.ApproximateNow() + DefaultDelayedAckTime(), + connection_.GetAckAlarm()->deadline()); +} + +TEST_P(QuicConnectionTest, ClientOnlyBlackholeDetectionClient) { + if (!GetQuicReloadableFlag(quic_default_enable_5rto_blackhole_detection2)) { + return; + } + QuicConfig config; + QuicTagVector connection_options; + connection_options.push_back(kCBHD); + config.SetConnectionOptionsToSend(connection_options); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + connection_.OnHandshakeComplete(); + EXPECT_FALSE(connection_.GetBlackholeDetectorAlarm()->IsSet()); + // Send stream data. + SendStreamDataToPeer( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), "foo", + 0, FIN, nullptr); + // Verify blackhole detection is in progress. + EXPECT_TRUE(connection_.GetBlackholeDetectorAlarm()->IsSet()); +} + +TEST_P(QuicConnectionTest, ClientOnlyBlackholeDetectionServer) { + if (!GetQuicReloadableFlag(quic_default_enable_5rto_blackhole_detection2)) { + return; + } + set_perspective(Perspective::IS_SERVER); + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + if (version().SupportsAntiAmplificationLimit()) { + QuicConnectionPeer::SetAddressValidated(&connection_); + } + QuicConfig config; + QuicTagVector connection_options; + connection_options.push_back(kCBHD); + config.SetInitialReceivedConnectionOptions(connection_options); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_COMPLETE)); + EXPECT_FALSE(connection_.GetBlackholeDetectorAlarm()->IsSet()); + // Send stream data. + SendStreamDataToPeer( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), "foo", + 0, FIN, nullptr); + // Verify blackhole detection is disabled. + EXPECT_FALSE(connection_.GetBlackholeDetectorAlarm()->IsSet()); +} + +// Regresstion test for b/158491591. +TEST_P(QuicConnectionTest, MadeForwardProgressOnDiscardingKeys) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + // Send handshake packet. + connection_.SetEncrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1); + QuicConfig config; + QuicTagVector connection_options; + connection_options.push_back(k5RTO); + config.SetConnectionOptionsToSend(connection_options); + QuicConfigPeer::SetNegotiated(&config, true); + if (GetQuicReloadableFlag(quic_default_enable_5rto_blackhole_detection2) || + GetQuicReloadableFlag( + quic_no_path_degrading_before_handshake_confirmed)) { + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_COMPLETE)); + } + if (connection_.version().UsesTls()) { + QuicConfigPeer::SetReceivedOriginalConnectionId( + &config, connection_.connection_id()); + QuicConfigPeer::SetReceivedInitialSourceConnectionId( + &config, connection_.connection_id()); + } + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + + connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_HANDSHAKE); + if (GetQuicReloadableFlag( + quic_no_path_degrading_before_handshake_confirmed)) { + // No blackhole detection before handshake confirmed. + EXPECT_FALSE(connection_.BlackholeDetectionInProgress()); + } else { + EXPECT_TRUE(connection_.BlackholeDetectionInProgress()); + } + // Discard handshake keys. + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + connection_.OnHandshakeComplete(); + if (GetQuicReloadableFlag(quic_default_enable_5rto_blackhole_detection2) || + GetQuicReloadableFlag( + quic_no_path_degrading_before_handshake_confirmed)) { + // Verify blackhole detection stops. + EXPECT_FALSE(connection_.BlackholeDetectionInProgress()); + } else { + // Problematic: although there is nothing in flight, blackhole detection is + // still in progress. + EXPECT_TRUE(connection_.BlackholeDetectionInProgress()); + } +} + +TEST_P(QuicConnectionTest, ProcessUndecryptablePacketsBasedOnEncryptionLevel) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + // SetFromConfig is always called after construction from InitializeSession. + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(AnyNumber()); + QuicConfig config; + connection_.SetFromConfig(config); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + connection_.RemoveDecrypter(ENCRYPTION_FORWARD_SECURE); + + peer_framer_.SetEncrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + peer_framer_.SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + + for (uint64_t i = 1; i <= 3; ++i) { + ProcessDataPacketAtLevel(i, !kHasStopWaiting, ENCRYPTION_HANDSHAKE); + } + ProcessDataPacketAtLevel(4, !kHasStopWaiting, ENCRYPTION_FORWARD_SECURE); + for (uint64_t j = 5; j <= 7; ++j) { + ProcessDataPacketAtLevel(j, !kHasStopWaiting, ENCRYPTION_HANDSHAKE); + } + EXPECT_EQ(7u, QuicConnectionPeer::NumUndecryptablePackets(&connection_)); + EXPECT_FALSE(connection_.GetProcessUndecryptablePacketsAlarm()->IsSet()); + SetDecrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + EXPECT_TRUE(connection_.GetProcessUndecryptablePacketsAlarm()->IsSet()); + connection_.SetEncrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + // Verify all ENCRYPTION_HANDSHAKE packets get processed. + if (!VersionHasIetfQuicFrames(version().transport_version)) { + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(6); + } + connection_.GetProcessUndecryptablePacketsAlarm()->Fire(); + EXPECT_EQ(1u, QuicConnectionPeer::NumUndecryptablePackets(&connection_)); + + SetDecrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + EXPECT_TRUE(connection_.GetProcessUndecryptablePacketsAlarm()->IsSet()); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + connection_.SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + // Verify the 1-RTT packet gets processed. + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + connection_.GetProcessUndecryptablePacketsAlarm()->Fire(); + EXPECT_EQ(0u, QuicConnectionPeer::NumUndecryptablePackets(&connection_)); +} + +TEST_P(QuicConnectionTest, ServerBundlesInitialDataWithInitialAck) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + set_perspective(Perspective::IS_SERVER); + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + } + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(AnyNumber()); + // Receives packet 1000 in initial data. + ProcessCryptoPacketAtLevel(1000, ENCRYPTION_INITIAL); + EXPECT_TRUE(connection_.HasPendingAcks()); + + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_INITIAL); + QuicTime expected_pto_time = + connection_.sent_packet_manager().GetRetransmissionTime(); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + connection_.SetEncrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(0x02)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1); + connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_HANDSHAKE); + // Verify PTO time does not change. + EXPECT_EQ(expected_pto_time, + connection_.sent_packet_manager().GetRetransmissionTime()); + + // Receives packet 1001 in initial data. + ProcessCryptoPacketAtLevel(1001, ENCRYPTION_INITIAL); + EXPECT_TRUE(connection_.HasPendingAcks()); + // Receives packet 1002 in initial data. + ProcessCryptoPacketAtLevel(1002, ENCRYPTION_INITIAL); + EXPECT_FALSE(writer_->ack_frames().empty()); + // Verify CRYPTO frame is bundled with INITIAL ACK. + EXPECT_FALSE(writer_->crypto_frames().empty()); + // Verify PTO time changes. + EXPECT_NE(expected_pto_time, + connection_.sent_packet_manager().GetRetransmissionTime()); +} + +TEST_P(QuicConnectionTest, ClientBundlesHandshakeDataWithHandshakeAck) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + EXPECT_EQ(Perspective::IS_CLIENT, connection_.perspective()); + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + } + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(AnyNumber()); + connection_.SetEncrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + SetDecrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + peer_framer_.SetEncrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + // Receives packet 1000 in handshake data. + ProcessCryptoPacketAtLevel(1000, ENCRYPTION_HANDSHAKE); + EXPECT_TRUE(connection_.HasPendingAcks()); + EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1); + connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_HANDSHAKE); + + // Receives packet 1001 in handshake data. + ProcessCryptoPacketAtLevel(1001, ENCRYPTION_HANDSHAKE); + EXPECT_TRUE(connection_.HasPendingAcks()); + // Receives packet 1002 in handshake data. + ProcessCryptoPacketAtLevel(1002, ENCRYPTION_HANDSHAKE); + EXPECT_FALSE(writer_->ack_frames().empty()); + // Verify CRYPTO frame is bundled with HANDSHAKE ACK. + EXPECT_FALSE(writer_->crypto_frames().empty()); +} + +// Regresstion test for b/156232673. +TEST_P(QuicConnectionTest, CoalescePacketOfLowerEncryptionLevel) { + if (!connection_.version().CanSendCoalescedPackets()) { + return; + } + EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1); + { + QuicConnection::ScopedPacketFlusher flusher(&connection_); + connection_.SetEncrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + connection_.SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + SendStreamDataToPeer(2, std::string(1286, 'a'), 0, NO_FIN, nullptr); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + // Try to coalesce a HANDSHAKE packet after 1-RTT packet. + // Verify soft max packet length gets resumed and handshake packet gets + // successfully sent. + connection_.SendCryptoDataWithString("a", 0, ENCRYPTION_HANDSHAKE); + } +} + +// Regression test for b/160790422. +TEST_P(QuicConnectionTest, ServerRetransmitsHandshakeDataEarly) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + set_perspective(Perspective::IS_SERVER); + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + } + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(AnyNumber()); + // Receives packet 1000 in initial data. + ProcessCryptoPacketAtLevel(1000, ENCRYPTION_INITIAL); + EXPECT_TRUE(connection_.HasPendingAcks()); + + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + // Send INITIAL 1. + connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_INITIAL); + QuicTime expected_pto_time = + connection_.sent_packet_manager().GetRetransmissionTime(); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + connection_.SetEncrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1); + // Send HANDSHAKE 2 and 3. + connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_HANDSHAKE); + connection_.SendCryptoDataWithString("bar", 3, ENCRYPTION_HANDSHAKE); + // Verify PTO time does not change. + EXPECT_EQ(expected_pto_time, + connection_.sent_packet_manager().GetRetransmissionTime()); + + // Receives ACK for HANDSHAKE 2. + QuicFrames frames; + auto ack_frame = InitAckFrame({{QuicPacketNumber(2), QuicPacketNumber(3)}}); + frames.push_back(QuicFrame(&ack_frame)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _, _, _)); + ProcessFramesPacketAtLevel(30, frames, ENCRYPTION_HANDSHAKE); + // Discard INITIAL key. + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.NeuterUnencryptedPackets(); + // Receives PING from peer. + frames.clear(); + frames.push_back(QuicFrame(QuicPingFrame())); + frames.push_back(QuicFrame(QuicPaddingFrame(3))); + ProcessFramesPacketAtLevel(31, frames, ENCRYPTION_HANDSHAKE); + EXPECT_EQ(clock_.Now() + kAlarmGranularity, + connection_.GetAckAlarm()->deadline()); + // Fire ACK alarm. + clock_.AdvanceTime(kAlarmGranularity); + connection_.GetAckAlarm()->Fire(); + EXPECT_FALSE(writer_->ack_frames().empty()); + // Verify handshake data gets retransmitted early. + EXPECT_FALSE(writer_->crypto_frames().empty()); +} + +// Regression test for b/161228202 +TEST_P(QuicConnectionTest, InflatedRttSample) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + // 30ms RTT. + const QuicTime::Delta kTestRTT = QuicTime::Delta::FromMilliseconds(30); + set_perspective(Perspective::IS_SERVER); + RttStats* rtt_stats = const_cast(manager_->GetRttStats()); + // Receives packet 1000 in initial data. + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + } + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(AnyNumber()); + ProcessCryptoPacketAtLevel(1000, ENCRYPTION_INITIAL); + EXPECT_TRUE(connection_.HasPendingAcks()); + + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + // Send INITIAL 1. + std::string initial_crypto_data(512, 'a'); + connection_.SendCryptoDataWithString(initial_crypto_data, 0, + ENCRYPTION_INITIAL); + ASSERT_TRUE(connection_.sent_packet_manager() + .GetRetransmissionTime() + .IsInitialized()); + QuicTime::Delta pto_timeout = + connection_.sent_packet_manager().GetRetransmissionTime() - clock_.Now(); + // Send Handshake 2. + connection_.SetEncrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1); + std::string handshake_crypto_data(1024, 'a'); + connection_.SendCryptoDataWithString(handshake_crypto_data, 0, + ENCRYPTION_HANDSHAKE); + + // INITIAL 1 gets lost and PTO fires. + clock_.AdvanceTime(pto_timeout); + connection_.GetRetransmissionAlarm()->Fire(); + + clock_.AdvanceTime(kTestRTT); + // Assume retransmitted INITIAL gets received. + QuicFrames frames; + auto ack_frame = InitAckFrame({{QuicPacketNumber(4), QuicPacketNumber(5)}}); + frames.push_back(QuicFrame(&ack_frame)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _, _, _)) + .Times(AnyNumber()); + ProcessFramesPacketAtLevel(1001, frames, ENCRYPTION_INITIAL); + EXPECT_EQ(kTestRTT, rtt_stats->latest_rtt()); + // Because retransmitted INITIAL gets received so HANDSHAKE 2 gets processed. + frames.clear(); + // HANDSHAKE 5 is also processed. + QuicAckFrame ack_frame2 = + InitAckFrame({{QuicPacketNumber(2), QuicPacketNumber(3)}, + {QuicPacketNumber(5), QuicPacketNumber(6)}}); + ack_frame2.ack_delay_time = QuicTime::Delta::Zero(); + frames.push_back(QuicFrame(&ack_frame2)); + ProcessFramesPacketAtLevel(1, frames, ENCRYPTION_HANDSHAKE); + // Verify RTT inflation gets mitigated. + EXPECT_EQ(rtt_stats->latest_rtt(), kTestRTT); +} + +// Regression test for b/161228202 +TEST_P(QuicConnectionTest, CoalescingPacketCausesInfiniteLoop) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + set_perspective(Perspective::IS_SERVER); + // Receives packet 1000 in initial data. + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + } + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(AnyNumber()); + + // Set anti amplification factor to 2, such that RetransmitDataOfSpaceIfAny + // makes no forward progress and causes infinite loop. + SetQuicFlag(quic_anti_amplification_factor, 2); + + ProcessCryptoPacketAtLevel(1000, ENCRYPTION_INITIAL); + EXPECT_TRUE(connection_.HasPendingAcks()); + + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + // Send INITIAL 1. + std::string initial_crypto_data(512, 'a'); + connection_.SendCryptoDataWithString(initial_crypto_data, 0, + ENCRYPTION_INITIAL); + ASSERT_TRUE(connection_.sent_packet_manager() + .GetRetransmissionTime() + .IsInitialized()); + QuicTime::Delta pto_timeout = + connection_.sent_packet_manager().GetRetransmissionTime() - clock_.Now(); + // Send Handshake 2. + connection_.SetEncrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + // Verify HANDSHAKE packet is coalesced with INITIAL retransmission. + EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1); + std::string handshake_crypto_data(1024, 'a'); + connection_.SendCryptoDataWithString(handshake_crypto_data, 0, + ENCRYPTION_HANDSHAKE); + + // INITIAL 1 gets lost and PTO fires. + clock_.AdvanceTime(pto_timeout); + connection_.GetRetransmissionAlarm()->Fire(); +} + +TEST_P(QuicConnectionTest, ClientAckDelayForAsyncPacketProcessing) { + if (!version().HasIetfQuicFrames()) { + return; + } + // SetFromConfig is always called after construction from InitializeSession. + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + EXPECT_CALL(visitor_, OnHandshakePacketSent()).WillOnce(Invoke([this]() { + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.NeuterUnencryptedPackets(); + })); + QuicConfig config; + connection_.SetFromConfig(config); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + peer_framer_.SetEncrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + EXPECT_EQ(0u, QuicConnectionPeer::NumUndecryptablePackets(&connection_)); + + // Received undecryptable HANDSHAKE 2. + ProcessDataPacketAtLevel(2, !kHasStopWaiting, ENCRYPTION_HANDSHAKE); + ASSERT_EQ(1u, QuicConnectionPeer::NumUndecryptablePackets(&connection_)); + // Received INITIAL 4 (which is retransmission of INITIAL 1) after 100ms. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(100)); + ProcessDataPacketAtLevel(4, !kHasStopWaiting, ENCRYPTION_INITIAL); + // Generate HANDSHAKE key. + SetDecrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + EXPECT_TRUE(connection_.GetProcessUndecryptablePacketsAlarm()->IsSet()); + connection_.SetEncrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + // Verify HANDSHAKE packet gets processed. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + connection_.GetProcessUndecryptablePacketsAlarm()->Fire(); + // Verify immediate ACK has been sent out when flush went out of scope. + ASSERT_FALSE(connection_.HasPendingAcks()); + ASSERT_FALSE(writer_->ack_frames().empty()); + // Verify the ack_delay_time in the sent HANDSHAKE ACK frame is 100ms. + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(100), + writer_->ack_frames()[0].ack_delay_time); + ASSERT_TRUE(writer_->coalesced_packet() == nullptr); +} + +TEST_P(QuicConnectionTest, TestingLiveness) { + const size_t kMinRttMs = 40; + RttStats* rtt_stats = const_cast(manager_->GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(kMinRttMs), + QuicTime::Delta::Zero(), QuicTime::Zero()); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + + CryptoHandshakeMessage msg; + std::string error_details; + QuicConfig client_config; + client_config.SetInitialStreamFlowControlWindowToSend( + kInitialStreamFlowControlWindowForTest); + client_config.SetInitialSessionFlowControlWindowToSend( + kInitialSessionFlowControlWindowForTest); + client_config.SetIdleNetworkTimeout(QuicTime::Delta::FromSeconds(30)); + client_config.ToHandshakeMessage(&msg, connection_.transport_version()); + const QuicErrorCode error = + config.ProcessPeerHello(msg, CLIENT, &error_details); + EXPECT_THAT(error, IsQuicNoError()); + + if (connection_.version().UsesTls()) { + QuicConfigPeer::SetReceivedOriginalConnectionId( + &config, connection_.connection_id()); + QuicConfigPeer::SetReceivedInitialSourceConnectionId( + &config, connection_.connection_id()); + } + + connection_.SetFromConfig(config); + connection_.OnHandshakeComplete(); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + ASSERT_TRUE(connection_.GetTimeoutAlarm()->IsSet()); + EXPECT_FALSE(connection_.MaybeTestLiveness()); + + QuicTime deadline = QuicConnectionPeer::GetIdleNetworkDeadline(&connection_); + QuicTime::Delta timeout = deadline - clock_.ApproximateNow(); + // Advance time to near the idle timeout. + clock_.AdvanceTime(timeout - QuicTime::Delta::FromMilliseconds(1)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + EXPECT_TRUE(connection_.MaybeTestLiveness()); + // Verify idle deadline does not change. + EXPECT_EQ(deadline, QuicConnectionPeer::GetIdleNetworkDeadline(&connection_)); +} + +TEST_P(QuicConnectionTest, DisableLivenessTesting) { + const size_t kMinRttMs = 40; + RttStats* rtt_stats = const_cast(manager_->GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(kMinRttMs), + QuicTime::Delta::Zero(), QuicTime::Zero()); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + + CryptoHandshakeMessage msg; + std::string error_details; + QuicConfig client_config; + client_config.SetInitialStreamFlowControlWindowToSend( + kInitialStreamFlowControlWindowForTest); + client_config.SetInitialSessionFlowControlWindowToSend( + kInitialSessionFlowControlWindowForTest); + client_config.SetIdleNetworkTimeout(QuicTime::Delta::FromSeconds(30)); + client_config.ToHandshakeMessage(&msg, connection_.transport_version()); + const QuicErrorCode error = + config.ProcessPeerHello(msg, CLIENT, &error_details); + EXPECT_THAT(error, IsQuicNoError()); + + if (connection_.version().UsesTls()) { + QuicConfigPeer::SetReceivedOriginalConnectionId( + &config, connection_.connection_id()); + QuicConfigPeer::SetReceivedInitialSourceConnectionId( + &config, connection_.connection_id()); + } + + connection_.SetFromConfig(config); + connection_.OnHandshakeComplete(); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + connection_.DisableLivenessTesting(); + ASSERT_TRUE(connection_.GetTimeoutAlarm()->IsSet()); + EXPECT_FALSE(connection_.MaybeTestLiveness()); + + QuicTime deadline = QuicConnectionPeer::GetIdleNetworkDeadline(&connection_); + QuicTime::Delta timeout = deadline - clock_.ApproximateNow(); + // Advance time to near the idle timeout. + clock_.AdvanceTime(timeout - QuicTime::Delta::FromMilliseconds(1)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + EXPECT_FALSE(connection_.MaybeTestLiveness()); +} + +TEST_P(QuicConnectionTest, SilentIdleTimeout) { + set_perspective(Perspective::IS_SERVER); + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + if (version().SupportsAntiAmplificationLimit()) { + QuicConnectionPeer::SetAddressValidated(&connection_); + } + + QuicConfig config; + QuicConfigPeer::SetNegotiated(&config, true); + if (connection_.version().UsesTls()) { + QuicConfigPeer::SetReceivedOriginalConnectionId( + &config, connection_.connection_id()); + QuicConfigPeer::SetReceivedInitialSourceConnectionId(&config, + QuicConnectionId()); + } + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + + EXPECT_TRUE(connection_.connected()); + EXPECT_TRUE(connection_.GetTimeoutAlarm()->IsSet()); + + if (version().handshake_protocol == PROTOCOL_TLS1_3) { + EXPECT_CALL(visitor_, BeforeConnectionCloseSent()); + } + EXPECT_CALL(visitor_, + OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + if (!QuicConnectionPeer::GetBandwidthUpdateTimeout(&connection_) + .IsInfinite()) { + // Fires the bandwidth update. + connection_.GetTimeoutAlarm()->Fire(); + } + connection_.GetTimeoutAlarm()->Fire(); + // Verify the connection close packets get serialized and added to + // termination packets list. + EXPECT_NE(nullptr, + QuicConnectionPeer::GetConnectionClosePacket(&connection_)); +} + +TEST_P(QuicConnectionTest, DoNotSendPing) { + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + connection_.OnHandshakeComplete(); + EXPECT_TRUE(connection_.connected()); + EXPECT_CALL(visitor_, ShouldKeepConnectionAlive()) + .WillRepeatedly(Return(true)); + EXPECT_FALSE(connection_.GetPingAlarm()->IsSet()); + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + + SendStreamDataToPeer( + GetNthClientInitiatedStreamId(0, connection_.transport_version()), + "GET /", 0, FIN, nullptr); + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + EXPECT_EQ(QuicTime::Delta::FromSeconds(15), + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + + // Now recevie an ACK and response of the previous packet, which will move the + // ping alarm forward. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + QuicFrames frames; + QuicAckFrame ack_frame = InitAckFrame(1); + frames.push_back(QuicFrame(&ack_frame)); + frames.push_back(QuicFrame(QuicStreamFrame( + GetNthClientInitiatedStreamId(0, connection_.transport_version()), true, + 0u, absl::string_view()))); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessFramesPacketAtLevel(1, frames, ENCRYPTION_FORWARD_SECURE); + EXPECT_TRUE(connection_.GetPingAlarm()->IsSet()); + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + // The ping timer is set slightly less than 15 seconds in the future, because + // of the 1s ping timer alarm granularity. + EXPECT_EQ( + QuicTime::Delta::FromSeconds(15) - QuicTime::Delta::FromMilliseconds(5), + connection_.GetPingAlarm()->deadline() - clock_.ApproximateNow()); + + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(15)); + // Suppose now ShouldKeepConnectionAlive returns false. + EXPECT_CALL(visitor_, ShouldKeepConnectionAlive()) + .WillRepeatedly(Return(false)); + // Verify PING does not get sent. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + connection_.GetPingAlarm()->Fire(); +} + +// Regression test for b/159698337 +TEST_P(QuicConnectionTest, DuplicateAckCausesLostPackets) { + if (!GetQuicReloadableFlag(quic_default_enable_5rto_blackhole_detection2)) { + return; + } + // Finish handshake. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + notifier_.NeuterUnencryptedData(); + connection_.NeuterUnencryptedPackets(); + connection_.OnHandshakeComplete(); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + + std::string data(1200, 'a'); + // Send data packets 1 - 5. + for (size_t i = 0; i < 5; ++i) { + SendStreamDataToPeer( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), data, + i * 1200, i == 4 ? FIN : NO_FIN, nullptr); + } + ASSERT_TRUE(connection_.BlackholeDetectionInProgress()); + + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _, _, _)) + .Times(3); + + // ACK packet 5 and 1 and 2 are detected lost. + QuicAckFrame frame = + InitAckFrame({{QuicPacketNumber(5), QuicPacketNumber(6)}}); + LostPacketVector lost_packets; + lost_packets.push_back( + LostPacket(QuicPacketNumber(1), kMaxOutgoingPacketSize)); + lost_packets.push_back( + LostPacket(QuicPacketNumber(2), kMaxOutgoingPacketSize)); + EXPECT_CALL(*loss_algorithm_, DetectLosses(_, _, _, _, _, _)) + .Times(AnyNumber()) + .WillOnce(DoAll(SetArgPointee<5>(lost_packets), + Return(LossDetectionInterface::DetectionStats()))); + ProcessAckPacket(1, &frame); + EXPECT_TRUE(connection_.BlackholeDetectionInProgress()); + QuicAlarm* retransmission_alarm = connection_.GetRetransmissionAlarm(); + EXPECT_TRUE(retransmission_alarm->IsSet()); + + // ACK packet 1 - 5 and 7. + QuicAckFrame frame2 = + InitAckFrame({{QuicPacketNumber(1), QuicPacketNumber(6)}, + {QuicPacketNumber(7), QuicPacketNumber(8)}}); + ProcessAckPacket(2, &frame2); + EXPECT_TRUE(connection_.BlackholeDetectionInProgress()); + + // ACK packet 7 again and assume packet 6 is detected lost. + QuicAckFrame frame3 = + InitAckFrame({{QuicPacketNumber(7), QuicPacketNumber(8)}}); + lost_packets.clear(); + lost_packets.push_back( + LostPacket(QuicPacketNumber(6), kMaxOutgoingPacketSize)); + EXPECT_CALL(*loss_algorithm_, DetectLosses(_, _, _, _, _, _)) + .Times(AnyNumber()) + .WillOnce(DoAll(SetArgPointee<5>(lost_packets), + Return(LossDetectionInterface::DetectionStats()))); + ProcessAckPacket(3, &frame3); + // Make sure loss detection is cancelled even there is no new acked packets. + EXPECT_FALSE(connection_.BlackholeDetectionInProgress()); +} + +TEST_P(QuicConnectionTest, ShorterIdleTimeoutOnSentPackets) { + EXPECT_TRUE(connection_.connected()); + RttStats* rtt_stats = const_cast(manager_->GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(100), + QuicTime::Delta::Zero(), QuicTime::Zero()); + + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + config.SetClientConnectionOptions(QuicTagVector{kFIDT}); + QuicConfigPeer::SetNegotiated(&config, true); + if (GetQuicReloadableFlag(quic_default_enable_5rto_blackhole_detection2)) { + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_COMPLETE)); + } + if (connection_.version().UsesTls()) { + QuicConfigPeer::SetReceivedOriginalConnectionId( + &config, connection_.connection_id()); + QuicConfigPeer::SetReceivedInitialSourceConnectionId( + &config, connection_.connection_id()); + } + connection_.SetFromConfig(config); + + ASSERT_TRUE(connection_.GetTimeoutAlarm()->IsSet()); + // Send a packet close to timeout. + QuicTime::Delta timeout = + connection_.GetTimeoutAlarm()->deadline() - clock_.Now(); + clock_.AdvanceTime(timeout - QuicTime::Delta::FromSeconds(1)); + // Send stream data. + SendStreamDataToPeer( + GetNthClientInitiatedStreamId(1, connection_.transport_version()), "foo", + 0, FIN, nullptr); + // Verify this sent packet does not extend idle timeout since 1s is > PTO + // delay. + ASSERT_TRUE(connection_.GetTimeoutAlarm()->IsSet()); + EXPECT_EQ(QuicTime::Delta::FromSeconds(1), + connection_.GetTimeoutAlarm()->deadline() - clock_.Now()); + + // Received an ACK 100ms later. + clock_.AdvanceTime(timeout - QuicTime::Delta::FromMilliseconds(100)); + QuicAckFrame ack = InitAckFrame(1); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + ProcessAckPacket(1, &ack); + // Verify idle timeout gets extended. + EXPECT_EQ(clock_.Now() + timeout, connection_.GetTimeoutAlarm()->deadline()); +} + +// Regression test for b/166255274 +TEST_P(QuicConnectionTest, + ReserializeInitialPacketInCoalescerAfterDiscardingInitialKey) { + if (!connection_.version().CanSendCoalescedPackets()) { + return; + } + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(1); + ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL); + EXPECT_TRUE(connection_.HasPendingAcks()); + connection_.SetEncrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(0x02)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + EXPECT_CALL(visitor_, OnHandshakePacketSent()).WillOnce(Invoke([this]() { + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.NeuterUnencryptedPackets(); + })); + { + QuicConnection::ScopedPacketFlusher flusher(&connection_); + connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_HANDSHAKE); + // Verify the packet is on hold. + EXPECT_EQ(0u, writer_->packets_write_attempts()); + // Flush pending ACKs. + connection_.GetAckAlarm()->Fire(); + } + EXPECT_FALSE(connection_.packet_creator().HasPendingFrames()); + // The ACK frame is deleted along with initial_packet_ in coalescer. Sending + // connection close would cause this (released) ACK frame be serialized (and + // crashes). + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacketAtLevel(1000, false, ENCRYPTION_FORWARD_SECURE); + EXPECT_TRUE(connection_.connected()); +} + +TEST_P(QuicConnectionTest, PathValidationOnNewSocketSuccess) { + if (!VersionHasIetfQuicFrames(connection_.version().transport_version)) { + return; + } + PathProbeTestInit(Perspective::IS_CLIENT); + const QuicSocketAddress kNewSelfAddress(QuicIpAddress::Any4(), 12345); + EXPECT_NE(kNewSelfAddress, connection_.self_address()); + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(AtLeast(1u)) + .WillOnce(Invoke([&]() { + EXPECT_EQ(1u, new_writer.packets_write_attempts()); + EXPECT_EQ(1u, new_writer.path_challenge_frames().size()); + EXPECT_EQ(1u, new_writer.padding_frames().size()); + EXPECT_EQ(kNewSelfAddress.host(), + new_writer.last_write_source_address()); + })); + bool success = false; + connection_.ValidatePath( + std::make_unique( + kNewSelfAddress, connection_.peer_address(), &new_writer), + std::make_unique( + &connection_, kNewSelfAddress, connection_.peer_address(), &success), + PathValidationReason::kReasonUnknown); + EXPECT_EQ(0u, writer_->packets_write_attempts()); + + QuicFrames frames; + frames.push_back(QuicFrame(QuicPathResponseFrame( + 99, new_writer.path_challenge_frames().front().data_buffer))); + ProcessFramesPacketWithAddresses(frames, kNewSelfAddress, kPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_TRUE(success); +} + +TEST_P(QuicConnectionTest, NewPathValidationCancelsPreviousOne) { + if (!VersionHasIetfQuicFrames(connection_.version().transport_version)) { + return; + } + PathProbeTestInit(Perspective::IS_CLIENT); + const QuicSocketAddress kNewSelfAddress(QuicIpAddress::Any4(), 12345); + EXPECT_NE(kNewSelfAddress, connection_.self_address()); + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(AtLeast(1u)) + .WillOnce(Invoke([&]() { + EXPECT_EQ(1u, new_writer.packets_write_attempts()); + EXPECT_EQ(1u, new_writer.path_challenge_frames().size()); + EXPECT_EQ(1u, new_writer.padding_frames().size()); + EXPECT_EQ(kNewSelfAddress.host(), + new_writer.last_write_source_address()); + })); + bool success = true; + connection_.ValidatePath( + std::make_unique( + kNewSelfAddress, connection_.peer_address(), &new_writer), + std::make_unique( + &connection_, kNewSelfAddress, connection_.peer_address(), &success), + PathValidationReason::kReasonUnknown); + EXPECT_EQ(0u, writer_->packets_write_attempts()); + + // Start another path validation request. + const QuicSocketAddress kNewSelfAddress2(QuicIpAddress::Any4(), 12346); + EXPECT_NE(kNewSelfAddress2, connection_.self_address()); + TestPacketWriter new_writer2(version(), &clock_, Perspective::IS_CLIENT); + if (!connection_.connection_migration_use_new_cid()) { + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(AtLeast(1u)) + .WillOnce(Invoke([&]() { + EXPECT_EQ(1u, new_writer2.packets_write_attempts()); + EXPECT_EQ(1u, new_writer2.path_challenge_frames().size()); + EXPECT_EQ(1u, new_writer2.padding_frames().size()); + EXPECT_EQ(kNewSelfAddress2.host(), + new_writer2.last_write_source_address()); + })); + } + bool success2 = false; + connection_.ValidatePath( + std::make_unique( + kNewSelfAddress2, connection_.peer_address(), &new_writer2), + std::make_unique( + &connection_, kNewSelfAddress2, connection_.peer_address(), + &success2), + PathValidationReason::kReasonUnknown); + EXPECT_FALSE(success); + if (connection_.connection_migration_use_new_cid()) { + // There is no pening path validation as there is no available connection + // ID. + EXPECT_FALSE(connection_.HasPendingPathValidation()); + } else { + EXPECT_TRUE(connection_.HasPendingPathValidation()); + } +} + +// Regression test for b/182571515. +TEST_P(QuicConnectionTest, PathValidationRetry) { + if (!VersionHasIetfQuicFrames(connection_.version().transport_version)) { + return; + } + PathProbeTestInit(Perspective::IS_CLIENT); + + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(2u) + .WillRepeatedly(Invoke([&]() { + EXPECT_EQ(1u, writer_->path_challenge_frames().size()); + EXPECT_EQ(1u, writer_->padding_frames().size()); + })); + bool success = true; + connection_.ValidatePath(std::make_unique( + connection_.self_address(), + connection_.peer_address(), writer_.get()), + std::make_unique( + &connection_, connection_.self_address(), + connection_.peer_address(), &success), + PathValidationReason::kReasonUnknown); + EXPECT_EQ(1u, writer_->packets_write_attempts()); + EXPECT_TRUE(connection_.HasPendingPathValidation()); + + // Retry after time out. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(3 * kInitialRttMs)); + static_cast(helper_->GetRandomGenerator())->ChangeValue(); + static_cast( + QuicPathValidatorPeer::retry_timer( + QuicConnectionPeer::path_validator(&connection_))) + ->Fire(); + EXPECT_EQ(2u, writer_->packets_write_attempts()); +} + +TEST_P(QuicConnectionTest, PathValidationReceivesStatelessReset) { + if (!VersionHasIetfQuicFrames(connection_.version().transport_version)) { + return; + } + PathProbeTestInit(Perspective::IS_CLIENT); + QuicConfig config; + QuicConfigPeer::SetReceivedStatelessResetToken(&config, + kTestStatelessResetToken); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + const QuicSocketAddress kNewSelfAddress(QuicIpAddress::Any4(), 12345); + EXPECT_NE(kNewSelfAddress, connection_.self_address()); + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(AtLeast(1u)) + .WillOnce(Invoke([&]() { + EXPECT_EQ(1u, new_writer.packets_write_attempts()); + EXPECT_EQ(1u, new_writer.path_challenge_frames().size()); + EXPECT_EQ(1u, new_writer.padding_frames().size()); + EXPECT_EQ(kNewSelfAddress.host(), + new_writer.last_write_source_address()); + })); + bool success = true; + connection_.ValidatePath( + std::make_unique( + kNewSelfAddress, connection_.peer_address(), &new_writer), + std::make_unique( + &connection_, kNewSelfAddress, connection_.peer_address(), &success), + PathValidationReason::kReasonUnknown); + EXPECT_EQ(0u, writer_->packets_write_attempts()); + EXPECT_TRUE(connection_.HasPendingPathValidation()); + + std::unique_ptr packet( + QuicFramer::BuildIetfStatelessResetPacket(connection_id_, + /*received_packet_length=*/100, + kTestStatelessResetToken)); + std::unique_ptr received( + ConstructReceivedPacket(*packet, QuicTime::Zero())); + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)).Times(0); + connection_.ProcessUdpPacket(kNewSelfAddress, kPeerAddress, *received); + EXPECT_FALSE(connection_.HasPendingPathValidation()); + EXPECT_FALSE(success); +} + +// Tests that PATH_CHALLENGE is dropped if it is sent via a blocked alternative +// writer. +TEST_P(QuicConnectionTest, SendPathChallengeUsingBlockedNewSocket) { + if (!VersionHasIetfQuicFrames(connection_.version().transport_version) || + !connection_.connection_migration_use_new_cid()) { + return; + } + PathProbeTestInit(Perspective::IS_CLIENT); + const QuicSocketAddress kNewSelfAddress(QuicIpAddress::Any4(), 12345); + EXPECT_NE(kNewSelfAddress, connection_.self_address()); + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); + new_writer.BlockOnNextWrite(); + EXPECT_CALL(visitor_, OnWriteBlocked()).Times(0); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(AtLeast(1)) + .WillOnce(Invoke([&]() { + // Even though the socket is blocked, the PATH_CHALLENGE should still be + // treated as sent. + EXPECT_EQ(1u, new_writer.packets_write_attempts()); + EXPECT_EQ(1u, new_writer.path_challenge_frames().size()); + EXPECT_EQ(1u, new_writer.padding_frames().size()); + EXPECT_EQ(kNewSelfAddress.host(), + new_writer.last_write_source_address()); + })); + bool success = false; + connection_.ValidatePath( + std::make_unique( + kNewSelfAddress, connection_.peer_address(), &new_writer), + std::make_unique( + &connection_, kNewSelfAddress, connection_.peer_address(), &success), + PathValidationReason::kReasonUnknown); + EXPECT_EQ(0u, writer_->packets_write_attempts()); + + new_writer.SetWritable(); + // Write event on the default socket shouldn't make any difference. + connection_.OnCanWrite(); + // A NEW_CONNECTION_ID frame is received in PathProbeTestInit and OnCanWrite + // will write a acking packet. + EXPECT_EQ(1u, writer_->packets_write_attempts()); + EXPECT_EQ(1u, new_writer.packets_write_attempts()); +} + +// Tests that PATH_CHALLENGE is dropped if it is sent via the default writer +// and the writer is blocked. +TEST_P(QuicConnectionTest, SendPathChallengeUsingBlockedDefaultSocket) { + if (!VersionHasIetfQuicFrames(connection_.version().transport_version)) { + return; + } + PathProbeTestInit(Perspective::IS_SERVER); + const QuicSocketAddress kNewPeerAddress(QuicIpAddress::Any4(), 12345); + writer_->BlockOnNextWrite(); + // 1st time is after writer returns WRITE_STATUS_BLOCKED. 2nd time is in + // ShouldGeneratePacket(). + EXPECT_CALL(visitor_, OnWriteBlocked()).Times(AtLeast(2)); + QuicPathFrameBuffer path_challenge_payload{0, 1, 2, 3, 4, 5, 6, 7}; + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(AtLeast(1u)) + .WillOnce(Invoke([&]() { + // This packet isn't sent actually, instead it is buffered in the + // connection. + EXPECT_EQ(1u, writer_->packets_write_attempts()); + if (connection_.validate_client_address()) { + EXPECT_EQ(1u, writer_->path_response_frames().size()); + EXPECT_EQ(0, + memcmp(&path_challenge_payload, + &writer_->path_response_frames().front().data_buffer, + sizeof(path_challenge_payload))); + } + EXPECT_EQ(1u, writer_->path_challenge_frames().size()); + EXPECT_EQ(1u, writer_->padding_frames().size()); + EXPECT_EQ(kNewPeerAddress, writer_->last_write_peer_address()); + })) + .WillRepeatedly(Invoke([&]() { + // Only one PATH_CHALLENGE should be sent out. + EXPECT_EQ(0u, writer_->path_challenge_frames().size()); + })); + bool success = false; + if (connection_.validate_client_address()) { + // Receiving a PATH_CHALLENGE from the new peer address should trigger + // address validation. + QuicFrames frames; + frames.push_back( + QuicFrame(QuicPathChallengeFrame(0, path_challenge_payload))); + ProcessFramesPacketWithAddresses(frames, kSelfAddress, kNewPeerAddress, + ENCRYPTION_FORWARD_SECURE); + } else { + // Manually start to validate the new peer address. + connection_.ValidatePath( + std::make_unique( + connection_.self_address(), kNewPeerAddress, writer_.get()), + std::make_unique( + &connection_, connection_.self_address(), kNewPeerAddress, + &success), + PathValidationReason::kReasonUnknown); + } + EXPECT_EQ(1u, writer_->packets_write_attempts()); + + // Try again with the new socket blocked from the beginning. The 2nd + // PATH_CHALLENGE shouldn't be serialized, but be dropped. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(3 * kInitialRttMs)); + static_cast(helper_->GetRandomGenerator())->ChangeValue(); + static_cast( + QuicPathValidatorPeer::retry_timer( + QuicConnectionPeer::path_validator(&connection_))) + ->Fire(); + + // No more write attempt should be made. + EXPECT_EQ(1u, writer_->packets_write_attempts()); + + writer_->SetWritable(); + // OnCanWrite() should actually write out the 1st PATH_CHALLENGE packet + // buffered earlier, thus incrementing the write counter. It may also send + // ACKs to previously received packets. + connection_.OnCanWrite(); + EXPECT_LE(2u, writer_->packets_write_attempts()); +} + +// Tests that write error on the alternate socket should be ignored. +TEST_P(QuicConnectionTest, SendPathChallengeFailOnNewSocket) { + if (!VersionHasIetfQuicFrames(connection_.version().transport_version)) { + return; + } + PathProbeTestInit(Perspective::IS_CLIENT); + const QuicSocketAddress kNewSelfAddress(QuicIpAddress::Any4(), 12345); + EXPECT_NE(kNewSelfAddress, connection_.self_address()); + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); + new_writer.SetShouldWriteFail(); + EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .Times(0); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0u); + + bool success = false; + connection_.ValidatePath( + std::make_unique( + kNewSelfAddress, connection_.peer_address(), &new_writer), + std::make_unique( + &connection_, kNewSelfAddress, connection_.peer_address(), &success), + PathValidationReason::kReasonUnknown); + EXPECT_EQ(1u, new_writer.packets_write_attempts()); + EXPECT_EQ(1u, new_writer.path_challenge_frames().size()); + EXPECT_EQ(1u, new_writer.padding_frames().size()); + EXPECT_EQ(kNewSelfAddress.host(), new_writer.last_write_source_address()); + + EXPECT_EQ(0u, writer_->packets_write_attempts()); + // Regardless of the write error, the connection should still be connected. + EXPECT_TRUE(connection_.connected()); +} + +// Tests that write error while sending PATH_CHALLANGE from the default socket +// should close the connection. +TEST_P(QuicConnectionTest, SendPathChallengeFailOnDefaultPath) { + if (!VersionHasIetfQuicFrames(connection_.version().transport_version)) { + return; + } + PathProbeTestInit(Perspective::IS_CLIENT); + + writer_->SetShouldWriteFail(); + EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .WillOnce( + Invoke([](QuicConnectionCloseFrame frame, ConnectionCloseSource) { + EXPECT_EQ(QUIC_PACKET_WRITE_ERROR, frame.quic_error_code); + })); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0u); + { + // Add a flusher to force flush, otherwise the frames will remain in the + // packet creator. + bool success = false; + QuicConnection::ScopedPacketFlusher flusher(&connection_); + connection_.ValidatePath(std::make_unique( + connection_.self_address(), + connection_.peer_address(), writer_.get()), + std::make_unique( + &connection_, connection_.self_address(), + connection_.peer_address(), &success), + PathValidationReason::kReasonUnknown); + } + EXPECT_EQ(1u, writer_->packets_write_attempts()); + EXPECT_EQ(1u, writer_->path_challenge_frames().size()); + EXPECT_EQ(1u, writer_->padding_frames().size()); + EXPECT_EQ(connection_.peer_address(), writer_->last_write_peer_address()); + EXPECT_FALSE(connection_.connected()); + // Closing connection should abandon ongoing path validation. + EXPECT_FALSE(connection_.HasPendingPathValidation()); +} + +TEST_P(QuicConnectionTest, SendPathChallengeFailOnAlternativePeerAddress) { + if (!VersionHasIetfQuicFrames(connection_.version().transport_version)) { + return; + } + PathProbeTestInit(Perspective::IS_CLIENT); + + writer_->SetShouldWriteFail(); + const QuicSocketAddress kNewPeerAddress(QuicIpAddress::Any4(), 12345); + EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .WillOnce( + Invoke([](QuicConnectionCloseFrame frame, ConnectionCloseSource) { + EXPECT_EQ(QUIC_PACKET_WRITE_ERROR, frame.quic_error_code); + })); + // Sending PATH_CHALLENGE to trigger a flush write which will fail and close + // the connection. + bool success = false; + connection_.ValidatePath( + std::make_unique( + connection_.self_address(), kNewPeerAddress, writer_.get()), + std::make_unique( + &connection_, connection_.self_address(), kNewPeerAddress, &success), + PathValidationReason::kReasonUnknown); + + EXPECT_EQ(1u, writer_->packets_write_attempts()); + EXPECT_FALSE(connection_.HasPendingPathValidation()); + EXPECT_EQ(1u, writer_->path_challenge_frames().size()); + EXPECT_EQ(1u, writer_->padding_frames().size()); + EXPECT_EQ(kNewPeerAddress, writer_->last_write_peer_address()); + EXPECT_FALSE(connection_.connected()); +} + +TEST_P(QuicConnectionTest, + SendPathChallengeFailPacketTooBigOnAlternativePeerAddress) { + if (!VersionHasIetfQuicFrames(connection_.version().transport_version)) { + return; + } + PathProbeTestInit(Perspective::IS_CLIENT); + // Make sure there is no outstanding ACK_FRAME to write. + connection_.OnCanWrite(); + uint32_t num_packets_write_attempts = writer_->packets_write_attempts(); + + writer_->SetShouldWriteFail(); + writer_->SetWriteError(*writer_->MessageTooBigErrorCode()); + const QuicSocketAddress kNewPeerAddress(QuicIpAddress::Any4(), 12345); + EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .Times(0u); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0u); + // Sending PATH_CHALLENGE to trigger a flush write which will fail with + // MSG_TOO_BIG. + bool success = false; + connection_.ValidatePath( + std::make_unique( + connection_.self_address(), kNewPeerAddress, writer_.get()), + std::make_unique( + &connection_, connection_.self_address(), kNewPeerAddress, &success), + PathValidationReason::kReasonUnknown); + EXPECT_TRUE(connection_.HasPendingPathValidation()); + // Connection shouldn't be closed. + EXPECT_TRUE(connection_.connected()); + EXPECT_EQ(++num_packets_write_attempts, writer_->packets_write_attempts()); + EXPECT_EQ(1u, writer_->path_challenge_frames().size()); + EXPECT_EQ(1u, writer_->padding_frames().size()); + EXPECT_EQ(kNewPeerAddress, writer_->last_write_peer_address()); +} + +// Check that if there are two PATH_CHALLENGE frames in the packet, the latter +// one is ignored. +TEST_P(QuicConnectionTest, ReceiveMultiplePathChallenge) { + if (!VersionHasIetfQuicFrames(connection_.version().transport_version)) { + return; + } + PathProbeTestInit(Perspective::IS_SERVER); + + QuicPathFrameBuffer path_frame_buffer1{0, 1, 2, 3, 4, 5, 6, 7}; + QuicPathFrameBuffer path_frame_buffer2{8, 9, 10, 11, 12, 13, 14, 15}; + QuicFrames frames; + frames.push_back(QuicFrame(QuicPathChallengeFrame(0, path_frame_buffer1))); + frames.push_back(QuicFrame(QuicPathChallengeFrame(0, path_frame_buffer2))); + const QuicSocketAddress kNewPeerAddress(QuicIpAddress::Loopback6(), + /*port=*/23456); + + EXPECT_CALL(visitor_, OnConnectionMigration(PORT_CHANGE)).Times(0); + + // Expect 2 packets to be sent: the first are padded PATH_RESPONSE(s) to the + // alternative peer address. The 2nd is a ACK-only packet to the original + // peer address. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(2) + .WillOnce(Invoke([=]() { + EXPECT_EQ(1u, writer_->path_response_frames().size()); + // The final check is to ensure that the random data in the response + // matches the random data from the challenge. + EXPECT_EQ(0, + memcmp(path_frame_buffer1.data(), + &(writer_->path_response_frames().front().data_buffer), + sizeof(path_frame_buffer1))); + EXPECT_EQ(1u, writer_->padding_frames().size()); + EXPECT_EQ(kNewPeerAddress, writer_->last_write_peer_address()); + })) + .WillOnce(Invoke([=]() { + // The last write of ACK-only packet should still use the old peer + // address. + EXPECT_EQ(kPeerAddress, writer_->last_write_peer_address()); + })); + ProcessFramesPacketWithAddresses(frames, kSelfAddress, kNewPeerAddress, + ENCRYPTION_FORWARD_SECURE); +} + +TEST_P(QuicConnectionTest, ReceiveStreamFrameBeforePathChallenge) { + if (!VersionHasIetfQuicFrames(connection_.version().transport_version)) { + return; + } + PathProbeTestInit(Perspective::IS_SERVER); + + QuicFrames frames; + frames.push_back(QuicFrame(frame1_)); + QuicPathFrameBuffer path_frame_buffer{0, 1, 2, 3, 4, 5, 6, 7}; + frames.push_back(QuicFrame(QuicPathChallengeFrame(0, path_frame_buffer))); + frames.push_back(QuicFrame(QuicPaddingFrame(-1))); + const QuicSocketAddress kNewPeerAddress(QuicIpAddress::Loopback4(), + /*port=*/23456); + + EXPECT_CALL(visitor_, OnConnectionMigration(IPV6_TO_IPV4_CHANGE)); + EXPECT_CALL(*send_algorithm_, OnConnectionMigration()) + .Times(connection_.validate_client_address() ? 0u : 1u); + EXPECT_CALL(visitor_, OnStreamFrame(_)) + .WillOnce(Invoke([=](const QuicStreamFrame& frame) { + // Send some data on the stream. The STREAM_FRAME should be built into + // one packet together with the latter PATH_RESPONSE and PATH_CHALLENGE. + const std::string data{"response body"}; + connection_.producer()->SaveStreamData(frame.stream_id, data); + return notifier_.WriteOrBufferData(frame.stream_id, data.length(), + NO_FIN); + })); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(connection_.validate_client_address() ? 0u : 1u); + ProcessFramesPacketWithAddresses(frames, kSelfAddress, kNewPeerAddress, + ENCRYPTION_FORWARD_SECURE); + + // Verify that this packet contains a STREAM_FRAME and a + // PATH_RESPONSE_FRAME. + EXPECT_EQ(1u, writer_->stream_frames().size()); + EXPECT_EQ(1u, writer_->path_response_frames().size()); + EXPECT_EQ(connection_.validate_client_address() ? 1u : 0u, + writer_->path_challenge_frames().size()); + // The final check is to ensure that the random data in the response + // matches the random data from the challenge. + EXPECT_EQ(0, memcmp(path_frame_buffer.data(), + &(writer_->path_response_frames().front().data_buffer), + sizeof(path_frame_buffer))); + EXPECT_EQ(connection_.validate_client_address() ? 1u : 0u, + writer_->path_challenge_frames().size()); + EXPECT_EQ(1u, writer_->padding_frames().size()); + EXPECT_EQ(kNewPeerAddress, writer_->last_write_peer_address()); + if (connection_.validate_client_address()) { + EXPECT_TRUE(connection_.HasPendingPathValidation()); + } +} + +TEST_P(QuicConnectionTest, ReceiveStreamFrameFollowingPathChallenge) { + if (!VersionHasIetfQuicFrames(connection_.version().transport_version)) { + return; + } + PathProbeTestInit(Perspective::IS_SERVER); + + QuicFrames frames; + QuicPathFrameBuffer path_frame_buffer{0, 1, 2, 3, 4, 5, 6, 7}; + frames.push_back(QuicFrame(QuicPathChallengeFrame(0, path_frame_buffer))); + // PATH_RESPONSE should be flushed out before the rest packet is parsed. + frames.push_back(QuicFrame(frame1_)); + const QuicSocketAddress kNewPeerAddress(QuicIpAddress::Loopback4(), + /*port=*/23456); + QuicByteCount received_packet_size; + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(AtLeast(1u)) + .WillOnce(Invoke([=, &received_packet_size]() { + // Verify that this packet contains a PATH_RESPONSE_FRAME. + EXPECT_EQ(0u, writer_->stream_frames().size()); + EXPECT_EQ(1u, writer_->path_response_frames().size()); + // The final check is to ensure that the random data in the response + // matches the random data from the challenge. + EXPECT_EQ(0, + memcmp(path_frame_buffer.data(), + &(writer_->path_response_frames().front().data_buffer), + sizeof(path_frame_buffer))); + EXPECT_EQ(connection_.validate_client_address() ? 1u : 0u, + writer_->path_challenge_frames().size()); + EXPECT_EQ(1u, writer_->padding_frames().size()); + EXPECT_EQ(kNewPeerAddress, writer_->last_write_peer_address()); + received_packet_size = + QuicConnectionPeer::BytesReceivedOnAlternativePath(&connection_); + })); + EXPECT_CALL(visitor_, OnConnectionMigration(IPV6_TO_IPV4_CHANGE)); + EXPECT_CALL(*send_algorithm_, OnConnectionMigration()) + .Times(connection_.validate_client_address() ? 0u : 1u); + EXPECT_CALL(visitor_, OnStreamFrame(_)) + .WillOnce(Invoke([=](const QuicStreamFrame& frame) { + // Send some data on the stream. The STREAM_FRAME should be built into a + // new packet but throttled by anti-amplifciation limit. + const std::string data{"response body"}; + connection_.producer()->SaveStreamData(frame.stream_id, data); + return notifier_.WriteOrBufferData(frame.stream_id, data.length(), + NO_FIN); + })); + + ProcessFramesPacketWithAddresses(frames, kSelfAddress, kNewPeerAddress, + ENCRYPTION_FORWARD_SECURE); + if (!connection_.validate_client_address()) { + return; + } + EXPECT_TRUE(connection_.HasPendingPathValidation()); + EXPECT_EQ(0u, + QuicConnectionPeer::BytesReceivedOnAlternativePath(&connection_)); + EXPECT_EQ( + received_packet_size, + QuicConnectionPeer::BytesReceivedBeforeAddressValidation(&connection_)); +} + +// Tests that a PATH_CHALLENGE is received in between other frames in an out of +// order packet. +TEST_P(QuicConnectionTest, PathChallengeWithDataInOutOfOrderPacket) { + if (!VersionHasIetfQuicFrames(connection_.version().transport_version)) { + return; + } + PathProbeTestInit(Perspective::IS_SERVER); + + QuicFrames frames; + frames.push_back(QuicFrame(frame1_)); + QuicPathFrameBuffer path_frame_buffer{0, 1, 2, 3, 4, 5, 6, 7}; + frames.push_back(QuicFrame(QuicPathChallengeFrame(0, path_frame_buffer))); + frames.push_back(QuicFrame(frame2_)); + const QuicSocketAddress kNewPeerAddress(QuicIpAddress::Loopback6(), + /*port=*/23456); + + EXPECT_CALL(visitor_, OnConnectionMigration(PORT_CHANGE)).Times(0u); + EXPECT_CALL(visitor_, OnStreamFrame(_)) + .Times(2) + .WillRepeatedly(Invoke([=](const QuicStreamFrame& frame) { + // Send some data on the stream. The STREAM_FRAME should be built into + // one packet together with the latter PATH_RESPONSE. + const std::string data{"response body"}; + connection_.producer()->SaveStreamData(frame.stream_id, data); + return notifier_.WriteOrBufferData(frame.stream_id, data.length(), + NO_FIN); + })); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .WillOnce(Invoke([=]() { + // Verify that this packet contains a STREAM_FRAME and is sent to the + // original peer address. + EXPECT_EQ(1u, writer_->stream_frames().size()); + // No connection migration should happen because the packet is received + // out of order. + EXPECT_EQ(kPeerAddress, writer_->last_write_peer_address()); + })) + .WillOnce(Invoke([=]() { + EXPECT_EQ(1u, writer_->path_response_frames().size()); + // The final check is to ensure that the random data in the response + // matches the random data from the challenge. + EXPECT_EQ(0, + memcmp(path_frame_buffer.data(), + &(writer_->path_response_frames().front().data_buffer), + sizeof(path_frame_buffer))); + EXPECT_EQ(1u, writer_->padding_frames().size()); + // PATH_RESPONSE should be sent in another packet to a different peer + // address. + EXPECT_EQ(kNewPeerAddress, writer_->last_write_peer_address()); + })) + .WillOnce(Invoke([=]() { + // Verify that this packet contains a STREAM_FRAME and is sent to the + // original peer address. + EXPECT_EQ(1u, writer_->stream_frames().size()); + // No connection migration should happen because the packet is received + // out of order. + EXPECT_EQ(kPeerAddress, writer_->last_write_peer_address()); + })); + // Lower the packet number so that receiving this packet shouldn't trigger + // peer migration. + QuicPacketCreatorPeer::SetPacketNumber(&peer_creator_, 1); + ProcessFramesPacketWithAddresses(frames, kSelfAddress, kNewPeerAddress, + ENCRYPTION_FORWARD_SECURE); +} + +// Tests that a PATH_CHALLENGE is cached if its PATH_RESPONSE can't be sent. +TEST_P(QuicConnectionTest, FailToWritePathResponse) { + if (!VersionHasIetfQuicFrames(connection_.version().transport_version)) { + return; + } + PathProbeTestInit(Perspective::IS_SERVER); + + QuicFrames frames; + QuicPathFrameBuffer path_frame_buffer{0, 1, 2, 3, 4, 5, 6, 7}; + frames.push_back(QuicFrame(QuicPathChallengeFrame(0, path_frame_buffer))); + const QuicSocketAddress kNewPeerAddress(QuicIpAddress::Loopback6(), + /*port=*/23456); + + EXPECT_CALL(visitor_, OnConnectionMigration(PORT_CHANGE)).Times(0u); + // Lower the packet number so that receiving this packet shouldn't trigger + // peer migration. + QuicPacketCreatorPeer::SetPacketNumber(&peer_creator_, 1); + EXPECT_CALL(visitor_, OnWriteBlocked()).Times(AtLeast(1)); + writer_->SetWriteBlocked(); + ProcessFramesPacketWithAddresses(frames, kSelfAddress, kNewPeerAddress, + ENCRYPTION_FORWARD_SECURE); +} + +// Regression test for b/168101557. +TEST_P(QuicConnectionTest, HandshakeDataDoesNotGetPtoed) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + set_perspective(Perspective::IS_SERVER); + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + } + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(AnyNumber()); + ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL); + EXPECT_TRUE(connection_.HasPendingAcks()); + + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + // Send INITIAL 1. + connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_INITIAL); + + connection_.SetEncrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + SetDecrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + // Send HANDSHAKE packets. + EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1); + connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_HANDSHAKE); + + connection_.SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + // Send half RTT packet. + connection_.SendStreamDataWithString(2, "foo", 0, NO_FIN); + + // Receives HANDSHAKE 1. + peer_framer_.SetEncrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + ProcessCryptoPacketAtLevel(1, ENCRYPTION_HANDSHAKE); + // Discard INITIAL key. + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.NeuterUnencryptedPackets(); + // Verify there is pending ACK. + ASSERT_TRUE(connection_.HasPendingAcks()); + // Set the send alarm. + connection_.GetSendAlarm()->Set(clock_.ApproximateNow()); + + // Fire ACK alarm. + connection_.GetAckAlarm()->Fire(); + // Verify 1-RTT packet is coalesced with handshake packet. + EXPECT_EQ(0x03030303u, writer_->final_bytes_of_last_packet()); + connection_.GetSendAlarm()->Fire(); + + ASSERT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + connection_.GetRetransmissionAlarm()->Fire(); + // Verify a handshake packet gets PTOed and 1-RTT packet gets coalesced. + EXPECT_EQ(0x03030303u, writer_->final_bytes_of_last_packet()); +} + +// Regression test for b/168294218. +TEST_P(QuicConnectionTest, CoalescerHandlesInitialKeyDiscard) { + if (!connection_.version().CanSendCoalescedPackets()) { + return; + } + SetQuicReloadableFlag(quic_discard_initial_packet_with_key_dropped, true); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(2); + EXPECT_CALL(visitor_, OnHandshakePacketSent()).WillOnce(Invoke([this]() { + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.NeuterUnencryptedPackets(); + })); + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + + EXPECT_EQ(0u, connection_.GetStats().packets_discarded); + { + QuicConnection::ScopedPacketFlusher flusher(&connection_); + ProcessCryptoPacketAtLevel(1000, ENCRYPTION_INITIAL); + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(1)); + connection_.SetEncrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + connection_.SendCryptoDataWithString(std::string(1200, 'a'), 0); + // Verify this packet is on hold. + EXPECT_EQ(0u, writer_->packets_write_attempts()); + } + EXPECT_TRUE(connection_.connected()); +} + +// Regresstion test for b/168294218 +TEST_P(QuicConnectionTest, ZeroRttRejectionAndMissingInitialKeys) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + // Not defer send in response to packet. + connection_.set_defer_send_in_response_to_packets(false); + EXPECT_CALL(visitor_, OnHandshakePacketSent()).WillOnce(Invoke([this]() { + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.NeuterUnencryptedPackets(); + })); + EXPECT_CALL(visitor_, OnCryptoFrame(_)) + .WillRepeatedly(Invoke([=](const QuicCryptoFrame& frame) { + if (frame.level == ENCRYPTION_HANDSHAKE) { + // 0-RTT gets rejected. + connection_.MarkZeroRttPacketsForRetransmission(0); + // Send Crypto data. + connection_.SetEncrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_HANDSHAKE); + connection_.SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + // Advance INITIAL ack delay to trigger initial ACK to be sent AFTER + // the retransmission of rejected 0-RTT packets while the HANDSHAKE + // packet is still in the coalescer, such that the INITIAL key gets + // dropped between SendAllPendingAcks and actually send the ack frame, + // bummer. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(1)); + } + })); + connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_INITIAL); + // Send 0-RTT packet. + connection_.SetEncrypter( + ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + connection_.SendStreamDataWithString(2, "foo", 0, NO_FIN); + + QuicAckFrame frame1 = InitAckFrame(1); + // Received ACK for packet 1. + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _, _, _)); + ProcessFramePacketAtLevel(1, QuicFrame(&frame1), ENCRYPTION_INITIAL); + EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + + // Fire retransmission alarm. + connection_.GetRetransmissionAlarm()->Fire(); + + QuicFrames frames1; + frames1.push_back(QuicFrame(&crypto_frame_)); + QuicFrames frames2; + QuicCryptoFrame crypto_frame(ENCRYPTION_HANDSHAKE, 0, + absl::string_view(data1)); + frames2.push_back(QuicFrame(&crypto_frame)); + ProcessCoalescedPacket( + {{2, frames1, ENCRYPTION_INITIAL}, {3, frames2, ENCRYPTION_HANDSHAKE}}); +} + +TEST_P(QuicConnectionTest, OnZeroRttPacketAcked) { + if (!connection_.version().UsesTls()) { + return; + } + MockQuicConnectionDebugVisitor debug_visitor; + connection_.set_debug_visitor(&debug_visitor); + connection_.SendCryptoStreamData(); + // Send 0-RTT packet. + connection_.SetEncrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(0x02)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + connection_.SendStreamDataWithString(2, "foo", 0, NO_FIN); + connection_.SendStreamDataWithString(4, "bar", 0, NO_FIN); + // Received ACK for packet 1, HANDSHAKE packet and 1-RTT ACK. + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _, _, _)) + .Times(AnyNumber()); + QuicFrames frames1; + QuicAckFrame ack_frame1 = InitAckFrame(1); + frames1.push_back(QuicFrame(&ack_frame1)); + + QuicFrames frames2; + QuicCryptoFrame crypto_frame(ENCRYPTION_HANDSHAKE, 0, + absl::string_view(data1)); + frames2.push_back(QuicFrame(&crypto_frame)); + EXPECT_CALL(debug_visitor, OnZeroRttPacketAcked()).Times(0); + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(1); + ProcessCoalescedPacket( + {{1, frames1, ENCRYPTION_INITIAL}, {2, frames2, ENCRYPTION_HANDSHAKE}}); + + QuicFrames frames3; + QuicAckFrame ack_frame2 = + InitAckFrame({{QuicPacketNumber(2), QuicPacketNumber(3)}}); + frames3.push_back(QuicFrame(&ack_frame2)); + EXPECT_CALL(debug_visitor, OnZeroRttPacketAcked()).Times(1); + ProcessCoalescedPacket({{3, frames3, ENCRYPTION_FORWARD_SECURE}}); + + QuicFrames frames4; + QuicAckFrame ack_frame3 = + InitAckFrame({{QuicPacketNumber(3), QuicPacketNumber(4)}}); + frames4.push_back(QuicFrame(&ack_frame3)); + EXPECT_CALL(debug_visitor, OnZeroRttPacketAcked()).Times(0); + ProcessCoalescedPacket({{4, frames4, ENCRYPTION_FORWARD_SECURE}}); +} + +TEST_P(QuicConnectionTest, InitiateKeyUpdate) { + if (!connection_.version().UsesTls()) { + return; + } + + TransportParameters params; + QuicConfig config; + std::string error_details; + EXPECT_THAT(config.ProcessTransportParameters( + params, /* is_resumption = */ false, &error_details), + IsQuicNoError()); + QuicConfigPeer::SetNegotiated(&config, true); + if (connection_.version().UsesTls()) { + QuicConfigPeer::SetReceivedOriginalConnectionId( + &config, connection_.connection_id()); + QuicConfigPeer::SetReceivedInitialSourceConnectionId( + &config, connection_.connection_id()); + } + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + + EXPECT_FALSE(connection_.IsKeyUpdateAllowed()); + + MockFramerVisitor peer_framer_visitor_; + peer_framer_.set_visitor(&peer_framer_visitor_); + + uint8_t correct_tag = ENCRYPTION_FORWARD_SECURE; + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + connection_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(correct_tag)); + SetDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(correct_tag)); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + connection_.OnHandshakeComplete(); + + peer_framer_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(correct_tag)); + + // Key update should still not be allowed, since no packet has been acked + // from the current key phase. + EXPECT_FALSE(connection_.IsKeyUpdateAllowed()); + EXPECT_FALSE(connection_.HaveSentPacketsInCurrentKeyPhaseButNoneAcked()); + + // Send packet 1. + QuicPacketNumber last_packet; + SendStreamDataToPeer(1, "foo", 0, NO_FIN, &last_packet); + EXPECT_EQ(QuicPacketNumber(1u), last_packet); + + // Key update should still not be allowed, even though a packet was sent in + // the current key phase it hasn't been acked yet. + EXPECT_FALSE(connection_.IsKeyUpdateAllowed()); + EXPECT_TRUE(connection_.HaveSentPacketsInCurrentKeyPhaseButNoneAcked()); + + EXPECT_FALSE(connection_.GetDiscardPreviousOneRttKeysAlarm()->IsSet()); + // Receive ack for packet 1. + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + QuicAckFrame frame1 = InitAckFrame(1); + ProcessAckPacket(&frame1); + + // OnDecryptedFirstPacketInKeyPhase is called even on the first key phase, + // so discard_previous_keys_alarm_ should be set now. + EXPECT_TRUE(connection_.GetDiscardPreviousOneRttKeysAlarm()->IsSet()); + EXPECT_FALSE(connection_.HaveSentPacketsInCurrentKeyPhaseButNoneAcked()); + + correct_tag++; + // Key update should now be allowed. + EXPECT_CALL(visitor_, AdvanceKeysAndCreateCurrentOneRttDecrypter()) + .WillOnce([&correct_tag]() { + return std::make_unique(correct_tag); + }); + EXPECT_CALL(visitor_, CreateCurrentOneRttEncrypter()) + .WillOnce([&correct_tag]() { + return std::make_unique(correct_tag); + }); + EXPECT_CALL(visitor_, OnKeyUpdate(KeyUpdateReason::kLocalForTests)); + EXPECT_TRUE(connection_.InitiateKeyUpdate(KeyUpdateReason::kLocalForTests)); + // discard_previous_keys_alarm_ should not be set until a packet from the new + // key phase has been received. (The alarm that was set above should be + // cleared if it hasn't fired before the next key update happened.) + EXPECT_FALSE(connection_.GetDiscardPreviousOneRttKeysAlarm()->IsSet()); + EXPECT_FALSE(connection_.HaveSentPacketsInCurrentKeyPhaseButNoneAcked()); + + // Pretend that peer accepts the key update. + EXPECT_CALL(peer_framer_visitor_, + AdvanceKeysAndCreateCurrentOneRttDecrypter()) + .WillOnce([&correct_tag]() { + return std::make_unique(correct_tag); + }); + EXPECT_CALL(peer_framer_visitor_, CreateCurrentOneRttEncrypter()) + .WillOnce([&correct_tag]() { + return std::make_unique(correct_tag); + }); + peer_framer_.SetKeyUpdateSupportForConnection(true); + peer_framer_.DoKeyUpdate(KeyUpdateReason::kRemote); + + // Another key update should not be allowed yet. + EXPECT_FALSE(connection_.IsKeyUpdateAllowed()); + + // Send packet 2. + SendStreamDataToPeer(2, "bar", 0, NO_FIN, &last_packet); + EXPECT_EQ(QuicPacketNumber(2u), last_packet); + EXPECT_TRUE(connection_.HaveSentPacketsInCurrentKeyPhaseButNoneAcked()); + // Receive ack for packet 2. + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + QuicAckFrame frame2 = InitAckFrame(2); + ProcessAckPacket(&frame2); + EXPECT_TRUE(connection_.GetDiscardPreviousOneRttKeysAlarm()->IsSet()); + EXPECT_FALSE(connection_.HaveSentPacketsInCurrentKeyPhaseButNoneAcked()); + + correct_tag++; + // Key update should be allowed again now that a packet has been acked from + // the current key phase. + EXPECT_CALL(visitor_, AdvanceKeysAndCreateCurrentOneRttDecrypter()) + .WillOnce([&correct_tag]() { + return std::make_unique(correct_tag); + }); + EXPECT_CALL(visitor_, CreateCurrentOneRttEncrypter()) + .WillOnce([&correct_tag]() { + return std::make_unique(correct_tag); + }); + EXPECT_CALL(visitor_, OnKeyUpdate(KeyUpdateReason::kLocalForTests)); + EXPECT_TRUE(connection_.InitiateKeyUpdate(KeyUpdateReason::kLocalForTests)); + + // Pretend that peer accepts the key update. + EXPECT_CALL(peer_framer_visitor_, + AdvanceKeysAndCreateCurrentOneRttDecrypter()) + .WillOnce([&correct_tag]() { + return std::make_unique(correct_tag); + }); + EXPECT_CALL(peer_framer_visitor_, CreateCurrentOneRttEncrypter()) + .WillOnce([&correct_tag]() { + return std::make_unique(correct_tag); + }); + peer_framer_.DoKeyUpdate(KeyUpdateReason::kRemote); + + // Another key update should not be allowed yet. + EXPECT_FALSE(connection_.IsKeyUpdateAllowed()); + + // Send packet 3. + SendStreamDataToPeer(3, "baz", 0, NO_FIN, &last_packet); + EXPECT_EQ(QuicPacketNumber(3u), last_packet); + + // Another key update should not be allowed yet. + EXPECT_FALSE(connection_.IsKeyUpdateAllowed()); + EXPECT_TRUE(connection_.HaveSentPacketsInCurrentKeyPhaseButNoneAcked()); + + // Receive ack for packet 3. + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + QuicAckFrame frame3 = InitAckFrame(3); + ProcessAckPacket(&frame3); + EXPECT_TRUE(connection_.GetDiscardPreviousOneRttKeysAlarm()->IsSet()); + EXPECT_FALSE(connection_.HaveSentPacketsInCurrentKeyPhaseButNoneAcked()); + + correct_tag++; + // Key update should be allowed now. + EXPECT_CALL(visitor_, AdvanceKeysAndCreateCurrentOneRttDecrypter()) + .WillOnce([&correct_tag]() { + return std::make_unique(correct_tag); + }); + EXPECT_CALL(visitor_, CreateCurrentOneRttEncrypter()) + .WillOnce([&correct_tag]() { + return std::make_unique(correct_tag); + }); + EXPECT_CALL(visitor_, OnKeyUpdate(KeyUpdateReason::kLocalForTests)); + EXPECT_TRUE(connection_.InitiateKeyUpdate(KeyUpdateReason::kLocalForTests)); + EXPECT_FALSE(connection_.GetDiscardPreviousOneRttKeysAlarm()->IsSet()); + EXPECT_FALSE(connection_.HaveSentPacketsInCurrentKeyPhaseButNoneAcked()); +} + +TEST_P(QuicConnectionTest, InitiateKeyUpdateApproachingConfidentialityLimit) { + if (!connection_.version().UsesTls()) { + return; + } + + SetQuicFlag(quic_key_update_confidentiality_limit, 3U); + + std::string error_details; + TransportParameters params; + // Key update is enabled. + QuicConfig config; + EXPECT_THAT(config.ProcessTransportParameters( + params, /* is_resumption = */ false, &error_details), + IsQuicNoError()); + QuicConfigPeer::SetNegotiated(&config, true); + if (connection_.version().UsesTls()) { + QuicConfigPeer::SetReceivedOriginalConnectionId( + &config, connection_.connection_id()); + QuicConfigPeer::SetReceivedInitialSourceConnectionId( + &config, connection_.connection_id()); + } + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + + MockFramerVisitor peer_framer_visitor_; + peer_framer_.set_visitor(&peer_framer_visitor_); + + uint8_t current_tag = ENCRYPTION_FORWARD_SECURE; + + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + connection_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(current_tag)); + SetDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(current_tag)); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + connection_.OnHandshakeComplete(); + + peer_framer_.SetKeyUpdateSupportForConnection(true); + peer_framer_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(current_tag)); + + const QuicConnectionStats& stats = connection_.GetStats(); + + for (int packet_num = 1; packet_num <= 8; ++packet_num) { + if (packet_num == 3 || packet_num == 6) { + current_tag++; + EXPECT_CALL(visitor_, AdvanceKeysAndCreateCurrentOneRttDecrypter()) + .WillOnce([current_tag]() { + return std::make_unique(current_tag); + }); + EXPECT_CALL(visitor_, CreateCurrentOneRttEncrypter()) + .WillOnce([current_tag]() { + return std::make_unique(current_tag); + }); + EXPECT_CALL(visitor_, + OnKeyUpdate(KeyUpdateReason::kLocalKeyUpdateLimitOverride)); + } + // Send packet. + QuicPacketNumber last_packet; + SendStreamDataToPeer(packet_num, "foo", 0, NO_FIN, &last_packet); + EXPECT_EQ(QuicPacketNumber(packet_num), last_packet); + if (packet_num >= 6) { + EXPECT_EQ(2U, stats.key_update_count); + } else if (packet_num >= 3) { + EXPECT_EQ(1U, stats.key_update_count); + } else { + EXPECT_EQ(0U, stats.key_update_count); + } + + if (packet_num == 4 || packet_num == 7) { + // Pretend that peer accepts the key update. + EXPECT_CALL(peer_framer_visitor_, + AdvanceKeysAndCreateCurrentOneRttDecrypter()) + .WillOnce([current_tag]() { + return std::make_unique(current_tag); + }); + EXPECT_CALL(peer_framer_visitor_, CreateCurrentOneRttEncrypter()) + .WillOnce([current_tag]() { + return std::make_unique(current_tag); + }); + peer_framer_.DoKeyUpdate(KeyUpdateReason::kRemote); + } + // Receive ack for packet. + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + QuicAckFrame frame1 = InitAckFrame(packet_num); + ProcessAckPacket(&frame1); + } +} + +TEST_P(QuicConnectionTest, + CloseConnectionOnConfidentialityLimitKeyUpdateNotAllowed) { + if (!connection_.version().UsesTls()) { + return; + } + + // Set key update confidentiality limit to 1 packet. + SetQuicFlag(quic_key_update_confidentiality_limit, 1U); + // Use confidentiality limit for connection close of 3 packets. + constexpr size_t kConfidentialityLimit = 3U; + + std::string error_details; + TransportParameters params; + // Key update is enabled. + QuicConfig config; + EXPECT_THAT(config.ProcessTransportParameters( + params, /* is_resumption = */ false, &error_details), + IsQuicNoError()); + QuicConfigPeer::SetNegotiated(&config, true); + if (connection_.version().UsesTls()) { + QuicConfigPeer::SetReceivedOriginalConnectionId( + &config, connection_.connection_id()); + QuicConfigPeer::SetReceivedInitialSourceConnectionId( + &config, connection_.connection_id()); + } + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + connection_.SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique( + ENCRYPTION_FORWARD_SECURE, kConfidentialityLimit)); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + connection_.OnHandshakeComplete(); + + QuicPacketNumber last_packet; + // Send 3 packets without receiving acks for any of them. Key update will not + // be allowed, so the confidentiality limit should be reached, forcing the + // connection to be closed. + SendStreamDataToPeer(1, "foo", 0, NO_FIN, &last_packet); + EXPECT_TRUE(connection_.connected()); + SendStreamDataToPeer(2, "foo", 0, NO_FIN, &last_packet); + EXPECT_TRUE(connection_.connected()); + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)); + SendStreamDataToPeer(3, "foo", 0, NO_FIN, &last_packet); + EXPECT_FALSE(connection_.connected()); + const QuicConnectionStats& stats = connection_.GetStats(); + EXPECT_EQ(0U, stats.key_update_count); + TestConnectionCloseQuicErrorCode(QUIC_AEAD_LIMIT_REACHED); +} + +TEST_P(QuicConnectionTest, CloseConnectionOnIntegrityLimitDuringHandshake) { + if (!connection_.version().UsesTls()) { + return; + } + + constexpr uint8_t correct_tag = ENCRYPTION_HANDSHAKE; + constexpr uint8_t wrong_tag = 0xFE; + constexpr QuicPacketCount kIntegrityLimit = 3; + + SetDecrypter(ENCRYPTION_HANDSHAKE, + std::make_unique( + correct_tag, kIntegrityLimit)); + connection_.SetEncrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(correct_tag)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + peer_framer_.SetEncrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(wrong_tag)); + for (uint64_t i = 1; i <= kIntegrityLimit; ++i) { + EXPECT_TRUE(connection_.connected()); + if (i == kIntegrityLimit) { + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)); + EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(AnyNumber()); + } + ProcessDataPacketAtLevel(i, !kHasStopWaiting, ENCRYPTION_HANDSHAKE); + EXPECT_EQ( + i, connection_.GetStats().num_failed_authentication_packets_received); + } + EXPECT_FALSE(connection_.connected()); + TestConnectionCloseQuicErrorCode(QUIC_AEAD_LIMIT_REACHED); +} + +TEST_P(QuicConnectionTest, CloseConnectionOnIntegrityLimitAfterHandshake) { + if (!connection_.version().UsesTls()) { + return; + } + + constexpr uint8_t correct_tag = ENCRYPTION_FORWARD_SECURE; + constexpr uint8_t wrong_tag = 0xFE; + constexpr QuicPacketCount kIntegrityLimit = 3; + + SetDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique( + correct_tag, kIntegrityLimit)); + connection_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(correct_tag)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + connection_.OnHandshakeComplete(); + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + peer_framer_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(wrong_tag)); + for (uint64_t i = 1; i <= kIntegrityLimit; ++i) { + EXPECT_TRUE(connection_.connected()); + if (i == kIntegrityLimit) { + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)); + } + ProcessDataPacketAtLevel(i, !kHasStopWaiting, ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ( + i, connection_.GetStats().num_failed_authentication_packets_received); + } + EXPECT_FALSE(connection_.connected()); + TestConnectionCloseQuicErrorCode(QUIC_AEAD_LIMIT_REACHED); +} + +TEST_P(QuicConnectionTest, + CloseConnectionOnIntegrityLimitAcrossEncryptionLevels) { + if (!connection_.version().UsesTls()) { + return; + } + + uint8_t correct_tag = ENCRYPTION_HANDSHAKE; + constexpr uint8_t wrong_tag = 0xFE; + constexpr QuicPacketCount kIntegrityLimit = 4; + + SetDecrypter(ENCRYPTION_HANDSHAKE, + std::make_unique( + correct_tag, kIntegrityLimit)); + connection_.SetEncrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(correct_tag)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + peer_framer_.SetEncrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(wrong_tag)); + for (uint64_t i = 1; i <= 2; ++i) { + EXPECT_TRUE(connection_.connected()); + ProcessDataPacketAtLevel(i, !kHasStopWaiting, ENCRYPTION_HANDSHAKE); + EXPECT_EQ( + i, connection_.GetStats().num_failed_authentication_packets_received); + } + + correct_tag = ENCRYPTION_FORWARD_SECURE; + SetDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique( + correct_tag, kIntegrityLimit)); + connection_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(correct_tag)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + connection_.OnHandshakeComplete(); + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.RemoveEncrypter(ENCRYPTION_HANDSHAKE); + peer_framer_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(wrong_tag)); + for (uint64_t i = 3; i <= kIntegrityLimit; ++i) { + EXPECT_TRUE(connection_.connected()); + if (i == kIntegrityLimit) { + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)); + } + ProcessDataPacketAtLevel(i, !kHasStopWaiting, ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ( + i, connection_.GetStats().num_failed_authentication_packets_received); + } + EXPECT_FALSE(connection_.connected()); + TestConnectionCloseQuicErrorCode(QUIC_AEAD_LIMIT_REACHED); +} + +TEST_P(QuicConnectionTest, IntegrityLimitDoesNotApplyWithoutDecryptionKey) { + if (!connection_.version().UsesTls()) { + return; + } + + constexpr uint8_t correct_tag = ENCRYPTION_HANDSHAKE; + constexpr uint8_t wrong_tag = 0xFE; + constexpr QuicPacketCount kIntegrityLimit = 3; + + SetDecrypter(ENCRYPTION_HANDSHAKE, + std::make_unique( + correct_tag, kIntegrityLimit)); + connection_.SetEncrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(correct_tag)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + connection_.RemoveDecrypter(ENCRYPTION_FORWARD_SECURE); + + peer_framer_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(wrong_tag)); + for (uint64_t i = 1; i <= kIntegrityLimit * 2; ++i) { + EXPECT_TRUE(connection_.connected()); + ProcessDataPacketAtLevel(i, !kHasStopWaiting, ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ( + 0u, connection_.GetStats().num_failed_authentication_packets_received); + } + EXPECT_TRUE(connection_.connected()); +} + +TEST_P(QuicConnectionTest, CloseConnectionOnIntegrityLimitAcrossKeyPhases) { + if (!connection_.version().UsesTls()) { + return; + } + + constexpr QuicPacketCount kIntegrityLimit = 4; + + TransportParameters params; + QuicConfig config; + std::string error_details; + EXPECT_THAT(config.ProcessTransportParameters( + params, /* is_resumption = */ false, &error_details), + IsQuicNoError()); + QuicConfigPeer::SetNegotiated(&config, true); + if (connection_.version().UsesTls()) { + QuicConfigPeer::SetReceivedOriginalConnectionId( + &config, connection_.connection_id()); + QuicConfigPeer::SetReceivedInitialSourceConnectionId( + &config, connection_.connection_id()); + } + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + + MockFramerVisitor peer_framer_visitor_; + peer_framer_.set_visitor(&peer_framer_visitor_); + + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + connection_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(0x01)); + SetDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique( + ENCRYPTION_FORWARD_SECURE, kIntegrityLimit)); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + connection_.OnHandshakeComplete(); + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + + peer_framer_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(0xFF)); + for (uint64_t i = 1; i <= 2; ++i) { + EXPECT_TRUE(connection_.connected()); + ProcessDataPacketAtLevel(i, !kHasStopWaiting, ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ( + i, connection_.GetStats().num_failed_authentication_packets_received); + } + + peer_framer_.SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + // Send packet 1. + QuicPacketNumber last_packet; + SendStreamDataToPeer(1, "foo", 0, NO_FIN, &last_packet); + EXPECT_EQ(QuicPacketNumber(1u), last_packet); + // Receive ack for packet 1. + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + QuicAckFrame frame1 = InitAckFrame(1); + ProcessAckPacket(&frame1); + // Key update should now be allowed, initiate it. + EXPECT_CALL(visitor_, AdvanceKeysAndCreateCurrentOneRttDecrypter()) + .WillOnce([kIntegrityLimit]() { + return std::make_unique( + 0x02, kIntegrityLimit); + }); + EXPECT_CALL(visitor_, CreateCurrentOneRttEncrypter()).WillOnce([]() { + return std::make_unique(0x02); + }); + EXPECT_CALL(visitor_, OnKeyUpdate(KeyUpdateReason::kLocalForTests)); + EXPECT_TRUE(connection_.InitiateKeyUpdate(KeyUpdateReason::kLocalForTests)); + + // Pretend that peer accepts the key update. + EXPECT_CALL(peer_framer_visitor_, + AdvanceKeysAndCreateCurrentOneRttDecrypter()) + .WillOnce( + []() { return std::make_unique(0x02); }); + EXPECT_CALL(peer_framer_visitor_, CreateCurrentOneRttEncrypter()) + .WillOnce([]() { return std::make_unique(0x02); }); + peer_framer_.SetKeyUpdateSupportForConnection(true); + peer_framer_.DoKeyUpdate(KeyUpdateReason::kLocalForTests); + + // Send packet 2. + SendStreamDataToPeer(2, "bar", 0, NO_FIN, &last_packet); + EXPECT_EQ(QuicPacketNumber(2u), last_packet); + // Receive ack for packet 2. + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _, _, _)); + QuicAckFrame frame2 = InitAckFrame(2); + ProcessAckPacket(&frame2); + + EXPECT_EQ(2u, + connection_.GetStats().num_failed_authentication_packets_received); + + // Do two more undecryptable packets. Integrity limit should be reached. + peer_framer_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(0xFF)); + for (uint64_t i = 3; i <= kIntegrityLimit; ++i) { + EXPECT_TRUE(connection_.connected()); + if (i == kIntegrityLimit) { + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)); + } + ProcessDataPacketAtLevel(i, !kHasStopWaiting, ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ( + i, connection_.GetStats().num_failed_authentication_packets_received); + } + EXPECT_FALSE(connection_.connected()); + TestConnectionCloseQuicErrorCode(QUIC_AEAD_LIMIT_REACHED); +} + +TEST_P(QuicConnectionTest, SendAckFrequencyFrame) { + if (!version().HasIetfQuicFrames()) { + return; + } + SetQuicReloadableFlag(quic_can_send_ack_frequency, true); + set_perspective(Perspective::IS_SERVER); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _, _, _)) + .Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(AnyNumber()); + + QuicConfig config; + QuicConfigPeer::SetReceivedMinAckDelayMs(&config, /*min_ack_delay_ms=*/1); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + QuicConnectionPeer::SetAddressValidated(&connection_); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + peer_creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + + connection_.OnHandshakeComplete(); + + writer_->SetWritable(); + QuicPacketCreatorPeer::SetPacketNumber(creator_, 99); + // Send packet 100 + SendStreamDataToPeer(/*id=*/1, "foo", /*offset=*/0, NO_FIN, nullptr); + + QuicAckFrequencyFrame captured_frame; + EXPECT_CALL(visitor_, SendAckFrequency(_)) + .WillOnce(Invoke([&captured_frame](const QuicAckFrequencyFrame& frame) { + captured_frame = frame; + })); + // Send packet 101. + SendStreamDataToPeer(/*id=*/1, "bar", /*offset=*/3, NO_FIN, nullptr); + + EXPECT_EQ(captured_frame.packet_tolerance, 10u); + EXPECT_EQ(captured_frame.max_ack_delay, + QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs)); + + // Sending packet 102 does not trigger sending another AckFrequencyFrame. + SendStreamDataToPeer(/*id=*/1, "baz", /*offset=*/6, NO_FIN, nullptr); +} + +TEST_P(QuicConnectionTest, SendAckFrequencyFrameUponHandshakeCompletion) { + if (!version().HasIetfQuicFrames()) { + return; + } + SetQuicReloadableFlag(quic_can_send_ack_frequency, true); + set_perspective(Perspective::IS_SERVER); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _, _, _)) + .Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(AnyNumber()); + + QuicConfig config; + QuicConfigPeer::SetReceivedMinAckDelayMs(&config, /*min_ack_delay_ms=*/1); + QuicTagVector quic_tag_vector; + // Enable sending AckFrequency upon handshake completion. + quic_tag_vector.push_back(kAFF2); + QuicConfigPeer::SetReceivedConnectionOptions(&config, quic_tag_vector); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + QuicConnectionPeer::SetAddressValidated(&connection_); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + peer_creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + + QuicAckFrequencyFrame captured_frame; + EXPECT_CALL(visitor_, SendAckFrequency(_)) + .WillOnce(Invoke([&captured_frame](const QuicAckFrequencyFrame& frame) { + captured_frame = frame; + })); + + connection_.OnHandshakeComplete(); + + EXPECT_EQ(captured_frame.packet_tolerance, 2u); + EXPECT_EQ(captured_frame.max_ack_delay, + QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs)); +} + +TEST_P(QuicConnectionTest, FastRecoveryOfLostServerHello) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + connection_.SetFromConfig(config); + + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + connection_.SendCryptoStreamData(); + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(20)); + + // Assume ServerHello gets lost. + peer_framer_.SetEncrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(0x02)); + ProcessCryptoPacketAtLevel(2, ENCRYPTION_HANDSHAKE); + ASSERT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + // Shorten PTO for fast recovery from lost ServerHello. + EXPECT_EQ(clock_.ApproximateNow() + kAlarmGranularity, + connection_.GetRetransmissionAlarm()->deadline()); +} + +TEST_P(QuicConnectionTest, ServerHelloGetsReordered) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + connection_.SetFromConfig(config); + EXPECT_CALL(visitor_, OnCryptoFrame(_)) + .WillRepeatedly(Invoke([=](const QuicCryptoFrame& frame) { + if (frame.level == ENCRYPTION_INITIAL) { + // Install handshake read keys. + SetDecrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + connection_.SetEncrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + } + })); + + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + connection_.SendCryptoStreamData(); + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(20)); + + // Assume ServerHello gets reordered. + peer_framer_.SetEncrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(0x02)); + ProcessCryptoPacketAtLevel(2, ENCRYPTION_HANDSHAKE); + ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL); + // Verify fast recovery is not enabled. + EXPECT_EQ(connection_.sent_packet_manager().GetRetransmissionTime(), + connection_.GetRetransmissionAlarm()->deadline()); +} + +TEST_P(QuicConnectionTest, MigratePath) { + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + EXPECT_CALL(visitor_, OnPathDegrading()); + connection_.OnPathDegradingDetected(); + const QuicSocketAddress kNewSelfAddress(QuicIpAddress::Any4(), 12345); + EXPECT_NE(kNewSelfAddress, connection_.self_address()); + + // Buffer a packet. + EXPECT_CALL(visitor_, OnWriteBlocked()).Times(1); + writer_->SetWriteBlocked(); + connection_.SendMtuDiscoveryPacket(kMaxOutgoingPacketSize); + EXPECT_EQ(1u, connection_.NumQueuedPackets()); + + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); + EXPECT_CALL(visitor_, OnForwardProgressMadeAfterPathDegrading()); + connection_.MigratePath(kNewSelfAddress, connection_.peer_address(), + &new_writer, /*owns_writer=*/false); + + EXPECT_EQ(kNewSelfAddress, connection_.self_address()); + EXPECT_EQ(&new_writer, QuicConnectionPeer::GetWriter(&connection_)); + EXPECT_FALSE(connection_.IsPathDegrading()); + // Buffered packet on the old path should be discarded. + if (connection_.connection_migration_use_new_cid()) { + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + } else { + EXPECT_EQ(1u, connection_.NumQueuedPackets()); + } +} + +TEST_P(QuicConnectionTest, MigrateToNewPathDuringProbing) { + if (!VersionHasIetfQuicFrames(connection_.version().transport_version)) { + return; + } + PathProbeTestInit(Perspective::IS_CLIENT); + const QuicSocketAddress kNewSelfAddress(QuicIpAddress::Any4(), 12345); + EXPECT_NE(kNewSelfAddress, connection_.self_address()); + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)); + bool success = false; + connection_.ValidatePath( + std::make_unique( + kNewSelfAddress, connection_.peer_address(), &new_writer), + std::make_unique( + &connection_, kNewSelfAddress, connection_.peer_address(), &success), + PathValidationReason::kReasonUnknown); + EXPECT_TRUE(connection_.HasPendingPathValidation()); + EXPECT_TRUE(QuicConnectionPeer::IsAlternativePath( + &connection_, kNewSelfAddress, connection_.peer_address())); + + connection_.MigratePath(kNewSelfAddress, connection_.peer_address(), + &new_writer, /*owns_writer=*/false); + EXPECT_EQ(kNewSelfAddress, connection_.self_address()); + EXPECT_TRUE(connection_.HasPendingPathValidation()); + EXPECT_FALSE(QuicConnectionPeer::IsAlternativePath( + &connection_, kNewSelfAddress, connection_.peer_address())); +} + +TEST_P(QuicConnectionTest, MultiPortConnection) { + set_perspective(Perspective::IS_CLIENT); + QuicConfig config; + config.SetConnectionOptionsToSend(QuicTagVector{kRVCM}); + config.SetClientConnectionOptions(QuicTagVector{kMPQC}); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + if (!connection_.connection_migration_use_new_cid()) { + return; + } + connection_.CreateConnectionIdManager(); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + connection_.OnHandshakeComplete(); + + EXPECT_CALL(visitor_, OnPathDegrading()); + connection_.OnPathDegradingDetected(); + + auto self_address = connection_.self_address(); + const QuicSocketAddress kNewSelfAddress(self_address.host(), + self_address.port() + 1); + EXPECT_NE(kNewSelfAddress, self_address); + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); + + EXPECT_CALL(visitor_, ShouldKeepConnectionAlive()).WillOnce(Return(false)); + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(1234); + ASSERT_NE(frame.connection_id, connection_.connection_id()); + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + frame.retire_prior_to = 0u; + frame.sequence_number = 1u; + EXPECT_CALL(visitor_, CreateContextForMultiPortPath()) + .WillRepeatedly(Return( + testing::ByMove(std::make_unique( + kNewSelfAddress, connection_.peer_address(), &new_writer)))); + connection_.OnNewConnectionIdFrame(frame); + EXPECT_TRUE(connection_.HasPendingPathValidation()); + EXPECT_TRUE(QuicConnectionPeer::IsAlternativePath( + &connection_, kNewSelfAddress, connection_.peer_address())); + auto* alt_path = QuicConnectionPeer::GetAlternativePath(&connection_); + EXPECT_FALSE(alt_path->validated); + EXPECT_EQ(PathValidationReason::kMultiPort, + QuicConnectionPeer::path_validator(&connection_) + ->GetPathValidationReason()); + + // Suppose the server retransmits the NEW_CID frame, the client will receive + // the same frame again. It should be ignored. + // Regression test of crbug.com/1406762 + connection_.OnNewConnectionIdFrame(frame); + + // 30ms RTT. + const QuicTime::Delta kTestRTT = QuicTime::Delta::FromMilliseconds(30); + // Fake a response delay. + clock_.AdvanceTime(kTestRTT); + + QuicFrames frames; + frames.push_back(QuicFrame(QuicPathResponseFrame( + 99, new_writer.path_challenge_frames().back().data_buffer))); + ProcessFramesPacketWithAddresses(frames, kNewSelfAddress, kPeerAddress, + ENCRYPTION_FORWARD_SECURE); + // No migration should happen and the alternative path should still be alive. + EXPECT_FALSE(connection_.HasPendingPathValidation()); + EXPECT_TRUE(QuicConnectionPeer::IsAlternativePath( + &connection_, kNewSelfAddress, connection_.peer_address())); + EXPECT_TRUE(alt_path->validated); + + auto stats = connection_.multi_port_stats(); + EXPECT_EQ(1, stats->num_path_degrading); + EXPECT_EQ(0, stats->num_multi_port_probe_failures_when_path_degrading); + EXPECT_EQ(kTestRTT, stats->rtt_stats.latest_rtt()); + EXPECT_EQ(kTestRTT, + stats->rtt_stats_when_default_path_degrading.latest_rtt()); + + // When there's no active request, the probing shouldn't happen. But the + // probing context should be saved. + EXPECT_CALL(visitor_, ShouldKeepConnectionAlive()).WillOnce(Return(false)); + connection_.GetMultiPortProbingAlarm()->Fire(); + EXPECT_FALSE(connection_.HasPendingPathValidation()); + EXPECT_FALSE(connection_.GetMultiPortProbingAlarm()->IsSet()); + + // Simulate the situation where a new request stream is created. + EXPECT_CALL(visitor_, ShouldKeepConnectionAlive()) + .WillRepeatedly(Return(true)); + random_generator_.ChangeValue(); + connection_.MaybeProbeMultiPortPath(); + EXPECT_TRUE(QuicConnectionPeer::IsAlternativePath( + &connection_, kNewSelfAddress, connection_.peer_address())); + EXPECT_TRUE(alt_path->validated); + // Fake a response delay. + clock_.AdvanceTime(kTestRTT); + QuicFrames frames2; + frames2.push_back(QuicFrame(QuicPathResponseFrame( + 99, new_writer.path_challenge_frames().back().data_buffer))); + ProcessFramesPacketWithAddresses(frames2, kNewSelfAddress, kPeerAddress, + ENCRYPTION_FORWARD_SECURE); + // No migration should happen and the alternative path should still be alive. + EXPECT_FALSE(connection_.HasPendingPathValidation()); + EXPECT_TRUE(QuicConnectionPeer::IsAlternativePath( + &connection_, kNewSelfAddress, connection_.peer_address())); + EXPECT_TRUE(alt_path->validated); + EXPECT_EQ(1, stats->num_path_degrading); + EXPECT_EQ(0, stats->num_multi_port_probe_failures_when_path_degrading); + EXPECT_EQ(kTestRTT, stats->rtt_stats.latest_rtt()); + EXPECT_EQ(kTestRTT, + stats->rtt_stats_when_default_path_degrading.latest_rtt()); + + EXPECT_TRUE(connection_.GetMultiPortProbingAlarm()->IsSet()); + // Since there's already a scheduled probing alarm, manual calls won't have + // any effect. + connection_.MaybeProbeMultiPortPath(); + EXPECT_FALSE(connection_.HasPendingPathValidation()); + + // Simulate the case where the path validation fails after retries. + connection_.GetMultiPortProbingAlarm()->Fire(); + EXPECT_TRUE(connection_.HasPendingPathValidation()); + EXPECT_TRUE(QuicConnectionPeer::IsAlternativePath( + &connection_, kNewSelfAddress, connection_.peer_address())); + for (size_t i = 0; i < QuicPathValidator::kMaxRetryTimes + 1; ++i) { + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(3 * kInitialRttMs)); + static_cast( + QuicPathValidatorPeer::retry_timer( + QuicConnectionPeer::path_validator(&connection_))) + ->Fire(); + } + + EXPECT_FALSE(connection_.HasPendingPathValidation()); + EXPECT_FALSE(QuicConnectionPeer::IsAlternativePath( + &connection_, kNewSelfAddress, connection_.peer_address())); + EXPECT_EQ(1, stats->num_path_degrading); + EXPECT_EQ(1, stats->num_multi_port_probe_failures_when_path_degrading); + EXPECT_EQ(0, stats->num_multi_port_probe_failures_when_path_not_degrading); +} + +TEST_P(QuicConnectionTest, TooManyMultiPortPathCreations) { + set_perspective(Perspective::IS_CLIENT); + QuicConfig config; + config.SetConnectionOptionsToSend(QuicTagVector{kRVCM}); + config.SetClientConnectionOptions(QuicTagVector{kMPQC}); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + if (!connection_.connection_migration_use_new_cid()) { + return; + } + connection_.CreateConnectionIdManager(); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + connection_.OnHandshakeComplete(); + + EXPECT_CALL(visitor_, OnPathDegrading()); + connection_.OnPathDegradingDetected(); + + auto self_address = connection_.self_address(); + const QuicSocketAddress kNewSelfAddress(self_address.host(), + self_address.port() + 1); + EXPECT_NE(kNewSelfAddress, self_address); + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); + + EXPECT_CALL(visitor_, ShouldKeepConnectionAlive()) + .WillRepeatedly(Return(true)); + + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(1234); + ASSERT_NE(frame.connection_id, connection_.connection_id()); + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + frame.retire_prior_to = 0u; + frame.sequence_number = 1u; + EXPECT_CALL(visitor_, CreateContextForMultiPortPath()) + .WillRepeatedly(Return( + testing::ByMove(std::make_unique( + kNewSelfAddress, connection_.peer_address(), &new_writer)))); + EXPECT_TRUE(connection_.OnNewConnectionIdFrame(frame)); + EXPECT_TRUE(connection_.HasPendingPathValidation()); + EXPECT_TRUE(QuicConnectionPeer::IsAlternativePath( + &connection_, kNewSelfAddress, connection_.peer_address())); + auto* alt_path = QuicConnectionPeer::GetAlternativePath(&connection_); + EXPECT_FALSE(alt_path->validated); + + EXPECT_TRUE(connection_.HasPendingPathValidation()); + EXPECT_TRUE(QuicConnectionPeer::IsAlternativePath( + &connection_, kNewSelfAddress, connection_.peer_address())); + for (size_t i = 0; i < QuicPathValidator::kMaxRetryTimes + 1; ++i) { + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(3 * kInitialRttMs)); + static_cast( + QuicPathValidatorPeer::retry_timer( + QuicConnectionPeer::path_validator(&connection_))) + ->Fire(); + } + + auto stats = connection_.multi_port_stats(); + EXPECT_FALSE(connection_.HasPendingPathValidation()); + EXPECT_FALSE(QuicConnectionPeer::IsAlternativePath( + &connection_, kNewSelfAddress, connection_.peer_address())); + EXPECT_EQ(1, stats->num_path_degrading); + EXPECT_EQ(1, stats->num_multi_port_probe_failures_when_path_degrading); + + uint64_t connection_id = 1235; + for (size_t i = 0; i < kMaxNumMultiPortPaths - 1; ++i) { + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(connection_id + i); + ASSERT_NE(frame.connection_id, connection_.connection_id()); + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + frame.retire_prior_to = 0u; + frame.sequence_number = i + 2; + EXPECT_CALL(visitor_, CreateContextForMultiPortPath()) + .WillRepeatedly(Return( + testing::ByMove(std::make_unique( + kNewSelfAddress, connection_.peer_address(), &new_writer)))); + EXPECT_TRUE(connection_.OnNewConnectionIdFrame(frame)); + EXPECT_TRUE(connection_.HasPendingPathValidation()); + EXPECT_TRUE(QuicConnectionPeer::IsAlternativePath( + &connection_, kNewSelfAddress, connection_.peer_address())); + EXPECT_FALSE(alt_path->validated); + + for (size_t j = 0; j < QuicPathValidator::kMaxRetryTimes + 1; ++j) { + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(3 * kInitialRttMs)); + static_cast( + QuicPathValidatorPeer::retry_timer( + QuicConnectionPeer::path_validator(&connection_))) + ->Fire(); + } + + EXPECT_FALSE(connection_.HasPendingPathValidation()); + EXPECT_FALSE(QuicConnectionPeer::IsAlternativePath( + &connection_, kNewSelfAddress, connection_.peer_address())); + EXPECT_EQ(1, stats->num_path_degrading); + EXPECT_EQ(i + 2, stats->num_multi_port_probe_failures_when_path_degrading); + } + + // The 6th attemp should fail. + QuicNewConnectionIdFrame frame2; + frame2.connection_id = TestConnectionId(1239); + ASSERT_NE(frame2.connection_id, connection_.connection_id()); + frame2.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame2.connection_id); + frame2.retire_prior_to = 0u; + frame2.sequence_number = 6u; + EXPECT_TRUE(connection_.OnNewConnectionIdFrame(frame2)); + EXPECT_FALSE(connection_.HasPendingPathValidation()); + EXPECT_EQ(kMaxNumMultiPortPaths, + stats->num_multi_port_probe_failures_when_path_degrading); +} + +// Verify that when multi-port is enabled and path degrading is triggered, if +// the alt-path is not ready, nothing happens. +TEST_P(QuicConnectionTest, PathDegradingWhenAltPathIsNotReady) { + set_perspective(Perspective::IS_CLIENT); + QuicConfig config; + config.SetConnectionOptionsToSend(QuicTagVector{kRVCM}); + config.SetClientConnectionOptions(QuicTagVector{kMPQC}); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + if (!connection_.connection_migration_use_new_cid()) { + return; + } + connection_.CreateConnectionIdManager(); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + connection_.OnHandshakeComplete(); + + auto self_address = connection_.self_address(); + const QuicSocketAddress kNewSelfAddress(self_address.host(), + self_address.port() + 1); + EXPECT_NE(kNewSelfAddress, self_address); + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); + + EXPECT_CALL(visitor_, ShouldKeepConnectionAlive()) + .WillRepeatedly(Return(true)); + + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(1234); + ASSERT_NE(frame.connection_id, connection_.connection_id()); + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + frame.retire_prior_to = 0u; + frame.sequence_number = 1u; + EXPECT_CALL(visitor_, CreateContextForMultiPortPath()) + .WillRepeatedly(Return( + testing::ByMove(std::make_unique( + kNewSelfAddress, connection_.peer_address(), &new_writer)))); + EXPECT_TRUE(connection_.OnNewConnectionIdFrame(frame)); + EXPECT_TRUE(connection_.HasPendingPathValidation()); + EXPECT_TRUE(QuicConnectionPeer::IsAlternativePath( + &connection_, kNewSelfAddress, connection_.peer_address())); + auto* alt_path = QuicConnectionPeer::GetAlternativePath(&connection_); + EXPECT_FALSE(alt_path->validated); + + // The alt path is not ready, path degrading doesn't do anything. + EXPECT_CALL(visitor_, OnPathDegrading()); + EXPECT_CALL(visitor_, MigrateToMultiPortPath(_)).Times(0); + connection_.OnPathDegradingDetected(); + + // 30ms RTT. + const QuicTime::Delta kTestRTT = QuicTime::Delta::FromMilliseconds(30); + // Fake a response delay. + clock_.AdvanceTime(kTestRTT); + + // Even if the alt path is validated after path degrading, nothing should + // happen. + QuicFrames frames; + frames.push_back(QuicFrame(QuicPathResponseFrame( + 99, new_writer.path_challenge_frames().back().data_buffer))); + ProcessFramesPacketWithAddresses(frames, kNewSelfAddress, kPeerAddress, + ENCRYPTION_FORWARD_SECURE); + // No migration should happen and the alternative path should still be alive. + EXPECT_FALSE(connection_.HasPendingPathValidation()); + EXPECT_TRUE(QuicConnectionPeer::IsAlternativePath( + &connection_, kNewSelfAddress, connection_.peer_address())); + EXPECT_TRUE(alt_path->validated); +} + +// Verify that when multi-port is enabled and path degrading is triggered, if +// the alt-path is ready and not probing, it should be migrated. +TEST_P(QuicConnectionTest, PathDegradingWhenAltPathIsReadyAndNotProbing) { + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + set_perspective(Perspective::IS_CLIENT); + QuicConfig config; + config.SetConnectionOptionsToSend(QuicTagVector{kRVCM}); + config.SetClientConnectionOptions(QuicTagVector{kMPQC}); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + if (!connection_.connection_migration_use_new_cid()) { + return; + } + connection_.CreateConnectionIdManager(); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + connection_.OnHandshakeComplete(); + + auto self_address = connection_.self_address(); + const QuicSocketAddress kNewSelfAddress(self_address.host(), + self_address.port() + 1); + EXPECT_NE(kNewSelfAddress, self_address); + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); + + EXPECT_CALL(visitor_, ShouldKeepConnectionAlive()) + .WillRepeatedly(Return(true)); + + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(1234); + ASSERT_NE(frame.connection_id, connection_.connection_id()); + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + frame.retire_prior_to = 0u; + frame.sequence_number = 1u; + EXPECT_CALL(visitor_, CreateContextForMultiPortPath()) + .WillRepeatedly(Return( + testing::ByMove(std::make_unique( + kNewSelfAddress, connection_.peer_address(), &new_writer)))); + EXPECT_TRUE(connection_.OnNewConnectionIdFrame(frame)); + EXPECT_TRUE(connection_.HasPendingPathValidation()); + EXPECT_TRUE(QuicConnectionPeer::IsAlternativePath( + &connection_, kNewSelfAddress, connection_.peer_address())); + auto* alt_path = QuicConnectionPeer::GetAlternativePath(&connection_); + EXPECT_FALSE(alt_path->validated); + + // 30ms RTT. + const QuicTime::Delta kTestRTT = QuicTime::Delta::FromMilliseconds(30); + // Fake a response delay. + clock_.AdvanceTime(kTestRTT); + + // Even if the alt path is validated after path degrading, nothing should + // happen. + QuicFrames frames; + frames.push_back(QuicFrame(QuicPathResponseFrame( + 99, new_writer.path_challenge_frames().back().data_buffer))); + ProcessFramesPacketWithAddresses(frames, kNewSelfAddress, kPeerAddress, + ENCRYPTION_FORWARD_SECURE); + // No migration should happen and the alternative path should still be alive. + EXPECT_FALSE(connection_.HasPendingPathValidation()); + EXPECT_TRUE(QuicConnectionPeer::IsAlternativePath( + &connection_, kNewSelfAddress, connection_.peer_address())); + EXPECT_TRUE(alt_path->validated); + + // Trigger path degrading and the connection should attempt to migrate. + EXPECT_CALL(visitor_, OnPathDegrading()); + EXPECT_CALL(visitor_, OnForwardProgressMadeAfterPathDegrading()).Times(0); + EXPECT_CALL(visitor_, MigrateToMultiPortPath(_)) + .WillOnce(Invoke([&](std::unique_ptr context) { + EXPECT_EQ(context->self_address(), kNewSelfAddress); + connection_.MigratePath(context->self_address(), + context->peer_address(), context->WriterToUse(), + /*owns_writer=*/false); + })); + connection_.OnPathDegradingDetected(); +} + +// Verify that when multi-port is enabled and path degrading is triggered, if +// the alt-path is probing, the probing should be cancelled and the path should +// be migrated. +TEST_P(QuicConnectionTest, PathDegradingWhenAltPathIsReadyAndProbing) { + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + set_perspective(Perspective::IS_CLIENT); + QuicConfig config; + config.SetConnectionOptionsToSend(QuicTagVector{kRVCM}); + config.SetClientConnectionOptions(QuicTagVector{kMPQC}); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + if (!connection_.connection_migration_use_new_cid()) { + return; + } + connection_.CreateConnectionIdManager(); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + connection_.OnHandshakeComplete(); + + auto self_address = connection_.self_address(); + const QuicSocketAddress kNewSelfAddress(self_address.host(), + self_address.port() + 1); + EXPECT_NE(kNewSelfAddress, self_address); + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); + + EXPECT_CALL(visitor_, ShouldKeepConnectionAlive()) + .WillRepeatedly(Return(true)); + + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(1234); + ASSERT_NE(frame.connection_id, connection_.connection_id()); + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + frame.retire_prior_to = 0u; + frame.sequence_number = 1u; + EXPECT_CALL(visitor_, CreateContextForMultiPortPath()) + .WillRepeatedly(Return( + testing::ByMove(std::make_unique( + kNewSelfAddress, connection_.peer_address(), &new_writer)))); + EXPECT_TRUE(connection_.OnNewConnectionIdFrame(frame)); + EXPECT_TRUE(connection_.HasPendingPathValidation()); + EXPECT_TRUE(QuicConnectionPeer::IsAlternativePath( + &connection_, kNewSelfAddress, connection_.peer_address())); + auto* alt_path = QuicConnectionPeer::GetAlternativePath(&connection_); + EXPECT_FALSE(alt_path->validated); + + // 30ms RTT. + const QuicTime::Delta kTestRTT = QuicTime::Delta::FromMilliseconds(30); + // Fake a response delay. + clock_.AdvanceTime(kTestRTT); + + // Even if the alt path is validated after path degrading, nothing should + // happen. + QuicFrames frames; + frames.push_back(QuicFrame(QuicPathResponseFrame( + 99, new_writer.path_challenge_frames().back().data_buffer))); + ProcessFramesPacketWithAddresses(frames, kNewSelfAddress, kPeerAddress, + ENCRYPTION_FORWARD_SECURE); + // No migration should happen and the alternative path should still be alive. + EXPECT_FALSE(connection_.HasPendingPathValidation()); + EXPECT_TRUE(QuicConnectionPeer::IsAlternativePath( + &connection_, kNewSelfAddress, connection_.peer_address())); + EXPECT_TRUE(alt_path->validated); + + random_generator_.ChangeValue(); + connection_.GetMultiPortProbingAlarm()->Fire(); + EXPECT_TRUE(connection_.HasPendingPathValidation()); + EXPECT_FALSE(connection_.GetMultiPortProbingAlarm()->IsSet()); + + // Trigger path degrading and the connection should attempt to migrate. + EXPECT_CALL(visitor_, OnPathDegrading()); + EXPECT_CALL(visitor_, OnForwardProgressMadeAfterPathDegrading()).Times(0); + EXPECT_CALL(visitor_, MigrateToMultiPortPath(_)) + .WillOnce(Invoke([&](std::unique_ptr context) { + EXPECT_EQ(context->self_address(), kNewSelfAddress); + connection_.MigratePath(context->self_address(), + context->peer_address(), context->WriterToUse(), + /*owns_writer=*/false); + })); + connection_.OnPathDegradingDetected(); + EXPECT_FALSE(connection_.HasPendingPathValidation()); + auto* path_validator = QuicConnectionPeer::path_validator(&connection_); + EXPECT_FALSE(QuicPathValidatorPeer::retry_timer(path_validator)->IsSet()); +} + +TEST_P(QuicConnectionTest, SingleAckInPacket) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.NeuterUnencryptedPackets(); + connection_.OnHandshakeComplete(); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_COMPLETE)); + + EXPECT_CALL(visitor_, OnStreamFrame(_)).WillOnce(Invoke([=]() { + connection_.SendStreamData3(); + connection_.CloseConnection( + QUIC_INTERNAL_ERROR, "error", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + })); + QuicFrames frames; + frames.push_back(QuicFrame(frame1_)); + ProcessFramesPacketWithAddresses(frames, kSelfAddress, kPeerAddress, + ENCRYPTION_FORWARD_SECURE); + ASSERT_FALSE(writer_->ack_frames().empty()); + EXPECT_EQ(1u, writer_->ack_frames().size()); +} + +TEST_P(QuicConnectionTest, + ServerReceivedZeroRttPacketAfterOneRttPacketWithRetainedKey) { + if (!connection_.version().UsesTls()) { + return; + } + + set_perspective(Perspective::IS_SERVER); + SetDecrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacketAtLevel(1, !kHasStopWaiting, ENCRYPTION_ZERO_RTT); + + // Finish handshake. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + notifier_.NeuterUnencryptedData(); + connection_.NeuterUnencryptedPackets(); + connection_.OnHandshakeComplete(); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_COMPLETE)); + + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacketAtLevel(4, !kHasStopWaiting, ENCRYPTION_FORWARD_SECURE); + EXPECT_TRUE(connection_.GetDiscardZeroRttDecryptionKeysAlarm()->IsSet()); + + // 0-RTT packet received out of order should be decoded since the decrypter + // is temporarily retained. + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacketAtLevel(2, !kHasStopWaiting, ENCRYPTION_ZERO_RTT); + EXPECT_EQ( + 0u, + connection_.GetStats() + .num_tls_server_zero_rtt_packets_received_after_discarding_decrypter); + + // Simulate the timeout for discarding 0-RTT keys passing. + connection_.GetDiscardZeroRttDecryptionKeysAlarm()->Fire(); + + // Another 0-RTT packet received now should not be decoded. + EXPECT_FALSE(connection_.GetDiscardZeroRttDecryptionKeysAlarm()->IsSet()); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(0); + ProcessDataPacketAtLevel(3, !kHasStopWaiting, ENCRYPTION_ZERO_RTT); + EXPECT_EQ( + 1u, + connection_.GetStats() + .num_tls_server_zero_rtt_packets_received_after_discarding_decrypter); + + // The |discard_zero_rtt_decryption_keys_alarm_| should only be set on the + // first 1-RTT packet received. + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacketAtLevel(5, !kHasStopWaiting, ENCRYPTION_FORWARD_SECURE); + EXPECT_FALSE(connection_.GetDiscardZeroRttDecryptionKeysAlarm()->IsSet()); +} + +TEST_P(QuicConnectionTest, NewTokenFrameInstigateAcks) { + if (!version().HasIetfQuicFrames()) { + return; + } + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + QuicNewTokenFrame* new_token = new QuicNewTokenFrame(); + EXPECT_CALL(visitor_, OnNewTokenReceived(_)); + ProcessFramePacket(QuicFrame(new_token)); + + // Ensure that this has caused the ACK alarm to be set. + EXPECT_TRUE(connection_.HasPendingAcks()); +} + +TEST_P(QuicConnectionTest, ServerClosesConnectionOnNewTokenFrame) { + if (!version().HasIetfQuicFrames()) { + return; + } + set_perspective(Perspective::IS_SERVER); + QuicNewTokenFrame* new_token = new QuicNewTokenFrame(); + EXPECT_CALL(visitor_, OnNewTokenReceived(_)).Times(0); + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)); + EXPECT_CALL(visitor_, BeforeConnectionCloseSent()); + ProcessFramePacket(QuicFrame(new_token)); + EXPECT_FALSE(connection_.connected()); +} + +TEST_P(QuicConnectionTest, OverrideRetryTokenWithRetryPacket) { + if (!version().HasIetfQuicFrames()) { + return; + } + std::string address_token = "TestAddressToken"; + connection_.SetSourceAddressTokenToSend(address_token); + EXPECT_EQ(QuicPacketCreatorPeer::GetRetryToken( + QuicConnectionPeer::GetPacketCreator(&connection_)), + address_token); + // Passes valid retry and verify token gets overridden. + TestClientRetryHandling(/*invalid_retry_tag=*/false, + /*missing_original_id_in_config=*/false, + /*wrong_original_id_in_config=*/false, + /*missing_retry_id_in_config=*/false, + /*wrong_retry_id_in_config=*/false); +} + +TEST_P(QuicConnectionTest, DonotOverrideRetryTokenWithAddressToken) { + if (!version().HasIetfQuicFrames()) { + return; + } + // Passes valid retry and verify token gets overridden. + TestClientRetryHandling(/*invalid_retry_tag=*/false, + /*missing_original_id_in_config=*/false, + /*wrong_original_id_in_config=*/false, + /*missing_retry_id_in_config=*/false, + /*wrong_retry_id_in_config=*/false); + std::string retry_token = QuicPacketCreatorPeer::GetRetryToken( + QuicConnectionPeer::GetPacketCreator(&connection_)); + + std::string address_token = "TestAddressToken"; + connection_.SetSourceAddressTokenToSend(address_token); + EXPECT_EQ(QuicPacketCreatorPeer::GetRetryToken( + QuicConnectionPeer::GetPacketCreator(&connection_)), + retry_token); +} + +TEST_P(QuicConnectionTest, + ServerReceivedZeroRttWithHigherPacketNumberThanOneRtt) { + if (!connection_.version().UsesTls()) { + return; + } + + // The code that checks for this error piggybacks on some book-keeping state + // kept for key update, so enable key update for the test. + std::string error_details; + TransportParameters params; + QuicConfig config; + EXPECT_THAT(config.ProcessTransportParameters( + params, /* is_resumption = */ false, &error_details), + IsQuicNoError()); + QuicConfigPeer::SetNegotiated(&config, true); + QuicConfigPeer::SetReceivedOriginalConnectionId(&config, + connection_.connection_id()); + QuicConfigPeer::SetReceivedInitialSourceConnectionId( + &config, connection_.connection_id()); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + + set_perspective(Perspective::IS_SERVER); + SetDecrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacketAtLevel(1, !kHasStopWaiting, ENCRYPTION_ZERO_RTT); + + // Finish handshake. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + notifier_.NeuterUnencryptedData(); + connection_.NeuterUnencryptedPackets(); + connection_.OnHandshakeComplete(); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_COMPLETE)); + + // Decrypt a 1-RTT packet. + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacketAtLevel(2, !kHasStopWaiting, ENCRYPTION_FORWARD_SECURE); + EXPECT_TRUE(connection_.GetDiscardZeroRttDecryptionKeysAlarm()->IsSet()); + + // 0-RTT packet with higher packet number than a 1-RTT packet is invalid and + // should cause the connection to be closed. + EXPECT_CALL(visitor_, BeforeConnectionCloseSent()); + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)); + ProcessDataPacketAtLevel(3, !kHasStopWaiting, ENCRYPTION_ZERO_RTT); + EXPECT_FALSE(connection_.connected()); + TestConnectionCloseQuicErrorCode( + QUIC_INVALID_0RTT_PACKET_NUMBER_OUT_OF_ORDER); +} + +// Regression test for b/177312785 +TEST_P(QuicConnectionTest, PeerMigrateBeforeHandshakeConfirm) { + if (!VersionHasIetfQuicFrames(version().transport_version)) { + return; + } + set_perspective(Perspective::IS_SERVER); + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + EXPECT_EQ(Perspective::IS_SERVER, connection_.perspective()); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_START)); + + // Clear direct_peer_address. + QuicConnectionPeer::SetDirectPeerAddress(&connection_, QuicSocketAddress()); + // Clear effective_peer_address, it is the same as direct_peer_address for + // this test. + QuicConnectionPeer::SetEffectivePeerAddress(&connection_, + QuicSocketAddress()); + EXPECT_FALSE(connection_.effective_peer_address().IsInitialized()); + + const QuicSocketAddress kNewPeerAddress = + QuicSocketAddress(QuicIpAddress::Loopback6(), /*port=*/23456); + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kSelfAddress, kPeerAddress, + ENCRYPTION_INITIAL); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + + // Process another packet with a different peer address on server side will + // close connection. + QuicAckFrame frame = InitAckFrame(1); + EXPECT_CALL(visitor_, BeforeConnectionCloseSent()); + EXPECT_CALL(visitor_, + OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)); + EXPECT_CALL(visitor_, OnConnectionMigration(PORT_CHANGE)).Times(0u); + + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _, _, _)) + .Times(0); + ProcessFramePacketWithAddresses(QuicFrame(&frame), kSelfAddress, + kNewPeerAddress, ENCRYPTION_INITIAL); + EXPECT_FALSE(connection_.connected()); +} + +// Regresstion test for b/175685916 +TEST_P(QuicConnectionTest, TryToFlushAckWithAckQueued) { + if (!version().HasIetfQuicFrames()) { + return; + } + SetQuicReloadableFlag(quic_can_send_ack_frequency, true); + set_perspective(Perspective::IS_SERVER); + + QuicConfig config; + QuicConfigPeer::SetReceivedMinAckDelayMs(&config, /*min_ack_delay_ms=*/1); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + connection_.OnHandshakeComplete(); + QuicPacketCreatorPeer::SetPacketNumber(creator_, 200); + + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacketAtLevel(1, !kHasStopWaiting, ENCRYPTION_FORWARD_SECURE); + // Sending ACK_FREQUENCY bundles ACK. QuicConnectionPeer::SendPing + // will try to bundle ACK but there is no pending ACK. + EXPECT_CALL(visitor_, SendAckFrequency(_)) + .WillOnce(Invoke(¬ifier_, + &SimpleSessionNotifier::WriteOrBufferAckFrequency)); + QuicConnectionPeer::SendPing(&connection_); +} + +TEST_P(QuicConnectionTest, PathChallengeBeforePeerIpAddressChangeAtServer) { + set_perspective(Perspective::IS_SERVER); + if (!connection_.connection_migration_use_new_cid()) { + return; + } + PathProbeTestInit(Perspective::IS_SERVER); + SetClientConnectionId(TestConnectionId(1)); + connection_.CreateConnectionIdManager(); + + QuicConnectionId server_cid0 = connection_.connection_id(); + QuicConnectionId client_cid0 = connection_.client_connection_id(); + QuicConnectionId client_cid1 = TestConnectionId(2); + QuicConnectionId server_cid1; + // Sends new server CID to client. + if (!connection_.connection_id().IsEmpty()) { + EXPECT_CALL(connection_id_generator_, GenerateNextConnectionId(_)) + .WillOnce(Return(TestConnectionId(456))); + } + EXPECT_CALL(visitor_, MaybeReserveConnectionId(_)) + .WillOnce(Invoke([&](const QuicConnectionId& cid) { + server_cid1 = cid; + return true; + })); + EXPECT_CALL(visitor_, SendNewConnectionId(_)); + connection_.MaybeSendConnectionIdToClient(); + // Receives new client CID from client. + QuicNewConnectionIdFrame new_cid_frame; + new_cid_frame.connection_id = client_cid1; + new_cid_frame.sequence_number = 1u; + new_cid_frame.retire_prior_to = 0u; + connection_.OnNewConnectionIdFrame(new_cid_frame); + auto* packet_creator = QuicConnectionPeer::GetPacketCreator(&connection_); + ASSERT_EQ(packet_creator->GetDestinationConnectionId(), client_cid0); + ASSERT_EQ(packet_creator->GetSourceConnectionId(), server_cid0); + + peer_creator_.SetServerConnectionId(server_cid1); + const QuicSocketAddress kNewPeerAddress = + QuicSocketAddress(QuicIpAddress::Loopback4(), /*port=*/23456); + QuicPathFrameBuffer path_challenge_payload{0, 1, 2, 3, 4, 5, 6, 7}; + QuicFrames frames1; + frames1.push_back( + QuicFrame(QuicPathChallengeFrame(0, path_challenge_payload))); + QuicPathFrameBuffer payload; + EXPECT_CALL(*send_algorithm_, + OnPacketSent(_, _, _, _, NO_RETRANSMITTABLE_DATA)) + .Times(AtLeast(1)) + .WillOnce(Invoke([&]() { + EXPECT_EQ(kNewPeerAddress, writer_->last_write_peer_address()); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + EXPECT_FALSE(writer_->path_response_frames().empty()); + EXPECT_FALSE(writer_->path_challenge_frames().empty()); + payload = writer_->path_challenge_frames().front().data_buffer; + })); + ProcessFramesPacketWithAddresses(frames1, kSelfAddress, kNewPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + EXPECT_TRUE(connection_.HasPendingPathValidation()); + const auto* default_path = QuicConnectionPeer::GetDefaultPath(&connection_); + const auto* alternative_path = + QuicConnectionPeer::GetAlternativePath(&connection_); + EXPECT_EQ(default_path->client_connection_id, client_cid0); + EXPECT_EQ(default_path->server_connection_id, server_cid0); + EXPECT_EQ(alternative_path->client_connection_id, client_cid1); + EXPECT_EQ(alternative_path->server_connection_id, server_cid1); + EXPECT_EQ(packet_creator->GetDestinationConnectionId(), client_cid0); + EXPECT_EQ(packet_creator->GetSourceConnectionId(), server_cid0); + + // Process another packet with a different peer address on server side will + // start connection migration. + EXPECT_CALL(visitor_, OnConnectionMigration(IPV6_TO_IPV4_CHANGE)).Times(1); + EXPECT_CALL(visitor_, OnStreamFrame(_)).WillOnce(Invoke([=]() { + EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); + })); + // IETF QUIC send algorithm should be changed to a different object, so no + // OnPacketSent() called on the old send algorithm. + EXPECT_CALL(*send_algorithm_, + OnPacketSent(_, _, _, _, NO_RETRANSMITTABLE_DATA)) + .Times(0); + QuicFrames frames2; + frames2.push_back(QuicFrame(frame2_)); + ProcessFramesPacketWithAddresses(frames2, kSelfAddress, kNewPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); + EXPECT_EQ(kNewPeerAddress, connection_.effective_peer_address()); + EXPECT_EQ(IPV6_TO_IPV4_CHANGE, + connection_.active_effective_peer_migration_type()); + EXPECT_TRUE(writer_->path_challenge_frames().empty()); + EXPECT_NE(connection_.sent_packet_manager().GetSendAlgorithm(), + send_algorithm_); + // Switch to use the mock send algorithm. + send_algorithm_ = new StrictMock(); + EXPECT_CALL(*send_algorithm_, CanSend(_)).WillRepeatedly(Return(true)); + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .WillRepeatedly(Return(kDefaultTCPMSS)); + EXPECT_CALL(*send_algorithm_, OnApplicationLimited(_)).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, BandwidthEstimate()) + .Times(AnyNumber()) + .WillRepeatedly(Return(QuicBandwidth::Zero())); + EXPECT_CALL(*send_algorithm_, InSlowStart()).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, InRecovery()).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, PopulateConnectionStats(_)).Times(AnyNumber()); + connection_.SetSendAlgorithm(send_algorithm_); + EXPECT_EQ(default_path->client_connection_id, client_cid1); + EXPECT_EQ(default_path->server_connection_id, server_cid1); + // The previous default path is kept as alternative path before reverse path + // validation finishes. + EXPECT_EQ(alternative_path->client_connection_id, client_cid0); + EXPECT_EQ(alternative_path->server_connection_id, server_cid0); + EXPECT_EQ(packet_creator->GetDestinationConnectionId(), client_cid1); + EXPECT_EQ(packet_creator->GetSourceConnectionId(), server_cid1); + + EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); + EXPECT_EQ(kNewPeerAddress, connection_.effective_peer_address()); + EXPECT_EQ(IPV6_TO_IPV4_CHANGE, + connection_.active_effective_peer_migration_type()); + EXPECT_EQ(1u, connection_.GetStats() + .num_peer_migration_to_proactively_validated_address); + + // The PATH_CHALLENGE and PATH_RESPONSE is expanded upto the max packet size + // which may exceeds the anti-amplification limit. Verify server is throttled + // by anti-amplification limit. + connection_.SendCryptoDataWithString("foo", 0); + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + + // Receiving PATH_RESPONSE should lift the anti-amplification limit. + QuicFrames frames3; + frames3.push_back(QuicFrame(QuicPathResponseFrame(99, payload))); + EXPECT_CALL(visitor_, MaybeSendAddressToken()); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(testing::AtLeast(1u)); + ProcessFramesPacketWithAddresses(frames3, kSelfAddress, kNewPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(NO_CHANGE, connection_.active_effective_peer_migration_type()); + // Verify that alternative_path_ is cleared and the peer CID is retired. + EXPECT_TRUE(alternative_path->client_connection_id.IsEmpty()); + EXPECT_TRUE(alternative_path->server_connection_id.IsEmpty()); + EXPECT_FALSE(alternative_path->stateless_reset_token.has_value()); + auto* retire_peer_issued_cid_alarm = + connection_.GetRetirePeerIssuedConnectionIdAlarm(); + ASSERT_TRUE(retire_peer_issued_cid_alarm->IsSet()); + EXPECT_CALL(visitor_, SendRetireConnectionId(/*sequence_number=*/0u)); + retire_peer_issued_cid_alarm->Fire(); + + // Verify the anti-amplification limit is lifted by sending a packet larger + // than the anti-amplification limit. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + EXPECT_CALL(*send_algorithm_, PacingRate(_)) + .WillRepeatedly(Return(QuicBandwidth::Zero())); + connection_.SendCryptoDataWithString(std::string(1200, 'a'), 0); + EXPECT_EQ(1u, connection_.GetStats().num_validated_peer_migration); + EXPECT_EQ(1u, connection_.num_unlinkable_client_migration()); +} + +TEST_P(QuicConnectionTest, + PathValidationSucceedsBeforePeerIpAddressChangeAtServer) { + set_perspective(Perspective::IS_SERVER); + if (!connection_.connection_migration_use_new_cid()) { + return; + } + PathProbeTestInit(Perspective::IS_SERVER); + connection_.CreateConnectionIdManager(); + + QuicConnectionId server_cid0 = connection_.connection_id(); + QuicConnectionId server_cid1; + // Sends new server CID to client. + if (!connection_.connection_id().IsEmpty()) { + EXPECT_CALL(connection_id_generator_, GenerateNextConnectionId(_)) + .WillOnce(Return(TestConnectionId(456))); + } + EXPECT_CALL(visitor_, MaybeReserveConnectionId(_)) + .WillOnce(Invoke([&](const QuicConnectionId& cid) { + server_cid1 = cid; + return true; + })); + EXPECT_CALL(visitor_, SendNewConnectionId(_)); + connection_.MaybeSendConnectionIdToClient(); + auto* packet_creator = QuicConnectionPeer::GetPacketCreator(&connection_); + ASSERT_EQ(packet_creator->GetSourceConnectionId(), server_cid0); + + // Receive probing packet with new peer address. + peer_creator_.SetServerConnectionId(server_cid1); + const QuicSocketAddress kNewPeerAddress(QuicIpAddress::Loopback4(), + /*port=*/23456); + QuicPathFrameBuffer payload; + EXPECT_CALL(*send_algorithm_, + OnPacketSent(_, _, _, _, NO_RETRANSMITTABLE_DATA)) + .WillOnce(Invoke([&]() { + EXPECT_EQ(kNewPeerAddress, writer_->last_write_peer_address()); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + EXPECT_FALSE(writer_->path_response_frames().empty()); + EXPECT_FALSE(writer_->path_challenge_frames().empty()); + payload = writer_->path_challenge_frames().front().data_buffer; + })) + .WillRepeatedly(Invoke([&]() { + // Only start reverse path validation once. + EXPECT_TRUE(writer_->path_challenge_frames().empty()); + })); + QuicPathFrameBuffer path_challenge_payload{0, 1, 2, 3, 4, 5, 6, 7}; + QuicFrames frames1; + frames1.push_back( + QuicFrame(QuicPathChallengeFrame(0, path_challenge_payload))); + ProcessFramesPacketWithAddresses(frames1, kSelfAddress, kNewPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_TRUE(connection_.HasPendingPathValidation()); + const auto* default_path = QuicConnectionPeer::GetDefaultPath(&connection_); + const auto* alternative_path = + QuicConnectionPeer::GetAlternativePath(&connection_); + EXPECT_EQ(default_path->server_connection_id, server_cid0); + EXPECT_EQ(alternative_path->server_connection_id, server_cid1); + EXPECT_EQ(packet_creator->GetSourceConnectionId(), server_cid0); + + // Receive PATH_RESPONSE should mark the new peer address validated. + QuicFrames frames3; + frames3.push_back(QuicFrame(QuicPathResponseFrame(99, payload))); + ProcessFramesPacketWithAddresses(frames3, kSelfAddress, kNewPeerAddress, + ENCRYPTION_FORWARD_SECURE); + + // Process another packet with a newer peer address with the same port will + // start connection migration. + EXPECT_CALL(visitor_, OnConnectionMigration(IPV6_TO_IPV4_CHANGE)).Times(1); + // IETF QUIC send algorithm should be changed to a different object, so no + // OnPacketSent() called on the old send algorithm. + EXPECT_CALL(*send_algorithm_, + OnPacketSent(_, _, _, _, NO_RETRANSMITTABLE_DATA)) + .Times(0); + const QuicSocketAddress kNewerPeerAddress(QuicIpAddress::Loopback4(), + /*port=*/34567); + EXPECT_CALL(visitor_, OnStreamFrame(_)).WillOnce(Invoke([=]() { + EXPECT_EQ(kNewerPeerAddress, connection_.peer_address()); + })); + EXPECT_CALL(visitor_, MaybeSendAddressToken()); + QuicFrames frames2; + frames2.push_back(QuicFrame(frame2_)); + ProcessFramesPacketWithAddresses(frames2, kSelfAddress, kNewerPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kNewerPeerAddress, connection_.peer_address()); + EXPECT_EQ(kNewerPeerAddress, connection_.effective_peer_address()); + // Since the newer address has the same IP as the previously validated probing + // address. The peer migration becomes validated immediately. + EXPECT_EQ(NO_CHANGE, connection_.active_effective_peer_migration_type()); + EXPECT_EQ(kNewerPeerAddress, writer_->last_write_peer_address()); + EXPECT_EQ(1u, connection_.GetStats() + .num_peer_migration_to_proactively_validated_address); + EXPECT_FALSE(connection_.HasPendingPathValidation()); + EXPECT_NE(connection_.sent_packet_manager().GetSendAlgorithm(), + send_algorithm_); + + EXPECT_EQ(default_path->server_connection_id, server_cid1); + EXPECT_EQ(packet_creator->GetSourceConnectionId(), server_cid1); + // Verify that alternative_path_ is cleared. + EXPECT_TRUE(alternative_path->server_connection_id.IsEmpty()); + EXPECT_FALSE(alternative_path->stateless_reset_token.has_value()); + + // Switch to use the mock send algorithm. + send_algorithm_ = new StrictMock(); + EXPECT_CALL(*send_algorithm_, CanSend(_)).WillRepeatedly(Return(true)); + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .WillRepeatedly(Return(kDefaultTCPMSS)); + EXPECT_CALL(*send_algorithm_, OnApplicationLimited(_)).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, BandwidthEstimate()) + .Times(AnyNumber()) + .WillRepeatedly(Return(QuicBandwidth::Zero())); + EXPECT_CALL(*send_algorithm_, InSlowStart()).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, InRecovery()).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, PopulateConnectionStats(_)).Times(AnyNumber()); + connection_.SetSendAlgorithm(send_algorithm_); + + // Verify the server is not throttled by the anti-amplification limit by + // sending a packet larger than the anti-amplification limit. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)); + connection_.SendCryptoDataWithString(std::string(1200, 'a'), 0); + EXPECT_EQ(1u, connection_.GetStats().num_validated_peer_migration); +} + +// Regression test of b/228645208. +TEST_P(QuicConnectionTest, NoNonProbingFrameOnAlternativePath) { + if (!connection_.connection_migration_use_new_cid()) { + return; + } + + PathProbeTestInit(Perspective::IS_SERVER); + SetClientConnectionId(TestConnectionId(1)); + connection_.CreateConnectionIdManager(); + + QuicConnectionId server_cid0 = connection_.connection_id(); + QuicConnectionId client_cid0 = connection_.client_connection_id(); + QuicConnectionId client_cid1 = TestConnectionId(2); + QuicConnectionId server_cid1; + // Sends new server CID to client. + if (!connection_.connection_id().IsEmpty()) { + EXPECT_CALL(connection_id_generator_, GenerateNextConnectionId(_)) + .WillOnce(Return(TestConnectionId(456))); + } + EXPECT_CALL(visitor_, MaybeReserveConnectionId(_)) + .WillOnce(Invoke([&](const QuicConnectionId& cid) { + server_cid1 = cid; + return true; + })); + EXPECT_CALL(visitor_, SendNewConnectionId(_)); + connection_.MaybeSendConnectionIdToClient(); + // Receives new client CID from client. + QuicNewConnectionIdFrame new_cid_frame; + new_cid_frame.connection_id = client_cid1; + new_cid_frame.sequence_number = 1u; + new_cid_frame.retire_prior_to = 0u; + connection_.OnNewConnectionIdFrame(new_cid_frame); + auto* packet_creator = QuicConnectionPeer::GetPacketCreator(&connection_); + ASSERT_EQ(packet_creator->GetDestinationConnectionId(), client_cid0); + ASSERT_EQ(packet_creator->GetSourceConnectionId(), server_cid0); + + peer_creator_.SetServerConnectionId(server_cid1); + const QuicSocketAddress kNewPeerAddress = + QuicSocketAddress(QuicIpAddress::Loopback4(), /*port=*/23456); + QuicPathFrameBuffer path_challenge_payload{0, 1, 2, 3, 4, 5, 6, 7}; + QuicFrames frames1; + frames1.push_back( + QuicFrame(QuicPathChallengeFrame(0, path_challenge_payload))); + QuicPathFrameBuffer payload; + EXPECT_CALL(*send_algorithm_, + OnPacketSent(_, _, _, _, NO_RETRANSMITTABLE_DATA)) + .Times(AtLeast(1)) + .WillOnce(Invoke([&]() { + EXPECT_EQ(kNewPeerAddress, writer_->last_write_peer_address()); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + EXPECT_FALSE(writer_->path_response_frames().empty()); + EXPECT_FALSE(writer_->path_challenge_frames().empty()); + payload = writer_->path_challenge_frames().front().data_buffer; + })); + ProcessFramesPacketWithAddresses(frames1, kSelfAddress, kNewPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + EXPECT_TRUE(connection_.HasPendingPathValidation()); + const auto* default_path = QuicConnectionPeer::GetDefaultPath(&connection_); + const auto* alternative_path = + QuicConnectionPeer::GetAlternativePath(&connection_); + EXPECT_EQ(default_path->client_connection_id, client_cid0); + EXPECT_EQ(default_path->server_connection_id, server_cid0); + EXPECT_EQ(alternative_path->client_connection_id, client_cid1); + EXPECT_EQ(alternative_path->server_connection_id, server_cid1); + EXPECT_EQ(packet_creator->GetDestinationConnectionId(), client_cid0); + EXPECT_EQ(packet_creator->GetSourceConnectionId(), server_cid0); + + // Process non-probing packets on the default path. + peer_creator_.SetServerConnectionId(server_cid0); + EXPECT_CALL(visitor_, OnStreamFrame(_)).WillRepeatedly(Invoke([=]() { + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + })); + // Receives packets 3 - 39 to send 19 ACK-only packets, which will force the + // connection to reach |kMaxConsecutiveNonRetransmittablePackets| while + // sending the next ACK. + for (size_t i = 3; i <= 39; ++i) { + ProcessDataPacket(i); + } + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + + EXPECT_TRUE(connection_.HasPendingAcks()); + QuicTime ack_time = connection_.GetAckAlarm()->deadline(); + QuicTime path_validation_retry_time = + connection_.GetRetryTimeout(kNewPeerAddress, writer_.get()); + // Advance time to simultaneously fire path validation retry and ACK alarms. + clock_.AdvanceTime(std::max(ack_time, path_validation_retry_time) - + clock_.ApproximateNow()); + + // The 20th ACK should bundle with a WINDOW_UPDATE frame. + EXPECT_CALL(visitor_, OnAckNeedsRetransmittableFrame()) + .WillOnce(Invoke([this]() { + connection_.SendControlFrame(QuicFrame(QuicWindowUpdateFrame(1, 0, 0))); + })); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .WillOnce(Invoke([&]() { + EXPECT_EQ(kNewPeerAddress, writer_->last_write_peer_address()); + EXPECT_FALSE(writer_->path_challenge_frames().empty()); + // Retry path validation shouldn't bundle ACK. + EXPECT_TRUE(writer_->ack_frames().empty()); + })) + .WillOnce(Invoke([&]() { + EXPECT_EQ(kPeerAddress, writer_->last_write_peer_address()); + EXPECT_FALSE(writer_->ack_frames().empty()); + EXPECT_FALSE(writer_->window_update_frames().empty()); + })); + static_cast( + QuicPathValidatorPeer::retry_timer( + QuicConnectionPeer::path_validator(&connection_))) + ->Fire(); +} + +TEST_P(QuicConnectionTest, DoNotIssueNewCidIfVisitorSaysNo) { + set_perspective(Perspective::IS_SERVER); + if (!connection_.connection_migration_use_new_cid()) { + return; + } + + connection_.CreateConnectionIdManager(); + + QuicConnectionId server_cid0 = connection_.connection_id(); + QuicConnectionId client_cid1 = TestConnectionId(2); + QuicConnectionId server_cid1; + // Sends new server CID to client. + if (!connection_.connection_id().IsEmpty()) { + EXPECT_CALL(connection_id_generator_, GenerateNextConnectionId(_)) + .WillOnce(Return(TestConnectionId(456))); + } + EXPECT_CALL(visitor_, MaybeReserveConnectionId(_)).WillOnce(Return(false)); + if (GetQuicReloadableFlag(quic_check_cid_collision_when_issue_new_cid)) { + EXPECT_CALL(visitor_, SendNewConnectionId(_)).Times(0); + } else { + EXPECT_CALL(visitor_, SendNewConnectionId(_)).Times(1); + } + connection_.MaybeSendConnectionIdToClient(); +} + +TEST_P(QuicConnectionTest, + ProbedOnAnotherPathAfterPeerIpAddressChangeAtServer) { + PathProbeTestInit(Perspective::IS_SERVER); + if (!connection_.validate_client_address()) { + return; + } + + const QuicSocketAddress kNewPeerAddress(QuicIpAddress::Loopback4(), + /*port=*/23456); + + // Process a packet with a new peer address will start connection migration. + EXPECT_CALL(visitor_, OnConnectionMigration(IPV6_TO_IPV4_CHANGE)).Times(1); + // IETF QUIC send algorithm should be changed to a different object, so no + // OnPacketSent() called on the old send algorithm. + EXPECT_CALL(*send_algorithm_, + OnPacketSent(_, _, _, _, NO_RETRANSMITTABLE_DATA)) + .Times(0); + EXPECT_CALL(visitor_, OnStreamFrame(_)).WillOnce(Invoke([=]() { + EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); + })); + QuicFrames frames2; + frames2.push_back(QuicFrame(frame2_)); + ProcessFramesPacketWithAddresses(frames2, kSelfAddress, kNewPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_TRUE(QuicConnectionPeer::IsAlternativePathValidated(&connection_)); + EXPECT_TRUE(connection_.HasPendingPathValidation()); + + // Switch to use the mock send algorithm. + send_algorithm_ = new StrictMock(); + EXPECT_CALL(*send_algorithm_, CanSend(_)).WillRepeatedly(Return(true)); + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .WillRepeatedly(Return(kDefaultTCPMSS)); + EXPECT_CALL(*send_algorithm_, OnApplicationLimited(_)).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, BandwidthEstimate()) + .Times(AnyNumber()) + .WillRepeatedly(Return(QuicBandwidth::Zero())); + EXPECT_CALL(*send_algorithm_, InSlowStart()).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, InRecovery()).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, PopulateConnectionStats(_)).Times(AnyNumber()); + connection_.SetSendAlgorithm(send_algorithm_); + + // Receive probing packet with a newer peer address shouldn't override the + // on-going path validation. + const QuicSocketAddress kNewerPeerAddress(QuicIpAddress::Loopback4(), + /*port=*/34567); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .WillOnce(Invoke([&]() { + EXPECT_EQ(kNewerPeerAddress, writer_->last_write_peer_address()); + EXPECT_FALSE(writer_->path_response_frames().empty()); + EXPECT_TRUE(writer_->path_challenge_frames().empty()); + })); + QuicPathFrameBuffer path_challenge_payload{0, 1, 2, 3, 4, 5, 6, 7}; + QuicFrames frames1; + frames1.push_back( + QuicFrame(QuicPathChallengeFrame(0, path_challenge_payload))); + ProcessFramesPacketWithAddresses(frames1, kSelfAddress, kNewerPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kNewPeerAddress, connection_.effective_peer_address()); + EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); + EXPECT_TRUE(QuicConnectionPeer::IsAlternativePathValidated(&connection_)); + EXPECT_TRUE(connection_.HasPendingPathValidation()); +} + +TEST_P(QuicConnectionTest, + PathValidationFailedOnClientDueToLackOfServerConnectionId) { + if (!GetQuicReloadableFlag( + quic_remove_connection_migration_connection_option_v2)) { + QuicConfig config; + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + config.SetConnectionOptionsToSend({kRVCM}); + } + if (!connection_.connection_migration_use_new_cid()) { + return; + } + PathProbeTestInit(Perspective::IS_CLIENT, + /*receive_new_server_connection_id=*/false); + + const QuicSocketAddress kNewSelfAddress(QuicIpAddress::Loopback4(), + /*port=*/34567); + + bool success; + connection_.ValidatePath( + std::make_unique( + kNewSelfAddress, connection_.peer_address(), writer_.get()), + std::make_unique( + &connection_, kNewSelfAddress, connection_.peer_address(), &success), + PathValidationReason::kReasonUnknown); + + EXPECT_FALSE(success); +} + +TEST_P(QuicConnectionTest, + PathValidationFailedOnClientDueToLackOfClientConnectionIdTheSecondTime) { + if (!GetQuicReloadableFlag( + quic_remove_connection_migration_connection_option_v2)) { + QuicConfig config; + config.SetConnectionOptionsToSend({kRVCM}); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + } + if (!connection_.connection_migration_use_new_cid()) { + return; + } + PathProbeTestInit(Perspective::IS_CLIENT, + /*receive_new_server_connection_id=*/false); + SetClientConnectionId(TestConnectionId(1)); + + // Make sure server connection ID is available for the 1st validation. + QuicConnectionId server_cid0 = connection_.connection_id(); + QuicConnectionId server_cid1 = TestConnectionId(2); + QuicConnectionId server_cid2 = TestConnectionId(4); + QuicConnectionId client_cid1; + QuicNewConnectionIdFrame frame1; + frame1.connection_id = server_cid1; + frame1.sequence_number = 1u; + frame1.retire_prior_to = 0u; + frame1.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame1.connection_id); + connection_.OnNewConnectionIdFrame(frame1); + const auto* packet_creator = + QuicConnectionPeer::GetPacketCreator(&connection_); + ASSERT_EQ(packet_creator->GetDestinationConnectionId(), server_cid0); + + // Client will issue a new client connection ID to server. + EXPECT_CALL(connection_id_generator_, GenerateNextConnectionId(_)) + .WillOnce(Return(TestConnectionId(456))); + EXPECT_CALL(visitor_, SendNewConnectionId(_)) + .WillOnce(Invoke([&](const QuicNewConnectionIdFrame& frame) { + client_cid1 = frame.connection_id; + })); + + const QuicSocketAddress kSelfAddress1(QuicIpAddress::Any4(), 12345); + ASSERT_NE(kSelfAddress1, connection_.self_address()); + bool success1; + connection_.ValidatePath( + std::make_unique( + kSelfAddress1, connection_.peer_address(), writer_.get()), + std::make_unique( + &connection_, kSelfAddress1, connection_.peer_address(), &success1), + PathValidationReason::kReasonUnknown); + + // Migrate upon 1st validation success. + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); + ASSERT_TRUE(connection_.MigratePath(kSelfAddress1, connection_.peer_address(), + &new_writer, /*owns_writer=*/false)); + QuicConnectionPeer::RetirePeerIssuedConnectionIdsNoLongerOnPath(&connection_); + const auto* default_path = QuicConnectionPeer::GetDefaultPath(&connection_); + EXPECT_EQ(default_path->client_connection_id, client_cid1); + EXPECT_EQ(default_path->server_connection_id, server_cid1); + EXPECT_EQ(default_path->stateless_reset_token, frame1.stateless_reset_token); + const auto* alternative_path = + QuicConnectionPeer::GetAlternativePath(&connection_); + EXPECT_TRUE(alternative_path->client_connection_id.IsEmpty()); + EXPECT_TRUE(alternative_path->server_connection_id.IsEmpty()); + EXPECT_FALSE(alternative_path->stateless_reset_token.has_value()); + ASSERT_EQ(packet_creator->GetDestinationConnectionId(), server_cid1); + + // Client will retire server connection ID on old default_path. + auto* retire_peer_issued_cid_alarm = + connection_.GetRetirePeerIssuedConnectionIdAlarm(); + ASSERT_TRUE(retire_peer_issued_cid_alarm->IsSet()); + EXPECT_CALL(visitor_, SendRetireConnectionId(/*sequence_number=*/0u)); + retire_peer_issued_cid_alarm->Fire(); + + // Another server connection ID is available to client. + QuicNewConnectionIdFrame frame2; + frame2.connection_id = server_cid2; + frame2.sequence_number = 2u; + frame2.retire_prior_to = 1u; + frame2.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame2.connection_id); + connection_.OnNewConnectionIdFrame(frame2); + + const QuicSocketAddress kSelfAddress2(QuicIpAddress::Loopback4(), + /*port=*/45678); + bool success2; + connection_.ValidatePath( + std::make_unique( + kSelfAddress2, connection_.peer_address(), writer_.get()), + std::make_unique( + &connection_, kSelfAddress2, connection_.peer_address(), &success2), + PathValidationReason::kReasonUnknown); + // Since server does not retire any client connection ID yet, 2nd validation + // would fail due to lack of client connection ID. + EXPECT_FALSE(success2); +} + +TEST_P(QuicConnectionTest, ServerConnectionIdRetiredUponPathValidationFailure) { + if (!GetQuicReloadableFlag( + quic_remove_connection_migration_connection_option_v2)) { + QuicConfig config; + config.SetConnectionOptionsToSend({kRVCM}); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + } + if (!connection_.connection_migration_use_new_cid()) { + return; + } + PathProbeTestInit(Perspective::IS_CLIENT); + + // Make sure server connection ID is available for validation. + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(2); + frame.sequence_number = 1u; + frame.retire_prior_to = 0u; + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + connection_.OnNewConnectionIdFrame(frame); + + const QuicSocketAddress kNewSelfAddress(QuicIpAddress::Loopback4(), + /*port=*/34567); + bool success; + connection_.ValidatePath( + std::make_unique( + kNewSelfAddress, connection_.peer_address(), writer_.get()), + std::make_unique( + &connection_, kNewSelfAddress, connection_.peer_address(), &success), + PathValidationReason::kReasonUnknown); + + auto* path_validator = QuicConnectionPeer::path_validator(&connection_); + path_validator->CancelPathValidation(); + QuicConnectionPeer::RetirePeerIssuedConnectionIdsNoLongerOnPath(&connection_); + EXPECT_FALSE(success); + const auto* alternative_path = + QuicConnectionPeer::GetAlternativePath(&connection_); + EXPECT_TRUE(alternative_path->client_connection_id.IsEmpty()); + EXPECT_TRUE(alternative_path->server_connection_id.IsEmpty()); + EXPECT_FALSE(alternative_path->stateless_reset_token.has_value()); + + // Client will retire server connection ID on alternative_path. + auto* retire_peer_issued_cid_alarm = + connection_.GetRetirePeerIssuedConnectionIdAlarm(); + ASSERT_TRUE(retire_peer_issued_cid_alarm->IsSet()); + EXPECT_CALL(visitor_, SendRetireConnectionId(/*sequence_number=*/1u)); + retire_peer_issued_cid_alarm->Fire(); +} + +TEST_P(QuicConnectionTest, + MigratePathDirectlyFailedDueToLackOfServerConnectionId) { + if (!GetQuicReloadableFlag( + quic_remove_connection_migration_connection_option_v2)) { + QuicConfig config; + config.SetConnectionOptionsToSend({kRVCM}); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + } + if (!connection_.connection_migration_use_new_cid()) { + return; + } + PathProbeTestInit(Perspective::IS_CLIENT, + /*receive_new_server_connection_id=*/false); + const QuicSocketAddress kSelfAddress1(QuicIpAddress::Any4(), 12345); + ASSERT_NE(kSelfAddress1, connection_.self_address()); + + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); + ASSERT_FALSE(connection_.MigratePath(kSelfAddress1, + connection_.peer_address(), &new_writer, + /*owns_writer=*/false)); +} + +TEST_P(QuicConnectionTest, + MigratePathDirectlyFailedDueToLackOfClientConnectionIdTheSecondTime) { + if (!GetQuicReloadableFlag( + quic_remove_connection_migration_connection_option_v2)) { + QuicConfig config; + config.SetConnectionOptionsToSend({kRVCM}); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + } + if (!connection_.connection_migration_use_new_cid()) { + return; + } + PathProbeTestInit(Perspective::IS_CLIENT, + /*receive_new_server_connection_id=*/false); + SetClientConnectionId(TestConnectionId(1)); + + // Make sure server connection ID is available for the 1st migration. + QuicNewConnectionIdFrame frame1; + frame1.connection_id = TestConnectionId(2); + frame1.sequence_number = 1u; + frame1.retire_prior_to = 0u; + frame1.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame1.connection_id); + connection_.OnNewConnectionIdFrame(frame1); + + // Client will issue a new client connection ID to server. + QuicConnectionId new_client_connection_id; + EXPECT_CALL(connection_id_generator_, GenerateNextConnectionId(_)) + .WillOnce(Return(TestConnectionId(456))); + EXPECT_CALL(visitor_, SendNewConnectionId(_)) + .WillOnce(Invoke([&](const QuicNewConnectionIdFrame& frame) { + new_client_connection_id = frame.connection_id; + })); + + // 1st migration is successful. + const QuicSocketAddress kSelfAddress1(QuicIpAddress::Any4(), 12345); + ASSERT_NE(kSelfAddress1, connection_.self_address()); + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); + ASSERT_TRUE(connection_.MigratePath(kSelfAddress1, connection_.peer_address(), + &new_writer, + /*owns_writer=*/false)); + QuicConnectionPeer::RetirePeerIssuedConnectionIdsNoLongerOnPath(&connection_); + const auto* default_path = QuicConnectionPeer::GetDefaultPath(&connection_); + EXPECT_EQ(default_path->client_connection_id, new_client_connection_id); + EXPECT_EQ(default_path->server_connection_id, frame1.connection_id); + EXPECT_EQ(default_path->stateless_reset_token, frame1.stateless_reset_token); + + // Client will retire server connection ID on old default_path. + auto* retire_peer_issued_cid_alarm = + connection_.GetRetirePeerIssuedConnectionIdAlarm(); + ASSERT_TRUE(retire_peer_issued_cid_alarm->IsSet()); + EXPECT_CALL(visitor_, SendRetireConnectionId(/*sequence_number=*/0u)); + retire_peer_issued_cid_alarm->Fire(); + + // Another server connection ID is available to client. + QuicNewConnectionIdFrame frame2; + frame2.connection_id = TestConnectionId(4); + frame2.sequence_number = 2u; + frame2.retire_prior_to = 1u; + frame2.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame2.connection_id); + connection_.OnNewConnectionIdFrame(frame2); + + // Since server does not retire any client connection ID yet, 2nd migration + // would fail due to lack of client connection ID. + const QuicSocketAddress kSelfAddress2(QuicIpAddress::Loopback4(), + /*port=*/45678); + auto new_writer2 = std::make_unique(version(), &clock_, + Perspective::IS_CLIENT); + ASSERT_FALSE(connection_.MigratePath( + kSelfAddress2, connection_.peer_address(), new_writer2.release(), + /*owns_writer=*/true)); +} + +TEST_P(QuicConnectionTest, + CloseConnectionAfterReceiveNewConnectionIdFromPeerUsingEmptyCID) { + if (!version().HasIetfQuicFrames()) { + return; + } + set_perspective(Perspective::IS_SERVER); + ASSERT_TRUE(connection_.client_connection_id().IsEmpty()); + + EXPECT_CALL(visitor_, BeforeConnectionCloseSent()); + EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .WillOnce(Invoke(this, &QuicConnectionTest::SaveConnectionCloseFrame)); + QuicNewConnectionIdFrame frame; + frame.sequence_number = 1u; + frame.connection_id = TestConnectionId(1); + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + frame.retire_prior_to = 0u; + + EXPECT_FALSE(connection_.OnNewConnectionIdFrame(frame)); + + EXPECT_FALSE(connection_.connected()); + EXPECT_THAT(saved_connection_close_frame_.quic_error_code, + IsError(IETF_QUIC_PROTOCOL_VIOLATION)); +} + +TEST_P(QuicConnectionTest, NewConnectionIdFrameResultsInError) { + if (!version().HasIetfQuicFrames()) { + return; + } + connection_.CreateConnectionIdManager(); + ASSERT_FALSE(connection_.connection_id().IsEmpty()); + + EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .WillOnce(Invoke(this, &QuicConnectionTest::SaveConnectionCloseFrame)); + QuicNewConnectionIdFrame frame; + frame.sequence_number = 1u; + frame.connection_id = connection_id_; // Reuses connection ID casuing error. + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + frame.retire_prior_to = 0u; + + EXPECT_FALSE(connection_.OnNewConnectionIdFrame(frame)); + + EXPECT_FALSE(connection_.connected()); + EXPECT_THAT(saved_connection_close_frame_.quic_error_code, + IsError(IETF_QUIC_PROTOCOL_VIOLATION)); +} + +TEST_P(QuicConnectionTest, + ClientRetirePeerIssuedConnectionIdTriggeredByNewConnectionIdFrame) { + if (!version().HasIetfQuicFrames()) { + return; + } + connection_.CreateConnectionIdManager(); + + QuicNewConnectionIdFrame frame; + frame.sequence_number = 1u; + frame.connection_id = TestConnectionId(1); + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + frame.retire_prior_to = 0u; + + EXPECT_TRUE(connection_.OnNewConnectionIdFrame(frame)); + auto* retire_peer_issued_cid_alarm = + connection_.GetRetirePeerIssuedConnectionIdAlarm(); + ASSERT_FALSE(retire_peer_issued_cid_alarm->IsSet()); + + frame.sequence_number = 2u; + frame.connection_id = TestConnectionId(2); + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + frame.retire_prior_to = 1u; // CID associated with #1 will be retired. + + EXPECT_TRUE(connection_.OnNewConnectionIdFrame(frame)); + ASSERT_TRUE(retire_peer_issued_cid_alarm->IsSet()); + EXPECT_EQ(connection_.connection_id(), connection_id_); + + EXPECT_CALL(visitor_, SendRetireConnectionId(/*sequence_number=*/0u)); + retire_peer_issued_cid_alarm->Fire(); + EXPECT_EQ(connection_.connection_id(), TestConnectionId(2)); + EXPECT_EQ(connection_.packet_creator().GetDestinationConnectionId(), + TestConnectionId(2)); +} + +TEST_P(QuicConnectionTest, + ServerRetirePeerIssuedConnectionIdTriggeredByNewConnectionIdFrame) { + if (!version().HasIetfQuicFrames()) { + return; + } + set_perspective(Perspective::IS_SERVER); + SetClientConnectionId(TestConnectionId(0)); + + QuicNewConnectionIdFrame frame; + frame.sequence_number = 1u; + frame.connection_id = TestConnectionId(1); + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + frame.retire_prior_to = 0u; + + EXPECT_TRUE(connection_.OnNewConnectionIdFrame(frame)); + auto* retire_peer_issued_cid_alarm = + connection_.GetRetirePeerIssuedConnectionIdAlarm(); + ASSERT_FALSE(retire_peer_issued_cid_alarm->IsSet()); + + frame.sequence_number = 2u; + frame.connection_id = TestConnectionId(2); + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + frame.retire_prior_to = 1u; // CID associated with #1 will be retired. + + EXPECT_TRUE(connection_.OnNewConnectionIdFrame(frame)); + ASSERT_TRUE(retire_peer_issued_cid_alarm->IsSet()); + EXPECT_EQ(connection_.client_connection_id(), TestConnectionId(0)); + + EXPECT_CALL(visitor_, SendRetireConnectionId(/*sequence_number=*/0u)); + retire_peer_issued_cid_alarm->Fire(); + EXPECT_EQ(connection_.client_connection_id(), TestConnectionId(2)); + EXPECT_EQ(connection_.packet_creator().GetDestinationConnectionId(), + TestConnectionId(2)); +} + +TEST_P( + QuicConnectionTest, + ReplacePeerIssuedConnectionIdOnBothPathsTriggeredByNewConnectionIdFrame) { + if (!version().HasIetfQuicFrames()) { + return; + } + PathProbeTestInit(Perspective::IS_SERVER); + SetClientConnectionId(TestConnectionId(0)); + + // Populate alternative_path_ with probing packet. + std::unique_ptr probing_packet = ConstructProbingPacket(); + + std::unique_ptr received(ConstructReceivedPacket( + QuicEncryptedPacket(probing_packet->encrypted_buffer, + probing_packet->encrypted_length), + clock_.Now())); + QuicIpAddress new_host; + new_host.FromString("1.1.1.1"); + ProcessReceivedPacket(kSelfAddress, + QuicSocketAddress(new_host, /*port=*/23456), *received); + + EXPECT_EQ( + TestConnectionId(0), + QuicConnectionPeer::GetClientConnectionIdOnAlternativePath(&connection_)); + + QuicNewConnectionIdFrame frame; + frame.sequence_number = 1u; + frame.connection_id = TestConnectionId(1); + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + frame.retire_prior_to = 0u; + + EXPECT_TRUE(connection_.OnNewConnectionIdFrame(frame)); + auto* retire_peer_issued_cid_alarm = + connection_.GetRetirePeerIssuedConnectionIdAlarm(); + ASSERT_FALSE(retire_peer_issued_cid_alarm->IsSet()); + + frame.sequence_number = 2u; + frame.connection_id = TestConnectionId(2); + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + frame.retire_prior_to = 1u; // CID associated with #1 will be retired. + + EXPECT_TRUE(connection_.OnNewConnectionIdFrame(frame)); + ASSERT_TRUE(retire_peer_issued_cid_alarm->IsSet()); + EXPECT_EQ(connection_.client_connection_id(), TestConnectionId(0)); + + EXPECT_CALL(visitor_, SendRetireConnectionId(/*sequence_number=*/0u)); + retire_peer_issued_cid_alarm->Fire(); + EXPECT_EQ(connection_.client_connection_id(), TestConnectionId(2)); + EXPECT_EQ(connection_.packet_creator().GetDestinationConnectionId(), + TestConnectionId(2)); + // Clean up alternative path connection ID. + EXPECT_EQ( + TestConnectionId(2), + QuicConnectionPeer::GetClientConnectionIdOnAlternativePath(&connection_)); +} + +TEST_P(QuicConnectionTest, + CloseConnectionAfterReceiveRetireConnectionIdWhenNoCIDIssued) { + if (!version().HasIetfQuicFrames() || + !connection_.connection_migration_use_new_cid()) { + return; + } + set_perspective(Perspective::IS_SERVER); + + EXPECT_CALL(visitor_, BeforeConnectionCloseSent()); + EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .WillOnce(Invoke(this, &QuicConnectionTest::SaveConnectionCloseFrame)); + QuicRetireConnectionIdFrame frame; + frame.sequence_number = 1u; + + EXPECT_FALSE(connection_.OnRetireConnectionIdFrame(frame)); + + EXPECT_FALSE(connection_.connected()); + EXPECT_THAT(saved_connection_close_frame_.quic_error_code, + IsError(IETF_QUIC_PROTOCOL_VIOLATION)); +} + +TEST_P(QuicConnectionTest, RetireConnectionIdFrameResultsInError) { + if (!version().HasIetfQuicFrames() || + !connection_.connection_migration_use_new_cid()) { + return; + } + set_perspective(Perspective::IS_SERVER); + connection_.CreateConnectionIdManager(); + + if (!connection_.connection_id().IsEmpty()) { + EXPECT_CALL(connection_id_generator_, GenerateNextConnectionId(_)) + .WillOnce(Return(TestConnectionId(456))); + } + EXPECT_CALL(visitor_, MaybeReserveConnectionId(_)); + EXPECT_CALL(visitor_, SendNewConnectionId(_)); + connection_.MaybeSendConnectionIdToClient(); + + EXPECT_CALL(visitor_, BeforeConnectionCloseSent()); + EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) + .WillOnce(Invoke(this, &QuicConnectionTest::SaveConnectionCloseFrame)); + QuicRetireConnectionIdFrame frame; + frame.sequence_number = 2u; // The corresponding ID is never issued. + + EXPECT_FALSE(connection_.OnRetireConnectionIdFrame(frame)); + + EXPECT_FALSE(connection_.connected()); + EXPECT_THAT(saved_connection_close_frame_.quic_error_code, + IsError(IETF_QUIC_PROTOCOL_VIOLATION)); +} + +TEST_P(QuicConnectionTest, + ServerRetireSelfIssuedConnectionIdWithoutSendingNewConnectionIdBefore) { + if (!version().HasIetfQuicFrames()) { + return; + } + set_perspective(Perspective::IS_SERVER); + connection_.CreateConnectionIdManager(); + + auto* retire_self_issued_cid_alarm = + connection_.GetRetireSelfIssuedConnectionIdAlarm(); + ASSERT_FALSE(retire_self_issued_cid_alarm->IsSet()); + + QuicConnectionId cid0 = connection_id_; + QuicRetireConnectionIdFrame frame; + frame.sequence_number = 0u; + if (connection_.connection_migration_use_new_cid()) { + if (!connection_.connection_id().IsEmpty()) { + EXPECT_CALL(connection_id_generator_, GenerateNextConnectionId(cid0)) + .WillOnce(Return(TestConnectionId(456))); + EXPECT_CALL(connection_id_generator_, + GenerateNextConnectionId(TestConnectionId(456))) + .WillOnce(Return(TestConnectionId(789))); + } + EXPECT_CALL(visitor_, MaybeReserveConnectionId(_)) + .Times(2) + .WillRepeatedly(Return(true)); + EXPECT_CALL(visitor_, SendNewConnectionId(_)).Times(2); + } + EXPECT_TRUE(connection_.OnRetireConnectionIdFrame(frame)); +} + +TEST_P(QuicConnectionTest, ServerRetireSelfIssuedConnectionId) { + if (!GetQuicReloadableFlag( + quic_remove_connection_migration_connection_option_v2)) { + QuicConfig config; + config.SetConnectionOptionsToSend({kRVCM}); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + } + if (!connection_.connection_migration_use_new_cid()) { + return; + } + set_perspective(Perspective::IS_SERVER); + connection_.CreateConnectionIdManager(); + QuicConnectionId recorded_cid; + auto cid_recorder = [&recorded_cid](const QuicConnectionId& cid) -> bool { + recorded_cid = cid; + return true; + }; + QuicConnectionId cid0 = connection_id_; + QuicConnectionId cid1; + QuicConnectionId cid2; + EXPECT_EQ(connection_.connection_id(), cid0); + EXPECT_EQ(connection_.GetOneActiveServerConnectionId(), cid0); + + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + if (!connection_.connection_id().IsEmpty()) { + EXPECT_CALL(connection_id_generator_, GenerateNextConnectionId(_)) + .WillOnce(Return(TestConnectionId(456))); + } + EXPECT_CALL(visitor_, MaybeReserveConnectionId(_)) + .WillOnce(Invoke(cid_recorder)); + EXPECT_CALL(visitor_, SendNewConnectionId(_)); + connection_.MaybeSendConnectionIdToClient(); + cid1 = recorded_cid; + + auto* retire_self_issued_cid_alarm = + connection_.GetRetireSelfIssuedConnectionIdAlarm(); + ASSERT_FALSE(retire_self_issued_cid_alarm->IsSet()); + + // Generate three packets with different connection IDs that will arrive out + // of order (2, 1, 3) later. + char buffers[3][kMaxOutgoingPacketSize]; + // Destination connection ID of packet1 is cid0. + auto packet1 = + ConstructPacket({QuicFrame(QuicPingFrame())}, ENCRYPTION_FORWARD_SECURE, + buffers[0], kMaxOutgoingPacketSize); + peer_creator_.SetServerConnectionId(cid1); + auto retire_cid_frame = std::make_unique(); + retire_cid_frame->sequence_number = 0u; + // Destination connection ID of packet2 is cid1. + auto packet2 = ConstructPacket({QuicFrame(retire_cid_frame.release())}, + ENCRYPTION_FORWARD_SECURE, buffers[1], + kMaxOutgoingPacketSize); + // Destination connection ID of packet3 is cid1. + auto packet3 = + ConstructPacket({QuicFrame(QuicPingFrame())}, ENCRYPTION_FORWARD_SECURE, + buffers[2], kMaxOutgoingPacketSize); + + // Packet2 with RetireConnectionId frame trigers sending NewConnectionId + // immediately. + if (!connection_.connection_id().IsEmpty()) { + EXPECT_CALL(connection_id_generator_, GenerateNextConnectionId(_)) + .WillOnce(Return(TestConnectionId(456))); + } + EXPECT_CALL(visitor_, MaybeReserveConnectionId(_)) + .WillOnce(Invoke(cid_recorder)); + EXPECT_CALL(visitor_, SendNewConnectionId(_)); + peer_creator_.SetServerConnectionId(cid1); + connection_.ProcessUdpPacket(kSelfAddress, kPeerAddress, *packet2); + cid2 = recorded_cid; + // cid0 is not retired immediately. + EXPECT_THAT(connection_.GetActiveServerConnectionIds(), + ElementsAre(cid0, cid1, cid2)); + ASSERT_TRUE(retire_self_issued_cid_alarm->IsSet()); + EXPECT_EQ(connection_.connection_id(), cid1); + EXPECT_TRUE(connection_.GetOneActiveServerConnectionId() == cid0 || + connection_.GetOneActiveServerConnectionId() == cid1 || + connection_.GetOneActiveServerConnectionId() == cid2); + + // Packet1 updates the connection ID on the default path but not the active + // connection ID. + connection_.ProcessUdpPacket(kSelfAddress, kPeerAddress, *packet1); + EXPECT_EQ(connection_.connection_id(), cid0); + EXPECT_TRUE(connection_.GetOneActiveServerConnectionId() == cid0 || + connection_.GetOneActiveServerConnectionId() == cid1 || + connection_.GetOneActiveServerConnectionId() == cid2); + + // cid0 is retired when the retire CID alarm fires. + EXPECT_CALL(visitor_, OnServerConnectionIdRetired(cid0)); + retire_self_issued_cid_alarm->Fire(); + EXPECT_THAT(connection_.GetActiveServerConnectionIds(), + ElementsAre(cid1, cid2)); + EXPECT_TRUE(connection_.GetOneActiveServerConnectionId() == cid1 || + connection_.GetOneActiveServerConnectionId() == cid2); + + // Packet3 updates the connection ID on the default path. + connection_.ProcessUdpPacket(kSelfAddress, kPeerAddress, *packet3); + EXPECT_EQ(connection_.connection_id(), cid1); + EXPECT_TRUE(connection_.GetOneActiveServerConnectionId() == cid1 || + connection_.GetOneActiveServerConnectionId() == cid2); +} + +TEST_P(QuicConnectionTest, PatchMissingClientConnectionIdOntoAlternativePath) { + if (!version().HasIetfQuicFrames()) { + return; + } + set_perspective(Perspective::IS_SERVER); + connection_.CreateConnectionIdManager(); + connection_.set_client_connection_id(TestConnectionId(1)); + + // Set up the state after path probing. + const auto* default_path = QuicConnectionPeer::GetDefaultPath(&connection_); + auto* alternative_path = QuicConnectionPeer::GetAlternativePath(&connection_); + QuicIpAddress new_host; + new_host.FromString("12.12.12.12"); + alternative_path->self_address = default_path->self_address; + alternative_path->peer_address = QuicSocketAddress(new_host, 12345); + alternative_path->server_connection_id = TestConnectionId(3); + ASSERT_TRUE(alternative_path->client_connection_id.IsEmpty()); + ASSERT_FALSE(alternative_path->stateless_reset_token.has_value()); + + QuicNewConnectionIdFrame frame; + frame.sequence_number = 1u; + frame.connection_id = TestConnectionId(5); + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + frame.retire_prior_to = 0u; + // New ID is patched onto the alternative path when the needed + // NEW_CONNECTION_ID frame is received after PATH_CHALLENGE frame. + connection_.OnNewConnectionIdFrame(frame); + + ASSERT_EQ(alternative_path->client_connection_id, frame.connection_id); + ASSERT_EQ(alternative_path->stateless_reset_token, + frame.stateless_reset_token); +} + +TEST_P(QuicConnectionTest, PatchMissingClientConnectionIdOntoDefaultPath) { + if (!version().HasIetfQuicFrames()) { + return; + } + set_perspective(Perspective::IS_SERVER); + connection_.CreateConnectionIdManager(); + connection_.set_client_connection_id(TestConnectionId(1)); + + // Set up the state after peer migration without probing. + auto* default_path = QuicConnectionPeer::GetDefaultPath(&connection_); + auto* alternative_path = QuicConnectionPeer::GetAlternativePath(&connection_); + auto* packet_creator = QuicConnectionPeer::GetPacketCreator(&connection_); + *alternative_path = std::move(*default_path); + QuicIpAddress new_host; + new_host.FromString("12.12.12.12"); + default_path->self_address = default_path->self_address; + default_path->peer_address = QuicSocketAddress(new_host, 12345); + default_path->server_connection_id = TestConnectionId(3); + packet_creator->SetDefaultPeerAddress(default_path->peer_address); + packet_creator->SetServerConnectionId(default_path->server_connection_id); + packet_creator->SetClientConnectionId(default_path->client_connection_id); + + ASSERT_FALSE(default_path->validated); + ASSERT_TRUE(default_path->client_connection_id.IsEmpty()); + ASSERT_FALSE(default_path->stateless_reset_token.has_value()); + + QuicNewConnectionIdFrame frame; + frame.sequence_number = 1u; + frame.connection_id = TestConnectionId(5); + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + frame.retire_prior_to = 0u; + // New ID is patched onto the default path when the needed + // NEW_CONNECTION_ID frame is received after PATH_CHALLENGE frame. + connection_.OnNewConnectionIdFrame(frame); + + ASSERT_EQ(default_path->client_connection_id, frame.connection_id); + ASSERT_EQ(default_path->stateless_reset_token, frame.stateless_reset_token); + ASSERT_EQ(packet_creator->GetDestinationConnectionId(), frame.connection_id); +} + +TEST_P(QuicConnectionTest, ShouldGeneratePacketBlockedByMissingConnectionId) { + if (!version().HasIetfQuicFrames()) { + return; + } + set_perspective(Perspective::IS_SERVER); + connection_.set_client_connection_id(TestConnectionId(1)); + connection_.CreateConnectionIdManager(); + if (version().SupportsAntiAmplificationLimit()) { + QuicConnectionPeer::SetAddressValidated(&connection_); + } + + ASSERT_TRUE( + connection_.ShouldGeneratePacket(NO_RETRANSMITTABLE_DATA, NOT_HANDSHAKE)); + + QuicPacketCreator* packet_creator = + QuicConnectionPeer::GetPacketCreator(&connection_); + QuicIpAddress peer_host1; + peer_host1.FromString("12.12.12.12"); + QuicSocketAddress peer_address1(peer_host1, 1235); + + { + // No connection ID is available as context is created without any. + QuicPacketCreator::ScopedPeerAddressContext context( + packet_creator, peer_address1, EmptyQuicConnectionId(), + EmptyQuicConnectionId(), + /*update_connection_id=*/true); + ASSERT_FALSE(connection_.ShouldGeneratePacket(NO_RETRANSMITTABLE_DATA, + NOT_HANDSHAKE)); + } + ASSERT_TRUE( + connection_.ShouldGeneratePacket(NO_RETRANSMITTABLE_DATA, NOT_HANDSHAKE)); +} + +// Regression test for b/182571515 +TEST_P(QuicConnectionTest, LostDataThenGetAcknowledged) { + set_perspective(Perspective::IS_SERVER); + if (!connection_.validate_client_address() || + GetQuicFlag(quic_enforce_strict_amplification_factor)) { + return; + } + + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + if (version().SupportsAntiAmplificationLimit()) { + QuicConnectionPeer::SetAddressValidated(&connection_); + } + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + // Discard INITIAL key. + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.NeuterUnencryptedPackets(); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + + QuicPacketNumber last_packet; + // Send packets 1 to 4. + SendStreamDataToPeer(3, "foo", 0, NO_FIN, &last_packet); // Packet 1 + SendStreamDataToPeer(3, "foo", 3, NO_FIN, &last_packet); // Packet 2 + SendStreamDataToPeer(3, "foo", 6, NO_FIN, &last_packet); // Packet 3 + SendStreamDataToPeer(3, "foo", 9, NO_FIN, &last_packet); // Packet 4 + + // Process a PING packet to set peer address. + ProcessFramePacket(QuicFrame(QuicPingFrame())); + + // Process a packet containing a STREAM_FRAME and an ACK with changed peer + // address. + QuicFrames frames; + frames.push_back(QuicFrame(frame1_)); + QuicAckFrame ack = InitAckFrame({{QuicPacketNumber(1), QuicPacketNumber(5)}}); + frames.push_back(QuicFrame(&ack)); + + // Invoke OnCanWrite. + QuicIpAddress ip_address; + ASSERT_TRUE(ip_address.FromString("127.0.52.223")); + EXPECT_QUIC_BUG( + { + EXPECT_CALL(visitor_, OnConnectionMigration(_)).Times(1); + EXPECT_CALL(visitor_, OnStreamFrame(_)) + .WillOnce(InvokeWithoutArgs(¬ifier_, + &SimpleSessionNotifier::OnCanWrite)); + ProcessFramesPacketWithAddresses(frames, kSelfAddress, + QuicSocketAddress(ip_address, 1000), + ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(1u, writer_->path_challenge_frames().size()); + + // Verify stream frame will not be retransmitted. + EXPECT_TRUE(writer_->stream_frames().empty()); + }, + "Try to write mid packet processing"); +} + +TEST_P(QuicConnectionTest, PtoSendStreamData) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + set_perspective(Perspective::IS_SERVER); + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + } + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(AnyNumber()); + ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL); + EXPECT_TRUE(connection_.HasPendingAcks()); + + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + // Send INITIAL 1. + connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_INITIAL); + + connection_.SetEncrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + SetDecrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + // Send HANDSHAKE packets. + EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1); + connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_HANDSHAKE); + + connection_.SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + + // Send half RTT packet with congestion control blocked. + EXPECT_CALL(*send_algorithm_, CanSend(_)).WillRepeatedly(Return(false)); + connection_.SendStreamDataWithString(2, std::string(1500, 'a'), 0, NO_FIN); + + ASSERT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + connection_.GetRetransmissionAlarm()->Fire(); + // Verify INITIAL and HANDSHAKE get retransmitted. + EXPECT_EQ(0x01010101u, writer_->final_bytes_of_last_packet()); +} + +TEST_P(QuicConnectionTest, SendingZeroRttPacketsDoesNotPostponePTO) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + // Send CHLO. + connection_.SendCryptoStreamData(); + ASSERT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + // Install 0-RTT keys. + connection_.SetEncrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(0x02)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + + // CHLO gets acknowledged after 10ms. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + QuicAckFrame frame1 = InitAckFrame(1); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _, _, _)); + ProcessFramePacketAtLevel(1, QuicFrame(&frame1), ENCRYPTION_INITIAL); + // Verify PTO is still armed since address validation is not finished yet. + ASSERT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + QuicTime pto_deadline = connection_.GetRetransmissionAlarm()->deadline(); + + // Send 0-RTT packet. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + connection_.SetEncrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(0x02)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + connection_.SendStreamDataWithString(2, "foo", 0, NO_FIN); + ASSERT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + // PTO deadline should be unchanged. + EXPECT_EQ(pto_deadline, connection_.GetRetransmissionAlarm()->deadline()); +} + +TEST_P(QuicConnectionTest, QueueingUndecryptablePacketsDoesntPostponePTO) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + config.set_max_undecryptable_packets(3); + connection_.SetFromConfig(config); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + connection_.RemoveDecrypter(ENCRYPTION_FORWARD_SECURE); + // Send CHLO. + connection_.SendCryptoStreamData(); + + // Send 0-RTT packet. + connection_.SetEncrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(0x02)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + connection_.SendStreamDataWithString(2, "foo", 0, NO_FIN); + + // CHLO gets acknowledged after 10ms. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + QuicAckFrame frame1 = InitAckFrame(1); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _, _, _)); + ProcessFramePacketAtLevel(1, QuicFrame(&frame1), ENCRYPTION_INITIAL); + // Verify PTO is still armed since address validation is not finished yet. + ASSERT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + QuicTime pto_deadline = connection_.GetRetransmissionAlarm()->deadline(); + + // Receive an undecryptable packets. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + peer_framer_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(0xFF)); + ProcessDataPacketAtLevel(3, !kHasStopWaiting, ENCRYPTION_FORWARD_SECURE); + // Verify PTO deadline is sooner. + EXPECT_GT(pto_deadline, connection_.GetRetransmissionAlarm()->deadline()); + pto_deadline = connection_.GetRetransmissionAlarm()->deadline(); + + // PTO fires. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + clock_.AdvanceTime(pto_deadline - clock_.ApproximateNow()); + connection_.GetRetransmissionAlarm()->Fire(); + // Verify PTO is still armed since address validation is not finished yet. + ASSERT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + pto_deadline = connection_.GetRetransmissionAlarm()->deadline(); + + // Verify PTO deadline does not change. + ProcessDataPacketAtLevel(4, !kHasStopWaiting, ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(pto_deadline, connection_.GetRetransmissionAlarm()->deadline()); +} + +TEST_P(QuicConnectionTest, QueueUndecryptableHandshakePackets) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + config.set_max_undecryptable_packets(3); + connection_.SetFromConfig(config); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + connection_.RemoveDecrypter(ENCRYPTION_HANDSHAKE); + // Send CHLO. + connection_.SendCryptoStreamData(); + + // Send 0-RTT packet. + connection_.SetEncrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(0x02)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + connection_.SendStreamDataWithString(2, "foo", 0, NO_FIN); + EXPECT_EQ(0u, QuicConnectionPeer::NumUndecryptablePackets(&connection_)); + + // Receive an undecryptable handshake packet. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + peer_framer_.SetEncrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(0xFF)); + ProcessDataPacketAtLevel(3, !kHasStopWaiting, ENCRYPTION_HANDSHAKE); + // Verify this handshake packet gets queued. + EXPECT_EQ(1u, QuicConnectionPeer::NumUndecryptablePackets(&connection_)); +} + +TEST_P(QuicConnectionTest, PingNotSentAt0RTTLevelWhenInitialAvailable) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + // Send CHLO. + connection_.SendCryptoStreamData(); + // Send 0-RTT packet. + connection_.SetEncrypter( + ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + connection_.SendStreamDataWithString(2, "foo", 0, NO_FIN); + + // CHLO gets acknowledged after 10ms. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + QuicAckFrame frame1 = InitAckFrame(1); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _, _, _)); + ProcessFramePacketAtLevel(1, QuicFrame(&frame1), ENCRYPTION_INITIAL); + // Verify PTO is still armed since address validation is not finished yet. + ASSERT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + QuicTime pto_deadline = connection_.GetRetransmissionAlarm()->deadline(); + + // PTO fires. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + clock_.AdvanceTime(pto_deadline - clock_.ApproximateNow()); + connection_.GetRetransmissionAlarm()->Fire(); + // Verify the PING gets sent in ENCRYPTION_INITIAL. + EXPECT_NE(0x02020202u, writer_->final_bytes_of_last_packet()); +} + +TEST_P(QuicConnectionTest, AckElicitingFrames) { + if (!GetQuicReloadableFlag( + quic_remove_connection_migration_connection_option_v2)) { + QuicConfig config; + config.SetConnectionOptionsToSend({kRVCM}); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + } + if (!version().HasIetfQuicFrames() || + !connection_.connection_migration_use_new_cid()) { + return; + } + EXPECT_CALL(connection_id_generator_, + GenerateNextConnectionId(TestConnectionId(12))) + .WillOnce(Return(TestConnectionId(456))); + EXPECT_CALL(connection_id_generator_, + GenerateNextConnectionId(TestConnectionId(456))) + .WillOnce(Return(TestConnectionId(789))); + EXPECT_CALL(visitor_, SendNewConnectionId(_)).Times(2); + EXPECT_CALL(visitor_, OnRstStream(_)); + EXPECT_CALL(visitor_, OnWindowUpdateFrame(_)); + EXPECT_CALL(visitor_, OnBlockedFrame(_)); + EXPECT_CALL(visitor_, OnHandshakeDoneReceived()); + EXPECT_CALL(visitor_, OnStreamFrame(_)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _, _, _)); + EXPECT_CALL(visitor_, OnMaxStreamsFrame(_)); + EXPECT_CALL(visitor_, OnStreamsBlockedFrame(_)); + EXPECT_CALL(visitor_, OnStopSendingFrame(_)); + EXPECT_CALL(visitor_, OnMessageReceived("")); + EXPECT_CALL(visitor_, OnNewTokenReceived("")); + + SetClientConnectionId(TestConnectionId(12)); + connection_.CreateConnectionIdManager(); + QuicConnectionPeer::GetSelfIssuedConnectionIdManager(&connection_) + ->MaybeSendNewConnectionIds(); + connection_.set_can_receive_ack_frequency_frame(); + + QuicAckFrame ack_frame = InitAckFrame(1); + QuicRstStreamFrame rst_stream_frame; + QuicWindowUpdateFrame window_update_frame; + QuicPathChallengeFrame path_challenge_frame; + QuicNewConnectionIdFrame new_connection_id_frame; + QuicRetireConnectionIdFrame retire_connection_id_frame; + retire_connection_id_frame.sequence_number = 1u; + QuicStopSendingFrame stop_sending_frame; + QuicPathResponseFrame path_response_frame; + QuicMessageFrame message_frame; + QuicNewTokenFrame new_token_frame; + QuicAckFrequencyFrame ack_frequency_frame; + QuicBlockedFrame blocked_frame; + size_t packet_number = 1; + + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + + for (uint8_t i = 0; i < NUM_FRAME_TYPES; ++i) { + QuicFrameType frame_type = static_cast(i); + bool skipped = false; + QuicFrame frame; + QuicFrames frames; + // Add some padding to fullfill the min size requirement of header + // protection. + frames.push_back(QuicFrame(QuicPaddingFrame(10))); + switch (frame_type) { + case PADDING_FRAME: + frame = QuicFrame(QuicPaddingFrame(10)); + break; + case MTU_DISCOVERY_FRAME: + frame = QuicFrame(QuicMtuDiscoveryFrame()); + break; + case PING_FRAME: + frame = QuicFrame(QuicPingFrame()); + break; + case MAX_STREAMS_FRAME: + frame = QuicFrame(QuicMaxStreamsFrame()); + break; + case STOP_WAITING_FRAME: + // Not supported. + skipped = true; + break; + case STREAMS_BLOCKED_FRAME: + frame = QuicFrame(QuicStreamsBlockedFrame()); + break; + case STREAM_FRAME: + frame = QuicFrame(QuicStreamFrame()); + break; + case HANDSHAKE_DONE_FRAME: + frame = QuicFrame(QuicHandshakeDoneFrame()); + break; + case ACK_FRAME: + frame = QuicFrame(&ack_frame); + break; + case RST_STREAM_FRAME: + frame = QuicFrame(&rst_stream_frame); + break; + case CONNECTION_CLOSE_FRAME: + // Do not test connection close. + skipped = true; + break; + case GOAWAY_FRAME: + // Does not exist in IETF QUIC. + skipped = true; + break; + case BLOCKED_FRAME: + frame = QuicFrame(blocked_frame); + break; + case WINDOW_UPDATE_FRAME: + frame = QuicFrame(window_update_frame); + break; + case PATH_CHALLENGE_FRAME: + frame = QuicFrame(path_challenge_frame); + break; + case STOP_SENDING_FRAME: + frame = QuicFrame(stop_sending_frame); + break; + case NEW_CONNECTION_ID_FRAME: + frame = QuicFrame(&new_connection_id_frame); + break; + case RETIRE_CONNECTION_ID_FRAME: + frame = QuicFrame(&retire_connection_id_frame); + break; + case PATH_RESPONSE_FRAME: + frame = QuicFrame(path_response_frame); + break; + case MESSAGE_FRAME: + frame = QuicFrame(&message_frame); + break; + case CRYPTO_FRAME: + // CRYPTO_FRAME is ack eliciting is covered by other tests. + skipped = true; + break; + case NEW_TOKEN_FRAME: + frame = QuicFrame(&new_token_frame); + break; + case ACK_FREQUENCY_FRAME: + frame = QuicFrame(&ack_frequency_frame); + break; + case NUM_FRAME_TYPES: + skipped = true; + break; + } + if (skipped) { + continue; + } + ASSERT_EQ(frame_type, frame.type); + frames.push_back(frame); + EXPECT_FALSE(connection_.HasPendingAcks()); + // Process frame. + ProcessFramesPacketAtLevel(packet_number++, frames, + ENCRYPTION_FORWARD_SECURE); + if (QuicUtils::IsAckElicitingFrame(frame_type)) { + ASSERT_TRUE(connection_.HasPendingAcks()) << frame; + // Flush ACK. + clock_.AdvanceTime(DefaultDelayedAckTime()); + connection_.GetAckAlarm()->Fire(); + } + EXPECT_FALSE(connection_.HasPendingAcks()); + ASSERT_TRUE(connection_.connected()); + } +} + +TEST_P(QuicConnectionTest, ReceivedChloAndAck) { + if (!version().HasIetfQuicFrames()) { + return; + } + set_perspective(Perspective::IS_SERVER); + QuicFrames frames; + QuicAckFrame ack_frame = InitAckFrame(1); + frames.push_back(MakeCryptoFrame()); + frames.push_back(QuicFrame(&ack_frame)); + + EXPECT_CALL(visitor_, OnCryptoFrame(_)) + .WillOnce(IgnoreResult(InvokeWithoutArgs( + &connection_, &TestConnection::SendCryptoStreamData))); + EXPECT_CALL(visitor_, BeforeConnectionCloseSent()); + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)); + ProcessFramesPacketWithAddresses(frames, kSelfAddress, kPeerAddress, + ENCRYPTION_INITIAL); +} + +// Regression test for b/201643321. +TEST_P(QuicConnectionTest, FailedToRetransmitShlo) { + if (!version().HasIetfQuicFrames() || + GetQuicFlag(quic_enforce_strict_amplification_factor)) { + return; + } + set_perspective(Perspective::IS_SERVER); + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(AnyNumber()); + // Received INITIAL 1. + ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL); + EXPECT_TRUE(connection_.HasPendingAcks()); + + peer_framer_.SetEncrypter( + ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + + connection_.SetEncrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + SetDecrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + SetDecrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + connection_.SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + // Received ENCRYPTION_ZERO_RTT 1. + ProcessDataPacketAtLevel(1, !kHasStopWaiting, ENCRYPTION_ZERO_RTT); + { + QuicConnection::ScopedPacketFlusher flusher(&connection_); + // Send INITIAL 1. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_INITIAL); + // Send HANDSHAKE 2. + EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_HANDSHAKE); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + // Send half RTT data to exhaust amplification credit. + connection_.SendStreamDataWithString(0, std::string(100 * 1024, 'a'), 0, + NO_FIN); + } + // Received INITIAL 2. + ProcessCryptoPacketAtLevel(2, ENCRYPTION_INITIAL); + ASSERT_TRUE(connection_.HasPendingAcks()); + // Verify ACK delay is 1ms. + EXPECT_EQ(clock_.Now() + kAlarmGranularity, + connection_.GetAckAlarm()->deadline()); + // ACK is not throttled by amplification limit, and SHLO is bundled. Also + // HANDSHAKE + 1RTT packets get coalesced. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(3); + // ACK alarm fires. + clock_.AdvanceTime(kAlarmGranularity); + connection_.GetAckAlarm()->Fire(); + // Verify 1-RTT packet is coalesced. + EXPECT_EQ(0x03030303u, writer_->final_bytes_of_last_packet()); + // Only the first packet in the coalesced packet has been processed, + // verify SHLO is bundled with INITIAL ACK. + EXPECT_EQ(1u, writer_->ack_frames().size()); + EXPECT_EQ(1u, writer_->crypto_frames().size()); + // Process the coalesced HANDSHAKE packet. + ASSERT_TRUE(writer_->coalesced_packet() != nullptr); + auto packet = writer_->coalesced_packet()->Clone(); + writer_->framer()->ProcessPacket(*packet); + EXPECT_EQ(0u, writer_->ack_frames().size()); + EXPECT_EQ(1u, writer_->crypto_frames().size()); + // Process the coalesced 1-RTT packet. + ASSERT_TRUE(writer_->coalesced_packet() != nullptr); + packet = writer_->coalesced_packet()->Clone(); + writer_->framer()->ProcessPacket(*packet); + EXPECT_EQ(0u, writer_->crypto_frames().size()); + EXPECT_EQ(1u, writer_->stream_frames().size()); + + // Received INITIAL 3. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(AnyNumber()); + ProcessCryptoPacketAtLevel(3, ENCRYPTION_INITIAL); + EXPECT_TRUE(connection_.HasPendingAcks()); +} + +// Regression test for b/216133388. +TEST_P(QuicConnectionTest, FailedToConsumeCryptoData) { + if (!version().HasIetfQuicFrames()) { + return; + } + set_perspective(Perspective::IS_SERVER); + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(AnyNumber()); + // Received INITIAL 1. + ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL); + EXPECT_TRUE(connection_.HasPendingAcks()); + + peer_framer_.SetEncrypter( + ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + connection_.SetEncrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + SetDecrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + SetDecrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + connection_.SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + // Received ENCRYPTION_ZERO_RTT 1. + ProcessDataPacketAtLevel(1, !kHasStopWaiting, ENCRYPTION_ZERO_RTT); + { + QuicConnection::ScopedPacketFlusher flusher(&connection_); + // Send INITIAL 1. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_INITIAL); + // Send HANDSHAKE 2. + EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + connection_.SendCryptoDataWithString(std::string(200, 'a'), 0, + ENCRYPTION_HANDSHAKE); + // Send 1-RTT 3. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + connection_.SendStreamDataWithString(0, std::string(40, 'a'), 0, NO_FIN); + } + // Received HANDSHAKE Ping, hence discard INITIAL keys. + peer_framer_.SetEncrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(0x03)); + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.NeuterUnencryptedPackets(); + ProcessCryptoPacketAtLevel(1, ENCRYPTION_HANDSHAKE); + clock_.AdvanceTime(kAlarmGranularity); + { + QuicConnection::ScopedPacketFlusher flusher(&connection_); + // Sending this 1-RTT data would leave the coalescer only have space to + // accommodate the HANDSHAKE ACK. The crypto data cannot be bundled with the + // ACK. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + connection_.SendStreamDataWithString(0, std::string(1395, 'a'), 40, NO_FIN); + } + // Verify retransmission alarm is armed. + ASSERT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + const QuicTime retransmission_time = + connection_.GetRetransmissionAlarm()->deadline(); + clock_.AdvanceTime(retransmission_time - clock_.Now()); + connection_.GetRetransmissionAlarm()->Fire(); + + // Verify the retransmission is a coalesced packet with HANDSHAKE 2 and + // 1-RTT 3. + EXPECT_EQ(0x03030303u, writer_->final_bytes_of_last_packet()); + // Only the first packet in the coalesced packet has been processed. + EXPECT_EQ(1u, writer_->crypto_frames().size()); + // Process the coalesced 1-RTT packet. + ASSERT_TRUE(writer_->coalesced_packet() != nullptr); + auto packet = writer_->coalesced_packet()->Clone(); + writer_->framer()->ProcessPacket(*packet); + EXPECT_EQ(1u, writer_->stream_frames().size()); + ASSERT_TRUE(writer_->coalesced_packet() == nullptr); + // Verify retransmission alarm is still armed. + ASSERT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); +} + +TEST_P(QuicConnectionTest, + RTTSampleDoesNotIncludeQueuingDelayWithPostponedAckProcessing) { + // An endpoint might postpone the processing of ACK when the corresponding + // decryption key is not available. This test makes sure the RTT sample does + // not include the queuing delay. + if (!version().HasIetfQuicFrames()) { + return; + } + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + config.set_max_undecryptable_packets(3); + connection_.SetFromConfig(config); + + // 30ms RTT. + const QuicTime::Delta kTestRTT = QuicTime::Delta::FromMilliseconds(30); + RttStats* rtt_stats = const_cast(manager_->GetRttStats()); + rtt_stats->UpdateRtt(kTestRTT, QuicTime::Delta::Zero(), QuicTime::Zero()); + + // Send 0-RTT packet. + connection_.RemoveDecrypter(ENCRYPTION_FORWARD_SECURE); + connection_.SetEncrypter( + ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + connection_.SendStreamDataWithString(0, std::string(10, 'a'), 0, FIN); + + // Receives 1-RTT ACK for 0-RTT packet after RTT + ack_delay. + clock_.AdvanceTime( + kTestRTT + QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs)); + EXPECT_EQ(0u, QuicConnectionPeer::NumUndecryptablePackets(&connection_)); + peer_framer_.SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + QuicAckFrame ack_frame = InitAckFrame(1); + // Peer reported ACK delay. + ack_frame.ack_delay_time = + QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs); + QuicFrames frames; + frames.push_back(QuicFrame(&ack_frame)); + QuicPacketHeader header = + ConstructPacketHeader(30, ENCRYPTION_FORWARD_SECURE); + std::unique_ptr packet(ConstructPacket(header, frames)); + + char buffer[kMaxOutgoingPacketSize]; + size_t encrypted_length = peer_framer_.EncryptPayload( + ENCRYPTION_FORWARD_SECURE, QuicPacketNumber(30), *packet, buffer, + kMaxOutgoingPacketSize); + connection_.ProcessUdpPacket( + kSelfAddress, kPeerAddress, + QuicReceivedPacket(buffer, encrypted_length, clock_.Now(), false)); + if (connection_.GetSendAlarm()->IsSet()) { + connection_.GetSendAlarm()->Fire(); + } + ASSERT_EQ(1u, QuicConnectionPeer::NumUndecryptablePackets(&connection_)); + + // Assume 1-RTT decrypter is available after 10ms. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + EXPECT_FALSE(connection_.GetProcessUndecryptablePacketsAlarm()->IsSet()); + SetDecrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + ASSERT_TRUE(connection_.GetProcessUndecryptablePacketsAlarm()->IsSet()); + + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _, _, _)); + connection_.GetProcessUndecryptablePacketsAlarm()->Fire(); + // Verify RTT sample does not include queueing delay. + EXPECT_EQ(rtt_stats->latest_rtt(), kTestRTT); +} + +// Regression test for b/112480134. +TEST_P(QuicConnectionTest, NoExtraPaddingInReserializedInitial) { + // EXPECT_QUIC_BUG tests are expensive so only run one instance of them. + if (!IsDefaultTestConfiguration() || + !connection_.version().CanSendCoalescedPackets()) { + return; + } + + set_perspective(Perspective::IS_SERVER); + MockQuicConnectionDebugVisitor debug_visitor; + connection_.set_debug_visitor(&debug_visitor); + + uint64_t debug_visitor_sent_count = 0; + EXPECT_CALL(debug_visitor, OnPacketSent(_, _, _, _, _, _, _, _)) + .WillRepeatedly([&]() { debug_visitor_sent_count++; }); + + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(AnyNumber()); + + // Received INITIAL 1. + ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL); + + peer_framer_.SetEncrypter( + ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + connection_.SetEncrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + SetDecrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + SetDecrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + connection_.SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + + // Received ENCRYPTION_ZERO_RTT 2. + ProcessDataPacketAtLevel(2, !kHasStopWaiting, ENCRYPTION_ZERO_RTT); + + { + QuicConnection::ScopedPacketFlusher flusher(&connection_); + // Send INITIAL 1. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_INITIAL); + // Send HANDSHAKE 2. + EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + connection_.SendCryptoDataWithString(std::string(200, 'a'), 0, + ENCRYPTION_HANDSHAKE); + // Send 1-RTT 3. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + connection_.SendStreamDataWithString(0, std::string(400, 'b'), 0, NO_FIN); + } + + // Arrange the stream data to be sent in response to ENCRYPTION_INITIAL 3. + const std::string data4(1000, '4'); // Data to send in stream id 4 + const std::string data8(3000, '8'); // Data to send in stream id 8 + EXPECT_CALL(visitor_, OnCanWrite()).WillOnce([&]() { + connection_.producer()->SaveStreamData(4, data4); + connection_.producer()->SaveStreamData(8, data8); + + notifier_.WriteOrBufferData(4, data4.size(), FIN_AND_PADDING); + + // This should trigger FlushCoalescedPacket. + notifier_.WriteOrBufferData(8, data8.size(), FIN); + }); + + QuicByteCount pending_padding_after_serialize_2nd_1rtt_packet = 0; + QuicPacketCount num_1rtt_packets_serialized = 0; + EXPECT_CALL(connection_, OnSerializedPacket(_)) + .WillRepeatedly([&](SerializedPacket packet) { + if (packet.encryption_level == ENCRYPTION_FORWARD_SECURE) { + num_1rtt_packets_serialized++; + if (num_1rtt_packets_serialized == 2) { + pending_padding_after_serialize_2nd_1rtt_packet = + connection_.packet_creator().pending_padding_bytes(); + } + } + connection_.QuicConnection::OnSerializedPacket(std::move(packet)); + }); + + // Server receives INITIAL 3, this will serialzie FS 7 (stream 4, stream 8), + // which will trigger a flush of a coalesced packet consists of INITIAL 4, + // HS 5 and FS 6 (stream 4). + + // Expect no QUIC_BUG. + ProcessDataPacketAtLevel(3, !kHasStopWaiting, ENCRYPTION_INITIAL); + EXPECT_EQ( + debug_visitor_sent_count, + connection_.sent_packet_manager().GetLargestSentPacket().ToUint64()); + + // The error only happens if after serializing the second 1RTT packet(pkt #7), + // the pending padding bytes is non zero. + EXPECT_GT(pending_padding_after_serialize_2nd_1rtt_packet, 0u); + EXPECT_TRUE(connection_.connected()); +} + +TEST_P(QuicConnectionTest, ReportedAckDelayIncludesQueuingDelay) { + if (!version().HasIetfQuicFrames()) { + return; + } + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + config.set_max_undecryptable_packets(3); + connection_.SetFromConfig(config); + + // Receive 1-RTT ack-eliciting packet while keys are not available. + connection_.RemoveDecrypter(ENCRYPTION_FORWARD_SECURE); + peer_framer_.SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + QuicFrames frames; + frames.push_back(QuicFrame(QuicPingFrame())); + frames.push_back(QuicFrame(QuicPaddingFrame(100))); + QuicPacketHeader header = + ConstructPacketHeader(30, ENCRYPTION_FORWARD_SECURE); + std::unique_ptr packet(ConstructPacket(header, frames)); + + char buffer[kMaxOutgoingPacketSize]; + size_t encrypted_length = peer_framer_.EncryptPayload( + ENCRYPTION_FORWARD_SECURE, QuicPacketNumber(30), *packet, buffer, + kMaxOutgoingPacketSize); + EXPECT_EQ(0u, QuicConnectionPeer::NumUndecryptablePackets(&connection_)); + const QuicTime packet_receipt_time = clock_.Now(); + connection_.ProcessUdpPacket( + kSelfAddress, kPeerAddress, + QuicReceivedPacket(buffer, encrypted_length, clock_.Now(), false)); + if (connection_.GetSendAlarm()->IsSet()) { + connection_.GetSendAlarm()->Fire(); + } + ASSERT_EQ(1u, QuicConnectionPeer::NumUndecryptablePackets(&connection_)); + // 1-RTT keys become available after 10ms. + const QuicTime::Delta kQueuingDelay = QuicTime::Delta::FromMilliseconds(10); + clock_.AdvanceTime(kQueuingDelay); + EXPECT_FALSE(connection_.GetProcessUndecryptablePacketsAlarm()->IsSet()); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + SetDecrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + ASSERT_TRUE(connection_.GetProcessUndecryptablePacketsAlarm()->IsSet()); + + connection_.GetProcessUndecryptablePacketsAlarm()->Fire(); + ASSERT_TRUE(connection_.HasPendingAcks()); + EXPECT_EQ(packet_receipt_time + DefaultDelayedAckTime(), + connection_.GetAckAlarm()->deadline()); + clock_.AdvanceTime(packet_receipt_time + DefaultDelayedAckTime() - + clock_.Now()); + // Fire ACK alarm. + connection_.GetAckAlarm()->Fire(); + ASSERT_EQ(1u, writer_->ack_frames().size()); + // Verify ACK delay time does not include queuing delay. + EXPECT_EQ(DefaultDelayedAckTime(), writer_->ack_frames()[0].ack_delay_time); +} + +TEST_P(QuicConnectionTest, CoalesceOneRTTPacketWithInitialAndHandshakePackets) { + if (!version().HasIetfQuicFrames()) { + return; + } + set_perspective(Perspective::IS_SERVER); + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(AnyNumber()); + + // Received INITIAL 1. + ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL); + + peer_framer_.SetEncrypter( + ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + + connection_.SetEncrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + SetDecrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + SetDecrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + connection_.SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + + // Received ENCRYPTION_ZERO_RTT 2. + ProcessDataPacketAtLevel(2, !kHasStopWaiting, ENCRYPTION_ZERO_RTT); + + { + QuicConnection::ScopedPacketFlusher flusher(&connection_); + // Send INITIAL 1. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_INITIAL); + // Send HANDSHAKE 2. + EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + connection_.SendCryptoDataWithString(std::string(200, 'a'), 0, + ENCRYPTION_HANDSHAKE); + // Send 1-RTT data. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + connection_.SendStreamDataWithString(0, std::string(2000, 'b'), 0, FIN); + } + // Verify coalesced packet [INITIAL 1 + HANDSHAKE 2 + part of 1-RTT data] + + // rest of 1-RTT data get sent. + EXPECT_EQ(2u, writer_->packets_write_attempts()); + + // Received ENCRYPTION_INITIAL 3. + ProcessDataPacketAtLevel(3, !kHasStopWaiting, ENCRYPTION_INITIAL); + + // Verify a coalesced packet gets sent. + EXPECT_EQ(3u, writer_->packets_write_attempts()); + + // Only the first INITIAL packet has been processed yet. + EXPECT_EQ(1u, writer_->ack_frames().size()); + EXPECT_EQ(1u, writer_->crypto_frames().size()); + + // Process HANDSHAKE packet. + ASSERT_TRUE(writer_->coalesced_packet() != nullptr); + auto packet = writer_->coalesced_packet()->Clone(); + writer_->framer()->ProcessPacket(*packet); + EXPECT_EQ(1u, writer_->crypto_frames().size()); + // Process 1-RTT packet. + ASSERT_TRUE(writer_->coalesced_packet() != nullptr); + packet = writer_->coalesced_packet()->Clone(); + writer_->framer()->ProcessPacket(*packet); + EXPECT_EQ(1u, writer_->stream_frames().size()); +} + +// Regression test for b/180103273 +TEST_P(QuicConnectionTest, SendMultipleConnectionCloses) { + if (!version().HasIetfQuicFrames() || + !GetQuicReloadableFlag(quic_default_enable_5rto_blackhole_detection2)) { + return; + } + set_perspective(Perspective::IS_SERVER); + // Finish handshake. + QuicConnectionPeer::SetAddressValidated(&connection_); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + notifier_.NeuterUnencryptedData(); + connection_.NeuterUnencryptedPackets(); + connection_.OnHandshakeComplete(); + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.RemoveEncrypter(ENCRYPTION_HANDSHAKE); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + + SendStreamDataToPeer(1, "foo", 0, NO_FIN, nullptr); + ASSERT_TRUE(connection_.BlackholeDetectionInProgress()); + // Verify that BeforeConnectionCloseSent() gets called twice, + // while OnConnectionClosed() is called only once. + EXPECT_CALL(visitor_, BeforeConnectionCloseSent()).Times(2); + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)); + // Send connection close w/o closing connection. + QuicConnectionPeer::SendConnectionClosePacket( + &connection_, INTERNAL_ERROR, QUIC_INTERNAL_ERROR, "internal error"); + // Fire blackhole detection alarm. This will invoke + // SendConnectionClosePacket() a second time. + connection_.GetBlackholeDetectorAlarm()->Fire(); +} + +// Regression test for b/157895910. +TEST_P(QuicConnectionTest, EarliestSentTimeNotInitializedWhenPtoFires) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + set_perspective(Perspective::IS_SERVER); + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(AnyNumber()); + + // Received INITIAL 1. + ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL); + connection_.SetEncrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + SetDecrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + connection_.SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + { + QuicConnection::ScopedPacketFlusher flusher(&connection_); + // Send INITIAL 1. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_INITIAL); + // Send HANDSHAKE 2. + EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + connection_.SendCryptoDataWithString(std::string(200, 'a'), 0, + ENCRYPTION_HANDSHAKE); + // Send half RTT data. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + connection_.SendStreamDataWithString(0, std::string(2000, 'b'), 0, FIN); + } + + // Received ACKs for both INITIAL and HANDSHAKE packets. + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _, _, _)) + .Times(AnyNumber()); + QuicFrames frames1; + QuicAckFrame ack_frame1 = InitAckFrame(1); + frames1.push_back(QuicFrame(&ack_frame1)); + + QuicFrames frames2; + QuicAckFrame ack_frame2 = + InitAckFrame({{QuicPacketNumber(2), QuicPacketNumber(3)}}); + frames2.push_back(QuicFrame(&ack_frame2)); + ProcessCoalescedPacket( + {{2, frames1, ENCRYPTION_INITIAL}, {3, frames2, ENCRYPTION_HANDSHAKE}}); + // Verify PTO is not armed given the only outstanding data is half RTT data. + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); +} + +TEST_P(QuicConnectionTest, CalculateNetworkBlackholeDelay) { + if (!IsDefaultTestConfiguration()) { + return; + } + + const QuicTime::Delta kOneSec = QuicTime::Delta::FromSeconds(1); + const QuicTime::Delta kTwoSec = QuicTime::Delta::FromSeconds(2); + const QuicTime::Delta kFourSec = QuicTime::Delta::FromSeconds(4); + + // Normal case: blackhole_delay longer than path_degrading_delay + + // 2*pto_delay. + EXPECT_EQ(QuicConnection::CalculateNetworkBlackholeDelay(kFourSec, kOneSec, + kOneSec), + kFourSec); + + EXPECT_EQ(QuicConnection::CalculateNetworkBlackholeDelay(kFourSec, kOneSec, + kTwoSec), + QuicTime::Delta::FromSeconds(5)); +} + +TEST_P(QuicConnectionTest, FixBytesAccountingForBufferedCoalescedPackets) { + if (!connection_.version().CanSendCoalescedPackets()) { + return; + } + // Write is blocked. + EXPECT_CALL(visitor_, OnWriteBlocked()).Times(AnyNumber()); + writer_->SetWriteBlocked(); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + QuicConnectionPeer::SendPing(&connection_); + const QuicConnectionStats& stats = connection_.GetStats(); + // Verify padding is accounted. + EXPECT_EQ(stats.bytes_sent, connection_.max_packet_length()); +} + +TEST_P(QuicConnectionTest, StrictAntiAmplificationLimit) { + if (!connection_.version().SupportsAntiAmplificationLimit()) { + return; + } + EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(AnyNumber()); + set_perspective(Perspective::IS_SERVER); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + // Verify no data can be sent at the beginning because bytes received is 0. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + connection_.SendCryptoDataWithString("foo", 0); + EXPECT_FALSE(connection_.CanWrite(HAS_RETRANSMITTABLE_DATA)); + EXPECT_FALSE(connection_.CanWrite(NO_RETRANSMITTABLE_DATA)); + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + + const size_t anti_amplification_factor = + GetQuicFlag(quic_anti_amplification_factor); + // Receives packet 1. + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(1); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(anti_amplification_factor); + ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL); + connection_.SetEncrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(0x02)); + connection_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(0x03)); + + for (size_t i = 1; i < anti_amplification_factor - 1; ++i) { + connection_.SendCryptoDataWithString("foo", i * 3); + } + // Send an addtion packet with max_packet_size - 1. + connection_.SetMaxPacketLength(connection_.max_packet_length() - 1); + connection_.SendCryptoDataWithString("bar", + (anti_amplification_factor - 1) * 3); + EXPECT_LT(writer_->total_bytes_written(), + anti_amplification_factor * + QuicConnectionPeer::BytesReceivedOnDefaultPath(&connection_)); + if (GetQuicFlag(quic_enforce_strict_amplification_factor)) { + // 3 connection closes which will be buffered. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(3); + // Verify retransmission alarm is not set. + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + } else { + // Crypto + 3 connection closes. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(4); + EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + } + // Try to send another packet with max_packet_size. + connection_.SetMaxPacketLength(connection_.max_packet_length() + 1); + connection_.SendCryptoDataWithString("bar", anti_amplification_factor * 3); + EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); + // Close connection. + EXPECT_CALL(visitor_, BeforeConnectionCloseSent()); + EXPECT_CALL(visitor_, OnConnectionClosed(_, _)); + connection_.CloseConnection( + QUIC_INTERNAL_ERROR, "error", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + if (GetQuicFlag(quic_enforce_strict_amplification_factor)) { + EXPECT_LT(writer_->total_bytes_written(), + anti_amplification_factor * + QuicConnectionPeer::BytesReceivedOnDefaultPath(&connection_)); + } else { + EXPECT_LT(writer_->total_bytes_written(), + (anti_amplification_factor + 2) * + QuicConnectionPeer::BytesReceivedOnDefaultPath(&connection_)); + EXPECT_GT(writer_->total_bytes_written(), + (anti_amplification_factor + 1) * + QuicConnectionPeer::BytesReceivedOnDefaultPath(&connection_)); + } +} + +TEST_P(QuicConnectionTest, OriginalConnectionId) { + set_perspective(Perspective::IS_SERVER); + EXPECT_FALSE(connection_.GetDiscardZeroRttDecryptionKeysAlarm()->IsSet()); + EXPECT_EQ(connection_.GetOriginalDestinationConnectionId(), + connection_.connection_id()); + QuicConnectionId original({0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}); + connection_.SetOriginalDestinationConnectionId(original); + EXPECT_EQ(original, connection_.GetOriginalDestinationConnectionId()); + // Send a 1-RTT packet to start the DiscardZeroRttDecryptionKeys timer. + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessDataPacketAtLevel(1, false, ENCRYPTION_FORWARD_SECURE); + if (connection_.version().UsesTls()) { + EXPECT_TRUE(connection_.GetDiscardZeroRttDecryptionKeysAlarm()->IsSet()); + EXPECT_CALL(visitor_, OnServerConnectionIdRetired(original)); + connection_.GetDiscardZeroRttDecryptionKeysAlarm()->Fire(); + EXPECT_EQ(connection_.GetOriginalDestinationConnectionId(), + connection_.connection_id()); + } else { + EXPECT_EQ(connection_.GetOriginalDestinationConnectionId(), original); + } +} + +ACTION_P2(InstallKeys, conn, level) { + uint8_t crypto_input = (level == ENCRYPTION_FORWARD_SECURE) ? 0x03 : 0x02; + conn->SetEncrypter(level, std::make_unique(crypto_input)); + conn->InstallDecrypter( + level, std::make_unique(crypto_input)); + conn->SetDefaultEncryptionLevel(level); +} + +TEST_P(QuicConnectionTest, ServerConnectionIdChangeWithLateInitial) { + if (!connection_.version().HasIetfQuicFrames()) { + return; + } + // Call SetFromConfig so that the undecrypted packet buffer size is + // initialized above zero. + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)).Times(1); + QuicConfig config; + connection_.SetFromConfig(config); + connection_.RemoveEncrypter(ENCRYPTION_FORWARD_SECURE); + connection_.RemoveDecrypter(ENCRYPTION_FORWARD_SECURE); + + // Send Client Initial. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + connection_.SendCryptoStreamData(); + + EXPECT_EQ(1u, writer_->packets_write_attempts()); + // Server Handshake packet with new connection ID is buffered. + QuicConnectionId old_id = connection_id_; + connection_id_ = TestConnectionId(2); + peer_creator_.SetEncrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(0x02)); + ProcessCryptoPacketAtLevel(0, ENCRYPTION_HANDSHAKE); + EXPECT_EQ(QuicConnectionPeer::NumUndecryptablePackets(&connection_), 1u); + EXPECT_EQ(connection_.connection_id(), old_id); + + // Server 1-RTT Packet is buffered. + peer_creator_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(0x03)); + ProcessDataPacket(0); + EXPECT_EQ(QuicConnectionPeer::NumUndecryptablePackets(&connection_), 2u); + + // Pretend the server Initial packet will yield the Handshake keys. + EXPECT_CALL(visitor_, OnCryptoFrame(_)) + .Times(2) + .WillOnce(InstallKeys(&connection_, ENCRYPTION_HANDSHAKE)) + .WillOnce(InstallKeys(&connection_, ENCRYPTION_FORWARD_SECURE)); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + ProcessCryptoPacketAtLevel(0, ENCRYPTION_INITIAL); + // Two packets processed, connection ID changed. + EXPECT_EQ(QuicConnectionPeer::NumUndecryptablePackets(&connection_), 0u); + EXPECT_EQ(connection_.connection_id(), connection_id_); +} + +TEST_P(QuicConnectionTest, ServerConnectionIdChangeTwiceWithLateInitial) { + if (!connection_.version().HasIetfQuicFrames()) { + return; + } + // Call SetFromConfig so that the undecrypted packet buffer size is + // initialized above zero. + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)).Times(1); + QuicConfig config; + connection_.SetFromConfig(config); + + // Send Client Initial. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + connection_.SendCryptoStreamData(); + + EXPECT_EQ(1u, writer_->packets_write_attempts()); + // Server Handshake Packet Arrives with new connection ID. + QuicConnectionId old_id = connection_id_; + connection_id_ = TestConnectionId(2); + peer_creator_.SetEncrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(0x02)); + ProcessCryptoPacketAtLevel(0, ENCRYPTION_HANDSHAKE); + // Packet is buffered. + EXPECT_EQ(QuicConnectionPeer::NumUndecryptablePackets(&connection_), 1u); + EXPECT_EQ(connection_.connection_id(), old_id); + + // Pretend the server Initial packet will yield the Handshake keys. + EXPECT_CALL(visitor_, OnCryptoFrame(_)) + .WillOnce(InstallKeys(&connection_, ENCRYPTION_HANDSHAKE)); + connection_id_ = TestConnectionId(1); + ProcessCryptoPacketAtLevel(0, ENCRYPTION_INITIAL); + // Handshake packet discarded because there's a different connection ID. + EXPECT_EQ(QuicConnectionPeer::NumUndecryptablePackets(&connection_), 0u); + EXPECT_EQ(connection_.connection_id(), connection_id_); +} + +TEST_P(QuicConnectionTest, ClientValidatedServerPreferredAddress) { + // Test the scenario where the client validates server preferred address by + // receiving PATH_RESPONSE from server preferred address. + if (!connection_.version().HasIetfQuicFrames()) { + return; + } + QuicConfig config; + ServerPreferredAddressInit(config); + const QuicSocketAddress kNewSelfAddress = + QuicSocketAddress(QuicIpAddress::Loopback6(), /*port=*/23456); + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); + const StatelessResetToken kNewStatelessResetToken = + QuicUtils::GenerateStatelessResetToken(TestConnectionId(17)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + // Kick off path validation of server preferred address on handshake + // confirmed. + EXPECT_CALL(visitor_, + OnServerPreferredAddressAvailable(kServerPreferredAddress)) + .WillOnce(Invoke([&]() { + connection_.ValidatePath( + std::make_unique( + kNewSelfAddress, kServerPreferredAddress, &new_writer), + std::make_unique( + &connection_), + PathValidationReason::kReasonUnknown); + })); + connection_.OnHandshakeComplete(); + EXPECT_TRUE(connection_.HasPendingPathValidation()); + EXPECT_TRUE(QuicConnectionPeer::IsAlternativePath( + &connection_, kNewSelfAddress, kServerPreferredAddress)); + EXPECT_EQ(TestConnectionId(17), + new_writer.last_packet_header().destination_connection_id); + EXPECT_EQ(kServerPreferredAddress, new_writer.last_write_peer_address()); + + ASSERT_FALSE(new_writer.path_challenge_frames().empty()); + QuicPathFrameBuffer payload = + new_writer.path_challenge_frames().front().data_buffer; + // Send data packet while path validation is pending. + connection_.SendStreamDataWithString(3, "foo", 0, NO_FIN); + ASSERT_FALSE(writer_->stream_frames().empty()); + // While path validation is pending, packet is sent on default path. + EXPECT_EQ(TestConnectionId(), + writer_->last_packet_header().destination_connection_id); + EXPECT_EQ(kPeerAddress, writer_->last_write_peer_address()); + EXPECT_TRUE(connection_.IsValidStatelessResetToken(kTestStatelessResetToken)); + EXPECT_FALSE(connection_.IsValidStatelessResetToken(kNewStatelessResetToken)); + + // Receive path response from server preferred address. + QuicFrames frames; + frames.push_back(QuicFrame(QuicPathResponseFrame(99, payload))); + // Verify send_algorithm gets reset after migration (new sent packet is not + // updated to exsting send_algorithm_). + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + ProcessFramesPacketWithAddresses(frames, kNewSelfAddress, + kServerPreferredAddress, + ENCRYPTION_FORWARD_SECURE); + ASSERT_FALSE(connection_.HasPendingPathValidation()); + EXPECT_TRUE(QuicConnectionPeer::IsDefaultPath(&connection_, kNewSelfAddress, + kServerPreferredAddress)); + ASSERT_FALSE(new_writer.stream_frames().empty()); + // Verify stream data is retransmitted on new path. + EXPECT_EQ(TestConnectionId(17), + new_writer.last_packet_header().destination_connection_id); + EXPECT_EQ(kServerPreferredAddress, new_writer.last_write_peer_address()); + // Verify stateless reset token gets changed. + EXPECT_FALSE( + connection_.IsValidStatelessResetToken(kTestStatelessResetToken)); + EXPECT_TRUE(connection_.IsValidStatelessResetToken(kNewStatelessResetToken)); + + auto* retire_peer_issued_cid_alarm = + connection_.GetRetirePeerIssuedConnectionIdAlarm(); + ASSERT_TRUE(retire_peer_issued_cid_alarm->IsSet()); + // Verify client retires connection ID with sequence number 0. + EXPECT_CALL(visitor_, SendRetireConnectionId(/*sequence_number=*/0u)); + retire_peer_issued_cid_alarm->Fire(); + EXPECT_TRUE(connection_.GetStats().server_preferred_address_validated); + EXPECT_FALSE( + connection_.GetStats().failed_to_validate_server_preferred_address); +} + +TEST_P(QuicConnectionTest, ClientValidatedServerPreferredAddress2) { + // Test the scenario where the client validates server preferred address by + // receiving PATH_RESPONSE from original server address. + if (!connection_.version().HasIetfQuicFrames()) { + return; + } + QuicConfig config; + ServerPreferredAddressInit(config); + const QuicSocketAddress kNewSelfAddress = + QuicSocketAddress(QuicIpAddress::Loopback6(), /*port=*/23456); + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + // Kick off path validation of server preferred address on handshake + // confirmed. + EXPECT_CALL(visitor_, + OnServerPreferredAddressAvailable(kServerPreferredAddress)) + .WillOnce(Invoke([&]() { + connection_.ValidatePath( + std::make_unique( + kNewSelfAddress, kServerPreferredAddress, &new_writer), + std::make_unique( + &connection_), + PathValidationReason::kReasonUnknown); + })); + connection_.OnHandshakeComplete(); + EXPECT_TRUE(connection_.HasPendingPathValidation()); + ASSERT_FALSE(new_writer.path_challenge_frames().empty()); + QuicPathFrameBuffer payload = + new_writer.path_challenge_frames().front().data_buffer; + // Send data packet while path validation is pending. + connection_.SendStreamDataWithString(3, "foo", 0, NO_FIN); + ASSERT_FALSE(writer_->stream_frames().empty()); + EXPECT_EQ(TestConnectionId(), + writer_->last_packet_header().destination_connection_id); + EXPECT_EQ(kPeerAddress, writer_->last_write_peer_address()); + + // Receive path response from original server address. + QuicFrames frames; + frames.push_back(QuicFrame(QuicPathResponseFrame(99, payload))); + ProcessFramesPacketWithAddresses(frames, kNewSelfAddress, kPeerAddress, + ENCRYPTION_FORWARD_SECURE); + ASSERT_FALSE(connection_.HasPendingPathValidation()); + ASSERT_FALSE(new_writer.stream_frames().empty()); + // Verify stream data is retransmitted on new path. + EXPECT_EQ(TestConnectionId(17), + new_writer.last_packet_header().destination_connection_id); + EXPECT_EQ(kServerPreferredAddress, new_writer.last_write_peer_address()); + + auto* retire_peer_issued_cid_alarm = + connection_.GetRetirePeerIssuedConnectionIdAlarm(); + ASSERT_TRUE(retire_peer_issued_cid_alarm->IsSet()); + // Verify client retires connection ID with sequence number 0. + EXPECT_CALL(visitor_, SendRetireConnectionId(/*sequence_number=*/0u)); + retire_peer_issued_cid_alarm->Fire(); + + // Verify another packet from original server address gets processed. + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(1); + frames.clear(); + frames.push_back(QuicFrame(frame1_)); + ProcessFramesPacketWithAddresses(frames, kSelfAddress, kPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_TRUE(connection_.GetStats().server_preferred_address_validated); + EXPECT_FALSE( + connection_.GetStats().failed_to_validate_server_preferred_address); +} + +TEST_P(QuicConnectionTest, ClientFailedToValidateServerPreferredAddress) { + // Test the scenario where the client fails to validate server preferred + // address. + if (!connection_.version().HasIetfQuicFrames()) { + return; + } + QuicConfig config; + ServerPreferredAddressInit(config); + const QuicSocketAddress kNewSelfAddress = + QuicSocketAddress(QuicIpAddress::Loopback6(), /*port=*/23456); + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + // Kick off path validation of server preferred address on handshake + // confirmed. + EXPECT_CALL(visitor_, + OnServerPreferredAddressAvailable(kServerPreferredAddress)) + .WillOnce(Invoke([&]() { + connection_.ValidatePath( + std::make_unique( + kNewSelfAddress, kServerPreferredAddress, &new_writer), + std::make_unique( + &connection_), + PathValidationReason::kReasonUnknown); + })); + connection_.OnHandshakeComplete(); + EXPECT_TRUE(connection_.IsValidatingServerPreferredAddress()); + EXPECT_TRUE(QuicConnectionPeer::IsAlternativePath( + &connection_, kNewSelfAddress, kServerPreferredAddress)); + ASSERT_FALSE(new_writer.path_challenge_frames().empty()); + + // Receive mismatched path challenge from original server address. + QuicFrames frames; + frames.push_back( + QuicFrame(QuicPathResponseFrame(99, {0, 1, 2, 3, 4, 5, 6, 7}))); + ProcessFramesPacketWithAddresses(frames, kNewSelfAddress, kPeerAddress, + ENCRYPTION_FORWARD_SECURE); + ASSERT_TRUE(connection_.HasPendingPathValidation()); + EXPECT_TRUE(QuicConnectionPeer::IsAlternativePath( + &connection_, kNewSelfAddress, kServerPreferredAddress)); + + // Simluate path validation times out. + for (size_t i = 0; i < QuicPathValidator::kMaxRetryTimes + 1; ++i) { + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(3 * kInitialRttMs)); + static_cast( + QuicPathValidatorPeer::retry_timer( + QuicConnectionPeer::path_validator(&connection_))) + ->Fire(); + } + EXPECT_FALSE(connection_.HasPendingPathValidation()); + EXPECT_FALSE(QuicConnectionPeer::IsAlternativePath( + &connection_, kNewSelfAddress, kServerPreferredAddress)); + // Verify stream data is sent on the default path. + connection_.SendStreamDataWithString(3, "foo", 0, NO_FIN); + ASSERT_FALSE(writer_->stream_frames().empty()); + EXPECT_EQ(TestConnectionId(), + writer_->last_packet_header().destination_connection_id); + EXPECT_EQ(kPeerAddress, writer_->last_write_peer_address()); + + auto* retire_peer_issued_cid_alarm = + connection_.GetRetirePeerIssuedConnectionIdAlarm(); + ASSERT_TRUE(retire_peer_issued_cid_alarm->IsSet()); + // Verify client retires connection ID with sequence number 1. + EXPECT_CALL(visitor_, SendRetireConnectionId(/*sequence_number=*/1u)); + retire_peer_issued_cid_alarm->Fire(); + EXPECT_TRUE(connection_.IsValidStatelessResetToken(kTestStatelessResetToken)); + EXPECT_FALSE(connection_.GetStats().server_preferred_address_validated); + EXPECT_TRUE( + connection_.GetStats().failed_to_validate_server_preferred_address); +} + +TEST_P(QuicConnectionTest, OptimizedServerPreferredAddress) { + if (!connection_.version().HasIetfQuicFrames()) { + return; + } + const QuicSocketAddress kNewSelfAddress = + QuicSocketAddress(QuicIpAddress::Loopback6(), /*port=*/23456); + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); + EXPECT_CALL(visitor_, + OnServerPreferredAddressAvailable(kServerPreferredAddress)) + .WillOnce(Invoke([&]() { + connection_.ValidatePath( + std::make_unique( + kNewSelfAddress, kServerPreferredAddress, &new_writer), + std::make_unique( + &connection_), + PathValidationReason::kReasonUnknown); + })); + QuicConfig config; + config.SetClientConnectionOptions(QuicTagVector{kSPA2}); + ServerPreferredAddressInit(config); + EXPECT_TRUE(connection_.HasPendingPathValidation()); + ASSERT_FALSE(new_writer.path_challenge_frames().empty()); + + // Send data packet while path validation is pending. + connection_.SendStreamDataWithString(3, "foo", 0, NO_FIN); + // Verify the packet is sent on both paths. + EXPECT_FALSE(writer_->stream_frames().empty()); + EXPECT_FALSE(new_writer.stream_frames().empty()); + + // Verify packet duplication stops on handshake confirmed. + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + connection_.OnHandshakeComplete(); + SendPing(); + EXPECT_FALSE(writer_->ping_frames().empty()); + EXPECT_TRUE(new_writer.ping_frames().empty()); +} + +TEST_P(QuicConnectionTest, OptimizedServerPreferredAddress2) { + if (!connection_.version().HasIetfQuicFrames()) { + return; + } + const QuicSocketAddress kNewSelfAddress = + QuicSocketAddress(QuicIpAddress::Loopback6(), /*port=*/23456); + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); + EXPECT_CALL(visitor_, + OnServerPreferredAddressAvailable(kServerPreferredAddress)) + .WillOnce(Invoke([&]() { + connection_.ValidatePath( + std::make_unique( + kNewSelfAddress, kServerPreferredAddress, &new_writer), + std::make_unique( + &connection_), + PathValidationReason::kReasonUnknown); + })); + QuicConfig config; + config.SetClientConnectionOptions(QuicTagVector{kSPA2}); + ServerPreferredAddressInit(config); + EXPECT_TRUE(connection_.HasPendingPathValidation()); + ASSERT_FALSE(new_writer.path_challenge_frames().empty()); + + // Send data packet while path validation is pending. + connection_.SendStreamDataWithString(3, "foo", 0, NO_FIN); + // Verify the packet is sent on both paths. + EXPECT_FALSE(writer_->stream_frames().empty()); + EXPECT_FALSE(new_writer.stream_frames().empty()); + + // Simluate path validation times out. + for (size_t i = 0; i < QuicPathValidator::kMaxRetryTimes + 1; ++i) { + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(3 * kInitialRttMs)); + static_cast( + QuicPathValidatorPeer::retry_timer( + QuicConnectionPeer::path_validator(&connection_))) + ->Fire(); + } + EXPECT_FALSE(connection_.HasPendingPathValidation()); + // Verify packet duplication stops if there is no pending validation. + SendPing(); + EXPECT_FALSE(writer_->ping_frames().empty()); + EXPECT_TRUE(new_writer.ping_frames().empty()); +} + +TEST_P(QuicConnectionTest, MaxDuplicatedPacketsSentToServerPreferredAddress) { + if (!connection_.version().HasIetfQuicFrames()) { + return; + } + const QuicSocketAddress kNewSelfAddress = + QuicSocketAddress(QuicIpAddress::Loopback6(), /*port=*/23456); + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); + EXPECT_CALL(visitor_, + OnServerPreferredAddressAvailable(kServerPreferredAddress)) + .WillOnce(Invoke([&]() { + connection_.ValidatePath( + std::make_unique( + kNewSelfAddress, kServerPreferredAddress, &new_writer), + std::make_unique( + &connection_), + PathValidationReason::kReasonUnknown); + })); + QuicConfig config; + config.SetClientConnectionOptions(QuicTagVector{kSPA2}); + ServerPreferredAddressInit(config); + EXPECT_TRUE(connection_.HasPendingPathValidation()); + ASSERT_FALSE(new_writer.path_challenge_frames().empty()); + + // Send data packet while path validation is pending. + size_t write_limit = writer_->packets_write_attempts(); + size_t new_write_limit = new_writer.packets_write_attempts(); + for (size_t i = 0; i < kMaxDuplicatedPacketsSentToServerPreferredAddress; + ++i) { + connection_.SendStreamDataWithString(3, "foo", i * 3, NO_FIN); + // Verify the packet is sent on both paths. + ASSERT_EQ(write_limit + 1, writer_->packets_write_attempts()); + ASSERT_EQ(new_write_limit + 1, new_writer.packets_write_attempts()); + ++write_limit; + ++new_write_limit; + EXPECT_FALSE(writer_->stream_frames().empty()); + EXPECT_FALSE(new_writer.stream_frames().empty()); + } + + // Verify packet duplication stops if duplication limit is hit. + SendPing(); + ASSERT_EQ(write_limit + 1, writer_->packets_write_attempts()); + ASSERT_EQ(new_write_limit, new_writer.packets_write_attempts()); + EXPECT_FALSE(writer_->ping_frames().empty()); + EXPECT_TRUE(new_writer.ping_frames().empty()); +} + +TEST_P(QuicConnectionTest, MultiPortCreationAfterServerMigration) { + if (!GetParam().version.HasIetfQuicFrames()) { + return; + } + QuicConfig config; + config.SetClientConnectionOptions(QuicTagVector{kMPQC}); + ServerPreferredAddressInit(config); + if (!connection_.connection_migration_use_new_cid()) { + return; + } + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + QuicConnectionId cid_for_preferred_address = TestConnectionId(17); + const QuicSocketAddress kNewSelfAddress = + QuicSocketAddress(QuicIpAddress::Loopback6(), /*port=*/23456); + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); + EXPECT_CALL(visitor_, + OnServerPreferredAddressAvailable(kServerPreferredAddress)) + .WillOnce(Invoke([&]() { + connection_.ValidatePath( + std::make_unique( + kNewSelfAddress, kServerPreferredAddress, &new_writer), + std::make_unique( + &connection_), + PathValidationReason::kReasonUnknown); + })); + // The connection should start probing the preferred address after handshake + // confirmed. + QuicPathFrameBuffer payload; + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(testing::AtLeast(1u)) + .WillOnce(Invoke([&]() { + EXPECT_EQ(1u, new_writer.path_challenge_frames().size()); + payload = new_writer.path_challenge_frames().front().data_buffer; + EXPECT_EQ(kServerPreferredAddress, + new_writer.last_write_peer_address()); + })); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + connection_.OnHandshakeComplete(); + EXPECT_TRUE(connection_.IsValidatingServerPreferredAddress()); + + // Receiving PATH_RESPONSE should cause the connection to migrate to the + // preferred address. + QuicFrames frames; + frames.push_back(QuicFrame(QuicPathResponseFrame(99, payload))); + ProcessFramesPacketWithAddresses(frames, kNewSelfAddress, kPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_FALSE(connection_.IsValidatingServerPreferredAddress()); + EXPECT_EQ(kServerPreferredAddress, connection_.effective_peer_address()); + EXPECT_EQ(kNewSelfAddress, connection_.self_address()); + EXPECT_EQ(connection_.connection_id(), cid_for_preferred_address); + + // As the default path changed, the server issued CID 1 should be retired. + auto* retire_peer_issued_cid_alarm = + connection_.GetRetirePeerIssuedConnectionIdAlarm(); + ASSERT_TRUE(retire_peer_issued_cid_alarm->IsSet()); + EXPECT_CALL(visitor_, SendRetireConnectionId(/*sequence_number=*/0u)); + retire_peer_issued_cid_alarm->Fire(); + + const QuicSocketAddress kNewSelfAddress2(kNewSelfAddress.host(), + kNewSelfAddress.port() + 1); + EXPECT_NE(kNewSelfAddress2, kNewSelfAddress); + TestPacketWriter new_writer2(version(), &clock_, Perspective::IS_CLIENT); + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(789); + ASSERT_NE(frame.connection_id, connection_.connection_id()); + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + frame.retire_prior_to = 0u; + frame.sequence_number = 2u; + EXPECT_CALL(visitor_, CreateContextForMultiPortPath()) + .WillOnce(Return( + testing::ByMove(std::make_unique( + kNewSelfAddress2, connection_.peer_address(), &new_writer2)))); + connection_.OnNewConnectionIdFrame(frame); + EXPECT_TRUE(connection_.HasPendingPathValidation()); + EXPECT_EQ(1u, new_writer.path_challenge_frames().size()); + payload = new_writer.path_challenge_frames().front().data_buffer; + EXPECT_EQ(kServerPreferredAddress, new_writer.last_write_peer_address()); + EXPECT_EQ(kNewSelfAddress2.host(), new_writer.last_write_source_address()); + EXPECT_TRUE(QuicConnectionPeer::IsAlternativePath( + &connection_, kNewSelfAddress2, connection_.peer_address())); + auto* alt_path = QuicConnectionPeer::GetAlternativePath(&connection_); + EXPECT_FALSE(alt_path->validated); + QuicFrames frames2; + frames2.push_back(QuicFrame(QuicPathResponseFrame(99, payload))); + ProcessFramesPacketWithAddresses(frames2, kNewSelfAddress2, kPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_TRUE(alt_path->validated); +} + +// Tests that after half-way server migration, the client should be able to +// respond to any reverse path validation from the original server address. +TEST_P(QuicConnectionTest, ClientReceivePathChallengeAfterServerMigration) { + if (!GetParam().version.HasIetfQuicFrames()) { + return; + } + QuicConfig config; + ServerPreferredAddressInit(config); + QuicConnectionId cid_for_preferred_address = TestConnectionId(17); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + EXPECT_CALL(visitor_, + OnServerPreferredAddressAvailable(kServerPreferredAddress)) + .WillOnce(Invoke([&]() { + connection_.AddKnownServerAddress(kServerPreferredAddress); + })); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + connection_.OnHandshakeComplete(); + + const QuicSocketAddress kNewSelfAddress = + QuicSocketAddress(QuicIpAddress::Loopback6(), kTestPort + 1); + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); + auto context = std::make_unique( + kNewSelfAddress, kServerPreferredAddress, &new_writer); + // Pretend that the validation already succeeded. And start to use the server + // preferred address. + connection_.OnServerPreferredAddressValidated(*context, false); + EXPECT_EQ(kServerPreferredAddress, connection_.effective_peer_address()); + EXPECT_EQ(kServerPreferredAddress, connection_.peer_address()); + EXPECT_EQ(kNewSelfAddress, connection_.self_address()); + EXPECT_EQ(connection_.connection_id(), cid_for_preferred_address); + EXPECT_NE(connection_.sent_packet_manager().GetSendAlgorithm(), + send_algorithm_); + // Switch to use a mock send algorithm. + send_algorithm_ = new StrictMock(); + EXPECT_CALL(*send_algorithm_, CanSend(_)).WillRepeatedly(Return(true)); + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .WillRepeatedly(Return(kDefaultTCPMSS)); + EXPECT_CALL(*send_algorithm_, OnApplicationLimited(_)).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, BandwidthEstimate()) + .Times(AnyNumber()) + .WillRepeatedly(Return(QuicBandwidth::Zero())); + EXPECT_CALL(*send_algorithm_, InSlowStart()).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, InRecovery()).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, PopulateConnectionStats(_)).Times(AnyNumber()); + connection_.SetSendAlgorithm(send_algorithm_); + + // As the default path changed, the server issued CID 123 should be retired. + QuicConnectionPeer::RetirePeerIssuedConnectionIdsNoLongerOnPath(&connection_); + auto* retire_peer_issued_cid_alarm = + connection_.GetRetirePeerIssuedConnectionIdAlarm(); + ASSERT_TRUE(retire_peer_issued_cid_alarm->IsSet()); + EXPECT_CALL(visitor_, SendRetireConnectionId(/*sequence_number=*/0u)); + retire_peer_issued_cid_alarm->Fire(); + + // Receive PATH_CHALLENGE from the original server + // address. The client connection responds it on the default path. + QuicPathFrameBuffer path_challenge_payload{0, 1, 2, 3, 4, 5, 6, 7}; + QuicFrames frames1; + frames1.push_back( + QuicFrame(QuicPathChallengeFrame(0, path_challenge_payload))); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(AtLeast(1)) + .WillOnce(Invoke([&]() { + ASSERT_FALSE(new_writer.path_response_frames().empty()); + EXPECT_EQ( + 0, memcmp(&path_challenge_payload, + &(new_writer.path_response_frames().front().data_buffer), + sizeof(path_challenge_payload))); + EXPECT_EQ(kServerPreferredAddress, + new_writer.last_write_peer_address()); + EXPECT_EQ(kNewSelfAddress.host(), + new_writer.last_write_source_address()); + })); + ProcessFramesPacketWithAddresses(frames1, kNewSelfAddress, kPeerAddress, + ENCRYPTION_FORWARD_SECURE); +} + +// Tests that after half-way server migration, the client should be able to +// probe with a different socket and respond to reverse path validation. +TEST_P(QuicConnectionTest, ClientProbesAfterServerMigration) { + if (!GetParam().version.HasIetfQuicFrames()) { + return; + } + QuicConfig config; + ServerPreferredAddressInit(config); + QuicConnectionId cid_for_preferred_address = TestConnectionId(17); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + + // The connection should start probing the preferred address after handshake + // confirmed. + EXPECT_CALL(visitor_, + OnServerPreferredAddressAvailable(kServerPreferredAddress)) + .WillOnce(Invoke([&]() { + connection_.AddKnownServerAddress(kServerPreferredAddress); + })); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + connection_.OnHandshakeComplete(); + + const QuicSocketAddress kNewSelfAddress = + QuicSocketAddress(QuicIpAddress::Loopback6(), kTestPort + 1); + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); + auto context = std::make_unique( + kNewSelfAddress, kServerPreferredAddress, &new_writer); + // Pretend that the validation already succeeded. + connection_.OnServerPreferredAddressValidated(*context, false); + EXPECT_EQ(kServerPreferredAddress, connection_.effective_peer_address()); + EXPECT_EQ(kServerPreferredAddress, connection_.peer_address()); + EXPECT_EQ(kNewSelfAddress, connection_.self_address()); + EXPECT_EQ(connection_.connection_id(), cid_for_preferred_address); + EXPECT_NE(connection_.sent_packet_manager().GetSendAlgorithm(), + send_algorithm_); + // Switch to use a mock send algorithm. + send_algorithm_ = new StrictMock(); + EXPECT_CALL(*send_algorithm_, CanSend(_)).WillRepeatedly(Return(true)); + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .WillRepeatedly(Return(kDefaultTCPMSS)); + EXPECT_CALL(*send_algorithm_, OnApplicationLimited(_)).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, BandwidthEstimate()) + .Times(AnyNumber()) + .WillRepeatedly(Return(QuicBandwidth::Zero())); + EXPECT_CALL(*send_algorithm_, InSlowStart()).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, InRecovery()).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, PopulateConnectionStats(_)).Times(AnyNumber()); + connection_.SetSendAlgorithm(send_algorithm_); + + // Receiving data from the original server address should not change the peer + // address. + EXPECT_CALL(visitor_, OnCryptoFrame(_)); + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kNewSelfAddress, + kPeerAddress, ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kServerPreferredAddress, connection_.effective_peer_address()); + EXPECT_EQ(kServerPreferredAddress, connection_.peer_address()); + + // As the default path changed, the server issued CID 123 should be retired. + auto* retire_peer_issued_cid_alarm = + connection_.GetRetirePeerIssuedConnectionIdAlarm(); + ASSERT_TRUE(retire_peer_issued_cid_alarm->IsSet()); + EXPECT_CALL(visitor_, SendRetireConnectionId(/*sequence_number=*/0u)); + retire_peer_issued_cid_alarm->Fire(); + + // Receiving a new CID from the server. + QuicNewConnectionIdFrame new_cid_frame1; + new_cid_frame1.connection_id = TestConnectionId(456); + ASSERT_NE(new_cid_frame1.connection_id, connection_.connection_id()); + new_cid_frame1.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(new_cid_frame1.connection_id); + new_cid_frame1.retire_prior_to = 0u; + new_cid_frame1.sequence_number = 2u; + connection_.OnNewConnectionIdFrame(new_cid_frame1); + + // Probe from a new socket. + const QuicSocketAddress kNewSelfAddress2 = + QuicSocketAddress(QuicIpAddress::Loopback4(), kTestPort + 2); + TestPacketWriter new_writer2(version(), &clock_, Perspective::IS_CLIENT); + bool success; + QuicPathFrameBuffer payload; + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(testing::AtLeast(1u)) + .WillOnce(Invoke([&]() { + EXPECT_EQ(1u, new_writer2.path_challenge_frames().size()); + payload = new_writer2.path_challenge_frames().front().data_buffer; + EXPECT_EQ(kServerPreferredAddress, + new_writer2.last_write_peer_address()); + EXPECT_EQ(kNewSelfAddress2.host(), + new_writer2.last_write_source_address()); + })); + connection_.ValidatePath( + std::make_unique( + kNewSelfAddress2, connection_.peer_address(), &new_writer2), + std::make_unique( + &connection_, kNewSelfAddress2, connection_.peer_address(), &success), + PathValidationReason::kServerPreferredAddressMigration); + EXPECT_TRUE(QuicConnectionPeer::IsAlternativePath( + &connection_, kNewSelfAddress2, kServerPreferredAddress)); + + // Our server implementation will send PATH_CHALLENGE from the original server + // address. The client connection send PATH_RESPONSE to the default peer + // address. + QuicPathFrameBuffer path_challenge_payload{0, 1, 2, 3, 4, 5, 6, 7}; + QuicFrames frames; + frames.push_back( + QuicFrame(QuicPathChallengeFrame(0, path_challenge_payload))); + frames.push_back(QuicFrame(QuicPathResponseFrame(99, payload))); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(AtLeast(1)) + .WillOnce(Invoke([&]() { + EXPECT_FALSE(new_writer2.path_response_frames().empty()); + EXPECT_EQ( + 0, memcmp(&path_challenge_payload, + &(new_writer2.path_response_frames().front().data_buffer), + sizeof(path_challenge_payload))); + EXPECT_EQ(kServerPreferredAddress, + new_writer2.last_write_peer_address()); + EXPECT_EQ(kNewSelfAddress2.host(), + new_writer2.last_write_source_address()); + })); + ProcessFramesPacketWithAddresses(frames, kNewSelfAddress2, kPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_TRUE(success); +} + +TEST_P(QuicConnectionTest, EcnMarksCorrectlyRecorded) { + set_perspective(Perspective::IS_SERVER); + QuicFrames frames; + frames.push_back(QuicFrame(QuicPingFrame())); + frames.push_back(QuicFrame(QuicPaddingFrame(7))); + QuicAckFrame ack_frame = + connection_.SupportsMultiplePacketNumberSpaces() + ? connection_.received_packet_manager().GetAckFrame(APPLICATION_DATA) + : connection_.received_packet_manager().ack_frame(); + EXPECT_FALSE(ack_frame.ecn_counters.has_value()); + + ProcessFramesPacketAtLevelWithEcn(1, frames, ENCRYPTION_FORWARD_SECURE, + ECN_ECT0); + ack_frame = + connection_.SupportsMultiplePacketNumberSpaces() + ? connection_.received_packet_manager().GetAckFrame(APPLICATION_DATA) + : connection_.received_packet_manager().ack_frame(); + // Send two PINGs so that the ACK goes too. The second packet should not + // include an ACK, which checks that the packet state is cleared properly. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + if (connection_.version().HasIetfQuicFrames()) { + QuicConnectionPeer::SendPing(&connection_); + QuicConnectionPeer::SendPing(&connection_); + } + QuicConnectionStats stats = connection_.GetStats(); + if (GetQuicRestartFlag(quic_receive_ecn)) { + ASSERT_TRUE(ack_frame.ecn_counters.has_value()); + EXPECT_EQ(ack_frame.ecn_counters->ect0, 1); + EXPECT_EQ(stats.num_ack_frames_sent_with_ecn, + connection_.version().HasIetfQuicFrames() ? 1 : 0); + } else { + EXPECT_FALSE(ack_frame.ecn_counters.has_value()); + EXPECT_EQ(stats.num_ack_frames_sent_with_ecn, 0); + } + EXPECT_EQ(stats.num_ecn_marks_received.ect0, 1); + EXPECT_EQ(stats.num_ecn_marks_received.ect1, 0); + EXPECT_EQ(stats.num_ecn_marks_received.ce, 0); +} + +TEST_P(QuicConnectionTest, EcnMarksCoalescedPacket) { + if (!connection_.version().CanSendCoalescedPackets() || + !GetQuicRestartFlag(quic_receive_ecn)) { + return; + } + QuicCryptoFrame crypto_frame1{ENCRYPTION_HANDSHAKE, 0, "foo"}; + QuicFrames frames1; + frames1.push_back(QuicFrame(&crypto_frame1)); + QuicFrames frames2; + QuicCryptoFrame crypto_frame2{ENCRYPTION_FORWARD_SECURE, 0, "bar"}; + frames2.push_back(QuicFrame(&crypto_frame2)); + std::vector packets = {{2, frames1, ENCRYPTION_HANDSHAKE}, + {3, frames2, ENCRYPTION_FORWARD_SECURE}}; + QuicAckFrame ack_frame = + connection_.SupportsMultiplePacketNumberSpaces() + ? connection_.received_packet_manager().GetAckFrame(APPLICATION_DATA) + : connection_.received_packet_manager().ack_frame(); + EXPECT_FALSE(ack_frame.ecn_counters.has_value()); + ack_frame = + connection_.SupportsMultiplePacketNumberSpaces() + ? connection_.received_packet_manager().GetAckFrame(HANDSHAKE_DATA) + : connection_.received_packet_manager().ack_frame(); + EXPECT_FALSE(ack_frame.ecn_counters.has_value()); + // Deliver packets. + connection_.SetEncrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(2); + ProcessCoalescedPacket(packets, ECN_ECT0); + // Send two PINGs so that the ACKs go too. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + if (connection_.version().HasIetfQuicFrames()) { + EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + QuicConnectionPeer::SendPing(&connection_); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + QuicConnectionPeer::SendPing(&connection_); + } + QuicConnectionStats stats = connection_.GetStats(); + ack_frame = + connection_.SupportsMultiplePacketNumberSpaces() + ? connection_.received_packet_manager().GetAckFrame(HANDSHAKE_DATA) + : connection_.received_packet_manager().ack_frame(); + ASSERT_TRUE(ack_frame.ecn_counters.has_value()); + EXPECT_EQ(ack_frame.ecn_counters->ect0, + connection_.SupportsMultiplePacketNumberSpaces() ? 1 : 2); + if (connection_.SupportsMultiplePacketNumberSpaces()) { + ack_frame = connection_.SupportsMultiplePacketNumberSpaces() + ? connection_.received_packet_manager().GetAckFrame( + APPLICATION_DATA) + : connection_.received_packet_manager().ack_frame(); + EXPECT_TRUE(ack_frame.ecn_counters.has_value()); + EXPECT_EQ(ack_frame.ecn_counters->ect0, 1); + } + if (GetQuicRestartFlag(quic_receive_ecn)) { + EXPECT_EQ(stats.num_ecn_marks_received.ect0, 2); + EXPECT_EQ(stats.num_ack_frames_sent_with_ecn, + connection_.version().HasIetfQuicFrames() ? 2 : 0); + } else { + EXPECT_EQ(stats.num_ecn_marks_received.ect0, 0); + EXPECT_EQ(stats.num_ack_frames_sent_with_ecn, 0); + } + EXPECT_EQ(stats.num_ecn_marks_received.ect1, 0); + EXPECT_EQ(stats.num_ecn_marks_received.ce, 0); +} + +TEST_P(QuicConnectionTest, EcnMarksUndecryptableCoalescedPacket) { + if (!connection_.version().CanSendCoalescedPackets() || + !GetQuicRestartFlag(quic_receive_ecn)) { + return; + } + // SetFromConfig is always called after construction from InitializeSession. + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + config.set_max_undecryptable_packets(100); + connection_.SetFromConfig(config); + QuicCryptoFrame crypto_frame1{ENCRYPTION_HANDSHAKE, 0, "foo"}; + QuicFrames frames1; + frames1.push_back(QuicFrame(&crypto_frame1)); + QuicFrames frames2; + QuicCryptoFrame crypto_frame2{ENCRYPTION_FORWARD_SECURE, 0, "bar"}; + frames2.push_back(QuicFrame(&crypto_frame2)); + std::vector packets = {{2, frames1, ENCRYPTION_HANDSHAKE}, + {3, frames2, ENCRYPTION_FORWARD_SECURE}}; + char coalesced_buffer[kMaxOutgoingPacketSize]; + size_t coalesced_size = 0; + for (const auto& packet : packets) { + QuicPacketHeader header = + ConstructPacketHeader(packet.packet_number, packet.level); + // Set the correct encryption level and encrypter on peer_creator and + // peer_framer, respectively. + peer_creator_.set_encryption_level(packet.level); + peer_framer_.SetEncrypter(packet.level, + std::make_unique(packet.level)); + // Set the corresponding decrypter. + if (packet.level == ENCRYPTION_HANDSHAKE) { + connection_.SetEncrypter( + packet.level, std::make_unique(packet.level)); + connection_.SetDefaultEncryptionLevel(packet.level); + SetDecrypter(packet.level, + std::make_unique(packet.level)); + } + // Forward Secure packet is undecryptable. + std::unique_ptr constructed_packet( + ConstructPacket(header, packet.frames)); + + char buffer[kMaxOutgoingPacketSize]; + size_t encrypted_length = peer_framer_.EncryptPayload( + packet.level, QuicPacketNumber(packet.packet_number), + *constructed_packet, buffer, kMaxOutgoingPacketSize); + QUICHE_DCHECK_LE(coalesced_size + encrypted_length, kMaxOutgoingPacketSize); + memcpy(coalesced_buffer + coalesced_size, buffer, encrypted_length); + coalesced_size += encrypted_length; + } + QuicAckFrame ack_frame = + connection_.SupportsMultiplePacketNumberSpaces() + ? connection_.received_packet_manager().GetAckFrame(APPLICATION_DATA) + : connection_.received_packet_manager().ack_frame(); + EXPECT_FALSE(ack_frame.ecn_counters.has_value()); + ack_frame = + connection_.SupportsMultiplePacketNumberSpaces() + ? connection_.received_packet_manager().GetAckFrame(HANDSHAKE_DATA) + : connection_.received_packet_manager().ack_frame(); + EXPECT_FALSE(ack_frame.ecn_counters.has_value()); + // Deliver packets, but first remove the Forward Secure decrypter so that + // packet has to be buffered. + connection_.RemoveDecrypter(ENCRYPTION_FORWARD_SECURE); + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(1); + EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1); + connection_.ProcessUdpPacket( + kSelfAddress, kPeerAddress, + QuicReceivedPacket(coalesced_buffer, coalesced_size, clock_.Now(), false, + 0, true, nullptr, 0, true, ECN_ECT0)); + if (connection_.GetSendAlarm()->IsSet()) { + connection_.GetSendAlarm()->Fire(); + } + ack_frame = + connection_.SupportsMultiplePacketNumberSpaces() + ? connection_.received_packet_manager().GetAckFrame(HANDSHAKE_DATA) + : connection_.received_packet_manager().ack_frame(); + ASSERT_TRUE(ack_frame.ecn_counters.has_value()); + EXPECT_EQ(ack_frame.ecn_counters->ect0, 1); + if (connection_.SupportsMultiplePacketNumberSpaces()) { + ack_frame = connection_.SupportsMultiplePacketNumberSpaces() + ? connection_.received_packet_manager().GetAckFrame( + APPLICATION_DATA) + : connection_.received_packet_manager().ack_frame(); + EXPECT_FALSE(ack_frame.ecn_counters.has_value()); + } + // Send PING packet with ECN_CE, which will change the ECN codepoint in + // last_received_packet_info_. + ProcessFramePacketAtLevelWithEcn(4, QuicFrame(QuicPingFrame()), + ENCRYPTION_HANDSHAKE, ECN_CE); + ack_frame = + connection_.SupportsMultiplePacketNumberSpaces() + ? connection_.received_packet_manager().GetAckFrame(HANDSHAKE_DATA) + : connection_.received_packet_manager().ack_frame(); + ASSERT_TRUE(ack_frame.ecn_counters.has_value()); + EXPECT_EQ(ack_frame.ecn_counters->ect0, 1); + EXPECT_EQ(ack_frame.ecn_counters->ce, 1); + if (connection_.SupportsMultiplePacketNumberSpaces()) { + ack_frame = connection_.SupportsMultiplePacketNumberSpaces() + ? connection_.received_packet_manager().GetAckFrame( + APPLICATION_DATA) + : connection_.received_packet_manager().ack_frame(); + EXPECT_FALSE(ack_frame.ecn_counters.has_value()); + } + // Install decrypter for ENCRYPTION_FORWARD_SECURE. Make sure the original + // ECN codepoint is incremented. + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(1); + SetDecrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + connection_.GetProcessUndecryptablePacketsAlarm()->Fire(); + ack_frame = + connection_.SupportsMultiplePacketNumberSpaces() + ? connection_.received_packet_manager().GetAckFrame(APPLICATION_DATA) + : connection_.received_packet_manager().ack_frame(); + ASSERT_TRUE(ack_frame.ecn_counters.has_value()); + // Should be recorded as ECT(0), not CE. + EXPECT_EQ(ack_frame.ecn_counters->ect0, + connection_.SupportsMultiplePacketNumberSpaces() ? 1 : 2); + QuicConnectionStats stats = connection_.GetStats(); + EXPECT_EQ(stats.num_ecn_marks_received.ect0, + GetQuicRestartFlag(quic_receive_ecn) ? 2 : 0); + EXPECT_EQ(stats.num_ecn_marks_received.ect1, 0); + EXPECT_EQ(stats.num_ecn_marks_received.ce, + GetQuicRestartFlag(quic_receive_ecn) ? 1 : 0); +} + +TEST_P(QuicConnectionTest, ReceivedPacketInfoDefaults) { + EXPECT_TRUE(QuicConnectionPeer::TestLastReceivedPacketInfoDefaults()); +} + +TEST_P(QuicConnectionTest, DetectMigrationToPreferredAddress) { + if (!GetParam().version.HasIetfQuicFrames()) { + return; + } + ServerHandlePreferredAddressInit(); + + // Issue a new server CID associated with the preferred address. + QuicConnectionId server_issued_cid_for_preferred_address = + TestConnectionId(17); + EXPECT_CALL(connection_id_generator_, + GenerateNextConnectionId(connection_id_)) + .WillOnce(Return(server_issued_cid_for_preferred_address)); + EXPECT_CALL(visitor_, MaybeReserveConnectionId(_)).WillOnce(Return(true)); + absl::optional frame = + connection_.MaybeIssueNewConnectionIdForPreferredAddress(); + ASSERT_TRUE(frame.has_value()); + + auto* packet_creator = QuicConnectionPeer::GetPacketCreator(&connection_); + ASSERT_EQ(packet_creator->GetDestinationConnectionId(), + connection_.client_connection_id()); + ASSERT_EQ(packet_creator->GetSourceConnectionId(), connection_id_); + + // Process a packet received at the preferred Address. + peer_creator_.SetServerConnectionId(server_issued_cid_for_preferred_address); + EXPECT_CALL(visitor_, OnCryptoFrame(_)); + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kServerPreferredAddress, + kPeerAddress, ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + // The server migrates half-way with the default path unchanged, and + // continuing with the client issued CID 1. + EXPECT_EQ(kSelfAddress.host(), writer_->last_write_source_address()); + EXPECT_EQ(kSelfAddress, connection_.self_address()); + + // The peer retires CID 123. + QuicRetireConnectionIdFrame retire_cid_frame; + retire_cid_frame.sequence_number = 0u; + EXPECT_CALL(connection_id_generator_, + GenerateNextConnectionId(server_issued_cid_for_preferred_address)) + .WillOnce(Return(TestConnectionId(456))); + EXPECT_CALL(visitor_, MaybeReserveConnectionId(_)).WillOnce(Return(true)); + EXPECT_CALL(visitor_, SendNewConnectionId(_)); + EXPECT_TRUE(connection_.OnRetireConnectionIdFrame(retire_cid_frame)); + + // Process another packet received at Preferred Address. + EXPECT_CALL(visitor_, OnCryptoFrame(_)); + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kServerPreferredAddress, + kPeerAddress, ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kSelfAddress.host(), writer_->last_write_source_address()); + EXPECT_EQ(kSelfAddress, connection_.self_address()); +} + +TEST_P(QuicConnectionTest, + DetectSimutanuousServerAndClientAddressChangeWithProbe) { + if (!GetParam().version.HasIetfQuicFrames()) { + return; + } + ServerHandlePreferredAddressInit(); + + // Issue a new server CID associated with the preferred address. + QuicConnectionId server_issued_cid_for_preferred_address = + TestConnectionId(17); + EXPECT_CALL(connection_id_generator_, + GenerateNextConnectionId(connection_id_)) + .WillOnce(Return(server_issued_cid_for_preferred_address)); + EXPECT_CALL(visitor_, MaybeReserveConnectionId(_)).WillOnce(Return(true)); + absl::optional frame = + connection_.MaybeIssueNewConnectionIdForPreferredAddress(); + ASSERT_TRUE(frame.has_value()); + + auto* packet_creator = QuicConnectionPeer::GetPacketCreator(&connection_); + ASSERT_EQ(packet_creator->GetSourceConnectionId(), connection_id_); + ASSERT_EQ(packet_creator->GetDestinationConnectionId(), + connection_.client_connection_id()); + + // Receiving a probing packet from a new client address to the preferred + // address. + peer_creator_.SetServerConnectionId(server_issued_cid_for_preferred_address); + const QuicSocketAddress kNewPeerAddress(QuicIpAddress::Loopback4(), + /*port=*/34567); + std::unique_ptr probing_packet = ConstructProbingPacket(); + std::unique_ptr received(ConstructReceivedPacket( + QuicEncryptedPacket(probing_packet->encrypted_buffer, + probing_packet->encrypted_length), + clock_.Now())); + uint64_t num_probing_received = + connection_.GetStats().num_connectivity_probing_received; + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(AtLeast(1u)) + .WillOnce(Invoke([&]() { + EXPECT_EQ(1u, writer_->path_response_frames().size()); + EXPECT_EQ(1u, writer_->path_challenge_frames().size()); + // The responses should be sent from preferred address given server + // has not received packet on original address from the new client + // address. + EXPECT_EQ(kServerPreferredAddress.host(), + writer_->last_write_source_address()); + EXPECT_EQ(kNewPeerAddress, writer_->last_write_peer_address()); + })); + ProcessReceivedPacket(kServerPreferredAddress, kNewPeerAddress, *received); + EXPECT_EQ(num_probing_received + 1, + connection_.GetStats().num_connectivity_probing_received); + EXPECT_TRUE(QuicConnectionPeer::IsAlternativePath(&connection_, kSelfAddress, + kNewPeerAddress)); + EXPECT_LT(0u, QuicConnectionPeer::BytesSentOnAlternativePath(&connection_)); + EXPECT_EQ(received->length(), + QuicConnectionPeer::BytesReceivedOnAlternativePath(&connection_)); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kSelfAddress, connection_.self_address()); + + // Process a data packet received at the preferred Address from the new client + // address. + EXPECT_CALL(visitor_, OnConnectionMigration(IPV6_TO_IPV4_CHANGE)); + EXPECT_CALL(visitor_, OnCryptoFrame(_)); + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kServerPreferredAddress, + kNewPeerAddress, ENCRYPTION_FORWARD_SECURE); + // The server migrates half-way with the new peer address but the same default + // self address. + EXPECT_EQ(kSelfAddress.host(), writer_->last_write_source_address()); + EXPECT_EQ(kSelfAddress, connection_.self_address()); + EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); + EXPECT_TRUE(connection_.HasPendingPathValidation()); + EXPECT_FALSE(QuicConnectionPeer::GetDefaultPath(&connection_)->validated); + EXPECT_TRUE(QuicConnectionPeer::IsAlternativePath(&connection_, kSelfAddress, + kPeerAddress)); + EXPECT_EQ(packet_creator->GetSourceConnectionId(), + server_issued_cid_for_preferred_address); + + // Process another packet received at the preferred Address. + EXPECT_CALL(visitor_, OnCryptoFrame(_)); + ProcessFramePacketWithAddresses(MakeCryptoFrame(), kServerPreferredAddress, + kNewPeerAddress, ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); + EXPECT_EQ(kServerPreferredAddress.host(), + writer_->last_write_source_address()); + EXPECT_EQ(kSelfAddress, connection_.self_address()); +} + +TEST_P(QuicConnectionTest, EcnCodepointsRejected) { + TestPerPacketOptions per_packet_options; + connection_.set_per_packet_options(&per_packet_options); + for (QuicEcnCodepoint ecn : {ECN_NOT_ECT, ECN_ECT0, ECN_ECT1, ECN_CE}) { + per_packet_options.ecn_codepoint = ecn; + if (ecn == ECN_ECT0) { + EXPECT_CALL(*send_algorithm_, SupportsECT0()).WillOnce(Return(false)); + } else if (ecn == ECN_ECT1) { + EXPECT_CALL(*send_algorithm_, SupportsECT1()).WillOnce(Return(false)); + } + EXPECT_CALL(connection_, OnSerializedPacket(_)); + SendPing(); + EXPECT_EQ(per_packet_options.ecn_codepoint, ECN_NOT_ECT); + EXPECT_EQ(writer_->last_ecn_sent(), ECN_NOT_ECT); + } +} + +TEST_P(QuicConnectionTest, EcnCodepointsAccepted) { + TestPerPacketOptions per_packet_options; + connection_.set_per_packet_options(&per_packet_options); + for (QuicEcnCodepoint ecn : {ECN_NOT_ECT, ECN_ECT0, ECN_ECT1, ECN_CE}) { + per_packet_options.ecn_codepoint = ecn; + if (ecn == ECN_ECT0) { + EXPECT_CALL(*send_algorithm_, SupportsECT0()).WillOnce(Return(true)); + } else if (ecn == ECN_ECT1) { + EXPECT_CALL(*send_algorithm_, SupportsECT1()).WillOnce(Return(true)); + } + EXPECT_CALL(connection_, OnSerializedPacket(_)); + SendPing(); + QuicEcnCodepoint expected_codepoint = ecn; + if (ecn == ECN_CE) { + expected_codepoint = ECN_NOT_ECT; + } + EXPECT_EQ(per_packet_options.ecn_codepoint, expected_codepoint); + EXPECT_EQ(writer_->last_ecn_sent(), expected_codepoint); + } +} + +TEST_P(QuicConnectionTest, EcnValidationDisabled) { + TestPerPacketOptions per_packet_options; + connection_.set_per_packet_options(&per_packet_options); + QuicConnectionPeer::DisableEcnCodepointValidation(&connection_); + for (QuicEcnCodepoint ecn : {ECN_NOT_ECT, ECN_ECT0, ECN_ECT1, ECN_CE}) { + per_packet_options.ecn_codepoint = ecn; + EXPECT_CALL(connection_, OnSerializedPacket(_)); + SendPing(); + EXPECT_EQ(per_packet_options.ecn_codepoint, ecn); + EXPECT_EQ(writer_->last_ecn_sent(), ecn); + } +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_constants.cc b/quiche/quic/core/quic_constants.cc new file mode 100644 index 000000000000..b9594fa5bacc --- /dev/null +++ b/quiche/quic/core/quic_constants.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_constants.h" + +namespace quic { + +const char* const kFinalOffsetHeaderKey = ":final-offset"; + +const char* const kEPIDGoogleFrontEnd = "GFE"; +const char* const kEPIDGoogleFrontEnd0 = "GFE0"; + +QuicPacketNumber MaxRandomInitialPacketNumber() { + static const QuicPacketNumber kMaxRandomInitialPacketNumber = + QuicPacketNumber(0x7fffffff); + return kMaxRandomInitialPacketNumber; +} + +QuicPacketNumber FirstSendingPacketNumber() { + static const QuicPacketNumber kFirstSendingPacketNumber = QuicPacketNumber(1); + return kFirstSendingPacketNumber; +} + +} // namespace quic diff --git a/quiche/quic/core/quic_constants.h b/quiche/quic/core/quic_constants.h new file mode 100644 index 000000000000..dfd908e92eb6 --- /dev/null +++ b/quiche/quic/core/quic_constants.h @@ -0,0 +1,333 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_CONSTANTS_H_ +#define QUICHE_QUIC_CORE_QUIC_CONSTANTS_H_ + +#include + +#include +#include + +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +// Definitions of constant values used throughout the QUIC code. + +namespace quic { + +// Simple time constants. +inline constexpr uint64_t kNumSecondsPerMinute = 60; +inline constexpr uint64_t kNumSecondsPerHour = kNumSecondsPerMinute * 60; +inline constexpr uint64_t kNumSecondsPerWeek = kNumSecondsPerHour * 24 * 7; +inline constexpr uint64_t kNumMillisPerSecond = 1000; +inline constexpr uint64_t kNumMicrosPerMilli = 1000; +inline constexpr uint64_t kNumMicrosPerSecond = + kNumMicrosPerMilli * kNumMillisPerSecond; + +// Default number of connections for N-connection emulation. +inline constexpr uint32_t kDefaultNumConnections = 2; +// Default initial maximum size in bytes of a QUIC packet. +inline constexpr QuicByteCount kDefaultMaxPacketSize = 1250; +// Default initial maximum size in bytes of a QUIC packet for servers. +inline constexpr QuicByteCount kDefaultServerMaxPacketSize = 1000; +// Maximum transmission unit on Ethernet. +inline constexpr QuicByteCount kEthernetMTU = 1500; +// The maximum packet size of any QUIC packet over IPv6, based on ethernet's max +// size, minus the IP and UDP headers. IPv6 has a 40 byte header, UDP adds an +// additional 8 bytes. This is a total overhead of 48 bytes. Ethernet's +// max packet size is 1500 bytes, 1500 - 48 = 1452. +inline constexpr QuicByteCount kMaxV6PacketSize = 1452; +// The maximum packet size of any QUIC packet over IPv4. +// 1500(Ethernet) - 20(IPv4 header) - 8(UDP header) = 1472. +inline constexpr QuicByteCount kMaxV4PacketSize = 1472; +// The maximum incoming packet size allowed. +inline constexpr QuicByteCount kMaxIncomingPacketSize = kMaxV4PacketSize; +// The maximum outgoing packet size allowed. +inline constexpr QuicByteCount kMaxOutgoingPacketSize = kMaxV6PacketSize; +// ETH_MAX_MTU - MAX(sizeof(iphdr), sizeof(ip6_hdr)) - sizeof(udphdr). +inline constexpr QuicByteCount kMaxGsoPacketSize = 65535 - 40 - 8; +// The maximal IETF DATAGRAM frame size we'll accept. Choosing 2^16 ensures +// that it is greater than the biggest frame we could ever fit in a QUIC packet. +inline constexpr QuicByteCount kMaxAcceptedDatagramFrameSize = 65536; +// Default value of the max_packet_size transport parameter if it is not +// transmitted. +inline constexpr QuicByteCount kDefaultMaxPacketSizeTransportParam = 65527; +// Default maximum packet size used in the Linux TCP implementation. +// Used in QUIC for congestion window computations in bytes. +inline constexpr QuicByteCount kDefaultTCPMSS = 1460; +inline constexpr QuicByteCount kMaxSegmentSize = kDefaultTCPMSS; +// The minimum size of a packet which can elicit a version negotiation packet, +// as per section 8.1 of the QUIC spec. +inline constexpr QuicByteCount kMinPacketSizeForVersionNegotiation = 1200; + +// We match SPDY's use of 32 (since we'd compete with SPDY). +inline constexpr QuicPacketCount kInitialCongestionWindow = 32; + +// Do not allow initial congestion window to be greater than 200 packets. +inline constexpr QuicPacketCount kMaxInitialCongestionWindow = 200; + +// Do not allow initial congestion window to be smaller than 10 packets. +inline constexpr QuicPacketCount kMinInitialCongestionWindow = 10; + +// Minimum size of initial flow control window, for both stream and session. +// This is only enforced when version.AllowsLowFlowControlLimits() is false. +inline constexpr QuicByteCount kMinimumFlowControlSendWindow = + 16 * 1024; // 16 KB +// Default size of initial flow control window, for both stream and session. +inline constexpr QuicByteCount kDefaultFlowControlSendWindow = + 16 * 1024; // 16 KB + +// Maximum flow control receive window limits for connection and stream. +inline constexpr QuicByteCount kStreamReceiveWindowLimit = + 16 * 1024 * 1024; // 16 MB +inline constexpr QuicByteCount kSessionReceiveWindowLimit = + 24 * 1024 * 1024; // 24 MB + +// Minimum size of the CWND, in packets, when doing bandwidth resumption. +inline constexpr QuicPacketCount kMinCongestionWindowForBandwidthResumption = + 10; + +// Default size of the socket receive buffer in bytes. +inline constexpr QuicByteCount kDefaultSocketReceiveBuffer = 1024 * 1024; + +// The lower bound of an untrusted initial rtt value. +inline constexpr uint32_t kMinUntrustedInitialRoundTripTimeUs = + 10 * kNumMicrosPerMilli; + +// The lower bound of a trusted initial rtt value. +inline constexpr uint32_t kMinTrustedInitialRoundTripTimeUs = + 5 * kNumMicrosPerMilli; + +// Don't allow a client to suggest an RTT longer than 1 second. +inline constexpr uint32_t kMaxInitialRoundTripTimeUs = kNumMicrosPerSecond; + +// Maximum number of open streams per connection. +inline constexpr size_t kDefaultMaxStreamsPerConnection = 100; + +// Number of bytes reserved for public flags in the packet header. +inline constexpr size_t kPublicFlagsSize = 1; +// Number of bytes reserved for version number in the packet header. +inline constexpr size_t kQuicVersionSize = 4; + +// Minimum number of active connection IDs that an end point can maintain. +inline constexpr uint32_t kMinNumOfActiveConnectionIds = 2; + +// Length of the retry integrity tag in bytes. +// https://tools.ietf.org/html/draft-ietf-quic-transport-25#section-17.2.5 +inline constexpr size_t kRetryIntegrityTagLength = 16; + +// By default, UnackedPacketsMap allocates buffer of 64 after the first packet +// is added. +inline constexpr int kDefaultUnackedPacketsInitialCapacity = 64; + +// Signifies that the QuicPacket will contain version of the protocol. +inline constexpr bool kIncludeVersion = true; +// Signifies that the QuicPacket will include a diversification nonce. +inline constexpr bool kIncludeDiversificationNonce = true; + +// Header key used to identify final offset on data stream when sending HTTP/2 +// trailing headers over QUIC. +QUIC_EXPORT_PRIVATE extern const char* const kFinalOffsetHeaderKey; + +// Default maximum delayed ack time, in ms. +// Uses a 25ms delayed ack timer. Helps with better signaling +// in low-bandwidth (< ~384 kbps), where an ack is sent per packet. +inline constexpr int64_t kDefaultDelayedAckTimeMs = 25; + +// Default minimum delayed ack time, in ms (used only for sender control of ack +// frequency). +inline constexpr uint32_t kDefaultMinAckDelayTimeMs = 5; + +// Default shift of the ACK delay in the IETF QUIC ACK frame. +inline constexpr uint32_t kDefaultAckDelayExponent = 3; + +// Minimum tail loss probe time in ms. +inline constexpr int64_t kMinTailLossProbeTimeoutMs = 10; + +// The timeout before the handshake succeeds. +inline constexpr int64_t kInitialIdleTimeoutSecs = 5; +// The maximum idle timeout that can be negotiated. +inline constexpr int64_t kMaximumIdleTimeoutSecs = 60 * 10; // 10 minutes. +// The default timeout for a connection until the crypto handshake succeeds. +inline constexpr int64_t kMaxTimeForCryptoHandshakeSecs = 10; // 10 secs. + +// Default limit on the number of undecryptable packets the connection buffers +// before the CHLO/SHLO arrive. +inline constexpr size_t kDefaultMaxUndecryptablePackets = 10; + +// Default ping timeout. +inline constexpr int64_t kPingTimeoutSecs = 15; // 15 secs. + +// Minimum number of RTTs between Server Config Updates (SCUP) sent to client. +inline constexpr int kMinIntervalBetweenServerConfigUpdatesRTTs = 10; + +// Minimum time between Server Config Updates (SCUP) sent to client. +inline constexpr int kMinIntervalBetweenServerConfigUpdatesMs = 1000; + +// Minimum number of packets between Server Config Updates (SCUP). +inline constexpr int kMinPacketsBetweenServerConfigUpdates = 100; + +// The number of open streams that a server will accept is set to be slightly +// larger than the negotiated limit. Immediately closing the connection if the +// client opens slightly too many streams is not ideal: the client may have sent +// a FIN that was lost, and simultaneously opened a new stream. The number of +// streams a server accepts is a fixed increment over the negotiated limit, or a +// percentage increase, whichever is larger. +inline constexpr float kMaxStreamsMultiplier = 1.1f; +inline constexpr int kMaxStreamsMinimumIncrement = 10; + +// Available streams are ones with IDs less than the highest stream that has +// been opened which have neither been opened or reset. The limit on the number +// of available streams is 10 times the limit on the number of open streams. +inline constexpr int kMaxAvailableStreamsMultiplier = 10; + +// Track the number of promises that are not yet claimed by a +// corresponding get. This must be smaller than +// kMaxAvailableStreamsMultiplier, because RST on a promised stream my +// create available streams entries. +inline constexpr int kMaxPromisedStreamsMultiplier = + kMaxAvailableStreamsMultiplier - 1; + +// The 1st PTO is armed with max of earliest in flight sent time + PTO +// delay and kFirstPtoSrttMultiplier * srtt from last in flight packet. +inline constexpr float kFirstPtoSrttMultiplier = 1.5; + +// The multiplier of RTT variation when calculating PTO timeout. +inline constexpr int kPtoRttvarMultiplier = 2; + +// TCP RFC calls for 1 second RTO however Linux differs from this default and +// define the minimum RTO to 200ms, we will use the same until we have data to +// support a higher or lower value. +inline constexpr const int64_t kMinRetransmissionTimeMs = 200; +// The delayed ack time must not be greater than half the min RTO. +static_assert(kDefaultDelayedAckTimeMs <= kMinRetransmissionTimeMs / 2, + "Delayed ack time must be less than or equal half the MinRTO"); + +// We define an unsigned 16-bit floating point value, inspired by IEEE floats +// (http://en.wikipedia.org/wiki/Half_precision_floating-point_format), +// with 5-bit exponent (bias 1), 11-bit mantissa (effective 12 with hidden +// bit) and denormals, but without signs, transfinites or fractions. Wire format +// 16 bits (little-endian byte order) are split into exponent (high 5) and +// mantissa (low 11) and decoded as: +// uint64_t value; +// if (exponent == 0) value = mantissa; +// else value = (mantissa | 1 << 11) << (exponent - 1) +inline constexpr int kUFloat16ExponentBits = 5; +inline constexpr int kUFloat16MaxExponent = + (1 << kUFloat16ExponentBits) - 2; // 30 +inline constexpr int kUFloat16MantissaBits = 16 - kUFloat16ExponentBits; // 11 +inline constexpr int kUFloat16MantissaEffectiveBits = + kUFloat16MantissaBits + 1; // 12 +inline constexpr uint64_t kUFloat16MaxValue = // 0x3FFC0000000 + ((UINT64_C(1) << kUFloat16MantissaEffectiveBits) - 1) + << kUFloat16MaxExponent; + +// kDiversificationNonceSize is the size, in bytes, of the nonce that a server +// may set in the packet header to ensure that its INITIAL keys are not +// duplicated. +inline constexpr size_t kDiversificationNonceSize = 32; + +// The largest gap in packets we'll accept without closing the connection. +// This will likely have to be tuned. +inline constexpr QuicPacketCount kMaxPacketGap = 5000; + +// The max number of sequence number intervals that +// QuicPeerIssuedConnetionIdManager can maintain. +inline constexpr size_t kMaxNumConnectionIdSequenceNumberIntervals = 20; + +// The maximum number of random padding bytes to add. +inline constexpr QuicByteCount kMaxNumRandomPaddingBytes = 256; + +// The size of stream send buffer data slice size in bytes. A data slice is +// piece of stream data stored in contiguous memory, and a stream frame can +// contain data from multiple data slices. +inline constexpr QuicByteCount kQuicStreamSendBufferSliceSize = 4 * 1024; + +// For When using Random Initial Packet Numbers, they can start +// anyplace in the range 1...((2^31)-1) or 0x7fffffff +QUIC_EXPORT_PRIVATE QuicPacketNumber MaxRandomInitialPacketNumber(); + +// Used to represent an invalid or no control frame id. +inline constexpr QuicControlFrameId kInvalidControlFrameId = 0; + +// The max length a stream can have. +inline constexpr QuicByteCount kMaxStreamLength = (UINT64_C(1) << 62) - 1; + +// The max value that can be encoded using IETF Var Ints. +inline constexpr uint64_t kMaxIetfVarInt = UINT64_C(0x3fffffffffffffff); + +// The maximum stream id value that is supported - (2^32)-1 +inline constexpr QuicStreamId kMaxQuicStreamId = 0xffffffff; + +// The maximum value that can be stored in a 32-bit QuicStreamCount. +inline constexpr QuicStreamCount kMaxQuicStreamCount = 0xffffffff; + +// Number of bytes reserved for packet header type. +inline constexpr size_t kPacketHeaderTypeSize = 1; + +// Number of bytes reserved for connection ID length. +inline constexpr size_t kConnectionIdLengthSize = 1; + +// Minimum length of random bytes in IETF stateless reset packet. +inline constexpr size_t kMinRandomBytesLengthInStatelessReset = 24; + +// Maximum length allowed for the token in a NEW_TOKEN frame. +inline constexpr size_t kMaxNewTokenTokenLength = 0xffff; + +// The prefix used by a source address token in a NEW_TOKEN frame. +inline constexpr uint8_t kAddressTokenPrefix = 0; + +// Default initial rtt used before any samples are received. +inline constexpr int kInitialRttMs = 100; + +// Default threshold of packet reordering before a packet is declared lost. +inline constexpr QuicPacketCount kDefaultPacketReorderingThreshold = 3; + +// Default fraction (1/4) of an RTT the algorithm waits before determining a +// packet is lost due to early retransmission by time based loss detection. +inline constexpr int kDefaultLossDelayShift = 2; + +// Default fraction (1/8) of an RTT when doing IETF loss detection. +inline constexpr int kDefaultIetfLossDelayShift = 3; + +// Maximum number of retransmittable packets received before sending an ack. +inline constexpr QuicPacketCount kDefaultRetransmittablePacketsBeforeAck = 2; +// Wait for up to 10 retransmittable packets before sending an ack. +inline constexpr QuicPacketCount kMaxRetransmittablePacketsBeforeAck = 10; +// Minimum number of packets received before ack decimation is enabled. +// This intends to avoid the beginning of slow start, when CWNDs may be +// rapidly increasing. +inline constexpr QuicPacketCount kMinReceivedBeforeAckDecimation = 100; +// One quarter RTT delay when doing ack decimation. +inline constexpr float kAckDecimationDelay = 0.25; + +// The default alarm granularity assumed by QUIC code. +inline constexpr QuicTime::Delta kAlarmGranularity = + QuicTime::Delta::FromMilliseconds(1); + +// Maximum number of unretired connection IDs a connection can have. +inline constexpr size_t kMaxNumConnectonIdsInUse = 10u; + +// Packet number of first sending packet of a connection. Please note, this +// cannot be used as first received packet because peer can choose its starting +// packet number. +QUIC_EXPORT_PRIVATE QuicPacketNumber FirstSendingPacketNumber(); + +// Used by clients to tell if a public reset is sent from a Google frontend. +QUIC_EXPORT_PRIVATE extern const char* const kEPIDGoogleFrontEnd; +QUIC_EXPORT_PRIVATE extern const char* const kEPIDGoogleFrontEnd0; + +inline constexpr uint64_t kHttpDatagramStreamIdDivisor = 4; + +inline constexpr QuicTime::Delta kDefaultMultiPortProbingInterval = + QuicTime::Delta::FromSeconds(3); + +inline constexpr size_t kMaxNumMultiPortPaths = 5; + +inline constexpr size_t kMaxDuplicatedPacketsSentToServerPreferredAddress = 5; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_CONSTANTS_H_ diff --git a/quiche/quic/core/quic_control_frame_manager.cc b/quiche/quic/core/quic_control_frame_manager.cc new file mode 100644 index 000000000000..cc231b1567dd --- /dev/null +++ b/quiche/quic/core/quic_control_frame_manager.cc @@ -0,0 +1,364 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_control_frame_manager.h" + +#include + +#include "absl/strings/str_cat.h" +#include "quiche/quic/core/frames/quic_ack_frequency_frame.h" +#include "quiche/quic/core/frames/quic_frame.h" +#include "quiche/quic/core/frames/quic_new_connection_id_frame.h" +#include "quiche/quic/core/frames/quic_retire_connection_id_frame.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" + +namespace quic { + +namespace { + +// The maximum number of buffered control frames which are waiting to be ACKed +// or sent for the first time. +const size_t kMaxNumControlFrames = 1000; + +} // namespace + +QuicControlFrameManager::QuicControlFrameManager(QuicSession* session) + : last_control_frame_id_(kInvalidControlFrameId), + least_unacked_(1), + least_unsent_(1), + delegate_(session) {} + +QuicControlFrameManager::~QuicControlFrameManager() { + while (!control_frames_.empty()) { + DeleteFrame(&control_frames_.front()); + control_frames_.pop_front(); + } +} + +void QuicControlFrameManager::WriteOrBufferQuicFrame(QuicFrame frame) { + const bool had_buffered_frames = HasBufferedFrames(); + control_frames_.emplace_back(frame); + if (control_frames_.size() > kMaxNumControlFrames) { + delegate_->OnControlFrameManagerError( + QUIC_TOO_MANY_BUFFERED_CONTROL_FRAMES, + absl::StrCat("More than ", kMaxNumControlFrames, + "buffered control frames, least_unacked: ", least_unacked_, + ", least_unsent_: ", least_unsent_)); + return; + } + if (had_buffered_frames) { + return; + } + WriteBufferedFrames(); +} + +void QuicControlFrameManager::WriteOrBufferRstStream( + QuicStreamId id, QuicResetStreamError error, + QuicStreamOffset bytes_written) { + QUIC_DVLOG(1) << "Writing RST_STREAM_FRAME"; + WriteOrBufferQuicFrame((QuicFrame(new QuicRstStreamFrame( + ++last_control_frame_id_, id, error, bytes_written)))); +} + +void QuicControlFrameManager::WriteOrBufferGoAway( + QuicErrorCode error, QuicStreamId last_good_stream_id, + const std::string& reason) { + QUIC_DVLOG(1) << "Writing GOAWAY_FRAME"; + WriteOrBufferQuicFrame(QuicFrame(new QuicGoAwayFrame( + ++last_control_frame_id_, error, last_good_stream_id, reason))); +} + +void QuicControlFrameManager::WriteOrBufferWindowUpdate( + QuicStreamId id, QuicStreamOffset byte_offset) { + QUIC_DVLOG(1) << "Writing WINDOW_UPDATE_FRAME"; + WriteOrBufferQuicFrame(QuicFrame( + QuicWindowUpdateFrame(++last_control_frame_id_, id, byte_offset))); +} + +void QuicControlFrameManager::WriteOrBufferBlocked( + QuicStreamId id, QuicStreamOffset byte_offset) { + QUIC_DVLOG(1) << "Writing BLOCKED_FRAME"; + WriteOrBufferQuicFrame( + QuicFrame(QuicBlockedFrame(++last_control_frame_id_, id, byte_offset))); +} + +void QuicControlFrameManager::WriteOrBufferStreamsBlocked(QuicStreamCount count, + bool unidirectional) { + QUIC_DVLOG(1) << "Writing STREAMS_BLOCKED Frame"; + QUIC_CODE_COUNT(quic_streams_blocked_transmits); + WriteOrBufferQuicFrame(QuicFrame(QuicStreamsBlockedFrame( + ++last_control_frame_id_, count, unidirectional))); +} + +void QuicControlFrameManager::WriteOrBufferMaxStreams(QuicStreamCount count, + bool unidirectional) { + QUIC_DVLOG(1) << "Writing MAX_STREAMS Frame"; + QUIC_CODE_COUNT(quic_max_streams_transmits); + WriteOrBufferQuicFrame(QuicFrame( + QuicMaxStreamsFrame(++last_control_frame_id_, count, unidirectional))); +} + +void QuicControlFrameManager::WriteOrBufferStopSending( + QuicResetStreamError error, QuicStreamId stream_id) { + QUIC_DVLOG(1) << "Writing STOP_SENDING_FRAME"; + WriteOrBufferQuicFrame(QuicFrame( + QuicStopSendingFrame(++last_control_frame_id_, stream_id, error))); +} + +void QuicControlFrameManager::WriteOrBufferHandshakeDone() { + QUIC_DVLOG(1) << "Writing HANDSHAKE_DONE"; + WriteOrBufferQuicFrame( + QuicFrame(QuicHandshakeDoneFrame(++last_control_frame_id_))); +} + +void QuicControlFrameManager::WriteOrBufferAckFrequency( + const QuicAckFrequencyFrame& ack_frequency_frame) { + QUIC_DVLOG(1) << "Writing ACK_FREQUENCY frame"; + QuicControlFrameId control_frame_id = ++last_control_frame_id_; + // Using the control_frame_id for sequence_number here leaves gaps in + // sequence_number. + WriteOrBufferQuicFrame( + QuicFrame(new QuicAckFrequencyFrame(control_frame_id, + /*sequence_number=*/control_frame_id, + ack_frequency_frame.packet_tolerance, + ack_frequency_frame.max_ack_delay))); +} + +void QuicControlFrameManager::WriteOrBufferNewConnectionId( + const QuicConnectionId& connection_id, uint64_t sequence_number, + uint64_t retire_prior_to, + const StatelessResetToken& stateless_reset_token) { + QUIC_DVLOG(1) << "Writing NEW_CONNECTION_ID frame"; + WriteOrBufferQuicFrame(QuicFrame(new QuicNewConnectionIdFrame( + ++last_control_frame_id_, connection_id, sequence_number, + stateless_reset_token, retire_prior_to))); +} + +void QuicControlFrameManager::WriteOrBufferRetireConnectionId( + uint64_t sequence_number) { + QUIC_DVLOG(1) << "Writing RETIRE_CONNECTION_ID frame"; + WriteOrBufferQuicFrame(QuicFrame(new QuicRetireConnectionIdFrame( + ++last_control_frame_id_, sequence_number))); +} + +void QuicControlFrameManager::WriteOrBufferNewToken(absl::string_view token) { + QUIC_DVLOG(1) << "Writing NEW_TOKEN frame"; + WriteOrBufferQuicFrame( + QuicFrame(new QuicNewTokenFrame(++last_control_frame_id_, token))); +} + +void QuicControlFrameManager::OnControlFrameSent(const QuicFrame& frame) { + QuicControlFrameId id = GetControlFrameId(frame); + if (id == kInvalidControlFrameId) { + QUIC_BUG(quic_bug_12727_1) + << "Send or retransmit a control frame with invalid control frame id"; + return; + } + if (frame.type == WINDOW_UPDATE_FRAME) { + QuicStreamId stream_id = frame.window_update_frame.stream_id; + if (window_update_frames_.contains(stream_id) && + id > window_update_frames_[stream_id]) { + // Consider the older window update of the same stream as acked. + OnControlFrameIdAcked(window_update_frames_[stream_id]); + } + window_update_frames_[stream_id] = id; + } + if (pending_retransmissions_.contains(id)) { + // This is retransmitted control frame. + pending_retransmissions_.erase(id); + return; + } + if (id > least_unsent_) { + QUIC_BUG(quic_bug_10517_1) + << "Try to send control frames out of order, id: " << id + << " least_unsent: " << least_unsent_; + delegate_->OnControlFrameManagerError( + QUIC_INTERNAL_ERROR, "Try to send control frames out of order"); + return; + } + ++least_unsent_; +} + +bool QuicControlFrameManager::OnControlFrameAcked(const QuicFrame& frame) { + QuicControlFrameId id = GetControlFrameId(frame); + if (!OnControlFrameIdAcked(id)) { + return false; + } + if (frame.type == WINDOW_UPDATE_FRAME) { + QuicStreamId stream_id = frame.window_update_frame.stream_id; + if (window_update_frames_.contains(stream_id) && + window_update_frames_[stream_id] == id) { + window_update_frames_.erase(stream_id); + } + } + return true; +} + +void QuicControlFrameManager::OnControlFrameLost(const QuicFrame& frame) { + QuicControlFrameId id = GetControlFrameId(frame); + if (id == kInvalidControlFrameId) { + // Frame does not have a valid control frame ID, ignore it. + return; + } + if (id >= least_unsent_) { + QUIC_BUG(quic_bug_10517_2) << "Try to mark unsent control frame as lost"; + delegate_->OnControlFrameManagerError( + QUIC_INTERNAL_ERROR, "Try to mark unsent control frame as lost"); + return; + } + if (id < least_unacked_ || + GetControlFrameId(control_frames_.at(id - least_unacked_)) == + kInvalidControlFrameId) { + // This frame has already been acked. + return; + } + if (!pending_retransmissions_.contains(id)) { + pending_retransmissions_[id] = true; + QUIC_BUG_IF(quic_bug_12727_2, + pending_retransmissions_.size() > control_frames_.size()) + << "least_unacked_: " << least_unacked_ + << ", least_unsent_: " << least_unsent_; + } +} + +bool QuicControlFrameManager::IsControlFrameOutstanding( + const QuicFrame& frame) const { + QuicControlFrameId id = GetControlFrameId(frame); + if (id == kInvalidControlFrameId) { + // Frame without a control frame ID should not be retransmitted. + return false; + } + // Consider this frame is outstanding if it does not get acked. + return id < least_unacked_ + control_frames_.size() && id >= least_unacked_ && + GetControlFrameId(control_frames_.at(id - least_unacked_)) != + kInvalidControlFrameId; +} + +bool QuicControlFrameManager::HasPendingRetransmission() const { + return !pending_retransmissions_.empty(); +} + +bool QuicControlFrameManager::WillingToWrite() const { + return HasPendingRetransmission() || HasBufferedFrames(); +} + +QuicFrame QuicControlFrameManager::NextPendingRetransmission() const { + QUIC_BUG_IF(quic_bug_12727_3, pending_retransmissions_.empty()) + << "Unexpected call to NextPendingRetransmission() with empty pending " + << "retransmission list."; + QuicControlFrameId id = pending_retransmissions_.begin()->first; + return control_frames_.at(id - least_unacked_); +} + +void QuicControlFrameManager::OnCanWrite() { + if (HasPendingRetransmission()) { + // Exit early to allow streams to write pending retransmissions if any. + WritePendingRetransmission(); + return; + } + WriteBufferedFrames(); +} + +bool QuicControlFrameManager::RetransmitControlFrame(const QuicFrame& frame, + TransmissionType type) { + QUICHE_DCHECK(type == PTO_RETRANSMISSION); + QuicControlFrameId id = GetControlFrameId(frame); + if (id == kInvalidControlFrameId) { + // Frame does not have a valid control frame ID, ignore it. Returns true + // to allow writing following frames. + return true; + } + if (id >= least_unsent_) { + QUIC_BUG(quic_bug_10517_3) << "Try to retransmit unsent control frame"; + delegate_->OnControlFrameManagerError( + QUIC_INTERNAL_ERROR, "Try to retransmit unsent control frame"); + return false; + } + if (id < least_unacked_ || + GetControlFrameId(control_frames_.at(id - least_unacked_)) == + kInvalidControlFrameId) { + // This frame has already been acked. + return true; + } + QuicFrame copy = CopyRetransmittableControlFrame(frame); + QUIC_DVLOG(1) << "control frame manager is forced to retransmit frame: " + << frame; + if (delegate_->WriteControlFrame(copy, type)) { + return true; + } + DeleteFrame(©); + return false; +} + +void QuicControlFrameManager::WriteBufferedFrames() { + while (HasBufferedFrames()) { + QuicFrame frame_to_send = + control_frames_.at(least_unsent_ - least_unacked_); + QuicFrame copy = CopyRetransmittableControlFrame(frame_to_send); + if (!delegate_->WriteControlFrame(copy, NOT_RETRANSMISSION)) { + // Connection is write blocked. + DeleteFrame(©); + break; + } + OnControlFrameSent(frame_to_send); + } +} + +void QuicControlFrameManager::WritePendingRetransmission() { + while (HasPendingRetransmission()) { + QuicFrame pending = NextPendingRetransmission(); + QuicFrame copy = CopyRetransmittableControlFrame(pending); + if (!delegate_->WriteControlFrame(copy, LOSS_RETRANSMISSION)) { + // Connection is write blocked. + DeleteFrame(©); + break; + } + OnControlFrameSent(pending); + } +} + +bool QuicControlFrameManager::OnControlFrameIdAcked(QuicControlFrameId id) { + if (id == kInvalidControlFrameId) { + // Frame does not have a valid control frame ID, ignore it. + return false; + } + if (id >= least_unsent_) { + QUIC_BUG(quic_bug_10517_4) << "Try to ack unsent control frame"; + delegate_->OnControlFrameManagerError(QUIC_INTERNAL_ERROR, + "Try to ack unsent control frame"); + return false; + } + if (id < least_unacked_ || + GetControlFrameId(control_frames_.at(id - least_unacked_)) == + kInvalidControlFrameId) { + // This frame has already been acked. + return false; + } + + // Set control frame ID of acked frames to 0. + SetControlFrameId(kInvalidControlFrameId, + &control_frames_.at(id - least_unacked_)); + // Remove acked control frames from pending retransmissions. + pending_retransmissions_.erase(id); + // Clean up control frames queue and increment least_unacked_. + while (!control_frames_.empty() && + GetControlFrameId(control_frames_.front()) == kInvalidControlFrameId) { + DeleteFrame(&control_frames_.front()); + control_frames_.pop_front(); + ++least_unacked_; + } + return true; +} + +bool QuicControlFrameManager::HasBufferedFrames() const { + return least_unsent_ < least_unacked_ + control_frames_.size(); +} + +} // namespace quic diff --git a/quiche/quic/core/quic_control_frame_manager.h b/quiche/quic/core/quic_control_frame_manager.h new file mode 100644 index 000000000000..46a0f47e5cef --- /dev/null +++ b/quiche/quic/core/quic_control_frame_manager.h @@ -0,0 +1,192 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_CONTROL_FRAME_MANAGER_H_ +#define QUICHE_QUIC_CORE_QUIC_CONTROL_FRAME_MANAGER_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "quiche/quic/core/frames/quic_frame.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/common/quiche_circular_deque.h" +#include "quiche/common/quiche_linked_hash_map.h" + +namespace quic { + +class QuicSession; + +namespace test { +class QuicControlFrameManagerPeer; +} // namespace test + +// Control frame manager contains a list of sent control frames with valid +// control frame IDs. Control frames without valid control frame IDs include: +// (1) non-retransmittable frames (e.g., ACK_FRAME, PADDING_FRAME, +// STOP_WAITING_FRAME, etc.), (2) CONNECTION_CLOSE and IETF Quic +// APPLICATION_CLOSE frames. +// New control frames are added to the tail of the list when they are added to +// the generator. Control frames are removed from the head of the list when they +// get acked. Control frame manager also keeps track of lost control frames +// which need to be retransmitted. +class QUIC_EXPORT_PRIVATE QuicControlFrameManager { + public: + class QUIC_EXPORT_PRIVATE DelegateInterface { + public: + virtual ~DelegateInterface() = default; + + // Notifies the delegate of errors. + virtual void OnControlFrameManagerError(QuicErrorCode error_code, + std::string error_details) = 0; + + virtual bool WriteControlFrame(const QuicFrame& frame, + TransmissionType type) = 0; + }; + + explicit QuicControlFrameManager(QuicSession* session); + QuicControlFrameManager(const QuicControlFrameManager& other) = delete; + QuicControlFrameManager(QuicControlFrameManager&& other) = delete; + ~QuicControlFrameManager(); + + // Tries to send a WINDOW_UPDATE_FRAME. Buffers the frame if it cannot be sent + // immediately. + void WriteOrBufferRstStream(QuicControlFrameId id, QuicResetStreamError error, + QuicStreamOffset bytes_written); + + // Tries to send a GOAWAY_FRAME. Buffers the frame if it cannot be sent + // immediately. + void WriteOrBufferGoAway(QuicErrorCode error, + QuicStreamId last_good_stream_id, + const std::string& reason); + + // Tries to send a WINDOW_UPDATE_FRAME. Buffers the frame if it cannot be sent + // immediately. + void WriteOrBufferWindowUpdate(QuicStreamId id, QuicStreamOffset byte_offset); + + // Tries to send a BLOCKED_FRAME. Buffers the frame if it cannot be sent + // immediately. + void WriteOrBufferBlocked(QuicStreamId id, QuicStreamOffset byte_offset); + + // Tries to send a STREAMS_BLOCKED Frame. Buffers the frame if it cannot be + // sent immediately. + void WriteOrBufferStreamsBlocked(QuicStreamCount count, bool unidirectional); + + // Tries to send a MAX_STREAMS Frame. Buffers the frame if it cannot be sent + // immediately. + void WriteOrBufferMaxStreams(QuicStreamCount count, bool unidirectional); + + // Tries to send an IETF-QUIC STOP_SENDING frame. The frame is buffered if it + // can not be sent immediately. + void WriteOrBufferStopSending(QuicResetStreamError error, + QuicStreamId stream_id); + + // Tries to send an HANDSHAKE_DONE frame. The frame is buffered if it can not + // be sent immediately. + void WriteOrBufferHandshakeDone(); + + // Tries to send an AckFrequencyFrame. The frame is buffered if it cannot be + // sent immediately. + void WriteOrBufferAckFrequency( + const QuicAckFrequencyFrame& ack_frequency_frame); + + // Tries to send a NEW_CONNECTION_ID frame. The frame is buffered if it cannot + // be sent immediately. + void WriteOrBufferNewConnectionId( + const QuicConnectionId& connection_id, uint64_t sequence_number, + uint64_t retire_prior_to, + const StatelessResetToken& stateless_reset_token); + + // Tries to send a RETIRE_CONNNECTION_ID frame. The frame is buffered if it + // cannot be sent immediately. + void WriteOrBufferRetireConnectionId(uint64_t sequence_number); + + // Tries to send a NEW_TOKEN frame. Buffers the frame if it cannot be sent + // immediately. + void WriteOrBufferNewToken(absl::string_view token); + + // Called when |frame| gets acked. Returns true if |frame| gets acked for the + // first time, return false otherwise. + bool OnControlFrameAcked(const QuicFrame& frame); + + // Called when |frame| is considered as lost. + void OnControlFrameLost(const QuicFrame& frame); + + // Called by the session when the connection becomes writable. + void OnCanWrite(); + + // Retransmit |frame| if it is still outstanding. Returns false if the frame + // does not get retransmitted because the connection is blocked. Otherwise, + // returns true. + bool RetransmitControlFrame(const QuicFrame& frame, TransmissionType type); + + // Returns true if |frame| is outstanding and waiting to be acked. Returns + // false otherwise. + bool IsControlFrameOutstanding(const QuicFrame& frame) const; + + // Returns true if there is any lost control frames waiting to be + // retransmitted. + bool HasPendingRetransmission() const; + + // Returns true if there are any lost or new control frames waiting to be + // sent. + bool WillingToWrite() const; + + private: + friend class test::QuicControlFrameManagerPeer; + + // Tries to write buffered control frames to the peer. + void WriteBufferedFrames(); + + // Called when |frame| is sent for the first time or gets retransmitted. + void OnControlFrameSent(const QuicFrame& frame); + + // Writes pending retransmissions if any. + void WritePendingRetransmission(); + + // Called when frame with |id| gets acked. Returns true if |id| gets acked for + // the first time, return false otherwise. + bool OnControlFrameIdAcked(QuicControlFrameId id); + + // Retrieves the next pending retransmission. This must only be called when + // there are pending retransmissions. + QuicFrame NextPendingRetransmission() const; + + // Returns true if there are buffered frames waiting to be sent for the first + // time. + bool HasBufferedFrames() const; + + // Writes or buffers a control frame. Frame is buffered if there already + // are frames waiting to be sent. If no others waiting, will try to send the + // frame. + void WriteOrBufferQuicFrame(QuicFrame frame); + + quiche::QuicheCircularDeque control_frames_; + + // Id of latest saved control frame. 0 if no control frame has been saved. + QuicControlFrameId last_control_frame_id_; + + // The control frame at the 0th index of control_frames_. + QuicControlFrameId least_unacked_; + + // ID of the least unsent control frame. + QuicControlFrameId least_unsent_; + + // TODO(fayang): switch to linked_hash_set when chromium supports it. The bool + // is not used here. + // Lost control frames waiting to be retransmitted. + quiche::QuicheLinkedHashMap + pending_retransmissions_; + + DelegateInterface* delegate_; + + // Last sent window update frame for each stream. + absl::flat_hash_map window_update_frames_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_CONTROL_FRAME_MANAGER_H_ diff --git a/quiche/quic/core/quic_control_frame_manager_test.cc b/quiche/quic/core/quic_control_frame_manager_test.cc new file mode 100644 index 000000000000..e1d2538a05f3 --- /dev/null +++ b/quiche/quic/core/quic_control_frame_manager_test.cc @@ -0,0 +1,363 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_control_frame_manager.h" + +#include + +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/core/frames/quic_ack_frequency_frame.h" +#include "quiche/quic/core/frames/quic_retire_connection_id_frame.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +using testing::_; +using testing::InSequence; +using testing::Invoke; +using testing::Return; +using testing::StrictMock; + +namespace quic { +namespace test { + +class QuicControlFrameManagerPeer { + public: + static size_t QueueSize(QuicControlFrameManager* manager) { + return manager->control_frames_.size(); + } +}; + +namespace { + +const QuicStreamId kTestStreamId = 5; +const QuicRstStreamErrorCode kTestStopSendingCode = + QUIC_STREAM_ENCODER_STREAM_ERROR; + +class QuicControlFrameManagerTest : public QuicTest { + public: + bool SaveControlFrame(const QuicFrame& frame, TransmissionType /*type*/) { + frame_ = frame; + return true; + } + + protected: + // Pre-fills the control frame queue with the following frames: + // ID Type + // 1 RST_STREAM + // 2 GO_AWAY + // 3 WINDOW_UPDATE + // 4 BLOCKED + // 5 STOP_SENDING + // This is verified. The tests then perform manipulations on these. + void Initialize() { + connection_ = new MockQuicConnection(&helper_, &alarm_factory_, + Perspective::IS_SERVER); + connection_->SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(connection_->perspective())); + session_ = std::make_unique>(connection_); + manager_ = std::make_unique(session_.get()); + EXPECT_EQ(0u, QuicControlFrameManagerPeer::QueueSize(manager_.get())); + EXPECT_FALSE(manager_->HasPendingRetransmission()); + EXPECT_FALSE(manager_->WillingToWrite()); + + EXPECT_CALL(*session_, WriteControlFrame(_, _)).WillOnce(Return(false)); + manager_->WriteOrBufferRstStream( + kTestStreamId, + QuicResetStreamError::FromInternal(QUIC_STREAM_CANCELLED), 0); + manager_->WriteOrBufferGoAway(QUIC_PEER_GOING_AWAY, kTestStreamId, + "Going away."); + manager_->WriteOrBufferWindowUpdate(kTestStreamId, 100); + manager_->WriteOrBufferBlocked(kTestStreamId, 0); + manager_->WriteOrBufferStopSending( + QuicResetStreamError::FromInternal(kTestStopSendingCode), + kTestStreamId); + number_of_frames_ = 5u; + EXPECT_EQ(number_of_frames_, + QuicControlFrameManagerPeer::QueueSize(manager_.get())); + EXPECT_TRUE(manager_->IsControlFrameOutstanding(QuicFrame(&rst_stream_))); + EXPECT_TRUE(manager_->IsControlFrameOutstanding(QuicFrame(&goaway_))); + EXPECT_TRUE(manager_->IsControlFrameOutstanding(QuicFrame(window_update_))); + EXPECT_TRUE(manager_->IsControlFrameOutstanding(QuicFrame(blocked_))); + EXPECT_TRUE(manager_->IsControlFrameOutstanding(QuicFrame(stop_sending_))); + + EXPECT_FALSE(manager_->HasPendingRetransmission()); + EXPECT_TRUE(manager_->WillingToWrite()); + } + + QuicRstStreamFrame rst_stream_ = {1, kTestStreamId, QUIC_STREAM_CANCELLED, 0}; + QuicGoAwayFrame goaway_ = {2, QUIC_PEER_GOING_AWAY, kTestStreamId, + "Going away."}; + QuicWindowUpdateFrame window_update_ = {3, kTestStreamId, 100}; + QuicBlockedFrame blocked_ = {4, kTestStreamId, 0}; + QuicStopSendingFrame stop_sending_ = {5, kTestStreamId, kTestStopSendingCode}; + MockQuicConnectionHelper helper_; + MockAlarmFactory alarm_factory_; + MockQuicConnection* connection_; + std::unique_ptr> session_; + std::unique_ptr manager_; + QuicFrame frame_; + size_t number_of_frames_; +}; + +TEST_F(QuicControlFrameManagerTest, OnControlFrameAcked) { + Initialize(); + InSequence s; + EXPECT_CALL(*session_, WriteControlFrame(_, _)) + .Times(3) + .WillRepeatedly(Invoke(&ClearControlFrameWithTransmissionType)); + EXPECT_CALL(*session_, WriteControlFrame(_, _)).WillOnce(Return(false)); + // Send control frames 1, 2, 3. + manager_->OnCanWrite(); + EXPECT_TRUE(manager_->IsControlFrameOutstanding(QuicFrame(&rst_stream_))); + EXPECT_TRUE(manager_->IsControlFrameOutstanding(QuicFrame(&goaway_))); + EXPECT_TRUE(manager_->IsControlFrameOutstanding(QuicFrame(window_update_))); + EXPECT_TRUE(manager_->IsControlFrameOutstanding(QuicFrame(blocked_))); + EXPECT_TRUE(manager_->IsControlFrameOutstanding(QuicFrame(stop_sending_))); + + EXPECT_TRUE(manager_->OnControlFrameAcked(QuicFrame(window_update_))); + EXPECT_FALSE(manager_->IsControlFrameOutstanding(QuicFrame(window_update_))); + EXPECT_EQ(number_of_frames_, + QuicControlFrameManagerPeer::QueueSize(manager_.get())); + + EXPECT_TRUE(manager_->OnControlFrameAcked(QuicFrame(&goaway_))); + EXPECT_FALSE(manager_->IsControlFrameOutstanding(QuicFrame(&goaway_))); + EXPECT_EQ(number_of_frames_, + QuicControlFrameManagerPeer::QueueSize(manager_.get())); + EXPECT_TRUE(manager_->OnControlFrameAcked(QuicFrame(&rst_stream_))); + EXPECT_FALSE(manager_->IsControlFrameOutstanding(QuicFrame(&rst_stream_))); + // Only after the first frame in the queue is acked do the frames get + // removed ... now see that the length has been reduced by 3. + EXPECT_EQ(number_of_frames_ - 3u, + QuicControlFrameManagerPeer::QueueSize(manager_.get())); + // Duplicate ack. + EXPECT_FALSE(manager_->OnControlFrameAcked(QuicFrame(&goaway_))); + + EXPECT_FALSE(manager_->HasPendingRetransmission()); + EXPECT_TRUE(manager_->WillingToWrite()); + + // Send control frames 4, 5. + EXPECT_CALL(*session_, WriteControlFrame(_, _)) + .WillRepeatedly(Invoke(&ClearControlFrameWithTransmissionType)); + manager_->OnCanWrite(); + EXPECT_FALSE(manager_->WillingToWrite()); +} + +TEST_F(QuicControlFrameManagerTest, OnControlFrameLost) { + Initialize(); + InSequence s; + EXPECT_CALL(*session_, WriteControlFrame(_, _)) + .Times(3) + .WillRepeatedly(Invoke(&ClearControlFrameWithTransmissionType)); + EXPECT_CALL(*session_, WriteControlFrame(_, _)).WillOnce(Return(false)); + // Send control frames 1, 2, 3. + manager_->OnCanWrite(); + + // Lost control frames 1, 2, 3. + manager_->OnControlFrameLost(QuicFrame(&rst_stream_)); + manager_->OnControlFrameLost(QuicFrame(&goaway_)); + manager_->OnControlFrameLost(QuicFrame(window_update_)); + EXPECT_TRUE(manager_->HasPendingRetransmission()); + + // Ack control frame 2. + manager_->OnControlFrameAcked(QuicFrame(&goaway_)); + + // Retransmit control frames 1, 3. + EXPECT_CALL(*session_, WriteControlFrame(_, _)) + .Times(2) + .WillRepeatedly(Invoke(&ClearControlFrameWithTransmissionType)); + manager_->OnCanWrite(); + EXPECT_FALSE(manager_->HasPendingRetransmission()); + EXPECT_TRUE(manager_->WillingToWrite()); + + // Send control frames 4, 5, and 6. + EXPECT_CALL(*session_, WriteControlFrame(_, _)) + .Times(number_of_frames_ - 3u) + .WillRepeatedly(Invoke(&ClearControlFrameWithTransmissionType)); + manager_->OnCanWrite(); + EXPECT_FALSE(manager_->WillingToWrite()); +} + +TEST_F(QuicControlFrameManagerTest, RetransmitControlFrame) { + Initialize(); + InSequence s; + // Send control frames 1, 2, 3, 4. + EXPECT_CALL(*session_, WriteControlFrame(_, _)) + .Times(number_of_frames_) + .WillRepeatedly(Invoke(&ClearControlFrameWithTransmissionType)); + manager_->OnCanWrite(); + + // Ack control frame 2. + manager_->OnControlFrameAcked(QuicFrame(&goaway_)); + // Do not retransmit an acked frame + EXPECT_CALL(*session_, WriteControlFrame(_, _)).Times(0); + EXPECT_TRUE(manager_->RetransmitControlFrame(QuicFrame(&goaway_), + PTO_RETRANSMISSION)); + + // Retransmit control frame 3. + EXPECT_CALL(*session_, WriteControlFrame(_, _)) + .WillOnce(Invoke(&ClearControlFrameWithTransmissionType)); + EXPECT_TRUE(manager_->RetransmitControlFrame(QuicFrame(window_update_), + PTO_RETRANSMISSION)); + + // Retransmit control frame 4, and connection is write blocked. + EXPECT_CALL(*session_, WriteControlFrame(_, _)).WillOnce(Return(false)); + EXPECT_FALSE(manager_->RetransmitControlFrame(QuicFrame(window_update_), + PTO_RETRANSMISSION)); +} + +TEST_F(QuicControlFrameManagerTest, SendAndAckAckFrequencyFrame) { + Initialize(); + InSequence s; + // Send Non-AckFrequency frame 1-5. + EXPECT_CALL(*session_, WriteControlFrame(_, _)) + .Times(5) + .WillRepeatedly(Invoke(&ClearControlFrameWithTransmissionType)); + EXPECT_CALL(*session_, WriteControlFrame(_, _)).WillOnce(Return(false)); + manager_->OnCanWrite(); + + // Send AckFrequencyFrame as frame 6. + QuicAckFrequencyFrame frame_to_send; + frame_to_send.packet_tolerance = 10; + frame_to_send.max_ack_delay = QuicTime::Delta::FromMilliseconds(24); + manager_->WriteOrBufferAckFrequency(frame_to_send); + EXPECT_CALL(*session_, WriteControlFrame(_, _)) + .WillOnce(Invoke(&ClearControlFrameWithTransmissionType)); + manager_->OnCanWrite(); + + // Ack AckFrequencyFrame. + QuicAckFrequencyFrame expected_ack_frequency = { + 6, 6, 10, QuicTime::Delta::FromMilliseconds(24)}; + EXPECT_TRUE( + manager_->OnControlFrameAcked(QuicFrame(&expected_ack_frequency))); +} + +TEST_F(QuicControlFrameManagerTest, NewAndRetireConnectionIdFrames) { + Initialize(); + InSequence s; + + // Send other frames 1-5. + EXPECT_CALL(*session_, WriteControlFrame(_, _)) + .Times(5) + .WillRepeatedly(Invoke(&ClearControlFrameWithTransmissionType)); + manager_->OnCanWrite(); + + EXPECT_CALL(*session_, WriteControlFrame(_, _)).WillOnce(Return(false)); + EXPECT_CALL(*session_, WriteControlFrame(_, _)) + .Times(2) + .WillRepeatedly(Invoke(&ClearControlFrameWithTransmissionType)); + // Send NewConnectionIdFrame as frame 6. + manager_->WriteOrBufferNewConnectionId( + TestConnectionId(3), /*sequence_number=*/2, /*retire_prior_to=*/1, + /*stateless_reset_token=*/ + {0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1}); + // Send RetireConnectionIdFrame as frame 7. + manager_->WriteOrBufferRetireConnectionId(/*sequence_number=*/0); + manager_->OnCanWrite(); + + // Ack both frames. + QuicNewConnectionIdFrame new_connection_id_frame; + new_connection_id_frame.control_frame_id = 6; + QuicRetireConnectionIdFrame retire_connection_id_frame; + retire_connection_id_frame.control_frame_id = 7; + EXPECT_TRUE( + manager_->OnControlFrameAcked(QuicFrame(&new_connection_id_frame))); + EXPECT_TRUE( + manager_->OnControlFrameAcked(QuicFrame(&retire_connection_id_frame))); +} + +TEST_F(QuicControlFrameManagerTest, DonotRetransmitOldWindowUpdates) { + Initialize(); + // Send two more window updates of the same stream. + manager_->WriteOrBufferWindowUpdate(kTestStreamId, 200); + QuicWindowUpdateFrame window_update2(number_of_frames_ + 1, kTestStreamId, + 200); + + manager_->WriteOrBufferWindowUpdate(kTestStreamId, 300); + QuicWindowUpdateFrame window_update3(number_of_frames_ + 2, kTestStreamId, + 300); + InSequence s; + // Flush all buffered control frames. + EXPECT_CALL(*session_, WriteControlFrame(_, _)) + .WillRepeatedly(Invoke(&ClearControlFrameWithTransmissionType)); + manager_->OnCanWrite(); + + // Mark all 3 window updates as lost. + manager_->OnControlFrameLost(QuicFrame(window_update_)); + manager_->OnControlFrameLost(QuicFrame(window_update2)); + manager_->OnControlFrameLost(QuicFrame(window_update3)); + EXPECT_TRUE(manager_->HasPendingRetransmission()); + EXPECT_TRUE(manager_->WillingToWrite()); + + // Verify only the latest window update gets retransmitted. + EXPECT_CALL(*session_, WriteControlFrame(_, _)) + .WillOnce(Invoke(this, &QuicControlFrameManagerTest::SaveControlFrame)); + manager_->OnCanWrite(); + EXPECT_EQ(number_of_frames_ + 2u, + frame_.window_update_frame.control_frame_id); + EXPECT_FALSE(manager_->HasPendingRetransmission()); + EXPECT_FALSE(manager_->WillingToWrite()); + DeleteFrame(&frame_); +} + +TEST_F(QuicControlFrameManagerTest, RetransmitWindowUpdateOfDifferentStreams) { + Initialize(); + // Send two more window updates of different streams. + manager_->WriteOrBufferWindowUpdate(kTestStreamId + 2, 200); + QuicWindowUpdateFrame window_update2(5, kTestStreamId + 2, 200); + + manager_->WriteOrBufferWindowUpdate(kTestStreamId + 4, 300); + QuicWindowUpdateFrame window_update3(6, kTestStreamId + 4, 300); + InSequence s; + // Flush all buffered control frames. + EXPECT_CALL(*session_, WriteControlFrame(_, _)) + .WillRepeatedly(Invoke(&ClearControlFrameWithTransmissionType)); + manager_->OnCanWrite(); + + // Mark all 3 window updates as lost. + manager_->OnControlFrameLost(QuicFrame(window_update_)); + manager_->OnControlFrameLost(QuicFrame(window_update2)); + manager_->OnControlFrameLost(QuicFrame(window_update3)); + EXPECT_TRUE(manager_->HasPendingRetransmission()); + EXPECT_TRUE(manager_->WillingToWrite()); + + // Verify all 3 window updates get retransmitted. + EXPECT_CALL(*session_, WriteControlFrame(_, _)) + .Times(3) + .WillRepeatedly(Invoke(&ClearControlFrameWithTransmissionType)); + manager_->OnCanWrite(); + EXPECT_FALSE(manager_->HasPendingRetransmission()); + EXPECT_FALSE(manager_->WillingToWrite()); +} + +TEST_F(QuicControlFrameManagerTest, TooManyBufferedControlFrames) { + Initialize(); + EXPECT_CALL(*session_, WriteControlFrame(_, _)) + .Times(5) + .WillRepeatedly(Invoke(&ClearControlFrameWithTransmissionType)); + // Flush buffered frames. + manager_->OnCanWrite(); + // Write 995 control frames. + EXPECT_CALL(*session_, WriteControlFrame(_, _)).WillOnce(Return(false)); + for (size_t i = 0; i < 995; ++i) { + manager_->WriteOrBufferRstStream( + kTestStreamId, + QuicResetStreamError::FromInternal(QUIC_STREAM_CANCELLED), 0); + } + // Verify write one more control frame causes connection close. + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_TOO_MANY_BUFFERED_CONTROL_FRAMES, _, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET)); + manager_->WriteOrBufferRstStream( + kTestStreamId, QuicResetStreamError::FromInternal(QUIC_STREAM_CANCELLED), + 0); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_crypto_client_handshaker.cc b/quiche/quic/core/quic_crypto_client_handshaker.cc new file mode 100644 index 000000000000..0f980cbd061a --- /dev/null +++ b/quiche/quic/core/quic_crypto_client_handshaker.cc @@ -0,0 +1,634 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_crypto_client_handshaker.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/crypto/crypto_utils.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_client_stats.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace quic { + +QuicCryptoClientHandshaker::ProofVerifierCallbackImpl:: + ProofVerifierCallbackImpl(QuicCryptoClientHandshaker* parent) + : parent_(parent) {} + +QuicCryptoClientHandshaker::ProofVerifierCallbackImpl:: + ~ProofVerifierCallbackImpl() {} + +void QuicCryptoClientHandshaker::ProofVerifierCallbackImpl::Run( + bool ok, const std::string& error_details, + std::unique_ptr* details) { + if (parent_ == nullptr) { + return; + } + + parent_->verify_ok_ = ok; + parent_->verify_error_details_ = error_details; + parent_->verify_details_ = std::move(*details); + parent_->proof_verify_callback_ = nullptr; + parent_->DoHandshakeLoop(nullptr); + + // The ProofVerifier owns this object and will delete it when this method + // returns. +} + +void QuicCryptoClientHandshaker::ProofVerifierCallbackImpl::Cancel() { + parent_ = nullptr; +} + +QuicCryptoClientHandshaker::QuicCryptoClientHandshaker( + const QuicServerId& server_id, QuicCryptoClientStream* stream, + QuicSession* session, std::unique_ptr verify_context, + QuicCryptoClientConfig* crypto_config, + QuicCryptoClientStream::ProofHandler* proof_handler) + : QuicCryptoHandshaker(stream, session), + stream_(stream), + session_(session), + delegate_(session), + next_state_(STATE_IDLE), + num_client_hellos_(0), + crypto_config_(crypto_config), + server_id_(server_id), + generation_counter_(0), + verify_context_(std::move(verify_context)), + proof_verify_callback_(nullptr), + proof_handler_(proof_handler), + verify_ok_(false), + proof_verify_start_time_(QuicTime::Zero()), + num_scup_messages_received_(0), + encryption_established_(false), + one_rtt_keys_available_(false), + crypto_negotiated_params_(new QuicCryptoNegotiatedParameters) {} + +QuicCryptoClientHandshaker::~QuicCryptoClientHandshaker() { + if (proof_verify_callback_) { + proof_verify_callback_->Cancel(); + } +} + +void QuicCryptoClientHandshaker::OnHandshakeMessage( + const CryptoHandshakeMessage& message) { + QuicCryptoHandshaker::OnHandshakeMessage(message); + if (message.tag() == kSCUP) { + if (!one_rtt_keys_available()) { + stream_->OnUnrecoverableError( + QUIC_CRYPTO_UPDATE_BEFORE_HANDSHAKE_COMPLETE, + "Early SCUP disallowed"); + return; + } + + // |message| is an update from the server, so we treat it differently from a + // handshake message. + HandleServerConfigUpdateMessage(message); + num_scup_messages_received_++; + return; + } + + // Do not process handshake messages after the handshake is confirmed. + if (one_rtt_keys_available()) { + stream_->OnUnrecoverableError(QUIC_CRYPTO_MESSAGE_AFTER_HANDSHAKE_COMPLETE, + "Unexpected handshake message"); + return; + } + + DoHandshakeLoop(&message); +} + +bool QuicCryptoClientHandshaker::CryptoConnect() { + next_state_ = STATE_INITIALIZE; + DoHandshakeLoop(nullptr); + return session()->connection()->connected(); +} + +int QuicCryptoClientHandshaker::num_sent_client_hellos() const { + return num_client_hellos_; +} + +bool QuicCryptoClientHandshaker::IsResumption() const { + QUIC_BUG_IF(quic_bug_12522_1, !one_rtt_keys_available_); + // While 0-RTT handshakes could be considered to be like resumption, QUIC + // Crypto doesn't have the same notion of a resumption like TLS does. + return false; +} + +bool QuicCryptoClientHandshaker::EarlyDataAccepted() const { + QUIC_BUG_IF(quic_bug_12522_2, !one_rtt_keys_available_); + return num_client_hellos_ == 1; +} + +ssl_early_data_reason_t QuicCryptoClientHandshaker::EarlyDataReason() const { + return early_data_reason_; +} + +bool QuicCryptoClientHandshaker::ReceivedInchoateReject() const { + QUIC_BUG_IF(quic_bug_12522_3, !one_rtt_keys_available_); + return num_client_hellos_ >= 3; +} + +int QuicCryptoClientHandshaker::num_scup_messages_received() const { + return num_scup_messages_received_; +} + +std::string QuicCryptoClientHandshaker::chlo_hash() const { return chlo_hash_; } + +bool QuicCryptoClientHandshaker::encryption_established() const { + return encryption_established_; +} + +bool QuicCryptoClientHandshaker::IsCryptoFrameExpectedForEncryptionLevel( + EncryptionLevel /*level*/) const { + return true; +} + +EncryptionLevel +QuicCryptoClientHandshaker::GetEncryptionLevelToSendCryptoDataOfSpace( + PacketNumberSpace space) const { + if (space == INITIAL_DATA) { + return ENCRYPTION_INITIAL; + } + QUICHE_DCHECK(false); + return NUM_ENCRYPTION_LEVELS; +} + +bool QuicCryptoClientHandshaker::one_rtt_keys_available() const { + return one_rtt_keys_available_; +} + +const QuicCryptoNegotiatedParameters& +QuicCryptoClientHandshaker::crypto_negotiated_params() const { + return *crypto_negotiated_params_; +} + +CryptoMessageParser* QuicCryptoClientHandshaker::crypto_message_parser() { + return QuicCryptoHandshaker::crypto_message_parser(); +} + +HandshakeState QuicCryptoClientHandshaker::GetHandshakeState() const { + return one_rtt_keys_available() ? HANDSHAKE_COMPLETE : HANDSHAKE_START; +} + +void QuicCryptoClientHandshaker::OnHandshakeDoneReceived() { + QUICHE_DCHECK(false); +} + +void QuicCryptoClientHandshaker::OnNewTokenReceived( + absl::string_view /*token*/) { + QUICHE_DCHECK(false); +} + +size_t QuicCryptoClientHandshaker::BufferSizeLimitForLevel( + EncryptionLevel level) const { + return QuicCryptoHandshaker::BufferSizeLimitForLevel(level); +} + +std::unique_ptr +QuicCryptoClientHandshaker::AdvanceKeysAndCreateCurrentOneRttDecrypter() { + // Key update is only defined in QUIC+TLS. + QUICHE_DCHECK(false); + return nullptr; +} + +std::unique_ptr +QuicCryptoClientHandshaker::CreateCurrentOneRttEncrypter() { + // Key update is only defined in QUIC+TLS. + QUICHE_DCHECK(false); + return nullptr; +} + +void QuicCryptoClientHandshaker::OnConnectionClosed( + QuicErrorCode /*error*/, ConnectionCloseSource /*source*/) { + next_state_ = STATE_CONNECTION_CLOSED; +} + +void QuicCryptoClientHandshaker::HandleServerConfigUpdateMessage( + const CryptoHandshakeMessage& server_config_update) { + QUICHE_DCHECK(server_config_update.tag() == kSCUP); + std::string error_details; + QuicCryptoClientConfig::CachedState* cached = + crypto_config_->LookupOrCreate(server_id_); + QuicErrorCode error = crypto_config_->ProcessServerConfigUpdate( + server_config_update, session()->connection()->clock()->WallNow(), + session()->transport_version(), chlo_hash_, cached, + crypto_negotiated_params_, &error_details); + + if (error != QUIC_NO_ERROR) { + stream_->OnUnrecoverableError( + error, "Server config update invalid: " + error_details); + return; + } + + QUICHE_DCHECK(one_rtt_keys_available()); + if (proof_verify_callback_) { + proof_verify_callback_->Cancel(); + } + next_state_ = STATE_INITIALIZE_SCUP; + DoHandshakeLoop(nullptr); +} + +void QuicCryptoClientHandshaker::DoHandshakeLoop( + const CryptoHandshakeMessage* in) { + QuicCryptoClientConfig::CachedState* cached = + crypto_config_->LookupOrCreate(server_id_); + + QuicAsyncStatus rv = QUIC_SUCCESS; + do { + QUICHE_CHECK_NE(STATE_NONE, next_state_); + const State state = next_state_; + next_state_ = STATE_IDLE; + rv = QUIC_SUCCESS; + switch (state) { + case STATE_INITIALIZE: + DoInitialize(cached); + break; + case STATE_SEND_CHLO: + DoSendCHLO(cached); + return; // return waiting to hear from server. + case STATE_RECV_REJ: + DoReceiveREJ(in, cached); + break; + case STATE_VERIFY_PROOF: + rv = DoVerifyProof(cached); + break; + case STATE_VERIFY_PROOF_COMPLETE: + DoVerifyProofComplete(cached); + break; + case STATE_RECV_SHLO: + DoReceiveSHLO(in, cached); + break; + case STATE_IDLE: + // This means that the peer sent us a message that we weren't expecting. + stream_->OnUnrecoverableError(QUIC_INVALID_CRYPTO_MESSAGE_TYPE, + "Handshake in idle state"); + return; + case STATE_INITIALIZE_SCUP: + DoInitializeServerConfigUpdate(cached); + break; + case STATE_NONE: + QUICHE_NOTREACHED(); + return; + case STATE_CONNECTION_CLOSED: + rv = QUIC_FAILURE; + return; // We are done. + } + } while (rv != QUIC_PENDING && next_state_ != STATE_NONE); +} + +void QuicCryptoClientHandshaker::DoInitialize( + QuicCryptoClientConfig::CachedState* cached) { + if (!cached->IsEmpty() && !cached->signature().empty()) { + // Note that we verify the proof even if the cached proof is valid. + // This allows us to respond to CA trust changes or certificate + // expiration because it may have been a while since we last verified + // the proof. + QUICHE_DCHECK(crypto_config_->proof_verifier()); + // Track proof verification time when cached server config is used. + proof_verify_start_time_ = session()->connection()->clock()->Now(); + chlo_hash_ = cached->chlo_hash(); + // If the cached state needs to be verified, do it now. + next_state_ = STATE_VERIFY_PROOF; + } else { + next_state_ = STATE_SEND_CHLO; + } +} + +void QuicCryptoClientHandshaker::DoSendCHLO( + QuicCryptoClientConfig::CachedState* cached) { + // Send the client hello in plaintext. + session()->connection()->SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + encryption_established_ = false; + if (num_client_hellos_ >= QuicCryptoClientStream::kMaxClientHellos) { + stream_->OnUnrecoverableError( + QUIC_CRYPTO_TOO_MANY_REJECTS, + absl::StrCat("More than ", QuicCryptoClientStream::kMaxClientHellos, + " rejects")); + return; + } + num_client_hellos_++; + + CryptoHandshakeMessage out; + QUICHE_DCHECK(session() != nullptr); + QUICHE_DCHECK(session()->config() != nullptr); + // Send all the options, regardless of whether we're sending an + // inchoate or subsequent hello. + session()->config()->ToHandshakeMessage(&out, session()->transport_version()); + + bool fill_inchoate_client_hello = false; + if (!cached->IsComplete(session()->connection()->clock()->WallNow())) { + early_data_reason_ = ssl_early_data_no_session_offered; + fill_inchoate_client_hello = true; + } else if (session()->config()->HasClientRequestedIndependentOption( + kQNZ2, session()->perspective()) && + num_client_hellos_ == 1) { + early_data_reason_ = ssl_early_data_disabled; + fill_inchoate_client_hello = true; + } + if (fill_inchoate_client_hello) { + crypto_config_->FillInchoateClientHello( + server_id_, session()->supported_versions().front(), cached, + session()->connection()->random_generator(), + /* demand_x509_proof= */ true, crypto_negotiated_params_, &out); + // Pad the inchoate client hello to fill up a packet. + const QuicByteCount kFramingOverhead = 50; // A rough estimate. + const QuicByteCount max_packet_size = + session()->connection()->max_packet_length(); + if (max_packet_size <= kFramingOverhead) { + QUIC_DLOG(DFATAL) << "max_packet_length (" << max_packet_size + << ") has no room for framing overhead."; + stream_->OnUnrecoverableError(QUIC_INTERNAL_ERROR, + "max_packet_size too smalll"); + return; + } + if (kClientHelloMinimumSize > max_packet_size - kFramingOverhead) { + QUIC_DLOG(DFATAL) << "Client hello won't fit in a single packet."; + stream_->OnUnrecoverableError(QUIC_INTERNAL_ERROR, "CHLO too large"); + return; + } + next_state_ = STATE_RECV_REJ; + chlo_hash_ = CryptoUtils::HashHandshakeMessage(out, Perspective::IS_CLIENT); + session()->connection()->set_fully_pad_crypto_handshake_packets( + crypto_config_->pad_inchoate_hello()); + SendHandshakeMessage(out, ENCRYPTION_INITIAL); + return; + } + + std::string error_details; + QuicErrorCode error = crypto_config_->FillClientHello( + server_id_, session()->connection()->connection_id(), + session()->supported_versions().front(), + session()->connection()->version(), cached, + session()->connection()->clock()->WallNow(), + session()->connection()->random_generator(), crypto_negotiated_params_, + &out, &error_details); + if (error != QUIC_NO_ERROR) { + // Flush the cached config so that, if it's bad, the server has a + // chance to send us another in the future. + cached->InvalidateServerConfig(); + stream_->OnUnrecoverableError(error, error_details); + return; + } + chlo_hash_ = CryptoUtils::HashHandshakeMessage(out, Perspective::IS_CLIENT); + if (cached->proof_verify_details()) { + proof_handler_->OnProofVerifyDetailsAvailable( + *cached->proof_verify_details()); + } + next_state_ = STATE_RECV_SHLO; + session()->connection()->set_fully_pad_crypto_handshake_packets( + crypto_config_->pad_full_hello()); + SendHandshakeMessage(out, ENCRYPTION_INITIAL); + // Be prepared to decrypt with the new server write key. + delegate_->OnNewEncryptionKeyAvailable( + ENCRYPTION_ZERO_RTT, + std::move(crypto_negotiated_params_->initial_crypters.encrypter)); + delegate_->OnNewDecryptionKeyAvailable( + ENCRYPTION_ZERO_RTT, + std::move(crypto_negotiated_params_->initial_crypters.decrypter), + /*set_alternative_decrypter=*/true, + /*latch_once_used=*/true); + encryption_established_ = true; + delegate_->SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + if (early_data_reason_ == ssl_early_data_unknown && num_client_hellos_ > 1) { + early_data_reason_ = ssl_early_data_peer_declined; + } +} + +void QuicCryptoClientHandshaker::DoReceiveREJ( + const CryptoHandshakeMessage* in, + QuicCryptoClientConfig::CachedState* cached) { + // We sent a dummy CHLO because we didn't have enough information to + // perform a handshake, or we sent a full hello that the server + // rejected. Here we hope to have a REJ that contains the information + // that we need. + if (in->tag() != kREJ) { + next_state_ = STATE_NONE; + stream_->OnUnrecoverableError(QUIC_INVALID_CRYPTO_MESSAGE_TYPE, + "Expected REJ"); + return; + } + + QuicTagVector reject_reasons; + static_assert(sizeof(QuicTag) == sizeof(uint32_t), "header out of sync"); + if (in->GetTaglist(kRREJ, &reject_reasons) == QUIC_NO_ERROR) { + uint32_t packed_error = 0; + for (size_t i = 0; i < reject_reasons.size(); ++i) { + // HANDSHAKE_OK is 0 and don't report that as error. + if (reject_reasons[i] == HANDSHAKE_OK || reject_reasons[i] >= 32) { + continue; + } + HandshakeFailureReason reason = + static_cast(reject_reasons[i]); + packed_error |= 1 << (reason - 1); + } + QUIC_DVLOG(1) << "Reasons for rejection: " << packed_error; + } + + // Receipt of a REJ message means that the server received the CHLO + // so we can cancel and retransmissions. + delegate_->NeuterUnencryptedData(); + + std::string error_details; + QuicErrorCode error = crypto_config_->ProcessRejection( + *in, session()->connection()->clock()->WallNow(), + session()->transport_version(), chlo_hash_, cached, + crypto_negotiated_params_, &error_details); + + if (error != QUIC_NO_ERROR) { + next_state_ = STATE_NONE; + stream_->OnUnrecoverableError(error, error_details); + return; + } + if (!cached->proof_valid()) { + if (!cached->signature().empty()) { + // Note that we only verify the proof if the cached proof is not + // valid. If the cached proof is valid here, someone else must have + // just added the server config to the cache and verified the proof, + // so we can assume no CA trust changes or certificate expiration + // has happened since then. + next_state_ = STATE_VERIFY_PROOF; + return; + } + } + next_state_ = STATE_SEND_CHLO; +} + +QuicAsyncStatus QuicCryptoClientHandshaker::DoVerifyProof( + QuicCryptoClientConfig::CachedState* cached) { + ProofVerifier* verifier = crypto_config_->proof_verifier(); + QUICHE_DCHECK(verifier); + next_state_ = STATE_VERIFY_PROOF_COMPLETE; + generation_counter_ = cached->generation_counter(); + + ProofVerifierCallbackImpl* proof_verify_callback = + new ProofVerifierCallbackImpl(this); + + verify_ok_ = false; + + QuicAsyncStatus status = verifier->VerifyProof( + server_id_.host(), server_id_.port(), cached->server_config(), + session()->transport_version(), chlo_hash_, cached->certs(), + cached->cert_sct(), cached->signature(), verify_context_.get(), + &verify_error_details_, &verify_details_, + std::unique_ptr(proof_verify_callback)); + + switch (status) { + case QUIC_PENDING: + proof_verify_callback_ = proof_verify_callback; + QUIC_DVLOG(1) << "Doing VerifyProof"; + break; + case QUIC_FAILURE: + break; + case QUIC_SUCCESS: + verify_ok_ = true; + break; + } + return status; +} + +void QuicCryptoClientHandshaker::DoVerifyProofComplete( + QuicCryptoClientConfig::CachedState* cached) { + if (proof_verify_start_time_.IsInitialized()) { + QUIC_CLIENT_HISTOGRAM_TIMES( + "QuicSession.VerifyProofTime.CachedServerConfig", + (session()->connection()->clock()->Now() - proof_verify_start_time_), + QuicTime::Delta::FromMilliseconds(1), QuicTime::Delta::FromSeconds(10), + 50, ""); + } + if (!verify_ok_) { + if (verify_details_) { + proof_handler_->OnProofVerifyDetailsAvailable(*verify_details_); + } + if (num_client_hellos_ == 0) { + cached->Clear(); + next_state_ = STATE_INITIALIZE; + return; + } + next_state_ = STATE_NONE; + QUIC_CLIENT_HISTOGRAM_BOOL("QuicVerifyProofFailed.HandshakeConfirmed", + one_rtt_keys_available(), ""); + stream_->OnUnrecoverableError(QUIC_PROOF_INVALID, + "Proof invalid: " + verify_error_details_); + return; + } + + // Check if generation_counter has changed between STATE_VERIFY_PROOF and + // STATE_VERIFY_PROOF_COMPLETE state changes. + if (generation_counter_ != cached->generation_counter()) { + next_state_ = STATE_VERIFY_PROOF; + } else { + SetCachedProofValid(cached); + cached->SetProofVerifyDetails(verify_details_.release()); + if (!one_rtt_keys_available()) { + next_state_ = STATE_SEND_CHLO; + } else { + next_state_ = STATE_NONE; + } + } +} + +void QuicCryptoClientHandshaker::DoReceiveSHLO( + const CryptoHandshakeMessage* in, + QuicCryptoClientConfig::CachedState* cached) { + next_state_ = STATE_NONE; + // We sent a CHLO that we expected to be accepted and now we're + // hoping for a SHLO from the server to confirm that. First check + // to see whether the response was a reject, and if so, move on to + // the reject-processing state. + if (in->tag() == kREJ) { + // A reject message must be sent in ENCRYPTION_INITIAL. + if (session()->connection()->last_decrypted_level() != ENCRYPTION_INITIAL) { + // The rejection was sent encrypted! + stream_->OnUnrecoverableError(QUIC_CRYPTO_ENCRYPTION_LEVEL_INCORRECT, + "encrypted REJ message"); + return; + } + next_state_ = STATE_RECV_REJ; + return; + } + + if (in->tag() != kSHLO) { + stream_->OnUnrecoverableError( + QUIC_INVALID_CRYPTO_MESSAGE_TYPE, + absl::StrCat("Expected SHLO or REJ. Received: ", + QuicTagToString(in->tag()))); + return; + } + + if (session()->connection()->last_decrypted_level() == ENCRYPTION_INITIAL) { + // The server hello was sent without encryption. + stream_->OnUnrecoverableError(QUIC_CRYPTO_ENCRYPTION_LEVEL_INCORRECT, + "unencrypted SHLO message"); + return; + } + if (num_client_hellos_ == 1) { + early_data_reason_ = ssl_early_data_accepted; + } + + std::string error_details; + QuicErrorCode error = crypto_config_->ProcessServerHello( + *in, session()->connection()->connection_id(), + session()->connection()->version(), + session()->connection()->server_supported_versions(), cached, + crypto_negotiated_params_, &error_details); + + if (error != QUIC_NO_ERROR) { + stream_->OnUnrecoverableError(error, + "Server hello invalid: " + error_details); + return; + } + error = session()->config()->ProcessPeerHello(*in, SERVER, &error_details); + if (error != QUIC_NO_ERROR) { + stream_->OnUnrecoverableError(error, + "Server hello invalid: " + error_details); + return; + } + session()->OnConfigNegotiated(); + + CrypterPair* crypters = &crypto_negotiated_params_->forward_secure_crypters; + // TODO(agl): we don't currently latch this decrypter because the idea + // has been floated that the server shouldn't send packets encrypted + // with the FORWARD_SECURE key until it receives a FORWARD_SECURE + // packet from the client. + delegate_->OnNewEncryptionKeyAvailable(ENCRYPTION_FORWARD_SECURE, + std::move(crypters->encrypter)); + delegate_->OnNewDecryptionKeyAvailable(ENCRYPTION_FORWARD_SECURE, + std::move(crypters->decrypter), + /*set_alternative_decrypter=*/true, + /*latch_once_used=*/false); + one_rtt_keys_available_ = true; + delegate_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + delegate_->DiscardOldEncryptionKey(ENCRYPTION_INITIAL); + delegate_->NeuterHandshakeData(); +} + +void QuicCryptoClientHandshaker::DoInitializeServerConfigUpdate( + QuicCryptoClientConfig::CachedState* cached) { + bool update_ignored = false; + if (!cached->IsEmpty() && !cached->signature().empty()) { + // Note that we verify the proof even if the cached proof is valid. + QUICHE_DCHECK(crypto_config_->proof_verifier()); + next_state_ = STATE_VERIFY_PROOF; + } else { + update_ignored = true; + next_state_ = STATE_NONE; + } + QUIC_CLIENT_HISTOGRAM_COUNTS("QuicNumServerConfig.UpdateMessagesIgnored", + update_ignored, 1, 1000000, 50, ""); +} + +void QuicCryptoClientHandshaker::SetCachedProofValid( + QuicCryptoClientConfig::CachedState* cached) { + cached->SetProofValid(); + proof_handler_->OnProofValid(*cached); +} + +} // namespace quic diff --git a/quiche/quic/core/quic_crypto_client_handshaker.h b/quiche/quic/core/quic_crypto_client_handshaker.h new file mode 100644 index 000000000000..dbbd942e709f --- /dev/null +++ b/quiche/quic/core/quic_crypto_client_handshaker.h @@ -0,0 +1,213 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_CRYPTO_CLIENT_HANDSHAKER_H_ +#define QUICHE_QUIC_CORE_QUIC_CRYPTO_CLIENT_HANDSHAKER_H_ + +#include + +#include "quiche/quic/core/crypto/proof_verifier.h" +#include "quiche/quic/core/crypto/quic_crypto_client_config.h" +#include "quiche/quic/core/quic_crypto_client_stream.h" +#include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace quic { + +// An implementation of QuicCryptoClientStream::HandshakerInterface which uses +// QUIC crypto as the crypto handshake protocol. +class QUIC_EXPORT_PRIVATE QuicCryptoClientHandshaker + : public QuicCryptoClientStream::HandshakerInterface, + public QuicCryptoHandshaker { + public: + QuicCryptoClientHandshaker( + const QuicServerId& server_id, QuicCryptoClientStream* stream, + QuicSession* session, std::unique_ptr verify_context, + QuicCryptoClientConfig* crypto_config, + QuicCryptoClientStream::ProofHandler* proof_handler); + QuicCryptoClientHandshaker(const QuicCryptoClientHandshaker&) = delete; + QuicCryptoClientHandshaker& operator=(const QuicCryptoClientHandshaker&) = + delete; + + ~QuicCryptoClientHandshaker() override; + + // From QuicCryptoClientStream::HandshakerInterface + bool CryptoConnect() override; + int num_sent_client_hellos() const override; + bool IsResumption() const override; + bool EarlyDataAccepted() const override; + ssl_early_data_reason_t EarlyDataReason() const override; + bool ReceivedInchoateReject() const override; + int num_scup_messages_received() const override; + std::string chlo_hash() const override; + bool encryption_established() const override; + bool IsCryptoFrameExpectedForEncryptionLevel( + EncryptionLevel level) const override; + EncryptionLevel GetEncryptionLevelToSendCryptoDataOfSpace( + PacketNumberSpace space) const override; + bool one_rtt_keys_available() const override; + const QuicCryptoNegotiatedParameters& crypto_negotiated_params() + const override; + CryptoMessageParser* crypto_message_parser() override; + HandshakeState GetHandshakeState() const override; + size_t BufferSizeLimitForLevel(EncryptionLevel level) const override; + std::unique_ptr AdvanceKeysAndCreateCurrentOneRttDecrypter() + override; + std::unique_ptr CreateCurrentOneRttEncrypter() override; + void OnOneRttPacketAcknowledged() override {} + void OnHandshakePacketSent() override {} + void OnConnectionClosed(QuicErrorCode /*error*/, + ConnectionCloseSource /*source*/) override; + void OnHandshakeDoneReceived() override; + void OnNewTokenReceived(absl::string_view token) override; + void SetServerApplicationStateForResumption( + std::unique_ptr /*application_state*/) override { + QUICHE_NOTREACHED(); + } + bool ExportKeyingMaterial(absl::string_view /*label*/, + absl::string_view /*context*/, + size_t /*result_len*/, + std::string* /*result*/) override { + QUICHE_NOTREACHED(); + return false; + } + + // From QuicCryptoHandshaker + void OnHandshakeMessage(const CryptoHandshakeMessage& message) override; + + protected: + // Returns the QuicSession that this stream belongs to. + QuicSession* session() const { return session_; } + + // Send either InchoateClientHello or ClientHello message to the server. + void DoSendCHLO(QuicCryptoClientConfig::CachedState* cached); + + private: + // ProofVerifierCallbackImpl is passed as the callback method to VerifyProof. + // The ProofVerifier calls this class with the result of proof verification + // when verification is performed asynchronously. + class QUIC_EXPORT_PRIVATE ProofVerifierCallbackImpl + : public ProofVerifierCallback { + public: + explicit ProofVerifierCallbackImpl(QuicCryptoClientHandshaker* parent); + ~ProofVerifierCallbackImpl() override; + + // ProofVerifierCallback interface. + void Run(bool ok, const std::string& error_details, + std::unique_ptr* details) override; + + // Cancel causes any future callbacks to be ignored. It must be called on + // the same thread as the callback will be made on. + void Cancel(); + + private: + QuicCryptoClientHandshaker* parent_; + }; + + enum State { + STATE_IDLE, + STATE_INITIALIZE, + STATE_SEND_CHLO, + STATE_RECV_REJ, + STATE_VERIFY_PROOF, + STATE_VERIFY_PROOF_COMPLETE, + STATE_RECV_SHLO, + STATE_INITIALIZE_SCUP, + STATE_NONE, + STATE_CONNECTION_CLOSED, + }; + + // Handles new server config and optional source-address token provided by the + // server during a connection. + void HandleServerConfigUpdateMessage( + const CryptoHandshakeMessage& server_config_update); + + // DoHandshakeLoop performs a step of the handshake state machine. Note that + // |in| may be nullptr if the call did not result from a received message. + void DoHandshakeLoop(const CryptoHandshakeMessage* in); + + // Start the handshake process. + void DoInitialize(QuicCryptoClientConfig::CachedState* cached); + + // Process REJ message from the server. + void DoReceiveREJ(const CryptoHandshakeMessage* in, + QuicCryptoClientConfig::CachedState* cached); + + // Start the proof verification process. Returns the QuicAsyncStatus returned + // by the ProofVerifier's VerifyProof. + QuicAsyncStatus DoVerifyProof(QuicCryptoClientConfig::CachedState* cached); + + // If proof is valid then it sets the proof as valid (which persists the + // server config). If not, it closes the connection. + void DoVerifyProofComplete(QuicCryptoClientConfig::CachedState* cached); + + // Process SHLO message from the server. + void DoReceiveSHLO(const CryptoHandshakeMessage* in, + QuicCryptoClientConfig::CachedState* cached); + + // Start the proof verification if |server_id_| is https and |cached| has + // signature. + void DoInitializeServerConfigUpdate( + QuicCryptoClientConfig::CachedState* cached); + + // Called to set the proof of |cached| valid. Also invokes the session's + // OnProofValid() method. + void SetCachedProofValid(QuicCryptoClientConfig::CachedState* cached); + + QuicCryptoClientStream* stream_; + + QuicSession* session_; + HandshakerDelegateInterface* delegate_; + + State next_state_; + // num_client_hellos_ contains the number of client hello messages that this + // connection has sent. + int num_client_hellos_; + + ssl_early_data_reason_t early_data_reason_ = ssl_early_data_unknown; + + QuicCryptoClientConfig* const crypto_config_; + + // SHA-256 hash of the most recently sent CHLO. + std::string chlo_hash_; + + // Server's (hostname, port, is_https, privacy_mode) tuple. + const QuicServerId server_id_; + + // Generation counter from QuicCryptoClientConfig's CachedState. + uint64_t generation_counter_; + + // verify_context_ contains the context object that we pass to asynchronous + // proof verifications. + std::unique_ptr verify_context_; + + // proof_verify_callback_ contains the callback object that we passed to an + // asynchronous proof verification. The ProofVerifier owns this object. + ProofVerifierCallbackImpl* proof_verify_callback_; + // proof_handler_ contains the callback object used by a quic client + // for proof verification. It is not owned by this class. + QuicCryptoClientStream::ProofHandler* proof_handler_; + + // These members are used to store the result of an asynchronous proof + // verification. These members must not be used after + // STATE_VERIFY_PROOF_COMPLETE. + bool verify_ok_; + std::string verify_error_details_; + std::unique_ptr verify_details_; + + QuicTime proof_verify_start_time_; + + int num_scup_messages_received_; + + bool encryption_established_; + bool one_rtt_keys_available_; + quiche::QuicheReferenceCountedPointer + crypto_negotiated_params_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_CRYPTO_CLIENT_HANDSHAKER_H_ diff --git a/quiche/quic/core/quic_crypto_client_handshaker_test.cc b/quiche/quic/core/quic_crypto_client_handshaker_test.cc new file mode 100644 index 000000000000..b31ec18b9813 --- /dev/null +++ b/quiche/quic/core/quic_crypto_client_handshaker_test.cc @@ -0,0 +1,217 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_crypto_client_handshaker.h" + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/proto/crypto_server_config_proto.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic::test { +namespace { + +class TestProofHandler : public QuicCryptoClientStream::ProofHandler { + public: + ~TestProofHandler() override {} + void OnProofValid( + const QuicCryptoClientConfig::CachedState& /*cached*/) override {} + void OnProofVerifyDetailsAvailable( + const ProofVerifyDetails& /*verify_details*/) override {} +}; + +class InsecureProofVerifier : public ProofVerifier { + public: + InsecureProofVerifier() {} + ~InsecureProofVerifier() override {} + + // ProofVerifier override. + QuicAsyncStatus VerifyProof( + const std::string& /*hostname*/, const uint16_t /*port*/, + const std::string& /*server_config*/, + QuicTransportVersion /*transport_version*/, + absl::string_view /*chlo_hash*/, + const std::vector& /*certs*/, + const std::string& /*cert_sct*/, const std::string& /*signature*/, + const ProofVerifyContext* /*context*/, std::string* /*error_details*/, + std::unique_ptr* /*verify_details*/, + std::unique_ptr /*callback*/) override { + return QUIC_SUCCESS; + } + + QuicAsyncStatus VerifyCertChain( + const std::string& /*hostname*/, const uint16_t /*port*/, + const std::vector& /*certs*/, + const std::string& /*ocsp_response*/, const std::string& /*cert_sct*/, + const ProofVerifyContext* /*context*/, std::string* /*error_details*/, + std::unique_ptr* /*details*/, uint8_t* /*out_alert*/, + std::unique_ptr /*callback*/) override { + return QUIC_SUCCESS; + } + + std::unique_ptr CreateDefaultContext() override { + return nullptr; + } +}; + +class DummyProofSource : public ProofSource { + public: + DummyProofSource() {} + ~DummyProofSource() override {} + + // ProofSource override. + void GetProof(const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const std::string& hostname, + const std::string& /*server_config*/, + QuicTransportVersion /*transport_version*/, + absl::string_view /*chlo_hash*/, + std::unique_ptr callback) override { + bool cert_matched_sni; + quiche::QuicheReferenceCountedPointer chain = + GetCertChain(server_address, client_address, hostname, + &cert_matched_sni); + QuicCryptoProof proof; + proof.signature = "Dummy signature"; + proof.leaf_cert_scts = "Dummy timestamp"; + proof.cert_matched_sni = cert_matched_sni; + callback->Run(true, chain, proof, /*details=*/nullptr); + } + + quiche::QuicheReferenceCountedPointer GetCertChain( + const QuicSocketAddress& /*server_address*/, + const QuicSocketAddress& /*client_address*/, + const std::string& /*hostname*/, bool* /*cert_matched_sni*/) override { + std::vector certs; + certs.push_back("Dummy cert"); + return quiche::QuicheReferenceCountedPointer( + new ProofSource::Chain(certs)); + } + + void ComputeTlsSignature( + const QuicSocketAddress& /*server_address*/, + const QuicSocketAddress& /*client_address*/, + const std::string& /*hostname*/, uint16_t /*signature_algorit*/, + absl::string_view /*in*/, + std::unique_ptr callback) override { + callback->Run(true, "Dummy signature", /*details=*/nullptr); + } + + absl::InlinedVector SupportedTlsSignatureAlgorithms() + const override { + return {}; + } + + TicketCrypter* GetTicketCrypter() override { return nullptr; } +}; + +class Handshaker : public QuicCryptoClientHandshaker { + public: + Handshaker(const QuicServerId& server_id, QuicCryptoClientStream* stream, + QuicSession* session, + std::unique_ptr verify_context, + QuicCryptoClientConfig* crypto_config, + QuicCryptoClientStream::ProofHandler* proof_handler) + : QuicCryptoClientHandshaker(server_id, stream, session, + std::move(verify_context), crypto_config, + proof_handler) {} + + void DoSendCHLOTest(QuicCryptoClientConfig::CachedState* cached) { + QuicCryptoClientHandshaker::DoSendCHLO(cached); + } +}; + +class QuicCryptoClientHandshakerTest + : public QuicTestWithParam { + protected: + QuicCryptoClientHandshakerTest() + : version_(GetParam()), + proof_handler_(), + helper_(), + alarm_factory_(), + server_id_("host", 123), + connection_(new test::MockQuicConnection( + &helper_, &alarm_factory_, Perspective::IS_CLIENT, {version_})), + session_(connection_, false), + crypto_client_config_(std::make_unique()), + client_stream_( + new QuicCryptoClientStream(server_id_, &session_, nullptr, + &crypto_client_config_, &proof_handler_, + /*has_application_state = */ false)), + handshaker_(server_id_, client_stream_, &session_, nullptr, + &crypto_client_config_, &proof_handler_), + state_() { + // Session takes the ownership of the client stream! (but handshaker also + // takes a reference to it, but doesn't take the ownership). + session_.SetCryptoStream(client_stream_); + session_.Initialize(); + } + + void InitializeServerParametersToEnableFullHello() { + QuicCryptoServerConfig::ConfigOptions options; + QuicServerConfigProtobuf config = QuicCryptoServerConfig::GenerateConfig( + helper_.GetRandomGenerator(), helper_.GetClock(), options); + state_.Initialize( + config.config(), "sourcetoken", std::vector{"Dummy cert"}, + "", "chlo_hash", "signature", helper_.GetClock()->WallNow(), + helper_.GetClock()->WallNow().Add(QuicTime::Delta::FromSeconds(30))); + + state_.SetProofValid(); + } + + ParsedQuicVersion version_; + TestProofHandler proof_handler_; + test::MockQuicConnectionHelper helper_; + test::MockAlarmFactory alarm_factory_; + QuicServerId server_id_; + // Session takes the ownership of the connection. + test::MockQuicConnection* connection_; + test::MockQuicSession session_; + QuicCryptoClientConfig crypto_client_config_; + QuicCryptoClientStream* client_stream_; + Handshaker handshaker_; + QuicCryptoClientConfig::CachedState state_; +}; + +INSTANTIATE_TEST_SUITE_P( + QuicCryptoClientHandshakerTests, QuicCryptoClientHandshakerTest, + ::testing::ValuesIn(AllSupportedVersionsWithQuicCrypto()), + ::testing::PrintToStringParamName()); + +TEST_P(QuicCryptoClientHandshakerTest, TestSendFullPaddingInInchoateHello) { + handshaker_.DoSendCHLOTest(&state_); + + EXPECT_TRUE(connection_->fully_pad_during_crypto_handshake()); +} + +TEST_P(QuicCryptoClientHandshakerTest, TestDisabledPaddingInInchoateHello) { + crypto_client_config_.set_pad_inchoate_hello(false); + handshaker_.DoSendCHLOTest(&state_); + EXPECT_FALSE(connection_->fully_pad_during_crypto_handshake()); +} + +TEST_P(QuicCryptoClientHandshakerTest, + TestPaddingInFullHelloEvenIfInchoateDisabled) { + // Disable inchoate, but full hello should still be padded. + crypto_client_config_.set_pad_inchoate_hello(false); + + InitializeServerParametersToEnableFullHello(); + + handshaker_.DoSendCHLOTest(&state_); + EXPECT_TRUE(connection_->fully_pad_during_crypto_handshake()); +} + +TEST_P(QuicCryptoClientHandshakerTest, TestNoPaddingInFullHelloWhenDisabled) { + crypto_client_config_.set_pad_full_hello(false); + + InitializeServerParametersToEnableFullHello(); + + handshaker_.DoSendCHLOTest(&state_); + EXPECT_FALSE(connection_->fully_pad_during_crypto_handshake()); +} + +} // namespace +} // namespace quic::test diff --git a/quiche/quic/core/quic_crypto_client_stream.cc b/quiche/quic/core/quic_crypto_client_stream.cc new file mode 100644 index 000000000000..3c6dbe767e3b --- /dev/null +++ b/quiche/quic/core/quic_crypto_client_stream.cc @@ -0,0 +1,178 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_crypto_client_stream.h" + +#include +#include +#include + +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/crypto/crypto_utils.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/core/crypto/quic_crypto_client_config.h" +#include "quiche/quic/core/quic_crypto_client_handshaker.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/tls_client_handshaker.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +const int QuicCryptoClientStream::kMaxClientHellos; + +QuicCryptoClientStreamBase::QuicCryptoClientStreamBase(QuicSession* session) + : QuicCryptoStream(session) {} + +QuicCryptoClientStream::QuicCryptoClientStream( + const QuicServerId& server_id, QuicSession* session, + std::unique_ptr verify_context, + QuicCryptoClientConfig* crypto_config, ProofHandler* proof_handler, + bool has_application_state) + : QuicCryptoClientStreamBase(session) { + QUICHE_DCHECK_EQ(Perspective::IS_CLIENT, + session->connection()->perspective()); + switch (session->connection()->version().handshake_protocol) { + case PROTOCOL_QUIC_CRYPTO: + handshaker_ = std::make_unique( + server_id, this, session, std::move(verify_context), crypto_config, + proof_handler); + break; + case PROTOCOL_TLS1_3: { + auto handshaker = std::make_unique( + server_id, this, session, std::move(verify_context), crypto_config, + proof_handler, has_application_state); + tls_handshaker_ = handshaker.get(); + handshaker_ = std::move(handshaker); + break; + } + case PROTOCOL_UNSUPPORTED: + QUIC_BUG(quic_bug_10296_1) + << "Attempting to create QuicCryptoClientStream for unknown " + "handshake protocol"; + } +} + +QuicCryptoClientStream::~QuicCryptoClientStream() {} + +bool QuicCryptoClientStream::CryptoConnect() { + return handshaker_->CryptoConnect(); +} + +int QuicCryptoClientStream::num_sent_client_hellos() const { + return handshaker_->num_sent_client_hellos(); +} + +bool QuicCryptoClientStream::IsResumption() const { + return handshaker_->IsResumption(); +} + +bool QuicCryptoClientStream::EarlyDataAccepted() const { + return handshaker_->EarlyDataAccepted(); +} + +ssl_early_data_reason_t QuicCryptoClientStream::EarlyDataReason() const { + return handshaker_->EarlyDataReason(); +} + +bool QuicCryptoClientStream::ReceivedInchoateReject() const { + return handshaker_->ReceivedInchoateReject(); +} + +int QuicCryptoClientStream::num_scup_messages_received() const { + return handshaker_->num_scup_messages_received(); +} + +bool QuicCryptoClientStream::encryption_established() const { + return handshaker_->encryption_established(); +} + +bool QuicCryptoClientStream::one_rtt_keys_available() const { + return handshaker_->one_rtt_keys_available(); +} + +const QuicCryptoNegotiatedParameters& +QuicCryptoClientStream::crypto_negotiated_params() const { + return handshaker_->crypto_negotiated_params(); +} + +CryptoMessageParser* QuicCryptoClientStream::crypto_message_parser() { + return handshaker_->crypto_message_parser(); +} + +HandshakeState QuicCryptoClientStream::GetHandshakeState() const { + return handshaker_->GetHandshakeState(); +} + +size_t QuicCryptoClientStream::BufferSizeLimitForLevel( + EncryptionLevel level) const { + return handshaker_->BufferSizeLimitForLevel(level); +} + +std::unique_ptr +QuicCryptoClientStream::AdvanceKeysAndCreateCurrentOneRttDecrypter() { + return handshaker_->AdvanceKeysAndCreateCurrentOneRttDecrypter(); +} + +std::unique_ptr +QuicCryptoClientStream::CreateCurrentOneRttEncrypter() { + return handshaker_->CreateCurrentOneRttEncrypter(); +} + +bool QuicCryptoClientStream::ExportKeyingMaterial(absl::string_view label, + absl::string_view context, + size_t result_len, + std::string* result) { + return handshaker_->ExportKeyingMaterial(label, context, result_len, result); +} + +std::string QuicCryptoClientStream::chlo_hash() const { + return handshaker_->chlo_hash(); +} + +void QuicCryptoClientStream::OnOneRttPacketAcknowledged() { + handshaker_->OnOneRttPacketAcknowledged(); +} + +void QuicCryptoClientStream::OnHandshakePacketSent() { + handshaker_->OnHandshakePacketSent(); +} + +void QuicCryptoClientStream::OnConnectionClosed(QuicErrorCode error, + ConnectionCloseSource source) { + handshaker_->OnConnectionClosed(error, source); +} + +void QuicCryptoClientStream::OnHandshakeDoneReceived() { + handshaker_->OnHandshakeDoneReceived(); +} + +void QuicCryptoClientStream::OnNewTokenReceived(absl::string_view token) { + handshaker_->OnNewTokenReceived(token); +} + +void QuicCryptoClientStream::SetServerApplicationStateForResumption( + std::unique_ptr application_state) { + handshaker_->SetServerApplicationStateForResumption( + std::move(application_state)); +} + +SSL* QuicCryptoClientStream::GetSsl() const { + return tls_handshaker_ == nullptr ? nullptr : tls_handshaker_->ssl(); +} + +bool QuicCryptoClientStream::IsCryptoFrameExpectedForEncryptionLevel( + EncryptionLevel level) const { + return handshaker_->IsCryptoFrameExpectedForEncryptionLevel(level); +} + +EncryptionLevel +QuicCryptoClientStream::GetEncryptionLevelToSendCryptoDataOfSpace( + PacketNumberSpace space) const { + return handshaker_->GetEncryptionLevelToSendCryptoDataOfSpace(space); +} + +} // namespace quic diff --git a/quiche/quic/core/quic_crypto_client_stream.h b/quiche/quic/core/quic_crypto_client_stream.h new file mode 100644 index 000000000000..0a3e5e4df573 --- /dev/null +++ b/quiche/quic/core/quic_crypto_client_stream.h @@ -0,0 +1,320 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_CRYPTO_CLIENT_STREAM_H_ +#define QUICHE_QUIC_CORE_QUIC_CRYPTO_CLIENT_STREAM_H_ + +#include +#include +#include + +#include "quiche/quic/core/crypto/proof_verifier.h" +#include "quiche/quic/core/crypto/quic_crypto_client_config.h" +#include "quiche/quic/core/proto/cached_network_parameters_proto.h" +#include "quiche/quic/core/quic_config.h" +#include "quiche/quic/core/quic_crypto_handshaker.h" +#include "quiche/quic/core/quic_crypto_stream.h" +#include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +namespace test { +class QuicCryptoClientStreamPeer; +} // namespace test + +class TlsClientHandshaker; + +class QUIC_EXPORT_PRIVATE QuicCryptoClientStreamBase : public QuicCryptoStream { + public: + explicit QuicCryptoClientStreamBase(QuicSession* session); + + ~QuicCryptoClientStreamBase() override {} + + // Performs a crypto handshake with the server. Returns true if the connection + // is still connected. + virtual bool CryptoConnect() = 0; + + // DEPRECATED: Use IsResumption, EarlyDataAccepted, and/or + // ReceivedInchoateReject instead. + // + // num_sent_client_hellos returns the number of client hello messages that + // have been sent. If the handshake has completed then this is one greater + // than the number of round-trips needed for the handshake. + virtual int num_sent_client_hellos() const = 0; + + // Returns true if the handshake performed was a resumption instead of a full + // handshake. Resumption only makes sense for TLS handshakes - there is no + // concept of resumption for QUIC crypto even though it supports a 0-RTT + // handshake. This function only returns valid results once the handshake is + // complete. + virtual bool IsResumption() const = 0; + + // Returns true if early data (0-RTT) was accepted in the connection. + virtual bool EarlyDataAccepted() const = 0; + + // Returns true if the client received an inchoate REJ during the handshake, + // extending the handshake by one round trip. This only applies for QUIC + // crypto handshakes. The equivalent feature in IETF QUIC is a Retry packet, + // but that is handled at the connection layer instead of the crypto layer. + virtual bool ReceivedInchoateReject() const = 0; + + // The number of server config update messages received by the + // client. Does not count update messages that were received prior + // to handshake confirmation. + virtual int num_scup_messages_received() const = 0; + + bool ExportKeyingMaterial(absl::string_view /*label*/, + absl::string_view /*context*/, + size_t /*result_len*/, + std::string* /*result*/) override { + QUICHE_NOTREACHED(); + return false; + } + + std::string GetAddressToken( + const CachedNetworkParameters* /*cached_network_params*/) const override { + QUICHE_DCHECK(false); + return ""; + } + + bool ValidateAddressToken(absl::string_view /*token*/) const override { + QUICHE_DCHECK(false); + return false; + } + + const CachedNetworkParameters* PreviousCachedNetworkParams() const override { + QUICHE_DCHECK(false); + return nullptr; + } + + void SetPreviousCachedNetworkParams( + CachedNetworkParameters /*cached_network_params*/) override { + QUICHE_DCHECK(false); + } +}; + +class QUIC_EXPORT_PRIVATE QuicCryptoClientStream + : public QuicCryptoClientStreamBase { + public: + // kMaxClientHellos is the maximum number of times that we'll send a client + // hello. The value 4 accounts for: + // * One failure due to an incorrect or missing source-address token. + // * One failure due the server's certificate chain being unavailible and + // the server being unwilling to send it without a valid source-address + // token. + // * One failure due to the ServerConfig private key being located on a + // remote oracle which has become unavailable, forcing the server to send + // the client a fallback ServerConfig. + static const int kMaxClientHellos = 4; + + // QuicCryptoClientStream creates a HandshakerInterface at construction time + // based on the QuicTransportVersion of the connection. Different + // HandshakerInterfaces provide implementations of different crypto handshake + // protocols. Currently QUIC crypto is the only protocol implemented; a future + // HandshakerInterface will use TLS as the handshake protocol. + // QuicCryptoClientStream delegates all of its public methods to its + // HandshakerInterface. + // + // This setup of the crypto stream delegating its implementation to the + // handshaker results in the handshaker reading and writing bytes on the + // crypto stream, instead of the handshaker passing the stream bytes to send. + class QUIC_EXPORT_PRIVATE HandshakerInterface { + public: + virtual ~HandshakerInterface() {} + + // Performs a crypto handshake with the server. Returns true if the + // connection is still connected. + virtual bool CryptoConnect() = 0; + + // DEPRECATED: Use IsResumption, EarlyDataAccepted, and/or + // ReceivedInchoateReject instead. + // + // num_sent_client_hellos returns the number of client hello messages that + // have been sent. If the handshake has completed then this is one greater + // than the number of round-trips needed for the handshake. + virtual int num_sent_client_hellos() const = 0; + + // Returns true if the handshake performed was a resumption instead of a + // full handshake. Resumption only makes sense for TLS handshakes - there is + // no concept of resumption for QUIC crypto even though it supports a 0-RTT + // handshake. This function only returns valid results once the handshake is + // complete. + virtual bool IsResumption() const = 0; + + // Returns true if early data (0-RTT) was accepted in the connection. + virtual bool EarlyDataAccepted() const = 0; + + // Returns the ssl_early_data_reason_t describing why 0-RTT was accepted or + // rejected. + virtual ssl_early_data_reason_t EarlyDataReason() const = 0; + + // Returns true if the client received an inchoate REJ during the handshake, + // extending the handshake by one round trip. This only applies for QUIC + // crypto handshakes. The equivalent feature in IETF QUIC is a Retry packet, + // but that is handled at the connection layer instead of the crypto layer. + virtual bool ReceivedInchoateReject() const = 0; + + // The number of server config update messages received by the + // client. Does not count update messages that were received prior + // to handshake confirmation. + virtual int num_scup_messages_received() const = 0; + + virtual std::string chlo_hash() const = 0; + + // Returns true once any encrypter (initial/0RTT or final/1RTT) has been set + // for the connection. + virtual bool encryption_established() const = 0; + + // Returns true if receiving CRYPTO_FRAME at encryption `level` is expected. + virtual bool IsCryptoFrameExpectedForEncryptionLevel( + EncryptionLevel level) const = 0; + + // Returns the encryption level to send CRYPTO_FRAME for `space`. + virtual EncryptionLevel GetEncryptionLevelToSendCryptoDataOfSpace( + PacketNumberSpace space) const = 0; + + // Returns true once 1RTT keys are available. + virtual bool one_rtt_keys_available() const = 0; + + // Returns the parameters negotiated in the crypto handshake. + virtual const QuicCryptoNegotiatedParameters& crypto_negotiated_params() + const = 0; + + // Used by QuicCryptoStream to parse data received on this stream. + virtual CryptoMessageParser* crypto_message_parser() = 0; + + // Used by QuicCryptoStream to know how much unprocessed data can be + // buffered at each encryption level. + virtual size_t BufferSizeLimitForLevel(EncryptionLevel level) const = 0; + + // Called to generate a decrypter for the next key phase. Each call should + // generate the key for phase n+1. + virtual std::unique_ptr + AdvanceKeysAndCreateCurrentOneRttDecrypter() = 0; + + // Called to generate an encrypter for the same key phase of the last + // decrypter returned by AdvanceKeysAndCreateCurrentOneRttDecrypter(). + virtual std::unique_ptr CreateCurrentOneRttEncrypter() = 0; + + // Returns current handshake state. + virtual HandshakeState GetHandshakeState() const = 0; + + // Called when a 1RTT packet has been acknowledged. + virtual void OnOneRttPacketAcknowledged() = 0; + + // Called when a packet of ENCRYPTION_HANDSHAKE gets sent. + virtual void OnHandshakePacketSent() = 0; + + // Called when connection gets closed. + virtual void OnConnectionClosed(QuicErrorCode error, + ConnectionCloseSource source) = 0; + + // Called when handshake done has been received. + virtual void OnHandshakeDoneReceived() = 0; + + // Called when new token has been received. + virtual void OnNewTokenReceived(absl::string_view token) = 0; + + // Called when application state is received. + virtual void SetServerApplicationStateForResumption( + std::unique_ptr application_state) = 0; + + // Called to obtain keying material export of length |result_len| with the + // given |label| and |context|. Returns false on failure. + virtual bool ExportKeyingMaterial(absl::string_view label, + absl::string_view context, + size_t result_len, + std::string* result) = 0; + }; + + // ProofHandler is an interface that handles callbacks from the crypto + // stream when the client has proof verification details of the server. + class QUIC_EXPORT_PRIVATE ProofHandler { + public: + virtual ~ProofHandler() {} + + // Called when the proof in |cached| is marked valid. If this is a secure + // QUIC session, then this will happen only after the proof verifier + // completes. + virtual void OnProofValid( + const QuicCryptoClientConfig::CachedState& cached) = 0; + + // Called when proof verification details become available, either because + // proof verification is complete, or when cached details are used. This + // will only be called for secure QUIC connections. + virtual void OnProofVerifyDetailsAvailable( + const ProofVerifyDetails& verify_details) = 0; + }; + + QuicCryptoClientStream(const QuicServerId& server_id, QuicSession* session, + std::unique_ptr verify_context, + QuicCryptoClientConfig* crypto_config, + ProofHandler* proof_handler, + bool has_application_state); + QuicCryptoClientStream(const QuicCryptoClientStream&) = delete; + QuicCryptoClientStream& operator=(const QuicCryptoClientStream&) = delete; + + ~QuicCryptoClientStream() override; + + // From QuicCryptoClientStreamBase + bool CryptoConnect() override; + int num_sent_client_hellos() const override; + bool IsResumption() const override; + bool EarlyDataAccepted() const override; + ssl_early_data_reason_t EarlyDataReason() const override; + bool ReceivedInchoateReject() const override; + + int num_scup_messages_received() const override; + + // From QuicCryptoStream + bool encryption_established() const override; + bool one_rtt_keys_available() const override; + const QuicCryptoNegotiatedParameters& crypto_negotiated_params() + const override; + CryptoMessageParser* crypto_message_parser() override; + void OnPacketDecrypted(EncryptionLevel /*level*/) override {} + void OnOneRttPacketAcknowledged() override; + void OnHandshakePacketSent() override; + void OnConnectionClosed(QuicErrorCode error, + ConnectionCloseSource source) override; + void OnHandshakeDoneReceived() override; + void OnNewTokenReceived(absl::string_view token) override; + HandshakeState GetHandshakeState() const override; + void SetServerApplicationStateForResumption( + std::unique_ptr application_state) override; + size_t BufferSizeLimitForLevel(EncryptionLevel level) const override; + std::unique_ptr AdvanceKeysAndCreateCurrentOneRttDecrypter() + override; + std::unique_ptr CreateCurrentOneRttEncrypter() override; + SSL* GetSsl() const override; + bool IsCryptoFrameExpectedForEncryptionLevel( + EncryptionLevel level) const override; + EncryptionLevel GetEncryptionLevelToSendCryptoDataOfSpace( + PacketNumberSpace space) const override; + + bool ExportKeyingMaterial(absl::string_view label, absl::string_view context, + size_t result_len, std::string* result) override; + std::string chlo_hash() const; + + protected: + void set_handshaker(std::unique_ptr handshaker) { + handshaker_ = std::move(handshaker); + } + + private: + friend class test::QuicCryptoClientStreamPeer; + std::unique_ptr handshaker_; + // Points to |handshaker_| if it uses TLS1.3. Otherwise, nullptr. + // TODO(danzh) change the type of |handshaker_| to TlsClientHandshaker after + // deprecating Google QUIC. + TlsClientHandshaker* tls_handshaker_{nullptr}; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_CRYPTO_CLIENT_STREAM_H_ diff --git a/quiche/quic/core/quic_crypto_client_stream_test.cc b/quiche/quic/core/quic_crypto_client_stream_test.cc new file mode 100644 index 000000000000..377f7a296e04 --- /dev/null +++ b/quiche/quic/core/quic_crypto_client_stream_test.cc @@ -0,0 +1,371 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_crypto_client_stream.h" + +#include +#include +#include + +#include "absl/base/macros.h" +#include "quiche/quic/core/crypto/aes_128_gcm_12_encrypter.h" +#include "quiche/quic/core/crypto/quic_decrypter.h" +#include "quiche/quic/core/crypto/quic_encrypter.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/quic_stream_peer.h" +#include "quiche/quic/test_tools/quic_stream_sequencer_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/test_tools/simple_quic_framer.h" +#include "quiche/quic/test_tools/simple_session_cache.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +using testing::_; + +namespace quic { +namespace test { +namespace { + +const char kServerHostname[] = "test.example.com"; +const uint16_t kServerPort = 443; + +// This test tests the client-side of the QUIC crypto handshake. It does not +// test the TLS handshake - that is in tls_client_handshaker_test.cc. +class QuicCryptoClientStreamTest : public QuicTest { + public: + QuicCryptoClientStreamTest() + : supported_versions_(AllSupportedVersionsWithQuicCrypto()), + server_id_(kServerHostname, kServerPort, false), + crypto_config_(crypto_test_utils::ProofVerifierForTesting(), + std::make_unique()), + server_crypto_config_( + crypto_test_utils::CryptoServerConfigForTesting()) { + CreateConnection(); + } + + void CreateSession() { + session_ = std::make_unique( + connection_, DefaultQuicConfig(), supported_versions_, server_id_, + &crypto_config_); + EXPECT_CALL(*session_, GetAlpnsToOffer()) + .WillRepeatedly(testing::Return(std::vector( + {AlpnForVersion(connection_->version())}))); + } + + void CreateConnection() { + connection_ = + new PacketSavingConnection(&client_helper_, &alarm_factory_, + Perspective::IS_CLIENT, supported_versions_); + // Advance the time, because timers do not like uninitialized times. + connection_->AdvanceTime(QuicTime::Delta::FromSeconds(1)); + CreateSession(); + } + + void CompleteCryptoHandshake() { + int proof_verify_details_calls = 1; + if (stream()->handshake_protocol() != PROTOCOL_TLS1_3) { + EXPECT_CALL(*session_, OnProofValid(testing::_)) + .Times(testing::AtLeast(1)); + proof_verify_details_calls = 0; + } + EXPECT_CALL(*session_, OnProofVerifyDetailsAvailable(testing::_)) + .Times(testing::AtLeast(proof_verify_details_calls)); + stream()->CryptoConnect(); + QuicConfig config; + crypto_test_utils::HandshakeWithFakeServer( + &config, server_crypto_config_.get(), &server_helper_, &alarm_factory_, + connection_, stream(), AlpnForVersion(connection_->version())); + } + + QuicCryptoClientStream* stream() { + return session_->GetMutableCryptoStream(); + } + + MockQuicConnectionHelper server_helper_; + MockQuicConnectionHelper client_helper_; + MockAlarmFactory alarm_factory_; + PacketSavingConnection* connection_; + ParsedQuicVersionVector supported_versions_; + std::unique_ptr session_; + QuicServerId server_id_; + CryptoHandshakeMessage message_; + QuicCryptoClientConfig crypto_config_; + std::unique_ptr server_crypto_config_; +}; + +TEST_F(QuicCryptoClientStreamTest, NotInitiallyConected) { + EXPECT_FALSE(stream()->encryption_established()); + EXPECT_FALSE(stream()->one_rtt_keys_available()); +} + +TEST_F(QuicCryptoClientStreamTest, ConnectedAfterSHLO) { + CompleteCryptoHandshake(); + EXPECT_TRUE(stream()->encryption_established()); + EXPECT_TRUE(stream()->one_rtt_keys_available()); + EXPECT_FALSE(stream()->IsResumption()); + EXPECT_EQ(stream()->EarlyDataReason(), ssl_early_data_no_session_offered); +} + +TEST_F(QuicCryptoClientStreamTest, MessageAfterHandshake) { + CompleteCryptoHandshake(); + + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_CRYPTO_MESSAGE_AFTER_HANDSHAKE_COMPLETE, _, _)); + message_.set_tag(kCHLO); + crypto_test_utils::SendHandshakeMessageToStream(stream(), message_, + Perspective::IS_CLIENT); +} + +TEST_F(QuicCryptoClientStreamTest, BadMessageType) { + stream()->CryptoConnect(); + + message_.set_tag(kCHLO); + + EXPECT_CALL(*connection_, CloseConnection(QUIC_INVALID_CRYPTO_MESSAGE_TYPE, + "Expected REJ", _)); + crypto_test_utils::SendHandshakeMessageToStream(stream(), message_, + Perspective::IS_CLIENT); +} + +TEST_F(QuicCryptoClientStreamTest, NegotiatedParameters) { + CompleteCryptoHandshake(); + + const QuicConfig* config = session_->config(); + EXPECT_EQ(kMaximumIdleTimeoutSecs, config->IdleNetworkTimeout().ToSeconds()); + + const QuicCryptoNegotiatedParameters& crypto_params( + stream()->crypto_negotiated_params()); + EXPECT_EQ(crypto_config_.aead[0], crypto_params.aead); + EXPECT_EQ(crypto_config_.kexs[0], crypto_params.key_exchange); +} + +TEST_F(QuicCryptoClientStreamTest, ExpiredServerConfig) { + // Seed the config with a cached server config. + CompleteCryptoHandshake(); + + // Recreate connection with the new config. + CreateConnection(); + + // Advance time 5 years to ensure that we pass the expiry time of the cached + // server config. + connection_->AdvanceTime( + QuicTime::Delta::FromSeconds(60 * 60 * 24 * 365 * 5)); + + EXPECT_CALL(*session_, OnProofValid(testing::_)); + stream()->CryptoConnect(); + // Check that a client hello was sent. + ASSERT_EQ(1u, connection_->encrypted_packets_.size()); + EXPECT_EQ(ENCRYPTION_INITIAL, connection_->encryption_level()); +} + +TEST_F(QuicCryptoClientStreamTest, ClientTurnedOffZeroRtt) { + // Seed the config with a cached server config. + CompleteCryptoHandshake(); + + // Recreate connection with the new config. + CreateConnection(); + + // Set connection option. + QuicTagVector options; + options.push_back(kQNZ2); + session_->config()->SetClientConnectionOptions(options); + + CompleteCryptoHandshake(); + // Check that two client hellos were sent, one inchoate and one normal. + EXPECT_EQ(2, stream()->num_sent_client_hellos()); + EXPECT_FALSE(stream()->EarlyDataAccepted()); + EXPECT_EQ(stream()->EarlyDataReason(), ssl_early_data_disabled); +} + +TEST_F(QuicCryptoClientStreamTest, ClockSkew) { + // Test that if the client's clock is skewed with respect to the server, + // the handshake succeeds. In the past, the client would get the server + // config, notice that it had already expired and then close the connection. + + // Advance time 5 years to ensure that we pass the expiry time in the server + // config, but the TTL is used instead. + connection_->AdvanceTime( + QuicTime::Delta::FromSeconds(60 * 60 * 24 * 365 * 5)); + + // The handshakes completes! + CompleteCryptoHandshake(); +} + +TEST_F(QuicCryptoClientStreamTest, InvalidCachedServerConfig) { + // Seed the config with a cached server config. + CompleteCryptoHandshake(); + + // Recreate connection with the new config. + CreateConnection(); + + QuicCryptoClientConfig::CachedState* state = + crypto_config_.LookupOrCreate(server_id_); + + std::vector certs = state->certs(); + std::string cert_sct = state->cert_sct(); + std::string signature = state->signature(); + std::string chlo_hash = state->chlo_hash(); + state->SetProof(certs, cert_sct, chlo_hash, signature + signature); + + EXPECT_CALL(*session_, OnProofVerifyDetailsAvailable(testing::_)) + .Times(testing::AnyNumber()); + stream()->CryptoConnect(); + // Check that a client hello was sent. + ASSERT_EQ(1u, connection_->encrypted_packets_.size()); +} + +TEST_F(QuicCryptoClientStreamTest, ServerConfigUpdate) { + // Test that the crypto client stream can receive server config updates after + // the connection has been established. + CompleteCryptoHandshake(); + + QuicCryptoClientConfig::CachedState* state = + crypto_config_.LookupOrCreate(server_id_); + + // Ensure cached STK is different to what we send in the handshake. + EXPECT_NE("xstk", state->source_address_token()); + + // Initialize using {...} syntax to avoid trailing \0 if converting from + // string. + unsigned char stk[] = {'x', 's', 't', 'k'}; + + // Minimum SCFG that passes config validation checks. + unsigned char scfg[] = {// SCFG + 0x53, 0x43, 0x46, 0x47, + // num entries + 0x01, 0x00, + // padding + 0x00, 0x00, + // EXPY + 0x45, 0x58, 0x50, 0x59, + // EXPY end offset + 0x08, 0x00, 0x00, 0x00, + // Value + '1', '2', '3', '4', '5', '6', '7', '8'}; + + CryptoHandshakeMessage server_config_update; + server_config_update.set_tag(kSCUP); + server_config_update.SetValue(kSourceAddressTokenTag, stk); + server_config_update.SetValue(kSCFG, scfg); + const uint64_t expiry_seconds = 60 * 60 * 24 * 2; + server_config_update.SetValue(kSTTL, expiry_seconds); + + crypto_test_utils::SendHandshakeMessageToStream( + stream(), server_config_update, Perspective::IS_SERVER); + + // Make sure that the STK and SCFG are cached correctly. + EXPECT_EQ("xstk", state->source_address_token()); + + const std::string& cached_scfg = state->server_config(); + quiche::test::CompareCharArraysWithHexError( + "scfg", cached_scfg.data(), cached_scfg.length(), + reinterpret_cast(scfg), ABSL_ARRAYSIZE(scfg)); + + QuicStreamSequencer* sequencer = QuicStreamPeer::sequencer(stream()); + EXPECT_FALSE(QuicStreamSequencerPeer::IsUnderlyingBufferAllocated(sequencer)); +} + +TEST_F(QuicCryptoClientStreamTest, ServerConfigUpdateWithCert) { + // Test that the crypto client stream can receive and use server config + // updates with certificates after the connection has been established. + CompleteCryptoHandshake(); + + // Build a server config update message with certificates + QuicCryptoServerConfig crypto_config( + QuicCryptoServerConfig::TESTING, QuicRandom::GetInstance(), + crypto_test_utils::ProofSourceForTesting(), KeyExchangeSource::Default()); + crypto_test_utils::SetupCryptoServerConfigForTest( + connection_->clock(), QuicRandom::GetInstance(), &crypto_config); + SourceAddressTokens tokens; + QuicCompressedCertsCache cache(1); + CachedNetworkParameters network_params; + CryptoHandshakeMessage server_config_update; + + class Callback : public BuildServerConfigUpdateMessageResultCallback { + public: + Callback(bool* ok, CryptoHandshakeMessage* message) + : ok_(ok), message_(message) {} + void Run(bool ok, const CryptoHandshakeMessage& message) override { + *ok_ = ok; + *message_ = message; + } + + private: + bool* ok_; + CryptoHandshakeMessage* message_; + }; + + // Note: relies on the callback being invoked synchronously + bool ok = false; + crypto_config.BuildServerConfigUpdateMessage( + session_->transport_version(), stream()->chlo_hash(), tokens, + QuicSocketAddress(QuicIpAddress::Loopback6(), 1234), + QuicSocketAddress(QuicIpAddress::Loopback6(), 4321), connection_->clock(), + QuicRandom::GetInstance(), &cache, stream()->crypto_negotiated_params(), + &network_params, + std::unique_ptr( + new Callback(&ok, &server_config_update))); + EXPECT_TRUE(ok); + + EXPECT_CALL(*session_, OnProofValid(testing::_)); + crypto_test_utils::SendHandshakeMessageToStream( + stream(), server_config_update, Perspective::IS_SERVER); + + // Recreate connection with the new config and verify a 0-RTT attempt. + CreateConnection(); + + EXPECT_CALL(*session_, OnProofValid(testing::_)); + EXPECT_CALL(*session_, OnProofVerifyDetailsAvailable(testing::_)) + .Times(testing::AnyNumber()); + stream()->CryptoConnect(); + EXPECT_TRUE(session_->IsEncryptionEstablished()); +} + +TEST_F(QuicCryptoClientStreamTest, ServerConfigUpdateBeforeHandshake) { + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_CRYPTO_UPDATE_BEFORE_HANDSHAKE_COMPLETE, _, _)); + CryptoHandshakeMessage server_config_update; + server_config_update.set_tag(kSCUP); + crypto_test_utils::SendHandshakeMessageToStream( + stream(), server_config_update, Perspective::IS_SERVER); +} + +TEST_F(QuicCryptoClientStreamTest, PreferredVersion) { + // This mimics the case where client receives version negotiation packet, such + // that, the preferred version is different from the packets' version. + connection_ = new PacketSavingConnection( + &client_helper_, &alarm_factory_, Perspective::IS_CLIENT, + ParsedVersionOfIndex(supported_versions_, 1)); + connection_->AdvanceTime(QuicTime::Delta::FromSeconds(1)); + + CreateSession(); + CompleteCryptoHandshake(); + // 2 CHLOs are sent. + ASSERT_EQ(2u, session_->sent_crypto_handshake_messages().size()); + // Verify preferred version is the highest version that session supports, and + // is different from connection's version. + QuicVersionLabel client_version_label; + EXPECT_THAT(session_->sent_crypto_handshake_messages()[0].GetVersionLabel( + kVER, &client_version_label), + IsQuicNoError()); + EXPECT_EQ(CreateQuicVersionLabel(supported_versions_[0]), + client_version_label); + EXPECT_THAT(session_->sent_crypto_handshake_messages()[1].GetVersionLabel( + kVER, &client_version_label), + IsQuicNoError()); + EXPECT_EQ(CreateQuicVersionLabel(supported_versions_[0]), + client_version_label); + EXPECT_NE(CreateQuicVersionLabel(connection_->version()), + client_version_label); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_crypto_handshaker.cc b/quiche/quic/core/quic_crypto_handshaker.cc new file mode 100644 index 000000000000..07819de86bfd --- /dev/null +++ b/quiche/quic/core/quic_crypto_handshaker.cc @@ -0,0 +1,52 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_crypto_handshaker.h" + +#include "quiche/quic/core/quic_session.h" + +namespace quic { + +#define ENDPOINT \ + (session()->perspective() == Perspective::IS_SERVER ? "Server: " : "Client: ") + +QuicCryptoHandshaker::QuicCryptoHandshaker(QuicCryptoStream* stream, + QuicSession* session) + : stream_(stream), session_(session), last_sent_handshake_message_tag_(0) { + crypto_framer_.set_visitor(this); +} + +QuicCryptoHandshaker::~QuicCryptoHandshaker() {} + +void QuicCryptoHandshaker::SendHandshakeMessage( + const CryptoHandshakeMessage& message, EncryptionLevel level) { + QUIC_DVLOG(1) << ENDPOINT << "Sending " << message.DebugString(); + session()->NeuterUnencryptedData(); + session()->OnCryptoHandshakeMessageSent(message); + last_sent_handshake_message_tag_ = message.tag(); + const QuicData& data = message.GetSerialized(); + stream_->WriteCryptoData(level, data.AsStringPiece()); +} + +void QuicCryptoHandshaker::OnError(CryptoFramer* framer) { + QUIC_DLOG(WARNING) << "Error processing crypto data: " + << QuicErrorCodeToString(framer->error()); +} + +void QuicCryptoHandshaker::OnHandshakeMessage( + const CryptoHandshakeMessage& message) { + QUIC_DVLOG(1) << ENDPOINT << "Received " << message.DebugString(); + session()->OnCryptoHandshakeMessageReceived(message); +} + +CryptoMessageParser* QuicCryptoHandshaker::crypto_message_parser() { + return &crypto_framer_; +} + +size_t QuicCryptoHandshaker::BufferSizeLimitForLevel(EncryptionLevel) const { + return GetQuicFlag(quic_max_buffered_crypto_bytes); +} + +#undef ENDPOINT // undef for jumbo builds +} // namespace quic diff --git a/quiche/quic/core/quic_crypto_handshaker.h b/quiche/quic/core/quic_crypto_handshaker.h new file mode 100644 index 000000000000..526dbd574761 --- /dev/null +++ b/quiche/quic/core/quic_crypto_handshaker.h @@ -0,0 +1,52 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_CRYPTO_HANDSHAKER_H_ +#define QUICHE_QUIC_CORE_QUIC_CRYPTO_HANDSHAKER_H_ + +#include "quiche/quic/core/quic_crypto_stream.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class QUIC_EXPORT_PRIVATE QuicCryptoHandshaker + : public CryptoFramerVisitorInterface { + public: + QuicCryptoHandshaker(QuicCryptoStream* stream, QuicSession* session); + QuicCryptoHandshaker(const QuicCryptoHandshaker&) = delete; + QuicCryptoHandshaker& operator=(const QuicCryptoHandshaker&) = delete; + + ~QuicCryptoHandshaker() override; + + // Sends |message| to the peer. + // TODO(wtc): return a success/failure status. + void SendHandshakeMessage(const CryptoHandshakeMessage& message, + EncryptionLevel level); + + void OnError(CryptoFramer* framer) override; + void OnHandshakeMessage(const CryptoHandshakeMessage& message) override; + + CryptoMessageParser* crypto_message_parser(); + size_t BufferSizeLimitForLevel(EncryptionLevel level) const; + + protected: + QuicTag last_sent_handshake_message_tag() const { + return last_sent_handshake_message_tag_; + } + + private: + QuicSession* session() { return session_; } + + QuicCryptoStream* stream_; + QuicSession* session_; + + CryptoFramer crypto_framer_; + + // Records last sent crypto handshake message tag. + QuicTag last_sent_handshake_message_tag_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_CRYPTO_HANDSHAKER_H_ diff --git a/quiche/quic/core/quic_crypto_server_stream.cc b/quiche/quic/core/quic_crypto_server_stream.cc new file mode 100644 index 000000000000..2bd8fbdc965c --- /dev/null +++ b/quiche/quic/core/quic_crypto_server_stream.cc @@ -0,0 +1,548 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_crypto_server_stream.h" + +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/string_view.h" +#include "openssl/sha.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_testvalue.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_text_utils.h" + +namespace quic { + +class QuicCryptoServerStream::ProcessClientHelloCallback + : public ProcessClientHelloResultCallback { + public: + ProcessClientHelloCallback( + QuicCryptoServerStream* parent, + const quiche::QuicheReferenceCountedPointer< + ValidateClientHelloResultCallback::Result>& result) + : parent_(parent), result_(result) {} + + void Run( + QuicErrorCode error, const std::string& error_details, + std::unique_ptr message, + std::unique_ptr diversification_nonce, + std::unique_ptr proof_source_details) override { + if (parent_ == nullptr) { + return; + } + + parent_->FinishProcessingHandshakeMessageAfterProcessClientHello( + *result_, error, error_details, std::move(message), + std::move(diversification_nonce), std::move(proof_source_details)); + } + + void Cancel() { parent_ = nullptr; } + + private: + QuicCryptoServerStream* parent_; + quiche::QuicheReferenceCountedPointer< + ValidateClientHelloResultCallback::Result> + result_; +}; + +QuicCryptoServerStream::QuicCryptoServerStream( + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache, QuicSession* session, + QuicCryptoServerStreamBase::Helper* helper) + : QuicCryptoServerStreamBase(session), + QuicCryptoHandshaker(this, session), + session_(session), + delegate_(session), + crypto_config_(crypto_config), + compressed_certs_cache_(compressed_certs_cache), + signed_config_(new QuicSignedServerConfig), + helper_(helper), + num_handshake_messages_(0), + num_handshake_messages_with_server_nonces_(0), + send_server_config_update_cb_(nullptr), + num_server_config_update_messages_sent_(0), + zero_rtt_attempted_(false), + chlo_packet_size_(0), + validate_client_hello_cb_(nullptr), + encryption_established_(false), + one_rtt_keys_available_(false), + one_rtt_packet_decrypted_(false), + crypto_negotiated_params_(new QuicCryptoNegotiatedParameters) {} + +QuicCryptoServerStream::~QuicCryptoServerStream() { + CancelOutstandingCallbacks(); +} + +void QuicCryptoServerStream::CancelOutstandingCallbacks() { + // Detach from the validation callback. Calling this multiple times is safe. + if (validate_client_hello_cb_ != nullptr) { + validate_client_hello_cb_->Cancel(); + validate_client_hello_cb_ = nullptr; + } + if (send_server_config_update_cb_ != nullptr) { + send_server_config_update_cb_->Cancel(); + send_server_config_update_cb_ = nullptr; + } + if (std::shared_ptr cb = + process_client_hello_cb_.lock()) { + cb->Cancel(); + process_client_hello_cb_.reset(); + } +} + +void QuicCryptoServerStream::OnHandshakeMessage( + const CryptoHandshakeMessage& message) { + QuicCryptoHandshaker::OnHandshakeMessage(message); + ++num_handshake_messages_; + chlo_packet_size_ = session()->connection()->GetCurrentPacket().length(); + + // Do not process handshake messages after the handshake is confirmed. + if (one_rtt_keys_available_) { + OnUnrecoverableError(QUIC_CRYPTO_MESSAGE_AFTER_HANDSHAKE_COMPLETE, + "Unexpected handshake message from client"); + return; + } + + if (message.tag() != kCHLO) { + OnUnrecoverableError(QUIC_INVALID_CRYPTO_MESSAGE_TYPE, + "Handshake packet not CHLO"); + return; + } + + if (validate_client_hello_cb_ != nullptr || + !process_client_hello_cb_.expired()) { + // Already processing some other handshake message. The protocol + // does not allow for clients to send multiple handshake messages + // before the server has a chance to respond. + OnUnrecoverableError(QUIC_CRYPTO_MESSAGE_WHILE_VALIDATING_CLIENT_HELLO, + "Unexpected handshake message while processing CHLO"); + return; + } + + chlo_hash_ = + CryptoUtils::HashHandshakeMessage(message, Perspective::IS_SERVER); + + std::unique_ptr cb(new ValidateCallback(this)); + QUICHE_DCHECK(validate_client_hello_cb_ == nullptr); + QUICHE_DCHECK(process_client_hello_cb_.expired()); + validate_client_hello_cb_ = cb.get(); + crypto_config_->ValidateClientHello( + message, GetClientAddress(), session()->connection()->self_address(), + transport_version(), session()->connection()->clock(), signed_config_, + std::move(cb)); +} + +void QuicCryptoServerStream::FinishProcessingHandshakeMessage( + quiche::QuicheReferenceCountedPointer< + ValidateClientHelloResultCallback::Result> + result, + std::unique_ptr details) { + // Clear the callback that got us here. + QUICHE_DCHECK(validate_client_hello_cb_ != nullptr); + QUICHE_DCHECK(process_client_hello_cb_.expired()); + validate_client_hello_cb_ = nullptr; + + auto cb = std::make_shared(this, result); + process_client_hello_cb_ = cb; + ProcessClientHello(result, std::move(details), std::move(cb)); +} + +void QuicCryptoServerStream:: + FinishProcessingHandshakeMessageAfterProcessClientHello( + const ValidateClientHelloResultCallback::Result& result, + QuicErrorCode error, const std::string& error_details, + std::unique_ptr reply, + std::unique_ptr diversification_nonce, + std::unique_ptr proof_source_details) { + // Clear the callback that got us here. + QUICHE_DCHECK(!process_client_hello_cb_.expired()); + QUICHE_DCHECK(validate_client_hello_cb_ == nullptr); + process_client_hello_cb_.reset(); + proof_source_details_ = std::move(proof_source_details); + + AdjustTestValue("quic::QuicCryptoServerStream::after_process_client_hello", + session()); + + if (!session()->connection()->connected()) { + QUIC_CODE_COUNT(quic_crypto_disconnected_after_process_client_hello); + QUIC_LOG_FIRST_N(INFO, 10) + << "After processing CHLO, QUIC connection has been closed with code " + << session()->error() << ", details: " << session()->error_details(); + return; + } + + const CryptoHandshakeMessage& message = result.client_hello; + if (error != QUIC_NO_ERROR) { + OnUnrecoverableError(error, error_details); + return; + } + + if (reply->tag() != kSHLO) { + session()->connection()->set_fully_pad_crypto_handshake_packets( + crypto_config_->pad_rej()); + // Send REJ in plaintext. + SendHandshakeMessage(*reply, ENCRYPTION_INITIAL); + return; + } + + // If we are returning a SHLO then we accepted the handshake. Now + // process the negotiated configuration options as part of the + // session config. + QuicConfig* config = session()->config(); + OverrideQuicConfigDefaults(config); + std::string process_error_details; + const QuicErrorCode process_error = + config->ProcessPeerHello(message, CLIENT, &process_error_details); + if (process_error != QUIC_NO_ERROR) { + OnUnrecoverableError(process_error, process_error_details); + return; + } + + session()->OnConfigNegotiated(); + + config->ToHandshakeMessage(reply.get(), session()->transport_version()); + + // Receiving a full CHLO implies the client is prepared to decrypt with + // the new server write key. We can start to encrypt with the new server + // write key. + // + // NOTE: the SHLO will be encrypted with the new server write key. + delegate_->OnNewEncryptionKeyAvailable( + ENCRYPTION_ZERO_RTT, + std::move(crypto_negotiated_params_->initial_crypters.encrypter)); + delegate_->OnNewDecryptionKeyAvailable( + ENCRYPTION_ZERO_RTT, + std::move(crypto_negotiated_params_->initial_crypters.decrypter), + /*set_alternative_decrypter=*/false, + /*latch_once_used=*/false); + delegate_->SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + delegate_->DiscardOldDecryptionKey(ENCRYPTION_INITIAL); + session()->connection()->SetDiversificationNonce(*diversification_nonce); + + session()->connection()->set_fully_pad_crypto_handshake_packets( + crypto_config_->pad_shlo()); + // Send SHLO in ENCRYPTION_ZERO_RTT. + SendHandshakeMessage(*reply, ENCRYPTION_ZERO_RTT); + delegate_->OnNewEncryptionKeyAvailable( + ENCRYPTION_FORWARD_SECURE, + std::move(crypto_negotiated_params_->forward_secure_crypters.encrypter)); + delegate_->OnNewDecryptionKeyAvailable( + ENCRYPTION_FORWARD_SECURE, + std::move(crypto_negotiated_params_->forward_secure_crypters.decrypter), + /*set_alternative_decrypter=*/true, + /*latch_once_used=*/false); + encryption_established_ = true; + one_rtt_keys_available_ = true; + delegate_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + delegate_->DiscardOldEncryptionKey(ENCRYPTION_INITIAL); +} + +void QuicCryptoServerStream::SendServerConfigUpdate( + const CachedNetworkParameters* cached_network_params) { + if (!one_rtt_keys_available_) { + return; + } + + if (send_server_config_update_cb_ != nullptr) { + QUIC_DVLOG(1) + << "Skipped server config update since one is already in progress"; + return; + } + + std::unique_ptr cb( + new SendServerConfigUpdateCallback(this)); + send_server_config_update_cb_ = cb.get(); + + crypto_config_->BuildServerConfigUpdateMessage( + session()->transport_version(), chlo_hash_, + previous_source_address_tokens_, session()->connection()->self_address(), + GetClientAddress(), session()->connection()->clock(), + session()->connection()->random_generator(), compressed_certs_cache_, + *crypto_negotiated_params_, cached_network_params, std::move(cb)); +} + +QuicCryptoServerStream::SendServerConfigUpdateCallback:: + SendServerConfigUpdateCallback(QuicCryptoServerStream* parent) + : parent_(parent) {} + +void QuicCryptoServerStream::SendServerConfigUpdateCallback::Cancel() { + parent_ = nullptr; +} + +// From BuildServerConfigUpdateMessageResultCallback +void QuicCryptoServerStream::SendServerConfigUpdateCallback::Run( + bool ok, const CryptoHandshakeMessage& message) { + if (parent_ == nullptr) { + return; + } + parent_->FinishSendServerConfigUpdate(ok, message); +} + +void QuicCryptoServerStream::FinishSendServerConfigUpdate( + bool ok, const CryptoHandshakeMessage& message) { + // Clear the callback that got us here. + QUICHE_DCHECK(send_server_config_update_cb_ != nullptr); + send_server_config_update_cb_ = nullptr; + + if (!ok) { + QUIC_DVLOG(1) << "Server: Failed to build server config update (SCUP)!"; + return; + } + + QUIC_DVLOG(1) << "Server: Sending server config update: " + << message.DebugString(); + + // Send server config update in ENCRYPTION_FORWARD_SECURE. + SendHandshakeMessage(message, ENCRYPTION_FORWARD_SECURE); + + ++num_server_config_update_messages_sent_; +} + +bool QuicCryptoServerStream::DisableResumption() { + QUICHE_DCHECK(false) << "Not supported for QUIC crypto."; + return false; +} + +bool QuicCryptoServerStream::IsZeroRtt() const { + return num_handshake_messages_ == 1 && + num_handshake_messages_with_server_nonces_ == 0; +} + +bool QuicCryptoServerStream::IsResumption() const { + // QUIC Crypto doesn't have a non-0-RTT resumption mode. + return IsZeroRtt(); +} + +int QuicCryptoServerStream::NumServerConfigUpdateMessagesSent() const { + return num_server_config_update_messages_sent_; +} + +const CachedNetworkParameters* +QuicCryptoServerStream::PreviousCachedNetworkParams() const { + return previous_cached_network_params_.get(); +} + +bool QuicCryptoServerStream::ResumptionAttempted() const { + return zero_rtt_attempted_; +} + +bool QuicCryptoServerStream::EarlyDataAttempted() const { + QUICHE_DCHECK(false) << "Not supported for QUIC crypto."; + return zero_rtt_attempted_; +} + +void QuicCryptoServerStream::SetPreviousCachedNetworkParams( + CachedNetworkParameters cached_network_params) { + previous_cached_network_params_.reset( + new CachedNetworkParameters(cached_network_params)); +} + +void QuicCryptoServerStream::OnPacketDecrypted(EncryptionLevel level) { + if (level == ENCRYPTION_FORWARD_SECURE) { + one_rtt_packet_decrypted_ = true; + delegate_->NeuterHandshakeData(); + } +} + +void QuicCryptoServerStream::OnHandshakeDoneReceived() { QUICHE_DCHECK(false); } + +void QuicCryptoServerStream::OnNewTokenReceived(absl::string_view /*token*/) { + QUICHE_DCHECK(false); +} + +std::string QuicCryptoServerStream::GetAddressToken( + const CachedNetworkParameters* /*cached_network_parameters*/) const { + QUICHE_DCHECK(false); + return ""; +} + +bool QuicCryptoServerStream::ValidateAddressToken( + absl::string_view /*token*/) const { + QUICHE_DCHECK(false); + return false; +} + +bool QuicCryptoServerStream::ShouldSendExpectCTHeader() const { + return signed_config_->proof.send_expect_ct_header; +} + +bool QuicCryptoServerStream::DidCertMatchSni() const { + return signed_config_->proof.cert_matched_sni; +} + +const ProofSource::Details* QuicCryptoServerStream::ProofSourceDetails() const { + return proof_source_details_.get(); +} + +bool QuicCryptoServerStream::GetBase64SHA256ClientChannelID( + std::string* output) const { + if (!encryption_established() || + crypto_negotiated_params_->channel_id.empty()) { + return false; + } + + const std::string& channel_id(crypto_negotiated_params_->channel_id); + uint8_t digest[SHA256_DIGEST_LENGTH]; + SHA256(reinterpret_cast(channel_id.data()), channel_id.size(), + digest); + + quiche::QuicheTextUtils::Base64Encode(digest, ABSL_ARRAYSIZE(digest), output); + return true; +} + +ssl_early_data_reason_t QuicCryptoServerStream::EarlyDataReason() const { + if (IsZeroRtt()) { + return ssl_early_data_accepted; + } + if (zero_rtt_attempted_) { + return ssl_early_data_session_not_resumed; + } + return ssl_early_data_no_session_offered; +} + +bool QuicCryptoServerStream::encryption_established() const { + return encryption_established_; +} + +bool QuicCryptoServerStream::one_rtt_keys_available() const { + return one_rtt_keys_available_; +} + +const QuicCryptoNegotiatedParameters& +QuicCryptoServerStream::crypto_negotiated_params() const { + return *crypto_negotiated_params_; +} + +CryptoMessageParser* QuicCryptoServerStream::crypto_message_parser() { + return QuicCryptoHandshaker::crypto_message_parser(); +} + +HandshakeState QuicCryptoServerStream::GetHandshakeState() const { + return one_rtt_packet_decrypted_ ? HANDSHAKE_COMPLETE : HANDSHAKE_START; +} + +void QuicCryptoServerStream::SetServerApplicationStateForResumption( + std::unique_ptr /*state*/) { + // QUIC Crypto doesn't need to remember any application state as part of doing + // 0-RTT resumption, so this function is a no-op. +} + +size_t QuicCryptoServerStream::BufferSizeLimitForLevel( + EncryptionLevel level) const { + return QuicCryptoHandshaker::BufferSizeLimitForLevel(level); +} + +std::unique_ptr +QuicCryptoServerStream::AdvanceKeysAndCreateCurrentOneRttDecrypter() { + // Key update is only defined in QUIC+TLS. + QUICHE_DCHECK(false); + return nullptr; +} + +std::unique_ptr +QuicCryptoServerStream::CreateCurrentOneRttEncrypter() { + // Key update is only defined in QUIC+TLS. + QUICHE_DCHECK(false); + return nullptr; +} + +void QuicCryptoServerStream::ProcessClientHello( + quiche::QuicheReferenceCountedPointer< + ValidateClientHelloResultCallback::Result> + result, + std::unique_ptr proof_source_details, + std::shared_ptr done_cb) { + proof_source_details_ = std::move(proof_source_details); + const CryptoHandshakeMessage& message = result->client_hello; + std::string error_details; + if (!helper_->CanAcceptClientHello( + message, GetClientAddress(), session()->connection()->peer_address(), + session()->connection()->self_address(), &error_details)) { + done_cb->Run(QUIC_HANDSHAKE_FAILED, error_details, nullptr, nullptr, + nullptr); + return; + } + + absl::string_view user_agent_id; + message.GetStringPiece(quic::kUAID, &user_agent_id); + if (!session()->user_agent_id().has_value() && !user_agent_id.empty()) { + session()->SetUserAgentId(std::string(user_agent_id)); + } + + if (!result->info.server_nonce.empty()) { + ++num_handshake_messages_with_server_nonces_; + } + + if (num_handshake_messages_ == 1) { + // Client attempts zero RTT handshake by sending a non-inchoate CHLO. + absl::string_view public_value; + zero_rtt_attempted_ = message.GetStringPiece(kPUBS, &public_value); + } + + // Store the bandwidth estimate from the client. + if (result->cached_network_params.bandwidth_estimate_bytes_per_second() > 0) { + previous_cached_network_params_.reset( + new CachedNetworkParameters(result->cached_network_params)); + } + previous_source_address_tokens_ = result->info.source_address_tokens; + + QuicConnection* connection = session()->connection(); + crypto_config_->ProcessClientHello( + result, /*reject_only=*/false, connection->connection_id(), + connection->self_address(), GetClientAddress(), connection->version(), + session()->supported_versions(), connection->clock(), + connection->random_generator(), compressed_certs_cache_, + crypto_negotiated_params_, signed_config_, + QuicCryptoStream::CryptoMessageFramingOverhead( + transport_version(), connection->connection_id()), + chlo_packet_size_, std::move(done_cb)); +} + +void QuicCryptoServerStream::OverrideQuicConfigDefaults( + QuicConfig* /*config*/) {} + +QuicCryptoServerStream::ValidateCallback::ValidateCallback( + QuicCryptoServerStream* parent) + : parent_(parent) {} + +void QuicCryptoServerStream::ValidateCallback::Cancel() { parent_ = nullptr; } + +void QuicCryptoServerStream::ValidateCallback::Run( + quiche::QuicheReferenceCountedPointer result, + std::unique_ptr details) { + if (parent_ != nullptr) { + parent_->FinishProcessingHandshakeMessage(std::move(result), + std::move(details)); + } +} + +const QuicSocketAddress QuicCryptoServerStream::GetClientAddress() { + return session()->connection()->peer_address(); +} + +SSL* QuicCryptoServerStream::GetSsl() const { return nullptr; } + +bool QuicCryptoServerStream::IsCryptoFrameExpectedForEncryptionLevel( + EncryptionLevel /*level*/) const { + return true; +} + +EncryptionLevel +QuicCryptoServerStream::GetEncryptionLevelToSendCryptoDataOfSpace( + PacketNumberSpace space) const { + if (space == INITIAL_DATA) { + return ENCRYPTION_INITIAL; + } + if (space == APPLICATION_DATA) { + return ENCRYPTION_ZERO_RTT; + } + QUICHE_DCHECK(false); + return NUM_ENCRYPTION_LEVELS; +} + +} // namespace quic diff --git a/quiche/quic/core/quic_crypto_server_stream.h b/quiche/quic/core/quic_crypto_server_stream.h new file mode 100644 index 000000000000..665d071d0742 --- /dev/null +++ b/quiche/quic/core/quic_crypto_server_stream.h @@ -0,0 +1,271 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_CRYPTO_SERVER_STREAM_H_ +#define QUICHE_QUIC_CORE_QUIC_CRYPTO_SERVER_STREAM_H_ + +#include + +#include "quiche/quic/core/proto/cached_network_parameters_proto.h" +#include "quiche/quic/core/proto/source_address_token_proto.h" +#include "quiche/quic/core/quic_crypto_handshaker.h" +#include "quiche/quic/core/quic_crypto_server_stream_base.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +namespace test { +class QuicCryptoServerStreamPeer; +} // namespace test + +class QUIC_EXPORT_PRIVATE QuicCryptoServerStream + : public QuicCryptoServerStreamBase, + public QuicCryptoHandshaker { + public: + QuicCryptoServerStream(const QuicCryptoServerStream&) = delete; + QuicCryptoServerStream& operator=(const QuicCryptoServerStream&) = delete; + + ~QuicCryptoServerStream() override; + + // From QuicCryptoServerStreamBase + void CancelOutstandingCallbacks() override; + bool GetBase64SHA256ClientChannelID(std::string* output) const override; + void SendServerConfigUpdate( + const CachedNetworkParameters* cached_network_params) override; + bool DisableResumption() override; + bool IsZeroRtt() const override; + bool IsResumption() const override; + bool ResumptionAttempted() const override; + bool EarlyDataAttempted() const override; + int NumServerConfigUpdateMessagesSent() const override; + const CachedNetworkParameters* PreviousCachedNetworkParams() const override; + void SetPreviousCachedNetworkParams( + CachedNetworkParameters cached_network_params) override; + void OnPacketDecrypted(EncryptionLevel level) override; + void OnOneRttPacketAcknowledged() override {} + void OnHandshakePacketSent() override {} + void OnConnectionClosed(QuicErrorCode /*error*/, + ConnectionCloseSource /*source*/) override {} + void OnHandshakeDoneReceived() override; + void OnNewTokenReceived(absl::string_view token) override; + std::string GetAddressToken( + const CachedNetworkParameters* /*cached_network_params*/) const override; + bool ValidateAddressToken(absl::string_view token) const override; + bool ShouldSendExpectCTHeader() const override; + bool DidCertMatchSni() const override; + const ProofSource::Details* ProofSourceDetails() const override; + + // From QuicCryptoStream + ssl_early_data_reason_t EarlyDataReason() const override; + bool encryption_established() const override; + bool one_rtt_keys_available() const override; + const QuicCryptoNegotiatedParameters& crypto_negotiated_params() + const override; + CryptoMessageParser* crypto_message_parser() override; + HandshakeState GetHandshakeState() const override; + void SetServerApplicationStateForResumption( + std::unique_ptr state) override; + size_t BufferSizeLimitForLevel(EncryptionLevel level) const override; + std::unique_ptr AdvanceKeysAndCreateCurrentOneRttDecrypter() + override; + std::unique_ptr CreateCurrentOneRttEncrypter() override; + SSL* GetSsl() const override; + bool IsCryptoFrameExpectedForEncryptionLevel( + EncryptionLevel level) const override; + EncryptionLevel GetEncryptionLevelToSendCryptoDataOfSpace( + PacketNumberSpace space) const override; + + // From QuicCryptoHandshaker + void OnHandshakeMessage(const CryptoHandshakeMessage& message) override; + + protected: + QUIC_EXPORT_PRIVATE friend std::unique_ptr + CreateCryptoServerStream(const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache, + QuicSession* session, + QuicCryptoServerStreamBase::Helper* helper); + + // |crypto_config| must outlive the stream. + // |session| must outlive the stream. + // |helper| must outlive the stream. + QuicCryptoServerStream(const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache, + QuicSession* session, + QuicCryptoServerStreamBase::Helper* helper); + + virtual void ProcessClientHello( + quiche::QuicheReferenceCountedPointer< + ValidateClientHelloResultCallback::Result> + result, + std::unique_ptr proof_source_details, + std::shared_ptr done_cb); + + // Hook that allows the server to set QuicConfig defaults just + // before going through the parameter negotiation step. + virtual void OverrideQuicConfigDefaults(QuicConfig* config); + + // Returns client address used to generate and validate source address token. + virtual const QuicSocketAddress GetClientAddress(); + + // Returns the QuicSession that this stream belongs to. + QuicSession* session() const { return session_; } + + void set_encryption_established(bool encryption_established) { + encryption_established_ = encryption_established; + } + + void set_one_rtt_keys_available(bool one_rtt_keys_available) { + one_rtt_keys_available_ = one_rtt_keys_available; + } + + private: + friend class test::QuicCryptoServerStreamPeer; + + class QUIC_EXPORT_PRIVATE ValidateCallback + : public ValidateClientHelloResultCallback { + public: + explicit ValidateCallback(QuicCryptoServerStream* parent); + ValidateCallback(const ValidateCallback&) = delete; + ValidateCallback& operator=(const ValidateCallback&) = delete; + // To allow the parent to detach itself from the callback before deletion. + void Cancel(); + + // From ValidateClientHelloResultCallback + void Run(quiche::QuicheReferenceCountedPointer result, + std::unique_ptr details) override; + + private: + QuicCryptoServerStream* parent_; + }; + + class SendServerConfigUpdateCallback + : public BuildServerConfigUpdateMessageResultCallback { + public: + explicit SendServerConfigUpdateCallback(QuicCryptoServerStream* parent); + SendServerConfigUpdateCallback(const SendServerConfigUpdateCallback&) = + delete; + void operator=(const SendServerConfigUpdateCallback&) = delete; + + // To allow the parent to detach itself from the callback before deletion. + void Cancel(); + + // From BuildServerConfigUpdateMessageResultCallback + void Run(bool ok, const CryptoHandshakeMessage& message) override; + + private: + QuicCryptoServerStream* parent_; + }; + + // Invoked by ValidateCallback::RunImpl once initial validation of + // the client hello is complete. Finishes processing of the client + // hello message and handles handshake success/failure. + void FinishProcessingHandshakeMessage( + quiche::QuicheReferenceCountedPointer< + ValidateClientHelloResultCallback::Result> + result, + std::unique_ptr details); + + class ProcessClientHelloCallback; + friend class ProcessClientHelloCallback; + + // Portion of FinishProcessingHandshakeMessage which executes after + // ProcessClientHello has been called. + void FinishProcessingHandshakeMessageAfterProcessClientHello( + const ValidateClientHelloResultCallback::Result& result, + QuicErrorCode error, const std::string& error_details, + std::unique_ptr reply, + std::unique_ptr diversification_nonce, + std::unique_ptr proof_source_details); + + // Invoked by SendServerConfigUpdateCallback::RunImpl once the proof has been + // received. |ok| indicates whether or not the proof was successfully + // acquired, and |message| holds the partially-constructed message from + // SendServerConfigUpdate. + void FinishSendServerConfigUpdate(bool ok, + const CryptoHandshakeMessage& message); + + // Returns the QuicTransportVersion of the connection. + QuicTransportVersion transport_version() const { + return session_->transport_version(); + } + + QuicSession* session_; + HandshakerDelegateInterface* delegate_; + + // crypto_config_ contains crypto parameters for the handshake. + const QuicCryptoServerConfig* crypto_config_; + + // compressed_certs_cache_ contains a set of most recently compressed certs. + // Owned by QuicDispatcher. + QuicCompressedCertsCache* compressed_certs_cache_; + + // Server's certificate chain and signature of the server config, as provided + // by ProofSource::GetProof. + quiche::QuicheReferenceCountedPointer signed_config_; + + // Hash of the last received CHLO message which can be used for generating + // server config update messages. + std::string chlo_hash_; + + // Pointer to the helper for this crypto stream. Must outlive this stream. + QuicCryptoServerStreamBase::Helper* helper_; + + // Number of handshake messages received by this stream. + uint8_t num_handshake_messages_; + + // Number of handshake messages received by this stream that contain + // server nonces (indicating that this is a non-zero-RTT handshake + // attempt). + uint8_t num_handshake_messages_with_server_nonces_; + + // Pointer to the active callback that will receive the result of + // BuildServerConfigUpdateMessage and forward it to + // FinishSendServerConfigUpdate. nullptr if no update message is currently + // being built. + SendServerConfigUpdateCallback* send_server_config_update_cb_; + + // Number of server config update (SCUP) messages sent by this stream. + int num_server_config_update_messages_sent_; + + // If the client provides CachedNetworkParameters in the STK in the CHLO, then + // store here, and send back in future STKs if we have no better bandwidth + // estimate to send. + std::unique_ptr previous_cached_network_params_; + + // Contains any source address tokens which were present in the CHLO. + SourceAddressTokens previous_source_address_tokens_; + + // True if client attempts 0-rtt handshake (which can succeed or fail). + bool zero_rtt_attempted_; + + // Size of the packet containing the most recently received CHLO. + QuicByteCount chlo_packet_size_; + + // Pointer to the active callback that will receive the result of the client + // hello validation request and forward it to FinishProcessingHandshakeMessage + // for processing. nullptr if no handshake message is being validated. Note + // that this field is mutually exclusive with process_client_hello_cb_. + ValidateCallback* validate_client_hello_cb_; + + // Pointer to the active callback which will receive the results of + // ProcessClientHello and forward it to + // FinishProcessingHandshakeMessageAfterProcessClientHello. Note that this + // field is mutually exclusive with validate_client_hello_cb_. + std::weak_ptr process_client_hello_cb_; + + // The ProofSource::Details from this connection. + std::unique_ptr proof_source_details_; + + bool encryption_established_; + bool one_rtt_keys_available_; + bool one_rtt_packet_decrypted_; + quiche::QuicheReferenceCountedPointer + crypto_negotiated_params_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_CRYPTO_SERVER_STREAM_H_ diff --git a/quiche/quic/core/quic_crypto_server_stream_base.cc b/quiche/quic/core/quic_crypto_server_stream_base.cc new file mode 100644 index 000000000000..41af6e117e0b --- /dev/null +++ b/quiche/quic/core/quic_crypto_server_stream_base.cc @@ -0,0 +1,50 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_crypto_server_stream_base.h" + +#include +#include +#include + +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/crypto/crypto_utils.h" +#include "quiche/quic/core/crypto/quic_crypto_server_config.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/proto/cached_network_parameters_proto.h" +#include "quiche/quic/core/quic_config.h" +#include "quiche/quic/core/quic_crypto_server_stream.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/tls_server_handshaker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +QuicCryptoServerStreamBase::QuicCryptoServerStreamBase(QuicSession* session) + : QuicCryptoStream(session) {} + +std::unique_ptr CreateCryptoServerStream( + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache, QuicSession* session, + QuicCryptoServerStreamBase::Helper* helper) { + switch (session->connection()->version().handshake_protocol) { + case PROTOCOL_QUIC_CRYPTO: + return std::unique_ptr(new QuicCryptoServerStream( + crypto_config, compressed_certs_cache, session, helper)); + case PROTOCOL_TLS1_3: + return std::unique_ptr( + new TlsServerHandshaker(session, crypto_config)); + case PROTOCOL_UNSUPPORTED: + break; + } + QUIC_BUG(quic_bug_10492_1) + << "Unknown handshake protocol: " + << static_cast(session->connection()->version().handshake_protocol); + return nullptr; +} + +} // namespace quic diff --git a/quiche/quic/core/quic_crypto_server_stream_base.h b/quiche/quic/core/quic_crypto_server_stream_base.h new file mode 100644 index 000000000000..84d61900d297 --- /dev/null +++ b/quiche/quic/core/quic_crypto_server_stream_base.h @@ -0,0 +1,122 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_CRYPTO_SERVER_STREAM_BASE_H_ +#define QUICHE_QUIC_CORE_QUIC_CRYPTO_SERVER_STREAM_BASE_H_ + +#include +#include +#include + +#include "quiche/quic/core/crypto/crypto_handshake.h" +#include "quiche/quic/core/crypto/quic_compressed_certs_cache.h" +#include "quiche/quic/core/crypto/quic_crypto_server_config.h" +#include "quiche/quic/core/quic_config.h" +#include "quiche/quic/core/quic_crypto_handshaker.h" +#include "quiche/quic/core/quic_crypto_stream.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class CachedNetworkParameters; +class CryptoHandshakeMessage; +class QuicCryptoServerConfig; +class QuicCryptoServerStreamBase; + +// TODO(alyssar) see what can be moved out of QuicCryptoServerStream with +// various code and test refactoring. +class QUIC_EXPORT_PRIVATE QuicCryptoServerStreamBase : public QuicCryptoStream { + public: + explicit QuicCryptoServerStreamBase(QuicSession* session); + + class QUIC_EXPORT_PRIVATE Helper { + public: + virtual ~Helper() {} + + // Returns true if |message|, which was received on |self_address| is + // acceptable according to the visitor's policy. Otherwise, returns false + // and populates |error_details|. + virtual bool CanAcceptClientHello(const CryptoHandshakeMessage& message, + const QuicSocketAddress& client_address, + const QuicSocketAddress& peer_address, + const QuicSocketAddress& self_address, + std::string* error_details) const = 0; + }; + + ~QuicCryptoServerStreamBase() override {} + + // Cancel any outstanding callbacks, such as asynchronous validation of client + // hello. + virtual void CancelOutstandingCallbacks() = 0; + + // GetBase64SHA256ClientChannelID sets |*output| to the base64 encoded, + // SHA-256 hash of the client's ChannelID key and returns true, if the client + // presented a ChannelID. Otherwise it returns false. + virtual bool GetBase64SHA256ClientChannelID(std::string* output) const = 0; + + virtual int NumServerConfigUpdateMessagesSent() const = 0; + + // Sends the latest server config and source-address token to the client. + virtual void SendServerConfigUpdate( + const CachedNetworkParameters* cached_network_params) = 0; + + // Disables TLS resumption, should be called as early as possible. + // Return true if resumption is disabled. + // Return false if nothing happened, typically it means it is called too late. + virtual bool DisableResumption() = 0; + + // Returns true if the connection was a successful 0-RTT resumption. + virtual bool IsZeroRtt() const = 0; + + // Returns true if the connection was the result of a resumption handshake, + // whether 0-RTT or not. + virtual bool IsResumption() const = 0; + + // Returns true if the client attempted a resumption handshake, whether or not + // the resumption actually occurred. + virtual bool ResumptionAttempted() const = 0; + + // Returns true if the client attempted to use early data, as indicated by the + // "early_data" TLS extension. TLS only. + virtual bool EarlyDataAttempted() const = 0; + + // NOTE: Indicating that the Expect-CT header should be sent here presents + // a layering violation to some extent. The Expect-CT header only applies to + // HTTP connections, while this class can be used for non-HTTP applications. + // However, it is exposed here because that is the only place where the + // configuration for the certificate used in the connection is accessible. + virtual bool ShouldSendExpectCTHeader() const = 0; + + // Return true if a cert was picked that matched the SNI hostname. + virtual bool DidCertMatchSni() const = 0; + + // Returns the Details from the latest call to ProofSource::GetProof or + // ProofSource::ComputeTlsSignature. Returns nullptr if no such call has been + // made. The Details are owned by the QuicCryptoServerStreamBase and the + // pointer is only valid while the owning object is still valid. + virtual const ProofSource::Details* ProofSourceDetails() const = 0; + + bool ExportKeyingMaterial(absl::string_view /*label*/, + absl::string_view /*context*/, + size_t /*result_len*/, + std::string* /*result*/) override { + QUICHE_NOTREACHED(); + return false; + } +}; + +// Creates an appropriate QuicCryptoServerStream for the provided parameters, +// including the version used by |session|. |crypto_config|, |session|, and +// |helper| must all outlive the stream. The caller takes ownership of the +// returned object. +QUIC_EXPORT_PRIVATE std::unique_ptr +CreateCryptoServerStream(const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache, + QuicSession* session, + QuicCryptoServerStreamBase::Helper* helper); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_CRYPTO_SERVER_STREAM_BASE_H_ diff --git a/quiche/quic/core/quic_crypto_server_stream_test.cc b/quiche/quic/core/quic_crypto_server_stream_test.cc new file mode 100644 index 000000000000..6f743d90469c --- /dev/null +++ b/quiche/quic/core/quic_crypto_server_stream_test.cc @@ -0,0 +1,397 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/aes_128_gcm_12_encrypter.h" +#include "quiche/quic/core/crypto/crypto_framer.h" +#include "quiche/quic/core/crypto/crypto_handshake.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/crypto/crypto_utils.h" +#include "quiche/quic/core/crypto/quic_crypto_server_config.h" +#include "quiche/quic/core/crypto/quic_decrypter.h" +#include "quiche/quic/core/crypto/quic_encrypter.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/quic_crypto_client_stream.h" +#include "quiche/quic/core/quic_crypto_server_stream_base.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/failing_proof_source.h" +#include "quiche/quic/test_tools/fake_proof_source.h" +#include "quiche/quic/test_tools/quic_crypto_server_config_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { +class QuicConnection; +class QuicStream; +} // namespace quic + +using testing::_; +using testing::NiceMock; + +namespace quic { +namespace test { + +namespace { + +const char kServerHostname[] = "test.example.com"; +const uint16_t kServerPort = 443; + +// This test tests the server-side of the QUIC crypto handshake. It does not +// test the TLS handshake - that is in tls_server_handshaker_test.cc. +class QuicCryptoServerStreamTest : public QuicTest { + public: + QuicCryptoServerStreamTest() + : QuicCryptoServerStreamTest(crypto_test_utils::ProofSourceForTesting()) { + } + + explicit QuicCryptoServerStreamTest(std::unique_ptr proof_source) + : server_crypto_config_( + QuicCryptoServerConfig::TESTING, QuicRandom::GetInstance(), + std::move(proof_source), KeyExchangeSource::Default()), + server_compressed_certs_cache_( + QuicCompressedCertsCache::kQuicCompressedCertsCacheSize), + server_id_(kServerHostname, kServerPort, false), + client_crypto_config_(crypto_test_utils::ProofVerifierForTesting()) {} + + void Initialize() { InitializeServer(); } + + ~QuicCryptoServerStreamTest() override { + // Ensure that anything that might reference |helpers_| is destroyed before + // |helpers_| is destroyed. + server_session_.reset(); + client_session_.reset(); + helpers_.clear(); + alarm_factories_.clear(); + } + + // Initializes the crypto server stream state for testing. May be + // called multiple times. + void InitializeServer() { + TestQuicSpdyServerSession* server_session = nullptr; + helpers_.push_back(std::make_unique>()); + alarm_factories_.push_back(std::make_unique()); + CreateServerSessionForTest( + server_id_, QuicTime::Delta::FromSeconds(100000), supported_versions_, + helpers_.back().get(), alarm_factories_.back().get(), + &server_crypto_config_, &server_compressed_certs_cache_, + &server_connection_, &server_session); + QUICHE_CHECK(server_session); + server_session_.reset(server_session); + EXPECT_CALL(*server_session_->helper(), CanAcceptClientHello(_, _, _, _, _)) + .Times(testing::AnyNumber()); + EXPECT_CALL(*server_session_, SelectAlpn(_)) + .WillRepeatedly([this](const std::vector& alpns) { + return std::find( + alpns.cbegin(), alpns.cend(), + AlpnForVersion(server_session_->connection()->version())); + }); + crypto_test_utils::SetupCryptoServerConfigForTest( + server_connection_->clock(), server_connection_->random_generator(), + &server_crypto_config_); + } + + QuicCryptoServerStreamBase* server_stream() { + return server_session_->GetMutableCryptoStream(); + } + + QuicCryptoClientStream* client_stream() { + return client_session_->GetMutableCryptoStream(); + } + + // Initializes a fake client, and all its associated state, for + // testing. May be called multiple times. + void InitializeFakeClient() { + TestQuicSpdyClientSession* client_session = nullptr; + helpers_.push_back(std::make_unique>()); + alarm_factories_.push_back(std::make_unique()); + CreateClientSessionForTest( + server_id_, QuicTime::Delta::FromSeconds(100000), supported_versions_, + helpers_.back().get(), alarm_factories_.back().get(), + &client_crypto_config_, &client_connection_, &client_session); + QUICHE_CHECK(client_session); + client_session_.reset(client_session); + } + + int CompleteCryptoHandshake() { + QUICHE_CHECK(server_connection_); + QUICHE_CHECK(server_session_ != nullptr); + + return crypto_test_utils::HandshakeWithFakeClient( + helpers_.back().get(), alarm_factories_.back().get(), + server_connection_, server_stream(), server_id_, client_options_, + /*alpn=*/""); + } + + // Performs a single round of handshake message-exchange between the + // client and server. + void AdvanceHandshakeWithFakeClient() { + QUICHE_CHECK(server_connection_); + QUICHE_CHECK(client_session_ != nullptr); + + EXPECT_CALL(*client_session_, OnProofValid(_)).Times(testing::AnyNumber()); + EXPECT_CALL(*client_session_, OnProofVerifyDetailsAvailable(_)) + .Times(testing::AnyNumber()); + EXPECT_CALL(*client_connection_, OnCanWrite()).Times(testing::AnyNumber()); + EXPECT_CALL(*server_connection_, OnCanWrite()).Times(testing::AnyNumber()); + client_stream()->CryptoConnect(); + crypto_test_utils::AdvanceHandshake(client_connection_, client_stream(), 0, + server_connection_, server_stream(), 0); + } + + protected: + // Every connection gets its own MockQuicConnectionHelper and + // MockAlarmFactory, tracked separately from the server and client state so + // their lifetimes persist through the whole test. + std::vector> helpers_; + std::vector> alarm_factories_; + + // Server state. + PacketSavingConnection* server_connection_; + std::unique_ptr server_session_; + QuicCryptoServerConfig server_crypto_config_; + QuicCompressedCertsCache server_compressed_certs_cache_; + QuicServerId server_id_; + + // Client state. + PacketSavingConnection* client_connection_; + QuicCryptoClientConfig client_crypto_config_; + std::unique_ptr client_session_; + + CryptoHandshakeMessage message_; + crypto_test_utils::FakeClientOptions client_options_; + + // Which QUIC versions the client and server support. + ParsedQuicVersionVector supported_versions_ = + AllSupportedVersionsWithQuicCrypto(); +}; + +TEST_F(QuicCryptoServerStreamTest, NotInitiallyConected) { + Initialize(); + EXPECT_FALSE(server_stream()->encryption_established()); + EXPECT_FALSE(server_stream()->one_rtt_keys_available()); +} + +TEST_F(QuicCryptoServerStreamTest, ConnectedAfterCHLO) { + // CompleteCryptoHandshake returns the number of client hellos sent. This + // test should send: + // * One to get a source-address token and certificates. + // * One to complete the handshake. + Initialize(); + EXPECT_EQ(2, CompleteCryptoHandshake()); + EXPECT_TRUE(server_stream()->encryption_established()); + EXPECT_TRUE(server_stream()->one_rtt_keys_available()); +} + +TEST_F(QuicCryptoServerStreamTest, ForwardSecureAfterCHLO) { + Initialize(); + InitializeFakeClient(); + + // Do a first handshake in order to prime the client config with the server's + // information. + AdvanceHandshakeWithFakeClient(); + EXPECT_FALSE(server_stream()->encryption_established()); + EXPECT_FALSE(server_stream()->one_rtt_keys_available()); + + // Now do another handshake, with the blocking SHLO connection option. + InitializeServer(); + InitializeFakeClient(); + + AdvanceHandshakeWithFakeClient(); + EXPECT_TRUE(server_stream()->encryption_established()); + EXPECT_TRUE(server_stream()->one_rtt_keys_available()); + EXPECT_EQ(ENCRYPTION_FORWARD_SECURE, + server_session_->connection()->encryption_level()); +} + +TEST_F(QuicCryptoServerStreamTest, ZeroRTT) { + Initialize(); + InitializeFakeClient(); + + // Do a first handshake in order to prime the client config with the server's + // information. + AdvanceHandshakeWithFakeClient(); + EXPECT_FALSE(server_stream()->ResumptionAttempted()); + + // Now do another handshake, hopefully in 0-RTT. + QUIC_LOG(INFO) << "Resetting for 0-RTT handshake attempt"; + InitializeFakeClient(); + InitializeServer(); + + EXPECT_CALL(*client_session_, OnProofValid(_)).Times(testing::AnyNumber()); + EXPECT_CALL(*client_session_, OnProofVerifyDetailsAvailable(_)) + .Times(testing::AnyNumber()); + EXPECT_CALL(*client_connection_, OnCanWrite()).Times(testing::AnyNumber()); + client_stream()->CryptoConnect(); + + EXPECT_CALL(*client_session_, OnProofValid(_)).Times(testing::AnyNumber()); + EXPECT_CALL(*client_session_, OnProofVerifyDetailsAvailable(_)) + .Times(testing::AnyNumber()); + EXPECT_CALL(*client_connection_, OnCanWrite()).Times(testing::AnyNumber()); + crypto_test_utils::CommunicateHandshakeMessages( + client_connection_, client_stream(), server_connection_, server_stream()); + + EXPECT_EQ(1, client_stream()->num_sent_client_hellos()); + EXPECT_TRUE(server_stream()->ResumptionAttempted()); +} + +TEST_F(QuicCryptoServerStreamTest, FailByPolicy) { + Initialize(); + InitializeFakeClient(); + + EXPECT_CALL(*server_session_->helper(), CanAcceptClientHello(_, _, _, _, _)) + .WillOnce(testing::Return(false)); + EXPECT_CALL(*server_connection_, + CloseConnection(QUIC_HANDSHAKE_FAILED, _, _)); + + AdvanceHandshakeWithFakeClient(); +} + +TEST_F(QuicCryptoServerStreamTest, MessageAfterHandshake) { + Initialize(); + CompleteCryptoHandshake(); + EXPECT_CALL( + *server_connection_, + CloseConnection(QUIC_CRYPTO_MESSAGE_AFTER_HANDSHAKE_COMPLETE, _, _)); + message_.set_tag(kCHLO); + crypto_test_utils::SendHandshakeMessageToStream(server_stream(), message_, + Perspective::IS_CLIENT); +} + +TEST_F(QuicCryptoServerStreamTest, BadMessageType) { + Initialize(); + + message_.set_tag(kSHLO); + EXPECT_CALL(*server_connection_, + CloseConnection(QUIC_INVALID_CRYPTO_MESSAGE_TYPE, _, _)); + crypto_test_utils::SendHandshakeMessageToStream(server_stream(), message_, + Perspective::IS_SERVER); +} + +TEST_F(QuicCryptoServerStreamTest, OnlySendSCUPAfterHandshakeComplete) { + // An attempt to send a SCUP before completing handshake should fail. + Initialize(); + + server_stream()->SendServerConfigUpdate(nullptr); + EXPECT_EQ(0, server_stream()->NumServerConfigUpdateMessagesSent()); +} + +TEST_F(QuicCryptoServerStreamTest, SendSCUPAfterHandshakeComplete) { + Initialize(); + + InitializeFakeClient(); + + // Do a first handshake in order to prime the client config with the server's + // information. + AdvanceHandshakeWithFakeClient(); + + // Now do another handshake, with the blocking SHLO connection option. + InitializeServer(); + InitializeFakeClient(); + AdvanceHandshakeWithFakeClient(); + + // Send a SCUP message and ensure that the client was able to verify it. + EXPECT_CALL(*client_connection_, CloseConnection(_, _, _)).Times(0); + server_stream()->SendServerConfigUpdate(nullptr); + crypto_test_utils::AdvanceHandshake(client_connection_, client_stream(), 1, + server_connection_, server_stream(), 1); + + EXPECT_EQ(1, server_stream()->NumServerConfigUpdateMessagesSent()); + EXPECT_EQ(1, client_stream()->num_scup_messages_received()); +} + +class QuicCryptoServerStreamTestWithFailingProofSource + : public QuicCryptoServerStreamTest { + public: + QuicCryptoServerStreamTestWithFailingProofSource() + : QuicCryptoServerStreamTest( + std::unique_ptr(new FailingProofSource)) {} +}; + +TEST_F(QuicCryptoServerStreamTestWithFailingProofSource, Test) { + Initialize(); + InitializeFakeClient(); + + EXPECT_CALL(*server_session_->helper(), CanAcceptClientHello(_, _, _, _, _)) + .WillOnce(testing::Return(true)); + EXPECT_CALL(*server_connection_, + CloseConnection(QUIC_HANDSHAKE_FAILED, "Failed to get proof", _)); + // Regression test for b/31521252, in which a crash would happen here. + AdvanceHandshakeWithFakeClient(); + EXPECT_FALSE(server_stream()->encryption_established()); + EXPECT_FALSE(server_stream()->one_rtt_keys_available()); +} + +class QuicCryptoServerStreamTestWithFakeProofSource + : public QuicCryptoServerStreamTest { + public: + QuicCryptoServerStreamTestWithFakeProofSource() + : QuicCryptoServerStreamTest( + std::unique_ptr(new FakeProofSource)), + crypto_config_peer_(&server_crypto_config_) {} + + FakeProofSource* GetFakeProofSource() const { + return static_cast(crypto_config_peer_.GetProofSource()); + } + + protected: + QuicCryptoServerConfigPeer crypto_config_peer_; +}; + +// Regression test for b/35422225, in which multiple CHLOs arriving on the same +// connection in close succession could cause a crash. +TEST_F(QuicCryptoServerStreamTestWithFakeProofSource, MultipleChlo) { + Initialize(); + GetFakeProofSource()->Activate(); + EXPECT_CALL(*server_session_->helper(), CanAcceptClientHello(_, _, _, _, _)) + .WillOnce(testing::Return(true)); + + // The methods below use a PROTOCOL_QUIC_CRYPTO version so we pick the + // first one from the list of supported versions. + QuicTransportVersion transport_version = QUIC_VERSION_UNSUPPORTED; + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + if (version.handshake_protocol == PROTOCOL_QUIC_CRYPTO) { + transport_version = version.transport_version; + break; + } + } + ASSERT_NE(QUIC_VERSION_UNSUPPORTED, transport_version); + + // Create a minimal CHLO + MockClock clock; + CryptoHandshakeMessage chlo = crypto_test_utils::GenerateDefaultInchoateCHLO( + &clock, transport_version, &server_crypto_config_); + + // Send in the CHLO, and check that a callback is now pending in the + // ProofSource. + crypto_test_utils::SendHandshakeMessageToStream(server_stream(), chlo, + Perspective::IS_CLIENT); + EXPECT_EQ(GetFakeProofSource()->NumPendingCallbacks(), 1); + + // Send in a second CHLO while processing of the first is still pending. + // Verify that the server closes the connection rather than crashing. Note + // that the crash is a use-after-free, so it may only show up consistently in + // ASAN tests. + EXPECT_CALL( + *server_connection_, + CloseConnection(QUIC_CRYPTO_MESSAGE_WHILE_VALIDATING_CLIENT_HELLO, + "Unexpected handshake message while processing CHLO", _)); + crypto_test_utils::SendHandshakeMessageToStream(server_stream(), chlo, + Perspective::IS_CLIENT); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_crypto_stream.cc b/quiche/quic/core/quic_crypto_stream.cc new file mode 100644 index 000000000000..31c2180c0a66 --- /dev/null +++ b/quiche/quic/core/quic_crypto_stream.cc @@ -0,0 +1,518 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_crypto_stream.h" + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/quic/core/crypto/crypto_handshake.h" +#include "quiche/quic/core/frames/quic_crypto_frame.h" +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +#define ENDPOINT \ + (session()->perspective() == Perspective::IS_SERVER ? "Server: " \ + : "Client:" \ + " ") + +QuicCryptoStream::QuicCryptoStream(QuicSession* session) + : QuicStream( + QuicVersionUsesCryptoFrames(session->transport_version()) + ? QuicUtils::GetInvalidStreamId(session->transport_version()) + : QuicUtils::GetCryptoStreamId(session->transport_version()), + session, + /*is_static=*/true, + QuicVersionUsesCryptoFrames(session->transport_version()) + ? CRYPTO + : BIDIRECTIONAL), + substreams_{{{this}, {this}, {this}}} { + // The crypto stream is exempt from connection level flow control. + DisableConnectionFlowControlForThisStream(); +} + +QuicCryptoStream::~QuicCryptoStream() {} + +// static +QuicByteCount QuicCryptoStream::CryptoMessageFramingOverhead( + QuicTransportVersion version, QuicConnectionId connection_id) { + QUICHE_DCHECK( + QuicUtils::IsConnectionIdValidForVersion(connection_id, version)); + quiche::QuicheVariableLengthIntegerLength retry_token_length_length = + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1; + quiche::QuicheVariableLengthIntegerLength length_length = + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2; + if (!QuicVersionHasLongHeaderLengths(version)) { + retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0; + length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0; + } + return QuicPacketCreator::StreamFramePacketOverhead( + version, connection_id.length(), 0, /*include_version=*/true, + /*include_diversification_nonce=*/true, + VersionHasIetfInvariantHeader(version) ? PACKET_4BYTE_PACKET_NUMBER + : PACKET_1BYTE_PACKET_NUMBER, + retry_token_length_length, length_length, + /*offset=*/0); +} + +void QuicCryptoStream::OnCryptoFrame(const QuicCryptoFrame& frame) { + QUIC_BUG_IF(quic_bug_12573_1, + !QuicVersionUsesCryptoFrames(session()->transport_version())) + << "Versions less than 47 shouldn't receive CRYPTO frames"; + EncryptionLevel level = session()->connection()->last_decrypted_level(); + if (!IsCryptoFrameExpectedForEncryptionLevel(level)) { + OnUnrecoverableError( + IETF_QUIC_PROTOCOL_VIOLATION, + absl::StrCat("CRYPTO_FRAME is unexpectedly received at level ", level)); + return; + } + CryptoSubstream& substream = + substreams_[QuicUtils::GetPacketNumberSpace(level)]; + substream.sequencer.OnCryptoFrame(frame); + EncryptionLevel frame_level = level; + if (substream.sequencer.NumBytesBuffered() > + BufferSizeLimitForLevel(frame_level)) { + OnUnrecoverableError(QUIC_FLOW_CONTROL_RECEIVED_TOO_MUCH_DATA, + "Too much crypto data received"); + } +} + +void QuicCryptoStream::OnStreamFrame(const QuicStreamFrame& frame) { + if (QuicVersionUsesCryptoFrames(session()->transport_version())) { + QUIC_PEER_BUG(quic_peer_bug_12573_2) + << "Crypto data received in stream frame instead of crypto frame"; + OnUnrecoverableError(QUIC_INVALID_STREAM_DATA, "Unexpected stream frame"); + } + QuicStream::OnStreamFrame(frame); +} + +void QuicCryptoStream::OnDataAvailable() { + EncryptionLevel level = session()->connection()->last_decrypted_level(); + if (!QuicVersionUsesCryptoFrames(session()->transport_version())) { + // Versions less than 47 only support QUIC crypto, which ignores the + // EncryptionLevel passed into CryptoMessageParser::ProcessInput (and + // OnDataAvailableInSequencer). + OnDataAvailableInSequencer(sequencer(), level); + return; + } + OnDataAvailableInSequencer( + &substreams_[QuicUtils::GetPacketNumberSpace(level)].sequencer, level); +} + +void QuicCryptoStream::OnDataAvailableInSequencer( + QuicStreamSequencer* sequencer, EncryptionLevel level) { + struct iovec iov; + while (sequencer->GetReadableRegion(&iov)) { + absl::string_view data(static_cast(iov.iov_base), iov.iov_len); + if (!crypto_message_parser()->ProcessInput(data, level)) { + OnUnrecoverableError(crypto_message_parser()->error(), + crypto_message_parser()->error_detail()); + return; + } + sequencer->MarkConsumed(iov.iov_len); + if (one_rtt_keys_available() && + crypto_message_parser()->InputBytesRemaining() == 0) { + // If the handshake is complete and the current message has been fully + // processed then no more handshake messages are likely to arrive soon + // so release the memory in the stream sequencer. + sequencer->ReleaseBufferIfEmpty(); + } + } +} + +void QuicCryptoStream::WriteCryptoData(EncryptionLevel level, + absl::string_view data) { + if (!QuicVersionUsesCryptoFrames(session()->transport_version())) { + WriteOrBufferDataAtLevel(data, /*fin=*/false, level, + /*ack_listener=*/nullptr); + return; + } + if (data.empty()) { + QUIC_BUG(quic_bug_10322_1) << "Empty crypto data being written"; + return; + } + const bool had_buffered_data = HasBufferedCryptoFrames(); + QuicStreamSendBuffer* send_buffer = + &substreams_[QuicUtils::GetPacketNumberSpace(level)].send_buffer; + QuicStreamOffset offset = send_buffer->stream_offset(); + + // Ensure this data does not cause the send buffer for this encryption level + // to exceed its size limit. + if (GetQuicFlag(quic_bounded_crypto_send_buffer)) { + QUIC_BUG_IF(quic_crypto_stream_offset_lt_bytes_written, + offset < send_buffer->stream_bytes_written()); + uint64_t current_buffer_size = + offset - std::min(offset, send_buffer->stream_bytes_written()); + if (current_buffer_size > 0) { + QUIC_CODE_COUNT(quic_received_crypto_data_with_non_empty_send_buffer); + if (BufferSizeLimitForLevel(level) < + (current_buffer_size + data.length())) { + QUIC_BUG(quic_crypto_send_buffer_overflow) + << absl::StrCat("Too much data for crypto send buffer with level: ", + EncryptionLevelToString(level), + ", current_buffer_size: ", current_buffer_size, + ", data length: ", data.length(), + ", SNI: ", crypto_negotiated_params().sni); + OnUnrecoverableError(QUIC_INTERNAL_ERROR, + "Too much data for crypto send buffer"); + return; + } + } + } + + // Append |data| to the send buffer for this encryption level. + send_buffer->SaveStreamData(data); + if (kMaxStreamLength - offset < data.length()) { + QUIC_BUG(quic_bug_10322_2) << "Writing too much crypto handshake data"; + OnUnrecoverableError(QUIC_INTERNAL_ERROR, + "Writing too much crypto handshake data"); + return; + } + if (had_buffered_data) { + // Do not try to write if there is buffered data. + return; + } + + size_t bytes_consumed = stream_delegate()->SendCryptoData( + level, data.length(), offset, NOT_RETRANSMISSION); + send_buffer->OnStreamDataConsumed(bytes_consumed); +} + +size_t QuicCryptoStream::BufferSizeLimitForLevel(EncryptionLevel) const { + return GetQuicFlag(quic_max_buffered_crypto_bytes); +} + +bool QuicCryptoStream::OnCryptoFrameAcked(const QuicCryptoFrame& frame, + QuicTime::Delta /*ack_delay_time*/) { + QuicByteCount newly_acked_length = 0; + if (!substreams_[QuicUtils::GetPacketNumberSpace(frame.level)] + .send_buffer.OnStreamDataAcked(frame.offset, frame.data_length, + &newly_acked_length)) { + OnUnrecoverableError(QUIC_INTERNAL_ERROR, + "Trying to ack unsent crypto data."); + return false; + } + return newly_acked_length > 0; +} + +void QuicCryptoStream::OnStreamReset(const QuicRstStreamFrame& /*frame*/) { + stream_delegate()->OnStreamError(QUIC_INVALID_STREAM_ID, + "Attempt to reset crypto stream"); +} + +void QuicCryptoStream::NeuterUnencryptedStreamData() { + NeuterStreamDataOfEncryptionLevel(ENCRYPTION_INITIAL); +} + +void QuicCryptoStream::NeuterStreamDataOfEncryptionLevel( + EncryptionLevel level) { + if (!QuicVersionUsesCryptoFrames(session()->transport_version())) { + for (const auto& interval : bytes_consumed_[level]) { + QuicByteCount newly_acked_length = 0; + send_buffer().OnStreamDataAcked( + interval.min(), interval.max() - interval.min(), &newly_acked_length); + } + return; + } + QuicStreamSendBuffer* send_buffer = + &substreams_[QuicUtils::GetPacketNumberSpace(level)].send_buffer; + // TODO(nharper): Consider adding a Clear() method to QuicStreamSendBuffer + // to replace the following code. + QuicIntervalSet to_ack = send_buffer->bytes_acked(); + to_ack.Complement(0, send_buffer->stream_offset()); + for (const auto& interval : to_ack) { + QuicByteCount newly_acked_length = 0; + send_buffer->OnStreamDataAcked( + interval.min(), interval.max() - interval.min(), &newly_acked_length); + } +} + +void QuicCryptoStream::OnStreamDataConsumed(QuicByteCount bytes_consumed) { + if (QuicVersionUsesCryptoFrames(session()->transport_version())) { + QUIC_BUG(quic_bug_10322_3) + << "Stream data consumed when CRYPTO frames should be in use"; + } + if (bytes_consumed > 0) { + bytes_consumed_[session()->connection()->encryption_level()].Add( + stream_bytes_written(), stream_bytes_written() + bytes_consumed); + } + QuicStream::OnStreamDataConsumed(bytes_consumed); +} + +bool QuicCryptoStream::HasPendingCryptoRetransmission() const { + if (!QuicVersionUsesCryptoFrames(session()->transport_version())) { + return false; + } + for (const auto& substream : substreams_) { + if (substream.send_buffer.HasPendingRetransmission()) { + return true; + } + } + return false; +} + +void QuicCryptoStream::WritePendingCryptoRetransmission() { + QUIC_BUG_IF(quic_bug_12573_3, + !QuicVersionUsesCryptoFrames(session()->transport_version())) + << "Versions less than 47 don't write CRYPTO frames"; + for (uint8_t i = INITIAL_DATA; i <= APPLICATION_DATA; ++i) { + auto packet_number_space = static_cast(i); + QuicStreamSendBuffer* send_buffer = + &substreams_[packet_number_space].send_buffer; + while (send_buffer->HasPendingRetransmission()) { + auto pending = send_buffer->NextPendingRetransmission(); + size_t bytes_consumed = stream_delegate()->SendCryptoData( + GetEncryptionLevelToSendCryptoDataOfSpace(packet_number_space), + pending.length, pending.offset, HANDSHAKE_RETRANSMISSION); + send_buffer->OnStreamDataRetransmitted(pending.offset, bytes_consumed); + if (bytes_consumed < pending.length) { + return; + } + } + } +} + +void QuicCryptoStream::WritePendingRetransmission() { + while (HasPendingRetransmission()) { + StreamPendingRetransmission pending = + send_buffer().NextPendingRetransmission(); + QuicIntervalSet retransmission( + pending.offset, pending.offset + pending.length); + EncryptionLevel retransmission_encryption_level = ENCRYPTION_INITIAL; + // Determine the encryption level to write the retransmission + // at. The retransmission should be written at the same encryption level + // as the original transmission. + for (size_t i = 0; i < NUM_ENCRYPTION_LEVELS; ++i) { + if (retransmission.Intersects(bytes_consumed_[i])) { + retransmission_encryption_level = static_cast(i); + retransmission.Intersection(bytes_consumed_[i]); + break; + } + } + pending.offset = retransmission.begin()->min(); + pending.length = + retransmission.begin()->max() - retransmission.begin()->min(); + QuicConsumedData consumed = RetransmitStreamDataAtLevel( + pending.offset, pending.length, retransmission_encryption_level, + HANDSHAKE_RETRANSMISSION); + if (consumed.bytes_consumed < pending.length) { + // The connection is write blocked. + break; + } + } +} + +bool QuicCryptoStream::RetransmitStreamData(QuicStreamOffset offset, + QuicByteCount data_length, + bool /*fin*/, + TransmissionType type) { + QUICHE_DCHECK(type == HANDSHAKE_RETRANSMISSION || type == PTO_RETRANSMISSION); + QuicIntervalSet retransmission(offset, + offset + data_length); + // Determine the encryption level to send data. This only needs to be once as + // [offset, offset + data_length) is guaranteed to be in the same packet. + EncryptionLevel send_encryption_level = ENCRYPTION_INITIAL; + for (size_t i = 0; i < NUM_ENCRYPTION_LEVELS; ++i) { + if (retransmission.Intersects(bytes_consumed_[i])) { + send_encryption_level = static_cast(i); + break; + } + } + retransmission.Difference(bytes_acked()); + for (const auto& interval : retransmission) { + QuicStreamOffset retransmission_offset = interval.min(); + QuicByteCount retransmission_length = interval.max() - interval.min(); + QuicConsumedData consumed = RetransmitStreamDataAtLevel( + retransmission_offset, retransmission_length, send_encryption_level, + type); + if (consumed.bytes_consumed < retransmission_length) { + // The connection is write blocked. + return false; + } + } + + return true; +} + +QuicConsumedData QuicCryptoStream::RetransmitStreamDataAtLevel( + QuicStreamOffset retransmission_offset, QuicByteCount retransmission_length, + EncryptionLevel encryption_level, TransmissionType type) { + QUICHE_DCHECK(type == HANDSHAKE_RETRANSMISSION || type == PTO_RETRANSMISSION); + const auto consumed = stream_delegate()->WritevData( + id(), retransmission_length, retransmission_offset, NO_FIN, type, + encryption_level); + QUIC_DVLOG(1) << ENDPOINT << "stream " << id() + << " is forced to retransmit stream data [" + << retransmission_offset << ", " + << retransmission_offset + retransmission_length + << "), with encryption level: " << encryption_level + << ", consumed: " << consumed; + OnStreamFrameRetransmitted(retransmission_offset, consumed.bytes_consumed, + consumed.fin_consumed); + + return consumed; +} + +uint64_t QuicCryptoStream::crypto_bytes_read() const { + if (!QuicVersionUsesCryptoFrames(session()->transport_version())) { + return stream_bytes_read(); + } + uint64_t bytes_read = 0; + for (const CryptoSubstream& substream : substreams_) { + bytes_read += substream.sequencer.NumBytesConsumed(); + } + return bytes_read; +} + +// TODO(haoyuewang) Move this test-only method under +// quiche/quic/test_tools. +uint64_t QuicCryptoStream::BytesReadOnLevel(EncryptionLevel level) const { + return substreams_[QuicUtils::GetPacketNumberSpace(level)] + .sequencer.NumBytesConsumed(); +} + +uint64_t QuicCryptoStream::BytesSentOnLevel(EncryptionLevel level) const { + return substreams_[QuicUtils::GetPacketNumberSpace(level)] + .send_buffer.stream_bytes_written(); +} + +bool QuicCryptoStream::WriteCryptoFrame(EncryptionLevel level, + QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* writer) { + QUIC_BUG_IF(quic_bug_12573_4, + !QuicVersionUsesCryptoFrames(session()->transport_version())) + << "Versions less than 47 don't write CRYPTO frames (2)"; + return substreams_[QuicUtils::GetPacketNumberSpace(level)] + .send_buffer.WriteStreamData(offset, data_length, writer); +} + +void QuicCryptoStream::OnCryptoFrameLost(QuicCryptoFrame* crypto_frame) { + QUIC_BUG_IF(quic_bug_12573_5, + !QuicVersionUsesCryptoFrames(session()->transport_version())) + << "Versions less than 47 don't lose CRYPTO frames"; + substreams_[QuicUtils::GetPacketNumberSpace(crypto_frame->level)] + .send_buffer.OnStreamDataLost(crypto_frame->offset, + crypto_frame->data_length); +} + +bool QuicCryptoStream::RetransmitData(QuicCryptoFrame* crypto_frame, + TransmissionType type) { + QUIC_BUG_IF(quic_bug_12573_6, + !QuicVersionUsesCryptoFrames(session()->transport_version())) + << "Versions less than 47 don't retransmit CRYPTO frames"; + QuicIntervalSet retransmission( + crypto_frame->offset, crypto_frame->offset + crypto_frame->data_length); + QuicStreamSendBuffer* send_buffer = + &substreams_[QuicUtils::GetPacketNumberSpace(crypto_frame->level)] + .send_buffer; + retransmission.Difference(send_buffer->bytes_acked()); + if (retransmission.Empty()) { + return true; + } + for (const auto& interval : retransmission) { + size_t retransmission_offset = interval.min(); + size_t retransmission_length = interval.max() - interval.min(); + EncryptionLevel retransmission_encryption_level = + GetEncryptionLevelToSendCryptoDataOfSpace( + QuicUtils::GetPacketNumberSpace(crypto_frame->level)); + size_t bytes_consumed = stream_delegate()->SendCryptoData( + retransmission_encryption_level, retransmission_length, + retransmission_offset, type); + send_buffer->OnStreamDataRetransmitted(retransmission_offset, + bytes_consumed); + if (bytes_consumed < retransmission_length) { + return false; + } + } + return true; +} + +void QuicCryptoStream::WriteBufferedCryptoFrames() { + QUIC_BUG_IF(quic_bug_12573_7, + !QuicVersionUsesCryptoFrames(session()->transport_version())) + << "Versions less than 47 don't use CRYPTO frames"; + for (uint8_t i = INITIAL_DATA; i <= APPLICATION_DATA; ++i) { + auto packet_number_space = static_cast(i); + QuicStreamSendBuffer* send_buffer = + &substreams_[packet_number_space].send_buffer; + const size_t data_length = + send_buffer->stream_offset() - send_buffer->stream_bytes_written(); + if (data_length == 0) { + // No buffered data for this encryption level. + continue; + } + size_t bytes_consumed = stream_delegate()->SendCryptoData( + GetEncryptionLevelToSendCryptoDataOfSpace(packet_number_space), + data_length, send_buffer->stream_bytes_written(), NOT_RETRANSMISSION); + send_buffer->OnStreamDataConsumed(bytes_consumed); + if (bytes_consumed < data_length) { + // Connection is write blocked. + break; + } + } +} + +bool QuicCryptoStream::HasBufferedCryptoFrames() const { + QUIC_BUG_IF(quic_bug_12573_8, + !QuicVersionUsesCryptoFrames(session()->transport_version())) + << "Versions less than 47 don't use CRYPTO frames"; + for (const CryptoSubstream& substream : substreams_) { + const QuicStreamSendBuffer& send_buffer = substream.send_buffer; + QUICHE_DCHECK_GE(send_buffer.stream_offset(), + send_buffer.stream_bytes_written()); + if (send_buffer.stream_offset() > send_buffer.stream_bytes_written()) { + return true; + } + } + return false; +} + +bool QuicCryptoStream::IsFrameOutstanding(EncryptionLevel level, size_t offset, + size_t length) const { + if (!QuicVersionUsesCryptoFrames(session()->transport_version())) { + // This only happens if a client was originally configured for a version + // greater than 45, but received a version negotiation packet and is + // attempting to retransmit for a version less than 47. Outside of tests, + // this is a misconfiguration of the client, and this connection will be + // doomed. Return false here to avoid trying to retransmit CRYPTO frames on + // the wrong transport version. + return false; + } + return substreams_[QuicUtils::GetPacketNumberSpace(level)] + .send_buffer.IsStreamDataOutstanding(offset, length); +} + +bool QuicCryptoStream::IsWaitingForAcks() const { + if (!QuicVersionUsesCryptoFrames(session()->transport_version())) { + return QuicStream::IsWaitingForAcks(); + } + for (const CryptoSubstream& substream : substreams_) { + if (substream.send_buffer.stream_bytes_outstanding()) { + return true; + } + } + return false; +} + +QuicCryptoStream::CryptoSubstream::CryptoSubstream( + QuicCryptoStream* crypto_stream) + : sequencer(crypto_stream), + send_buffer(crypto_stream->session() + ->connection() + ->helper() + ->GetStreamSendBufferAllocator()) {} + +#undef ENDPOINT // undef for jumbo builds +} // namespace quic diff --git a/quiche/quic/core/quic_crypto_stream.h b/quiche/quic/core/quic_crypto_stream.h new file mode 100644 index 000000000000..0be31858c89b --- /dev/null +++ b/quiche/quic/core/quic_crypto_stream.h @@ -0,0 +1,281 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_CRYPTO_STREAM_H_ +#define QUICHE_QUIC_CORE_QUIC_CRYPTO_STREAM_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "openssl/ssl.h" +#include "quiche/quic/core/crypto/crypto_framer.h" +#include "quiche/quic/core/crypto/crypto_utils.h" +#include "quiche/quic/core/proto/cached_network_parameters_proto.h" +#include "quiche/quic/core/quic_config.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_stream.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class CachedNetworkParameters; +class QuicSession; + +// Crypto handshake messages in QUIC take place over a reserved stream with the +// id 1. Each endpoint (client and server) will allocate an instance of a +// subclass of QuicCryptoStream to send and receive handshake messages. (In the +// normal 1-RTT handshake, the client will send a client hello, CHLO, message. +// The server will receive this message and respond with a server hello message, +// SHLO. At this point both sides will have established a crypto context they +// can use to send encrypted messages. +// +// For more details: +// https://docs.google.com/document/d/1g5nIXAIkN_Y-7XJW5K45IblHd_L2f5LTaDUDwvZ5L6g/edit?usp=sharing +class QUIC_EXPORT_PRIVATE QuicCryptoStream : public QuicStream { + public: + explicit QuicCryptoStream(QuicSession* session); + QuicCryptoStream(const QuicCryptoStream&) = delete; + QuicCryptoStream& operator=(const QuicCryptoStream&) = delete; + + ~QuicCryptoStream() override; + + // Returns the per-packet framing overhead associated with sending a + // handshake message for |version|. + static QuicByteCount CryptoMessageFramingOverhead( + QuicTransportVersion version, QuicConnectionId connection_id); + + // QuicStream implementation + void OnStreamFrame(const QuicStreamFrame& frame) override; + void OnDataAvailable() override; + + // Called when a CRYPTO frame is received. + void OnCryptoFrame(const QuicCryptoFrame& frame); + + // Called when a CRYPTO frame is ACKed. + bool OnCryptoFrameAcked(const QuicCryptoFrame& frame, + QuicTime::Delta ack_delay_time); + + void OnStreamReset(const QuicRstStreamFrame& frame) override; + + // Performs key extraction to derive a new secret of |result_len| bytes + // dependent on |label|, |context|, and the stream's negotiated subkey secret. + // Returns false if the handshake has not been confirmed or the parameters are + // invalid (e.g. |label| contains null bytes); returns true on success. This + // method is only supported for IETF QUIC and MUST NOT be called in gQUIC as + // that'll trigger an assert in DEBUG build. + virtual bool ExportKeyingMaterial(absl::string_view label, + absl::string_view context, + size_t result_len, std::string* result) = 0; + + // Writes |data| to the QuicStream at level |level|. + virtual void WriteCryptoData(EncryptionLevel level, absl::string_view data); + + // Returns the ssl_early_data_reason_t describing why 0-RTT was accepted or + // rejected. Note that the value returned by this function may vary during the + // handshake. Once |one_rtt_keys_available| returns true, the value returned + // by this function will not change for the rest of the lifetime of the + // QuicCryptoStream. + virtual ssl_early_data_reason_t EarlyDataReason() const = 0; + + // Returns true once an encrypter has been set for the connection. + virtual bool encryption_established() const = 0; + + // Returns true once the crypto handshake has completed. + virtual bool one_rtt_keys_available() const = 0; + + // Returns the parameters negotiated in the crypto handshake. + virtual const QuicCryptoNegotiatedParameters& crypto_negotiated_params() + const = 0; + + // Provides the message parser to use when data is received on this stream. + virtual CryptoMessageParser* crypto_message_parser() = 0; + + // Called when a packet of encryption |level| has been successfully decrypted. + virtual void OnPacketDecrypted(EncryptionLevel level) = 0; + + // Called when a 1RTT packet has been acknowledged. + virtual void OnOneRttPacketAcknowledged() = 0; + + // Called when a packet of ENCRYPTION_HANDSHAKE gets sent. + virtual void OnHandshakePacketSent() = 0; + + // Called when a handshake done frame has been received. + virtual void OnHandshakeDoneReceived() = 0; + + // Called when a new token frame has been received. + virtual void OnNewTokenReceived(absl::string_view token) = 0; + + // Called to get an address token. + virtual std::string GetAddressToken( + const CachedNetworkParameters* cached_network_params) const = 0; + + // Called to validate |token|. + virtual bool ValidateAddressToken(absl::string_view token) const = 0; + + // Get the last CachedNetworkParameters received from a valid address token. + virtual const CachedNetworkParameters* PreviousCachedNetworkParams() + const = 0; + + // Set the CachedNetworkParameters that will be returned by + // PreviousCachedNetworkParams. + // TODO(wub): This function is test only, move it to a test only library. + virtual void SetPreviousCachedNetworkParams( + CachedNetworkParameters cached_network_params) = 0; + + // Returns current handshake state. + virtual HandshakeState GetHandshakeState() const = 0; + + // Called to provide the server-side application state that must be checked + // when performing a 0-RTT TLS resumption. + // + // On a client, this may be called at any time; 0-RTT tickets will not be + // cached until this function is called. When a 0-RTT resumption is attempted, + // QuicSession::SetApplicationState will be called with the state provided by + // a call to this function on a previous connection. + // + // On a server, this function must be called before commencing the handshake, + // otherwise 0-RTT tickets will not be issued. On subsequent connections, + // 0-RTT will be rejected if the data passed into this function does not match + // the data passed in on the connection where the 0-RTT ticket was issued. + virtual void SetServerApplicationStateForResumption( + std::unique_ptr state) = 0; + + // Returns the maximum number of bytes that can be buffered at a particular + // encryption level |level|. + virtual size_t BufferSizeLimitForLevel(EncryptionLevel level) const; + + // Called to generate a decrypter for the next key phase. Each call should + // generate the key for phase n+1. + virtual std::unique_ptr + AdvanceKeysAndCreateCurrentOneRttDecrypter() = 0; + + // Called to generate an encrypter for the same key phase of the last + // decrypter returned by AdvanceKeysAndCreateCurrentOneRttDecrypter(). + virtual std::unique_ptr CreateCurrentOneRttEncrypter() = 0; + + // Return the SSL struct object created by BoringSSL if the stream is using + // TLS1.3. Otherwise, return nullptr. + // This method is used in Envoy. + virtual SSL* GetSsl() const = 0; + + // Called to cancel retransmission of unencrypted crypto stream data. + void NeuterUnencryptedStreamData(); + + // Called to cancel retransmission of data of encryption |level|. + void NeuterStreamDataOfEncryptionLevel(EncryptionLevel level); + + // Override to record the encryption level of consumed data. + void OnStreamDataConsumed(QuicByteCount bytes_consumed) override; + + // Returns whether there are any bytes pending retransmission in CRYPTO + // frames. + virtual bool HasPendingCryptoRetransmission() const; + + // Writes any pending CRYPTO frame retransmissions. + void WritePendingCryptoRetransmission(); + + // Override to retransmit lost crypto data with the appropriate encryption + // level. + void WritePendingRetransmission() override; + + // Override to send unacked crypto data with the appropriate encryption level. + bool RetransmitStreamData(QuicStreamOffset offset, QuicByteCount data_length, + bool fin, TransmissionType type) override; + + // Sends stream retransmission data at |encryption_level|. + QuicConsumedData RetransmitStreamDataAtLevel( + QuicStreamOffset retransmission_offset, + QuicByteCount retransmission_length, EncryptionLevel encryption_level, + TransmissionType type); + + // Returns the number of bytes of handshake data that have been received from + // the peer in either CRYPTO or STREAM frames. + uint64_t crypto_bytes_read() const; + + // Returns the number of bytes of handshake data that have been received from + // the peer in CRYPTO frames at a particular encryption level. + QuicByteCount BytesReadOnLevel(EncryptionLevel level) const; + + // Returns the number of bytes of handshake data that have been sent to + // the peer in CRYPTO frames at a particular encryption level. + QuicByteCount BytesSentOnLevel(EncryptionLevel level) const; + + // Writes |data_length| of data of a crypto frame to |writer|. The data + // written is from the send buffer for encryption level |level| and starts at + // |offset|. + bool WriteCryptoFrame(EncryptionLevel level, QuicStreamOffset offset, + QuicByteCount data_length, QuicDataWriter* writer); + + // Called when data from a CRYPTO frame is considered lost. The lost data is + // identified by the encryption level, offset, and length in |crypto_frame|. + void OnCryptoFrameLost(QuicCryptoFrame* crypto_frame); + + // Called to retransmit any outstanding data in the range indicated by the + // encryption level, offset, and length in |crypto_frame|. Returns true if all + // data gets retransmitted. + bool RetransmitData(QuicCryptoFrame* crypto_frame, TransmissionType type); + + // Called to write buffered crypto frames. + void WriteBufferedCryptoFrames(); + + // Returns true if there is buffered crypto frames. + bool HasBufferedCryptoFrames() const; + + // Returns true if any portion of the data at encryption level |level| + // starting at |offset| for |length| bytes is outstanding. + bool IsFrameOutstanding(EncryptionLevel level, size_t offset, + size_t length) const; + + // Returns true if the crypto handshake is still waiting for acks of sent + // data, and false if all data has been acked. + bool IsWaitingForAcks() const; + + // Helper method for OnDataAvailable. Calls CryptoMessageParser::ProcessInput + // with the data available in |sequencer| and |level|, and marks the data + // passed to ProcessInput as consumed. + virtual void OnDataAvailableInSequencer(QuicStreamSequencer* sequencer, + EncryptionLevel level); + + QuicStreamSequencer* GetStreamSequencerForPacketNumberSpace( + PacketNumberSpace packet_number_space) { + return &substreams_[packet_number_space].sequencer; + } + + // Called by OnCryptoFrame to check if a CRYPTO frame is received at an + // expected `level`. + virtual bool IsCryptoFrameExpectedForEncryptionLevel( + EncryptionLevel level) const = 0; + + // Called to determine the encryption level to send/retransmit crypto data. + virtual EncryptionLevel GetEncryptionLevelToSendCryptoDataOfSpace( + PacketNumberSpace space) const = 0; + + private: + // Data sent and received in CRYPTO frames is sent at multiple packet number + // spaces. Some of the state for the single logical crypto stream is split + // across packet number spaces, and a CryptoSubstream is used to manage that + // state for a particular packet number space. + struct QUIC_EXPORT_PRIVATE CryptoSubstream { + CryptoSubstream(QuicCryptoStream* crypto_stream); + + QuicStreamSequencer sequencer; + QuicStreamSendBuffer send_buffer; + }; + + // Consumed data according to encryption levels. + // TODO(fayang): This is not needed once switching from QUIC crypto to + // TLS 1.3, which never encrypts crypto data. + QuicIntervalSet bytes_consumed_[NUM_ENCRYPTION_LEVELS]; + + // Keeps state for data sent/received in CRYPTO frames at each packet number + // space; + std::array substreams_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_CRYPTO_STREAM_H_ diff --git a/quiche/quic/core/quic_crypto_stream_test.cc b/quiche/quic/core/quic_crypto_stream_test.cc new file mode 100644 index 000000000000..cd6d3cf0bef9 --- /dev/null +++ b/quiche/quic/core/quic_crypto_stream_test.cc @@ -0,0 +1,815 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_crypto_stream.h" + +#include +#include +#include +#include +#include + +#include "quiche/quic/core/crypto/crypto_handshake.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_stream_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +using testing::_; +using testing::InSequence; +using testing::Invoke; +using testing::InvokeWithoutArgs; +using testing::Return; + +namespace quic { +namespace test { +namespace { + +class MockQuicCryptoStream : public QuicCryptoStream, + public QuicCryptoHandshaker { + public: + explicit MockQuicCryptoStream(QuicSession* session) + : QuicCryptoStream(session), + QuicCryptoHandshaker(this, session), + params_(new QuicCryptoNegotiatedParameters) {} + MockQuicCryptoStream(const MockQuicCryptoStream&) = delete; + MockQuicCryptoStream& operator=(const MockQuicCryptoStream&) = delete; + + void OnHandshakeMessage(const CryptoHandshakeMessage& message) override { + messages_.push_back(message); + } + + std::vector* messages() { return &messages_; } + + ssl_early_data_reason_t EarlyDataReason() const override { + return ssl_early_data_unknown; + } + bool encryption_established() const override { return false; } + bool one_rtt_keys_available() const override { return false; } + + const QuicCryptoNegotiatedParameters& crypto_negotiated_params() + const override { + return *params_; + } + CryptoMessageParser* crypto_message_parser() override { + return QuicCryptoHandshaker::crypto_message_parser(); + } + void OnPacketDecrypted(EncryptionLevel /*level*/) override {} + void OnOneRttPacketAcknowledged() override {} + void OnHandshakePacketSent() override {} + void OnHandshakeDoneReceived() override {} + void OnNewTokenReceived(absl::string_view /*token*/) override {} + std::string GetAddressToken( + const CachedNetworkParameters* /*cached_network_parameters*/) + const override { + return ""; + } + bool ValidateAddressToken(absl::string_view /*token*/) const override { + return true; + } + const CachedNetworkParameters* PreviousCachedNetworkParams() const override { + return nullptr; + } + void SetPreviousCachedNetworkParams( + CachedNetworkParameters /*cached_network_params*/) override {} + HandshakeState GetHandshakeState() const override { return HANDSHAKE_START; } + void SetServerApplicationStateForResumption( + std::unique_ptr /*application_state*/) override {} + std::unique_ptr AdvanceKeysAndCreateCurrentOneRttDecrypter() + override { + return nullptr; + } + std::unique_ptr CreateCurrentOneRttEncrypter() override { + return nullptr; + } + bool ExportKeyingMaterial(absl::string_view /*label*/, + absl::string_view /*context*/, + size_t /*result_len*/, + std::string* /*result*/) override { + return false; + } + SSL* GetSsl() const override { return nullptr; } + + bool IsCryptoFrameExpectedForEncryptionLevel( + EncryptionLevel level) const override { + return level != ENCRYPTION_ZERO_RTT; + } + + EncryptionLevel GetEncryptionLevelToSendCryptoDataOfSpace( + PacketNumberSpace space) const override { + switch (space) { + case INITIAL_DATA: + return ENCRYPTION_INITIAL; + case HANDSHAKE_DATA: + return ENCRYPTION_HANDSHAKE; + case APPLICATION_DATA: + return QuicCryptoStream::session() + ->GetEncryptionLevelToSendApplicationData(); + default: + QUICHE_DCHECK(false); + return NUM_ENCRYPTION_LEVELS; + } + } + + private: + quiche::QuicheReferenceCountedPointer params_; + std::vector messages_; +}; + +class QuicCryptoStreamTest : public QuicTest { + public: + QuicCryptoStreamTest() + : connection_(new MockQuicConnection(&helper_, &alarm_factory_, + Perspective::IS_CLIENT)), + session_(connection_, /*create_mock_crypto_stream=*/false) { + EXPECT_CALL(*static_cast(connection_->writer()), + WritePacket(_, _, _, _, _)) + .WillRepeatedly(Return(WriteResult(WRITE_STATUS_OK, 0))); + stream_ = new MockQuicCryptoStream(&session_); + session_.SetCryptoStream(stream_); + session_.Initialize(); + message_.set_tag(kSHLO); + message_.SetStringPiece(1, "abc"); + message_.SetStringPiece(2, "def"); + ConstructHandshakeMessage(); + } + QuicCryptoStreamTest(const QuicCryptoStreamTest&) = delete; + QuicCryptoStreamTest& operator=(const QuicCryptoStreamTest&) = delete; + + void ConstructHandshakeMessage() { + CryptoFramer framer; + message_data_ = framer.ConstructHandshakeMessage(message_); + } + + protected: + MockQuicConnectionHelper helper_; + MockAlarmFactory alarm_factory_; + MockQuicConnection* connection_; + MockQuicSpdySession session_; + MockQuicCryptoStream* stream_; + CryptoHandshakeMessage message_; + std::unique_ptr message_data_; +}; + +TEST_F(QuicCryptoStreamTest, NotInitiallyConected) { + EXPECT_FALSE(stream_->encryption_established()); + EXPECT_FALSE(stream_->one_rtt_keys_available()); +} + +TEST_F(QuicCryptoStreamTest, ProcessRawData) { + if (!QuicVersionUsesCryptoFrames(connection_->transport_version())) { + stream_->OnStreamFrame(QuicStreamFrame( + QuicUtils::GetCryptoStreamId(connection_->transport_version()), + /*fin=*/false, + /*offset=*/0, message_data_->AsStringPiece())); + } else { + stream_->OnCryptoFrame(QuicCryptoFrame(ENCRYPTION_INITIAL, /*offset*/ 0, + message_data_->AsStringPiece())); + } + ASSERT_EQ(1u, stream_->messages()->size()); + const CryptoHandshakeMessage& message = (*stream_->messages())[0]; + EXPECT_EQ(kSHLO, message.tag()); + EXPECT_EQ(2u, message.tag_value_map().size()); + EXPECT_EQ("abc", crypto_test_utils::GetValueForTag(message, 1)); + EXPECT_EQ("def", crypto_test_utils::GetValueForTag(message, 2)); +} + +TEST_F(QuicCryptoStreamTest, ProcessBadData) { + std::string bad(message_data_->data(), message_data_->length()); + const int kFirstTagIndex = sizeof(uint32_t) + // message tag + sizeof(uint16_t) + // number of tag-value pairs + sizeof(uint16_t); // padding + EXPECT_EQ(1, bad[kFirstTagIndex]); + bad[kFirstTagIndex] = 0x7F; // out of order tag + + EXPECT_CALL(*connection_, CloseConnection(QUIC_CRYPTO_TAGS_OUT_OF_ORDER, + testing::_, testing::_)); + if (!QuicVersionUsesCryptoFrames(connection_->transport_version())) { + stream_->OnStreamFrame(QuicStreamFrame( + QuicUtils::GetCryptoStreamId(connection_->transport_version()), + /*fin=*/false, /*offset=*/0, bad)); + } else { + stream_->OnCryptoFrame( + QuicCryptoFrame(ENCRYPTION_INITIAL, /*offset*/ 0, bad)); + } +} + +TEST_F(QuicCryptoStreamTest, NoConnectionLevelFlowControl) { + EXPECT_FALSE( + QuicStreamPeer::StreamContributesToConnectionFlowControl(stream_)); +} + +TEST_F(QuicCryptoStreamTest, RetransmitCryptoData) { + if (QuicVersionUsesCryptoFrames(connection_->transport_version())) { + return; + } + InSequence s; + // Send [0, 1350) in ENCRYPTION_INITIAL. + EXPECT_EQ(ENCRYPTION_INITIAL, connection_->encryption_level()); + std::string data(1350, 'a'); + EXPECT_CALL( + session_, + WritevData(QuicUtils::GetCryptoStreamId(connection_->transport_version()), + 1350, 0, _, _, _)) + .WillOnce(Invoke(&session_, &MockQuicSpdySession::ConsumeData)); + stream_->WriteOrBufferData(data, false, nullptr); + // Send [1350, 2700) in ENCRYPTION_ZERO_RTT. + connection_->SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + EXPECT_EQ(ENCRYPTION_ZERO_RTT, connection_->encryption_level()); + EXPECT_CALL( + session_, + WritevData(QuicUtils::GetCryptoStreamId(connection_->transport_version()), + 1350, 1350, _, _, _)) + .WillOnce(Invoke(&session_, &MockQuicSpdySession::ConsumeData)); + stream_->WriteOrBufferData(data, false, nullptr); + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(ENCRYPTION_FORWARD_SECURE, connection_->encryption_level()); + + // Lost [0, 1000). + stream_->OnStreamFrameLost(0, 1000, false); + EXPECT_TRUE(stream_->HasPendingRetransmission()); + // Lost [1200, 2000). + stream_->OnStreamFrameLost(1200, 800, false); + EXPECT_CALL( + session_, + WritevData(QuicUtils::GetCryptoStreamId(connection_->transport_version()), + 1000, 0, _, _, _)) + .WillOnce(Invoke(&session_, &MockQuicSpdySession::ConsumeData)); + // Verify [1200, 2000) are sent in [1200, 1350) and [1350, 2000) because of + // they are in different encryption levels. + EXPECT_CALL( + session_, + WritevData(QuicUtils::GetCryptoStreamId(connection_->transport_version()), + 150, 1200, _, _, _)) + .WillOnce(Invoke(&session_, &MockQuicSpdySession::ConsumeData)); + EXPECT_CALL( + session_, + WritevData(QuicUtils::GetCryptoStreamId(connection_->transport_version()), + 650, 1350, _, _, _)) + .WillOnce(Invoke(&session_, &MockQuicSpdySession::ConsumeData)); + stream_->OnCanWrite(); + EXPECT_FALSE(stream_->HasPendingRetransmission()); + // Verify connection's encryption level has restored. + EXPECT_EQ(ENCRYPTION_FORWARD_SECURE, connection_->encryption_level()); +} + +TEST_F(QuicCryptoStreamTest, RetransmitCryptoDataInCryptoFrames) { + if (!QuicVersionUsesCryptoFrames(connection_->transport_version())) { + return; + } + EXPECT_CALL(*connection_, SendCryptoData(_, _, _)).Times(0); + InSequence s; + // Send [0, 1350) in ENCRYPTION_INITIAL. + EXPECT_EQ(ENCRYPTION_INITIAL, connection_->encryption_level()); + std::string data(1350, 'a'); + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_INITIAL, 1350, 0)) + .WillOnce(Invoke(connection_, + &MockQuicConnection::QuicConnection_SendCryptoData)); + stream_->WriteCryptoData(ENCRYPTION_INITIAL, data); + // Send [1350, 2700) in ENCRYPTION_ZERO_RTT. + std::unique_ptr encrypter = + std::make_unique(Perspective::IS_CLIENT); + connection_->SetEncrypter(ENCRYPTION_ZERO_RTT, std::move(encrypter)); + connection_->SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + EXPECT_EQ(ENCRYPTION_ZERO_RTT, connection_->encryption_level()); + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_ZERO_RTT, 1350, 0)) + .WillOnce(Invoke(connection_, + &MockQuicConnection::QuicConnection_SendCryptoData)); + stream_->WriteCryptoData(ENCRYPTION_ZERO_RTT, data); + + // Before encryption moves to ENCRYPTION_FORWARD_SECURE, ZERO RTT data are + // retranmitted at ENCRYPTION_ZERO_RTT. + QuicCryptoFrame lost_frame = QuicCryptoFrame(ENCRYPTION_ZERO_RTT, 0, 650); + stream_->OnCryptoFrameLost(&lost_frame); + + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_ZERO_RTT, 650, 0)) + .WillOnce(Invoke(connection_, + &MockQuicConnection::QuicConnection_SendCryptoData)); + stream_->WritePendingCryptoRetransmission(); + + connection_->SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(Perspective::IS_CLIENT)); + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(ENCRYPTION_FORWARD_SECURE, connection_->encryption_level()); + + // Lost [0, 1000). + lost_frame = QuicCryptoFrame(ENCRYPTION_INITIAL, 0, 1000); + stream_->OnCryptoFrameLost(&lost_frame); + EXPECT_TRUE(stream_->HasPendingCryptoRetransmission()); + // Lost [1200, 2000). + lost_frame = QuicCryptoFrame(ENCRYPTION_INITIAL, 1200, 150); + stream_->OnCryptoFrameLost(&lost_frame); + lost_frame = QuicCryptoFrame(ENCRYPTION_ZERO_RTT, 0, 650); + stream_->OnCryptoFrameLost(&lost_frame); + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_INITIAL, 1000, 0)) + .WillOnce(Invoke(connection_, + &MockQuicConnection::QuicConnection_SendCryptoData)); + // Verify [1200, 2000) are sent in [1200, 1350) and [1350, 2000) because of + // they are in different encryption levels. + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_INITIAL, 150, 1200)) + .WillOnce(Invoke(connection_, + &MockQuicConnection::QuicConnection_SendCryptoData)); + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_FORWARD_SECURE, 650, 0)) + .WillOnce(Invoke(connection_, + &MockQuicConnection::QuicConnection_SendCryptoData)); + stream_->WritePendingCryptoRetransmission(); + EXPECT_FALSE(stream_->HasPendingCryptoRetransmission()); + // Verify connection's encryption level has restored. + EXPECT_EQ(ENCRYPTION_FORWARD_SECURE, connection_->encryption_level()); +} + +// Regression test for handling the missing ENCRYPTION_HANDSHAKE in +// quic_crypto_stream.cc. This test is essentially the same as +// RetransmitCryptoDataInCryptoFrames, except it uses ENCRYPTION_HANDSHAKE in +// place of ENCRYPTION_ZERO_RTT. +TEST_F(QuicCryptoStreamTest, RetransmitEncryptionHandshakeLevelCryptoFrames) { + if (!QuicVersionUsesCryptoFrames(connection_->transport_version())) { + return; + } + EXPECT_CALL(*connection_, SendCryptoData(_, _, _)).Times(0); + InSequence s; + // Send [0, 1000) in ENCRYPTION_INITIAL. + EXPECT_EQ(ENCRYPTION_INITIAL, connection_->encryption_level()); + std::string data(1000, 'a'); + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_INITIAL, 1000, 0)) + .WillOnce(Invoke(connection_, + &MockQuicConnection::QuicConnection_SendCryptoData)); + stream_->WriteCryptoData(ENCRYPTION_INITIAL, data); + // Send [1000, 2000) in ENCRYPTION_HANDSHAKE. + std::unique_ptr encrypter = + std::make_unique(Perspective::IS_CLIENT); + connection_->SetEncrypter(ENCRYPTION_HANDSHAKE, std::move(encrypter)); + connection_->SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + EXPECT_EQ(ENCRYPTION_HANDSHAKE, connection_->encryption_level()); + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_HANDSHAKE, 1000, 0)) + .WillOnce(Invoke(connection_, + &MockQuicConnection::QuicConnection_SendCryptoData)); + stream_->WriteCryptoData(ENCRYPTION_HANDSHAKE, data); + connection_->SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(Perspective::IS_CLIENT)); + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(ENCRYPTION_FORWARD_SECURE, connection_->encryption_level()); + + // Lost [1000, 1200). + QuicCryptoFrame lost_frame(ENCRYPTION_HANDSHAKE, 0, 200); + stream_->OnCryptoFrameLost(&lost_frame); + EXPECT_TRUE(stream_->HasPendingCryptoRetransmission()); + // Verify [1000, 1200) is sent. + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_HANDSHAKE, 200, 0)) + .WillOnce(Invoke(connection_, + &MockQuicConnection::QuicConnection_SendCryptoData)); + stream_->WritePendingCryptoRetransmission(); + EXPECT_FALSE(stream_->HasPendingCryptoRetransmission()); +} + +TEST_F(QuicCryptoStreamTest, NeuterUnencryptedStreamData) { + if (QuicVersionUsesCryptoFrames(connection_->transport_version())) { + return; + } + // Send [0, 1350) in ENCRYPTION_INITIAL. + EXPECT_EQ(ENCRYPTION_INITIAL, connection_->encryption_level()); + std::string data(1350, 'a'); + EXPECT_CALL( + session_, + WritevData(QuicUtils::GetCryptoStreamId(connection_->transport_version()), + 1350, 0, _, _, _)) + .WillOnce(Invoke(&session_, &MockQuicSpdySession::ConsumeData)); + stream_->WriteOrBufferData(data, false, nullptr); + // Send [1350, 2700) in ENCRYPTION_ZERO_RTT. + connection_->SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + EXPECT_EQ(ENCRYPTION_ZERO_RTT, connection_->encryption_level()); + EXPECT_CALL( + session_, + WritevData(QuicUtils::GetCryptoStreamId(connection_->transport_version()), + 1350, 1350, _, _, _)) + .WillOnce(Invoke(&session_, &MockQuicSpdySession::ConsumeData)); + stream_->WriteOrBufferData(data, false, nullptr); + + // Lost [0, 1350). + stream_->OnStreamFrameLost(0, 1350, false); + EXPECT_TRUE(stream_->HasPendingRetransmission()); + // Neuters [0, 1350). + stream_->NeuterUnencryptedStreamData(); + EXPECT_FALSE(stream_->HasPendingRetransmission()); + // Lost [0, 1350) again. + stream_->OnStreamFrameLost(0, 1350, false); + EXPECT_FALSE(stream_->HasPendingRetransmission()); + + // Lost [1350, 2000). + stream_->OnStreamFrameLost(1350, 650, false); + EXPECT_TRUE(stream_->HasPendingRetransmission()); + stream_->NeuterUnencryptedStreamData(); + EXPECT_TRUE(stream_->HasPendingRetransmission()); +} + +TEST_F(QuicCryptoStreamTest, NeuterUnencryptedCryptoData) { + if (!QuicVersionUsesCryptoFrames(connection_->transport_version())) { + return; + } + // Send [0, 1350) in ENCRYPTION_INITIAL. + EXPECT_EQ(ENCRYPTION_INITIAL, connection_->encryption_level()); + std::string data(1350, 'a'); + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_INITIAL, 1350, 0)) + .WillOnce(Invoke(connection_, + &MockQuicConnection::QuicConnection_SendCryptoData)); + stream_->WriteCryptoData(ENCRYPTION_INITIAL, data); + // Send [1350, 2700) in ENCRYPTION_ZERO_RTT. + connection_->SetEncrypter( + ENCRYPTION_ZERO_RTT, + std::make_unique(Perspective::IS_CLIENT)); + connection_->SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + std::unique_ptr encrypter = + std::make_unique(Perspective::IS_CLIENT); + connection_->SetEncrypter(ENCRYPTION_ZERO_RTT, std::move(encrypter)); + EXPECT_EQ(ENCRYPTION_ZERO_RTT, connection_->encryption_level()); + EXPECT_CALL(*connection_, SendCryptoData(_, _, _)).Times(0); + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_ZERO_RTT, 1350, 0)) + .WillOnce(Invoke(connection_, + &MockQuicConnection::QuicConnection_SendCryptoData)); + stream_->WriteCryptoData(ENCRYPTION_ZERO_RTT, data); + + // Lost [0, 1350). + QuicCryptoFrame lost_frame(ENCRYPTION_INITIAL, 0, 1350); + stream_->OnCryptoFrameLost(&lost_frame); + EXPECT_TRUE(stream_->HasPendingCryptoRetransmission()); + // Neuters [0, 1350). + stream_->NeuterUnencryptedStreamData(); + EXPECT_FALSE(stream_->HasPendingCryptoRetransmission()); + // Lost [0, 1350) again. + stream_->OnCryptoFrameLost(&lost_frame); + EXPECT_FALSE(stream_->HasPendingCryptoRetransmission()); + + // Lost [1350, 2000), which starts at offset 0 at the ENCRYPTION_ZERO_RTT + // level. + lost_frame = QuicCryptoFrame(ENCRYPTION_ZERO_RTT, 0, 650); + stream_->OnCryptoFrameLost(&lost_frame); + EXPECT_TRUE(stream_->HasPendingCryptoRetransmission()); + stream_->NeuterUnencryptedStreamData(); + EXPECT_TRUE(stream_->HasPendingCryptoRetransmission()); +} + +TEST_F(QuicCryptoStreamTest, RetransmitStreamData) { + if (QuicVersionUsesCryptoFrames(connection_->transport_version())) { + return; + } + InSequence s; + // Send [0, 1350) in ENCRYPTION_INITIAL. + EXPECT_EQ(ENCRYPTION_INITIAL, connection_->encryption_level()); + std::string data(1350, 'a'); + EXPECT_CALL( + session_, + WritevData(QuicUtils::GetCryptoStreamId(connection_->transport_version()), + 1350, 0, _, _, _)) + .WillOnce(Invoke(&session_, &MockQuicSpdySession::ConsumeData)); + stream_->WriteOrBufferData(data, false, nullptr); + // Send [1350, 2700) in ENCRYPTION_ZERO_RTT. + connection_->SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + EXPECT_EQ(ENCRYPTION_ZERO_RTT, connection_->encryption_level()); + EXPECT_CALL( + session_, + WritevData(QuicUtils::GetCryptoStreamId(connection_->transport_version()), + 1350, 1350, _, _, _)) + .WillOnce(Invoke(&session_, &MockQuicSpdySession::ConsumeData)); + stream_->WriteOrBufferData(data, false, nullptr); + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(ENCRYPTION_FORWARD_SECURE, connection_->encryption_level()); + + // Ack [2000, 2500). + QuicByteCount newly_acked_length = 0; + stream_->OnStreamFrameAcked(2000, 500, false, QuicTime::Delta::Zero(), + QuicTime::Zero(), &newly_acked_length); + EXPECT_EQ(500u, newly_acked_length); + + // Force crypto stream to send [1350, 2700) and only [1350, 1500) is consumed. + EXPECT_CALL( + session_, + WritevData(QuicUtils::GetCryptoStreamId(connection_->transport_version()), + 650, 1350, _, _, _)) + .WillOnce(InvokeWithoutArgs([this]() { + return session_.ConsumeData( + QuicUtils::GetCryptoStreamId(connection_->transport_version()), 150, + 1350, NO_FIN, HANDSHAKE_RETRANSMISSION, absl::nullopt); + })); + + EXPECT_FALSE(stream_->RetransmitStreamData(1350, 1350, false, + HANDSHAKE_RETRANSMISSION)); + // Verify connection's encryption level has restored. + EXPECT_EQ(ENCRYPTION_FORWARD_SECURE, connection_->encryption_level()); + + // Force session to send [1350, 1500) again and all data is consumed. + EXPECT_CALL( + session_, + WritevData(QuicUtils::GetCryptoStreamId(connection_->transport_version()), + 650, 1350, _, _, _)) + .WillOnce(Invoke(&session_, &MockQuicSpdySession::ConsumeData)); + EXPECT_CALL( + session_, + WritevData(QuicUtils::GetCryptoStreamId(connection_->transport_version()), + 200, 2500, _, _, _)) + .WillOnce(Invoke(&session_, &MockQuicSpdySession::ConsumeData)); + EXPECT_TRUE(stream_->RetransmitStreamData(1350, 1350, false, + HANDSHAKE_RETRANSMISSION)); + // Verify connection's encryption level has restored. + EXPECT_EQ(ENCRYPTION_FORWARD_SECURE, connection_->encryption_level()); + + EXPECT_CALL(session_, WritevData(_, _, _, _, _, _)).Times(0); + // Force to send an empty frame. + EXPECT_TRUE( + stream_->RetransmitStreamData(0, 0, false, HANDSHAKE_RETRANSMISSION)); +} + +TEST_F(QuicCryptoStreamTest, RetransmitStreamDataWithCryptoFrames) { + if (!QuicVersionUsesCryptoFrames(connection_->transport_version())) { + return; + } + InSequence s; + // Send [0, 1350) in ENCRYPTION_INITIAL. + EXPECT_EQ(ENCRYPTION_INITIAL, connection_->encryption_level()); + std::string data(1350, 'a'); + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_INITIAL, 1350, 0)) + .WillOnce(Invoke(connection_, + &MockQuicConnection::QuicConnection_SendCryptoData)); + stream_->WriteCryptoData(ENCRYPTION_INITIAL, data); + // Send [1350, 2700) in ENCRYPTION_ZERO_RTT. + std::unique_ptr encrypter = + std::make_unique(Perspective::IS_CLIENT); + connection_->SetEncrypter(ENCRYPTION_ZERO_RTT, std::move(encrypter)); + connection_->SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + EXPECT_EQ(ENCRYPTION_ZERO_RTT, connection_->encryption_level()); + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_ZERO_RTT, 1350, 0)) + .WillOnce(Invoke(connection_, + &MockQuicConnection::QuicConnection_SendCryptoData)); + stream_->WriteCryptoData(ENCRYPTION_ZERO_RTT, data); + connection_->SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(Perspective::IS_CLIENT)); + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(ENCRYPTION_FORWARD_SECURE, connection_->encryption_level()); + + // Ack [2000, 2500). + QuicCryptoFrame acked_frame(ENCRYPTION_ZERO_RTT, 650, 500); + EXPECT_TRUE( + stream_->OnCryptoFrameAcked(acked_frame, QuicTime::Delta::Zero())); + + // Retransmit only [1350, 1500). + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_FORWARD_SECURE, 150, 0)) + .WillOnce(Invoke(connection_, + &MockQuicConnection::QuicConnection_SendCryptoData)); + QuicCryptoFrame frame_to_retransmit(ENCRYPTION_ZERO_RTT, 0, 150); + stream_->RetransmitData(&frame_to_retransmit, HANDSHAKE_RETRANSMISSION); + + // Verify connection's encryption level has restored. + EXPECT_EQ(ENCRYPTION_FORWARD_SECURE, connection_->encryption_level()); + + // Retransmit [1350, 2700) again and all data is sent. + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_FORWARD_SECURE, 650, 0)) + .WillOnce(Invoke(connection_, + &MockQuicConnection::QuicConnection_SendCryptoData)); + EXPECT_CALL(*connection_, + SendCryptoData(ENCRYPTION_FORWARD_SECURE, 200, 1150)) + .WillOnce(Invoke(connection_, + &MockQuicConnection::QuicConnection_SendCryptoData)); + frame_to_retransmit = QuicCryptoFrame(ENCRYPTION_ZERO_RTT, 0, 1350); + stream_->RetransmitData(&frame_to_retransmit, HANDSHAKE_RETRANSMISSION); + // Verify connection's encryption level has restored. + EXPECT_EQ(ENCRYPTION_FORWARD_SECURE, connection_->encryption_level()); + + EXPECT_CALL(*connection_, SendCryptoData(_, _, _)).Times(0); + // Force to send an empty frame. + QuicCryptoFrame empty_frame(ENCRYPTION_FORWARD_SECURE, 0, 0); + stream_->RetransmitData(&empty_frame, HANDSHAKE_RETRANSMISSION); +} + +// Regression test for b/115926584. +TEST_F(QuicCryptoStreamTest, HasUnackedCryptoData) { + if (QuicVersionUsesCryptoFrames(connection_->transport_version())) { + return; + } + std::string data(1350, 'a'); + EXPECT_CALL( + session_, + WritevData(QuicUtils::GetCryptoStreamId(connection_->transport_version()), + 1350, 0, _, _, _)) + .WillOnce(testing::Return(QuicConsumedData(0, false))); + stream_->WriteOrBufferData(data, false, nullptr); + EXPECT_FALSE(stream_->IsWaitingForAcks()); + // Although there is no outstanding data, verify session has pending crypto + // data. + EXPECT_TRUE(session_.HasUnackedCryptoData()); + + EXPECT_CALL( + session_, + WritevData(QuicUtils::GetCryptoStreamId(connection_->transport_version()), + 1350, 0, _, _, _)) + .WillOnce(Invoke(&session_, &MockQuicSpdySession::ConsumeData)); + stream_->OnCanWrite(); + EXPECT_TRUE(stream_->IsWaitingForAcks()); + EXPECT_TRUE(session_.HasUnackedCryptoData()); +} + +TEST_F(QuicCryptoStreamTest, HasUnackedCryptoDataWithCryptoFrames) { + if (!QuicVersionUsesCryptoFrames(connection_->transport_version())) { + return; + } + // Send [0, 1350) in ENCRYPTION_INITIAL. + EXPECT_EQ(ENCRYPTION_INITIAL, connection_->encryption_level()); + std::string data(1350, 'a'); + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_INITIAL, 1350, 0)) + .WillOnce(Invoke(connection_, + &MockQuicConnection::QuicConnection_SendCryptoData)); + stream_->WriteCryptoData(ENCRYPTION_INITIAL, data); + EXPECT_TRUE(stream_->IsWaitingForAcks()); + EXPECT_TRUE(session_.HasUnackedCryptoData()); +} + +// Regression test for bugfix of GetPacketHeaderSize. +TEST_F(QuicCryptoStreamTest, CryptoMessageFramingOverhead) { + for (const ParsedQuicVersion& version : + AllSupportedVersionsWithQuicCrypto()) { + SCOPED_TRACE(version); + QuicByteCount expected_overhead = 48; + if (version.HasIetfInvariantHeader()) { + expected_overhead += 4; + } + if (version.HasLongHeaderLengths()) { + expected_overhead += 3; + } + if (version.HasLengthPrefixedConnectionIds()) { + expected_overhead += 1; + } + EXPECT_EQ(expected_overhead, + QuicCryptoStream::CryptoMessageFramingOverhead( + version.transport_version, TestConnectionId())); + } +} + +TEST_F(QuicCryptoStreamTest, WriteCryptoDataExceedsSendBufferLimit) { + if (!QuicVersionUsesCryptoFrames(connection_->transport_version())) { + return; + } + EXPECT_EQ(ENCRYPTION_INITIAL, connection_->encryption_level()); + int32_t buffer_limit = GetQuicFlag(quic_max_buffered_crypto_bytes); + + // Write data larger than the buffer limit, when there is no existing data in + // the buffer. Data is sent rather than closing the connection. + EXPECT_FALSE(stream_->HasBufferedCryptoFrames()); + int32_t over_limit = buffer_limit + 1; + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_INITIAL, over_limit, 0)) + // All the data is sent, no resulting buffer. + .WillOnce(Return(over_limit)); + std::string large_data(over_limit, 'a'); + stream_->WriteCryptoData(ENCRYPTION_INITIAL, large_data); + + // Write data to the buffer up to the limit. One byte gets sent. + EXPECT_FALSE(stream_->HasBufferedCryptoFrames()); + EXPECT_CALL(*connection_, + SendCryptoData(ENCRYPTION_INITIAL, buffer_limit, over_limit)) + .WillOnce(Return(1)); + std::string data(buffer_limit, 'a'); + stream_->WriteCryptoData(ENCRYPTION_INITIAL, data); + EXPECT_TRUE(stream_->HasBufferedCryptoFrames()); + + // Write another byte that is not sent (due to there already being data in the + // buffer); send buffer is now full. + EXPECT_CALL(*connection_, SendCryptoData(_, _, _)).Times(0); + std::string data2(1, 'a'); + stream_->WriteCryptoData(ENCRYPTION_INITIAL, data2); + EXPECT_TRUE(stream_->HasBufferedCryptoFrames()); + + // Writing an additional byte to the send buffer closes the connection. + if (GetQuicFlag(quic_bounded_crypto_send_buffer)) { + EXPECT_CALL(*connection_, CloseConnection(QUIC_INTERNAL_ERROR, _, _)); + EXPECT_QUIC_BUG( + stream_->WriteCryptoData(ENCRYPTION_INITIAL, data2), + "Too much data for crypto send buffer with level: ENCRYPTION_INITIAL, " + "current_buffer_size: 16384, data length: 1"); + } +} + +TEST_F(QuicCryptoStreamTest, WriteBufferedCryptoFrames) { + if (!QuicVersionUsesCryptoFrames(connection_->transport_version())) { + return; + } + EXPECT_FALSE(stream_->HasBufferedCryptoFrames()); + InSequence s; + // Send [0, 1350) in ENCRYPTION_INITIAL. + EXPECT_EQ(ENCRYPTION_INITIAL, connection_->encryption_level()); + std::string data(1350, 'a'); + // Only consumed 1000 bytes. + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_INITIAL, 1350, 0)) + .WillOnce(Return(1000)); + stream_->WriteCryptoData(ENCRYPTION_INITIAL, data); + EXPECT_TRUE(stream_->HasBufferedCryptoFrames()); + + // Send [1350, 2700) in ENCRYPTION_ZERO_RTT and verify no write is attempted + // because there is buffered data. + EXPECT_CALL(*connection_, SendCryptoData(_, _, _)).Times(0); + connection_->SetEncrypter( + ENCRYPTION_ZERO_RTT, + std::make_unique(Perspective::IS_CLIENT)); + connection_->SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + stream_->WriteCryptoData(ENCRYPTION_ZERO_RTT, data); + EXPECT_EQ(ENCRYPTION_ZERO_RTT, connection_->encryption_level()); + + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_INITIAL, 350, 1000)) + .WillOnce(Return(350)); + // Partial write of ENCRYPTION_ZERO_RTT data. + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_ZERO_RTT, 1350, 0)) + .WillOnce(Return(1000)); + stream_->WriteBufferedCryptoFrames(); + EXPECT_TRUE(stream_->HasBufferedCryptoFrames()); + EXPECT_EQ(ENCRYPTION_ZERO_RTT, connection_->encryption_level()); + + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_ZERO_RTT, 350, 1000)) + .WillOnce(Return(350)); + stream_->WriteBufferedCryptoFrames(); + EXPECT_FALSE(stream_->HasBufferedCryptoFrames()); +} + +TEST_F(QuicCryptoStreamTest, LimitBufferedCryptoData) { + if (!QuicVersionUsesCryptoFrames(connection_->transport_version())) { + return; + } + + EXPECT_CALL(*connection_, + CloseConnection(QUIC_FLOW_CONTROL_RECEIVED_TOO_MUCH_DATA, _, _)); + std::string large_frame(2 * GetQuicFlag(quic_max_buffered_crypto_bytes), 'a'); + + // Set offset to 1 so that we guarantee the data gets buffered instead of + // immediately processed. + QuicStreamOffset offset = 1; + stream_->OnCryptoFrame( + QuicCryptoFrame(ENCRYPTION_INITIAL, offset, large_frame)); +} + +TEST_F(QuicCryptoStreamTest, CloseConnectionWithZeroRttCryptoFrame) { + if (!QuicVersionUsesCryptoFrames(connection_->transport_version())) { + return; + } + + EXPECT_CALL(*connection_, + CloseConnection(IETF_QUIC_PROTOCOL_VIOLATION, _, _)); + + test::QuicConnectionPeer::SetLastDecryptedLevel(connection_, + ENCRYPTION_ZERO_RTT); + QuicStreamOffset offset = 1; + stream_->OnCryptoFrame(QuicCryptoFrame(ENCRYPTION_ZERO_RTT, offset, "data")); +} + +TEST_F(QuicCryptoStreamTest, RetransmitCryptoFramesAndPartialWrite) { + if (!QuicVersionUsesCryptoFrames(connection_->transport_version())) { + return; + } + + EXPECT_CALL(*connection_, SendCryptoData(_, _, _)).Times(0); + InSequence s; + // Send [0, 1350) in ENCRYPTION_INITIAL. + EXPECT_EQ(ENCRYPTION_INITIAL, connection_->encryption_level()); + std::string data(1350, 'a'); + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_INITIAL, 1350, 0)) + .WillOnce(Invoke(connection_, + &MockQuicConnection::QuicConnection_SendCryptoData)); + stream_->WriteCryptoData(ENCRYPTION_INITIAL, data); + + // Lost [0, 1000). + QuicCryptoFrame lost_frame(ENCRYPTION_INITIAL, 0, 1000); + stream_->OnCryptoFrameLost(&lost_frame); + EXPECT_TRUE(stream_->HasPendingCryptoRetransmission()); + // Simulate connection is constrained by amplification restriction. + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_INITIAL, 1000, 0)) + .WillOnce(Return(0)); + stream_->WritePendingCryptoRetransmission(); + EXPECT_TRUE(stream_->HasPendingCryptoRetransmission()); + // Connection gets unblocked. + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_INITIAL, 1000, 0)) + .WillOnce(Invoke(connection_, + &MockQuicConnection::QuicConnection_SendCryptoData)); + stream_->WritePendingCryptoRetransmission(); + EXPECT_FALSE(stream_->HasPendingCryptoRetransmission()); +} + +// Regression test for b/203199510 +TEST_F(QuicCryptoStreamTest, EmptyCryptoFrame) { + if (!QuicVersionUsesCryptoFrames(connection_->transport_version())) { + return; + } + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + QuicCryptoFrame empty_crypto_frame(ENCRYPTION_INITIAL, 0, nullptr, 0); + stream_->OnCryptoFrame(empty_crypto_frame); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_data_reader.cc b/quiche/quic/core/quic_data_reader.cc new file mode 100644 index 000000000000..aa8b278ec011 --- /dev/null +++ b/quiche/quic/core/quic_data_reader.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_data_reader.h" + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/common/quiche_endian.h" + +namespace quic { + +QuicDataReader::QuicDataReader(absl::string_view data) + : quiche::QuicheDataReader(data) {} + +QuicDataReader::QuicDataReader(const char* data, const size_t len) + : QuicDataReader(data, len, quiche::NETWORK_BYTE_ORDER) {} + +QuicDataReader::QuicDataReader(const char* data, const size_t len, + quiche::Endianness endianness) + : quiche::QuicheDataReader(data, len, endianness) {} + +bool QuicDataReader::ReadUFloat16(uint64_t* result) { + uint16_t value; + if (!ReadUInt16(&value)) { + return false; + } + + *result = value; + if (*result < (1 << kUFloat16MantissaEffectiveBits)) { + // Fast path: either the value is denormalized (no hidden bit), or + // normalized (hidden bit set, exponent offset by one) with exponent zero. + // Zero exponent offset by one sets the bit exactly where the hidden bit is. + // So in both cases the value encodes itself. + return true; + } + + uint16_t exponent = + value >> kUFloat16MantissaBits; // No sign extend on uint! + // After the fast pass, the exponent is at least one (offset by one). + // Un-offset the exponent. + --exponent; + QUICHE_DCHECK_GE(exponent, 1); + QUICHE_DCHECK_LE(exponent, kUFloat16MaxExponent); + // Here we need to clear the exponent and set the hidden bit. We have already + // decremented the exponent, so when we subtract it, it leaves behind the + // hidden bit. + *result -= exponent << kUFloat16MantissaBits; + *result <<= exponent; + QUICHE_DCHECK_GE(*result, + static_cast(1 << kUFloat16MantissaEffectiveBits)); + QUICHE_DCHECK_LE(*result, kUFloat16MaxValue); + return true; +} + +bool QuicDataReader::ReadConnectionId(QuicConnectionId* connection_id, + uint8_t length) { + if (length == 0) { + connection_id->set_length(0); + return true; + } + + if (BytesRemaining() < length) { + return false; + } + + connection_id->set_length(length); + const bool ok = + ReadBytes(connection_id->mutable_data(), connection_id->length()); + QUICHE_DCHECK(ok); + return ok; +} + +bool QuicDataReader::ReadLengthPrefixedConnectionId( + QuicConnectionId* connection_id) { + uint8_t connection_id_length; + if (!ReadUInt8(&connection_id_length)) { + return false; + } + return ReadConnectionId(connection_id, connection_id_length); +} + +#undef ENDPOINT // undef for jumbo builds +} // namespace quic diff --git a/quiche/quic/core/quic_data_reader.h b/quiche/quic/core/quic_data_reader.h new file mode 100644 index 000000000000..0a907e0208f1 --- /dev/null +++ b/quiche/quic/core/quic_data_reader.h @@ -0,0 +1,69 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_DATA_READER_H_ +#define QUICHE_QUIC_CORE_QUIC_DATA_READER_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/common/quiche_data_reader.h" +#include "quiche/common/quiche_endian.h" + +namespace quic { + +// Used for reading QUIC data. Though there isn't really anything terribly +// QUIC-specific here, it's a helper class that's useful when doing QUIC +// framing. +// +// To use, simply construct a QuicDataReader using the underlying buffer that +// you'd like to read fields from, then call one of the Read*() methods to +// actually do some reading. +// +// This class keeps an internal iterator to keep track of what's already been +// read and each successive Read*() call automatically increments said iterator +// on success. On failure, internal state of the QuicDataReader should not be +// trusted and it is up to the caller to throw away the failed instance and +// handle the error as appropriate. None of the Read*() methods should ever be +// called after failure, as they will also fail immediately. +class QUIC_EXPORT_PRIVATE QuicDataReader : public quiche::QuicheDataReader { + public: + // Constructs a reader using NETWORK_BYTE_ORDER endianness. + // Caller must provide an underlying buffer to work on. + explicit QuicDataReader(absl::string_view data); + // Constructs a reader using NETWORK_BYTE_ORDER endianness. + // Caller must provide an underlying buffer to work on. + QuicDataReader(const char* data, const size_t len); + // Constructs a reader using the specified endianness. + // Caller must provide an underlying buffer to work on. + QuicDataReader(const char* data, const size_t len, + quiche::Endianness endianness); + QuicDataReader(const QuicDataReader&) = delete; + QuicDataReader& operator=(const QuicDataReader&) = delete; + + // Empty destructor. + ~QuicDataReader() {} + + // Reads a 16-bit unsigned float into the given output parameter. + // Forwards the internal iterator on success. + // Returns true on success, false otherwise. + bool ReadUFloat16(uint64_t* result); + + // Reads connection ID into the given output parameter. + // Forwards the internal iterator on success. + // Returns true on success, false otherwise. + bool ReadConnectionId(QuicConnectionId* connection_id, uint8_t length); + + // Reads 8-bit connection ID length followed by connection ID of that length. + // Forwards the internal iterator on success. + // Returns true on success, false otherwise. + bool ReadLengthPrefixedConnectionId(QuicConnectionId* connection_id); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_DATA_READER_H_ diff --git a/quiche/quic/core/quic_data_writer.cc b/quiche/quic/core/quic_data_writer.cc new file mode 100644 index 000000000000..09f1923052b0 --- /dev/null +++ b/quiche/quic/core/quic_data_writer.cc @@ -0,0 +1,105 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_data_writer.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/common/quiche_endian.h" + +namespace quic { + +QuicDataWriter::QuicDataWriter(size_t size, char* buffer) + : quiche::QuicheDataWriter(size, buffer) {} + +QuicDataWriter::QuicDataWriter(size_t size, char* buffer, + quiche::Endianness endianness) + : quiche::QuicheDataWriter(size, buffer, endianness) {} + +QuicDataWriter::~QuicDataWriter() {} + +bool QuicDataWriter::WriteUFloat16(uint64_t value) { + uint16_t result; + if (value < (UINT64_C(1) << kUFloat16MantissaEffectiveBits)) { + // Fast path: either the value is denormalized, or has exponent zero. + // Both cases are represented by the value itself. + result = static_cast(value); + } else if (value >= kUFloat16MaxValue) { + // Value is out of range; clamp it to the maximum representable. + result = std::numeric_limits::max(); + } else { + // The highest bit is between position 13 and 42 (zero-based), which + // corresponds to exponent 1-30. In the output, mantissa is from 0 to 10, + // hidden bit is 11 and exponent is 11 to 15. Shift the highest bit to 11 + // and count the shifts. + uint16_t exponent = 0; + for (uint16_t offset = 16; offset > 0; offset /= 2) { + // Right-shift the value until the highest bit is in position 11. + // For offset of 16, 8, 4, 2 and 1 (binary search over 1-30), + // shift if the bit is at or above 11 + offset. + if (value >= (UINT64_C(1) << (kUFloat16MantissaBits + offset))) { + exponent += offset; + value >>= offset; + } + } + + QUICHE_DCHECK_GE(exponent, 1); + QUICHE_DCHECK_LE(exponent, kUFloat16MaxExponent); + QUICHE_DCHECK_GE(value, UINT64_C(1) << kUFloat16MantissaBits); + QUICHE_DCHECK_LT(value, UINT64_C(1) << kUFloat16MantissaEffectiveBits); + + // Hidden bit (position 11) is set. We should remove it and increment the + // exponent. Equivalently, we just add it to the exponent. + // This hides the bit. + result = static_cast(value + (exponent << kUFloat16MantissaBits)); + } + + if (endianness() == quiche::NETWORK_BYTE_ORDER) { + result = quiche::QuicheEndian::HostToNet16(result); + } + return WriteBytes(&result, sizeof(result)); +} + +bool QuicDataWriter::WriteConnectionId(QuicConnectionId connection_id) { + if (connection_id.IsEmpty()) { + return true; + } + return WriteBytes(connection_id.data(), connection_id.length()); +} + +bool QuicDataWriter::WriteLengthPrefixedConnectionId( + QuicConnectionId connection_id) { + return WriteUInt8(connection_id.length()) && WriteConnectionId(connection_id); +} + +bool QuicDataWriter::WriteRandomBytes(QuicRandom* random, size_t length) { + char* dest = BeginWrite(length); + if (!dest) { + return false; + } + + random->RandBytes(dest, length); + IncreaseLength(length); + return true; +} + +bool QuicDataWriter::WriteInsecureRandomBytes(QuicRandom* random, + size_t length) { + char* dest = BeginWrite(length); + if (!dest) { + return false; + } + + random->InsecureRandBytes(dest, length); + IncreaseLength(length); + return true; +} + +} // namespace quic diff --git a/quiche/quic/core/quic_data_writer.h b/quiche/quic/core/quic_data_writer.h new file mode 100644 index 000000000000..cd2486a591f7 --- /dev/null +++ b/quiche/quic/core/quic_data_writer.h @@ -0,0 +1,61 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_DATA_WRITER_H_ +#define QUICHE_QUIC_CORE_QUIC_DATA_WRITER_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/common/quiche_data_writer.h" +#include "quiche/common/quiche_endian.h" + +namespace quic { + +// This class provides facilities for packing QUIC data. +// +// The QuicDataWriter supports appending primitive values (int, string, etc) +// to a frame instance. The internal memory buffer is exposed as the "data" +// of the QuicDataWriter. +class QUIC_EXPORT_PRIVATE QuicDataWriter : public quiche::QuicheDataWriter { + public: + // Creates a QuicDataWriter where |buffer| is not owned + // using NETWORK_BYTE_ORDER endianness. + QuicDataWriter(size_t size, char* buffer); + // Creates a QuicDataWriter where |buffer| is not owned + // using the specified endianness. + QuicDataWriter(size_t size, char* buffer, quiche::Endianness endianness); + QuicDataWriter(const QuicDataWriter&) = delete; + QuicDataWriter& operator=(const QuicDataWriter&) = delete; + + ~QuicDataWriter(); + + // Methods for adding to the payload. These values are appended to the end + // of the QuicDataWriter payload. + + // Write unsigned floating point corresponding to the value. Large values are + // clamped to the maximum representable (kUFloat16MaxValue). Values that can + // not be represented directly are rounded down. + bool WriteUFloat16(uint64_t value); + // Write connection ID to the payload. + bool WriteConnectionId(QuicConnectionId connection_id); + + // Write 8-bit length followed by connection ID to the payload. + bool WriteLengthPrefixedConnectionId(QuicConnectionId connection_id); + + // Write |length| random bytes generated by |random|. + bool WriteRandomBytes(QuicRandom* random, size_t length); + + // Write |length| random bytes generated by |random|. This MUST NOT be used + // for any application that requires cryptographically-secure randomness. + bool WriteInsecureRandomBytes(QuicRandom* random, size_t length); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_DATA_WRITER_H_ diff --git a/quiche/quic/core/quic_data_writer_test.cc b/quiche/quic/core/quic_data_writer_test.cc new file mode 100644 index 000000000000..9d454e93abfe --- /dev/null +++ b/quiche/quic/core/quic_data_writer_test.cc @@ -0,0 +1,874 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_data_writer.h" + +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_data_reader.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/quiche_endian.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +namespace quic { +namespace test { +namespace { + +char* AsChars(unsigned char* data) { return reinterpret_cast(data); } + +struct TestParams { + explicit TestParams(quiche::Endianness endianness) : endianness(endianness) {} + + quiche::Endianness endianness; +}; + +// Used by ::testing::PrintToStringParamName(). +std::string PrintToString(const TestParams& p) { + return absl::StrCat( + (p.endianness == quiche::NETWORK_BYTE_ORDER ? "Network" : "Host"), + "ByteOrder"); +} + +std::vector GetTestParams() { + std::vector params; + for (quiche::Endianness endianness : + {quiche::NETWORK_BYTE_ORDER, quiche::HOST_BYTE_ORDER}) { + params.push_back(TestParams(endianness)); + } + return params; +} + +class QuicDataWriterTest : public QuicTestWithParam {}; + +INSTANTIATE_TEST_SUITE_P(QuicDataWriterTests, QuicDataWriterTest, + ::testing::ValuesIn(GetTestParams()), + ::testing::PrintToStringParamName()); + +TEST_P(QuicDataWriterTest, SanityCheckUFloat16Consts) { + // Check the arithmetic on the constants - otherwise the values below make + // no sense. + EXPECT_EQ(30, kUFloat16MaxExponent); + EXPECT_EQ(11, kUFloat16MantissaBits); + EXPECT_EQ(12, kUFloat16MantissaEffectiveBits); + EXPECT_EQ(UINT64_C(0x3FFC0000000), kUFloat16MaxValue); +} + +TEST_P(QuicDataWriterTest, WriteUFloat16) { + struct TestCase { + uint64_t decoded; + uint16_t encoded; + }; + TestCase test_cases[] = { + // Small numbers represent themselves. + {0, 0}, + {1, 1}, + {2, 2}, + {3, 3}, + {4, 4}, + {5, 5}, + {6, 6}, + {7, 7}, + {15, 15}, + {31, 31}, + {42, 42}, + {123, 123}, + {1234, 1234}, + // Check transition through 2^11. + {2046, 2046}, + {2047, 2047}, + {2048, 2048}, + {2049, 2049}, + // Running out of mantissa at 2^12. + {4094, 4094}, + {4095, 4095}, + {4096, 4096}, + {4097, 4096}, + {4098, 4097}, + {4099, 4097}, + {4100, 4098}, + {4101, 4098}, + // Check transition through 2^13. + {8190, 6143}, + {8191, 6143}, + {8192, 6144}, + {8193, 6144}, + {8194, 6144}, + {8195, 6144}, + {8196, 6145}, + {8197, 6145}, + // Half-way through the exponents. + {0x7FF8000, 0x87FF}, + {0x7FFFFFF, 0x87FF}, + {0x8000000, 0x8800}, + {0xFFF0000, 0x8FFF}, + {0xFFFFFFF, 0x8FFF}, + {0x10000000, 0x9000}, + // Transition into the largest exponent. + {0x1FFFFFFFFFE, 0xF7FF}, + {0x1FFFFFFFFFF, 0xF7FF}, + {0x20000000000, 0xF800}, + {0x20000000001, 0xF800}, + {0x2003FFFFFFE, 0xF800}, + {0x2003FFFFFFF, 0xF800}, + {0x20040000000, 0xF801}, + {0x20040000001, 0xF801}, + // Transition into the max value and clamping. + {0x3FF80000000, 0xFFFE}, + {0x3FFBFFFFFFF, 0xFFFE}, + {0x3FFC0000000, 0xFFFF}, + {0x3FFC0000001, 0xFFFF}, + {0x3FFFFFFFFFF, 0xFFFF}, + {0x40000000000, 0xFFFF}, + {0xFFFFFFFFFFFFFFFF, 0xFFFF}, + }; + int num_test_cases = sizeof(test_cases) / sizeof(test_cases[0]); + + for (int i = 0; i < num_test_cases; ++i) { + char buffer[2]; + QuicDataWriter writer(2, buffer, GetParam().endianness); + EXPECT_TRUE(writer.WriteUFloat16(test_cases[i].decoded)); + uint16_t result = *reinterpret_cast(writer.data()); + if (GetParam().endianness == quiche::NETWORK_BYTE_ORDER) { + result = quiche::QuicheEndian::HostToNet16(result); + } + EXPECT_EQ(test_cases[i].encoded, result); + } +} + +TEST_P(QuicDataWriterTest, ReadUFloat16) { + struct TestCase { + uint64_t decoded; + uint16_t encoded; + }; + TestCase test_cases[] = { + // There are fewer decoding test cases because encoding truncates, and + // decoding returns the smallest expansion. + // Small numbers represent themselves. + {0, 0}, + {1, 1}, + {2, 2}, + {3, 3}, + {4, 4}, + {5, 5}, + {6, 6}, + {7, 7}, + {15, 15}, + {31, 31}, + {42, 42}, + {123, 123}, + {1234, 1234}, + // Check transition through 2^11. + {2046, 2046}, + {2047, 2047}, + {2048, 2048}, + {2049, 2049}, + // Running out of mantissa at 2^12. + {4094, 4094}, + {4095, 4095}, + {4096, 4096}, + {4098, 4097}, + {4100, 4098}, + // Check transition through 2^13. + {8190, 6143}, + {8192, 6144}, + {8196, 6145}, + // Half-way through the exponents. + {0x7FF8000, 0x87FF}, + {0x8000000, 0x8800}, + {0xFFF0000, 0x8FFF}, + {0x10000000, 0x9000}, + // Transition into the largest exponent. + {0x1FFE0000000, 0xF7FF}, + {0x20000000000, 0xF800}, + {0x20040000000, 0xF801}, + // Transition into the max value. + {0x3FF80000000, 0xFFFE}, + {0x3FFC0000000, 0xFFFF}, + }; + int num_test_cases = sizeof(test_cases) / sizeof(test_cases[0]); + + for (int i = 0; i < num_test_cases; ++i) { + uint16_t encoded_ufloat = test_cases[i].encoded; + if (GetParam().endianness == quiche::NETWORK_BYTE_ORDER) { + encoded_ufloat = quiche::QuicheEndian::HostToNet16(encoded_ufloat); + } + QuicDataReader reader(reinterpret_cast(&encoded_ufloat), 2, + GetParam().endianness); + uint64_t value; + EXPECT_TRUE(reader.ReadUFloat16(&value)); + EXPECT_EQ(test_cases[i].decoded, value); + } +} + +TEST_P(QuicDataWriterTest, RoundTripUFloat16) { + // Just test all 16-bit encoded values. 0 and max already tested above. + uint64_t previous_value = 0; + for (uint16_t i = 1; i < 0xFFFF; ++i) { + // Read the two bytes. + uint16_t read_number = i; + if (GetParam().endianness == quiche::NETWORK_BYTE_ORDER) { + read_number = quiche::QuicheEndian::HostToNet16(read_number); + } + QuicDataReader reader(reinterpret_cast(&read_number), 2, + GetParam().endianness); + uint64_t value; + // All values must be decodable. + EXPECT_TRUE(reader.ReadUFloat16(&value)); + // Check that small numbers represent themselves + if (i < 4097) { + EXPECT_EQ(i, value); + } + // Check there's monotonic growth. + EXPECT_LT(previous_value, value); + // Check that precision is within 0.5% away from the denormals. + if (i > 2000) { + EXPECT_GT(previous_value * 1005, value * 1000); + } + // Check we're always within the promised range. + EXPECT_LT(value, UINT64_C(0x3FFC0000000)); + previous_value = value; + char buffer[6]; + QuicDataWriter writer(6, buffer, GetParam().endianness); + EXPECT_TRUE(writer.WriteUFloat16(value - 1)); + EXPECT_TRUE(writer.WriteUFloat16(value)); + EXPECT_TRUE(writer.WriteUFloat16(value + 1)); + // Check minimal decoding (previous decoding has previous encoding). + uint16_t encoded1 = *reinterpret_cast(writer.data()); + uint16_t encoded2 = *reinterpret_cast(writer.data() + 2); + uint16_t encoded3 = *reinterpret_cast(writer.data() + 4); + if (GetParam().endianness == quiche::NETWORK_BYTE_ORDER) { + encoded1 = quiche::QuicheEndian::NetToHost16(encoded1); + encoded2 = quiche::QuicheEndian::NetToHost16(encoded2); + encoded3 = quiche::QuicheEndian::NetToHost16(encoded3); + } + EXPECT_EQ(i - 1, encoded1); + // Check roundtrip. + EXPECT_EQ(i, encoded2); + // Check next decoding. + EXPECT_EQ(i < 4096 ? i + 1 : i, encoded3); + } +} + +TEST_P(QuicDataWriterTest, WriteConnectionId) { + QuicConnectionId connection_id = + TestConnectionId(UINT64_C(0x0011223344556677)); + char big_endian[] = { + 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, + }; + EXPECT_EQ(connection_id.length(), ABSL_ARRAYSIZE(big_endian)); + ASSERT_LE(connection_id.length(), 255); + char buffer[255]; + QuicDataWriter writer(connection_id.length(), buffer, GetParam().endianness); + EXPECT_TRUE(writer.WriteConnectionId(connection_id)); + quiche::test::CompareCharArraysWithHexError( + "connection_id", buffer, connection_id.length(), big_endian, + connection_id.length()); + + QuicConnectionId read_connection_id; + QuicDataReader reader(buffer, connection_id.length(), GetParam().endianness); + EXPECT_TRUE( + reader.ReadConnectionId(&read_connection_id, ABSL_ARRAYSIZE(big_endian))); + EXPECT_EQ(connection_id, read_connection_id); +} + +TEST_P(QuicDataWriterTest, LengthPrefixedConnectionId) { + QuicConnectionId connection_id = + TestConnectionId(UINT64_C(0x0011223344556677)); + char length_prefixed_connection_id[] = { + 0x08, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, + }; + EXPECT_EQ(ABSL_ARRAYSIZE(length_prefixed_connection_id), + kConnectionIdLengthSize + connection_id.length()); + char buffer[kConnectionIdLengthSize + 255] = {}; + QuicDataWriter writer(ABSL_ARRAYSIZE(buffer), buffer); + EXPECT_TRUE(writer.WriteLengthPrefixedConnectionId(connection_id)); + quiche::test::CompareCharArraysWithHexError( + "WriteLengthPrefixedConnectionId", buffer, writer.length(), + length_prefixed_connection_id, + ABSL_ARRAYSIZE(length_prefixed_connection_id)); + + // Verify that writing length then connection ID produces the same output. + memset(buffer, 0, ABSL_ARRAYSIZE(buffer)); + QuicDataWriter writer2(ABSL_ARRAYSIZE(buffer), buffer); + EXPECT_TRUE(writer2.WriteUInt8(connection_id.length())); + EXPECT_TRUE(writer2.WriteConnectionId(connection_id)); + quiche::test::CompareCharArraysWithHexError( + "Write length then ConnectionId", buffer, writer2.length(), + length_prefixed_connection_id, + ABSL_ARRAYSIZE(length_prefixed_connection_id)); + + QuicConnectionId read_connection_id; + QuicDataReader reader(buffer, ABSL_ARRAYSIZE(buffer)); + EXPECT_TRUE(reader.ReadLengthPrefixedConnectionId(&read_connection_id)); + EXPECT_EQ(connection_id, read_connection_id); + + // Verify that reading length then connection ID produces the same output. + uint8_t read_connection_id_length2 = 33; + QuicConnectionId read_connection_id2; + QuicDataReader reader2(buffer, ABSL_ARRAYSIZE(buffer)); + ASSERT_TRUE(reader2.ReadUInt8(&read_connection_id_length2)); + EXPECT_EQ(connection_id.length(), read_connection_id_length2); + EXPECT_TRUE(reader2.ReadConnectionId(&read_connection_id2, + read_connection_id_length2)); + EXPECT_EQ(connection_id, read_connection_id2); +} + +TEST_P(QuicDataWriterTest, EmptyConnectionIds) { + QuicConnectionId empty_connection_id = EmptyQuicConnectionId(); + char buffer[2]; + QuicDataWriter writer(ABSL_ARRAYSIZE(buffer), buffer, GetParam().endianness); + EXPECT_TRUE(writer.WriteConnectionId(empty_connection_id)); + EXPECT_TRUE(writer.WriteUInt8(1)); + EXPECT_TRUE(writer.WriteConnectionId(empty_connection_id)); + EXPECT_TRUE(writer.WriteUInt8(2)); + EXPECT_TRUE(writer.WriteConnectionId(empty_connection_id)); + EXPECT_FALSE(writer.WriteUInt8(3)); + + EXPECT_EQ(buffer[0], 1); + EXPECT_EQ(buffer[1], 2); + + QuicConnectionId read_connection_id = TestConnectionId(); + uint8_t read_byte; + QuicDataReader reader(buffer, ABSL_ARRAYSIZE(buffer), GetParam().endianness); + EXPECT_TRUE(reader.ReadConnectionId(&read_connection_id, 0)); + EXPECT_EQ(read_connection_id, empty_connection_id); + EXPECT_TRUE(reader.ReadUInt8(&read_byte)); + EXPECT_EQ(read_byte, 1); + // Reset read_connection_id to something else to verify that + // ReadConnectionId properly sets it back to empty. + read_connection_id = TestConnectionId(); + EXPECT_TRUE(reader.ReadConnectionId(&read_connection_id, 0)); + EXPECT_EQ(read_connection_id, empty_connection_id); + EXPECT_TRUE(reader.ReadUInt8(&read_byte)); + EXPECT_EQ(read_byte, 2); + read_connection_id = TestConnectionId(); + EXPECT_TRUE(reader.ReadConnectionId(&read_connection_id, 0)); + EXPECT_EQ(read_connection_id, empty_connection_id); + EXPECT_FALSE(reader.ReadUInt8(&read_byte)); +} + +TEST_P(QuicDataWriterTest, WriteTag) { + char CHLO[] = { + 'C', + 'H', + 'L', + 'O', + }; + const int kBufferLength = sizeof(QuicTag); + char buffer[kBufferLength]; + QuicDataWriter writer(kBufferLength, buffer, GetParam().endianness); + writer.WriteTag(kCHLO); + quiche::test::CompareCharArraysWithHexError("CHLO", buffer, kBufferLength, + CHLO, kBufferLength); + + QuicTag read_chlo; + QuicDataReader reader(buffer, kBufferLength, GetParam().endianness); + reader.ReadTag(&read_chlo); + EXPECT_EQ(kCHLO, read_chlo); +} + +TEST_P(QuicDataWriterTest, Write16BitUnsignedIntegers) { + char little_endian16[] = {0x22, 0x11}; + char big_endian16[] = {0x11, 0x22}; + char buffer16[2]; + { + uint16_t in_memory16 = 0x1122; + QuicDataWriter writer(2, buffer16, GetParam().endianness); + writer.WriteUInt16(in_memory16); + quiche::test::CompareCharArraysWithHexError( + "uint16_t", buffer16, 2, + GetParam().endianness == quiche::NETWORK_BYTE_ORDER ? big_endian16 + : little_endian16, + 2); + + uint16_t read_number16; + QuicDataReader reader(buffer16, 2, GetParam().endianness); + reader.ReadUInt16(&read_number16); + EXPECT_EQ(in_memory16, read_number16); + } + + { + uint64_t in_memory16 = 0x0000000000001122; + QuicDataWriter writer(2, buffer16, GetParam().endianness); + writer.WriteBytesToUInt64(2, in_memory16); + quiche::test::CompareCharArraysWithHexError( + "uint16_t", buffer16, 2, + GetParam().endianness == quiche::NETWORK_BYTE_ORDER ? big_endian16 + : little_endian16, + 2); + + uint64_t read_number16; + QuicDataReader reader(buffer16, 2, GetParam().endianness); + reader.ReadBytesToUInt64(2, &read_number16); + EXPECT_EQ(in_memory16, read_number16); + } +} + +TEST_P(QuicDataWriterTest, Write24BitUnsignedIntegers) { + char little_endian24[] = {0x33, 0x22, 0x11}; + char big_endian24[] = {0x11, 0x22, 0x33}; + char buffer24[3]; + uint64_t in_memory24 = 0x0000000000112233; + QuicDataWriter writer(3, buffer24, GetParam().endianness); + writer.WriteBytesToUInt64(3, in_memory24); + quiche::test::CompareCharArraysWithHexError( + "uint24", buffer24, 3, + GetParam().endianness == quiche::NETWORK_BYTE_ORDER ? big_endian24 + : little_endian24, + 3); + + uint64_t read_number24; + QuicDataReader reader(buffer24, 3, GetParam().endianness); + reader.ReadBytesToUInt64(3, &read_number24); + EXPECT_EQ(in_memory24, read_number24); +} + +TEST_P(QuicDataWriterTest, Write32BitUnsignedIntegers) { + char little_endian32[] = {0x44, 0x33, 0x22, 0x11}; + char big_endian32[] = {0x11, 0x22, 0x33, 0x44}; + char buffer32[4]; + { + uint32_t in_memory32 = 0x11223344; + QuicDataWriter writer(4, buffer32, GetParam().endianness); + writer.WriteUInt32(in_memory32); + quiche::test::CompareCharArraysWithHexError( + "uint32_t", buffer32, 4, + GetParam().endianness == quiche::NETWORK_BYTE_ORDER ? big_endian32 + : little_endian32, + 4); + + uint32_t read_number32; + QuicDataReader reader(buffer32, 4, GetParam().endianness); + reader.ReadUInt32(&read_number32); + EXPECT_EQ(in_memory32, read_number32); + } + + { + uint64_t in_memory32 = 0x11223344; + QuicDataWriter writer(4, buffer32, GetParam().endianness); + writer.WriteBytesToUInt64(4, in_memory32); + quiche::test::CompareCharArraysWithHexError( + "uint32_t", buffer32, 4, + GetParam().endianness == quiche::NETWORK_BYTE_ORDER ? big_endian32 + : little_endian32, + 4); + + uint64_t read_number32; + QuicDataReader reader(buffer32, 4, GetParam().endianness); + reader.ReadBytesToUInt64(4, &read_number32); + EXPECT_EQ(in_memory32, read_number32); + } +} + +TEST_P(QuicDataWriterTest, Write40BitUnsignedIntegers) { + uint64_t in_memory40 = 0x0000001122334455; + char little_endian40[] = {0x55, 0x44, 0x33, 0x22, 0x11}; + char big_endian40[] = {0x11, 0x22, 0x33, 0x44, 0x55}; + char buffer40[5]; + QuicDataWriter writer(5, buffer40, GetParam().endianness); + writer.WriteBytesToUInt64(5, in_memory40); + quiche::test::CompareCharArraysWithHexError( + "uint40", buffer40, 5, + GetParam().endianness == quiche::NETWORK_BYTE_ORDER ? big_endian40 + : little_endian40, + 5); + + uint64_t read_number40; + QuicDataReader reader(buffer40, 5, GetParam().endianness); + reader.ReadBytesToUInt64(5, &read_number40); + EXPECT_EQ(in_memory40, read_number40); +} + +TEST_P(QuicDataWriterTest, Write48BitUnsignedIntegers) { + uint64_t in_memory48 = 0x0000112233445566; + char little_endian48[] = {0x66, 0x55, 0x44, 0x33, 0x22, 0x11}; + char big_endian48[] = {0x11, 0x22, 0x33, 0x44, 0x55, 0x66}; + char buffer48[6]; + QuicDataWriter writer(6, buffer48, GetParam().endianness); + writer.WriteBytesToUInt64(6, in_memory48); + quiche::test::CompareCharArraysWithHexError( + "uint48", buffer48, 6, + GetParam().endianness == quiche::NETWORK_BYTE_ORDER ? big_endian48 + : little_endian48, + 6); + + uint64_t read_number48; + QuicDataReader reader(buffer48, 6, GetParam().endianness); + reader.ReadBytesToUInt64(6., &read_number48); + EXPECT_EQ(in_memory48, read_number48); +} + +TEST_P(QuicDataWriterTest, Write56BitUnsignedIntegers) { + uint64_t in_memory56 = 0x0011223344556677; + char little_endian56[] = {0x77, 0x66, 0x55, 0x44, 0x33, 0x22, 0x11}; + char big_endian56[] = {0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77}; + char buffer56[7]; + QuicDataWriter writer(7, buffer56, GetParam().endianness); + writer.WriteBytesToUInt64(7, in_memory56); + quiche::test::CompareCharArraysWithHexError( + "uint56", buffer56, 7, + GetParam().endianness == quiche::NETWORK_BYTE_ORDER ? big_endian56 + : little_endian56, + 7); + + uint64_t read_number56; + QuicDataReader reader(buffer56, 7, GetParam().endianness); + reader.ReadBytesToUInt64(7, &read_number56); + EXPECT_EQ(in_memory56, read_number56); +} + +TEST_P(QuicDataWriterTest, Write64BitUnsignedIntegers) { + uint64_t in_memory64 = 0x1122334455667788; + unsigned char little_endian64[] = {0x88, 0x77, 0x66, 0x55, + 0x44, 0x33, 0x22, 0x11}; + unsigned char big_endian64[] = {0x11, 0x22, 0x33, 0x44, + 0x55, 0x66, 0x77, 0x88}; + char buffer64[8]; + QuicDataWriter writer(8, buffer64, GetParam().endianness); + writer.WriteBytesToUInt64(8, in_memory64); + quiche::test::CompareCharArraysWithHexError( + "uint64_t", buffer64, 8, + GetParam().endianness == quiche::NETWORK_BYTE_ORDER + ? AsChars(big_endian64) + : AsChars(little_endian64), + 8); + + uint64_t read_number64; + QuicDataReader reader(buffer64, 8, GetParam().endianness); + reader.ReadBytesToUInt64(8, &read_number64); + EXPECT_EQ(in_memory64, read_number64); + + QuicDataWriter writer2(8, buffer64, GetParam().endianness); + writer2.WriteUInt64(in_memory64); + quiche::test::CompareCharArraysWithHexError( + "uint64_t", buffer64, 8, + GetParam().endianness == quiche::NETWORK_BYTE_ORDER + ? AsChars(big_endian64) + : AsChars(little_endian64), + 8); + read_number64 = 0u; + QuicDataReader reader2(buffer64, 8, GetParam().endianness); + reader2.ReadUInt64(&read_number64); + EXPECT_EQ(in_memory64, read_number64); +} + +TEST_P(QuicDataWriterTest, WriteIntegers) { + char buf[43]; + uint8_t i8 = 0x01; + uint16_t i16 = 0x0123; + uint32_t i32 = 0x01234567; + uint64_t i64 = 0x0123456789ABCDEF; + QuicDataWriter writer(46, buf, GetParam().endianness); + for (size_t i = 0; i < 10; ++i) { + switch (i) { + case 0u: + EXPECT_TRUE(writer.WriteBytesToUInt64(i, i64)); + break; + case 1u: + EXPECT_TRUE(writer.WriteUInt8(i8)); + EXPECT_TRUE(writer.WriteBytesToUInt64(i, i64)); + break; + case 2u: + EXPECT_TRUE(writer.WriteUInt16(i16)); + EXPECT_TRUE(writer.WriteBytesToUInt64(i, i64)); + break; + case 3u: + EXPECT_TRUE(writer.WriteBytesToUInt64(i, i64)); + break; + case 4u: + EXPECT_TRUE(writer.WriteUInt32(i32)); + EXPECT_TRUE(writer.WriteBytesToUInt64(i, i64)); + break; + case 5u: + case 6u: + case 7u: + case 8u: + EXPECT_TRUE(writer.WriteBytesToUInt64(i, i64)); + break; + default: + EXPECT_FALSE(writer.WriteBytesToUInt64(i, i64)); + } + } + + QuicDataReader reader(buf, 46, GetParam().endianness); + for (size_t i = 0; i < 10; ++i) { + uint8_t read8; + uint16_t read16; + uint32_t read32; + uint64_t read64; + switch (i) { + case 0u: + EXPECT_TRUE(reader.ReadBytesToUInt64(i, &read64)); + EXPECT_EQ(0u, read64); + break; + case 1u: + EXPECT_TRUE(reader.ReadUInt8(&read8)); + EXPECT_TRUE(reader.ReadBytesToUInt64(i, &read64)); + EXPECT_EQ(i8, read8); + EXPECT_EQ(0xEFu, read64); + break; + case 2u: + EXPECT_TRUE(reader.ReadUInt16(&read16)); + EXPECT_TRUE(reader.ReadBytesToUInt64(i, &read64)); + EXPECT_EQ(i16, read16); + EXPECT_EQ(0xCDEFu, read64); + break; + case 3u: + EXPECT_TRUE(reader.ReadBytesToUInt64(i, &read64)); + EXPECT_EQ(0xABCDEFu, read64); + break; + case 4u: + EXPECT_TRUE(reader.ReadUInt32(&read32)); + EXPECT_TRUE(reader.ReadBytesToUInt64(i, &read64)); + EXPECT_EQ(i32, read32); + EXPECT_EQ(0x89ABCDEFu, read64); + break; + case 5u: + EXPECT_TRUE(reader.ReadBytesToUInt64(i, &read64)); + EXPECT_EQ(0x6789ABCDEFu, read64); + break; + case 6u: + EXPECT_TRUE(reader.ReadBytesToUInt64(i, &read64)); + EXPECT_EQ(0x456789ABCDEFu, read64); + break; + case 7u: + EXPECT_TRUE(reader.ReadBytesToUInt64(i, &read64)); + EXPECT_EQ(0x23456789ABCDEFu, read64); + break; + case 8u: + EXPECT_TRUE(reader.ReadBytesToUInt64(i, &read64)); + EXPECT_EQ(0x0123456789ABCDEFu, read64); + break; + default: + EXPECT_FALSE(reader.ReadBytesToUInt64(i, &read64)); + } + } +} + +TEST_P(QuicDataWriterTest, WriteBytes) { + char bytes[] = {0, 1, 2, 3, 4, 5, 6, 7, 8}; + char buf[ABSL_ARRAYSIZE(bytes)]; + QuicDataWriter writer(ABSL_ARRAYSIZE(buf), buf, GetParam().endianness); + EXPECT_TRUE(writer.WriteBytes(bytes, ABSL_ARRAYSIZE(bytes))); + for (unsigned int i = 0; i < ABSL_ARRAYSIZE(bytes); ++i) { + EXPECT_EQ(bytes[i], buf[i]); + } +} + +// Following tests all try to fill the buffer with multiple values, +// go one value more than the buffer can accommodate, then read +// the successfully encoded values, and try to read the unsuccessfully +// encoded value. The following is the number of values to encode. +const int kMultiVarCount = 1000; + +// Test encoding/decoding stream-id values. +void EncodeDecodeStreamId(uint64_t value_in) { + char buffer[1 * kMultiVarCount]; + memset(buffer, 0, sizeof(buffer)); + + // Encode the given Stream ID. + QuicDataWriter writer(sizeof(buffer), static_cast(buffer), + quiche::Endianness::NETWORK_BYTE_ORDER); + EXPECT_TRUE(writer.WriteVarInt62(value_in)); + + QuicDataReader reader(buffer, sizeof(buffer), + quiche::Endianness::NETWORK_BYTE_ORDER); + QuicStreamId received_stream_id; + uint64_t temp; + EXPECT_TRUE(reader.ReadVarInt62(&temp)); + received_stream_id = static_cast(temp); + EXPECT_EQ(value_in, received_stream_id); +} + +// Test writing & reading stream-ids of various value. +TEST_P(QuicDataWriterTest, StreamId1) { + // Check a 1-byte QuicStreamId, should work + EncodeDecodeStreamId(UINT64_C(0x15)); + + // Check a 2-byte QuicStream ID. It should work. + EncodeDecodeStreamId(UINT64_C(0x1567)); + + // Check a QuicStreamId that requires 4 bytes of encoding + // This should work. + EncodeDecodeStreamId(UINT64_C(0x34567890)); + + // Check a QuicStreamId that requires 8 bytes of encoding + // but whose value is in the acceptable range. + // This should work. + EncodeDecodeStreamId(UINT64_C(0xf4567890)); +} + +TEST_P(QuicDataWriterTest, WriteRandomBytes) { + char buffer[20]; + char expected[20]; + for (size_t i = 0; i < 20; ++i) { + expected[i] = 'r'; + } + MockRandom random; + QuicDataWriter writer(20, buffer, GetParam().endianness); + EXPECT_FALSE(writer.WriteRandomBytes(&random, 30)); + + EXPECT_TRUE(writer.WriteRandomBytes(&random, 20)); + quiche::test::CompareCharArraysWithHexError("random", buffer, 20, expected, + 20); +} + +TEST_P(QuicDataWriterTest, WriteInsecureRandomBytes) { + char buffer[20]; + char expected[20]; + for (size_t i = 0; i < 20; ++i) { + expected[i] = 'r'; + } + MockRandom random; + QuicDataWriter writer(20, buffer, GetParam().endianness); + EXPECT_FALSE(writer.WriteInsecureRandomBytes(&random, 30)); + + EXPECT_TRUE(writer.WriteInsecureRandomBytes(&random, 20)); + quiche::test::CompareCharArraysWithHexError("random", buffer, 20, expected, + 20); +} + +TEST_P(QuicDataWriterTest, PeekVarInt62Length) { + // In range [0, 63], variable length should be 1 byte. + char buffer[20]; + QuicDataWriter writer(20, buffer, quiche::NETWORK_BYTE_ORDER); + EXPECT_TRUE(writer.WriteVarInt62(50)); + QuicDataReader reader(buffer, 20, quiche::NETWORK_BYTE_ORDER); + EXPECT_EQ(1, reader.PeekVarInt62Length()); + // In range (63-16383], variable length should be 2 byte2. + char buffer2[20]; + QuicDataWriter writer2(20, buffer2, quiche::NETWORK_BYTE_ORDER); + EXPECT_TRUE(writer2.WriteVarInt62(100)); + QuicDataReader reader2(buffer2, 20, quiche::NETWORK_BYTE_ORDER); + EXPECT_EQ(2, reader2.PeekVarInt62Length()); + // In range (16383, 1073741823], variable length should be 4 bytes. + char buffer3[20]; + QuicDataWriter writer3(20, buffer3, quiche::NETWORK_BYTE_ORDER); + EXPECT_TRUE(writer3.WriteVarInt62(20000)); + QuicDataReader reader3(buffer3, 20, quiche::NETWORK_BYTE_ORDER); + EXPECT_EQ(4, reader3.PeekVarInt62Length()); + // In range (1073741823, 4611686018427387903], variable length should be 8 + // bytes. + char buffer4[20]; + QuicDataWriter writer4(20, buffer4, quiche::NETWORK_BYTE_ORDER); + EXPECT_TRUE(writer4.WriteVarInt62(2000000000)); + QuicDataReader reader4(buffer4, 20, quiche::NETWORK_BYTE_ORDER); + EXPECT_EQ(8, reader4.PeekVarInt62Length()); +} + +TEST_P(QuicDataWriterTest, ValidStreamCount) { + char buffer[1024]; + memset(buffer, 0, sizeof(buffer)); + QuicDataWriter writer(sizeof(buffer), static_cast(buffer), + quiche::Endianness::NETWORK_BYTE_ORDER); + QuicDataReader reader(buffer, sizeof(buffer)); + const QuicStreamCount write_stream_count = 0xffeeddcc; + EXPECT_TRUE(writer.WriteVarInt62(write_stream_count)); + QuicStreamCount read_stream_count; + uint64_t temp; + EXPECT_TRUE(reader.ReadVarInt62(&temp)); + read_stream_count = static_cast(temp); + EXPECT_EQ(write_stream_count, read_stream_count); +} + +TEST_P(QuicDataWriterTest, Seek) { + char buffer[3] = {}; + QuicDataWriter writer(ABSL_ARRAYSIZE(buffer), buffer, GetParam().endianness); + EXPECT_TRUE(writer.WriteUInt8(42)); + EXPECT_TRUE(writer.Seek(1)); + EXPECT_TRUE(writer.WriteUInt8(3)); + + char expected[] = {42, 0, 3}; + for (size_t i = 0; i < ABSL_ARRAYSIZE(expected); ++i) { + EXPECT_EQ(buffer[i], expected[i]); + } +} + +TEST_P(QuicDataWriterTest, SeekTooFarFails) { + char buffer[20]; + + // Check that one can seek to the end of the writer, but not past. + { + QuicDataWriter writer(ABSL_ARRAYSIZE(buffer), buffer, + GetParam().endianness); + EXPECT_TRUE(writer.Seek(20)); + EXPECT_FALSE(writer.Seek(1)); + } + + // Seeking several bytes past the end fails. + { + QuicDataWriter writer(ABSL_ARRAYSIZE(buffer), buffer, + GetParam().endianness); + EXPECT_FALSE(writer.Seek(100)); + } + + // Seeking so far that arithmetic overflow could occur also fails. + { + QuicDataWriter writer(ABSL_ARRAYSIZE(buffer), buffer, + GetParam().endianness); + EXPECT_TRUE(writer.Seek(10)); + EXPECT_FALSE(writer.Seek(std::numeric_limits::max())); + } +} + +TEST_P(QuicDataWriterTest, PayloadReads) { + char buffer[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + char expected_first_read[4] = {1, 2, 3, 4}; + char expected_remaining[12] = {5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + QuicDataReader reader(buffer, sizeof(buffer)); + char first_read_buffer[4] = {}; + EXPECT_TRUE(reader.ReadBytes(first_read_buffer, sizeof(first_read_buffer))); + quiche::test::CompareCharArraysWithHexError( + "first read", first_read_buffer, sizeof(first_read_buffer), + expected_first_read, sizeof(expected_first_read)); + absl::string_view peeked_remaining_payload = reader.PeekRemainingPayload(); + quiche::test::CompareCharArraysWithHexError( + "peeked_remaining_payload", peeked_remaining_payload.data(), + peeked_remaining_payload.length(), expected_remaining, + sizeof(expected_remaining)); + absl::string_view full_payload = reader.FullPayload(); + quiche::test::CompareCharArraysWithHexError( + "full_payload", full_payload.data(), full_payload.length(), buffer, + sizeof(buffer)); + absl::string_view read_remaining_payload = reader.ReadRemainingPayload(); + quiche::test::CompareCharArraysWithHexError( + "read_remaining_payload", read_remaining_payload.data(), + read_remaining_payload.length(), expected_remaining, + sizeof(expected_remaining)); + EXPECT_TRUE(reader.IsDoneReading()); + absl::string_view full_payload2 = reader.FullPayload(); + quiche::test::CompareCharArraysWithHexError( + "full_payload2", full_payload2.data(), full_payload2.length(), buffer, + sizeof(buffer)); +} + +TEST_P(QuicDataWriterTest, StringPieceVarInt62) { + char inner_buffer[16] = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}; + absl::string_view inner_payload_write(inner_buffer, sizeof(inner_buffer)); + char buffer[sizeof(inner_buffer) + sizeof(uint8_t)] = {}; + QuicDataWriter writer(sizeof(buffer), buffer); + EXPECT_TRUE(writer.WriteStringPieceVarInt62(inner_payload_write)); + EXPECT_EQ(0u, writer.remaining()); + QuicDataReader reader(buffer, sizeof(buffer)); + absl::string_view inner_payload_read; + EXPECT_TRUE(reader.ReadStringPieceVarInt62(&inner_payload_read)); + quiche::test::CompareCharArraysWithHexError( + "inner_payload", inner_payload_write.data(), inner_payload_write.length(), + inner_payload_read.data(), inner_payload_read.length()); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_datagram_queue.cc b/quiche/quic/core/quic_datagram_queue.cc new file mode 100644 index 000000000000..97965873771e --- /dev/null +++ b/quiche/quic/core/quic_datagram_queue.cc @@ -0,0 +1,102 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_datagram_queue.h" + +#include "absl/types/span.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" + +namespace quic { + +constexpr float kExpiryInMinRtts = 1.25; +constexpr float kMinPacingWindows = 4; + +QuicDatagramQueue::QuicDatagramQueue(QuicSession* session) + : QuicDatagramQueue(session, nullptr) {} + +QuicDatagramQueue::QuicDatagramQueue(QuicSession* session, + std::unique_ptr observer) + : session_(session), + clock_(session->connection()->clock()), + observer_(std::move(observer)), + force_flush_(false) {} + +MessageStatus QuicDatagramQueue::SendOrQueueDatagram( + quiche::QuicheMemSlice datagram) { + // If the queue is non-empty, always queue the daragram. This ensures that + // the datagrams are sent in the same order that they were sent by the + // application. + if (queue_.empty()) { + MessageResult result = session_->SendMessage(absl::MakeSpan(&datagram, 1), + /*flush=*/force_flush_); + if (result.status != MESSAGE_STATUS_BLOCKED) { + if (observer_) { + observer_->OnDatagramProcessed(result.status); + } + return result.status; + } + } + + queue_.emplace_back(Datagram{std::move(datagram), + clock_->ApproximateNow() + GetMaxTimeInQueue()}); + return MESSAGE_STATUS_BLOCKED; +} + +absl::optional QuicDatagramQueue::TrySendingNextDatagram() { + RemoveExpiredDatagrams(); + if (queue_.empty()) { + return absl::nullopt; + } + + MessageResult result = + session_->SendMessage(absl::MakeSpan(&queue_.front().datagram, 1)); + if (result.status != MESSAGE_STATUS_BLOCKED) { + queue_.pop_front(); + if (observer_) { + observer_->OnDatagramProcessed(result.status); + } + } + return result.status; +} + +size_t QuicDatagramQueue::SendDatagrams() { + size_t num_datagrams = 0; + for (;;) { + absl::optional status = TrySendingNextDatagram(); + if (!status.has_value()) { + break; + } + if (*status == MESSAGE_STATUS_BLOCKED) { + break; + } + num_datagrams++; + } + return num_datagrams; +} + +QuicTime::Delta QuicDatagramQueue::GetMaxTimeInQueue() const { + if (!max_time_in_queue_.IsZero()) { + return max_time_in_queue_; + } + + const QuicTime::Delta min_rtt = + session_->connection()->sent_packet_manager().GetRttStats()->min_rtt(); + return std::max(kExpiryInMinRtts * min_rtt, + kMinPacingWindows * kAlarmGranularity); +} + +void QuicDatagramQueue::RemoveExpiredDatagrams() { + QuicTime now = clock_->ApproximateNow(); + while (!queue_.empty() && queue_.front().expiry <= now) { + queue_.pop_front(); + if (observer_) { + observer_->OnDatagramProcessed(absl::nullopt); + } + } +} + +} // namespace quic diff --git a/quiche/quic/core/quic_datagram_queue.h b/quiche/quic/core/quic_datagram_queue.h new file mode 100644 index 000000000000..851cd0a9253d --- /dev/null +++ b/quiche/quic/core/quic_datagram_queue.h @@ -0,0 +1,95 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_DATAGRAM_QUEUE_H_ +#define QUICHE_QUIC_CORE_QUIC_DATAGRAM_QUEUE_H_ + +#include + +#include "absl/types/optional.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/common/quiche_circular_deque.h" + +namespace quic { + +class QuicSession; + +// Provides a way to buffer QUIC datagrams (messages) in case they cannot +// be sent due to congestion control. Datagrams are buffered for a limited +// amount of time, and deleted after that time passes. +class QUIC_EXPORT_PRIVATE QuicDatagramQueue { + public: + // An interface used to monitor events on the associated `QuicDatagramQueue`. + class QUIC_EXPORT_PRIVATE Observer { + public: + virtual ~Observer() = default; + + // Called when a datagram in the associated queue is sent or discarded. + // Identity information for the datagram is not given, because the sending + // and discarding order is always first-in-first-out. + // This function is called synchronously in `QuicDatagramQueue` methods. + // `status` is nullopt when the datagram is dropped due to being in the + // queue for too long. + virtual void OnDatagramProcessed(absl::optional status) = 0; + }; + + // |session| is not owned and must outlive this object. + explicit QuicDatagramQueue(QuicSession* session); + + // |session| is not owned and must outlive this object. + QuicDatagramQueue(QuicSession* session, std::unique_ptr observer); + + // Adds the datagram to the end of the queue. May send it immediately; if + // not, MESSAGE_STATUS_BLOCKED is returned. + MessageStatus SendOrQueueDatagram(quiche::QuicheMemSlice datagram); + + // Attempts to send a single datagram from the queue. Returns the result of + // SendMessage(), or nullopt if there were no unexpired datagrams to send. + absl::optional TrySendingNextDatagram(); + + // Sends all of the unexpired datagrams until either the connection becomes + // write-blocked or the queue is empty. Returns the number of datagrams sent. + size_t SendDatagrams(); + + // Returns the amount of time a datagram is allowed to be in the queue before + // it is dropped. If not set explicitly using SetMaxTimeInQueue(), an + // RTT-based heuristic is used. + QuicTime::Delta GetMaxTimeInQueue() const; + + void SetMaxTimeInQueue(QuicTime::Delta max_time_in_queue) { + max_time_in_queue_ = max_time_in_queue; + } + + // If set to true, all datagrams added into the queue would be sent with the + // flush flag set to true, meaning that they will bypass congestion control + // and related logic. + void SetForceFlush(bool force_flush) { force_flush_ = force_flush; } + + size_t queue_size() { return queue_.size(); } + + bool empty() { return queue_.empty(); } + + private: + struct QUIC_EXPORT_PRIVATE Datagram { + quiche::QuicheMemSlice datagram; + QuicTime expiry; + }; + + // Removes expired datagrams from the front of the queue. + void RemoveExpiredDatagrams(); + + QuicSession* session_; // Not owned. + const QuicClock* clock_; + + QuicTime::Delta max_time_in_queue_ = QuicTime::Delta::Zero(); + quiche::QuicheCircularDeque queue_; + std::unique_ptr observer_; + bool force_flush_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_DATAGRAM_QUEUE_H_ diff --git a/quiche/quic/core/quic_datagram_queue_test.cc b/quiche/quic/core/quic_datagram_queue_test.cc new file mode 100644 index 000000000000..7895941fb78a --- /dev/null +++ b/quiche/quic/core/quic_datagram_queue_test.cc @@ -0,0 +1,297 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_datagram_queue.h" + +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/common/platform/api/quiche_reference_counted.h" +#include "quiche/common/quiche_buffer_allocator.h" + +namespace quic { +namespace test { +namespace { + +using testing::_; +using testing::ElementsAre; +using testing::Return; + +class EstablishedCryptoStream : public MockQuicCryptoStream { + public: + using MockQuicCryptoStream::MockQuicCryptoStream; + + bool encryption_established() const override { return true; } +}; + +class QuicDatagramQueueObserver final : public QuicDatagramQueue::Observer { + public: + class Context : public quiche::QuicheReferenceCounted { + public: + std::vector> statuses; + }; + + QuicDatagramQueueObserver() : context_(new Context()) {} + QuicDatagramQueueObserver(const QuicDatagramQueueObserver&) = delete; + QuicDatagramQueueObserver& operator=(const QuicDatagramQueueObserver&) = + delete; + + void OnDatagramProcessed(absl::optional status) override { + context_->statuses.push_back(std::move(status)); + } + + const quiche::QuicheReferenceCountedPointer& context() { + return context_; + } + + private: + quiche::QuicheReferenceCountedPointer context_; +}; + +class QuicDatagramQueueTestBase : public QuicTest { + protected: + QuicDatagramQueueTestBase() + : connection_(new MockQuicConnection(&helper_, &alarm_factory_, + Perspective::IS_CLIENT)), + session_(connection_) { + session_.SetCryptoStream(new EstablishedCryptoStream(&session_)); + connection_->SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(connection_->perspective())); + } + + ~QuicDatagramQueueTestBase() = default; + + quiche::QuicheMemSlice CreateMemSlice(absl::string_view data) { + return quiche::QuicheMemSlice(quiche::QuicheBuffer::Copy( + helper_.GetStreamSendBufferAllocator(), data)); + } + + MockQuicConnectionHelper helper_; + MockAlarmFactory alarm_factory_; + MockQuicConnection* connection_; // Owned by |session_|. + MockQuicSession session_; +}; + +class QuicDatagramQueueTest : public QuicDatagramQueueTestBase { + public: + QuicDatagramQueueTest() : queue_(&session_) {} + + protected: + QuicDatagramQueue queue_; +}; + +TEST_F(QuicDatagramQueueTest, SendDatagramImmediately) { + EXPECT_CALL(*connection_, SendMessage(_, _, _)) + .WillOnce(Return(MESSAGE_STATUS_SUCCESS)); + MessageStatus status = queue_.SendOrQueueDatagram(CreateMemSlice("test")); + EXPECT_EQ(MESSAGE_STATUS_SUCCESS, status); + EXPECT_EQ(0u, queue_.queue_size()); +} + +TEST_F(QuicDatagramQueueTest, SendDatagramAfterBuffering) { + EXPECT_CALL(*connection_, SendMessage(_, _, _)) + .WillOnce(Return(MESSAGE_STATUS_BLOCKED)); + MessageStatus initial_status = + queue_.SendOrQueueDatagram(CreateMemSlice("test")); + EXPECT_EQ(MESSAGE_STATUS_BLOCKED, initial_status); + EXPECT_EQ(1u, queue_.queue_size()); + + // Verify getting write blocked does not remove the datagram from the queue. + EXPECT_CALL(*connection_, SendMessage(_, _, _)) + .WillOnce(Return(MESSAGE_STATUS_BLOCKED)); + absl::optional status = queue_.TrySendingNextDatagram(); + ASSERT_TRUE(status.has_value()); + EXPECT_EQ(MESSAGE_STATUS_BLOCKED, *status); + EXPECT_EQ(1u, queue_.queue_size()); + + EXPECT_CALL(*connection_, SendMessage(_, _, _)) + .WillOnce(Return(MESSAGE_STATUS_SUCCESS)); + status = queue_.TrySendingNextDatagram(); + ASSERT_TRUE(status.has_value()); + EXPECT_EQ(MESSAGE_STATUS_SUCCESS, *status); + EXPECT_EQ(0u, queue_.queue_size()); +} + +TEST_F(QuicDatagramQueueTest, EmptyBuffer) { + absl::optional status = queue_.TrySendingNextDatagram(); + EXPECT_FALSE(status.has_value()); + + size_t num_messages = queue_.SendDatagrams(); + EXPECT_EQ(0u, num_messages); +} + +TEST_F(QuicDatagramQueueTest, MultipleDatagrams) { + // Note that SendMessage() is called only once here, since all the remaining + // messages are automatically queued due to the queue being non-empty. + EXPECT_CALL(*connection_, SendMessage(_, _, _)) + .WillOnce(Return(MESSAGE_STATUS_BLOCKED)); + queue_.SendOrQueueDatagram(CreateMemSlice("a")); + queue_.SendOrQueueDatagram(CreateMemSlice("b")); + queue_.SendOrQueueDatagram(CreateMemSlice("c")); + queue_.SendOrQueueDatagram(CreateMemSlice("d")); + queue_.SendOrQueueDatagram(CreateMemSlice("e")); + + EXPECT_CALL(*connection_, SendMessage(_, _, _)) + .Times(5) + .WillRepeatedly(Return(MESSAGE_STATUS_SUCCESS)); + size_t num_messages = queue_.SendDatagrams(); + EXPECT_EQ(5u, num_messages); +} + +TEST_F(QuicDatagramQueueTest, DefaultMaxTimeInQueue) { + EXPECT_EQ(QuicTime::Delta::Zero(), + connection_->sent_packet_manager().GetRttStats()->min_rtt()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(4), queue_.GetMaxTimeInQueue()); + + RttStats* stats = + const_cast(connection_->sent_packet_manager().GetRttStats()); + stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(100), + QuicTime::Delta::Zero(), helper_.GetClock()->Now()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(125), queue_.GetMaxTimeInQueue()); +} + +TEST_F(QuicDatagramQueueTest, Expiry) { + constexpr QuicTime::Delta expiry = QuicTime::Delta::FromMilliseconds(100); + queue_.SetMaxTimeInQueue(expiry); + + EXPECT_CALL(*connection_, SendMessage(_, _, _)) + .WillOnce(Return(MESSAGE_STATUS_BLOCKED)); + queue_.SendOrQueueDatagram(CreateMemSlice("a")); + helper_.AdvanceTime(0.6 * expiry); + queue_.SendOrQueueDatagram(CreateMemSlice("b")); + helper_.AdvanceTime(0.6 * expiry); + queue_.SendOrQueueDatagram(CreateMemSlice("c")); + + std::vector messages; + EXPECT_CALL(*connection_, SendMessage(_, _, _)) + .WillRepeatedly([&messages](QuicMessageId /*id*/, + absl::Span message, + bool /*flush*/) { + messages.push_back(std::string(message[0].AsStringView())); + return MESSAGE_STATUS_SUCCESS; + }); + EXPECT_EQ(2u, queue_.SendDatagrams()); + EXPECT_THAT(messages, ElementsAre("b", "c")); +} + +TEST_F(QuicDatagramQueueTest, ExpireAll) { + constexpr QuicTime::Delta expiry = QuicTime::Delta::FromMilliseconds(100); + queue_.SetMaxTimeInQueue(expiry); + + EXPECT_CALL(*connection_, SendMessage(_, _, _)) + .WillOnce(Return(MESSAGE_STATUS_BLOCKED)); + queue_.SendOrQueueDatagram(CreateMemSlice("a")); + queue_.SendOrQueueDatagram(CreateMemSlice("b")); + queue_.SendOrQueueDatagram(CreateMemSlice("c")); + + helper_.AdvanceTime(100 * expiry); + EXPECT_CALL(*connection_, SendMessage(_, _, _)).Times(0); + EXPECT_EQ(0u, queue_.SendDatagrams()); +} + +class QuicDatagramQueueWithObserverTest : public QuicDatagramQueueTestBase { + public: + QuicDatagramQueueWithObserverTest() + : observer_(std::make_unique()), + context_(observer_->context()), + queue_(&session_, std::move(observer_)) {} + + protected: + // This is moved out immediately. + std::unique_ptr observer_; + + quiche::QuicheReferenceCountedPointer + context_; + QuicDatagramQueue queue_; +}; + +TEST_F(QuicDatagramQueueWithObserverTest, ObserveSuccessImmediately) { + EXPECT_TRUE(context_->statuses.empty()); + + EXPECT_CALL(*connection_, SendMessage(_, _, _)) + .WillOnce(Return(MESSAGE_STATUS_SUCCESS)); + + EXPECT_EQ(MESSAGE_STATUS_SUCCESS, + queue_.SendOrQueueDatagram(CreateMemSlice("a"))); + + EXPECT_THAT(context_->statuses, ElementsAre(MESSAGE_STATUS_SUCCESS)); +} + +TEST_F(QuicDatagramQueueWithObserverTest, ObserveFailureImmediately) { + EXPECT_TRUE(context_->statuses.empty()); + + EXPECT_CALL(*connection_, SendMessage(_, _, _)) + .WillOnce(Return(MESSAGE_STATUS_TOO_LARGE)); + + EXPECT_EQ(MESSAGE_STATUS_TOO_LARGE, + queue_.SendOrQueueDatagram(CreateMemSlice("a"))); + + EXPECT_THAT(context_->statuses, ElementsAre(MESSAGE_STATUS_TOO_LARGE)); +} + +TEST_F(QuicDatagramQueueWithObserverTest, BlockingShouldNotBeObserved) { + EXPECT_TRUE(context_->statuses.empty()); + + EXPECT_CALL(*connection_, SendMessage(_, _, _)) + .WillRepeatedly(Return(MESSAGE_STATUS_BLOCKED)); + + EXPECT_EQ(MESSAGE_STATUS_BLOCKED, + queue_.SendOrQueueDatagram(CreateMemSlice("a"))); + EXPECT_EQ(0u, queue_.SendDatagrams()); + + EXPECT_TRUE(context_->statuses.empty()); +} + +TEST_F(QuicDatagramQueueWithObserverTest, ObserveSuccessAfterBuffering) { + EXPECT_TRUE(context_->statuses.empty()); + + EXPECT_CALL(*connection_, SendMessage(_, _, _)) + .WillOnce(Return(MESSAGE_STATUS_BLOCKED)); + + EXPECT_EQ(MESSAGE_STATUS_BLOCKED, + queue_.SendOrQueueDatagram(CreateMemSlice("a"))); + + EXPECT_TRUE(context_->statuses.empty()); + + EXPECT_CALL(*connection_, SendMessage(_, _, _)) + .WillOnce(Return(MESSAGE_STATUS_SUCCESS)); + + EXPECT_EQ(1u, queue_.SendDatagrams()); + EXPECT_THAT(context_->statuses, ElementsAre(MESSAGE_STATUS_SUCCESS)); +} + +TEST_F(QuicDatagramQueueWithObserverTest, ObserveExpiry) { + constexpr QuicTime::Delta expiry = QuicTime::Delta::FromMilliseconds(100); + queue_.SetMaxTimeInQueue(expiry); + + EXPECT_TRUE(context_->statuses.empty()); + + EXPECT_CALL(*connection_, SendMessage(_, _, _)) + .WillOnce(Return(MESSAGE_STATUS_BLOCKED)); + + EXPECT_EQ(MESSAGE_STATUS_BLOCKED, + queue_.SendOrQueueDatagram(CreateMemSlice("a"))); + + EXPECT_TRUE(context_->statuses.empty()); + + EXPECT_CALL(*connection_, SendMessage(_, _, _)).Times(0); + helper_.AdvanceTime(100 * expiry); + + EXPECT_TRUE(context_->statuses.empty()); + + EXPECT_EQ(0u, queue_.SendDatagrams()); + EXPECT_THAT(context_->statuses, ElementsAre(absl::nullopt)); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_default_clock.cc b/quiche/quic/core/quic_default_clock.cc new file mode 100644 index 000000000000..208687b6440e --- /dev/null +++ b/quiche/quic/core/quic_default_clock.cc @@ -0,0 +1,26 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_default_clock.h" + +#include "absl/time/clock.h" + +namespace quic { + +QuicDefaultClock* QuicDefaultClock::Get() { + static QuicDefaultClock* clock = new QuicDefaultClock(); + return clock; +} + +QuicTime QuicDefaultClock::ApproximateNow() const { return Now(); } + +QuicTime QuicDefaultClock::Now() const { + return CreateTimeFromMicroseconds(absl::GetCurrentTimeNanos() / 1000); +} + +QuicWallTime QuicDefaultClock::WallNow() const { + return QuicWallTime::FromUNIXMicroseconds(absl::GetCurrentTimeNanos() / 1000); +} + +} // namespace quic diff --git a/quiche/quic/core/quic_default_clock.h b/quiche/quic/core/quic_default_clock.h new file mode 100644 index 000000000000..64b89dd3368e --- /dev/null +++ b/quiche/quic/core/quic_default_clock.h @@ -0,0 +1,32 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_DEFAULT_CLOCK_H_ +#define QUICHE_QUIC_CORE_QUIC_DEFAULT_CLOCK_H_ + +#include "quiche/quic/core/quic_clock.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// A QuicClock based on Abseil time API. Thread-safe. +class QUIC_EXPORT_PRIVATE QuicDefaultClock : public QuicClock { + public: + // Provides a single default stateless instance of QuicDefaultClock. + static QuicDefaultClock* Get(); + + explicit QuicDefaultClock() = default; + QuicDefaultClock(const QuicDefaultClock&) = delete; + QuicDefaultClock& operator=(const QuicDefaultClock&) = delete; + + // QuicClock implementation. + QuicTime ApproximateNow() const override; + QuicTime Now() const override; + QuicWallTime WallNow() const override; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_DEFAULT_CLOCK_H_ diff --git a/quiche/quic/core/quic_default_connection_helper.h b/quiche/quic/core/quic_default_connection_helper.h new file mode 100644 index 000000000000..3a10b43319c6 --- /dev/null +++ b/quiche/quic/core/quic_default_connection_helper.h @@ -0,0 +1,49 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_DEFAULT_CONNECTION_HELPER_H_ +#define QUICHE_QUIC_CORE_QUIC_DEFAULT_CONNECTION_HELPER_H_ + +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/core/quic_default_clock.h" +#include "quiche/common/simple_buffer_allocator.h" + +namespace quic { + +// A default implementation of QuicConnectionHelperInterface. Thread-safe. +class QUIC_EXPORT_PRIVATE QuicDefaultConnectionHelper + : public QuicConnectionHelperInterface { + public: + static QuicDefaultConnectionHelper* Get() { + static QuicDefaultConnectionHelper* helper = + new QuicDefaultConnectionHelper(); + return helper; + } + + // Creates a helper that uses the default allocator. + QuicDefaultConnectionHelper() : QuicDefaultConnectionHelper(nullptr) {} + // If |allocator| is nullptr, the default one is used. + QuicDefaultConnectionHelper( + std::unique_ptr allocator) + : allocator_(std::move(allocator)) {} + + QuicDefaultConnectionHelper(const QuicDefaultConnectionHelper&) = delete; + QuicDefaultConnectionHelper& operator=(const QuicDefaultConnectionHelper&) = + delete; + + const QuicClock* GetClock() const override { return QuicDefaultClock::Get(); } + QuicRandom* GetRandomGenerator() override { + return QuicRandom::GetInstance(); + } + quiche::QuicheBufferAllocator* GetStreamSendBufferAllocator() override { + return allocator_ ? allocator_.get() : quiche::SimpleBufferAllocator::Get(); + } + + private: + std::unique_ptr allocator_; +}; +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_DEFAULT_CONNECTION_HELPER_H_ diff --git a/quiche/quic/core/quic_default_packet_writer.cc b/quiche/quic/core/quic_default_packet_writer.cc new file mode 100644 index 000000000000..78feee512c5d --- /dev/null +++ b/quiche/quic/core/quic_default_packet_writer.cc @@ -0,0 +1,65 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_default_packet_writer.h" + +#include "quiche/quic/core/quic_udp_socket.h" + +namespace quic { + +QuicDefaultPacketWriter::QuicDefaultPacketWriter(int fd) + : fd_(fd), write_blocked_(false) {} + +QuicDefaultPacketWriter::~QuicDefaultPacketWriter() = default; + +WriteResult QuicDefaultPacketWriter::WritePacket( + const char* buffer, size_t buf_len, const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, PerPacketOptions* options) { + QUICHE_DCHECK(!write_blocked_); + QuicUdpPacketInfo packet_info; + packet_info.SetPeerAddress(peer_address); + packet_info.SetSelfIp(self_address); + if (options != nullptr) { + packet_info.SetEcnCodepoint(options->ecn_codepoint); + } + WriteResult result = + QuicUdpSocketApi().WritePacket(fd_, buffer, buf_len, packet_info); + if (IsWriteBlockedStatus(result.status)) { + write_blocked_ = true; + } + return result; +} + +bool QuicDefaultPacketWriter::IsWriteBlocked() const { return write_blocked_; } + +void QuicDefaultPacketWriter::SetWritable() { write_blocked_ = false; } + +absl::optional QuicDefaultPacketWriter::MessageTooBigErrorCode() const { + return EMSGSIZE; +} + +QuicByteCount QuicDefaultPacketWriter::GetMaxPacketSize( + const QuicSocketAddress& /*peer_address*/) const { + return kMaxOutgoingPacketSize; +} + +bool QuicDefaultPacketWriter::SupportsReleaseTime() const { return false; } + +bool QuicDefaultPacketWriter::IsBatchMode() const { return false; } + +QuicPacketBuffer QuicDefaultPacketWriter::GetNextWriteLocation( + const QuicIpAddress& /*self_address*/, + const QuicSocketAddress& /*peer_address*/) { + return {nullptr, nullptr}; +} + +WriteResult QuicDefaultPacketWriter::Flush() { + return WriteResult(WRITE_STATUS_OK, 0); +} + +void QuicDefaultPacketWriter::set_write_blocked(bool is_blocked) { + write_blocked_ = is_blocked; +} + +} // namespace quic diff --git a/quiche/quic/core/quic_default_packet_writer.h b/quiche/quic/core/quic_default_packet_writer.h new file mode 100644 index 000000000000..b2a0a86c4399 --- /dev/null +++ b/quiche/quic/core/quic_default_packet_writer.h @@ -0,0 +1,56 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_DEFAULT_PACKET_WRITER_H_ +#define QUICHE_QUIC_CORE_QUIC_DEFAULT_PACKET_WRITER_H_ + +#include + +#include "quiche/quic/core/quic_packet_writer.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_socket_address.h" + +namespace quic { + +struct WriteResult; + +// Default packet writer which wraps QuicSocketUtils WritePacket. +class QUIC_EXPORT_PRIVATE QuicDefaultPacketWriter : public QuicPacketWriter { + public: + explicit QuicDefaultPacketWriter(int fd); + QuicDefaultPacketWriter(const QuicDefaultPacketWriter&) = delete; + QuicDefaultPacketWriter& operator=(const QuicDefaultPacketWriter&) = delete; + ~QuicDefaultPacketWriter() override; + + // QuicPacketWriter + WriteResult WritePacket(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + PerPacketOptions* options) override; + bool IsWriteBlocked() const override; + void SetWritable() override; + absl::optional MessageTooBigErrorCode() const override; + QuicByteCount GetMaxPacketSize( + const QuicSocketAddress& peer_address) const override; + bool SupportsReleaseTime() const override; + bool IsBatchMode() const override; + QuicPacketBuffer GetNextWriteLocation( + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address) override; + WriteResult Flush() override; + + void set_fd(int fd) { fd_ = fd; } + + protected: + void set_write_blocked(bool is_blocked); + int fd() { return fd_; } + + private: + int fd_; + bool write_blocked_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_DEFAULT_PACKET_WRITER_H_ diff --git a/quiche/quic/core/quic_dispatcher.cc b/quiche/quic/core/quic_dispatcher.cc new file mode 100644 index 000000000000..9b1b470c3801 --- /dev/null +++ b/quiche/quic/core/quic_dispatcher.cc @@ -0,0 +1,1382 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_dispatcher.h" + +#include +#include +#include + +#include "absl/base/optimization.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/chlo_extractor.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_time_wait_list_manager.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/core/tls_chlo_extractor.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/platform/api/quic_stack_trace.h" +#include "quiche/common/quiche_text_utils.h" + +namespace quic { + +using BufferedPacket = QuicBufferedPacketStore::BufferedPacket; +using BufferedPacketList = QuicBufferedPacketStore::BufferedPacketList; +using EnqueuePacketResult = QuicBufferedPacketStore::EnqueuePacketResult; + +namespace { + +// Minimal INITIAL packet length sent by clients is 1200. +const QuicPacketLength kMinClientInitialPacketLength = 1200; + +// An alarm that informs the QuicDispatcher to delete old sessions. +class DeleteSessionsAlarm : public QuicAlarm::DelegateWithoutContext { + public: + explicit DeleteSessionsAlarm(QuicDispatcher* dispatcher) + : dispatcher_(dispatcher) {} + DeleteSessionsAlarm(const DeleteSessionsAlarm&) = delete; + DeleteSessionsAlarm& operator=(const DeleteSessionsAlarm&) = delete; + + void OnAlarm() override { dispatcher_->DeleteSessions(); } + + private: + // Not owned. + QuicDispatcher* dispatcher_; +}; + +// An alarm that informs the QuicDispatcher to clear +// recent_stateless_reset_addresses_. +class ClearStatelessResetAddressesAlarm + : public QuicAlarm::DelegateWithoutContext { + public: + explicit ClearStatelessResetAddressesAlarm(QuicDispatcher* dispatcher) + : dispatcher_(dispatcher) {} + ClearStatelessResetAddressesAlarm(const DeleteSessionsAlarm&) = delete; + ClearStatelessResetAddressesAlarm& operator=(const DeleteSessionsAlarm&) = + delete; + + void OnAlarm() override { dispatcher_->ClearStatelessResetAddresses(); } + + private: + // Not owned. + QuicDispatcher* dispatcher_; +}; + +// Collects packets serialized by a QuicPacketCreator in order +// to be handed off to the time wait list manager. +class PacketCollector : public QuicPacketCreator::DelegateInterface, + public QuicStreamFrameDataProducer { + public: + explicit PacketCollector(quiche::QuicheBufferAllocator* allocator) + : send_buffer_(allocator) {} + ~PacketCollector() override = default; + + // QuicPacketCreator::DelegateInterface methods: + void OnSerializedPacket(SerializedPacket serialized_packet) override { + // Make a copy of the serialized packet to send later. + packets_.emplace_back( + new QuicEncryptedPacket(CopyBuffer(serialized_packet), + serialized_packet.encrypted_length, true)); + } + + QuicPacketBuffer GetPacketBuffer() override { + // Let QuicPacketCreator to serialize packets on stack buffer. + return {nullptr, nullptr}; + } + + void OnUnrecoverableError(QuicErrorCode /*error*/, + const std::string& /*error_details*/) override {} + + bool ShouldGeneratePacket(HasRetransmittableData /*retransmittable*/, + IsHandshake /*handshake*/) override { + QUICHE_DCHECK(false); + return true; + } + + const QuicFrames MaybeBundleAckOpportunistically() override { + QUICHE_DCHECK(false); + return {}; + } + + SerializedPacketFate GetSerializedPacketFate( + bool /*is_mtu_discovery*/, + EncryptionLevel /*encryption_level*/) override { + return SEND_TO_WRITER; + } + + // QuicStreamFrameDataProducer + WriteStreamDataResult WriteStreamData(QuicStreamId /*id*/, + QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* writer) override { + if (send_buffer_.WriteStreamData(offset, data_length, writer)) { + return WRITE_SUCCESS; + } + return WRITE_FAILED; + } + bool WriteCryptoData(EncryptionLevel /*level*/, QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* writer) override { + return send_buffer_.WriteStreamData(offset, data_length, writer); + } + + std::vector>* packets() { + return &packets_; + } + + private: + std::vector> packets_; + // This is only needed until the packets are encrypted. Once packets are + // encrypted, the stream data is no longer required. + QuicStreamSendBuffer send_buffer_; +}; + +// Helper for statelessly closing connections by generating the +// correct termination packets and adding the connection to the time wait +// list manager. +class StatelessConnectionTerminator { + public: + StatelessConnectionTerminator(QuicConnectionId server_connection_id, + QuicConnectionId original_server_connection_id, + const ParsedQuicVersion version, + QuicConnectionHelperInterface* helper, + QuicTimeWaitListManager* time_wait_list_manager) + : server_connection_id_(server_connection_id), + framer_(ParsedQuicVersionVector{version}, + /*unused*/ QuicTime::Zero(), Perspective::IS_SERVER, + /*unused*/ kQuicDefaultConnectionIdLength), + collector_(helper->GetStreamSendBufferAllocator()), + creator_(server_connection_id, &framer_, &collector_), + time_wait_list_manager_(time_wait_list_manager) { + framer_.set_data_producer(&collector_); + // Always set encrypter with original_server_connection_id. + framer_.SetInitialObfuscators(original_server_connection_id); + } + + ~StatelessConnectionTerminator() { + // Clear framer's producer. + framer_.set_data_producer(nullptr); + } + + // Generates a packet containing a CONNECTION_CLOSE frame specifying + // |error_code| and |error_details| and add the connection to time wait. + void CloseConnection(QuicErrorCode error_code, + const std::string& error_details, bool ietf_quic, + std::vector active_connection_ids) { + SerializeConnectionClosePacket(error_code, error_details); + + time_wait_list_manager_->AddConnectionIdToTimeWait( + QuicTimeWaitListManager::SEND_TERMINATION_PACKETS, + TimeWaitConnectionInfo(ietf_quic, collector_.packets(), + std::move(active_connection_ids), + /*srtt=*/QuicTime::Delta::Zero())); + } + + private: + void SerializeConnectionClosePacket(QuicErrorCode error_code, + const std::string& error_details) { + QuicConnectionCloseFrame* frame = + new QuicConnectionCloseFrame(framer_.transport_version(), error_code, + NO_IETF_QUIC_ERROR, error_details, + /*transport_close_frame_type=*/0); + + if (!creator_.AddFrame(QuicFrame(frame), NOT_RETRANSMISSION)) { + QUIC_BUG(quic_bug_10287_1) << "Unable to add frame to an empty packet"; + delete frame; + return; + } + creator_.FlushCurrentPacket(); + QUICHE_DCHECK_EQ(1u, collector_.packets()->size()); + } + + QuicConnectionId server_connection_id_; + QuicFramer framer_; + // Set as the visitor of |creator_| to collect any generated packets. + PacketCollector collector_; + QuicPacketCreator creator_; + QuicTimeWaitListManager* time_wait_list_manager_; +}; + +// Class which extracts the ALPN and SNI from a QUIC_CRYPTO CHLO packet. +class ChloAlpnSniExtractor : public ChloExtractor::Delegate { + public: + void OnChlo(QuicTransportVersion /*version*/, + QuicConnectionId /*server_connection_id*/, + const CryptoHandshakeMessage& chlo) override { + absl::string_view alpn_value; + if (chlo.GetStringPiece(kALPN, &alpn_value)) { + alpn_ = std::string(alpn_value); + } + absl::string_view sni; + if (chlo.GetStringPiece(quic::kSNI, &sni)) { + sni_ = std::string(sni); + } + absl::string_view uaid_value; + if (chlo.GetStringPiece(quic::kUAID, &uaid_value)) { + uaid_ = std::string(uaid_value); + } + } + + std::string&& ConsumeAlpn() { return std::move(alpn_); } + + std::string&& ConsumeSni() { return std::move(sni_); } + + std::string&& ConsumeUaid() { return std::move(uaid_); } + + private: + std::string alpn_; + std::string sni_; + std::string uaid_; +}; + +} // namespace + +QuicDispatcher::QuicDispatcher( + const QuicConfig* config, const QuicCryptoServerConfig* crypto_config, + QuicVersionManager* version_manager, + std::unique_ptr helper, + std::unique_ptr session_helper, + std::unique_ptr alarm_factory, + uint8_t expected_server_connection_id_length, + ConnectionIdGeneratorInterface& connection_id_generator) + : config_(config), + crypto_config_(crypto_config), + compressed_certs_cache_( + QuicCompressedCertsCache::kQuicCompressedCertsCacheSize), + helper_(std::move(helper)), + session_helper_(std::move(session_helper)), + alarm_factory_(std::move(alarm_factory)), + delete_sessions_alarm_( + alarm_factory_->CreateAlarm(new DeleteSessionsAlarm(this))), + buffered_packets_(this, helper_->GetClock(), alarm_factory_.get()), + version_manager_(version_manager), + last_error_(QUIC_NO_ERROR), + new_sessions_allowed_per_event_loop_(0u), + accept_new_connections_(true), + allow_short_initial_server_connection_ids_(false), + expected_server_connection_id_length_( + expected_server_connection_id_length), + clear_stateless_reset_addresses_alarm_(alarm_factory_->CreateAlarm( + new ClearStatelessResetAddressesAlarm(this))), + should_update_expected_server_connection_id_length_(false), + connection_id_generator_(connection_id_generator) { + QUIC_BUG_IF(quic_bug_12724_1, GetSupportedVersions().empty()) + << "Trying to create dispatcher without any supported versions"; + QUIC_DLOG(INFO) << "Created QuicDispatcher with versions: " + << ParsedQuicVersionVectorToString(GetSupportedVersions()); +} + +QuicDispatcher::~QuicDispatcher() { + if (delete_sessions_alarm_ != nullptr) { + delete_sessions_alarm_->PermanentCancel(); + } + if (clear_stateless_reset_addresses_alarm_ != nullptr) { + clear_stateless_reset_addresses_alarm_->PermanentCancel(); + } + reference_counted_session_map_.clear(); + closed_session_list_.clear(); + num_sessions_in_session_map_ = 0; +} + +void QuicDispatcher::InitializeWithWriter(QuicPacketWriter* writer) { + QUICHE_DCHECK(writer_ == nullptr); + writer_.reset(writer); + time_wait_list_manager_.reset(CreateQuicTimeWaitListManager()); +} + +void QuicDispatcher::ProcessPacket(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + const QuicReceivedPacket& packet) { + QUIC_DVLOG(2) << "Dispatcher received encrypted " << packet.length() + << " bytes:" << std::endl + << quiche::QuicheTextUtils::HexDump( + absl::string_view(packet.data(), packet.length())); + ReceivedPacketInfo packet_info(self_address, peer_address, packet); + std::string detailed_error; + QuicErrorCode error; + error = QuicFramer::ParsePublicHeaderDispatcherShortHeaderLengthUnknown( + packet, &packet_info.form, &packet_info.long_packet_type, + &packet_info.version_flag, &packet_info.use_length_prefix, + &packet_info.version_label, &packet_info.version, + &packet_info.destination_connection_id, &packet_info.source_connection_id, + &packet_info.retry_token, &detailed_error, connection_id_generator_); + + if (error != QUIC_NO_ERROR) { + // Packet has framing error. + SetLastError(error); + QUIC_DLOG(ERROR) << detailed_error; + return; + } + if (packet_info.destination_connection_id.length() != + expected_server_connection_id_length_ && + !should_update_expected_server_connection_id_length_ && + packet_info.version.IsKnown() && + !packet_info.version.AllowsVariableLengthConnectionIds()) { + SetLastError(QUIC_INVALID_PACKET_HEADER); + QUIC_DLOG(ERROR) << "Invalid Connection Id Length"; + return; + } + + if (packet_info.version_flag && IsSupportedVersion(packet_info.version)) { + if (!QuicUtils::IsConnectionIdValidForVersion( + packet_info.destination_connection_id, + packet_info.version.transport_version)) { + SetLastError(QUIC_INVALID_PACKET_HEADER); + QUIC_DLOG(ERROR) + << "Invalid destination connection ID length for version"; + return; + } + if (packet_info.version.SupportsClientConnectionIds() && + !QuicUtils::IsConnectionIdValidForVersion( + packet_info.source_connection_id, + packet_info.version.transport_version)) { + SetLastError(QUIC_INVALID_PACKET_HEADER); + QUIC_DLOG(ERROR) << "Invalid source connection ID length for version"; + return; + } + } + + // Before introducing the flag, it was impossible for a short header to + // update |expected_server_connection_id_length_|. + if (should_update_expected_server_connection_id_length_ && + packet_info.version_flag) { + expected_server_connection_id_length_ = + packet_info.destination_connection_id.length(); + } + + if (MaybeDispatchPacket(packet_info)) { + // Packet has been dropped or successfully dispatched, stop processing. + return; + } + // The framer might have extracted the incorrect Connection ID length from a + // short header. |packet| could be gQUIC; if Q043, the connection ID has been + // parsed correctly thanks to the fixed bit. If a Q046 or Q050 short header, + // the dispatcher might have assumed it was a long connection ID when (because + // it was gQUIC) it actually issued or kept an 8-byte ID. The other case is + // where NEW_CONNECTION_IDs are not using the generator, and the dispatcher + // is, due to flag misconfiguration. + if (!packet_info.version_flag && + (IsSupportedVersion(ParsedQuicVersion::Q046()) || + IsSupportedVersion(ParsedQuicVersion::Q050()))) { + ReceivedPacketInfo gquic_packet_info(self_address, peer_address, packet); + // Try again without asking |connection_id_generator_| for the length. + const QuicErrorCode gquic_error = QuicFramer::ParsePublicHeaderDispatcher( + packet, expected_server_connection_id_length_, &gquic_packet_info.form, + &gquic_packet_info.long_packet_type, &gquic_packet_info.version_flag, + &gquic_packet_info.use_length_prefix, &gquic_packet_info.version_label, + &gquic_packet_info.version, + &gquic_packet_info.destination_connection_id, + &gquic_packet_info.source_connection_id, &gquic_packet_info.retry_token, + &detailed_error); + if (gquic_error == QUIC_NO_ERROR) { + if (MaybeDispatchPacket(gquic_packet_info)) { + return; + } + } else { + QUICHE_VLOG(1) << "Tried to parse short header as gQUIC packet: " + << detailed_error; + } + } + ProcessHeader(&packet_info); +} + +namespace { +constexpr bool IsSourceUdpPortBlocked(uint16_t port) { + // These UDP source ports have been observed in large scale denial of service + // attacks and are not expected to ever carry user traffic, they are therefore + // blocked as a safety measure. See draft-ietf-quic-applicability for details. + constexpr uint16_t blocked_ports[] = { + 0, // We cannot send to port 0 so drop that source port. + 17, // Quote of the Day, can loop with QUIC. + 19, // Chargen, can loop with QUIC. + 53, // DNS, vulnerable to reflection attacks. + 111, // Portmap. + 123, // NTP, vulnerable to reflection attacks. + 137, // NETBIOS Name Service, + 138, // NETBIOS Datagram Service + 161, // SNMP. + 389, // CLDAP. + 500, // IKE, can loop with QUIC. + 1900, // SSDP, vulnerable to reflection attacks. + 3702, // WS-Discovery, vulnerable to reflection attacks. + 5353, // mDNS, vulnerable to reflection attacks. + 5355, // LLMNR, vulnerable to reflection attacks. + 11211, // memcache, vulnerable to reflection attacks. + // This list MUST be sorted in increasing order. + }; + constexpr size_t num_blocked_ports = ABSL_ARRAYSIZE(blocked_ports); + constexpr uint16_t highest_blocked_port = + blocked_ports[num_blocked_ports - 1]; + if (ABSL_PREDICT_TRUE(port > highest_blocked_port)) { + // Early-return to skip comparisons for the majority of traffic. + return false; + } + for (size_t i = 0; i < num_blocked_ports; i++) { + if (port == blocked_ports[i]) { + return true; + } + } + return false; +} +} // namespace + +bool QuicDispatcher::MaybeDispatchPacket( + const ReceivedPacketInfo& packet_info) { + if (IsSourceUdpPortBlocked(packet_info.peer_address.port())) { + // Silently drop the received packet. + QUIC_CODE_COUNT(quic_dropped_blocked_port); + return true; + } + + const QuicConnectionId server_connection_id = + packet_info.destination_connection_id; + + // The IETF spec requires the client to generate an initial server + // connection ID that is at least 64 bits long. After that initial + // connection ID, the dispatcher picks a new one of its expected length. + // Therefore we should never receive a connection ID that is smaller + // than 64 bits and smaller than what we expect. Unless the version is + // unknown, in which case we allow short connection IDs for version + // negotiation because that version could allow those. + if (packet_info.version_flag && packet_info.version.IsKnown() && + IsServerConnectionIdTooShort(server_connection_id)) { + QUICHE_DCHECK(packet_info.version_flag); + QUICHE_DCHECK(packet_info.version.AllowsVariableLengthConnectionIds()); + QUIC_DLOG(INFO) << "Packet with short destination connection ID " + << server_connection_id << " expected " + << static_cast(expected_server_connection_id_length_); + // Drop the packet silently. + QUIC_CODE_COUNT(quic_dropped_invalid_small_initial_connection_id); + return true; + } + + if (packet_info.version_flag && packet_info.version.IsKnown() && + !QuicUtils::IsConnectionIdLengthValidForVersion( + server_connection_id.length(), + packet_info.version.transport_version)) { + QUIC_DLOG(INFO) << "Packet with destination connection ID " + << server_connection_id << " is invalid with version " + << packet_info.version; + // Drop the packet silently. + QUIC_CODE_COUNT(quic_dropped_invalid_initial_connection_id); + return true; + } + + // Packets with connection IDs for active connections are processed + // immediately. + auto it = reference_counted_session_map_.find(server_connection_id); + if (it != reference_counted_session_map_.end()) { + QUICHE_DCHECK(!buffered_packets_.HasBufferedPackets(server_connection_id)); + it->second->ProcessUdpPacket(packet_info.self_address, + packet_info.peer_address, packet_info.packet); + return true; + } + + if (buffered_packets_.HasChloForConnection(server_connection_id)) { + BufferEarlyPacket(packet_info); + return true; + } + + if (OnFailedToDispatchPacket(packet_info)) { + return true; + } + + if (time_wait_list_manager_->IsConnectionIdInTimeWait(server_connection_id)) { + // This connection ID is already in time-wait state. + time_wait_list_manager_->ProcessPacket( + packet_info.self_address, packet_info.peer_address, + packet_info.destination_connection_id, packet_info.form, + packet_info.packet.length(), GetPerPacketContext()); + return true; + } + + // The packet has an unknown connection ID. + if (!accept_new_connections_ && packet_info.version_flag) { + // If not accepting new connections, reject packets with version which can + // potentially result in new connection creation. But if the packet doesn't + // have version flag, leave it to ValidityChecks() to reset it. + // By adding the connection to time wait list, following packets on this + // connection will not reach ShouldAcceptNewConnections(). + StatelesslyTerminateConnection( + packet_info.destination_connection_id, packet_info.form, + packet_info.version_flag, packet_info.use_length_prefix, + packet_info.version, QUIC_HANDSHAKE_FAILED, + "Stop accepting new connections", + quic::QuicTimeWaitListManager::SEND_STATELESS_RESET); + // Time wait list will reject the packet correspondingly.. + time_wait_list_manager()->ProcessPacket( + packet_info.self_address, packet_info.peer_address, + packet_info.destination_connection_id, packet_info.form, + packet_info.packet.length(), GetPerPacketContext()); + OnNewConnectionRejected(); + return true; + } + + // Unless the packet provides a version, assume that we can continue + // processing using our preferred version. + if (packet_info.version_flag) { + if (!IsSupportedVersion(packet_info.version)) { + if (ShouldCreateSessionForUnknownVersion(packet_info.version_label)) { + return false; + } + // Since the version is not supported, send a version negotiation + // packet and stop processing the current packet. + MaybeSendVersionNegotiationPacket(packet_info); + return true; + } + + if (crypto_config()->validate_chlo_size() && + packet_info.form == IETF_QUIC_LONG_HEADER_PACKET && + packet_info.long_packet_type == INITIAL && + packet_info.packet.length() < kMinClientInitialPacketLength) { + QUIC_DVLOG(1) << "Dropping initial packet which is too short, length: " + << packet_info.packet.length(); + QUIC_CODE_COUNT(quic_drop_small_initial_packets); + return true; + } + } + + return false; +} + +void QuicDispatcher::ProcessHeader(ReceivedPacketInfo* packet_info) { + QuicConnectionId server_connection_id = + packet_info->destination_connection_id; + // Packet's connection ID is unknown. Apply the validity checks. + QuicPacketFate fate = ValidityChecks(*packet_info); + + // |connection_close_error_code| is used if the final packet fate is + // kFateTimeWait. + QuicErrorCode connection_close_error_code = QUIC_HANDSHAKE_FAILED; + + // If a fatal TLS alert was received when extracting Client Hello, + // |tls_alert_error_detail| will be set and will be used as the error_details + // of the connection close. + std::string tls_alert_error_detail; + + if (fate == kFateProcess) { + ExtractChloResult extract_chlo_result = + TryExtractChloOrBufferEarlyPacket(*packet_info); + auto& parsed_chlo = extract_chlo_result.parsed_chlo; + + if (extract_chlo_result.tls_alert.has_value()) { + QUIC_BUG_IF(quic_dispatcher_parsed_chlo_and_tls_alert_coexist_1, + parsed_chlo.has_value()) + << "parsed_chlo and tls_alert should not be set at the same time."; + // Fatal TLS alert when parsing Client Hello. + fate = kFateTimeWait; + uint8_t tls_alert = *extract_chlo_result.tls_alert; + connection_close_error_code = TlsAlertToQuicErrorCode(tls_alert); + tls_alert_error_detail = + absl::StrCat("TLS handshake failure (", + EncryptionLevelToString(ENCRYPTION_INITIAL), ") ", + static_cast(tls_alert), ": ", + SSL_alert_desc_string_long(tls_alert)); + } else if (!parsed_chlo.has_value()) { + // Client Hello incomplete. Packet has been buffered or (rarely) dropped. + return; + } else { + // Client Hello fully received. + fate = ValidityChecksOnFullChlo(*packet_info, *parsed_chlo); + + if (fate == kFateProcess) { + ProcessChlo(*std::move(parsed_chlo), packet_info); + return; + } + } + } + + switch (fate) { + case kFateProcess: + // kFateProcess have been processed above. + QUIC_BUG(quic_dispatcher_bad_packet_fate) << fate; + break; + case kFateTimeWait: { + // Add this connection_id to the time-wait state, to safely reject + // future packets. + QUIC_DLOG(INFO) << "Adding connection ID " << server_connection_id + << " to time-wait list."; + QUIC_CODE_COUNT(quic_reject_fate_time_wait); + const std::string& connection_close_error_detail = + tls_alert_error_detail.empty() ? "Reject connection" + : tls_alert_error_detail; + StatelesslyTerminateConnection( + server_connection_id, packet_info->form, packet_info->version_flag, + packet_info->use_length_prefix, packet_info->version, + connection_close_error_code, connection_close_error_detail, + quic::QuicTimeWaitListManager::SEND_STATELESS_RESET); + + QUICHE_DCHECK(time_wait_list_manager_->IsConnectionIdInTimeWait( + server_connection_id)); + time_wait_list_manager_->ProcessPacket( + packet_info->self_address, packet_info->peer_address, + server_connection_id, packet_info->form, packet_info->packet.length(), + GetPerPacketContext()); + + buffered_packets_.DiscardPackets(server_connection_id); + } break; + case kFateDrop: + break; + } +} + +QuicDispatcher::ExtractChloResult +QuicDispatcher::TryExtractChloOrBufferEarlyPacket( + const ReceivedPacketInfo& packet_info) { + ExtractChloResult result; + if (packet_info.version.UsesTls()) { + bool has_full_tls_chlo = false; + std::string sni; + std::vector alpns; + bool resumption_attempted = false, early_data_attempted = false; + if (buffered_packets_.HasBufferedPackets( + packet_info.destination_connection_id)) { + // If we already have buffered packets for this connection ID, + // use the associated TlsChloExtractor to parse this packet. + has_full_tls_chlo = buffered_packets_.IngestPacketForTlsChloExtraction( + packet_info.destination_connection_id, packet_info.version, + packet_info.packet, &alpns, &sni, &resumption_attempted, + &early_data_attempted, &result.tls_alert); + } else { + // If we do not have a BufferedPacketList for this connection ID, + // create a single-use one to check whether this packet contains a + // full single-packet CHLO. + TlsChloExtractor tls_chlo_extractor; + tls_chlo_extractor.IngestPacket(packet_info.version, packet_info.packet); + if (tls_chlo_extractor.HasParsedFullChlo()) { + // This packet contains a full single-packet CHLO. + has_full_tls_chlo = true; + alpns = tls_chlo_extractor.alpns(); + sni = tls_chlo_extractor.server_name(); + resumption_attempted = tls_chlo_extractor.resumption_attempted(); + early_data_attempted = tls_chlo_extractor.early_data_attempted(); + } else { + result.tls_alert = tls_chlo_extractor.tls_alert(); + } + } + + if (result.tls_alert.has_value()) { + QUIC_BUG_IF(quic_dispatcher_parsed_chlo_and_tls_alert_coexist_2, + has_full_tls_chlo) + << "parsed_chlo and tls_alert should not be set at the same time."; + return result; + } + + if (GetQuicFlag(quic_allow_chlo_buffering) && !has_full_tls_chlo) { + // This packet does not contain a full CHLO. It could be a 0-RTT + // packet that arrived before the CHLO (due to loss or reordering), + // or it could be a fragment of a multi-packet CHLO. + BufferEarlyPacket(packet_info); + return result; + } + + ParsedClientHello& parsed_chlo = result.parsed_chlo.emplace(); + parsed_chlo.sni = std::move(sni); + parsed_chlo.alpns = std::move(alpns); + if (packet_info.retry_token.has_value()) { + parsed_chlo.retry_token = std::string(*packet_info.retry_token); + } + parsed_chlo.resumption_attempted = resumption_attempted; + parsed_chlo.early_data_attempted = early_data_attempted; + return result; + } + + ChloAlpnSniExtractor alpn_extractor; + if (GetQuicFlag(quic_allow_chlo_buffering) && + !ChloExtractor::Extract(packet_info.packet, packet_info.version, + config_->create_session_tag_indicators(), + &alpn_extractor, + packet_info.destination_connection_id.length())) { + // Buffer non-CHLO packets. + BufferEarlyPacket(packet_info); + return result; + } + + // We only apply this check for versions that do not use the IETF + // invariant header because those versions are already checked in + // QuicDispatcher::MaybeDispatchPacket. + if (packet_info.version_flag && + !packet_info.version.HasIetfInvariantHeader() && + crypto_config()->validate_chlo_size() && + packet_info.packet.length() < kMinClientInitialPacketLength) { + QUIC_DVLOG(1) << "Dropping CHLO packet which is too short, length: " + << packet_info.packet.length(); + QUIC_CODE_COUNT(quic_drop_small_chlo_packets); + return result; + } + + ParsedClientHello& parsed_chlo = result.parsed_chlo.emplace(); + parsed_chlo.sni = alpn_extractor.ConsumeSni(); + parsed_chlo.uaid = alpn_extractor.ConsumeUaid(); + parsed_chlo.alpns = {alpn_extractor.ConsumeAlpn()}; + return result; +} + +std::string QuicDispatcher::SelectAlpn(const std::vector& alpns) { + if (alpns.empty()) { + return ""; + } + if (alpns.size() > 1u) { + const std::vector& supported_alpns = + version_manager_->GetSupportedAlpns(); + for (const std::string& alpn : alpns) { + if (std::find(supported_alpns.begin(), supported_alpns.end(), alpn) != + supported_alpns.end()) { + return alpn; + } + } + } + return alpns[0]; +} + +QuicDispatcher::QuicPacketFate QuicDispatcher::ValidityChecks( + const ReceivedPacketInfo& packet_info) { + if (!packet_info.version_flag) { + QUIC_DLOG(INFO) + << "Packet without version arrived for unknown connection ID " + << packet_info.destination_connection_id; + MaybeResetPacketsWithNoVersion(packet_info); + return kFateDrop; + } + + // Let the connection parse and validate packet number. + return kFateProcess; +} + +void QuicDispatcher::CleanUpSession(QuicConnectionId server_connection_id, + QuicConnection* connection, + QuicErrorCode /*error*/, + const std::string& /*error_details*/, + ConnectionCloseSource /*source*/) { + write_blocked_list_.erase(connection); + QuicTimeWaitListManager::TimeWaitAction action = + QuicTimeWaitListManager::SEND_STATELESS_RESET; + if (connection->termination_packets() != nullptr && + !connection->termination_packets()->empty()) { + action = QuicTimeWaitListManager::SEND_CONNECTION_CLOSE_PACKETS; + } else { + if (!connection->IsHandshakeComplete()) { + // TODO(fayang): Do not serialize connection close packet if the + // connection is closed by the client. + if (!connection->version().HasIetfInvariantHeader()) { + QUIC_CODE_COUNT(gquic_add_to_time_wait_list_with_handshake_failed); + } else { + QUIC_CODE_COUNT(quic_v44_add_to_time_wait_list_with_handshake_failed); + } + // This serializes a connection close termination packet and adds the + // connection to the time wait list. + StatelessConnectionTerminator terminator( + server_connection_id, + connection->GetOriginalDestinationConnectionId(), + connection->version(), helper_.get(), time_wait_list_manager_.get()); + terminator.CloseConnection( + QUIC_HANDSHAKE_FAILED, + "Connection is closed by server before handshake confirmed", + connection->version().HasIetfInvariantHeader(), + connection->GetActiveServerConnectionIds()); + return; + } + QUIC_CODE_COUNT(quic_v44_add_to_time_wait_list_with_stateless_reset); + } + time_wait_list_manager_->AddConnectionIdToTimeWait( + action, + TimeWaitConnectionInfo( + connection->version().HasIetfInvariantHeader(), + connection->termination_packets(), + connection->GetActiveServerConnectionIds(), + connection->sent_packet_manager().GetRttStats()->smoothed_rtt())); +} + +void QuicDispatcher::StartAcceptingNewConnections() { + accept_new_connections_ = true; +} + +void QuicDispatcher::StopAcceptingNewConnections() { + accept_new_connections_ = false; + // No more CHLO will arrive and buffered CHLOs shouldn't be able to create + // connections. + buffered_packets_.DiscardAllPackets(); +} + +void QuicDispatcher::PerformActionOnActiveSessions( + std::function operation) const { + absl::flat_hash_set visited_session; + visited_session.reserve(reference_counted_session_map_.size()); + for (auto const& kv : reference_counted_session_map_) { + QuicSession* session = kv.second.get(); + if (visited_session.insert(session).second) { + operation(session); + } + } +} + +// Get a snapshot of all sessions. +std::vector> QuicDispatcher::GetSessionsSnapshot() + const { + std::vector> snapshot; + snapshot.reserve(reference_counted_session_map_.size()); + absl::flat_hash_set visited_session; + visited_session.reserve(reference_counted_session_map_.size()); + for (auto const& kv : reference_counted_session_map_) { + QuicSession* session = kv.second.get(); + if (visited_session.insert(session).second) { + snapshot.push_back(kv.second); + } + } + return snapshot; +} + +std::unique_ptr QuicDispatcher::GetPerPacketContext() + const { + return nullptr; +} + +void QuicDispatcher::DeleteSessions() { + if (!write_blocked_list_.empty()) { + for (const auto& session : closed_session_list_) { + if (write_blocked_list_.erase(session->connection()) != 0) { + QUIC_BUG(quic_bug_12724_2) + << "QuicConnection was in WriteBlockedList before destruction " + << session->connection()->connection_id(); + } + } + } + closed_session_list_.clear(); +} + +void QuicDispatcher::ClearStatelessResetAddresses() { + recent_stateless_reset_addresses_.clear(); +} + +void QuicDispatcher::OnCanWrite() { + // The socket is now writable. + writer_->SetWritable(); + + // Move every blocked writer in |write_blocked_list_| to a temporary list. + const size_t num_blocked_writers_before = write_blocked_list_.size(); + WriteBlockedList temp_list; + temp_list.swap(write_blocked_list_); + QUICHE_DCHECK(write_blocked_list_.empty()); + + // Give each blocked writer a chance to write what they indended to write. + // If they are blocked again, they will call |OnWriteBlocked| to add + // themselves back into |write_blocked_list_|. + while (!temp_list.empty()) { + QuicBlockedWriterInterface* blocked_writer = temp_list.begin()->first; + temp_list.erase(temp_list.begin()); + blocked_writer->OnBlockedWriterCanWrite(); + } + const size_t num_blocked_writers_after = write_blocked_list_.size(); + if (num_blocked_writers_after != 0) { + if (num_blocked_writers_before == num_blocked_writers_after) { + QUIC_CODE_COUNT(quic_zero_progress_on_can_write); + } else { + QUIC_CODE_COUNT(quic_blocked_again_on_can_write); + } + } +} + +bool QuicDispatcher::HasPendingWrites() const { + return !write_blocked_list_.empty(); +} + +void QuicDispatcher::Shutdown() { + while (!reference_counted_session_map_.empty()) { + QuicSession* session = reference_counted_session_map_.begin()->second.get(); + session->connection()->CloseConnection( + QUIC_PEER_GOING_AWAY, "Server shutdown imminent", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + // Validate that the session removes itself from the session map on close. + QUICHE_DCHECK(reference_counted_session_map_.empty() || + reference_counted_session_map_.begin()->second.get() != + session); + } + DeleteSessions(); +} + +void QuicDispatcher::OnConnectionClosed(QuicConnectionId server_connection_id, + QuicErrorCode error, + const std::string& error_details, + ConnectionCloseSource source) { + auto it = reference_counted_session_map_.find(server_connection_id); + if (it == reference_counted_session_map_.end()) { + QUIC_BUG(quic_bug_10287_3) << "ConnectionId " << server_connection_id + << " does not exist in the session map. Error: " + << QuicErrorCodeToString(error); + QUIC_BUG(quic_bug_10287_4) << QuicStackTrace(); + return; + } + + QUIC_DLOG_IF(INFO, error != QUIC_NO_ERROR) + << "Closing connection (" << server_connection_id + << ") due to error: " << QuicErrorCodeToString(error) + << ", with details: " << error_details; + + const QuicSession* session = it->second.get(); + QuicConnection* connection = it->second->connection(); + // Set up alarm to fire immediately to bring destruction of this session + // out of current call stack. + if (closed_session_list_.empty()) { + delete_sessions_alarm_->Update(helper()->GetClock()->ApproximateNow(), + QuicTime::Delta::Zero()); + } + closed_session_list_.push_back(std::move(it->second)); + CleanUpSession(it->first, connection, error, error_details, source); + bool session_removed = false; + for (const QuicConnectionId& cid : + connection->GetActiveServerConnectionIds()) { + auto it1 = reference_counted_session_map_.find(cid); + if (it1 != reference_counted_session_map_.end()) { + const QuicSession* session2 = it1->second.get(); + // For cid == server_connection_id, session2 is a nullptr (and hence + // session2 != session) now since we have std::move the session into + // closed_session_list_ above. + if (session2 == session || cid == server_connection_id) { + reference_counted_session_map_.erase(it1); + session_removed = true; + } else { + // Leave this session in the map. + QUIC_BUG(quic_dispatcher_session_mismatch) + << "Session is mismatched in the map. server_connection_id: " + << server_connection_id << ". Current cid: " << cid + << ". Cid of the other session " + << (session2 == nullptr + ? "null" + : session2->connection()->connection_id().ToString()); + } + } else { + // GetActiveServerConnectionIds might return the original destination + // ID, which is not contained in the session map. + QUIC_BUG_IF(quic_dispatcher_session_not_found, + cid != connection->GetOriginalDestinationConnectionId()) + << "Missing session for cid " << cid + << ". server_connection_id: " << server_connection_id; + } + } + QUIC_BUG_IF(quic_session_is_not_removed, !session_removed); + --num_sessions_in_session_map_; +} + +void QuicDispatcher::OnWriteBlocked( + QuicBlockedWriterInterface* blocked_writer) { + if (!blocked_writer->IsWriterBlocked()) { + // It is a programming error if this ever happens. When we are sure it is + // not happening, replace it with a QUICHE_DCHECK. + QUIC_BUG(quic_bug_12724_4) + << "Tried to add writer into blocked list when it shouldn't be added"; + // Return without adding the connection to the blocked list, to avoid + // infinite loops in OnCanWrite. + return; + } + + write_blocked_list_.insert(std::make_pair(blocked_writer, true)); +} + +void QuicDispatcher::OnRstStreamReceived(const QuicRstStreamFrame& /*frame*/) {} + +void QuicDispatcher::OnStopSendingReceived( + const QuicStopSendingFrame& /*frame*/) {} + +bool QuicDispatcher::TryAddNewConnectionId( + const QuicConnectionId& server_connection_id, + const QuicConnectionId& new_connection_id) { + auto it = reference_counted_session_map_.find(server_connection_id); + if (it == reference_counted_session_map_.end()) { + QUIC_BUG(quic_bug_10287_7) + << "Couldn't locate the session that issues the connection ID in " + "reference_counted_session_map_. server_connection_id:" + << server_connection_id << " new_connection_id: " << new_connection_id; + return false; + } + // Count new connection ID added to the dispatcher map. + QUIC_RELOADABLE_FLAG_COUNT_N(quic_connection_migration_use_new_cid_v2, 6, 6); + auto insertion_result = reference_counted_session_map_.insert( + std::make_pair(new_connection_id, it->second)); + if (!insertion_result.second) { + QUIC_CODE_COUNT(quic_cid_already_in_session_map); + } + return insertion_result.second; +} + +void QuicDispatcher::OnConnectionIdRetired( + const QuicConnectionId& server_connection_id) { + reference_counted_session_map_.erase(server_connection_id); +} + +void QuicDispatcher::OnConnectionAddedToTimeWaitList( + QuicConnectionId server_connection_id) { + QUIC_DLOG(INFO) << "Connection " << server_connection_id + << " added to time wait list."; +} + +void QuicDispatcher::StatelesslyTerminateConnection( + QuicConnectionId server_connection_id, PacketHeaderFormat format, + bool version_flag, bool use_length_prefix, ParsedQuicVersion version, + QuicErrorCode error_code, const std::string& error_details, + QuicTimeWaitListManager::TimeWaitAction action) { + if (format != IETF_QUIC_LONG_HEADER_PACKET && !version_flag) { + QUIC_DVLOG(1) << "Statelessly terminating " << server_connection_id + << " based on a non-ietf-long packet, action:" << action + << ", error_code:" << error_code + << ", error_details:" << error_details; + time_wait_list_manager_->AddConnectionIdToTimeWait( + action, TimeWaitConnectionInfo(format != GOOGLE_QUIC_PACKET, nullptr, + {server_connection_id})); + return; + } + + // If the version is known and supported by framer, send a connection close. + if (IsSupportedVersion(version)) { + QUIC_DVLOG(1) + << "Statelessly terminating " << server_connection_id + << " based on an ietf-long packet, which has a supported version:" + << version << ", error_code:" << error_code + << ", error_details:" << error_details; + + StatelessConnectionTerminator terminator( + server_connection_id, server_connection_id, version, helper_.get(), + time_wait_list_manager_.get()); + // This also adds the connection to time wait list. + terminator.CloseConnection( + error_code, error_details, format != GOOGLE_QUIC_PACKET, + /*active_connection_ids=*/{server_connection_id}); + QUIC_CODE_COUNT(quic_dispatcher_generated_connection_close); + QuicSession::RecordConnectionCloseAtServer( + error_code, ConnectionCloseSource::FROM_SELF); + return; + } + + QUIC_DVLOG(1) + << "Statelessly terminating " << server_connection_id + << " based on an ietf-long packet, which has an unsupported version:" + << version << ", error_code:" << error_code + << ", error_details:" << error_details; + // Version is unknown or unsupported by framer, send a version negotiation + // with an empty version list, which can be understood by the client. + std::vector> termination_packets; + termination_packets.push_back(QuicFramer::BuildVersionNegotiationPacket( + server_connection_id, EmptyQuicConnectionId(), + /*ietf_quic=*/format != GOOGLE_QUIC_PACKET, use_length_prefix, + /*versions=*/{})); + time_wait_list_manager()->AddConnectionIdToTimeWait( + QuicTimeWaitListManager::SEND_TERMINATION_PACKETS, + TimeWaitConnectionInfo(/*ietf_quic=*/format != GOOGLE_QUIC_PACKET, + &termination_packets, {server_connection_id})); +} + +bool QuicDispatcher::ShouldCreateSessionForUnknownVersion( + QuicVersionLabel /*version_label*/) { + return false; +} + +void QuicDispatcher::OnExpiredPackets( + QuicConnectionId server_connection_id, + BufferedPacketList early_arrived_packets) { + QUIC_CODE_COUNT(quic_reject_buffered_packets_expired); + StatelesslyTerminateConnection( + server_connection_id, + early_arrived_packets.ietf_quic ? IETF_QUIC_LONG_HEADER_PACKET + : GOOGLE_QUIC_PACKET, + /*version_flag=*/true, + early_arrived_packets.version.HasLengthPrefixedConnectionIds(), + early_arrived_packets.version, QUIC_HANDSHAKE_FAILED, + "Packets buffered for too long", + quic::QuicTimeWaitListManager::SEND_STATELESS_RESET); +} + +void QuicDispatcher::ProcessBufferedChlos(size_t max_connections_to_create) { + // Reset the counter before starting creating connections. + new_sessions_allowed_per_event_loop_ = max_connections_to_create; + for (; new_sessions_allowed_per_event_loop_ > 0; + --new_sessions_allowed_per_event_loop_) { + QuicConnectionId server_connection_id; + BufferedPacketList packet_list = + buffered_packets_.DeliverPacketsForNextConnection( + &server_connection_id); + const std::list& packets = packet_list.buffered_packets; + if (packets.empty()) { + return; + } + if (!packet_list.parsed_chlo.has_value()) { + QUIC_BUG(quic_dispatcher_no_parsed_chlo_in_buffered_packets) + << "Buffered connection has no CHLO. connection_id:" + << server_connection_id; + continue; + } + auto session_ptr = QuicDispatcher::CreateSessionFromChlo( + server_connection_id, *packet_list.parsed_chlo, packet_list.version, + packets.front().self_address, packets.front().peer_address); + if (session_ptr != nullptr) { + DeliverPacketsToSession(packets, session_ptr.get()); + } + } +} + +bool QuicDispatcher::HasChlosBuffered() const { + return buffered_packets_.HasChlosBuffered(); +} + +// Return true if there is any packet buffered in the store. +bool QuicDispatcher::HasBufferedPackets(QuicConnectionId server_connection_id) { + return buffered_packets_.HasBufferedPackets(server_connection_id); +} + +void QuicDispatcher::OnBufferPacketFailure( + EnqueuePacketResult result, QuicConnectionId server_connection_id) { + QUIC_DLOG(INFO) << "Fail to buffer packet on connection " + << server_connection_id << " because of " << result; +} + +QuicTimeWaitListManager* QuicDispatcher::CreateQuicTimeWaitListManager() { + return new QuicTimeWaitListManager(writer_.get(), this, helper_->GetClock(), + alarm_factory_.get()); +} + +void QuicDispatcher::BufferEarlyPacket(const ReceivedPacketInfo& packet_info) { + EnqueuePacketResult rs = buffered_packets_.EnqueuePacket( + packet_info.destination_connection_id, + packet_info.form != GOOGLE_QUIC_PACKET, packet_info.packet, + packet_info.self_address, packet_info.peer_address, packet_info.version, + /*parsed_chlo=*/absl::nullopt); + if (rs != EnqueuePacketResult::SUCCESS) { + OnBufferPacketFailure(rs, packet_info.destination_connection_id); + } +} + +void QuicDispatcher::ProcessChlo(ParsedClientHello parsed_chlo, + ReceivedPacketInfo* packet_info) { + if (GetQuicFlag(quic_allow_chlo_buffering) && + new_sessions_allowed_per_event_loop_ <= 0) { + // Can't create new session any more. Wait till next event loop. + QUIC_BUG_IF(quic_bug_12724_7, buffered_packets_.HasChloForConnection( + packet_info->destination_connection_id)); + EnqueuePacketResult rs = buffered_packets_.EnqueuePacket( + packet_info->destination_connection_id, + packet_info->form != GOOGLE_QUIC_PACKET, packet_info->packet, + packet_info->self_address, packet_info->peer_address, + packet_info->version, std::move(parsed_chlo)); + if (rs != EnqueuePacketResult::SUCCESS) { + OnBufferPacketFailure(rs, packet_info->destination_connection_id); + } + return; + } + + auto session_ptr = QuicDispatcher::CreateSessionFromChlo( + packet_info->destination_connection_id, parsed_chlo, packet_info->version, + packet_info->self_address, packet_info->peer_address); + if (session_ptr == nullptr) { + return; + } + std::list packets = + buffered_packets_.DeliverPackets(packet_info->destination_connection_id) + .buffered_packets; + if (packet_info->destination_connection_id != session_ptr->connection_id()) { + // Provide the calling function with access to the new connection ID. + packet_info->destination_connection_id = session_ptr->connection_id(); + if (!packets.empty()) { + QUIC_CODE_COUNT( + quic_delivered_buffered_packets_to_connection_with_replaced_id); + } + } + // Process CHLO at first. + session_ptr->ProcessUdpPacket(packet_info->self_address, + packet_info->peer_address, packet_info->packet); + // Deliver queued-up packets in the same order as they arrived. + // Do this even when flag is off because there might be still some packets + // buffered in the store before flag is turned off. + DeliverPacketsToSession(packets, session_ptr.get()); + --new_sessions_allowed_per_event_loop_; +} + +void QuicDispatcher::SetLastError(QuicErrorCode error) { last_error_ = error; } + +bool QuicDispatcher::OnFailedToDispatchPacket( + const ReceivedPacketInfo& /*packet_info*/) { + return false; +} + +const ParsedQuicVersionVector& QuicDispatcher::GetSupportedVersions() { + return version_manager_->GetSupportedVersions(); +} + +void QuicDispatcher::DeliverPacketsToSession( + const std::list& packets, QuicSession* session) { + for (const BufferedPacket& packet : packets) { + session->ProcessUdpPacket(packet.self_address, packet.peer_address, + *(packet.packet)); + } +} + +bool QuicDispatcher::IsSupportedVersion(const ParsedQuicVersion version) { + for (const ParsedQuicVersion& supported_version : + version_manager_->GetSupportedVersions()) { + if (version == supported_version) { + return true; + } + } + return false; +} + +bool QuicDispatcher::IsServerConnectionIdTooShort( + QuicConnectionId connection_id) const { + if (connection_id.length() >= kQuicMinimumInitialConnectionIdLength || + connection_id.length() >= expected_server_connection_id_length_ || + allow_short_initial_server_connection_ids_) { + return false; + } + uint8_t generator_output = + connection_id.IsEmpty() + ? connection_id_generator_.ConnectionIdLength(0x00) + : connection_id_generator_.ConnectionIdLength( + static_cast(*connection_id.data())); + return connection_id.length() < generator_output; +} + +std::shared_ptr QuicDispatcher::CreateSessionFromChlo( + const QuicConnectionId original_connection_id, + const ParsedClientHello& parsed_chlo, const ParsedQuicVersion version, + const QuicSocketAddress self_address, + const QuicSocketAddress peer_address) { + absl::optional server_connection_id = + connection_id_generator_.MaybeReplaceConnectionId(original_connection_id, + version); + const bool replaced_connection_id = server_connection_id.has_value(); + if (!replaced_connection_id) { + server_connection_id = original_connection_id; + } + if (reference_counted_session_map_.count(*server_connection_id) > 0) { + // The new connection ID is owned by another session. Avoid creating one + // altogether, as this connection attempt cannot possibly succeed. + if (replaced_connection_id) { + // The original connection ID does not correspond to an existing + // session. It is safe to send CONNECTION_CLOSE and add to TIME_WAIT. + StatelesslyTerminateConnection( + original_connection_id, IETF_QUIC_LONG_HEADER_PACKET, + /*version_flag=*/true, version.HasLengthPrefixedConnectionIds(), + version, QUIC_HANDSHAKE_FAILED, + "Connection ID collision, please retry", + QuicTimeWaitListManager::SEND_CONNECTION_CLOSE_PACKETS); + } + return nullptr; + } + // Creates a new session and process all buffered packets for this connection. + std::string alpn = SelectAlpn(parsed_chlo.alpns); + std::unique_ptr session = + CreateQuicSession(*server_connection_id, self_address, peer_address, alpn, + version, parsed_chlo); + if (ABSL_PREDICT_FALSE(session == nullptr)) { + QUIC_BUG(quic_bug_10287_8) + << "CreateQuicSession returned nullptr for " << *server_connection_id + << " from " << peer_address << " to " << self_address << " ALPN \"" + << alpn << "\" version " << version; + return nullptr; + } + + if (replaced_connection_id) { + session->connection()->SetOriginalDestinationConnectionId( + original_connection_id); + } + QUIC_DLOG(INFO) << "Created new session for " << *server_connection_id; + + auto insertion_result = reference_counted_session_map_.insert(std::make_pair( + *server_connection_id, std::shared_ptr(std::move(session)))); + std::shared_ptr session_ptr = insertion_result.first->second; + if (!insertion_result.second) { + QUIC_BUG(quic_bug_10287_9) + << "Tried to add a session to session_map with existing " + "connection id: " + << *server_connection_id; + } else { + ++num_sessions_in_session_map_; + if (replaced_connection_id) { + auto insertion_result2 = reference_counted_session_map_.insert( + std::make_pair(original_connection_id, session_ptr)); + QUIC_BUG_IF(quic_460317833_02, !insertion_result2.second) + << "Original connection ID already in session_map: " + << original_connection_id; + // If insertion of the original connection ID fails, it might cause + // loss of 0-RTT and other first flight packets, but the connection + // will usually progress. + } + } + return session_ptr; +} + +void QuicDispatcher::MaybeResetPacketsWithNoVersion( + const ReceivedPacketInfo& packet_info) { + QUICHE_DCHECK(!packet_info.version_flag); + // Do not send a stateless reset if a reset has been sent to this address + // recently. + if (recent_stateless_reset_addresses_.contains(packet_info.peer_address)) { + QUIC_CODE_COUNT(quic_donot_send_reset_repeatedly); + return; + } + if (packet_info.form != GOOGLE_QUIC_PACKET) { + // Drop IETF packets smaller than the minimal stateless reset length. + if (packet_info.packet.length() <= + QuicFramer::GetMinStatelessResetPacketLength()) { + QUIC_CODE_COUNT(quic_drop_too_small_short_header_packets); + return; + } + } else { + const size_t MinValidPacketLength = + kPacketHeaderTypeSize + expected_server_connection_id_length_ + + PACKET_1BYTE_PACKET_NUMBER + /*payload size=*/1 + /*tag size=*/12; + if (packet_info.packet.length() < MinValidPacketLength) { + // The packet size is too small. + QUIC_CODE_COUNT(drop_too_small_packets); + return; + } + } + // Do not send a stateless reset if there are too many stateless reset + // addresses. + if (recent_stateless_reset_addresses_.size() >= + GetQuicFlag(quic_max_recent_stateless_reset_addresses)) { + QUIC_CODE_COUNT(quic_too_many_recent_reset_addresses); + return; + } + if (recent_stateless_reset_addresses_.empty()) { + clear_stateless_reset_addresses_alarm_->Update( + helper()->GetClock()->ApproximateNow() + + QuicTime::Delta::FromMilliseconds( + GetQuicFlag(quic_recent_stateless_reset_addresses_lifetime_ms)), + QuicTime::Delta::Zero()); + } + recent_stateless_reset_addresses_.emplace(packet_info.peer_address); + + time_wait_list_manager()->SendPublicReset( + packet_info.self_address, packet_info.peer_address, + packet_info.destination_connection_id, + packet_info.form != GOOGLE_QUIC_PACKET, packet_info.packet.length(), + GetPerPacketContext()); +} + +void QuicDispatcher::MaybeSendVersionNegotiationPacket( + const ReceivedPacketInfo& packet_info) { + if (crypto_config()->validate_chlo_size() && + packet_info.packet.length() < kMinPacketSizeForVersionNegotiation) { + return; + } + time_wait_list_manager()->SendVersionNegotiationPacket( + packet_info.destination_connection_id, packet_info.source_connection_id, + packet_info.form != GOOGLE_QUIC_PACKET, packet_info.use_length_prefix, + GetSupportedVersions(), packet_info.self_address, + packet_info.peer_address, GetPerPacketContext()); +} + +size_t QuicDispatcher::NumSessions() const { + return num_sessions_in_session_map_; +} + +} // namespace quic diff --git a/quiche/quic/core/quic_dispatcher.h b/quiche/quic/core/quic_dispatcher.h new file mode 100644 index 000000000000..24602b38e959 --- /dev/null +++ b/quiche/quic/core/quic_dispatcher.h @@ -0,0 +1,470 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// A server side dispatcher which dispatches a given client's data to their +// stream. + +#ifndef QUICHE_QUIC_CORE_QUIC_DISPATCHER_H_ +#define QUICHE_QUIC_CORE_QUIC_DISPATCHER_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/connection_id_generator.h" +#include "quiche/quic/core/crypto/quic_compressed_certs_cache.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/quic_blocked_writer_interface.h" +#include "quiche/quic/core/quic_buffered_packet_store.h" +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_crypto_server_stream_base.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_process_packet_interface.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_time_wait_list_manager.h" +#include "quiche/quic/core/quic_version_manager.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/common/platform/api/quiche_reference_counted.h" +#include "quiche/common/quiche_linked_hash_map.h" + +namespace quic { +namespace test { +class QuicDispatcherPeer; +} // namespace test + +class QuicConfig; +class QuicCryptoServerConfig; + +class QUIC_NO_EXPORT QuicDispatcher + : public QuicTimeWaitListManager::Visitor, + public ProcessPacketInterface, + public QuicBufferedPacketStore::VisitorInterface { + public: + // Ideally we'd have a linked_hash_set: the boolean is unused. + using WriteBlockedList = + quiche::QuicheLinkedHashMap; + + QuicDispatcher( + const QuicConfig* config, const QuicCryptoServerConfig* crypto_config, + QuicVersionManager* version_manager, + std::unique_ptr helper, + std::unique_ptr session_helper, + std::unique_ptr alarm_factory, + uint8_t expected_server_connection_id_length, + ConnectionIdGeneratorInterface& connection_id_generator); + QuicDispatcher(const QuicDispatcher&) = delete; + QuicDispatcher& operator=(const QuicDispatcher&) = delete; + + ~QuicDispatcher() override; + + // Takes ownership of |writer|. + void InitializeWithWriter(QuicPacketWriter* writer); + + // Process the incoming packet by creating a new session, passing it to + // an existing session, or passing it to the time wait list. + void ProcessPacket(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + const QuicReceivedPacket& packet) override; + + // Called when the socket becomes writable to allow queued writes to happen. + virtual void OnCanWrite(); + + // Returns true if there's anything in the blocked writer list. + virtual bool HasPendingWrites() const; + + // Sends ConnectionClose frames to all connected clients. + void Shutdown(); + + // QuicSession::Visitor interface implementation (via inheritance of + // QuicTimeWaitListManager::Visitor): + // Ensure that the closed connection is cleaned up asynchronously. + void OnConnectionClosed(QuicConnectionId server_connection_id, + QuicErrorCode error, const std::string& error_details, + ConnectionCloseSource source) override; + + // QuicSession::Visitor interface implementation (via inheritance of + // QuicTimeWaitListManager::Visitor): + // Queues the blocked writer for later resumption. + void OnWriteBlocked(QuicBlockedWriterInterface* blocked_writer) override; + + // QuicSession::Visitor interface implementation (via inheritance of + // QuicTimeWaitListManager::Visitor): + // Collects reset error code received on streams. + void OnRstStreamReceived(const QuicRstStreamFrame& frame) override; + + // QuicSession::Visitor interface implementation (via inheritance of + // QuicTimeWaitListManager::Visitor): + // Collects reset error code received on streams. + void OnStopSendingReceived(const QuicStopSendingFrame& frame) override; + + // QuicSession::Visitor interface implementation (via inheritance of + // QuicTimeWaitListManager::Visitor): + // Try to add the new connection ID to the session map. Returns true on + // success. + bool TryAddNewConnectionId( + const QuicConnectionId& server_connection_id, + const QuicConnectionId& new_connection_id) override; + + // QuicSession::Visitor interface implementation (via inheritance of + // QuicTimeWaitListManager::Visitor): + // Remove the retired connection ID from the session map. + void OnConnectionIdRetired( + const QuicConnectionId& server_connection_id) override; + + void OnServerPreferredAddressAvailable( + const QuicSocketAddress& /*server_preferred_address*/) override { + QUICHE_DCHECK(false); + } + + // QuicTimeWaitListManager::Visitor interface implementation + // Called whenever the time wait list manager adds a new connection to the + // time-wait list. + void OnConnectionAddedToTimeWaitList( + QuicConnectionId server_connection_id) override; + + using ReferenceCountedSessionMap = + absl::flat_hash_map, + QuicConnectionIdHash>; + + size_t NumSessions() const; + + // Deletes all sessions on the closed session list and clears the list. + virtual void DeleteSessions(); + + // Clear recent_stateless_reset_addresses_. + void ClearStatelessResetAddresses(); + + using ConnectionIdMap = + absl::flat_hash_map; + + // QuicBufferedPacketStore::VisitorInterface implementation. + void OnExpiredPackets(QuicConnectionId server_connection_id, + QuicBufferedPacketStore::BufferedPacketList + early_arrived_packets) override; + + // Create connections for previously buffered CHLOs as many as allowed. + virtual void ProcessBufferedChlos(size_t max_connections_to_create); + + // Return true if there is CHLO buffered. + virtual bool HasChlosBuffered() const; + + // Start accepting new ConnectionIds. + void StartAcceptingNewConnections(); + + // Stop accepting new ConnectionIds, either as a part of the lame + // duck process or because explicitly configured. + void StopAcceptingNewConnections(); + + // Apply an operation for each session. + void PerformActionOnActiveSessions( + std::function operation) const; + + // Get a snapshot of all sessions. + std::vector> GetSessionsSnapshot() const; + + bool accept_new_connections() const { return accept_new_connections_; } + + protected: + // Creates a QUIC session based on the given information. + // |alpn| is the selected ALPN from |parsed_chlo.alpns|. + virtual std::unique_ptr CreateQuicSession( + QuicConnectionId server_connection_id, + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, absl::string_view alpn, + const ParsedQuicVersion& version, + const ParsedClientHello& parsed_chlo) = 0; + + // Tries to validate and dispatch packet based on available information. + // Returns true if packet is dropped or successfully dispatched (e.g., + // processed by existing session, processed by time wait list, etc.), + // otherwise, returns false and the packet needs further processing. + virtual bool MaybeDispatchPacket(const ReceivedPacketInfo& packet_info); + + // Values to be returned by ValidityChecks() to indicate what should be done + // with a packet. Fates with greater values are considered to be higher + // priority. ValidityChecks should return fate based on the priority order + // (i.e., returns higher priority fate first) + enum QuicPacketFate { + // Process the packet normally, which is usually to establish a connection. + kFateProcess, + // Put the connection ID into time-wait state and send a public reset. + kFateTimeWait, + // Drop the packet. + kFateDrop, + }; + + // This method is called by ProcessHeader on packets not associated with a + // known connection ID. It applies validity checks and returns a + // QuicPacketFate to tell what should be done with the packet. + // TODO(fayang): Merge ValidityChecks into MaybeDispatchPacket. + virtual QuicPacketFate ValidityChecks(const ReceivedPacketInfo& packet_info); + + // Extra validity checks after the full Client Hello is parsed, this allows + // subclasses to reject a connection based on sni or alpn. + // Only called if ValidityChecks returns kFateProcess. + virtual QuicPacketFate ValidityChecksOnFullChlo( + const ReceivedPacketInfo& /*packet_info*/, + const ParsedClientHello& /*parsed_chlo*/) const { + return kFateProcess; + } + + // Create and return the time wait list manager for this dispatcher, which + // will be owned by the dispatcher as time_wait_list_manager_ + virtual QuicTimeWaitListManager* CreateQuicTimeWaitListManager(); + + // Buffers packet until it can be delivered to a connection. + void BufferEarlyPacket(const ReceivedPacketInfo& packet_info); + + // Called when |packet_info| is the last received packet of the client hello. + // |parsed_chlo| is the parsed version of the client hello. Creates a new + // connection and delivers any buffered packets for that connection id. + void ProcessChlo(ParsedClientHello parsed_chlo, + ReceivedPacketInfo* packet_info); + + QuicTimeWaitListManager* time_wait_list_manager() { + return time_wait_list_manager_.get(); + } + + const ParsedQuicVersionVector& GetSupportedVersions(); + + const QuicConfig& config() const { return *config_; } + + const QuicCryptoServerConfig* crypto_config() const { return crypto_config_; } + + QuicCompressedCertsCache* compressed_certs_cache() { + return &compressed_certs_cache_; + } + + QuicConnectionHelperInterface* helper() { return helper_.get(); } + + QuicCryptoServerStreamBase::Helper* session_helper() { + return session_helper_.get(); + } + + const QuicCryptoServerStreamBase::Helper* session_helper() const { + return session_helper_.get(); + } + + QuicAlarmFactory* alarm_factory() { return alarm_factory_.get(); } + + QuicPacketWriter* writer() { return writer_.get(); } + + // Returns true if a session should be created for a connection with an + // unknown version identified by |version_label|. + virtual bool ShouldCreateSessionForUnknownVersion( + QuicVersionLabel version_label); + + void SetLastError(QuicErrorCode error); + + // Called by MaybeDispatchPacket when current packet cannot be dispatched. + // Used by subclasses to conduct specific logic to dispatch packet. Returns + // true if packet is successfully dispatched. + virtual bool OnFailedToDispatchPacket(const ReceivedPacketInfo& packet_info); + + bool HasBufferedPackets(QuicConnectionId server_connection_id); + + // Called when BufferEarlyPacket() fail to buffer the packet. + virtual void OnBufferPacketFailure( + QuicBufferedPacketStore::EnqueuePacketResult result, + QuicConnectionId server_connection_id); + + // Removes the session from the write blocked list, and adds the ConnectionId + // to the time-wait list. The caller needs to manually remove the session + // from the map after that. + void CleanUpSession(QuicConnectionId server_connection_id, + QuicConnection* connection, QuicErrorCode error, + const std::string& error_details, + ConnectionCloseSource source); + + // Called to terminate a connection statelessly. Depending on |format|, either + // 1) send connection close with |error_code| and |error_details| and add + // connection to time wait list or 2) directly add connection to time wait + // list with |action|. + void StatelesslyTerminateConnection( + QuicConnectionId server_connection_id, PacketHeaderFormat format, + bool version_flag, bool use_length_prefix, ParsedQuicVersion version, + QuicErrorCode error_code, const std::string& error_details, + QuicTimeWaitListManager::TimeWaitAction action); + + // Save/Restore per packet context. + virtual std::unique_ptr GetPerPacketContext() const; + virtual void RestorePerPacketContext( + std::unique_ptr /*context*/) {} + + // If true, our framer will change its expected connection ID length + // to the received destination connection ID length of all IETF long headers. + void SetShouldUpdateExpectedServerConnectionIdLength( + bool should_update_expected_server_connection_id_length) { + should_update_expected_server_connection_id_length_ = + should_update_expected_server_connection_id_length; + } + + // If true, the dispatcher will allow incoming initial packets that have + // destination connection IDs shorter than 64 bits. + void SetAllowShortInitialServerConnectionIds( + bool allow_short_initial_server_connection_ids) { + allow_short_initial_server_connection_ids_ = + allow_short_initial_server_connection_ids; + } + + // Called if a packet from an unseen connection is reset or rejected. + virtual void OnNewConnectionRejected() {} + + // Selects the preferred ALPN from a vector of ALPNs. + // This runs through the list of ALPNs provided by the client and picks the + // first one it supports. If no supported versions are found, the first + // element of the vector is returned. + std::string SelectAlpn(const std::vector& alpns); + + // Sends public/stateless reset packets with no version and unknown + // connection ID according to the packet's size. + virtual void MaybeResetPacketsWithNoVersion( + const quic::ReceivedPacketInfo& packet_info); + + // Called on packets with unsupported versions. + virtual void MaybeSendVersionNegotiationPacket( + const ReceivedPacketInfo& packet_info); + + ConnectionIdGeneratorInterface& connection_id_generator() { + return connection_id_generator_; + } + + private: + friend class test::QuicDispatcherPeer; + + // TODO(fayang): Consider to rename this function to + // ProcessValidatedPacketWithUnknownConnectionId. + void ProcessHeader(ReceivedPacketInfo* packet_info); + + struct ExtractChloResult { + // If set, a full client hello has been successfully parsed. + absl::optional parsed_chlo; + // If set, the TLS alert that will cause a connection close. + // Always empty for Google QUIC. + absl::optional tls_alert; + }; + + // Try to extract information(sni, alpns, ...) if the full Client Hello has + // been parsed. + // + // Returns the parsed client hello in ExtractChloResult.parsed_chlo, if the + // full Client Hello has been successfully parsed. + // + // Returns the TLS alert in ExtractChloResult.tls_alert, if the extraction of + // Client Hello failed due to that alert. + // + // Otherwise returns a default-constructed ExtractChloResult and either buffer + // or (rarely) drop the packet. + ExtractChloResult TryExtractChloOrBufferEarlyPacket( + const ReceivedPacketInfo& packet_info); + + // Deliver |packets| to |session| for further processing. + void DeliverPacketsToSession( + const std::list& packets, + QuicSession* session); + + // Returns true if |version| is a supported protocol version. + bool IsSupportedVersion(const ParsedQuicVersion version); + + // Returns true if a server connection ID length is below all the minima + // required by various parameters. + bool IsServerConnectionIdTooShort(QuicConnectionId connection_id) const; + + // Core CHLO processing logic. + std::shared_ptr CreateSessionFromChlo( + const QuicConnectionId original_connection_id, + const ParsedClientHello& parsed_chlo, const ParsedQuicVersion version, + const QuicSocketAddress self_address, + const QuicSocketAddress peer_address); + + const QuicConfig* config_; + + const QuicCryptoServerConfig* crypto_config_; + + // The cache for most recently compressed certs. + QuicCompressedCertsCache compressed_certs_cache_; + + // The list of connections waiting to write. + WriteBlockedList write_blocked_list_; + + ReferenceCountedSessionMap reference_counted_session_map_; + + // Entity that manages connection_ids in time wait state. + std::unique_ptr time_wait_list_manager_; + + // The list of closed but not-yet-deleted sessions. + std::vector> closed_session_list_; + + // The helper used for all connections. + std::unique_ptr helper_; + + // The helper used for all sessions. + std::unique_ptr session_helper_; + + // Creates alarms. + std::unique_ptr alarm_factory_; + + // An alarm which deletes closed sessions. + std::unique_ptr delete_sessions_alarm_; + + // The writer to write to the socket with. + std::unique_ptr writer_; + + // Packets which are buffered until a connection can be created to handle + // them. + QuicBufferedPacketStore buffered_packets_; + + // Used to get the supported versions based on flag. Does not own. + QuicVersionManager* version_manager_; + + // The last error set by SetLastError(). + // TODO(fayang): consider removing last_error_. + QuicErrorCode last_error_; + + // Number of unique session in session map. + size_t num_sessions_in_session_map_ = 0; + + // A backward counter of how many new sessions can be create within current + // event loop. When reaches 0, it means can't create sessions for now. + int16_t new_sessions_allowed_per_event_loop_; + + // True if this dispatcher is accepting new ConnectionIds (new client + // connections), false otherwise. + bool accept_new_connections_; + + // If false, the dispatcher follows the IETF spec and rejects packets with + // invalid destination connection IDs lengths below 64 bits. + // If true they are allowed. + bool allow_short_initial_server_connection_ids_; + + // IETF short headers contain a destination connection ID but do not + // encode its length. This variable contains the length we expect to read. + // This is also used to signal an error when a long header packet with + // different destination connection ID length is received when + // should_update_expected_server_connection_id_length_ is false and packet's + // version does not allow variable length connection ID. + uint8_t expected_server_connection_id_length_; + + // Records client addresses that have been recently reset. + absl::flat_hash_set + recent_stateless_reset_addresses_; + + // An alarm which clear recent_stateless_reset_addresses_. + std::unique_ptr clear_stateless_reset_addresses_alarm_; + + // If true, change expected_server_connection_id_length_ to be the received + // destination connection ID length of all IETF long headers. + bool should_update_expected_server_connection_id_length_; + + ConnectionIdGeneratorInterface& connection_id_generator_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_DISPATCHER_H_ diff --git a/quiche/quic/core/quic_dispatcher_test.cc b/quiche/quic/core/quic_dispatcher_test.cc new file mode 100644 index 000000000000..12a62bb1dd0f --- /dev/null +++ b/quiche/quic/core/quic_dispatcher_test.cc @@ -0,0 +1,3003 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_dispatcher.h" + +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/str_cat.h" +#include "quiche/quic/core/chlo_extractor.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/crypto/quic_crypto_server_config.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/quic_config.h" +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_crypto_stream.h" +#include "quiche/quic/core/quic_packet_writer_wrapper.h" +#include "quiche/quic/core/quic_time_wait_list_manager.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/first_flight.h" +#include "quiche/quic/test_tools/mock_connection_id_generator.h" +#include "quiche/quic/test_tools/mock_quic_time_wait_list_manager.h" +#include "quiche/quic/test_tools/quic_buffered_packet_store_peer.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_dispatcher_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/tools/quic_simple_crypto_server_stream_helper.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +using testing::_; +using testing::ByMove; +using testing::Eq; +using testing::InSequence; +using testing::Invoke; +using testing::NiceMock; +using testing::Return; +using testing::WithArg; +using testing::WithoutArgs; + +static const size_t kDefaultMaxConnectionsInStore = 100; +static const size_t kMaxConnectionsWithoutCHLO = + kDefaultMaxConnectionsInStore / 2; +static const int16_t kMaxNumSessionsToCreate = 16; + +namespace quic { +namespace test { +namespace { + +const QuicConnectionId kReturnConnectionId{ + {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}}; + +class TestQuicSpdyServerSession : public QuicServerSessionBase { + public: + TestQuicSpdyServerSession(const QuicConfig& config, + QuicConnection* connection, + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache) + : QuicServerSessionBase(config, CurrentSupportedVersions(), connection, + nullptr, nullptr, crypto_config, + compressed_certs_cache) { + Initialize(); + } + TestQuicSpdyServerSession(const TestQuicSpdyServerSession&) = delete; + TestQuicSpdyServerSession& operator=(const TestQuicSpdyServerSession&) = + delete; + + ~TestQuicSpdyServerSession() override { DeleteConnection(); } + + MOCK_METHOD(void, OnConnectionClosed, + (const QuicConnectionCloseFrame& frame, + ConnectionCloseSource source), + (override)); + MOCK_METHOD(QuicSpdyStream*, CreateIncomingStream, (QuicStreamId id), + (override)); + MOCK_METHOD(QuicSpdyStream*, CreateIncomingStream, (PendingStream*), + (override)); + MOCK_METHOD(QuicSpdyStream*, CreateOutgoingBidirectionalStream, (), + (override)); + MOCK_METHOD(QuicSpdyStream*, CreateOutgoingUnidirectionalStream, (), + (override)); + + std::unique_ptr CreateQuicCryptoServerStream( + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache) override { + return CreateCryptoServerStream(crypto_config, compressed_certs_cache, this, + stream_helper()); + } + + QuicCryptoServerStreamBase::Helper* stream_helper() { + return QuicServerSessionBase::stream_helper(); + } +}; + +class TestDispatcher : public QuicDispatcher { + public: + TestDispatcher(const QuicConfig* config, + const QuicCryptoServerConfig* crypto_config, + QuicVersionManager* version_manager, QuicRandom* random, + ConnectionIdGeneratorInterface& generator) + : QuicDispatcher(config, crypto_config, version_manager, + std::make_unique(), + std::unique_ptr( + new QuicSimpleCryptoServerStreamHelper()), + std::make_unique(), + kQuicDefaultConnectionIdLength, generator), + random_(random) {} + + MOCK_METHOD(std::unique_ptr, CreateQuicSession, + (QuicConnectionId connection_id, + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, absl::string_view alpn, + const ParsedQuicVersion& version, + const ParsedClientHello& parsed_chlo), + (override)); + + struct TestQuicPerPacketContext : public QuicPerPacketContext { + std::string custom_packet_context; + }; + + std::unique_ptr GetPerPacketContext() const override { + auto test_context = std::make_unique(); + test_context->custom_packet_context = custom_packet_context_; + return std::move(test_context); + } + + void RestorePerPacketContext( + std::unique_ptr context) override { + TestQuicPerPacketContext* test_context = + static_cast(context.get()); + custom_packet_context_ = test_context->custom_packet_context; + } + + std::string custom_packet_context_; + + using QuicDispatcher::MaybeDispatchPacket; + using QuicDispatcher::SetAllowShortInitialServerConnectionIds; + using QuicDispatcher::writer; + + QuicRandom* random_; +}; + +// A Connection class which unregisters the session from the dispatcher when +// sending connection close. +// It'd be slightly more realistic to do this from the Session but it would +// involve a lot more mocking. +class MockServerConnection : public MockQuicConnection { + public: + MockServerConnection(QuicConnectionId connection_id, + MockQuicConnectionHelper* helper, + MockAlarmFactory* alarm_factory, + QuicDispatcher* dispatcher) + : MockQuicConnection(connection_id, helper, alarm_factory, + Perspective::IS_SERVER), + dispatcher_(dispatcher), + active_connection_ids_({connection_id}) {} + + void AddNewConnectionId(QuicConnectionId id) { + if (!dispatcher_->TryAddNewConnectionId(active_connection_ids_.back(), + id)) { + return; + } + QuicConnectionPeer::SetServerConnectionId(this, id); + active_connection_ids_.push_back(id); + } + + void UnconditionallyAddNewConnectionIdForTest(QuicConnectionId id) { + dispatcher_->TryAddNewConnectionId(active_connection_ids_.back(), id); + active_connection_ids_.push_back(id); + } + + void RetireConnectionId(QuicConnectionId id) { + auto it = std::find(active_connection_ids_.begin(), + active_connection_ids_.end(), id); + QUICHE_DCHECK(it != active_connection_ids_.end()); + dispatcher_->OnConnectionIdRetired(id); + active_connection_ids_.erase(it); + } + + std::vector GetActiveServerConnectionIds() const override { + std::vector result; + for (const auto& cid : active_connection_ids_) { + result.push_back(cid); + } + auto original_connection_id = GetOriginalDestinationConnectionId(); + if (std::find(result.begin(), result.end(), original_connection_id) == + result.end()) { + result.push_back(original_connection_id); + } + return result; + } + + void UnregisterOnConnectionClosed() { + QUIC_LOG(ERROR) << "Unregistering " << connection_id(); + dispatcher_->OnConnectionClosed(connection_id(), QUIC_NO_ERROR, + "Unregistering.", + ConnectionCloseSource::FROM_SELF); + } + + private: + QuicDispatcher* dispatcher_; + std::vector active_connection_ids_; +}; + +class QuicDispatcherTestBase : public QuicTestWithParam { + public: + QuicDispatcherTestBase() + : QuicDispatcherTestBase(crypto_test_utils::ProofSourceForTesting()) {} + + explicit QuicDispatcherTestBase(std::unique_ptr proof_source) + : version_(GetParam()), + version_manager_(AllSupportedVersions()), + crypto_config_(QuicCryptoServerConfig::TESTING, + QuicRandom::GetInstance(), std::move(proof_source), + KeyExchangeSource::Default()), + server_address_(QuicIpAddress::Any4(), 5), + dispatcher_(new NiceMock( + &config_, &crypto_config_, &version_manager_, + mock_helper_.GetRandomGenerator(), connection_id_generator_)), + time_wait_list_manager_(nullptr), + session1_(nullptr), + session2_(nullptr), + store_(nullptr), + connection_id_(1) {} + + void SetUp() override { + dispatcher_->InitializeWithWriter(new NiceMock()); + // Set the counter to some value to start with. + QuicDispatcherPeer::set_new_sessions_allowed_per_event_loop( + dispatcher_.get(), kMaxNumSessionsToCreate); + } + + MockQuicConnection* connection1() { + if (session1_ == nullptr) { + return nullptr; + } + return reinterpret_cast(session1_->connection()); + } + + MockQuicConnection* connection2() { + if (session2_ == nullptr) { + return nullptr; + } + return reinterpret_cast(session2_->connection()); + } + + // Process a packet with an 8 byte connection id, + // 6 byte packet number, default path id, and packet number 1, + // using the version under test. + void ProcessPacket(QuicSocketAddress peer_address, + QuicConnectionId server_connection_id, + bool has_version_flag, const std::string& data) { + ProcessPacket(peer_address, server_connection_id, has_version_flag, data, + CONNECTION_ID_PRESENT, PACKET_4BYTE_PACKET_NUMBER); + } + + // Process a packet with a default path id, and packet number 1, + // using the version under test. + void ProcessPacket(QuicSocketAddress peer_address, + QuicConnectionId server_connection_id, + bool has_version_flag, const std::string& data, + QuicConnectionIdIncluded server_connection_id_included, + QuicPacketNumberLength packet_number_length) { + ProcessPacket(peer_address, server_connection_id, has_version_flag, data, + server_connection_id_included, packet_number_length, 1); + } + + // Process a packet using the version under test. + void ProcessPacket(QuicSocketAddress peer_address, + QuicConnectionId server_connection_id, + bool has_version_flag, const std::string& data, + QuicConnectionIdIncluded server_connection_id_included, + QuicPacketNumberLength packet_number_length, + uint64_t packet_number) { + ProcessPacket(peer_address, server_connection_id, has_version_flag, + version_, data, true, server_connection_id_included, + packet_number_length, packet_number); + } + + // Processes a packet. + void ProcessPacket(QuicSocketAddress peer_address, + QuicConnectionId server_connection_id, + bool has_version_flag, ParsedQuicVersion version, + const std::string& data, bool full_padding, + QuicConnectionIdIncluded server_connection_id_included, + QuicPacketNumberLength packet_number_length, + uint64_t packet_number) { + ProcessPacket(peer_address, server_connection_id, EmptyQuicConnectionId(), + has_version_flag, version, data, full_padding, + server_connection_id_included, CONNECTION_ID_ABSENT, + packet_number_length, packet_number); + } + + // Processes a packet. + void ProcessPacket(QuicSocketAddress peer_address, + QuicConnectionId server_connection_id, + QuicConnectionId client_connection_id, + bool has_version_flag, ParsedQuicVersion version, + const std::string& data, bool full_padding, + QuicConnectionIdIncluded server_connection_id_included, + QuicConnectionIdIncluded client_connection_id_included, + QuicPacketNumberLength packet_number_length, + uint64_t packet_number) { + ParsedQuicVersionVector versions(SupportedVersions(version)); + std::unique_ptr packet(ConstructEncryptedPacket( + server_connection_id, client_connection_id, has_version_flag, false, + packet_number, data, full_padding, server_connection_id_included, + client_connection_id_included, packet_number_length, &versions)); + std::unique_ptr received_packet( + ConstructReceivedPacket(*packet, mock_helper_.GetClock()->Now())); + // Call ConnectionIdLength if the packet clears the Long Header bit, or + // if the test involves sending a connection ID that is too short + if (!has_version_flag || !version.AllowsVariableLengthConnectionIds() || + server_connection_id.length() == 0 || + server_connection_id_included == CONNECTION_ID_ABSENT) { + // Short headers will ask for the length + EXPECT_CALL(connection_id_generator_, ConnectionIdLength(_)) + .WillRepeatedly(Return(generated_connection_id_.has_value() + ? generated_connection_id_->length() + : kQuicDefaultConnectionIdLength)); + } + ProcessReceivedPacket(std::move(received_packet), peer_address, version, + server_connection_id); + } + + void ProcessReceivedPacket( + std::unique_ptr received_packet, + const QuicSocketAddress& peer_address, const ParsedQuicVersion& version, + const QuicConnectionId& server_connection_id) { + if (version.UsesQuicCrypto() && + ChloExtractor::Extract(*received_packet, version, {}, nullptr, + server_connection_id.length())) { + // Add CHLO packet to the beginning to be verified first, because it is + // also processed first by new session. + data_connection_map_[server_connection_id].push_front( + std::string(received_packet->data(), received_packet->length())); + } else { + // For non-CHLO, always append to last. + data_connection_map_[server_connection_id].push_back( + std::string(received_packet->data(), received_packet->length())); + } + dispatcher_->ProcessPacket(server_address_, peer_address, *received_packet); + } + + void ValidatePacket(QuicConnectionId conn_id, + const QuicEncryptedPacket& packet) { + EXPECT_EQ(data_connection_map_[conn_id].front().length(), + packet.AsStringPiece().length()); + EXPECT_EQ(data_connection_map_[conn_id].front(), packet.AsStringPiece()); + data_connection_map_[conn_id].pop_front(); + } + + std::unique_ptr CreateSession( + TestDispatcher* dispatcher, const QuicConfig& config, + QuicConnectionId connection_id, const QuicSocketAddress& /*peer_address*/, + MockQuicConnectionHelper* helper, MockAlarmFactory* alarm_factory, + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache, + TestQuicSpdyServerSession** session_ptr) { + MockServerConnection* connection = new MockServerConnection( + connection_id, helper, alarm_factory, dispatcher); + connection->SetQuicPacketWriter(dispatcher->writer(), + /*owns_writer=*/false); + auto session = std::make_unique( + config, connection, crypto_config, compressed_certs_cache); + *session_ptr = session.get(); + connection->set_visitor(session.get()); + ON_CALL(*connection, CloseConnection(_, _, _)) + .WillByDefault(WithoutArgs(Invoke( + connection, &MockServerConnection::UnregisterOnConnectionClosed))); + return session; + } + + void CreateTimeWaitListManager() { + time_wait_list_manager_ = new MockTimeWaitListManager( + QuicDispatcherPeer::GetWriter(dispatcher_.get()), dispatcher_.get(), + mock_helper_.GetClock(), &mock_alarm_factory_); + // dispatcher_ takes the ownership of time_wait_list_manager_. + QuicDispatcherPeer::SetTimeWaitListManager(dispatcher_.get(), + time_wait_list_manager_); + } + + std::string SerializeCHLO() { + CryptoHandshakeMessage client_hello; + client_hello.set_tag(kCHLO); + client_hello.SetStringPiece(kALPN, ExpectedAlpn()); + return std::string(client_hello.GetSerialized().AsStringPiece()); + } + + void ProcessUndecryptableEarlyPacket( + const QuicSocketAddress& peer_address, + const QuicConnectionId& server_connection_id) { + ProcessUndecryptableEarlyPacket(version_, peer_address, + server_connection_id); + } + + void ProcessUndecryptableEarlyPacket( + const ParsedQuicVersion& version, const QuicSocketAddress& peer_address, + const QuicConnectionId& server_connection_id) { + std::unique_ptr encrypted_packet = + GetUndecryptableEarlyPacket(version, server_connection_id); + std::unique_ptr received_packet(ConstructReceivedPacket( + *encrypted_packet, mock_helper_.GetClock()->Now())); + ProcessReceivedPacket(std::move(received_packet), peer_address, version, + server_connection_id); + } + + void ProcessFirstFlight(const QuicSocketAddress& peer_address, + const QuicConnectionId& server_connection_id) { + ProcessFirstFlight(version_, peer_address, server_connection_id); + } + + void ProcessFirstFlight(const ParsedQuicVersion& version, + const QuicSocketAddress& peer_address, + const QuicConnectionId& server_connection_id) { + ProcessFirstFlight(version, peer_address, server_connection_id, + EmptyQuicConnectionId()); + } + + void ProcessFirstFlight(const ParsedQuicVersion& version, + const QuicSocketAddress& peer_address, + const QuicConnectionId& server_connection_id, + const QuicConnectionId& client_connection_id) { + ProcessFirstFlight(version, peer_address, server_connection_id, + client_connection_id, TestClientCryptoConfig()); + } + + void ProcessFirstFlight( + const ParsedQuicVersion& version, const QuicSocketAddress& peer_address, + const QuicConnectionId& server_connection_id, + const QuicConnectionId& client_connection_id, + std::unique_ptr client_crypto_config) { + if (expect_generator_is_called_) { + if (version.AllowsVariableLengthConnectionIds()) { + EXPECT_CALL(connection_id_generator_, + MaybeReplaceConnectionId(server_connection_id, version)) + .WillOnce(Return(generated_connection_id_)); + } else { + EXPECT_CALL(connection_id_generator_, + MaybeReplaceConnectionId(server_connection_id, version)) + .WillOnce(Return(absl::nullopt)); + } + } + std::vector> packets = + GetFirstFlightOfPackets(version, DefaultQuicConfig(), + server_connection_id, client_connection_id, + std::move(client_crypto_config)); + for (auto&& packet : packets) { + ProcessReceivedPacket(std::move(packet), peer_address, version, + server_connection_id); + } + } + + std::unique_ptr TestClientCryptoConfig() { + auto client_crypto_config = std::make_unique( + crypto_test_utils::ProofVerifierForTesting()); + if (address_token_.has_value()) { + client_crypto_config->LookupOrCreate(TestServerId()) + ->set_source_address_token(*address_token_); + } + return client_crypto_config; + } + + // If called, the first flight packets generated in |ProcessFirstFlight| will + // contain the given |address_token|. + void SetAddressToken(std::string address_token) { + address_token_ = std::move(address_token); + } + + std::string ExpectedAlpnForVersion(ParsedQuicVersion version) { + return AlpnForVersion(version); + } + + std::string ExpectedAlpn() { return ExpectedAlpnForVersion(version_); } + + ParsedClientHello ParsedClientHelloForTest() { + ParsedClientHello parsed_chlo; + parsed_chlo.alpns = {ExpectedAlpn()}; + parsed_chlo.sni = TestHostname(); + return parsed_chlo; + } + + void MarkSession1Deleted() { session1_ = nullptr; } + + void VerifyVersionSupported(ParsedQuicVersion version) { + expect_generator_is_called_ = true; + QuicConnectionId connection_id = TestConnectionId(++connection_id_); + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + EXPECT_CALL(*dispatcher_, + CreateQuicSession(connection_id, _, client_address, + Eq(ExpectedAlpnForVersion(version)), _, _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, connection_id, client_address, + &mock_helper_, &mock_alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .WillOnce(WithArg<2>( + Invoke([this, connection_id](const QuicEncryptedPacket& packet) { + ValidatePacket(connection_id, packet); + }))); + ProcessFirstFlight(version, client_address, connection_id); + } + + void VerifyVersionNotSupported(ParsedQuicVersion version) { + QuicConnectionId connection_id = TestConnectionId(++connection_id_); + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + EXPECT_CALL(*dispatcher_, + CreateQuicSession(connection_id, _, client_address, _, _, _)) + .Times(0); + expect_generator_is_called_ = false; + ProcessFirstFlight(version, client_address, connection_id); + } + + void TestTlsMultiPacketClientHello(bool add_reordering, + bool long_connection_id); + + void TestVersionNegotiationForUnknownVersionInvalidShortInitialConnectionId( + const QuicConnectionId& server_connection_id, + const QuicConnectionId& client_connection_id); + + TestAlarmFactory::TestAlarm* GetClearResetAddressesAlarm() { + return reinterpret_cast( + QuicDispatcherPeer::GetClearResetAddressesAlarm(dispatcher_.get())); + } + + ParsedQuicVersion version_; + MockQuicConnectionHelper mock_helper_; + MockAlarmFactory mock_alarm_factory_; + QuicConfig config_; + QuicVersionManager version_manager_; + QuicCryptoServerConfig crypto_config_; + QuicSocketAddress server_address_; + // Set to false if the dispatcher won't create a session. + bool expect_generator_is_called_ = true; + // Set in conditions where the generator should return a different connection + // ID. + absl::optional generated_connection_id_; + MockConnectionIdGenerator connection_id_generator_; + std::unique_ptr> dispatcher_; + MockTimeWaitListManager* time_wait_list_manager_; + TestQuicSpdyServerSession* session1_; + TestQuicSpdyServerSession* session2_; + std::map> data_connection_map_; + QuicBufferedPacketStore* store_; + uint64_t connection_id_; + absl::optional address_token_; +}; + +class QuicDispatcherTestAllVersions : public QuicDispatcherTestBase {}; +class QuicDispatcherTestOneVersion : public QuicDispatcherTestBase {}; + +INSTANTIATE_TEST_SUITE_P(QuicDispatcherTestsAllVersions, + QuicDispatcherTestAllVersions, + ::testing::ValuesIn(CurrentSupportedVersions()), + ::testing::PrintToStringParamName()); + +INSTANTIATE_TEST_SUITE_P(QuicDispatcherTestsOneVersion, + QuicDispatcherTestOneVersion, + ::testing::Values(CurrentSupportedVersions().front()), + ::testing::PrintToStringParamName()); + +TEST_P(QuicDispatcherTestAllVersions, TlsClientHelloCreatesSession) { + if (version_.UsesQuicCrypto()) { + return; + } + SetAddressToken("hsdifghdsaifnasdpfjdsk"); + + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + + EXPECT_CALL( + *dispatcher_, + CreateQuicSession(TestConnectionId(1), _, client_address, + Eq(ExpectedAlpn()), _, Eq(ParsedClientHelloForTest()))) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, TestConnectionId(1), client_address, + &mock_helper_, &mock_alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .WillOnce(WithArg<2>(Invoke([this](const QuicEncryptedPacket& packet) { + ValidatePacket(TestConnectionId(1), packet); + }))); + + ProcessFirstFlight(client_address, TestConnectionId(1)); +} + +TEST_P(QuicDispatcherTestAllVersions, VariableServerConnectionIdLength) { + QuicConnectionId old_id = TestConnectionId(1); + // Return a connection ID that is not expected_server_connection_id_length_ + // bytes long. + if (version_.HasIetfQuicFrames()) { + generated_connection_id_ = + QuicConnectionId({0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x09, 0x0a, 0x0b}); + } + QuicConnectionId new_id = + generated_connection_id_.has_value() ? *generated_connection_id_ : old_id; + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + EXPECT_CALL(*dispatcher_, + CreateQuicSession(new_id, _, client_address, Eq(ExpectedAlpn()), + _, Eq(ParsedClientHelloForTest()))) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, new_id, client_address, &mock_helper_, + &mock_alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .WillOnce(WithArg<2>(Invoke([this](const QuicEncryptedPacket& packet) { + ValidatePacket(TestConnectionId(1), packet); + }))); + ProcessFirstFlight(client_address, old_id); + + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .Times(1); + ProcessPacket(client_address, new_id, false, "foo"); +} + +void QuicDispatcherTestBase::TestTlsMultiPacketClientHello( + bool add_reordering, bool long_connection_id) { + if (!version_.UsesTls()) { + return; + } + SetAddressToken("857293462398"); + + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + QuicConnectionId original_connection_id, new_connection_id; + if (long_connection_id) { + original_connection_id = TestConnectionIdNineBytesLong(1); + new_connection_id = kReturnConnectionId; + EXPECT_CALL(connection_id_generator_, + MaybeReplaceConnectionId(original_connection_id, version_)) + .WillOnce(Return(new_connection_id)); + + } else { + original_connection_id = TestConnectionId(); + new_connection_id = original_connection_id; + EXPECT_CALL(connection_id_generator_, + MaybeReplaceConnectionId(original_connection_id, version_)) + .WillOnce(Return(absl::nullopt)); + } + QuicConfig client_config = DefaultQuicConfig(); + // Add a 2000-byte custom parameter to increase the length of the CHLO. + constexpr auto kCustomParameterId = + static_cast(0xff33); + std::string kCustomParameterValue(2000, '-'); + client_config.custom_transport_parameters_to_send()[kCustomParameterId] = + kCustomParameterValue; + std::vector> packets = + GetFirstFlightOfPackets(version_, client_config, original_connection_id, + EmptyQuicConnectionId(), + TestClientCryptoConfig()); + ASSERT_EQ(packets.size(), 2u); + if (add_reordering) { + std::swap(packets[0], packets[1]); + } + + // Processing the first packet should not create a new session. + ProcessReceivedPacket(std::move(packets[0]), client_address, version_, + original_connection_id); + + EXPECT_EQ(dispatcher_->NumSessions(), 0u) + << "No session should be created before the rest of the CHLO arrives."; + + // Processing the second packet should create the new session. + EXPECT_CALL( + *dispatcher_, + CreateQuicSession(new_connection_id, _, client_address, + Eq(ExpectedAlpn()), _, Eq(ParsedClientHelloForTest()))) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, new_connection_id, client_address, + &mock_helper_, &mock_alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .Times(2); + + ProcessReceivedPacket(std::move(packets[1]), client_address, version_, + original_connection_id); + EXPECT_EQ(dispatcher_->NumSessions(), 1u); +} + +TEST_P(QuicDispatcherTestAllVersions, TlsMultiPacketClientHello) { + TestTlsMultiPacketClientHello(/*add_reordering=*/false, + /*long_connection_id=*/false); +} + +TEST_P(QuicDispatcherTestAllVersions, TlsMultiPacketClientHelloWithReordering) { + TestTlsMultiPacketClientHello(/*add_reordering=*/true, + /*long_connection_id=*/false); +} + +TEST_P(QuicDispatcherTestAllVersions, TlsMultiPacketClientHelloWithLongId) { + TestTlsMultiPacketClientHello(/*add_reordering=*/false, + /*long_connection_id=*/true); +} + +TEST_P(QuicDispatcherTestAllVersions, + TlsMultiPacketClientHelloWithReorderingAndLongId) { + TestTlsMultiPacketClientHello(/*add_reordering=*/true, + /*long_connection_id=*/true); +} + +TEST_P(QuicDispatcherTestAllVersions, ProcessPackets) { + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + + EXPECT_CALL( + *dispatcher_, + CreateQuicSession(TestConnectionId(1), _, client_address, + Eq(ExpectedAlpn()), _, Eq(ParsedClientHelloForTest()))) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, TestConnectionId(1), client_address, + &mock_helper_, &mock_alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .WillOnce(WithArg<2>(Invoke([this](const QuicEncryptedPacket& packet) { + ValidatePacket(TestConnectionId(1), packet); + }))); + ProcessFirstFlight(client_address, TestConnectionId(1)); + + EXPECT_CALL( + *dispatcher_, + CreateQuicSession(TestConnectionId(2), _, client_address, + Eq(ExpectedAlpn()), _, Eq(ParsedClientHelloForTest()))) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, TestConnectionId(2), client_address, + &mock_helper_, &mock_alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session2_)))); + EXPECT_CALL(*reinterpret_cast(session2_->connection()), + ProcessUdpPacket(_, _, _)) + .WillOnce(WithArg<2>(Invoke([this](const QuicEncryptedPacket& packet) { + ValidatePacket(TestConnectionId(2), packet); + }))); + ProcessFirstFlight(client_address, TestConnectionId(2)); + + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .Times(1) + .WillOnce(WithArg<2>(Invoke([this](const QuicEncryptedPacket& packet) { + ValidatePacket(TestConnectionId(1), packet); + }))); + ProcessPacket(client_address, TestConnectionId(1), false, "data"); +} + +// Regression test of b/93325907. +TEST_P(QuicDispatcherTestAllVersions, DispatcherDoesNotRejectPacketNumberZero) { + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + + EXPECT_CALL(*dispatcher_, + CreateQuicSession(TestConnectionId(1), _, client_address, + Eq(ExpectedAlpn()), _, _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, TestConnectionId(1), client_address, + &mock_helper_, &mock_alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + // Verify both packets 1 and 2 are processed by connection 1. + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .Times(2) + .WillRepeatedly( + WithArg<2>(Invoke([this](const QuicEncryptedPacket& packet) { + ValidatePacket(TestConnectionId(1), packet); + }))); + ProcessFirstFlight(client_address, TestConnectionId(1)); + // Packet number 256 with packet number length 1 would be considered as 0 in + // dispatcher. + ProcessPacket(client_address, TestConnectionId(1), false, version_, "", true, + CONNECTION_ID_PRESENT, PACKET_1BYTE_PACKET_NUMBER, 256); +} + +TEST_P(QuicDispatcherTestOneVersion, StatelessVersionNegotiation) { + CreateTimeWaitListManager(); + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + EXPECT_CALL( + *time_wait_list_manager_, + SendVersionNegotiationPacket(TestConnectionId(1), _, _, _, _, _, _, _)) + .Times(1); + expect_generator_is_called_ = false; + ProcessFirstFlight(QuicVersionReservedForNegotiation(), client_address, + TestConnectionId(1)); +} + +TEST_P(QuicDispatcherTestOneVersion, + StatelessVersionNegotiationWithVeryLongConnectionId) { + QuicConnectionId connection_id = QuicUtils::CreateRandomConnectionId(33); + CreateTimeWaitListManager(); + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + EXPECT_CALL(*time_wait_list_manager_, + SendVersionNegotiationPacket(connection_id, _, _, _, _, _, _, _)) + .Times(1); + expect_generator_is_called_ = false; + ProcessFirstFlight(QuicVersionReservedForNegotiation(), client_address, + connection_id); +} + +TEST_P(QuicDispatcherTestOneVersion, + StatelessVersionNegotiationWithClientConnectionId) { + CreateTimeWaitListManager(); + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + EXPECT_CALL(*time_wait_list_manager_, + SendVersionNegotiationPacket( + TestConnectionId(1), TestConnectionId(2), _, _, _, _, _, _)) + .Times(1); + expect_generator_is_called_ = false; + ProcessFirstFlight(QuicVersionReservedForNegotiation(), client_address, + TestConnectionId(1), TestConnectionId(2)); +} + +TEST_P(QuicDispatcherTestOneVersion, NoVersionNegotiationWithSmallPacket) { + CreateTimeWaitListManager(); + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + EXPECT_CALL(*time_wait_list_manager_, + SendVersionNegotiationPacket(_, _, _, _, _, _, _, _)) + .Times(0); + std::string chlo = SerializeCHLO() + std::string(1200, 'a'); + // Truncate to 1100 bytes of payload which results in a packet just + // under 1200 bytes after framing, packet, and encryption overhead. + QUICHE_DCHECK_LE(1200u, chlo.length()); + std::string truncated_chlo = chlo.substr(0, 1100); + QUICHE_DCHECK_EQ(1100u, truncated_chlo.length()); + ProcessPacket(client_address, TestConnectionId(1), true, + QuicVersionReservedForNegotiation(), truncated_chlo, false, + CONNECTION_ID_PRESENT, PACKET_4BYTE_PACKET_NUMBER, 1); +} + +// Disabling CHLO size validation allows the dispatcher to send version +// negotiation packets in response to a CHLO that is otherwise too small. +TEST_P(QuicDispatcherTestOneVersion, + VersionNegotiationWithoutChloSizeValidation) { + crypto_config_.set_validate_chlo_size(false); + + CreateTimeWaitListManager(); + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + EXPECT_CALL(*time_wait_list_manager_, + SendVersionNegotiationPacket(_, _, _, _, _, _, _, _)) + .Times(1); + std::string chlo = SerializeCHLO() + std::string(1200, 'a'); + // Truncate to 1100 bytes of payload which results in a packet just + // under 1200 bytes after framing, packet, and encryption overhead. + QUICHE_DCHECK_LE(1200u, chlo.length()); + std::string truncated_chlo = chlo.substr(0, 1100); + QUICHE_DCHECK_EQ(1100u, truncated_chlo.length()); + ProcessPacket(client_address, TestConnectionId(1), true, + QuicVersionReservedForNegotiation(), truncated_chlo, true, + CONNECTION_ID_PRESENT, PACKET_4BYTE_PACKET_NUMBER, 1); +} + +TEST_P(QuicDispatcherTestAllVersions, Shutdown) { + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + + EXPECT_CALL(*dispatcher_, + CreateQuicSession(_, _, client_address, Eq(ExpectedAlpn()), _, _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, TestConnectionId(1), client_address, + &mock_helper_, &mock_alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .WillOnce(WithArg<2>(Invoke([this](const QuicEncryptedPacket& packet) { + ValidatePacket(TestConnectionId(1), packet); + }))); + + ProcessFirstFlight(client_address, TestConnectionId(1)); + + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + CloseConnection(QUIC_PEER_GOING_AWAY, _, _)); + + dispatcher_->Shutdown(); +} + +TEST_P(QuicDispatcherTestAllVersions, TimeWaitListManager) { + CreateTimeWaitListManager(); + + // Create a new session. + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + QuicConnectionId connection_id = TestConnectionId(1); + EXPECT_CALL(*dispatcher_, CreateQuicSession(connection_id, _, client_address, + Eq(ExpectedAlpn()), _, _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, connection_id, client_address, + &mock_helper_, &mock_alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .WillOnce(WithArg<2>(Invoke([this](const QuicEncryptedPacket& packet) { + ValidatePacket(TestConnectionId(1), packet); + }))); + + ProcessFirstFlight(client_address, connection_id); + + // Now close the connection, which should add it to the time wait list. + session1_->connection()->CloseConnection( + QUIC_INVALID_VERSION, + "Server: Packet 2 without version flag before version negotiated.", + ConnectionCloseBehavior::SILENT_CLOSE); + EXPECT_TRUE(time_wait_list_manager_->IsConnectionIdInTimeWait(connection_id)); + + // Dispatcher forwards subsequent packets for this connection_id to the time + // wait list manager. + EXPECT_CALL(*time_wait_list_manager_, + ProcessPacket(_, _, connection_id, _, _, _)) + .Times(1); + EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _)) + .Times(0); + ProcessPacket(client_address, connection_id, true, "data"); +} + +TEST_P(QuicDispatcherTestAllVersions, NoVersionPacketToTimeWaitListManager) { + CreateTimeWaitListManager(); + + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + QuicConnectionId connection_id = TestConnectionId(1); + // Dispatcher forwards all packets for this connection_id to the time wait + // list manager. + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + EXPECT_CALL(*time_wait_list_manager_, + ProcessPacket(_, _, connection_id, _, _, _)) + .Times(0); + EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _)) + .Times(0); + EXPECT_CALL(*time_wait_list_manager_, SendPublicReset(_, _, _, _, _, _)) + .Times(1); + ProcessPacket(client_address, connection_id, /*has_version_flag=*/false, + "data"); +} + +TEST_P(QuicDispatcherTestAllVersions, + DonotTimeWaitPacketsWithUnknownConnectionIdAndNoVersion) { + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + CreateTimeWaitListManager(); + + uint8_t short_packet[22] = {0x70, 0xa7, 0x02, 0x6b}; + uint8_t valid_size_packet[23] = {0x70, 0xa7, 0x02, 0x6c}; + size_t short_packet_len; + if (version_.HasIetfInvariantHeader()) { + short_packet_len = 21; + } else { + short_packet_len = 22; + short_packet[0] = 0x0a; + valid_size_packet[0] = 0x0a; + } + QuicReceivedPacket packet(reinterpret_cast(short_packet), + short_packet_len, QuicTime::Zero()); + QuicReceivedPacket packet2(reinterpret_cast(valid_size_packet), + short_packet_len + 1, QuicTime::Zero()); + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + EXPECT_CALL(*time_wait_list_manager_, ProcessPacket(_, _, _, _, _, _)) + .Times(0); + EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _)) + .Times(0); + // Verify small packet is silently dropped. + if (version_.HasIetfInvariantHeader()) { + EXPECT_CALL(connection_id_generator_, ConnectionIdLength(0xa7)) + .WillOnce(Return(kQuicDefaultConnectionIdLength)); + } else { + EXPECT_CALL(connection_id_generator_, ConnectionIdLength(_)).Times(0); + } + EXPECT_CALL(*time_wait_list_manager_, SendPublicReset(_, _, _, _, _, _)) + .Times(0); + dispatcher_->ProcessPacket(server_address_, client_address, packet); + if (version_.HasIetfInvariantHeader()) { + EXPECT_CALL(connection_id_generator_, ConnectionIdLength(0xa7)) + .WillOnce(Return(kQuicDefaultConnectionIdLength)); + } else { + EXPECT_CALL(connection_id_generator_, ConnectionIdLength(_)).Times(0); + } + EXPECT_CALL(*time_wait_list_manager_, SendPublicReset(_, _, _, _, _, _)) + .Times(1); + dispatcher_->ProcessPacket(server_address_, client_address, packet2); +} + +TEST_P(QuicDispatcherTestOneVersion, DropPacketWithInvalidFlags) { + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + CreateTimeWaitListManager(); + uint8_t all_zero_packet[1200] = {}; + QuicReceivedPacket packet(reinterpret_cast(all_zero_packet), + sizeof(all_zero_packet), QuicTime::Zero()); + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + EXPECT_CALL(*time_wait_list_manager_, ProcessPacket(_, _, _, _, _, _)) + .Times(0); + EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _)) + .Times(0); + EXPECT_CALL(*time_wait_list_manager_, SendPublicReset(_, _, _, _, _, _)) + .Times(0); + EXPECT_CALL(connection_id_generator_, ConnectionIdLength(_)) + .WillOnce(Return(kQuicDefaultConnectionIdLength)); + dispatcher_->ProcessPacket(server_address_, client_address, packet); +} + +TEST_P(QuicDispatcherTestAllVersions, LimitResetsToSameClientAddress) { + CreateTimeWaitListManager(); + + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + QuicSocketAddress client_address2(QuicIpAddress::Loopback4(), 2); + QuicSocketAddress client_address3(QuicIpAddress::Loopback6(), 1); + QuicConnectionId connection_id = TestConnectionId(1); + + // Verify only one reset is sent to the address, although multiple packets + // are received. + EXPECT_CALL(*time_wait_list_manager_, SendPublicReset(_, _, _, _, _, _)) + .Times(1); + ProcessPacket(client_address, connection_id, /*has_version_flag=*/false, + "data"); + ProcessPacket(client_address, connection_id, /*has_version_flag=*/false, + "data2"); + ProcessPacket(client_address, connection_id, /*has_version_flag=*/false, + "data3"); + + EXPECT_CALL(*time_wait_list_manager_, SendPublicReset(_, _, _, _, _, _)) + .Times(2); + ProcessPacket(client_address2, connection_id, /*has_version_flag=*/false, + "data"); + ProcessPacket(client_address3, connection_id, /*has_version_flag=*/false, + "data"); +} + +TEST_P(QuicDispatcherTestAllVersions, + StopSendingResetOnTooManyRecentAddresses) { + SetQuicFlag(quic_max_recent_stateless_reset_addresses, 2); + const size_t kTestLifeTimeMs = 10; + SetQuicFlag(quic_recent_stateless_reset_addresses_lifetime_ms, + kTestLifeTimeMs); + CreateTimeWaitListManager(); + + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + QuicSocketAddress client_address2(QuicIpAddress::Loopback4(), 2); + QuicSocketAddress client_address3(QuicIpAddress::Loopback6(), 1); + QuicConnectionId connection_id = TestConnectionId(1); + + EXPECT_CALL(*time_wait_list_manager_, SendPublicReset(_, _, _, _, _, _)) + .Times(2); + EXPECT_FALSE(GetClearResetAddressesAlarm()->IsSet()); + ProcessPacket(client_address, connection_id, /*has_version_flag=*/false, + "data"); + const QuicTime expected_deadline = + mock_helper_.GetClock()->Now() + + QuicTime::Delta::FromMilliseconds(kTestLifeTimeMs); + ASSERT_TRUE(GetClearResetAddressesAlarm()->IsSet()); + EXPECT_EQ(expected_deadline, GetClearResetAddressesAlarm()->deadline()); + // Received no version packet 2 after 5ms. + mock_helper_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + ProcessPacket(client_address2, connection_id, /*has_version_flag=*/false, + "data"); + ASSERT_TRUE(GetClearResetAddressesAlarm()->IsSet()); + // Verify deadline does not change. + EXPECT_EQ(expected_deadline, GetClearResetAddressesAlarm()->deadline()); + // Verify reset gets throttled since there are too many recent addresses. + EXPECT_CALL(*time_wait_list_manager_, SendPublicReset(_, _, _, _, _, _)) + .Times(0); + ProcessPacket(client_address3, connection_id, /*has_version_flag=*/false, + "data"); + + mock_helper_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + GetClearResetAddressesAlarm()->Fire(); + EXPECT_CALL(*time_wait_list_manager_, SendPublicReset(_, _, _, _, _, _)) + .Times(2); + ProcessPacket(client_address, connection_id, /*has_version_flag=*/false, + "data"); + ProcessPacket(client_address2, connection_id, /*has_version_flag=*/false, + "data"); + ProcessPacket(client_address3, connection_id, /*has_version_flag=*/false, + "data"); +} + +// Makes sure nine-byte connection IDs are replaced by 8-byte ones. +TEST_P(QuicDispatcherTestAllVersions, LongConnectionIdLengthReplaced) { + if (!version_.AllowsVariableLengthConnectionIds()) { + // When variable length connection IDs are not supported, the connection + // fails. See StrayPacketTruncatedConnectionId. + return; + } + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + + QuicConnectionId bad_connection_id = TestConnectionIdNineBytesLong(2); + generated_connection_id_ = kReturnConnectionId; + + EXPECT_CALL(*dispatcher_, + CreateQuicSession(*generated_connection_id_, _, client_address, + Eq(ExpectedAlpn()), _, _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, *generated_connection_id_, client_address, + &mock_helper_, &mock_alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .WillOnce(WithArg<2>( + Invoke([this, bad_connection_id](const QuicEncryptedPacket& packet) { + ValidatePacket(bad_connection_id, packet); + }))); + ProcessFirstFlight(client_address, bad_connection_id); +} + +// Makes sure zero-byte connection IDs are replaced by 8-byte ones. +TEST_P(QuicDispatcherTestAllVersions, InvalidShortConnectionIdLengthReplaced) { + if (!version_.AllowsVariableLengthConnectionIds()) { + // When variable length connection IDs are not supported, the connection + // fails. See StrayPacketTruncatedConnectionId. + return; + } + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + + QuicConnectionId bad_connection_id = EmptyQuicConnectionId(); + generated_connection_id_ = kReturnConnectionId; + + // Disable validation of invalid short connection IDs. + dispatcher_->SetAllowShortInitialServerConnectionIds(true); + // Note that StrayPacketTruncatedConnectionId covers the case where the + // validation is still enabled. + EXPECT_CALL(*dispatcher_, + CreateQuicSession(*generated_connection_id_, _, client_address, + Eq(ExpectedAlpn()), _, _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, *generated_connection_id_, client_address, + &mock_helper_, &mock_alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .WillOnce(WithArg<2>( + Invoke([this, bad_connection_id](const QuicEncryptedPacket& packet) { + ValidatePacket(bad_connection_id, packet); + }))); + ProcessFirstFlight(client_address, bad_connection_id); +} + +// Makes sure TestConnectionId(1) creates a new connection and +// TestConnectionIdNineBytesLong(2) gets replaced. +TEST_P(QuicDispatcherTestAllVersions, MixGoodAndBadConnectionIdLengthPackets) { + if (!version_.AllowsVariableLengthConnectionIds()) { + return; + } + + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + QuicConnectionId bad_connection_id = TestConnectionIdNineBytesLong(2); + + EXPECT_CALL(*dispatcher_, + CreateQuicSession(TestConnectionId(1), _, client_address, + Eq(ExpectedAlpn()), _, _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, TestConnectionId(1), client_address, + &mock_helper_, &mock_alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .WillOnce(WithArg<2>(Invoke([this](const QuicEncryptedPacket& packet) { + ValidatePacket(TestConnectionId(1), packet); + }))); + ProcessFirstFlight(client_address, TestConnectionId(1)); + + generated_connection_id_ = kReturnConnectionId; + EXPECT_CALL(*dispatcher_, + CreateQuicSession(*generated_connection_id_, _, client_address, + Eq(ExpectedAlpn()), _, _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, *generated_connection_id_, client_address, + &mock_helper_, &mock_alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session2_)))); + EXPECT_CALL(*reinterpret_cast(session2_->connection()), + ProcessUdpPacket(_, _, _)) + .WillOnce(WithArg<2>( + Invoke([this, bad_connection_id](const QuicEncryptedPacket& packet) { + ValidatePacket(bad_connection_id, packet); + }))); + ProcessFirstFlight(client_address, bad_connection_id); + + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .Times(1) + .WillOnce(WithArg<2>(Invoke([this](const QuicEncryptedPacket& packet) { + ValidatePacket(TestConnectionId(1), packet); + }))); + ProcessPacket(client_address, TestConnectionId(1), false, "data"); +} + +TEST_P(QuicDispatcherTestAllVersions, ProcessPacketWithZeroPort) { + CreateTimeWaitListManager(); + + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 0); + + // dispatcher_ should drop this packet. + EXPECT_CALL(*dispatcher_, CreateQuicSession(TestConnectionId(1), _, + client_address, _, _, _)) + .Times(0); + EXPECT_CALL(*time_wait_list_manager_, ProcessPacket(_, _, _, _, _, _)) + .Times(0); + EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _)) + .Times(0); + ProcessPacket(client_address, TestConnectionId(1), /*has_version_flag=*/true, + "data"); +} + +TEST_P(QuicDispatcherTestAllVersions, ProcessPacketWithBlockedPort) { + CreateTimeWaitListManager(); + + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 17); + + // dispatcher_ should drop this packet. + EXPECT_CALL(*dispatcher_, CreateQuicSession(TestConnectionId(1), _, + client_address, _, _, _)) + .Times(0); + EXPECT_CALL(*time_wait_list_manager_, ProcessPacket(_, _, _, _, _, _)) + .Times(0); + EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _)) + .Times(0); + ProcessPacket(client_address, TestConnectionId(1), /*has_version_flag=*/true, + "data"); +} + +TEST_P(QuicDispatcherTestAllVersions, ProcessPacketWithNonBlockedPort) { + CreateTimeWaitListManager(); + + // Port 443 must not be blocked because it might be useful for proxies to send + // proxied traffic with source port 443 as that allows building a full QUIC + // proxy using a single UDP socket. + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 443); + + // dispatcher_ should not drop this packet. + EXPECT_CALL(*dispatcher_, + CreateQuicSession(TestConnectionId(1), _, client_address, + Eq(ExpectedAlpn()), _, _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, TestConnectionId(1), client_address, + &mock_helper_, &mock_alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + ProcessFirstFlight(client_address, TestConnectionId(1)); +} + +TEST_P(QuicDispatcherTestAllVersions, + DropPacketWithKnownVersionAndInvalidShortInitialConnectionId) { + if (!version_.AllowsVariableLengthConnectionIds()) { + return; + } + CreateTimeWaitListManager(); + + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + + // dispatcher_ should drop this packet. + EXPECT_CALL(connection_id_generator_, ConnectionIdLength(0x00)) + .WillOnce(Return(10)); + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + EXPECT_CALL(*time_wait_list_manager_, ProcessPacket(_, _, _, _, _, _)) + .Times(0); + EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _)) + .Times(0); + expect_generator_is_called_ = false; + ProcessFirstFlight(client_address, EmptyQuicConnectionId()); +} + +TEST_P(QuicDispatcherTestAllVersions, + DropPacketWithKnownVersionAndInvalidInitialConnectionId) { + CreateTimeWaitListManager(); + + QuicSocketAddress server_address; + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + + // dispatcher_ should drop this packet with invalid connection ID. + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + EXPECT_CALL(*time_wait_list_manager_, ProcessPacket(_, _, _, _, _, _)) + .Times(0); + EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _)) + .Times(0); + absl::string_view cid_str = "123456789abcdefg123456789abcdefg"; + QuicConnectionId invalid_connection_id(cid_str.data(), cid_str.length()); + QuicReceivedPacket packet("packet", 6, QuicTime::Zero()); + ReceivedPacketInfo packet_info(server_address, client_address, packet); + packet_info.version_flag = true; + packet_info.version = version_; + packet_info.destination_connection_id = invalid_connection_id; + + ASSERT_TRUE(dispatcher_->MaybeDispatchPacket(packet_info)); +} + +void QuicDispatcherTestBase:: + TestVersionNegotiationForUnknownVersionInvalidShortInitialConnectionId( + const QuicConnectionId& server_connection_id, + const QuicConnectionId& client_connection_id) { + CreateTimeWaitListManager(); + + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + EXPECT_CALL(*time_wait_list_manager_, + SendVersionNegotiationPacket( + server_connection_id, client_connection_id, + /*ietf_quic=*/true, + /*use_length_prefix=*/true, _, _, client_address, _)) + .Times(1); + expect_generator_is_called_ = false; + EXPECT_CALL(connection_id_generator_, ConnectionIdLength(_)).Times(0); + ProcessFirstFlight(ParsedQuicVersion::ReservedForNegotiation(), + client_address, server_connection_id, + client_connection_id); +} + +TEST_P(QuicDispatcherTestOneVersion, + VersionNegotiationForUnknownVersionInvalidShortInitialConnectionId) { + TestVersionNegotiationForUnknownVersionInvalidShortInitialConnectionId( + EmptyQuicConnectionId(), EmptyQuicConnectionId()); +} + +TEST_P(QuicDispatcherTestOneVersion, + VersionNegotiationForUnknownVersionInvalidShortInitialConnectionId2) { + char server_connection_id_bytes[3] = {1, 2, 3}; + QuicConnectionId server_connection_id(server_connection_id_bytes, + sizeof(server_connection_id_bytes)); + TestVersionNegotiationForUnknownVersionInvalidShortInitialConnectionId( + server_connection_id, EmptyQuicConnectionId()); +} + +TEST_P(QuicDispatcherTestOneVersion, + VersionNegotiationForUnknownVersionInvalidShortInitialConnectionId3) { + char client_connection_id_bytes[8] = {1, 2, 3, 4, 5, 6, 7, 8}; + QuicConnectionId client_connection_id(client_connection_id_bytes, + sizeof(client_connection_id_bytes)); + TestVersionNegotiationForUnknownVersionInvalidShortInitialConnectionId( + EmptyQuicConnectionId(), client_connection_id); +} + +TEST_P(QuicDispatcherTestOneVersion, VersionsChangeInFlight) { + VerifyVersionNotSupported(QuicVersionReservedForNegotiation()); + for (ParsedQuicVersion version : CurrentSupportedVersions()) { + VerifyVersionSupported(version); + QuicDisableVersion(version); + VerifyVersionNotSupported(version); + QuicEnableVersion(version); + VerifyVersionSupported(version); + } +} + +TEST_P(QuicDispatcherTestOneVersion, + RejectDeprecatedVersionDraft28WithVersionNegotiation) { + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + CreateTimeWaitListManager(); + uint8_t packet[kMinPacketSizeForVersionNegotiation] = { + 0xC0, 0xFF, 0x00, 0x00, 28, /*destination connection ID length*/ 0x08}; + QuicReceivedPacket received_packet(reinterpret_cast(packet), + ABSL_ARRAYSIZE(packet), QuicTime::Zero()); + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + EXPECT_CALL( + *time_wait_list_manager_, + SendVersionNegotiationPacket(_, _, /*ietf_quic=*/true, + /*use_length_prefix=*/true, _, _, _, _)) + .Times(1); + dispatcher_->ProcessPacket(server_address_, client_address, received_packet); +} + +TEST_P(QuicDispatcherTestOneVersion, + RejectDeprecatedVersionDraft27WithVersionNegotiation) { + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + CreateTimeWaitListManager(); + uint8_t packet[kMinPacketSizeForVersionNegotiation] = { + 0xC0, 0xFF, 0x00, 0x00, 27, /*destination connection ID length*/ 0x08}; + QuicReceivedPacket received_packet(reinterpret_cast(packet), + ABSL_ARRAYSIZE(packet), QuicTime::Zero()); + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + EXPECT_CALL( + *time_wait_list_manager_, + SendVersionNegotiationPacket(_, _, /*ietf_quic=*/true, + /*use_length_prefix=*/true, _, _, _, _)) + .Times(1); + dispatcher_->ProcessPacket(server_address_, client_address, received_packet); +} + +TEST_P(QuicDispatcherTestOneVersion, + RejectDeprecatedVersionDraft25WithVersionNegotiation) { + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + CreateTimeWaitListManager(); + uint8_t packet[kMinPacketSizeForVersionNegotiation] = { + 0xC0, 0xFF, 0x00, 0x00, 25, /*destination connection ID length*/ 0x08}; + QuicReceivedPacket received_packet(reinterpret_cast(packet), + ABSL_ARRAYSIZE(packet), QuicTime::Zero()); + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + EXPECT_CALL( + *time_wait_list_manager_, + SendVersionNegotiationPacket(_, _, /*ietf_quic=*/true, + /*use_length_prefix=*/true, _, _, _, _)) + .Times(1); + dispatcher_->ProcessPacket(server_address_, client_address, received_packet); +} + +TEST_P(QuicDispatcherTestOneVersion, + RejectDeprecatedVersionT050WithVersionNegotiation) { + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + CreateTimeWaitListManager(); + uint8_t packet[kMinPacketSizeForVersionNegotiation] = { + 0xC0, 'T', '0', '5', '0', /*destination connection ID length*/ 0x08}; + QuicReceivedPacket received_packet(reinterpret_cast(packet), + ABSL_ARRAYSIZE(packet), QuicTime::Zero()); + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + EXPECT_CALL( + *time_wait_list_manager_, + SendVersionNegotiationPacket(_, _, /*ietf_quic=*/true, + /*use_length_prefix=*/true, _, _, _, _)) + .Times(1); + dispatcher_->ProcessPacket(server_address_, client_address, received_packet); +} + +TEST_P(QuicDispatcherTestOneVersion, + RejectDeprecatedVersionQ049WithVersionNegotiation) { + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + CreateTimeWaitListManager(); + uint8_t packet[kMinPacketSizeForVersionNegotiation] = { + 0xC0, 'Q', '0', '4', '9', /*destination connection ID length*/ 0x08}; + QuicReceivedPacket received_packet(reinterpret_cast(packet), + ABSL_ARRAYSIZE(packet), QuicTime::Zero()); + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + EXPECT_CALL( + *time_wait_list_manager_, + SendVersionNegotiationPacket(_, _, /*ietf_quic=*/true, + /*use_length_prefix=*/true, _, _, _, _)) + .Times(1); + dispatcher_->ProcessPacket(server_address_, client_address, received_packet); +} + +TEST_P(QuicDispatcherTestOneVersion, + RejectDeprecatedVersionQ048WithVersionNegotiation) { + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + CreateTimeWaitListManager(); + uint8_t packet[kMinPacketSizeForVersionNegotiation] = { + 0xC0, 'Q', '0', '4', '8', /*connection ID length byte*/ 0x50}; + QuicReceivedPacket received_packet(reinterpret_cast(packet), + ABSL_ARRAYSIZE(packet), QuicTime::Zero()); + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + EXPECT_CALL( + *time_wait_list_manager_, + SendVersionNegotiationPacket(_, _, /*ietf_quic=*/true, + /*use_length_prefix=*/false, _, _, _, _)) + .Times(1); + dispatcher_->ProcessPacket(server_address_, client_address, received_packet); +} + +TEST_P(QuicDispatcherTestOneVersion, + RejectDeprecatedVersionQ047WithVersionNegotiation) { + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + CreateTimeWaitListManager(); + uint8_t packet[kMinPacketSizeForVersionNegotiation] = { + 0xC0, 'Q', '0', '4', '7', /*connection ID length byte*/ 0x50}; + QuicReceivedPacket received_packet(reinterpret_cast(packet), + ABSL_ARRAYSIZE(packet), QuicTime::Zero()); + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + EXPECT_CALL( + *time_wait_list_manager_, + SendVersionNegotiationPacket(_, _, /*ietf_quic=*/true, + /*use_length_prefix=*/false, _, _, _, _)) + .Times(1); + dispatcher_->ProcessPacket(server_address_, client_address, received_packet); +} + +TEST_P(QuicDispatcherTestOneVersion, + RejectDeprecatedVersionQ045WithVersionNegotiation) { + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + CreateTimeWaitListManager(); + uint8_t packet[kMinPacketSizeForVersionNegotiation] = { + 0xC0, 'Q', '0', '4', '5', /*connection ID length byte*/ 0x50}; + QuicReceivedPacket received_packet(reinterpret_cast(packet), + ABSL_ARRAYSIZE(packet), QuicTime::Zero()); + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + EXPECT_CALL( + *time_wait_list_manager_, + SendVersionNegotiationPacket(_, _, /*ietf_quic=*/true, + /*use_length_prefix=*/false, _, _, _, _)) + .Times(1); + dispatcher_->ProcessPacket(server_address_, client_address, received_packet); +} + +TEST_P(QuicDispatcherTestOneVersion, + RejectDeprecatedVersionQ044WithVersionNegotiation) { + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + CreateTimeWaitListManager(); + uint8_t packet44[kMinPacketSizeForVersionNegotiation] = { + 0xFF, 'Q', '0', '4', '4', /*connection ID length byte*/ 0x50}; + QuicReceivedPacket received_packet44(reinterpret_cast(packet44), + kMinPacketSizeForVersionNegotiation, + QuicTime::Zero()); + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + EXPECT_CALL( + *time_wait_list_manager_, + SendVersionNegotiationPacket(_, _, /*ietf_quic=*/true, + /*use_length_prefix=*/false, _, _, _, _)) + .Times(1); + dispatcher_->ProcessPacket(server_address_, client_address, + received_packet44); +} + +TEST_P(QuicDispatcherTestOneVersion, + RejectDeprecatedVersionT051WithVersionNegotiation) { + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + CreateTimeWaitListManager(); + uint8_t packet[kMinPacketSizeForVersionNegotiation] = { + 0xFF, 'T', '0', '5', '1', /*destination connection ID length*/ 0x08}; + QuicReceivedPacket received_packet(reinterpret_cast(packet), + kMinPacketSizeForVersionNegotiation, + QuicTime::Zero()); + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + EXPECT_CALL( + *time_wait_list_manager_, + SendVersionNegotiationPacket(_, _, /*ietf_quic=*/true, + /*use_length_prefix=*/true, _, _, _, _)) + .Times(1); + dispatcher_->ProcessPacket(server_address_, client_address, received_packet); +} + +static_assert(quic::SupportedVersions().size() == 6u, + "Please add new RejectDeprecatedVersion tests above this assert " + "when deprecating versions"); + +TEST_P(QuicDispatcherTestOneVersion, VersionNegotiationProbe) { + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + CreateTimeWaitListManager(); + char packet[1200]; + char destination_connection_id_bytes[] = {0x56, 0x4e, 0x20, 0x70, + 0x6c, 0x7a, 0x20, 0x21}; + EXPECT_TRUE(QuicFramer::WriteClientVersionNegotiationProbePacket( + packet, sizeof(packet), destination_connection_id_bytes, + sizeof(destination_connection_id_bytes))); + QuicEncryptedPacket encrypted(packet, sizeof(packet), false); + std::unique_ptr received_packet( + ConstructReceivedPacket(encrypted, mock_helper_.GetClock()->Now())); + QuicConnectionId client_connection_id = EmptyQuicConnectionId(); + QuicConnectionId server_connection_id( + destination_connection_id_bytes, sizeof(destination_connection_id_bytes)); + EXPECT_CALL(*time_wait_list_manager_, + SendVersionNegotiationPacket( + server_connection_id, client_connection_id, + /*ietf_quic=*/true, /*use_length_prefix=*/true, _, _, _, _)) + .Times(1); + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + + dispatcher_->ProcessPacket(server_address_, client_address, *received_packet); +} + +// Testing packet writer that saves all packets instead of sending them. +// Useful for tests that need access to sent packets. +class SavingWriter : public QuicPacketWriterWrapper { + public: + bool IsWriteBlocked() const override { return false; } + + WriteResult WritePacket(const char* buffer, size_t buf_len, + const QuicIpAddress& /*self_client_address*/, + const QuicSocketAddress& /*peer_client_address*/, + PerPacketOptions* /*options*/) override { + packets_.push_back( + QuicEncryptedPacket(buffer, buf_len, /*owns_buffer=*/false).Clone()); + return WriteResult(WRITE_STATUS_OK, buf_len); + } + + std::vector>* packets() { + return &packets_; + } + + private: + std::vector> packets_; +}; + +TEST_P(QuicDispatcherTestOneVersion, VersionNegotiationProbeEndToEnd) { + SavingWriter* saving_writer = new SavingWriter(); + // dispatcher_ takes ownership of saving_writer. + QuicDispatcherPeer::UseWriter(dispatcher_.get(), saving_writer); + + QuicTimeWaitListManager* time_wait_list_manager = new QuicTimeWaitListManager( + saving_writer, dispatcher_.get(), mock_helper_.GetClock(), + &mock_alarm_factory_); + // dispatcher_ takes ownership of time_wait_list_manager. + QuicDispatcherPeer::SetTimeWaitListManager(dispatcher_.get(), + time_wait_list_manager); + char packet[1200] = {}; + char destination_connection_id_bytes[] = {0x56, 0x4e, 0x20, 0x70, + 0x6c, 0x7a, 0x20, 0x21}; + EXPECT_TRUE(QuicFramer::WriteClientVersionNegotiationProbePacket( + packet, sizeof(packet), destination_connection_id_bytes, + sizeof(destination_connection_id_bytes))); + QuicEncryptedPacket encrypted(packet, sizeof(packet), false); + std::unique_ptr received_packet( + ConstructReceivedPacket(encrypted, mock_helper_.GetClock()->Now())); + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + dispatcher_->ProcessPacket(server_address_, client_address, *received_packet); + ASSERT_EQ(1u, saving_writer->packets()->size()); + + char source_connection_id_bytes[255] = {}; + uint8_t source_connection_id_length = sizeof(source_connection_id_bytes); + std::string detailed_error = "foobar"; + EXPECT_TRUE(QuicFramer::ParseServerVersionNegotiationProbeResponse( + (*(saving_writer->packets()))[0]->data(), + (*(saving_writer->packets()))[0]->length(), source_connection_id_bytes, + &source_connection_id_length, &detailed_error)); + EXPECT_EQ("", detailed_error); + + // The source connection ID of the probe response should match the + // destination connection ID of the probe request. + quiche::test::CompareCharArraysWithHexError( + "parsed probe", source_connection_id_bytes, source_connection_id_length, + destination_connection_id_bytes, sizeof(destination_connection_id_bytes)); +} + +TEST_P(QuicDispatcherTestOneVersion, AndroidConformanceTest) { + // WARNING: do not remove or modify this test without making sure that we + // still have adequate coverage for the Android conformance test. + SavingWriter* saving_writer = new SavingWriter(); + // dispatcher_ takes ownership of saving_writer. + QuicDispatcherPeer::UseWriter(dispatcher_.get(), saving_writer); + + QuicTimeWaitListManager* time_wait_list_manager = new QuicTimeWaitListManager( + saving_writer, dispatcher_.get(), mock_helper_.GetClock(), + &mock_alarm_factory_); + // dispatcher_ takes ownership of time_wait_list_manager. + QuicDispatcherPeer::SetTimeWaitListManager(dispatcher_.get(), + time_wait_list_manager); + // clang-format off + static const unsigned char packet[1200] = { + // Android UDP network conformance test packet as it was after this change: + // https://android-review.googlesource.com/c/platform/cts/+/1454515 + 0xc0, // long header + 0xaa, 0xda, 0xca, 0xca, // reserved-space version number + 0x08, // destination connection ID length + 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, // 8-byte connection ID + 0x00, // source connection ID length + }; + // clang-format on + + QuicEncryptedPacket encrypted(reinterpret_cast(packet), + sizeof(packet), false); + std::unique_ptr received_packet( + ConstructReceivedPacket(encrypted, mock_helper_.GetClock()->Now())); + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + dispatcher_->ProcessPacket(server_address_, client_address, *received_packet); + ASSERT_EQ(1u, saving_writer->packets()->size()); + + // The Android UDP network conformance test directly checks that these bytes + // of the response match the connection ID that was sent. + ASSERT_GE((*(saving_writer->packets()))[0]->length(), 15u); + quiche::test::CompareCharArraysWithHexError( + "response connection ID", &(*(saving_writer->packets()))[0]->data()[7], 8, + reinterpret_cast(&packet[6]), 8); +} + +TEST_P(QuicDispatcherTestOneVersion, AndroidConformanceTestOld) { + // WARNING: this test covers an old Android Conformance Test that has now been + // changed, but it'll take time for the change to propagate through the + // Android ecosystem. The Android team has asked us to keep this test + // supported until at least 2021-03-31. After that date, and when we drop + // support for sending QUIC version negotiation packets using the legacy + // Google QUIC format (Q001-Q043), then we can delete this test. + // TODO(dschinazi) delete this test after 2021-03-31 + SavingWriter* saving_writer = new SavingWriter(); + // dispatcher_ takes ownership of saving_writer. + QuicDispatcherPeer::UseWriter(dispatcher_.get(), saving_writer); + + QuicTimeWaitListManager* time_wait_list_manager = new QuicTimeWaitListManager( + saving_writer, dispatcher_.get(), mock_helper_.GetClock(), + &mock_alarm_factory_); + // dispatcher_ takes ownership of time_wait_list_manager. + QuicDispatcherPeer::SetTimeWaitListManager(dispatcher_.get(), + time_wait_list_manager); + // clang-format off + static const unsigned char packet[1200] = { + // Android UDP network conformance test packet as it was after this change: + // https://android-review.googlesource.com/c/platform/cts/+/1104285 + // but before this change: + // https://android-review.googlesource.com/c/platform/cts/+/1454515 + 0x0d, // public flags: version, 8-byte connection ID, 1-byte packet number + 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, // 8-byte connection ID + 0xaa, 0xda, 0xca, 0xaa, // reserved-space version number + 0x01, // 1-byte packet number + 0x00, // private flags + 0x07, // PING frame + }; + // clang-format on + + QuicEncryptedPacket encrypted(reinterpret_cast(packet), + sizeof(packet), false); + std::unique_ptr received_packet( + ConstructReceivedPacket(encrypted, mock_helper_.GetClock()->Now())); + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + dispatcher_->ProcessPacket(server_address_, client_address, *received_packet); + ASSERT_EQ(1u, saving_writer->packets()->size()); + + // The Android UDP network conformance test directly checks that bytes 1-9 + // of the response match the connection ID that was sent. + static const char connection_id_bytes[] = {0x71, 0x72, 0x73, 0x74, + 0x75, 0x76, 0x77, 0x78}; + ASSERT_GE((*(saving_writer->packets()))[0]->length(), + 1u + sizeof(connection_id_bytes)); + quiche::test::CompareCharArraysWithHexError( + "response connection ID", &(*(saving_writer->packets()))[0]->data()[1], + sizeof(connection_id_bytes), connection_id_bytes, + sizeof(connection_id_bytes)); +} + +TEST_P(QuicDispatcherTestAllVersions, DoNotProcessSmallPacket) { + CreateTimeWaitListManager(); + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + EXPECT_CALL(*time_wait_list_manager_, SendPacket(_, _, _)).Times(0); + EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _)) + .Times(0); + ProcessPacket(client_address, TestConnectionId(1), /*has_version_flag=*/true, + version_, SerializeCHLO(), /*full_padding=*/false, + CONNECTION_ID_PRESENT, PACKET_4BYTE_PACKET_NUMBER, 1); +} + +TEST_P(QuicDispatcherTestAllVersions, ProcessSmallCoalescedPacket) { + CreateTimeWaitListManager(); + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + + EXPECT_CALL(*time_wait_list_manager_, SendPacket(_, _, _)).Times(0); + + // clang-format off + uint8_t coalesced_packet[1200] = { + // first coalesced packet + // public flags (long header with packet type INITIAL and + // 4-byte packet number) + 0xC3, + // version + 'Q', '0', '9', '9', + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // long header packet length + 0x05, + // packet number + 0x12, 0x34, 0x56, 0x78, + // Padding + 0x00, + // second coalesced packet + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + 0xC3, + // version + 'Q', '0', '9', '9', + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // long header packet length + 0x1E, + // packet number + 0x12, 0x34, 0x56, 0x79, + }; + // clang-format on + QuicReceivedPacket packet(reinterpret_cast(coalesced_packet), 1200, + QuicTime::Zero()); + dispatcher_->ProcessPacket(server_address_, client_address, packet); +} + +TEST_P(QuicDispatcherTestAllVersions, StopAcceptingNewConnections) { + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + + EXPECT_CALL(*dispatcher_, + CreateQuicSession(TestConnectionId(1), _, client_address, + Eq(ExpectedAlpn()), _, _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, TestConnectionId(1), client_address, + &mock_helper_, &mock_alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .WillOnce(WithArg<2>(Invoke([this](const QuicEncryptedPacket& packet) { + ValidatePacket(TestConnectionId(1), packet); + }))); + ProcessFirstFlight(client_address, TestConnectionId(1)); + + dispatcher_->StopAcceptingNewConnections(); + EXPECT_FALSE(dispatcher_->accept_new_connections()); + + // No more new connections afterwards. + EXPECT_CALL(*dispatcher_, + CreateQuicSession(TestConnectionId(2), _, client_address, + Eq(ExpectedAlpn()), _, _)) + .Times(0u); + expect_generator_is_called_ = false; + ProcessFirstFlight(client_address, TestConnectionId(2)); + + // Existing connections should be able to continue. + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .Times(1u) + .WillOnce(WithArg<2>(Invoke([this](const QuicEncryptedPacket& packet) { + ValidatePacket(TestConnectionId(1), packet); + }))); + ProcessPacket(client_address, TestConnectionId(1), false, "data"); +} + +TEST_P(QuicDispatcherTestAllVersions, StartAcceptingNewConnections) { + dispatcher_->StopAcceptingNewConnections(); + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + + // No more new connections afterwards. + EXPECT_CALL(*dispatcher_, + CreateQuicSession(TestConnectionId(2), _, client_address, + Eq(ExpectedAlpn()), _, _)) + .Times(0u); + expect_generator_is_called_ = false; + ProcessFirstFlight(client_address, TestConnectionId(2)); + + dispatcher_->StartAcceptingNewConnections(); + EXPECT_TRUE(dispatcher_->accept_new_connections()); + + expect_generator_is_called_ = true; + EXPECT_CALL(*dispatcher_, + CreateQuicSession(TestConnectionId(1), _, client_address, + Eq(ExpectedAlpn()), _, _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, TestConnectionId(1), client_address, + &mock_helper_, &mock_alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .WillOnce(WithArg<2>(Invoke([this](const QuicEncryptedPacket& packet) { + ValidatePacket(TestConnectionId(1), packet); + }))); + ProcessFirstFlight(client_address, TestConnectionId(1)); +} + +TEST_P(QuicDispatcherTestOneVersion, SelectAlpn) { + EXPECT_EQ(QuicDispatcherPeer::SelectAlpn(dispatcher_.get(), {}), ""); + EXPECT_EQ(QuicDispatcherPeer::SelectAlpn(dispatcher_.get(), {""}), ""); + EXPECT_EQ(QuicDispatcherPeer::SelectAlpn(dispatcher_.get(), {"hq"}), "hq"); + // Q033 is no longer supported but Q050 is. + QuicEnableVersion(ParsedQuicVersion::Q050()); + EXPECT_EQ( + QuicDispatcherPeer::SelectAlpn(dispatcher_.get(), {"h3-Q033", "h3-Q050"}), + "h3-Q050"); +} + +// Verify the stopgap test: Packets with truncated connection IDs should be +// dropped. +class QuicDispatcherTestStrayPacketConnectionId + : public QuicDispatcherTestBase {}; + +INSTANTIATE_TEST_SUITE_P(QuicDispatcherTestsStrayPacketConnectionId, + QuicDispatcherTestStrayPacketConnectionId, + ::testing::ValuesIn(CurrentSupportedVersions()), + ::testing::PrintToStringParamName()); + +// Packets with truncated connection IDs should be dropped. +TEST_P(QuicDispatcherTestStrayPacketConnectionId, + StrayPacketTruncatedConnectionId) { + CreateTimeWaitListManager(); + + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + QuicConnectionId connection_id = TestConnectionId(1); + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + EXPECT_CALL(*time_wait_list_manager_, ProcessPacket(_, _, _, _, _, _)) + .Times(0); + EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _)) + .Times(0); + + ProcessPacket(client_address, connection_id, true, "data", + CONNECTION_ID_ABSENT, PACKET_4BYTE_PACKET_NUMBER); +} + +class BlockingWriter : public QuicPacketWriterWrapper { + public: + BlockingWriter() : write_blocked_(false) {} + + bool IsWriteBlocked() const override { return write_blocked_; } + void SetWritable() override { write_blocked_ = false; } + + WriteResult WritePacket(const char* /*buffer*/, size_t /*buf_len*/, + const QuicIpAddress& /*self_client_address*/, + const QuicSocketAddress& /*peer_client_address*/, + PerPacketOptions* /*options*/) override { + // It would be quite possible to actually implement this method here with + // the fake blocked status, but it would be significantly more work in + // Chromium, and since it's not called anyway, don't bother. + QUIC_LOG(DFATAL) << "Not supported"; + return WriteResult(); + } + + bool write_blocked_; +}; + +class QuicDispatcherWriteBlockedListTest : public QuicDispatcherTestBase { + public: + void SetUp() override { + QuicDispatcherTestBase::SetUp(); + writer_ = new BlockingWriter; + QuicDispatcherPeer::UseWriter(dispatcher_.get(), writer_); + + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, client_address, + Eq(ExpectedAlpn()), _, _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, TestConnectionId(1), client_address, + &helper_, &alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .WillOnce(WithArg<2>(Invoke([this](const QuicEncryptedPacket& packet) { + ValidatePacket(TestConnectionId(1), packet); + }))); + ProcessFirstFlight(client_address, TestConnectionId(1)); + + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, client_address, + Eq(ExpectedAlpn()), _, _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, TestConnectionId(2), client_address, + &helper_, &alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session2_)))); + EXPECT_CALL(*reinterpret_cast(session2_->connection()), + ProcessUdpPacket(_, _, _)) + .WillOnce(WithArg<2>(Invoke([this](const QuicEncryptedPacket& packet) { + ValidatePacket(TestConnectionId(2), packet); + }))); + ProcessFirstFlight(client_address, TestConnectionId(2)); + + blocked_list_ = QuicDispatcherPeer::GetWriteBlockedList(dispatcher_.get()); + } + + void TearDown() override { + if (connection1() != nullptr) { + EXPECT_CALL(*connection1(), CloseConnection(QUIC_PEER_GOING_AWAY, _, _)); + } + + if (connection2() != nullptr) { + EXPECT_CALL(*connection2(), CloseConnection(QUIC_PEER_GOING_AWAY, _, _)); + } + dispatcher_->Shutdown(); + } + + // Set the dispatcher's writer to be blocked. By default, all connections use + // the same writer as the dispatcher in this test. + void SetBlocked() { + QUIC_LOG(INFO) << "set writer " << writer_ << " to blocked"; + writer_->write_blocked_ = true; + } + + // Simulate what happens when connection1 gets blocked when writing. + void BlockConnection1() { + Connection1Writer()->write_blocked_ = true; + dispatcher_->OnWriteBlocked(connection1()); + } + + BlockingWriter* Connection1Writer() { + return static_cast(connection1()->writer()); + } + + // Simulate what happens when connection2 gets blocked when writing. + void BlockConnection2() { + Connection2Writer()->write_blocked_ = true; + dispatcher_->OnWriteBlocked(connection2()); + } + + BlockingWriter* Connection2Writer() { + return static_cast(connection2()->writer()); + } + + protected: + MockQuicConnectionHelper helper_; + MockAlarmFactory alarm_factory_; + BlockingWriter* writer_; + QuicDispatcher::WriteBlockedList* blocked_list_; +}; + +INSTANTIATE_TEST_SUITE_P(QuicDispatcherWriteBlockedListTests, + QuicDispatcherWriteBlockedListTest, + ::testing::Values(CurrentSupportedVersions().front()), + ::testing::PrintToStringParamName()); + +TEST_P(QuicDispatcherWriteBlockedListTest, BasicOnCanWrite) { + // No OnCanWrite calls because no connections are blocked. + dispatcher_->OnCanWrite(); + + // Register connection 1 for events, and make sure it's notified. + SetBlocked(); + dispatcher_->OnWriteBlocked(connection1()); + EXPECT_CALL(*connection1(), OnCanWrite()); + dispatcher_->OnCanWrite(); + + // It should get only one notification. + EXPECT_CALL(*connection1(), OnCanWrite()).Times(0); + dispatcher_->OnCanWrite(); + EXPECT_FALSE(dispatcher_->HasPendingWrites()); +} + +TEST_P(QuicDispatcherWriteBlockedListTest, OnCanWriteOrder) { + // Make sure we handle events in order. + InSequence s; + SetBlocked(); + dispatcher_->OnWriteBlocked(connection1()); + dispatcher_->OnWriteBlocked(connection2()); + EXPECT_CALL(*connection1(), OnCanWrite()); + EXPECT_CALL(*connection2(), OnCanWrite()); + dispatcher_->OnCanWrite(); + + // Check the other ordering. + SetBlocked(); + dispatcher_->OnWriteBlocked(connection2()); + dispatcher_->OnWriteBlocked(connection1()); + EXPECT_CALL(*connection2(), OnCanWrite()); + EXPECT_CALL(*connection1(), OnCanWrite()); + dispatcher_->OnCanWrite(); +} + +TEST_P(QuicDispatcherWriteBlockedListTest, OnCanWriteRemove) { + // Add and remove one connction. + SetBlocked(); + dispatcher_->OnWriteBlocked(connection1()); + blocked_list_->erase(connection1()); + EXPECT_CALL(*connection1(), OnCanWrite()).Times(0); + dispatcher_->OnCanWrite(); + + // Add and remove one connction and make sure it doesn't affect others. + SetBlocked(); + dispatcher_->OnWriteBlocked(connection1()); + dispatcher_->OnWriteBlocked(connection2()); + blocked_list_->erase(connection1()); + EXPECT_CALL(*connection2(), OnCanWrite()); + dispatcher_->OnCanWrite(); + + // Add it, remove it, and add it back and make sure things are OK. + SetBlocked(); + dispatcher_->OnWriteBlocked(connection1()); + blocked_list_->erase(connection1()); + dispatcher_->OnWriteBlocked(connection1()); + EXPECT_CALL(*connection1(), OnCanWrite()).Times(1); + dispatcher_->OnCanWrite(); +} + +TEST_P(QuicDispatcherWriteBlockedListTest, DoubleAdd) { + // Make sure a double add does not necessitate a double remove. + SetBlocked(); + dispatcher_->OnWriteBlocked(connection1()); + dispatcher_->OnWriteBlocked(connection1()); + blocked_list_->erase(connection1()); + EXPECT_CALL(*connection1(), OnCanWrite()).Times(0); + dispatcher_->OnCanWrite(); + + // Make sure a double add does not result in two OnCanWrite calls. + SetBlocked(); + dispatcher_->OnWriteBlocked(connection1()); + dispatcher_->OnWriteBlocked(connection1()); + EXPECT_CALL(*connection1(), OnCanWrite()).Times(1); + dispatcher_->OnCanWrite(); +} + +TEST_P(QuicDispatcherWriteBlockedListTest, OnCanWriteHandleBlockConnection1) { + // If the 1st blocked writer gets blocked in OnCanWrite, it will be added back + // into the write blocked list. + InSequence s; + SetBlocked(); + dispatcher_->OnWriteBlocked(connection1()); + dispatcher_->OnWriteBlocked(connection2()); + EXPECT_CALL(*connection1(), OnCanWrite()) + .WillOnce( + Invoke(this, &QuicDispatcherWriteBlockedListTest::BlockConnection1)); + EXPECT_CALL(*connection2(), OnCanWrite()); + dispatcher_->OnCanWrite(); + + // connection1 should be still in the write blocked list. + EXPECT_TRUE(dispatcher_->HasPendingWrites()); + + // Now call OnCanWrite again, connection1 should get its second chance. + EXPECT_CALL(*connection1(), OnCanWrite()); + EXPECT_CALL(*connection2(), OnCanWrite()).Times(0); + dispatcher_->OnCanWrite(); + EXPECT_FALSE(dispatcher_->HasPendingWrites()); +} + +TEST_P(QuicDispatcherWriteBlockedListTest, OnCanWriteHandleBlockConnection2) { + // If the 2nd blocked writer gets blocked in OnCanWrite, it will be added back + // into the write blocked list. + InSequence s; + SetBlocked(); + dispatcher_->OnWriteBlocked(connection1()); + dispatcher_->OnWriteBlocked(connection2()); + EXPECT_CALL(*connection1(), OnCanWrite()); + EXPECT_CALL(*connection2(), OnCanWrite()) + .WillOnce( + Invoke(this, &QuicDispatcherWriteBlockedListTest::BlockConnection2)); + dispatcher_->OnCanWrite(); + + // connection2 should be still in the write blocked list. + EXPECT_TRUE(dispatcher_->HasPendingWrites()); + + // Now call OnCanWrite again, connection2 should get its second chance. + EXPECT_CALL(*connection1(), OnCanWrite()).Times(0); + EXPECT_CALL(*connection2(), OnCanWrite()); + dispatcher_->OnCanWrite(); + EXPECT_FALSE(dispatcher_->HasPendingWrites()); +} + +TEST_P(QuicDispatcherWriteBlockedListTest, + OnCanWriteHandleBlockBothConnections) { + // Both connections get blocked in OnCanWrite, and added back into the write + // blocked list. + InSequence s; + SetBlocked(); + dispatcher_->OnWriteBlocked(connection1()); + dispatcher_->OnWriteBlocked(connection2()); + EXPECT_CALL(*connection1(), OnCanWrite()) + .WillOnce( + Invoke(this, &QuicDispatcherWriteBlockedListTest::BlockConnection1)); + EXPECT_CALL(*connection2(), OnCanWrite()) + .WillOnce( + Invoke(this, &QuicDispatcherWriteBlockedListTest::BlockConnection2)); + dispatcher_->OnCanWrite(); + + // Both connections should be still in the write blocked list. + EXPECT_TRUE(dispatcher_->HasPendingWrites()); + + // Now call OnCanWrite again, both connections should get its second chance. + EXPECT_CALL(*connection1(), OnCanWrite()); + EXPECT_CALL(*connection2(), OnCanWrite()); + dispatcher_->OnCanWrite(); + EXPECT_FALSE(dispatcher_->HasPendingWrites()); +} + +TEST_P(QuicDispatcherWriteBlockedListTest, PerConnectionWriterBlocked) { + // By default, all connections share the same packet writer with the + // dispatcher. + EXPECT_EQ(dispatcher_->writer(), connection1()->writer()); + EXPECT_EQ(dispatcher_->writer(), connection2()->writer()); + + // Test the case where connection1 shares the same packet writer as the + // dispatcher, whereas connection2 owns it's packet writer. + // Change connection2's writer. + connection2()->SetQuicPacketWriter(new BlockingWriter, /*owns_writer=*/true); + EXPECT_NE(dispatcher_->writer(), connection2()->writer()); + + BlockConnection2(); + EXPECT_TRUE(dispatcher_->HasPendingWrites()); + + EXPECT_CALL(*connection2(), OnCanWrite()); + dispatcher_->OnCanWrite(); + EXPECT_FALSE(dispatcher_->HasPendingWrites()); +} + +TEST_P(QuicDispatcherWriteBlockedListTest, + RemoveConnectionFromWriteBlockedListWhenDeletingSessions) { + EXPECT_QUIC_BUG( + { + dispatcher_->OnConnectionClosed( + connection1()->connection_id(), QUIC_PACKET_WRITE_ERROR, + "Closed by test.", ConnectionCloseSource::FROM_SELF); + + SetBlocked(); + + ASSERT_FALSE(dispatcher_->HasPendingWrites()); + SetBlocked(); + dispatcher_->OnWriteBlocked(connection1()); + ASSERT_TRUE(dispatcher_->HasPendingWrites()); + + dispatcher_->DeleteSessions(); + MarkSession1Deleted(); + }, + "QuicConnection was in WriteBlockedList before destruction"); +} + +class QuicDispatcherSupportMultipleConnectionIdPerConnectionTest + : public QuicDispatcherTestBase { + public: + QuicDispatcherSupportMultipleConnectionIdPerConnectionTest() + : QuicDispatcherTestBase(crypto_test_utils::ProofSourceForTesting()) { + dispatcher_ = std::make_unique>( + &config_, &crypto_config_, &version_manager_, + mock_helper_.GetRandomGenerator(), connection_id_generator_); + } + void AddConnection1() { + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, client_address, + Eq(ExpectedAlpn()), _, _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, TestConnectionId(1), client_address, + &helper_, &alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .WillOnce(WithArg<2>(Invoke([this](const QuicEncryptedPacket& packet) { + ValidatePacket(TestConnectionId(1), packet); + }))); + ProcessFirstFlight(client_address, TestConnectionId(1)); + } + + void AddConnection2() { + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 2); + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, client_address, + Eq(ExpectedAlpn()), _, _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, TestConnectionId(2), client_address, + &helper_, &alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session2_)))); + EXPECT_CALL(*reinterpret_cast(session2_->connection()), + ProcessUdpPacket(_, _, _)) + .WillOnce(WithArg<2>(Invoke([this](const QuicEncryptedPacket& packet) { + ValidatePacket(TestConnectionId(2), packet); + }))); + ProcessFirstFlight(client_address, TestConnectionId(2)); + } + + protected: + MockQuicConnectionHelper helper_; + MockAlarmFactory alarm_factory_; +}; + +INSTANTIATE_TEST_SUITE_P( + QuicDispatcherSupportMultipleConnectionIdPerConnectionTests, + QuicDispatcherSupportMultipleConnectionIdPerConnectionTest, + ::testing::Values(CurrentSupportedVersions().front()), + ::testing::PrintToStringParamName()); + +TEST_P(QuicDispatcherSupportMultipleConnectionIdPerConnectionTest, + FailToAddExistingConnectionId) { + AddConnection1(); + EXPECT_FALSE(dispatcher_->TryAddNewConnectionId(TestConnectionId(1), + TestConnectionId(1))); +} + +TEST_P(QuicDispatcherSupportMultipleConnectionIdPerConnectionTest, + TryAddNewConnectionId) { + AddConnection1(); + ASSERT_EQ(dispatcher_->NumSessions(), 1u); + ASSERT_THAT(session1_, testing::NotNull()); + MockServerConnection* mock_server_connection1 = + reinterpret_cast(connection1()); + + { + mock_server_connection1->AddNewConnectionId(TestConnectionId(3)); + EXPECT_EQ(dispatcher_->NumSessions(), 1u); + auto* session = + QuicDispatcherPeer::FindSession(dispatcher_.get(), TestConnectionId(3)); + ASSERT_EQ(session, session1_); + } + + { + mock_server_connection1->AddNewConnectionId(TestConnectionId(4)); + EXPECT_EQ(dispatcher_->NumSessions(), 1u); + auto* session = + QuicDispatcherPeer::FindSession(dispatcher_.get(), TestConnectionId(4)); + ASSERT_EQ(session, session1_); + } + + EXPECT_CALL(*connection1(), CloseConnection(QUIC_PEER_GOING_AWAY, _, _)); + // Would timed out unless all sessions have been removed from the session map. + dispatcher_->Shutdown(); +} + +TEST_P(QuicDispatcherSupportMultipleConnectionIdPerConnectionTest, + TryAddNewConnectionIdWithCollision) { + AddConnection1(); + AddConnection2(); + ASSERT_EQ(dispatcher_->NumSessions(), 2u); + ASSERT_THAT(session1_, testing::NotNull()); + ASSERT_THAT(session2_, testing::NotNull()); + MockServerConnection* mock_server_connection1 = + reinterpret_cast(connection1()); + MockServerConnection* mock_server_connection2 = + reinterpret_cast(connection2()); + + { + // TestConnectionId(2) is already claimed by connection2 but connection1 + // still thinks it owns it. + mock_server_connection1->UnconditionallyAddNewConnectionIdForTest( + TestConnectionId(2)); + EXPECT_EQ(dispatcher_->NumSessions(), 2u); + auto* session = + QuicDispatcherPeer::FindSession(dispatcher_.get(), TestConnectionId(2)); + ASSERT_EQ(session, session2_); + EXPECT_THAT(mock_server_connection1->GetActiveServerConnectionIds(), + testing::ElementsAre(TestConnectionId(1), TestConnectionId(2))); + } + + { + mock_server_connection2->AddNewConnectionId(TestConnectionId(3)); + EXPECT_EQ(dispatcher_->NumSessions(), 2u); + auto* session = + QuicDispatcherPeer::FindSession(dispatcher_.get(), TestConnectionId(3)); + ASSERT_EQ(session, session2_); + EXPECT_THAT(mock_server_connection2->GetActiveServerConnectionIds(), + testing::ElementsAre(TestConnectionId(2), TestConnectionId(3))); + } + + // Connection2 removes both TestConnectionId(2) & TestConnectionId(3) from the + // session map. + dispatcher_->OnConnectionClosed(TestConnectionId(2), + QuicErrorCode::QUIC_NO_ERROR, "detail", + quic::ConnectionCloseSource::FROM_SELF); + // QUICHE_BUG fires when connection1 tries to remove TestConnectionId(2) + // again from the session_map. + EXPECT_QUICHE_BUG(dispatcher_->OnConnectionClosed( + TestConnectionId(1), QuicErrorCode::QUIC_NO_ERROR, + "detail", quic::ConnectionCloseSource::FROM_SELF), + "Missing session for cid"); +} + +TEST_P(QuicDispatcherSupportMultipleConnectionIdPerConnectionTest, + MismatchedSessionAfterAddingCollidedConnectionId) { + AddConnection1(); + AddConnection2(); + MockServerConnection* mock_server_connection1 = + reinterpret_cast(connection1()); + + { + // TestConnectionId(2) is already claimed by connection2 but connection1 + // still thinks it owns it. + mock_server_connection1->UnconditionallyAddNewConnectionIdForTest( + TestConnectionId(2)); + EXPECT_EQ(dispatcher_->NumSessions(), 2u); + auto* session = + QuicDispatcherPeer::FindSession(dispatcher_.get(), TestConnectionId(2)); + ASSERT_EQ(session, session2_); + EXPECT_THAT(mock_server_connection1->GetActiveServerConnectionIds(), + testing::ElementsAre(TestConnectionId(1), TestConnectionId(2))); + } + + // Connection1 tries to remove both Cid1 & Cid2, but they point to different + // sessions. + EXPECT_QUIC_BUG(dispatcher_->OnConnectionClosed( + TestConnectionId(1), QuicErrorCode::QUIC_NO_ERROR, + "detail", quic::ConnectionCloseSource::FROM_SELF), + "Session is mismatched in the map"); +} + +TEST_P(QuicDispatcherSupportMultipleConnectionIdPerConnectionTest, + RetireConnectionIdFromSingleConnection) { + AddConnection1(); + ASSERT_EQ(dispatcher_->NumSessions(), 1u); + ASSERT_THAT(session1_, testing::NotNull()); + MockServerConnection* mock_server_connection1 = + reinterpret_cast(connection1()); + + // Adds 1 new connection id every turn and retires 2 connection ids every + // other turn. + for (int i = 2; i < 10; ++i) { + mock_server_connection1->AddNewConnectionId(TestConnectionId(i)); + ASSERT_EQ( + QuicDispatcherPeer::FindSession(dispatcher_.get(), TestConnectionId(i)), + session1_); + ASSERT_EQ(QuicDispatcherPeer::FindSession(dispatcher_.get(), + TestConnectionId(i - 1)), + session1_); + EXPECT_EQ(dispatcher_->NumSessions(), 1u); + if (i % 2 == 1) { + mock_server_connection1->RetireConnectionId(TestConnectionId(i - 2)); + mock_server_connection1->RetireConnectionId(TestConnectionId(i - 1)); + } + } + + EXPECT_CALL(*connection1(), CloseConnection(QUIC_PEER_GOING_AWAY, _, _)); + // Would timed out unless all sessions have been removed from the session map. + dispatcher_->Shutdown(); +} + +TEST_P(QuicDispatcherSupportMultipleConnectionIdPerConnectionTest, + RetireConnectionIdFromMultipleConnections) { + AddConnection1(); + AddConnection2(); + ASSERT_EQ(dispatcher_->NumSessions(), 2u); + MockServerConnection* mock_server_connection1 = + reinterpret_cast(connection1()); + MockServerConnection* mock_server_connection2 = + reinterpret_cast(connection2()); + + for (int i = 2; i < 10; ++i) { + mock_server_connection1->AddNewConnectionId(TestConnectionId(2 * i - 1)); + mock_server_connection2->AddNewConnectionId(TestConnectionId(2 * i)); + ASSERT_EQ(QuicDispatcherPeer::FindSession(dispatcher_.get(), + TestConnectionId(2 * i - 1)), + session1_); + ASSERT_EQ(QuicDispatcherPeer::FindSession(dispatcher_.get(), + TestConnectionId(2 * i)), + session2_); + EXPECT_EQ(dispatcher_->NumSessions(), 2u); + mock_server_connection1->RetireConnectionId(TestConnectionId(2 * i - 3)); + mock_server_connection2->RetireConnectionId(TestConnectionId(2 * i - 2)); + } + + mock_server_connection1->AddNewConnectionId(TestConnectionId(19)); + mock_server_connection2->AddNewConnectionId(TestConnectionId(20)); + EXPECT_CALL(*connection1(), CloseConnection(QUIC_PEER_GOING_AWAY, _, _)); + EXPECT_CALL(*connection2(), CloseConnection(QUIC_PEER_GOING_AWAY, _, _)); + // Would timed out unless all sessions have been removed from the session map. + dispatcher_->Shutdown(); +} + +TEST_P(QuicDispatcherSupportMultipleConnectionIdPerConnectionTest, + TimeWaitListPoplulateCorrectly) { + QuicTimeWaitListManager* time_wait_list_manager = + QuicDispatcherPeer::GetTimeWaitListManager(dispatcher_.get()); + AddConnection1(); + MockServerConnection* mock_server_connection1 = + reinterpret_cast(connection1()); + + mock_server_connection1->AddNewConnectionId(TestConnectionId(2)); + mock_server_connection1->AddNewConnectionId(TestConnectionId(3)); + mock_server_connection1->AddNewConnectionId(TestConnectionId(4)); + mock_server_connection1->RetireConnectionId(TestConnectionId(1)); + mock_server_connection1->RetireConnectionId(TestConnectionId(2)); + + EXPECT_CALL(*connection1(), CloseConnection(QUIC_PEER_GOING_AWAY, _, _)); + connection1()->CloseConnection( + QUIC_PEER_GOING_AWAY, "Close for testing", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + + EXPECT_FALSE( + time_wait_list_manager->IsConnectionIdInTimeWait(TestConnectionId(1))); + EXPECT_FALSE( + time_wait_list_manager->IsConnectionIdInTimeWait(TestConnectionId(2))); + EXPECT_TRUE( + time_wait_list_manager->IsConnectionIdInTimeWait(TestConnectionId(3))); + EXPECT_TRUE( + time_wait_list_manager->IsConnectionIdInTimeWait(TestConnectionId(4))); + + dispatcher_->Shutdown(); +} + +class BufferedPacketStoreTest : public QuicDispatcherTestBase { + public: + BufferedPacketStoreTest() + : QuicDispatcherTestBase(), + client_addr_(QuicIpAddress::Loopback4(), 1234) {} + + void ProcessFirstFlight(const ParsedQuicVersion& version, + const QuicSocketAddress& peer_address, + const QuicConnectionId& server_connection_id) { + QuicDispatcherTestBase::ProcessFirstFlight(version, peer_address, + server_connection_id); + } + + void ProcessFirstFlight(const QuicSocketAddress& peer_address, + const QuicConnectionId& server_connection_id) { + ProcessFirstFlight(version_, peer_address, server_connection_id); + } + + void ProcessFirstFlight(const QuicConnectionId& server_connection_id) { + ProcessFirstFlight(client_addr_, server_connection_id); + } + + void ProcessFirstFlight(const ParsedQuicVersion& version, + const QuicConnectionId& server_connection_id) { + ProcessFirstFlight(version, client_addr_, server_connection_id); + } + + void ProcessUndecryptableEarlyPacket( + const ParsedQuicVersion& version, const QuicSocketAddress& peer_address, + const QuicConnectionId& server_connection_id) { + QuicDispatcherTestBase::ProcessUndecryptableEarlyPacket( + version, peer_address, server_connection_id); + } + + void ProcessUndecryptableEarlyPacket( + const QuicSocketAddress& peer_address, + const QuicConnectionId& server_connection_id) { + ProcessUndecryptableEarlyPacket(version_, peer_address, + server_connection_id); + } + + void ProcessUndecryptableEarlyPacket( + const QuicConnectionId& server_connection_id) { + ProcessUndecryptableEarlyPacket(version_, client_addr_, + server_connection_id); + } + + protected: + QuicSocketAddress client_addr_; +}; + +INSTANTIATE_TEST_SUITE_P(BufferedPacketStoreTests, BufferedPacketStoreTest, + ::testing::ValuesIn(CurrentSupportedVersions()), + ::testing::PrintToStringParamName()); + +TEST_P(BufferedPacketStoreTest, ProcessNonChloPacketBeforeChlo) { + InSequence s; + QuicConnectionId conn_id = TestConnectionId(1); + // Process non-CHLO packet. + ProcessUndecryptableEarlyPacket(conn_id); + EXPECT_EQ(0u, dispatcher_->NumSessions()) + << "No session should be created before CHLO arrives."; + + // When CHLO arrives, a new session should be created, and all packets + // buffered should be delivered to the session. + EXPECT_CALL(connection_id_generator_, + MaybeReplaceConnectionId(conn_id, version_)) + .WillOnce(Return(absl::nullopt)); + EXPECT_CALL(*dispatcher_, + CreateQuicSession(conn_id, _, client_addr_, Eq(ExpectedAlpn()), _, + Eq(ParsedClientHelloForTest()))) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, conn_id, client_addr_, &mock_helper_, + &mock_alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .Times(2) // non-CHLO + CHLO. + .WillRepeatedly( + WithArg<2>(Invoke([this, conn_id](const QuicEncryptedPacket& packet) { + if (version_.UsesQuicCrypto()) { + ValidatePacket(conn_id, packet); + } + }))); + expect_generator_is_called_ = false; + ProcessFirstFlight(conn_id); +} + +TEST_P(BufferedPacketStoreTest, ProcessNonChloPacketsUptoLimitAndProcessChlo) { + InSequence s; + QuicConnectionId conn_id = TestConnectionId(1); + for (size_t i = 1; i <= kDefaultMaxUndecryptablePackets + 1; ++i) { + ProcessUndecryptableEarlyPacket(conn_id); + } + EXPECT_EQ(0u, dispatcher_->NumSessions()) + << "No session should be created before CHLO arrives."; + + // Pop out the last packet as it is also be dropped by the store. + data_connection_map_[conn_id].pop_back(); + // When CHLO arrives, a new session should be created, and all packets + // buffered should be delivered to the session. + EXPECT_CALL(connection_id_generator_, + MaybeReplaceConnectionId(conn_id, version_)) + .WillOnce(Return(absl::nullopt)); + EXPECT_CALL(*dispatcher_, CreateQuicSession(conn_id, _, client_addr_, + Eq(ExpectedAlpn()), _, _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, conn_id, client_addr_, &mock_helper_, + &mock_alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + + // Only |kDefaultMaxUndecryptablePackets| packets were buffered, and they + // should be delivered in arrival order. + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .Times(kDefaultMaxUndecryptablePackets + 1) // + 1 for CHLO. + .WillRepeatedly( + WithArg<2>(Invoke([this, conn_id](const QuicEncryptedPacket& packet) { + if (version_.UsesQuicCrypto()) { + ValidatePacket(conn_id, packet); + } + }))); + expect_generator_is_called_ = false; + ProcessFirstFlight(conn_id); +} + +TEST_P(BufferedPacketStoreTest, + ProcessNonChloPacketsForDifferentConnectionsUptoLimit) { + InSequence s; + // A bunch of non-CHLO should be buffered upon arrival. + size_t kNumConnections = kMaxConnectionsWithoutCHLO + 1; + for (size_t i = 1; i <= kNumConnections; ++i) { + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 20000 + i); + QuicConnectionId conn_id = TestConnectionId(i); + ProcessUndecryptableEarlyPacket(client_address, conn_id); + } + + // Pop out the packet on last connection as it shouldn't be enqueued in store + // as well. + data_connection_map_[TestConnectionId(kNumConnections)].pop_front(); + + // Reset session creation counter to ensure processing CHLO can always + // create session. + QuicDispatcherPeer::set_new_sessions_allowed_per_event_loop(dispatcher_.get(), + kNumConnections); + // Deactivate the EXPECT_CALL in ProcessFirstFlight() because we have to be + // in sequence, so the EXPECT_CALL has to explicitly be in order here. + expect_generator_is_called_ = false; + // Process CHLOs to create session for these connections. + for (size_t i = 1; i <= kNumConnections; ++i) { + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 20000 + i); + QuicConnectionId conn_id = TestConnectionId(i); + EXPECT_CALL(connection_id_generator_, + MaybeReplaceConnectionId(conn_id, version_)) + .WillOnce(Return(absl::nullopt)); + EXPECT_CALL(*dispatcher_, CreateQuicSession(conn_id, _, client_address, + Eq(ExpectedAlpn()), _, _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, conn_id, client_address, &mock_helper_, + &mock_alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + // First |kNumConnections| - 1 connections should have buffered + // a packet in store. The rest should have been dropped. + size_t num_packet_to_process = i <= kMaxConnectionsWithoutCHLO ? 2u : 1u; + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, client_address, _)) + .Times(num_packet_to_process) + .WillRepeatedly(WithArg<2>( + Invoke([this, conn_id](const QuicEncryptedPacket& packet) { + if (version_.UsesQuicCrypto()) { + ValidatePacket(conn_id, packet); + } + }))); + ProcessFirstFlight(client_address, conn_id); + } +} + +// Tests that store delivers empty packet list if CHLO arrives firstly. +TEST_P(BufferedPacketStoreTest, DeliverEmptyPackets) { + QuicConnectionId conn_id = TestConnectionId(1); + EXPECT_CALL(*dispatcher_, CreateQuicSession(conn_id, _, client_addr_, + Eq(ExpectedAlpn()), _, _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, conn_id, client_addr_, &mock_helper_, + &mock_alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, client_addr_, _)); + ProcessFirstFlight(conn_id); +} + +// Tests that a retransmitted CHLO arrives after a connection for the +// CHLO has been created. +TEST_P(BufferedPacketStoreTest, ReceiveRetransmittedCHLO) { + InSequence s; + QuicConnectionId conn_id = TestConnectionId(1); + ProcessUndecryptableEarlyPacket(conn_id); + + // When CHLO arrives, a new session should be created, and all packets + // buffered should be delivered to the session. + EXPECT_CALL(connection_id_generator_, + MaybeReplaceConnectionId(conn_id, version_)) + .WillOnce(Return(absl::nullopt)); + EXPECT_CALL(*dispatcher_, CreateQuicSession(conn_id, _, client_addr_, + Eq(ExpectedAlpn()), _, _)) + .Times(1) // Only triggered by 1st CHLO. + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, conn_id, client_addr_, &mock_helper_, + &mock_alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .Times(3) // Triggered by 1 data packet and 2 CHLOs. + .WillRepeatedly( + WithArg<2>(Invoke([this, conn_id](const QuicEncryptedPacket& packet) { + if (version_.UsesQuicCrypto()) { + ValidatePacket(conn_id, packet); + } + }))); + + std::vector> packets = + GetFirstFlightOfPackets(version_, conn_id); + ASSERT_EQ(packets.size(), 1u); + // Receive the CHLO once. + ProcessReceivedPacket(packets[0]->Clone(), client_addr_, version_, conn_id); + // Receive the CHLO a second time to simulate retransmission. + ProcessReceivedPacket(std::move(packets[0]), client_addr_, version_, conn_id); +} + +// Tests that expiration of a connection add connection id to time wait list. +TEST_P(BufferedPacketStoreTest, ReceiveCHLOAfterExpiration) { + InSequence s; + CreateTimeWaitListManager(); + QuicBufferedPacketStore* store = + QuicDispatcherPeer::GetBufferedPackets(dispatcher_.get()); + QuicBufferedPacketStorePeer::set_clock(store, mock_helper_.GetClock()); + + QuicConnectionId conn_id = TestConnectionId(1); + ProcessPacket(client_addr_, conn_id, true, absl::StrCat("data packet ", 2), + CONNECTION_ID_PRESENT, PACKET_4BYTE_PACKET_NUMBER, + /*packet_number=*/2); + + mock_helper_.AdvanceTime( + QuicTime::Delta::FromSeconds(kInitialIdleTimeoutSecs)); + QuicAlarm* alarm = QuicBufferedPacketStorePeer::expiration_alarm(store); + // Cancel alarm as if it had been fired. + alarm->Cancel(); + store->OnExpirationTimeout(); + // New arrived CHLO will be dropped because this connection is in time wait + // list. + ASSERT_TRUE(time_wait_list_manager_->IsConnectionIdInTimeWait(conn_id)); + EXPECT_CALL(*time_wait_list_manager_, ProcessPacket(_, _, conn_id, _, _, _)); + expect_generator_is_called_ = false; + ProcessFirstFlight(conn_id); +} + +TEST_P(BufferedPacketStoreTest, ProcessCHLOsUptoLimitAndBufferTheRest) { + // Process more than (|kMaxNumSessionsToCreate| + + // |kDefaultMaxConnectionsInStore|) CHLOs, + // the first |kMaxNumSessionsToCreate| should create connections immediately, + // the next |kDefaultMaxConnectionsInStore| should be buffered, + // the rest should be dropped. + QuicBufferedPacketStore* store = + QuicDispatcherPeer::GetBufferedPackets(dispatcher_.get()); + const size_t kNumCHLOs = + kMaxNumSessionsToCreate + kDefaultMaxConnectionsInStore + 1; + for (uint64_t conn_id = 1; conn_id <= kNumCHLOs; ++conn_id) { + if (conn_id <= kMaxNumSessionsToCreate) { + EXPECT_CALL(connection_id_generator_, + MaybeReplaceConnectionId(TestConnectionId(conn_id), version_)) + .WillOnce(Return(absl::nullopt)); + EXPECT_CALL(*dispatcher_, + CreateQuicSession(TestConnectionId(conn_id), _, client_addr_, + Eq(ExpectedAlpn()), _, + Eq(ParsedClientHelloForTest()))) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, TestConnectionId(conn_id), + client_addr_, &mock_helper_, &mock_alarm_factory_, + &crypto_config_, QuicDispatcherPeer::GetCache(dispatcher_.get()), + &session1_)))); + EXPECT_CALL( + *reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .WillOnce(WithArg<2>( + Invoke([this, conn_id](const QuicEncryptedPacket& packet) { + if (version_.UsesQuicCrypto()) { + ValidatePacket(TestConnectionId(conn_id), packet); + } + }))); + } + expect_generator_is_called_ = false; + ProcessFirstFlight(TestConnectionId(conn_id)); + if (conn_id <= kMaxNumSessionsToCreate + kDefaultMaxConnectionsInStore && + conn_id > kMaxNumSessionsToCreate) { + EXPECT_TRUE(store->HasChloForConnection(TestConnectionId(conn_id))); + } else { + // First |kMaxNumSessionsToCreate| CHLOs should be passed to new + // connections immediately, and the last CHLO should be dropped as the + // store is full. + EXPECT_FALSE(store->HasChloForConnection(TestConnectionId(conn_id))); + } + } + + // Graduately consume buffered CHLOs. The buffered connections should be + // created but the dropped one shouldn't. + for (uint64_t conn_id = kMaxNumSessionsToCreate + 1; + conn_id <= kMaxNumSessionsToCreate + kDefaultMaxConnectionsInStore; + ++conn_id) { + EXPECT_CALL(connection_id_generator_, + MaybeReplaceConnectionId(TestConnectionId(conn_id), version_)) + .WillOnce(Return(absl::nullopt)); + EXPECT_CALL(*dispatcher_, + CreateQuicSession(TestConnectionId(conn_id), _, client_addr_, + Eq(ExpectedAlpn()), _, + Eq(ParsedClientHelloForTest()))) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, TestConnectionId(conn_id), client_addr_, + &mock_helper_, &mock_alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .WillOnce(WithArg<2>( + Invoke([this, conn_id](const QuicEncryptedPacket& packet) { + if (version_.UsesQuicCrypto()) { + ValidatePacket(TestConnectionId(conn_id), packet); + } + }))); + } + EXPECT_CALL(connection_id_generator_, + MaybeReplaceConnectionId(TestConnectionId(kNumCHLOs), version_)) + .Times(0); + EXPECT_CALL(*dispatcher_, + CreateQuicSession(TestConnectionId(kNumCHLOs), _, client_addr_, + Eq(ExpectedAlpn()), _, _)) + .Times(0); + + while (store->HasChlosBuffered()) { + dispatcher_->ProcessBufferedChlos(kMaxNumSessionsToCreate); + } + + EXPECT_EQ(TestConnectionId(static_cast(kMaxNumSessionsToCreate) + + kDefaultMaxConnectionsInStore), + session1_->connection_id()); +} + +// Duplicated CHLO shouldn't be buffered. +TEST_P(BufferedPacketStoreTest, BufferDuplicatedCHLO) { + for (uint64_t conn_id = 1; conn_id <= kMaxNumSessionsToCreate + 1; + ++conn_id) { + // Last CHLO will be buffered. Others will create connection right away. + if (conn_id <= kMaxNumSessionsToCreate) { + EXPECT_CALL(*dispatcher_, + CreateQuicSession(TestConnectionId(conn_id), _, client_addr_, + Eq(ExpectedAlpn()), _, _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, TestConnectionId(conn_id), + client_addr_, &mock_helper_, &mock_alarm_factory_, + &crypto_config_, QuicDispatcherPeer::GetCache(dispatcher_.get()), + &session1_)))); + EXPECT_CALL( + *reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .WillOnce(WithArg<2>( + Invoke([this, conn_id](const QuicEncryptedPacket& packet) { + if (version_.UsesQuicCrypto()) { + ValidatePacket(TestConnectionId(conn_id), packet); + } + }))); + } + ProcessFirstFlight(TestConnectionId(conn_id)); + } + // Retransmit CHLO on last connection should be dropped. + QuicConnectionId last_connection = + TestConnectionId(kMaxNumSessionsToCreate + 1); + expect_generator_is_called_ = false; + ProcessFirstFlight(last_connection); + + size_t packets_buffered = 2; + + // Reset counter and process buffered CHLO. + EXPECT_CALL(*dispatcher_, CreateQuicSession(last_connection, _, client_addr_, + Eq(ExpectedAlpn()), _, _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, last_connection, client_addr_, + &mock_helper_, &mock_alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + // Only one packet(CHLO) should be process. + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .Times(packets_buffered) + .WillRepeatedly(WithArg<2>( + Invoke([this, last_connection](const QuicEncryptedPacket& packet) { + if (version_.UsesQuicCrypto()) { + ValidatePacket(last_connection, packet); + } + }))); + dispatcher_->ProcessBufferedChlos(kMaxNumSessionsToCreate); +} + +TEST_P(BufferedPacketStoreTest, BufferNonChloPacketsUptoLimitWithChloBuffered) { + uint64_t last_conn_id = kMaxNumSessionsToCreate + 1; + QuicConnectionId last_connection_id = TestConnectionId(last_conn_id); + for (uint64_t conn_id = 1; conn_id <= last_conn_id; ++conn_id) { + // Last CHLO will be buffered. Others will create connection right away. + if (conn_id <= kMaxNumSessionsToCreate) { + EXPECT_CALL(*dispatcher_, + CreateQuicSession(TestConnectionId(conn_id), _, client_addr_, + Eq(ExpectedAlpn()), _, _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, TestConnectionId(conn_id), + client_addr_, &mock_helper_, &mock_alarm_factory_, + &crypto_config_, QuicDispatcherPeer::GetCache(dispatcher_.get()), + &session1_)))); + EXPECT_CALL( + *reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .WillRepeatedly(WithArg<2>( + Invoke([this, conn_id](const QuicEncryptedPacket& packet) { + if (version_.UsesQuicCrypto()) { + ValidatePacket(TestConnectionId(conn_id), packet); + } + }))); + } + ProcessFirstFlight(TestConnectionId(conn_id)); + } + + // Process another |kDefaultMaxUndecryptablePackets| + 1 data packets. The + // last one should be dropped. + for (uint64_t packet_number = 2; + packet_number <= kDefaultMaxUndecryptablePackets + 2; ++packet_number) { + ProcessPacket(client_addr_, last_connection_id, true, "data packet"); + } + + // Reset counter and process buffered CHLO. + EXPECT_CALL(*dispatcher_, + CreateQuicSession(last_connection_id, _, client_addr_, + Eq(ExpectedAlpn()), _, _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, last_connection_id, client_addr_, + &mock_helper_, &mock_alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + // Only CHLO and following |kDefaultMaxUndecryptablePackets| data packets + // should be process. + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .Times(kDefaultMaxUndecryptablePackets + 1) + .WillRepeatedly(WithArg<2>( + Invoke([this, last_connection_id](const QuicEncryptedPacket& packet) { + if (version_.UsesQuicCrypto()) { + ValidatePacket(last_connection_id, packet); + } + }))); + dispatcher_->ProcessBufferedChlos(kMaxNumSessionsToCreate); +} + +// Tests that when dispatcher's packet buffer is full, a CHLO on connection +// which doesn't have buffered CHLO should be buffered. +TEST_P(BufferedPacketStoreTest, ReceiveCHLOForBufferedConnection) { + QuicBufferedPacketStore* store = + QuicDispatcherPeer::GetBufferedPackets(dispatcher_.get()); + + uint64_t conn_id = 1; + ProcessUndecryptableEarlyPacket(TestConnectionId(conn_id)); + // Fill packet buffer to full with CHLOs on other connections. Need to feed + // extra CHLOs because the first |kMaxNumSessionsToCreate| are going to create + // session directly. + for (conn_id = 2; + conn_id <= kDefaultMaxConnectionsInStore + kMaxNumSessionsToCreate; + ++conn_id) { + if (conn_id <= kMaxNumSessionsToCreate + 1) { + EXPECT_CALL(*dispatcher_, + CreateQuicSession(TestConnectionId(conn_id), _, client_addr_, + Eq(ExpectedAlpn()), _, _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, TestConnectionId(conn_id), + client_addr_, &mock_helper_, &mock_alarm_factory_, + &crypto_config_, QuicDispatcherPeer::GetCache(dispatcher_.get()), + &session1_)))); + EXPECT_CALL( + *reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .WillOnce(WithArg<2>( + Invoke([this, conn_id](const QuicEncryptedPacket& packet) { + if (version_.UsesQuicCrypto()) { + ValidatePacket(TestConnectionId(conn_id), packet); + } + }))); + } else { + expect_generator_is_called_ = false; + } + ProcessFirstFlight(TestConnectionId(conn_id)); + } + EXPECT_FALSE(store->HasChloForConnection( + /*connection_id=*/TestConnectionId(1))); + + // CHLO on connection 1 should still be buffered. + ProcessFirstFlight(TestConnectionId(1)); + EXPECT_TRUE(store->HasChloForConnection( + /*connection_id=*/TestConnectionId(1))); +} + +// Regression test for b/117874922. +TEST_P(BufferedPacketStoreTest, ProcessBufferedChloWithDifferentVersion) { + // Ensure the preferred version is not supported by the server. + QuicDisableVersion(AllSupportedVersions().front()); + + uint64_t last_connection_id = kMaxNumSessionsToCreate + 5; + ParsedQuicVersionVector supported_versions = CurrentSupportedVersions(); + for (uint64_t conn_id = 1; conn_id <= last_connection_id; ++conn_id) { + // Last 5 CHLOs will be buffered. Others will create connection right away. + ParsedQuicVersion version = + supported_versions[(conn_id - 1) % supported_versions.size()]; + if (conn_id <= kMaxNumSessionsToCreate) { + EXPECT_CALL( + *dispatcher_, + CreateQuicSession(TestConnectionId(conn_id), _, client_addr_, + Eq(ExpectedAlpnForVersion(version)), version, _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, TestConnectionId(conn_id), + client_addr_, &mock_helper_, &mock_alarm_factory_, + &crypto_config_, QuicDispatcherPeer::GetCache(dispatcher_.get()), + &session1_)))); + EXPECT_CALL( + *reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .WillRepeatedly(WithArg<2>( + Invoke([this, conn_id](const QuicEncryptedPacket& packet) { + if (version_.UsesQuicCrypto()) { + ValidatePacket(TestConnectionId(conn_id), packet); + } + }))); + } + ProcessFirstFlight(version, TestConnectionId(conn_id)); + } + + // Process buffered CHLOs. Verify the version is correct. + for (uint64_t conn_id = kMaxNumSessionsToCreate + 1; + conn_id <= last_connection_id; ++conn_id) { + ParsedQuicVersion version = + supported_versions[(conn_id - 1) % supported_versions.size()]; + EXPECT_CALL( + *dispatcher_, + CreateQuicSession(TestConnectionId(conn_id), _, client_addr_, + Eq(ExpectedAlpnForVersion(version)), version, _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, TestConnectionId(conn_id), client_addr_, + &mock_helper_, &mock_alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + EXPECT_CALL(*reinterpret_cast(session1_->connection()), + ProcessUdpPacket(_, _, _)) + .WillRepeatedly(WithArg<2>( + Invoke([this, conn_id](const QuicEncryptedPacket& packet) { + if (version_.UsesQuicCrypto()) { + ValidatePacket(TestConnectionId(conn_id), packet); + } + }))); + } + dispatcher_->ProcessBufferedChlos(kMaxNumSessionsToCreate); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_error_codes.cc b/quiche/quic/core/quic_error_codes.cc new file mode 100644 index 000000000000..1b176b470eee --- /dev/null +++ b/quiche/quic/core/quic_error_codes.cc @@ -0,0 +1,992 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_error_codes.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "openssl/ssl.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +#define RETURN_STRING_LITERAL(x) \ + case x: \ + return #x; + +const char* QuicRstStreamErrorCodeToString(QuicRstStreamErrorCode error) { + switch (error) { + RETURN_STRING_LITERAL(QUIC_STREAM_NO_ERROR); + RETURN_STRING_LITERAL(QUIC_ERROR_PROCESSING_STREAM); + RETURN_STRING_LITERAL(QUIC_MULTIPLE_TERMINATION_OFFSETS); + RETURN_STRING_LITERAL(QUIC_BAD_APPLICATION_PAYLOAD); + RETURN_STRING_LITERAL(QUIC_STREAM_CONNECTION_ERROR); + RETURN_STRING_LITERAL(QUIC_STREAM_PEER_GOING_AWAY); + RETURN_STRING_LITERAL(QUIC_STREAM_CANCELLED); + RETURN_STRING_LITERAL(QUIC_RST_ACKNOWLEDGEMENT); + RETURN_STRING_LITERAL(QUIC_REFUSED_STREAM); + RETURN_STRING_LITERAL(QUIC_INVALID_PROMISE_URL); + RETURN_STRING_LITERAL(QUIC_UNAUTHORIZED_PROMISE_URL); + RETURN_STRING_LITERAL(QUIC_DUPLICATE_PROMISE_URL); + RETURN_STRING_LITERAL(QUIC_PROMISE_VARY_MISMATCH); + RETURN_STRING_LITERAL(QUIC_INVALID_PROMISE_METHOD); + RETURN_STRING_LITERAL(QUIC_PUSH_STREAM_TIMED_OUT); + RETURN_STRING_LITERAL(QUIC_HEADERS_TOO_LARGE); + RETURN_STRING_LITERAL(QUIC_STREAM_TTL_EXPIRED); + RETURN_STRING_LITERAL(QUIC_DATA_AFTER_CLOSE_OFFSET); + RETURN_STRING_LITERAL(QUIC_STREAM_GENERAL_PROTOCOL_ERROR); + RETURN_STRING_LITERAL(QUIC_STREAM_INTERNAL_ERROR); + RETURN_STRING_LITERAL(QUIC_STREAM_STREAM_CREATION_ERROR); + RETURN_STRING_LITERAL(QUIC_STREAM_CLOSED_CRITICAL_STREAM); + RETURN_STRING_LITERAL(QUIC_STREAM_FRAME_UNEXPECTED); + RETURN_STRING_LITERAL(QUIC_STREAM_FRAME_ERROR); + RETURN_STRING_LITERAL(QUIC_STREAM_EXCESSIVE_LOAD); + RETURN_STRING_LITERAL(QUIC_STREAM_ID_ERROR); + RETURN_STRING_LITERAL(QUIC_STREAM_SETTINGS_ERROR); + RETURN_STRING_LITERAL(QUIC_STREAM_MISSING_SETTINGS); + RETURN_STRING_LITERAL(QUIC_STREAM_REQUEST_REJECTED); + RETURN_STRING_LITERAL(QUIC_STREAM_REQUEST_INCOMPLETE); + RETURN_STRING_LITERAL(QUIC_STREAM_CONNECT_ERROR); + RETURN_STRING_LITERAL(QUIC_STREAM_VERSION_FALLBACK); + RETURN_STRING_LITERAL(QUIC_STREAM_DECOMPRESSION_FAILED); + RETURN_STRING_LITERAL(QUIC_STREAM_ENCODER_STREAM_ERROR); + RETURN_STRING_LITERAL(QUIC_STREAM_DECODER_STREAM_ERROR); + RETURN_STRING_LITERAL(QUIC_STREAM_UNKNOWN_APPLICATION_ERROR_CODE); + RETURN_STRING_LITERAL(QUIC_STREAM_WEBTRANSPORT_SESSION_GONE); + RETURN_STRING_LITERAL( + QUIC_STREAM_WEBTRANSPORT_BUFFERED_STREAMS_LIMIT_EXCEEDED); + RETURN_STRING_LITERAL(QUIC_APPLICATION_DONE_WITH_STREAM); + RETURN_STRING_LITERAL(QUIC_STREAM_LAST_ERROR); + } + // Return a default value so that we return this when |error| doesn't match + // any of the QuicRstStreamErrorCodes. This can happen when the RstStream + // frame sent by the peer (attacker) has invalid error code. + return "INVALID_RST_STREAM_ERROR_CODE"; +} + +const char* QuicErrorCodeToString(QuicErrorCode error) { + switch (error) { + RETURN_STRING_LITERAL(QUIC_NO_ERROR); + RETURN_STRING_LITERAL(QUIC_INTERNAL_ERROR); + RETURN_STRING_LITERAL(QUIC_STREAM_DATA_AFTER_TERMINATION); + RETURN_STRING_LITERAL(QUIC_INVALID_PACKET_HEADER); + RETURN_STRING_LITERAL(QUIC_INVALID_FRAME_DATA); + RETURN_STRING_LITERAL(QUIC_MISSING_PAYLOAD); + RETURN_STRING_LITERAL(QUIC_INVALID_FEC_DATA); + RETURN_STRING_LITERAL(QUIC_INVALID_STREAM_DATA); + RETURN_STRING_LITERAL(QUIC_OVERLAPPING_STREAM_DATA); + RETURN_STRING_LITERAL(QUIC_UNENCRYPTED_STREAM_DATA); + RETURN_STRING_LITERAL(QUIC_INVALID_RST_STREAM_DATA); + RETURN_STRING_LITERAL(QUIC_INVALID_CONNECTION_CLOSE_DATA); + RETURN_STRING_LITERAL(QUIC_INVALID_GOAWAY_DATA); + RETURN_STRING_LITERAL(QUIC_INVALID_WINDOW_UPDATE_DATA); + RETURN_STRING_LITERAL(QUIC_INVALID_BLOCKED_DATA); + RETURN_STRING_LITERAL(QUIC_INVALID_STOP_WAITING_DATA); + RETURN_STRING_LITERAL(QUIC_INVALID_PATH_CLOSE_DATA); + RETURN_STRING_LITERAL(QUIC_INVALID_ACK_DATA); + RETURN_STRING_LITERAL(QUIC_INVALID_VERSION_NEGOTIATION_PACKET); + RETURN_STRING_LITERAL(QUIC_INVALID_PUBLIC_RST_PACKET); + RETURN_STRING_LITERAL(QUIC_DECRYPTION_FAILURE); + RETURN_STRING_LITERAL(QUIC_ENCRYPTION_FAILURE); + RETURN_STRING_LITERAL(QUIC_PACKET_TOO_LARGE); + RETURN_STRING_LITERAL(QUIC_PEER_GOING_AWAY); + RETURN_STRING_LITERAL(QUIC_HANDSHAKE_FAILED); + RETURN_STRING_LITERAL(QUIC_CRYPTO_TAGS_OUT_OF_ORDER); + RETURN_STRING_LITERAL(QUIC_CRYPTO_TOO_MANY_ENTRIES); + RETURN_STRING_LITERAL(QUIC_CRYPTO_TOO_MANY_REJECTS); + RETURN_STRING_LITERAL(QUIC_CRYPTO_INVALID_VALUE_LENGTH) + RETURN_STRING_LITERAL(QUIC_CRYPTO_MESSAGE_AFTER_HANDSHAKE_COMPLETE); + RETURN_STRING_LITERAL(QUIC_CRYPTO_INTERNAL_ERROR); + RETURN_STRING_LITERAL(QUIC_CRYPTO_VERSION_NOT_SUPPORTED); + RETURN_STRING_LITERAL(QUIC_CRYPTO_NO_SUPPORT); + RETURN_STRING_LITERAL(QUIC_INVALID_CRYPTO_MESSAGE_TYPE); + RETURN_STRING_LITERAL(QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER); + RETURN_STRING_LITERAL(QUIC_CRYPTO_MESSAGE_PARAMETER_NOT_FOUND); + RETURN_STRING_LITERAL(QUIC_CRYPTO_MESSAGE_PARAMETER_NO_OVERLAP); + RETURN_STRING_LITERAL(QUIC_CRYPTO_MESSAGE_INDEX_NOT_FOUND); + RETURN_STRING_LITERAL(QUIC_UNSUPPORTED_PROOF_DEMAND); + RETURN_STRING_LITERAL(QUIC_INVALID_STREAM_ID); + RETURN_STRING_LITERAL(QUIC_INVALID_PRIORITY); + RETURN_STRING_LITERAL(QUIC_TOO_MANY_OPEN_STREAMS); + RETURN_STRING_LITERAL(QUIC_PUBLIC_RESET); + RETURN_STRING_LITERAL(QUIC_INVALID_VERSION); + RETURN_STRING_LITERAL(QUIC_PACKET_WRONG_VERSION); + RETURN_STRING_LITERAL(QUIC_INVALID_0RTT_PACKET_NUMBER_OUT_OF_ORDER); + RETURN_STRING_LITERAL(QUIC_INVALID_HEADER_ID); + RETURN_STRING_LITERAL(QUIC_INVALID_NEGOTIATED_VALUE); + RETURN_STRING_LITERAL(QUIC_DECOMPRESSION_FAILURE); + RETURN_STRING_LITERAL(QUIC_NETWORK_IDLE_TIMEOUT); + RETURN_STRING_LITERAL(QUIC_HANDSHAKE_TIMEOUT); + RETURN_STRING_LITERAL(QUIC_ERROR_MIGRATING_ADDRESS); + RETURN_STRING_LITERAL(QUIC_ERROR_MIGRATING_PORT); + RETURN_STRING_LITERAL(QUIC_PACKET_WRITE_ERROR); + RETURN_STRING_LITERAL(QUIC_PACKET_READ_ERROR); + RETURN_STRING_LITERAL(QUIC_EMPTY_STREAM_FRAME_NO_FIN); + RETURN_STRING_LITERAL(QUIC_INVALID_HEADERS_STREAM_DATA); + RETURN_STRING_LITERAL(QUIC_HEADERS_STREAM_DATA_DECOMPRESS_FAILURE); + RETURN_STRING_LITERAL(QUIC_FLOW_CONTROL_RECEIVED_TOO_MUCH_DATA); + RETURN_STRING_LITERAL(QUIC_FLOW_CONTROL_SENT_TOO_MUCH_DATA); + RETURN_STRING_LITERAL(QUIC_FLOW_CONTROL_INVALID_WINDOW); + RETURN_STRING_LITERAL(QUIC_CONNECTION_IP_POOLED); + RETURN_STRING_LITERAL(QUIC_PROOF_INVALID); + RETURN_STRING_LITERAL(QUIC_CRYPTO_DUPLICATE_TAG); + RETURN_STRING_LITERAL(QUIC_CRYPTO_ENCRYPTION_LEVEL_INCORRECT); + RETURN_STRING_LITERAL(QUIC_CRYPTO_SERVER_CONFIG_EXPIRED); + RETURN_STRING_LITERAL(QUIC_INVALID_CHANNEL_ID_SIGNATURE); + RETURN_STRING_LITERAL(QUIC_CRYPTO_SYMMETRIC_KEY_SETUP_FAILED); + RETURN_STRING_LITERAL(QUIC_CRYPTO_MESSAGE_WHILE_VALIDATING_CLIENT_HELLO); + RETURN_STRING_LITERAL(QUIC_CRYPTO_UPDATE_BEFORE_HANDSHAKE_COMPLETE); + RETURN_STRING_LITERAL(QUIC_VERSION_NEGOTIATION_MISMATCH); + RETURN_STRING_LITERAL(QUIC_TOO_MANY_OUTSTANDING_SENT_PACKETS); + RETURN_STRING_LITERAL(QUIC_TOO_MANY_OUTSTANDING_RECEIVED_PACKETS); + RETURN_STRING_LITERAL(QUIC_CONNECTION_CANCELLED); + RETURN_STRING_LITERAL(QUIC_BAD_PACKET_LOSS_RATE); + RETURN_STRING_LITERAL(QUIC_PUBLIC_RESETS_POST_HANDSHAKE); + RETURN_STRING_LITERAL(QUIC_FAILED_TO_SERIALIZE_PACKET); + RETURN_STRING_LITERAL(QUIC_TOO_MANY_AVAILABLE_STREAMS); + RETURN_STRING_LITERAL(QUIC_UNENCRYPTED_FEC_DATA); + RETURN_STRING_LITERAL(QUIC_BAD_MULTIPATH_FLAG); + RETURN_STRING_LITERAL(QUIC_IP_ADDRESS_CHANGED); + RETURN_STRING_LITERAL(QUIC_CONNECTION_MIGRATION_NO_MIGRATABLE_STREAMS); + RETURN_STRING_LITERAL(QUIC_CONNECTION_MIGRATION_TOO_MANY_CHANGES); + RETURN_STRING_LITERAL(QUIC_CONNECTION_MIGRATION_NO_NEW_NETWORK); + RETURN_STRING_LITERAL(QUIC_CONNECTION_MIGRATION_NON_MIGRATABLE_STREAM); + RETURN_STRING_LITERAL(QUIC_TOO_MANY_RTOS); + RETURN_STRING_LITERAL(QUIC_ATTEMPT_TO_SEND_UNENCRYPTED_STREAM_DATA); + RETURN_STRING_LITERAL(QUIC_MAYBE_CORRUPTED_MEMORY); + RETURN_STRING_LITERAL(QUIC_CRYPTO_CHLO_TOO_LARGE); + RETURN_STRING_LITERAL(QUIC_MULTIPATH_PATH_DOES_NOT_EXIST); + RETURN_STRING_LITERAL(QUIC_MULTIPATH_PATH_NOT_ACTIVE); + RETURN_STRING_LITERAL(QUIC_TOO_MANY_STREAM_DATA_INTERVALS); + RETURN_STRING_LITERAL(QUIC_STREAM_SEQUENCER_INVALID_STATE); + RETURN_STRING_LITERAL(QUIC_TOO_MANY_SESSIONS_ON_SERVER); + RETURN_STRING_LITERAL(QUIC_STREAM_LENGTH_OVERFLOW); + RETURN_STRING_LITERAL(QUIC_CONNECTION_MIGRATION_DISABLED_BY_CONFIG); + RETURN_STRING_LITERAL(QUIC_CONNECTION_MIGRATION_INTERNAL_ERROR); + RETURN_STRING_LITERAL(QUIC_INVALID_MAX_DATA_FRAME_DATA); + RETURN_STRING_LITERAL(QUIC_INVALID_MAX_STREAM_DATA_FRAME_DATA); + RETURN_STRING_LITERAL(QUIC_INVALID_STREAM_BLOCKED_DATA); + RETURN_STRING_LITERAL(QUIC_MAX_STREAMS_DATA); + RETURN_STRING_LITERAL(QUIC_STREAMS_BLOCKED_DATA); + RETURN_STRING_LITERAL(QUIC_INVALID_NEW_CONNECTION_ID_DATA); + RETURN_STRING_LITERAL(QUIC_INVALID_RETIRE_CONNECTION_ID_DATA); + RETURN_STRING_LITERAL(QUIC_CONNECTION_ID_LIMIT_ERROR); + RETURN_STRING_LITERAL(QUIC_TOO_MANY_CONNECTION_ID_WAITING_TO_RETIRE); + RETURN_STRING_LITERAL(QUIC_INVALID_STOP_SENDING_FRAME_DATA); + RETURN_STRING_LITERAL(QUIC_INVALID_PATH_CHALLENGE_DATA); + RETURN_STRING_LITERAL(QUIC_INVALID_PATH_RESPONSE_DATA); + RETURN_STRING_LITERAL(QUIC_CONNECTION_MIGRATION_HANDSHAKE_UNCONFIRMED); + RETURN_STRING_LITERAL(QUIC_PEER_PORT_CHANGE_HANDSHAKE_UNCONFIRMED); + RETURN_STRING_LITERAL(QUIC_INVALID_MESSAGE_DATA); + RETURN_STRING_LITERAL(IETF_QUIC_PROTOCOL_VIOLATION); + RETURN_STRING_LITERAL(QUIC_INVALID_NEW_TOKEN); + RETURN_STRING_LITERAL(QUIC_DATA_RECEIVED_ON_WRITE_UNIDIRECTIONAL_STREAM); + RETURN_STRING_LITERAL(QUIC_TRY_TO_WRITE_DATA_ON_READ_UNIDIRECTIONAL_STREAM); + RETURN_STRING_LITERAL(QUIC_STREAMS_BLOCKED_ERROR); + RETURN_STRING_LITERAL(QUIC_MAX_STREAMS_ERROR); + RETURN_STRING_LITERAL(QUIC_HTTP_DECODER_ERROR); + RETURN_STRING_LITERAL(QUIC_STALE_CONNECTION_CANCELLED); + RETURN_STRING_LITERAL(QUIC_IETF_GQUIC_ERROR_MISSING); + RETURN_STRING_LITERAL( + QUIC_WINDOW_UPDATE_RECEIVED_ON_READ_UNIDIRECTIONAL_STREAM); + RETURN_STRING_LITERAL(QUIC_TOO_MANY_BUFFERED_CONTROL_FRAMES); + RETURN_STRING_LITERAL(QUIC_TRANSPORT_INVALID_CLIENT_INDICATION); + RETURN_STRING_LITERAL(QUIC_QPACK_DECOMPRESSION_FAILED); + RETURN_STRING_LITERAL(QUIC_QPACK_ENCODER_STREAM_ERROR); + RETURN_STRING_LITERAL(QUIC_QPACK_DECODER_STREAM_ERROR); + RETURN_STRING_LITERAL(QUIC_QPACK_ENCODER_STREAM_INTEGER_TOO_LARGE); + RETURN_STRING_LITERAL(QUIC_QPACK_ENCODER_STREAM_STRING_LITERAL_TOO_LONG); + RETURN_STRING_LITERAL(QUIC_QPACK_ENCODER_STREAM_HUFFMAN_ENCODING_ERROR); + RETURN_STRING_LITERAL(QUIC_QPACK_ENCODER_STREAM_INVALID_STATIC_ENTRY); + RETURN_STRING_LITERAL(QUIC_QPACK_ENCODER_STREAM_ERROR_INSERTING_STATIC); + RETURN_STRING_LITERAL( + QUIC_QPACK_ENCODER_STREAM_INSERTION_INVALID_RELATIVE_INDEX); + RETURN_STRING_LITERAL( + QUIC_QPACK_ENCODER_STREAM_INSERTION_DYNAMIC_ENTRY_NOT_FOUND); + RETURN_STRING_LITERAL(QUIC_QPACK_ENCODER_STREAM_ERROR_INSERTING_DYNAMIC); + RETURN_STRING_LITERAL(QUIC_QPACK_ENCODER_STREAM_ERROR_INSERTING_LITERAL); + RETURN_STRING_LITERAL( + QUIC_QPACK_ENCODER_STREAM_DUPLICATE_INVALID_RELATIVE_INDEX); + RETURN_STRING_LITERAL( + QUIC_QPACK_ENCODER_STREAM_DUPLICATE_DYNAMIC_ENTRY_NOT_FOUND); + RETURN_STRING_LITERAL(QUIC_QPACK_ENCODER_STREAM_SET_DYNAMIC_TABLE_CAPACITY); + RETURN_STRING_LITERAL(QUIC_QPACK_DECODER_STREAM_INTEGER_TOO_LARGE); + RETURN_STRING_LITERAL(QUIC_QPACK_DECODER_STREAM_INVALID_ZERO_INCREMENT); + RETURN_STRING_LITERAL(QUIC_QPACK_DECODER_STREAM_INCREMENT_OVERFLOW); + RETURN_STRING_LITERAL(QUIC_QPACK_DECODER_STREAM_IMPOSSIBLE_INSERT_COUNT); + RETURN_STRING_LITERAL(QUIC_QPACK_DECODER_STREAM_INCORRECT_ACKNOWLEDGEMENT); + RETURN_STRING_LITERAL(QUIC_STREAM_DATA_BEYOND_CLOSE_OFFSET); + RETURN_STRING_LITERAL(QUIC_STREAM_MULTIPLE_OFFSET); + RETURN_STRING_LITERAL(QUIC_HTTP_FRAME_TOO_LARGE); + RETURN_STRING_LITERAL(QUIC_HTTP_FRAME_ERROR); + RETURN_STRING_LITERAL(QUIC_HTTP_FRAME_UNEXPECTED_ON_SPDY_STREAM); + RETURN_STRING_LITERAL(QUIC_HTTP_FRAME_UNEXPECTED_ON_CONTROL_STREAM); + RETURN_STRING_LITERAL(QUIC_HTTP_INVALID_FRAME_SEQUENCE_ON_SPDY_STREAM); + RETURN_STRING_LITERAL(QUIC_HTTP_INVALID_FRAME_SEQUENCE_ON_CONTROL_STREAM); + RETURN_STRING_LITERAL(QUIC_HTTP_DUPLICATE_UNIDIRECTIONAL_STREAM); + RETURN_STRING_LITERAL(QUIC_HTTP_SERVER_INITIATED_BIDIRECTIONAL_STREAM); + RETURN_STRING_LITERAL(QUIC_HTTP_STREAM_WRONG_DIRECTION); + RETURN_STRING_LITERAL(QUIC_HTTP_CLOSED_CRITICAL_STREAM); + RETURN_STRING_LITERAL(QUIC_HTTP_MISSING_SETTINGS_FRAME); + RETURN_STRING_LITERAL(QUIC_HTTP_DUPLICATE_SETTING_IDENTIFIER); + RETURN_STRING_LITERAL(QUIC_HTTP_INVALID_MAX_PUSH_ID); + RETURN_STRING_LITERAL(QUIC_HTTP_STREAM_LIMIT_TOO_LOW); + RETURN_STRING_LITERAL(QUIC_HTTP_ZERO_RTT_RESUMPTION_SETTINGS_MISMATCH); + RETURN_STRING_LITERAL(QUIC_HTTP_ZERO_RTT_REJECTION_SETTINGS_MISMATCH); + RETURN_STRING_LITERAL(QUIC_HTTP_GOAWAY_INVALID_STREAM_ID); + RETURN_STRING_LITERAL(QUIC_HTTP_GOAWAY_ID_LARGER_THAN_PREVIOUS); + RETURN_STRING_LITERAL(QUIC_HTTP_RECEIVE_SPDY_SETTING); + RETURN_STRING_LITERAL(QUIC_HTTP_RECEIVE_SPDY_FRAME); + RETURN_STRING_LITERAL(QUIC_HTTP_RECEIVE_SERVER_PUSH); + RETURN_STRING_LITERAL(QUIC_HTTP_INVALID_SETTING_VALUE); + RETURN_STRING_LITERAL(QUIC_HPACK_INDEX_VARINT_ERROR); + RETURN_STRING_LITERAL(QUIC_HPACK_NAME_LENGTH_VARINT_ERROR); + RETURN_STRING_LITERAL(QUIC_HPACK_VALUE_LENGTH_VARINT_ERROR); + RETURN_STRING_LITERAL(QUIC_HPACK_NAME_TOO_LONG); + RETURN_STRING_LITERAL(QUIC_HPACK_VALUE_TOO_LONG); + RETURN_STRING_LITERAL(QUIC_HPACK_NAME_HUFFMAN_ERROR); + RETURN_STRING_LITERAL(QUIC_HPACK_VALUE_HUFFMAN_ERROR); + RETURN_STRING_LITERAL(QUIC_HPACK_MISSING_DYNAMIC_TABLE_SIZE_UPDATE); + RETURN_STRING_LITERAL(QUIC_HPACK_INVALID_INDEX); + RETURN_STRING_LITERAL(QUIC_HPACK_INVALID_NAME_INDEX); + RETURN_STRING_LITERAL(QUIC_HPACK_DYNAMIC_TABLE_SIZE_UPDATE_NOT_ALLOWED); + RETURN_STRING_LITERAL( + QUIC_HPACK_INITIAL_TABLE_SIZE_UPDATE_IS_ABOVE_LOW_WATER_MARK); + RETURN_STRING_LITERAL( + QUIC_HPACK_TABLE_SIZE_UPDATE_IS_ABOVE_ACKNOWLEDGED_SETTING); + RETURN_STRING_LITERAL(QUIC_HPACK_TRUNCATED_BLOCK); + RETURN_STRING_LITERAL(QUIC_HPACK_FRAGMENT_TOO_LONG); + RETURN_STRING_LITERAL(QUIC_HPACK_COMPRESSED_HEADER_SIZE_EXCEEDS_LIMIT); + RETURN_STRING_LITERAL(QUIC_ZERO_RTT_UNRETRANSMITTABLE); + RETURN_STRING_LITERAL(QUIC_ZERO_RTT_REJECTION_LIMIT_REDUCED); + RETURN_STRING_LITERAL(QUIC_ZERO_RTT_RESUMPTION_LIMIT_REDUCED); + RETURN_STRING_LITERAL(QUIC_SILENT_IDLE_TIMEOUT); + RETURN_STRING_LITERAL(QUIC_MISSING_WRITE_KEYS); + RETURN_STRING_LITERAL(QUIC_KEY_UPDATE_ERROR); + RETURN_STRING_LITERAL(QUIC_AEAD_LIMIT_REACHED); + RETURN_STRING_LITERAL(QUIC_MAX_AGE_TIMEOUT); + RETURN_STRING_LITERAL(QUIC_INVALID_PRIORITY_UPDATE); + RETURN_STRING_LITERAL(QUIC_TLS_BAD_CERTIFICATE); + RETURN_STRING_LITERAL(QUIC_TLS_UNSUPPORTED_CERTIFICATE); + RETURN_STRING_LITERAL(QUIC_TLS_CERTIFICATE_REVOKED); + RETURN_STRING_LITERAL(QUIC_TLS_CERTIFICATE_EXPIRED); + RETURN_STRING_LITERAL(QUIC_TLS_CERTIFICATE_UNKNOWN); + RETURN_STRING_LITERAL(QUIC_TLS_INTERNAL_ERROR); + RETURN_STRING_LITERAL(QUIC_TLS_UNRECOGNIZED_NAME); + RETURN_STRING_LITERAL(QUIC_TLS_CERTIFICATE_REQUIRED); + RETURN_STRING_LITERAL(QUIC_INVALID_CHARACTER_IN_FIELD_VALUE); + RETURN_STRING_LITERAL(QUIC_TLS_UNEXPECTED_KEYING_MATERIAL_EXPORT_LABEL); + RETURN_STRING_LITERAL(QUIC_TLS_KEYING_MATERIAL_EXPORTS_MISMATCH); + RETURN_STRING_LITERAL(QUIC_TLS_KEYING_MATERIAL_EXPORT_NOT_AVAILABLE); + RETURN_STRING_LITERAL(QUIC_UNEXPECTED_DATA_BEFORE_ENCRYPTION_ESTABLISHED); + RETURN_STRING_LITERAL(QUIC_SERVER_UNHEALTHY); + + RETURN_STRING_LITERAL(QUIC_LAST_ERROR); + // Intentionally have no default case, so we'll break the build + // if we add errors and don't put them here. + } + // Return a default value so that we return this when |error| doesn't match + // any of the QuicErrorCodes. This can happen when the ConnectionClose + // frame sent by the peer (attacker) has invalid error code. + return "INVALID_ERROR_CODE"; +} + +std::string QuicIetfTransportErrorCodeString(QuicIetfTransportErrorCodes c) { + if (c >= CRYPTO_ERROR_FIRST && c <= CRYPTO_ERROR_LAST) { + const int tls_error = static_cast(c - CRYPTO_ERROR_FIRST); + const char* tls_error_description = SSL_alert_desc_string_long(tls_error); + if (strcmp("unknown", tls_error_description) != 0) { + return absl::StrCat("CRYPTO_ERROR(", tls_error_description, ")"); + } + return absl::StrCat("CRYPTO_ERROR(unknown(", tls_error, "))"); + } + + switch (c) { + RETURN_STRING_LITERAL(NO_IETF_QUIC_ERROR); + RETURN_STRING_LITERAL(INTERNAL_ERROR); + RETURN_STRING_LITERAL(SERVER_BUSY_ERROR); + RETURN_STRING_LITERAL(FLOW_CONTROL_ERROR); + RETURN_STRING_LITERAL(STREAM_LIMIT_ERROR); + RETURN_STRING_LITERAL(STREAM_STATE_ERROR); + RETURN_STRING_LITERAL(FINAL_SIZE_ERROR); + RETURN_STRING_LITERAL(FRAME_ENCODING_ERROR); + RETURN_STRING_LITERAL(TRANSPORT_PARAMETER_ERROR); + RETURN_STRING_LITERAL(CONNECTION_ID_LIMIT_ERROR); + RETURN_STRING_LITERAL(PROTOCOL_VIOLATION); + RETURN_STRING_LITERAL(INVALID_TOKEN); + RETURN_STRING_LITERAL(CRYPTO_BUFFER_EXCEEDED); + RETURN_STRING_LITERAL(KEY_UPDATE_ERROR); + RETURN_STRING_LITERAL(AEAD_LIMIT_REACHED); + // CRYPTO_ERROR is handled in the if before this switch, these cases do not + // change behavior and are only here to make the compiler happy. + case CRYPTO_ERROR_FIRST: + case CRYPTO_ERROR_LAST: + QUICHE_DCHECK(false) << "Unexpected error " << static_cast(c); + break; + } + + return absl::StrCat("Unknown(", static_cast(c), ")"); +} + +std::ostream& operator<<(std::ostream& os, + const QuicIetfTransportErrorCodes& c) { + os << QuicIetfTransportErrorCodeString(c); + return os; +} + +QuicErrorCodeToIetfMapping QuicErrorCodeToTransportErrorCode( + QuicErrorCode error) { + switch (error) { + case QUIC_NO_ERROR: + return {true, static_cast(NO_IETF_QUIC_ERROR)}; + case QUIC_INTERNAL_ERROR: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_STREAM_DATA_AFTER_TERMINATION: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_INVALID_PACKET_HEADER: + return {true, static_cast(FRAME_ENCODING_ERROR)}; + case QUIC_INVALID_FRAME_DATA: + return {true, static_cast(FRAME_ENCODING_ERROR)}; + case QUIC_MISSING_PAYLOAD: + return {true, static_cast(FRAME_ENCODING_ERROR)}; + case QUIC_INVALID_FEC_DATA: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_INVALID_STREAM_DATA: + return {true, static_cast(FRAME_ENCODING_ERROR)}; + case QUIC_OVERLAPPING_STREAM_DATA: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_UNENCRYPTED_STREAM_DATA: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_ATTEMPT_TO_SEND_UNENCRYPTED_STREAM_DATA: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_MAYBE_CORRUPTED_MEMORY: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_UNENCRYPTED_FEC_DATA: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_INVALID_RST_STREAM_DATA: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_INVALID_CONNECTION_CLOSE_DATA: + return {true, static_cast(FRAME_ENCODING_ERROR)}; + case QUIC_INVALID_GOAWAY_DATA: + return {true, static_cast(FRAME_ENCODING_ERROR)}; + case QUIC_INVALID_WINDOW_UPDATE_DATA: + return {true, static_cast(FRAME_ENCODING_ERROR)}; + case QUIC_INVALID_BLOCKED_DATA: + return {true, static_cast(FRAME_ENCODING_ERROR)}; + case QUIC_INVALID_STOP_WAITING_DATA: + return {true, static_cast(FRAME_ENCODING_ERROR)}; + case QUIC_INVALID_PATH_CLOSE_DATA: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_INVALID_ACK_DATA: + return {true, static_cast(FRAME_ENCODING_ERROR)}; + case QUIC_INVALID_MESSAGE_DATA: + return {true, static_cast(FRAME_ENCODING_ERROR)}; + case QUIC_INVALID_VERSION_NEGOTIATION_PACKET: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_INVALID_PUBLIC_RST_PACKET: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_DECRYPTION_FAILURE: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_ENCRYPTION_FAILURE: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_PACKET_TOO_LARGE: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_PEER_GOING_AWAY: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_INVALID_STREAM_ID: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_INVALID_PRIORITY: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_TOO_MANY_OPEN_STREAMS: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_TOO_MANY_AVAILABLE_STREAMS: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_PUBLIC_RESET: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_INVALID_VERSION: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_PACKET_WRONG_VERSION: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_INVALID_0RTT_PACKET_NUMBER_OUT_OF_ORDER: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_INVALID_HEADER_ID: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_INVALID_NEGOTIATED_VALUE: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_DECOMPRESSION_FAILURE: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_NETWORK_IDLE_TIMEOUT: + return {true, static_cast(NO_IETF_QUIC_ERROR)}; + case QUIC_SILENT_IDLE_TIMEOUT: + return {true, static_cast(NO_IETF_QUIC_ERROR)}; + case QUIC_HANDSHAKE_TIMEOUT: + return {true, static_cast(NO_IETF_QUIC_ERROR)}; + case QUIC_ERROR_MIGRATING_ADDRESS: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_ERROR_MIGRATING_PORT: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_PACKET_WRITE_ERROR: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_PACKET_READ_ERROR: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_EMPTY_STREAM_FRAME_NO_FIN: + return {true, static_cast(FRAME_ENCODING_ERROR)}; + case QUIC_INVALID_HEADERS_STREAM_DATA: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_HEADERS_STREAM_DATA_DECOMPRESS_FAILURE: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_FLOW_CONTROL_RECEIVED_TOO_MUCH_DATA: + return {true, static_cast(FLOW_CONTROL_ERROR)}; + case QUIC_FLOW_CONTROL_SENT_TOO_MUCH_DATA: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_FLOW_CONTROL_INVALID_WINDOW: + return {true, static_cast(FLOW_CONTROL_ERROR)}; + case QUIC_CONNECTION_IP_POOLED: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_TOO_MANY_OUTSTANDING_SENT_PACKETS: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_TOO_MANY_OUTSTANDING_RECEIVED_PACKETS: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_CONNECTION_CANCELLED: + return {true, static_cast(NO_IETF_QUIC_ERROR)}; + case QUIC_BAD_PACKET_LOSS_RATE: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_PUBLIC_RESETS_POST_HANDSHAKE: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_FAILED_TO_SERIALIZE_PACKET: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_TOO_MANY_RTOS: + return {true, static_cast(NO_IETF_QUIC_ERROR)}; + case QUIC_HANDSHAKE_FAILED: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_CRYPTO_TAGS_OUT_OF_ORDER: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_CRYPTO_TOO_MANY_ENTRIES: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_CRYPTO_INVALID_VALUE_LENGTH: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_CRYPTO_MESSAGE_AFTER_HANDSHAKE_COMPLETE: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_INVALID_CRYPTO_MESSAGE_TYPE: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_INVALID_CHANNEL_ID_SIGNATURE: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_CRYPTO_MESSAGE_PARAMETER_NOT_FOUND: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_CRYPTO_MESSAGE_PARAMETER_NO_OVERLAP: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_CRYPTO_MESSAGE_INDEX_NOT_FOUND: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_UNSUPPORTED_PROOF_DEMAND: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_CRYPTO_INTERNAL_ERROR: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_CRYPTO_VERSION_NOT_SUPPORTED: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_CRYPTO_NO_SUPPORT: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_CRYPTO_TOO_MANY_REJECTS: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_PROOF_INVALID: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_CRYPTO_DUPLICATE_TAG: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_CRYPTO_ENCRYPTION_LEVEL_INCORRECT: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_CRYPTO_SERVER_CONFIG_EXPIRED: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_CRYPTO_SYMMETRIC_KEY_SETUP_FAILED: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_CRYPTO_MESSAGE_WHILE_VALIDATING_CLIENT_HELLO: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_CRYPTO_UPDATE_BEFORE_HANDSHAKE_COMPLETE: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_CRYPTO_CHLO_TOO_LARGE: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_VERSION_NEGOTIATION_MISMATCH: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_BAD_MULTIPATH_FLAG: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_MULTIPATH_PATH_DOES_NOT_EXIST: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_MULTIPATH_PATH_NOT_ACTIVE: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_IP_ADDRESS_CHANGED: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_CONNECTION_MIGRATION_NO_MIGRATABLE_STREAMS: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_CONNECTION_MIGRATION_TOO_MANY_CHANGES: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_CONNECTION_MIGRATION_NO_NEW_NETWORK: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_CONNECTION_MIGRATION_NON_MIGRATABLE_STREAM: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_CONNECTION_MIGRATION_DISABLED_BY_CONFIG: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_CONNECTION_MIGRATION_INTERNAL_ERROR: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_CONNECTION_MIGRATION_HANDSHAKE_UNCONFIRMED: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_PEER_PORT_CHANGE_HANDSHAKE_UNCONFIRMED: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_TOO_MANY_STREAM_DATA_INTERVALS: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_STREAM_SEQUENCER_INVALID_STATE: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_TOO_MANY_SESSIONS_ON_SERVER: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_STREAM_LENGTH_OVERFLOW: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_INVALID_MAX_DATA_FRAME_DATA: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_INVALID_MAX_STREAM_DATA_FRAME_DATA: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_MAX_STREAMS_DATA: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_STREAMS_BLOCKED_DATA: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_INVALID_STREAM_BLOCKED_DATA: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_INVALID_NEW_CONNECTION_ID_DATA: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_INVALID_STOP_SENDING_FRAME_DATA: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_INVALID_PATH_CHALLENGE_DATA: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_INVALID_PATH_RESPONSE_DATA: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case IETF_QUIC_PROTOCOL_VIOLATION: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_INVALID_NEW_TOKEN: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_DATA_RECEIVED_ON_WRITE_UNIDIRECTIONAL_STREAM: + return {true, static_cast(STREAM_STATE_ERROR)}; + case QUIC_TRY_TO_WRITE_DATA_ON_READ_UNIDIRECTIONAL_STREAM: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_INVALID_RETIRE_CONNECTION_ID_DATA: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_STREAMS_BLOCKED_ERROR: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_MAX_STREAMS_ERROR: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_HTTP_DECODER_ERROR: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_STALE_CONNECTION_CANCELLED: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_IETF_GQUIC_ERROR_MISSING: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_WINDOW_UPDATE_RECEIVED_ON_READ_UNIDIRECTIONAL_STREAM: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_TOO_MANY_BUFFERED_CONTROL_FRAMES: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_TRANSPORT_INVALID_CLIENT_INDICATION: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_QPACK_DECOMPRESSION_FAILED: + return {false, static_cast( + QuicHttpQpackErrorCode::DECOMPRESSION_FAILED)}; + case QUIC_QPACK_ENCODER_STREAM_ERROR: + return {false, static_cast( + QuicHttpQpackErrorCode::ENCODER_STREAM_ERROR)}; + case QUIC_QPACK_DECODER_STREAM_ERROR: + return {false, static_cast( + QuicHttpQpackErrorCode::DECODER_STREAM_ERROR)}; + case QUIC_QPACK_ENCODER_STREAM_INTEGER_TOO_LARGE: + return {false, static_cast( + QuicHttpQpackErrorCode::ENCODER_STREAM_ERROR)}; + case QUIC_QPACK_ENCODER_STREAM_STRING_LITERAL_TOO_LONG: + return {false, static_cast( + QuicHttpQpackErrorCode::ENCODER_STREAM_ERROR)}; + case QUIC_QPACK_ENCODER_STREAM_HUFFMAN_ENCODING_ERROR: + return {false, static_cast( + QuicHttpQpackErrorCode::ENCODER_STREAM_ERROR)}; + case QUIC_QPACK_ENCODER_STREAM_INVALID_STATIC_ENTRY: + return {false, static_cast( + QuicHttpQpackErrorCode::ENCODER_STREAM_ERROR)}; + case QUIC_QPACK_ENCODER_STREAM_ERROR_INSERTING_STATIC: + return {false, static_cast( + QuicHttpQpackErrorCode::ENCODER_STREAM_ERROR)}; + case QUIC_QPACK_ENCODER_STREAM_INSERTION_INVALID_RELATIVE_INDEX: + return {false, static_cast( + QuicHttpQpackErrorCode::ENCODER_STREAM_ERROR)}; + case QUIC_QPACK_ENCODER_STREAM_INSERTION_DYNAMIC_ENTRY_NOT_FOUND: + return {false, static_cast( + QuicHttpQpackErrorCode::ENCODER_STREAM_ERROR)}; + case QUIC_QPACK_ENCODER_STREAM_ERROR_INSERTING_DYNAMIC: + return {false, static_cast( + QuicHttpQpackErrorCode::ENCODER_STREAM_ERROR)}; + case QUIC_QPACK_ENCODER_STREAM_ERROR_INSERTING_LITERAL: + return {false, static_cast( + QuicHttpQpackErrorCode::ENCODER_STREAM_ERROR)}; + case QUIC_QPACK_ENCODER_STREAM_DUPLICATE_INVALID_RELATIVE_INDEX: + return {false, static_cast( + QuicHttpQpackErrorCode::ENCODER_STREAM_ERROR)}; + case QUIC_QPACK_ENCODER_STREAM_DUPLICATE_DYNAMIC_ENTRY_NOT_FOUND: + return {false, static_cast( + QuicHttpQpackErrorCode::ENCODER_STREAM_ERROR)}; + case QUIC_QPACK_ENCODER_STREAM_SET_DYNAMIC_TABLE_CAPACITY: + return {false, static_cast( + QuicHttpQpackErrorCode::ENCODER_STREAM_ERROR)}; + case QUIC_QPACK_DECODER_STREAM_INTEGER_TOO_LARGE: + return {false, static_cast( + QuicHttpQpackErrorCode::DECODER_STREAM_ERROR)}; + case QUIC_QPACK_DECODER_STREAM_INVALID_ZERO_INCREMENT: + return {false, static_cast( + QuicHttpQpackErrorCode::DECODER_STREAM_ERROR)}; + case QUIC_QPACK_DECODER_STREAM_INCREMENT_OVERFLOW: + return {false, static_cast( + QuicHttpQpackErrorCode::DECODER_STREAM_ERROR)}; + case QUIC_QPACK_DECODER_STREAM_IMPOSSIBLE_INSERT_COUNT: + return {false, static_cast( + QuicHttpQpackErrorCode::DECODER_STREAM_ERROR)}; + case QUIC_QPACK_DECODER_STREAM_INCORRECT_ACKNOWLEDGEMENT: + return {false, static_cast( + QuicHttpQpackErrorCode::DECODER_STREAM_ERROR)}; + case QUIC_STREAM_DATA_BEYOND_CLOSE_OFFSET: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_STREAM_MULTIPLE_OFFSET: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_HTTP_FRAME_TOO_LARGE: + return {false, static_cast(QuicHttp3ErrorCode::EXCESSIVE_LOAD)}; + case QUIC_HTTP_FRAME_ERROR: + return {false, static_cast(QuicHttp3ErrorCode::FRAME_ERROR)}; + case QUIC_HTTP_FRAME_UNEXPECTED_ON_SPDY_STREAM: + return {false, + static_cast(QuicHttp3ErrorCode::FRAME_UNEXPECTED)}; + case QUIC_HTTP_FRAME_UNEXPECTED_ON_CONTROL_STREAM: + return {false, + static_cast(QuicHttp3ErrorCode::FRAME_UNEXPECTED)}; + case QUIC_HTTP_INVALID_FRAME_SEQUENCE_ON_SPDY_STREAM: + return {false, + static_cast(QuicHttp3ErrorCode::FRAME_UNEXPECTED)}; + case QUIC_HTTP_INVALID_FRAME_SEQUENCE_ON_CONTROL_STREAM: + return {false, + static_cast(QuicHttp3ErrorCode::FRAME_UNEXPECTED)}; + case QUIC_HTTP_DUPLICATE_UNIDIRECTIONAL_STREAM: + return {false, + static_cast(QuicHttp3ErrorCode::STREAM_CREATION_ERROR)}; + case QUIC_HTTP_SERVER_INITIATED_BIDIRECTIONAL_STREAM: + return {false, + static_cast(QuicHttp3ErrorCode::STREAM_CREATION_ERROR)}; + case QUIC_HTTP_STREAM_WRONG_DIRECTION: + return {true, static_cast(STREAM_STATE_ERROR)}; + case QUIC_HTTP_CLOSED_CRITICAL_STREAM: + return {false, static_cast( + QuicHttp3ErrorCode::CLOSED_CRITICAL_STREAM)}; + case QUIC_HTTP_MISSING_SETTINGS_FRAME: + return {false, + static_cast(QuicHttp3ErrorCode::MISSING_SETTINGS)}; + case QUIC_HTTP_DUPLICATE_SETTING_IDENTIFIER: + return {false, static_cast(QuicHttp3ErrorCode::SETTINGS_ERROR)}; + case QUIC_HTTP_INVALID_MAX_PUSH_ID: + return {false, static_cast(QuicHttp3ErrorCode::ID_ERROR)}; + case QUIC_HTTP_STREAM_LIMIT_TOO_LOW: + return {false, static_cast( + QuicHttp3ErrorCode::GENERAL_PROTOCOL_ERROR)}; + case QUIC_HTTP_RECEIVE_SERVER_PUSH: + return {false, static_cast( + QuicHttp3ErrorCode::GENERAL_PROTOCOL_ERROR)}; + case QUIC_HTTP_ZERO_RTT_RESUMPTION_SETTINGS_MISMATCH: + return {false, static_cast(QuicHttp3ErrorCode::SETTINGS_ERROR)}; + case QUIC_HTTP_ZERO_RTT_REJECTION_SETTINGS_MISMATCH: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_HTTP_GOAWAY_INVALID_STREAM_ID: + return {false, static_cast(QuicHttp3ErrorCode::ID_ERROR)}; + case QUIC_HTTP_GOAWAY_ID_LARGER_THAN_PREVIOUS: + return {false, static_cast(QuicHttp3ErrorCode::ID_ERROR)}; + case QUIC_HTTP_RECEIVE_SPDY_SETTING: + return {false, static_cast(QuicHttp3ErrorCode::SETTINGS_ERROR)}; + case QUIC_HTTP_INVALID_SETTING_VALUE: + return {false, static_cast(QuicHttp3ErrorCode::SETTINGS_ERROR)}; + case QUIC_HTTP_RECEIVE_SPDY_FRAME: + return {false, + static_cast(QuicHttp3ErrorCode::FRAME_UNEXPECTED)}; + case QUIC_HPACK_INDEX_VARINT_ERROR: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_HPACK_NAME_LENGTH_VARINT_ERROR: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_HPACK_VALUE_LENGTH_VARINT_ERROR: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_HPACK_NAME_TOO_LONG: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_HPACK_VALUE_TOO_LONG: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_HPACK_NAME_HUFFMAN_ERROR: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_HPACK_VALUE_HUFFMAN_ERROR: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_HPACK_MISSING_DYNAMIC_TABLE_SIZE_UPDATE: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_HPACK_INVALID_INDEX: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_HPACK_INVALID_NAME_INDEX: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_HPACK_DYNAMIC_TABLE_SIZE_UPDATE_NOT_ALLOWED: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_HPACK_INITIAL_TABLE_SIZE_UPDATE_IS_ABOVE_LOW_WATER_MARK: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_HPACK_TABLE_SIZE_UPDATE_IS_ABOVE_ACKNOWLEDGED_SETTING: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_HPACK_TRUNCATED_BLOCK: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_HPACK_FRAGMENT_TOO_LONG: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_HPACK_COMPRESSED_HEADER_SIZE_EXCEEDS_LIMIT: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_ZERO_RTT_UNRETRANSMITTABLE: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_ZERO_RTT_REJECTION_LIMIT_REDUCED: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_ZERO_RTT_RESUMPTION_LIMIT_REDUCED: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_MISSING_WRITE_KEYS: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_KEY_UPDATE_ERROR: + return {true, static_cast(KEY_UPDATE_ERROR)}; + case QUIC_AEAD_LIMIT_REACHED: + return {true, static_cast(AEAD_LIMIT_REACHED)}; + case QUIC_MAX_AGE_TIMEOUT: + return {false, static_cast(QuicHttp3ErrorCode::INTERNAL_ERROR)}; + case QUIC_INVALID_PRIORITY_UPDATE: + return {false, static_cast( + QuicHttp3ErrorCode::GENERAL_PROTOCOL_ERROR)}; + case QUIC_TLS_BAD_CERTIFICATE: + return {true, static_cast(CRYPTO_ERROR_FIRST + + SSL_AD_BAD_CERTIFICATE)}; + case QUIC_TLS_UNSUPPORTED_CERTIFICATE: + return {true, static_cast(CRYPTO_ERROR_FIRST + + SSL_AD_UNSUPPORTED_CERTIFICATE)}; + case QUIC_TLS_CERTIFICATE_REVOKED: + return {true, static_cast(CRYPTO_ERROR_FIRST + + SSL_AD_CERTIFICATE_REVOKED)}; + case QUIC_TLS_CERTIFICATE_EXPIRED: + return {true, static_cast(CRYPTO_ERROR_FIRST + + SSL_AD_CERTIFICATE_EXPIRED)}; + case QUIC_TLS_CERTIFICATE_UNKNOWN: + return {true, static_cast(CRYPTO_ERROR_FIRST + + SSL_AD_CERTIFICATE_UNKNOWN)}; + case QUIC_TLS_INTERNAL_ERROR: + return {true, static_cast(CRYPTO_ERROR_FIRST + + SSL_AD_INTERNAL_ERROR)}; + case QUIC_TLS_UNRECOGNIZED_NAME: + return {true, static_cast(CRYPTO_ERROR_FIRST + + SSL_AD_UNRECOGNIZED_NAME)}; + case QUIC_TLS_CERTIFICATE_REQUIRED: + return {true, static_cast(CRYPTO_ERROR_FIRST + + SSL_AD_CERTIFICATE_REQUIRED)}; + case QUIC_CONNECTION_ID_LIMIT_ERROR: + return {true, static_cast(CONNECTION_ID_LIMIT_ERROR)}; + case QUIC_TOO_MANY_CONNECTION_ID_WAITING_TO_RETIRE: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_INVALID_CHARACTER_IN_FIELD_VALUE: + return {false, static_cast(QuicHttp3ErrorCode::MESSAGE_ERROR)}; + case QUIC_TLS_UNEXPECTED_KEYING_MATERIAL_EXPORT_LABEL: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_TLS_KEYING_MATERIAL_EXPORTS_MISMATCH: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_TLS_KEYING_MATERIAL_EXPORT_NOT_AVAILABLE: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_UNEXPECTED_DATA_BEFORE_ENCRYPTION_ESTABLISHED: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_SERVER_UNHEALTHY: + return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_LAST_ERROR: + return {false, static_cast(QUIC_LAST_ERROR)}; + } + // This function should not be called with unknown error code. + return {true, static_cast(INTERNAL_ERROR)}; +} + +QuicErrorCode TlsAlertToQuicErrorCode(uint8_t desc) { + switch (desc) { + case SSL_AD_BAD_CERTIFICATE: + return QUIC_TLS_BAD_CERTIFICATE; + case SSL_AD_UNSUPPORTED_CERTIFICATE: + return QUIC_TLS_UNSUPPORTED_CERTIFICATE; + case SSL_AD_CERTIFICATE_REVOKED: + return QUIC_TLS_CERTIFICATE_REVOKED; + case SSL_AD_CERTIFICATE_EXPIRED: + return QUIC_TLS_CERTIFICATE_EXPIRED; + case SSL_AD_CERTIFICATE_UNKNOWN: + return QUIC_TLS_CERTIFICATE_UNKNOWN; + case SSL_AD_INTERNAL_ERROR: + return QUIC_TLS_INTERNAL_ERROR; + case SSL_AD_UNRECOGNIZED_NAME: + return QUIC_TLS_UNRECOGNIZED_NAME; + case SSL_AD_CERTIFICATE_REQUIRED: + return QUIC_TLS_CERTIFICATE_REQUIRED; + default: + return QUIC_HANDSHAKE_FAILED; + } +} + +// Convert a QuicRstStreamErrorCode to an application error code to be used in +// an IETF QUIC RESET_STREAM frame +uint64_t RstStreamErrorCodeToIetfResetStreamErrorCode( + QuicRstStreamErrorCode rst_stream_error_code) { + switch (rst_stream_error_code) { + case QUIC_STREAM_NO_ERROR: + return static_cast(QuicHttp3ErrorCode::HTTP3_NO_ERROR); + case QUIC_ERROR_PROCESSING_STREAM: + return static_cast(QuicHttp3ErrorCode::GENERAL_PROTOCOL_ERROR); + case QUIC_MULTIPLE_TERMINATION_OFFSETS: + return static_cast(QuicHttp3ErrorCode::GENERAL_PROTOCOL_ERROR); + case QUIC_BAD_APPLICATION_PAYLOAD: + return static_cast(QuicHttp3ErrorCode::GENERAL_PROTOCOL_ERROR); + case QUIC_STREAM_CONNECTION_ERROR: + return static_cast(QuicHttp3ErrorCode::INTERNAL_ERROR); + case QUIC_STREAM_PEER_GOING_AWAY: + return static_cast(QuicHttp3ErrorCode::GENERAL_PROTOCOL_ERROR); + case QUIC_STREAM_CANCELLED: + return static_cast(QuicHttp3ErrorCode::REQUEST_CANCELLED); + case QUIC_RST_ACKNOWLEDGEMENT: + return static_cast(QuicHttp3ErrorCode::HTTP3_NO_ERROR); + case QUIC_REFUSED_STREAM: + return static_cast(QuicHttp3ErrorCode::ID_ERROR); + case QUIC_INVALID_PROMISE_URL: + return static_cast(QuicHttp3ErrorCode::STREAM_CREATION_ERROR); + case QUIC_UNAUTHORIZED_PROMISE_URL: + return static_cast(QuicHttp3ErrorCode::STREAM_CREATION_ERROR); + case QUIC_DUPLICATE_PROMISE_URL: + return static_cast(QuicHttp3ErrorCode::STREAM_CREATION_ERROR); + case QUIC_PROMISE_VARY_MISMATCH: + return static_cast(QuicHttp3ErrorCode::REQUEST_CANCELLED); + case QUIC_INVALID_PROMISE_METHOD: + return static_cast(QuicHttp3ErrorCode::STREAM_CREATION_ERROR); + case QUIC_PUSH_STREAM_TIMED_OUT: + return static_cast(QuicHttp3ErrorCode::REQUEST_CANCELLED); + case QUIC_HEADERS_TOO_LARGE: + return static_cast(QuicHttp3ErrorCode::EXCESSIVE_LOAD); + case QUIC_STREAM_TTL_EXPIRED: + return static_cast(QuicHttp3ErrorCode::REQUEST_CANCELLED); + case QUIC_DATA_AFTER_CLOSE_OFFSET: + return static_cast(QuicHttp3ErrorCode::GENERAL_PROTOCOL_ERROR); + case QUIC_STREAM_GENERAL_PROTOCOL_ERROR: + return static_cast(QuicHttp3ErrorCode::GENERAL_PROTOCOL_ERROR); + case QUIC_STREAM_INTERNAL_ERROR: + return static_cast(QuicHttp3ErrorCode::INTERNAL_ERROR); + case QUIC_STREAM_STREAM_CREATION_ERROR: + return static_cast(QuicHttp3ErrorCode::STREAM_CREATION_ERROR); + case QUIC_STREAM_CLOSED_CRITICAL_STREAM: + return static_cast(QuicHttp3ErrorCode::CLOSED_CRITICAL_STREAM); + case QUIC_STREAM_FRAME_UNEXPECTED: + return static_cast(QuicHttp3ErrorCode::FRAME_UNEXPECTED); + case QUIC_STREAM_FRAME_ERROR: + return static_cast(QuicHttp3ErrorCode::FRAME_ERROR); + case QUIC_STREAM_EXCESSIVE_LOAD: + return static_cast(QuicHttp3ErrorCode::EXCESSIVE_LOAD); + case QUIC_STREAM_ID_ERROR: + return static_cast(QuicHttp3ErrorCode::ID_ERROR); + case QUIC_STREAM_SETTINGS_ERROR: + return static_cast(QuicHttp3ErrorCode::SETTINGS_ERROR); + case QUIC_STREAM_MISSING_SETTINGS: + return static_cast(QuicHttp3ErrorCode::MISSING_SETTINGS); + case QUIC_STREAM_REQUEST_REJECTED: + return static_cast(QuicHttp3ErrorCode::REQUEST_REJECTED); + case QUIC_STREAM_REQUEST_INCOMPLETE: + return static_cast(QuicHttp3ErrorCode::REQUEST_INCOMPLETE); + case QUIC_STREAM_CONNECT_ERROR: + return static_cast(QuicHttp3ErrorCode::CONNECT_ERROR); + case QUIC_STREAM_VERSION_FALLBACK: + return static_cast(QuicHttp3ErrorCode::VERSION_FALLBACK); + case QUIC_STREAM_DECOMPRESSION_FAILED: + return static_cast( + QuicHttpQpackErrorCode::DECOMPRESSION_FAILED); + case QUIC_STREAM_ENCODER_STREAM_ERROR: + return static_cast( + QuicHttpQpackErrorCode::ENCODER_STREAM_ERROR); + case QUIC_STREAM_DECODER_STREAM_ERROR: + return static_cast( + QuicHttpQpackErrorCode::DECODER_STREAM_ERROR); + case QUIC_STREAM_UNKNOWN_APPLICATION_ERROR_CODE: + return static_cast(QuicHttp3ErrorCode::INTERNAL_ERROR); + case QUIC_STREAM_WEBTRANSPORT_SESSION_GONE: + return static_cast(QuicHttp3ErrorCode::CONNECT_ERROR); + case QUIC_STREAM_WEBTRANSPORT_BUFFERED_STREAMS_LIMIT_EXCEEDED: + return static_cast(QuicHttp3ErrorCode::CONNECT_ERROR); + case QUIC_APPLICATION_DONE_WITH_STREAM: + return static_cast(QuicHttp3ErrorCode::GENERAL_PROTOCOL_ERROR); + case QUIC_STREAM_LAST_ERROR: + return static_cast(QuicHttp3ErrorCode::INTERNAL_ERROR); + } + return static_cast(QuicHttp3ErrorCode::INTERNAL_ERROR); +} + +// Convert the application error code of an IETF QUIC RESET_STREAM frame +// to QuicRstStreamErrorCode. +QuicRstStreamErrorCode IetfResetStreamErrorCodeToRstStreamErrorCode( + uint64_t ietf_error_code) { + switch (ietf_error_code) { + case static_cast(QuicHttp3ErrorCode::HTTP3_NO_ERROR): + return QUIC_STREAM_NO_ERROR; + case static_cast(QuicHttp3ErrorCode::GENERAL_PROTOCOL_ERROR): + return QUIC_STREAM_GENERAL_PROTOCOL_ERROR; + case static_cast(QuicHttp3ErrorCode::INTERNAL_ERROR): + return QUIC_STREAM_INTERNAL_ERROR; + case static_cast(QuicHttp3ErrorCode::STREAM_CREATION_ERROR): + return QUIC_STREAM_STREAM_CREATION_ERROR; + case static_cast(QuicHttp3ErrorCode::CLOSED_CRITICAL_STREAM): + return QUIC_STREAM_CLOSED_CRITICAL_STREAM; + case static_cast(QuicHttp3ErrorCode::FRAME_UNEXPECTED): + return QUIC_STREAM_FRAME_UNEXPECTED; + case static_cast(QuicHttp3ErrorCode::FRAME_ERROR): + return QUIC_STREAM_FRAME_ERROR; + case static_cast(QuicHttp3ErrorCode::EXCESSIVE_LOAD): + return QUIC_STREAM_EXCESSIVE_LOAD; + case static_cast(QuicHttp3ErrorCode::ID_ERROR): + return QUIC_STREAM_ID_ERROR; + case static_cast(QuicHttp3ErrorCode::SETTINGS_ERROR): + return QUIC_STREAM_SETTINGS_ERROR; + case static_cast(QuicHttp3ErrorCode::MISSING_SETTINGS): + return QUIC_STREAM_MISSING_SETTINGS; + case static_cast(QuicHttp3ErrorCode::REQUEST_REJECTED): + return QUIC_STREAM_REQUEST_REJECTED; + case static_cast(QuicHttp3ErrorCode::REQUEST_CANCELLED): + return QUIC_STREAM_CANCELLED; + case static_cast(QuicHttp3ErrorCode::REQUEST_INCOMPLETE): + return QUIC_STREAM_REQUEST_INCOMPLETE; + case static_cast(QuicHttp3ErrorCode::CONNECT_ERROR): + return QUIC_STREAM_CONNECT_ERROR; + case static_cast(QuicHttp3ErrorCode::VERSION_FALLBACK): + return QUIC_STREAM_VERSION_FALLBACK; + case static_cast(QuicHttpQpackErrorCode::DECOMPRESSION_FAILED): + return QUIC_STREAM_DECOMPRESSION_FAILED; + case static_cast(QuicHttpQpackErrorCode::ENCODER_STREAM_ERROR): + return QUIC_STREAM_ENCODER_STREAM_ERROR; + case static_cast(QuicHttpQpackErrorCode::DECODER_STREAM_ERROR): + return QUIC_STREAM_DECODER_STREAM_ERROR; + } + return QUIC_STREAM_UNKNOWN_APPLICATION_ERROR_CODE; +} + +// static +QuicResetStreamError QuicResetStreamError::FromInternal( + QuicRstStreamErrorCode code) { + return QuicResetStreamError( + code, RstStreamErrorCodeToIetfResetStreamErrorCode(code)); +} + +// static +QuicResetStreamError QuicResetStreamError::FromIetf(uint64_t code) { + return QuicResetStreamError( + IetfResetStreamErrorCodeToRstStreamErrorCode(code), code); +} + +// static +QuicResetStreamError QuicResetStreamError::FromIetf(QuicHttp3ErrorCode code) { + return FromIetf(static_cast(code)); +} + +// static +QuicResetStreamError QuicResetStreamError::FromIetf( + QuicHttpQpackErrorCode code) { + return FromIetf(static_cast(code)); +} + +#undef RETURN_STRING_LITERAL // undef for jumbo builds + +} // namespace quic diff --git a/quiche/quic/core/quic_error_codes.h b/quiche/quic/core/quic_error_codes.h new file mode 100644 index 000000000000..e1a7b51e3d4f --- /dev/null +++ b/quiche/quic/core/quic_error_codes.h @@ -0,0 +1,776 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_ERROR_CODES_H_ +#define QUICHE_QUIC_CORE_QUIC_ERROR_CODES_H_ + +#include +#include +#include + +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// QuicRstStreamErrorCode is encoded as a single octet on-the-wire in IETF QUIC +// and a 32-bit integer in gQUIC. +enum QuicRstStreamErrorCode : uint32_t { + // Complete response has been sent, sending a RST to ask the other endpoint + // to stop sending request data without discarding the response. + QUIC_STREAM_NO_ERROR = 0, + + // There was some error which halted stream processing. + QUIC_ERROR_PROCESSING_STREAM = 1, + // We got two fin or reset offsets which did not match. + QUIC_MULTIPLE_TERMINATION_OFFSETS = 2, + // We got bad payload and can not respond to it at the protocol level. + QUIC_BAD_APPLICATION_PAYLOAD = 3, + // Stream closed due to connection error. No reset frame is sent when this + // happens. + QUIC_STREAM_CONNECTION_ERROR = 4, + // GoAway frame sent. No more stream can be created. + QUIC_STREAM_PEER_GOING_AWAY = 5, + // The stream has been cancelled. + QUIC_STREAM_CANCELLED = 6, + // Closing stream locally, sending a RST to allow for proper flow control + // accounting. Sent in response to a RST from the peer. + QUIC_RST_ACKNOWLEDGEMENT = 7, + // Receiver refused to create the stream (because its limit on open streams + // has been reached). The sender should retry the request later (using + // another stream). + QUIC_REFUSED_STREAM = 8, + // Invalid URL in PUSH_PROMISE request header. + QUIC_INVALID_PROMISE_URL = 9, + // Server is not authoritative for this URL. + QUIC_UNAUTHORIZED_PROMISE_URL = 10, + // Can't have more than one active PUSH_PROMISE per URL. + QUIC_DUPLICATE_PROMISE_URL = 11, + // Vary check failed. + QUIC_PROMISE_VARY_MISMATCH = 12, + // Only GET and HEAD methods allowed. + QUIC_INVALID_PROMISE_METHOD = 13, + // The push stream is unclaimed and timed out. + QUIC_PUSH_STREAM_TIMED_OUT = 14, + // Received headers were too large. + QUIC_HEADERS_TOO_LARGE = 15, + // The data is not likely arrive in time. + QUIC_STREAM_TTL_EXPIRED = 16, + // The stream received data that goes beyond its close offset. + QUIC_DATA_AFTER_CLOSE_OFFSET = 17, + // Peer violated protocol requirements in a way which does not match a more + // specific error code, or endpoint declines to use the more specific error + // code. + QUIC_STREAM_GENERAL_PROTOCOL_ERROR = 18, + // An internal error has occurred. + QUIC_STREAM_INTERNAL_ERROR = 19, + // Peer created a stream that will not be accepted. + QUIC_STREAM_STREAM_CREATION_ERROR = 20, + // A stream required by the connection was closed or reset. + QUIC_STREAM_CLOSED_CRITICAL_STREAM = 21, + // A frame was received which was not permitted in the current state or on the + // current stream. + QUIC_STREAM_FRAME_UNEXPECTED = 22, + // A frame that fails to satisfy layout requirements or with an invalid size + // was received. + QUIC_STREAM_FRAME_ERROR = 23, + // Peer exhibits a behavior that might be generating excessive load. + QUIC_STREAM_EXCESSIVE_LOAD = 24, + // A Stream ID or Push ID was used incorrectly, such as exceeding a limit, + // reducing a limit, or being reused. + QUIC_STREAM_ID_ERROR = 25, + // Error in the payload of a SETTINGS frame. + QUIC_STREAM_SETTINGS_ERROR = 26, + // No SETTINGS frame was received at the beginning of the control stream. + QUIC_STREAM_MISSING_SETTINGS = 27, + // A server rejected a request without performing any application processing. + QUIC_STREAM_REQUEST_REJECTED = 28, + // The client's stream terminated without containing a fully-formed request. + QUIC_STREAM_REQUEST_INCOMPLETE = 29, + // The connection established in response to a CONNECT request was reset or + // abnormally closed. + QUIC_STREAM_CONNECT_ERROR = 30, + // The requested operation cannot be served over HTTP/3. + // The peer should retry over HTTP/1.1. + QUIC_STREAM_VERSION_FALLBACK = 31, + // The QPACK decoder failed to interpret a header block and is not able to + // continue decoding that header block. + QUIC_STREAM_DECOMPRESSION_FAILED = 32, + // The QPACK decoder failed to interpret an encoder instruction received on + // the encoder stream. + QUIC_STREAM_ENCODER_STREAM_ERROR = 33, + // The QPACK encoder failed to interpret a decoder instruction received on the + // decoder stream. + QUIC_STREAM_DECODER_STREAM_ERROR = 34, + // IETF RESET_FRAME application error code not matching any HTTP/3 or QPACK + // error codes. + QUIC_STREAM_UNKNOWN_APPLICATION_ERROR_CODE = 35, + // WebTransport session is going away, causing all underlying streams to be + // reset. + QUIC_STREAM_WEBTRANSPORT_SESSION_GONE = 36, + // There is no corresponding WebTransport session to associate this stream + // with, and the limit for buffered streams has been exceeded. + QUIC_STREAM_WEBTRANSPORT_BUFFERED_STREAMS_LIMIT_EXCEEDED = 37, + // Application layer done with the current stream. + QUIC_APPLICATION_DONE_WITH_STREAM = 38, + // No error. Used as bound while iterating. + QUIC_STREAM_LAST_ERROR = 39, +}; +// QuicRstStreamErrorCode is encoded as a single octet on-the-wire. +static_assert(static_cast(QUIC_STREAM_LAST_ERROR) <= + std::numeric_limits::max(), + "QuicRstStreamErrorCode exceeds single octet"); + +// These values must remain stable as they are uploaded to UMA histograms. +// To add a new error code, use the current value of QUIC_LAST_ERROR and +// increment QUIC_LAST_ERROR. +enum QuicErrorCode { + QUIC_NO_ERROR = 0, + + // Connection has reached an invalid state. + QUIC_INTERNAL_ERROR = 1, + // There were data frames after the a fin or reset. + QUIC_STREAM_DATA_AFTER_TERMINATION = 2, + // Control frame is malformed. + QUIC_INVALID_PACKET_HEADER = 3, + // Frame data is malformed. + QUIC_INVALID_FRAME_DATA = 4, + // The packet contained no payload. + QUIC_MISSING_PAYLOAD = 48, + // FEC data is malformed. + QUIC_INVALID_FEC_DATA = 5, + // STREAM frame data is malformed. + QUIC_INVALID_STREAM_DATA = 46, + // STREAM frame data overlaps with buffered data. + QUIC_OVERLAPPING_STREAM_DATA = 87, + // Received STREAM frame data is not encrypted. + QUIC_UNENCRYPTED_STREAM_DATA = 61, + // Attempt to send unencrypted STREAM frame. + QUIC_ATTEMPT_TO_SEND_UNENCRYPTED_STREAM_DATA = 88, + // Received a frame which is likely the result of memory corruption. + QUIC_MAYBE_CORRUPTED_MEMORY = 89, + // FEC frame data is not encrypted. + QUIC_UNENCRYPTED_FEC_DATA = 77, + // RST_STREAM frame data is malformed. + QUIC_INVALID_RST_STREAM_DATA = 6, + // CONNECTION_CLOSE frame data is malformed. + QUIC_INVALID_CONNECTION_CLOSE_DATA = 7, + // GOAWAY frame data is malformed. + QUIC_INVALID_GOAWAY_DATA = 8, + // WINDOW_UPDATE frame data is malformed. + QUIC_INVALID_WINDOW_UPDATE_DATA = 57, + // BLOCKED frame data is malformed. + QUIC_INVALID_BLOCKED_DATA = 58, + // STOP_WAITING frame data is malformed. + QUIC_INVALID_STOP_WAITING_DATA = 60, + // PATH_CLOSE frame data is malformed. + QUIC_INVALID_PATH_CLOSE_DATA = 78, + // ACK frame data is malformed. + QUIC_INVALID_ACK_DATA = 9, + // Message frame data is malformed. + QUIC_INVALID_MESSAGE_DATA = 112, + + // Version negotiation packet is malformed. + QUIC_INVALID_VERSION_NEGOTIATION_PACKET = 10, + // Public RST packet is malformed. + QUIC_INVALID_PUBLIC_RST_PACKET = 11, + // There was an error decrypting. + QUIC_DECRYPTION_FAILURE = 12, + // There was an error encrypting. + QUIC_ENCRYPTION_FAILURE = 13, + // The packet exceeded kMaxOutgoingPacketSize. + QUIC_PACKET_TOO_LARGE = 14, + // The peer is going away. May be a client or server. + QUIC_PEER_GOING_AWAY = 16, + // A stream ID was invalid. + QUIC_INVALID_STREAM_ID = 17, + // A priority was invalid. + QUIC_INVALID_PRIORITY = 49, + // Too many streams already open. + QUIC_TOO_MANY_OPEN_STREAMS = 18, + // The peer created too many available streams. + QUIC_TOO_MANY_AVAILABLE_STREAMS = 76, + // Received public reset for this connection. + QUIC_PUBLIC_RESET = 19, + // Version selected by client is not acceptable to the server. + QUIC_INVALID_VERSION = 20, + // Received packet indicates version that does not match connection version. + QUIC_PACKET_WRONG_VERSION = 212, + + // The Header ID for a stream was too far from the previous. + QUIC_INVALID_HEADER_ID = 22, + // Negotiable parameter received during handshake had invalid value. + QUIC_INVALID_NEGOTIATED_VALUE = 23, + // There was an error decompressing data. + QUIC_DECOMPRESSION_FAILURE = 24, + // The connection timed out due to no network activity. + QUIC_NETWORK_IDLE_TIMEOUT = 25, + // The connection timed out waiting for the handshake to complete. + QUIC_HANDSHAKE_TIMEOUT = 67, + // There was an error encountered migrating addresses. + QUIC_ERROR_MIGRATING_ADDRESS = 26, + // There was an error encountered migrating port only. + QUIC_ERROR_MIGRATING_PORT = 86, + // There was an error while writing to the socket. + QUIC_PACKET_WRITE_ERROR = 27, + // There was an error while reading from the socket. + QUIC_PACKET_READ_ERROR = 51, + // We received a STREAM_FRAME with no data and no fin flag set. + QUIC_EMPTY_STREAM_FRAME_NO_FIN = 50, + // We received invalid data on the headers stream. + QUIC_INVALID_HEADERS_STREAM_DATA = 56, + // Invalid data on the headers stream received because of decompression + // failure. + QUIC_HEADERS_STREAM_DATA_DECOMPRESS_FAILURE = 97, + // The peer received too much data, violating flow control. + QUIC_FLOW_CONTROL_RECEIVED_TOO_MUCH_DATA = 59, + // The peer sent too much data, violating flow control. + QUIC_FLOW_CONTROL_SENT_TOO_MUCH_DATA = 63, + // The peer received an invalid flow control window. + QUIC_FLOW_CONTROL_INVALID_WINDOW = 64, + // The connection has been IP pooled into an existing connection. + QUIC_CONNECTION_IP_POOLED = 62, + // The connection has too many outstanding sent packets. + QUIC_TOO_MANY_OUTSTANDING_SENT_PACKETS = 68, + // The connection has too many outstanding received packets. + QUIC_TOO_MANY_OUTSTANDING_RECEIVED_PACKETS = 69, + // The quic connection has been cancelled. + QUIC_CONNECTION_CANCELLED = 70, + // Disabled QUIC because of high packet loss rate. + QUIC_BAD_PACKET_LOSS_RATE = 71, + // Disabled QUIC because of too many PUBLIC_RESETs post handshake. + QUIC_PUBLIC_RESETS_POST_HANDSHAKE = 73, + // Closed because we failed to serialize a packet. + QUIC_FAILED_TO_SERIALIZE_PACKET = 75, + // QUIC timed out after too many RTOs. + QUIC_TOO_MANY_RTOS = 85, + + // Crypto errors. + + // Handshake failed. + QUIC_HANDSHAKE_FAILED = 28, + // Handshake message contained out of order tags. + QUIC_CRYPTO_TAGS_OUT_OF_ORDER = 29, + // Handshake message contained too many entries. + QUIC_CRYPTO_TOO_MANY_ENTRIES = 30, + // Handshake message contained an invalid value length. + QUIC_CRYPTO_INVALID_VALUE_LENGTH = 31, + // A crypto message was received after the handshake was complete. + QUIC_CRYPTO_MESSAGE_AFTER_HANDSHAKE_COMPLETE = 32, + // A crypto message was received with an illegal message tag. + QUIC_INVALID_CRYPTO_MESSAGE_TYPE = 33, + // A crypto message was received with an illegal parameter. + QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER = 34, + // An invalid channel id signature was supplied. + QUIC_INVALID_CHANNEL_ID_SIGNATURE = 52, + // A crypto message was received with a mandatory parameter missing. + QUIC_CRYPTO_MESSAGE_PARAMETER_NOT_FOUND = 35, + // A crypto message was received with a parameter that has no overlap + // with the local parameter. + QUIC_CRYPTO_MESSAGE_PARAMETER_NO_OVERLAP = 36, + // A crypto message was received that contained a parameter with too few + // values. + QUIC_CRYPTO_MESSAGE_INDEX_NOT_FOUND = 37, + // A demand for an unsupport proof type was received. + QUIC_UNSUPPORTED_PROOF_DEMAND = 94, + // An internal error occurred in crypto processing. + QUIC_CRYPTO_INTERNAL_ERROR = 38, + // A crypto handshake message specified an unsupported version. + QUIC_CRYPTO_VERSION_NOT_SUPPORTED = 39, + // (Deprecated) A crypto handshake message resulted in a stateless reject. + // QUIC_CRYPTO_HANDSHAKE_STATELESS_REJECT = 72, + // There was no intersection between the crypto primitives supported by the + // peer and ourselves. + QUIC_CRYPTO_NO_SUPPORT = 40, + // The server rejected our client hello messages too many times. + QUIC_CRYPTO_TOO_MANY_REJECTS = 41, + // The client rejected the server's certificate chain or signature. + QUIC_PROOF_INVALID = 42, + // A crypto message was received with a duplicate tag. + QUIC_CRYPTO_DUPLICATE_TAG = 43, + // A crypto message was received with the wrong encryption level (i.e. it + // should have been encrypted but was not.) + QUIC_CRYPTO_ENCRYPTION_LEVEL_INCORRECT = 44, + // The server config for a server has expired. + QUIC_CRYPTO_SERVER_CONFIG_EXPIRED = 45, + // We failed to setup the symmetric keys for a connection. + QUIC_CRYPTO_SYMMETRIC_KEY_SETUP_FAILED = 53, + // A handshake message arrived, but we are still validating the + // previous handshake message. + QUIC_CRYPTO_MESSAGE_WHILE_VALIDATING_CLIENT_HELLO = 54, + // A server config update arrived before the handshake is complete. + QUIC_CRYPTO_UPDATE_BEFORE_HANDSHAKE_COMPLETE = 65, + // CHLO cannot fit in one packet. + QUIC_CRYPTO_CHLO_TOO_LARGE = 90, + // This connection involved a version negotiation which appears to have been + // tampered with. + QUIC_VERSION_NEGOTIATION_MISMATCH = 55, + + // Multipath errors. + // Multipath is not enabled, but a packet with multipath flag on is received. + QUIC_BAD_MULTIPATH_FLAG = 79, + // A path is supposed to exist but does not. + QUIC_MULTIPATH_PATH_DOES_NOT_EXIST = 91, + // A path is supposed to be active but is not. + QUIC_MULTIPATH_PATH_NOT_ACTIVE = 92, + + // IP address changed causing connection close. + QUIC_IP_ADDRESS_CHANGED = 80, + + // Connection migration errors. + // Network changed, but connection had no migratable streams. + QUIC_CONNECTION_MIGRATION_NO_MIGRATABLE_STREAMS = 81, + // Connection changed networks too many times. + QUIC_CONNECTION_MIGRATION_TOO_MANY_CHANGES = 82, + // Connection migration was attempted, but there was no new network to + // migrate to. + QUIC_CONNECTION_MIGRATION_NO_NEW_NETWORK = 83, + // Network changed, but connection had one or more non-migratable streams. + QUIC_CONNECTION_MIGRATION_NON_MIGRATABLE_STREAM = 84, + // Network changed, but connection migration was disabled by config. + QUIC_CONNECTION_MIGRATION_DISABLED_BY_CONFIG = 99, + // Network changed, but error was encountered on the alternative network. + QUIC_CONNECTION_MIGRATION_INTERNAL_ERROR = 100, + // Network changed, but handshake is not confirmed yet. + QUIC_CONNECTION_MIGRATION_HANDSHAKE_UNCONFIRMED = 111, + QUIC_PEER_PORT_CHANGE_HANDSHAKE_UNCONFIRMED = 194, + + // Stream frames arrived too discontiguously so that stream sequencer buffer + // maintains too many intervals. + QUIC_TOO_MANY_STREAM_DATA_INTERVALS = 93, + + // Sequencer buffer get into weird state where continuing read/write will lead + // to crash. + QUIC_STREAM_SEQUENCER_INVALID_STATE = 95, + + // Connection closed because of server hits max number of sessions allowed. + QUIC_TOO_MANY_SESSIONS_ON_SERVER = 96, + + // Receive a RST_STREAM with offset larger than kMaxStreamLength. + QUIC_STREAM_LENGTH_OVERFLOW = 98, + // Received a MAX DATA frame with errors. + QUIC_INVALID_MAX_DATA_FRAME_DATA = 102, + // Received a MAX STREAM DATA frame with errors. + QUIC_INVALID_MAX_STREAM_DATA_FRAME_DATA = 103, + // Received a MAX_STREAMS frame with bad data + QUIC_MAX_STREAMS_DATA = 104, + // Received a STREAMS_BLOCKED frame with bad data + QUIC_STREAMS_BLOCKED_DATA = 105, + // Error deframing a STREAM BLOCKED frame. + QUIC_INVALID_STREAM_BLOCKED_DATA = 106, + // NEW CONNECTION ID frame data is malformed. + QUIC_INVALID_NEW_CONNECTION_ID_DATA = 107, + // More connection IDs than allowed are issued. + QUIC_CONNECTION_ID_LIMIT_ERROR = 203, + // The peer retires connection IDs too quickly. + QUIC_TOO_MANY_CONNECTION_ID_WAITING_TO_RETIRE = 204, + // Received a MAX STREAM DATA frame with errors. + QUIC_INVALID_STOP_SENDING_FRAME_DATA = 108, + // Error deframing PATH CHALLENGE or PATH RESPONSE frames. + QUIC_INVALID_PATH_CHALLENGE_DATA = 109, + QUIC_INVALID_PATH_RESPONSE_DATA = 110, + // This is used to indicate an IETF QUIC PROTOCOL VIOLATION + // transport error within Google (pre-v99) QUIC. + IETF_QUIC_PROTOCOL_VIOLATION = 113, + QUIC_INVALID_NEW_TOKEN = 114, + + // Received stream data on a WRITE_UNIDIRECTIONAL stream. + QUIC_DATA_RECEIVED_ON_WRITE_UNIDIRECTIONAL_STREAM = 115, + // Try to send stream data on a READ_UNIDIRECTIONAL stream. + QUIC_TRY_TO_WRITE_DATA_ON_READ_UNIDIRECTIONAL_STREAM = 116, + + // RETIRE CONNECTION ID frame data is malformed. + QUIC_INVALID_RETIRE_CONNECTION_ID_DATA = 117, + + // Error in a received STREAMS BLOCKED frame. + QUIC_STREAMS_BLOCKED_ERROR = 118, + // Error in a received MAX STREAMS frame + QUIC_MAX_STREAMS_ERROR = 119, + // Error in Http decoder + QUIC_HTTP_DECODER_ERROR = 120, + // Connection from stale host needs to be cancelled. + QUIC_STALE_CONNECTION_CANCELLED = 121, + + // A pseudo error, used as an extended error reason code in the error_details + // of IETF-QUIC CONNECTION_CLOSE frames. It is used in + // OnConnectionClosed upcalls to indicate that extended error information was + // not available in a received CONNECTION_CLOSE frame. + QUIC_IETF_GQUIC_ERROR_MISSING = 122, + + // Received WindowUpdate on a READ_UNIDIRECTIONAL stream. + QUIC_WINDOW_UPDATE_RECEIVED_ON_READ_UNIDIRECTIONAL_STREAM = 123, + + // There are too many buffered control frames in control frame manager. + QUIC_TOO_MANY_BUFFERED_CONTROL_FRAMES = 124, + + // QuicTransport received invalid client indication. + QUIC_TRANSPORT_INVALID_CLIENT_INDICATION = 125, + + // Internal error codes for QPACK errors. + QUIC_QPACK_DECOMPRESSION_FAILED = 126, + + // Obsolete generic QPACK encoder and decoder stream error codes. + QUIC_QPACK_ENCODER_STREAM_ERROR = 127, + QUIC_QPACK_DECODER_STREAM_ERROR = 128, + + // QPACK encoder stream errors. + + // Variable integer exceeding 2^64-1 received. + QUIC_QPACK_ENCODER_STREAM_INTEGER_TOO_LARGE = 174, + // String literal exceeding kStringLiteralLengthLimit in length received. + QUIC_QPACK_ENCODER_STREAM_STRING_LITERAL_TOO_LONG = 175, + // String literal with invalid Huffman encoding received. + QUIC_QPACK_ENCODER_STREAM_HUFFMAN_ENCODING_ERROR = 176, + // Invalid static table index in Insert With Name Reference instruction. + QUIC_QPACK_ENCODER_STREAM_INVALID_STATIC_ENTRY = 177, + // Error inserting entry with static name reference in Insert With Name + // Reference instruction due to entry size exceeding dynamic table capacity. + QUIC_QPACK_ENCODER_STREAM_ERROR_INSERTING_STATIC = 178, + // Invalid relative index in Insert With Name Reference instruction. + QUIC_QPACK_ENCODER_STREAM_INSERTION_INVALID_RELATIVE_INDEX = 179, + // Dynamic entry not found in Insert With Name Reference instruction. + QUIC_QPACK_ENCODER_STREAM_INSERTION_DYNAMIC_ENTRY_NOT_FOUND = 180, + // Error inserting entry with dynamic name reference in Insert With Name + // Reference instruction due to entry size exceeding dynamic table capacity. + QUIC_QPACK_ENCODER_STREAM_ERROR_INSERTING_DYNAMIC = 181, + // Error inserting entry in Insert With Literal Name instruction due to entry + // size exceeding dynamic table capacity. + QUIC_QPACK_ENCODER_STREAM_ERROR_INSERTING_LITERAL = 182, + // Invalid relative index in Duplicate instruction. + QUIC_QPACK_ENCODER_STREAM_DUPLICATE_INVALID_RELATIVE_INDEX = 183, + // Dynamic entry not found in Duplicate instruction. + QUIC_QPACK_ENCODER_STREAM_DUPLICATE_DYNAMIC_ENTRY_NOT_FOUND = 184, + // Error in Set Dynamic Table Capacity instruction due to new capacity + // exceeding maximum dynamic table capacity. + QUIC_QPACK_ENCODER_STREAM_SET_DYNAMIC_TABLE_CAPACITY = 185, + + // QPACK decoder stream errors. + + // Variable integer exceeding 2^64-1 received. + QUIC_QPACK_DECODER_STREAM_INTEGER_TOO_LARGE = 186, + // Insert Count Increment instruction received with invalid 0 increment. + QUIC_QPACK_DECODER_STREAM_INVALID_ZERO_INCREMENT = 187, + // Insert Count Increment instruction causes uint64_t overflow. + QUIC_QPACK_DECODER_STREAM_INCREMENT_OVERFLOW = 188, + // Insert Count Increment instruction increases Known Received Count beyond + // inserted entry cound. + QUIC_QPACK_DECODER_STREAM_IMPOSSIBLE_INSERT_COUNT = 189, + // Header Acknowledgement received for stream that has no outstanding header + // blocks. + QUIC_QPACK_DECODER_STREAM_INCORRECT_ACKNOWLEDGEMENT = 190, + + // Received stream data beyond close offset. + QUIC_STREAM_DATA_BEYOND_CLOSE_OFFSET = 129, + + // Received multiple close offset. + QUIC_STREAM_MULTIPLE_OFFSET = 130, + + // HTTP/3 errors. + + // Frame payload larger than what HttpDecoder is willing to buffer. + QUIC_HTTP_FRAME_TOO_LARGE = 131, + // Malformed HTTP/3 frame, or PUSH_PROMISE or CANCEL_PUSH received (which is + // an error because MAX_PUSH_ID is never sent). + QUIC_HTTP_FRAME_ERROR = 132, + // A frame that is never allowed on a request stream is received. + QUIC_HTTP_FRAME_UNEXPECTED_ON_SPDY_STREAM = 133, + // A frame that is never allowed on the control stream is received. + QUIC_HTTP_FRAME_UNEXPECTED_ON_CONTROL_STREAM = 134, + // An invalid sequence of frames normally allowed on a request stream is + // received. + QUIC_HTTP_INVALID_FRAME_SEQUENCE_ON_SPDY_STREAM = 151, + // A second SETTINGS frame is received on the control stream. + QUIC_HTTP_INVALID_FRAME_SEQUENCE_ON_CONTROL_STREAM = 152, + // A second instance of a unidirectional stream of a certain type is created. + QUIC_HTTP_DUPLICATE_UNIDIRECTIONAL_STREAM = 153, + // Client receives a server-initiated bidirectional stream. + QUIC_HTTP_SERVER_INITIATED_BIDIRECTIONAL_STREAM = 154, + // Server opens stream with stream ID corresponding to client-initiated + // stream or vice versa. + QUIC_HTTP_STREAM_WRONG_DIRECTION = 155, + // Peer closes one of the six critical unidirectional streams (control, QPACK + // encoder or decoder, in either direction). + QUIC_HTTP_CLOSED_CRITICAL_STREAM = 156, + // The first frame received on the control stream is not a SETTINGS frame. + QUIC_HTTP_MISSING_SETTINGS_FRAME = 157, + // The received SETTINGS frame contains duplicate setting identifiers. + QUIC_HTTP_DUPLICATE_SETTING_IDENTIFIER = 158, + // MAX_PUSH_ID frame received with push ID value smaller than a previously + // received value. + QUIC_HTTP_INVALID_MAX_PUSH_ID = 159, + // Received unidirectional stream limit is lower than required by HTTP/3. + QUIC_HTTP_STREAM_LIMIT_TOO_LOW = 160, + // Received mismatched SETTINGS frame from HTTP/3 connection where early data + // is accepted. Server violated the HTTP/3 spec. + QUIC_HTTP_ZERO_RTT_RESUMPTION_SETTINGS_MISMATCH = 164, + // Received mismatched SETTINGS frame from HTTP/3 connection where early data + // is rejected. Our implementation currently doesn't support it. + QUIC_HTTP_ZERO_RTT_REJECTION_SETTINGS_MISMATCH = 165, + // Client received GOAWAY frame with stream ID that is not for a + // client-initiated bidirectional stream. + QUIC_HTTP_GOAWAY_INVALID_STREAM_ID = 166, + // Received GOAWAY frame with ID that is greater than previously received ID. + QUIC_HTTP_GOAWAY_ID_LARGER_THAN_PREVIOUS = 167, + // HTTP/3 session received SETTINGS frame which contains HTTP/2 specific + // settings. + QUIC_HTTP_RECEIVE_SPDY_SETTING = 169, + // HTTP/3 session received an HTTP/2 only frame. + QUIC_HTTP_RECEIVE_SPDY_FRAME = 171, + // HTTP/3 session received SERVER_PUSH stream, which is an error because + // PUSH_PROMISE is not accepted. + QUIC_HTTP_RECEIVE_SERVER_PUSH = 205, + // HTTP/3 session received invalid SETTING value. + QUIC_HTTP_INVALID_SETTING_VALUE = 207, + + // HPACK header block decoding errors. + // Index varint beyond implementation limit. + QUIC_HPACK_INDEX_VARINT_ERROR = 135, + // Name length varint beyond implementation limit. + QUIC_HPACK_NAME_LENGTH_VARINT_ERROR = 136, + // Value length varint beyond implementation limit. + QUIC_HPACK_VALUE_LENGTH_VARINT_ERROR = 137, + // Name length exceeds buffer limit. + QUIC_HPACK_NAME_TOO_LONG = 138, + // Value length exceeds buffer limit. + QUIC_HPACK_VALUE_TOO_LONG = 139, + // Name Huffman encoding error. + QUIC_HPACK_NAME_HUFFMAN_ERROR = 140, + // Value Huffman encoding error. + QUIC_HPACK_VALUE_HUFFMAN_ERROR = 141, + // Next instruction should have been a dynamic table size update. + QUIC_HPACK_MISSING_DYNAMIC_TABLE_SIZE_UPDATE = 142, + // Invalid index in indexed header field representation. + QUIC_HPACK_INVALID_INDEX = 143, + // Invalid index in literal header field with indexed name representation. + QUIC_HPACK_INVALID_NAME_INDEX = 144, + // Dynamic table size update not allowed. + QUIC_HPACK_DYNAMIC_TABLE_SIZE_UPDATE_NOT_ALLOWED = 145, + // Initial dynamic table size update is above low water mark. + QUIC_HPACK_INITIAL_TABLE_SIZE_UPDATE_IS_ABOVE_LOW_WATER_MARK = 146, + // Dynamic table size update is above acknowledged setting. + QUIC_HPACK_TABLE_SIZE_UPDATE_IS_ABOVE_ACKNOWLEDGED_SETTING = 147, + // HPACK block ends in the middle of an instruction. + QUIC_HPACK_TRUNCATED_BLOCK = 148, + // Incoming data fragment exceeds buffer limit. + QUIC_HPACK_FRAGMENT_TOO_LONG = 149, + // Total compressed HPACK data size exceeds limit. + QUIC_HPACK_COMPRESSED_HEADER_SIZE_EXCEEDS_LIMIT = 150, + + // Stream/flow control limit from 1-RTT handshake is too low to retransmit + // 0-RTT data. This is our implentation error. We could in theory keep the + // connection alive but chose not to for simplicity. + QUIC_ZERO_RTT_UNRETRANSMITTABLE = 161, + // Stream/flow control limit from 0-RTT rejection reduces cached limit. + // This is our implentation error. We could in theory keep the connection + // alive but chose not to for simplicity. + QUIC_ZERO_RTT_REJECTION_LIMIT_REDUCED = 162, + // Stream/flow control limit from 0-RTT resumption reduces cached limit. + // This is the peer violating QUIC spec. + QUIC_ZERO_RTT_RESUMPTION_LIMIT_REDUCED = 163, + + // The connection silently timed out due to no network activity. + QUIC_SILENT_IDLE_TIMEOUT = 168, + + // Try to write data without the right write keys. + QUIC_MISSING_WRITE_KEYS = 170, + + // An endpoint detected errors in performing key updates. + QUIC_KEY_UPDATE_ERROR = 172, + + // An endpoint has reached the confidentiality or integrity limit for the + // AEAD algorithm used by the given connection. + QUIC_AEAD_LIMIT_REACHED = 173, + + // Connection reached maximum age (regardless of activity), no new requests + // are accepted. This error code is sent in transport layer GOAWAY frame when + // using gQUIC, and only used internally when using HTTP/3. Active requests + // are still served, after which connection will be closed due to idle + // timeout. + QUIC_MAX_AGE_TIMEOUT = 191, + + // Decrypted a 0-RTT packet with a higher packet number than a 1-RTT packet. + QUIC_INVALID_0RTT_PACKET_NUMBER_OUT_OF_ORDER = 192, + + // Received PRIORITY_UPDATE frame with invalid payload. + QUIC_INVALID_PRIORITY_UPDATE = 193, + + // Maps to specific errors from the CRYPTO_ERROR range from + // https://quicwg.org/base-drafts/draft-ietf-quic-transport.html#name-transport-error-codes + // This attempts to choose a subset of the most interesting errors rather + // than mapping every possible CRYPTO_ERROR code. + QUIC_TLS_BAD_CERTIFICATE = 195, + QUIC_TLS_UNSUPPORTED_CERTIFICATE = 196, + QUIC_TLS_CERTIFICATE_REVOKED = 197, + QUIC_TLS_CERTIFICATE_EXPIRED = 198, + QUIC_TLS_CERTIFICATE_UNKNOWN = 199, + QUIC_TLS_INTERNAL_ERROR = 200, + QUIC_TLS_UNRECOGNIZED_NAME = 201, + QUIC_TLS_CERTIFICATE_REQUIRED = 202, + + // An HTTP field value containing an invalid character has been received. + QUIC_INVALID_CHARACTER_IN_FIELD_VALUE = 206, + + // Error code related to the usage of TLS keying material export. + QUIC_TLS_UNEXPECTED_KEYING_MATERIAL_EXPORT_LABEL = 208, + QUIC_TLS_KEYING_MATERIAL_EXPORTS_MISMATCH = 209, + QUIC_TLS_KEYING_MATERIAL_EXPORT_NOT_AVAILABLE = 210, + QUIC_UNEXPECTED_DATA_BEFORE_ENCRYPTION_ESTABLISHED = 211, + + // Error code related to backend health-check. + QUIC_SERVER_UNHEALTHY = 213, + + // No error. Used as bound while iterating. + QUIC_LAST_ERROR = 214, +}; +// QuicErrorCodes is encoded as four octets on-the-wire when doing Google QUIC, +// or a varint62 when doing IETF QUIC. Ensure that its value does not exceed +// the smaller of the two limits. +static_assert(static_cast(QUIC_LAST_ERROR) <= + static_cast(std::numeric_limits::max()), + "QuicErrorCode exceeds four octets"); + +// Wire values for HTTP/3 errors. +// https://www.rfc-editor.org/rfc/rfc9114.html#http-error-codes +enum class QuicHttp3ErrorCode { + // NO_ERROR is defined as a C preprocessor macro on Windows. + HTTP3_NO_ERROR = 0x100, + GENERAL_PROTOCOL_ERROR = 0x101, + INTERNAL_ERROR = 0x102, + STREAM_CREATION_ERROR = 0x103, + CLOSED_CRITICAL_STREAM = 0x104, + FRAME_UNEXPECTED = 0x105, + FRAME_ERROR = 0x106, + EXCESSIVE_LOAD = 0x107, + ID_ERROR = 0x108, + SETTINGS_ERROR = 0x109, + MISSING_SETTINGS = 0x10A, + REQUEST_REJECTED = 0x10B, + REQUEST_CANCELLED = 0x10C, + REQUEST_INCOMPLETE = 0x10D, + MESSAGE_ERROR = 0x10E, + CONNECT_ERROR = 0x10F, + VERSION_FALLBACK = 0x110, +}; + +// Wire values for QPACK errors. +// https://www.rfc-editor.org/rfc/rfc9204.html#error-code-registration +enum class QuicHttpQpackErrorCode { + DECOMPRESSION_FAILED = 0x200, + ENCODER_STREAM_ERROR = 0x201, + DECODER_STREAM_ERROR = 0x202 +}; + +// Represents a reason for resetting a stream in both gQUIC and IETF error code +// space. Both error codes have to be present. +class QUIC_EXPORT_PRIVATE QuicResetStreamError { + public: + // Constructs a QuicResetStreamError from QuicRstStreamErrorCode; the IETF + // error code is inferred. + static QuicResetStreamError FromInternal(QuicRstStreamErrorCode code); + // Constructs a QuicResetStreamError from an IETF error code; the internal + // error code is inferred. + static QuicResetStreamError FromIetf(uint64_t code); + static QuicResetStreamError FromIetf(QuicHttp3ErrorCode code); + static QuicResetStreamError FromIetf(QuicHttpQpackErrorCode code); + // Constructs a QuicResetStreamError with no error. + static QuicResetStreamError NoError() { + return FromInternal(QUIC_STREAM_NO_ERROR); + } + + QuicResetStreamError(QuicRstStreamErrorCode internal_code, + uint64_t ietf_application_code) + : internal_code_(internal_code), + ietf_application_code_(ietf_application_code) {} + + QuicRstStreamErrorCode internal_code() const { return internal_code_; } + uint64_t ietf_application_code() const { return ietf_application_code_; } + + bool operator==(const QuicResetStreamError& other) const { + return internal_code() == other.internal_code() && + ietf_application_code() == other.ietf_application_code(); + } + + // Returns true if the object holds no error. + bool ok() const { return internal_code() == QUIC_STREAM_NO_ERROR; } + + private: + // Error code used in gQUIC. Even when IETF QUIC is in use, this needs to be + // populated as we use those internally. + QuicRstStreamErrorCode internal_code_; + // Application error code used in IETF QUIC. + uint64_t ietf_application_code_; +}; + +// Convert TLS alert code to QuicErrorCode. +QUIC_EXPORT_PRIVATE QuicErrorCode TlsAlertToQuicErrorCode(uint8_t desc); + +// Returns the name of the QuicRstStreamErrorCode as a char* +QUIC_EXPORT_PRIVATE const char* QuicRstStreamErrorCodeToString( + QuicRstStreamErrorCode error); + +// Returns the name of the QuicErrorCode as a char* +QUIC_EXPORT_PRIVATE const char* QuicErrorCodeToString(QuicErrorCode error); + +// Wire values for QUIC transport errors. +// https://quicwg.org/base-drafts/draft-ietf-quic-transport.html#name-transport-error-codes +enum QuicIetfTransportErrorCodes : uint64_t { + NO_IETF_QUIC_ERROR = 0x0, + INTERNAL_ERROR = 0x1, + SERVER_BUSY_ERROR = 0x2, + FLOW_CONTROL_ERROR = 0x3, + STREAM_LIMIT_ERROR = 0x4, + STREAM_STATE_ERROR = 0x5, + FINAL_SIZE_ERROR = 0x6, + FRAME_ENCODING_ERROR = 0x7, + TRANSPORT_PARAMETER_ERROR = 0x8, + CONNECTION_ID_LIMIT_ERROR = 0x9, + PROTOCOL_VIOLATION = 0xA, + INVALID_TOKEN = 0xB, + CRYPTO_BUFFER_EXCEEDED = 0xD, + KEY_UPDATE_ERROR = 0xE, + AEAD_LIMIT_REACHED = 0xF, + CRYPTO_ERROR_FIRST = 0x100, + CRYPTO_ERROR_LAST = 0x1FF, +}; + +QUIC_EXPORT_PRIVATE std::string QuicIetfTransportErrorCodeString( + QuicIetfTransportErrorCodes c); + +QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const QuicIetfTransportErrorCodes& c); + +// A transport error code (if is_transport_close is true) or application error +// code (if is_transport_close is false) to be used in CONNECTION_CLOSE frames. +struct QUIC_EXPORT_PRIVATE QuicErrorCodeToIetfMapping { + bool is_transport_close; + uint64_t error_code; +}; + +// Convert QuicErrorCode to transport or application IETF error code +// to be used in CONNECTION_CLOSE frames. +QUIC_EXPORT_PRIVATE QuicErrorCodeToIetfMapping +QuicErrorCodeToTransportErrorCode(QuicErrorCode error); + +// Convert a QuicRstStreamErrorCode to an application error code to be used in +// an IETF QUIC RESET_STREAM frame +QUIC_EXPORT_PRIVATE uint64_t RstStreamErrorCodeToIetfResetStreamErrorCode( + QuicRstStreamErrorCode rst_stream_error_code); + +// Convert the application error code of an IETF QUIC RESET_STREAM frame +// to QuicRstStreamErrorCode. +QUIC_EXPORT_PRIVATE QuicRstStreamErrorCode +IetfResetStreamErrorCodeToRstStreamErrorCode(uint64_t ietf_error_code); + +QUIC_EXPORT_PRIVATE inline std::string HistogramEnumString( + QuicErrorCode enum_value) { + return QuicErrorCodeToString(enum_value); +} + +QUIC_EXPORT_PRIVATE inline std::string HistogramEnumDescription( + QuicErrorCode /*dummy*/) { + return "cause"; +} + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_ERROR_CODES_H_ diff --git a/quiche/quic/core/quic_error_codes_test.cc b/quiche/quic/core/quic_error_codes_test.cc new file mode 100644 index 000000000000..42b254193a52 --- /dev/null +++ b/quiche/quic/core/quic_error_codes_test.cc @@ -0,0 +1,143 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_error_codes.h" + +#include + +#include "openssl/ssl.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { +namespace { + +using QuicErrorCodesTest = QuicTest; + +TEST_F(QuicErrorCodesTest, QuicErrorCodeToString) { + EXPECT_STREQ("QUIC_NO_ERROR", QuicErrorCodeToString(QUIC_NO_ERROR)); +} + +TEST_F(QuicErrorCodesTest, QuicIetfTransportErrorCodeString) { + EXPECT_EQ("CRYPTO_ERROR(missing extension)", + QuicIetfTransportErrorCodeString( + static_cast( + CRYPTO_ERROR_FIRST + SSL_AD_MISSING_EXTENSION))); + + EXPECT_EQ("NO_IETF_QUIC_ERROR", + QuicIetfTransportErrorCodeString(NO_IETF_QUIC_ERROR)); + EXPECT_EQ("INTERNAL_ERROR", QuicIetfTransportErrorCodeString(INTERNAL_ERROR)); + EXPECT_EQ("SERVER_BUSY_ERROR", + QuicIetfTransportErrorCodeString(SERVER_BUSY_ERROR)); + EXPECT_EQ("FLOW_CONTROL_ERROR", + QuicIetfTransportErrorCodeString(FLOW_CONTROL_ERROR)); + EXPECT_EQ("STREAM_LIMIT_ERROR", + QuicIetfTransportErrorCodeString(STREAM_LIMIT_ERROR)); + EXPECT_EQ("STREAM_STATE_ERROR", + QuicIetfTransportErrorCodeString(STREAM_STATE_ERROR)); + EXPECT_EQ("FINAL_SIZE_ERROR", + QuicIetfTransportErrorCodeString(FINAL_SIZE_ERROR)); + EXPECT_EQ("FRAME_ENCODING_ERROR", + QuicIetfTransportErrorCodeString(FRAME_ENCODING_ERROR)); + EXPECT_EQ("TRANSPORT_PARAMETER_ERROR", + QuicIetfTransportErrorCodeString(TRANSPORT_PARAMETER_ERROR)); + EXPECT_EQ("CONNECTION_ID_LIMIT_ERROR", + QuicIetfTransportErrorCodeString(CONNECTION_ID_LIMIT_ERROR)); + EXPECT_EQ("PROTOCOL_VIOLATION", + QuicIetfTransportErrorCodeString(PROTOCOL_VIOLATION)); + EXPECT_EQ("INVALID_TOKEN", QuicIetfTransportErrorCodeString(INVALID_TOKEN)); + EXPECT_EQ("CRYPTO_BUFFER_EXCEEDED", + QuicIetfTransportErrorCodeString(CRYPTO_BUFFER_EXCEEDED)); + EXPECT_EQ("KEY_UPDATE_ERROR", + QuicIetfTransportErrorCodeString(KEY_UPDATE_ERROR)); + EXPECT_EQ("AEAD_LIMIT_REACHED", + QuicIetfTransportErrorCodeString(AEAD_LIMIT_REACHED)); + + EXPECT_EQ("Unknown(1024)", + QuicIetfTransportErrorCodeString( + static_cast(0x400))); +} + +TEST_F(QuicErrorCodesTest, QuicErrorCodeToTransportErrorCode) { + for (int internal_error_code = 0; internal_error_code < QUIC_LAST_ERROR; + ++internal_error_code) { + std::string internal_error_code_string = + QuicErrorCodeToString(static_cast(internal_error_code)); + if (internal_error_code_string == "INVALID_ERROR_CODE") { + // Not a valid QuicErrorCode. + continue; + } + QuicErrorCodeToIetfMapping ietf_error_code = + QuicErrorCodeToTransportErrorCode( + static_cast(internal_error_code)); + if (ietf_error_code.is_transport_close) { + QuicIetfTransportErrorCodes transport_error_code = + static_cast(ietf_error_code.error_code); + bool is_transport_crypto_error_code = + transport_error_code >= 0x100 && transport_error_code <= 0x1ff; + if (is_transport_crypto_error_code) { + // Ensure that every QuicErrorCode that maps to a CRYPTO_ERROR code has + // a corresponding reverse mapping in TlsAlertToQuicErrorCode: + EXPECT_EQ( + internal_error_code, + TlsAlertToQuicErrorCode(transport_error_code - CRYPTO_ERROR_FIRST)); + } + bool is_valid_transport_error_code = + transport_error_code <= 0x0f || is_transport_crypto_error_code; + EXPECT_TRUE(is_valid_transport_error_code) << internal_error_code_string; + } else { + // Non-transport errors are application errors, either HTTP/3 or QPACK. + uint64_t application_error_code = ietf_error_code.error_code; + bool is_valid_http3_error_code = + application_error_code >= 0x100 && application_error_code <= 0x110; + bool is_valid_qpack_error_code = + application_error_code >= 0x200 && application_error_code <= 0x202; + EXPECT_TRUE(is_valid_http3_error_code || is_valid_qpack_error_code) + << internal_error_code_string; + } + } +} + +using QuicRstErrorCodesTest = QuicTest; + +TEST_F(QuicRstErrorCodesTest, QuicRstStreamErrorCodeToString) { + EXPECT_STREQ("QUIC_BAD_APPLICATION_PAYLOAD", + QuicRstStreamErrorCodeToString(QUIC_BAD_APPLICATION_PAYLOAD)); +} + +// When an IETF application protocol error code (sent on the wire in +// RESET_STREAM and STOP_SENDING frames) is translated into a +// QuicRstStreamErrorCode and back, it must yield the original value. +TEST_F(QuicRstErrorCodesTest, + IetfResetStreamErrorCodeToRstStreamErrorCodeAndBack) { + for (uint64_t wire_code : + {static_cast(QuicHttp3ErrorCode::HTTP3_NO_ERROR), + static_cast(QuicHttp3ErrorCode::GENERAL_PROTOCOL_ERROR), + static_cast(QuicHttp3ErrorCode::INTERNAL_ERROR), + static_cast(QuicHttp3ErrorCode::STREAM_CREATION_ERROR), + static_cast(QuicHttp3ErrorCode::CLOSED_CRITICAL_STREAM), + static_cast(QuicHttp3ErrorCode::FRAME_UNEXPECTED), + static_cast(QuicHttp3ErrorCode::FRAME_ERROR), + static_cast(QuicHttp3ErrorCode::EXCESSIVE_LOAD), + static_cast(QuicHttp3ErrorCode::ID_ERROR), + static_cast(QuicHttp3ErrorCode::SETTINGS_ERROR), + static_cast(QuicHttp3ErrorCode::MISSING_SETTINGS), + static_cast(QuicHttp3ErrorCode::REQUEST_REJECTED), + static_cast(QuicHttp3ErrorCode::REQUEST_CANCELLED), + static_cast(QuicHttp3ErrorCode::REQUEST_INCOMPLETE), + static_cast(QuicHttp3ErrorCode::CONNECT_ERROR), + static_cast(QuicHttp3ErrorCode::VERSION_FALLBACK), + static_cast(QuicHttpQpackErrorCode::DECOMPRESSION_FAILED), + static_cast(QuicHttpQpackErrorCode::ENCODER_STREAM_ERROR), + static_cast(QuicHttpQpackErrorCode::DECODER_STREAM_ERROR)}) { + QuicRstStreamErrorCode rst_stream_error_code = + IetfResetStreamErrorCodeToRstStreamErrorCode(wire_code); + EXPECT_EQ(wire_code, RstStreamErrorCodeToIetfResetStreamErrorCode( + rst_stream_error_code)); + } +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_flags_list.h b/quiche/quic/core/quic_flags_list.h new file mode 100644 index 000000000000..cebbbff55fe0 --- /dev/null +++ b/quiche/quic/core/quic_flags_list.h @@ -0,0 +1,106 @@ +// Copyright (c) 2023 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This file is autogenerated by the QUICHE Copybara export script. + +#ifdef QUIC_FLAG + +QUIC_FLAG(quic_restart_flag_quic_offload_pacing_to_usps2, false) +// A testonly reloadable flag that will always default to false. +QUIC_FLAG(quic_reloadable_flag_quic_testonly_default_false, false) +// A testonly reloadable flag that will always default to true. +QUIC_FLAG(quic_reloadable_flag_quic_testonly_default_true, true) +// A testonly restart flag that will always default to false. +QUIC_FLAG(quic_restart_flag_quic_testonly_default_false, false) +// A testonly restart flag that will always default to true. +QUIC_FLAG(quic_restart_flag_quic_testonly_default_true, true) +// If trrue, early return before write control frame in OnCanWrite() if the connection is already closed. +QUIC_FLAG(quic_reloadable_flag_quic_no_write_control_frame_upon_connection_close, true) +// If true, QUIC will default enable MTU discovery at server, with a target of 1450 bytes. +QUIC_FLAG(quic_reloadable_flag_quic_enable_mtu_discovery_at_server, false) +// If true, QuicGsoBatchWriter will support release time if it is available and the process has the permission to do so. +QUIC_FLAG(quic_restart_flag_quic_support_release_time_for_gso, false) +// If true, ack frequency frame can be sent from server to client. +QUIC_FLAG(quic_reloadable_flag_quic_can_send_ack_frequency, true) +// If true, allow client to enable BBRv2 on server via connection option \'B2ON\'. +QUIC_FLAG(quic_reloadable_flag_quic_allow_client_enabled_bbr_v2, true) +// If true, an endpoint does not detect path degrading or blackholing until handshake gets confirmed. +QUIC_FLAG(quic_reloadable_flag_quic_no_path_degrading_before_handshake_confirmed, true) +// If true, default-enable 5RTO blachole detection. +QUIC_FLAG(quic_reloadable_flag_quic_default_enable_5rto_blackhole_detection2, true) +// If true, disable QUIC version Q043. +QUIC_FLAG(quic_reloadable_flag_quic_disable_version_q043, true) +// If true, disable QUIC version Q046. +QUIC_FLAG(quic_reloadable_flag_quic_disable_version_q046, true) +// If true, disable QUIC version Q050. +QUIC_FLAG(quic_reloadable_flag_quic_disable_version_q050, true) +// If true, disable QUIC version h3 (RFCv1). +QUIC_FLAG(quic_reloadable_flag_quic_disable_version_rfcv1, false) +// If true, disable QUIC version h3-29. +QUIC_FLAG(quic_reloadable_flag_quic_disable_version_draft_29, false) +// If true, disable blackhole detection on server side. +QUIC_FLAG(quic_reloadable_flag_quic_disable_server_blackhole_detection, false) +// If true, disable resumption when receiving NRES connection option. +QUIC_FLAG(quic_reloadable_flag_quic_enable_disable_resumption, true) +// If true, discard INITIAL packet if the key has been dropped. +QUIC_FLAG(quic_reloadable_flag_quic_discard_initial_packet_with_key_dropped, true) +// If true, do not close QUIC connection in SSL_QUIC_METHOD.send_alert, instead close it after SSL_do_handshake failed. +QUIC_FLAG(quic_reloadable_flag_quic_dont_close_connection_in_tls_alert_callback, true) +// If true, do not issue a new connection ID that has been claimed by another connection. +QUIC_FLAG(quic_reloadable_flag_quic_check_cid_collision_when_issue_new_cid, true) +// If true, enable server retransmittable on wire PING. +QUIC_FLAG(quic_reloadable_flag_quic_enable_server_on_wire_ping, true) +// If true, flush pending frames as well as pending padding bytes on connection migration. +QUIC_FLAG(quic_reloadable_flag_quic_flush_pending_frames_and_padding_bytes_on_migration, true) +// If true, ietf connection migration is no longer conditioned on connection option RVCM. +QUIC_FLAG(quic_reloadable_flag_quic_remove_connection_migration_connection_option_v2, true) +// If true, include stream information in idle timeout connection close detail. +QUIC_FLAG(quic_reloadable_flag_quic_add_stream_info_to_idle_close_detail, true) +// If true, quic server will send ENABLE_CONNECT_PROTOCOL setting and and endpoint will validate required request/response headers and extended CONNECT mechanism and update code counts of valid/invalid headers. +QUIC_FLAG(quic_reloadable_flag_quic_verify_request_headers_2, true) +// If true, reject or send error response code upon receiving invalid request or response headers. This flag depends on --gfe2_reloadable_flag_quic_verify_request_headers_2. +QUIC_FLAG(quic_reloadable_flag_quic_act_upon_invalid_header, false) +// If true, require handshake confirmation for QUIC connections, functionally disabling 0-rtt handshakes. +QUIC_FLAG(quic_reloadable_flag_quic_require_handshake_confirmation, false) +// If true, respect the incremental parameter of each stream in QuicWriteBlockedList. +QUIC_FLAG(quic_reloadable_flag_quic_priority_respect_incremental, false) +// If true, round-robin stream writes instead of batching in QuicWriteBlockedList. +QUIC_FLAG(quic_reloadable_flag_quic_disable_batch_write, false) +// If true, server proactively retires client issued connection ID on reverse path validation failure. +QUIC_FLAG(quic_reloadable_flag_quic_retire_cid_on_reverse_path_validation_failure, true) +// If true, server sends bandwidth eastimate when network is idle for a while. +QUIC_FLAG(quic_restart_flag_quic_enable_sending_bandwidth_estimate_when_network_idle_v2, true) +// If true, set burst token to 2 in cwnd bootstrapping experiment. +QUIC_FLAG(quic_reloadable_flag_quic_conservative_bursts, false) +// If true, use BBRv2 as the default congestion controller. Takes precedence over --quic_default_to_bbr. +QUIC_FLAG(quic_reloadable_flag_quic_default_to_bbr_v2, false) +// If true, use a LRU cache to record client addresses of packets received on server\'s original address. +QUIC_FLAG(quic_reloadable_flag_quic_use_received_client_addresses_cache, true) +// If true, use new connection ID in connection migration. +QUIC_FLAG(quic_reloadable_flag_quic_connection_migration_use_new_cid_v2, true) +// If true, use next_connection_id_sequence_number to validate retired cid number. +QUIC_FLAG(quic_reloadable_flag_quic_check_retire_cid_with_next_cid_sequence_number, true) +// If true, uses conservative cwnd gain and pacing gain when cwnd gets bootstrapped. +QUIC_FLAG(quic_reloadable_flag_quic_conservative_cwnd_and_pacing_gains, false) +// If true, when TicketCrypter fails to encrypt a session ticket, quic::TlsServerHandshaker will send a placeholder ticket, instead of an empty one, to the client. +QUIC_FLAG(quic_reloadable_flag_quic_send_placeholder_ticket_when_encrypt_ticket_fails, true) +// When true, check what sockopt is used to set the IP TOS byte on the platform. +QUIC_FLAG(quic_restart_flag_quic_platform_tos_sockopt, false) +// When true, defaults to BBR congestion control instead of Cubic. +QUIC_FLAG(quic_reloadable_flag_quic_default_to_bbr, false) +// When true, quiche UDP sockets report Explicit Congestion Notification (ECN) [RFC3168, RFC9330] results. +QUIC_FLAG(quic_restart_flag_quic_quiche_ecn_sockets, true) +// When true, report received ECN markings to the peer. +QUIC_FLAG(quic_restart_flag_quic_receive_ecn, true) +// When true, support draft-ietf-quic-v2-08 +QUIC_FLAG(quic_reloadable_flag_quic_enable_version_2_draft_08, false) +// When true, the BB2U copt causes BBR2 to wait two rounds with out draining the queue before exiting PROBE_UP and BB2S has the same effect in STARTUP. +QUIC_FLAG(quic_reloadable_flag_quic_bbr2_probe_two_rounds, true) +// When true, the BBHI copt causes QUIC BBRv2 to use a simpler algorithm for raising inflight_hi in PROBE_UP. +QUIC_FLAG(quic_reloadable_flag_quic_bbr2_simplify_inflight_hi, true) +// When true, the BBR4 copt sets the extra_acked window to 20 RTTs and BBR5 sets it to 40 RTTs. +QUIC_FLAG(quic_reloadable_flag_quic_bbr2_extra_acked_window, true) + +#endif + diff --git a/quiche/quic/core/quic_flow_controller.cc b/quiche/quic/core/quic_flow_controller.cc new file mode 100644 index 000000000000..4acdd513f3f7 --- /dev/null +++ b/quiche/quic/core/quic_flow_controller.cc @@ -0,0 +1,314 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_flow_controller.h" + +#include + +#include "absl/strings/str_cat.h" +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +#define ENDPOINT \ + (perspective_ == Perspective::IS_SERVER ? "Server: " : "Client: ") + +std::string QuicFlowController::LogLabel() { + if (is_connection_flow_controller_) { + return "connection"; + } + return absl::StrCat("stream ", id_); +} + +QuicFlowController::QuicFlowController( + QuicSession* session, QuicStreamId id, bool is_connection_flow_controller, + QuicStreamOffset send_window_offset, QuicStreamOffset receive_window_offset, + QuicByteCount receive_window_size_limit, + bool should_auto_tune_receive_window, + QuicFlowControllerInterface* session_flow_controller) + : session_(session), + connection_(session->connection()), + id_(id), + is_connection_flow_controller_(is_connection_flow_controller), + perspective_(session->perspective()), + bytes_sent_(0), + send_window_offset_(send_window_offset), + bytes_consumed_(0), + highest_received_byte_offset_(0), + receive_window_offset_(receive_window_offset), + receive_window_size_(receive_window_offset), + receive_window_size_limit_(receive_window_size_limit), + auto_tune_receive_window_(should_auto_tune_receive_window), + session_flow_controller_(session_flow_controller), + last_blocked_send_window_offset_(0), + prev_window_update_time_(QuicTime::Zero()) { + QUICHE_DCHECK_LE(receive_window_size_, receive_window_size_limit_); + QUICHE_DCHECK_EQ( + is_connection_flow_controller_, + QuicUtils::GetInvalidStreamId(session_->transport_version()) == id_); + + QUIC_DVLOG(1) << ENDPOINT << "Created flow controller for " << LogLabel() + << ", setting initial receive window offset to: " + << receive_window_offset_ + << ", max receive window to: " << receive_window_size_ + << ", max receive window limit to: " + << receive_window_size_limit_ + << ", setting send window offset to: " << send_window_offset_; +} + +void QuicFlowController::AddBytesConsumed(QuicByteCount bytes_consumed) { + bytes_consumed_ += bytes_consumed; + QUIC_DVLOG(1) << ENDPOINT << LogLabel() << " consumed " << bytes_consumed_ + << " bytes."; + + MaybeSendWindowUpdate(); +} + +bool QuicFlowController::UpdateHighestReceivedOffset( + QuicStreamOffset new_offset) { + // Only update if offset has increased. + if (new_offset <= highest_received_byte_offset_) { + return false; + } + + QUIC_DVLOG(1) << ENDPOINT << LogLabel() + << " highest byte offset increased from " + << highest_received_byte_offset_ << " to " << new_offset; + highest_received_byte_offset_ = new_offset; + return true; +} + +void QuicFlowController::AddBytesSent(QuicByteCount bytes_sent) { + if (bytes_sent_ + bytes_sent > send_window_offset_) { + QUIC_BUG(quic_bug_10836_1) + << ENDPOINT << LogLabel() << " Trying to send an extra " << bytes_sent + << " bytes, when bytes_sent = " << bytes_sent_ + << ", and send_window_offset_ = " << send_window_offset_; + bytes_sent_ = send_window_offset_; + + // This is an error on our side, close the connection as soon as possible. + connection_->CloseConnection( + QUIC_FLOW_CONTROL_SENT_TOO_MUCH_DATA, + absl::StrCat(send_window_offset_ - (bytes_sent_ + bytes_sent), + "bytes over send window offset"), + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + + bytes_sent_ += bytes_sent; + QUIC_DVLOG(1) << ENDPOINT << LogLabel() << " sent " << bytes_sent_ + << " bytes."; +} + +bool QuicFlowController::FlowControlViolation() { + if (highest_received_byte_offset_ > receive_window_offset_) { + QUIC_DLOG(INFO) << ENDPOINT << "Flow control violation on " << LogLabel() + << ", receive window offset: " << receive_window_offset_ + << ", highest received byte offset: " + << highest_received_byte_offset_; + return true; + } + return false; +} + +void QuicFlowController::MaybeIncreaseMaxWindowSize() { + // Core of receive window auto tuning. This method should be called before a + // WINDOW_UPDATE frame is sent. Ideally, window updates should occur close to + // once per RTT. If a window update happens much faster than RTT, it implies + // that the flow control window is imposing a bottleneck. To prevent this, + // this method will increase the receive window size (subject to a reasonable + // upper bound). For simplicity this algorithm is deliberately asymmetric, in + // that it may increase window size but never decreases. + + // Keep track of timing between successive window updates. + QuicTime now = connection_->clock()->ApproximateNow(); + QuicTime prev = prev_window_update_time_; + prev_window_update_time_ = now; + if (!prev.IsInitialized()) { + QUIC_DVLOG(1) << ENDPOINT << "first window update for " << LogLabel(); + return; + } + + if (!auto_tune_receive_window_) { + return; + } + + // Get outbound RTT. + QuicTime::Delta rtt = + connection_->sent_packet_manager().GetRttStats()->smoothed_rtt(); + if (rtt.IsZero()) { + QUIC_DVLOG(1) << ENDPOINT << "rtt zero for " << LogLabel(); + return; + } + + // Now we can compare timing of window updates with RTT. + QuicTime::Delta since_last = now - prev; + QuicTime::Delta two_rtt = 2 * rtt; + + if (since_last >= two_rtt) { + // If interval between window updates is sufficiently large, there + // is no need to increase receive_window_size_. + return; + } + QuicByteCount old_window = receive_window_size_; + IncreaseWindowSize(); + + if (receive_window_size_ > old_window) { + QUIC_DVLOG(1) << ENDPOINT << "New max window increase for " << LogLabel() + << " after " << since_last.ToMicroseconds() + << " us, and RTT is " << rtt.ToMicroseconds() + << "us. max wndw: " << receive_window_size_; + if (session_flow_controller_ != nullptr) { + session_flow_controller_->EnsureWindowAtLeast( + kSessionFlowControlMultiplier * receive_window_size_); + } + } else { + // TODO(ckrasic) - add a varz to track this (?). + QUIC_LOG_FIRST_N(INFO, 1) + << ENDPOINT << "Max window at limit for " << LogLabel() << " after " + << since_last.ToMicroseconds() << " us, and RTT is " + << rtt.ToMicroseconds() << "us. Limit size: " << receive_window_size_; + } +} + +void QuicFlowController::IncreaseWindowSize() { + receive_window_size_ *= 2; + receive_window_size_ = + std::min(receive_window_size_, receive_window_size_limit_); +} + +QuicByteCount QuicFlowController::WindowUpdateThreshold() { + return receive_window_size_ / 2; +} + +void QuicFlowController::MaybeSendWindowUpdate() { + if (!session_->connection()->connected()) { + return; + } + // Send WindowUpdate to increase receive window if + // (receive window offset - consumed bytes) < (max window / 2). + // This is behaviour copied from SPDY. + QUICHE_DCHECK_LE(bytes_consumed_, receive_window_offset_); + QuicStreamOffset available_window = receive_window_offset_ - bytes_consumed_; + QuicByteCount threshold = WindowUpdateThreshold(); + + if (!prev_window_update_time_.IsInitialized()) { + // Treat the initial window as if it is a window update, so if 1/2 the + // window is used in less than 2 RTTs, the window is increased. + prev_window_update_time_ = connection_->clock()->ApproximateNow(); + } + + if (available_window >= threshold) { + QUIC_DVLOG(1) << ENDPOINT << "Not sending WindowUpdate for " << LogLabel() + << ", available window: " << available_window + << " >= threshold: " << threshold; + return; + } + + MaybeIncreaseMaxWindowSize(); + UpdateReceiveWindowOffsetAndSendWindowUpdate(available_window); +} + +void QuicFlowController::UpdateReceiveWindowOffsetAndSendWindowUpdate( + QuicStreamOffset available_window) { + // Update our receive window. + receive_window_offset_ += (receive_window_size_ - available_window); + + QUIC_DVLOG(1) << ENDPOINT << "Sending WindowUpdate frame for " << LogLabel() + << ", consumed bytes: " << bytes_consumed_ + << ", available window: " << available_window + << ", and threshold: " << WindowUpdateThreshold() + << ", and receive window size: " << receive_window_size_ + << ". New receive window offset is: " << receive_window_offset_; + + SendWindowUpdate(); +} + +void QuicFlowController::MaybeSendBlocked() { + if (SendWindowSize() != 0 || + last_blocked_send_window_offset_ >= send_window_offset_) { + return; + } + QUIC_DLOG(INFO) << ENDPOINT << LogLabel() << " is flow control blocked. " + << "Send window: " << SendWindowSize() + << ", bytes sent: " << bytes_sent_ + << ", send limit: " << send_window_offset_; + // The entire send_window has been consumed, we are now flow control + // blocked. + + // Keep track of when we last sent a BLOCKED frame so that we only send one + // at a given send offset. + last_blocked_send_window_offset_ = send_window_offset_; + session_->SendBlocked(id_, last_blocked_send_window_offset_); +} + +bool QuicFlowController::UpdateSendWindowOffset( + QuicStreamOffset new_send_window_offset) { + // Only update if send window has increased. + if (new_send_window_offset <= send_window_offset_) { + return false; + } + + QUIC_DVLOG(1) << ENDPOINT << "UpdateSendWindowOffset for " << LogLabel() + << " with new offset " << new_send_window_offset + << " current offset: " << send_window_offset_ + << " bytes_sent: " << bytes_sent_; + + // The flow is now unblocked but could have also been unblocked + // before. Return true iff this update caused a change from blocked + // to unblocked. + const bool was_previously_blocked = IsBlocked(); + send_window_offset_ = new_send_window_offset; + return was_previously_blocked; +} + +void QuicFlowController::EnsureWindowAtLeast(QuicByteCount window_size) { + if (receive_window_size_limit_ >= window_size) { + return; + } + + QuicStreamOffset available_window = receive_window_offset_ - bytes_consumed_; + IncreaseWindowSize(); + UpdateReceiveWindowOffsetAndSendWindowUpdate(available_window); +} + +bool QuicFlowController::IsBlocked() const { return SendWindowSize() == 0; } + +uint64_t QuicFlowController::SendWindowSize() const { + if (bytes_sent_ > send_window_offset_) { + return 0; + } + return send_window_offset_ - bytes_sent_; +} + +void QuicFlowController::UpdateReceiveWindowSize(QuicStreamOffset size) { + QUICHE_DCHECK_LE(size, receive_window_size_limit_); + QUIC_DVLOG(1) << ENDPOINT << "UpdateReceiveWindowSize for " << LogLabel() + << ": " << size; + if (receive_window_size_ != receive_window_offset_) { + QUIC_BUG(quic_bug_10836_2) + << "receive_window_size_:" << receive_window_size_ + << " != receive_window_offset:" << receive_window_offset_; + return; + } + receive_window_size_ = size; + receive_window_offset_ = size; +} + +void QuicFlowController::SendWindowUpdate() { + QuicStreamId id = id_; + if (is_connection_flow_controller_) { + id = QuicUtils::GetInvalidStreamId(connection_->transport_version()); + } + session_->SendWindowUpdate(id, receive_window_offset_); +} + +} // namespace quic diff --git a/quiche/quic/core/quic_flow_controller.h b/quiche/quic/core/quic_flow_controller.h new file mode 100644 index 000000000000..eaf660a8d2f3 --- /dev/null +++ b/quiche/quic/core/quic_flow_controller.h @@ -0,0 +1,216 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_FLOW_CONTROLLER_H_ +#define QUICHE_QUIC_CORE_QUIC_FLOW_CONTROLLER_H_ + +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +namespace test { +class QuicFlowControllerPeer; +} // namespace test + +class QuicConnection; +class QuicSession; + +// How much larger the session flow control window needs to be relative to any +// stream's flow control window. +const float kSessionFlowControlMultiplier = 1.5; + +class QUIC_EXPORT_PRIVATE QuicFlowControllerInterface { + public: + virtual ~QuicFlowControllerInterface() {} + + // Ensures the flow control window is at least |window_size| and send out an + // update frame if it is increased. + virtual void EnsureWindowAtLeast(QuicByteCount window_size) = 0; +}; + +// QuicFlowController allows a QUIC stream or connection to perform flow +// control. The stream/connection owns a QuicFlowController which keeps track of +// bytes sent/received, can tell the owner if it is flow control blocked, and +// can send WINDOW_UPDATE or BLOCKED frames when needed. +class QUIC_EXPORT_PRIVATE QuicFlowController + : public QuicFlowControllerInterface { + public: + QuicFlowController(QuicSession* session, QuicStreamId id, + bool is_connection_flow_controller, + QuicStreamOffset send_window_offset, + QuicStreamOffset receive_window_offset, + QuicByteCount receive_window_size_limit, + bool should_auto_tune_receive_window, + QuicFlowControllerInterface* session_flow_controller); + + QuicFlowController(const QuicFlowController&) = delete; + QuicFlowController(QuicFlowController&&) = default; + QuicFlowController& operator=(const QuicFlowController&) = delete; + + ~QuicFlowController() override {} + + // Called when we see a new highest received byte offset from the peer, either + // via a data frame or a RST. + // Returns true if this call changes highest_received_byte_offset_, and false + // in the case where |new_offset| is <= highest_received_byte_offset_. + bool UpdateHighestReceivedOffset(QuicStreamOffset new_offset); + + // Called when bytes received from the peer are consumed locally. This may + // trigger the sending of a WINDOW_UPDATE frame using |connection|. + void AddBytesConsumed(QuicByteCount bytes_consumed); + + // Called when bytes are sent to the peer. + void AddBytesSent(QuicByteCount bytes_sent); + + // Increases |send_window_offset_| if |new_send_window_offset| is + // greater than the current value. Returns true if this increase + // also causes us to change from a blocked state to unblocked. In + // all other cases, returns false. + bool UpdateSendWindowOffset(QuicStreamOffset new_send_window_offset); + + // QuicFlowControllerInterface. + void EnsureWindowAtLeast(QuicByteCount window_size) override; + + // Returns the current available send window. + QuicByteCount SendWindowSize() const; + + QuicByteCount receive_window_size() const { return receive_window_size_; } + + // Sends a BLOCKED frame if needed. + void MaybeSendBlocked(); + + // Returns true if flow control send limits have been reached. + bool IsBlocked() const; + + // Returns true if flow control receive limits have been violated by the peer. + bool FlowControlViolation(); + + // Inform the peer of new receive window. + void SendWindowUpdate(); + + QuicByteCount bytes_consumed() const { return bytes_consumed_; } + + QuicByteCount bytes_sent() const { return bytes_sent_; } + + QuicStreamOffset send_window_offset() const { return send_window_offset_; } + + QuicStreamOffset highest_received_byte_offset() const { + return highest_received_byte_offset_; + } + + void set_receive_window_size_limit(QuicByteCount receive_window_size_limit) { + QUICHE_DCHECK_GE(receive_window_size_limit, receive_window_size_limit_); + receive_window_size_limit_ = receive_window_size_limit; + } + + // Should only be called before any data is received. + void UpdateReceiveWindowSize(QuicStreamOffset size); + + bool auto_tune_receive_window() { return auto_tune_receive_window_; } + + private: + friend class test::QuicFlowControllerPeer; + + // Send a WINDOW_UPDATE frame if appropriate. + void MaybeSendWindowUpdate(); + + // Auto-tune the max receive window size. + void MaybeIncreaseMaxWindowSize(); + + // Updates the current offset and sends a window update frame. + void UpdateReceiveWindowOffsetAndSendWindowUpdate( + QuicStreamOffset available_window); + + // Double the window size as long as we haven't hit the max window size. + void IncreaseWindowSize(); + + // Returns "stream $ID" (where $ID is set to |id_|) or "connection" based on + // |is_connection_flow_controller_|. + std::string LogLabel(); + + // The parent session/connection, used to send connection close on flow + // control violation, and WINDOW_UPDATE and BLOCKED frames when appropriate. + // Not owned. + QuicSession* session_; + QuicConnection* connection_; + + // ID of stream this flow controller belongs to. If + // |is_connection_flow_controller_| is false, this must be a valid stream ID. + QuicStreamId id_; + + // Whether this flow controller is the connection level flow controller + // instead of the flow controller for a stream. If true, |id_| is ignored. + bool is_connection_flow_controller_; + + // Tracks if this is owned by a server or a client. + Perspective perspective_; + + // Tracks number of bytes sent to the peer. + QuicByteCount bytes_sent_; + + // The absolute offset in the outgoing byte stream. If this offset is reached + // then we become flow control blocked until we receive a WINDOW_UPDATE. + QuicStreamOffset send_window_offset_; + + // Overview of receive flow controller. + // + // 0=...===1=======2-------3 ...... FIN + // |<--- <= 4 --->| + // + + // 1) bytes_consumed_ - moves forward when data is read out of the + // stream. + // + // 2) highest_received_byte_offset_ - moves when data is received + // from the peer. + // + // 3) receive_window_offset_ - moves when WINDOW_UPDATE is sent. + // + // 4) receive_window_size_ - maximum allowed unread data (3 - 1). + // This value may be increased by auto-tuning. + // + // 5) receive_window_size_limit_ - limit on receive_window_size_; + // auto-tuning will not increase window size beyond this limit. + + // Track number of bytes received from the peer, which have been consumed + // locally. + QuicByteCount bytes_consumed_; + + // The highest byte offset we have seen from the peer. This could be the + // highest offset in a data frame, or a final value in a RST. + QuicStreamOffset highest_received_byte_offset_; + + // The absolute offset in the incoming byte stream. The peer should never send + // us bytes which are beyond this offset. + QuicStreamOffset receive_window_offset_; + + // Largest size the receive window can grow to. + QuicByteCount receive_window_size_; + + // Upper limit on receive_window_size_; + QuicByteCount receive_window_size_limit_; + + // Used to dynamically enable receive window auto-tuning. + bool auto_tune_receive_window_; + + // The session's flow controller. Null if this is the session flow controller. + // Not owned. + QuicFlowControllerInterface* session_flow_controller_; + + // Send window update when receive window size drops below this. + QuicByteCount WindowUpdateThreshold(); + + // Keep track of the last time we sent a BLOCKED frame. We should only send + // another when the number of bytes we have sent has changed. + QuicStreamOffset last_blocked_send_window_offset_; + + // Keep time of the last time a window update was sent. We use this + // as part of the receive window auto tuning. + QuicTime prev_window_update_time_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_FLOW_CONTROLLER_H_ diff --git a/quiche/quic/core/quic_flow_controller_test.cc b/quiche/quic/core/quic_flow_controller_test.cc new file mode 100644 index 000000000000..567b120e682a --- /dev/null +++ b/quiche/quic/core/quic_flow_controller_test.cc @@ -0,0 +1,416 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_flow_controller.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_flow_controller_peer.h" +#include "quiche/quic/test_tools/quic_sent_packet_manager_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +using testing::_; +using testing::Invoke; +using testing::StrictMock; + +namespace quic { +namespace test { + +// Receive window auto-tuning uses RTT in its logic. +const int64_t kRtt = 100; + +class MockFlowController : public QuicFlowControllerInterface { + public: + MockFlowController() {} + MockFlowController(const MockFlowController&) = delete; + MockFlowController& operator=(const MockFlowController&) = delete; + ~MockFlowController() override {} + + MOCK_METHOD(void, EnsureWindowAtLeast, (QuicByteCount), (override)); +}; + +class QuicFlowControllerTest : public QuicTest { + public: + void Initialize() { + connection_ = new MockQuicConnection(&helper_, &alarm_factory_, + Perspective::IS_CLIENT); + connection_->SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(connection_->perspective())); + session_ = std::make_unique>(connection_); + flow_controller_ = std::make_unique( + session_.get(), stream_id_, /*is_connection_flow_controller*/ false, + send_window_, receive_window_, kStreamReceiveWindowLimit, + should_auto_tune_receive_window_, &session_flow_controller_); + } + + protected: + QuicStreamId stream_id_ = 1234; + QuicByteCount send_window_ = kInitialSessionFlowControlWindowForTest; + QuicByteCount receive_window_ = kInitialSessionFlowControlWindowForTest; + std::unique_ptr flow_controller_; + MockQuicConnectionHelper helper_; + MockAlarmFactory alarm_factory_; + MockQuicConnection* connection_; + std::unique_ptr> session_; + MockFlowController session_flow_controller_; + bool should_auto_tune_receive_window_ = false; +}; + +TEST_F(QuicFlowControllerTest, SendingBytes) { + Initialize(); + + EXPECT_FALSE(flow_controller_->IsBlocked()); + EXPECT_FALSE(flow_controller_->FlowControlViolation()); + EXPECT_EQ(send_window_, flow_controller_->SendWindowSize()); + + // Send some bytes, but not enough to block. + flow_controller_->AddBytesSent(send_window_ / 2); + EXPECT_FALSE(flow_controller_->IsBlocked()); + EXPECT_EQ(send_window_ / 2, flow_controller_->SendWindowSize()); + + // Send enough bytes to block. + flow_controller_->AddBytesSent(send_window_ / 2); + EXPECT_TRUE(flow_controller_->IsBlocked()); + EXPECT_EQ(0u, flow_controller_->SendWindowSize()); + + // BLOCKED frame should get sent. + EXPECT_CALL(*session_, SendBlocked(_, _)).Times(1); + flow_controller_->MaybeSendBlocked(); + + // Update the send window, and verify this has unblocked. + EXPECT_TRUE(flow_controller_->UpdateSendWindowOffset(2 * send_window_)); + EXPECT_FALSE(flow_controller_->IsBlocked()); + EXPECT_EQ(send_window_, flow_controller_->SendWindowSize()); + + // Updating with a smaller offset doesn't change anything. + EXPECT_FALSE(flow_controller_->UpdateSendWindowOffset(send_window_ / 10)); + EXPECT_EQ(send_window_, flow_controller_->SendWindowSize()); + + // Try to send more bytes, violating flow control. + EXPECT_QUIC_BUG( + { + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_FLOW_CONTROL_SENT_TOO_MUCH_DATA, _, _)); + flow_controller_->AddBytesSent(send_window_ * 10); + EXPECT_TRUE(flow_controller_->IsBlocked()); + EXPECT_EQ(0u, flow_controller_->SendWindowSize()); + }, + absl::StrCat("Trying to send an extra ", send_window_ * 10, " bytes")); +} + +TEST_F(QuicFlowControllerTest, ReceivingBytes) { + Initialize(); + + EXPECT_FALSE(flow_controller_->IsBlocked()); + EXPECT_FALSE(flow_controller_->FlowControlViolation()); + EXPECT_EQ(kInitialSessionFlowControlWindowForTest, + QuicFlowControllerPeer::ReceiveWindowSize(flow_controller_.get())); + + // Receive some bytes, updating highest received offset, but not enough to + // fill flow control receive window. + EXPECT_TRUE( + flow_controller_->UpdateHighestReceivedOffset(1 + receive_window_ / 2)); + EXPECT_FALSE(flow_controller_->FlowControlViolation()); + EXPECT_EQ((receive_window_ / 2) - 1, + QuicFlowControllerPeer::ReceiveWindowSize(flow_controller_.get())); + + // Consume enough bytes to send a WINDOW_UPDATE frame. + EXPECT_CALL(*session_, WriteControlFrame(_, _)).Times(1); + + flow_controller_->AddBytesConsumed(1 + receive_window_ / 2); + + // Result is that once again we have a fully open receive window. + EXPECT_FALSE(flow_controller_->FlowControlViolation()); + EXPECT_EQ(kInitialSessionFlowControlWindowForTest, + QuicFlowControllerPeer::ReceiveWindowSize(flow_controller_.get())); +} + +TEST_F(QuicFlowControllerTest, Move) { + Initialize(); + + flow_controller_->AddBytesSent(send_window_ / 2); + EXPECT_FALSE(flow_controller_->IsBlocked()); + EXPECT_EQ(send_window_ / 2, flow_controller_->SendWindowSize()); + + EXPECT_TRUE( + flow_controller_->UpdateHighestReceivedOffset(1 + receive_window_ / 2)); + EXPECT_FALSE(flow_controller_->FlowControlViolation()); + EXPECT_EQ((receive_window_ / 2) - 1, + QuicFlowControllerPeer::ReceiveWindowSize(flow_controller_.get())); + + QuicFlowController flow_controller2(std::move(*flow_controller_)); + EXPECT_EQ(send_window_ / 2, flow_controller2.SendWindowSize()); + EXPECT_FALSE(flow_controller2.FlowControlViolation()); + EXPECT_EQ((receive_window_ / 2) - 1, + QuicFlowControllerPeer::ReceiveWindowSize(&flow_controller2)); +} + +TEST_F(QuicFlowControllerTest, OnlySendBlockedFrameOncePerOffset) { + Initialize(); + + // Test that we don't send duplicate BLOCKED frames. We should only send one + // BLOCKED frame at a given send window offset. + EXPECT_FALSE(flow_controller_->IsBlocked()); + EXPECT_FALSE(flow_controller_->FlowControlViolation()); + EXPECT_EQ(send_window_, flow_controller_->SendWindowSize()); + + // Send enough bytes to block. + flow_controller_->AddBytesSent(send_window_); + EXPECT_TRUE(flow_controller_->IsBlocked()); + EXPECT_EQ(0u, flow_controller_->SendWindowSize()); + + // BLOCKED frame should get sent. + EXPECT_CALL(*session_, SendBlocked(_, _)).Times(1); + flow_controller_->MaybeSendBlocked(); + + // BLOCKED frame should not get sent again until our send offset changes. + EXPECT_CALL(*session_, SendBlocked(_, _)).Times(0); + flow_controller_->MaybeSendBlocked(); + flow_controller_->MaybeSendBlocked(); + flow_controller_->MaybeSendBlocked(); + flow_controller_->MaybeSendBlocked(); + flow_controller_->MaybeSendBlocked(); + + // Update the send window, then send enough bytes to block again. + EXPECT_TRUE(flow_controller_->UpdateSendWindowOffset(2 * send_window_)); + EXPECT_FALSE(flow_controller_->IsBlocked()); + EXPECT_EQ(send_window_, flow_controller_->SendWindowSize()); + flow_controller_->AddBytesSent(send_window_); + EXPECT_TRUE(flow_controller_->IsBlocked()); + EXPECT_EQ(0u, flow_controller_->SendWindowSize()); + + // BLOCKED frame should get sent as send offset has changed. + EXPECT_CALL(*session_, SendBlocked(_, _)).Times(1); + flow_controller_->MaybeSendBlocked(); +} + +TEST_F(QuicFlowControllerTest, ReceivingBytesFastIncreasesFlowWindow) { + should_auto_tune_receive_window_ = true; + Initialize(); + // This test will generate two WINDOW_UPDATE frames. + EXPECT_CALL(*session_, WriteControlFrame(_, _)).Times(1); + EXPECT_TRUE(flow_controller_->auto_tune_receive_window()); + + // Make sure clock is inititialized. + connection_->AdvanceTime(QuicTime::Delta::FromMilliseconds(1)); + + QuicSentPacketManager* manager = + QuicConnectionPeer::GetSentPacketManager(connection_); + + RttStats* rtt_stats = const_cast(manager->GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(kRtt), + QuicTime::Delta::Zero(), QuicTime::Zero()); + + EXPECT_FALSE(flow_controller_->IsBlocked()); + EXPECT_FALSE(flow_controller_->FlowControlViolation()); + EXPECT_EQ(kInitialSessionFlowControlWindowForTest, + QuicFlowControllerPeer::ReceiveWindowSize(flow_controller_.get())); + + QuicByteCount threshold = + QuicFlowControllerPeer::WindowUpdateThreshold(flow_controller_.get()); + + QuicStreamOffset receive_offset = threshold + 1; + // Receive some bytes, updating highest received offset, but not enough to + // fill flow control receive window. + EXPECT_TRUE(flow_controller_->UpdateHighestReceivedOffset(receive_offset)); + EXPECT_FALSE(flow_controller_->FlowControlViolation()); + EXPECT_EQ(kInitialSessionFlowControlWindowForTest - receive_offset, + QuicFlowControllerPeer::ReceiveWindowSize(flow_controller_.get())); + EXPECT_CALL( + session_flow_controller_, + EnsureWindowAtLeast(kInitialSessionFlowControlWindowForTest * 2 * 1.5)); + + // Consume enough bytes to send a WINDOW_UPDATE frame. + flow_controller_->AddBytesConsumed(threshold + 1); + // Result is that once again we have a fully open receive window. + EXPECT_FALSE(flow_controller_->FlowControlViolation()); + EXPECT_EQ(2 * kInitialSessionFlowControlWindowForTest, + QuicFlowControllerPeer::ReceiveWindowSize(flow_controller_.get())); + + connection_->AdvanceTime(QuicTime::Delta::FromMilliseconds(2 * kRtt - 1)); + receive_offset += threshold + 1; + EXPECT_TRUE(flow_controller_->UpdateHighestReceivedOffset(receive_offset)); + flow_controller_->AddBytesConsumed(threshold + 1); + EXPECT_FALSE(flow_controller_->FlowControlViolation()); + QuicByteCount new_threshold = + QuicFlowControllerPeer::WindowUpdateThreshold(flow_controller_.get()); + EXPECT_GT(new_threshold, threshold); +} + +TEST_F(QuicFlowControllerTest, ReceivingBytesFastNoAutoTune) { + Initialize(); + // This test will generate two WINDOW_UPDATE frames. + EXPECT_CALL(*session_, WriteControlFrame(_, _)) + .Times(2) + .WillRepeatedly(Invoke(&ClearControlFrameWithTransmissionType)); + EXPECT_FALSE(flow_controller_->auto_tune_receive_window()); + + // Make sure clock is inititialized. + connection_->AdvanceTime(QuicTime::Delta::FromMilliseconds(1)); + + QuicSentPacketManager* manager = + QuicConnectionPeer::GetSentPacketManager(connection_); + + RttStats* rtt_stats = const_cast(manager->GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(kRtt), + QuicTime::Delta::Zero(), QuicTime::Zero()); + + EXPECT_FALSE(flow_controller_->IsBlocked()); + EXPECT_FALSE(flow_controller_->FlowControlViolation()); + EXPECT_EQ(kInitialSessionFlowControlWindowForTest, + QuicFlowControllerPeer::ReceiveWindowSize(flow_controller_.get())); + + QuicByteCount threshold = + QuicFlowControllerPeer::WindowUpdateThreshold(flow_controller_.get()); + + QuicStreamOffset receive_offset = threshold + 1; + // Receive some bytes, updating highest received offset, but not enough to + // fill flow control receive window. + EXPECT_TRUE(flow_controller_->UpdateHighestReceivedOffset(receive_offset)); + EXPECT_FALSE(flow_controller_->FlowControlViolation()); + EXPECT_EQ(kInitialSessionFlowControlWindowForTest - receive_offset, + QuicFlowControllerPeer::ReceiveWindowSize(flow_controller_.get())); + + // Consume enough bytes to send a WINDOW_UPDATE frame. + flow_controller_->AddBytesConsumed(threshold + 1); + // Result is that once again we have a fully open receive window. + EXPECT_FALSE(flow_controller_->FlowControlViolation()); + EXPECT_EQ(kInitialSessionFlowControlWindowForTest, + QuicFlowControllerPeer::ReceiveWindowSize(flow_controller_.get())); + + // Move time forward, but by less than two RTTs. Then receive and consume + // some more, forcing a second WINDOW_UPDATE with an increased max window + // size. + connection_->AdvanceTime(QuicTime::Delta::FromMilliseconds(2 * kRtt - 1)); + receive_offset += threshold + 1; + EXPECT_TRUE(flow_controller_->UpdateHighestReceivedOffset(receive_offset)); + flow_controller_->AddBytesConsumed(threshold + 1); + EXPECT_FALSE(flow_controller_->FlowControlViolation()); + QuicByteCount new_threshold = + QuicFlowControllerPeer::WindowUpdateThreshold(flow_controller_.get()); + EXPECT_EQ(new_threshold, threshold); +} + +TEST_F(QuicFlowControllerTest, ReceivingBytesNormalStableFlowWindow) { + should_auto_tune_receive_window_ = true; + Initialize(); + // This test will generate two WINDOW_UPDATE frames. + EXPECT_CALL(*session_, WriteControlFrame(_, _)).Times(1); + EXPECT_TRUE(flow_controller_->auto_tune_receive_window()); + + // Make sure clock is inititialized. + connection_->AdvanceTime(QuicTime::Delta::FromMilliseconds(1)); + + QuicSentPacketManager* manager = + QuicConnectionPeer::GetSentPacketManager(connection_); + RttStats* rtt_stats = const_cast(manager->GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(kRtt), + QuicTime::Delta::Zero(), QuicTime::Zero()); + + EXPECT_FALSE(flow_controller_->IsBlocked()); + EXPECT_FALSE(flow_controller_->FlowControlViolation()); + EXPECT_EQ(kInitialSessionFlowControlWindowForTest, + QuicFlowControllerPeer::ReceiveWindowSize(flow_controller_.get())); + + QuicByteCount threshold = + QuicFlowControllerPeer::WindowUpdateThreshold(flow_controller_.get()); + + QuicStreamOffset receive_offset = threshold + 1; + // Receive some bytes, updating highest received offset, but not enough to + // fill flow control receive window. + EXPECT_TRUE(flow_controller_->UpdateHighestReceivedOffset(receive_offset)); + EXPECT_FALSE(flow_controller_->FlowControlViolation()); + EXPECT_EQ(kInitialSessionFlowControlWindowForTest - receive_offset, + QuicFlowControllerPeer::ReceiveWindowSize(flow_controller_.get())); + EXPECT_CALL( + session_flow_controller_, + EnsureWindowAtLeast(kInitialSessionFlowControlWindowForTest * 2 * 1.5)); + flow_controller_->AddBytesConsumed(threshold + 1); + + // Result is that once again we have a fully open receive window. + EXPECT_FALSE(flow_controller_->FlowControlViolation()); + EXPECT_EQ(2 * kInitialSessionFlowControlWindowForTest, + QuicFlowControllerPeer::ReceiveWindowSize(flow_controller_.get())); + + // Move time forward, but by more than two RTTs. Then receive and consume + // some more, forcing a second WINDOW_UPDATE with unchanged max window size. + connection_->AdvanceTime(QuicTime::Delta::FromMilliseconds(2 * kRtt + 1)); + + receive_offset += threshold + 1; + EXPECT_TRUE(flow_controller_->UpdateHighestReceivedOffset(receive_offset)); + + flow_controller_->AddBytesConsumed(threshold + 1); + EXPECT_FALSE(flow_controller_->FlowControlViolation()); + + QuicByteCount new_threshold = + QuicFlowControllerPeer::WindowUpdateThreshold(flow_controller_.get()); + EXPECT_EQ(new_threshold, 2 * threshold); +} + +TEST_F(QuicFlowControllerTest, ReceivingBytesNormalNoAutoTune) { + Initialize(); + // This test will generate two WINDOW_UPDATE frames. + EXPECT_CALL(*session_, WriteControlFrame(_, _)) + .Times(2) + .WillRepeatedly(Invoke(&ClearControlFrameWithTransmissionType)); + EXPECT_FALSE(flow_controller_->auto_tune_receive_window()); + + // Make sure clock is inititialized. + connection_->AdvanceTime(QuicTime::Delta::FromMilliseconds(1)); + + QuicSentPacketManager* manager = + QuicConnectionPeer::GetSentPacketManager(connection_); + RttStats* rtt_stats = const_cast(manager->GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(kRtt), + QuicTime::Delta::Zero(), QuicTime::Zero()); + + EXPECT_FALSE(flow_controller_->IsBlocked()); + EXPECT_FALSE(flow_controller_->FlowControlViolation()); + EXPECT_EQ(kInitialSessionFlowControlWindowForTest, + QuicFlowControllerPeer::ReceiveWindowSize(flow_controller_.get())); + + QuicByteCount threshold = + QuicFlowControllerPeer::WindowUpdateThreshold(flow_controller_.get()); + + QuicStreamOffset receive_offset = threshold + 1; + // Receive some bytes, updating highest received offset, but not enough to + // fill flow control receive window. + EXPECT_TRUE(flow_controller_->UpdateHighestReceivedOffset(receive_offset)); + EXPECT_FALSE(flow_controller_->FlowControlViolation()); + EXPECT_EQ(kInitialSessionFlowControlWindowForTest - receive_offset, + QuicFlowControllerPeer::ReceiveWindowSize(flow_controller_.get())); + + flow_controller_->AddBytesConsumed(threshold + 1); + + // Result is that once again we have a fully open receive window. + EXPECT_FALSE(flow_controller_->FlowControlViolation()); + EXPECT_EQ(kInitialSessionFlowControlWindowForTest, + QuicFlowControllerPeer::ReceiveWindowSize(flow_controller_.get())); + + // Move time forward, but by more than two RTTs. Then receive and consume + // some more, forcing a second WINDOW_UPDATE with unchanged max window size. + connection_->AdvanceTime(QuicTime::Delta::FromMilliseconds(2 * kRtt + 1)); + + receive_offset += threshold + 1; + EXPECT_TRUE(flow_controller_->UpdateHighestReceivedOffset(receive_offset)); + + flow_controller_->AddBytesConsumed(threshold + 1); + EXPECT_FALSE(flow_controller_->FlowControlViolation()); + + QuicByteCount new_threshold = + QuicFlowControllerPeer::WindowUpdateThreshold(flow_controller_.get()); + + EXPECT_EQ(new_threshold, threshold); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_framer.cc b/quiche/quic/core/quic_framer.cc new file mode 100644 index 000000000000..200c96174824 --- /dev/null +++ b/quiche/quic/core/quic_framer.cc @@ -0,0 +1,7306 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_framer.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/macros.h" +#include "absl/base/optimization.h" +#include "absl/cleanup/cleanup.h" +#include "absl/strings/escaping.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/quic/core/crypto/crypto_framer.h" +#include "quiche/quic/core/crypto/crypto_handshake.h" +#include "quiche/quic/core/crypto/crypto_handshake_message.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/crypto/crypto_utils.h" +#include "quiche/quic/core/crypto/null_decrypter.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/core/crypto/quic_decrypter.h" +#include "quiche/quic/core/crypto/quic_encrypter.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/frames/quic_ack_frequency_frame.h" +#include "quiche/quic/core/quic_connection_context.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_data_reader.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_socket_address_coder.h" +#include "quiche/quic/core/quic_stream_frame_data_producer.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_client_stats.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_ip_address_family.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_stack_trace.h" +#include "quiche/common/quiche_text_utils.h" + +namespace quic { + +namespace { + +#define ENDPOINT \ + (perspective_ == Perspective::IS_SERVER ? "Server: " : "Client: ") + +// Number of bits the packet number length bits are shifted from the right +// edge of the header. +const uint8_t kPublicHeaderSequenceNumberShift = 4; + +// There are two interpretations for the Frame Type byte in the QUIC protocol, +// resulting in two Frame Types: Special Frame Types and Regular Frame Types. +// +// Regular Frame Types use the Frame Type byte simply. Currently defined +// Regular Frame Types are: +// Padding : 0b 00000000 (0x00) +// ResetStream : 0b 00000001 (0x01) +// ConnectionClose : 0b 00000010 (0x02) +// GoAway : 0b 00000011 (0x03) +// WindowUpdate : 0b 00000100 (0x04) +// Blocked : 0b 00000101 (0x05) +// +// Special Frame Types encode both a Frame Type and corresponding flags +// all in the Frame Type byte. Currently defined Special Frame Types +// are: +// Stream : 0b 1xxxxxxx +// Ack : 0b 01xxxxxx +// +// Semantics of the flag bits above (the x bits) depends on the frame type. + +// Masks to determine if the frame type is a special use +// and for specific special frame types. +const uint8_t kQuicFrameTypeBrokenMask = 0xE0; // 0b 11100000 +const uint8_t kQuicFrameTypeSpecialMask = 0xC0; // 0b 11000000 +const uint8_t kQuicFrameTypeStreamMask = 0x80; +const uint8_t kQuicFrameTypeAckMask = 0x40; +static_assert(kQuicFrameTypeSpecialMask == + (kQuicFrameTypeStreamMask | kQuicFrameTypeAckMask), + "Invalid kQuicFrameTypeSpecialMask"); + +// The stream type format is 1FDOOOSS, where +// F is the fin bit. +// D is the data length bit (0 or 2 bytes). +// OO/OOO are the size of the offset. +// SS is the size of the stream ID. +// Note that the stream encoding can not be determined by inspection. It can +// be determined only by knowing the QUIC Version. +// Stream frame relative shifts and masks for interpreting the stream flags. +// StreamID may be 1, 2, 3, or 4 bytes. +const uint8_t kQuicStreamIdShift = 2; +const uint8_t kQuicStreamIDLengthMask = 0x03; + +// Offset may be 0, 2, 4, or 8 bytes. +const uint8_t kQuicStreamShift = 3; +const uint8_t kQuicStreamOffsetMask = 0x07; + +// Data length may be 0 or 2 bytes. +const uint8_t kQuicStreamDataLengthShift = 1; +const uint8_t kQuicStreamDataLengthMask = 0x01; + +// Fin bit may be set or not. +const uint8_t kQuicStreamFinShift = 1; +const uint8_t kQuicStreamFinMask = 0x01; + +// The format is 01M0LLOO, where +// M if set, there are multiple ack blocks in the frame. +// LL is the size of the largest ack field. +// OO is the size of the ack blocks offset field. +// packet number size shift used in AckFrames. +const uint8_t kQuicSequenceNumberLengthNumBits = 2; +const uint8_t kActBlockLengthOffset = 0; +const uint8_t kLargestAckedOffset = 2; + +// Acks may have only one ack block. +const uint8_t kQuicHasMultipleAckBlocksOffset = 5; + +// Timestamps are 4 bytes followed by 2 bytes. +const uint8_t kQuicNumTimestampsLength = 1; +const uint8_t kQuicFirstTimestampLength = 4; +const uint8_t kQuicTimestampLength = 2; +// Gaps between packet numbers are 1 byte. +const uint8_t kQuicTimestampPacketNumberGapLength = 1; + +// Maximum length of encoded error strings. +const int kMaxErrorStringLength = 256; + +const uint8_t kConnectionIdLengthAdjustment = 3; +const uint8_t kDestinationConnectionIdLengthMask = 0xF0; +const uint8_t kSourceConnectionIdLengthMask = 0x0F; + +// Returns the absolute value of the difference between |a| and |b|. +uint64_t Delta(uint64_t a, uint64_t b) { + // Since these are unsigned numbers, we can't just return abs(a - b) + if (a < b) { + return b - a; + } + return a - b; +} + +uint64_t ClosestTo(uint64_t target, uint64_t a, uint64_t b) { + return (Delta(target, a) < Delta(target, b)) ? a : b; +} + +QuicPacketNumberLength ReadSequenceNumberLength(uint8_t flags) { + switch (flags & PACKET_FLAGS_8BYTE_PACKET) { + case PACKET_FLAGS_8BYTE_PACKET: + return PACKET_6BYTE_PACKET_NUMBER; + case PACKET_FLAGS_4BYTE_PACKET: + return PACKET_4BYTE_PACKET_NUMBER; + case PACKET_FLAGS_2BYTE_PACKET: + return PACKET_2BYTE_PACKET_NUMBER; + case PACKET_FLAGS_1BYTE_PACKET: + return PACKET_1BYTE_PACKET_NUMBER; + default: + QUIC_BUG(quic_bug_10850_1) << "Unreachable case statement."; + return PACKET_6BYTE_PACKET_NUMBER; + } +} + +QuicPacketNumberLength ReadAckPacketNumberLength(uint8_t flags) { + switch (flags & PACKET_FLAGS_8BYTE_PACKET) { + case PACKET_FLAGS_8BYTE_PACKET: + return PACKET_6BYTE_PACKET_NUMBER; + case PACKET_FLAGS_4BYTE_PACKET: + return PACKET_4BYTE_PACKET_NUMBER; + case PACKET_FLAGS_2BYTE_PACKET: + return PACKET_2BYTE_PACKET_NUMBER; + case PACKET_FLAGS_1BYTE_PACKET: + return PACKET_1BYTE_PACKET_NUMBER; + default: + QUIC_BUG(quic_bug_10850_2) << "Unreachable case statement."; + return PACKET_6BYTE_PACKET_NUMBER; + } +} + +uint8_t PacketNumberLengthToOnWireValue( + QuicPacketNumberLength packet_number_length) { + return packet_number_length - 1; +} + +QuicPacketNumberLength GetShortHeaderPacketNumberLength(uint8_t type) { + QUICHE_DCHECK(!(type & FLAGS_LONG_HEADER)); + return static_cast((type & 0x03) + 1); +} + +uint8_t LongHeaderTypeToOnWireValue(QuicLongHeaderType type, + const ParsedQuicVersion& version) { + switch (type) { + case INITIAL: + return version.UsesV2PacketTypes() ? (1 << 4) : 0; + case ZERO_RTT_PROTECTED: + return version.UsesV2PacketTypes() ? (2 << 4) : (1 << 4); + case HANDSHAKE: + return version.UsesV2PacketTypes() ? (3 << 4) : (2 << 4); + case RETRY: + return version.UsesV2PacketTypes() ? 0 : (3 << 4); + case VERSION_NEGOTIATION: + return 0xF0; // Value does not matter + default: + QUIC_BUG(quic_bug_10850_3) << "Invalid long header type: " << type; + return 0xFF; + } +} + +QuicLongHeaderType GetLongHeaderType(uint8_t type, + const ParsedQuicVersion& version) { + QUICHE_DCHECK((type & FLAGS_LONG_HEADER)); + switch ((type & 0x30) >> 4) { + case 0: + return version.UsesV2PacketTypes() ? RETRY : INITIAL; + case 1: + return version.UsesV2PacketTypes() ? INITIAL : ZERO_RTT_PROTECTED; + case 2: + return version.UsesV2PacketTypes() ? ZERO_RTT_PROTECTED : HANDSHAKE; + case 3: + return version.UsesV2PacketTypes() ? HANDSHAKE : RETRY; + default: + QUIC_BUG(quic_bug_10850_4) << "Unreachable statement"; + return INVALID_PACKET_TYPE; + } +} + +QuicPacketNumberLength GetLongHeaderPacketNumberLength(uint8_t type) { + return static_cast((type & 0x03) + 1); +} + +// Used to get packet number space before packet gets decrypted. +PacketNumberSpace GetPacketNumberSpace(const QuicPacketHeader& header) { + switch (header.form) { + case GOOGLE_QUIC_PACKET: + QUIC_BUG(quic_bug_10850_5) + << "Try to get packet number space of Google QUIC packet"; + break; + case IETF_QUIC_SHORT_HEADER_PACKET: + return APPLICATION_DATA; + case IETF_QUIC_LONG_HEADER_PACKET: + switch (header.long_packet_type) { + case INITIAL: + return INITIAL_DATA; + case HANDSHAKE: + return HANDSHAKE_DATA; + case ZERO_RTT_PROTECTED: + return APPLICATION_DATA; + case VERSION_NEGOTIATION: + case RETRY: + case INVALID_PACKET_TYPE: + QUIC_BUG(quic_bug_10850_6) + << "Try to get packet number space of long header type: " + << QuicUtils::QuicLongHeaderTypetoString(header.long_packet_type); + break; + } + } + + return NUM_PACKET_NUMBER_SPACES; +} + +EncryptionLevel GetEncryptionLevel(const QuicPacketHeader& header) { + switch (header.form) { + case GOOGLE_QUIC_PACKET: + QUIC_BUG(quic_bug_10850_7) + << "Cannot determine EncryptionLevel from Google QUIC header"; + break; + case IETF_QUIC_SHORT_HEADER_PACKET: + return ENCRYPTION_FORWARD_SECURE; + case IETF_QUIC_LONG_HEADER_PACKET: + switch (header.long_packet_type) { + case INITIAL: + return ENCRYPTION_INITIAL; + case HANDSHAKE: + return ENCRYPTION_HANDSHAKE; + case ZERO_RTT_PROTECTED: + return ENCRYPTION_ZERO_RTT; + case VERSION_NEGOTIATION: + case RETRY: + case INVALID_PACKET_TYPE: + QUIC_BUG(quic_bug_10850_8) + << "No encryption used with type " + << QuicUtils::QuicLongHeaderTypetoString(header.long_packet_type); + } + } + return NUM_ENCRYPTION_LEVELS; +} + +absl::string_view TruncateErrorString(absl::string_view error) { + if (error.length() <= kMaxErrorStringLength) { + return error; + } + return absl::string_view(error.data(), kMaxErrorStringLength); +} + +size_t TruncatedErrorStringSize(const absl::string_view& error) { + if (error.length() < kMaxErrorStringLength) { + return error.length(); + } + return kMaxErrorStringLength; +} + +uint8_t GetConnectionIdLengthValue(uint8_t length) { + if (length == 0) { + return 0; + } + return static_cast(length - kConnectionIdLengthAdjustment); +} + +bool IsValidPacketNumberLength(QuicPacketNumberLength packet_number_length) { + size_t length = packet_number_length; + return length == 1 || length == 2 || length == 4 || length == 6 || + length == 8; +} + +bool IsValidFullPacketNumber(uint64_t full_packet_number, + ParsedQuicVersion version) { + return full_packet_number > 0 || version.HasIetfQuicFrames(); +} + +bool AppendIetfConnectionIds(bool version_flag, bool use_length_prefix, + QuicConnectionId destination_connection_id, + QuicConnectionId source_connection_id, + QuicDataWriter* writer) { + if (!version_flag) { + return writer->WriteConnectionId(destination_connection_id); + } + + if (use_length_prefix) { + return writer->WriteLengthPrefixedConnectionId(destination_connection_id) && + writer->WriteLengthPrefixedConnectionId(source_connection_id); + } + + // Compute connection ID length byte. + uint8_t dcil = GetConnectionIdLengthValue(destination_connection_id.length()); + uint8_t scil = GetConnectionIdLengthValue(source_connection_id.length()); + uint8_t connection_id_length = dcil << 4 | scil; + + return writer->WriteUInt8(connection_id_length) && + writer->WriteConnectionId(destination_connection_id) && + writer->WriteConnectionId(source_connection_id); +} + +enum class DroppedPacketReason { + // General errors + INVALID_PUBLIC_HEADER, + VERSION_MISMATCH, + // Version negotiation packet errors + INVALID_VERSION_NEGOTIATION_PACKET, + // Public reset packet errors, pre-v44 + INVALID_PUBLIC_RESET_PACKET, + // Data packet errors + INVALID_PACKET_NUMBER, + INVALID_DIVERSIFICATION_NONCE, + DECRYPTION_FAILURE, + NUM_REASONS, +}; + +void RecordDroppedPacketReason(DroppedPacketReason reason) { + QUIC_CLIENT_HISTOGRAM_ENUM("QuicDroppedPacketReason", reason, + DroppedPacketReason::NUM_REASONS, + "The reason a packet was not processed. Recorded " + "each time such a packet is dropped"); +} + +PacketHeaderFormat GetIetfPacketHeaderFormat(uint8_t type_byte) { + return type_byte & FLAGS_LONG_HEADER ? IETF_QUIC_LONG_HEADER_PACKET + : IETF_QUIC_SHORT_HEADER_PACKET; +} + +std::string GenerateErrorString(std::string initial_error_string, + QuicErrorCode quic_error_code) { + if (quic_error_code == QUIC_IETF_GQUIC_ERROR_MISSING) { + // QUIC_IETF_GQUIC_ERROR_MISSING is special -- it means not to encode + // the error value in the string. + return initial_error_string; + } + return absl::StrCat(std::to_string(static_cast(quic_error_code)), + ":", initial_error_string); +} + +// Return the minimum size of the ECN fields in an ACK frame +size_t AckEcnCountSize(const QuicAckFrame& ack_frame) { + if (!ack_frame.ecn_counters.has_value()) { + return 0; + } + return (QuicDataWriter::GetVarInt62Len(ack_frame.ecn_counters->ect0) + + QuicDataWriter::GetVarInt62Len(ack_frame.ecn_counters->ect1) + + QuicDataWriter::GetVarInt62Len(ack_frame.ecn_counters->ce)); +} + +} // namespace + +QuicFramer::QuicFramer(const ParsedQuicVersionVector& supported_versions, + QuicTime creation_time, Perspective perspective, + uint8_t expected_server_connection_id_length) + : visitor_(nullptr), + error_(QUIC_NO_ERROR), + last_serialized_server_connection_id_(EmptyQuicConnectionId()), + last_serialized_client_connection_id_(EmptyQuicConnectionId()), + version_(ParsedQuicVersion::Unsupported()), + supported_versions_(supported_versions), + decrypter_level_(ENCRYPTION_INITIAL), + alternative_decrypter_level_(NUM_ENCRYPTION_LEVELS), + alternative_decrypter_latch_(false), + perspective_(perspective), + validate_flags_(true), + process_timestamps_(false), + max_receive_timestamps_per_ack_(std::numeric_limits::max()), + receive_timestamps_exponent_(0), + creation_time_(creation_time), + last_timestamp_(QuicTime::Delta::Zero()), + support_key_update_for_connection_(false), + current_key_phase_bit_(false), + potential_peer_key_update_attempt_count_(0), + first_sending_packet_number_(FirstSendingPacketNumber()), + data_producer_(nullptr), + infer_packet_header_type_from_version_(perspective == + Perspective::IS_CLIENT), + expected_server_connection_id_length_( + expected_server_connection_id_length), + expected_client_connection_id_length_(0), + supports_multiple_packet_number_spaces_(false), + last_written_packet_number_length_(0), + peer_ack_delay_exponent_(kDefaultAckDelayExponent), + local_ack_delay_exponent_(kDefaultAckDelayExponent), + current_received_frame_type_(0), + previously_received_frame_type_(0) { + QUICHE_DCHECK(!supported_versions.empty()); + version_ = supported_versions_[0]; + QUICHE_DCHECK(version_.IsKnown()) + << ParsedQuicVersionVectorToString(supported_versions_); +} + +QuicFramer::~QuicFramer() {} + +// static +size_t QuicFramer::GetMinStreamFrameSize(QuicTransportVersion version, + QuicStreamId stream_id, + QuicStreamOffset offset, + bool last_frame_in_packet, + size_t data_length) { + if (VersionHasIetfQuicFrames(version)) { + return kQuicFrameTypeSize + QuicDataWriter::GetVarInt62Len(stream_id) + + (last_frame_in_packet + ? 0 + : QuicDataWriter::GetVarInt62Len(data_length)) + + (offset != 0 ? QuicDataWriter::GetVarInt62Len(offset) : 0); + } + return kQuicFrameTypeSize + GetStreamIdSize(stream_id) + + GetStreamOffsetSize(offset) + + (last_frame_in_packet ? 0 : kQuicStreamPayloadLengthSize); +} + +// static +size_t QuicFramer::GetMinCryptoFrameSize(QuicStreamOffset offset, + QuicPacketLength data_length) { + return kQuicFrameTypeSize + QuicDataWriter::GetVarInt62Len(offset) + + QuicDataWriter::GetVarInt62Len(data_length); +} + +// static +size_t QuicFramer::GetMessageFrameSize(QuicTransportVersion version, + bool last_frame_in_packet, + QuicByteCount length) { + QUIC_BUG_IF(quic_bug_12975_1, !VersionSupportsMessageFrames(version)) + << "Try to serialize MESSAGE frame in " << version; + return kQuicFrameTypeSize + + (last_frame_in_packet ? 0 : QuicDataWriter::GetVarInt62Len(length)) + + length; +} + +// static +size_t QuicFramer::GetMinAckFrameSize( + QuicTransportVersion version, const QuicAckFrame& ack_frame, + uint32_t local_ack_delay_exponent, + bool use_ietf_ack_with_receive_timestamp) { + if (VersionHasIetfQuicFrames(version)) { + // The minimal ack frame consists of the following fields: Largest + // Acknowledged, ACK Delay, 0 ACK Block Count, First ACK Block and either 0 + // Timestamp Range Count or ECN counts. + // Type byte + largest acked. + size_t min_size = + kQuicFrameTypeSize + + QuicDataWriter::GetVarInt62Len(LargestAcked(ack_frame).ToUint64()); + // Ack delay. + min_size += QuicDataWriter::GetVarInt62Len( + ack_frame.ack_delay_time.ToMicroseconds() >> local_ack_delay_exponent); + // 0 ack block count. + min_size += QuicDataWriter::GetVarInt62Len(0); + // First ack block. + min_size += QuicDataWriter::GetVarInt62Len( + ack_frame.packets.Empty() ? 0 + : ack_frame.packets.rbegin()->Length() - 1); + + if (use_ietf_ack_with_receive_timestamp) { + // 0 Timestamp Range Count. + min_size += QuicDataWriter::GetVarInt62Len(0); + } else { + min_size += AckEcnCountSize(ack_frame); + } + return min_size; + } + return kQuicFrameTypeSize + + GetMinPacketNumberLength(LargestAcked(ack_frame)) + + kQuicDeltaTimeLargestObservedSize + kQuicNumTimestampsSize; +} + +// static +size_t QuicFramer::GetStopWaitingFrameSize( + QuicPacketNumberLength packet_number_length) { + size_t min_size = kQuicFrameTypeSize + packet_number_length; + return min_size; +} + +// static +size_t QuicFramer::GetRstStreamFrameSize(QuicTransportVersion version, + const QuicRstStreamFrame& frame) { + if (VersionHasIetfQuicFrames(version)) { + return QuicDataWriter::GetVarInt62Len(frame.stream_id) + + QuicDataWriter::GetVarInt62Len(frame.byte_offset) + + kQuicFrameTypeSize + + QuicDataWriter::GetVarInt62Len(frame.ietf_error_code); + } + return kQuicFrameTypeSize + kQuicMaxStreamIdSize + kQuicMaxStreamOffsetSize + + kQuicErrorCodeSize; +} + +// static +size_t QuicFramer::GetConnectionCloseFrameSize( + QuicTransportVersion version, const QuicConnectionCloseFrame& frame) { + if (!VersionHasIetfQuicFrames(version)) { + // Not IETF QUIC, return Google QUIC CONNECTION CLOSE frame size. + return kQuicFrameTypeSize + kQuicErrorCodeSize + + kQuicErrorDetailsLengthSize + + TruncatedErrorStringSize(frame.error_details); + } + + // Prepend the extra error information to the string and get the result's + // length. + const size_t truncated_error_string_size = TruncatedErrorStringSize( + GenerateErrorString(frame.error_details, frame.quic_error_code)); + + const size_t frame_size = + truncated_error_string_size + + QuicDataWriter::GetVarInt62Len(truncated_error_string_size) + + kQuicFrameTypeSize + + QuicDataWriter::GetVarInt62Len(frame.wire_error_code); + if (frame.close_type == IETF_QUIC_APPLICATION_CONNECTION_CLOSE) { + return frame_size; + } + // The Transport close frame has the transport_close_frame_type, so include + // its length. + return frame_size + + QuicDataWriter::GetVarInt62Len(frame.transport_close_frame_type); +} + +// static +size_t QuicFramer::GetMinGoAwayFrameSize() { + return kQuicFrameTypeSize + kQuicErrorCodeSize + kQuicErrorDetailsLengthSize + + kQuicMaxStreamIdSize; +} + +// static +size_t QuicFramer::GetWindowUpdateFrameSize( + QuicTransportVersion version, const QuicWindowUpdateFrame& frame) { + if (!VersionHasIetfQuicFrames(version)) { + return kQuicFrameTypeSize + kQuicMaxStreamIdSize + kQuicMaxStreamOffsetSize; + } + if (frame.stream_id == QuicUtils::GetInvalidStreamId(version)) { + // Frame would be a MAX DATA frame, which has only a Maximum Data field. + return kQuicFrameTypeSize + QuicDataWriter::GetVarInt62Len(frame.max_data); + } + // Frame would be MAX STREAM DATA, has Maximum Stream Data and Stream ID + // fields. + return kQuicFrameTypeSize + QuicDataWriter::GetVarInt62Len(frame.max_data) + + QuicDataWriter::GetVarInt62Len(frame.stream_id); +} + +// static +size_t QuicFramer::GetMaxStreamsFrameSize(QuicTransportVersion version, + const QuicMaxStreamsFrame& frame) { + if (!VersionHasIetfQuicFrames(version)) { + QUIC_BUG(quic_bug_10850_9) + << "In version " << version + << ", which does not support IETF Frames, and tried to serialize " + "MaxStreams Frame."; + } + return kQuicFrameTypeSize + + QuicDataWriter::GetVarInt62Len(frame.stream_count); +} + +// static +size_t QuicFramer::GetStreamsBlockedFrameSize( + QuicTransportVersion version, const QuicStreamsBlockedFrame& frame) { + if (!VersionHasIetfQuicFrames(version)) { + QUIC_BUG(quic_bug_10850_10) + << "In version " << version + << ", which does not support IETF frames, and tried to serialize " + "StreamsBlocked Frame."; + } + + return kQuicFrameTypeSize + + QuicDataWriter::GetVarInt62Len(frame.stream_count); +} + +// static +size_t QuicFramer::GetBlockedFrameSize(QuicTransportVersion version, + const QuicBlockedFrame& frame) { + if (!VersionHasIetfQuicFrames(version)) { + return kQuicFrameTypeSize + kQuicMaxStreamIdSize; + } + if (frame.stream_id == QuicUtils::GetInvalidStreamId(version)) { + // return size of IETF QUIC Blocked frame + return kQuicFrameTypeSize + QuicDataWriter::GetVarInt62Len(frame.offset); + } + // return size of IETF QUIC Stream Blocked frame. + return kQuicFrameTypeSize + QuicDataWriter::GetVarInt62Len(frame.offset) + + QuicDataWriter::GetVarInt62Len(frame.stream_id); +} + +// static +size_t QuicFramer::GetStopSendingFrameSize(const QuicStopSendingFrame& frame) { + return kQuicFrameTypeSize + QuicDataWriter::GetVarInt62Len(frame.stream_id) + + QuicDataWriter::GetVarInt62Len(frame.ietf_error_code); +} + +// static +size_t QuicFramer::GetAckFrequencyFrameSize( + const QuicAckFrequencyFrame& frame) { + return QuicDataWriter::GetVarInt62Len(IETF_ACK_FREQUENCY) + + QuicDataWriter::GetVarInt62Len(frame.sequence_number) + + QuicDataWriter::GetVarInt62Len(frame.packet_tolerance) + + QuicDataWriter::GetVarInt62Len(frame.max_ack_delay.ToMicroseconds()) + + // One byte for encoding boolean + 1; +} + +// static +size_t QuicFramer::GetPathChallengeFrameSize( + const QuicPathChallengeFrame& frame) { + return kQuicFrameTypeSize + sizeof(frame.data_buffer); +} + +// static +size_t QuicFramer::GetPathResponseFrameSize( + const QuicPathResponseFrame& frame) { + return kQuicFrameTypeSize + sizeof(frame.data_buffer); +} + +// static +size_t QuicFramer::GetRetransmittableControlFrameSize( + QuicTransportVersion version, const QuicFrame& frame) { + switch (frame.type) { + case PING_FRAME: + // Ping has no payload. + return kQuicFrameTypeSize; + case RST_STREAM_FRAME: + return GetRstStreamFrameSize(version, *frame.rst_stream_frame); + case CONNECTION_CLOSE_FRAME: + return GetConnectionCloseFrameSize(version, + *frame.connection_close_frame); + case GOAWAY_FRAME: + return GetMinGoAwayFrameSize() + + TruncatedErrorStringSize(frame.goaway_frame->reason_phrase); + case WINDOW_UPDATE_FRAME: + // For IETF QUIC, this could be either a MAX DATA or MAX STREAM DATA. + // GetWindowUpdateFrameSize figures this out and returns the correct + // length. + return GetWindowUpdateFrameSize(version, frame.window_update_frame); + case BLOCKED_FRAME: + return GetBlockedFrameSize(version, frame.blocked_frame); + case NEW_CONNECTION_ID_FRAME: + return GetNewConnectionIdFrameSize(*frame.new_connection_id_frame); + case RETIRE_CONNECTION_ID_FRAME: + return GetRetireConnectionIdFrameSize(*frame.retire_connection_id_frame); + case NEW_TOKEN_FRAME: + return GetNewTokenFrameSize(*frame.new_token_frame); + case MAX_STREAMS_FRAME: + return GetMaxStreamsFrameSize(version, frame.max_streams_frame); + case STREAMS_BLOCKED_FRAME: + return GetStreamsBlockedFrameSize(version, frame.streams_blocked_frame); + case PATH_RESPONSE_FRAME: + return GetPathResponseFrameSize(frame.path_response_frame); + case PATH_CHALLENGE_FRAME: + return GetPathChallengeFrameSize(frame.path_challenge_frame); + case STOP_SENDING_FRAME: + return GetStopSendingFrameSize(frame.stop_sending_frame); + case HANDSHAKE_DONE_FRAME: + // HANDSHAKE_DONE has no payload. + return kQuicFrameTypeSize; + case ACK_FREQUENCY_FRAME: + return GetAckFrequencyFrameSize(*frame.ack_frequency_frame); + case STREAM_FRAME: + case ACK_FRAME: + case STOP_WAITING_FRAME: + case MTU_DISCOVERY_FRAME: + case PADDING_FRAME: + case MESSAGE_FRAME: + case CRYPTO_FRAME: + case NUM_FRAME_TYPES: + QUICHE_DCHECK(false); + return 0; + } + + // Not reachable, but some Chrome compilers can't figure that out. *sigh* + QUICHE_DCHECK(false); + return 0; +} + +// static +size_t QuicFramer::GetStreamIdSize(QuicStreamId stream_id) { + // Sizes are 1 through 4 bytes. + for (int i = 1; i <= 4; ++i) { + stream_id >>= 8; + if (stream_id == 0) { + return i; + } + } + QUIC_BUG(quic_bug_10850_11) << "Failed to determine StreamIDSize."; + return 4; +} + +// static +size_t QuicFramer::GetStreamOffsetSize(QuicStreamOffset offset) { + // 0 is a special case. + if (offset == 0) { + return 0; + } + // 2 through 8 are the remaining sizes. + offset >>= 8; + for (int i = 2; i <= 8; ++i) { + offset >>= 8; + if (offset == 0) { + return i; + } + } + QUIC_BUG(quic_bug_10850_12) << "Failed to determine StreamOffsetSize."; + return 8; +} + +// static +size_t QuicFramer::GetNewConnectionIdFrameSize( + const QuicNewConnectionIdFrame& frame) { + return kQuicFrameTypeSize + + QuicDataWriter::GetVarInt62Len(frame.sequence_number) + + QuicDataWriter::GetVarInt62Len(frame.retire_prior_to) + + kConnectionIdLengthSize + frame.connection_id.length() + + sizeof(frame.stateless_reset_token); +} + +// static +size_t QuicFramer::GetRetireConnectionIdFrameSize( + const QuicRetireConnectionIdFrame& frame) { + return kQuicFrameTypeSize + + QuicDataWriter::GetVarInt62Len(frame.sequence_number); +} + +// static +size_t QuicFramer::GetNewTokenFrameSize(const QuicNewTokenFrame& frame) { + return kQuicFrameTypeSize + + QuicDataWriter::GetVarInt62Len(frame.token.length()) + + frame.token.length(); +} + +// TODO(nharper): Change this method to take a ParsedQuicVersion. +bool QuicFramer::IsSupportedTransportVersion( + const QuicTransportVersion version) const { + for (const ParsedQuicVersion& supported_version : supported_versions_) { + if (version == supported_version.transport_version) { + return true; + } + } + return false; +} + +bool QuicFramer::IsSupportedVersion(const ParsedQuicVersion version) const { + for (const ParsedQuicVersion& supported_version : supported_versions_) { + if (version == supported_version) { + return true; + } + } + return false; +} + +size_t QuicFramer::GetSerializedFrameLength( + const QuicFrame& frame, size_t free_bytes, bool first_frame, + bool last_frame, QuicPacketNumberLength packet_number_length) { + // Prevent a rare crash reported in b/19458523. + if (frame.type == ACK_FRAME && frame.ack_frame == nullptr) { + QUIC_BUG(quic_bug_10850_13) + << "Cannot compute the length of a null ack frame. free_bytes:" + << free_bytes << " first_frame:" << first_frame + << " last_frame:" << last_frame + << " seq num length:" << packet_number_length; + set_error(QUIC_INTERNAL_ERROR); + visitor_->OnError(this); + return 0; + } + if (frame.type == PADDING_FRAME) { + if (frame.padding_frame.num_padding_bytes == -1) { + // Full padding to the end of the packet. + return free_bytes; + } else { + // Lite padding. + return free_bytes < + static_cast(frame.padding_frame.num_padding_bytes) + ? free_bytes + : frame.padding_frame.num_padding_bytes; + } + } + + size_t frame_len = + ComputeFrameLength(frame, last_frame, packet_number_length); + if (frame_len <= free_bytes) { + // Frame fits within packet. Note that acks may be truncated. + return frame_len; + } + // Only truncate the first frame in a packet, so if subsequent ones go + // over, stop including more frames. + if (!first_frame) { + return 0; + } + bool can_truncate = + frame.type == ACK_FRAME && + free_bytes >= + GetMinAckFrameSize(version_.transport_version, *frame.ack_frame, + local_ack_delay_exponent_, + UseIetfAckWithReceiveTimestamp(*frame.ack_frame)); + if (can_truncate) { + // Truncate the frame so the packet will not exceed kMaxOutgoingPacketSize. + // Note that we may not use every byte of the writer in this case. + QUIC_DLOG(INFO) << ENDPOINT + << "Truncating large frame, free bytes: " << free_bytes; + return free_bytes; + } + return 0; +} + +QuicFramer::AckFrameInfo::AckFrameInfo() + : max_block_length(0), first_block_length(0), num_ack_blocks(0) {} + +QuicFramer::AckFrameInfo::AckFrameInfo(const AckFrameInfo& other) = default; + +QuicFramer::AckFrameInfo::~AckFrameInfo() {} + +bool QuicFramer::WriteIetfLongHeaderLength(const QuicPacketHeader& header, + QuicDataWriter* writer, + size_t length_field_offset, + EncryptionLevel level) { + if (!QuicVersionHasLongHeaderLengths(transport_version()) || + !header.version_flag || length_field_offset == 0) { + return true; + } + if (writer->length() < length_field_offset || + writer->length() - length_field_offset < + quiche::kQuicheDefaultLongHeaderLengthLength) { + set_detailed_error("Invalid length_field_offset."); + QUIC_BUG(quic_bug_10850_14) << "Invalid length_field_offset."; + return false; + } + size_t length_to_write = writer->length() - length_field_offset - + quiche::kQuicheDefaultLongHeaderLengthLength; + // Add length of auth tag. + length_to_write = GetCiphertextSize(level, length_to_write); + + QuicDataWriter length_writer(writer->length() - length_field_offset, + writer->data() + length_field_offset); + if (!length_writer.WriteVarInt62WithForcedLength( + length_to_write, quiche::kQuicheDefaultLongHeaderLengthLength)) { + set_detailed_error("Failed to overwrite long header length."); + QUIC_BUG(quic_bug_10850_15) << "Failed to overwrite long header length."; + return false; + } + return true; +} + +size_t QuicFramer::BuildDataPacket(const QuicPacketHeader& header, + const QuicFrames& frames, char* buffer, + size_t packet_length, + EncryptionLevel level) { + QUIC_BUG_IF(quic_bug_12975_2, + header.version_flag && version().HasIetfInvariantHeader() && + header.long_packet_type == RETRY && !frames.empty()) + << "IETF RETRY packets cannot contain frames " << header; + QuicDataWriter writer(packet_length, buffer); + size_t length_field_offset = 0; + if (!AppendPacketHeader(header, &writer, &length_field_offset)) { + QUIC_BUG(quic_bug_10850_16) << "AppendPacketHeader failed"; + return 0; + } + + if (VersionHasIetfQuicFrames(transport_version())) { + if (AppendIetfFrames(frames, &writer) == 0) { + return 0; + } + if (!WriteIetfLongHeaderLength(header, &writer, length_field_offset, + level)) { + return 0; + } + return writer.length(); + } + + size_t i = 0; + for (const QuicFrame& frame : frames) { + // Determine if we should write stream frame length in header. + const bool last_frame_in_packet = i == frames.size() - 1; + if (!AppendTypeByte(frame, last_frame_in_packet, &writer)) { + QUIC_BUG(quic_bug_10850_17) << "AppendTypeByte failed"; + return 0; + } + + switch (frame.type) { + case PADDING_FRAME: + if (!AppendPaddingFrame(frame.padding_frame, &writer)) { + QUIC_BUG(quic_bug_10850_18) + << "AppendPaddingFrame of " + << frame.padding_frame.num_padding_bytes << " failed"; + return 0; + } + break; + case STREAM_FRAME: + if (!AppendStreamFrame(frame.stream_frame, last_frame_in_packet, + &writer)) { + QUIC_BUG(quic_bug_10850_19) << "AppendStreamFrame failed"; + return 0; + } + break; + case ACK_FRAME: + if (!AppendAckFrameAndTypeByte(*frame.ack_frame, &writer)) { + QUIC_BUG(quic_bug_10850_20) + << "AppendAckFrameAndTypeByte failed: " << detailed_error_; + return 0; + } + break; + case STOP_WAITING_FRAME: + if (!AppendStopWaitingFrame(header, frame.stop_waiting_frame, + &writer)) { + QUIC_BUG(quic_bug_10850_21) << "AppendStopWaitingFrame failed"; + return 0; + } + break; + case MTU_DISCOVERY_FRAME: + // MTU discovery frames are serialized as ping frames. + ABSL_FALLTHROUGH_INTENDED; + case PING_FRAME: + // Ping has no payload. + break; + case RST_STREAM_FRAME: + if (!AppendRstStreamFrame(*frame.rst_stream_frame, &writer)) { + QUIC_BUG(quic_bug_10850_22) << "AppendRstStreamFrame failed"; + return 0; + } + break; + case CONNECTION_CLOSE_FRAME: + if (!AppendConnectionCloseFrame(*frame.connection_close_frame, + &writer)) { + QUIC_BUG(quic_bug_10850_23) << "AppendConnectionCloseFrame failed"; + return 0; + } + break; + case GOAWAY_FRAME: + if (!AppendGoAwayFrame(*frame.goaway_frame, &writer)) { + QUIC_BUG(quic_bug_10850_24) << "AppendGoAwayFrame failed"; + return 0; + } + break; + case WINDOW_UPDATE_FRAME: + if (!AppendWindowUpdateFrame(frame.window_update_frame, &writer)) { + QUIC_BUG(quic_bug_10850_25) << "AppendWindowUpdateFrame failed"; + return 0; + } + break; + case BLOCKED_FRAME: + if (!AppendBlockedFrame(frame.blocked_frame, &writer)) { + QUIC_BUG(quic_bug_10850_26) << "AppendBlockedFrame failed"; + return 0; + } + break; + case NEW_CONNECTION_ID_FRAME: + set_detailed_error( + "Attempt to append NEW_CONNECTION_ID frame and not in IETF QUIC."); + return RaiseError(QUIC_INTERNAL_ERROR); + case RETIRE_CONNECTION_ID_FRAME: + set_detailed_error( + "Attempt to append RETIRE_CONNECTION_ID frame and not in IETF " + "QUIC."); + return RaiseError(QUIC_INTERNAL_ERROR); + case NEW_TOKEN_FRAME: + set_detailed_error( + "Attempt to append NEW_TOKEN_ID frame and not in IETF QUIC."); + return RaiseError(QUIC_INTERNAL_ERROR); + case MAX_STREAMS_FRAME: + set_detailed_error( + "Attempt to append MAX_STREAMS frame and not in IETF QUIC."); + return RaiseError(QUIC_INTERNAL_ERROR); + case STREAMS_BLOCKED_FRAME: + set_detailed_error( + "Attempt to append STREAMS_BLOCKED frame and not in IETF QUIC."); + return RaiseError(QUIC_INTERNAL_ERROR); + case PATH_RESPONSE_FRAME: + set_detailed_error( + "Attempt to append PATH_RESPONSE frame and not in IETF QUIC."); + return RaiseError(QUIC_INTERNAL_ERROR); + case PATH_CHALLENGE_FRAME: + set_detailed_error( + "Attempt to append PATH_CHALLENGE frame and not in IETF QUIC."); + return RaiseError(QUIC_INTERNAL_ERROR); + case STOP_SENDING_FRAME: + set_detailed_error( + "Attempt to append STOP_SENDING frame and not in IETF QUIC."); + return RaiseError(QUIC_INTERNAL_ERROR); + case MESSAGE_FRAME: + if (!AppendMessageFrameAndTypeByte(*frame.message_frame, + last_frame_in_packet, &writer)) { + QUIC_BUG(quic_bug_10850_27) << "AppendMessageFrame failed"; + return 0; + } + break; + case CRYPTO_FRAME: + if (!QuicVersionUsesCryptoFrames(version_.transport_version)) { + set_detailed_error( + "Attempt to append CRYPTO frame in version prior to 47."); + return RaiseError(QUIC_INTERNAL_ERROR); + } + if (!AppendCryptoFrame(*frame.crypto_frame, &writer)) { + QUIC_BUG(quic_bug_10850_28) << "AppendCryptoFrame failed"; + return 0; + } + break; + case HANDSHAKE_DONE_FRAME: + // HANDSHAKE_DONE has no payload. + break; + default: + RaiseError(QUIC_INVALID_FRAME_DATA); + QUIC_BUG(quic_bug_10850_29) << "QUIC_INVALID_FRAME_DATA"; + return 0; + } + ++i; + } + + if (!WriteIetfLongHeaderLength(header, &writer, length_field_offset, level)) { + return 0; + } + + return writer.length(); +} + +size_t QuicFramer::AppendIetfFrames(const QuicFrames& frames, + QuicDataWriter* writer) { + size_t i = 0; + for (const QuicFrame& frame : frames) { + // Determine if we should write stream frame length in header. + const bool last_frame_in_packet = i == frames.size() - 1; + if (!AppendIetfFrameType(frame, last_frame_in_packet, writer)) { + QUIC_BUG(quic_bug_10850_30) + << "AppendIetfFrameType failed: " << detailed_error(); + return 0; + } + + switch (frame.type) { + case PADDING_FRAME: + if (!AppendPaddingFrame(frame.padding_frame, writer)) { + QUIC_BUG(quic_bug_10850_31) << "AppendPaddingFrame of " + << frame.padding_frame.num_padding_bytes + << " failed: " << detailed_error(); + return 0; + } + break; + case STREAM_FRAME: + if (!AppendStreamFrame(frame.stream_frame, last_frame_in_packet, + writer)) { + QUIC_BUG(quic_bug_10850_32) + << "AppendStreamFrame " << frame.stream_frame + << " failed: " << detailed_error(); + return 0; + } + break; + case ACK_FRAME: + if (!AppendIetfAckFrameAndTypeByte(*frame.ack_frame, writer)) { + QUIC_BUG(quic_bug_10850_33) + << "AppendIetfAckFrameAndTypeByte failed: " << detailed_error(); + return 0; + } + break; + case STOP_WAITING_FRAME: + set_detailed_error( + "Attempt to append STOP WAITING frame in IETF QUIC."); + RaiseError(QUIC_INTERNAL_ERROR); + QUIC_BUG(quic_bug_10850_34) << detailed_error(); + return 0; + case MTU_DISCOVERY_FRAME: + // MTU discovery frames are serialized as ping frames. + ABSL_FALLTHROUGH_INTENDED; + case PING_FRAME: + // Ping has no payload. + break; + case RST_STREAM_FRAME: + if (!AppendRstStreamFrame(*frame.rst_stream_frame, writer)) { + QUIC_BUG(quic_bug_10850_35) + << "AppendRstStreamFrame failed: " << detailed_error(); + return 0; + } + break; + case CONNECTION_CLOSE_FRAME: + if (!AppendIetfConnectionCloseFrame(*frame.connection_close_frame, + writer)) { + QUIC_BUG(quic_bug_10850_36) + << "AppendIetfConnectionCloseFrame failed: " << detailed_error(); + return 0; + } + break; + case GOAWAY_FRAME: + set_detailed_error("Attempt to append GOAWAY frame in IETF QUIC."); + RaiseError(QUIC_INTERNAL_ERROR); + QUIC_BUG(quic_bug_10850_37) << detailed_error(); + return 0; + case WINDOW_UPDATE_FRAME: + // Depending on whether there is a stream ID or not, will be either a + // MAX STREAM DATA frame or a MAX DATA frame. + if (frame.window_update_frame.stream_id == + QuicUtils::GetInvalidStreamId(transport_version())) { + if (!AppendMaxDataFrame(frame.window_update_frame, writer)) { + QUIC_BUG(quic_bug_10850_38) + << "AppendMaxDataFrame failed: " << detailed_error(); + return 0; + } + } else { + if (!AppendMaxStreamDataFrame(frame.window_update_frame, writer)) { + QUIC_BUG(quic_bug_10850_39) + << "AppendMaxStreamDataFrame failed: " << detailed_error(); + return 0; + } + } + break; + case BLOCKED_FRAME: + if (!AppendBlockedFrame(frame.blocked_frame, writer)) { + QUIC_BUG(quic_bug_10850_40) + << "AppendBlockedFrame failed: " << detailed_error(); + return 0; + } + break; + case MAX_STREAMS_FRAME: + if (!AppendMaxStreamsFrame(frame.max_streams_frame, writer)) { + QUIC_BUG(quic_bug_10850_41) + << "AppendMaxStreamsFrame failed: " << detailed_error(); + return 0; + } + break; + case STREAMS_BLOCKED_FRAME: + if (!AppendStreamsBlockedFrame(frame.streams_blocked_frame, writer)) { + QUIC_BUG(quic_bug_10850_42) + << "AppendStreamsBlockedFrame failed: " << detailed_error(); + return 0; + } + break; + case NEW_CONNECTION_ID_FRAME: + if (!AppendNewConnectionIdFrame(*frame.new_connection_id_frame, + writer)) { + QUIC_BUG(quic_bug_10850_43) + << "AppendNewConnectionIdFrame failed: " << detailed_error(); + return 0; + } + break; + case RETIRE_CONNECTION_ID_FRAME: + if (!AppendRetireConnectionIdFrame(*frame.retire_connection_id_frame, + writer)) { + QUIC_BUG(quic_bug_10850_44) + << "AppendRetireConnectionIdFrame failed: " << detailed_error(); + return 0; + } + break; + case NEW_TOKEN_FRAME: + if (!AppendNewTokenFrame(*frame.new_token_frame, writer)) { + QUIC_BUG(quic_bug_10850_45) + << "AppendNewTokenFrame failed: " << detailed_error(); + return 0; + } + break; + case STOP_SENDING_FRAME: + if (!AppendStopSendingFrame(frame.stop_sending_frame, writer)) { + QUIC_BUG(quic_bug_10850_46) + << "AppendStopSendingFrame failed: " << detailed_error(); + return 0; + } + break; + case PATH_CHALLENGE_FRAME: + if (!AppendPathChallengeFrame(frame.path_challenge_frame, writer)) { + QUIC_BUG(quic_bug_10850_47) + << "AppendPathChallengeFrame failed: " << detailed_error(); + return 0; + } + break; + case PATH_RESPONSE_FRAME: + if (!AppendPathResponseFrame(frame.path_response_frame, writer)) { + QUIC_BUG(quic_bug_10850_48) + << "AppendPathResponseFrame failed: " << detailed_error(); + return 0; + } + break; + case MESSAGE_FRAME: + if (!AppendMessageFrameAndTypeByte(*frame.message_frame, + last_frame_in_packet, writer)) { + QUIC_BUG(quic_bug_10850_49) + << "AppendMessageFrame failed: " << detailed_error(); + return 0; + } + break; + case CRYPTO_FRAME: + if (!AppendCryptoFrame(*frame.crypto_frame, writer)) { + QUIC_BUG(quic_bug_10850_50) + << "AppendCryptoFrame failed: " << detailed_error(); + return 0; + } + break; + case HANDSHAKE_DONE_FRAME: + // HANDSHAKE_DONE has no payload. + break; + case ACK_FREQUENCY_FRAME: + if (!AppendAckFrequencyFrame(*frame.ack_frequency_frame, writer)) { + QUIC_BUG(quic_bug_10850_51) + << "AppendAckFrequencyFrame failed: " << detailed_error(); + return 0; + } + break; + default: + set_detailed_error("Tried to append unknown frame type."); + RaiseError(QUIC_INVALID_FRAME_DATA); + QUIC_BUG(quic_bug_10850_52) + << "QUIC_INVALID_FRAME_DATA: " << frame.type; + return 0; + } + ++i; + } + + return writer->length(); +} + +// static +std::unique_ptr QuicFramer::BuildPublicResetPacket( + const QuicPublicResetPacket& packet) { + CryptoHandshakeMessage reset; + reset.set_tag(kPRST); + reset.SetValue(kRNON, packet.nonce_proof); + if (packet.client_address.host().address_family() != + IpAddressFamily::IP_UNSPEC) { + // packet.client_address is non-empty. + QuicSocketAddressCoder address_coder(packet.client_address); + std::string serialized_address = address_coder.Encode(); + if (serialized_address.empty()) { + return nullptr; + } + reset.SetStringPiece(kCADR, serialized_address); + } + if (!packet.endpoint_id.empty()) { + reset.SetStringPiece(kEPID, packet.endpoint_id); + } + const QuicData& reset_serialized = reset.GetSerialized(); + + size_t len = kPublicFlagsSize + packet.connection_id.length() + + reset_serialized.length(); + std::unique_ptr buffer(new char[len]); + QuicDataWriter writer(len, buffer.get()); + + uint8_t flags = static_cast(PACKET_PUBLIC_FLAGS_RST | + PACKET_PUBLIC_FLAGS_8BYTE_CONNECTION_ID); + // This hack makes post-v33 public reset packet look like pre-v33 packets. + flags |= static_cast(PACKET_PUBLIC_FLAGS_8BYTE_CONNECTION_ID_OLD); + if (!writer.WriteUInt8(flags)) { + return nullptr; + } + + if (!writer.WriteConnectionId(packet.connection_id)) { + return nullptr; + } + + if (!writer.WriteBytes(reset_serialized.data(), reset_serialized.length())) { + return nullptr; + } + + return std::make_unique(buffer.release(), len, true); +} + +// static +size_t QuicFramer::GetMinStatelessResetPacketLength() { + // 5 bytes (40 bits) = 2 Fixed Bits (01) + 38 Unpredictable bits + return 5 + kStatelessResetTokenLength; +} + +// static +std::unique_ptr QuicFramer::BuildIetfStatelessResetPacket( + QuicConnectionId connection_id, size_t received_packet_length, + StatelessResetToken stateless_reset_token) { + return BuildIetfStatelessResetPacket(connection_id, received_packet_length, + stateless_reset_token, + QuicRandom::GetInstance()); +} + +// static +std::unique_ptr QuicFramer::BuildIetfStatelessResetPacket( + QuicConnectionId /*connection_id*/, size_t received_packet_length, + StatelessResetToken stateless_reset_token, QuicRandom* random) { + QUIC_DVLOG(1) << "Building IETF stateless reset packet."; + if (received_packet_length <= GetMinStatelessResetPacketLength()) { + QUICHE_DLOG(ERROR) + << "Tried to build stateless reset packet with received packet " + "length " + << received_packet_length; + return nullptr; + } + // To ensure stateless reset is indistinguishable from a valid packet, + // include the max connection ID length. + size_t len = std::min(received_packet_length - 1, + GetMinStatelessResetPacketLength() + 1 + + kQuicMaxConnectionIdWithLengthPrefixLength); + std::unique_ptr buffer(new char[len]); + QuicDataWriter writer(len, buffer.get()); + // Append random bytes. This randomness only exists to prevent middleboxes + // from comparing the entire packet to a known value. Therefore it has no + // cryptographic use, and does not need a secure cryptographic pseudo-random + // number generator. It's therefore safe to use WriteInsecureRandomBytes. + const size_t random_bytes_size = len - kStatelessResetTokenLength; + if (!writer.WriteInsecureRandomBytes(random, random_bytes_size)) { + QUIC_BUG(362045737_2) << "Failed to append random bytes of length: " + << random_bytes_size; + return nullptr; + } + // Change first 2 fixed bits to 01. + buffer[0] &= ~FLAGS_LONG_HEADER; + buffer[0] |= FLAGS_FIXED_BIT; + + // Append stateless reset token. + if (!writer.WriteBytes(&stateless_reset_token, + sizeof(stateless_reset_token))) { + QUIC_BUG(362045737_3) << "Failed to write stateless reset token"; + return nullptr; + } + return std::make_unique(buffer.release(), len, + /*owns_buffer=*/true); +} + +// static +std::unique_ptr QuicFramer::BuildVersionNegotiationPacket( + QuicConnectionId server_connection_id, + QuicConnectionId client_connection_id, bool ietf_quic, + bool use_length_prefix, const ParsedQuicVersionVector& versions) { + QUIC_CODE_COUNT(quic_build_version_negotiation); + if (use_length_prefix) { + QUICHE_DCHECK(ietf_quic); + QUIC_CODE_COUNT(quic_build_version_negotiation_ietf); + } else if (ietf_quic) { + QUIC_CODE_COUNT(quic_build_version_negotiation_old_ietf); + } else { + QUIC_CODE_COUNT(quic_build_version_negotiation_old_gquic); + } + ParsedQuicVersionVector wire_versions = versions; + // Add a version reserved for negotiation as suggested by the + // "Using Reserved Versions" section of draft-ietf-quic-transport. + if (wire_versions.empty()) { + // Ensure that version negotiation packets we send have at least two + // versions. This guarantees that, under all circumstances, all QUIC + // packets we send are at least 14 bytes long. + wire_versions = {QuicVersionReservedForNegotiation(), + QuicVersionReservedForNegotiation()}; + } else { + // This is not uniformely distributed but is acceptable since no security + // depends on this randomness. + size_t version_index = 0; + const bool disable_randomness = + GetQuicFlag(quic_disable_version_negotiation_grease_randomness); + if (!disable_randomness) { + version_index = + QuicRandom::GetInstance()->RandUint64() % (wire_versions.size() + 1); + } + wire_versions.insert(wire_versions.begin() + version_index, + QuicVersionReservedForNegotiation()); + } + if (ietf_quic) { + return BuildIetfVersionNegotiationPacket( + use_length_prefix, server_connection_id, client_connection_id, + wire_versions); + } + + // The GQUIC encoding does not support encoding client connection IDs. + QUICHE_DCHECK(client_connection_id.IsEmpty()); + // The GQUIC encoding does not support length-prefixed connection IDs. + QUICHE_DCHECK(!use_length_prefix); + + QUICHE_DCHECK(!wire_versions.empty()); + size_t len = kPublicFlagsSize + server_connection_id.length() + + wire_versions.size() * kQuicVersionSize; + std::unique_ptr buffer(new char[len]); + QuicDataWriter writer(len, buffer.get()); + + uint8_t flags = static_cast( + PACKET_PUBLIC_FLAGS_VERSION | PACKET_PUBLIC_FLAGS_8BYTE_CONNECTION_ID | + PACKET_PUBLIC_FLAGS_8BYTE_CONNECTION_ID_OLD); + if (!writer.WriteUInt8(flags)) { + return nullptr; + } + + if (!writer.WriteConnectionId(server_connection_id)) { + return nullptr; + } + + for (const ParsedQuicVersion& version : wire_versions) { + if (!writer.WriteUInt32(CreateQuicVersionLabel(version))) { + return nullptr; + } + } + + return std::make_unique(buffer.release(), len, true); +} + +// static +std::unique_ptr +QuicFramer::BuildIetfVersionNegotiationPacket( + bool use_length_prefix, QuicConnectionId server_connection_id, + QuicConnectionId client_connection_id, + const ParsedQuicVersionVector& versions) { + QUIC_DVLOG(1) << "Building IETF version negotiation packet with" + << (use_length_prefix ? "" : "out") + << " length prefix, server_connection_id " + << server_connection_id << " client_connection_id " + << client_connection_id << " versions " + << ParsedQuicVersionVectorToString(versions); + QUICHE_DCHECK(!versions.empty()); + size_t len = kPacketHeaderTypeSize + kConnectionIdLengthSize + + client_connection_id.length() + server_connection_id.length() + + (versions.size() + 1) * kQuicVersionSize; + if (use_length_prefix) { + // When using length-prefixed connection IDs, packets carry two lengths + // instead of one. + len += kConnectionIdLengthSize; + } + std::unique_ptr buffer(new char[len]); + QuicDataWriter writer(len, buffer.get()); + + // TODO(fayang): Randomly select a value for the type. + uint8_t type = static_cast(FLAGS_LONG_HEADER | FLAGS_FIXED_BIT); + if (!writer.WriteUInt8(type)) { + return nullptr; + } + + if (!writer.WriteUInt32(0)) { + return nullptr; + } + + if (!AppendIetfConnectionIds(true, use_length_prefix, client_connection_id, + server_connection_id, &writer)) { + return nullptr; + } + + for (const ParsedQuicVersion& version : versions) { + if (!writer.WriteUInt32(CreateQuicVersionLabel(version))) { + return nullptr; + } + } + + return std::make_unique(buffer.release(), len, true); +} + +bool QuicFramer::ProcessPacket(const QuicEncryptedPacket& packet) { + QUICHE_DCHECK(!is_processing_packet_) << ENDPOINT << "Nested ProcessPacket"; + is_processing_packet_ = true; + bool result = ProcessPacketInternal(packet); + is_processing_packet_ = false; + return result; +} + +bool QuicFramer::ProcessPacketInternal(const QuicEncryptedPacket& packet) { + QuicDataReader reader(packet.data(), packet.length()); + + bool packet_has_ietf_packet_header = false; + if (infer_packet_header_type_from_version_) { + packet_has_ietf_packet_header = version_.HasIetfInvariantHeader(); + } else if (!reader.IsDoneReading()) { + uint8_t type = reader.PeekByte(); + packet_has_ietf_packet_header = QuicUtils::IsIetfPacketHeader(type); + } + if (packet_has_ietf_packet_header) { + QUIC_DVLOG(1) << ENDPOINT << "Processing IETF QUIC packet."; + } + + visitor_->OnPacket(); + + QuicPacketHeader header; + if (!ProcessPublicHeader(&reader, packet_has_ietf_packet_header, &header)) { + QUICHE_DCHECK_NE("", detailed_error_); + QUIC_DVLOG(1) << ENDPOINT << "Unable to process public header. Error: " + << detailed_error_; + QUICHE_DCHECK_NE("", detailed_error_); + RecordDroppedPacketReason(DroppedPacketReason::INVALID_PUBLIC_HEADER); + return RaiseError(QUIC_INVALID_PACKET_HEADER); + } + + if (!visitor_->OnUnauthenticatedPublicHeader(header)) { + // The visitor suppresses further processing of the packet. + return true; + } + + if (IsVersionNegotiation(header, packet_has_ietf_packet_header)) { + if (perspective_ == Perspective::IS_CLIENT) { + QUIC_DVLOG(1) << "Client received version negotiation packet"; + return ProcessVersionNegotiationPacket(&reader, header); + } else { + QUIC_DLOG(ERROR) << "Server received version negotiation packet"; + set_detailed_error("Server received version negotiation packet."); + return RaiseError(QUIC_INVALID_VERSION_NEGOTIATION_PACKET); + } + } + + if (header.version_flag && header.version != version_) { + if (perspective_ == Perspective::IS_SERVER) { + if (!visitor_->OnProtocolVersionMismatch(header.version)) { + RecordDroppedPacketReason(DroppedPacketReason::VERSION_MISMATCH); + return true; + } + } else { + // A client received a packet of a different version but that packet is + // not a version negotiation packet. It is therefore invalid and dropped. + QUIC_DLOG(ERROR) << "Client received unexpected version " + << ParsedQuicVersionToString(header.version) + << " instead of " << ParsedQuicVersionToString(version_); + set_detailed_error("Client received unexpected version."); + return RaiseError(QUIC_PACKET_WRONG_VERSION); + } + } + + bool rv; + if (header.long_packet_type == RETRY) { + rv = ProcessRetryPacket(&reader, header); + } else if (header.reset_flag) { + rv = ProcessPublicResetPacket(&reader, header); + } else if (packet.length() <= kMaxIncomingPacketSize) { + // The optimized decryption algorithm implementations run faster when + // operating on aligned memory. + ABSL_CACHELINE_ALIGNED char buffer[kMaxIncomingPacketSize]; + if (packet_has_ietf_packet_header) { + rv = ProcessIetfDataPacket(&reader, &header, packet, buffer, + ABSL_ARRAYSIZE(buffer)); + } else { + rv = ProcessDataPacket(&reader, &header, packet, buffer, + ABSL_ARRAYSIZE(buffer)); + } + } else { + std::unique_ptr large_buffer(new char[packet.length()]); + if (packet_has_ietf_packet_header) { + rv = ProcessIetfDataPacket(&reader, &header, packet, large_buffer.get(), + packet.length()); + } else { + rv = ProcessDataPacket(&reader, &header, packet, large_buffer.get(), + packet.length()); + } + QUIC_BUG_IF(quic_bug_10850_53, rv) + << "QUIC should never successfully process packets larger" + << "than kMaxIncomingPacketSize. packet size:" << packet.length(); + } + return rv; +} + +bool QuicFramer::ProcessVersionNegotiationPacket( + QuicDataReader* reader, const QuicPacketHeader& header) { + QUICHE_DCHECK_EQ(Perspective::IS_CLIENT, perspective_); + + QuicVersionNegotiationPacket packet( + GetServerConnectionIdAsRecipient(header, perspective_)); + // Try reading at least once to raise error if the packet is invalid. + do { + QuicVersionLabel version_label; + if (!ProcessVersionLabel(reader, &version_label)) { + set_detailed_error("Unable to read supported version in negotiation."); + RecordDroppedPacketReason( + DroppedPacketReason::INVALID_VERSION_NEGOTIATION_PACKET); + return RaiseError(QUIC_INVALID_VERSION_NEGOTIATION_PACKET); + } + ParsedQuicVersion parsed_version = ParseQuicVersionLabel(version_label); + if (parsed_version != UnsupportedQuicVersion()) { + packet.versions.push_back(parsed_version); + } + } while (!reader->IsDoneReading()); + + QUIC_DLOG(INFO) << ENDPOINT << "parsed version negotiation: " + << ParsedQuicVersionVectorToString(packet.versions); + + visitor_->OnVersionNegotiationPacket(packet); + return true; +} + +bool QuicFramer::ProcessRetryPacket(QuicDataReader* reader, + const QuicPacketHeader& header) { + QUICHE_DCHECK_EQ(Perspective::IS_CLIENT, perspective_); + if (drop_incoming_retry_packets_) { + QUIC_DLOG(INFO) << "Ignoring received RETRY packet"; + return true; + } + + if (version_.UsesTls()) { + QUICHE_DCHECK(version_.HasLengthPrefixedConnectionIds()) << version_; + const size_t bytes_remaining = reader->BytesRemaining(); + if (bytes_remaining <= kRetryIntegrityTagLength) { + set_detailed_error("Retry packet too short to parse integrity tag."); + return false; + } + const size_t retry_token_length = + bytes_remaining - kRetryIntegrityTagLength; + QUICHE_DCHECK_GT(retry_token_length, 0u); + absl::string_view retry_token; + if (!reader->ReadStringPiece(&retry_token, retry_token_length)) { + set_detailed_error("Failed to read retry token."); + return false; + } + absl::string_view retry_without_tag = reader->PreviouslyReadPayload(); + absl::string_view integrity_tag = reader->ReadRemainingPayload(); + QUICHE_DCHECK_EQ(integrity_tag.length(), kRetryIntegrityTagLength); + visitor_->OnRetryPacket(EmptyQuicConnectionId(), + header.source_connection_id, retry_token, + integrity_tag, retry_without_tag); + return true; + } + + QuicConnectionId original_destination_connection_id; + if (version_.HasLengthPrefixedConnectionIds()) { + // Parse Original Destination Connection ID. + if (!reader->ReadLengthPrefixedConnectionId( + &original_destination_connection_id)) { + set_detailed_error("Unable to read Original Destination ConnectionId."); + return false; + } + } else { + // Parse Original Destination Connection ID Length. + uint8_t odcil = header.type_byte & 0xf; + if (odcil != 0) { + odcil += kConnectionIdLengthAdjustment; + } + + // Parse Original Destination Connection ID. + if (!reader->ReadConnectionId(&original_destination_connection_id, odcil)) { + set_detailed_error("Unable to read Original Destination ConnectionId."); + return false; + } + } + + if (!QuicUtils::IsConnectionIdValidForVersion( + original_destination_connection_id, transport_version())) { + set_detailed_error( + "Received Original Destination ConnectionId with invalid length."); + return false; + } + + absl::string_view retry_token = reader->ReadRemainingPayload(); + visitor_->OnRetryPacket(original_destination_connection_id, + header.source_connection_id, retry_token, + /*retry_integrity_tag=*/absl::string_view(), + /*retry_without_tag=*/absl::string_view()); + return true; +} + +// Seeks the current packet to check for a coalesced packet at the end. +// If the IETF length field only spans part of the outer packet, +// then there is a coalesced packet after this one. +void QuicFramer::MaybeProcessCoalescedPacket( + const QuicDataReader& encrypted_reader, uint64_t remaining_bytes_length, + const QuicPacketHeader& header) { + if (header.remaining_packet_length >= remaining_bytes_length) { + // There is no coalesced packet. + return; + } + + absl::string_view remaining_data = encrypted_reader.PeekRemainingPayload(); + QUICHE_DCHECK_EQ(remaining_data.length(), remaining_bytes_length); + + const char* coalesced_data = + remaining_data.data() + header.remaining_packet_length; + uint64_t coalesced_data_length = + remaining_bytes_length - header.remaining_packet_length; + QuicDataReader coalesced_reader(coalesced_data, coalesced_data_length); + + QuicPacketHeader coalesced_header; + if (!ProcessIetfPacketHeader(&coalesced_reader, &coalesced_header)) { + // Some implementations pad their INITIAL packets by sending random invalid + // data after the INITIAL, and that is allowed by the specification. If we + // fail to parse a subsequent coalesced packet, simply ignore it. + QUIC_DLOG(INFO) << ENDPOINT + << "Failed to parse received coalesced header of length " + << coalesced_data_length + << " with error: " << detailed_error_ << ": " + << absl::BytesToHexString(absl::string_view( + coalesced_data, coalesced_data_length)) + << " previous header was " << header; + return; + } + + if (coalesced_header.destination_connection_id != + header.destination_connection_id) { + // Drop coalesced packets with mismatched connection IDs. + QUIC_DLOG(INFO) << ENDPOINT << "Received mismatched coalesced header " + << coalesced_header << " previous header was " << header; + QUIC_CODE_COUNT( + quic_received_coalesced_packets_with_mismatched_connection_id); + return; + } + + QuicEncryptedPacket coalesced_packet(coalesced_data, coalesced_data_length, + /*owns_buffer=*/false); + visitor_->OnCoalescedPacket(coalesced_packet); +} + +bool QuicFramer::MaybeProcessIetfLength(QuicDataReader* encrypted_reader, + QuicPacketHeader* header) { + if (!QuicVersionHasLongHeaderLengths(header->version.transport_version) || + header->form != IETF_QUIC_LONG_HEADER_PACKET || + (header->long_packet_type != INITIAL && + header->long_packet_type != HANDSHAKE && + header->long_packet_type != ZERO_RTT_PROTECTED)) { + return true; + } + header->length_length = encrypted_reader->PeekVarInt62Length(); + if (!encrypted_reader->ReadVarInt62(&header->remaining_packet_length)) { + set_detailed_error("Unable to read long header payload length."); + return RaiseError(QUIC_INVALID_PACKET_HEADER); + } + uint64_t remaining_bytes_length = encrypted_reader->BytesRemaining(); + if (header->remaining_packet_length > remaining_bytes_length) { + set_detailed_error("Long header payload length longer than packet."); + return RaiseError(QUIC_INVALID_PACKET_HEADER); + } + + MaybeProcessCoalescedPacket(*encrypted_reader, remaining_bytes_length, + *header); + + if (!encrypted_reader->TruncateRemaining(header->remaining_packet_length)) { + set_detailed_error("Length TruncateRemaining failed."); + QUIC_BUG(quic_bug_10850_54) << "Length TruncateRemaining failed."; + return RaiseError(QUIC_INVALID_PACKET_HEADER); + } + return true; +} + +bool QuicFramer::ProcessIetfDataPacket(QuicDataReader* encrypted_reader, + QuicPacketHeader* header, + const QuicEncryptedPacket& packet, + char* decrypted_buffer, + size_t buffer_length) { + QUICHE_DCHECK_NE(GOOGLE_QUIC_PACKET, header->form); + QUICHE_DCHECK(!header->has_possible_stateless_reset_token); + header->length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0; + header->remaining_packet_length = 0; + if (header->form == IETF_QUIC_SHORT_HEADER_PACKET && + perspective_ == Perspective::IS_CLIENT) { + // Peek possible stateless reset token. Will only be used on decryption + // failure. + absl::string_view remaining = encrypted_reader->PeekRemainingPayload(); + if (remaining.length() >= sizeof(header->possible_stateless_reset_token)) { + header->has_possible_stateless_reset_token = true; + memcpy(&header->possible_stateless_reset_token, + &remaining.data()[remaining.length() - + sizeof(header->possible_stateless_reset_token)], + sizeof(header->possible_stateless_reset_token)); + } + } + + if (!MaybeProcessIetfLength(encrypted_reader, header)) { + return false; + } + + absl::string_view associated_data; + std::vector ad_storage; + QuicPacketNumber base_packet_number; + if (header->form == IETF_QUIC_SHORT_HEADER_PACKET || + header->long_packet_type != VERSION_NEGOTIATION) { + QUICHE_DCHECK(header->form == IETF_QUIC_SHORT_HEADER_PACKET || + header->long_packet_type == INITIAL || + header->long_packet_type == HANDSHAKE || + header->long_packet_type == ZERO_RTT_PROTECTED); + // Process packet number. + if (supports_multiple_packet_number_spaces_) { + PacketNumberSpace pn_space = GetPacketNumberSpace(*header); + if (pn_space == NUM_PACKET_NUMBER_SPACES) { + return RaiseError(QUIC_INVALID_PACKET_HEADER); + } + base_packet_number = largest_decrypted_packet_numbers_[pn_space]; + } else { + base_packet_number = largest_packet_number_; + } + uint64_t full_packet_number; + bool hp_removal_failed = false; + if (version_.HasHeaderProtection()) { + if (!RemoveHeaderProtection(encrypted_reader, packet, header, + &full_packet_number, &ad_storage)) { + hp_removal_failed = true; + } + associated_data = absl::string_view(ad_storage.data(), ad_storage.size()); + } else if (!ProcessAndCalculatePacketNumber( + encrypted_reader, header->packet_number_length, + base_packet_number, &full_packet_number)) { + set_detailed_error("Unable to read packet number."); + RecordDroppedPacketReason(DroppedPacketReason::INVALID_PACKET_NUMBER); + return RaiseError(QUIC_INVALID_PACKET_HEADER); + } + + if (hp_removal_failed || + !IsValidFullPacketNumber(full_packet_number, version())) { + if (IsIetfStatelessResetPacket(*header)) { + // This is a stateless reset packet. + QuicIetfStatelessResetPacket reset_packet( + *header, header->possible_stateless_reset_token); + visitor_->OnAuthenticatedIetfStatelessResetPacket(reset_packet); + return true; + } + if (hp_removal_failed) { + const EncryptionLevel decryption_level = GetEncryptionLevel(*header); + const bool has_decryption_key = decrypter_[decryption_level] != nullptr; + visitor_->OnUndecryptablePacket( + QuicEncryptedPacket(encrypted_reader->FullPayload()), + decryption_level, has_decryption_key); + RecordDroppedPacketReason(DroppedPacketReason::DECRYPTION_FAILURE); + set_detailed_error(absl::StrCat( + "Unable to decrypt ", EncryptionLevelToString(decryption_level), + " header protection", has_decryption_key ? "" : " (missing key)", + ".")); + return RaiseError(QUIC_DECRYPTION_FAILURE); + } + RecordDroppedPacketReason(DroppedPacketReason::INVALID_PACKET_NUMBER); + set_detailed_error("packet numbers cannot be 0."); + return RaiseError(QUIC_INVALID_PACKET_HEADER); + } + header->packet_number = QuicPacketNumber(full_packet_number); + } + + // A nonce should only present in SHLO from the server to the client when + // using QUIC crypto. + if (header->form == IETF_QUIC_LONG_HEADER_PACKET && + header->long_packet_type == ZERO_RTT_PROTECTED && + perspective_ == Perspective::IS_CLIENT && + version_.handshake_protocol == PROTOCOL_QUIC_CRYPTO) { + if (!encrypted_reader->ReadBytes( + reinterpret_cast(last_nonce_.data()), + last_nonce_.size())) { + set_detailed_error("Unable to read nonce."); + RecordDroppedPacketReason( + DroppedPacketReason::INVALID_DIVERSIFICATION_NONCE); + return RaiseError(QUIC_INVALID_PACKET_HEADER); + } + + header->nonce = &last_nonce_; + } else { + header->nonce = nullptr; + } + + if (!visitor_->OnUnauthenticatedHeader(*header)) { + set_detailed_error( + "Visitor asked to stop processing of unauthenticated header."); + return false; + } + + absl::string_view encrypted = encrypted_reader->ReadRemainingPayload(); + if (!version_.HasHeaderProtection()) { + associated_data = GetAssociatedDataFromEncryptedPacket( + version_.transport_version, packet, + GetIncludedDestinationConnectionIdLength(*header), + GetIncludedSourceConnectionIdLength(*header), header->version_flag, + header->nonce != nullptr, header->packet_number_length, + header->retry_token_length_length, header->retry_token.length(), + header->length_length); + } + + size_t decrypted_length = 0; + EncryptionLevel decrypted_level; + if (!DecryptPayload(packet.length(), encrypted, associated_data, *header, + decrypted_buffer, buffer_length, &decrypted_length, + &decrypted_level)) { + if (IsIetfStatelessResetPacket(*header)) { + // This is a stateless reset packet. + QuicIetfStatelessResetPacket reset_packet( + *header, header->possible_stateless_reset_token); + visitor_->OnAuthenticatedIetfStatelessResetPacket(reset_packet); + return true; + } + const EncryptionLevel decryption_level = GetEncryptionLevel(*header); + const bool has_decryption_key = version_.KnowsWhichDecrypterToUse() && + decrypter_[decryption_level] != nullptr; + visitor_->OnUndecryptablePacket( + QuicEncryptedPacket(encrypted_reader->FullPayload()), decryption_level, + has_decryption_key); + set_detailed_error(absl::StrCat( + "Unable to decrypt ", EncryptionLevelToString(decryption_level), + " payload with reconstructed packet number ", + header->packet_number.ToString(), " (largest decrypted was ", + base_packet_number.ToString(), ")", + has_decryption_key || !version_.KnowsWhichDecrypterToUse() + ? "" + : " (missing key)", + ".")); + RecordDroppedPacketReason(DroppedPacketReason::DECRYPTION_FAILURE); + return RaiseError(QUIC_DECRYPTION_FAILURE); + } + QuicDataReader reader(decrypted_buffer, decrypted_length); + + // Remember decrypted_payload in the current connection context until the end + // of this function. + auto* connection_context = QuicConnectionContext::Current(); + if (connection_context != nullptr) { + connection_context->process_packet_context.decrypted_payload = + reader.FullPayload(); + connection_context->process_packet_context.current_frame_offset = 0; + } + auto clear_decrypted_payload = absl::MakeCleanup([&]() { + if (connection_context != nullptr) { + connection_context->process_packet_context.decrypted_payload = + absl::string_view(); + } + }); + + // Update the largest packet number after we have decrypted the packet + // so we are confident is not attacker controlled. + if (supports_multiple_packet_number_spaces_) { + largest_decrypted_packet_numbers_[QuicUtils::GetPacketNumberSpace( + decrypted_level)] + .UpdateMax(header->packet_number); + } else { + largest_packet_number_.UpdateMax(header->packet_number); + } + + if (!visitor_->OnPacketHeader(*header)) { + RecordDroppedPacketReason(DroppedPacketReason::INVALID_PACKET_NUMBER); + // The visitor suppresses further processing of the packet. + return true; + } + + if (packet.length() > kMaxIncomingPacketSize) { + set_detailed_error("Packet too large."); + return RaiseError(QUIC_PACKET_TOO_LARGE); + } + + // Handle the payload. + if (VersionHasIetfQuicFrames(version_.transport_version)) { + current_received_frame_type_ = 0; + previously_received_frame_type_ = 0; + if (!ProcessIetfFrameData(&reader, *header, decrypted_level)) { + current_received_frame_type_ = 0; + previously_received_frame_type_ = 0; + QUICHE_DCHECK_NE(QUIC_NO_ERROR, + error_); // ProcessIetfFrameData sets the error. + QUICHE_DCHECK_NE("", detailed_error_); + QUIC_DLOG(WARNING) << ENDPOINT << "Unable to process frame data. Error: " + << detailed_error_; + return false; + } + current_received_frame_type_ = 0; + previously_received_frame_type_ = 0; + } else { + if (!ProcessFrameData(&reader, *header)) { + QUICHE_DCHECK_NE(QUIC_NO_ERROR, + error_); // ProcessFrameData sets the error. + QUICHE_DCHECK_NE("", detailed_error_); + QUIC_DLOG(WARNING) << ENDPOINT << "Unable to process frame data. Error: " + << detailed_error_; + return false; + } + } + + visitor_->OnPacketComplete(); + return true; +} + +bool QuicFramer::ProcessDataPacket(QuicDataReader* encrypted_reader, + QuicPacketHeader* header, + const QuicEncryptedPacket& packet, + char* decrypted_buffer, + size_t buffer_length) { + if (!ProcessUnauthenticatedHeader(encrypted_reader, header)) { + QUICHE_DCHECK_NE("", detailed_error_); + QUIC_DVLOG(1) + << ENDPOINT + << "Unable to process packet header. Stopping parsing. Error: " + << detailed_error_; + RecordDroppedPacketReason(DroppedPacketReason::INVALID_PACKET_NUMBER); + return false; + } + + absl::string_view encrypted = encrypted_reader->ReadRemainingPayload(); + absl::string_view associated_data = GetAssociatedDataFromEncryptedPacket( + version_.transport_version, packet, + GetIncludedDestinationConnectionIdLength(*header), + GetIncludedSourceConnectionIdLength(*header), header->version_flag, + header->nonce != nullptr, header->packet_number_length, + header->retry_token_length_length, header->retry_token.length(), + header->length_length); + + size_t decrypted_length = 0; + EncryptionLevel decrypted_level; + if (!DecryptPayload(packet.length(), encrypted, associated_data, *header, + decrypted_buffer, buffer_length, &decrypted_length, + &decrypted_level)) { + const EncryptionLevel decryption_level = decrypter_level_; + // This version uses trial decryption so we always report to our visitor + // that we are not certain we have the correct decryption key. + const bool has_decryption_key = false; + visitor_->OnUndecryptablePacket( + QuicEncryptedPacket(encrypted_reader->FullPayload()), decryption_level, + has_decryption_key); + RecordDroppedPacketReason(DroppedPacketReason::DECRYPTION_FAILURE); + set_detailed_error(absl::StrCat("Unable to decrypt ", + EncryptionLevelToString(decryption_level), + " payload.")); + return RaiseError(QUIC_DECRYPTION_FAILURE); + } + + QuicDataReader reader(decrypted_buffer, decrypted_length); + + // Update the largest packet number after we have decrypted the packet + // so we are confident is not attacker controlled. + if (supports_multiple_packet_number_spaces_) { + largest_decrypted_packet_numbers_[QuicUtils::GetPacketNumberSpace( + decrypted_level)] + .UpdateMax(header->packet_number); + } else { + largest_packet_number_.UpdateMax(header->packet_number); + } + + if (!visitor_->OnPacketHeader(*header)) { + // The visitor suppresses further processing of the packet. + return true; + } + + if (packet.length() > kMaxIncomingPacketSize) { + set_detailed_error("Packet too large."); + return RaiseError(QUIC_PACKET_TOO_LARGE); + } + + // Handle the payload. + if (!ProcessFrameData(&reader, *header)) { + QUICHE_DCHECK_NE(QUIC_NO_ERROR, + error_); // ProcessFrameData sets the error. + QUICHE_DCHECK_NE("", detailed_error_); + QUIC_DLOG(WARNING) << ENDPOINT << "Unable to process frame data. Error: " + << detailed_error_; + return false; + } + + visitor_->OnPacketComplete(); + return true; +} + +bool QuicFramer::ProcessPublicResetPacket(QuicDataReader* reader, + const QuicPacketHeader& header) { + QuicPublicResetPacket packet( + GetServerConnectionIdAsRecipient(header, perspective_)); + + std::unique_ptr reset( + CryptoFramer::ParseMessage(reader->ReadRemainingPayload())); + if (!reset) { + set_detailed_error("Unable to read reset message."); + RecordDroppedPacketReason(DroppedPacketReason::INVALID_PUBLIC_RESET_PACKET); + return RaiseError(QUIC_INVALID_PUBLIC_RST_PACKET); + } + if (reset->tag() != kPRST) { + set_detailed_error("Incorrect message tag."); + RecordDroppedPacketReason(DroppedPacketReason::INVALID_PUBLIC_RESET_PACKET); + return RaiseError(QUIC_INVALID_PUBLIC_RST_PACKET); + } + + if (reset->GetUint64(kRNON, &packet.nonce_proof) != QUIC_NO_ERROR) { + set_detailed_error("Unable to read nonce proof."); + RecordDroppedPacketReason(DroppedPacketReason::INVALID_PUBLIC_RESET_PACKET); + return RaiseError(QUIC_INVALID_PUBLIC_RST_PACKET); + } + // TODO(satyamshekhar): validate nonce to protect against DoS. + + absl::string_view address; + if (reset->GetStringPiece(kCADR, &address)) { + QuicSocketAddressCoder address_coder; + if (address_coder.Decode(address.data(), address.length())) { + packet.client_address = + QuicSocketAddress(address_coder.ip(), address_coder.port()); + } + } + + absl::string_view endpoint_id; + if (perspective_ == Perspective::IS_CLIENT && + reset->GetStringPiece(kEPID, &endpoint_id)) { + packet.endpoint_id = std::string(endpoint_id); + packet.endpoint_id += '\0'; + } + + visitor_->OnPublicResetPacket(packet); + return true; +} + +bool QuicFramer::IsIetfStatelessResetPacket( + const QuicPacketHeader& header) const { + QUIC_BUG_IF(quic_bug_12975_3, header.has_possible_stateless_reset_token && + perspective_ != Perspective::IS_CLIENT) + << "has_possible_stateless_reset_token can only be true at client side."; + return header.form == IETF_QUIC_SHORT_HEADER_PACKET && + header.has_possible_stateless_reset_token && + visitor_->IsValidStatelessResetToken( + header.possible_stateless_reset_token); +} + +bool QuicFramer::HasEncrypterOfEncryptionLevel(EncryptionLevel level) const { + return encrypter_[level] != nullptr; +} + +bool QuicFramer::HasDecrypterOfEncryptionLevel(EncryptionLevel level) const { + return decrypter_[level] != nullptr; +} + +bool QuicFramer::HasAnEncrypterForSpace(PacketNumberSpace space) const { + switch (space) { + case INITIAL_DATA: + return HasEncrypterOfEncryptionLevel(ENCRYPTION_INITIAL); + case HANDSHAKE_DATA: + return HasEncrypterOfEncryptionLevel(ENCRYPTION_HANDSHAKE); + case APPLICATION_DATA: + return HasEncrypterOfEncryptionLevel(ENCRYPTION_ZERO_RTT) || + HasEncrypterOfEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + case NUM_PACKET_NUMBER_SPACES: + break; + } + QUIC_BUG(quic_bug_10850_55) + << ENDPOINT + << "Try to send data of space: " << PacketNumberSpaceToString(space); + return false; +} + +EncryptionLevel QuicFramer::GetEncryptionLevelToSendApplicationData() const { + if (!HasAnEncrypterForSpace(APPLICATION_DATA)) { + QUIC_BUG(quic_bug_12975_4) + << "Tried to get encryption level to send application data with no " + "encrypter available."; + return NUM_ENCRYPTION_LEVELS; + } + if (HasEncrypterOfEncryptionLevel(ENCRYPTION_FORWARD_SECURE)) { + return ENCRYPTION_FORWARD_SECURE; + } + QUICHE_DCHECK(HasEncrypterOfEncryptionLevel(ENCRYPTION_ZERO_RTT)); + return ENCRYPTION_ZERO_RTT; +} + +bool QuicFramer::AppendPacketHeader(const QuicPacketHeader& header, + QuicDataWriter* writer, + size_t* length_field_offset) { + if (version().HasIetfInvariantHeader()) { + return AppendIetfPacketHeader(header, writer, length_field_offset); + } + QUIC_DVLOG(1) << ENDPOINT << "Appending header: " << header; + uint8_t public_flags = 0; + if (header.reset_flag) { + public_flags |= PACKET_PUBLIC_FLAGS_RST; + } + if (header.version_flag) { + public_flags |= PACKET_PUBLIC_FLAGS_VERSION; + } + + public_flags |= GetPacketNumberFlags(header.packet_number_length) + << kPublicHeaderSequenceNumberShift; + + if (header.nonce != nullptr) { + QUICHE_DCHECK_EQ(Perspective::IS_SERVER, perspective_); + public_flags |= PACKET_PUBLIC_FLAGS_NONCE; + } + + QuicConnectionId server_connection_id = + GetServerConnectionIdAsSender(header, perspective_); + QuicConnectionIdIncluded server_connection_id_included = + GetServerConnectionIdIncludedAsSender(header, perspective_); + QUICHE_DCHECK_EQ(CONNECTION_ID_ABSENT, + GetClientConnectionIdIncludedAsSender(header, perspective_)) + << ENDPOINT << ParsedQuicVersionToString(version_) + << " invalid header: " << header; + + switch (server_connection_id_included) { + case CONNECTION_ID_ABSENT: + if (!writer->WriteUInt8(public_flags | + PACKET_PUBLIC_FLAGS_0BYTE_CONNECTION_ID)) { + return false; + } + break; + case CONNECTION_ID_PRESENT: + QUIC_BUG_IF(quic_bug_12975_5, + !QuicUtils::IsConnectionIdValidForVersion( + server_connection_id, transport_version())) + << "AppendPacketHeader: attempted to use connection ID " + << server_connection_id << " which is invalid with version " + << version(); + + public_flags |= PACKET_PUBLIC_FLAGS_8BYTE_CONNECTION_ID; + if (perspective_ == Perspective::IS_CLIENT) { + public_flags |= PACKET_PUBLIC_FLAGS_8BYTE_CONNECTION_ID_OLD; + } + if (!writer->WriteUInt8(public_flags) || + !writer->WriteConnectionId(server_connection_id)) { + return false; + } + break; + } + last_serialized_server_connection_id_ = server_connection_id; + + if (header.version_flag) { + QUICHE_DCHECK_EQ(Perspective::IS_CLIENT, perspective_); + QuicVersionLabel version_label = CreateQuicVersionLabel(version_); + if (!writer->WriteUInt32(version_label)) { + return false; + } + + QUIC_DVLOG(1) << ENDPOINT << "label = '" + << QuicVersionLabelToString(version_label) << "'"; + } + + if (header.nonce != nullptr && + !writer->WriteBytes(header.nonce, kDiversificationNonceSize)) { + return false; + } + + if (!AppendPacketNumber(header.packet_number_length, header.packet_number, + writer)) { + return false; + } + + return true; +} + +bool QuicFramer::AppendIetfHeaderTypeByte(const QuicPacketHeader& header, + QuicDataWriter* writer) { + uint8_t type = 0; + if (header.version_flag) { + type = static_cast( + FLAGS_LONG_HEADER | FLAGS_FIXED_BIT | + LongHeaderTypeToOnWireValue(header.long_packet_type, version_) | + PacketNumberLengthToOnWireValue(header.packet_number_length)); + } else { + type = static_cast( + FLAGS_FIXED_BIT | (current_key_phase_bit_ ? FLAGS_KEY_PHASE_BIT : 0) | + PacketNumberLengthToOnWireValue(header.packet_number_length)); + } + return writer->WriteUInt8(type); +} + +bool QuicFramer::AppendIetfPacketHeader(const QuicPacketHeader& header, + QuicDataWriter* writer, + size_t* length_field_offset) { + QUIC_DVLOG(1) << ENDPOINT << "Appending IETF header: " << header; + QuicConnectionId server_connection_id = + GetServerConnectionIdAsSender(header, perspective_); + QUIC_BUG_IF(quic_bug_12975_6, !QuicUtils::IsConnectionIdValidForVersion( + server_connection_id, transport_version())) + << "AppendIetfPacketHeader: attempted to use connection ID " + << server_connection_id << " which is invalid with version " << version(); + if (!AppendIetfHeaderTypeByte(header, writer)) { + return false; + } + + if (header.version_flag) { + QUICHE_DCHECK_NE(VERSION_NEGOTIATION, header.long_packet_type) + << "QuicFramer::AppendIetfPacketHeader does not support sending " + "version negotiation packets, use " + "QuicFramer::BuildVersionNegotiationPacket instead " + << header; + // Append version for long header. + QuicVersionLabel version_label = CreateQuicVersionLabel(version_); + if (!writer->WriteUInt32(version_label)) { + return false; + } + } + + // Append connection ID. + if (!AppendIetfConnectionIds( + header.version_flag, version_.HasLengthPrefixedConnectionIds(), + header.destination_connection_id_included != CONNECTION_ID_ABSENT + ? header.destination_connection_id + : EmptyQuicConnectionId(), + header.source_connection_id_included != CONNECTION_ID_ABSENT + ? header.source_connection_id + : EmptyQuicConnectionId(), + writer)) { + return false; + } + + last_serialized_server_connection_id_ = server_connection_id; + if (version_.SupportsClientConnectionIds()) { + last_serialized_client_connection_id_ = + GetClientConnectionIdAsSender(header, perspective_); + } + + // TODO(b/141924462) Remove this QUIC_BUG once we do support sending RETRY. + QUIC_BUG_IF(quic_bug_12975_7, + header.version_flag && header.long_packet_type == RETRY) + << "Sending IETF RETRY packets is not currently supported " << header; + + if (QuicVersionHasLongHeaderLengths(transport_version()) && + header.version_flag) { + if (header.long_packet_type == INITIAL) { + QUICHE_DCHECK_NE(quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0, + header.retry_token_length_length) + << ENDPOINT << ParsedQuicVersionToString(version_) + << " bad retry token length length in header: " << header; + // Write retry token length. + if (!writer->WriteVarInt62WithForcedLength( + header.retry_token.length(), header.retry_token_length_length)) { + return false; + } + // Write retry token. + if (!header.retry_token.empty() && + !writer->WriteStringPiece(header.retry_token)) { + return false; + } + } + if (length_field_offset != nullptr) { + *length_field_offset = writer->length(); + } + // Add fake length to reserve two bytes to add length in later. + writer->WriteVarInt62(256); + } else if (length_field_offset != nullptr) { + *length_field_offset = 0; + } + + // Append packet number. + if (!AppendPacketNumber(header.packet_number_length, header.packet_number, + writer)) { + return false; + } + last_written_packet_number_length_ = header.packet_number_length; + + if (!header.version_flag) { + return true; + } + + if (header.nonce != nullptr) { + QUICHE_DCHECK(header.version_flag); + QUICHE_DCHECK_EQ(ZERO_RTT_PROTECTED, header.long_packet_type); + QUICHE_DCHECK_EQ(Perspective::IS_SERVER, perspective_); + if (!writer->WriteBytes(header.nonce, kDiversificationNonceSize)) { + return false; + } + } + + return true; +} + +const QuicTime::Delta QuicFramer::CalculateTimestampFromWire( + uint32_t time_delta_us) { + // The new time_delta might have wrapped to the next epoch, or it + // might have reverse wrapped to the previous epoch, or it might + // remain in the same epoch. Select the time closest to the previous + // time. + // + // epoch_delta is the delta between epochs. A delta is 4 bytes of + // microseconds. + const uint64_t epoch_delta = UINT64_C(1) << 32; + uint64_t epoch = last_timestamp_.ToMicroseconds() & ~(epoch_delta - 1); + // Wrapping is safe here because a wrapped value will not be ClosestTo below. + uint64_t prev_epoch = epoch - epoch_delta; + uint64_t next_epoch = epoch + epoch_delta; + + uint64_t time = ClosestTo( + last_timestamp_.ToMicroseconds(), epoch + time_delta_us, + ClosestTo(last_timestamp_.ToMicroseconds(), prev_epoch + time_delta_us, + next_epoch + time_delta_us)); + + return QuicTime::Delta::FromMicroseconds(time); +} + +uint64_t QuicFramer::CalculatePacketNumberFromWire( + QuicPacketNumberLength packet_number_length, + QuicPacketNumber base_packet_number, uint64_t packet_number) const { + // The new packet number might have wrapped to the next epoch, or + // it might have reverse wrapped to the previous epoch, or it might + // remain in the same epoch. Select the packet number closest to the + // next expected packet number, the previous packet number plus 1. + + // epoch_delta is the delta between epochs the packet number was serialized + // with, so the correct value is likely the same epoch as the last sequence + // number or an adjacent epoch. + if (!base_packet_number.IsInitialized()) { + return packet_number; + } + const uint64_t epoch_delta = UINT64_C(1) << (8 * packet_number_length); + uint64_t next_packet_number = base_packet_number.ToUint64() + 1; + uint64_t epoch = base_packet_number.ToUint64() & ~(epoch_delta - 1); + uint64_t prev_epoch = epoch - epoch_delta; + uint64_t next_epoch = epoch + epoch_delta; + + return ClosestTo(next_packet_number, epoch + packet_number, + ClosestTo(next_packet_number, prev_epoch + packet_number, + next_epoch + packet_number)); +} + +bool QuicFramer::ProcessPublicHeader(QuicDataReader* reader, + bool packet_has_ietf_packet_header, + QuicPacketHeader* header) { + if (packet_has_ietf_packet_header) { + return ProcessIetfPacketHeader(reader, header); + } + uint8_t public_flags; + if (!reader->ReadBytes(&public_flags, 1)) { + set_detailed_error("Unable to read public flags."); + return false; + } + + header->reset_flag = (public_flags & PACKET_PUBLIC_FLAGS_RST) != 0; + header->version_flag = (public_flags & PACKET_PUBLIC_FLAGS_VERSION) != 0; + + if (validate_flags_ && !header->version_flag && + public_flags > PACKET_PUBLIC_FLAGS_MAX) { + set_detailed_error("Illegal public flags value."); + return false; + } + + if (header->reset_flag && header->version_flag) { + set_detailed_error("Got version flag in reset packet"); + return false; + } + + QuicConnectionId* header_connection_id = &header->destination_connection_id; + QuicConnectionIdIncluded* header_connection_id_included = + &header->destination_connection_id_included; + if (perspective_ == Perspective::IS_CLIENT) { + header_connection_id = &header->source_connection_id; + header_connection_id_included = &header->source_connection_id_included; + } + switch (public_flags & PACKET_PUBLIC_FLAGS_8BYTE_CONNECTION_ID) { + case PACKET_PUBLIC_FLAGS_8BYTE_CONNECTION_ID: + if (!reader->ReadConnectionId(header_connection_id, + kQuicDefaultConnectionIdLength)) { + set_detailed_error("Unable to read ConnectionId."); + return false; + } + *header_connection_id_included = CONNECTION_ID_PRESENT; + break; + case PACKET_PUBLIC_FLAGS_0BYTE_CONNECTION_ID: + *header_connection_id_included = CONNECTION_ID_ABSENT; + *header_connection_id = last_serialized_server_connection_id_; + break; + } + + header->packet_number_length = ReadSequenceNumberLength( + public_flags >> kPublicHeaderSequenceNumberShift); + + // Read the version only if the packet is from the client. + // version flag from the server means version negotiation packet. + if (header->version_flag && perspective_ == Perspective::IS_SERVER) { + QuicVersionLabel version_label; + if (!ProcessVersionLabel(reader, &version_label)) { + set_detailed_error("Unable to read protocol version."); + return false; + } + // If the version from the new packet is the same as the version of this + // framer, then the public flags should be set to something we understand. + // If not, this raises an error. + ParsedQuicVersion version = ParseQuicVersionLabel(version_label); + if (version == version_ && public_flags > PACKET_PUBLIC_FLAGS_MAX) { + set_detailed_error("Illegal public flags value."); + return false; + } + header->version = version; + } + + // A nonce should only be present in packets from the server to the client, + // which are neither version negotiation nor public reset packets. + if (public_flags & PACKET_PUBLIC_FLAGS_NONCE && + !(public_flags & PACKET_PUBLIC_FLAGS_VERSION) && + !(public_flags & PACKET_PUBLIC_FLAGS_RST) && + // The nonce flag from a client is ignored and is assumed to be an older + // client indicating an eight-byte connection ID. + perspective_ == Perspective::IS_CLIENT) { + if (!reader->ReadBytes(reinterpret_cast(last_nonce_.data()), + last_nonce_.size())) { + set_detailed_error("Unable to read nonce."); + return false; + } + header->nonce = &last_nonce_; + } else { + header->nonce = nullptr; + } + + return true; +} + +// static +QuicPacketNumberLength QuicFramer::GetMinPacketNumberLength( + QuicPacketNumber packet_number) { + QUICHE_DCHECK(packet_number.IsInitialized()); + if (packet_number < QuicPacketNumber(1 << (PACKET_1BYTE_PACKET_NUMBER * 8))) { + return PACKET_1BYTE_PACKET_NUMBER; + } else if (packet_number < + QuicPacketNumber(1 << (PACKET_2BYTE_PACKET_NUMBER * 8))) { + return PACKET_2BYTE_PACKET_NUMBER; + } else if (packet_number < + QuicPacketNumber(UINT64_C(1) + << (PACKET_4BYTE_PACKET_NUMBER * 8))) { + return PACKET_4BYTE_PACKET_NUMBER; + } else { + return PACKET_6BYTE_PACKET_NUMBER; + } +} + +// static +uint8_t QuicFramer::GetPacketNumberFlags( + QuicPacketNumberLength packet_number_length) { + switch (packet_number_length) { + case PACKET_1BYTE_PACKET_NUMBER: + return PACKET_FLAGS_1BYTE_PACKET; + case PACKET_2BYTE_PACKET_NUMBER: + return PACKET_FLAGS_2BYTE_PACKET; + case PACKET_4BYTE_PACKET_NUMBER: + return PACKET_FLAGS_4BYTE_PACKET; + case PACKET_6BYTE_PACKET_NUMBER: + case PACKET_8BYTE_PACKET_NUMBER: + return PACKET_FLAGS_8BYTE_PACKET; + default: + QUIC_BUG(quic_bug_10850_56) << "Unreachable case statement."; + return PACKET_FLAGS_8BYTE_PACKET; + } +} + +// static +QuicFramer::AckFrameInfo QuicFramer::GetAckFrameInfo( + const QuicAckFrame& frame) { + AckFrameInfo new_ack_info; + if (frame.packets.Empty()) { + return new_ack_info; + } + // The first block is the last interval. It isn't encoded with the gap-length + // encoding, so skip it. + new_ack_info.first_block_length = frame.packets.LastIntervalLength(); + auto itr = frame.packets.rbegin(); + QuicPacketNumber previous_start = itr->min(); + new_ack_info.max_block_length = itr->Length(); + ++itr; + + // Don't do any more work after getting information for 256 ACK blocks; any + // more can't be encoded anyway. + for (; itr != frame.packets.rend() && + new_ack_info.num_ack_blocks < std::numeric_limits::max(); + previous_start = itr->min(), ++itr) { + const auto& interval = *itr; + const QuicPacketCount total_gap = previous_start - interval.max(); + new_ack_info.num_ack_blocks += + (total_gap + std::numeric_limits::max() - 1) / + std::numeric_limits::max(); + new_ack_info.max_block_length = + std::max(new_ack_info.max_block_length, interval.Length()); + } + return new_ack_info; +} + +bool QuicFramer::ProcessUnauthenticatedHeader(QuicDataReader* encrypted_reader, + QuicPacketHeader* header) { + QuicPacketNumber base_packet_number; + if (supports_multiple_packet_number_spaces_) { + PacketNumberSpace pn_space = GetPacketNumberSpace(*header); + if (pn_space == NUM_PACKET_NUMBER_SPACES) { + set_detailed_error("Unable to determine packet number space."); + return RaiseError(QUIC_INVALID_PACKET_HEADER); + } + base_packet_number = largest_decrypted_packet_numbers_[pn_space]; + } else { + base_packet_number = largest_packet_number_; + } + uint64_t full_packet_number; + if (!ProcessAndCalculatePacketNumber( + encrypted_reader, header->packet_number_length, base_packet_number, + &full_packet_number)) { + set_detailed_error("Unable to read packet number."); + return RaiseError(QUIC_INVALID_PACKET_HEADER); + } + + if (!IsValidFullPacketNumber(full_packet_number, version())) { + set_detailed_error("packet numbers cannot be 0."); + return RaiseError(QUIC_INVALID_PACKET_HEADER); + } + header->packet_number = QuicPacketNumber(full_packet_number); + + if (!visitor_->OnUnauthenticatedHeader(*header)) { + set_detailed_error( + "Visitor asked to stop processing of unauthenticated header."); + return false; + } + // The function we are in is called because the framer believes that it is + // processing a packet that uses the non-IETF (i.e. Google QUIC) packet header + // type. Usually, the framer makes that decision based on the framer's + // version, but when the framer is used with Perspective::IS_SERVER, then + // before version negotiation is complete (specifically, before + // InferPacketHeaderTypeFromVersion is called), this decision is made based on + // the type byte of the packet. + // + // If the framer's version KnowsWhichDecrypterToUse, then that version expects + // to use the IETF packet header type. If that's the case and we're in this + // function, then the packet received is invalid: the framer was expecting an + // IETF packet header and didn't get one. + if (version().KnowsWhichDecrypterToUse()) { + set_detailed_error("Invalid public header type for expected version."); + return RaiseError(QUIC_INVALID_PACKET_HEADER); + } + return true; +} + +bool QuicFramer::ProcessIetfHeaderTypeByte(QuicDataReader* reader, + QuicPacketHeader* header) { + uint8_t type; + if (!reader->ReadBytes(&type, 1)) { + set_detailed_error("Unable to read first byte."); + return false; + } + header->type_byte = type; + // Determine whether this is a long or short header. + header->form = GetIetfPacketHeaderFormat(type); + if (header->form == IETF_QUIC_LONG_HEADER_PACKET) { + // Version is always present in long headers. + header->version_flag = true; + // In versions that do not support client connection IDs, we mark the + // corresponding connection ID as absent. + header->destination_connection_id_included = + (perspective_ == Perspective::IS_SERVER || + version_.SupportsClientConnectionIds()) + ? CONNECTION_ID_PRESENT + : CONNECTION_ID_ABSENT; + header->source_connection_id_included = + (perspective_ == Perspective::IS_CLIENT || + version_.SupportsClientConnectionIds()) + ? CONNECTION_ID_PRESENT + : CONNECTION_ID_ABSENT; + // Read version tag. + QuicVersionLabel version_label; + if (!ProcessVersionLabel(reader, &version_label)) { + set_detailed_error("Unable to read protocol version."); + return false; + } + if (!version_label) { + // Version label is 0 indicating this is a version negotiation packet. + header->long_packet_type = VERSION_NEGOTIATION; + } else { + header->version = ParseQuicVersionLabel(version_label); + if (header->version.IsKnown()) { + if (!(type & FLAGS_FIXED_BIT)) { + set_detailed_error("Fixed bit is 0 in long header."); + return false; + } + header->long_packet_type = GetLongHeaderType(type, header->version); + switch (header->long_packet_type) { + case INVALID_PACKET_TYPE: + set_detailed_error("Illegal long header type value."); + return false; + case RETRY: + if (!version().SupportsRetry()) { + set_detailed_error("RETRY not supported in this version."); + return false; + } + if (perspective_ == Perspective::IS_SERVER) { + set_detailed_error("Client-initiated RETRY is invalid."); + return false; + } + break; + default: + if (!header->version.HasHeaderProtection()) { + header->packet_number_length = + GetLongHeaderPacketNumberLength(type); + } + break; + } + } + } + + QUIC_DVLOG(1) << ENDPOINT << "Received IETF long header: " + << QuicUtils::QuicLongHeaderTypetoString( + header->long_packet_type); + return true; + } + + QUIC_DVLOG(1) << ENDPOINT << "Received IETF short header"; + // Version is not present in short headers. + header->version_flag = false; + // In versions that do not support client connection IDs, the client will not + // receive destination connection IDs. + header->destination_connection_id_included = + (perspective_ == Perspective::IS_SERVER || + version_.SupportsClientConnectionIds()) + ? CONNECTION_ID_PRESENT + : CONNECTION_ID_ABSENT; + header->source_connection_id_included = CONNECTION_ID_ABSENT; + if (!(type & FLAGS_FIXED_BIT)) { + set_detailed_error("Fixed bit is 0 in short header."); + return false; + } + if (!version_.HasHeaderProtection()) { + header->packet_number_length = GetShortHeaderPacketNumberLength(type); + } + QUIC_DVLOG(1) << "packet_number_length = " << header->packet_number_length; + return true; +} + +// static +bool QuicFramer::ProcessVersionLabel(QuicDataReader* reader, + QuicVersionLabel* version_label) { + if (!reader->ReadUInt32(version_label)) { + return false; + } + return true; +} + +// static +bool QuicFramer::ProcessAndValidateIetfConnectionIdLength( + QuicDataReader* reader, ParsedQuicVersion version, Perspective perspective, + bool should_update_expected_server_connection_id_length, + uint8_t* expected_server_connection_id_length, + uint8_t* destination_connection_id_length, + uint8_t* source_connection_id_length, std::string* detailed_error) { + uint8_t connection_id_lengths_byte; + if (!reader->ReadBytes(&connection_id_lengths_byte, 1)) { + *detailed_error = "Unable to read ConnectionId length."; + return false; + } + uint8_t dcil = + (connection_id_lengths_byte & kDestinationConnectionIdLengthMask) >> 4; + if (dcil != 0) { + dcil += kConnectionIdLengthAdjustment; + } + uint8_t scil = connection_id_lengths_byte & kSourceConnectionIdLengthMask; + if (scil != 0) { + scil += kConnectionIdLengthAdjustment; + } + if (should_update_expected_server_connection_id_length) { + uint8_t server_connection_id_length = + perspective == Perspective::IS_SERVER ? dcil : scil; + if (*expected_server_connection_id_length != server_connection_id_length) { + QUIC_DVLOG(1) << "Updating expected_server_connection_id_length: " + << static_cast(*expected_server_connection_id_length) + << " -> " << static_cast(server_connection_id_length); + *expected_server_connection_id_length = server_connection_id_length; + } + } + if (!should_update_expected_server_connection_id_length && + (dcil != *destination_connection_id_length || + scil != *source_connection_id_length) && + version.IsKnown() && !version.AllowsVariableLengthConnectionIds()) { + QUIC_DVLOG(1) << "dcil: " << static_cast(dcil) + << ", scil: " << static_cast(scil); + *detailed_error = "Invalid ConnectionId length."; + return false; + } + *destination_connection_id_length = dcil; + *source_connection_id_length = scil; + return true; +} + +bool QuicFramer::ValidateReceivedConnectionIds(const QuicPacketHeader& header) { + bool skip_server_connection_id_validation = + perspective_ == Perspective::IS_CLIENT && + header.form == IETF_QUIC_SHORT_HEADER_PACKET; + if (!skip_server_connection_id_validation && + !QuicUtils::IsConnectionIdValidForVersion( + GetServerConnectionIdAsRecipient(header, perspective_), + transport_version())) { + set_detailed_error("Received server connection ID with invalid length."); + return false; + } + + bool skip_client_connection_id_validation = + perspective_ == Perspective::IS_SERVER && + header.form == IETF_QUIC_SHORT_HEADER_PACKET; + if (!skip_client_connection_id_validation && + version_.SupportsClientConnectionIds() && + !QuicUtils::IsConnectionIdValidForVersion( + GetClientConnectionIdAsRecipient(header, perspective_), + transport_version())) { + set_detailed_error("Received client connection ID with invalid length."); + return false; + } + return true; +} + +bool QuicFramer::ProcessIetfPacketHeader(QuicDataReader* reader, + QuicPacketHeader* header) { + if (version_.HasLengthPrefixedConnectionIds()) { + uint8_t expected_destination_connection_id_length = + perspective_ == Perspective::IS_CLIENT + ? expected_client_connection_id_length_ + : expected_server_connection_id_length_; + QuicVersionLabel version_label; + bool has_length_prefix; + std::string detailed_error; + QuicErrorCode parse_result = QuicFramer::ParsePublicHeader( + reader, expected_destination_connection_id_length, + version_.HasIetfInvariantHeader(), &header->type_byte, &header->form, + &header->version_flag, &has_length_prefix, &version_label, + &header->version, &header->destination_connection_id, + &header->source_connection_id, &header->long_packet_type, + &header->retry_token_length_length, &header->retry_token, + &detailed_error); + if (parse_result != QUIC_NO_ERROR) { + set_detailed_error(detailed_error); + return false; + } + header->destination_connection_id_included = CONNECTION_ID_PRESENT; + header->source_connection_id_included = + header->version_flag ? CONNECTION_ID_PRESENT : CONNECTION_ID_ABSENT; + + if (!ValidateReceivedConnectionIds(*header)) { + return false; + } + + if (header->version_flag && + header->long_packet_type != VERSION_NEGOTIATION && + !(header->type_byte & FLAGS_FIXED_BIT)) { + set_detailed_error("Fixed bit is 0 in long header."); + return false; + } + if (!header->version_flag && !(header->type_byte & FLAGS_FIXED_BIT)) { + set_detailed_error("Fixed bit is 0 in short header."); + return false; + } + if (!header->version_flag) { + if (!version_.HasHeaderProtection()) { + header->packet_number_length = + GetShortHeaderPacketNumberLength(header->type_byte); + } + return true; + } + if (header->long_packet_type == RETRY) { + if (!version().SupportsRetry()) { + set_detailed_error("RETRY not supported in this version."); + return false; + } + if (perspective_ == Perspective::IS_SERVER) { + set_detailed_error("Client-initiated RETRY is invalid."); + return false; + } + return true; + } + if (header->version.IsKnown() && !header->version.HasHeaderProtection()) { + header->packet_number_length = + GetLongHeaderPacketNumberLength(header->type_byte); + } + + return true; + } + + if (!ProcessIetfHeaderTypeByte(reader, header)) { + return false; + } + + uint8_t destination_connection_id_length = + header->destination_connection_id_included == CONNECTION_ID_PRESENT + ? (perspective_ == Perspective::IS_SERVER + ? expected_server_connection_id_length_ + : expected_client_connection_id_length_) + : 0; + uint8_t source_connection_id_length = + header->source_connection_id_included == CONNECTION_ID_PRESENT + ? (perspective_ == Perspective::IS_CLIENT + ? expected_server_connection_id_length_ + : expected_client_connection_id_length_) + : 0; + if (header->form == IETF_QUIC_LONG_HEADER_PACKET) { + if (!ProcessAndValidateIetfConnectionIdLength( + reader, header->version, perspective_, + /*should_update_expected_server_connection_id_length=*/false, + &expected_server_connection_id_length_, + &destination_connection_id_length, &source_connection_id_length, + &detailed_error_)) { + return false; + } + } + + // Read connection ID. + if (!reader->ReadConnectionId(&header->destination_connection_id, + destination_connection_id_length)) { + set_detailed_error("Unable to read destination connection ID."); + return false; + } + + if (!reader->ReadConnectionId(&header->source_connection_id, + source_connection_id_length)) { + set_detailed_error("Unable to read source connection ID."); + return false; + } + + if (header->source_connection_id_included == CONNECTION_ID_ABSENT) { + if (!header->source_connection_id.IsEmpty()) { + QUICHE_DCHECK(!version_.SupportsClientConnectionIds()); + set_detailed_error("Client connection ID not supported in this version."); + return false; + } + } + + return ValidateReceivedConnectionIds(*header); +} + +bool QuicFramer::ProcessAndCalculatePacketNumber( + QuicDataReader* reader, QuicPacketNumberLength packet_number_length, + QuicPacketNumber base_packet_number, uint64_t* packet_number) { + uint64_t wire_packet_number; + if (!reader->ReadBytesToUInt64(packet_number_length, &wire_packet_number)) { + return false; + } + + // TODO(ianswett): Explore the usefulness of trying multiple packet numbers + // in case the first guess is incorrect. + *packet_number = CalculatePacketNumberFromWire( + packet_number_length, base_packet_number, wire_packet_number); + return true; +} + +bool QuicFramer::ProcessFrameData(QuicDataReader* reader, + const QuicPacketHeader& header) { + QUICHE_DCHECK(!VersionHasIetfQuicFrames(version_.transport_version)) + << "IETF QUIC Framing negotiated but attempting to process frames as " + "non-IETF QUIC."; + if (reader->IsDoneReading()) { + set_detailed_error("Packet has no frames."); + return RaiseError(QUIC_MISSING_PAYLOAD); + } + QUIC_DVLOG(2) << ENDPOINT << "Processing packet with header " << header; + while (!reader->IsDoneReading()) { + uint8_t frame_type; + if (!reader->ReadBytes(&frame_type, 1)) { + set_detailed_error("Unable to read frame type."); + return RaiseError(QUIC_INVALID_FRAME_DATA); + } + const uint8_t special_mask = version_.HasIetfInvariantHeader() + ? kQuicFrameTypeSpecialMask + : kQuicFrameTypeBrokenMask; + if (frame_type & special_mask) { + // Stream Frame + if (frame_type & kQuicFrameTypeStreamMask) { + QuicStreamFrame frame; + if (!ProcessStreamFrame(reader, frame_type, &frame)) { + return RaiseError(QUIC_INVALID_STREAM_DATA); + } + QUIC_DVLOG(2) << ENDPOINT << "Processing stream frame " << frame; + if (!visitor_->OnStreamFrame(frame)) { + QUIC_DVLOG(1) << ENDPOINT + << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + continue; + } + + // Ack Frame + if (frame_type & kQuicFrameTypeAckMask) { + if (!ProcessAckFrame(reader, frame_type)) { + return RaiseError(QUIC_INVALID_ACK_DATA); + } + QUIC_DVLOG(2) << ENDPOINT << "Processing ACK frame"; + continue; + } + + // This was a special frame type that did not match any + // of the known ones. Error. + set_detailed_error("Illegal frame type."); + QUIC_DLOG(WARNING) << ENDPOINT << "Illegal frame type: " + << static_cast(frame_type); + return RaiseError(QUIC_INVALID_FRAME_DATA); + } + + switch (frame_type) { + case PADDING_FRAME: { + QuicPaddingFrame frame; + ProcessPaddingFrame(reader, &frame); + QUIC_DVLOG(2) << ENDPOINT << "Processing padding frame " << frame; + if (!visitor_->OnPaddingFrame(frame)) { + QUIC_DVLOG(1) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + continue; + } + + case RST_STREAM_FRAME: { + QuicRstStreamFrame frame; + if (!ProcessRstStreamFrame(reader, &frame)) { + return RaiseError(QUIC_INVALID_RST_STREAM_DATA); + } + QUIC_DVLOG(2) << ENDPOINT << "Processing reset stream frame " << frame; + if (!visitor_->OnRstStreamFrame(frame)) { + QUIC_DVLOG(1) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + continue; + } + + case CONNECTION_CLOSE_FRAME: { + QuicConnectionCloseFrame frame; + if (!ProcessConnectionCloseFrame(reader, &frame)) { + return RaiseError(QUIC_INVALID_CONNECTION_CLOSE_DATA); + } + + QUIC_DVLOG(2) << ENDPOINT << "Processing connection close frame " + << frame; + if (!visitor_->OnConnectionCloseFrame(frame)) { + QUIC_DVLOG(1) << ENDPOINT + << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + continue; + } + + case GOAWAY_FRAME: { + QuicGoAwayFrame goaway_frame; + if (!ProcessGoAwayFrame(reader, &goaway_frame)) { + return RaiseError(QUIC_INVALID_GOAWAY_DATA); + } + QUIC_DVLOG(2) << ENDPOINT << "Processing go away frame " + << goaway_frame; + if (!visitor_->OnGoAwayFrame(goaway_frame)) { + QUIC_DVLOG(1) << ENDPOINT + << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + continue; + } + + case WINDOW_UPDATE_FRAME: { + QuicWindowUpdateFrame window_update_frame; + if (!ProcessWindowUpdateFrame(reader, &window_update_frame)) { + return RaiseError(QUIC_INVALID_WINDOW_UPDATE_DATA); + } + QUIC_DVLOG(2) << ENDPOINT << "Processing window update frame " + << window_update_frame; + if (!visitor_->OnWindowUpdateFrame(window_update_frame)) { + QUIC_DVLOG(1) << ENDPOINT + << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + continue; + } + + case BLOCKED_FRAME: { + QuicBlockedFrame blocked_frame; + if (!ProcessBlockedFrame(reader, &blocked_frame)) { + return RaiseError(QUIC_INVALID_BLOCKED_DATA); + } + QUIC_DVLOG(2) << ENDPOINT << "Processing blocked frame " + << blocked_frame; + if (!visitor_->OnBlockedFrame(blocked_frame)) { + QUIC_DVLOG(1) << ENDPOINT + << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + continue; + } + + case STOP_WAITING_FRAME: { + QuicStopWaitingFrame stop_waiting_frame; + if (!ProcessStopWaitingFrame(reader, header, &stop_waiting_frame)) { + return RaiseError(QUIC_INVALID_STOP_WAITING_DATA); + } + QUIC_DVLOG(2) << ENDPOINT << "Processing stop waiting frame " + << stop_waiting_frame; + if (!visitor_->OnStopWaitingFrame(stop_waiting_frame)) { + QUIC_DVLOG(1) << ENDPOINT + << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + continue; + } + case PING_FRAME: { + // Ping has no payload. + QuicPingFrame ping_frame; + if (!visitor_->OnPingFrame(ping_frame)) { + QUIC_DVLOG(1) << ENDPOINT + << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + QUIC_DVLOG(2) << ENDPOINT << "Processing ping frame " << ping_frame; + continue; + } + case IETF_EXTENSION_MESSAGE_NO_LENGTH: + ABSL_FALLTHROUGH_INTENDED; + case IETF_EXTENSION_MESSAGE: { + QuicMessageFrame message_frame; + if (!ProcessMessageFrame(reader, + frame_type == IETF_EXTENSION_MESSAGE_NO_LENGTH, + &message_frame)) { + return RaiseError(QUIC_INVALID_MESSAGE_DATA); + } + QUIC_DVLOG(2) << ENDPOINT << "Processing message frame " + << message_frame; + if (!visitor_->OnMessageFrame(message_frame)) { + QUIC_DVLOG(1) << ENDPOINT + << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + break; + } + case CRYPTO_FRAME: { + if (!QuicVersionUsesCryptoFrames(version_.transport_version)) { + set_detailed_error("Illegal frame type."); + return RaiseError(QUIC_INVALID_FRAME_DATA); + } + QuicCryptoFrame frame; + if (!ProcessCryptoFrame(reader, GetEncryptionLevel(header), &frame)) { + return RaiseError(QUIC_INVALID_FRAME_DATA); + } + QUIC_DVLOG(2) << ENDPOINT << "Processing crypto frame " << frame; + if (!visitor_->OnCryptoFrame(frame)) { + QUIC_DVLOG(1) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + break; + } + case HANDSHAKE_DONE_FRAME: { + // HANDSHAKE_DONE has no payload. + QuicHandshakeDoneFrame handshake_done_frame; + QUIC_DVLOG(2) << ENDPOINT << "Processing handshake done frame " + << handshake_done_frame; + if (!visitor_->OnHandshakeDoneFrame(handshake_done_frame)) { + QUIC_DVLOG(1) << ENDPOINT + << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + break; + } + + default: + set_detailed_error("Illegal frame type."); + QUIC_DLOG(WARNING) << ENDPOINT << "Illegal frame type: " + << static_cast(frame_type); + return RaiseError(QUIC_INVALID_FRAME_DATA); + } + } + + return true; +} + +// static +bool QuicFramer::IsIetfFrameTypeExpectedForEncryptionLevel( + uint64_t frame_type, EncryptionLevel level) { + // IETF_CRYPTO is allowed for any level here and is separately checked in + // QuicCryptoStream::OnCryptoFrame. + switch (level) { + case ENCRYPTION_INITIAL: + case ENCRYPTION_HANDSHAKE: + return frame_type == IETF_CRYPTO || frame_type == IETF_ACK || + frame_type == IETF_ACK_ECN || + frame_type == IETF_ACK_RECEIVE_TIMESTAMPS || + frame_type == IETF_PING || frame_type == IETF_PADDING || + frame_type == IETF_CONNECTION_CLOSE; + case ENCRYPTION_ZERO_RTT: + return !(frame_type == IETF_ACK || frame_type == IETF_ACK_ECN || + frame_type == IETF_ACK_RECEIVE_TIMESTAMPS || + frame_type == IETF_HANDSHAKE_DONE || + frame_type == IETF_NEW_TOKEN || + frame_type == IETF_PATH_RESPONSE || + frame_type == IETF_RETIRE_CONNECTION_ID); + case ENCRYPTION_FORWARD_SECURE: + return true; + default: + QUIC_BUG(quic_bug_10850_57) << "Unknown encryption level: " << level; + } + return false; +} + +bool QuicFramer::ProcessIetfFrameData(QuicDataReader* reader, + const QuicPacketHeader& header, + EncryptionLevel decrypted_level) { + QUICHE_DCHECK(VersionHasIetfQuicFrames(version_.transport_version)) + << "Attempt to process frames as IETF frames but version (" + << version_.transport_version << ") does not support IETF Framing."; + + if (reader->IsDoneReading()) { + set_detailed_error("Packet has no frames."); + return RaiseError(QUIC_MISSING_PAYLOAD); + } + + QUIC_DVLOG(2) << ENDPOINT << "Processing IETF packet with header " << header; + auto* connection_context = QuicConnectionContext::Current(); + while (!reader->IsDoneReading()) { + if (connection_context != nullptr) { + connection_context->process_packet_context.current_frame_offset = + connection_context->process_packet_context.decrypted_payload.size() - + reader->BytesRemaining(); + } + uint64_t frame_type; + // Will be the number of bytes into which frame_type was encoded. + size_t encoded_bytes = reader->BytesRemaining(); + if (!reader->ReadVarInt62(&frame_type)) { + set_detailed_error("Unable to read frame type."); + return RaiseError(QUIC_INVALID_FRAME_DATA); + } + if (!IsIetfFrameTypeExpectedForEncryptionLevel(frame_type, + decrypted_level)) { + set_detailed_error(absl::StrCat( + "IETF frame type ", + QuicIetfFrameTypeString(static_cast(frame_type)), + " is unexpected at encryption level ", + EncryptionLevelToString(decrypted_level))); + return RaiseError(IETF_QUIC_PROTOCOL_VIOLATION); + } + previously_received_frame_type_ = current_received_frame_type_; + current_received_frame_type_ = frame_type; + + // Is now the number of bytes into which the frame type was encoded. + encoded_bytes -= reader->BytesRemaining(); + + // Check that the frame type is minimally encoded. + if (encoded_bytes != + static_cast(QuicDataWriter::GetVarInt62Len(frame_type))) { + // The frame type was not minimally encoded. + set_detailed_error("Frame type not minimally encoded."); + return RaiseError(IETF_QUIC_PROTOCOL_VIOLATION); + } + + if (IS_IETF_STREAM_FRAME(frame_type)) { + QuicStreamFrame frame; + if (!ProcessIetfStreamFrame(reader, frame_type, &frame)) { + return RaiseError(QUIC_INVALID_STREAM_DATA); + } + QUIC_DVLOG(2) << ENDPOINT << "Processing IETF stream frame " << frame; + if (!visitor_->OnStreamFrame(frame)) { + QUIC_DVLOG(1) << ENDPOINT + << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + } else { + switch (frame_type) { + case IETF_PADDING: { + QuicPaddingFrame frame; + ProcessPaddingFrame(reader, &frame); + QUIC_DVLOG(2) << ENDPOINT << "Processing IETF padding frame " + << frame; + if (!visitor_->OnPaddingFrame(frame)) { + QUIC_DVLOG(1) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + break; + } + case IETF_RST_STREAM: { + QuicRstStreamFrame frame; + if (!ProcessIetfResetStreamFrame(reader, &frame)) { + return RaiseError(QUIC_INVALID_RST_STREAM_DATA); + } + QUIC_DVLOG(2) << ENDPOINT << "Processing IETF reset stream frame " + << frame; + if (!visitor_->OnRstStreamFrame(frame)) { + QUIC_DVLOG(1) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + break; + } + case IETF_APPLICATION_CLOSE: + case IETF_CONNECTION_CLOSE: { + QuicConnectionCloseFrame frame; + if (!ProcessIetfConnectionCloseFrame( + reader, + (frame_type == IETF_CONNECTION_CLOSE) + ? IETF_QUIC_TRANSPORT_CONNECTION_CLOSE + : IETF_QUIC_APPLICATION_CONNECTION_CLOSE, + &frame)) { + return RaiseError(QUIC_INVALID_CONNECTION_CLOSE_DATA); + } + QUIC_DVLOG(2) << ENDPOINT << "Processing IETF connection close frame " + << frame; + if (!visitor_->OnConnectionCloseFrame(frame)) { + QUIC_DVLOG(1) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + break; + } + case IETF_MAX_DATA: { + QuicWindowUpdateFrame frame; + if (!ProcessMaxDataFrame(reader, &frame)) { + return RaiseError(QUIC_INVALID_MAX_DATA_FRAME_DATA); + } + QUIC_DVLOG(2) << ENDPOINT << "Processing IETF max data frame " + << frame; + if (!visitor_->OnWindowUpdateFrame(frame)) { + QUIC_DVLOG(1) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + break; + } + case IETF_MAX_STREAM_DATA: { + QuicWindowUpdateFrame frame; + if (!ProcessMaxStreamDataFrame(reader, &frame)) { + return RaiseError(QUIC_INVALID_MAX_STREAM_DATA_FRAME_DATA); + } + QUIC_DVLOG(2) << ENDPOINT << "Processing IETF max stream data frame " + << frame; + if (!visitor_->OnWindowUpdateFrame(frame)) { + QUIC_DVLOG(1) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + break; + } + case IETF_MAX_STREAMS_BIDIRECTIONAL: + case IETF_MAX_STREAMS_UNIDIRECTIONAL: { + QuicMaxStreamsFrame frame; + if (!ProcessMaxStreamsFrame(reader, &frame, frame_type)) { + return RaiseError(QUIC_MAX_STREAMS_DATA); + } + QUIC_CODE_COUNT_N(quic_max_streams_received, 1, 2); + QUIC_DVLOG(2) << ENDPOINT << "Processing IETF max streams frame " + << frame; + if (!visitor_->OnMaxStreamsFrame(frame)) { + QUIC_DVLOG(1) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + break; + } + case IETF_PING: { + // Ping has no payload. + QuicPingFrame ping_frame; + QUIC_DVLOG(2) << ENDPOINT << "Processing IETF ping frame " + << ping_frame; + if (!visitor_->OnPingFrame(ping_frame)) { + QUIC_DVLOG(1) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + break; + } + case IETF_DATA_BLOCKED: { + QuicBlockedFrame frame; + if (!ProcessDataBlockedFrame(reader, &frame)) { + return RaiseError(QUIC_INVALID_BLOCKED_DATA); + } + QUIC_DVLOG(2) << ENDPOINT << "Processing IETF blocked frame " + << frame; + if (!visitor_->OnBlockedFrame(frame)) { + QUIC_DVLOG(1) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + break; + } + case IETF_STREAM_DATA_BLOCKED: { + QuicBlockedFrame frame; + if (!ProcessStreamDataBlockedFrame(reader, &frame)) { + return RaiseError(QUIC_INVALID_STREAM_BLOCKED_DATA); + } + QUIC_DVLOG(2) << ENDPOINT << "Processing IETF stream blocked frame " + << frame; + if (!visitor_->OnBlockedFrame(frame)) { + QUIC_DVLOG(1) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + break; + } + case IETF_STREAMS_BLOCKED_UNIDIRECTIONAL: + case IETF_STREAMS_BLOCKED_BIDIRECTIONAL: { + QuicStreamsBlockedFrame frame; + if (!ProcessStreamsBlockedFrame(reader, &frame, frame_type)) { + return RaiseError(QUIC_STREAMS_BLOCKED_DATA); + } + QUIC_DVLOG(2) << ENDPOINT << "Processing IETF streams blocked frame " + << frame; + if (!visitor_->OnStreamsBlockedFrame(frame)) { + QUIC_DVLOG(1) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + break; + } + case IETF_NEW_CONNECTION_ID: { + QuicNewConnectionIdFrame frame; + if (!ProcessNewConnectionIdFrame(reader, &frame)) { + return RaiseError(QUIC_INVALID_NEW_CONNECTION_ID_DATA); + } + QUIC_DVLOG(2) << ENDPOINT + << "Processing IETF new connection ID frame " << frame; + if (!visitor_->OnNewConnectionIdFrame(frame)) { + QUIC_DVLOG(1) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + break; + } + case IETF_RETIRE_CONNECTION_ID: { + QuicRetireConnectionIdFrame frame; + if (!ProcessRetireConnectionIdFrame(reader, &frame)) { + return RaiseError(QUIC_INVALID_RETIRE_CONNECTION_ID_DATA); + } + QUIC_DVLOG(2) << ENDPOINT + << "Processing IETF retire connection ID frame " + << frame; + if (!visitor_->OnRetireConnectionIdFrame(frame)) { + QUIC_DVLOG(1) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + break; + } + case IETF_NEW_TOKEN: { + QuicNewTokenFrame frame; + if (!ProcessNewTokenFrame(reader, &frame)) { + return RaiseError(QUIC_INVALID_NEW_TOKEN); + } + QUIC_DVLOG(2) << ENDPOINT << "Processing IETF new token frame " + << frame; + if (!visitor_->OnNewTokenFrame(frame)) { + QUIC_DVLOG(1) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + break; + } + case IETF_STOP_SENDING: { + QuicStopSendingFrame frame; + if (!ProcessStopSendingFrame(reader, &frame)) { + return RaiseError(QUIC_INVALID_STOP_SENDING_FRAME_DATA); + } + QUIC_DVLOG(2) << ENDPOINT << "Processing IETF stop sending frame " + << frame; + if (!visitor_->OnStopSendingFrame(frame)) { + QUIC_DVLOG(1) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + break; + } + case IETF_ACK_RECEIVE_TIMESTAMPS: + if (!process_timestamps_) { + set_detailed_error("Unsupported frame type."); + QUIC_DLOG(WARNING) + << ENDPOINT << "IETF_ACK_RECEIVE_TIMESTAMPS not supported"; + return RaiseError(QUIC_INVALID_FRAME_DATA); + } + ABSL_FALLTHROUGH_INTENDED; + case IETF_ACK_ECN: + case IETF_ACK: { + QuicAckFrame frame; + if (!ProcessIetfAckFrame(reader, frame_type, &frame)) { + return RaiseError(QUIC_INVALID_ACK_DATA); + } + QUIC_DVLOG(2) << ENDPOINT << "Processing IETF ACK frame " << frame; + break; + } + case IETF_PATH_CHALLENGE: { + QuicPathChallengeFrame frame; + if (!ProcessPathChallengeFrame(reader, &frame)) { + return RaiseError(QUIC_INVALID_PATH_CHALLENGE_DATA); + } + QUIC_DVLOG(2) << ENDPOINT << "Processing IETF path challenge frame " + << frame; + if (!visitor_->OnPathChallengeFrame(frame)) { + QUIC_DVLOG(1) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + break; + } + case IETF_PATH_RESPONSE: { + QuicPathResponseFrame frame; + if (!ProcessPathResponseFrame(reader, &frame)) { + return RaiseError(QUIC_INVALID_PATH_RESPONSE_DATA); + } + QUIC_DVLOG(2) << ENDPOINT << "Processing IETF path response frame " + << frame; + if (!visitor_->OnPathResponseFrame(frame)) { + QUIC_DVLOG(1) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + break; + } + case IETF_EXTENSION_MESSAGE_NO_LENGTH_V99: + ABSL_FALLTHROUGH_INTENDED; + case IETF_EXTENSION_MESSAGE_V99: { + QuicMessageFrame message_frame; + if (!ProcessMessageFrame( + reader, frame_type == IETF_EXTENSION_MESSAGE_NO_LENGTH_V99, + &message_frame)) { + return RaiseError(QUIC_INVALID_MESSAGE_DATA); + } + QUIC_DVLOG(2) << ENDPOINT << "Processing IETF message frame " + << message_frame; + if (!visitor_->OnMessageFrame(message_frame)) { + QUIC_DVLOG(1) << ENDPOINT + << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + break; + } + case IETF_CRYPTO: { + QuicCryptoFrame frame; + if (!ProcessCryptoFrame(reader, GetEncryptionLevel(header), &frame)) { + return RaiseError(QUIC_INVALID_FRAME_DATA); + } + QUIC_DVLOG(2) << ENDPOINT << "Processing IETF crypto frame " << frame; + if (!visitor_->OnCryptoFrame(frame)) { + QUIC_DVLOG(1) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + break; + } + case IETF_HANDSHAKE_DONE: { + // HANDSHAKE_DONE has no payload. + QuicHandshakeDoneFrame handshake_done_frame; + if (!visitor_->OnHandshakeDoneFrame(handshake_done_frame)) { + QUIC_DVLOG(1) << ENDPOINT + << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + QUIC_DVLOG(2) << ENDPOINT << "Processing handshake done frame " + << handshake_done_frame; + break; + } + case IETF_ACK_FREQUENCY: { + QuicAckFrequencyFrame frame; + if (!ProcessAckFrequencyFrame(reader, &frame)) { + return RaiseError(QUIC_INVALID_FRAME_DATA); + } + QUIC_DVLOG(2) << ENDPOINT << "Processing IETF ack frequency frame " + << frame; + if (!visitor_->OnAckFrequencyFrame(frame)) { + QUIC_DVLOG(1) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + break; + } + default: + set_detailed_error("Illegal frame type."); + QUIC_DLOG(WARNING) + << ENDPOINT + << "Illegal frame type: " << static_cast(frame_type); + return RaiseError(QUIC_INVALID_FRAME_DATA); + } + } + } + return true; +} + +namespace { +// Create a mask that sets the last |num_bits| to 1 and the rest to 0. +inline uint8_t GetMaskFromNumBits(uint8_t num_bits) { + return (1u << num_bits) - 1; +} + +// Extract |num_bits| from |flags| offset by |offset|. +uint8_t ExtractBits(uint8_t flags, uint8_t num_bits, uint8_t offset) { + return (flags >> offset) & GetMaskFromNumBits(num_bits); +} + +// Extract the bit at position |offset| from |flags| as a bool. +bool ExtractBit(uint8_t flags, uint8_t offset) { + return ((flags >> offset) & GetMaskFromNumBits(1)) != 0; +} + +// Set |num_bits|, offset by |offset| to |val| in |flags|. +void SetBits(uint8_t* flags, uint8_t val, uint8_t num_bits, uint8_t offset) { + QUICHE_DCHECK_LE(val, GetMaskFromNumBits(num_bits)); + *flags |= val << offset; +} + +// Set the bit at position |offset| to |val| in |flags|. +void SetBit(uint8_t* flags, bool val, uint8_t offset) { + SetBits(flags, val ? 1 : 0, 1, offset); +} +} // namespace + +bool QuicFramer::ProcessStreamFrame(QuicDataReader* reader, uint8_t frame_type, + QuicStreamFrame* frame) { + uint8_t stream_flags = frame_type; + + uint8_t stream_id_length = 0; + uint8_t offset_length = 4; + bool has_data_length = true; + stream_flags &= ~kQuicFrameTypeStreamMask; + + // Read from right to left: StreamID, Offset, Data Length, Fin. + stream_id_length = (stream_flags & kQuicStreamIDLengthMask) + 1; + stream_flags >>= kQuicStreamIdShift; + + offset_length = (stream_flags & kQuicStreamOffsetMask); + // There is no encoding for 1 byte, only 0 and 2 through 8. + if (offset_length > 0) { + offset_length += 1; + } + stream_flags >>= kQuicStreamShift; + + has_data_length = + (stream_flags & kQuicStreamDataLengthMask) == kQuicStreamDataLengthMask; + stream_flags >>= kQuicStreamDataLengthShift; + + frame->fin = (stream_flags & kQuicStreamFinMask) == kQuicStreamFinShift; + + uint64_t stream_id; + if (!reader->ReadBytesToUInt64(stream_id_length, &stream_id)) { + set_detailed_error("Unable to read stream_id."); + return false; + } + frame->stream_id = static_cast(stream_id); + + if (!reader->ReadBytesToUInt64(offset_length, &frame->offset)) { + set_detailed_error("Unable to read offset."); + return false; + } + + // TODO(ianswett): Don't use absl::string_view as an intermediary. + absl::string_view data; + if (has_data_length) { + if (!reader->ReadStringPiece16(&data)) { + set_detailed_error("Unable to read frame data."); + return false; + } + } else { + if (!reader->ReadStringPiece(&data, reader->BytesRemaining())) { + set_detailed_error("Unable to read frame data."); + return false; + } + } + frame->data_buffer = data.data(); + frame->data_length = static_cast(data.length()); + + return true; +} + +bool QuicFramer::ProcessIetfStreamFrame(QuicDataReader* reader, + uint8_t frame_type, + QuicStreamFrame* frame) { + // Read stream id from the frame. It's always present. + if (!ReadUint32FromVarint62(reader, IETF_STREAM, &frame->stream_id)) { + return false; + } + + // If we have a data offset, read it. If not, set to 0. + if (frame_type & IETF_STREAM_FRAME_OFF_BIT) { + if (!reader->ReadVarInt62(&frame->offset)) { + set_detailed_error("Unable to read stream data offset."); + return false; + } + } else { + // no offset in the frame, ensure it's 0 in the Frame. + frame->offset = 0; + } + + // If we have a data length, read it. If not, set to 0. + if (frame_type & IETF_STREAM_FRAME_LEN_BIT) { + uint64_t length; + if (!reader->ReadVarInt62(&length)) { + set_detailed_error("Unable to read stream data length."); + return false; + } + if (length > std::numeric_limitsdata_length)>::max()) { + set_detailed_error("Stream data length is too large."); + return false; + } + frame->data_length = length; + } else { + // no length in the frame, it is the number of bytes remaining in the + // packet. + frame->data_length = reader->BytesRemaining(); + } + + if (frame_type & IETF_STREAM_FRAME_FIN_BIT) { + frame->fin = true; + } else { + frame->fin = false; + } + + // TODO(ianswett): Don't use absl::string_view as an intermediary. + absl::string_view data; + if (!reader->ReadStringPiece(&data, frame->data_length)) { + set_detailed_error("Unable to read frame data."); + return false; + } + frame->data_buffer = data.data(); + QUICHE_DCHECK_EQ(frame->data_length, data.length()); + + return true; +} + +bool QuicFramer::ProcessCryptoFrame(QuicDataReader* reader, + EncryptionLevel encryption_level, + QuicCryptoFrame* frame) { + frame->level = encryption_level; + if (!reader->ReadVarInt62(&frame->offset)) { + set_detailed_error("Unable to read crypto data offset."); + return false; + } + uint64_t len; + if (!reader->ReadVarInt62(&len) || + len > std::numeric_limits::max()) { + set_detailed_error("Invalid data length."); + return false; + } + frame->data_length = len; + + // TODO(ianswett): Don't use absl::string_view as an intermediary. + absl::string_view data; + if (!reader->ReadStringPiece(&data, frame->data_length)) { + set_detailed_error("Unable to read frame data."); + return false; + } + frame->data_buffer = data.data(); + return true; +} + +bool QuicFramer::ProcessAckFrequencyFrame(QuicDataReader* reader, + QuicAckFrequencyFrame* frame) { + if (!reader->ReadVarInt62(&frame->sequence_number)) { + set_detailed_error("Unable to read sequence number."); + return false; + } + + if (!reader->ReadVarInt62(&frame->packet_tolerance)) { + set_detailed_error("Unable to read packet tolerance."); + return false; + } + if (frame->packet_tolerance == 0) { + set_detailed_error("Invalid packet tolerance."); + return false; + } + uint64_t max_ack_delay_us; + if (!reader->ReadVarInt62(&max_ack_delay_us)) { + set_detailed_error("Unable to read max_ack_delay_us."); + return false; + } + constexpr uint64_t kMaxAckDelayUsBound = 1u << 24; + if (max_ack_delay_us > kMaxAckDelayUsBound) { + set_detailed_error("Invalid max_ack_delay_us."); + return false; + } + frame->max_ack_delay = QuicTime::Delta::FromMicroseconds(max_ack_delay_us); + + uint8_t ignore_order; + if (!reader->ReadUInt8(&ignore_order)) { + set_detailed_error("Unable to read ignore_order."); + return false; + } + if (ignore_order > 1) { + set_detailed_error("Invalid ignore_order."); + return false; + } + frame->ignore_order = ignore_order; + + return true; +} + +bool QuicFramer::ProcessAckFrame(QuicDataReader* reader, uint8_t frame_type) { + const bool has_ack_blocks = + ExtractBit(frame_type, kQuicHasMultipleAckBlocksOffset); + uint8_t num_ack_blocks = 0; + uint8_t num_received_packets = 0; + + // Determine the two lengths from the frame type: largest acked length, + // ack block length. + const QuicPacketNumberLength ack_block_length = + ReadAckPacketNumberLength(ExtractBits( + frame_type, kQuicSequenceNumberLengthNumBits, kActBlockLengthOffset)); + const QuicPacketNumberLength largest_acked_length = + ReadAckPacketNumberLength(ExtractBits( + frame_type, kQuicSequenceNumberLengthNumBits, kLargestAckedOffset)); + + uint64_t largest_acked; + if (!reader->ReadBytesToUInt64(largest_acked_length, &largest_acked)) { + set_detailed_error("Unable to read largest acked."); + return false; + } + + if (largest_acked < first_sending_packet_number_.ToUint64()) { + // Connection always sends packet starting from kFirstSendingPacketNumber > + // 0, peer has observed an unsent packet. + set_detailed_error("Largest acked is 0."); + return false; + } + + uint64_t ack_delay_time_us; + if (!reader->ReadUFloat16(&ack_delay_time_us)) { + set_detailed_error("Unable to read ack delay time."); + return false; + } + + if (!visitor_->OnAckFrameStart( + QuicPacketNumber(largest_acked), + ack_delay_time_us == kUFloat16MaxValue + ? QuicTime::Delta::Infinite() + : QuicTime::Delta::FromMicroseconds(ack_delay_time_us))) { + // The visitor suppresses further processing of the packet. Although this is + // not a parsing error, returns false as this is in middle of processing an + // ack frame, + set_detailed_error("Visitor suppresses further processing of ack frame."); + return false; + } + + if (has_ack_blocks && !reader->ReadUInt8(&num_ack_blocks)) { + set_detailed_error("Unable to read num of ack blocks."); + return false; + } + + uint64_t first_block_length; + if (!reader->ReadBytesToUInt64(ack_block_length, &first_block_length)) { + set_detailed_error("Unable to read first ack block length."); + return false; + } + + if (first_block_length == 0) { + set_detailed_error("First block length is zero."); + return false; + } + bool first_ack_block_underflow = first_block_length > largest_acked + 1; + if (first_block_length + first_sending_packet_number_.ToUint64() > + largest_acked + 1) { + first_ack_block_underflow = true; + } + if (first_ack_block_underflow) { + set_detailed_error(absl::StrCat("Underflow with first ack block length ", + first_block_length, " largest acked is ", + largest_acked, ".") + .c_str()); + return false; + } + + uint64_t first_received = largest_acked + 1 - first_block_length; + if (!visitor_->OnAckRange(QuicPacketNumber(first_received), + QuicPacketNumber(largest_acked + 1))) { + // The visitor suppresses further processing of the packet. Although + // this is not a parsing error, returns false as this is in middle + // of processing an ack frame, + set_detailed_error("Visitor suppresses further processing of ack frame."); + return false; + } + + if (num_ack_blocks > 0) { + for (size_t i = 0; i < num_ack_blocks; ++i) { + uint8_t gap = 0; + if (!reader->ReadUInt8(&gap)) { + set_detailed_error("Unable to read gap to next ack block."); + return false; + } + uint64_t current_block_length; + if (!reader->ReadBytesToUInt64(ack_block_length, ¤t_block_length)) { + set_detailed_error("Unable to ack block length."); + return false; + } + bool ack_block_underflow = first_received < gap + current_block_length; + if (first_received < gap + current_block_length + + first_sending_packet_number_.ToUint64()) { + ack_block_underflow = true; + } + if (ack_block_underflow) { + set_detailed_error(absl::StrCat("Underflow with ack block length ", + current_block_length, + ", end of block is ", + first_received - gap, ".") + .c_str()); + return false; + } + + first_received -= (gap + current_block_length); + if (current_block_length > 0) { + if (!visitor_->OnAckRange( + QuicPacketNumber(first_received), + QuicPacketNumber(first_received) + current_block_length)) { + // The visitor suppresses further processing of the packet. Although + // this is not a parsing error, returns false as this is in middle + // of processing an ack frame, + set_detailed_error( + "Visitor suppresses further processing of ack frame."); + return false; + } + } + } + } + + if (!reader->ReadUInt8(&num_received_packets)) { + set_detailed_error("Unable to read num received packets."); + return false; + } + + if (!ProcessTimestampsInAckFrame(num_received_packets, + QuicPacketNumber(largest_acked), reader)) { + return false; + } + + // Done processing the ACK frame. + absl::optional ecn_counts = absl::nullopt; + if (!visitor_->OnAckFrameEnd(QuicPacketNumber(first_received), ecn_counts)) { + set_detailed_error( + "Error occurs when visitor finishes processing the ACK frame."); + return false; + } + + return true; +} + +bool QuicFramer::ProcessTimestampsInAckFrame(uint8_t num_received_packets, + QuicPacketNumber largest_acked, + QuicDataReader* reader) { + if (num_received_packets == 0) { + return true; + } + uint8_t delta_from_largest_observed; + if (!reader->ReadUInt8(&delta_from_largest_observed)) { + set_detailed_error("Unable to read sequence delta in received packets."); + return false; + } + + if (largest_acked.ToUint64() <= delta_from_largest_observed) { + set_detailed_error( + absl::StrCat("delta_from_largest_observed too high: ", + delta_from_largest_observed, + ", largest_acked: ", largest_acked.ToUint64()) + .c_str()); + return false; + } + + // Time delta from the framer creation. + uint32_t time_delta_us; + if (!reader->ReadUInt32(&time_delta_us)) { + set_detailed_error("Unable to read time delta in received packets."); + return false; + } + + QuicPacketNumber seq_num = largest_acked - delta_from_largest_observed; + if (process_timestamps_) { + last_timestamp_ = CalculateTimestampFromWire(time_delta_us); + + visitor_->OnAckTimestamp(seq_num, creation_time_ + last_timestamp_); + } + + for (uint8_t i = 1; i < num_received_packets; ++i) { + if (!reader->ReadUInt8(&delta_from_largest_observed)) { + set_detailed_error("Unable to read sequence delta in received packets."); + return false; + } + if (largest_acked.ToUint64() <= delta_from_largest_observed) { + set_detailed_error( + absl::StrCat("delta_from_largest_observed too high: ", + delta_from_largest_observed, + ", largest_acked: ", largest_acked.ToUint64()) + .c_str()); + return false; + } + seq_num = largest_acked - delta_from_largest_observed; + + // Time delta from the previous timestamp. + uint64_t incremental_time_delta_us; + if (!reader->ReadUFloat16(&incremental_time_delta_us)) { + set_detailed_error( + "Unable to read incremental time delta in received packets."); + return false; + } + + if (process_timestamps_) { + last_timestamp_ = last_timestamp_ + QuicTime::Delta::FromMicroseconds( + incremental_time_delta_us); + visitor_->OnAckTimestamp(seq_num, creation_time_ + last_timestamp_); + } + } + return true; +} + +bool QuicFramer::ProcessIetfAckFrame(QuicDataReader* reader, + uint64_t frame_type, + QuicAckFrame* ack_frame) { + uint64_t largest_acked; + if (!reader->ReadVarInt62(&largest_acked)) { + set_detailed_error("Unable to read largest acked."); + return false; + } + if (largest_acked < first_sending_packet_number_.ToUint64()) { + // Connection always sends packet starting from kFirstSendingPacketNumber > + // 0, peer has observed an unsent packet. + set_detailed_error("Largest acked is 0."); + return false; + } + ack_frame->largest_acked = static_cast(largest_acked); + uint64_t ack_delay_time_in_us; + if (!reader->ReadVarInt62(&ack_delay_time_in_us)) { + set_detailed_error("Unable to read ack delay time."); + return false; + } + + if (ack_delay_time_in_us >= + (quiche::kVarInt62MaxValue >> peer_ack_delay_exponent_)) { + ack_frame->ack_delay_time = QuicTime::Delta::Infinite(); + } else { + ack_delay_time_in_us = (ack_delay_time_in_us << peer_ack_delay_exponent_); + ack_frame->ack_delay_time = + QuicTime::Delta::FromMicroseconds(ack_delay_time_in_us); + } + if (!visitor_->OnAckFrameStart(QuicPacketNumber(largest_acked), + ack_frame->ack_delay_time)) { + // The visitor suppresses further processing of the packet. Although this is + // not a parsing error, returns false as this is in middle of processing an + // ACK frame. + set_detailed_error("Visitor suppresses further processing of ACK frame."); + return false; + } + + // Get number of ACK blocks from the packet. + uint64_t ack_block_count; + if (!reader->ReadVarInt62(&ack_block_count)) { + set_detailed_error("Unable to read ack block count."); + return false; + } + // There always is a first ACK block, which is the (number of packets being + // acked)-1, up to and including the packet at largest_acked. Therefore if the + // value is 0, then only largest is acked. If it is 1, then largest-1, + // largest] are acked, etc + uint64_t ack_block_value; + if (!reader->ReadVarInt62(&ack_block_value)) { + set_detailed_error("Unable to read first ack block length."); + return false; + } + // Calculate the packets being acked in the first block. + // +1 because AddRange implementation requires [low,high) + uint64_t block_high = largest_acked + 1; + uint64_t block_low = largest_acked - ack_block_value; + + // ack_block_value is the number of packets preceding the + // largest_acked packet which are in the block being acked. Thus, + // its maximum value is largest_acked-1. Test this, reporting an + // error if the value is wrong. + if (ack_block_value + first_sending_packet_number_.ToUint64() > + largest_acked) { + set_detailed_error(absl::StrCat("Underflow with first ack block length ", + ack_block_value + 1, " largest acked is ", + largest_acked, ".") + .c_str()); + return false; + } + + if (!visitor_->OnAckRange(QuicPacketNumber(block_low), + QuicPacketNumber(block_high))) { + // The visitor suppresses further processing of the packet. Although + // this is not a parsing error, returns false as this is in middle + // of processing an ACK frame. + set_detailed_error("Visitor suppresses further processing of ACK frame."); + return false; + } + + while (ack_block_count != 0) { + uint64_t gap_block_value; + // Get the sizes of the gap and ack blocks, + if (!reader->ReadVarInt62(&gap_block_value)) { + set_detailed_error("Unable to read gap block value."); + return false; + } + // It's an error if the gap is larger than the space from packet + // number 0 to the start of the block that's just been acked, PLUS + // there must be space for at least 1 packet to be acked. For + // example, if block_low is 10 and gap_block_value is 9, it means + // the gap block is 10 packets long, leaving no room for a packet + // to be acked. Thus, gap_block_value+2 can not be larger than + // block_low. + // The test is written this way to detect wrap-arounds. + if ((gap_block_value + 2) > block_low) { + set_detailed_error( + absl::StrCat("Underflow with gap block length ", gap_block_value + 1, + " previous ack block start is ", block_low, ".") + .c_str()); + return false; + } + + // Adjust block_high to be the top of the next ack block. + // There is a gap of |gap_block_value| packets between the bottom + // of ack block N and top of block N+1. Note that gap_block_value + // is he size of the gap minus 1 (per the QUIC protocol), and + // block_high is the packet number of the first packet of the gap + // (per the implementation of OnAckRange/AddAckRange, below). + block_high = block_low - 1 - gap_block_value; + + if (!reader->ReadVarInt62(&ack_block_value)) { + set_detailed_error("Unable to read ack block value."); + return false; + } + if (ack_block_value + first_sending_packet_number_.ToUint64() > + (block_high - 1)) { + set_detailed_error( + absl::StrCat("Underflow with ack block length ", ack_block_value + 1, + " latest ack block end is ", block_high - 1, ".") + .c_str()); + return false; + } + // Calculate the low end of the new nth ack block. The +1 is + // because the encoded value is the blocksize-1. + block_low = block_high - 1 - ack_block_value; + if (!visitor_->OnAckRange(QuicPacketNumber(block_low), + QuicPacketNumber(block_high))) { + // The visitor suppresses further processing of the packet. Although + // this is not a parsing error, returns false as this is in middle + // of processing an ACK frame. + set_detailed_error("Visitor suppresses further processing of ACK frame."); + return false; + } + + // Another one done. + ack_block_count--; + } + + QUICHE_DCHECK(!ack_frame->ecn_counters.has_value()); + if (frame_type == IETF_ACK_RECEIVE_TIMESTAMPS) { + QUICHE_DCHECK(process_timestamps_); + if (!ProcessIetfTimestampsInAckFrame(ack_frame->largest_acked, reader)) { + return false; + } + } else if (frame_type == IETF_ACK_ECN) { + ack_frame->ecn_counters = QuicEcnCounts(); + if (!reader->ReadVarInt62(&ack_frame->ecn_counters->ect0)) { + set_detailed_error("Unable to read ack ect_0_count."); + return false; + } + if (!reader->ReadVarInt62(&ack_frame->ecn_counters->ect1)) { + set_detailed_error("Unable to read ack ect_1_count."); + return false; + } + if (!reader->ReadVarInt62(&ack_frame->ecn_counters->ce)) { + set_detailed_error("Unable to read ack ecn_ce_count."); + return false; + } + if (GetQuicRestartFlag(quic_receive_ecn)) { + QUIC_RESTART_FLAG_COUNT_N(quic_receive_ecn, 2, 3); + } + } + + if (!visitor_->OnAckFrameEnd(QuicPacketNumber(block_low), + ack_frame->ecn_counters)) { + set_detailed_error( + "Error occurs when visitor finishes processing the ACK frame."); + return false; + } + + return true; +} + +bool QuicFramer::ProcessIetfTimestampsInAckFrame(QuicPacketNumber largest_acked, + QuicDataReader* reader) { + uint64_t timestamp_range_count; + if (!reader->ReadVarInt62(×tamp_range_count)) { + set_detailed_error("Unable to read receive timestamp range count."); + return false; + } + if (timestamp_range_count == 0) { + return true; + } + + QuicPacketNumber packet_number = largest_acked; + + // Iterate through all timestamp ranges, each of which represents a block of + // contiguous packets for which receive timestamps are being reported. Each + // range is of the form: + // + // Timestamp Range { + // Gap (i), + // Timestamp Delta Count (i), + // Timestamp Delta (i) ..., + // } + for (uint64_t i = 0; i < timestamp_range_count; i++) { + uint64_t gap; + if (!reader->ReadVarInt62(&gap)) { + set_detailed_error("Unable to read receive timestamp gap."); + return false; + } + if (packet_number.ToUint64() < gap) { + set_detailed_error("Receive timestamp gap too high."); + return false; + } + packet_number = packet_number - gap; + uint64_t timestamp_count; + if (!reader->ReadVarInt62(×tamp_count)) { + set_detailed_error("Unable to read receive timestamp count."); + return false; + } + if (packet_number.ToUint64() < timestamp_count) { + set_detailed_error("Receive timestamp count too high."); + return false; + } + for (uint64_t j = 0; j < timestamp_count; j++) { + uint64_t timestamp_delta; + if (!reader->ReadVarInt62(×tamp_delta)) { + set_detailed_error("Unable to read receive timestamp delta."); + return false; + } + // The first timestamp delta is relative to framer creation time; whereas + // subsequent deltas are relative to the previous delta in decreasing + // packet order. + timestamp_delta = timestamp_delta << receive_timestamps_exponent_; + if (i == 0 && j == 0) { + last_timestamp_ = QuicTime::Delta::FromMicroseconds(timestamp_delta); + } else { + last_timestamp_ = last_timestamp_ - + QuicTime::Delta::FromMicroseconds(timestamp_delta); + if (last_timestamp_ < QuicTime::Delta::Zero()) { + set_detailed_error("Receive timestamp delta too high."); + return false; + } + } + visitor_->OnAckTimestamp(packet_number, creation_time_ + last_timestamp_); + packet_number--; + } + packet_number--; + } + return true; +} + +bool QuicFramer::ProcessStopWaitingFrame(QuicDataReader* reader, + const QuicPacketHeader& header, + QuicStopWaitingFrame* stop_waiting) { + uint64_t least_unacked_delta; + if (!reader->ReadBytesToUInt64(header.packet_number_length, + &least_unacked_delta)) { + set_detailed_error("Unable to read least unacked delta."); + return false; + } + if (header.packet_number.ToUint64() <= least_unacked_delta) { + set_detailed_error("Invalid unacked delta."); + return false; + } + stop_waiting->least_unacked = header.packet_number - least_unacked_delta; + + return true; +} + +bool QuicFramer::ProcessRstStreamFrame(QuicDataReader* reader, + QuicRstStreamFrame* frame) { + if (!reader->ReadUInt32(&frame->stream_id)) { + set_detailed_error("Unable to read stream_id."); + return false; + } + + if (!reader->ReadUInt64(&frame->byte_offset)) { + set_detailed_error("Unable to read rst stream sent byte offset."); + return false; + } + + uint32_t error_code; + if (!reader->ReadUInt32(&error_code)) { + set_detailed_error("Unable to read rst stream error code."); + return false; + } + + if (error_code >= QUIC_STREAM_LAST_ERROR) { + // Ignore invalid stream error code if any. + error_code = QUIC_STREAM_LAST_ERROR; + } + + frame->error_code = static_cast(error_code); + + return true; +} + +bool QuicFramer::ProcessConnectionCloseFrame(QuicDataReader* reader, + QuicConnectionCloseFrame* frame) { + uint32_t error_code; + frame->close_type = GOOGLE_QUIC_CONNECTION_CLOSE; + + if (!reader->ReadUInt32(&error_code)) { + set_detailed_error("Unable to read connection close error code."); + return false; + } + + // For Google QUIC connection closes, |wire_error_code| and |quic_error_code| + // must have the same value. + frame->wire_error_code = error_code; + frame->quic_error_code = static_cast(error_code); + + absl::string_view error_details; + if (!reader->ReadStringPiece16(&error_details)) { + set_detailed_error("Unable to read connection close error details."); + return false; + } + frame->error_details = std::string(error_details); + + return true; +} + +bool QuicFramer::ProcessGoAwayFrame(QuicDataReader* reader, + QuicGoAwayFrame* frame) { + uint32_t error_code; + if (!reader->ReadUInt32(&error_code)) { + set_detailed_error("Unable to read go away error code."); + return false; + } + + frame->error_code = static_cast(error_code); + + uint32_t stream_id; + if (!reader->ReadUInt32(&stream_id)) { + set_detailed_error("Unable to read last good stream id."); + return false; + } + frame->last_good_stream_id = static_cast(stream_id); + + absl::string_view reason_phrase; + if (!reader->ReadStringPiece16(&reason_phrase)) { + set_detailed_error("Unable to read goaway reason."); + return false; + } + frame->reason_phrase = std::string(reason_phrase); + + return true; +} + +bool QuicFramer::ProcessWindowUpdateFrame(QuicDataReader* reader, + QuicWindowUpdateFrame* frame) { + if (!reader->ReadUInt32(&frame->stream_id)) { + set_detailed_error("Unable to read stream_id."); + return false; + } + + if (!reader->ReadUInt64(&frame->max_data)) { + set_detailed_error("Unable to read window byte_offset."); + return false; + } + + return true; +} + +bool QuicFramer::ProcessBlockedFrame(QuicDataReader* reader, + QuicBlockedFrame* frame) { + QUICHE_DCHECK(!VersionHasIetfQuicFrames(version_.transport_version)) + << "Attempt to process non-IETF QUIC frames in an IETF QUIC version."; + + if (!reader->ReadUInt32(&frame->stream_id)) { + set_detailed_error("Unable to read stream_id."); + return false; + } + + return true; +} + +void QuicFramer::ProcessPaddingFrame(QuicDataReader* reader, + QuicPaddingFrame* frame) { + // Type byte has been read. + frame->num_padding_bytes = 1; + uint8_t next_byte; + while (!reader->IsDoneReading() && reader->PeekByte() == 0x00) { + reader->ReadBytes(&next_byte, 1); + QUICHE_DCHECK_EQ(0x00, next_byte); + ++frame->num_padding_bytes; + } +} + +bool QuicFramer::ProcessMessageFrame(QuicDataReader* reader, + bool no_message_length, + QuicMessageFrame* frame) { + if (no_message_length) { + absl::string_view remaining(reader->ReadRemainingPayload()); + frame->data = remaining.data(); + frame->message_length = remaining.length(); + return true; + } + + uint64_t message_length; + if (!reader->ReadVarInt62(&message_length)) { + set_detailed_error("Unable to read message length"); + return false; + } + + absl::string_view message_piece; + if (!reader->ReadStringPiece(&message_piece, message_length)) { + set_detailed_error("Unable to read message data"); + return false; + } + + frame->data = message_piece.data(); + frame->message_length = message_length; + + return true; +} + +// static +absl::string_view QuicFramer::GetAssociatedDataFromEncryptedPacket( + QuicTransportVersion version, const QuicEncryptedPacket& encrypted, + uint8_t destination_connection_id_length, + uint8_t source_connection_id_length, bool includes_version, + bool includes_diversification_nonce, + QuicPacketNumberLength packet_number_length, + quiche::QuicheVariableLengthIntegerLength retry_token_length_length, + uint64_t retry_token_length, + quiche::QuicheVariableLengthIntegerLength length_length) { + // TODO(ianswett): This is identical to QuicData::AssociatedData. + return absl::string_view( + encrypted.data(), + GetStartOfEncryptedData(version, destination_connection_id_length, + source_connection_id_length, includes_version, + includes_diversification_nonce, + packet_number_length, retry_token_length_length, + retry_token_length, length_length)); +} + +void QuicFramer::SetDecrypter(EncryptionLevel level, + std::unique_ptr decrypter) { + QUICHE_DCHECK_GE(level, decrypter_level_); + QUICHE_DCHECK(!version_.KnowsWhichDecrypterToUse()); + QUIC_DVLOG(1) << ENDPOINT << "Setting decrypter from level " + << decrypter_level_ << " to " << level; + decrypter_[decrypter_level_] = nullptr; + decrypter_[level] = std::move(decrypter); + decrypter_level_ = level; +} + +void QuicFramer::SetAlternativeDecrypter( + EncryptionLevel level, std::unique_ptr decrypter, + bool latch_once_used) { + QUICHE_DCHECK_NE(level, decrypter_level_); + QUICHE_DCHECK(!version_.KnowsWhichDecrypterToUse()); + QUIC_DVLOG(1) << ENDPOINT << "Setting alternative decrypter from level " + << alternative_decrypter_level_ << " to " << level; + if (alternative_decrypter_level_ != NUM_ENCRYPTION_LEVELS) { + decrypter_[alternative_decrypter_level_] = nullptr; + } + decrypter_[level] = std::move(decrypter); + alternative_decrypter_level_ = level; + alternative_decrypter_latch_ = latch_once_used; +} + +void QuicFramer::InstallDecrypter(EncryptionLevel level, + std::unique_ptr decrypter) { + QUICHE_DCHECK(version_.KnowsWhichDecrypterToUse()); + QUIC_DVLOG(1) << ENDPOINT << "Installing decrypter at level " << level; + decrypter_[level] = std::move(decrypter); +} + +void QuicFramer::RemoveDecrypter(EncryptionLevel level) { + QUICHE_DCHECK(version_.KnowsWhichDecrypterToUse()); + QUIC_DVLOG(1) << ENDPOINT << "Removing decrypter at level " << level; + decrypter_[level] = nullptr; +} + +void QuicFramer::SetKeyUpdateSupportForConnection(bool enabled) { + QUIC_DVLOG(1) << ENDPOINT << "SetKeyUpdateSupportForConnection: " << enabled; + support_key_update_for_connection_ = enabled; +} + +void QuicFramer::DiscardPreviousOneRttKeys() { + QUICHE_DCHECK(support_key_update_for_connection_); + QUIC_DVLOG(1) << ENDPOINT << "Discarding previous set of 1-RTT keys"; + previous_decrypter_ = nullptr; +} + +bool QuicFramer::DoKeyUpdate(KeyUpdateReason reason) { + QUICHE_DCHECK(support_key_update_for_connection_); + if (!next_decrypter_) { + // If key update is locally initiated, next decrypter might not be created + // yet. + next_decrypter_ = visitor_->AdvanceKeysAndCreateCurrentOneRttDecrypter(); + } + std::unique_ptr next_encrypter = + visitor_->CreateCurrentOneRttEncrypter(); + if (!next_decrypter_ || !next_encrypter) { + QUIC_BUG(quic_bug_10850_58) << "Failed to create next crypters"; + return false; + } + key_update_performed_ = true; + current_key_phase_bit_ = !current_key_phase_bit_; + QUIC_DLOG(INFO) << ENDPOINT << "DoKeyUpdate: new current_key_phase_bit_=" + << current_key_phase_bit_; + current_key_phase_first_received_packet_number_.Clear(); + previous_decrypter_ = std::move(decrypter_[ENCRYPTION_FORWARD_SECURE]); + decrypter_[ENCRYPTION_FORWARD_SECURE] = std::move(next_decrypter_); + encrypter_[ENCRYPTION_FORWARD_SECURE] = std::move(next_encrypter); + switch (reason) { + case KeyUpdateReason::kInvalid: + QUIC_CODE_COUNT(quic_key_update_invalid); + break; + case KeyUpdateReason::kRemote: + QUIC_CODE_COUNT(quic_key_update_remote); + break; + case KeyUpdateReason::kLocalForTests: + QUIC_CODE_COUNT(quic_key_update_local_for_tests); + break; + case KeyUpdateReason::kLocalForInteropRunner: + QUIC_CODE_COUNT(quic_key_update_local_for_interop_runner); + break; + case KeyUpdateReason::kLocalAeadConfidentialityLimit: + QUIC_CODE_COUNT(quic_key_update_local_aead_confidentiality_limit); + break; + case KeyUpdateReason::kLocalKeyUpdateLimitOverride: + QUIC_CODE_COUNT(quic_key_update_local_limit_override); + break; + } + visitor_->OnKeyUpdate(reason); + return true; +} + +QuicPacketCount QuicFramer::PotentialPeerKeyUpdateAttemptCount() const { + return potential_peer_key_update_attempt_count_; +} + +const QuicDecrypter* QuicFramer::GetDecrypter(EncryptionLevel level) const { + QUICHE_DCHECK(version_.KnowsWhichDecrypterToUse()); + return decrypter_[level].get(); +} + +const QuicDecrypter* QuicFramer::decrypter() const { + return decrypter_[decrypter_level_].get(); +} + +const QuicDecrypter* QuicFramer::alternative_decrypter() const { + if (alternative_decrypter_level_ == NUM_ENCRYPTION_LEVELS) { + return nullptr; + } + return decrypter_[alternative_decrypter_level_].get(); +} + +void QuicFramer::SetEncrypter(EncryptionLevel level, + std::unique_ptr encrypter) { + QUICHE_DCHECK_GE(level, 0); + QUICHE_DCHECK_LT(level, NUM_ENCRYPTION_LEVELS); + QUIC_DVLOG(1) << ENDPOINT << "Setting encrypter at level " << level; + encrypter_[level] = std::move(encrypter); +} + +void QuicFramer::RemoveEncrypter(EncryptionLevel level) { + QUIC_DVLOG(1) << ENDPOINT << "Removing encrypter of " << level; + encrypter_[level] = nullptr; +} + +void QuicFramer::SetInitialObfuscators(QuicConnectionId connection_id) { + CrypterPair crypters; + CryptoUtils::CreateInitialObfuscators(perspective_, version_, connection_id, + &crypters); + encrypter_[ENCRYPTION_INITIAL] = std::move(crypters.encrypter); + decrypter_[ENCRYPTION_INITIAL] = std::move(crypters.decrypter); +} + +size_t QuicFramer::EncryptInPlace(EncryptionLevel level, + QuicPacketNumber packet_number, size_t ad_len, + size_t total_len, size_t buffer_len, + char* buffer) { + QUICHE_DCHECK(packet_number.IsInitialized()); + if (encrypter_[level] == nullptr) { + QUIC_BUG(quic_bug_10850_59) + << ENDPOINT + << "Attempted to encrypt in place without encrypter at level " << level; + RaiseError(QUIC_ENCRYPTION_FAILURE); + return 0; + } + + size_t output_length = 0; + if (!encrypter_[level]->EncryptPacket( + packet_number.ToUint64(), + absl::string_view(buffer, ad_len), // Associated data + absl::string_view(buffer + ad_len, + total_len - ad_len), // Plaintext + buffer + ad_len, // Destination buffer + &output_length, buffer_len - ad_len)) { + RaiseError(QUIC_ENCRYPTION_FAILURE); + return 0; + } + if (version_.HasHeaderProtection() && + !ApplyHeaderProtection(level, buffer, ad_len + output_length, ad_len)) { + QUIC_DLOG(ERROR) << "Applying header protection failed."; + RaiseError(QUIC_ENCRYPTION_FAILURE); + return 0; + } + + return ad_len + output_length; +} + +namespace { + +const size_t kHPSampleLen = 16; + +constexpr bool IsLongHeader(uint8_t type_byte) { + return (type_byte & FLAGS_LONG_HEADER) != 0; +} + +} // namespace + +bool QuicFramer::ApplyHeaderProtection(EncryptionLevel level, char* buffer, + size_t buffer_len, size_t ad_len) { + QuicDataReader buffer_reader(buffer, buffer_len); + QuicDataWriter buffer_writer(buffer_len, buffer); + // The sample starts 4 bytes after the start of the packet number. + if (ad_len < last_written_packet_number_length_) { + return false; + } + size_t pn_offset = ad_len - last_written_packet_number_length_; + // Sample the ciphertext and generate the mask to use for header protection. + size_t sample_offset = pn_offset + 4; + QuicDataReader sample_reader(buffer, buffer_len); + absl::string_view sample; + if (!sample_reader.Seek(sample_offset) || + !sample_reader.ReadStringPiece(&sample, kHPSampleLen)) { + QUIC_BUG(quic_bug_10850_60) + << "Not enough bytes to sample: sample_offset " << sample_offset + << ", sample len: " << kHPSampleLen << ", buffer len: " << buffer_len; + return false; + } + + if (encrypter_[level] == nullptr) { + QUIC_BUG(quic_bug_12975_8) + << ENDPOINT + << "Attempted to apply header protection without encrypter at level " + << level << " using " << version_; + return false; + } + + std::string mask = encrypter_[level]->GenerateHeaderProtectionMask(sample); + if (mask.empty()) { + QUIC_BUG(quic_bug_10850_61) << "Unable to generate header protection mask."; + return false; + } + QuicDataReader mask_reader(mask.data(), mask.size()); + + // Apply the mask to the 4 or 5 least significant bits of the first byte. + uint8_t bitmask = 0x1f; + uint8_t type_byte; + if (!buffer_reader.ReadUInt8(&type_byte)) { + return false; + } + QuicLongHeaderType header_type; + if (IsLongHeader(type_byte)) { + bitmask = 0x0f; + header_type = GetLongHeaderType(type_byte, version_); + if (header_type == INVALID_PACKET_TYPE) { + return false; + } + } + uint8_t mask_byte; + if (!mask_reader.ReadUInt8(&mask_byte) || + !buffer_writer.WriteUInt8(type_byte ^ (mask_byte & bitmask))) { + return false; + } + + // Adjust |pn_offset| to account for the diversification nonce. + if (IsLongHeader(type_byte) && header_type == ZERO_RTT_PROTECTED && + perspective_ == Perspective::IS_SERVER && + version_.handshake_protocol == PROTOCOL_QUIC_CRYPTO) { + if (pn_offset <= kDiversificationNonceSize) { + QUIC_BUG(quic_bug_10850_62) + << "Expected diversification nonce, but not enough bytes"; + return false; + } + pn_offset -= kDiversificationNonceSize; + } + // Advance the reader and writer to the packet number. Both the reader and + // writer have each read/written one byte. + if (!buffer_writer.Seek(pn_offset - 1) || + !buffer_reader.Seek(pn_offset - 1)) { + return false; + } + // Apply the rest of the mask to the packet number. + for (size_t i = 0; i < last_written_packet_number_length_; ++i) { + uint8_t buffer_byte; + uint8_t pn_mask_byte; + if (!mask_reader.ReadUInt8(&pn_mask_byte) || + !buffer_reader.ReadUInt8(&buffer_byte) || + !buffer_writer.WriteUInt8(buffer_byte ^ pn_mask_byte)) { + return false; + } + } + return true; +} + +bool QuicFramer::RemoveHeaderProtection(QuicDataReader* reader, + const QuicEncryptedPacket& packet, + QuicPacketHeader* header, + uint64_t* full_packet_number, + std::vector* associated_data) { + EncryptionLevel expected_decryption_level = GetEncryptionLevel(*header); + QuicDecrypter* decrypter = decrypter_[expected_decryption_level].get(); + if (decrypter == nullptr) { + QUIC_DVLOG(1) + << ENDPOINT + << "No decrypter available for removing header protection at level " + << expected_decryption_level; + return false; + } + + bool has_diversification_nonce = + header->form == IETF_QUIC_LONG_HEADER_PACKET && + header->long_packet_type == ZERO_RTT_PROTECTED && + perspective_ == Perspective::IS_CLIENT && + version_.handshake_protocol == PROTOCOL_QUIC_CRYPTO; + + // Read a sample from the ciphertext and compute the mask to use for header + // protection. + absl::string_view remaining_packet = reader->PeekRemainingPayload(); + QuicDataReader sample_reader(remaining_packet); + + // The sample starts 4 bytes after the start of the packet number. + absl::string_view pn; + if (!sample_reader.ReadStringPiece(&pn, 4)) { + QUIC_DVLOG(1) << "Not enough data to sample"; + return false; + } + if (has_diversification_nonce) { + // In Google QUIC, the diversification nonce comes between the packet number + // and the sample. + if (!sample_reader.Seek(kDiversificationNonceSize)) { + QUIC_DVLOG(1) << "No diversification nonce to skip over"; + return false; + } + } + std::string mask = decrypter->GenerateHeaderProtectionMask(&sample_reader); + QuicDataReader mask_reader(mask.data(), mask.size()); + if (mask.empty()) { + QUIC_DVLOG(1) << "Failed to compute mask"; + return false; + } + + // Unmask the rest of the type byte. + uint8_t bitmask = 0x1f; + if (IsLongHeader(header->type_byte)) { + bitmask = 0x0f; + } + uint8_t mask_byte; + if (!mask_reader.ReadUInt8(&mask_byte)) { + QUIC_DVLOG(1) << "No first byte to read from mask"; + return false; + } + header->type_byte ^= (mask_byte & bitmask); + + // Compute the packet number length. + header->packet_number_length = + static_cast((header->type_byte & 0x03) + 1); + + char pn_buffer[IETF_MAX_PACKET_NUMBER_LENGTH] = {}; + QuicDataWriter pn_writer(ABSL_ARRAYSIZE(pn_buffer), pn_buffer); + + // Read the (protected) packet number from the reader and unmask the packet + // number. + for (size_t i = 0; i < header->packet_number_length; ++i) { + uint8_t protected_pn_byte, pn_mask_byte; + if (!mask_reader.ReadUInt8(&pn_mask_byte) || + !reader->ReadUInt8(&protected_pn_byte) || + !pn_writer.WriteUInt8(protected_pn_byte ^ pn_mask_byte)) { + QUIC_DVLOG(1) << "Failed to unmask packet number"; + return false; + } + } + QuicDataReader packet_number_reader(pn_writer.data(), pn_writer.length()); + QuicPacketNumber base_packet_number; + if (supports_multiple_packet_number_spaces_) { + PacketNumberSpace pn_space = GetPacketNumberSpace(*header); + if (pn_space == NUM_PACKET_NUMBER_SPACES) { + return false; + } + base_packet_number = largest_decrypted_packet_numbers_[pn_space]; + } else { + base_packet_number = largest_packet_number_; + } + if (!ProcessAndCalculatePacketNumber( + &packet_number_reader, header->packet_number_length, + base_packet_number, full_packet_number)) { + return false; + } + + // Get the associated data, and apply the same unmasking operations to it. + absl::string_view ad = GetAssociatedDataFromEncryptedPacket( + version_.transport_version, packet, + GetIncludedDestinationConnectionIdLength(*header), + GetIncludedSourceConnectionIdLength(*header), header->version_flag, + has_diversification_nonce, header->packet_number_length, + header->retry_token_length_length, header->retry_token.length(), + header->length_length); + *associated_data = std::vector(ad.begin(), ad.end()); + QuicDataWriter ad_writer(associated_data->size(), associated_data->data()); + + // Apply the unmasked type byte and packet number to |associated_data|. + if (!ad_writer.WriteUInt8(header->type_byte)) { + return false; + } + // Put the packet number at the end of the AD, or if there's a diversification + // nonce, before that (which is at the end of the AD). + size_t seek_len = ad_writer.remaining() - header->packet_number_length; + if (has_diversification_nonce) { + seek_len -= kDiversificationNonceSize; + } + if (!ad_writer.Seek(seek_len) || + !ad_writer.WriteBytes(pn_writer.data(), pn_writer.length())) { + QUIC_DVLOG(1) << "Failed to apply unmasking operations to AD"; + return false; + } + + return true; +} + +size_t QuicFramer::EncryptPayload(EncryptionLevel level, + QuicPacketNumber packet_number, + const QuicPacket& packet, char* buffer, + size_t buffer_len) { + QUICHE_DCHECK(packet_number.IsInitialized()); + if (encrypter_[level] == nullptr) { + QUIC_BUG(quic_bug_10850_63) + << ENDPOINT << "Attempted to encrypt without encrypter at level " + << level; + RaiseError(QUIC_ENCRYPTION_FAILURE); + return 0; + } + + absl::string_view associated_data = + packet.AssociatedData(version_.transport_version); + // Copy in the header, because the encrypter only populates the encrypted + // plaintext content. + const size_t ad_len = associated_data.length(); + if (packet.length() < ad_len) { + QUIC_BUG(quic_bug_10850_64) + << ENDPOINT << "packet is shorter than associated data length. version:" + << version() << ", packet length:" << packet.length() + << ", associated data length:" << ad_len; + RaiseError(QUIC_ENCRYPTION_FAILURE); + return 0; + } + memmove(buffer, associated_data.data(), ad_len); + // Encrypt the plaintext into the buffer. + size_t output_length = 0; + if (!encrypter_[level]->EncryptPacket( + packet_number.ToUint64(), associated_data, + packet.Plaintext(version_.transport_version), buffer + ad_len, + &output_length, buffer_len - ad_len)) { + RaiseError(QUIC_ENCRYPTION_FAILURE); + return 0; + } + if (version_.HasHeaderProtection() && + !ApplyHeaderProtection(level, buffer, ad_len + output_length, ad_len)) { + QUIC_DLOG(ERROR) << "Applying header protection failed."; + RaiseError(QUIC_ENCRYPTION_FAILURE); + return 0; + } + + return ad_len + output_length; +} + +size_t QuicFramer::GetCiphertextSize(EncryptionLevel level, + size_t plaintext_size) const { + if (encrypter_[level] == nullptr) { + QUIC_BUG(quic_bug_10850_65) + << ENDPOINT + << "Attempted to get ciphertext size without encrypter at level " + << level << " using " << version_; + return plaintext_size; + } + return encrypter_[level]->GetCiphertextSize(plaintext_size); +} + +size_t QuicFramer::GetMaxPlaintextSize(size_t ciphertext_size) { + // In order to keep the code simple, we don't have the current encryption + // level to hand. Both the NullEncrypter and AES-GCM have a tag length of 12. + size_t min_plaintext_size = ciphertext_size; + + for (int i = ENCRYPTION_INITIAL; i < NUM_ENCRYPTION_LEVELS; i++) { + if (encrypter_[i] != nullptr) { + size_t size = encrypter_[i]->GetMaxPlaintextSize(ciphertext_size); + if (size < min_plaintext_size) { + min_plaintext_size = size; + } + } + } + + return min_plaintext_size; +} + +QuicPacketCount QuicFramer::GetOneRttEncrypterConfidentialityLimit() const { + if (!encrypter_[ENCRYPTION_FORWARD_SECURE]) { + QUIC_BUG(quic_bug_10850_66) << "1-RTT encrypter not set"; + return 0; + } + return encrypter_[ENCRYPTION_FORWARD_SECURE]->GetConfidentialityLimit(); +} + +bool QuicFramer::DecryptPayload(size_t udp_packet_length, + absl::string_view encrypted, + absl::string_view associated_data, + const QuicPacketHeader& header, + char* decrypted_buffer, size_t buffer_length, + size_t* decrypted_length, + EncryptionLevel* decrypted_level) { + if (!EncryptionLevelIsValid(decrypter_level_)) { + QUIC_BUG(quic_bug_10850_67) + << "Attempted to decrypt with bad decrypter_level_"; + return false; + } + EncryptionLevel level = decrypter_level_; + QuicDecrypter* decrypter = decrypter_[level].get(); + QuicDecrypter* alternative_decrypter = nullptr; + bool key_phase_parsed = false; + bool key_phase; + bool attempt_key_update = false; + if (version().KnowsWhichDecrypterToUse()) { + if (header.form == GOOGLE_QUIC_PACKET) { + QUIC_BUG(quic_bug_10850_68) + << "Attempted to decrypt GOOGLE_QUIC_PACKET with a version that " + "knows which decrypter to use"; + return false; + } + level = GetEncryptionLevel(header); + if (!EncryptionLevelIsValid(level)) { + QUIC_BUG(quic_bug_10850_69) << "Attempted to decrypt with bad level"; + return false; + } + decrypter = decrypter_[level].get(); + if (decrypter == nullptr) { + return false; + } + if (level == ENCRYPTION_ZERO_RTT && + perspective_ == Perspective::IS_CLIENT && header.nonce != nullptr) { + decrypter->SetDiversificationNonce(*header.nonce); + } + if (support_key_update_for_connection_ && + header.form == IETF_QUIC_SHORT_HEADER_PACKET) { + QUICHE_DCHECK(version().UsesTls()); + QUICHE_DCHECK_EQ(level, ENCRYPTION_FORWARD_SECURE); + key_phase = (header.type_byte & FLAGS_KEY_PHASE_BIT) != 0; + key_phase_parsed = true; + QUIC_DVLOG(1) << ENDPOINT << "packet " << header.packet_number + << " received key_phase=" << key_phase + << " current_key_phase_bit_=" << current_key_phase_bit_; + if (key_phase != current_key_phase_bit_) { + if ((current_key_phase_first_received_packet_number_.IsInitialized() && + header.packet_number > + current_key_phase_first_received_packet_number_) || + (!current_key_phase_first_received_packet_number_.IsInitialized() && + !key_update_performed_)) { + if (!next_decrypter_) { + next_decrypter_ = + visitor_->AdvanceKeysAndCreateCurrentOneRttDecrypter(); + if (!next_decrypter_) { + QUIC_BUG(quic_bug_10850_70) << "Failed to create next_decrypter"; + return false; + } + } + QUIC_DVLOG(1) << ENDPOINT << "packet " << header.packet_number + << " attempt_key_update=true"; + attempt_key_update = true; + potential_peer_key_update_attempt_count_++; + decrypter = next_decrypter_.get(); + } else { + if (previous_decrypter_) { + QUIC_DVLOG(1) << ENDPOINT + << "trying previous_decrypter_ for packet " + << header.packet_number; + decrypter = previous_decrypter_.get(); + } else { + QUIC_DVLOG(1) << ENDPOINT << "dropping packet " + << header.packet_number << " with old key phase"; + return false; + } + } + } + } + } else if (alternative_decrypter_level_ != NUM_ENCRYPTION_LEVELS) { + if (!EncryptionLevelIsValid(alternative_decrypter_level_)) { + QUIC_BUG(quic_bug_10850_71) + << "Attempted to decrypt with bad alternative_decrypter_level_"; + return false; + } + alternative_decrypter = decrypter_[alternative_decrypter_level_].get(); + } + + if (decrypter == nullptr) { + QUIC_BUG(quic_bug_10850_72) + << "Attempting to decrypt without decrypter, encryption level:" << level + << " version:" << version(); + return false; + } + + bool success = decrypter->DecryptPacket( + header.packet_number.ToUint64(), associated_data, encrypted, + decrypted_buffer, decrypted_length, buffer_length); + if (success) { + visitor_->OnDecryptedPacket(udp_packet_length, level); + if (level == ENCRYPTION_ZERO_RTT && + current_key_phase_first_received_packet_number_.IsInitialized() && + header.packet_number > + current_key_phase_first_received_packet_number_) { + set_detailed_error(absl::StrCat( + "Decrypted a 0-RTT packet with a packet number ", + header.packet_number.ToString(), + " which is higher than a 1-RTT packet number ", + current_key_phase_first_received_packet_number_.ToString())); + return RaiseError(QUIC_INVALID_0RTT_PACKET_NUMBER_OUT_OF_ORDER); + } + *decrypted_level = level; + potential_peer_key_update_attempt_count_ = 0; + if (attempt_key_update) { + if (!DoKeyUpdate(KeyUpdateReason::kRemote)) { + set_detailed_error("Key update failed due to internal error"); + return RaiseError(QUIC_INTERNAL_ERROR); + } + QUICHE_DCHECK_EQ(current_key_phase_bit_, key_phase); + } + if (key_phase_parsed && + !current_key_phase_first_received_packet_number_.IsInitialized() && + key_phase == current_key_phase_bit_) { + // Set packet number for current key phase if it hasn't been initialized + // yet. This is set outside of attempt_key_update since the key update + // may have been initiated locally, and in that case we don't know yet + // which packet number from the remote side to use until we receive a + // packet with that phase. + QUIC_DVLOG(1) << ENDPOINT + << "current_key_phase_first_received_packet_number_ = " + << header.packet_number; + current_key_phase_first_received_packet_number_ = header.packet_number; + visitor_->OnDecryptedFirstPacketInKeyPhase(); + } + } else if (alternative_decrypter != nullptr) { + if (header.nonce != nullptr) { + QUICHE_DCHECK_EQ(perspective_, Perspective::IS_CLIENT); + alternative_decrypter->SetDiversificationNonce(*header.nonce); + } + bool try_alternative_decryption = true; + if (alternative_decrypter_level_ == ENCRYPTION_ZERO_RTT) { + if (perspective_ == Perspective::IS_CLIENT) { + if (header.nonce == nullptr) { + // Can not use INITIAL decryption without a diversification nonce. + try_alternative_decryption = false; + } + } else { + QUICHE_DCHECK(header.nonce == nullptr); + } + } + + if (try_alternative_decryption) { + success = alternative_decrypter->DecryptPacket( + header.packet_number.ToUint64(), associated_data, encrypted, + decrypted_buffer, decrypted_length, buffer_length); + } + if (success) { + visitor_->OnDecryptedPacket(udp_packet_length, + alternative_decrypter_level_); + *decrypted_level = decrypter_level_; + if (alternative_decrypter_latch_) { + if (!EncryptionLevelIsValid(alternative_decrypter_level_)) { + QUIC_BUG(quic_bug_10850_73) + << "Attempted to latch alternate decrypter with bad " + "alternative_decrypter_level_"; + return false; + } + // Switch to the alternative decrypter and latch so that we cannot + // switch back. + decrypter_level_ = alternative_decrypter_level_; + alternative_decrypter_level_ = NUM_ENCRYPTION_LEVELS; + } else { + // Switch the alternative decrypter so that we use it first next time. + EncryptionLevel alt_level = alternative_decrypter_level_; + alternative_decrypter_level_ = decrypter_level_; + decrypter_level_ = alt_level; + } + } + } + + if (!success) { + QUIC_DVLOG(1) << ENDPOINT << "DecryptPacket failed for: " << header; + return false; + } + + return true; +} + +size_t QuicFramer::GetIetfAckFrameSize(const QuicAckFrame& frame) { + // Type byte, largest_acked, and delay_time are straight-forward. + size_t ack_frame_size = kQuicFrameTypeSize; + QuicPacketNumber largest_acked = LargestAcked(frame); + ack_frame_size += QuicDataWriter::GetVarInt62Len(largest_acked.ToUint64()); + uint64_t ack_delay_time_us; + ack_delay_time_us = frame.ack_delay_time.ToMicroseconds(); + ack_delay_time_us = ack_delay_time_us >> local_ack_delay_exponent_; + ack_frame_size += QuicDataWriter::GetVarInt62Len(ack_delay_time_us); + + if (frame.packets.Empty() || frame.packets.Max() != largest_acked) { + QUIC_BUG(quic_bug_10850_74) << "Malformed ack frame"; + // ACK frame serialization will fail and connection will be closed. + return ack_frame_size; + } + + // Ack block count. + ack_frame_size += + QuicDataWriter::GetVarInt62Len(frame.packets.NumIntervals() - 1); + + // First Ack range. + auto iter = frame.packets.rbegin(); + ack_frame_size += QuicDataWriter::GetVarInt62Len(iter->Length() - 1); + QuicPacketNumber previous_smallest = iter->min(); + ++iter; + + // Ack blocks. + for (; iter != frame.packets.rend(); ++iter) { + const uint64_t gap = previous_smallest - iter->max() - 1; + const uint64_t ack_range = iter->Length() - 1; + ack_frame_size += (QuicDataWriter::GetVarInt62Len(gap) + + QuicDataWriter::GetVarInt62Len(ack_range)); + previous_smallest = iter->min(); + } + + if (UseIetfAckWithReceiveTimestamp(frame)) { + ack_frame_size += GetIetfAckFrameTimestampSize(frame); + } else { + ack_frame_size += AckEcnCountSize(frame); + } + + return ack_frame_size; +} + +size_t QuicFramer::GetIetfAckFrameTimestampSize(const QuicAckFrame& ack) { + QUICHE_DCHECK(!ack.received_packet_times.empty()); + std::string detailed_error; + absl::InlinedVector timestamp_ranges = + GetAckTimestampRanges(ack, detailed_error); + if (!detailed_error.empty()) { + return 0; + } + + int64_t size = + FrameAckTimestampRanges(ack, timestamp_ranges, /*writer=*/nullptr); + return std::max(0, size); +} + +size_t QuicFramer::GetAckFrameSize( + const QuicAckFrame& ack, QuicPacketNumberLength /*packet_number_length*/) { + QUICHE_DCHECK(!ack.packets.Empty()); + size_t ack_size = 0; + + if (VersionHasIetfQuicFrames(version_.transport_version)) { + return GetIetfAckFrameSize(ack); + } + AckFrameInfo ack_info = GetAckFrameInfo(ack); + QuicPacketNumberLength ack_block_length = + GetMinPacketNumberLength(QuicPacketNumber(ack_info.max_block_length)); + + ack_size = GetMinAckFrameSize(version_.transport_version, ack, + local_ack_delay_exponent_, + UseIetfAckWithReceiveTimestamp(ack)); + // First ack block length. + ack_size += ack_block_length; + if (ack_info.num_ack_blocks != 0) { + ack_size += kNumberOfAckBlocksSize; + ack_size += std::min(ack_info.num_ack_blocks, kMaxAckBlocks) * + (ack_block_length + PACKET_1BYTE_PACKET_NUMBER); + } + + // Include timestamps. + if (process_timestamps_) { + ack_size += GetAckFrameTimeStampSize(ack); + } + + return ack_size; +} + +size_t QuicFramer::GetAckFrameTimeStampSize(const QuicAckFrame& ack) { + if (ack.received_packet_times.empty()) { + return 0; + } + + return kQuicNumTimestampsLength + kQuicFirstTimestampLength + + (kQuicTimestampLength + kQuicTimestampPacketNumberGapLength) * + (ack.received_packet_times.size() - 1); +} + +size_t QuicFramer::ComputeFrameLength( + const QuicFrame& frame, bool last_frame_in_packet, + QuicPacketNumberLength packet_number_length) { + switch (frame.type) { + case STREAM_FRAME: + return GetMinStreamFrameSize( + version_.transport_version, frame.stream_frame.stream_id, + frame.stream_frame.offset, last_frame_in_packet, + frame.stream_frame.data_length) + + frame.stream_frame.data_length; + case CRYPTO_FRAME: + return GetMinCryptoFrameSize(frame.crypto_frame->offset, + frame.crypto_frame->data_length) + + frame.crypto_frame->data_length; + case ACK_FRAME: { + return GetAckFrameSize(*frame.ack_frame, packet_number_length); + } + case STOP_WAITING_FRAME: + return GetStopWaitingFrameSize(packet_number_length); + case MTU_DISCOVERY_FRAME: + // MTU discovery frames are serialized as ping frames. + return kQuicFrameTypeSize; + case MESSAGE_FRAME: + return GetMessageFrameSize(version_.transport_version, + last_frame_in_packet, + frame.message_frame->message_length); + case PADDING_FRAME: + QUICHE_DCHECK(false); + return 0; + default: + return GetRetransmittableControlFrameSize(version_.transport_version, + frame); + } +} + +bool QuicFramer::AppendTypeByte(const QuicFrame& frame, + bool last_frame_in_packet, + QuicDataWriter* writer) { + if (VersionHasIetfQuicFrames(version_.transport_version)) { + return AppendIetfFrameType(frame, last_frame_in_packet, writer); + } + uint8_t type_byte = 0; + switch (frame.type) { + case STREAM_FRAME: + type_byte = + GetStreamFrameTypeByte(frame.stream_frame, last_frame_in_packet); + break; + case ACK_FRAME: + return true; + case MTU_DISCOVERY_FRAME: + type_byte = static_cast(PING_FRAME); + break; + case NEW_CONNECTION_ID_FRAME: + set_detailed_error( + "Attempt to append NEW_CONNECTION_ID frame and not in IETF QUIC."); + return RaiseError(QUIC_INTERNAL_ERROR); + case RETIRE_CONNECTION_ID_FRAME: + set_detailed_error( + "Attempt to append RETIRE_CONNECTION_ID frame and not in IETF QUIC."); + return RaiseError(QUIC_INTERNAL_ERROR); + case NEW_TOKEN_FRAME: + set_detailed_error( + "Attempt to append NEW_TOKEN frame and not in IETF QUIC."); + return RaiseError(QUIC_INTERNAL_ERROR); + case MAX_STREAMS_FRAME: + set_detailed_error( + "Attempt to append MAX_STREAMS frame and not in IETF QUIC."); + return RaiseError(QUIC_INTERNAL_ERROR); + case STREAMS_BLOCKED_FRAME: + set_detailed_error( + "Attempt to append STREAMS_BLOCKED frame and not in IETF QUIC."); + return RaiseError(QUIC_INTERNAL_ERROR); + case PATH_RESPONSE_FRAME: + set_detailed_error( + "Attempt to append PATH_RESPONSE frame and not in IETF QUIC."); + return RaiseError(QUIC_INTERNAL_ERROR); + case PATH_CHALLENGE_FRAME: + set_detailed_error( + "Attempt to append PATH_CHALLENGE frame and not in IETF QUIC."); + return RaiseError(QUIC_INTERNAL_ERROR); + case STOP_SENDING_FRAME: + set_detailed_error( + "Attempt to append STOP_SENDING frame and not in IETF QUIC."); + return RaiseError(QUIC_INTERNAL_ERROR); + case MESSAGE_FRAME: + return true; + + default: + type_byte = static_cast(frame.type); + break; + } + + return writer->WriteUInt8(type_byte); +} + +bool QuicFramer::AppendIetfFrameType(const QuicFrame& frame, + bool last_frame_in_packet, + QuicDataWriter* writer) { + uint8_t type_byte = 0; + switch (frame.type) { + case PADDING_FRAME: + type_byte = IETF_PADDING; + break; + case RST_STREAM_FRAME: + type_byte = IETF_RST_STREAM; + break; + case CONNECTION_CLOSE_FRAME: + switch (frame.connection_close_frame->close_type) { + case IETF_QUIC_APPLICATION_CONNECTION_CLOSE: + type_byte = IETF_APPLICATION_CLOSE; + break; + case IETF_QUIC_TRANSPORT_CONNECTION_CLOSE: + type_byte = IETF_CONNECTION_CLOSE; + break; + default: + set_detailed_error(absl::StrCat( + "Invalid QuicConnectionCloseFrame type: ", + static_cast(frame.connection_close_frame->close_type))); + return RaiseError(QUIC_INTERNAL_ERROR); + } + break; + case GOAWAY_FRAME: + set_detailed_error( + "Attempt to create non-IETF QUIC GOAWAY frame in IETF QUIC."); + return RaiseError(QUIC_INTERNAL_ERROR); + case WINDOW_UPDATE_FRAME: + // Depending on whether there is a stream ID or not, will be either a + // MAX_STREAM_DATA frame or a MAX_DATA frame. + if (frame.window_update_frame.stream_id == + QuicUtils::GetInvalidStreamId(transport_version())) { + type_byte = IETF_MAX_DATA; + } else { + type_byte = IETF_MAX_STREAM_DATA; + } + break; + case BLOCKED_FRAME: + if (frame.blocked_frame.stream_id == + QuicUtils::GetInvalidStreamId(transport_version())) { + type_byte = IETF_DATA_BLOCKED; + } else { + type_byte = IETF_STREAM_DATA_BLOCKED; + } + break; + case STOP_WAITING_FRAME: + set_detailed_error( + "Attempt to append type byte of STOP WAITING frame in IETF QUIC."); + return RaiseError(QUIC_INTERNAL_ERROR); + case PING_FRAME: + type_byte = IETF_PING; + break; + case STREAM_FRAME: + type_byte = + GetStreamFrameTypeByte(frame.stream_frame, last_frame_in_packet); + break; + case ACK_FRAME: + // Do nothing here, AppendIetfAckFrameAndTypeByte() will put the type byte + // in the buffer. + return true; + case MTU_DISCOVERY_FRAME: + // The path MTU discovery frame is encoded as a PING frame on the wire. + type_byte = IETF_PING; + break; + case NEW_CONNECTION_ID_FRAME: + type_byte = IETF_NEW_CONNECTION_ID; + break; + case RETIRE_CONNECTION_ID_FRAME: + type_byte = IETF_RETIRE_CONNECTION_ID; + break; + case NEW_TOKEN_FRAME: + type_byte = IETF_NEW_TOKEN; + break; + case MAX_STREAMS_FRAME: + if (frame.max_streams_frame.unidirectional) { + type_byte = IETF_MAX_STREAMS_UNIDIRECTIONAL; + } else { + type_byte = IETF_MAX_STREAMS_BIDIRECTIONAL; + } + break; + case STREAMS_BLOCKED_FRAME: + if (frame.streams_blocked_frame.unidirectional) { + type_byte = IETF_STREAMS_BLOCKED_UNIDIRECTIONAL; + } else { + type_byte = IETF_STREAMS_BLOCKED_BIDIRECTIONAL; + } + break; + case PATH_RESPONSE_FRAME: + type_byte = IETF_PATH_RESPONSE; + break; + case PATH_CHALLENGE_FRAME: + type_byte = IETF_PATH_CHALLENGE; + break; + case STOP_SENDING_FRAME: + type_byte = IETF_STOP_SENDING; + break; + case MESSAGE_FRAME: + return true; + case CRYPTO_FRAME: + type_byte = IETF_CRYPTO; + break; + case HANDSHAKE_DONE_FRAME: + type_byte = IETF_HANDSHAKE_DONE; + break; + case ACK_FREQUENCY_FRAME: + type_byte = IETF_ACK_FREQUENCY; + break; + default: + QUIC_BUG(quic_bug_10850_75) + << "Attempt to generate a frame type for an unsupported value: " + << frame.type; + return false; + } + return writer->WriteVarInt62(type_byte); +} + +// static +bool QuicFramer::AppendPacketNumber(QuicPacketNumberLength packet_number_length, + QuicPacketNumber packet_number, + QuicDataWriter* writer) { + QUICHE_DCHECK(packet_number.IsInitialized()); + if (!IsValidPacketNumberLength(packet_number_length)) { + QUIC_BUG(quic_bug_10850_76) + << "Invalid packet_number_length: " << packet_number_length; + return false; + } + return writer->WriteBytesToUInt64(packet_number_length, + packet_number.ToUint64()); +} + +// static +bool QuicFramer::AppendStreamId(size_t stream_id_length, QuicStreamId stream_id, + QuicDataWriter* writer) { + if (stream_id_length == 0 || stream_id_length > 4) { + QUIC_BUG(quic_bug_10850_77) + << "Invalid stream_id_length: " << stream_id_length; + return false; + } + return writer->WriteBytesToUInt64(stream_id_length, stream_id); +} + +// static +bool QuicFramer::AppendStreamOffset(size_t offset_length, + QuicStreamOffset offset, + QuicDataWriter* writer) { + if (offset_length == 1 || offset_length > 8) { + QUIC_BUG(quic_bug_10850_78) + << "Invalid stream_offset_length: " << offset_length; + return false; + } + + return writer->WriteBytesToUInt64(offset_length, offset); +} + +// static +bool QuicFramer::AppendAckBlock(uint8_t gap, + QuicPacketNumberLength length_length, + uint64_t length, QuicDataWriter* writer) { + if (length == 0) { + if (!IsValidPacketNumberLength(length_length)) { + QUIC_BUG(quic_bug_10850_79) + << "Invalid packet_number_length: " << length_length; + return false; + } + return writer->WriteUInt8(gap) && + writer->WriteBytesToUInt64(length_length, length); + } + return writer->WriteUInt8(gap) && + AppendPacketNumber(length_length, QuicPacketNumber(length), writer); +} + +bool QuicFramer::AppendStreamFrame(const QuicStreamFrame& frame, + bool no_stream_frame_length, + QuicDataWriter* writer) { + if (VersionHasIetfQuicFrames(version_.transport_version)) { + return AppendIetfStreamFrame(frame, no_stream_frame_length, writer); + } + if (!AppendStreamId(GetStreamIdSize(frame.stream_id), frame.stream_id, + writer)) { + QUIC_BUG(quic_bug_10850_80) << "Writing stream id size failed."; + return false; + } + if (!AppendStreamOffset(GetStreamOffsetSize(frame.offset), frame.offset, + writer)) { + QUIC_BUG(quic_bug_10850_81) << "Writing offset size failed."; + return false; + } + if (!no_stream_frame_length) { + static_assert( + std::numeric_limits::max() <= + std::numeric_limits::max(), + "If frame.data_length can hold more than a uint16_t than we need to " + "check that frame.data_length <= std::numeric_limits::max()"); + if (!writer->WriteUInt16(static_cast(frame.data_length))) { + QUIC_BUG(quic_bug_10850_82) << "Writing stream frame length failed"; + return false; + } + } + + if (data_producer_ != nullptr) { + QUICHE_DCHECK_EQ(nullptr, frame.data_buffer); + if (frame.data_length == 0) { + return true; + } + if (data_producer_->WriteStreamData(frame.stream_id, frame.offset, + frame.data_length, + writer) != WRITE_SUCCESS) { + QUIC_BUG(quic_bug_10850_83) << "Writing frame data failed."; + return false; + } + return true; + } + + if (!writer->WriteBytes(frame.data_buffer, frame.data_length)) { + QUIC_BUG(quic_bug_10850_84) << "Writing frame data failed."; + return false; + } + return true; +} + +bool QuicFramer::AppendNewTokenFrame(const QuicNewTokenFrame& frame, + QuicDataWriter* writer) { + if (!writer->WriteVarInt62(static_cast(frame.token.length()))) { + set_detailed_error("Writing token length failed."); + return false; + } + if (!writer->WriteBytes(frame.token.data(), frame.token.length())) { + set_detailed_error("Writing token buffer failed."); + return false; + } + return true; +} + +bool QuicFramer::ProcessNewTokenFrame(QuicDataReader* reader, + QuicNewTokenFrame* frame) { + uint64_t length; + if (!reader->ReadVarInt62(&length)) { + set_detailed_error("Unable to read new token length."); + return false; + } + if (length > kMaxNewTokenTokenLength) { + set_detailed_error("Token length larger than maximum."); + return false; + } + + // TODO(ianswett): Don't use absl::string_view as an intermediary. + absl::string_view data; + if (!reader->ReadStringPiece(&data, length)) { + set_detailed_error("Unable to read new token data."); + return false; + } + frame->token = std::string(data); + return true; +} + +// Add a new ietf-format stream frame. +// Bits controlling whether there is a frame-length and frame-offset +// are in the QuicStreamFrame. +bool QuicFramer::AppendIetfStreamFrame(const QuicStreamFrame& frame, + bool last_frame_in_packet, + QuicDataWriter* writer) { + if (!writer->WriteVarInt62(static_cast(frame.stream_id))) { + set_detailed_error("Writing stream id failed."); + return false; + } + + if (frame.offset != 0) { + if (!writer->WriteVarInt62(static_cast(frame.offset))) { + set_detailed_error("Writing data offset failed."); + return false; + } + } + + if (!last_frame_in_packet) { + if (!writer->WriteVarInt62(frame.data_length)) { + set_detailed_error("Writing data length failed."); + return false; + } + } + + if (frame.data_length == 0) { + return true; + } + if (data_producer_ == nullptr) { + if (!writer->WriteBytes(frame.data_buffer, frame.data_length)) { + set_detailed_error("Writing frame data failed."); + return false; + } + } else { + QUICHE_DCHECK_EQ(nullptr, frame.data_buffer); + + if (data_producer_->WriteStreamData(frame.stream_id, frame.offset, + frame.data_length, + writer) != WRITE_SUCCESS) { + set_detailed_error("Writing frame data from producer failed."); + return false; + } + } + return true; +} + +bool QuicFramer::AppendCryptoFrame(const QuicCryptoFrame& frame, + QuicDataWriter* writer) { + if (!writer->WriteVarInt62(static_cast(frame.offset))) { + set_detailed_error("Writing data offset failed."); + return false; + } + if (!writer->WriteVarInt62(static_cast(frame.data_length))) { + set_detailed_error("Writing data length failed."); + return false; + } + if (data_producer_ == nullptr) { + if (frame.data_buffer == nullptr || + !writer->WriteBytes(frame.data_buffer, frame.data_length)) { + set_detailed_error("Writing frame data failed."); + return false; + } + } else { + QUICHE_DCHECK_EQ(nullptr, frame.data_buffer); + if (!data_producer_->WriteCryptoData(frame.level, frame.offset, + frame.data_length, writer)) { + return false; + } + } + return true; +} + +bool QuicFramer::AppendAckFrequencyFrame(const QuicAckFrequencyFrame& frame, + QuicDataWriter* writer) { + if (!writer->WriteVarInt62(frame.sequence_number)) { + set_detailed_error("Writing sequence number failed."); + return false; + } + if (!writer->WriteVarInt62(frame.packet_tolerance)) { + set_detailed_error("Writing packet tolerance failed."); + return false; + } + if (!writer->WriteVarInt62( + static_cast(frame.max_ack_delay.ToMicroseconds()))) { + set_detailed_error("Writing max_ack_delay_us failed."); + return false; + } + if (!writer->WriteUInt8(static_cast(frame.ignore_order))) { + set_detailed_error("Writing ignore_order failed."); + return false; + } + + return true; +} + +void QuicFramer::set_version(const ParsedQuicVersion version) { + QUICHE_DCHECK(IsSupportedVersion(version)) + << ParsedQuicVersionToString(version); + version_ = version; +} + +bool QuicFramer::AppendAckFrameAndTypeByte(const QuicAckFrame& frame, + QuicDataWriter* writer) { + if (VersionHasIetfQuicFrames(transport_version())) { + return AppendIetfAckFrameAndTypeByte(frame, writer); + } + + const AckFrameInfo new_ack_info = GetAckFrameInfo(frame); + QuicPacketNumber largest_acked = LargestAcked(frame); + QuicPacketNumberLength largest_acked_length = + GetMinPacketNumberLength(largest_acked); + QuicPacketNumberLength ack_block_length = + GetMinPacketNumberLength(QuicPacketNumber(new_ack_info.max_block_length)); + // Calculate available bytes for timestamps and ack blocks. + int32_t available_timestamp_and_ack_block_bytes = + writer->capacity() - writer->length() - ack_block_length - + GetMinAckFrameSize(version_.transport_version, frame, + local_ack_delay_exponent_, + UseIetfAckWithReceiveTimestamp(frame)) - + (new_ack_info.num_ack_blocks != 0 ? kNumberOfAckBlocksSize : 0); + QUICHE_DCHECK_LE(0, available_timestamp_and_ack_block_bytes); + + uint8_t type_byte = 0; + SetBit(&type_byte, new_ack_info.num_ack_blocks != 0, + kQuicHasMultipleAckBlocksOffset); + + SetBits(&type_byte, GetPacketNumberFlags(largest_acked_length), + kQuicSequenceNumberLengthNumBits, kLargestAckedOffset); + + SetBits(&type_byte, GetPacketNumberFlags(ack_block_length), + kQuicSequenceNumberLengthNumBits, kActBlockLengthOffset); + + type_byte |= kQuicFrameTypeAckMask; + + if (!writer->WriteUInt8(type_byte)) { + return false; + } + + size_t max_num_ack_blocks = available_timestamp_and_ack_block_bytes / + (ack_block_length + PACKET_1BYTE_PACKET_NUMBER); + + // Number of ack blocks. + size_t num_ack_blocks = + std::min(new_ack_info.num_ack_blocks, max_num_ack_blocks); + if (num_ack_blocks > std::numeric_limits::max()) { + num_ack_blocks = std::numeric_limits::max(); + } + + // Largest acked. + if (!AppendPacketNumber(largest_acked_length, largest_acked, writer)) { + return false; + } + + // Largest acked delta time. + uint64_t ack_delay_time_us = kUFloat16MaxValue; + if (!frame.ack_delay_time.IsInfinite()) { + QUICHE_DCHECK_LE(0u, frame.ack_delay_time.ToMicroseconds()); + ack_delay_time_us = frame.ack_delay_time.ToMicroseconds(); + } + if (!writer->WriteUFloat16(ack_delay_time_us)) { + return false; + } + + if (num_ack_blocks > 0) { + if (!writer->WriteBytes(&num_ack_blocks, 1)) { + return false; + } + } + + // First ack block length. + if (!AppendPacketNumber(ack_block_length, + QuicPacketNumber(new_ack_info.first_block_length), + writer)) { + return false; + } + + // Ack blocks. + if (num_ack_blocks > 0) { + size_t num_ack_blocks_written = 0; + // Append, in descending order from the largest ACKed packet, a series of + // ACK blocks that represents the successfully acknoweldged packets. Each + // appended gap/block length represents a descending delta from the previous + // block. i.e.: + // |--- length ---|--- gap ---|--- length ---|--- gap ---|--- largest ---| + // For gaps larger than can be represented by a single encoded gap, a 0 + // length gap of the maximum is used, i.e.: + // |--- length ---|--- gap ---|- 0 -|--- gap ---|--- largest ---| + auto itr = frame.packets.rbegin(); + QuicPacketNumber previous_start = itr->min(); + ++itr; + + for (; + itr != frame.packets.rend() && num_ack_blocks_written < num_ack_blocks; + previous_start = itr->min(), ++itr) { + const auto& interval = *itr; + const uint64_t total_gap = previous_start - interval.max(); + const size_t num_encoded_gaps = + (total_gap + std::numeric_limits::max() - 1) / + std::numeric_limits::max(); + + // Append empty ACK blocks because the gap is longer than a single gap. + for (size_t i = 1; + i < num_encoded_gaps && num_ack_blocks_written < num_ack_blocks; + ++i) { + if (!AppendAckBlock(std::numeric_limits::max(), + ack_block_length, 0, writer)) { + return false; + } + ++num_ack_blocks_written; + } + if (num_ack_blocks_written >= num_ack_blocks) { + if (ABSL_PREDICT_FALSE(num_ack_blocks_written != num_ack_blocks)) { + QUIC_BUG(quic_bug_10850_85) + << "Wrote " << num_ack_blocks_written << ", expected to write " + << num_ack_blocks; + } + break; + } + + const uint8_t last_gap = + total_gap - + (num_encoded_gaps - 1) * std::numeric_limits::max(); + // Append the final ACK block with a non-empty size. + if (!AppendAckBlock(last_gap, ack_block_length, interval.Length(), + writer)) { + return false; + } + ++num_ack_blocks_written; + } + QUICHE_DCHECK_EQ(num_ack_blocks, num_ack_blocks_written); + } + // Timestamps. + // If we don't process timestamps or if we don't have enough available space + // to append all the timestamps, don't append any of them. + if (process_timestamps_ && writer->capacity() - writer->length() >= + GetAckFrameTimeStampSize(frame)) { + if (!AppendTimestampsToAckFrame(frame, writer)) { + return false; + } + } else { + uint8_t num_received_packets = 0; + if (!writer->WriteBytes(&num_received_packets, 1)) { + return false; + } + } + + return true; +} + +bool QuicFramer::AppendTimestampsToAckFrame(const QuicAckFrame& frame, + QuicDataWriter* writer) { + QUICHE_DCHECK_GE(std::numeric_limits::max(), + frame.received_packet_times.size()); + // num_received_packets is only 1 byte. + if (frame.received_packet_times.size() > + std::numeric_limits::max()) { + return false; + } + + uint8_t num_received_packets = frame.received_packet_times.size(); + if (!writer->WriteBytes(&num_received_packets, 1)) { + return false; + } + if (num_received_packets == 0) { + return true; + } + + auto it = frame.received_packet_times.begin(); + QuicPacketNumber packet_number = it->first; + uint64_t delta_from_largest_observed = LargestAcked(frame) - packet_number; + + QUICHE_DCHECK_GE(std::numeric_limits::max(), + delta_from_largest_observed); + if (delta_from_largest_observed > std::numeric_limits::max()) { + return false; + } + + if (!writer->WriteUInt8(delta_from_largest_observed)) { + return false; + } + + // Use the lowest 4 bytes of the time delta from the creation_time_. + const uint64_t time_epoch_delta_us = UINT64_C(1) << 32; + uint32_t time_delta_us = + static_cast((it->second - creation_time_).ToMicroseconds() & + (time_epoch_delta_us - 1)); + if (!writer->WriteUInt32(time_delta_us)) { + return false; + } + + QuicTime prev_time = it->second; + + for (++it; it != frame.received_packet_times.end(); ++it) { + packet_number = it->first; + delta_from_largest_observed = LargestAcked(frame) - packet_number; + + if (delta_from_largest_observed > std::numeric_limits::max()) { + return false; + } + + if (!writer->WriteUInt8(delta_from_largest_observed)) { + return false; + } + + uint64_t frame_time_delta_us = (it->second - prev_time).ToMicroseconds(); + prev_time = it->second; + if (!writer->WriteUFloat16(frame_time_delta_us)) { + return false; + } + } + return true; +} + +absl::InlinedVector +QuicFramer::GetAckTimestampRanges(const QuicAckFrame& frame, + std::string& detailed_error) const { + detailed_error = ""; + if (frame.received_packet_times.empty()) { + return {}; + } + + absl::InlinedVector timestamp_ranges; + + for (size_t r = 0; r < std::min(max_receive_timestamps_per_ack_, + frame.received_packet_times.size()); + ++r) { + const size_t i = frame.received_packet_times.size() - 1 - r; + const QuicPacketNumber packet_number = frame.received_packet_times[i].first; + const QuicTime receive_timestamp = frame.received_packet_times[i].second; + + if (timestamp_ranges.empty()) { + if (receive_timestamp < creation_time_ || + LargestAcked(frame) < packet_number) { + detailed_error = + "The first packet is either received earlier than framer creation " + "time, or larger than largest acked packet."; + QUIC_BUG(quic_framer_ack_ts_first_packet_bad) + << detailed_error << " receive_timestamp:" << receive_timestamp + << ", framer_creation_time:" << creation_time_ + << ", packet_number:" << packet_number + << ", largest_acked:" << LargestAcked(frame); + return {}; + } + timestamp_ranges.push_back(AckTimestampRange()); + timestamp_ranges.back().gap = LargestAcked(frame) - packet_number; + timestamp_ranges.back().range_begin = i; + timestamp_ranges.back().range_end = i; + continue; + } + + const size_t prev_i = timestamp_ranges.back().range_end; + const QuicPacketNumber prev_packet_number = + frame.received_packet_times[prev_i].first; + const QuicTime prev_receive_timestamp = + frame.received_packet_times[prev_i].second; + + QUIC_DVLOG(3) << "prev_packet_number:" << prev_packet_number + << ", packet_number:" << packet_number; + if (prev_receive_timestamp < receive_timestamp || + prev_packet_number <= packet_number) { + detailed_error = "Packet number and/or receive time not in order."; + QUIC_BUG(quic_framer_ack_ts_packet_out_of_order) + << detailed_error << " packet_number:" << packet_number + << ", receive_timestamp:" << receive_timestamp + << ", prev_packet_number:" << prev_packet_number + << ", prev_receive_timestamp:" << prev_receive_timestamp; + return {}; + } + + if (prev_packet_number == packet_number + 1) { + timestamp_ranges.back().range_end = i; + } else { + timestamp_ranges.push_back(AckTimestampRange()); + timestamp_ranges.back().gap = prev_packet_number - 2 - packet_number; + timestamp_ranges.back().range_begin = i; + timestamp_ranges.back().range_end = i; + } + } + + return timestamp_ranges; +} + +int64_t QuicFramer::FrameAckTimestampRanges( + const QuicAckFrame& frame, + const absl::InlinedVector& timestamp_ranges, + QuicDataWriter* writer) const { + int64_t size = 0; + auto maybe_write_var_int62 = [&](uint64_t value) { + size += QuicDataWriter::GetVarInt62Len(value); + if (writer != nullptr && !writer->WriteVarInt62(value)) { + return false; + } + return true; + }; + + if (!maybe_write_var_int62(timestamp_ranges.size())) { + return -1; + } + + // |effective_prev_time| is the exponent-encoded timestamp of the previous + // packet. + absl::optional effective_prev_time; + for (const AckTimestampRange& range : timestamp_ranges) { + QUIC_DVLOG(3) << "Range: gap:" << range.gap << ", beg:" << range.range_begin + << ", end:" << range.range_end; + if (!maybe_write_var_int62(range.gap)) { + return -1; + } + + if (!maybe_write_var_int62(range.range_begin - range.range_end + 1)) { + return -1; + } + + for (int64_t i = range.range_begin; i >= range.range_end; --i) { + const QuicTime receive_timestamp = frame.received_packet_times[i].second; + uint64_t time_delta; + if (effective_prev_time.has_value()) { + time_delta = + (*effective_prev_time - receive_timestamp).ToMicroseconds(); + QUIC_DVLOG(3) << "time_delta:" << time_delta + << ", exponent:" << receive_timestamps_exponent_ + << ", effective_prev_time:" << *effective_prev_time + << ", recv_time:" << receive_timestamp; + time_delta = time_delta >> receive_timestamps_exponent_; + effective_prev_time = effective_prev_time.value() - + QuicTime::Delta::FromMicroseconds( + time_delta << receive_timestamps_exponent_); + } else { + // The first delta is from framer creation to the current receive + // timestamp (forward in time), whereas in the common case subsequent + // deltas move backwards in time. + time_delta = (receive_timestamp - creation_time_).ToMicroseconds(); + QUIC_DVLOG(3) << "First time_delta:" << time_delta + << ", exponent:" << receive_timestamps_exponent_ + << ", recv_time:" << receive_timestamp + << ", creation_time:" << creation_time_; + // Round up the first exponent-encoded time delta so that the next + // receive timestamp is guaranteed to be decreasing. + time_delta = ((time_delta - 1) >> receive_timestamps_exponent_) + 1; + effective_prev_time = + creation_time_ + QuicTime::Delta::FromMicroseconds( + time_delta << receive_timestamps_exponent_); + } + + if (!maybe_write_var_int62(time_delta)) { + return -1; + } + } + } + + return size; +} + +bool QuicFramer::AppendIetfTimestampsToAckFrame(const QuicAckFrame& frame, + QuicDataWriter* writer) { + QUICHE_DCHECK(!frame.received_packet_times.empty()); + std::string detailed_error; + const absl::InlinedVector timestamp_ranges = + GetAckTimestampRanges(frame, detailed_error); + if (!detailed_error.empty()) { + set_detailed_error(std::move(detailed_error)); + return false; + } + + // Compute the size first using a null writer. + int64_t size = + FrameAckTimestampRanges(frame, timestamp_ranges, /*writer=*/nullptr); + if (size > static_cast(writer->capacity() - writer->length())) { + QUIC_DVLOG(1) << "Insufficient room to write IETF ack receive timestamps. " + "size_remain:" + << (writer->capacity() - writer->length()) + << ", size_needed:" << size; + // Write a Timestamp Range Count of 0. + return writer->WriteVarInt62(0); + } + + return FrameAckTimestampRanges(frame, timestamp_ranges, writer) > 0; +} + +bool QuicFramer::AppendStopWaitingFrame(const QuicPacketHeader& header, + const QuicStopWaitingFrame& frame, + QuicDataWriter* writer) { + QUICHE_DCHECK(!version_.HasIetfInvariantHeader()); + QUICHE_DCHECK(frame.least_unacked.IsInitialized()); + QUICHE_DCHECK_GE(header.packet_number, frame.least_unacked); + const uint64_t least_unacked_delta = + header.packet_number - frame.least_unacked; + const uint64_t length_shift = header.packet_number_length * 8; + + if (least_unacked_delta >> length_shift > 0) { + QUIC_BUG(quic_bug_10850_86) + << "packet_number_length " << header.packet_number_length + << " is too small for least_unacked_delta: " << least_unacked_delta + << " packet_number:" << header.packet_number + << " least_unacked:" << frame.least_unacked + << " version:" << version_.transport_version; + return false; + } + if (least_unacked_delta == 0) { + return writer->WriteBytesToUInt64(header.packet_number_length, + least_unacked_delta); + } + if (!AppendPacketNumber(header.packet_number_length, + QuicPacketNumber(least_unacked_delta), writer)) { + QUIC_BUG(quic_bug_10850_87) + << " seq failed: " << header.packet_number_length; + return false; + } + + return true; +} + +bool QuicFramer::AppendIetfAckFrameAndTypeByte(const QuicAckFrame& frame, + QuicDataWriter* writer) { + uint8_t type = IETF_ACK; + uint64_t ecn_size = 0; + if (UseIetfAckWithReceiveTimestamp(frame)) { + type = IETF_ACK_RECEIVE_TIMESTAMPS; + } else if (frame.ecn_counters.has_value()) { + // Change frame type to ACK_ECN if any ECN count is available. + type = IETF_ACK_ECN; + ecn_size = AckEcnCountSize(frame); + } + + if (!writer->WriteVarInt62(type)) { + set_detailed_error("No room for frame-type"); + return false; + } + + QuicPacketNumber largest_acked = LargestAcked(frame); + if (!writer->WriteVarInt62(largest_acked.ToUint64())) { + set_detailed_error("No room for largest-acked in ack frame"); + return false; + } + + uint64_t ack_delay_time_us = quiche::kVarInt62MaxValue; + if (!frame.ack_delay_time.IsInfinite()) { + QUICHE_DCHECK_LE(0u, frame.ack_delay_time.ToMicroseconds()); + ack_delay_time_us = frame.ack_delay_time.ToMicroseconds(); + ack_delay_time_us = ack_delay_time_us >> local_ack_delay_exponent_; + } + + if (!writer->WriteVarInt62(ack_delay_time_us)) { + set_detailed_error("No room for ack-delay in ack frame"); + return false; + } + + if (frame.packets.Empty() || frame.packets.Max() != largest_acked) { + QUIC_BUG(quic_bug_10850_88) << "Malformed ack frame: " << frame; + set_detailed_error("Malformed ack frame"); + return false; + } + + // Latch ack_block_count for potential truncation. + const uint64_t ack_block_count = frame.packets.NumIntervals() - 1; + QuicDataWriter count_writer(QuicDataWriter::GetVarInt62Len(ack_block_count), + writer->data() + writer->length()); + if (!writer->WriteVarInt62(ack_block_count)) { + set_detailed_error("No room for ack block count in ack frame"); + return false; + } + auto iter = frame.packets.rbegin(); + if (!writer->WriteVarInt62(iter->Length() - 1)) { + set_detailed_error("No room for first ack block in ack frame"); + return false; + } + QuicPacketNumber previous_smallest = iter->min(); + ++iter; + // Append remaining ACK blocks. + uint64_t appended_ack_blocks = 0; + for (; iter != frame.packets.rend(); ++iter) { + const uint64_t gap = previous_smallest - iter->max() - 1; + const uint64_t ack_range = iter->Length() - 1; + + if (type == IETF_ACK_RECEIVE_TIMESTAMPS && + writer->remaining() < + static_cast(QuicDataWriter::GetVarInt62Len(gap) + + QuicDataWriter::GetVarInt62Len(ack_range) + + QuicDataWriter::GetVarInt62Len(0))) { + // If we write this ACK range we won't have space for a timestamp range + // count of 0. + break; + } else if (writer->remaining() < ecn_size || + writer->remaining() - ecn_size < + static_cast( + QuicDataWriter::GetVarInt62Len(gap) + + QuicDataWriter::GetVarInt62Len(ack_range))) { + // ACK range does not fit, truncate it. + break; + } + const bool success = + writer->WriteVarInt62(gap) && writer->WriteVarInt62(ack_range); + QUICHE_DCHECK(success); + previous_smallest = iter->min(); + ++appended_ack_blocks; + } + + if (appended_ack_blocks < ack_block_count) { + // Truncation is needed, rewrite the ack block count. + if (QuicDataWriter::GetVarInt62Len(appended_ack_blocks) != + QuicDataWriter::GetVarInt62Len(ack_block_count) || + !count_writer.WriteVarInt62(appended_ack_blocks)) { + // This should never happen as ack_block_count is limited by + // max_ack_ranges_. + QUIC_BUG(quic_bug_10850_89) + << "Ack frame truncation fails. ack_block_count: " << ack_block_count + << ", appended count: " << appended_ack_blocks; + set_detailed_error("ACK frame truncation fails"); + return false; + } + QUIC_DLOG(INFO) << ENDPOINT << "ACK ranges get truncated from " + << ack_block_count << " to " << appended_ack_blocks; + } + + if (type == IETF_ACK_ECN) { + // Encode the ECN counts. + if (!writer->WriteVarInt62(frame.ecn_counters->ect0)) { + set_detailed_error("No room for ect_0_count in ack frame"); + return false; + } + if (!writer->WriteVarInt62(frame.ecn_counters->ect1)) { + set_detailed_error("No room for ect_1_count in ack frame"); + return false; + } + if (!writer->WriteVarInt62(frame.ecn_counters->ce)) { + set_detailed_error("No room for ecn_ce_count in ack frame"); + return false; + } + } + + if (type == IETF_ACK_RECEIVE_TIMESTAMPS) { + if (!AppendIetfTimestampsToAckFrame(frame, writer)) { + return false; + } + } + + return true; +} + +bool QuicFramer::AppendRstStreamFrame(const QuicRstStreamFrame& frame, + QuicDataWriter* writer) { + if (VersionHasIetfQuicFrames(version_.transport_version)) { + return AppendIetfResetStreamFrame(frame, writer); + } + if (!writer->WriteUInt32(frame.stream_id)) { + return false; + } + + if (!writer->WriteUInt64(frame.byte_offset)) { + return false; + } + + uint32_t error_code = static_cast(frame.error_code); + if (!writer->WriteUInt32(error_code)) { + return false; + } + + return true; +} + +bool QuicFramer::AppendConnectionCloseFrame( + const QuicConnectionCloseFrame& frame, QuicDataWriter* writer) { + if (VersionHasIetfQuicFrames(version_.transport_version)) { + return AppendIetfConnectionCloseFrame(frame, writer); + } + uint32_t error_code = static_cast(frame.wire_error_code); + if (!writer->WriteUInt32(error_code)) { + return false; + } + if (!writer->WriteStringPiece16(TruncateErrorString(frame.error_details))) { + return false; + } + return true; +} + +bool QuicFramer::AppendGoAwayFrame(const QuicGoAwayFrame& frame, + QuicDataWriter* writer) { + uint32_t error_code = static_cast(frame.error_code); + if (!writer->WriteUInt32(error_code)) { + return false; + } + uint32_t stream_id = static_cast(frame.last_good_stream_id); + if (!writer->WriteUInt32(stream_id)) { + return false; + } + if (!writer->WriteStringPiece16(TruncateErrorString(frame.reason_phrase))) { + return false; + } + return true; +} + +bool QuicFramer::AppendWindowUpdateFrame(const QuicWindowUpdateFrame& frame, + QuicDataWriter* writer) { + uint32_t stream_id = static_cast(frame.stream_id); + if (!writer->WriteUInt32(stream_id)) { + return false; + } + if (!writer->WriteUInt64(frame.max_data)) { + return false; + } + return true; +} + +bool QuicFramer::AppendBlockedFrame(const QuicBlockedFrame& frame, + QuicDataWriter* writer) { + if (VersionHasIetfQuicFrames(version_.transport_version)) { + if (frame.stream_id == QuicUtils::GetInvalidStreamId(transport_version())) { + return AppendDataBlockedFrame(frame, writer); + } + return AppendStreamDataBlockedFrame(frame, writer); + } + uint32_t stream_id = static_cast(frame.stream_id); + if (!writer->WriteUInt32(stream_id)) { + return false; + } + return true; +} + +bool QuicFramer::AppendPaddingFrame(const QuicPaddingFrame& frame, + QuicDataWriter* writer) { + if (frame.num_padding_bytes == 0) { + return false; + } + if (frame.num_padding_bytes < 0) { + QUIC_BUG_IF(quic_bug_12975_9, frame.num_padding_bytes != -1); + writer->WritePadding(); + return true; + } + // Please note, num_padding_bytes includes type byte which has been written. + return writer->WritePaddingBytes(frame.num_padding_bytes - 1); +} + +bool QuicFramer::AppendMessageFrameAndTypeByte(const QuicMessageFrame& frame, + bool last_frame_in_packet, + QuicDataWriter* writer) { + uint8_t type_byte; + if (VersionHasIetfQuicFrames(version_.transport_version)) { + type_byte = last_frame_in_packet ? IETF_EXTENSION_MESSAGE_NO_LENGTH_V99 + : IETF_EXTENSION_MESSAGE_V99; + } else { + type_byte = last_frame_in_packet ? IETF_EXTENSION_MESSAGE_NO_LENGTH + : IETF_EXTENSION_MESSAGE; + } + if (!writer->WriteUInt8(type_byte)) { + return false; + } + if (!last_frame_in_packet && !writer->WriteVarInt62(frame.message_length)) { + return false; + } + for (const auto& slice : frame.message_data) { + if (!writer->WriteBytes(slice.data(), slice.length())) { + return false; + } + } + return true; +} + +bool QuicFramer::RaiseError(QuicErrorCode error) { + QUIC_DLOG(INFO) << ENDPOINT << "Error: " << QuicErrorCodeToString(error) + << " detail: " << detailed_error_; + set_error(error); + if (visitor_) { + visitor_->OnError(this); + } + return false; +} + +bool QuicFramer::IsVersionNegotiation( + const QuicPacketHeader& header, bool packet_has_ietf_packet_header) const { + if (!packet_has_ietf_packet_header && + perspective_ == Perspective::IS_CLIENT) { + return header.version_flag; + } + if (header.form == IETF_QUIC_SHORT_HEADER_PACKET) { + return false; + } + return header.long_packet_type == VERSION_NEGOTIATION; +} + +bool QuicFramer::AppendIetfConnectionCloseFrame( + const QuicConnectionCloseFrame& frame, QuicDataWriter* writer) { + if (frame.close_type != IETF_QUIC_TRANSPORT_CONNECTION_CLOSE && + frame.close_type != IETF_QUIC_APPLICATION_CONNECTION_CLOSE) { + QUIC_BUG(quic_bug_10850_90) + << "Invalid close_type for writing IETF CONNECTION CLOSE."; + set_detailed_error("Invalid close_type for writing IETF CONNECTION CLOSE."); + return false; + } + + if (!writer->WriteVarInt62(frame.wire_error_code)) { + set_detailed_error("Can not write connection close frame error code"); + return false; + } + + if (frame.close_type == IETF_QUIC_TRANSPORT_CONNECTION_CLOSE) { + // Write the frame-type of the frame causing the error only + // if it's a CONNECTION_CLOSE/Transport. + if (!writer->WriteVarInt62(frame.transport_close_frame_type)) { + set_detailed_error("Writing frame type failed."); + return false; + } + } + + // There may be additional error information available in the extracted error + // code. Encode the error information in the reason phrase and serialize the + // result. + std::string final_error_string = + GenerateErrorString(frame.error_details, frame.quic_error_code); + if (!writer->WriteStringPieceVarInt62( + TruncateErrorString(final_error_string))) { + set_detailed_error("Can not write connection close phrase"); + return false; + } + return true; +} + +bool QuicFramer::ProcessIetfConnectionCloseFrame( + QuicDataReader* reader, QuicConnectionCloseType type, + QuicConnectionCloseFrame* frame) { + frame->close_type = type; + + uint64_t error_code; + if (!reader->ReadVarInt62(&error_code)) { + set_detailed_error("Unable to read connection close error code."); + return false; + } + + frame->wire_error_code = error_code; + + if (type == IETF_QUIC_TRANSPORT_CONNECTION_CLOSE) { + // The frame-type of the frame causing the error is present only + // if it's a CONNECTION_CLOSE/Transport. + if (!reader->ReadVarInt62(&frame->transport_close_frame_type)) { + set_detailed_error("Unable to read connection close frame type."); + return false; + } + } + + uint64_t phrase_length; + if (!reader->ReadVarInt62(&phrase_length)) { + set_detailed_error("Unable to read connection close error details."); + return false; + } + + absl::string_view phrase; + if (!reader->ReadStringPiece(&phrase, static_cast(phrase_length))) { + set_detailed_error("Unable to read connection close error details."); + return false; + } + frame->error_details = std::string(phrase); + + // The frame may have an extracted error code in it. Look for it and + // extract it. If it's not present, MaybeExtract will return + // QUIC_IETF_GQUIC_ERROR_MISSING. + MaybeExtractQuicErrorCode(frame); + return true; +} + +// IETF Quic Path Challenge/Response frames. +bool QuicFramer::ProcessPathChallengeFrame(QuicDataReader* reader, + QuicPathChallengeFrame* frame) { + if (!reader->ReadBytes(frame->data_buffer.data(), + frame->data_buffer.size())) { + set_detailed_error("Can not read path challenge data."); + return false; + } + return true; +} + +bool QuicFramer::ProcessPathResponseFrame(QuicDataReader* reader, + QuicPathResponseFrame* frame) { + if (!reader->ReadBytes(frame->data_buffer.data(), + frame->data_buffer.size())) { + set_detailed_error("Can not read path response data."); + return false; + } + return true; +} + +bool QuicFramer::AppendPathChallengeFrame(const QuicPathChallengeFrame& frame, + QuicDataWriter* writer) { + if (!writer->WriteBytes(frame.data_buffer.data(), frame.data_buffer.size())) { + set_detailed_error("Writing Path Challenge data failed."); + return false; + } + return true; +} + +bool QuicFramer::AppendPathResponseFrame(const QuicPathResponseFrame& frame, + QuicDataWriter* writer) { + if (!writer->WriteBytes(frame.data_buffer.data(), frame.data_buffer.size())) { + set_detailed_error("Writing Path Response data failed."); + return false; + } + return true; +} + +// Add a new ietf-format stream reset frame. +// General format is +// stream id +// application error code +// final offset +bool QuicFramer::AppendIetfResetStreamFrame(const QuicRstStreamFrame& frame, + QuicDataWriter* writer) { + if (!writer->WriteVarInt62(static_cast(frame.stream_id))) { + set_detailed_error("Writing reset-stream stream id failed."); + return false; + } + if (!writer->WriteVarInt62(static_cast(frame.ietf_error_code))) { + set_detailed_error("Writing reset-stream error code failed."); + return false; + } + if (!writer->WriteVarInt62(static_cast(frame.byte_offset))) { + set_detailed_error("Writing reset-stream final-offset failed."); + return false; + } + return true; +} + +bool QuicFramer::ProcessIetfResetStreamFrame(QuicDataReader* reader, + QuicRstStreamFrame* frame) { + // Get Stream ID from frame. ReadVarIntStreamID returns false + // if either A) there is a read error or B) the resulting value of + // the Stream ID is larger than the maximum allowed value. + if (!ReadUint32FromVarint62(reader, IETF_RST_STREAM, &frame->stream_id)) { + return false; + } + + if (!reader->ReadVarInt62(&frame->ietf_error_code)) { + set_detailed_error("Unable to read rst stream error code."); + return false; + } + + frame->error_code = + IetfResetStreamErrorCodeToRstStreamErrorCode(frame->ietf_error_code); + + if (!reader->ReadVarInt62(&frame->byte_offset)) { + set_detailed_error("Unable to read rst stream sent byte offset."); + return false; + } + return true; +} + +bool QuicFramer::ProcessStopSendingFrame( + QuicDataReader* reader, QuicStopSendingFrame* stop_sending_frame) { + if (!ReadUint32FromVarint62(reader, IETF_STOP_SENDING, + &stop_sending_frame->stream_id)) { + return false; + } + + if (!reader->ReadVarInt62(&stop_sending_frame->ietf_error_code)) { + set_detailed_error("Unable to read stop sending application error code."); + return false; + } + + stop_sending_frame->error_code = IetfResetStreamErrorCodeToRstStreamErrorCode( + stop_sending_frame->ietf_error_code); + return true; +} + +bool QuicFramer::AppendStopSendingFrame( + const QuicStopSendingFrame& stop_sending_frame, QuicDataWriter* writer) { + if (!writer->WriteVarInt62(stop_sending_frame.stream_id)) { + set_detailed_error("Can not write stop sending stream id"); + return false; + } + if (!writer->WriteVarInt62( + static_cast(stop_sending_frame.ietf_error_code))) { + set_detailed_error("Can not write application error code"); + return false; + } + return true; +} + +// Append/process IETF-Format MAX_DATA Frame +bool QuicFramer::AppendMaxDataFrame(const QuicWindowUpdateFrame& frame, + QuicDataWriter* writer) { + if (!writer->WriteVarInt62(frame.max_data)) { + set_detailed_error("Can not write MAX_DATA byte-offset"); + return false; + } + return true; +} + +bool QuicFramer::ProcessMaxDataFrame(QuicDataReader* reader, + QuicWindowUpdateFrame* frame) { + frame->stream_id = QuicUtils::GetInvalidStreamId(transport_version()); + if (!reader->ReadVarInt62(&frame->max_data)) { + set_detailed_error("Can not read MAX_DATA byte-offset"); + return false; + } + return true; +} + +// Append/process IETF-Format MAX_STREAM_DATA Frame +bool QuicFramer::AppendMaxStreamDataFrame(const QuicWindowUpdateFrame& frame, + QuicDataWriter* writer) { + if (!writer->WriteVarInt62(frame.stream_id)) { + set_detailed_error("Can not write MAX_STREAM_DATA stream id"); + return false; + } + if (!writer->WriteVarInt62(frame.max_data)) { + set_detailed_error("Can not write MAX_STREAM_DATA byte-offset"); + return false; + } + return true; +} + +bool QuicFramer::ProcessMaxStreamDataFrame(QuicDataReader* reader, + QuicWindowUpdateFrame* frame) { + if (!ReadUint32FromVarint62(reader, IETF_MAX_STREAM_DATA, + &frame->stream_id)) { + return false; + } + if (!reader->ReadVarInt62(&frame->max_data)) { + set_detailed_error("Can not read MAX_STREAM_DATA byte-count"); + return false; + } + return true; +} + +bool QuicFramer::AppendMaxStreamsFrame(const QuicMaxStreamsFrame& frame, + QuicDataWriter* writer) { + if (!writer->WriteVarInt62(frame.stream_count)) { + set_detailed_error("Can not write MAX_STREAMS stream count"); + return false; + } + return true; +} + +bool QuicFramer::ProcessMaxStreamsFrame(QuicDataReader* reader, + QuicMaxStreamsFrame* frame, + uint64_t frame_type) { + if (!ReadUint32FromVarint62(reader, + static_cast(frame_type), + &frame->stream_count)) { + return false; + } + frame->unidirectional = (frame_type == IETF_MAX_STREAMS_UNIDIRECTIONAL); + return true; +} + +bool QuicFramer::AppendDataBlockedFrame(const QuicBlockedFrame& frame, + QuicDataWriter* writer) { + if (!writer->WriteVarInt62(frame.offset)) { + set_detailed_error("Can not write blocked offset."); + return false; + } + return true; +} + +bool QuicFramer::ProcessDataBlockedFrame(QuicDataReader* reader, + QuicBlockedFrame* frame) { + // Indicates that it is a BLOCKED frame (as opposed to STREAM_BLOCKED). + frame->stream_id = QuicUtils::GetInvalidStreamId(transport_version()); + if (!reader->ReadVarInt62(&frame->offset)) { + set_detailed_error("Can not read blocked offset."); + return false; + } + return true; +} + +bool QuicFramer::AppendStreamDataBlockedFrame(const QuicBlockedFrame& frame, + QuicDataWriter* writer) { + if (!writer->WriteVarInt62(frame.stream_id)) { + set_detailed_error("Can not write stream blocked stream id."); + return false; + } + if (!writer->WriteVarInt62(frame.offset)) { + set_detailed_error("Can not write stream blocked offset."); + return false; + } + return true; +} + +bool QuicFramer::ProcessStreamDataBlockedFrame(QuicDataReader* reader, + QuicBlockedFrame* frame) { + if (!ReadUint32FromVarint62(reader, IETF_STREAM_DATA_BLOCKED, + &frame->stream_id)) { + return false; + } + if (!reader->ReadVarInt62(&frame->offset)) { + set_detailed_error("Can not read stream blocked offset."); + return false; + } + return true; +} + +bool QuicFramer::AppendStreamsBlockedFrame(const QuicStreamsBlockedFrame& frame, + QuicDataWriter* writer) { + if (!writer->WriteVarInt62(frame.stream_count)) { + set_detailed_error("Can not write STREAMS_BLOCKED stream count"); + return false; + } + return true; +} + +bool QuicFramer::ProcessStreamsBlockedFrame(QuicDataReader* reader, + QuicStreamsBlockedFrame* frame, + uint64_t frame_type) { + if (!ReadUint32FromVarint62(reader, + static_cast(frame_type), + &frame->stream_count)) { + return false; + } + if (frame->stream_count > QuicUtils::GetMaxStreamCount()) { + // If stream count is such that the resulting stream ID would exceed our + // implementation limit, generate an error. + set_detailed_error( + "STREAMS_BLOCKED stream count exceeds implementation limit."); + return false; + } + frame->unidirectional = (frame_type == IETF_STREAMS_BLOCKED_UNIDIRECTIONAL); + return true; +} + +bool QuicFramer::AppendNewConnectionIdFrame( + const QuicNewConnectionIdFrame& frame, QuicDataWriter* writer) { + if (!writer->WriteVarInt62(frame.sequence_number)) { + set_detailed_error("Can not write New Connection ID sequence number"); + return false; + } + if (!writer->WriteVarInt62(frame.retire_prior_to)) { + set_detailed_error("Can not write New Connection ID retire_prior_to"); + return false; + } + if (!writer->WriteLengthPrefixedConnectionId(frame.connection_id)) { + set_detailed_error("Can not write New Connection ID frame connection ID"); + return false; + } + + if (!writer->WriteBytes( + static_cast(&frame.stateless_reset_token), + sizeof(frame.stateless_reset_token))) { + set_detailed_error("Can not write New Connection ID Reset Token"); + return false; + } + return true; +} + +bool QuicFramer::ProcessNewConnectionIdFrame(QuicDataReader* reader, + QuicNewConnectionIdFrame* frame) { + if (!reader->ReadVarInt62(&frame->sequence_number)) { + set_detailed_error( + "Unable to read new connection ID frame sequence number."); + return false; + } + + if (!reader->ReadVarInt62(&frame->retire_prior_to)) { + set_detailed_error( + "Unable to read new connection ID frame retire_prior_to."); + return false; + } + if (frame->retire_prior_to > frame->sequence_number) { + set_detailed_error("Retire_prior_to > sequence_number."); + return false; + } + + if (!reader->ReadLengthPrefixedConnectionId(&frame->connection_id)) { + set_detailed_error("Unable to read new connection ID frame connection id."); + return false; + } + + if (!QuicUtils::IsConnectionIdValidForVersion(frame->connection_id, + transport_version())) { + set_detailed_error("Invalid new connection ID length for version."); + return false; + } + + if (!reader->ReadBytes(&frame->stateless_reset_token, + sizeof(frame->stateless_reset_token))) { + set_detailed_error("Can not read new connection ID frame reset token."); + return false; + } + return true; +} + +bool QuicFramer::AppendRetireConnectionIdFrame( + const QuicRetireConnectionIdFrame& frame, QuicDataWriter* writer) { + if (!writer->WriteVarInt62(frame.sequence_number)) { + set_detailed_error("Can not write Retire Connection ID sequence number"); + return false; + } + return true; +} + +bool QuicFramer::ProcessRetireConnectionIdFrame( + QuicDataReader* reader, QuicRetireConnectionIdFrame* frame) { + if (!reader->ReadVarInt62(&frame->sequence_number)) { + set_detailed_error( + "Unable to read retire connection ID frame sequence number."); + return false; + } + return true; +} + +bool QuicFramer::ReadUint32FromVarint62(QuicDataReader* reader, + QuicIetfFrameType type, + QuicStreamId* id) { + uint64_t temp_uint64; + if (!reader->ReadVarInt62(&temp_uint64)) { + set_detailed_error("Unable to read " + QuicIetfFrameTypeString(type) + + " frame stream id/count."); + return false; + } + if (temp_uint64 > kMaxQuicStreamId) { + set_detailed_error("Stream id/count of " + QuicIetfFrameTypeString(type) + + "frame is too large."); + return false; + } + *id = static_cast(temp_uint64); + return true; +} + +uint8_t QuicFramer::GetStreamFrameTypeByte(const QuicStreamFrame& frame, + bool last_frame_in_packet) const { + if (VersionHasIetfQuicFrames(version_.transport_version)) { + return GetIetfStreamFrameTypeByte(frame, last_frame_in_packet); + } + uint8_t type_byte = 0; + // Fin bit. + type_byte |= frame.fin ? kQuicStreamFinMask : 0; + + // Data Length bit. + type_byte <<= kQuicStreamDataLengthShift; + type_byte |= last_frame_in_packet ? 0 : kQuicStreamDataLengthMask; + + // Offset 3 bits. + type_byte <<= kQuicStreamShift; + const size_t offset_len = GetStreamOffsetSize(frame.offset); + if (offset_len > 0) { + type_byte |= offset_len - 1; + } + + // stream id 2 bits. + type_byte <<= kQuicStreamIdShift; + type_byte |= GetStreamIdSize(frame.stream_id) - 1; + type_byte |= kQuicFrameTypeStreamMask; // Set Stream Frame Type to 1. + + return type_byte; +} + +uint8_t QuicFramer::GetIetfStreamFrameTypeByte( + const QuicStreamFrame& frame, bool last_frame_in_packet) const { + QUICHE_DCHECK(VersionHasIetfQuicFrames(version_.transport_version)); + uint8_t type_byte = IETF_STREAM; + if (!last_frame_in_packet) { + type_byte |= IETF_STREAM_FRAME_LEN_BIT; + } + if (frame.offset != 0) { + type_byte |= IETF_STREAM_FRAME_OFF_BIT; + } + if (frame.fin) { + type_byte |= IETF_STREAM_FRAME_FIN_BIT; + } + return type_byte; +} + +void QuicFramer::InferPacketHeaderTypeFromVersion() { + // This function should only be called when server connection negotiates the + // version. + QUICHE_DCHECK_EQ(perspective_, Perspective::IS_SERVER); + QUICHE_DCHECK(!infer_packet_header_type_from_version_); + infer_packet_header_type_from_version_ = true; +} + +void QuicFramer::EnableMultiplePacketNumberSpacesSupport() { + if (supports_multiple_packet_number_spaces_) { + QUIC_BUG(quic_bug_10850_91) + << "Multiple packet number spaces has already been enabled"; + return; + } + if (largest_packet_number_.IsInitialized()) { + QUIC_BUG(quic_bug_10850_92) + << "Try to enable multiple packet number spaces support after any " + "packet has been received."; + return; + } + + supports_multiple_packet_number_spaces_ = true; +} + +// static +QuicErrorCode QuicFramer::ParsePublicHeaderDispatcher( + const QuicEncryptedPacket& packet, + uint8_t expected_destination_connection_id_length, + PacketHeaderFormat* format, QuicLongHeaderType* long_packet_type, + bool* version_present, bool* has_length_prefix, + QuicVersionLabel* version_label, ParsedQuicVersion* parsed_version, + QuicConnectionId* destination_connection_id, + QuicConnectionId* source_connection_id, + absl::optional* retry_token, + std::string* detailed_error) { + QuicDataReader reader(packet.data(), packet.length()); + if (reader.IsDoneReading()) { + *detailed_error = "Unable to read first byte."; + return QUIC_INVALID_PACKET_HEADER; + } + const uint8_t first_byte = reader.PeekByte(); + if ((first_byte & FLAGS_LONG_HEADER) == 0 && + (first_byte & FLAGS_FIXED_BIT) == 0 && + (first_byte & FLAGS_DEMULTIPLEXING_BIT) == 0) { + // All versions of Google QUIC up to and including Q043 set + // FLAGS_DEMULTIPLEXING_BIT to one on all client-to-server packets. Q044 + // and Q045 were never default-enabled in production. All subsequent + // versions of Google QUIC (starting with Q046) require FLAGS_FIXED_BIT to + // be set to one on all packets. All versions of IETF QUIC (since + // draft-ietf-quic-transport-17 which was earlier than the first IETF QUIC + // version that was deployed in production by any implementation) also + // require FLAGS_FIXED_BIT to be set to one on all packets. If a packet + // has the FLAGS_LONG_HEADER bit set to one, it could be a first flight + // from an unknown future version that allows the other two bits to be set + // to zero. Based on this, packets that have all three of those bits set + // to zero are known to be invalid. + *detailed_error = "Invalid flags."; + return QUIC_INVALID_PACKET_HEADER; + } + const bool ietf_format = QuicUtils::IsIetfPacketHeader(first_byte); + uint8_t unused_first_byte; + quiche::QuicheVariableLengthIntegerLength retry_token_length_length; + absl::string_view maybe_retry_token; + QuicErrorCode error_code = ParsePublicHeader( + &reader, expected_destination_connection_id_length, ietf_format, + &unused_first_byte, format, version_present, has_length_prefix, + version_label, parsed_version, destination_connection_id, + source_connection_id, long_packet_type, &retry_token_length_length, + &maybe_retry_token, detailed_error); + if (retry_token_length_length != quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0) { + *retry_token = maybe_retry_token; + } else { + retry_token->reset(); + } + return error_code; +} + +// static +QuicErrorCode QuicFramer::ParsePublicHeaderDispatcherShortHeaderLengthUnknown( + const QuicEncryptedPacket& packet, PacketHeaderFormat* format, + QuicLongHeaderType* long_packet_type, bool* version_present, + bool* has_length_prefix, QuicVersionLabel* version_label, + ParsedQuicVersion* parsed_version, + QuicConnectionId* destination_connection_id, + QuicConnectionId* source_connection_id, + absl::optional* retry_token, std::string* detailed_error, + ConnectionIdGeneratorInterface& generator) { + QuicDataReader reader(packet.data(), packet.length()); + // Get the first two bytes. + if (reader.BytesRemaining() < 2) { + *detailed_error = "Unable to read first two bytes."; + return QUIC_INVALID_PACKET_HEADER; + } + uint8_t two_bytes[2]; + reader.ReadBytes(two_bytes, 2); + uint8_t expected_destination_connection_id_length = + (!QuicUtils::IsIetfPacketHeader(two_bytes[0]) || + two_bytes[0] & FLAGS_LONG_HEADER) + ? 0 + : generator.ConnectionIdLength(two_bytes[1]); + return ParsePublicHeaderDispatcher( + packet, expected_destination_connection_id_length, format, + long_packet_type, version_present, has_length_prefix, version_label, + parsed_version, destination_connection_id, source_connection_id, + retry_token, detailed_error); +} + +// static +QuicErrorCode QuicFramer::ParsePublicHeaderGoogleQuic( + QuicDataReader* reader, uint8_t* first_byte, PacketHeaderFormat* format, + bool* version_present, QuicVersionLabel* version_label, + ParsedQuicVersion* parsed_version, + QuicConnectionId* destination_connection_id, std::string* detailed_error) { + *format = GOOGLE_QUIC_PACKET; + *version_present = (*first_byte & PACKET_PUBLIC_FLAGS_VERSION) != 0; + uint8_t destination_connection_id_length = 0; + if ((*first_byte & PACKET_PUBLIC_FLAGS_8BYTE_CONNECTION_ID) != 0) { + destination_connection_id_length = kQuicDefaultConnectionIdLength; + } + if (!reader->ReadConnectionId(destination_connection_id, + destination_connection_id_length)) { + *detailed_error = "Unable to read ConnectionId."; + return QUIC_INVALID_PACKET_HEADER; + } + if (*version_present) { + if (!ProcessVersionLabel(reader, version_label)) { + *detailed_error = "Unable to read protocol version."; + return QUIC_INVALID_PACKET_HEADER; + } + *parsed_version = ParseQuicVersionLabel(*version_label); + } + return QUIC_NO_ERROR; +} + +namespace { + +const QuicVersionLabel kProxVersionLabel = 0x50524F58; // "PROX" + +inline bool PacketHasLengthPrefixedConnectionIds( + const QuicDataReader& reader, ParsedQuicVersion parsed_version, + QuicVersionLabel version_label, uint8_t first_byte) { + if (parsed_version.IsKnown()) { + return parsed_version.HasLengthPrefixedConnectionIds(); + } + + // Received unsupported version, check known old unsupported versions. + if (QuicVersionLabelUses4BitConnectionIdLength(version_label)) { + return false; + } + + // Received unknown version, check connection ID length byte. + if (reader.IsDoneReading()) { + // This check is required to safely peek the connection ID length byte. + return true; + } + const uint8_t connection_id_length_byte = reader.PeekByte(); + + // Check for packets produced by older versions of + // QuicFramer::WriteClientVersionNegotiationProbePacket + if (first_byte == 0xc0 && (connection_id_length_byte & 0x0f) == 0 && + connection_id_length_byte >= 0x50 && version_label == 0xcabadaba) { + return false; + } + + // Check for munged packets with version tag PROX. + if ((connection_id_length_byte & 0x0f) == 0 && + connection_id_length_byte >= 0x20 && version_label == kProxVersionLabel) { + return false; + } + + return true; +} + +inline bool ParseLongHeaderConnectionIds( + QuicDataReader& reader, bool has_length_prefix, + QuicVersionLabel version_label, QuicConnectionId& destination_connection_id, + QuicConnectionId& source_connection_id, std::string& detailed_error) { + if (has_length_prefix) { + if (!reader.ReadLengthPrefixedConnectionId(&destination_connection_id)) { + detailed_error = "Unable to read destination connection ID."; + return false; + } + if (!reader.ReadLengthPrefixedConnectionId(&source_connection_id)) { + if (version_label == kProxVersionLabel) { + // The "PROX" version does not follow the length-prefixed invariants, + // and can therefore attempt to read a payload byte and interpret it + // as the source connection ID length, which could fail to parse. + // In that scenario we keep the source connection ID empty but mark + // parsing as successful. + return true; + } + detailed_error = "Unable to read source connection ID."; + return false; + } + } else { + // Parse connection ID lengths. + uint8_t connection_id_lengths_byte; + if (!reader.ReadUInt8(&connection_id_lengths_byte)) { + detailed_error = "Unable to read connection ID lengths."; + return false; + } + uint8_t destination_connection_id_length = + (connection_id_lengths_byte & kDestinationConnectionIdLengthMask) >> 4; + if (destination_connection_id_length != 0) { + destination_connection_id_length += kConnectionIdLengthAdjustment; + } + uint8_t source_connection_id_length = + connection_id_lengths_byte & kSourceConnectionIdLengthMask; + if (source_connection_id_length != 0) { + source_connection_id_length += kConnectionIdLengthAdjustment; + } + + // Read destination connection ID. + if (!reader.ReadConnectionId(&destination_connection_id, + destination_connection_id_length)) { + detailed_error = "Unable to read destination connection ID."; + return false; + } + + // Read source connection ID. + if (!reader.ReadConnectionId(&source_connection_id, + source_connection_id_length)) { + detailed_error = "Unable to read source connection ID."; + return false; + } + } + return true; +} + +} // namespace + +// static +QuicErrorCode QuicFramer::ParsePublicHeader( + QuicDataReader* reader, uint8_t expected_destination_connection_id_length, + bool ietf_format, uint8_t* first_byte, PacketHeaderFormat* format, + bool* version_present, bool* has_length_prefix, + QuicVersionLabel* version_label, ParsedQuicVersion* parsed_version, + QuicConnectionId* destination_connection_id, + QuicConnectionId* source_connection_id, + QuicLongHeaderType* long_packet_type, + quiche::QuicheVariableLengthIntegerLength* retry_token_length_length, + absl::string_view* retry_token, std::string* detailed_error) { + *version_present = false; + *has_length_prefix = false; + *version_label = 0; + *parsed_version = UnsupportedQuicVersion(); + *source_connection_id = EmptyQuicConnectionId(); + *long_packet_type = INVALID_PACKET_TYPE; + *retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0; + *retry_token = absl::string_view(); + *detailed_error = ""; + + if (!reader->ReadUInt8(first_byte)) { + *detailed_error = "Unable to read first byte."; + return QUIC_INVALID_PACKET_HEADER; + } + + if (!ietf_format) { + return ParsePublicHeaderGoogleQuic( + reader, first_byte, format, version_present, version_label, + parsed_version, destination_connection_id, detailed_error); + } + + *format = GetIetfPacketHeaderFormat(*first_byte); + + if (*format == IETF_QUIC_SHORT_HEADER_PACKET) { + if (!reader->ReadConnectionId(destination_connection_id, + expected_destination_connection_id_length)) { + *detailed_error = "Unable to read destination connection ID."; + return QUIC_INVALID_PACKET_HEADER; + } + return QUIC_NO_ERROR; + } + + QUICHE_DCHECK_EQ(IETF_QUIC_LONG_HEADER_PACKET, *format); + *version_present = true; + if (!ProcessVersionLabel(reader, version_label)) { + *detailed_error = "Unable to read protocol version."; + return QUIC_INVALID_PACKET_HEADER; + } + + if (*version_label == 0) { + *long_packet_type = VERSION_NEGOTIATION; + } + + // Parse version. + *parsed_version = ParseQuicVersionLabel(*version_label); + + // Figure out which IETF QUIC invariants this packet follows. + *has_length_prefix = PacketHasLengthPrefixedConnectionIds( + *reader, *parsed_version, *version_label, *first_byte); + + // Parse connection IDs. + if (!ParseLongHeaderConnectionIds(*reader, *has_length_prefix, *version_label, + *destination_connection_id, + *source_connection_id, *detailed_error)) { + return QUIC_INVALID_PACKET_HEADER; + } + + if (!parsed_version->IsKnown()) { + // Skip parsing of long packet type and retry token for unknown versions. + return QUIC_NO_ERROR; + } + + // Parse long packet type. + *long_packet_type = GetLongHeaderType(*first_byte, *parsed_version); + + switch (*long_packet_type) { + case INVALID_PACKET_TYPE: + *detailed_error = "Unable to parse long packet type."; + return QUIC_INVALID_PACKET_HEADER; + case INITIAL: + if (!parsed_version->SupportsRetry()) { + // Retry token is only present on initial packets for some versions. + return QUIC_NO_ERROR; + } + break; + default: + return QUIC_NO_ERROR; + } + + *retry_token_length_length = reader->PeekVarInt62Length(); + uint64_t retry_token_length; + if (!reader->ReadVarInt62(&retry_token_length)) { + *retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0; + *detailed_error = "Unable to read retry token length."; + return QUIC_INVALID_PACKET_HEADER; + } + + if (!reader->ReadStringPiece(retry_token, retry_token_length)) { + *detailed_error = "Unable to read retry token."; + return QUIC_INVALID_PACKET_HEADER; + } + + return QUIC_NO_ERROR; +} + +// static +bool QuicFramer::WriteClientVersionNegotiationProbePacket( + char* packet_bytes, QuicByteCount packet_length, + const char* destination_connection_id_bytes, + uint8_t destination_connection_id_length) { + if (packet_bytes == nullptr) { + QUIC_BUG(quic_bug_10850_93) << "Invalid packet_bytes"; + return false; + } + if (packet_length < kMinPacketSizeForVersionNegotiation || + packet_length > 65535) { + QUIC_BUG(quic_bug_10850_94) << "Invalid packet_length"; + return false; + } + if (destination_connection_id_length > kQuicMaxConnectionId4BitLength || + destination_connection_id_length < kQuicDefaultConnectionIdLength) { + QUIC_BUG(quic_bug_10850_95) << "Invalid connection_id_length"; + return false; + } + // clang-format off + const unsigned char packet_start_bytes[] = { + // IETF long header with fixed bit set, type initial, all-0 encrypted bits. + 0xc0, + // Version, part of the IETF space reserved for negotiation. + // This intentionally differs from QuicVersionReservedForNegotiation() + // to allow differentiating them over the wire. + 0xca, 0xba, 0xda, 0xda, + }; + // clang-format on + static_assert(sizeof(packet_start_bytes) == 5, "bad packet_start_bytes size"); + QuicDataWriter writer(packet_length, packet_bytes); + if (!writer.WriteBytes(packet_start_bytes, sizeof(packet_start_bytes))) { + QUIC_BUG(quic_bug_10850_96) << "Failed to write packet start"; + return false; + } + + QuicConnectionId destination_connection_id(destination_connection_id_bytes, + destination_connection_id_length); + if (!AppendIetfConnectionIds( + /*version_flag=*/true, /*use_length_prefix=*/true, + destination_connection_id, EmptyQuicConnectionId(), &writer)) { + QUIC_BUG(quic_bug_10850_97) << "Failed to write connection IDs"; + return false; + } + // Add 8 bytes of zeroes followed by 8 bytes of ones to ensure that this does + // not parse with any known version. The zeroes make sure that packet numbers, + // retry token lengths and payload lengths are parsed as zero, and if the + // zeroes are treated as padding frames, 0xff is known to not parse as a + // valid frame type. + if (!writer.WriteUInt64(0) || + !writer.WriteUInt64(std::numeric_limits::max())) { + QUIC_BUG(quic_bug_10850_98) << "Failed to write 18 bytes"; + return false; + } + // Make sure the polite greeting below is padded to a 16-byte boundary to + // make it easier to read in tcpdump. + while (writer.length() % 16 != 0) { + if (!writer.WriteUInt8(0)) { + QUIC_BUG(quic_bug_10850_99) << "Failed to write padding byte"; + return false; + } + } + // Add a polite greeting in case a human sees this in tcpdump. + static const char polite_greeting[] = + "This packet only exists to trigger IETF QUIC version negotiation. " + "Please respond with a Version Negotiation packet indicating what " + "versions you support. Thank you and have a nice day."; + if (!writer.WriteBytes(polite_greeting, sizeof(polite_greeting))) { + QUIC_BUG(quic_bug_10850_100) << "Failed to write polite greeting"; + return false; + } + // Fill the rest of the packet with zeroes. + writer.WritePadding(); + QUICHE_DCHECK_EQ(0u, writer.remaining()); + return true; +} + +// static +bool QuicFramer::ParseServerVersionNegotiationProbeResponse( + const char* packet_bytes, QuicByteCount packet_length, + char* source_connection_id_bytes, uint8_t* source_connection_id_length_out, + std::string* detailed_error) { + if (detailed_error == nullptr) { + QUIC_BUG(quic_bug_10850_101) << "Invalid error_details"; + return false; + } + *detailed_error = ""; + if (packet_bytes == nullptr) { + *detailed_error = "Invalid packet_bytes"; + return false; + } + if (packet_length < 6) { + *detailed_error = "Invalid packet_length"; + return false; + } + if (source_connection_id_bytes == nullptr) { + *detailed_error = "Invalid source_connection_id_bytes"; + return false; + } + if (source_connection_id_length_out == nullptr) { + *detailed_error = "Invalid source_connection_id_length_out"; + return false; + } + QuicDataReader reader(packet_bytes, packet_length); + uint8_t type_byte = 0; + if (!reader.ReadUInt8(&type_byte)) { + *detailed_error = "Failed to read type byte"; + return false; + } + if ((type_byte & 0x80) == 0) { + *detailed_error = "Packet does not have long header"; + return false; + } + uint32_t version = 0; + if (!reader.ReadUInt32(&version)) { + *detailed_error = "Failed to read version"; + return false; + } + if (version != 0) { + *detailed_error = "Packet is not a version negotiation packet"; + return false; + } + + QuicConnectionId destination_connection_id, source_connection_id; + if (!reader.ReadLengthPrefixedConnectionId(&destination_connection_id)) { + *detailed_error = "Failed to read destination connection ID"; + return false; + } + if (!reader.ReadLengthPrefixedConnectionId(&source_connection_id)) { + *detailed_error = "Failed to read source connection ID"; + return false; + } + + if (destination_connection_id.length() != 0) { + *detailed_error = "Received unexpected destination connection ID length"; + return false; + } + if (*source_connection_id_length_out < source_connection_id.length()) { + *detailed_error = + absl::StrCat("*source_connection_id_length_out too small ", + static_cast(*source_connection_id_length_out), " < ", + static_cast(source_connection_id.length())); + return false; + } + + memcpy(source_connection_id_bytes, source_connection_id.data(), + source_connection_id.length()); + *source_connection_id_length_out = source_connection_id.length(); + + return true; +} + +// Look for and parse the error code from the ":" text that +// may be present at the start of the CONNECTION_CLOSE error details string. +// This text, inserted by the peer if it's using Google's QUIC implementation, +// contains additional error information that narrows down the exact error. If +// the string is not found, or is not properly formed, it returns +// ErrorCode::QUIC_IETF_GQUIC_ERROR_MISSING +void MaybeExtractQuicErrorCode(QuicConnectionCloseFrame* frame) { + std::vector ed = absl::StrSplit(frame->error_details, ':'); + uint64_t extracted_error_code; + if (ed.size() < 2 || !quiche::QuicheTextUtils::IsAllDigits(ed[0]) || + !absl::SimpleAtoi(ed[0], &extracted_error_code)) { + if (frame->close_type == IETF_QUIC_TRANSPORT_CONNECTION_CLOSE && + frame->wire_error_code == NO_IETF_QUIC_ERROR) { + frame->quic_error_code = QUIC_NO_ERROR; + } else { + frame->quic_error_code = QUIC_IETF_GQUIC_ERROR_MISSING; + } + return; + } + // Return the error code (numeric) and the error details string without the + // error code prefix. Note that Split returns everything up to, but not + // including, the split character, so the length of ed[0] is just the number + // of digits in the error number. In removing the prefix, 1 is added to the + // length to account for the : + absl::string_view x = absl::string_view(frame->error_details); + x.remove_prefix(ed[0].length() + 1); + frame->error_details = std::string(x); + frame->quic_error_code = static_cast(extracted_error_code); +} + +#undef ENDPOINT // undef for jumbo builds +} // namespace quic diff --git a/quiche/quic/core/quic_framer.h b/quiche/quic/core/quic_framer.h new file mode 100644 index 000000000000..349b80f76895 --- /dev/null +++ b/quiche/quic/core/quic_framer.h @@ -0,0 +1,1243 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_FRAMER_H_ +#define QUICHE_QUIC_CORE_QUIC_FRAMER_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/connection_id_generator.h" +#include "quiche/quic/core/crypto/quic_decrypter.h" +#include "quiche/quic/core/crypto/quic_encrypter.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +namespace test { +class QuicFramerPeer; +} // namespace test + +class QuicDataReader; +class QuicDataWriter; +class QuicFramer; +class QuicStreamFrameDataProducer; + +// Number of bytes reserved for the frame type preceding each frame. +const size_t kQuicFrameTypeSize = 1; +// Number of bytes reserved for error code. +const size_t kQuicErrorCodeSize = 4; +// Number of bytes reserved to denote the length of error details field. +const size_t kQuicErrorDetailsLengthSize = 2; + +// Maximum number of bytes reserved for stream id. +const size_t kQuicMaxStreamIdSize = 4; +// Maximum number of bytes reserved for byte offset in stream frame. +const size_t kQuicMaxStreamOffsetSize = 8; +// Number of bytes reserved to store payload length in stream frame. +const size_t kQuicStreamPayloadLengthSize = 2; +// Number of bytes to reserve for IQ Error codes (for the Connection Close, +// Application Close, and Reset Stream frames). +const size_t kQuicIetfQuicErrorCodeSize = 2; +// Minimum size of the IETF QUIC Error Phrase's length field +const size_t kIetfQuicMinErrorPhraseLengthSize = 1; + +// Size in bytes reserved for the delta time of the largest observed +// packet number in ack frames. +const size_t kQuicDeltaTimeLargestObservedSize = 2; +// Size in bytes reserved for the number of received packets with timestamps. +const size_t kQuicNumTimestampsSize = 1; +// Size in bytes reserved for the number of missing packets in ack frames. +const size_t kNumberOfNackRangesSize = 1; +// Size in bytes reserved for the number of ack blocks in ack frames. +const size_t kNumberOfAckBlocksSize = 1; +// Maximum number of missing packet ranges that can fit within an ack frame. +const size_t kMaxNackRanges = (1 << (kNumberOfNackRangesSize * 8)) - 1; +// Maximum number of ack blocks that can fit within an ack frame. +const size_t kMaxAckBlocks = (1 << (kNumberOfAckBlocksSize * 8)) - 1; + +// This class receives callbacks from the framer when packets +// are processed. +class QUIC_EXPORT_PRIVATE QuicFramerVisitorInterface { + public: + virtual ~QuicFramerVisitorInterface() {} + + // Called if an error is detected in the QUIC protocol. + virtual void OnError(QuicFramer* framer) = 0; + + // Called only when |perspective_| is IS_SERVER and the framer gets a + // packet with version flag true and the version on the packet doesn't match + // |quic_version_|. The visitor should return true after it updates the + // version of the |framer_| to |received_version| or false to stop processing + // this packet. + virtual bool OnProtocolVersionMismatch( + ParsedQuicVersion received_version) = 0; + + // Called when a new packet has been received, before it + // has been validated or processed. + virtual void OnPacket() = 0; + + // Called when a public reset packet has been parsed but has not yet + // been validated. + virtual void OnPublicResetPacket(const QuicPublicResetPacket& packet) = 0; + + // Called only when |perspective_| is IS_CLIENT and a version negotiation + // packet has been parsed. + virtual void OnVersionNegotiationPacket( + const QuicVersionNegotiationPacket& packet) = 0; + + // Called only when |perspective_| is IS_CLIENT and a retry packet has been + // parsed. |new_connection_id| contains the value of the Source Connection + // ID field, and |retry_token| contains the value of the Retry Token field. + // On versions where UsesTls() is false, + // |original_connection_id| contains the value of the Original Destination + // Connection ID field, and both |retry_integrity_tag| and + // |retry_without_tag| are empty. + // On versions where UsesTls() is true, + // |original_connection_id| is empty, |retry_integrity_tag| contains the + // value of the Retry Integrity Tag field, and |retry_without_tag| contains + // the entire RETRY packet except the Retry Integrity Tag field. + virtual void OnRetryPacket(QuicConnectionId original_connection_id, + QuicConnectionId new_connection_id, + absl::string_view retry_token, + absl::string_view retry_integrity_tag, + absl::string_view retry_without_tag) = 0; + + // Called when all fields except packet number has been parsed, but has not + // been authenticated. If it returns false, framing for this packet will + // cease. + virtual bool OnUnauthenticatedPublicHeader( + const QuicPacketHeader& header) = 0; + + // Called when the unauthenticated portion of the header has been parsed. + // If OnUnauthenticatedHeader returns false, framing for this packet will + // cease. + virtual bool OnUnauthenticatedHeader(const QuicPacketHeader& header) = 0; + + // Called when a packet has been decrypted. |length| is the packet length, + // and |level| is the encryption level of the packet. + virtual void OnDecryptedPacket(size_t length, EncryptionLevel level) = 0; + + // Called when the complete header of a packet had been parsed. + // If OnPacketHeader returns false, framing for this packet will cease. + virtual bool OnPacketHeader(const QuicPacketHeader& header) = 0; + + // Called when the packet being processed contains multiple IETF QUIC packets, + // which is due to there being more data after what is covered by the length + // field. |packet| contains the remaining data which can be processed. + // Note that this is called when the framer parses the length field, before + // it attempts to decrypt the first payload. It is the visitor's + // responsibility to buffer the packet and call ProcessPacket on it + // after the framer is done parsing the current payload. |packet| does not + // own its internal buffer, the visitor should make a copy of it. + virtual void OnCoalescedPacket(const QuicEncryptedPacket& packet) = 0; + + // Called when the packet being processed failed to decrypt. + // |has_decryption_key| indicates whether the framer knew which decryption + // key to use for this packet and already had a suitable key. + virtual void OnUndecryptablePacket(const QuicEncryptedPacket& packet, + EncryptionLevel decryption_level, + bool has_decryption_key) = 0; + + // Called when a StreamFrame has been parsed. + virtual bool OnStreamFrame(const QuicStreamFrame& frame) = 0; + + // Called when a CRYPTO frame has been parsed. + virtual bool OnCryptoFrame(const QuicCryptoFrame& frame) = 0; + + // Called when largest acked of an AckFrame has been parsed. + virtual bool OnAckFrameStart(QuicPacketNumber largest_acked, + QuicTime::Delta ack_delay_time) = 0; + + // Called when ack range [start, end) of an AckFrame has been parsed. + virtual bool OnAckRange(QuicPacketNumber start, QuicPacketNumber end) = 0; + + // Called when a timestamp in the AckFrame has been parsed. + virtual bool OnAckTimestamp(QuicPacketNumber packet_number, + QuicTime timestamp) = 0; + + // Called after the last ack range in an AckFrame has been parsed. + // |start| is the starting value of the last ack range. |ecn_counts| are + // the reported ECN counts in the ack frame, if present. + virtual bool OnAckFrameEnd( + QuicPacketNumber start, + const absl::optional& ecn_counts) = 0; + + // Called when a StopWaitingFrame has been parsed. + virtual bool OnStopWaitingFrame(const QuicStopWaitingFrame& frame) = 0; + + // Called when a QuicPaddingFrame has been parsed. + virtual bool OnPaddingFrame(const QuicPaddingFrame& frame) = 0; + + // Called when a PingFrame has been parsed. + virtual bool OnPingFrame(const QuicPingFrame& frame) = 0; + + // Called when a RstStreamFrame has been parsed. + virtual bool OnRstStreamFrame(const QuicRstStreamFrame& frame) = 0; + + // Called when a ConnectionCloseFrame, of any type, has been parsed. + virtual bool OnConnectionCloseFrame( + const QuicConnectionCloseFrame& frame) = 0; + + // Called when a StopSendingFrame has been parsed. + virtual bool OnStopSendingFrame(const QuicStopSendingFrame& frame) = 0; + + // Called when a PathChallengeFrame has been parsed. + virtual bool OnPathChallengeFrame(const QuicPathChallengeFrame& frame) = 0; + + // Called when a PathResponseFrame has been parsed. + virtual bool OnPathResponseFrame(const QuicPathResponseFrame& frame) = 0; + + // Called when a GoAwayFrame has been parsed. + virtual bool OnGoAwayFrame(const QuicGoAwayFrame& frame) = 0; + + // Called when a WindowUpdateFrame has been parsed. + virtual bool OnWindowUpdateFrame(const QuicWindowUpdateFrame& frame) = 0; + + // Called when a BlockedFrame has been parsed. + virtual bool OnBlockedFrame(const QuicBlockedFrame& frame) = 0; + + // Called when a NewConnectionIdFrame has been parsed. + virtual bool OnNewConnectionIdFrame( + const QuicNewConnectionIdFrame& frame) = 0; + + // Called when a RetireConnectionIdFrame has been parsed. + virtual bool OnRetireConnectionIdFrame( + const QuicRetireConnectionIdFrame& frame) = 0; + + // Called when a NewTokenFrame has been parsed. + virtual bool OnNewTokenFrame(const QuicNewTokenFrame& frame) = 0; + + // Called when a message frame has been parsed. + virtual bool OnMessageFrame(const QuicMessageFrame& frame) = 0; + + // Called when a handshake done frame has been parsed. + virtual bool OnHandshakeDoneFrame(const QuicHandshakeDoneFrame& frame) = 0; + + // Called when an AckFrequencyFrame has been parsed. + virtual bool OnAckFrequencyFrame(const QuicAckFrequencyFrame& frame) = 0; + + // Called when a packet has been completely processed. + virtual void OnPacketComplete() = 0; + + // Called to check whether |token| is a valid stateless reset token. + virtual bool IsValidStatelessResetToken( + const StatelessResetToken& token) const = 0; + + // Called when an IETF stateless reset packet has been parsed and validated + // with the stateless reset token. + virtual void OnAuthenticatedIetfStatelessResetPacket( + const QuicIetfStatelessResetPacket& packet) = 0; + + // Called when an IETF MaxStreams frame has been parsed. + virtual bool OnMaxStreamsFrame(const QuicMaxStreamsFrame& frame) = 0; + + // Called when an IETF StreamsBlocked frame has been parsed. + virtual bool OnStreamsBlockedFrame(const QuicStreamsBlockedFrame& frame) = 0; + + // Called when a Key Phase Update has been initiated. This is called for both + // locally and peer initiated key updates. If the key update was locally + // initiated, this does not indicate the peer has received the key update yet. + virtual void OnKeyUpdate(KeyUpdateReason reason) = 0; + + // Called on the first decrypted packet in each key phase (including the + // first key phase.) + virtual void OnDecryptedFirstPacketInKeyPhase() = 0; + + // Called when the framer needs to generate a decrypter for the next key + // phase. Each call should generate the key for phase n+1. + virtual std::unique_ptr + AdvanceKeysAndCreateCurrentOneRttDecrypter() = 0; + + // Called when the framer needs to generate an encrypter. The key corresponds + // to the key phase of the last decrypter returned by + // AdvanceKeysAndCreateCurrentOneRttDecrypter(). + virtual std::unique_ptr CreateCurrentOneRttEncrypter() = 0; +}; + +// Class for parsing and constructing QUIC packets. It has a +// QuicFramerVisitorInterface that is called when packets are parsed. +class QUIC_EXPORT_PRIVATE QuicFramer { + public: + // Constructs a new framer that installs a kNULL QuicEncrypter and + // QuicDecrypter for level ENCRYPTION_INITIAL. |supported_versions| specifies + // the list of supported QUIC versions. |quic_version_| is set to the maximum + // version in |supported_versions|. + QuicFramer(const ParsedQuicVersionVector& supported_versions, + QuicTime creation_time, Perspective perspective, + uint8_t expected_server_connection_id_length); + QuicFramer(const QuicFramer&) = delete; + QuicFramer& operator=(const QuicFramer&) = delete; + + virtual ~QuicFramer(); + + // Returns true if |version| is a supported transport version. + bool IsSupportedTransportVersion(const QuicTransportVersion version) const; + + // Returns true if |version| is a supported protocol version. + bool IsSupportedVersion(const ParsedQuicVersion version) const; + + // Set callbacks to be called from the framer. A visitor must be set, or + // else the framer will likely crash. It is acceptable for the visitor + // to do nothing. If this is called multiple times, only the last visitor + // will be used. + void set_visitor(QuicFramerVisitorInterface* visitor) { visitor_ = visitor; } + + const ParsedQuicVersionVector& supported_versions() const { + return supported_versions_; + } + + QuicTransportVersion transport_version() const { + return version_.transport_version; + } + + ParsedQuicVersion version() const { return version_; } + + void set_version(const ParsedQuicVersion version); + + // Does not QUICHE_DCHECK for supported version. Used by tests to set + // unsupported version to trigger version negotiation. + void set_version_for_tests(const ParsedQuicVersion version) { + version_ = version; + } + + QuicErrorCode error() const { return error_; } + + // Allows enabling or disabling of timestamp processing and serialization. + // TODO(ianswett): Remove the const once timestamps are negotiated via + // transport params. + void set_process_timestamps(bool process_timestamps) const { + process_timestamps_ = process_timestamps; + } + + // Sets the max number of receive timestamps to send per ACK frame. + // TODO(wub): Remove the const once timestamps are negotiated via + // transport params. + void set_max_receive_timestamps_per_ack(uint32_t max_timestamps) const { + max_receive_timestamps_per_ack_ = max_timestamps; + } + + // Sets the exponent to use when writing/reading ACK receive timestamps. + void set_receive_timestamps_exponent(uint32_t exponent) const { + receive_timestamps_exponent_ = exponent; + } + + // Pass a UDP packet into the framer for parsing. + // Return true if the packet was processed successfully. |packet| must be a + // single, complete UDP packet (not a frame of a packet). This packet + // might be null padded past the end of the payload, which will be correctly + // ignored. + bool ProcessPacket(const QuicEncryptedPacket& packet); + + // Whether we are in the middle of a call to this->ProcessPacket. + bool is_processing_packet() const { return is_processing_packet_; } + + // Largest size in bytes of all stream frame fields without the payload. + static size_t GetMinStreamFrameSize(QuicTransportVersion version, + QuicStreamId stream_id, + QuicStreamOffset offset, + bool last_frame_in_packet, + size_t data_length); + // Returns the overhead of framing a CRYPTO frame with the specific offset and + // data length provided, but not counting the size of the data payload. + static size_t GetMinCryptoFrameSize(QuicStreamOffset offset, + QuicPacketLength data_length); + static size_t GetMessageFrameSize(QuicTransportVersion version, + bool last_frame_in_packet, + QuicByteCount length); + // Size in bytes of all ack frame fields without the missing packets or ack + // blocks. + static size_t GetMinAckFrameSize(QuicTransportVersion version, + const QuicAckFrame& ack_frame, + uint32_t local_ack_delay_exponent, + bool use_ietf_ack_with_receive_timestamp); + // Size in bytes of a stop waiting frame. + static size_t GetStopWaitingFrameSize( + QuicPacketNumberLength packet_number_length); + // Size in bytes of all reset stream frame fields. + static size_t GetRstStreamFrameSize(QuicTransportVersion version, + const QuicRstStreamFrame& frame); + // Size in bytes of all ack frenquency frame fields. + static size_t GetAckFrequencyFrameSize(const QuicAckFrequencyFrame& frame); + // Size in bytes of all connection close frame fields, including the error + // details. + static size_t GetConnectionCloseFrameSize( + QuicTransportVersion version, const QuicConnectionCloseFrame& frame); + // Size in bytes of all GoAway frame fields without the reason phrase. + static size_t GetMinGoAwayFrameSize(); + // Size in bytes of all WindowUpdate frame fields. + // For version 99, determines whether a MAX DATA or MAX STREAM DATA frame will + // be generated and calculates the appropriate size. + static size_t GetWindowUpdateFrameSize(QuicTransportVersion version, + const QuicWindowUpdateFrame& frame); + // Size in bytes of all MaxStreams frame fields. + static size_t GetMaxStreamsFrameSize(QuicTransportVersion version, + const QuicMaxStreamsFrame& frame); + // Size in bytes of all StreamsBlocked frame fields. + static size_t GetStreamsBlockedFrameSize( + QuicTransportVersion version, const QuicStreamsBlockedFrame& frame); + // Size in bytes of all Blocked frame fields. + static size_t GetBlockedFrameSize(QuicTransportVersion version, + const QuicBlockedFrame& frame); + // Size in bytes of PathChallenge frame. + static size_t GetPathChallengeFrameSize(const QuicPathChallengeFrame& frame); + // Size in bytes of PathResponse frame. + static size_t GetPathResponseFrameSize(const QuicPathResponseFrame& frame); + // Size in bytes required to serialize the stream id. + static size_t GetStreamIdSize(QuicStreamId stream_id); + // Size in bytes required to serialize the stream offset. + static size_t GetStreamOffsetSize(QuicStreamOffset offset); + // Size in bytes for a serialized new connection id frame + static size_t GetNewConnectionIdFrameSize( + const QuicNewConnectionIdFrame& frame); + + // Size in bytes for a serialized retire connection id frame + static size_t GetRetireConnectionIdFrameSize( + const QuicRetireConnectionIdFrame& frame); + + // Size in bytes for a serialized new token frame + static size_t GetNewTokenFrameSize(const QuicNewTokenFrame& frame); + + // Size in bytes required for a serialized stop sending frame. + static size_t GetStopSendingFrameSize(const QuicStopSendingFrame& frame); + + // Size in bytes required for a serialized retransmittable control |frame|. + static size_t GetRetransmittableControlFrameSize(QuicTransportVersion version, + const QuicFrame& frame); + + // Returns the number of bytes added to the packet for the specified frame, + // and 0 if the frame doesn't fit. Includes the header size for the first + // frame. + size_t GetSerializedFrameLength(const QuicFrame& frame, size_t free_bytes, + bool first_frame_in_packet, + bool last_frame_in_packet, + QuicPacketNumberLength packet_number_length); + + // Returns the associated data from the encrypted packet |encrypted| as a + // stringpiece. + static absl::string_view GetAssociatedDataFromEncryptedPacket( + QuicTransportVersion version, const QuicEncryptedPacket& encrypted, + uint8_t destination_connection_id_length, + uint8_t source_connection_id_length, bool includes_version, + bool includes_diversification_nonce, + QuicPacketNumberLength packet_number_length, + quiche::QuicheVariableLengthIntegerLength retry_token_length_length, + uint64_t retry_token_length, + quiche::QuicheVariableLengthIntegerLength length_length); + + // Parses the unencrypted fields in a QUIC header using |reader| as input, + // stores the result in the other parameters. + // |expected_destination_connection_id_length| is only used for short headers. + // When server connection IDs are generated by a + // ConnectionIdGeneartor interface, and callers need an accurate + // Destination Connection ID for short header packets, call + // ParsePublicHeaderDispatcherShortHeaderLengthUnknown() instead. + static QuicErrorCode ParsePublicHeader( + QuicDataReader* reader, uint8_t expected_destination_connection_id_length, + bool ietf_format, uint8_t* first_byte, PacketHeaderFormat* format, + bool* version_present, bool* has_length_prefix, + QuicVersionLabel* version_label, ParsedQuicVersion* parsed_version, + QuicConnectionId* destination_connection_id, + QuicConnectionId* source_connection_id, + QuicLongHeaderType* long_packet_type, + quiche::QuicheVariableLengthIntegerLength* retry_token_length_length, + absl::string_view* retry_token, std::string* detailed_error); + + // Parses the unencrypted fields in |packet| and stores them in the other + // parameters. This can only be called on the server. + // |expected_destination_connection_id_length| is only used + // for short headers. When callers need an accurate Destination Connection ID + // specifically for short header packets, call + // ParsePublicHeaderDispatcherShortHeaderLengthUnknown() instead. + static QuicErrorCode ParsePublicHeaderDispatcher( + const QuicEncryptedPacket& packet, + uint8_t expected_destination_connection_id_length, + PacketHeaderFormat* format, QuicLongHeaderType* long_packet_type, + bool* version_present, bool* has_length_prefix, + QuicVersionLabel* version_label, ParsedQuicVersion* parsed_version, + QuicConnectionId* destination_connection_id, + QuicConnectionId* source_connection_id, + absl::optional* retry_token, + std::string* detailed_error); + + // Parses the unencrypted fields in |packet| and stores them in the other + // parameters. The only callers that should use this method are ones where + // (1) the short-header connection ID length is only known by looking at the + // connection ID itself (and |generator| can provide the answer), and (2) + // the caller is interested in the parsed contents even if the packet has a + // short header. Some callers are only interested in parsing long header + // packets to peer into the handshake, and should use + // ParsePublicHeaderDispatcher instead. + static QuicErrorCode ParsePublicHeaderDispatcherShortHeaderLengthUnknown( + const QuicEncryptedPacket& packet, PacketHeaderFormat* format, + QuicLongHeaderType* long_packet_type, bool* version_present, + bool* has_length_prefix, QuicVersionLabel* version_label, + ParsedQuicVersion* parsed_version, + QuicConnectionId* destination_connection_id, + QuicConnectionId* source_connection_id, + absl::optional* retry_token, + std::string* detailed_error, ConnectionIdGeneratorInterface& generator); + + // Serializes a packet containing |frames| into |buffer|. + // Returns the length of the packet, which must not be longer than + // |packet_length|. Returns 0 if it fails to serialize. + size_t BuildDataPacket(const QuicPacketHeader& header, + const QuicFrames& frames, char* buffer, + size_t packet_length, EncryptionLevel level); + + // Returns a new public reset packet. + static std::unique_ptr BuildPublicResetPacket( + const QuicPublicResetPacket& packet); + + // Returns the minimal stateless reset packet length. + static size_t GetMinStatelessResetPacketLength(); + + // Returns a new IETF stateless reset packet. + static std::unique_ptr BuildIetfStatelessResetPacket( + QuicConnectionId connection_id, size_t received_packet_length, + StatelessResetToken stateless_reset_token); + + // Returns a new IETF stateless reset packet with random bytes generated from + // |random|->InsecureRandBytes(). NOTE: the first two bits of the random bytes + // will be modified to 01b to make it look like a short header packet. + static std::unique_ptr BuildIetfStatelessResetPacket( + QuicConnectionId connection_id, size_t received_packet_length, + StatelessResetToken stateless_reset_token, QuicRandom* random); + + // Returns a new version negotiation packet. + static std::unique_ptr BuildVersionNegotiationPacket( + QuicConnectionId server_connection_id, + QuicConnectionId client_connection_id, bool ietf_quic, + bool use_length_prefix, const ParsedQuicVersionVector& versions); + + // Returns a new IETF version negotiation packet. + static std::unique_ptr BuildIetfVersionNegotiationPacket( + bool use_length_prefix, QuicConnectionId server_connection_id, + QuicConnectionId client_connection_id, + const ParsedQuicVersionVector& versions); + + // If header.version_flag is set, the version in the + // packet will be set -- but it will be set from version_ not + // header.versions. + bool AppendPacketHeader(const QuicPacketHeader& header, + QuicDataWriter* writer, size_t* length_field_offset); + bool AppendIetfHeaderTypeByte(const QuicPacketHeader& header, + QuicDataWriter* writer); + bool AppendIetfPacketHeader(const QuicPacketHeader& header, + QuicDataWriter* writer, + size_t* length_field_offset); + bool WriteIetfLongHeaderLength(const QuicPacketHeader& header, + QuicDataWriter* writer, + size_t length_field_offset, + EncryptionLevel level); + bool AppendTypeByte(const QuicFrame& frame, bool last_frame_in_packet, + QuicDataWriter* writer); + bool AppendIetfFrameType(const QuicFrame& frame, bool last_frame_in_packet, + QuicDataWriter* writer); + size_t AppendIetfFrames(const QuicFrames& frames, QuicDataWriter* writer); + bool AppendStreamFrame(const QuicStreamFrame& frame, + bool no_stream_frame_length, QuicDataWriter* writer); + bool AppendCryptoFrame(const QuicCryptoFrame& frame, QuicDataWriter* writer); + bool AppendAckFrequencyFrame(const QuicAckFrequencyFrame& frame, + QuicDataWriter* writer); + + // SetDecrypter sets the primary decrypter, replacing any that already exists. + // If an alternative decrypter is in place then the function QUICHE_DCHECKs. + // This is intended for cases where one knows that future packets will be + // using the new decrypter and the previous decrypter is now obsolete. |level| + // indicates the encryption level of the new decrypter. + void SetDecrypter(EncryptionLevel level, + std::unique_ptr decrypter); + + // SetAlternativeDecrypter sets a decrypter that may be used to decrypt + // future packets. |level| indicates the encryption level of the decrypter. If + // |latch_once_used| is true, then the first time that the decrypter is + // successful it will replace the primary decrypter. Otherwise both + // decrypters will remain active and the primary decrypter will be the one + // last used. + void SetAlternativeDecrypter(EncryptionLevel level, + std::unique_ptr decrypter, + bool latch_once_used); + + void InstallDecrypter(EncryptionLevel level, + std::unique_ptr decrypter); + void RemoveDecrypter(EncryptionLevel level); + + // Enables key update support. + void SetKeyUpdateSupportForConnection(bool enabled); + // Discard the decrypter for the previous key phase. + void DiscardPreviousOneRttKeys(); + // Update the key phase. + bool DoKeyUpdate(KeyUpdateReason reason); + // Returns the count of packets received that appeared to attempt a key + // update but failed decryption which have been received since the last + // successfully decrypted packet. + QuicPacketCount PotentialPeerKeyUpdateAttemptCount() const; + + const QuicDecrypter* GetDecrypter(EncryptionLevel level) const; + const QuicDecrypter* decrypter() const; + const QuicDecrypter* alternative_decrypter() const; + + // Changes the encrypter used for level |level| to |encrypter|. + void SetEncrypter(EncryptionLevel level, + std::unique_ptr encrypter); + + // Called to remove encrypter of encryption |level|. + void RemoveEncrypter(EncryptionLevel level); + + // Sets the encrypter and decrypter for the ENCRYPTION_INITIAL level. + void SetInitialObfuscators(QuicConnectionId connection_id); + + // Encrypts a payload in |buffer|. |ad_len| is the length of the associated + // data. |total_len| is the length of the associated data plus plaintext. + // |buffer_len| is the full length of the allocated buffer. + size_t EncryptInPlace(EncryptionLevel level, QuicPacketNumber packet_number, + size_t ad_len, size_t total_len, size_t buffer_len, + char* buffer); + + // Returns the length of the data encrypted into |buffer| if |buffer_len| is + // long enough, and otherwise 0. + size_t EncryptPayload(EncryptionLevel level, QuicPacketNumber packet_number, + const QuicPacket& packet, char* buffer, + size_t buffer_len); + + // Returns the length of the ciphertext that would be generated by encrypting + // to plaintext of size |plaintext_size| at the given level. + size_t GetCiphertextSize(EncryptionLevel level, size_t plaintext_size) const; + + // Returns the maximum length of plaintext that can be encrypted + // to ciphertext no larger than |ciphertext_size|. + size_t GetMaxPlaintextSize(size_t ciphertext_size); + + // Returns the maximum number of packets that can be safely encrypted with + // the active AEAD. 1-RTT keys must be set before calling this method. + QuicPacketCount GetOneRttEncrypterConfidentialityLimit() const; + + const std::string& detailed_error() { return detailed_error_; } + + // The minimum packet number length required to represent |packet_number|. + static QuicPacketNumberLength GetMinPacketNumberLength( + QuicPacketNumber packet_number); + + void SetSupportedVersions(const ParsedQuicVersionVector& versions) { + supported_versions_ = versions; + version_ = versions[0]; + } + + // Tell framer to infer packet header type from version_. + void InferPacketHeaderTypeFromVersion(); + + // Returns true if |header| is considered as an stateless reset packet. + bool IsIetfStatelessResetPacket(const QuicPacketHeader& header) const; + + // Returns true if encrypter of |level| is available. + bool HasEncrypterOfEncryptionLevel(EncryptionLevel level) const; + // Returns true if decrypter of |level| is available. + bool HasDecrypterOfEncryptionLevel(EncryptionLevel level) const; + + // Returns true if an encrypter of |space| is available. + bool HasAnEncrypterForSpace(PacketNumberSpace space) const; + + // Returns the encryption level to send application data. This should be only + // called with available encrypter for application data. + EncryptionLevel GetEncryptionLevelToSendApplicationData() const; + + void set_validate_flags(bool value) { validate_flags_ = value; } + + Perspective perspective() const { return perspective_; } + + QuicStreamFrameDataProducer* data_producer() const { return data_producer_; } + + void set_data_producer(QuicStreamFrameDataProducer* data_producer) { + data_producer_ = data_producer; + } + + QuicTime creation_time() const { return creation_time_; } + + QuicPacketNumber first_sending_packet_number() const { + return first_sending_packet_number_; + } + + uint64_t current_received_frame_type() const { + return current_received_frame_type_; + } + + uint64_t previously_received_frame_type() const { + return previously_received_frame_type_; + } + + // The connection ID length the framer expects on incoming IETF short headers + // on the server. + uint8_t GetExpectedServerConnectionIdLength() { + return expected_server_connection_id_length_; + } + + // Change the expected destination connection ID length for short headers on + // the client. + void SetExpectedClientConnectionIdLength( + uint8_t expected_client_connection_id_length) { + expected_client_connection_id_length_ = + expected_client_connection_id_length; + } + + void EnableMultiplePacketNumberSpacesSupport(); + + // Writes an array of bytes that, if sent as a UDP datagram, will trigger + // IETF QUIC Version Negotiation on servers. The bytes will be written to + // |packet_bytes|, which must point to |packet_length| bytes of memory. + // |packet_length| must be in the range [1200, 65535]. + // |destination_connection_id_bytes| will be sent as the destination + // connection ID, and must point to |destination_connection_id_length| bytes + // of memory. |destination_connection_id_length| must be in the range [8,18]. + // When targeting Google servers, it is recommended to use a + // |destination_connection_id_length| of 8. + static bool WriteClientVersionNegotiationProbePacket( + char* packet_bytes, QuicByteCount packet_length, + const char* destination_connection_id_bytes, + uint8_t destination_connection_id_length); + + // Parses a packet which a QUIC server sent in response to a packet sent by + // WriteClientVersionNegotiationProbePacket. |packet_bytes| must point to + // |packet_length| bytes in memory which represent the response. + // |packet_length| must be greater or equal to 6. This method will fill in + // |source_connection_id_bytes| which must point to at least + // |*source_connection_id_length_out| bytes in memory. + // |*source_connection_id_length_out| must be at least 18. + // |*source_connection_id_length_out| will contain the length of the received + // source connection ID, which on success will match the contents of the + // destination connection ID passed in to + // WriteClientVersionNegotiationProbePacket. In the case of a failure, + // |detailed_error| will be filled in with an explanation of what failed. + static bool ParseServerVersionNegotiationProbeResponse( + const char* packet_bytes, QuicByteCount packet_length, + char* source_connection_id_bytes, + uint8_t* source_connection_id_length_out, std::string* detailed_error); + + void set_local_ack_delay_exponent(uint32_t exponent) { + local_ack_delay_exponent_ = exponent; + } + uint32_t local_ack_delay_exponent() const { + return local_ack_delay_exponent_; + } + + void set_peer_ack_delay_exponent(uint32_t exponent) { + peer_ack_delay_exponent_ = exponent; + } + uint32_t peer_ack_delay_exponent() const { return peer_ack_delay_exponent_; } + + void set_drop_incoming_retry_packets(bool drop_incoming_retry_packets) { + drop_incoming_retry_packets_ = drop_incoming_retry_packets; + } + + private: + friend class test::QuicFramerPeer; + + using NackRangeMap = std::map; + + // AckTimestampRange is a data structure derived from a QuicAckFrame. It is + // used to serialize timestamps in a IETF_ACK_RECEIVE_TIMESTAMPS frame. + struct QUIC_EXPORT_PRIVATE AckTimestampRange { + QuicPacketCount gap; + // |range_begin| and |range_end| are index(es) in + // QuicAckFrame.received_packet_times, representing a continuous range of + // packet numbers in descending order. |range_begin| >= |range_end|. + int64_t range_begin; // Inclusive + int64_t range_end; // Inclusive + }; + absl::InlinedVector GetAckTimestampRanges( + const QuicAckFrame& frame, std::string& detailed_error) const; + int64_t FrameAckTimestampRanges( + const QuicAckFrame& frame, + const absl::InlinedVector& timestamp_ranges, + QuicDataWriter* writer) const; + + struct QUIC_EXPORT_PRIVATE AckFrameInfo { + AckFrameInfo(); + AckFrameInfo(const AckFrameInfo& other); + ~AckFrameInfo(); + + // The maximum ack block length. + QuicPacketCount max_block_length; + // Length of first ack block. + QuicPacketCount first_block_length; + // Number of ACK blocks needed for the ACK frame. + size_t num_ack_blocks; + }; + + // Applies header protection to an IETF QUIC packet header in |buffer| using + // the encrypter for level |level|. The buffer has |buffer_len| bytes of data, + // with the first protected packet bytes starting at |ad_len|. + bool ApplyHeaderProtection(EncryptionLevel level, char* buffer, + size_t buffer_len, size_t ad_len); + + // Removes header protection from an IETF QUIC packet header. + // + // The packet number from the header is read from |reader|, where the packet + // number is the next contents in |reader|. |reader| is only advanced by the + // length of the packet number, but it is also used to peek the sample needed + // for removing header protection. + // + // Properties needed for removing header protection are read from |header|. + // The packet number length and type byte are written to |header|. + // + // The packet number, after removing header protection and decoding it, is + // written to |full_packet_number|. Finally, the header, with header + // protection removed, is written to |associated_data| to be used in packet + // decryption. |packet| is used in computing the asociated data. + bool RemoveHeaderProtection(QuicDataReader* reader, + const QuicEncryptedPacket& packet, + QuicPacketHeader* header, + uint64_t* full_packet_number, + std::vector* associated_data); + + bool ProcessDataPacket(QuicDataReader* reader, QuicPacketHeader* header, + const QuicEncryptedPacket& packet, + char* decrypted_buffer, size_t buffer_length); + + bool ProcessIetfDataPacket(QuicDataReader* encrypted_reader, + QuicPacketHeader* header, + const QuicEncryptedPacket& packet, + char* decrypted_buffer, size_t buffer_length); + + bool ProcessPublicResetPacket(QuicDataReader* reader, + const QuicPacketHeader& header); + + bool ProcessVersionNegotiationPacket(QuicDataReader* reader, + const QuicPacketHeader& header); + + bool ProcessRetryPacket(QuicDataReader* reader, + const QuicPacketHeader& header); + + void MaybeProcessCoalescedPacket(const QuicDataReader& encrypted_reader, + uint64_t remaining_bytes_length, + const QuicPacketHeader& header); + + bool MaybeProcessIetfLength(QuicDataReader* encrypted_reader, + QuicPacketHeader* header); + + bool ProcessPublicHeader(QuicDataReader* reader, + bool packet_has_ietf_packet_header, + QuicPacketHeader* header); + + // Processes the unauthenticated portion of the header into |header| from + // the current QuicDataReader. Returns true on success, false on failure. + bool ProcessUnauthenticatedHeader(QuicDataReader* encrypted_reader, + QuicPacketHeader* header); + + // Processes the version label in the packet header. + static bool ProcessVersionLabel(QuicDataReader* reader, + QuicVersionLabel* version_label); + + // Validates and updates |destination_connection_id_length| and + // |source_connection_id_length|. When + // |should_update_expected_server_connection_id_length| is true, length + // validation is disabled and |expected_server_connection_id_length| is set + // to the appropriate length. + // TODO(b/133873272) refactor this method. + static bool ProcessAndValidateIetfConnectionIdLength( + QuicDataReader* reader, ParsedQuicVersion version, + Perspective perspective, + bool should_update_expected_server_connection_id_length, + uint8_t* expected_server_connection_id_length, + uint8_t* destination_connection_id_length, + uint8_t* source_connection_id_length, std::string* detailed_error); + + bool ProcessIetfHeaderTypeByte(QuicDataReader* reader, + QuicPacketHeader* header); + bool ProcessIetfPacketHeader(QuicDataReader* reader, + QuicPacketHeader* header); + + // First processes possibly truncated packet number. Calculates the full + // packet number from the truncated one and the last seen packet number, and + // stores it to |packet_number|. + bool ProcessAndCalculatePacketNumber( + QuicDataReader* reader, QuicPacketNumberLength packet_number_length, + QuicPacketNumber base_packet_number, uint64_t* packet_number); + bool ProcessFrameData(QuicDataReader* reader, const QuicPacketHeader& header); + + static bool IsIetfFrameTypeExpectedForEncryptionLevel(uint64_t frame_type, + EncryptionLevel level); + + bool ProcessIetfFrameData(QuicDataReader* reader, + const QuicPacketHeader& header, + EncryptionLevel decrypted_level); + bool ProcessStreamFrame(QuicDataReader* reader, uint8_t frame_type, + QuicStreamFrame* frame); + bool ProcessAckFrame(QuicDataReader* reader, uint8_t frame_type); + bool ProcessTimestampsInAckFrame(uint8_t num_received_packets, + QuicPacketNumber largest_acked, + QuicDataReader* reader); + bool ProcessIetfAckFrame(QuicDataReader* reader, uint64_t frame_type, + QuicAckFrame* ack_frame); + bool ProcessIetfTimestampsInAckFrame(QuicPacketNumber largest_acked, + QuicDataReader* reader); + bool ProcessStopWaitingFrame(QuicDataReader* reader, + const QuicPacketHeader& header, + QuicStopWaitingFrame* stop_waiting); + bool ProcessRstStreamFrame(QuicDataReader* reader, QuicRstStreamFrame* frame); + bool ProcessConnectionCloseFrame(QuicDataReader* reader, + QuicConnectionCloseFrame* frame); + bool ProcessGoAwayFrame(QuicDataReader* reader, QuicGoAwayFrame* frame); + bool ProcessWindowUpdateFrame(QuicDataReader* reader, + QuicWindowUpdateFrame* frame); + bool ProcessBlockedFrame(QuicDataReader* reader, QuicBlockedFrame* frame); + void ProcessPaddingFrame(QuicDataReader* reader, QuicPaddingFrame* frame); + bool ProcessMessageFrame(QuicDataReader* reader, bool no_message_length, + QuicMessageFrame* frame); + + bool DecryptPayload(size_t udp_packet_length, absl::string_view encrypted, + absl::string_view associated_data, + const QuicPacketHeader& header, char* decrypted_buffer, + size_t buffer_length, size_t* decrypted_length, + EncryptionLevel* decrypted_level); + + // Returns the full packet number from the truncated + // wire format version and the last seen packet number. + uint64_t CalculatePacketNumberFromWire( + QuicPacketNumberLength packet_number_length, + QuicPacketNumber base_packet_number, uint64_t packet_number) const; + + // Returns the QuicTime::Delta corresponding to the time from when the framer + // was created. + const QuicTime::Delta CalculateTimestampFromWire(uint32_t time_delta_us); + + // Computes the wire size in bytes of time stamps in |ack|. + size_t GetAckFrameTimeStampSize(const QuicAckFrame& ack); + size_t GetIetfAckFrameTimestampSize(const QuicAckFrame& ack); + + // Computes the wire size in bytes of the |ack| frame. + size_t GetAckFrameSize(const QuicAckFrame& ack, + QuicPacketNumberLength packet_number_length); + // Computes the wire-size, in bytes, of the |frame| ack frame, for IETF Quic. + size_t GetIetfAckFrameSize(const QuicAckFrame& frame); + + // Computes the wire size in bytes of the |ack| frame. + size_t GetAckFrameSize(const QuicAckFrame& ack); + + // Computes the wire size in bytes of the payload of |frame|. + size_t ComputeFrameLength(const QuicFrame& frame, bool last_frame_in_packet, + QuicPacketNumberLength packet_number_length); + + static bool AppendPacketNumber(QuicPacketNumberLength packet_number_length, + QuicPacketNumber packet_number, + QuicDataWriter* writer); + static bool AppendStreamId(size_t stream_id_length, QuicStreamId stream_id, + QuicDataWriter* writer); + static bool AppendStreamOffset(size_t offset_length, QuicStreamOffset offset, + QuicDataWriter* writer); + + // Appends a single ACK block to |writer| and returns true if the block was + // successfully appended. + static bool AppendAckBlock(uint8_t gap, QuicPacketNumberLength length_length, + uint64_t length, QuicDataWriter* writer); + + static uint8_t GetPacketNumberFlags( + QuicPacketNumberLength packet_number_length); + + static AckFrameInfo GetAckFrameInfo(const QuicAckFrame& frame); + + static QuicErrorCode ParsePublicHeaderGoogleQuic( + QuicDataReader* reader, uint8_t* first_byte, PacketHeaderFormat* format, + bool* version_present, QuicVersionLabel* version_label, + ParsedQuicVersion* parsed_version, + QuicConnectionId* destination_connection_id, std::string* detailed_error); + + bool ValidateReceivedConnectionIds(const QuicPacketHeader& header); + + // The Append* methods attempt to write the provided header or frame using the + // |writer|, and return true if successful. + + bool AppendAckFrameAndTypeByte(const QuicAckFrame& frame, + QuicDataWriter* writer); + bool AppendTimestampsToAckFrame(const QuicAckFrame& frame, + QuicDataWriter* writer); + + // Append IETF format ACK frame. + // + // AppendIetfAckFrameAndTypeByte adds the IETF type byte and the body + // of the frame. + bool AppendIetfAckFrameAndTypeByte(const QuicAckFrame& frame, + QuicDataWriter* writer); + bool AppendIetfTimestampsToAckFrame(const QuicAckFrame& frame, + QuicDataWriter* writer); + + bool AppendStopWaitingFrame(const QuicPacketHeader& header, + const QuicStopWaitingFrame& frame, + QuicDataWriter* writer); + bool AppendRstStreamFrame(const QuicRstStreamFrame& frame, + QuicDataWriter* writer); + bool AppendConnectionCloseFrame(const QuicConnectionCloseFrame& frame, + QuicDataWriter* writer); + bool AppendGoAwayFrame(const QuicGoAwayFrame& frame, QuicDataWriter* writer); + bool AppendWindowUpdateFrame(const QuicWindowUpdateFrame& frame, + QuicDataWriter* writer); + bool AppendBlockedFrame(const QuicBlockedFrame& frame, + QuicDataWriter* writer); + bool AppendPaddingFrame(const QuicPaddingFrame& frame, + QuicDataWriter* writer); + bool AppendMessageFrameAndTypeByte(const QuicMessageFrame& frame, + bool last_frame_in_packet, + QuicDataWriter* writer); + + // IETF frame processing methods. + bool ProcessIetfStreamFrame(QuicDataReader* reader, uint8_t frame_type, + QuicStreamFrame* frame); + bool ProcessIetfConnectionCloseFrame(QuicDataReader* reader, + QuicConnectionCloseType type, + QuicConnectionCloseFrame* frame); + bool ProcessPathChallengeFrame(QuicDataReader* reader, + QuicPathChallengeFrame* frame); + bool ProcessPathResponseFrame(QuicDataReader* reader, + QuicPathResponseFrame* frame); + bool ProcessIetfResetStreamFrame(QuicDataReader* reader, + QuicRstStreamFrame* frame); + bool ProcessStopSendingFrame(QuicDataReader* reader, + QuicStopSendingFrame* stop_sending_frame); + bool ProcessCryptoFrame(QuicDataReader* reader, + EncryptionLevel encryption_level, + QuicCryptoFrame* frame); + bool ProcessAckFrequencyFrame(QuicDataReader* reader, + QuicAckFrequencyFrame* frame); + // IETF frame appending methods. All methods append the type byte as well. + bool AppendIetfStreamFrame(const QuicStreamFrame& frame, + bool last_frame_in_packet, QuicDataWriter* writer); + bool AppendIetfConnectionCloseFrame(const QuicConnectionCloseFrame& frame, + QuicDataWriter* writer); + bool AppendPathChallengeFrame(const QuicPathChallengeFrame& frame, + QuicDataWriter* writer); + bool AppendPathResponseFrame(const QuicPathResponseFrame& frame, + QuicDataWriter* writer); + bool AppendIetfResetStreamFrame(const QuicRstStreamFrame& frame, + QuicDataWriter* writer); + bool AppendStopSendingFrame(const QuicStopSendingFrame& stop_sending_frame, + QuicDataWriter* writer); + + // Append/consume IETF-Format MAX_DATA and MAX_STREAM_DATA frames + bool AppendMaxDataFrame(const QuicWindowUpdateFrame& frame, + QuicDataWriter* writer); + bool AppendMaxStreamDataFrame(const QuicWindowUpdateFrame& frame, + QuicDataWriter* writer); + bool ProcessMaxDataFrame(QuicDataReader* reader, + QuicWindowUpdateFrame* frame); + bool ProcessMaxStreamDataFrame(QuicDataReader* reader, + QuicWindowUpdateFrame* frame); + + bool AppendMaxStreamsFrame(const QuicMaxStreamsFrame& frame, + QuicDataWriter* writer); + bool ProcessMaxStreamsFrame(QuicDataReader* reader, + QuicMaxStreamsFrame* frame, uint64_t frame_type); + + bool AppendDataBlockedFrame(const QuicBlockedFrame& frame, + QuicDataWriter* writer); + bool ProcessDataBlockedFrame(QuicDataReader* reader, QuicBlockedFrame* frame); + + bool AppendStreamDataBlockedFrame(const QuicBlockedFrame& frame, + QuicDataWriter* writer); + bool ProcessStreamDataBlockedFrame(QuicDataReader* reader, + QuicBlockedFrame* frame); + + bool AppendStreamsBlockedFrame(const QuicStreamsBlockedFrame& frame, + QuicDataWriter* writer); + bool ProcessStreamsBlockedFrame(QuicDataReader* reader, + QuicStreamsBlockedFrame* frame, + uint64_t frame_type); + + bool AppendNewConnectionIdFrame(const QuicNewConnectionIdFrame& frame, + QuicDataWriter* writer); + bool ProcessNewConnectionIdFrame(QuicDataReader* reader, + QuicNewConnectionIdFrame* frame); + bool AppendRetireConnectionIdFrame(const QuicRetireConnectionIdFrame& frame, + QuicDataWriter* writer); + bool ProcessRetireConnectionIdFrame(QuicDataReader* reader, + QuicRetireConnectionIdFrame* frame); + + bool AppendNewTokenFrame(const QuicNewTokenFrame& frame, + QuicDataWriter* writer); + bool ProcessNewTokenFrame(QuicDataReader* reader, QuicNewTokenFrame* frame); + + bool RaiseError(QuicErrorCode error); + + // Returns true if |header| indicates a version negotiation packet. + bool IsVersionNegotiation(const QuicPacketHeader& header, + bool packet_has_ietf_packet_header) const; + + // Calculates and returns type byte of stream frame. + uint8_t GetStreamFrameTypeByte(const QuicStreamFrame& frame, + bool last_frame_in_packet) const; + uint8_t GetIetfStreamFrameTypeByte(const QuicStreamFrame& frame, + bool last_frame_in_packet) const; + + void set_error(QuicErrorCode error) { error_ = error; } + + void set_detailed_error(const char* error) { detailed_error_ = error; } + void set_detailed_error(std::string error) { detailed_error_ = error; } + + // Returns false if the reading fails. + bool ReadUint32FromVarint62(QuicDataReader* reader, QuicIetfFrameType type, + QuicStreamId* id); + + bool ProcessPacketInternal(const QuicEncryptedPacket& packet); + + // Determine whether the given QuicAckFrame should be serialized with a + // IETF_ACK_RECEIVE_TIMESTAMPS frame type. + bool UseIetfAckWithReceiveTimestamp(const QuicAckFrame& frame) const { + return VersionHasIetfQuicFrames(version_.transport_version) && + process_timestamps_ && + std::min(max_receive_timestamps_per_ack_, + frame.received_packet_times.size()) > 0; + } + + std::string detailed_error_; + QuicFramerVisitorInterface* visitor_; + QuicErrorCode error_; + // Updated by ProcessPacketHeader when it succeeds decrypting a larger packet. + QuicPacketNumber largest_packet_number_; + // Largest successfully decrypted packet number per packet number space. Only + // used when supports_multiple_packet_number_spaces_ is true. + QuicPacketNumber largest_decrypted_packet_numbers_[NUM_PACKET_NUMBER_SPACES]; + // Last server connection ID seen on the wire. + QuicConnectionId last_serialized_server_connection_id_; + // Last client connection ID seen on the wire. + QuicConnectionId last_serialized_client_connection_id_; + // Version of the protocol being used. + ParsedQuicVersion version_; + // This vector contains QUIC versions which we currently support. + // This should be ordered such that the highest supported version is the first + // element, with subsequent elements in descending order (versions can be + // skipped as necessary). + ParsedQuicVersionVector supported_versions_; + // Decrypters used to decrypt packets during parsing. + std::unique_ptr decrypter_[NUM_ENCRYPTION_LEVELS]; + // The encryption level of the primary decrypter to use in |decrypter_|. + EncryptionLevel decrypter_level_; + // The encryption level of the alternative decrypter to use in |decrypter_|. + // When set to NUM_ENCRYPTION_LEVELS, indicates that there is no alternative + // decrypter. + EncryptionLevel alternative_decrypter_level_; + // |alternative_decrypter_latch_| is true if, when the decrypter at + // |alternative_decrypter_level_| successfully decrypts a packet, we should + // install it as the only decrypter. + bool alternative_decrypter_latch_; + // Encrypters used to encrypt packets via EncryptPayload(). + std::unique_ptr encrypter_[NUM_ENCRYPTION_LEVELS]; + // Tracks if the framer is being used by the entity that received the + // connection or the entity that initiated it. + Perspective perspective_; + // If false, skip validation that the public flags are set to legal values. + bool validate_flags_; + // The diversification nonce from the last received packet. + DiversificationNonce last_nonce_; + // If true, send and process timestamps in the ACK frame. + // TODO(ianswett): Remove the mutables once set_process_timestamps and + // set_receive_timestamp_exponent_ aren't const. + mutable bool process_timestamps_; + // The max number of receive timestamps to send per ACK frame. + mutable uint32_t max_receive_timestamps_per_ack_; + // The exponent to use when writing/reading ACK receive timestamps. + mutable uint32_t receive_timestamps_exponent_; + // The creation time of the connection, used to calculate timestamps. + QuicTime creation_time_; + // The last timestamp received if process_timestamps_ is true. + QuicTime::Delta last_timestamp_; + + // Whether IETF QUIC Key Update is supported on this connection. + bool support_key_update_for_connection_; + // The value of the current key phase bit, which is toggled when the keys are + // changed. + bool current_key_phase_bit_; + // Whether we have performed a key update at least once. + bool key_update_performed_ = false; + // Tracks the first packet received in the current key phase. Will be + // uninitialized before the first one-RTT packet has been received or after a + // locally initiated key update but before the first packet from the peer in + // the new key phase is received. + QuicPacketNumber current_key_phase_first_received_packet_number_; + // Counts the number of packets received that might have been failed key + // update attempts. Reset to zero every time a packet is successfully + // decrypted. + QuicPacketCount potential_peer_key_update_attempt_count_; + // Decrypter for the previous key phase. Will be null if in the first key + // phase or previous keys have been discarded. + std::unique_ptr previous_decrypter_; + // Decrypter for the next key phase. May be null if next keys haven't been + // generated yet. + std::unique_ptr next_decrypter_; + + // If this is a framer of a connection, this is the packet number of first + // sending packet. If this is a framer of a framer of dispatcher, this is the + // packet number of sent packets (for those which have packet number). + const QuicPacketNumber first_sending_packet_number_; + + // If not null, framer asks data_producer_ to write stream frame data. Not + // owned. TODO(fayang): Consider add data producer to framer's constructor. + QuicStreamFrameDataProducer* data_producer_; + + // Whether we are in the middle of a call to this->ProcessPacket. + bool is_processing_packet_ = false; + + // If true, framer infers packet header type (IETF/GQUIC) from version_. + // Otherwise, framer infers packet header type from first byte of a received + // packet. + bool infer_packet_header_type_from_version_; + + // IETF short headers contain a destination connection ID but do not + // encode its length. These variables contains the length we expect to read. + // This is also used to validate the long header destination connection ID + // lengths in older versions of QUIC. + uint8_t expected_server_connection_id_length_; + uint8_t expected_client_connection_id_length_; + + // Indicates whether this framer supports multiple packet number spaces. + bool supports_multiple_packet_number_spaces_; + + // Indicates whether received RETRY packets should be dropped. + bool drop_incoming_retry_packets_ = false; + + // The length in bytes of the last packet number written to an IETF-framed + // packet. + size_t last_written_packet_number_length_; + + // The amount to shift the ack timestamp in ACK frames. The default is 3. + // Local_ is the amount this node shifts timestamps in ACK frames it + // generates. it is sent to the peer in a transport parameter negotiation. + // Peer_ is the amount the peer shifts timestamps when it sends ACK frames to + // this node. This node "unshifts" by this amount. The value is received from + // the peer in the transport parameter negotiation. IETF QUIC only. + uint32_t peer_ack_delay_exponent_; + uint32_t local_ack_delay_exponent_; + + // The type of received IETF frame currently being processed. 0 when not + // processing a frame or when processing Google QUIC frames. Used to populate + // the Transport Connection Close when there is an error during frame + // processing. + uint64_t current_received_frame_type_; + + // TODO(haoyuewang) Remove this debug utility. + // The type of the IETF frame preceding the frame currently being processed. 0 + // when not processing a frame or only 1 frame has been processed. + uint64_t previously_received_frame_type_; +}; + +// Look for and parse the error code from the ":" text that +// may be present at the start of the CONNECTION_CLOSE error details string. +// This text, inserted by the peer if it's using Google's QUIC implementation, +// contains additional error information that narrows down the exact error. The +// extracted error code and (possibly updated) error_details string are returned +// in |*frame|. If an error code is not found in the error details, then +// frame->quic_error_code is set to +// QuicErrorCode::QUIC_IETF_GQUIC_ERROR_MISSING. If there is an error code in +// the string then it is removed from the string. +QUIC_EXPORT_PRIVATE void MaybeExtractQuicErrorCode( + QuicConnectionCloseFrame* frame); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_FRAMER_H_ diff --git a/quiche/quic/core/quic_framer_test.cc b/quiche/quic/core/quic_framer_test.cc new file mode 100644 index 000000000000..7efaebe2dd0d --- /dev/null +++ b/quiche/quic/core/quic_framer_test.cc @@ -0,0 +1,16544 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_framer.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/memory/memory.h" +#include "absl/strings/escaping.h" +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/null_decrypter.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/core/crypto/quic_decrypter.h" +#include "quiche/quic/core/crypto/quic_encrypter.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_ip_address_family.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_framer_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/test_tools/simple_data_producer.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +using testing::_; +using testing::ContainerEq; +using testing::Return; + +namespace quic { +namespace test { +namespace { + +const uint64_t kEpoch = UINT64_C(1) << 32; +const uint64_t kMask = kEpoch - 1; +const uint8_t kPacket0ByteConnectionId = 0; +const uint8_t kPacket8ByteConnectionId = 8; +constexpr size_t kTagSize = 16; + +const StatelessResetToken kTestStatelessResetToken{ + 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, + 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f}; + +// Use fields in which each byte is distinct to ensure that every byte is +// framed correctly. The values are otherwise arbitrary. +QuicConnectionId FramerTestConnectionId() { + return TestConnectionId(UINT64_C(0xFEDCBA9876543210)); +} + +QuicConnectionId FramerTestConnectionIdPlusOne() { + return TestConnectionId(UINT64_C(0xFEDCBA9876543211)); +} + +QuicConnectionId FramerTestConnectionIdNineBytes() { + uint8_t connection_id_bytes[9] = {0xFE, 0xDC, 0xBA, 0x98, 0x76, + 0x54, 0x32, 0x10, 0x42}; + return QuicConnectionId(reinterpret_cast(connection_id_bytes), + sizeof(connection_id_bytes)); +} + +const QuicPacketNumber kPacketNumber = QuicPacketNumber(UINT64_C(0x12345678)); +const QuicPacketNumber kSmallLargestObserved = + QuicPacketNumber(UINT16_C(0x1234)); +const QuicPacketNumber kSmallMissingPacket = QuicPacketNumber(UINT16_C(0x1233)); +const QuicPacketNumber kLeastUnacked = QuicPacketNumber(UINT64_C(0x012345670)); +const QuicStreamId kStreamId = UINT64_C(0x01020304); +// Note that the high 4 bits of the stream offset must be less than 0x40 +// in order to ensure that the value can be encoded using VarInt62 encoding. +const QuicStreamOffset kStreamOffset = UINT64_C(0x3A98FEDC32107654); +const QuicPublicResetNonceProof kNonceProof = UINT64_C(0xABCDEF0123456789); + +// In testing that we can ack the full range of packets... +// This is the largest packet number that can be represented in IETF QUIC +// varint62 format. +const QuicPacketNumber kLargestIetfLargestObserved = + QuicPacketNumber(UINT64_C(0x3fffffffffffffff)); +// Encodings for the two bits in a VarInt62 that +// describe the length of the VarInt61. For binary packet +// formats in this file, the convention is to code the +// first byte as +// kVarInt62FourBytes + 0x +const uint8_t kVarInt62OneByte = 0x00; +const uint8_t kVarInt62TwoBytes = 0x40; +const uint8_t kVarInt62FourBytes = 0x80; +const uint8_t kVarInt62EightBytes = 0xc0; + +class TestEncrypter : public QuicEncrypter { + public: + ~TestEncrypter() override {} + bool SetKey(absl::string_view /*key*/) override { return true; } + bool SetNoncePrefix(absl::string_view /*nonce_prefix*/) override { + return true; + } + bool SetIV(absl::string_view /*iv*/) override { return true; } + bool SetHeaderProtectionKey(absl::string_view /*key*/) override { + return true; + } + bool EncryptPacket(uint64_t packet_number, absl::string_view associated_data, + absl::string_view plaintext, char* output, + size_t* output_length, + size_t /*max_output_length*/) override { + packet_number_ = QuicPacketNumber(packet_number); + associated_data_ = std::string(associated_data); + plaintext_ = std::string(plaintext); + memcpy(output, plaintext.data(), plaintext.length()); + *output_length = plaintext.length(); + return true; + } + std::string GenerateHeaderProtectionMask( + absl::string_view /*sample*/) override { + return std::string(5, 0); + } + size_t GetKeySize() const override { return 0; } + size_t GetNoncePrefixSize() const override { return 0; } + size_t GetIVSize() const override { return 0; } + size_t GetMaxPlaintextSize(size_t ciphertext_size) const override { + return ciphertext_size; + } + size_t GetCiphertextSize(size_t plaintext_size) const override { + return plaintext_size; + } + QuicPacketCount GetConfidentialityLimit() const override { + return std::numeric_limits::max(); + } + absl::string_view GetKey() const override { return absl::string_view(); } + absl::string_view GetNoncePrefix() const override { + return absl::string_view(); + } + + QuicPacketNumber packet_number_; + std::string associated_data_; + std::string plaintext_; +}; + +class TestDecrypter : public QuicDecrypter { + public: + ~TestDecrypter() override {} + bool SetKey(absl::string_view /*key*/) override { return true; } + bool SetNoncePrefix(absl::string_view /*nonce_prefix*/) override { + return true; + } + bool SetIV(absl::string_view /*iv*/) override { return true; } + bool SetHeaderProtectionKey(absl::string_view /*key*/) override { + return true; + } + bool SetPreliminaryKey(absl::string_view /*key*/) override { + QUIC_BUG(quic_bug_10486_1) << "should not be called"; + return false; + } + bool SetDiversificationNonce(const DiversificationNonce& /*key*/) override { + return true; + } + bool DecryptPacket(uint64_t packet_number, absl::string_view associated_data, + absl::string_view ciphertext, char* output, + size_t* output_length, + size_t /*max_output_length*/) override { + packet_number_ = QuicPacketNumber(packet_number); + associated_data_ = std::string(associated_data); + ciphertext_ = std::string(ciphertext); + memcpy(output, ciphertext.data(), ciphertext.length()); + *output_length = ciphertext.length(); + return true; + } + std::string GenerateHeaderProtectionMask( + QuicDataReader* /*sample_reader*/) override { + return std::string(5, 0); + } + size_t GetKeySize() const override { return 0; } + size_t GetNoncePrefixSize() const override { return 0; } + size_t GetIVSize() const override { return 0; } + absl::string_view GetKey() const override { return absl::string_view(); } + absl::string_view GetNoncePrefix() const override { + return absl::string_view(); + } + // Use a distinct value starting with 0xFFFFFF, which is never used by TLS. + uint32_t cipher_id() const override { return 0xFFFFFFF2; } + QuicPacketCount GetIntegrityLimit() const override { + return std::numeric_limits::max(); + } + QuicPacketNumber packet_number_; + std::string associated_data_; + std::string ciphertext_; +}; + +std::unique_ptr EncryptPacketWithTagAndPhase( + const QuicPacket& packet, uint8_t tag, bool phase) { + std::string packet_data = std::string(packet.AsStringPiece()); + if (phase) { + packet_data[0] |= FLAGS_KEY_PHASE_BIT; + } else { + packet_data[0] &= ~FLAGS_KEY_PHASE_BIT; + } + + TaggingEncrypter crypter(tag); + const size_t packet_size = crypter.GetCiphertextSize(packet_data.size()); + char* buffer = new char[packet_size]; + size_t buf_len = 0; + if (!crypter.EncryptPacket(0, absl::string_view(), packet_data, buffer, + &buf_len, packet_size)) { + delete[] buffer; + return nullptr; + } + + return std::make_unique(buffer, buf_len, + /*owns_buffer=*/true); +} + +class TestQuicVisitor : public QuicFramerVisitorInterface { + public: + TestQuicVisitor() + : error_count_(0), + version_mismatch_(0), + packet_count_(0), + frame_count_(0), + complete_packets_(0), + derive_next_key_count_(0), + decrypted_first_packet_in_key_phase_count_(0), + accept_packet_(true), + accept_public_header_(true) {} + + ~TestQuicVisitor() override {} + + void OnError(QuicFramer* f) override { + QUIC_DLOG(INFO) << "QuicFramer Error: " << QuicErrorCodeToString(f->error()) + << " (" << f->error() << ")"; + ++error_count_; + } + + void OnPacket() override {} + + void OnPublicResetPacket(const QuicPublicResetPacket& packet) override { + public_reset_packet_ = std::make_unique((packet)); + EXPECT_EQ(0u, framer_->current_received_frame_type()); + } + + void OnVersionNegotiationPacket( + const QuicVersionNegotiationPacket& packet) override { + version_negotiation_packet_ = + std::make_unique((packet)); + EXPECT_EQ(0u, framer_->current_received_frame_type()); + } + + void OnRetryPacket(QuicConnectionId original_connection_id, + QuicConnectionId new_connection_id, + absl::string_view retry_token, + absl::string_view retry_integrity_tag, + absl::string_view retry_without_tag) override { + on_retry_packet_called_ = true; + retry_original_connection_id_ = + std::make_unique(original_connection_id); + retry_new_connection_id_ = + std::make_unique(new_connection_id); + retry_token_ = std::make_unique(std::string(retry_token)); + retry_token_integrity_tag_ = + std::make_unique(std::string(retry_integrity_tag)); + retry_without_tag_ = + std::make_unique(std::string(retry_without_tag)); + EXPECT_EQ(0u, framer_->current_received_frame_type()); + } + + bool OnProtocolVersionMismatch(ParsedQuicVersion received_version) override { + QUIC_DLOG(INFO) << "QuicFramer Version Mismatch, version: " + << received_version; + ++version_mismatch_; + EXPECT_EQ(0u, framer_->current_received_frame_type()); + return false; + } + + bool OnUnauthenticatedPublicHeader(const QuicPacketHeader& header) override { + header_ = std::make_unique((header)); + EXPECT_EQ(0u, framer_->current_received_frame_type()); + return accept_public_header_; + } + + bool OnUnauthenticatedHeader(const QuicPacketHeader& /*header*/) override { + EXPECT_EQ(0u, framer_->current_received_frame_type()); + return true; + } + + void OnDecryptedPacket(size_t /*length*/, + EncryptionLevel /*level*/) override { + EXPECT_EQ(0u, framer_->current_received_frame_type()); + } + + bool OnPacketHeader(const QuicPacketHeader& header) override { + ++packet_count_; + header_ = std::make_unique((header)); + EXPECT_EQ(0u, framer_->current_received_frame_type()); + return accept_packet_; + } + + void OnCoalescedPacket(const QuicEncryptedPacket& packet) override { + coalesced_packets_.push_back(packet.Clone()); + } + + void OnUndecryptablePacket(const QuicEncryptedPacket& packet, + EncryptionLevel decryption_level, + bool has_decryption_key) override { + undecryptable_packets_.push_back(packet.Clone()); + undecryptable_decryption_levels_.push_back(decryption_level); + undecryptable_has_decryption_keys_.push_back(has_decryption_key); + } + + bool OnStreamFrame(const QuicStreamFrame& frame) override { + ++frame_count_; + // Save a copy of the data so it is valid after the packet is processed. + std::string* string_data = + new std::string(frame.data_buffer, frame.data_length); + stream_data_.push_back(absl::WrapUnique(string_data)); + stream_frames_.push_back(std::make_unique( + frame.stream_id, frame.fin, frame.offset, *string_data)); + if (VersionHasIetfQuicFrames(transport_version_)) { + // Low order bits of type encode flags, ignore them for this test. + EXPECT_TRUE(IS_IETF_STREAM_FRAME(framer_->current_received_frame_type())); + } else { + EXPECT_EQ(0u, framer_->current_received_frame_type()); + } + return true; + } + + bool OnCryptoFrame(const QuicCryptoFrame& frame) override { + ++frame_count_; + // Save a copy of the data so it is valid after the packet is processed. + std::string* string_data = + new std::string(frame.data_buffer, frame.data_length); + crypto_data_.push_back(absl::WrapUnique(string_data)); + crypto_frames_.push_back(std::make_unique( + frame.level, frame.offset, *string_data)); + if (VersionHasIetfQuicFrames(transport_version_)) { + EXPECT_EQ(IETF_CRYPTO, framer_->current_received_frame_type()); + } else { + EXPECT_EQ(0u, framer_->current_received_frame_type()); + } + return true; + } + + bool OnAckFrameStart(QuicPacketNumber largest_acked, + QuicTime::Delta ack_delay_time) override { + ++frame_count_; + QuicAckFrame ack_frame; + ack_frame.largest_acked = largest_acked; + ack_frame.ack_delay_time = ack_delay_time; + ack_frames_.push_back(std::make_unique(ack_frame)); + if (VersionHasIetfQuicFrames(transport_version_)) { + EXPECT_TRUE(IETF_ACK == framer_->current_received_frame_type() || + IETF_ACK_ECN == framer_->current_received_frame_type() || + IETF_ACK_RECEIVE_TIMESTAMPS == + framer_->current_received_frame_type()); + } else { + EXPECT_EQ(0u, framer_->current_received_frame_type()); + } + return true; + } + + bool OnAckRange(QuicPacketNumber start, QuicPacketNumber end) override { + QUICHE_DCHECK(!ack_frames_.empty()); + ack_frames_[ack_frames_.size() - 1]->packets.AddRange(start, end); + if (VersionHasIetfQuicFrames(transport_version_)) { + EXPECT_TRUE(IETF_ACK == framer_->current_received_frame_type() || + IETF_ACK_ECN == framer_->current_received_frame_type() || + IETF_ACK_RECEIVE_TIMESTAMPS == + framer_->current_received_frame_type()); + } else { + EXPECT_EQ(0u, framer_->current_received_frame_type()); + } + return true; + } + + bool OnAckTimestamp(QuicPacketNumber packet_number, + QuicTime timestamp) override { + ack_frames_[ack_frames_.size() - 1]->received_packet_times.push_back( + std::make_pair(packet_number, timestamp)); + if (VersionHasIetfQuicFrames(transport_version_)) { + EXPECT_TRUE(IETF_ACK == framer_->current_received_frame_type() || + IETF_ACK_ECN == framer_->current_received_frame_type() || + IETF_ACK_RECEIVE_TIMESTAMPS == + framer_->current_received_frame_type()); + } else { + EXPECT_EQ(0u, framer_->current_received_frame_type()); + } + return true; + } + + bool OnAckFrameEnd( + QuicPacketNumber /*start*/, + const absl::optional& /*ecn_counts*/) override { + return true; + } + + bool OnStopWaitingFrame(const QuicStopWaitingFrame& frame) override { + ++frame_count_; + stop_waiting_frames_.push_back( + std::make_unique(frame)); + EXPECT_EQ(0u, framer_->current_received_frame_type()); + return true; + } + + bool OnPaddingFrame(const QuicPaddingFrame& frame) override { + padding_frames_.push_back(std::make_unique(frame)); + if (VersionHasIetfQuicFrames(transport_version_)) { + EXPECT_EQ(IETF_PADDING, framer_->current_received_frame_type()); + } else { + EXPECT_EQ(0u, framer_->current_received_frame_type()); + } + return true; + } + + bool OnPingFrame(const QuicPingFrame& frame) override { + ++frame_count_; + ping_frames_.push_back(std::make_unique(frame)); + if (VersionHasIetfQuicFrames(transport_version_)) { + EXPECT_EQ(IETF_PING, framer_->current_received_frame_type()); + } else { + EXPECT_EQ(0u, framer_->current_received_frame_type()); + } + return true; + } + + bool OnMessageFrame(const QuicMessageFrame& frame) override { + ++frame_count_; + message_frames_.push_back( + std::make_unique(frame.data, frame.message_length)); + if (VersionHasIetfQuicFrames(transport_version_)) { + EXPECT_TRUE(IETF_EXTENSION_MESSAGE_NO_LENGTH_V99 == + framer_->current_received_frame_type() || + IETF_EXTENSION_MESSAGE_V99 == + framer_->current_received_frame_type()); + } else { + EXPECT_EQ(0u, framer_->current_received_frame_type()); + } + return true; + } + + bool OnHandshakeDoneFrame(const QuicHandshakeDoneFrame& frame) override { + ++frame_count_; + handshake_done_frames_.push_back( + std::make_unique(frame)); + QUICHE_DCHECK(VersionHasIetfQuicFrames(transport_version_)); + EXPECT_EQ(IETF_HANDSHAKE_DONE, framer_->current_received_frame_type()); + return true; + } + + bool OnAckFrequencyFrame(const QuicAckFrequencyFrame& frame) override { + ++frame_count_; + ack_frequency_frames_.emplace_back( + std::make_unique(frame)); + QUICHE_DCHECK(VersionHasIetfQuicFrames(transport_version_)); + EXPECT_EQ(IETF_ACK_FREQUENCY, framer_->current_received_frame_type()); + return true; + } + + void OnPacketComplete() override { ++complete_packets_; } + + bool OnRstStreamFrame(const QuicRstStreamFrame& frame) override { + rst_stream_frame_ = frame; + if (VersionHasIetfQuicFrames(transport_version_)) { + EXPECT_EQ(IETF_RST_STREAM, framer_->current_received_frame_type()); + } else { + EXPECT_EQ(0u, framer_->current_received_frame_type()); + } + return true; + } + + bool OnConnectionCloseFrame(const QuicConnectionCloseFrame& frame) override { + connection_close_frame_ = frame; + if (VersionHasIetfQuicFrames(transport_version_)) { + EXPECT_NE(GOOGLE_QUIC_CONNECTION_CLOSE, frame.close_type); + if (frame.close_type == IETF_QUIC_TRANSPORT_CONNECTION_CLOSE) { + EXPECT_EQ(IETF_CONNECTION_CLOSE, + framer_->current_received_frame_type()); + } else { + EXPECT_EQ(IETF_APPLICATION_CLOSE, + framer_->current_received_frame_type()); + } + } else { + EXPECT_EQ(0u, framer_->current_received_frame_type()); + } + return true; + } + + bool OnStopSendingFrame(const QuicStopSendingFrame& frame) override { + stop_sending_frame_ = frame; + EXPECT_EQ(IETF_STOP_SENDING, framer_->current_received_frame_type()); + EXPECT_TRUE(VersionHasIetfQuicFrames(transport_version_)); + return true; + } + + bool OnPathChallengeFrame(const QuicPathChallengeFrame& frame) override { + path_challenge_frame_ = frame; + EXPECT_EQ(IETF_PATH_CHALLENGE, framer_->current_received_frame_type()); + EXPECT_TRUE(VersionHasIetfQuicFrames(transport_version_)); + return true; + } + + bool OnPathResponseFrame(const QuicPathResponseFrame& frame) override { + path_response_frame_ = frame; + EXPECT_EQ(IETF_PATH_RESPONSE, framer_->current_received_frame_type()); + EXPECT_TRUE(VersionHasIetfQuicFrames(transport_version_)); + return true; + } + + bool OnGoAwayFrame(const QuicGoAwayFrame& frame) override { + goaway_frame_ = frame; + EXPECT_FALSE(VersionHasIetfQuicFrames(transport_version_)); + EXPECT_EQ(0u, framer_->current_received_frame_type()); + return true; + } + + bool OnMaxStreamsFrame(const QuicMaxStreamsFrame& frame) override { + max_streams_frame_ = frame; + EXPECT_TRUE(VersionHasIetfQuicFrames(transport_version_)); + EXPECT_TRUE(IETF_MAX_STREAMS_UNIDIRECTIONAL == + framer_->current_received_frame_type() || + IETF_MAX_STREAMS_BIDIRECTIONAL == + framer_->current_received_frame_type()); + return true; + } + + bool OnStreamsBlockedFrame(const QuicStreamsBlockedFrame& frame) override { + streams_blocked_frame_ = frame; + EXPECT_TRUE(VersionHasIetfQuicFrames(transport_version_)); + EXPECT_TRUE(IETF_STREAMS_BLOCKED_UNIDIRECTIONAL == + framer_->current_received_frame_type() || + IETF_STREAMS_BLOCKED_BIDIRECTIONAL == + framer_->current_received_frame_type()); + return true; + } + + bool OnWindowUpdateFrame(const QuicWindowUpdateFrame& frame) override { + window_update_frame_ = frame; + if (VersionHasIetfQuicFrames(transport_version_)) { + EXPECT_TRUE(IETF_MAX_DATA == framer_->current_received_frame_type() || + IETF_MAX_STREAM_DATA == + framer_->current_received_frame_type()); + } else { + EXPECT_EQ(0u, framer_->current_received_frame_type()); + } + return true; + } + + bool OnBlockedFrame(const QuicBlockedFrame& frame) override { + blocked_frame_ = frame; + if (VersionHasIetfQuicFrames(transport_version_)) { + EXPECT_TRUE(IETF_DATA_BLOCKED == framer_->current_received_frame_type() || + IETF_STREAM_DATA_BLOCKED == + framer_->current_received_frame_type()); + } else { + EXPECT_EQ(0u, framer_->current_received_frame_type()); + } + return true; + } + + bool OnNewConnectionIdFrame(const QuicNewConnectionIdFrame& frame) override { + new_connection_id_ = frame; + EXPECT_EQ(IETF_NEW_CONNECTION_ID, framer_->current_received_frame_type()); + EXPECT_TRUE(VersionHasIetfQuicFrames(transport_version_)); + return true; + } + + bool OnRetireConnectionIdFrame( + const QuicRetireConnectionIdFrame& frame) override { + EXPECT_EQ(IETF_RETIRE_CONNECTION_ID, + framer_->current_received_frame_type()); + EXPECT_TRUE(VersionHasIetfQuicFrames(transport_version_)); + retire_connection_id_ = frame; + return true; + } + + bool OnNewTokenFrame(const QuicNewTokenFrame& frame) override { + new_token_ = frame; + EXPECT_EQ(IETF_NEW_TOKEN, framer_->current_received_frame_type()); + EXPECT_TRUE(VersionHasIetfQuicFrames(transport_version_)); + return true; + } + + bool IsValidStatelessResetToken( + const StatelessResetToken& token) const override { + EXPECT_EQ(0u, framer_->current_received_frame_type()); + return token == kTestStatelessResetToken; + } + + void OnAuthenticatedIetfStatelessResetPacket( + const QuicIetfStatelessResetPacket& packet) override { + stateless_reset_packet_ = + std::make_unique(packet); + EXPECT_EQ(0u, framer_->current_received_frame_type()); + } + + void OnKeyUpdate(KeyUpdateReason reason) override { + key_update_reasons_.push_back(reason); + } + + void OnDecryptedFirstPacketInKeyPhase() override { + decrypted_first_packet_in_key_phase_count_++; + } + + std::unique_ptr AdvanceKeysAndCreateCurrentOneRttDecrypter() + override { + derive_next_key_count_++; + return std::make_unique(derive_next_key_count_); + } + std::unique_ptr CreateCurrentOneRttEncrypter() override { + return std::make_unique(derive_next_key_count_); + } + + void set_framer(QuicFramer* framer) { + framer_ = framer; + transport_version_ = framer->transport_version(); + } + + size_t key_update_count() const { return key_update_reasons_.size(); } + + // Counters from the visitor_ callbacks. + int error_count_; + int version_mismatch_; + int packet_count_; + int frame_count_; + int complete_packets_; + std::vector key_update_reasons_; + int derive_next_key_count_; + int decrypted_first_packet_in_key_phase_count_; + bool accept_packet_; + bool accept_public_header_; + + std::unique_ptr header_; + std::unique_ptr public_reset_packet_; + std::unique_ptr stateless_reset_packet_; + std::unique_ptr version_negotiation_packet_; + std::unique_ptr retry_original_connection_id_; + std::unique_ptr retry_new_connection_id_; + std::unique_ptr retry_token_; + std::unique_ptr retry_token_integrity_tag_; + std::unique_ptr retry_without_tag_; + bool on_retry_packet_called_ = false; + std::vector> stream_frames_; + std::vector> crypto_frames_; + std::vector> ack_frames_; + std::vector> stop_waiting_frames_; + std::vector> padding_frames_; + std::vector> ping_frames_; + std::vector> message_frames_; + std::vector> handshake_done_frames_; + std::vector> ack_frequency_frames_; + std::vector> coalesced_packets_; + std::vector> undecryptable_packets_; + std::vector undecryptable_decryption_levels_; + std::vector undecryptable_has_decryption_keys_; + QuicRstStreamFrame rst_stream_frame_; + QuicConnectionCloseFrame connection_close_frame_; + QuicStopSendingFrame stop_sending_frame_; + QuicGoAwayFrame goaway_frame_; + QuicPathChallengeFrame path_challenge_frame_; + QuicPathResponseFrame path_response_frame_; + QuicWindowUpdateFrame window_update_frame_; + QuicBlockedFrame blocked_frame_; + QuicStreamsBlockedFrame streams_blocked_frame_; + QuicMaxStreamsFrame max_streams_frame_; + QuicNewConnectionIdFrame new_connection_id_; + QuicRetireConnectionIdFrame retire_connection_id_; + QuicNewTokenFrame new_token_; + std::vector> stream_data_; + std::vector> crypto_data_; + QuicTransportVersion transport_version_; + QuicFramer* framer_; +}; + +// Simple struct for defining a packet's content, and associated +// parse error. +struct PacketFragment { + std::string error_if_missing; + std::vector fragment; +}; + +using PacketFragments = std::vector; + +class QuicFramerTest : public QuicTestWithParam { + public: + QuicFramerTest() + : encrypter_(new test::TestEncrypter()), + decrypter_(new test::TestDecrypter()), + version_(GetParam()), + start_(QuicTime::Zero() + QuicTime::Delta::FromMicroseconds(0x10)), + framer_(AllSupportedVersions(), start_, Perspective::IS_SERVER, + kQuicDefaultConnectionIdLength) { + framer_.set_version(version_); + if (framer_.version().KnowsWhichDecrypterToUse()) { + framer_.InstallDecrypter(ENCRYPTION_INITIAL, + std::unique_ptr(decrypter_)); + } else { + framer_.SetDecrypter(ENCRYPTION_INITIAL, + std::unique_ptr(decrypter_)); + } + framer_.SetEncrypter(ENCRYPTION_INITIAL, + std::unique_ptr(encrypter_)); + + framer_.set_visitor(&visitor_); + framer_.InferPacketHeaderTypeFromVersion(); + visitor_.set_framer(&framer_); + } + + void SetDecrypterLevel(EncryptionLevel level) { + if (!framer_.version().KnowsWhichDecrypterToUse()) { + return; + } + decrypter_ = new TestDecrypter(); + framer_.InstallDecrypter(level, std::unique_ptr(decrypter_)); + } + + // Helper function to get unsigned char representation of the handshake + // protocol byte at position |pos| of the current QUIC version number. + unsigned char GetQuicVersionByte(int pos) { + return (CreateQuicVersionLabel(version_) >> 8 * (3 - pos)) & 0xff; + } + + // Helper functions to take a v1 long header packet and make it v2. These are + // not needed for short header packets, but if sent, this function will exit + // cleanly. It needs to be called twice for coalesced packets (see references + // to length_of_first_coalesced_packet below for examples of how to do this). + inline void ReviseFirstByteByVersion(unsigned char packet_ietf[]) { + if (version_.UsesV2PacketTypes() && (packet_ietf[0] >= 0x80)) { + packet_ietf[0] = (packet_ietf[0] + 0x10) | 0xc0; + } + } + inline void ReviseFirstByteByVersion(PacketFragments& packet_ietf) { + ReviseFirstByteByVersion(&packet_ietf[0].fragment[0]); + } + + bool CheckEncryption(QuicPacketNumber packet_number, QuicPacket* packet) { + if (packet_number != encrypter_->packet_number_) { + QUIC_LOG(ERROR) << "Encrypted incorrect packet number. expected " + << packet_number + << " actual: " << encrypter_->packet_number_; + return false; + } + if (packet->AssociatedData(framer_.transport_version()) != + encrypter_->associated_data_) { + QUIC_LOG(ERROR) << "Encrypted incorrect associated data. expected " + << packet->AssociatedData(framer_.transport_version()) + << " actual: " << encrypter_->associated_data_; + return false; + } + if (packet->Plaintext(framer_.transport_version()) != + encrypter_->plaintext_) { + QUIC_LOG(ERROR) << "Encrypted incorrect plaintext data. expected " + << packet->Plaintext(framer_.transport_version()) + << " actual: " << encrypter_->plaintext_; + return false; + } + return true; + } + + bool CheckDecryption(const QuicEncryptedPacket& encrypted, + bool includes_version, + bool includes_diversification_nonce, + uint8_t destination_connection_id_length, + uint8_t source_connection_id_length) { + return CheckDecryption( + encrypted, includes_version, includes_diversification_nonce, + destination_connection_id_length, source_connection_id_length, + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0, 0, + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0); + } + + bool CheckDecryption( + const QuicEncryptedPacket& encrypted, bool includes_version, + bool includes_diversification_nonce, + uint8_t destination_connection_id_length, + uint8_t source_connection_id_length, + quiche::QuicheVariableLengthIntegerLength retry_token_length_length, + size_t retry_token_length, + quiche::QuicheVariableLengthIntegerLength length_length) { + if (visitor_.header_->packet_number != decrypter_->packet_number_) { + QUIC_LOG(ERROR) << "Decrypted incorrect packet number. expected " + << visitor_.header_->packet_number + << " actual: " << decrypter_->packet_number_; + return false; + } + absl::string_view associated_data = + QuicFramer::GetAssociatedDataFromEncryptedPacket( + framer_.transport_version(), encrypted, + destination_connection_id_length, source_connection_id_length, + includes_version, includes_diversification_nonce, + PACKET_4BYTE_PACKET_NUMBER, retry_token_length_length, + retry_token_length, length_length); + if (associated_data != decrypter_->associated_data_) { + QUIC_LOG(ERROR) << "Decrypted incorrect associated data. expected " + << absl::BytesToHexString(associated_data) << " actual: " + << absl::BytesToHexString(decrypter_->associated_data_); + return false; + } + absl::string_view ciphertext( + encrypted.AsStringPiece().substr(GetStartOfEncryptedData( + framer_.transport_version(), destination_connection_id_length, + source_connection_id_length, includes_version, + includes_diversification_nonce, PACKET_4BYTE_PACKET_NUMBER, + retry_token_length_length, retry_token_length, length_length))); + if (ciphertext != decrypter_->ciphertext_) { + QUIC_LOG(ERROR) << "Decrypted incorrect ciphertext data. expected " + << absl::BytesToHexString(ciphertext) << " actual: " + << absl::BytesToHexString(decrypter_->ciphertext_) + << " associated data: " + << absl::BytesToHexString(associated_data); + return false; + } + return true; + } + + char* AsChars(unsigned char* data) { return reinterpret_cast(data); } + + // Creates a new QuicEncryptedPacket by concatenating the various + // packet fragments in |fragments|. + std::unique_ptr AssemblePacketFromFragments( + const PacketFragments& fragments) { + char* buffer = new char[kMaxOutgoingPacketSize + 1]; + size_t len = 0; + for (const auto& fragment : fragments) { + memcpy(buffer + len, fragment.fragment.data(), fragment.fragment.size()); + len += fragment.fragment.size(); + } + return std::make_unique(buffer, len, true); + } + + void CheckFramingBoundaries(const PacketFragments& fragments, + QuicErrorCode error_code) { + std::unique_ptr packet( + AssemblePacketFromFragments(fragments)); + // Check all the various prefixes of |packet| for the expected + // parse error and error code. + for (size_t i = 0; i < packet->length(); ++i) { + std::string expected_error; + size_t len = 0; + for (const auto& fragment : fragments) { + len += fragment.fragment.size(); + if (i < len) { + expected_error = fragment.error_if_missing; + break; + } + } + + if (expected_error.empty()) continue; + + CheckProcessingFails(*packet, i, expected_error, error_code); + } + } + + void CheckProcessingFails(const QuicEncryptedPacket& packet, size_t len, + std::string expected_error, + QuicErrorCode error_code) { + QuicEncryptedPacket encrypted(packet.data(), len, false); + EXPECT_FALSE(framer_.ProcessPacket(encrypted)) << "len: " << len; + EXPECT_EQ(expected_error, framer_.detailed_error()) << "len: " << len; + EXPECT_EQ(error_code, framer_.error()) << "len: " << len; + } + + void CheckProcessingFails(unsigned char* packet, size_t len, + std::string expected_error, + QuicErrorCode error_code) { + QuicEncryptedPacket encrypted(AsChars(packet), len, false); + EXPECT_FALSE(framer_.ProcessPacket(encrypted)) << "len: " << len; + EXPECT_EQ(expected_error, framer_.detailed_error()) << "len: " << len; + EXPECT_EQ(error_code, framer_.error()) << "len: " << len; + } + + // Checks if the supplied string matches data in the supplied StreamFrame. + void CheckStreamFrameData(std::string str, QuicStreamFrame* frame) { + EXPECT_EQ(str, std::string(frame->data_buffer, frame->data_length)); + } + + void CheckCalculatePacketNumber(uint64_t expected_packet_number, + QuicPacketNumber last_packet_number) { + uint64_t wire_packet_number = expected_packet_number & kMask; + EXPECT_EQ(expected_packet_number, + QuicFramerPeer::CalculatePacketNumberFromWire( + &framer_, PACKET_4BYTE_PACKET_NUMBER, last_packet_number, + wire_packet_number)) + << "last_packet_number: " << last_packet_number + << " wire_packet_number: " << wire_packet_number; + } + + std::unique_ptr BuildDataPacket(const QuicPacketHeader& header, + const QuicFrames& frames) { + return BuildUnsizedDataPacket(&framer_, header, frames); + } + + std::unique_ptr BuildDataPacket(const QuicPacketHeader& header, + const QuicFrames& frames, + size_t packet_size) { + return BuildUnsizedDataPacket(&framer_, header, frames, packet_size); + } + + // N starts at 1. + QuicStreamId GetNthStreamid(QuicTransportVersion transport_version, + Perspective perspective, bool bidirectional, + int n) { + if (bidirectional) { + return QuicUtils::GetFirstBidirectionalStreamId(transport_version, + perspective) + + ((n - 1) * QuicUtils::StreamIdDelta(transport_version)); + } + // Unidirectional + return QuicUtils::GetFirstUnidirectionalStreamId(transport_version, + perspective) + + ((n - 1) * QuicUtils::StreamIdDelta(transport_version)); + } + + QuicTime CreationTimePlus(uint64_t offset_us) { + return framer_.creation_time() + + QuicTime::Delta::FromMicroseconds(offset_us); + } + + test::TestEncrypter* encrypter_; + test::TestDecrypter* decrypter_; + ParsedQuicVersion version_; + QuicTime start_; + QuicFramer framer_; + test::TestQuicVisitor visitor_; + quiche::SimpleBufferAllocator allocator_; +}; + +// Multiple test cases of QuicFramerTest use byte arrays to define packets for +// testing, and these byte arrays contain the QUIC version. This macro explodes +// the 32-bit version into four bytes in network order. Since it uses methods of +// QuicFramerTest, it is only valid to use this in a QuicFramerTest. +#define QUIC_VERSION_BYTES \ + GetQuicVersionByte(0), GetQuicVersionByte(1), GetQuicVersionByte(2), \ + GetQuicVersionByte(3) + +// Run all framer tests with all supported versions of QUIC. +INSTANTIATE_TEST_SUITE_P(QuicFramerTests, QuicFramerTest, + ::testing::ValuesIn(AllSupportedVersions()), + ::testing::PrintToStringParamName()); + +TEST_P(QuicFramerTest, CalculatePacketNumberFromWireNearEpochStart) { + // A few quick manual sanity checks. + CheckCalculatePacketNumber(UINT64_C(1), QuicPacketNumber()); + CheckCalculatePacketNumber(kEpoch + 1, QuicPacketNumber(kMask)); + CheckCalculatePacketNumber(kEpoch, QuicPacketNumber(kMask)); + for (uint64_t j = 0; j < 10; j++) { + CheckCalculatePacketNumber(j, QuicPacketNumber()); + CheckCalculatePacketNumber(kEpoch - 1 - j, QuicPacketNumber()); + } + + // Cases where the last number was close to the start of the range. + for (QuicPacketNumber last = QuicPacketNumber(1); last < QuicPacketNumber(10); + last++) { + // Small numbers should not wrap (even if they're out of order). + for (uint64_t j = 0; j < 10; j++) { + CheckCalculatePacketNumber(j, last); + } + + // Large numbers should not wrap either (because we're near 0 already). + for (uint64_t j = 0; j < 10; j++) { + CheckCalculatePacketNumber(kEpoch - 1 - j, last); + } + } +} + +TEST_P(QuicFramerTest, CalculatePacketNumberFromWireNearEpochEnd) { + // Cases where the last number was close to the end of the range + for (uint64_t i = 0; i < 10; i++) { + QuicPacketNumber last = QuicPacketNumber(kEpoch - i); + + // Small numbers should wrap. + for (uint64_t j = 0; j < 10; j++) { + CheckCalculatePacketNumber(kEpoch + j, last); + } + + // Large numbers should not (even if they're out of order). + for (uint64_t j = 0; j < 10; j++) { + CheckCalculatePacketNumber(kEpoch - 1 - j, last); + } + } +} + +// Next check where we're in a non-zero epoch to verify we handle +// reverse wrapping, too. +TEST_P(QuicFramerTest, CalculatePacketNumberFromWireNearPrevEpoch) { + const uint64_t prev_epoch = 1 * kEpoch; + const uint64_t cur_epoch = 2 * kEpoch; + // Cases where the last number was close to the start of the range + for (uint64_t i = 0; i < 10; i++) { + QuicPacketNumber last = QuicPacketNumber(cur_epoch + i); + // Small number should not wrap (even if they're out of order). + for (uint64_t j = 0; j < 10; j++) { + CheckCalculatePacketNumber(cur_epoch + j, last); + } + + // But large numbers should reverse wrap. + for (uint64_t j = 0; j < 10; j++) { + uint64_t num = kEpoch - 1 - j; + CheckCalculatePacketNumber(prev_epoch + num, last); + } + } +} + +TEST_P(QuicFramerTest, CalculatePacketNumberFromWireNearNextEpoch) { + const uint64_t cur_epoch = 2 * kEpoch; + const uint64_t next_epoch = 3 * kEpoch; + // Cases where the last number was close to the end of the range + for (uint64_t i = 0; i < 10; i++) { + QuicPacketNumber last = QuicPacketNumber(next_epoch - 1 - i); + + // Small numbers should wrap. + for (uint64_t j = 0; j < 10; j++) { + CheckCalculatePacketNumber(next_epoch + j, last); + } + + // but large numbers should not (even if they're out of order). + for (uint64_t j = 0; j < 10; j++) { + uint64_t num = kEpoch - 1 - j; + CheckCalculatePacketNumber(cur_epoch + num, last); + } + } +} + +TEST_P(QuicFramerTest, CalculatePacketNumberFromWireNearNextMax) { + const uint64_t max_number = std::numeric_limits::max(); + const uint64_t max_epoch = max_number & ~kMask; + + // Cases where the last number was close to the end of the range + for (uint64_t i = 0; i < 10; i++) { + // Subtract 1, because the expected next packet number is 1 more than the + // last packet number. + QuicPacketNumber last = QuicPacketNumber(max_number - i - 1); + + // Small numbers should not wrap, because they have nowhere to go. + for (uint64_t j = 0; j < 10; j++) { + CheckCalculatePacketNumber(max_epoch + j, last); + } + + // Large numbers should not wrap either. + for (uint64_t j = 0; j < 10; j++) { + uint64_t num = kEpoch - 1 - j; + CheckCalculatePacketNumber(max_epoch + num, last); + } + } +} + +TEST_P(QuicFramerTest, EmptyPacket) { + char packet[] = {0x00}; + QuicEncryptedPacket encrypted(packet, 0, false); + EXPECT_FALSE(framer_.ProcessPacket(encrypted)); + EXPECT_THAT(framer_.error(), IsError(QUIC_INVALID_PACKET_HEADER)); +} + +TEST_P(QuicFramerTest, LargePacket) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + unsigned char packet[kMaxIncomingPacketSize + 1] = { + // public flags (8 byte connection_id) + 0x28, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x78, 0x56, 0x34, 0x12, + // private flags + 0x00, + }; + unsigned char packet46[kMaxIncomingPacketSize + 1] = { + // type (short header 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x78, 0x56, 0x34, 0x12, + }; + // clang-format on + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_size = ABSL_ARRAYSIZE(packet46); + } + + const size_t header_size = GetPacketHeaderSize( + framer_.transport_version(), kPacket8ByteConnectionId, + kPacket0ByteConnectionId, !kIncludeVersion, !kIncludeDiversificationNonce, + PACKET_4BYTE_PACKET_NUMBER, quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0, 0, + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0); + + memset(p + header_size, 0, kMaxIncomingPacketSize - header_size); + + QuicEncryptedPacket encrypted(AsChars(p), p_size, false); + EXPECT_FALSE(framer_.ProcessPacket(encrypted)); + + ASSERT_TRUE(visitor_.header_.get()); + // Make sure we've parsed the packet header, so we can send an error. + EXPECT_EQ(FramerTestConnectionId(), + visitor_.header_->destination_connection_id); + // Make sure the correct error is propagated. + EXPECT_THAT(framer_.error(), IsError(QUIC_PACKET_TOO_LARGE)); + EXPECT_EQ("Packet too large.", framer_.detailed_error()); +} + +TEST_P(QuicFramerTest, PacketHeader) { + if (framer_.version().HasIetfInvariantHeader()) { + return; + } + + // clang-format off + PacketFragments packet = { + // public flags (8 byte connection_id) + {"Unable to read public flags.", + {0x28}}, + // connection_id + {"Unable to read ConnectionId.", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"Unable to read packet number.", + {0x12, 0x34, 0x56, 0x78}}, + }; + // clang-format on + + PacketFragments& fragments = packet; + + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsError(QUIC_MISSING_PAYLOAD)); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_EQ(FramerTestConnectionId(), + visitor_.header_->destination_connection_id); + EXPECT_FALSE(visitor_.header_->reset_flag); + EXPECT_FALSE(visitor_.header_->version_flag); + EXPECT_EQ(kPacketNumber, visitor_.header_->packet_number); + + CheckFramingBoundaries(fragments, QUIC_INVALID_PACKET_HEADER); + + PacketHeaderFormat format; + QuicLongHeaderType long_packet_type = INVALID_PACKET_TYPE; + bool version_flag; + QuicConnectionId destination_connection_id, source_connection_id; + QuicVersionLabel version_label; + std::string detailed_error; + bool use_length_prefix; + absl::optional retry_token; + ParsedQuicVersion parsed_version = UnsupportedQuicVersion(); + const QuicErrorCode error_code = QuicFramer::ParsePublicHeaderDispatcher( + *encrypted, kQuicDefaultConnectionIdLength, &format, &long_packet_type, + &version_flag, &use_length_prefix, &version_label, &parsed_version, + &destination_connection_id, &source_connection_id, &retry_token, + &detailed_error); + EXPECT_FALSE(retry_token.has_value()); + EXPECT_FALSE(use_length_prefix); + EXPECT_THAT(error_code, IsQuicNoError()); + EXPECT_EQ(GOOGLE_QUIC_PACKET, format); + EXPECT_FALSE(version_flag); + EXPECT_EQ(kQuicDefaultConnectionIdLength, destination_connection_id.length()); + EXPECT_EQ(FramerTestConnectionId(), destination_connection_id); + EXPECT_EQ(EmptyQuicConnectionId(), source_connection_id); +} + +TEST_P(QuicFramerTest, LongPacketHeader) { + // clang-format off + PacketFragments packet46 = { + // type (long header with packet type ZERO_RTT) + {"Unable to read first byte.", + {0xD3}}, + // version tag + {"Unable to read protocol version.", + {QUIC_VERSION_BYTES}}, + // connection_id length + {"Unable to read ConnectionId length.", + {0x50}}, + // connection_id + {"Unable to read destination connection ID.", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"Unable to read packet number.", + {0x12, 0x34, 0x56, 0x78}}, + }; + // clang-format on + + if (!framer_.version().HasIetfInvariantHeader() || + QuicVersionHasLongHeaderLengths(framer_.transport_version())) { + return; + } + + SetDecrypterLevel(ENCRYPTION_ZERO_RTT); + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet46)); + + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsError(QUIC_MISSING_PAYLOAD)); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_EQ(FramerTestConnectionId(), + visitor_.header_->destination_connection_id); + EXPECT_FALSE(visitor_.header_->reset_flag); + EXPECT_TRUE(visitor_.header_->version_flag); + EXPECT_EQ(kPacketNumber, visitor_.header_->packet_number); + + CheckFramingBoundaries(packet46, QUIC_INVALID_PACKET_HEADER); + + PacketHeaderFormat format; + QuicLongHeaderType long_packet_type = INVALID_PACKET_TYPE; + bool version_flag; + QuicConnectionId destination_connection_id, source_connection_id; + QuicVersionLabel version_label; + std::string detailed_error; + bool use_length_prefix; + absl::optional retry_token; + ParsedQuicVersion parsed_version = UnsupportedQuicVersion(); + const QuicErrorCode error_code = QuicFramer::ParsePublicHeaderDispatcher( + *encrypted, kQuicDefaultConnectionIdLength, &format, &long_packet_type, + &version_flag, &use_length_prefix, &version_label, &parsed_version, + &destination_connection_id, &source_connection_id, &retry_token, + &detailed_error); + EXPECT_THAT(error_code, IsQuicNoError()); + EXPECT_EQ("", detailed_error); + EXPECT_FALSE(retry_token.has_value()); + EXPECT_FALSE(use_length_prefix); + EXPECT_EQ(IETF_QUIC_LONG_HEADER_PACKET, format); + EXPECT_TRUE(version_flag); + EXPECT_EQ(kQuicDefaultConnectionIdLength, destination_connection_id.length()); + EXPECT_EQ(FramerTestConnectionId(), destination_connection_id); + EXPECT_EQ(EmptyQuicConnectionId(), source_connection_id); +} + +TEST_P(QuicFramerTest, LongPacketHeaderWithBothConnectionIds) { + if (!framer_.version().HasIetfInvariantHeader()) { + // This test requires an IETF long header. + return; + } + SetDecrypterLevel(ENCRYPTION_ZERO_RTT); + // clang-format off + unsigned char packet[] = { + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + 0xD3, + // version + QUIC_VERSION_BYTES, + // connection ID lengths + 0x55, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x11, + // packet number + 0x12, 0x34, 0x56, 0x00, + // padding frame + 0x00, + }; + unsigned char packet49[] = { + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + 0xD3, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x08, + // source connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x11, + // long header packet length + 0x05, + // packet number + 0x12, 0x34, 0x56, 0x00, + // padding frame + 0x00, + }; + // clang-format on + + unsigned char* p = packet; + size_t p_length = ABSL_ARRAYSIZE(packet); + if (framer_.version().HasLongHeaderLengths()) { + ReviseFirstByteByVersion(packet49); + p = packet49; + p_length = ABSL_ARRAYSIZE(packet49); + } + + QuicEncryptedPacket encrypted(AsChars(p), p_length, false); + PacketHeaderFormat format = GOOGLE_QUIC_PACKET; + QuicLongHeaderType long_packet_type = INVALID_PACKET_TYPE; + bool version_flag = false; + QuicConnectionId destination_connection_id, source_connection_id; + QuicVersionLabel version_label = 0; + std::string detailed_error = ""; + bool use_length_prefix; + absl::optional retry_token; + ParsedQuicVersion parsed_version = UnsupportedQuicVersion(); + const QuicErrorCode error_code = QuicFramer::ParsePublicHeaderDispatcher( + encrypted, kQuicDefaultConnectionIdLength, &format, &long_packet_type, + &version_flag, &use_length_prefix, &version_label, &parsed_version, + &destination_connection_id, &source_connection_id, &retry_token, + &detailed_error); + EXPECT_THAT(error_code, IsQuicNoError()); + EXPECT_FALSE(retry_token.has_value()); + EXPECT_EQ(framer_.version().HasLengthPrefixedConnectionIds(), + use_length_prefix); + EXPECT_EQ("", detailed_error); + EXPECT_EQ(IETF_QUIC_LONG_HEADER_PACKET, format); + EXPECT_TRUE(version_flag); + EXPECT_EQ(FramerTestConnectionId(), destination_connection_id); + EXPECT_EQ(FramerTestConnectionIdPlusOne(), source_connection_id); +} + +TEST_P(QuicFramerTest, AllZeroPacketParsingFails) { + unsigned char packet[1200] = {}; + QuicEncryptedPacket encrypted(AsChars(packet), ABSL_ARRAYSIZE(packet), false); + PacketHeaderFormat format = GOOGLE_QUIC_PACKET; + QuicLongHeaderType long_packet_type = INVALID_PACKET_TYPE; + bool version_flag = false; + QuicConnectionId destination_connection_id, source_connection_id; + QuicVersionLabel version_label = 0; + std::string detailed_error = ""; + bool use_length_prefix; + absl::optional retry_token; + ParsedQuicVersion parsed_version = UnsupportedQuicVersion(); + const QuicErrorCode error_code = QuicFramer::ParsePublicHeaderDispatcher( + encrypted, kQuicDefaultConnectionIdLength, &format, &long_packet_type, + &version_flag, &use_length_prefix, &version_label, &parsed_version, + &destination_connection_id, &source_connection_id, &retry_token, + &detailed_error); + EXPECT_EQ(error_code, QUIC_INVALID_PACKET_HEADER); + EXPECT_EQ(detailed_error, "Invalid flags."); +} + +TEST_P(QuicFramerTest, ParsePublicHeader) { + // clang-format off + unsigned char packet[] = { + // public flags (version included, 8-byte connection ID, + // 4-byte packet number) + 0x29, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // version + QUIC_VERSION_BYTES, + // packet number + 0x12, 0x34, 0x56, 0x78, + // padding frame + 0x00, + }; + unsigned char packet46[] = { + // public flags (long header with packet type HANDSHAKE and + // 4-byte packet number) + 0xE3, + // version + QUIC_VERSION_BYTES, + // connection ID lengths + 0x50, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // long header packet length + 0x05, + // packet number + 0x12, 0x34, 0x56, 0x78, + // padding frame + 0x00, + }; + unsigned char packet49[] = { + // public flags (long header with packet type HANDSHAKE and + // 4-byte packet number) + 0xE3, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // long header packet length + 0x05, + // packet number + 0x12, 0x34, 0x56, 0x78, + // padding frame + 0x00, + }; + // clang-format on + unsigned char* p = packet; + size_t p_length = ABSL_ARRAYSIZE(packet); + if (framer_.version().HasLongHeaderLengths()) { + ReviseFirstByteByVersion(packet49); + p = packet49; + p_length = ABSL_ARRAYSIZE(packet49); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_length = ABSL_ARRAYSIZE(packet46); + } + + uint8_t first_byte = 0x33; + PacketHeaderFormat format = GOOGLE_QUIC_PACKET; + bool version_present = false, has_length_prefix = false; + QuicVersionLabel version_label = 0; + ParsedQuicVersion parsed_version = UnsupportedQuicVersion(); + QuicConnectionId destination_connection_id = EmptyQuicConnectionId(), + source_connection_id = EmptyQuicConnectionId(); + QuicLongHeaderType long_packet_type = INVALID_PACKET_TYPE; + quiche::QuicheVariableLengthIntegerLength retry_token_length_length = + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_4; + absl::string_view retry_token; + std::string detailed_error = "foobar"; + + QuicDataReader reader(AsChars(p), p_length); + const QuicErrorCode parse_error = QuicFramer::ParsePublicHeader( + &reader, kQuicDefaultConnectionIdLength, + /*ietf_format=*/ + framer_.version().HasIetfInvariantHeader(), &first_byte, &format, + &version_present, &has_length_prefix, &version_label, &parsed_version, + &destination_connection_id, &source_connection_id, &long_packet_type, + &retry_token_length_length, &retry_token, &detailed_error); + EXPECT_THAT(parse_error, IsQuicNoError()); + EXPECT_EQ("", detailed_error); + EXPECT_EQ(p[0], first_byte); + EXPECT_TRUE(version_present); + EXPECT_EQ(framer_.version().HasLengthPrefixedConnectionIds(), + has_length_prefix); + EXPECT_EQ(CreateQuicVersionLabel(framer_.version()), version_label); + EXPECT_EQ(framer_.version(), parsed_version); + EXPECT_EQ(FramerTestConnectionId(), destination_connection_id); + EXPECT_EQ(EmptyQuicConnectionId(), source_connection_id); + EXPECT_EQ(quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0, + retry_token_length_length); + EXPECT_EQ(absl::string_view(), retry_token); + if (framer_.version().HasIetfInvariantHeader()) { + EXPECT_EQ(IETF_QUIC_LONG_HEADER_PACKET, format); + EXPECT_EQ(HANDSHAKE, long_packet_type); + } else { + EXPECT_EQ(GOOGLE_QUIC_PACKET, format); + } +} + +TEST_P(QuicFramerTest, ParsePublicHeaderProxBadSourceConnectionIdLength) { + if (!framer_.version().HasLengthPrefixedConnectionIds()) { + return; + } + // clang-format off + unsigned char packet[] = { + // public flags (long header with packet type HANDSHAKE and + // 4-byte packet number) + 0xE3, + // version + 'P', 'R', 'O', 'X', + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length (bogus) + 0xEE, + // long header packet length + 0x05, + // packet number + 0x12, 0x34, 0x56, 0x78, + // padding frame + 0x00, + }; + // clang-format on + unsigned char* p = packet; + size_t p_length = ABSL_ARRAYSIZE(packet); + + uint8_t first_byte = 0x33; + PacketHeaderFormat format = GOOGLE_QUIC_PACKET; + bool version_present = false, has_length_prefix = false; + QuicVersionLabel version_label = 0; + ParsedQuicVersion parsed_version = UnsupportedQuicVersion(); + QuicConnectionId destination_connection_id = EmptyQuicConnectionId(), + source_connection_id = EmptyQuicConnectionId(); + QuicLongHeaderType long_packet_type = INVALID_PACKET_TYPE; + quiche::QuicheVariableLengthIntegerLength retry_token_length_length = + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_4; + absl::string_view retry_token; + std::string detailed_error = "foobar"; + + QuicDataReader reader(AsChars(p), p_length); + const QuicErrorCode parse_error = QuicFramer::ParsePublicHeader( + &reader, kQuicDefaultConnectionIdLength, + /*ietf_format=*/true, &first_byte, &format, &version_present, + &has_length_prefix, &version_label, &parsed_version, + &destination_connection_id, &source_connection_id, &long_packet_type, + &retry_token_length_length, &retry_token, &detailed_error); + EXPECT_THAT(parse_error, IsQuicNoError()); + EXPECT_EQ("", detailed_error); + EXPECT_EQ(p[0], first_byte); + EXPECT_TRUE(version_present); + EXPECT_TRUE(has_length_prefix); + EXPECT_EQ(0x50524F58u, version_label); // "PROX" + EXPECT_EQ(UnsupportedQuicVersion(), parsed_version); + EXPECT_EQ(FramerTestConnectionId(), destination_connection_id); + EXPECT_EQ(EmptyQuicConnectionId(), source_connection_id); + EXPECT_EQ(quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0, + retry_token_length_length); + EXPECT_EQ(absl::string_view(), retry_token); + EXPECT_EQ(IETF_QUIC_LONG_HEADER_PACKET, format); +} + +TEST_P(QuicFramerTest, ClientConnectionIdFromShortHeaderToClient) { + if (!framer_.version().SupportsClientConnectionIds()) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + QuicFramerPeer::SetLastSerializedServerConnectionId(&framer_, + TestConnectionId(0x33)); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + framer_.SetExpectedClientConnectionIdLength(kQuicDefaultConnectionIdLength); + // clang-format off + unsigned char packet[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x13, 0x37, 0x42, 0x33, + // padding frame + 0x00, + }; + // clang-format on + QuicEncryptedPacket encrypted(AsChars(packet), ABSL_ARRAYSIZE(packet), false); + EXPECT_TRUE(framer_.ProcessPacket(encrypted)); + EXPECT_THAT(framer_.error(), IsQuicNoError()); + EXPECT_EQ("", framer_.detailed_error()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_EQ(FramerTestConnectionId(), + visitor_.header_->destination_connection_id); +} + +// In short header packets from client to server, the client connection ID +// is omitted, but the framer adds it to the header struct using its +// last serialized client connection ID. This test ensures that this +// mechanism behaves as expected. +TEST_P(QuicFramerTest, ClientConnectionIdFromShortHeaderToServer) { + if (!framer_.version().SupportsClientConnectionIds()) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + QuicFramerPeer::SetLastSerializedClientConnectionId(&framer_, + TestConnectionId(0x33)); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // clang-format off + unsigned char packet[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x13, 0x37, 0x42, 0x33, + // padding frame + 0x00, + }; + // clang-format on + QuicEncryptedPacket encrypted(AsChars(packet), ABSL_ARRAYSIZE(packet), false); + EXPECT_TRUE(framer_.ProcessPacket(encrypted)); + EXPECT_THAT(framer_.error(), IsQuicNoError()); + EXPECT_EQ("", framer_.detailed_error()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_EQ(FramerTestConnectionId(), + visitor_.header_->destination_connection_id); +} + +TEST_P(QuicFramerTest, PacketHeaderWith0ByteConnectionId) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + QuicFramerPeer::SetLastSerializedServerConnectionId(&framer_, + FramerTestConnectionId()); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + + // clang-format off + PacketFragments packet = { + // public flags (0 byte connection_id) + {"Unable to read public flags.", + {0x20}}, + // connection_id + // packet number + {"Unable to read packet number.", + {0x12, 0x34, 0x56, 0x78}}, + }; + + PacketFragments packet46 = { + // type (short header, 4 byte packet number) + {"Unable to read first byte.", + {0x43}}, + // connection_id + // packet number + {"Unable to read packet number.", + {0x12, 0x34, 0x56, 0x78}}, + }; + + PacketFragments packet_hp = { + // type (short header, 4 byte packet number) + {"Unable to read first byte.", + {0x43}}, + // connection_id + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + }; + // clang-format on + + PacketFragments& fragments = + framer_.version().HasHeaderProtection() + ? packet_hp + : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsError(QUIC_MISSING_PAYLOAD)); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_FALSE(visitor_.header_->reset_flag); + EXPECT_FALSE(visitor_.header_->version_flag); + EXPECT_EQ(kPacketNumber, visitor_.header_->packet_number); + + CheckFramingBoundaries(fragments, QUIC_INVALID_PACKET_HEADER); +} + +TEST_P(QuicFramerTest, PacketHeaderWithVersionFlag) { + SetDecrypterLevel(ENCRYPTION_ZERO_RTT); + // clang-format off + PacketFragments packet = { + // public flags (0 byte connection_id) + {"Unable to read public flags.", + {0x29}}, + // connection_id + {"Unable to read ConnectionId.", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // version tag + {"Unable to read protocol version.", + {QUIC_VERSION_BYTES}}, + // packet number + {"Unable to read packet number.", + {0x12, 0x34, 0x56, 0x78}}, + }; + + PacketFragments packet46 = { + // type (long header with packet type ZERO_RTT_PROTECTED and 4 bytes + // packet number) + {"Unable to read first byte.", + {0xD3}}, + // version tag + {"Unable to read protocol version.", + {QUIC_VERSION_BYTES}}, + // connection_id length + {"Unable to read ConnectionId length.", + {0x50}}, + // connection_id + {"Unable to read destination connection ID.", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"Unable to read packet number.", + {0x12, 0x34, 0x56, 0x78}}, + }; + + PacketFragments packet49 = { + // type (long header with packet type ZERO_RTT_PROTECTED and 4 bytes + // packet number) + {"Unable to read first byte.", + {0xD3}}, + // version tag + {"Unable to read protocol version.", + {QUIC_VERSION_BYTES}}, + // destination connection ID length + {"Unable to read destination connection ID.", + {0x08}}, + // destination connection ID + {"Unable to read destination connection ID.", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // source connection ID length + {"Unable to read source connection ID.", + {0x00}}, + // long header packet length + {"Unable to read long header payload length.", + {0x04}}, + // packet number + {"Long header payload length longer than packet.", + {0x12, 0x34, 0x56, 0x78}}, + }; + // clang-format on + + ReviseFirstByteByVersion(packet49); + PacketFragments& fragments = + framer_.version().HasLongHeaderLengths() + ? packet49 + : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsError(QUIC_MISSING_PAYLOAD)); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_EQ(FramerTestConnectionId(), + visitor_.header_->destination_connection_id); + EXPECT_FALSE(visitor_.header_->reset_flag); + EXPECT_TRUE(visitor_.header_->version_flag); + EXPECT_EQ(GetParam(), visitor_.header_->version); + EXPECT_EQ(kPacketNumber, visitor_.header_->packet_number); + + CheckFramingBoundaries(fragments, QUIC_INVALID_PACKET_HEADER); +} + +TEST_P(QuicFramerTest, PacketHeaderWith4BytePacketNumber) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + QuicFramerPeer::SetLargestPacketNumber(&framer_, kPacketNumber - 2); + + // clang-format off + PacketFragments packet = { + // public flags (8 byte connection_id and 4 byte packet number) + {"Unable to read public flags.", + {0x28}}, + // connection_id + {"Unable to read ConnectionId.", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"Unable to read packet number.", + {0x12, 0x34, 0x56, 0x78}}, + }; + + PacketFragments packet46 = { + // type (short header, 4 byte packet number) + {"Unable to read first byte.", + {0x43}}, + // connection_id + {"Unable to read destination connection ID.", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"Unable to read packet number.", + {0x12, 0x34, 0x56, 0x78}}, + }; + + PacketFragments packet_hp = { + // type (short header, 4 byte packet number) + {"Unable to read first byte.", + {0x43}}, + // connection_id + {"Unable to read destination connection ID.", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + }; + // clang-format on + + PacketFragments& fragments = + framer_.version().HasHeaderProtection() + ? packet_hp + : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsError(QUIC_MISSING_PAYLOAD)); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_EQ(FramerTestConnectionId(), + visitor_.header_->destination_connection_id); + EXPECT_FALSE(visitor_.header_->reset_flag); + EXPECT_FALSE(visitor_.header_->version_flag); + EXPECT_EQ(kPacketNumber, visitor_.header_->packet_number); + + CheckFramingBoundaries(fragments, QUIC_INVALID_PACKET_HEADER); +} + +TEST_P(QuicFramerTest, PacketHeaderWith2BytePacketNumber) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + QuicFramerPeer::SetLargestPacketNumber(&framer_, kPacketNumber - 2); + + // clang-format off + PacketFragments packet = { + // public flags (8 byte connection_id and 2 byte packet number) + {"Unable to read public flags.", + {0x18}}, + // connection_id + {"Unable to read ConnectionId.", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"Unable to read packet number.", + {0x56, 0x78}}, + }; + + PacketFragments packet46 = { + // type (short header, 2 byte packet number) + {"Unable to read first byte.", + {0x41}}, + // connection_id + {"Unable to read destination connection ID.", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"Unable to read packet number.", + {0x56, 0x78}}, + }; + + PacketFragments packet_hp = { + // type (short header, 2 byte packet number) + {"Unable to read first byte.", + {0x41}}, + // connection_id + {"Unable to read destination connection ID.", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x56, 0x78}}, + // padding + {"", {0x00, 0x00}}, + }; + // clang-format on + + PacketFragments& fragments = + framer_.version().HasHeaderProtection() + ? packet_hp + : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + if (framer_.version().HasHeaderProtection()) { + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsQuicNoError()); + } else { + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsError(QUIC_MISSING_PAYLOAD)); + } + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_EQ(FramerTestConnectionId(), + visitor_.header_->destination_connection_id); + EXPECT_FALSE(visitor_.header_->reset_flag); + EXPECT_FALSE(visitor_.header_->version_flag); + EXPECT_EQ(PACKET_2BYTE_PACKET_NUMBER, visitor_.header_->packet_number_length); + EXPECT_EQ(kPacketNumber, visitor_.header_->packet_number); + + CheckFramingBoundaries(fragments, QUIC_INVALID_PACKET_HEADER); +} + +TEST_P(QuicFramerTest, PacketHeaderWith1BytePacketNumber) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + QuicFramerPeer::SetLargestPacketNumber(&framer_, kPacketNumber - 2); + + // clang-format off + PacketFragments packet = { + // public flags (8 byte connection_id and 1 byte packet number) + {"Unable to read public flags.", + {0x08}}, + // connection_id + {"Unable to read ConnectionId.", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"Unable to read packet number.", + {0x78}}, + }; + + PacketFragments packet46 = { + // type (8 byte connection_id and 1 byte packet number) + {"Unable to read first byte.", + {0x40}}, + // connection_id + {"Unable to read destination connection ID.", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"Unable to read packet number.", + {0x78}}, + }; + + PacketFragments packet_hp = { + // type (8 byte connection_id and 1 byte packet number) + {"Unable to read first byte.", + {0x40}}, + // connection_id + {"Unable to read destination connection ID.", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x78}}, + // padding + {"", {0x00, 0x00, 0x00}}, + }; + + // clang-format on + + PacketFragments& fragments = + framer_.version().HasHeaderProtection() + ? packet_hp + : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + if (framer_.version().HasHeaderProtection()) { + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsQuicNoError()); + } else { + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsError(QUIC_MISSING_PAYLOAD)); + } + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_EQ(FramerTestConnectionId(), + visitor_.header_->destination_connection_id); + EXPECT_FALSE(visitor_.header_->reset_flag); + EXPECT_FALSE(visitor_.header_->version_flag); + EXPECT_EQ(PACKET_1BYTE_PACKET_NUMBER, visitor_.header_->packet_number_length); + EXPECT_EQ(kPacketNumber, visitor_.header_->packet_number); + + CheckFramingBoundaries(fragments, QUIC_INVALID_PACKET_HEADER); +} + +TEST_P(QuicFramerTest, PacketNumberDecreasesThenIncreases) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // Test the case when a packet is received from the past and future packet + // numbers are still calculated relative to the largest received packet. + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber - 2; + + QuicFrames frames = {QuicFrame(QuicPaddingFrame())}; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + QuicEncryptedPacket encrypted(data->data(), data->length(), false); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(encrypted)); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_EQ(FramerTestConnectionId(), + visitor_.header_->destination_connection_id); + EXPECT_EQ(PACKET_4BYTE_PACKET_NUMBER, visitor_.header_->packet_number_length); + EXPECT_EQ(kPacketNumber - 2, visitor_.header_->packet_number); + + // Receive a 1 byte packet number. + header.packet_number = kPacketNumber; + header.packet_number_length = PACKET_1BYTE_PACKET_NUMBER; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + QuicEncryptedPacket encrypted1(data->data(), data->length(), false); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(encrypted1)); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_EQ(FramerTestConnectionId(), + visitor_.header_->destination_connection_id); + EXPECT_EQ(PACKET_1BYTE_PACKET_NUMBER, visitor_.header_->packet_number_length); + EXPECT_EQ(kPacketNumber, visitor_.header_->packet_number); + + // Process a 2 byte packet number 256 packets ago. + header.packet_number = kPacketNumber - 256; + header.packet_number_length = PACKET_2BYTE_PACKET_NUMBER; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + QuicEncryptedPacket encrypted2(data->data(), data->length(), false); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(encrypted2)); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_EQ(FramerTestConnectionId(), + visitor_.header_->destination_connection_id); + EXPECT_EQ(PACKET_2BYTE_PACKET_NUMBER, visitor_.header_->packet_number_length); + EXPECT_EQ(kPacketNumber - 256, visitor_.header_->packet_number); + + // Process another 1 byte packet number and ensure it works. + header.packet_number = kPacketNumber - 1; + header.packet_number_length = PACKET_1BYTE_PACKET_NUMBER; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + QuicEncryptedPacket encrypted3(data->data(), data->length(), false); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(encrypted3)); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_EQ(FramerTestConnectionId(), + visitor_.header_->destination_connection_id); + EXPECT_EQ(PACKET_1BYTE_PACKET_NUMBER, visitor_.header_->packet_number_length); + EXPECT_EQ(kPacketNumber - 1, visitor_.header_->packet_number); +} + +TEST_P(QuicFramerTest, PacketWithDiversificationNonce) { + SetDecrypterLevel(ENCRYPTION_ZERO_RTT); + // clang-format off + unsigned char packet[] = { + // public flags: includes nonce flag + 0x2C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // nonce + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (padding) + 0x00, + 0x00, 0x00, 0x00, 0x00 + }; + + unsigned char packet46[] = { + // type: Long header with packet type ZERO_RTT_PROTECTED and 1 byte packet + // number. + 0xD0, + // version tag + QUIC_VERSION_BYTES, + // connection_id length + 0x05, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x78, + // nonce + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + + // frame type (padding) + 0x00, + 0x00, 0x00, 0x00, 0x00 + }; + + unsigned char packet49[] = { + // type: Long header with packet type ZERO_RTT_PROTECTED and 1 byte packet + // number. + 0xD0, + // version tag + QUIC_VERSION_BYTES, + // destination connection ID length + 0x00, + // source connection ID length + 0x08, + // source connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // long header packet length + 0x26, + // packet number + 0x78, + // nonce + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + + // frame type (padding) + 0x00, + 0x00, 0x00, 0x00, 0x00 + }; + // clang-format on + + if (framer_.version().handshake_protocol != PROTOCOL_QUIC_CRYPTO) { + return; + } + + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + if (framer_.version().HasLongHeaderLengths()) { + p = packet49; + p_size = ABSL_ARRAYSIZE(packet49); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_size = ABSL_ARRAYSIZE(packet46); + } + + QuicEncryptedPacket encrypted(AsChars(p), p_size, false); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + EXPECT_TRUE(framer_.ProcessPacket(encrypted)); + ASSERT_TRUE(visitor_.header_->nonce != nullptr); + for (char i = 0; i < 32; ++i) { + EXPECT_EQ(i, (*visitor_.header_->nonce)[static_cast(i)]); + } + EXPECT_EQ(1u, visitor_.padding_frames_.size()); + EXPECT_EQ(5, visitor_.padding_frames_[0]->num_padding_bytes); +} + +TEST_P(QuicFramerTest, LargePublicFlagWithMismatchedVersions) { + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id, version flag and an unknown flag) + 0x29, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // version tag + 'Q', '0', '0', '0', + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (padding frame) + 0x00, + 0x00, 0x00, 0x00, 0x00 + }; + + unsigned char packet46[] = { + // type (long header, ZERO_RTT_PROTECTED, 4-byte packet number) + 0xD3, + // version tag + 'Q', '0', '0', '0', + // connection_id length + 0x50, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (padding frame) + 0x00, + 0x00, 0x00, 0x00, 0x00 + }; + + unsigned char packet49[] = { + // type (long header, ZERO_RTT_PROTECTED, 4-byte packet number) + 0xD3, + // version tag + 'Q', '0', '0', '0', + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (padding frame) + 0x00, + 0x00, 0x00, 0x00, 0x00 + }; + // clang-format on + + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + if (framer_.version().HasLongHeaderLengths()) { + p = packet49; + p_size = ABSL_ARRAYSIZE(packet49); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_size = ABSL_ARRAYSIZE(packet46); + } + QuicEncryptedPacket encrypted(AsChars(p), p_size, false); + EXPECT_TRUE(framer_.ProcessPacket(encrypted)); + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_EQ(0, visitor_.frame_count_); + EXPECT_EQ(1, visitor_.version_mismatch_); +} + +TEST_P(QuicFramerTest, PaddingFrame) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x28, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // paddings + 0x00, 0x00, + // frame type (stream frame with fin) + 0xFF, + // stream id + 0x01, 0x02, 0x03, 0x04, + // offset + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54, + // data length + 0x00, 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + // paddings + 0x00, 0x00, + }; + + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // paddings + 0x00, 0x00, + // frame type (stream frame with fin) + 0xFF, + // stream id + 0x01, 0x02, 0x03, 0x04, + // offset + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54, + // data length + 0x00, 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + // paddings + 0x00, 0x00, + }; + + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // paddings + 0x00, 0x00, + // frame type - IETF_STREAM with FIN, LEN, and OFFSET bits set. + 0x08 | 0x01 | 0x02 | 0x04, + + // stream id + kVarInt62FourBytes + 0x01, 0x02, 0x03, 0x04, + // offset + kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54, + // data length + kVarInt62OneByte + 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + // paddings + 0x00, 0x00, + }; + // clang-format on + + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_size = ABSL_ARRAYSIZE(packet46); + } + + QuicEncryptedPacket encrypted(AsChars(p), p_size, false); + EXPECT_TRUE(framer_.ProcessPacket(encrypted)); + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + ASSERT_EQ(1u, visitor_.stream_frames_.size()); + EXPECT_EQ(0u, visitor_.ack_frames_.size()); + EXPECT_EQ(2u, visitor_.padding_frames_.size()); + EXPECT_EQ(2, visitor_.padding_frames_[0]->num_padding_bytes); + EXPECT_EQ(2, visitor_.padding_frames_[1]->num_padding_bytes); + EXPECT_EQ(kStreamId, visitor_.stream_frames_[0]->stream_id); + EXPECT_TRUE(visitor_.stream_frames_[0]->fin); + EXPECT_EQ(kStreamOffset, visitor_.stream_frames_[0]->offset); + CheckStreamFrameData("hello world!", visitor_.stream_frames_[0].get()); +} + +TEST_P(QuicFramerTest, StreamFrame) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet = { + // public flags (8 byte connection_id) + {"", + {0x28}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (stream frame with fin) + {"", + {0xFF}}, + // stream id + {"Unable to read stream_id.", + {0x01, 0x02, 0x03, 0x04}}, + // offset + {"Unable to read offset.", + {0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54}}, + {"Unable to read frame data.", + { + // data length + 0x00, 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!'}}, + }; + + PacketFragments packet46 = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (stream frame with fin) + {"", + {0xFF}}, + // stream id + {"Unable to read stream_id.", + {0x01, 0x02, 0x03, 0x04}}, + // offset + {"Unable to read offset.", + {0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54}}, + {"Unable to read frame data.", + { + // data length + 0x00, 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!'}}, + }; + + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type - IETF_STREAM with FIN, LEN, and OFFSET bits set. + {"", + { 0x08 | 0x01 | 0x02 | 0x04 }}, + // stream id + {"Unable to read IETF_STREAM frame stream id/count.", + {kVarInt62FourBytes + 0x01, 0x02, 0x03, 0x04}}, + // offset + {"Unable to read stream data offset.", + {kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54}}, + // data length + {"Unable to read stream data length.", + {kVarInt62OneByte + 0x0c}}, + // data + {"Unable to read frame data.", + { 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!'}}, + }; + // clang-format on + + PacketFragments& fragments = + VersionHasIetfQuicFrames(framer_.transport_version()) + ? packet_ietf + : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + ASSERT_EQ(1u, visitor_.stream_frames_.size()); + EXPECT_EQ(0u, visitor_.ack_frames_.size()); + EXPECT_EQ(kStreamId, visitor_.stream_frames_[0]->stream_id); + EXPECT_TRUE(visitor_.stream_frames_[0]->fin); + EXPECT_EQ(kStreamOffset, visitor_.stream_frames_[0]->offset); + CheckStreamFrameData("hello world!", visitor_.stream_frames_[0].get()); + + CheckFramingBoundaries(fragments, QUIC_INVALID_STREAM_DATA); +} + +// Test an empty (no data) stream frame. +TEST_P(QuicFramerTest, EmptyStreamFrame) { + // Only the IETF QUIC spec explicitly says that empty + // stream frames are supported. + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type - IETF_STREAM with FIN, LEN, and OFFSET bits set. + {"", + { 0x08 | 0x01 | 0x02 | 0x04 }}, + // stream id + {"Unable to read IETF_STREAM frame stream id/count.", + {kVarInt62FourBytes + 0x01, 0x02, 0x03, 0x04}}, + // offset + {"Unable to read stream data offset.", + {kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54}}, + // data length + {"Unable to read stream data length.", + {kVarInt62OneByte + 0x00}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + ASSERT_EQ(1u, visitor_.stream_frames_.size()); + EXPECT_EQ(0u, visitor_.ack_frames_.size()); + EXPECT_EQ(kStreamId, visitor_.stream_frames_[0]->stream_id); + EXPECT_TRUE(visitor_.stream_frames_[0]->fin); + EXPECT_EQ(kStreamOffset, visitor_.stream_frames_[0]->offset); + EXPECT_EQ(visitor_.stream_frames_[0].get()->data_length, 0u); + + CheckFramingBoundaries(packet, QUIC_INVALID_STREAM_DATA); +} + +TEST_P(QuicFramerTest, MissingDiversificationNonce) { + if (framer_.version().handshake_protocol != PROTOCOL_QUIC_CRYPTO) { + // TLS does not use diversification nonces. + return; + } + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + decrypter_ = new test::TestDecrypter(); + if (framer_.version().KnowsWhichDecrypterToUse()) { + framer_.InstallDecrypter( + ENCRYPTION_INITIAL, + std::make_unique(Perspective::IS_CLIENT)); + framer_.InstallDecrypter(ENCRYPTION_ZERO_RTT, + std::unique_ptr(decrypter_)); + } else { + framer_.SetDecrypter(ENCRYPTION_INITIAL, std::make_unique( + Perspective::IS_CLIENT)); + framer_.SetAlternativeDecrypter( + ENCRYPTION_ZERO_RTT, std::unique_ptr(decrypter_), false); + } + + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x28, + // connection_id + 0x10, 0x32, 0x54, 0x76, 0x98, 0xBA, 0xDC, 0xFE, + // packet number + 0x12, 0x34, 0x56, 0x78, + // padding frame + 0x00, + }; + + unsigned char packet46[] = { + // type (long header, ZERO_RTT_PROTECTED, 4-byte packet number) + 0xD3, + // version tag + QUIC_VERSION_BYTES, + // connection_id length + 0x05, + // connection_id + 0x10, 0x32, 0x54, 0x76, 0x98, 0xBA, 0xDC, 0xFE, + // packet number + 0x12, 0x34, 0x56, 0x78, + // padding frame + 0x00, + }; + + unsigned char packet49[] = { + // type (long header, ZERO_RTT_PROTECTED, 4-byte packet number) + 0xD3, + // version tag + QUIC_VERSION_BYTES, + // destination connection ID length + 0x00, + // source connection ID length + 0x08, + // source connection ID + 0x10, 0x32, 0x54, 0x76, 0x98, 0xBA, 0xDC, 0xFE, + // IETF long header payload length + 0x05, + // packet number + 0x12, 0x34, 0x56, 0x78, + // padding frame + 0x00, + }; + // clang-format on + + unsigned char* p = packet; + size_t p_length = ABSL_ARRAYSIZE(packet); + if (framer_.version().HasLongHeaderLengths()) { + p = packet49; + p_length = ABSL_ARRAYSIZE(packet49); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_length = ABSL_ARRAYSIZE(packet46); + } + QuicEncryptedPacket encrypted(AsChars(p), p_length, false); + EXPECT_FALSE(framer_.ProcessPacket(encrypted)); + if (framer_.version().HasHeaderProtection()) { + EXPECT_THAT(framer_.error(), IsError(QUIC_DECRYPTION_FAILURE)); + EXPECT_EQ("Unable to decrypt ENCRYPTION_ZERO_RTT header protection.", + framer_.detailed_error()); + } else if (framer_.version().HasIetfInvariantHeader()) { + // Cannot read diversification nonce. + EXPECT_THAT(framer_.error(), IsError(QUIC_INVALID_PACKET_HEADER)); + EXPECT_EQ("Unable to read nonce.", framer_.detailed_error()); + } else { + EXPECT_THAT(framer_.error(), IsError(QUIC_DECRYPTION_FAILURE)); + } +} + +TEST_P(QuicFramerTest, StreamFrame3ByteStreamId) { + if (framer_.version().HasIetfInvariantHeader()) { + // This test is nonsensical for IETF Quic. + return; + } + // clang-format off + PacketFragments packet = { + // public flags (8 byte connection_id) + {"", + {0x28}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (stream frame with fin) + {"", + {0xFE}}, + // stream id + {"Unable to read stream_id.", + {0x02, 0x03, 0x04}}, + // offset + {"Unable to read offset.", + {0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54}}, + {"Unable to read frame data.", + { + // data length + 0x00, 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!'}}, + }; + // clang-format on + + PacketFragments& fragments = packet; + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + ASSERT_EQ(1u, visitor_.stream_frames_.size()); + EXPECT_EQ(0u, visitor_.ack_frames_.size()); + // Stream ID should be the last 3 bytes of kStreamId. + EXPECT_EQ(0x00FFFFFF & kStreamId, visitor_.stream_frames_[0]->stream_id); + EXPECT_TRUE(visitor_.stream_frames_[0]->fin); + EXPECT_EQ(kStreamOffset, visitor_.stream_frames_[0]->offset); + CheckStreamFrameData("hello world!", visitor_.stream_frames_[0].get()); + + CheckFramingBoundaries(fragments, QUIC_INVALID_STREAM_DATA); +} + +TEST_P(QuicFramerTest, StreamFrame2ByteStreamId) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet = { + // public flags (8 byte connection_id) + {"", + {0x28}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (stream frame with fin) + {"", + {0xFD}}, + // stream id + {"Unable to read stream_id.", + {0x03, 0x04}}, + // offset + {"Unable to read offset.", + {0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54}}, + {"Unable to read frame data.", + { + // data length + 0x00, 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!'}}, + }; + + PacketFragments packet46 = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (stream frame with fin) + {"", + {0xFD}}, + // stream id + {"Unable to read stream_id.", + {0x03, 0x04}}, + // offset + {"Unable to read offset.", + {0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54}}, + {"Unable to read frame data.", + { + // data length + 0x00, 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!'}}, + }; + + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (IETF_STREAM frame with LEN, FIN, and OFFSET bits set) + {"", + {0x08 | 0x01 | 0x02 | 0x04}}, + // stream id + {"Unable to read IETF_STREAM frame stream id/count.", + {kVarInt62TwoBytes + 0x03, 0x04}}, + // offset + {"Unable to read stream data offset.", + {kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54}}, + // data length + {"Unable to read stream data length.", + {kVarInt62OneByte + 0x0c}}, + // data + {"Unable to read frame data.", + { 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!'}}, + }; + // clang-format on + + PacketFragments& fragments = + VersionHasIetfQuicFrames(framer_.transport_version()) + ? packet_ietf + : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + ASSERT_EQ(1u, visitor_.stream_frames_.size()); + EXPECT_EQ(0u, visitor_.ack_frames_.size()); + // Stream ID should be the last 2 bytes of kStreamId. + EXPECT_EQ(0x0000FFFF & kStreamId, visitor_.stream_frames_[0]->stream_id); + EXPECT_TRUE(visitor_.stream_frames_[0]->fin); + EXPECT_EQ(kStreamOffset, visitor_.stream_frames_[0]->offset); + CheckStreamFrameData("hello world!", visitor_.stream_frames_[0].get()); + + CheckFramingBoundaries(fragments, QUIC_INVALID_STREAM_DATA); +} + +TEST_P(QuicFramerTest, StreamFrame1ByteStreamId) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet = { + // public flags (8 byte connection_id) + {"", + {0x28}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (stream frame with fin) + {"", + {0xFC}}, + // stream id + {"Unable to read stream_id.", + {0x04}}, + // offset + {"Unable to read offset.", + {0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54}}, + {"Unable to read frame data.", + { + // data length + 0x00, 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!'}}, + }; + + PacketFragments packet46 = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (stream frame with fin) + {"", + {0xFC}}, + // stream id + {"Unable to read stream_id.", + {0x04}}, + // offset + {"Unable to read offset.", + {0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54}}, + {"Unable to read frame data.", + { + // data length + 0x00, 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!'}}, + }; + + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (IETF_STREAM frame with LEN, FIN, and OFFSET bits set) + {"", + {0x08 | 0x01 | 0x02 | 0x04}}, + // stream id + {"Unable to read IETF_STREAM frame stream id/count.", + {kVarInt62OneByte + 0x04}}, + // offset + {"Unable to read stream data offset.", + {kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54}}, + // data length + {"Unable to read stream data length.", + {kVarInt62OneByte + 0x0c}}, + // data + {"Unable to read frame data.", + { 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!'}}, + }; + // clang-format on + + PacketFragments& fragments = + VersionHasIetfQuicFrames(framer_.transport_version()) + ? packet_ietf + : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + ASSERT_EQ(1u, visitor_.stream_frames_.size()); + EXPECT_EQ(0u, visitor_.ack_frames_.size()); + // Stream ID should be the last 1 byte of kStreamId. + EXPECT_EQ(0x000000FF & kStreamId, visitor_.stream_frames_[0]->stream_id); + EXPECT_TRUE(visitor_.stream_frames_[0]->fin); + EXPECT_EQ(kStreamOffset, visitor_.stream_frames_[0]->offset); + CheckStreamFrameData("hello world!", visitor_.stream_frames_[0].get()); + + CheckFramingBoundaries(fragments, QUIC_INVALID_STREAM_DATA); +} + +TEST_P(QuicFramerTest, StreamFrameWithVersion) { + // If IETF frames are in use then we must also have the IETF + // header invariants. + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + QUICHE_DCHECK(framer_.version().HasIetfInvariantHeader()); + } + + SetDecrypterLevel(ENCRYPTION_ZERO_RTT); + // clang-format off + PacketFragments packet = { + // public flags (version, 8 byte connection_id) + {"", + {0x29}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // version tag + {"", + {QUIC_VERSION_BYTES}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (stream frame with fin) + {"", + {0xFE}}, + // stream id + {"Unable to read stream_id.", + {0x02, 0x03, 0x04}}, + // offset + {"Unable to read offset.", + {0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54}}, + {"Unable to read frame data.", + { + // data length + 0x00, 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!'}}, + }; + + PacketFragments packet46 = { + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + {"", + {0xD3}}, + // version tag + {"", + {QUIC_VERSION_BYTES}}, + // connection_id length + {"", + {0x50}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (stream frame with fin) + {"", + {0xFE}}, + // stream id + {"Unable to read stream_id.", + {0x02, 0x03, 0x04}}, + // offset + {"Unable to read offset.", + {0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54}}, + {"Unable to read frame data.", + { + // data length + 0x00, 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!'}}, + }; + + PacketFragments packet49 = { + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + {"", + {0xD3}}, + // version tag + {"", + {QUIC_VERSION_BYTES}}, + // destination connection ID length + {"", + {0x08}}, + // destination connection ID + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // source connection ID length + {"", + {0x00}}, + // long header packet length + {"", + {0x1E}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (stream frame with fin) + {"", + {0xFE}}, + // stream id + {"Long header payload length longer than packet.", + {0x02, 0x03, 0x04}}, + // offset + {"Long header payload length longer than packet.", + {0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54}}, + {"Long header payload length longer than packet.", + { + // data length + 0x00, 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!'}}, + }; + + PacketFragments packet_ietf = { + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + {"", + {0xD3}}, + // version tag + {"", + {QUIC_VERSION_BYTES}}, + // destination connection ID length + {"", + {0x08}}, + // destination connection ID + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // source connection ID length + {"", + {0x00}}, + // long header packet length + {"", + {0x1E}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (IETF_STREAM frame with FIN, LEN, and OFFSET bits set) + {"", + {0x08 | 0x01 | 0x02 | 0x04}}, + // stream id + {"Long header payload length longer than packet.", + {kVarInt62FourBytes + 0x00, 0x02, 0x03, 0x04}}, + // offset + {"Long header payload length longer than packet.", + {kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54}}, + // data length + {"Long header payload length longer than packet.", + {kVarInt62OneByte + 0x0c}}, + // data + {"Long header payload length longer than packet.", + { 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!'}}, + }; + // clang-format on + + quiche::QuicheVariableLengthIntegerLength retry_token_length_length = + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0; + size_t retry_token_length = 0; + quiche::QuicheVariableLengthIntegerLength length_length = + QuicVersionHasLongHeaderLengths(framer_.transport_version()) + ? quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1 + : quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0; + + ReviseFirstByteByVersion(packet_ietf); + PacketFragments& fragments = + VersionHasIetfQuicFrames(framer_.transport_version()) + ? packet_ietf + : (framer_.version().HasLongHeaderLengths() + ? packet49 + : (framer_.version().HasIetfInvariantHeader() ? packet46 + : packet)); + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId, + retry_token_length_length, retry_token_length, length_length)); + + ASSERT_EQ(1u, visitor_.stream_frames_.size()); + EXPECT_EQ(0u, visitor_.ack_frames_.size()); + // Stream ID should be the last 3 bytes of kStreamId. + EXPECT_EQ(0x00FFFFFF & kStreamId, visitor_.stream_frames_[0]->stream_id); + EXPECT_TRUE(visitor_.stream_frames_[0]->fin); + EXPECT_EQ(kStreamOffset, visitor_.stream_frames_[0]->offset); + CheckStreamFrameData("hello world!", visitor_.stream_frames_[0].get()); + + CheckFramingBoundaries(fragments, framer_.version().HasLongHeaderLengths() + ? QUIC_INVALID_PACKET_HEADER + : QUIC_INVALID_STREAM_DATA); +} + +TEST_P(QuicFramerTest, RejectPacket) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + visitor_.accept_packet_ = false; + + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x28, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (stream frame with fin) + 0xFF, + // stream id + 0x01, 0x02, 0x03, 0x04, + // offset + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54, + // data length + 0x00, 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + }; + + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (STREAM Frame with FIN, LEN, and OFFSET bits set) + 0x10 | 0x01 | 0x02 | 0x04, + // stream id + kVarInt62FourBytes + 0x01, 0x02, 0x03, 0x04, + // offset + kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54, + // data length + kVarInt62OneByte + 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + }; + // clang-format on + + unsigned char* p = packet; + if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + } + QuicEncryptedPacket encrypted(AsChars(p), + framer_.version().HasIetfInvariantHeader() + ? ABSL_ARRAYSIZE(packet46) + : ABSL_ARRAYSIZE(packet), + false); + EXPECT_TRUE(framer_.ProcessPacket(encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + ASSERT_EQ(0u, visitor_.stream_frames_.size()); + EXPECT_EQ(0u, visitor_.ack_frames_.size()); +} + +TEST_P(QuicFramerTest, RejectPublicHeader) { + visitor_.accept_public_header_ = false; + + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x28, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + }; + + unsigned char packet46[] = { + // type (short header, 1 byte packet number) + 0x40, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x01, + }; + // clang-format on + + QuicEncryptedPacket encrypted( + framer_.version().HasIetfInvariantHeader() ? AsChars(packet46) + : AsChars(packet), + framer_.version().HasIetfInvariantHeader() ? ABSL_ARRAYSIZE(packet46) + : ABSL_ARRAYSIZE(packet), + false); + EXPECT_TRUE(framer_.ProcessPacket(encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_FALSE(visitor_.header_->packet_number.IsInitialized()); +} + +TEST_P(QuicFramerTest, AckFrameOneAckBlock) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet = { + // public flags (8 byte connection_id) + {"", + {0x2C}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (ack frame) + // (one ack block, 2 byte largest observed, 2 byte block length) + {"", + {0x45}}, + // largest acked + {"Unable to read largest acked.", + {0x12, 0x34}}, + // Zero delta time. + {"Unable to read ack delay time.", + {0x00, 0x00}}, + // first ack block length. + {"Unable to read first ack block length.", + {0x12, 0x34}}, + // num timestamps. + {"Unable to read num received packets.", + {0x00}} + }; + + PacketFragments packet46 = { + // type (short packet, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (ack frame) + // (one ack block, 2 byte largest observed, 2 byte block length) + {"", + {0x45}}, + // largest acked + {"Unable to read largest acked.", + {0x12, 0x34}}, + // Zero delta time. + {"Unable to read ack delay time.", + {0x00, 0x00}}, + // first ack block length. + {"Unable to read first ack block length.", + {0x12, 0x34}}, + // num timestamps. + {"Unable to read num received packets.", + {0x00}} + }; + + PacketFragments packet_ietf = { + // type (short packet, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (IETF_ACK) + // (one ack block, 2 byte largest observed, 2 byte block length) + // IETF-Quic ignores the bit-fields in the ack type, all of + // that information is encoded elsewhere in the frame. + {"", + {0x02}}, + // largest acked + {"Unable to read largest acked.", + {kVarInt62TwoBytes + 0x12, 0x34}}, + // Zero delta time. + {"Unable to read ack delay time.", + {kVarInt62OneByte + 0x00}}, + // Ack block count (0 -- no blocks after the first) + {"Unable to read ack block count.", + {kVarInt62OneByte + 0x00}}, + // first ack block length - 1. + // IETF Quic defines the ack block's value as the "number of + // packets that preceed the largest packet number in the block" + // which for the 1st ack block is the largest acked field, + // above. This means that if we are acking just packet 0x1234 + // then the 1st ack block will be 0. + {"Unable to read first ack block length.", + {kVarInt62TwoBytes + 0x12, 0x33}} + }; + // clang-format on + + PacketFragments& fragments = + VersionHasIetfQuicFrames(framer_.transport_version()) + ? packet_ietf + : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(0u, visitor_.stream_frames_.size()); + ASSERT_EQ(1u, visitor_.ack_frames_.size()); + const QuicAckFrame& frame = *visitor_.ack_frames_[0]; + EXPECT_EQ(kSmallLargestObserved, LargestAcked(frame)); + ASSERT_EQ(4660u, frame.packets.NumPacketsSlow()); + + CheckFramingBoundaries(fragments, QUIC_INVALID_ACK_DATA); +} + +// This test checks that the ack frame processor correctly identifies +// and handles the case where the first ack block is larger than the +// largest_acked packet. +TEST_P(QuicFramerTest, FirstAckFrameUnderflow) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet = { + // public flags (8 byte connection_id) + {"", + {0x2C}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (ack frame) + // (one ack block, 2 byte largest observed, 2 byte block length) + {"", + {0x45}}, + // largest acked + {"Unable to read largest acked.", + {0x12, 0x34}}, + // Zero delta time. + {"Unable to read ack delay time.", + {0x00, 0x00}}, + // first ack block length. + {"Unable to read first ack block length.", + {0x88, 0x88}}, + // num timestamps. + {"Underflow with first ack block length 34952 largest acked is 4660.", + {0x00}} + }; + + PacketFragments packet46 = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (ack frame) + // (one ack block, 2 byte largest observed, 2 byte block length) + {"", + {0x45}}, + // largest acked + {"Unable to read largest acked.", + {0x12, 0x34}}, + // Zero delta time. + {"Unable to read ack delay time.", + {0x00, 0x00}}, + // first ack block length. + {"Unable to read first ack block length.", + {0x88, 0x88}}, + // num timestamps. + {"Underflow with first ack block length 34952 largest acked is 4660.", + {0x00}} + }; + + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (IETF_ACK) + {"", + {0x02}}, + // largest acked + {"Unable to read largest acked.", + {kVarInt62TwoBytes + 0x12, 0x34}}, + // Zero delta time. + {"Unable to read ack delay time.", + {kVarInt62OneByte + 0x00}}, + // Ack block count (0 -- no blocks after the first) + {"Unable to read ack block count.", + {kVarInt62OneByte + 0x00}}, + // first ack block length. + {"Unable to read first ack block length.", + {kVarInt62TwoBytes + 0x28, 0x88}} + }; + // clang-format on + + PacketFragments& fragments = + VersionHasIetfQuicFrames(framer_.transport_version()) + ? packet_ietf + : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + CheckFramingBoundaries(fragments, QUIC_INVALID_ACK_DATA); +} + +// This test checks that the ack frame processor correctly identifies +// and handles the case where the third ack block's gap is larger than the +// available space in the ack range. +TEST_P(QuicFramerTest, ThirdAckBlockUnderflowGap) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // Test originally written for development of IETF QUIC. The test may + // also apply to Google QUIC. If so, the test should be extended to + // include Google QUIC (frame formats, etc). See b/141858819. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (IETF_ACK frame) + {"", + {0x02}}, + // largest acked + {"Unable to read largest acked.", + {kVarInt62OneByte + 63}}, + // Zero delta time. + {"Unable to read ack delay time.", + {kVarInt62OneByte + 0x00}}, + // Ack block count (2 -- 2 blocks after the first) + {"Unable to read ack block count.", + {kVarInt62OneByte + 0x02}}, + // first ack block length. + {"Unable to read first ack block length.", + {kVarInt62OneByte + 13}}, // Ack 14 packets, range 50..63 (inclusive) + + {"Unable to read gap block value.", + {kVarInt62OneByte + 9}}, // Gap 10 packets, 40..49 (inclusive) + {"Unable to read ack block value.", + {kVarInt62OneByte + 9}}, // Ack 10 packets, 30..39 (inclusive) + {"Unable to read gap block value.", + {kVarInt62OneByte + 29}}, // A gap of 30 packets (0..29 inclusive) + // should be too big, leaving no room + // for the ack. + {"Underflow with gap block length 30 previous ack block start is 30.", + {kVarInt62OneByte + 10}}, // Don't care + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ( + framer_.detailed_error(), + "Underflow with gap block length 30 previous ack block start is 30."); + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_ACK_DATA); +} + +// This test checks that the ack frame processor correctly identifies +// and handles the case where the third ack block's length is larger than the +// available space in the ack range. +TEST_P(QuicFramerTest, ThirdAckBlockUnderflowAck) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // Test originally written for development of IETF QUIC. The test may + // also apply to Google QUIC. If so, the test should be extended to + // include Google QUIC (frame formats, etc). See b/141858819. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (IETF_ACK frame) + {"", + {0x02}}, + // largest acked + {"Unable to read largest acked.", + {kVarInt62OneByte + 63}}, + // Zero delta time. + {"Unable to read ack delay time.", + {kVarInt62OneByte + 0x00}}, + // Ack block count (2 -- 2 blocks after the first) + {"Unable to read ack block count.", + {kVarInt62OneByte + 0x02}}, + // first ack block length. + {"Unable to read first ack block length.", + {kVarInt62OneByte + 13}}, // only 50 packet numbers "left" + + {"Unable to read gap block value.", + {kVarInt62OneByte + 10}}, // Only 40 packet numbers left + {"Unable to read ack block value.", + {kVarInt62OneByte + 10}}, // only 30 packet numbers left. + {"Unable to read gap block value.", + {kVarInt62OneByte + 1}}, // Gap is OK, 29 packet numbers left + {"Unable to read ack block value.", + {kVarInt62OneByte + 30}}, // Use up all 30, should be an error + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(framer_.detailed_error(), + "Underflow with ack block length 31 latest ack block end is 25."); + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_ACK_DATA); +} + +// Tests a variety of ack block wrap scenarios. For example, if the +// N-1th block causes packet 0 to be acked, then a gap would wrap +// around to 0x3fffffff ffffffff... Make sure we detect this +// condition. +TEST_P(QuicFramerTest, AckBlockUnderflowGapWrap) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // Test originally written for development of IETF QUIC. The test may + // also apply to Google QUIC. If so, the test should be extended to + // include Google QUIC (frame formats, etc). See b/141858819. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (IETF_ACK frame) + {"", + {0x02}}, + // largest acked + {"Unable to read largest acked.", + {kVarInt62OneByte + 10}}, + // Zero delta time. + {"Unable to read ack delay time.", + {kVarInt62OneByte + 0x00}}, + // Ack block count (1 -- 1 blocks after the first) + {"Unable to read ack block count.", + {kVarInt62OneByte + 1}}, + // first ack block length. + {"Unable to read first ack block length.", + {kVarInt62OneByte + 9}}, // Ack packets 1..10 (inclusive) + + {"Unable to read gap block value.", + {kVarInt62OneByte + 1}}, // Gap of 2 packets (-1...0), should wrap + {"Underflow with gap block length 2 previous ack block start is 1.", + {kVarInt62OneByte + 9}}, // irrelevant + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(framer_.detailed_error(), + "Underflow with gap block length 2 previous ack block start is 1."); + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_ACK_DATA); +} + +// As AckBlockUnderflowGapWrap, but in this test, it's the ack +// component of the ack-block that causes the wrap, not the gap. +TEST_P(QuicFramerTest, AckBlockUnderflowAckWrap) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // Test originally written for development of IETF QUIC. The test may + // also apply to Google QUIC. If so, the test should be extended to + // include Google QUIC (frame formats, etc). See b/141858819. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (IETF_ACK frame) + {"", + {0x02}}, + // largest acked + {"Unable to read largest acked.", + {kVarInt62OneByte + 10}}, + // Zero delta time. + {"Unable to read ack delay time.", + {kVarInt62OneByte + 0x00}}, + // Ack block count (1 -- 1 blocks after the first) + {"Unable to read ack block count.", + {kVarInt62OneByte + 1}}, + // first ack block length. + {"Unable to read first ack block length.", + {kVarInt62OneByte + 6}}, // Ack packets 4..10 (inclusive) + + {"Unable to read gap block value.", + {kVarInt62OneByte + 1}}, // Gap of 2 packets (2..3) + {"Unable to read ack block value.", + {kVarInt62OneByte + 9}}, // Should wrap. + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(framer_.detailed_error(), + "Underflow with ack block length 10 latest ack block end is 1."); + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_ACK_DATA); +} + +// An ack block that acks the entire range, 1...0x3fffffffffffffff +TEST_P(QuicFramerTest, AckBlockAcksEverything) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // Test originally written for development of IETF QUIC. The test may + // also apply to Google QUIC. If so, the test should be extended to + // include Google QUIC (frame formats, etc). See b/141858819. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (IETF_ACK frame) + {"", + {0x02}}, + // largest acked + {"Unable to read largest acked.", + {kVarInt62EightBytes + 0x3f, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff}}, + // Zero delta time. + {"Unable to read ack delay time.", + {kVarInt62OneByte + 0x00}}, + // Ack block count No additional blocks + {"Unable to read ack block count.", + {kVarInt62OneByte + 0}}, + // first ack block length. + {"Unable to read first ack block length.", + {kVarInt62EightBytes + 0x3f, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xfe}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(1u, visitor_.ack_frames_.size()); + const QuicAckFrame& frame = *visitor_.ack_frames_[0]; + EXPECT_EQ(1u, frame.packets.NumIntervals()); + EXPECT_EQ(kLargestIetfLargestObserved, LargestAcked(frame)); + EXPECT_EQ(kLargestIetfLargestObserved.ToUint64(), + frame.packets.NumPacketsSlow()); +} + +// This test looks for a malformed ack where +// - There is a largest-acked value (that is, the frame is acking +// something, +// - But the length of the first ack block is 0 saying that no frames +// are being acked with the largest-acked value or there are no +// additional ack blocks. +// +TEST_P(QuicFramerTest, AckFrameFirstAckBlockLengthZero) { + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + // Not applicable to version 99 -- first ack block contains the + // number of packets that preceed the largest_acked packet. + // A value of 0 means no packets preceed --- that the block's + // length is 1. Therefore the condition that this test checks can + // not arise. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + PacketFragments packet = { + // public flags (8 byte connection_id) + {"", + { 0x2C }}, + // connection_id + {"", + { 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10 }}, + // packet number + {"", + { 0x12, 0x34, 0x56, 0x78 }}, + + // frame type (ack frame) + // (more than one ack block, 2 byte largest observed, 2 byte block length) + {"", + { 0x65 }}, + // largest acked + {"Unable to read largest acked.", + { 0x12, 0x34 }}, + // Zero delta time. + {"Unable to read ack delay time.", + { 0x00, 0x00 }}, + // num ack blocks ranges. + {"Unable to read num of ack blocks.", + { 0x01 }}, + // first ack block length. + {"Unable to read first ack block length.", + { 0x00, 0x00 }}, + // gap to next block. + { "First block length is zero.", + { 0x01 }}, + // ack block length. + { "First block length is zero.", + { 0x0e, 0xaf }}, + // Number of timestamps. + { "First block length is zero.", + { 0x00 }}, + }; + + PacketFragments packet46 = { + // type (short header, 4 byte packet number) + {"", + { 0x43 }}, + // connection_id + {"", + { 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10 }}, + // packet number + {"", + { 0x12, 0x34, 0x56, 0x78 }}, + + // frame type (ack frame) + // (more than one ack block, 2 byte largest observed, 2 byte block length) + {"", + { 0x65 }}, + // largest acked + {"Unable to read largest acked.", + { 0x12, 0x34 }}, + // Zero delta time. + {"Unable to read ack delay time.", + { 0x00, 0x00 }}, + // num ack blocks ranges. + {"Unable to read num of ack blocks.", + { 0x01 }}, + // first ack block length. + {"Unable to read first ack block length.", + { 0x00, 0x00 }}, + // gap to next block. + { "First block length is zero.", + { 0x01 }}, + // ack block length. + { "First block length is zero.", + { 0x0e, 0xaf }}, + // Number of timestamps. + { "First block length is zero.", + { 0x00 }}, + }; + + // clang-format on + PacketFragments& fragments = + framer_.version().HasIetfInvariantHeader() ? packet46 : packet; + + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsError(QUIC_INVALID_ACK_DATA)); + + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(0u, visitor_.stream_frames_.size()); + ASSERT_EQ(1u, visitor_.ack_frames_.size()); + + CheckFramingBoundaries(fragments, QUIC_INVALID_ACK_DATA); +} + +TEST_P(QuicFramerTest, AckFrameOneAckBlockMaxLength) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet = { + // public flags (8 byte connection_id) + {"", + {0x2C}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (ack frame) + // (one ack block, 4 byte largest observed, 2 byte block length) + {"", + {0x49}}, + // largest acked + {"Unable to read largest acked.", + {0x12, 0x34, 0x56, 0x78}}, + // Zero delta time. + {"Unable to read ack delay time.", + {0x00, 0x00}}, + // first ack block length. + {"Unable to read first ack block length.", + {0x12, 0x34}}, + // num timestamps. + {"Unable to read num received packets.", + {0x00}} + }; + + PacketFragments packet46 = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x56, 0x78, 0x9A, 0xBC}}, + // frame type (ack frame) + // (one ack block, 4 byte largest observed, 2 byte block length) + {"", + {0x49}}, + // largest acked + {"Unable to read largest acked.", + {0x12, 0x34, 0x56, 0x78}}, + // Zero delta time. + {"Unable to read ack delay time.", + {0x00, 0x00}}, + // first ack block length. + {"Unable to read first ack block length.", + {0x12, 0x34}}, + // num timestamps. + {"Unable to read num received packets.", + {0x00}} + }; + + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x56, 0x78, 0x9A, 0xBC}}, + // frame type (IETF_ACK frame) + {"", + {0x02}}, + // largest acked + {"Unable to read largest acked.", + {kVarInt62FourBytes + 0x12, 0x34, 0x56, 0x78}}, + // Zero delta time. + {"Unable to read ack delay time.", + {kVarInt62OneByte + 0x00}}, + // Number of ack blocks after first + {"Unable to read ack block count.", + {kVarInt62OneByte + 0x00}}, + // first ack block length. + {"Unable to read first ack block length.", + {kVarInt62TwoBytes + 0x12, 0x33}} + }; + // clang-format on + + PacketFragments& fragments = + VersionHasIetfQuicFrames(framer_.transport_version()) + ? packet_ietf + : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(0u, visitor_.stream_frames_.size()); + ASSERT_EQ(1u, visitor_.ack_frames_.size()); + const QuicAckFrame& frame = *visitor_.ack_frames_[0]; + EXPECT_EQ(kPacketNumber, LargestAcked(frame)); + ASSERT_EQ(4660u, frame.packets.NumPacketsSlow()); + + CheckFramingBoundaries(fragments, QUIC_INVALID_ACK_DATA); +} + +// Tests ability to handle multiple ackblocks after the first ack +// block. Non-version-99 tests include multiple timestamps as well. +TEST_P(QuicFramerTest, AckFrameTwoTimeStampsMultipleAckBlocks) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet = { + // public flags (8 byte connection_id) + {"", + { 0x2C }}, + // connection_id + {"", + { 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10 }}, + // packet number + {"", + { 0x12, 0x34, 0x56, 0x78 }}, + + // frame type (ack frame) + // (more than one ack block, 2 byte largest observed, 2 byte block length) + {"", + { 0x65 }}, + // largest acked + {"Unable to read largest acked.", + { 0x12, 0x34 }}, + // Zero delta time. + {"Unable to read ack delay time.", + { 0x00, 0x00 }}, + // num ack blocks ranges. + {"Unable to read num of ack blocks.", + { 0x04 }}, + // first ack block length. + {"Unable to read first ack block length.", + { 0x00, 0x01 }}, + // gap to next block. + { "Unable to read gap to next ack block.", + { 0x01 }}, + // ack block length. + { "Unable to ack block length.", + { 0x0e, 0xaf }}, + // gap to next block. + { "Unable to read gap to next ack block.", + { 0xff }}, + // ack block length. + { "Unable to ack block length.", + { 0x00, 0x00 }}, + // gap to next block. + { "Unable to read gap to next ack block.", + { 0x91 }}, + // ack block length. + { "Unable to ack block length.", + { 0x01, 0xea }}, + // gap to next block. + { "Unable to read gap to next ack block.", + { 0x05 }}, + // ack block length. + { "Unable to ack block length.", + { 0x00, 0x04 }}, + // Number of timestamps. + { "Unable to read num received packets.", + { 0x02 }}, + // Delta from largest observed. + { "Unable to read sequence delta in received packets.", + { 0x01 }}, + // Delta time. + { "Unable to read time delta in received packets.", + { 0x76, 0x54, 0x32, 0x10 }}, + // Delta from largest observed. + { "Unable to read sequence delta in received packets.", + { 0x02 }}, + // Delta time. + { "Unable to read incremental time delta in received packets.", + { 0x32, 0x10 }}, + }; + + PacketFragments packet46 = { + // type (short header, 4 byte packet number) + {"", + { 0x43 }}, + // connection_id + {"", + { 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10 }}, + // packet number + {"", + { 0x12, 0x34, 0x56, 0x78 }}, + + // frame type (ack frame) + // (more than one ack block, 2 byte largest observed, 2 byte block length) + {"", + { 0x65 }}, + // largest acked + {"Unable to read largest acked.", + { 0x12, 0x34 }}, + // Zero delta time. + {"Unable to read ack delay time.", + { 0x00, 0x00 }}, + // num ack blocks ranges. + {"Unable to read num of ack blocks.", + { 0x04 }}, + // first ack block length. + {"Unable to read first ack block length.", + { 0x00, 0x01 }}, + // gap to next block. + { "Unable to read gap to next ack block.", + { 0x01 }}, + // ack block length. + { "Unable to ack block length.", + { 0x0e, 0xaf }}, + // gap to next block. + { "Unable to read gap to next ack block.", + { 0xff }}, + // ack block length. + { "Unable to ack block length.", + { 0x00, 0x00 }}, + // gap to next block. + { "Unable to read gap to next ack block.", + { 0x91 }}, + // ack block length. + { "Unable to ack block length.", + { 0x01, 0xea }}, + // gap to next block. + { "Unable to read gap to next ack block.", + { 0x05 }}, + // ack block length. + { "Unable to ack block length.", + { 0x00, 0x04 }}, + // Number of timestamps. + { "Unable to read num received packets.", + { 0x02 }}, + // Delta from largest observed. + { "Unable to read sequence delta in received packets.", + { 0x01 }}, + // Delta time. + { "Unable to read time delta in received packets.", + { 0x76, 0x54, 0x32, 0x10 }}, + // Delta from largest observed. + { "Unable to read sequence delta in received packets.", + { 0x02 }}, + // Delta time. + { "Unable to read incremental time delta in received packets.", + { 0x32, 0x10 }}, + }; + + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + { 0x43 }}, + // connection_id + {"", + { 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10 }}, + // packet number + {"", + { 0x12, 0x34, 0x56, 0x78 }}, + + // frame type (IETF_ACK_RECEIVE_TIMESTAMPS frame) + {"", + { 0x22 }}, + // largest acked + {"Unable to read largest acked.", + { kVarInt62TwoBytes + 0x12, 0x34 }}, // = 4660 + // Zero delta time. + {"Unable to read ack delay time.", + { kVarInt62OneByte + 0x00 }}, + // number of additional ack blocks + {"Unable to read ack block count.", + { kVarInt62OneByte + 0x03 }}, + // first ack block length. + {"Unable to read first ack block length.", + { kVarInt62OneByte + 0x00 }}, // 1st block length = 1 + + // Additional ACK Block #1 + // gap to next block. + { "Unable to read gap block value.", + { kVarInt62OneByte + 0x00 }}, // gap of 1 packet + // ack block length. + { "Unable to read ack block value.", + { kVarInt62TwoBytes + 0x0e, 0xae }}, // 3759 + + // pre-version-99 test includes an ack block of 0 length. this + // can not happen in version 99. ergo the second block is not + // present in the v99 test and the gap length of the next block + // is the sum of the two gaps in the pre-version-99 tests. + // Additional ACK Block #2 + // gap to next block. + { "Unable to read gap block value.", + { kVarInt62TwoBytes + 0x01, 0x8f }}, // Gap is 400 (0x190) pkts + // ack block length. + { "Unable to read ack block value.", + { kVarInt62TwoBytes + 0x01, 0xe9 }}, // block is 389 (x1ea) pkts + + // Additional ACK Block #3 + // gap to next block. + { "Unable to read gap block value.", + { kVarInt62OneByte + 0x04 }}, // Gap is 5 packets. + // ack block length. + { "Unable to read ack block value.", + { kVarInt62OneByte + 0x03 }}, // block is 3 packets. + + // Receive Timestamps. + { "Unable to read receive timestamp range count.", + { kVarInt62OneByte + 0x01 }}, + { "Unable to read receive timestamp gap.", + { kVarInt62OneByte + 0x01 }}, + { "Unable to read receive timestamp count.", + { kVarInt62OneByte + 0x02 }}, + { "Unable to read receive timestamp delta.", + { kVarInt62FourBytes + 0x36, 0x54, 0x32, 0x10 }}, + { "Unable to read receive timestamp delta.", + { kVarInt62TwoBytes + 0x32, 0x10 }}, + }; + + // clang-format on + PacketFragments& fragments = + VersionHasIetfQuicFrames(framer_.transport_version()) + ? packet_ietf + : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); + + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + + framer_.set_process_timestamps(true); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(0u, visitor_.stream_frames_.size()); + ASSERT_EQ(1u, visitor_.ack_frames_.size()); + const QuicAckFrame& frame = *visitor_.ack_frames_[0]; + EXPECT_EQ(kSmallLargestObserved, LargestAcked(frame)); + ASSERT_EQ(4254u, frame.packets.NumPacketsSlow()); + EXPECT_EQ(4u, frame.packets.NumIntervals()); + EXPECT_EQ(2u, frame.received_packet_times.size()); +} + +TEST_P(QuicFramerTest, AckFrameMultipleReceiveTimestampRanges) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + { 0x43 }}, + // connection_id + {"", + { 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10 }}, + // packet number + {"", + { 0x12, 0x34, 0x56, 0x78 }}, + + // frame type (IETF_ACK_RECEIVE_TIMESTAMPS frame) + {"", + { 0x22 }}, + // largest acked + {"Unable to read largest acked.", + { kVarInt62TwoBytes + 0x12, 0x34 }}, // = 4660 + // Zero delta time. + {"Unable to read ack delay time.", + { kVarInt62OneByte + 0x00 }}, + // number of additional ack blocks + {"Unable to read ack block count.", + { kVarInt62OneByte + 0x00 }}, + // first ack block length. + {"Unable to read first ack block length.", + { kVarInt62OneByte + 0x00 }}, // 1st block length = 1 + + // Receive Timestamps. + { "Unable to read receive timestamp range count.", + { kVarInt62OneByte + 0x03 }}, + + // Timestamp range 1 (three packets). + { "Unable to read receive timestamp gap.", + { kVarInt62OneByte + 0x02 }}, + { "Unable to read receive timestamp count.", + { kVarInt62OneByte + 0x03 }}, + { "Unable to read receive timestamp delta.", + { kVarInt62FourBytes + 0x29, 0xff, 0xff, 0xff}}, + { "Unable to read receive timestamp delta.", + { kVarInt62TwoBytes + 0x11, 0x11 }}, + { "Unable to read receive timestamp delta.", + { kVarInt62OneByte + 0x01}}, + + // Timestamp range 2 (one packet). + { "Unable to read receive timestamp gap.", + { kVarInt62OneByte + 0x05 }}, + { "Unable to read receive timestamp count.", + { kVarInt62OneByte + 0x01 }}, + { "Unable to read receive timestamp delta.", + { kVarInt62TwoBytes + 0x10, 0x00 }}, + + // Timestamp range 3 (two packets). + { "Unable to read receive timestamp gap.", + { kVarInt62OneByte + 0x08 }}, + { "Unable to read receive timestamp count.", + { kVarInt62OneByte + 0x02 }}, + { "Unable to read receive timestamp delta.", + { kVarInt62OneByte + 0x10 }}, + { "Unable to read receive timestamp delta.", + { kVarInt62TwoBytes + 0x01, 0x00 }}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + + framer_.set_process_timestamps(true); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + const QuicAckFrame& frame = *visitor_.ack_frames_[0]; + + EXPECT_THAT(frame.received_packet_times, + ContainerEq(PacketTimeVector{ + // Timestamp Range 1. + {LargestAcked(frame) - 2, CreationTimePlus(0x29ffffff)}, + {LargestAcked(frame) - 3, CreationTimePlus(0x29ffeeee)}, + {LargestAcked(frame) - 4, CreationTimePlus(0x29ffeeed)}, + // Timestamp Range 2. + {LargestAcked(frame) - 11, CreationTimePlus(0x29ffdeed)}, + // Timestamp Range 3. + {LargestAcked(frame) - 21, CreationTimePlus(0x29ffdedd)}, + {LargestAcked(frame) - 22, CreationTimePlus(0x29ffdddd)}, + })); +} + +TEST_P(QuicFramerTest, AckFrameReceiveTimestampWithExponent) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + { 0x43 }}, + // connection_id + {"", + { 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10 }}, + // packet number + {"", + { 0x12, 0x34, 0x56, 0x78 }}, + + // frame type (IETF_ACK_RECEIVE_TIMESTAMPS frame) + {"", + { 0x22 }}, + // largest acked + {"Unable to read largest acked.", + { kVarInt62TwoBytes + 0x12, 0x34 }}, // = 4660 + // Zero delta time. + {"Unable to read ack delay time.", + { kVarInt62OneByte + 0x00 }}, + // number of additional ack blocks + {"Unable to read ack block count.", + { kVarInt62OneByte + 0x00 }}, + // first ack block length. + {"Unable to read first ack block length.", + { kVarInt62OneByte + 0x00 }}, // 1st block length = 1 + + // Receive Timestamps. + { "Unable to read receive timestamp range count.", + { kVarInt62OneByte + 0x01 }}, + { "Unable to read receive timestamp gap.", + { kVarInt62OneByte + 0x00 }}, + { "Unable to read receive timestamp count.", + { kVarInt62OneByte + 0x03 }}, + { "Unable to read receive timestamp delta.", + { kVarInt62TwoBytes + 0x29, 0xff}}, + { "Unable to read receive timestamp delta.", + { kVarInt62TwoBytes + 0x11, 0x11 }}, + { "Unable to read receive timestamp delta.", + { kVarInt62OneByte + 0x01}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + + framer_.set_receive_timestamps_exponent(3); + framer_.set_process_timestamps(true); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + const QuicAckFrame& frame = *visitor_.ack_frames_[0]; + + EXPECT_THAT(frame.received_packet_times, + ContainerEq(PacketTimeVector{ + // Timestamp Range 1. + {LargestAcked(frame), CreationTimePlus(0x29ff << 3)}, + {LargestAcked(frame) - 1, CreationTimePlus(0x18ee << 3)}, + {LargestAcked(frame) - 2, CreationTimePlus(0x18ed << 3)}, + })); +} + +TEST_P(QuicFramerTest, AckFrameReceiveTimestampGapTooHigh) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + { 0x43 }}, + // connection_id + {"", + { 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10 }}, + // packet number + {"", + { 0x12, 0x34, 0x56, 0x78 }}, + + // frame type (IETF_ACK_RECEIVE_TIMESTAMPS frame) + {"", + { 0x22 }}, + // largest acked + {"Unable to read largest acked.", + { kVarInt62TwoBytes + 0x12, 0x34 }}, // = 4660 + // Zero delta time. + {"Unable to read ack delay time.", + { kVarInt62OneByte + 0x00 }}, + // number of additional ack blocks + {"Unable to read ack block count.", + { kVarInt62OneByte + 0x00 }}, + // first ack block length. + {"Unable to read first ack block length.", + { kVarInt62OneByte + 0x00 }}, // 1st block length = 1 + + // Receive Timestamps. + { "Unable to read receive timestamp range count.", + { kVarInt62OneByte + 0x01 }}, + { "Unable to read receive timestamp gap.", + { kVarInt62FourBytes + 0x12, 0x34, 0x56, 0x79 }}, + { "Unable to read receive timestamp count.", + { kVarInt62OneByte + 0x01 }}, + { "Unable to read receive timestamp delta.", + { kVarInt62TwoBytes + 0x29, 0xff}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + + framer_.set_process_timestamps(true); + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_TRUE(absl::StartsWith(framer_.detailed_error(), + "Receive timestamp gap too high.")); +} + +TEST_P(QuicFramerTest, AckFrameReceiveTimestampCountTooHigh) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + { 0x43 }}, + // connection_id + {"", + { 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10 }}, + // packet number + {"", + { 0x12, 0x34, 0x56, 0x78 }}, + + // frame type (IETF_ACK_RECEIVE_TIMESTAMPS frame) + {"", + { 0x22 }}, + // largest acked + {"Unable to read largest acked.", + { kVarInt62TwoBytes + 0x12, 0x34 }}, // = 4660 + // Zero delta time. + {"Unable to read ack delay time.", + { kVarInt62OneByte + 0x00 }}, + // number of additional ack blocks + {"Unable to read ack block count.", + { kVarInt62OneByte + 0x00 }}, + // first ack block length. + {"Unable to read first ack block length.", + { kVarInt62OneByte + 0x00 }}, // 1st block length = 1 + + // Receive Timestamps. + { "Unable to read receive timestamp range count.", + { kVarInt62OneByte + 0x01 }}, + { "Unable to read receive timestamp gap.", + { kVarInt62OneByte + 0x02 }}, + { "Unable to read receive timestamp count.", + { kVarInt62OneByte + 0x02 }}, + { "Unable to read receive timestamp delta.", + { kVarInt62OneByte + 0x0a}}, + { "Unable to read receive timestamp delta.", + { kVarInt62OneByte + 0x0b}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + + framer_.set_process_timestamps(true); + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_TRUE(absl::StartsWith(framer_.detailed_error(), + "Receive timestamp delta too high.")); +} + +TEST_P(QuicFramerTest, AckFrameReceiveTimestampDeltaTooHigh) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + { 0x43 }}, + // connection_id + {"", + { 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10 }}, + // packet number + {"", + { 0x12, 0x34, 0x56, 0x78 }}, + + // frame type (IETF_ACK_RECEIVE_TIMESTAMPS frame) + {"", + { 0x22 }}, + // largest acked + {"Unable to read largest acked.", + { kVarInt62TwoBytes + 0x12, 0x34 }}, // = 4660 + // Zero delta time. + {"Unable to read ack delay time.", + { kVarInt62OneByte + 0x00 }}, + // number of additional ack blocks + {"Unable to read ack block count.", + { kVarInt62OneByte + 0x00 }}, + // first ack block length. + {"Unable to read first ack block length.", + { kVarInt62OneByte + 0x00 }}, // 1st block length = 1 + + // Receive Timestamps. + { "Unable to read receive timestamp range count.", + { kVarInt62OneByte + 0x01 }}, + { "Unable to read receive timestamp gap.", + { kVarInt62OneByte + 0x02 }}, + { "Unable to read receive timestamp count.", + { kVarInt62FourBytes + 0x12, 0x34, 0x56, 0x77 }}, + { "Unable to read receive timestamp delta.", + { kVarInt62TwoBytes + 0x29, 0xff}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + + framer_.set_process_timestamps(true); + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_TRUE(absl::StartsWith(framer_.detailed_error(), + "Receive timestamp count too high.")); +} + +TEST_P(QuicFramerTest, AckFrameTimeStampDeltaTooHigh) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x28, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (ack frame) + // (no ack blocks, 1 byte largest observed, 1 byte block length) + 0x40, + // largest acked + 0x01, + // Zero delta time. + 0x00, 0x00, + // first ack block length. + 0x01, + // num timestamps. + 0x01, + // Delta from largest observed. + 0x01, + // Delta time. + 0x10, 0x32, 0x54, 0x76, + }; + + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (ack frame) + // (no ack blocks, 1 byte largest observed, 1 byte block length) + 0x40, + // largest acked + 0x01, + // Zero delta time. + 0x00, 0x00, + // first ack block length. + 0x01, + // num timestamps. + 0x01, + // Delta from largest observed. + 0x01, + // Delta time. + 0x10, 0x32, 0x54, 0x76, + }; + // clang-format on + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + // ACK Timestamp is not a feature of IETF QUIC. + return; + } + QuicEncryptedPacket encrypted( + AsChars(framer_.version().HasIetfInvariantHeader() ? packet46 : packet), + ABSL_ARRAYSIZE(packet), false); + EXPECT_FALSE(framer_.ProcessPacket(encrypted)); + EXPECT_TRUE(absl::StartsWith(framer_.detailed_error(), + "delta_from_largest_observed too high")); +} + +TEST_P(QuicFramerTest, AckFrameTimeStampSecondDeltaTooHigh) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x28, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (ack frame) + // (no ack blocks, 1 byte largest observed, 1 byte block length) + 0x40, + // largest acked + 0x03, + // Zero delta time. + 0x00, 0x00, + // first ack block length. + 0x03, + // num timestamps. + 0x02, + // Delta from largest observed. + 0x01, + // Delta time. + 0x10, 0x32, 0x54, 0x76, + // Delta from largest observed. + 0x03, + // Delta time. + 0x10, 0x32, + }; + + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (ack frame) + // (no ack blocks, 1 byte largest observed, 1 byte block length) + 0x40, + // largest acked + 0x03, + // Zero delta time. + 0x00, 0x00, + // first ack block length. + 0x03, + // num timestamps. + 0x02, + // Delta from largest observed. + 0x01, + // Delta time. + 0x10, 0x32, 0x54, 0x76, + // Delta from largest observed. + 0x03, + // Delta time. + 0x10, 0x32, + }; + // clang-format on + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + // ACK Timestamp is not a feature of IETF QUIC. + return; + } + QuicEncryptedPacket encrypted( + AsChars(framer_.version().HasIetfInvariantHeader() ? packet46 : packet), + ABSL_ARRAYSIZE(packet), false); + EXPECT_FALSE(framer_.ProcessPacket(encrypted)); + EXPECT_TRUE(absl::StartsWith(framer_.detailed_error(), + "delta_from_largest_observed too high")); +} + +TEST_P(QuicFramerTest, NewStopWaitingFrame) { + if (VersionHasIetfQuicFrames(version_.transport_version)) { + // The Stop Waiting frame is not in IETF QUIC + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet = { + // public flags (8 byte connection_id) + {"", + {0x2C}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (stop waiting frame) + {"", + {0x06}}, + // least packet number awaiting an ack, delta from packet number. + {"Unable to read least unacked delta.", + {0x00, 0x00, 0x00, 0x08}} + }; + + PacketFragments packet46 = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (stop waiting frame) + {"", + {0x06}}, + // least packet number awaiting an ack, delta from packet number. + {"Unable to read least unacked delta.", + {0x00, 0x00, 0x00, 0x08}} + }; + // clang-format on + + PacketFragments& fragments = + framer_.version().HasIetfInvariantHeader() ? packet46 : packet; + + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(0u, visitor_.stream_frames_.size()); + ASSERT_EQ(1u, visitor_.stop_waiting_frames_.size()); + const QuicStopWaitingFrame& frame = *visitor_.stop_waiting_frames_[0]; + EXPECT_EQ(kLeastUnacked, frame.least_unacked); + + CheckFramingBoundaries(fragments, QUIC_INVALID_STOP_WAITING_DATA); +} + +TEST_P(QuicFramerTest, InvalidNewStopWaitingFrame) { + // The Stop Waiting frame is not in IETF QUIC + if (VersionHasIetfQuicFrames(version_.transport_version) && + framer_.version().HasIetfInvariantHeader()) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x2C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + // frame type (stop waiting frame) + 0x06, + // least packet number awaiting an ack, delta from packet number. + 0x13, 0x34, 0x56, 0x78, + 0x9A, 0xA8, + }; + + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + // frame type (stop waiting frame) + 0x06, + // least packet number awaiting an ack, delta from packet number. + 0x57, 0x78, 0x9A, 0xA8, + }; + // clang-format on + + QuicEncryptedPacket encrypted( + AsChars(framer_.version().HasIetfInvariantHeader() ? packet46 : packet), + framer_.version().HasIetfInvariantHeader() ? ABSL_ARRAYSIZE(packet46) + : ABSL_ARRAYSIZE(packet), + false); + EXPECT_FALSE(framer_.ProcessPacket(encrypted)); + EXPECT_THAT(framer_.error(), IsError(QUIC_INVALID_STOP_WAITING_DATA)); + EXPECT_EQ("Invalid unacked delta.", framer_.detailed_error()); +} + +TEST_P(QuicFramerTest, RstStreamFrame) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet = { + // public flags (8 byte connection_id) + {"", + {0x28}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (rst stream frame) + {"", + {0x01}}, + // stream id + {"Unable to read stream_id.", + {0x01, 0x02, 0x03, 0x04}}, + // sent byte offset + {"Unable to read rst stream sent byte offset.", + {0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54}}, + // error code QUIC_STREAM_CANCELLED + {"Unable to read rst stream error code.", + {0x00, 0x00, 0x00, 0x06}} + }; + + PacketFragments packet46 = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (rst stream frame) + {"", + {0x01}}, + // stream id + {"Unable to read stream_id.", + {0x01, 0x02, 0x03, 0x04}}, + // sent byte offset + {"Unable to read rst stream sent byte offset.", + {0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54}}, + // error code QUIC_STREAM_CANCELLED + {"Unable to read rst stream error code.", + {0x00, 0x00, 0x00, 0x06}} + }; + + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (IETF_RST_STREAM frame) + {"", + {0x04}}, + // stream id + {"Unable to read IETF_RST_STREAM frame stream id/count.", + {kVarInt62FourBytes + 0x01, 0x02, 0x03, 0x04}}, + // application error code H3_REQUEST_CANCELLED gets translated to + // QuicRstStreamErrorCode::QUIC_STREAM_CANCELLED. + {"Unable to read rst stream error code.", + {kVarInt62TwoBytes + 0x01, 0x0c}}, + // Final Offset + {"Unable to read rst stream sent byte offset.", + {kVarInt62EightBytes + 0x3a, 0x98, 0xFE, 0xDC, 0x32, 0x10, 0x76, 0x54}} + }; + // clang-format on + + PacketFragments& fragments = + VersionHasIetfQuicFrames(framer_.transport_version()) + ? packet_ietf + : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(kStreamId, visitor_.rst_stream_frame_.stream_id); + EXPECT_EQ(QUIC_STREAM_CANCELLED, visitor_.rst_stream_frame_.error_code); + EXPECT_EQ(kStreamOffset, visitor_.rst_stream_frame_.byte_offset); + CheckFramingBoundaries(fragments, QUIC_INVALID_RST_STREAM_DATA); +} + +TEST_P(QuicFramerTest, ConnectionCloseFrame) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet = { + // public flags (8 byte connection_id) + {"", + {0x28}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (connection close frame) + {"", + {0x02}}, + // error code + {"Unable to read connection close error code.", + {0x00, 0x00, 0x00, 0x11}}, + {"Unable to read connection close error details.", + { + // error details length + 0x0, 0x0d, + // error details + 'b', 'e', 'c', 'a', + 'u', 's', 'e', ' ', + 'I', ' ', 'c', 'a', + 'n'} + } + }; + + PacketFragments packet46 = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (connection close frame) + {"", + {0x02}}, + // error code + {"Unable to read connection close error code.", + {0x00, 0x00, 0x00, 0x11}}, + {"Unable to read connection close error details.", + { + // error details length + 0x0, 0x0d, + // error details + 'b', 'e', 'c', 'a', + 'u', 's', 'e', ' ', + 'I', ' ', 'c', 'a', + 'n'} + } + }; + + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (IETF Transport CONNECTION_CLOSE frame) + {"", + {0x1c}}, + // error code + {"Unable to read connection close error code.", + {kVarInt62TwoBytes + 0x00, 0x11}}, + {"Unable to read connection close frame type.", + {kVarInt62TwoBytes + 0x12, 0x34 }}, + {"Unable to read connection close error details.", + { + // error details length + kVarInt62OneByte + 0x11, + // error details with QuicErrorCode serialized + '1', '1', '5', ':', + 'b', 'e', 'c', 'a', + 'u', 's', 'e', ' ', + 'I', ' ', 'c', 'a', + 'n'} + } + }; + // clang-format on + + PacketFragments& fragments = + VersionHasIetfQuicFrames(framer_.transport_version()) + ? packet_ietf + : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(0u, visitor_.stream_frames_.size()); + EXPECT_EQ(0x11u, static_cast( + visitor_.connection_close_frame_.wire_error_code)); + EXPECT_EQ("because I can", visitor_.connection_close_frame_.error_details); + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + EXPECT_EQ(0x1234u, + visitor_.connection_close_frame_.transport_close_frame_type); + EXPECT_EQ(115u, visitor_.connection_close_frame_.quic_error_code); + } else { + // For Google QUIC frame, |quic_error_code| and |wire_error_code| has the + // same value. + EXPECT_EQ(0x11u, static_cast( + visitor_.connection_close_frame_.quic_error_code)); + } + + ASSERT_EQ(0u, visitor_.ack_frames_.size()); + + CheckFramingBoundaries(fragments, QUIC_INVALID_CONNECTION_CLOSE_DATA); +} + +TEST_P(QuicFramerTest, ConnectionCloseFrameWithUnknownErrorCode) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet = { + // public flags (8 byte connection_id) + {"", + {0x28}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (connection close frame) + {"", + {0x02}}, + // error code larger than QUIC_LAST_ERROR + {"Unable to read connection close error code.", + {0x00, 0x00, 0xC0, 0xDE}}, + {"Unable to read connection close error details.", + { + // error details length + 0x0, 0x0d, + // error details + 'b', 'e', 'c', 'a', + 'u', 's', 'e', ' ', + 'I', ' ', 'c', 'a', + 'n'} + } + }; + + PacketFragments packet46 = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (connection close frame) + {"", + {0x02}}, + // error code larger than QUIC_LAST_ERROR + {"Unable to read connection close error code.", + {0x00, 0x00, 0xC0, 0xDE}}, + {"Unable to read connection close error details.", + { + // error details length + 0x0, 0x0d, + // error details + 'b', 'e', 'c', 'a', + 'u', 's', 'e', ' ', + 'I', ' ', 'c', 'a', + 'n'} + } + }; + + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (IETF Transport CONNECTION_CLOSE frame) + {"", + {0x1c}}, + // error code + {"Unable to read connection close error code.", + {kVarInt62FourBytes + 0x00, 0x00, 0xC0, 0xDE}}, + {"Unable to read connection close frame type.", + {kVarInt62TwoBytes + 0x12, 0x34 }}, + {"Unable to read connection close error details.", + { + // error details length + kVarInt62OneByte + 0x11, + // error details with QuicErrorCode larger than QUIC_LAST_ERROR + '8', '4', '9', ':', + 'b', 'e', 'c', 'a', + 'u', 's', 'e', ' ', + 'I', ' ', 'c', 'a', + 'n'} + } + }; + // clang-format on + + PacketFragments& fragments = + VersionHasIetfQuicFrames(framer_.transport_version()) + ? packet_ietf + : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(0u, visitor_.stream_frames_.size()); + EXPECT_EQ("because I can", visitor_.connection_close_frame_.error_details); + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + EXPECT_EQ(0x1234u, + visitor_.connection_close_frame_.transport_close_frame_type); + EXPECT_EQ(0xC0DEu, visitor_.connection_close_frame_.wire_error_code); + EXPECT_EQ(849u, visitor_.connection_close_frame_.quic_error_code); + } else { + // For Google QUIC frame, |quic_error_code| and |wire_error_code| has the + // same value. + EXPECT_EQ(0xC0DEu, visitor_.connection_close_frame_.wire_error_code); + EXPECT_EQ(0xC0DEu, visitor_.connection_close_frame_.quic_error_code); + } + + ASSERT_EQ(0u, visitor_.ack_frames_.size()); + + CheckFramingBoundaries(fragments, QUIC_INVALID_CONNECTION_CLOSE_DATA); +} + +// As above, but checks that for Google-QUIC, if there happens +// to be an ErrorCode string at the start of the details, it is +// NOT extracted/parsed/folded/spindled/and/mutilated. +TEST_P(QuicFramerTest, ConnectionCloseFrameWithExtractedInfoIgnoreGCuic) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + PacketFragments packet = { + // public flags (8 byte connection_id) + {"", + {0x28}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (connection close frame) + {"", + {0x02}}, + // error code + {"Unable to read connection close error code.", + {0x00, 0x00, 0x00, 0x11}}, + {"Unable to read connection close error details.", + { + // error details length + 0x0, 0x13, + // error details + '1', '7', '7', '6', + '7', ':', 'b', 'e', + 'c', 'a', 'u', 's', + 'e', ' ', 'I', ' ', + 'c', 'a', 'n'} + } + }; + + PacketFragments packet46 = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (connection close frame) + {"", + {0x02}}, + // error code + {"Unable to read connection close error code.", + {0x00, 0x00, 0x00, 0x11}}, + {"Unable to read connection close error details.", + { + // error details length + 0x0, 0x13, + // error details + '1', '7', '7', '6', + '7', ':', 'b', 'e', + 'c', 'a', 'u', 's', + 'e', ' ', 'I', ' ', + 'c', 'a', 'n'} + } + }; + + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (IETF Transport CONNECTION_CLOSE frame) + {"", + {0x1c}}, + // error code + {"Unable to read connection close error code.", + {kVarInt62OneByte + 0x11}}, + {"Unable to read connection close frame type.", + {kVarInt62TwoBytes + 0x12, 0x34 }}, + {"Unable to read connection close error details.", + { + // error details length + kVarInt62OneByte + 0x13, + // error details + '1', '7', '7', '6', + '7', ':', 'b', 'e', + 'c', 'a', 'u', 's', + 'e', ' ', 'I', ' ', + 'c', 'a', 'n'} + } + }; + // clang-format on + + PacketFragments& fragments = + VersionHasIetfQuicFrames(framer_.transport_version()) + ? packet_ietf + : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(0u, visitor_.stream_frames_.size()); + EXPECT_EQ(0x11u, static_cast( + visitor_.connection_close_frame_.wire_error_code)); + + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + EXPECT_EQ(0x1234u, + visitor_.connection_close_frame_.transport_close_frame_type); + EXPECT_EQ(17767u, visitor_.connection_close_frame_.quic_error_code); + EXPECT_EQ("because I can", visitor_.connection_close_frame_.error_details); + } else { + EXPECT_EQ(0x11u, visitor_.connection_close_frame_.quic_error_code); + // Error code is not prepended in GQUIC, so it is not removed and should + // remain in the reason phrase. + EXPECT_EQ("17767:because I can", + visitor_.connection_close_frame_.error_details); + } + + ASSERT_EQ(0u, visitor_.ack_frames_.size()); + + CheckFramingBoundaries(fragments, QUIC_INVALID_CONNECTION_CLOSE_DATA); +} + +// Test the CONNECTION_CLOSE/Application variant. +TEST_P(QuicFramerTest, ApplicationCloseFrame) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // This frame is only in IETF QUIC. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (IETF_CONNECTION_CLOSE/Application frame) + {"", + {0x1d}}, + // error code + {"Unable to read connection close error code.", + {kVarInt62TwoBytes + 0x00, 0x11}}, + {"Unable to read connection close error details.", + { + // error details length + kVarInt62OneByte + 0x0d, + // error details + 'b', 'e', 'c', 'a', + 'u', 's', 'e', ' ', + 'I', ' ', 'c', 'a', + 'n'} + } + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(0u, visitor_.stream_frames_.size()); + + EXPECT_EQ(IETF_QUIC_APPLICATION_CONNECTION_CLOSE, + visitor_.connection_close_frame_.close_type); + EXPECT_EQ(122u, visitor_.connection_close_frame_.quic_error_code); + EXPECT_EQ(0x11u, visitor_.connection_close_frame_.wire_error_code); + EXPECT_EQ("because I can", visitor_.connection_close_frame_.error_details); + + ASSERT_EQ(0u, visitor_.ack_frames_.size()); + + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_CONNECTION_CLOSE_DATA); +} + +// Check that we can extract an error code from an application close. +TEST_P(QuicFramerTest, ApplicationCloseFrameExtract) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // This frame is only in IETF QUIC. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (IETF_CONNECTION_CLOSE/Application frame) + {"", + {0x1d}}, + // error code + {"Unable to read connection close error code.", + {kVarInt62OneByte + 0x11}}, + {"Unable to read connection close error details.", + { + // error details length + kVarInt62OneByte + 0x13, + // error details + '1', '7', '7', '6', + '7', ':', 'b', 'e', + 'c', 'a', 'u', 's', + 'e', ' ', 'I', ' ', + 'c', 'a', 'n'} + } + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(0u, visitor_.stream_frames_.size()); + + EXPECT_EQ(IETF_QUIC_APPLICATION_CONNECTION_CLOSE, + visitor_.connection_close_frame_.close_type); + EXPECT_EQ(17767u, visitor_.connection_close_frame_.quic_error_code); + EXPECT_EQ(0x11u, visitor_.connection_close_frame_.wire_error_code); + EXPECT_EQ("because I can", visitor_.connection_close_frame_.error_details); + + ASSERT_EQ(0u, visitor_.ack_frames_.size()); + + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_CONNECTION_CLOSE_DATA); +} + +TEST_P(QuicFramerTest, GoAwayFrame) { + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + // This frame is not in IETF QUIC. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet = { + // public flags (8 byte connection_id) + {"", + {0x28}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (go away frame) + {"", + {0x03}}, + // error code + {"Unable to read go away error code.", + {0x00, 0x00, 0x00, 0x09}}, + // stream id + {"Unable to read last good stream id.", + {0x01, 0x02, 0x03, 0x04}}, + // stream id + {"Unable to read goaway reason.", + { + // error details length + 0x0, 0x0d, + // error details + 'b', 'e', 'c', 'a', + 'u', 's', 'e', ' ', + 'I', ' ', 'c', 'a', + 'n'} + } + }; + + PacketFragments packet46 = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (go away frame) + {"", + {0x03}}, + // error code + {"Unable to read go away error code.", + {0x00, 0x00, 0x00, 0x09}}, + // stream id + {"Unable to read last good stream id.", + {0x01, 0x02, 0x03, 0x04}}, + // stream id + {"Unable to read goaway reason.", + { + // error details length + 0x0, 0x0d, + // error details + 'b', 'e', 'c', 'a', + 'u', 's', 'e', ' ', + 'I', ' ', 'c', 'a', + 'n'} + } + }; + // clang-format on + + PacketFragments& fragments = + framer_.version().HasIetfInvariantHeader() ? packet46 : packet; + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(kStreamId, visitor_.goaway_frame_.last_good_stream_id); + EXPECT_EQ(0x9u, visitor_.goaway_frame_.error_code); + EXPECT_EQ("because I can", visitor_.goaway_frame_.reason_phrase); + + CheckFramingBoundaries(fragments, QUIC_INVALID_GOAWAY_DATA); +} + +TEST_P(QuicFramerTest, GoAwayFrameWithUnknownErrorCode) { + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + // This frame is not in IETF QUIC. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet = { + // public flags (8 byte connection_id) + {"", + {0x28}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (go away frame) + {"", + {0x03}}, + // error code larger than QUIC_LAST_ERROR + {"Unable to read go away error code.", + {0x00, 0x00, 0xC0, 0xDE}}, + // stream id + {"Unable to read last good stream id.", + {0x01, 0x02, 0x03, 0x04}}, + // stream id + {"Unable to read goaway reason.", + { + // error details length + 0x0, 0x0d, + // error details + 'b', 'e', 'c', 'a', + 'u', 's', 'e', ' ', + 'I', ' ', 'c', 'a', + 'n'} + } + }; + + PacketFragments packet46 = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (go away frame) + {"", + {0x03}}, + // error code larger than QUIC_LAST_ERROR + {"Unable to read go away error code.", + {0x00, 0x00, 0xC0, 0xDE}}, + // stream id + {"Unable to read last good stream id.", + {0x01, 0x02, 0x03, 0x04}}, + // stream id + {"Unable to read goaway reason.", + { + // error details length + 0x0, 0x0d, + // error details + 'b', 'e', 'c', 'a', + 'u', 's', 'e', ' ', + 'I', ' ', 'c', 'a', + 'n'} + } + }; + // clang-format on + + PacketFragments& fragments = + framer_.version().HasIetfInvariantHeader() ? packet46 : packet; + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(kStreamId, visitor_.goaway_frame_.last_good_stream_id); + EXPECT_EQ(0xC0DE, visitor_.goaway_frame_.error_code); + EXPECT_EQ("because I can", visitor_.goaway_frame_.reason_phrase); + + CheckFramingBoundaries(fragments, QUIC_INVALID_GOAWAY_DATA); +} + +TEST_P(QuicFramerTest, WindowUpdateFrame) { + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + // This frame is not in IETF QUIC, see MaxDataFrame and MaxStreamDataFrame + // for IETF QUIC equivalents. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet = { + // public flags (8 byte connection_id) + {"", + {0x28}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (window update frame) + {"", + {0x04}}, + // stream id + {"Unable to read stream_id.", + {0x01, 0x02, 0x03, 0x04}}, + // byte offset + {"Unable to read window byte_offset.", + {0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54}}, + }; + + PacketFragments packet46 = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (window update frame) + {"", + {0x04}}, + // stream id + {"Unable to read stream_id.", + {0x01, 0x02, 0x03, 0x04}}, + // byte offset + {"Unable to read window byte_offset.", + {0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54}}, + }; + + // clang-format on + + PacketFragments& fragments = + framer_.version().HasIetfInvariantHeader() ? packet46 : packet; + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(kStreamId, visitor_.window_update_frame_.stream_id); + EXPECT_EQ(kStreamOffset, visitor_.window_update_frame_.max_data); + + CheckFramingBoundaries(fragments, QUIC_INVALID_WINDOW_UPDATE_DATA); +} + +TEST_P(QuicFramerTest, MaxDataFrame) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // This frame is available only in IETF QUIC. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (IETF_MAX_DATA frame) + {"", + {0x10}}, + // byte offset + {"Can not read MAX_DATA byte-offset", + {kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(QuicUtils::GetInvalidStreamId(framer_.transport_version()), + visitor_.window_update_frame_.stream_id); + EXPECT_EQ(kStreamOffset, visitor_.window_update_frame_.max_data); + + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_MAX_DATA_FRAME_DATA); +} + +TEST_P(QuicFramerTest, MaxStreamDataFrame) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // This frame available only in IETF QUIC. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (IETF_MAX_STREAM_DATA frame) + {"", + {0x11}}, + // stream id + {"Unable to read IETF_MAX_STREAM_DATA frame stream id/count.", + {kVarInt62FourBytes + 0x01, 0x02, 0x03, 0x04}}, + // byte offset + {"Can not read MAX_STREAM_DATA byte-count", + {kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(kStreamId, visitor_.window_update_frame_.stream_id); + EXPECT_EQ(kStreamOffset, visitor_.window_update_frame_.max_data); + + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_MAX_STREAM_DATA_FRAME_DATA); +} + +TEST_P(QuicFramerTest, BlockedFrame) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet = { + // public flags (8 byte connection_id) + {"", + {0x28}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (blocked frame) + {"", + {0x05}}, + // stream id + {"Unable to read stream_id.", + {0x01, 0x02, 0x03, 0x04}}, + }; + + PacketFragments packet46 = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (blocked frame) + {"", + {0x05}}, + // stream id + {"Unable to read stream_id.", + {0x01, 0x02, 0x03, 0x04}}, + }; + + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (IETF_STREAM_BLOCKED frame) + {"", + {0x15}}, + // stream id + {"Unable to read IETF_STREAM_DATA_BLOCKED frame stream id/count.", + {kVarInt62FourBytes + 0x01, 0x02, 0x03, 0x04}}, + // Offset + {"Can not read stream blocked offset.", + {kVarInt62EightBytes + 0x3a, 0x98, 0xFE, 0xDC, 0x32, 0x10, 0x76, 0x54}}, + }; + // clang-format on + + PacketFragments& fragments = + VersionHasIetfQuicFrames(framer_.transport_version()) + ? packet_ietf + : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + EXPECT_EQ(kStreamOffset, visitor_.blocked_frame_.offset); + } else { + EXPECT_EQ(0u, visitor_.blocked_frame_.offset); + } + EXPECT_EQ(kStreamId, visitor_.blocked_frame_.stream_id); + + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + CheckFramingBoundaries(fragments, QUIC_INVALID_STREAM_BLOCKED_DATA); + } else { + CheckFramingBoundaries(fragments, QUIC_INVALID_BLOCKED_DATA); + } +} + +TEST_P(QuicFramerTest, PingFrame) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x28, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (ping frame) + 0x07, + }; + + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type + 0x07, + }; + + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_PING frame) + 0x01, + }; + // clang-format on + + QuicEncryptedPacket encrypted( + AsChars(VersionHasIetfQuicFrames(framer_.transport_version()) + ? packet_ietf + : (framer_.version().HasIetfInvariantHeader() ? packet46 + : packet)), + VersionHasIetfQuicFrames(framer_.transport_version()) + ? ABSL_ARRAYSIZE(packet_ietf) + : (framer_.version().HasIetfInvariantHeader() + ? ABSL_ARRAYSIZE(packet46) + : ABSL_ARRAYSIZE(packet)), + false); + EXPECT_TRUE(framer_.ProcessPacket(encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(1u, visitor_.ping_frames_.size()); + + // No need to check the PING frame boundaries because it has no payload. +} + +TEST_P(QuicFramerTest, HandshakeDoneFrame) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + unsigned char packet[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (Handshake done frame) + 0x1e, + }; + // clang-format on + + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + + QuicEncryptedPacket encrypted(AsChars(packet), ABSL_ARRAYSIZE(packet), false); + EXPECT_TRUE(framer_.ProcessPacket(encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(1u, visitor_.handshake_done_frames_.size()); +} + +TEST_P(QuicFramerTest, ParseAckFrequencyFrame) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + unsigned char packet[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // ack frequency frame type (which needs two bytes as it is > 0x3F) + 0x40, 0xAF, + // sequence_number + 0x11, + // packet_tolerance + 0x02, + // max_ack_delay_us = 2'5000 us + 0x80, 0x00, 0x61, 0xA8, + // ignore_order + 0x01 + }; + // clang-format on + + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + + QuicEncryptedPacket encrypted(AsChars(packet), ABSL_ARRAYSIZE(packet), false); + EXPECT_TRUE(framer_.ProcessPacket(encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + ASSERT_EQ(1u, visitor_.ack_frequency_frames_.size()); + const auto& frame = visitor_.ack_frequency_frames_.front(); + EXPECT_EQ(17u, frame->sequence_number); + EXPECT_EQ(2u, frame->packet_tolerance); + EXPECT_EQ(2'5000u, frame->max_ack_delay.ToMicroseconds()); + EXPECT_EQ(true, frame->ignore_order); +} + +TEST_P(QuicFramerTest, MessageFrame) { + if (!VersionSupportsMessageFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet46 = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // message frame type. + {"", + { 0x21 }}, + // message length + {"Unable to read message length", + {0x07}}, + // message data + {"Unable to read message data", + {'m', 'e', 's', 's', 'a', 'g', 'e'}}, + // message frame no length. + {"", + { 0x20 }}, + // message data + {{}, + {'m', 'e', 's', 's', 'a', 'g', 'e', '2'}}, + }; + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // message frame type. + {"", + { 0x31 }}, + // message length + {"Unable to read message length", + {0x07}}, + // message data + {"Unable to read message data", + {'m', 'e', 's', 's', 'a', 'g', 'e'}}, + // message frame no length. + {"", + { 0x30 }}, + // message data + {{}, + {'m', 'e', 's', 's', 'a', 'g', 'e', '2'}}, + }; + // clang-format on + + std::unique_ptr encrypted; + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + encrypted = AssemblePacketFromFragments(packet_ietf); + } else { + encrypted = AssemblePacketFromFragments(packet46); + } + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + ASSERT_EQ(2u, visitor_.message_frames_.size()); + EXPECT_EQ(7u, visitor_.message_frames_[0]->message_length); + EXPECT_EQ(8u, visitor_.message_frames_[1]->message_length); + + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_MESSAGE_DATA); + } else { + CheckFramingBoundaries(packet46, QUIC_INVALID_MESSAGE_DATA); + } +} + +TEST_P(QuicFramerTest, PublicResetPacketV33) { + // clang-format off + PacketFragments packet = { + // public flags (public reset, 8 byte connection_id) + {"", + {0x0A}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + {"Unable to read reset message.", + { + // message tag (kPRST) + 'P', 'R', 'S', 'T', + // num_entries (2) + padding + 0x02, 0x00, 0x00, 0x00, + // tag kRNON + 'R', 'N', 'O', 'N', + // end offset 8 + 0x08, 0x00, 0x00, 0x00, + // tag kRSEQ + 'R', 'S', 'E', 'Q', + // end offset 16 + 0x10, 0x00, 0x00, 0x00, + // nonce proof + 0x89, 0x67, 0x45, 0x23, + 0x01, 0xEF, 0xCD, 0xAB, + // rejected packet number + 0xBC, 0x9A, 0x78, 0x56, + 0x34, 0x12, 0x00, 0x00, + } + } + }; + // clang-format on + if (framer_.version().HasIetfInvariantHeader()) { + return; + } + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + ASSERT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.public_reset_packet_.get()); + EXPECT_EQ(FramerTestConnectionId(), + visitor_.public_reset_packet_->connection_id); + EXPECT_EQ(kNonceProof, visitor_.public_reset_packet_->nonce_proof); + EXPECT_EQ( + IpAddressFamily::IP_UNSPEC, + visitor_.public_reset_packet_->client_address.host().address_family()); + + CheckFramingBoundaries(packet, QUIC_INVALID_PUBLIC_RST_PACKET); +} + +TEST_P(QuicFramerTest, PublicResetPacket) { + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + + // clang-format off + PacketFragments packet = { + // public flags (public reset, 8 byte connection_id) + {"", + {0x0E}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + {"Unable to read reset message.", + { + // message tag (kPRST) + 'P', 'R', 'S', 'T', + // num_entries (2) + padding + 0x02, 0x00, 0x00, 0x00, + // tag kRNON + 'R', 'N', 'O', 'N', + // end offset 8 + 0x08, 0x00, 0x00, 0x00, + // tag kRSEQ + 'R', 'S', 'E', 'Q', + // end offset 16 + 0x10, 0x00, 0x00, 0x00, + // nonce proof + 0x89, 0x67, 0x45, 0x23, + 0x01, 0xEF, 0xCD, 0xAB, + // rejected packet number + 0xBC, 0x9A, 0x78, 0x56, + 0x34, 0x12, 0x00, 0x00, + } + } + }; + // clang-format on + + if (framer_.version().HasIetfInvariantHeader()) { + return; + } + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + ASSERT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.public_reset_packet_.get()); + EXPECT_EQ(FramerTestConnectionId(), + visitor_.public_reset_packet_->connection_id); + EXPECT_EQ(kNonceProof, visitor_.public_reset_packet_->nonce_proof); + EXPECT_EQ( + IpAddressFamily::IP_UNSPEC, + visitor_.public_reset_packet_->client_address.host().address_family()); + + CheckFramingBoundaries(packet, QUIC_INVALID_PUBLIC_RST_PACKET); +} + +TEST_P(QuicFramerTest, PublicResetPacketWithTrailingJunk) { + // clang-format off + unsigned char packet[] = { + // public flags (public reset, 8 byte connection_id) + 0x0A, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // message tag (kPRST) + 'P', 'R', 'S', 'T', + // num_entries (2) + padding + 0x02, 0x00, 0x00, 0x00, + // tag kRNON + 'R', 'N', 'O', 'N', + // end offset 8 + 0x08, 0x00, 0x00, 0x00, + // tag kRSEQ + 'R', 'S', 'E', 'Q', + // end offset 16 + 0x10, 0x00, 0x00, 0x00, + // nonce proof + 0x89, 0x67, 0x45, 0x23, + 0x01, 0xEF, 0xCD, 0xAB, + // rejected packet number + 0xBC, 0x9A, 0x78, 0x56, + 0x34, 0x12, 0x00, 0x00, + // trailing junk + 'j', 'u', 'n', 'k', + }; + // clang-format on + if (framer_.version().HasIetfInvariantHeader()) { + return; + } + + QuicEncryptedPacket encrypted(AsChars(packet), ABSL_ARRAYSIZE(packet), false); + EXPECT_FALSE(framer_.ProcessPacket(encrypted)); + ASSERT_THAT(framer_.error(), IsError(QUIC_INVALID_PUBLIC_RST_PACKET)); + EXPECT_EQ("Unable to read reset message.", framer_.detailed_error()); +} + +TEST_P(QuicFramerTest, PublicResetPacketWithClientAddress) { + // clang-format off + PacketFragments packet = { + // public flags (public reset, 8 byte connection_id) + {"", + {0x0A}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + {"Unable to read reset message.", + { + // message tag (kPRST) + 'P', 'R', 'S', 'T', + // num_entries (2) + padding + 0x03, 0x00, 0x00, 0x00, + // tag kRNON + 'R', 'N', 'O', 'N', + // end offset 8 + 0x08, 0x00, 0x00, 0x00, + // tag kRSEQ + 'R', 'S', 'E', 'Q', + // end offset 16 + 0x10, 0x00, 0x00, 0x00, + // tag kCADR + 'C', 'A', 'D', 'R', + // end offset 24 + 0x18, 0x00, 0x00, 0x00, + // nonce proof + 0x89, 0x67, 0x45, 0x23, + 0x01, 0xEF, 0xCD, 0xAB, + // rejected packet number + 0xBC, 0x9A, 0x78, 0x56, + 0x34, 0x12, 0x00, 0x00, + // client address: 4.31.198.44:443 + 0x02, 0x00, + 0x04, 0x1F, 0xC6, 0x2C, + 0xBB, 0x01, + } + } + }; + // clang-format on + if (framer_.version().HasIetfInvariantHeader()) { + return; + } + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + ASSERT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.public_reset_packet_.get()); + EXPECT_EQ(FramerTestConnectionId(), + visitor_.public_reset_packet_->connection_id); + EXPECT_EQ(kNonceProof, visitor_.public_reset_packet_->nonce_proof); + EXPECT_EQ("4.31.198.44", + visitor_.public_reset_packet_->client_address.host().ToString()); + EXPECT_EQ(443, visitor_.public_reset_packet_->client_address.port()); + + CheckFramingBoundaries(packet, QUIC_INVALID_PUBLIC_RST_PACKET); +} + +TEST_P(QuicFramerTest, IetfStatelessResetPacket) { + // clang-format off + unsigned char packet[] = { + // type (short packet, 1 byte packet number) + 0x50, + // Random bytes + 0x01, 0x11, 0x02, 0x22, 0x03, 0x33, 0x04, 0x44, + 0x01, 0x11, 0x02, 0x22, 0x03, 0x33, 0x04, 0x44, + 0x01, 0x11, 0x02, 0x22, 0x03, 0x33, 0x04, 0x44, + 0x01, 0x11, 0x02, 0x22, 0x03, 0x33, 0x04, 0x44, + // stateless reset token + 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, + 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f, + }; + // clang-format on + if (!framer_.version().HasIetfInvariantHeader()) { + return; + } + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicFramerPeer::SetLastSerializedServerConnectionId(&framer_, + TestConnectionId(0x33)); + decrypter_ = new test::TestDecrypter(); + if (framer_.version().KnowsWhichDecrypterToUse()) { + framer_.InstallDecrypter( + ENCRYPTION_INITIAL, + std::make_unique(Perspective::IS_CLIENT)); + framer_.InstallDecrypter(ENCRYPTION_ZERO_RTT, + std::unique_ptr(decrypter_)); + } else { + framer_.SetDecrypter(ENCRYPTION_INITIAL, std::make_unique( + Perspective::IS_CLIENT)); + framer_.SetAlternativeDecrypter( + ENCRYPTION_ZERO_RTT, std::unique_ptr(decrypter_), false); + } + // This packet cannot be decrypted because diversification nonce is missing. + QuicEncryptedPacket encrypted(AsChars(packet), ABSL_ARRAYSIZE(packet), false); + EXPECT_TRUE(framer_.ProcessPacket(encrypted)); + ASSERT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.stateless_reset_packet_.get()); + EXPECT_EQ(kTestStatelessResetToken, + visitor_.stateless_reset_packet_->stateless_reset_token); +} + +TEST_P(QuicFramerTest, IetfStatelessResetPacketInvalidStatelessResetToken) { + // clang-format off + unsigned char packet[] = { + // type (short packet, 1 byte packet number) + 0x50, + // Random bytes + 0x01, 0x11, 0x02, 0x22, 0x03, 0x33, 0x04, 0x44, + 0x01, 0x11, 0x02, 0x22, 0x03, 0x33, 0x04, 0x44, + 0x01, 0x11, 0x02, 0x22, 0x03, 0x33, 0x04, 0x44, + 0x01, 0x11, 0x02, 0x22, 0x03, 0x33, 0x04, 0x44, + // stateless reset token + 0xB6, 0x69, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }; + // clang-format on + if (!framer_.version().HasIetfInvariantHeader()) { + return; + } + QuicFramerPeer::SetLastSerializedServerConnectionId(&framer_, + TestConnectionId(0x33)); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + decrypter_ = new test::TestDecrypter(); + if (framer_.version().KnowsWhichDecrypterToUse()) { + framer_.InstallDecrypter( + ENCRYPTION_INITIAL, + std::make_unique(Perspective::IS_CLIENT)); + framer_.InstallDecrypter(ENCRYPTION_ZERO_RTT, + std::unique_ptr(decrypter_)); + } else { + framer_.SetDecrypter(ENCRYPTION_INITIAL, std::make_unique( + Perspective::IS_CLIENT)); + framer_.SetAlternativeDecrypter( + ENCRYPTION_ZERO_RTT, std::unique_ptr(decrypter_), false); + } + // This packet cannot be decrypted because diversification nonce is missing. + QuicEncryptedPacket encrypted(AsChars(packet), ABSL_ARRAYSIZE(packet), false); + EXPECT_FALSE(framer_.ProcessPacket(encrypted)); + EXPECT_THAT(framer_.error(), IsError(QUIC_DECRYPTION_FAILURE)); + ASSERT_FALSE(visitor_.stateless_reset_packet_); +} + +TEST_P(QuicFramerTest, VersionNegotiationPacketClient) { + // clang-format off + PacketFragments packet = { + // public flags (version, 8 byte connection_id) + {"", + {0x29}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // version tag + {"Unable to read supported version in negotiation.", + {QUIC_VERSION_BYTES, + 'Q', '2', '.', '0'}}, + }; + + PacketFragments packet46 = { + // type (long header) + {"", + {0x8F}}, + // version tag + {"", + {0x00, 0x00, 0x00, 0x00}}, + {"", + {0x05}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // Supported versions + {"Unable to read supported version in negotiation.", + {QUIC_VERSION_BYTES, + 'Q', '2', '.', '0'}}, + }; + + PacketFragments packet49 = { + // type (long header) + {"", + {0x8F}}, + // version tag + {"", + {0x00, 0x00, 0x00, 0x00}}, + {"", + {0x08}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + {"", + {0x00}}, + // Supported versions + {"Unable to read supported version in negotiation.", + {QUIC_VERSION_BYTES, + 'Q', '2', '.', '0'}}, + }; + // clang-format on + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + + PacketFragments& fragments = + framer_.version().HasLongHeaderLengths() ? packet49 + : framer_.version().HasIetfInvariantHeader() ? packet46 + : packet; + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + ASSERT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.version_negotiation_packet_.get()); + EXPECT_EQ(1u, visitor_.version_negotiation_packet_->versions.size()); + EXPECT_EQ(GetParam(), visitor_.version_negotiation_packet_->versions[0]); + + // Remove the last version from the packet so that every truncated + // version of the packet is invalid, otherwise checking boundaries + // is annoyingly complicated. + for (size_t i = 0; i < 4; ++i) { + fragments.back().fragment.pop_back(); + } + CheckFramingBoundaries(fragments, QUIC_INVALID_VERSION_NEGOTIATION_PACKET); +} + +TEST_P(QuicFramerTest, VersionNegotiationPacketServer) { + if (!framer_.version().HasIetfInvariantHeader()) { + return; + } + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // clang-format off + unsigned char packet[] = { + // public flags (long header with all ignored bits set) + 0xFF, + // version + 0x00, 0x00, 0x00, 0x00, + // connection ID lengths + 0x50, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x11, + // supported versions + QUIC_VERSION_BYTES, + 'Q', '2', '.', '0', + }; + unsigned char packet2[] = { + // public flags (long header with all ignored bits set) + 0xFF, + // version + 0x00, 0x00, 0x00, 0x00, + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x11, + // source connection ID length + 0x00, + // supported versions + QUIC_VERSION_BYTES, + 'Q', '2', '.', '0', + }; + // clang-format on + unsigned char* p = packet; + size_t p_length = ABSL_ARRAYSIZE(packet); + if (framer_.version().HasLengthPrefixedConnectionIds()) { + p = packet2; + p_length = ABSL_ARRAYSIZE(packet2); + } + + QuicEncryptedPacket encrypted(AsChars(p), p_length, false); + EXPECT_FALSE(framer_.ProcessPacket(encrypted)); + EXPECT_THAT(framer_.error(), + IsError(QUIC_INVALID_VERSION_NEGOTIATION_PACKET)); + EXPECT_EQ("Server received version negotiation packet.", + framer_.detailed_error()); + EXPECT_FALSE(visitor_.version_negotiation_packet_.get()); +} + +TEST_P(QuicFramerTest, OldVersionNegotiationPacket) { + // clang-format off + PacketFragments packet = { + // public flags (version, 8 byte connection_id) + {"", + {0x2D}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // version tag + {"Unable to read supported version in negotiation.", + {QUIC_VERSION_BYTES, + 'Q', '2', '.', '0'}}, + }; + // clang-format on + + if (framer_.version().HasIetfInvariantHeader()) { + return; + } + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + ASSERT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.version_negotiation_packet_.get()); + EXPECT_EQ(1u, visitor_.version_negotiation_packet_->versions.size()); + EXPECT_EQ(GetParam(), visitor_.version_negotiation_packet_->versions[0]); + + // Remove the last version from the packet so that every truncated + // version of the packet is invalid, otherwise checking boundaries + // is annoyingly complicated. + for (size_t i = 0; i < 4; ++i) { + packet.back().fragment.pop_back(); + } + CheckFramingBoundaries(packet, QUIC_INVALID_VERSION_NEGOTIATION_PACKET); +} + +TEST_P(QuicFramerTest, ParseIetfRetryPacket) { + if (!framer_.version().SupportsRetry()) { + return; + } + // IETF RETRY is only sent from client to server. + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + // clang-format off + unsigned char packet[] = { + // public flags (long header with packet type RETRY and ODCIL=8) + 0xF5, + // version + QUIC_VERSION_BYTES, + // connection ID lengths + 0x05, + // source connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x11, + // original destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // retry token + 'H', 'e', 'l', 'l', 'o', ' ', 't', 'h', 'i', 's', + ' ', 'i', 's', ' ', 'R', 'E', 'T', 'R', 'Y', '!', + }; + unsigned char packet49[] = { + // public flags (long header with packet type RETRY) + 0xF0, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x00, + // source connection ID length + 0x08, + // source connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x11, + // original destination connection ID length + 0x08, + // original destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // retry token + 'H', 'e', 'l', 'l', 'o', ' ', 't', 'h', 'i', 's', + ' ', 'i', 's', ' ', 'R', 'E', 'T', 'R', 'Y', '!', + }; + unsigned char packet_with_tag[] = { + // public flags (long header with packet type RETRY) + 0xF0, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x00, + // source connection ID length + 0x08, + // source connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x11, + // retry token + 'H', 'e', 'l', 'l', 'o', ' ', 't', 'h', 'i', 's', + ' ', 'i', 's', ' ', 'R', 'E', 'T', 'R', 'Y', '!', + // retry token integrity tag + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + }; + // clang-format on + + unsigned char* p = packet; + size_t p_length = ABSL_ARRAYSIZE(packet); + if (framer_.version().UsesTls()) { + ReviseFirstByteByVersion(packet_with_tag); + p = packet_with_tag; + p_length = ABSL_ARRAYSIZE(packet_with_tag); + } else if (framer_.version().HasLongHeaderLengths()) { + p = packet49; + p_length = ABSL_ARRAYSIZE(packet49); + } + QuicEncryptedPacket encrypted(AsChars(p), p_length, false); + EXPECT_TRUE(framer_.ProcessPacket(encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + + ASSERT_TRUE(visitor_.on_retry_packet_called_); + ASSERT_TRUE(visitor_.retry_new_connection_id_.get()); + ASSERT_TRUE(visitor_.retry_token_.get()); + + if (framer_.version().UsesTls()) { + ASSERT_TRUE(visitor_.retry_token_integrity_tag_.get()); + static const unsigned char expected_integrity_tag[16] = { + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + }; + quiche::test::CompareCharArraysWithHexError( + "retry integrity tag", visitor_.retry_token_integrity_tag_->data(), + visitor_.retry_token_integrity_tag_->length(), + reinterpret_cast(expected_integrity_tag), + ABSL_ARRAYSIZE(expected_integrity_tag)); + ASSERT_TRUE(visitor_.retry_without_tag_.get()); + quiche::test::CompareCharArraysWithHexError( + "retry without tag", visitor_.retry_without_tag_->data(), + visitor_.retry_without_tag_->length(), + reinterpret_cast(packet_with_tag), 35); + } else { + ASSERT_TRUE(visitor_.retry_original_connection_id_.get()); + EXPECT_EQ(FramerTestConnectionId(), + *visitor_.retry_original_connection_id_.get()); + } + + EXPECT_EQ(FramerTestConnectionIdPlusOne(), + *visitor_.retry_new_connection_id_.get()); + EXPECT_EQ("Hello this is RETRY!", *visitor_.retry_token_.get()); + + // IETF RETRY is only sent from client to server, the rest of this test + // ensures that the server correctly drops them without acting on them. + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Reset our visitor state to default settings. + visitor_.retry_original_connection_id_.reset(); + visitor_.retry_new_connection_id_.reset(); + visitor_.retry_token_.reset(); + visitor_.retry_token_integrity_tag_.reset(); + visitor_.retry_without_tag_.reset(); + visitor_.on_retry_packet_called_ = false; + + EXPECT_FALSE(framer_.ProcessPacket(encrypted)); + + EXPECT_THAT(framer_.error(), IsError(QUIC_INVALID_PACKET_HEADER)); + EXPECT_EQ("Client-initiated RETRY is invalid.", framer_.detailed_error()); + + EXPECT_FALSE(visitor_.on_retry_packet_called_); + EXPECT_FALSE(visitor_.retry_new_connection_id_.get()); + EXPECT_FALSE(visitor_.retry_token_.get()); + EXPECT_FALSE(visitor_.retry_token_integrity_tag_.get()); + EXPECT_FALSE(visitor_.retry_without_tag_.get()); +} + +TEST_P(QuicFramerTest, BuildPaddingFramePacket) { + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicFrames frames = {QuicFrame(QuicPaddingFrame())}; + + // clang-format off + unsigned char packet[kMaxOutgoingPacketSize] = { + // public flags (8 byte connection_id) + 0x2C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (padding frame) + 0x00, + 0x00, 0x00, 0x00, 0x00 + }; + + unsigned char packet46[kMaxOutgoingPacketSize] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (padding frame) + 0x00, + 0x00, 0x00, 0x00, 0x00 + }; + + unsigned char packet_ietf[kMaxOutgoingPacketSize] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (padding frame) + 0x00, + 0x00, 0x00, 0x00, 0x00 + }; + // clang-format on + + unsigned char* p = packet; + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + p = packet_ietf; + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + } + + uint64_t header_size = GetPacketHeaderSize( + framer_.transport_version(), kPacket8ByteConnectionId, + kPacket0ByteConnectionId, !kIncludeVersion, !kIncludeDiversificationNonce, + PACKET_4BYTE_PACKET_NUMBER, quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0, 0, + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0); + memset(p + header_size + 1, 0x00, kMaxOutgoingPacketSize - header_size - 1); + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(p), + framer_.version().HasIetfInvariantHeader() ? ABSL_ARRAYSIZE(packet46) + : ABSL_ARRAYSIZE(packet)); +} + +TEST_P(QuicFramerTest, BuildStreamFramePacketWithNewPaddingFrame) { + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + QuicStreamFrame stream_frame(kStreamId, true, kStreamOffset, + absl::string_view("hello world!")); + QuicPaddingFrame padding_frame(2); + QuicFrames frames = {QuicFrame(padding_frame), QuicFrame(stream_frame), + QuicFrame(padding_frame)}; + + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x2C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // paddings + 0x00, 0x00, + // frame type (stream frame with fin) + 0xFF, + // stream id + 0x01, 0x02, 0x03, 0x04, + // offset + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54, + // data length + 0x00, 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + // paddings + 0x00, 0x00, + }; + + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // paddings + 0x00, 0x00, + // frame type (stream frame with fin) + 0xFF, + // stream id + 0x01, 0x02, 0x03, 0x04, + // offset + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54, + // data length + 0x00, 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + // paddings + 0x00, 0x00, + }; + + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // paddings + 0x00, 0x00, + // frame type (IETF_STREAM with FIN, LEN, and OFFSET bits set) + 0x08 | 0x01 | 0x02 | 0x04, + // stream id + kVarInt62FourBytes + 0x01, 0x02, 0x03, 0x04, + // offset + kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54, + // data length + kVarInt62OneByte + 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + // paddings + 0x00, 0x00, + }; + // clang-format on + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_size = ABSL_ARRAYSIZE(packet46); + } + QuicEncryptedPacket encrypted(AsChars(p), p_size, false); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(p), p_size); +} + +TEST_P(QuicFramerTest, Build4ByteSequenceNumberPaddingFramePacket) { + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number_length = PACKET_4BYTE_PACKET_NUMBER; + header.packet_number = kPacketNumber; + + QuicFrames frames = {QuicFrame(QuicPaddingFrame())}; + + // clang-format off + unsigned char packet[kMaxOutgoingPacketSize] = { + // public flags (8 byte connection_id and 4 byte packet number) + 0x2C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (padding frame) + 0x00, + 0x00, 0x00, 0x00, 0x00 + }; + + unsigned char packet46[kMaxOutgoingPacketSize] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (padding frame) + 0x00, + 0x00, 0x00, 0x00, 0x00 + }; + + unsigned char packet_ietf[kMaxOutgoingPacketSize] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (padding frame) + 0x00, + 0x00, 0x00, 0x00, 0x00 + }; + // clang-format on + + unsigned char* p = packet; + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + p = packet_ietf; + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + } + + uint64_t header_size = GetPacketHeaderSize( + framer_.transport_version(), kPacket8ByteConnectionId, + kPacket0ByteConnectionId, !kIncludeVersion, !kIncludeDiversificationNonce, + PACKET_4BYTE_PACKET_NUMBER, quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0, 0, + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0); + memset(p + header_size + 1, 0x00, kMaxOutgoingPacketSize - header_size - 1); + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(p), + framer_.version().HasIetfInvariantHeader() ? ABSL_ARRAYSIZE(packet46) + : ABSL_ARRAYSIZE(packet)); +} + +TEST_P(QuicFramerTest, Build2ByteSequenceNumberPaddingFramePacket) { + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number_length = PACKET_2BYTE_PACKET_NUMBER; + header.packet_number = kPacketNumber; + + QuicFrames frames = {QuicFrame(QuicPaddingFrame())}; + + // clang-format off + unsigned char packet[kMaxOutgoingPacketSize] = { + // public flags (8 byte connection_id and 2 byte packet number) + 0x1C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x56, 0x78, + + // frame type (padding frame) + 0x00, + 0x00, 0x00, 0x00, 0x00 + }; + + unsigned char packet46[kMaxOutgoingPacketSize] = { + // type (short header, 2 byte packet number) + 0x41, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x56, 0x78, + + // frame type (padding frame) + 0x00, + 0x00, 0x00, 0x00, 0x00 + }; + + unsigned char packet_ietf[kMaxOutgoingPacketSize] = { + // type (short header, 2 byte packet number) + 0x41, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x56, 0x78, + + // frame type (padding frame) + 0x00, + 0x00, 0x00, 0x00, 0x00 + }; + // clang-format on + + unsigned char* p = packet; + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + p = packet_ietf; + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + } + + uint64_t header_size = GetPacketHeaderSize( + framer_.transport_version(), kPacket8ByteConnectionId, + kPacket0ByteConnectionId, !kIncludeVersion, !kIncludeDiversificationNonce, + PACKET_2BYTE_PACKET_NUMBER, quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0, 0, + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0); + memset(p + header_size + 1, 0x00, kMaxOutgoingPacketSize - header_size - 1); + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(p), + framer_.version().HasIetfInvariantHeader() ? ABSL_ARRAYSIZE(packet46) + : ABSL_ARRAYSIZE(packet)); +} + +TEST_P(QuicFramerTest, Build1ByteSequenceNumberPaddingFramePacket) { + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number_length = PACKET_1BYTE_PACKET_NUMBER; + header.packet_number = kPacketNumber; + + QuicFrames frames = {QuicFrame(QuicPaddingFrame())}; + + // clang-format off + unsigned char packet[kMaxOutgoingPacketSize] = { + // public flags (8 byte connection_id and 1 byte packet number) + 0x0C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x78, + + // frame type (padding frame) + 0x00, + 0x00, 0x00, 0x00, 0x00 + }; + + unsigned char packet46[kMaxOutgoingPacketSize] = { + // type (short header, 1 byte packet number) + 0x40, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x78, + + // frame type (padding frame) + 0x00, + 0x00, 0x00, 0x00, 0x00 + }; + + unsigned char packet_ietf[kMaxOutgoingPacketSize] = { + // type (short header, 1 byte packet number) + 0x40, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x78, + + // frame type (padding frame) + 0x00, + 0x00, 0x00, 0x00, 0x00 + }; + // clang-format on + + unsigned char* p = packet; + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + p = packet_ietf; + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + } + + uint64_t header_size = GetPacketHeaderSize( + framer_.transport_version(), kPacket8ByteConnectionId, + kPacket0ByteConnectionId, !kIncludeVersion, !kIncludeDiversificationNonce, + PACKET_1BYTE_PACKET_NUMBER, quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0, 0, + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0); + memset(p + header_size + 1, 0x00, kMaxOutgoingPacketSize - header_size - 1); + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(p), + framer_.version().HasIetfInvariantHeader() ? ABSL_ARRAYSIZE(packet46) + : ABSL_ARRAYSIZE(packet)); +} + +TEST_P(QuicFramerTest, BuildStreamFramePacket) { + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + if (QuicVersionHasLongHeaderLengths(framer_.transport_version())) { + header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2; + } + + QuicStreamFrame stream_frame(kStreamId, true, kStreamOffset, + absl::string_view("hello world!")); + + QuicFrames frames = {QuicFrame(stream_frame)}; + + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x2C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (stream frame with fin and no length) + 0xDF, + // stream id + 0x01, 0x02, 0x03, 0x04, + // offset + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + }; + + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (stream frame with fin and no length) + 0xDF, + // stream id + 0x01, 0x02, 0x03, 0x04, + // offset + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + }; + + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_STREAM frame with FIN and OFFSET, no length) + 0x08 | 0x01 | 0x04, + // stream id + kVarInt62FourBytes + 0x01, 0x02, 0x03, 0x04, + // offset + kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + }; + // clang-format on + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_size = ABSL_ARRAYSIZE(packet46); + } + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(p), p_size); +} + +TEST_P(QuicFramerTest, BuildStreamFramePacketWithVersionFlag) { + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = true; + if (framer_.version().HasIetfInvariantHeader()) { + header.long_packet_type = ZERO_RTT_PROTECTED; + } + header.packet_number = kPacketNumber; + if (QuicVersionHasLongHeaderLengths(framer_.transport_version())) { + header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2; + } + + QuicStreamFrame stream_frame(kStreamId, true, kStreamOffset, + absl::string_view("hello world!")); + QuicFrames frames = {QuicFrame(stream_frame)}; + + // clang-format off + unsigned char packet[] = { + // public flags (version, 8 byte connection_id) + 0x2D, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // version tag + QUIC_VERSION_BYTES, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (stream frame with fin and no length) + 0xDF, + // stream id + 0x01, 0x02, 0x03, 0x04, + // offset + 0x3A, 0x98, 0xFE, 0xDC, 0x32, 0x10, 0x76, 0x54, + // data + 'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', '!', + }; + + unsigned char packet46[] = { + // type (long header with packet type ZERO_RTT_PROTECTED) + 0xD3, + // version tag + QUIC_VERSION_BYTES, + // connection_id length + 0x50, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (stream frame with fin and no length) + 0xDF, + // stream id + 0x01, 0x02, 0x03, 0x04, + // offset + 0x3A, 0x98, 0xFE, 0xDC, 0x32, 0x10, 0x76, 0x54, + // data + 'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', '!', + }; + + unsigned char packet49[] = { + // type (long header with packet type ZERO_RTT_PROTECTED) + 0xD3, + // version tag + QUIC_VERSION_BYTES, + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // length + 0x40, 0x1D, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (stream frame with fin and no length) + 0xDF, + // stream id + 0x01, 0x02, 0x03, 0x04, + // offset + 0x3A, 0x98, 0xFE, 0xDC, 0x32, 0x10, 0x76, 0x54, + // data + 'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', '!', + }; + + unsigned char packet_ietf[] = { + // type (long header with packet type ZERO_RTT_PROTECTED) + 0xD3, + // version tag + QUIC_VERSION_BYTES, + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // length + 0x40, 0x1D, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_STREAM frame with fin and offset, no length) + 0x08 | 0x01 | 0x04, + // stream id + kVarInt62FourBytes + 0x01, 0x02, 0x03, 0x04, + // offset + kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, 0x32, 0x10, 0x76, 0x54, + // data + 'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', '!', + }; + // clang-format on + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + ReviseFirstByteByVersion(packet_ietf); + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); + } else if (framer_.version().HasLongHeaderLengths()) { + p = packet49; + p_size = ABSL_ARRAYSIZE(packet49); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_size = ABSL_ARRAYSIZE(packet46); + } + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(p), p_size); +} + +TEST_P(QuicFramerTest, BuildCryptoFramePacket) { + if (!QuicVersionUsesCryptoFrames(framer_.transport_version())) { + return; + } + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + SimpleDataProducer data_producer; + framer_.set_data_producer(&data_producer); + + absl::string_view crypto_frame_contents("hello world!"); + QuicCryptoFrame crypto_frame(ENCRYPTION_INITIAL, kStreamOffset, + crypto_frame_contents.length()); + data_producer.SaveCryptoData(ENCRYPTION_INITIAL, kStreamOffset, + crypto_frame_contents); + + QuicFrames frames = {QuicFrame(&crypto_frame)}; + + // clang-format off + unsigned char packet48[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (QuicFrameType CRYPTO_FRAME) + 0x08, + // offset + kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54, + // length + kVarInt62OneByte + 12, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + }; + + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_CRYPTO frame) + 0x06, + // offset + kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54, + // length + kVarInt62OneByte + 12, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + }; + // clang-format on + + unsigned char* packet = packet48; + size_t packet_size = ABSL_ARRAYSIZE(packet48); + if (framer_.version().HasIetfQuicFrames()) { + packet = packet_ietf; + packet_size = ABSL_ARRAYSIZE(packet_ietf); + } + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + quiche::test::CompareCharArraysWithHexError("constructed packet", + data->data(), data->length(), + AsChars(packet), packet_size); +} + +TEST_P(QuicFramerTest, CryptoFrame) { + if (!QuicVersionUsesCryptoFrames(framer_.transport_version())) { + // CRYPTO frames aren't supported prior to v48. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + PacketFragments packet48 = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (QuicFrameType CRYPTO_FRAME) + {"", + {0x08}}, + // offset + {"", + {kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54}}, + // data length + {"Invalid data length.", + {kVarInt62OneByte + 12}}, + // data + {"Unable to read frame data.", + {'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!'}}, + }; + + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (IETF_CRYPTO frame) + {"", + {0x06}}, + // offset + {"", + {kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54}}, + // data length + {"Invalid data length.", + {kVarInt62OneByte + 12}}, + // data + {"Unable to read frame data.", + {'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!'}}, + }; + // clang-format on + + PacketFragments& fragments = + framer_.version().HasIetfQuicFrames() ? packet_ietf : packet48; + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + ASSERT_EQ(1u, visitor_.crypto_frames_.size()); + QuicCryptoFrame* frame = visitor_.crypto_frames_[0].get(); + EXPECT_EQ(ENCRYPTION_FORWARD_SECURE, frame->level); + EXPECT_EQ(kStreamOffset, frame->offset); + EXPECT_EQ("hello world!", + std::string(frame->data_buffer, frame->data_length)); + + CheckFramingBoundaries(fragments, QUIC_INVALID_FRAME_DATA); +} + +TEST_P(QuicFramerTest, BuildVersionNegotiationPacket) { + SetQuicFlag(quic_disable_version_negotiation_grease_randomness, true); + // clang-format off + unsigned char packet[] = { + // public flags (version, 8 byte connection_id) + 0x0D, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // supported versions + 0xDA, 0x5A, 0x3A, 0x3A, + QUIC_VERSION_BYTES, + }; + unsigned char packet46[] = { + // type (long header) + 0xC0, + // version tag + 0x00, 0x00, 0x00, 0x00, + // connection_id length + 0x05, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // supported versions + 0xDA, 0x5A, 0x3A, 0x3A, + QUIC_VERSION_BYTES, + }; + unsigned char packet49[] = { + // type (long header) + 0xC0, + // version tag + 0x00, 0x00, 0x00, 0x00, + // destination connection ID length + 0x00, + // source connection ID length + 0x08, + // source connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // supported versions + 0xDA, 0x5A, 0x3A, 0x3A, + QUIC_VERSION_BYTES, + }; + // clang-format on + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + if (framer_.version().HasLongHeaderLengths()) { + p = packet49; + p_size = ABSL_ARRAYSIZE(packet49); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_size = ABSL_ARRAYSIZE(packet46); + } + + QuicConnectionId connection_id = FramerTestConnectionId(); + std::unique_ptr data( + QuicFramer::BuildVersionNegotiationPacket( + connection_id, EmptyQuicConnectionId(), + framer_.version().HasIetfInvariantHeader(), + framer_.version().HasLengthPrefixedConnectionIds(), + SupportedVersions(GetParam()))); + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(p), p_size); +} + +TEST_P(QuicFramerTest, BuildVersionNegotiationPacketWithClientConnectionId) { + if (!framer_.version().SupportsClientConnectionIds()) { + return; + } + + SetQuicFlag(quic_disable_version_negotiation_grease_randomness, true); + + // clang-format off + unsigned char packet[] = { + // type (long header) + 0xC0, + // version tag + 0x00, 0x00, 0x00, 0x00, + // client/destination connection ID + 0x08, + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x11, + // server/source connection ID + 0x08, + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // supported versions + 0xDA, 0x5A, 0x3A, 0x3A, + QUIC_VERSION_BYTES, + }; + // clang-format on + + QuicConnectionId server_connection_id = FramerTestConnectionId(); + QuicConnectionId client_connection_id = FramerTestConnectionIdPlusOne(); + std::unique_ptr data( + QuicFramer::BuildVersionNegotiationPacket( + server_connection_id, client_connection_id, true, true, + SupportedVersions(GetParam()))); + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet), + ABSL_ARRAYSIZE(packet)); +} + +TEST_P(QuicFramerTest, BuildAckFramePacketOneAckBlock) { + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + // Use kSmallLargestObserved to make this test finished in a short time. + QuicAckFrame ack_frame = InitAckFrame(kSmallLargestObserved); + ack_frame.ack_delay_time = QuicTime::Delta::Zero(); + + QuicFrames frames = {QuicFrame(&ack_frame)}; + + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x2C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (ack frame) + // (no ack blocks, 2 byte largest observed, 2 byte block length) + 0x45, + // largest acked + 0x12, 0x34, + // Zero delta time. + 0x00, 0x00, + // first ack block length. + 0x12, 0x34, + // num timestamps. + 0x00, + }; + + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (ack frame) + // (no ack blocks, 2 byte largest observed, 2 byte block length) + 0x45, + // largest acked + 0x12, 0x34, + // Zero delta time. + 0x00, 0x00, + // first ack block length. + 0x12, 0x34, + // num timestamps. + 0x00, + }; + + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_ACK frame) + 0x02, + // largest acked + kVarInt62TwoBytes + 0x12, 0x34, + // Zero delta time. + kVarInt62OneByte + 0x00, + // Number of additional ack blocks. + kVarInt62OneByte + 0x00, + // first ack block length. + kVarInt62TwoBytes + 0x12, 0x33, + }; + // clang-format on + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_size = ABSL_ARRAYSIZE(packet46); + } + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(p), p_size); +} + +TEST_P(QuicFramerTest, BuildAckReceiveTimestampsFrameMultipleRanges) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicAckFrame ack_frame = InitAckFrame(kSmallLargestObserved); + ack_frame.received_packet_times = PacketTimeVector{ + // Timestamp Range 3. + {kSmallLargestObserved - 22, CreationTimePlus(0x29ffdddd)}, + {kSmallLargestObserved - 21, CreationTimePlus(0x29ffdedd)}, + // Timestamp Range 2. + {kSmallLargestObserved - 11, CreationTimePlus(0x29ffdeed)}, + // Timestamp Range 1. + {kSmallLargestObserved - 4, CreationTimePlus(0x29ffeeed)}, + {kSmallLargestObserved - 3, CreationTimePlus(0x29ffeeee)}, + {kSmallLargestObserved - 2, CreationTimePlus(0x29ffffff)}, + }; + ack_frame.ack_delay_time = QuicTime::Delta::Zero(); + QuicFrames frames = {QuicFrame(&ack_frame)}; + + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, + 0xDC, + 0xBA, + 0x98, + 0x76, + 0x54, + 0x32, + 0x10, + // packet number + 0x12, + 0x34, + 0x56, + 0x78, + + // frame type (IETF_ACK_RECEIVE_TIMESTAMPS frame) + 0x22, + // largest acked + kVarInt62TwoBytes + 0x12, + 0x34, // = 4660 + // Zero delta time. + kVarInt62OneByte + 0x00, + // number of additional ack blocks + kVarInt62OneByte + 0x00, + // first ack block length. + kVarInt62TwoBytes + 0x12, + 0x33, + + // Receive Timestamps. + + // Timestamp Range Count + kVarInt62OneByte + 0x03, + + // Timestamp range 1 (three packets). + // Gap + kVarInt62OneByte + 0x02, + // Timestamp Range Count + kVarInt62OneByte + 0x03, + // Timestamp Delta + kVarInt62FourBytes + 0x29, + 0xff, + 0xff, + 0xff, + // Timestamp Delta + kVarInt62TwoBytes + 0x11, + 0x11, + // Timestamp Delta + kVarInt62OneByte + 0x01, + + // Timestamp range 2 (one packet). + // Gap + kVarInt62OneByte + 0x05, + // Timestamp Range Count + kVarInt62OneByte + 0x01, + // Timestamp Delta + kVarInt62TwoBytes + 0x10, + 0x00, + + // Timestamp range 3 (two packets). + // Gap + kVarInt62OneByte + 0x08, + // Timestamp Range Count + kVarInt62OneByte + 0x02, + // Timestamp Delta + kVarInt62OneByte + 0x10, + // Timestamp Delta + kVarInt62TwoBytes + 0x01, + 0x00, + }; + // clang-format on + + framer_.set_process_timestamps(true); + framer_.set_max_receive_timestamps_per_ack(8); + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); +} + +TEST_P(QuicFramerTest, BuildAckReceiveTimestampsFrameExceedsMaxTimestamps) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicAckFrame ack_frame = InitAckFrame(kSmallLargestObserved); + ack_frame.received_packet_times = PacketTimeVector{ + // Timestamp Range 3 (not included because max receive timestamps = 4). + {kSmallLargestObserved - 20, CreationTimePlus(0x29ffdddd)}, + // Timestamp Range 2. + {kSmallLargestObserved - 10, CreationTimePlus(0x29ffdedd)}, + {kSmallLargestObserved - 9, CreationTimePlus(0x29ffdeed)}, + // Timestamp Range 1. + {kSmallLargestObserved - 2, CreationTimePlus(0x29ffeeed)}, + {kSmallLargestObserved - 1, CreationTimePlus(0x29ffeeee)}, + {kSmallLargestObserved, CreationTimePlus(0x29ffffff)}, + }; + ack_frame.ack_delay_time = QuicTime::Delta::Zero(); + QuicFrames frames = {QuicFrame(&ack_frame)}; + + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, + 0xDC, + 0xBA, + 0x98, + 0x76, + 0x54, + 0x32, + 0x10, + // packet number + 0x12, + 0x34, + 0x56, + 0x78, + + // frame type (IETF_ACK_RECEIVE_TIMESTAMPS frame) + 0x22, + // largest acked + kVarInt62TwoBytes + 0x12, + 0x34, // = 4660 + // Zero delta time. + kVarInt62OneByte + 0x00, + // number of additional ack blocks + kVarInt62OneByte + 0x00, + // first ack block length. + kVarInt62TwoBytes + 0x12, + 0x33, + + // Receive Timestamps. + + // Timestamp Range Count + kVarInt62OneByte + 0x02, + + // Timestamp range 1 (three packets). + // Gap + kVarInt62OneByte + 0x00, + // Timestamp Range Count + kVarInt62OneByte + 0x03, + // Timestamp Delta + kVarInt62FourBytes + 0x29, + 0xff, + 0xff, + 0xff, + // Timestamp Delta + kVarInt62TwoBytes + 0x11, + 0x11, + // Timestamp Delta + kVarInt62OneByte + 0x01, + + // Timestamp range 2 (one packet). + // Gap + kVarInt62OneByte + 0x05, + // Timestamp Range Count + kVarInt62OneByte + 0x01, + // Timestamp Delta + kVarInt62TwoBytes + 0x10, + 0x00, + }; + // clang-format on + + framer_.set_process_timestamps(true); + framer_.set_max_receive_timestamps_per_ack(4); + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); +} + +TEST_P(QuicFramerTest, BuildAckReceiveTimestampsFrameWithExponentEncoding) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicAckFrame ack_frame = InitAckFrame(kSmallLargestObserved); + ack_frame.received_packet_times = PacketTimeVector{ + // Timestamp Range 2. + {kSmallLargestObserved - 12, CreationTimePlus((0x06c00 << 3) + 0x03)}, + {kSmallLargestObserved - 11, CreationTimePlus((0x28e00 << 3) + 0x00)}, + // Timestamp Range 1. + {kSmallLargestObserved - 5, CreationTimePlus((0x29f00 << 3) + 0x00)}, + {kSmallLargestObserved - 4, CreationTimePlus((0x29f00 << 3) + 0x01)}, + {kSmallLargestObserved - 3, CreationTimePlus((0x29f00 << 3) + 0x02)}, + {kSmallLargestObserved - 2, CreationTimePlus((0x29f00 << 3) + 0x03)}, + }; + ack_frame.ack_delay_time = QuicTime::Delta::Zero(); + + QuicFrames frames = {QuicFrame(&ack_frame)}; + + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, + 0xDC, + 0xBA, + 0x98, + 0x76, + 0x54, + 0x32, + 0x10, + // packet number + 0x12, + 0x34, + 0x56, + 0x78, + + // frame type (IETF_ACK_RECEIVE_TIMESTAMPS frame) + 0x22, + // largest acked + kVarInt62TwoBytes + 0x12, + 0x34, // = 4660 + // Zero delta time. + kVarInt62OneByte + 0x00, + // number of additional ack blocks + kVarInt62OneByte + 0x00, + // first ack block length. + kVarInt62TwoBytes + 0x12, + 0x33, + + // Receive Timestamps. + + // Timestamp Range Count + kVarInt62OneByte + 0x02, + + // Timestamp range 1 (three packets). + // Gap + kVarInt62OneByte + 0x02, + // Timestamp Range Count + kVarInt62OneByte + 0x04, + // Timestamp Delta + kVarInt62FourBytes + 0x00, + 0x02, + 0x9f, + 0x01, // round up + // Timestamp Delta + kVarInt62OneByte + 0x00, + // Timestamp Delta + kVarInt62OneByte + 0x00, + // Timestamp Delta + kVarInt62OneByte + 0x01, + + // Timestamp range 2 (one packet). + // Gap + kVarInt62OneByte + 0x04, + // Timestamp Range Count + kVarInt62OneByte + 0x02, + // Timestamp Delta + kVarInt62TwoBytes + 0x11, + 0x00, + // Timestamp Delta + kVarInt62FourBytes + 0x00, + 0x02, + 0x21, + 0xff, + }; + // clang-format on + + framer_.set_process_timestamps(true); + framer_.set_max_receive_timestamps_per_ack(8); + framer_.set_receive_timestamps_exponent(3); + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); +} + +TEST_P(QuicFramerTest, BuildAndProcessAckReceiveTimestampsWithMultipleRanges) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(/*key=*/0)); + framer_.SetKeyUpdateSupportForConnection(true); + framer_.set_process_timestamps(true); + framer_.set_max_receive_timestamps_per_ack(8); + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicAckFrame ack_frame = InitAckFrame(kSmallLargestObserved); + ack_frame.received_packet_times = PacketTimeVector{ + {kSmallLargestObserved - 1201, CreationTimePlus(0x8bcaef234)}, + {kSmallLargestObserved - 1200, CreationTimePlus(0x8bcdef123)}, + {kSmallLargestObserved - 1000, CreationTimePlus(0xaacdef123)}, + {kSmallLargestObserved - 4, CreationTimePlus(0xabcdea125)}, + {kSmallLargestObserved - 2, CreationTimePlus(0xabcdee124)}, + {kSmallLargestObserved - 1, CreationTimePlus(0xabcdef123)}, + {kSmallLargestObserved, CreationTimePlus(0xabcdef123)}, + }; + ack_frame.ack_delay_time = QuicTime::Delta::Zero(); + QuicFrames frames = {QuicFrame(&ack_frame)}; + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + std::unique_ptr encrypted( + EncryptPacketWithTagAndPhase(*data, 0, false)); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsQuicNoError()); + + const QuicAckFrame& frame = *visitor_.ack_frames_[0]; + EXPECT_THAT(frame.received_packet_times, + ContainerEq(PacketTimeVector{ + {kSmallLargestObserved, CreationTimePlus(0xabcdef123)}, + {kSmallLargestObserved - 1, CreationTimePlus(0xabcdef123)}, + {kSmallLargestObserved - 2, CreationTimePlus(0xabcdee124)}, + {kSmallLargestObserved - 4, CreationTimePlus(0xabcdea125)}, + {kSmallLargestObserved - 1000, CreationTimePlus(0xaacdef123)}, + {kSmallLargestObserved - 1200, CreationTimePlus(0x8bcdef123)}, + {kSmallLargestObserved - 1201, CreationTimePlus(0x8bcaef234)}, + })); +} + +TEST_P(QuicFramerTest, + BuildAndProcessAckReceiveTimestampsExceedsMaxTimestamps) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(/*key=*/0)); + framer_.SetKeyUpdateSupportForConnection(true); + framer_.set_process_timestamps(true); + framer_.set_max_receive_timestamps_per_ack(2); + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicAckFrame ack_frame = InitAckFrame(kSmallLargestObserved); + ack_frame.received_packet_times = PacketTimeVector{ + {kSmallLargestObserved - 1201, CreationTimePlus(0x8bcaef234)}, + {kSmallLargestObserved - 1200, CreationTimePlus(0x8bcdef123)}, + {kSmallLargestObserved - 1000, CreationTimePlus(0xaacdef123)}, + {kSmallLargestObserved - 5, CreationTimePlus(0xabcdea125)}, + {kSmallLargestObserved - 3, CreationTimePlus(0xabcded124)}, + {kSmallLargestObserved - 2, CreationTimePlus(0xabcdee124)}, + {kSmallLargestObserved - 1, CreationTimePlus(0xabcdef123)}, + }; + ack_frame.ack_delay_time = QuicTime::Delta::Zero(); + QuicFrames frames = {QuicFrame(&ack_frame)}; + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + std::unique_ptr encrypted( + EncryptPacketWithTagAndPhase(*data, 0, false)); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsQuicNoError()); + + const QuicAckFrame& frame = *visitor_.ack_frames_[0]; + EXPECT_THAT(frame.received_packet_times, + ContainerEq(PacketTimeVector{ + {kSmallLargestObserved - 1, CreationTimePlus(0xabcdef123)}, + {kSmallLargestObserved - 2, CreationTimePlus(0xabcdee124)}, + })); +} + +TEST_P(QuicFramerTest, + BuildAndProcessAckReceiveTimestampsWithExponentNoTruncation) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(/*key=*/0)); + framer_.SetKeyUpdateSupportForConnection(true); + framer_.set_process_timestamps(true); + framer_.set_max_receive_timestamps_per_ack(8); + framer_.set_receive_timestamps_exponent(3); + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicAckFrame ack_frame = InitAckFrame(kSmallLargestObserved); + ack_frame.received_packet_times = PacketTimeVector{ + {kSmallLargestObserved - 8, CreationTimePlus(0x1add << 3)}, + {kSmallLargestObserved - 7, CreationTimePlus(0x29ed << 3)}, + {kSmallLargestObserved - 3, CreationTimePlus(0x29fe << 3)}, + {kSmallLargestObserved - 2, CreationTimePlus(0x29ff << 3)}, + }; + ack_frame.ack_delay_time = QuicTime::Delta::Zero(); + QuicFrames frames = {QuicFrame(&ack_frame)}; + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + std::unique_ptr encrypted( + EncryptPacketWithTagAndPhase(*data, 0, false)); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsQuicNoError()); + + const QuicAckFrame& frame = *visitor_.ack_frames_[0]; + EXPECT_THAT(frame.received_packet_times, + ContainerEq(PacketTimeVector{ + {kSmallLargestObserved - 2, CreationTimePlus(0x29ff << 3)}, + {kSmallLargestObserved - 3, CreationTimePlus(0x29fe << 3)}, + {kSmallLargestObserved - 7, CreationTimePlus(0x29ed << 3)}, + {kSmallLargestObserved - 8, CreationTimePlus(0x1add << 3)}, + })); +} + +TEST_P(QuicFramerTest, + BuildAndProcessAckReceiveTimestampsWithExponentTruncation) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(/*key=*/0)); + framer_.SetKeyUpdateSupportForConnection(true); + framer_.set_process_timestamps(true); + framer_.set_max_receive_timestamps_per_ack(8); + framer_.set_receive_timestamps_exponent(3); + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicAckFrame ack_frame = InitAckFrame(kSmallLargestObserved); + ack_frame.received_packet_times = PacketTimeVector{ + {kSmallLargestObserved - 10, CreationTimePlus((0x1001 << 3) + 1)}, + {kSmallLargestObserved - 9, CreationTimePlus((0x2995 << 3) - 1)}, + {kSmallLargestObserved - 8, CreationTimePlus((0x2995 << 3) + 0)}, + {kSmallLargestObserved - 7, CreationTimePlus((0x2995 << 3) + 1)}, + {kSmallLargestObserved - 6, CreationTimePlus((0x2995 << 3) + 2)}, + {kSmallLargestObserved - 3, CreationTimePlus((0x2995 << 3) + 3)}, + {kSmallLargestObserved - 2, CreationTimePlus((0x2995 << 3) + 4)}, + }; + ack_frame.ack_delay_time = QuicTime::Delta::Zero(); + QuicFrames frames = {QuicFrame(&ack_frame)}; + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + std::unique_ptr encrypted( + EncryptPacketWithTagAndPhase(*data, 0, false)); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsQuicNoError()); + + const QuicAckFrame& frame = *visitor_.ack_frames_[0]; + EXPECT_THAT(frame.received_packet_times, + ContainerEq(PacketTimeVector{ + {kSmallLargestObserved - 2, CreationTimePlus(0x2996 << 3)}, + {kSmallLargestObserved - 3, CreationTimePlus(0x2996 << 3)}, + {kSmallLargestObserved - 6, CreationTimePlus(0x2996 << 3)}, + {kSmallLargestObserved - 7, CreationTimePlus(0x2996 << 3)}, + {kSmallLargestObserved - 8, CreationTimePlus(0x2995 << 3)}, + {kSmallLargestObserved - 9, CreationTimePlus(0x2995 << 3)}, + {kSmallLargestObserved - 10, CreationTimePlus(0x1002 << 3)}, + })); +} + +TEST_P(QuicFramerTest, AckReceiveTimestamps) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(/*key=*/0)); + framer_.SetKeyUpdateSupportForConnection(true); + framer_.set_process_timestamps(true); + framer_.set_max_receive_timestamps_per_ack(8); + framer_.set_receive_timestamps_exponent(3); + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + // Use kSmallLargestObserved to make this test finished in a short time. + QuicAckFrame ack_frame = InitAckFrame(kSmallLargestObserved); + ack_frame.received_packet_times = PacketTimeVector{ + {kSmallLargestObserved - 5, CreationTimePlus((0x29ff << 3))}, + {kSmallLargestObserved - 4, CreationTimePlus((0x29ff << 3))}, + {kSmallLargestObserved - 3, CreationTimePlus((0x29ff << 3))}, + {kSmallLargestObserved - 2, CreationTimePlus((0x29ff << 3))}, + }; + ack_frame.ack_delay_time = QuicTime::Delta::Zero(); + QuicFrames frames = {QuicFrame(&ack_frame)}; + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + std::unique_ptr encrypted( + EncryptPacketWithTagAndPhase(*data, 0, false)); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsQuicNoError()); + + const QuicAckFrame& frame = *visitor_.ack_frames_[0]; + EXPECT_THAT(frame.received_packet_times, + ContainerEq(PacketTimeVector{ + {kSmallLargestObserved - 2, CreationTimePlus(0x29ff << 3)}, + {kSmallLargestObserved - 3, CreationTimePlus(0x29ff << 3)}, + {kSmallLargestObserved - 4, CreationTimePlus(0x29ff << 3)}, + {kSmallLargestObserved - 5, CreationTimePlus(0x29ff << 3)}, + })); +} + +TEST_P(QuicFramerTest, AckReceiveTimestampsPacketOutOfOrder) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(/*key=*/0)); + framer_.SetKeyUpdateSupportForConnection(true); + framer_.set_process_timestamps(true); + framer_.set_max_receive_timestamps_per_ack(8); + framer_.set_receive_timestamps_exponent(3); + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + // Use kSmallLargestObserved to make this test finished in a short time. + QuicAckFrame ack_frame = InitAckFrame(kSmallLargestObserved); + + // The packet numbers below are out of order, this is impossible because we + // don't record out of order packets in received_packet_times. The test is + // intended to ensure this error is raised when it happens. + ack_frame.received_packet_times = PacketTimeVector{ + {kSmallLargestObserved - 5, CreationTimePlus((0x29ff << 3))}, + {kSmallLargestObserved - 2, CreationTimePlus((0x29ff << 3))}, + {kSmallLargestObserved - 4, CreationTimePlus((0x29ff << 3))}, + {kSmallLargestObserved - 3, CreationTimePlus((0x29ff << 3))}, + }; + ack_frame.ack_delay_time = QuicTime::Delta::Zero(); + QuicFrames frames = {QuicFrame(&ack_frame)}; + + EXPECT_QUIC_BUG(BuildDataPacket(header, frames), + "Packet number and/or receive time not in order."); +} + +// If there's insufficient room for IETF ack receive timestamps, don't write any +// timestamp ranges. +TEST_P(QuicFramerTest, IetfAckReceiveTimestampsTruncate) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(/*key=*/0)); + framer_.SetKeyUpdateSupportForConnection(true); + framer_.set_process_timestamps(true); + framer_.set_max_receive_timestamps_per_ack(8192); + framer_.set_receive_timestamps_exponent(3); + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + // Use kSmallLargestObserved to make this test finished in a short time. + QuicAckFrame ack_frame = InitAckFrame(kSmallLargestObserved); + for (QuicPacketNumber i(1); i <= kSmallLargestObserved; i += 2) { + ack_frame.received_packet_times.push_back( + {i, CreationTimePlus((0x29ff << 3))}); + } + + ack_frame.ack_delay_time = QuicTime::Delta::Zero(); + QuicFrames frames = {QuicFrame(&ack_frame)}; + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + std::unique_ptr encrypted( + EncryptPacketWithTagAndPhase(*data, 0, false)); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsQuicNoError()); + + const QuicAckFrame& frame = *visitor_.ack_frames_[0]; + EXPECT_TRUE(frame.received_packet_times.empty()); +} + +// If there are too many ack ranges, they will be truncated to make room for a +// timestamp range count of 0. +TEST_P(QuicFramerTest, IetfAckReceiveTimestampsAckRangeTruncation) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + framer_.set_process_timestamps(true); + framer_.set_max_receive_timestamps_per_ack(8); + framer_.set_receive_timestamps_exponent(3); + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicAckFrame ack_frame; + // Create a packet with just the ack. + ack_frame = MakeAckFrameWithGaps(/*gap_size=*/0xffffffff, + /*max_num_gaps=*/200, + /*largest_acked=*/kMaxIetfVarInt); + ack_frame.received_packet_times = PacketTimeVector{ + {QuicPacketNumber(kMaxIetfVarInt) - 2, CreationTimePlus((0x29ff << 3))}, + }; + QuicFrames frames = {QuicFrame(&ack_frame)}; + // Build an ACK packet. + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + std::unique_ptr raw_ack_packet(BuildDataPacket(header, frames)); + ASSERT_TRUE(raw_ack_packet != nullptr); + char buffer[kMaxOutgoingPacketSize]; + size_t encrypted_length = + framer_.EncryptPayload(ENCRYPTION_INITIAL, header.packet_number, + *raw_ack_packet, buffer, kMaxOutgoingPacketSize); + ASSERT_NE(0u, encrypted_length); + // Now make sure we can turn our ack packet back into an ack frame. + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + ASSERT_TRUE(framer_.ProcessPacket( + QuicEncryptedPacket(buffer, encrypted_length, false))); + ASSERT_EQ(1u, visitor_.ack_frames_.size()); + QuicAckFrame& processed_ack_frame = *visitor_.ack_frames_[0]; + EXPECT_EQ(QuicPacketNumber(kMaxIetfVarInt), + LargestAcked(processed_ack_frame)); + // Verify ACK ranges in the frame gets truncated. + ASSERT_LT(processed_ack_frame.packets.NumPacketsSlow(), + ack_frame.packets.NumIntervals()); + EXPECT_EQ(158u, processed_ack_frame.packets.NumPacketsSlow()); + EXPECT_LT(processed_ack_frame.packets.NumIntervals(), + ack_frame.packets.NumIntervals()); + EXPECT_EQ(QuicPacketNumber(kMaxIetfVarInt), + processed_ack_frame.packets.Max()); + // But the receive timestamps are not truncated because they are small. + EXPECT_FALSE(processed_ack_frame.received_packet_times.empty()); +} + +TEST_P(QuicFramerTest, BuildAckFramePacketOneAckBlockMaxLength) { + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicAckFrame ack_frame = InitAckFrame(kPacketNumber); + ack_frame.ack_delay_time = QuicTime::Delta::Zero(); + + QuicFrames frames = {QuicFrame(&ack_frame)}; + + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x2C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (ack frame) + // (no ack blocks, 4 byte largest observed, 4 byte block length) + 0x4A, + // largest acked + 0x12, 0x34, 0x56, 0x78, + // Zero delta time. + 0x00, 0x00, + // first ack block length. + 0x12, 0x34, 0x56, 0x78, + // num timestamps. + 0x00, + }; + + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (ack frame) + // (no ack blocks, 4 byte largest observed, 4 byte block length) + 0x4A, + // largest acked + 0x12, 0x34, 0x56, 0x78, + // Zero delta time. + 0x00, 0x00, + // first ack block length. + 0x12, 0x34, 0x56, 0x78, + // num timestamps. + 0x00, + }; + + + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_ACK frame) + 0x02, + // largest acked + kVarInt62FourBytes + 0x12, 0x34, 0x56, 0x78, + // Zero delta time. + kVarInt62OneByte + 0x00, + // Nr. of additional ack blocks + kVarInt62OneByte + 0x00, + // first ack block length. + kVarInt62FourBytes + 0x12, 0x34, 0x56, 0x77, + }; + // clang-format on + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_size = ABSL_ARRAYSIZE(packet46); + } + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(p), p_size); +} + +TEST_P(QuicFramerTest, BuildAckFramePacketMultipleAckBlocks) { + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + // Use kSmallLargestObserved to make this test finished in a short time. + QuicAckFrame ack_frame = + InitAckFrame({{QuicPacketNumber(1), QuicPacketNumber(5)}, + {QuicPacketNumber(10), QuicPacketNumber(500)}, + {QuicPacketNumber(900), kSmallMissingPacket}, + {kSmallMissingPacket + 1, kSmallLargestObserved + 1}}); + ack_frame.ack_delay_time = QuicTime::Delta::Zero(); + + QuicFrames frames = {QuicFrame(&ack_frame)}; + + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x2C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (ack frame) + // (has ack blocks, 2 byte largest observed, 2 byte block length) + 0x65, + // largest acked + 0x12, 0x34, + // Zero delta time. + 0x00, 0x00, + // num ack blocks ranges. + 0x04, + // first ack block length. + 0x00, 0x01, + // gap to next block. + 0x01, + // ack block length. + 0x0e, 0xaf, + // gap to next block. + 0xff, + // ack block length. + 0x00, 0x00, + // gap to next block. + 0x91, + // ack block length. + 0x01, 0xea, + // gap to next block. + 0x05, + // ack block length. + 0x00, 0x04, + // num timestamps. + 0x00, + }; + + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (ack frame) + // (has ack blocks, 2 byte largest observed, 2 byte block length) + 0x65, + // largest acked + 0x12, 0x34, + // Zero delta time. + 0x00, 0x00, + // num ack blocks ranges. + 0x04, + // first ack block length. + 0x00, 0x01, + // gap to next block. + 0x01, + // ack block length. + 0x0e, 0xaf, + // gap to next block. + 0xff, + // ack block length. + 0x00, 0x00, + // gap to next block. + 0x91, + // ack block length. + 0x01, 0xea, + // gap to next block. + 0x05, + // ack block length. + 0x00, 0x04, + // num timestamps. + 0x00, + }; + + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_ACK frame) + 0x02, + // largest acked + kVarInt62TwoBytes + 0x12, 0x34, + // Zero delta time. + kVarInt62OneByte + 0x00, + // num additional ack blocks. + kVarInt62OneByte + 0x03, + // first ack block length. + kVarInt62OneByte + 0x00, + + // gap to next block. + kVarInt62OneByte + 0x00, + // ack block length. + kVarInt62TwoBytes + 0x0e, 0xae, + + // gap to next block. + kVarInt62TwoBytes + 0x01, 0x8f, + // ack block length. + kVarInt62TwoBytes + 0x01, 0xe9, + + // gap to next block. + kVarInt62OneByte + 0x04, + // ack block length. + kVarInt62OneByte + 0x03, + }; + // clang-format on + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_size = ABSL_ARRAYSIZE(packet46); + } + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(p), p_size); +} + +TEST_P(QuicFramerTest, BuildAckFramePacketMaxAckBlocks) { + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + // Use kSmallLargestObservedto make this test finished in a short time. + QuicAckFrame ack_frame; + ack_frame.largest_acked = kSmallLargestObserved; + ack_frame.ack_delay_time = QuicTime::Delta::Zero(); + // 300 ack blocks. + for (size_t i = 2; i < 2 * 300; i += 2) { + ack_frame.packets.Add(QuicPacketNumber(i)); + } + ack_frame.packets.AddRange(QuicPacketNumber(600), kSmallLargestObserved + 1); + + QuicFrames frames = {QuicFrame(&ack_frame)}; + + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x2C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + // frame type (ack frame) + // (has ack blocks, 2 byte largest observed, 2 byte block length) + 0x65, + // largest acked + 0x12, 0x34, + // Zero delta time. + 0x00, 0x00, + // num ack blocks ranges. + 0xff, + // first ack block length. + 0x0f, 0xdd, + // 255 = 4 * 63 + 3 + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + // num timestamps. + 0x00, + }; + + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + // frame type (ack frame) + // (has ack blocks, 2 byte largest observed, 2 byte block length) + 0x65, + // largest acked + 0x12, 0x34, + // Zero delta time. + 0x00, 0x00, + // num ack blocks ranges. + 0xff, + // first ack block length. + 0x0f, 0xdd, + // 255 = 4 * 63 + 3 + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, + // num timestamps. + 0x00, + }; + + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + // frame type (IETF_ACK frame) + 0x02, + // largest acked + kVarInt62TwoBytes + 0x12, 0x34, + // Zero delta time. + kVarInt62OneByte + 0x00, + // num ack blocks ranges. + kVarInt62TwoBytes + 0x01, 0x2b, + // first ack block length. + kVarInt62TwoBytes + 0x0f, 0xdc, + // 255 added blocks of gap_size == 1, ack_size == 1 +#define V99AddedBLOCK kVarInt62OneByte + 0x00, kVarInt62OneByte + 0x00 + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, V99AddedBLOCK, + +#undef V99AddedBLOCK + }; + // clang-format on + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_size = ABSL_ARRAYSIZE(packet46); + } + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(p), p_size); +} + +TEST_P(QuicFramerTest, BuildNewStopWaitingPacket) { + if (framer_.version().HasIetfInvariantHeader()) { + return; + } + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicStopWaitingFrame stop_waiting_frame; + stop_waiting_frame.least_unacked = kLeastUnacked; + + QuicFrames frames = {QuicFrame(stop_waiting_frame)}; + + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x2C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (stop waiting frame) + 0x06, + // least packet number awaiting an ack, delta from packet number. + 0x00, 0x00, 0x00, 0x08, + }; + + // clang-format on + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet), + ABSL_ARRAYSIZE(packet)); +} + +TEST_P(QuicFramerTest, BuildRstFramePacketQuic) { + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicRstStreamFrame rst_frame; + rst_frame.stream_id = kStreamId; + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + rst_frame.ietf_error_code = 0x01; + } else { + rst_frame.error_code = static_cast(0x05060708); + } + rst_frame.byte_offset = 0x0807060504030201; + + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x2C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (rst stream frame) + 0x01, + // stream id + 0x01, 0x02, 0x03, 0x04, + // sent byte offset + 0x08, 0x07, 0x06, 0x05, + 0x04, 0x03, 0x02, 0x01, + // error code + 0x05, 0x06, 0x07, 0x08, + }; + + unsigned char packet46[] = { + // type (short packet, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (rst stream frame) + 0x01, + // stream id + 0x01, 0x02, 0x03, 0x04, + // sent byte offset + 0x08, 0x07, 0x06, 0x05, + 0x04, 0x03, 0x02, 0x01, + // error code + 0x05, 0x06, 0x07, 0x08, + }; + + unsigned char packet_ietf[] = { + // type (short packet, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_RST_STREAM frame) + 0x04, + // stream id + kVarInt62FourBytes + 0x01, 0x02, 0x03, 0x04, + // error code + kVarInt62OneByte + 0x01, + // sent byte offset + kVarInt62EightBytes + 0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01 + }; + // clang-format on + + QuicFrames frames = {QuicFrame(&rst_frame)}; + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_size = ABSL_ARRAYSIZE(packet46); + } + QuicEncryptedPacket encrypted(AsChars(p), p_size, false); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(p), p_size); +} + +TEST_P(QuicFramerTest, BuildCloseFramePacket) { + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicConnectionCloseFrame close_frame(framer_.transport_version(), + QUIC_INTERNAL_ERROR, NO_IETF_QUIC_ERROR, + "because I can", 0x05); + QuicFrames frames = {QuicFrame(&close_frame)}; + + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x2C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (connection close frame) + 0x02, + // error code + 0x00, 0x00, 0x00, 0x01, + // error details length + 0x00, 0x0d, + // error details + 'b', 'e', 'c', 'a', + 'u', 's', 'e', ' ', + 'I', ' ', 'c', 'a', + 'n', + }; + + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (connection close frame) + 0x02, + // error code + 0x00, 0x00, 0x00, 0x01, + // error details length + 0x00, 0x0d, + // error details + 'b', 'e', 'c', 'a', + 'u', 's', 'e', ' ', + 'I', ' ', 'c', 'a', + 'n', + }; + + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_CONNECTION_CLOSE frame) + 0x1c, + // error code + kVarInt62OneByte + 0x01, + // Frame type within the CONNECTION_CLOSE frame + kVarInt62OneByte + 0x05, + // error details length + kVarInt62OneByte + 0x0f, + // error details + '1', ':', 'b', 'e', + 'c', 'a', 'u', 's', + 'e', ' ', 'I', ' ', + 'c', 'a', 'n', + }; + // clang-format on + + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_size = ABSL_ARRAYSIZE(packet46); + } + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(p), p_size); +} + +TEST_P(QuicFramerTest, BuildCloseFramePacketExtendedInfo) { + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicConnectionCloseFrame close_frame( + framer_.transport_version(), + static_cast( + VersionHasIetfQuicFrames(framer_.transport_version()) ? 0x01 + : 0x05060708), + NO_IETF_QUIC_ERROR, "because I can", 0x05); + // Set this so that it is "there" for both Google QUIC and IETF QUIC + // framing. It better not show up for Google QUIC! + close_frame.quic_error_code = static_cast(0x4567); + + QuicFrames frames = {QuicFrame(&close_frame)}; + + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x2C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (connection close frame) + 0x02, + // error code + 0x05, 0x06, 0x07, 0x08, + // error details length + 0x00, 0x0d, + // error details + 'b', 'e', 'c', 'a', + 'u', 's', 'e', ' ', + 'I', ' ', 'c', 'a', + 'n', + }; + + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (connection close frame) + 0x02, + // error code + 0x05, 0x06, 0x07, 0x08, + // error details length + 0x00, 0x0d, + // error details + 'b', 'e', 'c', 'a', + 'u', 's', 'e', ' ', + 'I', ' ', 'c', 'a', + 'n', + }; + + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_CONNECTION_CLOSE frame) + 0x1c, + // IETF error code INTERNAL_ERROR = 0x01 corresponding to + // QuicErrorCode::QUIC_INTERNAL_ERROR = 0x01. + kVarInt62OneByte + 0x01, + // Frame type within the CONNECTION_CLOSE frame + kVarInt62OneByte + 0x05, + // error details length + kVarInt62OneByte + 0x13, + // error details + '1', '7', '7', '6', + '7', ':', 'b', 'e', + 'c', 'a', 'u', 's', + 'e', ' ', 'I', ' ', + 'c', 'a', 'n' + }; + // clang-format on + + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_size = ABSL_ARRAYSIZE(packet46); + } + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(p), p_size); +} + +TEST_P(QuicFramerTest, BuildTruncatedCloseFramePacket) { + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicConnectionCloseFrame close_frame(framer_.transport_version(), + QUIC_INTERNAL_ERROR, NO_IETF_QUIC_ERROR, + std::string(2048, 'A'), 0x05); + QuicFrames frames = {QuicFrame(&close_frame)}; + + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x2C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (connection close frame) + 0x02, + // error code + 0x00, 0x00, 0x00, 0x01, + // error details length + 0x01, 0x00, + // error details (truncated to 256 bytes) + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + }; + + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (connection close frame) + 0x02, + // error code + 0x00, 0x00, 0x00, 0x01, + // error details length + 0x01, 0x00, + // error details (truncated to 256 bytes) + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + }; + + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_CONNECTION_CLOSE frame) + 0x1c, + // error code + kVarInt62OneByte + 0x01, + // Frame type within the CONNECTION_CLOSE frame + kVarInt62OneByte + 0x05, + // error details length + kVarInt62TwoBytes + 0x01, 0x00, + // error details (truncated to 256 bytes) + '1', ':', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + }; + // clang-format on + + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_size = ABSL_ARRAYSIZE(packet46); + } + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(p), p_size); +} + +TEST_P(QuicFramerTest, BuildApplicationCloseFramePacket) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // This frame is only for IETF QUIC. + return; + } + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicConnectionCloseFrame app_close_frame; + app_close_frame.wire_error_code = 0x11; + app_close_frame.error_details = "because I can"; + app_close_frame.close_type = IETF_QUIC_APPLICATION_CONNECTION_CLOSE; + + QuicFrames frames = {QuicFrame(&app_close_frame)}; + + // clang-format off + + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_APPLICATION_CLOSE frame) + 0x1d, + // error code + kVarInt62OneByte + 0x11, + // error details length + kVarInt62OneByte + 0x0f, + // error details, note that it includes an extended error code. + '0', ':', 'b', 'e', + 'c', 'a', 'u', 's', + 'e', ' ', 'I', ' ', + 'c', 'a', 'n', + }; + // clang-format on + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); +} + +TEST_P(QuicFramerTest, BuildTruncatedApplicationCloseFramePacket) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // This frame is only for IETF QUIC. + return; + } + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicConnectionCloseFrame app_close_frame; + app_close_frame.wire_error_code = 0x11; + app_close_frame.error_details = std::string(2048, 'A'); + app_close_frame.close_type = IETF_QUIC_APPLICATION_CONNECTION_CLOSE; + // Setting to missing ensures that if it is missing, the extended + // code is not added to the text message. + app_close_frame.quic_error_code = QUIC_IETF_GQUIC_ERROR_MISSING; + + QuicFrames frames = {QuicFrame(&app_close_frame)}; + + // clang-format off + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_APPLICATION_CLOSE frame) + 0x1d, + // error code + kVarInt62OneByte + 0x11, + // error details length + kVarInt62TwoBytes + 0x01, 0x00, + // error details (truncated to 256 bytes) + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + }; + // clang-format on + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); +} + +TEST_P(QuicFramerTest, BuildGoAwayPacket) { + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + // This frame is only for Google QUIC. + return; + } + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicGoAwayFrame goaway_frame; + goaway_frame.error_code = static_cast(0x05060708); + goaway_frame.last_good_stream_id = kStreamId; + goaway_frame.reason_phrase = "because I can"; + + QuicFrames frames = {QuicFrame(&goaway_frame)}; + + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x2C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (go away frame) + 0x03, + // error code + 0x05, 0x06, 0x07, 0x08, + // stream id + 0x01, 0x02, 0x03, 0x04, + // error details length + 0x00, 0x0d, + // error details + 'b', 'e', 'c', 'a', + 'u', 's', 'e', ' ', + 'I', ' ', 'c', 'a', + 'n', + }; + + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (go away frame) + 0x03, + // error code + 0x05, 0x06, 0x07, 0x08, + // stream id + 0x01, 0x02, 0x03, 0x04, + // error details length + 0x00, 0x0d, + // error details + 'b', 'e', 'c', 'a', + 'u', 's', 'e', ' ', + 'I', ' ', 'c', 'a', + 'n', + }; + + // clang-format on + + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_size = ABSL_ARRAYSIZE(packet46); + } + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(p), p_size); +} + +TEST_P(QuicFramerTest, BuildTruncatedGoAwayPacket) { + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + // This frame is only for Google QUIC. + return; + } + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicGoAwayFrame goaway_frame; + goaway_frame.error_code = static_cast(0x05060708); + goaway_frame.last_good_stream_id = kStreamId; + goaway_frame.reason_phrase = std::string(2048, 'A'); + + QuicFrames frames = {QuicFrame(&goaway_frame)}; + + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x2C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (go away frame) + 0x03, + // error code + 0x05, 0x06, 0x07, 0x08, + // stream id + 0x01, 0x02, 0x03, 0x04, + // error details length + 0x01, 0x00, + // error details (truncated to 256 bytes) + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + }; + + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (go away frame) + 0x03, + // error code + 0x05, 0x06, 0x07, 0x08, + // stream id + 0x01, 0x02, 0x03, 0x04, + // error details length + 0x01, 0x00, + // error details (truncated to 256 bytes) + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', + }; + // clang-format on + + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_size = ABSL_ARRAYSIZE(packet46); + } + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(p), p_size); +} + +TEST_P(QuicFramerTest, BuildWindowUpdatePacket) { + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicWindowUpdateFrame window_update_frame; + window_update_frame.stream_id = kStreamId; + window_update_frame.max_data = 0x1122334455667788; + + QuicFrames frames = {QuicFrame(window_update_frame)}; + + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x2C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (window update frame) + 0x04, + // stream id + 0x01, 0x02, 0x03, 0x04, + // byte offset + 0x11, 0x22, 0x33, 0x44, + 0x55, 0x66, 0x77, 0x88, + }; + + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (window update frame) + 0x04, + // stream id + 0x01, 0x02, 0x03, 0x04, + // byte offset + 0x11, 0x22, 0x33, 0x44, + 0x55, 0x66, 0x77, 0x88, + }; + + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_MAX_STREAM_DATA frame) + 0x11, + // stream id + kVarInt62FourBytes + 0x01, 0x02, 0x03, 0x04, + // byte offset + kVarInt62EightBytes + 0x11, 0x22, 0x33, 0x44, + 0x55, 0x66, 0x77, 0x88, + }; + // clang-format on + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_size = ABSL_ARRAYSIZE(packet46); + } + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(p), p_size); +} + +TEST_P(QuicFramerTest, BuildMaxStreamDataPacket) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // This frame is only for IETF QUIC. + return; + } + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicWindowUpdateFrame window_update_frame; + window_update_frame.stream_id = kStreamId; + window_update_frame.max_data = 0x1122334455667788; + + QuicFrames frames = {QuicFrame(window_update_frame)}; + + // clang-format off + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_MAX_STREAM_DATA frame) + 0x11, + // stream id + kVarInt62FourBytes + 0x01, 0x02, 0x03, 0x04, + // byte offset + kVarInt62EightBytes + 0x11, 0x22, 0x33, 0x44, + 0x55, 0x66, 0x77, 0x88, + }; + // clang-format on + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); +} + +TEST_P(QuicFramerTest, BuildMaxDataPacket) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // This frame is only for IETF QUIC. + return; + } + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicWindowUpdateFrame window_update_frame; + window_update_frame.stream_id = + QuicUtils::GetInvalidStreamId(framer_.transport_version()); + window_update_frame.max_data = 0x1122334455667788; + + QuicFrames frames = {QuicFrame(window_update_frame)}; + + // clang-format off + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_MAX_DATA frame) + 0x10, + // byte offset + kVarInt62EightBytes + 0x11, 0x22, 0x33, 0x44, + 0x55, 0x66, 0x77, 0x88, + }; + // clang-format on + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); +} + +TEST_P(QuicFramerTest, BuildBlockedPacket) { + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicBlockedFrame blocked_frame; + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + // For IETF QUIC, the stream ID must be for the frame + // to be a BLOCKED frame. if it's valid, it will be a + // STREAM_BLOCKED frame. + blocked_frame.stream_id = + QuicUtils::GetInvalidStreamId(framer_.transport_version()); + } else { + blocked_frame.stream_id = kStreamId; + } + blocked_frame.offset = kStreamOffset; + + QuicFrames frames = {QuicFrame(blocked_frame)}; + + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x2C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (blocked frame) + 0x05, + // stream id + 0x01, 0x02, 0x03, 0x04, + }; + + unsigned char packet46[] = { + // type (short packet, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (blocked frame) + 0x05, + // stream id + 0x01, 0x02, 0x03, 0x04, + }; + + unsigned char packet_ietf[] = { + // type (short packet, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_DATA_BLOCKED frame) + 0x14, + // Offset + kVarInt62EightBytes + 0x3a, 0x98, 0xFE, 0xDC, 0x32, 0x10, 0x76, 0x54 + }; + // clang-format on + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_size = ABSL_ARRAYSIZE(packet46); + } + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(p), p_size); +} + +TEST_P(QuicFramerTest, BuildPingPacket) { + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicFrames frames = {QuicFrame(QuicPingFrame())}; + + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x2C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (ping frame) + 0x07, + }; + + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type + 0x07, + }; + + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_PING frame) + 0x01, + }; + // clang-format on + + unsigned char* p = packet; + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + p = packet_ietf; + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + } + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(p), + framer_.version().HasIetfInvariantHeader() ? ABSL_ARRAYSIZE(packet46) + : ABSL_ARRAYSIZE(packet)); +} + +TEST_P(QuicFramerTest, BuildHandshakeDonePacket) { + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicFrames frames = {QuicFrame(QuicHandshakeDoneFrame())}; + + // clang-format off + unsigned char packet[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (Handshake done frame) + 0x1e, + }; + // clang-format on + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet), + ABSL_ARRAYSIZE(packet)); +} + +TEST_P(QuicFramerTest, BuildAckFrequencyPacket) { + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicAckFrequencyFrame ack_frequency_frame; + ack_frequency_frame.sequence_number = 3; + ack_frequency_frame.packet_tolerance = 5; + ack_frequency_frame.max_ack_delay = QuicTime::Delta::FromMicroseconds(0x3fff); + ack_frequency_frame.ignore_order = false; + QuicFrames frames = {QuicFrame(&ack_frequency_frame)}; + + // clang-format off + unsigned char packet[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (Ack Frequency frame) + 0x40, 0xaf, + // sequence number + 0x03, + // packet tolerance + 0x05, + // max_ack_delay_us + 0x7f, 0xff, + // ignore_oder + 0x00 + }; + // clang-format on + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet), + ABSL_ARRAYSIZE(packet)); +} + +TEST_P(QuicFramerTest, BuildMessagePacket) { + if (!VersionSupportsMessageFrames(framer_.transport_version())) { + return; + } + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicMessageFrame frame(1, MemSliceFromString("message")); + QuicMessageFrame frame2(2, MemSliceFromString("message2")); + QuicFrames frames = {QuicFrame(&frame), QuicFrame(&frame2)}; + + // clang-format off + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (message frame) + 0x21, + // Length + 0x07, + // Message Data + 'm', 'e', 's', 's', 'a', 'g', 'e', + // frame type (message frame no length) + 0x20, + // Message Data + 'm', 'e', 's', 's', 'a', 'g', 'e', '2' + }; + + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_MESSAGE frame) + 0x31, + // Length + 0x07, + // Message Data + 'm', 'e', 's', 's', 'a', 'g', 'e', + // frame type (message frame no length) + 0x30, + // Message Data + 'm', 'e', 's', 's', 'a', 'g', 'e', '2' + }; + // clang-format on + + unsigned char* p = packet46; + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + p = packet_ietf; + } + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(p), + ABSL_ARRAYSIZE(packet46)); +} + +// Test that the MTU discovery packet is serialized correctly as a PING packet. +TEST_P(QuicFramerTest, BuildMtuDiscoveryPacket) { + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicFrames frames = {QuicFrame(QuicMtuDiscoveryFrame())}; + + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x2C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (ping frame) + 0x07, + }; + + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type + 0x07, + }; + + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_PING frame) + 0x01, + }; + // clang-format on + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + unsigned char* p = packet; + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + p = packet_ietf; + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + } + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(p), + framer_.version().HasIetfInvariantHeader() ? ABSL_ARRAYSIZE(packet46) + : ABSL_ARRAYSIZE(packet)); +} + +TEST_P(QuicFramerTest, BuildPublicResetPacket) { + QuicPublicResetPacket reset_packet; + reset_packet.connection_id = FramerTestConnectionId(); + reset_packet.nonce_proof = kNonceProof; + + // clang-format off + unsigned char packet[] = { + // public flags (public reset, 8 byte ConnectionId) + 0x0E, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // message tag (kPRST) + 'P', 'R', 'S', 'T', + // num_entries (1) + padding + 0x01, 0x00, 0x00, 0x00, + // tag kRNON + 'R', 'N', 'O', 'N', + // end offset 8 + 0x08, 0x00, 0x00, 0x00, + // nonce proof + 0x89, 0x67, 0x45, 0x23, + 0x01, 0xEF, 0xCD, 0xAB, + }; + // clang-format on + + if (framer_.version().HasIetfInvariantHeader()) { + return; + } + + std::unique_ptr data( + framer_.BuildPublicResetPacket(reset_packet)); + ASSERT_TRUE(data != nullptr); + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet), + ABSL_ARRAYSIZE(packet)); +} + +TEST_P(QuicFramerTest, BuildPublicResetPacketWithClientAddress) { + QuicPublicResetPacket reset_packet; + reset_packet.connection_id = FramerTestConnectionId(); + reset_packet.nonce_proof = kNonceProof; + reset_packet.client_address = + QuicSocketAddress(QuicIpAddress::Loopback4(), 0x1234); + + // clang-format off + unsigned char packet[] = { + // public flags (public reset, 8 byte ConnectionId) + 0x0E, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, + 0x76, 0x54, 0x32, 0x10, + // message tag (kPRST) + 'P', 'R', 'S', 'T', + // num_entries (2) + padding + 0x02, 0x00, 0x00, 0x00, + // tag kRNON + 'R', 'N', 'O', 'N', + // end offset 8 + 0x08, 0x00, 0x00, 0x00, + // tag kCADR + 'C', 'A', 'D', 'R', + // end offset 16 + 0x10, 0x00, 0x00, 0x00, + // nonce proof + 0x89, 0x67, 0x45, 0x23, + 0x01, 0xEF, 0xCD, 0xAB, + // client address + 0x02, 0x00, + 0x7F, 0x00, 0x00, 0x01, + 0x34, 0x12, + }; + // clang-format on + + if (framer_.version().HasIetfInvariantHeader()) { + return; + } + + std::unique_ptr data( + framer_.BuildPublicResetPacket(reset_packet)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet), + ABSL_ARRAYSIZE(packet)); +} + +TEST_P(QuicFramerTest, BuildPublicResetPacketWithEndpointId) { + QuicPublicResetPacket reset_packet; + reset_packet.connection_id = FramerTestConnectionId(); + reset_packet.nonce_proof = kNonceProof; + reset_packet.endpoint_id = "FakeServerId"; + + // The tag value map in CryptoHandshakeMessage is a std::map, so the two tags + // in the packet, kRNON and kEPID, have unspecified ordering w.r.t each other. + // clang-format off + unsigned char packet_variant1[] = { + // public flags (public reset, 8 byte ConnectionId) + 0x0E, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, + 0x76, 0x54, 0x32, 0x10, + // message tag (kPRST) + 'P', 'R', 'S', 'T', + // num_entries (2) + padding + 0x02, 0x00, 0x00, 0x00, + // tag kRNON + 'R', 'N', 'O', 'N', + // end offset 8 + 0x08, 0x00, 0x00, 0x00, + // tag kEPID + 'E', 'P', 'I', 'D', + // end offset 20 + 0x14, 0x00, 0x00, 0x00, + // nonce proof + 0x89, 0x67, 0x45, 0x23, + 0x01, 0xEF, 0xCD, 0xAB, + // Endpoint ID + 'F', 'a', 'k', 'e', 'S', 'e', 'r', 'v', 'e', 'r', 'I', 'd', + }; + unsigned char packet_variant2[] = { + // public flags (public reset, 8 byte ConnectionId) + 0x0E, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, + 0x76, 0x54, 0x32, 0x10, + // message tag (kPRST) + 'P', 'R', 'S', 'T', + // num_entries (2) + padding + 0x02, 0x00, 0x00, 0x00, + // tag kEPID + 'E', 'P', 'I', 'D', + // end offset 12 + 0x0C, 0x00, 0x00, 0x00, + // tag kRNON + 'R', 'N', 'O', 'N', + // end offset 20 + 0x14, 0x00, 0x00, 0x00, + // Endpoint ID + 'F', 'a', 'k', 'e', 'S', 'e', 'r', 'v', 'e', 'r', 'I', 'd', + // nonce proof + 0x89, 0x67, 0x45, 0x23, + 0x01, 0xEF, 0xCD, 0xAB, + }; + // clang-format on + + if (framer_.version().HasIetfInvariantHeader()) { + return; + } + + std::unique_ptr data( + framer_.BuildPublicResetPacket(reset_packet)); + ASSERT_TRUE(data != nullptr); + + // Variant 1 ends with char 'd'. Variant 1 ends with char 0xAB. + if ('d' == data->data()[data->length() - 1]) { + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), + AsChars(packet_variant1), ABSL_ARRAYSIZE(packet_variant1)); + } else { + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), + AsChars(packet_variant2), ABSL_ARRAYSIZE(packet_variant2)); + } +} + +TEST_P(QuicFramerTest, BuildIetfStatelessResetPacket) { + // clang-format off + unsigned char packet[] = { + // 1st byte 01XX XXXX + 0x40, + // At least 4 bytes of random bytes. + 0x00, 0x00, 0x00, 0x00, + // stateless reset token + 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, + 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f + }; + // clang-format on + + // Build the minimal stateless reset packet. + std::unique_ptr data( + framer_.BuildIetfStatelessResetPacket( + FramerTestConnectionId(), + QuicFramer::GetMinStatelessResetPacketLength() + 1, + kTestStatelessResetToken)); + ASSERT_TRUE(data); + EXPECT_EQ(QuicFramer::GetMinStatelessResetPacketLength(), data->length()); + // Verify the first 2 bits are 01. + EXPECT_FALSE(data->data()[0] & FLAGS_LONG_HEADER); + EXPECT_TRUE(data->data()[0] & FLAGS_FIXED_BIT); + // Verify stateless reset token. + quiche::test::CompareCharArraysWithHexError( + "constructed packet", + data->data() + data->length() - kStatelessResetTokenLength, + kStatelessResetTokenLength, + AsChars(packet) + ABSL_ARRAYSIZE(packet) - kStatelessResetTokenLength, + kStatelessResetTokenLength); + + // Packets with length <= minimal stateless reset does not trigger stateless + // reset. + std::unique_ptr data2( + framer_.BuildIetfStatelessResetPacket( + FramerTestConnectionId(), + QuicFramer::GetMinStatelessResetPacketLength(), + kTestStatelessResetToken)); + ASSERT_FALSE(data2); + + // Do not send stateless reset >= minimal stateless reset + 1 + max + // connection ID length. + std::unique_ptr data3( + framer_.BuildIetfStatelessResetPacket(FramerTestConnectionId(), 1000, + kTestStatelessResetToken)); + ASSERT_TRUE(data3); + EXPECT_EQ(QuicFramer::GetMinStatelessResetPacketLength() + 1 + + kQuicMaxConnectionIdWithLengthPrefixLength, + data3->length()); +} + +TEST_P(QuicFramerTest, BuildIetfStatelessResetPacketCallerProvidedRandomBytes) { + // clang-format off + unsigned char packet[] = { + // 1st byte 01XX XXXX + 0x7c, + // Random bytes + 0x7c, 0x7c, 0x7c, 0x7c, + // stateless reset token + 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, + 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f + }; + // clang-format on + + // Build the minimal stateless reset packet with caller-provided random bytes. + MockRandom random; + auto generate_random_bytes = [](void* data, size_t len) { + std::string bytes(len, 0x7c); + memcpy(data, bytes.data(), bytes.size()); + }; + EXPECT_CALL(random, InsecureRandBytes(_, _)) + .WillOnce(testing::Invoke(generate_random_bytes)); + std::unique_ptr data( + framer_.BuildIetfStatelessResetPacket( + FramerTestConnectionId(), + QuicFramer::GetMinStatelessResetPacketLength() + 1, + kTestStatelessResetToken, &random)); + ASSERT_TRUE(data); + EXPECT_EQ(QuicFramer::GetMinStatelessResetPacketLength(), data->length()); + // Verify the first 2 bits are 01. + EXPECT_FALSE(data->data()[0] & FLAGS_LONG_HEADER); + EXPECT_TRUE(data->data()[0] & FLAGS_FIXED_BIT); + // Verify the entire packet. + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet), + ABSL_ARRAYSIZE(packet)); +} + +TEST_P(QuicFramerTest, EncryptPacket) { + QuicPacketNumber packet_number = kPacketNumber; + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x28, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // redundancy + 'a', 'b', 'c', 'd', + 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', + 'm', 'n', 'o', 'p', + }; + + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // redundancy + 'a', 'b', 'c', 'd', + 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', + 'm', 'n', 'o', 'p', + }; + + unsigned char packet50[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // redundancy + 'a', 'b', 'c', 'd', + 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', + 'm', 'n', 'o', 'p', + 'q', 'r', 's', 't', + }; + // clang-format on + + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + if (framer_.version().HasHeaderProtection()) { + p = packet50; + p_size = ABSL_ARRAYSIZE(packet50); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + } + + std::unique_ptr raw(new QuicPacket( + AsChars(p), p_size, false, kPacket8ByteConnectionId, + kPacket0ByteConnectionId, !kIncludeVersion, !kIncludeDiversificationNonce, + PACKET_4BYTE_PACKET_NUMBER, quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0, 0, + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0)); + char buffer[kMaxOutgoingPacketSize]; + size_t encrypted_length = framer_.EncryptPayload( + ENCRYPTION_INITIAL, packet_number, *raw, buffer, kMaxOutgoingPacketSize); + + ASSERT_NE(0u, encrypted_length); + EXPECT_TRUE(CheckEncryption(packet_number, raw.get())); +} + +// Regression test for b/158014497. +TEST_P(QuicFramerTest, EncryptEmptyPacket) { + auto packet = std::make_unique( + new char[100], 0, true, kPacket8ByteConnectionId, + kPacket0ByteConnectionId, + /*includes_version=*/true, + /*includes_diversification_nonce=*/true, PACKET_1BYTE_PACKET_NUMBER, + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0, + /*retry_token_length=*/0, quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0); + char buffer[kMaxOutgoingPacketSize]; + size_t encrypted_length = 1; + EXPECT_QUIC_BUG( + { + encrypted_length = + framer_.EncryptPayload(ENCRYPTION_INITIAL, kPacketNumber, *packet, + buffer, kMaxOutgoingPacketSize); + EXPECT_EQ(0u, encrypted_length); + }, + "packet is shorter than associated data length"); +} + +TEST_P(QuicFramerTest, EncryptPacketWithVersionFlag) { + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketNumber packet_number = kPacketNumber; + // clang-format off + unsigned char packet[] = { + // public flags (version, 8 byte connection_id) + 0x29, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // version tag + 'Q', '.', '1', '0', + // packet number + 0x12, 0x34, 0x56, 0x78, + + // redundancy + 'a', 'b', 'c', 'd', + 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', + 'm', 'n', 'o', 'p', + }; + + unsigned char packet46[] = { + // type (long header with packet type ZERO_RTT_PROTECTED) + 0xD3, + // version tag + 'Q', '.', '1', '0', + // connection_id length + 0x50, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // redundancy + 'a', 'b', 'c', 'd', + 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', + 'm', 'n', 'o', 'p', + }; + + unsigned char packet50[] = { + // type (long header with packet type ZERO_RTT_PROTECTED) + 0xD3, + // version tag + 'Q', '.', '1', '0', + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // redundancy + 'a', 'b', 'c', 'd', + 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', + 'm', 'n', 'o', 'p', + 'q', 'r', 's', 't', + }; + // clang-format on + + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + // TODO(ianswett): see todo in previous test. + if (framer_.version().HasHeaderProtection()) { + p = packet50; + p_size = ABSL_ARRAYSIZE(packet50); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_size = ABSL_ARRAYSIZE(packet46); + } + + std::unique_ptr raw(new QuicPacket( + AsChars(p), p_size, false, kPacket8ByteConnectionId, + kPacket0ByteConnectionId, kIncludeVersion, !kIncludeDiversificationNonce, + PACKET_4BYTE_PACKET_NUMBER, quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0, 0, + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0)); + char buffer[kMaxOutgoingPacketSize]; + size_t encrypted_length = framer_.EncryptPayload( + ENCRYPTION_INITIAL, packet_number, *raw, buffer, kMaxOutgoingPacketSize); + + ASSERT_NE(0u, encrypted_length); + EXPECT_TRUE(CheckEncryption(packet_number, raw.get())); +} + +TEST_P(QuicFramerTest, AckTruncationLargePacket) { + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + // This test is not applicable to this version; the range count is + // effectively unlimited + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicAckFrame ack_frame; + // Create a packet with just the ack. + ack_frame = MakeAckFrameWithAckBlocks(300, 0u); + QuicFrames frames = {QuicFrame(&ack_frame)}; + + // Build an ack packet with truncation due to limit in number of nack ranges. + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + std::unique_ptr raw_ack_packet(BuildDataPacket(header, frames)); + ASSERT_TRUE(raw_ack_packet != nullptr); + char buffer[kMaxOutgoingPacketSize]; + size_t encrypted_length = + framer_.EncryptPayload(ENCRYPTION_INITIAL, header.packet_number, + *raw_ack_packet, buffer, kMaxOutgoingPacketSize); + ASSERT_NE(0u, encrypted_length); + // Now make sure we can turn our ack packet back into an ack frame. + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + ASSERT_TRUE(framer_.ProcessPacket( + QuicEncryptedPacket(buffer, encrypted_length, false))); + ASSERT_EQ(1u, visitor_.ack_frames_.size()); + QuicAckFrame& processed_ack_frame = *visitor_.ack_frames_[0]; + EXPECT_EQ(QuicPacketNumber(600u), LargestAcked(processed_ack_frame)); + ASSERT_EQ(256u, processed_ack_frame.packets.NumPacketsSlow()); + EXPECT_EQ(QuicPacketNumber(90u), processed_ack_frame.packets.Min()); + EXPECT_EQ(QuicPacketNumber(600u), processed_ack_frame.packets.Max()); +} + +// Regression test for b/150386368. +TEST_P(QuicFramerTest, IetfAckFrameTruncation) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicAckFrame ack_frame; + // Create a packet with just the ack. + ack_frame = MakeAckFrameWithGaps(/*gap_size=*/0xffffffff, + /*max_num_gaps=*/200, + /*largest_acked=*/kMaxIetfVarInt); + ack_frame.ecn_counters = QuicEcnCounts(100, 10000, 1000000); + QuicFrames frames = {QuicFrame(&ack_frame)}; + // Build an ACK packet. + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + std::unique_ptr raw_ack_packet(BuildDataPacket(header, frames)); + ASSERT_TRUE(raw_ack_packet != nullptr); + char buffer[kMaxOutgoingPacketSize]; + size_t encrypted_length = + framer_.EncryptPayload(ENCRYPTION_INITIAL, header.packet_number, + *raw_ack_packet, buffer, kMaxOutgoingPacketSize); + ASSERT_NE(0u, encrypted_length); + // Now make sure we can turn our ack packet back into an ack frame. + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + ASSERT_TRUE(framer_.ProcessPacket( + QuicEncryptedPacket(buffer, encrypted_length, false))); + ASSERT_EQ(1u, visitor_.ack_frames_.size()); + QuicAckFrame& processed_ack_frame = *visitor_.ack_frames_[0]; + EXPECT_EQ(QuicPacketNumber(kMaxIetfVarInt), + LargestAcked(processed_ack_frame)); + // Verify ACK frame gets truncated. + ASSERT_LT(processed_ack_frame.packets.NumPacketsSlow(), + ack_frame.packets.NumIntervals()); + EXPECT_EQ(157u, processed_ack_frame.packets.NumPacketsSlow()); + EXPECT_LT(processed_ack_frame.packets.NumIntervals(), + ack_frame.packets.NumIntervals()); + EXPECT_EQ(QuicPacketNumber(kMaxIetfVarInt), + processed_ack_frame.packets.Max()); +} + +TEST_P(QuicFramerTest, AckTruncationSmallPacket) { + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + // This test is not applicable to this version; the range count is + // effectively unlimited + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + // Create a packet with just the ack. + QuicAckFrame ack_frame; + ack_frame = MakeAckFrameWithAckBlocks(300, 0u); + QuicFrames frames = {QuicFrame(&ack_frame)}; + + // Build an ack packet with truncation due to limit in number of nack ranges. + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + std::unique_ptr raw_ack_packet( + BuildDataPacket(header, frames, 500)); + ASSERT_TRUE(raw_ack_packet != nullptr); + char buffer[kMaxOutgoingPacketSize]; + size_t encrypted_length = + framer_.EncryptPayload(ENCRYPTION_INITIAL, header.packet_number, + *raw_ack_packet, buffer, kMaxOutgoingPacketSize); + ASSERT_NE(0u, encrypted_length); + // Now make sure we can turn our ack packet back into an ack frame. + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + ASSERT_TRUE(framer_.ProcessPacket( + QuicEncryptedPacket(buffer, encrypted_length, false))); + ASSERT_EQ(1u, visitor_.ack_frames_.size()); + QuicAckFrame& processed_ack_frame = *visitor_.ack_frames_[0]; + EXPECT_EQ(QuicPacketNumber(600u), LargestAcked(processed_ack_frame)); + ASSERT_EQ(240u, processed_ack_frame.packets.NumPacketsSlow()); + EXPECT_EQ(QuicPacketNumber(122u), processed_ack_frame.packets.Min()); + EXPECT_EQ(QuicPacketNumber(600u), processed_ack_frame.packets.Max()); +} + +TEST_P(QuicFramerTest, CleanTruncation) { + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + // This test is not applicable to this version; the range count is + // effectively unlimited + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicAckFrame ack_frame = InitAckFrame(201); + + // Create a packet with just the ack. + QuicFrames frames = {QuicFrame(&ack_frame)}; + if (framer_.version().HasHeaderProtection()) { + frames.push_back(QuicFrame(QuicPaddingFrame(12))); + } + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + std::unique_ptr raw_ack_packet(BuildDataPacket(header, frames)); + ASSERT_TRUE(raw_ack_packet != nullptr); + + char buffer[kMaxOutgoingPacketSize]; + size_t encrypted_length = + framer_.EncryptPayload(ENCRYPTION_INITIAL, header.packet_number, + *raw_ack_packet, buffer, kMaxOutgoingPacketSize); + ASSERT_NE(0u, encrypted_length); + + // Now make sure we can turn our ack packet back into an ack frame. + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + ASSERT_TRUE(framer_.ProcessPacket( + QuicEncryptedPacket(buffer, encrypted_length, false))); + + // Test for clean truncation of the ack by comparing the length of the + // original packets to the re-serialized packets. + frames.clear(); + frames.push_back(QuicFrame(visitor_.ack_frames_[0].get())); + if (framer_.version().HasHeaderProtection()) { + frames.push_back(QuicFrame(*visitor_.padding_frames_[0].get())); + } + + size_t original_raw_length = raw_ack_packet->length(); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + raw_ack_packet = BuildDataPacket(header, frames); + ASSERT_TRUE(raw_ack_packet != nullptr); + EXPECT_EQ(original_raw_length, raw_ack_packet->length()); + ASSERT_TRUE(raw_ack_packet != nullptr); +} + +TEST_P(QuicFramerTest, StopPacketProcessing) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x28, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (stream frame with fin) + 0xFF, + // stream id + 0x01, 0x02, 0x03, 0x04, + // offset + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54, + // data length + 0x00, 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + + // frame type (ack frame) + 0x40, + // least packet number awaiting an ack + 0x12, 0x34, 0x56, 0x78, + 0x9A, 0xA0, + // largest observed packet number + 0x12, 0x34, 0x56, 0x78, + 0x9A, 0xBF, + // num missing packets + 0x01, + // missing packet + 0x12, 0x34, 0x56, 0x78, + 0x9A, 0xBE, + }; + + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (stream frame with fin) + 0xFF, + // stream id + 0x01, 0x02, 0x03, 0x04, + // offset + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54, + // data length + 0x00, 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + + // frame type (ack frame) + 0x40, + // least packet number awaiting an ack + 0x12, 0x34, 0x56, 0x78, + 0x9A, 0xA0, + // largest observed packet number + 0x12, 0x34, 0x56, 0x78, + 0x9A, 0xBF, + // num missing packets + 0x01, + // missing packet + 0x12, 0x34, 0x56, 0x78, + 0x9A, 0xBE, + }; + + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_STREAM frame with fin, length, and offset bits set) + 0x08 | 0x01 | 0x02 | 0x04, + // stream id + kVarInt62FourBytes + 0x01, 0x02, 0x03, 0x04, + // offset + kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54, + // data length + kVarInt62TwoBytes + 0x00, 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + + // frame type (ack frame) + 0x0d, + // largest observed packet number + kVarInt62FourBytes + 0x12, 0x34, 0x56, 0x78, + // Delta time + kVarInt62OneByte + 0x00, + // Ack Block count + kVarInt62OneByte + 0x01, + // First block size (one packet) + kVarInt62OneByte + 0x00, + + // Next gap size & ack. Missing all preceding packets + kVarInt62FourBytes + 0x12, 0x34, 0x56, 0x77, + kVarInt62OneByte + 0x00, + }; + // clang-format on + + MockFramerVisitor visitor; + framer_.set_visitor(&visitor); + EXPECT_CALL(visitor, OnPacket()); + EXPECT_CALL(visitor, OnPacketHeader(_)); + EXPECT_CALL(visitor, OnStreamFrame(_)).WillOnce(Return(false)); + EXPECT_CALL(visitor, OnPacketComplete()); + EXPECT_CALL(visitor, OnUnauthenticatedPublicHeader(_)).WillOnce(Return(true)); + EXPECT_CALL(visitor, OnUnauthenticatedHeader(_)).WillOnce(Return(true)); + EXPECT_CALL(visitor, OnDecryptedPacket(_, _)); + + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_size = ABSL_ARRAYSIZE(packet46); + } + QuicEncryptedPacket encrypted(AsChars(p), p_size, false); + EXPECT_TRUE(framer_.ProcessPacket(encrypted)); + EXPECT_THAT(framer_.error(), IsQuicNoError()); +} + +static char kTestString[] = "At least 20 characters."; +static QuicStreamId kTestQuicStreamId = 1; + +MATCHER_P(ExpectedStreamFrame, version, "") { + return (arg.stream_id == kTestQuicStreamId || + QuicUtils::IsCryptoStreamId(version.transport_version, + arg.stream_id)) && + !arg.fin && arg.offset == 0 && + std::string(arg.data_buffer, arg.data_length) == kTestString; + // FIN is hard-coded false in ConstructEncryptedPacket. + // Offset 0 is hard-coded in ConstructEncryptedPacket. +} + +// Verify that the packet returned by ConstructEncryptedPacket() can be properly +// parsed by the framer. +TEST_P(QuicFramerTest, ConstructEncryptedPacket) { + // Since we are using ConstructEncryptedPacket, we have to set the framer's + // crypto to be Null. + if (framer_.version().KnowsWhichDecrypterToUse()) { + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique( + (uint8_t)ENCRYPTION_FORWARD_SECURE)); + } else { + framer_.SetDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique( + (uint8_t)ENCRYPTION_FORWARD_SECURE)); + } + ParsedQuicVersionVector versions; + versions.push_back(framer_.version()); + std::unique_ptr packet(ConstructEncryptedPacket( + TestConnectionId(), EmptyQuicConnectionId(), false, false, + kTestQuicStreamId, kTestString, CONNECTION_ID_PRESENT, + CONNECTION_ID_ABSENT, PACKET_4BYTE_PACKET_NUMBER, &versions)); + + MockFramerVisitor visitor; + framer_.set_visitor(&visitor); + EXPECT_CALL(visitor, OnPacket()).Times(1); + EXPECT_CALL(visitor, OnUnauthenticatedPublicHeader(_)) + .Times(1) + .WillOnce(Return(true)); + EXPECT_CALL(visitor, OnUnauthenticatedHeader(_)) + .Times(1) + .WillOnce(Return(true)); + EXPECT_CALL(visitor, OnPacketHeader(_)).Times(1).WillOnce(Return(true)); + EXPECT_CALL(visitor, OnDecryptedPacket(_, _)).Times(1); + EXPECT_CALL(visitor, OnError(_)).Times(0); + EXPECT_CALL(visitor, OnStreamFrame(_)).Times(0); + if (!QuicVersionUsesCryptoFrames(framer_.version().transport_version)) { + EXPECT_CALL(visitor, OnStreamFrame(ExpectedStreamFrame(framer_.version()))) + .Times(1); + } else { + EXPECT_CALL(visitor, OnCryptoFrame(_)).Times(1); + } + EXPECT_CALL(visitor, OnPacketComplete()).Times(1); + + EXPECT_TRUE(framer_.ProcessPacket(*packet)); + EXPECT_THAT(framer_.error(), IsQuicNoError()); +} + +// Verify that the packet returned by ConstructMisFramedEncryptedPacket() +// does cause the framer to return an error. +TEST_P(QuicFramerTest, ConstructMisFramedEncryptedPacket) { + // Since we are using ConstructEncryptedPacket, we have to set the framer's + // crypto to be Null. + if (framer_.version().KnowsWhichDecrypterToUse()) { + framer_.InstallDecrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + } + std::unique_ptr packet(ConstructMisFramedEncryptedPacket( + TestConnectionId(), EmptyQuicConnectionId(), false, false, + kTestQuicStreamId, kTestString, CONNECTION_ID_PRESENT, + CONNECTION_ID_ABSENT, PACKET_4BYTE_PACKET_NUMBER, framer_.version(), + Perspective::IS_CLIENT)); + + MockFramerVisitor visitor; + framer_.set_visitor(&visitor); + EXPECT_CALL(visitor, OnPacket()).Times(1); + EXPECT_CALL(visitor, OnUnauthenticatedPublicHeader(_)) + .Times(1) + .WillOnce(Return(true)); + EXPECT_CALL(visitor, OnUnauthenticatedHeader(_)) + .Times(1) + .WillOnce(Return(true)); + EXPECT_CALL(visitor, OnPacketHeader(_)).Times(1); + EXPECT_CALL(visitor, OnDecryptedPacket(_, _)).Times(1); + EXPECT_CALL(visitor, OnError(_)).Times(1); + EXPECT_CALL(visitor, OnStreamFrame(_)).Times(0); + EXPECT_CALL(visitor, OnPacketComplete()).Times(0); + + EXPECT_FALSE(framer_.ProcessPacket(*packet)); + EXPECT_THAT(framer_.error(), IsError(QUIC_INVALID_FRAME_DATA)); +} + +TEST_P(QuicFramerTest, IetfBlockedFrame) { + // This frame is only for IETF QUIC. + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (IETF_DATA_BLOCKED) + {"", + {0x14}}, + // blocked offset + {"Can not read blocked offset.", + {kVarInt62EightBytes + 0x3a, 0x98, 0xFE, 0xDC, 0x32, 0x10, 0x76, 0x54}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(kStreamOffset, visitor_.blocked_frame_.offset); + + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_BLOCKED_DATA); +} + +TEST_P(QuicFramerTest, BuildIetfBlockedPacket) { + // This frame is only for IETF QUIC. + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicBlockedFrame frame; + frame.stream_id = QuicUtils::GetInvalidStreamId(framer_.transport_version()); + frame.offset = kStreamOffset; + QuicFrames frames = {QuicFrame(frame)}; + + // clang-format off + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_DATA_BLOCKED) + 0x14, + // Offset + kVarInt62EightBytes + 0x3a, 0x98, 0xFE, 0xDC, 0x32, 0x10, 0x76, 0x54 + }; + // clang-format on + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); +} + +TEST_P(QuicFramerTest, IetfStreamBlockedFrame) { + // This frame is only for IETF QUIC. + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (IETF_STREAM_DATA_BLOCKED) + {"", + {0x15}}, + // blocked offset + {"Unable to read IETF_STREAM_DATA_BLOCKED frame stream id/count.", + {kVarInt62FourBytes + 0x01, 0x02, 0x03, 0x04}}, + {"Can not read stream blocked offset.", + {kVarInt62EightBytes + 0x3a, 0x98, 0xFE, 0xDC, 0x32, 0x10, 0x76, 0x54}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(kStreamId, visitor_.blocked_frame_.stream_id); + EXPECT_EQ(kStreamOffset, visitor_.blocked_frame_.offset); + + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_STREAM_BLOCKED_DATA); +} + +TEST_P(QuicFramerTest, BuildIetfStreamBlockedPacket) { + // This frame is only for IETF QUIC. + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicBlockedFrame frame; + frame.stream_id = kStreamId; + frame.offset = kStreamOffset; + QuicFrames frames = {QuicFrame(frame)}; + + // clang-format off + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_STREAM_DATA_BLOCKED) + 0x15, + // Stream ID + kVarInt62FourBytes + 0x01, 0x02, 0x03, 0x04, + // Offset + kVarInt62EightBytes + 0x3a, 0x98, 0xFE, 0xDC, 0x32, 0x10, 0x76, 0x54 + }; + // clang-format on + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); +} + +TEST_P(QuicFramerTest, BiDiMaxStreamsFrame) { + // This frame is only for IETF QUIC. + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (IETF_MAX_STREAMS_BIDIRECTIONAL) + {"", + {0x12}}, + // max. streams + {"Unable to read IETF_MAX_STREAMS_BIDIRECTIONAL frame stream id/count.", + {kVarInt62OneByte + 0x03}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(3u, visitor_.max_streams_frame_.stream_count); + EXPECT_FALSE(visitor_.max_streams_frame_.unidirectional); + CheckFramingBoundaries(packet_ietf, QUIC_MAX_STREAMS_DATA); +} + +TEST_P(QuicFramerTest, UniDiMaxStreamsFrame) { + // This frame is only for IETF QUIC. + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // Test runs in client mode, no connection id + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (IETF_MAX_STREAMS_UNIDIRECTIONAL) + {"", + {0x13}}, + // max. streams + {"Unable to read IETF_MAX_STREAMS_UNIDIRECTIONAL frame stream id/count.", + {kVarInt62OneByte + 0x03}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket0ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(3u, visitor_.max_streams_frame_.stream_count); + EXPECT_TRUE(visitor_.max_streams_frame_.unidirectional); + CheckFramingBoundaries(packet_ietf, QUIC_MAX_STREAMS_DATA); +} + +TEST_P(QuicFramerTest, ServerUniDiMaxStreamsFrame) { + // This frame is only for IETF QUIC. + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (IETF_MAX_STREAMS_UNIDIRECTIONAL) + {"", + {0x13}}, + // max. streams + {"Unable to read IETF_MAX_STREAMS_UNIDIRECTIONAL frame stream id/count.", + {kVarInt62OneByte + 0x03}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(3u, visitor_.max_streams_frame_.stream_count); + EXPECT_TRUE(visitor_.max_streams_frame_.unidirectional); + CheckFramingBoundaries(packet_ietf, QUIC_MAX_STREAMS_DATA); +} + +TEST_P(QuicFramerTest, ClientUniDiMaxStreamsFrame) { + // This frame is only for IETF QUIC. + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // Test runs in client mode, no connection id + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (IETF_MAX_STREAMS_UNIDIRECTIONAL) + {"", + {0x13}}, + // max. streams + {"Unable to read IETF_MAX_STREAMS_UNIDIRECTIONAL frame stream id/count.", + {kVarInt62OneByte + 0x03}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket0ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(3u, visitor_.max_streams_frame_.stream_count); + EXPECT_TRUE(visitor_.max_streams_frame_.unidirectional); + CheckFramingBoundaries(packet_ietf, QUIC_MAX_STREAMS_DATA); +} + +// The following four tests ensure that the framer can deserialize a stream +// count that is large enough to cause the resulting stream ID to exceed the +// current implementation limit(32 bits). The intent is that when this happens, +// the stream limit is pegged to the maximum supported value. There are four +// tests, for the four combinations of uni- and bi-directional, server- and +// client- initiated. +TEST_P(QuicFramerTest, BiDiMaxStreamsFrameTooBig) { + // This frame is only for IETF QUIC. + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x9A, 0xBC, + // frame type (IETF_MAX_STREAMS_BIDIRECTIONAL) + 0x12, + + // max. streams. Max stream ID allowed is 0xffffffff + // This encodes a count of 0x40000000, leading to stream + // IDs in the range 0x1 00000000 to 0x1 00000003. + kVarInt62EightBytes + 0x00, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00 + }; + // clang-format on + + QuicEncryptedPacket encrypted(AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf), false); + EXPECT_TRUE(framer_.ProcessPacket(encrypted)); + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(0x40000000u, visitor_.max_streams_frame_.stream_count); + EXPECT_FALSE(visitor_.max_streams_frame_.unidirectional); +} + +TEST_P(QuicFramerTest, ClientBiDiMaxStreamsFrameTooBig) { + // This frame is only for IETF QUIC. + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // Test runs in client mode, no connection id + // packet number + 0x12, 0x34, 0x9A, 0xBC, + // frame type (IETF_MAX_STREAMS_BIDIRECTIONAL) + 0x12, + + // max. streams. Max stream ID allowed is 0xffffffff + // This encodes a count of 0x40000000, leading to stream + // IDs in the range 0x1 00000000 to 0x1 00000003. + kVarInt62EightBytes + 0x00, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00 + }; + // clang-format on + + QuicEncryptedPacket encrypted(AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf), false); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + EXPECT_TRUE(framer_.ProcessPacket(encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket0ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(0x40000000u, visitor_.max_streams_frame_.stream_count); + EXPECT_FALSE(visitor_.max_streams_frame_.unidirectional); +} + +TEST_P(QuicFramerTest, ServerUniDiMaxStreamsFrameTooBig) { + // This frame is only for IETF QUIC. + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x9A, 0xBC, + // frame type (IETF_MAX_STREAMS_UNIDIRECTIONAL) + 0x13, + + // max. streams. Max stream ID allowed is 0xffffffff + // This encodes a count of 0x40000000, leading to stream + // IDs in the range 0x1 00000000 to 0x1 00000003. + kVarInt62EightBytes + 0x00, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00 + }; + // clang-format on + + QuicEncryptedPacket encrypted(AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf), false); + EXPECT_TRUE(framer_.ProcessPacket(encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(0x40000000u, visitor_.max_streams_frame_.stream_count); + EXPECT_TRUE(visitor_.max_streams_frame_.unidirectional); +} + +TEST_P(QuicFramerTest, ClientUniDiMaxStreamsFrameTooBig) { + // This frame is only for IETF QUIC. + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // Test runs in client mode, no connection id + // packet number + 0x12, 0x34, 0x9A, 0xBC, + // frame type (IETF_MAX_STREAMS_UNDIRECTIONAL) + 0x13, + + // max. streams. Max stream ID allowed is 0xffffffff + // This encodes a count of 0x40000000, leading to stream + // IDs in the range 0x1 00000000 to 0x1 00000003. + kVarInt62EightBytes + 0x00, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00 + }; + // clang-format on + + QuicEncryptedPacket encrypted(AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf), false); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + EXPECT_TRUE(framer_.ProcessPacket(encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket0ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(0x40000000u, visitor_.max_streams_frame_.stream_count); + EXPECT_TRUE(visitor_.max_streams_frame_.unidirectional); +} + +// Specifically test that count==0 is accepted. +TEST_P(QuicFramerTest, MaxStreamsFrameZeroCount) { + // This frame is only for IETF QUIC. + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x9A, 0xBC, + // frame type (IETF_MAX_STREAMS_BIDIRECTIONAL) + 0x12, + // max. streams == 0. + kVarInt62OneByte + 0x00 + }; + // clang-format on + + QuicEncryptedPacket encrypted(AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf), false); + EXPECT_TRUE(framer_.ProcessPacket(encrypted)); +} + +TEST_P(QuicFramerTest, ServerBiDiStreamsBlockedFrame) { + // This frame is only for IETF QUIC. + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (IETF_MAX_STREAMS_UNIDIRECTIONAL frame) + {"", + {0x13}}, + // stream count + {"Unable to read IETF_MAX_STREAMS_UNIDIRECTIONAL frame stream id/count.", + {kVarInt62OneByte + 0x00}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(0u, visitor_.max_streams_frame_.stream_count); + EXPECT_TRUE(visitor_.max_streams_frame_.unidirectional); + + CheckFramingBoundaries(packet_ietf, QUIC_MAX_STREAMS_DATA); +} + +TEST_P(QuicFramerTest, BiDiStreamsBlockedFrame) { + // This frame is only for IETF QUIC. + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (IETF_STREAMS_BLOCKED_BIDIRECTIONAL frame) + {"", + {0x16}}, + // stream id + {"Unable to read IETF_STREAMS_BLOCKED_BIDIRECTIONAL " + "frame stream id/count.", + {kVarInt62OneByte + 0x03}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(3u, visitor_.streams_blocked_frame_.stream_count); + EXPECT_FALSE(visitor_.streams_blocked_frame_.unidirectional); + + CheckFramingBoundaries(packet_ietf, QUIC_STREAMS_BLOCKED_DATA); +} + +TEST_P(QuicFramerTest, UniDiStreamsBlockedFrame) { + // This frame is only for IETF QUIC. + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (IETF_STREAMS_BLOCKED_UNIDIRECTIONAL frame) + {"", + {0x17}}, + // stream id + {"Unable to read IETF_STREAMS_BLOCKED_UNIDIRECTIONAL " + "frame stream id/count.", + {kVarInt62OneByte + 0x03}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(3u, visitor_.streams_blocked_frame_.stream_count); + EXPECT_TRUE(visitor_.streams_blocked_frame_.unidirectional); + CheckFramingBoundaries(packet_ietf, QUIC_STREAMS_BLOCKED_DATA); +} + +TEST_P(QuicFramerTest, ClientUniDiStreamsBlockedFrame) { + // This frame is only for IETF QUIC. + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // Test runs in client mode, no connection id + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (IETF_STREAMS_BLOCKED_UNIDIRECTIONAL frame) + {"", + {0x17}}, + // stream id + {"Unable to read IETF_STREAMS_BLOCKED_UNIDIRECTIONAL " + "frame stream id/count.", + {kVarInt62OneByte + 0x03}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket0ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(3u, visitor_.streams_blocked_frame_.stream_count); + EXPECT_TRUE(visitor_.streams_blocked_frame_.unidirectional); + CheckFramingBoundaries(packet_ietf, QUIC_STREAMS_BLOCKED_DATA); +} + +// Check that when we get a STREAMS_BLOCKED frame that specifies too large +// a stream count, we reject with an appropriate error. There is no need to +// check for different combinations of Uni/Bi directional and client/server +// initiated; the logic does not take these into account. +TEST_P(QuicFramerTest, StreamsBlockedFrameTooBig) { + // This frame is only for IETF QUIC. + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // Test runs in client mode, no connection id + // packet number + 0x12, 0x34, 0x9A, 0xBC, + // frame type (IETF_STREAMS_BLOCKED_BIDIRECTIONAL) + 0x16, + + // max. streams. Max stream ID allowed is 0xffffffff + // This encodes a count of 0x40000000, leading to stream + // IDs in the range 0x1 00000000 to 0x1 00000003. + kVarInt62EightBytes + 0x00, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x01 + }; + // clang-format on + + QuicEncryptedPacket encrypted(AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf), false); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + EXPECT_FALSE(framer_.ProcessPacket(encrypted)); + + EXPECT_THAT(framer_.error(), IsError(QUIC_STREAMS_BLOCKED_DATA)); + EXPECT_EQ(framer_.detailed_error(), + "STREAMS_BLOCKED stream count exceeds implementation limit."); +} + +// Specifically test that count==0 is accepted. +TEST_P(QuicFramerTest, StreamsBlockedFrameZeroCount) { + // This frame is only for IETF QUIC. + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (IETF_STREAMS_BLOCKED_UNIDIRECTIONAL frame) + {"", + {0x17}}, + // stream id + {"Unable to read IETF_STREAMS_BLOCKED_UNIDIRECTIONAL " + "frame stream id/count.", + {kVarInt62OneByte + 0x00}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(0u, visitor_.streams_blocked_frame_.stream_count); + EXPECT_TRUE(visitor_.streams_blocked_frame_.unidirectional); + + CheckFramingBoundaries(packet_ietf, QUIC_STREAMS_BLOCKED_DATA); +} + +TEST_P(QuicFramerTest, BuildBiDiStreamsBlockedPacket) { + // This frame is only for IETF QUIC. + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicStreamsBlockedFrame frame; + frame.stream_count = 3; + frame.unidirectional = false; + + QuicFrames frames = {QuicFrame(frame)}; + + // clang-format off + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_STREAMS_BLOCKED_BIDIRECTIONAL frame) + 0x16, + // Stream count + kVarInt62OneByte + 0x03 + }; + // clang-format on + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); +} + +TEST_P(QuicFramerTest, BuildUniStreamsBlockedPacket) { + // This frame is only for IETF QUIC. + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicStreamsBlockedFrame frame; + frame.stream_count = 3; + frame.unidirectional = true; + + QuicFrames frames = {QuicFrame(frame)}; + + // clang-format off + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_STREAMS_BLOCKED_UNIDIRECTIONAL frame) + 0x17, + // Stream count + kVarInt62OneByte + 0x03 + }; + // clang-format on + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); +} + +TEST_P(QuicFramerTest, BuildBiDiMaxStreamsPacket) { + // This frame is only for IETF QUIC. + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicMaxStreamsFrame frame; + frame.stream_count = 3; + frame.unidirectional = false; + + QuicFrames frames = {QuicFrame(frame)}; + + // clang-format off + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_MAX_STREAMS_BIDIRECTIONAL frame) + 0x12, + // Stream count + kVarInt62OneByte + 0x03 + }; + // clang-format on + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); +} + +TEST_P(QuicFramerTest, BuildUniDiMaxStreamsPacket) { + // This frame is only for IETF QUIC. + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + + // This test runs in client mode. + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicMaxStreamsFrame frame; + frame.stream_count = 3; + frame.unidirectional = true; + + QuicFrames frames = {QuicFrame(frame)}; + + // clang-format off + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_MAX_STREAMS_UNIDIRECTIONAL frame) + 0x13, + // Stream count + kVarInt62OneByte + 0x03 + }; + // clang-format on + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); +} + +TEST_P(QuicFramerTest, NewConnectionIdFrame) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // This frame is only for IETF QUIC. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (IETF_NEW_CONNECTION_ID frame) + {"", + {0x18}}, + // error code + {"Unable to read new connection ID frame sequence number.", + {kVarInt62OneByte + 0x11}}, + {"Unable to read new connection ID frame retire_prior_to.", + {kVarInt62OneByte + 0x09}}, + {"Unable to read new connection ID frame connection id.", + {0x08}}, // connection ID length + {"Unable to read new connection ID frame connection id.", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x11}}, + {"Can not read new connection ID frame reset token.", + {0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, + 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f}} + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(0u, visitor_.stream_frames_.size()); + + EXPECT_EQ(FramerTestConnectionIdPlusOne(), + visitor_.new_connection_id_.connection_id); + EXPECT_EQ(0x11u, visitor_.new_connection_id_.sequence_number); + EXPECT_EQ(0x09u, visitor_.new_connection_id_.retire_prior_to); + EXPECT_EQ(kTestStatelessResetToken, + visitor_.new_connection_id_.stateless_reset_token); + + ASSERT_EQ(0u, visitor_.ack_frames_.size()); + + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_NEW_CONNECTION_ID_DATA); +} + +TEST_P(QuicFramerTest, NewConnectionIdFrameVariableLength) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // This frame is only for IETF QUIC. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (IETF_NEW_CONNECTION_ID frame) + {"", + {0x18}}, + // error code + {"Unable to read new connection ID frame sequence number.", + {kVarInt62OneByte + 0x11}}, + {"Unable to read new connection ID frame retire_prior_to.", + {kVarInt62OneByte + 0x0a}}, + {"Unable to read new connection ID frame connection id.", + {0x09}}, // connection ID length + {"Unable to read new connection ID frame connection id.", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, 0x42}}, + {"Can not read new connection ID frame reset token.", + {0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, + 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f}} + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(0u, visitor_.stream_frames_.size()); + + EXPECT_EQ(FramerTestConnectionIdNineBytes(), + visitor_.new_connection_id_.connection_id); + EXPECT_EQ(0x11u, visitor_.new_connection_id_.sequence_number); + EXPECT_EQ(0x0au, visitor_.new_connection_id_.retire_prior_to); + EXPECT_EQ(kTestStatelessResetToken, + visitor_.new_connection_id_.stateless_reset_token); + + ASSERT_EQ(0u, visitor_.ack_frames_.size()); + + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_NEW_CONNECTION_ID_DATA); +} + +// Verifies that parsing a NEW_CONNECTION_ID frame with a length above the +// specified maximum fails. +TEST_P(QuicFramerTest, InvalidLongNewConnectionIdFrame) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // The NEW_CONNECTION_ID frame is only for IETF QUIC. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (IETF_NEW_CONNECTION_ID frame) + {"", + {0x18}}, + // error code + {"Unable to read new connection ID frame sequence number.", + {kVarInt62OneByte + 0x11}}, + {"Unable to read new connection ID frame retire_prior_to.", + {kVarInt62OneByte + 0x0b}}, + {"Unable to read new connection ID frame connection id.", + {0x40}}, // connection ID length + {"Unable to read new connection ID frame connection id.", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + 0xF0, 0xD2, 0xB4, 0x96, 0x78, 0x5A, 0x3C, 0x1E, + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + 0xF0, 0xD2, 0xB4, 0x96, 0x78, 0x5A, 0x3C, 0x1E, + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + 0xF0, 0xD2, 0xB4, 0x96, 0x78, 0x5A, 0x3C, 0x1E, + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + 0xF0, 0xD2, 0xB4, 0x96, 0x78, 0x5A, 0x3C, 0x1E}}, + {"Can not read new connection ID frame reset token.", + {0xb5, 0x69, 0x0f, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}} + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsError(QUIC_INVALID_NEW_CONNECTION_ID_DATA)); + EXPECT_EQ("Invalid new connection ID length for version.", + framer_.detailed_error()); +} + +// Verifies that parsing a NEW_CONNECTION_ID frame with an invalid +// retire-prior-to fails. +TEST_P(QuicFramerTest, InvalidRetirePriorToNewConnectionIdFrame) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // This frame is only for IETF QUIC only. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (IETF_NEW_CONNECTION_ID frame) + {"", + {0x18}}, + // sequence number + {"Unable to read new connection ID frame sequence number.", + {kVarInt62OneByte + 0x11}}, + {"Unable to read new connection ID frame retire_prior_to.", + {kVarInt62OneByte + 0x1b}}, + {"Unable to read new connection ID frame connection id length.", + {0x08}}, // connection ID length + {"Unable to read new connection ID frame connection id.", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x11}}, + {"Can not read new connection ID frame reset token.", + {0xb5, 0x69, 0x0f, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}} + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsError(QUIC_INVALID_NEW_CONNECTION_ID_DATA)); + EXPECT_EQ("Retire_prior_to > sequence_number.", framer_.detailed_error()); +} + +TEST_P(QuicFramerTest, BuildNewConnectionIdFramePacket) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // This frame is only for IETF QUIC only. + return; + } + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicNewConnectionIdFrame frame; + frame.sequence_number = 0x11; + frame.retire_prior_to = 0x0c; + // Use this value to force a 4-byte encoded variable length connection ID + // in the frame. + frame.connection_id = FramerTestConnectionIdPlusOne(); + frame.stateless_reset_token = kTestStatelessResetToken; + + QuicFrames frames = {QuicFrame(&frame)}; + + // clang-format off + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_NEW_CONNECTION_ID frame) + 0x18, + // sequence number + kVarInt62OneByte + 0x11, + // retire_prior_to + kVarInt62OneByte + 0x0c, + // new connection id length + 0x08, + // new connection id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x11, + // stateless reset token + 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, + 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f, + }; + // clang-format on + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); +} + +TEST_P(QuicFramerTest, NewTokenFrame) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // This frame is only for IETF QUIC only. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (IETF_NEW_TOKEN frame) + {"", + {0x07}}, + // Length + {"Unable to read new token length.", + {kVarInt62OneByte + 0x08}}, + {"Unable to read new token data.", + {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}} + }; + // clang-format on + uint8_t expected_token_value[] = {0x00, 0x01, 0x02, 0x03, + 0x04, 0x05, 0x06, 0x07}; + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(0u, visitor_.stream_frames_.size()); + + EXPECT_EQ(sizeof(expected_token_value), visitor_.new_token_.token.length()); + EXPECT_EQ(0, memcmp(expected_token_value, visitor_.new_token_.token.data(), + sizeof(expected_token_value))); + + CheckFramingBoundaries(packet, QUIC_INVALID_NEW_TOKEN); +} + +TEST_P(QuicFramerTest, BuildNewTokenFramePacket) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // This frame is only for IETF QUIC only. + return; + } + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + uint8_t expected_token_value[] = {0x00, 0x01, 0x02, 0x03, + 0x04, 0x05, 0x06, 0x07}; + + QuicNewTokenFrame frame(0, + absl::string_view((const char*)(expected_token_value), + sizeof(expected_token_value))); + + QuicFrames frames = {QuicFrame(&frame)}; + + // clang-format off + unsigned char packet[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_NEW_TOKEN frame) + 0x07, + // Length and token + kVarInt62OneByte + 0x08, + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07 + }; + // clang-format on + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet), + ABSL_ARRAYSIZE(packet)); +} + +TEST_P(QuicFramerTest, IetfStopSendingFrame) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // Stop sending frame is IETF QUIC only. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (IETF_STOP_SENDING frame) + {"", + {0x05}}, + // stream id + {"Unable to read IETF_STOP_SENDING frame stream id/count.", + {kVarInt62FourBytes + 0x01, 0x02, 0x03, 0x04}}, + {"Unable to read stop sending application error code.", + {kVarInt62FourBytes + 0x00, 0x00, 0x76, 0x54}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(kStreamId, visitor_.stop_sending_frame_.stream_id); + EXPECT_EQ(QUIC_STREAM_UNKNOWN_APPLICATION_ERROR_CODE, + visitor_.stop_sending_frame_.error_code); + EXPECT_EQ(static_cast(0x7654), + visitor_.stop_sending_frame_.ietf_error_code); + + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_STOP_SENDING_FRAME_DATA); +} + +TEST_P(QuicFramerTest, BuildIetfStopSendingPacket) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // Stop sending frame is IETF QUIC only. + return; + } + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicStopSendingFrame frame; + frame.stream_id = kStreamId; + frame.error_code = QUIC_STREAM_ENCODER_STREAM_ERROR; + frame.ietf_error_code = + static_cast(QuicHttpQpackErrorCode::ENCODER_STREAM_ERROR); + QuicFrames frames = {QuicFrame(frame)}; + + // clang-format off + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_STOP_SENDING frame) + 0x05, + // Stream ID + kVarInt62FourBytes + 0x01, 0x02, 0x03, 0x04, + // Application error code + kVarInt62TwoBytes + 0x02, 0x01, + }; + // clang-format on + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); +} + +TEST_P(QuicFramerTest, IetfPathChallengeFrame) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // Path Challenge frame is IETF QUIC only. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (IETF_PATH_CHALLENGE) + {"", + {0x1a}}, + // data + {"Can not read path challenge data.", + {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(QuicPathFrameBuffer({{0, 1, 2, 3, 4, 5, 6, 7}}), + visitor_.path_challenge_frame_.data_buffer); + + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_PATH_CHALLENGE_DATA); +} + +TEST_P(QuicFramerTest, BuildIetfPathChallengePacket) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // Path Challenge frame is IETF QUIC only. + return; + } + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicPathChallengeFrame frame; + frame.data_buffer = QuicPathFrameBuffer({{0, 1, 2, 3, 4, 5, 6, 7}}); + QuicFrames frames = {QuicFrame(frame)}; + + // clang-format off + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_PATH_CHALLENGE) + 0x1a, + // Data + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07 + }; + // clang-format on + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); +} + +TEST_P(QuicFramerTest, IetfPathResponseFrame) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // Path response frame is IETF QUIC only. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (IETF_PATH_RESPONSE) + {"", + {0x1b}}, + // data + {"Can not read path response data.", + {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(QuicPathFrameBuffer({{0, 1, 2, 3, 4, 5, 6, 7}}), + visitor_.path_response_frame_.data_buffer); + + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_PATH_RESPONSE_DATA); +} + +TEST_P(QuicFramerTest, BuildIetfPathResponsePacket) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // Path response frame is IETF QUIC only + return; + } + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicPathResponseFrame frame; + frame.data_buffer = QuicPathFrameBuffer({{0, 1, 2, 3, 4, 5, 6, 7}}); + QuicFrames frames = {QuicFrame(frame)}; + + // clang-format off + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_PATH_RESPONSE) + 0x1b, + // Data + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07 + }; + // clang-format on + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); +} + +TEST_P(QuicFramerTest, GetRetransmittableControlFrameSize) { + QuicRstStreamFrame rst_stream(1, 3, QUIC_STREAM_CANCELLED, 1024); + EXPECT_EQ(QuicFramer::GetRstStreamFrameSize(framer_.transport_version(), + rst_stream), + QuicFramer::GetRetransmittableControlFrameSize( + framer_.transport_version(), QuicFrame(&rst_stream))); + + std::string error_detail(2048, 'e'); + QuicConnectionCloseFrame connection_close(framer_.transport_version(), + QUIC_NETWORK_IDLE_TIMEOUT, + NO_IETF_QUIC_ERROR, error_detail, + /*transport_close_frame_type=*/0); + + EXPECT_EQ(QuicFramer::GetConnectionCloseFrameSize(framer_.transport_version(), + connection_close), + QuicFramer::GetRetransmittableControlFrameSize( + framer_.transport_version(), QuicFrame(&connection_close))); + + QuicGoAwayFrame goaway(2, QUIC_PEER_GOING_AWAY, 3, error_detail); + EXPECT_EQ(QuicFramer::GetMinGoAwayFrameSize() + 256, + QuicFramer::GetRetransmittableControlFrameSize( + framer_.transport_version(), QuicFrame(&goaway))); + + QuicWindowUpdateFrame window_update(3, 3, 1024); + EXPECT_EQ(QuicFramer::GetWindowUpdateFrameSize(framer_.transport_version(), + window_update), + QuicFramer::GetRetransmittableControlFrameSize( + framer_.transport_version(), QuicFrame(window_update))); + + QuicBlockedFrame blocked(4, 3, 1024); + EXPECT_EQ( + QuicFramer::GetBlockedFrameSize(framer_.transport_version(), blocked), + QuicFramer::GetRetransmittableControlFrameSize( + framer_.transport_version(), QuicFrame(blocked))); + + // Following frames are IETF QUIC frames only. + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + + QuicNewConnectionIdFrame new_connection_id(5, TestConnectionId(), 1, + kTestStatelessResetToken, 1); + EXPECT_EQ(QuicFramer::GetNewConnectionIdFrameSize(new_connection_id), + QuicFramer::GetRetransmittableControlFrameSize( + framer_.transport_version(), QuicFrame(&new_connection_id))); + + QuicMaxStreamsFrame max_streams(6, 3, /*unidirectional=*/false); + EXPECT_EQ(QuicFramer::GetMaxStreamsFrameSize(framer_.transport_version(), + max_streams), + QuicFramer::GetRetransmittableControlFrameSize( + framer_.transport_version(), QuicFrame(max_streams))); + + QuicStreamsBlockedFrame streams_blocked(7, 3, /*unidirectional=*/false); + EXPECT_EQ(QuicFramer::GetStreamsBlockedFrameSize(framer_.transport_version(), + streams_blocked), + QuicFramer::GetRetransmittableControlFrameSize( + framer_.transport_version(), QuicFrame(streams_blocked))); + + QuicPathFrameBuffer buffer = { + {0x80, 0x91, 0xa2, 0xb3, 0xc4, 0xd5, 0xe5, 0xf7}}; + QuicPathResponseFrame path_response_frame(8, buffer); + EXPECT_EQ(QuicFramer::GetPathResponseFrameSize(path_response_frame), + QuicFramer::GetRetransmittableControlFrameSize( + framer_.transport_version(), QuicFrame(path_response_frame))); + + QuicPathChallengeFrame path_challenge_frame(9, buffer); + EXPECT_EQ(QuicFramer::GetPathChallengeFrameSize(path_challenge_frame), + QuicFramer::GetRetransmittableControlFrameSize( + framer_.transport_version(), QuicFrame(path_challenge_frame))); + + QuicStopSendingFrame stop_sending_frame(10, 3, QUIC_STREAM_CANCELLED); + EXPECT_EQ(QuicFramer::GetStopSendingFrameSize(stop_sending_frame), + QuicFramer::GetRetransmittableControlFrameSize( + framer_.transport_version(), QuicFrame(stop_sending_frame))); +} + +// A set of tests to ensure that bad frame-type encodings +// are properly detected and handled. +// First, four tests to see that unknown frame types generate +// a QUIC_INVALID_FRAME_DATA error with detailed information +// "Illegal frame type." This regardless of the encoding of the type +// (1/2/4/8 bytes). +// This only for version 99. +TEST_P(QuicFramerTest, IetfFrameTypeEncodingErrorUnknown1Byte) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // Only IETF QUIC encodes frame types such that this test is relevant. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (unknown value, single-byte encoding) + {"", + {0x38}} + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet)); + + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsError(QUIC_INVALID_FRAME_DATA)); + EXPECT_EQ("Illegal frame type.", framer_.detailed_error()); +} + +TEST_P(QuicFramerTest, IetfFrameTypeEncodingErrorUnknown2Bytes) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // Only IETF QUIC encodes frame types such that this test is relevant. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + PacketFragments packet = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (unknown value, two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x01, 0x38}} + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet)); + + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsError(QUIC_INVALID_FRAME_DATA)); + EXPECT_EQ("Illegal frame type.", framer_.detailed_error()); +} + +TEST_P(QuicFramerTest, IetfFrameTypeEncodingErrorUnknown4Bytes) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // Only IETF QUIC encodes frame types such that this test is relevant. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + PacketFragments packet = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (unknown value, four-byte encoding) + {"", + {kVarInt62FourBytes + 0x01, 0x00, 0x00, 0x38}} + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet)); + + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsError(QUIC_INVALID_FRAME_DATA)); + EXPECT_EQ("Illegal frame type.", framer_.detailed_error()); +} + +TEST_P(QuicFramerTest, IetfFrameTypeEncodingErrorUnknown8Bytes) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // Only IETF QUIC encodes frame types such that this test is relevant. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (unknown value, eight-byte encoding) + {"", + {kVarInt62EightBytes + 0x01, 0x00, 0x00, 0x01, 0x02, 0x34, 0x56, 0x38}} + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet)); + + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsError(QUIC_INVALID_FRAME_DATA)); + EXPECT_EQ("Illegal frame type.", framer_.detailed_error()); +} + +// Three tests to check that known frame types that are not minimally +// encoded generate IETF_QUIC_PROTOCOL_VIOLATION errors with detailed +// information "Frame type not minimally encoded." +// Look at the frame-type encoded in 2, 4, and 8 bytes. +TEST_P(QuicFramerTest, IetfFrameTypeEncodingErrorKnown2Bytes) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // Only IETF QUIC encodes frame types such that this test is relevant. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + PacketFragments packet = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (Blocked, two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x08}} + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet)); + + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsError(IETF_QUIC_PROTOCOL_VIOLATION)); + EXPECT_EQ("Frame type not minimally encoded.", framer_.detailed_error()); +} + +TEST_P(QuicFramerTest, IetfFrameTypeEncodingErrorKnown4Bytes) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // Only IETF QUIC encodes frame types such that this test is relevant. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + PacketFragments packet = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (Blocked, four-byte encoding) + {"", + {kVarInt62FourBytes + 0x00, 0x00, 0x00, 0x08}} + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet)); + + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsError(IETF_QUIC_PROTOCOL_VIOLATION)); + EXPECT_EQ("Frame type not minimally encoded.", framer_.detailed_error()); +} + +TEST_P(QuicFramerTest, IetfFrameTypeEncodingErrorKnown8Bytes) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // Only IETF QUIC encodes frame types such that this test is relevant. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (Blocked, eight-byte encoding) + {"", + {kVarInt62EightBytes + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08}} + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet)); + + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsError(IETF_QUIC_PROTOCOL_VIOLATION)); + EXPECT_EQ("Frame type not minimally encoded.", framer_.detailed_error()); +} + +// Tests to check that all known IETF frame types that are not minimally +// encoded generate IETF_QUIC_PROTOCOL_VIOLATION errors with detailed +// information "Frame type not minimally encoded." +// Just look at 2-byte encoding. +TEST_P(QuicFramerTest, IetfFrameTypeEncodingErrorKnown2BytesAllTypes) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // Only IETF QUIC encodes frame types such that this test is relevant. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + // clang-format off + PacketFragments packets[] = { + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x00}} + }, + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x01}} + }, + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x02}} + }, + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x03}} + }, + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x04}} + }, + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x05}} + }, + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x06}} + }, + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x07}} + }, + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x08}} + }, + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x09}} + }, + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x0a}} + }, + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x0b}} + }, + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x0c}} + }, + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x0d}} + }, + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x0e}} + }, + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x0f}} + }, + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x10}} + }, + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x11}} + }, + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x12}} + }, + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x13}} + }, + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x14}} + }, + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x15}} + }, + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x16}} + }, + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x17}} + }, + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x18}} + }, + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x20}} + }, + { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x9A, 0xBC}}, + // frame type (two-byte encoding) + {"", + {kVarInt62TwoBytes + 0x00, 0x21}} + }, + }; + // clang-format on + + for (PacketFragments& packet : packets) { + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet)); + + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsError(IETF_QUIC_PROTOCOL_VIOLATION)); + EXPECT_EQ("Frame type not minimally encoded.", framer_.detailed_error()); + } +} + +TEST_P(QuicFramerTest, RetireConnectionIdFrame) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // This frame is only for version 99. + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + {0x43}}, + // connection_id + {"", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"", + {0x12, 0x34, 0x56, 0x78}}, + // frame type (IETF_RETIRE_CONNECTION_ID frame) + {"", + {0x19}}, + // Sequence number + {"Unable to read retire connection ID frame sequence number.", + {kVarInt62TwoBytes + 0x11, 0x22}} + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_TRUE(CheckDecryption( + *encrypted, !kIncludeVersion, !kIncludeDiversificationNonce, + kPacket8ByteConnectionId, kPacket0ByteConnectionId)); + + EXPECT_EQ(0u, visitor_.stream_frames_.size()); + + EXPECT_EQ(0x1122u, visitor_.retire_connection_id_.sequence_number); + + ASSERT_EQ(0u, visitor_.ack_frames_.size()); + + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_RETIRE_CONNECTION_ID_DATA); +} + +TEST_P(QuicFramerTest, BuildRetireConnectionIdFramePacket) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + // This frame is only for version 99. + return; + } + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicRetireConnectionIdFrame frame; + frame.sequence_number = 0x1122; + + QuicFrames frames = {QuicFrame(&frame)}; + + // clang-format off + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_RETIRE_CONNECTION_ID frame) + 0x19, + // sequence number + kVarInt62TwoBytes + 0x11, 0x22 + }; + // clang-format on + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); +} + +TEST_P(QuicFramerTest, AckFrameWithInvalidLargestObserved) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x2C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (ack frame) + 0x45, + // largest observed + 0x00, 0x00, + // Zero delta time. + 0x00, 0x00, + // first ack block length. + 0x00, 0x00, + // num timestamps. + 0x00 + }; + + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (ack frame) + 0x45, + // largest observed + 0x00, 0x00, + // Zero delta time. + 0x00, 0x00, + // first ack block length. + 0x00, 0x00, + // num timestamps. + 0x00 + }; + + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_ACK frame) + 0x02, + // Largest acked + kVarInt62OneByte + 0x00, + // Zero delta time. + kVarInt62OneByte + 0x00, + // Ack block count 0 + kVarInt62OneByte + 0x00, + // First ack block length + kVarInt62OneByte + 0x00, + }; + // clang-format on + + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + } + + QuicEncryptedPacket encrypted(AsChars(p), p_size, false); + EXPECT_FALSE(framer_.ProcessPacket(encrypted)); + EXPECT_EQ(framer_.detailed_error(), "Largest acked is 0."); +} + +TEST_P(QuicFramerTest, FirstAckBlockJustUnderFlow) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x2C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (ack frame) + 0x45, + // largest observed + 0x00, 0x02, + // Zero delta time. + 0x00, 0x00, + // first ack block length. + 0x00, 0x03, + // num timestamps. + 0x00 + }; + + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (ack frame) + 0x45, + // largest observed + 0x00, 0x02, + // Zero delta time. + 0x00, 0x00, + // first ack block length. + 0x00, 0x03, + // num timestamps. + 0x00 + }; + + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_ACK frame) + 0x02, + // Largest acked + kVarInt62OneByte + 0x02, + // Zero delta time. + kVarInt62OneByte + 0x00, + // Ack block count 0 + kVarInt62OneByte + 0x00, + // First ack block length + kVarInt62OneByte + 0x02, + }; + // clang-format on + + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_size = ABSL_ARRAYSIZE(packet46); + } + + QuicEncryptedPacket encrypted(AsChars(p), p_size, false); + EXPECT_FALSE(framer_.ProcessPacket(encrypted)); + EXPECT_EQ(framer_.detailed_error(), + "Underflow with first ack block length 3 largest acked is 2."); +} + +TEST_P(QuicFramerTest, ThirdAckBlockJustUnderflow) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x2C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (ack frame) + 0x60, + // largest observed + 0x0A, + // Zero delta time. + 0x00, 0x00, + // Num of ack blocks + 0x02, + // first ack block length. + 0x02, + // gap to next block + 0x01, + // ack block length + 0x01, + // gap to next block + 0x01, + // ack block length + 0x06, + // num timestamps. + 0x00 + }; + + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (ack frame) + 0x60, + // largest observed + 0x0A, + // Zero delta time. + 0x00, 0x00, + // Num of ack blocks + 0x02, + // first ack block length. + 0x02, + // gap to next block + 0x01, + // ack block length + 0x01, + // gap to next block + 0x01, + // ack block length + 0x06, + // num timestamps. + 0x00 + }; + + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_ACK frame) + 0x02, + // Largest acked + kVarInt62OneByte + 0x0A, + // Zero delta time. + kVarInt62OneByte + 0x00, + // Ack block count 2 + kVarInt62OneByte + 0x02, + // First ack block length + kVarInt62OneByte + 0x01, + // gap to next block length + kVarInt62OneByte + 0x00, + // ack block length + kVarInt62OneByte + 0x00, + // gap to next block length + kVarInt62OneByte + 0x00, + // ack block length + kVarInt62OneByte + 0x05, + }; + // clang-format on + + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_size = ABSL_ARRAYSIZE(packet46); + } + + QuicEncryptedPacket encrypted(AsChars(p), p_size, false); + EXPECT_FALSE(framer_.ProcessPacket(encrypted)); + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + EXPECT_EQ(framer_.detailed_error(), + "Underflow with ack block length 6 latest ack block end is 5."); + } else { + EXPECT_EQ(framer_.detailed_error(), + "Underflow with ack block length 6, end of block is 6."); + } +} + +TEST_P(QuicFramerTest, CoalescedPacket) { + if (!QuicVersionHasLongHeaderLengths(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_ZERO_RTT); + // clang-format off + unsigned char packet[] = { + // first coalesced packet + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + 0xD3, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // long header packet length + 0x1E, + // packet number + 0x12, 0x34, 0x56, 0x78, + // frame type (stream frame with fin) + 0xFE, + // stream id + 0x02, 0x03, 0x04, + // offset + 0x3A, 0x98, 0xFE, 0xDC, 0x32, 0x10, 0x76, 0x54, + // data length + 0x00, 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + // second coalesced packet + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + 0xD3, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // long header packet length + 0x1E, + // packet number + 0x12, 0x34, 0x56, 0x79, + // frame type (stream frame with fin) + 0xFE, + // stream id + 0x02, 0x03, 0x04, + // offset + 0x3A, 0x98, 0xFE, 0xDC, 0x32, 0x10, 0x76, 0x54, + // data length + 0x00, 0x0c, + // data + 'H', 'E', 'L', 'L', + 'O', '_', 'W', 'O', + 'R', 'L', 'D', '?', + }; + unsigned char packet_ietf[] = { + // first coalesced packet + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + 0xD3, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // long header packet length + 0x1E, + // packet number + 0x12, 0x34, 0x56, 0x78, + // frame type (IETF_STREAM frame with FIN, LEN, and OFFSET bits set) + 0x08 | 0x01 | 0x02 | 0x04, + // stream id + kVarInt62FourBytes + 0x00, 0x02, 0x03, 0x04, + // offset + kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54, + // data length + kVarInt62OneByte + 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + // second coalesced packet + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + 0xD3, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // long header packet length + 0x1E, + // packet number + 0x12, 0x34, 0x56, 0x79, + // frame type (IETF_STREAM frame with FIN, LEN, and OFFSET bits set) + 0x08 | 0x01 | 0x02 | 0x04, + // stream id + kVarInt62FourBytes + 0x00, 0x02, 0x03, 0x04, + // offset + kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54, + // data length + kVarInt62OneByte + 0x0c, + // data + 'H', 'E', 'L', 'L', + 'O', '_', 'W', 'O', + 'R', 'L', 'D', '?', + }; + // clang-format on + const size_t first_packet_ietf_size = 46; + // If the first packet changes, the attempt to fix the first byte of the + // second packet will fail. + EXPECT_EQ(packet_ietf[first_packet_ietf_size], 0xD3); + + unsigned char* p = packet; + size_t p_length = ABSL_ARRAYSIZE(packet); + if (framer_.version().HasIetfQuicFrames()) { + ReviseFirstByteByVersion(packet_ietf); + ReviseFirstByteByVersion(&packet_ietf[first_packet_ietf_size]); + p = packet_ietf; + p_length = ABSL_ARRAYSIZE(packet_ietf); + } + + QuicEncryptedPacket encrypted(AsChars(p), p_length, false); + EXPECT_TRUE(framer_.ProcessPacket(encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + + ASSERT_EQ(1u, visitor_.stream_frames_.size()); + EXPECT_EQ(0u, visitor_.ack_frames_.size()); + + // Stream ID should be the last 3 bytes of kStreamId. + EXPECT_EQ(0x00FFFFFF & kStreamId, visitor_.stream_frames_[0]->stream_id); + EXPECT_TRUE(visitor_.stream_frames_[0]->fin); + EXPECT_EQ(kStreamOffset, visitor_.stream_frames_[0]->offset); + CheckStreamFrameData("hello world!", visitor_.stream_frames_[0].get()); + + ASSERT_EQ(visitor_.coalesced_packets_.size(), 1u); + EXPECT_TRUE(framer_.ProcessPacket(*visitor_.coalesced_packets_[0].get())); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + + ASSERT_EQ(2u, visitor_.stream_frames_.size()); + EXPECT_EQ(0u, visitor_.ack_frames_.size()); + + // Stream ID should be the last 3 bytes of kStreamId. + EXPECT_EQ(0x00FFFFFF & kStreamId, visitor_.stream_frames_[1]->stream_id); + EXPECT_TRUE(visitor_.stream_frames_[1]->fin); + EXPECT_EQ(kStreamOffset, visitor_.stream_frames_[1]->offset); + CheckStreamFrameData("HELLO_WORLD?", visitor_.stream_frames_[1].get()); +} + +TEST_P(QuicFramerTest, CoalescedPacketWithUdpPadding) { + if (!framer_.version().HasLongHeaderLengths()) { + return; + } + SetDecrypterLevel(ENCRYPTION_ZERO_RTT); + // clang-format off + unsigned char packet[] = { + // first coalesced packet + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + 0xD3, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // long header packet length + 0x1E, + // packet number + 0x12, 0x34, 0x56, 0x78, + // frame type (stream frame with fin) + 0xFE, + // stream id + 0x02, 0x03, 0x04, + // offset + 0x3A, 0x98, 0xFE, 0xDC, 0x32, 0x10, 0x76, 0x54, + // data length + 0x00, 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + // padding + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + }; + unsigned char packet_ietf[] = { + // first coalesced packet + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + 0xD3, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // long header packet length + 0x1E, + // packet number + 0x12, 0x34, 0x56, 0x78, + // frame type (IETF_STREAM frame with FIN, LEN, and OFFSET bits set) + 0x08 | 0x01 | 0x02 | 0x04, + // stream id + kVarInt62FourBytes + 0x00, 0x02, 0x03, 0x04, + // offset + kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54, + // data length + kVarInt62OneByte + 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + // padding + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + }; + // clang-format on + + unsigned char* p = packet; + size_t p_length = ABSL_ARRAYSIZE(packet); + if (framer_.version().HasIetfQuicFrames()) { + ReviseFirstByteByVersion(packet_ietf); + p = packet_ietf; + p_length = ABSL_ARRAYSIZE(packet_ietf); + } + + QuicEncryptedPacket encrypted(AsChars(p), p_length, false); + EXPECT_TRUE(framer_.ProcessPacket(encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + + ASSERT_EQ(1u, visitor_.stream_frames_.size()); + EXPECT_EQ(0u, visitor_.ack_frames_.size()); + + // Stream ID should be the last 3 bytes of kStreamId. + EXPECT_EQ(0x00FFFFFF & kStreamId, visitor_.stream_frames_[0]->stream_id); + EXPECT_TRUE(visitor_.stream_frames_[0]->fin); + EXPECT_EQ(kStreamOffset, visitor_.stream_frames_[0]->offset); + CheckStreamFrameData("hello world!", visitor_.stream_frames_[0].get()); + + EXPECT_EQ(visitor_.coalesced_packets_.size(), 0u); +} + +TEST_P(QuicFramerTest, CoalescedPacketWithDifferentVersion) { + if (!QuicVersionHasLongHeaderLengths(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_ZERO_RTT); + // clang-format off + unsigned char packet[] = { + // first coalesced packet + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + 0xD3, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // long header packet length + 0x1E, + // packet number + 0x12, 0x34, 0x56, 0x78, + // frame type (stream frame with fin) + 0xFE, + // stream id + 0x02, 0x03, 0x04, + // offset + 0x3A, 0x98, 0xFE, 0xDC, 0x32, 0x10, 0x76, 0x54, + // data length + 0x00, 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + // second coalesced packet + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + 0xD3, + // garbage version + 'G', 'A', 'B', 'G', + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // long header packet length + 0x1E, + // packet number + 0x12, 0x34, 0x56, 0x79, + // frame type (stream frame with fin) + 0xFE, + // stream id + 0x02, 0x03, 0x04, + // offset + 0x3A, 0x98, 0xFE, 0xDC, 0x32, 0x10, 0x76, 0x54, + // data length + 0x00, 0x0c, + // data + 'H', 'E', 'L', 'L', + 'O', '_', 'W', 'O', + 'R', 'L', 'D', '?', + }; + unsigned char packet_ietf[] = { + // first coalesced packet + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + 0xD3, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // long header packet length + 0x1E, + // packet number + 0x12, 0x34, 0x56, 0x78, + // frame type (IETF_STREAM frame with FIN, LEN, and OFFSET bits set) + 0x08 | 0x01 | 0x02 | 0x04, + // stream id + kVarInt62FourBytes + 0x00, 0x02, 0x03, 0x04, + // offset + kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54, + // data length + kVarInt62OneByte + 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + // second coalesced packet + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + 0xD3, + // garbage version + 'G', 'A', 'B', 'G', + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // long header packet length + 0x1E, + // packet number + 0x12, 0x34, 0x56, 0x79, + // frame type (IETF_STREAM frame with FIN, LEN, and OFFSET bits set) + 0x08 | 0x01 | 0x02 | 0x04, + // stream id + kVarInt62FourBytes + 0x00, 0x02, 0x03, 0x04, + // offset + kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54, + // data length + kVarInt62OneByte + 0x0c, + // data + 'H', 'E', 'L', 'L', + 'O', '_', 'W', 'O', + 'R', 'L', 'D', '?', + }; + // clang-format on + const size_t first_packet_ietf_size = 46; + // If the first packet changes, the attempt to fix the first byte of the + // second packet will fail. + EXPECT_EQ(packet_ietf[first_packet_ietf_size], 0xD3); + + unsigned char* p = packet; + size_t p_length = ABSL_ARRAYSIZE(packet); + if (framer_.version().HasIetfQuicFrames()) { + ReviseFirstByteByVersion(packet_ietf); + ReviseFirstByteByVersion(&packet_ietf[first_packet_ietf_size]); + p = packet_ietf; + p_length = ABSL_ARRAYSIZE(packet_ietf); + } + + QuicEncryptedPacket encrypted(AsChars(p), p_length, false); + EXPECT_TRUE(framer_.ProcessPacket(encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + + ASSERT_EQ(1u, visitor_.stream_frames_.size()); + EXPECT_EQ(0u, visitor_.ack_frames_.size()); + + // Stream ID should be the last 3 bytes of kStreamId. + EXPECT_EQ(0x00FFFFFF & kStreamId, visitor_.stream_frames_[0]->stream_id); + EXPECT_TRUE(visitor_.stream_frames_[0]->fin); + EXPECT_EQ(kStreamOffset, visitor_.stream_frames_[0]->offset); + CheckStreamFrameData("hello world!", visitor_.stream_frames_[0].get()); + + ASSERT_EQ(visitor_.coalesced_packets_.size(), 1u); + EXPECT_TRUE(framer_.ProcessPacket(*visitor_.coalesced_packets_[0].get())); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + + ASSERT_EQ(1u, visitor_.stream_frames_.size()); + // Verify version mismatch gets reported. + EXPECT_EQ(1, visitor_.version_mismatch_); +} + +TEST_P(QuicFramerTest, UndecryptablePacketWithoutDecrypter) { + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + + if (!framer_.version().KnowsWhichDecrypterToUse()) { + // We create a bad client decrypter by using initial encryption with a + // bogus connection ID; it should fail to decrypt everything. + QuicConnectionId bogus_connection_id = TestConnectionId(0xbad); + CrypterPair bogus_crypters; + CryptoUtils::CreateInitialObfuscators(Perspective::IS_CLIENT, + framer_.version(), + bogus_connection_id, &bogus_crypters); + // This removes all other decrypters. + framer_.SetDecrypter(ENCRYPTION_FORWARD_SECURE, + std::move(bogus_crypters.decrypter)); + } + + // clang-format off + unsigned char packet[] = { + // public flags (version included, 8-byte connection ID, + // 4-byte packet number) + 0x28, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x00, + // padding frames + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }; + unsigned char packet46[] = { + // public flags (long header with packet type HANDSHAKE and + // 4-byte packet number) + 0xE3, + // version + QUIC_VERSION_BYTES, + // connection ID lengths + 0x05, + // source connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // long header packet length + 0x05, + // packet number + 0x12, 0x34, 0x56, 0x00, + // padding frames + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }; + unsigned char packet49[] = { + // public flags (long header with packet type HANDSHAKE and + // 4-byte packet number) + 0xE3, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x00, + // source connection ID length + 0x08, + // source connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // long header packet length + 0x24, + // packet number + 0x12, 0x34, 0x56, 0x00, + // padding frames + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }; + // clang-format on + unsigned char* p = packet; + size_t p_length = ABSL_ARRAYSIZE(packet); + if (framer_.version().HasLongHeaderLengths()) { + ReviseFirstByteByVersion(packet49); + p = packet49; + p_length = ABSL_ARRAYSIZE(packet49); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_length = ABSL_ARRAYSIZE(packet46); + } + // First attempt decryption without the handshake crypter. + EXPECT_FALSE( + framer_.ProcessPacket(QuicEncryptedPacket(AsChars(p), p_length, false))); + EXPECT_THAT(framer_.error(), IsError(QUIC_DECRYPTION_FAILURE)); + ASSERT_EQ(1u, visitor_.undecryptable_packets_.size()); + ASSERT_EQ(1u, visitor_.undecryptable_decryption_levels_.size()); + ASSERT_EQ(1u, visitor_.undecryptable_has_decryption_keys_.size()); + quiche::test::CompareCharArraysWithHexError( + "undecryptable packet", visitor_.undecryptable_packets_[0]->data(), + visitor_.undecryptable_packets_[0]->length(), AsChars(p), p_length); + if (framer_.version().KnowsWhichDecrypterToUse()) { + EXPECT_EQ(ENCRYPTION_HANDSHAKE, + visitor_.undecryptable_decryption_levels_[0]); + } + EXPECT_FALSE(visitor_.undecryptable_has_decryption_keys_[0]); +} + +TEST_P(QuicFramerTest, UndecryptablePacketWithDecrypter) { + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + + // We create a bad client decrypter by using initial encryption with a + // bogus connection ID; it should fail to decrypt everything. + QuicConnectionId bogus_connection_id = TestConnectionId(0xbad); + CrypterPair bad_handshake_crypters; + CryptoUtils::CreateInitialObfuscators(Perspective::IS_CLIENT, + framer_.version(), bogus_connection_id, + &bad_handshake_crypters); + if (framer_.version().KnowsWhichDecrypterToUse()) { + framer_.InstallDecrypter(ENCRYPTION_HANDSHAKE, + std::move(bad_handshake_crypters.decrypter)); + } else { + framer_.SetDecrypter(ENCRYPTION_HANDSHAKE, + std::move(bad_handshake_crypters.decrypter)); + } + + // clang-format off + unsigned char packet[] = { + // public flags (version included, 8-byte connection ID, + // 4-byte packet number) + 0x28, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x00, + // padding frames + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }; + unsigned char packet46[] = { + // public flags (long header with packet type HANDSHAKE and + // 4-byte packet number) + 0xE3, + // version + QUIC_VERSION_BYTES, + // connection ID lengths + 0x05, + // source connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // long header packet length + 0x05, + // packet number + 0x12, 0x34, 0x56, 0x00, + // padding frames + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }; + unsigned char packet49[] = { + // public flags (long header with packet type HANDSHAKE and + // 4-byte packet number) + 0xE3, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x00, + // source connection ID length + 0x08, + // source connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // long header packet length + 0x24, + // packet number + 0x12, 0x34, 0x56, 0x00, + // padding frames + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }; + // clang-format on + unsigned char* p = packet; + size_t p_length = ABSL_ARRAYSIZE(packet); + if (framer_.version().HasLongHeaderLengths()) { + ReviseFirstByteByVersion(packet49); + p = packet49; + p_length = ABSL_ARRAYSIZE(packet49); + } else if (framer_.version().HasIetfInvariantHeader()) { + p = packet46; + p_length = ABSL_ARRAYSIZE(packet46); + } + + EXPECT_FALSE( + framer_.ProcessPacket(QuicEncryptedPacket(AsChars(p), p_length, false))); + EXPECT_THAT(framer_.error(), IsError(QUIC_DECRYPTION_FAILURE)); + ASSERT_EQ(1u, visitor_.undecryptable_packets_.size()); + ASSERT_EQ(1u, visitor_.undecryptable_decryption_levels_.size()); + ASSERT_EQ(1u, visitor_.undecryptable_has_decryption_keys_.size()); + quiche::test::CompareCharArraysWithHexError( + "undecryptable packet", visitor_.undecryptable_packets_[0]->data(), + visitor_.undecryptable_packets_[0]->length(), AsChars(p), p_length); + if (framer_.version().KnowsWhichDecrypterToUse()) { + EXPECT_EQ(ENCRYPTION_HANDSHAKE, + visitor_.undecryptable_decryption_levels_[0]); + } + EXPECT_EQ(framer_.version().KnowsWhichDecrypterToUse(), + visitor_.undecryptable_has_decryption_keys_[0]); +} + +TEST_P(QuicFramerTest, UndecryptableCoalescedPacket) { + if (!QuicVersionHasLongHeaderLengths(framer_.transport_version())) { + return; + } + ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse()); + SetDecrypterLevel(ENCRYPTION_ZERO_RTT); + // We create a bad client decrypter by using initial encryption with a + // bogus connection ID; it should fail to decrypt everything. + QuicConnectionId bogus_connection_id = TestConnectionId(0xbad); + CrypterPair bad_handshake_crypters; + CryptoUtils::CreateInitialObfuscators(Perspective::IS_CLIENT, + framer_.version(), bogus_connection_id, + &bad_handshake_crypters); + framer_.InstallDecrypter(ENCRYPTION_HANDSHAKE, + std::move(bad_handshake_crypters.decrypter)); + // clang-format off + unsigned char packet[] = { + // first coalesced packet + // public flags (long header with packet type HANDSHAKE and + // 4-byte packet number) + 0xE3, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // long header packet length + 0x1E, + // packet number + 0x12, 0x34, 0x56, 0x78, + // frame type (stream frame with fin) + 0xFE, + // stream id + 0x02, 0x03, 0x04, + // offset + 0x3A, 0x98, 0xFE, 0xDC, 0x32, 0x10, 0x76, 0x54, + // data length + 0x00, 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + // second coalesced packet + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + 0xD3, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // long header packet length + 0x1E, + // packet number + 0x12, 0x34, 0x56, 0x79, + // frame type (stream frame with fin) + 0xFE, + // stream id + 0x02, 0x03, 0x04, + // offset + 0x3A, 0x98, 0xFE, 0xDC, 0x32, 0x10, 0x76, 0x54, + // data length + 0x00, 0x0c, + // data + 'H', 'E', 'L', 'L', + 'O', '_', 'W', 'O', + 'R', 'L', 'D', '?', + }; + unsigned char packet_ietf[] = { + // first coalesced packet + // public flags (long header with packet type HANDSHAKE and + // 4-byte packet number) + 0xE3, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // long header packet length + 0x1E, + // packet number + 0x12, 0x34, 0x56, 0x78, + // frame type (IETF_STREAM frame with FIN, LEN, and OFFSET bits set) + 0x08 | 0x01 | 0x02 | 0x04, + // stream id + kVarInt62FourBytes + 0x00, 0x02, 0x03, 0x04, + // offset + kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54, + // data length + kVarInt62OneByte + 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + // second coalesced packet + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + 0xD3, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // long header packet length + 0x1E, + // packet number + 0x12, 0x34, 0x56, 0x79, + // frame type (IETF_STREAM frame with FIN, LEN, and OFFSET bits set) + 0x08 | 0x01 | 0x02 | 0x04, + // stream id + kVarInt62FourBytes + 0x00, 0x02, 0x03, 0x04, + // offset + kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54, + // data length + kVarInt62OneByte + 0x0c, + // data + 'H', 'E', 'L', 'L', + 'O', '_', 'W', 'O', + 'R', 'L', 'D', '?', + }; + // clang-format on + const size_t length_of_first_coalesced_packet = 46; + // If the first packet changes, the attempt to fix the first byte of the + // second packet will fail. + EXPECT_EQ(packet_ietf[length_of_first_coalesced_packet], 0xD3); + + unsigned char* p = packet; + size_t p_length = ABSL_ARRAYSIZE(packet); + if (framer_.version().HasIetfQuicFrames()) { + ReviseFirstByteByVersion(packet_ietf); + ReviseFirstByteByVersion(&packet_ietf[length_of_first_coalesced_packet]); + p = packet_ietf; + p_length = ABSL_ARRAYSIZE(packet_ietf); + } + + QuicEncryptedPacket encrypted(AsChars(p), p_length, false); + + EXPECT_FALSE(framer_.ProcessPacket(encrypted)); + + EXPECT_THAT(framer_.error(), IsError(QUIC_DECRYPTION_FAILURE)); + + ASSERT_EQ(1u, visitor_.undecryptable_packets_.size()); + ASSERT_EQ(1u, visitor_.undecryptable_decryption_levels_.size()); + ASSERT_EQ(1u, visitor_.undecryptable_has_decryption_keys_.size()); + // Make sure we only receive the first undecryptable packet and not the + // full packet including the second coalesced packet. + quiche::test::CompareCharArraysWithHexError( + "undecryptable packet", visitor_.undecryptable_packets_[0]->data(), + visitor_.undecryptable_packets_[0]->length(), AsChars(p), + length_of_first_coalesced_packet); + EXPECT_EQ(ENCRYPTION_HANDSHAKE, visitor_.undecryptable_decryption_levels_[0]); + EXPECT_TRUE(visitor_.undecryptable_has_decryption_keys_[0]); + + // Make sure the second coalesced packet is parsed correctly. + ASSERT_EQ(visitor_.coalesced_packets_.size(), 1u); + EXPECT_TRUE(framer_.ProcessPacket(*visitor_.coalesced_packets_[0].get())); + + ASSERT_TRUE(visitor_.header_.get()); + + ASSERT_EQ(1u, visitor_.stream_frames_.size()); + EXPECT_EQ(0u, visitor_.ack_frames_.size()); + + // Stream ID should be the last 3 bytes of kStreamId. + EXPECT_EQ(0x00FFFFFF & kStreamId, visitor_.stream_frames_[0]->stream_id); + EXPECT_TRUE(visitor_.stream_frames_[0]->fin); + EXPECT_EQ(kStreamOffset, visitor_.stream_frames_[0]->offset); + CheckStreamFrameData("HELLO_WORLD?", visitor_.stream_frames_[0].get()); +} + +TEST_P(QuicFramerTest, MismatchedCoalescedPacket) { + if (!QuicVersionHasLongHeaderLengths(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_ZERO_RTT); + // clang-format off + unsigned char packet[] = { + // first coalesced packet + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + 0xD3, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // long header packet length + 0x1E, + // packet number + 0x12, 0x34, 0x56, 0x78, + // frame type (stream frame with fin) + 0xFE, + // stream id + 0x02, 0x03, 0x04, + // offset + 0x3A, 0x98, 0xFE, 0xDC, 0x32, 0x10, 0x76, 0x54, + // data length + 0x00, 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + // second coalesced packet + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + 0xD3, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x11, + // source connection ID length + 0x00, + // long header packet length + 0x1E, + // packet number + 0x12, 0x34, 0x56, 0x79, + // frame type (stream frame with fin) + 0xFE, + // stream id + 0x02, 0x03, 0x04, + // offset + 0x3A, 0x98, 0xFE, 0xDC, 0x32, 0x10, 0x76, 0x54, + // data length + 0x00, 0x0c, + // data + 'H', 'E', 'L', 'L', + 'O', '_', 'W', 'O', + 'R', 'L', 'D', '?', + }; + unsigned char packet_ietf[] = { + // first coalesced packet + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + 0xD3, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // long header packet length + 0x1E, + // packet number + 0x12, 0x34, 0x56, 0x78, + // frame type (IETF_STREAM frame with FIN, LEN, and OFFSET bits set) + 0x08 | 0x01 | 0x02 | 0x04, + // stream id + kVarInt62FourBytes + 0x00, 0x02, 0x03, 0x04, + // offset + kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54, + // data length + kVarInt62OneByte + 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + // second coalesced packet + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + 0xD3, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x11, + // source connection ID length + 0x00, + // long header packet length + 0x1E, + // packet number + 0x12, 0x34, 0x56, 0x79, + // frame type (IETF_STREAM frame with FIN, LEN, and OFFSET bits set) + 0x08 | 0x01 | 0x02 | 0x04, + // stream id + kVarInt62FourBytes + 0x00, 0x02, 0x03, 0x04, + // offset + kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54, + // data length + kVarInt62OneByte + 0x0c, + // data + 'H', 'E', 'L', 'L', + 'O', '_', 'W', 'O', + 'R', 'L', 'D', '?', + }; + // clang-format on + const size_t length_of_first_coalesced_packet = 46; + // If the first packet changes, the attempt to fix the first byte of the + // second packet will fail. + EXPECT_EQ(packet_ietf[length_of_first_coalesced_packet], 0xD3); + + unsigned char* p = packet; + size_t p_length = ABSL_ARRAYSIZE(packet); + if (framer_.version().HasIetfQuicFrames()) { + ReviseFirstByteByVersion(packet_ietf); + ReviseFirstByteByVersion(&packet_ietf[length_of_first_coalesced_packet]); + p = packet_ietf; + p_length = ABSL_ARRAYSIZE(packet_ietf); + } + + QuicEncryptedPacket encrypted(AsChars(p), p_length, false); + + EXPECT_TRUE(framer_.ProcessPacket(encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + + ASSERT_EQ(1u, visitor_.stream_frames_.size()); + EXPECT_EQ(0u, visitor_.ack_frames_.size()); + + // Stream ID should be the last 3 bytes of kStreamId. + EXPECT_EQ(0x00FFFFFF & kStreamId, visitor_.stream_frames_[0]->stream_id); + EXPECT_TRUE(visitor_.stream_frames_[0]->fin); + EXPECT_EQ(kStreamOffset, visitor_.stream_frames_[0]->offset); + CheckStreamFrameData("hello world!", visitor_.stream_frames_[0].get()); + + ASSERT_EQ(visitor_.coalesced_packets_.size(), 0u); +} + +TEST_P(QuicFramerTest, InvalidCoalescedPacket) { + if (!QuicVersionHasLongHeaderLengths(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_ZERO_RTT); + // clang-format off + unsigned char packet[] = { + // first coalesced packet + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + 0xD3, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // long header packet length + 0x1E, + // packet number + 0x12, 0x34, 0x56, 0x78, + // frame type (stream frame with fin) + 0xFE, + // stream id + 0x02, 0x03, 0x04, + // offset + 0x3A, 0x98, 0xFE, 0xDC, 0x32, 0x10, 0x76, 0x54, + // data length + 0x00, 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + // second coalesced packet + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + 0xD3, + // version would be here but we cut off the invalid coalesced header. + }; + unsigned char packet_ietf[] = { + // first coalesced packet + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + 0xD3, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // long header packet length + 0x1E, + // packet number + 0x12, 0x34, 0x56, 0x78, + // frame type (IETF_STREAM frame with FIN, LEN, and OFFSET bits set) + 0x08 | 0x01 | 0x02 | 0x04, + // stream id + kVarInt62FourBytes + 0x00, 0x02, 0x03, 0x04, + // offset + kVarInt62EightBytes + 0x3A, 0x98, 0xFE, 0xDC, + 0x32, 0x10, 0x76, 0x54, + // data length + kVarInt62OneByte + 0x0c, + // data + 'h', 'e', 'l', 'l', + 'o', ' ', 'w', 'o', + 'r', 'l', 'd', '!', + // second coalesced packet + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + 0xD3, + // version would be here but we cut off the invalid coalesced header. + }; + // clang-format on + const size_t length_of_first_coalesced_packet = 46; + // If the first packet changes, the attempt to fix the first byte of the + // second packet will fail. + EXPECT_EQ(packet_ietf[length_of_first_coalesced_packet], 0xD3); + + unsigned char* p = packet; + size_t p_length = ABSL_ARRAYSIZE(packet); + if (framer_.version().HasIetfQuicFrames()) { + ReviseFirstByteByVersion(packet_ietf); + ReviseFirstByteByVersion(&packet_ietf[length_of_first_coalesced_packet]); + p = packet_ietf; + p_length = ABSL_ARRAYSIZE(packet_ietf); + } + + QuicEncryptedPacket encrypted(AsChars(p), p_length, false); + + EXPECT_TRUE(framer_.ProcessPacket(encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + + ASSERT_EQ(1u, visitor_.stream_frames_.size()); + EXPECT_EQ(0u, visitor_.ack_frames_.size()); + + // Stream ID should be the last 3 bytes of kStreamId. + EXPECT_EQ(0x00FFFFFF & kStreamId, visitor_.stream_frames_[0]->stream_id); + EXPECT_TRUE(visitor_.stream_frames_[0]->fin); + EXPECT_EQ(kStreamOffset, visitor_.stream_frames_[0]->offset); + CheckStreamFrameData("hello world!", visitor_.stream_frames_[0].get()); + + ASSERT_EQ(visitor_.coalesced_packets_.size(), 0u); +} + +// Some IETF implementations send an initial followed by zeroes instead of +// padding inside the initial. We need to make sure that we still process +// the initial correctly and ignore the zeroes. +TEST_P(QuicFramerTest, CoalescedPacketWithZeroesRoundTrip) { + if (!QuicVersionHasLongHeaderLengths(framer_.transport_version()) || + !framer_.version().UsesInitialObfuscators()) { + return; + } + ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse()); + QuicConnectionId connection_id = FramerTestConnectionId(); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + + CrypterPair client_crypters; + CryptoUtils::CreateInitialObfuscators(Perspective::IS_CLIENT, + framer_.version(), connection_id, + &client_crypters); + framer_.SetEncrypter(ENCRYPTION_INITIAL, + std::move(client_crypters.encrypter)); + + QuicPacketHeader header; + header.destination_connection_id = connection_id; + header.version_flag = true; + header.packet_number = kPacketNumber; + header.packet_number_length = PACKET_4BYTE_PACKET_NUMBER; + header.long_packet_type = INITIAL; + header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2; + header.retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1; + QuicFrames frames = {QuicFrame(QuicPingFrame()), + QuicFrame(QuicPaddingFrame(3))}; + + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_NE(nullptr, data); + + // Add zeroes after the valid initial packet. + unsigned char packet[kMaxOutgoingPacketSize] = {}; + size_t encrypted_length = + framer_.EncryptPayload(ENCRYPTION_INITIAL, header.packet_number, *data, + AsChars(packet), ABSL_ARRAYSIZE(packet)); + ASSERT_NE(0u, encrypted_length); + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + CrypterPair server_crypters; + CryptoUtils::CreateInitialObfuscators(Perspective::IS_SERVER, + framer_.version(), connection_id, + &server_crypters); + framer_.InstallDecrypter(ENCRYPTION_INITIAL, + std::move(server_crypters.decrypter)); + + // Make sure the first long header initial packet parses correctly. + QuicEncryptedPacket encrypted(AsChars(packet), ABSL_ARRAYSIZE(packet), false); + + // Make sure we discard the subsequent zeroes. + EXPECT_TRUE(framer_.ProcessPacket(encrypted)); + EXPECT_TRUE(visitor_.coalesced_packets_.empty()); +} + +TEST_P(QuicFramerTest, ClientReceivesWrongVersion) { + if (!framer_.version().HasIetfInvariantHeader()) { + return; + } + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + + // clang-format off + unsigned char packet[] = { + // public flags (long header with packet type INITIAL) + 0xC3, + // version that is different from the framer's version + 'Q', '0', '4', '3', + // connection ID lengths + 0x05, + // source connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x01, + // padding frame + 0x00, + }; + // clang-format on + + QuicEncryptedPacket encrypted(AsChars(packet), ABSL_ARRAYSIZE(packet), false); + EXPECT_FALSE(framer_.ProcessPacket(encrypted)); + + EXPECT_THAT(framer_.error(), IsError(QUIC_PACKET_WRONG_VERSION)); + EXPECT_EQ("Client received unexpected version.", framer_.detailed_error()); +} + +TEST_P(QuicFramerTest, PacketHeaderWithVariableLengthConnectionId) { + if (!framer_.version().AllowsVariableLengthConnectionIds()) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + uint8_t connection_id_bytes[9] = {0xFE, 0xDC, 0xBA, 0x98, 0x76, + 0x54, 0x32, 0x10, 0x42}; + QuicConnectionId connection_id(reinterpret_cast(connection_id_bytes), + sizeof(connection_id_bytes)); + QuicFramerPeer::SetLargestPacketNumber(&framer_, kPacketNumber - 2); + QuicFramerPeer::SetExpectedServerConnectionIDLength(&framer_, + connection_id.length()); + + // clang-format off + PacketFragments packet = { + // type (8 byte connection_id and 1 byte packet number) + {"Unable to read first byte.", + {0x40}}, + // connection_id + {"Unable to read destination connection ID.", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, 0x42}}, + // packet number + {"Unable to read packet number.", + {0x78}}, + }; + + PacketFragments packet_with_padding = { + // type (8 byte connection_id and 1 byte packet number) + {"Unable to read first byte.", + {0x40}}, + // connection_id + {"Unable to read destination connection ID.", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, 0x42}}, + // packet number + {"", + {0x78}}, + // padding + {"", {0x00, 0x00, 0x00}}, + }; + // clang-format on + + PacketFragments& fragments = + framer_.version().HasHeaderProtection() ? packet_with_padding : packet; + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + if (framer_.version().HasHeaderProtection()) { + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsQuicNoError()); + } else { + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsError(QUIC_MISSING_PAYLOAD)); + } + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_EQ(connection_id, visitor_.header_->destination_connection_id); + EXPECT_FALSE(visitor_.header_->reset_flag); + EXPECT_FALSE(visitor_.header_->version_flag); + EXPECT_EQ(PACKET_1BYTE_PACKET_NUMBER, visitor_.header_->packet_number_length); + EXPECT_EQ(kPacketNumber, visitor_.header_->packet_number); + + CheckFramingBoundaries(fragments, QUIC_INVALID_PACKET_HEADER); +} + +TEST_P(QuicFramerTest, MultiplePacketNumberSpaces) { + if (!framer_.version().HasIetfInvariantHeader()) { + return; + } + framer_.EnableMultiplePacketNumberSpacesSupport(); + + // clang-format off + unsigned char long_header_packet[] = { + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + 0xD3, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x50, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + // padding frame + 0x00, + }; + unsigned char long_header_packet_ietf[] = { + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + 0xD3, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // long header packet length + 0x05, + // packet number + 0x12, 0x34, 0x56, 0x78, + // padding frame + 0x00, + }; + // clang-format on + + if (framer_.version().KnowsWhichDecrypterToUse()) { + framer_.InstallDecrypter(ENCRYPTION_ZERO_RTT, + std::make_unique()); + framer_.RemoveDecrypter(ENCRYPTION_INITIAL); + } else { + framer_.SetDecrypter(ENCRYPTION_ZERO_RTT, + std::make_unique()); + } + if (!QuicVersionHasLongHeaderLengths(framer_.transport_version())) { + EXPECT_TRUE(framer_.ProcessPacket( + QuicEncryptedPacket(AsChars(long_header_packet), + ABSL_ARRAYSIZE(long_header_packet), false))); + } else { + ReviseFirstByteByVersion(long_header_packet_ietf); + EXPECT_TRUE(framer_.ProcessPacket( + QuicEncryptedPacket(AsChars(long_header_packet_ietf), + ABSL_ARRAYSIZE(long_header_packet_ietf), false))); + } + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + EXPECT_FALSE( + QuicFramerPeer::GetLargestDecryptedPacketNumber(&framer_, INITIAL_DATA) + .IsInitialized()); + EXPECT_FALSE( + QuicFramerPeer::GetLargestDecryptedPacketNumber(&framer_, HANDSHAKE_DATA) + .IsInitialized()); + EXPECT_EQ(kPacketNumber, QuicFramerPeer::GetLargestDecryptedPacketNumber( + &framer_, APPLICATION_DATA)); + + // clang-format off + unsigned char short_header_packet[] = { + // type (short header, 1 byte packet number) + 0x40, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x79, + // padding frame + 0x00, 0x00, 0x00, + }; + // clang-format on + + QuicEncryptedPacket short_header_encrypted( + AsChars(short_header_packet), ABSL_ARRAYSIZE(short_header_packet), false); + if (framer_.version().KnowsWhichDecrypterToUse()) { + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique()); + framer_.RemoveDecrypter(ENCRYPTION_ZERO_RTT); + } else { + framer_.SetDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique()); + } + EXPECT_TRUE(framer_.ProcessPacket(short_header_encrypted)); + + EXPECT_THAT(framer_.error(), IsQuicNoError()); + EXPECT_FALSE( + QuicFramerPeer::GetLargestDecryptedPacketNumber(&framer_, INITIAL_DATA) + .IsInitialized()); + EXPECT_FALSE( + QuicFramerPeer::GetLargestDecryptedPacketNumber(&framer_, HANDSHAKE_DATA) + .IsInitialized()); + EXPECT_EQ(kPacketNumber + 1, QuicFramerPeer::GetLargestDecryptedPacketNumber( + &framer_, APPLICATION_DATA)); +} + +TEST_P(QuicFramerTest, IetfRetryPacketRejected) { + if (!framer_.version().KnowsWhichDecrypterToUse() || + framer_.version().SupportsRetry()) { + return; + } + + // clang-format off + PacketFragments packet46 = { + // public flags (IETF Retry packet, 0-length original destination CID) + {"Unable to read first byte.", + {0xf0}}, + // version tag + {"Unable to read protocol version.", + {QUIC_VERSION_BYTES}}, + // connection_id length + {"RETRY not supported in this version.", + {0x00}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet46)); + + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsError(QUIC_INVALID_PACKET_HEADER)); + CheckFramingBoundaries(packet46, QUIC_INVALID_PACKET_HEADER); +} + +TEST_P(QuicFramerTest, RetryPacketRejectedWithMultiplePacketNumberSpaces) { + if (!framer_.version().HasIetfInvariantHeader() || + framer_.version().SupportsRetry()) { + return; + } + framer_.EnableMultiplePacketNumberSpacesSupport(); + + // clang-format off + PacketFragments packet = { + // public flags (IETF Retry packet, 0-length original destination CID) + {"Unable to read first byte.", + {0xf0}}, + // version tag + {"Unable to read protocol version.", + {QUIC_VERSION_BYTES}}, + // connection_id length + {"RETRY not supported in this version.", + {0x00}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet)); + + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsError(QUIC_INVALID_PACKET_HEADER)); + CheckFramingBoundaries(packet, QUIC_INVALID_PACKET_HEADER); +} + +TEST_P(QuicFramerTest, ProcessPublicHeaderNoVersionInferredType) { + // The framer needs to have Perspective::IS_SERVER and configured to infer the + // packet header type from the packet (not the version). The framer's version + // needs to be one that uses the IETF packet format. + if (!framer_.version().KnowsWhichDecrypterToUse()) { + return; + } + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + + // Prepare a packet that uses the Google QUIC packet header but has no version + // field. + + // clang-format off + PacketFragments packet = { + // public flags (1-byte packet number, 8-byte connection_id, no version) + {"Unable to read public flags.", + {0x08}}, + // connection_id + {"Unable to read ConnectionId.", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // packet number + {"Unable to read packet number.", + {0x01}}, + // padding + {"Invalid public header type for expected version.", + {0x00}}, + }; + // clang-format on + + PacketFragments& fragments = packet; + + std::unique_ptr encrypted( + AssemblePacketFromFragments(fragments)); + + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsError(QUIC_INVALID_PACKET_HEADER)); + EXPECT_EQ("Invalid public header type for expected version.", + framer_.detailed_error()); + CheckFramingBoundaries(fragments, QUIC_INVALID_PACKET_HEADER); +} + +TEST_P(QuicFramerTest, ProcessMismatchedHeaderVersion) { + // The framer needs to have Perspective::IS_SERVER and configured to infer the + // packet header type from the packet (not the version). The framer's version + // needs to be one that uses the IETF packet format. + if (!framer_.version().KnowsWhichDecrypterToUse()) { + return; + } + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + + // clang-format off + PacketFragments packet = { + // public flags (Google QUIC header with version present) + {"Unable to read public flags.", + {0x09}}, + // connection_id + {"Unable to read ConnectionId.", + {0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}}, + // version tag + {"Unable to read protocol version.", + {QUIC_VERSION_BYTES}}, + // packet number + {"Unable to read packet number.", + {0x01}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet)); + framer_.ProcessPacket(*encrypted); + + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsError(QUIC_INVALID_PACKET_HEADER)); + EXPECT_EQ("Invalid public header type for expected version.", + framer_.detailed_error()); + CheckFramingBoundaries(packet, QUIC_INVALID_PACKET_HEADER); +} + +TEST_P(QuicFramerTest, WriteClientVersionNegotiationProbePacket) { + // clang-format off + static const uint8_t expected_packet[1200] = { + // IETF long header with fixed bit set, type initial, all-0 encrypted bits. + 0xc0, + // Version, part of the IETF space reserved for negotiation. + 0xca, 0xba, 0xda, 0xda, + // Destination connection ID length 8. + 0x08, + // 8-byte destination connection ID. + 0x56, 0x4e, 0x20, 0x70, 0x6c, 0x7a, 0x20, 0x21, + // Source connection ID length 0. + 0x00, + // 8 bytes of zeroes followed by 8 bytes of ones to ensure that this does + // not parse with any known version. + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + // zeroes to pad to 16 byte boundary. + 0x00, + // A polite greeting in case a human sees this in tcpdump. + 0x54, 0x68, 0x69, 0x73, 0x20, 0x70, 0x61, 0x63, + 0x6b, 0x65, 0x74, 0x20, 0x6f, 0x6e, 0x6c, 0x79, + 0x20, 0x65, 0x78, 0x69, 0x73, 0x74, 0x73, 0x20, + 0x74, 0x6f, 0x20, 0x74, 0x72, 0x69, 0x67, 0x67, + 0x65, 0x72, 0x20, 0x49, 0x45, 0x54, 0x46, 0x20, + 0x51, 0x55, 0x49, 0x43, 0x20, 0x76, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x20, 0x6e, 0x65, 0x67, + 0x6f, 0x74, 0x69, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x2e, 0x20, 0x50, 0x6c, 0x65, 0x61, 0x73, 0x65, + 0x20, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x64, + 0x20, 0x77, 0x69, 0x74, 0x68, 0x20, 0x61, 0x20, + 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x20, + 0x4e, 0x65, 0x67, 0x6f, 0x74, 0x69, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x20, 0x70, 0x61, 0x63, 0x6b, + 0x65, 0x74, 0x20, 0x69, 0x6e, 0x64, 0x69, 0x63, + 0x61, 0x74, 0x69, 0x6e, 0x67, 0x20, 0x77, 0x68, + 0x61, 0x74, 0x20, 0x76, 0x65, 0x72, 0x73, 0x69, + 0x6f, 0x6e, 0x73, 0x20, 0x79, 0x6f, 0x75, 0x20, + 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x2e, + 0x20, 0x54, 0x68, 0x61, 0x6e, 0x6b, 0x20, 0x79, + 0x6f, 0x75, 0x20, 0x61, 0x6e, 0x64, 0x20, 0x68, + 0x61, 0x76, 0x65, 0x20, 0x61, 0x20, 0x6e, 0x69, + 0x63, 0x65, 0x20, 0x64, 0x61, 0x79, 0x2e, 0x00, + }; + // clang-format on + char packet[1200]; + char destination_connection_id_bytes[] = {0x56, 0x4e, 0x20, 0x70, + 0x6c, 0x7a, 0x20, 0x21}; + EXPECT_TRUE(QuicFramer::WriteClientVersionNegotiationProbePacket( + packet, sizeof(packet), destination_connection_id_bytes, + sizeof(destination_connection_id_bytes))); + quiche::test::CompareCharArraysWithHexError( + "constructed packet", packet, sizeof(packet), + reinterpret_cast(expected_packet), sizeof(expected_packet)); + QuicEncryptedPacket encrypted(reinterpret_cast(packet), + sizeof(packet), false); + if (!framer_.version().HasLengthPrefixedConnectionIds()) { + // We can only parse the connection ID with a parser expecting + // length-prefixed connection IDs. + EXPECT_FALSE(framer_.ProcessPacket(encrypted)); + return; + } + EXPECT_TRUE(framer_.ProcessPacket(encrypted)); + ASSERT_TRUE(visitor_.header_.get()); + QuicConnectionId probe_payload_connection_id( + reinterpret_cast(destination_connection_id_bytes), + sizeof(destination_connection_id_bytes)); + EXPECT_EQ(probe_payload_connection_id, + visitor_.header_.get()->destination_connection_id); +} + +TEST_P(QuicFramerTest, DispatcherParseOldClientVersionNegotiationProbePacket) { + // clang-format off + static const uint8_t packet[1200] = { + // IETF long header with fixed bit set, type initial, all-0 encrypted bits. + 0xc0, + // Version, part of the IETF space reserved for negotiation. + 0xca, 0xba, 0xda, 0xba, + // Destination connection ID length 8, source connection ID length 0. + 0x50, + // 8-byte destination connection ID. + 0x56, 0x4e, 0x20, 0x70, 0x6c, 0x7a, 0x20, 0x21, + // 8 bytes of zeroes followed by 8 bytes of ones to ensure that this does + // not parse with any known version. + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + // 2 bytes of zeroes to pad to 16 byte boundary. + 0x00, 0x00, + // A polite greeting in case a human sees this in tcpdump. + 0x54, 0x68, 0x69, 0x73, 0x20, 0x70, 0x61, 0x63, + 0x6b, 0x65, 0x74, 0x20, 0x6f, 0x6e, 0x6c, 0x79, + 0x20, 0x65, 0x78, 0x69, 0x73, 0x74, 0x73, 0x20, + 0x74, 0x6f, 0x20, 0x74, 0x72, 0x69, 0x67, 0x67, + 0x65, 0x72, 0x20, 0x49, 0x45, 0x54, 0x46, 0x20, + 0x51, 0x55, 0x49, 0x43, 0x20, 0x76, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x20, 0x6e, 0x65, 0x67, + 0x6f, 0x74, 0x69, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x2e, 0x20, 0x50, 0x6c, 0x65, 0x61, 0x73, 0x65, + 0x20, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x64, + 0x20, 0x77, 0x69, 0x74, 0x68, 0x20, 0x61, 0x20, + 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x20, + 0x4e, 0x65, 0x67, 0x6f, 0x74, 0x69, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x20, 0x70, 0x61, 0x63, 0x6b, + 0x65, 0x74, 0x20, 0x69, 0x6e, 0x64, 0x69, 0x63, + 0x61, 0x74, 0x69, 0x6e, 0x67, 0x20, 0x77, 0x68, + 0x61, 0x74, 0x20, 0x76, 0x65, 0x72, 0x73, 0x69, + 0x6f, 0x6e, 0x73, 0x20, 0x79, 0x6f, 0x75, 0x20, + 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x2e, + 0x20, 0x54, 0x68, 0x61, 0x6e, 0x6b, 0x20, 0x79, + 0x6f, 0x75, 0x20, 0x61, 0x6e, 0x64, 0x20, 0x68, + 0x61, 0x76, 0x65, 0x20, 0x61, 0x20, 0x6e, 0x69, + 0x63, 0x65, 0x20, 0x64, 0x61, 0x79, 0x2e, 0x00, + }; + // clang-format on + char expected_destination_connection_id_bytes[] = {0x56, 0x4e, 0x20, 0x70, + 0x6c, 0x7a, 0x20, 0x21}; + QuicConnectionId expected_destination_connection_id( + reinterpret_cast(expected_destination_connection_id_bytes), + sizeof(expected_destination_connection_id_bytes)); + + QuicEncryptedPacket encrypted(reinterpret_cast(packet), + sizeof(packet)); + PacketHeaderFormat format = GOOGLE_QUIC_PACKET; + QuicLongHeaderType long_packet_type = INVALID_PACKET_TYPE; + bool version_present = false, has_length_prefix = true; + QuicVersionLabel version_label = 33; + ParsedQuicVersion parsed_version = UnsupportedQuicVersion(); + QuicConnectionId destination_connection_id = TestConnectionId(1); + QuicConnectionId source_connection_id = TestConnectionId(2); + absl::optional retry_token; + std::string detailed_error = "foobar"; + QuicErrorCode header_parse_result = QuicFramer::ParsePublicHeaderDispatcher( + encrypted, kQuicDefaultConnectionIdLength, &format, &long_packet_type, + &version_present, &has_length_prefix, &version_label, &parsed_version, + &destination_connection_id, &source_connection_id, &retry_token, + &detailed_error); + EXPECT_THAT(header_parse_result, IsQuicNoError()); + EXPECT_EQ(IETF_QUIC_LONG_HEADER_PACKET, format); + EXPECT_TRUE(version_present); + EXPECT_FALSE(has_length_prefix); + EXPECT_EQ(0xcabadaba, version_label); + EXPECT_EQ(expected_destination_connection_id, destination_connection_id); + EXPECT_EQ(EmptyQuicConnectionId(), source_connection_id); + EXPECT_FALSE(retry_token.has_value()); + EXPECT_EQ("", detailed_error); +} + +TEST_P(QuicFramerTest, DispatcherParseClientVersionNegotiationProbePacket) { + // clang-format off + static const uint8_t packet[1200] = { + // IETF long header with fixed bit set, type initial, all-0 encrypted bits. + 0xc0, + // Version, part of the IETF space reserved for negotiation. + 0xca, 0xba, 0xda, 0xba, + // Destination connection ID length 8. + 0x08, + // 8-byte destination connection ID. + 0x56, 0x4e, 0x20, 0x70, 0x6c, 0x7a, 0x20, 0x21, + // Source connection ID length 0. + 0x00, + // 8 bytes of zeroes followed by 8 bytes of ones to ensure that this does + // not parse with any known version. + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + // 1 byte of zeroes to pad to 16 byte boundary. + 0x00, + // A polite greeting in case a human sees this in tcpdump. + 0x54, 0x68, 0x69, 0x73, 0x20, 0x70, 0x61, 0x63, + 0x6b, 0x65, 0x74, 0x20, 0x6f, 0x6e, 0x6c, 0x79, + 0x20, 0x65, 0x78, 0x69, 0x73, 0x74, 0x73, 0x20, + 0x74, 0x6f, 0x20, 0x74, 0x72, 0x69, 0x67, 0x67, + 0x65, 0x72, 0x20, 0x49, 0x45, 0x54, 0x46, 0x20, + 0x51, 0x55, 0x49, 0x43, 0x20, 0x76, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x20, 0x6e, 0x65, 0x67, + 0x6f, 0x74, 0x69, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x2e, 0x20, 0x50, 0x6c, 0x65, 0x61, 0x73, 0x65, + 0x20, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x64, + 0x20, 0x77, 0x69, 0x74, 0x68, 0x20, 0x61, 0x20, + 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x20, + 0x4e, 0x65, 0x67, 0x6f, 0x74, 0x69, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x20, 0x70, 0x61, 0x63, 0x6b, + 0x65, 0x74, 0x20, 0x69, 0x6e, 0x64, 0x69, 0x63, + 0x61, 0x74, 0x69, 0x6e, 0x67, 0x20, 0x77, 0x68, + 0x61, 0x74, 0x20, 0x76, 0x65, 0x72, 0x73, 0x69, + 0x6f, 0x6e, 0x73, 0x20, 0x79, 0x6f, 0x75, 0x20, + 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x2e, + 0x20, 0x54, 0x68, 0x61, 0x6e, 0x6b, 0x20, 0x79, + 0x6f, 0x75, 0x20, 0x61, 0x6e, 0x64, 0x20, 0x68, + 0x61, 0x76, 0x65, 0x20, 0x61, 0x20, 0x6e, 0x69, + 0x63, 0x65, 0x20, 0x64, 0x61, 0x79, 0x2e, 0x00, + }; + // clang-format on + char expected_destination_connection_id_bytes[] = {0x56, 0x4e, 0x20, 0x70, + 0x6c, 0x7a, 0x20, 0x21}; + QuicConnectionId expected_destination_connection_id( + reinterpret_cast(expected_destination_connection_id_bytes), + sizeof(expected_destination_connection_id_bytes)); + + QuicEncryptedPacket encrypted(reinterpret_cast(packet), + sizeof(packet)); + PacketHeaderFormat format = GOOGLE_QUIC_PACKET; + QuicLongHeaderType long_packet_type = INVALID_PACKET_TYPE; + bool version_present = false, has_length_prefix = false; + QuicVersionLabel version_label = 33; + ParsedQuicVersion parsed_version = UnsupportedQuicVersion(); + QuicConnectionId destination_connection_id = TestConnectionId(1); + QuicConnectionId source_connection_id = TestConnectionId(2); + absl::optional retry_token; + std::string detailed_error = "foobar"; + QuicErrorCode header_parse_result = QuicFramer::ParsePublicHeaderDispatcher( + encrypted, kQuicDefaultConnectionIdLength, &format, &long_packet_type, + &version_present, &has_length_prefix, &version_label, &parsed_version, + &destination_connection_id, &source_connection_id, &retry_token, + &detailed_error); + EXPECT_THAT(header_parse_result, IsQuicNoError()); + EXPECT_EQ(IETF_QUIC_LONG_HEADER_PACKET, format); + EXPECT_TRUE(version_present); + EXPECT_TRUE(has_length_prefix); + EXPECT_EQ(0xcabadaba, version_label); + EXPECT_EQ(expected_destination_connection_id, destination_connection_id); + EXPECT_EQ(EmptyQuicConnectionId(), source_connection_id); + EXPECT_EQ("", detailed_error); +} + +TEST_P(QuicFramerTest, ParseServerVersionNegotiationProbeResponse) { + // clang-format off + const uint8_t packet[] = { + // IETF long header with fixed bit set, type initial, all-0 encrypted bits. + 0xc0, + // Version of 0, indicating version negotiation. + 0x00, 0x00, 0x00, 0x00, + // Destination connection ID length 0, source connection ID length 8. + 0x00, 0x08, + // 8-byte source connection ID. + 0x56, 0x4e, 0x20, 0x70, 0x6c, 0x7a, 0x20, 0x21, + // A few supported versions. + 0xaa, 0xaa, 0xaa, 0xaa, + QUIC_VERSION_BYTES, + }; + // clang-format on + char probe_payload_bytes[] = {0x56, 0x4e, 0x20, 0x70, 0x6c, 0x7a, 0x20, 0x21}; + char parsed_probe_payload_bytes[255] = {}; + uint8_t parsed_probe_payload_length = sizeof(parsed_probe_payload_bytes); + std::string parse_detailed_error = ""; + EXPECT_TRUE(QuicFramer::ParseServerVersionNegotiationProbeResponse( + reinterpret_cast(packet), sizeof(packet), + reinterpret_cast(parsed_probe_payload_bytes), + &parsed_probe_payload_length, &parse_detailed_error)); + EXPECT_EQ("", parse_detailed_error); + quiche::test::CompareCharArraysWithHexError( + "parsed probe", parsed_probe_payload_bytes, parsed_probe_payload_length, + probe_payload_bytes, sizeof(probe_payload_bytes)); +} + +TEST_P(QuicFramerTest, ParseClientVersionNegotiationProbePacket) { + char packet[1200]; + char input_destination_connection_id_bytes[] = {0x56, 0x4e, 0x20, 0x70, + 0x6c, 0x7a, 0x20, 0x21}; + ASSERT_TRUE(QuicFramer::WriteClientVersionNegotiationProbePacket( + packet, sizeof(packet), input_destination_connection_id_bytes, + sizeof(input_destination_connection_id_bytes))); + char parsed_destination_connection_id_bytes[255] = {0}; + uint8_t parsed_destination_connection_id_length = + sizeof(parsed_destination_connection_id_bytes); + ASSERT_TRUE(ParseClientVersionNegotiationProbePacket( + packet, sizeof(packet), parsed_destination_connection_id_bytes, + &parsed_destination_connection_id_length)); + quiche::test::CompareCharArraysWithHexError( + "parsed destination connection ID", + parsed_destination_connection_id_bytes, + parsed_destination_connection_id_length, + input_destination_connection_id_bytes, + sizeof(input_destination_connection_id_bytes)); +} + +TEST_P(QuicFramerTest, WriteServerVersionNegotiationProbeResponse) { + char packet[1200]; + size_t packet_length = sizeof(packet); + char input_source_connection_id_bytes[] = {0x56, 0x4e, 0x20, 0x70, + 0x6c, 0x7a, 0x20, 0x21}; + ASSERT_TRUE(WriteServerVersionNegotiationProbeResponse( + packet, &packet_length, input_source_connection_id_bytes, + sizeof(input_source_connection_id_bytes))); + char parsed_source_connection_id_bytes[255] = {0}; + uint8_t parsed_source_connection_id_length = + sizeof(parsed_source_connection_id_bytes); + std::string detailed_error; + ASSERT_TRUE(QuicFramer::ParseServerVersionNegotiationProbeResponse( + packet, packet_length, parsed_source_connection_id_bytes, + &parsed_source_connection_id_length, &detailed_error)) + << detailed_error; + quiche::test::CompareCharArraysWithHexError( + "parsed destination connection ID", parsed_source_connection_id_bytes, + parsed_source_connection_id_length, input_source_connection_id_bytes, + sizeof(input_source_connection_id_bytes)); +} + +TEST_P(QuicFramerTest, ClientConnectionIdFromLongHeaderToClient) { + if (!framer_.version().HasIetfInvariantHeader()) { + // This test requires an IETF long header. + return; + } + SetDecrypterLevel(ENCRYPTION_HANDSHAKE); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + // clang-format off + unsigned char packet[] = { + // public flags (long header with packet type HANDSHAKE and + // 4-byte packet number) + 0xE3, + // version + QUIC_VERSION_BYTES, + // connection ID lengths + 0x50, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // long header packet length + 0x05, + // packet number + 0x12, 0x34, 0x56, 0x00, + // padding frame + 0x00, + }; + unsigned char packet49[] = { + // public flags (long header with packet type HANDSHAKE and + // 4-byte packet number) + 0xE3, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x00, + // long header packet length + 0x05, + // packet number + 0x12, 0x34, 0x56, 0x00, + // padding frame + 0x00, + }; + // clang-format on + + unsigned char* p = packet; + size_t p_length = ABSL_ARRAYSIZE(packet); + if (framer_.version().HasLongHeaderLengths()) { + ReviseFirstByteByVersion(packet49); + p = packet49; + p_length = ABSL_ARRAYSIZE(packet49); + } + const bool parse_success = + framer_.ProcessPacket(QuicEncryptedPacket(AsChars(p), p_length, false)); + if (!framer_.version().AllowsVariableLengthConnectionIds()) { + EXPECT_FALSE(parse_success); + EXPECT_THAT(framer_.error(), IsError(QUIC_INVALID_PACKET_HEADER)); + EXPECT_EQ("Invalid ConnectionId length.", framer_.detailed_error()); + return; + } + EXPECT_TRUE(parse_success); + EXPECT_THAT(framer_.error(), IsQuicNoError()); + EXPECT_EQ("", framer_.detailed_error()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_EQ(FramerTestConnectionId(), + visitor_.header_.get()->destination_connection_id); +} + +TEST_P(QuicFramerTest, ClientConnectionIdFromLongHeaderToServer) { + if (!framer_.version().HasIetfInvariantHeader()) { + // This test requires an IETF long header. + return; + } + SetDecrypterLevel(ENCRYPTION_HANDSHAKE); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // clang-format off + unsigned char packet[] = { + // public flags (long header with packet type HANDSHAKE and + // 4-byte packet number) + 0xE3, + // version + QUIC_VERSION_BYTES, + // connection ID lengths + 0x05, + // source connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // long header packet length + 0x05, + // packet number + 0x12, 0x34, 0x56, 0x00, + // padding frame + 0x00, + }; + unsigned char packet49[] = { + // public flags (long header with packet type HANDSHAKE and + // 4-byte packet number) + 0xE3, + // version + QUIC_VERSION_BYTES, + // connection ID lengths + 0x00, 0x08, + // source connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // long header packet length + 0x05, + // packet number + 0x12, 0x34, 0x56, 0x00, + // padding frame + 0x00, + }; + // clang-format on + unsigned char* p = packet; + size_t p_length = ABSL_ARRAYSIZE(packet); + if (framer_.version().HasLongHeaderLengths()) { + ReviseFirstByteByVersion(packet49); + p = packet49; + p_length = ABSL_ARRAYSIZE(packet49); + } + const bool parse_success = + framer_.ProcessPacket(QuicEncryptedPacket(AsChars(p), p_length, false)); + if (!framer_.version().AllowsVariableLengthConnectionIds()) { + EXPECT_FALSE(parse_success); + EXPECT_THAT(framer_.error(), IsError(QUIC_INVALID_PACKET_HEADER)); + EXPECT_EQ("Invalid ConnectionId length.", framer_.detailed_error()); + return; + } + if (!framer_.version().SupportsClientConnectionIds()) { + EXPECT_FALSE(parse_success); + EXPECT_THAT(framer_.error(), IsError(QUIC_INVALID_PACKET_HEADER)); + EXPECT_EQ("Client connection ID not supported in this version.", + framer_.detailed_error()); + return; + } + EXPECT_TRUE(parse_success); + EXPECT_THAT(framer_.error(), IsQuicNoError()); + EXPECT_EQ("", framer_.detailed_error()); + ASSERT_TRUE(visitor_.header_.get()); + EXPECT_EQ(FramerTestConnectionId(), + visitor_.header_.get()->source_connection_id); +} + +TEST_P(QuicFramerTest, ProcessAndValidateIetfConnectionIdLengthClient) { + if (!framer_.version().HasIetfInvariantHeader()) { + // This test requires an IETF long header. + return; + } + char connection_id_lengths = 0x05; + QuicDataReader reader(&connection_id_lengths, 1); + + bool should_update_expected_server_connection_id_length = false; + uint8_t expected_server_connection_id_length = 8; + uint8_t destination_connection_id_length = 0; + uint8_t source_connection_id_length = 8; + std::string detailed_error = ""; + + EXPECT_TRUE(QuicFramerPeer::ProcessAndValidateIetfConnectionIdLength( + &reader, framer_.version(), Perspective::IS_CLIENT, + should_update_expected_server_connection_id_length, + &expected_server_connection_id_length, &destination_connection_id_length, + &source_connection_id_length, &detailed_error)); + EXPECT_EQ(8, expected_server_connection_id_length); + EXPECT_EQ(0, destination_connection_id_length); + EXPECT_EQ(8, source_connection_id_length); + EXPECT_EQ("", detailed_error); + + QuicDataReader reader2(&connection_id_lengths, 1); + should_update_expected_server_connection_id_length = true; + expected_server_connection_id_length = 33; + EXPECT_TRUE(QuicFramerPeer::ProcessAndValidateIetfConnectionIdLength( + &reader2, framer_.version(), Perspective::IS_CLIENT, + should_update_expected_server_connection_id_length, + &expected_server_connection_id_length, &destination_connection_id_length, + &source_connection_id_length, &detailed_error)); + EXPECT_EQ(8, expected_server_connection_id_length); + EXPECT_EQ(0, destination_connection_id_length); + EXPECT_EQ(8, source_connection_id_length); + EXPECT_EQ("", detailed_error); +} + +TEST_P(QuicFramerTest, ProcessAndValidateIetfConnectionIdLengthServer) { + if (!framer_.version().HasIetfInvariantHeader()) { + // This test requires an IETF long header. + return; + } + char connection_id_lengths = 0x50; + QuicDataReader reader(&connection_id_lengths, 1); + + bool should_update_expected_server_connection_id_length = false; + uint8_t expected_server_connection_id_length = 8; + uint8_t destination_connection_id_length = 8; + uint8_t source_connection_id_length = 0; + std::string detailed_error = ""; + + EXPECT_TRUE(QuicFramerPeer::ProcessAndValidateIetfConnectionIdLength( + &reader, framer_.version(), Perspective::IS_SERVER, + should_update_expected_server_connection_id_length, + &expected_server_connection_id_length, &destination_connection_id_length, + &source_connection_id_length, &detailed_error)); + EXPECT_EQ(8, expected_server_connection_id_length); + EXPECT_EQ(8, destination_connection_id_length); + EXPECT_EQ(0, source_connection_id_length); + EXPECT_EQ("", detailed_error); + + QuicDataReader reader2(&connection_id_lengths, 1); + should_update_expected_server_connection_id_length = true; + expected_server_connection_id_length = 33; + EXPECT_TRUE(QuicFramerPeer::ProcessAndValidateIetfConnectionIdLength( + &reader2, framer_.version(), Perspective::IS_SERVER, + should_update_expected_server_connection_id_length, + &expected_server_connection_id_length, &destination_connection_id_length, + &source_connection_id_length, &detailed_error)); + EXPECT_EQ(8, expected_server_connection_id_length); + EXPECT_EQ(8, destination_connection_id_length); + EXPECT_EQ(0, source_connection_id_length); + EXPECT_EQ("", detailed_error); +} + +TEST_P(QuicFramerTest, TestExtendedErrorCodeParser) { + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + // Extended error codes only in IETF QUIC + return; + } + QuicConnectionCloseFrame frame; + + frame.error_details = "this has no error code info in it"; + MaybeExtractQuicErrorCode(&frame); + EXPECT_THAT(frame.quic_error_code, IsError(QUIC_IETF_GQUIC_ERROR_MISSING)); + EXPECT_EQ("this has no error code info in it", frame.error_details); + + frame.error_details = "1234this does not have the colon in it"; + MaybeExtractQuicErrorCode(&frame); + EXPECT_THAT(frame.quic_error_code, IsError(QUIC_IETF_GQUIC_ERROR_MISSING)); + EXPECT_EQ("1234this does not have the colon in it", frame.error_details); + + frame.error_details = "1a234:this has a colon, but a malformed error number"; + MaybeExtractQuicErrorCode(&frame); + EXPECT_THAT(frame.quic_error_code, IsError(QUIC_IETF_GQUIC_ERROR_MISSING)); + EXPECT_EQ("1a234:this has a colon, but a malformed error number", + frame.error_details); + + frame.error_details = "1234:this is good"; + MaybeExtractQuicErrorCode(&frame); + EXPECT_EQ(1234u, frame.quic_error_code); + EXPECT_EQ("this is good", frame.error_details); + + frame.error_details = + "1234 :this is not good, space between last digit and colon"; + MaybeExtractQuicErrorCode(&frame); + EXPECT_THAT(frame.quic_error_code, IsError(QUIC_IETF_GQUIC_ERROR_MISSING)); + EXPECT_EQ("1234 :this is not good, space between last digit and colon", + frame.error_details); + + frame.error_details = "123456789"; + MaybeExtractQuicErrorCode(&frame); + EXPECT_THAT( + frame.quic_error_code, + IsError(QUIC_IETF_GQUIC_ERROR_MISSING)); // Not good, all numbers, no : + EXPECT_EQ("123456789", frame.error_details); + + frame.error_details = "1234:"; + MaybeExtractQuicErrorCode(&frame); + EXPECT_EQ(1234u, + frame.quic_error_code); // corner case. + EXPECT_EQ("", frame.error_details); + + frame.error_details = "1234:5678"; + MaybeExtractQuicErrorCode(&frame); + EXPECT_EQ(1234u, + frame.quic_error_code); // another corner case. + EXPECT_EQ("5678", frame.error_details); + + frame.error_details = "12345 6789:"; + MaybeExtractQuicErrorCode(&frame); + EXPECT_THAT(frame.quic_error_code, + IsError(QUIC_IETF_GQUIC_ERROR_MISSING)); // Not good + EXPECT_EQ("12345 6789:", frame.error_details); + + frame.error_details = ":no numbers, is not good"; + MaybeExtractQuicErrorCode(&frame); + EXPECT_THAT(frame.quic_error_code, IsError(QUIC_IETF_GQUIC_ERROR_MISSING)); + EXPECT_EQ(":no numbers, is not good", frame.error_details); + + frame.error_details = "qwer:also no numbers, is not good"; + MaybeExtractQuicErrorCode(&frame); + EXPECT_THAT(frame.quic_error_code, IsError(QUIC_IETF_GQUIC_ERROR_MISSING)); + EXPECT_EQ("qwer:also no numbers, is not good", frame.error_details); + + frame.error_details = " 1234:this is not good, space before first digit"; + MaybeExtractQuicErrorCode(&frame); + EXPECT_THAT(frame.quic_error_code, IsError(QUIC_IETF_GQUIC_ERROR_MISSING)); + EXPECT_EQ(" 1234:this is not good, space before first digit", + frame.error_details); + + frame.error_details = "1234:"; + MaybeExtractQuicErrorCode(&frame); + EXPECT_EQ(1234u, + frame.quic_error_code); // this is good + EXPECT_EQ("", frame.error_details); +} + +// Regression test for crbug/1029636. +TEST_P(QuicFramerTest, OverlyLargeAckDelay) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + unsigned char packet_ietf[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_ACK frame) + 0x02, + // largest acked + kVarInt62FourBytes + 0x12, 0x34, 0x56, 0x78, + // ack delay time. + kVarInt62EightBytes + 0x31, 0x00, 0x00, 0x00, 0xF3, 0xA0, 0x81, 0xE0, + // Nr. of additional ack blocks + kVarInt62OneByte + 0x00, + // first ack block length. + kVarInt62FourBytes + 0x12, 0x34, 0x56, 0x77, + }; + // clang-format on + + framer_.ProcessPacket(QuicEncryptedPacket( + AsChars(packet_ietf), ABSL_ARRAYSIZE(packet_ietf), false)); + ASSERT_EQ(1u, visitor_.ack_frames_.size()); + // Verify ack_delay_time is set correctly. + EXPECT_EQ(QuicTime::Delta::Infinite(), + visitor_.ack_frames_[0]->ack_delay_time); +} + +TEST_P(QuicFramerTest, KeyUpdate) { + if (!framer_.version().UsesTls()) { + // Key update is only used in QUIC+TLS. + return; + } + ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse()); + // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter + // instead of TestDecrypter. + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(/*key=*/0)); + framer_.SetKeyUpdateSupportForConnection(true); + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicFrames frames = {QuicFrame(QuicPaddingFrame())}; + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + std::unique_ptr encrypted( + EncryptPacketWithTagAndPhase(*data, 0, false)); + ASSERT_TRUE(encrypted); + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + // Processed valid packet with phase=0, key=1: no key update. + EXPECT_EQ(0u, visitor_.key_update_count()); + EXPECT_EQ(0, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); + + header.packet_number += 1; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 1, true); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + // Processed valid packet with phase=1, key=2: key update should have + // occurred. + ASSERT_EQ(1u, visitor_.key_update_count()); + EXPECT_EQ(KeyUpdateReason::kRemote, visitor_.key_update_reasons_[0]); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(2, visitor_.decrypted_first_packet_in_key_phase_count_); + + header.packet_number += 1; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 1, true); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + // Processed another valid packet with phase=1, key=2: no key update. + EXPECT_EQ(1u, visitor_.key_update_count()); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(2, visitor_.decrypted_first_packet_in_key_phase_count_); + + // Process another key update. + header.packet_number += 1; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 2, false); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + ASSERT_EQ(2u, visitor_.key_update_count()); + EXPECT_EQ(KeyUpdateReason::kRemote, visitor_.key_update_reasons_[1]); + EXPECT_EQ(2, visitor_.derive_next_key_count_); + EXPECT_EQ(3, visitor_.decrypted_first_packet_in_key_phase_count_); +} + +TEST_P(QuicFramerTest, KeyUpdateOldPacketAfterUpdate) { + if (!framer_.version().UsesTls()) { + // Key update is only used in QUIC+TLS. + return; + } + ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse()); + // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter + // instead of TestDecrypter. + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(/*key=*/0)); + framer_.SetKeyUpdateSupportForConnection(true); + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicFrames frames = {QuicFrame(QuicPaddingFrame())}; + + // Process packet N with phase 0. + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + std::unique_ptr encrypted( + EncryptPacketWithTagAndPhase(*data, 0, false)); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(0u, visitor_.key_update_count()); + EXPECT_EQ(0, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); + + // Process packet N+2 with phase 1. + header.packet_number = kPacketNumber + 2; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 1, true); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(1u, visitor_.key_update_count()); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(2, visitor_.decrypted_first_packet_in_key_phase_count_); + + // Process packet N+1 with phase 0. (Receiving packet from previous phase + // after packet from new phase was received.) + header.packet_number = kPacketNumber + 1; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 0, false); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Packet should decrypt and key update count should not change. + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(1u, visitor_.key_update_count()); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(2, visitor_.decrypted_first_packet_in_key_phase_count_); +} + +TEST_P(QuicFramerTest, KeyUpdateOldPacketAfterDiscardPreviousOneRttKeys) { + if (!framer_.version().UsesTls()) { + // Key update is only used in QUIC+TLS. + return; + } + ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse()); + // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter + // instead of TestDecrypter. + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(/*key=*/0)); + framer_.SetKeyUpdateSupportForConnection(true); + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicFrames frames = {QuicFrame(QuicPaddingFrame())}; + + // Process packet N with phase 0. + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + std::unique_ptr encrypted( + EncryptPacketWithTagAndPhase(*data, 0, false)); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(0u, visitor_.key_update_count()); + EXPECT_EQ(0, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); + + // Process packet N+2 with phase 1. + header.packet_number = kPacketNumber + 2; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 1, true); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(1u, visitor_.key_update_count()); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(2, visitor_.decrypted_first_packet_in_key_phase_count_); + + // Discard keys for previous key phase. + framer_.DiscardPreviousOneRttKeys(); + + // Process packet N+1 with phase 0. (Receiving packet from previous phase + // after packet from new phase was received.) + header.packet_number = kPacketNumber + 1; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 0, false); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Packet should not decrypt and key update count should not change. + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(1u, visitor_.key_update_count()); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(2, visitor_.decrypted_first_packet_in_key_phase_count_); +} + +TEST_P(QuicFramerTest, KeyUpdatePacketsOutOfOrder) { + if (!framer_.version().UsesTls()) { + // Key update is only used in QUIC+TLS. + return; + } + ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse()); + // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter + // instead of TestDecrypter. + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(/*key=*/0)); + framer_.SetKeyUpdateSupportForConnection(true); + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicFrames frames = {QuicFrame(QuicPaddingFrame())}; + + // Process packet N with phase 0. + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + std::unique_ptr encrypted( + EncryptPacketWithTagAndPhase(*data, 0, false)); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(0u, visitor_.key_update_count()); + EXPECT_EQ(0, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); + + // Process packet N+2 with phase 1. + header.packet_number = kPacketNumber + 2; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 1, true); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(1u, visitor_.key_update_count()); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(2, visitor_.decrypted_first_packet_in_key_phase_count_); + + // Process packet N+1 with phase 1. (Receiving packet from new phase out of + // order.) + header.packet_number = kPacketNumber + 1; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 1, true); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Packet should decrypt and key update count should not change. + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(1u, visitor_.key_update_count()); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(2, visitor_.decrypted_first_packet_in_key_phase_count_); +} + +TEST_P(QuicFramerTest, KeyUpdateWrongKey) { + if (!framer_.version().UsesTls()) { + // Key update is only used in QUIC+TLS. + return; + } + ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse()); + // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter + // instead of TestDecrypter. + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(/*key=*/0)); + framer_.SetKeyUpdateSupportForConnection(true); + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicFrames frames = {QuicFrame(QuicPaddingFrame())}; + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + std::unique_ptr encrypted( + EncryptPacketWithTagAndPhase(*data, 0, false)); + ASSERT_TRUE(encrypted); + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + // Processed valid packet with phase=0, key=1: no key update. + EXPECT_EQ(0u, visitor_.key_update_count()); + EXPECT_EQ(0, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); + EXPECT_EQ(0u, framer_.PotentialPeerKeyUpdateAttemptCount()); + + header.packet_number += 1; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 2, true); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Packet with phase=1 but key=3, should not process and should not cause key + // update, but next decrypter key should have been created to attempt to + // decode it. + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(0u, visitor_.key_update_count()); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); + EXPECT_EQ(1u, framer_.PotentialPeerKeyUpdateAttemptCount()); + + header.packet_number += 1; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 0, true); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Packet with phase=1 but key=1, should not process and should not cause key + // update. + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(0u, visitor_.key_update_count()); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); + EXPECT_EQ(2u, framer_.PotentialPeerKeyUpdateAttemptCount()); + + header.packet_number += 1; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 1, false); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Packet with phase=0 but key=2, should not process and should not cause key + // update. + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(0u, visitor_.key_update_count()); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); + EXPECT_EQ(2u, framer_.PotentialPeerKeyUpdateAttemptCount()); + + header.packet_number += 1; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 0, false); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Packet with phase=0 and key=0, should process and reset + // potential_peer_key_update_attempt_count_. + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(0u, visitor_.key_update_count()); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); + EXPECT_EQ(0u, framer_.PotentialPeerKeyUpdateAttemptCount()); +} + +TEST_P(QuicFramerTest, KeyUpdateReceivedWhenNotEnabled) { + if (!framer_.version().UsesTls()) { + // Key update is only used in QUIC+TLS. + return; + } + ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse()); + // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter + // instead of TestDecrypter. + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(/*key=*/0)); + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicFrames frames = {QuicFrame(QuicPaddingFrame())}; + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + std::unique_ptr encrypted( + EncryptPacketWithTagAndPhase(*data, 1, true)); + ASSERT_TRUE(encrypted); + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Received a packet with key phase updated even though framer hasn't had key + // update enabled (SetNextOneRttCrypters never called). Should fail to + // process. + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(0u, visitor_.key_update_count()); + EXPECT_EQ(0, visitor_.derive_next_key_count_); + EXPECT_EQ(0, visitor_.decrypted_first_packet_in_key_phase_count_); +} + +TEST_P(QuicFramerTest, KeyUpdateLocallyInitiated) { + if (!framer_.version().UsesTls()) { + // Key update is only used in QUIC+TLS. + return; + } + ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse()); + // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter + // instead of TestDecrypter. + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(/*key=*/0)); + framer_.SetKeyUpdateSupportForConnection(true); + + EXPECT_TRUE(framer_.DoKeyUpdate(KeyUpdateReason::kLocalForTests)); + // Key update count should be updated, but haven't received packet from peer + // with new key phase. + ASSERT_EQ(1u, visitor_.key_update_count()); + EXPECT_EQ(KeyUpdateReason::kLocalForTests, visitor_.key_update_reasons_[0]); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(0, visitor_.decrypted_first_packet_in_key_phase_count_); + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicFrames frames = {QuicFrame(QuicPaddingFrame())}; + + // Process packet N with phase 1. + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + std::unique_ptr encrypted( + EncryptPacketWithTagAndPhase(*data, 1, true)); + ASSERT_TRUE(encrypted); + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Packet should decrypt and key update count should not change and + // OnDecryptedFirstPacketInKeyPhase should have been called. + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(1u, visitor_.key_update_count()); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); + + // Process packet N-1 with phase 0. (Receiving packet from previous phase + // after packet from new phase was received.) + header.packet_number = kPacketNumber - 1; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 0, false); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Packet should decrypt and key update count should not change. + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(1u, visitor_.key_update_count()); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); + + // Process packet N+1 with phase 0 and key 1. This should not decrypt even + // though it's using the previous key, since the packet number is higher than + // a packet number received using the current key. + header.packet_number = kPacketNumber + 1; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 0, false); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Packet should not decrypt and key update count should not change. + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(1u, visitor_.key_update_count()); + EXPECT_EQ(2, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); +} + +TEST_P(QuicFramerTest, KeyUpdateLocallyInitiatedReceivedOldPacket) { + if (!framer_.version().UsesTls()) { + // Key update is only used in QUIC+TLS. + return; + } + ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse()); + // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter + // instead of TestDecrypter. + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(/*key=*/0)); + framer_.SetKeyUpdateSupportForConnection(true); + + EXPECT_TRUE(framer_.DoKeyUpdate(KeyUpdateReason::kLocalForTests)); + // Key update count should be updated, but haven't received packet + // from peer with new key phase. + EXPECT_EQ(1u, visitor_.key_update_count()); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(0, visitor_.decrypted_first_packet_in_key_phase_count_); + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + QuicFrames frames = {QuicFrame(QuicPaddingFrame())}; + + // Process packet N with phase 0. (Receiving packet from previous phase + // after locally initiated key update, but before any packet from new phase + // was received.) + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + std::unique_ptr data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + std::unique_ptr encrypted = + EncryptPacketWithTagAndPhase(*data, 0, false); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Packet should decrypt and key update count should not change and + // OnDecryptedFirstPacketInKeyPhase should not have been called since the + // packet was from the previous key phase. + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(1u, visitor_.key_update_count()); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(0, visitor_.decrypted_first_packet_in_key_phase_count_); + + // Process packet N+1 with phase 1. + header.packet_number = kPacketNumber + 1; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 1, true); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Packet should decrypt and key update count should not change, but + // OnDecryptedFirstPacketInKeyPhase should have been called. + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(1u, visitor_.key_update_count()); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); + + // Process packet N+2 with phase 0 and key 1. This should not decrypt even + // though it's using the previous key, since the packet number is higher than + // a packet number received using the current key. + header.packet_number = kPacketNumber + 2; + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + data = BuildDataPacket(header, frames); + ASSERT_TRUE(data != nullptr); + encrypted = EncryptPacketWithTagAndPhase(*data, 0, false); + ASSERT_TRUE(encrypted); + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + // Packet should not decrypt and key update count should not change. + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_EQ(1u, visitor_.key_update_count()); + EXPECT_EQ(2, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); +} + +TEST_P(QuicFramerTest, KeyUpdateOnFirstReceivedPacket) { + if (!framer_.version().UsesTls()) { + // Key update is only used in QUIC+TLS. + return; + } + ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse()); + // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter + // instead of TestDecrypter. + framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(/*key=*/0)); + framer_.SetKeyUpdateSupportForConnection(true); + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = QuicPacketNumber(123); + + QuicFrames frames = {QuicFrame(QuicPaddingFrame())}; + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + std::unique_ptr data(BuildDataPacket(header, frames)); + ASSERT_TRUE(data != nullptr); + std::unique_ptr encrypted( + EncryptPacketWithTagAndPhase(*data, /*tag=*/1, /*phase=*/true)); + ASSERT_TRUE(encrypted); + + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + // Processed valid packet with phase=1, key=1: do key update. + EXPECT_EQ(1u, visitor_.key_update_count()); + EXPECT_EQ(1, visitor_.derive_next_key_count_); + EXPECT_EQ(1, visitor_.decrypted_first_packet_in_key_phase_count_); +} + +TEST_P(QuicFramerTest, ErrorWhenUnexpectedFrameTypeEncountered) { + if (!VersionHasIetfQuicFrames(framer_.transport_version()) || + !QuicVersionHasLongHeaderLengths(framer_.transport_version()) || + !framer_.version().HasLongHeaderLengths()) { + return; + } + SetDecrypterLevel(ENCRYPTION_ZERO_RTT); + // clang-format off + unsigned char packet[] = { + // public flags (long header with packet type ZERO_RTT_PROTECTED and + // 4-byte packet number) + 0xD3, + // version + QUIC_VERSION_BYTES, + // destination connection ID length + 0x08, + // destination connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // source connection ID length + 0x08, + // source connection ID + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x11, + // long header packet length + 0x05, + // packet number + 0x12, 0x34, 0x56, 0x00, + // unexpected ietf ack frame type in 0-RTT packet + 0x02, + }; + // clang-format on + + ReviseFirstByteByVersion(packet); + QuicEncryptedPacket encrypted(AsChars(packet), ABSL_ARRAYSIZE(packet), false); + + EXPECT_FALSE(framer_.ProcessPacket(encrypted)); + + EXPECT_THAT(framer_.error(), IsError(IETF_QUIC_PROTOCOL_VIOLATION)); + EXPECT_EQ( + "IETF frame type IETF_ACK is unexpected at encryption level " + "ENCRYPTION_ZERO_RTT", + framer_.detailed_error()); +} + +TEST_P(QuicFramerTest, ShortHeaderWithNonDefaultConnectionIdLength) { + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + unsigned char packet[kMaxIncomingPacketSize + 1] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0x28, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, 0x48, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (padding frame) + 0x00, + 0x00, 0x00, 0x00, 0x00 + }; + MockConnectionIdGenerator generator; + if (version_.HasIetfInvariantHeader()) { + EXPECT_CALL(generator, ConnectionIdLength(0x28)).WillOnce(Return(9)); + } else { + packet[0] = 0x0a; + EXPECT_CALL(generator, ConnectionIdLength(_)).Times(0); + } + unsigned char* p = packet; + size_t p_size = ABSL_ARRAYSIZE(packet); + + const size_t header_size = GetPacketHeaderSize( + framer_.transport_version(), kPacket8ByteConnectionId + 1, + kPacket0ByteConnectionId, !kIncludeVersion, + !kIncludeDiversificationNonce, PACKET_4BYTE_PACKET_NUMBER, + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0, 0, + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0) + 1; + // Add one because it's a 9 byte connection ID. + + memset(p + header_size, 0, kMaxIncomingPacketSize - header_size); + + QuicEncryptedPacket encrypted(AsChars(p), p_size, false); + PacketHeaderFormat format; + QuicLongHeaderType long_packet_type = INVALID_PACKET_TYPE; + bool version_flag; + QuicConnectionId destination_connection_id, source_connection_id; + QuicVersionLabel version_label; + std::string detailed_error; + bool use_length_prefix; + absl::optional retry_token; + ParsedQuicVersion parsed_version = UnsupportedQuicVersion(); + EXPECT_EQ(QUIC_NO_ERROR, + QuicFramer::ParsePublicHeaderDispatcherShortHeaderLengthUnknown( + encrypted, &format, &long_packet_type, &version_flag, + &use_length_prefix, &version_label, &parsed_version, + &destination_connection_id, &source_connection_id, &retry_token, + &detailed_error, generator)); + if (version_.HasIetfInvariantHeader()) { + EXPECT_EQ(format, IETF_QUIC_SHORT_HEADER_PACKET); + EXPECT_EQ(destination_connection_id.length(), 9); + } else { + EXPECT_EQ(format, GOOGLE_QUIC_PACKET); + EXPECT_EQ(destination_connection_id.length(), 8); + } + EXPECT_EQ(long_packet_type, INVALID_PACKET_TYPE); + EXPECT_FALSE(version_flag); + EXPECT_FALSE(use_length_prefix); + EXPECT_EQ(version_label, 0); + EXPECT_EQ(parsed_version, UnsupportedQuicVersion()); + EXPECT_EQ(source_connection_id.length(), 0); + EXPECT_FALSE(retry_token.has_value()); + EXPECT_EQ(detailed_error, ""); +} + +TEST_P(QuicFramerTest, ReportEcnCountsIfPresent) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + + QuicPacketHeader header; + header.destination_connection_id = FramerTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + for (bool ecn_marks : { false, true }) { + // Add some padding, because TestEncrypter doesn't add an authentication + // tag. For a small packet, this will cause QuicFramer to fail to get a + // header protection sample. + QuicPaddingFrame padding_frame(kTagSize); + // Create a packet with just an ack. + QuicAckFrame ack_frame = InitAckFrame(5); + if (ecn_marks) { + ack_frame.ecn_counters = QuicEcnCounts(100, 10000, 1000000); + } else { + ack_frame.ecn_counters = absl::nullopt; + } + QuicFrames frames = {QuicFrame(padding_frame), QuicFrame(&ack_frame)}; + // Build an ACK packet. + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); + std::unique_ptr raw_ack_packet(BuildDataPacket(header, frames)); + ASSERT_TRUE(raw_ack_packet != nullptr); + char buffer[kMaxOutgoingPacketSize]; + size_t encrypted_length = + framer_.EncryptPayload(ENCRYPTION_INITIAL, header.packet_number, + *raw_ack_packet, buffer, kMaxOutgoingPacketSize); + ASSERT_NE(0u, encrypted_length); + // Now make sure we can turn our ack packet back into an ack frame. + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + MockFramerVisitor visitor; + framer_.set_visitor(&visitor); + EXPECT_CALL(visitor, OnPacket()).Times(1); + EXPECT_CALL(visitor, OnUnauthenticatedPublicHeader(_)) + .Times(1) + .WillOnce(Return(true)); + EXPECT_CALL(visitor, OnUnauthenticatedHeader(_)) + .Times(1) + .WillOnce(Return(true)); + EXPECT_CALL(visitor, OnPacketHeader(_)).Times(1); + EXPECT_CALL(visitor, OnDecryptedPacket(_, _)).Times(1); + EXPECT_CALL(visitor, OnAckFrameStart(_, _)).Times(1).WillOnce(Return(true)); + EXPECT_CALL(visitor, OnAckRange(_, _)).Times(1).WillOnce(Return(true)); + EXPECT_CALL(visitor, OnAckFrameEnd(_, ack_frame.ecn_counters)) + .Times(1).WillOnce(Return(true)); + EXPECT_CALL(visitor, OnPacketComplete()).Times(1); + ASSERT_TRUE(framer_.ProcessPacket( + QuicEncryptedPacket(buffer, encrypted_length, false))); + } +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_idle_network_detector.cc b/quiche/quic/core/quic_idle_network_detector.cc new file mode 100644 index 000000000000..5a63ebc2ade3 --- /dev/null +++ b/quiche/quic/core/quic_idle_network_detector.cc @@ -0,0 +1,173 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_idle_network_detector.h" + +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace quic { + +namespace { + +class AlarmDelegate : public QuicAlarm::DelegateWithContext { + public: + explicit AlarmDelegate(QuicIdleNetworkDetector* detector, + QuicConnectionContext* context) + : QuicAlarm::DelegateWithContext(context), detector_(detector) {} + AlarmDelegate(const AlarmDelegate&) = delete; + AlarmDelegate& operator=(const AlarmDelegate&) = delete; + + void OnAlarm() override { detector_->OnAlarm(); } + + private: + QuicIdleNetworkDetector* detector_; +}; + +} // namespace + +QuicIdleNetworkDetector::QuicIdleNetworkDetector( + Delegate* delegate, QuicTime now, QuicConnectionArena* arena, + QuicAlarmFactory* alarm_factory, QuicConnectionContext* context) + : delegate_(delegate), + start_time_(now), + handshake_timeout_(QuicTime::Delta::Infinite()), + time_of_last_received_packet_(now), + time_of_first_packet_sent_after_receiving_(QuicTime::Zero()), + idle_network_timeout_(QuicTime::Delta::Infinite()), + bandwidth_update_timeout_(QuicTime::Delta::Infinite()), + alarm_(alarm_factory->CreateAlarm( + arena->New(this, context), arena)) {} + +void QuicIdleNetworkDetector::OnAlarm() { + if (!bandwidth_update_timeout_.IsInfinite()) { + QUICHE_DCHECK(handshake_timeout_.IsInfinite()); + bandwidth_update_timeout_ = QuicTime::Delta::Infinite(); + SetAlarm(); + delegate_->OnBandwidthUpdateTimeout(); + return; + } + if (handshake_timeout_.IsInfinite()) { + delegate_->OnIdleNetworkDetected(); + return; + } + if (idle_network_timeout_.IsInfinite()) { + delegate_->OnHandshakeTimeout(); + return; + } + if (last_network_activity_time() + idle_network_timeout_ > + start_time_ + handshake_timeout_) { + delegate_->OnHandshakeTimeout(); + return; + } + delegate_->OnIdleNetworkDetected(); +} + +void QuicIdleNetworkDetector::SetTimeouts( + QuicTime::Delta handshake_timeout, QuicTime::Delta idle_network_timeout) { + handshake_timeout_ = handshake_timeout; + idle_network_timeout_ = idle_network_timeout; + bandwidth_update_timeout_ = QuicTime::Delta::Infinite(); + + if (GetQuicRestartFlag( + quic_enable_sending_bandwidth_estimate_when_network_idle_v2) && + handshake_timeout_.IsInfinite()) { + QUIC_RESTART_FLAG_COUNT_N( + quic_enable_sending_bandwidth_estimate_when_network_idle_v2, 1, 3); + bandwidth_update_timeout_ = idle_network_timeout_ * 0.5; + } + + SetAlarm(); +} + +void QuicIdleNetworkDetector::StopDetection() { + alarm_->PermanentCancel(); + handshake_timeout_ = QuicTime::Delta::Infinite(); + idle_network_timeout_ = QuicTime::Delta::Infinite(); + handshake_timeout_ = QuicTime::Delta::Infinite(); + stopped_ = true; +} + +void QuicIdleNetworkDetector::OnPacketSent(QuicTime now, + QuicTime::Delta pto_delay) { + if (time_of_first_packet_sent_after_receiving_ > + time_of_last_received_packet_) { + return; + } + time_of_first_packet_sent_after_receiving_ = + std::max(time_of_first_packet_sent_after_receiving_, now); + if (shorter_idle_timeout_on_sent_packet_) { + MaybeSetAlarmOnSentPacket(pto_delay); + return; + } + + SetAlarm(); +} + +void QuicIdleNetworkDetector::OnPacketReceived(QuicTime now) { + time_of_last_received_packet_ = std::max(time_of_last_received_packet_, now); + + SetAlarm(); +} + +void QuicIdleNetworkDetector::SetAlarm() { + if (stopped_) { + // TODO(wub): If this QUIC_BUG fires, it indicates a problem in the + // QuicConnection, which somehow called this function while disconnected. + // That problem needs to be fixed. + QUIC_BUG(quic_idle_detector_set_alarm_after_stopped) + << "SetAlarm called after stopped"; + return; + } + // Set alarm to the nearer deadline. + QuicTime new_deadline = QuicTime::Zero(); + if (!handshake_timeout_.IsInfinite()) { + new_deadline = start_time_ + handshake_timeout_; + } + if (!idle_network_timeout_.IsInfinite()) { + const QuicTime idle_network_deadline = GetIdleNetworkDeadline(); + if (new_deadline.IsInitialized()) { + new_deadline = std::min(new_deadline, idle_network_deadline); + } else { + new_deadline = idle_network_deadline; + } + } + if (!bandwidth_update_timeout_.IsInfinite()) { + new_deadline = std::min(new_deadline, GetBandwidthUpdateDeadline()); + } + alarm_->Update(new_deadline, kAlarmGranularity); +} + +void QuicIdleNetworkDetector::MaybeSetAlarmOnSentPacket( + QuicTime::Delta pto_delay) { + QUICHE_DCHECK(shorter_idle_timeout_on_sent_packet_); + if (!handshake_timeout_.IsInfinite() || !alarm_->IsSet()) { + SetAlarm(); + return; + } + // Make sure connection will be alive for another PTO. + const QuicTime deadline = alarm_->deadline(); + const QuicTime min_deadline = last_network_activity_time() + pto_delay; + if (deadline > min_deadline) { + return; + } + alarm_->Update(min_deadline, kAlarmGranularity); +} + +QuicTime QuicIdleNetworkDetector::GetIdleNetworkDeadline() const { + if (idle_network_timeout_.IsInfinite()) { + return QuicTime::Zero(); + } + return last_network_activity_time() + idle_network_timeout_; +} + +QuicTime QuicIdleNetworkDetector::GetBandwidthUpdateDeadline() const { + QUICHE_DCHECK(!bandwidth_update_timeout_.IsInfinite()); + return last_network_activity_time() + bandwidth_update_timeout_; +} + +} // namespace quic diff --git a/quiche/quic/core/quic_idle_network_detector.h b/quiche/quic/core/quic_idle_network_detector.h new file mode 100644 index 000000000000..ca4314825452 --- /dev/null +++ b/quiche/quic/core/quic_idle_network_detector.h @@ -0,0 +1,130 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_IDLE_NETWORK_DETECTOR_H_ +#define QUICHE_QUIC_CORE_QUIC_IDLE_NETWORK_DETECTOR_H_ + +#include "quiche/quic/core/quic_alarm.h" +#include "quiche/quic/core/quic_alarm_factory.h" +#include "quiche/quic/core/quic_one_block_arena.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_flags.h" + +namespace quic { + +namespace test { +class QuicConnectionPeer; +class QuicIdleNetworkDetectorTestPeer; +} // namespace test + +// QuicIdleNetworkDetector detects handshake timeout and idle network timeout. +// Handshake timeout detection is disabled after handshake completes. Idle +// network deadline is extended by network activity (e.g., sending or receiving +// packets). +class QUIC_EXPORT_PRIVATE QuicIdleNetworkDetector { + public: + class QUIC_EXPORT_PRIVATE Delegate { + public: + virtual ~Delegate() {} + + // Called when handshake times out. + virtual void OnHandshakeTimeout() = 0; + + // Called when idle network has been detected. + virtual void OnIdleNetworkDetected() = 0; + + // Called when bandwidth update alarms. + virtual void OnBandwidthUpdateTimeout() = 0; + }; + + QuicIdleNetworkDetector(Delegate* delegate, QuicTime now, + QuicConnectionArena* arena, + QuicAlarmFactory* alarm_factory, + QuicConnectionContext* context); + + void OnAlarm(); + + // Called to set handshake_timeout_ and idle_network_timeout_. + void SetTimeouts(QuicTime::Delta handshake_timeout, + QuicTime::Delta idle_network_timeout); + + // Stop the detection once and for all. + void StopDetection(); + + // Called when a packet gets sent. + void OnPacketSent(QuicTime now, QuicTime::Delta pto_delay); + + // Called when a packet gets received. + void OnPacketReceived(QuicTime now); + + void enable_shorter_idle_timeout_on_sent_packet() { + shorter_idle_timeout_on_sent_packet_ = true; + } + + QuicTime::Delta handshake_timeout() const { return handshake_timeout_; } + + QuicTime time_of_last_received_packet() const { + return time_of_last_received_packet_; + } + + QuicTime last_network_activity_time() const { + return std::max(time_of_last_received_packet_, + time_of_first_packet_sent_after_receiving_); + } + + QuicTime::Delta idle_network_timeout() const { return idle_network_timeout_; } + + QuicTime::Delta bandwidth_update_timeout() const { + return bandwidth_update_timeout_; + } + + QuicTime GetIdleNetworkDeadline() const; + + private: + friend class test::QuicConnectionPeer; + friend class test::QuicIdleNetworkDetectorTestPeer; + + void SetAlarm(); + + void MaybeSetAlarmOnSentPacket(QuicTime::Delta pto_delay); + + QuicTime GetBandwidthUpdateDeadline() const; + + Delegate* delegate_; // Not owned. + + // Start time of the detector, handshake deadline = start_time_ + + // handshake_timeout_. + const QuicTime start_time_; + + // Handshake timeout. Infinite means handshake has completed. + QuicTime::Delta handshake_timeout_; + + // Time that last packet is received for this connection. Initialized to + // start_time_. + QuicTime time_of_last_received_packet_; + + // Time that the first packet gets sent after the received packet. idle + // network deadline = std::max(time_of_last_received_packet_, + // time_of_first_packet_sent_after_receiving_) + idle_network_timeout_. + // Initialized to 0. + QuicTime time_of_first_packet_sent_after_receiving_; + + // Idle network timeout. Infinite means no idle network timeout. + QuicTime::Delta idle_network_timeout_; + + // Bandwidth update timeout. Infinite means no bandwidth update timeout. + QuicTime::Delta bandwidth_update_timeout_; + + QuicArenaScopedPtr alarm_; + + bool shorter_idle_timeout_on_sent_packet_ = false; + + // Whether |StopDetection| has been called. + bool stopped_ = false; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_IDLE_NETWORK_DETECTOR_H_ diff --git a/quiche/quic/core/quic_idle_network_detector_test.cc b/quiche/quic/core/quic_idle_network_detector_test.cc new file mode 100644 index 000000000000..a139eb25ec14 --- /dev/null +++ b/quiche/quic/core/quic_idle_network_detector_test.cc @@ -0,0 +1,282 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_idle_network_detector.h" + +#include "quiche/quic/core/quic_one_block_arena.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { +namespace test { + +class QuicIdleNetworkDetectorTestPeer { + public: + static QuicAlarm* GetAlarm(QuicIdleNetworkDetector* detector) { + return detector->alarm_.get(); + } +}; + +namespace { + +class MockDelegate : public QuicIdleNetworkDetector::Delegate { + public: + MOCK_METHOD(void, OnHandshakeTimeout, (), (override)); + MOCK_METHOD(void, OnIdleNetworkDetected, (), (override)); + MOCK_METHOD(void, OnBandwidthUpdateTimeout, (), (override)); +}; + +class QuicIdleNetworkDetectorTest : public QuicTest { + public: + QuicIdleNetworkDetectorTest() { + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(1)); + detector_ = std::make_unique( + &delegate_, clock_.Now(), &arena_, &alarm_factory_, + /*context=*/nullptr); + alarm_ = static_cast( + QuicIdleNetworkDetectorTestPeer::GetAlarm(detector_.get())); + } + + protected: + testing::StrictMock delegate_; + QuicConnectionArena arena_; + MockAlarmFactory alarm_factory_; + + std::unique_ptr detector_; + + MockAlarmFactory::TestAlarm* alarm_; + MockClock clock_; +}; + +TEST_F(QuicIdleNetworkDetectorTest, + IdleNetworkDetectedBeforeHandshakeCompletes) { + EXPECT_FALSE(alarm_->IsSet()); + detector_->SetTimeouts( + /*handshake_timeout=*/QuicTime::Delta::FromSeconds(30), + /*idle_network_timeout=*/QuicTime::Delta::FromSeconds(20)); + EXPECT_TRUE(alarm_->IsSet()); + EXPECT_EQ(clock_.Now() + QuicTime::Delta::FromSeconds(20), + alarm_->deadline()); + + // No network activity for 20s. + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(20)); + EXPECT_CALL(delegate_, OnIdleNetworkDetected()); + alarm_->Fire(); +} + +TEST_F(QuicIdleNetworkDetectorTest, HandshakeTimeout) { + EXPECT_FALSE(alarm_->IsSet()); + detector_->SetTimeouts( + /*handshake_timeout=*/QuicTime::Delta::FromSeconds(30), + /*idle_network_timeout=*/QuicTime::Delta::FromSeconds(20)); + EXPECT_TRUE(alarm_->IsSet()); + + // Has network activity after 15s. + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(15)); + detector_->OnPacketReceived(clock_.Now()); + EXPECT_EQ(clock_.Now() + QuicTime::Delta::FromSeconds(15), + alarm_->deadline()); + // Handshake does not complete for another 15s. + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(15)); + EXPECT_CALL(delegate_, OnHandshakeTimeout()); + alarm_->Fire(); +} + +TEST_F(QuicIdleNetworkDetectorTest, + IdleNetworkDetectedAfterHandshakeCompletes) { + EXPECT_FALSE(alarm_->IsSet()); + detector_->SetTimeouts( + /*handshake_timeout=*/QuicTime::Delta::FromSeconds(30), + /*idle_network_timeout=*/QuicTime::Delta::FromSeconds(20)); + EXPECT_TRUE(alarm_->IsSet()); + EXPECT_EQ(clock_.Now() + QuicTime::Delta::FromSeconds(20), + alarm_->deadline()); + + // Handshake completes in 200ms. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(200)); + detector_->OnPacketReceived(clock_.Now()); + detector_->SetTimeouts( + /*handshake_timeout=*/QuicTime::Delta::Infinite(), + /*idle_network_timeout=*/QuicTime::Delta::FromSeconds(600)); + if (!GetQuicRestartFlag( + quic_enable_sending_bandwidth_estimate_when_network_idle_v2)) { + EXPECT_EQ(clock_.Now() + QuicTime::Delta::FromSeconds(600), + alarm_->deadline()); + + // No network activity for 600s. + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(600)); + EXPECT_CALL(delegate_, OnIdleNetworkDetected()); + alarm_->Fire(); + return; + } + + EXPECT_EQ(clock_.Now() + QuicTime::Delta::FromSeconds(300), + alarm_->deadline()); + + // No network activity for 300s. + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(300)); + EXPECT_CALL(delegate_, OnBandwidthUpdateTimeout()); + alarm_->Fire(); + EXPECT_EQ(clock_.Now() + QuicTime::Delta::FromSeconds(300), + alarm_->deadline()); + + // No network activity for 600s. + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(300)); + EXPECT_CALL(delegate_, OnIdleNetworkDetected()); + alarm_->Fire(); +} + +TEST_F(QuicIdleNetworkDetectorTest, + DoNotExtendIdleDeadlineOnConsecutiveSentPackets) { + EXPECT_FALSE(alarm_->IsSet()); + detector_->SetTimeouts( + /*handshake_timeout=*/QuicTime::Delta::FromSeconds(30), + /*idle_network_timeout=*/QuicTime::Delta::FromSeconds(20)); + EXPECT_TRUE(alarm_->IsSet()); + + // Handshake completes in 200ms. + const bool enable_sending_bandwidth_estimate_when_network_idle = + GetQuicRestartFlag( + quic_enable_sending_bandwidth_estimate_when_network_idle_v2); + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(200)); + detector_->OnPacketReceived(clock_.Now()); + detector_->SetTimeouts( + /*handshake_timeout=*/QuicTime::Delta::Infinite(), + enable_sending_bandwidth_estimate_when_network_idle + ? QuicTime::Delta::FromSeconds(1200) + : QuicTime::Delta::FromSeconds(600)); + EXPECT_EQ(clock_.Now() + QuicTime::Delta::FromSeconds(600), + alarm_->deadline()); + + // Sent packets after 200ms. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(200)); + detector_->OnPacketSent(clock_.Now(), QuicTime::Delta::Zero()); + const QuicTime packet_sent_time = clock_.Now(); + EXPECT_EQ(packet_sent_time + QuicTime::Delta::FromSeconds(600), + alarm_->deadline()); + + // Sent another packet after 200ms + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(200)); + detector_->OnPacketSent(clock_.Now(), QuicTime::Delta::Zero()); + // Verify network deadline does not extend. + EXPECT_EQ(packet_sent_time + QuicTime::Delta::FromSeconds(600), + alarm_->deadline()); + + if (!enable_sending_bandwidth_estimate_when_network_idle) { + // No network activity for 600s. + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(600) - + QuicTime::Delta::FromMilliseconds(200)); + EXPECT_CALL(delegate_, OnIdleNetworkDetected()); + alarm_->Fire(); + return; + } + + // Bandwidth update times out after no network activity for 600s. + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(600) - + QuicTime::Delta::FromMilliseconds(200)); + EXPECT_CALL(delegate_, OnBandwidthUpdateTimeout()); + alarm_->Fire(); + EXPECT_TRUE(alarm_->IsSet()); + EXPECT_EQ(packet_sent_time + QuicTime::Delta::FromSeconds(1200), + alarm_->deadline()); + + // Network idle time out after no network activity for 1200s. + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(1200) - + QuicTime::Delta::FromMilliseconds(600)); + EXPECT_CALL(delegate_, OnIdleNetworkDetected()); + alarm_->Fire(); +} + +TEST_F(QuicIdleNetworkDetectorTest, ShorterIdleTimeoutOnSentPacket) { + detector_->enable_shorter_idle_timeout_on_sent_packet(); + QuicTime::Delta idle_network_timeout = QuicTime::Delta::Zero(); + if (GetQuicRestartFlag( + quic_enable_sending_bandwidth_estimate_when_network_idle_v2)) { + idle_network_timeout = QuicTime::Delta::FromSeconds(60); + } else { + idle_network_timeout = QuicTime::Delta::FromSeconds(30); + } + detector_->SetTimeouts( + /*handshake_timeout=*/QuicTime::Delta::Infinite(), idle_network_timeout); + EXPECT_TRUE(alarm_->IsSet()); + const QuicTime deadline = alarm_->deadline(); + EXPECT_EQ(clock_.Now() + QuicTime::Delta::FromSeconds(30), deadline); + + // Send a packet after 15s and 2s PTO delay. + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(15)); + detector_->OnPacketSent(clock_.Now(), QuicTime::Delta::FromSeconds(2)); + EXPECT_TRUE(alarm_->IsSet()); + // Verify alarm does not get extended because deadline is > PTO delay. + EXPECT_EQ(deadline, alarm_->deadline()); + + // Send another packet near timeout and 2 s PTO delay. + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(14)); + detector_->OnPacketSent(clock_.Now(), QuicTime::Delta::FromSeconds(2)); + EXPECT_TRUE(alarm_->IsSet()); + // Verify alarm does not get extended although it is shorter than PTO. + EXPECT_EQ(deadline, alarm_->deadline()); + + // Receive a packet after 1s. + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(1)); + detector_->OnPacketReceived(clock_.Now()); + EXPECT_TRUE(alarm_->IsSet()); + // Verify idle timeout gets extended by 30s. + EXPECT_EQ(clock_.Now() + QuicTime::Delta::FromSeconds(30), + alarm_->deadline()); + + // Send a packet near timeout. + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(29)); + detector_->OnPacketSent(clock_.Now(), QuicTime::Delta::FromSeconds(2)); + EXPECT_TRUE(alarm_->IsSet()); + // Verify idle timeout gets extended by 1s. + EXPECT_EQ(clock_.Now() + QuicTime::Delta::FromSeconds(2), alarm_->deadline()); +} + +TEST_F(QuicIdleNetworkDetectorTest, NoAlarmAfterStopped) { + detector_->StopDetection(); + + EXPECT_QUIC_BUG( + detector_->SetTimeouts( + /*handshake_timeout=*/QuicTime::Delta::FromSeconds(30), + /*idle_network_timeout=*/QuicTime::Delta::FromSeconds(20)), + "SetAlarm called after stopped"); + EXPECT_FALSE(alarm_->IsSet()); +} + +TEST_F(QuicIdleNetworkDetectorTest, + ResetBandwidthTimeoutWhenHandshakeTimeoutIsSet) { + if (!GetQuicRestartFlag( + quic_enable_sending_bandwidth_estimate_when_network_idle_v2)) { + return; + } + detector_->SetTimeouts( + /*handshake_timeout=*/QuicTime::Delta::Infinite(), + /*idle_network_timeout=*/QuicTime::Delta::FromSeconds(20)); + // The deadline is set based on the bandwidth timeout. + EXPECT_EQ(clock_.Now() + QuicTime::Delta::FromSeconds(10), + alarm_->deadline()); + + detector_->SetTimeouts( + /*handshake_timeout=*/QuicTime::Delta::FromSeconds(15), + /*idle_network_timeout=*/QuicTime::Delta::FromSeconds(20)); + // Bandwidth timeout is reset and the deadline is set based on the handshake + // timeout. + EXPECT_EQ(clock_.Now() + QuicTime::Delta::FromSeconds(15), + alarm_->deadline()); + + detector_->SetTimeouts( + /*handshake_timeout=*/QuicTime::Delta::Infinite(), + /*idle_network_timeout=*/QuicTime::Delta::FromSeconds(20)); + // The deadline is set based on the bandwidth timeout. + EXPECT_EQ(clock_.Now() + QuicTime::Delta::FromSeconds(10), + alarm_->deadline()); +} + +} // namespace + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_interval.h b/quiche/quic/core/quic_interval.h new file mode 100644 index 000000000000..e291c53cae0b --- /dev/null +++ b/quiche/quic/core/quic_interval.h @@ -0,0 +1,381 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_INTERVAL_H_ +#define QUICHE_QUIC_CORE_QUIC_INTERVAL_H_ + +// An QuicInterval is a data structure used to represent a contiguous, +// mutable range over an ordered type T. Supported operations include testing a +// value to see whether it is included in the QuicInterval, comparing two +// QuicIntervals, and performing their union, intersection, and difference. For +// the purposes of this library, an "ordered type" is any type that induces a +// total order on its values via its less-than operator (operator<()). Examples +// of such types are basic arithmetic types like int and double as well as class +// types like string. +// +// An QuicInterval is represented using the usual C++ STL convention, namely +// as the half-open QuicInterval [min, max). A point p is considered to be +// contained in the QuicInterval iff p >= min && p < max. One consequence of +// this definition is that for any non-empty QuicInterval, min is contained in +// the QuicInterval but max is not. There is no canonical representation for the +// empty QuicInterval; rather, any QuicInterval where max <= min is regarded as +// empty. As a consequence, two empty QuicIntervals will still compare as equal +// despite possibly having different underlying min() or max() values. Also +// beware of the terminology used here: the library uses the terms "min" and +// "max" rather than "begin" and "end" as is conventional for the STL. +// +// T is required to be default- and copy-constructable, to have an assignment +// operator, and the full complement of comparison operators (<, <=, ==, !=, >=, +// >). A difference operator (operator-()) is required if +// QuicInterval::Length is used. +// +// QuicInterval supports operator==. Two QuicIntervals are considered equal if +// either they are both empty or if their corresponding min and max fields +// compare equal. QuicInterval also provides an operator<. Unfortunately, +// operator< is currently buggy because its behavior is inconsistent with +// operator==: two empty ranges with different representations may be regarded +// as equal by operator== but regarded as different by operator<. Bug 9240050 +// has been created to address this. +// +// +// Examples: +// QuicInterval r1(0, 100); // The QuicInterval [0, 100). +// EXPECT_TRUE(r1.Contains(0)); +// EXPECT_TRUE(r1.Contains(50)); +// EXPECT_FALSE(r1.Contains(100)); // 100 is just outside the QuicInterval. +// +// QuicInterval r2(50, 150); // The QuicInterval [50, 150). +// EXPECT_TRUE(r1.Intersects(r2)); +// EXPECT_FALSE(r1.Contains(r2)); +// EXPECT_TRUE(r1.IntersectWith(r2)); // Mutates r1. +// EXPECT_EQ(QuicInterval(50, 100), r1); // r1 is now [50, 100). +// +// QuicInterval r3(1000, 2000); // The QuicInterval [1000, 2000). +// EXPECT_TRUE(r1.IntersectWith(r3)); // Mutates r1. +// EXPECT_TRUE(r1.Empty()); // Now r1 is empty. +// EXPECT_FALSE(r1.Contains(r1.min())); // e.g. doesn't contain its own min. + +#include + +#include +#include +#include +#include +#include + +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +template +class QUIC_NO_EXPORT QuicInterval { + private: + // Type trait for deriving the return type for QuicInterval::Length. If + // operator-() is not defined for T, then the return type is void. This makes + // the signature for Length compile so that the class can be used for such T, + // but code that calls Length would still generate a compilation error. + template + class QUIC_NO_EXPORT DiffTypeOrVoid { + private: + template + static auto f(const V* v) -> decltype(*v - *v); + template + static void f(...); + + public: + using type = typename std::decay(nullptr))>::type; + }; + + public: + // Construct an QuicInterval representing an empty QuicInterval. + QuicInterval() : min_(), max_() {} + + // Construct an QuicInterval representing the QuicInterval [min, max). If min + // < max, the constructed object will represent the non-empty QuicInterval + // containing all values from min up to (but not including) max. On the other + // hand, if min >= max, the constructed object will represent the empty + // QuicInterval. + QuicInterval(const T& min, const T& max) : min_(min), max_(max) {} + + template ::value && + std::is_convertible::value>::type> + QuicInterval(U1&& min, U2&& max) + : min_(std::forward(min)), max_(std::forward(max)) {} + + const T& min() const { return min_; } + const T& max() const { return max_; } + void SetMin(const T& t) { min_ = t; } + void SetMax(const T& t) { max_ = t; } + + void Set(const T& min, const T& max) { + SetMin(min); + SetMax(max); + } + + void Clear() { *this = {}; } + + bool Empty() const { return min() >= max(); } + + // Returns the length of this QuicInterval. The value returned is zero if + // Empty() is true; otherwise the value returned is max() - min(). + typename DiffTypeOrVoid::type Length() const { + return (Empty() ? min() : max()) - min(); + } + + // Returns true iff t >= min() && t < max(). + bool Contains(const T& t) const { return min() <= t && max() > t; } + + // Returns true iff *this and i are non-empty, and *this includes i. "*this + // includes i" means that for all t, if i.Contains(t) then this->Contains(t). + // Note the unintuitive consequence of this definition: this method always + // returns false when i is the empty QuicInterval. + bool Contains(const QuicInterval& i) const { + return !Empty() && !i.Empty() && min() <= i.min() && max() >= i.max(); + } + + // Returns true iff there exists some point t for which this->Contains(t) && + // i.Contains(t) evaluates to true, i.e. if the intersection is non-empty. + bool Intersects(const QuicInterval& i) const { + return !Empty() && !i.Empty() && min() < i.max() && max() > i.min(); + } + + // Returns true iff there exists some point t for which this->Contains(t) && + // i.Contains(t) evaluates to true, i.e. if the intersection is non-empty. + // Furthermore, if the intersection is non-empty and the out pointer is not + // null, this method stores the calculated intersection in *out. + bool Intersects(const QuicInterval& i, QuicInterval* out) const; + + // Sets *this to be the intersection of itself with i. Returns true iff + // *this was modified. + bool IntersectWith(const QuicInterval& i); + + // Returns true iff this and other have disjoint closures. For nonempty + // intervals, that means there is at least one point between this and other. + // Roughly speaking that means the intervals don't intersect, and they are not + // adjacent. Empty intervals are always separated from any other interval. + bool Separated(const QuicInterval& other) const { + if (Empty() || other.Empty()) return true; + return other.max() < min() || max() < other.min(); + } + + // Calculates the smallest QuicInterval containing both *this i, and updates + // *this to represent that QuicInterval, and returns true iff *this was + // modified. + bool SpanningUnion(const QuicInterval& i); + + // Determines the difference between two QuicIntervals by finding all points + // that are contained in *this but not in i, coalesces those points into the + // largest possible contiguous QuicIntervals, and appends those QuicIntervals + // to the *difference vector. Intuitively this can be thought of as "erasing" + // i from *this. This will either completely erase *this (leaving nothing + // behind), partially erase some of *this from the left or right side (leaving + // some residual behind), or erase a hole in the middle of *this (leaving + // behind an QuicInterval on either side). Therefore, 0, 1, or 2 QuicIntervals + // will be appended to *difference. The method returns true iff the + // intersection of *this and i is non-empty. The caller owns the vector and + // the QuicInterval* pointers inside it. The difference vector is required to + // be non-null. + bool Difference(const QuicInterval& i, + std::vector* difference) const; + + // Determines the difference between two QuicIntervals as in + // Difference(QuicInterval&, vector*), but stores the results directly in out + // parameters rather than dynamically allocating an QuicInterval* and + // appending it to a vector. If two results are generated, the one with the + // smaller value of min() will be stored in *lo and the other in *hi. + // Otherwise (if fewer than two results are generated), unused arguments will + // be set to the empty QuicInterval (it is possible that *lo will be empty and + // *hi non-empty). The method returns true iff the intersection of *this and i + // is non-empty. + bool Difference(const QuicInterval& i, QuicInterval* lo, + QuicInterval* hi) const; + + friend bool operator==(const QuicInterval& a, const QuicInterval& b) { + bool ae = a.Empty(); + bool be = b.Empty(); + if (ae && be) return true; // All empties are equal. + if (ae != be) return false; // Empty cannot equal nonempty. + return a.min() == b.min() && a.max() == b.max(); + } + + friend bool operator!=(const QuicInterval& a, const QuicInterval& b) { + return !(a == b); + } + + // Defines a comparator which can be used to induce an order on QuicIntervals, + // so that, for example, they can be stored in an ordered container such as + // std::set. The ordering is arbitrary, but does provide the guarantee that, + // for non-empty QuicIntervals X and Y, if X contains Y, then X <= Y. + // TODO(kosak): The current implementation of this comparator has a problem + // because the ordering it induces is inconsistent with that of Equals(). In + // particular, this comparator does not properly consider all empty + // QuicIntervals equivalent. Bug 9240050 has been created to track this. + friend bool operator<(const QuicInterval& a, const QuicInterval& b) { + return a.min() < b.min() || (!(b.min() < a.min()) && b.max() < a.max()); + } + + private: + T min_; // Inclusive lower bound. + T max_; // Exclusive upper bound. +}; + +// Constructs an QuicInterval by deducing the types from the function arguments. +template +QuicInterval MakeQuicInterval(T&& lhs, T&& rhs) { + return QuicInterval(std::forward(lhs), std::forward(rhs)); +} + +// Note: ideally we'd use +// decltype(out << "[" << i.min() << ", " << i.max() << ")") +// as return type of the function, but as of July 2017 this triggers g++ +// "sorry, unimplemented: string literal in function template signature" error. +template +auto operator<<(std::ostream& out, const QuicInterval& i) + -> decltype(out << i.min()) { + return out << "[" << i.min() << ", " << i.max() << ")"; +} + +//============================================================================== +// Implementation details: Clients can stop reading here. + +template +bool QuicInterval::Intersects(const QuicInterval& i, + QuicInterval* out) const { + if (!Intersects(i)) return false; + if (out != nullptr) { + *out = QuicInterval(std::max(min(), i.min()), std::min(max(), i.max())); + } + return true; +} + +template +bool QuicInterval::IntersectWith(const QuicInterval& i) { + if (Empty()) return false; + bool modified = false; + if (i.min() > min()) { + SetMin(i.min()); + modified = true; + } + if (i.max() < max()) { + SetMax(i.max()); + modified = true; + } + return modified; +} + +template +bool QuicInterval::SpanningUnion(const QuicInterval& i) { + if (i.Empty()) return false; + if (Empty()) { + *this = i; + return true; + } + bool modified = false; + if (i.min() < min()) { + SetMin(i.min()); + modified = true; + } + if (i.max() > max()) { + SetMax(i.max()); + modified = true; + } + return modified; +} + +template +bool QuicInterval::Difference(const QuicInterval& i, + std::vector* difference) const { + if (Empty()) { + // - = + return false; + } + if (i.Empty()) { + // - = + difference->push_back(new QuicInterval(*this)); + return false; + } + if (min() < i.max() && min() >= i.min() && max() > i.max()) { + // [------ this ------) + // [------ i ------) + // [-- result ---) + difference->push_back(new QuicInterval(i.max(), max())); + return true; + } + if (max() > i.min() && max() <= i.max() && min() < i.min()) { + // [------ this ------) + // [------ i ------) + // [- result -) + difference->push_back(new QuicInterval(min(), i.min())); + return true; + } + if (min() < i.min() && max() > i.max()) { + // [------- this --------) + // [---- i ----) + // [ R1 ) [ R2 ) + // There are two results: R1 and R2. + difference->push_back(new QuicInterval(min(), i.min())); + difference->push_back(new QuicInterval(i.max(), max())); + return true; + } + if (min() >= i.min() && max() <= i.max()) { + // [--- this ---) + // [------ i --------) + // Intersection is , so difference yields the empty QuicInterval. + // Nothing is appended to *difference. + return true; + } + // No intersection. Append . + difference->push_back(new QuicInterval(*this)); + return false; +} + +template +bool QuicInterval::Difference(const QuicInterval& i, QuicInterval* lo, + QuicInterval* hi) const { + // Initialize *lo and *hi to empty + *lo = {}; + *hi = {}; + if (Empty()) return false; + if (i.Empty()) { + *lo = *this; + return false; + } + if (min() < i.max() && min() >= i.min() && max() > i.max()) { + // [------ this ------) + // [------ i ------) + // [-- result ---) + *hi = QuicInterval(i.max(), max()); + return true; + } + if (max() > i.min() && max() <= i.max() && min() < i.min()) { + // [------ this ------) + // [------ i ------) + // [- result -) + *lo = QuicInterval(min(), i.min()); + return true; + } + if (min() < i.min() && max() > i.max()) { + // [------- this --------) + // [---- i ----) + // [ R1 ) [ R2 ) + // There are two results: R1 and R2. + *lo = QuicInterval(min(), i.min()); + *hi = QuicInterval(i.max(), max()); + return true; + } + if (min() >= i.min() && max() <= i.max()) { + // [--- this ---) + // [------ i --------) + // Intersection is , so difference yields the empty QuicInterval. + return true; + } + *lo = *this; // No intersection. + return false; +} + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_INTERVAL_H_ diff --git a/quiche/quic/core/quic_interval_deque.h b/quiche/quic/core/quic_interval_deque.h new file mode 100644 index 000000000000..ed13e7a44df6 --- /dev/null +++ b/quiche/quic/core/quic_interval_deque.h @@ -0,0 +1,391 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_INTERVAL_DEQUE_H_ +#define QUICHE_QUIC_CORE_QUIC_INTERVAL_DEQUE_H_ + +#include + +#include "absl/types/optional.h" +#include "quiche/quic/core/quic_interval.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/quiche_circular_deque.h" + +namespace quic { + +namespace test { +class QuicIntervalDequePeer; +} // namespace test + +// QuicIntervalDeque is a templated wrapper container, wrapping random +// access data structures. The wrapper allows items to be added to the +// underlying container with intervals associated with each item. The intervals +// _should_ be added already sorted and represent searchable indices. The search +// is optimized for sequential usage. +// +// As the intervals are to be searched sequentially the search for the next +// interval can be achieved in O(1), by simply remembering the last interval +// consumed. The structure also checks for an "off-by-one" use case wherein the +// |cached_index_| is off by one index as the caller didn't call operator |++| +// to increment the index. Other intervals can be found in O(log(n)) as they are +// binary searchable. A use case for this structure is packet buffering: Packets +// are sent sequentially but can sometimes needed for retransmission. The +// packets and their payloads are added with associated intervals representing +// data ranges they carry. When a user asks for a particular interval it's very +// likely they are requesting the next sequential interval, receiving it in O(1) +// time. Updating the sequential index is done automatically through the +// |DataAt| method and its iterator operator |++|. +// +// The intervals are represented using the usual C++ STL convention, namely as +// the half-open QuicInterval [min, max). A point p is considered to be +// contained in the QuicInterval iff p >= min && p < max. One consequence of +// this definition is that for any non-empty QuicInterval, min is contained in +// the QuicInterval but max is not. There is no canonical representation for the +// empty QuicInterval; and empty intervals are forbidden from being added to +// this container as they would be unsearchable. +// +// The type T is required to be copy-constructable or move-constructable. The +// type T is also expected to have an |interval()| method returning a +// QuicInterval for the particular value. The type C is required to +// be a random access container supporting the methods |pop_front|, |push_back|, +// |operator[]|, |size|, and iterator support for |std::lower_bound| eg. a +// |deque| or |vector|. +// +// The QuicIntervalDeque, like other C++ STL random access containers, +// doesn't have any explicit support for any equality operators. +// +// +// Examples with internal state: +// +// // An example class to be stored inside the Interval Deque. +// struct IntervalVal { +// const int32_t val; +// const size_t interval_begin, interval_end; +// QuicInterval interval(); +// }; +// typedef IntervalVal IV; +// QuicIntervialDeque deque; +// +// // State: +// // cached_index -> None +// // container -> {} +// +// // Add interval items +// deque.PushBack(IV(val: 0, interval_begin: 0, interval_end: 10)); +// deque.PushBack(IV(val: 1, interval_begin: 20, interval_end: 25)); +// deque.PushBack(IV(val: 2, interval_begin: 25, interval_end: 30)); +// +// // State: +// // cached_index -> 0 +// // container -> {{0, [0, 10)}, {1, [20, 25)}, {2, [25, 30)}} +// +// // Look for 0 and return [0, 10). Time: O(1) +// auto it = deque.DataAt(0); +// assert(it->val == 0); +// it++; // Increment and move the |cached_index_| over [0, 10) to [20, 25). +// assert(it->val == 1); +// +// // State: +// // cached_index -> 1 +// // container -> {{0, [0, 10)}, {1, [20, 25)}, {2, [25, 30)}} +// +// // Look for 20 and return [20, 25). Time: O(1) +// auto it = deque.DataAt(20); // |cached_index_| remains unchanged. +// assert(it->val == 1); +// +// // State: +// // cached_index -> 1 +// // container -> {{0, [0, 10)}, {1, [20, 25)}, {2, [25, 30)}} +// +// // Look for 15 and return deque.DataEnd(). Time: O(log(n)) +// auto it = deque.DataAt(15); // |cached_index_| remains unchanged. +// assert(it == deque.DataEnd()); +// +// // Look for 25 and return [25, 30). Time: O(1) with off-by-one. +// auto it = deque.DataAt(25); // |cached_index_| is updated to 2. +// assert(it->val == 2); +// it++; // |cached_index_| is set to |None| as all data has been iterated. +// +// +// // State: +// // cached_index -> None +// // container -> {{0, [0, 10)}, {1, [20, 25)}, {2, [25, 30)}} +// +// // Look again for 0 and return [0, 10). Time: O(log(n)) +// auto it = deque.DataAt(0); +// +// +// deque.PopFront(); // Pop -> {0, [0, 10)} +// +// // State: +// // cached_index -> None +// // container -> {{1, [20, 25)}, {2, [25, 30)}} +// +// deque.PopFront(); // Pop -> {1, [20, 25)} +// +// // State: +// // cached_index -> None +// // container -> {{2, [25, 30)}} +// +// deque.PushBack(IV(val: 3, interval_begin: 35, interval_end: 50)); +// +// // State: +// // cached_index -> 1 +// // container -> {{2, [25, 30)}, {3, [35, 50)}} + +template > +class QUIC_NO_EXPORT QuicIntervalDeque { + public: + class QUIC_NO_EXPORT Iterator { + public: + // Used by |std::lower_bound| + using iterator_category = std::forward_iterator_tag; + using value_type = T; + using difference_type = std::ptrdiff_t; + using pointer = T*; + using reference = T&; + + // Every increment of the iterator will increment the |cached_index_| if + // |index_| is larger than the current |cached_index_|. |index_| is bounded + // at |deque_.size()| and any attempt to increment above that will be + // ignored. Once an iterator has iterated all elements the |cached_index_| + // will be reset. + Iterator(std::size_t index, QuicIntervalDeque* deque) + : index_(index), deque_(deque) {} + // Only the ++ operator attempts to update the cached index. Other operators + // are used by |lower_bound| to binary search and are thus private. + Iterator& operator++() { + // Don't increment when we are at the end. + const std::size_t container_size = deque_->container_.size(); + if (index_ >= container_size) { + QUIC_BUG(quic_bug_10862_1) << "Iterator out of bounds."; + return *this; + } + index_++; + if (deque_->cached_index_.has_value()) { + const std::size_t cached_index = deque_->cached_index_.value(); + // If all items are iterated then reset the |cached_index_| + if (index_ == container_size) { + deque_->cached_index_.reset(); + } else { + // Otherwise the new |cached_index_| is the max of itself and |index_| + if (cached_index < index_) { + deque_->cached_index_ = index_; + } + } + } + return *this; + } + Iterator operator++(int) { + Iterator copy = *this; + ++(*this); + return copy; + } + reference operator*() { return deque_->container_[index_]; } + reference operator*() const { return deque_->container_[index_]; } + pointer operator->() { return &deque_->container_[index_]; } + bool operator==(const Iterator& rhs) const { + return index_ == rhs.index_ && deque_ == rhs.deque_; + } + bool operator!=(const Iterator& rhs) const { return !(*this == rhs); } + + private: + // A set of private operators for |std::lower_bound| + Iterator operator+(difference_type amount) const { + Iterator copy = *this; + copy.index_ += amount; + QUICHE_DCHECK(copy.index_ < copy.deque_->size()); + return copy; + } + Iterator& operator+=(difference_type amount) { + index_ += amount; + QUICHE_DCHECK(index_ < deque_->size()); + return *this; + } + difference_type operator-(const Iterator& rhs) const { + return static_cast(index_) - + static_cast(rhs.index_); + } + + // |index_| is the index of the item in |*deque_|. + std::size_t index_; + // |deque_| is a pointer to the container the iterator came from. + QuicIntervalDeque* deque_; + + friend class QuicIntervalDeque; + }; + + QuicIntervalDeque(); + + // Adds an item to the underlying container. The |item|'s interval _should_ be + // strictly greater than the last interval added. + void PushBack(T&& item); + void PushBack(const T& item); + // Removes the front/top of the underlying container and the associated + // interval. + void PopFront(); + // Returns an iterator to the beginning of the data. The iterator will move + // the |cached_index_| as the iterator moves. + Iterator DataBegin(); + // Returns an iterator to the end of the data. + Iterator DataEnd(); + // Returns an iterator pointing to the item in |interval_begin|. The iterator + // will move the |cached_index_| as the iterator moves. + Iterator DataAt(const std::size_t interval_begin); + + // Returns the number of items contained inside the structure. + std::size_t Size() const; + // Returns whether the structure is empty. + bool Empty() const; + + private: + struct QUIC_NO_EXPORT IntervalCompare { + bool operator()(const T& item, std::size_t interval_begin) const { + return item.interval().max() <= interval_begin; + } + }; + + template + void PushBackUniversal(U&& item); + + Iterator Search(const std::size_t interval_begin, + const std::size_t begin_index, const std::size_t end_index); + + // For accessing the |cached_index_| + friend class test::QuicIntervalDequePeer; + + C container_; + absl::optional cached_index_; +}; + +template +QuicIntervalDeque::QuicIntervalDeque() {} + +template +void QuicIntervalDeque::PushBack(T&& item) { + PushBackUniversal(std::move(item)); +} + +template +void QuicIntervalDeque::PushBack(const T& item) { + PushBackUniversal(item); +} + +template +void QuicIntervalDeque::PopFront() { + if (container_.size() == 0) { + QUIC_BUG(quic_bug_10862_2) << "Trying to pop from an empty container."; + return; + } + container_.pop_front(); + if (container_.size() == 0) { + cached_index_.reset(); + } + if (cached_index_.value_or(0) > 0) { + cached_index_ = cached_index_.value() - 1; + } +} + +template +typename QuicIntervalDeque::Iterator +QuicIntervalDeque::DataBegin() { + return Iterator(0, this); +} + +template +typename QuicIntervalDeque::Iterator QuicIntervalDeque::DataEnd() { + return Iterator(container_.size(), this); +} + +template +typename QuicIntervalDeque::Iterator QuicIntervalDeque::DataAt( + const std::size_t interval_begin) { + // No |cached_index_| value means all items can be searched. + if (!cached_index_.has_value()) { + return Search(interval_begin, 0, container_.size()); + } + + const std::size_t cached_index = cached_index_.value(); + QUICHE_DCHECK(cached_index < container_.size()); + + const QuicInterval cached_interval = + container_[cached_index].interval(); + // Does our cached index point directly to what we want? + if (cached_interval.Contains(interval_begin)) { + return Iterator(cached_index, this); + } + + // Are we off-by-one? + const std::size_t next_index = cached_index + 1; + if (next_index < container_.size()) { + if (container_[next_index].interval().Contains(interval_begin)) { + cached_index_ = next_index; + return Iterator(next_index, this); + } + } + + // Small optimization: + // Determine if we should binary search above or below the cached interval. + const std::size_t cached_begin = cached_interval.min(); + bool looking_below = interval_begin < cached_begin; + const std::size_t lower = looking_below ? 0 : cached_index + 1; + const std::size_t upper = looking_below ? cached_index : container_.size(); + Iterator ret = Search(interval_begin, lower, upper); + if (ret == DataEnd()) { + return ret; + } + // Update the |cached_index_| to point to the higher index. + if (!looking_below) { + cached_index_ = ret.index_; + } + return ret; +} + +template +std::size_t QuicIntervalDeque::Size() const { + return container_.size(); +} + +template +bool QuicIntervalDeque::Empty() const { + return container_.size() == 0; +} + +template +template +void QuicIntervalDeque::PushBackUniversal(U&& item) { + QuicInterval interval = item.interval(); + // Adding an empty interval is a bug. + if (interval.Empty()) { + QUIC_BUG(quic_bug_10862_3) + << "Trying to save empty interval to quiche::QuicheCircularDeque."; + return; + } + container_.push_back(std::forward(item)); + if (!cached_index_.has_value()) { + cached_index_ = container_.size() - 1; + } +} + +template +typename QuicIntervalDeque::Iterator QuicIntervalDeque::Search( + const std::size_t interval_begin, const std::size_t begin_index, + const std::size_t end_index) { + auto begin = container_.begin() + begin_index; + auto end = container_.begin() + end_index; + auto res = std::lower_bound(begin, end, interval_begin, IntervalCompare()); + // Just because we run |lower_bound| and it didn't return |container_.end()| + // doesn't mean we found our desired interval. + if (res != end && res->interval().Contains(interval_begin)) { + return Iterator(std::distance(begin, res) + begin_index, this); + } + return DataEnd(); +} + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_INTERVAL_DEQUE_H_ diff --git a/quiche/quic/core/quic_interval_deque_test.cc b/quiche/quic/core/quic_interval_deque_test.cc new file mode 100644 index 000000000000..318059f286fa --- /dev/null +++ b/quiche/quic/core/quic_interval_deque_test.cc @@ -0,0 +1,361 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_interval_deque.h" + +#include +#include + +#include "quiche/quic/core/quic_interval.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_interval_deque_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { +namespace test { +namespace { + +const int32_t kSize = 100; +const std::size_t kIntervalStep = 10; + +} // namespace + +struct TestIntervalItem { + int32_t val; + std::size_t interval_start, interval_end; + QuicInterval interval() const { + return QuicInterval(interval_start, interval_end); + } + TestIntervalItem(int32_t val, std::size_t interval_start, + std::size_t interval_end) + : val(val), interval_start(interval_start), interval_end(interval_end) {} +}; + +using QID = QuicIntervalDeque; + +class QuicIntervalDequeTest : public QuicTest { + public: + QuicIntervalDequeTest() { + // Add items with intervals of |kIntervalStep| size. + for (int32_t i = 0; i < kSize; ++i) { + const std::size_t interval_begin = kIntervalStep * i; + const std::size_t interval_end = interval_begin + kIntervalStep; + qid_.PushBack(TestIntervalItem(i, interval_begin, interval_end)); + } + } + + QID qid_; +}; + +// The goal of this test is to show insertion/push_back, iteration, and and +// deletion/pop_front from the container. +TEST_F(QuicIntervalDequeTest, InsertRemoveSize) { + QID qid; + + EXPECT_EQ(qid.Size(), std::size_t(0)); + qid.PushBack(TestIntervalItem(0, 0, 10)); + EXPECT_EQ(qid.Size(), std::size_t(1)); + qid.PushBack(TestIntervalItem(1, 10, 20)); + EXPECT_EQ(qid.Size(), std::size_t(2)); + qid.PushBack(TestIntervalItem(2, 20, 30)); + EXPECT_EQ(qid.Size(), std::size_t(3)); + qid.PushBack(TestIntervalItem(3, 30, 40)); + EXPECT_EQ(qid.Size(), std::size_t(4)); + + // Advance the index all the way... + int32_t i = 0; + for (auto it = qid.DataAt(0); it != qid.DataEnd(); ++it, ++i) { + const int32_t index = QuicIntervalDequePeer::GetCachedIndex(&qid); + EXPECT_EQ(index, i); + EXPECT_EQ(it->val, i); + } + const int32_t index = QuicIntervalDequePeer::GetCachedIndex(&qid); + EXPECT_EQ(index, -1); + + qid.PopFront(); + EXPECT_EQ(qid.Size(), std::size_t(3)); + qid.PopFront(); + EXPECT_EQ(qid.Size(), std::size_t(2)); + qid.PopFront(); + EXPECT_EQ(qid.Size(), std::size_t(1)); + qid.PopFront(); + EXPECT_EQ(qid.Size(), std::size_t(0)); + + EXPECT_QUIC_BUG(qid.PopFront(), "Trying to pop from an empty container."); +} + +// The goal of this test is to push data into the container at specific +// intervals and show how the |DataAt| method can move the |cached_index| as the +// iterator moves through the data. +TEST_F(QuicIntervalDequeTest, InsertIterateWhole) { + // The write index should point to the beginning of the container. + const int32_t cached_index = QuicIntervalDequePeer::GetCachedIndex(&qid_); + EXPECT_EQ(cached_index, 0); + + auto it = qid_.DataBegin(); + auto end = qid_.DataEnd(); + for (int32_t i = 0; i < kSize; ++i, ++it) { + EXPECT_EQ(it->val, i); + const std::size_t current_iteraval_begin = i * kIntervalStep; + // The |DataAt| method should find the correct interval. + auto lookup = qid_.DataAt(current_iteraval_begin); + EXPECT_EQ(i, lookup->val); + // Make sure the index hasn't changed just from using |DataAt| + const int32_t index_before = QuicIntervalDequePeer::GetCachedIndex(&qid_); + EXPECT_EQ(index_before, i); + // This increment should move the index forward. + lookup++; + // Check that the index has changed. + const int32_t index_after = QuicIntervalDequePeer::GetCachedIndex(&qid_); + const int32_t after_i = (i + 1) == kSize ? -1 : (i + 1); + EXPECT_EQ(index_after, after_i); + EXPECT_NE(it, end); + } +} + +// The goal of this test is to push data into the container at specific +// intervals and show how the |DataAt| method can move the |cached_index| using +// the off-by-one logic. +TEST_F(QuicIntervalDequeTest, OffByOne) { + // The write index should point to the beginning of the container. + const int32_t cached_index = QuicIntervalDequePeer::GetCachedIndex(&qid_); + EXPECT_EQ(cached_index, 0); + + auto it = qid_.DataBegin(); + auto end = qid_.DataEnd(); + for (int32_t i = 0; i < kSize - 1; ++i, ++it) { + EXPECT_EQ(it->val, i); + const int32_t off_by_one_i = i + 1; + const std::size_t current_iteraval_begin = off_by_one_i * kIntervalStep; + // Make sure the index has changed just from using |DataAt| + const int32_t index_before = QuicIntervalDequePeer::GetCachedIndex(&qid_); + EXPECT_EQ(index_before, i); + // The |DataAt| method should find the correct interval. + auto lookup = qid_.DataAt(current_iteraval_begin); + EXPECT_EQ(off_by_one_i, lookup->val); + // Check that the index has changed. + const int32_t index_after = QuicIntervalDequePeer::GetCachedIndex(&qid_); + const int32_t after_i = off_by_one_i == kSize ? -1 : off_by_one_i; + EXPECT_EQ(index_after, after_i); + EXPECT_NE(it, end); + } +} + +// The goal of this test is to push data into the container at specific +// intervals and show modify the structure with a live iterator. +TEST_F(QuicIntervalDequeTest, IteratorInvalidation) { + // The write index should point to the beginning of the container. + const int32_t cached_index = QuicIntervalDequePeer::GetCachedIndex(&qid_); + EXPECT_EQ(cached_index, 0); + + const std::size_t iteraval_begin = (kSize - 1) * kIntervalStep; + auto lookup = qid_.DataAt(iteraval_begin); + EXPECT_EQ((*lookup).val, (kSize - 1)); + qid_.PopFront(); + EXPECT_QUIC_BUG(lookup++, "Iterator out of bounds."); + auto lookup_end = qid_.DataAt(iteraval_begin + kIntervalStep); + EXPECT_EQ(lookup_end, qid_.DataEnd()); +} + +// The goal of this test is the same as |InsertIterateWhole| but to +// skip certain intervals and show the |cached_index| is updated properly. +TEST_F(QuicIntervalDequeTest, InsertIterateSkip) { + // The write index should point to the beginning of the container. + const int32_t cached_index = QuicIntervalDequePeer::GetCachedIndex(&qid_); + EXPECT_EQ(cached_index, 0); + + const std::size_t step = 4; + for (int32_t i = 0; i < kSize; i += 4) { + if (i != 0) { + const int32_t before_i = (i - (step - 1)); + EXPECT_EQ(QuicIntervalDequePeer::GetCachedIndex(&qid_), before_i); + } + const std::size_t current_iteraval_begin = i * kIntervalStep; + // The |DataAt| method should find the correct interval. + auto lookup = qid_.DataAt(current_iteraval_begin); + EXPECT_EQ(i, lookup->val); + // Make sure the index _has_ changed just from using |DataAt| since we're + // skipping data. + const int32_t index_before = QuicIntervalDequePeer::GetCachedIndex(&qid_); + EXPECT_EQ(index_before, i); + // This increment should move the index forward. + lookup++; + // Check that the index has changed. + const int32_t index_after = QuicIntervalDequePeer::GetCachedIndex(&qid_); + const int32_t after_i = (i + 1) == kSize ? -1 : (i + 1); + EXPECT_EQ(index_after, after_i); + } +} + +// The goal of this test is the same as |InsertIterateWhole| but it has +// |PopFront| calls interleaved to show the |cached_index| updates correctly. +TEST_F(QuicIntervalDequeTest, InsertDeleteIterate) { + // The write index should point to the beginning of the container. + const int32_t index = QuicIntervalDequePeer::GetCachedIndex(&qid_); + EXPECT_EQ(index, 0); + + std::size_t limit = 0; + for (int32_t i = 0; limit < qid_.Size(); ++i, ++limit) { + // Always point to the beginning of the container. + auto it = qid_.DataBegin(); + EXPECT_EQ(it->val, i); + + // Get an iterator. + const std::size_t current_iteraval_begin = i * kIntervalStep; + auto lookup = qid_.DataAt(current_iteraval_begin); + const int32_t index_before = QuicIntervalDequePeer::GetCachedIndex(&qid_); + // The index should always point to 0. + EXPECT_EQ(index_before, 0); + // This iterator increment should effect the index. + lookup++; + const int32_t index_after = QuicIntervalDequePeer::GetCachedIndex(&qid_); + EXPECT_EQ(index_after, 1); + // Decrement the |temp_size| and pop from the front. + qid_.PopFront(); + // Show the index has been updated to point to 0 again (from 1). + const int32_t index_after_pop = + QuicIntervalDequePeer::GetCachedIndex(&qid_); + EXPECT_EQ(index_after_pop, 0); + } +} + +// The goal of this test is to move the index to the end and then add more data +// to show it can be reset to a valid index. +TEST_F(QuicIntervalDequeTest, InsertIterateInsert) { + // The write index should point to the beginning of the container. + const int32_t index = QuicIntervalDequePeer::GetCachedIndex(&qid_); + EXPECT_EQ(index, 0); + + int32_t iterated_elements = 0; + for (int32_t i = 0; i < kSize; ++i, ++iterated_elements) { + // Get an iterator. + const std::size_t current_iteraval_begin = i * kIntervalStep; + auto lookup = qid_.DataAt(current_iteraval_begin); + const int32_t index_before = QuicIntervalDequePeer::GetCachedIndex(&qid_); + // The index should always point to i. + EXPECT_EQ(index_before, i); + // This iterator increment should effect the index. + lookup++; + // Show the index has been updated to point to i + 1 or -1 if at the end. + const int32_t index_after = QuicIntervalDequePeer::GetCachedIndex(&qid_); + const int32_t after_i = (i + 1) == kSize ? -1 : (i + 1); + EXPECT_EQ(index_after, after_i); + } + const int32_t invalid_index = QuicIntervalDequePeer::GetCachedIndex(&qid_); + EXPECT_EQ(invalid_index, -1); + + // Add more data to the container, making the index valid. + const std::size_t offset = qid_.Size(); + for (int32_t i = 0; i < kSize; ++i) { + const std::size_t interval_begin = offset + (kIntervalStep * i); + const std::size_t interval_end = offset + interval_begin + kIntervalStep; + qid_.PushBack(TestIntervalItem(i + offset, interval_begin, interval_end)); + const int32_t index_current = QuicIntervalDequePeer::GetCachedIndex(&qid_); + // Index should now be valid and equal to the size of the container before + // adding more items to it. + EXPECT_EQ(index_current, iterated_elements); + } + // Show the index is still valid and hasn't changed since the first iteration + // of the loop. + const int32_t index_after_add = QuicIntervalDequePeer::GetCachedIndex(&qid_); + EXPECT_EQ(index_after_add, iterated_elements); + + // Iterate over all the data in the container and eventually reset the index + // as we did before. + for (int32_t i = 0; i < kSize; ++i, ++iterated_elements) { + const std::size_t interval_begin = offset + (kIntervalStep * i); + const int32_t index_current = QuicIntervalDequePeer::GetCachedIndex(&qid_); + EXPECT_EQ(index_current, iterated_elements); + auto lookup = qid_.DataAt(interval_begin); + const int32_t expected_value = i + offset; + EXPECT_EQ(lookup->val, expected_value); + lookup++; + const int32_t after_inc = + (iterated_elements + 1) == (kSize * 2) ? -1 : (iterated_elements + 1); + const int32_t after_index = QuicIntervalDequePeer::GetCachedIndex(&qid_); + EXPECT_EQ(after_index, after_inc); + } + // Show the index is now invalid. + const int32_t invalid_index_again = + QuicIntervalDequePeer::GetCachedIndex(&qid_); + EXPECT_EQ(invalid_index_again, -1); +} + +// The goal of this test is to push data into the container at specific +// intervals and show how the |DataAt| can iterate over already scanned data. +TEST_F(QuicIntervalDequeTest, RescanData) { + // The write index should point to the beginning of the container. + const int32_t index = QuicIntervalDequePeer::GetCachedIndex(&qid_); + EXPECT_EQ(index, 0); + + auto it = qid_.DataBegin(); + auto end = qid_.DataEnd(); + for (int32_t i = 0; i < kSize - 1; ++i, ++it) { + EXPECT_EQ(it->val, i); + const std::size_t current_iteraval_begin = i * kIntervalStep; + // The |DataAt| method should find the correct interval. + auto lookup = qid_.DataAt(current_iteraval_begin); + EXPECT_EQ(i, lookup->val); + // Make sure the index has changed just from using |DataAt| + const int32_t cached_index_before = + QuicIntervalDequePeer::GetCachedIndex(&qid_); + EXPECT_EQ(cached_index_before, i); + // Ensure the real index has changed just from using |DataAt| and the + // off-by-one logic + const int32_t index_before = QuicIntervalDequePeer::GetCachedIndex(&qid_); + const int32_t before_i = i; + EXPECT_EQ(index_before, before_i); + // This increment should move the cached index forward. + lookup++; + // Check that the cached index has moved foward. + const int32_t cached_index_after = + QuicIntervalDequePeer::GetCachedIndex(&qid_); + const int32_t after_i = (i + 1); + EXPECT_EQ(cached_index_after, after_i); + EXPECT_NE(it, end); + } + + // Iterate over items which have been consumed before. + int32_t expected_index = static_cast(kSize - 1); + for (int32_t i = 0; i < kSize - 1; ++i) { + const std::size_t current_iteraval_begin = i * kIntervalStep; + // The |DataAt| method should find the correct interval. + auto lookup = qid_.DataAt(current_iteraval_begin); + EXPECT_EQ(i, lookup->val); + // This increment shouldn't move the index forward as the index is currently + // ahead. + lookup++; + // Check that the index hasn't moved foward. + const int32_t index_after = QuicIntervalDequePeer::GetCachedIndex(&qid_); + EXPECT_EQ(index_after, expected_index); + EXPECT_NE(it, end); + } +} + +// The goal of this test is to show that popping from an empty container is a +// bug. +TEST_F(QuicIntervalDequeTest, PopEmpty) { + QID qid; + EXPECT_TRUE(qid.Empty()); + EXPECT_QUIC_BUG(qid.PopFront(), "Trying to pop from an empty container."); +} + +// The goal of this test is to show that adding a zero-sized interval is a bug. +TEST_F(QuicIntervalDequeTest, ZeroSizedInterval) { + QID qid; + EXPECT_QUIC_BUG(qid.PushBack(TestIntervalItem(0, 0, 0)), + "Trying to save empty interval to ."); +} + +// The goal of this test is to show that an iterator to an empty container +// returns |DataEnd|. +TEST_F(QuicIntervalDequeTest, IteratorEmpty) { + QID qid; + auto it = qid.DataAt(0); + EXPECT_EQ(it, qid.DataEnd()); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_interval_set.h b/quiche/quic/core/quic_interval_set.h new file mode 100644 index 000000000000..11d0270d46cc --- /dev/null +++ b/quiche/quic/core/quic_interval_set.h @@ -0,0 +1,885 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_INTERVAL_SET_H_ +#define QUICHE_QUIC_CORE_QUIC_INTERVAL_SET_H_ + +// QuicIntervalSet is a data structure used to represent a sorted set of +// non-empty, non-adjacent, and mutually disjoint intervals. Mutations to an +// interval set preserve these properties, altering the set as needed. For +// example, adding [2, 3) to a set containing only [1, 2) would result in the +// set containing the single interval [1, 3). +// +// Supported operations include testing whether an Interval is contained in the +// QuicIntervalSet, comparing two QuicIntervalSets, and performing +// QuicIntervalSet union, intersection, and difference. +// +// QuicIntervalSet maintains the minimum number of entries needed to represent +// the set of underlying intervals. When the QuicIntervalSet is modified (e.g. +// due to an Add operation), other interval entries may be coalesced, removed, +// or otherwise modified in order to maintain this invariant. The intervals are +// maintained in sorted order, by ascending min() value. +// +// The reader is cautioned to beware of the terminology used here: this library +// uses the terms "min" and "max" rather than "begin" and "end" as is +// conventional for the STL. The terminology [min, max) refers to the half-open +// interval which (if the interval is not empty) contains min but does not +// contain max. An interval is considered empty if min >= max. +// +// T is required to be default- and copy-constructible, to have an assignment +// operator, a difference operator (operator-()), and the full complement of +// comparison operators (<, <=, ==, !=, >=, >). These requirements are inherited +// from value_type. +// +// QuicIntervalSet has constant-time move operations. +// +// +// Examples: +// QuicIntervalSet intervals; +// intervals.Add(Interval(10, 20)); +// intervals.Add(Interval(30, 40)); +// // intervals contains [10,20) and [30,40). +// intervals.Add(Interval(15, 35)); +// // intervals has been coalesced. It now contains the single range [10,40). +// EXPECT_EQ(1, intervals.Size()); +// EXPECT_TRUE(intervals.Contains(Interval(10, 40))); +// +// intervals.Difference(Interval(10, 20)); +// // intervals should now contain the single range [20, 40). +// EXPECT_EQ(1, intervals.Size()); +// EXPECT_TRUE(intervals.Contains(Interval(20, 40))); + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "quiche/quic/core/quic_interval.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/common/platform/api/quiche_containers.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace quic { + +template +class QUIC_NO_EXPORT QuicIntervalSet { + public: + using value_type = QuicInterval; + + private: + struct QUIC_NO_EXPORT IntervalLess { + using is_transparent = void; + bool operator()(const value_type& a, const value_type& b) const; + // These transparent overloads are used when we do all of our searches (via + // Set::lower_bound() and Set::upper_bound()), which avoids the need to + // construct an interval when we are looking for a point and also avoids + // needing to worry about comparing overlapping intervals in the overload + // that takes two value_types (the one just above this comment). + bool operator()(const value_type& a, const T& point) const; + bool operator()(const value_type& a, T&& point) const; + bool operator()(const T& point, const value_type& a) const; + bool operator()(T&& point, const value_type& a) const; + }; + + using Set = quiche::QuicheSmallOrderedSet; + + public: + using const_iterator = typename Set::const_iterator; + using const_reverse_iterator = typename Set::const_reverse_iterator; + + // Instantiates an empty QuicIntervalSet. + QuicIntervalSet() = default; + + // Instantiates a QuicIntervalSet containing exactly one initial half-open + // interval [min, max), unless the given interval is empty, in which case the + // QuicIntervalSet will be empty. + explicit QuicIntervalSet(const value_type& interval) { Add(interval); } + + // Instantiates a QuicIntervalSet containing the half-open interval [min, + // max). + QuicIntervalSet(const T& min, const T& max) { Add(min, max); } + + QuicIntervalSet(std::initializer_list il) { assign(il); } + + // Clears this QuicIntervalSet. + void Clear() { intervals_.clear(); } + + // Returns the number of disjoint intervals contained in this QuicIntervalSet. + size_t Size() const { return intervals_.size(); } + + // Returns the smallest interval that contains all intervals in this + // QuicIntervalSet, or the empty interval if the set is empty. + value_type SpanningInterval() const; + + // Adds "interval" to this QuicIntervalSet. Adding the empty interval has no + // effect. + void Add(const value_type& interval); + + // Adds the interval [min, max) to this QuicIntervalSet. Adding the empty + // interval has no effect. + void Add(const T& min, const T& max) { Add(value_type(min, max)); } + + // Same semantics as Add(const value_type&), but optimized for the case where + // rbegin()->min() <= |interval|.min() <= rbegin()->max(). + void AddOptimizedForAppend(const value_type& interval) { + if (Empty() || !GetQuicFlag(quic_interval_set_enable_add_optimization)) { + Add(interval); + return; + } + + const_reverse_iterator last_interval = intervals_.rbegin(); + + // If interval.min() is outside of [last_interval->min, last_interval->max], + // we can not simply extend last_interval->max. + if (interval.min() < last_interval->min() || + interval.min() > last_interval->max()) { + Add(interval); + return; + } + + if (interval.max() <= last_interval->max()) { + // interval is fully contained by last_interval. + return; + } + + // Extend last_interval.max to interval.max, in place. + // + // Set does not allow in-place updates due to the potential of violating its + // ordering requirements. But we know setting the max of the last interval + // is safe w.r.t set ordering and other invariants of QuicIntervalSet, so we + // force an in-place update for performance. + const_cast(&(*last_interval))->SetMax(interval.max()); + } + + // Same semantics as Add(const T&, const T&), but optimized for the case where + // rbegin()->max() == |min|. + void AddOptimizedForAppend(const T& min, const T& max) { + AddOptimizedForAppend(value_type(min, max)); + } + + // TODO(wub): Similar to AddOptimizedForAppend, we can also have a + // AddOptimizedForPrepend if there is a use case. + + // Remove the first interval. + // REQUIRES: !Empty() + void PopFront() { + QUICHE_DCHECK(!Empty()); + intervals_.erase(intervals_.begin()); + } + + // Trim all values that are smaller than |value|. Which means + // a) If all values in an interval is smaller than |value|, the entire + // interval is removed. + // b) If some but not all values in an interval is smaller than |value|, the + // min of that interval is raised to |value|. + // Returns true if some intervals are trimmed. + bool TrimLessThan(const T& value) { + // Number of intervals that are fully or partially trimmed. + size_t num_intervals_trimmed = 0; + + while (!intervals_.empty()) { + const_iterator first_interval = intervals_.begin(); + if (first_interval->min() >= value) { + break; + } + + ++num_intervals_trimmed; + + if (first_interval->max() <= value) { + // a) Trim the entire interval. + intervals_.erase(first_interval); + continue; + } + + // b) Trim a prefix of the interval. + // + // Set does not allow in-place updates due to the potential of violating + // its ordering requirements. But increasing the min of the first interval + // will not break the ordering, hence the const_cast. + const_cast(&(*first_interval))->SetMin(value); + break; + } + + return num_intervals_trimmed != 0; + } + + // Returns true if this QuicIntervalSet is empty. + bool Empty() const { return intervals_.empty(); } + + // Returns true if any interval in this QuicIntervalSet contains the indicated + // value. + bool Contains(const T& value) const; + + // Returns true if there is some interval in this QuicIntervalSet that wholly + // contains the given interval. An interval O "wholly contains" a non-empty + // interval I if O.Contains(p) is true for every p in I. This is the same + // definition used by value_type::Contains(). This method returns false on + // the empty interval, due to a (perhaps unintuitive) convention inherited + // from value_type. + // Example: + // Assume an QuicIntervalSet containing the entries { [10,20), [30,40) }. + // Contains(Interval(15, 16)) returns true, because [10,20) contains + // [15,16). However, Contains(Interval(15, 35)) returns false. + bool Contains(const value_type& interval) const; + + // Returns true if for each interval in "other", there is some (possibly + // different) interval in this QuicIntervalSet which wholly contains it. See + // Contains(const value_type& interval) for the meaning of "wholly contains". + // Perhaps unintuitively, this method returns false if "other" is the empty + // set. The algorithmic complexity of this method is O(other.Size() * + // log(this->Size())). The method could be rewritten to run in O(other.Size() + // + this->Size()), and this alternative could be implemented as a free + // function using the public API. + bool Contains(const QuicIntervalSet& other) const; + + // Returns true if there is some interval in this QuicIntervalSet that wholly + // contains the interval [min, max). See Contains(const value_type&). + bool Contains(const T& min, const T& max) const { + return Contains(value_type(min, max)); + } + + // Returns true if for some interval in "other", there is some interval in + // this QuicIntervalSet that intersects with it. See value_type::Intersects() + // for the definition of interval intersection. Runs in time O(n+m) where n + // is the number of intervals in this and m is the number of intervals in + // other. + bool Intersects(const QuicIntervalSet& other) const; + + // Returns an iterator to the value_type in the QuicIntervalSet that contains + // the given value. In other words, returns an iterator to the unique interval + // [min, max) in the QuicIntervalSet that has the property min <= value < max. + // If there is no such interval, this method returns end(). + const_iterator Find(const T& value) const; + + // Returns an iterator to the value_type in the QuicIntervalSet that wholly + // contains the given interval. In other words, returns an iterator to the + // unique interval outer in the QuicIntervalSet that has the property that + // outer.Contains(interval). If there is no such interval, or if interval is + // empty, returns end(). + const_iterator Find(const value_type& interval) const; + + // Returns an iterator to the value_type in the QuicIntervalSet that wholly + // contains [min, max). In other words, returns an iterator to the unique + // interval outer in the QuicIntervalSet that has the property that + // outer.Contains(Interval(min, max)). If there is no such interval, or if + // interval is empty, returns end(). + const_iterator Find(const T& min, const T& max) const { + return Find(value_type(min, max)); + } + + // Returns an iterator pointing to the first value_type which contains or + // goes after the given value. + // + // Example: + // [10, 20) [30, 40) + // ^ LowerBound(10) + // ^ LowerBound(15) + // ^ LowerBound(20) + // ^ LowerBound(25) + const_iterator LowerBound(const T& value) const; + + // Returns an iterator pointing to the first value_type which goes after + // the given value. + // + // Example: + // [10, 20) [30, 40) + // ^ UpperBound(10) + // ^ UpperBound(15) + // ^ UpperBound(20) + // ^ UpperBound(25) + const_iterator UpperBound(const T& value) const; + + // Returns true if every value within the passed interval is not Contained + // within the QuicIntervalSet. + // Note that empty intervals are always considered disjoint from the + // QuicIntervalSet (even though the QuicIntervalSet doesn't `Contain` them). + bool IsDisjoint(const value_type& interval) const; + + // Merges all the values contained in "other" into this QuicIntervalSet. + // + // Performance: Let n == Size() and m = other.Size(). Union() runs in O(m) + // Set operations, so that if Set is a tree, it runs in time O(m log(n+m)) and + // if Set is a flat_set it runs in time O(m(n+m)). In principle, for the + // flat_set, we should be able to make this run in time O(n+m). + // + // TODO(bradleybear): Make Union() run in time O(n+m) for flat_set. This may + // require an additional template parameter to indicate that the Set is a + // linear-time data structure instead of a log-time data structure. + void Union(const QuicIntervalSet& other); + + // Modifies this QuicIntervalSet so that it contains only those values that + // are currently present both in *this and in the QuicIntervalSet "other". + void Intersection(const QuicIntervalSet& other); + + // Mutates this QuicIntervalSet so that it contains only those values that are + // currently in *this but not in "interval". + void Difference(const value_type& interval); + + // Mutates this QuicIntervalSet so that it contains only those values that are + // currently in *this but not in the interval [min, max). + void Difference(const T& min, const T& max); + + // Mutates this QuicIntervalSet so that it contains only those values that are + // currently in *this but not in the QuicIntervalSet "other". Runs in time + // O(n+m) where n is this->Size(), m is other.Size(), regardless of whether + // the Set is a flat_set or a std::set. + void Difference(const QuicIntervalSet& other); + + // Mutates this QuicIntervalSet so that it contains only those values that are + // in [min, max) but not currently in *this. + void Complement(const T& min, const T& max); + + // QuicIntervalSet's begin() iterator. The invariants of QuicIntervalSet + // guarantee that for each entry e in the set, e.min() < e.max() (because the + // entries are non-empty) and for each entry f that appears later in the set, + // e.max() < f.min() (because the entries are ordered, pairwise-disjoint, and + // non-adjacent). Modifications to this QuicIntervalSet invalidate these + // iterators. + const_iterator begin() const { return intervals_.begin(); } + + // QuicIntervalSet's end() iterator. + const_iterator end() const { return intervals_.end(); } + + // QuicIntervalSet's rbegin() and rend() iterators. Iterator invalidation + // semantics are the same as those for begin() / end(). + const_reverse_iterator rbegin() const { return intervals_.rbegin(); } + + const_reverse_iterator rend() const { return intervals_.rend(); } + + template + void assign(Iter first, Iter last) { + Clear(); + for (; first != last; ++first) Add(*first); + } + + void assign(std::initializer_list il) { + assign(il.begin(), il.end()); + } + + // Returns a human-readable representation of this set. This will typically be + // (though is not guaranteed to be) of the form + // "[a1, b1) [a2, b2) ... [an, bn)" + // where the intervals are in the same order as given by traversal from + // begin() to end(). This representation is intended for human consumption; + // computer programs should not rely on the output being in exactly this form. + std::string ToString() const; + + QuicIntervalSet& operator=(std::initializer_list il) { + assign(il.begin(), il.end()); + return *this; + } + + friend bool operator==(const QuicIntervalSet& a, const QuicIntervalSet& b) { + return a.Size() == b.Size() && + std::equal(a.begin(), a.end(), b.begin(), NonemptyIntervalEq()); + } + + friend bool operator!=(const QuicIntervalSet& a, const QuicIntervalSet& b) { + return !(a == b); + } + + private: + // Simple member-wise equality, since all intervals are non-empty. + struct QUIC_NO_EXPORT NonemptyIntervalEq { + bool operator()(const value_type& a, const value_type& b) const { + return a.min() == b.min() && a.max() == b.max(); + } + }; + + // Returns true if this set is valid (i.e. all intervals in it are non-empty, + // non-adjacent, and mutually disjoint). Currently this is used as an + // integrity check by the Intersection() and Difference() methods, but is only + // invoked for debug builds (via QUICHE_DCHECK). + bool Valid() const; + + // Finds the first interval that potentially intersects 'other'. + const_iterator FindIntersectionCandidate(const QuicIntervalSet& other) const; + + // Finds the first interval that potentially intersects 'interval'. More + // precisely, return an interator it pointing at the last interval J such that + // interval <= J. If all the intervals are > J then return begin(). + const_iterator FindIntersectionCandidate(const value_type& interval) const; + + // Helper for Intersection() and Difference(): Finds the next pair of + // intervals from 'x' and 'y' that intersect. 'mine' is an iterator + // over x->intervals_. 'theirs' is an iterator over y.intervals_. 'mine' + // and 'theirs' are advanced until an intersecting pair is found. + // Non-intersecting intervals (aka "holes") from x->intervals_ can be + // optionally erased by "on_hole". "on_hole" must return an iterator to the + // first element in 'x' after the hole, or x->intervals_.end() if no elements + // exist after the hole. + template + static bool FindNextIntersectingPairImpl(X* x, const QuicIntervalSet& y, + const_iterator* mine, + const_iterator* theirs, + Func on_hole); + + // The variant of the above method that doesn't mutate this QuicIntervalSet. + bool FindNextIntersectingPair(const QuicIntervalSet& other, + const_iterator* mine, + const_iterator* theirs) const { + return FindNextIntersectingPairImpl( + this, other, mine, theirs, + [](const QuicIntervalSet*, const_iterator, const_iterator end) { + return end; + }); + } + + // The variant of the above method that mutates this QuicIntervalSet by + // erasing holes. + bool FindNextIntersectingPairAndEraseHoles(const QuicIntervalSet& other, + const_iterator* mine, + const_iterator* theirs) { + return FindNextIntersectingPairImpl( + this, other, mine, theirs, + [](QuicIntervalSet* x, const_iterator from, const_iterator to) { + return x->intervals_.erase(from, to); + }); + } + + // The representation for the intervals. The intervals in this set are + // non-empty, pairwise-disjoint, non-adjacent and ordered in ascending order + // by min(). + Set intervals_; +}; + +template +auto operator<<(std::ostream& out, const QuicIntervalSet& seq) + -> decltype(out << *seq.begin()) { + out << "{"; + for (const auto& interval : seq) { + out << " " << interval; + } + out << " }"; + + return out; +} + +//============================================================================== +// Implementation details: Clients can stop reading here. + +template +typename QuicIntervalSet::value_type QuicIntervalSet::SpanningInterval() + const { + value_type result; + if (!intervals_.empty()) { + result.SetMin(intervals_.begin()->min()); + result.SetMax(intervals_.rbegin()->max()); + } + return result; +} + +template +void QuicIntervalSet::Add(const value_type& interval) { + if (interval.Empty()) return; + const_iterator it = intervals_.lower_bound(interval.min()); + value_type the_union = interval; + if (it != intervals_.begin()) { + --it; + if (it->Separated(the_union)) { + ++it; + } + } + // Don't erase the elements one at a time, since that will produce quadratic + // work on a flat_set, and apparently an extra log-factor of work for a + // tree-based set. Instead identify the first and last intervals that need to + // be erased, and call erase only once. + const_iterator start = it; + while (it != intervals_.end() && !it->Separated(the_union)) { + the_union.SpanningUnion(*it); + ++it; + } + intervals_.erase(start, it); + intervals_.insert(the_union); +} + +template +bool QuicIntervalSet::Contains(const T& value) const { + // Find the first interval with min() > value, then move back one step + const_iterator it = intervals_.upper_bound(value); + if (it == intervals_.begin()) return false; + --it; + return it->Contains(value); +} + +template +bool QuicIntervalSet::Contains(const value_type& interval) const { + // Find the first interval with min() > value, then move back one step. + const_iterator it = intervals_.upper_bound(interval.min()); + if (it == intervals_.begin()) return false; + --it; + return it->Contains(interval); +} + +template +bool QuicIntervalSet::Contains(const QuicIntervalSet& other) const { + if (!SpanningInterval().Contains(other.SpanningInterval())) { + return false; + } + + for (const_iterator i = other.begin(); i != other.end(); ++i) { + // If we don't contain the interval, can return false now. + if (!Contains(*i)) { + return false; + } + } + return true; +} + +// This method finds the interval that Contains() "value", if such an interval +// exists in the QuicIntervalSet. The way this is done is to locate the +// "candidate interval", the only interval that could *possibly* contain value, +// and test it using Contains(). The candidate interval is the interval with the +// largest min() having min() <= value. +// +// Another detail involves the choice of which Set method to use to try to find +// the candidate interval. The most appropriate entry point is +// Set::upper_bound(), which finds the least interval with a min > the +// value. The semantics of upper_bound() are slightly different from what we +// want (namely, to find the greatest interval which is <= the probe interval) +// but they are close enough; the interval found by upper_bound() will always be +// one step past the interval we are looking for (if it exists) or at begin() +// (if it does not). Getting to the proper interval is a simple matter of +// decrementing the iterator. +template +typename QuicIntervalSet::const_iterator QuicIntervalSet::Find( + const T& value) const { + const_iterator it = intervals_.upper_bound(value); + if (it == intervals_.begin()) return intervals_.end(); + --it; + if (it->Contains(value)) + return it; + else + return intervals_.end(); +} + +// This method finds the interval that Contains() the interval "probe", if such +// an interval exists in the QuicIntervalSet. The way this is done is to locate +// the "candidate interval", the only interval that could *possibly* contain +// "probe", and test it using Contains(). We use the same algorithm as for +// Find(value), except that instead of checking that the value is contained, we +// check that the probe is contained. +template +typename QuicIntervalSet::const_iterator QuicIntervalSet::Find( + const value_type& probe) const { + const_iterator it = intervals_.upper_bound(probe.min()); + if (it == intervals_.begin()) return intervals_.end(); + --it; + if (it->Contains(probe)) + return it; + else + return intervals_.end(); +} + +template +typename QuicIntervalSet::const_iterator QuicIntervalSet::LowerBound( + const T& value) const { + const_iterator it = intervals_.lower_bound(value); + if (it == intervals_.begin()) { + return it; + } + + // The previous intervals_.lower_bound() checking is essentially based on + // interval.min(), so we need to check whether the `value` is contained in + // the previous interval. + --it; + if (it->Contains(value)) { + return it; + } else { + return ++it; + } +} + +template +typename QuicIntervalSet::const_iterator QuicIntervalSet::UpperBound( + const T& value) const { + return intervals_.upper_bound(value); +} + +template +bool QuicIntervalSet::IsDisjoint(const value_type& interval) const { + if (interval.Empty()) return true; + // Find the first interval with min() > interval.min() + const_iterator it = intervals_.upper_bound(interval.min()); + if (it != intervals_.end() && interval.max() > it->min()) return false; + if (it == intervals_.begin()) return true; + --it; + return it->max() <= interval.min(); +} + +template +void QuicIntervalSet::Union(const QuicIntervalSet& other) { + for (const value_type& interval : other.intervals_) { + Add(interval); + } +} + +template +typename QuicIntervalSet::const_iterator +QuicIntervalSet::FindIntersectionCandidate( + const QuicIntervalSet& other) const { + return FindIntersectionCandidate(*other.intervals_.begin()); +} + +template +typename QuicIntervalSet::const_iterator +QuicIntervalSet::FindIntersectionCandidate( + const value_type& interval) const { + // Use upper_bound to efficiently find the first interval in intervals_ + // where min() is greater than interval.min(). If the result + // isn't the beginning of intervals_ then move backwards one interval since + // the interval before it is the first candidate where max() may be + // greater than interval.min(). + // In other words, no interval before that can possibly intersect with any + // of other.intervals_. + const_iterator mine = intervals_.upper_bound(interval.min()); + if (mine != intervals_.begin()) { + --mine; + } + return mine; +} + +template +template +bool QuicIntervalSet::FindNextIntersectingPairImpl(X* x, + const QuicIntervalSet& y, + const_iterator* mine, + const_iterator* theirs, + Func on_hole) { + QUICHE_CHECK(x != nullptr); + if ((*mine == x->intervals_.end()) || (*theirs == y.intervals_.end())) { + return false; + } + while (!(**mine).Intersects(**theirs)) { + const_iterator erase_first = *mine; + // Skip over intervals in 'mine' that don't reach 'theirs'. + while (*mine != x->intervals_.end() && (**mine).max() <= (**theirs).min()) { + ++(*mine); + } + *mine = on_hole(x, erase_first, *mine); + // We're done if the end of intervals_ is reached. + if (*mine == x->intervals_.end()) { + return false; + } + // Skip over intervals 'theirs' that don't reach 'mine'. + while (*theirs != y.intervals_.end() && + (**theirs).max() <= (**mine).min()) { + ++(*theirs); + } + // If the end of other.intervals_ is reached, we're done. + if (*theirs == y.intervals_.end()) { + on_hole(x, *mine, x->intervals_.end()); + return false; + } + } + return true; +} + +template +void QuicIntervalSet::Intersection(const QuicIntervalSet& other) { + if (!SpanningInterval().Intersects(other.SpanningInterval())) { + intervals_.clear(); + return; + } + + const_iterator mine = FindIntersectionCandidate(other); + // Remove any intervals that cannot possibly intersect with other.intervals_. + mine = intervals_.erase(intervals_.begin(), mine); + const_iterator theirs = other.FindIntersectionCandidate(*this); + + while (FindNextIntersectingPairAndEraseHoles(other, &mine, &theirs)) { + // OK, *mine and *theirs intersect. Now, we find the largest + // span of intervals in other (starting at theirs) - say [a..b] + // - that intersect *mine, and we replace *mine with (*mine + // intersect x) for all x in [a..b] Note that subsequent + // intervals in this can't intersect any intervals in [a..b) -- + // they may only intersect b or subsequent intervals in other. + value_type i(*mine); + intervals_.erase(mine); + mine = intervals_.end(); + value_type intersection; + while (theirs != other.intervals_.end() && + i.Intersects(*theirs, &intersection)) { + std::pair ins = intervals_.insert(intersection); + QUICHE_DCHECK(ins.second); + mine = ins.first; + ++theirs; + } + QUICHE_DCHECK(mine != intervals_.end()); + --theirs; + ++mine; + } + QUICHE_DCHECK(Valid()); +} + +template +bool QuicIntervalSet::Intersects(const QuicIntervalSet& other) const { + // Don't bother to handle nonoverlapping spanning intervals as a special case. + // This code runs in time O(n+m), as guaranteed, even for that case . + // Handling the nonoverlapping spanning intervals as a special case doesn't + // improve the asymptotics but does make the code more complex. + auto mine = intervals_.begin(); + auto theirs = other.intervals_.begin(); + while (mine != intervals_.end() && theirs != other.intervals_.end()) { + if (mine->Intersects(*theirs)) + return true; + else if (*mine < *theirs) + ++mine; + else + ++theirs; + } + return false; +} + +template +void QuicIntervalSet::Difference(const value_type& interval) { + if (!SpanningInterval().Intersects(interval)) { + return; + } + Difference(QuicIntervalSet(interval)); +} + +template +void QuicIntervalSet::Difference(const T& min, const T& max) { + Difference(value_type(min, max)); +} + +template +void QuicIntervalSet::Difference(const QuicIntervalSet& other) { + // In order to avoid quadratic-time when using a flat set, we don't try to + // update intervals_ in place. Instead we build up a new result_, always + // inserting at the end which is O(1) time per insertion. Since the number of + // elements in the result is O(Size() + other.Size()), the cost for all the + // insertions is also O(Size() + other.Size()). + // + // We look at all the elements of intervals_, so that's O(Size()). + // + // We also look at all the elements of other.intervals_, for O(other.Size()). + if (Empty()) return; + Set result; + const_iterator mine = intervals_.begin(); + value_type myinterval = *mine; + const_iterator theirs = other.intervals_.begin(); + while (mine != intervals_.end()) { + // Loop invariants: + // myinterval is nonempty. + // mine points at a range that is a suffix of myinterval. + QUICHE_DCHECK(!myinterval.Empty()); + QUICHE_DCHECK(myinterval.max() == mine->max()); + + // There are 3 cases. + // myinterval is completely before theirs (treat theirs==end() as if it is + // infinity). + // --> consume myinterval into result. + // myinterval is completely after theirs + // --> theirs can no longer affect us, so ++theirs. + // myinterval touches theirs with a prefix of myinterval not touching + // *theirs. + // --> consume the prefix of myinterval into the result. + // myinterval touches theirs, with the first element of myinterval in + // *theirs. + // -> reduce myinterval + if (theirs == other.intervals_.end() || myinterval.max() <= theirs->min()) { + // Keep all of my_interval. + result.insert(result.end(), myinterval); + myinterval.Clear(); + } else if (theirs->max() <= myinterval.min()) { + ++theirs; + } else if (myinterval.min() < theirs->min()) { + // Keep a nonempty prefix of my interval. + result.insert(result.end(), value_type(myinterval.min(), theirs->min())); + myinterval.SetMin(theirs->max()); + } else { + // myinterval starts at or after *theirs, chop down myinterval. + myinterval.SetMin(theirs->max()); + } + // if myinterval became empty, find the next interval + if (myinterval.Empty()) { + ++mine; + if (mine != intervals_.end()) { + myinterval = *mine; + } + } + } + std::swap(result, intervals_); + QUICHE_DCHECK(Valid()); +} + +template +void QuicIntervalSet::Complement(const T& min, const T& max) { + QuicIntervalSet span(min, max); + span.Difference(*this); + intervals_.swap(span.intervals_); +} + +template +std::string QuicIntervalSet::ToString() const { + std::ostringstream os; + os << *this; + return os.str(); +} + +template +bool QuicIntervalSet::Valid() const { + const_iterator prev = end(); + for (const_iterator it = begin(); it != end(); ++it) { + // invalid or empty interval. + if (it->min() >= it->max()) return false; + // Not sorted, not disjoint, or adjacent. + if (prev != end() && prev->max() >= it->min()) return false; + prev = it; + } + return true; +} + +// This comparator orders intervals first by ascending min(). The Set never +// contains overlapping intervals, so that suffices. +template +bool QuicIntervalSet::IntervalLess::operator()(const value_type& a, + const value_type& b) const { + // This overload is probably used only by Set::insert(). + return a.min() < b.min(); +} + +// It appears that the Set::lower_bound(T) method uses only two overloads of the +// comparison operator that take a T as the second argument.. In contrast +// Set::upper_bound(T) uses the two overloads that take T as the first argument. +template +bool QuicIntervalSet::IntervalLess::operator()(const value_type& a, + const T& point) const { + // Compare an interval to a point. + return a.min() < point; +} + +template +bool QuicIntervalSet::IntervalLess::operator()(const value_type& a, + T&& point) const { + // Compare an interval to a point + return a.min() < point; +} + +// It appears that the Set::upper_bound(T) method uses only the next two +// overloads of the comparison operator. +template +bool QuicIntervalSet::IntervalLess::operator()(const T& point, + const value_type& a) const { + // Compare an interval to a point. + return point < a.min(); +} + +template +bool QuicIntervalSet::IntervalLess::operator()(T&& point, + const value_type& a) const { + // Compare an interval to a point. + return point < a.min(); +} + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_INTERVAL_SET_H_ diff --git a/quiche/quic/core/quic_interval_set_test.cc b/quiche/quic/core/quic_interval_set_test.cc new file mode 100644 index 000000000000..b3ac3e66376a --- /dev/null +++ b/quiche/quic/core/quic_interval_set_test.cc @@ -0,0 +1,1062 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_interval_set.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { +namespace { + +using ::testing::ElementsAreArray; + +class QuicIntervalSetTest : public QuicTest { + protected: + virtual void SetUp() { + // Initialize two QuicIntervalSets for union, intersection, and difference + // tests + is.Add(100, 200); + is.Add(300, 400); + is.Add(500, 600); + is.Add(700, 800); + is.Add(900, 1000); + is.Add(1100, 1200); + is.Add(1300, 1400); + is.Add(1500, 1600); + is.Add(1700, 1800); + is.Add(1900, 2000); + is.Add(2100, 2200); + + // Lots of different cases: + other.Add(50, 70); // disjoint, at the beginning + other.Add(2250, 2270); // disjoint, at the end + other.Add(650, 670); // disjoint, in the middle + other.Add(350, 360); // included + other.Add(370, 380); // also included (two at once) + other.Add(470, 530); // overlaps low end + other.Add(770, 830); // overlaps high end + other.Add(870, 900); // meets at low end + other.Add(1200, 1230); // meets at high end + other.Add(1270, 1830); // overlaps multiple ranges + } + + virtual void TearDown() { + is.Clear(); + EXPECT_TRUE(is.Empty()); + other.Clear(); + EXPECT_TRUE(other.Empty()); + } + QuicIntervalSet is; + QuicIntervalSet other; +}; + +TEST_F(QuicIntervalSetTest, IsDisjoint) { + EXPECT_TRUE(is.IsDisjoint(QuicInterval(0, 99))); + EXPECT_TRUE(is.IsDisjoint(QuicInterval(0, 100))); + EXPECT_TRUE(is.IsDisjoint(QuicInterval(200, 200))); + EXPECT_TRUE(is.IsDisjoint(QuicInterval(200, 299))); + EXPECT_TRUE(is.IsDisjoint(QuicInterval(400, 407))); + EXPECT_TRUE(is.IsDisjoint(QuicInterval(405, 499))); + EXPECT_TRUE(is.IsDisjoint(QuicInterval(2300, 2300))); + EXPECT_TRUE( + is.IsDisjoint(QuicInterval(2300, std::numeric_limits::max()))); + EXPECT_FALSE(is.IsDisjoint(QuicInterval(100, 105))); + EXPECT_FALSE(is.IsDisjoint(QuicInterval(199, 300))); + EXPECT_FALSE(is.IsDisjoint(QuicInterval(250, 450))); + EXPECT_FALSE(is.IsDisjoint(QuicInterval(299, 400))); + EXPECT_FALSE(is.IsDisjoint(QuicInterval(250, 2000))); + EXPECT_FALSE( + is.IsDisjoint(QuicInterval(2199, std::numeric_limits::max()))); + // Empty intervals. + EXPECT_TRUE(is.IsDisjoint(QuicInterval(90, 90))); + EXPECT_TRUE(is.IsDisjoint(QuicInterval(100, 100))); + EXPECT_TRUE(is.IsDisjoint(QuicInterval(100, 90))); + EXPECT_TRUE(is.IsDisjoint(QuicInterval(150, 150))); + EXPECT_TRUE(is.IsDisjoint(QuicInterval(200, 200))); + EXPECT_TRUE(is.IsDisjoint(QuicInterval(400, 300))); +} + +// Base helper method for verifying the contents of an interval set. +// Returns true iff contains intervals whose successive +// endpoints match the sequence of args in : +static bool VA_Check(const QuicIntervalSet& is, int count, va_list ap) { + std::vector> intervals(is.begin(), is.end()); + if (count != static_cast(intervals.size())) { + QUIC_LOG(ERROR) << "Expected " << count << " intervals, got " + << intervals.size() << ": " << is; + return false; + } + if (count != static_cast(is.Size())) { + QUIC_LOG(ERROR) << "Expected " << count << " intervals, got Size " + << is.Size() << ": " << is; + return false; + } + bool result = true; + for (int i = 0; i < count; i++) { + int min = va_arg(ap, int); + int max = va_arg(ap, int); + if (min != intervals[i].min() || max != intervals[i].max()) { + QUIC_LOG(ERROR) << "Expected: [" << min << ", " << max << ") got " + << intervals[i] << " in " << is; + result = false; + } + } + return result; +} + +static bool Check(const QuicIntervalSet& is, int count, ...) { + va_list ap; + va_start(ap, count); + const bool result = VA_Check(is, count, ap); + va_end(ap); + return result; +} + +// Some helper functions for testing Contains and Find, which are logically the +// same. +static void TestContainsAndFind(const QuicIntervalSet& is, int value) { + EXPECT_TRUE(is.Contains(value)) << "Set does not contain " << value; + auto it = is.Find(value); + EXPECT_NE(it, is.end()) << "No iterator to interval containing " << value; + EXPECT_TRUE(it->Contains(value)) << "Iterator does not contain " << value; +} + +static void TestContainsAndFind(const QuicIntervalSet& is, int min, + int max) { + EXPECT_TRUE(is.Contains(min, max)) + << "Set does not contain interval with min " << min << "and max " << max; + auto it = is.Find(min, max); + EXPECT_NE(it, is.end()) << "No iterator to interval with min " << min + << "and max " << max; + EXPECT_TRUE(it->Contains(QuicInterval(min, max))) + << "Iterator does not contain interval with min " << min << "and max " + << max; +} + +static void TestNotContainsAndFind(const QuicIntervalSet& is, int value) { + EXPECT_FALSE(is.Contains(value)) << "Set contains " << value; + auto it = is.Find(value); + EXPECT_EQ(it, is.end()) << "There is iterator to interval containing " + << value; +} + +static void TestNotContainsAndFind(const QuicIntervalSet& is, int min, + int max) { + EXPECT_FALSE(is.Contains(min, max)) + << "Set contains interval with min " << min << "and max " << max; + auto it = is.Find(min, max); + EXPECT_EQ(it, is.end()) << "There is iterator to interval with min " << min + << "and max " << max; +} + +TEST_F(QuicIntervalSetTest, AddInterval) { + QuicIntervalSet s; + s.Add(QuicInterval(0, 10)); + EXPECT_TRUE(Check(s, 1, 0, 10)); +} + +TEST_F(QuicIntervalSetTest, DecrementIterator) { + auto it = is.end(); + EXPECT_NE(it, is.begin()); + --it; + EXPECT_EQ(*it, QuicInterval(2100, 2200)); + ++it; + EXPECT_EQ(it, is.end()); +} + +TEST_F(QuicIntervalSetTest, AddOptimizedForAppend) { + QuicIntervalSet empty_one, empty_two; + empty_one.AddOptimizedForAppend(QuicInterval(0, 99)); + EXPECT_TRUE(Check(empty_one, 1, 0, 99)); + + empty_two.AddOptimizedForAppend(1, 50); + EXPECT_TRUE(Check(empty_two, 1, 1, 50)); + + QuicIntervalSet iset; + iset.AddOptimizedForAppend(100, 150); + iset.AddOptimizedForAppend(200, 250); + EXPECT_TRUE(Check(iset, 2, 100, 150, 200, 250)); + + iset.AddOptimizedForAppend(199, 200); + EXPECT_TRUE(Check(iset, 2, 100, 150, 199, 250)); + + iset.AddOptimizedForAppend(251, 260); + EXPECT_TRUE(Check(iset, 3, 100, 150, 199, 250, 251, 260)); + + iset.AddOptimizedForAppend(252, 260); + EXPECT_TRUE(Check(iset, 3, 100, 150, 199, 250, 251, 260)); + + iset.AddOptimizedForAppend(252, 300); + EXPECT_TRUE(Check(iset, 3, 100, 150, 199, 250, 251, 300)); + + iset.AddOptimizedForAppend(300, 350); + EXPECT_TRUE(Check(iset, 3, 100, 150, 199, 250, 251, 350)); +} + +TEST_F(QuicIntervalSetTest, PopFront) { + QuicIntervalSet iset{{100, 200}, {400, 500}, {700, 800}}; + EXPECT_TRUE(Check(iset, 3, 100, 200, 400, 500, 700, 800)); + + iset.PopFront(); + EXPECT_TRUE(Check(iset, 2, 400, 500, 700, 800)); + + iset.PopFront(); + EXPECT_TRUE(Check(iset, 1, 700, 800)); + + iset.PopFront(); + EXPECT_TRUE(iset.Empty()); +} + +TEST_F(QuicIntervalSetTest, TrimLessThan) { + QuicIntervalSet iset{{100, 200}, {400, 500}, {700, 800}}; + EXPECT_TRUE(Check(iset, 3, 100, 200, 400, 500, 700, 800)); + + EXPECT_FALSE(iset.TrimLessThan(99)); + EXPECT_FALSE(iset.TrimLessThan(100)); + EXPECT_TRUE(Check(iset, 3, 100, 200, 400, 500, 700, 800)); + + EXPECT_TRUE(iset.TrimLessThan(101)); + EXPECT_TRUE(Check(iset, 3, 101, 200, 400, 500, 700, 800)); + + EXPECT_TRUE(iset.TrimLessThan(199)); + EXPECT_TRUE(Check(iset, 3, 199, 200, 400, 500, 700, 800)); + + EXPECT_TRUE(iset.TrimLessThan(450)); + EXPECT_TRUE(Check(iset, 2, 450, 500, 700, 800)); + + EXPECT_TRUE(iset.TrimLessThan(500)); + EXPECT_TRUE(Check(iset, 1, 700, 800)); + + EXPECT_TRUE(iset.TrimLessThan(801)); + EXPECT_TRUE(iset.Empty()); + + EXPECT_FALSE(iset.TrimLessThan(900)); + EXPECT_TRUE(iset.Empty()); +} + +TEST_F(QuicIntervalSetTest, QuicIntervalSetBasic) { + // Test Add, Get, Contains and Find + QuicIntervalSet iset; + EXPECT_TRUE(iset.Empty()); + EXPECT_EQ(0u, iset.Size()); + iset.Add(100, 200); + EXPECT_FALSE(iset.Empty()); + EXPECT_EQ(1u, iset.Size()); + iset.Add(100, 150); + iset.Add(150, 200); + iset.Add(130, 170); + iset.Add(90, 150); + iset.Add(170, 220); + iset.Add(300, 400); + iset.Add(250, 450); + EXPECT_FALSE(iset.Empty()); + EXPECT_EQ(2u, iset.Size()); + EXPECT_TRUE(Check(iset, 2, 90, 220, 250, 450)); + + // Test two intervals with a.max == b.min, that will just join up. + iset.Clear(); + iset.Add(100, 200); + iset.Add(200, 300); + EXPECT_FALSE(iset.Empty()); + EXPECT_EQ(1u, iset.Size()); + EXPECT_TRUE(Check(iset, 1, 100, 300)); + + // Test adding two sets together. + iset.Clear(); + QuicIntervalSet iset_add; + iset.Add(100, 200); + iset.Add(100, 150); + iset.Add(150, 200); + iset.Add(130, 170); + iset_add.Add(90, 150); + iset_add.Add(170, 220); + iset_add.Add(300, 400); + iset_add.Add(250, 450); + + iset.Union(iset_add); + EXPECT_FALSE(iset.Empty()); + EXPECT_EQ(2u, iset.Size()); + EXPECT_TRUE(Check(iset, 2, 90, 220, 250, 450)); + + // Test begin()/end(), and rbegin()/rend() + // to iterate over intervals. + { + std::vector> expected(iset.begin(), iset.end()); + + std::vector> actual1; + std::copy(iset.begin(), iset.end(), back_inserter(actual1)); + ASSERT_EQ(expected.size(), actual1.size()); + + std::vector> actual2; + std::copy(iset.begin(), iset.end(), back_inserter(actual2)); + ASSERT_EQ(expected.size(), actual2.size()); + + for (size_t i = 0; i < expected.size(); i++) { + EXPECT_EQ(expected[i].min(), actual1[i].min()); + EXPECT_EQ(expected[i].max(), actual1[i].max()); + + EXPECT_EQ(expected[i].min(), actual2[i].min()); + EXPECT_EQ(expected[i].max(), actual2[i].max()); + } + + // Ensure that the rbegin()/rend() iterators correctly yield the intervals + // in reverse order. + EXPECT_THAT(std::vector>(iset.rbegin(), iset.rend()), + ElementsAreArray(expected.rbegin(), expected.rend())); + } + + TestNotContainsAndFind(iset, 89); + TestContainsAndFind(iset, 90); + TestContainsAndFind(iset, 120); + TestContainsAndFind(iset, 219); + TestNotContainsAndFind(iset, 220); + TestNotContainsAndFind(iset, 235); + TestNotContainsAndFind(iset, 249); + TestContainsAndFind(iset, 250); + TestContainsAndFind(iset, 300); + TestContainsAndFind(iset, 449); + TestNotContainsAndFind(iset, 450); + TestNotContainsAndFind(iset, 451); + + TestNotContainsAndFind(iset, 50, 60); + TestNotContainsAndFind(iset, 50, 90); + TestNotContainsAndFind(iset, 50, 200); + TestNotContainsAndFind(iset, 90, 90); + TestContainsAndFind(iset, 90, 200); + TestContainsAndFind(iset, 100, 200); + TestContainsAndFind(iset, 100, 220); + TestNotContainsAndFind(iset, 100, 221); + TestNotContainsAndFind(iset, 220, 220); + TestNotContainsAndFind(iset, 240, 300); + TestContainsAndFind(iset, 250, 300); + TestContainsAndFind(iset, 260, 300); + TestContainsAndFind(iset, 300, 450); + TestNotContainsAndFind(iset, 300, 451); + + QuicIntervalSet iset_contains; + iset_contains.Add(50, 90); + EXPECT_FALSE(iset.Contains(iset_contains)); + iset_contains.Clear(); + + iset_contains.Add(90, 200); + EXPECT_TRUE(iset.Contains(iset_contains)); + iset_contains.Add(100, 200); + EXPECT_TRUE(iset.Contains(iset_contains)); + iset_contains.Add(100, 220); + EXPECT_TRUE(iset.Contains(iset_contains)); + iset_contains.Add(250, 300); + EXPECT_TRUE(iset.Contains(iset_contains)); + iset_contains.Add(300, 450); + EXPECT_TRUE(iset.Contains(iset_contains)); + iset_contains.Add(300, 451); + EXPECT_FALSE(iset.Contains(iset_contains)); + EXPECT_FALSE(iset.Contains(QuicInterval())); + EXPECT_FALSE(iset.Contains(QuicIntervalSet())); + + // Check the case where the query set isn't contained, but the spanning + // intervals do overlap. + QuicIntervalSet i2({{220, 230}}); + EXPECT_FALSE(iset.Contains(i2)); +} + +TEST_F(QuicIntervalSetTest, QuicIntervalSetContainsEmpty) { + const QuicIntervalSet empty; + const QuicIntervalSet other_empty; + const QuicIntervalSet non_empty({{10, 20}, {40, 50}}); + EXPECT_FALSE(empty.Contains(empty)); + EXPECT_FALSE(empty.Contains(other_empty)); + EXPECT_FALSE(empty.Contains(non_empty)); + EXPECT_FALSE(non_empty.Contains(empty)); +} + +TEST_F(QuicIntervalSetTest, Equality) { + QuicIntervalSet is_copy = is; + EXPECT_EQ(is, is); + EXPECT_EQ(is, is_copy); + EXPECT_NE(is, other); + EXPECT_NE(is, QuicIntervalSet()); + EXPECT_EQ(QuicIntervalSet(), QuicIntervalSet()); +} + +TEST_F(QuicIntervalSetTest, LowerAndUpperBound) { + QuicIntervalSet intervals; + intervals.Add(10, 20); + intervals.Add(30, 40); + + // [10, 20) [30, 40) end + // ^ LowerBound(5) + // ^ LowerBound(10) + // ^ LowerBound(15) + // ^ LowerBound(20) + // ^ LowerBound(25) + // ^ LowerBound(30) + // ^ LowerBound(35) + // ^ LowerBound(40) + // ^ LowerBound(50) + EXPECT_EQ(intervals.LowerBound(5)->min(), 10); + EXPECT_EQ(intervals.LowerBound(10)->min(), 10); + EXPECT_EQ(intervals.LowerBound(15)->min(), 10); + EXPECT_EQ(intervals.LowerBound(20)->min(), 30); + EXPECT_EQ(intervals.LowerBound(25)->min(), 30); + EXPECT_EQ(intervals.LowerBound(30)->min(), 30); + EXPECT_EQ(intervals.LowerBound(35)->min(), 30); + EXPECT_EQ(intervals.LowerBound(40), intervals.end()); + EXPECT_EQ(intervals.LowerBound(50), intervals.end()); + + // [10, 20) [30, 40) end + // ^ UpperBound(5) + // ^ UpperBound(10) + // ^ UpperBound(15) + // ^ UpperBound(20) + // ^ UpperBound(25) + // ^ UpperBound(30) + // ^ UpperBound(35) + // ^ UpperBound(40) + // ^ UpperBound(50) + EXPECT_EQ(intervals.UpperBound(5)->min(), 10); + EXPECT_EQ(intervals.UpperBound(10)->min(), 30); + EXPECT_EQ(intervals.UpperBound(15)->min(), 30); + EXPECT_EQ(intervals.UpperBound(20)->min(), 30); + EXPECT_EQ(intervals.UpperBound(25)->min(), 30); + EXPECT_EQ(intervals.UpperBound(30), intervals.end()); + EXPECT_EQ(intervals.UpperBound(35), intervals.end()); + EXPECT_EQ(intervals.UpperBound(40), intervals.end()); + EXPECT_EQ(intervals.UpperBound(50), intervals.end()); +} + +TEST_F(QuicIntervalSetTest, SpanningInterval) { + // Spanning interval of an empty set is empty: + { + QuicIntervalSet iset; + const QuicInterval& ival = iset.SpanningInterval(); + EXPECT_TRUE(ival.Empty()); + } + + // Spanning interval of a set with one interval is that interval: + { + QuicIntervalSet iset; + iset.Add(100, 200); + const QuicInterval& ival = iset.SpanningInterval(); + EXPECT_EQ(100, ival.min()); + EXPECT_EQ(200, ival.max()); + } + + // Spanning interval of a set with multiple elements is determined + // by the endpoints of the first and last element: + { + const QuicInterval& ival = is.SpanningInterval(); + EXPECT_EQ(100, ival.min()); + EXPECT_EQ(2200, ival.max()); + } + { + const QuicInterval& ival = other.SpanningInterval(); + EXPECT_EQ(50, ival.min()); + EXPECT_EQ(2270, ival.max()); + } +} + +TEST_F(QuicIntervalSetTest, QuicIntervalSetUnion) { + is.Union(other); + EXPECT_TRUE(Check(is, 12, 50, 70, 100, 200, 300, 400, 470, 600, 650, 670, 700, + 830, 870, 1000, 1100, 1230, 1270, 1830, 1900, 2000, 2100, + 2200, 2250, 2270)); +} + +TEST_F(QuicIntervalSetTest, QuicIntervalSetIntersection) { + EXPECT_TRUE(is.Intersects(other)); + EXPECT_TRUE(other.Intersects(is)); + is.Intersection(other); + EXPECT_TRUE(Check(is, 7, 350, 360, 370, 380, 500, 530, 770, 800, 1300, 1400, + 1500, 1600, 1700, 1800)); + EXPECT_TRUE(is.Intersects(other)); + EXPECT_TRUE(other.Intersects(is)); +} + +TEST_F(QuicIntervalSetTest, QuicIntervalSetIntersectionBothEmpty) { + QuicIntervalSet mine, theirs; + EXPECT_FALSE(mine.Intersects(theirs)); + EXPECT_FALSE(theirs.Intersects(mine)); + mine.Intersection(theirs); + EXPECT_TRUE(mine.Empty()); + EXPECT_FALSE(mine.Intersects(theirs)); + EXPECT_FALSE(theirs.Intersects(mine)); +} + +TEST_F(QuicIntervalSetTest, QuicIntervalSetIntersectionEmptyMine) { + QuicIntervalSet mine; + QuicIntervalSet theirs("a", "b"); + EXPECT_FALSE(mine.Intersects(theirs)); + EXPECT_FALSE(theirs.Intersects(mine)); + mine.Intersection(theirs); + EXPECT_TRUE(mine.Empty()); + EXPECT_FALSE(mine.Intersects(theirs)); + EXPECT_FALSE(theirs.Intersects(mine)); +} + +TEST_F(QuicIntervalSetTest, QuicIntervalSetIntersectionEmptyTheirs) { + QuicIntervalSet mine("a", "b"); + QuicIntervalSet theirs; + EXPECT_FALSE(mine.Intersects(theirs)); + EXPECT_FALSE(theirs.Intersects(mine)); + mine.Intersection(theirs); + EXPECT_TRUE(mine.Empty()); + EXPECT_FALSE(mine.Intersects(theirs)); + EXPECT_FALSE(theirs.Intersects(mine)); +} + +TEST_F(QuicIntervalSetTest, QuicIntervalSetIntersectionTheirsBeforeMine) { + QuicIntervalSet mine("y", "z"); + QuicIntervalSet theirs; + theirs.Add("a", "b"); + theirs.Add("c", "d"); + EXPECT_FALSE(mine.Intersects(theirs)); + EXPECT_FALSE(theirs.Intersects(mine)); + mine.Intersection(theirs); + EXPECT_TRUE(mine.Empty()); + EXPECT_FALSE(mine.Intersects(theirs)); + EXPECT_FALSE(theirs.Intersects(mine)); +} + +TEST_F(QuicIntervalSetTest, QuicIntervalSetIntersectionMineBeforeTheirs) { + QuicIntervalSet mine; + mine.Add("a", "b"); + mine.Add("c", "d"); + QuicIntervalSet theirs("y", "z"); + EXPECT_FALSE(mine.Intersects(theirs)); + EXPECT_FALSE(theirs.Intersects(mine)); + mine.Intersection(theirs); + EXPECT_TRUE(mine.Empty()); + EXPECT_FALSE(mine.Intersects(theirs)); + EXPECT_FALSE(theirs.Intersects(mine)); +} + +TEST_F(QuicIntervalSetTest, + QuicIntervalSetIntersectionTheirsBeforeMineInt64Singletons) { + QuicIntervalSet mine({{10, 15}}); + QuicIntervalSet theirs({{-20, -5}}); + EXPECT_FALSE(mine.Intersects(theirs)); + EXPECT_FALSE(theirs.Intersects(mine)); + mine.Intersection(theirs); + EXPECT_TRUE(mine.Empty()); + EXPECT_FALSE(mine.Intersects(theirs)); + EXPECT_FALSE(theirs.Intersects(mine)); +} + +TEST_F(QuicIntervalSetTest, + QuicIntervalSetIntersectionMineBeforeTheirsIntSingletons) { + QuicIntervalSet mine({{10, 15}}); + QuicIntervalSet theirs({{90, 95}}); + EXPECT_FALSE(mine.Intersects(theirs)); + EXPECT_FALSE(theirs.Intersects(mine)); + mine.Intersection(theirs); + EXPECT_TRUE(mine.Empty()); + EXPECT_FALSE(mine.Intersects(theirs)); + EXPECT_FALSE(theirs.Intersects(mine)); +} + +TEST_F(QuicIntervalSetTest, QuicIntervalSetIntersectionTheirsBetweenMine) { + QuicIntervalSet mine({{0, 5}, {40, 50}}); + QuicIntervalSet theirs({{10, 15}}); + EXPECT_FALSE(mine.Intersects(theirs)); + EXPECT_FALSE(theirs.Intersects(mine)); + mine.Intersection(theirs); + EXPECT_TRUE(mine.Empty()); + EXPECT_FALSE(mine.Intersects(theirs)); + EXPECT_FALSE(theirs.Intersects(mine)); +} + +TEST_F(QuicIntervalSetTest, QuicIntervalSetIntersectionMineBetweenTheirs) { + QuicIntervalSet mine({{20, 25}}); + QuicIntervalSet theirs({{10, 15}, {30, 32}}); + EXPECT_FALSE(mine.Intersects(theirs)); + EXPECT_FALSE(theirs.Intersects(mine)); + mine.Intersection(theirs); + EXPECT_TRUE(mine.Empty()); + EXPECT_FALSE(mine.Intersects(theirs)); + EXPECT_FALSE(theirs.Intersects(mine)); +} + +TEST_F(QuicIntervalSetTest, QuicIntervalSetIntersectionAlternatingIntervals) { + QuicIntervalSet mine, theirs; + mine.Add(10, 20); + mine.Add(40, 50); + mine.Add(60, 70); + theirs.Add(25, 39); + theirs.Add(55, 59); + theirs.Add(75, 79); + EXPECT_FALSE(mine.Intersects(theirs)); + EXPECT_FALSE(theirs.Intersects(mine)); + mine.Intersection(theirs); + EXPECT_TRUE(mine.Empty()); + EXPECT_FALSE(mine.Intersects(theirs)); + EXPECT_FALSE(theirs.Intersects(mine)); +} + +TEST_F(QuicIntervalSetTest, + QuicIntervalSetIntersectionAdjacentAlternatingNonIntersectingIntervals) { + // Make sure that intersection with adjacent interval set is empty. + const QuicIntervalSet x1({{0, 10}}); + const QuicIntervalSet y1({{-50, 0}, {10, 95}}); + + QuicIntervalSet result1 = x1; + result1.Intersection(y1); + EXPECT_TRUE(result1.Empty()) << result1; + + const QuicIntervalSet x2({{0, 10}, {20, 30}, {40, 90}}); + const QuicIntervalSet y2( + {{-50, -40}, {-2, 0}, {10, 20}, {32, 40}, {90, 95}}); + + QuicIntervalSet result2 = x2; + result2.Intersection(y2); + EXPECT_TRUE(result2.Empty()) << result2; + + const QuicIntervalSet x3({{-1, 5}, {5, 10}}); + const QuicIntervalSet y3({{-10, -1}, {10, 95}}); + + QuicIntervalSet result3 = x3; + result3.Intersection(y3); + EXPECT_TRUE(result3.Empty()) << result3; +} + +TEST_F(QuicIntervalSetTest, + QuicIntervalSetIntersectionAlternatingIntersectingIntervals) { + const QuicIntervalSet x1({{0, 10}}); + const QuicIntervalSet y1({{-50, 1}, {9, 95}}); + const QuicIntervalSet expected_result1({{0, 1}, {9, 10}}); + + QuicIntervalSet result1 = x1; + result1.Intersection(y1); + EXPECT_EQ(result1, expected_result1); + + const QuicIntervalSet x2({{0, 10}, {20, 30}, {40, 90}}); + const QuicIntervalSet y2( + {{-50, -40}, {-2, 2}, {9, 21}, {32, 41}, {85, 95}}); + const QuicIntervalSet expected_result2( + {{0, 2}, {9, 10}, {20, 21}, {40, 41}, {85, 90}}); + + QuicIntervalSet result2 = x2; + result2.Intersection(y2); + EXPECT_EQ(result2, expected_result2); + + const QuicIntervalSet x3({{-1, 5}, {5, 10}}); + const QuicIntervalSet y3({{-10, 3}, {4, 95}}); + const QuicIntervalSet expected_result3({{-1, 3}, {4, 10}}); + + QuicIntervalSet result3 = x3; + result3.Intersection(y3); + EXPECT_EQ(result3, expected_result3); +} + +TEST_F(QuicIntervalSetTest, QuicIntervalSetIntersectionIdentical) { + QuicIntervalSet copy(is); + EXPECT_TRUE(copy.Intersects(is)); + EXPECT_TRUE(is.Intersects(copy)); + is.Intersection(copy); + EXPECT_EQ(copy, is); +} + +TEST_F(QuicIntervalSetTest, QuicIntervalSetIntersectionSuperset) { + QuicIntervalSet mine(-1, 10000); + EXPECT_TRUE(mine.Intersects(is)); + EXPECT_TRUE(is.Intersects(mine)); + mine.Intersection(is); + EXPECT_EQ(is, mine); +} + +TEST_F(QuicIntervalSetTest, QuicIntervalSetIntersectionSubset) { + QuicIntervalSet copy(is); + QuicIntervalSet theirs(-1, 10000); + EXPECT_TRUE(copy.Intersects(theirs)); + EXPECT_TRUE(theirs.Intersects(copy)); + is.Intersection(theirs); + EXPECT_EQ(copy, is); +} + +TEST_F(QuicIntervalSetTest, QuicIntervalSetIntersectionLargeSet) { + QuicIntervalSet mine, theirs; + // mine: [0, 9), [10, 19), ..., [990, 999) + for (int i = 0; i < 1000; i += 10) { + mine.Add(i, i + 9); + } + + theirs.Add(500, 520); + theirs.Add(535, 545); + theirs.Add(801, 809); + EXPECT_TRUE(mine.Intersects(theirs)); + EXPECT_TRUE(theirs.Intersects(mine)); + mine.Intersection(theirs); + EXPECT_TRUE(Check(mine, 5, 500, 509, 510, 519, 535, 539, 540, 545, 801, 809)); + EXPECT_TRUE(mine.Intersects(theirs)); + EXPECT_TRUE(theirs.Intersects(mine)); +} + +TEST_F(QuicIntervalSetTest, QuicIntervalSetDifference) { + is.Difference(other); + EXPECT_TRUE(Check(is, 10, 100, 200, 300, 350, 360, 370, 380, 400, 530, 600, + 700, 770, 900, 1000, 1100, 1200, 1900, 2000, 2100, 2200)); + QuicIntervalSet copy = is; + is.Difference(copy); + EXPECT_TRUE(is.Empty()); +} + +TEST_F(QuicIntervalSetTest, QuicIntervalSetDifferenceSingleBounds) { + std::vector> ivals(other.begin(), other.end()); + for (const QuicInterval& ival : ivals) { + is.Difference(ival.min(), ival.max()); + } + EXPECT_TRUE(Check(is, 10, 100, 200, 300, 350, 360, 370, 380, 400, 530, 600, + 700, 770, 900, 1000, 1100, 1200, 1900, 2000, 2100, 2200)); +} + +TEST_F(QuicIntervalSetTest, QuicIntervalSetDifferenceSingleInterval) { + std::vector> ivals(other.begin(), other.end()); + for (const QuicInterval& ival : ivals) { + is.Difference(ival); + } + EXPECT_TRUE(Check(is, 10, 100, 200, 300, 350, 360, 370, 380, 400, 530, 600, + 700, 770, 900, 1000, 1100, 1200, 1900, 2000, 2100, 2200)); +} + +TEST_F(QuicIntervalSetTest, QuicIntervalSetDifferenceAlternatingIntervals) { + QuicIntervalSet mine, theirs; + mine.Add(10, 20); + mine.Add(40, 50); + mine.Add(60, 70); + theirs.Add(25, 39); + theirs.Add(55, 59); + theirs.Add(75, 79); + + mine.Difference(theirs); + EXPECT_TRUE(Check(mine, 3, 10, 20, 40, 50, 60, 70)); +} + +TEST_F(QuicIntervalSetTest, QuicIntervalSetDifferenceEmptyMine) { + QuicIntervalSet mine, theirs; + theirs.Add("a", "b"); + + mine.Difference(theirs); + EXPECT_TRUE(mine.Empty()); +} + +TEST_F(QuicIntervalSetTest, QuicIntervalSetDifferenceEmptyTheirs) { + QuicIntervalSet mine, theirs; + mine.Add("a", "b"); + + mine.Difference(theirs); + EXPECT_EQ(1u, mine.Size()); + EXPECT_EQ("a", mine.begin()->min()); + EXPECT_EQ("b", mine.begin()->max()); +} + +TEST_F(QuicIntervalSetTest, QuicIntervalSetDifferenceTheirsBeforeMine) { + QuicIntervalSet mine, theirs; + mine.Add("y", "z"); + theirs.Add("a", "b"); + + mine.Difference(theirs); + EXPECT_EQ(1u, mine.Size()); + EXPECT_EQ("y", mine.begin()->min()); + EXPECT_EQ("z", mine.begin()->max()); +} + +TEST_F(QuicIntervalSetTest, QuicIntervalSetDifferenceMineBeforeTheirs) { + QuicIntervalSet mine, theirs; + mine.Add("a", "b"); + theirs.Add("y", "z"); + + mine.Difference(theirs); + EXPECT_EQ(1u, mine.Size()); + EXPECT_EQ("a", mine.begin()->min()); + EXPECT_EQ("b", mine.begin()->max()); +} + +TEST_F(QuicIntervalSetTest, QuicIntervalSetDifferenceIdentical) { + QuicIntervalSet mine; + mine.Add("a", "b"); + mine.Add("c", "d"); + QuicIntervalSet theirs(mine); + + mine.Difference(theirs); + EXPECT_TRUE(mine.Empty()); +} + +TEST_F(QuicIntervalSetTest, EmptyComplement) { + // The complement of an empty set is the input interval: + QuicIntervalSet iset; + iset.Complement(100, 200); + EXPECT_TRUE(Check(iset, 1, 100, 200)); +} + +TEST(QuicIntervalSetMultipleCompactionTest, OuterCovering) { + QuicIntervalSet iset; + // First add a bunch of disjoint ranges + iset.Add(100, 150); + iset.Add(200, 250); + iset.Add(300, 350); + iset.Add(400, 450); + EXPECT_TRUE(Check(iset, 4, 100, 150, 200, 250, 300, 350, 400, 450)); + // Now add a big range that covers all of these ranges + iset.Add(0, 500); + EXPECT_TRUE(Check(iset, 1, 0, 500)); +} + +TEST(QuicIntervalSetMultipleCompactionTest, InnerCovering) { + QuicIntervalSet iset; + // First add a bunch of disjoint ranges + iset.Add(100, 150); + iset.Add(200, 250); + iset.Add(300, 350); + iset.Add(400, 450); + EXPECT_TRUE(Check(iset, 4, 100, 150, 200, 250, 300, 350, 400, 450)); + // Now add a big range that partially covers the left and right most ranges. + iset.Add(125, 425); + EXPECT_TRUE(Check(iset, 1, 100, 450)); +} + +TEST(QuicIntervalSetMultipleCompactionTest, LeftCovering) { + QuicIntervalSet iset; + // First add a bunch of disjoint ranges + iset.Add(100, 150); + iset.Add(200, 250); + iset.Add(300, 350); + iset.Add(400, 450); + EXPECT_TRUE(Check(iset, 4, 100, 150, 200, 250, 300, 350, 400, 450)); + // Now add a big range that partially covers the left most range. + iset.Add(125, 500); + EXPECT_TRUE(Check(iset, 1, 100, 500)); +} + +TEST(QuicIntervalSetMultipleCompactionTest, RightCovering) { + QuicIntervalSet iset; + // First add a bunch of disjoint ranges + iset.Add(100, 150); + iset.Add(200, 250); + iset.Add(300, 350); + iset.Add(400, 450); + EXPECT_TRUE(Check(iset, 4, 100, 150, 200, 250, 300, 350, 400, 450)); + // Now add a big range that partially covers the right most range. + iset.Add(0, 425); + EXPECT_TRUE(Check(iset, 1, 0, 450)); +} + +// Helper method for testing and verifying the results of a one-interval +// completement case. +static bool CheckOneComplement(int add_min, int add_max, int comp_min, + int comp_max, int count, ...) { + QuicIntervalSet iset; + iset.Add(add_min, add_max); + iset.Complement(comp_min, comp_max); + bool result = true; + va_list ap; + va_start(ap, count); + if (!VA_Check(iset, count, ap)) { + result = false; + } + va_end(ap); + return result; +} + +TEST_F(QuicIntervalSetTest, SingleIntervalComplement) { + // Verify the complement of a set with one interval (i): + // |----- i -----| + // |----- args -----| + EXPECT_TRUE(CheckOneComplement(0, 10, 50, 150, 1, 50, 150)); + + // |----- i -----| + // |----- args -----| + EXPECT_TRUE(CheckOneComplement(50, 150, 0, 100, 1, 0, 50)); + + // |----- i -----| + // |----- args -----| + EXPECT_TRUE(CheckOneComplement(50, 150, 50, 150, 0)); + + // |---------- i ----------| + // |----- args -----| + EXPECT_TRUE(CheckOneComplement(50, 500, 100, 300, 0)); + + // |----- i -----| + // |---------- args ----------| + EXPECT_TRUE(CheckOneComplement(50, 500, 0, 800, 2, 0, 50, 500, 800)); + + // |----- i -----| + // |----- args -----| + EXPECT_TRUE(CheckOneComplement(50, 150, 100, 300, 1, 150, 300)); + + // |----- i -----| + // |----- args -----| + EXPECT_TRUE(CheckOneComplement(50, 150, 200, 300, 1, 200, 300)); +} + +// Helper method that copies and takes its complement, +// returning false if Check succeeds. +static bool CheckComplement(const QuicIntervalSet& iset, int comp_min, + int comp_max, int count, ...) { + QuicIntervalSet iset_copy = iset; + iset_copy.Complement(comp_min, comp_max); + bool result = true; + va_list ap; + va_start(ap, count); + if (!VA_Check(iset_copy, count, ap)) { + result = false; + } + va_end(ap); + return result; +} + +TEST_F(QuicIntervalSetTest, MultiIntervalComplement) { + // Initialize a small test set: + QuicIntervalSet iset; + iset.Add(100, 200); + iset.Add(300, 400); + iset.Add(500, 600); + + // |----- i -----| + // |----- comp -----| + EXPECT_TRUE(CheckComplement(iset, 0, 50, 1, 0, 50)); + + // |----- i -----| + // |----- comp -----| + EXPECT_TRUE(CheckComplement(iset, 0, 200, 1, 0, 100)); + EXPECT_TRUE(CheckComplement(iset, 0, 220, 2, 0, 100, 200, 220)); + + // |----- i -----| + // |----- comp -----| + EXPECT_TRUE(CheckComplement(iset, 100, 600, 2, 200, 300, 400, 500)); + + // |---------- i ----------| + // |----- comp -----| + EXPECT_TRUE(CheckComplement(iset, 300, 400, 0)); + EXPECT_TRUE(CheckComplement(iset, 250, 400, 1, 250, 300)); + EXPECT_TRUE(CheckComplement(iset, 300, 450, 1, 400, 450)); + EXPECT_TRUE(CheckComplement(iset, 250, 450, 2, 250, 300, 400, 450)); + + // |----- i -----| + // |---------- comp ----------| + EXPECT_TRUE( + CheckComplement(iset, 0, 700, 4, 0, 100, 200, 300, 400, 500, 600, 700)); + + // |----- i -----| + // |----- comp -----| + EXPECT_TRUE(CheckComplement(iset, 400, 700, 2, 400, 500, 600, 700)); + EXPECT_TRUE(CheckComplement(iset, 350, 700, 2, 400, 500, 600, 700)); + + // |----- i -----| + // |----- comp -----| + EXPECT_TRUE(CheckComplement(iset, 700, 800, 1, 700, 800)); +} + +// Verifies ToString, operator<< don't assert. +TEST_F(QuicIntervalSetTest, ToString) { + QuicIntervalSet iset; + iset.Add(300, 400); + iset.Add(100, 200); + iset.Add(500, 600); + EXPECT_TRUE(!iset.ToString().empty()); + QUIC_VLOG(2) << iset; + // Order and format of ToString() output is guaranteed. + EXPECT_EQ("{ [100, 200) [300, 400) [500, 600) }", iset.ToString()); + EXPECT_EQ("{ [1, 2) }", QuicIntervalSet(1, 2).ToString()); + EXPECT_EQ("{ }", QuicIntervalSet().ToString()); +} + +TEST_F(QuicIntervalSetTest, ConstructionDiscardsEmptyInterval) { + EXPECT_TRUE(QuicIntervalSet(QuicInterval(2, 2)).Empty()); + EXPECT_TRUE(QuicIntervalSet(2, 2).Empty()); + EXPECT_FALSE(QuicIntervalSet(QuicInterval(2, 3)).Empty()); + EXPECT_FALSE(QuicIntervalSet(2, 3).Empty()); +} + +TEST_F(QuicIntervalSetTest, Swap) { + QuicIntervalSet a, b; + a.Add(300, 400); + b.Add(100, 200); + b.Add(500, 600); + std::swap(a, b); + EXPECT_TRUE(Check(a, 2, 100, 200, 500, 600)); + EXPECT_TRUE(Check(b, 1, 300, 400)); + std::swap(a, b); + EXPECT_TRUE(Check(a, 1, 300, 400)); + EXPECT_TRUE(Check(b, 2, 100, 200, 500, 600)); +} + +TEST_F(QuicIntervalSetTest, OutputReturnsOstreamRef) { + std::stringstream ss; + const QuicIntervalSet v(QuicInterval(1, 2)); + auto return_type_is_a_ref = [](std::ostream&) {}; + return_type_is_a_ref(ss << v); +} + +struct NotOstreamable { + bool operator<(const NotOstreamable&) const { return false; } + bool operator>(const NotOstreamable&) const { return false; } + bool operator!=(const NotOstreamable&) const { return false; } + bool operator>=(const NotOstreamable&) const { return true; } + bool operator<=(const NotOstreamable&) const { return true; } + bool operator==(const NotOstreamable&) const { return true; } +}; + +TEST_F(QuicIntervalSetTest, IntervalOfTypeWithNoOstreamSupport) { + const NotOstreamable v; + const QuicIntervalSet d(QuicInterval(v, v)); + // EXPECT_EQ builds a string representation of d. If d::operator<<() + // would be defined then this test would not compile because NotOstreamable + // objects lack the operator<<() support. + EXPECT_EQ(d, d); +} + +class QuicIntervalSetInitTest : public QuicTest { + protected: + const std::vector> intervals_{{0, 1}, {2, 4}}; +}; + +TEST_F(QuicIntervalSetInitTest, DirectInit) { + std::initializer_list> il = {{0, 1}, {2, 3}, {3, 4}}; + QuicIntervalSet s(il); + EXPECT_THAT(s, ElementsAreArray(intervals_)); +} + +TEST_F(QuicIntervalSetInitTest, CopyInit) { + std::initializer_list> il = {{0, 1}, {2, 3}, {3, 4}}; + QuicIntervalSet s = il; + EXPECT_THAT(s, ElementsAreArray(intervals_)); +} + +TEST_F(QuicIntervalSetInitTest, AssignIterPair) { + QuicIntervalSet s(0, 1000); // Make sure assign clears. + s.assign(intervals_.begin(), intervals_.end()); + EXPECT_THAT(s, ElementsAreArray(intervals_)); +} + +TEST_F(QuicIntervalSetInitTest, AssignInitList) { + QuicIntervalSet s(0, 1000); // Make sure assign clears. + s.assign({{0, 1}, {2, 3}, {3, 4}}); + EXPECT_THAT(s, ElementsAreArray(intervals_)); +} + +TEST_F(QuicIntervalSetInitTest, AssignmentInitList) { + std::initializer_list> il = {{0, 1}, {2, 3}, {3, 4}}; + QuicIntervalSet s; + s = il; + EXPECT_THAT(s, ElementsAreArray(intervals_)); +} + +TEST_F(QuicIntervalSetInitTest, BracedInitThenBracedAssign) { + QuicIntervalSet s{{0, 1}, {2, 3}, {3, 4}}; + s = {{0, 1}, {2, 4}}; + EXPECT_THAT(s, ElementsAreArray(intervals_)); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_interval_test.cc b/quiche/quic/core/quic_interval_test.cc new file mode 100644 index 000000000000..9a7c70d9c2cf --- /dev/null +++ b/quiche/quic/core/quic_interval_test.cc @@ -0,0 +1,467 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_interval.h" + +#include +#include +#include +#include + +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { +namespace { + +template +void STLDeleteContainerPointers(ForwardIterator begin, ForwardIterator end) { + while (begin != end) { + auto temp = begin; + ++begin; + delete *temp; + } +} + +template +void STLDeleteElements(T* container) { + if (!container) return; + STLDeleteContainerPointers(container->begin(), container->end()); + container->clear(); +} + +class ConstructorListener { + public: + ConstructorListener(int* copy_construct_counter, int* move_construct_counter) + : copy_construct_counter_(copy_construct_counter), + move_construct_counter_(move_construct_counter) { + *copy_construct_counter_ = 0; + *move_construct_counter_ = 0; + } + ConstructorListener(const ConstructorListener& other) { + copy_construct_counter_ = other.copy_construct_counter_; + move_construct_counter_ = other.move_construct_counter_; + ++*copy_construct_counter_; + } + ConstructorListener(ConstructorListener&& other) { + copy_construct_counter_ = other.copy_construct_counter_; + move_construct_counter_ = other.move_construct_counter_; + ++*move_construct_counter_; + } + bool operator<(const ConstructorListener&) { return false; } + bool operator>(const ConstructorListener&) { return false; } + bool operator<=(const ConstructorListener&) { return true; } + bool operator>=(const ConstructorListener&) { return true; } + bool operator==(const ConstructorListener&) { return true; } + + private: + int* copy_construct_counter_; + int* move_construct_counter_; +}; + +TEST(QuicIntervalConstructorTest, Move) { + int object1_copy_count, object1_move_count; + ConstructorListener object1(&object1_copy_count, &object1_move_count); + int object2_copy_count, object2_move_count; + ConstructorListener object2(&object2_copy_count, &object2_move_count); + + QuicInterval interval(object1, std::move(object2)); + EXPECT_EQ(1, object1_copy_count); + EXPECT_EQ(0, object1_move_count); + EXPECT_EQ(0, object2_copy_count); + EXPECT_EQ(1, object2_move_count); +} + +TEST(QuicIntervalConstructorTest, ImplicitConversion) { + struct WrappedInt { + WrappedInt(int value) : value(value) {} + bool operator<(const WrappedInt& other) { return value < other.value; } + bool operator>(const WrappedInt& other) { return value > other.value; } + bool operator<=(const WrappedInt& other) { return value <= other.value; } + bool operator>=(const WrappedInt& other) { return value >= other.value; } + bool operator==(const WrappedInt& other) { return value == other.value; } + int value; + }; + + static_assert(std::is_convertible::value, ""); + static_assert( + std::is_constructible, int, int>::value, ""); + + QuicInterval i(10, 20); + EXPECT_EQ(10, i.min().value); + EXPECT_EQ(20, i.max().value); +} + +class QuicIntervalTest : public QuicTest { + protected: + // Test intersection between the two intervals i1 and i2. Tries + // i1.IntersectWith(i2) and vice versa. The intersection should change i1 iff + // changes_i1 is true, and the same for changes_i2. The resulting + // intersection should be result. + void TestIntersect(const QuicInterval& i1, + const QuicInterval& i2, bool changes_i1, + bool changes_i2, const QuicInterval& result) { + QuicInterval i; + i = i1; + EXPECT_TRUE(i.IntersectWith(i2) == changes_i1 && i == result); + i = i2; + EXPECT_TRUE(i.IntersectWith(i1) == changes_i2 && i == result); + } +}; + +TEST_F(QuicIntervalTest, ConstructorsCopyAndClear) { + QuicInterval empty; + EXPECT_TRUE(empty.Empty()); + + QuicInterval d2(0, 100); + EXPECT_EQ(0, d2.min()); + EXPECT_EQ(100, d2.max()); + EXPECT_EQ(QuicInterval(0, 100), d2); + EXPECT_NE(QuicInterval(0, 99), d2); + + empty = d2; + EXPECT_EQ(0, d2.min()); + EXPECT_EQ(100, d2.max()); + EXPECT_TRUE(empty == d2); + EXPECT_EQ(empty, d2); + EXPECT_TRUE(d2 == empty); + EXPECT_EQ(d2, empty); + + QuicInterval max_less_than_min(40, 20); + EXPECT_TRUE(max_less_than_min.Empty()); + EXPECT_EQ(40, max_less_than_min.min()); + EXPECT_EQ(20, max_less_than_min.max()); + + QuicInterval d3(10, 20); + d3.Clear(); + EXPECT_TRUE(d3.Empty()); +} + +TEST_F(QuicIntervalTest, MakeQuicInterval) { + static_assert( + std::is_same, decltype(MakeQuicInterval(0, 3))>::value, + "Type is deduced incorrectly."); + static_assert(std::is_same, + decltype(MakeQuicInterval(0., 3.))>::value, + "Type is deduced incorrectly."); + + EXPECT_EQ(MakeQuicInterval(0., 3.), QuicInterval(0, 3)); +} + +TEST_F(QuicIntervalTest, GettersSetters) { + QuicInterval d1(100, 200); + + // SetMin: + d1.SetMin(30); + EXPECT_EQ(30, d1.min()); + EXPECT_EQ(200, d1.max()); + + // SetMax: + d1.SetMax(220); + EXPECT_EQ(30, d1.min()); + EXPECT_EQ(220, d1.max()); + + // Set: + d1.Clear(); + d1.Set(30, 220); + EXPECT_EQ(30, d1.min()); + EXPECT_EQ(220, d1.max()); + + // SpanningUnion: + QuicInterval d2; + EXPECT_TRUE(!d1.SpanningUnion(d2)); + EXPECT_EQ(30, d1.min()); + EXPECT_EQ(220, d1.max()); + + EXPECT_TRUE(d2.SpanningUnion(d1)); + EXPECT_EQ(30, d2.min()); + EXPECT_EQ(220, d2.max()); + + d2.SetMin(40); + d2.SetMax(100); + EXPECT_TRUE(!d1.SpanningUnion(d2)); + EXPECT_EQ(30, d1.min()); + EXPECT_EQ(220, d1.max()); + + d2.SetMin(20); + d2.SetMax(100); + EXPECT_TRUE(d1.SpanningUnion(d2)); + EXPECT_EQ(20, d1.min()); + EXPECT_EQ(220, d1.max()); + + d2.SetMin(50); + d2.SetMax(300); + EXPECT_TRUE(d1.SpanningUnion(d2)); + EXPECT_EQ(20, d1.min()); + EXPECT_EQ(300, d1.max()); + + d2.SetMin(0); + d2.SetMax(500); + EXPECT_TRUE(d1.SpanningUnion(d2)); + EXPECT_EQ(0, d1.min()); + EXPECT_EQ(500, d1.max()); + + d2.SetMin(100); + d2.SetMax(0); + EXPECT_TRUE(!d1.SpanningUnion(d2)); + EXPECT_EQ(0, d1.min()); + EXPECT_EQ(500, d1.max()); + EXPECT_TRUE(d2.SpanningUnion(d1)); + EXPECT_EQ(0, d2.min()); + EXPECT_EQ(500, d2.max()); +} + +TEST_F(QuicIntervalTest, CoveringOps) { + const QuicInterval empty; + const QuicInterval d(100, 200); + const QuicInterval d1(0, 50); + const QuicInterval d2(50, 110); + const QuicInterval d3(110, 180); + const QuicInterval d4(180, 220); + const QuicInterval d5(220, 300); + const QuicInterval d6(100, 150); + const QuicInterval d7(150, 200); + const QuicInterval d8(0, 300); + + // Intersection: + EXPECT_TRUE(d.Intersects(d)); + EXPECT_TRUE(!empty.Intersects(d) && !d.Intersects(empty)); + EXPECT_TRUE(!d.Intersects(d1) && !d1.Intersects(d)); + EXPECT_TRUE(d.Intersects(d2) && d2.Intersects(d)); + EXPECT_TRUE(d.Intersects(d3) && d3.Intersects(d)); + EXPECT_TRUE(d.Intersects(d4) && d4.Intersects(d)); + EXPECT_TRUE(!d.Intersects(d5) && !d5.Intersects(d)); + EXPECT_TRUE(d.Intersects(d6) && d6.Intersects(d)); + EXPECT_TRUE(d.Intersects(d7) && d7.Intersects(d)); + EXPECT_TRUE(d.Intersects(d8) && d8.Intersects(d)); + + QuicInterval i; + EXPECT_TRUE(d.Intersects(d, &i) && d == i); + EXPECT_TRUE(!empty.Intersects(d, nullptr) && !d.Intersects(empty, nullptr)); + EXPECT_TRUE(!d.Intersects(d1, nullptr) && !d1.Intersects(d, nullptr)); + EXPECT_TRUE(d.Intersects(d2, &i) && i == QuicInterval(100, 110)); + EXPECT_TRUE(d2.Intersects(d, &i) && i == QuicInterval(100, 110)); + EXPECT_TRUE(d.Intersects(d3, &i) && i == d3); + EXPECT_TRUE(d3.Intersects(d, &i) && i == d3); + EXPECT_TRUE(d.Intersects(d4, &i) && i == QuicInterval(180, 200)); + EXPECT_TRUE(d4.Intersects(d, &i) && i == QuicInterval(180, 200)); + EXPECT_TRUE(!d.Intersects(d5, nullptr) && !d5.Intersects(d, nullptr)); + EXPECT_TRUE(d.Intersects(d6, &i) && i == d6); + EXPECT_TRUE(d6.Intersects(d, &i) && i == d6); + EXPECT_TRUE(d.Intersects(d7, &i) && i == d7); + EXPECT_TRUE(d7.Intersects(d, &i) && i == d7); + EXPECT_TRUE(d.Intersects(d8, &i) && i == d); + EXPECT_TRUE(d8.Intersects(d, &i) && i == d); + + // Test IntersectsWith(). + // Arguments are TestIntersect(i1, i2, changes_i1, changes_i2, result). + TestIntersect(empty, d, false, true, empty); + TestIntersect(d, d1, true, true, empty); + TestIntersect(d1, d2, true, true, empty); + TestIntersect(d, d2, true, true, QuicInterval(100, 110)); + TestIntersect(d8, d, true, false, d); + TestIntersect(d8, d1, true, false, d1); + TestIntersect(d8, d5, true, false, d5); + + // Contains: + EXPECT_TRUE(!empty.Contains(d) && !d.Contains(empty)); + EXPECT_TRUE(d.Contains(d)); + EXPECT_TRUE(!d.Contains(d1) && !d1.Contains(d)); + EXPECT_TRUE(!d.Contains(d2) && !d2.Contains(d)); + EXPECT_TRUE(d.Contains(d3) && !d3.Contains(d)); + EXPECT_TRUE(!d.Contains(d4) && !d4.Contains(d)); + EXPECT_TRUE(!d.Contains(d5) && !d5.Contains(d)); + EXPECT_TRUE(d.Contains(d6) && !d6.Contains(d)); + EXPECT_TRUE(d.Contains(d7) && !d7.Contains(d)); + EXPECT_TRUE(!d.Contains(d8) && d8.Contains(d)); + + EXPECT_TRUE(d.Contains(100)); + EXPECT_TRUE(!d.Contains(200)); + EXPECT_TRUE(d.Contains(150)); + EXPECT_TRUE(!d.Contains(99)); + EXPECT_TRUE(!d.Contains(201)); + + // Difference: + std::vector*> diff; + + EXPECT_TRUE(!d.Difference(empty, &diff)); + EXPECT_EQ(1u, diff.size()); + EXPECT_EQ(100, diff[0]->min()); + EXPECT_EQ(200, diff[0]->max()); + STLDeleteElements(&diff); + EXPECT_TRUE(!empty.Difference(d, &diff) && diff.empty()); + + EXPECT_TRUE(d.Difference(d, &diff) && diff.empty()); + EXPECT_TRUE(!d.Difference(d1, &diff)); + EXPECT_EQ(1u, diff.size()); + EXPECT_EQ(100, diff[0]->min()); + EXPECT_EQ(200, diff[0]->max()); + STLDeleteElements(&diff); + + QuicInterval lo; + QuicInterval hi; + + EXPECT_TRUE(d.Difference(d2, &lo, &hi)); + EXPECT_TRUE(lo.Empty()); + EXPECT_EQ(110, hi.min()); + EXPECT_EQ(200, hi.max()); + EXPECT_TRUE(d.Difference(d2, &diff)); + EXPECT_EQ(1u, diff.size()); + EXPECT_EQ(110, diff[0]->min()); + EXPECT_EQ(200, diff[0]->max()); + STLDeleteElements(&diff); + + EXPECT_TRUE(d.Difference(d3, &lo, &hi)); + EXPECT_EQ(100, lo.min()); + EXPECT_EQ(110, lo.max()); + EXPECT_EQ(180, hi.min()); + EXPECT_EQ(200, hi.max()); + EXPECT_TRUE(d.Difference(d3, &diff)); + EXPECT_EQ(2u, diff.size()); + EXPECT_EQ(100, diff[0]->min()); + EXPECT_EQ(110, diff[0]->max()); + EXPECT_EQ(180, diff[1]->min()); + EXPECT_EQ(200, diff[1]->max()); + STLDeleteElements(&diff); + + EXPECT_TRUE(d.Difference(d4, &lo, &hi)); + EXPECT_EQ(100, lo.min()); + EXPECT_EQ(180, lo.max()); + EXPECT_TRUE(hi.Empty()); + EXPECT_TRUE(d.Difference(d4, &diff)); + EXPECT_EQ(1u, diff.size()); + EXPECT_EQ(100, diff[0]->min()); + EXPECT_EQ(180, diff[0]->max()); + STLDeleteElements(&diff); + + EXPECT_FALSE(d.Difference(d5, &lo, &hi)); + EXPECT_EQ(100, lo.min()); + EXPECT_EQ(200, lo.max()); + EXPECT_TRUE(hi.Empty()); + EXPECT_FALSE(d.Difference(d5, &diff)); + EXPECT_EQ(1u, diff.size()); + EXPECT_EQ(100, diff[0]->min()); + EXPECT_EQ(200, diff[0]->max()); + STLDeleteElements(&diff); + + EXPECT_TRUE(d.Difference(d6, &lo, &hi)); + EXPECT_TRUE(lo.Empty()); + EXPECT_EQ(150, hi.min()); + EXPECT_EQ(200, hi.max()); + EXPECT_TRUE(d.Difference(d6, &diff)); + EXPECT_EQ(1u, diff.size()); + EXPECT_EQ(150, diff[0]->min()); + EXPECT_EQ(200, diff[0]->max()); + STLDeleteElements(&diff); + + EXPECT_TRUE(d.Difference(d7, &lo, &hi)); + EXPECT_EQ(100, lo.min()); + EXPECT_EQ(150, lo.max()); + EXPECT_TRUE(hi.Empty()); + EXPECT_TRUE(d.Difference(d7, &diff)); + EXPECT_EQ(1u, diff.size()); + EXPECT_EQ(100, diff[0]->min()); + EXPECT_EQ(150, diff[0]->max()); + STLDeleteElements(&diff); + + EXPECT_TRUE(d.Difference(d8, &lo, &hi)); + EXPECT_TRUE(lo.Empty()); + EXPECT_TRUE(hi.Empty()); + EXPECT_TRUE(d.Difference(d8, &diff) && diff.empty()); +} + +TEST_F(QuicIntervalTest, Separated) { + using QI = QuicInterval; + EXPECT_FALSE(QI(100, 200).Separated(QI(100, 200))); + EXPECT_FALSE(QI(100, 200).Separated(QI(200, 300))); + EXPECT_TRUE(QI(100, 200).Separated(QI(201, 300))); + EXPECT_FALSE(QI(100, 200).Separated(QI(0, 100))); + EXPECT_TRUE(QI(100, 200).Separated(QI(0, 99))); + EXPECT_FALSE(QI(100, 200).Separated(QI(150, 170))); + EXPECT_FALSE(QI(150, 170).Separated(QI(100, 200))); + EXPECT_FALSE(QI(100, 200).Separated(QI(150, 250))); + EXPECT_FALSE(QI(150, 250).Separated(QI(100, 200))); +} + +TEST_F(QuicIntervalTest, Length) { + const QuicInterval empty1; + const QuicInterval empty2(1, 1); + const QuicInterval empty3(1, 0); + const QuicInterval empty4( + QuicTime::Zero() + QuicTime::Delta::FromSeconds(1), QuicTime::Zero()); + const QuicInterval d1(1, 2); + const QuicInterval d2(0, 50); + const QuicInterval d3( + QuicTime::Zero(), QuicTime::Zero() + QuicTime::Delta::FromSeconds(1)); + const QuicInterval d4( + QuicTime::Zero() + QuicTime::Delta::FromSeconds(3600), + QuicTime::Zero() + QuicTime::Delta::FromSeconds(5400)); + + EXPECT_EQ(0, empty1.Length()); + EXPECT_EQ(0, empty2.Length()); + EXPECT_EQ(0, empty3.Length()); + EXPECT_EQ(QuicTime::Delta::Zero(), empty4.Length()); + EXPECT_EQ(1, d1.Length()); + EXPECT_EQ(50, d2.Length()); + EXPECT_EQ(QuicTime::Delta::FromSeconds(1), d3.Length()); + EXPECT_EQ(QuicTime::Delta::FromSeconds(1800), d4.Length()); +} + +TEST_F(QuicIntervalTest, IntervalOfTypeWithNoOperatorMinus) { + // QuicInterval should work even if T does not support operator-(). We + // just can't call QuicInterval::Length() for such types. + const QuicInterval d1("a", "b"); + const QuicInterval> d2({1, 2}, {4, 3}); + EXPECT_EQ("a", d1.min()); + EXPECT_EQ("b", d1.max()); + EXPECT_EQ(std::make_pair(1, 2), d2.min()); + EXPECT_EQ(std::make_pair(4, 3), d2.max()); +} + +struct NoEquals { + NoEquals(int v) : value(v) {} // NOLINT + int value; + bool operator<(const NoEquals& other) const { return value < other.value; } +}; + +TEST_F(QuicIntervalTest, OrderedComparisonForTypeWithoutEquals) { + const QuicInterval d1(0, 4); + const QuicInterval d2(0, 3); + const QuicInterval d3(1, 4); + const QuicInterval d4(1, 5); + const QuicInterval d6(0, 4); + EXPECT_TRUE(d1 < d2); + EXPECT_TRUE(d1 < d3); + EXPECT_TRUE(d1 < d4); + EXPECT_FALSE(d1 < d6); +} + +TEST_F(QuicIntervalTest, OutputReturnsOstreamRef) { + std::stringstream ss; + const QuicInterval v(1, 2); + // If (ss << v) were to return a value, it wouldn't match the signature of + // return_type_is_a_ref() function. + auto return_type_is_a_ref = [](std::ostream&) {}; + return_type_is_a_ref(ss << v); +} + +struct NotOstreamable { + bool operator<(const NotOstreamable&) const { return false; } + bool operator>=(const NotOstreamable&) const { return true; } + bool operator==(const NotOstreamable&) const { return true; } +}; + +TEST_F(QuicIntervalTest, IntervalOfTypeWithNoOstreamSupport) { + const NotOstreamable v; + const QuicInterval d(v, v); + // EXPECT_EQ builds a string representation of d. If d::operator<<() would be + // defined then this test would not compile because NotOstreamable objects + // lack the operator<<() support. + EXPECT_EQ(d, d); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_linux_socket_utils.cc b/quiche/quic/core/quic_linux_socket_utils.cc new file mode 100644 index 000000000000..ba3541eee273 --- /dev/null +++ b/quiche/quic/core/quic_linux_socket_utils.cc @@ -0,0 +1,310 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_linux_socket_utils.h" + +#include +#include + +#include + +#include "quiche/quic/core/quic_syscall_wrapper.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_socket_address.h" + +namespace quic { + +QuicMsgHdr::QuicMsgHdr(const char* buffer, size_t buf_len, + const QuicSocketAddress& peer_address, char* cbuf, + size_t cbuf_size) + : iov_{const_cast(buffer), buf_len}, + cbuf_(cbuf), + cbuf_size_(cbuf_size), + cmsg_(nullptr) { + // Only support unconnected sockets. + QUICHE_DCHECK(peer_address.IsInitialized()); + + raw_peer_address_ = peer_address.generic_address(); + hdr_.msg_name = &raw_peer_address_; + hdr_.msg_namelen = raw_peer_address_.ss_family == AF_INET + ? sizeof(sockaddr_in) + : sizeof(sockaddr_in6); + + hdr_.msg_iov = &iov_; + hdr_.msg_iovlen = 1; + hdr_.msg_flags = 0; + + hdr_.msg_control = nullptr; + hdr_.msg_controllen = 0; +} + +void QuicMsgHdr::SetIpInNextCmsg(const QuicIpAddress& self_address) { + if (!self_address.IsInitialized()) { + return; + } + + if (self_address.IsIPv4()) { + QuicLinuxSocketUtils::SetIpInfoInCmsgData( + self_address, GetNextCmsgData(IPPROTO_IP, IP_PKTINFO)); + } else { + QuicLinuxSocketUtils::SetIpInfoInCmsgData( + self_address, GetNextCmsgData(IPPROTO_IPV6, IPV6_PKTINFO)); + } +} + +void* QuicMsgHdr::GetNextCmsgDataInternal(int cmsg_level, int cmsg_type, + size_t data_size) { + // msg_controllen needs to be increased first, otherwise CMSG_NXTHDR will + // return nullptr. + hdr_.msg_controllen += CMSG_SPACE(data_size); + QUICHE_DCHECK_LE(hdr_.msg_controllen, cbuf_size_); + + if (cmsg_ == nullptr) { + QUICHE_DCHECK_EQ(nullptr, hdr_.msg_control); + memset(cbuf_, 0, cbuf_size_); + hdr_.msg_control = cbuf_; + cmsg_ = CMSG_FIRSTHDR(&hdr_); + } else { + QUICHE_DCHECK_NE(nullptr, hdr_.msg_control); + cmsg_ = CMSG_NXTHDR(&hdr_, cmsg_); + } + + QUICHE_DCHECK_NE(nullptr, cmsg_) << "Insufficient control buffer space"; + + cmsg_->cmsg_len = CMSG_LEN(data_size); + cmsg_->cmsg_level = cmsg_level; + cmsg_->cmsg_type = cmsg_type; + + return CMSG_DATA(cmsg_); +} + +void QuicMMsgHdr::InitOneHeader(int i, const BufferedWrite& buffered_write) { + mmsghdr* mhdr = GetMMsgHdr(i); + msghdr* hdr = &mhdr->msg_hdr; + iovec* iov = GetIov(i); + + iov->iov_base = const_cast(buffered_write.buffer); + iov->iov_len = buffered_write.buf_len; + hdr->msg_iov = iov; + hdr->msg_iovlen = 1; + hdr->msg_control = nullptr; + hdr->msg_controllen = 0; + + // Only support unconnected sockets. + QUICHE_DCHECK(buffered_write.peer_address.IsInitialized()); + + sockaddr_storage* peer_address_storage = GetPeerAddressStorage(i); + *peer_address_storage = buffered_write.peer_address.generic_address(); + hdr->msg_name = peer_address_storage; + hdr->msg_namelen = peer_address_storage->ss_family == AF_INET + ? sizeof(sockaddr_in) + : sizeof(sockaddr_in6); +} + +void QuicMMsgHdr::SetIpInNextCmsg(int i, const QuicIpAddress& self_address) { + if (!self_address.IsInitialized()) { + return; + } + + if (self_address.IsIPv4()) { + QuicLinuxSocketUtils::SetIpInfoInCmsgData( + self_address, GetNextCmsgData(i, IPPROTO_IP, IP_PKTINFO)); + } else { + QuicLinuxSocketUtils::SetIpInfoInCmsgData( + self_address, + GetNextCmsgData(i, IPPROTO_IPV6, IPV6_PKTINFO)); + } +} + +void* QuicMMsgHdr::GetNextCmsgDataInternal(int i, int cmsg_level, int cmsg_type, + size_t data_size) { + mmsghdr* mhdr = GetMMsgHdr(i); + msghdr* hdr = &mhdr->msg_hdr; + cmsghdr*& cmsg = *GetCmsgHdr(i); + + // msg_controllen needs to be increased first, otherwise CMSG_NXTHDR will + // return nullptr. + hdr->msg_controllen += CMSG_SPACE(data_size); + QUICHE_DCHECK_LE(hdr->msg_controllen, cbuf_size_); + + if (cmsg == nullptr) { + QUICHE_DCHECK_EQ(nullptr, hdr->msg_control); + hdr->msg_control = GetCbuf(i); + cmsg = CMSG_FIRSTHDR(hdr); + } else { + QUICHE_DCHECK_NE(nullptr, hdr->msg_control); + cmsg = CMSG_NXTHDR(hdr, cmsg); + } + + QUICHE_DCHECK_NE(nullptr, cmsg) << "Insufficient control buffer space"; + + cmsg->cmsg_len = CMSG_LEN(data_size); + cmsg->cmsg_level = cmsg_level; + cmsg->cmsg_type = cmsg_type; + + return CMSG_DATA(cmsg); +} + +int QuicMMsgHdr::num_bytes_sent(int num_packets_sent) { + QUICHE_DCHECK_LE(0, num_packets_sent); + QUICHE_DCHECK_LE(num_packets_sent, num_msgs_); + + int bytes_sent = 0; + iovec* iov = GetIov(0); + for (int i = 0; i < num_packets_sent; ++i) { + bytes_sent += iov[i].iov_len; + } + return bytes_sent; +} + +// static +int QuicLinuxSocketUtils::GetUDPSegmentSize(int fd) { + int optval; + socklen_t optlen = sizeof(optval); + int rc = getsockopt(fd, SOL_UDP, UDP_SEGMENT, &optval, &optlen); + if (rc < 0) { + QUIC_LOG_EVERY_N_SEC(INFO, 10) + << "getsockopt(UDP_SEGMENT) failed: " << strerror(errno); + return -1; + } + QUIC_LOG_EVERY_N_SEC(INFO, 10) + << "getsockopt(UDP_SEGMENT) returned segment size: " << optval; + return optval; +} + +// static +bool QuicLinuxSocketUtils::EnableReleaseTime(int fd, clockid_t clockid) { + // TODO(wub): Change to sock_txtime once it is available in linux/net_tstamp.h + struct LinuxSockTxTime { + clockid_t clockid; /* reference clockid */ + uint32_t flags; /* flags defined by enum txtime_flags */ + }; + + LinuxSockTxTime so_txtime_val{clockid, 0}; + + if (setsockopt(fd, SOL_SOCKET, SO_TXTIME, &so_txtime_val, + sizeof(so_txtime_val)) != 0) { + QUIC_LOG_EVERY_N_SEC(INFO, 10) + << "setsockopt(SOL_SOCKET,SO_TXTIME) failed: " << strerror(errno); + return false; + } + + return true; +} + +// static +bool QuicLinuxSocketUtils::GetTtlFromMsghdr(struct msghdr* hdr, int* ttl) { + if (hdr->msg_controllen > 0) { + struct cmsghdr* cmsg; + for (cmsg = CMSG_FIRSTHDR(hdr); cmsg != nullptr; + cmsg = CMSG_NXTHDR(hdr, cmsg)) { + if ((cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_TTL) || + (cmsg->cmsg_level == IPPROTO_IPV6 && + cmsg->cmsg_type == IPV6_HOPLIMIT)) { + *ttl = *(reinterpret_cast(CMSG_DATA(cmsg))); + return true; + } + } + } + return false; +} + +// static +void QuicLinuxSocketUtils::SetIpInfoInCmsgData( + const QuicIpAddress& self_address, void* cmsg_data) { + QUICHE_DCHECK(self_address.IsInitialized()); + const std::string& address_str = self_address.ToPackedString(); + if (self_address.IsIPv4()) { + in_pktinfo* pktinfo = static_cast(cmsg_data); + pktinfo->ipi_ifindex = 0; + memcpy(&pktinfo->ipi_spec_dst, address_str.c_str(), address_str.length()); + } else if (self_address.IsIPv6()) { + in6_pktinfo* pktinfo = static_cast(cmsg_data); + memcpy(&pktinfo->ipi6_addr, address_str.c_str(), address_str.length()); + } else { + QUIC_BUG(quic_bug_10598_1) << "Unrecognized IPAddress"; + } +} + +// static +size_t QuicLinuxSocketUtils::SetIpInfoInCmsg(const QuicIpAddress& self_address, + cmsghdr* cmsg) { + std::string address_string; + if (self_address.IsIPv4()) { + cmsg->cmsg_len = CMSG_LEN(sizeof(in_pktinfo)); + cmsg->cmsg_level = IPPROTO_IP; + cmsg->cmsg_type = IP_PKTINFO; + in_pktinfo* pktinfo = reinterpret_cast(CMSG_DATA(cmsg)); + memset(pktinfo, 0, sizeof(in_pktinfo)); + pktinfo->ipi_ifindex = 0; + address_string = self_address.ToPackedString(); + memcpy(&pktinfo->ipi_spec_dst, address_string.c_str(), + address_string.length()); + return sizeof(in_pktinfo); + } else if (self_address.IsIPv6()) { + cmsg->cmsg_len = CMSG_LEN(sizeof(in6_pktinfo)); + cmsg->cmsg_level = IPPROTO_IPV6; + cmsg->cmsg_type = IPV6_PKTINFO; + in6_pktinfo* pktinfo = reinterpret_cast(CMSG_DATA(cmsg)); + memset(pktinfo, 0, sizeof(in6_pktinfo)); + address_string = self_address.ToPackedString(); + memcpy(&pktinfo->ipi6_addr, address_string.c_str(), + address_string.length()); + return sizeof(in6_pktinfo); + } else { + QUIC_BUG(quic_bug_10598_2) << "Unrecognized IPAddress"; + return 0; + } +} + +// static +WriteResult QuicLinuxSocketUtils::WritePacket(int fd, const QuicMsgHdr& hdr) { + int rc; + do { + rc = GetGlobalSyscallWrapper()->Sendmsg(fd, hdr.hdr(), 0); + } while (rc < 0 && errno == EINTR); + if (rc >= 0) { + return WriteResult(WRITE_STATUS_OK, rc); + } + return WriteResult((errno == EAGAIN || errno == EWOULDBLOCK) + ? WRITE_STATUS_BLOCKED + : WRITE_STATUS_ERROR, + errno); +} + +// static +WriteResult QuicLinuxSocketUtils::WriteMultiplePackets(int fd, + QuicMMsgHdr* mhdr, + int* num_packets_sent) { + *num_packets_sent = 0; + + if (mhdr->num_msgs() <= 0) { + return WriteResult(WRITE_STATUS_ERROR, EINVAL); + } + + int rc; + do { + rc = GetGlobalSyscallWrapper()->Sendmmsg(fd, mhdr->mhdr(), mhdr->num_msgs(), + 0); + } while (rc < 0 && errno == EINTR); + + if (rc > 0) { + *num_packets_sent = rc; + + return WriteResult(WRITE_STATUS_OK, mhdr->num_bytes_sent(rc)); + } else if (rc == 0) { + QUIC_BUG(quic_bug_10598_3) + << "sendmmsg returned 0, returning WRITE_STATUS_ERROR. errno: " + << errno; + errno = EIO; + } + + return WriteResult((errno == EAGAIN || errno == EWOULDBLOCK) + ? WRITE_STATUS_BLOCKED + : WRITE_STATUS_ERROR, + errno); +} + +} // namespace quic diff --git a/quiche/quic/core/quic_linux_socket_utils.h b/quiche/quic/core/quic_linux_socket_utils.h new file mode 100644 index 000000000000..de80dfd7ff35 --- /dev/null +++ b/quiche/quic/core/quic_linux_socket_utils.h @@ -0,0 +1,285 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_LINUX_SOCKET_UTILS_H_ +#define QUICHE_QUIC_CORE_QUIC_LINUX_SOCKET_UTILS_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "quiche/quic/core/quic_packet_writer.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_socket_address.h" + +#ifndef SOL_UDP +#define SOL_UDP 17 +#endif + +#ifndef UDP_SEGMENT +#define UDP_SEGMENT 103 +#endif + +#ifndef UDP_MAX_SEGMENTS +#define UDP_MAX_SEGMENTS (1 << 6UL) +#endif + +#ifndef SO_TXTIME +#define SO_TXTIME 61 +#endif + +namespace quic { + +const int kCmsgSpaceForIpv4 = CMSG_SPACE(sizeof(in_pktinfo)); +const int kCmsgSpaceForIpv6 = CMSG_SPACE(sizeof(in6_pktinfo)); +// kCmsgSpaceForIp should be big enough to hold both IPv4 and IPv6 packet info. +const int kCmsgSpaceForIp = (kCmsgSpaceForIpv4 < kCmsgSpaceForIpv6) + ? kCmsgSpaceForIpv6 + : kCmsgSpaceForIpv4; + +const int kCmsgSpaceForSegmentSize = CMSG_SPACE(sizeof(uint16_t)); + +const int kCmsgSpaceForTxTime = CMSG_SPACE(sizeof(uint64_t)); + +const int kCmsgSpaceForTTL = CMSG_SPACE(sizeof(int)); + +// QuicMsgHdr is used to build msghdr objects that can be used send packets via +// ::sendmsg. +// +// Example: +// // cbuf holds control messages(cmsgs). The size is determined from what +// // cmsgs will be set for this msghdr. +// char cbuf[kCmsgSpaceForIp + kCmsgSpaceForSegmentSize]; +// QuicMsgHdr hdr(packet_buf, packet_buf_len, peer_addr, cbuf, sizeof(cbuf)); +// +// // Set IP in cmsgs. +// hdr.SetIpInNextCmsg(self_addr); +// +// // Set GSO size in cmsgs. +// *hdr.GetNextCmsgData(SOL_UDP, UDP_SEGMENT) = 1200; +// +// QuicLinuxSocketUtils::WritePacket(fd, hdr); +class QUIC_EXPORT_PRIVATE QuicMsgHdr { + public: + QuicMsgHdr(const char* buffer, size_t buf_len, + const QuicSocketAddress& peer_address, char* cbuf, + size_t cbuf_size); + + // Set IP info in the next cmsg. Both IPv4 and IPv6 are supported. + void SetIpInNextCmsg(const QuicIpAddress& self_address); + + template + DataType* GetNextCmsgData(int cmsg_level, int cmsg_type) { + return reinterpret_cast( + GetNextCmsgDataInternal(cmsg_level, cmsg_type, sizeof(DataType))); + } + + const msghdr* hdr() const { return &hdr_; } + + protected: + void* GetNextCmsgDataInternal(int cmsg_level, int cmsg_type, + size_t data_size); + + msghdr hdr_; + iovec iov_; + sockaddr_storage raw_peer_address_; + char* cbuf_; + const size_t cbuf_size_; + // The last cmsg populated so far. nullptr means nothing has been populated. + cmsghdr* cmsg_; +}; + +// BufferedWrite holds all information needed to send a packet. +struct QUIC_EXPORT_PRIVATE BufferedWrite { + BufferedWrite(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address) + : BufferedWrite(buffer, buf_len, self_address, peer_address, + std::unique_ptr(), + /*release_time=*/0) {} + + BufferedWrite(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + std::unique_ptr options, + uint64_t release_time) + : buffer(buffer), + buf_len(buf_len), + self_address(self_address), + peer_address(peer_address), + options(std::move(options)), + release_time(release_time) {} + + const char* buffer; // Not owned. + size_t buf_len; + QuicIpAddress self_address; + QuicSocketAddress peer_address; + std::unique_ptr options; + + // The release time according to the owning packet writer's clock, which is + // often not a QuicClock. Calculated from packet writer's Now() and the + // release time delay in |options|. + // 0 means it can be sent at the same time as the previous packet in a batch, + // or can be sent Now() if this is the first packet of a batch. + uint64_t release_time; +}; + +// QuicMMsgHdr is used to build mmsghdr objects that can be used to send +// multiple packets at once via ::sendmmsg. +// +// Example: +// quiche::QuicheCircularDeque buffered_writes; +// ... (Populate buffered_writes) ... +// +// QuicMMsgHdr mhdr( +// buffered_writes.begin(), buffered_writes.end(), kCmsgSpaceForIp, +// [](QuicMMsgHdr* mhdr, int i, const BufferedWrite& buffered_write) { +// mhdr->SetIpInNextCmsg(i, buffered_write.self_address); +// }); +// +// int num_packets_sent; +// QuicSocketUtils::WriteMultiplePackets(fd, &mhdr, &num_packets_sent); +class QUIC_EXPORT_PRIVATE QuicMMsgHdr { + public: + using ControlBufferInitializer = std::function; + template + QuicMMsgHdr(const IteratorT& first, const IteratorT& last, size_t cbuf_size, + ControlBufferInitializer cbuf_initializer) + : num_msgs_(std::distance(first, last)), cbuf_size_(cbuf_size) { + static_assert( + std::is_same::value_type, + BufferedWrite>::value, + "Must iterate over a collection of BufferedWrite."); + + QUICHE_DCHECK_LE(0, num_msgs_); + if (num_msgs_ == 0) { + return; + } + + storage_.reset(new char[StorageSize()]); + memset(&storage_[0], 0, StorageSize()); + + int i = -1; + for (auto it = first; it != last; ++it) { + ++i; + + InitOneHeader(i, *it); + if (cbuf_initializer) { + cbuf_initializer(this, i, *it); + } + } + } + + void SetIpInNextCmsg(int i, const QuicIpAddress& self_address); + + template + DataType* GetNextCmsgData(int i, int cmsg_level, int cmsg_type) { + return reinterpret_cast( + GetNextCmsgDataInternal(i, cmsg_level, cmsg_type, sizeof(DataType))); + } + + mmsghdr* mhdr() { return GetMMsgHdr(0); } + + int num_msgs() const { return num_msgs_; } + + // Get the total number of bytes in the first |num_packets_sent| packets. + int num_bytes_sent(int num_packets_sent); + + protected: + void InitOneHeader(int i, const BufferedWrite& buffered_write); + + void* GetNextCmsgDataInternal(int i, int cmsg_level, int cmsg_type, + size_t data_size); + + size_t StorageSize() const { + return num_msgs_ * + (sizeof(mmsghdr) + sizeof(iovec) + sizeof(sockaddr_storage) + + sizeof(cmsghdr*) + cbuf_size_); + } + + mmsghdr* GetMMsgHdr(int i) { + auto* first = reinterpret_cast(&storage_[0]); + return &first[i]; + } + + iovec* GetIov(int i) { + auto* first = reinterpret_cast(GetMMsgHdr(num_msgs_)); + return &first[i]; + } + + sockaddr_storage* GetPeerAddressStorage(int i) { + auto* first = reinterpret_cast(GetIov(num_msgs_)); + return &first[i]; + } + + cmsghdr** GetCmsgHdr(int i) { + auto* first = reinterpret_cast(GetPeerAddressStorage(num_msgs_)); + return &first[i]; + } + + char* GetCbuf(int i) { + auto* first = reinterpret_cast(GetCmsgHdr(num_msgs_)); + return &first[i * cbuf_size_]; + } + + const int num_msgs_; + // Size of cmsg buffer for each message. + const size_t cbuf_size_; + // storage_ holds the memory of + // |num_msgs_| mmsghdr + // |num_msgs_| iovec + // |num_msgs_| sockaddr_storage, for peer addresses + // |num_msgs_| cmsghdr* + // |num_msgs_| cbuf, each of size cbuf_size + std::unique_ptr storage_; +}; + +class QUIC_EXPORT_PRIVATE QuicLinuxSocketUtils { + public: + // Return the UDP segment size of |fd|, 0 means segment size has not been set + // on this socket. If GSO is not supported, return -1. + static int GetUDPSegmentSize(int fd); + + // Enable release time on |fd|. + static bool EnableReleaseTime(int fd, clockid_t clockid); + + // If the msghdr contains an IP_TTL entry, this will set ttl to the correct + // value and return true. Otherwise it will return false. + static bool GetTtlFromMsghdr(struct msghdr* hdr, int* ttl); + + // Set IP(self_address) in |cmsg_data|. Does not touch other fields in the + // containing cmsghdr. + static void SetIpInfoInCmsgData(const QuicIpAddress& self_address, + void* cmsg_data); + + // A helper for WritePacket which fills in the cmsg with the supplied self + // address. + // Returns the length of the packet info structure used. + static size_t SetIpInfoInCmsg(const QuicIpAddress& self_address, + cmsghdr* cmsg); + + // Writes the packet in |hdr| to the socket, using ::sendmsg. + static WriteResult WritePacket(int fd, const QuicMsgHdr& hdr); + + // Writes the packets in |mhdr| to the socket, using ::sendmmsg if available. + static WriteResult WriteMultiplePackets(int fd, QuicMMsgHdr* mhdr, + int* num_packets_sent); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_LINUX_SOCKET_UTILS_H_ diff --git a/quiche/quic/core/quic_linux_socket_utils_test.cc b/quiche/quic/core/quic_linux_socket_utils_test.cc new file mode 100644 index 000000000000..e9d1b1475046 --- /dev/null +++ b/quiche/quic/core/quic_linux_socket_utils_test.cc @@ -0,0 +1,324 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_linux_socket_utils.h" + +#include +#include + +#include +#include +#include +#include + +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_mock_syscall_wrapper.h" +#include "quiche/common/quiche_circular_deque.h" + +using testing::_; +using testing::InSequence; +using testing::Invoke; + +namespace quic { +namespace test { +namespace { + +class QuicLinuxSocketUtilsTest : public QuicTest { + protected: + WriteResult TestWriteMultiplePackets( + int fd, + const quiche::QuicheCircularDeque::const_iterator& first, + const quiche::QuicheCircularDeque::const_iterator& last, + int* num_packets_sent) { + QuicMMsgHdr mhdr( + first, last, kCmsgSpaceForIp, + [](QuicMMsgHdr* mhdr, int i, const BufferedWrite& buffered_write) { + mhdr->SetIpInNextCmsg(i, buffered_write.self_address); + }); + + WriteResult res = + QuicLinuxSocketUtils::WriteMultiplePackets(fd, &mhdr, num_packets_sent); + return res; + } + + MockQuicSyscallWrapper mock_syscalls_; + ScopedGlobalSyscallWrapperOverride syscall_override_{&mock_syscalls_}; +}; + +void CheckIpAndTtlInCbuf(msghdr* hdr, const void* cbuf, + const QuicIpAddress& self_addr, int ttl) { + const bool is_ipv4 = self_addr.IsIPv4(); + const size_t ip_cmsg_space = is_ipv4 ? kCmsgSpaceForIpv4 : kCmsgSpaceForIpv6; + + EXPECT_EQ(cbuf, hdr->msg_control); + EXPECT_EQ(ip_cmsg_space + CMSG_SPACE(sizeof(uint16_t)), hdr->msg_controllen); + + cmsghdr* cmsg = CMSG_FIRSTHDR(hdr); + EXPECT_EQ(cmsg->cmsg_len, is_ipv4 ? CMSG_LEN(sizeof(in_pktinfo)) + : CMSG_LEN(sizeof(in6_pktinfo))); + EXPECT_EQ(cmsg->cmsg_level, is_ipv4 ? IPPROTO_IP : IPPROTO_IPV6); + EXPECT_EQ(cmsg->cmsg_type, is_ipv4 ? IP_PKTINFO : IPV6_PKTINFO); + + const std::string& self_addr_str = self_addr.ToPackedString(); + if (is_ipv4) { + in_pktinfo* pktinfo = reinterpret_cast(CMSG_DATA(cmsg)); + EXPECT_EQ(0, memcmp(&pktinfo->ipi_spec_dst, self_addr_str.c_str(), + self_addr_str.length())); + } else { + in6_pktinfo* pktinfo = reinterpret_cast(CMSG_DATA(cmsg)); + EXPECT_EQ(0, memcmp(&pktinfo->ipi6_addr, self_addr_str.c_str(), + self_addr_str.length())); + } + + cmsg = CMSG_NXTHDR(hdr, cmsg); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int))); + EXPECT_EQ(cmsg->cmsg_level, is_ipv4 ? IPPROTO_IP : IPPROTO_IPV6); + EXPECT_EQ(cmsg->cmsg_type, is_ipv4 ? IP_TTL : IPV6_HOPLIMIT); + EXPECT_EQ(ttl, *reinterpret_cast(CMSG_DATA(cmsg))); + + EXPECT_EQ(nullptr, CMSG_NXTHDR(hdr, cmsg)); +} + +void CheckMsghdrWithoutCbuf(const msghdr* hdr, const void* buffer, + size_t buf_len, + const QuicSocketAddress& peer_addr) { + EXPECT_EQ( + peer_addr.host().IsIPv4() ? sizeof(sockaddr_in) : sizeof(sockaddr_in6), + hdr->msg_namelen); + sockaddr_storage peer_generic_addr = peer_addr.generic_address(); + EXPECT_EQ(0, memcmp(hdr->msg_name, &peer_generic_addr, hdr->msg_namelen)); + EXPECT_EQ(1u, hdr->msg_iovlen); + EXPECT_EQ(buffer, hdr->msg_iov->iov_base); + EXPECT_EQ(buf_len, hdr->msg_iov->iov_len); + EXPECT_EQ(0, hdr->msg_flags); + EXPECT_EQ(nullptr, hdr->msg_control); + EXPECT_EQ(0u, hdr->msg_controllen); +} + +void CheckIpAndGsoSizeInCbuf(msghdr* hdr, const void* cbuf, + const QuicIpAddress& self_addr, + uint16_t gso_size) { + const bool is_ipv4 = self_addr.IsIPv4(); + const size_t ip_cmsg_space = is_ipv4 ? kCmsgSpaceForIpv4 : kCmsgSpaceForIpv6; + + EXPECT_EQ(cbuf, hdr->msg_control); + EXPECT_EQ(ip_cmsg_space + CMSG_SPACE(sizeof(uint16_t)), hdr->msg_controllen); + + cmsghdr* cmsg = CMSG_FIRSTHDR(hdr); + EXPECT_EQ(cmsg->cmsg_len, is_ipv4 ? CMSG_LEN(sizeof(in_pktinfo)) + : CMSG_LEN(sizeof(in6_pktinfo))); + EXPECT_EQ(cmsg->cmsg_level, is_ipv4 ? IPPROTO_IP : IPPROTO_IPV6); + EXPECT_EQ(cmsg->cmsg_type, is_ipv4 ? IP_PKTINFO : IPV6_PKTINFO); + + const std::string& self_addr_str = self_addr.ToPackedString(); + if (is_ipv4) { + in_pktinfo* pktinfo = reinterpret_cast(CMSG_DATA(cmsg)); + EXPECT_EQ(0, memcmp(&pktinfo->ipi_spec_dst, self_addr_str.c_str(), + self_addr_str.length())); + } else { + in6_pktinfo* pktinfo = reinterpret_cast(CMSG_DATA(cmsg)); + EXPECT_EQ(0, memcmp(&pktinfo->ipi6_addr, self_addr_str.c_str(), + self_addr_str.length())); + } + + cmsg = CMSG_NXTHDR(hdr, cmsg); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(uint16_t))); + EXPECT_EQ(cmsg->cmsg_level, SOL_UDP); + EXPECT_EQ(cmsg->cmsg_type, UDP_SEGMENT); + EXPECT_EQ(gso_size, *reinterpret_cast(CMSG_DATA(cmsg))); + + EXPECT_EQ(nullptr, CMSG_NXTHDR(hdr, cmsg)); +} + +TEST_F(QuicLinuxSocketUtilsTest, QuicMsgHdr) { + QuicSocketAddress peer_addr(QuicIpAddress::Loopback4(), 1234); + char packet_buf[1024]; + + QuicMsgHdr quic_hdr(packet_buf, sizeof(packet_buf), peer_addr, nullptr, 0); + CheckMsghdrWithoutCbuf(quic_hdr.hdr(), packet_buf, sizeof(packet_buf), + peer_addr); + + for (bool is_ipv4 : {true, false}) { + QuicIpAddress self_addr = + is_ipv4 ? QuicIpAddress::Loopback4() : QuicIpAddress::Loopback6(); + char cbuf[kCmsgSpaceForIp + kCmsgSpaceForTTL]; + QuicMsgHdr quic_hdr(packet_buf, sizeof(packet_buf), peer_addr, cbuf, + sizeof(cbuf)); + msghdr* hdr = const_cast(quic_hdr.hdr()); + + EXPECT_EQ(nullptr, hdr->msg_control); + EXPECT_EQ(0u, hdr->msg_controllen); + + quic_hdr.SetIpInNextCmsg(self_addr); + EXPECT_EQ(cbuf, hdr->msg_control); + const size_t ip_cmsg_space = + is_ipv4 ? kCmsgSpaceForIpv4 : kCmsgSpaceForIpv6; + EXPECT_EQ(ip_cmsg_space, hdr->msg_controllen); + + if (is_ipv4) { + *quic_hdr.GetNextCmsgData(IPPROTO_IP, IP_TTL) = 32; + } else { + *quic_hdr.GetNextCmsgData(IPPROTO_IPV6, IPV6_HOPLIMIT) = 32; + } + + CheckIpAndTtlInCbuf(hdr, cbuf, self_addr, 32); + } +} + +TEST_F(QuicLinuxSocketUtilsTest, QuicMMsgHdr) { + quiche::QuicheCircularDeque buffered_writes; + char packet_buf1[1024]; + char packet_buf2[512]; + buffered_writes.emplace_back( + packet_buf1, sizeof(packet_buf1), QuicIpAddress::Loopback4(), + QuicSocketAddress(QuicIpAddress::Loopback4(), 4)); + buffered_writes.emplace_back( + packet_buf2, sizeof(packet_buf2), QuicIpAddress::Loopback6(), + QuicSocketAddress(QuicIpAddress::Loopback6(), 6)); + + QuicMMsgHdr quic_mhdr_without_cbuf(buffered_writes.begin(), + buffered_writes.end(), 0, nullptr); + for (size_t i = 0; i < buffered_writes.size(); ++i) { + const BufferedWrite& bw = buffered_writes[i]; + CheckMsghdrWithoutCbuf(&quic_mhdr_without_cbuf.mhdr()[i].msg_hdr, bw.buffer, + bw.buf_len, bw.peer_address); + } + + QuicMMsgHdr quic_mhdr_with_cbuf( + buffered_writes.begin(), buffered_writes.end(), + kCmsgSpaceForIp + kCmsgSpaceForSegmentSize, + [](QuicMMsgHdr* mhdr, int i, const BufferedWrite& buffered_write) { + mhdr->SetIpInNextCmsg(i, buffered_write.self_address); + *mhdr->GetNextCmsgData(i, SOL_UDP, UDP_SEGMENT) = 1300; + }); + for (size_t i = 0; i < buffered_writes.size(); ++i) { + const BufferedWrite& bw = buffered_writes[i]; + msghdr* hdr = &quic_mhdr_with_cbuf.mhdr()[i].msg_hdr; + CheckIpAndGsoSizeInCbuf(hdr, hdr->msg_control, bw.self_address, 1300); + } +} + +TEST_F(QuicLinuxSocketUtilsTest, WriteMultiplePackets_NoPacketsToSend) { + int num_packets_sent; + quiche::QuicheCircularDeque buffered_writes; + + EXPECT_CALL(mock_syscalls_, Sendmmsg(_, _, _, _)).Times(0); + + EXPECT_EQ(WriteResult(WRITE_STATUS_ERROR, EINVAL), + TestWriteMultiplePackets(1, buffered_writes.begin(), + buffered_writes.end(), &num_packets_sent)); +} + +TEST_F(QuicLinuxSocketUtilsTest, WriteMultiplePackets_WriteBlocked) { + int num_packets_sent; + quiche::QuicheCircularDeque buffered_writes; + buffered_writes.emplace_back(nullptr, 0, QuicIpAddress(), + QuicSocketAddress(QuicIpAddress::Any4(), 0)); + + EXPECT_CALL(mock_syscalls_, Sendmmsg(_, _, _, _)) + .WillOnce(Invoke([](int /*fd*/, mmsghdr* /*msgvec*/, + unsigned int /*vlen*/, int /*flags*/) { + errno = EWOULDBLOCK; + return -1; + })); + + EXPECT_EQ(WriteResult(WRITE_STATUS_BLOCKED, EWOULDBLOCK), + TestWriteMultiplePackets(1, buffered_writes.begin(), + buffered_writes.end(), &num_packets_sent)); + EXPECT_EQ(0, num_packets_sent); +} + +TEST_F(QuicLinuxSocketUtilsTest, WriteMultiplePackets_WriteError) { + int num_packets_sent; + quiche::QuicheCircularDeque buffered_writes; + buffered_writes.emplace_back(nullptr, 0, QuicIpAddress(), + QuicSocketAddress(QuicIpAddress::Any4(), 0)); + + EXPECT_CALL(mock_syscalls_, Sendmmsg(_, _, _, _)) + .WillOnce(Invoke([](int /*fd*/, mmsghdr* /*msgvec*/, + unsigned int /*vlen*/, int /*flags*/) { + errno = EPERM; + return -1; + })); + + EXPECT_EQ(WriteResult(WRITE_STATUS_ERROR, EPERM), + TestWriteMultiplePackets(1, buffered_writes.begin(), + buffered_writes.end(), &num_packets_sent)); + EXPECT_EQ(0, num_packets_sent); +} + +TEST_F(QuicLinuxSocketUtilsTest, WriteMultiplePackets_WriteSuccess) { + int num_packets_sent; + quiche::QuicheCircularDeque buffered_writes; + const int kNumBufferedWrites = 10; + static_assert(kNumBufferedWrites < 256, "Must be less than 256"); + std::vector buffer_holder; + for (int i = 0; i < kNumBufferedWrites; ++i) { + size_t buf_len = (i + 1) * 2; + std::ostringstream buffer_ostream; + while (buffer_ostream.str().length() < buf_len) { + buffer_ostream << i; + } + buffer_holder.push_back(buffer_ostream.str().substr(0, buf_len - 1) + '$'); + + buffered_writes.emplace_back(buffer_holder.back().data(), buf_len, + QuicIpAddress(), + QuicSocketAddress(QuicIpAddress::Any4(), 0)); + + // Leave the first self_address uninitialized. + if (i != 0) { + ASSERT_TRUE(buffered_writes.back().self_address.FromString("127.0.0.1")); + } + + std::ostringstream peer_ip_ostream; + QuicIpAddress peer_ip_address; + peer_ip_ostream << "127.0.1." << i + 1; + ASSERT_TRUE(peer_ip_address.FromString(peer_ip_ostream.str())); + buffered_writes.back().peer_address = + QuicSocketAddress(peer_ip_address, i + 1); + } + + InSequence s; + + for (int expected_num_packets_sent : {1, 2, 3, 10}) { + SCOPED_TRACE(testing::Message() + << "expected_num_packets_sent=" << expected_num_packets_sent); + EXPECT_CALL(mock_syscalls_, Sendmmsg(_, _, _, _)) + .WillOnce(Invoke([&](int /*fd*/, mmsghdr* msgvec, unsigned int vlen, + int /*flags*/) { + EXPECT_LE(static_cast(expected_num_packets_sent), vlen); + for (unsigned int i = 0; i < vlen; ++i) { + const BufferedWrite& buffered_write = buffered_writes[i]; + const msghdr& hdr = msgvec[i].msg_hdr; + EXPECT_EQ(1u, hdr.msg_iovlen); + EXPECT_EQ(buffered_write.buffer, hdr.msg_iov->iov_base); + EXPECT_EQ(buffered_write.buf_len, hdr.msg_iov->iov_len); + sockaddr_storage expected_peer_address = + buffered_write.peer_address.generic_address(); + EXPECT_EQ(0, memcmp(&expected_peer_address, hdr.msg_name, + sizeof(sockaddr_storage))); + EXPECT_EQ(buffered_write.self_address.IsInitialized(), + hdr.msg_control != nullptr); + } + return expected_num_packets_sent; + })) + .RetiresOnSaturation(); + + int expected_bytes_written = 0; + for (auto it = buffered_writes.cbegin(); + it != buffered_writes.cbegin() + expected_num_packets_sent; ++it) { + expected_bytes_written += it->buf_len; + } + + EXPECT_EQ( + WriteResult(WRITE_STATUS_OK, expected_bytes_written), + TestWriteMultiplePackets(1, buffered_writes.cbegin(), + buffered_writes.cend(), &num_packets_sent)); + EXPECT_EQ(expected_num_packets_sent, num_packets_sent); + } +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_lru_cache.h b/quiche/quic/core/quic_lru_cache.h new file mode 100644 index 000000000000..d1c010e68088 --- /dev/null +++ b/quiche/quic/core/quic_lru_cache.h @@ -0,0 +1,98 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_LRU_CACHE_H_ +#define QUICHE_QUIC_CORE_QUIC_LRU_CACHE_H_ + +#include + +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/quiche_linked_hash_map.h" + +namespace quic { + +// A LRU cache that maps from type Key to Value* in QUIC. +// This cache CANNOT be shared by multiple threads (even with locks) because +// Value* returned by Lookup() can be invalid if the entry is evicted by other +// threads. +template , + class Eq = std::equal_to> +class QUIC_NO_EXPORT QuicLRUCache { + private: + using HashMapType = + typename quiche::QuicheLinkedHashMap, Hash, Eq>; + + public: + // The iterator, if valid, points to std::pair>. + using iterator = typename HashMapType::iterator; + using const_iterator = typename HashMapType::const_iterator; + using reverse_iterator = typename HashMapType::reverse_iterator; + using const_reverse_iterator = typename HashMapType::const_reverse_iterator; + + explicit QuicLRUCache(size_t capacity) : capacity_(capacity) {} + QuicLRUCache(const QuicLRUCache&) = delete; + QuicLRUCache& operator=(const QuicLRUCache&) = delete; + + iterator begin() { return cache_.begin(); } + const_iterator begin() const { return cache_.begin(); } + + iterator end() { return cache_.end(); } + const_iterator end() const { return cache_.end(); } + + reverse_iterator rbegin() { return cache_.rbegin(); } + const_reverse_iterator rbegin() const { return cache_.rbegin(); } + + reverse_iterator rend() { return cache_.rend(); } + const_reverse_iterator rend() const { return cache_.rend(); } + + // Inserts one unit of |key|, |value| pair to the cache. Cache takes ownership + // of inserted |value|. + void Insert(const K& key, std::unique_ptr value) { + auto it = cache_.find(key); + if (it != cache_.end()) { + cache_.erase(it); + } + cache_.emplace(key, std::move(value)); + + if (cache_.size() > capacity_) { + cache_.pop_front(); + } + QUICHE_DCHECK_LE(cache_.size(), capacity_); + } + + iterator Lookup(const K& key) { + auto iter = cache_.find(key); + if (iter == cache_.end()) { + return iter; + } + + std::unique_ptr value = std::move(iter->second); + cache_.erase(iter); + auto result = cache_.emplace(key, std::move(value)); + QUICHE_DCHECK(result.second); + return result.first; + } + + iterator Erase(iterator iter) { return cache_.erase(iter); } + + // Removes all entries from the cache. + void Clear() { cache_.clear(); } + + // Returns maximum size of the cache. + size_t MaxSize() const { return capacity_; } + + // Returns current size of the cache. + size_t Size() const { return cache_.size(); } + + private: + quiche::QuicheLinkedHashMap, Hash, Eq> cache_; + const size_t capacity_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_LRU_CACHE_H_ diff --git a/quiche/quic/core/quic_lru_cache_test.cc b/quiche/quic/core/quic_lru_cache_test.cc new file mode 100644 index 000000000000..91a7913b84cc --- /dev/null +++ b/quiche/quic/core/quic_lru_cache_test.cc @@ -0,0 +1,81 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_lru_cache.h" + +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { +namespace { + +struct CachedItem { + explicit CachedItem(uint32_t new_value) : value(new_value) {} + + uint32_t value; +}; + +TEST(QuicLRUCacheTest, InsertAndLookup) { + QuicLRUCache cache(5); + EXPECT_EQ(cache.end(), cache.Lookup(1)); + EXPECT_EQ(0u, cache.Size()); + EXPECT_EQ(5u, cache.MaxSize()); + + // Check that item 1 was properly inserted. + std::unique_ptr item1(new CachedItem(11)); + cache.Insert(1, std::move(item1)); + EXPECT_EQ(1u, cache.Size()); + EXPECT_EQ(11u, cache.Lookup(1)->second->value); + + // Check that item 2 overrides item 1. + std::unique_ptr item2(new CachedItem(12)); + cache.Insert(1, std::move(item2)); + EXPECT_EQ(1u, cache.Size()); + EXPECT_EQ(12u, cache.Lookup(1)->second->value); + + std::unique_ptr item3(new CachedItem(13)); + cache.Insert(3, std::move(item3)); + EXPECT_EQ(2u, cache.Size()); + auto iter = cache.Lookup(3); + ASSERT_NE(cache.end(), iter); + EXPECT_EQ(13u, iter->second->value); + cache.Erase(iter); + ASSERT_EQ(cache.end(), cache.Lookup(3)); + EXPECT_EQ(1u, cache.Size()); + + // No memory leakage. + cache.Clear(); + EXPECT_EQ(0u, cache.Size()); +} + +TEST(QuicLRUCacheTest, Eviction) { + QuicLRUCache cache(3); + + for (size_t i = 1; i <= 4; ++i) { + std::unique_ptr item(new CachedItem(10 + i)); + cache.Insert(i, std::move(item)); + } + + EXPECT_EQ(3u, cache.Size()); + EXPECT_EQ(3u, cache.MaxSize()); + + // Make sure item 1 is evicted. + EXPECT_EQ(cache.end(), cache.Lookup(1)); + EXPECT_EQ(14u, cache.Lookup(4)->second->value); + + EXPECT_EQ(12u, cache.Lookup(2)->second->value); + std::unique_ptr item5(new CachedItem(15)); + cache.Insert(5, std::move(item5)); + // Make sure item 3 is evicted. + EXPECT_EQ(cache.end(), cache.Lookup(3)); + EXPECT_EQ(15u, cache.Lookup(5)->second->value); + + // No memory leakage. + cache.Clear(); + EXPECT_EQ(0u, cache.Size()); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_mtu_discovery.cc b/quiche/quic/core/quic_mtu_discovery.cc new file mode 100644 index 000000000000..373239344fd7 --- /dev/null +++ b/quiche/quic/core/quic_mtu_discovery.cc @@ -0,0 +1,137 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_mtu_discovery.h" + +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_stack_trace.h" + +namespace quic { + +QuicConnectionMtuDiscoverer::QuicConnectionMtuDiscoverer( + QuicPacketCount packets_between_probes_base, QuicPacketNumber next_probe_at) + : packets_between_probes_(packets_between_probes_base), + next_probe_at_(next_probe_at) {} + +void QuicConnectionMtuDiscoverer::Enable( + QuicByteCount max_packet_length, QuicByteCount target_max_packet_length) { + QUICHE_DCHECK(!IsEnabled()); + + if (target_max_packet_length <= max_packet_length) { + QUIC_DVLOG(1) << "MtuDiscoverer not enabled. target_max_packet_length:" + << target_max_packet_length + << " <= max_packet_length:" << max_packet_length; + return; + } + + min_probe_length_ = max_packet_length; + max_probe_length_ = target_max_packet_length; + QUICHE_DCHECK(IsEnabled()); + + QUIC_DVLOG(1) << "MtuDiscoverer enabled. min:" << min_probe_length_ + << ", max:" << max_probe_length_ + << ", next:" << next_probe_packet_length(); +} + +void QuicConnectionMtuDiscoverer::Disable() { + *this = QuicConnectionMtuDiscoverer(packets_between_probes_, next_probe_at_); +} + +bool QuicConnectionMtuDiscoverer::IsEnabled() const { + return min_probe_length_ < max_probe_length_; +} + +bool QuicConnectionMtuDiscoverer::ShouldProbeMtu( + QuicPacketNumber largest_sent_packet) const { + if (!IsEnabled()) { + return false; + } + + if (remaining_probe_count_ == 0) { + QUIC_DVLOG(1) + << "ShouldProbeMtu returns false because max probe count reached"; + return false; + } + + if (largest_sent_packet < next_probe_at_) { + QUIC_DVLOG(1) << "ShouldProbeMtu returns false because not enough packets " + "sent since last probe. largest_sent_packet:" + << largest_sent_packet + << ", next_probe_at_:" << next_probe_at_; + return false; + } + + QUIC_DVLOG(1) << "ShouldProbeMtu returns true. largest_sent_packet:" + << largest_sent_packet; + return true; +} + +QuicPacketLength QuicConnectionMtuDiscoverer::GetUpdatedMtuProbeSize( + QuicPacketNumber largest_sent_packet) { + QUICHE_DCHECK(ShouldProbeMtu(largest_sent_packet)); + + QuicPacketLength probe_packet_length = next_probe_packet_length(); + if (probe_packet_length == last_probe_length_) { + // The next probe packet is as big as the previous one. Assuming the + // previous one exceeded MTU, we need to decrease the probe packet length. + max_probe_length_ = probe_packet_length; + } else { + QUICHE_DCHECK_GT(probe_packet_length, last_probe_length_); + } + last_probe_length_ = next_probe_packet_length(); + + packets_between_probes_ *= 2; + next_probe_at_ = largest_sent_packet + packets_between_probes_ + 1; + if (remaining_probe_count_ > 0) { + --remaining_probe_count_; + } + + QUIC_DVLOG(1) << "GetUpdatedMtuProbeSize: probe_packet_length:" + << last_probe_length_ + << ", New packets_between_probes_:" << packets_between_probes_ + << ", next_probe_at_:" << next_probe_at_ + << ", remaining_probe_count_:" << remaining_probe_count_; + QUICHE_DCHECK(!ShouldProbeMtu(largest_sent_packet)); + return last_probe_length_; +} + +QuicPacketLength QuicConnectionMtuDiscoverer::next_probe_packet_length() const { + QUICHE_DCHECK_NE(min_probe_length_, 0); + QUICHE_DCHECK_NE(max_probe_length_, 0); + QUICHE_DCHECK_GE(max_probe_length_, min_probe_length_); + + const QuicPacketLength normal_next_probe_length = + (min_probe_length_ + max_probe_length_ + 1) / 2; + + if (remaining_probe_count_ == 1 && + normal_next_probe_length > last_probe_length_) { + // If the previous probe succeeded, and there is only one last probe to + // send, use |max_probe_length_| for the last probe. + return max_probe_length_; + } + return normal_next_probe_length; +} + +void QuicConnectionMtuDiscoverer::OnMaxPacketLengthUpdated( + QuicByteCount old_value, QuicByteCount new_value) { + if (!IsEnabled() || new_value <= old_value) { + return; + } + + QUICHE_DCHECK_EQ(old_value, min_probe_length_); + min_probe_length_ = new_value; +} + +std::ostream& operator<<(std::ostream& os, + const QuicConnectionMtuDiscoverer& d) { + os << "{ min_probe_length_:" << d.min_probe_length_ + << " max_probe_length_:" << d.max_probe_length_ + << " last_probe_length_:" << d.last_probe_length_ + << " remaining_probe_count_:" << d.remaining_probe_count_ + << " packets_between_probes_:" << d.packets_between_probes_ + << " next_probe_at_:" << d.next_probe_at_ << " }"; + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/quic_mtu_discovery.h b/quiche/quic/core/quic_mtu_discovery.h new file mode 100644 index 000000000000..c44894f260ae --- /dev/null +++ b/quiche/quic/core/quic_mtu_discovery.h @@ -0,0 +1,116 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_MTU_DISCOVERY_H_ +#define QUICHE_QUIC_CORE_QUIC_MTU_DISCOVERY_H_ + +#include + +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_types.h" + +namespace quic { + +// The initial number of packets between MTU probes. After each attempt the +// number is doubled. +const QuicPacketCount kPacketsBetweenMtuProbesBase = 100; + +// The number of MTU probes that get sent before giving up. +const size_t kMtuDiscoveryAttempts = 3; + +// Ensure that exponential back-off does not result in an integer overflow. +// The number of packets can be potentially capped, but that is not useful at +// current kMtuDiscoveryAttempts value, and hence is not implemented at present. +static_assert(kMtuDiscoveryAttempts + 8 < 8 * sizeof(QuicPacketNumber), + "The number of MTU discovery attempts is too high"); +static_assert(kPacketsBetweenMtuProbesBase < (1 << 8), + "The initial number of packets between MTU probes is too high"); + +// The increased packet size targeted when doing path MTU discovery. +const QuicByteCount kMtuDiscoveryTargetPacketSizeHigh = 1400; +const QuicByteCount kMtuDiscoveryTargetPacketSizeLow = 1380; + +static_assert(kMtuDiscoveryTargetPacketSizeLow <= kMaxOutgoingPacketSize, + "MTU discovery target is too large"); +static_assert(kMtuDiscoveryTargetPacketSizeHigh <= kMaxOutgoingPacketSize, + "MTU discovery target is too large"); + +static_assert(kMtuDiscoveryTargetPacketSizeLow > kDefaultMaxPacketSize, + "MTU discovery target does not exceed the default packet size"); +static_assert(kMtuDiscoveryTargetPacketSizeHigh > kDefaultMaxPacketSize, + "MTU discovery target does not exceed the default packet size"); + +// QuicConnectionMtuDiscoverer is a MTU discovery controller, it answers two +// questions: +// 1) Probe scheduling: Whether a connection should send a MTU probe packet +// right now. +// 2) MTU search stradegy: When it is time to send, what should be the size of +// the probing packet. +// Note the discoverer does not actually send or process probing packets. +// +// Unit tests are in QuicConnectionTest.MtuDiscovery*. +class QUIC_EXPORT_PRIVATE QuicConnectionMtuDiscoverer { + public: + // Construct a discoverer in the disabled state. + QuicConnectionMtuDiscoverer() = default; + + // Construct a discoverer in the disabled state, with the given parameters. + QuicConnectionMtuDiscoverer(QuicPacketCount packets_between_probes_base, + QuicPacketNumber next_probe_at); + + // Enable the discoverer by setting the probe target. + // max_packet_length: The max packet length currently used. + // target_max_packet_length: The target max packet length to probe. + void Enable(QuicByteCount max_packet_length, + QuicByteCount target_max_packet_length); + + // Disable the discoverer by unsetting the probe target. + void Disable(); + + // Whether a MTU probe packet should be sent right now. + // Always return false if disabled. + bool ShouldProbeMtu(QuicPacketNumber largest_sent_packet) const; + + // Called immediately before a probing packet is sent, to get the size of the + // packet. + // REQUIRES: ShouldProbeMtu(largest_sent_packet) == true. + QuicPacketLength GetUpdatedMtuProbeSize(QuicPacketNumber largest_sent_packet); + + // Called after the max packet length is updated, which is triggered by a ack + // of a probing packet. + void OnMaxPacketLengthUpdated(QuicByteCount old_value, + QuicByteCount new_value); + + QuicPacketCount packets_between_probes() const { + return packets_between_probes_; + } + + QuicPacketNumber next_probe_at() const { return next_probe_at_; } + + QUIC_EXPORT_PRIVATE friend std::ostream& operator<<( + std::ostream& os, const QuicConnectionMtuDiscoverer& d); + + private: + bool IsEnabled() const; + QuicPacketLength next_probe_packet_length() const; + + QuicPacketLength min_probe_length_ = 0; + QuicPacketLength max_probe_length_ = 0; + + QuicPacketLength last_probe_length_ = 0; + + uint16_t remaining_probe_count_ = kMtuDiscoveryAttempts; + + // The number of packets between MTU probes. + QuicPacketCount packets_between_probes_ = kPacketsBetweenMtuProbesBase; + + // The packet number of the packet after which the next MTU probe will be + // sent. + QuicPacketNumber next_probe_at_ = + QuicPacketNumber(kPacketsBetweenMtuProbesBase); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_MTU_DISCOVERY_H_ diff --git a/quiche/quic/core/quic_network_blackhole_detector.cc b/quiche/quic/core/quic_network_blackhole_detector.cc new file mode 100644 index 000000000000..4ded8bddc743 --- /dev/null +++ b/quiche/quic/core/quic_network_blackhole_detector.cc @@ -0,0 +1,135 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_network_blackhole_detector.h" + +#include "quiche/quic/core/quic_constants.h" + +namespace quic { + +namespace { + +class AlarmDelegate : public QuicAlarm::DelegateWithContext { + public: + explicit AlarmDelegate(QuicNetworkBlackholeDetector* detector, + QuicConnectionContext* context) + : QuicAlarm::DelegateWithContext(context), detector_(detector) {} + AlarmDelegate(const AlarmDelegate&) = delete; + AlarmDelegate& operator=(const AlarmDelegate&) = delete; + + void OnAlarm() override { detector_->OnAlarm(); } + + private: + QuicNetworkBlackholeDetector* detector_; +}; + +} // namespace + +QuicNetworkBlackholeDetector::QuicNetworkBlackholeDetector( + Delegate* delegate, QuicConnectionArena* arena, + QuicAlarmFactory* alarm_factory, QuicConnectionContext* context) + : delegate_(delegate), + alarm_(alarm_factory->CreateAlarm( + arena->New(this, context), arena)) {} + +void QuicNetworkBlackholeDetector::OnAlarm() { + QuicTime next_deadline = GetEarliestDeadline(); + if (!next_deadline.IsInitialized()) { + QUIC_BUG(quic_bug_10328_1) << "BlackholeDetector alarm fired unexpectedly"; + return; + } + + QUIC_DVLOG(1) << "BlackholeDetector alarm firing. next_deadline:" + << next_deadline + << ", path_degrading_deadline_:" << path_degrading_deadline_ + << ", path_mtu_reduction_deadline_:" + << path_mtu_reduction_deadline_ + << ", blackhole_deadline_:" << blackhole_deadline_; + if (path_degrading_deadline_ == next_deadline) { + path_degrading_deadline_ = QuicTime::Zero(); + delegate_->OnPathDegradingDetected(); + } + + if (path_mtu_reduction_deadline_ == next_deadline) { + path_mtu_reduction_deadline_ = QuicTime::Zero(); + delegate_->OnPathMtuReductionDetected(); + } + + if (blackhole_deadline_ == next_deadline) { + blackhole_deadline_ = QuicTime::Zero(); + delegate_->OnBlackholeDetected(); + } + + UpdateAlarm(); +} + +void QuicNetworkBlackholeDetector::StopDetection(bool permanent) { + if (permanent) { + alarm_->PermanentCancel(); + } else { + alarm_->Cancel(); + } + path_degrading_deadline_ = QuicTime::Zero(); + blackhole_deadline_ = QuicTime::Zero(); + path_mtu_reduction_deadline_ = QuicTime::Zero(); +} + +void QuicNetworkBlackholeDetector::RestartDetection( + QuicTime path_degrading_deadline, QuicTime blackhole_deadline, + QuicTime path_mtu_reduction_deadline) { + path_degrading_deadline_ = path_degrading_deadline; + blackhole_deadline_ = blackhole_deadline; + path_mtu_reduction_deadline_ = path_mtu_reduction_deadline; + + QUIC_BUG_IF(quic_bug_12708_1, blackhole_deadline_.IsInitialized() && + blackhole_deadline_ != GetLastDeadline()) + << "Blackhole detection deadline should be the last deadline."; + + UpdateAlarm(); +} + +QuicTime QuicNetworkBlackholeDetector::GetEarliestDeadline() const { + QuicTime result = QuicTime::Zero(); + for (QuicTime t : {path_degrading_deadline_, blackhole_deadline_, + path_mtu_reduction_deadline_}) { + if (!t.IsInitialized()) { + continue; + } + + if (!result.IsInitialized() || t < result) { + result = t; + } + } + + return result; +} + +QuicTime QuicNetworkBlackholeDetector::GetLastDeadline() const { + return std::max({path_degrading_deadline_, blackhole_deadline_, + path_mtu_reduction_deadline_}); +} + +void QuicNetworkBlackholeDetector::UpdateAlarm() const { + // If called after OnBlackholeDetected(), the alarm may have been permanently + // cancelled and is not safe to be armed again. + if (alarm_->IsPermanentlyCancelled()) { + return; + } + + QuicTime next_deadline = GetEarliestDeadline(); + + QUIC_DVLOG(1) << "Updating alarm. next_deadline:" << next_deadline + << ", path_degrading_deadline_:" << path_degrading_deadline_ + << ", path_mtu_reduction_deadline_:" + << path_mtu_reduction_deadline_ + << ", blackhole_deadline_:" << blackhole_deadline_; + + alarm_->Update(next_deadline, kAlarmGranularity); +} + +bool QuicNetworkBlackholeDetector::IsDetectionInProgress() const { + return alarm_->IsSet(); +} + +} // namespace quic diff --git a/quiche/quic/core/quic_network_blackhole_detector.h b/quiche/quic/core/quic_network_blackhole_detector.h new file mode 100644 index 000000000000..7defc4760035 --- /dev/null +++ b/quiche/quic/core/quic_network_blackhole_detector.h @@ -0,0 +1,91 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_NETWORK_BLACKHOLE_DETECTOR_H_ +#define QUICHE_QUIC_CORE_QUIC_NETWORK_BLACKHOLE_DETECTOR_H_ + +#include "quiche/quic/core/quic_alarm.h" +#include "quiche/quic/core/quic_alarm_factory.h" +#include "quiche/quic/core/quic_one_block_arena.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_flags.h" + +namespace quic { + +namespace test { +class QuicConnectionPeer; +class QuicNetworkBlackholeDetectorPeer; +} // namespace test + +// QuicNetworkBlackholeDetector can detect path degrading and/or network +// blackhole. If both detections are in progress, detector will be in path +// degrading detection mode. After reporting path degrading detected, detector +// switches to blackhole detection mode. So blackhole detection deadline must +// be later than path degrading deadline. +class QUIC_EXPORT_PRIVATE QuicNetworkBlackholeDetector { + public: + class QUIC_EXPORT_PRIVATE Delegate { + public: + virtual ~Delegate() {} + + // Called when the path degrading alarm fires. + virtual void OnPathDegradingDetected() = 0; + + // Called when the path blackhole alarm fires. + virtual void OnBlackholeDetected() = 0; + + // Called when the path mtu reduction alarm fires. + virtual void OnPathMtuReductionDetected() = 0; + }; + + QuicNetworkBlackholeDetector(Delegate* delegate, QuicConnectionArena* arena, + QuicAlarmFactory* alarm_factory, + QuicConnectionContext* context); + + // Called to stop all detections. If |permanent|, the alarm will be cancelled + // permanently and future calls to RestartDetection will be no-op. + void StopDetection(bool permanent); + + // Called to restart path degrading, path mtu reduction and blackhole + // detections. Please note, if |blackhole_deadline| is set, it must be the + // furthest in the future of all deadlines. + void RestartDetection(QuicTime path_degrading_deadline, + QuicTime blackhole_deadline, + QuicTime path_mtu_reduction_deadline); + + // Called when |alarm_| fires. + void OnAlarm(); + + // Returns true if |alarm_| is set. + bool IsDetectionInProgress() const; + + private: + friend class test::QuicConnectionPeer; + friend class test::QuicNetworkBlackholeDetectorPeer; + + QuicTime GetEarliestDeadline() const; + QuicTime GetLastDeadline() const; + + // Update alarm to the next deadline. + void UpdateAlarm() const; + + Delegate* delegate_; // Not owned. + + // Time that Delegate::OnPathDegrading will be called. 0 means no path + // degrading detection is in progress. + QuicTime path_degrading_deadline_ = QuicTime::Zero(); + // Time that Delegate::OnBlackholeDetected will be called. 0 means no + // blackhole detection is in progress. + QuicTime blackhole_deadline_ = QuicTime::Zero(); + // Time that Delegate::OnPathMtuReductionDetected will be called. 0 means no + // path mtu reduction detection is in progress. + QuicTime path_mtu_reduction_deadline_ = QuicTime::Zero(); + + QuicArenaScopedPtr alarm_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_NETWORK_BLACKHOLE_DETECTOR_H_ diff --git a/quiche/quic/core/quic_network_blackhole_detector_test.cc b/quiche/quic/core/quic_network_blackhole_detector_test.cc new file mode 100644 index 000000000000..ca2c87d55280 --- /dev/null +++ b/quiche/quic/core/quic_network_blackhole_detector_test.cc @@ -0,0 +1,139 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_network_blackhole_detector.h" + +#include "quiche/quic/core/quic_one_block_arena.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { +namespace test { + +class QuicNetworkBlackholeDetectorPeer { + public: + static QuicAlarm* GetAlarm(QuicNetworkBlackholeDetector* detector) { + return detector->alarm_.get(); + } +}; + +namespace { +class MockDelegate : public QuicNetworkBlackholeDetector::Delegate { + public: + MOCK_METHOD(void, OnPathDegradingDetected, (), (override)); + MOCK_METHOD(void, OnBlackholeDetected, (), (override)); + MOCK_METHOD(void, OnPathMtuReductionDetected, (), (override)); +}; + +const size_t kPathDegradingDelayInSeconds = 5; +const size_t kPathMtuReductionDelayInSeconds = 7; +const size_t kBlackholeDelayInSeconds = 10; + +class QuicNetworkBlackholeDetectorTest : public QuicTest { + public: + QuicNetworkBlackholeDetectorTest() + : detector_(&delegate_, &arena_, &alarm_factory_, /*context=*/nullptr), + alarm_(static_cast( + QuicNetworkBlackholeDetectorPeer::GetAlarm(&detector_))), + path_degrading_delay_( + QuicTime::Delta::FromSeconds(kPathDegradingDelayInSeconds)), + path_mtu_reduction_delay_( + QuicTime::Delta::FromSeconds(kPathMtuReductionDelayInSeconds)), + blackhole_delay_( + QuicTime::Delta::FromSeconds(kBlackholeDelayInSeconds)) { + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(1)); + } + + protected: + void RestartDetection() { + detector_.RestartDetection(clock_.Now() + path_degrading_delay_, + clock_.Now() + blackhole_delay_, + clock_.Now() + path_mtu_reduction_delay_); + } + + testing::StrictMock delegate_; + QuicConnectionArena arena_; + MockAlarmFactory alarm_factory_; + + QuicNetworkBlackholeDetector detector_; + + MockAlarmFactory::TestAlarm* alarm_; + MockClock clock_; + const QuicTime::Delta path_degrading_delay_; + const QuicTime::Delta path_mtu_reduction_delay_; + const QuicTime::Delta blackhole_delay_; +}; + +TEST_F(QuicNetworkBlackholeDetectorTest, StartAndFire) { + EXPECT_FALSE(detector_.IsDetectionInProgress()); + + RestartDetection(); + EXPECT_TRUE(detector_.IsDetectionInProgress()); + EXPECT_EQ(clock_.Now() + path_degrading_delay_, alarm_->deadline()); + + // Fire path degrading alarm. + clock_.AdvanceTime(path_degrading_delay_); + EXPECT_CALL(delegate_, OnPathDegradingDetected()); + alarm_->Fire(); + + // Verify path mtu reduction detection is still in progress. + EXPECT_TRUE(detector_.IsDetectionInProgress()); + EXPECT_EQ(clock_.Now() + path_mtu_reduction_delay_ - path_degrading_delay_, + alarm_->deadline()); + + // Fire path mtu reduction detection alarm. + clock_.AdvanceTime(path_mtu_reduction_delay_ - path_degrading_delay_); + EXPECT_CALL(delegate_, OnPathMtuReductionDetected()); + alarm_->Fire(); + + // Verify blackhole detection is still in progress. + EXPECT_TRUE(detector_.IsDetectionInProgress()); + EXPECT_EQ(clock_.Now() + blackhole_delay_ - path_mtu_reduction_delay_, + alarm_->deadline()); + + // Fire blackhole detection alarm. + clock_.AdvanceTime(blackhole_delay_ - path_mtu_reduction_delay_); + EXPECT_CALL(delegate_, OnBlackholeDetected()); + alarm_->Fire(); + EXPECT_FALSE(detector_.IsDetectionInProgress()); +} + +TEST_F(QuicNetworkBlackholeDetectorTest, RestartAndStop) { + RestartDetection(); + + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(1)); + RestartDetection(); + EXPECT_EQ(clock_.Now() + path_degrading_delay_, alarm_->deadline()); + + detector_.StopDetection(/*permanent=*/false); + EXPECT_FALSE(detector_.IsDetectionInProgress()); +} + +TEST_F(QuicNetworkBlackholeDetectorTest, PathDegradingFiresAndRestart) { + EXPECT_FALSE(detector_.IsDetectionInProgress()); + RestartDetection(); + EXPECT_TRUE(detector_.IsDetectionInProgress()); + EXPECT_EQ(clock_.Now() + path_degrading_delay_, alarm_->deadline()); + + // Fire path degrading alarm. + clock_.AdvanceTime(path_degrading_delay_); + EXPECT_CALL(delegate_, OnPathDegradingDetected()); + alarm_->Fire(); + + // Verify path mtu reduction detection is still in progress. + EXPECT_TRUE(detector_.IsDetectionInProgress()); + EXPECT_EQ(clock_.Now() + path_mtu_reduction_delay_ - path_degrading_delay_, + alarm_->deadline()); + + // After 100ms, restart detections on forward progress. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(100)); + RestartDetection(); + // Verify alarm is armed based on path degrading deadline. + EXPECT_EQ(clock_.Now() + path_degrading_delay_, alarm_->deadline()); +} + +} // namespace + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_one_block_arena.h b/quiche/quic/core/quic_one_block_arena.h new file mode 100644 index 000000000000..b4162541fb5e --- /dev/null +++ b/quiche/quic/core/quic_one_block_arena.h @@ -0,0 +1,77 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// An arena that consists of a single inlined block of |ArenaSize|. Useful to +// avoid repeated calls to malloc/new and to improve memory locality. +// QUICHE_DCHECK's if an allocation out of the arena ever fails in debug builds; +// falls back to heap allocation in release builds. + +#ifndef QUICHE_QUIC_CORE_QUIC_ONE_BLOCK_ARENA_H_ +#define QUICHE_QUIC_CORE_QUIC_ONE_BLOCK_ARENA_H_ + +#include + +#include "absl/base/optimization.h" +#include "quiche/quic/core/quic_arena_scoped_ptr.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +template +class QUIC_EXPORT_PRIVATE QuicOneBlockArena { + static const uint32_t kMaxAlign = 8; + + public: + QuicOneBlockArena() : offset_(0) {} + QuicOneBlockArena(const QuicOneBlockArena&) = delete; + QuicOneBlockArena& operator=(const QuicOneBlockArena&) = delete; + + // Instantiates an object of type |T| with |args|. |args| are perfectly + // forwarded to |T|'s constructor. The returned pointer's lifetime is + // controlled by QuicArenaScopedPtr. + template + QuicArenaScopedPtr New(Args&&... args) { + QUICHE_DCHECK_LT(AlignedSize(), ArenaSize) + << "Object is too large for the arena."; + static_assert(alignof(T) > 1, + "Objects added to the arena must be at least 2B aligned."); + if (ABSL_PREDICT_FALSE(offset_ > ArenaSize - AlignedSize())) { + QUIC_BUG(quic_bug_10593_1) + << "Ran out of space in QuicOneBlockArena at " << this + << ", max size was " << ArenaSize << ", failing request was " + << AlignedSize() << ", end of arena was " << offset_; + return QuicArenaScopedPtr(new T(std::forward(args)...)); + } + + void* buf = &storage_[offset_]; + new (buf) T(std::forward(args)...); + offset_ += AlignedSize(); + return QuicArenaScopedPtr(buf, + QuicArenaScopedPtr::ConstructFrom::kArena); + } + + private: + // Returns the size of |T| aligned up to |kMaxAlign|. + template + static inline uint32_t AlignedSize() { + return ((sizeof(T) + (kMaxAlign - 1)) / kMaxAlign) * kMaxAlign; + } + + // Actual storage. + // Subtle/annoying: the value '8' must be coded explicitly into the alignment + // declaration for MSVC. + alignas(8) char storage_[ArenaSize]; + // Current offset into the storage. + uint32_t offset_; +}; + +// QuicConnections currently use around 1KB of polymorphic types which would +// ordinarily be on the heap. Instead, store them inline in an arena. +using QuicConnectionArena = QuicOneBlockArena<1380>; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_ONE_BLOCK_ARENA_H_ diff --git a/quiche/quic/core/quic_one_block_arena_test.cc b/quiche/quic/core/quic_one_block_arena_test.cc new file mode 100644 index 000000000000..5c1079b77305 --- /dev/null +++ b/quiche/quic/core/quic_one_block_arena_test.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_one_block_arena.h" + +#include + +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic::test { +namespace { + +static const uint32_t kMaxAlign = 8; + +struct TestObject { + uint32_t value; +}; + +class QuicOneBlockArenaTest : public QuicTest {}; + +TEST_F(QuicOneBlockArenaTest, AllocateSuccess) { + QuicOneBlockArena<1024> arena; + QuicArenaScopedPtr ptr = arena.New(); + EXPECT_TRUE(ptr.is_from_arena()); +} + +TEST_F(QuicOneBlockArenaTest, Exhaust) { + QuicOneBlockArena<1024> arena; + for (size_t i = 0; i < 1024 / kMaxAlign; ++i) { + QuicArenaScopedPtr ptr = arena.New(); + EXPECT_TRUE(ptr.is_from_arena()); + } + QuicArenaScopedPtr ptr; + EXPECT_QUIC_BUG(ptr = arena.New(), + "Ran out of space in QuicOneBlockArena"); + EXPECT_FALSE(ptr.is_from_arena()); +} + +TEST_F(QuicOneBlockArenaTest, NoOverlaps) { + QuicOneBlockArena<1024> arena; + std::vector> objects; + QuicIntervalSet used; + for (size_t i = 0; i < 1024 / kMaxAlign; ++i) { + QuicArenaScopedPtr ptr = arena.New(); + EXPECT_TRUE(ptr.is_from_arena()); + + uintptr_t begin = reinterpret_cast(ptr.get()); + uintptr_t end = begin + sizeof(TestObject); + EXPECT_FALSE(used.Contains(begin)); + EXPECT_FALSE(used.Contains(end - 1)); + used.Add(begin, end); + } +} + +} // namespace +} // namespace quic::test diff --git a/quiche/quic/core/quic_packet_creator.cc b/quiche/quic/core/quic_packet_creator.cc new file mode 100644 index 000000000000..c04a4180761b --- /dev/null +++ b/quiche/quic/core/quic_packet_creator.cc @@ -0,0 +1,2289 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_packet_creator.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/base/optimization.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/frames/quic_frame.h" +#include "quiche/quic/core/frames/quic_padding_frame.h" +#include "quiche/quic/core/frames/quic_path_challenge_frame.h" +#include "quiche/quic/core/frames/quic_stream_frame.h" +#include "quiche/quic/core/quic_chaos_protector.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_exported_stats.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_server_stats.h" +#include "quiche/common/print_elements.h" + +namespace quic { +namespace { + +QuicLongHeaderType EncryptionlevelToLongHeaderType(EncryptionLevel level) { + switch (level) { + case ENCRYPTION_INITIAL: + return INITIAL; + case ENCRYPTION_HANDSHAKE: + return HANDSHAKE; + case ENCRYPTION_ZERO_RTT: + return ZERO_RTT_PROTECTED; + case ENCRYPTION_FORWARD_SECURE: + QUIC_BUG(quic_bug_12398_1) + << "Try to derive long header type for packet with encryption level: " + << level; + return INVALID_PACKET_TYPE; + default: + QUIC_BUG(quic_bug_10752_1) << level; + return INVALID_PACKET_TYPE; + } +} + +void LogCoalesceStreamFrameStatus(bool success) { + QUIC_HISTOGRAM_BOOL("QuicSession.CoalesceStreamFrameStatus", success, + "Success rate of coalesing stream frames attempt."); +} + +// ScopedPacketContextSwitcher saves |packet|'s states and change states +// during its construction. When the switcher goes out of scope, it restores +// saved states. +class ScopedPacketContextSwitcher { + public: + ScopedPacketContextSwitcher(QuicPacketNumber packet_number, + QuicPacketNumberLength packet_number_length, + EncryptionLevel encryption_level, + SerializedPacket* packet) + + : saved_packet_number_(packet->packet_number), + saved_packet_number_length_(packet->packet_number_length), + saved_encryption_level_(packet->encryption_level), + packet_(packet) { + packet_->packet_number = packet_number, + packet_->packet_number_length = packet_number_length; + packet_->encryption_level = encryption_level; + } + + ~ScopedPacketContextSwitcher() { + packet_->packet_number = saved_packet_number_; + packet_->packet_number_length = saved_packet_number_length_; + packet_->encryption_level = saved_encryption_level_; + } + + private: + const QuicPacketNumber saved_packet_number_; + const QuicPacketNumberLength saved_packet_number_length_; + const EncryptionLevel saved_encryption_level_; + SerializedPacket* packet_; +}; + +} // namespace + +#define ENDPOINT \ + (framer_->perspective() == Perspective::IS_SERVER ? "Server: " : "Client: ") + +QuicPacketCreator::QuicPacketCreator(QuicConnectionId server_connection_id, + QuicFramer* framer, + DelegateInterface* delegate) + : QuicPacketCreator(server_connection_id, framer, QuicRandom::GetInstance(), + delegate) {} + +QuicPacketCreator::QuicPacketCreator(QuicConnectionId server_connection_id, + QuicFramer* framer, QuicRandom* random, + DelegateInterface* delegate) + : delegate_(delegate), + debug_delegate_(nullptr), + framer_(framer), + random_(random), + send_version_in_packet_(framer->perspective() == Perspective::IS_CLIENT), + have_diversification_nonce_(false), + max_packet_length_(0), + server_connection_id_included_(CONNECTION_ID_PRESENT), + packet_size_(0), + server_connection_id_(server_connection_id), + client_connection_id_(EmptyQuicConnectionId()), + packet_(QuicPacketNumber(), PACKET_1BYTE_PACKET_NUMBER, nullptr, 0, false, + false), + pending_padding_bytes_(0), + needs_full_padding_(false), + next_transmission_type_(NOT_RETRANSMISSION), + flusher_attached_(false), + fully_pad_crypto_handshake_packets_(true), + latched_hard_max_packet_length_(0), + max_datagram_frame_size_(0) { + SetMaxPacketLength(kDefaultMaxPacketSize); + if (!framer_->version().UsesTls()) { + // QUIC+TLS negotiates the maximum datagram frame size via the + // IETF QUIC max_datagram_frame_size transport parameter. + // QUIC_CRYPTO however does not negotiate this so we set its value here. + SetMaxDatagramFrameSize(kMaxAcceptedDatagramFrameSize); + } +} + +QuicPacketCreator::~QuicPacketCreator() { + DeleteFrames(&packet_.retransmittable_frames); +} + +void QuicPacketCreator::SetEncrypter(EncryptionLevel level, + std::unique_ptr encrypter) { + framer_->SetEncrypter(level, std::move(encrypter)); + max_plaintext_size_ = framer_->GetMaxPlaintextSize(max_packet_length_); +} + +bool QuicPacketCreator::CanSetMaxPacketLength() const { + // |max_packet_length_| should not be changed mid-packet. + return queued_frames_.empty(); +} + +void QuicPacketCreator::SetMaxPacketLength(QuicByteCount length) { + QUICHE_DCHECK(CanSetMaxPacketLength()) << ENDPOINT; + + // Avoid recomputing |max_plaintext_size_| if the length does not actually + // change. + if (length == max_packet_length_) { + return; + } + QUIC_DVLOG(1) << ENDPOINT << "Updating packet creator max packet length from " + << max_packet_length_ << " to " << length; + + max_packet_length_ = length; + max_plaintext_size_ = framer_->GetMaxPlaintextSize(max_packet_length_); + QUIC_BUG_IF( + quic_bug_12398_2, + max_plaintext_size_ - PacketHeaderSize() < + MinPlaintextPacketSize(framer_->version(), GetPacketNumberLength())) + << ENDPOINT << "Attempted to set max packet length too small"; +} + +void QuicPacketCreator::SetMaxDatagramFrameSize( + QuicByteCount max_datagram_frame_size) { + constexpr QuicByteCount upper_bound = + std::min(std::numeric_limits::max(), + std::numeric_limits::max()); + if (max_datagram_frame_size > upper_bound) { + // A value of |max_datagram_frame_size| that is equal or greater than + // 2^16-1 is effectively infinite because QUIC packets cannot be that large. + // We therefore clamp the value here to allow us to safely cast + // |max_datagram_frame_size_| to QuicPacketLength or size_t. + max_datagram_frame_size = upper_bound; + } + max_datagram_frame_size_ = max_datagram_frame_size; +} + +void QuicPacketCreator::SetSoftMaxPacketLength(QuicByteCount length) { + QUICHE_DCHECK(CanSetMaxPacketLength()) << ENDPOINT; + if (length > max_packet_length_) { + QUIC_BUG(quic_bug_10752_2) + << ENDPOINT + << "Try to increase max_packet_length_ in " + "SetSoftMaxPacketLength, use SetMaxPacketLength instead."; + return; + } + if (framer_->GetMaxPlaintextSize(length) < + PacketHeaderSize() + + MinPlaintextPacketSize(framer_->version(), GetPacketNumberLength())) { + // Please note: this would not guarantee to fit next packet if the size of + // packet header increases (e.g., encryption level changes). + QUIC_DLOG(INFO) << ENDPOINT << length + << " is too small to fit packet header"; + RemoveSoftMaxPacketLength(); + return; + } + QUIC_DVLOG(1) << ENDPOINT << "Setting soft max packet length to: " << length; + latched_hard_max_packet_length_ = max_packet_length_; + max_packet_length_ = length; + max_plaintext_size_ = framer_->GetMaxPlaintextSize(length); +} + +// Stops serializing version of the protocol in packets sent after this call. +// A packet that is already open might send kQuicVersionSize bytes less than the +// maximum packet size if we stop sending version before it is serialized. +void QuicPacketCreator::StopSendingVersion() { + QUICHE_DCHECK(send_version_in_packet_) << ENDPOINT; + QUICHE_DCHECK(!version().HasIetfInvariantHeader()) << ENDPOINT; + send_version_in_packet_ = false; + if (packet_size_ > 0) { + QUICHE_DCHECK_LT(kQuicVersionSize, packet_size_) << ENDPOINT; + packet_size_ -= kQuicVersionSize; + } +} + +void QuicPacketCreator::SetDiversificationNonce( + const DiversificationNonce& nonce) { + QUICHE_DCHECK(!have_diversification_nonce_) << ENDPOINT; + have_diversification_nonce_ = true; + diversification_nonce_ = nonce; +} + +void QuicPacketCreator::UpdatePacketNumberLength( + QuicPacketNumber least_packet_awaited_by_peer, + QuicPacketCount max_packets_in_flight) { + if (!queued_frames_.empty()) { + // Don't change creator state if there are frames queued. + QUIC_BUG(quic_bug_10752_3) + << ENDPOINT << "Called UpdatePacketNumberLength with " + << queued_frames_.size() + << " queued_frames. First frame type:" << queued_frames_.front().type + << " last frame type:" << queued_frames_.back().type; + return; + } + + const QuicPacketNumber next_packet_number = NextSendingPacketNumber(); + QUICHE_DCHECK_LE(least_packet_awaited_by_peer, next_packet_number) + << ENDPOINT; + const uint64_t current_delta = + next_packet_number - least_packet_awaited_by_peer; + const uint64_t delta = std::max(current_delta, max_packets_in_flight); + const QuicPacketNumberLength packet_number_length = + QuicFramer::GetMinPacketNumberLength(QuicPacketNumber(delta * 4)); + if (packet_.packet_number_length == packet_number_length) { + return; + } + QUIC_DVLOG(1) << ENDPOINT << "Updating packet number length from " + << static_cast(packet_.packet_number_length) << " to " + << static_cast(packet_number_length) + << ", least_packet_awaited_by_peer: " + << least_packet_awaited_by_peer + << " max_packets_in_flight: " << max_packets_in_flight + << " next_packet_number: " << next_packet_number; + packet_.packet_number_length = packet_number_length; +} + +void QuicPacketCreator::SkipNPacketNumbers( + QuicPacketCount count, QuicPacketNumber least_packet_awaited_by_peer, + QuicPacketCount max_packets_in_flight) { + if (!queued_frames_.empty()) { + // Don't change creator state if there are frames queued. + QUIC_BUG(quic_bug_10752_4) + << ENDPOINT << "Called SkipNPacketNumbers with " + << queued_frames_.size() + << " queued_frames. First frame type:" << queued_frames_.front().type + << " last frame type:" << queued_frames_.back().type; + return; + } + if (packet_.packet_number > packet_.packet_number + count) { + // Skipping count packet numbers causes packet number wrapping around, + // reject it. + QUIC_LOG(WARNING) << ENDPOINT << "Skipping " << count + << " packet numbers causes packet number wrapping " + "around, least_packet_awaited_by_peer: " + << least_packet_awaited_by_peer + << " packet_number:" << packet_.packet_number; + return; + } + packet_.packet_number += count; + // Packet number changes, update packet number length if necessary. + UpdatePacketNumberLength(least_packet_awaited_by_peer, max_packets_in_flight); +} + +bool QuicPacketCreator::ConsumeCryptoDataToFillCurrentPacket( + EncryptionLevel level, size_t write_length, QuicStreamOffset offset, + bool needs_full_padding, TransmissionType transmission_type, + QuicFrame* frame) { + QUIC_DVLOG(2) << ENDPOINT << "ConsumeCryptoDataToFillCurrentPacket " << level + << " write_length " << write_length << " offset " << offset + << (needs_full_padding ? " needs_full_padding" : "") << " " + << transmission_type; + if (!CreateCryptoFrame(level, write_length, offset, frame)) { + return false; + } + // When crypto data was sent in stream frames, ConsumeData is called with + // |needs_full_padding = true|. Keep the same behavior here when sending + // crypto frames. + // + // TODO(nharper): Check what the IETF drafts say about padding out initial + // messages and change this as appropriate. + if (needs_full_padding) { + needs_full_padding_ = true; + } + return AddFrame(*frame, transmission_type); +} + +bool QuicPacketCreator::ConsumeDataToFillCurrentPacket( + QuicStreamId id, size_t data_size, QuicStreamOffset offset, bool fin, + bool needs_full_padding, TransmissionType transmission_type, + QuicFrame* frame) { + if (!HasRoomForStreamFrame(id, offset, data_size)) { + return false; + } + CreateStreamFrame(id, data_size, offset, fin, frame); + // Explicitly disallow multi-packet CHLOs. + if (GetQuicFlag(quic_enforce_single_packet_chlo) && + StreamFrameIsClientHello(frame->stream_frame) && + frame->stream_frame.data_length < data_size) { + const std::string error_details = + "Client hello won't fit in a single packet."; + QUIC_BUG(quic_bug_10752_5) + << ENDPOINT << error_details << " Constructed stream frame length: " + << frame->stream_frame.data_length << " CHLO length: " << data_size; + delegate_->OnUnrecoverableError(QUIC_CRYPTO_CHLO_TOO_LARGE, error_details); + return false; + } + if (!AddFrame(*frame, transmission_type)) { + // Fails if we try to write unencrypted stream data. + return false; + } + if (needs_full_padding) { + needs_full_padding_ = true; + } + + return true; +} + +bool QuicPacketCreator::HasRoomForStreamFrame(QuicStreamId id, + QuicStreamOffset offset, + size_t data_size) { + const size_t min_stream_frame_size = QuicFramer::GetMinStreamFrameSize( + framer_->transport_version(), id, offset, /*last_frame_in_packet=*/true, + data_size); + if (BytesFree() > min_stream_frame_size) { + return true; + } + if (!RemoveSoftMaxPacketLength()) { + return false; + } + return BytesFree() > min_stream_frame_size; +} + +bool QuicPacketCreator::HasRoomForMessageFrame(QuicByteCount length) { + const size_t message_frame_size = QuicFramer::GetMessageFrameSize( + framer_->transport_version(), /*last_frame_in_packet=*/true, length); + if (static_cast(message_frame_size) > + max_datagram_frame_size_) { + return false; + } + if (BytesFree() >= message_frame_size) { + return true; + } + if (!RemoveSoftMaxPacketLength()) { + return false; + } + return BytesFree() >= message_frame_size; +} + +// static +size_t QuicPacketCreator::StreamFramePacketOverhead( + QuicTransportVersion version, uint8_t destination_connection_id_length, + uint8_t source_connection_id_length, bool include_version, + bool include_diversification_nonce, + QuicPacketNumberLength packet_number_length, + quiche::QuicheVariableLengthIntegerLength retry_token_length_length, + quiche::QuicheVariableLengthIntegerLength length_length, + QuicStreamOffset offset) { + return GetPacketHeaderSize(version, destination_connection_id_length, + source_connection_id_length, include_version, + include_diversification_nonce, + packet_number_length, retry_token_length_length, 0, + length_length) + + + // Assumes a packet with a single stream frame, which omits the length, + // causing the data length argument to be ignored. + QuicFramer::GetMinStreamFrameSize(version, 1u, offset, true, + kMaxOutgoingPacketSize /* unused */); +} + +void QuicPacketCreator::CreateStreamFrame(QuicStreamId id, size_t data_size, + QuicStreamOffset offset, bool fin, + QuicFrame* frame) { + // Make sure max_packet_length_ is greater than the largest possible overhead + // or max_packet_length_ is set to the soft limit. + QUICHE_DCHECK( + max_packet_length_ > + StreamFramePacketOverhead( + framer_->transport_version(), GetDestinationConnectionIdLength(), + GetSourceConnectionIdLength(), kIncludeVersion, + IncludeNonceInPublicHeader(), PACKET_6BYTE_PACKET_NUMBER, + GetRetryTokenLengthLength(), GetLengthLength(), offset) || + latched_hard_max_packet_length_ > 0) + << ENDPOINT; + + QUIC_BUG_IF(quic_bug_12398_3, !HasRoomForStreamFrame(id, offset, data_size)) + << ENDPOINT << "No room for Stream frame, BytesFree: " << BytesFree() + << " MinStreamFrameSize: " + << QuicFramer::GetMinStreamFrameSize(framer_->transport_version(), id, + offset, true, data_size); + + QUIC_BUG_IF(quic_bug_12398_4, data_size == 0 && !fin) + << ENDPOINT << "Creating a stream frame for stream ID:" << id + << " with no data or fin."; + size_t min_frame_size = QuicFramer::GetMinStreamFrameSize( + framer_->transport_version(), id, offset, + /* last_frame_in_packet= */ true, data_size); + size_t bytes_consumed = + std::min(BytesFree() - min_frame_size, data_size); + + bool set_fin = fin && bytes_consumed == data_size; // Last frame. + *frame = QuicFrame(QuicStreamFrame(id, set_fin, offset, bytes_consumed)); +} + +bool QuicPacketCreator::CreateCryptoFrame(EncryptionLevel level, + size_t write_length, + QuicStreamOffset offset, + QuicFrame* frame) { + const size_t min_frame_size = + QuicFramer::GetMinCryptoFrameSize(write_length, offset); + if (BytesFree() <= min_frame_size && + (!RemoveSoftMaxPacketLength() || BytesFree() <= min_frame_size)) { + return false; + } + size_t max_write_length = BytesFree() - min_frame_size; + size_t bytes_consumed = std::min(max_write_length, write_length); + *frame = QuicFrame(new QuicCryptoFrame(level, offset, bytes_consumed)); + return true; +} + +void QuicPacketCreator::FlushCurrentPacket() { + if (!HasPendingFrames() && pending_padding_bytes_ == 0) { + return; + } + + ABSL_CACHELINE_ALIGNED char stack_buffer[kMaxOutgoingPacketSize]; + QuicOwnedPacketBuffer external_buffer(delegate_->GetPacketBuffer()); + + if (external_buffer.buffer == nullptr) { + external_buffer.buffer = stack_buffer; + external_buffer.release_buffer = nullptr; + } + + QUICHE_DCHECK_EQ(nullptr, packet_.encrypted_buffer) << ENDPOINT; + if (!SerializePacket(std::move(external_buffer), kMaxOutgoingPacketSize, + /*allow_padding=*/true)) { + return; + } + OnSerializedPacket(); +} + +void QuicPacketCreator::OnSerializedPacket() { + QUIC_BUG_IF(quic_bug_12398_5, packet_.encrypted_buffer == nullptr) + << ENDPOINT; + + // Clear bytes_not_retransmitted for packets containing only + // NOT_RETRANSMISSION frames. + if (packet_.transmission_type == NOT_RETRANSMISSION) { + packet_.bytes_not_retransmitted.reset(); + } + + SerializedPacket packet(std::move(packet_)); + ClearPacket(); + RemoveSoftMaxPacketLength(); + delegate_->OnSerializedPacket(std::move(packet)); +} + +void QuicPacketCreator::ClearPacket() { + packet_.has_ack = false; + packet_.has_stop_waiting = false; + packet_.has_ack_ecn = false; + packet_.has_crypto_handshake = NOT_HANDSHAKE; + packet_.transmission_type = NOT_RETRANSMISSION; + packet_.encrypted_buffer = nullptr; + packet_.encrypted_length = 0; + packet_.has_ack_frequency = false; + packet_.has_message = false; + packet_.fate = SEND_TO_WRITER; + QUIC_BUG_IF(quic_bug_12398_6, packet_.release_encrypted_buffer != nullptr) + << ENDPOINT << "packet_.release_encrypted_buffer should be empty"; + packet_.release_encrypted_buffer = nullptr; + QUICHE_DCHECK(packet_.retransmittable_frames.empty()) << ENDPOINT; + QUICHE_DCHECK(packet_.nonretransmittable_frames.empty()) << ENDPOINT; + packet_.largest_acked.Clear(); + needs_full_padding_ = false; + packet_.bytes_not_retransmitted.reset(); + packet_.initial_header.reset(); +} + +size_t QuicPacketCreator::ReserializeInitialPacketInCoalescedPacket( + const SerializedPacket& packet, size_t padding_size, char* buffer, + size_t buffer_len) { + QUIC_BUG_IF(quic_bug_12398_7, packet.encryption_level != ENCRYPTION_INITIAL); + QUIC_BUG_IF(quic_bug_12398_8, packet.nonretransmittable_frames.empty() && + packet.retransmittable_frames.empty()) + << ENDPOINT + << "Attempt to serialize empty ENCRYPTION_INITIAL packet in coalesced " + "packet"; + + if (HasPendingFrames()) { + QUIC_BUG(quic_packet_creator_unexpected_queued_frames) + << "Unexpected queued frames: " << GetPendingFramesInfo(); + return 0; + } + + ScopedPacketContextSwitcher switcher( + packet.packet_number - + 1, // -1 because serialize packet increase packet number. + packet.packet_number_length, packet.encryption_level, &packet_); + for (const QuicFrame& frame : packet.nonretransmittable_frames) { + if (!AddFrame(frame, packet.transmission_type)) { + QUIC_BUG(quic_bug_10752_6) + << ENDPOINT << "Failed to serialize frame: " << frame; + return 0; + } + } + for (const QuicFrame& frame : packet.retransmittable_frames) { + if (!AddFrame(frame, packet.transmission_type)) { + QUIC_BUG(quic_bug_10752_7) + << ENDPOINT << "Failed to serialize frame: " << frame; + return 0; + } + } + // Add necessary padding. + if (padding_size > 0) { + QUIC_DVLOG(2) << ENDPOINT << "Add padding of size: " << padding_size; + if (!AddFrame(QuicFrame(QuicPaddingFrame(padding_size)), + packet.transmission_type)) { + QUIC_BUG(quic_bug_10752_8) + << ENDPOINT << "Failed to add padding of size " << padding_size + << " when serializing ENCRYPTION_INITIAL " + "packet in coalesced packet"; + return 0; + } + } + + if (!SerializePacket(QuicOwnedPacketBuffer(buffer, nullptr), buffer_len, + /*allow_padding=*/false)) { + return 0; + } + if (!packet.initial_header.has_value() || + !packet_.initial_header.has_value()) { + QUIC_BUG(missing initial packet header) + << "initial serialized packet does not have header populated"; + } else if (packet.initial_header.value() != packet_.initial_header.value()) { + QUIC_BUG(initial packet header changed before reserialization) + << ENDPOINT << "original header: " << packet.initial_header.value() + << ", new header: " << packet_.initial_header.value(); + } + const size_t encrypted_length = packet_.encrypted_length; + // Clear frames in packet_. No need to DeleteFrames since frames are owned by + // initial_packet. + packet_.retransmittable_frames.clear(); + packet_.nonretransmittable_frames.clear(); + ClearPacket(); + return encrypted_length; +} + +void QuicPacketCreator::CreateAndSerializeStreamFrame( + QuicStreamId id, size_t write_length, QuicStreamOffset iov_offset, + QuicStreamOffset stream_offset, bool fin, + TransmissionType transmission_type, size_t* num_bytes_consumed) { + // TODO(b/167222597): consider using ScopedSerializationFailureHandler. + QUICHE_DCHECK(queued_frames_.empty()) << ENDPOINT; + QUICHE_DCHECK(!QuicUtils::IsCryptoStreamId(transport_version(), id)) + << ENDPOINT; + // Write out the packet header + QuicPacketHeader header; + FillPacketHeader(&header); + packet_.fate = delegate_->GetSerializedPacketFate( + /*is_mtu_discovery=*/false, packet_.encryption_level); + QUIC_DVLOG(1) << ENDPOINT << "fate of packet " << packet_.packet_number + << ": " << SerializedPacketFateToString(packet_.fate) << " of " + << EncryptionLevelToString(packet_.encryption_level); + + ABSL_CACHELINE_ALIGNED char stack_buffer[kMaxOutgoingPacketSize]; + QuicOwnedPacketBuffer packet_buffer(delegate_->GetPacketBuffer()); + + if (packet_buffer.buffer == nullptr) { + packet_buffer.buffer = stack_buffer; + packet_buffer.release_buffer = nullptr; + } + + char* encrypted_buffer = packet_buffer.buffer; + + QuicDataWriter writer(kMaxOutgoingPacketSize, encrypted_buffer); + size_t length_field_offset = 0; + if (!framer_->AppendPacketHeader(header, &writer, &length_field_offset)) { + QUIC_BUG(quic_bug_10752_9) << ENDPOINT << "AppendPacketHeader failed"; + return; + } + + // Create a Stream frame with the remaining space. + QUIC_BUG_IF(quic_bug_12398_9, iov_offset == write_length && !fin) + << ENDPOINT << "Creating a stream frame with no data or fin."; + const size_t remaining_data_size = write_length - iov_offset; + size_t min_frame_size = QuicFramer::GetMinStreamFrameSize( + framer_->transport_version(), id, stream_offset, + /* last_frame_in_packet= */ true, remaining_data_size); + size_t available_size = + max_plaintext_size_ - writer.length() - min_frame_size; + size_t bytes_consumed = std::min(available_size, remaining_data_size); + size_t plaintext_bytes_written = min_frame_size + bytes_consumed; + bool needs_padding = false; + const size_t min_plaintext_size = + MinPlaintextPacketSize(framer_->version(), GetPacketNumberLength()); + if (plaintext_bytes_written < min_plaintext_size) { + needs_padding = true; + } + + const bool set_fin = fin && (bytes_consumed == remaining_data_size); + QuicStreamFrame frame(id, set_fin, stream_offset, bytes_consumed); + if (debug_delegate_ != nullptr) { + debug_delegate_->OnFrameAddedToPacket(QuicFrame(frame)); + } + QUIC_DVLOG(1) << ENDPOINT << "Adding frame: " << frame; + + QUIC_DVLOG(2) << ENDPOINT << "Serializing stream packet " << header << frame; + + // TODO(ianswett): AppendTypeByte and AppendStreamFrame could be optimized + // into one method that takes a QuicStreamFrame, if warranted. + if (needs_padding) { + if (!writer.WritePaddingBytes(min_plaintext_size - + plaintext_bytes_written)) { + QUIC_BUG(quic_bug_10752_12) << ENDPOINT << "Unable to add padding bytes"; + return; + } + needs_padding = false; + } + bool omit_frame_length = !needs_padding; + if (!framer_->AppendTypeByte(QuicFrame(frame), omit_frame_length, &writer)) { + QUIC_BUG(quic_bug_10752_10) << ENDPOINT << "AppendTypeByte failed"; + return; + } + if (!framer_->AppendStreamFrame(frame, omit_frame_length, &writer)) { + QUIC_BUG(quic_bug_10752_11) << ENDPOINT << "AppendStreamFrame failed"; + return; + } + if (needs_padding && plaintext_bytes_written < min_plaintext_size && + !writer.WritePaddingBytes(min_plaintext_size - plaintext_bytes_written)) { + QUIC_BUG(quic_bug_10752_12) << ENDPOINT << "Unable to add padding bytes"; + return; + } + + if (!framer_->WriteIetfLongHeaderLength(header, &writer, length_field_offset, + packet_.encryption_level)) { + return; + } + + packet_.transmission_type = transmission_type; + + QUICHE_DCHECK(packet_.encryption_level == ENCRYPTION_FORWARD_SECURE || + packet_.encryption_level == ENCRYPTION_ZERO_RTT) + << ENDPOINT << packet_.encryption_level; + size_t encrypted_length = framer_->EncryptInPlace( + packet_.encryption_level, packet_.packet_number, + GetStartOfEncryptedData(framer_->transport_version(), header), + writer.length(), kMaxOutgoingPacketSize, encrypted_buffer); + if (encrypted_length == 0) { + QUIC_BUG(quic_bug_10752_13) + << ENDPOINT << "Failed to encrypt packet number " + << header.packet_number; + return; + } + // TODO(ianswett): Optimize the storage so RetransmitableFrames can be + // unioned with a QuicStreamFrame and a UniqueStreamBuffer. + *num_bytes_consumed = bytes_consumed; + packet_size_ = 0; + packet_.encrypted_buffer = encrypted_buffer; + packet_.encrypted_length = encrypted_length; + + packet_buffer.buffer = nullptr; + packet_.release_encrypted_buffer = std::move(packet_buffer).release_buffer; + + packet_.retransmittable_frames.push_back(QuicFrame(frame)); + OnSerializedPacket(); +} + +bool QuicPacketCreator::HasPendingFrames() const { + return !queued_frames_.empty(); +} + +std::string QuicPacketCreator::GetPendingFramesInfo() const { + return QuicFramesToString(queued_frames_); +} + +bool QuicPacketCreator::HasPendingRetransmittableFrames() const { + return !packet_.retransmittable_frames.empty(); +} + +bool QuicPacketCreator::HasPendingStreamFramesOfStream(QuicStreamId id) const { + for (const auto& frame : packet_.retransmittable_frames) { + if (frame.type == STREAM_FRAME && frame.stream_frame.stream_id == id) { + return true; + } + } + return false; +} + +size_t QuicPacketCreator::ExpansionOnNewFrame() const { + // If the last frame in the packet is a message frame, then it will expand to + // include the varint message length when a new frame is added. + if (queued_frames_.empty()) { + return 0; + } + return ExpansionOnNewFrameWithLastFrame(queued_frames_.back(), + framer_->transport_version()); +} + +// static +size_t QuicPacketCreator::ExpansionOnNewFrameWithLastFrame( + const QuicFrame& last_frame, QuicTransportVersion version) { + if (last_frame.type == MESSAGE_FRAME) { + return QuicDataWriter::GetVarInt62Len( + last_frame.message_frame->message_length); + } + if (last_frame.type != STREAM_FRAME) { + return 0; + } + if (VersionHasIetfQuicFrames(version)) { + return QuicDataWriter::GetVarInt62Len(last_frame.stream_frame.data_length); + } + return kQuicStreamPayloadLengthSize; +} + +size_t QuicPacketCreator::BytesFree() const { + return max_plaintext_size_ - + std::min(max_plaintext_size_, PacketSize() + ExpansionOnNewFrame()); +} + +size_t QuicPacketCreator::BytesFreeForPadding() const { + size_t consumed = PacketSize(); + return max_plaintext_size_ - std::min(max_plaintext_size_, consumed); +} + +size_t QuicPacketCreator::PacketSize() const { + return queued_frames_.empty() ? PacketHeaderSize() : packet_size_; +} + +bool QuicPacketCreator::AddPaddedSavedFrame( + const QuicFrame& frame, TransmissionType transmission_type) { + if (AddFrame(frame, transmission_type)) { + needs_full_padding_ = true; + return true; + } + return false; +} + +absl::optional +QuicPacketCreator::MaybeBuildDataPacketWithChaosProtection( + const QuicPacketHeader& header, char* buffer) { + if (!GetQuicFlag(quic_enable_chaos_protection) || + framer_->perspective() != Perspective::IS_CLIENT || + packet_.encryption_level != ENCRYPTION_INITIAL || + !framer_->version().UsesCryptoFrames() || queued_frames_.size() != 2u || + queued_frames_[0].type != CRYPTO_FRAME || + queued_frames_[1].type != PADDING_FRAME || + // Do not perform chaos protection if we do not have a known number of + // padding bytes to work with. + queued_frames_[1].padding_frame.num_padding_bytes <= 0 || + // Chaos protection relies on the framer using a crypto data producer, + // which is always the case in practice. + framer_->data_producer() == nullptr) { + return absl::nullopt; + } + const QuicCryptoFrame& crypto_frame = *queued_frames_[0].crypto_frame; + if (packet_.encryption_level != crypto_frame.level) { + QUIC_BUG(chaos frame level) + << ENDPOINT << packet_.encryption_level << " != " << crypto_frame.level; + return absl::nullopt; + } + QuicChaosProtector chaos_protector( + crypto_frame, queued_frames_[1].padding_frame.num_padding_bytes, + packet_size_, framer_, random_); + return chaos_protector.BuildDataPacket(header, buffer); +} + +bool QuicPacketCreator::SerializePacket(QuicOwnedPacketBuffer encrypted_buffer, + size_t encrypted_buffer_len, + bool allow_padding) { + if (packet_.encrypted_buffer != nullptr) { + const std::string error_details = + "Packet's encrypted buffer is not empty before serialization"; + QUIC_BUG(quic_bug_10752_14) << ENDPOINT << error_details; + delegate_->OnUnrecoverableError(QUIC_FAILED_TO_SERIALIZE_PACKET, + error_details); + return false; + } + ScopedSerializationFailureHandler handler(this); + + QUICHE_DCHECK_LT(0u, encrypted_buffer_len) << ENDPOINT; + QUIC_BUG_IF(quic_bug_12398_10, + queued_frames_.empty() && pending_padding_bytes_ == 0) + << ENDPOINT << "Attempt to serialize empty packet"; + QuicPacketHeader header; + // FillPacketHeader increments packet_number_. + FillPacketHeader(&header); + if (packet_.encryption_level == ENCRYPTION_INITIAL) { + packet_.initial_header = header; + } + if (delegate_ != nullptr) { + packet_.fate = delegate_->GetSerializedPacketFate( + /*is_mtu_discovery=*/QuicUtils::ContainsFrameType(queued_frames_, + MTU_DISCOVERY_FRAME), + packet_.encryption_level); + QUIC_DVLOG(1) << ENDPOINT << "fate of packet " << packet_.packet_number + << ": " << SerializedPacketFateToString(packet_.fate) + << " of " + << EncryptionLevelToString(packet_.encryption_level); + } + + if (allow_padding) { + MaybeAddPadding(); + } + + QUIC_DVLOG(2) << ENDPOINT << "Serializing packet " << header + << QuicFramesToString(queued_frames_) << " at encryption_level " + << packet_.encryption_level + << ", allow_padding:" << allow_padding; + + if (!framer_->HasEncrypterOfEncryptionLevel(packet_.encryption_level)) { + // TODO(fayang): Use QUIC_MISSING_WRITE_KEYS for serialization failures due + // to missing keys. + QUIC_BUG(quic_bug_10752_15) + << ENDPOINT << "Attempting to serialize " << header + << QuicFramesToString(queued_frames_) << " at missing encryption_level " + << packet_.encryption_level << " using " << framer_->version(); + return false; + } + + QUICHE_DCHECK_GE(max_plaintext_size_, packet_size_) << ENDPOINT; + // Use the packet_size_ instead of the buffer size to ensure smaller + // packet sizes are properly used. + + size_t length; + absl::optional length_with_chaos_protection = + MaybeBuildDataPacketWithChaosProtection(header, encrypted_buffer.buffer); + if (length_with_chaos_protection.has_value()) { + length = length_with_chaos_protection.value(); + } else { + length = framer_->BuildDataPacket(header, queued_frames_, + encrypted_buffer.buffer, packet_size_, + packet_.encryption_level); + } + + if (length == 0) { + QUIC_BUG(quic_bug_10752_16) + << ENDPOINT << "Failed to serialize " + << QuicFramesToString(queued_frames_) + << " at encryption_level: " << packet_.encryption_level + << ", needs_full_padding_: " << needs_full_padding_ + << ", pending_padding_bytes_: " << pending_padding_bytes_ + << ", latched_hard_max_packet_length_: " + << latched_hard_max_packet_length_ + << ", max_packet_length_: " << max_packet_length_ + << ", header: " << header; + return false; + } + + // ACK Frames will be truncated due to length only if they're the only frame + // in the packet, and if packet_size_ was set to max_plaintext_size_. If + // truncation due to length occurred, then GetSerializedFrameLength will have + // returned all bytes free. + bool possibly_truncated_by_length = packet_size_ == max_plaintext_size_ && + queued_frames_.size() == 1 && + queued_frames_.back().type == ACK_FRAME; + // Because of possible truncation, we can't be confident that our + // packet size calculation worked correctly. + if (!possibly_truncated_by_length) { + QUICHE_DCHECK_EQ(packet_size_, length) << ENDPOINT; + } + const size_t encrypted_length = framer_->EncryptInPlace( + packet_.encryption_level, packet_.packet_number, + GetStartOfEncryptedData(framer_->transport_version(), header), length, + encrypted_buffer_len, encrypted_buffer.buffer); + if (encrypted_length == 0) { + QUIC_BUG(quic_bug_10752_17) + << ENDPOINT << "Failed to encrypt packet number " + << packet_.packet_number; + return false; + } + + packet_size_ = 0; + packet_.encrypted_buffer = encrypted_buffer.buffer; + packet_.encrypted_length = encrypted_length; + + encrypted_buffer.buffer = nullptr; + packet_.release_encrypted_buffer = std::move(encrypted_buffer).release_buffer; + return true; +} + +std::unique_ptr +QuicPacketCreator::SerializeConnectivityProbingPacket() { + QUIC_BUG_IF(quic_bug_12398_11, + VersionHasIetfQuicFrames(framer_->transport_version())) + << ENDPOINT + << "Must not be version 99 to serialize padded ping connectivity probe"; + RemoveSoftMaxPacketLength(); + QuicPacketHeader header; + // FillPacketHeader increments packet_number_. + FillPacketHeader(&header); + + QUIC_DVLOG(2) << ENDPOINT << "Serializing connectivity probing packet " + << header; + + std::unique_ptr buffer(new char[kMaxOutgoingPacketSize]); + size_t length = BuildConnectivityProbingPacket( + header, buffer.get(), max_plaintext_size_, packet_.encryption_level); + QUICHE_DCHECK(length) << ENDPOINT; + + QUICHE_DCHECK_EQ(packet_.encryption_level, ENCRYPTION_FORWARD_SECURE) + << ENDPOINT; + const size_t encrypted_length = framer_->EncryptInPlace( + packet_.encryption_level, packet_.packet_number, + GetStartOfEncryptedData(framer_->transport_version(), header), length, + kMaxOutgoingPacketSize, buffer.get()); + QUICHE_DCHECK(encrypted_length) << ENDPOINT; + + std::unique_ptr serialize_packet(new SerializedPacket( + header.packet_number, header.packet_number_length, buffer.release(), + encrypted_length, /*has_ack=*/false, /*has_stop_waiting=*/false)); + + serialize_packet->release_encrypted_buffer = [](const char* p) { + delete[] p; + }; + serialize_packet->encryption_level = packet_.encryption_level; + serialize_packet->transmission_type = NOT_RETRANSMISSION; + + return serialize_packet; +} + +std::unique_ptr +QuicPacketCreator::SerializePathChallengeConnectivityProbingPacket( + const QuicPathFrameBuffer& payload) { + QUIC_BUG_IF(quic_bug_12398_12, + !VersionHasIetfQuicFrames(framer_->transport_version())) + << ENDPOINT + << "Must be version 99 to serialize path challenge connectivity probe, " + "is version " + << framer_->transport_version(); + RemoveSoftMaxPacketLength(); + QuicPacketHeader header; + // FillPacketHeader increments packet_number_. + FillPacketHeader(&header); + + QUIC_DVLOG(2) << ENDPOINT << "Serializing path challenge packet " << header; + + std::unique_ptr buffer(new char[kMaxOutgoingPacketSize]); + size_t length = + BuildPaddedPathChallengePacket(header, buffer.get(), max_plaintext_size_, + payload, packet_.encryption_level); + QUICHE_DCHECK(length) << ENDPOINT; + + QUICHE_DCHECK_EQ(packet_.encryption_level, ENCRYPTION_FORWARD_SECURE) + << ENDPOINT; + const size_t encrypted_length = framer_->EncryptInPlace( + packet_.encryption_level, packet_.packet_number, + GetStartOfEncryptedData(framer_->transport_version(), header), length, + kMaxOutgoingPacketSize, buffer.get()); + QUICHE_DCHECK(encrypted_length) << ENDPOINT; + + std::unique_ptr serialize_packet( + new SerializedPacket(header.packet_number, header.packet_number_length, + buffer.release(), encrypted_length, + /*has_ack=*/false, /*has_stop_waiting=*/false)); + + serialize_packet->release_encrypted_buffer = [](const char* p) { + delete[] p; + }; + serialize_packet->encryption_level = packet_.encryption_level; + serialize_packet->transmission_type = NOT_RETRANSMISSION; + + return serialize_packet; +} + +std::unique_ptr +QuicPacketCreator::SerializePathResponseConnectivityProbingPacket( + const quiche::QuicheCircularDeque& payloads, + const bool is_padded) { + QUIC_BUG_IF(quic_bug_12398_13, + !VersionHasIetfQuicFrames(framer_->transport_version())) + << ENDPOINT + << "Must be version 99 to serialize path response connectivity probe, is " + "version " + << framer_->transport_version(); + RemoveSoftMaxPacketLength(); + QuicPacketHeader header; + // FillPacketHeader increments packet_number_. + FillPacketHeader(&header); + + QUIC_DVLOG(2) << ENDPOINT << "Serializing path response packet " << header; + + std::unique_ptr buffer(new char[kMaxOutgoingPacketSize]); + size_t length = + BuildPathResponsePacket(header, buffer.get(), max_plaintext_size_, + payloads, is_padded, packet_.encryption_level); + QUICHE_DCHECK(length) << ENDPOINT; + + QUICHE_DCHECK_EQ(packet_.encryption_level, ENCRYPTION_FORWARD_SECURE) + << ENDPOINT; + const size_t encrypted_length = framer_->EncryptInPlace( + packet_.encryption_level, packet_.packet_number, + GetStartOfEncryptedData(framer_->transport_version(), header), length, + kMaxOutgoingPacketSize, buffer.get()); + QUICHE_DCHECK(encrypted_length) << ENDPOINT; + + std::unique_ptr serialize_packet( + new SerializedPacket(header.packet_number, header.packet_number_length, + buffer.release(), encrypted_length, + /*has_ack=*/false, /*has_stop_waiting=*/false)); + + serialize_packet->release_encrypted_buffer = [](const char* p) { + delete[] p; + }; + serialize_packet->encryption_level = packet_.encryption_level; + serialize_packet->transmission_type = NOT_RETRANSMISSION; + + return serialize_packet; +} + +size_t QuicPacketCreator::BuildPaddedPathChallengePacket( + const QuicPacketHeader& header, char* buffer, size_t packet_length, + const QuicPathFrameBuffer& payload, EncryptionLevel level) { + QUICHE_DCHECK(VersionHasIetfQuicFrames(framer_->transport_version())) + << ENDPOINT; + QuicFrames frames; + + // Write a PATH_CHALLENGE frame, which has a random 8-byte payload + frames.push_back(QuicFrame(QuicPathChallengeFrame(0, payload))); + + if (debug_delegate_ != nullptr) { + debug_delegate_->OnFrameAddedToPacket(frames.back()); + } + + // Add padding to the rest of the packet in order to assess Path MTU + // characteristics. + QuicPaddingFrame padding_frame; + frames.push_back(QuicFrame(padding_frame)); + + return framer_->BuildDataPacket(header, frames, buffer, packet_length, level); +} + +size_t QuicPacketCreator::BuildPathResponsePacket( + const QuicPacketHeader& header, char* buffer, size_t packet_length, + const quiche::QuicheCircularDeque& payloads, + const bool is_padded, EncryptionLevel level) { + if (payloads.empty()) { + QUIC_BUG(quic_bug_12398_14) + << ENDPOINT + << "Attempt to generate connectivity response with no request payloads"; + return 0; + } + QUICHE_DCHECK(VersionHasIetfQuicFrames(framer_->transport_version())) + << ENDPOINT; + + QuicFrames frames; + for (const QuicPathFrameBuffer& payload : payloads) { + // Note that the control frame ID can be 0 since this is not retransmitted. + frames.push_back(QuicFrame(QuicPathResponseFrame(0, payload))); + if (debug_delegate_ != nullptr) { + debug_delegate_->OnFrameAddedToPacket(frames.back()); + } + } + + if (is_padded) { + // Add padding to the rest of the packet in order to assess Path MTU + // characteristics. + QuicPaddingFrame padding_frame; + frames.push_back(QuicFrame(padding_frame)); + } + + return framer_->BuildDataPacket(header, frames, buffer, packet_length, level); +} + +size_t QuicPacketCreator::BuildConnectivityProbingPacket( + const QuicPacketHeader& header, char* buffer, size_t packet_length, + EncryptionLevel level) { + QuicFrames frames; + + // Write a PING frame, which has no data payload. + QuicPingFrame ping_frame; + frames.push_back(QuicFrame(ping_frame)); + + // Add padding to the rest of the packet. + QuicPaddingFrame padding_frame; + frames.push_back(QuicFrame(padding_frame)); + + return framer_->BuildDataPacket(header, frames, buffer, packet_length, level); +} + +size_t QuicPacketCreator::SerializeCoalescedPacket( + const QuicCoalescedPacket& coalesced, char* buffer, size_t buffer_len) { + if (HasPendingFrames()) { + QUIC_BUG(quic_bug_10752_18) + << ENDPOINT << "Try to serialize coalesced packet with pending frames"; + return 0; + } + RemoveSoftMaxPacketLength(); + QUIC_BUG_IF(quic_bug_12398_15, coalesced.length() == 0) + << ENDPOINT << "Attempt to serialize empty coalesced packet"; + size_t packet_length = 0; + size_t initial_length = 0; + size_t padding_size = 0; + if (coalesced.initial_packet() != nullptr) { + // Padding coalesced packet containing initial packet to full. + padding_size = coalesced.max_packet_length() - coalesced.length(); + if (framer_->perspective() == Perspective::IS_SERVER && + QuicUtils::ContainsFrameType( + coalesced.initial_packet()->retransmittable_frames, + CONNECTION_CLOSE_FRAME)) { + // Do not pad server initial connection close packet. + padding_size = 0; + } + initial_length = ReserializeInitialPacketInCoalescedPacket( + *coalesced.initial_packet(), padding_size, buffer, buffer_len); + if (initial_length == 0) { + QUIC_BUG(quic_bug_10752_19) + << ENDPOINT + << "Failed to reserialize ENCRYPTION_INITIAL packet in " + "coalesced packet"; + return 0; + } + QUIC_BUG_IF(quic_reserialize_initial_packet_unexpected_size, + coalesced.initial_packet()->encrypted_length + padding_size != + initial_length) + << "Reserialize initial packet in coalescer has unexpected size, " + "original_length: " + << coalesced.initial_packet()->encrypted_length + << ", coalesced.max_packet_length: " << coalesced.max_packet_length() + << ", coalesced.length: " << coalesced.length() + << ", padding_size: " << padding_size + << ", serialized_length: " << initial_length + << ", retransmittable frames: " + << QuicFramesToString( + coalesced.initial_packet()->retransmittable_frames) + << ", nonretransmittable frames: " + << QuicFramesToString( + coalesced.initial_packet()->nonretransmittable_frames); + buffer += initial_length; + buffer_len -= initial_length; + packet_length += initial_length; + } + size_t length_copied = 0; + if (!coalesced.CopyEncryptedBuffers(buffer, buffer_len, &length_copied)) { + QUIC_BUG(quic_serialize_coalesced_packet_copy_failure) + << "SerializeCoalescedPacket failed. buffer_len:" << buffer_len + << ", initial_length:" << initial_length + << ", padding_size: " << padding_size + << ", length_copied:" << length_copied + << ", coalesced.length:" << coalesced.length() + << ", coalesced.max_packet_length:" << coalesced.max_packet_length() + << ", coalesced.packet_lengths:" + << absl::StrJoin(coalesced.packet_lengths(), ":"); + return 0; + } + packet_length += length_copied; + QUIC_DVLOG(1) << ENDPOINT + << "Successfully serialized coalesced packet of length: " + << packet_length; + return packet_length; +} + +// TODO(b/74062209): Make this a public method of framer? +SerializedPacket QuicPacketCreator::NoPacket() { + return SerializedPacket(QuicPacketNumber(), PACKET_1BYTE_PACKET_NUMBER, + nullptr, 0, false, false); +} + +QuicConnectionId QuicPacketCreator::GetDestinationConnectionId() const { + if (framer_->perspective() == Perspective::IS_SERVER) { + return client_connection_id_; + } + return server_connection_id_; +} + +QuicConnectionId QuicPacketCreator::GetSourceConnectionId() const { + if (framer_->perspective() == Perspective::IS_CLIENT) { + return client_connection_id_; + } + return server_connection_id_; +} + +QuicConnectionIdIncluded QuicPacketCreator::GetDestinationConnectionIdIncluded() + const { + // In versions that do not support client connection IDs, the destination + // connection ID is only sent from client to server. + return (framer_->perspective() == Perspective::IS_CLIENT || + framer_->version().SupportsClientConnectionIds()) + ? CONNECTION_ID_PRESENT + : CONNECTION_ID_ABSENT; +} + +QuicConnectionIdIncluded QuicPacketCreator::GetSourceConnectionIdIncluded() + const { + // Long header packets sent by server include source connection ID. + // Ones sent by the client only include source connection ID if the version + // supports client connection IDs. + if (HasIetfLongHeader() && + (framer_->perspective() == Perspective::IS_SERVER || + framer_->version().SupportsClientConnectionIds())) { + return CONNECTION_ID_PRESENT; + } + if (framer_->perspective() == Perspective::IS_SERVER) { + return server_connection_id_included_; + } + return CONNECTION_ID_ABSENT; +} + +uint8_t QuicPacketCreator::GetDestinationConnectionIdLength() const { + QUICHE_DCHECK(QuicUtils::IsConnectionIdValidForVersion(server_connection_id_, + transport_version())) + << ENDPOINT; + return GetDestinationConnectionIdIncluded() == CONNECTION_ID_PRESENT + ? GetDestinationConnectionId().length() + : 0; +} + +uint8_t QuicPacketCreator::GetSourceConnectionIdLength() const { + QUICHE_DCHECK(QuicUtils::IsConnectionIdValidForVersion(server_connection_id_, + transport_version())) + << ENDPOINT; + return GetSourceConnectionIdIncluded() == CONNECTION_ID_PRESENT + ? GetSourceConnectionId().length() + : 0; +} + +QuicPacketNumberLength QuicPacketCreator::GetPacketNumberLength() const { + if (HasIetfLongHeader() && + !framer_->version().SendsVariableLengthPacketNumberInLongHeader()) { + return PACKET_4BYTE_PACKET_NUMBER; + } + return packet_.packet_number_length; +} + +size_t QuicPacketCreator::PacketHeaderSize() const { + return GetPacketHeaderSize( + framer_->transport_version(), GetDestinationConnectionIdLength(), + GetSourceConnectionIdLength(), IncludeVersionInHeader(), + IncludeNonceInPublicHeader(), GetPacketNumberLength(), + GetRetryTokenLengthLength(), GetRetryToken().length(), GetLengthLength()); +} + +quiche::QuicheVariableLengthIntegerLength +QuicPacketCreator::GetRetryTokenLengthLength() const { + if (QuicVersionHasLongHeaderLengths(framer_->transport_version()) && + HasIetfLongHeader() && + EncryptionlevelToLongHeaderType(packet_.encryption_level) == INITIAL) { + return QuicDataWriter::GetVarInt62Len(GetRetryToken().length()); + } + return quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0; +} + +absl::string_view QuicPacketCreator::GetRetryToken() const { + if (QuicVersionHasLongHeaderLengths(framer_->transport_version()) && + HasIetfLongHeader() && + EncryptionlevelToLongHeaderType(packet_.encryption_level) == INITIAL) { + return retry_token_; + } + return absl::string_view(); +} + +void QuicPacketCreator::SetRetryToken(absl::string_view retry_token) { + retry_token_ = std::string(retry_token); +} + +bool QuicPacketCreator::ConsumeRetransmittableControlFrame( + const QuicFrame& frame) { + QUIC_BUG_IF(quic_bug_12398_16, IsControlFrame(frame.type) && + !GetControlFrameId(frame) && + frame.type != PING_FRAME) + << ENDPOINT + << "Adding a control frame with no control frame id: " << frame; + QUICHE_DCHECK(QuicUtils::IsRetransmittableFrame(frame.type)) + << ENDPOINT << frame; + MaybeBundleAckOpportunistically(); + if (HasPendingFrames()) { + if (AddFrame(frame, next_transmission_type_)) { + // There is pending frames and current frame fits. + return true; + } + } + QUICHE_DCHECK(!HasPendingFrames()) << ENDPOINT; + if (frame.type != PING_FRAME && frame.type != CONNECTION_CLOSE_FRAME && + !delegate_->ShouldGeneratePacket(HAS_RETRANSMITTABLE_DATA, + NOT_HANDSHAKE)) { + // Do not check congestion window for ping or connection close frames. + return false; + } + const bool success = AddFrame(frame, next_transmission_type_); + QUIC_BUG_IF(quic_bug_10752_20, !success) + << ENDPOINT << "Failed to add frame:" << frame + << " transmission_type:" << next_transmission_type_; + return success; +} + +QuicConsumedData QuicPacketCreator::ConsumeData(QuicStreamId id, + size_t write_length, + QuicStreamOffset offset, + StreamSendingState state) { + QUIC_BUG_IF(quic_bug_10752_21, !flusher_attached_) + << ENDPOINT + << "Packet flusher is not attached when " + "generator tries to write stream data."; + bool has_handshake = QuicUtils::IsCryptoStreamId(transport_version(), id); + MaybeBundleAckOpportunistically(); + bool fin = state != NO_FIN; + QUIC_BUG_IF(quic_bug_12398_17, has_handshake && fin) + << ENDPOINT << "Handshake packets should never send a fin"; + // To make reasoning about crypto frames easier, we don't combine them with + // other retransmittable frames in a single packet. + if (has_handshake && HasPendingRetransmittableFrames()) { + FlushCurrentPacket(); + } + + size_t total_bytes_consumed = 0; + bool fin_consumed = false; + + if (!HasRoomForStreamFrame(id, offset, write_length)) { + FlushCurrentPacket(); + } + + if (!fin && (write_length == 0)) { + QUIC_BUG(quic_bug_10752_22) + << ENDPOINT << "Attempt to consume empty data without FIN."; + return QuicConsumedData(0, false); + } + // We determine if we can enter the fast path before executing + // the slow path loop. + bool run_fast_path = + !has_handshake && state != FIN_AND_PADDING && !HasPendingFrames() && + write_length - total_bytes_consumed > kMaxOutgoingPacketSize && + latched_hard_max_packet_length_ == 0; + + while (!run_fast_path && + (has_handshake || delegate_->ShouldGeneratePacket( + HAS_RETRANSMITTABLE_DATA, NOT_HANDSHAKE))) { + QuicFrame frame; + bool needs_full_padding = + has_handshake && fully_pad_crypto_handshake_packets_; + + if (!ConsumeDataToFillCurrentPacket(id, write_length - total_bytes_consumed, + offset + total_bytes_consumed, fin, + needs_full_padding, + next_transmission_type_, &frame)) { + // The creator is always flushed if there's not enough room for a new + // stream frame before ConsumeData, so ConsumeData should always succeed. + QUIC_BUG(quic_bug_10752_23) + << ENDPOINT << "Failed to ConsumeData, stream:" << id; + return QuicConsumedData(0, false); + } + + // A stream frame is created and added. + size_t bytes_consumed = frame.stream_frame.data_length; + total_bytes_consumed += bytes_consumed; + fin_consumed = fin && total_bytes_consumed == write_length; + if (fin_consumed && state == FIN_AND_PADDING) { + AddRandomPadding(); + } + QUICHE_DCHECK(total_bytes_consumed == write_length || + (bytes_consumed > 0 && HasPendingFrames())) + << ENDPOINT; + + if (total_bytes_consumed == write_length) { + // We're done writing the data. Exit the loop. + // We don't make this a precondition because we could have 0 bytes of data + // if we're simply writing a fin. + break; + } + FlushCurrentPacket(); + + run_fast_path = + !has_handshake && state != FIN_AND_PADDING && !HasPendingFrames() && + write_length - total_bytes_consumed > kMaxOutgoingPacketSize && + latched_hard_max_packet_length_ == 0; + } + + if (run_fast_path) { + return ConsumeDataFastPath(id, write_length, offset, state != NO_FIN, + total_bytes_consumed); + } + + // Don't allow the handshake to be bundled with other retransmittable frames. + if (has_handshake) { + FlushCurrentPacket(); + } + + return QuicConsumedData(total_bytes_consumed, fin_consumed); +} + +QuicConsumedData QuicPacketCreator::ConsumeDataFastPath( + QuicStreamId id, size_t write_length, QuicStreamOffset offset, bool fin, + size_t total_bytes_consumed) { + QUICHE_DCHECK(!QuicUtils::IsCryptoStreamId(transport_version(), id)) + << ENDPOINT; + if (AttemptingToSendUnencryptedStreamData()) { + return QuicConsumedData(total_bytes_consumed, + fin && (total_bytes_consumed == write_length)); + } + + while (total_bytes_consumed < write_length && + delegate_->ShouldGeneratePacket(HAS_RETRANSMITTABLE_DATA, + NOT_HANDSHAKE)) { + // Serialize and encrypt the packet. + size_t bytes_consumed = 0; + CreateAndSerializeStreamFrame(id, write_length, total_bytes_consumed, + offset + total_bytes_consumed, fin, + next_transmission_type_, &bytes_consumed); + if (bytes_consumed == 0) { + const std::string error_details = + "Failed in CreateAndSerializeStreamFrame."; + QUIC_BUG(quic_bug_10752_24) << ENDPOINT << error_details; + delegate_->OnUnrecoverableError(QUIC_FAILED_TO_SERIALIZE_PACKET, + error_details); + break; + } + total_bytes_consumed += bytes_consumed; + } + + return QuicConsumedData(total_bytes_consumed, + fin && (total_bytes_consumed == write_length)); +} + +size_t QuicPacketCreator::ConsumeCryptoData(EncryptionLevel level, + size_t write_length, + QuicStreamOffset offset) { + QUIC_DVLOG(2) << ENDPOINT << "ConsumeCryptoData " << level << " write_length " + << write_length << " offset " << offset; + QUIC_BUG_IF(quic_bug_10752_25, !flusher_attached_) + << ENDPOINT + << "Packet flusher is not attached when " + "generator tries to write crypto data."; + MaybeBundleAckOpportunistically(); + // To make reasoning about crypto frames easier, we don't combine them with + // other retransmittable frames in a single packet. + // TODO(nharper): Once we have separate packet number spaces, everything + // should be driven by encryption level, and we should stop flushing in this + // spot. + if (HasPendingRetransmittableFrames()) { + FlushCurrentPacket(); + } + + size_t total_bytes_consumed = 0; + + while ( + total_bytes_consumed < write_length && + delegate_->ShouldGeneratePacket(HAS_RETRANSMITTABLE_DATA, IS_HANDSHAKE)) { + QuicFrame frame; + if (!ConsumeCryptoDataToFillCurrentPacket( + level, write_length - total_bytes_consumed, + offset + total_bytes_consumed, fully_pad_crypto_handshake_packets_, + next_transmission_type_, &frame)) { + // The only pending data in the packet is non-retransmittable frames. + // I'm assuming here that they won't occupy so much of the packet that a + // CRYPTO frame won't fit. + QUIC_BUG_IF(quic_bug_10752_26, !HasSoftMaxPacketLength()) << absl::StrCat( + ENDPOINT, "Failed to ConsumeCryptoData at level ", level, + ", pending_frames: ", GetPendingFramesInfo(), + ", has_soft_max_packet_length: ", HasSoftMaxPacketLength(), + ", max_packet_length: ", max_packet_length_, ", transmission_type: ", + TransmissionTypeToString(next_transmission_type_), + ", packet_number: ", packet_number().ToString()); + return 0; + } + total_bytes_consumed += frame.crypto_frame->data_length; + FlushCurrentPacket(); + } + + // Don't allow the handshake to be bundled with other retransmittable frames. + FlushCurrentPacket(); + + return total_bytes_consumed; +} + +void QuicPacketCreator::GenerateMtuDiscoveryPacket(QuicByteCount target_mtu) { + // MTU discovery frames must be sent by themselves. + if (!CanSetMaxPacketLength()) { + QUIC_BUG(quic_bug_10752_27) + << ENDPOINT + << "MTU discovery packets should only be sent when no other " + << "frames needs to be sent."; + return; + } + const QuicByteCount current_mtu = max_packet_length(); + + // The MTU discovery frame is allocated on the stack, since it is going to be + // serialized within this function. + QuicMtuDiscoveryFrame mtu_discovery_frame; + QuicFrame frame(mtu_discovery_frame); + + // Send the probe packet with the new length. + SetMaxPacketLength(target_mtu); + const bool success = AddPaddedSavedFrame(frame, next_transmission_type_); + FlushCurrentPacket(); + // The only reason AddFrame can fail is that the packet is too full to fit in + // a ping. This is not possible for any sane MTU. + QUIC_BUG_IF(quic_bug_10752_28, !success) + << ENDPOINT << "Failed to send path MTU target_mtu:" << target_mtu + << " transmission_type:" << next_transmission_type_; + + // Reset the packet length back. + SetMaxPacketLength(current_mtu); +} + +void QuicPacketCreator::MaybeBundleAckOpportunistically() { + if (has_ack()) { + // Ack already queued, nothing to do. + return; + } + if (!delegate_->ShouldGeneratePacket(NO_RETRANSMITTABLE_DATA, + NOT_HANDSHAKE)) { + return; + } + const bool flushed = + FlushAckFrame(delegate_->MaybeBundleAckOpportunistically()); + QUIC_BUG_IF(quic_bug_10752_29, !flushed) + << ENDPOINT << "Failed to flush ACK frame. encryption_level:" + << packet_.encryption_level; +} + +bool QuicPacketCreator::FlushAckFrame(const QuicFrames& frames) { + QUIC_BUG_IF(quic_bug_10752_30, !flusher_attached_) + << ENDPOINT + << "Packet flusher is not attached when " + "generator tries to send ACK frame."; + // MaybeBundleAckOpportunistically could be called nestedly when sending a + // control frame causing another control frame to be sent. + QUIC_BUG_IF(quic_bug_12398_18, !frames.empty() && has_ack()) + << ENDPOINT << "Trying to flush " << quiche::PrintElements(frames) + << " when there is ACK queued"; + for (const auto& frame : frames) { + QUICHE_DCHECK(frame.type == ACK_FRAME || frame.type == STOP_WAITING_FRAME) + << ENDPOINT; + if (HasPendingFrames()) { + if (AddFrame(frame, next_transmission_type_)) { + // There is pending frames and current frame fits. + continue; + } + } + QUICHE_DCHECK(!HasPendingFrames()) << ENDPOINT; + // There is no pending frames, consult the delegate whether a packet can be + // generated. + if (!delegate_->ShouldGeneratePacket(NO_RETRANSMITTABLE_DATA, + NOT_HANDSHAKE)) { + return false; + } + const bool success = AddFrame(frame, next_transmission_type_); + QUIC_BUG_IF(quic_bug_10752_31, !success) + << ENDPOINT << "Failed to flush " << frame; + } + return true; +} + +void QuicPacketCreator::AddRandomPadding() { + AddPendingPadding(random_->RandUint64() % kMaxNumRandomPaddingBytes + 1); +} + +void QuicPacketCreator::AttachPacketFlusher() { + flusher_attached_ = true; + if (!write_start_packet_number_.IsInitialized()) { + write_start_packet_number_ = NextSendingPacketNumber(); + } +} + +void QuicPacketCreator::Flush() { + FlushCurrentPacket(); + SendRemainingPendingPadding(); + flusher_attached_ = false; + if (GetQuicFlag(quic_export_write_path_stats_at_server)) { + if (!write_start_packet_number_.IsInitialized()) { + QUIC_BUG(quic_bug_10752_32) + << ENDPOINT << "write_start_packet_number is not initialized"; + return; + } + QUIC_SERVER_HISTOGRAM_COUNTS( + "quic_server_num_written_packets_per_write", + NextSendingPacketNumber() - write_start_packet_number_, 1, 200, 50, + "Number of QUIC packets written per write operation"); + } + write_start_packet_number_.Clear(); +} + +void QuicPacketCreator::SendRemainingPendingPadding() { + while ( + pending_padding_bytes() > 0 && !HasPendingFrames() && + delegate_->ShouldGeneratePacket(NO_RETRANSMITTABLE_DATA, NOT_HANDSHAKE)) { + FlushCurrentPacket(); + } +} + +void QuicPacketCreator::SetServerConnectionIdLength(uint32_t length) { + if (length == 0) { + SetServerConnectionIdIncluded(CONNECTION_ID_ABSENT); + } else { + SetServerConnectionIdIncluded(CONNECTION_ID_PRESENT); + } +} + +void QuicPacketCreator::SetTransmissionType(TransmissionType type) { + next_transmission_type_ = type; +} + +MessageStatus QuicPacketCreator::AddMessageFrame( + QuicMessageId message_id, absl::Span message) { + QUIC_BUG_IF(quic_bug_10752_33, !flusher_attached_) + << ENDPOINT + << "Packet flusher is not attached when " + "generator tries to add message frame."; + MaybeBundleAckOpportunistically(); + const QuicByteCount message_length = MemSliceSpanTotalSize(message); + if (message_length > GetCurrentLargestMessagePayload()) { + return MESSAGE_STATUS_TOO_LARGE; + } + if (!HasRoomForMessageFrame(message_length)) { + FlushCurrentPacket(); + } + QuicMessageFrame* frame = new QuicMessageFrame(message_id, message); + const bool success = AddFrame(QuicFrame(frame), next_transmission_type_); + if (!success) { + QUIC_BUG(quic_bug_10752_34) + << ENDPOINT << "Failed to send message " << message_id; + delete frame; + return MESSAGE_STATUS_INTERNAL_ERROR; + } + QUICHE_DCHECK_EQ(MemSliceSpanTotalSize(message), + 0u); // Ensure the old slices are empty. + return MESSAGE_STATUS_SUCCESS; +} + +quiche::QuicheVariableLengthIntegerLength QuicPacketCreator::GetLengthLength() + const { + if (QuicVersionHasLongHeaderLengths(framer_->transport_version()) && + HasIetfLongHeader()) { + QuicLongHeaderType long_header_type = + EncryptionlevelToLongHeaderType(packet_.encryption_level); + if (long_header_type == INITIAL || long_header_type == ZERO_RTT_PROTECTED || + long_header_type == HANDSHAKE) { + return quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2; + } + } + return quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0; +} + +void QuicPacketCreator::FillPacketHeader(QuicPacketHeader* header) { + header->destination_connection_id = GetDestinationConnectionId(); + header->destination_connection_id_included = + GetDestinationConnectionIdIncluded(); + header->source_connection_id = GetSourceConnectionId(); + header->source_connection_id_included = GetSourceConnectionIdIncluded(); + header->reset_flag = false; + header->version_flag = IncludeVersionInHeader(); + if (IncludeNonceInPublicHeader()) { + QUICHE_DCHECK_EQ(Perspective::IS_SERVER, framer_->perspective()) + << ENDPOINT; + header->nonce = &diversification_nonce_; + } else { + header->nonce = nullptr; + } + packet_.packet_number = NextSendingPacketNumber(); + header->packet_number = packet_.packet_number; + header->packet_number_length = GetPacketNumberLength(); + header->retry_token_length_length = GetRetryTokenLengthLength(); + header->retry_token = GetRetryToken(); + header->length_length = GetLengthLength(); + header->remaining_packet_length = 0; + if (!HasIetfLongHeader()) { + return; + } + header->long_packet_type = + EncryptionlevelToLongHeaderType(packet_.encryption_level); +} + +size_t QuicPacketCreator::GetSerializedFrameLength(const QuicFrame& frame) { + size_t serialized_frame_length = framer_->GetSerializedFrameLength( + frame, BytesFree(), queued_frames_.empty(), + /* last_frame_in_packet= */ true, GetPacketNumberLength()); + if (!framer_->version().HasHeaderProtection() || + serialized_frame_length == 0) { + return serialized_frame_length; + } + // Calculate frame bytes and bytes free with this frame added. + const size_t frame_bytes = PacketSize() - PacketHeaderSize() + + ExpansionOnNewFrame() + serialized_frame_length; + if (frame_bytes >= + MinPlaintextPacketSize(framer_->version(), GetPacketNumberLength())) { + // No extra bytes is needed. + return serialized_frame_length; + } + if (BytesFree() < serialized_frame_length) { + QUIC_BUG(quic_bug_10752_35) << ENDPOINT << "Frame does not fit: " << frame; + return 0; + } + // Please note bytes_free does not take |frame|'s expansion into account. + size_t bytes_free = BytesFree() - serialized_frame_length; + // Extra bytes needed (this is NOT padding needed) should be at least 1 + // padding + expansion. + const size_t extra_bytes_needed = std::max( + 1 + ExpansionOnNewFrameWithLastFrame(frame, framer_->transport_version()), + MinPlaintextPacketSize(framer_->version(), GetPacketNumberLength()) - + frame_bytes); + if (bytes_free < extra_bytes_needed) { + // This frame does not fit. + return 0; + } + return serialized_frame_length; +} + +bool QuicPacketCreator::AddFrame(const QuicFrame& frame, + TransmissionType transmission_type) { + QUIC_DVLOG(1) << ENDPOINT << "Adding frame with transmission type " + << transmission_type << ": " << frame; + if (frame.type == STREAM_FRAME && + !QuicUtils::IsCryptoStreamId(framer_->transport_version(), + frame.stream_frame.stream_id) && + AttemptingToSendUnencryptedStreamData()) { + return false; + } + + // Sanity check to ensure we don't send frames at the wrong encryption level. + QUICHE_DCHECK( + packet_.encryption_level == ENCRYPTION_ZERO_RTT || + packet_.encryption_level == ENCRYPTION_FORWARD_SECURE || + (frame.type != GOAWAY_FRAME && frame.type != WINDOW_UPDATE_FRAME && + frame.type != HANDSHAKE_DONE_FRAME && + frame.type != NEW_CONNECTION_ID_FRAME && + frame.type != MAX_STREAMS_FRAME && frame.type != STREAMS_BLOCKED_FRAME && + frame.type != PATH_RESPONSE_FRAME && + frame.type != PATH_CHALLENGE_FRAME && frame.type != STOP_SENDING_FRAME && + frame.type != MESSAGE_FRAME && frame.type != NEW_TOKEN_FRAME && + frame.type != RETIRE_CONNECTION_ID_FRAME && + frame.type != ACK_FREQUENCY_FRAME)) + << ENDPOINT << frame.type << " not allowed at " + << packet_.encryption_level; + + if (frame.type == STREAM_FRAME) { + if (MaybeCoalesceStreamFrame(frame.stream_frame)) { + LogCoalesceStreamFrameStatus(true); + return true; + } else { + LogCoalesceStreamFrameStatus(false); + } + } + + // If this is an ACK frame, validate that it is non-empty and that + // largest_acked matches the max packet number. + QUICHE_DCHECK(frame.type != ACK_FRAME || (!frame.ack_frame->packets.Empty() && + frame.ack_frame->packets.Max() == + frame.ack_frame->largest_acked)) + << ENDPOINT << "Invalid ACK frame: " << frame; + + size_t frame_len = GetSerializedFrameLength(frame); + if (frame_len == 0 && RemoveSoftMaxPacketLength()) { + // Remove soft max_packet_length and retry. + frame_len = GetSerializedFrameLength(frame); + } + if (frame_len == 0) { + QUIC_DVLOG(1) << ENDPOINT + << "Flushing because current open packet is full when adding " + << frame; + FlushCurrentPacket(); + return false; + } + if (queued_frames_.empty()) { + packet_size_ = PacketHeaderSize(); + } + QUICHE_DCHECK_LT(0u, packet_size_) << ENDPOINT; + + packet_size_ += ExpansionOnNewFrame() + frame_len; + + if (QuicUtils::IsRetransmittableFrame(frame.type)) { + packet_.retransmittable_frames.push_back(frame); + queued_frames_.push_back(frame); + if (QuicUtils::IsHandshakeFrame(frame, framer_->transport_version())) { + packet_.has_crypto_handshake = IS_HANDSHAKE; + } + } else { + if (frame.type == PADDING_FRAME && + frame.padding_frame.num_padding_bytes == -1) { + // Populate the actual length of full padding frame, such that one can + // know how much padding is actually added. + packet_.nonretransmittable_frames.push_back( + QuicFrame(QuicPaddingFrame(frame_len))); + } else { + packet_.nonretransmittable_frames.push_back(frame); + } + queued_frames_.push_back(frame); + } + + if (frame.type == ACK_FRAME) { + packet_.has_ack = true; + packet_.largest_acked = LargestAcked(*frame.ack_frame); + if (frame.ack_frame->ecn_counters.has_value()) { + packet_.has_ack_ecn = true; + } + } else if (frame.type == STOP_WAITING_FRAME) { + packet_.has_stop_waiting = true; + } else if (frame.type == ACK_FREQUENCY_FRAME) { + packet_.has_ack_frequency = true; + } else if (frame.type == MESSAGE_FRAME) { + packet_.has_message = true; + } + if (debug_delegate_ != nullptr) { + debug_delegate_->OnFrameAddedToPacket(frame); + } + + if (transmission_type == NOT_RETRANSMISSION) { + packet_.bytes_not_retransmitted.emplace( + packet_.bytes_not_retransmitted.value_or(0) + frame_len); + } else if (QuicUtils::IsRetransmittableFrame(frame.type)) { + // Packet transmission type is determined by the last added retransmittable + // frame of a retransmission type. If a packet has no retransmittable + // retransmission frames, it has type NOT_RETRANSMISSION. + packet_.transmission_type = transmission_type; + } + return true; +} + +void QuicPacketCreator::MaybeAddExtraPaddingForHeaderProtection() { + if (!framer_->version().HasHeaderProtection() || needs_full_padding_) { + return; + } + const size_t frame_bytes = PacketSize() - PacketHeaderSize(); + if (frame_bytes >= + MinPlaintextPacketSize(framer_->version(), GetPacketNumberLength())) { + return; + } + QuicByteCount min_header_protection_padding = + MinPlaintextPacketSize(framer_->version(), GetPacketNumberLength()) - + frame_bytes; + // Update pending_padding_bytes_. + pending_padding_bytes_ = + std::max(pending_padding_bytes_, min_header_protection_padding); +} + +bool QuicPacketCreator::MaybeCoalesceStreamFrame(const QuicStreamFrame& frame) { + if (queued_frames_.empty() || queued_frames_.back().type != STREAM_FRAME) { + return false; + } + QuicStreamFrame* candidate = &queued_frames_.back().stream_frame; + if (candidate->stream_id != frame.stream_id || + candidate->offset + candidate->data_length != frame.offset || + frame.data_length > BytesFree()) { + return false; + } + candidate->data_length += frame.data_length; + candidate->fin = frame.fin; + + // The back of retransmittable frames must be the same as the original + // queued frames' back. + QUICHE_DCHECK_EQ(packet_.retransmittable_frames.back().type, STREAM_FRAME) + << ENDPOINT; + QuicStreamFrame* retransmittable = + &packet_.retransmittable_frames.back().stream_frame; + QUICHE_DCHECK_EQ(retransmittable->stream_id, frame.stream_id) << ENDPOINT; + QUICHE_DCHECK_EQ(retransmittable->offset + retransmittable->data_length, + frame.offset) + << ENDPOINT; + retransmittable->data_length = candidate->data_length; + retransmittable->fin = candidate->fin; + packet_size_ += frame.data_length; + if (debug_delegate_ != nullptr) { + debug_delegate_->OnStreamFrameCoalesced(*candidate); + } + return true; +} + +bool QuicPacketCreator::RemoveSoftMaxPacketLength() { + if (latched_hard_max_packet_length_ == 0) { + return false; + } + if (!CanSetMaxPacketLength()) { + return false; + } + QUIC_DVLOG(1) << ENDPOINT << "Restoring max packet length to: " + << latched_hard_max_packet_length_; + SetMaxPacketLength(latched_hard_max_packet_length_); + // Reset latched_max_packet_length_. + latched_hard_max_packet_length_ = 0; + return true; +} + +void QuicPacketCreator::MaybeAddPadding() { + // The current packet should have no padding bytes because padding is only + // added when this method is called just before the packet is serialized. + if (BytesFreeForPadding() == 0) { + // Don't pad full packets. + return; + } + + if (packet_.fate == COALESCE) { + // Do not add full padding if the packet is going to be coalesced. + needs_full_padding_ = false; + } + + // Header protection requires a minimum plaintext packet size. + MaybeAddExtraPaddingForHeaderProtection(); + + QUIC_DVLOG(3) << "MaybeAddPadding for " << packet_.packet_number + << ": transmission_type:" << packet_.transmission_type + << ", fate:" << packet_.fate + << ", needs_full_padding_:" << needs_full_padding_ + << ", pending_padding_bytes_:" << pending_padding_bytes_ + << ", BytesFree:" << BytesFree(); + + if (!needs_full_padding_ && pending_padding_bytes_ == 0) { + // Do not need padding. + return; + } + + int padding_bytes = -1; + if (!needs_full_padding_) { + padding_bytes = + std::min(pending_padding_bytes_, BytesFreeForPadding()); + pending_padding_bytes_ -= padding_bytes; + } + + if (!queued_frames_.empty()) { + // Insert PADDING before the other frames to avoid adding a length field + // to any trailing STREAM frame. + if (needs_full_padding_) { + padding_bytes = BytesFreeForPadding(); + } + // AddFrame cannot be used here because it adds the frame to the end of the + // packet. + QuicFrame frame{QuicPaddingFrame(padding_bytes)}; + queued_frames_.insert(queued_frames_.begin(), frame); + packet_size_ += padding_bytes; + packet_.nonretransmittable_frames.push_back(frame); + if (packet_.transmission_type == NOT_RETRANSMISSION) { + packet_.bytes_not_retransmitted.emplace( + packet_.bytes_not_retransmitted.value_or(0) + padding_bytes); + } + } else { + bool success = AddFrame(QuicFrame(QuicPaddingFrame(padding_bytes)), + packet_.transmission_type); + QUIC_BUG_IF(quic_bug_10752_36, !success) + << ENDPOINT << "Failed to add padding_bytes: " << padding_bytes + << " transmission_type: " << packet_.transmission_type; + } +} + +bool QuicPacketCreator::IncludeNonceInPublicHeader() const { + return have_diversification_nonce_ && + packet_.encryption_level == ENCRYPTION_ZERO_RTT; +} + +bool QuicPacketCreator::IncludeVersionInHeader() const { + if (version().HasIetfInvariantHeader()) { + return packet_.encryption_level < ENCRYPTION_FORWARD_SECURE; + } + return send_version_in_packet_; +} + +void QuicPacketCreator::AddPendingPadding(QuicByteCount size) { + pending_padding_bytes_ += size; + QUIC_DVLOG(3) << "After AddPendingPadding(" << size + << "), pending_padding_bytes_:" << pending_padding_bytes_; +} + +bool QuicPacketCreator::StreamFrameIsClientHello( + const QuicStreamFrame& frame) const { + if (framer_->perspective() == Perspective::IS_SERVER || + !QuicUtils::IsCryptoStreamId(framer_->transport_version(), + frame.stream_id)) { + return false; + } + // The ClientHello is always sent with INITIAL encryption. + return packet_.encryption_level == ENCRYPTION_INITIAL; +} + +void QuicPacketCreator::SetServerConnectionIdIncluded( + QuicConnectionIdIncluded server_connection_id_included) { + QUICHE_DCHECK(server_connection_id_included == CONNECTION_ID_PRESENT || + server_connection_id_included == CONNECTION_ID_ABSENT) + << ENDPOINT; + QUICHE_DCHECK(framer_->perspective() == Perspective::IS_SERVER || + server_connection_id_included != CONNECTION_ID_ABSENT) + << ENDPOINT; + server_connection_id_included_ = server_connection_id_included; +} + +void QuicPacketCreator::SetServerConnectionId( + QuicConnectionId server_connection_id) { + server_connection_id_ = server_connection_id; +} + +void QuicPacketCreator::SetClientConnectionId( + QuicConnectionId client_connection_id) { + QUICHE_DCHECK(client_connection_id.IsEmpty() || + framer_->version().SupportsClientConnectionIds()) + << ENDPOINT; + client_connection_id_ = client_connection_id; +} + +QuicPacketLength QuicPacketCreator::GetCurrentLargestMessagePayload() const { + if (!VersionSupportsMessageFrames(framer_->transport_version())) { + return 0; + } + const size_t packet_header_size = GetPacketHeaderSize( + framer_->transport_version(), GetDestinationConnectionIdLength(), + GetSourceConnectionIdLength(), IncludeVersionInHeader(), + IncludeNonceInPublicHeader(), GetPacketNumberLength(), + // No Retry token on packets containing application data. + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0, 0, GetLengthLength()); + // This is the largest possible message payload when the length field is + // omitted. + size_t max_plaintext_size = + latched_hard_max_packet_length_ == 0 + ? max_plaintext_size_ + : framer_->GetMaxPlaintextSize(latched_hard_max_packet_length_); + size_t largest_frame = + max_plaintext_size - std::min(max_plaintext_size, packet_header_size); + if (static_cast(largest_frame) > max_datagram_frame_size_) { + largest_frame = static_cast(max_datagram_frame_size_); + } + return largest_frame - std::min(largest_frame, kQuicFrameTypeSize); +} + +QuicPacketLength QuicPacketCreator::GetGuaranteedLargestMessagePayload() const { + if (!VersionSupportsMessageFrames(framer_->transport_version())) { + return 0; + } + // QUIC Crypto server packets may include a diversification nonce. + const bool may_include_nonce = + framer_->version().handshake_protocol == PROTOCOL_QUIC_CRYPTO && + framer_->perspective() == Perspective::IS_SERVER; + // IETF QUIC long headers include a length on client 0RTT packets. + quiche::QuicheVariableLengthIntegerLength length_length = + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0; + if (framer_->perspective() == Perspective::IS_CLIENT) { + length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2; + } + if (!QuicVersionHasLongHeaderLengths(framer_->transport_version())) { + length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0; + } + const size_t packet_header_size = GetPacketHeaderSize( + framer_->transport_version(), GetDestinationConnectionIdLength(), + // Assume CID lengths don't change, but version may be present. + GetSourceConnectionIdLength(), kIncludeVersion, may_include_nonce, + PACKET_4BYTE_PACKET_NUMBER, + // No Retry token on packets containing application data. + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0, 0, length_length); + // This is the largest possible message payload when the length field is + // omitted. + size_t max_plaintext_size = + latched_hard_max_packet_length_ == 0 + ? max_plaintext_size_ + : framer_->GetMaxPlaintextSize(latched_hard_max_packet_length_); + size_t largest_frame = + max_plaintext_size - std::min(max_plaintext_size, packet_header_size); + if (static_cast(largest_frame) > max_datagram_frame_size_) { + largest_frame = static_cast(max_datagram_frame_size_); + } + const QuicPacketLength largest_payload = + largest_frame - std::min(largest_frame, kQuicFrameTypeSize); + // This must always be less than or equal to GetCurrentLargestMessagePayload. + QUICHE_DCHECK_LE(largest_payload, GetCurrentLargestMessagePayload()) + << ENDPOINT; + return largest_payload; +} + +bool QuicPacketCreator::AttemptingToSendUnencryptedStreamData() { + if (packet_.encryption_level == ENCRYPTION_ZERO_RTT || + packet_.encryption_level == ENCRYPTION_FORWARD_SECURE) { + return false; + } + const std::string error_details = + absl::StrCat("Cannot send stream data with level: ", + EncryptionLevelToString(packet_.encryption_level)); + QUIC_BUG(quic_bug_10752_37) << ENDPOINT << error_details; + delegate_->OnUnrecoverableError(QUIC_ATTEMPT_TO_SEND_UNENCRYPTED_STREAM_DATA, + error_details); + return true; +} + +bool QuicPacketCreator::HasIetfLongHeader() const { + return version().HasIetfInvariantHeader() && + packet_.encryption_level < ENCRYPTION_FORWARD_SECURE; +} + +// static +size_t QuicPacketCreator::MinPlaintextPacketSize( + const ParsedQuicVersion& version, + QuicPacketNumberLength packet_number_length) { + if (!version.HasHeaderProtection()) { + return 0; + } + // Header protection samples 16 bytes of ciphertext starting 4 bytes after the + // packet number. In IETF QUIC, all AEAD algorithms have a 16-byte auth tag + // (i.e. the ciphertext is 16 bytes larger than the plaintext). Since packet + // numbers could be as small as 1 byte, but the sample starts 4 bytes after + // the packet number, at least 3 bytes of plaintext are needed to make sure + // that there is enough ciphertext to sample. + // + // Google QUIC crypto uses different AEAD algorithms - in particular the auth + // tags are only 12 bytes instead of 16 bytes. Since the auth tag is 4 bytes + // shorter, 4 more bytes of plaintext are needed to guarantee there is enough + // ciphertext to sample. + // + // This method could check for PROTOCOL_TLS1_3 vs PROTOCOL_QUIC_CRYPTO and + // return 3 when TLS 1.3 is in use (the use of IETF vs Google QUIC crypters is + // determined based on the handshake protocol used). However, even when TLS + // 1.3 is used, unittests still use NullEncrypter/NullDecrypter (and other + // test crypters) which also only use 12 byte tags. + // + return (version.UsesTls() ? 4 : 8) - packet_number_length; +} + +QuicPacketNumber QuicPacketCreator::NextSendingPacketNumber() const { + if (!packet_number().IsInitialized()) { + return framer_->first_sending_packet_number(); + } + return packet_number() + 1; +} + +bool QuicPacketCreator::PacketFlusherAttached() const { + return flusher_attached_; +} + +bool QuicPacketCreator::HasSoftMaxPacketLength() const { + return latched_hard_max_packet_length_ != 0; +} + +void QuicPacketCreator::SetDefaultPeerAddress(QuicSocketAddress address) { + if (!packet_.peer_address.IsInitialized()) { + packet_.peer_address = address; + return; + } + if (packet_.peer_address != address) { + FlushCurrentPacket(); + packet_.peer_address = address; + } +} + +#define ENDPOINT2 \ + (creator_->framer_->perspective() == Perspective::IS_SERVER ? "Server: " \ + : "Client: ") + +QuicPacketCreator::ScopedPeerAddressContext::ScopedPeerAddressContext( + QuicPacketCreator* creator, QuicSocketAddress address, + bool update_connection_id) + : ScopedPeerAddressContext(creator, address, EmptyQuicConnectionId(), + EmptyQuicConnectionId(), update_connection_id) {} + +QuicPacketCreator::ScopedPeerAddressContext::ScopedPeerAddressContext( + QuicPacketCreator* creator, QuicSocketAddress address, + const QuicConnectionId& client_connection_id, + const QuicConnectionId& server_connection_id, bool update_connection_id) + : creator_(creator), + old_peer_address_(creator_->packet_.peer_address), + old_client_connection_id_(creator_->GetClientConnectionId()), + old_server_connection_id_(creator_->GetServerConnectionId()), + update_connection_id_(update_connection_id) { + QUIC_BUG_IF(quic_bug_12398_19, !old_peer_address_.IsInitialized()) + << ENDPOINT2 + << "Context is used before serialized packet's peer address is " + "initialized."; + creator_->SetDefaultPeerAddress(address); + if (update_connection_id_) { + // Flush current packet if connection ID length changes. + if (address == old_peer_address_ && + ((client_connection_id.length() != + old_client_connection_id_.length()) || + (server_connection_id.length() != + old_server_connection_id_.length()))) { + creator_->FlushCurrentPacket(); + } + creator_->SetClientConnectionId(client_connection_id); + creator_->SetServerConnectionId(server_connection_id); + } +} + +QuicPacketCreator::ScopedPeerAddressContext::~ScopedPeerAddressContext() { + creator_->SetDefaultPeerAddress(old_peer_address_); + if (update_connection_id_) { + creator_->SetClientConnectionId(old_client_connection_id_); + creator_->SetServerConnectionId(old_server_connection_id_); + } +} + +QuicPacketCreator::ScopedSerializationFailureHandler:: + ScopedSerializationFailureHandler(QuicPacketCreator* creator) + : creator_(creator) {} + +QuicPacketCreator::ScopedSerializationFailureHandler:: + ~ScopedSerializationFailureHandler() { + if (creator_ == nullptr) { + return; + } + // Always clear queued_frames_. + creator_->queued_frames_.clear(); + + if (creator_->packet_.encrypted_buffer == nullptr) { + const std::string error_details = "Failed to SerializePacket."; + QUIC_BUG(quic_bug_10752_38) << ENDPOINT2 << error_details; + creator_->delegate_->OnUnrecoverableError(QUIC_FAILED_TO_SERIALIZE_PACKET, + error_details); + } +} + +#undef ENDPOINT2 + +void QuicPacketCreator::set_encryption_level(EncryptionLevel level) { + QUICHE_DCHECK(level == packet_.encryption_level || !HasPendingFrames()) + << ENDPOINT << "Cannot update encryption level from " + << packet_.encryption_level << " to " << level + << " when we already have pending frames: " + << QuicFramesToString(queued_frames_); + packet_.encryption_level = level; +} + +void QuicPacketCreator::AddPathChallengeFrame( + const QuicPathFrameBuffer& payload) { + // TODO(danzh) Unify similar checks at several entry points into one in + // AddFrame(). Sort out test helper functions and peer class that don't + // enforce this check. + QUIC_BUG_IF(quic_bug_10752_39, !flusher_attached_) + << ENDPOINT + << "Packet flusher is not attached when " + "generator tries to write stream data."; + // Write a PATH_CHALLENGE frame, which has a random 8-byte payload. + QuicFrame frame(QuicPathChallengeFrame(0, payload)); + if (AddPaddedFrameWithRetry(frame)) { + return; + } + // Fail silently if the probing packet cannot be written, path validation + // initiator will retry sending automatically. + // TODO(danzh) This will consume retry budget, if it causes performance + // regression, consider to notify the caller about the sending failure and let + // the caller to decide if it worth retrying. + QUIC_DVLOG(1) << ENDPOINT << "Can't send PATH_CHALLENGE now"; +} + +bool QuicPacketCreator::AddPathResponseFrame( + const QuicPathFrameBuffer& data_buffer) { + QuicFrame frame(QuicPathResponseFrame(kInvalidControlFrameId, data_buffer)); + if (AddPaddedFrameWithRetry(frame)) { + return true; + } + + QUIC_DVLOG(1) << ENDPOINT << "Can't send PATH_RESPONSE now"; + return false; +} + +bool QuicPacketCreator::AddPaddedFrameWithRetry(const QuicFrame& frame) { + if (HasPendingFrames()) { + if (AddPaddedSavedFrame(frame, NOT_RETRANSMISSION)) { + // Frame is queued. + return true; + } + } + // Frame was not queued but queued frames were flushed. + QUICHE_DCHECK(!HasPendingFrames()) << ENDPOINT; + if (!delegate_->ShouldGeneratePacket(NO_RETRANSMITTABLE_DATA, + NOT_HANDSHAKE)) { + return false; + } + bool success = AddPaddedSavedFrame(frame, NOT_RETRANSMISSION); + QUIC_BUG_IF(quic_bug_12398_20, !success) << ENDPOINT; + return true; +} + +bool QuicPacketCreator::HasRetryToken() const { return !retry_token_.empty(); } + +#undef ENDPOINT // undef for jumbo builds +} // namespace quic diff --git a/quiche/quic/core/quic_packet_creator.h b/quiche/quic/core/quic_packet_creator.h new file mode 100644 index 000000000000..785efb143910 --- /dev/null +++ b/quiche/quic/core/quic_packet_creator.h @@ -0,0 +1,693 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Responsible for creating packets on behalf of a QuicConnection. +// Packets are serialized just-in-time. Stream data and control frames will be +// requested from the Connection just-in-time. Frames are accumulated into +// "current" packet until no more frames can fit, then current packet gets +// serialized and passed to connection via OnSerializedPacket(). +// +// Whether a packet should be serialized is determined by whether delegate is +// writable. If the Delegate is not writable, then no operations will cause +// a packet to be serialized. + +#ifndef QUICHE_QUIC_CORE_QUIC_PACKET_CREATOR_H_ +#define QUICHE_QUIC_CORE_QUIC_PACKET_CREATOR_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/quic/core/frames/quic_stream_frame.h" +#include "quiche/quic/core/quic_coalesced_packet.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_framer.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/common/quiche_circular_deque.h" + +namespace quic { +namespace test { +class QuicPacketCreatorPeer; +} + +class QUIC_EXPORT_PRIVATE QuicPacketCreator { + public: + // A delegate interface for further processing serialized packet. + class QUIC_EXPORT_PRIVATE DelegateInterface { + public: + virtual ~DelegateInterface() {} + // Get a buffer of kMaxOutgoingPacketSize bytes to serialize the next + // packet. If the return value's buffer is nullptr, QuicPacketCreator will + // serialize on a stack buffer. + virtual QuicPacketBuffer GetPacketBuffer() = 0; + // Called when a packet is serialized. Delegate take the ownership of + // |serialized_packet|. + virtual void OnSerializedPacket(SerializedPacket serialized_packet) = 0; + + // Called when an unrecoverable error is encountered. + virtual void OnUnrecoverableError(QuicErrorCode error, + const std::string& error_details) = 0; + + // Consults delegate whether a packet should be generated. + virtual bool ShouldGeneratePacket(HasRetransmittableData retransmittable, + IsHandshake handshake) = 0; + // Called when there is data to be sent. Retrieves updated ACK frame from + // the delegate. + virtual const QuicFrames MaybeBundleAckOpportunistically() = 0; + + // Returns the packet fate for serialized packets which will be handed over + // to delegate via OnSerializedPacket(). Called when a packet is about to be + // serialized. + virtual SerializedPacketFate GetSerializedPacketFate( + bool is_mtu_discovery, EncryptionLevel encryption_level) = 0; + }; + + // Interface which gets callbacks from the QuicPacketCreator at interesting + // points. Implementations must not mutate the state of the creator + // as a result of these callbacks. + class QUIC_EXPORT_PRIVATE DebugDelegate { + public: + virtual ~DebugDelegate() {} + + // Called when a frame has been added to the current packet. + virtual void OnFrameAddedToPacket(const QuicFrame& /*frame*/) {} + + // Called when a stream frame is coalesced with an existing stream frame. + // |frame| is the new stream frame. + virtual void OnStreamFrameCoalesced(const QuicStreamFrame& /*frame*/) {} + }; + + // Set the peer address and connection IDs with which the serialized packet + // will be sent to during the scope of this object. Upon exiting the scope, + // the original peer address and connection IDs are restored. + class QUIC_EXPORT_PRIVATE ScopedPeerAddressContext { + public: + ScopedPeerAddressContext(QuicPacketCreator* creator, + QuicSocketAddress address, + bool update_connection_id); + + ScopedPeerAddressContext(QuicPacketCreator* creator, + QuicSocketAddress address, + const QuicConnectionId& client_connection_id, + const QuicConnectionId& server_connection_id, + bool update_connection_id); + ~ScopedPeerAddressContext(); + + private: + QuicPacketCreator* creator_; + QuicSocketAddress old_peer_address_; + QuicConnectionId old_client_connection_id_; + QuicConnectionId old_server_connection_id_; + bool update_connection_id_; + }; + + QuicPacketCreator(QuicConnectionId server_connection_id, QuicFramer* framer, + DelegateInterface* delegate); + QuicPacketCreator(QuicConnectionId server_connection_id, QuicFramer* framer, + QuicRandom* random, DelegateInterface* delegate); + QuicPacketCreator(const QuicPacketCreator&) = delete; + QuicPacketCreator& operator=(const QuicPacketCreator&) = delete; + + ~QuicPacketCreator(); + + // Makes the framer not serialize the protocol version in sent packets. + void StopSendingVersion(); + + // SetDiversificationNonce sets the nonce that will be sent in each public + // header of packets encrypted at the initial encryption level. Should only + // be called by servers. + void SetDiversificationNonce(const DiversificationNonce& nonce); + + // Update the packet number length to use in future packets as soon as it + // can be safely changed. + // TODO(fayang): Directly set packet number length instead of compute it in + // creator. + void UpdatePacketNumberLength(QuicPacketNumber least_packet_awaited_by_peer, + QuicPacketCount max_packets_in_flight); + + // Skip |count| packet numbers. + void SkipNPacketNumbers(QuicPacketCount count, + QuicPacketNumber least_packet_awaited_by_peer, + QuicPacketCount max_packets_in_flight); + + // The overhead the framing will add for a packet with one frame. + static size_t StreamFramePacketOverhead( + QuicTransportVersion version, uint8_t destination_connection_id_length, + uint8_t source_connection_id_length, bool include_version, + bool include_diversification_nonce, + QuicPacketNumberLength packet_number_length, + quiche::QuicheVariableLengthIntegerLength retry_token_length_length, + quiche::QuicheVariableLengthIntegerLength length_length, + QuicStreamOffset offset); + + // Returns false and flushes all pending frames if current open packet is + // full. + // If current packet is not full, creates a stream frame that fits into the + // open packet and adds it to the packet. + bool ConsumeDataToFillCurrentPacket(QuicStreamId id, size_t data_size, + QuicStreamOffset offset, bool fin, + bool needs_full_padding, + TransmissionType transmission_type, + QuicFrame* frame); + + // Creates a CRYPTO frame that fits into the current packet (which must be + // empty) and adds it to the packet. + bool ConsumeCryptoDataToFillCurrentPacket(EncryptionLevel level, + size_t write_length, + QuicStreamOffset offset, + bool needs_full_padding, + TransmissionType transmission_type, + QuicFrame* frame); + + // Returns true if current open packet can accommodate more stream frames of + // stream |id| at |offset| and data length |data_size|, false otherwise. + // TODO(fayang): mark this const by moving RemoveSoftMaxPacketLength out. + bool HasRoomForStreamFrame(QuicStreamId id, QuicStreamOffset offset, + size_t data_size); + + // Returns true if current open packet can accommodate a message frame of + // |length|. + // TODO(fayang): mark this const by moving RemoveSoftMaxPacketLength out. + bool HasRoomForMessageFrame(QuicByteCount length); + + // Serializes all added frames into a single packet and invokes the delegate_ + // to further process the SerializedPacket. + void FlushCurrentPacket(); + + // Optimized method to create a QuicStreamFrame and serialize it. Adds the + // QuicStreamFrame to the returned SerializedPacket. Sets + // |num_bytes_consumed| to the number of bytes consumed to create the + // QuicStreamFrame. + void CreateAndSerializeStreamFrame(QuicStreamId id, size_t write_length, + QuicStreamOffset iov_offset, + QuicStreamOffset stream_offset, bool fin, + TransmissionType transmission_type, + size_t* num_bytes_consumed); + + // Returns true if there are frames pending to be serialized. + bool HasPendingFrames() const; + + // TODO(haoyuewang) Remove this debug utility. + // Returns the information of pending frames as a string. + std::string GetPendingFramesInfo() const; + + // Returns true if there are retransmittable frames pending to be serialized. + bool HasPendingRetransmittableFrames() const; + + // Returns true if there are stream frames for |id| pending to be serialized. + bool HasPendingStreamFramesOfStream(QuicStreamId id) const; + + // Returns the number of bytes which are available to be used by additional + // frames in the packet. Since stream frames are slightly smaller when they + // are the last frame in a packet, this method will return a different + // value than max_packet_size - PacketSize(), in this case. + size_t BytesFree() const; + + // Since PADDING frames are always prepended, a separate function computes + // available space without considering STREAM frame expansion. + size_t BytesFreeForPadding() const; + + // Returns the number of bytes that the packet will expand by if a new frame + // is added to the packet. If the last frame was a stream frame, it will + // expand slightly when a new frame is added, and this method returns the + // amount of expected expansion. + size_t ExpansionOnNewFrame() const; + + // Returns the number of bytes that the packet will expand by when a new frame + // is going to be added. |last_frame| is the last frame of the packet. + static size_t ExpansionOnNewFrameWithLastFrame(const QuicFrame& last_frame, + QuicTransportVersion version); + + // Returns the number of bytes in the current packet, including the header, + // if serialized with the current frames. Adding a frame to the packet + // may change the serialized length of existing frames, as per the comment + // in BytesFree. + size_t PacketSize() const; + + // Tries to add |frame| to the packet creator's list of frames to be + // serialized. If the frame does not fit into the current packet, flushes the + // packet and returns false. + bool AddFrame(const QuicFrame& frame, TransmissionType transmission_type); + + // Identical to AddSavedFrame, but allows the frame to be padded. + bool AddPaddedSavedFrame(const QuicFrame& frame, + TransmissionType transmission_type); + + // Creates a connectivity probing packet for versions prior to version 99. + std::unique_ptr SerializeConnectivityProbingPacket(); + + // Create connectivity probing request and response packets using PATH + // CHALLENGE and PATH RESPONSE frames, respectively, for version 99/IETF QUIC. + // SerializePathChallengeConnectivityProbingPacket will pad the packet to be + // MTU bytes long. + std::unique_ptr + SerializePathChallengeConnectivityProbingPacket( + const QuicPathFrameBuffer& payload); + + // If |is_padded| is true then SerializePathResponseConnectivityProbingPacket + // will pad the packet to be MTU bytes long, else it will not pad the packet. + // |payloads| is cleared. + std::unique_ptr + SerializePathResponseConnectivityProbingPacket( + const quiche::QuicheCircularDeque& payloads, + const bool is_padded); + + // Add PATH_RESPONSE to current packet, flush before or afterwards if needed. + bool AddPathResponseFrame(const QuicPathFrameBuffer& data_buffer); + + // Add PATH_CHALLENGE to current packet, flush before or afterwards if needed. + // This is a best effort adding. It may fail becasue of delegate state, but + // it's okay because of path validation retry mechanism. + void AddPathChallengeFrame(const QuicPathFrameBuffer& payload); + + // Returns a dummy packet that is valid but contains no useful information. + static SerializedPacket NoPacket(); + + // Returns the server connection ID to send over the wire. + const QuicConnectionId& GetServerConnectionId() const { + return server_connection_id_; + } + + // Returns the client connection ID to send over the wire. + const QuicConnectionId& GetClientConnectionId() const { + return client_connection_id_; + } + + // Returns the destination connection ID to send over the wire. + QuicConnectionId GetDestinationConnectionId() const; + + // Returns the source connection ID to send over the wire. + QuicConnectionId GetSourceConnectionId() const; + + // Returns length of destination connection ID to send over the wire. + uint8_t GetDestinationConnectionIdLength() const; + + // Returns length of source connection ID to send over the wire. + uint8_t GetSourceConnectionIdLength() const; + + // Sets whether the server connection ID should be sent over the wire. + void SetServerConnectionIdIncluded( + QuicConnectionIdIncluded server_connection_id_included); + + // Update the server connection ID used in outgoing packets. + void SetServerConnectionId(QuicConnectionId server_connection_id); + + // Update the client connection ID used in outgoing packets. + void SetClientConnectionId(QuicConnectionId client_connection_id); + + // Sets the encryption level that will be applied to new packets. + void set_encryption_level(EncryptionLevel level); + EncryptionLevel encryption_level() { return packet_.encryption_level; } + + // packet number of the last created packet, or 0 if no packets have been + // created. + QuicPacketNumber packet_number() const { return packet_.packet_number; } + + QuicByteCount max_packet_length() const { return max_packet_length_; } + + bool has_ack() const { return packet_.has_ack; } + + bool has_stop_waiting() const { return packet_.has_stop_waiting; } + + // Sets the encrypter to use for the encryption level and updates the max + // plaintext size. + void SetEncrypter(EncryptionLevel level, + std::unique_ptr encrypter); + + // Indicates whether the packet creator is in a state where it can change + // current maximum packet length. + bool CanSetMaxPacketLength() const; + + // Sets the maximum packet length. + void SetMaxPacketLength(QuicByteCount length); + + // Sets the maximum DATAGRAM/MESSAGE frame size we can send. + void SetMaxDatagramFrameSize(QuicByteCount max_datagram_frame_size); + + // Set a soft maximum packet length in the creator. If a packet cannot be + // successfully created, creator will remove the soft limit and use the actual + // max packet length. + void SetSoftMaxPacketLength(QuicByteCount length); + + // Increases pending_padding_bytes by |size|. Pending padding will be sent by + // MaybeAddPadding(). + void AddPendingPadding(QuicByteCount size); + + // Sets the retry token to be sent over the wire in IETF Initial packets. + void SetRetryToken(absl::string_view retry_token); + + // Consumes retransmittable control |frame|. Returns true if the frame is + // successfully consumed. Returns false otherwise. + bool ConsumeRetransmittableControlFrame(const QuicFrame& frame); + + // Given some data, may consume part or all of it and pass it to the + // packet creator to be serialized into packets. If not in batch + // mode, these packets will also be sent during this call. + // When |state| is FIN_AND_PADDING, random padding of size [1, 256] will be + // added after stream frames. If current constructed packet cannot + // accommodate, the padding will overflow to the next packet(s). + QuicConsumedData ConsumeData(QuicStreamId id, size_t write_length, + QuicStreamOffset offset, + StreamSendingState state); + + // Sends as many data only packets as allowed by the send algorithm and the + // available iov. + // This path does not support padding, or bundling pending frames. + // In case we access this method from ConsumeData, total_bytes_consumed + // keeps track of how many bytes have already been consumed. + QuicConsumedData ConsumeDataFastPath(QuicStreamId id, size_t write_length, + QuicStreamOffset offset, bool fin, + size_t total_bytes_consumed); + + // Consumes data for CRYPTO frames sent at |level| starting at |offset| for a + // total of |write_length| bytes, and returns the number of bytes consumed. + // The data is passed into the packet creator and serialized into one or more + // packets. + size_t ConsumeCryptoData(EncryptionLevel level, size_t write_length, + QuicStreamOffset offset); + + // Generates an MTU discovery packet of specified size. + void GenerateMtuDiscoveryPacket(QuicByteCount target_mtu); + + // Called when there is data to be sent, Retrieves updated ACK frame from + // delegate_ and flushes it. + void MaybeBundleAckOpportunistically(); + + // Called to flush ACK and STOP_WAITING frames, returns false if the flush + // fails. + bool FlushAckFrame(const QuicFrames& frames); + + // Adds a random amount of padding (between 1 to 256 bytes). + void AddRandomPadding(); + + // Attaches packet flusher. + void AttachPacketFlusher(); + + // Flushes everything, including current open packet and pending padding. + void Flush(); + + // Sends remaining pending padding. + // Pending paddings should only be sent when there is nothing else to send. + void SendRemainingPendingPadding(); + + // Set the minimum number of bytes for the server connection id length; + void SetServerConnectionIdLength(uint32_t length); + + // Set transmission type of next constructed packets. + void SetTransmissionType(TransmissionType type); + + // Tries to add a message frame containing |message| and returns the status. + MessageStatus AddMessageFrame(QuicMessageId message_id, + absl::Span message); + + // Returns the largest payload that will fit into a single MESSAGE frame. + QuicPacketLength GetCurrentLargestMessagePayload() const; + // Returns the largest payload that will fit into a single MESSAGE frame at + // any point during the connection. This assumes the version and + // connection ID lengths do not change. + QuicPacketLength GetGuaranteedLargestMessagePayload() const; + + // Packet number of next created packet. + QuicPacketNumber NextSendingPacketNumber() const; + + void set_debug_delegate(DebugDelegate* debug_delegate) { + debug_delegate_ = debug_delegate; + } + + QuicByteCount pending_padding_bytes() const { return pending_padding_bytes_; } + + ParsedQuicVersion version() const { return framer_->version(); } + + QuicTransportVersion transport_version() const { + return framer_->transport_version(); + } + + // Returns the minimum size that the plaintext of a packet must be. + static size_t MinPlaintextPacketSize( + const ParsedQuicVersion& version, + QuicPacketNumberLength packet_number_length); + + // Indicates whether packet flusher is currently attached. + bool PacketFlusherAttached() const; + + void set_fully_pad_crypto_handshake_packets(bool new_value) { + fully_pad_crypto_handshake_packets_ = new_value; + } + + bool fully_pad_crypto_handshake_packets() const { + return fully_pad_crypto_handshake_packets_; + } + + // Serialize a probing packet that uses IETF QUIC's PATH CHALLENGE frame. Also + // fills the packet with padding. + size_t BuildPaddedPathChallengePacket(const QuicPacketHeader& header, + char* buffer, size_t packet_length, + const QuicPathFrameBuffer& payload, + EncryptionLevel level); + + // Serialize a probing response packet that uses IETF QUIC's PATH RESPONSE + // frame. Also fills the packet with padding if |is_padded| is + // true. |payloads| is always emptied, even if the packet can not be + // successfully built. + size_t BuildPathResponsePacket( + const QuicPacketHeader& header, char* buffer, size_t packet_length, + const quiche::QuicheCircularDeque& payloads, + const bool is_padded, EncryptionLevel level); + + // Serializes a probing packet, which is a padded PING packet. Returns the + // length of the packet. Returns 0 if it fails to serialize. + size_t BuildConnectivityProbingPacket(const QuicPacketHeader& header, + char* buffer, size_t packet_length, + EncryptionLevel level); + + // Serializes |coalesced| to provided |buffer|, returns coalesced packet + // length if serialization succeeds. Otherwise, returns 0. + size_t SerializeCoalescedPacket(const QuicCoalescedPacket& coalesced, + char* buffer, size_t buffer_len); + + // Returns true if max_packet_length_ is currently a soft value. + bool HasSoftMaxPacketLength() const; + + // Use this address to sent to the peer from now on. If this address is + // different from the current one, flush all the queue frames first. + void SetDefaultPeerAddress(QuicSocketAddress address); + + // Return true if retry_token_ is not empty. + bool HasRetryToken() const; + + const QuicSocketAddress& peer_address() const { return packet_.peer_address; } + + private: + friend class test::QuicPacketCreatorPeer; + + // Used to 1) clear queued_frames_, 2) report unrecoverable error (if + // serialization fails) upon exiting the scope. + class QUIC_EXPORT_PRIVATE ScopedSerializationFailureHandler { + public: + explicit ScopedSerializationFailureHandler(QuicPacketCreator* creator); + ~ScopedSerializationFailureHandler(); + + private: + QuicPacketCreator* creator_; // Unowned. + }; + + // Attempts to build a data packet with chaos protection. If this packet isn't + // supposed to be protected or if serialization fails then absl::nullopt is + // returned. Otherwise returns the serialized length. + absl::optional MaybeBuildDataPacketWithChaosProtection( + const QuicPacketHeader& header, char* buffer); + + // Creates a stream frame which fits into the current open packet. If + // |data_size| is 0 and fin is true, the expected behavior is to consume + // the fin. + void CreateStreamFrame(QuicStreamId id, size_t data_size, + QuicStreamOffset offset, bool fin, QuicFrame* frame); + + // Creates a CRYPTO frame which fits into the current open packet. Returns + // false if there isn't enough room in the current open packet for a CRYPTO + // frame, and true if there is. + bool CreateCryptoFrame(EncryptionLevel level, size_t write_length, + QuicStreamOffset offset, QuicFrame* frame); + + void FillPacketHeader(QuicPacketHeader* header); + + // Adds a padding frame to the current packet (if there is space) when (1) + // current packet needs full padding or (2) there are pending paddings. + void MaybeAddPadding(); + + // Serializes all frames which have been added and adds any which should be + // retransmitted to packet_.retransmittable_frames. All frames must fit into + // a single packet. Returns true on success, otherwise, returns false. + // Fails if |encrypted_buffer| is not large enough for the encrypted packet. + // + // Padding may be added if |allow_padding|. Currently, the only case where it + // is disallowed is reserializing a coalesced initial packet. + ABSL_MUST_USE_RESULT bool SerializePacket( + QuicOwnedPacketBuffer encrypted_buffer, size_t encrypted_buffer_len, + bool allow_padding); + + // Called after a new SerialiedPacket is created to call the delegate's + // OnSerializedPacket and reset state. + void OnSerializedPacket(); + + // Clears all fields of packet_ that should be cleared between serializations. + void ClearPacket(); + + // Re-serialzes frames of ENCRYPTION_INITIAL packet in coalesced packet with + // the original packet's packet number and packet number length. + // |padding_size| indicates the size of necessary padding. Returns 0 if + // serialization fails. + size_t ReserializeInitialPacketInCoalescedPacket( + const SerializedPacket& packet, size_t padding_size, char* buffer, + size_t buffer_len); + + // Tries to coalesce |frame| with the back of |queued_frames_|. + // Returns true on success. + bool MaybeCoalesceStreamFrame(const QuicStreamFrame& frame); + + // Called to remove the soft max_packet_length and restores + // latched_hard_max_packet_length_ if the packet cannot accommodate a single + // frame. Returns true if the soft limit is successfully removed. Returns + // false if either there is no current soft limit or there are queued frames + // (such that the packet length cannot be changed). + bool RemoveSoftMaxPacketLength(); + + // Returns true if a diversification nonce should be included in the current + // packet's header. + bool IncludeNonceInPublicHeader() const; + + // Returns true if version should be included in current packet's header. + bool IncludeVersionInHeader() const; + + // Returns length of packet number to send over the wire. + // packet_.packet_number_length should never be read directly, use this + // function instead. + QuicPacketNumberLength GetPacketNumberLength() const; + + // Returns the size in bytes of the packet header. + size_t PacketHeaderSize() const; + + // Returns whether the destination connection ID is sent over the wire. + QuicConnectionIdIncluded GetDestinationConnectionIdIncluded() const; + + // Returns whether the source connection ID is sent over the wire. + QuicConnectionIdIncluded GetSourceConnectionIdIncluded() const; + + // Returns length of the retry token variable length integer to send over the + // wire. Is non-zero for v99 IETF Initial packets. + quiche::QuicheVariableLengthIntegerLength GetRetryTokenLengthLength() const; + + // Returns the retry token to send over the wire, only sent in + // v99 IETF Initial packets. + absl::string_view GetRetryToken() const; + + // Returns length of the length variable length integer to send over the + // wire. Is non-zero for v99 IETF Initial, 0-RTT or Handshake packets. + quiche::QuicheVariableLengthIntegerLength GetLengthLength() const; + + // Returns true if |frame| is a ClientHello. + bool StreamFrameIsClientHello(const QuicStreamFrame& frame) const; + + // Returns true if packet under construction has IETF long header. + bool HasIetfLongHeader() const; + + // Get serialized frame length. Returns 0 if the frame does not fit into + // current packet. + size_t GetSerializedFrameLength(const QuicFrame& frame); + + // Add extra padding to pending_padding_bytes_ to meet minimum plaintext + // packet size required for header protection. + void MaybeAddExtraPaddingForHeaderProtection(); + + // Returns true and close connection if it attempts to send unencrypted data. + bool AttemptingToSendUnencryptedStreamData(); + + // Add the given frame to the current packet with full padding. If the current + // packet doesn't have enough space, flush once and try again. Return false if + // fail to add. + bool AddPaddedFrameWithRetry(const QuicFrame& frame); + + // Does not own these delegates or the framer. + DelegateInterface* delegate_; + DebugDelegate* debug_delegate_; + QuicFramer* framer_; + QuicRandom* random_; + + // Controls whether version should be included while serializing the packet. + // send_version_in_packet_ should never be read directly, use + // IncludeVersionInHeader() instead. + bool send_version_in_packet_; + // If true, then |diversification_nonce_| will be included in the header of + // all packets created at the initial encryption level. + bool have_diversification_nonce_; + DiversificationNonce diversification_nonce_; + // Maximum length including headers and encryption (UDP payload length.) + QuicByteCount max_packet_length_; + size_t max_plaintext_size_; + // Whether the server_connection_id is sent over the wire. + QuicConnectionIdIncluded server_connection_id_included_; + + // Frames to be added to the next SerializedPacket + QuicFrames queued_frames_; + + // Serialization size of header + frames. If there is no queued frames, + // packet_size_ is 0. + // TODO(ianswett): Move packet_size_ into SerializedPacket once + // QuicEncryptedPacket has been flattened into SerializedPacket. + size_t packet_size_; + QuicConnectionId server_connection_id_; + QuicConnectionId client_connection_id_; + + // Packet used to invoke OnSerializedPacket. + SerializedPacket packet_; + + // Retry token to send over the wire in v99 IETF Initial packets. + std::string retry_token_; + + // Pending padding bytes to send. Pending padding bytes will be sent in next + // packet(s) (after all other frames) if current constructed packet does not + // have room to send all of them. + QuicByteCount pending_padding_bytes_; + + // Indicates whether current constructed packet needs full padding to max + // packet size. Please note, full padding does not consume pending padding + // bytes. + bool needs_full_padding_; + + // Transmission type of the next serialized packet. + TransmissionType next_transmission_type_; + + // True if packet flusher is currently attached. + bool flusher_attached_; + + // Whether crypto handshake packets should be fully padded. + bool fully_pad_crypto_handshake_packets_; + + // Packet number of the first packet of a write operation. This gets set + // when the out-most flusher attaches and gets cleared when the out-most + // flusher detaches. + QuicPacketNumber write_start_packet_number_; + + // If not 0, this latches the actual max_packet_length when + // SetSoftMaxPacketLength is called and max_packet_length_ gets + // set to a soft value. + QuicByteCount latched_hard_max_packet_length_; + + // The maximum length of a MESSAGE/DATAGRAM frame that our peer is willing to + // accept. There is no limit for QUIC_CRYPTO connections, but QUIC+TLS + // negotiates this during the handshake. + QuicByteCount max_datagram_frame_size_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_PACKET_CREATOR_H_ diff --git a/quiche/quic/core/quic_packet_creator_test.cc b/quiche/quic/core/quic_packet_creator_test.cc new file mode 100644 index 000000000000..ec8bcd70606c --- /dev/null +++ b/quiche/quic/core/quic_packet_creator_test.cc @@ -0,0 +1,4148 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_packet_creator.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/quic_decrypter.h" +#include "quiche/quic/core/crypto/quic_encrypter.h" +#include "quiche/quic/core/frames/quic_frame.h" +#include "quiche/quic/core/frames/quic_stream_frame.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_framer_peer.h" +#include "quiche/quic/test_tools/quic_packet_creator_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/test_tools/simple_data_producer.h" +#include "quiche/quic/test_tools/simple_quic_framer.h" +#include "quiche/common/simple_buffer_allocator.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +using ::testing::_; +using ::testing::AtLeast; +using ::testing::DoAll; +using ::testing::InSequence; +using ::testing::Invoke; +using ::testing::Return; +using ::testing::SaveArg; +using ::testing::StrictMock; + +namespace quic { +namespace test { +namespace { + +const QuicPacketNumber kPacketNumber = QuicPacketNumber(UINT64_C(0x12345678)); +// Use fields in which each byte is distinct to ensure that every byte is +// framed correctly. The values are otherwise arbitrary. +QuicConnectionId CreateTestConnectionId() { + return TestConnectionId(UINT64_C(0xFEDCBA9876543210)); +} + +// Run tests with combinations of {ParsedQuicVersion, +// ToggleVersionSerialization}. +struct TestParams { + TestParams(ParsedQuicVersion version, bool version_serialization) + : version(version), version_serialization(version_serialization) {} + + ParsedQuicVersion version; + bool version_serialization; +}; + +// Used by ::testing::PrintToStringParamName(). +std::string PrintToString(const TestParams& p) { + return absl::StrCat(ParsedQuicVersionToString(p.version), "_", + (p.version_serialization ? "Include" : "No"), "Version"); +} + +// Constructs various test permutations. +std::vector GetTestParams() { + std::vector params; + ParsedQuicVersionVector all_supported_versions = AllSupportedVersions(); + for (size_t i = 0; i < all_supported_versions.size(); ++i) { + params.push_back(TestParams(all_supported_versions[i], true)); + params.push_back(TestParams(all_supported_versions[i], false)); + } + return params; +} + +class MockDebugDelegate : public QuicPacketCreator::DebugDelegate { + public: + ~MockDebugDelegate() override = default; + + MOCK_METHOD(void, OnFrameAddedToPacket, (const QuicFrame& frame), (override)); + + MOCK_METHOD(void, OnStreamFrameCoalesced, (const QuicStreamFrame& frame), + (override)); +}; + +class TestPacketCreator : public QuicPacketCreator { + public: + TestPacketCreator(QuicConnectionId connection_id, QuicFramer* framer, + DelegateInterface* delegate, SimpleDataProducer* producer) + : QuicPacketCreator(connection_id, framer, delegate), + producer_(producer), + version_(framer->version()) {} + + bool ConsumeDataToFillCurrentPacket(QuicStreamId id, absl::string_view data, + QuicStreamOffset offset, bool fin, + bool needs_full_padding, + TransmissionType transmission_type, + QuicFrame* frame) { + // Save data before data is consumed. + if (!data.empty()) { + producer_->SaveStreamData(id, data); + } + return QuicPacketCreator::ConsumeDataToFillCurrentPacket( + id, data.length(), offset, fin, needs_full_padding, transmission_type, + frame); + } + + void StopSendingVersion() { + if (version_.HasIetfInvariantHeader()) { + set_encryption_level(ENCRYPTION_FORWARD_SECURE); + return; + } + QuicPacketCreator::StopSendingVersion(); + } + + SimpleDataProducer* producer_; + ParsedQuicVersion version_; +}; + +class QuicPacketCreatorTest : public QuicTestWithParam { + public: + void ClearSerializedPacketForTests(SerializedPacket /*serialized_packet*/) { + // serialized packet self-clears on destruction. + } + + void SaveSerializedPacket(SerializedPacket serialized_packet) { + serialized_packet_.reset(CopySerializedPacket( + serialized_packet, &allocator_, /*copy_buffer=*/true)); + } + + void DeleteSerializedPacket() { serialized_packet_ = nullptr; } + + protected: + QuicPacketCreatorTest() + : connection_id_(TestConnectionId(2)), + server_framer_(SupportedVersions(GetParam().version), QuicTime::Zero(), + Perspective::IS_SERVER, connection_id_.length()), + client_framer_(SupportedVersions(GetParam().version), QuicTime::Zero(), + Perspective::IS_CLIENT, connection_id_.length()), + data_("foo"), + creator_(connection_id_, &client_framer_, &delegate_, &producer_) { + EXPECT_CALL(delegate_, GetPacketBuffer()) + .WillRepeatedly(Return(QuicPacketBuffer())); + EXPECT_CALL(delegate_, GetSerializedPacketFate(_, _)) + .WillRepeatedly(Return(SEND_TO_WRITER)); + creator_.SetEncrypter( + ENCRYPTION_INITIAL, + std::make_unique(ENCRYPTION_INITIAL)); + creator_.SetEncrypter( + ENCRYPTION_HANDSHAKE, + std::make_unique(ENCRYPTION_HANDSHAKE)); + creator_.SetEncrypter( + ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + creator_.SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + client_framer_.set_visitor(&framer_visitor_); + server_framer_.set_visitor(&framer_visitor_); + client_framer_.set_data_producer(&producer_); + if (server_framer_.version().KnowsWhichDecrypterToUse()) { + server_framer_.InstallDecrypter(ENCRYPTION_INITIAL, + std::make_unique()); + server_framer_.InstallDecrypter(ENCRYPTION_ZERO_RTT, + std::make_unique()); + server_framer_.InstallDecrypter(ENCRYPTION_HANDSHAKE, + std::make_unique()); + server_framer_.InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique()); + } else { + server_framer_.SetDecrypter(ENCRYPTION_INITIAL, + std::make_unique()); + server_framer_.SetAlternativeDecrypter( + ENCRYPTION_FORWARD_SECURE, std::make_unique(), + false); + } + } + + ~QuicPacketCreatorTest() override {} + + SerializedPacket SerializeAllFrames(const QuicFrames& frames) { + SerializedPacket packet = QuicPacketCreatorPeer::SerializeAllFrames( + &creator_, frames, buffer_, kMaxOutgoingPacketSize); + EXPECT_EQ(QuicPacketCreatorPeer::GetEncryptionLevel(&creator_), + packet.encryption_level); + return packet; + } + + void ProcessPacket(const SerializedPacket& packet) { + QuicEncryptedPacket encrypted_packet(packet.encrypted_buffer, + packet.encrypted_length); + server_framer_.ProcessPacket(encrypted_packet); + } + + void CheckStreamFrame(const QuicFrame& frame, QuicStreamId stream_id, + const std::string& data, QuicStreamOffset offset, + bool fin) { + EXPECT_EQ(STREAM_FRAME, frame.type); + EXPECT_EQ(stream_id, frame.stream_frame.stream_id); + char buf[kMaxOutgoingPacketSize]; + QuicDataWriter writer(kMaxOutgoingPacketSize, buf, quiche::HOST_BYTE_ORDER); + if (frame.stream_frame.data_length > 0) { + producer_.WriteStreamData(stream_id, frame.stream_frame.offset, + frame.stream_frame.data_length, &writer); + } + EXPECT_EQ(data, absl::string_view(buf, frame.stream_frame.data_length)); + EXPECT_EQ(offset, frame.stream_frame.offset); + EXPECT_EQ(fin, frame.stream_frame.fin); + } + + // Returns the number of bytes consumed by the header of packet, including + // the version. + size_t GetPacketHeaderOverhead(QuicTransportVersion version) { + return GetPacketHeaderSize( + version, creator_.GetDestinationConnectionIdLength(), + creator_.GetSourceConnectionIdLength(), + QuicPacketCreatorPeer::SendVersionInPacket(&creator_), + !kIncludeDiversificationNonce, + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_), + QuicPacketCreatorPeer::GetRetryTokenLengthLength(&creator_), 0, + QuicPacketCreatorPeer::GetLengthLength(&creator_)); + } + + // Returns the number of bytes of overhead that will be added to a packet + // of maximum length. + size_t GetEncryptionOverhead() { + return creator_.max_packet_length() - + client_framer_.GetMaxPlaintextSize(creator_.max_packet_length()); + } + + // Returns the number of bytes consumed by the non-data fields of a stream + // frame, assuming it is the last frame in the packet + size_t GetStreamFrameOverhead(QuicTransportVersion version) { + return QuicFramer::GetMinStreamFrameSize( + version, GetNthClientInitiatedStreamId(1), kOffset, true, + /* data_length= */ 0); + } + + bool IsDefaultTestConfiguration() { + TestParams p = GetParam(); + return p.version == AllSupportedVersions()[0] && p.version_serialization; + } + + QuicStreamId GetNthClientInitiatedStreamId(int n) const { + return QuicUtils::GetFirstBidirectionalStreamId( + creator_.transport_version(), Perspective::IS_CLIENT) + + n * 2; + } + + void TestChaosProtection(bool enabled); + + static constexpr QuicStreamOffset kOffset = 0u; + + char buffer_[kMaxOutgoingPacketSize]; + QuicConnectionId connection_id_; + QuicFrames frames_; + QuicFramer server_framer_; + QuicFramer client_framer_; + StrictMock framer_visitor_; + StrictMock delegate_; + std::string data_; + TestPacketCreator creator_; + std::unique_ptr serialized_packet_; + SimpleDataProducer producer_; + quiche::SimpleBufferAllocator allocator_; +}; + +// Run all packet creator tests with all supported versions of QUIC, and with +// and without version in the packet header, as well as doing a run for each +// length of truncated connection id. +INSTANTIATE_TEST_SUITE_P(QuicPacketCreatorTests, QuicPacketCreatorTest, + ::testing::ValuesIn(GetTestParams()), + ::testing::PrintToStringParamName()); + +TEST_P(QuicPacketCreatorTest, SerializeFrames) { + ParsedQuicVersion version = client_framer_.version(); + for (int i = ENCRYPTION_INITIAL; i < NUM_ENCRYPTION_LEVELS; ++i) { + EncryptionLevel level = static_cast(i); + bool has_ack = false, has_stream = false; + creator_.set_encryption_level(level); + size_t payload_len = 0; + if (level != ENCRYPTION_ZERO_RTT) { + frames_.push_back(QuicFrame(new QuicAckFrame(InitAckFrame(1)))); + has_ack = true; + payload_len += version.UsesTls() ? 12 : 6; + } + if (level != ENCRYPTION_INITIAL && level != ENCRYPTION_HANDSHAKE) { + QuicStreamId stream_id = QuicUtils::GetFirstBidirectionalStreamId( + client_framer_.transport_version(), Perspective::IS_CLIENT); + frames_.push_back(QuicFrame( + QuicStreamFrame(stream_id, false, 0u, absl::string_view()))); + has_stream = true; + payload_len += 2; + } + SerializedPacket serialized = SerializeAllFrames(frames_); + EXPECT_EQ(level, serialized.encryption_level); + if (level != ENCRYPTION_ZERO_RTT) { + delete frames_[0].ack_frame; + } + frames_.clear(); + ASSERT_GT(payload_len, 0); // Must have a frame! + size_t min_payload = version.UsesTls() ? 3 : 7; + bool need_padding = + (version.HasHeaderProtection() && (payload_len < min_payload)); + { + InSequence s; + EXPECT_CALL(framer_visitor_, OnPacket()); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedPublicHeader(_)); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedHeader(_)); + EXPECT_CALL(framer_visitor_, OnDecryptedPacket(_, _)); + EXPECT_CALL(framer_visitor_, OnPacketHeader(_)); + if (need_padding) { + EXPECT_CALL(framer_visitor_, OnPaddingFrame(_)); + } + if (has_ack) { + EXPECT_CALL(framer_visitor_, OnAckFrameStart(_, _)) + .WillOnce(Return(true)); + EXPECT_CALL(framer_visitor_, + OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2))) + .WillOnce(Return(true)); + EXPECT_CALL(framer_visitor_, OnAckFrameEnd(QuicPacketNumber(1), _)) + .WillOnce(Return(true)); + } + if (has_stream) { + EXPECT_CALL(framer_visitor_, OnStreamFrame(_)); + } + EXPECT_CALL(framer_visitor_, OnPacketComplete()); + } + ProcessPacket(serialized); + } +} + +TEST_P(QuicPacketCreatorTest, SerializeConnectionClose) { + QuicConnectionCloseFrame* frame = new QuicConnectionCloseFrame( + creator_.transport_version(), QUIC_NO_ERROR, NO_IETF_QUIC_ERROR, "error", + /*transport_close_frame_type=*/0); + + QuicFrames frames; + frames.push_back(QuicFrame(frame)); + SerializedPacket serialized = SerializeAllFrames(frames); + EXPECT_EQ(ENCRYPTION_INITIAL, serialized.encryption_level); + ASSERT_EQ(QuicPacketNumber(1u), serialized.packet_number); + ASSERT_EQ(QuicPacketNumber(1u), creator_.packet_number()); + + InSequence s; + EXPECT_CALL(framer_visitor_, OnPacket()); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedPublicHeader(_)); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedHeader(_)); + EXPECT_CALL(framer_visitor_, OnDecryptedPacket(_, _)); + EXPECT_CALL(framer_visitor_, OnPacketHeader(_)); + EXPECT_CALL(framer_visitor_, OnConnectionCloseFrame(_)); + EXPECT_CALL(framer_visitor_, OnPacketComplete()); + + ProcessPacket(serialized); +} + +TEST_P(QuicPacketCreatorTest, ConsumeCryptoDataToFillCurrentPacket) { + std::string data = "crypto data"; + QuicFrame frame; + ASSERT_TRUE(creator_.ConsumeCryptoDataToFillCurrentPacket( + ENCRYPTION_INITIAL, data.length(), 0, + /*needs_full_padding=*/true, NOT_RETRANSMISSION, &frame)); + EXPECT_EQ(frame.crypto_frame->data_length, data.length()); + EXPECT_TRUE(creator_.HasPendingFrames()); +} + +TEST_P(QuicPacketCreatorTest, ConsumeDataToFillCurrentPacket) { + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + QuicFrame frame; + QuicStreamId stream_id = QuicUtils::GetFirstBidirectionalStreamId( + client_framer_.transport_version(), Perspective::IS_CLIENT); + const std::string data("test"); + ASSERT_TRUE(creator_.ConsumeDataToFillCurrentPacket( + stream_id, data, 0u, false, false, NOT_RETRANSMISSION, &frame)); + size_t consumed = frame.stream_frame.data_length; + EXPECT_EQ(4u, consumed); + CheckStreamFrame(frame, stream_id, "test", 0u, false); + EXPECT_TRUE(creator_.HasPendingFrames()); +} + +TEST_P(QuicPacketCreatorTest, ConsumeDataFin) { + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + QuicFrame frame; + QuicStreamId stream_id = QuicUtils::GetFirstBidirectionalStreamId( + client_framer_.transport_version(), Perspective::IS_CLIENT); + const std::string data("test"); + ASSERT_TRUE(creator_.ConsumeDataToFillCurrentPacket( + stream_id, data, 0u, true, false, NOT_RETRANSMISSION, &frame)); + size_t consumed = frame.stream_frame.data_length; + EXPECT_EQ(4u, consumed); + CheckStreamFrame(frame, stream_id, "test", 0u, true); + EXPECT_TRUE(creator_.HasPendingFrames()); +} + +TEST_P(QuicPacketCreatorTest, ConsumeDataFinOnly) { + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + QuicFrame frame; + QuicStreamId stream_id = QuicUtils::GetFirstBidirectionalStreamId( + client_framer_.transport_version(), Perspective::IS_CLIENT); + ASSERT_TRUE(creator_.ConsumeDataToFillCurrentPacket( + stream_id, {}, 0u, true, false, NOT_RETRANSMISSION, &frame)); + size_t consumed = frame.stream_frame.data_length; + EXPECT_EQ(0u, consumed); + CheckStreamFrame(frame, stream_id, std::string(), 0u, true); + EXPECT_TRUE(creator_.HasPendingFrames()); + EXPECT_TRUE(absl::StartsWith(creator_.GetPendingFramesInfo(), + "type { STREAM_FRAME }")); +} + +TEST_P(QuicPacketCreatorTest, CreateAllFreeBytesForStreamFrames) { + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + const size_t overhead = + GetPacketHeaderOverhead(client_framer_.transport_version()) + + GetEncryptionOverhead(); + for (size_t i = overhead + + QuicPacketCreator::MinPlaintextPacketSize( + client_framer_.version(), + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_)); + i < overhead + 100; ++i) { + SCOPED_TRACE(i); + creator_.SetMaxPacketLength(i); + const bool should_have_room = + i > + overhead + GetStreamFrameOverhead(client_framer_.transport_version()); + ASSERT_EQ(should_have_room, + creator_.HasRoomForStreamFrame(GetNthClientInitiatedStreamId(1), + kOffset, /* data_size=*/0xffff)); + if (should_have_room) { + QuicFrame frame; + const std::string data("testdata"); + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillRepeatedly(Invoke( + this, &QuicPacketCreatorTest::ClearSerializedPacketForTests)); + ASSERT_TRUE(creator_.ConsumeDataToFillCurrentPacket( + GetNthClientInitiatedStreamId(1), data, kOffset, false, false, + NOT_RETRANSMISSION, &frame)); + size_t bytes_consumed = frame.stream_frame.data_length; + EXPECT_LT(0u, bytes_consumed); + creator_.FlushCurrentPacket(); + } + } +} + +TEST_P(QuicPacketCreatorTest, StreamFrameConsumption) { + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + // Compute the total overhead for a single frame in packet. + const size_t overhead = + GetPacketHeaderOverhead(client_framer_.transport_version()) + + GetEncryptionOverhead() + + GetStreamFrameOverhead(client_framer_.transport_version()); + size_t capacity = kDefaultMaxPacketSize - overhead; + // Now, test various sizes around this size. + for (int delta = -5; delta <= 5; ++delta) { + std::string data(capacity + delta, 'A'); + size_t bytes_free = delta > 0 ? 0 : 0 - delta; + QuicFrame frame; + ASSERT_TRUE(creator_.ConsumeDataToFillCurrentPacket( + GetNthClientInitiatedStreamId(1), data, kOffset, false, false, + NOT_RETRANSMISSION, &frame)); + + // BytesFree() returns bytes available for the next frame, which will + // be two bytes smaller since the stream frame would need to be grown. + EXPECT_EQ(2u, creator_.ExpansionOnNewFrame()); + size_t expected_bytes_free = bytes_free < 3 ? 0 : bytes_free - 2; + EXPECT_EQ(expected_bytes_free, creator_.BytesFree()) << "delta: " << delta; + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce(Invoke(this, &QuicPacketCreatorTest::SaveSerializedPacket)); + creator_.FlushCurrentPacket(); + ASSERT_TRUE(serialized_packet_->encrypted_buffer); + DeleteSerializedPacket(); + } +} + +TEST_P(QuicPacketCreatorTest, CryptoStreamFramePacketPadding) { + // This test serializes crypto payloads slightly larger than a packet, which + // Causes the multi-packet ClientHello check to fail. + SetQuicFlag(quic_enforce_single_packet_chlo, false); + // Compute the total overhead for a single frame in packet. + size_t overhead = + GetPacketHeaderOverhead(client_framer_.transport_version()) + + GetEncryptionOverhead(); + if (QuicVersionUsesCryptoFrames(client_framer_.transport_version())) { + overhead += + QuicFramer::GetMinCryptoFrameSize(kOffset, kMaxOutgoingPacketSize); + } else { + overhead += QuicFramer::GetMinStreamFrameSize( + client_framer_.transport_version(), GetNthClientInitiatedStreamId(1), + kOffset, false, 0); + } + ASSERT_GT(kMaxOutgoingPacketSize, overhead); + size_t capacity = kDefaultMaxPacketSize - overhead; + // Now, test various sizes around this size. + for (int delta = -5; delta <= 5; ++delta) { + SCOPED_TRACE(delta); + std::string data(capacity + delta, 'A'); + size_t bytes_free = delta > 0 ? 0 : 0 - delta; + + QuicFrame frame; + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillRepeatedly( + Invoke(this, &QuicPacketCreatorTest::SaveSerializedPacket)); + if (client_framer_.version().CanSendCoalescedPackets()) { + EXPECT_CALL(delegate_, GetSerializedPacketFate(_, _)) + .WillRepeatedly(Return(COALESCE)); + } + if (!QuicVersionUsesCryptoFrames(client_framer_.transport_version())) { + ASSERT_TRUE(creator_.ConsumeDataToFillCurrentPacket( + QuicUtils::GetCryptoStreamId(client_framer_.transport_version()), + data, kOffset, false, true, NOT_RETRANSMISSION, &frame)); + size_t bytes_consumed = frame.stream_frame.data_length; + EXPECT_LT(0u, bytes_consumed); + } else { + producer_.SaveCryptoData(ENCRYPTION_INITIAL, kOffset, data); + ASSERT_TRUE(creator_.ConsumeCryptoDataToFillCurrentPacket( + ENCRYPTION_INITIAL, data.length(), kOffset, + /*needs_full_padding=*/true, NOT_RETRANSMISSION, &frame)); + size_t bytes_consumed = frame.crypto_frame->data_length; + EXPECT_LT(0u, bytes_consumed); + } + creator_.FlushCurrentPacket(); + ASSERT_TRUE(serialized_packet_->encrypted_buffer); + // If there is not enough space in the packet to fit a padding frame + // (1 byte) and to expand the stream frame (another 2 bytes) the packet + // will not be padded. + // Padding is skipped when we try to send coalesced packets. + if (client_framer_.version().CanSendCoalescedPackets()) { + EXPECT_EQ(kDefaultMaxPacketSize - bytes_free, + serialized_packet_->encrypted_length); + } else { + EXPECT_EQ(kDefaultMaxPacketSize, serialized_packet_->encrypted_length); + } + DeleteSerializedPacket(); + } +} + +TEST_P(QuicPacketCreatorTest, NonCryptoStreamFramePacketNonPadding) { + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + // Compute the total overhead for a single frame in packet. + const size_t overhead = + GetPacketHeaderOverhead(client_framer_.transport_version()) + + GetEncryptionOverhead() + + GetStreamFrameOverhead(client_framer_.transport_version()); + ASSERT_GT(kDefaultMaxPacketSize, overhead); + size_t capacity = kDefaultMaxPacketSize - overhead; + // Now, test various sizes around this size. + for (int delta = -5; delta <= 5; ++delta) { + std::string data(capacity + delta, 'A'); + size_t bytes_free = delta > 0 ? 0 : 0 - delta; + + QuicFrame frame; + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce(Invoke(this, &QuicPacketCreatorTest::SaveSerializedPacket)); + ASSERT_TRUE(creator_.ConsumeDataToFillCurrentPacket( + GetNthClientInitiatedStreamId(1), data, kOffset, false, false, + NOT_RETRANSMISSION, &frame)); + size_t bytes_consumed = frame.stream_frame.data_length; + EXPECT_LT(0u, bytes_consumed); + creator_.FlushCurrentPacket(); + ASSERT_TRUE(serialized_packet_->encrypted_buffer); + if (bytes_free > 0) { + EXPECT_EQ(kDefaultMaxPacketSize - bytes_free, + serialized_packet_->encrypted_length); + } else { + EXPECT_EQ(kDefaultMaxPacketSize, serialized_packet_->encrypted_length); + } + DeleteSerializedPacket(); + } +} + +// Test that the path challenge connectivity probing packet is serialized +// correctly as a padded PATH CHALLENGE packet. +TEST_P(QuicPacketCreatorTest, BuildPathChallengePacket) { + if (!VersionHasIetfQuicFrames(creator_.transport_version())) { + // This frame is only for IETF QUIC. + return; + } + + QuicPacketHeader header; + header.destination_connection_id = CreateTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + MockRandom randomizer; + QuicPathFrameBuffer payload; + randomizer.RandBytes(payload.data(), payload.size()); + + // clang-format off + unsigned char packet[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // Path Challenge Frame type (IETF_PATH_CHALLENGE) + 0x1a, + // 8 "random" bytes, MockRandom makes lots of r's + 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', + // frame type (padding frame) + 0x00, + 0x00, 0x00, 0x00, 0x00 + }; + // clang-format on + + std::unique_ptr buffer(new char[kMaxOutgoingPacketSize]); + + size_t length = creator_.BuildPaddedPathChallengePacket( + header, buffer.get(), ABSL_ARRAYSIZE(packet), payload, + ENCRYPTION_INITIAL); + EXPECT_EQ(length, ABSL_ARRAYSIZE(packet)); + + // Payload has the random bytes that were generated. Copy them into packet, + // above, before checking that the generated packet is correct. + EXPECT_EQ(kQuicPathFrameBufferSize, payload.size()); + + QuicPacket data(creator_.transport_version(), buffer.release(), length, true, + header); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data.data(), data.length(), + reinterpret_cast(packet), ABSL_ARRAYSIZE(packet)); +} + +TEST_P(QuicPacketCreatorTest, BuildConnectivityProbingPacket) { + QuicPacketHeader header; + header.destination_connection_id = CreateTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + + // clang-format off + unsigned char packet[] = { + // public flags (8 byte connection_id) + 0x2C, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (ping frame) + 0x07, + // frame type (padding frame) + 0x00, + 0x00, 0x00, 0x00, 0x00 + }; + + unsigned char packet46[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type + 0x07, + // frame type (padding frame) + 0x00, + 0x00, 0x00, 0x00, 0x00 + }; + + unsigned char packet99[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // frame type (IETF_PING frame) + 0x01, + // frame type (padding frame) + 0x00, + 0x00, 0x00, 0x00, 0x00 + }; + // clang-format on + + unsigned char* p = packet; + size_t packet_size = ABSL_ARRAYSIZE(packet); + if (creator_.version().HasIetfQuicFrames()) { + p = packet99; + packet_size = ABSL_ARRAYSIZE(packet99); + } else if (creator_.version().HasIetfInvariantHeader()) { + p = packet46; + packet_size = ABSL_ARRAYSIZE(packet46); + } + + std::unique_ptr buffer(new char[kMaxOutgoingPacketSize]); + + size_t length = creator_.BuildConnectivityProbingPacket( + header, buffer.get(), packet_size, ENCRYPTION_INITIAL); + + EXPECT_NE(0u, length); + QuicPacket data(creator_.transport_version(), buffer.release(), length, true, + header); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data.data(), data.length(), + reinterpret_cast(p), packet_size); +} + +// Several tests that the path response connectivity probing packet is +// serialized correctly as either a padded and unpadded PATH RESPONSE +// packet. Also generates packets with 1 and 3 PATH_RESPONSES in them to +// exercised the single- and multiple- payload cases. +TEST_P(QuicPacketCreatorTest, BuildPathResponsePacket1ResponseUnpadded) { + if (!VersionHasIetfQuicFrames(creator_.transport_version())) { + // This frame is only for IETF QUIC. + return; + } + + QuicPacketHeader header; + header.destination_connection_id = CreateTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + QuicPathFrameBuffer payload0 = { + {0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}}; + + // Build 1 PATH RESPONSE, not padded + // clang-format off + unsigned char packet[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // Path Response Frame type (IETF_PATH_RESPONSE) + 0x1b, + // 8 "random" bytes + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + }; + // clang-format on + std::unique_ptr buffer(new char[kMaxOutgoingPacketSize]); + quiche::QuicheCircularDeque payloads; + payloads.push_back(payload0); + size_t length = creator_.BuildPathResponsePacket( + header, buffer.get(), ABSL_ARRAYSIZE(packet), payloads, + /*is_padded=*/false, ENCRYPTION_INITIAL); + EXPECT_EQ(length, ABSL_ARRAYSIZE(packet)); + QuicPacket data(creator_.transport_version(), buffer.release(), length, true, + header); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data.data(), data.length(), + reinterpret_cast(packet), ABSL_ARRAYSIZE(packet)); +} + +TEST_P(QuicPacketCreatorTest, BuildPathResponsePacket1ResponsePadded) { + if (!VersionHasIetfQuicFrames(creator_.transport_version())) { + // This frame is only for IETF QUIC. + return; + } + + QuicPacketHeader header; + header.destination_connection_id = CreateTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + QuicPathFrameBuffer payload0 = { + {0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}}; + + // Build 1 PATH RESPONSE, padded + // clang-format off + unsigned char packet[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // Path Response Frame type (IETF_PATH_RESPONSE) + 0x1b, + // 8 "random" bytes + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + // Padding type and pad + 0x00, 0x00, 0x00, 0x00, 0x00 + }; + // clang-format on + std::unique_ptr buffer(new char[kMaxOutgoingPacketSize]); + quiche::QuicheCircularDeque payloads; + payloads.push_back(payload0); + size_t length = creator_.BuildPathResponsePacket( + header, buffer.get(), ABSL_ARRAYSIZE(packet), payloads, + /*is_padded=*/true, ENCRYPTION_INITIAL); + EXPECT_EQ(length, ABSL_ARRAYSIZE(packet)); + QuicPacket data(creator_.transport_version(), buffer.release(), length, true, + header); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data.data(), data.length(), + reinterpret_cast(packet), ABSL_ARRAYSIZE(packet)); +} + +TEST_P(QuicPacketCreatorTest, BuildPathResponsePacket3ResponsesUnpadded) { + if (!VersionHasIetfQuicFrames(creator_.transport_version())) { + // This frame is only for IETF QUIC. + return; + } + + QuicPacketHeader header; + header.destination_connection_id = CreateTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + QuicPathFrameBuffer payload0 = { + {0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}}; + QuicPathFrameBuffer payload1 = { + {0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18}}; + QuicPathFrameBuffer payload2 = { + {0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28}}; + + // Build one packet with 3 PATH RESPONSES, no padding + // clang-format off + unsigned char packet[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // 3 path response frames (IETF_PATH_RESPONSE type byte and payload) + 0x1b, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x1b, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, + 0x1b, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, + }; + // clang-format on + + std::unique_ptr buffer(new char[kMaxOutgoingPacketSize]); + quiche::QuicheCircularDeque payloads; + payloads.push_back(payload0); + payloads.push_back(payload1); + payloads.push_back(payload2); + size_t length = creator_.BuildPathResponsePacket( + header, buffer.get(), ABSL_ARRAYSIZE(packet), payloads, + /*is_padded=*/false, ENCRYPTION_INITIAL); + EXPECT_EQ(length, ABSL_ARRAYSIZE(packet)); + QuicPacket data(creator_.transport_version(), buffer.release(), length, true, + header); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data.data(), data.length(), + reinterpret_cast(packet), ABSL_ARRAYSIZE(packet)); +} + +TEST_P(QuicPacketCreatorTest, BuildPathResponsePacket3ResponsesPadded) { + if (!VersionHasIetfQuicFrames(creator_.transport_version())) { + // This frame is only for IETF QUIC. + return; + } + + QuicPacketHeader header; + header.destination_connection_id = CreateTestConnectionId(); + header.reset_flag = false; + header.version_flag = false; + header.packet_number = kPacketNumber; + QuicPathFrameBuffer payload0 = { + {0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}}; + QuicPathFrameBuffer payload1 = { + {0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18}}; + QuicPathFrameBuffer payload2 = { + {0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28}}; + + // Build one packet with 3 PATH RESPONSES, with padding + // clang-format off + unsigned char packet[] = { + // type (short header, 4 byte packet number) + 0x43, + // connection_id + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, + // packet number + 0x12, 0x34, 0x56, 0x78, + + // 3 path response frames (IETF_PATH_RESPONSE byte and payload) + 0x1b, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x1b, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, + 0x1b, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, + // Padding + 0x00, 0x00, 0x00, 0x00, 0x00 + }; + // clang-format on + + std::unique_ptr buffer(new char[kMaxOutgoingPacketSize]); + quiche::QuicheCircularDeque payloads; + payloads.push_back(payload0); + payloads.push_back(payload1); + payloads.push_back(payload2); + size_t length = creator_.BuildPathResponsePacket( + header, buffer.get(), ABSL_ARRAYSIZE(packet), payloads, + /*is_padded=*/true, ENCRYPTION_INITIAL); + EXPECT_EQ(length, ABSL_ARRAYSIZE(packet)); + QuicPacket data(creator_.transport_version(), buffer.release(), length, true, + header); + + quiche::test::CompareCharArraysWithHexError( + "constructed packet", data.data(), data.length(), + reinterpret_cast(packet), ABSL_ARRAYSIZE(packet)); +} + +TEST_P(QuicPacketCreatorTest, SerializeConnectivityProbingPacket) { + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + + std::unique_ptr encrypted; + if (VersionHasIetfQuicFrames(creator_.transport_version())) { + QuicPathFrameBuffer payload = { + {0xde, 0xad, 0xbe, 0xef, 0xba, 0xdc, 0x0f, 0xfe}}; + encrypted = + creator_.SerializePathChallengeConnectivityProbingPacket(payload); + } else { + encrypted = creator_.SerializeConnectivityProbingPacket(); + } + { + InSequence s; + EXPECT_CALL(framer_visitor_, OnPacket()); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedPublicHeader(_)); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedHeader(_)); + EXPECT_CALL(framer_visitor_, OnDecryptedPacket(_, _)); + EXPECT_CALL(framer_visitor_, OnPacketHeader(_)); + if (VersionHasIetfQuicFrames(creator_.transport_version())) { + EXPECT_CALL(framer_visitor_, OnPathChallengeFrame(_)); + EXPECT_CALL(framer_visitor_, OnPaddingFrame(_)); + } else { + EXPECT_CALL(framer_visitor_, OnPingFrame(_)); + EXPECT_CALL(framer_visitor_, OnPaddingFrame(_)); + } + EXPECT_CALL(framer_visitor_, OnPacketComplete()); + } + // QuicFramerPeer::SetPerspective(&client_framer_, Perspective::IS_SERVER); + server_framer_.ProcessPacket(QuicEncryptedPacket( + encrypted->encrypted_buffer, encrypted->encrypted_length)); +} + +TEST_P(QuicPacketCreatorTest, SerializePathChallengeProbePacket) { + if (!VersionHasIetfQuicFrames(creator_.transport_version())) { + return; + } + QuicPathFrameBuffer payload = { + {0xde, 0xad, 0xbe, 0xef, 0xba, 0xdc, 0x0f, 0xee}}; + + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + + std::unique_ptr encrypted( + creator_.SerializePathChallengeConnectivityProbingPacket(payload)); + { + InSequence s; + EXPECT_CALL(framer_visitor_, OnPacket()); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedPublicHeader(_)); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedHeader(_)); + EXPECT_CALL(framer_visitor_, OnDecryptedPacket(_, _)); + EXPECT_CALL(framer_visitor_, OnPacketHeader(_)); + EXPECT_CALL(framer_visitor_, OnPathChallengeFrame(_)); + EXPECT_CALL(framer_visitor_, OnPaddingFrame(_)); + EXPECT_CALL(framer_visitor_, OnPacketComplete()); + } + // QuicFramerPeer::SetPerspective(&client_framer_, Perspective::IS_SERVER); + server_framer_.ProcessPacket(QuicEncryptedPacket( + encrypted->encrypted_buffer, encrypted->encrypted_length)); +} + +TEST_P(QuicPacketCreatorTest, SerializePathResponseProbePacket1PayloadPadded) { + if (!VersionHasIetfQuicFrames(creator_.transport_version())) { + return; + } + QuicPathFrameBuffer payload0 = { + {0xde, 0xad, 0xbe, 0xef, 0xba, 0xdc, 0x0f, 0xee}}; + + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + + quiche::QuicheCircularDeque payloads; + payloads.push_back(payload0); + + std::unique_ptr encrypted( + creator_.SerializePathResponseConnectivityProbingPacket(payloads, true)); + { + InSequence s; + EXPECT_CALL(framer_visitor_, OnPacket()); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedPublicHeader(_)); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedHeader(_)); + EXPECT_CALL(framer_visitor_, OnDecryptedPacket(_, _)); + EXPECT_CALL(framer_visitor_, OnPacketHeader(_)); + EXPECT_CALL(framer_visitor_, OnPathResponseFrame(_)); + EXPECT_CALL(framer_visitor_, OnPaddingFrame(_)); + EXPECT_CALL(framer_visitor_, OnPacketComplete()); + } + server_framer_.ProcessPacket(QuicEncryptedPacket( + encrypted->encrypted_buffer, encrypted->encrypted_length)); +} + +TEST_P(QuicPacketCreatorTest, + SerializePathResponseProbePacket1PayloadUnPadded) { + if (!VersionHasIetfQuicFrames(creator_.transport_version())) { + return; + } + QuicPathFrameBuffer payload0 = { + {0xde, 0xad, 0xbe, 0xef, 0xba, 0xdc, 0x0f, 0xee}}; + + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + + quiche::QuicheCircularDeque payloads; + payloads.push_back(payload0); + + std::unique_ptr encrypted( + creator_.SerializePathResponseConnectivityProbingPacket(payloads, false)); + { + InSequence s; + EXPECT_CALL(framer_visitor_, OnPacket()); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedPublicHeader(_)); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedHeader(_)); + EXPECT_CALL(framer_visitor_, OnDecryptedPacket(_, _)); + EXPECT_CALL(framer_visitor_, OnPacketHeader(_)); + EXPECT_CALL(framer_visitor_, OnPathResponseFrame(_)); + EXPECT_CALL(framer_visitor_, OnPacketComplete()); + } + server_framer_.ProcessPacket(QuicEncryptedPacket( + encrypted->encrypted_buffer, encrypted->encrypted_length)); +} + +TEST_P(QuicPacketCreatorTest, SerializePathResponseProbePacket2PayloadsPadded) { + if (!VersionHasIetfQuicFrames(creator_.transport_version())) { + return; + } + QuicPathFrameBuffer payload0 = { + {0xde, 0xad, 0xbe, 0xef, 0xba, 0xdc, 0x0f, 0xee}}; + QuicPathFrameBuffer payload1 = { + {0xad, 0xbe, 0xef, 0xba, 0xdc, 0x0f, 0xee, 0xde}}; + + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + + quiche::QuicheCircularDeque payloads; + payloads.push_back(payload0); + payloads.push_back(payload1); + + std::unique_ptr encrypted( + creator_.SerializePathResponseConnectivityProbingPacket(payloads, true)); + { + InSequence s; + EXPECT_CALL(framer_visitor_, OnPacket()); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedPublicHeader(_)); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedHeader(_)); + EXPECT_CALL(framer_visitor_, OnDecryptedPacket(_, _)); + EXPECT_CALL(framer_visitor_, OnPacketHeader(_)); + EXPECT_CALL(framer_visitor_, OnPathResponseFrame(_)).Times(2); + EXPECT_CALL(framer_visitor_, OnPaddingFrame(_)); + EXPECT_CALL(framer_visitor_, OnPacketComplete()); + } + server_framer_.ProcessPacket(QuicEncryptedPacket( + encrypted->encrypted_buffer, encrypted->encrypted_length)); +} + +TEST_P(QuicPacketCreatorTest, + SerializePathResponseProbePacket2PayloadsUnPadded) { + if (!VersionHasIetfQuicFrames(creator_.transport_version())) { + return; + } + QuicPathFrameBuffer payload0 = { + {0xde, 0xad, 0xbe, 0xef, 0xba, 0xdc, 0x0f, 0xee}}; + QuicPathFrameBuffer payload1 = { + {0xad, 0xbe, 0xef, 0xba, 0xdc, 0x0f, 0xee, 0xde}}; + + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + + quiche::QuicheCircularDeque payloads; + payloads.push_back(payload0); + payloads.push_back(payload1); + + std::unique_ptr encrypted( + creator_.SerializePathResponseConnectivityProbingPacket(payloads, false)); + { + InSequence s; + EXPECT_CALL(framer_visitor_, OnPacket()); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedPublicHeader(_)); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedHeader(_)); + EXPECT_CALL(framer_visitor_, OnDecryptedPacket(_, _)); + EXPECT_CALL(framer_visitor_, OnPacketHeader(_)); + EXPECT_CALL(framer_visitor_, OnPathResponseFrame(_)).Times(2); + EXPECT_CALL(framer_visitor_, OnPacketComplete()); + } + server_framer_.ProcessPacket(QuicEncryptedPacket( + encrypted->encrypted_buffer, encrypted->encrypted_length)); +} + +TEST_P(QuicPacketCreatorTest, SerializePathResponseProbePacket3PayloadsPadded) { + if (!VersionHasIetfQuicFrames(creator_.transport_version())) { + return; + } + QuicPathFrameBuffer payload0 = { + {0xde, 0xad, 0xbe, 0xef, 0xba, 0xdc, 0x0f, 0xee}}; + QuicPathFrameBuffer payload1 = { + {0xad, 0xbe, 0xef, 0xba, 0xdc, 0x0f, 0xee, 0xde}}; + QuicPathFrameBuffer payload2 = { + {0xbe, 0xef, 0xba, 0xdc, 0x0f, 0xee, 0xde, 0xad}}; + + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + + quiche::QuicheCircularDeque payloads; + payloads.push_back(payload0); + payloads.push_back(payload1); + payloads.push_back(payload2); + + std::unique_ptr encrypted( + creator_.SerializePathResponseConnectivityProbingPacket(payloads, true)); + { + InSequence s; + EXPECT_CALL(framer_visitor_, OnPacket()); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedPublicHeader(_)); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedHeader(_)); + EXPECT_CALL(framer_visitor_, OnDecryptedPacket(_, _)); + EXPECT_CALL(framer_visitor_, OnPacketHeader(_)); + EXPECT_CALL(framer_visitor_, OnPathResponseFrame(_)).Times(3); + EXPECT_CALL(framer_visitor_, OnPaddingFrame(_)); + EXPECT_CALL(framer_visitor_, OnPacketComplete()); + } + server_framer_.ProcessPacket(QuicEncryptedPacket( + encrypted->encrypted_buffer, encrypted->encrypted_length)); +} + +TEST_P(QuicPacketCreatorTest, + SerializePathResponseProbePacket3PayloadsUnpadded) { + if (!VersionHasIetfQuicFrames(creator_.transport_version())) { + return; + } + QuicPathFrameBuffer payload0 = { + {0xde, 0xad, 0xbe, 0xef, 0xba, 0xdc, 0x0f, 0xee}}; + QuicPathFrameBuffer payload1 = { + {0xad, 0xbe, 0xef, 0xba, 0xdc, 0x0f, 0xee, 0xde}}; + QuicPathFrameBuffer payload2 = { + {0xbe, 0xef, 0xba, 0xdc, 0x0f, 0xee, 0xde, 0xad}}; + + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + + quiche::QuicheCircularDeque payloads; + payloads.push_back(payload0); + payloads.push_back(payload1); + payloads.push_back(payload2); + + std::unique_ptr encrypted( + creator_.SerializePathResponseConnectivityProbingPacket(payloads, false)); + InSequence s; + EXPECT_CALL(framer_visitor_, OnPacket()); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedPublicHeader(_)); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedHeader(_)); + EXPECT_CALL(framer_visitor_, OnDecryptedPacket(_, _)); + EXPECT_CALL(framer_visitor_, OnPacketHeader(_)); + EXPECT_CALL(framer_visitor_, OnPathResponseFrame(_)).Times(3); + EXPECT_CALL(framer_visitor_, OnPacketComplete()); + + server_framer_.ProcessPacket(QuicEncryptedPacket( + encrypted->encrypted_buffer, encrypted->encrypted_length)); +} + +TEST_P(QuicPacketCreatorTest, UpdatePacketSequenceNumberLengthLeastAwaiting) { + if (creator_.version().HasIetfInvariantHeader() && + !GetParam().version.SendsVariableLengthPacketNumberInLongHeader()) { + EXPECT_EQ(PACKET_4BYTE_PACKET_NUMBER, + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_)); + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + } else { + EXPECT_EQ(PACKET_1BYTE_PACKET_NUMBER, + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_)); + } + + QuicPacketCreatorPeer::SetPacketNumber(&creator_, 64); + creator_.UpdatePacketNumberLength(QuicPacketNumber(2), + 10000 / kDefaultMaxPacketSize); + EXPECT_EQ(PACKET_1BYTE_PACKET_NUMBER, + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_)); + + QuicPacketCreatorPeer::SetPacketNumber(&creator_, 64 * 256); + creator_.UpdatePacketNumberLength(QuicPacketNumber(2), + 10000 / kDefaultMaxPacketSize); + EXPECT_EQ(PACKET_2BYTE_PACKET_NUMBER, + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_)); + + QuicPacketCreatorPeer::SetPacketNumber(&creator_, 64 * 256 * 256); + creator_.UpdatePacketNumberLength(QuicPacketNumber(2), + 10000 / kDefaultMaxPacketSize); + EXPECT_EQ(PACKET_4BYTE_PACKET_NUMBER, + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_)); + + QuicPacketCreatorPeer::SetPacketNumber(&creator_, + UINT64_C(64) * 256 * 256 * 256 * 256); + creator_.UpdatePacketNumberLength(QuicPacketNumber(2), + 10000 / kDefaultMaxPacketSize); + EXPECT_EQ(PACKET_6BYTE_PACKET_NUMBER, + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_)); +} + +TEST_P(QuicPacketCreatorTest, UpdatePacketSequenceNumberLengthCwnd) { + QuicPacketCreatorPeer::SetPacketNumber(&creator_, 1); + if (creator_.version().HasIetfInvariantHeader() && + !GetParam().version.SendsVariableLengthPacketNumberInLongHeader()) { + EXPECT_EQ(PACKET_4BYTE_PACKET_NUMBER, + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_)); + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + } else { + EXPECT_EQ(PACKET_1BYTE_PACKET_NUMBER, + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_)); + } + + creator_.UpdatePacketNumberLength(QuicPacketNumber(1), + 10000 / kDefaultMaxPacketSize); + EXPECT_EQ(PACKET_1BYTE_PACKET_NUMBER, + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_)); + + creator_.UpdatePacketNumberLength(QuicPacketNumber(1), + 10000 * 256 / kDefaultMaxPacketSize); + EXPECT_EQ(PACKET_2BYTE_PACKET_NUMBER, + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_)); + + creator_.UpdatePacketNumberLength(QuicPacketNumber(1), + 10000 * 256 * 256 / kDefaultMaxPacketSize); + EXPECT_EQ(PACKET_4BYTE_PACKET_NUMBER, + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_)); + + creator_.UpdatePacketNumberLength( + QuicPacketNumber(1), + UINT64_C(1000) * 256 * 256 * 256 * 256 / kDefaultMaxPacketSize); + EXPECT_EQ(PACKET_6BYTE_PACKET_NUMBER, + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_)); +} + +TEST_P(QuicPacketCreatorTest, SkipNPacketNumbers) { + QuicPacketCreatorPeer::SetPacketNumber(&creator_, 1); + if (creator_.version().HasIetfInvariantHeader() && + !GetParam().version.SendsVariableLengthPacketNumberInLongHeader()) { + EXPECT_EQ(PACKET_4BYTE_PACKET_NUMBER, + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_)); + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + } else { + EXPECT_EQ(PACKET_1BYTE_PACKET_NUMBER, + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_)); + } + creator_.SkipNPacketNumbers(63, QuicPacketNumber(2), + 10000 / kDefaultMaxPacketSize); + EXPECT_EQ(QuicPacketNumber(64), creator_.packet_number()); + EXPECT_EQ(PACKET_1BYTE_PACKET_NUMBER, + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_)); + + creator_.SkipNPacketNumbers(64 * 255, QuicPacketNumber(2), + 10000 / kDefaultMaxPacketSize); + EXPECT_EQ(QuicPacketNumber(64 * 256), creator_.packet_number()); + EXPECT_EQ(PACKET_2BYTE_PACKET_NUMBER, + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_)); + + creator_.SkipNPacketNumbers(64 * 256 * 255, QuicPacketNumber(2), + 10000 / kDefaultMaxPacketSize); + EXPECT_EQ(QuicPacketNumber(64 * 256 * 256), creator_.packet_number()); + EXPECT_EQ(PACKET_4BYTE_PACKET_NUMBER, + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_)); +} + +TEST_P(QuicPacketCreatorTest, SerializeFrame) { + if (!GetParam().version_serialization) { + creator_.StopSendingVersion(); + } + std::string data("test data"); + if (!QuicVersionUsesCryptoFrames(client_framer_.transport_version())) { + QuicStreamFrame stream_frame( + QuicUtils::GetCryptoStreamId(client_framer_.transport_version()), + /*fin=*/false, 0u, absl::string_view()); + frames_.push_back(QuicFrame(stream_frame)); + } else { + producer_.SaveCryptoData(ENCRYPTION_INITIAL, 0, data); + frames_.push_back( + QuicFrame(new QuicCryptoFrame(ENCRYPTION_INITIAL, 0, data.length()))); + } + SerializedPacket serialized = SerializeAllFrames(frames_); + + QuicPacketHeader header; + { + InSequence s; + EXPECT_CALL(framer_visitor_, OnPacket()); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedPublicHeader(_)); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedHeader(_)); + EXPECT_CALL(framer_visitor_, OnDecryptedPacket(_, _)); + EXPECT_CALL(framer_visitor_, OnPacketHeader(_)) + .WillOnce(DoAll(SaveArg<0>(&header), Return(true))); + if (QuicVersionUsesCryptoFrames(client_framer_.transport_version())) { + EXPECT_CALL(framer_visitor_, OnCryptoFrame(_)); + } else { + EXPECT_CALL(framer_visitor_, OnStreamFrame(_)); + } + EXPECT_CALL(framer_visitor_, OnPacketComplete()); + } + ProcessPacket(serialized); + EXPECT_EQ(GetParam().version_serialization, header.version_flag); +} + +TEST_P(QuicPacketCreatorTest, SerializeFrameShortData) { + if (!GetParam().version_serialization) { + creator_.StopSendingVersion(); + } + std::string data("Hello World!"); + if (!QuicVersionUsesCryptoFrames(client_framer_.transport_version())) { + QuicStreamFrame stream_frame( + QuicUtils::GetCryptoStreamId(client_framer_.transport_version()), + /*fin=*/false, 0u, absl::string_view()); + frames_.push_back(QuicFrame(stream_frame)); + } else { + producer_.SaveCryptoData(ENCRYPTION_INITIAL, 0, data); + frames_.push_back( + QuicFrame(new QuicCryptoFrame(ENCRYPTION_INITIAL, 0, data.length()))); + } + SerializedPacket serialized = SerializeAllFrames(frames_); + + QuicPacketHeader header; + { + InSequence s; + EXPECT_CALL(framer_visitor_, OnPacket()); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedPublicHeader(_)); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedHeader(_)); + EXPECT_CALL(framer_visitor_, OnDecryptedPacket(_, _)); + EXPECT_CALL(framer_visitor_, OnPacketHeader(_)) + .WillOnce(DoAll(SaveArg<0>(&header), Return(true))); + if (QuicVersionUsesCryptoFrames(client_framer_.transport_version())) { + EXPECT_CALL(framer_visitor_, OnCryptoFrame(_)); + } else { + EXPECT_CALL(framer_visitor_, OnStreamFrame(_)); + } + EXPECT_CALL(framer_visitor_, OnPacketComplete()); + } + ProcessPacket(serialized); + EXPECT_EQ(GetParam().version_serialization, header.version_flag); +} + +void QuicPacketCreatorTest::TestChaosProtection(bool enabled) { + if (!GetParam().version.UsesCryptoFrames()) { + return; + } + MockRandom mock_random(2); + QuicPacketCreatorPeer::SetRandom(&creator_, &mock_random); + std::string data("ChAoS_ThEoRy!"); + producer_.SaveCryptoData(ENCRYPTION_INITIAL, 0, data); + frames_.push_back( + QuicFrame(new QuicCryptoFrame(ENCRYPTION_INITIAL, 0, data.length()))); + frames_.push_back(QuicFrame(QuicPaddingFrame(33))); + SerializedPacket serialized = SerializeAllFrames(frames_); + EXPECT_CALL(framer_visitor_, OnPacket()); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedPublicHeader(_)); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedHeader(_)); + EXPECT_CALL(framer_visitor_, OnDecryptedPacket(_, _)); + EXPECT_CALL(framer_visitor_, OnPacketHeader(_)); + if (enabled) { + EXPECT_CALL(framer_visitor_, OnCryptoFrame(_)).Times(AtLeast(2)); + EXPECT_CALL(framer_visitor_, OnPaddingFrame(_)).Times(AtLeast(2)); + EXPECT_CALL(framer_visitor_, OnPingFrame(_)).Times(AtLeast(1)); + } else { + EXPECT_CALL(framer_visitor_, OnCryptoFrame(_)).Times(1); + EXPECT_CALL(framer_visitor_, OnPaddingFrame(_)).Times(1); + EXPECT_CALL(framer_visitor_, OnPingFrame(_)).Times(0); + } + EXPECT_CALL(framer_visitor_, OnPacketComplete()); + ProcessPacket(serialized); +} + +TEST_P(QuicPacketCreatorTest, ChaosProtectionEnabled) { + TestChaosProtection(true); +} + +TEST_P(QuicPacketCreatorTest, ChaosProtectionDisabled) { + SetQuicFlag(quic_enable_chaos_protection, false); + TestChaosProtection(false); +} + +TEST_P(QuicPacketCreatorTest, ConsumeDataLargerThanOneStreamFrame) { + if (!GetParam().version_serialization) { + creator_.StopSendingVersion(); + } + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + // A string larger than fits into a frame. + QuicFrame frame; + size_t payload_length = creator_.max_packet_length(); + const std::string too_long_payload(payload_length, 'a'); + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce(Invoke(this, &QuicPacketCreatorTest::SaveSerializedPacket)); + QuicStreamId stream_id = QuicUtils::GetFirstBidirectionalStreamId( + client_framer_.transport_version(), Perspective::IS_CLIENT); + ASSERT_TRUE(creator_.ConsumeDataToFillCurrentPacket( + stream_id, too_long_payload, 0u, true, false, NOT_RETRANSMISSION, + &frame)); + size_t consumed = frame.stream_frame.data_length; + // The entire payload could not be consumed. + EXPECT_GT(payload_length, consumed); + creator_.FlushCurrentPacket(); + DeleteSerializedPacket(); +} + +TEST_P(QuicPacketCreatorTest, AddFrameAndFlush) { + if (!GetParam().version_serialization) { + creator_.StopSendingVersion(); + } + const size_t max_plaintext_size = + client_framer_.GetMaxPlaintextSize(creator_.max_packet_length()); + EXPECT_FALSE(creator_.HasPendingFrames()); + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + QuicStreamId stream_id = QuicUtils::GetFirstBidirectionalStreamId( + client_framer_.transport_version(), Perspective::IS_CLIENT); + if (!QuicVersionUsesCryptoFrames(client_framer_.transport_version())) { + stream_id = + QuicUtils::GetCryptoStreamId(client_framer_.transport_version()); + } + EXPECT_FALSE(creator_.HasPendingStreamFramesOfStream(stream_id)); + EXPECT_EQ(max_plaintext_size - + GetPacketHeaderSize( + client_framer_.transport_version(), + creator_.GetDestinationConnectionIdLength(), + creator_.GetSourceConnectionIdLength(), + QuicPacketCreatorPeer::SendVersionInPacket(&creator_), + !kIncludeDiversificationNonce, + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_), + QuicPacketCreatorPeer::GetRetryTokenLengthLength(&creator_), + 0, QuicPacketCreatorPeer::GetLengthLength(&creator_)), + creator_.BytesFree()); + StrictMock debug; + creator_.set_debug_delegate(&debug); + + // Add a variety of frame types and then a padding frame. + QuicAckFrame ack_frame(InitAckFrame(10u)); + EXPECT_CALL(debug, OnFrameAddedToPacket(_)); + EXPECT_TRUE(creator_.AddFrame(QuicFrame(&ack_frame), NOT_RETRANSMISSION)); + EXPECT_TRUE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingStreamFramesOfStream(stream_id)); + + QuicFrame frame; + const std::string data("test"); + EXPECT_CALL(debug, OnFrameAddedToPacket(_)); + ASSERT_TRUE(creator_.ConsumeDataToFillCurrentPacket( + stream_id, data, 0u, false, false, NOT_RETRANSMISSION, &frame)); + size_t consumed = frame.stream_frame.data_length; + EXPECT_EQ(4u, consumed); + EXPECT_TRUE(creator_.HasPendingFrames()); + EXPECT_TRUE(creator_.HasPendingStreamFramesOfStream(stream_id)); + + QuicPaddingFrame padding_frame; + EXPECT_CALL(debug, OnFrameAddedToPacket(_)); + EXPECT_TRUE(creator_.AddFrame(QuicFrame(padding_frame), NOT_RETRANSMISSION)); + EXPECT_TRUE(creator_.HasPendingFrames()); + EXPECT_EQ(0u, creator_.BytesFree()); + + // Packet is full. Creator will flush. + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce(Invoke(this, &QuicPacketCreatorTest::SaveSerializedPacket)); + EXPECT_FALSE(creator_.AddFrame(QuicFrame(&ack_frame), NOT_RETRANSMISSION)); + + // Ensure the packet is successfully created. + ASSERT_TRUE(serialized_packet_->encrypted_buffer); + ASSERT_FALSE(serialized_packet_->retransmittable_frames.empty()); + const QuicFrames& retransmittable = + serialized_packet_->retransmittable_frames; + ASSERT_EQ(1u, retransmittable.size()); + EXPECT_EQ(STREAM_FRAME, retransmittable[0].type); + EXPECT_TRUE(serialized_packet_->has_ack); + EXPECT_EQ(QuicPacketNumber(10u), serialized_packet_->largest_acked); + DeleteSerializedPacket(); + + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingStreamFramesOfStream(stream_id)); + EXPECT_EQ(max_plaintext_size - + GetPacketHeaderSize( + client_framer_.transport_version(), + creator_.GetDestinationConnectionIdLength(), + creator_.GetSourceConnectionIdLength(), + QuicPacketCreatorPeer::SendVersionInPacket(&creator_), + !kIncludeDiversificationNonce, + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_), + QuicPacketCreatorPeer::GetRetryTokenLengthLength(&creator_), + 0, QuicPacketCreatorPeer::GetLengthLength(&creator_)), + creator_.BytesFree()); +} + +TEST_P(QuicPacketCreatorTest, SerializeAndSendStreamFrame) { + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + if (!GetParam().version_serialization) { + creator_.StopSendingVersion(); + } + EXPECT_FALSE(creator_.HasPendingFrames()); + + const std::string data("test"); + producer_.SaveStreamData(GetNthClientInitiatedStreamId(0), data); + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce(Invoke(this, &QuicPacketCreatorTest::SaveSerializedPacket)); + size_t num_bytes_consumed; + StrictMock debug; + creator_.set_debug_delegate(&debug); + EXPECT_CALL(debug, OnFrameAddedToPacket(_)); + creator_.CreateAndSerializeStreamFrame( + GetNthClientInitiatedStreamId(0), data.length(), 0, 0, true, + NOT_RETRANSMISSION, &num_bytes_consumed); + EXPECT_EQ(4u, num_bytes_consumed); + + // Ensure the packet is successfully created. + ASSERT_TRUE(serialized_packet_->encrypted_buffer); + ASSERT_FALSE(serialized_packet_->retransmittable_frames.empty()); + const QuicFrames& retransmittable = + serialized_packet_->retransmittable_frames; + ASSERT_EQ(1u, retransmittable.size()); + EXPECT_EQ(STREAM_FRAME, retransmittable[0].type); + DeleteSerializedPacket(); + + EXPECT_FALSE(creator_.HasPendingFrames()); +} + +TEST_P(QuicPacketCreatorTest, SerializeStreamFrameWithPadding) { + // Regression test to check that CreateAndSerializeStreamFrame uses a + // correctly formatted stream frame header when appending padding. + + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + if (!GetParam().version_serialization) { + creator_.StopSendingVersion(); + } + EXPECT_FALSE(creator_.HasPendingFrames()); + + // Send zero bytes of stream data. This requires padding. + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce(Invoke(this, &QuicPacketCreatorTest::SaveSerializedPacket)); + size_t num_bytes_consumed; + creator_.CreateAndSerializeStreamFrame(GetNthClientInitiatedStreamId(0), 0, 0, + 0, true, NOT_RETRANSMISSION, + &num_bytes_consumed); + EXPECT_EQ(0u, num_bytes_consumed); + + // Check that a packet is created. + ASSERT_TRUE(serialized_packet_->encrypted_buffer); + ASSERT_FALSE(serialized_packet_->retransmittable_frames.empty()); + ASSERT_EQ(serialized_packet_->packet_number_length, + PACKET_1BYTE_PACKET_NUMBER); + { + InSequence s; + EXPECT_CALL(framer_visitor_, OnPacket()); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedPublicHeader(_)); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedHeader(_)); + EXPECT_CALL(framer_visitor_, OnDecryptedPacket(_, _)); + EXPECT_CALL(framer_visitor_, OnPacketHeader(_)); + if (client_framer_.version().HasHeaderProtection()) { + EXPECT_CALL(framer_visitor_, OnPaddingFrame(_)); + EXPECT_CALL(framer_visitor_, OnStreamFrame(_)); + } else { + EXPECT_CALL(framer_visitor_, OnStreamFrame(_)); + } + EXPECT_CALL(framer_visitor_, OnPacketComplete()); + } + ProcessPacket(*serialized_packet_); +} + +TEST_P(QuicPacketCreatorTest, AddUnencryptedStreamDataClosesConnection) { + // EXPECT_QUIC_BUG tests are expensive so only run one instance of them. + if (!IsDefaultTestConfiguration()) { + return; + } + + creator_.set_encryption_level(ENCRYPTION_INITIAL); + QuicStreamFrame stream_frame(GetNthClientInitiatedStreamId(0), + /*fin=*/false, 0u, absl::string_view()); + EXPECT_QUIC_BUG( + { + EXPECT_CALL(delegate_, OnUnrecoverableError(_, _)); + creator_.AddFrame(QuicFrame(stream_frame), NOT_RETRANSMISSION); + }, + "Cannot send stream data with level: ENCRYPTION_INITIAL"); +} + +TEST_P(QuicPacketCreatorTest, SendStreamDataWithEncryptionHandshake) { + // EXPECT_QUIC_BUG tests are expensive so only run one instance of them. + if (!IsDefaultTestConfiguration()) { + return; + } + + creator_.set_encryption_level(ENCRYPTION_HANDSHAKE); + QuicStreamFrame stream_frame(GetNthClientInitiatedStreamId(0), + /*fin=*/false, 0u, absl::string_view()); + EXPECT_QUIC_BUG( + { + EXPECT_CALL(delegate_, OnUnrecoverableError(_, _)); + creator_.AddFrame(QuicFrame(stream_frame), NOT_RETRANSMISSION); + }, + "Cannot send stream data with level: ENCRYPTION_HANDSHAKE"); +} + +TEST_P(QuicPacketCreatorTest, ChloTooLarge) { + // EXPECT_QUIC_BUG tests are expensive so only run one instance of them. + if (!IsDefaultTestConfiguration()) { + return; + } + + // This test only matters when the crypto handshake is sent in stream frames. + // TODO(b/128596274): Re-enable when this check is supported for CRYPTO + // frames. + if (QuicVersionUsesCryptoFrames(client_framer_.transport_version())) { + return; + } + + CryptoHandshakeMessage message; + message.set_tag(kCHLO); + message.set_minimum_size(kMaxOutgoingPacketSize); + CryptoFramer framer; + std::unique_ptr message_data; + message_data = framer.ConstructHandshakeMessage(message); + + QuicFrame frame; + EXPECT_CALL(delegate_, OnUnrecoverableError(QUIC_CRYPTO_CHLO_TOO_LARGE, _)); + EXPECT_QUIC_BUG( + creator_.ConsumeDataToFillCurrentPacket( + QuicUtils::GetCryptoStreamId(client_framer_.transport_version()), + absl::string_view(message_data->data(), message_data->length()), 0u, + false, false, NOT_RETRANSMISSION, &frame), + "Client hello won't fit in a single packet."); +} + +TEST_P(QuicPacketCreatorTest, PendingPadding) { + EXPECT_EQ(0u, creator_.pending_padding_bytes()); + creator_.AddPendingPadding(kMaxNumRandomPaddingBytes * 10); + EXPECT_EQ(kMaxNumRandomPaddingBytes * 10, creator_.pending_padding_bytes()); + + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillRepeatedly( + Invoke(this, &QuicPacketCreatorTest::SaveSerializedPacket)); + // Flush all paddings. + while (creator_.pending_padding_bytes() > 0) { + creator_.FlushCurrentPacket(); + { + InSequence s; + EXPECT_CALL(framer_visitor_, OnPacket()); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedPublicHeader(_)); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedHeader(_)); + EXPECT_CALL(framer_visitor_, OnDecryptedPacket(_, _)); + EXPECT_CALL(framer_visitor_, OnPacketHeader(_)); + EXPECT_CALL(framer_visitor_, OnPaddingFrame(_)); + EXPECT_CALL(framer_visitor_, OnPacketComplete()); + } + // Packet only contains padding. + ProcessPacket(*serialized_packet_); + } + EXPECT_EQ(0u, creator_.pending_padding_bytes()); +} + +TEST_P(QuicPacketCreatorTest, FullPaddingDoesNotConsumePendingPadding) { + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + creator_.AddPendingPadding(kMaxNumRandomPaddingBytes); + QuicFrame frame; + QuicStreamId stream_id = QuicUtils::GetFirstBidirectionalStreamId( + client_framer_.transport_version(), Perspective::IS_CLIENT); + const std::string data("test"); + ASSERT_TRUE(creator_.ConsumeDataToFillCurrentPacket( + stream_id, data, 0u, false, + /*needs_full_padding=*/true, NOT_RETRANSMISSION, &frame)); + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce(Invoke(this, &QuicPacketCreatorTest::SaveSerializedPacket)); + creator_.FlushCurrentPacket(); + EXPECT_EQ(kMaxNumRandomPaddingBytes, creator_.pending_padding_bytes()); +} + +TEST_P(QuicPacketCreatorTest, ConsumeDataAndRandomPadding) { + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + const QuicByteCount kStreamFramePayloadSize = 100u; + // Set the packet size be enough for one stream frame with 0 stream offset + + // 1. + QuicStreamId stream_id = QuicUtils::GetFirstBidirectionalStreamId( + client_framer_.transport_version(), Perspective::IS_CLIENT); + size_t length = + GetPacketHeaderOverhead(client_framer_.transport_version()) + + GetEncryptionOverhead() + + QuicFramer::GetMinStreamFrameSize( + client_framer_.transport_version(), stream_id, 0, + /*last_frame_in_packet=*/true, kStreamFramePayloadSize + 1) + + kStreamFramePayloadSize + 1; + creator_.SetMaxPacketLength(length); + creator_.AddPendingPadding(kMaxNumRandomPaddingBytes); + QuicByteCount pending_padding_bytes = creator_.pending_padding_bytes(); + QuicFrame frame; + char buf[kStreamFramePayloadSize + 1] = {}; + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillRepeatedly( + Invoke(this, &QuicPacketCreatorTest::SaveSerializedPacket)); + // Send stream frame of size kStreamFramePayloadSize. + creator_.ConsumeDataToFillCurrentPacket( + stream_id, absl::string_view(buf, kStreamFramePayloadSize), 0u, false, + false, NOT_RETRANSMISSION, &frame); + creator_.FlushCurrentPacket(); + // 1 byte padding is sent. + EXPECT_EQ(pending_padding_bytes - 1, creator_.pending_padding_bytes()); + // Send stream frame of size kStreamFramePayloadSize + 1. + creator_.ConsumeDataToFillCurrentPacket( + stream_id, absl::string_view(buf, kStreamFramePayloadSize + 1), + kStreamFramePayloadSize, false, false, NOT_RETRANSMISSION, &frame); + // No padding is sent. + creator_.FlushCurrentPacket(); + EXPECT_EQ(pending_padding_bytes - 1, creator_.pending_padding_bytes()); + // Flush all paddings. + while (creator_.pending_padding_bytes() > 0) { + creator_.FlushCurrentPacket(); + } + EXPECT_EQ(0u, creator_.pending_padding_bytes()); +} + +TEST_P(QuicPacketCreatorTest, FlushWithExternalBuffer) { + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + char* buffer = new char[kMaxOutgoingPacketSize]; + QuicPacketBuffer external_buffer = {buffer, + [](const char* p) { delete[] p; }}; + EXPECT_CALL(delegate_, GetPacketBuffer()).WillOnce(Return(external_buffer)); + + QuicFrame frame; + QuicStreamId stream_id = QuicUtils::GetFirstBidirectionalStreamId( + client_framer_.transport_version(), Perspective::IS_CLIENT); + const std::string data("test"); + ASSERT_TRUE(creator_.ConsumeDataToFillCurrentPacket( + stream_id, data, 0u, false, + /*needs_full_padding=*/true, NOT_RETRANSMISSION, &frame)); + + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce(Invoke([&external_buffer](SerializedPacket serialized_packet) { + EXPECT_EQ(external_buffer.buffer, serialized_packet.encrypted_buffer); + })); + creator_.FlushCurrentPacket(); +} + +// Test for error found in +// https://bugs.chromium.org/p/chromium/issues/detail?id=859949 where a gap +// length that crosses an IETF VarInt length boundary would cause a +// failure. While this test is not applicable to versions other than version 99, +// it should still work. Hence, it is not made version-specific. +TEST_P(QuicPacketCreatorTest, IetfAckGapErrorRegression) { + QuicAckFrame ack_frame = + InitAckFrame({{QuicPacketNumber(60), QuicPacketNumber(61)}, + {QuicPacketNumber(125), QuicPacketNumber(126)}}); + frames_.push_back(QuicFrame(&ack_frame)); + SerializeAllFrames(frames_); +} + +TEST_P(QuicPacketCreatorTest, AddMessageFrame) { + if (!VersionSupportsMessageFrames(client_framer_.transport_version())) { + return; + } + if (client_framer_.version().UsesTls()) { + creator_.SetMaxDatagramFrameSize(kMaxAcceptedDatagramFrameSize); + } + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .Times(3) + .WillRepeatedly( + Invoke(this, &QuicPacketCreatorTest::ClearSerializedPacketForTests)); + // Verify that there is enough room for the largest message payload. + EXPECT_TRUE(creator_.HasRoomForMessageFrame( + creator_.GetCurrentLargestMessagePayload())); + std::string large_message(creator_.GetCurrentLargestMessagePayload(), 'a'); + QuicMessageFrame* message_frame = + new QuicMessageFrame(1, MemSliceFromString(large_message)); + EXPECT_TRUE(creator_.AddFrame(QuicFrame(message_frame), NOT_RETRANSMISSION)); + EXPECT_TRUE(creator_.HasPendingFrames()); + creator_.FlushCurrentPacket(); + + QuicMessageFrame* frame2 = + new QuicMessageFrame(2, MemSliceFromString("message")); + EXPECT_TRUE(creator_.AddFrame(QuicFrame(frame2), NOT_RETRANSMISSION)); + EXPECT_TRUE(creator_.HasPendingFrames()); + // Verify if a new frame is added, 1 byte message length will be added. + EXPECT_EQ(1u, creator_.ExpansionOnNewFrame()); + QuicMessageFrame* frame3 = + new QuicMessageFrame(3, MemSliceFromString("message2")); + EXPECT_TRUE(creator_.AddFrame(QuicFrame(frame3), NOT_RETRANSMISSION)); + EXPECT_EQ(1u, creator_.ExpansionOnNewFrame()); + creator_.FlushCurrentPacket(); + + QuicFrame frame; + QuicStreamId stream_id = QuicUtils::GetFirstBidirectionalStreamId( + client_framer_.transport_version(), Perspective::IS_CLIENT); + const std::string data("test"); + EXPECT_TRUE(creator_.ConsumeDataToFillCurrentPacket( + stream_id, data, 0u, false, false, NOT_RETRANSMISSION, &frame)); + QuicMessageFrame* frame4 = + new QuicMessageFrame(4, MemSliceFromString("message")); + EXPECT_TRUE(creator_.AddFrame(QuicFrame(frame4), NOT_RETRANSMISSION)); + EXPECT_TRUE(creator_.HasPendingFrames()); + // Verify there is not enough room for largest payload. + EXPECT_FALSE(creator_.HasRoomForMessageFrame( + creator_.GetCurrentLargestMessagePayload())); + // Add largest message will causes the flush of the stream frame. + QuicMessageFrame frame5(5, MemSliceFromString(large_message)); + EXPECT_FALSE(creator_.AddFrame(QuicFrame(&frame5), NOT_RETRANSMISSION)); + EXPECT_FALSE(creator_.HasPendingFrames()); +} + +TEST_P(QuicPacketCreatorTest, MessageFrameConsumption) { + if (!VersionSupportsMessageFrames(client_framer_.transport_version())) { + return; + } + if (client_framer_.version().UsesTls()) { + creator_.SetMaxDatagramFrameSize(kMaxAcceptedDatagramFrameSize); + } + std::string message_data(kDefaultMaxPacketSize, 'a'); + // Test all possible encryption levels of message frames. + for (EncryptionLevel level : + {ENCRYPTION_ZERO_RTT, ENCRYPTION_FORWARD_SECURE}) { + creator_.set_encryption_level(level); + // Test all possible sizes of message frames. + for (size_t message_size = 0; + message_size <= creator_.GetCurrentLargestMessagePayload(); + ++message_size) { + QuicMessageFrame* frame = + new QuicMessageFrame(0, MemSliceFromString(absl::string_view( + message_data.data(), message_size))); + EXPECT_TRUE(creator_.AddFrame(QuicFrame(frame), NOT_RETRANSMISSION)); + EXPECT_TRUE(creator_.HasPendingFrames()); + + size_t expansion_bytes = message_size >= 64 ? 2 : 1; + EXPECT_EQ(expansion_bytes, creator_.ExpansionOnNewFrame()); + // Verify BytesFree returns bytes available for the next frame, which + // should subtract the message length. + size_t expected_bytes_free = + creator_.GetCurrentLargestMessagePayload() - message_size < + expansion_bytes + ? 0 + : creator_.GetCurrentLargestMessagePayload() - expansion_bytes - + message_size; + EXPECT_EQ(expected_bytes_free, creator_.BytesFree()); + EXPECT_LE(creator_.GetGuaranteedLargestMessagePayload(), + creator_.GetCurrentLargestMessagePayload()); + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce(Invoke(this, &QuicPacketCreatorTest::SaveSerializedPacket)); + creator_.FlushCurrentPacket(); + ASSERT_TRUE(serialized_packet_->encrypted_buffer); + DeleteSerializedPacket(); + } + } +} + +TEST_P(QuicPacketCreatorTest, GetGuaranteedLargestMessagePayload) { + ParsedQuicVersion version = GetParam().version; + if (!version.SupportsMessageFrames()) { + return; + } + if (version.UsesTls()) { + creator_.SetMaxDatagramFrameSize(kMaxAcceptedDatagramFrameSize); + } + QuicPacketLength expected_largest_payload = 1215; + if (version.HasLongHeaderLengths()) { + expected_largest_payload -= 2; + } + if (version.HasLengthPrefixedConnectionIds()) { + expected_largest_payload -= 1; + } + EXPECT_EQ(expected_largest_payload, + creator_.GetGuaranteedLargestMessagePayload()); + EXPECT_TRUE(creator_.HasRoomForMessageFrame( + creator_.GetGuaranteedLargestMessagePayload())); + + // Now test whether SetMaxDatagramFrameSize works. + creator_.SetMaxDatagramFrameSize(expected_largest_payload + 1 + + kQuicFrameTypeSize); + EXPECT_EQ(expected_largest_payload, + creator_.GetGuaranteedLargestMessagePayload()); + EXPECT_TRUE(creator_.HasRoomForMessageFrame( + creator_.GetGuaranteedLargestMessagePayload())); + + creator_.SetMaxDatagramFrameSize(expected_largest_payload + + kQuicFrameTypeSize); + EXPECT_EQ(expected_largest_payload, + creator_.GetGuaranteedLargestMessagePayload()); + EXPECT_TRUE(creator_.HasRoomForMessageFrame( + creator_.GetGuaranteedLargestMessagePayload())); + + creator_.SetMaxDatagramFrameSize(expected_largest_payload - 1 + + kQuicFrameTypeSize); + EXPECT_EQ(expected_largest_payload - 1, + creator_.GetGuaranteedLargestMessagePayload()); + EXPECT_TRUE(creator_.HasRoomForMessageFrame( + creator_.GetGuaranteedLargestMessagePayload())); + + constexpr QuicPacketLength kFrameSizeLimit = 1000; + constexpr QuicPacketLength kPayloadSizeLimit = + kFrameSizeLimit - kQuicFrameTypeSize; + creator_.SetMaxDatagramFrameSize(kFrameSizeLimit); + EXPECT_EQ(creator_.GetGuaranteedLargestMessagePayload(), kPayloadSizeLimit); + EXPECT_TRUE(creator_.HasRoomForMessageFrame(kPayloadSizeLimit)); + EXPECT_FALSE(creator_.HasRoomForMessageFrame(kPayloadSizeLimit + 1)); +} + +TEST_P(QuicPacketCreatorTest, GetCurrentLargestMessagePayload) { + ParsedQuicVersion version = GetParam().version; + if (!version.SupportsMessageFrames()) { + return; + } + if (version.UsesTls()) { + creator_.SetMaxDatagramFrameSize(kMaxAcceptedDatagramFrameSize); + } + QuicPacketLength expected_largest_payload = 1215; + if (version.SendsVariableLengthPacketNumberInLongHeader()) { + expected_largest_payload += 3; + } + if (version.HasLongHeaderLengths()) { + expected_largest_payload -= 2; + } + if (version.HasLengthPrefixedConnectionIds()) { + expected_largest_payload -= 1; + } + EXPECT_EQ(expected_largest_payload, + creator_.GetCurrentLargestMessagePayload()); + + // Now test whether SetMaxDatagramFrameSize works. + creator_.SetMaxDatagramFrameSize(expected_largest_payload + 1 + + kQuicFrameTypeSize); + EXPECT_EQ(expected_largest_payload, + creator_.GetCurrentLargestMessagePayload()); + + creator_.SetMaxDatagramFrameSize(expected_largest_payload + + kQuicFrameTypeSize); + EXPECT_EQ(expected_largest_payload, + creator_.GetCurrentLargestMessagePayload()); + + creator_.SetMaxDatagramFrameSize(expected_largest_payload - 1 + + kQuicFrameTypeSize); + EXPECT_EQ(expected_largest_payload - 1, + creator_.GetCurrentLargestMessagePayload()); +} + +TEST_P(QuicPacketCreatorTest, PacketTransmissionType) { + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + + QuicAckFrame temp_ack_frame = InitAckFrame(1); + QuicFrame ack_frame(&temp_ack_frame); + ASSERT_FALSE(QuicUtils::IsRetransmittableFrame(ack_frame.type)); + + QuicStreamId stream_id = QuicUtils::GetFirstBidirectionalStreamId( + client_framer_.transport_version(), Perspective::IS_CLIENT); + QuicFrame stream_frame(QuicStreamFrame(stream_id, + /*fin=*/false, 0u, + absl::string_view())); + ASSERT_TRUE(QuicUtils::IsRetransmittableFrame(stream_frame.type)); + + QuicFrame stream_frame_2(QuicStreamFrame(stream_id, + /*fin=*/false, 1u, + absl::string_view())); + + QuicFrame padding_frame{QuicPaddingFrame()}; + ASSERT_FALSE(QuicUtils::IsRetransmittableFrame(padding_frame.type)); + + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce(Invoke(this, &QuicPacketCreatorTest::SaveSerializedPacket)); + + EXPECT_TRUE(creator_.AddFrame(ack_frame, LOSS_RETRANSMISSION)); + ASSERT_EQ(serialized_packet_, nullptr); + + EXPECT_TRUE(creator_.AddFrame(stream_frame, PTO_RETRANSMISSION)); + ASSERT_EQ(serialized_packet_, nullptr); + + EXPECT_TRUE(creator_.AddFrame(stream_frame_2, PATH_RETRANSMISSION)); + ASSERT_EQ(serialized_packet_, nullptr); + + EXPECT_TRUE(creator_.AddFrame(padding_frame, PTO_RETRANSMISSION)); + creator_.FlushCurrentPacket(); + ASSERT_TRUE(serialized_packet_->encrypted_buffer); + + // The last retransmittable frame on packet is a stream frame, the packet's + // transmission type should be the same as the stream frame's. + EXPECT_EQ(serialized_packet_->transmission_type, PATH_RETRANSMISSION); + DeleteSerializedPacket(); +} + +TEST_P(QuicPacketCreatorTest, + PacketBytesRetransmitted_AddFrame_Retransmission) { + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + + QuicAckFrame temp_ack_frame = InitAckFrame(1); + QuicFrame ack_frame(&temp_ack_frame); + EXPECT_TRUE(creator_.AddFrame(ack_frame, LOSS_RETRANSMISSION)); + + QuicStreamId stream_id = QuicUtils::GetFirstBidirectionalStreamId( + client_framer_.transport_version(), Perspective::IS_CLIENT); + + QuicFrame stream_frame; + const std::string data("data"); + // ConsumeDataToFillCurrentPacket calls AddFrame + ASSERT_TRUE(creator_.ConsumeDataToFillCurrentPacket( + stream_id, data, 0u, false, false, PTO_RETRANSMISSION, &stream_frame)); + EXPECT_EQ(4u, stream_frame.stream_frame.data_length); + + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce(Invoke(this, &QuicPacketCreatorTest::SaveSerializedPacket)); + + creator_.FlushCurrentPacket(); + ASSERT_TRUE(serialized_packet_->encrypted_buffer); + ASSERT_FALSE(serialized_packet_->bytes_not_retransmitted.has_value()); + + DeleteSerializedPacket(); +} + +TEST_P(QuicPacketCreatorTest, + PacketBytesRetransmitted_AddFrame_NotRetransmission) { + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + + QuicAckFrame temp_ack_frame = InitAckFrame(1); + QuicFrame ack_frame(&temp_ack_frame); + EXPECT_TRUE(creator_.AddFrame(ack_frame, NOT_RETRANSMISSION)); + + QuicStreamId stream_id = QuicUtils::GetFirstBidirectionalStreamId( + client_framer_.transport_version(), Perspective::IS_CLIENT); + + QuicFrame stream_frame; + const std::string data("data"); + // ConsumeDataToFillCurrentPacket calls AddFrame + ASSERT_TRUE(creator_.ConsumeDataToFillCurrentPacket( + stream_id, data, 0u, false, false, NOT_RETRANSMISSION, &stream_frame)); + EXPECT_EQ(4u, stream_frame.stream_frame.data_length); + + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce(Invoke(this, &QuicPacketCreatorTest::SaveSerializedPacket)); + + creator_.FlushCurrentPacket(); + ASSERT_TRUE(serialized_packet_->encrypted_buffer); + ASSERT_FALSE(serialized_packet_->bytes_not_retransmitted.has_value()); + + DeleteSerializedPacket(); +} + +TEST_P(QuicPacketCreatorTest, PacketBytesRetransmitted_AddFrame_MixedFrames) { + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + + QuicAckFrame temp_ack_frame = InitAckFrame(1); + QuicFrame ack_frame(&temp_ack_frame); + EXPECT_TRUE(creator_.AddFrame(ack_frame, NOT_RETRANSMISSION)); + + QuicStreamId stream_id = QuicUtils::GetFirstBidirectionalStreamId( + client_framer_.transport_version(), Perspective::IS_CLIENT); + + QuicFrame stream_frame; + const std::string data("data"); + // ConsumeDataToFillCurrentPacket calls AddFrame + ASSERT_TRUE(creator_.ConsumeDataToFillCurrentPacket( + stream_id, data, 0u, false, false, NOT_RETRANSMISSION, &stream_frame)); + EXPECT_EQ(4u, stream_frame.stream_frame.data_length); + + QuicFrame stream_frame2; + // ConsumeDataToFillCurrentPacket calls AddFrame + ASSERT_TRUE(creator_.ConsumeDataToFillCurrentPacket( + stream_id, data, 0u, false, false, LOSS_RETRANSMISSION, &stream_frame2)); + EXPECT_EQ(4u, stream_frame2.stream_frame.data_length); + + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce(Invoke(this, &QuicPacketCreatorTest::SaveSerializedPacket)); + + creator_.FlushCurrentPacket(); + ASSERT_TRUE(serialized_packet_->encrypted_buffer); + ASSERT_TRUE(serialized_packet_->bytes_not_retransmitted.has_value()); + ASSERT_GE(serialized_packet_->bytes_not_retransmitted.value(), 4u); + + DeleteSerializedPacket(); +} + +TEST_P(QuicPacketCreatorTest, + PacketBytesRetransmitted_CreateAndSerializeStreamFrame_Retransmission) { + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + + const std::string data("test"); + producer_.SaveStreamData(GetNthClientInitiatedStreamId(0), data); + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce(Invoke(this, &QuicPacketCreatorTest::SaveSerializedPacket)); + size_t num_bytes_consumed; + // Retransmission frame adds to packet's bytes_retransmitted + creator_.CreateAndSerializeStreamFrame( + GetNthClientInitiatedStreamId(0), data.length(), 0, 0, true, + LOSS_RETRANSMISSION, &num_bytes_consumed); + EXPECT_EQ(4u, num_bytes_consumed); + + ASSERT_TRUE(serialized_packet_->encrypted_buffer); + ASSERT_FALSE(serialized_packet_->bytes_not_retransmitted.has_value()); + DeleteSerializedPacket(); + + EXPECT_FALSE(creator_.HasPendingFrames()); +} + +TEST_P( + QuicPacketCreatorTest, + PacketBytesRetransmitted_CreateAndSerializeStreamFrame_NotRetransmission) { + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + + const std::string data("test"); + producer_.SaveStreamData(GetNthClientInitiatedStreamId(0), data); + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce(Invoke(this, &QuicPacketCreatorTest::SaveSerializedPacket)); + size_t num_bytes_consumed; + // Non-retransmission frame does not add to packet's bytes_retransmitted + creator_.CreateAndSerializeStreamFrame( + GetNthClientInitiatedStreamId(0), data.length(), 0, 0, true, + NOT_RETRANSMISSION, &num_bytes_consumed); + EXPECT_EQ(4u, num_bytes_consumed); + + ASSERT_TRUE(serialized_packet_->encrypted_buffer); + ASSERT_FALSE(serialized_packet_->bytes_not_retransmitted.has_value()); + DeleteSerializedPacket(); + + EXPECT_FALSE(creator_.HasPendingFrames()); +} + +TEST_P(QuicPacketCreatorTest, RetryToken) { + if (!GetParam().version_serialization || + !QuicVersionHasLongHeaderLengths(client_framer_.transport_version())) { + return; + } + + char retry_token_bytes[] = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}; + + creator_.SetRetryToken( + std::string(retry_token_bytes, sizeof(retry_token_bytes))); + + frames_.push_back(QuicFrame(QuicPingFrame())); + SerializedPacket serialized = SerializeAllFrames(frames_); + + QuicPacketHeader header; + { + InSequence s; + EXPECT_CALL(framer_visitor_, OnPacket()); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedPublicHeader(_)); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedHeader(_)); + EXPECT_CALL(framer_visitor_, OnDecryptedPacket(_, _)); + EXPECT_CALL(framer_visitor_, OnPacketHeader(_)) + .WillOnce(DoAll(SaveArg<0>(&header), Return(true))); + if (client_framer_.version().HasHeaderProtection()) { + EXPECT_CALL(framer_visitor_, OnPaddingFrame(_)); + } + EXPECT_CALL(framer_visitor_, OnPingFrame(_)); + EXPECT_CALL(framer_visitor_, OnPacketComplete()); + } + ProcessPacket(serialized); + ASSERT_TRUE(header.version_flag); + ASSERT_EQ(header.long_packet_type, INITIAL); + ASSERT_EQ(header.retry_token.length(), sizeof(retry_token_bytes)); + quiche::test::CompareCharArraysWithHexError( + "retry token", header.retry_token.data(), header.retry_token.length(), + retry_token_bytes, sizeof(retry_token_bytes)); +} + +TEST_P(QuicPacketCreatorTest, GetConnectionId) { + EXPECT_EQ(TestConnectionId(2), creator_.GetDestinationConnectionId()); + EXPECT_EQ(EmptyQuicConnectionId(), creator_.GetSourceConnectionId()); +} + +TEST_P(QuicPacketCreatorTest, ClientConnectionId) { + if (!client_framer_.version().SupportsClientConnectionIds()) { + return; + } + EXPECT_EQ(TestConnectionId(2), creator_.GetDestinationConnectionId()); + EXPECT_EQ(EmptyQuicConnectionId(), creator_.GetSourceConnectionId()); + creator_.SetClientConnectionId(TestConnectionId(0x33)); + EXPECT_EQ(TestConnectionId(2), creator_.GetDestinationConnectionId()); + EXPECT_EQ(TestConnectionId(0x33), creator_.GetSourceConnectionId()); +} + +TEST_P(QuicPacketCreatorTest, CoalesceStreamFrames) { + InSequence s; + if (!GetParam().version_serialization) { + creator_.StopSendingVersion(); + } + const size_t max_plaintext_size = + client_framer_.GetMaxPlaintextSize(creator_.max_packet_length()); + EXPECT_FALSE(creator_.HasPendingFrames()); + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + QuicStreamId stream_id1 = QuicUtils::GetFirstBidirectionalStreamId( + client_framer_.transport_version(), Perspective::IS_CLIENT); + QuicStreamId stream_id2 = GetNthClientInitiatedStreamId(1); + EXPECT_FALSE(creator_.HasPendingStreamFramesOfStream(stream_id1)); + EXPECT_EQ(max_plaintext_size - + GetPacketHeaderSize( + client_framer_.transport_version(), + creator_.GetDestinationConnectionIdLength(), + creator_.GetSourceConnectionIdLength(), + QuicPacketCreatorPeer::SendVersionInPacket(&creator_), + !kIncludeDiversificationNonce, + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_), + QuicPacketCreatorPeer::GetRetryTokenLengthLength(&creator_), + 0, QuicPacketCreatorPeer::GetLengthLength(&creator_)), + creator_.BytesFree()); + StrictMock debug; + creator_.set_debug_delegate(&debug); + + QuicFrame frame; + const std::string data1("test"); + EXPECT_CALL(debug, OnFrameAddedToPacket(_)); + ASSERT_TRUE(creator_.ConsumeDataToFillCurrentPacket( + stream_id1, data1, 0u, false, false, NOT_RETRANSMISSION, &frame)); + EXPECT_TRUE(creator_.HasPendingFrames()); + EXPECT_TRUE(creator_.HasPendingStreamFramesOfStream(stream_id1)); + + const std::string data2("coalesce"); + // frame will be coalesced with the first frame. + const auto previous_size = creator_.PacketSize(); + QuicStreamFrame target(stream_id1, true, 0, data1.length() + data2.length()); + EXPECT_CALL(debug, OnStreamFrameCoalesced(target)); + ASSERT_TRUE(creator_.ConsumeDataToFillCurrentPacket( + stream_id1, data2, 4u, true, false, NOT_RETRANSMISSION, &frame)); + EXPECT_EQ(frame.stream_frame.data_length, + creator_.PacketSize() - previous_size); + + // frame is for another stream, so it won't be coalesced. + const auto length = creator_.BytesFree() - 10u; + const std::string data3(length, 'x'); + EXPECT_CALL(debug, OnFrameAddedToPacket(_)); + ASSERT_TRUE(creator_.ConsumeDataToFillCurrentPacket( + stream_id2, data3, 0u, false, false, NOT_RETRANSMISSION, &frame)); + EXPECT_TRUE(creator_.HasPendingStreamFramesOfStream(stream_id2)); + + // The packet doesn't have enough free bytes for all data, but will still be + // able to consume and coalesce part of them. + EXPECT_CALL(debug, OnStreamFrameCoalesced(_)); + const std::string data4("somerandomdata"); + ASSERT_TRUE(creator_.ConsumeDataToFillCurrentPacket( + stream_id2, data4, length, false, false, NOT_RETRANSMISSION, &frame)); + + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce(Invoke(this, &QuicPacketCreatorTest::SaveSerializedPacket)); + creator_.FlushCurrentPacket(); + EXPECT_CALL(framer_visitor_, OnPacket()); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedPublicHeader(_)); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedHeader(_)); + EXPECT_CALL(framer_visitor_, OnDecryptedPacket(_, _)); + EXPECT_CALL(framer_visitor_, OnPacketHeader(_)); + // The packet should only have 2 stream frames. + EXPECT_CALL(framer_visitor_, OnStreamFrame(_)); + EXPECT_CALL(framer_visitor_, OnStreamFrame(_)); + EXPECT_CALL(framer_visitor_, OnPacketComplete()); + ProcessPacket(*serialized_packet_); +} + +TEST_P(QuicPacketCreatorTest, SaveNonRetransmittableFrames) { + QuicAckFrame ack_frame(InitAckFrame(1)); + frames_.push_back(QuicFrame(&ack_frame)); + frames_.push_back(QuicFrame(QuicPaddingFrame(-1))); + SerializedPacket serialized = SerializeAllFrames(frames_); + ASSERT_EQ(2u, serialized.nonretransmittable_frames.size()); + EXPECT_EQ(ACK_FRAME, serialized.nonretransmittable_frames[0].type); + EXPECT_EQ(PADDING_FRAME, serialized.nonretransmittable_frames[1].type); + // Verify full padding frame is translated to a padding frame with actual + // bytes of padding. + EXPECT_LT( + 0, + serialized.nonretransmittable_frames[1].padding_frame.num_padding_bytes); + frames_.clear(); + + // Serialize another packet with the same frames. + SerializedPacket packet = QuicPacketCreatorPeer::SerializeAllFrames( + &creator_, serialized.nonretransmittable_frames, buffer_, + kMaxOutgoingPacketSize); + // Verify the packet length of both packets are equal. + EXPECT_EQ(serialized.encrypted_length, packet.encrypted_length); +} + +TEST_P(QuicPacketCreatorTest, SerializeCoalescedPacket) { + QuicCoalescedPacket coalesced; + quiche::SimpleBufferAllocator allocator; + QuicSocketAddress self_address(QuicIpAddress::Loopback4(), 1); + QuicSocketAddress peer_address(QuicIpAddress::Loopback4(), 2); + for (size_t i = ENCRYPTION_INITIAL; i < NUM_ENCRYPTION_LEVELS; ++i) { + EncryptionLevel level = static_cast(i); + creator_.set_encryption_level(level); + QuicAckFrame ack_frame(InitAckFrame(1)); + if (level != ENCRYPTION_ZERO_RTT) { + frames_.push_back(QuicFrame(&ack_frame)); + } + if (level != ENCRYPTION_INITIAL && level != ENCRYPTION_HANDSHAKE) { + frames_.push_back( + QuicFrame(QuicStreamFrame(1, false, 0u, absl::string_view()))); + } + SerializedPacket serialized = SerializeAllFrames(frames_); + EXPECT_EQ(level, serialized.encryption_level); + frames_.clear(); + ASSERT_TRUE(coalesced.MaybeCoalescePacket(serialized, self_address, + peer_address, &allocator, + creator_.max_packet_length())); + } + char buffer[kMaxOutgoingPacketSize]; + size_t coalesced_length = creator_.SerializeCoalescedPacket( + coalesced, buffer, kMaxOutgoingPacketSize); + // Verify packet is padded to full. + ASSERT_EQ(coalesced.max_packet_length(), coalesced_length); + if (!QuicVersionHasLongHeaderLengths(server_framer_.transport_version())) { + return; + } + // Verify packet process. + std::unique_ptr packets[NUM_ENCRYPTION_LEVELS]; + packets[ENCRYPTION_INITIAL] = + std::make_unique(buffer, coalesced_length); + for (size_t i = ENCRYPTION_INITIAL; i < NUM_ENCRYPTION_LEVELS; ++i) { + InSequence s; + EXPECT_CALL(framer_visitor_, OnPacket()); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedPublicHeader(_)); + if (i < ENCRYPTION_FORWARD_SECURE) { + // Save coalesced packet. + EXPECT_CALL(framer_visitor_, OnCoalescedPacket(_)) + .WillOnce(Invoke([i, &packets](const QuicEncryptedPacket& packet) { + packets[i + 1] = packet.Clone(); + })); + } + EXPECT_CALL(framer_visitor_, OnUnauthenticatedHeader(_)); + EXPECT_CALL(framer_visitor_, OnDecryptedPacket(_, _)); + EXPECT_CALL(framer_visitor_, OnPacketHeader(_)); + if (i != ENCRYPTION_ZERO_RTT) { + if (i != ENCRYPTION_INITIAL) { + EXPECT_CALL(framer_visitor_, OnPaddingFrame(_)) + .Times(testing::AtMost(1)); + } + EXPECT_CALL(framer_visitor_, OnAckFrameStart(_, _)) + .WillOnce(Return(true)); + EXPECT_CALL(framer_visitor_, + OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2))) + .WillOnce(Return(true)); + EXPECT_CALL(framer_visitor_, OnAckFrameEnd(_, _)).WillOnce(Return(true)); + } + if (i == ENCRYPTION_INITIAL) { + // Verify padding is added. + EXPECT_CALL(framer_visitor_, OnPaddingFrame(_)); + } + if (i == ENCRYPTION_ZERO_RTT) { + EXPECT_CALL(framer_visitor_, OnPaddingFrame(_)); + } + if (i != ENCRYPTION_INITIAL && i != ENCRYPTION_HANDSHAKE) { + EXPECT_CALL(framer_visitor_, OnStreamFrame(_)); + } + EXPECT_CALL(framer_visitor_, OnPacketComplete()); + server_framer_.ProcessPacket(*packets[i]); + } +} + +TEST_P(QuicPacketCreatorTest, SoftMaxPacketLength) { + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + QuicByteCount previous_max_packet_length = creator_.max_packet_length(); + const size_t overhead = + GetPacketHeaderOverhead(client_framer_.transport_version()) + + QuicPacketCreator::MinPlaintextPacketSize( + client_framer_.version(), + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_)) + + GetEncryptionOverhead(); + // Make sure a length which cannot accommodate header (includes header + // protection minimal length) gets rejected. + creator_.SetSoftMaxPacketLength(overhead - 1); + EXPECT_EQ(previous_max_packet_length, creator_.max_packet_length()); + + creator_.SetSoftMaxPacketLength(overhead); + EXPECT_EQ(overhead, creator_.max_packet_length()); + + // Verify creator has room for stream frame because max_packet_length_ gets + // restored. + ASSERT_TRUE(creator_.HasRoomForStreamFrame( + GetNthClientInitiatedStreamId(1), kMaxIetfVarInt, + std::numeric_limits::max())); + EXPECT_EQ(previous_max_packet_length, creator_.max_packet_length()); + + // Same for message frame. + if (VersionSupportsMessageFrames(client_framer_.transport_version())) { + creator_.SetSoftMaxPacketLength(overhead); + if (client_framer_.version().UsesTls()) { + creator_.SetMaxDatagramFrameSize(kMaxAcceptedDatagramFrameSize); + } + // Verify GetCurrentLargestMessagePayload is based on the actual + // max_packet_length. + EXPECT_LT(1u, creator_.GetCurrentLargestMessagePayload()); + EXPECT_EQ(overhead, creator_.max_packet_length()); + ASSERT_TRUE(creator_.HasRoomForMessageFrame( + creator_.GetCurrentLargestMessagePayload())); + EXPECT_EQ(previous_max_packet_length, creator_.max_packet_length()); + } + + // Verify creator can consume crypto data because max_packet_length_ gets + // restored. + creator_.SetSoftMaxPacketLength(overhead); + EXPECT_EQ(overhead, creator_.max_packet_length()); + const std::string data = "crypto data"; + QuicFrame frame; + if (!QuicVersionUsesCryptoFrames(client_framer_.transport_version())) { + ASSERT_TRUE(creator_.ConsumeDataToFillCurrentPacket( + QuicUtils::GetCryptoStreamId(client_framer_.transport_version()), data, + kOffset, false, true, NOT_RETRANSMISSION, &frame)); + size_t bytes_consumed = frame.stream_frame.data_length; + EXPECT_LT(0u, bytes_consumed); + } else { + producer_.SaveCryptoData(ENCRYPTION_INITIAL, kOffset, data); + ASSERT_TRUE(creator_.ConsumeCryptoDataToFillCurrentPacket( + ENCRYPTION_INITIAL, data.length(), kOffset, + /*needs_full_padding=*/true, NOT_RETRANSMISSION, &frame)); + size_t bytes_consumed = frame.crypto_frame->data_length; + EXPECT_LT(0u, bytes_consumed); + } + EXPECT_TRUE(creator_.HasPendingFrames()); + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce(Invoke(this, &QuicPacketCreatorTest::SaveSerializedPacket)); + creator_.FlushCurrentPacket(); + + // Verify ACK frame can be consumed. + creator_.SetSoftMaxPacketLength(overhead); + EXPECT_EQ(overhead, creator_.max_packet_length()); + QuicAckFrame ack_frame(InitAckFrame(10u)); + EXPECT_TRUE(creator_.AddFrame(QuicFrame(&ack_frame), NOT_RETRANSMISSION)); + EXPECT_TRUE(creator_.HasPendingFrames()); +} + +TEST_P(QuicPacketCreatorTest, + ChangingEncryptionLevelRemovesSoftMaxPacketLength) { + if (!client_framer_.version().CanSendCoalescedPackets()) { + return; + } + // First set encryption level to forward secure which has the shortest header. + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + const QuicByteCount previous_max_packet_length = creator_.max_packet_length(); + const size_t min_acceptable_packet_size = + GetPacketHeaderOverhead(client_framer_.transport_version()) + + QuicPacketCreator::MinPlaintextPacketSize( + client_framer_.version(), + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_)) + + GetEncryptionOverhead(); + // Then set the soft max packet length to the lowest allowed value. + creator_.SetSoftMaxPacketLength(min_acceptable_packet_size); + // Make sure that the low value was accepted. + EXPECT_EQ(creator_.max_packet_length(), min_acceptable_packet_size); + // Now set the encryption level to handshake which increases the header size. + creator_.set_encryption_level(ENCRYPTION_HANDSHAKE); + // Make sure that adding a frame removes the the soft max packet length. + QuicAckFrame ack_frame(InitAckFrame(1)); + frames_.push_back(QuicFrame(&ack_frame)); + SerializedPacket serialized = SerializeAllFrames(frames_); + EXPECT_EQ(serialized.encryption_level, ENCRYPTION_HANDSHAKE); + EXPECT_EQ(creator_.max_packet_length(), previous_max_packet_length); +} + +TEST_P(QuicPacketCreatorTest, MinPayloadLength) { + ParsedQuicVersion version = client_framer_.version(); + for (QuicPacketNumberLength pn_length : + {PACKET_1BYTE_PACKET_NUMBER, PACKET_2BYTE_PACKET_NUMBER, + PACKET_3BYTE_PACKET_NUMBER, PACKET_4BYTE_PACKET_NUMBER}) { + if (!version.HasHeaderProtection()) { + EXPECT_EQ(creator_.MinPlaintextPacketSize(version, pn_length), 0); + } else { + EXPECT_EQ(creator_.MinPlaintextPacketSize(version, pn_length), + (version.UsesTls() ? 4 : 8) - pn_length); + } + } +} + +// A variant of StreamFrameConsumption that tests when expansion of the stream +// frame puts it at or over the max length, but the packet is supposed to be +// padded to max length. +TEST_P(QuicPacketCreatorTest, PadWhenAlmostMaxLength) { + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + // Compute the total overhead for a single frame in packet. + const size_t overhead = + GetPacketHeaderOverhead(client_framer_.transport_version()) + + GetEncryptionOverhead() + + GetStreamFrameOverhead(client_framer_.transport_version()); + size_t capacity = kDefaultMaxPacketSize - overhead; + for (size_t bytes_free = 1; bytes_free <= 2; bytes_free++) { + std::string data(capacity - bytes_free, 'A'); + + QuicFrame frame; + ASSERT_TRUE(creator_.ConsumeDataToFillCurrentPacket( + GetNthClientInitiatedStreamId(1), data, kOffset, false, + /*needs_full_padding=*/true, NOT_RETRANSMISSION, &frame)); + + // BytesFree() returns bytes available for the next frame, which will + // be two bytes smaller since the stream frame would need to be grown. + EXPECT_EQ(2u, creator_.ExpansionOnNewFrame()); + EXPECT_EQ(0u, creator_.BytesFree()); + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce(Invoke(this, &QuicPacketCreatorTest::SaveSerializedPacket)); + creator_.FlushCurrentPacket(); + EXPECT_EQ(serialized_packet_->encrypted_length, kDefaultMaxPacketSize); + DeleteSerializedPacket(); + } +} + +TEST_P(QuicPacketCreatorTest, MorePendingPaddingThanBytesFree) { + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + // Compute the total overhead for a single frame in packet. + const size_t overhead = + GetPacketHeaderOverhead(client_framer_.transport_version()) + + GetEncryptionOverhead() + + GetStreamFrameOverhead(client_framer_.transport_version()); + size_t capacity = kDefaultMaxPacketSize - overhead; + const size_t pending_padding = 10; + std::string data(capacity - pending_padding, 'A'); + QuicFrame frame; + // The stream frame means that BytesFree() will be less than the + // available space, because of the frame length field. + ASSERT_TRUE(creator_.ConsumeDataToFillCurrentPacket( + GetNthClientInitiatedStreamId(1), data, kOffset, false, + /*needs_full_padding=*/false, NOT_RETRANSMISSION, &frame)); + creator_.AddPendingPadding(pending_padding); + EXPECT_EQ(2u, creator_.ExpansionOnNewFrame()); + // BytesFree() does not know about pending_padding because that's added + // when flushed. + EXPECT_EQ(pending_padding - 2u, creator_.BytesFree()); + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce(Invoke(this, &QuicPacketCreatorTest::SaveSerializedPacket)); + creator_.FlushCurrentPacket(); + /* Without the fix, the packet is not full-length. */ + EXPECT_EQ(serialized_packet_->encrypted_length, kDefaultMaxPacketSize); + DeleteSerializedPacket(); +} + +class MockDelegate : public QuicPacketCreator::DelegateInterface { + public: + MockDelegate() {} + MockDelegate(const MockDelegate&) = delete; + MockDelegate& operator=(const MockDelegate&) = delete; + ~MockDelegate() override {} + + MOCK_METHOD(bool, ShouldGeneratePacket, + (HasRetransmittableData retransmittable, IsHandshake handshake), + (override)); + MOCK_METHOD(const QuicFrames, MaybeBundleAckOpportunistically, (), + (override)); + MOCK_METHOD(QuicPacketBuffer, GetPacketBuffer, (), (override)); + MOCK_METHOD(void, OnSerializedPacket, (SerializedPacket), (override)); + MOCK_METHOD(void, OnUnrecoverableError, (QuicErrorCode, const std::string&), + (override)); + MOCK_METHOD(SerializedPacketFate, GetSerializedPacketFate, + (bool, EncryptionLevel), (override)); + + void SetCanWriteAnything() { + EXPECT_CALL(*this, ShouldGeneratePacket(_, _)).WillRepeatedly(Return(true)); + EXPECT_CALL(*this, ShouldGeneratePacket(NO_RETRANSMITTABLE_DATA, _)) + .WillRepeatedly(Return(true)); + } + + void SetCanNotWrite() { + EXPECT_CALL(*this, ShouldGeneratePacket(_, _)) + .WillRepeatedly(Return(false)); + EXPECT_CALL(*this, ShouldGeneratePacket(NO_RETRANSMITTABLE_DATA, _)) + .WillRepeatedly(Return(false)); + } + + // Use this when only ack frames should be allowed to be written. + void SetCanWriteOnlyNonRetransmittable() { + EXPECT_CALL(*this, ShouldGeneratePacket(_, _)) + .WillRepeatedly(Return(false)); + EXPECT_CALL(*this, ShouldGeneratePacket(NO_RETRANSMITTABLE_DATA, _)) + .WillRepeatedly(Return(true)); + } +}; + +// Simple struct for describing the contents of a packet. +// Useful in conjunction with a SimpleQuicFrame for validating that a packet +// contains the expected frames. +struct PacketContents { + PacketContents() + : num_ack_frames(0), + num_connection_close_frames(0), + num_goaway_frames(0), + num_rst_stream_frames(0), + num_stop_waiting_frames(0), + num_stream_frames(0), + num_crypto_frames(0), + num_ping_frames(0), + num_mtu_discovery_frames(0), + num_padding_frames(0) {} + + size_t num_ack_frames; + size_t num_connection_close_frames; + size_t num_goaway_frames; + size_t num_rst_stream_frames; + size_t num_stop_waiting_frames; + size_t num_stream_frames; + size_t num_crypto_frames; + size_t num_ping_frames; + size_t num_mtu_discovery_frames; + size_t num_padding_frames; +}; + +class MultiplePacketsTestPacketCreator : public QuicPacketCreator { + public: + MultiplePacketsTestPacketCreator( + QuicConnectionId connection_id, QuicFramer* framer, + QuicRandom* random_generator, + QuicPacketCreator::DelegateInterface* delegate, + SimpleDataProducer* producer) + : QuicPacketCreator(connection_id, framer, random_generator, delegate), + ack_frame_(InitAckFrame(1)), + delegate_(static_cast(delegate)), + producer_(producer) {} + + bool ConsumeRetransmittableControlFrame(const QuicFrame& frame, + bool bundle_ack) { + if (!has_ack()) { + QuicFrames frames; + if (bundle_ack) { + frames.push_back(QuicFrame(&ack_frame_)); + } + if (delegate_->ShouldGeneratePacket(NO_RETRANSMITTABLE_DATA, + NOT_HANDSHAKE)) { + EXPECT_CALL(*delegate_, MaybeBundleAckOpportunistically()) + .WillOnce(Return(frames)); + } + } + return QuicPacketCreator::ConsumeRetransmittableControlFrame(frame); + } + + QuicConsumedData ConsumeDataFastPath(QuicStreamId id, + absl::string_view data) { + // Save data before data is consumed. + if (!data.empty()) { + producer_->SaveStreamData(id, data); + } + return QuicPacketCreator::ConsumeDataFastPath(id, data.length(), + /* offset = */ 0, + /* fin = */ true, 0); + } + + QuicConsumedData ConsumeData(QuicStreamId id, absl::string_view data, + QuicStreamOffset offset, + StreamSendingState state) { + // Save data before data is consumed. + if (!data.empty()) { + producer_->SaveStreamData(id, data); + } + if (!has_ack() && delegate_->ShouldGeneratePacket(NO_RETRANSMITTABLE_DATA, + NOT_HANDSHAKE)) { + EXPECT_CALL(*delegate_, MaybeBundleAckOpportunistically()).Times(1); + } + return QuicPacketCreator::ConsumeData(id, data.length(), offset, state); + } + + MessageStatus AddMessageFrame(QuicMessageId message_id, + quiche::QuicheMemSlice message) { + if (!has_ack() && delegate_->ShouldGeneratePacket(NO_RETRANSMITTABLE_DATA, + NOT_HANDSHAKE)) { + EXPECT_CALL(*delegate_, MaybeBundleAckOpportunistically()).Times(1); + } + return QuicPacketCreator::AddMessageFrame(message_id, + absl::MakeSpan(&message, 1)); + } + + size_t ConsumeCryptoData(EncryptionLevel level, absl::string_view data, + QuicStreamOffset offset) { + producer_->SaveCryptoData(level, offset, data); + if (!has_ack() && delegate_->ShouldGeneratePacket(NO_RETRANSMITTABLE_DATA, + NOT_HANDSHAKE)) { + EXPECT_CALL(*delegate_, MaybeBundleAckOpportunistically()).Times(1); + } + return QuicPacketCreator::ConsumeCryptoData(level, data.length(), offset); + } + + QuicAckFrame ack_frame_; + MockDelegate* delegate_; + SimpleDataProducer* producer_; +}; + +class QuicPacketCreatorMultiplePacketsTest : public QuicTest { + public: + QuicPacketCreatorMultiplePacketsTest() + : framer_(AllSupportedVersions(), QuicTime::Zero(), + Perspective::IS_CLIENT, kQuicDefaultConnectionIdLength), + creator_(TestConnectionId(), &framer_, &random_creator_, &delegate_, + &producer_), + ack_frame_(InitAckFrame(1)) { + EXPECT_CALL(delegate_, GetPacketBuffer()) + .WillRepeatedly(Return(QuicPacketBuffer())); + EXPECT_CALL(delegate_, GetSerializedPacketFate(_, _)) + .WillRepeatedly(Return(SEND_TO_WRITER)); + creator_.SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE)); + creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + framer_.set_data_producer(&producer_); + if (simple_framer_.framer()->version().KnowsWhichDecrypterToUse()) { + simple_framer_.framer()->InstallDecrypter( + ENCRYPTION_FORWARD_SECURE, std::make_unique()); + } + creator_.AttachPacketFlusher(); + } + + ~QuicPacketCreatorMultiplePacketsTest() override {} + + void SavePacket(SerializedPacket packet) { + QUICHE_DCHECK(packet.release_encrypted_buffer == nullptr); + packet.encrypted_buffer = CopyBuffer(packet); + packet.release_encrypted_buffer = [](const char* p) { delete[] p; }; + packets_.push_back(std::move(packet)); + } + + protected: + QuicRstStreamFrame* CreateRstStreamFrame() { + return new QuicRstStreamFrame(1, 1, QUIC_STREAM_NO_ERROR, 0); + } + + QuicGoAwayFrame* CreateGoAwayFrame() { + return new QuicGoAwayFrame(2, QUIC_NO_ERROR, 1, std::string()); + } + + void CheckPacketContains(const PacketContents& contents, + size_t packet_index) { + ASSERT_GT(packets_.size(), packet_index); + const SerializedPacket& packet = packets_[packet_index]; + size_t num_retransmittable_frames = + contents.num_connection_close_frames + contents.num_goaway_frames + + contents.num_rst_stream_frames + contents.num_stream_frames + + contents.num_crypto_frames + contents.num_ping_frames; + size_t num_frames = + contents.num_ack_frames + contents.num_stop_waiting_frames + + contents.num_mtu_discovery_frames + contents.num_padding_frames + + num_retransmittable_frames; + + if (num_retransmittable_frames == 0) { + ASSERT_TRUE(packet.retransmittable_frames.empty()); + } else { + EXPECT_EQ(num_retransmittable_frames, + packet.retransmittable_frames.size()); + } + + ASSERT_TRUE(packet.encrypted_buffer != nullptr); + ASSERT_TRUE(simple_framer_.ProcessPacket( + QuicEncryptedPacket(packet.encrypted_buffer, packet.encrypted_length))); + size_t num_padding_frames = 0; + if (contents.num_padding_frames == 0) { + num_padding_frames = simple_framer_.padding_frames().size(); + } + EXPECT_EQ(num_frames + num_padding_frames, simple_framer_.num_frames()); + EXPECT_EQ(contents.num_ack_frames, simple_framer_.ack_frames().size()); + EXPECT_EQ(contents.num_connection_close_frames, + simple_framer_.connection_close_frames().size()); + EXPECT_EQ(contents.num_goaway_frames, + simple_framer_.goaway_frames().size()); + EXPECT_EQ(contents.num_rst_stream_frames, + simple_framer_.rst_stream_frames().size()); + EXPECT_EQ(contents.num_stream_frames, + simple_framer_.stream_frames().size()); + EXPECT_EQ(contents.num_crypto_frames, + simple_framer_.crypto_frames().size()); + EXPECT_EQ(contents.num_stop_waiting_frames, + simple_framer_.stop_waiting_frames().size()); + if (contents.num_padding_frames != 0) { + EXPECT_EQ(contents.num_padding_frames, + simple_framer_.padding_frames().size()); + } + + // From the receiver's perspective, MTU discovery frames are ping frames. + EXPECT_EQ(contents.num_ping_frames + contents.num_mtu_discovery_frames, + simple_framer_.ping_frames().size()); + } + + void CheckPacketHasSingleStreamFrame(size_t packet_index) { + ASSERT_GT(packets_.size(), packet_index); + const SerializedPacket& packet = packets_[packet_index]; + ASSERT_FALSE(packet.retransmittable_frames.empty()); + EXPECT_EQ(1u, packet.retransmittable_frames.size()); + ASSERT_TRUE(packet.encrypted_buffer != nullptr); + ASSERT_TRUE(simple_framer_.ProcessPacket( + QuicEncryptedPacket(packet.encrypted_buffer, packet.encrypted_length))); + EXPECT_EQ(1u, simple_framer_.num_frames()); + EXPECT_EQ(1u, simple_framer_.stream_frames().size()); + } + + void CheckAllPacketsHaveSingleStreamFrame() { + for (size_t i = 0; i < packets_.size(); i++) { + CheckPacketHasSingleStreamFrame(i); + } + } + + QuicFramer framer_; + MockRandom random_creator_; + StrictMock delegate_; + MultiplePacketsTestPacketCreator creator_; + SimpleQuicFramer simple_framer_; + std::vector packets_; + QuicAckFrame ack_frame_; + struct iovec iov_; + quiche::SimpleBufferAllocator allocator_; + + private: + std::unique_ptr data_array_; + SimpleDataProducer producer_; +}; + +TEST_F(QuicPacketCreatorMultiplePacketsTest, AddControlFrame_NotWritable) { + delegate_.SetCanNotWrite(); + + QuicRstStreamFrame* rst_frame = CreateRstStreamFrame(); + const bool consumed = + creator_.ConsumeRetransmittableControlFrame(QuicFrame(rst_frame), + /*bundle_ack=*/false); + EXPECT_FALSE(consumed); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + delete rst_frame; +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, + WrongEncryptionLevelForStreamDataFastPath) { + creator_.set_encryption_level(ENCRYPTION_HANDSHAKE); + delegate_.SetCanWriteAnything(); + const std::string data(10000, '?'); + EXPECT_CALL(delegate_, OnSerializedPacket(_)).Times(0); + EXPECT_QUIC_BUG( + { + EXPECT_CALL(delegate_, OnUnrecoverableError(_, _)); + creator_.ConsumeDataFastPath( + QuicUtils::GetFirstBidirectionalStreamId( + framer_.transport_version(), Perspective::IS_CLIENT), + data); + }, + ""); +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, AddControlFrame_OnlyAckWritable) { + delegate_.SetCanWriteOnlyNonRetransmittable(); + + QuicRstStreamFrame* rst_frame = CreateRstStreamFrame(); + const bool consumed = + creator_.ConsumeRetransmittableControlFrame(QuicFrame(rst_frame), + /*bundle_ack=*/false); + EXPECT_FALSE(consumed); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + delete rst_frame; +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, + AddControlFrame_WritableAndShouldNotFlush) { + delegate_.SetCanWriteAnything(); + + creator_.ConsumeRetransmittableControlFrame(QuicFrame(CreateRstStreamFrame()), + /*bundle_ack=*/false); + EXPECT_TRUE(creator_.HasPendingFrames()); + EXPECT_TRUE(creator_.HasPendingRetransmittableFrames()); +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, + AddControlFrame_NotWritableBatchThenFlush) { + delegate_.SetCanNotWrite(); + + QuicRstStreamFrame* rst_frame = CreateRstStreamFrame(); + const bool consumed = + creator_.ConsumeRetransmittableControlFrame(QuicFrame(rst_frame), + /*bundle_ack=*/false); + EXPECT_FALSE(consumed); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + delete rst_frame; +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, + AddControlFrame_WritableAndShouldFlush) { + delegate_.SetCanWriteAnything(); + + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + + creator_.ConsumeRetransmittableControlFrame(QuicFrame(CreateRstStreamFrame()), + /*bundle_ack=*/false); + creator_.Flush(); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + + PacketContents contents; + contents.num_rst_stream_frames = 1; + CheckPacketContains(contents, 0); +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, ConsumeCryptoData) { + delegate_.SetCanWriteAnything(); + + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + std::string data = "crypto data"; + size_t consumed_bytes = + creator_.ConsumeCryptoData(ENCRYPTION_INITIAL, data, 0); + creator_.Flush(); + EXPECT_EQ(data.length(), consumed_bytes); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + + PacketContents contents; + contents.num_crypto_frames = 1; + contents.num_padding_frames = 1; + CheckPacketContains(contents, 0); +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, + ConsumeCryptoDataCheckShouldGeneratePacket) { + delegate_.SetCanNotWrite(); + + EXPECT_CALL(delegate_, OnSerializedPacket(_)).Times(0); + std::string data = "crypto data"; + size_t consumed_bytes = + creator_.ConsumeCryptoData(ENCRYPTION_INITIAL, data, 0); + creator_.Flush(); + EXPECT_EQ(0u, consumed_bytes); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, ConsumeData_NotWritable) { + delegate_.SetCanNotWrite(); + + QuicConsumedData consumed = creator_.ConsumeData( + QuicUtils::GetFirstBidirectionalStreamId(framer_.transport_version(), + Perspective::IS_CLIENT), + "foo", 0, FIN); + EXPECT_EQ(0u, consumed.bytes_consumed); + EXPECT_FALSE(consumed.fin_consumed); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, + ConsumeData_WritableAndShouldNotFlush) { + delegate_.SetCanWriteAnything(); + + QuicConsumedData consumed = creator_.ConsumeData( + QuicUtils::GetFirstBidirectionalStreamId(framer_.transport_version(), + Perspective::IS_CLIENT), + "foo", 0, FIN); + EXPECT_EQ(3u, consumed.bytes_consumed); + EXPECT_TRUE(consumed.fin_consumed); + EXPECT_TRUE(creator_.HasPendingFrames()); + EXPECT_TRUE(creator_.HasPendingRetransmittableFrames()); +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, + ConsumeData_WritableAndShouldFlush) { + delegate_.SetCanWriteAnything(); + + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + QuicConsumedData consumed = creator_.ConsumeData( + QuicUtils::GetFirstBidirectionalStreamId(framer_.transport_version(), + Perspective::IS_CLIENT), + "foo", 0, FIN); + creator_.Flush(); + EXPECT_EQ(3u, consumed.bytes_consumed); + EXPECT_TRUE(consumed.fin_consumed); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + + PacketContents contents; + contents.num_stream_frames = 1; + CheckPacketContains(contents, 0); +} + +// Test the behavior of ConsumeData when the data consumed is for the crypto +// handshake stream. Ensure that the packet is always sent and padded even if +// the creator operates in batch mode. +TEST_F(QuicPacketCreatorMultiplePacketsTest, ConsumeData_Handshake) { + delegate_.SetCanWriteAnything(); + + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + const std::string data = "foo bar"; + size_t consumed_bytes = 0; + if (QuicVersionUsesCryptoFrames(framer_.transport_version())) { + consumed_bytes = creator_.ConsumeCryptoData(ENCRYPTION_INITIAL, data, 0); + } else { + consumed_bytes = + creator_ + .ConsumeData( + QuicUtils::GetCryptoStreamId(framer_.transport_version()), data, + 0, NO_FIN) + .bytes_consumed; + } + EXPECT_EQ(7u, consumed_bytes); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + + PacketContents contents; + if (QuicVersionUsesCryptoFrames(framer_.transport_version())) { + contents.num_crypto_frames = 1; + } else { + contents.num_stream_frames = 1; + } + contents.num_padding_frames = 1; + CheckPacketContains(contents, 0); + + ASSERT_EQ(1u, packets_.size()); + ASSERT_EQ(kDefaultMaxPacketSize, creator_.max_packet_length()); + EXPECT_EQ(kDefaultMaxPacketSize, packets_[0].encrypted_length); +} + +// Test the behavior of ConsumeData when the data is for the crypto handshake +// stream, but padding is disabled. +TEST_F(QuicPacketCreatorMultiplePacketsTest, + ConsumeData_Handshake_PaddingDisabled) { + creator_.set_fully_pad_crypto_handshake_packets(false); + + delegate_.SetCanWriteAnything(); + + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + const std::string data = "foo"; + size_t bytes_consumed = 0; + if (QuicVersionUsesCryptoFrames(framer_.transport_version())) { + bytes_consumed = creator_.ConsumeCryptoData(ENCRYPTION_INITIAL, data, 0); + } else { + bytes_consumed = + creator_ + .ConsumeData( + QuicUtils::GetCryptoStreamId(framer_.transport_version()), data, + 0, NO_FIN) + .bytes_consumed; + } + EXPECT_EQ(3u, bytes_consumed); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + + PacketContents contents; + if (QuicVersionUsesCryptoFrames(framer_.transport_version())) { + contents.num_crypto_frames = 1; + } else { + contents.num_stream_frames = 1; + } + contents.num_padding_frames = 0; + CheckPacketContains(contents, 0); + + ASSERT_EQ(1u, packets_.size()); + + // Packet is not fully padded, but we want to future packets to be larger. + ASSERT_EQ(kDefaultMaxPacketSize, creator_.max_packet_length()); + size_t expected_packet_length = 31; + if (QuicVersionUsesCryptoFrames(framer_.transport_version())) { + // The framing of CRYPTO frames is slightly different than that of stream + // frames, so the expected packet length differs slightly. + expected_packet_length = 32; + } + EXPECT_EQ(expected_packet_length, packets_[0].encrypted_length); +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, ConsumeData_EmptyData) { + delegate_.SetCanWriteAnything(); + + EXPECT_QUIC_BUG(creator_.ConsumeData( + QuicUtils::QuicUtils::GetFirstBidirectionalStreamId( + framer_.transport_version(), Perspective::IS_CLIENT), + {}, 0, NO_FIN), + "Attempt to consume empty data without FIN."); +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, + ConsumeDataMultipleTimes_WritableAndShouldNotFlush) { + delegate_.SetCanWriteAnything(); + + creator_.ConsumeData(QuicUtils::GetFirstBidirectionalStreamId( + framer_.transport_version(), Perspective::IS_CLIENT), + "foo", 0, FIN); + QuicConsumedData consumed = creator_.ConsumeData(3, "quux", 3, NO_FIN); + EXPECT_EQ(4u, consumed.bytes_consumed); + EXPECT_FALSE(consumed.fin_consumed); + EXPECT_TRUE(creator_.HasPendingFrames()); + EXPECT_TRUE(creator_.HasPendingRetransmittableFrames()); +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, ConsumeData_BatchOperations) { + delegate_.SetCanWriteAnything(); + + creator_.ConsumeData(QuicUtils::GetFirstBidirectionalStreamId( + framer_.transport_version(), Perspective::IS_CLIENT), + "foo", 0, NO_FIN); + QuicConsumedData consumed = creator_.ConsumeData( + QuicUtils::GetFirstBidirectionalStreamId(framer_.transport_version(), + Perspective::IS_CLIENT), + "quux", 3, FIN); + EXPECT_EQ(4u, consumed.bytes_consumed); + EXPECT_TRUE(consumed.fin_consumed); + EXPECT_TRUE(creator_.HasPendingFrames()); + EXPECT_TRUE(creator_.HasPendingRetransmittableFrames()); + + // Now both frames will be flushed out. + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + creator_.Flush(); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + + PacketContents contents; + contents.num_stream_frames = 1; + CheckPacketContains(contents, 0); +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, + ConsumeData_FramesPreviouslyQueued) { + // Set the packet size be enough for two stream frames with 0 stream offset, + // but not enough for a stream frame of 0 offset and one with non-zero offset. + size_t length = + TaggingEncrypter(0x00).GetCiphertextSize(0) + + GetPacketHeaderSize( + framer_.transport_version(), + creator_.GetDestinationConnectionIdLength(), + creator_.GetSourceConnectionIdLength(), + QuicPacketCreatorPeer::SendVersionInPacket(&creator_), + !kIncludeDiversificationNonce, + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_), + QuicPacketCreatorPeer::GetRetryTokenLengthLength(&creator_), 0, + QuicPacketCreatorPeer::GetLengthLength(&creator_)) + + // Add an extra 3 bytes for the payload and 1 byte so + // BytesFree is larger than the GetMinStreamFrameSize. + QuicFramer::GetMinStreamFrameSize(framer_.transport_version(), 1, 0, + false, 3) + + 3 + + QuicFramer::GetMinStreamFrameSize(framer_.transport_version(), 1, 0, true, + 1) + + 1; + creator_.SetMaxPacketLength(length); + delegate_.SetCanWriteAnything(); + { + InSequence dummy; + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + } + // Queue enough data to prevent a stream frame with a non-zero offset from + // fitting. + QuicConsumedData consumed = creator_.ConsumeData( + QuicUtils::GetFirstBidirectionalStreamId(framer_.transport_version(), + Perspective::IS_CLIENT), + "foo", 0, NO_FIN); + EXPECT_EQ(3u, consumed.bytes_consumed); + EXPECT_FALSE(consumed.fin_consumed); + EXPECT_TRUE(creator_.HasPendingFrames()); + EXPECT_TRUE(creator_.HasPendingRetransmittableFrames()); + + // This frame will not fit with the existing frame, causing the queued frame + // to be serialized, and it will be added to a new open packet. + consumed = creator_.ConsumeData( + QuicUtils::GetFirstBidirectionalStreamId(framer_.transport_version(), + Perspective::IS_CLIENT), + "bar", 3, FIN); + EXPECT_EQ(3u, consumed.bytes_consumed); + EXPECT_TRUE(consumed.fin_consumed); + EXPECT_TRUE(creator_.HasPendingFrames()); + EXPECT_TRUE(creator_.HasPendingRetransmittableFrames()); + + creator_.FlushCurrentPacket(); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + + PacketContents contents; + contents.num_stream_frames = 1; + CheckPacketContains(contents, 0); + CheckPacketContains(contents, 1); +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, ConsumeDataFastPath) { + delegate_.SetCanWriteAnything(); + creator_.SetTransmissionType(LOSS_RETRANSMISSION); + + const std::string data(10000, '?'); + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillRepeatedly( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + QuicConsumedData consumed = creator_.ConsumeDataFastPath( + QuicUtils::GetFirstBidirectionalStreamId(framer_.transport_version(), + Perspective::IS_CLIENT), + data); + EXPECT_EQ(10000u, consumed.bytes_consumed); + EXPECT_TRUE(consumed.fin_consumed); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + + PacketContents contents; + contents.num_stream_frames = 1; + CheckPacketContains(contents, 0); + EXPECT_FALSE(packets_.empty()); + SerializedPacket& packet = packets_.back(); + EXPECT_TRUE(!packet.retransmittable_frames.empty()); + EXPECT_EQ(LOSS_RETRANSMISSION, packet.transmission_type); + EXPECT_EQ(STREAM_FRAME, packet.retransmittable_frames.front().type); + const QuicStreamFrame& stream_frame = + packet.retransmittable_frames.front().stream_frame; + EXPECT_EQ(10000u, stream_frame.data_length + stream_frame.offset); +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, ConsumeDataLarge) { + delegate_.SetCanWriteAnything(); + + const std::string data(10000, '?'); + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillRepeatedly( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + QuicConsumedData consumed = creator_.ConsumeData( + QuicUtils::GetFirstBidirectionalStreamId(framer_.transport_version(), + Perspective::IS_CLIENT), + data, 0, FIN); + EXPECT_EQ(10000u, consumed.bytes_consumed); + EXPECT_TRUE(consumed.fin_consumed); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + + PacketContents contents; + contents.num_stream_frames = 1; + CheckPacketContains(contents, 0); + EXPECT_FALSE(packets_.empty()); + SerializedPacket& packet = packets_.back(); + EXPECT_TRUE(!packet.retransmittable_frames.empty()); + EXPECT_EQ(STREAM_FRAME, packet.retransmittable_frames.front().type); + const QuicStreamFrame& stream_frame = + packet.retransmittable_frames.front().stream_frame; + EXPECT_EQ(10000u, stream_frame.data_length + stream_frame.offset); +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, ConsumeDataLargeSendAckFalse) { + delegate_.SetCanNotWrite(); + + QuicRstStreamFrame* rst_frame = CreateRstStreamFrame(); + const bool success = + creator_.ConsumeRetransmittableControlFrame(QuicFrame(rst_frame), + /*bundle_ack=*/true); + EXPECT_FALSE(success); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + + delegate_.SetCanWriteAnything(); + + creator_.ConsumeRetransmittableControlFrame(QuicFrame(rst_frame), + /*bundle_ack=*/false); + + const std::string data(10000, '?'); + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillRepeatedly( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + creator_.ConsumeRetransmittableControlFrame(QuicFrame(CreateRstStreamFrame()), + /*bundle_ack=*/true); + QuicConsumedData consumed = creator_.ConsumeData( + QuicUtils::GetFirstBidirectionalStreamId(framer_.transport_version(), + Perspective::IS_CLIENT), + data, 0, FIN); + creator_.Flush(); + + EXPECT_EQ(10000u, consumed.bytes_consumed); + EXPECT_TRUE(consumed.fin_consumed); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + + EXPECT_FALSE(packets_.empty()); + SerializedPacket& packet = packets_.back(); + EXPECT_TRUE(!packet.retransmittable_frames.empty()); + EXPECT_EQ(STREAM_FRAME, packet.retransmittable_frames.front().type); + const QuicStreamFrame& stream_frame = + packet.retransmittable_frames.front().stream_frame; + EXPECT_EQ(10000u, stream_frame.data_length + stream_frame.offset); +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, ConsumeDataLargeSendAckTrue) { + delegate_.SetCanNotWrite(); + delegate_.SetCanWriteAnything(); + + const std::string data(10000, '?'); + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillRepeatedly( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + QuicConsumedData consumed = creator_.ConsumeData( + QuicUtils::GetFirstBidirectionalStreamId(framer_.transport_version(), + Perspective::IS_CLIENT), + data, 0, FIN); + creator_.Flush(); + + EXPECT_EQ(10000u, consumed.bytes_consumed); + EXPECT_TRUE(consumed.fin_consumed); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + + EXPECT_FALSE(packets_.empty()); + SerializedPacket& packet = packets_.back(); + EXPECT_TRUE(!packet.retransmittable_frames.empty()); + EXPECT_EQ(STREAM_FRAME, packet.retransmittable_frames.front().type); + const QuicStreamFrame& stream_frame = + packet.retransmittable_frames.front().stream_frame; + EXPECT_EQ(10000u, stream_frame.data_length + stream_frame.offset); +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, NotWritableThenBatchOperations) { + delegate_.SetCanNotWrite(); + + QuicRstStreamFrame* rst_frame = CreateRstStreamFrame(); + const bool consumed = + creator_.ConsumeRetransmittableControlFrame(QuicFrame(rst_frame), + /*bundle_ack=*/true); + EXPECT_FALSE(consumed); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + EXPECT_FALSE(creator_.HasPendingStreamFramesOfStream(3)); + + delegate_.SetCanWriteAnything(); + + EXPECT_TRUE( + creator_.ConsumeRetransmittableControlFrame(QuicFrame(rst_frame), + /*bundle_ack=*/false)); + // Send some data and a control frame + creator_.ConsumeData(3, "quux", 0, NO_FIN); + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + creator_.ConsumeRetransmittableControlFrame(QuicFrame(CreateGoAwayFrame()), + /*bundle_ack=*/false); + } + EXPECT_TRUE(creator_.HasPendingStreamFramesOfStream(3)); + + // All five frames will be flushed out in a single packet. + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + creator_.Flush(); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + EXPECT_FALSE(creator_.HasPendingStreamFramesOfStream(3)); + + PacketContents contents; + // ACK will be flushed by connection. + contents.num_ack_frames = 0; + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + contents.num_goaway_frames = 1; + } else { + contents.num_goaway_frames = 0; + } + contents.num_rst_stream_frames = 1; + contents.num_stream_frames = 1; + CheckPacketContains(contents, 0); +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, NotWritableThenBatchOperations2) { + delegate_.SetCanNotWrite(); + + QuicRstStreamFrame* rst_frame = CreateRstStreamFrame(); + const bool success = + creator_.ConsumeRetransmittableControlFrame(QuicFrame(rst_frame), + /*bundle_ack=*/true); + EXPECT_FALSE(success); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + + delegate_.SetCanWriteAnything(); + + { + InSequence dummy; + // All five frames will be flushed out in a single packet + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + } + EXPECT_TRUE( + creator_.ConsumeRetransmittableControlFrame(QuicFrame(rst_frame), + /*bundle_ack=*/false)); + // Send enough data to exceed one packet + size_t data_len = kDefaultMaxPacketSize + 100; + const std::string data(data_len, '?'); + QuicConsumedData consumed = creator_.ConsumeData(3, data, 0, FIN); + EXPECT_EQ(data_len, consumed.bytes_consumed); + EXPECT_TRUE(consumed.fin_consumed); + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + creator_.ConsumeRetransmittableControlFrame(QuicFrame(CreateGoAwayFrame()), + /*bundle_ack=*/false); + } + + creator_.Flush(); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + + // The first packet should have the queued data and part of the stream data. + PacketContents contents; + // ACK will be sent by connection. + contents.num_ack_frames = 0; + contents.num_rst_stream_frames = 1; + contents.num_stream_frames = 1; + CheckPacketContains(contents, 0); + + // The second should have the remainder of the stream data. + PacketContents contents2; + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + contents2.num_goaway_frames = 1; + } else { + contents2.num_goaway_frames = 0; + } + contents2.num_stream_frames = 1; + CheckPacketContains(contents2, 1); +} + +// Regression test of b/120493795. +TEST_F(QuicPacketCreatorMultiplePacketsTest, PacketTransmissionType) { + delegate_.SetCanWriteAnything(); + + // The first ConsumeData will fill the packet without flush. + creator_.SetTransmissionType(LOSS_RETRANSMISSION); + + size_t data_len = 1220; + const std::string data(data_len, '?'); + QuicStreamId stream1_id = QuicUtils::GetFirstBidirectionalStreamId( + framer_.transport_version(), Perspective::IS_CLIENT); + QuicConsumedData consumed = creator_.ConsumeData(stream1_id, data, 0, NO_FIN); + EXPECT_EQ(data_len, consumed.bytes_consumed); + ASSERT_EQ(0u, creator_.BytesFree()) + << "Test setup failed: Please increase data_len to " + << data_len + creator_.BytesFree() << " bytes."; + + // The second ConsumeData can not be added to the packet and will flush. + creator_.SetTransmissionType(NOT_RETRANSMISSION); + + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + + QuicStreamId stream2_id = stream1_id + 4; + + consumed = creator_.ConsumeData(stream2_id, data, 0, NO_FIN); + EXPECT_EQ(data_len, consumed.bytes_consumed); + + // Ensure the packet is successfully created. + ASSERT_EQ(1u, packets_.size()); + ASSERT_TRUE(packets_[0].encrypted_buffer); + ASSERT_EQ(1u, packets_[0].retransmittable_frames.size()); + EXPECT_EQ(stream1_id, + packets_[0].retransmittable_frames[0].stream_frame.stream_id); + + // Since the second frame was not added, the packet's transmission type + // should be the first frame's type. + EXPECT_EQ(packets_[0].transmission_type, LOSS_RETRANSMISSION); +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, TestConnectionIdLength) { + QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_SERVER); + creator_.SetServerConnectionIdLength(0); + EXPECT_EQ(0, creator_.GetDestinationConnectionIdLength()); + + for (size_t i = 1; i < 10; i++) { + creator_.SetServerConnectionIdLength(i); + if (framer_.version().HasIetfInvariantHeader()) { + EXPECT_EQ(0, creator_.GetDestinationConnectionIdLength()); + } else { + EXPECT_EQ(8, creator_.GetDestinationConnectionIdLength()); + } + } +} + +// Test whether SetMaxPacketLength() works in the situation when the queue is +// empty, and we send three packets worth of data. +TEST_F(QuicPacketCreatorMultiplePacketsTest, SetMaxPacketLength_Initial) { + delegate_.SetCanWriteAnything(); + + // Send enough data for three packets. + size_t data_len = 3 * kDefaultMaxPacketSize + 1; + size_t packet_len = kDefaultMaxPacketSize + 100; + ASSERT_LE(packet_len, kMaxOutgoingPacketSize); + creator_.SetMaxPacketLength(packet_len); + EXPECT_EQ(packet_len, creator_.max_packet_length()); + + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .Times(3) + .WillRepeatedly( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + const std::string data(data_len, '?'); + QuicConsumedData consumed = creator_.ConsumeData( + QuicUtils::GetFirstBidirectionalStreamId(framer_.transport_version(), + Perspective::IS_CLIENT), + data, + /*offset=*/0, FIN); + EXPECT_EQ(data_len, consumed.bytes_consumed); + EXPECT_TRUE(consumed.fin_consumed); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + + // We expect three packets, and first two of them have to be of packet_len + // size. We check multiple packets (instead of just one) because we want to + // ensure that |max_packet_length_| does not get changed incorrectly by the + // creator after first packet is serialized. + ASSERT_EQ(3u, packets_.size()); + EXPECT_EQ(packet_len, packets_[0].encrypted_length); + EXPECT_EQ(packet_len, packets_[1].encrypted_length); + CheckAllPacketsHaveSingleStreamFrame(); +} + +// Test whether SetMaxPacketLength() works in the situation when we first write +// data, then change packet size, then write data again. +TEST_F(QuicPacketCreatorMultiplePacketsTest, SetMaxPacketLength_Middle) { + delegate_.SetCanWriteAnything(); + + // We send enough data to overflow default packet length, but not the altered + // one. + size_t data_len = kDefaultMaxPacketSize; + size_t packet_len = kDefaultMaxPacketSize + 100; + ASSERT_LE(packet_len, kMaxOutgoingPacketSize); + + // We expect to see three packets in total. + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .Times(3) + .WillRepeatedly( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + + // Send two packets before packet size change. + const std::string data(data_len, '?'); + QuicConsumedData consumed = creator_.ConsumeData( + QuicUtils::GetFirstBidirectionalStreamId(framer_.transport_version(), + Perspective::IS_CLIENT), + data, + /*offset=*/0, NO_FIN); + creator_.Flush(); + EXPECT_EQ(data_len, consumed.bytes_consumed); + EXPECT_FALSE(consumed.fin_consumed); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + + // Make sure we already have two packets. + ASSERT_EQ(2u, packets_.size()); + + // Increase packet size. + creator_.SetMaxPacketLength(packet_len); + EXPECT_EQ(packet_len, creator_.max_packet_length()); + + // Send a packet after packet size change. + creator_.AttachPacketFlusher(); + consumed = creator_.ConsumeData( + QuicUtils::GetFirstBidirectionalStreamId(framer_.transport_version(), + Perspective::IS_CLIENT), + data, data_len, FIN); + creator_.Flush(); + EXPECT_EQ(data_len, consumed.bytes_consumed); + EXPECT_TRUE(consumed.fin_consumed); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + + // We expect first data chunk to get fragmented, but the second one to fit + // into a single packet. + ASSERT_EQ(3u, packets_.size()); + EXPECT_EQ(kDefaultMaxPacketSize, packets_[0].encrypted_length); + EXPECT_LE(kDefaultMaxPacketSize, packets_[2].encrypted_length); + CheckAllPacketsHaveSingleStreamFrame(); +} + +// Test whether SetMaxPacketLength() works correctly when we force the change of +// the packet size in the middle of the batched packet. +TEST_F(QuicPacketCreatorMultiplePacketsTest, + SetMaxPacketLength_MidpacketFlush) { + delegate_.SetCanWriteAnything(); + + size_t first_write_len = kDefaultMaxPacketSize / 2; + size_t packet_len = kDefaultMaxPacketSize + 100; + size_t second_write_len = packet_len + 1; + ASSERT_LE(packet_len, kMaxOutgoingPacketSize); + + // First send half of the packet worth of data. We are in the batch mode, so + // should not cause packet serialization. + const std::string first_write(first_write_len, '?'); + QuicConsumedData consumed = creator_.ConsumeData( + QuicUtils::GetFirstBidirectionalStreamId(framer_.transport_version(), + Perspective::IS_CLIENT), + first_write, + /*offset=*/0, NO_FIN); + EXPECT_EQ(first_write_len, consumed.bytes_consumed); + EXPECT_FALSE(consumed.fin_consumed); + EXPECT_TRUE(creator_.HasPendingFrames()); + EXPECT_TRUE(creator_.HasPendingRetransmittableFrames()); + + // Make sure we have no packets so far. + ASSERT_EQ(0u, packets_.size()); + + // Expect a packet to be flushed. + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + + // Increase packet size after flushing all frames. + // Ensure it's immediately enacted. + creator_.FlushCurrentPacket(); + creator_.SetMaxPacketLength(packet_len); + EXPECT_EQ(packet_len, creator_.max_packet_length()); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + + // We expect to see exactly one packet serialized after that, because we send + // a value somewhat exceeding new max packet size, and the tail data does not + // get serialized because we are still in the batch mode. + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + + // Send a more than a packet worth of data to the same stream. This should + // trigger serialization of one packet, and queue another one. + const std::string second_write(second_write_len, '?'); + consumed = creator_.ConsumeData( + QuicUtils::GetFirstBidirectionalStreamId(framer_.transport_version(), + Perspective::IS_CLIENT), + second_write, + /*offset=*/first_write_len, FIN); + EXPECT_EQ(second_write_len, consumed.bytes_consumed); + EXPECT_TRUE(consumed.fin_consumed); + EXPECT_TRUE(creator_.HasPendingFrames()); + EXPECT_TRUE(creator_.HasPendingRetransmittableFrames()); + + // We expect the first packet to be underfilled, and the second packet be up + // to the new max packet size. + ASSERT_EQ(2u, packets_.size()); + EXPECT_GT(kDefaultMaxPacketSize, packets_[0].encrypted_length); + EXPECT_EQ(packet_len, packets_[1].encrypted_length); + + CheckAllPacketsHaveSingleStreamFrame(); +} + +// Test sending a connectivity probing packet. +TEST_F(QuicPacketCreatorMultiplePacketsTest, + GenerateConnectivityProbingPacket) { + delegate_.SetCanWriteAnything(); + + std::unique_ptr probing_packet; + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + QuicPathFrameBuffer payload = { + {0xde, 0xad, 0xbe, 0xef, 0xba, 0xdc, 0x0f, 0xfe}}; + probing_packet = + creator_.SerializePathChallengeConnectivityProbingPacket(payload); + } else { + probing_packet = creator_.SerializeConnectivityProbingPacket(); + } + + ASSERT_TRUE(simple_framer_.ProcessPacket(QuicEncryptedPacket( + probing_packet->encrypted_buffer, probing_packet->encrypted_length))); + + EXPECT_EQ(2u, simple_framer_.num_frames()); + if (VersionHasIetfQuicFrames(framer_.transport_version())) { + EXPECT_EQ(1u, simple_framer_.path_challenge_frames().size()); + } else { + EXPECT_EQ(1u, simple_framer_.ping_frames().size()); + } + EXPECT_EQ(1u, simple_framer_.padding_frames().size()); +} + +// Test sending an MTU probe, without any surrounding data. +TEST_F(QuicPacketCreatorMultiplePacketsTest, + GenerateMtuDiscoveryPacket_Simple) { + delegate_.SetCanWriteAnything(); + + const size_t target_mtu = kDefaultMaxPacketSize + 100; + static_assert(target_mtu < kMaxOutgoingPacketSize, + "The MTU probe used by the test exceeds maximum packet size"); + + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + + creator_.GenerateMtuDiscoveryPacket(target_mtu); + + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + ASSERT_EQ(1u, packets_.size()); + EXPECT_EQ(target_mtu, packets_[0].encrypted_length); + + PacketContents contents; + contents.num_mtu_discovery_frames = 1; + contents.num_padding_frames = 1; + CheckPacketContains(contents, 0); +} + +// Test sending an MTU probe. Surround it with data, to ensure that it resets +// the MTU to the value before the probe was sent. +TEST_F(QuicPacketCreatorMultiplePacketsTest, + GenerateMtuDiscoveryPacket_SurroundedByData) { + delegate_.SetCanWriteAnything(); + + const size_t target_mtu = kDefaultMaxPacketSize + 100; + static_assert(target_mtu < kMaxOutgoingPacketSize, + "The MTU probe used by the test exceeds maximum packet size"); + + // Send enough data so it would always cause two packets to be sent. + const size_t data_len = target_mtu + 1; + + // Send a total of five packets: two packets before the probe, the probe + // itself, and two packets after the probe. + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .Times(5) + .WillRepeatedly( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + + // Send data before the MTU probe. + const std::string data(data_len, '?'); + QuicConsumedData consumed = creator_.ConsumeData( + QuicUtils::GetFirstBidirectionalStreamId(framer_.transport_version(), + Perspective::IS_CLIENT), + data, + /*offset=*/0, NO_FIN); + creator_.Flush(); + EXPECT_EQ(data_len, consumed.bytes_consumed); + EXPECT_FALSE(consumed.fin_consumed); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + + // Send the MTU probe. + creator_.GenerateMtuDiscoveryPacket(target_mtu); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + + // Send data after the MTU probe. + creator_.AttachPacketFlusher(); + consumed = creator_.ConsumeData( + QuicUtils::GetFirstBidirectionalStreamId(framer_.transport_version(), + Perspective::IS_CLIENT), + data, + /*offset=*/data_len, FIN); + creator_.Flush(); + EXPECT_EQ(data_len, consumed.bytes_consumed); + EXPECT_TRUE(consumed.fin_consumed); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + + ASSERT_EQ(5u, packets_.size()); + EXPECT_EQ(kDefaultMaxPacketSize, packets_[0].encrypted_length); + EXPECT_EQ(target_mtu, packets_[2].encrypted_length); + EXPECT_EQ(kDefaultMaxPacketSize, packets_[3].encrypted_length); + + PacketContents probe_contents; + probe_contents.num_mtu_discovery_frames = 1; + probe_contents.num_padding_frames = 1; + + CheckPacketHasSingleStreamFrame(0); + CheckPacketHasSingleStreamFrame(1); + CheckPacketContains(probe_contents, 2); + CheckPacketHasSingleStreamFrame(3); + CheckPacketHasSingleStreamFrame(4); +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, DontCrashOnInvalidStopWaiting) { + if (VersionSupportsMessageFrames(framer_.transport_version())) { + return; + } + // Test added to ensure the creator does not crash when an invalid frame is + // added. Because this is an indication of internal programming errors, + // DFATALs are expected. + // A 1 byte packet number length can't encode a gap of 1000. + QuicPacketCreatorPeer::SetPacketNumber(&creator_, 1000); + + delegate_.SetCanNotWrite(); + delegate_.SetCanWriteAnything(); + + // This will not serialize any packets, because of the invalid frame. + EXPECT_CALL(delegate_, + OnUnrecoverableError(QUIC_FAILED_TO_SERIALIZE_PACKET, _)); + EXPECT_QUIC_BUG(creator_.Flush(), + "packet_number_length 1 is too small " + "for least_unacked_delta: 1001"); +} + +// Regression test for b/31486443. +TEST_F(QuicPacketCreatorMultiplePacketsTest, + ConnectionCloseFrameLargerThanPacketSize) { + delegate_.SetCanWriteAnything(); + char buf[2000] = {}; + absl::string_view error_details(buf, 2000); + const QuicErrorCode kQuicErrorCode = QUIC_PACKET_WRITE_ERROR; + + QuicConnectionCloseFrame* frame = new QuicConnectionCloseFrame( + framer_.transport_version(), kQuicErrorCode, NO_IETF_QUIC_ERROR, + std::string(error_details), + /*transport_close_frame_type=*/0); + creator_.ConsumeRetransmittableControlFrame(QuicFrame(frame), + /*bundle_ack=*/false); + EXPECT_TRUE(creator_.HasPendingFrames()); + EXPECT_TRUE(creator_.HasPendingRetransmittableFrames()); +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, + RandomPaddingAfterFinSingleStreamSinglePacket) { + const QuicByteCount kStreamFramePayloadSize = 100u; + char buf[kStreamFramePayloadSize] = {}; + const QuicStreamId kDataStreamId = 5; + // Set the packet size be enough for one stream frame with 0 stream offset and + // max size of random padding. + size_t length = + TaggingEncrypter(0x00).GetCiphertextSize(0) + + GetPacketHeaderSize( + framer_.transport_version(), + creator_.GetDestinationConnectionIdLength(), + creator_.GetSourceConnectionIdLength(), + QuicPacketCreatorPeer::SendVersionInPacket(&creator_), + !kIncludeDiversificationNonce, + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_), + QuicPacketCreatorPeer::GetRetryTokenLengthLength(&creator_), 0, + QuicPacketCreatorPeer::GetLengthLength(&creator_)) + + QuicFramer::GetMinStreamFrameSize( + framer_.transport_version(), kDataStreamId, 0, + /*last_frame_in_packet=*/false, + kStreamFramePayloadSize + kMaxNumRandomPaddingBytes) + + kStreamFramePayloadSize + kMaxNumRandomPaddingBytes; + creator_.SetMaxPacketLength(length); + delegate_.SetCanWriteAnything(); + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + QuicConsumedData consumed = creator_.ConsumeData( + kDataStreamId, absl::string_view(buf, kStreamFramePayloadSize), 0, + FIN_AND_PADDING); + creator_.Flush(); + EXPECT_EQ(kStreamFramePayloadSize, consumed.bytes_consumed); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + + EXPECT_EQ(1u, packets_.size()); + PacketContents contents; + // The packet has both stream and padding frames. + contents.num_padding_frames = 1; + contents.num_stream_frames = 1; + CheckPacketContains(contents, 0); +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, + RandomPaddingAfterFinSingleStreamMultiplePackets) { + const QuicByteCount kStreamFramePayloadSize = 100u; + char buf[kStreamFramePayloadSize] = {}; + const QuicStreamId kDataStreamId = 5; + // Set the packet size be enough for one stream frame with 0 stream offset + + // 1. One or more packets will accommodate. + size_t length = + TaggingEncrypter(0x00).GetCiphertextSize(0) + + GetPacketHeaderSize( + framer_.transport_version(), + creator_.GetDestinationConnectionIdLength(), + creator_.GetSourceConnectionIdLength(), + QuicPacketCreatorPeer::SendVersionInPacket(&creator_), + !kIncludeDiversificationNonce, + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_), + QuicPacketCreatorPeer::GetRetryTokenLengthLength(&creator_), 0, + QuicPacketCreatorPeer::GetLengthLength(&creator_)) + + QuicFramer::GetMinStreamFrameSize( + framer_.transport_version(), kDataStreamId, 0, + /*last_frame_in_packet=*/false, kStreamFramePayloadSize + 1) + + kStreamFramePayloadSize + 1; + creator_.SetMaxPacketLength(length); + delegate_.SetCanWriteAnything(); + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillRepeatedly( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + QuicConsumedData consumed = creator_.ConsumeData( + kDataStreamId, absl::string_view(buf, kStreamFramePayloadSize), 0, + FIN_AND_PADDING); + creator_.Flush(); + EXPECT_EQ(kStreamFramePayloadSize, consumed.bytes_consumed); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + + EXPECT_LE(1u, packets_.size()); + PacketContents contents; + // The first packet has both stream and padding frames. + contents.num_stream_frames = 1; + contents.num_padding_frames = 1; + CheckPacketContains(contents, 0); + + for (size_t i = 1; i < packets_.size(); ++i) { + // Following packets only have paddings. + contents.num_stream_frames = 0; + contents.num_padding_frames = 1; + CheckPacketContains(contents, i); + } +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, + RandomPaddingAfterFinMultipleStreamsMultiplePackets) { + const QuicByteCount kStreamFramePayloadSize = 100u; + char buf[kStreamFramePayloadSize] = {}; + const QuicStreamId kDataStreamId1 = 5; + const QuicStreamId kDataStreamId2 = 6; + // Set the packet size be enough for first frame with 0 stream offset + second + // frame + 1 byte payload. two or more packets will accommodate. + size_t length = + TaggingEncrypter(0x00).GetCiphertextSize(0) + + GetPacketHeaderSize( + framer_.transport_version(), + creator_.GetDestinationConnectionIdLength(), + creator_.GetSourceConnectionIdLength(), + QuicPacketCreatorPeer::SendVersionInPacket(&creator_), + !kIncludeDiversificationNonce, + QuicPacketCreatorPeer::GetPacketNumberLength(&creator_), + QuicPacketCreatorPeer::GetRetryTokenLengthLength(&creator_), 0, + QuicPacketCreatorPeer::GetLengthLength(&creator_)) + + QuicFramer::GetMinStreamFrameSize( + framer_.transport_version(), kDataStreamId1, 0, + /*last_frame_in_packet=*/false, kStreamFramePayloadSize) + + kStreamFramePayloadSize + + QuicFramer::GetMinStreamFrameSize(framer_.transport_version(), + kDataStreamId1, 0, + /*last_frame_in_packet=*/false, 1) + + 1; + creator_.SetMaxPacketLength(length); + delegate_.SetCanWriteAnything(); + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillRepeatedly( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + QuicConsumedData consumed = creator_.ConsumeData( + kDataStreamId1, absl::string_view(buf, kStreamFramePayloadSize), 0, + FIN_AND_PADDING); + EXPECT_EQ(kStreamFramePayloadSize, consumed.bytes_consumed); + consumed = creator_.ConsumeData( + kDataStreamId2, absl::string_view(buf, kStreamFramePayloadSize), 0, + FIN_AND_PADDING); + EXPECT_EQ(kStreamFramePayloadSize, consumed.bytes_consumed); + creator_.Flush(); + EXPECT_FALSE(creator_.HasPendingFrames()); + EXPECT_FALSE(creator_.HasPendingRetransmittableFrames()); + + EXPECT_LE(2u, packets_.size()); + PacketContents contents; + // The first packet has two stream frames. + contents.num_stream_frames = 2; + CheckPacketContains(contents, 0); + + // The second packet has one stream frame and padding frames. + contents.num_stream_frames = 1; + contents.num_padding_frames = 1; + CheckPacketContains(contents, 1); + + for (size_t i = 2; i < packets_.size(); ++i) { + // Following packets only have paddings. + contents.num_stream_frames = 0; + contents.num_padding_frames = 1; + CheckPacketContains(contents, i); + } +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, AddMessageFrame) { + if (!VersionSupportsMessageFrames(framer_.transport_version())) { + return; + } + if (framer_.version().UsesTls()) { + creator_.SetMaxDatagramFrameSize(kMaxAcceptedDatagramFrameSize); + } + delegate_.SetCanWriteAnything(); + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + + creator_.ConsumeData(QuicUtils::GetFirstBidirectionalStreamId( + framer_.transport_version(), Perspective::IS_CLIENT), + "foo", 0, FIN); + EXPECT_EQ(MESSAGE_STATUS_SUCCESS, + creator_.AddMessageFrame(1, MemSliceFromString("message"))); + EXPECT_TRUE(creator_.HasPendingFrames()); + EXPECT_TRUE(creator_.HasPendingRetransmittableFrames()); + + // Add a message which causes the flush of current packet. + EXPECT_EQ(MESSAGE_STATUS_SUCCESS, + creator_.AddMessageFrame( + 2, MemSliceFromString(std::string( + creator_.GetCurrentLargestMessagePayload(), 'a')))); + EXPECT_TRUE(creator_.HasPendingRetransmittableFrames()); + + // Failed to send messages which cannot fit into one packet. + EXPECT_EQ(MESSAGE_STATUS_TOO_LARGE, + creator_.AddMessageFrame( + 3, MemSliceFromString(std::string( + creator_.GetCurrentLargestMessagePayload() + 10, 'a')))); +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, ConnectionId) { + creator_.SetServerConnectionId(TestConnectionId(0x1337)); + EXPECT_EQ(TestConnectionId(0x1337), creator_.GetDestinationConnectionId()); + EXPECT_EQ(EmptyQuicConnectionId(), creator_.GetSourceConnectionId()); + if (!framer_.version().SupportsClientConnectionIds()) { + return; + } + creator_.SetClientConnectionId(TestConnectionId(0x33)); + EXPECT_EQ(TestConnectionId(0x1337), creator_.GetDestinationConnectionId()); + EXPECT_EQ(TestConnectionId(0x33), creator_.GetSourceConnectionId()); +} + +// Regresstion test for b/159812345. +TEST_F(QuicPacketCreatorMultiplePacketsTest, ExtraPaddingNeeded) { + if (!framer_.version().HasHeaderProtection()) { + return; + } + delegate_.SetCanWriteAnything(); + // If the packet number length > 1, we won't get padding. + EXPECT_EQ(QuicPacketCreatorPeer::GetPacketNumberLength(&creator_), + PACKET_1BYTE_PACKET_NUMBER); + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce( + Invoke(this, &QuicPacketCreatorMultiplePacketsTest::SavePacket)); + // with no data and no offset, this is a 2B STREAM frame. + creator_.ConsumeData(QuicUtils::GetFirstBidirectionalStreamId( + framer_.transport_version(), Perspective::IS_CLIENT), + "", 0, FIN); + creator_.Flush(); + ASSERT_FALSE(packets_[0].nonretransmittable_frames.empty()); + QuicFrame padding = packets_[0].nonretransmittable_frames[0]; + // Verify stream frame expansion is excluded. + EXPECT_EQ(padding.padding_frame.num_padding_bytes, 1); +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, + PeerAddressContextWithSameAddress) { + QuicConnectionId client_connection_id = TestConnectionId(1); + QuicConnectionId server_connection_id = TestConnectionId(2); + QuicSocketAddress peer_addr(QuicIpAddress::Any4(), 12345); + creator_.SetDefaultPeerAddress(peer_addr); + creator_.SetClientConnectionId(client_connection_id); + creator_.SetServerConnectionId(server_connection_id); + // Send some stream data. + EXPECT_CALL(delegate_, ShouldGeneratePacket(_, _)) + .WillRepeatedly(Return(true)); + QuicConsumedData consumed = creator_.ConsumeData( + QuicUtils::GetFirstBidirectionalStreamId(creator_.transport_version(), + Perspective::IS_CLIENT), + "foo", 0, NO_FIN); + EXPECT_EQ(3u, consumed.bytes_consumed); + EXPECT_TRUE(creator_.HasPendingFrames()); + { + // Set the same address via context which should not trigger flush. + QuicPacketCreator::ScopedPeerAddressContext context( + &creator_, peer_addr, client_connection_id, server_connection_id, + /*update_connection_id=*/true); + ASSERT_EQ(client_connection_id, creator_.GetClientConnectionId()); + ASSERT_EQ(server_connection_id, creator_.GetServerConnectionId()); + EXPECT_TRUE(creator_.HasPendingFrames()); + // Queue another STREAM_FRAME. + QuicConsumedData consumed = creator_.ConsumeData( + QuicUtils::GetFirstBidirectionalStreamId(creator_.transport_version(), + Perspective::IS_CLIENT), + "foo", 0, FIN); + EXPECT_EQ(3u, consumed.bytes_consumed); + } + // After exiting the scope, the last queued frame should be flushed. + EXPECT_TRUE(creator_.HasPendingFrames()); + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce(Invoke([=](SerializedPacket packet) { + EXPECT_EQ(peer_addr, packet.peer_address); + ASSERT_EQ(2u, packet.retransmittable_frames.size()); + EXPECT_EQ(STREAM_FRAME, packet.retransmittable_frames.front().type); + EXPECT_EQ(STREAM_FRAME, packet.retransmittable_frames.back().type); + })); + creator_.FlushCurrentPacket(); +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, + PeerAddressContextWithDifferentAddress) { + QuicSocketAddress peer_addr(QuicIpAddress::Any4(), 12345); + creator_.SetDefaultPeerAddress(peer_addr); + // Send some stream data. + EXPECT_CALL(delegate_, ShouldGeneratePacket(_, _)) + .WillRepeatedly(Return(true)); + QuicConsumedData consumed = creator_.ConsumeData( + QuicUtils::GetFirstBidirectionalStreamId(creator_.transport_version(), + Perspective::IS_CLIENT), + "foo", 0, NO_FIN); + EXPECT_EQ(3u, consumed.bytes_consumed); + + QuicSocketAddress peer_addr1(QuicIpAddress::Any4(), 12346); + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce(Invoke([=](SerializedPacket packet) { + EXPECT_EQ(peer_addr, packet.peer_address); + ASSERT_EQ(1u, packet.retransmittable_frames.size()); + EXPECT_EQ(STREAM_FRAME, packet.retransmittable_frames.front().type); + })) + .WillOnce(Invoke([=](SerializedPacket packet) { + EXPECT_EQ(peer_addr1, packet.peer_address); + ASSERT_EQ(1u, packet.retransmittable_frames.size()); + EXPECT_EQ(STREAM_FRAME, packet.retransmittable_frames.front().type); + })); + EXPECT_TRUE(creator_.HasPendingFrames()); + { + QuicConnectionId client_connection_id = TestConnectionId(1); + QuicConnectionId server_connection_id = TestConnectionId(2); + // Set a different address via context which should trigger flush. + QuicPacketCreator::ScopedPeerAddressContext context( + &creator_, peer_addr1, client_connection_id, server_connection_id, + /*update_connection_id=*/true); + ASSERT_EQ(client_connection_id, creator_.GetClientConnectionId()); + ASSERT_EQ(server_connection_id, creator_.GetServerConnectionId()); + EXPECT_FALSE(creator_.HasPendingFrames()); + // Queue another STREAM_FRAME. + QuicConsumedData consumed = creator_.ConsumeData( + QuicUtils::GetFirstBidirectionalStreamId(creator_.transport_version(), + Perspective::IS_CLIENT), + "foo", 0, FIN); + EXPECT_EQ(3u, consumed.bytes_consumed); + EXPECT_TRUE(creator_.HasPendingFrames()); + } + // After exiting the scope, the last queued frame should be flushed. + EXPECT_FALSE(creator_.HasPendingFrames()); +} + +TEST_F(QuicPacketCreatorMultiplePacketsTest, + NestedPeerAddressContextWithDifferentAddress) { + QuicConnectionId client_connection_id1 = creator_.GetClientConnectionId(); + QuicConnectionId server_connection_id1 = creator_.GetServerConnectionId(); + QuicSocketAddress peer_addr(QuicIpAddress::Any4(), 12345); + creator_.SetDefaultPeerAddress(peer_addr); + QuicPacketCreator::ScopedPeerAddressContext context( + &creator_, peer_addr, client_connection_id1, server_connection_id1, + /*update_connection_id=*/true); + ASSERT_EQ(client_connection_id1, creator_.GetClientConnectionId()); + ASSERT_EQ(server_connection_id1, creator_.GetServerConnectionId()); + + // Send some stream data. + EXPECT_CALL(delegate_, ShouldGeneratePacket(_, _)) + .WillRepeatedly(Return(true)); + QuicConsumedData consumed = creator_.ConsumeData( + QuicUtils::GetFirstBidirectionalStreamId(creator_.transport_version(), + Perspective::IS_CLIENT), + "foo", 0, NO_FIN); + EXPECT_EQ(3u, consumed.bytes_consumed); + EXPECT_TRUE(creator_.HasPendingFrames()); + + QuicSocketAddress peer_addr1(QuicIpAddress::Any4(), 12346); + EXPECT_CALL(delegate_, OnSerializedPacket(_)) + .WillOnce(Invoke([=](SerializedPacket packet) { + EXPECT_EQ(peer_addr, packet.peer_address); + ASSERT_EQ(1u, packet.retransmittable_frames.size()); + EXPECT_EQ(STREAM_FRAME, packet.retransmittable_frames.front().type); + + QuicConnectionId client_connection_id2 = TestConnectionId(3); + QuicConnectionId server_connection_id2 = TestConnectionId(4); + // Set up another context with a different address. + QuicPacketCreator::ScopedPeerAddressContext context( + &creator_, peer_addr1, client_connection_id2, server_connection_id2, + /*update_connection_id=*/true); + ASSERT_EQ(client_connection_id2, creator_.GetClientConnectionId()); + ASSERT_EQ(server_connection_id2, creator_.GetServerConnectionId()); + EXPECT_CALL(delegate_, ShouldGeneratePacket(_, _)) + .WillRepeatedly(Return(true)); + QuicConsumedData consumed = creator_.ConsumeData( + QuicUtils::GetFirstBidirectionalStreamId( + creator_.transport_version(), Perspective::IS_CLIENT), + "foo", 0, NO_FIN); + EXPECT_EQ(3u, consumed.bytes_consumed); + EXPECT_TRUE(creator_.HasPendingFrames()); + // This should trigger another OnSerializedPacket() with the 2nd + // address. + creator_.FlushCurrentPacket(); + })) + .WillOnce(Invoke([=](SerializedPacket packet) { + EXPECT_EQ(peer_addr1, packet.peer_address); + ASSERT_EQ(1u, packet.retransmittable_frames.size()); + EXPECT_EQ(STREAM_FRAME, packet.retransmittable_frames.front().type); + })); + creator_.FlushCurrentPacket(); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_packet_number.cc b/quiche/quic/core/quic_packet_number.cc new file mode 100644 index 000000000000..c7bda674398b --- /dev/null +++ b/quiche/quic/core/quic_packet_number.cc @@ -0,0 +1,109 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_packet_number.h" + +#include +#include + +#include "absl/strings/str_cat.h" + +namespace quic { + +void QuicPacketNumber::Clear() { packet_number_ = UninitializedPacketNumber(); } + +void QuicPacketNumber::UpdateMax(QuicPacketNumber new_value) { + if (!new_value.IsInitialized()) { + return; + } + if (!IsInitialized()) { + packet_number_ = new_value.ToUint64(); + } else { + packet_number_ = std::max(packet_number_, new_value.ToUint64()); + } +} + +uint64_t QuicPacketNumber::Hash() const { + QUICHE_DCHECK(IsInitialized()); + return packet_number_; +} + +uint64_t QuicPacketNumber::ToUint64() const { + QUICHE_DCHECK(IsInitialized()); + return packet_number_; +} + +bool QuicPacketNumber::IsInitialized() const { + return packet_number_ != UninitializedPacketNumber(); +} + +QuicPacketNumber& QuicPacketNumber::operator++() { +#ifndef NDEBUG + QUICHE_DCHECK(IsInitialized()); + QUICHE_DCHECK_LT(ToUint64(), std::numeric_limits::max() - 1); +#endif + packet_number_++; + return *this; +} + +QuicPacketNumber QuicPacketNumber::operator++(int) { +#ifndef NDEBUG + QUICHE_DCHECK(IsInitialized()); + QUICHE_DCHECK_LT(ToUint64(), std::numeric_limits::max() - 1); +#endif + QuicPacketNumber previous(*this); + packet_number_++; + return previous; +} + +QuicPacketNumber& QuicPacketNumber::operator--() { +#ifndef NDEBUG + QUICHE_DCHECK(IsInitialized()); + QUICHE_DCHECK_GE(ToUint64(), 1UL); +#endif + packet_number_--; + return *this; +} + +QuicPacketNumber QuicPacketNumber::operator--(int) { +#ifndef NDEBUG + QUICHE_DCHECK(IsInitialized()); + QUICHE_DCHECK_GE(ToUint64(), 1UL); +#endif + QuicPacketNumber previous(*this); + packet_number_--; + return previous; +} + +QuicPacketNumber& QuicPacketNumber::operator+=(uint64_t delta) { +#ifndef NDEBUG + QUICHE_DCHECK(IsInitialized()); + QUICHE_DCHECK_GT(std::numeric_limits::max() - ToUint64(), delta); +#endif + packet_number_ += delta; + return *this; +} + +QuicPacketNumber& QuicPacketNumber::operator-=(uint64_t delta) { +#ifndef NDEBUG + QUICHE_DCHECK(IsInitialized()); + QUICHE_DCHECK_GE(ToUint64(), delta); +#endif + packet_number_ -= delta; + return *this; +} + +std::string QuicPacketNumber::ToString() const { + if (!IsInitialized()) { + return "uninitialized"; + } + return absl::StrCat(ToUint64()); +} + +std::ostream& operator<<(std::ostream& os, const QuicPacketNumber& p) { + os << p.ToString(); + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/quic_packet_number.h b/quiche/quic/core/quic_packet_number.h new file mode 100644 index 000000000000..8d6b1b63a5d7 --- /dev/null +++ b/quiche/quic/core/quic_packet_number.h @@ -0,0 +1,164 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_PACKET_NUMBER_H_ +#define QUICHE_QUIC_CORE_QUIC_PACKET_NUMBER_H_ + +#include +#include +#include + +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +// QuicPacketNumber can either initialized or uninitialized. An initialized +// packet number is simply an ordinal number. A sentinel value is used to +// represent an uninitialized packet number. +class QUIC_EXPORT_PRIVATE QuicPacketNumber { + public: + // Construct an uninitialized packet number. + constexpr QuicPacketNumber() : packet_number_(UninitializedPacketNumber()) {} + + // Construct a packet number from uint64_t. |packet_number| cannot equal the + // sentinel value. + explicit constexpr QuicPacketNumber(uint64_t packet_number) + : packet_number_(packet_number) { + QUICHE_DCHECK_NE(UninitializedPacketNumber(), packet_number) + << "Use default constructor for uninitialized packet number"; + } + + // The sentinel value representing an uninitialized packet number. + static constexpr uint64_t UninitializedPacketNumber() { + return std::numeric_limits::max(); + } + + // Packet number becomes uninitialized after calling this function. + void Clear(); + + // Updates this packet number to be |new_value| if it is greater than current + // value. + void UpdateMax(QuicPacketNumber new_value); + + // REQUIRES: IsInitialized() == true. + uint64_t Hash() const; + + // Converts packet number to uint64_t. + // REQUIRES: IsInitialized() == true. + uint64_t ToUint64() const; + + // Returns true if packet number is considered initialized. + bool IsInitialized() const; + + // REQUIRES: IsInitialized() == true && ToUint64() < + // numeric_limits::max() - 1. + QuicPacketNumber& operator++(); + QuicPacketNumber operator++(int); + // REQUIRES: IsInitialized() == true && ToUint64() >= 1. + QuicPacketNumber& operator--(); + QuicPacketNumber operator--(int); + + // REQUIRES: IsInitialized() == true && numeric_limits::max() - + // ToUint64() > |delta|. + QuicPacketNumber& operator+=(uint64_t delta); + // REQUIRES: IsInitialized() == true && ToUint64() >= |delta|. + QuicPacketNumber& operator-=(uint64_t delta); + + // Human-readable representation suitable for logging. + std::string ToString() const; + + QUIC_EXPORT_PRIVATE friend std::ostream& operator<<( + std::ostream& os, const QuicPacketNumber& p); + + private: + // All following operators REQUIRE operands.Initialized() == true. + friend inline bool operator==(QuicPacketNumber lhs, QuicPacketNumber rhs); + friend inline bool operator!=(QuicPacketNumber lhs, QuicPacketNumber rhs); + friend inline bool operator<(QuicPacketNumber lhs, QuicPacketNumber rhs); + friend inline bool operator<=(QuicPacketNumber lhs, QuicPacketNumber rhs); + friend inline bool operator>(QuicPacketNumber lhs, QuicPacketNumber rhs); + friend inline bool operator>=(QuicPacketNumber lhs, QuicPacketNumber rhs); + + // REQUIRES: numeric_limits::max() - lhs.ToUint64() > |delta|. + friend inline QuicPacketNumber operator+(QuicPacketNumber lhs, + uint64_t delta); + // REQUIRES: lhs.ToUint64() >= |delta|. + friend inline QuicPacketNumber operator-(QuicPacketNumber lhs, + uint64_t delta); + // REQUIRES: lhs >= rhs. + friend inline uint64_t operator-(QuicPacketNumber lhs, QuicPacketNumber rhs); + + uint64_t packet_number_; +}; + +class QUIC_EXPORT_PRIVATE QuicPacketNumberHash { + public: + uint64_t operator()(QuicPacketNumber packet_number) const noexcept { + return packet_number.Hash(); + } +}; + +inline bool operator==(QuicPacketNumber lhs, QuicPacketNumber rhs) { + QUICHE_DCHECK(lhs.IsInitialized() && rhs.IsInitialized()) + << lhs << " vs. " << rhs; + return lhs.packet_number_ == rhs.packet_number_; +} + +inline bool operator!=(QuicPacketNumber lhs, QuicPacketNumber rhs) { + QUICHE_DCHECK(lhs.IsInitialized() && rhs.IsInitialized()) + << lhs << " vs. " << rhs; + return lhs.packet_number_ != rhs.packet_number_; +} + +inline bool operator<(QuicPacketNumber lhs, QuicPacketNumber rhs) { + QUICHE_DCHECK(lhs.IsInitialized() && rhs.IsInitialized()) + << lhs << " vs. " << rhs; + return lhs.packet_number_ < rhs.packet_number_; +} + +inline bool operator<=(QuicPacketNumber lhs, QuicPacketNumber rhs) { + QUICHE_DCHECK(lhs.IsInitialized() && rhs.IsInitialized()) + << lhs << " vs. " << rhs; + return lhs.packet_number_ <= rhs.packet_number_; +} + +inline bool operator>(QuicPacketNumber lhs, QuicPacketNumber rhs) { + QUICHE_DCHECK(lhs.IsInitialized() && rhs.IsInitialized()) + << lhs << " vs. " << rhs; + return lhs.packet_number_ > rhs.packet_number_; +} + +inline bool operator>=(QuicPacketNumber lhs, QuicPacketNumber rhs) { + QUICHE_DCHECK(lhs.IsInitialized() && rhs.IsInitialized()) + << lhs << " vs. " << rhs; + return lhs.packet_number_ >= rhs.packet_number_; +} + +inline QuicPacketNumber operator+(QuicPacketNumber lhs, uint64_t delta) { +#ifndef NDEBUG + QUICHE_DCHECK(lhs.IsInitialized()); + QUICHE_DCHECK_GT(std::numeric_limits::max() - lhs.ToUint64(), + delta); +#endif + return QuicPacketNumber(lhs.packet_number_ + delta); +} + +inline QuicPacketNumber operator-(QuicPacketNumber lhs, uint64_t delta) { +#ifndef NDEBUG + QUICHE_DCHECK(lhs.IsInitialized()); + QUICHE_DCHECK_GE(lhs.ToUint64(), delta); +#endif + return QuicPacketNumber(lhs.packet_number_ - delta); +} + +inline uint64_t operator-(QuicPacketNumber lhs, QuicPacketNumber rhs) { + QUICHE_DCHECK(lhs.IsInitialized() && rhs.IsInitialized() && lhs >= rhs) + << lhs << " vs. " << rhs; + return lhs.packet_number_ - rhs.packet_number_; +} + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_PACKET_NUMBER_H_ diff --git a/quiche/quic/core/quic_packet_number_test.cc b/quiche/quic/core/quic_packet_number_test.cc new file mode 100644 index 000000000000..084a43574476 --- /dev/null +++ b/quiche/quic/core/quic_packet_number_test.cc @@ -0,0 +1,67 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_packet_number.h" + +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { + +namespace test { + +namespace { + +TEST(QuicPacketNumberTest, BasicTest) { + QuicPacketNumber num; + EXPECT_FALSE(num.IsInitialized()); + + QuicPacketNumber num2(10); + EXPECT_TRUE(num2.IsInitialized()); + EXPECT_EQ(10u, num2.ToUint64()); + EXPECT_EQ(10u, num2.Hash()); + num2.UpdateMax(num); + EXPECT_EQ(10u, num2.ToUint64()); + num2.UpdateMax(QuicPacketNumber(9)); + EXPECT_EQ(10u, num2.ToUint64()); + num2.UpdateMax(QuicPacketNumber(11)); + EXPECT_EQ(11u, num2.ToUint64()); + num2.Clear(); + EXPECT_FALSE(num2.IsInitialized()); + num2.UpdateMax(QuicPacketNumber(9)); + EXPECT_EQ(9u, num2.ToUint64()); + + QuicPacketNumber num4(0); + EXPECT_TRUE(num4.IsInitialized()); + EXPECT_EQ(0u, num4.ToUint64()); + EXPECT_EQ(0u, num4.Hash()); + num4.Clear(); + EXPECT_FALSE(num4.IsInitialized()); +} + +TEST(QuicPacketNumberTest, Operators) { + QuicPacketNumber num(100); + EXPECT_EQ(QuicPacketNumber(100), num++); + EXPECT_EQ(QuicPacketNumber(101), num); + EXPECT_EQ(QuicPacketNumber(101), num--); + EXPECT_EQ(QuicPacketNumber(100), num); + + EXPECT_EQ(QuicPacketNumber(101), ++num); + EXPECT_EQ(QuicPacketNumber(100), --num); + + QuicPacketNumber num3(0); + EXPECT_EQ(QuicPacketNumber(0), num3++); + EXPECT_EQ(QuicPacketNumber(1), num3); + EXPECT_EQ(QuicPacketNumber(2), ++num3); + + EXPECT_EQ(QuicPacketNumber(2), num3--); + EXPECT_EQ(QuicPacketNumber(1), num3); + EXPECT_EQ(QuicPacketNumber(0), --num3); +} + +} // namespace + +} // namespace test + +} // namespace quic diff --git a/quiche/quic/core/quic_packet_reader.cc b/quiche/quic/core/quic_packet_reader.cc new file mode 100644 index 000000000000..eaa3441b3d8b --- /dev/null +++ b/quiche/quic/core/quic_packet_reader.cc @@ -0,0 +1,136 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_packet_reader.h" + +#include "absl/base/macros.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_process_packet_interface.h" +#include "quiche/quic/core/quic_udp_socket.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_server_stats.h" +#include "quiche/quic/platform/api/quic_socket_address.h" + +namespace quic { + +QuicPacketReader::QuicPacketReader() + : read_buffers_(kNumPacketsPerReadMmsgCall), + read_results_(kNumPacketsPerReadMmsgCall) { + QUICHE_DCHECK_EQ(read_buffers_.size(), read_results_.size()); + for (size_t i = 0; i < read_results_.size(); ++i) { + read_results_[i].packet_buffer.buffer = read_buffers_[i].packet_buffer; + read_results_[i].packet_buffer.buffer_len = + sizeof(read_buffers_[i].packet_buffer); + + read_results_[i].control_buffer.buffer = read_buffers_[i].control_buffer; + read_results_[i].control_buffer.buffer_len = + sizeof(read_buffers_[i].control_buffer); + } +} + +QuicPacketReader::~QuicPacketReader() = default; + +bool QuicPacketReader::ReadAndDispatchPackets( + int fd, int port, const QuicClock& clock, ProcessPacketInterface* processor, + QuicPacketCount* /*packets_dropped*/) { + // Reset all read_results for reuse. + for (size_t i = 0; i < read_results_.size(); ++i) { + read_results_[i].Reset( + /*packet_buffer_length=*/sizeof(read_buffers_[i].packet_buffer)); + } + + // Use clock.Now() as the packet receipt time, the time between packet + // arriving at the host and now is considered part of the network delay. + QuicTime now = clock.Now(); + + BitMask64 info_bits{QuicUdpPacketInfoBit::DROPPED_PACKETS, + QuicUdpPacketInfoBit::PEER_ADDRESS, + QuicUdpPacketInfoBit::V4_SELF_IP, + QuicUdpPacketInfoBit::V6_SELF_IP, + QuicUdpPacketInfoBit::RECV_TIMESTAMP, + QuicUdpPacketInfoBit::TTL, + QuicUdpPacketInfoBit::GOOGLE_PACKET_HEADER}; + if (GetQuicRestartFlag(quic_receive_ecn)) { + QUIC_RESTART_FLAG_COUNT_N(quic_receive_ecn, 3, 3); + info_bits.Set(QuicUdpPacketInfoBit::ECN); + } + size_t packets_read = + socket_api_.ReadMultiplePackets(fd, info_bits, &read_results_); + for (size_t i = 0; i < packets_read; ++i) { + auto& result = read_results_[i]; + if (!result.ok) { + QUIC_CODE_COUNT(quic_packet_reader_read_failure); + continue; + } + + if (!result.packet_info.HasValue(QuicUdpPacketInfoBit::PEER_ADDRESS)) { + QUIC_BUG(quic_bug_10329_1) << "Unable to get peer socket address."; + continue; + } + + QuicSocketAddress peer_address = + result.packet_info.peer_address().Normalized(); + + QuicIpAddress self_ip = GetSelfIpFromPacketInfo( + result.packet_info, peer_address.host().IsIPv6()); + if (!self_ip.IsInitialized()) { + QUIC_BUG(quic_bug_10329_2) << "Unable to get self IP address."; + continue; + } + + bool has_ttl = result.packet_info.HasValue(QuicUdpPacketInfoBit::TTL); + int ttl = has_ttl ? result.packet_info.ttl() : 0; + if (!has_ttl) { + QUIC_CODE_COUNT(quic_packet_reader_no_ttl); + } + + char* headers = nullptr; + size_t headers_length = 0; + if (result.packet_info.HasValue( + QuicUdpPacketInfoBit::GOOGLE_PACKET_HEADER)) { + headers = result.packet_info.google_packet_headers().buffer; + headers_length = result.packet_info.google_packet_headers().buffer_len; + } else { + QUIC_CODE_COUNT(quic_packet_reader_no_google_packet_header); + } + + QuicReceivedPacket packet( + result.packet_buffer.buffer, result.packet_buffer.buffer_len, now, + /*owns_buffer=*/false, ttl, has_ttl, headers, headers_length, + /*owns_header_buffer=*/false, result.packet_info.ecn_codepoint()); + QuicSocketAddress self_address(self_ip, port); + processor->ProcessPacket(self_address, peer_address, packet); + } + + // We may not have read all of the packets available on the socket. + return packets_read == kNumPacketsPerReadMmsgCall; +} + +// static +QuicIpAddress QuicPacketReader::GetSelfIpFromPacketInfo( + const QuicUdpPacketInfo& packet_info, bool prefer_v6_ip) { + if (prefer_v6_ip) { + if (packet_info.HasValue(QuicUdpPacketInfoBit::V6_SELF_IP)) { + return packet_info.self_v6_ip(); + } + if (packet_info.HasValue(QuicUdpPacketInfoBit::V4_SELF_IP)) { + return packet_info.self_v4_ip(); + } + } else { + if (packet_info.HasValue(QuicUdpPacketInfoBit::V4_SELF_IP)) { + return packet_info.self_v4_ip(); + } + if (packet_info.HasValue(QuicUdpPacketInfoBit::V6_SELF_IP)) { + return packet_info.self_v6_ip(); + } + } + return QuicIpAddress(); +} + +} // namespace quic diff --git a/quiche/quic/core/quic_packet_reader.h b/quiche/quic/core/quic_packet_reader.h new file mode 100644 index 000000000000..6ec2682f0243 --- /dev/null +++ b/quiche/quic/core/quic_packet_reader.h @@ -0,0 +1,64 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// A class to read incoming QUIC packets from the UDP socket. + +#ifndef QUICHE_QUIC_CORE_QUIC_PACKET_READER_H_ +#define QUICHE_QUIC_CORE_QUIC_PACKET_READER_H_ + +#include "absl/base/optimization.h" +#include "quiche/quic/core/quic_clock.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_process_packet_interface.h" +#include "quiche/quic/core/quic_udp_socket.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_socket_address.h" + +namespace quic { + +// Read in larger batches to minimize recvmmsg overhead. +const int kNumPacketsPerReadMmsgCall = 16; + +class QUIC_EXPORT_PRIVATE QuicPacketReader { + public: + QuicPacketReader(); + QuicPacketReader(const QuicPacketReader&) = delete; + QuicPacketReader& operator=(const QuicPacketReader&) = delete; + + virtual ~QuicPacketReader(); + + // Reads a number of packets from the given fd, and then passes them off to + // the PacketProcessInterface. Returns true if there may be additional + // packets available on the socket. + // Populates |packets_dropped| if it is non-null and the socket is configured + // to track dropped packets and some packets are read. + // If the socket has timestamping enabled, the per packet timestamps will be + // passed to the processor. Otherwise, |clock| will be used. + virtual bool ReadAndDispatchPackets(int fd, int port, const QuicClock& clock, + ProcessPacketInterface* processor, + QuicPacketCount* packets_dropped); + + private: + // Return the self ip from |packet_info|. + // For dual stack sockets, |packet_info| may contain both a v4 and a v6 ip, in + // that case, |prefer_v6_ip| is used to determine which one is used as the + // return value. If neither v4 nor v6 ip exists, return an uninitialized ip. + static QuicIpAddress GetSelfIpFromPacketInfo( + const QuicUdpPacketInfo& packet_info, bool prefer_v6_ip); + + struct QUIC_EXPORT_PRIVATE ReadBuffer { + ABSL_CACHELINE_ALIGNED char + control_buffer[kDefaultUdpPacketControlBufferSize]; // For ancillary + // data. + ABSL_CACHELINE_ALIGNED char packet_buffer[kMaxIncomingPacketSize]; + }; + + QuicUdpSocketApi socket_api_; + std::vector read_buffers_; + QuicUdpSocketApi::ReadPacketResults read_results_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_PACKET_READER_H_ diff --git a/quiche/quic/core/quic_packet_writer.h b/quiche/quic/core/quic_packet_writer.h new file mode 100644 index 000000000000..3e6cb21a4f08 --- /dev/null +++ b/quiche/quic/core/quic_packet_writer.h @@ -0,0 +1,171 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_PACKET_WRITER_H_ +#define QUICHE_QUIC_CORE_QUIC_PACKET_WRITER_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_socket_address.h" + +namespace quic { + +struct WriteResult; + +struct QUIC_EXPORT_PRIVATE PerPacketOptions { + virtual ~PerPacketOptions() {} + + // Returns a heap-allocated copy of |this|. + // + // The subclass implementation of this method should look like this: + // return std::make_unique(*this); + // + // This method is declared pure virtual in order to ensure the subclasses + // would not forget to override it. + virtual std::unique_ptr Clone() const = 0; + + // Specifies ideal release time delay for this packet. + QuicTime::Delta release_time_delay = QuicTime::Delta::Zero(); + // Whether it is allowed to send this packet without |release_time_delay|. + bool allow_burst = false; + // ECN codepoint to use when sending this packet. + QuicEcnCodepoint ecn_codepoint = ECN_NOT_ECT; +}; + +// An interface between writers and the entity managing the +// socket (in our case the QuicDispatcher). This allows the Dispatcher to +// control writes, and manage any writers who end up write blocked. +// A concrete writer works in one of the two modes: +// - PassThrough mode. This is the default mode. Caller calls WritePacket with +// caller-allocated packet buffer. Unless the writer is blocked, each call to +// WritePacket triggers a write using the underlying socket API. +// +// - Batch mode. In this mode, a call to WritePacket may not cause a packet to +// be sent using the underlying socket API. Instead, multiple packets are +// saved in the writer's internal buffer until they are flushed. The flush can +// be explicit, by calling Flush, or implicit, e.g. by calling +// WritePacket when the internal buffer is near full. +// +// Buffer management: +// In Batch mode, a writer manages an internal buffer, which is large enough to +// hold multiple packets' data. If the caller calls WritePacket with a +// caller-allocated packet buffer, the writer will memcpy the buffer into the +// internal buffer. Caller can also avoid this memcpy by: +// 1. Call GetNextWriteLocation to get a pointer P into the internal buffer. +// 2. Serialize the packet directly to P. +// 3. Call WritePacket with P as the |buffer|. +class QUIC_EXPORT_PRIVATE QuicPacketWriter { + public: + virtual ~QuicPacketWriter() {} + + // PassThrough mode: + // Sends the packet out to the peer, with some optional per-packet options. + // If the write succeeded, the result's status is WRITE_STATUS_OK and + // bytes_written is populated. If the write failed, the result's status is + // WRITE_STATUS_BLOCKED or WRITE_STATUS_ERROR and error_code is populated. + // + // Batch mode: + // If the writer is blocked, return WRITE_STATUS_BLOCKED immediately. + // If the packet can be batched with other buffered packets, save the packet + // to the internal buffer. + // If the packet can not be batched, or the internal buffer is near full after + // it is buffered, the internal buffer is flushed to free up space. + // Return WriteResult(WRITE_STATUS_OK, ) on success. When + // is zero, it means the packet is buffered and not flushed. + // Return WRITE_STATUS_BLOCKED if the packet is not buffered and the socket is + // blocked while flushing. + // Otherwise return an error status. + // + // Options must be either null, or created for the particular QuicPacketWriter + // implementation. Options may be ignored, depending on the implementation. + // + // Some comment about memory management if |buffer| was previously acquired + // by a call to "GetNextWriteLocation()": + // + // a) When WRITE_STATUS_OK is returned, the caller expects the writer owns the + // packet buffers and they will be released when the write finishes. + // + // b) When this function returns any status >= WRITE_STATUS_ERROR, the caller + // expects the writer releases the buffer (if needed) before the function + // returns. + // + // c) When WRITE_STATUS_BLOCKED is returned, the caller makes a copy of the + // buffer and will retry after unblock, so if |payload| is allocated from + // GetNextWriteLocation(), it + // 1) needs to be released before return, and + // 2) the content of |payload| should not change after return. + // + // d) When WRITE_STATUS_BLOCKED_DATA_BUFFERED is returned, the caller expects + // 1) the writer owns the packet buffers, and 2) the writer will re-send the + // packet when it unblocks. + virtual WriteResult WritePacket(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + PerPacketOptions* options) = 0; + + // Returns true if the network socket is not writable. + virtual bool IsWriteBlocked() const = 0; + + // Records that the socket has become writable, for example when an EPOLLOUT + // is received or an asynchronous write completes. + virtual void SetWritable() = 0; + + // The error code used by the writer to indicate that the write failed due to + // supplied packet being too big. This is equivalent to returning + // WRITE_STATUS_MSG_TOO_BIG as a status. + virtual absl::optional MessageTooBigErrorCode() const = 0; + + // Returns the maximum size of the packet which can be written using this + // writer for the supplied peer address. This size may actually exceed the + // size of a valid QUIC packet. + virtual QuicByteCount GetMaxPacketSize( + const QuicSocketAddress& peer_address) const = 0; + + // Returns true if the socket supports release timestamp. + virtual bool SupportsReleaseTime() const = 0; + + // True=Batch mode. False=PassThrough mode. + virtual bool IsBatchMode() const = 0; + + // PassThrough mode: Return {nullptr, nullptr} + // + // Batch mode: + // Return the QuicPacketBuffer for the next packet. A minimum of + // kMaxOutgoingPacketSize is guaranteed to be available from the returned + // address. If the internal buffer does not have enough space, + // {nullptr, nullptr} is returned. All arguments should be identical to the + // follow-up call to |WritePacket|, they are here to allow advanced packet + // memory management in packet writers, e.g. one packet buffer pool per + // |peer_address|. + // + // If QuicPacketBuffer.release_buffer is !nullptr, it should be called iff + // the caller does not call WritePacket for the returned buffer. + virtual QuicPacketBuffer GetNextWriteLocation( + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address) = 0; + + // PassThrough mode: Return WriteResult(WRITE_STATUS_OK, 0). + // + // Batch mode: + // Try send all buffered packets. + // - Return WriteResult(WRITE_STATUS_OK, ) if all buffered + // packets were sent successfully. + // - Return WRITE_STATUS_BLOCKED if the underlying socket is blocked while + // sending. Some packets may have been sent, packets not sent will stay in + // the internal buffer. + // - Return a status >= WRITE_STATUS_ERROR if an error was encuontered while + // sending. As this is not a re-tryable error, any batched packets which + // were on memory acquired via GetNextWriteLocation() should be released and + // the batch should be dropped. + virtual WriteResult Flush() = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_PACKET_WRITER_H_ diff --git a/quiche/quic/core/quic_packet_writer_wrapper.cc b/quiche/quic/core/quic_packet_writer_wrapper.cc new file mode 100644 index 000000000000..c040c91042da --- /dev/null +++ b/quiche/quic/core/quic_packet_writer_wrapper.cc @@ -0,0 +1,73 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_packet_writer_wrapper.h" + +#include "quiche/quic/core/quic_types.h" + +namespace quic { + +QuicPacketWriterWrapper::QuicPacketWriterWrapper() = default; + +QuicPacketWriterWrapper::~QuicPacketWriterWrapper() { unset_writer(); } + +WriteResult QuicPacketWriterWrapper::WritePacket( + const char* buffer, size_t buf_len, const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, PerPacketOptions* options) { + return writer_->WritePacket(buffer, buf_len, self_address, peer_address, + options); +} + +bool QuicPacketWriterWrapper::IsWriteBlocked() const { + return writer_->IsWriteBlocked(); +} + +void QuicPacketWriterWrapper::SetWritable() { writer_->SetWritable(); } + +absl::optional QuicPacketWriterWrapper::MessageTooBigErrorCode() const { + return writer_->MessageTooBigErrorCode(); +} + +QuicByteCount QuicPacketWriterWrapper::GetMaxPacketSize( + const QuicSocketAddress& peer_address) const { + return writer_->GetMaxPacketSize(peer_address); +} + +bool QuicPacketWriterWrapper::SupportsReleaseTime() const { + return writer_->SupportsReleaseTime(); +} + +bool QuicPacketWriterWrapper::IsBatchMode() const { + return writer_->IsBatchMode(); +} + +QuicPacketBuffer QuicPacketWriterWrapper::GetNextWriteLocation( + const QuicIpAddress& self_address, const QuicSocketAddress& peer_address) { + return writer_->GetNextWriteLocation(self_address, peer_address); +} + +WriteResult QuicPacketWriterWrapper::Flush() { return writer_->Flush(); } + +void QuicPacketWriterWrapper::set_writer(QuicPacketWriter* writer) { + unset_writer(); + writer_ = writer; + owns_writer_ = true; +} + +void QuicPacketWriterWrapper::set_non_owning_writer(QuicPacketWriter* writer) { + unset_writer(); + writer_ = writer; + owns_writer_ = false; +} + +void QuicPacketWriterWrapper::unset_writer() { + if (owns_writer_) { + delete writer_; + } + + owns_writer_ = false; + writer_ = nullptr; +} + +} // namespace quic diff --git a/quiche/quic/core/quic_packet_writer_wrapper.h b/quiche/quic/core/quic_packet_writer_wrapper.h new file mode 100644 index 000000000000..3afeaf17491a --- /dev/null +++ b/quiche/quic/core/quic_packet_writer_wrapper.h @@ -0,0 +1,62 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_PACKET_WRITER_WRAPPER_H_ +#define QUICHE_QUIC_CORE_QUIC_PACKET_WRITER_WRAPPER_H_ + +#include +#include + +#include "quiche/quic/core/quic_packet_writer.h" + +namespace quic { + +// Wraps a writer object to allow dynamically extending functionality. Use +// cases: replace writer while dispatcher and connections hold on to the +// wrapper; mix in monitoring; mix in mocks in unit tests. +class QUIC_NO_EXPORT QuicPacketWriterWrapper : public QuicPacketWriter { + public: + QuicPacketWriterWrapper(); + QuicPacketWriterWrapper(const QuicPacketWriterWrapper&) = delete; + QuicPacketWriterWrapper& operator=(const QuicPacketWriterWrapper&) = delete; + ~QuicPacketWriterWrapper() override; + + // Default implementation of the QuicPacketWriter interface. Passes everything + // to |writer_|. + WriteResult WritePacket(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + PerPacketOptions* options) override; + bool IsWriteBlocked() const override; + void SetWritable() override; + absl::optional MessageTooBigErrorCode() const override; + QuicByteCount GetMaxPacketSize( + const QuicSocketAddress& peer_address) const override; + bool SupportsReleaseTime() const override; + bool IsBatchMode() const override; + QuicPacketBuffer GetNextWriteLocation( + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address) override; + WriteResult Flush() override; + + // Takes ownership of |writer|. + void set_writer(QuicPacketWriter* writer); + + // Does not take ownership of |writer|. + void set_non_owning_writer(QuicPacketWriter* writer); + + virtual void set_peer_address(const QuicSocketAddress& /*peer_address*/) {} + + QuicPacketWriter* writer() { return writer_; } + + private: + void unset_writer(); + + QuicPacketWriter* writer_ = nullptr; + bool owns_writer_ = false; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_PACKET_WRITER_WRAPPER_H_ diff --git a/quiche/quic/core/quic_packets.cc b/quiche/quic/core/quic_packets.cc new file mode 100644 index 000000000000..39e7eaf51d55 --- /dev/null +++ b/quiche/quic/core/quic_packets.cc @@ -0,0 +1,601 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_packets.h" + +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" + +namespace quic { + +QuicConnectionId GetServerConnectionIdAsRecipient( + const QuicPacketHeader& header, Perspective perspective) { + if (perspective == Perspective::IS_SERVER) { + return header.destination_connection_id; + } + return header.source_connection_id; +} + +QuicConnectionId GetClientConnectionIdAsRecipient( + const QuicPacketHeader& header, Perspective perspective) { + if (perspective == Perspective::IS_CLIENT) { + return header.destination_connection_id; + } + return header.source_connection_id; +} + +QuicConnectionId GetServerConnectionIdAsSender(const QuicPacketHeader& header, + Perspective perspective) { + if (perspective == Perspective::IS_CLIENT) { + return header.destination_connection_id; + } + return header.source_connection_id; +} + +QuicConnectionIdIncluded GetServerConnectionIdIncludedAsSender( + const QuicPacketHeader& header, Perspective perspective) { + if (perspective == Perspective::IS_CLIENT) { + return header.destination_connection_id_included; + } + return header.source_connection_id_included; +} + +QuicConnectionId GetClientConnectionIdAsSender(const QuicPacketHeader& header, + Perspective perspective) { + if (perspective == Perspective::IS_CLIENT) { + return header.source_connection_id; + } + return header.destination_connection_id; +} + +QuicConnectionIdIncluded GetClientConnectionIdIncludedAsSender( + const QuicPacketHeader& header, Perspective perspective) { + if (perspective == Perspective::IS_CLIENT) { + return header.source_connection_id_included; + } + return header.destination_connection_id_included; +} + +uint8_t GetIncludedConnectionIdLength( + QuicConnectionId connection_id, + QuicConnectionIdIncluded connection_id_included) { + QUICHE_DCHECK(connection_id_included == CONNECTION_ID_PRESENT || + connection_id_included == CONNECTION_ID_ABSENT); + return connection_id_included == CONNECTION_ID_PRESENT + ? connection_id.length() + : 0; +} + +uint8_t GetIncludedDestinationConnectionIdLength( + const QuicPacketHeader& header) { + return GetIncludedConnectionIdLength( + header.destination_connection_id, + header.destination_connection_id_included); +} + +uint8_t GetIncludedSourceConnectionIdLength(const QuicPacketHeader& header) { + return GetIncludedConnectionIdLength(header.source_connection_id, + header.source_connection_id_included); +} + +size_t GetPacketHeaderSize(QuicTransportVersion version, + const QuicPacketHeader& header) { + return GetPacketHeaderSize( + version, GetIncludedDestinationConnectionIdLength(header), + GetIncludedSourceConnectionIdLength(header), header.version_flag, + header.nonce != nullptr, header.packet_number_length, + header.retry_token_length_length, header.retry_token.length(), + header.length_length); +} + +size_t GetPacketHeaderSize( + QuicTransportVersion version, uint8_t destination_connection_id_length, + uint8_t source_connection_id_length, bool include_version, + bool include_diversification_nonce, + QuicPacketNumberLength packet_number_length, + quiche::QuicheVariableLengthIntegerLength retry_token_length_length, + QuicByteCount retry_token_length, + quiche::QuicheVariableLengthIntegerLength length_length) { + if (VersionHasIetfInvariantHeader(version)) { + if (include_version) { + // Long header. + size_t size = kPacketHeaderTypeSize + kConnectionIdLengthSize + + destination_connection_id_length + + source_connection_id_length + packet_number_length + + kQuicVersionSize; + if (include_diversification_nonce) { + size += kDiversificationNonceSize; + } + if (VersionHasLengthPrefixedConnectionIds(version)) { + size += kConnectionIdLengthSize; + } + QUICHE_DCHECK( + QuicVersionHasLongHeaderLengths(version) || + retry_token_length_length + retry_token_length + length_length == 0); + if (QuicVersionHasLongHeaderLengths(version)) { + size += retry_token_length_length + retry_token_length + length_length; + } + return size; + } + // Short header. + return kPacketHeaderTypeSize + destination_connection_id_length + + packet_number_length; + } + // Google QUIC versions <= 43 can only carry one connection ID. + QUICHE_DCHECK(destination_connection_id_length == 0 || + source_connection_id_length == 0); + return kPublicFlagsSize + destination_connection_id_length + + source_connection_id_length + + (include_version ? kQuicVersionSize : 0) + packet_number_length + + (include_diversification_nonce ? kDiversificationNonceSize : 0); +} + +size_t GetStartOfEncryptedData(QuicTransportVersion version, + const QuicPacketHeader& header) { + return GetPacketHeaderSize(version, header); +} + +size_t GetStartOfEncryptedData( + QuicTransportVersion version, uint8_t destination_connection_id_length, + uint8_t source_connection_id_length, bool include_version, + bool include_diversification_nonce, + QuicPacketNumberLength packet_number_length, + quiche::QuicheVariableLengthIntegerLength retry_token_length_length, + QuicByteCount retry_token_length, + quiche::QuicheVariableLengthIntegerLength length_length) { + // Encryption starts before private flags. + return GetPacketHeaderSize( + version, destination_connection_id_length, source_connection_id_length, + include_version, include_diversification_nonce, packet_number_length, + retry_token_length_length, retry_token_length, length_length); +} + +QuicPacketHeader::QuicPacketHeader() + : destination_connection_id(EmptyQuicConnectionId()), + destination_connection_id_included(CONNECTION_ID_PRESENT), + source_connection_id(EmptyQuicConnectionId()), + source_connection_id_included(CONNECTION_ID_ABSENT), + reset_flag(false), + version_flag(false), + has_possible_stateless_reset_token(false), + packet_number_length(PACKET_4BYTE_PACKET_NUMBER), + type_byte(0), + version(UnsupportedQuicVersion()), + nonce(nullptr), + form(GOOGLE_QUIC_PACKET), + long_packet_type(INITIAL), + possible_stateless_reset_token({}), + retry_token_length_length(quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0), + retry_token(absl::string_view()), + length_length(quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0), + remaining_packet_length(0) {} + +QuicPacketHeader::QuicPacketHeader(const QuicPacketHeader& other) = default; + +QuicPacketHeader::~QuicPacketHeader() {} + +QuicPacketHeader& QuicPacketHeader::operator=(const QuicPacketHeader& other) = + default; + +QuicPublicResetPacket::QuicPublicResetPacket() + : connection_id(EmptyQuicConnectionId()), nonce_proof(0) {} + +QuicPublicResetPacket::QuicPublicResetPacket(QuicConnectionId connection_id) + : connection_id(connection_id), nonce_proof(0) {} + +QuicVersionNegotiationPacket::QuicVersionNegotiationPacket() + : connection_id(EmptyQuicConnectionId()) {} + +QuicVersionNegotiationPacket::QuicVersionNegotiationPacket( + QuicConnectionId connection_id) + : connection_id(connection_id) {} + +QuicVersionNegotiationPacket::QuicVersionNegotiationPacket( + const QuicVersionNegotiationPacket& other) = default; + +QuicVersionNegotiationPacket::~QuicVersionNegotiationPacket() {} + +QuicIetfStatelessResetPacket::QuicIetfStatelessResetPacket() + : stateless_reset_token({}) {} + +QuicIetfStatelessResetPacket::QuicIetfStatelessResetPacket( + const QuicPacketHeader& header, StatelessResetToken token) + : header(header), stateless_reset_token(token) {} + +QuicIetfStatelessResetPacket::QuicIetfStatelessResetPacket( + const QuicIetfStatelessResetPacket& other) = default; + +QuicIetfStatelessResetPacket::~QuicIetfStatelessResetPacket() {} + +std::ostream& operator<<(std::ostream& os, const QuicPacketHeader& header) { + os << "{ destination_connection_id: " << header.destination_connection_id + << " (" + << (header.destination_connection_id_included == CONNECTION_ID_PRESENT + ? "present" + : "absent") + << "), source_connection_id: " << header.source_connection_id << " (" + << (header.source_connection_id_included == CONNECTION_ID_PRESENT + ? "present" + : "absent") + << "), packet_number_length: " + << static_cast(header.packet_number_length) + << ", reset_flag: " << header.reset_flag + << ", version_flag: " << header.version_flag; + if (header.version_flag) { + os << ", version: " << ParsedQuicVersionToString(header.version); + if (header.long_packet_type != INVALID_PACKET_TYPE) { + os << ", long_packet_type: " + << QuicUtils::QuicLongHeaderTypetoString(header.long_packet_type); + } + if (header.retry_token_length_length != + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0) { + os << ", retry_token_length_length: " + << static_cast(header.retry_token_length_length); + } + if (header.retry_token.length() != 0) { + os << ", retry_token_length: " << header.retry_token.length(); + } + if (header.length_length != quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0) { + os << ", length_length: " << static_cast(header.length_length); + } + if (header.remaining_packet_length != 0) { + os << ", remaining_packet_length: " << header.remaining_packet_length; + } + } + if (header.nonce != nullptr) { + os << ", diversification_nonce: " + << absl::BytesToHexString( + absl::string_view(header.nonce->data(), header.nonce->size())); + } + os << ", packet_number: " << header.packet_number << " }\n"; + return os; +} + +QuicData::QuicData(const char* buffer, size_t length) + : buffer_(buffer), length_(length), owns_buffer_(false) {} + +QuicData::QuicData(const char* buffer, size_t length, bool owns_buffer) + : buffer_(buffer), length_(length), owns_buffer_(owns_buffer) {} + +QuicData::QuicData(absl::string_view packet_data) + : buffer_(packet_data.data()), + length_(packet_data.length()), + owns_buffer_(false) {} + +QuicData::~QuicData() { + if (owns_buffer_) { + delete[] const_cast(buffer_); + } +} + +QuicPacket::QuicPacket( + char* buffer, size_t length, bool owns_buffer, + uint8_t destination_connection_id_length, + uint8_t source_connection_id_length, bool includes_version, + bool includes_diversification_nonce, + QuicPacketNumberLength packet_number_length, + quiche::QuicheVariableLengthIntegerLength retry_token_length_length, + QuicByteCount retry_token_length, + quiche::QuicheVariableLengthIntegerLength length_length) + : QuicData(buffer, length, owns_buffer), + buffer_(buffer), + destination_connection_id_length_(destination_connection_id_length), + source_connection_id_length_(source_connection_id_length), + includes_version_(includes_version), + includes_diversification_nonce_(includes_diversification_nonce), + packet_number_length_(packet_number_length), + retry_token_length_length_(retry_token_length_length), + retry_token_length_(retry_token_length), + length_length_(length_length) {} + +QuicPacket::QuicPacket(QuicTransportVersion /*version*/, char* buffer, + size_t length, bool owns_buffer, + const QuicPacketHeader& header) + : QuicPacket(buffer, length, owns_buffer, + GetIncludedDestinationConnectionIdLength(header), + GetIncludedSourceConnectionIdLength(header), + header.version_flag, header.nonce != nullptr, + header.packet_number_length, header.retry_token_length_length, + header.retry_token.length(), header.length_length) {} + +QuicEncryptedPacket::QuicEncryptedPacket(const char* buffer, size_t length) + : QuicData(buffer, length) {} + +QuicEncryptedPacket::QuicEncryptedPacket(const char* buffer, size_t length, + bool owns_buffer) + : QuicData(buffer, length, owns_buffer) {} + +QuicEncryptedPacket::QuicEncryptedPacket(absl::string_view data) + : QuicData(data) {} + +std::unique_ptr QuicEncryptedPacket::Clone() const { + char* buffer = new char[this->length()]; + memcpy(buffer, this->data(), this->length()); + return std::make_unique(buffer, this->length(), true); +} + +std::ostream& operator<<(std::ostream& os, const QuicEncryptedPacket& s) { + os << s.length() << "-byte data"; + return os; +} + +QuicReceivedPacket::QuicReceivedPacket(const char* buffer, size_t length, + QuicTime receipt_time) + : QuicReceivedPacket(buffer, length, receipt_time, + false /* owns_buffer */) {} + +QuicReceivedPacket::QuicReceivedPacket(const char* buffer, size_t length, + QuicTime receipt_time, bool owns_buffer) + : QuicReceivedPacket(buffer, length, receipt_time, owns_buffer, 0 /* ttl */, + true /* ttl_valid */) {} + +QuicReceivedPacket::QuicReceivedPacket(const char* buffer, size_t length, + QuicTime receipt_time, bool owns_buffer, + int ttl, bool ttl_valid) + : quic::QuicReceivedPacket(buffer, length, receipt_time, owns_buffer, ttl, + ttl_valid, nullptr /* packet_headers */, + 0 /* headers_length */, + false /* owns_header_buffer */, ECN_NOT_ECT) {} + +QuicReceivedPacket::QuicReceivedPacket(const char* buffer, size_t length, + QuicTime receipt_time, bool owns_buffer, + int ttl, bool ttl_valid, + char* packet_headers, + size_t headers_length, + bool owns_header_buffer) + : quic::QuicReceivedPacket(buffer, length, receipt_time, owns_buffer, ttl, + ttl_valid, packet_headers, headers_length, + owns_header_buffer, ECN_NOT_ECT) {} + +QuicReceivedPacket::QuicReceivedPacket( + const char* buffer, size_t length, QuicTime receipt_time, bool owns_buffer, + int ttl, bool ttl_valid, char* packet_headers, size_t headers_length, + bool owns_header_buffer, QuicEcnCodepoint ecn_codepoint) + : QuicEncryptedPacket(buffer, length, owns_buffer), + receipt_time_(receipt_time), + ttl_(ttl_valid ? ttl : -1), + packet_headers_(packet_headers), + headers_length_(headers_length), + owns_header_buffer_(owns_header_buffer), + ecn_codepoint_(ecn_codepoint) {} + +QuicReceivedPacket::~QuicReceivedPacket() { + if (owns_header_buffer_) { + delete[] static_cast(packet_headers_); + } +} + +std::unique_ptr QuicReceivedPacket::Clone() const { + char* buffer = new char[this->length()]; + memcpy(buffer, this->data(), this->length()); + if (this->packet_headers()) { + char* headers_buffer = new char[this->headers_length()]; + memcpy(headers_buffer, this->packet_headers(), this->headers_length()); + return std::make_unique( + buffer, this->length(), receipt_time(), true, ttl(), ttl() >= 0, + headers_buffer, this->headers_length(), true); + } + + return std::make_unique( + buffer, this->length(), receipt_time(), true, ttl(), ttl() >= 0); +} + +std::ostream& operator<<(std::ostream& os, const QuicReceivedPacket& s) { + os << s.length() << "-byte data"; + return os; +} + +absl::string_view QuicPacket::AssociatedData( + QuicTransportVersion version) const { + return absl::string_view( + data(), + GetStartOfEncryptedData(version, destination_connection_id_length_, + source_connection_id_length_, includes_version_, + includes_diversification_nonce_, + packet_number_length_, retry_token_length_length_, + retry_token_length_, length_length_)); +} + +absl::string_view QuicPacket::Plaintext(QuicTransportVersion version) const { + const size_t start_of_encrypted_data = GetStartOfEncryptedData( + version, destination_connection_id_length_, source_connection_id_length_, + includes_version_, includes_diversification_nonce_, packet_number_length_, + retry_token_length_length_, retry_token_length_, length_length_); + return absl::string_view(data() + start_of_encrypted_data, + length() - start_of_encrypted_data); +} + +SerializedPacket::SerializedPacket(QuicPacketNumber packet_number, + QuicPacketNumberLength packet_number_length, + const char* encrypted_buffer, + QuicPacketLength encrypted_length, + bool has_ack, bool has_stop_waiting) + : encrypted_buffer(encrypted_buffer), + encrypted_length(encrypted_length), + has_crypto_handshake(NOT_HANDSHAKE), + packet_number(packet_number), + packet_number_length(packet_number_length), + encryption_level(ENCRYPTION_INITIAL), + has_ack(has_ack), + has_stop_waiting(has_stop_waiting), + transmission_type(NOT_RETRANSMISSION), + has_ack_frame_copy(false), + has_ack_frequency(false), + has_message(false), + fate(SEND_TO_WRITER) {} + +SerializedPacket::SerializedPacket(SerializedPacket&& other) + : has_crypto_handshake(other.has_crypto_handshake), + packet_number(other.packet_number), + packet_number_length(other.packet_number_length), + encryption_level(other.encryption_level), + has_ack(other.has_ack), + has_stop_waiting(other.has_stop_waiting), + has_ack_ecn(other.has_ack_ecn), + transmission_type(other.transmission_type), + largest_acked(other.largest_acked), + has_ack_frame_copy(other.has_ack_frame_copy), + has_ack_frequency(other.has_ack_frequency), + has_message(other.has_message), + fate(other.fate), + peer_address(other.peer_address), + bytes_not_retransmitted(other.bytes_not_retransmitted), + initial_header(other.initial_header) { + if (this != &other) { + if (release_encrypted_buffer && encrypted_buffer != nullptr) { + release_encrypted_buffer(encrypted_buffer); + } + encrypted_buffer = other.encrypted_buffer; + encrypted_length = other.encrypted_length; + release_encrypted_buffer = std::move(other.release_encrypted_buffer); + other.release_encrypted_buffer = nullptr; + + retransmittable_frames.swap(other.retransmittable_frames); + nonretransmittable_frames.swap(other.nonretransmittable_frames); + } +} + +SerializedPacket::~SerializedPacket() { + if (release_encrypted_buffer && encrypted_buffer != nullptr) { + release_encrypted_buffer(encrypted_buffer); + } + + if (!retransmittable_frames.empty()) { + DeleteFrames(&retransmittable_frames); + } + for (auto& frame : nonretransmittable_frames) { + if (!has_ack_frame_copy && frame.type == ACK_FRAME) { + // Do not delete ack frame if the packet does not own a copy of it. + continue; + } + DeleteFrame(&frame); + } +} + +SerializedPacket* CopySerializedPacket(const SerializedPacket& serialized, + quiche::QuicheBufferAllocator* allocator, + bool copy_buffer) { + SerializedPacket* copy = new SerializedPacket( + serialized.packet_number, serialized.packet_number_length, + serialized.encrypted_buffer, serialized.encrypted_length, + serialized.has_ack, serialized.has_stop_waiting); + copy->has_crypto_handshake = serialized.has_crypto_handshake; + copy->encryption_level = serialized.encryption_level; + copy->transmission_type = serialized.transmission_type; + copy->largest_acked = serialized.largest_acked; + copy->has_ack_frequency = serialized.has_ack_frequency; + copy->has_message = serialized.has_message; + copy->fate = serialized.fate; + copy->peer_address = serialized.peer_address; + copy->bytes_not_retransmitted = serialized.bytes_not_retransmitted; + copy->initial_header = serialized.initial_header; + copy->has_ack_ecn = serialized.has_ack_ecn; + + if (copy_buffer) { + copy->encrypted_buffer = CopyBuffer(serialized); + copy->release_encrypted_buffer = [](const char* p) { delete[] p; }; + } + // Copy underlying frames. + copy->retransmittable_frames = + CopyQuicFrames(allocator, serialized.retransmittable_frames); + QUICHE_DCHECK(copy->nonretransmittable_frames.empty()); + for (const auto& frame : serialized.nonretransmittable_frames) { + if (frame.type == ACK_FRAME) { + copy->has_ack_frame_copy = true; + } + copy->nonretransmittable_frames.push_back(CopyQuicFrame(allocator, frame)); + } + return copy; +} + +char* CopyBuffer(const SerializedPacket& packet) { + return CopyBuffer(packet.encrypted_buffer, packet.encrypted_length); +} + +char* CopyBuffer(const char* encrypted_buffer, + QuicPacketLength encrypted_length) { + char* dst_buffer = new char[encrypted_length]; + memcpy(dst_buffer, encrypted_buffer, encrypted_length); + return dst_buffer; +} + +ReceivedPacketInfo::ReceivedPacketInfo(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + const QuicReceivedPacket& packet) + : self_address(self_address), + peer_address(peer_address), + packet(packet), + form(GOOGLE_QUIC_PACKET), + long_packet_type(INVALID_PACKET_TYPE), + version_flag(false), + use_length_prefix(false), + version_label(0), + version(ParsedQuicVersion::Unsupported()), + destination_connection_id(EmptyQuicConnectionId()), + source_connection_id(EmptyQuicConnectionId()) {} + +ReceivedPacketInfo::~ReceivedPacketInfo() {} + +std::string ReceivedPacketInfo::ToString() const { + std::string output = + absl::StrCat("{ self_address: ", self_address.ToString(), + ", peer_address: ", peer_address.ToString(), + ", packet_length: ", packet.length(), + ", header_format: ", form, ", version_flag: ", version_flag); + if (version_flag) { + absl::StrAppend(&output, ", version: ", ParsedQuicVersionToString(version)); + } + absl::StrAppend( + &output, + ", destination_connection_id: ", destination_connection_id.ToString(), + ", source_connection_id: ", source_connection_id.ToString(), " }\n"); + return output; +} + +std::ostream& operator<<(std::ostream& os, + const ReceivedPacketInfo& packet_info) { + os << packet_info.ToString(); + return os; +} + +bool QuicPacketHeader::operator==(const QuicPacketHeader& other) const { + return destination_connection_id == other.destination_connection_id && + destination_connection_id_included == + other.destination_connection_id_included && + source_connection_id == other.source_connection_id && + source_connection_id_included == other.source_connection_id_included && + reset_flag == other.reset_flag && version_flag == other.version_flag && + has_possible_stateless_reset_token == + other.has_possible_stateless_reset_token && + packet_number_length == other.packet_number_length && + type_byte == other.type_byte && version == other.version && + nonce == other.nonce && + ((!packet_number.IsInitialized() && + !other.packet_number.IsInitialized()) || + (packet_number.IsInitialized() && + other.packet_number.IsInitialized() && + packet_number == other.packet_number)) && + form == other.form && long_packet_type == other.long_packet_type && + possible_stateless_reset_token == + other.possible_stateless_reset_token && + retry_token_length_length == other.retry_token_length_length && + retry_token == other.retry_token && + length_length == other.length_length && + remaining_packet_length == other.remaining_packet_length; +} + +bool QuicPacketHeader::operator!=(const QuicPacketHeader& other) const { + return !operator==(other); +} + +} // namespace quic diff --git a/quiche/quic/core/quic_packets.h b/quiche/quic/core/quic_packets.h new file mode 100644 index 000000000000..d1eb52baa358 --- /dev/null +++ b/quiche/quic/core/quic_packets.h @@ -0,0 +1,452 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_PACKETS_H_ +#define QUICHE_QUIC_CORE_QUIC_PACKETS_H_ + +#include + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/frames/quic_frame.h" +#include "quiche/quic/core/quic_ack_listener_interface.h" +#include "quiche/quic/core/quic_bandwidth.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_socket_address.h" + +namespace quic { + +class QuicPacket; +struct QuicPacketHeader; + +// Returns the destination connection ID of |header| when |perspective| is +// server, and the source connection ID when |perspective| is client. +QUIC_EXPORT_PRIVATE QuicConnectionId GetServerConnectionIdAsRecipient( + const QuicPacketHeader& header, Perspective perspective); + +// Returns the destination connection ID of |header| when |perspective| is +// client, and the source connection ID when |perspective| is server. +QUIC_EXPORT_PRIVATE QuicConnectionId GetClientConnectionIdAsRecipient( + const QuicPacketHeader& header, Perspective perspective); + +// Returns the destination connection ID of |header| when |perspective| is +// client, and the source connection ID when |perspective| is server. +QUIC_EXPORT_PRIVATE QuicConnectionId GetServerConnectionIdAsSender( + const QuicPacketHeader& header, Perspective perspective); + +// Returns the destination connection ID included of |header| when |perspective| +// is client, and the source connection ID included when |perspective| is +// server. +QUIC_EXPORT_PRIVATE QuicConnectionIdIncluded +GetServerConnectionIdIncludedAsSender(const QuicPacketHeader& header, + Perspective perspective); + +// Returns the destination connection ID of |header| when |perspective| is +// server, and the source connection ID when |perspective| is client. +QUIC_EXPORT_PRIVATE QuicConnectionId GetClientConnectionIdAsSender( + const QuicPacketHeader& header, Perspective perspective); + +// Returns the destination connection ID included of |header| when |perspective| +// is server, and the source connection ID included when |perspective| is +// client. +QUIC_EXPORT_PRIVATE QuicConnectionIdIncluded +GetClientConnectionIdIncludedAsSender(const QuicPacketHeader& header, + Perspective perspective); + +// Number of connection ID bytes that are actually included over the wire. +QUIC_EXPORT_PRIVATE uint8_t +GetIncludedConnectionIdLength(QuicConnectionId connection_id, + QuicConnectionIdIncluded connection_id_included); + +// Number of destination connection ID bytes that are actually included over the +// wire for this particular header. +QUIC_EXPORT_PRIVATE uint8_t +GetIncludedDestinationConnectionIdLength(const QuicPacketHeader& header); + +// Number of source connection ID bytes that are actually included over the +// wire for this particular header. +QUIC_EXPORT_PRIVATE uint8_t +GetIncludedSourceConnectionIdLength(const QuicPacketHeader& header); + +// Size in bytes of the data packet header. +QUIC_EXPORT_PRIVATE size_t GetPacketHeaderSize(QuicTransportVersion version, + const QuicPacketHeader& header); + +QUIC_EXPORT_PRIVATE size_t GetPacketHeaderSize( + QuicTransportVersion version, uint8_t destination_connection_id_length, + uint8_t source_connection_id_length, bool include_version, + bool include_diversification_nonce, + QuicPacketNumberLength packet_number_length, + quiche::QuicheVariableLengthIntegerLength retry_token_length_length, + QuicByteCount retry_token_length, + quiche::QuicheVariableLengthIntegerLength length_length); + +// Index of the first byte in a QUIC packet of encrypted data. +QUIC_EXPORT_PRIVATE size_t GetStartOfEncryptedData( + QuicTransportVersion version, const QuicPacketHeader& header); + +QUIC_EXPORT_PRIVATE size_t GetStartOfEncryptedData( + QuicTransportVersion version, uint8_t destination_connection_id_length, + uint8_t source_connection_id_length, bool include_version, + bool include_diversification_nonce, + QuicPacketNumberLength packet_number_length, + quiche::QuicheVariableLengthIntegerLength retry_token_length_length, + QuicByteCount retry_token_length, + quiche::QuicheVariableLengthIntegerLength length_length); + +struct QUIC_EXPORT_PRIVATE QuicPacketHeader { + QuicPacketHeader(); + QuicPacketHeader(const QuicPacketHeader& other); + ~QuicPacketHeader(); + + QuicPacketHeader& operator=(const QuicPacketHeader& other); + + QUIC_EXPORT_PRIVATE friend std::ostream& operator<<( + std::ostream& os, const QuicPacketHeader& header); + + // Universal header. All QuicPacket headers will have a connection_id and + // public flags. + QuicConnectionId destination_connection_id; + QuicConnectionIdIncluded destination_connection_id_included; + QuicConnectionId source_connection_id; + QuicConnectionIdIncluded source_connection_id_included; + // This is only used for Google QUIC. + bool reset_flag; + // For Google QUIC, version flag in packets from the server means version + // negotiation packet. For IETF QUIC, version flag means long header. + bool version_flag; + // Indicates whether |possible_stateless_reset_token| contains a valid value + // parsed from the packet buffer. IETF QUIC only, always false for GQUIC. + bool has_possible_stateless_reset_token; + QuicPacketNumberLength packet_number_length; + uint8_t type_byte; + ParsedQuicVersion version; + // nonce contains an optional, 32-byte nonce value. If not included in the + // packet, |nonce| will be empty. + DiversificationNonce* nonce; + QuicPacketNumber packet_number; + // Format of this header. + PacketHeaderFormat form; + // Short packet type is reflected in packet_number_length. + QuicLongHeaderType long_packet_type; + // Only valid if |has_possible_stateless_reset_token| is true. + // Stores last 16 bytes of a this packet, used to check whether this packet is + // a stateless reset packet on decryption failure. + StatelessResetToken possible_stateless_reset_token; + // Length of the retry token length variable length integer field, + // carried only by v99 IETF Initial packets. + quiche::QuicheVariableLengthIntegerLength retry_token_length_length; + // Retry token, carried only by v99 IETF Initial packets. + absl::string_view retry_token; + // Length of the length variable length integer field, + // carried only by v99 IETF Initial, 0-RTT and Handshake packets. + quiche::QuicheVariableLengthIntegerLength length_length; + // Length of the packet number and payload, carried only by v99 IETF Initial, + // 0-RTT and Handshake packets. Also includes the length of the + // diversification nonce in server to client 0-RTT packets. + QuicByteCount remaining_packet_length; + + bool operator==(const QuicPacketHeader& other) const; + bool operator!=(const QuicPacketHeader& other) const; +}; + +struct QUIC_EXPORT_PRIVATE QuicPublicResetPacket { + QuicPublicResetPacket(); + explicit QuicPublicResetPacket(QuicConnectionId connection_id); + + QuicConnectionId connection_id; + QuicPublicResetNonceProof nonce_proof; + QuicSocketAddress client_address; + // An arbitrary string to identify an endpoint. Used by clients to + // differentiate traffic from Google servers vs Non-google servers. + // Will not be used if empty(). + std::string endpoint_id; +}; + +struct QUIC_EXPORT_PRIVATE QuicVersionNegotiationPacket { + QuicVersionNegotiationPacket(); + explicit QuicVersionNegotiationPacket(QuicConnectionId connection_id); + QuicVersionNegotiationPacket(const QuicVersionNegotiationPacket& other); + ~QuicVersionNegotiationPacket(); + + QuicConnectionId connection_id; + ParsedQuicVersionVector versions; +}; + +struct QUIC_EXPORT_PRIVATE QuicIetfStatelessResetPacket { + QuicIetfStatelessResetPacket(); + QuicIetfStatelessResetPacket(const QuicPacketHeader& header, + StatelessResetToken token); + QuicIetfStatelessResetPacket(const QuicIetfStatelessResetPacket& other); + ~QuicIetfStatelessResetPacket(); + + QuicPacketHeader header; + StatelessResetToken stateless_reset_token; +}; + +class QUIC_EXPORT_PRIVATE QuicData { + public: + // Creates a QuicData from a buffer and length. Does not own the buffer. + QuicData(const char* buffer, size_t length); + // Creates a QuicData from a buffer and length, + // optionally taking ownership of the buffer. + QuicData(const char* buffer, size_t length, bool owns_buffer); + // Creates a QuicData from a absl::string_view. Does not own the + // buffer. + QuicData(absl::string_view data); + QuicData(const QuicData&) = delete; + QuicData& operator=(const QuicData&) = delete; + virtual ~QuicData(); + + absl::string_view AsStringPiece() const { + return absl::string_view(data(), length()); + } + + const char* data() const { return buffer_; } + size_t length() const { return length_; } + + private: + const char* buffer_; + size_t length_; + bool owns_buffer_; +}; + +class QUIC_EXPORT_PRIVATE QuicPacket : public QuicData { + public: + QuicPacket( + char* buffer, size_t length, bool owns_buffer, + uint8_t destination_connection_id_length, + uint8_t source_connection_id_length, bool includes_version, + bool includes_diversification_nonce, + QuicPacketNumberLength packet_number_length, + quiche::QuicheVariableLengthIntegerLength retry_token_length_length, + QuicByteCount retry_token_length, + quiche::QuicheVariableLengthIntegerLength length_length); + QuicPacket(QuicTransportVersion version, char* buffer, size_t length, + bool owns_buffer, const QuicPacketHeader& header); + QuicPacket(const QuicPacket&) = delete; + QuicPacket& operator=(const QuicPacket&) = delete; + + absl::string_view AssociatedData(QuicTransportVersion version) const; + absl::string_view Plaintext(QuicTransportVersion version) const; + + char* mutable_data() { return buffer_; } + + private: + char* buffer_; + const uint8_t destination_connection_id_length_; + const uint8_t source_connection_id_length_; + const bool includes_version_; + const bool includes_diversification_nonce_; + const QuicPacketNumberLength packet_number_length_; + const quiche::QuicheVariableLengthIntegerLength retry_token_length_length_; + const QuicByteCount retry_token_length_; + const quiche::QuicheVariableLengthIntegerLength length_length_; +}; + +class QUIC_EXPORT_PRIVATE QuicEncryptedPacket : public QuicData { + public: + // Creates a QuicEncryptedPacket from a buffer and length. + // Does not own the buffer. + QuicEncryptedPacket(const char* buffer, size_t length); + // Creates a QuicEncryptedPacket from a buffer and length, + // optionally taking ownership of the buffer. + QuicEncryptedPacket(const char* buffer, size_t length, bool owns_buffer); + // Creates a QuicEncryptedPacket from a absl::string_view. + // Does not own the buffer. + QuicEncryptedPacket(absl::string_view data); + + QuicEncryptedPacket(const QuicEncryptedPacket&) = delete; + QuicEncryptedPacket& operator=(const QuicEncryptedPacket&) = delete; + + // Clones the packet into a new packet which owns the buffer. + std::unique_ptr Clone() const; + + // By default, gtest prints the raw bytes of an object. The bool data + // member (in the base class QuicData) causes this object to have padding + // bytes, which causes the default gtest object printer to read + // uninitialize memory. So we need to teach gtest how to print this object. + QUIC_EXPORT_PRIVATE friend std::ostream& operator<<( + std::ostream& os, const QuicEncryptedPacket& s); +}; + +// A received encrypted QUIC packet, with a recorded time of receipt. +class QUIC_EXPORT_PRIVATE QuicReceivedPacket : public QuicEncryptedPacket { + public: + QuicReceivedPacket(const char* buffer, size_t length, QuicTime receipt_time); + QuicReceivedPacket(const char* buffer, size_t length, QuicTime receipt_time, + bool owns_buffer); + QuicReceivedPacket(const char* buffer, size_t length, QuicTime receipt_time, + bool owns_buffer, int ttl, bool ttl_valid); + QuicReceivedPacket(const char* buffer, size_t length, QuicTime receipt_time, + bool owns_buffer, int ttl, bool ttl_valid, + char* packet_headers, size_t headers_length, + bool owns_header_buffer); + QuicReceivedPacket(const char* buffer, size_t length, QuicTime receipt_time, + bool owns_buffer, int ttl, bool ttl_valid, + char* packet_headers, size_t headers_length, + bool owns_header_buffer, QuicEcnCodepoint ecn_codepoint); + ~QuicReceivedPacket(); + QuicReceivedPacket(const QuicReceivedPacket&) = delete; + QuicReceivedPacket& operator=(const QuicReceivedPacket&) = delete; + + // Clones the packet into a new packet which owns the buffer. + std::unique_ptr Clone() const; + + // Returns the time at which the packet was received. + QuicTime receipt_time() const { return receipt_time_; } + + // This is the TTL of the packet, assuming ttl_vaild_ is true. + int ttl() const { return ttl_; } + + // Start of packet headers. + char* packet_headers() const { return packet_headers_; } + + // Length of packet headers. + int headers_length() const { return headers_length_; } + + // By default, gtest prints the raw bytes of an object. The bool data + // member (in the base class QuicData) causes this object to have padding + // bytes, which causes the default gtest object printer to read + // uninitialize memory. So we need to teach gtest how to print this object. + QUIC_EXPORT_PRIVATE friend std::ostream& operator<<( + std::ostream& os, const QuicReceivedPacket& s); + + QuicEcnCodepoint ecn_codepoint() const { return ecn_codepoint_; } + + private: + const QuicTime receipt_time_; + int ttl_; + // Points to the start of packet headers. + char* packet_headers_; + // Length of packet headers. + int headers_length_; + // Whether owns the buffer for packet headers. + bool owns_header_buffer_; + QuicEcnCodepoint ecn_codepoint_; +}; + +// SerializedPacket contains information of a serialized(encrypted) packet. +// +// WARNING: +// +// If you add a member field to this class, please make sure it is properly +// copied in |CopySerializedPacket|. +// +struct QUIC_EXPORT_PRIVATE SerializedPacket { + SerializedPacket(QuicPacketNumber packet_number, + QuicPacketNumberLength packet_number_length, + const char* encrypted_buffer, + QuicPacketLength encrypted_length, bool has_ack, + bool has_stop_waiting); + + // Copy constructor & assignment are deleted. Use |CopySerializedPacket| to + // make a copy. + SerializedPacket(const SerializedPacket& other) = delete; + SerializedPacket& operator=(const SerializedPacket& other) = delete; + SerializedPacket(SerializedPacket&& other); + ~SerializedPacket(); + + // TODO(wub): replace |encrypted_buffer|+|release_encrypted_buffer| by a + // QuicOwnedPacketBuffer. + // Not owned if |release_encrypted_buffer| is nullptr. Otherwise it is + // released by |release_encrypted_buffer| on destruction. + const char* encrypted_buffer; + QuicPacketLength encrypted_length; + std::function release_encrypted_buffer; + + QuicFrames retransmittable_frames; + QuicFrames nonretransmittable_frames; + IsHandshake has_crypto_handshake; + QuicPacketNumber packet_number; + QuicPacketNumberLength packet_number_length; + EncryptionLevel encryption_level; + // TODO(fayang): Remove has_ack and has_stop_waiting. + bool has_ack; + bool has_stop_waiting; + bool has_ack_ecn = false; // ack frame contains ECN counts. + TransmissionType transmission_type; + // The largest acked of the AckFrame in this packet if has_ack is true, + // 0 otherwise. + QuicPacketNumber largest_acked; + // Indicates whether this packet has a copy of ack frame in + // nonretransmittable_frames. + bool has_ack_frame_copy; + bool has_ack_frequency; + bool has_message; + SerializedPacketFate fate; + QuicSocketAddress peer_address; + // Sum of bytes from frames that are not retransmissions. This field is only + // populated for packets with "mixed frames": at least one frame of a + // retransmission type and at least one frame of NOT_RETRANSMISSION type. + absl::optional bytes_not_retransmitted; + // Only populated if encryption_level is ENCRYPTION_INITIAL. + // TODO(b/265777524): remove this. + absl::optional initial_header; +}; + +// Make a copy of |serialized| (including the underlying frames). |copy_buffer| +// indicates whether the encrypted buffer should be copied. +QUIC_EXPORT_PRIVATE SerializedPacket* CopySerializedPacket( + const SerializedPacket& serialized, + quiche::QuicheBufferAllocator* allocator, bool copy_buffer); + +// Allocates a new char[] of size |packet.encrypted_length| and copies in +// |packet.encrypted_buffer|. +QUIC_EXPORT_PRIVATE char* CopyBuffer(const SerializedPacket& packet); +// Allocates a new char[] of size |encrypted_length| and copies in +// |encrypted_buffer|. +QUIC_EXPORT_PRIVATE char* CopyBuffer(const char* encrypted_buffer, + QuicPacketLength encrypted_length); + +// Context for an incoming packet. +struct QUIC_EXPORT_PRIVATE QuicPerPacketContext { + virtual ~QuicPerPacketContext() {} +}; + +// ReceivedPacketInfo comprises information obtained by parsing the unencrypted +// bytes of a received packet. +struct QUIC_EXPORT_PRIVATE ReceivedPacketInfo { + ReceivedPacketInfo(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + const QuicReceivedPacket& packet); + ReceivedPacketInfo(const ReceivedPacketInfo& other) = default; + + ~ReceivedPacketInfo(); + + std::string ToString() const; + + QUIC_EXPORT_PRIVATE friend std::ostream& operator<<( + std::ostream& os, const ReceivedPacketInfo& packet_info); + + const QuicSocketAddress& self_address; + const QuicSocketAddress& peer_address; + const QuicReceivedPacket& packet; + + PacketHeaderFormat form; + // This is only used if the form is IETF_QUIC_LONG_HEADER_PACKET. + QuicLongHeaderType long_packet_type; + bool version_flag; + bool use_length_prefix; + QuicVersionLabel version_label; + ParsedQuicVersion version; + QuicConnectionId destination_connection_id; + QuicConnectionId source_connection_id; + absl::optional retry_token; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_PACKETS_H_ diff --git a/quiche/quic/core/quic_packets_test.cc b/quiche/quic/core/quic_packets_test.cc new file mode 100644 index 000000000000..4e6598ddb5b4 --- /dev/null +++ b/quiche/quic/core/quic_packets_test.cc @@ -0,0 +1,120 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_packets.h" + +#include "absl/memory/memory.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +namespace quic { +namespace test { +namespace { + +QuicPacketHeader CreateFakePacketHeader() { + QuicPacketHeader header; + header.destination_connection_id = TestConnectionId(1); + header.destination_connection_id_included = CONNECTION_ID_PRESENT; + header.source_connection_id = TestConnectionId(2); + header.source_connection_id_included = CONNECTION_ID_ABSENT; + return header; +} + +class QuicPacketsTest : public QuicTest {}; + +TEST_F(QuicPacketsTest, GetServerConnectionIdAsRecipient) { + QuicPacketHeader header = CreateFakePacketHeader(); + EXPECT_EQ(TestConnectionId(1), + GetServerConnectionIdAsRecipient(header, Perspective::IS_SERVER)); + EXPECT_EQ(TestConnectionId(2), + GetServerConnectionIdAsRecipient(header, Perspective::IS_CLIENT)); +} + +TEST_F(QuicPacketsTest, GetServerConnectionIdAsSender) { + QuicPacketHeader header = CreateFakePacketHeader(); + EXPECT_EQ(TestConnectionId(2), + GetServerConnectionIdAsSender(header, Perspective::IS_SERVER)); + EXPECT_EQ(TestConnectionId(1), + GetServerConnectionIdAsSender(header, Perspective::IS_CLIENT)); +} + +TEST_F(QuicPacketsTest, GetServerConnectionIdIncludedAsSender) { + QuicPacketHeader header = CreateFakePacketHeader(); + EXPECT_EQ(CONNECTION_ID_ABSENT, GetServerConnectionIdIncludedAsSender( + header, Perspective::IS_SERVER)); + EXPECT_EQ(CONNECTION_ID_PRESENT, GetServerConnectionIdIncludedAsSender( + header, Perspective::IS_CLIENT)); +} + +TEST_F(QuicPacketsTest, GetClientConnectionIdIncludedAsSender) { + QuicPacketHeader header = CreateFakePacketHeader(); + EXPECT_EQ(CONNECTION_ID_PRESENT, GetClientConnectionIdIncludedAsSender( + header, Perspective::IS_SERVER)); + EXPECT_EQ(CONNECTION_ID_ABSENT, GetClientConnectionIdIncludedAsSender( + header, Perspective::IS_CLIENT)); +} + +TEST_F(QuicPacketsTest, GetClientConnectionIdAsRecipient) { + QuicPacketHeader header = CreateFakePacketHeader(); + EXPECT_EQ(TestConnectionId(2), + GetClientConnectionIdAsRecipient(header, Perspective::IS_SERVER)); + EXPECT_EQ(TestConnectionId(1), + GetClientConnectionIdAsRecipient(header, Perspective::IS_CLIENT)); +} + +TEST_F(QuicPacketsTest, GetClientConnectionIdAsSender) { + QuicPacketHeader header = CreateFakePacketHeader(); + EXPECT_EQ(TestConnectionId(1), + GetClientConnectionIdAsSender(header, Perspective::IS_SERVER)); + EXPECT_EQ(TestConnectionId(2), + GetClientConnectionIdAsSender(header, Perspective::IS_CLIENT)); +} + +TEST_F(QuicPacketsTest, CopyQuicPacketHeader) { + QuicPacketHeader header; + QuicPacketHeader header2 = CreateFakePacketHeader(); + EXPECT_NE(header, header2); + QuicPacketHeader header3(header2); + EXPECT_EQ(header2, header3); +} + +TEST_F(QuicPacketsTest, CopySerializedPacket) { + std::string buffer(1000, 'a'); + quiche::SimpleBufferAllocator allocator; + SerializedPacket packet(QuicPacketNumber(1), PACKET_1BYTE_PACKET_NUMBER, + buffer.data(), buffer.length(), /*has_ack=*/false, + /*has_stop_waiting=*/false); + packet.retransmittable_frames.push_back(QuicFrame(QuicWindowUpdateFrame())); + packet.retransmittable_frames.push_back(QuicFrame(QuicStreamFrame())); + + QuicAckFrame ack_frame(InitAckFrame(1)); + packet.nonretransmittable_frames.push_back(QuicFrame(&ack_frame)); + packet.nonretransmittable_frames.push_back(QuicFrame(QuicPaddingFrame(-1))); + + std::unique_ptr copy = absl::WrapUnique( + CopySerializedPacket(packet, &allocator, /*copy_buffer=*/true)); + EXPECT_EQ(quic::QuicPacketNumber(1), copy->packet_number); + EXPECT_EQ(PACKET_1BYTE_PACKET_NUMBER, copy->packet_number_length); + ASSERT_EQ(2u, copy->retransmittable_frames.size()); + EXPECT_EQ(WINDOW_UPDATE_FRAME, copy->retransmittable_frames[0].type); + EXPECT_EQ(STREAM_FRAME, copy->retransmittable_frames[1].type); + + ASSERT_EQ(2u, copy->nonretransmittable_frames.size()); + EXPECT_EQ(ACK_FRAME, copy->nonretransmittable_frames[0].type); + EXPECT_EQ(PADDING_FRAME, copy->nonretransmittable_frames[1].type); + EXPECT_EQ(1000u, copy->encrypted_length); + quiche::test::CompareCharArraysWithHexError( + "encrypted_buffer", copy->encrypted_buffer, copy->encrypted_length, + packet.encrypted_buffer, packet.encrypted_length); + + std::unique_ptr copy2 = absl::WrapUnique( + CopySerializedPacket(packet, &allocator, /*copy_buffer=*/false)); + EXPECT_EQ(packet.encrypted_buffer, copy2->encrypted_buffer); + EXPECT_EQ(1000u, copy2->encrypted_length); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_path_validator.cc b/quiche/quic/core/quic_path_validator.cc new file mode 100644 index 000000000000..48c21d2330e1 --- /dev/null +++ b/quiche/quic/core/quic_path_validator.cc @@ -0,0 +1,175 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_path_validator.h" + +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_socket_address.h" + +namespace quic { + +class RetryAlarmDelegate : public QuicAlarm::DelegateWithContext { + public: + explicit RetryAlarmDelegate(QuicPathValidator* path_validator, + QuicConnectionContext* context) + : QuicAlarm::DelegateWithContext(context), + path_validator_(path_validator) {} + RetryAlarmDelegate(const RetryAlarmDelegate&) = delete; + RetryAlarmDelegate& operator=(const RetryAlarmDelegate&) = delete; + + void OnAlarm() override { path_validator_->OnRetryTimeout(); } + + private: + QuicPathValidator* path_validator_; +}; + +std::ostream& operator<<(std::ostream& os, + const QuicPathValidationContext& context) { + return os << " from " << context.self_address_ << " to " + << context.peer_address_; +} + +QuicPathValidator::QuicPathValidator(QuicAlarmFactory* alarm_factory, + QuicConnectionArena* arena, + SendDelegate* send_delegate, + QuicRandom* random, const QuicClock* clock, + QuicConnectionContext* context) + : send_delegate_(send_delegate), + random_(random), + clock_(clock), + retry_timer_(alarm_factory->CreateAlarm( + arena->New(this, context), arena)), + retry_count_(0u) {} + +void QuicPathValidator::OnPathResponse(const QuicPathFrameBuffer& probing_data, + QuicSocketAddress self_address) { + if (!HasPendingPathValidation()) { + return; + } + + QUIC_DVLOG(1) << "Match PATH_RESPONSE received on " << self_address; + QUIC_BUG_IF(quic_bug_12402_1, !path_context_->self_address().IsInitialized()) + << "Self address should have been known by now"; + if (self_address != path_context_->self_address()) { + QUIC_DVLOG(1) << "Expect the response to be received on " + << path_context_->self_address(); + return; + } + // This iterates at most 3 times. + for (auto it = probing_data_.begin(); it != probing_data_.end(); ++it) { + if (it->frame_buffer == probing_data) { + result_delegate_->OnPathValidationSuccess(std::move(path_context_), + it->send_time); + ResetPathValidation(); + return; + } + } + QUIC_DVLOG(1) << "PATH_RESPONSE with payload " << probing_data.data() + << " doesn't match the probing data."; +} + +void QuicPathValidator::StartPathValidation( + std::unique_ptr context, + std::unique_ptr result_delegate, + PathValidationReason reason) { + QUICHE_DCHECK(context); + QUIC_DLOG(INFO) << "Start validating path " << *context + << " via writer: " << context->WriterToUse(); + if (path_context_ != nullptr) { + QUIC_BUG(quic_bug_10876_1) + << "There is an on-going validation on path " << *path_context_; + ResetPathValidation(); + } + + reason_ = reason; + path_context_ = std::move(context); + result_delegate_ = std::move(result_delegate); + SendPathChallengeAndSetAlarm(); +} + +void QuicPathValidator::ResetPathValidation() { + path_context_ = nullptr; + result_delegate_ = nullptr; + retry_timer_->Cancel(); + retry_count_ = 0; + reason_ = PathValidationReason::kReasonUnknown; +} + +void QuicPathValidator::CancelPathValidation() { + if (path_context_ == nullptr) { + return; + } + QUIC_DVLOG(1) << "Cancel validation on path" << *path_context_; + result_delegate_->OnPathValidationFailure(std::move(path_context_)); + ResetPathValidation(); +} + +bool QuicPathValidator::HasPendingPathValidation() const { + return path_context_ != nullptr; +} + +QuicPathValidationContext* QuicPathValidator::GetContext() const { + return path_context_.get(); +} + +std::unique_ptr QuicPathValidator::ReleaseContext() { + auto ret = std::move(path_context_); + ResetPathValidation(); + return ret; +} + +const QuicPathFrameBuffer& QuicPathValidator::GeneratePathChallengePayload() { + probing_data_.emplace_back(clock_->Now()); + random_->RandBytes(probing_data_.back().frame_buffer.data(), + sizeof(QuicPathFrameBuffer)); + return probing_data_.back().frame_buffer; +} + +void QuicPathValidator::OnRetryTimeout() { + ++retry_count_; + if (retry_count_ > kMaxRetryTimes) { + CancelPathValidation(); + return; + } + QUIC_DVLOG(1) << "Send another PATH_CHALLENGE on path " << *path_context_; + SendPathChallengeAndSetAlarm(); +} + +void QuicPathValidator::SendPathChallengeAndSetAlarm() { + bool should_continue = send_delegate_->SendPathChallenge( + GeneratePathChallengePayload(), path_context_->self_address(), + path_context_->peer_address(), path_context_->effective_peer_address(), + path_context_->WriterToUse()); + + if (!should_continue) { + // The delegate doesn't want to continue the path validation. + CancelPathValidation(); + return; + } + retry_timer_->Set(send_delegate_->GetRetryTimeout( + path_context_->peer_address(), path_context_->WriterToUse())); +} + +bool QuicPathValidator::IsValidatingPeerAddress( + const QuicSocketAddress& effective_peer_address) { + return path_context_ != nullptr && + path_context_->effective_peer_address() == effective_peer_address; +} + +void QuicPathValidator::MaybeWritePacketToAddress( + const char* buffer, size_t buf_len, const QuicSocketAddress& peer_address) { + if (!HasPendingPathValidation() || + path_context_->peer_address() != peer_address) { + return; + } + QUIC_DVLOG(1) << "Path validator is sending packet of size " << buf_len + << " from " << path_context_->self_address() << " to " + << path_context_->peer_address(); + path_context_->WriterToUse()->WritePacket( + buffer, buf_len, path_context_->self_address().host(), + path_context_->peer_address(), nullptr); +} + +} // namespace quic diff --git a/quiche/quic/core/quic_path_validator.h b/quiche/quic/core/quic_path_validator.h new file mode 100644 index 000000000000..1079f21ca515 --- /dev/null +++ b/quiche/quic/core/quic_path_validator.h @@ -0,0 +1,194 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_PATH_VALIDATOR_H_ +#define QUICHE_QUIC_CORE_QUIC_PATH_VALIDATOR_H_ + +#include +#include + +#include "absl/container/inlined_vector.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/quic_alarm.h" +#include "quiche/quic/core/quic_alarm_factory.h" +#include "quiche/quic/core/quic_arena_scoped_ptr.h" +#include "quiche/quic/core/quic_clock.h" +#include "quiche/quic/core/quic_connection_context.h" +#include "quiche/quic/core/quic_one_block_arena.h" +#include "quiche/quic/core/quic_packet_writer.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_socket_address.h" + +namespace quic { + +namespace test { +class QuicPathValidatorPeer; +} + +class QuicConnection; + +enum class PathValidationReason { + kReasonUnknown, + kMultiPort, + kReversePathValidation, + kServerPreferredAddressMigration, + kPortMigration, + kConnectionMigration, + kMaxValue, +}; + +// Interface to provide the information of the path to be validated. +class QUIC_EXPORT_PRIVATE QuicPathValidationContext { + public: + QuicPathValidationContext(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address) + : self_address_(self_address), + peer_address_(peer_address), + effective_peer_address_(peer_address) {} + + QuicPathValidationContext(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + const QuicSocketAddress& effective_peer_address) + : self_address_(self_address), + peer_address_(peer_address), + effective_peer_address_(effective_peer_address) {} + + virtual ~QuicPathValidationContext() = default; + + virtual QuicPacketWriter* WriterToUse() = 0; + + const QuicSocketAddress& self_address() const { return self_address_; } + const QuicSocketAddress& peer_address() const { return peer_address_; } + const QuicSocketAddress& effective_peer_address() const { + return effective_peer_address_; + } + + private: + QUIC_EXPORT_PRIVATE friend std::ostream& operator<<( + std::ostream& os, const QuicPathValidationContext& context); + + QuicSocketAddress self_address_; + // The address to send PATH_CHALLENGE. + QuicSocketAddress peer_address_; + // The actual peer address which is different from |peer_address_| if the peer + // is behind a proxy. + QuicSocketAddress effective_peer_address_; +}; + +// Used to validate a path by sending up to 3 PATH_CHALLENGE frames before +// declaring a path validation failure. +class QUIC_EXPORT_PRIVATE QuicPathValidator { + public: + static const uint16_t kMaxRetryTimes = 2; + + // Used to write PATH_CHALLENGE on the path to be validated and to get retry + // timeout. + class QUIC_EXPORT_PRIVATE SendDelegate { + public: + virtual ~SendDelegate() = default; + + // Send a PATH_CHALLENGE with |data_buffer| as the frame payload using given + // path information. Return false if the delegate doesn't want to continue + // the validation. + virtual bool SendPathChallenge( + const QuicPathFrameBuffer& data_buffer, + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + const QuicSocketAddress& effective_peer_address, + QuicPacketWriter* writer) = 0; + // Return the time to retry sending PATH_CHALLENGE again based on given peer + // address and writer. + virtual QuicTime GetRetryTimeout(const QuicSocketAddress& peer_address, + QuicPacketWriter* writer) const = 0; + }; + + // Handles the validation result. + // TODO(danzh) consider to simplify this interface and its life time to + // outlive a validation. + class QUIC_EXPORT_PRIVATE ResultDelegate { + public: + virtual ~ResultDelegate() = default; + + // Called when a PATH_RESPONSE is received with a matching PATH_CHALLANGE. + // |start_time| is the time when the matching PATH_CHALLANGE was sent. + virtual void OnPathValidationSuccess( + std::unique_ptr context, + QuicTime start_time) = 0; + + virtual void OnPathValidationFailure( + std::unique_ptr context) = 0; + }; + + QuicPathValidator(QuicAlarmFactory* alarm_factory, QuicConnectionArena* arena, + SendDelegate* delegate, QuicRandom* random, + const QuicClock* clock, QuicConnectionContext* context); + + // Send PATH_CHALLENGE and start the retry timer. + void StartPathValidation(std::unique_ptr context, + std::unique_ptr result_delegate, + PathValidationReason reason); + + // Called when a PATH_RESPONSE frame has been received. Matches the received + // PATH_RESPONSE payload with the payloads previously sent in PATH_CHALLANGE + // frames and the self address on which it was sent. + void OnPathResponse(const QuicPathFrameBuffer& probing_data, + QuicSocketAddress self_address); + + // Cancel the retry timer and reset the path and result delegate. + void CancelPathValidation(); + + bool HasPendingPathValidation() const; + + QuicPathValidationContext* GetContext() const; + + // Pass the ownership of path_validation context to the caller and reset the + // validator. + std::unique_ptr ReleaseContext(); + + PathValidationReason GetPathValidationReason() const { return reason_; } + + // Send another PATH_CHALLENGE on the same path. After retrying + // |kMaxRetryTimes| times, fail the current path validation. + void OnRetryTimeout(); + + bool IsValidatingPeerAddress(const QuicSocketAddress& effective_peer_address); + + // Called to send packet to |peer_address| if the path validation to this + // address is pending. + void MaybeWritePacketToAddress(const char* buffer, size_t buf_len, + const QuicSocketAddress& peer_address); + + private: + friend class test::QuicPathValidatorPeer; + + // Return the payload to be used in the next PATH_CHALLENGE frame. + const QuicPathFrameBuffer& GeneratePathChallengePayload(); + + void SendPathChallengeAndSetAlarm(); + + void ResetPathValidation(); + + struct QUIC_NO_EXPORT ProbingData { + explicit ProbingData(QuicTime send_time) : send_time(send_time) {} + QuicPathFrameBuffer frame_buffer; + QuicTime send_time; + }; + + // Has at most 3 entries due to validation timeout. + absl::InlinedVector probing_data_; + SendDelegate* send_delegate_; + QuicRandom* random_; + const QuicClock* clock_; + std::unique_ptr path_context_; + std::unique_ptr result_delegate_; + QuicArenaScopedPtr retry_timer_; + size_t retry_count_; + PathValidationReason reason_ = PathValidationReason::kReasonUnknown; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_PATH_VALIDATOR_H_ diff --git a/quiche/quic/core/quic_path_validator_test.cc b/quiche/quic/core/quic_path_validator_test.cc new file mode 100644 index 000000000000..6d0be9e829f3 --- /dev/null +++ b/quiche/quic/core/quic_path_validator_test.cc @@ -0,0 +1,276 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_path_validator.h" + +#include + +#include "quiche/quic/core/frames/quic_path_challenge_frame.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/mock_clock.h" +#include "quiche/quic/test_tools/mock_random.h" +#include "quiche/quic/test_tools/quic_path_validator_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +using testing::_; +using testing::Invoke; +using testing::Return; + +namespace quic { +namespace test { + +class MockSendDelegate : public QuicPathValidator::SendDelegate { + public: + // Send a PATH_CHALLENGE frame using given path information and populate + // |data_buffer| with the frame payload. Return true if the validator should + // move forward in validation, i.e. arm the retry timer. + MOCK_METHOD(bool, SendPathChallenge, + (const QuicPathFrameBuffer&, const QuicSocketAddress&, + const QuicSocketAddress&, const QuicSocketAddress&, + QuicPacketWriter*), + (override)); + + MOCK_METHOD(QuicTime, GetRetryTimeout, + (const QuicSocketAddress&, QuicPacketWriter*), (const, override)); +}; + +class QuicPathValidatorTest : public QuicTest { + public: + QuicPathValidatorTest() + : path_validator_(&alarm_factory_, &arena_, &send_delegate_, &random_, + &clock_, + /*context=*/nullptr), + context_(new MockQuicPathValidationContext( + self_address_, peer_address_, effective_peer_address_, &writer_)), + result_delegate_( + new testing::StrictMock()) { + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(1)); + ON_CALL(send_delegate_, GetRetryTimeout(_, _)) + .WillByDefault( + Return(clock_.ApproximateNow() + + 3 * QuicTime::Delta::FromMilliseconds(kInitialRttMs))); + } + + protected: + quic::test::MockAlarmFactory alarm_factory_; + MockSendDelegate send_delegate_; + MockRandom random_; + MockClock clock_; + QuicConnectionArena arena_; + QuicPathValidator path_validator_; + QuicSocketAddress self_address_{QuicIpAddress::Any4(), 443}; + QuicSocketAddress peer_address_{QuicIpAddress::Loopback4(), 443}; + QuicSocketAddress effective_peer_address_{QuicIpAddress::Loopback4(), 12345}; + MockPacketWriter writer_; + MockQuicPathValidationContext* context_; + MockQuicPathValidationResultDelegate* result_delegate_; +}; + +TEST_F(QuicPathValidatorTest, PathValidationSuccessOnFirstRound) { + QuicPathFrameBuffer challenge_data; + EXPECT_CALL(send_delegate_, + SendPathChallenge(_, self_address_, peer_address_, + effective_peer_address_, &writer_)) + .WillOnce(Invoke([&](const QuicPathFrameBuffer& payload, + const QuicSocketAddress&, const QuicSocketAddress&, + const QuicSocketAddress&, QuicPacketWriter*) { + memcpy(challenge_data.data(), payload.data(), payload.size()); + return true; + })); + EXPECT_CALL(send_delegate_, GetRetryTimeout(peer_address_, &writer_)); + const QuicTime expected_start_time = clock_.Now(); + path_validator_.StartPathValidation( + std::unique_ptr(context_), + std::unique_ptr(result_delegate_), + PathValidationReason::kMultiPort); + EXPECT_TRUE(path_validator_.HasPendingPathValidation()); + EXPECT_EQ(PathValidationReason::kMultiPort, + path_validator_.GetPathValidationReason()); + EXPECT_TRUE(path_validator_.IsValidatingPeerAddress(effective_peer_address_)); + EXPECT_CALL(*result_delegate_, OnPathValidationSuccess(_, _)) + .WillOnce(Invoke([=](std::unique_ptr context, + QuicTime start_time) { + EXPECT_EQ(context.get(), context_); + EXPECT_EQ(start_time, expected_start_time); + })); + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(kInitialRttMs)); + path_validator_.OnPathResponse(challenge_data, self_address_); + EXPECT_FALSE(path_validator_.HasPendingPathValidation()); + EXPECT_EQ(PathValidationReason::kReasonUnknown, + path_validator_.GetPathValidationReason()); +} + +TEST_F(QuicPathValidatorTest, RespondWithDifferentSelfAddress) { + QuicPathFrameBuffer challenge_data; + EXPECT_CALL(send_delegate_, + SendPathChallenge(_, self_address_, peer_address_, + effective_peer_address_, &writer_)) + .WillOnce(Invoke([&](const QuicPathFrameBuffer payload, + const QuicSocketAddress&, const QuicSocketAddress&, + const QuicSocketAddress&, QuicPacketWriter*) { + memcpy(challenge_data.data(), payload.data(), payload.size()); + return true; + })); + EXPECT_CALL(send_delegate_, GetRetryTimeout(peer_address_, &writer_)); + const QuicTime expected_start_time = clock_.Now(); + path_validator_.StartPathValidation( + std::unique_ptr(context_), + std::unique_ptr(result_delegate_), + PathValidationReason::kMultiPort); + + // Reception of a PATH_RESPONSE on a different self address should be ignored. + const QuicSocketAddress kAlternativeSelfAddress(QuicIpAddress::Any6(), 54321); + EXPECT_NE(kAlternativeSelfAddress, self_address_); + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(kInitialRttMs)); + path_validator_.OnPathResponse(challenge_data, kAlternativeSelfAddress); + + EXPECT_CALL(*result_delegate_, OnPathValidationSuccess(_, _)) + .WillOnce(Invoke([=](std::unique_ptr context, + QuicTime start_time) { + EXPECT_EQ(context->self_address(), self_address_); + EXPECT_EQ(start_time, expected_start_time); + })); + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(kInitialRttMs)); + path_validator_.OnPathResponse(challenge_data, self_address_); + EXPECT_EQ(PathValidationReason::kReasonUnknown, + path_validator_.GetPathValidationReason()); +} + +TEST_F(QuicPathValidatorTest, RespondAfter1stRetry) { + QuicPathFrameBuffer challenge_data; + EXPECT_CALL(send_delegate_, + SendPathChallenge(_, self_address_, peer_address_, + effective_peer_address_, &writer_)) + .WillOnce(Invoke([&](const QuicPathFrameBuffer& payload, + const QuicSocketAddress&, const QuicSocketAddress&, + const QuicSocketAddress&, QuicPacketWriter*) { + // Store up the 1st PATH_CHALLANGE payload. + memcpy(challenge_data.data(), payload.data(), payload.size()); + return true; + })) + .WillOnce(Invoke([&](const QuicPathFrameBuffer& payload, + const QuicSocketAddress&, const QuicSocketAddress&, + const QuicSocketAddress&, QuicPacketWriter*) { + EXPECT_NE(payload, challenge_data); + return true; + })); + EXPECT_CALL(send_delegate_, GetRetryTimeout(peer_address_, &writer_)) + .Times(2u); + const QuicTime start_time = clock_.Now(); + path_validator_.StartPathValidation( + std::unique_ptr(context_), + std::unique_ptr(result_delegate_), + PathValidationReason::kMultiPort); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(3 * kInitialRttMs)); + random_.ChangeValue(); + alarm_factory_.FireAlarm( + QuicPathValidatorPeer::retry_timer(&path_validator_)); + + EXPECT_CALL(*result_delegate_, OnPathValidationSuccess(_, start_time)); + // Respond to the 1st PATH_CHALLENGE should complete the validation. + path_validator_.OnPathResponse(challenge_data, self_address_); + EXPECT_FALSE(path_validator_.HasPendingPathValidation()); +} + +TEST_F(QuicPathValidatorTest, RespondToRetryChallenge) { + QuicPathFrameBuffer challenge_data; + EXPECT_CALL(send_delegate_, + SendPathChallenge(_, self_address_, peer_address_, + effective_peer_address_, &writer_)) + .WillOnce(Invoke([&](const QuicPathFrameBuffer& payload, + const QuicSocketAddress&, const QuicSocketAddress&, + const QuicSocketAddress&, QuicPacketWriter*) { + memcpy(challenge_data.data(), payload.data(), payload.size()); + return true; + })) + .WillOnce(Invoke([&](const QuicPathFrameBuffer& payload, + const QuicSocketAddress&, const QuicSocketAddress&, + const QuicSocketAddress&, QuicPacketWriter*) { + EXPECT_NE(challenge_data, payload); + memcpy(challenge_data.data(), payload.data(), payload.size()); + return true; + })); + EXPECT_CALL(send_delegate_, GetRetryTimeout(peer_address_, &writer_)) + .Times(2u); + path_validator_.StartPathValidation( + std::unique_ptr(context_), + std::unique_ptr(result_delegate_), + PathValidationReason::kMultiPort); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(3 * kInitialRttMs)); + const QuicTime start_time = clock_.Now(); + random_.ChangeValue(); + alarm_factory_.FireAlarm( + QuicPathValidatorPeer::retry_timer(&path_validator_)); + + // Respond to the 2nd PATH_CHALLENGE should complete the validation. + EXPECT_CALL(*result_delegate_, OnPathValidationSuccess(_, start_time)); + path_validator_.OnPathResponse(challenge_data, self_address_); + EXPECT_FALSE(path_validator_.HasPendingPathValidation()); +} + +TEST_F(QuicPathValidatorTest, ValidationTimeOut) { + EXPECT_CALL(send_delegate_, + SendPathChallenge(_, self_address_, peer_address_, + effective_peer_address_, &writer_)) + .Times(3u) + .WillRepeatedly(Return(true)); + EXPECT_CALL(send_delegate_, GetRetryTimeout(peer_address_, &writer_)) + .Times(3u); + path_validator_.StartPathValidation( + std::unique_ptr(context_), + std::unique_ptr(result_delegate_), + PathValidationReason::kMultiPort); + + QuicPathFrameBuffer challenge_data; + memset(challenge_data.data(), 'a', challenge_data.size()); + // Reception of a PATH_RESPONSE with different payload should be ignored. + path_validator_.OnPathResponse(challenge_data, self_address_); + + // Retry 3 times. The 3rd time should fail the validation. + EXPECT_CALL(*result_delegate_, OnPathValidationFailure(_)) + .WillOnce(Invoke([=](std::unique_ptr context) { + EXPECT_EQ(context_, context.get()); + })); + for (size_t i = 0; i <= QuicPathValidator::kMaxRetryTimes; ++i) { + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(3 * kInitialRttMs)); + alarm_factory_.FireAlarm( + QuicPathValidatorPeer::retry_timer(&path_validator_)); + } + EXPECT_EQ(PathValidationReason::kReasonUnknown, + path_validator_.GetPathValidationReason()); +} + +TEST_F(QuicPathValidatorTest, SendPathChallengeError) { + EXPECT_CALL(send_delegate_, + SendPathChallenge(_, self_address_, peer_address_, + effective_peer_address_, &writer_)) + .WillOnce(Invoke([&](const QuicPathFrameBuffer&, const QuicSocketAddress&, + const QuicSocketAddress&, const QuicSocketAddress&, + QuicPacketWriter*) { + // Abandon this validation in the call stack shouldn't cause crash and + // should cancel the alarm. + path_validator_.CancelPathValidation(); + return false; + })); + EXPECT_CALL(send_delegate_, GetRetryTimeout(peer_address_, &writer_)) + .Times(0u); + EXPECT_CALL(*result_delegate_, OnPathValidationFailure(_)); + path_validator_.StartPathValidation( + std::unique_ptr(context_), + std::unique_ptr(result_delegate_), + PathValidationReason::kMultiPort); + EXPECT_FALSE(path_validator_.HasPendingPathValidation()); + EXPECT_FALSE(QuicPathValidatorPeer::retry_timer(&path_validator_)->IsSet()); + EXPECT_EQ(PathValidationReason::kReasonUnknown, + path_validator_.GetPathValidationReason()); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_ping_manager.cc b/quiche/quic/core/quic_ping_manager.cc new file mode 100644 index 000000000000..1574632b5a3f --- /dev/null +++ b/quiche/quic/core/quic_ping_manager.cc @@ -0,0 +1,163 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_ping_manager.h" + +#include "quiche/quic/platform/api/quic_flags.h" + +namespace quic { + +namespace { + +// Maximum shift used to calculate retransmittable on wire timeout. For 200ms +// initial retransmittable on wire delay, this would get a maximum of 200ms * (1 +// << 10) = 204.8s +const int kMaxRetransmittableOnWireDelayShift = 10; + +class AlarmDelegate : public QuicAlarm::DelegateWithContext { + public: + explicit AlarmDelegate(QuicPingManager* manager, + QuicConnectionContext* context) + : QuicAlarm::DelegateWithContext(context), manager_(manager) {} + AlarmDelegate(const AlarmDelegate&) = delete; + AlarmDelegate& operator=(const AlarmDelegate&) = delete; + + void OnAlarm() override { manager_->OnAlarm(); } + + private: + QuicPingManager* manager_; +}; + +} // namespace + +QuicPingManager::QuicPingManager(Perspective perspective, Delegate* delegate, + QuicConnectionArena* arena, + QuicAlarmFactory* alarm_factory, + QuicConnectionContext* context) + : perspective_(perspective), + delegate_(delegate), + alarm_(alarm_factory->CreateAlarm( + arena->New(this, context), arena)) {} + +void QuicPingManager::SetAlarm(QuicTime now, bool should_keep_alive, + bool has_in_flight_packets) { + UpdateDeadlines(now, should_keep_alive, has_in_flight_packets); + const QuicTime earliest_deadline = GetEarliestDeadline(); + if (!earliest_deadline.IsInitialized()) { + alarm_->Cancel(); + return; + } + if (earliest_deadline == keep_alive_deadline_) { + // Use 1s granularity for keep-alive time. + alarm_->Update(earliest_deadline, QuicTime::Delta::FromSeconds(1)); + return; + } + alarm_->Update(earliest_deadline, kAlarmGranularity); +} + +void QuicPingManager::OnAlarm() { + const QuicTime earliest_deadline = GetEarliestDeadline(); + if (!earliest_deadline.IsInitialized()) { + QUIC_BUG(quic_ping_manager_alarm_fires_unexpectedly) + << "QuicPingManager alarm fires unexpectedly."; + return; + } + // Please note, alarm does not get re-armed here, and we are relying on caller + // to SetAlarm later. + if (earliest_deadline == retransmittable_on_wire_deadline_) { + retransmittable_on_wire_deadline_ = QuicTime::Zero(); + if (GetQuicFlag(quic_max_aggressive_retransmittable_on_wire_ping_count) != + 0) { + ++consecutive_retransmittable_on_wire_count_; + } + ++retransmittable_on_wire_count_; + delegate_->OnRetransmittableOnWireTimeout(); + return; + } + if (earliest_deadline == keep_alive_deadline_) { + keep_alive_deadline_ = QuicTime::Zero(); + delegate_->OnKeepAliveTimeout(); + } +} + +void QuicPingManager::Stop() { + alarm_->PermanentCancel(); + retransmittable_on_wire_deadline_ = QuicTime::Zero(); + keep_alive_deadline_ = QuicTime::Zero(); +} + +void QuicPingManager::UpdateDeadlines(QuicTime now, bool should_keep_alive, + bool has_in_flight_packets) { + // Reset keep-alive deadline given it will be set later (with left edge + // |now|). + keep_alive_deadline_ = QuicTime::Zero(); + if (perspective_ == Perspective::IS_SERVER && + initial_retransmittable_on_wire_timeout_.IsInfinite()) { + // The PING alarm exists to support two features: + // 1) clients send PINGs every 15s to prevent NAT timeouts, + // 2) both clients and servers can send retransmittable on the wire PINGs + // (ROWP) while ShouldKeepConnectionAlive is true and there is no packets in + // flight. + QUICHE_DCHECK(!retransmittable_on_wire_deadline_.IsInitialized()); + return; + } + if (!should_keep_alive) { + // Don't send a ping unless the application (ie: HTTP/3) says to, usually + // because it is expecting a response from the peer. + retransmittable_on_wire_deadline_ = QuicTime::Zero(); + return; + } + if (perspective_ == Perspective::IS_CLIENT) { + // Clients send 15s PINGs to avoid NATs from timing out. + keep_alive_deadline_ = now + keep_alive_timeout_; + } + if (initial_retransmittable_on_wire_timeout_.IsInfinite() || + has_in_flight_packets || + retransmittable_on_wire_count_ > + GetQuicFlag(quic_max_retransmittable_on_wire_ping_count)) { + // No need to set retransmittable-on-wire timeout. + retransmittable_on_wire_deadline_ = QuicTime::Zero(); + return; + } + + QUICHE_DCHECK_LT(initial_retransmittable_on_wire_timeout_, + keep_alive_timeout_); + QuicTime::Delta retransmittable_on_wire_timeout = + initial_retransmittable_on_wire_timeout_; + const int max_aggressive_retransmittable_on_wire_count = + GetQuicFlag(quic_max_aggressive_retransmittable_on_wire_ping_count); + QUICHE_DCHECK_LE(0, max_aggressive_retransmittable_on_wire_count); + if (consecutive_retransmittable_on_wire_count_ > + max_aggressive_retransmittable_on_wire_count) { + // Exponentially back off the timeout if the number of consecutive + // retransmittable on wire pings has exceeds the allowance. + int shift = std::min(consecutive_retransmittable_on_wire_count_ - + max_aggressive_retransmittable_on_wire_count, + kMaxRetransmittableOnWireDelayShift); + retransmittable_on_wire_timeout = + initial_retransmittable_on_wire_timeout_ * (1 << shift); + } + if (retransmittable_on_wire_deadline_.IsInitialized() && + retransmittable_on_wire_deadline_ < + now + retransmittable_on_wire_timeout) { + // Alarm is set to an earlier time. Do not postpone it. + return; + } + retransmittable_on_wire_deadline_ = now + retransmittable_on_wire_timeout; +} + +QuicTime QuicPingManager::GetEarliestDeadline() const { + QuicTime earliest_deadline = QuicTime::Zero(); + for (QuicTime t : {retransmittable_on_wire_deadline_, keep_alive_deadline_}) { + if (!t.IsInitialized()) { + continue; + } + if (!earliest_deadline.IsInitialized() || t < earliest_deadline) { + earliest_deadline = t; + } + } + return earliest_deadline; +} + +} // namespace quic diff --git a/quiche/quic/core/quic_ping_manager.h b/quiche/quic/core/quic_ping_manager.h new file mode 100644 index 000000000000..d88dac26f369 --- /dev/null +++ b/quiche/quic/core/quic_ping_manager.h @@ -0,0 +1,108 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_PING_MANAGER_H_ +#define QUICHE_QUIC_CORE_QUIC_PING_MANAGER_H_ + +#include "quiche/quic/core/quic_alarm.h" +#include "quiche/quic/core/quic_alarm_factory.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_one_block_arena.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +namespace test { +class QuicConnectionPeer; +class QuicPingManagerPeer; +} // namespace test + +// QuicPingManager manages an alarm that has two modes: +// 1) keep-alive. When alarm fires, send packet to extend idle timeout to keep +// connection alive. +// 2) retransmittable-on-wire. When alarm fires, send packets to detect path +// degrading (used in IP/port migrations). +class QUIC_EXPORT_PRIVATE QuicPingManager { + public: + // Interface that get notified when |alarm_| fires. + class QUIC_EXPORT_PRIVATE Delegate { + public: + virtual ~Delegate() {} + + // Called when alarm fires in keep-alive mode. + virtual void OnKeepAliveTimeout() = 0; + // Called when alarm fires in retransmittable-on-wire mode. + virtual void OnRetransmittableOnWireTimeout() = 0; + }; + + QuicPingManager(Perspective perspective, Delegate* delegate, + QuicConnectionArena* arena, QuicAlarmFactory* alarm_factory, + QuicConnectionContext* context); + + // Called to set |alarm_|. + void SetAlarm(QuicTime now, bool should_keep_alive, + bool has_in_flight_packets); + + // Called when |alarm_| fires. + void OnAlarm(); + + // Called to stop |alarm_| permanently. + void Stop(); + + void set_keep_alive_timeout(QuicTime::Delta keep_alive_timeout) { + QUICHE_DCHECK(!alarm_->IsSet()); + keep_alive_timeout_ = keep_alive_timeout; + } + + void set_initial_retransmittable_on_wire_timeout( + QuicTime::Delta retransmittable_on_wire_timeout) { + QUICHE_DCHECK(!alarm_->IsSet()); + initial_retransmittable_on_wire_timeout_ = retransmittable_on_wire_timeout; + } + + void reset_consecutive_retransmittable_on_wire_count() { + consecutive_retransmittable_on_wire_count_ = 0; + } + + private: + friend class test::QuicConnectionPeer; + friend class test::QuicPingManagerPeer; + + // Update |retransmittable_on_wire_deadline_| and |keep_alive_deadline_|. + void UpdateDeadlines(QuicTime now, bool should_keep_alive, + bool has_in_flight_packets); + + // Get earliest deadline of |retransmittable_on_wire_deadline_| and + // |keep_alive_deadline_|. Returns 0 if both deadlines are not initialized. + QuicTime GetEarliestDeadline() const; + + Perspective perspective_; + + Delegate* delegate_; // Not owned. + + // Initial timeout for how long the wire can have no retransmittable packets. + QuicTime::Delta initial_retransmittable_on_wire_timeout_ = + QuicTime::Delta::Infinite(); + + // Indicates how many consecutive retransmittable-on-wire has been armed + // (since last reset). + int consecutive_retransmittable_on_wire_count_ = 0; + + // Indicates how many retransmittable-on-wire has been armed in total. + int retransmittable_on_wire_count_ = 0; + + QuicTime::Delta keep_alive_timeout_ = + QuicTime::Delta::FromSeconds(kPingTimeoutSecs); + + QuicTime retransmittable_on_wire_deadline_ = QuicTime::Zero(); + + QuicTime keep_alive_deadline_ = QuicTime::Zero(); + + QuicArenaScopedPtr alarm_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_PING_MANAGER_H_ diff --git a/quiche/quic/core/quic_ping_manager_test.cc b/quiche/quic/core/quic_ping_manager_test.cc new file mode 100644 index 000000000000..d9acc7aa48d0 --- /dev/null +++ b/quiche/quic/core/quic_ping_manager_test.cc @@ -0,0 +1,429 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_ping_manager.h" + +#include "quiche/quic/core/quic_one_block_arena.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { +namespace test { + +class QuicPingManagerPeer { + public: + static QuicAlarm* GetAlarm(QuicPingManager* manager) { + return manager->alarm_.get(); + } + + static void SetPerspective(QuicPingManager* manager, + Perspective perspective) { + manager->perspective_ = perspective; + } +}; + +namespace { + +const bool kShouldKeepAlive = true; +const bool kHasInflightPackets = true; + +class MockDelegate : public QuicPingManager::Delegate { + public: + MOCK_METHOD(void, OnKeepAliveTimeout, (), (override)); + MOCK_METHOD(void, OnRetransmittableOnWireTimeout, (), (override)); +}; + +class QuicPingManagerTest : public QuicTest { + public: + QuicPingManagerTest() + : manager_(Perspective::IS_CLIENT, &delegate_, &arena_, &alarm_factory_, + /*context=*/nullptr), + alarm_(static_cast( + QuicPingManagerPeer::GetAlarm(&manager_))) { + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(1)); + } + + protected: + testing::StrictMock delegate_; + MockClock clock_; + QuicConnectionArena arena_; + MockAlarmFactory alarm_factory_; + QuicPingManager manager_; + MockAlarmFactory::TestAlarm* alarm_; +}; + +TEST_F(QuicPingManagerTest, KeepAliveTimeout) { + EXPECT_FALSE(alarm_->IsSet()); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + // Set alarm with in flight packets. + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + kHasInflightPackets); + EXPECT_TRUE(alarm_->IsSet()); + EXPECT_EQ(QuicTime::Delta::FromSeconds(kPingTimeoutSecs), + alarm_->deadline() - clock_.ApproximateNow()); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + // Reset alarm with no in flight packets. + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + !kHasInflightPackets); + EXPECT_TRUE(alarm_->IsSet()); + // Verify the deadline is set slightly less than 15 seconds in the future, + // because of the 1s alarm granularity. + EXPECT_EQ(QuicTime::Delta::FromSeconds(kPingTimeoutSecs) - + QuicTime::Delta::FromMilliseconds(5), + alarm_->deadline() - clock_.ApproximateNow()); + + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(kPingTimeoutSecs)); + EXPECT_CALL(delegate_, OnKeepAliveTimeout()); + alarm_->Fire(); + EXPECT_FALSE(alarm_->IsSet()); + // Reset alarm with in flight packets. + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + kHasInflightPackets); + EXPECT_TRUE(alarm_->IsSet()); + + // Verify alarm is not armed if !kShouldKeepAlive. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + manager_.SetAlarm(clock_.ApproximateNow(), !kShouldKeepAlive, + kHasInflightPackets); + EXPECT_FALSE(alarm_->IsSet()); +} + +TEST_F(QuicPingManagerTest, CustomizedKeepAliveTimeout) { + EXPECT_FALSE(alarm_->IsSet()); + + // Set customized keep-alive timeout. + manager_.set_keep_alive_timeout(QuicTime::Delta::FromSeconds(10)); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + // Set alarm with in flight packets. + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + kHasInflightPackets); + EXPECT_TRUE(alarm_->IsSet()); + EXPECT_EQ(QuicTime::Delta::FromSeconds(10), + alarm_->deadline() - clock_.ApproximateNow()); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + // Set alarm with no in flight packets. + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + !kHasInflightPackets); + EXPECT_TRUE(alarm_->IsSet()); + // The deadline is set slightly less than 10 seconds in the future, because + // of the 1s alarm granularity. + EXPECT_EQ( + QuicTime::Delta::FromSeconds(10) - QuicTime::Delta::FromMilliseconds(5), + alarm_->deadline() - clock_.ApproximateNow()); + + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(10)); + EXPECT_CALL(delegate_, OnKeepAliveTimeout()); + alarm_->Fire(); + EXPECT_FALSE(alarm_->IsSet()); + // Reset alarm with in flight packets. + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + kHasInflightPackets); + EXPECT_TRUE(alarm_->IsSet()); + + // Verify alarm is not armed if !kShouldKeepAlive. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + manager_.SetAlarm(clock_.ApproximateNow(), !kShouldKeepAlive, + kHasInflightPackets); + EXPECT_FALSE(alarm_->IsSet()); +} + +TEST_F(QuicPingManagerTest, RetransmittableOnWireTimeout) { + const QuicTime::Delta kRtransmittableOnWireTimeout = + QuicTime::Delta::FromMilliseconds(50); + manager_.set_initial_retransmittable_on_wire_timeout( + kRtransmittableOnWireTimeout); + + EXPECT_FALSE(alarm_->IsSet()); + + // Set alarm with in flight packets. + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + kHasInflightPackets); + // Verify alarm is in keep-alive mode. + EXPECT_TRUE(alarm_->IsSet()); + EXPECT_EQ(QuicTime::Delta::FromSeconds(kPingTimeoutSecs), + alarm_->deadline() - clock_.ApproximateNow()); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + // Set alarm with no in flight packets. + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + !kHasInflightPackets); + EXPECT_TRUE(alarm_->IsSet()); + // Verify alarm is in retransmittable-on-wire mode. + EXPECT_EQ(kRtransmittableOnWireTimeout, + alarm_->deadline() - clock_.ApproximateNow()); + + clock_.AdvanceTime(kRtransmittableOnWireTimeout); + EXPECT_CALL(delegate_, OnRetransmittableOnWireTimeout()); + alarm_->Fire(); + EXPECT_FALSE(alarm_->IsSet()); + // Reset alarm with in flight packets. + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + kHasInflightPackets); + // Verify the alarm is in keep-alive mode. + ASSERT_TRUE(alarm_->IsSet()); + EXPECT_EQ(QuicTime::Delta::FromSeconds(kPingTimeoutSecs), + alarm_->deadline() - clock_.ApproximateNow()); +} + +TEST_F(QuicPingManagerTest, RetransmittableOnWireTimeoutExponentiallyBackOff) { + const int kMaxAggressiveRetransmittableOnWireCount = 5; + SetQuicFlag(quic_max_aggressive_retransmittable_on_wire_ping_count, + kMaxAggressiveRetransmittableOnWireCount); + const QuicTime::Delta initial_retransmittable_on_wire_timeout = + QuicTime::Delta::FromMilliseconds(200); + manager_.set_initial_retransmittable_on_wire_timeout( + initial_retransmittable_on_wire_timeout); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + EXPECT_FALSE(alarm_->IsSet()); + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + kHasInflightPackets); + // Verify alarm is in keep-alive mode. + EXPECT_TRUE(alarm_->IsSet()); + EXPECT_EQ(QuicTime::Delta::FromSeconds(kPingTimeoutSecs), + alarm_->deadline() - clock_.ApproximateNow()); + + // Verify no exponential backoff on the first few retransmittable on wire + // timeouts. + for (int i = 0; i <= kMaxAggressiveRetransmittableOnWireCount; ++i) { + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + // Reset alarm with no in flight packets. + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + !kHasInflightPackets); + EXPECT_TRUE(alarm_->IsSet()); + // Verify alarm is in retransmittable-on-wire mode. + EXPECT_EQ(initial_retransmittable_on_wire_timeout, + alarm_->deadline() - clock_.ApproximateNow()); + clock_.AdvanceTime(initial_retransmittable_on_wire_timeout); + EXPECT_CALL(delegate_, OnRetransmittableOnWireTimeout()); + alarm_->Fire(); + EXPECT_FALSE(alarm_->IsSet()); + // Reset alarm with in flight packets. + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + kHasInflightPackets); + } + + QuicTime::Delta retransmittable_on_wire_timeout = + initial_retransmittable_on_wire_timeout; + + // Verify subsequent retransmittable-on-wire timeout is exponentially backed + // off. + while (retransmittable_on_wire_timeout * 2 < + QuicTime::Delta::FromSeconds(kPingTimeoutSecs)) { + retransmittable_on_wire_timeout = retransmittable_on_wire_timeout * 2; + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + !kHasInflightPackets); + EXPECT_TRUE(alarm_->IsSet()); + EXPECT_EQ(retransmittable_on_wire_timeout, + alarm_->deadline() - clock_.ApproximateNow()); + + clock_.AdvanceTime(retransmittable_on_wire_timeout); + EXPECT_CALL(delegate_, OnRetransmittableOnWireTimeout()); + alarm_->Fire(); + EXPECT_FALSE(alarm_->IsSet()); + // Reset alarm with in flight packets. + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + kHasInflightPackets); + } + + // Verify alarm is in keep-alive mode. + EXPECT_TRUE(alarm_->IsSet()); + EXPECT_EQ(QuicTime::Delta::FromSeconds(kPingTimeoutSecs), + alarm_->deadline() - clock_.ApproximateNow()); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + // Reset alarm with no in flight packets + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + !kHasInflightPackets); + EXPECT_TRUE(alarm_->IsSet()); + // Verify alarm is in keep-alive mode because retransmittable-on-wire deadline + // is later. + EXPECT_EQ(QuicTime::Delta::FromSeconds(kPingTimeoutSecs) - + QuicTime::Delta::FromMilliseconds(5), + alarm_->deadline() - clock_.ApproximateNow()); + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(kPingTimeoutSecs) - + QuicTime::Delta::FromMilliseconds(5)); + EXPECT_CALL(delegate_, OnKeepAliveTimeout()); + alarm_->Fire(); + EXPECT_FALSE(alarm_->IsSet()); +} + +TEST_F(QuicPingManagerTest, + ResetRetransmitableOnWireTimeoutExponentiallyBackOff) { + const int kMaxAggressiveRetransmittableOnWireCount = 3; + SetQuicFlag(quic_max_aggressive_retransmittable_on_wire_ping_count, + kMaxAggressiveRetransmittableOnWireCount); + const QuicTime::Delta initial_retransmittable_on_wire_timeout = + QuicTime::Delta::FromMilliseconds(200); + manager_.set_initial_retransmittable_on_wire_timeout( + initial_retransmittable_on_wire_timeout); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + EXPECT_FALSE(alarm_->IsSet()); + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + kHasInflightPackets); + // Verify alarm is in keep-alive mode. + EXPECT_TRUE(alarm_->IsSet()); + EXPECT_EQ(QuicTime::Delta::FromSeconds(kPingTimeoutSecs), + alarm_->deadline() - clock_.ApproximateNow()); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + !kHasInflightPackets); + EXPECT_TRUE(alarm_->IsSet()); + // Verify alarm is in retransmittable-on-wire mode. + EXPECT_EQ(initial_retransmittable_on_wire_timeout, + alarm_->deadline() - clock_.ApproximateNow()); + + EXPECT_CALL(delegate_, OnRetransmittableOnWireTimeout()); + clock_.AdvanceTime(initial_retransmittable_on_wire_timeout); + alarm_->Fire(); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + !kHasInflightPackets); + EXPECT_TRUE(alarm_->IsSet()); + EXPECT_EQ(initial_retransmittable_on_wire_timeout, + alarm_->deadline() - clock_.ApproximateNow()); + + manager_.reset_consecutive_retransmittable_on_wire_count(); + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + !kHasInflightPackets); + EXPECT_EQ(initial_retransmittable_on_wire_timeout, + alarm_->deadline() - clock_.ApproximateNow()); + EXPECT_CALL(delegate_, OnRetransmittableOnWireTimeout()); + clock_.AdvanceTime(initial_retransmittable_on_wire_timeout); + alarm_->Fire(); + + for (int i = 0; i < kMaxAggressiveRetransmittableOnWireCount; i++) { + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + !kHasInflightPackets); + EXPECT_TRUE(alarm_->IsSet()); + EXPECT_EQ(initial_retransmittable_on_wire_timeout, + alarm_->deadline() - clock_.ApproximateNow()); + clock_.AdvanceTime(initial_retransmittable_on_wire_timeout); + EXPECT_CALL(delegate_, OnRetransmittableOnWireTimeout()); + alarm_->Fire(); + // Reset alarm with in flight packets. + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + kHasInflightPackets); + // Advance 5ms to receive next packet. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + } + + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + !kHasInflightPackets); + EXPECT_TRUE(alarm_->IsSet()); + EXPECT_EQ(initial_retransmittable_on_wire_timeout * 2, + alarm_->deadline() - clock_.ApproximateNow()); + + clock_.AdvanceTime(2 * initial_retransmittable_on_wire_timeout); + EXPECT_CALL(delegate_, OnRetransmittableOnWireTimeout()); + alarm_->Fire(); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + manager_.reset_consecutive_retransmittable_on_wire_count(); + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + !kHasInflightPackets); + EXPECT_TRUE(alarm_->IsSet()); + EXPECT_EQ(initial_retransmittable_on_wire_timeout, + alarm_->deadline() - clock_.ApproximateNow()); +} + +TEST_F(QuicPingManagerTest, RetransmittableOnWireLimit) { + static constexpr int kMaxRetransmittableOnWirePingCount = 3; + SetQuicFlag(quic_max_retransmittable_on_wire_ping_count, + kMaxRetransmittableOnWirePingCount); + static constexpr QuicTime::Delta initial_retransmittable_on_wire_timeout = + QuicTime::Delta::FromMilliseconds(200); + static constexpr QuicTime::Delta kShortDelay = + QuicTime::Delta::FromMilliseconds(5); + ASSERT_LT(kShortDelay * 10, initial_retransmittable_on_wire_timeout); + manager_.set_initial_retransmittable_on_wire_timeout( + initial_retransmittable_on_wire_timeout); + + clock_.AdvanceTime(kShortDelay); + EXPECT_FALSE(alarm_->IsSet()); + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + kHasInflightPackets); + + EXPECT_TRUE(alarm_->IsSet()); + EXPECT_EQ(QuicTime::Delta::FromSeconds(kPingTimeoutSecs), + alarm_->deadline() - clock_.ApproximateNow()); + + for (int i = 0; i <= kMaxRetransmittableOnWirePingCount; i++) { + clock_.AdvanceTime(kShortDelay); + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + !kHasInflightPackets); + EXPECT_TRUE(alarm_->IsSet()); + EXPECT_EQ(initial_retransmittable_on_wire_timeout, + alarm_->deadline() - clock_.ApproximateNow()); + clock_.AdvanceTime(initial_retransmittable_on_wire_timeout); + EXPECT_CALL(delegate_, OnRetransmittableOnWireTimeout()); + alarm_->Fire(); + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + kHasInflightPackets); + } + + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + !kHasInflightPackets); + EXPECT_TRUE(alarm_->IsSet()); + // Verify alarm is in keep-alive mode. + EXPECT_EQ(QuicTime::Delta::FromSeconds(kPingTimeoutSecs), + alarm_->deadline() - clock_.ApproximateNow()); + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(kPingTimeoutSecs)); + EXPECT_CALL(delegate_, OnKeepAliveTimeout()); + alarm_->Fire(); + EXPECT_FALSE(alarm_->IsSet()); +} + +TEST_F(QuicPingManagerTest, MaxRetransmittableOnWireDelayShift) { + QuicPingManagerPeer::SetPerspective(&manager_, Perspective::IS_SERVER); + const int kMaxAggressiveRetransmittableOnWireCount = 3; + SetQuicFlag(quic_max_aggressive_retransmittable_on_wire_ping_count, + kMaxAggressiveRetransmittableOnWireCount); + const QuicTime::Delta initial_retransmittable_on_wire_timeout = + QuicTime::Delta::FromMilliseconds(200); + manager_.set_initial_retransmittable_on_wire_timeout( + initial_retransmittable_on_wire_timeout); + + for (int i = 0; i <= kMaxAggressiveRetransmittableOnWireCount; i++) { + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + !kHasInflightPackets); + EXPECT_TRUE(alarm_->IsSet()); + EXPECT_EQ(initial_retransmittable_on_wire_timeout, + alarm_->deadline() - clock_.ApproximateNow()); + clock_.AdvanceTime(initial_retransmittable_on_wire_timeout); + EXPECT_CALL(delegate_, OnRetransmittableOnWireTimeout()); + alarm_->Fire(); + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + kHasInflightPackets); + } + for (int i = 1; i <= 20; ++i) { + manager_.SetAlarm(clock_.ApproximateNow(), kShouldKeepAlive, + !kHasInflightPackets); + EXPECT_TRUE(alarm_->IsSet()); + if (i <= 10) { + EXPECT_EQ(initial_retransmittable_on_wire_timeout * (1 << i), + alarm_->deadline() - clock_.ApproximateNow()); + } else { + // Verify shift is capped. + EXPECT_EQ(initial_retransmittable_on_wire_timeout * (1 << 10), + alarm_->deadline() - clock_.ApproximateNow()); + } + clock_.AdvanceTime(alarm_->deadline() - clock_.ApproximateNow()); + EXPECT_CALL(delegate_, OnRetransmittableOnWireTimeout()); + alarm_->Fire(); + } +} + +} // namespace + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_process_packet_interface.h b/quiche/quic/core/quic_process_packet_interface.h new file mode 100644 index 000000000000..30bd0710e961 --- /dev/null +++ b/quiche/quic/core/quic_process_packet_interface.h @@ -0,0 +1,24 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_PROCESS_PACKET_INTERFACE_H_ +#define QUICHE_QUIC_CORE_QUIC_PROCESS_PACKET_INTERFACE_H_ + +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/platform/api/quic_socket_address.h" + +namespace quic { + +// A class to process each incoming packet. +class QUIC_NO_EXPORT ProcessPacketInterface { + public: + virtual ~ProcessPacketInterface() {} + virtual void ProcessPacket(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + const QuicReceivedPacket& packet) = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_PROCESS_PACKET_INTERFACE_H_ diff --git a/quiche/quic/core/quic_protocol_flags_list.h b/quiche/quic/core/quic_protocol_flags_list.h new file mode 100644 index 000000000000..7644cbfb68fc --- /dev/null +++ b/quiche/quic/core/quic_protocol_flags_list.h @@ -0,0 +1,229 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// NOLINT(build/header_guard) +// This file intentionally does not have header guards, it's intended to be +// included multiple times, each time with a different definition of +// QUIC_PROTOCOL_FLAG. + +#if defined(QUIC_PROTOCOL_FLAG) + +QUIC_PROTOCOL_FLAG( + bool, quic_allow_chlo_buffering, true, + "If true, allows packets to be buffered in anticipation of a " + "future CHLO, and allow CHLO packets to be buffered until next " + "iteration of the event loop.") + +QUIC_PROTOCOL_FLAG(bool, quic_disable_pacing_for_perf_tests, false, + "If true, disable pacing in QUIC") + +// Note that single-packet CHLOs are only enforced for Google QUIC versions that +// do not use CRYPTO frames. This currently means only Q043 and Q046. All other +// versions of QUIC (both Google QUIC and IETF) allow multi-packet CHLOs +// regardless of the value of this flag. +QUIC_PROTOCOL_FLAG(bool, quic_enforce_single_packet_chlo, true, + "If true, enforce that sent QUIC CHLOs fit in one packet. " + "Only applies to Q043 and Q046.") + +// Currently, this number is quite conservative. At a hypothetical 1000 qps, +// this means that the longest time-wait list we should see is: +// 200 seconds * 1000 qps = 200000. +// Of course, there are usually many queries per QUIC connection, so we allow a +// factor of 3 leeway. +QUIC_PROTOCOL_FLAG(int64_t, quic_time_wait_list_max_connections, 600000, + "Maximum number of connections on the time-wait list. " + "A negative value implies no configured limit.") + +QUIC_PROTOCOL_FLAG(int64_t, quic_time_wait_list_seconds, 200, + "Time period for which a given connection_id should live in " + "the time-wait state.") + +// This number is relatively conservative. For example, there are at most 1K +// queued stateless resets, which consume 1K * 21B = 21KB. +QUIC_PROTOCOL_FLAG( + uint64_t, quic_time_wait_list_max_pending_packets, 1024, + "Upper limit of pending packets in time wait list when writer is blocked.") + +// Stop sending a reset if the recorded number of addresses that server has +// recently sent stateless reset to exceeds this limit. +QUIC_PROTOCOL_FLAG(uint64_t, quic_max_recent_stateless_reset_addresses, 1024, + "Max number of recorded recent reset addresses.") + +// After this timeout, recent reset addresses will be cleared. +// FLAGS_quic_max_recent_stateless_reset_addresses * (1000ms / +// FLAGS_quic_recent_stateless_reset_addresses_lifetime_ms) is roughly the max +// reset per second. For example, 1024 * (1000ms / 1000ms) = 1K reset per +// second. +QUIC_PROTOCOL_FLAG( + uint64_t, quic_recent_stateless_reset_addresses_lifetime_ms, 1000, + "Max time that a client address lives in recent reset addresses set.") + +QUIC_PROTOCOL_FLAG(double, quic_bbr_cwnd_gain, 2.0f, + "Congestion window gain for QUIC BBR during PROBE_BW phase.") + +QUIC_PROTOCOL_FLAG( + int32_t, quic_buffered_data_threshold, 8 * 1024, + "If buffered data in QUIC stream is less than this " + "threshold, buffers all provided data or asks upper layer for more data") + +QUIC_PROTOCOL_FLAG( + uint64_t, quic_send_buffer_max_data_slice_size, 4 * 1024, + "Max size of data slice in bytes for QUIC stream send buffer.") + +QUIC_PROTOCOL_FLAG( + int32_t, quic_lumpy_pacing_size, 2, + "Number of packets that the pacing sender allows in bursts during " + "pacing. This flag is ignored if a flow's estimated bandwidth is " + "lower than 1200 kbps.") + +QUIC_PROTOCOL_FLAG( + double, quic_lumpy_pacing_cwnd_fraction, 0.25f, + "Congestion window fraction that the pacing sender allows in bursts " + "during pacing.") + +QUIC_PROTOCOL_FLAG( + int32_t, quic_lumpy_pacing_min_bandwidth_kbps, 1200, + "The minimum estimated client bandwidth below which the pacing sender will " + "not allow bursts.") + +QUIC_PROTOCOL_FLAG(int32_t, quic_max_pace_time_into_future_ms, 10, + "Max time that QUIC can pace packets into the future in ms.") + +QUIC_PROTOCOL_FLAG( + double, quic_pace_time_into_future_srtt_fraction, + 0.125f, // One-eighth smoothed RTT + "Smoothed RTT fraction that a connection can pace packets into the future.") + +QUIC_PROTOCOL_FLAG(bool, quic_export_write_path_stats_at_server, false, + "If true, export detailed write path statistics at server.") + +QUIC_PROTOCOL_FLAG(bool, quic_disable_version_negotiation_grease_randomness, + false, + "If true, use predictable version negotiation versions.") + +QUIC_PROTOCOL_FLAG(bool, quic_enable_http3_grease_randomness, true, + "If true, use random greased settings and frames.") + +QUIC_PROTOCOL_FLAG(int64_t, quic_max_tracked_packet_count, 10000, + "Maximum number of tracked packets.") + +QUIC_PROTOCOL_FLAG( + bool, quic_client_convert_http_header_name_to_lowercase, true, + "If true, HTTP request header names sent from QuicSpdyClientBase(and " + "descendents) will be automatically converted to lower case.") + +QUIC_PROTOCOL_FLAG( + int32_t, quic_bbr2_default_probe_bw_base_duration_ms, 2000, + "The default minimum duration for BBRv2-native probes, in milliseconds.") + +QUIC_PROTOCOL_FLAG( + int32_t, quic_bbr2_default_probe_bw_max_rand_duration_ms, 1000, + "The default upper bound of the random amount of BBRv2-native " + "probes, in milliseconds.") + +QUIC_PROTOCOL_FLAG( + int32_t, quic_bbr2_default_probe_rtt_period_ms, 10000, + "The default period for entering PROBE_RTT, in milliseconds.") + +QUIC_PROTOCOL_FLAG( + double, quic_bbr2_default_loss_threshold, 0.02, + "The default loss threshold for QUIC BBRv2, should be a value " + "between 0 and 1.") + +QUIC_PROTOCOL_FLAG( + int32_t, quic_bbr2_default_startup_full_loss_count, 8, + "The default minimum number of loss marking events to exit STARTUP.") + +QUIC_PROTOCOL_FLAG( + int32_t, quic_bbr2_default_probe_bw_full_loss_count, 2, + "The default minimum number of loss marking events to exit PROBE_UP phase.") + +QUIC_PROTOCOL_FLAG( + double, quic_bbr2_default_inflight_hi_headroom, 0.15, + "The default fraction of unutilized headroom to try to leave in path " + "upon high loss.") + +QUIC_PROTOCOL_FLAG( + int32_t, quic_bbr2_default_initial_ack_height_filter_window, 10, + "The default initial value of the max ack height filter's window length.") + +QUIC_PROTOCOL_FLAG( + double, quic_ack_aggregation_bandwidth_threshold, 1.0, + "If the bandwidth during ack aggregation is smaller than (estimated " + "bandwidth * this flag), consider the current aggregation completed " + "and starts a new one.") + +QUIC_PROTOCOL_FLAG( + int32_t, quic_anti_amplification_factor, 3, + "Anti-amplification factor. Before address validation, server will " + "send no more than factor times bytes received.") + +QUIC_PROTOCOL_FLAG( + int32_t, quic_max_buffered_crypto_bytes, + 16 * 1024, // 16 KB + "The maximum amount of CRYPTO frame data that can be buffered.") + +QUIC_PROTOCOL_FLAG( + int32_t, quic_max_aggressive_retransmittable_on_wire_ping_count, 5, + "Maximum number of consecutive pings that can be sent with the " + "aggressive initial retransmittable on the wire timeout if there is " + "no new stream data received. After this limit, the timeout will be " + "doubled each ping until it exceeds the default ping timeout.") + +QUIC_PROTOCOL_FLAG( + int32_t, quic_max_retransmittable_on_wire_ping_count, 1000, + "Maximum number of pings that can be sent with the retransmittable " + "on the wire timeout, over the lifetime of a connection. After this " + "limit, the timeout will be the default ping timeout.") + +QUIC_PROTOCOL_FLAG(int32_t, quic_max_congestion_window, 2000, + "The maximum congestion window in packets.") + +QUIC_PROTOCOL_FLAG( + int32_t, quic_max_streams_window_divisor, 2, + "The divisor that controls how often MAX_STREAMS frame is sent.") + +QUIC_PROTOCOL_FLAG( + uint64_t, quic_key_update_confidentiality_limit, 0, + "If non-zero and key update is allowed, the maximum number of " + "packets sent for each key phase before initiating a key update.") + +QUIC_PROTOCOL_FLAG(bool, quic_disable_client_tls_zero_rtt, false, + "If true, QUIC client with TLS will not try 0-RTT.") + +QUIC_PROTOCOL_FLAG(bool, quic_disable_server_tls_resumption, false, + "If true, QUIC server will disable TLS resumption by not " + "issuing or processing session tickets.") + +QUIC_PROTOCOL_FLAG(bool, quic_defer_send_in_response, true, + "If true, QUIC servers will defer sending in response to " + "incoming packets by default.") + +QUIC_PROTOCOL_FLAG( + bool, quic_header_size_limit_includes_overhead, true, + "If true, QUIC QPACK decoder includes 32-bytes overheader per entry while " + "comparing request/response header size against its upper limit.") + +QUIC_PROTOCOL_FLAG( + bool, quic_reject_retry_token_in_initial_packet, false, + "If true, always reject retry_token received in INITIAL packets") + +QUIC_PROTOCOL_FLAG(bool, quic_use_lower_server_response_mtu_for_test, false, + "If true, cap server response packet size at 1250.") + +QUIC_PROTOCOL_FLAG(bool, quic_enforce_strict_amplification_factor, false, + "If true, enforce strict amplification factor") + +QUIC_PROTOCOL_FLAG(bool, quic_bounded_crypto_send_buffer, false, + "If true, close the connection if a crypto send buffer " + "exceeds its size limit.") + +QUIC_PROTOCOL_FLAG(bool, quic_interval_set_enable_add_optimization, true, + "If true, enable an optimization in QuicIntervalSet") + +QUIC_PROTOCOL_FLAG( + bool, quic_enable_chaos_protection, true, + "If true, use chaos protection to randomize client initials.") + +#endif diff --git a/quiche/quic/core/quic_received_packet_manager.cc b/quiche/quic/core/quic_received_packet_manager.cc new file mode 100644 index 000000000000..048661821a03 --- /dev/null +++ b/quiche/quic/core/quic_received_packet_manager.cc @@ -0,0 +1,362 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_received_packet_manager.h" + +#include +#include +#include + +#include "quiche/quic/core/congestion_control/rtt_stats.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/quic_config.h" +#include "quiche/quic/core/quic_connection_stats.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +namespace { + +// The maximum number of packets to ack immediately after a missing packet for +// fast retransmission to kick in at the sender. This limit is created to +// reduce the number of acks sent that have no benefit for fast retransmission. +// Set to the number of nacks needed for fast retransmit plus one for protection +// against an ack loss +const size_t kMaxPacketsAfterNewMissing = 4; + +// One eighth RTT delay when doing ack decimation. +const float kShortAckDecimationDelay = 0.125; +} // namespace + +QuicReceivedPacketManager::QuicReceivedPacketManager() + : QuicReceivedPacketManager(nullptr) {} + +QuicReceivedPacketManager::QuicReceivedPacketManager(QuicConnectionStats* stats) + : ack_frame_updated_(false), + max_ack_ranges_(0), + time_largest_observed_(QuicTime::Zero()), + save_timestamps_(false), + save_timestamps_for_in_order_packets_(false), + stats_(stats), + num_retransmittable_packets_received_since_last_ack_sent_(0), + min_received_before_ack_decimation_(kMinReceivedBeforeAckDecimation), + ack_frequency_(kDefaultRetransmittablePacketsBeforeAck), + ack_decimation_delay_(kAckDecimationDelay), + unlimited_ack_decimation_(false), + one_immediate_ack_(false), + ignore_order_(false), + local_max_ack_delay_( + QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs)), + ack_timeout_(QuicTime::Zero()), + time_of_previous_received_packet_(QuicTime::Zero()), + was_last_packet_missing_(false), + last_ack_frequency_frame_sequence_number_(-1) {} + +QuicReceivedPacketManager::~QuicReceivedPacketManager() {} + +void QuicReceivedPacketManager::SetFromConfig(const QuicConfig& config, + Perspective perspective) { + if (config.HasClientSentConnectionOption(kAKD3, perspective)) { + ack_decimation_delay_ = kShortAckDecimationDelay; + } + if (config.HasClientSentConnectionOption(kAKDU, perspective)) { + unlimited_ack_decimation_ = true; + } + if (config.HasClientSentConnectionOption(k1ACK, perspective)) { + one_immediate_ack_ = true; + } +} + +void QuicReceivedPacketManager::RecordPacketReceived( + const QuicPacketHeader& header, QuicTime receipt_time, + const QuicEcnCodepoint ecn) { + const QuicPacketNumber packet_number = header.packet_number; + QUICHE_DCHECK(IsAwaitingPacket(packet_number)) + << " packet_number:" << packet_number; + was_last_packet_missing_ = IsMissing(packet_number); + if (!ack_frame_updated_) { + ack_frame_.received_packet_times.clear(); + } + ack_frame_updated_ = true; + + // Whether |packet_number| is received out of order. + bool packet_reordered = false; + if (LargestAcked(ack_frame_).IsInitialized() && + LargestAcked(ack_frame_) > packet_number) { + // Record how out of order stats. + packet_reordered = true; + ++stats_->packets_reordered; + stats_->max_sequence_reordering = + std::max(stats_->max_sequence_reordering, + LargestAcked(ack_frame_) - packet_number); + int64_t reordering_time_us = + (receipt_time - time_largest_observed_).ToMicroseconds(); + stats_->max_time_reordering_us = + std::max(stats_->max_time_reordering_us, reordering_time_us); + } + if (!LargestAcked(ack_frame_).IsInitialized() || + packet_number > LargestAcked(ack_frame_)) { + ack_frame_.largest_acked = packet_number; + time_largest_observed_ = receipt_time; + } + ack_frame_.packets.Add(packet_number); + + if (save_timestamps_) { + // The timestamp format only handles packets in time order. + if (save_timestamps_for_in_order_packets_ && packet_reordered) { + QUIC_DLOG(WARNING) << "Not saving receive timestamp for packet " + << packet_number; + } else if (!ack_frame_.received_packet_times.empty() && + ack_frame_.received_packet_times.back().second > receipt_time) { + QUIC_LOG(WARNING) + << "Receive time went backwards from: " + << ack_frame_.received_packet_times.back().second.ToDebuggingValue() + << " to " << receipt_time.ToDebuggingValue(); + } else { + ack_frame_.received_packet_times.push_back( + std::make_pair(packet_number, receipt_time)); + } + } + + if (GetQuicRestartFlag(quic_receive_ecn) && ecn != ECN_NOT_ECT) { + QUIC_RESTART_FLAG_COUNT_N(quic_receive_ecn, 1, 3); + if (!ack_frame_.ecn_counters.has_value()) { + ack_frame_.ecn_counters = QuicEcnCounts(); + } + switch (ecn) { + case ECN_NOT_ECT: + QUICHE_NOTREACHED(); + break; // It's impossible to get here, but the compiler complains. + case ECN_ECT0: + ack_frame_.ecn_counters->ect0++; + break; + case ECN_ECT1: + ack_frame_.ecn_counters->ect1++; + break; + case ECN_CE: + ack_frame_.ecn_counters->ce++; + break; + } + } + + if (least_received_packet_number_.IsInitialized()) { + least_received_packet_number_ = + std::min(least_received_packet_number_, packet_number); + } else { + least_received_packet_number_ = packet_number; + } +} + +bool QuicReceivedPacketManager::IsMissing(QuicPacketNumber packet_number) { + return LargestAcked(ack_frame_).IsInitialized() && + packet_number < LargestAcked(ack_frame_) && + !ack_frame_.packets.Contains(packet_number); +} + +bool QuicReceivedPacketManager::IsAwaitingPacket( + QuicPacketNumber packet_number) const { + return quic::IsAwaitingPacket(ack_frame_, packet_number, + peer_least_packet_awaiting_ack_); +} + +const QuicFrame QuicReceivedPacketManager::GetUpdatedAckFrame( + QuicTime approximate_now) { + if (time_largest_observed_ == QuicTime::Zero()) { + // We have received no packets. + ack_frame_.ack_delay_time = QuicTime::Delta::Infinite(); + } else { + // Ensure the delta is zero if approximate now is "in the past". + ack_frame_.ack_delay_time = approximate_now < time_largest_observed_ + ? QuicTime::Delta::Zero() + : approximate_now - time_largest_observed_; + } + while (max_ack_ranges_ > 0 && + ack_frame_.packets.NumIntervals() > max_ack_ranges_) { + ack_frame_.packets.RemoveSmallestInterval(); + } + // Clear all packet times if any are too far from largest observed. + // It's expected this is extremely rare. + for (auto it = ack_frame_.received_packet_times.begin(); + it != ack_frame_.received_packet_times.end();) { + if (LargestAcked(ack_frame_) - it->first >= + std::numeric_limits::max()) { + it = ack_frame_.received_packet_times.erase(it); + } else { + ++it; + } + } + +#if QUIC_FRAME_DEBUG + QuicFrame frame = QuicFrame(&ack_frame_); + frame.delete_forbidden = true; + return frame; +#else // QUIC_FRAME_DEBUG + return QuicFrame(&ack_frame_); +#endif // QUIC_FRAME_DEBUG +} + +void QuicReceivedPacketManager::DontWaitForPacketsBefore( + QuicPacketNumber least_unacked) { + if (!least_unacked.IsInitialized()) { + return; + } + // ValidateAck() should fail if peer_least_packet_awaiting_ack shrinks. + QUICHE_DCHECK(!peer_least_packet_awaiting_ack_.IsInitialized() || + peer_least_packet_awaiting_ack_ <= least_unacked); + if (!peer_least_packet_awaiting_ack_.IsInitialized() || + least_unacked > peer_least_packet_awaiting_ack_) { + peer_least_packet_awaiting_ack_ = least_unacked; + bool packets_updated = ack_frame_.packets.RemoveUpTo(least_unacked); + if (packets_updated) { + // Ack frame gets updated because packets set is updated because of stop + // waiting frame. + ack_frame_updated_ = true; + } + } + QUICHE_DCHECK(ack_frame_.packets.Empty() || + !peer_least_packet_awaiting_ack_.IsInitialized() || + ack_frame_.packets.Min() >= peer_least_packet_awaiting_ack_); +} + +QuicTime::Delta QuicReceivedPacketManager::GetMaxAckDelay( + QuicPacketNumber last_received_packet_number, + const RttStats& rtt_stats) const { + if (AckFrequencyFrameReceived() || + last_received_packet_number < PeerFirstSendingPacketNumber() + + min_received_before_ack_decimation_) { + return local_max_ack_delay_; + } + + // Wait for the minimum of the ack decimation delay or the delayed ack time + // before sending an ack. + QuicTime::Delta ack_delay = std::min( + local_max_ack_delay_, rtt_stats.min_rtt() * ack_decimation_delay_); + return std::max(ack_delay, kAlarmGranularity); +} + +void QuicReceivedPacketManager::MaybeUpdateAckFrequency( + QuicPacketNumber last_received_packet_number) { + if (AckFrequencyFrameReceived()) { + // Skip Ack Decimation below after receiving an AckFrequencyFrame from the + // other end point. + return; + } + if (last_received_packet_number < + PeerFirstSendingPacketNumber() + min_received_before_ack_decimation_) { + return; + } + ack_frequency_ = unlimited_ack_decimation_ + ? std::numeric_limits::max() + : kMaxRetransmittablePacketsBeforeAck; +} + +void QuicReceivedPacketManager::MaybeUpdateAckTimeout( + bool should_last_packet_instigate_acks, + QuicPacketNumber last_received_packet_number, + QuicTime last_packet_receipt_time, QuicTime now, + const RttStats* rtt_stats) { + if (!ack_frame_updated_) { + // ACK frame has not been updated, nothing to do. + return; + } + + if (!ignore_order_ && was_last_packet_missing_ && + last_sent_largest_acked_.IsInitialized() && + last_received_packet_number < last_sent_largest_acked_) { + // Only ack immediately if an ACK frame was sent with a larger largest acked + // than the newly received packet number. + ack_timeout_ = now; + return; + } + + if (!should_last_packet_instigate_acks) { + return; + } + + ++num_retransmittable_packets_received_since_last_ack_sent_; + + MaybeUpdateAckFrequency(last_received_packet_number); + if (num_retransmittable_packets_received_since_last_ack_sent_ >= + ack_frequency_) { + ack_timeout_ = now; + return; + } + + if (!ignore_order_ && HasNewMissingPackets()) { + ack_timeout_ = now; + return; + } + + const QuicTime updated_ack_time = std::max( + now, std::min(last_packet_receipt_time, now) + + GetMaxAckDelay(last_received_packet_number, *rtt_stats)); + if (!ack_timeout_.IsInitialized() || ack_timeout_ > updated_ack_time) { + ack_timeout_ = updated_ack_time; + } +} + +void QuicReceivedPacketManager::ResetAckStates() { + ack_frame_updated_ = false; + ack_timeout_ = QuicTime::Zero(); + num_retransmittable_packets_received_since_last_ack_sent_ = 0; + last_sent_largest_acked_ = LargestAcked(ack_frame_); +} + +bool QuicReceivedPacketManager::HasMissingPackets() const { + if (ack_frame_.packets.Empty()) { + return false; + } + if (ack_frame_.packets.NumIntervals() > 1) { + return true; + } + return peer_least_packet_awaiting_ack_.IsInitialized() && + ack_frame_.packets.Min() > peer_least_packet_awaiting_ack_; +} + +bool QuicReceivedPacketManager::HasNewMissingPackets() const { + if (one_immediate_ack_) { + return HasMissingPackets() && ack_frame_.packets.LastIntervalLength() == 1; + } + return HasMissingPackets() && + ack_frame_.packets.LastIntervalLength() <= kMaxPacketsAfterNewMissing; +} + +bool QuicReceivedPacketManager::ack_frame_updated() const { + return ack_frame_updated_; +} + +QuicPacketNumber QuicReceivedPacketManager::GetLargestObserved() const { + return LargestAcked(ack_frame_); +} + +QuicPacketNumber QuicReceivedPacketManager::PeerFirstSendingPacketNumber() + const { + if (!least_received_packet_number_.IsInitialized()) { + QUIC_BUG(quic_bug_10849_1) << "No packets have been received yet"; + return QuicPacketNumber(1); + } + return least_received_packet_number_; +} + +bool QuicReceivedPacketManager::IsAckFrameEmpty() const { + return ack_frame_.packets.Empty(); +} + +void QuicReceivedPacketManager::OnAckFrequencyFrame( + const QuicAckFrequencyFrame& frame) { + int64_t new_sequence_number = frame.sequence_number; + if (new_sequence_number <= last_ack_frequency_frame_sequence_number_) { + // Ignore old ACK_FREQUENCY frames. + return; + } + last_ack_frequency_frame_sequence_number_ = new_sequence_number; + ack_frequency_ = frame.packet_tolerance; + local_max_ack_delay_ = frame.max_ack_delay; + ignore_order_ = frame.ignore_order; +} + +} // namespace quic diff --git a/quiche/quic/core/quic_received_packet_manager.h b/quiche/quic/core/quic_received_packet_manager.h new file mode 100644 index 000000000000..ab298d424db7 --- /dev/null +++ b/quiche/quic/core/quic_received_packet_manager.h @@ -0,0 +1,221 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_RECEIVED_PACKET_MANAGER_H_ +#define QUICHE_QUIC_CORE_QUIC_RECEIVED_PACKET_MANAGER_H_ + +#include + +#include "quiche/quic/core/frames/quic_ack_frequency_frame.h" +#include "quiche/quic/core/quic_config.h" +#include "quiche/quic/core/quic_framer.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class RttStats; + +namespace test { +class QuicConnectionPeer; +class QuicReceivedPacketManagerPeer; +class UberReceivedPacketManagerPeer; +} // namespace test + +struct QuicConnectionStats; + +// Records all received packets by a connection. +class QUIC_EXPORT_PRIVATE QuicReceivedPacketManager { + public: + QuicReceivedPacketManager(); + explicit QuicReceivedPacketManager(QuicConnectionStats* stats); + QuicReceivedPacketManager(const QuicReceivedPacketManager&) = delete; + QuicReceivedPacketManager& operator=(const QuicReceivedPacketManager&) = + delete; + virtual ~QuicReceivedPacketManager(); + + void SetFromConfig(const QuicConfig& config, Perspective perspective); + + // Updates the internal state concerning which packets have been received. + // header: the packet header. + // timestamp: the arrival time of the packet. + virtual void RecordPacketReceived(const QuicPacketHeader& header, + QuicTime receipt_time, + QuicEcnCodepoint ecn); + + // Checks whether |packet_number| is missing and less than largest observed. + virtual bool IsMissing(QuicPacketNumber packet_number); + + // Checks if we're still waiting for the packet with |packet_number|. + virtual bool IsAwaitingPacket(QuicPacketNumber packet_number) const; + + // Retrieves a frame containing a QuicAckFrame. The ack frame may not be + // changed outside QuicReceivedPacketManager and must be serialized before + // another packet is received, or it will change. + const QuicFrame GetUpdatedAckFrame(QuicTime approximate_now); + + // Deletes all missing packets before least unacked. The connection won't + // process any packets with packet number before |least_unacked| that it + // received after this call. + void DontWaitForPacketsBefore(QuicPacketNumber least_unacked); + + // Called to update ack_timeout_ to the time when an ACK needs to be sent. A + // caller can decide whether and when to send an ACK by retrieving + // ack_timeout_. If ack_timeout_ is not initialized, no ACK needs to be sent. + // Otherwise, ACK needs to be sent by the specified time. + void MaybeUpdateAckTimeout(bool should_last_packet_instigate_acks, + QuicPacketNumber last_received_packet_number, + QuicTime last_packet_receipt_time, QuicTime now, + const RttStats* rtt_stats); + + // Resets ACK related states, called after an ACK is successfully sent. + void ResetAckStates(); + + // Returns true if there are any missing packets. + bool HasMissingPackets() const; + + // Returns true when there are new missing packets to be reported within 3 + // packets of the largest observed. + virtual bool HasNewMissingPackets() const; + + QuicPacketNumber peer_least_packet_awaiting_ack() const { + return peer_least_packet_awaiting_ack_; + } + + virtual bool ack_frame_updated() const; + + QuicPacketNumber GetLargestObserved() const; + + // Returns peer first sending packet number to our best knowledge. Considers + // least_received_packet_number_ as peer first sending packet number. Please + // note, this function should only be called when at least one packet has been + // received. + QuicPacketNumber PeerFirstSendingPacketNumber() const; + + // Returns true if ack frame is empty. + bool IsAckFrameEmpty() const; + + void set_connection_stats(QuicConnectionStats* stats) { stats_ = stats; } + + // For logging purposes. + const QuicAckFrame& ack_frame() const { return ack_frame_; } + + void set_max_ack_ranges(size_t max_ack_ranges) { + max_ack_ranges_ = max_ack_ranges; + } + + void set_save_timestamps(bool save_timestamps, bool in_order_packets_only) { + save_timestamps_ = save_timestamps; + save_timestamps_for_in_order_packets_ = in_order_packets_only; + } + + size_t min_received_before_ack_decimation() const { + return min_received_before_ack_decimation_; + } + void set_min_received_before_ack_decimation(size_t new_value) { + min_received_before_ack_decimation_ = new_value; + } + + void set_ack_frequency(size_t new_value) { + QUICHE_DCHECK_GT(new_value, 0u); + ack_frequency_ = new_value; + } + + void set_local_max_ack_delay(QuicTime::Delta local_max_ack_delay) { + local_max_ack_delay_ = local_max_ack_delay; + } + + QuicTime ack_timeout() const { return ack_timeout_; } + + void OnAckFrequencyFrame(const QuicAckFrequencyFrame& frame); + + private: + friend class test::QuicConnectionPeer; + friend class test::QuicReceivedPacketManagerPeer; + friend class test::UberReceivedPacketManagerPeer; + + // Sets ack_timeout_ to |time| if ack_timeout_ is not initialized or > time. + void MaybeUpdateAckTimeoutTo(QuicTime time); + + // Maybe update ack_frequency_ when condition meets. + void MaybeUpdateAckFrequency(QuicPacketNumber last_received_packet_number); + + QuicTime::Delta GetMaxAckDelay(QuicPacketNumber last_received_packet_number, + const RttStats& rtt_stats) const; + + bool AckFrequencyFrameReceived() const { + return last_ack_frequency_frame_sequence_number_ >= 0; + } + + // Least packet number of the the packet sent by the peer for which it + // hasn't received an ack. + QuicPacketNumber peer_least_packet_awaiting_ack_; + + // Received packet information used to produce acks. + QuicAckFrame ack_frame_; + + // True if |ack_frame_| has been updated since UpdateReceivedPacketInfo was + // last called. + bool ack_frame_updated_; + + // Maximum number of ack ranges allowed to be stored in the ack frame. + size_t max_ack_ranges_; + + // The time we received the largest_observed packet number, or zero if + // no packet numbers have been received since UpdateReceivedPacketInfo. + // Needed for calculating ack_delay_time. + QuicTime time_largest_observed_; + + // If true, save timestamps in the ack_frame_. + bool save_timestamps_; + + // If true and |save_timestamps_|, only save timestamps for packets that are + // received in order. + bool save_timestamps_for_in_order_packets_; + + // Least packet number received from peer. + QuicPacketNumber least_received_packet_number_; + + QuicConnectionStats* stats_; + + // How many retransmittable packets have arrived without sending an ack. + QuicPacketCount num_retransmittable_packets_received_since_last_ack_sent_; + // Ack decimation will start happening after this many packets are received. + size_t min_received_before_ack_decimation_; + // Ack every n-th packet. + size_t ack_frequency_; + // The max delay in fraction of min_rtt to use when sending decimated acks. + float ack_decimation_delay_; + // When true, removes ack decimation's max number of packets(10) before + // sending an ack. + bool unlimited_ack_decimation_; + // When true, only send 1 immediate ACK when reordering is detected. + bool one_immediate_ack_; + // When true, do not ack immediately upon observation of packet reordering. + bool ignore_order_; + + // The local node's maximum ack delay time. This is the maximum amount of + // time to wait before sending an acknowledgement. + QuicTime::Delta local_max_ack_delay_; + // Time that an ACK needs to be sent. 0 means no ACK is pending. Used when + // decide_when_to_send_acks_ is true. + QuicTime ack_timeout_; + + // The time the previous ack-instigating packet was received and processed. + QuicTime time_of_previous_received_packet_; + // Whether the most recent packet was missing before it was received. + bool was_last_packet_missing_; + + // Last sent largest acked, which gets updated when ACK was successfully sent. + QuicPacketNumber last_sent_largest_acked_; + + // The sequence number of the last received AckFrequencyFrame. Negative if + // none received. + int64_t last_ack_frequency_frame_sequence_number_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_RECEIVED_PACKET_MANAGER_H_ diff --git a/quiche/quic/core/quic_received_packet_manager_test.cc b/quiche/quic/core/quic_received_packet_manager_test.cc new file mode 100644 index 000000000000..143654e61b79 --- /dev/null +++ b/quiche/quic/core/quic_received_packet_manager_test.cc @@ -0,0 +1,704 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_received_packet_manager.h" + +#include +#include +#include +#include + +#include "quiche/quic/core/congestion_control/rtt_stats.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/quic_connection_stats.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/mock_clock.h" + +namespace quic { +namespace test { + +class QuicReceivedPacketManagerPeer { + public: + static void SetOneImmediateAck(QuicReceivedPacketManager* manager, + bool one_immediate_ack) { + manager->one_immediate_ack_ = one_immediate_ack; + } + + static void SetAckDecimationDelay(QuicReceivedPacketManager* manager, + float ack_decimation_delay) { + manager->ack_decimation_delay_ = ack_decimation_delay; + } +}; + +namespace { + +const bool kInstigateAck = true; +const QuicTime::Delta kMinRttMs = QuicTime::Delta::FromMilliseconds(40); +const QuicTime::Delta kDelayedAckTime = + QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs); + +class QuicReceivedPacketManagerTest : public QuicTest { + protected: + QuicReceivedPacketManagerTest() : received_manager_(&stats_) { + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(1)); + rtt_stats_.UpdateRtt(kMinRttMs, QuicTime::Delta::Zero(), QuicTime::Zero()); + received_manager_.set_save_timestamps(true, false); + } + + void RecordPacketReceipt(uint64_t packet_number) { + RecordPacketReceipt(packet_number, QuicTime::Zero()); + } + + void RecordPacketReceipt(uint64_t packet_number, QuicTime receipt_time) { + RecordPacketReceipt(packet_number, receipt_time, ECN_NOT_ECT); + } + + void RecordPacketReceipt(uint64_t packet_number, QuicTime receipt_time, + QuicEcnCodepoint ecn_codepoint) { + QuicPacketHeader header; + header.packet_number = QuicPacketNumber(packet_number); + received_manager_.RecordPacketReceived(header, receipt_time, ecn_codepoint); + } + + bool HasPendingAck() { + return received_manager_.ack_timeout().IsInitialized(); + } + + void MaybeUpdateAckTimeout(bool should_last_packet_instigate_acks, + uint64_t last_received_packet_number) { + received_manager_.MaybeUpdateAckTimeout( + should_last_packet_instigate_acks, + QuicPacketNumber(last_received_packet_number), + /*last_packet_receipt_time=*/clock_.ApproximateNow(), + /*now=*/clock_.ApproximateNow(), &rtt_stats_); + } + + void CheckAckTimeout(QuicTime time) { + QUICHE_DCHECK(HasPendingAck()); + QUICHE_DCHECK_EQ(received_manager_.ack_timeout(), time); + if (time <= clock_.ApproximateNow()) { + // ACK timeout expires, send an ACK. + received_manager_.ResetAckStates(); + QUICHE_DCHECK(!HasPendingAck()); + } + } + + MockClock clock_; + RttStats rtt_stats_; + QuicConnectionStats stats_; + QuicReceivedPacketManager received_manager_; +}; + +TEST_F(QuicReceivedPacketManagerTest, DontWaitForPacketsBefore) { + QuicPacketHeader header; + header.packet_number = QuicPacketNumber(2u); + received_manager_.RecordPacketReceived(header, QuicTime::Zero(), ECN_NOT_ECT); + header.packet_number = QuicPacketNumber(7u); + received_manager_.RecordPacketReceived(header, QuicTime::Zero(), ECN_NOT_ECT); + EXPECT_TRUE(received_manager_.IsAwaitingPacket(QuicPacketNumber(3u))); + EXPECT_TRUE(received_manager_.IsAwaitingPacket(QuicPacketNumber(6u))); + received_manager_.DontWaitForPacketsBefore(QuicPacketNumber(4)); + EXPECT_FALSE(received_manager_.IsAwaitingPacket(QuicPacketNumber(3u))); + EXPECT_TRUE(received_manager_.IsAwaitingPacket(QuicPacketNumber(6u))); +} + +TEST_F(QuicReceivedPacketManagerTest, GetUpdatedAckFrame) { + QuicPacketHeader header; + header.packet_number = QuicPacketNumber(2u); + QuicTime two_ms = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(2); + EXPECT_FALSE(received_manager_.ack_frame_updated()); + received_manager_.RecordPacketReceived(header, two_ms, ECN_NOT_ECT); + EXPECT_TRUE(received_manager_.ack_frame_updated()); + + QuicFrame ack = received_manager_.GetUpdatedAckFrame(QuicTime::Zero()); + received_manager_.ResetAckStates(); + EXPECT_FALSE(received_manager_.ack_frame_updated()); + // When UpdateReceivedPacketInfo with a time earlier than the time of the + // largest observed packet, make sure that the delta is 0, not negative. + EXPECT_EQ(QuicTime::Delta::Zero(), ack.ack_frame->ack_delay_time); + EXPECT_EQ(1u, ack.ack_frame->received_packet_times.size()); + + QuicTime four_ms = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(4); + ack = received_manager_.GetUpdatedAckFrame(four_ms); + received_manager_.ResetAckStates(); + EXPECT_FALSE(received_manager_.ack_frame_updated()); + // When UpdateReceivedPacketInfo after not having received a new packet, + // the delta should still be accurate. + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(2), + ack.ack_frame->ack_delay_time); + // And received packet times won't have change. + EXPECT_EQ(1u, ack.ack_frame->received_packet_times.size()); + + header.packet_number = QuicPacketNumber(999u); + received_manager_.RecordPacketReceived(header, two_ms, ECN_NOT_ECT); + header.packet_number = QuicPacketNumber(4u); + received_manager_.RecordPacketReceived(header, two_ms, ECN_NOT_ECT); + header.packet_number = QuicPacketNumber(1000u); + received_manager_.RecordPacketReceived(header, two_ms, ECN_NOT_ECT); + EXPECT_TRUE(received_manager_.ack_frame_updated()); + ack = received_manager_.GetUpdatedAckFrame(two_ms); + received_manager_.ResetAckStates(); + EXPECT_FALSE(received_manager_.ack_frame_updated()); + // UpdateReceivedPacketInfo should discard any times which can't be + // expressed on the wire. + EXPECT_EQ(2u, ack.ack_frame->received_packet_times.size()); +} + +TEST_F(QuicReceivedPacketManagerTest, UpdateReceivedConnectionStats) { + EXPECT_FALSE(received_manager_.ack_frame_updated()); + RecordPacketReceipt(1); + EXPECT_TRUE(received_manager_.ack_frame_updated()); + RecordPacketReceipt(6); + RecordPacketReceipt(2, + QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(1)); + + EXPECT_EQ(4u, stats_.max_sequence_reordering); + EXPECT_EQ(1000, stats_.max_time_reordering_us); + EXPECT_EQ(1u, stats_.packets_reordered); +} + +TEST_F(QuicReceivedPacketManagerTest, LimitAckRanges) { + received_manager_.set_max_ack_ranges(10); + EXPECT_FALSE(received_manager_.ack_frame_updated()); + for (int i = 0; i < 100; ++i) { + RecordPacketReceipt(1 + 2 * i); + EXPECT_TRUE(received_manager_.ack_frame_updated()); + received_manager_.GetUpdatedAckFrame(QuicTime::Zero()); + EXPECT_GE(10u, received_manager_.ack_frame().packets.NumIntervals()); + EXPECT_EQ(QuicPacketNumber(1u + 2 * i), + received_manager_.ack_frame().packets.Max()); + for (int j = 0; j < std::min(10, i + 1); ++j) { + ASSERT_GE(i, j); + EXPECT_TRUE(received_manager_.ack_frame().packets.Contains( + QuicPacketNumber(1 + (i - j) * 2))); + if (i > j) { + EXPECT_FALSE(received_manager_.ack_frame().packets.Contains( + QuicPacketNumber((i - j) * 2))); + } + } + } +} + +TEST_F(QuicReceivedPacketManagerTest, IgnoreOutOfOrderTimestamps) { + EXPECT_FALSE(received_manager_.ack_frame_updated()); + RecordPacketReceipt(1, QuicTime::Zero()); + EXPECT_TRUE(received_manager_.ack_frame_updated()); + EXPECT_EQ(1u, received_manager_.ack_frame().received_packet_times.size()); + RecordPacketReceipt(2, + QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(1)); + EXPECT_EQ(2u, received_manager_.ack_frame().received_packet_times.size()); + RecordPacketReceipt(3, QuicTime::Zero()); + EXPECT_EQ(2u, received_manager_.ack_frame().received_packet_times.size()); +} + +TEST_F(QuicReceivedPacketManagerTest, IgnoreOutOfOrderPackets) { + received_manager_.set_save_timestamps(true, true); + EXPECT_FALSE(received_manager_.ack_frame_updated()); + RecordPacketReceipt(1, QuicTime::Zero()); + EXPECT_TRUE(received_manager_.ack_frame_updated()); + EXPECT_EQ(1u, received_manager_.ack_frame().received_packet_times.size()); + RecordPacketReceipt(4, + QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(1)); + EXPECT_EQ(2u, received_manager_.ack_frame().received_packet_times.size()); + + RecordPacketReceipt(3, + QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(3)); + EXPECT_EQ(2u, received_manager_.ack_frame().received_packet_times.size()); +} + +TEST_F(QuicReceivedPacketManagerTest, HasMissingPackets) { + EXPECT_QUIC_BUG(received_manager_.PeerFirstSendingPacketNumber(), + "No packets have been received yet"); + RecordPacketReceipt(4, QuicTime::Zero()); + EXPECT_EQ(QuicPacketNumber(4), + received_manager_.PeerFirstSendingPacketNumber()); + EXPECT_FALSE(received_manager_.HasMissingPackets()); + RecordPacketReceipt(3, QuicTime::Zero()); + EXPECT_FALSE(received_manager_.HasMissingPackets()); + EXPECT_EQ(QuicPacketNumber(3), + received_manager_.PeerFirstSendingPacketNumber()); + RecordPacketReceipt(1, QuicTime::Zero()); + EXPECT_EQ(QuicPacketNumber(1), + received_manager_.PeerFirstSendingPacketNumber()); + EXPECT_TRUE(received_manager_.HasMissingPackets()); + RecordPacketReceipt(2, QuicTime::Zero()); + EXPECT_EQ(QuicPacketNumber(1), + received_manager_.PeerFirstSendingPacketNumber()); + EXPECT_FALSE(received_manager_.HasMissingPackets()); +} + +TEST_F(QuicReceivedPacketManagerTest, OutOfOrderReceiptCausesAckSent) { + EXPECT_FALSE(HasPendingAck()); + + RecordPacketReceipt(3, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 3); + // Delayed ack is scheduled. + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); + + RecordPacketReceipt(5, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 5); + // Immediate ack is sent. + CheckAckTimeout(clock_.ApproximateNow()); + + RecordPacketReceipt(6, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 6); + // Immediate ack is scheduled, because 4 is still missing. + CheckAckTimeout(clock_.ApproximateNow()); + + RecordPacketReceipt(2, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 2); + CheckAckTimeout(clock_.ApproximateNow()); + + RecordPacketReceipt(1, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 1); + // Should ack immediately, since this fills the last hole. + CheckAckTimeout(clock_.ApproximateNow()); + + RecordPacketReceipt(7, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 7); + // Immediate ack is scheduled, because 4 is still missing. + CheckAckTimeout(clock_.ApproximateNow()); +} + +TEST_F(QuicReceivedPacketManagerTest, OutOfOrderReceiptCausesAckSent1Ack) { + QuicReceivedPacketManagerPeer::SetOneImmediateAck(&received_manager_, true); + EXPECT_FALSE(HasPendingAck()); + + RecordPacketReceipt(3, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 3); + // Delayed ack is scheduled. + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); + + RecordPacketReceipt(5, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 5); + // Immediate ack is sent. + CheckAckTimeout(clock_.ApproximateNow()); + + RecordPacketReceipt(6, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 6); + // Delayed ack is scheduled. + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); + + RecordPacketReceipt(2, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 2); + CheckAckTimeout(clock_.ApproximateNow()); + + RecordPacketReceipt(1, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 1); + // Should ack immediately, since this fills the last hole. + CheckAckTimeout(clock_.ApproximateNow()); + + RecordPacketReceipt(7, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 7); + // Delayed ack is scheduled, even though 4 is still missing. + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); +} + +TEST_F(QuicReceivedPacketManagerTest, OutOfOrderAckReceiptCausesNoAck) { + EXPECT_FALSE(HasPendingAck()); + + RecordPacketReceipt(2, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(!kInstigateAck, 2); + EXPECT_FALSE(HasPendingAck()); + + RecordPacketReceipt(1, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(!kInstigateAck, 1); + EXPECT_FALSE(HasPendingAck()); +} + +TEST_F(QuicReceivedPacketManagerTest, AckReceiptCausesAckSend) { + EXPECT_FALSE(HasPendingAck()); + + RecordPacketReceipt(1, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(!kInstigateAck, 1); + EXPECT_FALSE(HasPendingAck()); + + RecordPacketReceipt(2, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(!kInstigateAck, 2); + EXPECT_FALSE(HasPendingAck()); + + RecordPacketReceipt(3, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 3); + // Delayed ack is scheduled. + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); + clock_.AdvanceTime(kDelayedAckTime); + CheckAckTimeout(clock_.ApproximateNow()); + + RecordPacketReceipt(4, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(!kInstigateAck, 4); + EXPECT_FALSE(HasPendingAck()); + + RecordPacketReceipt(5, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(!kInstigateAck, 5); + EXPECT_FALSE(HasPendingAck()); +} + +TEST_F(QuicReceivedPacketManagerTest, AckSentEveryNthPacket) { + EXPECT_FALSE(HasPendingAck()); + received_manager_.set_ack_frequency(3); + + // Receives packets 1 - 39. + for (size_t i = 1; i <= 39; ++i) { + RecordPacketReceipt(i, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, i); + if (i % 3 == 0) { + CheckAckTimeout(clock_.ApproximateNow()); + } else { + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); + } + } +} + +TEST_F(QuicReceivedPacketManagerTest, AckDecimationReducesAcks) { + EXPECT_FALSE(HasPendingAck()); + + // Start ack decimation from 10th packet. + received_manager_.set_min_received_before_ack_decimation(10); + + // Receives packets 1 - 29. + for (size_t i = 1; i <= 29; ++i) { + RecordPacketReceipt(i, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, i); + if (i <= 10) { + // For packets 1-10, ack every 2 packets. + if (i % 2 == 0) { + CheckAckTimeout(clock_.ApproximateNow()); + } else { + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); + } + continue; + } + // ack at 20. + if (i == 20) { + CheckAckTimeout(clock_.ApproximateNow()); + } else { + CheckAckTimeout(clock_.ApproximateNow() + kMinRttMs * 0.25); + } + } + + // We now receive the 30th packet, and so we send an ack. + RecordPacketReceipt(30, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 30); + CheckAckTimeout(clock_.ApproximateNow()); +} + +TEST_F(QuicReceivedPacketManagerTest, SendDelayedAckDecimation) { + EXPECT_FALSE(HasPendingAck()); + // The ack time should be based on min_rtt * 1/4, since it's less than the + // default delayed ack time. + QuicTime ack_time = clock_.ApproximateNow() + kMinRttMs * 0.25; + + // Process all the packets in order so there aren't missing packets. + uint64_t kFirstDecimatedPacket = 101; + for (uint64_t i = 1; i < kFirstDecimatedPacket; ++i) { + RecordPacketReceipt(i, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, i); + if (i % 2 == 0) { + // Ack every 2 packets by default. + CheckAckTimeout(clock_.ApproximateNow()); + } else { + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); + } + } + + RecordPacketReceipt(kFirstDecimatedPacket, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, kFirstDecimatedPacket); + CheckAckTimeout(ack_time); + + // The 10th received packet causes an ack to be sent. + for (uint64_t i = 1; i < 10; ++i) { + RecordPacketReceipt(kFirstDecimatedPacket + i, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, kFirstDecimatedPacket + i); + } + CheckAckTimeout(clock_.ApproximateNow()); +} + +TEST_F(QuicReceivedPacketManagerTest, SendDelayedAckDecimationMin1ms) { + EXPECT_FALSE(HasPendingAck()); + // Seed the min_rtt with a kAlarmGranularity signal. + rtt_stats_.UpdateRtt(kAlarmGranularity, QuicTime::Delta::Zero(), + clock_.ApproximateNow()); + // The ack time should be based on kAlarmGranularity, since the RTT is 1ms. + QuicTime ack_time = clock_.ApproximateNow() + kAlarmGranularity; + + // Process all the packets in order so there aren't missing packets. + uint64_t kFirstDecimatedPacket = 101; + for (uint64_t i = 1; i < kFirstDecimatedPacket; ++i) { + RecordPacketReceipt(i, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, i); + if (i % 2 == 0) { + // Ack every 2 packets by default. + CheckAckTimeout(clock_.ApproximateNow()); + } else { + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); + } + } + + RecordPacketReceipt(kFirstDecimatedPacket, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, kFirstDecimatedPacket); + CheckAckTimeout(ack_time); + + // The 10th received packet causes an ack to be sent. + for (uint64_t i = 1; i < 10; ++i) { + RecordPacketReceipt(kFirstDecimatedPacket + i, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, kFirstDecimatedPacket + i); + } + CheckAckTimeout(clock_.ApproximateNow()); +} + +TEST_F(QuicReceivedPacketManagerTest, + SendDelayedAckDecimationUnlimitedAggregation) { + EXPECT_FALSE(HasPendingAck()); + QuicConfig config; + QuicTagVector connection_options; + // No limit on the number of packets received before sending an ack. + connection_options.push_back(kAKDU); + config.SetConnectionOptionsToSend(connection_options); + received_manager_.SetFromConfig(config, Perspective::IS_CLIENT); + + // The ack time should be based on min_rtt/4, since it's less than the + // default delayed ack time. + QuicTime ack_time = clock_.ApproximateNow() + kMinRttMs * 0.25; + + // Process all the initial packets in order so there aren't missing packets. + uint64_t kFirstDecimatedPacket = 101; + for (uint64_t i = 1; i < kFirstDecimatedPacket; ++i) { + RecordPacketReceipt(i, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, i); + if (i % 2 == 0) { + // Ack every 2 packets by default. + CheckAckTimeout(clock_.ApproximateNow()); + } else { + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); + } + } + + RecordPacketReceipt(kFirstDecimatedPacket, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, kFirstDecimatedPacket); + CheckAckTimeout(ack_time); + + // 18 packets will not cause an ack to be sent. 19 will because when + // stop waiting frames are in use, we ack every 20 packets no matter what. + for (int i = 1; i <= 18; ++i) { + RecordPacketReceipt(kFirstDecimatedPacket + i, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, kFirstDecimatedPacket + i); + } + CheckAckTimeout(ack_time); +} + +TEST_F(QuicReceivedPacketManagerTest, SendDelayedAckDecimationEighthRtt) { + EXPECT_FALSE(HasPendingAck()); + QuicReceivedPacketManagerPeer::SetAckDecimationDelay(&received_manager_, + 0.125); + + // The ack time should be based on min_rtt/8, since it's less than the + // default delayed ack time. + QuicTime ack_time = clock_.ApproximateNow() + kMinRttMs * 0.125; + + // Process all the packets in order so there aren't missing packets. + uint64_t kFirstDecimatedPacket = 101; + for (uint64_t i = 1; i < kFirstDecimatedPacket; ++i) { + RecordPacketReceipt(i, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, i); + if (i % 2 == 0) { + // Ack every 2 packets by default. + CheckAckTimeout(clock_.ApproximateNow()); + } else { + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); + } + } + + RecordPacketReceipt(kFirstDecimatedPacket, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, kFirstDecimatedPacket); + CheckAckTimeout(ack_time); + + // The 10th received packet causes an ack to be sent. + for (uint64_t i = 1; i < 10; ++i) { + RecordPacketReceipt(kFirstDecimatedPacket + i, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, kFirstDecimatedPacket + i); + } + CheckAckTimeout(clock_.ApproximateNow()); +} + +TEST_F(QuicReceivedPacketManagerTest, + UpdateMaxAckDelayAndAckFrequencyFromAckFrequencyFrame) { + EXPECT_FALSE(HasPendingAck()); + + QuicAckFrequencyFrame frame; + frame.max_ack_delay = QuicTime::Delta::FromMilliseconds(10); + frame.packet_tolerance = 5; + received_manager_.OnAckFrequencyFrame(frame); + + for (int i = 1; i <= 50; ++i) { + RecordPacketReceipt(i, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, i); + if (i % frame.packet_tolerance == 0) { + CheckAckTimeout(clock_.ApproximateNow()); + } else { + CheckAckTimeout(clock_.ApproximateNow() + frame.max_ack_delay); + } + } +} + +TEST_F(QuicReceivedPacketManagerTest, + DisableOutOfOrderAckByIgnoreOrderFromAckFrequencyFrame) { + EXPECT_FALSE(HasPendingAck()); + + QuicAckFrequencyFrame frame; + frame.max_ack_delay = kDelayedAckTime; + frame.packet_tolerance = 2; + frame.ignore_order = true; + received_manager_.OnAckFrequencyFrame(frame); + + RecordPacketReceipt(4, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 4); + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); + RecordPacketReceipt(5, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 5); + // Immediate ack is sent as this is the 2nd packet of every two packets. + CheckAckTimeout(clock_.ApproximateNow()); + + RecordPacketReceipt(3, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 3); + // Don't ack as ignore_order is set by AckFrequencyFrame. + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); + + RecordPacketReceipt(2, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 2); + // Immediate ack is sent as this is the 2nd packet of every two packets. + CheckAckTimeout(clock_.ApproximateNow()); + + RecordPacketReceipt(1, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 1); + // Don't ack as ignore_order is set by AckFrequencyFrame. + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); +} + +TEST_F(QuicReceivedPacketManagerTest, + DisableMissingPaketsAckByIgnoreOrderFromAckFrequencyFrame) { + EXPECT_FALSE(HasPendingAck()); + QuicConfig config; + config.SetConnectionOptionsToSend({kAFFE}); + received_manager_.SetFromConfig(config, Perspective::IS_CLIENT); + + QuicAckFrequencyFrame frame; + frame.max_ack_delay = kDelayedAckTime; + frame.packet_tolerance = 2; + frame.ignore_order = true; + received_manager_.OnAckFrequencyFrame(frame); + + RecordPacketReceipt(1, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 1); + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); + RecordPacketReceipt(2, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 2); + // Immediate ack is sent as this is the 2nd packet of every two packets. + CheckAckTimeout(clock_.ApproximateNow()); + + RecordPacketReceipt(4, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 4); + // Don't ack even if packet 3 is newly missing as ignore_order is set by + // AckFrequencyFrame. + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); + + RecordPacketReceipt(5, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 5); + // Immediate ack is sent as this is the 2nd packet of every two packets. + CheckAckTimeout(clock_.ApproximateNow()); + + RecordPacketReceipt(7, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 7); + // Don't ack even if packet 6 is newly missing as ignore_order is set by + // AckFrequencyFrame. + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); +} + +TEST_F(QuicReceivedPacketManagerTest, + AckDecimationDisabledWhenAckFrequencyFrameIsReceived) { + EXPECT_FALSE(HasPendingAck()); + + QuicAckFrequencyFrame frame; + frame.max_ack_delay = kDelayedAckTime; + frame.packet_tolerance = 3; + frame.ignore_order = true; + received_manager_.OnAckFrequencyFrame(frame); + + // Process all the packets in order so there aren't missing packets. + uint64_t kFirstDecimatedPacket = 101; + uint64_t FiftyPacketsAfterAckDecimation = kFirstDecimatedPacket + 50; + for (uint64_t i = 1; i < FiftyPacketsAfterAckDecimation; ++i) { + RecordPacketReceipt(i, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, i); + if (i % 3 == 0) { + // Ack every 3 packets as decimation is disabled. + CheckAckTimeout(clock_.ApproximateNow()); + } else { + // Ack at default delay as decimation is disabled. + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); + } + } +} + +TEST_F(QuicReceivedPacketManagerTest, UpdateAckTimeoutOnPacketReceiptTime) { + EXPECT_FALSE(HasPendingAck()); + + // Received packets 3 and 4. + QuicTime packet_receipt_time3 = clock_.ApproximateNow(); + // Packet 3 gets processed after 10ms. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + RecordPacketReceipt(3, packet_receipt_time3); + received_manager_.MaybeUpdateAckTimeout( + kInstigateAck, QuicPacketNumber(3), + /*last_packet_receipt_time=*/packet_receipt_time3, + clock_.ApproximateNow(), &rtt_stats_); + // Make sure ACK timeout is based on receipt time. + CheckAckTimeout(packet_receipt_time3 + kDelayedAckTime); + + RecordPacketReceipt(4, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 4); + // Immediate ack is sent. + CheckAckTimeout(clock_.ApproximateNow()); +} + +TEST_F(QuicReceivedPacketManagerTest, + UpdateAckTimeoutOnPacketReceiptTimeLongerQueuingTime) { + EXPECT_FALSE(HasPendingAck()); + + // Received packets 3 and 4. + QuicTime packet_receipt_time3 = clock_.ApproximateNow(); + // Packet 3 gets processed after 100ms. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(100)); + RecordPacketReceipt(3, packet_receipt_time3); + received_manager_.MaybeUpdateAckTimeout( + kInstigateAck, QuicPacketNumber(3), + /*last_packet_receipt_time=*/packet_receipt_time3, + clock_.ApproximateNow(), &rtt_stats_); + // Given 100ms > ack delay, verify immediate ACK. + CheckAckTimeout(clock_.ApproximateNow()); +} + +TEST_F(QuicReceivedPacketManagerTest, CountEcnPackets) { + EXPECT_FALSE(HasPendingAck()); + RecordPacketReceipt(3, QuicTime::Zero(), ECN_NOT_ECT); + RecordPacketReceipt(4, QuicTime::Zero(), ECN_ECT0); + RecordPacketReceipt(5, QuicTime::Zero(), ECN_ECT1); + RecordPacketReceipt(6, QuicTime::Zero(), ECN_CE); + QuicFrame ack = received_manager_.GetUpdatedAckFrame(QuicTime::Zero()); + if (GetQuicRestartFlag(quic_receive_ecn)) { + EXPECT_TRUE(ack.ack_frame->ecn_counters.has_value()); + EXPECT_EQ(ack.ack_frame->ecn_counters->ect0, 1); + EXPECT_EQ(ack.ack_frame->ecn_counters->ect1, 1); + EXPECT_EQ(ack.ack_frame->ecn_counters->ce, 1); + } else { + EXPECT_FALSE(ack.ack_frame->ecn_counters.has_value()); + } +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_sent_packet_manager.cc b/quiche/quic/core/quic_sent_packet_manager.cc new file mode 100644 index 000000000000..1a098c9d00db --- /dev/null +++ b/quiche/quic/core/quic_sent_packet_manager.cc @@ -0,0 +1,1468 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_sent_packet_manager.h" + +#include +#include +#include + +#include "quiche/quic/core/congestion_control/general_loss_algorithm.h" +#include "quiche/quic/core/congestion_control/pacing_sender.h" +#include "quiche/quic/core/congestion_control/send_algorithm_interface.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/frames/quic_ack_frequency_frame.h" +#include "quiche/quic/core/proto/cached_network_parameters_proto.h" +#include "quiche/quic/core/quic_connection_stats.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_packet_number.h" +#include "quiche/quic/core/quic_transmission_info.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/print_elements.h" + +namespace quic { + +namespace { +static const int64_t kDefaultRetransmissionTimeMs = 500; + +// Ensure the handshake timer isnt't faster than 10ms. +// This limits the tenth retransmitted packet to 10s after the initial CHLO. +static const int64_t kMinHandshakeTimeoutMs = 10; + +// Sends up to two tail loss probes before firing an RTO, +// per draft RFC draft-dukkipati-tcpm-tcp-loss-probe. +static const size_t kDefaultMaxTailLossProbes = 2; + +// The multiplier for calculating PTO timeout before any RTT sample is +// available. +static const float kPtoMultiplierWithoutRttSamples = 3; + +// Returns true of retransmissions of the specified type should retransmit +// the frames directly (as opposed to resulting in a loss notification). +inline bool ShouldForceRetransmission(TransmissionType transmission_type) { + return transmission_type == HANDSHAKE_RETRANSMISSION || + transmission_type == PTO_RETRANSMISSION; +} + +// If pacing rate is accurate, > 2 burst token is not likely to help first ACK +// to arrive earlier, and overly large burst token could cause incast packet +// losses. +static const uint32_t kConservativeUnpacedBurst = 2; + +// The default number of PTOs to trigger path degrading. +static const uint32_t kNumProbeTimeoutsForPathDegradingDelay = 4; + +} // namespace + +#define ENDPOINT \ + (unacked_packets_.perspective() == Perspective::IS_SERVER ? "Server: " \ + : "Client: ") + +QuicSentPacketManager::QuicSentPacketManager( + Perspective perspective, const QuicClock* clock, QuicRandom* random, + QuicConnectionStats* stats, CongestionControlType congestion_control_type) + : unacked_packets_(perspective), + clock_(clock), + random_(random), + stats_(stats), + debug_delegate_(nullptr), + network_change_visitor_(nullptr), + initial_congestion_window_(kInitialCongestionWindow), + loss_algorithm_(&uber_loss_algorithm_), + consecutive_crypto_retransmission_count_(0), + pending_timer_transmission_count_(0), + using_pacing_(false), + conservative_handshake_retransmits_(false), + largest_mtu_acked_(0), + handshake_finished_(false), + peer_max_ack_delay_( + QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs)), + rtt_updated_(false), + acked_packets_iter_(last_ack_frame_.packets.rbegin()), + consecutive_pto_count_(0), + handshake_mode_disabled_(false), + handshake_packet_acked_(false), + zero_rtt_packet_acked_(false), + one_rtt_packet_acked_(false), + num_ptos_for_path_degrading_(kNumProbeTimeoutsForPathDegradingDelay), + ignore_pings_(false), + ignore_ack_delay_(false) { + SetSendAlgorithm(congestion_control_type); +} + +QuicSentPacketManager::~QuicSentPacketManager() {} + +void QuicSentPacketManager::SetFromConfig(const QuicConfig& config) { + const Perspective perspective = unacked_packets_.perspective(); + if (config.HasReceivedInitialRoundTripTimeUs() && + config.ReceivedInitialRoundTripTimeUs() > 0) { + if (!config.HasClientSentConnectionOption(kNRTT, perspective)) { + SetInitialRtt(QuicTime::Delta::FromMicroseconds( + config.ReceivedInitialRoundTripTimeUs()), + /*trusted=*/false); + } + } else if (config.HasInitialRoundTripTimeUsToSend() && + config.GetInitialRoundTripTimeUsToSend() > 0) { + SetInitialRtt(QuicTime::Delta::FromMicroseconds( + config.GetInitialRoundTripTimeUsToSend()), + /*trusted=*/false); + } + if (config.HasReceivedMaxAckDelayMs()) { + peer_max_ack_delay_ = + QuicTime::Delta::FromMilliseconds(config.ReceivedMaxAckDelayMs()); + } + if (GetQuicReloadableFlag(quic_can_send_ack_frequency) && + perspective == Perspective::IS_SERVER) { + if (config.HasReceivedMinAckDelayMs()) { + peer_min_ack_delay_ = + QuicTime::Delta::FromMilliseconds(config.ReceivedMinAckDelayMs()); + } + if (config.HasClientSentConnectionOption(kAFF1, perspective)) { + use_smoothed_rtt_in_ack_delay_ = true; + } + } + if (config.HasClientSentConnectionOption(kMAD0, perspective)) { + ignore_ack_delay_ = true; + } + + // Configure congestion control. + if (config.HasClientRequestedIndependentOption(kTBBR, perspective)) { + SetSendAlgorithm(kBBR); + } + if (GetQuicReloadableFlag(quic_allow_client_enabled_bbr_v2) && + config.HasClientRequestedIndependentOption(kB2ON, perspective)) { + QUIC_RELOADABLE_FLAG_COUNT(quic_allow_client_enabled_bbr_v2); + SetSendAlgorithm(kBBRv2); + } + + if (config.HasClientRequestedIndependentOption(kRENO, perspective)) { + SetSendAlgorithm(kRenoBytes); + } else if (config.HasClientRequestedIndependentOption(kBYTE, perspective) || + (GetQuicReloadableFlag(quic_default_to_bbr) && + config.HasClientRequestedIndependentOption(kQBIC, perspective))) { + SetSendAlgorithm(kCubicBytes); + } + + // Initial window. + if (config.HasClientRequestedIndependentOption(kIW03, perspective)) { + initial_congestion_window_ = 3; + send_algorithm_->SetInitialCongestionWindowInPackets(3); + } + if (config.HasClientRequestedIndependentOption(kIW10, perspective)) { + initial_congestion_window_ = 10; + send_algorithm_->SetInitialCongestionWindowInPackets(10); + } + if (config.HasClientRequestedIndependentOption(kIW20, perspective)) { + initial_congestion_window_ = 20; + send_algorithm_->SetInitialCongestionWindowInPackets(20); + } + if (config.HasClientRequestedIndependentOption(kIW50, perspective)) { + initial_congestion_window_ = 50; + send_algorithm_->SetInitialCongestionWindowInPackets(50); + } + if (config.HasClientRequestedIndependentOption(kBWS5, perspective)) { + initial_congestion_window_ = 10; + send_algorithm_->SetInitialCongestionWindowInPackets(10); + } + + if (config.HasClientRequestedIndependentOption(kIGNP, perspective)) { + ignore_pings_ = true; + } + + using_pacing_ = !GetQuicFlag(quic_disable_pacing_for_perf_tests); + // Configure loss detection. + if (config.HasClientRequestedIndependentOption(kILD0, perspective)) { + uber_loss_algorithm_.SetReorderingShift(kDefaultIetfLossDelayShift); + uber_loss_algorithm_.DisableAdaptiveReorderingThreshold(); + } + if (config.HasClientRequestedIndependentOption(kILD1, perspective)) { + uber_loss_algorithm_.SetReorderingShift(kDefaultLossDelayShift); + uber_loss_algorithm_.DisableAdaptiveReorderingThreshold(); + } + if (config.HasClientRequestedIndependentOption(kILD2, perspective)) { + uber_loss_algorithm_.EnableAdaptiveReorderingThreshold(); + uber_loss_algorithm_.SetReorderingShift(kDefaultIetfLossDelayShift); + } + if (config.HasClientRequestedIndependentOption(kILD3, perspective)) { + uber_loss_algorithm_.SetReorderingShift(kDefaultLossDelayShift); + uber_loss_algorithm_.EnableAdaptiveReorderingThreshold(); + } + if (config.HasClientRequestedIndependentOption(kILD4, perspective)) { + uber_loss_algorithm_.SetReorderingShift(kDefaultLossDelayShift); + uber_loss_algorithm_.EnableAdaptiveReorderingThreshold(); + uber_loss_algorithm_.EnableAdaptiveTimeThreshold(); + } + if (config.HasClientRequestedIndependentOption(kRUNT, perspective)) { + uber_loss_algorithm_.DisablePacketThresholdForRuntPackets(); + } + if (config.HasClientSentConnectionOption(kCONH, perspective)) { + conservative_handshake_retransmits_ = true; + } + send_algorithm_->SetFromConfig(config, perspective); + loss_algorithm_->SetFromConfig(config, perspective); + + if (network_change_visitor_ != nullptr) { + network_change_visitor_->OnCongestionChange(); + } + + if (debug_delegate_ != nullptr) { + DebugDelegate::SendParameters parameters; + parameters.congestion_control_type = + send_algorithm_->GetCongestionControlType(); + parameters.use_pacing = using_pacing_; + parameters.initial_congestion_window = initial_congestion_window_; + debug_delegate_->OnConfigProcessed(parameters); + } +} + +void QuicSentPacketManager::ApplyConnectionOptions( + const QuicTagVector& connection_options) { + absl::optional cc_type; + if (ContainsQuicTag(connection_options, kB2ON)) { + cc_type = kBBRv2; + } else if (ContainsQuicTag(connection_options, kTBBR)) { + cc_type = kBBR; + } else if (ContainsQuicTag(connection_options, kRENO)) { + cc_type = kRenoBytes; + } else if (ContainsQuicTag(connection_options, kQBIC)) { + cc_type = kCubicBytes; + } + + if (cc_type.has_value()) { + SetSendAlgorithm(*cc_type); + } + + send_algorithm_->ApplyConnectionOptions(connection_options); +} + +void QuicSentPacketManager::ResumeConnectionState( + const CachedNetworkParameters& cached_network_params, + bool max_bandwidth_resumption) { + QuicBandwidth bandwidth = QuicBandwidth::FromBytesPerSecond( + max_bandwidth_resumption + ? cached_network_params.max_bandwidth_estimate_bytes_per_second() + : cached_network_params.bandwidth_estimate_bytes_per_second()); + QuicTime::Delta rtt = + QuicTime::Delta::FromMilliseconds(cached_network_params.min_rtt_ms()); + // This calls the old AdjustNetworkParameters interface, and fills certain + // fields in SendAlgorithmInterface::NetworkParams + // (e.g., quic_bbr_fix_pacing_rate) using GFE flags. + SendAlgorithmInterface::NetworkParams params( + bandwidth, rtt, /*allow_cwnd_to_decrease = */ false); + // The rtt is trusted because it's a min_rtt measured from a previous + // connection with the same network path between client and server. + params.is_rtt_trusted = true; + AdjustNetworkParameters(params); +} + +void QuicSentPacketManager::AdjustNetworkParameters( + const SendAlgorithmInterface::NetworkParams& params) { + const QuicBandwidth& bandwidth = params.bandwidth; + const QuicTime::Delta& rtt = params.rtt; + + if (!rtt.IsZero()) { + if (params.is_rtt_trusted) { + // Always set initial rtt if it's trusted. + SetInitialRtt(rtt, /*trusted=*/true); + } else if (rtt_stats_.initial_rtt() == + QuicTime::Delta::FromMilliseconds(kInitialRttMs)) { + // Only set initial rtt if we are using the default. This avoids + // overwriting a trusted initial rtt by an untrusted one. + SetInitialRtt(rtt, /*trusted=*/false); + } + } + + const QuicByteCount old_cwnd = send_algorithm_->GetCongestionWindow(); + if (GetQuicReloadableFlag(quic_conservative_bursts) && using_pacing_ && + !bandwidth.IsZero()) { + QUIC_RELOADABLE_FLAG_COUNT(quic_conservative_bursts); + pacing_sender_.SetBurstTokens(kConservativeUnpacedBurst); + } + send_algorithm_->AdjustNetworkParameters(params); + if (debug_delegate_ != nullptr) { + debug_delegate_->OnAdjustNetworkParameters( + bandwidth, rtt.IsZero() ? rtt_stats_.MinOrInitialRtt() : rtt, old_cwnd, + send_algorithm_->GetCongestionWindow()); + } +} + +void QuicSentPacketManager::SetLossDetectionTuner( + std::unique_ptr tuner) { + uber_loss_algorithm_.SetLossDetectionTuner(std::move(tuner)); +} + +void QuicSentPacketManager::OnConfigNegotiated() { + loss_algorithm_->OnConfigNegotiated(); +} + +void QuicSentPacketManager::OnConnectionClosed() { + loss_algorithm_->OnConnectionClosed(); +} + +void QuicSentPacketManager::SetHandshakeConfirmed() { + if (!handshake_finished_) { + handshake_finished_ = true; + NeuterHandshakePackets(); + } +} + +void QuicSentPacketManager::PostProcessNewlyAckedPackets( + QuicPacketNumber ack_packet_number, EncryptionLevel ack_decrypted_level, + const QuicAckFrame& ack_frame, QuicTime ack_receive_time, bool rtt_updated, + QuicByteCount prior_bytes_in_flight) { + unacked_packets_.NotifyAggregatedStreamFrameAcked( + last_ack_frame_.ack_delay_time); + InvokeLossDetection(ack_receive_time); + MaybeInvokeCongestionEvent(rtt_updated, prior_bytes_in_flight, + ack_receive_time); + unacked_packets_.RemoveObsoletePackets(); + + sustained_bandwidth_recorder_.RecordEstimate( + send_algorithm_->InRecovery(), send_algorithm_->InSlowStart(), + send_algorithm_->BandwidthEstimate(), ack_receive_time, clock_->WallNow(), + rtt_stats_.smoothed_rtt()); + + // Anytime we are making forward progress and have a new RTT estimate, reset + // the backoff counters. + if (rtt_updated) { + // Records the max consecutive PTO before forward progress has been made. + if (consecutive_pto_count_ > + stats_->max_consecutive_rto_with_forward_progress) { + stats_->max_consecutive_rto_with_forward_progress = + consecutive_pto_count_; + } + // Reset all retransmit counters any time a new packet is acked. + consecutive_pto_count_ = 0; + consecutive_crypto_retransmission_count_ = 0; + } + + if (debug_delegate_ != nullptr) { + debug_delegate_->OnIncomingAck( + ack_packet_number, ack_decrypted_level, ack_frame, ack_receive_time, + LargestAcked(ack_frame), rtt_updated, GetLeastUnacked()); + } + // Remove packets below least unacked from all_packets_acked_ and + // last_ack_frame_. + last_ack_frame_.packets.RemoveUpTo(unacked_packets_.GetLeastUnacked()); + last_ack_frame_.received_packet_times.clear(); +} + +void QuicSentPacketManager::MaybeInvokeCongestionEvent( + bool rtt_updated, QuicByteCount prior_in_flight, QuicTime event_time) { + if (!rtt_updated && packets_acked_.empty() && packets_lost_.empty()) { + return; + } + const bool overshooting_detected = + stats_->overshooting_detected_with_network_parameters_adjusted; + if (using_pacing_) { + pacing_sender_.OnCongestionEvent(rtt_updated, prior_in_flight, event_time, + packets_acked_, packets_lost_, 0, 0); + } else { + send_algorithm_->OnCongestionEvent(rtt_updated, prior_in_flight, event_time, + packets_acked_, packets_lost_, 0, 0); + } + if (debug_delegate_ != nullptr && !overshooting_detected && + stats_->overshooting_detected_with_network_parameters_adjusted) { + debug_delegate_->OnOvershootingDetected(); + } + packets_acked_.clear(); + packets_lost_.clear(); + if (network_change_visitor_ != nullptr) { + network_change_visitor_->OnCongestionChange(); + } +} + +void QuicSentPacketManager::MarkInitialPacketsForRetransmission() { + if (unacked_packets_.empty()) { + return; + } + QuicPacketNumber packet_number = unacked_packets_.GetLeastUnacked(); + QuicPacketNumber largest_sent_packet = unacked_packets_.largest_sent_packet(); + for (; packet_number <= largest_sent_packet; ++packet_number) { + QuicTransmissionInfo* transmission_info = + unacked_packets_.GetMutableTransmissionInfo(packet_number); + if (transmission_info->encryption_level == ENCRYPTION_INITIAL) { + if (transmission_info->in_flight) { + unacked_packets_.RemoveFromInFlight(transmission_info); + } + if (unacked_packets_.HasRetransmittableFrames(*transmission_info)) { + MarkForRetransmission(packet_number, ALL_INITIAL_RETRANSMISSION); + } + } + } +} + +void QuicSentPacketManager::MarkZeroRttPacketsForRetransmission() { + if (unacked_packets_.empty()) { + return; + } + QuicPacketNumber packet_number = unacked_packets_.GetLeastUnacked(); + QuicPacketNumber largest_sent_packet = unacked_packets_.largest_sent_packet(); + for (; packet_number <= largest_sent_packet; ++packet_number) { + QuicTransmissionInfo* transmission_info = + unacked_packets_.GetMutableTransmissionInfo(packet_number); + if (transmission_info->encryption_level == ENCRYPTION_ZERO_RTT) { + if (transmission_info->in_flight) { + // Remove 0-RTT packets and packets of the wrong version from flight, + // because neither can be processed by the peer. + unacked_packets_.RemoveFromInFlight(transmission_info); + } + if (unacked_packets_.HasRetransmittableFrames(*transmission_info)) { + MarkForRetransmission(packet_number, ALL_ZERO_RTT_RETRANSMISSION); + } + } + } +} + +void QuicSentPacketManager::NeuterUnencryptedPackets() { + for (QuicPacketNumber packet_number : + unacked_packets_.NeuterUnencryptedPackets()) { + send_algorithm_->OnPacketNeutered(packet_number); + } + if (handshake_mode_disabled_) { + consecutive_pto_count_ = 0; + uber_loss_algorithm_.ResetLossDetection(INITIAL_DATA); + } +} + +void QuicSentPacketManager::NeuterHandshakePackets() { + for (QuicPacketNumber packet_number : + unacked_packets_.NeuterHandshakePackets()) { + send_algorithm_->OnPacketNeutered(packet_number); + } + if (handshake_mode_disabled_) { + consecutive_pto_count_ = 0; + uber_loss_algorithm_.ResetLossDetection(HANDSHAKE_DATA); + } +} + +bool QuicSentPacketManager::ShouldAddMaxAckDelay( + PacketNumberSpace space) const { + // Do not include max_ack_delay when PTO is armed for Initial or Handshake + // packet number spaces. + return !supports_multiple_packet_number_spaces() || space == APPLICATION_DATA; +} + +QuicTime QuicSentPacketManager::GetEarliestPacketSentTimeForPto( + PacketNumberSpace* packet_number_space) const { + QUICHE_DCHECK(supports_multiple_packet_number_spaces()); + QuicTime earliest_sent_time = QuicTime::Zero(); + for (int8_t i = 0; i < NUM_PACKET_NUMBER_SPACES; ++i) { + const QuicTime sent_time = unacked_packets_.GetLastInFlightPacketSentTime( + static_cast(i)); + if (!handshake_finished_ && i == APPLICATION_DATA) { + // Do not arm PTO for application data until handshake gets confirmed. + continue; + } + if (!sent_time.IsInitialized() || (earliest_sent_time.IsInitialized() && + earliest_sent_time <= sent_time)) { + continue; + } + earliest_sent_time = sent_time; + *packet_number_space = static_cast(i); + } + + return earliest_sent_time; +} + +void QuicSentPacketManager::MarkForRetransmission( + QuicPacketNumber packet_number, TransmissionType transmission_type) { + QuicTransmissionInfo* transmission_info = + unacked_packets_.GetMutableTransmissionInfo(packet_number); + // Packets without retransmittable frames can only be marked for loss + // retransmission. + QUIC_BUG_IF(quic_bug_12552_2, transmission_type != LOSS_RETRANSMISSION && + !unacked_packets_.HasRetransmittableFrames( + *transmission_info)) + << "packet number " << packet_number + << " transmission_type: " << transmission_type << " transmission_info " + << transmission_info->DebugString(); + if (ShouldForceRetransmission(transmission_type)) { + if (!unacked_packets_.RetransmitFrames( + QuicFrames(transmission_info->retransmittable_frames), + transmission_type)) { + // Do not set packet state if the data is not fully retransmitted. + // This should only happen if packet payload size decreases which can be + // caused by: + // 1) connection tries to opportunistically retransmit data + // when sending a packet of a different packet number space, or + // 2) path MTU decreases, or + // 3) packet header size increases (e.g., packet number length + // increases). + QUIC_CODE_COUNT(quic_retransmit_frames_failed); + return; + } + QUIC_CODE_COUNT(quic_retransmit_frames_succeeded); + } else { + unacked_packets_.NotifyFramesLost(*transmission_info, transmission_type); + + if (!transmission_info->retransmittable_frames.empty()) { + if (transmission_type == LOSS_RETRANSMISSION) { + // Record the first packet sent after loss, which allows to wait 1 + // more RTT before giving up on this lost packet. + transmission_info->first_sent_after_loss = + unacked_packets_.largest_sent_packet() + 1; + } else { + // Clear the recorded first packet sent after loss when version or + // encryption changes. + transmission_info->first_sent_after_loss.Clear(); + } + } + } + + // Get the latest transmission_info here as it can be invalidated after + // HandleRetransmission adding new sent packets into unacked_packets_. + transmission_info = + unacked_packets_.GetMutableTransmissionInfo(packet_number); + + // Update packet state according to transmission type. + transmission_info->state = + QuicUtils::RetransmissionTypeToPacketState(transmission_type); +} + +void QuicSentPacketManager::RecordOneSpuriousRetransmission( + const QuicTransmissionInfo& info) { + stats_->bytes_spuriously_retransmitted += info.bytes_sent; + ++stats_->packets_spuriously_retransmitted; + if (debug_delegate_ != nullptr) { + debug_delegate_->OnSpuriousPacketRetransmission(info.transmission_type, + info.bytes_sent); + } +} + +void QuicSentPacketManager::MarkPacketHandled(QuicPacketNumber packet_number, + QuicTransmissionInfo* info, + QuicTime ack_receive_time, + QuicTime::Delta ack_delay_time, + QuicTime receive_timestamp) { + if (info->has_ack_frequency) { + for (const auto& frame : info->retransmittable_frames) { + if (frame.type == ACK_FREQUENCY_FRAME) { + OnAckFrequencyFrameAcked(*frame.ack_frequency_frame); + } + } + } + // Try to aggregate acked stream frames if acked packet is not a + // retransmission. + if (info->transmission_type == NOT_RETRANSMISSION) { + unacked_packets_.MaybeAggregateAckedStreamFrame(*info, ack_delay_time, + receive_timestamp); + } else { + unacked_packets_.NotifyAggregatedStreamFrameAcked(ack_delay_time); + const bool new_data_acked = unacked_packets_.NotifyFramesAcked( + *info, ack_delay_time, receive_timestamp); + if (!new_data_acked && info->transmission_type != NOT_RETRANSMISSION) { + // Record as a spurious retransmission if this packet is a + // retransmission and no new data gets acked. + QUIC_DVLOG(1) << "Detect spurious retransmitted packet " << packet_number + << " transmission type: " << info->transmission_type; + RecordOneSpuriousRetransmission(*info); + } + } + if (info->state == LOST) { + // Record as a spurious loss as a packet previously declared lost gets + // acked. + const PacketNumberSpace packet_number_space = + unacked_packets_.GetPacketNumberSpace(info->encryption_level); + const QuicPacketNumber previous_largest_acked = + supports_multiple_packet_number_spaces() + ? unacked_packets_.GetLargestAckedOfPacketNumberSpace( + packet_number_space) + : unacked_packets_.largest_acked(); + QUIC_DVLOG(1) << "Packet " << packet_number + << " was detected lost spuriously, " + "previous_largest_acked: " + << previous_largest_acked; + loss_algorithm_->SpuriousLossDetected(unacked_packets_, rtt_stats_, + ack_receive_time, packet_number, + previous_largest_acked); + ++stats_->packet_spuriously_detected_lost; + } + + if (network_change_visitor_ != nullptr && + info->bytes_sent > largest_mtu_acked_) { + largest_mtu_acked_ = info->bytes_sent; + network_change_visitor_->OnPathMtuIncreased(largest_mtu_acked_); + } + unacked_packets_.RemoveFromInFlight(info); + unacked_packets_.RemoveRetransmittability(info); + info->state = ACKED; +} + +bool QuicSentPacketManager::CanSendAckFrequency() const { + return !peer_min_ack_delay_.IsInfinite() && handshake_finished_; +} + +QuicAckFrequencyFrame QuicSentPacketManager::GetUpdatedAckFrequencyFrame() + const { + QuicAckFrequencyFrame frame; + if (!CanSendAckFrequency()) { + QUIC_BUG(quic_bug_10750_1) + << "New AckFrequencyFrame is created while it shouldn't."; + return frame; + } + + QUIC_RELOADABLE_FLAG_COUNT_N(quic_can_send_ack_frequency, 1, 3); + frame.packet_tolerance = kMaxRetransmittablePacketsBeforeAck; + auto rtt = use_smoothed_rtt_in_ack_delay_ ? rtt_stats_.SmoothedOrInitialRtt() + : rtt_stats_.MinOrInitialRtt(); + frame.max_ack_delay = rtt * kAckDecimationDelay; + frame.max_ack_delay = std::max(frame.max_ack_delay, peer_min_ack_delay_); + // TODO(haoyuewang) Remove this once kDefaultMinAckDelayTimeMs is updated to + // 5 ms on the client side. + frame.max_ack_delay = + std::max(frame.max_ack_delay, + QuicTime::Delta::FromMilliseconds(kDefaultMinAckDelayTimeMs)); + return frame; +} + +bool QuicSentPacketManager::OnPacketSent( + SerializedPacket* mutable_packet, QuicTime sent_time, + TransmissionType transmission_type, + HasRetransmittableData has_retransmittable_data, bool measure_rtt, + QuicEcnCodepoint ecn_codepoint) { + const SerializedPacket& packet = *mutable_packet; + QuicPacketNumber packet_number = packet.packet_number; + QUICHE_DCHECK_LE(FirstSendingPacketNumber(), packet_number); + QUICHE_DCHECK(!unacked_packets_.IsUnacked(packet_number)); + QUIC_BUG_IF(quic_bug_10750_2, packet.encrypted_length == 0) + << "Cannot send empty packets."; + if (pending_timer_transmission_count_ > 0) { + --pending_timer_transmission_count_; + } + + bool in_flight = has_retransmittable_data == HAS_RETRANSMITTABLE_DATA; + if (ignore_pings_ && mutable_packet->retransmittable_frames.size() == 1 && + mutable_packet->retransmittable_frames[0].type == PING_FRAME) { + // Dot not use PING only packet for RTT measure or congestion control. + in_flight = false; + measure_rtt = false; + } + if (using_pacing_) { + pacing_sender_.OnPacketSent(sent_time, unacked_packets_.bytes_in_flight(), + packet_number, packet.encrypted_length, + has_retransmittable_data); + } else { + send_algorithm_->OnPacketSent(sent_time, unacked_packets_.bytes_in_flight(), + packet_number, packet.encrypted_length, + has_retransmittable_data); + } + + // Deallocate message data in QuicMessageFrame immediately after packet + // sent. + if (packet.has_message) { + for (auto& frame : mutable_packet->retransmittable_frames) { + if (frame.type == MESSAGE_FRAME) { + frame.message_frame->message_data.clear(); + frame.message_frame->message_length = 0; + } + } + } + + if (packet.has_ack_frequency) { + for (const auto& frame : packet.retransmittable_frames) { + if (frame.type == ACK_FREQUENCY_FRAME) { + OnAckFrequencyFrameSent(*frame.ack_frequency_frame); + } + } + } + unacked_packets_.AddSentPacket(mutable_packet, transmission_type, sent_time, + in_flight, measure_rtt, ecn_codepoint); + // Reset the retransmission timer anytime a pending packet is sent. + return in_flight; +} + +QuicSentPacketManager::RetransmissionTimeoutMode +QuicSentPacketManager::OnRetransmissionTimeout() { + QUICHE_DCHECK(unacked_packets_.HasInFlightPackets() || + (handshake_mode_disabled_ && !handshake_finished_)); + QUICHE_DCHECK_EQ(0u, pending_timer_transmission_count_); + // Handshake retransmission, timer based loss detection, TLP, and RTO are + // implemented with a single alarm. The handshake alarm is set when the + // handshake has not completed, the loss alarm is set when the loss detection + // algorithm says to, and the TLP and RTO alarms are set after that. + // The TLP alarm is always set to run for under an RTO. + switch (GetRetransmissionMode()) { + case HANDSHAKE_MODE: + QUICHE_DCHECK(!handshake_mode_disabled_); + ++stats_->crypto_retransmit_count; + RetransmitCryptoPackets(); + return HANDSHAKE_MODE; + case LOSS_MODE: { + ++stats_->loss_timeout_count; + QuicByteCount prior_in_flight = unacked_packets_.bytes_in_flight(); + const QuicTime now = clock_->Now(); + InvokeLossDetection(now); + MaybeInvokeCongestionEvent(false, prior_in_flight, now); + return LOSS_MODE; + } + case PTO_MODE: + QUIC_DVLOG(1) << ENDPOINT << "PTO mode"; + ++stats_->pto_count; + if (handshake_mode_disabled_ && !handshake_finished_) { + ++stats_->crypto_retransmit_count; + } + ++consecutive_pto_count_; + pending_timer_transmission_count_ = 1; + return PTO_MODE; + } + QUIC_BUG(quic_bug_10750_3) + << "Unknown retransmission mode " << GetRetransmissionMode(); + return GetRetransmissionMode(); +} + +void QuicSentPacketManager::RetransmitCryptoPackets() { + QUICHE_DCHECK_EQ(HANDSHAKE_MODE, GetRetransmissionMode()); + ++consecutive_crypto_retransmission_count_; + bool packet_retransmitted = false; + std::vector crypto_retransmissions; + if (!unacked_packets_.empty()) { + QuicPacketNumber packet_number = unacked_packets_.GetLeastUnacked(); + QuicPacketNumber largest_sent_packet = + unacked_packets_.largest_sent_packet(); + for (; packet_number <= largest_sent_packet; ++packet_number) { + QuicTransmissionInfo* transmission_info = + unacked_packets_.GetMutableTransmissionInfo(packet_number); + // Only retransmit frames which are in flight, and therefore have been + // sent. + if (!transmission_info->in_flight || + transmission_info->state != OUTSTANDING || + !transmission_info->has_crypto_handshake || + !unacked_packets_.HasRetransmittableFrames(*transmission_info)) { + continue; + } + packet_retransmitted = true; + crypto_retransmissions.push_back(packet_number); + ++pending_timer_transmission_count_; + } + } + QUICHE_DCHECK(packet_retransmitted) + << "No crypto packets found to retransmit."; + for (QuicPacketNumber retransmission : crypto_retransmissions) { + MarkForRetransmission(retransmission, HANDSHAKE_RETRANSMISSION); + } +} + +bool QuicSentPacketManager::MaybeRetransmitOldestPacket(TransmissionType type) { + if (!unacked_packets_.empty()) { + QuicPacketNumber packet_number = unacked_packets_.GetLeastUnacked(); + QuicPacketNumber largest_sent_packet = + unacked_packets_.largest_sent_packet(); + for (; packet_number <= largest_sent_packet; ++packet_number) { + QuicTransmissionInfo* transmission_info = + unacked_packets_.GetMutableTransmissionInfo(packet_number); + // Only retransmit frames which are in flight, and therefore have been + // sent. + if (!transmission_info->in_flight || + transmission_info->state != OUTSTANDING || + !unacked_packets_.HasRetransmittableFrames(*transmission_info)) { + continue; + } + MarkForRetransmission(packet_number, type); + return true; + } + } + QUIC_DVLOG(1) + << "No retransmittable packets, so RetransmitOldestPacket failed."; + return false; +} + +void QuicSentPacketManager::MaybeSendProbePacket() { + if (pending_timer_transmission_count_ == 0) { + return; + } + PacketNumberSpace packet_number_space; + if (supports_multiple_packet_number_spaces()) { + // Find out the packet number space to send probe packets. + if (!GetEarliestPacketSentTimeForPto(&packet_number_space) + .IsInitialized()) { + QUIC_BUG_IF(quic_earliest_sent_time_not_initialized, + unacked_packets_.perspective() == Perspective::IS_SERVER) + << "earliest_sent_time not initialized when trying to send PTO " + "retransmissions"; + return; + } + } + std::vector probing_packets; + if (!unacked_packets_.empty()) { + QuicPacketNumber packet_number = unacked_packets_.GetLeastUnacked(); + QuicPacketNumber largest_sent_packet = + unacked_packets_.largest_sent_packet(); + for (; packet_number <= largest_sent_packet; ++packet_number) { + QuicTransmissionInfo* transmission_info = + unacked_packets_.GetMutableTransmissionInfo(packet_number); + if (transmission_info->state == OUTSTANDING && + unacked_packets_.HasRetransmittableFrames(*transmission_info) && + (!supports_multiple_packet_number_spaces() || + unacked_packets_.GetPacketNumberSpace( + transmission_info->encryption_level) == packet_number_space)) { + QUICHE_DCHECK(transmission_info->in_flight); + probing_packets.push_back(packet_number); + if (probing_packets.size() == pending_timer_transmission_count_) { + break; + } + } + } + } + + for (QuicPacketNumber retransmission : probing_packets) { + QUIC_DVLOG(1) << ENDPOINT << "Marking " << retransmission + << " for probing retransmission"; + MarkForRetransmission(retransmission, PTO_RETRANSMISSION); + } + // It is possible that there is not enough outstanding data for probing. +} + +void QuicSentPacketManager::EnableIetfPtoAndLossDetection() { + // Disable handshake mode. + handshake_mode_disabled_ = true; +} + +void QuicSentPacketManager::RetransmitDataOfSpaceIfAny( + PacketNumberSpace space) { + QUICHE_DCHECK(supports_multiple_packet_number_spaces()); + if (!unacked_packets_.GetLastInFlightPacketSentTime(space).IsInitialized()) { + // No in flight data of space. + return; + } + if (unacked_packets_.empty()) { + return; + } + QuicPacketNumber packet_number = unacked_packets_.GetLeastUnacked(); + QuicPacketNumber largest_sent_packet = unacked_packets_.largest_sent_packet(); + for (; packet_number <= largest_sent_packet; ++packet_number) { + QuicTransmissionInfo* transmission_info = + unacked_packets_.GetMutableTransmissionInfo(packet_number); + if (transmission_info->state == OUTSTANDING && + unacked_packets_.HasRetransmittableFrames(*transmission_info) && + unacked_packets_.GetPacketNumberSpace( + transmission_info->encryption_level) == space) { + QUICHE_DCHECK(transmission_info->in_flight); + if (pending_timer_transmission_count_ == 0) { + pending_timer_transmission_count_ = 1; + } + MarkForRetransmission(packet_number, PTO_RETRANSMISSION); + return; + } + } +} + +QuicSentPacketManager::RetransmissionTimeoutMode +QuicSentPacketManager::GetRetransmissionMode() const { + QUICHE_DCHECK(unacked_packets_.HasInFlightPackets() || + (handshake_mode_disabled_ && !handshake_finished_)); + if (!handshake_mode_disabled_ && !handshake_finished_ && + unacked_packets_.HasPendingCryptoPackets()) { + return HANDSHAKE_MODE; + } + if (loss_algorithm_->GetLossTimeout() != QuicTime::Zero()) { + return LOSS_MODE; + } + return PTO_MODE; +} + +void QuicSentPacketManager::InvokeLossDetection(QuicTime time) { + if (!packets_acked_.empty()) { + QUICHE_DCHECK_LE(packets_acked_.front().packet_number, + packets_acked_.back().packet_number); + largest_newly_acked_ = packets_acked_.back().packet_number; + } + LossDetectionInterface::DetectionStats detection_stats = + loss_algorithm_->DetectLosses(unacked_packets_, time, rtt_stats_, + largest_newly_acked_, packets_acked_, + &packets_lost_); + + if (detection_stats.sent_packets_max_sequence_reordering > + stats_->sent_packets_max_sequence_reordering) { + stats_->sent_packets_max_sequence_reordering = + detection_stats.sent_packets_max_sequence_reordering; + } + + stats_->sent_packets_num_borderline_time_reorderings += + detection_stats.sent_packets_num_borderline_time_reorderings; + + stats_->total_loss_detection_response_time += + detection_stats.total_loss_detection_response_time; + + for (const LostPacket& packet : packets_lost_) { + QuicTransmissionInfo* info = + unacked_packets_.GetMutableTransmissionInfo(packet.packet_number); + ++stats_->packets_lost; + if (debug_delegate_ != nullptr) { + debug_delegate_->OnPacketLoss(packet.packet_number, + info->encryption_level, LOSS_RETRANSMISSION, + time); + } + unacked_packets_.RemoveFromInFlight(info); + + MarkForRetransmission(packet.packet_number, LOSS_RETRANSMISSION); + } +} + +bool QuicSentPacketManager::MaybeUpdateRTT(QuicPacketNumber largest_acked, + QuicTime::Delta ack_delay_time, + QuicTime ack_receive_time) { + // We rely on ack_delay_time to compute an RTT estimate, so we + // only update rtt when the largest observed gets acked and the acked packet + // is not useless. + if (!unacked_packets_.IsUnacked(largest_acked)) { + return false; + } + // We calculate the RTT based on the highest ACKed packet number, the lower + // packet numbers will include the ACK aggregation delay. + const QuicTransmissionInfo& transmission_info = + unacked_packets_.GetTransmissionInfo(largest_acked); + // Ensure the packet has a valid sent time. + if (transmission_info.sent_time == QuicTime::Zero()) { + QUIC_BUG(quic_bug_10750_4) + << "Acked packet has zero sent time, largest_acked:" << largest_acked; + return false; + } + if (transmission_info.state == NOT_CONTRIBUTING_RTT) { + return false; + } + if (transmission_info.sent_time > ack_receive_time) { + QUIC_CODE_COUNT(quic_receive_acked_before_sending); + } + + QuicTime::Delta send_delta = ack_receive_time - transmission_info.sent_time; + const bool min_rtt_available = !rtt_stats_.min_rtt().IsZero(); + rtt_stats_.UpdateRtt(send_delta, ack_delay_time, ack_receive_time); + + if (!min_rtt_available && !rtt_stats_.min_rtt().IsZero()) { + loss_algorithm_->OnMinRttAvailable(); + } + + return true; +} + +QuicTime::Delta QuicSentPacketManager::TimeUntilSend(QuicTime now) const { + // The TLP logic is entirely contained within QuicSentPacketManager, so the + // send algorithm does not need to be consulted. + if (pending_timer_transmission_count_ > 0) { + return QuicTime::Delta::Zero(); + } + + if (using_pacing_) { + return pacing_sender_.TimeUntilSend(now, + unacked_packets_.bytes_in_flight()); + } + + return send_algorithm_->CanSend(unacked_packets_.bytes_in_flight()) + ? QuicTime::Delta::Zero() + : QuicTime::Delta::Infinite(); +} + +const QuicTime QuicSentPacketManager::GetRetransmissionTime() const { + if (!unacked_packets_.HasInFlightPackets() && + PeerCompletedAddressValidation()) { + return QuicTime::Zero(); + } + if (pending_timer_transmission_count_ > 0) { + // Do not set the timer if there is any credit left. + return QuicTime::Zero(); + } + switch (GetRetransmissionMode()) { + case HANDSHAKE_MODE: + return unacked_packets_.GetLastCryptoPacketSentTime() + + GetCryptoRetransmissionDelay(); + case LOSS_MODE: + return loss_algorithm_->GetLossTimeout(); + case PTO_MODE: { + if (!supports_multiple_packet_number_spaces()) { + if (unacked_packets_.HasInFlightPackets() && + consecutive_pto_count_ == 0) { + // Arm 1st PTO with earliest in flight sent time, and make sure at + // least kFirstPtoSrttMultiplier * RTT has been passed since last + // in flight packet. + return std::max( + clock_->ApproximateNow(), + std::max(unacked_packets_.GetFirstInFlightTransmissionInfo() + ->sent_time + + GetProbeTimeoutDelay(NUM_PACKET_NUMBER_SPACES), + unacked_packets_.GetLastInFlightPacketSentTime() + + kFirstPtoSrttMultiplier * + rtt_stats_.SmoothedOrInitialRtt())); + } + // Ensure PTO never gets set to a time in the past. + return std::max(clock_->ApproximateNow(), + unacked_packets_.GetLastInFlightPacketSentTime() + + GetProbeTimeoutDelay(NUM_PACKET_NUMBER_SPACES)); + } + + PacketNumberSpace packet_number_space = NUM_PACKET_NUMBER_SPACES; + // earliest_right_edge is the earliest sent time of the last in flight + // packet of all packet number spaces. + QuicTime earliest_right_edge = + GetEarliestPacketSentTimeForPto(&packet_number_space); + if (!earliest_right_edge.IsInitialized()) { + // Arm PTO from now if there is no in flight packets. + earliest_right_edge = clock_->ApproximateNow(); + } + if (packet_number_space == APPLICATION_DATA && + consecutive_pto_count_ == 0) { + const QuicTransmissionInfo* first_application_info = + unacked_packets_.GetFirstInFlightTransmissionInfoOfSpace( + APPLICATION_DATA); + if (first_application_info != nullptr) { + // Arm 1st PTO with earliest in flight sent time, and make sure at + // least kFirstPtoSrttMultiplier * RTT has been passed since last + // in flight packet. Only do this for application data. + return std::max( + clock_->ApproximateNow(), + std::max( + first_application_info->sent_time + + GetProbeTimeoutDelay(packet_number_space), + earliest_right_edge + kFirstPtoSrttMultiplier * + rtt_stats_.SmoothedOrInitialRtt())); + } + } + return std::max( + clock_->ApproximateNow(), + earliest_right_edge + GetProbeTimeoutDelay(packet_number_space)); + } + } + QUICHE_DCHECK(false); + return QuicTime::Zero(); +} + +const QuicTime::Delta QuicSentPacketManager::GetPathDegradingDelay() const { + QUICHE_DCHECK_GT(num_ptos_for_path_degrading_, 0); + return num_ptos_for_path_degrading_ * GetPtoDelay(); +} + +const QuicTime::Delta QuicSentPacketManager::GetNetworkBlackholeDelay( + int8_t num_rtos_for_blackhole_detection) const { + return GetNConsecutiveRetransmissionTimeoutDelay( + kDefaultMaxTailLossProbes + num_rtos_for_blackhole_detection); +} + +QuicTime::Delta QuicSentPacketManager::GetMtuReductionDelay( + int8_t num_rtos_for_blackhole_detection) const { + return GetNetworkBlackholeDelay(num_rtos_for_blackhole_detection / 2); +} + +const QuicTime::Delta QuicSentPacketManager::GetCryptoRetransmissionDelay() + const { + // This is equivalent to the TailLossProbeDelay, but slightly more aggressive + // because crypto handshake messages don't incur a delayed ack time. + QuicTime::Delta srtt = rtt_stats_.SmoothedOrInitialRtt(); + int64_t delay_ms; + if (conservative_handshake_retransmits_) { + // Using the delayed ack time directly could cause conservative handshake + // retransmissions to actually be more aggressive than the default. + delay_ms = std::max(peer_max_ack_delay_.ToMilliseconds(), + static_cast(2 * srtt.ToMilliseconds())); + } else { + delay_ms = std::max(kMinHandshakeTimeoutMs, + static_cast(1.5 * srtt.ToMilliseconds())); + } + return QuicTime::Delta::FromMilliseconds( + delay_ms << consecutive_crypto_retransmission_count_); +} + +const QuicTime::Delta QuicSentPacketManager::GetProbeTimeoutDelay( + PacketNumberSpace space) const { + if (rtt_stats_.smoothed_rtt().IsZero()) { + // Respect kMinHandshakeTimeoutMs to avoid a potential amplification attack. + QUIC_BUG_IF(quic_bug_12552_6, rtt_stats_.initial_rtt().IsZero()); + return std::max(kPtoMultiplierWithoutRttSamples * rtt_stats_.initial_rtt(), + QuicTime::Delta::FromMilliseconds(kMinHandshakeTimeoutMs)) * + (1 << consecutive_pto_count_); + } + QuicTime::Delta pto_delay = + rtt_stats_.smoothed_rtt() + + std::max(kPtoRttvarMultiplier * rtt_stats_.mean_deviation(), + kAlarmGranularity) + + (ShouldAddMaxAckDelay(space) ? peer_max_ack_delay_ + : QuicTime::Delta::Zero()); + return pto_delay * (1 << consecutive_pto_count_); +} + +QuicTime::Delta QuicSentPacketManager::GetSlowStartDuration() const { + if (send_algorithm_->GetCongestionControlType() == kBBR || + send_algorithm_->GetCongestionControlType() == kBBRv2) { + return stats_->slowstart_duration.GetTotalElapsedTime( + clock_->ApproximateNow()); + } + return QuicTime::Delta::Infinite(); +} + +QuicByteCount QuicSentPacketManager::GetAvailableCongestionWindowInBytes() + const { + QuicByteCount congestion_window = GetCongestionWindowInBytes(); + QuicByteCount bytes_in_flight = GetBytesInFlight(); + return congestion_window - std::min(congestion_window, bytes_in_flight); +} + +std::string QuicSentPacketManager::GetDebugState() const { + return send_algorithm_->GetDebugState(); +} + +void QuicSentPacketManager::SetSendAlgorithm( + CongestionControlType congestion_control_type) { + if (send_algorithm_ && + send_algorithm_->GetCongestionControlType() == congestion_control_type) { + return; + } + + SetSendAlgorithm(SendAlgorithmInterface::Create( + clock_, &rtt_stats_, &unacked_packets_, congestion_control_type, random_, + stats_, initial_congestion_window_, send_algorithm_.get())); +} + +void QuicSentPacketManager::SetSendAlgorithm( + SendAlgorithmInterface* send_algorithm) { + send_algorithm_.reset(send_algorithm); + pacing_sender_.set_sender(send_algorithm); +} + +std::unique_ptr +QuicSentPacketManager::OnConnectionMigration(bool reset_send_algorithm) { + consecutive_pto_count_ = 0; + rtt_stats_.OnConnectionMigration(); + if (!reset_send_algorithm) { + send_algorithm_->OnConnectionMigration(); + return nullptr; + } + + std::unique_ptr old_send_algorithm = + std::move(send_algorithm_); + SetSendAlgorithm(old_send_algorithm->GetCongestionControlType()); + // Treat all in flight packets sent to the old peer address as lost and + // retransmit them. + QuicPacketNumber packet_number = unacked_packets_.GetLeastUnacked(); + for (auto it = unacked_packets_.begin(); it != unacked_packets_.end(); + ++it, ++packet_number) { + if (it->in_flight) { + // Proactively retransmit any packet which is in flight on the old path. + // As a result, these packets will not contribute to congestion control. + unacked_packets_.RemoveFromInFlight(packet_number); + // Retransmitting these packets with PATH_CHANGE_RETRANSMISSION will mark + // them as useless, thus not contributing to RTT stats. + if (unacked_packets_.HasRetransmittableFrames(packet_number)) { + MarkForRetransmission(packet_number, PATH_RETRANSMISSION); + QUICHE_DCHECK_EQ(it->state, NOT_CONTRIBUTING_RTT); + } + } + it->state = NOT_CONTRIBUTING_RTT; + } + return old_send_algorithm; +} + +void QuicSentPacketManager::OnAckFrameStart(QuicPacketNumber largest_acked, + QuicTime::Delta ack_delay_time, + QuicTime ack_receive_time) { + QUICHE_DCHECK(packets_acked_.empty()); + QUICHE_DCHECK_LE(largest_acked, unacked_packets_.largest_sent_packet()); + // Ignore peer_max_ack_delay and use received ack_delay during + // handshake when supporting multiple packet number spaces. + if (!supports_multiple_packet_number_spaces() || handshake_finished_) { + if (ack_delay_time > peer_max_ack_delay()) { + ack_delay_time = peer_max_ack_delay(); + } + if (ignore_ack_delay_) { + ack_delay_time = QuicTime::Delta::Zero(); + } + } + rtt_updated_ = + MaybeUpdateRTT(largest_acked, ack_delay_time, ack_receive_time); + last_ack_frame_.ack_delay_time = ack_delay_time; + acked_packets_iter_ = last_ack_frame_.packets.rbegin(); +} + +void QuicSentPacketManager::OnAckRange(QuicPacketNumber start, + QuicPacketNumber end) { + if (!last_ack_frame_.largest_acked.IsInitialized() || + end > last_ack_frame_.largest_acked + 1) { + // Largest acked increases. + unacked_packets_.IncreaseLargestAcked(end - 1); + last_ack_frame_.largest_acked = end - 1; + } + // Drop ack ranges which ack packets below least_unacked. + QuicPacketNumber least_unacked = unacked_packets_.GetLeastUnacked(); + if (least_unacked.IsInitialized() && end <= least_unacked) { + return; + } + start = std::max(start, least_unacked); + do { + QuicPacketNumber newly_acked_start = start; + if (acked_packets_iter_ != last_ack_frame_.packets.rend()) { + newly_acked_start = std::max(start, acked_packets_iter_->max()); + } + for (QuicPacketNumber acked = end - 1; acked >= newly_acked_start; + --acked) { + // Check if end is above the current range. If so add newly acked packets + // in descending order. + packets_acked_.push_back(AckedPacket(acked, 0, QuicTime::Zero())); + if (acked == FirstSendingPacketNumber()) { + break; + } + } + if (acked_packets_iter_ == last_ack_frame_.packets.rend() || + start > acked_packets_iter_->min()) { + // Finish adding all newly acked packets. + return; + } + end = std::min(end, acked_packets_iter_->min()); + ++acked_packets_iter_; + } while (start < end); +} + +void QuicSentPacketManager::OnAckTimestamp(QuicPacketNumber packet_number, + QuicTime timestamp) { + last_ack_frame_.received_packet_times.push_back({packet_number, timestamp}); + for (AckedPacket& packet : packets_acked_) { + if (packet.packet_number == packet_number) { + packet.receive_timestamp = timestamp; + return; + } + } +} + +AckResult QuicSentPacketManager::OnAckFrameEnd( + QuicTime ack_receive_time, QuicPacketNumber ack_packet_number, + EncryptionLevel ack_decrypted_level, + const absl::optional& ecn_counts) { + QuicByteCount prior_bytes_in_flight = unacked_packets_.bytes_in_flight(); + // Reverse packets_acked_ so that it is in ascending order. + std::reverse(packets_acked_.begin(), packets_acked_.end()); + for (AckedPacket& acked_packet : packets_acked_) { + QuicTransmissionInfo* info = + unacked_packets_.GetMutableTransmissionInfo(acked_packet.packet_number); + if (!QuicUtils::IsAckable(info->state)) { + if (info->state == ACKED) { + QUIC_BUG(quic_bug_10750_5) + << "Trying to ack an already acked packet: " + << acked_packet.packet_number + << ", last_ack_frame_: " << last_ack_frame_ + << ", least_unacked: " << unacked_packets_.GetLeastUnacked() + << ", packets_acked_: " << quiche::PrintElements(packets_acked_); + } else { + QUIC_PEER_BUG(quic_peer_bug_10750_6) + << "Received " << ack_decrypted_level + << " ack for unackable packet: " << acked_packet.packet_number + << " with state: " + << QuicUtils::SentPacketStateToString(info->state); + if (supports_multiple_packet_number_spaces()) { + if (info->state == NEVER_SENT) { + return UNSENT_PACKETS_ACKED; + } + return UNACKABLE_PACKETS_ACKED; + } + } + continue; + } + QUIC_DVLOG(1) << ENDPOINT << "Got an " << ack_decrypted_level + << " ack for packet " << acked_packet.packet_number + << " , state: " + << QuicUtils::SentPacketStateToString(info->state); + const PacketNumberSpace packet_number_space = + unacked_packets_.GetPacketNumberSpace(info->encryption_level); + if (supports_multiple_packet_number_spaces() && + QuicUtils::GetPacketNumberSpace(ack_decrypted_level) != + packet_number_space) { + return PACKETS_ACKED_IN_WRONG_PACKET_NUMBER_SPACE; + } + last_ack_frame_.packets.Add(acked_packet.packet_number); + if (info->encryption_level == ENCRYPTION_HANDSHAKE) { + handshake_packet_acked_ = true; + } else if (info->encryption_level == ENCRYPTION_ZERO_RTT) { + zero_rtt_packet_acked_ = true; + } else if (info->encryption_level == ENCRYPTION_FORWARD_SECURE) { + one_rtt_packet_acked_ = true; + } + largest_packet_peer_knows_is_acked_.UpdateMax(info->largest_acked); + if (supports_multiple_packet_number_spaces()) { + largest_packets_peer_knows_is_acked_[packet_number_space].UpdateMax( + info->largest_acked); + } + // If data is associated with the most recent transmission of this + // packet, then inform the caller. + if (info->in_flight) { + acked_packet.bytes_acked = info->bytes_sent; + } else { + // Unackable packets are skipped earlier. + largest_newly_acked_ = acked_packet.packet_number; + } + unacked_packets_.MaybeUpdateLargestAckedOfPacketNumberSpace( + packet_number_space, acked_packet.packet_number); + MarkPacketHandled(acked_packet.packet_number, info, ack_receive_time, + last_ack_frame_.ack_delay_time, + acked_packet.receive_timestamp); + } + PacketNumberSpace packet_number_space = + QuicUtils::GetPacketNumberSpace(ack_decrypted_level); + const bool acked_new_packet = !packets_acked_.empty(); + PostProcessNewlyAckedPackets(ack_packet_number, ack_decrypted_level, + last_ack_frame_, ack_receive_time, rtt_updated_, + prior_bytes_in_flight); + if (ecn_counts.has_value()) { + peer_ack_ecn_counts_[packet_number_space] = ecn_counts.value(); + } + + return acked_new_packet ? PACKETS_NEWLY_ACKED : NO_PACKETS_NEWLY_ACKED; +} + +void QuicSentPacketManager::SetDebugDelegate(DebugDelegate* debug_delegate) { + debug_delegate_ = debug_delegate; +} + +void QuicSentPacketManager::OnApplicationLimited() { + if (using_pacing_) { + pacing_sender_.OnApplicationLimited(); + } + send_algorithm_->OnApplicationLimited(unacked_packets_.bytes_in_flight()); + if (debug_delegate_ != nullptr) { + debug_delegate_->OnApplicationLimited(); + } +} + +NextReleaseTimeResult QuicSentPacketManager::GetNextReleaseTime() const { + if (!using_pacing_) { + return {QuicTime::Zero(), false}; + } + + return pacing_sender_.GetNextReleaseTime(); +} + +void QuicSentPacketManager::SetInitialRtt(QuicTime::Delta rtt, bool trusted) { + const QuicTime::Delta min_rtt = QuicTime::Delta::FromMicroseconds( + trusted ? kMinTrustedInitialRoundTripTimeUs + : kMinUntrustedInitialRoundTripTimeUs); + QuicTime::Delta max_rtt = + QuicTime::Delta::FromMicroseconds(kMaxInitialRoundTripTimeUs); + rtt_stats_.set_initial_rtt(std::max(min_rtt, std::min(max_rtt, rtt))); +} + +void QuicSentPacketManager::EnableMultiplePacketNumberSpacesSupport() { + EnableIetfPtoAndLossDetection(); + unacked_packets_.EnableMultiplePacketNumberSpacesSupport(); +} + +QuicPacketNumber QuicSentPacketManager::GetLargestAckedPacket( + EncryptionLevel decrypted_packet_level) const { + QUICHE_DCHECK(supports_multiple_packet_number_spaces()); + return unacked_packets_.GetLargestAckedOfPacketNumberSpace( + QuicUtils::GetPacketNumberSpace(decrypted_packet_level)); +} + +QuicPacketNumber QuicSentPacketManager::GetLeastPacketAwaitedByPeer( + EncryptionLevel encryption_level) const { + QuicPacketNumber largest_acked; + if (supports_multiple_packet_number_spaces()) { + largest_acked = GetLargestAckedPacket(encryption_level); + } else { + largest_acked = GetLargestObserved(); + } + if (!largest_acked.IsInitialized()) { + // If no packets have been acked, return the first sent packet to ensure + // we use a large enough packet number length. + return FirstSendingPacketNumber(); + } + QuicPacketNumber least_awaited = largest_acked + 1; + QuicPacketNumber least_unacked = GetLeastUnacked(); + if (least_unacked.IsInitialized() && least_unacked < least_awaited) { + least_awaited = least_unacked; + } + return least_awaited; +} + +QuicPacketNumber QuicSentPacketManager::GetLargestPacketPeerKnowsIsAcked( + EncryptionLevel decrypted_packet_level) const { + QUICHE_DCHECK(supports_multiple_packet_number_spaces()); + return largest_packets_peer_knows_is_acked_[QuicUtils::GetPacketNumberSpace( + decrypted_packet_level)]; +} + +QuicTime::Delta +QuicSentPacketManager::GetNConsecutiveRetransmissionTimeoutDelay( + int num_timeouts) const { + QuicTime::Delta total_delay = QuicTime::Delta::Zero(); + const QuicTime::Delta srtt = rtt_stats_.SmoothedOrInitialRtt(); + int num_tlps = + std::min(num_timeouts, static_cast(kDefaultMaxTailLossProbes)); + num_timeouts -= num_tlps; + if (num_tlps > 0) { + const QuicTime::Delta tlp_delay = std::max( + 2 * srtt, + unacked_packets_.HasMultipleInFlightPackets() + ? QuicTime::Delta::FromMilliseconds(kMinTailLossProbeTimeoutMs) + : (1.5 * srtt + + (QuicTime::Delta::FromMilliseconds(kMinRetransmissionTimeMs) * + 0.5))); + total_delay = total_delay + num_tlps * tlp_delay; + } + if (num_timeouts == 0) { + return total_delay; + } + + const QuicTime::Delta retransmission_delay = + rtt_stats_.smoothed_rtt().IsZero() + ? QuicTime::Delta::FromMilliseconds(kDefaultRetransmissionTimeMs) + : std::max( + srtt + 4 * rtt_stats_.mean_deviation(), + QuicTime::Delta::FromMilliseconds(kMinRetransmissionTimeMs)); + total_delay = total_delay + ((1 << num_timeouts) - 1) * retransmission_delay; + return total_delay; +} + +bool QuicSentPacketManager::PeerCompletedAddressValidation() const { + if (unacked_packets_.perspective() == Perspective::IS_SERVER || + !handshake_mode_disabled_) { + return true; + } + + // To avoid handshake deadlock due to anti-amplification limit, client needs + // to set PTO timer until server successfully processed any HANDSHAKE packet. + return handshake_finished_ || handshake_packet_acked_; +} + +bool QuicSentPacketManager::IsLessThanThreePTOs(QuicTime::Delta timeout) const { + return timeout < 3 * GetPtoDelay(); +} + +QuicTime::Delta QuicSentPacketManager::GetPtoDelay() const { + return GetProbeTimeoutDelay(APPLICATION_DATA); +} + +void QuicSentPacketManager::OnAckFrequencyFrameSent( + const QuicAckFrequencyFrame& ack_frequency_frame) { + in_use_sent_ack_delays_.emplace_back(ack_frequency_frame.max_ack_delay, + ack_frequency_frame.sequence_number); + if (ack_frequency_frame.max_ack_delay > peer_max_ack_delay_) { + peer_max_ack_delay_ = ack_frequency_frame.max_ack_delay; + } +} + +void QuicSentPacketManager::OnAckFrequencyFrameAcked( + const QuicAckFrequencyFrame& ack_frequency_frame) { + int stale_entry_count = 0; + for (auto it = in_use_sent_ack_delays_.cbegin(); + it != in_use_sent_ack_delays_.cend(); ++it) { + if (it->second < ack_frequency_frame.sequence_number) { + ++stale_entry_count; + } else { + break; + } + } + if (stale_entry_count > 0) { + in_use_sent_ack_delays_.pop_front_n(stale_entry_count); + } + if (in_use_sent_ack_delays_.empty()) { + QUIC_BUG(quic_bug_10750_7) << "in_use_sent_ack_delays_ is empty."; + return; + } + peer_max_ack_delay_ = std::max_element(in_use_sent_ack_delays_.cbegin(), + in_use_sent_ack_delays_.cend()) + ->first; +} + +#undef ENDPOINT // undef for jumbo builds +} // namespace quic diff --git a/quiche/quic/core/quic_sent_packet_manager.h b/quiche/quic/core/quic_sent_packet_manager.h new file mode 100644 index 000000000000..e6c8105c808a --- /dev/null +++ b/quiche/quic/core/quic_sent_packet_manager.h @@ -0,0 +1,680 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_SENT_PACKET_MANAGER_H_ +#define QUICHE_QUIC_CORE_QUIC_SENT_PACKET_MANAGER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "quiche/quic/core/congestion_control/pacing_sender.h" +#include "quiche/quic/core/congestion_control/rtt_stats.h" +#include "quiche/quic/core/congestion_control/send_algorithm_interface.h" +#include "quiche/quic/core/congestion_control/uber_loss_algorithm.h" +#include "quiche/quic/core/proto/cached_network_parameters_proto.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_sustained_bandwidth_recorder.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_transmission_info.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_unacked_packet_map.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/common/quiche_circular_deque.h" + +namespace quic { + +namespace test { +class QuicConnectionPeer; +class QuicSentPacketManagerPeer; +} // namespace test + +class QuicClock; +class QuicConfig; +struct QuicConnectionStats; + +// Class which tracks the set of packets sent on a QUIC connection and contains +// a send algorithm to decide when to send new packets. It keeps track of any +// retransmittable data associated with each packet. If a packet is +// retransmitted, it will keep track of each version of a packet so that if a +// previous transmission is acked, the data will not be retransmitted. +class QUIC_EXPORT_PRIVATE QuicSentPacketManager { + public: + // Interface which gets callbacks from the QuicSentPacketManager at + // interesting points. Implementations must not mutate the state of + // the packet manager or connection as a result of these callbacks. + class QUIC_EXPORT_PRIVATE DebugDelegate { + public: + struct QUIC_EXPORT_PRIVATE SendParameters { + CongestionControlType congestion_control_type; + bool use_pacing; + QuicPacketCount initial_congestion_window; + }; + + virtual ~DebugDelegate() {} + + // Called when a spurious retransmission is detected. + virtual void OnSpuriousPacketRetransmission( + TransmissionType /*transmission_type*/, QuicByteCount /*byte_size*/) {} + + virtual void OnIncomingAck(QuicPacketNumber /*ack_packet_number*/, + EncryptionLevel /*ack_decrypted_level*/, + const QuicAckFrame& /*ack_frame*/, + QuicTime /*ack_receive_time*/, + QuicPacketNumber /*largest_observed*/, + bool /*rtt_updated*/, + QuicPacketNumber /*least_unacked_sent_packet*/) { + } + + virtual void OnPacketLoss(QuicPacketNumber /*lost_packet_number*/, + EncryptionLevel /*encryption_level*/, + TransmissionType /*transmission_type*/, + QuicTime /*detection_time*/) {} + + virtual void OnApplicationLimited() {} + + virtual void OnAdjustNetworkParameters(QuicBandwidth /*bandwidth*/, + QuicTime::Delta /*rtt*/, + QuicByteCount /*old_cwnd*/, + QuicByteCount /*new_cwnd*/) {} + + virtual void OnAdjustBurstSize(int /*old_burst_size*/, + int /*new_burst_size*/) {} + + virtual void OnOvershootingDetected() {} + + virtual void OnConfigProcessed(const SendParameters& /*parameters*/) {} + }; + + // Interface which gets callbacks from the QuicSentPacketManager when + // network-related state changes. Implementations must not mutate the + // state of the packet manager as a result of these callbacks. + class QUIC_EXPORT_PRIVATE NetworkChangeVisitor { + public: + virtual ~NetworkChangeVisitor() {} + + // Called when congestion window or RTT may have changed. + virtual void OnCongestionChange() = 0; + + // Called when the Path MTU may have increased. + virtual void OnPathMtuIncreased(QuicPacketLength packet_size) = 0; + }; + + // The retransmission timer is a single timer which switches modes depending + // upon connection state. + enum RetransmissionTimeoutMode { + // Retransmission of handshake packets prior to handshake completion. + HANDSHAKE_MODE, + // Re-invoke the loss detection when a packet is not acked before the + // loss detection algorithm expects. + LOSS_MODE, + // A probe timeout. At least one probe packet must be sent when timer + // expires. + PTO_MODE, + }; + + QuicSentPacketManager(Perspective perspective, const QuicClock* clock, + QuicRandom* random, QuicConnectionStats* stats, + CongestionControlType congestion_control_type); + QuicSentPacketManager(const QuicSentPacketManager&) = delete; + QuicSentPacketManager& operator=(const QuicSentPacketManager&) = delete; + virtual ~QuicSentPacketManager(); + + virtual void SetFromConfig(const QuicConfig& config); + + void ReserveUnackedPacketsInitialCapacity(int initial_capacity) { + unacked_packets_.ReserveInitialCapacity(initial_capacity); + } + + void ApplyConnectionOptions(const QuicTagVector& connection_options); + + // Pass the CachedNetworkParameters to the send algorithm. + void ResumeConnectionState( + const CachedNetworkParameters& cached_network_params, + bool max_bandwidth_resumption); + + void SetMaxPacingRate(QuicBandwidth max_pacing_rate) { + pacing_sender_.set_max_pacing_rate(max_pacing_rate); + } + + QuicBandwidth MaxPacingRate() const { + return pacing_sender_.max_pacing_rate(); + } + + // Called to mark the handshake state complete, and all handshake packets are + // neutered. + // TODO(fayang): Rename this function to OnHandshakeComplete. + void SetHandshakeConfirmed(); + + // Requests retransmission of all unacked 0-RTT packets. + // Only 0-RTT encrypted packets will be retransmitted. This can happen, + // for example, when a CHLO has been rejected and the previously encrypted + // data needs to be encrypted with a new key. + void MarkZeroRttPacketsForRetransmission(); + + // Request retransmission of all unacked INITIAL packets. + void MarkInitialPacketsForRetransmission(); + + // Notify the sent packet manager of an external network measurement or + // prediction for either |bandwidth| or |rtt|; either can be empty. + void AdjustNetworkParameters( + const SendAlgorithmInterface::NetworkParams& params); + + void SetLossDetectionTuner( + std::unique_ptr tuner); + void OnConfigNegotiated(); + void OnConnectionClosed(); + + // Retransmits the oldest pending packet. + bool MaybeRetransmitOldestPacket(TransmissionType type); + + // Removes the retransmittable frames from all unencrypted packets to ensure + // they don't get retransmitted. + void NeuterUnencryptedPackets(); + + // Returns true if there's outstanding crypto data. + bool HasUnackedCryptoPackets() const { + return unacked_packets_.HasPendingCryptoPackets(); + } + + // Returns true if there are packets in flight expecting to be acknowledged. + bool HasInFlightPackets() const { + return unacked_packets_.HasInFlightPackets(); + } + + // Returns the smallest packet number of a serialized packet which has not + // been acked by the peer. + QuicPacketNumber GetLeastUnacked() const { + return unacked_packets_.GetLeastUnacked(); + } + + // Called when we have sent bytes to the peer. This informs the manager both + // the number of bytes sent and if they were retransmitted and if this packet + // is used for rtt measuring. Returns true if the sender should reset the + // retransmission timer. + bool OnPacketSent(SerializedPacket* mutable_packet, QuicTime sent_time, + TransmissionType transmission_type, + HasRetransmittableData has_retransmittable_data, + bool measure_rtt, QuicEcnCodepoint ecn_codepoint); + + bool CanSendAckFrequency() const; + + QuicAckFrequencyFrame GetUpdatedAckFrequencyFrame() const; + + // Called when the retransmission timer expires and returns the retransmission + // mode. + RetransmissionTimeoutMode OnRetransmissionTimeout(); + + // Calculate the time until we can send the next packet to the wire. + // Note 1: When kUnknownWaitTime is returned, there is no need to poll + // TimeUntilSend again until we receive an OnIncomingAckFrame event. + // Note 2: Send algorithms may or may not use |retransmit| in their + // calculations. + QuicTime::Delta TimeUntilSend(QuicTime now) const; + + // Returns the current delay for the retransmission timer, which may send + // either a tail loss probe or do a full RTO. Returns QuicTime::Zero() if + // there are no retransmittable packets. + const QuicTime GetRetransmissionTime() const; + + // Returns the current delay for the path degrading timer, which is used to + // notify the session that this connection is degrading. + const QuicTime::Delta GetPathDegradingDelay() const; + + // Returns the current delay for detecting network blackhole. + const QuicTime::Delta GetNetworkBlackholeDelay( + int8_t num_rtos_for_blackhole_detection) const; + + // Returns the delay before reducing max packet size. This delay is guranteed + // to be smaller than the network blackhole delay. + QuicTime::Delta GetMtuReductionDelay( + int8_t num_rtos_for_blackhole_detection) const; + + const RttStats* GetRttStats() const { return &rtt_stats_; } + + void SetRttStats(const RttStats& rtt_stats) { + rtt_stats_.CloneFrom(rtt_stats); + } + + // Returns the estimated bandwidth calculated by the congestion algorithm. + QuicBandwidth BandwidthEstimate() const { + return send_algorithm_->BandwidthEstimate(); + } + + const QuicSustainedBandwidthRecorder* SustainedBandwidthRecorder() const { + return &sustained_bandwidth_recorder_; + } + + // Returns the size of the current congestion window in number of + // kDefaultTCPMSS-sized segments. Note, this is not the *available* window. + // Some send algorithms may not use a congestion window and will return 0. + QuicPacketCount GetCongestionWindowInTcpMss() const { + return send_algorithm_->GetCongestionWindow() / kDefaultTCPMSS; + } + + // Returns the number of packets of length |max_packet_length| which fit in + // the current congestion window. More packets may end up in flight if the + // congestion window has been recently reduced, of if non-full packets are + // sent. + QuicPacketCount EstimateMaxPacketsInFlight( + QuicByteCount max_packet_length) const { + return send_algorithm_->GetCongestionWindow() / max_packet_length; + } + + // Returns the size of the current congestion window size in bytes. + QuicByteCount GetCongestionWindowInBytes() const { + return send_algorithm_->GetCongestionWindow(); + } + + // Returns the difference between current congestion window and bytes in + // flight. Returns 0 if bytes in flight is bigger than the current congestion + // window. + QuicByteCount GetAvailableCongestionWindowInBytes() const; + + QuicBandwidth GetPacingRate() const { + return send_algorithm_->PacingRate(GetBytesInFlight()); + } + + // Returns the size of the slow start congestion window in nume of 1460 byte + // TCP segments, aka ssthresh. Some send algorithms do not define a slow + // start threshold and will return 0. + QuicPacketCount GetSlowStartThresholdInTcpMss() const { + return send_algorithm_->GetSlowStartThreshold() / kDefaultTCPMSS; + } + + // Return the total time spent in slow start so far. If the sender is + // currently in slow start, the return value will include the duration between + // the most recent entry to slow start and now. + // + // Only implemented for BBR. Return QuicTime::Delta::Infinite() for other + // congestion controllers. + QuicTime::Delta GetSlowStartDuration() const; + + // Returns debugging information about the state of the congestion controller. + std::string GetDebugState() const; + + // Returns the number of bytes that are considered in-flight, i.e. not lost or + // acknowledged. + QuicByteCount GetBytesInFlight() const { + return unacked_packets_.bytes_in_flight(); + } + + // Called when peer address changes. Must be called IFF the address change is + // not NAT rebinding. If reset_send_algorithm is true, switch to a new send + // algorithm object and retransmit all the in-flight packets. Return the send + // algorithm object used on the previous path. + std::unique_ptr OnConnectionMigration( + bool reset_send_algorithm); + + // Called when an ack frame is initially parsed. + void OnAckFrameStart(QuicPacketNumber largest_acked, + QuicTime::Delta ack_delay_time, + QuicTime ack_receive_time); + + // Called when ack range [start, end) is received. Populates packets_acked_ + // with newly acked packets. + void OnAckRange(QuicPacketNumber start, QuicPacketNumber end); + + // Called when a timestamp is processed. If it's present in packets_acked_, + // the timestamp field is set. Otherwise, the timestamp is ignored. + void OnAckTimestamp(QuicPacketNumber packet_number, QuicTime timestamp); + + // Called when an ack frame is parsed completely. + AckResult OnAckFrameEnd(QuicTime ack_receive_time, + QuicPacketNumber ack_packet_number, + EncryptionLevel ack_decrypted_level, + const absl::optional& ecn_counts); + + void EnableMultiplePacketNumberSpacesSupport(); + + void SetDebugDelegate(DebugDelegate* debug_delegate); + + void SetPacingAlarmGranularity(QuicTime::Delta alarm_granularity) { + pacing_sender_.set_alarm_granularity(alarm_granularity); + } + + QuicPacketNumber GetLargestObserved() const { + return unacked_packets_.largest_acked(); + } + + QuicPacketNumber GetLargestAckedPacket( + EncryptionLevel decrypted_packet_level) const; + + QuicPacketNumber GetLargestSentPacket() const { + return unacked_packets_.largest_sent_packet(); + } + + // Returns the lowest of the largest acknowledged packet and the least + // unacked packet. This is designed to be used when computing the packet + // number length to send. + QuicPacketNumber GetLeastPacketAwaitedByPeer( + EncryptionLevel encryption_level) const; + + QuicPacketNumber GetLargestPacketPeerKnowsIsAcked( + EncryptionLevel decrypted_packet_level) const; + + void SetNetworkChangeVisitor(NetworkChangeVisitor* visitor) { + QUICHE_DCHECK(!network_change_visitor_); + QUICHE_DCHECK(visitor); + network_change_visitor_ = visitor; + } + + bool InSlowStart() const { return send_algorithm_->InSlowStart(); } + + size_t GetConsecutivePtoCount() const { return consecutive_pto_count_; } + + void OnApplicationLimited(); + + const SendAlgorithmInterface* GetSendAlgorithm() const { + return send_algorithm_.get(); + } + + void SetSessionNotifier(SessionNotifierInterface* session_notifier) { + unacked_packets_.SetSessionNotifier(session_notifier); + } + + NextReleaseTimeResult GetNextReleaseTime() const; + + QuicPacketCount initial_congestion_window() const { + return initial_congestion_window_; + } + + QuicPacketNumber largest_packet_peer_knows_is_acked() const { + QUICHE_DCHECK(!supports_multiple_packet_number_spaces()); + return largest_packet_peer_knows_is_acked_; + } + + size_t pending_timer_transmission_count() const { + return pending_timer_transmission_count_; + } + + QuicTime::Delta peer_max_ack_delay() const { return peer_max_ack_delay_; } + + void set_peer_max_ack_delay(QuicTime::Delta peer_max_ack_delay) { + // The delayed ack time should never be more than one half the min RTO time. + QUICHE_DCHECK_LE( + peer_max_ack_delay, + (QuicTime::Delta::FromMilliseconds(kMinRetransmissionTimeMs) * 0.5)); + peer_max_ack_delay_ = peer_max_ack_delay; + } + + const QuicUnackedPacketMap& unacked_packets() const { + return unacked_packets_; + } + + const UberLossAlgorithm* uber_loss_algorithm() const { + return &uber_loss_algorithm_; + } + + // Sets the send algorithm to the given congestion control type and points the + // pacing sender at |send_algorithm_|. Can be called any number of times. + void SetSendAlgorithm(CongestionControlType congestion_control_type); + + // Sets the send algorithm to |send_algorithm| and points the pacing sender at + // |send_algorithm_|. Takes ownership of |send_algorithm|. Can be called any + // number of times. + // Setting the send algorithm once the connection is underway is dangerous. + void SetSendAlgorithm(SendAlgorithmInterface* send_algorithm); + + // Sends one probe packet. + void MaybeSendProbePacket(); + + // Called to disable HANDSHAKE_MODE, and only PTO and LOSS modes are used. + // Also enable IETF loss detection. + void EnableIetfPtoAndLossDetection(); + + // Called to retransmit in flight packet of |space| if any. + void RetransmitDataOfSpaceIfAny(PacketNumberSpace space); + + // Returns true if |timeout| is less than 3 * RTO/PTO delay. + bool IsLessThanThreePTOs(QuicTime::Delta timeout) const; + + // Returns current PTO delay. + QuicTime::Delta GetPtoDelay() const; + + bool supports_multiple_packet_number_spaces() const { + return unacked_packets_.supports_multiple_packet_number_spaces(); + } + + bool handshake_mode_disabled() const { return handshake_mode_disabled_; } + + bool zero_rtt_packet_acked() const { return zero_rtt_packet_acked_; } + + bool one_rtt_packet_acked() const { return one_rtt_packet_acked_; } + + void OnUserAgentIdKnown() { loss_algorithm_->OnUserAgentIdKnown(); } + + // Gets the earliest in flight packet sent time to calculate PTO. Also + // updates |packet_number_space| if a PTO timer should be armed. + QuicTime GetEarliestPacketSentTimeForPto( + PacketNumberSpace* packet_number_space) const; + + void set_num_ptos_for_path_degrading(int num_ptos_for_path_degrading) { + num_ptos_for_path_degrading_ = num_ptos_for_path_degrading; + } + + // Sets the initial RTT of the connection. The inital RTT is clamped to + // - A maximum of kMaxInitialRoundTripTimeUs. + // - A minimum of kMinTrustedInitialRoundTripTimeUs if |trusted|, or + // kMinUntrustedInitialRoundTripTimeUs if not |trusted|. + void SetInitialRtt(QuicTime::Delta rtt, bool trusted); + + private: + friend class test::QuicConnectionPeer; + friend class test::QuicSentPacketManagerPeer; + + // Returns the current retransmission mode. + RetransmissionTimeoutMode GetRetransmissionMode() const; + + // Retransmits all crypto stream packets. + void RetransmitCryptoPackets(); + + // Returns the timeout for retransmitting crypto handshake packets. + const QuicTime::Delta GetCryptoRetransmissionDelay() const; + + // Returns the probe timeout. + const QuicTime::Delta GetProbeTimeoutDelay(PacketNumberSpace space) const; + + // Update the RTT if the ack is for the largest acked packet number. + // Returns true if the rtt was updated. + bool MaybeUpdateRTT(QuicPacketNumber largest_acked, + QuicTime::Delta ack_delay_time, + QuicTime ack_receive_time); + + // Invokes the loss detection algorithm and loses and retransmits packets if + // necessary. + void InvokeLossDetection(QuicTime time); + + // Invokes OnCongestionEvent if |rtt_updated| is true, there are pending acks, + // or pending losses. Clears pending acks and pending losses afterwards. + // |prior_in_flight| is the number of bytes in flight before the losses or + // acks, |event_time| is normally the timestamp of the ack packet which caused + // the event, although it can be the time at which loss detection was + // triggered. + void MaybeInvokeCongestionEvent(bool rtt_updated, + QuicByteCount prior_in_flight, + QuicTime event_time); + + // Removes the retransmittability and in flight properties from the packet at + // |info| due to receipt by the peer. + void MarkPacketHandled(QuicPacketNumber packet_number, + QuicTransmissionInfo* info, QuicTime ack_receive_time, + QuicTime::Delta ack_delay_time, + QuicTime receive_timestamp); + + // Request that |packet_number| be retransmitted after the other pending + // retransmissions. Does not add it to the retransmissions if it's already + // a pending retransmission. Do not reuse iterator of the underlying + // unacked_packets_ after calling this function as it can be invalidated. + void MarkForRetransmission(QuicPacketNumber packet_number, + TransmissionType transmission_type); + + // Called after packets have been marked handled with last received ack frame. + void PostProcessNewlyAckedPackets(QuicPacketNumber ack_packet_number, + EncryptionLevel ack_decrypted_level, + const QuicAckFrame& ack_frame, + QuicTime ack_receive_time, bool rtt_updated, + QuicByteCount prior_bytes_in_flight); + + // Notify observers that packet with QuicTransmissionInfo |info| is a spurious + // retransmission. It is caller's responsibility to guarantee the packet with + // QuicTransmissionInfo |info| is a spurious retransmission before calling + // this function. + void RecordOneSpuriousRetransmission(const QuicTransmissionInfo& info); + + // Called when handshake is confirmed to remove the retransmittable frames + // from all packets of HANDSHAKE_DATA packet number space to ensure they don't + // get retransmitted and will eventually be removed from unacked packets map. + void NeuterHandshakePackets(); + + // Indicates whether including peer_max_ack_delay_ when calculating PTO + // timeout. + bool ShouldAddMaxAckDelay(PacketNumberSpace space) const; + + // A helper function to return total delay of |num_timeouts| retransmission + // timeout with TLP and RTO mode. + // TODO(fayang): remove this method and calculate blackhole delay by PTO. + QuicTime::Delta GetNConsecutiveRetransmissionTimeoutDelay( + int num_timeouts) const; + + // Returns true if peer has finished address validation, such that + // retransmission timer is not armed if there is no packets in flight. + bool PeerCompletedAddressValidation() const; + + // Called when an AckFrequencyFrame is sent. + void OnAckFrequencyFrameSent( + const QuicAckFrequencyFrame& ack_frequency_frame); + + // Called when an AckFrequencyFrame is acked. + void OnAckFrequencyFrameAcked( + const QuicAckFrequencyFrame& ack_frequency_frame); + + // Newly serialized retransmittable packets are added to this map, which + // contains owning pointers to any contained frames. If a packet is + // retransmitted, this map will contain entries for both the old and the new + // packet. The old packet's retransmittable frames entry will be nullptr, + // while the new packet's entry will contain the frames to retransmit. + // If the old packet is acked before the new packet, then the old entry will + // be removed from the map and the new entry's retransmittable frames will be + // set to nullptr. + QuicUnackedPacketMap unacked_packets_; + + const QuicClock* clock_; + QuicRandom* random_; + QuicConnectionStats* stats_; + + DebugDelegate* debug_delegate_; + NetworkChangeVisitor* network_change_visitor_; + QuicPacketCount initial_congestion_window_; + RttStats rtt_stats_; + std::unique_ptr send_algorithm_; + // Not owned. Always points to |uber_loss_algorithm_| outside of tests. + LossDetectionInterface* loss_algorithm_; + UberLossAlgorithm uber_loss_algorithm_; + + // Number of times the crypto handshake has been retransmitted. + size_t consecutive_crypto_retransmission_count_; + // Number of pending transmissions of PTO or crypto packets. + size_t pending_timer_transmission_count_; + + bool using_pacing_; + // If true, use a more conservative handshake retransmission policy. + bool conservative_handshake_retransmits_; + + // Vectors packets acked and lost as a result of the last congestion event. + AckedPacketVector packets_acked_; + LostPacketVector packets_lost_; + // Largest newly acknowledged packet. + QuicPacketNumber largest_newly_acked_; + // Largest packet in bytes ever acknowledged. + QuicPacketLength largest_mtu_acked_; + + // Replaces certain calls to |send_algorithm_| when |using_pacing_| is true. + // Calls into |send_algorithm_| for the underlying congestion control. + PacingSender pacing_sender_; + + // Indicates whether handshake is finished. This is purely used to determine + // retransmission mode. DONOT use this to infer handshake state. + bool handshake_finished_; + + // Records bandwidth from server to client in normal operation, over periods + // of time with no loss events. + QuicSustainedBandwidthRecorder sustained_bandwidth_recorder_; + + // The largest acked value that was sent in an ack, which has then been acked. + QuicPacketNumber largest_packet_peer_knows_is_acked_; + // The largest acked value that was sent in an ack, which has then been acked + // for per packet number space. Only used when connection supports multiple + // packet number spaces. + QuicPacketNumber + largest_packets_peer_knows_is_acked_[NUM_PACKET_NUMBER_SPACES]; + + // The maximum ACK delay time that the peer might uses. Initialized to be the + // same as local_max_ack_delay_, may be changed via transport parameter + // negotiation or subsequently by AckFrequencyFrame. + QuicTime::Delta peer_max_ack_delay_; + + // Peer sends min_ack_delay in TransportParameter to advertise its support for + // AckFrequencyFrame. + QuicTime::Delta peer_min_ack_delay_ = QuicTime::Delta::Infinite(); + + // Use smoothed RTT for computing max_ack_delay in AckFrequency frame. + bool use_smoothed_rtt_in_ack_delay_ = false; + + // The history of outstanding max_ack_delays sent to peer. Outstanding means + // a max_ack_delay is sent as part of the last acked AckFrequencyFrame or + // an unacked AckFrequencyFrame after that. + quiche::QuicheCircularDeque< + std::pair> + in_use_sent_ack_delays_; + + // Latest received ack frame. + QuicAckFrame last_ack_frame_; + + // Record whether RTT gets updated by last largest acked.. + bool rtt_updated_; + + // A reverse iterator of last_ack_frame_.packets. This is reset in + // OnAckRangeStart, and gradually moves in OnAckRange.. + PacketNumberQueue::const_reverse_iterator acked_packets_iter_; + + // Number of times the PTO timer has fired in a row without receiving an ack. + size_t consecutive_pto_count_; + + // True if HANDSHAKE mode has been disabled. + bool handshake_mode_disabled_; + + // True if any ENCRYPTION_HANDSHAKE packet gets acknowledged. + bool handshake_packet_acked_; + + // True if any 0-RTT packet gets acknowledged. + bool zero_rtt_packet_acked_; + + // True if any 1-RTT packet gets acknowledged. + bool one_rtt_packet_acked_; + + // The number of PTOs needed for path degrading alarm. If equals to 0, the + // traditional path degrading mechanism will be used. + int num_ptos_for_path_degrading_; + + // If true, do not use PING only packets for RTT measurement or congestion + // control. + bool ignore_pings_; + + // Whether to ignore the ack_delay in received ACKs. + bool ignore_ack_delay_; + + // Most recent ECN codepoint counts received in an ACK frame sent by the peer. + QuicEcnCounts peer_ack_ecn_counts_[NUM_PACKET_NUMBER_SPACES]; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_SENT_PACKET_MANAGER_H_ diff --git a/quiche/quic/core/quic_sent_packet_manager_test.cc b/quiche/quic/core/quic_sent_packet_manager_test.cc new file mode 100644 index 000000000000..86a1d404a82c --- /dev/null +++ b/quiche/quic/core/quic_sent_packet_manager_test.cc @@ -0,0 +1,3216 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_sent_packet_manager.h" + +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/frames/quic_ack_frequency_frame.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_config_peer.h" +#include "quiche/quic/test_tools/quic_sent_packet_manager_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" + +using testing::_; +using testing::AnyNumber; +using testing::Invoke; +using testing::InvokeWithoutArgs; +using testing::IsEmpty; +using testing::Not; +using testing::Pointwise; +using testing::Return; +using testing::StrictMock; +using testing::WithArgs; + +namespace quic { +namespace test { +namespace { +// Default packet length. +const uint32_t kDefaultLength = 1000; + +// Stream ID for data sent in CreatePacket(). +const QuicStreamId kStreamId = 7; + +// The compiler won't allow absl::nullopt as an argument. +const absl::optional kEmptyCounts = absl::nullopt; + +// Matcher to check that the packet number matches the second argument. +MATCHER(PacketNumberEq, "") { + return std::get<0>(arg).packet_number == QuicPacketNumber(std::get<1>(arg)); +} + +class MockDebugDelegate : public QuicSentPacketManager::DebugDelegate { + public: + MOCK_METHOD(void, OnSpuriousPacketRetransmission, + (TransmissionType transmission_type, QuicByteCount byte_size), + (override)); + MOCK_METHOD(void, OnPacketLoss, + (QuicPacketNumber lost_packet_number, + EncryptionLevel encryption_level, + TransmissionType transmission_type, QuicTime detection_time), + (override)); +}; + +class QuicSentPacketManagerTest : public QuicTest { + public: + bool RetransmitCryptoPacket(uint64_t packet_number) { + EXPECT_CALL( + *send_algorithm_, + OnPacketSent(_, BytesInFlight(), QuicPacketNumber(packet_number), + kDefaultLength, HAS_RETRANSMITTABLE_DATA)); + SerializedPacket packet(CreatePacket(packet_number, false)); + packet.retransmittable_frames.push_back( + QuicFrame(QuicStreamFrame(1, false, 0, absl::string_view()))); + packet.has_crypto_handshake = IS_HANDSHAKE; + manager_.OnPacketSent(&packet, clock_.Now(), HANDSHAKE_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA, true, ECN_NOT_ECT); + return true; + } + + bool RetransmitDataPacket(uint64_t packet_number, TransmissionType type, + EncryptionLevel level) { + EXPECT_CALL( + *send_algorithm_, + OnPacketSent(_, BytesInFlight(), QuicPacketNumber(packet_number), + kDefaultLength, HAS_RETRANSMITTABLE_DATA)); + SerializedPacket packet(CreatePacket(packet_number, true)); + packet.encryption_level = level; + manager_.OnPacketSent(&packet, clock_.Now(), type, HAS_RETRANSMITTABLE_DATA, + true, ECN_NOT_ECT); + return true; + } + + bool RetransmitDataPacket(uint64_t packet_number, TransmissionType type) { + return RetransmitDataPacket(packet_number, type, ENCRYPTION_INITIAL); + } + + protected: + const CongestionControlType kInitialCongestionControlType = kCubicBytes; + QuicSentPacketManagerTest() + : manager_(Perspective::IS_SERVER, &clock_, QuicRandom::GetInstance(), + &stats_, kInitialCongestionControlType), + send_algorithm_(new StrictMock), + network_change_visitor_(new StrictMock) { + QuicSentPacketManagerPeer::SetSendAlgorithm(&manager_, send_algorithm_); + // Advance the time 1s so the send times are never QuicTime::Zero. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(1000)); + manager_.SetNetworkChangeVisitor(network_change_visitor_.get()); + manager_.SetSessionNotifier(¬ifier_); + + EXPECT_CALL(*send_algorithm_, GetCongestionControlType()) + .WillRepeatedly(Return(kInitialCongestionControlType)); + EXPECT_CALL(*send_algorithm_, BandwidthEstimate()) + .Times(AnyNumber()) + .WillRepeatedly(Return(QuicBandwidth::Zero())); + EXPECT_CALL(*send_algorithm_, InSlowStart()).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, InRecovery()).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, OnPacketNeutered(_)).Times(AnyNumber()); + EXPECT_CALL(*network_change_visitor_, OnPathMtuIncreased(1000)) + .Times(AnyNumber()); + EXPECT_CALL(notifier_, IsFrameOutstanding(_)).WillRepeatedly(Return(true)); + EXPECT_CALL(notifier_, HasUnackedCryptoData()) + .WillRepeatedly(Return(false)); + EXPECT_CALL(notifier_, OnStreamFrameRetransmitted(_)).Times(AnyNumber()); + EXPECT_CALL(notifier_, OnFrameAcked(_, _, _)).WillRepeatedly(Return(true)); + } + + ~QuicSentPacketManagerTest() override {} + + QuicByteCount BytesInFlight() { return manager_.GetBytesInFlight(); } + void VerifyUnackedPackets(uint64_t* packets, size_t num_packets) { + if (num_packets == 0) { + EXPECT_TRUE(manager_.unacked_packets().empty()); + EXPECT_EQ(0u, QuicSentPacketManagerPeer::GetNumRetransmittablePackets( + &manager_)); + return; + } + + EXPECT_FALSE(manager_.unacked_packets().empty()); + EXPECT_EQ(QuicPacketNumber(packets[0]), manager_.GetLeastUnacked()); + for (size_t i = 0; i < num_packets; ++i) { + EXPECT_TRUE( + manager_.unacked_packets().IsUnacked(QuicPacketNumber(packets[i]))) + << packets[i]; + } + } + + void VerifyRetransmittablePackets(uint64_t* packets, size_t num_packets) { + EXPECT_EQ( + num_packets, + QuicSentPacketManagerPeer::GetNumRetransmittablePackets(&manager_)); + for (size_t i = 0; i < num_packets; ++i) { + EXPECT_TRUE(QuicSentPacketManagerPeer::HasRetransmittableFrames( + &manager_, packets[i])) + << " packets[" << i << "]:" << packets[i]; + } + } + + void ExpectAck(uint64_t largest_observed) { + EXPECT_CALL( + *send_algorithm_, + // Ensure the AckedPacketVector argument contains largest_observed. + OnCongestionEvent(true, _, _, + Pointwise(PacketNumberEq(), {largest_observed}), + IsEmpty(), _, _)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + } + + void ExpectUpdatedRtt(uint64_t /*largest_observed*/) { + EXPECT_CALL(*send_algorithm_, + OnCongestionEvent(true, _, _, IsEmpty(), IsEmpty(), _, _)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + } + + void ExpectAckAndLoss(bool rtt_updated, uint64_t largest_observed, + uint64_t lost_packet) { + EXPECT_CALL( + *send_algorithm_, + OnCongestionEvent(rtt_updated, _, _, + Pointwise(PacketNumberEq(), {largest_observed}), + Pointwise(PacketNumberEq(), {lost_packet}), _, _)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + } + + // |packets_acked| and |packets_lost| should be in packet number order. + void ExpectAcksAndLosses(bool rtt_updated, uint64_t* packets_acked, + size_t num_packets_acked, uint64_t* packets_lost, + size_t num_packets_lost) { + std::vector ack_vector; + for (size_t i = 0; i < num_packets_acked; ++i) { + ack_vector.push_back(QuicPacketNumber(packets_acked[i])); + } + std::vector lost_vector; + for (size_t i = 0; i < num_packets_lost; ++i) { + lost_vector.push_back(QuicPacketNumber(packets_lost[i])); + } + EXPECT_CALL(*send_algorithm_, + OnCongestionEvent( + rtt_updated, _, _, Pointwise(PacketNumberEq(), ack_vector), + Pointwise(PacketNumberEq(), lost_vector), _, _)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()) + .Times(AnyNumber()); + } + + void RetransmitAndSendPacket(uint64_t old_packet_number, + uint64_t new_packet_number) { + RetransmitAndSendPacket(old_packet_number, new_packet_number, + PTO_RETRANSMISSION); + } + + void RetransmitAndSendPacket(uint64_t old_packet_number, + uint64_t new_packet_number, + TransmissionType transmission_type) { + bool is_lost = false; + if (transmission_type == HANDSHAKE_RETRANSMISSION || + transmission_type == PTO_RETRANSMISSION) { + EXPECT_CALL(notifier_, RetransmitFrames(_, _)) + .WillOnce(WithArgs<1>( + Invoke([this, new_packet_number](TransmissionType type) { + return RetransmitDataPacket(new_packet_number, type); + }))); + } else { + EXPECT_CALL(notifier_, OnFrameLost(_)).Times(1); + is_lost = true; + } + QuicSentPacketManagerPeer::MarkForRetransmission( + &manager_, old_packet_number, transmission_type); + if (!is_lost) { + return; + } + EXPECT_CALL( + *send_algorithm_, + OnPacketSent(_, BytesInFlight(), QuicPacketNumber(new_packet_number), + kDefaultLength, HAS_RETRANSMITTABLE_DATA)); + SerializedPacket packet(CreatePacket(new_packet_number, true)); + manager_.OnPacketSent(&packet, clock_.Now(), transmission_type, + HAS_RETRANSMITTABLE_DATA, true, ECN_NOT_ECT); + } + + SerializedPacket CreateDataPacket(uint64_t packet_number) { + return CreatePacket(packet_number, true); + } + + SerializedPacket CreatePacket(uint64_t packet_number, bool retransmittable) { + SerializedPacket packet(QuicPacketNumber(packet_number), + PACKET_4BYTE_PACKET_NUMBER, nullptr, kDefaultLength, + false, false); + if (retransmittable) { + packet.retransmittable_frames.push_back( + QuicFrame(QuicStreamFrame(kStreamId, false, 0, absl::string_view()))); + } + return packet; + } + + SerializedPacket CreatePingPacket(uint64_t packet_number) { + SerializedPacket packet(QuicPacketNumber(packet_number), + PACKET_4BYTE_PACKET_NUMBER, nullptr, kDefaultLength, + false, false); + packet.retransmittable_frames.push_back(QuicFrame(QuicPingFrame())); + return packet; + } + + void SendDataPacket(uint64_t packet_number) { + SendDataPacket(packet_number, ENCRYPTION_INITIAL); + } + + void SendDataPacket(uint64_t packet_number, + EncryptionLevel encryption_level) { + EXPECT_CALL(*send_algorithm_, + OnPacketSent(_, BytesInFlight(), + QuicPacketNumber(packet_number), _, _)); + SerializedPacket packet(CreateDataPacket(packet_number)); + packet.encryption_level = encryption_level; + manager_.OnPacketSent(&packet, clock_.Now(), NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA, true, ECN_NOT_ECT); + } + + void SendPingPacket(uint64_t packet_number, + EncryptionLevel encryption_level) { + EXPECT_CALL(*send_algorithm_, + OnPacketSent(_, BytesInFlight(), + QuicPacketNumber(packet_number), _, _)); + SerializedPacket packet(CreatePingPacket(packet_number)); + packet.encryption_level = encryption_level; + manager_.OnPacketSent(&packet, clock_.Now(), NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA, true, ECN_NOT_ECT); + } + + void SendCryptoPacket(uint64_t packet_number) { + EXPECT_CALL( + *send_algorithm_, + OnPacketSent(_, BytesInFlight(), QuicPacketNumber(packet_number), + kDefaultLength, HAS_RETRANSMITTABLE_DATA)); + SerializedPacket packet(CreatePacket(packet_number, false)); + packet.retransmittable_frames.push_back( + QuicFrame(QuicStreamFrame(1, false, 0, absl::string_view()))); + packet.has_crypto_handshake = IS_HANDSHAKE; + manager_.OnPacketSent(&packet, clock_.Now(), NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA, true, ECN_NOT_ECT); + EXPECT_CALL(notifier_, HasUnackedCryptoData()).WillRepeatedly(Return(true)); + } + + void SendAckPacket(uint64_t packet_number, uint64_t largest_acked) { + SendAckPacket(packet_number, largest_acked, ENCRYPTION_INITIAL); + } + + void SendAckPacket(uint64_t packet_number, uint64_t largest_acked, + EncryptionLevel level) { + EXPECT_CALL( + *send_algorithm_, + OnPacketSent(_, BytesInFlight(), QuicPacketNumber(packet_number), + kDefaultLength, NO_RETRANSMITTABLE_DATA)); + SerializedPacket packet(CreatePacket(packet_number, false)); + packet.largest_acked = QuicPacketNumber(largest_acked); + packet.encryption_level = level; + manager_.OnPacketSent(&packet, clock_.Now(), NOT_RETRANSMISSION, + NO_RETRANSMITTABLE_DATA, true, ECN_NOT_ECT); + } + + quiche::SimpleBufferAllocator allocator_; + QuicSentPacketManager manager_; + MockClock clock_; + QuicConnectionStats stats_; + MockSendAlgorithm* send_algorithm_; + std::unique_ptr network_change_visitor_; + StrictMock notifier_; +}; + +TEST_F(QuicSentPacketManagerTest, IsUnacked) { + VerifyUnackedPackets(nullptr, 0); + SendDataPacket(1); + + uint64_t unacked[] = {1}; + VerifyUnackedPackets(unacked, ABSL_ARRAYSIZE(unacked)); + uint64_t retransmittable[] = {1}; + VerifyRetransmittablePackets(retransmittable, + ABSL_ARRAYSIZE(retransmittable)); +} + +TEST_F(QuicSentPacketManagerTest, IsUnAckedRetransmit) { + SendDataPacket(1); + RetransmitAndSendPacket(1, 2); + + EXPECT_TRUE(QuicSentPacketManagerPeer::IsRetransmission(&manager_, 2)); + uint64_t unacked[] = {1, 2}; + VerifyUnackedPackets(unacked, ABSL_ARRAYSIZE(unacked)); + std::vector retransmittable = {1, 2}; + VerifyRetransmittablePackets(&retransmittable[0], retransmittable.size()); +} + +TEST_F(QuicSentPacketManagerTest, RetransmitThenAck) { + SendDataPacket(1); + RetransmitAndSendPacket(1, 2); + + // Ack 2 but not 1. + ExpectAck(2); + manager_.OnAckFrameStart(QuicPacketNumber(2), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(2), QuicPacketNumber(3)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); + EXPECT_CALL(notifier_, IsFrameOutstanding(_)).WillRepeatedly(Return(false)); + // Packet 1 is unacked, pending, but not retransmittable. + uint64_t unacked[] = {1}; + VerifyUnackedPackets(unacked, ABSL_ARRAYSIZE(unacked)); + EXPECT_TRUE(manager_.HasInFlightPackets()); + VerifyRetransmittablePackets(nullptr, 0); +} + +TEST_F(QuicSentPacketManagerTest, RetransmitThenAckBeforeSend) { + SendDataPacket(1); + EXPECT_CALL(notifier_, RetransmitFrames(_, _)) + .WillOnce(WithArgs<1>(Invoke([this](TransmissionType type) { + return RetransmitDataPacket(2, type); + }))); + QuicSentPacketManagerPeer::MarkForRetransmission(&manager_, 1, + PTO_RETRANSMISSION); + // Ack 1. + ExpectAck(1); + manager_.OnAckFrameStart(QuicPacketNumber(1), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); + + EXPECT_CALL(notifier_, IsFrameOutstanding(_)).WillRepeatedly(Return(false)); + uint64_t unacked[] = {2}; + VerifyUnackedPackets(unacked, ABSL_ARRAYSIZE(unacked)); + // We do not know packet 2 is a spurious retransmission until it gets acked. + VerifyRetransmittablePackets(nullptr, 0); + EXPECT_EQ(0u, stats_.packets_spuriously_retransmitted); +} + +TEST_F(QuicSentPacketManagerTest, RetransmitThenStopRetransmittingBeforeSend) { + SendDataPacket(1); + EXPECT_CALL(notifier_, RetransmitFrames(_, _)).WillRepeatedly(Return(true)); + QuicSentPacketManagerPeer::MarkForRetransmission(&manager_, 1, + PTO_RETRANSMISSION); + + EXPECT_CALL(notifier_, IsFrameOutstanding(_)).WillRepeatedly(Return(false)); + + uint64_t unacked[] = {1}; + VerifyUnackedPackets(unacked, ABSL_ARRAYSIZE(unacked)); + VerifyRetransmittablePackets(nullptr, 0); + EXPECT_EQ(0u, stats_.packets_spuriously_retransmitted); +} + +TEST_F(QuicSentPacketManagerTest, RetransmitThenAckPrevious) { + SendDataPacket(1); + RetransmitAndSendPacket(1, 2); + QuicTime::Delta rtt = QuicTime::Delta::FromMilliseconds(15); + clock_.AdvanceTime(rtt); + + // Ack 1 but not 2. + ExpectAck(1); + manager_.OnAckFrameStart(QuicPacketNumber(1), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); + EXPECT_CALL(notifier_, IsFrameOutstanding(_)).WillRepeatedly(Return(false)); + // 2 remains unacked, but no packets have retransmittable data. + uint64_t unacked[] = {2}; + VerifyUnackedPackets(unacked, ABSL_ARRAYSIZE(unacked)); + EXPECT_TRUE(manager_.HasInFlightPackets()); + VerifyRetransmittablePackets(nullptr, 0); + // Ack 2 causes 2 be considered as spurious retransmission. + EXPECT_CALL(notifier_, OnFrameAcked(_, _, _)).WillOnce(Return(false)); + ExpectAck(2); + manager_.OnAckFrameStart(QuicPacketNumber(2), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(3)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(2), + ENCRYPTION_INITIAL, kEmptyCounts)); + + EXPECT_EQ(1u, stats_.packets_spuriously_retransmitted); +} + +TEST_F(QuicSentPacketManagerTest, RetransmitThenAckPreviousThenNackRetransmit) { + SendDataPacket(1); + RetransmitAndSendPacket(1, 2); + QuicTime::Delta rtt = QuicTime::Delta::FromMilliseconds(15); + clock_.AdvanceTime(rtt); + + // First, ACK packet 1 which makes packet 2 non-retransmittable. + ExpectAck(1); + manager_.OnAckFrameStart(QuicPacketNumber(1), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); + + SendDataPacket(3); + SendDataPacket(4); + SendDataPacket(5); + clock_.AdvanceTime(rtt); + + // Next, NACK packet 2 three times. + EXPECT_CALL(notifier_, IsFrameOutstanding(_)).WillRepeatedly(Return(false)); + EXPECT_CALL(notifier_, OnFrameLost(_)).Times(1); + ExpectAckAndLoss(true, 3, 2); + manager_.OnAckFrameStart(QuicPacketNumber(3), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(3), QuicPacketNumber(4)); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(2), + ENCRYPTION_INITIAL, kEmptyCounts)); + + ExpectAck(4); + manager_.OnAckFrameStart(QuicPacketNumber(4), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(3), QuicPacketNumber(5)); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(3), + ENCRYPTION_INITIAL, kEmptyCounts)); + + ExpectAck(5); + manager_.OnAckFrameStart(QuicPacketNumber(5), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(3), QuicPacketNumber(6)); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(4), + ENCRYPTION_INITIAL, kEmptyCounts)); + + uint64_t unacked[] = {2}; + VerifyUnackedPackets(unacked, ABSL_ARRAYSIZE(unacked)); + EXPECT_FALSE(manager_.HasInFlightPackets()); + VerifyRetransmittablePackets(nullptr, 0); + + // Verify that the retransmission alarm would not fire, + // since there is no retransmittable data outstanding. + EXPECT_EQ(QuicTime::Zero(), manager_.GetRetransmissionTime()); +} + +TEST_F(QuicSentPacketManagerTest, + DISABLED_RetransmitTwiceThenAckPreviousBeforeSend) { + SendDataPacket(1); + RetransmitAndSendPacket(1, 2); + + // Fire the RTO, which will mark 2 for retransmission (but will not send it). + EXPECT_CALL(*send_algorithm_, OnRetransmissionTimeout(true)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + manager_.OnRetransmissionTimeout(); + + // Ack 1 but not 2, before 2 is able to be sent. + // Since 1 has been retransmitted, it has already been lost, and so the + // send algorithm is not informed that it has been ACK'd. + ExpectUpdatedRtt(1); + manager_.OnAckFrameStart(QuicPacketNumber(1), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); + + // Since 2 was marked for retransmit, when 1 is acked, 2 is kept for RTT. + uint64_t unacked[] = {2}; + VerifyUnackedPackets(unacked, ABSL_ARRAYSIZE(unacked)); + EXPECT_FALSE(manager_.HasInFlightPackets()); + VerifyRetransmittablePackets(nullptr, 0); + + // Verify that the retransmission alarm would not fire, + // since there is no retransmittable data outstanding. + EXPECT_EQ(QuicTime::Zero(), manager_.GetRetransmissionTime()); +} + +TEST_F(QuicSentPacketManagerTest, RetransmitTwiceThenAckFirst) { + StrictMock debug_delegate; + EXPECT_CALL(debug_delegate, OnSpuriousPacketRetransmission(PTO_RETRANSMISSION, + kDefaultLength)) + .Times(1); + manager_.SetDebugDelegate(&debug_delegate); + + SendDataPacket(1); + RetransmitAndSendPacket(1, 2); + RetransmitAndSendPacket(2, 3); + QuicTime::Delta rtt = QuicTime::Delta::FromMilliseconds(15); + clock_.AdvanceTime(rtt); + + // Ack 1 but not 2 or 3. + ExpectAck(1); + manager_.OnAckFrameStart(QuicPacketNumber(1), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); + // Frames in packets 2 and 3 are acked. + EXPECT_CALL(notifier_, IsFrameOutstanding(_)) + .Times(2) + .WillRepeatedly(Return(false)); + + // 2 and 3 remain unacked, but no packets have retransmittable data. + uint64_t unacked[] = {2, 3}; + VerifyUnackedPackets(unacked, ABSL_ARRAYSIZE(unacked)); + EXPECT_TRUE(manager_.HasInFlightPackets()); + VerifyRetransmittablePackets(nullptr, 0); + + // Ensure packet 2 is lost when 4 is sent and 3 and 4 are acked. + SendDataPacket(4); + // No new data gets acked in packet 3. + EXPECT_CALL(notifier_, OnFrameAcked(_, _, _)) + .WillOnce(Return(false)) + .WillRepeatedly(Return(true)); + uint64_t acked[] = {3, 4}; + ExpectAcksAndLosses(true, acked, ABSL_ARRAYSIZE(acked), nullptr, 0); + manager_.OnAckFrameStart(QuicPacketNumber(4), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(3), QuicPacketNumber(5)); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(2), + ENCRYPTION_INITIAL, kEmptyCounts)); + + uint64_t unacked2[] = {2}; + VerifyUnackedPackets(unacked2, ABSL_ARRAYSIZE(unacked2)); + EXPECT_TRUE(manager_.HasInFlightPackets()); + + SendDataPacket(5); + ExpectAckAndLoss(true, 5, 2); + EXPECT_CALL(debug_delegate, + OnPacketLoss(QuicPacketNumber(2), _, LOSS_RETRANSMISSION, _)); + // Frames in all packets are acked. + EXPECT_CALL(notifier_, IsFrameOutstanding(_)).WillRepeatedly(Return(false)); + // Notify session that stream frame in packet 2 gets lost although it is + // not outstanding. + EXPECT_CALL(notifier_, OnFrameLost(_)).Times(1); + manager_.OnAckFrameStart(QuicPacketNumber(5), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(3), QuicPacketNumber(6)); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(3), + ENCRYPTION_INITIAL, kEmptyCounts)); + + uint64_t unacked3[] = {2}; + VerifyUnackedPackets(unacked3, ABSL_ARRAYSIZE(unacked3)); + EXPECT_FALSE(manager_.HasInFlightPackets()); + // Spurious retransmission is detected when packet 3 gets acked. We cannot + // know packet 2 is a spurious until it gets acked. + EXPECT_EQ(1u, stats_.packets_spuriously_retransmitted); + EXPECT_EQ(1u, stats_.packets_lost); + EXPECT_LT(0.0, stats_.total_loss_detection_response_time); + EXPECT_LE(1u, stats_.sent_packets_max_sequence_reordering); +} + +TEST_F(QuicSentPacketManagerTest, AckOriginalTransmission) { + auto loss_algorithm = std::make_unique(); + QuicSentPacketManagerPeer::SetLossAlgorithm(&manager_, loss_algorithm.get()); + + SendDataPacket(1); + RetransmitAndSendPacket(1, 2); + + // Ack original transmission, but that wasn't lost via fast retransmit, + // so no call on OnSpuriousRetransmission is expected. + { + ExpectAck(1); + EXPECT_CALL(*loss_algorithm, DetectLosses(_, _, _, _, _, _)); + manager_.OnAckFrameStart(QuicPacketNumber(1), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); + } + + SendDataPacket(3); + SendDataPacket(4); + // Ack 4, which causes 3 to be retransmitted. + { + ExpectAck(4); + EXPECT_CALL(*loss_algorithm, DetectLosses(_, _, _, _, _, _)); + manager_.OnAckFrameStart(QuicPacketNumber(4), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(4), QuicPacketNumber(5)); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(2), + ENCRYPTION_INITIAL, kEmptyCounts)); + RetransmitAndSendPacket(3, 5, LOSS_RETRANSMISSION); + } + + // Ack 3, which causes SpuriousRetransmitDetected to be called. + { + uint64_t acked[] = {3}; + ExpectAcksAndLosses(false, acked, ABSL_ARRAYSIZE(acked), nullptr, 0); + EXPECT_CALL(*loss_algorithm, DetectLosses(_, _, _, _, _, _)); + EXPECT_CALL(*loss_algorithm, + SpuriousLossDetected(_, _, _, QuicPacketNumber(3), + QuicPacketNumber(4))); + manager_.OnAckFrameStart(QuicPacketNumber(4), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(3), QuicPacketNumber(5)); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(0u, stats_.packet_spuriously_detected_lost); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(3), + ENCRYPTION_INITIAL, kEmptyCounts)); + EXPECT_EQ(1u, stats_.packet_spuriously_detected_lost); + // Ack 3 will not cause 5 be considered as a spurious retransmission. Ack + // 5 will cause 5 be considered as a spurious retransmission as no new + // data gets acked. + ExpectAck(5); + EXPECT_CALL(*loss_algorithm, DetectLosses(_, _, _, _, _, _)); + EXPECT_CALL(notifier_, OnFrameAcked(_, _, _)).WillOnce(Return(false)); + manager_.OnAckFrameStart(QuicPacketNumber(5), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(3), QuicPacketNumber(6)); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(4), + ENCRYPTION_INITIAL, kEmptyCounts)); + } +} + +TEST_F(QuicSentPacketManagerTest, GetLeastUnacked) { + EXPECT_EQ(QuicPacketNumber(1u), manager_.GetLeastUnacked()); +} + +TEST_F(QuicSentPacketManagerTest, GetLeastUnackedUnacked) { + SendDataPacket(1); + EXPECT_EQ(QuicPacketNumber(1u), manager_.GetLeastUnacked()); +} + +TEST_F(QuicSentPacketManagerTest, AckAckAndUpdateRtt) { + EXPECT_FALSE(manager_.largest_packet_peer_knows_is_acked().IsInitialized()); + SendDataPacket(1); + SendAckPacket(2, 1); + + // Now ack the ack and expect an RTT update. + uint64_t acked[] = {1, 2}; + ExpectAcksAndLosses(true, acked, ABSL_ARRAYSIZE(acked), nullptr, 0); + manager_.OnAckFrameStart(QuicPacketNumber(2), + QuicTime::Delta::FromMilliseconds(5), clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(3)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); + EXPECT_EQ(QuicPacketNumber(1), manager_.largest_packet_peer_knows_is_acked()); + + SendAckPacket(3, 3); + + // Now ack the ack and expect only an RTT update. + uint64_t acked2[] = {3}; + ExpectAcksAndLosses(true, acked2, ABSL_ARRAYSIZE(acked2), nullptr, 0); + manager_.OnAckFrameStart(QuicPacketNumber(3), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(4)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(2), + ENCRYPTION_INITIAL, kEmptyCounts)); + EXPECT_EQ(QuicPacketNumber(3u), + manager_.largest_packet_peer_knows_is_acked()); +} + +TEST_F(QuicSentPacketManagerTest, Rtt) { + QuicTime::Delta expected_rtt = QuicTime::Delta::FromMilliseconds(20); + SendDataPacket(1); + clock_.AdvanceTime(expected_rtt); + + ExpectAck(1); + manager_.OnAckFrameStart(QuicPacketNumber(1), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); + EXPECT_EQ(expected_rtt, manager_.GetRttStats()->latest_rtt()); +} + +TEST_F(QuicSentPacketManagerTest, RttWithInvalidDelta) { + // Expect that the RTT is equal to the local time elapsed, since the + // ack_delay_time is larger than the local time elapsed + // and is hence invalid. + QuicTime::Delta expected_rtt = QuicTime::Delta::FromMilliseconds(10); + SendDataPacket(1); + clock_.AdvanceTime(expected_rtt); + + ExpectAck(1); + manager_.OnAckFrameStart(QuicPacketNumber(1), + QuicTime::Delta::FromMilliseconds(11), clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); + EXPECT_EQ(expected_rtt, manager_.GetRttStats()->latest_rtt()); +} + +TEST_F(QuicSentPacketManagerTest, RttWithInfiniteDelta) { + // Expect that the RTT is equal to the local time elapsed, since the + // ack_delay_time is infinite, and is hence invalid. + QuicTime::Delta expected_rtt = QuicTime::Delta::FromMilliseconds(10); + SendDataPacket(1); + clock_.AdvanceTime(expected_rtt); + + ExpectAck(1); + manager_.OnAckFrameStart(QuicPacketNumber(1), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); + EXPECT_EQ(expected_rtt, manager_.GetRttStats()->latest_rtt()); +} + +TEST_F(QuicSentPacketManagerTest, RttWithDeltaExceedingLimit) { + // Initialize min and smoothed rtt to 10ms. + RttStats* rtt_stats = const_cast(manager_.GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(10), + QuicTime::Delta::Zero(), QuicTime::Zero()); + + QuicTime::Delta send_delta = QuicTime::Delta::FromMilliseconds(100); + QuicTime::Delta ack_delay = + QuicTime::Delta::FromMilliseconds(5) + manager_.peer_max_ack_delay(); + ASSERT_GT(send_delta - rtt_stats->min_rtt(), ack_delay); + SendDataPacket(1); + clock_.AdvanceTime(send_delta); + + ExpectAck(1); + manager_.OnAckFrameStart(QuicPacketNumber(1), ack_delay, clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_FORWARD_SECURE, kEmptyCounts)); + + QuicTime::Delta expected_rtt_sample = + send_delta - manager_.peer_max_ack_delay(); + EXPECT_EQ(expected_rtt_sample, manager_.GetRttStats()->latest_rtt()); +} + +TEST_F(QuicSentPacketManagerTest, RttZeroDelta) { + // Expect that the RTT is the time between send and receive since the + // ack_delay_time is zero. + QuicTime::Delta expected_rtt = QuicTime::Delta::FromMilliseconds(10); + SendDataPacket(1); + clock_.AdvanceTime(expected_rtt); + + ExpectAck(1); + manager_.OnAckFrameStart(QuicPacketNumber(1), QuicTime::Delta::Zero(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); + EXPECT_EQ(expected_rtt, manager_.GetRttStats()->latest_rtt()); +} + +TEST_F(QuicSentPacketManagerTest, CryptoHandshakeTimeout) { + // Send 2 crypto packets and 3 data packets. + const size_t kNumSentCryptoPackets = 2; + for (size_t i = 1; i <= kNumSentCryptoPackets; ++i) { + SendCryptoPacket(i); + } + const size_t kNumSentDataPackets = 3; + for (size_t i = 1; i <= kNumSentDataPackets; ++i) { + SendDataPacket(kNumSentCryptoPackets + i); + } + EXPECT_TRUE(manager_.HasUnackedCryptoPackets()); + EXPECT_EQ(5 * kDefaultLength, manager_.GetBytesInFlight()); + + // The first retransmits 2 packets. + EXPECT_CALL(notifier_, RetransmitFrames(_, _)) + .Times(2) + .WillOnce( + InvokeWithoutArgs([this]() { return RetransmitCryptoPacket(6); })) + .WillOnce( + InvokeWithoutArgs([this]() { return RetransmitCryptoPacket(7); })); + manager_.OnRetransmissionTimeout(); + // Expect all 4 handshake packets to be in flight and 3 data packets. + EXPECT_EQ(7 * kDefaultLength, manager_.GetBytesInFlight()); + EXPECT_TRUE(manager_.HasUnackedCryptoPackets()); + + // The second retransmits 2 packets. + EXPECT_CALL(notifier_, RetransmitFrames(_, _)) + .Times(2) + .WillOnce( + InvokeWithoutArgs([this]() { return RetransmitCryptoPacket(8); })) + .WillOnce( + InvokeWithoutArgs([this]() { return RetransmitCryptoPacket(9); })); + manager_.OnRetransmissionTimeout(); + EXPECT_EQ(9 * kDefaultLength, manager_.GetBytesInFlight()); + EXPECT_TRUE(manager_.HasUnackedCryptoPackets()); + + // Now ack the two crypto packets and the speculatively encrypted request, + // and ensure the first four crypto packets get abandoned, but not lost. + // Crypto packets remain in flight, so any that aren't acked will be lost. + uint64_t acked[] = {3, 4, 5, 8, 9}; + uint64_t lost[] = {1, 2, 6}; + ExpectAcksAndLosses(true, acked, ABSL_ARRAYSIZE(acked), lost, + ABSL_ARRAYSIZE(lost)); + EXPECT_CALL(notifier_, OnFrameLost(_)).Times(3); + EXPECT_CALL(notifier_, HasUnackedCryptoData()).WillRepeatedly(Return(false)); + manager_.OnAckFrameStart(QuicPacketNumber(9), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(8), QuicPacketNumber(10)); + manager_.OnAckRange(QuicPacketNumber(3), QuicPacketNumber(6)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); + + EXPECT_FALSE(manager_.HasUnackedCryptoPackets()); +} + +TEST_F(QuicSentPacketManagerTest, CryptoHandshakeSpuriousRetransmission) { + // Send 1 crypto packet. + SendCryptoPacket(1); + EXPECT_TRUE(manager_.HasUnackedCryptoPackets()); + + // Retransmit the crypto packet as 2. + EXPECT_CALL(notifier_, RetransmitFrames(_, _)) + .WillOnce( + InvokeWithoutArgs([this]() { return RetransmitCryptoPacket(2); })); + manager_.OnRetransmissionTimeout(); + + // Retransmit the crypto packet as 3. + EXPECT_CALL(notifier_, RetransmitFrames(_, _)) + .WillOnce( + InvokeWithoutArgs([this]() { return RetransmitCryptoPacket(3); })); + manager_.OnRetransmissionTimeout(); + + // Now ack the second crypto packet, and ensure the first gets removed, but + // the third does not. + uint64_t acked[] = {2}; + ExpectAcksAndLosses(true, acked, ABSL_ARRAYSIZE(acked), nullptr, 0); + EXPECT_CALL(notifier_, HasUnackedCryptoData()).WillRepeatedly(Return(false)); + EXPECT_CALL(notifier_, IsFrameOutstanding(_)).WillRepeatedly(Return(false)); + manager_.OnAckFrameStart(QuicPacketNumber(2), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(2), QuicPacketNumber(3)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); + + EXPECT_FALSE(manager_.HasUnackedCryptoPackets()); + uint64_t unacked[] = {1, 3}; + VerifyUnackedPackets(unacked, ABSL_ARRAYSIZE(unacked)); +} + +TEST_F(QuicSentPacketManagerTest, CryptoHandshakeTimeoutUnsentDataPacket) { + // Send 2 crypto packets and 1 data packet. + const size_t kNumSentCryptoPackets = 2; + for (size_t i = 1; i <= kNumSentCryptoPackets; ++i) { + SendCryptoPacket(i); + } + SendDataPacket(3); + EXPECT_TRUE(manager_.HasUnackedCryptoPackets()); + + // Retransmit 2 crypto packets, but not the serialized packet. + EXPECT_CALL(notifier_, RetransmitFrames(_, _)) + .Times(2) + .WillOnce( + InvokeWithoutArgs([this]() { return RetransmitCryptoPacket(4); })) + .WillOnce( + InvokeWithoutArgs([this]() { return RetransmitCryptoPacket(5); })); + manager_.OnRetransmissionTimeout(); + EXPECT_TRUE(manager_.HasUnackedCryptoPackets()); +} + +TEST_F(QuicSentPacketManagerTest, + CryptoHandshakeRetransmissionThenNeuterAndAck) { + // Send 1 crypto packet. + SendCryptoPacket(1); + + EXPECT_TRUE(manager_.HasUnackedCryptoPackets()); + + // Retransmit the crypto packet as 2. + EXPECT_CALL(notifier_, RetransmitFrames(_, _)) + .WillOnce( + InvokeWithoutArgs([this]() { return RetransmitCryptoPacket(2); })); + manager_.OnRetransmissionTimeout(); + EXPECT_TRUE(manager_.HasUnackedCryptoPackets()); + + // Retransmit the crypto packet as 3. + EXPECT_CALL(notifier_, RetransmitFrames(_, _)) + .WillOnce( + InvokeWithoutArgs([this]() { return RetransmitCryptoPacket(3); })); + manager_.OnRetransmissionTimeout(); + EXPECT_TRUE(manager_.HasUnackedCryptoPackets()); + + // Now neuter all unacked unencrypted packets, which occurs when the + // connection goes forward secure. + EXPECT_CALL(notifier_, HasUnackedCryptoData()).WillRepeatedly(Return(false)); + EXPECT_CALL(notifier_, IsFrameOutstanding(_)).WillRepeatedly(Return(false)); + manager_.NeuterUnencryptedPackets(); + EXPECT_FALSE(manager_.HasUnackedCryptoPackets()); + uint64_t unacked[] = {1, 2, 3}; + VerifyUnackedPackets(unacked, ABSL_ARRAYSIZE(unacked)); + VerifyRetransmittablePackets(nullptr, 0); + EXPECT_FALSE(manager_.HasUnackedCryptoPackets()); + EXPECT_FALSE(manager_.HasInFlightPackets()); + + // Ensure both packets get discarded when packet 2 is acked. + uint64_t acked[] = {3}; + ExpectAcksAndLosses(true, acked, ABSL_ARRAYSIZE(acked), nullptr, 0); + manager_.OnAckFrameStart(QuicPacketNumber(3), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(3), QuicPacketNumber(4)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); + VerifyUnackedPackets(nullptr, 0); + VerifyRetransmittablePackets(nullptr, 0); +} + +TEST_F(QuicSentPacketManagerTest, GetTransmissionTime) { + EXPECT_EQ(QuicTime::Zero(), manager_.GetRetransmissionTime()); +} + +TEST_F(QuicSentPacketManagerTest, GetTransmissionTimeCryptoHandshake) { + QuicTime crypto_packet_send_time = clock_.Now(); + SendCryptoPacket(1); + + // Check the min. + RttStats* rtt_stats = const_cast(manager_.GetRttStats()); + rtt_stats->set_initial_rtt(QuicTime::Delta::FromMilliseconds(1)); + EXPECT_EQ(clock_.Now() + QuicTime::Delta::FromMilliseconds(10), + manager_.GetRetransmissionTime()); + + // Test with a standard smoothed RTT. + rtt_stats->set_initial_rtt(QuicTime::Delta::FromMilliseconds(100)); + + QuicTime::Delta srtt = rtt_stats->initial_rtt(); + QuicTime expected_time = clock_.Now() + 1.5 * srtt; + EXPECT_EQ(expected_time, manager_.GetRetransmissionTime()); + + // Retransmit the packet by invoking the retransmission timeout. + clock_.AdvanceTime(1.5 * srtt); + EXPECT_CALL(notifier_, RetransmitFrames(_, _)) + .WillOnce( + InvokeWithoutArgs([this]() { return RetransmitCryptoPacket(2); })); + // When session decides what to write, crypto_packet_send_time gets updated. + crypto_packet_send_time = clock_.Now(); + manager_.OnRetransmissionTimeout(); + + // The retransmission time should now be twice as far in the future. + expected_time = crypto_packet_send_time + srtt * 2 * 1.5; + EXPECT_EQ(expected_time, manager_.GetRetransmissionTime()); + + // Retransmit the packet for the 2nd time. + clock_.AdvanceTime(2 * 1.5 * srtt); + EXPECT_CALL(notifier_, RetransmitFrames(_, _)) + .WillOnce( + InvokeWithoutArgs([this]() { return RetransmitCryptoPacket(3); })); + // When session decides what to write, crypto_packet_send_time gets updated. + crypto_packet_send_time = clock_.Now(); + manager_.OnRetransmissionTimeout(); + + // Verify exponential backoff of the retransmission timeout. + expected_time = crypto_packet_send_time + srtt * 4 * 1.5; + EXPECT_EQ(expected_time, manager_.GetRetransmissionTime()); +} + +TEST_F(QuicSentPacketManagerTest, + GetConservativeTransmissionTimeCryptoHandshake) { + QuicConfig config; + QuicTagVector options; + options.push_back(kCONH); + QuicConfigPeer::SetReceivedConnectionOptions(&config, options); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + manager_.SetFromConfig(config); + // Calling SetFromConfig requires mocking out some send algorithm methods. + EXPECT_CALL(*send_algorithm_, PacingRate(_)) + .WillRepeatedly(Return(QuicBandwidth::Zero())); + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .WillRepeatedly(Return(10 * kDefaultTCPMSS)); + + QuicTime crypto_packet_send_time = clock_.Now(); + SendCryptoPacket(1); + + // Check the min. + RttStats* rtt_stats = const_cast(manager_.GetRttStats()); + rtt_stats->set_initial_rtt(QuicTime::Delta::FromMilliseconds(1)); + EXPECT_EQ(clock_.Now() + QuicTime::Delta::FromMilliseconds(25), + manager_.GetRetransmissionTime()); + + // Test with a standard smoothed RTT. + rtt_stats->set_initial_rtt(QuicTime::Delta::FromMilliseconds(100)); + + QuicTime::Delta srtt = rtt_stats->initial_rtt(); + QuicTime expected_time = clock_.Now() + 2 * srtt; + EXPECT_EQ(expected_time, manager_.GetRetransmissionTime()); + + // Retransmit the packet by invoking the retransmission timeout. + clock_.AdvanceTime(2 * srtt); + EXPECT_CALL(notifier_, RetransmitFrames(_, _)) + .WillOnce( + InvokeWithoutArgs([this]() { return RetransmitCryptoPacket(2); })); + crypto_packet_send_time = clock_.Now(); + manager_.OnRetransmissionTimeout(); + + // The retransmission time should now be twice as far in the future. + expected_time = crypto_packet_send_time + srtt * 2 * 2; + EXPECT_EQ(expected_time, manager_.GetRetransmissionTime()); +} + +TEST_F(QuicSentPacketManagerTest, GetLossDelay) { + auto loss_algorithm = std::make_unique(); + QuicSentPacketManagerPeer::SetLossAlgorithm(&manager_, loss_algorithm.get()); + + EXPECT_CALL(*loss_algorithm, GetLossTimeout()) + .WillRepeatedly(Return(QuicTime::Zero())); + SendDataPacket(1); + SendDataPacket(2); + + // Handle an ack which causes the loss algorithm to be evaluated and + // set the loss timeout. + ExpectAck(2); + EXPECT_CALL(*loss_algorithm, DetectLosses(_, _, _, _, _, _)); + manager_.OnAckFrameStart(QuicPacketNumber(2), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(2), QuicPacketNumber(3)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); + + QuicTime timeout(clock_.Now() + QuicTime::Delta::FromMilliseconds(10)); + EXPECT_CALL(*loss_algorithm, GetLossTimeout()) + .WillRepeatedly(Return(timeout)); + EXPECT_EQ(timeout, manager_.GetRetransmissionTime()); + + // Fire the retransmission timeout and ensure the loss detection algorithm + // is invoked. + EXPECT_CALL(*loss_algorithm, DetectLosses(_, _, _, _, _, _)); + manager_.OnRetransmissionTimeout(); +} + +TEST_F(QuicSentPacketManagerTest, NegotiateIetfLossDetectionFromOptions) { + EXPECT_TRUE( + QuicSentPacketManagerPeer::AdaptiveReorderingThresholdEnabled(&manager_)); + EXPECT_FALSE( + QuicSentPacketManagerPeer::AdaptiveTimeThresholdEnabled(&manager_)); + EXPECT_EQ(kDefaultLossDelayShift, + QuicSentPacketManagerPeer::GetReorderingShift(&manager_)); + + QuicConfig config; + QuicTagVector options; + options.push_back(kILD0); + QuicConfigPeer::SetReceivedConnectionOptions(&config, options); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + manager_.SetFromConfig(config); + + EXPECT_EQ(3, QuicSentPacketManagerPeer::GetReorderingShift(&manager_)); + EXPECT_FALSE( + QuicSentPacketManagerPeer::AdaptiveReorderingThresholdEnabled(&manager_)); +} + +TEST_F(QuicSentPacketManagerTest, + NegotiateIetfLossDetectionOneFourthRttFromOptions) { + EXPECT_TRUE( + QuicSentPacketManagerPeer::AdaptiveReorderingThresholdEnabled(&manager_)); + EXPECT_FALSE( + QuicSentPacketManagerPeer::AdaptiveTimeThresholdEnabled(&manager_)); + EXPECT_EQ(kDefaultLossDelayShift, + QuicSentPacketManagerPeer::GetReorderingShift(&manager_)); + + QuicConfig config; + QuicTagVector options; + options.push_back(kILD1); + QuicConfigPeer::SetReceivedConnectionOptions(&config, options); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + manager_.SetFromConfig(config); + + EXPECT_EQ(kDefaultLossDelayShift, + QuicSentPacketManagerPeer::GetReorderingShift(&manager_)); + EXPECT_FALSE( + QuicSentPacketManagerPeer::AdaptiveReorderingThresholdEnabled(&manager_)); +} + +TEST_F(QuicSentPacketManagerTest, + NegotiateIetfLossDetectionAdaptiveReorderingThreshold) { + EXPECT_TRUE( + QuicSentPacketManagerPeer::AdaptiveReorderingThresholdEnabled(&manager_)); + EXPECT_FALSE( + QuicSentPacketManagerPeer::AdaptiveTimeThresholdEnabled(&manager_)); + EXPECT_EQ(kDefaultLossDelayShift, + QuicSentPacketManagerPeer::GetReorderingShift(&manager_)); + + QuicConfig config; + QuicTagVector options; + options.push_back(kILD2); + QuicConfigPeer::SetReceivedConnectionOptions(&config, options); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + manager_.SetFromConfig(config); + + EXPECT_EQ(3, QuicSentPacketManagerPeer::GetReorderingShift(&manager_)); + EXPECT_TRUE( + QuicSentPacketManagerPeer::AdaptiveReorderingThresholdEnabled(&manager_)); +} + +TEST_F(QuicSentPacketManagerTest, + NegotiateIetfLossDetectionAdaptiveReorderingThreshold2) { + EXPECT_TRUE( + QuicSentPacketManagerPeer::AdaptiveReorderingThresholdEnabled(&manager_)); + EXPECT_FALSE( + QuicSentPacketManagerPeer::AdaptiveTimeThresholdEnabled(&manager_)); + EXPECT_EQ(kDefaultLossDelayShift, + QuicSentPacketManagerPeer::GetReorderingShift(&manager_)); + + QuicConfig config; + QuicTagVector options; + options.push_back(kILD3); + QuicConfigPeer::SetReceivedConnectionOptions(&config, options); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + manager_.SetFromConfig(config); + EXPECT_EQ(kDefaultLossDelayShift, + QuicSentPacketManagerPeer::GetReorderingShift(&manager_)); + EXPECT_TRUE( + QuicSentPacketManagerPeer::AdaptiveReorderingThresholdEnabled(&manager_)); +} + +TEST_F(QuicSentPacketManagerTest, + NegotiateIetfLossDetectionAdaptiveReorderingAndTimeThreshold) { + EXPECT_TRUE( + QuicSentPacketManagerPeer::AdaptiveReorderingThresholdEnabled(&manager_)); + EXPECT_FALSE( + QuicSentPacketManagerPeer::AdaptiveTimeThresholdEnabled(&manager_)); + EXPECT_EQ(kDefaultLossDelayShift, + QuicSentPacketManagerPeer::GetReorderingShift(&manager_)); + + QuicConfig config; + QuicTagVector options; + options.push_back(kILD4); + QuicConfigPeer::SetReceivedConnectionOptions(&config, options); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + manager_.SetFromConfig(config); + + EXPECT_EQ(kDefaultLossDelayShift, + QuicSentPacketManagerPeer::GetReorderingShift(&manager_)); + EXPECT_TRUE( + QuicSentPacketManagerPeer::AdaptiveReorderingThresholdEnabled(&manager_)); + EXPECT_TRUE( + QuicSentPacketManagerPeer::AdaptiveTimeThresholdEnabled(&manager_)); +} + +TEST_F(QuicSentPacketManagerTest, NegotiateCongestionControlFromOptions) { + QuicConfig config; + QuicTagVector options; + + options.push_back(kRENO); + QuicConfigPeer::SetReceivedConnectionOptions(&config, options); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + manager_.SetFromConfig(config); + EXPECT_EQ(kRenoBytes, QuicSentPacketManagerPeer::GetSendAlgorithm(manager_) + ->GetCongestionControlType()); + + options.clear(); + options.push_back(kTBBR); + QuicConfigPeer::SetReceivedConnectionOptions(&config, options); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + manager_.SetFromConfig(config); + EXPECT_EQ(kBBR, QuicSentPacketManagerPeer::GetSendAlgorithm(manager_) + ->GetCongestionControlType()); + + options.clear(); + options.push_back(kBYTE); + QuicConfigPeer::SetReceivedConnectionOptions(&config, options); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + manager_.SetFromConfig(config); + EXPECT_EQ(kCubicBytes, QuicSentPacketManagerPeer::GetSendAlgorithm(manager_) + ->GetCongestionControlType()); + options.clear(); + options.push_back(kRENO); + options.push_back(kBYTE); + QuicConfigPeer::SetReceivedConnectionOptions(&config, options); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + manager_.SetFromConfig(config); + EXPECT_EQ(kRenoBytes, QuicSentPacketManagerPeer::GetSendAlgorithm(manager_) + ->GetCongestionControlType()); +} + +TEST_F(QuicSentPacketManagerTest, NegotiateClientCongestionControlFromOptions) { + QuicConfig config; + QuicTagVector options; + + // No change if the server receives client options. + const SendAlgorithmInterface* mock_sender = + QuicSentPacketManagerPeer::GetSendAlgorithm(manager_); + options.push_back(kRENO); + config.SetClientConnectionOptions(options); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + manager_.SetFromConfig(config); + EXPECT_EQ(mock_sender, QuicSentPacketManagerPeer::GetSendAlgorithm(manager_)); + + // Change the congestion control on the client with client options. + QuicSentPacketManagerPeer::SetPerspective(&manager_, Perspective::IS_CLIENT); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + manager_.SetFromConfig(config); + EXPECT_EQ(kRenoBytes, QuicSentPacketManagerPeer::GetSendAlgorithm(manager_) + ->GetCongestionControlType()); + + options.clear(); + options.push_back(kTBBR); + config.SetClientConnectionOptions(options); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + manager_.SetFromConfig(config); + EXPECT_EQ(kBBR, QuicSentPacketManagerPeer::GetSendAlgorithm(manager_) + ->GetCongestionControlType()); + + options.clear(); + options.push_back(kBYTE); + config.SetClientConnectionOptions(options); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + manager_.SetFromConfig(config); + EXPECT_EQ(kCubicBytes, QuicSentPacketManagerPeer::GetSendAlgorithm(manager_) + ->GetCongestionControlType()); + + options.clear(); + options.push_back(kRENO); + options.push_back(kBYTE); + config.SetClientConnectionOptions(options); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + manager_.SetFromConfig(config); + EXPECT_EQ(kRenoBytes, QuicSentPacketManagerPeer::GetSendAlgorithm(manager_) + ->GetCongestionControlType()); +} + +TEST_F(QuicSentPacketManagerTest, UseInitialRoundTripTimeToSend) { + QuicTime::Delta initial_rtt = QuicTime::Delta::FromMilliseconds(325); + EXPECT_NE(initial_rtt, manager_.GetRttStats()->smoothed_rtt()); + + QuicConfig config; + config.SetInitialRoundTripTimeUsToSend(initial_rtt.ToMicroseconds()); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + manager_.SetFromConfig(config); + + EXPECT_EQ(QuicTime::Delta::Zero(), manager_.GetRttStats()->smoothed_rtt()); + EXPECT_EQ(initial_rtt, manager_.GetRttStats()->initial_rtt()); +} + +TEST_F(QuicSentPacketManagerTest, ResumeConnectionState) { + // The sent packet manager should use the RTT from CachedNetworkParameters if + // it is provided. + const QuicTime::Delta kRtt = QuicTime::Delta::FromMilliseconds(123); + CachedNetworkParameters cached_network_params; + cached_network_params.set_min_rtt_ms(kRtt.ToMilliseconds()); + + SendAlgorithmInterface::NetworkParams params; + params.bandwidth = QuicBandwidth::Zero(); + params.allow_cwnd_to_decrease = false; + params.rtt = kRtt; + params.is_rtt_trusted = true; + + EXPECT_CALL(*send_algorithm_, AdjustNetworkParameters(params)); + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .Times(testing::AnyNumber()); + manager_.ResumeConnectionState(cached_network_params, false); + EXPECT_EQ(kRtt, manager_.GetRttStats()->initial_rtt()); +} + +TEST_F(QuicSentPacketManagerTest, ConnectionMigrationUnspecifiedChange) { + RttStats* rtt_stats = const_cast(manager_.GetRttStats()); + QuicTime::Delta default_init_rtt = rtt_stats->initial_rtt(); + rtt_stats->set_initial_rtt(default_init_rtt * 2); + EXPECT_EQ(2 * default_init_rtt, rtt_stats->initial_rtt()); + + QuicSentPacketManagerPeer::SetConsecutivePtoCount(&manager_, 1); + EXPECT_EQ(1u, manager_.GetConsecutivePtoCount()); + + EXPECT_CALL(*send_algorithm_, OnConnectionMigration()); + EXPECT_EQ(nullptr, + manager_.OnConnectionMigration(/*reset_send_algorithm=*/false)); + + EXPECT_EQ(default_init_rtt, rtt_stats->initial_rtt()); + EXPECT_EQ(0u, manager_.GetConsecutivePtoCount()); +} + +// Tests that ResetCongestionControlUponPeerAddressChange() resets send +// algorithm and RTT. And unACK'ed packets are handled correctly. +TEST_F(QuicSentPacketManagerTest, + ConnectionMigrationUnspecifiedChangeResetSendAlgorithm) { + auto loss_algorithm = std::make_unique(); + QuicSentPacketManagerPeer::SetLossAlgorithm(&manager_, loss_algorithm.get()); + + RttStats* rtt_stats = const_cast(manager_.GetRttStats()); + QuicTime::Delta default_init_rtt = rtt_stats->initial_rtt(); + rtt_stats->set_initial_rtt(default_init_rtt * 2); + EXPECT_EQ(2 * default_init_rtt, rtt_stats->initial_rtt()); + + QuicSentPacketManagerPeer::SetConsecutivePtoCount(&manager_, 1); + EXPECT_EQ(1u, manager_.GetConsecutivePtoCount()); + + SendDataPacket(1, ENCRYPTION_FORWARD_SECURE); + + RttStats old_rtt_stats; + old_rtt_stats.CloneFrom(*manager_.GetRttStats()); + + // Packet1 will be mark for retransmission upon migration. + EXPECT_CALL(notifier_, OnFrameLost(_)); + std::unique_ptr old_send_algorithm = + manager_.OnConnectionMigration(/*reset_send_algorithm=*/true); + + EXPECT_NE(old_send_algorithm.get(), manager_.GetSendAlgorithm()); + EXPECT_EQ(old_send_algorithm->GetCongestionControlType(), + manager_.GetSendAlgorithm()->GetCongestionControlType()); + EXPECT_EQ(default_init_rtt, rtt_stats->initial_rtt()); + EXPECT_EQ(0u, manager_.GetConsecutivePtoCount()); + // Packets sent earlier shouldn't be regarded as in flight. + EXPECT_EQ(0u, BytesInFlight()); + + // Replace the new send algorithm with the mock object. + manager_.SetSendAlgorithm(old_send_algorithm.release()); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + // Application retransmit the data as LOSS_RETRANSMISSION. + RetransmitDataPacket(2, LOSS_RETRANSMISSION, ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kDefaultLength, BytesInFlight()); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + // Receiving an ACK for packet1 20s later shouldn't update the RTT, and + // shouldn't be treated as spurious retransmission. + EXPECT_CALL( + *send_algorithm_, + OnCongestionEvent(/*rtt_updated=*/false, kDefaultLength, _, _, _, _, _)) + .WillOnce(testing::WithArg<3>( + Invoke([](const AckedPacketVector& acked_packets) { + EXPECT_EQ(1u, acked_packets.size()); + EXPECT_EQ(QuicPacketNumber(1), acked_packets[0].packet_number); + // The bytes in packet1 shouldn't contribute to congestion control. + EXPECT_EQ(0u, acked_packets[0].bytes_acked); + }))); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + manager_.OnAckFrameStart(QuicPacketNumber(1), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_CALL(*loss_algorithm, DetectLosses(_, _, _, _, _, _)); + EXPECT_CALL(*loss_algorithm, SpuriousLossDetected(_, _, _, _, _)).Times(0u); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_FORWARD_SECURE, kEmptyCounts)); + EXPECT_TRUE(manager_.GetRttStats()->latest_rtt().IsZero()); + + // Receiving an ACK for packet2 should update RTT and congestion control. + manager_.OnAckFrameStart(QuicPacketNumber(2), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(2), QuicPacketNumber(3)); + EXPECT_CALL(*loss_algorithm, DetectLosses(_, _, _, _, _, _)); + EXPECT_CALL( + *send_algorithm_, + OnCongestionEvent(/*rtt_updated=*/true, kDefaultLength, _, _, _, _, _)) + .WillOnce(testing::WithArg<3>( + Invoke([](const AckedPacketVector& acked_packets) { + EXPECT_EQ(1u, acked_packets.size()); + EXPECT_EQ(QuicPacketNumber(2), acked_packets[0].packet_number); + // The bytes in packet2 should contribute to congestion control. + EXPECT_EQ(kDefaultLength, acked_packets[0].bytes_acked); + }))); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(2), + ENCRYPTION_FORWARD_SECURE, kEmptyCounts)); + EXPECT_EQ(0u, BytesInFlight()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(10), + manager_.GetRttStats()->latest_rtt()); + + SendDataPacket(3, ENCRYPTION_FORWARD_SECURE); + // Trigger loss timeout and mark packet3 for retransmission. + EXPECT_CALL(*loss_algorithm, GetLossTimeout()) + .WillOnce(Return(clock_.Now() + QuicTime::Delta::FromMilliseconds(10))); + EXPECT_CALL(*loss_algorithm, DetectLosses(_, _, _, _, _, _)) + .WillOnce(WithArgs<5>(Invoke([](LostPacketVector* packet_lost) { + packet_lost->emplace_back(QuicPacketNumber(3u), kDefaultLength); + return LossDetectionInterface::DetectionStats(); + }))); + EXPECT_CALL(notifier_, OnFrameLost(_)); + EXPECT_CALL(*send_algorithm_, + OnCongestionEvent(false, kDefaultLength, _, _, _, _, _)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + manager_.OnRetransmissionTimeout(); + EXPECT_EQ(0u, BytesInFlight()); + + // Migrate again with unACK'ed but not in-flight packet. + // Packet3 shouldn't be marked for retransmission again as it is not in + // flight. + old_send_algorithm = + manager_.OnConnectionMigration(/*reset_send_algorithm=*/true); + + EXPECT_NE(old_send_algorithm.get(), manager_.GetSendAlgorithm()); + EXPECT_EQ(old_send_algorithm->GetCongestionControlType(), + manager_.GetSendAlgorithm()->GetCongestionControlType()); + EXPECT_EQ(default_init_rtt, rtt_stats->initial_rtt()); + EXPECT_EQ(0u, manager_.GetConsecutivePtoCount()); + EXPECT_EQ(0u, BytesInFlight()); + EXPECT_TRUE(manager_.GetRttStats()->latest_rtt().IsZero()); + + manager_.SetSendAlgorithm(old_send_algorithm.release()); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(30)); + // Receiving an ACK for packet3 shouldn't update RTT. Though packet 3 was + // marked lost, this spurious retransmission shouldn't be reported to the loss + // algorithm. + manager_.OnAckFrameStart(QuicPacketNumber(3), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(3), QuicPacketNumber(4)); + EXPECT_CALL(*loss_algorithm, DetectLosses(_, _, _, _, _, _)); + EXPECT_CALL(*loss_algorithm, SpuriousLossDetected(_, _, _, _, _)).Times(0u); + EXPECT_CALL(*send_algorithm_, + OnCongestionEvent(/*rtt_updated=*/false, 0, _, _, _, _, _)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(3), + ENCRYPTION_FORWARD_SECURE, kEmptyCounts)); + EXPECT_EQ(0u, BytesInFlight()); + EXPECT_TRUE(manager_.GetRttStats()->latest_rtt().IsZero()); + + SendDataPacket(4, ENCRYPTION_FORWARD_SECURE); + // Trigger loss timeout and mark packet4 for retransmission. + EXPECT_CALL(*loss_algorithm, GetLossTimeout()) + .WillOnce(Return(clock_.Now() + QuicTime::Delta::FromMilliseconds(10))); + EXPECT_CALL(*loss_algorithm, DetectLosses(_, _, _, _, _, _)) + .WillOnce(WithArgs<5>(Invoke([](LostPacketVector* packet_lost) { + packet_lost->emplace_back(QuicPacketNumber(4u), kDefaultLength); + return LossDetectionInterface::DetectionStats(); + }))); + EXPECT_CALL(notifier_, OnFrameLost(_)); + EXPECT_CALL(*send_algorithm_, + OnCongestionEvent(false, kDefaultLength, _, _, _, _, _)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + manager_.OnRetransmissionTimeout(); + EXPECT_EQ(0u, BytesInFlight()); + + // Application retransmit the data as LOSS_RETRANSMISSION. + RetransmitDataPacket(5, LOSS_RETRANSMISSION, ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kDefaultLength, BytesInFlight()); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(30)); + // Receiving an ACK for packet4 should update RTT, but not bytes in flight. + // This spurious retransmission should be reported to the loss algorithm. + manager_.OnAckFrameStart(QuicPacketNumber(4), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(4), QuicPacketNumber(5)); + EXPECT_CALL(*loss_algorithm, DetectLosses(_, _, _, _, _, _)); + EXPECT_CALL(*loss_algorithm, SpuriousLossDetected(_, _, _, _, _)); + EXPECT_CALL( + *send_algorithm_, + OnCongestionEvent(/*rtt_updated=*/true, kDefaultLength, _, _, _, _, _)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + EXPECT_CALL(notifier_, OnFrameAcked(_, _, _)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(3), + ENCRYPTION_FORWARD_SECURE, kEmptyCounts)); + EXPECT_EQ(kDefaultLength, BytesInFlight()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(30), + manager_.GetRttStats()->latest_rtt()); + + // Migrate again with in-flight packet5 whose retransmittable frames are all + // ACKed. Packet5 should be marked for retransmission but nothing to + // retransmit. + EXPECT_CALL(notifier_, IsFrameOutstanding(_)).WillOnce(Return(false)); + EXPECT_CALL(notifier_, OnFrameLost(_)).Times(0u); + old_send_algorithm = + manager_.OnConnectionMigration(/*reset_send_algorithm=*/true); + EXPECT_EQ(default_init_rtt, rtt_stats->initial_rtt()); + EXPECT_EQ(0u, manager_.GetConsecutivePtoCount()); + EXPECT_EQ(0u, BytesInFlight()); + EXPECT_TRUE(manager_.GetRttStats()->latest_rtt().IsZero()); + + manager_.SetSendAlgorithm(old_send_algorithm.release()); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + // Receiving an ACK for packet5 shouldn't update RTT. Though packet 5 was + // marked for retransmission, this spurious retransmission shouldn't be + // reported to the loss algorithm. + manager_.OnAckFrameStart(QuicPacketNumber(5), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(5), QuicPacketNumber(6)); + EXPECT_CALL(*loss_algorithm, DetectLosses(_, _, _, _, _, _)); + EXPECT_CALL(*loss_algorithm, SpuriousLossDetected(_, _, _, _, _)).Times(0u); + EXPECT_CALL(*send_algorithm_, + OnCongestionEvent(/*rtt_updated=*/false, 0, _, _, _, _, _)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + EXPECT_CALL(notifier_, OnFrameAcked(_, _, _)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(3), + ENCRYPTION_FORWARD_SECURE, kEmptyCounts)); + EXPECT_EQ(0u, BytesInFlight()); + EXPECT_TRUE(manager_.GetRttStats()->latest_rtt().IsZero()); +} + +TEST_F(QuicSentPacketManagerTest, PathMtuIncreased) { + EXPECT_CALL(*send_algorithm_, + OnPacketSent(_, BytesInFlight(), QuicPacketNumber(1), _, _)); + SerializedPacket packet(QuicPacketNumber(1), PACKET_4BYTE_PACKET_NUMBER, + nullptr, kDefaultLength + 100, false, false); + manager_.OnPacketSent(&packet, clock_.Now(), NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA, true, ECN_NOT_ECT); + + // Ack the large packet and expect the path MTU to increase. + ExpectAck(1); + EXPECT_CALL(*network_change_visitor_, + OnPathMtuIncreased(kDefaultLength + 100)); + QuicAckFrame ack_frame = InitAckFrame(1); + manager_.OnAckFrameStart(QuicPacketNumber(1), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); +} + +TEST_F(QuicSentPacketManagerTest, OnAckRangeSlowPath) { + // Send packets 1 - 20. + for (size_t i = 1; i <= 20; ++i) { + SendDataPacket(i); + } + // Ack [5, 7), [10, 12), [15, 17). + uint64_t acked1[] = {5, 6, 10, 11, 15, 16}; + uint64_t lost1[] = {1, 2, 3, 4, 7, 8, 9, 12, 13}; + ExpectAcksAndLosses(true, acked1, ABSL_ARRAYSIZE(acked1), lost1, + ABSL_ARRAYSIZE(lost1)); + EXPECT_CALL(notifier_, OnFrameLost(_)).Times(AnyNumber()); + manager_.OnAckFrameStart(QuicPacketNumber(16), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(15), QuicPacketNumber(17)); + manager_.OnAckRange(QuicPacketNumber(10), QuicPacketNumber(12)); + manager_.OnAckRange(QuicPacketNumber(5), QuicPacketNumber(7)); + // Make sure empty range does not harm. + manager_.OnAckRange(QuicPacketNumber(4), QuicPacketNumber(4)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); + + // Ack [4, 8), [9, 13), [14, 21). + uint64_t acked2[] = {4, 7, 9, 12, 14, 17, 18, 19, 20}; + ExpectAcksAndLosses(true, acked2, ABSL_ARRAYSIZE(acked2), nullptr, 0); + manager_.OnAckFrameStart(QuicPacketNumber(20), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(14), QuicPacketNumber(21)); + manager_.OnAckRange(QuicPacketNumber(9), QuicPacketNumber(13)); + manager_.OnAckRange(QuicPacketNumber(4), QuicPacketNumber(8)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(2), + ENCRYPTION_INITIAL, kEmptyCounts)); +} + +TEST_F(QuicSentPacketManagerTest, TolerateReneging) { + // Send packets 1 - 20. + for (size_t i = 1; i <= 20; ++i) { + SendDataPacket(i); + } + // Ack [5, 7), [10, 12), [15, 17). + uint64_t acked1[] = {5, 6, 10, 11, 15, 16}; + uint64_t lost1[] = {1, 2, 3, 4, 7, 8, 9, 12, 13}; + ExpectAcksAndLosses(true, acked1, ABSL_ARRAYSIZE(acked1), lost1, + ABSL_ARRAYSIZE(lost1)); + EXPECT_CALL(notifier_, OnFrameLost(_)).Times(AnyNumber()); + manager_.OnAckFrameStart(QuicPacketNumber(16), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(15), QuicPacketNumber(17)); + manager_.OnAckRange(QuicPacketNumber(10), QuicPacketNumber(12)); + manager_.OnAckRange(QuicPacketNumber(5), QuicPacketNumber(7)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); + + // Making sure reneged ACK does not harm. Ack [4, 8), [9, 13). + uint64_t acked2[] = {4, 7, 9, 12}; + ExpectAcksAndLosses(true, acked2, ABSL_ARRAYSIZE(acked2), nullptr, 0); + manager_.OnAckFrameStart(QuicPacketNumber(12), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(9), QuicPacketNumber(13)); + manager_.OnAckRange(QuicPacketNumber(4), QuicPacketNumber(8)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(2), + ENCRYPTION_INITIAL, kEmptyCounts)); + EXPECT_EQ(QuicPacketNumber(16), manager_.GetLargestObserved()); +} + +TEST_F(QuicSentPacketManagerTest, MultiplePacketNumberSpaces) { + manager_.EnableMultiplePacketNumberSpacesSupport(); + const QuicUnackedPacketMap* unacked_packets = + QuicSentPacketManagerPeer::GetUnackedPacketMap(&manager_); + EXPECT_FALSE( + unacked_packets + ->GetLargestSentRetransmittableOfPacketNumberSpace(INITIAL_DATA) + .IsInitialized()); + EXPECT_FALSE( + manager_.GetLargestAckedPacket(ENCRYPTION_INITIAL).IsInitialized()); + // Send packet 1. + SendDataPacket(1, ENCRYPTION_INITIAL); + EXPECT_EQ(QuicPacketNumber(1), + unacked_packets->GetLargestSentRetransmittableOfPacketNumberSpace( + INITIAL_DATA)); + EXPECT_FALSE( + unacked_packets + ->GetLargestSentRetransmittableOfPacketNumberSpace(HANDSHAKE_DATA) + .IsInitialized()); + // Ack packet 1. + ExpectAck(1); + manager_.OnAckFrameStart(QuicPacketNumber(1), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); + EXPECT_EQ(QuicPacketNumber(1), + manager_.GetLargestAckedPacket(ENCRYPTION_INITIAL)); + EXPECT_FALSE( + manager_.GetLargestAckedPacket(ENCRYPTION_HANDSHAKE).IsInitialized()); + // Send packets 2 and 3. + SendDataPacket(2, ENCRYPTION_HANDSHAKE); + SendDataPacket(3, ENCRYPTION_HANDSHAKE); + EXPECT_EQ(QuicPacketNumber(1), + unacked_packets->GetLargestSentRetransmittableOfPacketNumberSpace( + INITIAL_DATA)); + EXPECT_EQ(QuicPacketNumber(3), + unacked_packets->GetLargestSentRetransmittableOfPacketNumberSpace( + HANDSHAKE_DATA)); + EXPECT_FALSE( + unacked_packets + ->GetLargestSentRetransmittableOfPacketNumberSpace(APPLICATION_DATA) + .IsInitialized()); + // Ack packet 2. + ExpectAck(2); + manager_.OnAckFrameStart(QuicPacketNumber(2), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(2), QuicPacketNumber(3)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(2), + ENCRYPTION_HANDSHAKE, kEmptyCounts)); + EXPECT_EQ(QuicPacketNumber(2), + manager_.GetLargestAckedPacket(ENCRYPTION_HANDSHAKE)); + EXPECT_FALSE( + manager_.GetLargestAckedPacket(ENCRYPTION_ZERO_RTT).IsInitialized()); + // Ack packet 3. + ExpectAck(3); + manager_.OnAckFrameStart(QuicPacketNumber(3), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(2), QuicPacketNumber(4)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(3), + ENCRYPTION_HANDSHAKE, kEmptyCounts)); + EXPECT_EQ(QuicPacketNumber(3), + manager_.GetLargestAckedPacket(ENCRYPTION_HANDSHAKE)); + EXPECT_FALSE( + manager_.GetLargestAckedPacket(ENCRYPTION_ZERO_RTT).IsInitialized()); + // Send packets 4 and 5. + SendDataPacket(4, ENCRYPTION_ZERO_RTT); + SendDataPacket(5, ENCRYPTION_ZERO_RTT); + EXPECT_EQ(QuicPacketNumber(1), + unacked_packets->GetLargestSentRetransmittableOfPacketNumberSpace( + INITIAL_DATA)); + EXPECT_EQ(QuicPacketNumber(3), + unacked_packets->GetLargestSentRetransmittableOfPacketNumberSpace( + HANDSHAKE_DATA)); + EXPECT_EQ(QuicPacketNumber(5), + unacked_packets->GetLargestSentRetransmittableOfPacketNumberSpace( + APPLICATION_DATA)); + // Ack packet 5. + ExpectAck(5); + manager_.OnAckFrameStart(QuicPacketNumber(5), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(5), QuicPacketNumber(6)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(4), + ENCRYPTION_FORWARD_SECURE, kEmptyCounts)); + EXPECT_EQ(QuicPacketNumber(3), + manager_.GetLargestAckedPacket(ENCRYPTION_HANDSHAKE)); + EXPECT_EQ(QuicPacketNumber(5), + manager_.GetLargestAckedPacket(ENCRYPTION_ZERO_RTT)); + EXPECT_EQ(QuicPacketNumber(5), + manager_.GetLargestAckedPacket(ENCRYPTION_FORWARD_SECURE)); + + // Send packets 6 - 8. + SendDataPacket(6, ENCRYPTION_FORWARD_SECURE); + SendDataPacket(7, ENCRYPTION_FORWARD_SECURE); + SendDataPacket(8, ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(QuicPacketNumber(1), + unacked_packets->GetLargestSentRetransmittableOfPacketNumberSpace( + INITIAL_DATA)); + EXPECT_EQ(QuicPacketNumber(3), + unacked_packets->GetLargestSentRetransmittableOfPacketNumberSpace( + HANDSHAKE_DATA)); + EXPECT_EQ(QuicPacketNumber(8), + unacked_packets->GetLargestSentRetransmittableOfPacketNumberSpace( + APPLICATION_DATA)); + // Ack all packets. + uint64_t acked[] = {4, 6, 7, 8}; + ExpectAcksAndLosses(true, acked, ABSL_ARRAYSIZE(acked), nullptr, 0); + manager_.OnAckFrameStart(QuicPacketNumber(8), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(4), QuicPacketNumber(9)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(5), + ENCRYPTION_FORWARD_SECURE, kEmptyCounts)); + EXPECT_EQ(QuicPacketNumber(3), + manager_.GetLargestAckedPacket(ENCRYPTION_HANDSHAKE)); + EXPECT_EQ(QuicPacketNumber(8), + manager_.GetLargestAckedPacket(ENCRYPTION_ZERO_RTT)); + EXPECT_EQ(QuicPacketNumber(8), + manager_.GetLargestAckedPacket(ENCRYPTION_FORWARD_SECURE)); +} + +TEST_F(QuicSentPacketManagerTest, PacketsGetAckedInWrongPacketNumberSpace) { + manager_.EnableMultiplePacketNumberSpacesSupport(); + // Send packet 1. + SendDataPacket(1, ENCRYPTION_INITIAL); + // Send packets 2 and 3. + SendDataPacket(2, ENCRYPTION_HANDSHAKE); + SendDataPacket(3, ENCRYPTION_HANDSHAKE); + + // ACK packets 2 and 3 in the wrong packet number space. + manager_.OnAckFrameStart(QuicPacketNumber(3), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(4)); + EXPECT_EQ(PACKETS_ACKED_IN_WRONG_PACKET_NUMBER_SPACE, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); +} + +TEST_F(QuicSentPacketManagerTest, PacketsGetAckedInWrongPacketNumberSpace2) { + manager_.EnableMultiplePacketNumberSpacesSupport(); + // Send packet 1. + SendDataPacket(1, ENCRYPTION_INITIAL); + // Send packets 2 and 3. + SendDataPacket(2, ENCRYPTION_HANDSHAKE); + SendDataPacket(3, ENCRYPTION_HANDSHAKE); + + // ACK packet 1 in the wrong packet number space. + manager_.OnAckFrameStart(QuicPacketNumber(3), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(4)); + EXPECT_EQ(PACKETS_ACKED_IN_WRONG_PACKET_NUMBER_SPACE, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_HANDSHAKE, kEmptyCounts)); +} + +TEST_F(QuicSentPacketManagerTest, + ToleratePacketsGetAckedInWrongPacketNumberSpace) { + manager_.EnableMultiplePacketNumberSpacesSupport(); + // Send packet 1. + SendDataPacket(1, ENCRYPTION_INITIAL); + // Ack packet 1. + ExpectAck(1); + manager_.OnAckFrameStart(QuicPacketNumber(1), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); + + // Send packets 2 and 3. + SendDataPacket(2, ENCRYPTION_HANDSHAKE); + SendDataPacket(3, ENCRYPTION_HANDSHAKE); + + // Packet 1 gets acked in the wrong packet number space. Since packet 1 has + // been acked in the correct packet number space, tolerate it. + uint64_t acked[] = {2, 3}; + ExpectAcksAndLosses(true, acked, ABSL_ARRAYSIZE(acked), nullptr, 0); + manager_.OnAckFrameStart(QuicPacketNumber(3), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(4)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(2), + ENCRYPTION_HANDSHAKE, kEmptyCounts)); +} + +TEST_F(QuicSentPacketManagerTest, ComputingProbeTimeout) { + EXPECT_CALL(*send_algorithm_, PacingRate(_)) + .WillRepeatedly(Return(QuicBandwidth::Zero())); + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .WillRepeatedly(Return(10 * kDefaultTCPMSS)); + RttStats* rtt_stats = const_cast(manager_.GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(100), + QuicTime::Delta::Zero(), QuicTime::Zero()); + QuicTime::Delta srtt = rtt_stats->smoothed_rtt(); + + SendDataPacket(1, ENCRYPTION_FORWARD_SECURE); + // Verify PTO is correctly set. + QuicTime::Delta expected_pto_delay = + srtt + kPtoRttvarMultiplier * rtt_stats->mean_deviation() + + QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs); + QuicTime packet1_sent_time = clock_.Now(); + EXPECT_EQ(clock_.Now() + expected_pto_delay, + manager_.GetRetransmissionTime()); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + SendDataPacket(2, ENCRYPTION_FORWARD_SECURE); + // Verify PTO is set based on left edge. + QuicTime deadline = packet1_sent_time + expected_pto_delay; + EXPECT_EQ(deadline, manager_.GetRetransmissionTime()); + EXPECT_EQ(0u, stats_.pto_count); + + // Invoke PTO. + clock_.AdvanceTime(deadline - clock_.Now()); + manager_.OnRetransmissionTimeout(); + EXPECT_EQ(QuicTime::Delta::Zero(), manager_.TimeUntilSend(clock_.Now())); + EXPECT_EQ(1u, stats_.pto_count); + EXPECT_EQ(0u, stats_.max_consecutive_rto_with_forward_progress); + + EXPECT_CALL(notifier_, RetransmitFrames(_, _)) + .WillOnce(WithArgs<1>(Invoke([this](TransmissionType type) { + return RetransmitDataPacket(3, type, ENCRYPTION_FORWARD_SECURE); + }))); + manager_.MaybeSendProbePacket(); + // Verify PTO period gets set to twice the current value. + QuicTime sent_time = clock_.Now(); + EXPECT_EQ(sent_time + expected_pto_delay * 2, + manager_.GetRetransmissionTime()); + + // Received ACK for packets 1 and 2. + uint64_t acked[] = {1, 2}; + ExpectAcksAndLosses(true, acked, ABSL_ARRAYSIZE(acked), nullptr, 0); + manager_.OnAckFrameStart(QuicPacketNumber(2), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(3)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_FORWARD_SECURE, kEmptyCounts)); + expected_pto_delay = + rtt_stats->SmoothedOrInitialRtt() + + std::max(kPtoRttvarMultiplier * rtt_stats->mean_deviation(), + QuicTime::Delta::FromMilliseconds(1)) + + QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs); + + // Verify PTO is correctly re-armed based on sent time of packet 4. + EXPECT_EQ(sent_time + expected_pto_delay, manager_.GetRetransmissionTime()); + EXPECT_EQ(1u, stats_.max_consecutive_rto_with_forward_progress); +} + +TEST_F(QuicSentPacketManagerTest, SendOneProbePacket) { + EXPECT_CALL(*send_algorithm_, PacingRate(_)) + .WillRepeatedly(Return(QuicBandwidth::Zero())); + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .WillRepeatedly(Return(10 * kDefaultTCPMSS)); + + SendDataPacket(1, ENCRYPTION_FORWARD_SECURE); + QuicTime packet1_sent_time = clock_.Now(); + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + SendDataPacket(2, ENCRYPTION_FORWARD_SECURE); + + RttStats* rtt_stats = const_cast(manager_.GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(100), + QuicTime::Delta::Zero(), QuicTime::Zero()); + QuicTime::Delta srtt = rtt_stats->smoothed_rtt(); + // Verify PTO period is correctly set. + QuicTime::Delta expected_pto_delay = + srtt + kPtoRttvarMultiplier * rtt_stats->mean_deviation() + + QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs); + // Verify PTO is set based on left edge. + QuicTime deadline = packet1_sent_time + expected_pto_delay; + EXPECT_EQ(deadline, manager_.GetRetransmissionTime()); + + // Invoke PTO. + clock_.AdvanceTime(deadline - clock_.Now()); + manager_.OnRetransmissionTimeout(); + EXPECT_EQ(QuicTime::Delta::Zero(), manager_.TimeUntilSend(clock_.Now())); + + // Verify one probe packet gets sent. + EXPECT_CALL(notifier_, RetransmitFrames(_, _)) + .WillOnce(WithArgs<1>(Invoke([this](TransmissionType type) { + return RetransmitDataPacket(3, type, ENCRYPTION_FORWARD_SECURE); + }))); + manager_.MaybeSendProbePacket(); +} + +TEST_F(QuicSentPacketManagerTest, DisableHandshakeModeClient) { + QuicSentPacketManagerPeer::SetPerspective(&manager_, Perspective::IS_CLIENT); + manager_.EnableMultiplePacketNumberSpacesSupport(); + // Send CHLO. + SendCryptoPacket(1); + EXPECT_NE(QuicTime::Zero(), manager_.GetRetransmissionTime()); + // Ack packet 1. + ExpectAck(1); + manager_.OnAckFrameStart(QuicPacketNumber(1), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); + EXPECT_EQ(0u, manager_.GetBytesInFlight()); + // Verify retransmission timeout is not zero because handshake is not + // confirmed although there is no in flight packet. + EXPECT_NE(QuicTime::Zero(), manager_.GetRetransmissionTime()); + // Fire PTO. + EXPECT_EQ(QuicSentPacketManager::PTO_MODE, + manager_.OnRetransmissionTimeout()); + // Send handshake packet. + SendDataPacket(2, ENCRYPTION_HANDSHAKE); + // Ack packet 2. + ExpectAck(2); + manager_.OnAckFrameStart(QuicPacketNumber(2), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(2), QuicPacketNumber(3)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(2), + ENCRYPTION_HANDSHAKE, kEmptyCounts)); + // Verify retransmission timeout is zero because server has successfully + // processed HANDSHAKE packet. + EXPECT_EQ(QuicTime::Zero(), manager_.GetRetransmissionTime()); +} + +TEST_F(QuicSentPacketManagerTest, DisableHandshakeModeServer) { + manager_.EnableIetfPtoAndLossDetection(); + // Send SHLO. + SendCryptoPacket(1); + EXPECT_NE(QuicTime::Zero(), manager_.GetRetransmissionTime()); + // Ack packet 1. + ExpectAck(1); + manager_.OnAckFrameStart(QuicPacketNumber(1), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); + EXPECT_EQ(0u, manager_.GetBytesInFlight()); + // Verify retransmission timeout is not set on server side because there is + // nothing in flight. + EXPECT_EQ(QuicTime::Zero(), manager_.GetRetransmissionTime()); +} + +TEST_F(QuicSentPacketManagerTest, PtoTimeoutRttVarMultiple) { + EXPECT_CALL(*send_algorithm_, PacingRate(_)) + .WillRepeatedly(Return(QuicBandwidth::Zero())); + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .WillRepeatedly(Return(10 * kDefaultTCPMSS)); + RttStats* rtt_stats = const_cast(manager_.GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(100), + QuicTime::Delta::Zero(), QuicTime::Zero()); + QuicTime::Delta srtt = rtt_stats->smoothed_rtt(); + + SendDataPacket(1, ENCRYPTION_FORWARD_SECURE); + // Verify PTO is correctly set based on 2 times rtt var. + QuicTime::Delta expected_pto_delay = + srtt + kPtoRttvarMultiplier * rtt_stats->mean_deviation() + + QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs); + EXPECT_EQ(clock_.Now() + expected_pto_delay, + manager_.GetRetransmissionTime()); +} + +TEST_F(QuicSentPacketManagerTest, IW10ForUpAndDown) { + QuicConfig config; + QuicTagVector options; + options.push_back(kBWS5); + QuicConfigPeer::SetReceivedConnectionOptions(&config, options); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + EXPECT_CALL(*send_algorithm_, SetInitialCongestionWindowInPackets(10)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + manager_.SetFromConfig(config); + + EXPECT_EQ(10u, manager_.initial_congestion_window()); +} + +TEST_F(QuicSentPacketManagerTest, ClientMultiplePacketNumberSpacePtoTimeout) { + manager_.EnableMultiplePacketNumberSpacesSupport(); + EXPECT_CALL(*send_algorithm_, PacingRate(_)) + .WillRepeatedly(Return(QuicBandwidth::Zero())); + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .WillRepeatedly(Return(10 * kDefaultTCPMSS)); + RttStats* rtt_stats = const_cast(manager_.GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(100), + QuicTime::Delta::Zero(), QuicTime::Zero()); + QuicTime::Delta srtt = rtt_stats->smoothed_rtt(); + QuicSentPacketManagerPeer::SetPerspective(&manager_, Perspective::IS_CLIENT); + + // Send packet 1. + SendDataPacket(1, ENCRYPTION_INITIAL); + // Verify PTO is correctly set. + QuicTime::Delta expected_pto_delay = + srtt + kPtoRttvarMultiplier * rtt_stats->mean_deviation() + + QuicTime::Delta::Zero(); + EXPECT_EQ(clock_.Now() + expected_pto_delay, + manager_.GetRetransmissionTime()); + + // Discard initial key and send packet 2 in handshake. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + EXPECT_CALL(notifier_, IsFrameOutstanding(_)).WillRepeatedly(Return(false)); + manager_.NeuterUnencryptedPackets(); + + EXPECT_CALL(notifier_, IsFrameOutstanding(_)).WillRepeatedly(Return(true)); + SendDataPacket(2, ENCRYPTION_HANDSHAKE); + // Verify PTO is correctly set based on sent time of packet 2. + EXPECT_EQ(clock_.Now() + expected_pto_delay, + manager_.GetRetransmissionTime()); + // Invoke PTO. + clock_.AdvanceTime(expected_pto_delay); + manager_.OnRetransmissionTimeout(); + EXPECT_EQ(QuicTime::Delta::Zero(), manager_.TimeUntilSend(clock_.Now())); + EXPECT_EQ(1u, stats_.pto_count); + EXPECT_EQ(1u, stats_.crypto_retransmit_count); + + // Verify probe packet gets sent. + EXPECT_CALL(notifier_, RetransmitFrames(_, _)) + .WillOnce(WithArgs<1>(Invoke([this](TransmissionType type) { + return RetransmitDataPacket(3, type, ENCRYPTION_HANDSHAKE); + }))); + manager_.MaybeSendProbePacket(); + // Verify PTO period gets set to twice the current value. + const QuicTime packet3_sent_time = clock_.Now(); + EXPECT_EQ(packet3_sent_time + expected_pto_delay * 2, + manager_.GetRetransmissionTime()); + + // Send packet 4 in application data with 0-RTT. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + SendDataPacket(4, ENCRYPTION_ZERO_RTT); + const QuicTime packet4_sent_time = clock_.Now(); + // Verify PTO timeout is still based on packet 3. + EXPECT_EQ(packet3_sent_time + expected_pto_delay * 2, + manager_.GetRetransmissionTime()); + + // Send packet 5 in handshake. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + SendDataPacket(5, ENCRYPTION_HANDSHAKE); + const QuicTime packet5_sent_time = clock_.Now(); + // Verify PTO timeout is now based on packet 5 because packet 4 should be + // ignored. + EXPECT_EQ(clock_.Now() + expected_pto_delay * 2, + manager_.GetRetransmissionTime()); + + // Send packet 6 in 1-RTT. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + SendDataPacket(6, ENCRYPTION_FORWARD_SECURE); + // Verify PTO timeout is now based on packet 5. + EXPECT_EQ(packet5_sent_time + expected_pto_delay * 2, + manager_.GetRetransmissionTime()); + + // Send packet 7 in handshake. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + const QuicTime packet7_sent_time = clock_.Now(); + SendDataPacket(7, ENCRYPTION_HANDSHAKE); + + expected_pto_delay = + srtt + kPtoRttvarMultiplier * rtt_stats->mean_deviation(); + // Verify PTO timeout is now based on packet 7. + EXPECT_EQ(packet7_sent_time + expected_pto_delay * 2, + manager_.GetRetransmissionTime()); + + // Neuter handshake key. + manager_.SetHandshakeConfirmed(); + // Forward progress has been made, verify PTO counter gets reset. PTO timeout + // is armed by left edge. + expected_pto_delay = + srtt + kPtoRttvarMultiplier * rtt_stats->mean_deviation() + + QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs); + EXPECT_EQ(packet4_sent_time + expected_pto_delay, + manager_.GetRetransmissionTime()); +} + +TEST_F(QuicSentPacketManagerTest, ServerMultiplePacketNumberSpacePtoTimeout) { + manager_.EnableMultiplePacketNumberSpacesSupport(); + EXPECT_CALL(*send_algorithm_, PacingRate(_)) + .WillRepeatedly(Return(QuicBandwidth::Zero())); + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .WillRepeatedly(Return(10 * kDefaultTCPMSS)); + RttStats* rtt_stats = const_cast(manager_.GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(100), + QuicTime::Delta::Zero(), QuicTime::Zero()); + QuicTime::Delta srtt = rtt_stats->smoothed_rtt(); + + // Send packet 1. + SendDataPacket(1, ENCRYPTION_INITIAL); + const QuicTime packet1_sent_time = clock_.Now(); + // Verify PTO is correctly set. + QuicTime::Delta expected_pto_delay = + srtt + kPtoRttvarMultiplier * rtt_stats->mean_deviation() + + QuicTime::Delta::Zero(); + EXPECT_EQ(packet1_sent_time + expected_pto_delay, + manager_.GetRetransmissionTime()); + + // Send packet 2 in handshake. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + SendDataPacket(2, ENCRYPTION_HANDSHAKE); + const QuicTime packet2_sent_time = clock_.Now(); + // Verify PTO timeout is still based on packet 1. + EXPECT_EQ(packet1_sent_time + expected_pto_delay, + manager_.GetRetransmissionTime()); + + // Discard initial keys. + EXPECT_CALL(notifier_, IsFrameOutstanding(_)).WillRepeatedly(Return(false)); + manager_.NeuterUnencryptedPackets(); + + // Send packet 3 in 1-RTT. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + SendDataPacket(3, ENCRYPTION_FORWARD_SECURE); + // Verify PTO timeout is based on packet 2. + const QuicTime packet3_sent_time = clock_.Now(); + EXPECT_EQ(packet2_sent_time + expected_pto_delay, + manager_.GetRetransmissionTime()); + + // Send packet 4 in handshake. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + SendDataPacket(4, ENCRYPTION_HANDSHAKE); + // Verify PTO timeout is based on packet 4 as application data is ignored. + EXPECT_EQ(clock_.Now() + expected_pto_delay, + manager_.GetRetransmissionTime()); + + // Discard handshake keys. + manager_.SetHandshakeConfirmed(); + expected_pto_delay = + srtt + kPtoRttvarMultiplier * rtt_stats->mean_deviation() + + QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs); + // Verify PTO timeout is now based on packet 3 as handshake is + // complete/confirmed. + EXPECT_EQ(packet3_sent_time + expected_pto_delay, + manager_.GetRetransmissionTime()); +} + +TEST_F(QuicSentPacketManagerTest, ComputingProbeTimeoutByLeftEdge) { + EXPECT_CALL(*send_algorithm_, CanSend(_)).WillRepeatedly(Return(true)); + EXPECT_CALL(*send_algorithm_, PacingRate(_)) + .WillRepeatedly(Return(QuicBandwidth::Zero())); + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .WillRepeatedly(Return(10 * kDefaultTCPMSS)); + RttStats* rtt_stats = const_cast(manager_.GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(100), + QuicTime::Delta::Zero(), QuicTime::Zero()); + QuicTime::Delta srtt = rtt_stats->smoothed_rtt(); + + SendDataPacket(1, ENCRYPTION_FORWARD_SECURE); + // Verify PTO is correctly set. + QuicTime::Delta expected_pto_delay = + srtt + kPtoRttvarMultiplier * rtt_stats->mean_deviation() + + QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs); + const QuicTime packet1_sent_time = clock_.Now(); + EXPECT_EQ(packet1_sent_time + expected_pto_delay, + manager_.GetRetransmissionTime()); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + SendDataPacket(2, ENCRYPTION_FORWARD_SECURE); + // Verify PTO is still based on packet 1. + EXPECT_EQ(packet1_sent_time + expected_pto_delay, + manager_.GetRetransmissionTime()); + EXPECT_EQ(0u, stats_.pto_count); + + // Invoke PTO. + clock_.AdvanceTime(expected_pto_delay); + manager_.OnRetransmissionTimeout(); + EXPECT_EQ(QuicTime::Delta::Zero(), manager_.TimeUntilSend(clock_.Now())); + EXPECT_EQ(1u, stats_.pto_count); + + EXPECT_CALL(notifier_, RetransmitFrames(_, _)) + .WillOnce(WithArgs<1>(Invoke([this](TransmissionType type) { + return RetransmitDataPacket(3, type, ENCRYPTION_FORWARD_SECURE); + }))); + manager_.MaybeSendProbePacket(); + // Verify PTO period gets set to twice the current value and based on packet3. + QuicTime packet3_sent_time = clock_.Now(); + EXPECT_EQ(packet3_sent_time + expected_pto_delay * 2, + manager_.GetRetransmissionTime()); + + // Received ACK for packets 1 and 2. + uint64_t acked[] = {1, 2}; + ExpectAcksAndLosses(true, acked, ABSL_ARRAYSIZE(acked), nullptr, 0); + manager_.OnAckFrameStart(QuicPacketNumber(2), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(3)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_FORWARD_SECURE, kEmptyCounts)); + expected_pto_delay = + rtt_stats->SmoothedOrInitialRtt() + + std::max(kPtoRttvarMultiplier * rtt_stats->mean_deviation(), + QuicTime::Delta::FromMilliseconds(1)) + + QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs); + + // Verify PTO is correctly re-armed based on sent time of packet 4. + EXPECT_EQ(packet3_sent_time + expected_pto_delay, + manager_.GetRetransmissionTime()); +} + +TEST_F(QuicSentPacketManagerTest, ComputingProbeTimeoutByLeftEdge2) { + EXPECT_CALL(*send_algorithm_, CanSend(_)).WillRepeatedly(Return(true)); + EXPECT_CALL(*send_algorithm_, PacingRate(_)) + .WillRepeatedly(Return(QuicBandwidth::Zero())); + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .WillRepeatedly(Return(10 * kDefaultTCPMSS)); + RttStats* rtt_stats = const_cast(manager_.GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(100), + QuicTime::Delta::Zero(), QuicTime::Zero()); + QuicTime::Delta srtt = rtt_stats->smoothed_rtt(); + + SendDataPacket(1, ENCRYPTION_FORWARD_SECURE); + // Verify PTO is correctly set. + QuicTime::Delta expected_pto_delay = + srtt + kPtoRttvarMultiplier * rtt_stats->mean_deviation() + + QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs); + const QuicTime packet1_sent_time = clock_.Now(); + EXPECT_EQ(packet1_sent_time + expected_pto_delay, + manager_.GetRetransmissionTime()); + + // Sent a packet 10ms before PTO expiring. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds( + expected_pto_delay.ToMilliseconds() - 10)); + SendDataPacket(2, ENCRYPTION_FORWARD_SECURE); + // Verify PTO expands to packet 2 sent time + 1.5 * srtt. + expected_pto_delay = kFirstPtoSrttMultiplier * rtt_stats->smoothed_rtt(); + EXPECT_EQ(clock_.Now() + expected_pto_delay, + manager_.GetRetransmissionTime()); + EXPECT_EQ(0u, stats_.pto_count); + + // Invoke PTO. + clock_.AdvanceTime(expected_pto_delay); + manager_.OnRetransmissionTimeout(); + EXPECT_EQ(QuicTime::Delta::Zero(), manager_.TimeUntilSend(clock_.Now())); + EXPECT_EQ(1u, stats_.pto_count); + + EXPECT_CALL(notifier_, RetransmitFrames(_, _)) + .WillOnce(WithArgs<1>(Invoke([this](TransmissionType type) { + return RetransmitDataPacket(3, type, ENCRYPTION_FORWARD_SECURE); + }))); + manager_.MaybeSendProbePacket(); + // Verify PTO period gets set to twice the expected value and based on + // packet3 (right edge). + expected_pto_delay = + srtt + kPtoRttvarMultiplier * rtt_stats->mean_deviation() + + QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs); + QuicTime packet3_sent_time = clock_.Now(); + EXPECT_EQ(packet3_sent_time + expected_pto_delay * 2, + manager_.GetRetransmissionTime()); + + // Received ACK for packets 1 and 2. + uint64_t acked[] = {1, 2}; + ExpectAcksAndLosses(true, acked, ABSL_ARRAYSIZE(acked), nullptr, 0); + manager_.OnAckFrameStart(QuicPacketNumber(2), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(3)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_FORWARD_SECURE, kEmptyCounts)); + expected_pto_delay = + rtt_stats->SmoothedOrInitialRtt() + + std::max(kPtoRttvarMultiplier * rtt_stats->mean_deviation(), + QuicTime::Delta::FromMilliseconds(1)) + + QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs); + + // Verify PTO is correctly re-armed based on sent time of packet 3 (left + // edge). + EXPECT_EQ(packet3_sent_time + expected_pto_delay, + manager_.GetRetransmissionTime()); +} + +TEST_F(QuicSentPacketManagerTest, + ComputingProbeTimeoutByLeftEdgeMultiplePacketNumberSpaces) { + manager_.EnableMultiplePacketNumberSpacesSupport(); + EXPECT_CALL(*send_algorithm_, CanSend(_)).WillRepeatedly(Return(true)); + EXPECT_CALL(*send_algorithm_, PacingRate(_)) + .WillRepeatedly(Return(QuicBandwidth::Zero())); + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .WillRepeatedly(Return(10 * kDefaultTCPMSS)); + RttStats* rtt_stats = const_cast(manager_.GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(100), + QuicTime::Delta::Zero(), QuicTime::Zero()); + QuicTime::Delta srtt = rtt_stats->smoothed_rtt(); + + // Send packet 1. + SendDataPacket(1, ENCRYPTION_INITIAL); + const QuicTime packet1_sent_time = clock_.Now(); + // Verify PTO is correctly set. + QuicTime::Delta expected_pto_delay = + srtt + kPtoRttvarMultiplier * rtt_stats->mean_deviation() + + QuicTime::Delta::Zero(); + EXPECT_EQ(packet1_sent_time + expected_pto_delay, + manager_.GetRetransmissionTime()); + + // Send packet 2 in handshake. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + SendDataPacket(2, ENCRYPTION_HANDSHAKE); + const QuicTime packet2_sent_time = clock_.Now(); + // Verify PTO timeout is still based on packet 1. + EXPECT_EQ(packet1_sent_time + expected_pto_delay, + manager_.GetRetransmissionTime()); + + // Discard initial keys. + EXPECT_CALL(notifier_, IsFrameOutstanding(_)).WillRepeatedly(Return(false)); + manager_.NeuterUnencryptedPackets(); + + // Send packet 3 in 1-RTT. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + SendDataPacket(3, ENCRYPTION_FORWARD_SECURE); + // Verify PTO timeout is based on packet 2. + const QuicTime packet3_sent_time = clock_.Now(); + EXPECT_EQ(packet2_sent_time + expected_pto_delay, + manager_.GetRetransmissionTime()); + + // Send packet 4 in handshake. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + SendDataPacket(4, ENCRYPTION_HANDSHAKE); + // Verify PTO timeout is based on packet 4 as application data is ignored. + EXPECT_EQ(clock_.Now() + expected_pto_delay, + manager_.GetRetransmissionTime()); + + // Discard handshake keys. + manager_.SetHandshakeConfirmed(); + // Verify PTO timeout is now based on packet 3 as handshake is + // complete/confirmed. + expected_pto_delay = + srtt + kPtoRttvarMultiplier * rtt_stats->mean_deviation() + + QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs); + EXPECT_EQ(packet3_sent_time + expected_pto_delay, + manager_.GetRetransmissionTime()); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + SendDataPacket(5, ENCRYPTION_FORWARD_SECURE); + // Verify PTO timeout is still based on packet 3. + EXPECT_EQ(packet3_sent_time + expected_pto_delay, + manager_.GetRetransmissionTime()); +} + +TEST_F(QuicSentPacketManagerTest, + ComputingProbeTimeoutByLeftEdge2MultiplePacketNumberSpaces) { + manager_.EnableMultiplePacketNumberSpacesSupport(); + EXPECT_CALL(*send_algorithm_, CanSend(_)).WillRepeatedly(Return(true)); + EXPECT_CALL(*send_algorithm_, PacingRate(_)) + .WillRepeatedly(Return(QuicBandwidth::Zero())); + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .WillRepeatedly(Return(10 * kDefaultTCPMSS)); + RttStats* rtt_stats = const_cast(manager_.GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(100), + QuicTime::Delta::Zero(), QuicTime::Zero()); + QuicTime::Delta srtt = rtt_stats->smoothed_rtt(); + + // Send packet 1. + SendDataPacket(1, ENCRYPTION_INITIAL); + const QuicTime packet1_sent_time = clock_.Now(); + // Verify PTO is correctly set. + QuicTime::Delta expected_pto_delay = + srtt + kPtoRttvarMultiplier * rtt_stats->mean_deviation() + + QuicTime::Delta::Zero(); + EXPECT_EQ(packet1_sent_time + expected_pto_delay, + manager_.GetRetransmissionTime()); + + // Send packet 2 in handshake. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + SendDataPacket(2, ENCRYPTION_HANDSHAKE); + const QuicTime packet2_sent_time = clock_.Now(); + // Verify PTO timeout is still based on packet 1. + EXPECT_EQ(packet1_sent_time + expected_pto_delay, + manager_.GetRetransmissionTime()); + + // Discard initial keys. + EXPECT_CALL(notifier_, IsFrameOutstanding(_)).WillRepeatedly(Return(false)); + manager_.NeuterUnencryptedPackets(); + + // Send packet 3 in 1-RTT. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + SendDataPacket(3, ENCRYPTION_FORWARD_SECURE); + // Verify PTO timeout is based on packet 2. + const QuicTime packet3_sent_time = clock_.Now(); + EXPECT_EQ(packet2_sent_time + expected_pto_delay, + manager_.GetRetransmissionTime()); + + // Send packet 4 in handshake. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + SendDataPacket(4, ENCRYPTION_HANDSHAKE); + // Verify PTO timeout is based on packet 4 as application data is ignored. + EXPECT_EQ(clock_.Now() + expected_pto_delay, + manager_.GetRetransmissionTime()); + + // Discard handshake keys. + manager_.SetHandshakeConfirmed(); + // Verify PTO timeout is now based on packet 3 as handshake is + // complete/confirmed. + expected_pto_delay = + srtt + kPtoRttvarMultiplier * rtt_stats->mean_deviation() + + QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs); + EXPECT_EQ(packet3_sent_time + expected_pto_delay, + manager_.GetRetransmissionTime()); + + // Send packet 5 10ms before PTO expiring. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds( + expected_pto_delay.ToMilliseconds() - 10)); + SendDataPacket(5, ENCRYPTION_FORWARD_SECURE); + // Verify PTO timeout expands to packet 5 sent time + 1.5 * srtt. + EXPECT_EQ(clock_.Now() + kFirstPtoSrttMultiplier * rtt_stats->smoothed_rtt(), + manager_.GetRetransmissionTime()); +} + +TEST_F(QuicSentPacketManagerTest, SetHandshakeConfirmed) { + QuicSentPacketManagerPeer::SetPerspective(&manager_, Perspective::IS_CLIENT); + manager_.EnableMultiplePacketNumberSpacesSupport(); + + SendDataPacket(1, ENCRYPTION_INITIAL); + + SendDataPacket(2, ENCRYPTION_HANDSHAKE); + + EXPECT_CALL(notifier_, OnFrameAcked(_, _, _)) + .WillOnce( + Invoke([](const QuicFrame& /*frame*/, QuicTime::Delta ack_delay_time, + QuicTime receive_timestamp) { + EXPECT_TRUE(ack_delay_time.IsZero()); + EXPECT_EQ(receive_timestamp, QuicTime::Zero()); + return true; + })); + + EXPECT_CALL(*send_algorithm_, OnPacketNeutered(QuicPacketNumber(2))).Times(1); + manager_.SetHandshakeConfirmed(); +} + +// Regresstion test for b/148841700. +TEST_F(QuicSentPacketManagerTest, NeuterUnencryptedPackets) { + SendCryptoPacket(1); + SendPingPacket(2, ENCRYPTION_INITIAL); + // Crypto data has been discarded but ping does not. + EXPECT_CALL(notifier_, OnFrameAcked(_, _, _)) + .Times(2) + .WillOnce(Return(false)) + .WillOnce(Return(true)); + EXPECT_CALL(notifier_, IsFrameOutstanding(_)).WillRepeatedly(Return(false)); + + EXPECT_CALL(*send_algorithm_, OnPacketNeutered(QuicPacketNumber(1))).Times(1); + manager_.NeuterUnencryptedPackets(); +} + +TEST_F(QuicSentPacketManagerTest, MarkInitialPacketsForRetransmission) { + SendCryptoPacket(1); + SendPingPacket(2, ENCRYPTION_HANDSHAKE); + // Only the INITIAL packet will be retransmitted. + EXPECT_CALL(notifier_, OnFrameLost(_)).Times(1); + manager_.MarkInitialPacketsForRetransmission(); +} + +TEST_F(QuicSentPacketManagerTest, NoPacketThresholdDetectionForRuntPackets) { + EXPECT_TRUE( + QuicSentPacketManagerPeer::UsePacketThresholdForRuntPackets(&manager_)); + + QuicConfig config; + QuicTagVector options; + options.push_back(kRUNT); + QuicConfigPeer::SetReceivedConnectionOptions(&config, options); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + manager_.SetFromConfig(config); + + EXPECT_FALSE( + QuicSentPacketManagerPeer::UsePacketThresholdForRuntPackets(&manager_)); +} + +TEST_F(QuicSentPacketManagerTest, GetPathDegradingDelayDefaultPTO) { + QuicSentPacketManagerPeer::SetPerspective(&manager_, Perspective::IS_CLIENT); + QuicTime::Delta expected_delay = 4 * manager_.GetPtoDelay(); + EXPECT_EQ(expected_delay, manager_.GetPathDegradingDelay()); +} + +TEST_F(QuicSentPacketManagerTest, ClientsIgnorePings) { + QuicSentPacketManagerPeer::SetPerspective(&manager_, Perspective::IS_CLIENT); + QuicConfig client_config; + QuicTagVector options; + QuicTagVector client_options; + client_options.push_back(kIGNP); + client_config.SetConnectionOptionsToSend(options); + client_config.SetClientConnectionOptions(client_options); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + manager_.SetFromConfig(client_config); + + EXPECT_CALL(*send_algorithm_, PacingRate(_)) + .WillRepeatedly(Return(QuicBandwidth::Zero())); + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .WillRepeatedly(Return(10 * kDefaultTCPMSS)); + EXPECT_CALL(*send_algorithm_, CanSend(_)).WillRepeatedly(Return(true)); + + SendPingPacket(1, ENCRYPTION_INITIAL); + // Verify PING only packet is not considered in flight. + EXPECT_EQ(QuicTime::Zero(), manager_.GetRetransmissionTime()); + SendDataPacket(2, ENCRYPTION_INITIAL); + EXPECT_NE(QuicTime::Zero(), manager_.GetRetransmissionTime()); + + uint64_t acked[] = {1}; + ExpectAcksAndLosses(/*rtt_updated=*/false, acked, ABSL_ARRAYSIZE(acked), + nullptr, 0); + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(90)); + manager_.OnAckFrameStart(QuicPacketNumber(1), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); + RttStats* rtt_stats = const_cast(manager_.GetRttStats()); + // Verify no RTT samples for PING only packet. + EXPECT_TRUE(rtt_stats->smoothed_rtt().IsZero()); + + ExpectAck(2); + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + manager_.OnAckFrameStart(QuicPacketNumber(2), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(3)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(2), + ENCRYPTION_INITIAL, kEmptyCounts)); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(100), rtt_stats->smoothed_rtt()); +} + +// Regression test for b/154050235. +TEST_F(QuicSentPacketManagerTest, ExponentialBackoffWithNoRttMeasurement) { + QuicSentPacketManagerPeer::SetPerspective(&manager_, Perspective::IS_CLIENT); + manager_.EnableMultiplePacketNumberSpacesSupport(); + RttStats* rtt_stats = const_cast(manager_.GetRttStats()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(kInitialRttMs), + rtt_stats->initial_rtt()); + EXPECT_TRUE(rtt_stats->smoothed_rtt().IsZero()); + + SendCryptoPacket(1); + QuicTime::Delta expected_pto_delay = + QuicTime::Delta::FromMilliseconds(3 * kInitialRttMs); + EXPECT_EQ(clock_.Now() + expected_pto_delay, + manager_.GetRetransmissionTime()); + + // Invoke PTO. + clock_.AdvanceTime(expected_pto_delay); + manager_.OnRetransmissionTimeout(); + + EXPECT_CALL(notifier_, RetransmitFrames(_, _)) + .WillOnce( + WithArgs<1>(Invoke([this]() { return RetransmitCryptoPacket(3); }))); + manager_.MaybeSendProbePacket(); + // Verify exponential backoff of the PTO timeout. + EXPECT_EQ(clock_.Now() + 2 * expected_pto_delay, + manager_.GetRetransmissionTime()); +} + +TEST_F(QuicSentPacketManagerTest, PtoDelayWithTinyInitialRtt) { + manager_.EnableMultiplePacketNumberSpacesSupport(); + RttStats* rtt_stats = const_cast(manager_.GetRttStats()); + // Assume client provided a tiny initial RTT. + rtt_stats->set_initial_rtt(QuicTime::Delta::FromMicroseconds(1)); + EXPECT_EQ(QuicTime::Delta::FromMicroseconds(1), rtt_stats->initial_rtt()); + EXPECT_TRUE(rtt_stats->smoothed_rtt().IsZero()); + + SendCryptoPacket(1); + QuicTime::Delta expected_pto_delay = QuicTime::Delta::FromMilliseconds(10); + // Verify kMinHandshakeTimeoutMs is respected. + EXPECT_EQ(clock_.Now() + expected_pto_delay, + manager_.GetRetransmissionTime()); + + // Invoke PTO. + clock_.AdvanceTime(expected_pto_delay); + manager_.OnRetransmissionTimeout(); + + EXPECT_CALL(notifier_, RetransmitFrames(_, _)) + .WillOnce( + WithArgs<1>(Invoke([this]() { return RetransmitCryptoPacket(3); }))); + manager_.MaybeSendProbePacket(); + // Verify exponential backoff of the PTO timeout. + EXPECT_EQ(clock_.Now() + 2 * expected_pto_delay, + manager_.GetRetransmissionTime()); +} + +TEST_F(QuicSentPacketManagerTest, HandshakeAckCausesInitialKeyDropping) { + manager_.EnableMultiplePacketNumberSpacesSupport(); + QuicSentPacketManagerPeer::SetPerspective(&manager_, Perspective::IS_CLIENT); + // Send INITIAL packet 1. + SendDataPacket(1, ENCRYPTION_INITIAL); + QuicTime::Delta expected_pto_delay = + QuicTime::Delta::FromMilliseconds(3 * kInitialRttMs); + EXPECT_EQ(clock_.Now() + expected_pto_delay, + manager_.GetRetransmissionTime()); + // Send HANDSHAKE ack. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + SendAckPacket(2, /*largest_acked=*/1, ENCRYPTION_HANDSHAKE); + // Sending HANDSHAKE packet causes dropping of INITIAL key. + EXPECT_CALL(notifier_, HasUnackedCryptoData()).WillRepeatedly(Return(false)); + EXPECT_CALL(notifier_, IsFrameOutstanding(_)).WillRepeatedly(Return(false)); + manager_.NeuterUnencryptedPackets(); + // There is no in flight packets. + EXPECT_FALSE(manager_.HasInFlightPackets()); + // Verify PTO timer gets rearmed from now because of anti-amplification. + EXPECT_EQ(clock_.Now() + expected_pto_delay, + manager_.GetRetransmissionTime()); + + // Invoke PTO. + clock_.AdvanceTime(expected_pto_delay); + manager_.OnRetransmissionTimeout(); + // Verify nothing to probe (and connection will send PING for current + // encryption level). + EXPECT_CALL(notifier_, RetransmitFrames(_, _)).Times(0); + manager_.MaybeSendProbePacket(); +} + +// Regression test for b/156487311 +TEST_F(QuicSentPacketManagerTest, ClearLastInflightPacketsSentTime) { + manager_.EnableMultiplePacketNumberSpacesSupport(); + EXPECT_CALL(*send_algorithm_, PacingRate(_)) + .WillRepeatedly(Return(QuicBandwidth::Zero())); + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .WillRepeatedly(Return(10 * kDefaultTCPMSS)); + + // Send INITIAL 1. + SendDataPacket(1, ENCRYPTION_INITIAL); + // Send HANDSHAKE 2. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + SendDataPacket(2, ENCRYPTION_HANDSHAKE); + SendDataPacket(3, ENCRYPTION_HANDSHAKE); + SendDataPacket(4, ENCRYPTION_HANDSHAKE); + const QuicTime packet2_sent_time = clock_.Now(); + + // Send half RTT 5. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + SendDataPacket(5, ENCRYPTION_FORWARD_SECURE); + + // Received ACK for INITIAL 1. + ExpectAck(1); + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(90)); + manager_.OnAckFrameStart(QuicPacketNumber(1), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); + RttStats* rtt_stats = const_cast(manager_.GetRttStats()); + const QuicTime::Delta pto_delay = + rtt_stats->smoothed_rtt() + + kPtoRttvarMultiplier * rtt_stats->mean_deviation() + + QuicTime::Delta::Zero(); + // Verify PTO is armed based on handshake data. + EXPECT_EQ(packet2_sent_time + pto_delay, manager_.GetRetransmissionTime()); +} + +TEST_F(QuicSentPacketManagerTest, MaybeRetransmitInitialData) { + manager_.EnableMultiplePacketNumberSpacesSupport(); + EXPECT_CALL(*send_algorithm_, PacingRate(_)) + .WillRepeatedly(Return(QuicBandwidth::Zero())); + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .WillRepeatedly(Return(10 * kDefaultTCPMSS)); + RttStats* rtt_stats = const_cast(manager_.GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(100), + QuicTime::Delta::Zero(), QuicTime::Zero()); + QuicTime::Delta srtt = rtt_stats->smoothed_rtt(); + + // Send packet 1. + SendDataPacket(1, ENCRYPTION_INITIAL); + QuicTime packet1_sent_time = clock_.Now(); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + // Send packets 2 and 3. + SendDataPacket(2, ENCRYPTION_HANDSHAKE); + QuicTime packet2_sent_time = clock_.Now(); + SendDataPacket(3, ENCRYPTION_HANDSHAKE); + // Verify PTO is correctly set based on packet 1. + QuicTime::Delta expected_pto_delay = + srtt + kPtoRttvarMultiplier * rtt_stats->mean_deviation() + + QuicTime::Delta::Zero(); + EXPECT_EQ(packet1_sent_time + expected_pto_delay, + manager_.GetRetransmissionTime()); + + // Assume connection is going to send INITIAL ACK. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + EXPECT_CALL(notifier_, RetransmitFrames(_, _)) + .WillOnce(WithArgs<1>(Invoke([this](TransmissionType type) { + return RetransmitDataPacket(4, type, ENCRYPTION_INITIAL); + }))); + manager_.RetransmitDataOfSpaceIfAny(INITIAL_DATA); + // Verify PTO is re-armed based on packet 2. + EXPECT_EQ(packet2_sent_time + expected_pto_delay, + manager_.GetRetransmissionTime()); + + // Connection is going to send another INITIAL ACK. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + EXPECT_CALL(notifier_, RetransmitFrames(_, _)) + .WillOnce(WithArgs<1>(Invoke([this](TransmissionType type) { + return RetransmitDataPacket(5, type, ENCRYPTION_INITIAL); + }))); + manager_.RetransmitDataOfSpaceIfAny(INITIAL_DATA); + // Verify PTO does not change. + EXPECT_EQ(packet2_sent_time + expected_pto_delay, + manager_.GetRetransmissionTime()); +} + +TEST_F(QuicSentPacketManagerTest, SendPathChallengeAndGetAck) { + QuicPacketNumber packet_number(1); + EXPECT_CALL(*send_algorithm_, + OnPacketSent(_, BytesInFlight(), packet_number, _, _)); + SerializedPacket packet(packet_number, PACKET_4BYTE_PACKET_NUMBER, nullptr, + kDefaultLength, false, false); + QuicPathFrameBuffer path_frame_buffer{0, 1, 2, 3, 4, 5, 6, 7}; + packet.nonretransmittable_frames.push_back( + QuicFrame(QuicPathChallengeFrame(0, path_frame_buffer))); + packet.encryption_level = ENCRYPTION_FORWARD_SECURE; + manager_.OnPacketSent(&packet, clock_.Now(), NOT_RETRANSMISSION, + NO_RETRANSMITTABLE_DATA, false, ECN_NOT_ECT); + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + EXPECT_CALL( + *send_algorithm_, + OnCongestionEvent(/*rtt_updated=*/false, _, _, + Pointwise(PacketNumberEq(), {1}), IsEmpty(), _, _)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + + // Get ACK for the packet. + manager_.OnAckFrameStart(QuicPacketNumber(1), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_FORWARD_SECURE, kEmptyCounts)); +} + +SerializedPacket MakePacketWithAckFrequencyFrame( + int packet_number, int ack_frequency_sequence_number, + QuicTime::Delta max_ack_delay) { + auto* ack_frequency_frame = new QuicAckFrequencyFrame(); + ack_frequency_frame->max_ack_delay = max_ack_delay; + ack_frequency_frame->sequence_number = ack_frequency_sequence_number; + SerializedPacket packet(QuicPacketNumber(packet_number), + PACKET_4BYTE_PACKET_NUMBER, nullptr, kDefaultLength, + /*has_ack=*/false, + /*has_stop_waiting=*/false); + packet.retransmittable_frames.push_back(QuicFrame(ack_frequency_frame)); + packet.has_ack_frequency = true; + packet.encryption_level = ENCRYPTION_FORWARD_SECURE; + return packet; +} + +TEST_F(QuicSentPacketManagerTest, + PeerMaxAckDelayUpdatedFromAckFrequencyFrameOneAtATime) { + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _, _, _)) + .Times(AnyNumber()); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()) + .Times(AnyNumber()); + + auto initial_peer_max_ack_delay = manager_.peer_max_ack_delay(); + auto one_ms = QuicTime::Delta::FromMilliseconds(1); + auto plus_1_ms_delay = initial_peer_max_ack_delay + one_ms; + auto minus_1_ms_delay = initial_peer_max_ack_delay - one_ms; + + // Send and Ack frame1. + SerializedPacket packet1 = MakePacketWithAckFrequencyFrame( + /*packet_number=*/1, /*ack_frequency_sequence_number=*/1, + plus_1_ms_delay); + // Higher on the fly max_ack_delay changes peer_max_ack_delay. + manager_.OnPacketSent(&packet1, clock_.Now(), NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA, /*measure_rtt=*/true, + ECN_NOT_ECT); + EXPECT_EQ(manager_.peer_max_ack_delay(), plus_1_ms_delay); + manager_.OnAckFrameStart(QuicPacketNumber(1), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_FORWARD_SECURE, kEmptyCounts); + EXPECT_EQ(manager_.peer_max_ack_delay(), plus_1_ms_delay); + + // Send and Ack frame2. + SerializedPacket packet2 = MakePacketWithAckFrequencyFrame( + /*packet_number=*/2, /*ack_frequency_sequence_number=*/2, + minus_1_ms_delay); + // Lower on the fly max_ack_delay does not change peer_max_ack_delay. + manager_.OnPacketSent(&packet2, clock_.Now(), NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA, /*measure_rtt=*/true, + ECN_NOT_ECT); + EXPECT_EQ(manager_.peer_max_ack_delay(), plus_1_ms_delay); + manager_.OnAckFrameStart(QuicPacketNumber(2), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(2), QuicPacketNumber(3)); + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(2), + ENCRYPTION_FORWARD_SECURE, kEmptyCounts); + EXPECT_EQ(manager_.peer_max_ack_delay(), minus_1_ms_delay); +} + +TEST_F(QuicSentPacketManagerTest, + PeerMaxAckDelayUpdatedFromInOrderAckFrequencyFrames) { + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _, _, _)) + .Times(AnyNumber()); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()) + .Times(AnyNumber()); + + auto initial_peer_max_ack_delay = manager_.peer_max_ack_delay(); + auto one_ms = QuicTime::Delta::FromMilliseconds(1); + auto extra_1_ms = initial_peer_max_ack_delay + one_ms; + auto extra_2_ms = initial_peer_max_ack_delay + 2 * one_ms; + auto extra_3_ms = initial_peer_max_ack_delay + 3 * one_ms; + SerializedPacket packet1 = MakePacketWithAckFrequencyFrame( + /*packet_number=*/1, /*ack_frequency_sequence_number=*/1, extra_1_ms); + SerializedPacket packet2 = MakePacketWithAckFrequencyFrame( + /*packet_number=*/2, /*ack_frequency_sequence_number=*/2, extra_3_ms); + SerializedPacket packet3 = MakePacketWithAckFrequencyFrame( + /*packet_number=*/3, /*ack_frequency_sequence_number=*/3, extra_2_ms); + + // Send frame1, farme2, frame3. + manager_.OnPacketSent(&packet1, clock_.Now(), NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA, /*measure_rtt=*/true, + ECN_NOT_ECT); + EXPECT_EQ(manager_.peer_max_ack_delay(), extra_1_ms); + manager_.OnPacketSent(&packet2, clock_.Now(), NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA, /*measure_rtt=*/true, + ECN_NOT_ECT); + EXPECT_EQ(manager_.peer_max_ack_delay(), extra_3_ms); + manager_.OnPacketSent(&packet3, clock_.Now(), NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA, /*measure_rtt=*/true, + ECN_NOT_ECT); + EXPECT_EQ(manager_.peer_max_ack_delay(), extra_3_ms); + + // Ack frame1, farme2, frame3. + manager_.OnAckFrameStart(QuicPacketNumber(1), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_FORWARD_SECURE, kEmptyCounts); + EXPECT_EQ(manager_.peer_max_ack_delay(), extra_3_ms); + manager_.OnAckFrameStart(QuicPacketNumber(2), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(3)); + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_FORWARD_SECURE, kEmptyCounts); + EXPECT_EQ(manager_.peer_max_ack_delay(), extra_3_ms); + manager_.OnAckFrameStart(QuicPacketNumber(3), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(4)); + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_FORWARD_SECURE, kEmptyCounts); + EXPECT_EQ(manager_.peer_max_ack_delay(), extra_2_ms); +} + +TEST_F(QuicSentPacketManagerTest, + PeerMaxAckDelayUpdatedFromOutOfOrderAckedAckFrequencyFrames) { + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _, _, _)) + .Times(AnyNumber()); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()) + .Times(AnyNumber()); + + auto initial_peer_max_ack_delay = manager_.peer_max_ack_delay(); + auto one_ms = QuicTime::Delta::FromMilliseconds(1); + auto extra_1_ms = initial_peer_max_ack_delay + one_ms; + auto extra_2_ms = initial_peer_max_ack_delay + 2 * one_ms; + auto extra_3_ms = initial_peer_max_ack_delay + 3 * one_ms; + auto extra_4_ms = initial_peer_max_ack_delay + 4 * one_ms; + SerializedPacket packet1 = MakePacketWithAckFrequencyFrame( + /*packet_number=*/1, /*ack_frequency_sequence_number=*/1, extra_4_ms); + SerializedPacket packet2 = MakePacketWithAckFrequencyFrame( + /*packet_number=*/2, /*ack_frequency_sequence_number=*/2, extra_3_ms); + SerializedPacket packet3 = MakePacketWithAckFrequencyFrame( + /*packet_number=*/3, /*ack_frequency_sequence_number=*/3, extra_2_ms); + SerializedPacket packet4 = MakePacketWithAckFrequencyFrame( + /*packet_number=*/4, /*ack_frequency_sequence_number=*/4, extra_1_ms); + + // Send frame1, farme2, frame3, frame4. + manager_.OnPacketSent(&packet1, clock_.Now(), NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA, /*measure_rtt=*/true, + ECN_NOT_ECT); + manager_.OnPacketSent(&packet2, clock_.Now(), NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA, /*measure_rtt=*/true, + ECN_NOT_ECT); + manager_.OnPacketSent(&packet3, clock_.Now(), NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA, /*measure_rtt=*/true, + ECN_NOT_ECT); + manager_.OnPacketSent(&packet4, clock_.Now(), NOT_RETRANSMISSION, + NO_RETRANSMITTABLE_DATA, /*measure_rtt=*/true, + ECN_NOT_ECT); + EXPECT_EQ(manager_.peer_max_ack_delay(), extra_4_ms); + + // Ack frame3. + manager_.OnAckFrameStart(QuicPacketNumber(3), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(3), QuicPacketNumber(4)); + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_FORWARD_SECURE, kEmptyCounts); + EXPECT_EQ(manager_.peer_max_ack_delay(), extra_2_ms); + // Acking frame1 do not affect peer_max_ack_delay after frame3 is acked. + manager_.OnAckFrameStart(QuicPacketNumber(3), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(3), QuicPacketNumber(4)); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_FORWARD_SECURE, kEmptyCounts); + EXPECT_EQ(manager_.peer_max_ack_delay(), extra_2_ms); + // Acking frame2 do not affect peer_max_ack_delay after frame3 is acked. + manager_.OnAckFrameStart(QuicPacketNumber(3), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(4)); + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_FORWARD_SECURE, kEmptyCounts); + EXPECT_EQ(manager_.peer_max_ack_delay(), extra_2_ms); + // Acking frame4 updates peer_max_ack_delay. + manager_.OnAckFrameStart(QuicPacketNumber(4), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(5)); + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_FORWARD_SECURE, kEmptyCounts); + EXPECT_EQ(manager_.peer_max_ack_delay(), extra_1_ms); +} + +TEST_F(QuicSentPacketManagerTest, ClearDataInMessageFrameAfterPacketSent) { + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + + QuicMessageFrame* message_frame = nullptr; + { + quiche::QuicheMemSlice slice(quiche::QuicheBuffer(&allocator_, 1024)); + message_frame = new QuicMessageFrame(/*message_id=*/1, std::move(slice)); + EXPECT_FALSE(message_frame->message_data.empty()); + EXPECT_EQ(message_frame->message_length, 1024); + + SerializedPacket packet(QuicPacketNumber(1), PACKET_4BYTE_PACKET_NUMBER, + /*encrypted_buffer=*/nullptr, kDefaultLength, + /*has_ack=*/false, + /*has_stop_waiting*/ false); + packet.encryption_level = ENCRYPTION_FORWARD_SECURE; + packet.retransmittable_frames.push_back(QuicFrame(message_frame)); + packet.has_message = true; + manager_.OnPacketSent(&packet, clock_.Now(), NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA, /*measure_rtt=*/true, + ECN_NOT_ECT); + } + + EXPECT_TRUE(message_frame->message_data.empty()); + EXPECT_EQ(message_frame->message_length, 0); +} + +TEST_F(QuicSentPacketManagerTest, BuildAckFrequencyFrame) { + SetQuicReloadableFlag(quic_can_send_ack_frequency, true); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + QuicConfig config; + QuicConfigPeer::SetReceivedMinAckDelayMs(&config, /*min_ack_delay_ms=*/1); + manager_.SetFromConfig(config); + manager_.SetHandshakeConfirmed(); + + // Set up RTTs. + auto* rtt_stats = const_cast(manager_.GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(80), + /*ack_delay=*/QuicTime::Delta::Zero(), + /*now=*/QuicTime::Zero()); + // Make sure srtt and min_rtt are different. + rtt_stats->UpdateRtt( + QuicTime::Delta::FromMilliseconds(160), + /*ack_delay=*/QuicTime::Delta::Zero(), + /*now=*/QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(24)); + + auto frame = manager_.GetUpdatedAckFrequencyFrame(); + EXPECT_EQ(frame.max_ack_delay, + std::max(rtt_stats->min_rtt() * 0.25, + QuicTime::Delta::FromMilliseconds(1u))); + EXPECT_EQ(frame.packet_tolerance, 10u); +} + +TEST_F(QuicSentPacketManagerTest, SmoothedRttIgnoreAckDelay) { + QuicConfig config; + QuicTagVector options; + options.push_back(kMAD0); + QuicConfigPeer::SetReceivedConnectionOptions(&config, options); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + EXPECT_CALL(*send_algorithm_, CanSend(_)).WillRepeatedly(Return(true)); + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .WillRepeatedly(Return(10 * kDefaultTCPMSS)); + manager_.SetFromConfig(config); + + SendDataPacket(1); + // Ack 1. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(300)); + ExpectAck(1); + manager_.OnAckFrameStart(QuicPacketNumber(1), + QuicTime::Delta::FromMilliseconds(100), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); + // Verify that ack_delay is ignored in the first measurement. + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(300), + manager_.GetRttStats()->latest_rtt()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(300), + manager_.GetRttStats()->smoothed_rtt()); + + SendDataPacket(2); + // Ack 2. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(300)); + ExpectAck(2); + manager_.OnAckFrameStart(QuicPacketNumber(2), + QuicTime::Delta::FromMilliseconds(100), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(2), QuicPacketNumber(3)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(2), + ENCRYPTION_INITIAL, kEmptyCounts)); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(300), + manager_.GetRttStats()->latest_rtt()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(300), + manager_.GetRttStats()->smoothed_rtt()); + + SendDataPacket(3); + // Ack 3. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(300)); + ExpectAck(3); + manager_.OnAckFrameStart(QuicPacketNumber(3), + QuicTime::Delta::FromMilliseconds(50), clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(3), QuicPacketNumber(4)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(3), + ENCRYPTION_INITIAL, kEmptyCounts)); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(300), + manager_.GetRttStats()->latest_rtt()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(300), + manager_.GetRttStats()->smoothed_rtt()); + + SendDataPacket(4); + // Ack 4. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(200)); + ExpectAck(4); + manager_.OnAckFrameStart(QuicPacketNumber(4), + QuicTime::Delta::FromMilliseconds(300), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(4), QuicPacketNumber(5)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(4), + ENCRYPTION_INITIAL, kEmptyCounts)); + // Verify that large erroneous ack_delay does not change Smoothed RTT. + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(200), + manager_.GetRttStats()->latest_rtt()); + EXPECT_EQ(QuicTime::Delta::FromMicroseconds(287500), + manager_.GetRttStats()->smoothed_rtt()); +} + +TEST_F(QuicSentPacketManagerTest, IgnorePeerMaxAckDelayDuringHandshake) { + manager_.EnableMultiplePacketNumberSpacesSupport(); + // 100ms RTT. + const QuicTime::Delta kTestRTT = QuicTime::Delta::FromMilliseconds(100); + + // Server sends INITIAL 1 and HANDSHAKE 2. + SendDataPacket(1, ENCRYPTION_INITIAL); + SendDataPacket(2, ENCRYPTION_HANDSHAKE); + + // Receive client ACK for INITIAL 1 after one RTT. + clock_.AdvanceTime(kTestRTT); + ExpectAck(1); + manager_.OnAckFrameStart(QuicPacketNumber(1), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL, kEmptyCounts)); + EXPECT_EQ(kTestRTT, manager_.GetRttStats()->latest_rtt()); + + // Assume the cert verification on client takes 50ms, such that the HANDSHAKE + // packet is queued for 50ms. + const QuicTime::Delta queuing_delay = QuicTime::Delta::FromMilliseconds(50); + clock_.AdvanceTime(queuing_delay); + // Ack 2. + ExpectAck(2); + manager_.OnAckFrameStart(QuicPacketNumber(2), queuing_delay, clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(2), QuicPacketNumber(3)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(2), + ENCRYPTION_HANDSHAKE, kEmptyCounts)); + EXPECT_EQ(kTestRTT, manager_.GetRttStats()->latest_rtt()); +} + +TEST_F(QuicSentPacketManagerTest, BuildAckFrequencyFrameWithSRTT) { + SetQuicReloadableFlag(quic_can_send_ack_frequency, true); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + QuicConfig config; + QuicConfigPeer::SetReceivedMinAckDelayMs(&config, /*min_ack_delay_ms=*/1); + QuicTagVector quic_tag_vector; + quic_tag_vector.push_back(kAFF1); // SRTT enabling tag. + QuicConfigPeer::SetReceivedConnectionOptions(&config, quic_tag_vector); + manager_.SetFromConfig(config); + manager_.SetHandshakeConfirmed(); + + // Set up RTTs. + auto* rtt_stats = const_cast(manager_.GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(80), + /*ack_delay=*/QuicTime::Delta::Zero(), + /*now=*/QuicTime::Zero()); + // Make sure srtt and min_rtt are different. + rtt_stats->UpdateRtt( + QuicTime::Delta::FromMilliseconds(160), + /*ack_delay=*/QuicTime::Delta::Zero(), + /*now=*/QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(24)); + + auto frame = manager_.GetUpdatedAckFrequencyFrame(); + EXPECT_EQ(frame.max_ack_delay, + std::max(rtt_stats->SmoothedOrInitialRtt() * 0.25, + QuicTime::Delta::FromMilliseconds(1u))); +} + +TEST_F(QuicSentPacketManagerTest, SetInitialRtt) { + // Upper bounds. + manager_.SetInitialRtt( + QuicTime::Delta::FromMicroseconds(kMaxInitialRoundTripTimeUs + 1), false); + EXPECT_EQ(manager_.GetRttStats()->initial_rtt().ToMicroseconds(), + kMaxInitialRoundTripTimeUs); + + manager_.SetInitialRtt( + QuicTime::Delta::FromMicroseconds(kMaxInitialRoundTripTimeUs + 1), true); + EXPECT_EQ(manager_.GetRttStats()->initial_rtt().ToMicroseconds(), + kMaxInitialRoundTripTimeUs); + + EXPECT_GT(kMinUntrustedInitialRoundTripTimeUs, + kMinTrustedInitialRoundTripTimeUs); + + // Lower bounds for untrusted rtt. + manager_.SetInitialRtt(QuicTime::Delta::FromMicroseconds( + kMinUntrustedInitialRoundTripTimeUs - 1), + false); + EXPECT_EQ(manager_.GetRttStats()->initial_rtt().ToMicroseconds(), + kMinUntrustedInitialRoundTripTimeUs); + + // Lower bounds for trusted rtt. + manager_.SetInitialRtt(QuicTime::Delta::FromMicroseconds( + kMinUntrustedInitialRoundTripTimeUs - 1), + true); + EXPECT_EQ(manager_.GetRttStats()->initial_rtt().ToMicroseconds(), + kMinUntrustedInitialRoundTripTimeUs - 1); + + manager_.SetInitialRtt( + QuicTime::Delta::FromMicroseconds(kMinTrustedInitialRoundTripTimeUs - 1), + true); + EXPECT_EQ(manager_.GetRttStats()->initial_rtt().ToMicroseconds(), + kMinTrustedInitialRoundTripTimeUs); +} + +TEST_F(QuicSentPacketManagerTest, GetAvailableCongestionWindow) { + SendDataPacket(1); + EXPECT_EQ(kDefaultLength, manager_.GetBytesInFlight()); + + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .WillOnce(Return(kDefaultLength + 10)); + EXPECT_EQ(10u, manager_.GetAvailableCongestionWindowInBytes()); + + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .WillOnce(Return(kDefaultLength)); + EXPECT_EQ(0u, manager_.GetAvailableCongestionWindowInBytes()); + + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .WillOnce(Return(kDefaultLength - 10)); + EXPECT_EQ(0u, manager_.GetAvailableCongestionWindowInBytes()); +} + +TEST_F(QuicSentPacketManagerTest, EcnCountsAreStored) { + absl::optional ecn_counts1, ecn_counts2, ecn_counts3; + ecn_counts1 = {1, 2, 3}; + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), ENCRYPTION_INITIAL, + ecn_counts1); + ecn_counts2 = {0, 3, 1}; + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(2), + ENCRYPTION_HANDSHAKE, ecn_counts2); + ecn_counts3 = {0, 2, 0}; + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(3), + ENCRYPTION_FORWARD_SECURE, ecn_counts3); + EXPECT_EQ( + *QuicSentPacketManagerPeer::GetPeerEcnCounts(&manager_, INITIAL_DATA), + ecn_counts1); + EXPECT_EQ( + *QuicSentPacketManagerPeer::GetPeerEcnCounts(&manager_, HANDSHAKE_DATA), + ecn_counts2); + EXPECT_EQ( + *QuicSentPacketManagerPeer::GetPeerEcnCounts(&manager_, APPLICATION_DATA), + ecn_counts3); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_server_id.cc b/quiche/quic/core/quic_server_id.cc new file mode 100644 index 000000000000..17233ff1d5e3 --- /dev/null +++ b/quiche/quic/core/quic_server_id.cc @@ -0,0 +1,108 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_server_id.h" + +#include +#include +#include + +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "url/third_party/mozilla/url_parse.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace quic { + +// static +absl::optional QuicServerId::ParseFromHostPortString( + absl::string_view host_port_string) { + url::Component username_component; + url::Component password_component; + url::Component host_component; + url::Component port_component; + + url::ParseAuthority(host_port_string.data(), + url::Component(0, host_port_string.size()), + &username_component, &password_component, &host_component, + &port_component); + + // Only support "host:port" and nothing more or less. + if (username_component.is_valid() || password_component.is_valid() || + !host_component.is_nonempty() || !port_component.is_nonempty()) { + QUICHE_DVLOG(1) << "QuicServerId could not be parsed: " << host_port_string; + return absl::nullopt; + } + + std::string hostname(host_port_string.data() + host_component.begin, + host_component.len); + + int parsed_port_number = + url::ParsePort(host_port_string.data(), port_component); + // Negative result is either invalid or unspecified, either of which is + // disallowed for this parse. Port 0 is technically valid but reserved and not + // really usable in practice, so easiest to just disallow it here. + if (parsed_port_number <= 0) { + QUICHE_DVLOG(1) + << "Port could not be parsed while parsing QuicServerId from: " + << host_port_string; + return absl::nullopt; + } + QUICHE_DCHECK_LE(parsed_port_number, std::numeric_limits::max()); + + return QuicServerId(std::move(hostname), + static_cast(parsed_port_number)); +} + +QuicServerId::QuicServerId() : QuicServerId("", 0, false) {} + +QuicServerId::QuicServerId(std::string host, uint16_t port) + : QuicServerId(std::move(host), port, false) {} + +QuicServerId::QuicServerId(std::string host, uint16_t port, + bool privacy_mode_enabled) + : host_(std::move(host)), + port_(port), + privacy_mode_enabled_(privacy_mode_enabled) {} + +QuicServerId::~QuicServerId() {} + +bool QuicServerId::operator<(const QuicServerId& other) const { + return std::tie(port_, host_, privacy_mode_enabled_) < + std::tie(other.port_, other.host_, other.privacy_mode_enabled_); +} + +bool QuicServerId::operator==(const QuicServerId& other) const { + return privacy_mode_enabled_ == other.privacy_mode_enabled_ && + host_ == other.host_ && port_ == other.port_; +} + +bool QuicServerId::operator!=(const QuicServerId& other) const { + return !(*this == other); +} + +std::string QuicServerId::ToHostPortString() const { + return absl::StrCat(GetHostWithIpv6Brackets(), ":", port_); +} + +absl::string_view QuicServerId::GetHostWithoutIpv6Brackets() const { + if (host_.length() > 2 && host_.front() == '[' && host_.back() == ']') { + return absl::string_view(host_.data() + 1, host_.length() - 2); + } else { + return host_; + } +} + +std::string QuicServerId::GetHostWithIpv6Brackets() const { + if (!absl::StrContains(host_, ':') || host_.length() <= 2 || + (host_.front() == '[' && host_.back() == ']')) { + return host_; + } else { + return absl::StrCat("[", host_, "]"); + } +} + +} // namespace quic diff --git a/quiche/quic/core/quic_server_id.h b/quiche/quic/core/quic_server_id.h new file mode 100644 index 000000000000..51d055b1a2ae --- /dev/null +++ b/quiche/quic/core/quic_server_id.h @@ -0,0 +1,73 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_SERVER_ID_H_ +#define QUICHE_QUIC_CORE_QUIC_SERVER_ID_H_ + +#include +#include + +#include "absl/hash/hash.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// The id used to identify sessions. Includes the hostname, port, scheme and +// privacy_mode. +class QUIC_EXPORT_PRIVATE QuicServerId { + public: + // Attempts to parse a QuicServerId from a "host:port" string. Returns nullopt + // if input could not be parsed. Requires input to contain both host and port + // and no other components of a URL authority. + static absl::optional ParseFromHostPortString( + absl::string_view host_port_string); + + QuicServerId(); + QuicServerId(std::string host, uint16_t port); + QuicServerId(std::string host, uint16_t port, bool privacy_mode_enabled); + ~QuicServerId(); + + // Needed to be an element of an ordered container. + bool operator<(const QuicServerId& other) const; + bool operator==(const QuicServerId& other) const; + + bool operator!=(const QuicServerId& other) const; + + const std::string& host() const { return host_; } + + uint16_t port() const { return port_; } + + bool privacy_mode_enabled() const { return privacy_mode_enabled_; } + + // Returns a "host:port" representation. IPv6 literal hosts will always be + // bracketed in result. + std::string ToHostPortString() const; + + // If host is an IPv6 literal surrounded by [], returns the substring without + // []. Otherwise, returns host as is. + absl::string_view GetHostWithoutIpv6Brackets() const; + + // If host is an IPv6 literal without surrounding [], returns host wrapped in + // []. Otherwise, returns host as is. + std::string GetHostWithIpv6Brackets() const; + + template + friend H AbslHashValue(H h, const QuicServerId& server_id) { + return H::combine(std::move(h), server_id.host(), server_id.port(), + server_id.privacy_mode_enabled()); + } + + private: + std::string host_; + uint16_t port_; + bool privacy_mode_enabled_; +}; + +using QuicServerIdHash = absl::Hash; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_SERVER_ID_H_ diff --git a/quiche/quic/core/quic_server_id_test.cc b/quiche/quic/core/quic_server_id_test.cc new file mode 100644 index 000000000000..1ca7a70e12ac --- /dev/null +++ b/quiche/quic/core/quic_server_id_test.cc @@ -0,0 +1,225 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_server_id.h" + +#include + +#include "absl/types/optional.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic::test { + +namespace { + +using ::testing::Optional; +using ::testing::Property; + +class QuicServerIdTest : public QuicTest {}; + +TEST_F(QuicServerIdTest, Constructor) { + QuicServerId google_server_id("google.com", 10, false); + EXPECT_EQ("google.com", google_server_id.host()); + EXPECT_EQ(10, google_server_id.port()); + EXPECT_FALSE(google_server_id.privacy_mode_enabled()); + + QuicServerId private_server_id("mail.google.com", 12, true); + EXPECT_EQ("mail.google.com", private_server_id.host()); + EXPECT_EQ(12, private_server_id.port()); + EXPECT_TRUE(private_server_id.privacy_mode_enabled()); +} + +TEST_F(QuicServerIdTest, LessThan) { + QuicServerId a_10_https("a.com", 10, false); + QuicServerId a_11_https("a.com", 11, false); + QuicServerId b_10_https("b.com", 10, false); + QuicServerId b_11_https("b.com", 11, false); + + QuicServerId a_10_https_private("a.com", 10, true); + QuicServerId a_11_https_private("a.com", 11, true); + QuicServerId b_10_https_private("b.com", 10, true); + QuicServerId b_11_https_private("b.com", 11, true); + + // Test combinations of host, port, and privacy being same on left and + // right side of less than. + EXPECT_FALSE(a_10_https < a_10_https); + EXPECT_TRUE(a_10_https < a_10_https_private); + EXPECT_FALSE(a_10_https_private < a_10_https); + EXPECT_FALSE(a_10_https_private < a_10_https_private); + + // Test with either host, port or https being different on left and right side + // of less than. + bool left_privacy; + bool right_privacy; + for (int i = 0; i < 4; i++) { + left_privacy = (i / 2 == 0); + right_privacy = (i % 2 == 0); + QuicServerId a_10_https_left_private("a.com", 10, left_privacy); + QuicServerId a_10_https_right_private("a.com", 10, right_privacy); + QuicServerId a_11_https_left_private("a.com", 11, left_privacy); + QuicServerId a_11_https_right_private("a.com", 11, right_privacy); + + QuicServerId b_10_https_left_private("b.com", 10, left_privacy); + QuicServerId b_10_https_right_private("b.com", 10, right_privacy); + QuicServerId b_11_https_left_private("b.com", 11, left_privacy); + QuicServerId b_11_https_right_private("b.com", 11, right_privacy); + + EXPECT_TRUE(a_10_https_left_private < a_11_https_right_private); + EXPECT_TRUE(a_10_https_left_private < b_10_https_right_private); + EXPECT_TRUE(a_10_https_left_private < b_11_https_right_private); + EXPECT_FALSE(a_11_https_left_private < a_10_https_right_private); + EXPECT_FALSE(a_11_https_left_private < b_10_https_right_private); + EXPECT_TRUE(a_11_https_left_private < b_11_https_right_private); + EXPECT_FALSE(b_10_https_left_private < a_10_https_right_private); + EXPECT_TRUE(b_10_https_left_private < a_11_https_right_private); + EXPECT_TRUE(b_10_https_left_private < b_11_https_right_private); + EXPECT_FALSE(b_11_https_left_private < a_10_https_right_private); + EXPECT_FALSE(b_11_https_left_private < a_11_https_right_private); + EXPECT_FALSE(b_11_https_left_private < b_10_https_right_private); + } +} + +TEST_F(QuicServerIdTest, Equals) { + bool left_privacy; + bool right_privacy; + for (int i = 0; i < 2; i++) { + left_privacy = right_privacy = (i == 0); + QuicServerId a_10_https_right_private("a.com", 10, right_privacy); + QuicServerId a_11_https_right_private("a.com", 11, right_privacy); + QuicServerId b_10_https_right_private("b.com", 10, right_privacy); + QuicServerId b_11_https_right_private("b.com", 11, right_privacy); + + EXPECT_NE(a_10_https_right_private, a_11_https_right_private); + EXPECT_NE(a_10_https_right_private, b_10_https_right_private); + EXPECT_NE(a_10_https_right_private, b_11_https_right_private); + + QuicServerId new_a_10_https_left_private("a.com", 10, left_privacy); + QuicServerId new_a_11_https_left_private("a.com", 11, left_privacy); + QuicServerId new_b_10_https_left_private("b.com", 10, left_privacy); + QuicServerId new_b_11_https_left_private("b.com", 11, left_privacy); + + EXPECT_EQ(new_a_10_https_left_private, a_10_https_right_private); + EXPECT_EQ(new_a_11_https_left_private, a_11_https_right_private); + EXPECT_EQ(new_b_10_https_left_private, b_10_https_right_private); + EXPECT_EQ(new_b_11_https_left_private, b_11_https_right_private); + } + + for (int i = 0; i < 2; i++) { + right_privacy = (i == 0); + QuicServerId a_10_https_right_private("a.com", 10, right_privacy); + QuicServerId a_11_https_right_private("a.com", 11, right_privacy); + QuicServerId b_10_https_right_private("b.com", 10, right_privacy); + QuicServerId b_11_https_right_private("b.com", 11, right_privacy); + + QuicServerId new_a_10_https_left_private("a.com", 10, false); + + EXPECT_NE(new_a_10_https_left_private, a_11_https_right_private); + EXPECT_NE(new_a_10_https_left_private, b_10_https_right_private); + EXPECT_NE(new_a_10_https_left_private, b_11_https_right_private); + } + QuicServerId a_10_https_private("a.com", 10, true); + QuicServerId new_a_10_https_no_private("a.com", 10, false); + EXPECT_NE(new_a_10_https_no_private, a_10_https_private); +} + +TEST_F(QuicServerIdTest, Parse) { + absl::optional server_id = + QuicServerId::ParseFromHostPortString("host.test:500"); + + EXPECT_THAT(server_id, Optional(Property(&QuicServerId::host, "host.test"))); + EXPECT_THAT(server_id, Optional(Property(&QuicServerId::port, 500))); + EXPECT_THAT(server_id, + Optional(Property(&QuicServerId::privacy_mode_enabled, false))); +} + +TEST_F(QuicServerIdTest, CannotParseMissingPort) { + absl::optional server_id = + QuicServerId::ParseFromHostPortString("host.test"); + + EXPECT_EQ(server_id, absl::nullopt); +} + +TEST_F(QuicServerIdTest, CannotParseEmptyPort) { + absl::optional server_id = + QuicServerId::ParseFromHostPortString("host.test:"); + + EXPECT_EQ(server_id, absl::nullopt); +} + +TEST_F(QuicServerIdTest, CannotParseEmptyHost) { + absl::optional server_id = + QuicServerId::ParseFromHostPortString(":500"); + + EXPECT_EQ(server_id, absl::nullopt); +} + +TEST_F(QuicServerIdTest, CannotParseUserInfo) { + absl::optional server_id = + QuicServerId::ParseFromHostPortString("userinfo@host.test:500"); + + EXPECT_EQ(server_id, absl::nullopt); +} + +TEST_F(QuicServerIdTest, ParseIpv6Literal) { + absl::optional server_id = + QuicServerId::ParseFromHostPortString("[::1]:400"); + + EXPECT_THAT(server_id, Optional(Property(&QuicServerId::host, "[::1]"))); + EXPECT_THAT(server_id, Optional(Property(&QuicServerId::port, 400))); + EXPECT_THAT(server_id, + Optional(Property(&QuicServerId::privacy_mode_enabled, false))); +} + +TEST_F(QuicServerIdTest, ParseUnbracketedIpv6Literal) { + absl::optional server_id = + QuicServerId::ParseFromHostPortString("::1:400"); + + EXPECT_THAT(server_id, Optional(Property(&QuicServerId::host, "::1"))); + EXPECT_THAT(server_id, Optional(Property(&QuicServerId::port, 400))); + EXPECT_THAT(server_id, + Optional(Property(&QuicServerId::privacy_mode_enabled, false))); +} + +TEST_F(QuicServerIdTest, AddBracketsToIpv6) { + QuicServerId server_id("::1", 100); + + EXPECT_EQ(server_id.GetHostWithIpv6Brackets(), "[::1]"); + EXPECT_EQ(server_id.ToHostPortString(), "[::1]:100"); +} + +TEST_F(QuicServerIdTest, AddBracketsAlreadyIncluded) { + QuicServerId server_id("[::1]", 100); + + EXPECT_EQ(server_id.GetHostWithIpv6Brackets(), "[::1]"); + EXPECT_EQ(server_id.ToHostPortString(), "[::1]:100"); +} + +TEST_F(QuicServerIdTest, AddBracketsNotAddedToNonIpv6) { + QuicServerId server_id("host.test", 100); + + EXPECT_EQ(server_id.GetHostWithIpv6Brackets(), "host.test"); + EXPECT_EQ(server_id.ToHostPortString(), "host.test:100"); +} + +TEST_F(QuicServerIdTest, RemoveBracketsFromIpv6) { + QuicServerId server_id("[::1]", 100); + + EXPECT_EQ(server_id.GetHostWithoutIpv6Brackets(), "::1"); +} + +TEST_F(QuicServerIdTest, RemoveBracketsNotIncluded) { + QuicServerId server_id("::1", 100); + + EXPECT_EQ(server_id.GetHostWithoutIpv6Brackets(), "::1"); +} + +TEST_F(QuicServerIdTest, RemoveBracketsFromNonIpv6) { + QuicServerId server_id("host.test", 100); + + EXPECT_EQ(server_id.GetHostWithoutIpv6Brackets(), "host.test"); +} + +} // namespace + +} // namespace quic::test diff --git a/quiche/quic/core/quic_session.cc b/quiche/quic/core/quic_session.cc new file mode 100644 index 000000000000..a5664acda70f --- /dev/null +++ b/quiche/quic/core/quic_session.cc @@ -0,0 +1,2728 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_session.h" + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/frames/quic_ack_frequency_frame.h" +#include "quiche/quic/core/frames/quic_window_update_frame.h" +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/core/quic_connection_context.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_flow_controller.h" +#include "quiche/quic/core/quic_stream_priority.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/core/quic_write_blocked_list.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_server_stats.h" +#include "quiche/quic/platform/api/quic_stack_trace.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_text_utils.h" + +namespace quic { + +namespace { + +class ClosedStreamsCleanUpDelegate : public QuicAlarm::Delegate { + public: + explicit ClosedStreamsCleanUpDelegate(QuicSession* session) + : session_(session) {} + ClosedStreamsCleanUpDelegate(const ClosedStreamsCleanUpDelegate&) = delete; + ClosedStreamsCleanUpDelegate& operator=(const ClosedStreamsCleanUpDelegate&) = + delete; + + QuicConnectionContext* GetConnectionContext() override { + return (session_->connection() == nullptr) + ? nullptr + : session_->connection()->context(); + } + + void OnAlarm() override { session_->CleanUpClosedStreams(); } + + private: + QuicSession* session_; +}; + +} // namespace + +#define ENDPOINT \ + (perspective() == Perspective::IS_SERVER ? "Server: " : "Client: ") + +QuicSession::QuicSession( + QuicConnection* connection, Visitor* owner, const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + QuicStreamCount num_expected_unidirectional_static_streams) + : QuicSession(connection, owner, config, supported_versions, + num_expected_unidirectional_static_streams, nullptr) {} + +QuicSession::QuicSession( + QuicConnection* connection, Visitor* owner, const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + QuicStreamCount num_expected_unidirectional_static_streams, + std::unique_ptr datagram_observer) + : connection_(connection), + perspective_(connection->perspective()), + visitor_(owner), + write_blocked_streams_(std::make_unique()), + config_(config), + stream_id_manager_(perspective(), connection->transport_version(), + kDefaultMaxStreamsPerConnection, + config_.GetMaxBidirectionalStreamsToSend()), + ietf_streamid_manager_(perspective(), connection->version(), this, 0, + num_expected_unidirectional_static_streams, + config_.GetMaxBidirectionalStreamsToSend(), + config_.GetMaxUnidirectionalStreamsToSend() + + num_expected_unidirectional_static_streams), + num_draining_streams_(0), + num_outgoing_draining_streams_(0), + num_static_streams_(0), + num_zombie_streams_(0), + flow_controller_( + this, QuicUtils::GetInvalidStreamId(connection->transport_version()), + /*is_connection_flow_controller*/ true, + connection->version().AllowsLowFlowControlLimits() + ? 0 + : kMinimumFlowControlSendWindow, + config_.GetInitialSessionFlowControlWindowToSend(), + kSessionReceiveWindowLimit, perspective() == Perspective::IS_SERVER, + nullptr), + currently_writing_stream_id_(0), + transport_goaway_sent_(false), + transport_goaway_received_(false), + control_frame_manager_(this), + last_message_id_(0), + datagram_queue_(this, std::move(datagram_observer)), + closed_streams_clean_up_alarm_(nullptr), + supported_versions_(supported_versions), + is_configured_(false), + was_zero_rtt_rejected_(false), + liveness_testing_in_progress_(false) { + closed_streams_clean_up_alarm_ = + absl::WrapUnique(connection_->alarm_factory()->CreateAlarm( + new ClosedStreamsCleanUpDelegate(this))); + if (VersionHasIetfQuicFrames(transport_version())) { + config_.SetMaxUnidirectionalStreamsToSend( + config_.GetMaxUnidirectionalStreamsToSend() + + num_expected_unidirectional_static_streams); + } +} + +void QuicSession::Initialize() { + connection_->set_visitor(this); + connection_->SetSessionNotifier(this); + connection_->SetDataProducer(this); + connection_->SetUnackedMapInitialCapacity(); + connection_->SetFromConfig(config_); + if (perspective_ == Perspective::IS_CLIENT) { + if (config_.HasClientRequestedIndependentOption(kAFFE, perspective_) && + version().HasIetfQuicFrames()) { + connection_->set_can_receive_ack_frequency_frame(); + config_.SetMinAckDelayMs(kDefaultMinAckDelayTimeMs); + } + } + if (perspective() == Perspective::IS_SERVER && + connection_->version().handshake_protocol == PROTOCOL_TLS1_3) { + config_.SetStatelessResetTokenToSend(GetStatelessResetToken()); + } + + connection_->CreateConnectionIdManager(); + + // On the server side, version negotiation has been done by the dispatcher, + // and the server session is created with the right version. + if (perspective() == Perspective::IS_SERVER) { + connection_->OnSuccessfulVersionNegotiation(); + } + + if (QuicVersionUsesCryptoFrames(transport_version())) { + return; + } + + QUICHE_DCHECK_EQ(QuicUtils::GetCryptoStreamId(transport_version()), + GetMutableCryptoStream()->id()); +} + +QuicSession::~QuicSession() { + if (closed_streams_clean_up_alarm_ != nullptr) { + closed_streams_clean_up_alarm_->PermanentCancel(); + } +} + +PendingStream* QuicSession::PendingStreamOnStreamFrame( + const QuicStreamFrame& frame) { + QUICHE_DCHECK(VersionUsesHttp3(transport_version())); + QuicStreamId stream_id = frame.stream_id; + + PendingStream* pending = GetOrCreatePendingStream(stream_id); + + if (!pending) { + if (frame.fin) { + QuicStreamOffset final_byte_offset = frame.offset + frame.data_length; + OnFinalByteOffsetReceived(stream_id, final_byte_offset); + } + return nullptr; + } + + pending->OnStreamFrame(frame); + if (!connection()->connected()) { + return nullptr; + } + return pending; +} + +void QuicSession::MaybeProcessPendingStream(PendingStream* pending) { + QUICHE_DCHECK(pending != nullptr); + QuicStreamId stream_id = pending->id(); + absl::optional stop_sending_error_code = + pending->GetStopSendingErrorCode(); + QuicStream* stream = ProcessPendingStream(pending); + if (stream != nullptr) { + // The pending stream should now be in the scope of normal streams. + QUICHE_DCHECK(IsClosedStream(stream_id) || IsOpenStream(stream_id)) + << "Stream " << stream_id << " not created"; + pending_stream_map_.erase(stream_id); + if (stop_sending_error_code) { + stream->OnStopSending(*stop_sending_error_code); + if (!connection()->connected()) { + return; + } + } + stream->OnStreamCreatedFromPendingStream(); + return; + } + // At this point, none of the bytes has been successfully consumed by the + // application layer. We should close the pending stream even if it is + // bidirectionl as no application will be able to write in a bidirectional + // stream with zero byte as input. + if (pending->sequencer()->IsClosed()) { + ClosePendingStream(stream_id); + } +} + +void QuicSession::PendingStreamOnWindowUpdateFrame( + const QuicWindowUpdateFrame& frame) { + QUICHE_DCHECK(VersionUsesHttp3(transport_version())); + PendingStream* pending = GetOrCreatePendingStream(frame.stream_id); + if (pending) { + pending->OnWindowUpdateFrame(frame); + } +} + +void QuicSession::PendingStreamOnStopSendingFrame( + const QuicStopSendingFrame& frame) { + QUICHE_DCHECK(VersionUsesHttp3(transport_version())); + PendingStream* pending = GetOrCreatePendingStream(frame.stream_id); + if (pending) { + pending->OnStopSending(frame.error()); + } +} + +void QuicSession::OnStreamFrame(const QuicStreamFrame& frame) { + QuicStreamId stream_id = frame.stream_id; + if (stream_id == QuicUtils::GetInvalidStreamId(transport_version())) { + connection()->CloseConnection( + QUIC_INVALID_STREAM_ID, "Received data for an invalid stream", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + + if (ShouldProcessFrameByPendingStream(STREAM_FRAME, stream_id)) { + PendingStream* pending = PendingStreamOnStreamFrame(frame); + if (pending != nullptr && ShouldProcessPendingStreamImmediately()) { + MaybeProcessPendingStream(pending); + } + return; + } + + QuicStream* stream = GetOrCreateStream(stream_id); + + if (!stream) { + // The stream no longer exists, but we may still be interested in the + // final stream byte offset sent by the peer. A frame with a FIN can give + // us this offset. + if (frame.fin) { + QuicStreamOffset final_byte_offset = frame.offset + frame.data_length; + OnFinalByteOffsetReceived(stream_id, final_byte_offset); + } + return; + } + stream->OnStreamFrame(frame); +} + +void QuicSession::OnCryptoFrame(const QuicCryptoFrame& frame) { + GetMutableCryptoStream()->OnCryptoFrame(frame); +} + +void QuicSession::OnStopSendingFrame(const QuicStopSendingFrame& frame) { + // STOP_SENDING is in IETF QUIC only. + QUICHE_DCHECK(VersionHasIetfQuicFrames(transport_version())); + QUICHE_DCHECK(QuicVersionUsesCryptoFrames(transport_version())); + + QuicStreamId stream_id = frame.stream_id; + // If Stream ID is invalid then close the connection. + // TODO(ianswett): This check is redundant to checks for IsClosedStream, + // but removing it requires removing multiple QUICHE_DCHECKs. + // TODO(ianswett): Multiple QUIC_DVLOGs could be QUIC_PEER_BUGs. + if (stream_id == QuicUtils::GetInvalidStreamId(transport_version())) { + QUIC_DVLOG(1) << ENDPOINT + << "Received STOP_SENDING with invalid stream_id: " + << stream_id << " Closing connection"; + connection()->CloseConnection( + QUIC_INVALID_STREAM_ID, "Received STOP_SENDING for an invalid stream", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + + // If stream_id is READ_UNIDIRECTIONAL, close the connection. + if (QuicUtils::GetStreamType(stream_id, perspective(), + IsIncomingStream(stream_id), + version()) == READ_UNIDIRECTIONAL) { + QUIC_DVLOG(1) << ENDPOINT + << "Received STOP_SENDING for a read-only stream_id: " + << stream_id << "."; + connection()->CloseConnection( + QUIC_INVALID_STREAM_ID, "Received STOP_SENDING for a read-only stream", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + + if (visitor_) { + visitor_->OnStopSendingReceived(frame); + } + if (ShouldProcessFrameByPendingStream(STOP_SENDING_FRAME, stream_id)) { + PendingStreamOnStopSendingFrame(frame); + return; + } + + QuicStream* stream = GetOrCreateStream(stream_id); + if (!stream) { + // Errors are handled by GetOrCreateStream. + return; + } + + stream->OnStopSending(frame.error()); +} + +void QuicSession::OnPacketDecrypted(EncryptionLevel level) { + GetMutableCryptoStream()->OnPacketDecrypted(level); + if (liveness_testing_in_progress_) { + liveness_testing_in_progress_ = false; + OnCanCreateNewOutgoingStream(/*unidirectional=*/false); + } +} + +void QuicSession::OnOneRttPacketAcknowledged() { + GetMutableCryptoStream()->OnOneRttPacketAcknowledged(); +} + +void QuicSession::OnHandshakePacketSent() { + GetMutableCryptoStream()->OnHandshakePacketSent(); +} + +std::unique_ptr +QuicSession::AdvanceKeysAndCreateCurrentOneRttDecrypter() { + return GetMutableCryptoStream()->AdvanceKeysAndCreateCurrentOneRttDecrypter(); +} + +std::unique_ptr QuicSession::CreateCurrentOneRttEncrypter() { + return GetMutableCryptoStream()->CreateCurrentOneRttEncrypter(); +} + +void QuicSession::PendingStreamOnRstStream(const QuicRstStreamFrame& frame) { + QUICHE_DCHECK(VersionUsesHttp3(transport_version())); + QuicStreamId stream_id = frame.stream_id; + + PendingStream* pending = GetOrCreatePendingStream(stream_id); + + if (!pending) { + HandleRstOnValidNonexistentStream(frame); + return; + } + + pending->OnRstStreamFrame(frame); + // At this point, none of the bytes has been consumed by the application + // layer. It is safe to close the pending stream even if it is bidirectionl as + // no application will be able to write in a bidirectional stream with zero + // byte as input. + ClosePendingStream(stream_id); +} + +void QuicSession::OnRstStream(const QuicRstStreamFrame& frame) { + QuicStreamId stream_id = frame.stream_id; + if (stream_id == QuicUtils::GetInvalidStreamId(transport_version())) { + connection()->CloseConnection( + QUIC_INVALID_STREAM_ID, "Received data for an invalid stream", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + + if (VersionHasIetfQuicFrames(transport_version()) && + QuicUtils::GetStreamType(stream_id, perspective(), + IsIncomingStream(stream_id), + version()) == WRITE_UNIDIRECTIONAL) { + connection()->CloseConnection( + QUIC_INVALID_STREAM_ID, "Received RESET_STREAM for a write-only stream", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + + if (visitor_) { + visitor_->OnRstStreamReceived(frame); + } + + if (ShouldProcessFrameByPendingStream(RST_STREAM_FRAME, stream_id)) { + PendingStreamOnRstStream(frame); + return; + } + + QuicStream* stream = GetOrCreateStream(stream_id); + + if (!stream) { + HandleRstOnValidNonexistentStream(frame); + return; // Errors are handled by GetOrCreateStream. + } + stream->OnStreamReset(frame); +} + +void QuicSession::OnGoAway(const QuicGoAwayFrame& /*frame*/) { + QUIC_BUG_IF(quic_bug_12435_1, version().UsesHttp3()) + << "gQUIC GOAWAY received on version " << version(); + + transport_goaway_received_ = true; +} + +void QuicSession::OnMessageReceived(absl::string_view message) { + QUIC_DVLOG(1) << ENDPOINT << "Received message of length " + << message.length(); + QUIC_DVLOG(2) << ENDPOINT << "Contents of message of length " + << message.length() << ":" << std::endl + << quiche::QuicheTextUtils::HexDump(message); +} + +void QuicSession::OnHandshakeDoneReceived() { + QUIC_DVLOG(1) << ENDPOINT << "OnHandshakeDoneReceived"; + GetMutableCryptoStream()->OnHandshakeDoneReceived(); +} + +void QuicSession::OnNewTokenReceived(absl::string_view token) { + QUICHE_DCHECK_EQ(perspective_, Perspective::IS_CLIENT); + GetMutableCryptoStream()->OnNewTokenReceived(token); +} + +// static +void QuicSession::RecordConnectionCloseAtServer(QuicErrorCode error, + ConnectionCloseSource source) { + if (error != QUIC_NO_ERROR) { + if (source == ConnectionCloseSource::FROM_SELF) { + QUIC_SERVER_HISTOGRAM_ENUM( + "quic_server_connection_close_errors", error, QUIC_LAST_ERROR, + "QuicErrorCode for server-closed connections."); + } else { + QUIC_SERVER_HISTOGRAM_ENUM( + "quic_client_connection_close_errors", error, QUIC_LAST_ERROR, + "QuicErrorCode for client-closed connections."); + } + } +} + +void QuicSession::OnConnectionClosed(const QuicConnectionCloseFrame& frame, + ConnectionCloseSource source) { + QUICHE_DCHECK(!connection_->connected()); + if (perspective() == Perspective::IS_SERVER) { + RecordConnectionCloseAtServer(frame.quic_error_code, source); + } + + if (on_closed_frame_.quic_error_code == QUIC_NO_ERROR) { + // Save all of the connection close information + on_closed_frame_ = frame; + source_ = source; + } + + GetMutableCryptoStream()->OnConnectionClosed(frame.quic_error_code, source); + + PerformActionOnActiveStreams([this, frame, source](QuicStream* stream) { + QuicStreamId id = stream->id(); + stream->OnConnectionClosed(frame.quic_error_code, source); + auto it = stream_map_.find(id); + if (it != stream_map_.end()) { + QUIC_BUG_IF(quic_bug_12435_2, !it->second->IsZombie()) + << ENDPOINT << "Non-zombie stream " << id + << " failed to close under OnConnectionClosed"; + } + return true; + }); + + closed_streams_clean_up_alarm_->Cancel(); + + if (visitor_) { + visitor_->OnConnectionClosed(connection_->GetOneActiveServerConnectionId(), + frame.quic_error_code, frame.error_details, + source); + } +} + +void QuicSession::OnWriteBlocked() { + if (!connection_->connected()) { + return; + } + if (visitor_) { + visitor_->OnWriteBlocked(connection_); + } +} + +void QuicSession::OnSuccessfulVersionNegotiation( + const ParsedQuicVersion& /*version*/) {} + +void QuicSession::OnPacketReceived(const QuicSocketAddress& /*self_address*/, + const QuicSocketAddress& peer_address, + bool is_connectivity_probe) { + if (is_connectivity_probe && perspective() == Perspective::IS_SERVER) { + // Server only sends back a connectivity probe after received a + // connectivity probe from a new peer address. + connection_->SendConnectivityProbingPacket(nullptr, peer_address); + } +} + +void QuicSession::OnPathDegrading() {} + +void QuicSession::OnForwardProgressMadeAfterPathDegrading() {} + +bool QuicSession::AllowSelfAddressChange() const { return false; } + +void QuicSession::OnWindowUpdateFrame(const QuicWindowUpdateFrame& frame) { + // Stream may be closed by the time we receive a WINDOW_UPDATE, so we can't + // assume that it still exists. + QuicStreamId stream_id = frame.stream_id; + if (stream_id == QuicUtils::GetInvalidStreamId(transport_version())) { + // This is a window update that applies to the connection, rather than an + // individual stream. + QUIC_DVLOG(1) << ENDPOINT + << "Received connection level flow control window " + "update with max data: " + << frame.max_data; + flow_controller_.UpdateSendWindowOffset(frame.max_data); + return; + } + + if (VersionHasIetfQuicFrames(transport_version()) && + QuicUtils::GetStreamType(stream_id, perspective(), + IsIncomingStream(stream_id), + version()) == READ_UNIDIRECTIONAL) { + connection()->CloseConnection( + QUIC_WINDOW_UPDATE_RECEIVED_ON_READ_UNIDIRECTIONAL_STREAM, + "WindowUpdateFrame received on READ_UNIDIRECTIONAL stream.", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + + if (ShouldProcessFrameByPendingStream(WINDOW_UPDATE_FRAME, stream_id)) { + PendingStreamOnWindowUpdateFrame(frame); + return; + } + + QuicStream* stream = GetOrCreateStream(stream_id); + if (stream != nullptr) { + stream->OnWindowUpdateFrame(frame); + } +} + +void QuicSession::OnBlockedFrame(const QuicBlockedFrame& frame) { + // TODO(rjshade): Compare our flow control receive windows for specified + // streams: if we have a large window then maybe something + // had gone wrong with the flow control accounting. + QUIC_DLOG(INFO) << ENDPOINT << "Received BLOCKED frame with stream id: " + << frame.stream_id << ", offset: " << frame.offset; +} + +bool QuicSession::CheckStreamNotBusyLooping(QuicStream* stream, + uint64_t previous_bytes_written, + bool previous_fin_sent) { + if ( // Stream should not be closed. + !stream->write_side_closed() && + // Not connection flow control blocked. + !flow_controller_.IsBlocked() && + // Detect lack of forward progress. + previous_bytes_written == stream->stream_bytes_written() && + previous_fin_sent == stream->fin_sent()) { + stream->set_busy_counter(stream->busy_counter() + 1); + QUIC_DVLOG(1) << ENDPOINT << "Suspected busy loop on stream id " + << stream->id() << " stream_bytes_written " + << stream->stream_bytes_written() << " fin " + << stream->fin_sent() << " count " << stream->busy_counter(); + // Wait a few iterations before firing, the exact count is + // arbitrary, more than a few to cover a few test-only false + // positives. + if (stream->busy_counter() > 20) { + QUIC_LOG(ERROR) << ENDPOINT << "Detected busy loop on stream id " + << stream->id() << " stream_bytes_written " + << stream->stream_bytes_written() << " fin " + << stream->fin_sent(); + return false; + } + } else { + stream->set_busy_counter(0); + } + return true; +} + +bool QuicSession::CheckStreamWriteBlocked(QuicStream* stream) const { + if (!stream->write_side_closed() && stream->HasBufferedData() && + !stream->IsFlowControlBlocked() && + !write_blocked_streams_->IsStreamBlocked(stream->id())) { + QUIC_DLOG(ERROR) << ENDPOINT << "stream " << stream->id() + << " has buffered " << stream->BufferedDataBytes() + << " bytes, and is not flow control blocked, " + "but it is not in the write block list."; + return false; + } + return true; +} + +void QuicSession::OnCanWrite() { + if (connection_->framer().is_processing_packet()) { + // Do not write data in the middle of packet processing because rest + // frames in the packet may change the data to write. For example, lost + // data could be acknowledged. Also, connection is going to emit + // OnCanWrite signal post packet processing. + QUIC_BUG(session_write_mid_packet_processing) + << ENDPOINT << "Try to write mid packet processing."; + return; + } + if (!RetransmitLostData()) { + // Cannot finish retransmitting lost data, connection is write blocked. + QUIC_DVLOG(1) << ENDPOINT + << "Cannot finish retransmitting lost data, connection is " + "write blocked."; + return; + } + // We limit the number of writes to the number of pending streams. If more + // streams become pending, WillingAndAbleToWrite will be true, which will + // cause the connection to request resumption before yielding to other + // connections. + // If we are connection level flow control blocked, then only allow the + // crypto and headers streams to try writing as all other streams will be + // blocked. + size_t num_writes = flow_controller_.IsBlocked() + ? write_blocked_streams_->NumBlockedSpecialStreams() + : write_blocked_streams_->NumBlockedStreams(); + if (num_writes == 0 && !control_frame_manager_.WillingToWrite() && + datagram_queue_.empty() && + (!QuicVersionUsesCryptoFrames(transport_version()) || + !GetCryptoStream()->HasBufferedCryptoFrames())) { + return; + } + + QuicConnection::ScopedPacketFlusher flusher(connection_); + if (QuicVersionUsesCryptoFrames(transport_version())) { + QuicCryptoStream* crypto_stream = GetMutableCryptoStream(); + if (crypto_stream->HasBufferedCryptoFrames()) { + crypto_stream->WriteBufferedCryptoFrames(); + } + if ((GetQuicReloadableFlag( + quic_no_write_control_frame_upon_connection_close) && + !connection_->connected()) || + crypto_stream->HasBufferedCryptoFrames()) { + // Cannot finish writing buffered crypto frames, connection is either + // write blocked or closed. + return; + } + } + if (control_frame_manager_.WillingToWrite()) { + control_frame_manager_.OnCanWrite(); + } + if (version().UsesTls() && GetHandshakeState() != HANDSHAKE_CONFIRMED && + connection_->in_probe_time_out()) { + QUIC_CODE_COUNT(quic_donot_pto_stream_data_before_handshake_confirmed); + // Do not PTO stream data before handshake gets confirmed. + return; + } + // TODO(b/147146815): this makes all datagrams go before stream data. We + // should have a better priority scheme for this. + if (!datagram_queue_.empty()) { + size_t written = datagram_queue_.SendDatagrams(); + QUIC_DVLOG(1) << ENDPOINT << "Sent " << written << " datagrams"; + if (!datagram_queue_.empty()) { + return; + } + } + std::vector last_writing_stream_ids; + for (size_t i = 0; i < num_writes; ++i) { + if (!(write_blocked_streams_->HasWriteBlockedSpecialStream() || + write_blocked_streams_->HasWriteBlockedDataStreams())) { + // Writing one stream removed another!? Something's broken. + QUIC_BUG(quic_bug_10866_1) + << "WriteBlockedStream is missing, num_writes: " << num_writes + << ", finished_writes: " << i + << ", connected: " << connection_->connected() + << ", connection level flow control blocked: " + << flow_controller_.IsBlocked(); + for (QuicStreamId id : last_writing_stream_ids) { + QUIC_LOG(WARNING) << "last_writing_stream_id: " << id; + } + connection_->CloseConnection(QUIC_INTERNAL_ERROR, + "WriteBlockedStream is missing", + ConnectionCloseBehavior::SILENT_CLOSE); + return; + } + if (!CanWriteStreamData()) { + return; + } + currently_writing_stream_id_ = write_blocked_streams_->PopFront(); + last_writing_stream_ids.push_back(currently_writing_stream_id_); + QUIC_DVLOG(1) << ENDPOINT << "Removing stream " + << currently_writing_stream_id_ << " from write-blocked list"; + QuicStream* stream = GetOrCreateStream(currently_writing_stream_id_); + if (stream != nullptr && !stream->IsFlowControlBlocked()) { + // If the stream can't write all bytes it'll re-add itself to the blocked + // list. + uint64_t previous_bytes_written = stream->stream_bytes_written(); + bool previous_fin_sent = stream->fin_sent(); + QUIC_DVLOG(1) << ENDPOINT << "stream " << stream->id() + << " bytes_written " << previous_bytes_written << " fin " + << previous_fin_sent; + stream->OnCanWrite(); + QUICHE_DCHECK(CheckStreamWriteBlocked(stream)); + QUICHE_DCHECK(CheckStreamNotBusyLooping(stream, previous_bytes_written, + previous_fin_sent)); + } + currently_writing_stream_id_ = 0; + } +} + +bool QuicSession::WillingAndAbleToWrite() const { + // Schedule a write when: + // 1) control frame manager has pending or new control frames, or + // 2) any stream has pending retransmissions, or + // 3) If the crypto or headers streams are blocked, or + // 4) connection is not flow control blocked and there are write blocked + // streams. + if (QuicVersionUsesCryptoFrames(transport_version())) { + if (HasPendingHandshake()) { + return true; + } + if (!IsEncryptionEstablished()) { + return false; + } + } + if (control_frame_manager_.WillingToWrite() || + !streams_with_pending_retransmission_.empty()) { + return true; + } + if (flow_controller_.IsBlocked()) { + if (VersionUsesHttp3(transport_version())) { + return false; + } + // Crypto and headers streams are not blocked by connection level flow + // control. + return write_blocked_streams_->HasWriteBlockedSpecialStream(); + } + return write_blocked_streams_->HasWriteBlockedSpecialStream() || + write_blocked_streams_->HasWriteBlockedDataStreams(); +} + +std::string QuicSession::GetStreamsInfoForLogging() const { + std::string info = absl::StrCat( + "num_active_streams: ", GetNumActiveStreams(), + ", num_pending_streams: ", pending_streams_size(), + ", num_outgoing_draining_streams: ", num_outgoing_draining_streams(), + " "); + // Log info for up to 5 streams. + size_t i = 5; + for (const auto& it : stream_map_) { + if (it.second->is_static()) { + continue; + } + // Calculate the stream creation delay. + const QuicTime::Delta delay = + connection_->clock()->ApproximateNow() - it.second->creation_time(); + absl::StrAppend( + &info, "{", it.second->id(), ":", delay.ToDebuggingValue(), ";", + it.second->stream_bytes_written(), ",", it.second->fin_sent(), ",", + it.second->HasBufferedData(), ",", it.second->fin_buffered(), ";", + it.second->stream_bytes_read(), ",", it.second->fin_received(), "}"); + --i; + if (i == 0) { + break; + } + } + return info; +} + +bool QuicSession::HasPendingHandshake() const { + if (QuicVersionUsesCryptoFrames(transport_version())) { + return GetCryptoStream()->HasPendingCryptoRetransmission() || + GetCryptoStream()->HasBufferedCryptoFrames(); + } + return streams_with_pending_retransmission_.contains( + QuicUtils::GetCryptoStreamId(transport_version())) || + write_blocked_streams_->IsStreamBlocked( + QuicUtils::GetCryptoStreamId(transport_version())); +} + +void QuicSession::ProcessUdpPacket(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + const QuicReceivedPacket& packet) { + QuicConnectionContextSwitcher cs(connection_->context()); + connection_->ProcessUdpPacket(self_address, peer_address, packet); +} + +std::string QuicSession::on_closed_frame_string() const { + std::stringstream ss; + ss << on_closed_frame_; + if (source_.has_value()) { + ss << " " << ConnectionCloseSourceToString(source_.value()); + } + return ss.str(); +} + +QuicConsumedData QuicSession::WritevData(QuicStreamId id, size_t write_length, + QuicStreamOffset offset, + StreamSendingState state, + TransmissionType type, + EncryptionLevel level) { + QUIC_BUG_IF(session writevdata when disconnected, !connection()->connected()) + << ENDPOINT << "Try to write stream data when connection is closed: " + << on_closed_frame_string(); + if (!IsEncryptionEstablished() && + !QuicUtils::IsCryptoStreamId(transport_version(), id)) { + // Do not let streams write without encryption. The calling stream will end + // up write blocked until OnCanWrite is next called. + if (was_zero_rtt_rejected_ && !OneRttKeysAvailable()) { + QUICHE_DCHECK(version().UsesTls() && + perspective() == Perspective::IS_CLIENT); + QUIC_DLOG(INFO) << ENDPOINT + << "Suppress the write while 0-RTT gets rejected and " + "1-RTT keys are not available. Version: " + << ParsedQuicVersionToString(version()); + } else if (version().UsesTls() || perspective() == Perspective::IS_SERVER) { + QUIC_BUG(quic_bug_10866_2) + << ENDPOINT << "Try to send data of stream " << id + << " before encryption is established. Version: " + << ParsedQuicVersionToString(version()); + } else { + // In QUIC crypto, this could happen when the client sends full CHLO and + // 0-RTT request, then receives an inchoate REJ and sends an inchoate + // CHLO. The client then gets the ACK of the inchoate CHLO or the client + // gets the full REJ and needs to verify the proof (before it sends the + // full CHLO), such that there is no outstanding crypto data. + // Retransmission alarm fires in TLP mode which tries to retransmit the + // 0-RTT request (without encryption). + QUIC_DLOG(INFO) << ENDPOINT << "Try to send data of stream " << id + << " before encryption is established."; + } + return QuicConsumedData(0, false); + } + + SetTransmissionType(type); + QuicConnection::ScopedEncryptionLevelContext context(connection(), level); + + QuicConsumedData data = + connection_->SendStreamData(id, write_length, offset, state); + if (type == NOT_RETRANSMISSION) { + // This is new stream data. + write_blocked_streams_->UpdateBytesForStream(id, data.bytes_consumed); + } + + return data; +} + +size_t QuicSession::SendCryptoData(EncryptionLevel level, size_t write_length, + QuicStreamOffset offset, + TransmissionType type) { + QUICHE_DCHECK(QuicVersionUsesCryptoFrames(transport_version())); + if (!connection()->framer().HasEncrypterOfEncryptionLevel(level)) { + const std::string error_details = absl::StrCat( + "Try to send crypto data with missing keys of encryption level: ", + EncryptionLevelToString(level)); + QUIC_BUG(quic_bug_10866_3) << ENDPOINT << error_details; + connection()->CloseConnection( + QUIC_MISSING_WRITE_KEYS, error_details, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return 0; + } + SetTransmissionType(type); + QuicConnection::ScopedEncryptionLevelContext context(connection(), level); + const auto bytes_consumed = + connection_->SendCryptoData(level, write_length, offset); + return bytes_consumed; +} + +void QuicSession::OnControlFrameManagerError(QuicErrorCode error_code, + std::string error_details) { + connection_->CloseConnection( + error_code, error_details, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); +} + +bool QuicSession::WriteControlFrame(const QuicFrame& frame, + TransmissionType type) { + QUIC_BUG_IF(quic_bug_12435_11, !connection()->connected()) + << ENDPOINT + << absl::StrCat("Try to write control frame: ", QuicFrameToString(frame), + " when connection is closed: ") + << on_closed_frame_string(); + if (!IsEncryptionEstablished()) { + // Suppress the write before encryption gets established. + return false; + } + SetTransmissionType(type); + QuicConnection::ScopedEncryptionLevelContext context( + connection(), GetEncryptionLevelToSendApplicationData()); + return connection_->SendControlFrame(frame); +} + +void QuicSession::ResetStream(QuicStreamId id, QuicRstStreamErrorCode error) { + QuicStream* stream = GetStream(id); + if (stream != nullptr && stream->is_static()) { + connection()->CloseConnection( + QUIC_INVALID_STREAM_ID, "Try to reset a static stream", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + + if (stream != nullptr) { + stream->Reset(error); + return; + } + + QuicConnection::ScopedPacketFlusher flusher(connection()); + MaybeSendStopSendingFrame(id, QuicResetStreamError::FromInternal(error)); + MaybeSendRstStreamFrame(id, QuicResetStreamError::FromInternal(error), 0); +} + +void QuicSession::MaybeSendRstStreamFrame(QuicStreamId id, + QuicResetStreamError error, + QuicStreamOffset bytes_written) { + if (!connection()->connected()) { + return; + } + if (!VersionHasIetfQuicFrames(transport_version()) || + QuicUtils::GetStreamType(id, perspective(), IsIncomingStream(id), + version()) != READ_UNIDIRECTIONAL) { + control_frame_manager_.WriteOrBufferRstStream(id, error, bytes_written); + } + + connection_->OnStreamReset(id, error.internal_code()); +} + +void QuicSession::MaybeSendStopSendingFrame(QuicStreamId id, + QuicResetStreamError error) { + if (!connection()->connected()) { + return; + } + if (VersionHasIetfQuicFrames(transport_version()) && + QuicUtils::GetStreamType(id, perspective(), IsIncomingStream(id), + version()) != WRITE_UNIDIRECTIONAL) { + control_frame_manager_.WriteOrBufferStopSending(error, id); + } +} + +void QuicSession::SendGoAway(QuicErrorCode error_code, + const std::string& reason) { + // GOAWAY frame is not supported in IETF QUIC. + QUICHE_DCHECK(!VersionHasIetfQuicFrames(transport_version())); + if (!IsEncryptionEstablished()) { + QUIC_CODE_COUNT(quic_goaway_before_encryption_established); + connection_->CloseConnection( + error_code, reason, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + if (transport_goaway_sent_) { + return; + } + transport_goaway_sent_ = true; + + QUICHE_DCHECK_EQ(perspective(), Perspective::IS_SERVER); + control_frame_manager_.WriteOrBufferGoAway( + error_code, + QuicUtils::GetMaxClientInitiatedBidirectionalStreamId( + transport_version()), + reason); +} + +void QuicSession::SendBlocked(QuicStreamId id, QuicStreamOffset byte_offset) { + control_frame_manager_.WriteOrBufferBlocked(id, byte_offset); +} + +void QuicSession::SendWindowUpdate(QuicStreamId id, + QuicStreamOffset byte_offset) { + control_frame_manager_.WriteOrBufferWindowUpdate(id, byte_offset); +} + +void QuicSession::OnStreamError(QuicErrorCode error_code, + std::string error_details) { + connection_->CloseConnection( + error_code, error_details, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); +} + +void QuicSession::OnStreamError(QuicErrorCode error_code, + QuicIetfTransportErrorCodes ietf_error, + std::string error_details) { + connection_->CloseConnection( + error_code, ietf_error, error_details, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); +} + +void QuicSession::SendMaxStreams(QuicStreamCount stream_count, + bool unidirectional) { + if (!is_configured_) { + QUIC_BUG(quic_bug_10866_5) + << "Try to send max streams before config negotiated."; + return; + } + control_frame_manager_.WriteOrBufferMaxStreams(stream_count, unidirectional); +} + +void QuicSession::InsertLocallyClosedStreamsHighestOffset( + const QuicStreamId id, QuicStreamOffset offset) { + locally_closed_streams_highest_offset_[id] = offset; +} + +void QuicSession::OnStreamClosed(QuicStreamId stream_id) { + QUIC_DVLOG(1) << ENDPOINT << "Closing stream: " << stream_id; + StreamMap::iterator it = stream_map_.find(stream_id); + if (it == stream_map_.end()) { + QUIC_BUG(quic_bug_10866_6) + << ENDPOINT << "Stream is already closed: " << stream_id; + return; + } + QuicStream* stream = it->second.get(); + StreamType type = stream->type(); + + const bool stream_waiting_for_acks = stream->IsWaitingForAcks(); + if (stream_waiting_for_acks) { + // The stream needs to be kept alive because it's waiting for acks. + ++num_zombie_streams_; + } else { + closed_streams_.push_back(std::move(it->second)); + stream_map_.erase(it); + // Do not retransmit data of a closed stream. + streams_with_pending_retransmission_.erase(stream_id); + if (!closed_streams_clean_up_alarm_->IsSet()) { + closed_streams_clean_up_alarm_->Set( + connection_->clock()->ApproximateNow()); + } + connection_->QuicBugIfHasPendingFrames(stream_id); + } + + if (!stream->HasReceivedFinalOffset()) { + // If we haven't received a FIN or RST for this stream, we need to keep + // track of the how many bytes the stream's flow controller believes it has + // received, for accurate connection level flow control accounting. + // If this is an outgoing stream, it is technically open from peer's + // perspective. Do not inform stream Id manager yet. + QUICHE_DCHECK(!stream->was_draining()); + InsertLocallyClosedStreamsHighestOffset( + stream_id, stream->highest_received_byte_offset()); + return; + } + + const bool stream_was_draining = stream->was_draining(); + QUIC_DVLOG_IF(1, stream_was_draining) + << ENDPOINT << "Stream " << stream_id << " was draining"; + if (stream_was_draining) { + QUIC_BUG_IF(quic_bug_12435_4, num_draining_streams_ == 0); + --num_draining_streams_; + if (!IsIncomingStream(stream_id)) { + QUIC_BUG_IF(quic_bug_12435_5, num_outgoing_draining_streams_ == 0); + --num_outgoing_draining_streams_; + } + // Stream Id manager has been informed with draining streams. + return; + } + if (!VersionHasIetfQuicFrames(transport_version())) { + stream_id_manager_.OnStreamClosed( + /*is_incoming=*/IsIncomingStream(stream_id)); + } + if (!connection_->connected()) { + return; + } + if (IsIncomingStream(stream_id)) { + // Stream Id manager is only interested in peer initiated stream IDs. + if (VersionHasIetfQuicFrames(transport_version())) { + ietf_streamid_manager_.OnStreamClosed(stream_id); + } + return; + } + if (!VersionHasIetfQuicFrames(transport_version())) { + OnCanCreateNewOutgoingStream(type != BIDIRECTIONAL); + } +} + +void QuicSession::ClosePendingStream(QuicStreamId stream_id) { + QUIC_DVLOG(1) << ENDPOINT << "Closing stream " << stream_id; + QUICHE_DCHECK(VersionHasIetfQuicFrames(transport_version())); + pending_stream_map_.erase(stream_id); + if (connection_->connected()) { + ietf_streamid_manager_.OnStreamClosed(stream_id); + } +} + +bool QuicSession::ShouldProcessFrameByPendingStream(QuicFrameType type, + QuicStreamId id) const { + return UsesPendingStreamForFrame(type, id) && + stream_map_.find(id) == stream_map_.end(); +} + +void QuicSession::OnFinalByteOffsetReceived( + QuicStreamId stream_id, QuicStreamOffset final_byte_offset) { + auto it = locally_closed_streams_highest_offset_.find(stream_id); + if (it == locally_closed_streams_highest_offset_.end()) { + return; + } + + QUIC_DVLOG(1) << ENDPOINT << "Received final byte offset " + << final_byte_offset << " for stream " << stream_id; + QuicByteCount offset_diff = final_byte_offset - it->second; + if (flow_controller_.UpdateHighestReceivedOffset( + flow_controller_.highest_received_byte_offset() + offset_diff)) { + // If the final offset violates flow control, close the connection now. + if (flow_controller_.FlowControlViolation()) { + connection_->CloseConnection( + QUIC_FLOW_CONTROL_RECEIVED_TOO_MUCH_DATA, + "Connection level flow control violation", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + } + + flow_controller_.AddBytesConsumed(offset_diff); + locally_closed_streams_highest_offset_.erase(it); + if (!VersionHasIetfQuicFrames(transport_version())) { + stream_id_manager_.OnStreamClosed( + /*is_incoming=*/IsIncomingStream(stream_id)); + } + if (IsIncomingStream(stream_id)) { + if (VersionHasIetfQuicFrames(transport_version())) { + ietf_streamid_manager_.OnStreamClosed(stream_id); + } + } else if (!VersionHasIetfQuicFrames(transport_version())) { + OnCanCreateNewOutgoingStream(false); + } +} + +bool QuicSession::IsEncryptionEstablished() const { + if (GetCryptoStream() == nullptr) { + return false; + } + return GetCryptoStream()->encryption_established(); +} + +bool QuicSession::OneRttKeysAvailable() const { + if (GetCryptoStream() == nullptr) { + return false; + } + return GetCryptoStream()->one_rtt_keys_available(); +} + +void QuicSession::OnConfigNegotiated() { + // In versions with TLS, the configs will be set twice if 0-RTT is available. + // In the second config setting, 1-RTT keys are guaranteed to be available. + if (version().UsesTls() && is_configured_ && + connection_->encryption_level() != ENCRYPTION_FORWARD_SECURE) { + QUIC_BUG(quic_bug_12435_6) + << ENDPOINT + << "1-RTT keys missing when config is negotiated for the second time."; + connection_->CloseConnection( + QUIC_INTERNAL_ERROR, + "1-RTT keys missing when config is negotiated for the second time.", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + + QUIC_DVLOG(1) << ENDPOINT << "OnConfigNegotiated"; + connection_->SetFromConfig(config_); + + if (VersionHasIetfQuicFrames(transport_version())) { + uint32_t max_streams = 0; + if (config_.HasReceivedMaxBidirectionalStreams()) { + max_streams = config_.ReceivedMaxBidirectionalStreams(); + } + if (was_zero_rtt_rejected_ && + max_streams < + ietf_streamid_manager_.outgoing_bidirectional_stream_count()) { + connection_->CloseConnection( + QUIC_ZERO_RTT_UNRETRANSMITTABLE, + absl::StrCat( + "Server rejected 0-RTT, aborting because new bidirectional " + "initial stream limit ", + max_streams, " is less than current open streams: ", + ietf_streamid_manager_.outgoing_bidirectional_stream_count()), + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + QUIC_DVLOG(1) << ENDPOINT + << "Setting Bidirectional outgoing_max_streams_ to " + << max_streams; + if (perspective_ == Perspective::IS_CLIENT && + max_streams < + ietf_streamid_manager_.max_outgoing_bidirectional_streams()) { + connection_->CloseConnection( + was_zero_rtt_rejected_ ? QUIC_ZERO_RTT_REJECTION_LIMIT_REDUCED + : QUIC_ZERO_RTT_RESUMPTION_LIMIT_REDUCED, + absl::StrCat( + was_zero_rtt_rejected_ + ? "Server rejected 0-RTT, aborting because " + : "", + "new bidirectional limit ", max_streams, + " decreases the current limit: ", + ietf_streamid_manager_.max_outgoing_bidirectional_streams()), + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + if (ietf_streamid_manager_.MaybeAllowNewOutgoingBidirectionalStreams( + max_streams)) { + OnCanCreateNewOutgoingStream(/*unidirectional = */ false); + } + + max_streams = 0; + if (config_.HasReceivedMaxUnidirectionalStreams()) { + max_streams = config_.ReceivedMaxUnidirectionalStreams(); + } + + if (was_zero_rtt_rejected_ && + max_streams < + ietf_streamid_manager_.outgoing_unidirectional_stream_count()) { + connection_->CloseConnection( + QUIC_ZERO_RTT_UNRETRANSMITTABLE, + absl::StrCat( + "Server rejected 0-RTT, aborting because new unidirectional " + "initial stream limit ", + max_streams, " is less than current open streams: ", + ietf_streamid_manager_.outgoing_unidirectional_stream_count()), + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + + if (max_streams < + ietf_streamid_manager_.max_outgoing_unidirectional_streams()) { + connection_->CloseConnection( + was_zero_rtt_rejected_ ? QUIC_ZERO_RTT_REJECTION_LIMIT_REDUCED + : QUIC_ZERO_RTT_RESUMPTION_LIMIT_REDUCED, + absl::StrCat( + was_zero_rtt_rejected_ + ? "Server rejected 0-RTT, aborting because " + : "", + "new unidirectional limit ", max_streams, + " decreases the current limit: ", + ietf_streamid_manager_.max_outgoing_unidirectional_streams()), + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + QUIC_DVLOG(1) << ENDPOINT + << "Setting Unidirectional outgoing_max_streams_ to " + << max_streams; + if (ietf_streamid_manager_.MaybeAllowNewOutgoingUnidirectionalStreams( + max_streams)) { + OnCanCreateNewOutgoingStream(/*unidirectional = */ true); + } + } else { + uint32_t max_streams = 0; + if (config_.HasReceivedMaxBidirectionalStreams()) { + max_streams = config_.ReceivedMaxBidirectionalStreams(); + } + QUIC_DVLOG(1) << ENDPOINT << "Setting max_open_outgoing_streams_ to " + << max_streams; + if (was_zero_rtt_rejected_ && + max_streams < stream_id_manager_.num_open_outgoing_streams()) { + connection_->CloseConnection( + QUIC_INTERNAL_ERROR, + absl::StrCat( + "Server rejected 0-RTT, aborting because new stream limit ", + max_streams, " is less than current open streams: ", + stream_id_manager_.num_open_outgoing_streams()), + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + stream_id_manager_.set_max_open_outgoing_streams(max_streams); + } + + if (perspective() == Perspective::IS_SERVER) { + if (config_.HasReceivedConnectionOptions()) { + // The following variations change the initial receive flow control + // window sizes. + if (ContainsQuicTag(config_.ReceivedConnectionOptions(), kIFW6)) { + AdjustInitialFlowControlWindows(64 * 1024); + } + if (ContainsQuicTag(config_.ReceivedConnectionOptions(), kIFW7)) { + AdjustInitialFlowControlWindows(128 * 1024); + } + if (ContainsQuicTag(config_.ReceivedConnectionOptions(), kIFW8)) { + AdjustInitialFlowControlWindows(256 * 1024); + } + if (ContainsQuicTag(config_.ReceivedConnectionOptions(), kIFW9)) { + AdjustInitialFlowControlWindows(512 * 1024); + } + if (ContainsQuicTag(config_.ReceivedConnectionOptions(), kIFWA)) { + AdjustInitialFlowControlWindows(1024 * 1024); + } + } + + config_.SetStatelessResetTokenToSend(GetStatelessResetToken()); + } + + if (VersionHasIetfQuicFrames(transport_version())) { + ietf_streamid_manager_.SetMaxOpenIncomingBidirectionalStreams( + config_.GetMaxBidirectionalStreamsToSend()); + ietf_streamid_manager_.SetMaxOpenIncomingUnidirectionalStreams( + config_.GetMaxUnidirectionalStreamsToSend()); + } else { + // A small number of additional incoming streams beyond the limit should be + // allowed. This helps avoid early connection termination when FIN/RSTs for + // old streams are lost or arrive out of order. + // Use a minimum number of additional streams, or a percentage increase, + // whichever is larger. + uint32_t max_incoming_streams_to_send = + config_.GetMaxBidirectionalStreamsToSend(); + uint32_t max_incoming_streams = + std::max(max_incoming_streams_to_send + kMaxStreamsMinimumIncrement, + static_cast(max_incoming_streams_to_send * + kMaxStreamsMultiplier)); + stream_id_manager_.set_max_open_incoming_streams(max_incoming_streams); + } + + if (connection_->version().handshake_protocol == PROTOCOL_TLS1_3) { + // When using IETF-style TLS transport parameters, inform existing streams + // of new flow-control limits. + if (config_.HasReceivedInitialMaxStreamDataBytesOutgoingBidirectional()) { + OnNewStreamOutgoingBidirectionalFlowControlWindow( + config_.ReceivedInitialMaxStreamDataBytesOutgoingBidirectional()); + } + if (config_.HasReceivedInitialMaxStreamDataBytesIncomingBidirectional()) { + OnNewStreamIncomingBidirectionalFlowControlWindow( + config_.ReceivedInitialMaxStreamDataBytesIncomingBidirectional()); + } + if (config_.HasReceivedInitialMaxStreamDataBytesUnidirectional()) { + OnNewStreamUnidirectionalFlowControlWindow( + config_.ReceivedInitialMaxStreamDataBytesUnidirectional()); + } + } else { // The version uses Google QUIC Crypto. + if (config_.HasReceivedInitialStreamFlowControlWindowBytes()) { + // Streams which were created before the SHLO was received (0-RTT + // requests) are now informed of the peer's initial flow control window. + OnNewStreamFlowControlWindow( + config_.ReceivedInitialStreamFlowControlWindowBytes()); + } + } + + if (config_.HasReceivedInitialSessionFlowControlWindowBytes()) { + OnNewSessionFlowControlWindow( + config_.ReceivedInitialSessionFlowControlWindowBytes()); + } + + if (perspective_ == Perspective::IS_SERVER && version().HasIetfQuicFrames() && + connection_->effective_peer_address().IsInitialized()) { + if (config_.HasClientSentConnectionOption(kSPAD, perspective_)) { + quiche::IpAddressFamily address_family = + connection_->effective_peer_address() + .Normalized() + .host() + .address_family(); + absl::optional preferred_address = + config_.GetPreferredAddressToSend(address_family); + if (preferred_address.has_value()) { + // Set connection ID and token if SPAD has received and a preferred + // address of the same address family is configured. + absl::optional frame = + connection_->MaybeIssueNewConnectionIdForPreferredAddress(); + if (frame.has_value()) { + config_.SetPreferredAddressConnectionIdAndTokenToSend( + frame->connection_id, frame->stateless_reset_token); + } + connection_->set_sent_server_preferred_address( + preferred_address.value()); + } + // Clear the alternative address of the other address family in the + // config. + config_.ClearAlternateServerAddressToSend( + address_family == quiche::IpAddressFamily::IP_V4 + ? quiche::IpAddressFamily::IP_V6 + : quiche::IpAddressFamily::IP_V4); + } else { + // Clear alternative IPv(4|6) addresses in config if the server hasn't + // received 'SPAD' connection option. + config_.ClearAlternateServerAddressToSend(quiche::IpAddressFamily::IP_V4); + config_.ClearAlternateServerAddressToSend(quiche::IpAddressFamily::IP_V6); + } + } + + is_configured_ = true; + connection()->OnConfigNegotiated(); + + // Ask flow controllers to try again since the config could have unblocked us. + // Or if this session is configured on TLS enabled QUIC versions, + // attempt to retransmit 0-RTT data if there's any. + // TODO(fayang): consider removing this OnCanWrite call. + if (!connection_->framer().is_processing_packet() && + (connection_->version().AllowsLowFlowControlLimits() || + version().UsesTls())) { + QUIC_CODE_COUNT(quic_session_on_can_write_on_config_negotiated); + OnCanWrite(); + } +} + +absl::optional QuicSession::OnAlpsData( + const uint8_t* /*alps_data*/, size_t /*alps_length*/) { + return absl::nullopt; +} + +void QuicSession::AdjustInitialFlowControlWindows(size_t stream_window) { + const float session_window_multiplier = + config_.GetInitialStreamFlowControlWindowToSend() + ? static_cast( + config_.GetInitialSessionFlowControlWindowToSend()) / + config_.GetInitialStreamFlowControlWindowToSend() + : 1.5; + + QUIC_DVLOG(1) << ENDPOINT << "Set stream receive window to " << stream_window; + config_.SetInitialStreamFlowControlWindowToSend(stream_window); + + size_t session_window = session_window_multiplier * stream_window; + QUIC_DVLOG(1) << ENDPOINT << "Set session receive window to " + << session_window; + config_.SetInitialSessionFlowControlWindowToSend(session_window); + flow_controller_.UpdateReceiveWindowSize(session_window); + // Inform all existing streams about the new window. + for (auto const& kv : stream_map_) { + kv.second->UpdateReceiveWindowSize(stream_window); + } + if (!QuicVersionUsesCryptoFrames(transport_version())) { + GetMutableCryptoStream()->UpdateReceiveWindowSize(stream_window); + } +} + +void QuicSession::HandleFrameOnNonexistentOutgoingStream( + QuicStreamId stream_id) { + QUICHE_DCHECK(!IsClosedStream(stream_id)); + // Received a frame for a locally-created stream that is not currently + // active. This is an error. + if (VersionHasIetfQuicFrames(transport_version())) { + connection()->CloseConnection( + QUIC_HTTP_STREAM_WRONG_DIRECTION, "Data for nonexistent stream", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + connection()->CloseConnection( + QUIC_INVALID_STREAM_ID, "Data for nonexistent stream", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); +} + +void QuicSession::HandleRstOnValidNonexistentStream( + const QuicRstStreamFrame& frame) { + // If the stream is neither originally in active streams nor created in + // GetOrCreateStream(), it could be a closed stream in which case its + // final received byte offset need to be updated. + if (IsClosedStream(frame.stream_id)) { + // The RST frame contains the final byte offset for the stream: we can now + // update the connection level flow controller if needed. + OnFinalByteOffsetReceived(frame.stream_id, frame.byte_offset); + } +} + +void QuicSession::OnNewStreamFlowControlWindow(QuicStreamOffset new_window) { + QUICHE_DCHECK(version().UsesQuicCrypto()); + QUIC_DVLOG(1) << ENDPOINT << "OnNewStreamFlowControlWindow " << new_window; + if (new_window < kMinimumFlowControlSendWindow) { + QUIC_LOG_FIRST_N(ERROR, 1) + << "Peer sent us an invalid stream flow control send window: " + << new_window << ", below minimum: " << kMinimumFlowControlSendWindow; + connection_->CloseConnection( + QUIC_FLOW_CONTROL_INVALID_WINDOW, "New stream window too low", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + + // Inform all existing streams about the new window. + for (auto const& kv : stream_map_) { + QUIC_DVLOG(1) << ENDPOINT << "Informing stream " << kv.first + << " of new stream flow control window " << new_window; + if (!kv.second->MaybeConfigSendWindowOffset( + new_window, /* was_zero_rtt_rejected = */ false)) { + return; + } + } + if (!QuicVersionUsesCryptoFrames(transport_version())) { + QUIC_DVLOG(1) + << ENDPOINT + << "Informing crypto stream of new stream flow control window " + << new_window; + GetMutableCryptoStream()->MaybeConfigSendWindowOffset( + new_window, /* was_zero_rtt_rejected = */ false); + } +} + +void QuicSession::OnNewStreamUnidirectionalFlowControlWindow( + QuicStreamOffset new_window) { + QUICHE_DCHECK_EQ(connection_->version().handshake_protocol, PROTOCOL_TLS1_3); + QUIC_DVLOG(1) << ENDPOINT << "OnNewStreamUnidirectionalFlowControlWindow " + << new_window; + // Inform all existing outgoing unidirectional streams about the new window. + for (auto const& kv : stream_map_) { + const QuicStreamId id = kv.first; + if (!version().HasIetfQuicFrames()) { + if (kv.second->type() == BIDIRECTIONAL) { + continue; + } + } else { + if (QuicUtils::IsBidirectionalStreamId(id, version())) { + continue; + } + } + if (!QuicUtils::IsOutgoingStreamId(connection_->version(), id, + perspective())) { + continue; + } + QUIC_DVLOG(1) << ENDPOINT << "Informing unidirectional stream " << id + << " of new stream flow control window " << new_window; + if (!kv.second->MaybeConfigSendWindowOffset(new_window, + was_zero_rtt_rejected_)) { + return; + } + } +} + +void QuicSession::OnNewStreamOutgoingBidirectionalFlowControlWindow( + QuicStreamOffset new_window) { + QUICHE_DCHECK_EQ(connection_->version().handshake_protocol, PROTOCOL_TLS1_3); + QUIC_DVLOG(1) << ENDPOINT + << "OnNewStreamOutgoingBidirectionalFlowControlWindow " + << new_window; + // Inform all existing outgoing bidirectional streams about the new window. + for (auto const& kv : stream_map_) { + const QuicStreamId id = kv.first; + if (!version().HasIetfQuicFrames()) { + if (kv.second->type() != BIDIRECTIONAL) { + continue; + } + } else { + if (!QuicUtils::IsBidirectionalStreamId(id, version())) { + continue; + } + } + if (!QuicUtils::IsOutgoingStreamId(connection_->version(), id, + perspective())) { + continue; + } + QUIC_DVLOG(1) << ENDPOINT << "Informing outgoing bidirectional stream " + << id << " of new stream flow control window " << new_window; + if (!kv.second->MaybeConfigSendWindowOffset(new_window, + was_zero_rtt_rejected_)) { + return; + } + } +} + +void QuicSession::OnNewStreamIncomingBidirectionalFlowControlWindow( + QuicStreamOffset new_window) { + QUICHE_DCHECK_EQ(connection_->version().handshake_protocol, PROTOCOL_TLS1_3); + QUIC_DVLOG(1) << ENDPOINT + << "OnNewStreamIncomingBidirectionalFlowControlWindow " + << new_window; + // Inform all existing incoming bidirectional streams about the new window. + for (auto const& kv : stream_map_) { + const QuicStreamId id = kv.first; + if (!version().HasIetfQuicFrames()) { + if (kv.second->type() != BIDIRECTIONAL) { + continue; + } + } else { + if (!QuicUtils::IsBidirectionalStreamId(id, version())) { + continue; + } + } + if (QuicUtils::IsOutgoingStreamId(connection_->version(), id, + perspective())) { + continue; + } + QUIC_DVLOG(1) << ENDPOINT << "Informing incoming bidirectional stream " + << id << " of new stream flow control window " << new_window; + if (!kv.second->MaybeConfigSendWindowOffset(new_window, + was_zero_rtt_rejected_)) { + return; + } + } +} + +void QuicSession::OnNewSessionFlowControlWindow(QuicStreamOffset new_window) { + QUIC_DVLOG(1) << ENDPOINT << "OnNewSessionFlowControlWindow " << new_window; + + if (was_zero_rtt_rejected_ && new_window < flow_controller_.bytes_sent()) { + std::string error_details = absl::StrCat( + "Server rejected 0-RTT. Aborting because the client received session " + "flow control send window: ", + new_window, + ", which is below currently used: ", flow_controller_.bytes_sent()); + QUIC_LOG(ERROR) << error_details; + connection_->CloseConnection( + QUIC_ZERO_RTT_UNRETRANSMITTABLE, error_details, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + if (!connection_->version().AllowsLowFlowControlLimits() && + new_window < kMinimumFlowControlSendWindow) { + std::string error_details = absl::StrCat( + "Peer sent us an invalid session flow control send window: ", + new_window, ", below minimum: ", kMinimumFlowControlSendWindow); + QUIC_LOG_FIRST_N(ERROR, 1) << error_details; + connection_->CloseConnection( + QUIC_FLOW_CONTROL_INVALID_WINDOW, error_details, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + if (perspective_ == Perspective::IS_CLIENT && + new_window < flow_controller_.send_window_offset()) { + // The client receives a lower limit than remembered, violating + // https://tools.ietf.org/html/draft-ietf-quic-transport-27#section-7.3.1 + std::string error_details = absl::StrCat( + was_zero_rtt_rejected_ ? "Server rejected 0-RTT, aborting because " + : "", + "new session max data ", new_window, + " decreases current limit: ", flow_controller_.send_window_offset()); + QUIC_LOG(ERROR) << error_details; + connection_->CloseConnection( + was_zero_rtt_rejected_ ? QUIC_ZERO_RTT_REJECTION_LIMIT_REDUCED + : QUIC_ZERO_RTT_RESUMPTION_LIMIT_REDUCED, + error_details, ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + + flow_controller_.UpdateSendWindowOffset(new_window); +} + +bool QuicSession::OnNewDecryptionKeyAvailable( + EncryptionLevel level, std::unique_ptr decrypter, + bool set_alternative_decrypter, bool latch_once_used) { + if (connection_->version().handshake_protocol == PROTOCOL_TLS1_3 && + !connection()->framer().HasEncrypterOfEncryptionLevel( + QuicUtils::GetEncryptionLevelToSendAckofSpace( + QuicUtils::GetPacketNumberSpace(level)))) { + // This should never happen because connection should never decrypt a packet + // while an ACK for it cannot be encrypted. + return false; + } + if (connection()->version().KnowsWhichDecrypterToUse()) { + connection()->InstallDecrypter(level, std::move(decrypter)); + return true; + } + if (set_alternative_decrypter) { + connection()->SetAlternativeDecrypter(level, std::move(decrypter), + latch_once_used); + return true; + } + connection()->SetDecrypter(level, std::move(decrypter)); + return true; +} + +void QuicSession::OnNewEncryptionKeyAvailable( + EncryptionLevel level, std::unique_ptr encrypter) { + connection()->SetEncrypter(level, std::move(encrypter)); + if (connection_->version().handshake_protocol != PROTOCOL_TLS1_3) { + return; + } + + bool reset_encryption_level = false; + if (IsEncryptionEstablished() && level == ENCRYPTION_HANDSHAKE) { + // ENCRYPTION_HANDSHAKE keys are only used for the handshake. If + // ENCRYPTION_ZERO_RTT keys exist, it is possible for a client to send + // stream data, which must not be sent at the ENCRYPTION_HANDSHAKE level. + // Therefore, we avoid setting the default encryption level to + // ENCRYPTION_HANDSHAKE. + reset_encryption_level = true; + } + QUIC_DVLOG(1) << ENDPOINT << "Set default encryption level to " << level; + connection()->SetDefaultEncryptionLevel(level); + if (reset_encryption_level) { + connection()->SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + } + QUIC_BUG_IF(quic_bug_12435_7, + IsEncryptionEstablished() && + (connection()->encryption_level() == ENCRYPTION_INITIAL || + connection()->encryption_level() == ENCRYPTION_HANDSHAKE)) + << "Encryption is established, but the encryption level " << level + << " does not support sending stream data"; +} + +void QuicSession::SetDefaultEncryptionLevel(EncryptionLevel level) { + QUICHE_DCHECK_EQ(PROTOCOL_QUIC_CRYPTO, + connection_->version().handshake_protocol); + QUIC_DVLOG(1) << ENDPOINT << "Set default encryption level to " << level; + connection()->SetDefaultEncryptionLevel(level); + + switch (level) { + case ENCRYPTION_INITIAL: + break; + case ENCRYPTION_ZERO_RTT: + if (perspective() == Perspective::IS_CLIENT) { + // Retransmit old 0-RTT data (if any) with the new 0-RTT keys, since + // they can't be decrypted by the server. + connection_->MarkZeroRttPacketsForRetransmission(0); + if (!connection_->framer().is_processing_packet()) { + // TODO(fayang): consider removing this OnCanWrite call. + // Given any streams blocked by encryption a chance to write. + QUIC_CODE_COUNT( + quic_session_on_can_write_set_default_encryption_level); + OnCanWrite(); + } + } + break; + case ENCRYPTION_HANDSHAKE: + break; + case ENCRYPTION_FORWARD_SECURE: + QUIC_BUG_IF(quic_bug_12435_8, !config_.negotiated()) + << ENDPOINT << "Handshake confirmed without parameter negotiation."; + connection()->mutable_stats().handshake_completion_time = + connection()->clock()->ApproximateNow(); + break; + default: + QUIC_BUG(quic_bug_10866_7) << "Unknown encryption level: " << level; + } +} + +void QuicSession::OnTlsHandshakeComplete() { + QUICHE_DCHECK_EQ(PROTOCOL_TLS1_3, connection_->version().handshake_protocol); + QUIC_BUG_IF(quic_bug_12435_9, + !GetCryptoStream()->crypto_negotiated_params().cipher_suite) + << ENDPOINT << "Handshake completes without cipher suite negotiation."; + QUIC_BUG_IF(quic_bug_12435_10, !config_.negotiated()) + << ENDPOINT << "Handshake completes without parameter negotiation."; + connection()->mutable_stats().handshake_completion_time = + connection()->clock()->ApproximateNow(); + if (connection()->version().UsesTls() && + perspective_ == Perspective::IS_SERVER) { + // Server sends HANDSHAKE_DONE to signal confirmation of the handshake + // to the client. + control_frame_manager_.WriteOrBufferHandshakeDone(); + if (connection()->version().HasIetfQuicFrames()) { + MaybeSendAddressToken(); + } + } +} + +bool QuicSession::MaybeSendAddressToken() { + QUICHE_DCHECK(perspective_ == Perspective::IS_SERVER && + connection()->version().HasIetfQuicFrames()); + absl::optional cached_network_params = + GenerateCachedNetworkParameters(); + + std::string address_token = GetCryptoStream()->GetAddressToken( + cached_network_params.has_value() ? &cached_network_params.value() + : nullptr); + if (address_token.empty()) { + return false; + } + const size_t buf_len = address_token.length() + 1; + auto buffer = std::make_unique(buf_len); + QuicDataWriter writer(buf_len, buffer.get()); + // Add |kAddressTokenPrefix| for token sent in NEW_TOKEN frame. + writer.WriteUInt8(kAddressTokenPrefix); + writer.WriteBytes(address_token.data(), address_token.length()); + control_frame_manager_.WriteOrBufferNewToken( + absl::string_view(buffer.get(), buf_len)); + if (cached_network_params.has_value()) { + connection()->OnSendConnectionState(*cached_network_params); + } + return true; +} + +void QuicSession::DiscardOldDecryptionKey(EncryptionLevel level) { + if (!connection()->version().KnowsWhichDecrypterToUse()) { + return; + } + connection()->RemoveDecrypter(level); +} + +void QuicSession::DiscardOldEncryptionKey(EncryptionLevel level) { + QUIC_DLOG(INFO) << ENDPOINT << "Discarding " << level << " keys"; + if (connection()->version().handshake_protocol == PROTOCOL_TLS1_3) { + connection()->RemoveEncrypter(level); + } + switch (level) { + case ENCRYPTION_INITIAL: + NeuterUnencryptedData(); + break; + case ENCRYPTION_HANDSHAKE: + NeuterHandshakeData(); + break; + case ENCRYPTION_ZERO_RTT: + break; + case ENCRYPTION_FORWARD_SECURE: + QUIC_BUG(quic_bug_10866_8) + << ENDPOINT << "Discarding 1-RTT keys is not allowed"; + break; + default: + QUIC_BUG(quic_bug_10866_9) + << ENDPOINT + << "Cannot discard keys for unknown encryption level: " << level; + } +} + +void QuicSession::NeuterHandshakeData() { + GetMutableCryptoStream()->NeuterStreamDataOfEncryptionLevel( + ENCRYPTION_HANDSHAKE); + connection()->OnHandshakeComplete(); +} + +void QuicSession::OnZeroRttRejected(int reason) { + was_zero_rtt_rejected_ = true; + connection_->MarkZeroRttPacketsForRetransmission(reason); + if (connection_->encryption_level() == ENCRYPTION_FORWARD_SECURE) { + QUIC_BUG(quic_bug_10866_10) + << "1-RTT keys already available when 0-RTT is rejected."; + connection_->CloseConnection( + QUIC_INTERNAL_ERROR, + "1-RTT keys already available when 0-RTT is rejected.", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + } +} + +bool QuicSession::FillTransportParameters(TransportParameters* params) { + if (version().UsesTls()) { + if (perspective() == Perspective::IS_SERVER) { + config_.SetOriginalConnectionIdToSend( + connection_->GetOriginalDestinationConnectionId()); + config_.SetInitialSourceConnectionIdToSend(connection_->connection_id()); + } else { + config_.SetInitialSourceConnectionIdToSend( + connection_->client_connection_id()); + } + } + return config_.FillTransportParameters(params); +} + +QuicErrorCode QuicSession::ProcessTransportParameters( + const TransportParameters& params, bool is_resumption, + std::string* error_details) { + return config_.ProcessTransportParameters(params, is_resumption, + error_details); +} + +void QuicSession::OnHandshakeCallbackDone() { + if (!connection_->connected()) { + return; + } + + if (!connection()->is_processing_packet()) { + connection()->MaybeProcessUndecryptablePackets(); + } +} + +bool QuicSession::PacketFlusherAttached() const { + QUICHE_DCHECK(connection_->connected()); + return connection()->packet_creator().PacketFlusherAttached(); +} + +void QuicSession::OnCryptoHandshakeMessageSent( + const CryptoHandshakeMessage& /*message*/) {} + +void QuicSession::OnCryptoHandshakeMessageReceived( + const CryptoHandshakeMessage& /*message*/) {} + +void QuicSession::RegisterStreamPriority(QuicStreamId id, bool is_static, + const QuicStreamPriority& priority) { + write_blocked_streams()->RegisterStream(id, is_static, priority); +} + +void QuicSession::UnregisterStreamPriority(QuicStreamId id) { + write_blocked_streams()->UnregisterStream(id); +} + +void QuicSession::UpdateStreamPriority(QuicStreamId id, + const QuicStreamPriority& new_priority) { + write_blocked_streams()->UpdateStreamPriority(id, new_priority); +} + +void QuicSession::ActivateStream(std::unique_ptr stream) { + const bool should_keep_alive = ShouldKeepConnectionAlive(); + QuicStreamId stream_id = stream->id(); + bool is_static = stream->is_static(); + QUIC_DVLOG(1) << ENDPOINT << "num_streams: " << stream_map_.size() + << ". activating stream " << stream_id; + QUICHE_DCHECK(!stream_map_.contains(stream_id)); + stream_map_[stream_id] = std::move(stream); + if (is_static) { + ++num_static_streams_; + return; + } + if (!VersionHasIetfQuicFrames(transport_version())) { + // Do not inform stream ID manager of static streams. + stream_id_manager_.ActivateStream( + /*is_incoming=*/IsIncomingStream(stream_id)); + } + if (perspective() == Perspective::IS_CLIENT && + connection()->multi_port_stats() != nullptr && !should_keep_alive && + ShouldKeepConnectionAlive()) { + connection()->MaybeProbeMultiPortPath(); + } +} + +QuicStreamId QuicSession::GetNextOutgoingBidirectionalStreamId() { + if (VersionHasIetfQuicFrames(transport_version())) { + return ietf_streamid_manager_.GetNextOutgoingBidirectionalStreamId(); + } + return stream_id_manager_.GetNextOutgoingStreamId(); +} + +QuicStreamId QuicSession::GetNextOutgoingUnidirectionalStreamId() { + if (VersionHasIetfQuicFrames(transport_version())) { + return ietf_streamid_manager_.GetNextOutgoingUnidirectionalStreamId(); + } + return stream_id_manager_.GetNextOutgoingStreamId(); +} + +bool QuicSession::CanOpenNextOutgoingBidirectionalStream() { + if (liveness_testing_in_progress_) { + QUICHE_DCHECK_EQ(Perspective::IS_CLIENT, perspective()); + QUIC_CODE_COUNT( + quic_client_fails_to_create_stream_liveness_testing_in_progress); + return false; + } + if (!VersionHasIetfQuicFrames(transport_version())) { + if (!stream_id_manager_.CanOpenNextOutgoingStream()) { + return false; + } + } else { + if (!ietf_streamid_manager_.CanOpenNextOutgoingBidirectionalStream()) { + QUIC_CODE_COUNT( + quic_fails_to_create_stream_close_too_many_streams_created); + if (is_configured_) { + // Send STREAM_BLOCKED after config negotiated. + control_frame_manager_.WriteOrBufferStreamsBlocked( + ietf_streamid_manager_.max_outgoing_bidirectional_streams(), + /*unidirectional=*/false); + } + return false; + } + } + if (perspective() == Perspective::IS_CLIENT && + connection_->MaybeTestLiveness()) { + // Now is relatively close to the idle timeout having the risk that requests + // could be discarded at the server. + liveness_testing_in_progress_ = true; + QUIC_CODE_COUNT(quic_client_fails_to_create_stream_close_to_idle_timeout); + return false; + } + return true; +} + +bool QuicSession::CanOpenNextOutgoingUnidirectionalStream() { + if (!VersionHasIetfQuicFrames(transport_version())) { + return stream_id_manager_.CanOpenNextOutgoingStream(); + } + if (ietf_streamid_manager_.CanOpenNextOutgoingUnidirectionalStream()) { + return true; + } + if (is_configured_) { + // Send STREAM_BLOCKED after config negotiated. + control_frame_manager_.WriteOrBufferStreamsBlocked( + ietf_streamid_manager_.max_outgoing_unidirectional_streams(), + /*unidirectional=*/true); + } + return false; +} + +QuicStreamCount QuicSession::GetAdvertisedMaxIncomingBidirectionalStreams() + const { + QUICHE_DCHECK(VersionHasIetfQuicFrames(transport_version())); + return ietf_streamid_manager_.advertised_max_incoming_bidirectional_streams(); +} + +QuicStream* QuicSession::GetOrCreateStream(const QuicStreamId stream_id) { + QUICHE_DCHECK(!pending_stream_map_.contains(stream_id)); + if (QuicUtils::IsCryptoStreamId(transport_version(), stream_id)) { + return GetMutableCryptoStream(); + } + + StreamMap::iterator it = stream_map_.find(stream_id); + if (it != stream_map_.end()) { + return it->second->IsZombie() ? nullptr : it->second.get(); + } + + if (IsClosedStream(stream_id)) { + return nullptr; + } + + if (!IsIncomingStream(stream_id)) { + HandleFrameOnNonexistentOutgoingStream(stream_id); + return nullptr; + } + + // TODO(fkastenholz): If we are creating a new stream and we have sent a + // goaway, we should ignore the stream creation. Need to add code to A) test + // if goaway was sent ("if (transport_goaway_sent_)") and B) reject stream + // creation ("return nullptr") + + if (!MaybeIncreaseLargestPeerStreamId(stream_id)) { + return nullptr; + } + + if (!VersionHasIetfQuicFrames(transport_version()) && + !stream_id_manager_.CanOpenIncomingStream()) { + // Refuse to open the stream. + ResetStream(stream_id, QUIC_REFUSED_STREAM); + return nullptr; + } + + return CreateIncomingStream(stream_id); +} + +void QuicSession::StreamDraining(QuicStreamId stream_id, bool unidirectional) { + QUICHE_DCHECK(stream_map_.contains(stream_id)); + QUIC_DVLOG(1) << ENDPOINT << "Stream " << stream_id << " is draining"; + if (VersionHasIetfQuicFrames(transport_version())) { + ietf_streamid_manager_.OnStreamClosed(stream_id); + } else { + stream_id_manager_.OnStreamClosed( + /*is_incoming=*/IsIncomingStream(stream_id)); + } + ++num_draining_streams_; + if (!IsIncomingStream(stream_id)) { + ++num_outgoing_draining_streams_; + if (!VersionHasIetfQuicFrames(transport_version())) { + OnCanCreateNewOutgoingStream(unidirectional); + } + } +} + +bool QuicSession::MaybeIncreaseLargestPeerStreamId( + const QuicStreamId stream_id) { + if (VersionHasIetfQuicFrames(transport_version())) { + std::string error_details; + if (ietf_streamid_manager_.MaybeIncreaseLargestPeerStreamId( + stream_id, &error_details)) { + return true; + } + connection()->CloseConnection( + QUIC_INVALID_STREAM_ID, error_details, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return false; + } + if (!stream_id_manager_.MaybeIncreaseLargestPeerStreamId(stream_id)) { + connection()->CloseConnection( + QUIC_TOO_MANY_AVAILABLE_STREAMS, + absl::StrCat(stream_id, " exceeds available streams ", + stream_id_manager_.MaxAvailableStreams()), + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return false; + } + return true; +} + +bool QuicSession::ShouldYield(QuicStreamId stream_id) { + if (stream_id == currently_writing_stream_id_) { + return false; + } + return write_blocked_streams()->ShouldYield(stream_id); +} + +PendingStream* QuicSession::GetOrCreatePendingStream(QuicStreamId stream_id) { + auto it = pending_stream_map_.find(stream_id); + if (it != pending_stream_map_.end()) { + return it->second.get(); + } + + if (IsClosedStream(stream_id) || + !MaybeIncreaseLargestPeerStreamId(stream_id)) { + return nullptr; + } + + auto pending = std::make_unique(stream_id, this); + PendingStream* unowned_pending = pending.get(); + pending_stream_map_[stream_id] = std::move(pending); + return unowned_pending; +} + +void QuicSession::set_largest_peer_created_stream_id( + QuicStreamId largest_peer_created_stream_id) { + QUICHE_DCHECK(!VersionHasIetfQuicFrames(transport_version())); + stream_id_manager_.set_largest_peer_created_stream_id( + largest_peer_created_stream_id); +} + +QuicStreamId QuicSession::GetLargestPeerCreatedStreamId( + bool unidirectional) const { + // This method is only used in IETF QUIC. + QUICHE_DCHECK(VersionHasIetfQuicFrames(transport_version())); + return ietf_streamid_manager_.GetLargestPeerCreatedStreamId(unidirectional); +} + +void QuicSession::DeleteConnection() { + if (connection_) { + delete connection_; + connection_ = nullptr; + } +} + +bool QuicSession::MaybeSetStreamPriority(QuicStreamId stream_id, + const QuicStreamPriority& priority) { + auto active_stream = stream_map_.find(stream_id); + if (active_stream != stream_map_.end()) { + active_stream->second->SetPriority(priority); + return true; + } + + return false; +} + +bool QuicSession::IsClosedStream(QuicStreamId id) { + QUICHE_DCHECK_NE(QuicUtils::GetInvalidStreamId(transport_version()), id); + if (IsOpenStream(id)) { + // Stream is active + return false; + } + + if (VersionHasIetfQuicFrames(transport_version())) { + return !ietf_streamid_manager_.IsAvailableStream(id); + } + + return !stream_id_manager_.IsAvailableStream(id); +} + +bool QuicSession::IsOpenStream(QuicStreamId id) { + QUICHE_DCHECK_NE(QuicUtils::GetInvalidStreamId(transport_version()), id); + const StreamMap::iterator it = stream_map_.find(id); + if (it != stream_map_.end()) { + return !it->second->IsZombie(); + } + if (pending_stream_map_.contains(id) || + QuicUtils::IsCryptoStreamId(transport_version(), id)) { + // Stream is active + return true; + } + return false; +} + +bool QuicSession::IsStaticStream(QuicStreamId id) const { + auto it = stream_map_.find(id); + if (it == stream_map_.end()) { + return false; + } + return it->second->is_static(); +} + +size_t QuicSession::GetNumActiveStreams() const { + QUICHE_DCHECK_GE( + static_cast(stream_map_.size()), + num_static_streams_ + num_draining_streams_ + num_zombie_streams_); + return stream_map_.size() - num_draining_streams_ - num_static_streams_ - + num_zombie_streams_; +} + +void QuicSession::MarkConnectionLevelWriteBlocked(QuicStreamId id) { + if (GetOrCreateStream(id) == nullptr) { + QUIC_BUG(quic_bug_10866_11) + << "Marking unknown stream " << id << " blocked."; + QUIC_LOG_FIRST_N(ERROR, 2) << QuicStackTrace(); + } + + QUIC_DVLOG(1) << ENDPOINT << "Adding stream " << id + << " to write-blocked list"; + + write_blocked_streams_->AddStream(id); +} + +bool QuicSession::HasDataToWrite() const { + return write_blocked_streams_->HasWriteBlockedSpecialStream() || + write_blocked_streams_->HasWriteBlockedDataStreams() || + connection_->HasQueuedData() || + !streams_with_pending_retransmission_.empty() || + control_frame_manager_.WillingToWrite(); +} + +void QuicSession::OnAckNeedsRetransmittableFrame() { + flow_controller_.SendWindowUpdate(); +} + +void QuicSession::SendAckFrequency(const QuicAckFrequencyFrame& frame) { + control_frame_manager_.WriteOrBufferAckFrequency(frame); +} + +void QuicSession::SendNewConnectionId(const QuicNewConnectionIdFrame& frame) { + // Count NEW_CONNECTION_ID frames sent to client. + QUIC_RELOADABLE_FLAG_COUNT_N(quic_connection_migration_use_new_cid_v2, 1, 6); + control_frame_manager_.WriteOrBufferNewConnectionId( + frame.connection_id, frame.sequence_number, frame.retire_prior_to, + frame.stateless_reset_token); +} + +void QuicSession::SendRetireConnectionId(uint64_t sequence_number) { + // Count RETIRE_CONNECTION_ID frames sent to client. + QUIC_RELOADABLE_FLAG_COUNT_N(quic_connection_migration_use_new_cid_v2, 2, 6); + control_frame_manager_.WriteOrBufferRetireConnectionId(sequence_number); +} + +bool QuicSession::MaybeReserveConnectionId( + const QuicConnectionId& server_connection_id) { + if (visitor_) { + return visitor_->TryAddNewConnectionId( + connection_->GetOneActiveServerConnectionId(), server_connection_id); + } + return true; +} + +void QuicSession::OnServerConnectionIdRetired( + const QuicConnectionId& server_connection_id) { + if (visitor_) { + visitor_->OnConnectionIdRetired(server_connection_id); + } +} + +bool QuicSession::IsConnectionFlowControlBlocked() const { + return flow_controller_.IsBlocked(); +} + +bool QuicSession::IsStreamFlowControlBlocked() { + for (auto const& kv : stream_map_) { + if (kv.second->IsFlowControlBlocked()) { + return true; + } + } + if (!QuicVersionUsesCryptoFrames(transport_version()) && + GetMutableCryptoStream()->IsFlowControlBlocked()) { + return true; + } + return false; +} + +size_t QuicSession::MaxAvailableBidirectionalStreams() const { + if (VersionHasIetfQuicFrames(transport_version())) { + return ietf_streamid_manager_.GetMaxAllowdIncomingBidirectionalStreams(); + } + return stream_id_manager_.MaxAvailableStreams(); +} + +size_t QuicSession::MaxAvailableUnidirectionalStreams() const { + if (VersionHasIetfQuicFrames(transport_version())) { + return ietf_streamid_manager_.GetMaxAllowdIncomingUnidirectionalStreams(); + } + return stream_id_manager_.MaxAvailableStreams(); +} + +bool QuicSession::IsIncomingStream(QuicStreamId id) const { + if (VersionHasIetfQuicFrames(transport_version())) { + return !QuicUtils::IsOutgoingStreamId(version(), id, perspective_); + } + return stream_id_manager_.IsIncomingStream(id); +} + +void QuicSession::MaybeCloseZombieStream(QuicStreamId id) { + auto it = stream_map_.find(id); + if (it == stream_map_.end()) { + return; + } + --num_zombie_streams_; + closed_streams_.push_back(std::move(it->second)); + stream_map_.erase(it); + + if (!closed_streams_clean_up_alarm_->IsSet()) { + closed_streams_clean_up_alarm_->Set(connection_->clock()->ApproximateNow()); + } + // Do not retransmit data of a closed stream. + streams_with_pending_retransmission_.erase(id); + connection_->QuicBugIfHasPendingFrames(id); +} + +QuicStream* QuicSession::GetStream(QuicStreamId id) const { + auto active_stream = stream_map_.find(id); + if (active_stream != stream_map_.end()) { + return active_stream->second.get(); + } + + if (QuicUtils::IsCryptoStreamId(transport_version(), id)) { + return const_cast(GetCryptoStream()); + } + + return nullptr; +} + +QuicStream* QuicSession::GetActiveStream(QuicStreamId id) const { + auto stream = stream_map_.find(id); + if (stream != stream_map_.end() && !stream->second->is_static()) { + return stream->second.get(); + } + return nullptr; +} + +bool QuicSession::OnFrameAcked(const QuicFrame& frame, + QuicTime::Delta ack_delay_time, + QuicTime receive_timestamp) { + if (frame.type == MESSAGE_FRAME) { + OnMessageAcked(frame.message_frame->message_id, receive_timestamp); + return true; + } + if (frame.type == CRYPTO_FRAME) { + return GetMutableCryptoStream()->OnCryptoFrameAcked(*frame.crypto_frame, + ack_delay_time); + } + if (frame.type != STREAM_FRAME) { + return control_frame_manager_.OnControlFrameAcked(frame); + } + bool new_stream_data_acked = false; + QuicStream* stream = GetStream(frame.stream_frame.stream_id); + // Stream can already be reset when sent frame gets acked. + if (stream != nullptr) { + QuicByteCount newly_acked_length = 0; + new_stream_data_acked = stream->OnStreamFrameAcked( + frame.stream_frame.offset, frame.stream_frame.data_length, + frame.stream_frame.fin, ack_delay_time, receive_timestamp, + &newly_acked_length); + if (!stream->HasPendingRetransmission()) { + streams_with_pending_retransmission_.erase(stream->id()); + } + } + return new_stream_data_acked; +} + +void QuicSession::OnStreamFrameRetransmitted(const QuicStreamFrame& frame) { + QuicStream* stream = GetStream(frame.stream_id); + if (stream == nullptr) { + QUIC_BUG(quic_bug_10866_12) + << "Stream: " << frame.stream_id << " is closed when " << frame + << " is retransmitted."; + connection()->CloseConnection( + QUIC_INTERNAL_ERROR, "Attempt to retransmit frame of a closed stream", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + stream->OnStreamFrameRetransmitted(frame.offset, frame.data_length, + frame.fin); +} + +void QuicSession::OnFrameLost(const QuicFrame& frame) { + if (frame.type == MESSAGE_FRAME) { + OnMessageLost(frame.message_frame->message_id); + return; + } + if (frame.type == CRYPTO_FRAME) { + GetMutableCryptoStream()->OnCryptoFrameLost(frame.crypto_frame); + return; + } + if (frame.type != STREAM_FRAME) { + control_frame_manager_.OnControlFrameLost(frame); + return; + } + QuicStream* stream = GetStream(frame.stream_frame.stream_id); + if (stream == nullptr) { + return; + } + stream->OnStreamFrameLost(frame.stream_frame.offset, + frame.stream_frame.data_length, + frame.stream_frame.fin); + if (stream->HasPendingRetransmission() && + !streams_with_pending_retransmission_.contains( + frame.stream_frame.stream_id)) { + streams_with_pending_retransmission_.insert( + std::make_pair(frame.stream_frame.stream_id, true)); + } +} + +bool QuicSession::RetransmitFrames(const QuicFrames& frames, + TransmissionType type) { + QuicConnection::ScopedPacketFlusher retransmission_flusher(connection_); + for (const QuicFrame& frame : frames) { + if (frame.type == MESSAGE_FRAME) { + // Do not retransmit MESSAGE frames. + continue; + } + if (frame.type == CRYPTO_FRAME) { + if (!GetMutableCryptoStream()->RetransmitData(frame.crypto_frame, type)) { + return false; + } + continue; + } + if (frame.type != STREAM_FRAME) { + if (!control_frame_manager_.RetransmitControlFrame(frame, type)) { + return false; + } + continue; + } + QuicStream* stream = GetStream(frame.stream_frame.stream_id); + if (stream != nullptr && + !stream->RetransmitStreamData(frame.stream_frame.offset, + frame.stream_frame.data_length, + frame.stream_frame.fin, type)) { + return false; + } + } + return true; +} + +bool QuicSession::IsFrameOutstanding(const QuicFrame& frame) const { + if (frame.type == MESSAGE_FRAME) { + return false; + } + if (frame.type == CRYPTO_FRAME) { + return GetCryptoStream()->IsFrameOutstanding( + frame.crypto_frame->level, frame.crypto_frame->offset, + frame.crypto_frame->data_length); + } + if (frame.type != STREAM_FRAME) { + return control_frame_manager_.IsControlFrameOutstanding(frame); + } + QuicStream* stream = GetStream(frame.stream_frame.stream_id); + return stream != nullptr && + stream->IsStreamFrameOutstanding(frame.stream_frame.offset, + frame.stream_frame.data_length, + frame.stream_frame.fin); +} + +bool QuicSession::HasUnackedCryptoData() const { + const QuicCryptoStream* crypto_stream = GetCryptoStream(); + return crypto_stream->IsWaitingForAcks() || crypto_stream->HasBufferedData(); +} + +bool QuicSession::HasUnackedStreamData() const { + for (const auto& it : stream_map_) { + if (it.second->IsWaitingForAcks()) { + return true; + } + } + return false; +} + +HandshakeState QuicSession::GetHandshakeState() const { + return GetCryptoStream()->GetHandshakeState(); +} + +WriteStreamDataResult QuicSession::WriteStreamData(QuicStreamId id, + QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* writer) { + QuicStream* stream = GetStream(id); + if (stream == nullptr) { + // This causes the connection to be closed because of failed to serialize + // packet. + QUIC_BUG(quic_bug_10866_13) + << "Stream " << id << " does not exist when trying to write data." + << " version:" << transport_version(); + return STREAM_MISSING; + } + if (stream->WriteStreamData(offset, data_length, writer)) { + return WRITE_SUCCESS; + } + return WRITE_FAILED; +} + +bool QuicSession::WriteCryptoData(EncryptionLevel level, + QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* writer) { + return GetMutableCryptoStream()->WriteCryptoFrame(level, offset, data_length, + writer); +} + +StatelessResetToken QuicSession::GetStatelessResetToken() const { + return QuicUtils::GenerateStatelessResetToken(connection_->connection_id()); +} + +bool QuicSession::CanWriteStreamData() const { + // Don't write stream data if there are queued data packets. + if (connection_->HasQueuedPackets()) { + return false; + } + // Immediately write handshake data. + if (HasPendingHandshake()) { + return true; + } + return connection_->CanWrite(HAS_RETRANSMITTABLE_DATA); +} + +bool QuicSession::RetransmitLostData() { + QuicConnection::ScopedPacketFlusher retransmission_flusher(connection_); + // Retransmit crypto data first. + bool uses_crypto_frames = QuicVersionUsesCryptoFrames(transport_version()); + QuicCryptoStream* crypto_stream = GetMutableCryptoStream(); + if (uses_crypto_frames && crypto_stream->HasPendingCryptoRetransmission()) { + crypto_stream->WritePendingCryptoRetransmission(); + } + // Retransmit crypto data in stream 1 frames (version < 47). + if (!uses_crypto_frames && + streams_with_pending_retransmission_.contains( + QuicUtils::GetCryptoStreamId(transport_version()))) { + // Retransmit crypto data first. + QuicStream* crypto_stream = + GetStream(QuicUtils::GetCryptoStreamId(transport_version())); + crypto_stream->OnCanWrite(); + QUICHE_DCHECK(CheckStreamWriteBlocked(crypto_stream)); + if (crypto_stream->HasPendingRetransmission()) { + // Connection is write blocked. + return false; + } else { + streams_with_pending_retransmission_.erase( + QuicUtils::GetCryptoStreamId(transport_version())); + } + } + if (control_frame_manager_.HasPendingRetransmission()) { + control_frame_manager_.OnCanWrite(); + if (control_frame_manager_.HasPendingRetransmission()) { + return false; + } + } + while (!streams_with_pending_retransmission_.empty()) { + if (!CanWriteStreamData()) { + break; + } + // Retransmit lost data on headers and data streams. + const QuicStreamId id = streams_with_pending_retransmission_.begin()->first; + QuicStream* stream = GetStream(id); + if (stream != nullptr) { + stream->OnCanWrite(); + QUICHE_DCHECK(CheckStreamWriteBlocked(stream)); + if (stream->HasPendingRetransmission()) { + // Connection is write blocked. + break; + } else if (!streams_with_pending_retransmission_.empty() && + streams_with_pending_retransmission_.begin()->first == id) { + // Retransmit lost data may cause connection close. If this stream + // has not yet sent fin, a RST_STREAM will be sent and it will be + // removed from streams_with_pending_retransmission_. + streams_with_pending_retransmission_.pop_front(); + } + } else { + QUIC_BUG(quic_bug_10866_14) + << "Try to retransmit data of a closed stream"; + streams_with_pending_retransmission_.pop_front(); + } + } + + return streams_with_pending_retransmission_.empty(); +} + +void QuicSession::NeuterUnencryptedData() { + QuicCryptoStream* crypto_stream = GetMutableCryptoStream(); + crypto_stream->NeuterUnencryptedStreamData(); + if (!crypto_stream->HasPendingRetransmission() && + !QuicVersionUsesCryptoFrames(transport_version())) { + streams_with_pending_retransmission_.erase( + QuicUtils::GetCryptoStreamId(transport_version())); + } + connection_->NeuterUnencryptedPackets(); +} + +void QuicSession::SetTransmissionType(TransmissionType type) { + connection_->SetTransmissionType(type); +} + +MessageResult QuicSession::SendMessage( + absl::Span message) { + return SendMessage(message, /*flush=*/false); +} + +MessageResult QuicSession::SendMessage(quiche::QuicheMemSlice message) { + return SendMessage(absl::MakeSpan(&message, 1), /*flush=*/false); +} + +MessageResult QuicSession::SendMessage( + absl::Span message, bool flush) { + QUICHE_DCHECK(connection_->connected()) + << ENDPOINT << "Try to write messages when connection is closed."; + if (!IsEncryptionEstablished()) { + return {MESSAGE_STATUS_ENCRYPTION_NOT_ESTABLISHED, 0}; + } + QuicConnection::ScopedEncryptionLevelContext context( + connection(), GetEncryptionLevelToSendApplicationData()); + MessageStatus result = + connection_->SendMessage(last_message_id_ + 1, message, flush); + if (result == MESSAGE_STATUS_SUCCESS) { + return {result, ++last_message_id_}; + } + return {result, 0}; +} + +void QuicSession::OnMessageAcked(QuicMessageId message_id, + QuicTime /*receive_timestamp*/) { + QUIC_DVLOG(1) << ENDPOINT << "message " << message_id << " gets acked."; +} + +void QuicSession::OnMessageLost(QuicMessageId message_id) { + QUIC_DVLOG(1) << ENDPOINT << "message " << message_id + << " is considered lost"; +} + +void QuicSession::CleanUpClosedStreams() { closed_streams_.clear(); } + +QuicPacketLength QuicSession::GetCurrentLargestMessagePayload() const { + return connection_->GetCurrentLargestMessagePayload(); +} + +QuicPacketLength QuicSession::GetGuaranteedLargestMessagePayload() const { + return connection_->GetGuaranteedLargestMessagePayload(); +} + +QuicStreamId QuicSession::next_outgoing_bidirectional_stream_id() const { + if (VersionHasIetfQuicFrames(transport_version())) { + return ietf_streamid_manager_.next_outgoing_bidirectional_stream_id(); + } + return stream_id_manager_.next_outgoing_stream_id(); +} + +QuicStreamId QuicSession::next_outgoing_unidirectional_stream_id() const { + if (VersionHasIetfQuicFrames(transport_version())) { + return ietf_streamid_manager_.next_outgoing_unidirectional_stream_id(); + } + return stream_id_manager_.next_outgoing_stream_id(); +} + +bool QuicSession::OnMaxStreamsFrame(const QuicMaxStreamsFrame& frame) { + const bool allow_new_streams = + frame.unidirectional + ? ietf_streamid_manager_.MaybeAllowNewOutgoingUnidirectionalStreams( + frame.stream_count) + : ietf_streamid_manager_.MaybeAllowNewOutgoingBidirectionalStreams( + frame.stream_count); + if (allow_new_streams) { + OnCanCreateNewOutgoingStream(frame.unidirectional); + } + + return true; +} + +bool QuicSession::OnStreamsBlockedFrame(const QuicStreamsBlockedFrame& frame) { + std::string error_details; + if (ietf_streamid_manager_.OnStreamsBlockedFrame(frame, &error_details)) { + return true; + } + connection_->CloseConnection( + QUIC_STREAMS_BLOCKED_ERROR, error_details, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return false; +} + +size_t QuicSession::max_open_incoming_bidirectional_streams() const { + if (VersionHasIetfQuicFrames(transport_version())) { + return ietf_streamid_manager_.GetMaxAllowdIncomingBidirectionalStreams(); + } + return stream_id_manager_.max_open_incoming_streams(); +} + +size_t QuicSession::max_open_incoming_unidirectional_streams() const { + if (VersionHasIetfQuicFrames(transport_version())) { + return ietf_streamid_manager_.GetMaxAllowdIncomingUnidirectionalStreams(); + } + return stream_id_manager_.max_open_incoming_streams(); +} + +std::vector::const_iterator QuicSession::SelectAlpn( + const std::vector& alpns) const { + const std::string alpn = AlpnForVersion(connection()->version()); + return std::find(alpns.cbegin(), alpns.cend(), alpn); +} + +void QuicSession::OnAlpnSelected(absl::string_view alpn) { + QUIC_DLOG(INFO) << (perspective() == Perspective::IS_SERVER ? "Server: " + : "Client: ") + << "ALPN selected: " << alpn; +} + +void QuicSession::NeuterCryptoDataOfEncryptionLevel(EncryptionLevel level) { + GetMutableCryptoStream()->NeuterStreamDataOfEncryptionLevel(level); +} + +void QuicSession::PerformActionOnActiveStreams( + std::function action) { + std::vector active_streams; + for (const auto& it : stream_map_) { + if (!it.second->is_static() && !it.second->IsZombie()) { + active_streams.push_back(it.second.get()); + } + } + + for (QuicStream* stream : active_streams) { + if (!action(stream)) { + return; + } + } +} + +void QuicSession::PerformActionOnActiveStreams( + std::function action) const { + for (const auto& it : stream_map_) { + if (!it.second->is_static() && !it.second->IsZombie() && + !action(it.second.get())) { + return; + } + } +} + +EncryptionLevel QuicSession::GetEncryptionLevelToSendApplicationData() const { + return connection_->framer().GetEncryptionLevelToSendApplicationData(); +} + +void QuicSession::ProcessAllPendingStreams() { + std::vector pending_streams; + pending_streams.reserve(pending_stream_map_.size()); + for (auto it = pending_stream_map_.cbegin(); it != pending_stream_map_.cend(); + ++it) { + pending_streams.push_back(it->second.get()); + } + for (auto* pending_stream : pending_streams) { + MaybeProcessPendingStream(pending_stream); + if (!connection()->connected()) { + return; + } + } +} + +void QuicSession::ValidatePath( + std::unique_ptr context, + std::unique_ptr result_delegate, + PathValidationReason reason) { + connection_->ValidatePath(std::move(context), std::move(result_delegate), + reason); +} + +bool QuicSession::HasPendingPathValidation() const { + return connection_->HasPendingPathValidation(); +} + +bool QuicSession::MigratePath(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + QuicPacketWriter* writer, bool owns_writer) { + return connection_->MigratePath(self_address, peer_address, writer, + owns_writer); +} + +bool QuicSession::ValidateToken(absl::string_view token) { + QUICHE_DCHECK_EQ(perspective_, Perspective::IS_SERVER); + if (GetQuicFlag(quic_reject_retry_token_in_initial_packet)) { + return false; + } + if (token.empty() || token[0] != kAddressTokenPrefix) { + // Validate the prefix for token received in NEW_TOKEN frame. + return false; + } + const bool valid = GetCryptoStream()->ValidateAddressToken( + absl::string_view(token.data() + 1, token.length() - 1)); + if (valid) { + const CachedNetworkParameters* cached_network_params = + GetCryptoStream()->PreviousCachedNetworkParams(); + if (cached_network_params != nullptr && + cached_network_params->timestamp() > 0) { + connection()->OnReceiveConnectionState(*cached_network_params); + } + } + return valid; +} + +void QuicSession::OnServerPreferredAddressAvailable( + const QuicSocketAddress& server_preferred_address) { + QUICHE_DCHECK_EQ(perspective_, Perspective::IS_CLIENT); + if (visitor_ != nullptr) { + visitor_->OnServerPreferredAddressAvailable(server_preferred_address); + } +} + +#undef ENDPOINT // undef for jumbo builds +} // namespace quic diff --git a/quiche/quic/core/quic_session.h b/quiche/quic/core/quic_session.h new file mode 100644 index 000000000000..e1c1932ab1cc --- /dev/null +++ b/quiche/quic/core/quic_session.h @@ -0,0 +1,1036 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// A QuicSession, which demuxes a single connection to individual streams. + +#ifndef QUICHE_QUIC_CORE_QUIC_SESSION_H_ +#define QUICHE_QUIC_CORE_QUIC_SESSION_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "quiche/quic/core/crypto/tls_connection.h" +#include "quiche/quic/core/frames/quic_ack_frequency_frame.h" +#include "quiche/quic/core/frames/quic_stop_sending_frame.h" +#include "quiche/quic/core/frames/quic_window_update_frame.h" +#include "quiche/quic/core/handshaker_delegate_interface.h" +#include "quiche/quic/core/legacy_quic_stream_id_manager.h" +#include "quiche/quic/core/proto/cached_network_parameters_proto.h" +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/core/quic_control_frame_manager.h" +#include "quiche/quic/core/quic_crypto_stream.h" +#include "quiche/quic/core/quic_datagram_queue.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_packet_creator.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_path_validator.h" +#include "quiche/quic/core/quic_stream.h" +#include "quiche/quic/core/quic_stream_frame_data_producer.h" +#include "quiche/quic/core/quic_stream_priority.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_write_blocked_list.h" +#include "quiche/quic/core/session_notifier_interface.h" +#include "quiche/quic/core/stream_delegate_interface.h" +#include "quiche/quic/core/uber_quic_stream_id_manager.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/common/quiche_linked_hash_map.h" + +namespace quic { + +class QuicCryptoStream; +class QuicFlowController; +class QuicStream; +class QuicStreamIdManager; + +namespace test { +class QuicSessionPeer; +} // namespace test + +class QUIC_EXPORT_PRIVATE QuicSession + : public QuicConnectionVisitorInterface, + public SessionNotifierInterface, + public QuicStreamFrameDataProducer, + public QuicStreamIdManager::DelegateInterface, + public HandshakerDelegateInterface, + public StreamDelegateInterface, + public QuicControlFrameManager::DelegateInterface { + public: + // An interface from the session to the entity owning the session. + // This lets the session notify its owner when the connection + // is closed, blocked, etc. + // TODO(danzh): split this visitor to separate visitors for client and server + // respectively as not all methods in this class are interesting to both + // perspectives. + class QUIC_EXPORT_PRIVATE Visitor { + public: + virtual ~Visitor() {} + + // Called when the connection is closed after the streams have been closed. + virtual void OnConnectionClosed(QuicConnectionId server_connection_id, + QuicErrorCode error, + const std::string& error_details, + ConnectionCloseSource source) = 0; + + // Called when the session has become write blocked. + virtual void OnWriteBlocked(QuicBlockedWriterInterface* blocked_writer) = 0; + + // Called when the session receives reset on a stream from the peer. + virtual void OnRstStreamReceived(const QuicRstStreamFrame& frame) = 0; + + // Called when the session receives a STOP_SENDING for a stream from the + // peer. + virtual void OnStopSendingReceived(const QuicStopSendingFrame& frame) = 0; + + // Called when on whether a NewConnectionId frame can been sent. + virtual bool TryAddNewConnectionId( + const QuicConnectionId& server_connection_id, + const QuicConnectionId& new_connection_id) = 0; + + // Called when a ConnectionId has been retired. + virtual void OnConnectionIdRetired( + const QuicConnectionId& server_connection_id) = 0; + + virtual void OnServerPreferredAddressAvailable( + const QuicSocketAddress& /*server_preferred_address*/) = 0; + }; + + // Does not take ownership of |connection| or |visitor|. + QuicSession(QuicConnection* connection, Visitor* owner, + const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + QuicStreamCount num_expected_unidirectional_static_streams); + QuicSession(QuicConnection* connection, Visitor* owner, + const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + QuicStreamCount num_expected_unidirectional_static_streams, + std::unique_ptr datagram_observer); + QuicSession(const QuicSession&) = delete; + QuicSession& operator=(const QuicSession&) = delete; + + ~QuicSession() override; + + virtual void Initialize(); + + // Return the reserved crypto stream as a constant pointer. + virtual const QuicCryptoStream* GetCryptoStream() const = 0; + + // QuicConnectionVisitorInterface methods: + void OnStreamFrame(const QuicStreamFrame& frame) override; + void OnCryptoFrame(const QuicCryptoFrame& frame) override; + void OnRstStream(const QuicRstStreamFrame& frame) override; + void OnGoAway(const QuicGoAwayFrame& frame) override; + void OnMessageReceived(absl::string_view message) override; + void OnHandshakeDoneReceived() override; + void OnNewTokenReceived(absl::string_view token) override; + void OnWindowUpdateFrame(const QuicWindowUpdateFrame& frame) override; + void OnBlockedFrame(const QuicBlockedFrame& frame) override; + void OnConnectionClosed(const QuicConnectionCloseFrame& frame, + ConnectionCloseSource source) override; + void OnWriteBlocked() override; + void OnSuccessfulVersionNegotiation( + const ParsedQuicVersion& version) override; + void OnPacketReceived(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + bool is_connectivity_probe) override; + void OnCanWrite() override; + void OnCongestionWindowChange(QuicTime /*now*/) override {} + void OnConnectionMigration(AddressChangeType /*type*/) override {} + // Adds a connection level WINDOW_UPDATE frame. + void OnAckNeedsRetransmittableFrame() override; + void SendAckFrequency(const QuicAckFrequencyFrame& frame) override; + void SendNewConnectionId(const QuicNewConnectionIdFrame& frame) override; + void SendRetireConnectionId(uint64_t sequence_number) override; + // Returns true if server_connection_id can be issued. If returns true, + // |visitor_| may establish a mapping from |server_connection_id| to this + // session, if that's not desired, + // OnServerConnectionIdRetired(server_connection_id) can be used to remove the + // mapping. + bool MaybeReserveConnectionId( + const QuicConnectionId& server_connection_id) override; + void OnServerConnectionIdRetired( + const QuicConnectionId& server_connection_id) override; + bool WillingAndAbleToWrite() const override; + std::string GetStreamsInfoForLogging() const override; + void OnPathDegrading() override; + void OnForwardProgressMadeAfterPathDegrading() override; + bool AllowSelfAddressChange() const override; + HandshakeState GetHandshakeState() const override; + bool OnMaxStreamsFrame(const QuicMaxStreamsFrame& frame) override; + bool OnStreamsBlockedFrame(const QuicStreamsBlockedFrame& frame) override; + void OnStopSendingFrame(const QuicStopSendingFrame& frame) override; + void OnPacketDecrypted(EncryptionLevel level) override; + void OnOneRttPacketAcknowledged() override; + void OnHandshakePacketSent() override; + void OnKeyUpdate(KeyUpdateReason /*reason*/) override {} + std::unique_ptr AdvanceKeysAndCreateCurrentOneRttDecrypter() + override; + std::unique_ptr CreateCurrentOneRttEncrypter() override; + void BeforeConnectionCloseSent() override {} + bool ValidateToken(absl::string_view token) override; + bool MaybeSendAddressToken() override; + void OnBandwidthUpdateTimeout() override {} + std::unique_ptr CreateContextForMultiPortPath() + override { + return nullptr; + } + void MigrateToMultiPortPath( + std::unique_ptr /*context*/) override {} + void OnServerPreferredAddressAvailable( + const QuicSocketAddress& /*server_preferred_address*/) override; + + // QuicStreamFrameDataProducer + WriteStreamDataResult WriteStreamData(QuicStreamId id, + QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* writer) override; + bool WriteCryptoData(EncryptionLevel level, QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* writer) override; + + // SessionNotifierInterface methods: + bool OnFrameAcked(const QuicFrame& frame, QuicTime::Delta ack_delay_time, + QuicTime receive_timestamp) override; + void OnStreamFrameRetransmitted(const QuicStreamFrame& frame) override; + void OnFrameLost(const QuicFrame& frame) override; + bool RetransmitFrames(const QuicFrames& frames, + TransmissionType type) override; + bool IsFrameOutstanding(const QuicFrame& frame) const override; + bool HasUnackedCryptoData() const override; + bool HasUnackedStreamData() const override; + + void SendMaxStreams(QuicStreamCount stream_count, + bool unidirectional) override; + // The default implementation does nothing. Subclasses should override if + // for example they queue up stream requests. + virtual void OnCanCreateNewOutgoingStream(bool /*unidirectional*/) {} + + // Called on every incoming packet. Passes |packet| through to |connection_|. + virtual void ProcessUdpPacket(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + const QuicReceivedPacket& packet); + + // Sends |message| as a QUIC DATAGRAM frame (QUIC MESSAGE frame in gQUIC). + // See for + // more details. + // + // Returns a MessageResult struct which includes the status of the write + // operation and a message ID. The message ID (not sent on the wire) can be + // used to track the message; OnMessageAcked and OnMessageLost are called when + // a specific message gets acked or lost. + // + // If the write operation is successful, all of the slices in |message| are + // consumed, leaving them empty. If MESSAGE_STATUS_INTERNAL_ERROR is + // returned, the slices in question may or may not be consumed; it is no + // longer safe to access those. For all other status codes, |message| is kept + // intact. + // + // Note that SendMessage will fail with status = MESSAGE_STATUS_BLOCKED + // if the connection is congestion control blocked or the underlying socket is + // write blocked. In this case the caller can retry sending message again when + // connection becomes available, for example after getting OnCanWrite() + // callback. + // + // SendMessage flushes the current packet even it is not full; if the + // application needs to bundle other data in the same packet, consider using + // QuicConnection::ScopedPacketFlusher around the relevant write operations. + MessageResult SendMessage(absl::Span message); + + // Same as above SendMessage, except caller can specify if the given |message| + // should be flushed even if the underlying connection is deemed unwritable. + MessageResult SendMessage(absl::Span message, + bool flush); + + // Single-slice version of SendMessage(). Unlike the version above, this + // version always takes ownership of the slice. + MessageResult SendMessage(quiche::QuicheMemSlice message); + + // Called when message with |message_id| gets acked. + virtual void OnMessageAcked(QuicMessageId message_id, + QuicTime receive_timestamp); + + // Called when message with |message_id| is considered as lost. + virtual void OnMessageLost(QuicMessageId message_id); + + // QuicControlFrameManager::DelegateInterface + // Close the connection on error. + void OnControlFrameManagerError(QuicErrorCode error_code, + std::string error_details) override; + // Called by control frame manager when it wants to write control frames to + // the peer. Returns true if |frame| is consumed, false otherwise. The frame + // will be sent in specified transmission |type|. + bool WriteControlFrame(const QuicFrame& frame, + TransmissionType type) override; + + // Called to send RST_STREAM (and STOP_SENDING) and close stream. If stream + // |id| does not exist, just send RST_STREAM (and STOP_SENDING). + virtual void ResetStream(QuicStreamId id, QuicRstStreamErrorCode error); + + // Called when the session wants to go away and not accept any new streams. + virtual void SendGoAway(QuicErrorCode error_code, const std::string& reason); + + // Sends a BLOCKED frame. + virtual void SendBlocked(QuicStreamId id, QuicStreamOffset byte_offset); + + // Sends a WINDOW_UPDATE frame. + virtual void SendWindowUpdate(QuicStreamId id, QuicStreamOffset byte_offset); + + // Called by stream |stream_id| when it gets closed. + virtual void OnStreamClosed(QuicStreamId stream_id); + + // Returns true if outgoing packets will be encrypted, even if the server + // hasn't confirmed the handshake yet. + virtual bool IsEncryptionEstablished() const; + + // Returns true if 1RTT keys are available. + bool OneRttKeysAvailable() const; + + // Called by the QuicCryptoStream when a new QuicConfig has been negotiated. + virtual void OnConfigNegotiated(); + + // Called by the TLS handshaker when ALPS data is received. + // Returns an error message if an error has occurred, or nullopt otherwise. + virtual absl::optional OnAlpsData(const uint8_t* alps_data, + size_t alps_length); + + // From HandshakerDelegateInterface + bool OnNewDecryptionKeyAvailable(EncryptionLevel level, + std::unique_ptr decrypter, + bool set_alternative_decrypter, + bool latch_once_used) override; + void OnNewEncryptionKeyAvailable( + EncryptionLevel level, std::unique_ptr encrypter) override; + void SetDefaultEncryptionLevel(EncryptionLevel level) override; + void OnTlsHandshakeComplete() override; + void DiscardOldDecryptionKey(EncryptionLevel level) override; + void DiscardOldEncryptionKey(EncryptionLevel level) override; + void NeuterUnencryptedData() override; + void NeuterHandshakeData() override; + void OnZeroRttRejected(int reason) override; + bool FillTransportParameters(TransportParameters* params) override; + QuicErrorCode ProcessTransportParameters(const TransportParameters& params, + bool is_resumption, + std::string* error_details) override; + void OnHandshakeCallbackDone() override; + bool PacketFlusherAttached() const override; + ParsedQuicVersion parsed_version() const override { return version(); } + + // Implement StreamDelegateInterface. + void OnStreamError(QuicErrorCode error_code, + std::string error_details) override; + void OnStreamError(QuicErrorCode error_code, + QuicIetfTransportErrorCodes ietf_error, + std::string error_details) override; + // Sets priority in the write blocked list. + void RegisterStreamPriority(QuicStreamId id, bool is_static, + const QuicStreamPriority& priority) override; + // Clears priority from the write blocked list. + void UnregisterStreamPriority(QuicStreamId id) override; + // Updates priority on the write blocked list. + void UpdateStreamPriority(QuicStreamId id, + const QuicStreamPriority& new_priority) override; + + // Called by streams when they want to write data to the peer. + // Returns a pair with the number of bytes consumed from data, and a boolean + // indicating if the fin bit was consumed. This does not indicate the data + // has been sent on the wire: it may have been turned into a packet and queued + // if the socket was unexpectedly blocked. + QuicConsumedData WritevData(QuicStreamId id, size_t write_length, + QuicStreamOffset offset, StreamSendingState state, + TransmissionType type, + EncryptionLevel level) override; + + size_t SendCryptoData(EncryptionLevel level, size_t write_length, + QuicStreamOffset offset, + TransmissionType type) override; + + // Called by the QuicCryptoStream when a handshake message is sent. + virtual void OnCryptoHandshakeMessageSent( + const CryptoHandshakeMessage& message); + + // Called by the QuicCryptoStream when a handshake message is received. + virtual void OnCryptoHandshakeMessageReceived( + const CryptoHandshakeMessage& message); + + // Returns mutable config for this session. Returned config is owned + // by QuicSession. + QuicConfig* config() { return &config_; } + const QuicConfig* config() const { return &config_; } + + // Returns true if the stream existed previously and has been closed. + // Returns false if the stream is still active or if the stream has + // not yet been created. + bool IsClosedStream(QuicStreamId id); + + QuicConnection* connection() { return connection_; } + const QuicConnection* connection() const { return connection_; } + const QuicSocketAddress& peer_address() const { + return connection_->peer_address(); + } + const QuicSocketAddress& self_address() const { + return connection_->self_address(); + } + QuicConnectionId connection_id() const { + return connection_->connection_id(); + } + + // Returns the number of currently open streams, excluding static streams, and + // never counting unfinished streams. + size_t GetNumActiveStreams() const; + + // Add the stream to the session's write-blocked list because it is blocked by + // connection-level flow control but not by its own stream-level flow control. + // The stream will be given a chance to write when a connection-level + // WINDOW_UPDATE arrives. + virtual void MarkConnectionLevelWriteBlocked(QuicStreamId id); + + // Called to close zombie stream |id|. + void MaybeCloseZombieStream(QuicStreamId id); + + // Returns true if there is pending handshake data in the crypto stream. + // TODO(ianswett): Make this private or remove. + bool HasPendingHandshake() const; + + // Returns true if the session has data to be sent, either queued in the + // connection, or in a write-blocked stream. + bool HasDataToWrite() const; + + // Initiates a path validation on the path described in the given context, + // asynchronously calls |result_delegate| upon success or failure. + // The initiator should extend QuicPathValidationContext to provide the writer + // and ResultDelegate to react upon the validation result. + // Example implementations of these for path validation for connection + // migration could be: + // class QUIC_EXPORT_PRIVATE PathMigrationContext + // : public QuicPathValidationContext { + // public: + // PathMigrationContext(std::unique_ptr writer, + // const QuicSocketAddress& self_address, + // const QuicSocketAddress& peer_address) + // : QuicPathValidationContext(self_address, peer_address), + // alternative_writer_(std::move(writer)) {} + // + // QuicPacketWriter* WriterToUse() override { + // return alternative_writer_.get(); + // } + // + // QuicPacketWriter* ReleaseWriter() { + // return alternative_writer_.release(); + // } + // + // private: + // std::unique_ptr alternative_writer_; + // }; + // + // class PathMigrationValidationResultDelegate + // : public QuicPathValidator::ResultDelegate { + // public: + // PathMigrationValidationResultDelegate(QuicConnection* connection) + // : QuicPathValidator::ResultDelegate(), connection_(connection) {} + // + // void OnPathValidationSuccess( + // std::unique_ptr context) override { + // // Do some work to prepare for migration. + // // ... + // + // // Actually migrate to the validated path. + // auto migration_context = std::unique_ptr( + // static_cast(context.release())); + // connection_->MigratePath(migration_context->self_address(), + // migration_context->peer_address(), + // migration_context->ReleaseWriter(), + // /*owns_writer=*/true); + // + // // Post-migration actions + // // ... + // } + // + // void OnPathValidationFailure( + // std::unique_ptr /*context*/) override { + // // Handle validation failure. + // } + // + // private: + // QuicConnection* connection_; + // }; + void ValidatePath( + std::unique_ptr context, + std::unique_ptr result_delegate, + PathValidationReason reason); + + // Return true if there is a path being validated. + bool HasPendingPathValidation() const; + + // Switch to the path described in |context| without validating the path. + bool MigratePath(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + QuicPacketWriter* writer, bool owns_writer); + + // Returns the largest payload that will fit into a single MESSAGE frame. + // Because overhead can vary during a connection, this method should be + // checked for every message. + QuicPacketLength GetCurrentLargestMessagePayload() const; + + // Returns the largest payload that will fit into a single MESSAGE frame at + // any point during the connection. This assumes the version and + // connection ID lengths do not change. + QuicPacketLength GetGuaranteedLargestMessagePayload() const; + + bool transport_goaway_sent() const { return transport_goaway_sent_; } + + bool transport_goaway_received() const { return transport_goaway_received_; } + + // Returns the Google QUIC error code + QuicErrorCode error() const { return on_closed_frame_.quic_error_code; } + const std::string& error_details() const { + return on_closed_frame_.error_details; + } + uint64_t transport_close_frame_type() const { + return on_closed_frame_.transport_close_frame_type; + } + QuicConnectionCloseType close_type() const { + return on_closed_frame_.close_type; + } + + Perspective perspective() const { return perspective_; } + + QuicFlowController* flow_controller() { return &flow_controller_; } + + // Returns true if connection is flow controller blocked. + bool IsConnectionFlowControlBlocked() const; + + // Returns true if any stream is flow controller blocked. + bool IsStreamFlowControlBlocked(); + + size_t max_open_incoming_bidirectional_streams() const; + size_t max_open_incoming_unidirectional_streams() const; + + size_t MaxAvailableBidirectionalStreams() const; + size_t MaxAvailableUnidirectionalStreams() const; + + // Returns existing stream with id = |stream_id|. If no + // such stream exists, and |stream_id| is a peer-created stream id, + // then a new stream is created and returned. In all other cases, nullptr is + // returned. + // Caller does not own the returned stream. + QuicStream* GetOrCreateStream(const QuicStreamId stream_id); + + // Mark a stream as draining. + void StreamDraining(QuicStreamId id, bool unidirectional); + + // Returns true if this stream should yield writes to another blocked stream. + virtual bool ShouldYield(QuicStreamId stream_id); + + // Clean up closed_streams_. + void CleanUpClosedStreams(); + + const ParsedQuicVersionVector& supported_versions() const { + return supported_versions_; + } + + QuicStreamId next_outgoing_bidirectional_stream_id() const; + QuicStreamId next_outgoing_unidirectional_stream_id() const; + + // Return true if given stream is peer initiated. + bool IsIncomingStream(QuicStreamId id) const; + + // Record errors when a connection is closed at the server side, should only + // be called from server's perspective. + // Noop if |error| is QUIC_NO_ERROR. + static void RecordConnectionCloseAtServer(QuicErrorCode error, + ConnectionCloseSource source); + + QuicTransportVersion transport_version() const { + return connection_->transport_version(); + } + + ParsedQuicVersion version() const { return connection_->version(); } + + bool is_configured() const { return is_configured_; } + + // Called to neuter crypto data of encryption |level|. + void NeuterCryptoDataOfEncryptionLevel(EncryptionLevel level); + + // Returns the ALPN values to negotiate on this session. + virtual std::vector GetAlpnsToOffer() const { + // TODO(vasilvv): this currently sets HTTP/3 by default. Switch all + // non-HTTP applications to appropriate ALPNs. + return std::vector({AlpnForVersion(connection()->version())}); + } + + // Provided a list of ALPNs offered by the client, selects an ALPN from the + // list, or alpns.end() if none of the ALPNs are acceptable. + virtual std::vector::const_iterator SelectAlpn( + const std::vector& alpns) const; + + // Called when the ALPN of the connection is established for a connection that + // uses TLS handshake. + virtual void OnAlpnSelected(absl::string_view alpn); + + // Called on clients by the crypto handshaker to provide application state + // necessary for sending application data in 0-RTT. The state provided here is + // the same state that was provided to the crypto handshaker in + // QuicCryptoStream::SetServerApplicationStateForResumption on a previous + // connection. Application protocols that require state to be carried over + // from the previous connection to support 0-RTT data must implement this + // method to ingest this state. For example, an HTTP/3 QuicSession would + // implement this function to process the remembered server SETTINGS and apply + // those SETTINGS to 0-RTT data. This function returns true if the application + // state has been successfully processed, and false if there was an error + // processing the cached state and the connection should be closed. + virtual bool ResumeApplicationState(ApplicationState* /*cached_state*/) { + return true; + } + + // Does actual work of sending RESET_STREAM, if the stream type allows. + // Also informs the connection so that pending stream frames can be flushed. + virtual void MaybeSendRstStreamFrame(QuicStreamId id, + QuicResetStreamError error, + QuicStreamOffset bytes_written); + + // Sends a STOP_SENDING frame if the stream type allows. + virtual void MaybeSendStopSendingFrame(QuicStreamId id, + QuicResetStreamError error); + + // Returns the encryption level to send application data. + EncryptionLevel GetEncryptionLevelToSendApplicationData() const; + + const absl::optional user_agent_id() const { + return user_agent_id_; + } + + // TODO(wub): remove saving user-agent to QuicSession. + void SetUserAgentId(std::string user_agent_id) { + user_agent_id_ = std::move(user_agent_id); + connection()->OnUserAgentIdKnown(user_agent_id_.value()); + } + + void SetSourceAddressTokenToSend(absl::string_view token) { + connection()->SetSourceAddressTokenToSend(token); + } + + const QuicClock* GetClock() const { + return connection()->helper()->GetClock(); + } + + bool liveness_testing_in_progress() const { + return liveness_testing_in_progress_; + } + + virtual QuicSSLConfig GetSSLConfig() const { return QuicSSLConfig(); } + + // Try converting all pending streams to normal streams. + void ProcessAllPendingStreams(); + + const ParsedQuicVersionVector& client_original_supported_versions() const { + QUICHE_DCHECK_EQ(perspective_, Perspective::IS_CLIENT); + return client_original_supported_versions_; + } + void set_client_original_supported_versions( + const ParsedQuicVersionVector& client_original_supported_versions) { + QUICHE_DCHECK_EQ(perspective_, Perspective::IS_CLIENT); + client_original_supported_versions_ = client_original_supported_versions; + } + + // Controls whether the default datagram queue used by the session actually + // queues the datagram. If set to true, the datagrams in the default queue + // will be forcefully flushed, potentially bypassing congestion control and + // other limitations. + void SetForceFlushForDefaultQueue(bool force_flush) { + datagram_queue_.SetForceFlush(force_flush); + } + + // Find stream with |id|, returns nullptr if the stream does not exist or + // closed. static streams and zombie streams are not considered active + // streams. + QuicStream* GetActiveStream(QuicStreamId id) const; + + // Returns the priority type used by the streams in the session. + QuicPriorityType priority_type() const { return QuicPriorityType::kHttp; } + + protected: + using StreamMap = + absl::flat_hash_map>; + + using PendingStreamMap = + absl::flat_hash_map>; + + using ClosedStreams = std::vector>; + + using ZombieStreamMap = + absl::flat_hash_map>; + + std::string on_closed_frame_string() const; + + // Creates a new stream to handle a peer-initiated stream. + // Caller does not own the returned stream. + // Returns nullptr and does error handling if the stream can not be created. + virtual QuicStream* CreateIncomingStream(QuicStreamId id) = 0; + virtual QuicStream* CreateIncomingStream(PendingStream* pending) = 0; + + // Return the reserved crypto stream. + virtual QuicCryptoStream* GetMutableCryptoStream() = 0; + + // Adds |stream| to the stream map. + virtual void ActivateStream(std::unique_ptr stream); + + // Set transmission type of next sending packets. + void SetTransmissionType(TransmissionType type); + + // Returns the stream ID for a new outgoing bidirectional/unidirectional + // stream, and increments the underlying counter. + QuicStreamId GetNextOutgoingBidirectionalStreamId(); + QuicStreamId GetNextOutgoingUnidirectionalStreamId(); + + // Indicates whether the next outgoing bidirectional/unidirectional stream ID + // can be allocated or not. The test for version-99/IETF QUIC is whether it + // will exceed the maximum-stream-id or not. For non-version-99 (Google) QUIC + // it checks whether the next stream would exceed the limit on the number of + // open streams. + bool CanOpenNextOutgoingBidirectionalStream(); + bool CanOpenNextOutgoingUnidirectionalStream(); + + // Returns the maximum bidirectional streams parameter sent with the handshake + // as a transport parameter, or in the most recent MAX_STREAMS frame. + QuicStreamCount GetAdvertisedMaxIncomingBidirectionalStreams() const; + + // When a stream is closed locally, it may not yet know how many bytes the + // peer sent on that stream. + // When this data arrives (via stream frame w. FIN, trailing headers, or RST) + // this method is called, and correctly updates the connection level flow + // controller. + virtual void OnFinalByteOffsetReceived(QuicStreamId id, + QuicStreamOffset final_byte_offset); + + // Returns true if a frame with the given type and id can be prcoessed by a + // PendingStream. However, the frame will always be processed by a QuicStream + // if one exists with the given stream_id. + virtual bool UsesPendingStreamForFrame(QuicFrameType /*type*/, + QuicStreamId /*stream_id*/) const { + return false; + } + + // Returns true if a pending stream should be converted to a real stream after + // a corresponding STREAM_FRAME is received. + virtual bool ShouldProcessPendingStreamImmediately() const { return true; } + + spdy::SpdyPriority GetSpdyPriorityofStream(QuicStreamId stream_id) const { + return write_blocked_streams_->GetPriorityOfStream(stream_id) + .http() + .urgency; + } + + size_t pending_streams_size() const { return pending_stream_map_.size(); } + + ClosedStreams* closed_streams() { return &closed_streams_; } + + void set_largest_peer_created_stream_id( + QuicStreamId largest_peer_created_stream_id); + + QuicWriteBlockedListInterface* write_blocked_streams() { + return write_blocked_streams_.get(); + } + + // Returns true if the stream is still active. + bool IsOpenStream(QuicStreamId id); + + // Returns true if the stream is a static stream. + bool IsStaticStream(QuicStreamId id) const; + + // Close connection when receive a frame for a locally-created nonexistent + // stream. + // Prerequisite: IsClosedStream(stream_id) == false + // Server session might need to override this method to allow server push + // stream to be promised before creating an active stream. + virtual void HandleFrameOnNonexistentOutgoingStream(QuicStreamId stream_id); + + virtual bool MaybeIncreaseLargestPeerStreamId(const QuicStreamId stream_id); + + void InsertLocallyClosedStreamsHighestOffset(const QuicStreamId id, + QuicStreamOffset offset); + // If stream is a locally closed stream, this RST will update FIN offset. + // Otherwise stream is a preserved stream and the behavior of it depends on + // derived class's own implementation. + virtual void HandleRstOnValidNonexistentStream( + const QuicRstStreamFrame& frame); + + // Returns a stateless reset token which will be included in the public reset + // packet. + virtual StatelessResetToken GetStatelessResetToken() const; + + QuicControlFrameManager& control_frame_manager() { + return control_frame_manager_; + } + + const LegacyQuicStreamIdManager& stream_id_manager() const { + return stream_id_manager_; + } + + QuicDatagramQueue* datagram_queue() { return &datagram_queue_; } + + size_t num_static_streams() const { return num_static_streams_; } + + size_t num_zombie_streams() const { return num_zombie_streams_; } + + bool was_zero_rtt_rejected() const { return was_zero_rtt_rejected_; } + + size_t num_outgoing_draining_streams() const { + return num_outgoing_draining_streams_; + } + + size_t num_draining_streams() const { return num_draining_streams_; } + + // Processes the stream type information of |pending| depending on + // different kinds of sessions' own rules. If the pending stream has been + // converted to a normal stream, returns a pointer to the new stream; + // otherwise, returns nullptr. + virtual QuicStream* ProcessPendingStream(PendingStream* /*pending*/) { + return nullptr; + } + + // Called by applications to perform |action| on active streams. + // Stream iteration will be stopped if action returns false. + void PerformActionOnActiveStreams(std::function action); + void PerformActionOnActiveStreams( + std::function action) const; + + // Return the largest peer created stream id depending on directionality + // indicated by |unidirectional|. + QuicStreamId GetLargestPeerCreatedStreamId(bool unidirectional) const; + + // Deletes the connection and sets it to nullptr, so calling it mulitiple + // times is safe. + void DeleteConnection(); + + // Call SetPriority() on stream id |id| and return true if stream is active. + bool MaybeSetStreamPriority(QuicStreamId stream_id, + const QuicStreamPriority& priority); + + void SetLossDetectionTuner( + std::unique_ptr tuner) { + connection()->SetLossDetectionTuner(std::move(tuner)); + } + + const UberQuicStreamIdManager& ietf_streamid_manager() const { + QUICHE_DCHECK(VersionHasIetfQuicFrames(transport_version())); + return ietf_streamid_manager_; + } + + // Only called at a server session. Generate a CachedNetworkParameters that + // can be sent to the client as part of the address token, based on the latest + // bandwidth/rtt information. If return absl::nullopt, address token will not + // contain the CachedNetworkParameters. + virtual absl::optional + GenerateCachedNetworkParameters() const { + return absl::nullopt; + } + + private: + friend class test::QuicSessionPeer; + + // Called in OnConfigNegotiated when we receive a new stream level flow + // control window in a negotiated config. Closes the connection if invalid. + void OnNewStreamFlowControlWindow(QuicStreamOffset new_window); + + // Called in OnConfigNegotiated when we receive a new unidirectional stream + // flow control window in a negotiated config. + void OnNewStreamUnidirectionalFlowControlWindow(QuicStreamOffset new_window); + + // Called in OnConfigNegotiated when we receive a new outgoing bidirectional + // stream flow control window in a negotiated config. + void OnNewStreamOutgoingBidirectionalFlowControlWindow( + QuicStreamOffset new_window); + + // Called in OnConfigNegotiated when we receive a new incoming bidirectional + // stream flow control window in a negotiated config. + void OnNewStreamIncomingBidirectionalFlowControlWindow( + QuicStreamOffset new_window); + + // Called in OnConfigNegotiated when we receive a new connection level flow + // control window in a negotiated config. Closes the connection if invalid. + void OnNewSessionFlowControlWindow(QuicStreamOffset new_window); + + // Debug helper for |OnCanWrite()|, check that OnStreamWrite() makes + // forward progress. Returns false if busy loop detected. + bool CheckStreamNotBusyLooping(QuicStream* stream, + uint64_t previous_bytes_written, + bool previous_fin_sent); + + // Debug helper for OnCanWrite. Check that after QuicStream::OnCanWrite(), + // if stream has buffered data and is not stream level flow control blocked, + // it has to be in the write blocked list. + bool CheckStreamWriteBlocked(QuicStream* stream) const; + + // Called in OnConfigNegotiated for Finch trials to measure performance of + // starting with larger flow control receive windows. + void AdjustInitialFlowControlWindows(size_t stream_window); + + // Find stream with |id|, returns nullptr if the stream does not exist or + // closed. + QuicStream* GetStream(QuicStreamId id) const; + + // Can return NULL, e.g., if the stream has been closed before. + PendingStream* GetOrCreatePendingStream(QuicStreamId stream_id); + + // Let streams and control frame managers retransmit lost data, returns true + // if all lost data is retransmitted. Returns false otherwise. + bool RetransmitLostData(); + + // Returns true if stream data should be written. + bool CanWriteStreamData() const; + + // Closes the pending stream |stream_id| before it has been created. + void ClosePendingStream(QuicStreamId stream_id); + + // Whether the frame with given type and id should be feed to a pending + // stream. + bool ShouldProcessFrameByPendingStream(QuicFrameType type, + QuicStreamId id) const; + + // Process the pending stream if possible. + void MaybeProcessPendingStream(PendingStream* pending); + + // Creates or gets pending stream, feeds it with |frame|, and returns the + // pending stream. Can return NULL, e.g., if the stream ID is invalid. + PendingStream* PendingStreamOnStreamFrame(const QuicStreamFrame& frame); + + // Creates or gets pending strea, feed it with |frame|, and closes the pending + // stream. + void PendingStreamOnRstStream(const QuicRstStreamFrame& frame); + + // Creates or gets pending stream, feeds it with |frame|, and records the + // max_data in the pending stream. + void PendingStreamOnWindowUpdateFrame(const QuicWindowUpdateFrame& frame); + + // Creates or gets pending stream, feeds it with |frame|, and records the + // ietf_error_code in the pending stream. + void PendingStreamOnStopSendingFrame(const QuicStopSendingFrame& frame); + + // Keep track of highest received byte offset of locally closed streams, while + // waiting for a definitive final highest offset from the peer. + absl::flat_hash_map + locally_closed_streams_highest_offset_; + + QuicConnection* connection_; + + // Store perspective on QuicSession during the constructor as it may be needed + // during our destructor when connection_ may have already been destroyed. + Perspective perspective_; + + // May be null. + Visitor* visitor_; + + // A list of streams which need to write more data. Stream register + // themselves in their constructor, and unregisterm themselves in their + // destructors, so the write blocked list must outlive all streams. + std::unique_ptr write_blocked_streams_; + + ClosedStreams closed_streams_; + + QuicConfig config_; + + // Map from StreamId to pointers to streams. Owns the streams. + StreamMap stream_map_; + + // Map from StreamId to PendingStreams for peer-created unidirectional streams + // which are waiting for the first byte of payload to arrive. + PendingStreamMap pending_stream_map_; + + // TODO(fayang): Consider moving LegacyQuicStreamIdManager into + // UberQuicStreamIdManager. + // Manages stream IDs for Google QUIC. + LegacyQuicStreamIdManager stream_id_manager_; + + // Manages stream IDs for version99/IETF QUIC + UberQuicStreamIdManager ietf_streamid_manager_; + + // A counter for streams which have sent and received FIN but waiting for + // application to consume data. + size_t num_draining_streams_; + + // A counter for self initiated streams which have sent and received FIN but + // waiting for application to consume data. + size_t num_outgoing_draining_streams_; + + // A counter for static streams which are in stream_map_. + size_t num_static_streams_; + + // A counter for streams which have done reading and writing, but are waiting + // for acks. + size_t num_zombie_streams_; + + // Received information for a connection close. + QuicConnectionCloseFrame on_closed_frame_; + absl::optional source_; + + // Used for connection-level flow control. + QuicFlowController flow_controller_; + + // The stream id which was last popped in OnCanWrite, or 0, if not under the + // call stack of OnCanWrite. + QuicStreamId currently_writing_stream_id_; + + // Whether a transport layer GOAWAY frame has been sent. + // Such a frame only exists in Google QUIC, therefore |transport_goaway_sent_| + // is always false when using IETF QUIC. + bool transport_goaway_sent_; + + // Whether a transport layer GOAWAY frame has been received. + // Such a frame only exists in Google QUIC, therefore + // |transport_goaway_received_| is always false when using IETF QUIC. + bool transport_goaway_received_; + + QuicControlFrameManager control_frame_manager_; + + // Id of latest successfully sent message. + QuicMessageId last_message_id_; + + // The buffer used to queue the DATAGRAM frames. + QuicDatagramQueue datagram_queue_; + + // TODO(fayang): switch to linked_hash_set when chromium supports it. The bool + // is not used here. + // List of streams with pending retransmissions. + quiche::QuicheLinkedHashMap + streams_with_pending_retransmission_; + + // Clean up closed_streams_ when this alarm fires. + std::unique_ptr closed_streams_clean_up_alarm_; + + // Supported version list used by the crypto handshake only. Please note, this + // list may be a superset of the connection framer's supported versions. + ParsedQuicVersionVector supported_versions_; + + // Only non-empty on the client after receiving a version negotiation packet, + // contains the configured versions from the original session before version + // negotiation was received. + ParsedQuicVersionVector client_original_supported_versions_; + + absl::optional user_agent_id_; + + // Initialized to false. Set to true when the session has been properly + // configured and is ready for general operation. + bool is_configured_; + + // Whether the session has received a 0-RTT rejection (QUIC+TLS only). + bool was_zero_rtt_rejected_; + + // This indicates a liveness testing is in progress, and push back the + // creation of new outgoing bidirectional streams. + bool liveness_testing_in_progress_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_SESSION_H_ diff --git a/quiche/quic/core/quic_session_test.cc b/quiche/quic/core/quic_session_test.cc new file mode 100644 index 000000000000..bdf74b309123 --- /dev/null +++ b/quiche/quic/core/quic_session_test.cc @@ -0,0 +1,3318 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_session.h" + +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/crypto/null_decrypter.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/core/crypto/transport_parameters.h" +#include "quiche/quic/core/frames/quic_max_streams_frame.h" +#include "quiche/quic/core/quic_crypto_stream.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_stream.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/mock_quic_session_visitor.h" +#include "quiche/quic/test_tools/quic_config_peer.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_flow_controller_peer.h" +#include "quiche/quic/test_tools/quic_session_peer.h" +#include "quiche/quic/test_tools/quic_stream_id_manager_peer.h" +#include "quiche/quic/test_tools/quic_stream_peer.h" +#include "quiche/quic/test_tools/quic_stream_send_buffer_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_mem_slice_storage.h" + +using spdy::kV3HighestPriority; +using spdy::SpdyPriority; +using ::testing::_; +using ::testing::AnyNumber; +using ::testing::AtLeast; +using ::testing::InSequence; +using ::testing::Invoke; +using ::testing::NiceMock; +using ::testing::Return; +using ::testing::StrictMock; +using ::testing::WithArg; + +namespace quic { +namespace test { +namespace { + +class TestCryptoStream : public QuicCryptoStream, public QuicCryptoHandshaker { + public: + explicit TestCryptoStream(QuicSession* session) + : QuicCryptoStream(session), + QuicCryptoHandshaker(this, session), + encryption_established_(false), + one_rtt_keys_available_(false), + params_(new QuicCryptoNegotiatedParameters) { + // Simulate a negotiated cipher_suite with a fake value. + params_->cipher_suite = 1; + } + + void EstablishZeroRttEncryption() { + encryption_established_ = true; + session()->connection()->SetEncrypter( + ENCRYPTION_ZERO_RTT, + std::make_unique(session()->perspective())); + } + + void OnHandshakeMessage(const CryptoHandshakeMessage& /*message*/) override { + encryption_established_ = true; + one_rtt_keys_available_ = true; + QuicErrorCode error; + std::string error_details; + session()->config()->SetInitialStreamFlowControlWindowToSend( + kInitialStreamFlowControlWindowForTest); + session()->config()->SetInitialSessionFlowControlWindowToSend( + kInitialSessionFlowControlWindowForTest); + if (session()->version().UsesTls()) { + if (session()->perspective() == Perspective::IS_CLIENT) { + session()->config()->SetOriginalConnectionIdToSend( + session()->connection()->connection_id()); + session()->config()->SetInitialSourceConnectionIdToSend( + session()->connection()->connection_id()); + } else { + session()->config()->SetInitialSourceConnectionIdToSend( + session()->connection()->client_connection_id()); + } + TransportParameters transport_parameters; + EXPECT_TRUE( + session()->config()->FillTransportParameters(&transport_parameters)); + error = session()->config()->ProcessTransportParameters( + transport_parameters, /* is_resumption = */ false, &error_details); + } else { + CryptoHandshakeMessage msg; + session()->config()->ToHandshakeMessage(&msg, transport_version()); + error = + session()->config()->ProcessPeerHello(msg, CLIENT, &error_details); + } + EXPECT_THAT(error, IsQuicNoError()); + session()->OnNewEncryptionKeyAvailable( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(session()->perspective())); + session()->OnConfigNegotiated(); + if (session()->connection()->version().handshake_protocol == + PROTOCOL_TLS1_3) { + session()->OnTlsHandshakeComplete(); + } else { + session()->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + } + session()->DiscardOldEncryptionKey(ENCRYPTION_INITIAL); + } + + // QuicCryptoStream implementation + ssl_early_data_reason_t EarlyDataReason() const override { + return ssl_early_data_unknown; + } + bool encryption_established() const override { + return encryption_established_; + } + bool one_rtt_keys_available() const override { + return one_rtt_keys_available_; + } + const QuicCryptoNegotiatedParameters& crypto_negotiated_params() + const override { + return *params_; + } + CryptoMessageParser* crypto_message_parser() override { + return QuicCryptoHandshaker::crypto_message_parser(); + } + void OnPacketDecrypted(EncryptionLevel /*level*/) override {} + void OnOneRttPacketAcknowledged() override {} + void OnHandshakePacketSent() override {} + void OnHandshakeDoneReceived() override {} + void OnNewTokenReceived(absl::string_view /*token*/) override {} + std::string GetAddressToken( + const CachedNetworkParameters* /*cached_network_parameters*/) + const override { + return ""; + } + bool ValidateAddressToken(absl::string_view /*token*/) const override { + return true; + } + const CachedNetworkParameters* PreviousCachedNetworkParams() const override { + return nullptr; + } + void SetPreviousCachedNetworkParams( + CachedNetworkParameters /*cached_network_params*/) override {} + HandshakeState GetHandshakeState() const override { + return one_rtt_keys_available() ? HANDSHAKE_COMPLETE : HANDSHAKE_START; + } + void SetServerApplicationStateForResumption( + std::unique_ptr /*application_state*/) override {} + MOCK_METHOD(std::unique_ptr, + AdvanceKeysAndCreateCurrentOneRttDecrypter, (), (override)); + MOCK_METHOD(std::unique_ptr, CreateCurrentOneRttEncrypter, (), + (override)); + + MOCK_METHOD(void, OnCanWrite, (), (override)); + bool HasPendingCryptoRetransmission() const override { return false; } + + MOCK_METHOD(bool, HasPendingRetransmission, (), (const, override)); + + void OnConnectionClosed(QuicErrorCode /*error*/, + ConnectionCloseSource /*source*/) override {} + + bool ExportKeyingMaterial(absl::string_view /*label*/, + absl::string_view /*context*/, + size_t /*result_len*/, + std::string* /*result*/) override { + return false; + } + + SSL* GetSsl() const override { return nullptr; } + + bool IsCryptoFrameExpectedForEncryptionLevel( + EncryptionLevel level) const override { + return level != ENCRYPTION_ZERO_RTT; + } + + EncryptionLevel GetEncryptionLevelToSendCryptoDataOfSpace( + PacketNumberSpace space) const override { + switch (space) { + case INITIAL_DATA: + return ENCRYPTION_INITIAL; + case HANDSHAKE_DATA: + return ENCRYPTION_HANDSHAKE; + case APPLICATION_DATA: + return ENCRYPTION_FORWARD_SECURE; + default: + QUICHE_DCHECK(false); + return NUM_ENCRYPTION_LEVELS; + } + } + + private: + using QuicCryptoStream::session; + + bool encryption_established_; + bool one_rtt_keys_available_; + quiche::QuicheReferenceCountedPointer params_; +}; + +class TestStream : public QuicStream { + public: + TestStream(QuicStreamId id, QuicSession* session, StreamType type) + : TestStream(id, session, /*is_static=*/false, type) {} + + TestStream(QuicStreamId id, QuicSession* session, bool is_static, + StreamType type) + : QuicStream(id, session, is_static, type) {} + + TestStream(PendingStream* pending, QuicSession* session) + : QuicStream(pending, session, /*is_static=*/false) {} + + using QuicStream::CloseWriteSide; + using QuicStream::WriteMemSlices; + + void OnDataAvailable() override {} + + MOCK_METHOD(void, OnCanWrite, (), (override)); + MOCK_METHOD(bool, RetransmitStreamData, + (QuicStreamOffset, QuicByteCount, bool, TransmissionType), + (override)); + + MOCK_METHOD(bool, HasPendingRetransmission, (), (const, override)); +}; + +class TestSession : public QuicSession { + public: + explicit TestSession(QuicConnection* connection, + MockQuicSessionVisitor* session_visitor) + : QuicSession(connection, session_visitor, DefaultQuicConfig(), + CurrentSupportedVersions(), + /*num_expected_unidirectional_static_streams = */ 0), + crypto_stream_(this), + writev_consumes_all_data_(false), + uses_pending_streams_(false), + num_incoming_streams_created_(0) { + Initialize(); + this->connection()->SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(connection->perspective())); + if (this->connection()->version().SupportsAntiAmplificationLimit()) { + QuicConnectionPeer::SetAddressValidated(this->connection()); + } + } + + ~TestSession() override { DeleteConnection(); } + + TestCryptoStream* GetMutableCryptoStream() override { + return &crypto_stream_; + } + + const TestCryptoStream* GetCryptoStream() const override { + return &crypto_stream_; + } + + TestStream* CreateOutgoingBidirectionalStream() { + QuicStreamId id = GetNextOutgoingBidirectionalStreamId(); + if (id == + QuicUtils::GetInvalidStreamId(connection()->transport_version())) { + return nullptr; + } + TestStream* stream = new TestStream(id, this, BIDIRECTIONAL); + ActivateStream(absl::WrapUnique(stream)); + return stream; + } + + TestStream* CreateOutgoingUnidirectionalStream() { + TestStream* stream = new TestStream(GetNextOutgoingUnidirectionalStreamId(), + this, WRITE_UNIDIRECTIONAL); + ActivateStream(absl::WrapUnique(stream)); + return stream; + } + + TestStream* CreateIncomingStream(QuicStreamId id) override { + // Enforce the limit on the number of open streams. + if (!VersionHasIetfQuicFrames(connection()->transport_version()) && + stream_id_manager().num_open_incoming_streams() + 1 > + max_open_incoming_bidirectional_streams()) { + // No need to do this test for version 99; it's done by + // QuicSession::GetOrCreateStream. + connection()->CloseConnection( + QUIC_TOO_MANY_OPEN_STREAMS, "Too many streams!", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return nullptr; + } + + TestStream* stream = new TestStream( + id, this, + DetermineStreamType(id, connection()->version(), perspective(), + /*is_incoming=*/true, BIDIRECTIONAL)); + ActivateStream(absl::WrapUnique(stream)); + ++num_incoming_streams_created_; + return stream; + } + + TestStream* CreateIncomingStream(PendingStream* pending) override { + TestStream* stream = new TestStream(pending, this); + ActivateStream(absl::WrapUnique(stream)); + ++num_incoming_streams_created_; + return stream; + } + + // QuicSession doesn't do anything in this method. So it's overridden here to + // test that the session handles pending streams correctly in terms of + // receiving stream frames. + QuicStream* ProcessPendingStream(PendingStream* pending) override { + if (pending->is_bidirectional()) { + return CreateIncomingStream(pending); + } + struct iovec iov; + if (pending->sequencer()->GetReadableRegion(&iov)) { + // Create TestStream once the first byte is received. + return CreateIncomingStream(pending); + } + return nullptr; + } + + bool IsClosedStream(QuicStreamId id) { + return QuicSession::IsClosedStream(id); + } + + QuicStream* GetOrCreateStream(QuicStreamId stream_id) { + return QuicSession::GetOrCreateStream(stream_id); + } + + bool ShouldKeepConnectionAlive() const override { + return GetNumActiveStreams() > 0; + } + + QuicConsumedData WritevData(QuicStreamId id, size_t write_length, + QuicStreamOffset offset, StreamSendingState state, + TransmissionType type, + EncryptionLevel level) override { + bool fin = state != NO_FIN; + QuicConsumedData consumed(write_length, fin); + if (!writev_consumes_all_data_) { + consumed = + QuicSession::WritevData(id, write_length, offset, state, type, level); + } + QuicSessionPeer::GetWriteBlockedStreams(this)->UpdateBytesForStream( + id, consumed.bytes_consumed); + return consumed; + } + + MOCK_METHOD(void, OnCanCreateNewOutgoingStream, (bool unidirectional), + (override)); + + void set_writev_consumes_all_data(bool val) { + writev_consumes_all_data_ = val; + } + + QuicConsumedData SendStreamData(QuicStream* stream) { + if (!QuicUtils::IsCryptoStreamId(connection()->transport_version(), + stream->id()) && + this->connection()->encryption_level() != ENCRYPTION_FORWARD_SECURE) { + this->connection()->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + } + QuicStreamPeer::SendBuffer(stream).SaveStreamData("not empty"); + QuicConsumedData consumed = + WritevData(stream->id(), 9, 0, FIN, NOT_RETRANSMISSION, + GetEncryptionLevelToSendApplicationData()); + QuicStreamPeer::SendBuffer(stream).OnStreamDataConsumed( + consumed.bytes_consumed); + return consumed; + } + + const QuicFrame& save_frame() { return save_frame_; } + + bool SaveFrame(const QuicFrame& frame) { + save_frame_ = frame; + DeleteFrame(&const_cast(frame)); + return true; + } + + QuicConsumedData SendLargeFakeData(QuicStream* stream, int bytes) { + QUICHE_DCHECK(writev_consumes_all_data_); + return WritevData(stream->id(), bytes, 0, FIN, NOT_RETRANSMISSION, + GetEncryptionLevelToSendApplicationData()); + } + + bool UsesPendingStreamForFrame(QuicFrameType type, + QuicStreamId stream_id) const override { + if (!uses_pending_streams_) { + return false; + } + // Uses pending stream for STREAM/RST_STREAM frames with unidirectional read + // stream and uses pending stream for + // STREAM/RST_STREAM/STOP_SENDING/WINDOW_UPDATE frames with bidirectional + // stream. + bool is_incoming_stream = IsIncomingStream(stream_id); + StreamType stream_type = QuicUtils::GetStreamType( + stream_id, perspective(), is_incoming_stream, version()); + switch (type) { + case STREAM_FRAME: + ABSL_FALLTHROUGH_INTENDED; + case RST_STREAM_FRAME: + return is_incoming_stream; + case STOP_SENDING_FRAME: + ABSL_FALLTHROUGH_INTENDED; + case WINDOW_UPDATE_FRAME: + return stream_type == BIDIRECTIONAL; + default: + return false; + } + } + + bool ShouldProcessPendingStreamImmediately() const override { + return process_pending_stream_immediately_; + } + + void set_uses_pending_streams(bool uses_pending_streams) { + uses_pending_streams_ = uses_pending_streams; + } + + void set_process_pending_stream_immediately( + bool process_pending_stream_immediately) { + process_pending_stream_immediately_ = process_pending_stream_immediately; + } + + int num_incoming_streams_created() const { + return num_incoming_streams_created_; + } + + using QuicSession::ActivateStream; + using QuicSession::CanOpenNextOutgoingBidirectionalStream; + using QuicSession::CanOpenNextOutgoingUnidirectionalStream; + using QuicSession::closed_streams; + using QuicSession::GetNextOutgoingBidirectionalStreamId; + using QuicSession::GetNextOutgoingUnidirectionalStreamId; + + private: + StrictMock crypto_stream_; + + bool writev_consumes_all_data_; + bool uses_pending_streams_; + bool process_pending_stream_immediately_ = true; + QuicFrame save_frame_; + int num_incoming_streams_created_; +}; + +class QuicSessionTestBase : public QuicTestWithParam { + protected: + QuicSessionTestBase(Perspective perspective, bool configure_session) + : connection_(new StrictMock( + &helper_, &alarm_factory_, perspective, + SupportedVersions(GetParam()))), + session_(connection_, &session_visitor_), + configure_session_(configure_session) { + session_.config()->SetInitialStreamFlowControlWindowToSend( + kInitialStreamFlowControlWindowForTest); + session_.config()->SetInitialSessionFlowControlWindowToSend( + kInitialSessionFlowControlWindowForTest); + + if (configure_session) { + if (VersionHasIetfQuicFrames(transport_version())) { + EXPECT_CALL(session_, OnCanCreateNewOutgoingStream(false)).Times(1); + EXPECT_CALL(session_, OnCanCreateNewOutgoingStream(true)).Times(1); + } + QuicConfigPeer::SetReceivedMaxBidirectionalStreams( + session_.config(), kDefaultMaxStreamsPerConnection); + QuicConfigPeer::SetReceivedMaxUnidirectionalStreams( + session_.config(), kDefaultMaxStreamsPerConnection); + QuicConfigPeer::SetReceivedInitialMaxStreamDataBytesUnidirectional( + session_.config(), kMinimumFlowControlSendWindow); + QuicConfigPeer::SetReceivedInitialMaxStreamDataBytesIncomingBidirectional( + session_.config(), kMinimumFlowControlSendWindow); + QuicConfigPeer::SetReceivedInitialMaxStreamDataBytesOutgoingBidirectional( + session_.config(), kMinimumFlowControlSendWindow); + QuicConfigPeer::SetReceivedInitialSessionFlowControlWindow( + session_.config(), kMinimumFlowControlSendWindow); + connection_->AdvanceTime(QuicTime::Delta::FromSeconds(1)); + session_.OnConfigNegotiated(); + } + TestCryptoStream* crypto_stream = session_.GetMutableCryptoStream(); + EXPECT_CALL(*crypto_stream, HasPendingRetransmission()) + .Times(testing::AnyNumber()); + testing::Mock::VerifyAndClearExpectations(&session_); + } + + ~QuicSessionTestBase() { + if (configure_session_) { + EXPECT_TRUE(session_.is_configured()); + } + } + + void CheckClosedStreams() { + QuicStreamId first_stream_id = QuicUtils::GetFirstBidirectionalStreamId( + connection_->transport_version(), Perspective::IS_CLIENT); + if (!QuicVersionUsesCryptoFrames(connection_->transport_version())) { + first_stream_id = + QuicUtils::GetCryptoStreamId(connection_->transport_version()); + } + for (QuicStreamId i = first_stream_id; i < 100; i++) { + if (closed_streams_.find(i) == closed_streams_.end()) { + EXPECT_FALSE(session_.IsClosedStream(i)) << " stream id: " << i; + } else { + EXPECT_TRUE(session_.IsClosedStream(i)) << " stream id: " << i; + } + } + } + + void CloseStream(QuicStreamId id) { + if (VersionHasIetfQuicFrames(transport_version())) { + if (QuicUtils::GetStreamType( + id, session_.perspective(), session_.IsIncomingStream(id), + connection_->version()) == READ_UNIDIRECTIONAL) { + // Verify STOP_SENDING but no RESET_STREAM is sent for + // READ_UNIDIRECTIONAL streams. + EXPECT_CALL(*connection_, SendControlFrame(_)) + .Times(1) + .WillOnce(Invoke(&ClearControlFrame)); + EXPECT_CALL(*connection_, OnStreamReset(id, _)).Times(1); + } else if (QuicUtils::GetStreamType( + id, session_.perspective(), session_.IsIncomingStream(id), + connection_->version()) == WRITE_UNIDIRECTIONAL) { + // Verify RESET_STREAM but not STOP_SENDING is sent for write-only + // stream. + EXPECT_CALL(*connection_, SendControlFrame(_)) + .Times(1) + .WillOnce(Invoke(&ClearControlFrame)); + EXPECT_CALL(*connection_, OnStreamReset(id, _)); + } else { + // Verify RESET_STREAM and STOP_SENDING are sent for BIDIRECTIONAL + // streams. + EXPECT_CALL(*connection_, SendControlFrame(_)) + .Times(2) + .WillRepeatedly(Invoke(&ClearControlFrame)); + EXPECT_CALL(*connection_, OnStreamReset(id, _)); + } + } else { + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillOnce(Invoke(&ClearControlFrame)); + EXPECT_CALL(*connection_, OnStreamReset(id, _)); + } + session_.ResetStream(id, QUIC_STREAM_CANCELLED); + closed_streams_.insert(id); + } + + void CompleteHandshake() { + CryptoHandshakeMessage msg; + if (connection_->version().UsesTls() && + connection_->perspective() == Perspective::IS_SERVER) { + // HANDSHAKE_DONE frame. + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillOnce(Invoke(&ClearControlFrame)); + } + session_.GetMutableCryptoStream()->OnHandshakeMessage(msg); + } + + QuicTransportVersion transport_version() const { + return connection_->transport_version(); + } + + QuicStreamId GetNthClientInitiatedBidirectionalId(int n) { + return QuicUtils::GetFirstBidirectionalStreamId( + connection_->transport_version(), Perspective::IS_CLIENT) + + QuicUtils::StreamIdDelta(connection_->transport_version()) * n; + } + + QuicStreamId GetNthClientInitiatedUnidirectionalId(int n) { + return QuicUtils::GetFirstUnidirectionalStreamId( + connection_->transport_version(), Perspective::IS_CLIENT) + + QuicUtils::StreamIdDelta(connection_->transport_version()) * n; + } + + QuicStreamId GetNthServerInitiatedBidirectionalId(int n) { + return QuicUtils::GetFirstBidirectionalStreamId( + connection_->transport_version(), Perspective::IS_SERVER) + + QuicUtils::StreamIdDelta(connection_->transport_version()) * n; + } + + QuicStreamId GetNthServerInitiatedUnidirectionalId(int n) { + return QuicUtils::GetFirstUnidirectionalStreamId( + connection_->transport_version(), Perspective::IS_SERVER) + + QuicUtils::StreamIdDelta(connection_->transport_version()) * n; + } + + QuicStreamId StreamCountToId(QuicStreamCount stream_count, + Perspective perspective, bool bidirectional) { + // Calculate and build up stream ID rather than use + // GetFirst... because tests that rely on this method + // needs to do the stream count where #1 is 0/1/2/3, and not + // take into account that stream 0 is special. + QuicStreamId id = + ((stream_count - 1) * QuicUtils::StreamIdDelta(transport_version())); + if (!bidirectional) { + id |= 0x2; + } + if (perspective == Perspective::IS_SERVER) { + id |= 0x1; + } + return id; + } + + MockQuicConnectionHelper helper_; + MockAlarmFactory alarm_factory_; + NiceMock session_visitor_; + StrictMock* connection_; + TestSession session_; + std::set closed_streams_; + bool configure_session_; +}; + +class QuicSessionTestServer : public QuicSessionTestBase { + public: + // CheckMultiPathResponse validates that a written packet + // contains both expected path responses. + WriteResult CheckMultiPathResponse(const char* buffer, size_t buf_len, + const QuicIpAddress& /*self_address*/, + const QuicSocketAddress& /*peer_address*/, + PerPacketOptions* /*options*/) { + QuicEncryptedPacket packet(buffer, buf_len); + { + InSequence s; + EXPECT_CALL(framer_visitor_, OnPacket()); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedPublicHeader(_)); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedHeader(_)); + EXPECT_CALL(framer_visitor_, OnDecryptedPacket(_, _)); + EXPECT_CALL(framer_visitor_, OnPacketHeader(_)); + EXPECT_CALL(framer_visitor_, OnPathResponseFrame(_)) + .WillOnce( + WithArg<0>(Invoke([this](const QuicPathResponseFrame& frame) { + EXPECT_EQ(path_frame_buffer1_, frame.data_buffer); + return true; + }))); + EXPECT_CALL(framer_visitor_, OnPathResponseFrame(_)) + .WillOnce( + WithArg<0>(Invoke([this](const QuicPathResponseFrame& frame) { + EXPECT_EQ(path_frame_buffer2_, frame.data_buffer); + return true; + }))); + EXPECT_CALL(framer_visitor_, OnPacketComplete()); + } + client_framer_.ProcessPacket(packet); + return WriteResult(WRITE_STATUS_OK, 0); + } + + protected: + QuicSessionTestServer() + : QuicSessionTestBase(Perspective::IS_SERVER, /*configure_session=*/true), + path_frame_buffer1_({0, 1, 2, 3, 4, 5, 6, 7}), + path_frame_buffer2_({8, 9, 10, 11, 12, 13, 14, 15}), + client_framer_(SupportedVersions(GetParam()), QuicTime::Zero(), + Perspective::IS_CLIENT, kQuicDefaultConnectionIdLength) { + client_framer_.set_visitor(&framer_visitor_); + client_framer_.SetInitialObfuscators(TestConnectionId()); + if (client_framer_.version().KnowsWhichDecrypterToUse()) { + client_framer_.InstallDecrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(Perspective::IS_CLIENT)); + } + } + + QuicPathFrameBuffer path_frame_buffer1_; + QuicPathFrameBuffer path_frame_buffer2_; + StrictMock framer_visitor_; + // Framer used to process packets sent by server. + QuicFramer client_framer_; +}; + +INSTANTIATE_TEST_SUITE_P(Tests, QuicSessionTestServer, + ::testing::ValuesIn(AllSupportedVersions()), + ::testing::PrintToStringParamName()); + +TEST_P(QuicSessionTestServer, PeerAddress) { + EXPECT_EQ(QuicSocketAddress(QuicIpAddress::Loopback4(), kTestPort), + session_.peer_address()); +} + +TEST_P(QuicSessionTestServer, SelfAddress) { + EXPECT_TRUE(session_.self_address().IsInitialized()); +} + +TEST_P(QuicSessionTestServer, DontCallOnWriteBlockedForDisconnectedConnection) { + EXPECT_CALL(*connection_, CloseConnection(_, _, _)) + .WillOnce( + Invoke(connection_, &MockQuicConnection::ReallyCloseConnection)); + connection_->CloseConnection(QUIC_NO_ERROR, "Everything is fine.", + ConnectionCloseBehavior::SILENT_CLOSE); + ASSERT_FALSE(connection_->connected()); + + EXPECT_CALL(session_visitor_, OnWriteBlocked(_)).Times(0); + session_.OnWriteBlocked(); +} + +TEST_P(QuicSessionTestServer, OneRttKeysAvailable) { + EXPECT_FALSE(session_.OneRttKeysAvailable()); + CompleteHandshake(); + EXPECT_TRUE(session_.OneRttKeysAvailable()); +} + +TEST_P(QuicSessionTestServer, IsClosedStreamDefault) { + // Ensure that no streams are initially closed. + QuicStreamId first_stream_id = QuicUtils::GetFirstBidirectionalStreamId( + connection_->transport_version(), Perspective::IS_CLIENT); + if (!QuicVersionUsesCryptoFrames(connection_->transport_version())) { + first_stream_id = + QuicUtils::GetCryptoStreamId(connection_->transport_version()); + } + for (QuicStreamId i = first_stream_id; i < 100; i++) { + EXPECT_FALSE(session_.IsClosedStream(i)) << "stream id: " << i; + } +} + +TEST_P(QuicSessionTestServer, AvailableBidirectionalStreams) { + ASSERT_TRUE(session_.GetOrCreateStream( + GetNthClientInitiatedBidirectionalId(3)) != nullptr); + // Smaller bidirectional streams should be available. + EXPECT_TRUE(QuicSessionPeer::IsStreamAvailable( + &session_, GetNthClientInitiatedBidirectionalId(1))); + EXPECT_TRUE(QuicSessionPeer::IsStreamAvailable( + &session_, GetNthClientInitiatedBidirectionalId(2))); + ASSERT_TRUE(session_.GetOrCreateStream( + GetNthClientInitiatedBidirectionalId(2)) != nullptr); + ASSERT_TRUE(session_.GetOrCreateStream( + GetNthClientInitiatedBidirectionalId(1)) != nullptr); +} + +TEST_P(QuicSessionTestServer, AvailableUnidirectionalStreams) { + ASSERT_TRUE(session_.GetOrCreateStream( + GetNthClientInitiatedUnidirectionalId(3)) != nullptr); + // Smaller unidirectional streams should be available. + EXPECT_TRUE(QuicSessionPeer::IsStreamAvailable( + &session_, GetNthClientInitiatedUnidirectionalId(1))); + EXPECT_TRUE(QuicSessionPeer::IsStreamAvailable( + &session_, GetNthClientInitiatedUnidirectionalId(2))); + ASSERT_TRUE(session_.GetOrCreateStream( + GetNthClientInitiatedUnidirectionalId(2)) != nullptr); + ASSERT_TRUE(session_.GetOrCreateStream( + GetNthClientInitiatedUnidirectionalId(1)) != nullptr); +} + +TEST_P(QuicSessionTestServer, MaxAvailableBidirectionalStreams) { + if (VersionHasIetfQuicFrames(transport_version())) { + EXPECT_EQ(session_.max_open_incoming_bidirectional_streams(), + session_.MaxAvailableBidirectionalStreams()); + } else { + // The protocol specification requires that there can be at least 10 times + // as many available streams as the connection's maximum open streams. + EXPECT_EQ(session_.max_open_incoming_bidirectional_streams() * + kMaxAvailableStreamsMultiplier, + session_.MaxAvailableBidirectionalStreams()); + } +} + +TEST_P(QuicSessionTestServer, MaxAvailableUnidirectionalStreams) { + if (VersionHasIetfQuicFrames(transport_version())) { + EXPECT_EQ(session_.max_open_incoming_unidirectional_streams(), + session_.MaxAvailableUnidirectionalStreams()); + } else { + // The protocol specification requires that there can be at least 10 times + // as many available streams as the connection's maximum open streams. + EXPECT_EQ(session_.max_open_incoming_unidirectional_streams() * + kMaxAvailableStreamsMultiplier, + session_.MaxAvailableUnidirectionalStreams()); + } +} + +TEST_P(QuicSessionTestServer, IsClosedBidirectionalStreamLocallyCreated) { + CompleteHandshake(); + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + EXPECT_EQ(GetNthServerInitiatedBidirectionalId(0), stream2->id()); + TestStream* stream4 = session_.CreateOutgoingBidirectionalStream(); + EXPECT_EQ(GetNthServerInitiatedBidirectionalId(1), stream4->id()); + + CheckClosedStreams(); + CloseStream(GetNthServerInitiatedBidirectionalId(0)); + CheckClosedStreams(); + CloseStream(GetNthServerInitiatedBidirectionalId(1)); + CheckClosedStreams(); +} + +TEST_P(QuicSessionTestServer, IsClosedUnidirectionalStreamLocallyCreated) { + CompleteHandshake(); + TestStream* stream2 = session_.CreateOutgoingUnidirectionalStream(); + EXPECT_EQ(GetNthServerInitiatedUnidirectionalId(0), stream2->id()); + TestStream* stream4 = session_.CreateOutgoingUnidirectionalStream(); + EXPECT_EQ(GetNthServerInitiatedUnidirectionalId(1), stream4->id()); + + CheckClosedStreams(); + CloseStream(GetNthServerInitiatedUnidirectionalId(0)); + CheckClosedStreams(); + CloseStream(GetNthServerInitiatedUnidirectionalId(1)); + CheckClosedStreams(); +} + +TEST_P(QuicSessionTestServer, IsClosedBidirectionalStreamPeerCreated) { + CompleteHandshake(); + QuicStreamId stream_id1 = GetNthClientInitiatedBidirectionalId(0); + QuicStreamId stream_id2 = GetNthClientInitiatedBidirectionalId(1); + session_.GetOrCreateStream(stream_id1); + session_.GetOrCreateStream(stream_id2); + + CheckClosedStreams(); + CloseStream(stream_id1); + CheckClosedStreams(); + CloseStream(stream_id2); + // Create a stream, and make another available. + QuicStream* stream3 = session_.GetOrCreateStream( + stream_id2 + + 2 * QuicUtils::StreamIdDelta(connection_->transport_version())); + CheckClosedStreams(); + // Close one, but make sure the other is still not closed + CloseStream(stream3->id()); + CheckClosedStreams(); +} + +TEST_P(QuicSessionTestServer, IsClosedUnidirectionalStreamPeerCreated) { + CompleteHandshake(); + QuicStreamId stream_id1 = GetNthClientInitiatedUnidirectionalId(0); + QuicStreamId stream_id2 = GetNthClientInitiatedUnidirectionalId(1); + session_.GetOrCreateStream(stream_id1); + session_.GetOrCreateStream(stream_id2); + + CheckClosedStreams(); + CloseStream(stream_id1); + CheckClosedStreams(); + CloseStream(stream_id2); + // Create a stream, and make another available. + QuicStream* stream3 = session_.GetOrCreateStream( + stream_id2 + + 2 * QuicUtils::StreamIdDelta(connection_->transport_version())); + CheckClosedStreams(); + // Close one, but make sure the other is still not closed + CloseStream(stream3->id()); + CheckClosedStreams(); +} + +TEST_P(QuicSessionTestServer, MaximumAvailableOpenedBidirectionalStreams) { + QuicStreamId stream_id = GetNthClientInitiatedBidirectionalId(0); + session_.GetOrCreateStream(stream_id); + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + EXPECT_NE(nullptr, + session_.GetOrCreateStream(GetNthClientInitiatedBidirectionalId( + session_.max_open_incoming_bidirectional_streams() - 1))); +} + +TEST_P(QuicSessionTestServer, MaximumAvailableOpenedUnidirectionalStreams) { + QuicStreamId stream_id = GetNthClientInitiatedUnidirectionalId(0); + session_.GetOrCreateStream(stream_id); + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + EXPECT_NE(nullptr, + session_.GetOrCreateStream(GetNthClientInitiatedUnidirectionalId( + session_.max_open_incoming_unidirectional_streams() - 1))); +} + +TEST_P(QuicSessionTestServer, TooManyAvailableBidirectionalStreams) { + QuicStreamId stream_id1 = GetNthClientInitiatedBidirectionalId(0); + QuicStreamId stream_id2; + EXPECT_NE(nullptr, session_.GetOrCreateStream(stream_id1)); + // A stream ID which is too large to create. + stream_id2 = GetNthClientInitiatedBidirectionalId( + session_.MaxAvailableBidirectionalStreams() + 2); + if (VersionHasIetfQuicFrames(transport_version())) { + // IETF QUIC terminates the connection with invalid stream id + EXPECT_CALL(*connection_, CloseConnection(QUIC_INVALID_STREAM_ID, _, _)); + } else { + // other versions terminate the connection with + // QUIC_TOO_MANY_AVAILABLE_STREAMS. + EXPECT_CALL(*connection_, + CloseConnection(QUIC_TOO_MANY_AVAILABLE_STREAMS, _, _)); + } + EXPECT_EQ(nullptr, session_.GetOrCreateStream(stream_id2)); +} + +TEST_P(QuicSessionTestServer, TooManyAvailableUnidirectionalStreams) { + QuicStreamId stream_id1 = GetNthClientInitiatedUnidirectionalId(0); + QuicStreamId stream_id2; + EXPECT_NE(nullptr, session_.GetOrCreateStream(stream_id1)); + // A stream ID which is too large to create. + stream_id2 = GetNthClientInitiatedUnidirectionalId( + session_.MaxAvailableUnidirectionalStreams() + 2); + if (VersionHasIetfQuicFrames(transport_version())) { + // IETF QUIC terminates the connection with invalid stream id + EXPECT_CALL(*connection_, CloseConnection(QUIC_INVALID_STREAM_ID, _, _)); + } else { + // other versions terminate the connection with + // QUIC_TOO_MANY_AVAILABLE_STREAMS. + EXPECT_CALL(*connection_, + CloseConnection(QUIC_TOO_MANY_AVAILABLE_STREAMS, _, _)); + } + EXPECT_EQ(nullptr, session_.GetOrCreateStream(stream_id2)); +} + +TEST_P(QuicSessionTestServer, ManyAvailableBidirectionalStreams) { + // When max_open_streams_ is 200, should be able to create 200 streams + // out-of-order, that is, creating the one with the largest stream ID first. + if (VersionHasIetfQuicFrames(transport_version())) { + QuicSessionPeer::SetMaxOpenIncomingBidirectionalStreams(&session_, 200); + // Smaller limit on unidirectional streams to help detect crossed wires. + QuicSessionPeer::SetMaxOpenIncomingUnidirectionalStreams(&session_, 50); + } else { + QuicSessionPeer::SetMaxOpenIncomingStreams(&session_, 200); + } + // Create a stream at the start of the range. + QuicStreamId stream_id = GetNthClientInitiatedBidirectionalId(0); + EXPECT_NE(nullptr, session_.GetOrCreateStream(stream_id)); + + // Create the largest stream ID of a threatened total of 200 streams. + // GetNth... starts at 0, so for 200 streams, get the 199th. + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + EXPECT_NE(nullptr, session_.GetOrCreateStream( + GetNthClientInitiatedBidirectionalId(199))); + + if (VersionHasIetfQuicFrames(transport_version())) { + // If IETF QUIC, check to make sure that creating bidirectional + // streams does not mess up the unidirectional streams. + stream_id = GetNthClientInitiatedUnidirectionalId(0); + EXPECT_NE(nullptr, session_.GetOrCreateStream(stream_id)); + // Now try to get the last possible unidirectional stream. + EXPECT_NE(nullptr, session_.GetOrCreateStream( + GetNthClientInitiatedUnidirectionalId(49))); + // and this should fail because it exceeds the unidirectional limit + // (but not the bi-) + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_INVALID_STREAM_ID, + "Stream id 798 would exceed stream count limit 50", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET)) + .Times(1); + EXPECT_EQ(nullptr, session_.GetOrCreateStream( + GetNthClientInitiatedUnidirectionalId(199))); + } +} + +TEST_P(QuicSessionTestServer, ManyAvailableUnidirectionalStreams) { + // When max_open_streams_ is 200, should be able to create 200 streams + // out-of-order, that is, creating the one with the largest stream ID first. + if (VersionHasIetfQuicFrames(transport_version())) { + QuicSessionPeer::SetMaxOpenIncomingUnidirectionalStreams(&session_, 200); + // Smaller limit on unidirectional streams to help detect crossed wires. + QuicSessionPeer::SetMaxOpenIncomingBidirectionalStreams(&session_, 50); + } else { + QuicSessionPeer::SetMaxOpenIncomingStreams(&session_, 200); + } + // Create one stream. + QuicStreamId stream_id = GetNthClientInitiatedUnidirectionalId(0); + EXPECT_NE(nullptr, session_.GetOrCreateStream(stream_id)); + + // Create the largest stream ID of a threatened total of 200 streams. + // GetNth... starts at 0, so for 200 streams, get the 199th. + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + EXPECT_NE(nullptr, session_.GetOrCreateStream( + GetNthClientInitiatedUnidirectionalId(199))); + if (VersionHasIetfQuicFrames(transport_version())) { + // If IETF QUIC, check to make sure that creating unidirectional + // streams does not mess up the bidirectional streams. + stream_id = GetNthClientInitiatedBidirectionalId(0); + EXPECT_NE(nullptr, session_.GetOrCreateStream(stream_id)); + // Now try to get the last possible bidirectional stream. + EXPECT_NE(nullptr, session_.GetOrCreateStream( + GetNthClientInitiatedBidirectionalId(49))); + // and this should fail because it exceeds the bnidirectional limit + // (but not the uni-) + std::string error_detail; + if (QuicVersionUsesCryptoFrames(transport_version())) { + error_detail = "Stream id 796 would exceed stream count limit 50"; + } else { + error_detail = "Stream id 800 would exceed stream count limit 50"; + } + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_INVALID_STREAM_ID, error_detail, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET)) + .Times(1); + EXPECT_EQ(nullptr, session_.GetOrCreateStream( + GetNthClientInitiatedBidirectionalId(199))); + } +} + +TEST_P(QuicSessionTestServer, DebugDFatalIfMarkingClosedStreamWriteBlocked) { + CompleteHandshake(); + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + QuicStreamId closed_stream_id = stream2->id(); + // Close the stream. + EXPECT_CALL(*connection_, SendControlFrame(_)); + EXPECT_CALL(*connection_, OnStreamReset(closed_stream_id, _)); + stream2->Reset(QUIC_BAD_APPLICATION_PAYLOAD); + std::string msg = + absl::StrCat("Marking unknown stream ", closed_stream_id, " blocked."); + EXPECT_QUIC_BUG(session_.MarkConnectionLevelWriteBlocked(closed_stream_id), + msg); +} + +// SpdySession::OnCanWrite() queries QuicWriteBlockedList for the number of +// streams that are marked as connection level write blocked, then queries +// QuicWriteBlockedList that many times for what stream to write data on. This +// can result in some streams writing multiple times in a single +// SpdySession::OnCanWrite() call while other streams not getting a turn. +TEST_P(QuicSessionTestServer, OnCanWrite) { + CompleteHandshake(); + session_.set_writev_consumes_all_data(true); + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + TestStream* stream4 = session_.CreateOutgoingBidirectionalStream(); + TestStream* stream6 = session_.CreateOutgoingBidirectionalStream(); + + session_.MarkConnectionLevelWriteBlocked(stream2->id()); + session_.MarkConnectionLevelWriteBlocked(stream6->id()); + session_.MarkConnectionLevelWriteBlocked(stream4->id()); + + InSequence s; + + // Reregister, to test the loop limit. + EXPECT_CALL(*stream2, OnCanWrite()).WillOnce(Invoke([this, stream2]() { + session_.SendStreamData(stream2); + session_.MarkConnectionLevelWriteBlocked(stream2->id()); + })); + + if (!GetQuicReloadableFlag(quic_disable_batch_write) || + GetQuicReloadableFlag(quic_priority_respect_incremental)) { + // If batched writes are enabled, stream 2 will write again. Also, streams + // are non-incremental by default, so if the incremental flag is respected, + // then stream 2 will write again. (If it is not respected, then every + // stream is treated as incremental.) + EXPECT_CALL(*stream2, OnCanWrite()).WillOnce(Invoke([this, stream2]() { + session_.SendStreamData(stream2); + })); + EXPECT_CALL(*stream6, OnCanWrite()).WillOnce(Invoke([this, stream6]() { + session_.SendStreamData(stream6); + })); + } else { + EXPECT_CALL(*stream6, OnCanWrite()).WillOnce(Invoke([this, stream6]() { + session_.SendStreamData(stream6); + })); + EXPECT_CALL(*stream4, OnCanWrite()).WillOnce(Invoke([this, stream4]() { + session_.SendStreamData(stream4); + })); + } + + // Stream 4 will not get called, as we exceeded the loop limit. + session_.OnCanWrite(); + EXPECT_TRUE(session_.WillingAndAbleToWrite()); +} + +TEST_P(QuicSessionTestServer, TestBatchedWrites) { + session_.set_writev_consumes_all_data(true); + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + TestStream* stream4 = session_.CreateOutgoingBidirectionalStream(); + TestStream* stream6 = session_.CreateOutgoingBidirectionalStream(); + + const QuicStreamPriority priority( + HttpStreamPriority{HttpStreamPriority::kDefaultUrgency, + /* incremental = */ true}); + stream2->SetPriority(priority); + stream4->SetPriority(priority); + stream6->SetPriority(priority); + + session_.set_writev_consumes_all_data(true); + // Tell the session that stream2 and stream4 have data to write. + session_.MarkConnectionLevelWriteBlocked(stream2->id()); + session_.MarkConnectionLevelWriteBlocked(stream4->id()); + + // With two sessions blocked, we should get two write calls. + InSequence s; + EXPECT_CALL(*stream2, OnCanWrite()).WillOnce(Invoke([this, stream2]() { + session_.SendLargeFakeData(stream2, 6000); + session_.MarkConnectionLevelWriteBlocked(stream2->id()); + })); + if (GetQuicReloadableFlag(quic_disable_batch_write)) { + EXPECT_CALL(*stream4, OnCanWrite()).WillOnce(Invoke([this, stream4]() { + session_.SendLargeFakeData(stream4, 6000); + session_.MarkConnectionLevelWriteBlocked(stream4->id()); + })); + } else { + // Since stream2 only wrote 6 kB and marked itself blocked again, + // the second write happens on the same stream. + EXPECT_CALL(*stream2, OnCanWrite()).WillOnce(Invoke([this, stream2]() { + session_.SendLargeFakeData(stream2, 6000); + session_.MarkConnectionLevelWriteBlocked(stream2->id()); + })); + } + session_.OnCanWrite(); + + // If batched write is enabled, stream2 can write a third time in a row. + // If batched write is disabled, stream2 has a turn again after stream4. + EXPECT_CALL(*stream2, OnCanWrite()).WillOnce(Invoke([this, stream2]() { + session_.SendLargeFakeData(stream2, 6000); + session_.MarkConnectionLevelWriteBlocked(stream2->id()); + })); + EXPECT_CALL(*stream4, OnCanWrite()).WillOnce(Invoke([this, stream4]() { + session_.SendLargeFakeData(stream4, 6000); + session_.MarkConnectionLevelWriteBlocked(stream4->id()); + })); + session_.OnCanWrite(); + + // The next write adds a block for stream 6. + stream6->SetPriority(QuicStreamPriority(HttpStreamPriority{ + kV3HighestPriority, HttpStreamPriority::kDefaultIncremental})); + if (GetQuicReloadableFlag(quic_disable_batch_write)) { + EXPECT_CALL(*stream2, OnCanWrite()) + .WillOnce(Invoke([this, stream2, stream6]() { + session_.SendLargeFakeData(stream2, 6000); + session_.MarkConnectionLevelWriteBlocked(stream2->id()); + session_.MarkConnectionLevelWriteBlocked(stream6->id()); + })); + } else { + EXPECT_CALL(*stream4, OnCanWrite()) + .WillOnce(Invoke([this, stream4, stream6]() { + session_.SendLargeFakeData(stream4, 6000); + session_.MarkConnectionLevelWriteBlocked(stream4->id()); + session_.MarkConnectionLevelWriteBlocked(stream6->id()); + })); + } + // Stream 6 will write next, because it has higher priority. + // It does not mark itself as blocked. + EXPECT_CALL(*stream6, OnCanWrite()) + .WillOnce(Invoke([this, stream4, stream6]() { + session_.SendStreamData(stream6); + session_.SendLargeFakeData(stream4, 6000); + })); + session_.OnCanWrite(); + + // If batched write is enabled, stream4 can continue to write, but will + // exhaust its write limit, so the last write is on stream2. + // If batched write is disabled, stream4 has a turn again, then stream2. + EXPECT_CALL(*stream4, OnCanWrite()).WillOnce(Invoke([this, stream4]() { + session_.SendLargeFakeData(stream4, 12000); + session_.MarkConnectionLevelWriteBlocked(stream4->id()); + })); + EXPECT_CALL(*stream2, OnCanWrite()).WillOnce(Invoke([this, stream2]() { + session_.SendLargeFakeData(stream2, 6000); + session_.MarkConnectionLevelWriteBlocked(stream2->id()); + })); + session_.OnCanWrite(); +} + +TEST_P(QuicSessionTestServer, OnCanWriteBundlesStreams) { + // Encryption needs to be established before data can be sent. + CompleteHandshake(); + MockPacketWriter* writer = static_cast( + QuicConnectionPeer::GetWriter(session_.connection())); + + // Drive congestion control manually. + MockSendAlgorithm* send_algorithm = new StrictMock; + QuicConnectionPeer::SetSendAlgorithm(session_.connection(), send_algorithm); + + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + TestStream* stream4 = session_.CreateOutgoingBidirectionalStream(); + TestStream* stream6 = session_.CreateOutgoingBidirectionalStream(); + + session_.MarkConnectionLevelWriteBlocked(stream2->id()); + session_.MarkConnectionLevelWriteBlocked(stream6->id()); + session_.MarkConnectionLevelWriteBlocked(stream4->id()); + + EXPECT_CALL(*send_algorithm, CanSend(_)).WillRepeatedly(Return(true)); + EXPECT_CALL(*send_algorithm, GetCongestionWindow()) + .WillRepeatedly(Return(kMaxOutgoingPacketSize * 10)); + EXPECT_CALL(*send_algorithm, InRecovery()).WillRepeatedly(Return(false)); + EXPECT_CALL(*stream2, OnCanWrite()).WillOnce(Invoke([this, stream2]() { + session_.SendStreamData(stream2); + })); + EXPECT_CALL(*stream4, OnCanWrite()).WillOnce(Invoke([this, stream4]() { + session_.SendStreamData(stream4); + })); + EXPECT_CALL(*stream6, OnCanWrite()).WillOnce(Invoke([this, stream6]() { + session_.SendStreamData(stream6); + })); + + // Expect that we only send one packet, the writes from different streams + // should be bundled together. + EXPECT_CALL(*writer, WritePacket(_, _, _, _, _)) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 0))); + EXPECT_CALL(*send_algorithm, OnPacketSent(_, _, _, _, _)); + EXPECT_CALL(*send_algorithm, OnApplicationLimited(_)); + session_.OnCanWrite(); + EXPECT_FALSE(session_.WillingAndAbleToWrite()); +} + +TEST_P(QuicSessionTestServer, OnCanWriteCongestionControlBlocks) { + CompleteHandshake(); + session_.set_writev_consumes_all_data(true); + InSequence s; + + // Drive congestion control manually. + MockSendAlgorithm* send_algorithm = new StrictMock; + QuicConnectionPeer::SetSendAlgorithm(session_.connection(), send_algorithm); + + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + TestStream* stream4 = session_.CreateOutgoingBidirectionalStream(); + TestStream* stream6 = session_.CreateOutgoingBidirectionalStream(); + + session_.MarkConnectionLevelWriteBlocked(stream2->id()); + session_.MarkConnectionLevelWriteBlocked(stream6->id()); + session_.MarkConnectionLevelWriteBlocked(stream4->id()); + + EXPECT_CALL(*send_algorithm, CanSend(_)).WillOnce(Return(true)); + EXPECT_CALL(*stream2, OnCanWrite()).WillOnce(Invoke([this, stream2]() { + session_.SendStreamData(stream2); + })); + EXPECT_CALL(*send_algorithm, GetCongestionWindow()).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm, CanSend(_)).WillOnce(Return(true)); + EXPECT_CALL(*stream6, OnCanWrite()).WillOnce(Invoke([this, stream6]() { + session_.SendStreamData(stream6); + })); + EXPECT_CALL(*send_algorithm, CanSend(_)).WillOnce(Return(false)); + // stream4->OnCanWrite is not called. + + session_.OnCanWrite(); + EXPECT_TRUE(session_.WillingAndAbleToWrite()); + + // Still congestion-control blocked. + EXPECT_CALL(*send_algorithm, CanSend(_)).WillOnce(Return(false)); + session_.OnCanWrite(); + EXPECT_TRUE(session_.WillingAndAbleToWrite()); + + // stream4->OnCanWrite is called once the connection stops being + // congestion-control blocked. + EXPECT_CALL(*send_algorithm, CanSend(_)).WillOnce(Return(true)); + EXPECT_CALL(*stream4, OnCanWrite()).WillOnce(Invoke([this, stream4]() { + session_.SendStreamData(stream4); + })); + EXPECT_CALL(*send_algorithm, OnApplicationLimited(_)); + session_.OnCanWrite(); + EXPECT_FALSE(session_.WillingAndAbleToWrite()); +} + +TEST_P(QuicSessionTestServer, OnCanWriteWriterBlocks) { + CompleteHandshake(); + // Drive congestion control manually in order to ensure that + // application-limited signaling is handled correctly. + MockSendAlgorithm* send_algorithm = new StrictMock; + QuicConnectionPeer::SetSendAlgorithm(session_.connection(), send_algorithm); + EXPECT_CALL(*send_algorithm, CanSend(_)).WillRepeatedly(Return(true)); + + // Drive packet writer manually. + MockPacketWriter* writer = static_cast( + QuicConnectionPeer::GetWriter(session_.connection())); + EXPECT_CALL(*writer, IsWriteBlocked()).WillRepeatedly(Return(true)); + EXPECT_CALL(*writer, WritePacket(_, _, _, _, _)).Times(0); + + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + + session_.MarkConnectionLevelWriteBlocked(stream2->id()); + + EXPECT_CALL(*stream2, OnCanWrite()).Times(0); + EXPECT_CALL(*send_algorithm, OnApplicationLimited(_)).Times(0); + + session_.OnCanWrite(); + EXPECT_TRUE(session_.WillingAndAbleToWrite()); +} + +TEST_P(QuicSessionTestServer, SendStreamsBlocked) { + if (!VersionHasIetfQuicFrames(transport_version())) { + return; + } + CompleteHandshake(); + for (size_t i = 0; i < kDefaultMaxStreamsPerConnection; ++i) { + ASSERT_TRUE(session_.CanOpenNextOutgoingBidirectionalStream()); + session_.GetNextOutgoingBidirectionalStreamId(); + } + // Next checking causes STREAMS_BLOCKED to be sent. + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillOnce(Invoke([](const QuicFrame& frame) { + EXPECT_FALSE(frame.streams_blocked_frame.unidirectional); + EXPECT_EQ(kDefaultMaxStreamsPerConnection, + frame.streams_blocked_frame.stream_count); + ClearControlFrame(frame); + return true; + })); + EXPECT_FALSE(session_.CanOpenNextOutgoingBidirectionalStream()); + + for (size_t i = 0; i < kDefaultMaxStreamsPerConnection; ++i) { + ASSERT_TRUE(session_.CanOpenNextOutgoingUnidirectionalStream()); + session_.GetNextOutgoingUnidirectionalStreamId(); + } + // Next checking causes STREAM_BLOCKED to be sent. + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillOnce(Invoke([](const QuicFrame& frame) { + EXPECT_TRUE(frame.streams_blocked_frame.unidirectional); + EXPECT_EQ(kDefaultMaxStreamsPerConnection, + frame.streams_blocked_frame.stream_count); + ClearControlFrame(frame); + return true; + })); + EXPECT_FALSE(session_.CanOpenNextOutgoingUnidirectionalStream()); +} + +TEST_P(QuicSessionTestServer, BufferedHandshake) { + // This test is testing behavior of crypto stream flow control, but when + // CRYPTO frames are used, there is no flow control for the crypto handshake. + if (QuicVersionUsesCryptoFrames(connection_->transport_version())) { + return; + } + session_.set_writev_consumes_all_data(true); + EXPECT_FALSE(session_.HasPendingHandshake()); // Default value. + + // Test that blocking other streams does not change our status. + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + session_.MarkConnectionLevelWriteBlocked(stream2->id()); + EXPECT_FALSE(session_.HasPendingHandshake()); + + TestStream* stream3 = session_.CreateOutgoingBidirectionalStream(); + session_.MarkConnectionLevelWriteBlocked(stream3->id()); + EXPECT_FALSE(session_.HasPendingHandshake()); + + // Blocking (due to buffering of) the Crypto stream is detected. + session_.MarkConnectionLevelWriteBlocked( + QuicUtils::GetCryptoStreamId(connection_->transport_version())); + EXPECT_TRUE(session_.HasPendingHandshake()); + + TestStream* stream4 = session_.CreateOutgoingBidirectionalStream(); + session_.MarkConnectionLevelWriteBlocked(stream4->id()); + EXPECT_TRUE(session_.HasPendingHandshake()); + + InSequence s; + // Force most streams to re-register, which is common scenario when we block + // the Crypto stream, and only the crypto stream can "really" write. + + // Due to prioritization, we *should* be asked to write the crypto stream + // first. + // Don't re-register the crypto stream (which signals complete writing). + TestCryptoStream* crypto_stream = session_.GetMutableCryptoStream(); + EXPECT_CALL(*crypto_stream, OnCanWrite()); + + EXPECT_CALL(*stream2, OnCanWrite()).WillOnce(Invoke([this, stream2]() { + session_.SendStreamData(stream2); + })); + EXPECT_CALL(*stream3, OnCanWrite()).WillOnce(Invoke([this, stream3]() { + session_.SendStreamData(stream3); + })); + EXPECT_CALL(*stream4, OnCanWrite()).WillOnce(Invoke([this, stream4]() { + session_.SendStreamData(stream4); + session_.MarkConnectionLevelWriteBlocked(stream4->id()); + })); + + session_.OnCanWrite(); + EXPECT_TRUE(session_.WillingAndAbleToWrite()); + EXPECT_FALSE(session_.HasPendingHandshake()); // Crypto stream wrote. +} + +TEST_P(QuicSessionTestServer, OnCanWriteWithClosedStream) { + CompleteHandshake(); + session_.set_writev_consumes_all_data(true); + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + TestStream* stream4 = session_.CreateOutgoingBidirectionalStream(); + TestStream* stream6 = session_.CreateOutgoingBidirectionalStream(); + + session_.MarkConnectionLevelWriteBlocked(stream2->id()); + session_.MarkConnectionLevelWriteBlocked(stream6->id()); + session_.MarkConnectionLevelWriteBlocked(stream4->id()); + CloseStream(stream6->id()); + + InSequence s; + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillRepeatedly(Invoke(&ClearControlFrame)); + EXPECT_CALL(*stream2, OnCanWrite()).WillOnce(Invoke([this, stream2]() { + session_.SendStreamData(stream2); + })); + EXPECT_CALL(*stream4, OnCanWrite()).WillOnce(Invoke([this, stream4]() { + session_.SendStreamData(stream4); + })); + session_.OnCanWrite(); + EXPECT_FALSE(session_.WillingAndAbleToWrite()); +} + +TEST_P(QuicSessionTestServer, OnCanWriteLimitsNumWritesIfFlowControlBlocked) { + // Drive congestion control manually in order to ensure that + // application-limited signaling is handled correctly. + MockSendAlgorithm* send_algorithm = new StrictMock; + QuicConnectionPeer::SetSendAlgorithm(session_.connection(), send_algorithm); + EXPECT_CALL(*send_algorithm, CanSend(_)).WillRepeatedly(Return(true)); + + // Ensure connection level flow control blockage. + QuicFlowControllerPeer::SetSendWindowOffset(session_.flow_controller(), 0); + EXPECT_TRUE(session_.flow_controller()->IsBlocked()); + EXPECT_TRUE(session_.IsConnectionFlowControlBlocked()); + EXPECT_FALSE(session_.IsStreamFlowControlBlocked()); + + // Mark the crypto and headers streams as write blocked, we expect them to be + // allowed to write later. + if (!QuicVersionUsesCryptoFrames(connection_->transport_version())) { + session_.MarkConnectionLevelWriteBlocked( + QuicUtils::GetCryptoStreamId(connection_->transport_version())); + } + + // Create a data stream, and although it is write blocked we never expect it + // to be allowed to write as we are connection level flow control blocked. + TestStream* stream = session_.CreateOutgoingBidirectionalStream(); + session_.MarkConnectionLevelWriteBlocked(stream->id()); + EXPECT_CALL(*stream, OnCanWrite()).Times(0); + + // The crypto and headers streams should be called even though we are + // connection flow control blocked. + if (!QuicVersionUsesCryptoFrames(connection_->transport_version())) { + TestCryptoStream* crypto_stream = session_.GetMutableCryptoStream(); + EXPECT_CALL(*crypto_stream, OnCanWrite()); + } + + // After the crypto and header streams perform a write, the connection will be + // blocked by the flow control, hence it should become application-limited. + EXPECT_CALL(*send_algorithm, OnApplicationLimited(_)); + + session_.OnCanWrite(); + EXPECT_FALSE(session_.WillingAndAbleToWrite()); +} + +TEST_P(QuicSessionTestServer, SendGoAway) { + if (VersionHasIetfQuicFrames(transport_version())) { + // In IETF QUIC, GOAWAY lives up in the HTTP layer. + return; + } + CompleteHandshake(); + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + MockPacketWriter* writer = static_cast( + QuicConnectionPeer::GetWriter(session_.connection())); + EXPECT_CALL(*writer, WritePacket(_, _, _, _, _)) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 0))); + + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillOnce( + Invoke(connection_, &MockQuicConnection::ReallySendControlFrame)); + session_.SendGoAway(QUIC_PEER_GOING_AWAY, "Going Away."); + EXPECT_TRUE(session_.transport_goaway_sent()); + + const QuicStreamId kTestStreamId = 5u; + EXPECT_CALL(*connection_, SendControlFrame(_)).Times(0); + EXPECT_CALL(*connection_, + OnStreamReset(kTestStreamId, QUIC_STREAM_PEER_GOING_AWAY)) + .Times(0); + EXPECT_TRUE(session_.GetOrCreateStream(kTestStreamId)); +} + +TEST_P(QuicSessionTestServer, DoNotSendGoAwayTwice) { + CompleteHandshake(); + if (VersionHasIetfQuicFrames(transport_version())) { + // In IETF QUIC, GOAWAY lives up in the HTTP layer. + return; + } + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillOnce(Invoke(&ClearControlFrame)); + session_.SendGoAway(QUIC_PEER_GOING_AWAY, "Going Away."); + EXPECT_TRUE(session_.transport_goaway_sent()); + session_.SendGoAway(QUIC_PEER_GOING_AWAY, "Going Away."); +} + +TEST_P(QuicSessionTestServer, InvalidGoAway) { + if (VersionHasIetfQuicFrames(transport_version())) { + // In IETF QUIC, GOAWAY lives up in the HTTP layer. + return; + } + QuicGoAwayFrame go_away(kInvalidControlFrameId, QUIC_PEER_GOING_AWAY, + session_.next_outgoing_bidirectional_stream_id(), ""); + session_.OnGoAway(go_away); +} + +// Test that server session will send a connectivity probe in response to a +// connectivity probe on the same path. +TEST_P(QuicSessionTestServer, ServerReplyToConnectivityProbe) { + if (VersionHasIetfQuicFrames(transport_version())) { + return; + } + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + QuicSocketAddress old_peer_address = + QuicSocketAddress(QuicIpAddress::Loopback4(), kTestPort); + EXPECT_EQ(old_peer_address, session_.peer_address()); + + QuicSocketAddress new_peer_address = + QuicSocketAddress(QuicIpAddress::Loopback4(), kTestPort + 1); + + MockPacketWriter* writer = static_cast( + QuicConnectionPeer::GetWriter(session_.connection())); + EXPECT_CALL(*writer, WritePacket(_, _, _, new_peer_address, _)) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 0))); + + EXPECT_CALL(*connection_, SendConnectivityProbingPacket(_, _)) + .WillOnce( + Invoke(connection_, + &MockQuicConnection::ReallySendConnectivityProbingPacket)); + session_.OnPacketReceived(session_.self_address(), new_peer_address, + /*is_connectivity_probe=*/true); + EXPECT_EQ(old_peer_address, session_.peer_address()); +} + +TEST_P(QuicSessionTestServer, IncreasedTimeoutAfterCryptoHandshake) { + EXPECT_EQ(kInitialIdleTimeoutSecs + 3, + QuicConnectionPeer::GetNetworkTimeout(connection_).ToSeconds()); + CompleteHandshake(); + EXPECT_EQ(kMaximumIdleTimeoutSecs + 3, + QuicConnectionPeer::GetNetworkTimeout(connection_).ToSeconds()); +} + +TEST_P(QuicSessionTestServer, OnStreamFrameFinStaticStreamId) { + if (VersionUsesHttp3(connection_->transport_version())) { + // The test relies on headers stream, which no longer exists in IETF QUIC. + return; + } + QuicStreamId headers_stream_id = + QuicUtils::GetHeadersStreamId(connection_->transport_version()); + std::unique_ptr fake_headers_stream = + std::make_unique(headers_stream_id, &session_, + /*is_static*/ true, BIDIRECTIONAL); + QuicSessionPeer::ActivateStream(&session_, std::move(fake_headers_stream)); + // Send two bytes of payload. + QuicStreamFrame data1(headers_stream_id, true, 0, absl::string_view("HT")); + EXPECT_CALL(*connection_, + CloseConnection( + QUIC_INVALID_STREAM_ID, "Attempt to close a static stream", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET)); + session_.OnStreamFrame(data1); +} + +TEST_P(QuicSessionTestServer, OnStreamFrameInvalidStreamId) { + // Send two bytes of payload. + QuicStreamFrame data1( + QuicUtils::GetInvalidStreamId(connection_->transport_version()), true, 0, + absl::string_view("HT")); + EXPECT_CALL(*connection_, + CloseConnection( + QUIC_INVALID_STREAM_ID, "Received data for an invalid stream", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET)); + session_.OnStreamFrame(data1); +} + +TEST_P(QuicSessionTestServer, OnRstStreamInvalidStreamId) { + // Send two bytes of payload. + QuicRstStreamFrame rst1( + kInvalidControlFrameId, + QuicUtils::GetInvalidStreamId(connection_->transport_version()), + QUIC_ERROR_PROCESSING_STREAM, 0); + EXPECT_CALL(*connection_, + CloseConnection( + QUIC_INVALID_STREAM_ID, "Received data for an invalid stream", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET)); + session_.OnRstStream(rst1); +} + +TEST_P(QuicSessionTestServer, HandshakeUnblocksFlowControlBlockedStream) { + if (connection_->version().handshake_protocol == PROTOCOL_TLS1_3) { + // This test requires Google QUIC crypto because it assumes streams start + // off unblocked. + return; + } + // Test that if a stream is flow control blocked, then on receipt of the SHLO + // containing a suitable send window offset, the stream becomes unblocked. + + // Ensure that Writev consumes all the data it is given (simulate no socket + // blocking). + session_.set_writev_consumes_all_data(true); + session_.GetMutableCryptoStream()->EstablishZeroRttEncryption(); + + // Create a stream, and send enough data to make it flow control blocked. + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + std::string body(kMinimumFlowControlSendWindow, '.'); + EXPECT_FALSE(stream2->IsFlowControlBlocked()); + EXPECT_FALSE(session_.IsConnectionFlowControlBlocked()); + EXPECT_FALSE(session_.IsStreamFlowControlBlocked()); + EXPECT_CALL(*connection_, SendControlFrame(_)).Times(AtLeast(1)); + stream2->WriteOrBufferData(body, false, nullptr); + EXPECT_TRUE(stream2->IsFlowControlBlocked()); + EXPECT_TRUE(session_.IsConnectionFlowControlBlocked()); + EXPECT_TRUE(session_.IsStreamFlowControlBlocked()); + + // Now complete the crypto handshake, resulting in an increased flow control + // send window. + CompleteHandshake(); + EXPECT_TRUE(QuicSessionPeer::IsStreamWriteBlocked(&session_, stream2->id())); + // Stream is now unblocked. + EXPECT_FALSE(stream2->IsFlowControlBlocked()); + EXPECT_FALSE(session_.IsConnectionFlowControlBlocked()); + EXPECT_FALSE(session_.IsStreamFlowControlBlocked()); +} + +TEST_P(QuicSessionTestServer, ConnectionFlowControlAccountingRstOutOfOrder) { + CompleteHandshake(); + // Test that when we receive an out of order stream RST we correctly adjust + // our connection level flow control receive window. + // On close, the stream should mark as consumed all bytes between the highest + // byte consumed so far and the final byte offset from the RST frame. + TestStream* stream = session_.CreateOutgoingBidirectionalStream(); + + const QuicStreamOffset kByteOffset = + 1 + kInitialSessionFlowControlWindowForTest / 2; + + EXPECT_CALL(*connection_, SendControlFrame(_)) + .Times(2) + .WillRepeatedly(Invoke(&ClearControlFrame)); + EXPECT_CALL(*connection_, OnStreamReset(stream->id(), _)); + + QuicRstStreamFrame rst_frame(kInvalidControlFrameId, stream->id(), + QUIC_STREAM_CANCELLED, kByteOffset); + session_.OnRstStream(rst_frame); + if (VersionHasIetfQuicFrames(transport_version())) { + // The test requires the stream to be fully closed in both directions. For + // IETF QUIC, the RST_STREAM only closes one side. + QuicStopSendingFrame frame(kInvalidControlFrameId, stream->id(), + QUIC_STREAM_CANCELLED); + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + session_.OnStopSendingFrame(frame); + } + EXPECT_EQ(kByteOffset, session_.flow_controller()->bytes_consumed()); +} + +TEST_P(QuicSessionTestServer, ConnectionFlowControlAccountingFinAndLocalReset) { + CompleteHandshake(); + // Test the situation where we receive a FIN on a stream, and before we fully + // consume all the data from the sequencer buffer we locally RST the stream. + // The bytes between highest consumed byte, and the final byte offset that we + // determined when the FIN arrived, should be marked as consumed at the + // connection level flow controller when the stream is reset. + TestStream* stream = session_.CreateOutgoingBidirectionalStream(); + + const QuicStreamOffset kByteOffset = + kInitialSessionFlowControlWindowForTest / 2 - 1; + QuicStreamFrame frame(stream->id(), true, kByteOffset, "."); + session_.OnStreamFrame(frame); + EXPECT_TRUE(connection_->connected()); + + EXPECT_EQ(0u, session_.flow_controller()->bytes_consumed()); + EXPECT_EQ(kByteOffset + frame.data_length, + stream->highest_received_byte_offset()); + + // Reset stream locally. + EXPECT_CALL(*connection_, SendControlFrame(_)); + EXPECT_CALL(*connection_, OnStreamReset(stream->id(), _)); + stream->Reset(QUIC_STREAM_CANCELLED); + EXPECT_EQ(kByteOffset + frame.data_length, + session_.flow_controller()->bytes_consumed()); +} + +TEST_P(QuicSessionTestServer, ConnectionFlowControlAccountingFinAfterRst) { + CompleteHandshake(); + // Test that when we RST the stream (and tear down stream state), and then + // receive a FIN from the peer, we correctly adjust our connection level flow + // control receive window. + + // Connection starts with some non-zero highest received byte offset, + // due to other active streams. + const uint64_t kInitialConnectionBytesConsumed = 567; + const uint64_t kInitialConnectionHighestReceivedOffset = 1234; + EXPECT_LT(kInitialConnectionBytesConsumed, + kInitialConnectionHighestReceivedOffset); + session_.flow_controller()->UpdateHighestReceivedOffset( + kInitialConnectionHighestReceivedOffset); + session_.flow_controller()->AddBytesConsumed(kInitialConnectionBytesConsumed); + + // Reset our stream: this results in the stream being closed locally. + TestStream* stream = session_.CreateOutgoingBidirectionalStream(); + EXPECT_CALL(*connection_, SendControlFrame(_)); + EXPECT_CALL(*connection_, OnStreamReset(stream->id(), _)); + stream->Reset(QUIC_STREAM_CANCELLED); + + // Now receive a response from the peer with a FIN. We should handle this by + // adjusting the connection level flow control receive window to take into + // account the total number of bytes sent by the peer. + const QuicStreamOffset kByteOffset = 5678; + std::string body = "hello"; + QuicStreamFrame frame(stream->id(), true, kByteOffset, + absl::string_view(body)); + session_.OnStreamFrame(frame); + + QuicStreamOffset total_stream_bytes_sent_by_peer = + kByteOffset + body.length(); + EXPECT_EQ(kInitialConnectionBytesConsumed + total_stream_bytes_sent_by_peer, + session_.flow_controller()->bytes_consumed()); + EXPECT_EQ( + kInitialConnectionHighestReceivedOffset + total_stream_bytes_sent_by_peer, + session_.flow_controller()->highest_received_byte_offset()); +} + +TEST_P(QuicSessionTestServer, ConnectionFlowControlAccountingRstAfterRst) { + CompleteHandshake(); + // Test that when we RST the stream (and tear down stream state), and then + // receive a RST from the peer, we correctly adjust our connection level flow + // control receive window. + + // Connection starts with some non-zero highest received byte offset, + // due to other active streams. + const uint64_t kInitialConnectionBytesConsumed = 567; + const uint64_t kInitialConnectionHighestReceivedOffset = 1234; + EXPECT_LT(kInitialConnectionBytesConsumed, + kInitialConnectionHighestReceivedOffset); + session_.flow_controller()->UpdateHighestReceivedOffset( + kInitialConnectionHighestReceivedOffset); + session_.flow_controller()->AddBytesConsumed(kInitialConnectionBytesConsumed); + + // Reset our stream: this results in the stream being closed locally. + TestStream* stream = session_.CreateOutgoingBidirectionalStream(); + EXPECT_CALL(*connection_, SendControlFrame(_)); + EXPECT_CALL(*connection_, OnStreamReset(stream->id(), _)); + stream->Reset(QUIC_STREAM_CANCELLED); + EXPECT_TRUE(QuicStreamPeer::read_side_closed(stream)); + + // Now receive a RST from the peer. We should handle this by adjusting the + // connection level flow control receive window to take into account the total + // number of bytes sent by the peer. + const QuicStreamOffset kByteOffset = 5678; + QuicRstStreamFrame rst_frame(kInvalidControlFrameId, stream->id(), + QUIC_STREAM_CANCELLED, kByteOffset); + session_.OnRstStream(rst_frame); + + EXPECT_EQ(kInitialConnectionBytesConsumed + kByteOffset, + session_.flow_controller()->bytes_consumed()); + EXPECT_EQ(kInitialConnectionHighestReceivedOffset + kByteOffset, + session_.flow_controller()->highest_received_byte_offset()); +} + +TEST_P(QuicSessionTestServer, InvalidStreamFlowControlWindowInHandshake) { + // Test that receipt of an invalid (< default) stream flow control window from + // the peer results in the connection being torn down. + const uint32_t kInvalidWindow = kMinimumFlowControlSendWindow - 1; + QuicConfigPeer::SetReceivedInitialStreamFlowControlWindow(session_.config(), + kInvalidWindow); + + if (connection_->version().handshake_protocol != PROTOCOL_TLS1_3) { + EXPECT_CALL(*connection_, + CloseConnection(QUIC_FLOW_CONTROL_INVALID_WINDOW, _, _)); + } else { + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + } + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + session_.OnConfigNegotiated(); +} + +// Test negotiation of custom server initial flow control window. +TEST_P(QuicSessionTestServer, CustomFlowControlWindow) { + QuicTagVector copt; + copt.push_back(kIFW7); + QuicConfigPeer::SetReceivedConnectionOptions(session_.config(), copt); + + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + session_.OnConfigNegotiated(); + EXPECT_EQ(192 * 1024u, QuicFlowControllerPeer::ReceiveWindowSize( + session_.flow_controller())); +} + +TEST_P(QuicSessionTestServer, FlowControlWithInvalidFinalOffset) { + CompleteHandshake(); + // Test that if we receive a stream RST with a highest byte offset that + // violates flow control, that we close the connection. + const uint64_t kLargeOffset = kInitialSessionFlowControlWindowForTest + 1; + EXPECT_CALL(*connection_, + CloseConnection(QUIC_FLOW_CONTROL_RECEIVED_TOO_MUCH_DATA, _, _)) + .Times(2); + + // Check that stream frame + FIN results in connection close. + TestStream* stream = session_.CreateOutgoingBidirectionalStream(); + EXPECT_CALL(*connection_, SendControlFrame(_)); + EXPECT_CALL(*connection_, OnStreamReset(stream->id(), _)); + stream->Reset(QUIC_STREAM_CANCELLED); + QuicStreamFrame frame(stream->id(), true, kLargeOffset, absl::string_view()); + session_.OnStreamFrame(frame); + + // Check that RST results in connection close. + QuicRstStreamFrame rst_frame(kInvalidControlFrameId, stream->id(), + QUIC_STREAM_CANCELLED, kLargeOffset); + session_.OnRstStream(rst_frame); +} + +TEST_P(QuicSessionTestServer, TooManyUnfinishedStreamsCauseServerRejectStream) { + CompleteHandshake(); + // If a buggy/malicious peer creates too many streams that are not ended + // with a FIN or RST then we send an RST to refuse streams. For IETF QUIC the + // connection is closed. + const QuicStreamId kMaxStreams = 5; + if (VersionHasIetfQuicFrames(transport_version())) { + QuicSessionPeer::SetMaxOpenIncomingBidirectionalStreams(&session_, + kMaxStreams); + } else { + QuicSessionPeer::SetMaxOpenIncomingStreams(&session_, kMaxStreams); + } + const QuicStreamId kFirstStreamId = GetNthClientInitiatedBidirectionalId(0); + const QuicStreamId kFinalStreamId = + GetNthClientInitiatedBidirectionalId(kMaxStreams); + // Create kMaxStreams data streams, and close them all without receiving a + // FIN or a RST_STREAM from the client. + for (QuicStreamId i = kFirstStreamId; i < kFinalStreamId; + i += QuicUtils::StreamIdDelta(connection_->transport_version())) { + QuicStreamFrame data1(i, false, 0, absl::string_view("HT")); + session_.OnStreamFrame(data1); + CloseStream(i); + } + + if (VersionHasIetfQuicFrames(transport_version())) { + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_INVALID_STREAM_ID, + "Stream id 20 would exceed stream count limit 5", _)); + } else { + EXPECT_CALL(*connection_, SendControlFrame(_)).Times(1); + EXPECT_CALL(*connection_, + OnStreamReset(kFinalStreamId, QUIC_REFUSED_STREAM)) + .Times(1); + } + // Create one more data streams to exceed limit of open stream. + QuicStreamFrame data1(kFinalStreamId, false, 0, absl::string_view("HT")); + session_.OnStreamFrame(data1); +} + +TEST_P(QuicSessionTestServer, DrainingStreamsDoNotCountAsOpenedOutgoing) { + // Verify that a draining stream (which has received a FIN but not consumed + // it) does not count against the open quota (because it is closed from the + // protocol point of view). + CompleteHandshake(); + TestStream* stream = session_.CreateOutgoingBidirectionalStream(); + QuicStreamId stream_id = stream->id(); + QuicStreamFrame data1(stream_id, true, 0, absl::string_view("HT")); + session_.OnStreamFrame(data1); + if (!VersionHasIetfQuicFrames(transport_version())) { + EXPECT_CALL(session_, OnCanCreateNewOutgoingStream(false)).Times(1); + } + session_.StreamDraining(stream_id, /*unidirectional=*/false); +} + +TEST_P(QuicSessionTestServer, NoPendingStreams) { + session_.set_uses_pending_streams(false); + + QuicStreamId stream_id = QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT); + QuicStreamFrame data1(stream_id, true, 10, absl::string_view("HT")); + session_.OnStreamFrame(data1); + EXPECT_EQ(1, session_.num_incoming_streams_created()); + + QuicStreamFrame data2(stream_id, false, 0, absl::string_view("HT")); + session_.OnStreamFrame(data2); + EXPECT_EQ(1, session_.num_incoming_streams_created()); +} + +TEST_P(QuicSessionTestServer, PendingStreams) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + session_.set_uses_pending_streams(true); + session_.set_process_pending_stream_immediately(true); + + QuicStreamId stream_id = QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT); + QuicStreamFrame data1(stream_id, true, 10, absl::string_view("HT")); + session_.OnStreamFrame(data1); + EXPECT_TRUE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); + EXPECT_EQ(0, session_.num_incoming_streams_created()); + + QuicStreamFrame data2(stream_id, false, 0, absl::string_view("HT")); + session_.OnStreamFrame(data2); + EXPECT_FALSE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); + EXPECT_EQ(1, session_.num_incoming_streams_created()); +} + +TEST_P(QuicSessionTestServer, BufferAllIncomingStreams) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + session_.set_uses_pending_streams(true); + session_.set_process_pending_stream_immediately(false); + + QuicStreamId stream_id = QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT); + QuicStreamFrame data1(stream_id, true, 10, absl::string_view("HT")); + session_.OnStreamFrame(data1); + EXPECT_TRUE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); + EXPECT_EQ(0, session_.num_incoming_streams_created()); + // Read unidirectional stream is still buffered when the first byte arrives. + QuicStreamFrame data2(stream_id, false, 0, absl::string_view("HT")); + session_.OnStreamFrame(data2); + EXPECT_TRUE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); + EXPECT_EQ(0, session_.num_incoming_streams_created()); + + // Bidirectional stream is buffered. + QuicStreamId bidirectional_stream_id = + QuicUtils::GetFirstBidirectionalStreamId(transport_version(), + Perspective::IS_CLIENT); + QuicStreamFrame data3(bidirectional_stream_id, false, 0, + absl::string_view("HT")); + session_.OnStreamFrame(data3); + EXPECT_TRUE( + QuicSessionPeer::GetPendingStream(&session_, bidirectional_stream_id)); + EXPECT_EQ(0, session_.num_incoming_streams_created()); + + session_.ProcessAllPendingStreams(); + // Both bidirectional and read-unidirectional streams are unbuffered. + EXPECT_FALSE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); + EXPECT_FALSE( + QuicSessionPeer::GetPendingStream(&session_, bidirectional_stream_id)); + EXPECT_EQ(2, session_.num_incoming_streams_created()); +} + +TEST_P(QuicSessionTestServer, RstPendingStreams) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + session_.set_uses_pending_streams(true); + session_.set_process_pending_stream_immediately(false); + + QuicStreamId stream_id = QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT); + QuicStreamFrame data1(stream_id, true, 10, absl::string_view("HT")); + session_.OnStreamFrame(data1); + EXPECT_TRUE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); + EXPECT_EQ(0, session_.num_incoming_streams_created()); + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(&session_)); + + QuicRstStreamFrame rst1(kInvalidControlFrameId, stream_id, + QUIC_ERROR_PROCESSING_STREAM, 12); + session_.OnRstStream(rst1); + EXPECT_FALSE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); + EXPECT_EQ(0, session_.num_incoming_streams_created()); + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(&session_)); + + QuicStreamFrame data2(stream_id, false, 0, absl::string_view("HT")); + session_.OnStreamFrame(data2); + EXPECT_FALSE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); + EXPECT_EQ(0, session_.num_incoming_streams_created()); + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(&session_)); + + session_.ProcessAllPendingStreams(); + // Bidirectional stream is buffered. + QuicStreamId bidirectional_stream_id = + QuicUtils::GetFirstBidirectionalStreamId(transport_version(), + Perspective::IS_CLIENT); + QuicStreamFrame data3(bidirectional_stream_id, false, 0, + absl::string_view("HT")); + session_.OnStreamFrame(data3); + EXPECT_TRUE( + QuicSessionPeer::GetPendingStream(&session_, bidirectional_stream_id)); + EXPECT_EQ(0, session_.num_incoming_streams_created()); + + // Bidirectional pending stream is removed after RST_STREAM is received. + QuicRstStreamFrame rst2(kInvalidControlFrameId, bidirectional_stream_id, + QUIC_ERROR_PROCESSING_STREAM, 12); + session_.OnRstStream(rst2); + EXPECT_FALSE( + QuicSessionPeer::GetPendingStream(&session_, bidirectional_stream_id)); + EXPECT_EQ(0, session_.num_incoming_streams_created()); + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(&session_)); +} + +TEST_P(QuicSessionTestServer, OnFinPendingStreams) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + session_.set_uses_pending_streams(true); + session_.set_process_pending_stream_immediately(true); + + QuicStreamId stream_id = QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT); + QuicStreamFrame data(stream_id, true, 0, ""); + session_.OnStreamFrame(data); + + EXPECT_FALSE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); + EXPECT_EQ(0, session_.num_incoming_streams_created()); + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(&session_)); + + session_.set_process_pending_stream_immediately(false); + // Bidirectional pending stream remains after Fin is received. + // Bidirectional stream is buffered. + QuicStreamId bidirectional_stream_id = + QuicUtils::GetFirstBidirectionalStreamId(transport_version(), + Perspective::IS_CLIENT); + QuicStreamFrame data2(bidirectional_stream_id, true, 0, + absl::string_view("HT")); + session_.OnStreamFrame(data2); + EXPECT_TRUE( + QuicSessionPeer::GetPendingStream(&session_, bidirectional_stream_id)); + EXPECT_EQ(0, session_.num_incoming_streams_created()); + + session_.ProcessAllPendingStreams(); + EXPECT_FALSE( + QuicSessionPeer::GetPendingStream(&session_, bidirectional_stream_id)); + EXPECT_EQ(1, session_.num_incoming_streams_created()); + QuicStream* bidirectional_stream = + QuicSessionPeer::GetStream(&session_, bidirectional_stream_id); + EXPECT_TRUE(bidirectional_stream->fin_received()); +} + +TEST_P(QuicSessionTestServer, UnidirectionalPendingStreamOnWindowUpdate) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + session_.set_uses_pending_streams(true); + QuicStreamId stream_id = QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT); + QuicStreamFrame data1(stream_id, true, 10, absl::string_view("HT")); + session_.OnStreamFrame(data1); + EXPECT_TRUE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); + EXPECT_EQ(0, session_.num_incoming_streams_created()); + QuicWindowUpdateFrame window_update_frame(kInvalidControlFrameId, stream_id, + 0); + EXPECT_CALL( + *connection_, + CloseConnection( + QUIC_WINDOW_UPDATE_RECEIVED_ON_READ_UNIDIRECTIONAL_STREAM, + "WindowUpdateFrame received on READ_UNIDIRECTIONAL stream.", _)); + session_.OnWindowUpdateFrame(window_update_frame); +} + +TEST_P(QuicSessionTestServer, BidirectionalPendingStreamOnWindowUpdate) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + session_.set_uses_pending_streams(true); + session_.set_process_pending_stream_immediately(false); + QuicStreamId stream_id = QuicUtils::GetFirstBidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT); + QuicStreamFrame data(stream_id, true, 10, absl::string_view("HT")); + session_.OnStreamFrame(data); + QuicWindowUpdateFrame window_update_frame(kInvalidControlFrameId, stream_id, + kDefaultFlowControlSendWindow * 2); + session_.OnWindowUpdateFrame(window_update_frame); + EXPECT_TRUE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); + EXPECT_EQ(0, session_.num_incoming_streams_created()); + + session_.ProcessAllPendingStreams(); + EXPECT_FALSE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); + EXPECT_EQ(1, session_.num_incoming_streams_created()); + QuicStream* bidirectional_stream = + QuicSessionPeer::GetStream(&session_, stream_id); + QuicByteCount send_window = + QuicStreamPeer::SendWindowSize(bidirectional_stream); + EXPECT_EQ(send_window, kDefaultFlowControlSendWindow * 2); +} + +TEST_P(QuicSessionTestServer, UnidirectionalPendingStreamOnStopSending) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + session_.set_uses_pending_streams(true); + QuicStreamId stream_id = QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT); + QuicStreamFrame data1(stream_id, true, 10, absl::string_view("HT")); + session_.OnStreamFrame(data1); + EXPECT_TRUE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); + EXPECT_EQ(0, session_.num_incoming_streams_created()); + QuicStopSendingFrame stop_sending_frame(kInvalidControlFrameId, stream_id, + QUIC_STREAM_CANCELLED); + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_INVALID_STREAM_ID, + "Received STOP_SENDING for a read-only stream", _)); + session_.OnStopSendingFrame(stop_sending_frame); +} + +TEST_P(QuicSessionTestServer, BidirectionalPendingStreamOnStopSending) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + session_.set_uses_pending_streams(true); + session_.set_process_pending_stream_immediately(false); + QuicStreamId stream_id = QuicUtils::GetFirstBidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT); + QuicStreamFrame data(stream_id, true, 0, absl::string_view("HT")); + session_.OnStreamFrame(data); + QuicStopSendingFrame stop_sending_frame(kInvalidControlFrameId, stream_id, + QUIC_STREAM_CANCELLED); + session_.OnStopSendingFrame(stop_sending_frame); + EXPECT_TRUE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); + EXPECT_EQ(0, session_.num_incoming_streams_created()); + + EXPECT_CALL(*connection_, OnStreamReset(stream_id, _)); + session_.ProcessAllPendingStreams(); + EXPECT_FALSE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); + EXPECT_EQ(1, session_.num_incoming_streams_created()); + QuicStream* bidirectional_stream = + QuicSessionPeer::GetStream(&session_, stream_id); + EXPECT_TRUE(bidirectional_stream->write_side_closed()); +} + +TEST_P(QuicSessionTestServer, DrainingStreamsDoNotCountAsOpened) { + // Verify that a draining stream (which has received a FIN but not consumed + // it) does not count against the open quota (because it is closed from the + // protocol point of view). + CompleteHandshake(); + if (VersionHasIetfQuicFrames(transport_version())) { + // On IETF QUIC, we will expect to see a MAX_STREAMS go out when there are + // not enough streams to create the next one. + EXPECT_CALL(*connection_, SendControlFrame(_)).Times(1); + } else { + EXPECT_CALL(*connection_, SendControlFrame(_)).Times(0); + } + EXPECT_CALL(*connection_, OnStreamReset(_, QUIC_REFUSED_STREAM)).Times(0); + const QuicStreamId kMaxStreams = 5; + if (VersionHasIetfQuicFrames(transport_version())) { + QuicSessionPeer::SetMaxOpenIncomingBidirectionalStreams(&session_, + kMaxStreams); + } else { + QuicSessionPeer::SetMaxOpenIncomingStreams(&session_, kMaxStreams); + } + + // Create kMaxStreams + 1 data streams, and mark them draining. + const QuicStreamId kFirstStreamId = GetNthClientInitiatedBidirectionalId(0); + const QuicStreamId kFinalStreamId = + GetNthClientInitiatedBidirectionalId(2 * kMaxStreams + 1); + for (QuicStreamId i = kFirstStreamId; i < kFinalStreamId; + i += QuicUtils::StreamIdDelta(connection_->transport_version())) { + QuicStreamFrame data1(i, true, 0, absl::string_view("HT")); + session_.OnStreamFrame(data1); + EXPECT_EQ(1u, QuicSessionPeer::GetNumOpenDynamicStreams(&session_)); + session_.StreamDraining(i, /*unidirectional=*/false); + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(&session_)); + } +} + +class QuicSessionTestClient : public QuicSessionTestBase { + protected: + QuicSessionTestClient() + : QuicSessionTestBase(Perspective::IS_CLIENT, + /*configure_session=*/true) {} +}; + +INSTANTIATE_TEST_SUITE_P(Tests, QuicSessionTestClient, + ::testing::ValuesIn(AllSupportedVersions()), + ::testing::PrintToStringParamName()); + +TEST_P(QuicSessionTestClient, AvailableBidirectionalStreamsClient) { + ASSERT_TRUE(session_.GetOrCreateStream( + GetNthServerInitiatedBidirectionalId(2)) != nullptr); + // Smaller bidirectional streams should be available. + EXPECT_TRUE(QuicSessionPeer::IsStreamAvailable( + &session_, GetNthServerInitiatedBidirectionalId(0))); + EXPECT_TRUE(QuicSessionPeer::IsStreamAvailable( + &session_, GetNthServerInitiatedBidirectionalId(1))); + ASSERT_TRUE(session_.GetOrCreateStream( + GetNthServerInitiatedBidirectionalId(0)) != nullptr); + ASSERT_TRUE(session_.GetOrCreateStream( + GetNthServerInitiatedBidirectionalId(1)) != nullptr); + // And 5 should be not available. + EXPECT_FALSE(QuicSessionPeer::IsStreamAvailable( + &session_, GetNthClientInitiatedBidirectionalId(1))); +} + +TEST_P(QuicSessionTestClient, NewStreamCreationResumesMultiPortProbing) { + session_.config()->SetConnectionOptionsToSend({kRVCM}); + session_.config()->SetClientConnectionOptions({kMPQC}); + session_.Initialize(); + connection_->CreateConnectionIdManager(); + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + connection_->OnHandshakeComplete(); + session_.OnConfigNegotiated(); + + if (!connection_->connection_migration_use_new_cid()) { + return; + } + + EXPECT_CALL(*connection_, MaybeProbeMultiPortPath()); + session_.CreateOutgoingBidirectionalStream(); +} + +TEST_P(QuicSessionTestClient, InvalidSessionFlowControlWindowInHandshake) { + // Test that receipt of an invalid (< default for gQUIC, < current for TLS) + // session flow control window from the peer results in the connection being + // torn down. + const uint32_t kInvalidWindow = kMinimumFlowControlSendWindow - 1; + QuicConfigPeer::SetReceivedInitialSessionFlowControlWindow(session_.config(), + kInvalidWindow); + EXPECT_CALL( + *connection_, + CloseConnection(connection_->version().AllowsLowFlowControlLimits() + ? QUIC_ZERO_RTT_RESUMPTION_LIMIT_REDUCED + : QUIC_FLOW_CONTROL_INVALID_WINDOW, + _, _)); + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + session_.OnConfigNegotiated(); +} + +TEST_P(QuicSessionTestClient, InvalidBidiStreamLimitInHandshake) { + // IETF QUIC only feature. + if (!VersionHasIetfQuicFrames(transport_version())) { + return; + } + QuicConfigPeer::SetReceivedMaxBidirectionalStreams( + session_.config(), kDefaultMaxStreamsPerConnection - 1); + EXPECT_CALL(*connection_, + CloseConnection(QUIC_ZERO_RTT_RESUMPTION_LIMIT_REDUCED, _, _)); + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + session_.OnConfigNegotiated(); +} + +TEST_P(QuicSessionTestClient, InvalidUniStreamLimitInHandshake) { + // IETF QUIC only feature. + if (!VersionHasIetfQuicFrames(transport_version())) { + return; + } + QuicConfigPeer::SetReceivedMaxUnidirectionalStreams( + session_.config(), kDefaultMaxStreamsPerConnection - 1); + EXPECT_CALL(*connection_, + CloseConnection(QUIC_ZERO_RTT_RESUMPTION_LIMIT_REDUCED, _, _)); + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + session_.OnConfigNegotiated(); +} + +TEST_P(QuicSessionTestClient, InvalidStreamFlowControlWindowInHandshake) { + // IETF QUIC only feature. + if (!VersionHasIetfQuicFrames(transport_version())) { + return; + } + session_.CreateOutgoingBidirectionalStream(); + session_.CreateOutgoingBidirectionalStream(); + QuicConfigPeer::SetReceivedInitialMaxStreamDataBytesOutgoingBidirectional( + session_.config(), kMinimumFlowControlSendWindow - 1); + + EXPECT_CALL(*connection_, CloseConnection(_, _, _)) + .WillOnce( + Invoke(connection_, &MockQuicConnection::ReallyCloseConnection)); + EXPECT_CALL(*connection_, SendConnectionClosePacket(_, _, _)); + + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + session_.OnConfigNegotiated(); +} + +TEST_P(QuicSessionTestClient, OnMaxStreamFrame) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + QuicMaxStreamsFrame frame; + frame.unidirectional = false; + frame.stream_count = 120; + EXPECT_CALL(session_, OnCanCreateNewOutgoingStream(false)).Times(1); + session_.OnMaxStreamsFrame(frame); + + QuicMaxStreamsFrame frame2; + frame2.unidirectional = false; + frame2.stream_count = 110; + EXPECT_CALL(session_, OnCanCreateNewOutgoingStream(false)).Times(0); + session_.OnMaxStreamsFrame(frame2); +} + +TEST_P(QuicSessionTestClient, AvailableUnidirectionalStreamsClient) { + ASSERT_TRUE(session_.GetOrCreateStream( + GetNthServerInitiatedUnidirectionalId(2)) != nullptr); + // Smaller unidirectional streams should be available. + EXPECT_TRUE(QuicSessionPeer::IsStreamAvailable( + &session_, GetNthServerInitiatedUnidirectionalId(0))); + EXPECT_TRUE(QuicSessionPeer::IsStreamAvailable( + &session_, GetNthServerInitiatedUnidirectionalId(1))); + ASSERT_TRUE(session_.GetOrCreateStream( + GetNthServerInitiatedUnidirectionalId(0)) != nullptr); + ASSERT_TRUE(session_.GetOrCreateStream( + GetNthServerInitiatedUnidirectionalId(1)) != nullptr); + // And 5 should be not available. + EXPECT_FALSE(QuicSessionPeer::IsStreamAvailable( + &session_, GetNthClientInitiatedUnidirectionalId(1))); +} + +TEST_P(QuicSessionTestClient, RecordFinAfterReadSideClosed) { + CompleteHandshake(); + // Verify that an incoming FIN is recorded in a stream object even if the read + // side has been closed. This prevents an entry from being made in + // locally_closed_streams_highest_offset_ (which will never be deleted). + TestStream* stream = session_.CreateOutgoingBidirectionalStream(); + QuicStreamId stream_id = stream->id(); + + // Close the read side manually. + QuicStreamPeer::CloseReadSide(stream); + + // Receive a stream data frame with FIN. + QuicStreamFrame frame(stream_id, true, 0, absl::string_view()); + session_.OnStreamFrame(frame); + EXPECT_TRUE(stream->fin_received()); + + // Reset stream locally. + EXPECT_CALL(*connection_, SendControlFrame(_)); + EXPECT_CALL(*connection_, OnStreamReset(stream->id(), _)); + stream->Reset(QUIC_STREAM_CANCELLED); + EXPECT_TRUE(QuicStreamPeer::read_side_closed(stream)); + + EXPECT_TRUE(connection_->connected()); + EXPECT_TRUE(QuicSessionPeer::IsStreamClosed(&session_, stream_id)); + EXPECT_FALSE(QuicSessionPeer::IsStreamCreated(&session_, stream_id)); + + // The stream is not waiting for the arrival of the peer's final offset as it + // was received with the FIN earlier. + EXPECT_EQ( + 0u, + QuicSessionPeer::GetLocallyClosedStreamsHighestOffset(&session_).size()); +} + +TEST_P(QuicSessionTestClient, IncomingStreamWithClientInitiatedStreamId) { + const QuicErrorCode expected_error = + VersionHasIetfQuicFrames(transport_version()) + ? QUIC_HTTP_STREAM_WRONG_DIRECTION + : QUIC_INVALID_STREAM_ID; + EXPECT_CALL( + *connection_, + CloseConnection(expected_error, "Data for nonexistent stream", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET)); + + QuicStreamFrame frame(GetNthClientInitiatedBidirectionalId(1), + /* fin = */ false, /* offset = */ 0, + absl::string_view("foo")); + session_.OnStreamFrame(frame); +} + +TEST_P(QuicSessionTestClient, MinAckDelaySetOnTheClientQuicConfig) { + if (!session_.version().HasIetfQuicFrames()) { + return; + } + session_.config()->SetClientConnectionOptions({kAFFE}); + session_.Initialize(); + ASSERT_EQ(session_.config()->GetMinAckDelayToSendMs(), + kDefaultMinAckDelayTimeMs); + ASSERT_TRUE(session_.connection()->can_receive_ack_frequency_frame()); +} + +TEST_P(QuicSessionTestClient, FailedToCreateStreamIfTooCloseToIdleTimeout) { + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + EXPECT_TRUE(session_.CanOpenNextOutgoingBidirectionalStream()); + QuicTime deadline = QuicConnectionPeer::GetIdleNetworkDeadline(connection_); + ASSERT_TRUE(deadline.IsInitialized()); + QuicTime::Delta timeout = deadline - helper_.GetClock()->ApproximateNow(); + // Advance time to very close idle timeout. + connection_->AdvanceTime(timeout - QuicTime::Delta::FromMilliseconds(1)); + // Verify creation of new stream gets pushed back and connectivity probing + // packet gets sent. + EXPECT_CALL(*connection_, SendConnectivityProbingPacket(_, _)).Times(1); + EXPECT_FALSE(session_.CanOpenNextOutgoingBidirectionalStream()); + + // New packet gets received, idle deadline gets extended. + EXPECT_CALL(session_, OnCanCreateNewOutgoingStream(false)); + QuicConnectionPeer::GetIdleNetworkDetector(connection_) + .OnPacketReceived(helper_.GetClock()->ApproximateNow()); + session_.OnPacketDecrypted(ENCRYPTION_FORWARD_SECURE); + + EXPECT_TRUE(session_.CanOpenNextOutgoingBidirectionalStream()); +} + +TEST_P(QuicSessionTestServer, ZombieStreams) { + CompleteHandshake(); + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + QuicStreamPeer::SetStreamBytesWritten(3, stream2); + EXPECT_TRUE(stream2->IsWaitingForAcks()); + + CloseStream(stream2->id()); + ASSERT_EQ(1u, session_.closed_streams()->size()); + EXPECT_EQ(stream2->id(), session_.closed_streams()->front()->id()); + session_.MaybeCloseZombieStream(stream2->id()); + EXPECT_EQ(1u, session_.closed_streams()->size()); + EXPECT_EQ(stream2->id(), session_.closed_streams()->front()->id()); +} + +TEST_P(QuicSessionTestServer, RstStreamReceivedAfterRstStreamSent) { + CompleteHandshake(); + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + QuicStreamPeer::SetStreamBytesWritten(3, stream2); + EXPECT_TRUE(stream2->IsWaitingForAcks()); + + EXPECT_CALL(*connection_, SendControlFrame(_)); + EXPECT_CALL(*connection_, OnStreamReset(stream2->id(), _)); + EXPECT_CALL(session_, OnCanCreateNewOutgoingStream(false)).Times(0); + stream2->Reset(quic::QUIC_STREAM_CANCELLED); + + QuicRstStreamFrame rst1(kInvalidControlFrameId, stream2->id(), + QUIC_ERROR_PROCESSING_STREAM, 0); + if (!VersionHasIetfQuicFrames(transport_version())) { + EXPECT_CALL(session_, OnCanCreateNewOutgoingStream(false)).Times(1); + } + session_.OnRstStream(rst1); +} + +// Regression test of b/71548958. +TEST_P(QuicSessionTestServer, TestZombieStreams) { + CompleteHandshake(); + session_.set_writev_consumes_all_data(true); + + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + std::string body(100, '.'); + stream2->WriteOrBufferData(body, false, nullptr); + EXPECT_TRUE(stream2->IsWaitingForAcks()); + EXPECT_EQ(1u, QuicStreamPeer::SendBuffer(stream2).size()); + + QuicRstStreamFrame rst_frame(kInvalidControlFrameId, stream2->id(), + QUIC_STREAM_CANCELLED, 1234); + // Just for the RST_STREAM + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillOnce(Invoke(&ClearControlFrame)); + if (VersionHasIetfQuicFrames(transport_version())) { + EXPECT_CALL(*connection_, + OnStreamReset(stream2->id(), QUIC_STREAM_CANCELLED)); + } else { + EXPECT_CALL(*connection_, + OnStreamReset(stream2->id(), QUIC_RST_ACKNOWLEDGEMENT)); + } + stream2->OnStreamReset(rst_frame); + + if (VersionHasIetfQuicFrames(transport_version())) { + // The test requires the stream to be fully closed in both directions. For + // IETF QUIC, the RST_STREAM only closes one side. + QuicStopSendingFrame frame(kInvalidControlFrameId, stream2->id(), + QUIC_STREAM_CANCELLED); + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + session_.OnStopSendingFrame(frame); + } + ASSERT_EQ(1u, session_.closed_streams()->size()); + EXPECT_EQ(stream2->id(), session_.closed_streams()->front()->id()); + + TestStream* stream4 = session_.CreateOutgoingBidirectionalStream(); + if (VersionHasIetfQuicFrames(transport_version())) { + // Once for the RST_STREAM, once for the STOP_SENDING + EXPECT_CALL(*connection_, SendControlFrame(_)) + .Times(2) + .WillRepeatedly(Invoke(&ClearControlFrame)); + } else { + // Just for the RST_STREAM + EXPECT_CALL(*connection_, SendControlFrame(_)).Times(1); + } + EXPECT_CALL(*connection_, + OnStreamReset(stream4->id(), QUIC_STREAM_CANCELLED)); + stream4->WriteOrBufferData(body, false, nullptr); + // Note well: Reset() actually closes the stream in both directions. For + // GOOGLE QUIC it sends a RST_STREAM (which does a 2-way close), for IETF + // QUIC it sends both a RST_STREAM and a STOP_SENDING (each of which + // closes in only one direction). + stream4->Reset(QUIC_STREAM_CANCELLED); + EXPECT_EQ(2u, session_.closed_streams()->size()); +} + +TEST_P(QuicSessionTestServer, OnStreamFrameLost) { + CompleteHandshake(); + InSequence s; + + // Drive congestion control manually. + MockSendAlgorithm* send_algorithm = new StrictMock; + QuicConnectionPeer::SetSendAlgorithm(session_.connection(), send_algorithm); + + TestCryptoStream* crypto_stream = session_.GetMutableCryptoStream(); + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + TestStream* stream4 = session_.CreateOutgoingBidirectionalStream(); + + QuicStreamFrame frame1; + if (!QuicVersionUsesCryptoFrames(connection_->transport_version())) { + frame1 = QuicStreamFrame( + QuicUtils::GetCryptoStreamId(connection_->transport_version()), false, + 0, 1300); + } + QuicStreamFrame frame2(stream2->id(), false, 0, 9); + QuicStreamFrame frame3(stream4->id(), false, 0, 9); + + // Lost data on cryption stream, streams 2 and 4. + EXPECT_CALL(*stream4, HasPendingRetransmission()).WillOnce(Return(true)); + if (!QuicVersionUsesCryptoFrames(connection_->transport_version())) { + EXPECT_CALL(*crypto_stream, HasPendingRetransmission()) + .WillOnce(Return(true)); + } + EXPECT_CALL(*stream2, HasPendingRetransmission()).WillOnce(Return(true)); + session_.OnFrameLost(QuicFrame(frame3)); + if (!QuicVersionUsesCryptoFrames(connection_->transport_version())) { + session_.OnFrameLost(QuicFrame(frame1)); + } else { + QuicCryptoFrame crypto_frame(ENCRYPTION_INITIAL, 0, 1300); + session_.OnFrameLost(QuicFrame(&crypto_frame)); + } + session_.OnFrameLost(QuicFrame(frame2)); + EXPECT_TRUE(session_.WillingAndAbleToWrite()); + + // Mark streams 2 and 4 write blocked. + session_.MarkConnectionLevelWriteBlocked(stream2->id()); + session_.MarkConnectionLevelWriteBlocked(stream4->id()); + + // Lost data is retransmitted before new data, and retransmissions for crypto + // stream go first. + // Do not check congestion window when crypto stream has lost data. + EXPECT_CALL(*send_algorithm, CanSend(_)).Times(0); + if (!QuicVersionUsesCryptoFrames(connection_->transport_version())) { + EXPECT_CALL(*crypto_stream, OnCanWrite()); + EXPECT_CALL(*crypto_stream, HasPendingRetransmission()) + .WillOnce(Return(false)); + } + // Check congestion window for non crypto streams. + EXPECT_CALL(*send_algorithm, CanSend(_)).WillOnce(Return(true)); + EXPECT_CALL(*stream4, OnCanWrite()); + EXPECT_CALL(*stream4, HasPendingRetransmission()).WillOnce(Return(false)); + // Connection is blocked. + EXPECT_CALL(*send_algorithm, CanSend(_)).WillRepeatedly(Return(false)); + + session_.OnCanWrite(); + EXPECT_TRUE(session_.WillingAndAbleToWrite()); + + // Unblock connection. + // Stream 2 retransmits lost data. + EXPECT_CALL(*send_algorithm, CanSend(_)).WillOnce(Return(true)); + EXPECT_CALL(*stream2, OnCanWrite()); + EXPECT_CALL(*stream2, HasPendingRetransmission()).WillOnce(Return(false)); + EXPECT_CALL(*send_algorithm, CanSend(_)).WillOnce(Return(true)); + // Stream 2 sends new data. + EXPECT_CALL(*stream2, OnCanWrite()); + EXPECT_CALL(*send_algorithm, CanSend(_)).WillOnce(Return(true)); + EXPECT_CALL(*stream4, OnCanWrite()); + EXPECT_CALL(*send_algorithm, OnApplicationLimited(_)); + + session_.OnCanWrite(); + EXPECT_FALSE(session_.WillingAndAbleToWrite()); +} + +TEST_P(QuicSessionTestServer, DonotRetransmitDataOfClosedStreams) { + CompleteHandshake(); + InSequence s; + + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + TestStream* stream4 = session_.CreateOutgoingBidirectionalStream(); + TestStream* stream6 = session_.CreateOutgoingBidirectionalStream(); + + QuicStreamFrame frame1(stream2->id(), false, 0, 9); + QuicStreamFrame frame2(stream4->id(), false, 0, 9); + QuicStreamFrame frame3(stream6->id(), false, 0, 9); + + EXPECT_CALL(*stream6, HasPendingRetransmission()).WillOnce(Return(true)); + EXPECT_CALL(*stream4, HasPendingRetransmission()).WillOnce(Return(true)); + EXPECT_CALL(*stream2, HasPendingRetransmission()).WillOnce(Return(true)); + session_.OnFrameLost(QuicFrame(frame3)); + session_.OnFrameLost(QuicFrame(frame2)); + session_.OnFrameLost(QuicFrame(frame1)); + + session_.MarkConnectionLevelWriteBlocked(stream2->id()); + session_.MarkConnectionLevelWriteBlocked(stream4->id()); + session_.MarkConnectionLevelWriteBlocked(stream6->id()); + + // Reset stream 4 locally. + EXPECT_CALL(*connection_, SendControlFrame(_)); + EXPECT_CALL(*connection_, OnStreamReset(stream4->id(), _)); + stream4->Reset(QUIC_STREAM_CANCELLED); + + // Verify stream 4 is removed from streams with lost data list. + EXPECT_CALL(*stream6, OnCanWrite()); + EXPECT_CALL(*stream6, HasPendingRetransmission()).WillOnce(Return(false)); + EXPECT_CALL(*stream2, OnCanWrite()); + EXPECT_CALL(*stream2, HasPendingRetransmission()).WillOnce(Return(false)); + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillRepeatedly(Invoke(&ClearControlFrame)); + EXPECT_CALL(*stream2, OnCanWrite()); + EXPECT_CALL(*stream6, OnCanWrite()); + session_.OnCanWrite(); +} + +TEST_P(QuicSessionTestServer, RetransmitFrames) { + CompleteHandshake(); + MockSendAlgorithm* send_algorithm = new StrictMock; + QuicConnectionPeer::SetSendAlgorithm(session_.connection(), send_algorithm); + InSequence s; + + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + TestStream* stream4 = session_.CreateOutgoingBidirectionalStream(); + TestStream* stream6 = session_.CreateOutgoingBidirectionalStream(); + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillOnce(Invoke(&ClearControlFrame)); + session_.SendWindowUpdate(stream2->id(), 9); + + QuicStreamFrame frame1(stream2->id(), false, 0, 9); + QuicStreamFrame frame2(stream4->id(), false, 0, 9); + QuicStreamFrame frame3(stream6->id(), false, 0, 9); + QuicWindowUpdateFrame window_update(1, stream2->id(), 9); + QuicFrames frames; + frames.push_back(QuicFrame(frame1)); + frames.push_back(QuicFrame(window_update)); + frames.push_back(QuicFrame(frame2)); + frames.push_back(QuicFrame(frame3)); + EXPECT_FALSE(session_.WillingAndAbleToWrite()); + + EXPECT_CALL(*stream2, RetransmitStreamData(_, _, _, _)) + .WillOnce(Return(true)); + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillOnce(Invoke(&ClearControlFrame)); + EXPECT_CALL(*stream4, RetransmitStreamData(_, _, _, _)) + .WillOnce(Return(true)); + EXPECT_CALL(*stream6, RetransmitStreamData(_, _, _, _)) + .WillOnce(Return(true)); + EXPECT_CALL(*send_algorithm, OnApplicationLimited(_)); + session_.RetransmitFrames(frames, PTO_RETRANSMISSION); +} + +// Regression test of b/110082001. +TEST_P(QuicSessionTestServer, RetransmitLostDataCausesConnectionClose) { + CompleteHandshake(); + // This test mimics the scenario when a dynamic stream retransmits lost data + // and causes connection close. + TestStream* stream = session_.CreateOutgoingBidirectionalStream(); + QuicStreamFrame frame(stream->id(), false, 0, 9); + + EXPECT_CALL(*stream, HasPendingRetransmission()) + .Times(2) + .WillOnce(Return(true)) + .WillOnce(Return(false)); + session_.OnFrameLost(QuicFrame(frame)); + // Retransmit stream data causes connection close. Stream has not sent fin + // yet, so an RST is sent. + EXPECT_CALL(*stream, OnCanWrite()).WillOnce(Invoke([this, stream]() { + session_.ResetStream(stream->id(), QUIC_STREAM_CANCELLED); + })); + if (VersionHasIetfQuicFrames(transport_version())) { + // Once for the RST_STREAM, once for the STOP_SENDING + EXPECT_CALL(*connection_, SendControlFrame(_)) + .Times(2) + .WillRepeatedly(Invoke(&session_, &TestSession::SaveFrame)); + } else { + // Just for the RST_STREAM + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillOnce(Invoke(&session_, &TestSession::SaveFrame)); + } + EXPECT_CALL(*connection_, OnStreamReset(stream->id(), _)); + session_.OnCanWrite(); +} + +TEST_P(QuicSessionTestServer, SendMessage) { + // Cannot send message when encryption is not established. + EXPECT_FALSE(session_.OneRttKeysAvailable()); + EXPECT_EQ(MessageResult(MESSAGE_STATUS_ENCRYPTION_NOT_ESTABLISHED, 0), + session_.SendMessage(MemSliceFromString(""))); + + CompleteHandshake(); + EXPECT_TRUE(session_.OneRttKeysAvailable()); + + EXPECT_CALL(*connection_, SendMessage(1, _, false)) + .WillOnce(Return(MESSAGE_STATUS_SUCCESS)); + EXPECT_EQ(MessageResult(MESSAGE_STATUS_SUCCESS, 1), + session_.SendMessage(MemSliceFromString(""))); + // Verify message_id increases. + EXPECT_CALL(*connection_, SendMessage(2, _, false)) + .WillOnce(Return(MESSAGE_STATUS_TOO_LARGE)); + EXPECT_EQ(MessageResult(MESSAGE_STATUS_TOO_LARGE, 0), + session_.SendMessage(MemSliceFromString(""))); + // Verify unsent message does not consume a message_id. + EXPECT_CALL(*connection_, SendMessage(2, _, false)) + .WillOnce(Return(MESSAGE_STATUS_SUCCESS)); + EXPECT_EQ(MessageResult(MESSAGE_STATUS_SUCCESS, 2), + session_.SendMessage(MemSliceFromString(""))); + + QuicMessageFrame frame(1); + QuicMessageFrame frame2(2); + EXPECT_FALSE(session_.IsFrameOutstanding(QuicFrame(&frame))); + EXPECT_FALSE(session_.IsFrameOutstanding(QuicFrame(&frame2))); + + // Lost message 2. + session_.OnMessageLost(2); + EXPECT_FALSE(session_.IsFrameOutstanding(QuicFrame(&frame2))); + + // message 1 gets acked. + session_.OnMessageAcked(1, QuicTime::Zero()); + EXPECT_FALSE(session_.IsFrameOutstanding(QuicFrame(&frame))); +} + +// Regression test of b/115323618. +TEST_P(QuicSessionTestServer, LocallyResetZombieStreams) { + CompleteHandshake(); + session_.set_writev_consumes_all_data(true); + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + std::string body(100, '.'); + QuicStreamPeer::CloseReadSide(stream2); + stream2->WriteOrBufferData(body, true, nullptr); + EXPECT_TRUE(stream2->IsWaitingForAcks()); + // Verify stream2 is a zombie streams. + auto& stream_map = QuicSessionPeer::stream_map(&session_); + ASSERT_TRUE(stream_map.contains(stream2->id())); + auto* stream = stream_map.find(stream2->id())->second.get(); + EXPECT_TRUE(stream->IsZombie()); + + QuicStreamFrame frame(stream2->id(), true, 0, 100); + EXPECT_CALL(*stream2, HasPendingRetransmission()) + .WillRepeatedly(Return(true)); + session_.OnFrameLost(QuicFrame(frame)); + + // Reset stream2 locally. + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillRepeatedly(Invoke(&ClearControlFrame)); + EXPECT_CALL(*connection_, OnStreamReset(stream2->id(), _)); + stream2->Reset(QUIC_STREAM_CANCELLED); + + // Verify stream 2 gets closed. + EXPECT_TRUE(session_.IsClosedStream(stream2->id())); + EXPECT_CALL(*stream2, OnCanWrite()).Times(0); + session_.OnCanWrite(); +} + +TEST_P(QuicSessionTestServer, CleanUpClosedStreamsAlarm) { + CompleteHandshake(); + EXPECT_FALSE( + QuicSessionPeer::GetCleanUpClosedStreamsAlarm(&session_)->IsSet()); + + session_.set_writev_consumes_all_data(true); + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + EXPECT_FALSE(stream2->IsWaitingForAcks()); + + CloseStream(stream2->id()); + EXPECT_EQ(1u, session_.closed_streams()->size()); + EXPECT_TRUE( + QuicSessionPeer::GetCleanUpClosedStreamsAlarm(&session_)->IsSet()); + + alarm_factory_.FireAlarm( + QuicSessionPeer::GetCleanUpClosedStreamsAlarm(&session_)); + EXPECT_TRUE(session_.closed_streams()->empty()); +} + +TEST_P(QuicSessionTestServer, WriteUnidirectionalStream) { + session_.set_writev_consumes_all_data(true); + TestStream* stream4 = new TestStream(GetNthServerInitiatedUnidirectionalId(1), + &session_, WRITE_UNIDIRECTIONAL); + session_.ActivateStream(absl::WrapUnique(stream4)); + std::string body(100, '.'); + stream4->WriteOrBufferData(body, false, nullptr); + stream4->WriteOrBufferData(body, true, nullptr); + auto& stream_map = QuicSessionPeer::stream_map(&session_); + ASSERT_TRUE(stream_map.contains(stream4->id())); + auto* stream = stream_map.find(stream4->id())->second.get(); + EXPECT_TRUE(stream->IsZombie()); +} + +TEST_P(QuicSessionTestServer, ReceivedDataOnWriteUnidirectionalStream) { + TestStream* stream4 = new TestStream(GetNthServerInitiatedUnidirectionalId(1), + &session_, WRITE_UNIDIRECTIONAL); + session_.ActivateStream(absl::WrapUnique(stream4)); + + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_DATA_RECEIVED_ON_WRITE_UNIDIRECTIONAL_STREAM, _, _)) + .Times(1); + QuicStreamFrame stream_frame(GetNthServerInitiatedUnidirectionalId(1), false, + 0, 2); + session_.OnStreamFrame(stream_frame); +} + +TEST_P(QuicSessionTestServer, ReadUnidirectionalStream) { + TestStream* stream4 = new TestStream(GetNthClientInitiatedUnidirectionalId(1), + &session_, READ_UNIDIRECTIONAL); + session_.ActivateStream(absl::WrapUnique(stream4)); + EXPECT_FALSE(stream4->IsWaitingForAcks()); + // Discard all incoming data. + stream4->StopReading(); + + std::string data(100, '.'); + QuicStreamFrame stream_frame(GetNthClientInitiatedUnidirectionalId(1), false, + 0, data); + stream4->OnStreamFrame(stream_frame); + EXPECT_TRUE(session_.closed_streams()->empty()); + + QuicStreamFrame stream_frame2(GetNthClientInitiatedUnidirectionalId(1), true, + 100, data); + stream4->OnStreamFrame(stream_frame2); + EXPECT_EQ(1u, session_.closed_streams()->size()); +} + +TEST_P(QuicSessionTestServer, WriteOrBufferDataOnReadUnidirectionalStream) { + TestStream* stream4 = new TestStream(GetNthClientInitiatedUnidirectionalId(1), + &session_, READ_UNIDIRECTIONAL); + session_.ActivateStream(absl::WrapUnique(stream4)); + + EXPECT_CALL(*connection_, + CloseConnection( + QUIC_TRY_TO_WRITE_DATA_ON_READ_UNIDIRECTIONAL_STREAM, _, _)) + .Times(1); + std::string body(100, '.'); + stream4->WriteOrBufferData(body, false, nullptr); +} + +TEST_P(QuicSessionTestServer, WritevDataOnReadUnidirectionalStream) { + TestStream* stream4 = new TestStream(GetNthClientInitiatedUnidirectionalId(1), + &session_, READ_UNIDIRECTIONAL); + session_.ActivateStream(absl::WrapUnique(stream4)); + + EXPECT_CALL(*connection_, + CloseConnection( + QUIC_TRY_TO_WRITE_DATA_ON_READ_UNIDIRECTIONAL_STREAM, _, _)) + .Times(1); + std::string body(100, '.'); + struct iovec iov = {const_cast(body.data()), body.length()}; + quiche::QuicheMemSliceStorage storage( + &iov, 1, session_.connection()->helper()->GetStreamSendBufferAllocator(), + 1024); + stream4->WriteMemSlices(storage.ToSpan(), false); +} + +TEST_P(QuicSessionTestServer, WriteMemSlicesOnReadUnidirectionalStream) { + TestStream* stream4 = new TestStream(GetNthClientInitiatedUnidirectionalId(1), + &session_, READ_UNIDIRECTIONAL); + session_.ActivateStream(absl::WrapUnique(stream4)); + + EXPECT_CALL(*connection_, + CloseConnection( + QUIC_TRY_TO_WRITE_DATA_ON_READ_UNIDIRECTIONAL_STREAM, _, _)) + .Times(1); + std::string data(1024, 'a'); + std::vector buffers; + buffers.push_back(MemSliceFromString(data)); + buffers.push_back(MemSliceFromString(data)); + stream4->WriteMemSlices(absl::MakeSpan(buffers), false); +} + +// Test code that tests that an incoming stream frame with a new (not previously +// seen) stream id is acceptable. The ID must not be larger than has been +// advertised. It may be equal to what has been advertised. These tests +// invoke QuicStreamIdManager::MaybeIncreaseLargestPeerStreamId by calling +// QuicSession::OnStreamFrame in order to check that all the steps are connected +// properly and that nothing in the call path interferes with the check. +// First test make sure that streams with ids below the limit are accepted. +TEST_P(QuicSessionTestServer, NewStreamIdBelowLimit) { + if (!VersionHasIetfQuicFrames(transport_version())) { + // Applicable only to IETF QUIC + return; + } + QuicStreamId bidirectional_stream_id = StreamCountToId( + QuicSessionPeer::ietf_streamid_manager(&session_) + ->advertised_max_incoming_bidirectional_streams() - + 1, + Perspective::IS_CLIENT, + /*bidirectional=*/true); + + QuicStreamFrame bidirectional_stream_frame(bidirectional_stream_id, false, 0, + "Random String"); + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + session_.OnStreamFrame(bidirectional_stream_frame); + + QuicStreamId unidirectional_stream_id = StreamCountToId( + QuicSessionPeer::ietf_streamid_manager(&session_) + ->advertised_max_incoming_unidirectional_streams() - + 1, + Perspective::IS_CLIENT, + /*bidirectional=*/false); + QuicStreamFrame unidirectional_stream_frame(unidirectional_stream_id, false, + 0, "Random String"); + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + session_.OnStreamFrame(unidirectional_stream_frame); +} + +// Accept a stream with an ID that equals the limit. +TEST_P(QuicSessionTestServer, NewStreamIdAtLimit) { + if (!VersionHasIetfQuicFrames(transport_version())) { + // Applicable only to IETF QUIC + return; + } + QuicStreamId bidirectional_stream_id = + StreamCountToId(QuicSessionPeer::ietf_streamid_manager(&session_) + ->advertised_max_incoming_bidirectional_streams(), + Perspective::IS_CLIENT, /*bidirectional=*/true); + QuicStreamFrame bidirectional_stream_frame(bidirectional_stream_id, false, 0, + "Random String"); + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + session_.OnStreamFrame(bidirectional_stream_frame); + + QuicStreamId unidirectional_stream_id = + StreamCountToId(QuicSessionPeer::ietf_streamid_manager(&session_) + ->advertised_max_incoming_unidirectional_streams(), + Perspective::IS_CLIENT, /*bidirectional=*/false); + QuicStreamFrame unidirectional_stream_frame(unidirectional_stream_id, false, + 0, "Random String"); + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + session_.OnStreamFrame(unidirectional_stream_frame); +} + +// Close the connection if the id exceeds the limit. +TEST_P(QuicSessionTestServer, NewStreamIdAboveLimit) { + if (!VersionHasIetfQuicFrames(transport_version())) { + // Applicable only to IETF QUIC + return; + } + + QuicStreamId bidirectional_stream_id = StreamCountToId( + QuicSessionPeer::ietf_streamid_manager(&session_) + ->advertised_max_incoming_bidirectional_streams() + + 1, + Perspective::IS_CLIENT, /*bidirectional=*/true); + QuicStreamFrame bidirectional_stream_frame(bidirectional_stream_id, false, 0, + "Random String"); + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_INVALID_STREAM_ID, + "Stream id 400 would exceed stream count limit 100", _)); + session_.OnStreamFrame(bidirectional_stream_frame); + + QuicStreamId unidirectional_stream_id = StreamCountToId( + QuicSessionPeer::ietf_streamid_manager(&session_) + ->advertised_max_incoming_unidirectional_streams() + + 1, + Perspective::IS_CLIENT, /*bidirectional=*/false); + QuicStreamFrame unidirectional_stream_frame(unidirectional_stream_id, false, + 0, "Random String"); + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_INVALID_STREAM_ID, + "Stream id 402 would exceed stream count limit 100", _)); + session_.OnStreamFrame(unidirectional_stream_frame); +} + +// Checks that invalid stream ids are handled. +TEST_P(QuicSessionTestServer, OnStopSendingInvalidStreamId) { + if (!VersionHasIetfQuicFrames(transport_version())) { + return; + } + // Check that "invalid" stream ids are rejected. + QuicStopSendingFrame frame(1, -1, QUIC_STREAM_CANCELLED); + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_INVALID_STREAM_ID, + "Received STOP_SENDING for an invalid stream", _)); + session_.OnStopSendingFrame(frame); +} + +TEST_P(QuicSessionTestServer, OnStopSendingReadUnidirectional) { + if (!VersionHasIetfQuicFrames(transport_version())) { + return; + } + // It's illegal to send STOP_SENDING with a stream ID that is read-only. + QuicStopSendingFrame frame(1, GetNthClientInitiatedUnidirectionalId(1), + QUIC_STREAM_CANCELLED); + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_INVALID_STREAM_ID, + "Received STOP_SENDING for a read-only stream", _)); + session_.OnStopSendingFrame(frame); +} + +// Static streams ignore STOP_SENDING. +TEST_P(QuicSessionTestServer, OnStopSendingStaticStreams) { + if (!VersionHasIetfQuicFrames(transport_version())) { + return; + } + QuicStreamId stream_id = 0; + std::unique_ptr fake_static_stream = std::make_unique( + stream_id, &session_, /*is_static*/ true, BIDIRECTIONAL); + QuicSessionPeer::ActivateStream(&session_, std::move(fake_static_stream)); + // Check that a stream id in the static stream map is ignored. + QuicStopSendingFrame frame(1, stream_id, QUIC_STREAM_CANCELLED); + EXPECT_CALL(*connection_, + CloseConnection(QUIC_INVALID_STREAM_ID, + "Received STOP_SENDING for a static stream", _)); + session_.OnStopSendingFrame(frame); +} + +// If stream is write closed, do not send a RESET_STREAM frame. +TEST_P(QuicSessionTestServer, OnStopSendingForWriteClosedStream) { + if (!VersionHasIetfQuicFrames(transport_version())) { + return; + } + + TestStream* stream = session_.CreateOutgoingBidirectionalStream(); + QuicStreamId stream_id = stream->id(); + QuicStreamPeer::SetFinSent(stream); + stream->CloseWriteSide(); + EXPECT_TRUE(stream->write_side_closed()); + QuicStopSendingFrame frame(1, stream_id, QUIC_STREAM_CANCELLED); + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + session_.OnStopSendingFrame(frame); +} + +// If stream is closed, return true and do not close the connection. +TEST_P(QuicSessionTestServer, OnStopSendingClosedStream) { + if (!VersionHasIetfQuicFrames(transport_version())) { + return; + } + CompleteHandshake(); + TestStream* stream = session_.CreateOutgoingBidirectionalStream(); + QuicStreamId stream_id = stream->id(); + CloseStream(stream_id); + QuicStopSendingFrame frame(1, stream_id, QUIC_STREAM_CANCELLED); + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + session_.OnStopSendingFrame(frame); +} + +// If stream id is a nonexistent local stream, return false and close the +// connection. +TEST_P(QuicSessionTestServer, OnStopSendingInputNonExistentLocalStream) { + if (!VersionHasIetfQuicFrames(transport_version())) { + return; + } + + QuicStopSendingFrame frame(1, GetNthServerInitiatedBidirectionalId(123456), + QUIC_STREAM_CANCELLED); + EXPECT_CALL(*connection_, CloseConnection(QUIC_HTTP_STREAM_WRONG_DIRECTION, + "Data for nonexistent stream", _)) + .Times(1); + session_.OnStopSendingFrame(frame); +} + +// If a STOP_SENDING is received for a peer initiated stream, the new stream +// will be created. +TEST_P(QuicSessionTestServer, OnStopSendingNewStream) { + CompleteHandshake(); + if (!VersionHasIetfQuicFrames(transport_version())) { + return; + } + QuicStopSendingFrame frame(1, GetNthClientInitiatedBidirectionalId(1), + QUIC_STREAM_CANCELLED); + + // A Rst will be sent as a response for STOP_SENDING. + EXPECT_CALL(*connection_, SendControlFrame(_)).Times(1); + EXPECT_CALL(*connection_, OnStreamReset(_, _)).Times(1); + session_.OnStopSendingFrame(frame); + + QuicStream* stream = + session_.GetOrCreateStream(GetNthClientInitiatedBidirectionalId(1)); + EXPECT_TRUE(stream); + EXPECT_TRUE(stream->write_side_closed()); +} + +// For a valid stream, ensure that all works +TEST_P(QuicSessionTestServer, OnStopSendingInputValidStream) { + CompleteHandshake(); + if (!VersionHasIetfQuicFrames(transport_version())) { + // Applicable only to IETF QUIC + return; + } + + TestStream* stream = session_.CreateOutgoingBidirectionalStream(); + + // Ensure that the stream starts out open in both directions. + EXPECT_FALSE(stream->write_side_closed()); + EXPECT_FALSE(QuicStreamPeer::read_side_closed(stream)); + + QuicStreamId stream_id = stream->id(); + QuicStopSendingFrame frame(1, stream_id, QUIC_STREAM_CANCELLED); + // Expect a reset to come back out. + EXPECT_CALL(*connection_, SendControlFrame(_)); + EXPECT_CALL(*connection_, OnStreamReset(stream_id, QUIC_STREAM_CANCELLED)); + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + session_.OnStopSendingFrame(frame); + + EXPECT_FALSE(QuicStreamPeer::read_side_closed(stream)); + EXPECT_TRUE(stream->write_side_closed()); +} + +TEST_P(QuicSessionTestServer, WriteBufferedCryptoFrames) { + if (!QuicVersionUsesCryptoFrames(connection_->transport_version())) { + return; + } + std::string data(1350, 'a'); + TestCryptoStream* crypto_stream = session_.GetMutableCryptoStream(); + // Only consumed 1000 bytes. + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_INITIAL, 1350, 0)) + .WillOnce(Return(1000)); + crypto_stream->WriteCryptoData(ENCRYPTION_INITIAL, data); + EXPECT_TRUE(session_.HasPendingHandshake()); + EXPECT_TRUE(session_.WillingAndAbleToWrite()); + + EXPECT_CALL(*connection_, SendCryptoData(_, _, _)).Times(0); + connection_->SetEncrypter( + ENCRYPTION_ZERO_RTT, + std::make_unique(connection_->perspective())); + crypto_stream->WriteCryptoData(ENCRYPTION_ZERO_RTT, data); + + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_INITIAL, 350, 1000)) + .WillOnce(Return(350)); + EXPECT_CALL( + *connection_, + SendCryptoData(crypto_stream->GetEncryptionLevelToSendCryptoDataOfSpace( + QuicUtils::GetPacketNumberSpace(ENCRYPTION_ZERO_RTT)), + 1350, 0)) + .WillOnce(Return(1350)); + session_.OnCanWrite(); + EXPECT_FALSE(session_.HasPendingHandshake()); + EXPECT_FALSE(session_.WillingAndAbleToWrite()); +} + +// Regression test for +// https://bugs.chromium.org/p/chromium/issues/detail?id=1002119 +TEST_P(QuicSessionTestServer, StreamFrameReceivedAfterFin) { + TestStream* stream = session_.CreateOutgoingBidirectionalStream(); + QuicStreamFrame frame(stream->id(), true, 0, ","); + session_.OnStreamFrame(frame); + + QuicStreamFrame frame1(stream->id(), false, 1, ","); + EXPECT_CALL(*connection_, + CloseConnection(QUIC_STREAM_DATA_BEYOND_CLOSE_OFFSET, _, _)); + session_.OnStreamFrame(frame1); +} + +TEST_P(QuicSessionTestServer, ResetForIETFStreamTypes) { + CompleteHandshake(); + if (!VersionHasIetfQuicFrames(transport_version())) { + return; + } + + QuicStreamId read_only = GetNthClientInitiatedUnidirectionalId(0); + + EXPECT_CALL(*connection_, SendControlFrame(_)) + .Times(1) + .WillOnce(Invoke(&ClearControlFrame)); + EXPECT_CALL(*connection_, OnStreamReset(read_only, _)); + session_.ResetStream(read_only, QUIC_STREAM_CANCELLED); + + QuicStreamId write_only = GetNthServerInitiatedUnidirectionalId(0); + EXPECT_CALL(*connection_, SendControlFrame(_)) + .Times(1) + .WillOnce(Invoke(&ClearControlFrame)); + EXPECT_CALL(*connection_, OnStreamReset(write_only, _)); + session_.ResetStream(write_only, QUIC_STREAM_CANCELLED); + + QuicStreamId bidirectional = GetNthClientInitiatedBidirectionalId(0); + EXPECT_CALL(*connection_, SendControlFrame(_)) + .Times(2) + .WillRepeatedly(Invoke(&ClearControlFrame)); + EXPECT_CALL(*connection_, OnStreamReset(bidirectional, _)); + session_.ResetStream(bidirectional, QUIC_STREAM_CANCELLED); +} + +TEST_P(QuicSessionTestServer, DecryptionKeyAvailableBeforeEncryptionKey) { + if (connection_->version().handshake_protocol != PROTOCOL_TLS1_3) { + return; + } + ASSERT_FALSE(connection_->framer().HasEncrypterOfEncryptionLevel( + ENCRYPTION_HANDSHAKE)); + EXPECT_FALSE(session_.OnNewDecryptionKeyAvailable( + ENCRYPTION_HANDSHAKE, /*decrypter=*/nullptr, + /*set_alternative_decrypter=*/false, /*latch_once_used=*/false)); +} + +TEST_P(QuicSessionTestServer, IncomingStreamWithServerInitiatedStreamId) { + const QuicErrorCode expected_error = + VersionHasIetfQuicFrames(transport_version()) + ? QUIC_HTTP_STREAM_WRONG_DIRECTION + : QUIC_INVALID_STREAM_ID; + EXPECT_CALL( + *connection_, + CloseConnection(expected_error, "Data for nonexistent stream", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET)); + + QuicStreamFrame frame(GetNthServerInitiatedBidirectionalId(1), + /* fin = */ false, /* offset = */ 0, + absl::string_view("foo")); + session_.OnStreamFrame(frame); +} + +// Regression test for b/235204908. +TEST_P(QuicSessionTestServer, BlockedFrameCausesWriteError) { + CompleteHandshake(); + MockPacketWriter* writer = static_cast( + QuicConnectionPeer::GetWriter(session_.connection())); + EXPECT_CALL(*writer, WritePacket(_, _, _, _, _)) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 0))); + // Set a small connection level flow control limit. + const uint64_t kWindow = 36; + QuicFlowControllerPeer::SetSendWindowOffset(session_.flow_controller(), + kWindow); + auto stream = + session_.GetOrCreateStream(GetNthClientInitiatedBidirectionalId(0)); + // Try to send more data than the flow control limit allows. + const uint64_t kOverflow = 15; + std::string body(kWindow + kOverflow, 'a'); + EXPECT_CALL(*connection_, SendControlFrame(_)) + .WillOnce(testing::InvokeWithoutArgs([this]() { + connection_->ReallyCloseConnection( + QUIC_PACKET_WRITE_ERROR, "write error", + ConnectionCloseBehavior::SILENT_CLOSE); + return false; + })); + stream->WriteOrBufferData(body, false, nullptr); +} + +TEST_P(QuicSessionTestServer, BufferedCryptoFrameCausesWriteError) { + if (!VersionHasIetfQuicFrames(transport_version())) { + return; + } + std::string data(1350, 'a'); + TestCryptoStream* crypto_stream = session_.GetMutableCryptoStream(); + // Only consumed 1000 bytes. + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_FORWARD_SECURE, 1350, 0)) + .WillOnce(Return(1000)); + crypto_stream->WriteCryptoData(ENCRYPTION_FORWARD_SECURE, data); + EXPECT_TRUE(session_.HasPendingHandshake()); + EXPECT_TRUE(session_.WillingAndAbleToWrite()); + + EXPECT_CALL(*connection_, + SendCryptoData(ENCRYPTION_FORWARD_SECURE, 350, 1000)) + .WillOnce(Return(0)); + // Buffer the HANDSHAKE_DONE frame. + EXPECT_CALL(*connection_, SendControlFrame(_)).WillOnce(Return(false)); + CryptoHandshakeMessage msg; + session_.GetMutableCryptoStream()->OnHandshakeMessage(msg); + + // Flush both frames. + EXPECT_CALL(*connection_, + SendCryptoData(ENCRYPTION_FORWARD_SECURE, 350, 1000)) + .WillOnce(testing::InvokeWithoutArgs([this]() { + connection_->ReallyCloseConnection( + QUIC_PACKET_WRITE_ERROR, "write error", + ConnectionCloseBehavior::SILENT_CLOSE); + return 350; + })); + if (!GetQuicReloadableFlag( + quic_no_write_control_frame_upon_connection_close)) { + EXPECT_CALL(*connection_, SendControlFrame(_)).WillOnce(Return(false)); + EXPECT_QUIC_BUG(session_.OnCanWrite(), "Try to write control frame"); + } else { + session_.OnCanWrite(); + } +} + +TEST_P(QuicSessionTestServer, DonotPtoStreamDataBeforeHandshakeConfirmed) { + if (!session_.version().UsesTls()) { + return; + } + EXPECT_NE(HANDSHAKE_CONFIRMED, session_.GetHandshakeState()); + + TestCryptoStream* crypto_stream = session_.GetMutableCryptoStream(); + EXPECT_FALSE(crypto_stream->HasBufferedCryptoFrames()); + std::string data(1350, 'a'); + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_INITIAL, 1350, 0)) + .WillOnce(Return(1000)); + crypto_stream->WriteCryptoData(ENCRYPTION_INITIAL, data); + ASSERT_TRUE(crypto_stream->HasBufferedCryptoFrames()); + + TestStream* stream = session_.CreateOutgoingBidirectionalStream(); + + session_.MarkConnectionLevelWriteBlocked(stream->id()); + // Buffered crypto data gets sent. + EXPECT_CALL(*connection_, SendCryptoData(ENCRYPTION_INITIAL, _, _)) + .WillOnce(Return(350)); + // Verify stream data is not sent on PTO before handshake confirmed. + EXPECT_CALL(*stream, OnCanWrite()).Times(0); + + // Fire PTO. + QuicConnectionPeer::SetInProbeTimeOut(connection_, true); + session_.OnCanWrite(); + EXPECT_FALSE(crypto_stream->HasBufferedCryptoFrames()); +} + +TEST_P(QuicSessionTestServer, SetStatelessResetTokenToSend) { + if (!session_.version().HasIetfQuicFrames()) { + return; + } + EXPECT_TRUE(session_.config()->HasStatelessResetTokenToSend()); +} + +TEST_P(QuicSessionTestServer, + SetServerPreferredAddressAccordingToAddressFamily) { + if (!session_.version().HasIetfQuicFrames()) { + return; + } + EXPECT_EQ(quiche::IpAddressFamily::IP_V4, + connection_->peer_address().host().address_family()); + QuicConnectionPeer::SetEffectivePeerAddress(connection_, + connection_->peer_address()); + QuicTagVector copt; + copt.push_back(kSPAD); + QuicConfigPeer::SetReceivedConnectionOptions(session_.config(), copt); + QuicSocketAddress preferred_address(QuicIpAddress::Loopback4(), 12345); + session_.config()->SetIPv4AlternateServerAddressToSend(preferred_address); + session_.config()->SetIPv6AlternateServerAddressToSend( + QuicSocketAddress(QuicIpAddress::Loopback6(), 12345)); + + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + session_.OnConfigNegotiated(); + EXPECT_EQ(QuicSocketAddress(QuicIpAddress::Loopback4(), 12345), + session_.config() + ->GetPreferredAddressToSend(quiche::IpAddressFamily::IP_V4) + .value()); + EXPECT_FALSE(session_.config() + ->GetPreferredAddressToSend(quiche::IpAddressFamily::IP_V6) + .has_value()); + EXPECT_EQ(preferred_address, + QuicConnectionPeer::GetSentServerPreferredAddress(connection_)); +} + +TEST_P(QuicSessionTestServer, NoServerPreferredAddressIfAddressFamilyMismatch) { + if (!session_.version().HasIetfQuicFrames()) { + return; + } + EXPECT_EQ(quiche::IpAddressFamily::IP_V4, + connection_->peer_address().host().address_family()); + QuicConnectionPeer::SetEffectivePeerAddress(connection_, + connection_->peer_address()); + QuicTagVector copt; + copt.push_back(kSPAD); + QuicConfigPeer::SetReceivedConnectionOptions(session_.config(), copt); + session_.config()->SetIPv6AlternateServerAddressToSend( + QuicSocketAddress(QuicIpAddress::Loopback6(), 12345)); + + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + session_.OnConfigNegotiated(); + EXPECT_FALSE(session_.config() + ->GetPreferredAddressToSend(quiche::IpAddressFamily::IP_V4) + .has_value()); + EXPECT_FALSE(session_.config() + ->GetPreferredAddressToSend(quiche::IpAddressFamily::IP_V6) + .has_value()); + EXPECT_FALSE(QuicConnectionPeer::GetSentServerPreferredAddress(connection_) + .IsInitialized()); +} + +// A client test class that can be used when the automatic configuration is not +// desired. +class QuicSessionTestClientUnconfigured : public QuicSessionTestBase { + protected: + QuicSessionTestClientUnconfigured() + : QuicSessionTestBase(Perspective::IS_CLIENT, + /*configure_session=*/false) {} +}; + +INSTANTIATE_TEST_SUITE_P(Tests, QuicSessionTestClientUnconfigured, + ::testing::ValuesIn(AllSupportedVersions()), + ::testing::PrintToStringParamName()); + +TEST_P(QuicSessionTestClientUnconfigured, StreamInitiallyBlockedThenUnblocked) { + if (!connection_->version().AllowsLowFlowControlLimits()) { + return; + } + // Create a stream before negotiating the config and verify it starts off + // blocked. + QuicSessionPeer::SetMaxOpenOutgoingBidirectionalStreams(&session_, 10); + TestStream* stream2 = session_.CreateOutgoingBidirectionalStream(); + EXPECT_TRUE(stream2->IsFlowControlBlocked()); + EXPECT_TRUE(session_.IsConnectionFlowControlBlocked()); + EXPECT_TRUE(session_.IsStreamFlowControlBlocked()); + + // Negotiate the config with higher received limits. + QuicConfigPeer::SetReceivedInitialMaxStreamDataBytesOutgoingBidirectional( + session_.config(), kMinimumFlowControlSendWindow); + QuicConfigPeer::SetReceivedInitialSessionFlowControlWindow( + session_.config(), kMinimumFlowControlSendWindow); + session_.OnConfigNegotiated(); + + // Stream is now unblocked. + EXPECT_FALSE(stream2->IsFlowControlBlocked()); + EXPECT_FALSE(session_.IsConnectionFlowControlBlocked()); + EXPECT_FALSE(session_.IsStreamFlowControlBlocked()); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_socket_address_coder.cc b/quiche/quic/core/quic_socket_address_coder.cc new file mode 100644 index 000000000000..9bf85b2ea9e7 --- /dev/null +++ b/quiche/quic/core/quic_socket_address_coder.cc @@ -0,0 +1,92 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_socket_address_coder.h" + +#include +#include +#include + +#include "quiche/quic/platform/api/quic_ip_address_family.h" + +namespace quic { + +namespace { + +// For convenience, the values of these constants match the values of AF_INET +// and AF_INET6 on Linux. +const uint16_t kIPv4 = 2; +const uint16_t kIPv6 = 10; + +} // namespace + +QuicSocketAddressCoder::QuicSocketAddressCoder() {} + +QuicSocketAddressCoder::QuicSocketAddressCoder(const QuicSocketAddress& address) + : address_(address) {} + +QuicSocketAddressCoder::~QuicSocketAddressCoder() {} + +std::string QuicSocketAddressCoder::Encode() const { + std::string serialized; + uint16_t address_family; + switch (address_.host().address_family()) { + case IpAddressFamily::IP_V4: + address_family = kIPv4; + break; + case IpAddressFamily::IP_V6: + address_family = kIPv6; + break; + default: + return serialized; + } + serialized.append(reinterpret_cast(&address_family), + sizeof(address_family)); + serialized.append(address_.host().ToPackedString()); + uint16_t port = address_.port(); + serialized.append(reinterpret_cast(&port), sizeof(port)); + return serialized; +} + +bool QuicSocketAddressCoder::Decode(const char* data, size_t length) { + uint16_t address_family; + if (length < sizeof(address_family)) { + return false; + } + memcpy(&address_family, data, sizeof(address_family)); + data += sizeof(address_family); + length -= sizeof(address_family); + + size_t ip_length; + switch (address_family) { + case kIPv4: + ip_length = QuicIpAddress::kIPv4AddressSize; + break; + case kIPv6: + ip_length = QuicIpAddress::kIPv6AddressSize; + break; + default: + return false; + } + if (length < ip_length) { + return false; + } + std::vector ip(ip_length); + memcpy(&ip[0], data, ip_length); + data += ip_length; + length -= ip_length; + + uint16_t port; + if (length != sizeof(port)) { + return false; + } + memcpy(&port, data, length); + + QuicIpAddress ip_address; + ip_address.FromPackedString(reinterpret_cast(&ip[0]), ip_length); + address_ = QuicSocketAddress(ip_address, port); + return true; +} + +} // namespace quic diff --git a/quiche/quic/core/quic_socket_address_coder.h b/quiche/quic/core/quic_socket_address_coder.h new file mode 100644 index 000000000000..b56ee6a2d565 --- /dev/null +++ b/quiche/quic/core/quic_socket_address_coder.h @@ -0,0 +1,42 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_SOCKET_ADDRESS_CODER_H_ +#define QUICHE_QUIC_CORE_QUIC_SOCKET_ADDRESS_CODER_H_ + +#include +#include + +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_socket_address.h" + +namespace quic { + +// Serializes and parses a socket address (IP address and port), to be used in +// the kCADR tag in the ServerHello handshake message and the Public Reset +// packet. +class QUIC_EXPORT_PRIVATE QuicSocketAddressCoder { + public: + QuicSocketAddressCoder(); + explicit QuicSocketAddressCoder(const QuicSocketAddress& address); + QuicSocketAddressCoder(const QuicSocketAddressCoder&) = delete; + QuicSocketAddressCoder& operator=(const QuicSocketAddressCoder&) = delete; + ~QuicSocketAddressCoder(); + + std::string Encode() const; + + bool Decode(const char* data, size_t length); + + QuicIpAddress ip() const { return address_.host(); } + + uint16_t port() const { return address_.port(); } + + private: + QuicSocketAddress address_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_SOCKET_ADDRESS_CODER_H_ diff --git a/quiche/quic/core/quic_socket_address_coder_test.cc b/quiche/quic/core/quic_socket_address_coder_test.cc new file mode 100644 index 000000000000..32f3570c450b --- /dev/null +++ b/quiche/quic/core/quic_socket_address_coder_test.cc @@ -0,0 +1,130 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_socket_address_coder.h" + +#include + +#include "absl/base/macros.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_ip_address_family.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { + +class QuicSocketAddressCoderTest : public QuicTest {}; + +TEST_F(QuicSocketAddressCoderTest, EncodeIPv4) { + QuicIpAddress ip; + ip.FromString("4.31.198.44"); + QuicSocketAddressCoder coder(QuicSocketAddress(ip, 0x1234)); + std::string serialized = coder.Encode(); + std::string expected("\x02\x00\x04\x1f\xc6\x2c\x34\x12", 8); + EXPECT_EQ(expected, serialized); +} + +TEST_F(QuicSocketAddressCoderTest, EncodeIPv6) { + QuicIpAddress ip; + ip.FromString("2001:700:300:1800::f"); + QuicSocketAddressCoder coder(QuicSocketAddress(ip, 0x5678)); + std::string serialized = coder.Encode(); + std::string expected( + "\x0a\x00" + "\x20\x01\x07\x00\x03\x00\x18\x00" + "\x00\x00\x00\x00\x00\x00\x00\x0f" + "\x78\x56", + 20); + EXPECT_EQ(expected, serialized); +} + +TEST_F(QuicSocketAddressCoderTest, DecodeIPv4) { + std::string serialized("\x02\x00\x04\x1f\xc6\x2c\x34\x12", 8); + QuicSocketAddressCoder coder; + ASSERT_TRUE(coder.Decode(serialized.data(), serialized.length())); + EXPECT_EQ(IpAddressFamily::IP_V4, coder.ip().address_family()); + std::string expected_addr("\x04\x1f\xc6\x2c"); + EXPECT_EQ(expected_addr, coder.ip().ToPackedString()); + EXPECT_EQ(0x1234, coder.port()); +} + +TEST_F(QuicSocketAddressCoderTest, DecodeIPv6) { + std::string serialized( + "\x0a\x00" + "\x20\x01\x07\x00\x03\x00\x18\x00" + "\x00\x00\x00\x00\x00\x00\x00\x0f" + "\x78\x56", + 20); + QuicSocketAddressCoder coder; + ASSERT_TRUE(coder.Decode(serialized.data(), serialized.length())); + EXPECT_EQ(IpAddressFamily::IP_V6, coder.ip().address_family()); + std::string expected_addr( + "\x20\x01\x07\x00\x03\x00\x18\x00" + "\x00\x00\x00\x00\x00\x00\x00\x0f", + 16); + EXPECT_EQ(expected_addr, coder.ip().ToPackedString()); + EXPECT_EQ(0x5678, coder.port()); +} + +TEST_F(QuicSocketAddressCoderTest, DecodeBad) { + std::string serialized( + "\x0a\x00" + "\x20\x01\x07\x00\x03\x00\x18\x00" + "\x00\x00\x00\x00\x00\x00\x00\x0f" + "\x78\x56", + 20); + QuicSocketAddressCoder coder; + EXPECT_TRUE(coder.Decode(serialized.data(), serialized.length())); + // Append junk. + serialized.push_back('\0'); + EXPECT_FALSE(coder.Decode(serialized.data(), serialized.length())); + // Undo. + serialized.resize(20); + EXPECT_TRUE(coder.Decode(serialized.data(), serialized.length())); + + // Set an unknown address family. + serialized[0] = '\x03'; + EXPECT_FALSE(coder.Decode(serialized.data(), serialized.length())); + // Undo. + serialized[0] = '\x0a'; + EXPECT_TRUE(coder.Decode(serialized.data(), serialized.length())); + + // Truncate. + size_t len = serialized.length(); + for (size_t i = 0; i < len; i++) { + ASSERT_FALSE(serialized.empty()); + serialized.erase(serialized.length() - 1); + EXPECT_FALSE(coder.Decode(serialized.data(), serialized.length())); + } + EXPECT_TRUE(serialized.empty()); +} + +TEST_F(QuicSocketAddressCoderTest, EncodeAndDecode) { + struct { + const char* ip_literal; + uint16_t port; + } test_case[] = { + {"93.184.216.119", 0x1234}, + {"199.204.44.194", 80}, + {"149.20.4.69", 443}, + {"127.0.0.1", 8080}, + {"2001:700:300:1800::", 0x5678}, + {"::1", 65534}, + }; + + for (size_t i = 0; i < ABSL_ARRAYSIZE(test_case); i++) { + QuicIpAddress ip; + ASSERT_TRUE(ip.FromString(test_case[i].ip_literal)); + QuicSocketAddressCoder encoder(QuicSocketAddress(ip, test_case[i].port)); + std::string serialized = encoder.Encode(); + + QuicSocketAddressCoder decoder; + ASSERT_TRUE(decoder.Decode(serialized.data(), serialized.length())); + EXPECT_EQ(encoder.ip(), decoder.ip()); + EXPECT_EQ(encoder.port(), decoder.port()); + } +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_stream.cc b/quiche/quic/core/quic_stream.cc new file mode 100644 index 000000000000..e570b443c69e --- /dev/null +++ b/quiche/quic/core/quic_stream.cc @@ -0,0 +1,1438 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_stream.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_flow_controller.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" + +using spdy::SpdyPriority; + +namespace quic { + +#define ENDPOINT \ + (perspective_ == Perspective::IS_SERVER ? "Server: " : "Client: ") + +namespace { + +QuicByteCount DefaultFlowControlWindow(ParsedQuicVersion version) { + if (!version.AllowsLowFlowControlLimits()) { + return kDefaultFlowControlSendWindow; + } + return 0; +} + +QuicByteCount GetInitialStreamFlowControlWindowToSend(QuicSession* session, + QuicStreamId stream_id) { + ParsedQuicVersion version = session->connection()->version(); + if (version.handshake_protocol != PROTOCOL_TLS1_3) { + return session->config()->GetInitialStreamFlowControlWindowToSend(); + } + + // Unidirectional streams (v99 only). + if (VersionHasIetfQuicFrames(version.transport_version) && + !QuicUtils::IsBidirectionalStreamId(stream_id, version)) { + return session->config() + ->GetInitialMaxStreamDataBytesUnidirectionalToSend(); + } + + if (QuicUtils::IsOutgoingStreamId(version, stream_id, + session->perspective())) { + return session->config() + ->GetInitialMaxStreamDataBytesOutgoingBidirectionalToSend(); + } + + return session->config() + ->GetInitialMaxStreamDataBytesIncomingBidirectionalToSend(); +} + +QuicByteCount GetReceivedFlowControlWindow(QuicSession* session, + QuicStreamId stream_id) { + ParsedQuicVersion version = session->connection()->version(); + if (version.handshake_protocol != PROTOCOL_TLS1_3) { + if (session->config()->HasReceivedInitialStreamFlowControlWindowBytes()) { + return session->config()->ReceivedInitialStreamFlowControlWindowBytes(); + } + + return DefaultFlowControlWindow(version); + } + + // Unidirectional streams (v99 only). + if (VersionHasIetfQuicFrames(version.transport_version) && + !QuicUtils::IsBidirectionalStreamId(stream_id, version)) { + if (session->config() + ->HasReceivedInitialMaxStreamDataBytesUnidirectional()) { + return session->config() + ->ReceivedInitialMaxStreamDataBytesUnidirectional(); + } + + return DefaultFlowControlWindow(version); + } + + if (QuicUtils::IsOutgoingStreamId(version, stream_id, + session->perspective())) { + if (session->config() + ->HasReceivedInitialMaxStreamDataBytesOutgoingBidirectional()) { + return session->config() + ->ReceivedInitialMaxStreamDataBytesOutgoingBidirectional(); + } + + return DefaultFlowControlWindow(version); + } + + if (session->config() + ->HasReceivedInitialMaxStreamDataBytesIncomingBidirectional()) { + return session->config() + ->ReceivedInitialMaxStreamDataBytesIncomingBidirectional(); + } + + return DefaultFlowControlWindow(version); +} + +} // namespace + +PendingStream::PendingStream(QuicStreamId id, QuicSession* session) + : id_(id), + version_(session->version()), + stream_delegate_(session), + stream_bytes_read_(0), + fin_received_(false), + is_bidirectional_(QuicUtils::GetStreamType(id, session->perspective(), + /*peer_initiated = */ true, + session->version()) == + BIDIRECTIONAL), + connection_flow_controller_(session->flow_controller()), + flow_controller_(session, id, + /*is_connection_flow_controller*/ false, + GetReceivedFlowControlWindow(session, id), + GetInitialStreamFlowControlWindowToSend(session, id), + kStreamReceiveWindowLimit, + session->flow_controller()->auto_tune_receive_window(), + session->flow_controller()), + sequencer_(this) {} + +void PendingStream::OnDataAvailable() { + // Data should be kept in the sequencer so that + // QuicSession::ProcessPendingStream() can read it. +} + +void PendingStream::OnFinRead() { QUICHE_DCHECK(sequencer_.IsClosed()); } + +void PendingStream::AddBytesConsumed(QuicByteCount bytes) { + // It will be called when the metadata of the stream is consumed. + flow_controller_.AddBytesConsumed(bytes); + connection_flow_controller_->AddBytesConsumed(bytes); +} + +void PendingStream::ResetWithError(QuicResetStreamError /*error*/) { + // Currently PendingStream is only read-unidirectional. It shouldn't send + // Reset. + QUICHE_NOTREACHED(); +} + +void PendingStream::OnUnrecoverableError(QuicErrorCode error, + const std::string& details) { + stream_delegate_->OnStreamError(error, details); +} + +void PendingStream::OnUnrecoverableError(QuicErrorCode error, + QuicIetfTransportErrorCodes ietf_error, + const std::string& details) { + stream_delegate_->OnStreamError(error, ietf_error, details); +} + +QuicStreamId PendingStream::id() const { return id_; } + +ParsedQuicVersion PendingStream::version() const { return version_; } + +void PendingStream::OnStreamFrame(const QuicStreamFrame& frame) { + QUICHE_DCHECK_EQ(frame.stream_id, id_); + + bool is_stream_too_long = + (frame.offset > kMaxStreamLength) || + (kMaxStreamLength - frame.offset < frame.data_length); + if (is_stream_too_long) { + // Close connection if stream becomes too long. + QUIC_PEER_BUG(quic_peer_bug_12570_1) + << "Receive stream frame reaches max stream length. frame offset " + << frame.offset << " length " << frame.data_length; + OnUnrecoverableError(QUIC_STREAM_LENGTH_OVERFLOW, + "Peer sends more data than allowed on this stream."); + return; + } + + if (frame.offset + frame.data_length > sequencer_.close_offset()) { + OnUnrecoverableError( + QUIC_STREAM_DATA_BEYOND_CLOSE_OFFSET, + absl::StrCat( + "Stream ", id_, + " received data with offset: ", frame.offset + frame.data_length, + ", which is beyond close offset: ", sequencer()->close_offset())); + return; + } + + if (frame.fin) { + fin_received_ = true; + } + + // This count includes duplicate data received. + QuicByteCount frame_payload_size = frame.data_length; + stream_bytes_read_ += frame_payload_size; + + // Flow control is interested in tracking highest received offset. + // Only interested in received frames that carry data. + if (frame_payload_size > 0 && + MaybeIncreaseHighestReceivedOffset(frame.offset + frame_payload_size)) { + // As the highest received offset has changed, check to see if this is a + // violation of flow control. + if (flow_controller_.FlowControlViolation() || + connection_flow_controller_->FlowControlViolation()) { + OnUnrecoverableError(QUIC_FLOW_CONTROL_RECEIVED_TOO_MUCH_DATA, + "Flow control violation after increasing offset"); + return; + } + } + + sequencer_.OnStreamFrame(frame); +} + +void PendingStream::OnRstStreamFrame(const QuicRstStreamFrame& frame) { + QUICHE_DCHECK_EQ(frame.stream_id, id_); + + if (frame.byte_offset > kMaxStreamLength) { + // Peer are not suppose to write bytes more than maxium allowed. + OnUnrecoverableError(QUIC_STREAM_LENGTH_OVERFLOW, + "Reset frame stream offset overflow."); + return; + } + + const QuicStreamOffset kMaxOffset = + std::numeric_limits::max(); + if (sequencer()->close_offset() != kMaxOffset && + frame.byte_offset != sequencer()->close_offset()) { + OnUnrecoverableError( + QUIC_STREAM_MULTIPLE_OFFSET, + absl::StrCat("Stream ", id_, + " received new final offset: ", frame.byte_offset, + ", which is different from close offset: ", + sequencer()->close_offset())); + return; + } + + MaybeIncreaseHighestReceivedOffset(frame.byte_offset); + if (flow_controller_.FlowControlViolation() || + connection_flow_controller_->FlowControlViolation()) { + OnUnrecoverableError(QUIC_FLOW_CONTROL_RECEIVED_TOO_MUCH_DATA, + "Flow control violation after increasing offset"); + return; + } +} + +void PendingStream::OnWindowUpdateFrame(const QuicWindowUpdateFrame& frame) { + QUICHE_DCHECK(is_bidirectional_); + flow_controller_.UpdateSendWindowOffset(frame.max_data); +} + +bool PendingStream::MaybeIncreaseHighestReceivedOffset( + QuicStreamOffset new_offset) { + uint64_t increment = + new_offset - flow_controller_.highest_received_byte_offset(); + if (!flow_controller_.UpdateHighestReceivedOffset(new_offset)) { + return false; + } + + // If |new_offset| increased the stream flow controller's highest received + // offset, increase the connection flow controller's value by the incremental + // difference. + connection_flow_controller_->UpdateHighestReceivedOffset( + connection_flow_controller_->highest_received_byte_offset() + increment); + return true; +} + +void PendingStream::OnStopSending( + QuicResetStreamError stop_sending_error_code) { + if (!stop_sending_error_code_) { + stop_sending_error_code_ = stop_sending_error_code; + } +} + +void PendingStream::MarkConsumed(QuicByteCount num_bytes) { + sequencer_.MarkConsumed(num_bytes); +} + +void PendingStream::StopReading() { + QUIC_DVLOG(1) << "Stop reading from pending stream " << id(); + sequencer_.StopReading(); +} + +QuicStream::QuicStream(PendingStream* pending, QuicSession* session, + bool is_static) + : QuicStream(pending->id_, session, std::move(pending->sequencer_), + is_static, + QuicUtils::GetStreamType(pending->id_, session->perspective(), + /*peer_initiated = */ true, + session->version()), + pending->stream_bytes_read_, pending->fin_received_, + std::move(pending->flow_controller_), + pending->connection_flow_controller_) { + QUICHE_DCHECK(session->version().HasIetfQuicFrames()); + sequencer_.set_stream(this); +} + +namespace { + +absl::optional FlowController(QuicStreamId id, + QuicSession* session, + StreamType type) { + if (type == CRYPTO) { + // The only QuicStream with a StreamType of CRYPTO is QuicCryptoStream, when + // it is using crypto frames instead of stream frames. The QuicCryptoStream + // doesn't have any flow control in that case, so we don't create a + // QuicFlowController for it. + return absl::nullopt; + } + return QuicFlowController( + session, id, + /*is_connection_flow_controller*/ false, + GetReceivedFlowControlWindow(session, id), + GetInitialStreamFlowControlWindowToSend(session, id), + kStreamReceiveWindowLimit, + session->flow_controller()->auto_tune_receive_window(), + session->flow_controller()); +} + +} // namespace + +QuicStream::QuicStream(QuicStreamId id, QuicSession* session, bool is_static, + StreamType type) + : QuicStream(id, session, QuicStreamSequencer(this), is_static, type, 0, + false, FlowController(id, session, type), + session->flow_controller()) {} + +QuicStream::QuicStream(QuicStreamId id, QuicSession* session, + QuicStreamSequencer sequencer, bool is_static, + StreamType type, uint64_t stream_bytes_read, + bool fin_received, + absl::optional flow_controller, + QuicFlowController* connection_flow_controller) + : sequencer_(std::move(sequencer)), + id_(id), + session_(session), + stream_delegate_(session), + priority_(QuicStreamPriority::Default(session->priority_type())), + stream_bytes_read_(stream_bytes_read), + stream_error_(QuicResetStreamError::NoError()), + connection_error_(QUIC_NO_ERROR), + read_side_closed_(false), + write_side_closed_(false), + write_side_data_recvd_state_notified_(false), + fin_buffered_(false), + fin_sent_(false), + fin_outstanding_(false), + fin_lost_(false), + fin_received_(fin_received), + rst_sent_(false), + rst_received_(false), + stop_sending_sent_(false), + flow_controller_(std::move(flow_controller)), + connection_flow_controller_(connection_flow_controller), + stream_contributes_to_connection_flow_control_(true), + busy_counter_(0), + add_random_padding_after_fin_(false), + send_buffer_( + session->connection()->helper()->GetStreamSendBufferAllocator()), + buffered_data_threshold_(GetQuicFlag(quic_buffered_data_threshold)), + is_static_(is_static), + deadline_(QuicTime::Zero()), + was_draining_(false), + type_(VersionHasIetfQuicFrames(session->transport_version()) && + type != CRYPTO + ? QuicUtils::GetStreamType(id_, session->perspective(), + session->IsIncomingStream(id_), + session->version()) + : type), + creation_time_(session->connection()->clock()->ApproximateNow()), + perspective_(session->perspective()) { + if (type_ == WRITE_UNIDIRECTIONAL) { + fin_received_ = true; + CloseReadSide(); + } else if (type_ == READ_UNIDIRECTIONAL) { + fin_sent_ = true; + CloseWriteSide(); + } + if (type_ != CRYPTO) { + stream_delegate_->RegisterStreamPriority(id, is_static_, priority_); + } +} + +QuicStream::~QuicStream() { + if (session_ != nullptr && IsWaitingForAcks()) { + QUIC_DVLOG(1) + << ENDPOINT << "Stream " << id_ + << " gets destroyed while waiting for acks. stream_bytes_outstanding = " + << send_buffer_.stream_bytes_outstanding() + << ", fin_outstanding: " << fin_outstanding_; + } + if (stream_delegate_ != nullptr && type_ != CRYPTO) { + stream_delegate_->UnregisterStreamPriority(id()); + } +} + +void QuicStream::OnStreamFrame(const QuicStreamFrame& frame) { + QUICHE_DCHECK_EQ(frame.stream_id, id_); + + QUICHE_DCHECK(!(read_side_closed_ && write_side_closed_)); + + if (frame.fin && is_static_) { + OnUnrecoverableError(QUIC_INVALID_STREAM_ID, + "Attempt to close a static stream"); + return; + } + + if (type_ == WRITE_UNIDIRECTIONAL) { + OnUnrecoverableError(QUIC_DATA_RECEIVED_ON_WRITE_UNIDIRECTIONAL_STREAM, + "Data received on write unidirectional stream"); + return; + } + + bool is_stream_too_long = + (frame.offset > kMaxStreamLength) || + (kMaxStreamLength - frame.offset < frame.data_length); + if (is_stream_too_long) { + // Close connection if stream becomes too long. + QUIC_PEER_BUG(quic_peer_bug_10586_1) + << "Receive stream frame on stream " << id_ + << " reaches max stream length. frame offset " << frame.offset + << " length " << frame.data_length << ". " << sequencer_.DebugString(); + OnUnrecoverableError( + QUIC_STREAM_LENGTH_OVERFLOW, + absl::StrCat("Peer sends more data than allowed on stream ", id_, + ". frame: offset = ", frame.offset, ", length = ", + frame.data_length, ". ", sequencer_.DebugString())); + return; + } + + if (frame.offset + frame.data_length > sequencer_.close_offset()) { + OnUnrecoverableError( + QUIC_STREAM_DATA_BEYOND_CLOSE_OFFSET, + absl::StrCat( + "Stream ", id_, + " received data with offset: ", frame.offset + frame.data_length, + ", which is beyond close offset: ", sequencer_.close_offset())); + return; + } + + if (frame.fin && !fin_received_) { + fin_received_ = true; + if (fin_sent_) { + QUICHE_DCHECK(!was_draining_); + session_->StreamDraining(id_, + /*unidirectional=*/type_ != BIDIRECTIONAL); + was_draining_ = true; + } + } + + if (read_side_closed_) { + QUIC_DLOG(INFO) + << ENDPOINT << "Stream " << frame.stream_id + << " is closed for reading. Ignoring newly received stream data."; + // The subclass does not want to read data: blackhole the data. + return; + } + + // This count includes duplicate data received. + QuicByteCount frame_payload_size = frame.data_length; + stream_bytes_read_ += frame_payload_size; + + // Flow control is interested in tracking highest received offset. + // Only interested in received frames that carry data. + if (frame_payload_size > 0 && + MaybeIncreaseHighestReceivedOffset(frame.offset + frame_payload_size)) { + // As the highest received offset has changed, check to see if this is a + // violation of flow control. + QUIC_BUG_IF(quic_bug_12570_2, !flow_controller_.has_value()) + << ENDPOINT << "OnStreamFrame called on stream without flow control"; + if ((flow_controller_.has_value() && + flow_controller_->FlowControlViolation()) || + connection_flow_controller_->FlowControlViolation()) { + OnUnrecoverableError(QUIC_FLOW_CONTROL_RECEIVED_TOO_MUCH_DATA, + "Flow control violation after increasing offset"); + return; + } + } + + sequencer_.OnStreamFrame(frame); +} + +bool QuicStream::OnStopSending(QuicResetStreamError error) { + // Do not reset the stream if all data has been sent and acknowledged. + if (write_side_closed() && !IsWaitingForAcks()) { + QUIC_DVLOG(1) << ENDPOINT + << "Ignoring STOP_SENDING for a write closed stream, id: " + << id_; + return false; + } + + if (is_static_) { + QUIC_DVLOG(1) << ENDPOINT + << "Received STOP_SENDING for a static stream, id: " << id_ + << " Closing connection"; + OnUnrecoverableError(QUIC_INVALID_STREAM_ID, + "Received STOP_SENDING for a static stream"); + return false; + } + + stream_error_ = error; + MaybeSendRstStream(error); + return true; +} + +int QuicStream::num_frames_received() const { + return sequencer_.num_frames_received(); +} + +int QuicStream::num_duplicate_frames_received() const { + return sequencer_.num_duplicate_frames_received(); +} + +void QuicStream::OnStreamReset(const QuicRstStreamFrame& frame) { + rst_received_ = true; + if (frame.byte_offset > kMaxStreamLength) { + // Peer are not suppose to write bytes more than maxium allowed. + OnUnrecoverableError(QUIC_STREAM_LENGTH_OVERFLOW, + "Reset frame stream offset overflow."); + return; + } + + const QuicStreamOffset kMaxOffset = + std::numeric_limits::max(); + if (sequencer()->close_offset() != kMaxOffset && + frame.byte_offset != sequencer()->close_offset()) { + OnUnrecoverableError( + QUIC_STREAM_MULTIPLE_OFFSET, + absl::StrCat("Stream ", id_, + " received new final offset: ", frame.byte_offset, + ", which is different from close offset: ", + sequencer_.close_offset())); + return; + } + + MaybeIncreaseHighestReceivedOffset(frame.byte_offset); + QUIC_BUG_IF(quic_bug_12570_3, !flow_controller_.has_value()) + << ENDPOINT << "OnStreamReset called on stream without flow control"; + if ((flow_controller_.has_value() && + flow_controller_->FlowControlViolation()) || + connection_flow_controller_->FlowControlViolation()) { + OnUnrecoverableError(QUIC_FLOW_CONTROL_RECEIVED_TOO_MUCH_DATA, + "Flow control violation after increasing offset"); + return; + } + + stream_error_ = frame.error(); + // Google QUIC closes both sides of the stream in response to a + // RESET_STREAM, IETF QUIC closes only the read side. + if (!VersionHasIetfQuicFrames(transport_version())) { + CloseWriteSide(); + } + CloseReadSide(); +} + +void QuicStream::OnConnectionClosed(QuicErrorCode error, + ConnectionCloseSource /*source*/) { + if (read_side_closed_ && write_side_closed_) { + return; + } + if (error != QUIC_NO_ERROR) { + stream_error_ = + QuicResetStreamError::FromInternal(QUIC_STREAM_CONNECTION_ERROR); + connection_error_ = error; + } + + CloseWriteSide(); + CloseReadSide(); +} + +void QuicStream::OnFinRead() { + QUICHE_DCHECK(sequencer_.IsClosed()); + // OnFinRead can be called due to a FIN flag in a headers block, so there may + // have been no OnStreamFrame call with a FIN in the frame. + fin_received_ = true; + // If fin_sent_ is true, then CloseWriteSide has already been called, and the + // stream will be destroyed by CloseReadSide, so don't need to call + // StreamDraining. + CloseReadSide(); +} + +void QuicStream::SetFinSent() { + QUICHE_DCHECK(!VersionUsesHttp3(transport_version())); + fin_sent_ = true; +} + +void QuicStream::Reset(QuicRstStreamErrorCode error) { + ResetWithError(QuicResetStreamError::FromInternal(error)); +} + +void QuicStream::ResetWithError(QuicResetStreamError error) { + stream_error_ = error; + QuicConnection::ScopedPacketFlusher flusher(session()->connection()); + MaybeSendStopSending(error); + MaybeSendRstStream(error); + + if (read_side_closed_ && write_side_closed_ && !IsWaitingForAcks()) { + session()->MaybeCloseZombieStream(id_); + } +} + +void QuicStream::ResetWriteSide(QuicResetStreamError error) { + stream_error_ = error; + MaybeSendRstStream(error); + + if (read_side_closed_ && write_side_closed_ && !IsWaitingForAcks()) { + session()->MaybeCloseZombieStream(id_); + } +} + +void QuicStream::SendStopSending(QuicResetStreamError error) { + stream_error_ = error; + MaybeSendStopSending(error); + + if (read_side_closed_ && write_side_closed_ && !IsWaitingForAcks()) { + session()->MaybeCloseZombieStream(id_); + } +} + +void QuicStream::OnUnrecoverableError(QuicErrorCode error, + const std::string& details) { + stream_delegate_->OnStreamError(error, details); +} + +void QuicStream::OnUnrecoverableError(QuicErrorCode error, + QuicIetfTransportErrorCodes ietf_error, + const std::string& details) { + stream_delegate_->OnStreamError(error, ietf_error, details); +} + +const QuicStreamPriority& QuicStream::priority() const { return priority_; } + +void QuicStream::SetPriority(const QuicStreamPriority& priority) { + priority_ = priority; + + MaybeSendPriorityUpdateFrame(); + + stream_delegate_->UpdateStreamPriority(id(), priority); +} + +void QuicStream::WriteOrBufferData( + absl::string_view data, bool fin, + quiche::QuicheReferenceCountedPointer + ack_listener) { + QUIC_BUG_IF(quic_bug_12570_4, + QuicUtils::IsCryptoStreamId(transport_version(), id_)) + << ENDPOINT + << "WriteOrBufferData is used to send application data, use " + "WriteOrBufferDataAtLevel to send crypto data."; + return WriteOrBufferDataAtLevel( + data, fin, session()->GetEncryptionLevelToSendApplicationData(), + ack_listener); +} + +void QuicStream::WriteOrBufferDataAtLevel( + absl::string_view data, bool fin, EncryptionLevel level, + quiche::QuicheReferenceCountedPointer + ack_listener) { + if (data.empty() && !fin) { + QUIC_BUG(quic_bug_10586_2) << "data.empty() && !fin"; + return; + } + + if (fin_buffered_) { + QUIC_BUG(quic_bug_10586_3) << "Fin already buffered"; + return; + } + if (write_side_closed_) { + QUIC_DLOG(ERROR) << ENDPOINT + << "Attempt to write when the write side is closed"; + if (type_ == READ_UNIDIRECTIONAL) { + OnUnrecoverableError(QUIC_TRY_TO_WRITE_DATA_ON_READ_UNIDIRECTIONAL_STREAM, + "Try to send data on read unidirectional stream"); + } + return; + } + + fin_buffered_ = fin; + + bool had_buffered_data = HasBufferedData(); + // Do not respect buffered data upper limit as WriteOrBufferData guarantees + // all data to be consumed. + if (data.length() > 0) { + QuicStreamOffset offset = send_buffer_.stream_offset(); + if (kMaxStreamLength - offset < data.length()) { + QUIC_BUG(quic_bug_10586_4) << "Write too many data via stream " << id_; + OnUnrecoverableError( + QUIC_STREAM_LENGTH_OVERFLOW, + absl::StrCat("Write too many data via stream ", id_)); + return; + } + send_buffer_.SaveStreamData(data); + OnDataBuffered(offset, data.length(), ack_listener); + } + if (!had_buffered_data && (HasBufferedData() || fin_buffered_)) { + // Write data if there is no buffered data before. + WriteBufferedData(level); + } +} + +void QuicStream::OnCanWrite() { + if (HasDeadlinePassed()) { + OnDeadlinePassed(); + return; + } + if (HasPendingRetransmission()) { + WritePendingRetransmission(); + // Exit early to allow other streams to write pending retransmissions if + // any. + return; + } + + if (write_side_closed_) { + QUIC_DLOG(ERROR) + << ENDPOINT << "Stream " << id() + << " attempting to write new data when the write side is closed"; + return; + } + if (HasBufferedData() || (fin_buffered_ && !fin_sent_)) { + WriteBufferedData(session()->GetEncryptionLevelToSendApplicationData()); + } + if (!fin_buffered_ && !fin_sent_ && CanWriteNewData()) { + // Notify upper layer to write new data when buffered data size is below + // low water mark. + OnCanWriteNewData(); + } +} + +void QuicStream::MaybeSendBlocked() { + if (!flow_controller_.has_value()) { + QUIC_BUG(quic_bug_10586_5) + << ENDPOINT << "MaybeSendBlocked called on stream without flow control"; + return; + } + flow_controller_->MaybeSendBlocked(); + if (!stream_contributes_to_connection_flow_control_) { + return; + } + connection_flow_controller_->MaybeSendBlocked(); + + // If the stream is blocked by connection-level flow control but not by + // stream-level flow control, add the stream to the write blocked list so that + // the stream will be given a chance to write when a connection-level + // WINDOW_UPDATE arrives. + if (!write_side_closed_ && connection_flow_controller_->IsBlocked() && + !flow_controller_->IsBlocked()) { + session_->MarkConnectionLevelWriteBlocked(id()); + } +} + +QuicConsumedData QuicStream::WriteMemSlice(quiche::QuicheMemSlice span, + bool fin) { + return WriteMemSlices(absl::MakeSpan(&span, 1), fin); +} + +QuicConsumedData QuicStream::WriteMemSlices( + absl::Span span, bool fin) { + QuicConsumedData consumed_data(0, false); + if (span.empty() && !fin) { + QUIC_BUG(quic_bug_10586_6) << "span.empty() && !fin"; + return consumed_data; + } + + if (fin_buffered_) { + QUIC_BUG(quic_bug_10586_7) << "Fin already buffered"; + return consumed_data; + } + + if (write_side_closed_) { + QUIC_DLOG(ERROR) << ENDPOINT << "Stream " << id() + << " attempting to write when the write side is closed"; + if (type_ == READ_UNIDIRECTIONAL) { + OnUnrecoverableError(QUIC_TRY_TO_WRITE_DATA_ON_READ_UNIDIRECTIONAL_STREAM, + "Try to send data on read unidirectional stream"); + } + return consumed_data; + } + + bool had_buffered_data = HasBufferedData(); + if (CanWriteNewData() || span.empty()) { + consumed_data.fin_consumed = fin; + if (!span.empty()) { + // Buffer all data if buffered data size is below limit. + QuicStreamOffset offset = send_buffer_.stream_offset(); + consumed_data.bytes_consumed = send_buffer_.SaveMemSliceSpan(span); + if (offset > send_buffer_.stream_offset() || + kMaxStreamLength < send_buffer_.stream_offset()) { + QUIC_BUG(quic_bug_10586_8) << "Write too many data via stream " << id_; + OnUnrecoverableError( + QUIC_STREAM_LENGTH_OVERFLOW, + absl::StrCat("Write too many data via stream ", id_)); + return consumed_data; + } + OnDataBuffered(offset, consumed_data.bytes_consumed, nullptr); + } + } + fin_buffered_ = consumed_data.fin_consumed; + + if (!had_buffered_data && (HasBufferedData() || fin_buffered_)) { + // Write data if there is no buffered data before. + WriteBufferedData(session()->GetEncryptionLevelToSendApplicationData()); + } + + return consumed_data; +} + +bool QuicStream::HasPendingRetransmission() const { + return send_buffer_.HasPendingRetransmission() || fin_lost_; +} + +bool QuicStream::IsStreamFrameOutstanding(QuicStreamOffset offset, + QuicByteCount data_length, + bool fin) const { + return send_buffer_.IsStreamDataOutstanding(offset, data_length) || + (fin && fin_outstanding_); +} + +void QuicStream::CloseReadSide() { + if (read_side_closed_) { + return; + } + QUIC_DVLOG(1) << ENDPOINT << "Done reading from stream " << id(); + + read_side_closed_ = true; + sequencer_.ReleaseBuffer(); + + if (write_side_closed_) { + QUIC_DVLOG(1) << ENDPOINT << "Closing stream " << id(); + session_->OnStreamClosed(id()); + OnClose(); + } +} + +void QuicStream::CloseWriteSide() { + if (write_side_closed_) { + return; + } + QUIC_DVLOG(1) << ENDPOINT << "Done writing to stream " << id(); + + write_side_closed_ = true; + if (read_side_closed_) { + QUIC_DVLOG(1) << ENDPOINT << "Closing stream " << id(); + session_->OnStreamClosed(id()); + OnClose(); + } +} + +void QuicStream::MaybeSendStopSending(QuicResetStreamError error) { + if (stop_sending_sent_) { + return; + } + + if (!session()->version().UsesHttp3() && !error.ok()) { + // In gQUIC, RST with error closes both read and write side. + return; + } + + if (session()->version().UsesHttp3()) { + session()->MaybeSendStopSendingFrame(id(), error); + } else { + QUICHE_DCHECK_EQ(QUIC_STREAM_NO_ERROR, error.internal_code()); + session()->MaybeSendRstStreamFrame(id(), QuicResetStreamError::NoError(), + stream_bytes_written()); + } + stop_sending_sent_ = true; + CloseReadSide(); +} + +void QuicStream::MaybeSendRstStream(QuicResetStreamError error) { + if (rst_sent_) { + return; + } + + if (!session()->version().UsesHttp3()) { + QUIC_BUG_IF(quic_bug_12570_5, error.ok()); + stop_sending_sent_ = true; + CloseReadSide(); + } + session()->MaybeSendRstStreamFrame(id(), error, stream_bytes_written()); + rst_sent_ = true; + CloseWriteSide(); +} + +bool QuicStream::HasBufferedData() const { + QUICHE_DCHECK_GE(send_buffer_.stream_offset(), stream_bytes_written()); + return send_buffer_.stream_offset() > stream_bytes_written(); +} + +ParsedQuicVersion QuicStream::version() const { return session_->version(); } + +QuicTransportVersion QuicStream::transport_version() const { + return session_->transport_version(); +} + +HandshakeProtocol QuicStream::handshake_protocol() const { + return session_->connection()->version().handshake_protocol; +} + +void QuicStream::StopReading() { + QUIC_DVLOG(1) << ENDPOINT << "Stop reading from stream " << id(); + sequencer_.StopReading(); +} + +void QuicStream::OnClose() { + QUICHE_DCHECK(read_side_closed_ && write_side_closed_); + + if (!fin_sent_ && !rst_sent_) { + QUIC_BUG_IF(quic_bug_12570_6, session()->connection()->connected() && + session()->version().UsesHttp3()) + << "The stream should've already sent RST in response to " + "STOP_SENDING"; + // For flow control accounting, tell the peer how many bytes have been + // written on this stream before termination. Done here if needed, using a + // RST_STREAM frame. + MaybeSendRstStream(QUIC_RST_ACKNOWLEDGEMENT); + session_->MaybeCloseZombieStream(id_); + } + + if (!flow_controller_.has_value() || + flow_controller_->FlowControlViolation() || + connection_flow_controller_->FlowControlViolation()) { + return; + } + // The stream is being closed and will not process any further incoming bytes. + // As there may be more bytes in flight, to ensure that both endpoints have + // the same connection level flow control state, mark all unreceived or + // buffered bytes as consumed. + QuicByteCount bytes_to_consume = + flow_controller_->highest_received_byte_offset() - + flow_controller_->bytes_consumed(); + AddBytesConsumed(bytes_to_consume); +} + +void QuicStream::OnWindowUpdateFrame(const QuicWindowUpdateFrame& frame) { + if (type_ == READ_UNIDIRECTIONAL) { + OnUnrecoverableError( + QUIC_WINDOW_UPDATE_RECEIVED_ON_READ_UNIDIRECTIONAL_STREAM, + "WindowUpdateFrame received on READ_UNIDIRECTIONAL stream."); + return; + } + + if (!flow_controller_.has_value()) { + QUIC_BUG(quic_bug_10586_9) + << ENDPOINT + << "OnWindowUpdateFrame called on stream without flow control"; + return; + } + + if (flow_controller_->UpdateSendWindowOffset(frame.max_data)) { + // Let session unblock this stream. + session_->MarkConnectionLevelWriteBlocked(id_); + } +} + +bool QuicStream::MaybeIncreaseHighestReceivedOffset( + QuicStreamOffset new_offset) { + if (!flow_controller_.has_value()) { + QUIC_BUG(quic_bug_10586_10) + << ENDPOINT + << "MaybeIncreaseHighestReceivedOffset called on stream without " + "flow control"; + return false; + } + uint64_t increment = + new_offset - flow_controller_->highest_received_byte_offset(); + if (!flow_controller_->UpdateHighestReceivedOffset(new_offset)) { + return false; + } + + // If |new_offset| increased the stream flow controller's highest received + // offset, increase the connection flow controller's value by the incremental + // difference. + if (stream_contributes_to_connection_flow_control_) { + connection_flow_controller_->UpdateHighestReceivedOffset( + connection_flow_controller_->highest_received_byte_offset() + + increment); + } + return true; +} + +void QuicStream::AddBytesSent(QuicByteCount bytes) { + if (!flow_controller_.has_value()) { + QUIC_BUG(quic_bug_10586_11) + << ENDPOINT << "AddBytesSent called on stream without flow control"; + return; + } + flow_controller_->AddBytesSent(bytes); + if (stream_contributes_to_connection_flow_control_) { + connection_flow_controller_->AddBytesSent(bytes); + } +} + +void QuicStream::AddBytesConsumed(QuicByteCount bytes) { + if (type_ == CRYPTO) { + // A stream with type CRYPTO has no flow control, so there's nothing this + // function needs to do. This function still gets called by the + // QuicStreamSequencers used by QuicCryptoStream. + return; + } + if (!flow_controller_.has_value()) { + QUIC_BUG(quic_bug_12570_7) + << ENDPOINT + << "AddBytesConsumed called on non-crypto stream without flow control"; + return; + } + // Only adjust stream level flow controller if still reading. + if (!read_side_closed_) { + flow_controller_->AddBytesConsumed(bytes); + } + + if (stream_contributes_to_connection_flow_control_) { + connection_flow_controller_->AddBytesConsumed(bytes); + } +} + +bool QuicStream::MaybeConfigSendWindowOffset(QuicStreamOffset new_offset, + bool was_zero_rtt_rejected) { + if (!flow_controller_.has_value()) { + QUIC_BUG(quic_bug_10586_12) + << ENDPOINT + << "ConfigSendWindowOffset called on stream without flow control"; + return false; + } + + // The validation code below is for QUIC with TLS only. + if (new_offset < flow_controller_->send_window_offset()) { + QUICHE_DCHECK(session()->version().UsesTls()); + if (was_zero_rtt_rejected && new_offset < flow_controller_->bytes_sent()) { + // The client is given flow control window lower than what's written in + // 0-RTT. This QUIC implementation is unable to retransmit them. + QUIC_BUG_IF(quic_bug_12570_8, perspective_ == Perspective::IS_SERVER) + << "Server streams' flow control should never be configured twice."; + OnUnrecoverableError( + QUIC_ZERO_RTT_UNRETRANSMITTABLE, + absl::StrCat( + "Server rejected 0-RTT, aborting because new stream max data ", + new_offset, " for stream ", id_, " is less than currently used: ", + flow_controller_->bytes_sent())); + return false; + } else if (session()->version().AllowsLowFlowControlLimits()) { + // In IETF QUIC, if the client receives flow control limit lower than what + // was resumed from 0-RTT, depending on 0-RTT status, it's either the + // peer's fault or our implementation's fault. + QUIC_BUG_IF(quic_bug_12570_9, perspective_ == Perspective::IS_SERVER) + << "Server streams' flow control should never be configured twice."; + OnUnrecoverableError( + was_zero_rtt_rejected ? QUIC_ZERO_RTT_REJECTION_LIMIT_REDUCED + : QUIC_ZERO_RTT_RESUMPTION_LIMIT_REDUCED, + absl::StrCat( + was_zero_rtt_rejected ? "Server rejected 0-RTT, aborting because " + : "", + "new stream max data ", new_offset, " decreases current limit: ", + flow_controller_->send_window_offset())); + return false; + } + } + + if (flow_controller_->UpdateSendWindowOffset(new_offset)) { + // Let session unblock this stream. + session_->MarkConnectionLevelWriteBlocked(id_); + } + return true; +} + +void QuicStream::AddRandomPaddingAfterFin() { + add_random_padding_after_fin_ = true; +} + +bool QuicStream::OnStreamFrameAcked(QuicStreamOffset offset, + QuicByteCount data_length, bool fin_acked, + QuicTime::Delta /*ack_delay_time*/, + QuicTime /*receive_timestamp*/, + QuicByteCount* newly_acked_length) { + QUIC_DVLOG(1) << ENDPOINT << "stream " << id_ << " Acking " + << "[" << offset << ", " << offset + data_length << "]" + << " fin = " << fin_acked; + *newly_acked_length = 0; + if (!send_buffer_.OnStreamDataAcked(offset, data_length, + newly_acked_length)) { + OnUnrecoverableError(QUIC_INTERNAL_ERROR, "Trying to ack unsent data."); + return false; + } + if (!fin_sent_ && fin_acked) { + OnUnrecoverableError(QUIC_INTERNAL_ERROR, "Trying to ack unsent fin."); + return false; + } + // Indicates whether ack listener's OnPacketAcked should be called. + const bool new_data_acked = + *newly_acked_length > 0 || (fin_acked && fin_outstanding_); + if (fin_acked) { + fin_outstanding_ = false; + fin_lost_ = false; + } + if (!IsWaitingForAcks() && write_side_closed_ && + !write_side_data_recvd_state_notified_) { + OnWriteSideInDataRecvdState(); + write_side_data_recvd_state_notified_ = true; + } + if (!IsWaitingForAcks() && read_side_closed_ && write_side_closed_) { + session_->MaybeCloseZombieStream(id_); + } + return new_data_acked; +} + +void QuicStream::OnStreamFrameRetransmitted(QuicStreamOffset offset, + QuicByteCount data_length, + bool fin_retransmitted) { + send_buffer_.OnStreamDataRetransmitted(offset, data_length); + if (fin_retransmitted) { + fin_lost_ = false; + } +} + +void QuicStream::OnStreamFrameLost(QuicStreamOffset offset, + QuicByteCount data_length, bool fin_lost) { + QUIC_DVLOG(1) << ENDPOINT << "stream " << id_ << " Losting " + << "[" << offset << ", " << offset + data_length << "]" + << " fin = " << fin_lost; + if (data_length > 0) { + send_buffer_.OnStreamDataLost(offset, data_length); + } + if (fin_lost && fin_outstanding_) { + fin_lost_ = true; + } +} + +bool QuicStream::RetransmitStreamData(QuicStreamOffset offset, + QuicByteCount data_length, bool fin, + TransmissionType type) { + QUICHE_DCHECK(type == PTO_RETRANSMISSION); + if (HasDeadlinePassed()) { + OnDeadlinePassed(); + return true; + } + QuicIntervalSet retransmission(offset, + offset + data_length); + retransmission.Difference(bytes_acked()); + bool retransmit_fin = fin && fin_outstanding_; + if (retransmission.Empty() && !retransmit_fin) { + return true; + } + QuicConsumedData consumed(0, false); + for (const auto& interval : retransmission) { + QuicStreamOffset retransmission_offset = interval.min(); + QuicByteCount retransmission_length = interval.max() - interval.min(); + const bool can_bundle_fin = + retransmit_fin && (retransmission_offset + retransmission_length == + stream_bytes_written()); + consumed = stream_delegate_->WritevData( + id_, retransmission_length, retransmission_offset, + can_bundle_fin ? FIN : NO_FIN, type, + session()->GetEncryptionLevelToSendApplicationData()); + QUIC_DVLOG(1) << ENDPOINT << "stream " << id_ + << " is forced to retransmit stream data [" + << retransmission_offset << ", " + << retransmission_offset + retransmission_length + << ") and fin: " << can_bundle_fin + << ", consumed: " << consumed; + OnStreamFrameRetransmitted(retransmission_offset, consumed.bytes_consumed, + consumed.fin_consumed); + if (can_bundle_fin) { + retransmit_fin = !consumed.fin_consumed; + } + if (consumed.bytes_consumed < retransmission_length || + (can_bundle_fin && !consumed.fin_consumed)) { + // Connection is write blocked. + return false; + } + } + if (retransmit_fin) { + QUIC_DVLOG(1) << ENDPOINT << "stream " << id_ + << " retransmits fin only frame."; + consumed = stream_delegate_->WritevData( + id_, 0, stream_bytes_written(), FIN, type, + session()->GetEncryptionLevelToSendApplicationData()); + if (!consumed.fin_consumed) { + return false; + } + } + return true; +} + +bool QuicStream::IsWaitingForAcks() const { + return (!rst_sent_ || stream_error_.ok()) && + (send_buffer_.stream_bytes_outstanding() || fin_outstanding_); +} + +QuicByteCount QuicStream::ReadableBytes() const { + return sequencer_.ReadableBytes(); +} + +bool QuicStream::WriteStreamData(QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* writer) { + QUICHE_DCHECK_LT(0u, data_length); + QUIC_DVLOG(2) << ENDPOINT << "Write stream " << id_ << " data from offset " + << offset << " length " << data_length; + return send_buffer_.WriteStreamData(offset, data_length, writer); +} + +void QuicStream::WriteBufferedData(EncryptionLevel level) { + QUICHE_DCHECK(!write_side_closed_ && (HasBufferedData() || fin_buffered_)); + + if (session_->ShouldYield(id())) { + session_->MarkConnectionLevelWriteBlocked(id()); + return; + } + + // Size of buffered data. + QuicByteCount write_length = BufferedDataBytes(); + + // A FIN with zero data payload should not be flow control blocked. + bool fin_with_zero_data = (fin_buffered_ && write_length == 0); + + bool fin = fin_buffered_; + + // How much data flow control permits to be written. + QuicByteCount send_window; + if (flow_controller_.has_value()) { + send_window = flow_controller_->SendWindowSize(); + } else { + send_window = std::numeric_limits::max(); + QUIC_BUG(quic_bug_10586_13) + << ENDPOINT + << "WriteBufferedData called on stream without flow control"; + } + if (stream_contributes_to_connection_flow_control_) { + send_window = + std::min(send_window, connection_flow_controller_->SendWindowSize()); + } + + if (send_window == 0 && !fin_with_zero_data) { + // Quick return if nothing can be sent. + MaybeSendBlocked(); + return; + } + + if (write_length > send_window) { + // Don't send the FIN unless all the data will be sent. + fin = false; + + // Writing more data would be a violation of flow control. + write_length = send_window; + QUIC_DVLOG(1) << "stream " << id() << " shortens write length to " + << write_length << " due to flow control"; + } + + StreamSendingState state = fin ? FIN : NO_FIN; + if (fin && add_random_padding_after_fin_) { + state = FIN_AND_PADDING; + } + QuicConsumedData consumed_data = + stream_delegate_->WritevData(id(), write_length, stream_bytes_written(), + state, NOT_RETRANSMISSION, level); + + OnStreamDataConsumed(consumed_data.bytes_consumed); + + AddBytesSent(consumed_data.bytes_consumed); + QUIC_DVLOG(1) << ENDPOINT << "stream " << id_ << " sends " + << stream_bytes_written() << " bytes " + << " and has buffered data " << BufferedDataBytes() << " bytes." + << " fin is sent: " << consumed_data.fin_consumed + << " fin is buffered: " << fin_buffered_; + + // The write may have generated a write error causing this stream to be + // closed. If so, simply return without marking the stream write blocked. + if (write_side_closed_) { + return; + } + + if (consumed_data.bytes_consumed == write_length) { + if (!fin_with_zero_data) { + MaybeSendBlocked(); + } + if (fin && consumed_data.fin_consumed) { + QUICHE_DCHECK(!fin_sent_); + fin_sent_ = true; + fin_outstanding_ = true; + if (fin_received_) { + QUICHE_DCHECK(!was_draining_); + session_->StreamDraining(id_, + /*unidirectional=*/type_ != BIDIRECTIONAL); + was_draining_ = true; + } + CloseWriteSide(); + } else if (fin && !consumed_data.fin_consumed && !write_side_closed_) { + session_->MarkConnectionLevelWriteBlocked(id()); + } + } else { + session_->MarkConnectionLevelWriteBlocked(id()); + } + if (consumed_data.bytes_consumed > 0 || consumed_data.fin_consumed) { + busy_counter_ = 0; + } +} + +uint64_t QuicStream::BufferedDataBytes() const { + QUICHE_DCHECK_GE(send_buffer_.stream_offset(), stream_bytes_written()); + return send_buffer_.stream_offset() - stream_bytes_written(); +} + +bool QuicStream::CanWriteNewData() const { + return BufferedDataBytes() < buffered_data_threshold_; +} + +bool QuicStream::CanWriteNewDataAfterData(QuicByteCount length) const { + return (BufferedDataBytes() + length) < buffered_data_threshold_; +} + +uint64_t QuicStream::stream_bytes_written() const { + return send_buffer_.stream_bytes_written(); +} + +const QuicIntervalSet& QuicStream::bytes_acked() const { + return send_buffer_.bytes_acked(); +} + +void QuicStream::OnStreamDataConsumed(QuicByteCount bytes_consumed) { + send_buffer_.OnStreamDataConsumed(bytes_consumed); +} + +void QuicStream::WritePendingRetransmission() { + while (HasPendingRetransmission()) { + QuicConsumedData consumed(0, false); + if (!send_buffer_.HasPendingRetransmission()) { + QUIC_DVLOG(1) << ENDPOINT << "stream " << id_ + << " retransmits fin only frame."; + consumed = stream_delegate_->WritevData( + id_, 0, stream_bytes_written(), FIN, LOSS_RETRANSMISSION, + session()->GetEncryptionLevelToSendApplicationData()); + fin_lost_ = !consumed.fin_consumed; + if (fin_lost_) { + // Connection is write blocked. + return; + } + } else { + StreamPendingRetransmission pending = + send_buffer_.NextPendingRetransmission(); + // Determine whether the lost fin can be bundled with the data. + const bool can_bundle_fin = + fin_lost_ && + (pending.offset + pending.length == stream_bytes_written()); + consumed = stream_delegate_->WritevData( + id_, pending.length, pending.offset, can_bundle_fin ? FIN : NO_FIN, + LOSS_RETRANSMISSION, + session()->GetEncryptionLevelToSendApplicationData()); + QUIC_DVLOG(1) << ENDPOINT << "stream " << id_ + << " tries to retransmit stream data [" << pending.offset + << ", " << pending.offset + pending.length + << ") and fin: " << can_bundle_fin + << ", consumed: " << consumed; + OnStreamFrameRetransmitted(pending.offset, consumed.bytes_consumed, + consumed.fin_consumed); + if (consumed.bytes_consumed < pending.length || + (can_bundle_fin && !consumed.fin_consumed)) { + // Connection is write blocked. + return; + } + } + } +} + +bool QuicStream::MaybeSetTtl(QuicTime::Delta ttl) { + if (is_static_) { + QUIC_BUG(quic_bug_10586_14) << "Cannot set TTL of a static stream."; + return false; + } + if (deadline_.IsInitialized()) { + QUIC_DLOG(WARNING) << "Deadline has already been set."; + return false; + } + QuicTime now = session()->connection()->clock()->ApproximateNow(); + deadline_ = now + ttl; + return true; +} + +bool QuicStream::HasDeadlinePassed() const { + if (!deadline_.IsInitialized()) { + // No deadline has been set. + return false; + } + QuicTime now = session()->connection()->clock()->ApproximateNow(); + if (now < deadline_) { + return false; + } + // TTL expired. + QUIC_DVLOG(1) << "stream " << id() << " deadline has passed"; + return true; +} + +void QuicStream::OnDeadlinePassed() { Reset(QUIC_STREAM_TTL_EXPIRED); } + +bool QuicStream::IsFlowControlBlocked() const { + if (!flow_controller_.has_value()) { + QUIC_BUG(quic_bug_10586_15) + << "Trying to access non-existent flow controller."; + return false; + } + return flow_controller_->IsBlocked(); +} + +QuicStreamOffset QuicStream::highest_received_byte_offset() const { + if (!flow_controller_.has_value()) { + QUIC_BUG(quic_bug_10586_16) + << "Trying to access non-existent flow controller."; + return 0; + } + return flow_controller_->highest_received_byte_offset(); +} + +void QuicStream::UpdateReceiveWindowSize(QuicStreamOffset size) { + if (!flow_controller_.has_value()) { + QUIC_BUG(quic_bug_10586_17) + << "Trying to access non-existent flow controller."; + return; + } + flow_controller_->UpdateReceiveWindowSize(size); +} + +absl::optional QuicStream::GetSendWindow() const { + return flow_controller_.has_value() + ? absl::optional(flow_controller_->SendWindowSize()) + : absl::nullopt; +} + +absl::optional QuicStream::GetReceiveWindow() const { + return flow_controller_.has_value() + ? absl::optional( + flow_controller_->receive_window_size()) + : absl::nullopt; +} + +void QuicStream::OnStreamCreatedFromPendingStream() { + sequencer()->SetUnblocked(); +} + +} // namespace quic diff --git a/quiche/quic/core/quic_stream.h b/quiche/quic/core/quic_stream.h new file mode 100644 index 000000000000..b4145285aacb --- /dev/null +++ b/quiche/quic/core/quic_stream.h @@ -0,0 +1,610 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// The base class for client/server QUIC streams. + +// It does not contain the entire interface needed by an application to interact +// with a QUIC stream. Some parts of the interface must be obtained by +// accessing the owning session object. A subclass of QuicStream +// connects the object and the application that generates and consumes the data +// of the stream. + +// The QuicStream object has a dependent QuicStreamSequencer object, +// which is given the stream frames as they arrive, and provides stream data in +// order by invoking ProcessRawData(). + +#ifndef QUICHE_QUIC_CORE_QUIC_STREAM_H_ +#define QUICHE_QUIC_CORE_QUIC_STREAM_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "quiche/quic/core/frames/quic_rst_stream_frame.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_flow_controller.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_stream_priority.h" +#include "quiche/quic/core/quic_stream_send_buffer.h" +#include "quiche/quic/core/quic_stream_sequencer.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/session_notifier_interface.h" +#include "quiche/quic/core/stream_delegate_interface.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/common/platform/api/quiche_reference_counted.h" +#include "quiche/spdy/core/spdy_protocol.h" + +namespace quic { + +namespace test { +class QuicStreamPeer; +} // namespace test + +class QuicSession; +class QuicStream; + +// Buffers frames for a stream until the first byte of that frame arrives. +class QUIC_EXPORT_PRIVATE PendingStream + : public QuicStreamSequencer::StreamInterface { + public: + PendingStream(QuicStreamId id, QuicSession* session); + PendingStream(const PendingStream&) = delete; + PendingStream(PendingStream&&) = default; + ~PendingStream() override = default; + + // QuicStreamSequencer::StreamInterface + void OnDataAvailable() override; + void OnFinRead() override; + void AddBytesConsumed(QuicByteCount bytes) override; + void ResetWithError(QuicResetStreamError error) override; + void OnUnrecoverableError(QuicErrorCode error, + const std::string& details) override; + void OnUnrecoverableError(QuicErrorCode error, + QuicIetfTransportErrorCodes ietf_error, + const std::string& details) override; + QuicStreamId id() const override; + ParsedQuicVersion version() const override; + + // Buffers the contents of |frame|. Frame must have a non-zero offset. + // If the data violates flow control, the connection will be closed. + void OnStreamFrame(const QuicStreamFrame& frame); + + bool is_bidirectional() const { return is_bidirectional_; } + + // Stores the final byte offset from |frame|. + // If the final offset violates flow control, the connection will be closed. + void OnRstStreamFrame(const QuicRstStreamFrame& frame); + + void OnWindowUpdateFrame(const QuicWindowUpdateFrame& frame); + + void OnStopSending(QuicResetStreamError stop_sending_error_code); + + // The error code received from QuicStopSendingFrame (if any). + const absl::optional& GetStopSendingErrorCode() const { + return stop_sending_error_code_; + } + + // Returns the number of bytes read on this stream. + uint64_t stream_bytes_read() { return stream_bytes_read_; } + + const QuicStreamSequencer* sequencer() const { return &sequencer_; } + + void MarkConsumed(QuicByteCount num_bytes); + + // Tells the sequencer to ignore all incoming data itself and not call + // OnDataAvailable(). + void StopReading(); + + private: + friend class QuicStream; + + bool MaybeIncreaseHighestReceivedOffset(QuicStreamOffset new_offset); + + // ID of this stream. + QuicStreamId id_; + + // QUIC version being used by this stream. + ParsedQuicVersion version_; + + // |stream_delegate_| must outlive this stream. + StreamDelegateInterface* stream_delegate_; + + // Bytes read refers to payload bytes only: they do not include framing, + // encryption overhead etc. + uint64_t stream_bytes_read_; + + // True if a frame containing a fin has been received. + bool fin_received_; + + // True if this pending stream is backing a bidirectional stream. + bool is_bidirectional_; + + // Connection-level flow controller. Owned by the session. + QuicFlowController* connection_flow_controller_; + // Stream-level flow controller. + QuicFlowController flow_controller_; + // Stores the buffered frames. + QuicStreamSequencer sequencer_; + // The error code received from QuicStopSendingFrame (if any). + absl::optional stop_sending_error_code_; +}; + +class QUIC_EXPORT_PRIVATE QuicStream + : public QuicStreamSequencer::StreamInterface { + public: + // Creates a new stream with stream_id |id| associated with |session|. If + // |is_static| is true, then the stream will be given precedence + // over other streams when determing what streams should write next. + // |type| indicates whether the stream is bidirectional, read unidirectional + // or write unidirectional. + // TODO(fayang): Remove |type| when IETF stream ID numbering fully kicks in. + QuicStream(QuicStreamId id, QuicSession* session, bool is_static, + StreamType type); + QuicStream(PendingStream* pending, QuicSession* session, bool is_static); + QuicStream(const QuicStream&) = delete; + QuicStream& operator=(const QuicStream&) = delete; + + virtual ~QuicStream(); + + // QuicStreamSequencer::StreamInterface implementation. + QuicStreamId id() const override { return id_; } + ParsedQuicVersion version() const override; + // Called by the stream subclass after it has consumed the final incoming + // data. + void OnFinRead() override; + + // Called by the subclass or the sequencer to reset the stream from this + // end. + void ResetWithError(QuicResetStreamError error) override; + // Convenience wrapper for the method above. + // TODO(b/200606367): switch all calls to using QuicResetStreamError + // interface. + void Reset(QuicRstStreamErrorCode error); + + // Reset() sends both RESET_STREAM and STOP_SENDING; the two methods below + // allow to send only one of those. + void ResetWriteSide(QuicResetStreamError error); + void SendStopSending(QuicResetStreamError error); + + // Called by the subclass or the sequencer to close the entire connection from + // this end. + void OnUnrecoverableError(QuicErrorCode error, + const std::string& details) override; + void OnUnrecoverableError(QuicErrorCode error, + QuicIetfTransportErrorCodes ietf_error, + const std::string& details) override; + + // Called by the session when a (potentially duplicate) stream frame has been + // received for this stream. + virtual void OnStreamFrame(const QuicStreamFrame& frame); + + // Called by the session when the connection becomes writeable to allow the + // stream to write any pending data. + virtual void OnCanWrite(); + + // Called by the session when the endpoint receives a RST_STREAM from the + // peer. + virtual void OnStreamReset(const QuicRstStreamFrame& frame); + + // Called by the session when the endpoint receives or sends a connection + // close, and should immediately close the stream. + virtual void OnConnectionClosed(QuicErrorCode error, + ConnectionCloseSource source); + + const QuicStreamPriority& priority() const; + + // Send PRIORITY_UPDATE frame if application protocol supports it. + virtual void MaybeSendPriorityUpdateFrame() {} + + // Sets |priority_| to priority. This should only be called before bytes are + // written to the server. For a server stream, this is called when a + // PRIORITY_UPDATE frame is received. This calls + // MaybeSendPriorityUpdateFrame(), which for a client stream might send a + // PRIORITY_UPDATE frame. + void SetPriority(const QuicStreamPriority& priority); + + // Returns true if this stream is still waiting for acks of sent data. + // This will return false if all data has been acked, or if the stream + // is no longer interested in data being acked (which happens when + // a stream is reset because of an error). + bool IsWaitingForAcks() const; + + // Number of bytes available to read. + QuicByteCount ReadableBytes() const; + + QuicRstStreamErrorCode stream_error() const { + return stream_error_.internal_code(); + } + QuicErrorCode connection_error() const { return connection_error_; } + + bool reading_stopped() const { + return sequencer_.ignore_read_data() || read_side_closed_; + } + bool write_side_closed() const { return write_side_closed_; } + bool read_side_closed() const { return read_side_closed_; } + + bool IsZombie() const { + return read_side_closed_ && write_side_closed_ && IsWaitingForAcks(); + } + + bool rst_received() const { return rst_received_; } + bool rst_sent() const { return rst_sent_; } + bool fin_received() const { return fin_received_; } + bool fin_sent() const { return fin_sent_; } + bool fin_outstanding() const { return fin_outstanding_; } + bool fin_lost() const { return fin_lost_; } + + uint64_t BufferedDataBytes() const; + + uint64_t stream_bytes_read() const { return stream_bytes_read_; } + uint64_t stream_bytes_written() const; + + size_t busy_counter() const { return busy_counter_; } + void set_busy_counter(size_t busy_counter) { busy_counter_ = busy_counter; } + + // Adjust the flow control window according to new offset in |frame|. + virtual void OnWindowUpdateFrame(const QuicWindowUpdateFrame& frame); + + int num_frames_received() const; + int num_duplicate_frames_received() const; + + // Flow controller related methods. + bool IsFlowControlBlocked() const; + QuicStreamOffset highest_received_byte_offset() const; + void UpdateReceiveWindowSize(QuicStreamOffset size); + + // Called when endpoint receives a frame which could increase the highest + // offset. + // Returns true if the highest offset did increase. + bool MaybeIncreaseHighestReceivedOffset(QuicStreamOffset new_offset); + + // Set the flow controller's send window offset from session config. + // |was_zero_rtt_rejected| is true if this config is from a rejected IETF QUIC + // 0-RTT attempt. Closes the connection and returns false if |new_offset| is + // not valid. + bool MaybeConfigSendWindowOffset(QuicStreamOffset new_offset, + bool was_zero_rtt_rejected); + + // Returns true if the stream has received either a RST_STREAM or a FIN - + // either of which gives a definitive number of bytes which the peer has + // sent. If this is not true on deletion of the stream object, the session + // must keep track of the stream's byte offset until a definitive final value + // arrives. + bool HasReceivedFinalOffset() const { return fin_received_ || rst_received_; } + + // Returns true if the stream has queued data waiting to write. + bool HasBufferedData() const; + + // Returns the version of QUIC being used for this stream. + QuicTransportVersion transport_version() const; + + // Returns the crypto handshake protocol that was used on this stream's + // connection. + HandshakeProtocol handshake_protocol() const; + + // Sets the sequencer to consume all incoming data itself and not call + // OnDataAvailable(). + // When the FIN is received, the stream will be notified automatically (via + // OnFinRead()) (which may happen during the call of StopReading()). + // TODO(dworley): There should be machinery to send a RST_STREAM/NO_ERROR and + // stop sending stream-level flow-control updates when this end sends FIN. + virtual void StopReading(); + + // Sends as much of |data| to the connection on the application encryption + // level as the connection will consume, and then buffers any remaining data + // in the send buffer. If fin is true: if it is immediately passed on to the + // session, write_side_closed() becomes true, otherwise fin_buffered_ becomes + // true. + void WriteOrBufferData( + absl::string_view data, bool fin, + quiche::QuicheReferenceCountedPointer + ack_listener); + + // Sends |data| to connection with specified |level|. + void WriteOrBufferDataAtLevel( + absl::string_view data, bool fin, EncryptionLevel level, + quiche::QuicheReferenceCountedPointer + ack_listener); + + // Adds random padding after the fin is consumed for this stream. + void AddRandomPaddingAfterFin(); + + // Write |data_length| of data starts at |offset| from send buffer. + bool WriteStreamData(QuicStreamOffset offset, QuicByteCount data_length, + QuicDataWriter* writer); + + // Called when data [offset, offset + data_length) is acked. |fin_acked| + // indicates whether the fin is acked. Returns true and updates + // |newly_acked_length| if any new stream data (including fin) gets acked. + virtual bool OnStreamFrameAcked(QuicStreamOffset offset, + QuicByteCount data_length, bool fin_acked, + QuicTime::Delta ack_delay_time, + QuicTime receive_timestamp, + QuicByteCount* newly_acked_length); + + // Called when data [offset, offset + data_length) was retransmitted. + // |fin_retransmitted| indicates whether fin was retransmitted. + virtual void OnStreamFrameRetransmitted(QuicStreamOffset offset, + QuicByteCount data_length, + bool fin_retransmitted); + + // Called when data [offset, offset + data_length) is considered as lost. + // |fin_lost| indicates whether the fin is considered as lost. + virtual void OnStreamFrameLost(QuicStreamOffset offset, + QuicByteCount data_length, bool fin_lost); + + // Called to retransmit outstanding portion in data [offset, offset + + // data_length) and |fin| with Transmission |type|. + // Returns true if all data gets retransmitted. + virtual bool RetransmitStreamData(QuicStreamOffset offset, + QuicByteCount data_length, bool fin, + TransmissionType type); + + // Sets deadline of this stream to be now + |ttl|, returns true if the setting + // succeeds. + bool MaybeSetTtl(QuicTime::Delta ttl); + + // Commits data into the stream write buffer, and potentially sends it over + // the wire. This method has all-or-nothing semantics: if the write buffer is + // not full, all of the memslices in |span| are moved into it; otherwise, + // nothing happens. + QuicConsumedData WriteMemSlices(absl::Span span, + bool fin); + QuicConsumedData WriteMemSlice(quiche::QuicheMemSlice span, bool fin); + + // Returns true if any stream data is lost (including fin) and needs to be + // retransmitted. + virtual bool HasPendingRetransmission() const; + + // Returns true if any portion of data [offset, offset + data_length) is + // outstanding or fin is outstanding (if |fin| is true). Returns false + // otherwise. + bool IsStreamFrameOutstanding(QuicStreamOffset offset, + QuicByteCount data_length, bool fin) const; + + StreamType type() const { return type_; } + + // Handle received StopSending frame. Returns true if the processing finishes + // gracefully. + virtual bool OnStopSending(QuicResetStreamError error); + + // Returns true if the stream is static. + bool is_static() const { return is_static_; } + + bool was_draining() const { return was_draining_; } + + QuicTime creation_time() const { return creation_time_; } + + bool fin_buffered() const { return fin_buffered_; } + + // True if buffered data in send buffer is below buffered_data_threshold_. + bool CanWriteNewData() const; + + // Called immediately after the stream is created from a pending stream, + // indicating it can start processing data. + void OnStreamCreatedFromPendingStream(); + + void DisableConnectionFlowControlForThisStream() { + stream_contributes_to_connection_flow_control_ = false; + } + + protected: + // Called when data of [offset, offset + data_length] is buffered in send + // buffer. + virtual void OnDataBuffered( + QuicStreamOffset /*offset*/, QuicByteCount /*data_length*/, + const quiche::QuicheReferenceCountedPointer& + /*ack_listener*/) {} + + // Called just before the object is destroyed. + // The object should not be accessed after OnClose is called. + // Sends a RST_STREAM with code QUIC_RST_ACKNOWLEDGEMENT if neither a FIN nor + // a RST_STREAM has been sent. + virtual void OnClose(); + + // True if buffered data in send buffer is still below + // buffered_data_threshold_ even after writing |length| bytes. + bool CanWriteNewDataAfterData(QuicByteCount length) const; + + // Called when upper layer can write new data. + virtual void OnCanWriteNewData() {} + + // Called when |bytes_consumed| bytes has been consumed. + virtual void OnStreamDataConsumed(QuicByteCount bytes_consumed); + + // Called by the stream sequencer as bytes are consumed from the buffer. + // If the receive window has dropped below the threshold, then send a + // WINDOW_UPDATE frame. + void AddBytesConsumed(QuicByteCount bytes) override; + + // Writes pending retransmissions if any. + virtual void WritePendingRetransmission(); + + // This is called when stream tries to retransmit data after deadline_. Make + // this virtual so that subclasses can implement their own logics. + virtual void OnDeadlinePassed(); + + // Called to set fin_sent_. This is only used by Google QUIC while body is + // empty. + void SetFinSent(); + + // Send STOP_SENDING if it hasn't been sent yet. + void MaybeSendStopSending(QuicResetStreamError error); + + // Send RESET_STREAM if it hasn't been sent yet. + void MaybeSendRstStream(QuicResetStreamError error); + + // Convenience warppers for two methods above. + void MaybeSendRstStream(QuicRstStreamErrorCode error) { + MaybeSendRstStream(QuicResetStreamError::FromInternal(error)); + } + void MaybeSendStopSending(QuicRstStreamErrorCode error) { + MaybeSendStopSending(QuicResetStreamError::FromInternal(error)); + } + + // Close the write side of the socket. Further writes will fail. + // Can be called by the subclass or internally. + // Does not send a FIN. May cause the stream to be closed. + virtual void CloseWriteSide(); + + void set_rst_received(bool rst_received) { rst_received_ = rst_received; } + void set_stream_error(QuicResetStreamError error) { stream_error_ = error; } + + StreamDelegateInterface* stream_delegate() { return stream_delegate_; } + + const QuicSession* session() const { return session_; } + QuicSession* session() { return session_; } + + const QuicStreamSequencer* sequencer() const { return &sequencer_; } + QuicStreamSequencer* sequencer() { return &sequencer_; } + + const QuicIntervalSet& bytes_acked() const; + + const QuicStreamSendBuffer& send_buffer() const { return send_buffer_; } + + QuicStreamSendBuffer& send_buffer() { return send_buffer_; } + + // Called when the write side of the stream is closed, and all of the outgoing + // data has been acknowledged. This corresponds to the "Data Recvd" state of + // RFC 9000. + virtual void OnWriteSideInDataRecvdState() {} + + // Return the current flow control send window in bytes. + absl::optional GetSendWindow() const; + absl::optional GetReceiveWindow() const; + + private: + friend class test::QuicStreamPeer; + friend class QuicStreamUtils; + + QuicStream(QuicStreamId id, QuicSession* session, + QuicStreamSequencer sequencer, bool is_static, StreamType type, + uint64_t stream_bytes_read, bool fin_received, + absl::optional flow_controller, + QuicFlowController* connection_flow_controller); + + // Calls MaybeSendBlocked on the stream's flow controller and the connection + // level flow controller. If the stream is flow control blocked by the + // connection-level flow controller but not by the stream-level flow + // controller, marks this stream as connection-level write blocked. + void MaybeSendBlocked(); + + // Write buffered data (in send buffer) at |level|. + void WriteBufferedData(EncryptionLevel level); + + // Close the read side of the stream. May cause the stream to be closed. + void CloseReadSide(); + + // Called when bytes are sent to the peer. + void AddBytesSent(QuicByteCount bytes); + + // Returns true if deadline_ has passed. + bool HasDeadlinePassed() const; + + QuicStreamSequencer sequencer_; + QuicStreamId id_; + // Pointer to the owning QuicSession object. + // TODO(b/136274541): Remove session pointer from streams. + QuicSession* session_; + StreamDelegateInterface* stream_delegate_; + // The priority of the stream, once parsed. + QuicStreamPriority priority_; + // Bytes read refers to payload bytes only: they do not include framing, + // encryption overhead etc. + uint64_t stream_bytes_read_; + + // Stream error code received from a RstStreamFrame or error code sent by the + // visitor or sequencer in the RstStreamFrame. + QuicResetStreamError stream_error_; + // Connection error code due to which the stream was closed. |stream_error_| + // is set to |QUIC_STREAM_CONNECTION_ERROR| when this happens and consumers + // should check |connection_error_|. + QuicErrorCode connection_error_; + + // True if the read side is closed and further frames should be rejected. + bool read_side_closed_; + // True if the write side is closed, and further writes should fail. + bool write_side_closed_; + + // True if OnWriteSideInDataRecvdState() has already been called. + bool write_side_data_recvd_state_notified_; + + // True if the subclass has written a FIN with WriteOrBufferData, but it was + // buffered in queued_data_ rather than being sent to the session. + bool fin_buffered_; + // True if a FIN has been sent to the session. + bool fin_sent_; + // True if a FIN is waiting to be acked. + bool fin_outstanding_; + // True if a FIN is lost. + bool fin_lost_; + + // True if this stream has received (and the sequencer has accepted) a + // StreamFrame with the FIN set. + bool fin_received_; + + // True if an RST_STREAM has been sent to the session. + // In combination with fin_sent_, used to ensure that a FIN and/or a + // RST_STREAM is always sent to terminate the stream. + bool rst_sent_; + + // True if this stream has received a RST_STREAM frame. + bool rst_received_; + + // True if the stream has sent STOP_SENDING to the session. + bool stop_sending_sent_; + + absl::optional flow_controller_; + + // The connection level flow controller. Not owned. + QuicFlowController* connection_flow_controller_; + + // Special streams, such as the crypto and headers streams, do not respect + // connection level flow control limits (but are stream level flow control + // limited). + bool stream_contributes_to_connection_flow_control_; + + // A counter incremented when OnCanWrite() is called and no progress is made. + // For debugging only. + size_t busy_counter_; + + // Indicates whether paddings will be added after the fin is consumed for this + // stream. + bool add_random_padding_after_fin_; + + // Send buffer of this stream. Send buffer is cleaned up when data gets acked + // or discarded. + QuicStreamSendBuffer send_buffer_; + + // Latched value of quic_buffered_data_threshold. + const QuicByteCount buffered_data_threshold_; + + // If true, then this stream has precedence over other streams for write + // scheduling. + const bool is_static_; + + // If initialized, reset this stream at this deadline. + QuicTime deadline_; + + // True if this stream has entered draining state. + bool was_draining_; + + // Indicates whether this stream is bidirectional, read unidirectional or + // write unidirectional. + const StreamType type_; + + // Creation time of this stream, as reported by the QuicClock. + const QuicTime creation_time_; + + Perspective perspective_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_STREAM_H_ diff --git a/quiche/quic/core/quic_stream_frame_data_producer.h b/quiche/quic/core/quic_stream_frame_data_producer.h new file mode 100644 index 000000000000..5dc12b7912d4 --- /dev/null +++ b/quiche/quic/core/quic_stream_frame_data_producer.h @@ -0,0 +1,38 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_STREAM_FRAME_DATA_PRODUCER_H_ +#define QUICHE_QUIC_CORE_QUIC_STREAM_FRAME_DATA_PRODUCER_H_ + +#include "quiche/quic/core/quic_types.h" + +namespace quic { + +class QuicDataWriter; + +// Pure virtual class to retrieve stream data. +class QUIC_EXPORT_PRIVATE QuicStreamFrameDataProducer { + public: + virtual ~QuicStreamFrameDataProducer() {} + + // Let |writer| write |data_length| data with |offset| of stream |id|. The + // write fails when either stream is closed or corresponding data is failed to + // be retrieved. This method allows writing a single stream frame from data + // that spans multiple buffers. + virtual WriteStreamDataResult WriteStreamData(QuicStreamId id, + QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* writer) = 0; + + // Writes the data for a CRYPTO frame to |writer| for a frame at encryption + // level |level| starting at offset |offset| for |data_length| bytes. Returns + // whether writing the data was successful. + virtual bool WriteCryptoData(EncryptionLevel level, QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* writer) = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_STREAM_FRAME_DATA_PRODUCER_H_ diff --git a/quiche/quic/core/quic_stream_id_manager.cc b/quiche/quic/core/quic_stream_id_manager.cc new file mode 100644 index 000000000000..1443211786c4 --- /dev/null +++ b/quiche/quic/core/quic_stream_id_manager.cc @@ -0,0 +1,238 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +#include "quiche/quic/core/quic_stream_id_manager.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +#define ENDPOINT \ + (perspective_ == Perspective::IS_SERVER ? " Server: " : " Client: ") + +QuicStreamIdManager::QuicStreamIdManager( + DelegateInterface* delegate, bool unidirectional, Perspective perspective, + ParsedQuicVersion version, QuicStreamCount max_allowed_outgoing_streams, + QuicStreamCount max_allowed_incoming_streams) + : delegate_(delegate), + unidirectional_(unidirectional), + perspective_(perspective), + version_(version), + outgoing_max_streams_(max_allowed_outgoing_streams), + next_outgoing_stream_id_(GetFirstOutgoingStreamId()), + outgoing_stream_count_(0), + incoming_actual_max_streams_(max_allowed_incoming_streams), + incoming_advertised_max_streams_(max_allowed_incoming_streams), + incoming_initial_max_open_streams_(max_allowed_incoming_streams), + incoming_stream_count_(0), + largest_peer_created_stream_id_( + QuicUtils::GetInvalidStreamId(version.transport_version)) {} + +QuicStreamIdManager::~QuicStreamIdManager() {} + +bool QuicStreamIdManager::OnStreamsBlockedFrame( + const QuicStreamsBlockedFrame& frame, std::string* error_details) { + QUICHE_DCHECK_EQ(frame.unidirectional, unidirectional_); + if (frame.stream_count > incoming_advertised_max_streams_) { + // Peer thinks it can send more streams that we've told it. + *error_details = absl::StrCat( + "StreamsBlockedFrame's stream count ", frame.stream_count, + " exceeds incoming max stream ", incoming_advertised_max_streams_); + return false; + } + QUICHE_DCHECK_LE(incoming_advertised_max_streams_, + incoming_actual_max_streams_); + if (incoming_advertised_max_streams_ == incoming_actual_max_streams_) { + // We have told peer about current max. + return true; + } + if (frame.stream_count < incoming_actual_max_streams_) { + // Peer thinks it's blocked on a stream count that is less than our current + // max. Inform the peer of the correct stream count. + SendMaxStreamsFrame(); + } + return true; +} + +bool QuicStreamIdManager::MaybeAllowNewOutgoingStreams( + QuicStreamCount max_open_streams) { + if (max_open_streams <= outgoing_max_streams_) { + // Only update the stream count if it would increase the limit. + return false; + } + + // This implementation only supports 32 bit Stream IDs, so limit max streams + // if it would exceed the max 32 bits can express. + outgoing_max_streams_ = + std::min(max_open_streams, QuicUtils::GetMaxStreamCount()); + + return true; +} + +void QuicStreamIdManager::SetMaxOpenIncomingStreams( + QuicStreamCount max_open_streams) { + QUIC_BUG_IF(quic_bug_12413_1, incoming_stream_count_ > 0) + << "non-zero incoming stream count " << incoming_stream_count_ + << " when setting max incoming stream to " << max_open_streams; + QUIC_DLOG_IF(WARNING, incoming_initial_max_open_streams_ != max_open_streams) + << absl::StrCat(unidirectional_ ? "unidirectional " : "bidirectional: ", + "incoming stream limit changed from ", + incoming_initial_max_open_streams_, " to ", + max_open_streams); + incoming_actual_max_streams_ = max_open_streams; + incoming_advertised_max_streams_ = max_open_streams; + incoming_initial_max_open_streams_ = max_open_streams; +} + +void QuicStreamIdManager::MaybeSendMaxStreamsFrame() { + int divisor = GetQuicFlag(quic_max_streams_window_divisor); + + if (divisor > 0) { + if ((incoming_advertised_max_streams_ - incoming_stream_count_) > + (incoming_initial_max_open_streams_ / divisor)) { + // window too large, no advertisement + return; + } + } + SendMaxStreamsFrame(); +} + +void QuicStreamIdManager::SendMaxStreamsFrame() { + QUIC_BUG_IF(quic_bug_12413_2, + incoming_advertised_max_streams_ >= incoming_actual_max_streams_); + incoming_advertised_max_streams_ = incoming_actual_max_streams_; + delegate_->SendMaxStreams(incoming_advertised_max_streams_, unidirectional_); +} + +void QuicStreamIdManager::OnStreamClosed(QuicStreamId stream_id) { + QUICHE_DCHECK_NE(QuicUtils::IsBidirectionalStreamId(stream_id, version_), + unidirectional_); + if (QuicUtils::IsOutgoingStreamId(version_, stream_id, perspective_)) { + // Nothing to do for outgoing streams. + return; + } + // If the stream is inbound, we can increase the actual stream limit and maybe + // advertise the new limit to the peer. + if (incoming_actual_max_streams_ == QuicUtils::GetMaxStreamCount()) { + // Reached the maximum stream id value that the implementation + // supports. Nothing can be done here. + return; + } + // One stream closed, and another one can be opened. + incoming_actual_max_streams_++; + MaybeSendMaxStreamsFrame(); +} + +QuicStreamId QuicStreamIdManager::GetNextOutgoingStreamId() { + QUIC_BUG_IF(quic_bug_12413_3, outgoing_stream_count_ >= outgoing_max_streams_) + << "Attempt to allocate a new outgoing stream that would exceed the " + "limit (" + << outgoing_max_streams_ << ")"; + QuicStreamId id = next_outgoing_stream_id_; + next_outgoing_stream_id_ += + QuicUtils::StreamIdDelta(version_.transport_version); + outgoing_stream_count_++; + return id; +} + +bool QuicStreamIdManager::CanOpenNextOutgoingStream() const { + QUICHE_DCHECK(VersionHasIetfQuicFrames(version_.transport_version)); + return outgoing_stream_count_ < outgoing_max_streams_; +} + +bool QuicStreamIdManager::MaybeIncreaseLargestPeerStreamId( + const QuicStreamId stream_id, std::string* error_details) { + // |stream_id| must be an incoming stream of the right directionality. + QUICHE_DCHECK_NE(QuicUtils::IsBidirectionalStreamId(stream_id, version_), + unidirectional_); + QUICHE_DCHECK_NE(QuicUtils::IsServerInitiatedStreamId( + version_.transport_version, stream_id), + perspective_ == Perspective::IS_SERVER); + if (available_streams_.erase(stream_id) == 1) { + // stream_id is available. + return true; + } + + if (largest_peer_created_stream_id_ != + QuicUtils::GetInvalidStreamId(version_.transport_version)) { + QUICHE_DCHECK_GT(stream_id, largest_peer_created_stream_id_); + } + + // Calculate increment of incoming_stream_count_ by creating stream_id. + const QuicStreamCount delta = + QuicUtils::StreamIdDelta(version_.transport_version); + const QuicStreamId least_new_stream_id = + largest_peer_created_stream_id_ == + QuicUtils::GetInvalidStreamId(version_.transport_version) + ? GetFirstIncomingStreamId() + : largest_peer_created_stream_id_ + delta; + const QuicStreamCount stream_count_increment = + (stream_id - least_new_stream_id) / delta + 1; + + if (incoming_stream_count_ + stream_count_increment > + incoming_advertised_max_streams_) { + QUIC_DLOG(INFO) << ENDPOINT + << "Failed to create a new incoming stream with id:" + << stream_id << ", reaching MAX_STREAMS limit: " + << incoming_advertised_max_streams_ << "."; + *error_details = absl::StrCat("Stream id ", stream_id, + " would exceed stream count limit ", + incoming_advertised_max_streams_); + return false; + } + + for (QuicStreamId id = least_new_stream_id; id < stream_id; id += delta) { + available_streams_.insert(id); + } + incoming_stream_count_ += stream_count_increment; + largest_peer_created_stream_id_ = stream_id; + return true; +} + +bool QuicStreamIdManager::IsAvailableStream(QuicStreamId id) const { + QUICHE_DCHECK_NE(QuicUtils::IsBidirectionalStreamId(id, version_), + unidirectional_); + if (QuicUtils::IsOutgoingStreamId(version_, id, perspective_)) { + // Stream IDs under next_ougoing_stream_id_ are either open or previously + // open but now closed. + return id >= next_outgoing_stream_id_; + } + // For peer created streams, we also need to consider available streams. + return largest_peer_created_stream_id_ == + QuicUtils::GetInvalidStreamId(version_.transport_version) || + id > largest_peer_created_stream_id_ || + available_streams_.contains(id); +} + +QuicStreamId QuicStreamIdManager::GetFirstOutgoingStreamId() const { + return (unidirectional_) ? QuicUtils::GetFirstUnidirectionalStreamId( + version_.transport_version, perspective_) + : QuicUtils::GetFirstBidirectionalStreamId( + version_.transport_version, perspective_); +} + +QuicStreamId QuicStreamIdManager::GetFirstIncomingStreamId() const { + return (unidirectional_) ? QuicUtils::GetFirstUnidirectionalStreamId( + version_.transport_version, + QuicUtils::InvertPerspective(perspective_)) + : QuicUtils::GetFirstBidirectionalStreamId( + version_.transport_version, + QuicUtils::InvertPerspective(perspective_)); +} + +QuicStreamCount QuicStreamIdManager::available_incoming_streams() const { + return incoming_advertised_max_streams_ - incoming_stream_count_; +} + +} // namespace quic diff --git a/quiche/quic/core/quic_stream_id_manager.h b/quiche/quic/core/quic_stream_id_manager.h new file mode 100644 index 000000000000..eaad296e4ea9 --- /dev/null +++ b/quiche/quic/core/quic_stream_id_manager.h @@ -0,0 +1,184 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +#ifndef QUICHE_QUIC_CORE_QUIC_STREAM_ID_MANAGER_H_ +#define QUICHE_QUIC_CORE_QUIC_STREAM_ID_MANAGER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" +#include "quiche/quic/core/frames/quic_frame.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +namespace test { +class QuicSessionPeer; +class QuicStreamIdManagerPeer; +} // namespace test + +// This class manages the stream ids for IETF QUIC. +class QUIC_EXPORT_PRIVATE QuicStreamIdManager { + public: + class QUIC_EXPORT_PRIVATE DelegateInterface { + public: + virtual ~DelegateInterface() = default; + + // Send a MAX_STREAMS frame. + virtual void SendMaxStreams(QuicStreamCount stream_count, + bool unidirectional) = 0; + }; + + QuicStreamIdManager(DelegateInterface* delegate, bool unidirectional, + Perspective perspective, ParsedQuicVersion version, + QuicStreamCount max_allowed_outgoing_streams, + QuicStreamCount max_allowed_incoming_streams); + + ~QuicStreamIdManager(); + + // Generate a string suitable for sending to the log/etc to show current state + // of the stream ID manager. + std::string DebugString() const { + return absl::StrCat( + " { unidirectional_: ", unidirectional_, + ", perspective: ", perspective_, + ", outgoing_max_streams_: ", outgoing_max_streams_, + ", next_outgoing_stream_id_: ", next_outgoing_stream_id_, + ", outgoing_stream_count_: ", outgoing_stream_count_, + ", incoming_actual_max_streams_: ", incoming_actual_max_streams_, + ", incoming_advertised_max_streams_: ", + incoming_advertised_max_streams_, + ", incoming_stream_count_: ", incoming_stream_count_, + ", available_streams_.size(): ", available_streams_.size(), + ", largest_peer_created_stream_id_: ", largest_peer_created_stream_id_, + " }"); + } + + // Processes the STREAMS_BLOCKED frame. If error is encountered, populates + // |error_details| and returns false. + bool OnStreamsBlockedFrame(const QuicStreamsBlockedFrame& frame, + std::string* error_details); + + // Returns whether the next outgoing stream ID can be allocated or not. + bool CanOpenNextOutgoingStream() const; + + // Generate and send a MAX_STREAMS frame. + void SendMaxStreamsFrame(); + + // Invoked to deal with releasing a stream. Does nothing if the stream is + // outgoing. If the stream is incoming, the number of streams that the peer + // can open will be updated and a MAX_STREAMS frame, informing the peer of + // the additional streams, may be sent. + void OnStreamClosed(QuicStreamId stream_id); + + // Returns the next outgoing stream id. Applications must call + // CanOpenNextOutgoingStream() first. + QuicStreamId GetNextOutgoingStreamId(); + + void SetMaxOpenIncomingStreams(QuicStreamCount max_open_streams); + + // Called on |max_open_streams| outgoing streams can be created because of 1) + // config negotiated or 2) MAX_STREAMS received. Returns true if new + // streams can be created. + bool MaybeAllowNewOutgoingStreams(QuicStreamCount max_open_streams); + + // Checks if the incoming stream ID exceeds the MAX_STREAMS limit. If the + // limit is exceeded, populates |error_detials| and returns false. + bool MaybeIncreaseLargestPeerStreamId(const QuicStreamId stream_id, + std::string* error_details); + + // Returns true if |id| is still available. + bool IsAvailableStream(QuicStreamId id) const; + + QuicStreamCount incoming_initial_max_open_streams() const { + return incoming_initial_max_open_streams_; + } + + QuicStreamId next_outgoing_stream_id() const { + return next_outgoing_stream_id_; + } + + // Number of streams that the peer believes that it can still create. + QuicStreamCount available_incoming_streams() const; + + QuicStreamId largest_peer_created_stream_id() const { + return largest_peer_created_stream_id_; + } + + QuicStreamCount outgoing_max_streams() const { return outgoing_max_streams_; } + QuicStreamCount incoming_actual_max_streams() const { + return incoming_actual_max_streams_; + } + QuicStreamCount incoming_advertised_max_streams() const { + return incoming_advertised_max_streams_; + } + QuicStreamCount outgoing_stream_count() const { + return outgoing_stream_count_; + } + + private: + friend class test::QuicSessionPeer; + friend class test::QuicStreamIdManagerPeer; + + // Check whether the MAX_STREAMS window has opened up enough and, if so, + // generate and send a MAX_STREAMS frame. + void MaybeSendMaxStreamsFrame(); + + // Get what should be the first incoming/outgoing stream ID that + // this stream id manager will manage, taking into account directionality and + // client/server perspective. + QuicStreamId GetFirstOutgoingStreamId() const; + QuicStreamId GetFirstIncomingStreamId() const; + + // Back reference to the session containing this Stream ID Manager. + DelegateInterface* delegate_; + + // Whether this stream id manager is for unidrectional (true) or bidirectional + // (false) streams. + const bool unidirectional_; + + // Is this manager a client or a server. + const Perspective perspective_; + + // QUIC version used for this manager. + const ParsedQuicVersion version_; + + // The number of streams that this node can initiate. + // This limit is first set when config is negotiated, but may be updated upon + // receiving MAX_STREAMS frame. + QuicStreamCount outgoing_max_streams_; + + // The ID to use for the next outgoing stream. + QuicStreamId next_outgoing_stream_id_; + + // The number of outgoing streams that have ever been opened, including those + // that have been closed. This number must never be larger than + // outgoing_max_streams_. + QuicStreamCount outgoing_stream_count_; + + // FOR INCOMING STREAMS + + // The actual maximum number of streams that can be opened by the peer. + QuicStreamCount incoming_actual_max_streams_; + // Max incoming stream number that has been advertised to the peer and is <= + // incoming_actual_max_streams_. It is set to incoming_actual_max_streams_ + // when a MAX_STREAMS is sent. + QuicStreamCount incoming_advertised_max_streams_; + + // Initial maximum on the number of open streams allowed. + QuicStreamCount incoming_initial_max_open_streams_; + + // The number of streams that have been created, including open ones and + // closed ones. + QuicStreamCount incoming_stream_count_; + + // Set of stream ids that are less than the largest stream id that has been + // received, but are nonetheless available to be created. + absl::flat_hash_set available_streams_; + + QuicStreamId largest_peer_created_stream_id_; +}; +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_STREAM_ID_MANAGER_H_ diff --git a/quiche/quic/core/quic_stream_id_manager_test.cc b/quiche/quic/core/quic_stream_id_manager_test.cc new file mode 100644 index 000000000000..5b131123ee2d --- /dev/null +++ b/quiche/quic/core/quic_stream_id_manager_test.cc @@ -0,0 +1,472 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +#include "quiche/quic/core/quic_stream_id_manager.h" + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_stream_id_manager_peer.h" + +using testing::_; +using testing::StrictMock; + +namespace quic { +namespace test { +namespace { + +class MockDelegate : public QuicStreamIdManager::DelegateInterface { + public: + MOCK_METHOD(void, SendMaxStreams, + (QuicStreamCount stream_count, bool unidirectional), (override)); +}; + +struct TestParams { + TestParams(ParsedQuicVersion version, Perspective perspective, + bool is_unidirectional) + : version(version), + perspective(perspective), + is_unidirectional(is_unidirectional) {} + + ParsedQuicVersion version; + Perspective perspective; + bool is_unidirectional; +}; + +// Used by ::testing::PrintToStringParamName(). +std::string PrintToString(const TestParams& p) { + return absl::StrCat( + ParsedQuicVersionToString(p.version), "_", + (p.perspective == Perspective::IS_CLIENT ? "Client" : "Server"), + (p.is_unidirectional ? "Unidirectional" : "Bidirectional")); +} + +std::vector GetTestParams() { + std::vector params; + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + if (!version.HasIetfQuicFrames()) { + continue; + } + for (Perspective perspective : + {Perspective::IS_CLIENT, Perspective::IS_SERVER}) { + for (bool is_unidirectional : {true, false}) { + params.push_back(TestParams(version, perspective, is_unidirectional)); + } + } + } + return params; +} + +class QuicStreamIdManagerTest : public QuicTestWithParam { + protected: + QuicStreamIdManagerTest() + : stream_id_manager_(&delegate_, IsUnidirectional(), perspective(), + GetParam().version, 0, + kDefaultMaxStreamsPerConnection) { + QUICHE_DCHECK(VersionHasIetfQuicFrames(transport_version())); + } + + QuicTransportVersion transport_version() const { + return GetParam().version.transport_version; + } + + // Returns the stream ID for the Nth incoming stream (created by the peer) + // of the corresponding directionality of this manager. + QuicStreamId GetNthIncomingStreamId(int n) { + return QuicUtils::StreamIdDelta(transport_version()) * n + + (IsUnidirectional() + ? QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), + QuicUtils::InvertPerspective(perspective())) + : QuicUtils::GetFirstBidirectionalStreamId( + transport_version(), + QuicUtils::InvertPerspective(perspective()))); + } + + bool IsUnidirectional() { return GetParam().is_unidirectional; } + Perspective perspective() { return GetParam().perspective; } + + StrictMock delegate_; + QuicStreamIdManager stream_id_manager_; +}; + +INSTANTIATE_TEST_SUITE_P(Tests, QuicStreamIdManagerTest, + ::testing::ValuesIn(GetTestParams()), + ::testing::PrintToStringParamName()); + +TEST_P(QuicStreamIdManagerTest, Initialization) { + EXPECT_EQ(0u, stream_id_manager_.outgoing_max_streams()); + + EXPECT_EQ(kDefaultMaxStreamsPerConnection, + stream_id_manager_.incoming_actual_max_streams()); + EXPECT_EQ(kDefaultMaxStreamsPerConnection, + stream_id_manager_.incoming_advertised_max_streams()); + EXPECT_EQ(kDefaultMaxStreamsPerConnection, + stream_id_manager_.incoming_initial_max_open_streams()); +} + +// This test checks that the stream advertisement window is set to 1 +// if the number of stream ids is 1. This is a special case in the code. +TEST_P(QuicStreamIdManagerTest, CheckMaxStreamsWindowForSingleStream) { + stream_id_manager_.SetMaxOpenIncomingStreams(1); + EXPECT_EQ(1u, stream_id_manager_.incoming_initial_max_open_streams()); + EXPECT_EQ(1u, stream_id_manager_.incoming_actual_max_streams()); +} + +TEST_P(QuicStreamIdManagerTest, CheckMaxStreamsBadValuesOverMaxFailsOutgoing) { + QuicStreamCount implementation_max = QuicUtils::GetMaxStreamCount(); + // Ensure that the limit is less than the implementation maximum. + EXPECT_LT(stream_id_manager_.outgoing_max_streams(), implementation_max); + + EXPECT_TRUE( + stream_id_manager_.MaybeAllowNewOutgoingStreams(implementation_max + 1)); + // Should be pegged at the max. + EXPECT_EQ(implementation_max, stream_id_manager_.outgoing_max_streams()); +} + +// Check the case of the stream count in a STREAMS_BLOCKED frame is less than +// the count most recently advertised in a MAX_STREAMS frame. +TEST_P(QuicStreamIdManagerTest, ProcessStreamsBlockedOk) { + QuicStreamCount stream_count = + stream_id_manager_.incoming_initial_max_open_streams(); + QuicStreamsBlockedFrame frame(0, stream_count - 1, IsUnidirectional()); + // We have notified peer about current max. + EXPECT_CALL(delegate_, SendMaxStreams(stream_count, IsUnidirectional())) + .Times(0); + std::string error_details; + EXPECT_TRUE(stream_id_manager_.OnStreamsBlockedFrame(frame, &error_details)); +} + +// Check the case of the stream count in a STREAMS_BLOCKED frame is equal to the +// count most recently advertised in a MAX_STREAMS frame. No MAX_STREAMS +// should be generated. +TEST_P(QuicStreamIdManagerTest, ProcessStreamsBlockedNoOp) { + QuicStreamCount stream_count = + stream_id_manager_.incoming_initial_max_open_streams(); + QuicStreamsBlockedFrame frame(0, stream_count, IsUnidirectional()); + EXPECT_CALL(delegate_, SendMaxStreams(_, _)).Times(0); +} + +// Check the case of the stream count in a STREAMS_BLOCKED frame is greater than +// the count most recently advertised in a MAX_STREAMS frame. Expect a +// connection close with an error. +TEST_P(QuicStreamIdManagerTest, ProcessStreamsBlockedTooBig) { + EXPECT_CALL(delegate_, SendMaxStreams(_, _)).Times(0); + QuicStreamCount stream_count = + stream_id_manager_.incoming_initial_max_open_streams() + 1; + QuicStreamsBlockedFrame frame(0, stream_count, IsUnidirectional()); + std::string error_details; + EXPECT_FALSE(stream_id_manager_.OnStreamsBlockedFrame(frame, &error_details)); + EXPECT_EQ( + error_details, + "StreamsBlockedFrame's stream count 101 exceeds incoming max stream 100"); +} + +// Same basic tests as above, but calls +// QuicStreamIdManager::MaybeIncreaseLargestPeerStreamId directly, avoiding the +// call chain. The intent is that if there is a problem, the following tests +// will point to either the stream ID manager or the call chain. They also +// provide specific, small scale, tests of a public QuicStreamIdManager method. +// First test make sure that streams with ids below the limit are accepted. +TEST_P(QuicStreamIdManagerTest, IsIncomingStreamIdValidBelowLimit) { + QuicStreamId stream_id = GetNthIncomingStreamId( + stream_id_manager_.incoming_actual_max_streams() - 2); + EXPECT_TRUE( + stream_id_manager_.MaybeIncreaseLargestPeerStreamId(stream_id, nullptr)); +} + +// Accept a stream with an ID that equals the limit. +TEST_P(QuicStreamIdManagerTest, IsIncomingStreamIdValidAtLimit) { + QuicStreamId stream_id = GetNthIncomingStreamId( + stream_id_manager_.incoming_actual_max_streams() - 1); + EXPECT_TRUE( + stream_id_manager_.MaybeIncreaseLargestPeerStreamId(stream_id, nullptr)); +} + +// Close the connection if the id exceeds the limit. +TEST_P(QuicStreamIdManagerTest, IsIncomingStreamIdInValidAboveLimit) { + QuicStreamId stream_id = + GetNthIncomingStreamId(stream_id_manager_.incoming_actual_max_streams()); + std::string error_details; + EXPECT_FALSE(stream_id_manager_.MaybeIncreaseLargestPeerStreamId( + stream_id, &error_details)); + EXPECT_EQ(error_details, + absl::StrCat("Stream id ", stream_id, + " would exceed stream count limit 100")); +} + +TEST_P(QuicStreamIdManagerTest, OnStreamsBlockedFrame) { + // Get the current maximum allowed incoming stream count. + QuicStreamCount advertised_stream_count = + stream_id_manager_.incoming_advertised_max_streams(); + + QuicStreamsBlockedFrame frame; + + frame.unidirectional = IsUnidirectional(); + + // If the peer is saying it's blocked on the stream count that + // we've advertised, it's a noop since the peer has the correct information. + frame.stream_count = advertised_stream_count; + std::string error_details; + EXPECT_TRUE(stream_id_manager_.OnStreamsBlockedFrame(frame, &error_details)); + + // If the peer is saying it's blocked on a stream count that is larger + // than what we've advertised, the connection should get closed. + frame.stream_count = advertised_stream_count + 1; + EXPECT_FALSE(stream_id_manager_.OnStreamsBlockedFrame(frame, &error_details)); + EXPECT_EQ( + error_details, + "StreamsBlockedFrame's stream count 101 exceeds incoming max stream 100"); + + // If the peer is saying it's blocked on a count that is less than + // our actual count, we send a MAX_STREAMS frame and update + // the advertised value. + // First, need to bump up the actual max so there is room for the MAX + // STREAMS frame to send a larger ID. + QuicStreamCount actual_stream_count = + stream_id_manager_.incoming_actual_max_streams(); + + // Closing a stream will result in the ability to initiate one more + // stream + stream_id_manager_.OnStreamClosed( + QuicStreamIdManagerPeer::GetFirstIncomingStreamId(&stream_id_manager_)); + EXPECT_EQ(actual_stream_count + 1u, + stream_id_manager_.incoming_actual_max_streams()); + EXPECT_EQ(stream_id_manager_.incoming_actual_max_streams(), + stream_id_manager_.incoming_advertised_max_streams() + 1u); + + // Now simulate receiving a STREAMS_BLOCKED frame... + // Changing the actual maximum, above, forces a MAX_STREAMS frame to be + // sent, so the logic for that (SendMaxStreamsFrame(), etc) is tested. + + // The STREAMS_BLOCKED frame contains the previous advertised count, + // not the one that the peer would have received as a result of the + // MAX_STREAMS sent earler. + frame.stream_count = advertised_stream_count; + + EXPECT_CALL(delegate_, + SendMaxStreams(stream_id_manager_.incoming_actual_max_streams(), + IsUnidirectional())); + + EXPECT_TRUE(stream_id_manager_.OnStreamsBlockedFrame(frame, &error_details)); + // Check that the saved frame is correct. + EXPECT_EQ(stream_id_manager_.incoming_actual_max_streams(), + stream_id_manager_.incoming_advertised_max_streams()); +} + +TEST_P(QuicStreamIdManagerTest, GetNextOutgoingStream) { + // Number of streams we can open and the first one we should get when + // opening... + size_t number_of_streams = kDefaultMaxStreamsPerConnection; + + EXPECT_TRUE( + stream_id_manager_.MaybeAllowNewOutgoingStreams(number_of_streams)); + + QuicStreamId stream_id = IsUnidirectional() + ? QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), perspective()) + : QuicUtils::GetFirstBidirectionalStreamId( + transport_version(), perspective()); + + EXPECT_EQ(number_of_streams, stream_id_manager_.outgoing_max_streams()); + while (number_of_streams) { + EXPECT_TRUE(stream_id_manager_.CanOpenNextOutgoingStream()); + EXPECT_EQ(stream_id, stream_id_manager_.GetNextOutgoingStreamId()); + stream_id += QuicUtils::StreamIdDelta(transport_version()); + number_of_streams--; + } + + // If we try to check that the next outgoing stream id is available it should + // fail. + EXPECT_FALSE(stream_id_manager_.CanOpenNextOutgoingStream()); + + // If we try to get the next id (above the limit), it should cause a quic-bug. + EXPECT_QUIC_BUG( + stream_id_manager_.GetNextOutgoingStreamId(), + "Attempt to allocate a new outgoing stream that would exceed the limit"); +} + +TEST_P(QuicStreamIdManagerTest, MaybeIncreaseLargestPeerStreamId) { + QuicStreamId max_stream_id = GetNthIncomingStreamId( + stream_id_manager_.incoming_actual_max_streams() - 1); + EXPECT_TRUE(stream_id_manager_.MaybeIncreaseLargestPeerStreamId(max_stream_id, + nullptr)); + + QuicStreamId first_stream_id = GetNthIncomingStreamId(0); + EXPECT_TRUE(stream_id_manager_.MaybeIncreaseLargestPeerStreamId( + first_stream_id, nullptr)); + // A bad stream ID results in a closed connection. + std::string error_details; + EXPECT_FALSE(stream_id_manager_.MaybeIncreaseLargestPeerStreamId( + max_stream_id + QuicUtils::StreamIdDelta(transport_version()), + &error_details)); + EXPECT_EQ(error_details, + absl::StrCat( + "Stream id ", + max_stream_id + QuicUtils::StreamIdDelta(transport_version()), + " would exceed stream count limit 100")); +} + +TEST_P(QuicStreamIdManagerTest, MaxStreamsWindow) { + // Open and then close a number of streams to get close to the threshold of + // sending a MAX_STREAM_FRAME. + int stream_count = stream_id_manager_.incoming_initial_max_open_streams() / + GetQuicFlag(quic_max_streams_window_divisor) - + 1; + + // Should not get a control-frame transmission since the peer should have + // "plenty" of stream IDs to use. + EXPECT_CALL(delegate_, SendMaxStreams(_, _)).Times(0); + + // Get the first incoming stream ID to try and allocate. + QuicStreamId stream_id = GetNthIncomingStreamId(0); + size_t old_available_incoming_streams = + stream_id_manager_.available_incoming_streams(); + auto i = stream_count; + while (i) { + EXPECT_TRUE(stream_id_manager_.MaybeIncreaseLargestPeerStreamId(stream_id, + nullptr)); + + // This node should think that the peer believes it has one fewer + // stream it can create. + old_available_incoming_streams--; + EXPECT_EQ(old_available_incoming_streams, + stream_id_manager_.available_incoming_streams()); + + i--; + stream_id += QuicUtils::StreamIdDelta(transport_version()); + } + + // Now close them, still should get no MAX_STREAMS + stream_id = GetNthIncomingStreamId(0); + QuicStreamCount expected_actual_max = + stream_id_manager_.incoming_actual_max_streams(); + QuicStreamCount expected_advertised_max_streams = + stream_id_manager_.incoming_advertised_max_streams(); + while (stream_count) { + stream_id_manager_.OnStreamClosed(stream_id); + stream_count--; + stream_id += QuicUtils::StreamIdDelta(transport_version()); + expected_actual_max++; + EXPECT_EQ(expected_actual_max, + stream_id_manager_.incoming_actual_max_streams()); + // Advertised maximum should remain the same. + EXPECT_EQ(expected_advertised_max_streams, + stream_id_manager_.incoming_advertised_max_streams()); + } + + // This should not change. + EXPECT_EQ(old_available_incoming_streams, + stream_id_manager_.available_incoming_streams()); + + // Now whenever we close a stream we should get a MAX_STREAMS frame. + // Above code closed all the open streams, so we have to open/close + // EXPECT_CALL(delegate_, + // SendMaxStreams(stream_id_manager_.incoming_actual_max_streams(), + // IsUnidirectional())); + EXPECT_CALL(delegate_, SendMaxStreams(_, IsUnidirectional())); + EXPECT_TRUE( + stream_id_manager_.MaybeIncreaseLargestPeerStreamId(stream_id, nullptr)); + stream_id_manager_.OnStreamClosed(stream_id); +} + +TEST_P(QuicStreamIdManagerTest, StreamsBlockedEdgeConditions) { + QuicStreamsBlockedFrame frame; + frame.unidirectional = IsUnidirectional(); + + // Check that receipt of a STREAMS BLOCKED with stream-count = 0 does nothing + // when max_allowed_incoming_streams is 0. + EXPECT_CALL(delegate_, SendMaxStreams(_, _)).Times(0); + stream_id_manager_.SetMaxOpenIncomingStreams(0); + frame.stream_count = 0; + std::string error_details; + EXPECT_TRUE(stream_id_manager_.OnStreamsBlockedFrame(frame, &error_details)); + + // Check that receipt of a STREAMS BLOCKED with stream-count = 0 invokes a + // MAX STREAMS, count = 123, when the MaxOpen... is set to 123. + EXPECT_CALL(delegate_, SendMaxStreams(123u, IsUnidirectional())); + QuicStreamIdManagerPeer::set_incoming_actual_max_streams(&stream_id_manager_, + 123); + frame.stream_count = 0; + EXPECT_TRUE(stream_id_manager_.OnStreamsBlockedFrame(frame, &error_details)); +} + +// Test that a MAX_STREAMS frame is generated when half the stream ids become +// available. This has a useful side effect of testing that when streams are +// closed, the number of available stream ids increases. +TEST_P(QuicStreamIdManagerTest, MaxStreamsSlidingWindow) { + QuicStreamCount first_advert = + stream_id_manager_.incoming_advertised_max_streams(); + + // Open/close enough streams to shrink the window without causing a MAX + // STREAMS to be generated. The loop + // will make that many stream IDs available, so the last CloseStream should + // cause a MAX STREAMS frame to be generated. + int i = + static_cast(stream_id_manager_.incoming_initial_max_open_streams() / + GetQuicFlag(quic_max_streams_window_divisor)); + QuicStreamId id = + QuicStreamIdManagerPeer::GetFirstIncomingStreamId(&stream_id_manager_); + EXPECT_CALL(delegate_, SendMaxStreams(first_advert + i, IsUnidirectional())); + while (i) { + EXPECT_TRUE( + stream_id_manager_.MaybeIncreaseLargestPeerStreamId(id, nullptr)); + stream_id_manager_.OnStreamClosed(id); + i--; + id += QuicUtils::StreamIdDelta(transport_version()); + } +} + +TEST_P(QuicStreamIdManagerTest, NewStreamDoesNotExceedLimit) { + EXPECT_TRUE(stream_id_manager_.MaybeAllowNewOutgoingStreams(100)); + + size_t stream_count = stream_id_manager_.outgoing_max_streams(); + EXPECT_NE(0u, stream_count); + + while (stream_count) { + EXPECT_TRUE(stream_id_manager_.CanOpenNextOutgoingStream()); + stream_id_manager_.GetNextOutgoingStreamId(); + stream_count--; + } + + EXPECT_EQ(stream_id_manager_.outgoing_stream_count(), + stream_id_manager_.outgoing_max_streams()); + // Create another, it should fail. + EXPECT_FALSE(stream_id_manager_.CanOpenNextOutgoingStream()); +} + +TEST_P(QuicStreamIdManagerTest, AvailableStreams) { + stream_id_manager_.MaybeIncreaseLargestPeerStreamId(GetNthIncomingStreamId(3), + nullptr); + + EXPECT_TRUE(stream_id_manager_.IsAvailableStream(GetNthIncomingStreamId(1))); + EXPECT_TRUE(stream_id_manager_.IsAvailableStream(GetNthIncomingStreamId(2))); + EXPECT_FALSE(stream_id_manager_.IsAvailableStream(GetNthIncomingStreamId(3))); + EXPECT_TRUE(stream_id_manager_.IsAvailableStream(GetNthIncomingStreamId(4))); +} + +// Tests that if MaybeIncreaseLargestPeerStreamId is given an extremely +// large stream ID (larger than the limit) it is rejected. +// This is a regression for Chromium bugs 909987 and 910040 +TEST_P(QuicStreamIdManagerTest, ExtremeMaybeIncreaseLargestPeerStreamId) { + QuicStreamId too_big_stream_id = GetNthIncomingStreamId( + stream_id_manager_.incoming_actual_max_streams() + 20); + + std::string error_details; + EXPECT_FALSE(stream_id_manager_.MaybeIncreaseLargestPeerStreamId( + too_big_stream_id, &error_details)); + EXPECT_EQ(error_details, + absl::StrCat("Stream id ", too_big_stream_id, + " would exceed stream count limit 100")); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_stream_priority.cc b/quiche/quic/core/quic_stream_priority.cc new file mode 100644 index 000000000000..5199a438bdd8 --- /dev/null +++ b/quiche/quic/core/quic_stream_priority.cc @@ -0,0 +1,85 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_stream_priority.h" + +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/structured_headers.h" + +namespace quic { + +std::string SerializePriorityFieldValue(HttpStreamPriority priority) { + quiche::structured_headers::Dictionary dictionary; + + // TODO(b/266722347): Never send `urgency` if value equals default value. + if ((priority.urgency != HttpStreamPriority::kDefaultUrgency || + priority.incremental != HttpStreamPriority::kDefaultIncremental) && + priority.urgency >= HttpStreamPriority::kMinimumUrgency && + priority.urgency <= HttpStreamPriority::kMaximumUrgency) { + dictionary[HttpStreamPriority::kUrgencyKey] = + quiche::structured_headers::ParameterizedMember( + quiche::structured_headers::Item( + static_cast(priority.urgency)), + {}); + } + + if (priority.incremental != HttpStreamPriority::kDefaultIncremental) { + dictionary[HttpStreamPriority::kIncrementalKey] = + quiche::structured_headers::ParameterizedMember( + quiche::structured_headers::Item(priority.incremental), {}); + } + + absl::optional priority_field_value = + quiche::structured_headers::SerializeDictionary(dictionary); + if (!priority_field_value.has_value()) { + QUICHE_BUG(priority_field_value_serialization_failed); + return ""; + } + + return *priority_field_value; +} + +absl::optional ParsePriorityFieldValue( + absl::string_view priority_field_value) { + absl::optional parsed_dictionary = + quiche::structured_headers::ParseDictionary(priority_field_value); + if (!parsed_dictionary.has_value()) { + return absl::nullopt; + } + + uint8_t urgency = HttpStreamPriority::kDefaultUrgency; + bool incremental = HttpStreamPriority::kDefaultIncremental; + + for (const auto& [name, value] : *parsed_dictionary) { + if (value.member_is_inner_list) { + continue; + } + + const std::vector& member = + value.member; + if (member.size() != 1) { + // If `member_is_inner_list` is false above, + // then `member` should have exactly one element. + QUICHE_BUG(priority_field_value_parsing_internal_error); + continue; + } + + const quiche::structured_headers::Item item = member[0].item; + if (name == HttpStreamPriority::kUrgencyKey && item.is_integer()) { + int parsed_urgency = item.GetInteger(); + // Ignore out-of-range values. + if (parsed_urgency >= HttpStreamPriority::kMinimumUrgency && + parsed_urgency <= HttpStreamPriority::kMaximumUrgency) { + urgency = parsed_urgency; + } + } else if (name == HttpStreamPriority::kIncrementalKey && + item.is_boolean()) { + incremental = item.GetBoolean(); + } + } + + return HttpStreamPriority{urgency, incremental}; +} + +} // namespace quic diff --git a/quiche/quic/core/quic_stream_priority.h b/quiche/quic/core/quic_stream_priority.h new file mode 100644 index 000000000000..7a4b9d09d6a4 --- /dev/null +++ b/quiche/quic/core/quic_stream_priority.h @@ -0,0 +1,142 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_STREAM_PRIORITY_H_ +#define QUICHE_QUIC_CORE_QUIC_STREAM_PRIORITY_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace quic { + +// Represents HTTP priorities as defined by RFC 9218. +struct QUICHE_EXPORT HttpStreamPriority { + static constexpr int kMinimumUrgency = 0; + static constexpr int kMaximumUrgency = 7; + static constexpr int kDefaultUrgency = 3; + static constexpr bool kDefaultIncremental = false; + + // Parameter names for Priority Field Value. + static constexpr absl::string_view kUrgencyKey = "u"; + static constexpr absl::string_view kIncrementalKey = "i"; + + int urgency = kDefaultUrgency; + bool incremental = kDefaultIncremental; + + bool operator==(const HttpStreamPriority& other) const { + return std::tie(urgency, incremental) == + std::tie(other.urgency, other.incremental); + } + + bool operator!=(const HttpStreamPriority& other) const { + return !(*this == other); + } +}; + +// Represents WebTransport priorities as defined by +// . +struct QUICHE_EXPORT WebTransportStreamPriority { + enum class StreamType : uint8_t { + // WebTransport data streams. + kData = 0, + // Regular HTTP traffic. Since we're currently only supporting dedicated + // HTTP/3 transport, this means that all HTTP traffic is control traffic, + // and thus should always go first. + kHttp = 1, + // Streams that the QUIC stack declares as static. + kStatic = 2, + }; + + // Allows prioritizing control streams over the data streams. + StreamType stream_type = StreamType::kData; + // https://w3c.github.io/webtransport/#dom-webtransportsendstreamoptions-sendorder + int64_t send_order = 0; + + bool operator==(const WebTransportStreamPriority& other) const { + return stream_type == other.stream_type && send_order == other.send_order; + } + bool operator!=(const WebTransportStreamPriority& other) const { + return !(*this == other); + } +}; + +// A class that wraps different types of priorities that can be used for +// scheduling QUIC streams. +class QUICHE_EXPORT QuicStreamPriority { + public: + explicit QuicStreamPriority(HttpStreamPriority priority) : value_(priority) {} + explicit QuicStreamPriority(WebTransportStreamPriority priority) + : value_(priority) {} + + static QuicStreamPriority Default(QuicPriorityType type) { + switch (type) { + case QuicPriorityType::kHttp: + return QuicStreamPriority(HttpStreamPriority()); + case QuicPriorityType::kWebTransport: + return QuicStreamPriority(WebTransportStreamPriority()); + } + + QUICHE_BUG(unhandled_quic_priority_type_518918225) + << "Tried to create QuicStreamPriority for unknown QuicPriorityType " + << type; + return QuicStreamPriority(HttpStreamPriority()); + } + + QuicPriorityType type() const { return absl::visit(TypeExtractor(), value_); } + + HttpStreamPriority http() const { + if (absl::holds_alternative(value_)) { + return absl::get(value_); + } + QUICHE_BUG(invalid_priority_type_http) + << "Tried to access HTTP priority for a priority type" << type(); + return HttpStreamPriority(); + } + WebTransportStreamPriority web_transport() const { + if (absl::holds_alternative(value_)) { + return absl::get(value_); + } + QUICHE_BUG(invalid_priority_type_wt) + << "Tried to access WebTransport priority for a priority type" + << type(); + return WebTransportStreamPriority(); + } + + bool operator==(const QuicStreamPriority& other) const { + return value_ == other.value_; + } + + private: + struct TypeExtractor { + QuicPriorityType operator()(const HttpStreamPriority&) { + return QuicPriorityType::kHttp; + } + QuicPriorityType operator()(const WebTransportStreamPriority&) { + return QuicPriorityType::kWebTransport; + } + }; + + absl::variant value_; +}; + +// Serializes the Priority Field Value for a PRIORITY_UPDATE frame. +QUICHE_EXPORT std::string SerializePriorityFieldValue( + HttpStreamPriority priority); + +// Parses the Priority Field Value field of a PRIORITY_UPDATE frame. +// Returns nullopt on failure. +QUICHE_EXPORT absl::optional ParsePriorityFieldValue( + absl::string_view priority_field_value); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_STREAM_PRIORITY_H_ diff --git a/quiche/quic/core/quic_stream_priority_test.cc b/quiche/quic/core/quic_stream_priority_test.cc new file mode 100644 index 000000000000..db5d1a9c884f --- /dev/null +++ b/quiche/quic/core/quic_stream_priority_test.cc @@ -0,0 +1,160 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_stream_priority.h" + +#include "quiche/quic/core/quic_types.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace quic::test { + +TEST(HttpStreamPriority, DefaultConstructed) { + HttpStreamPriority priority; + + EXPECT_EQ(HttpStreamPriority::kDefaultUrgency, priority.urgency); + EXPECT_EQ(HttpStreamPriority::kDefaultIncremental, priority.incremental); +} + +TEST(HttpStreamPriority, Equals) { + EXPECT_EQ((HttpStreamPriority()), + (HttpStreamPriority{HttpStreamPriority::kDefaultUrgency, + HttpStreamPriority::kDefaultIncremental})); + EXPECT_EQ((HttpStreamPriority{5, true}), (HttpStreamPriority{5, true})); + EXPECT_EQ((HttpStreamPriority{2, false}), (HttpStreamPriority{2, false})); + EXPECT_EQ((HttpStreamPriority{11, true}), (HttpStreamPriority{11, true})); + + EXPECT_NE((HttpStreamPriority{1, true}), (HttpStreamPriority{3, true})); + EXPECT_NE((HttpStreamPriority{4, false}), (HttpStreamPriority{4, true})); + EXPECT_NE((HttpStreamPriority{6, true}), (HttpStreamPriority{2, false})); + EXPECT_NE((HttpStreamPriority{12, true}), (HttpStreamPriority{9, true})); + EXPECT_NE((HttpStreamPriority{2, false}), (HttpStreamPriority{8, false})); +} + +TEST(WebTransportStreamPriority, DefaultConstructed) { + WebTransportStreamPriority priority; + + EXPECT_EQ(priority.stream_type, + WebTransportStreamPriority::StreamType::kData); + EXPECT_EQ(priority.send_order, 0); +} + +TEST(WebTransportStreamPriority, Equals) { + EXPECT_EQ(WebTransportStreamPriority(), + (WebTransportStreamPriority{ + WebTransportStreamPriority::StreamType::kData, 0})); + EXPECT_NE(WebTransportStreamPriority(), + (WebTransportStreamPriority{ + WebTransportStreamPriority::StreamType::kData, 1})); + EXPECT_NE(WebTransportStreamPriority(), + (WebTransportStreamPriority{ + WebTransportStreamPriority::StreamType::kHttp, 0})); +} + +TEST(QuicStreamPriority, Default) { + EXPECT_EQ(QuicStreamPriority::Default(QuicPriorityType::kHttp).http(), + HttpStreamPriority()); + EXPECT_EQ(QuicStreamPriority::Default(QuicPriorityType::kWebTransport) + .web_transport(), + WebTransportStreamPriority()); +} + +TEST(QuicStreamPriority, Equals) { + EXPECT_EQ(QuicStreamPriority::Default(QuicPriorityType::kHttp), + QuicStreamPriority(HttpStreamPriority())); + EXPECT_EQ(QuicStreamPriority::Default(QuicPriorityType::kWebTransport), + QuicStreamPriority(WebTransportStreamPriority())); +} + +TEST(QuicStreamPriority, Type) { + EXPECT_EQ(QuicStreamPriority(HttpStreamPriority()).type(), + QuicPriorityType::kHttp); + EXPECT_EQ(QuicStreamPriority(WebTransportStreamPriority()).type(), + QuicPriorityType::kWebTransport); +} + +TEST(SerializePriorityFieldValueTest, SerializePriorityFieldValue) { + // Default value is omitted. + EXPECT_EQ("", SerializePriorityFieldValue( + {/* urgency = */ 3, /* incremental = */ false})); + EXPECT_EQ("u=5", SerializePriorityFieldValue( + {/* urgency = */ 5, /* incremental = */ false})); + // TODO(b/266722347): Never send `urgency` if value equals default value. + EXPECT_EQ("u=3, i", SerializePriorityFieldValue( + {/* urgency = */ 3, /* incremental = */ true})); + EXPECT_EQ("u=0, i", SerializePriorityFieldValue( + {/* urgency = */ 0, /* incremental = */ true})); + // Out-of-bound value is ignored. + EXPECT_EQ("i", SerializePriorityFieldValue( + {/* urgency = */ 9, /* incremental = */ true})); +} + +TEST(ParsePriorityFieldValueTest, ParsePriorityFieldValue) { + // Default values + absl::optional result = ParsePriorityFieldValue(""); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(3, result->urgency); + EXPECT_FALSE(result->incremental); + + result = ParsePriorityFieldValue("i=?1"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(3, result->urgency); + EXPECT_TRUE(result->incremental); + + result = ParsePriorityFieldValue("u=5"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(5, result->urgency); + EXPECT_FALSE(result->incremental); + + result = ParsePriorityFieldValue("u=5, i"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(5, result->urgency); + EXPECT_TRUE(result->incremental); + + result = ParsePriorityFieldValue("i, u=1"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(1, result->urgency); + EXPECT_TRUE(result->incremental); + + // Duplicate values are allowed. + result = ParsePriorityFieldValue("u=5, i=?1, i=?0, u=2"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(2, result->urgency); + EXPECT_FALSE(result->incremental); + + // Unknown parameters MUST be ignored. + result = ParsePriorityFieldValue("a=42, u=4, i=?0"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(4, result->urgency); + EXPECT_FALSE(result->incremental); + + // Out-of-range values MUST be ignored. + result = ParsePriorityFieldValue("u=-2, i"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(3, result->urgency); + EXPECT_TRUE(result->incremental); + + // Values of unexpected types MUST be ignored. + result = ParsePriorityFieldValue("u=4.2, i=\"foo\""); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(3, result->urgency); + EXPECT_FALSE(result->incremental); + + // Values of the right type but different names are ignored. + result = ParsePriorityFieldValue("a=4, b=?1"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(3, result->urgency); + EXPECT_FALSE(result->incremental); + + // Cannot be parsed as structured headers. + result = ParsePriorityFieldValue("000"); + EXPECT_FALSE(result.has_value()); + + // Inner list dictionary values are ignored. + result = ParsePriorityFieldValue("a=(1 2), u=1"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(1, result->urgency); + EXPECT_FALSE(result->incremental); +} + +} // namespace quic::test diff --git a/quiche/quic/core/quic_stream_send_buffer.cc b/quiche/quic/core/quic_stream_send_buffer.cc new file mode 100644 index 000000000000..a8657b56a365 --- /dev/null +++ b/quiche/quic/core/quic_stream_send_buffer.cc @@ -0,0 +1,293 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_stream_send_buffer.h" + +#include + +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/core/quic_interval.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" + +namespace quic { + +namespace { + +struct CompareOffset { + bool operator()(const BufferedSlice& slice, QuicStreamOffset offset) const { + return slice.offset + slice.slice.length() < offset; + } +}; + +} // namespace + +BufferedSlice::BufferedSlice(quiche::QuicheMemSlice mem_slice, + QuicStreamOffset offset) + : slice(std::move(mem_slice)), offset(offset) {} + +BufferedSlice::BufferedSlice(BufferedSlice&& other) = default; + +BufferedSlice& BufferedSlice::operator=(BufferedSlice&& other) = default; + +BufferedSlice::~BufferedSlice() {} + +QuicInterval BufferedSlice::interval() const { + const std::size_t length = slice.length(); + return QuicInterval(offset, offset + length); +} + +bool StreamPendingRetransmission::operator==( + const StreamPendingRetransmission& other) const { + return offset == other.offset && length == other.length; +} + +QuicStreamSendBuffer::QuicStreamSendBuffer( + quiche::QuicheBufferAllocator* allocator) + : current_end_offset_(0), + stream_offset_(0), + allocator_(allocator), + stream_bytes_written_(0), + stream_bytes_outstanding_(0), + write_index_(-1) {} + +QuicStreamSendBuffer::~QuicStreamSendBuffer() {} + +void QuicStreamSendBuffer::SaveStreamData(absl::string_view data) { + QUICHE_DCHECK(!data.empty()); + + // Latch the maximum data slice size. + const QuicByteCount max_data_slice_size = + GetQuicFlag(quic_send_buffer_max_data_slice_size); + while (!data.empty()) { + auto slice_len = std::min( + data.length(), max_data_slice_size); + auto buffer = + quiche::QuicheBuffer::Copy(allocator_, data.substr(0, slice_len)); + SaveMemSlice(quiche::QuicheMemSlice(std::move(buffer))); + + data = data.substr(slice_len); + } +} + +void QuicStreamSendBuffer::SaveMemSlice(quiche::QuicheMemSlice slice) { + QUIC_DVLOG(2) << "Save slice offset " << stream_offset_ << " length " + << slice.length(); + if (slice.empty()) { + QUIC_BUG(quic_bug_10853_1) << "Try to save empty MemSlice to send buffer."; + return; + } + size_t length = slice.length(); + // Need to start the offsets at the right interval. + if (interval_deque_.Empty()) { + const QuicStreamOffset end = stream_offset_ + length; + current_end_offset_ = std::max(current_end_offset_, end); + } + BufferedSlice bs = BufferedSlice(std::move(slice), stream_offset_); + interval_deque_.PushBack(std::move(bs)); + stream_offset_ += length; +} + +QuicByteCount QuicStreamSendBuffer::SaveMemSliceSpan( + absl::Span span) { + QuicByteCount total = 0; + for (quiche::QuicheMemSlice& slice : span) { + if (slice.length() == 0) { + // Skip empty slices. + continue; + } + total += slice.length(); + SaveMemSlice(std::move(slice)); + } + return total; +} + +void QuicStreamSendBuffer::OnStreamDataConsumed(size_t bytes_consumed) { + stream_bytes_written_ += bytes_consumed; + stream_bytes_outstanding_ += bytes_consumed; +} + +bool QuicStreamSendBuffer::WriteStreamData(QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* writer) { + QUIC_BUG_IF(quic_bug_12823_1, current_end_offset_ < offset) + << "Tried to write data out of sequence. last_offset_end:" + << current_end_offset_ << ", offset:" << offset; + // The iterator returned from |interval_deque_| will automatically advance + // the internal write index for the QuicIntervalDeque. The incrementing is + // done in operator++. + for (auto slice_it = interval_deque_.DataAt(offset); + slice_it != interval_deque_.DataEnd(); ++slice_it) { + if (data_length == 0 || offset < slice_it->offset) { + break; + } + + QuicByteCount slice_offset = offset - slice_it->offset; + QuicByteCount available_bytes_in_slice = + slice_it->slice.length() - slice_offset; + QuicByteCount copy_length = std::min(data_length, available_bytes_in_slice); + if (!writer->WriteBytes(slice_it->slice.data() + slice_offset, + copy_length)) { + QUIC_BUG(quic_bug_10853_2) << "Writer fails to write."; + return false; + } + offset += copy_length; + data_length -= copy_length; + const QuicStreamOffset new_end = + slice_it->offset + slice_it->slice.length(); + current_end_offset_ = std::max(current_end_offset_, new_end); + } + return data_length == 0; +} + +bool QuicStreamSendBuffer::OnStreamDataAcked( + QuicStreamOffset offset, QuicByteCount data_length, + QuicByteCount* newly_acked_length) { + *newly_acked_length = 0; + if (data_length == 0) { + return true; + } + if (bytes_acked_.Empty() || offset >= bytes_acked_.rbegin()->max() || + bytes_acked_.IsDisjoint( + QuicInterval(offset, offset + data_length))) { + // Optimization for the typical case, when all data is newly acked. + if (stream_bytes_outstanding_ < data_length) { + return false; + } + bytes_acked_.AddOptimizedForAppend(offset, offset + data_length); + *newly_acked_length = data_length; + stream_bytes_outstanding_ -= data_length; + pending_retransmissions_.Difference(offset, offset + data_length); + if (!FreeMemSlices(offset, offset + data_length)) { + return false; + } + CleanUpBufferedSlices(); + return true; + } + // Exit if no new data gets acked. + if (bytes_acked_.Contains(offset, offset + data_length)) { + return true; + } + // Execute the slow path if newly acked data fill in existing holes. + QuicIntervalSet newly_acked(offset, offset + data_length); + newly_acked.Difference(bytes_acked_); + for (const auto& interval : newly_acked) { + *newly_acked_length += (interval.max() - interval.min()); + } + if (stream_bytes_outstanding_ < *newly_acked_length) { + return false; + } + stream_bytes_outstanding_ -= *newly_acked_length; + bytes_acked_.Add(offset, offset + data_length); + pending_retransmissions_.Difference(offset, offset + data_length); + if (newly_acked.Empty()) { + return true; + } + if (!FreeMemSlices(newly_acked.begin()->min(), newly_acked.rbegin()->max())) { + return false; + } + CleanUpBufferedSlices(); + return true; +} + +void QuicStreamSendBuffer::OnStreamDataLost(QuicStreamOffset offset, + QuicByteCount data_length) { + if (data_length == 0) { + return; + } + QuicIntervalSet bytes_lost(offset, offset + data_length); + bytes_lost.Difference(bytes_acked_); + if (bytes_lost.Empty()) { + return; + } + for (const auto& lost : bytes_lost) { + pending_retransmissions_.Add(lost.min(), lost.max()); + } +} + +void QuicStreamSendBuffer::OnStreamDataRetransmitted( + QuicStreamOffset offset, QuicByteCount data_length) { + if (data_length == 0) { + return; + } + pending_retransmissions_.Difference(offset, offset + data_length); +} + +bool QuicStreamSendBuffer::HasPendingRetransmission() const { + return !pending_retransmissions_.Empty(); +} + +StreamPendingRetransmission QuicStreamSendBuffer::NextPendingRetransmission() + const { + if (HasPendingRetransmission()) { + const auto pending = pending_retransmissions_.begin(); + return {pending->min(), pending->max() - pending->min()}; + } + QUIC_BUG(quic_bug_10853_3) + << "NextPendingRetransmission is called unexpected with no " + "pending retransmissions."; + return {0, 0}; +} + +bool QuicStreamSendBuffer::FreeMemSlices(QuicStreamOffset start, + QuicStreamOffset end) { + auto it = interval_deque_.DataBegin(); + if (it == interval_deque_.DataEnd() || it->slice.empty()) { + QUIC_BUG(quic_bug_10853_4) + << "Trying to ack stream data [" << start << ", " << end << "), " + << (it == interval_deque_.DataEnd() + ? "and there is no outstanding data." + : "and the first slice is empty."); + return false; + } + if (!it->interval().Contains(start)) { + // Slow path that not the earliest outstanding data gets acked. + it = std::lower_bound(interval_deque_.DataBegin(), + interval_deque_.DataEnd(), start, CompareOffset()); + } + if (it == interval_deque_.DataEnd() || it->slice.empty()) { + QUIC_BUG(quic_bug_10853_5) + << "Offset " << start << " with iterator offset: " << it->offset + << (it == interval_deque_.DataEnd() ? " does not exist." + : " has already been acked."); + return false; + } + for (; it != interval_deque_.DataEnd(); ++it) { + if (it->offset >= end) { + break; + } + if (!it->slice.empty() && + bytes_acked_.Contains(it->offset, it->offset + it->slice.length())) { + it->slice.Reset(); + } + } + return true; +} + +void QuicStreamSendBuffer::CleanUpBufferedSlices() { + while (!interval_deque_.Empty() && + interval_deque_.DataBegin()->slice.empty()) { + QUIC_BUG_IF(quic_bug_12823_2, + interval_deque_.DataBegin()->offset > current_end_offset_) + << "Fail to pop front from interval_deque_. Front element contained " + "a slice whose data has not all be written. Front offset " + << interval_deque_.DataBegin()->offset << " length " + << interval_deque_.DataBegin()->slice.length(); + interval_deque_.PopFront(); + } +} + +bool QuicStreamSendBuffer::IsStreamDataOutstanding( + QuicStreamOffset offset, QuicByteCount data_length) const { + return data_length > 0 && + !bytes_acked_.Contains(offset, offset + data_length); +} + +size_t QuicStreamSendBuffer::size() const { return interval_deque_.Size(); } + +} // namespace quic diff --git a/quiche/quic/core/quic_stream_send_buffer.h b/quiche/quic/core/quic_stream_send_buffer.h new file mode 100644 index 000000000000..b74968a1926d --- /dev/null +++ b/quiche/quic/core/quic_stream_send_buffer.h @@ -0,0 +1,171 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_STREAM_SEND_BUFFER_H_ +#define QUICHE_QUIC_CORE_QUIC_STREAM_SEND_BUFFER_H_ + +#include "absl/types/span.h" +#include "quiche/quic/core/frames/quic_stream_frame.h" +#include "quiche/quic/core/quic_interval_deque.h" +#include "quiche/quic/core/quic_interval_set.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/common/quiche_circular_deque.h" + +namespace quic { + +namespace test { +class QuicStreamSendBufferPeer; +class QuicStreamPeer; +} // namespace test + +class QuicDataWriter; + +// BufferedSlice comprises information of a piece of stream data stored in +// contiguous memory space. Please note, BufferedSlice is constructed when +// stream data is saved in send buffer and is removed when stream data is fully +// acked. It is move-only. +struct QUIC_EXPORT_PRIVATE BufferedSlice { + BufferedSlice(quiche::QuicheMemSlice mem_slice, QuicStreamOffset offset); + BufferedSlice(BufferedSlice&& other); + BufferedSlice& operator=(BufferedSlice&& other); + + BufferedSlice(const BufferedSlice& other) = delete; + BufferedSlice& operator=(const BufferedSlice& other) = delete; + ~BufferedSlice(); + + // Return an interval representing the offset and length. + QuicInterval interval() const; + + // Stream data of this data slice. + quiche::QuicheMemSlice slice; + // Location of this data slice in the stream. + QuicStreamOffset offset; +}; + +struct QUIC_EXPORT_PRIVATE StreamPendingRetransmission { + constexpr StreamPendingRetransmission(QuicStreamOffset offset, + QuicByteCount length) + : offset(offset), length(length) {} + + // Starting offset of this pending retransmission. + QuicStreamOffset offset; + // Length of this pending retransmission. + QuicByteCount length; + + bool operator==(const StreamPendingRetransmission& other) const; +}; + +// QuicStreamSendBuffer contains a list of QuicStreamDataSlices. New data slices +// are added to the tail of the list. Data slices are removed from the head of +// the list when they get fully acked. Stream data can be retrieved and acked +// across slice boundaries. +class QUIC_EXPORT_PRIVATE QuicStreamSendBuffer { + public: + explicit QuicStreamSendBuffer(quiche::QuicheBufferAllocator* allocator); + QuicStreamSendBuffer(const QuicStreamSendBuffer& other) = delete; + QuicStreamSendBuffer(QuicStreamSendBuffer&& other) = delete; + ~QuicStreamSendBuffer(); + + // Save |data| to send buffer. + void SaveStreamData(absl::string_view data); + + // Save |slice| to send buffer. + void SaveMemSlice(quiche::QuicheMemSlice slice); + + // Save all slices in |span| to send buffer. Return total bytes saved. + QuicByteCount SaveMemSliceSpan(absl::Span span); + + // Called when |bytes_consumed| bytes has been consumed by the stream. + void OnStreamDataConsumed(size_t bytes_consumed); + + // Write |data_length| of data starts at |offset|. + bool WriteStreamData(QuicStreamOffset offset, QuicByteCount data_length, + QuicDataWriter* writer); + + // Called when data [offset, offset + data_length) is acked or removed as + // stream is canceled. Removes fully acked data slice from send buffer. Set + // |newly_acked_length|. Returns false if trying to ack unsent data. + bool OnStreamDataAcked(QuicStreamOffset offset, QuicByteCount data_length, + QuicByteCount* newly_acked_length); + + // Called when data [offset, offset + data_length) is considered as lost. + void OnStreamDataLost(QuicStreamOffset offset, QuicByteCount data_length); + + // Called when data [offset, offset + length) was retransmitted. + void OnStreamDataRetransmitted(QuicStreamOffset offset, + QuicByteCount data_length); + + // Returns true if there is pending retransmissions. + bool HasPendingRetransmission() const; + + // Returns next pending retransmissions. + StreamPendingRetransmission NextPendingRetransmission() const; + + // Returns true if data [offset, offset + data_length) is outstanding and + // waiting to be acked. Returns false otherwise. + bool IsStreamDataOutstanding(QuicStreamOffset offset, + QuicByteCount data_length) const; + + // Number of data slices in send buffer. + size_t size() const; + + QuicStreamOffset stream_offset() const { return stream_offset_; } + + uint64_t stream_bytes_written() const { return stream_bytes_written_; } + + uint64_t stream_bytes_outstanding() const { + return stream_bytes_outstanding_; + } + + const QuicIntervalSet& bytes_acked() const { + return bytes_acked_; + } + + const QuicIntervalSet& pending_retransmissions() const { + return pending_retransmissions_; + } + + private: + friend class test::QuicStreamSendBufferPeer; + friend class test::QuicStreamPeer; + + // Called when data within offset [start, end) gets acked. Frees fully + // acked buffered slices if any. Returns false if the corresponding data does + // not exist or has been acked. + bool FreeMemSlices(QuicStreamOffset start, QuicStreamOffset end); + + // Cleanup empty slices in order from buffered_slices_. + void CleanUpBufferedSlices(); + + // |current_end_offset_| stores the end offset of the current slice to ensure + // data isn't being written out of order when using the |interval_deque_|. + QuicStreamOffset current_end_offset_; + QuicIntervalDeque interval_deque_; + + // Offset of next inserted byte. + QuicStreamOffset stream_offset_; + + quiche::QuicheBufferAllocator* allocator_; + + // Bytes that have been consumed by the stream. + uint64_t stream_bytes_written_; + + // Bytes that have been consumed and are waiting to be acked. + uint64_t stream_bytes_outstanding_; + + // Offsets of data that has been acked. + QuicIntervalSet bytes_acked_; + + // Data considered as lost and needs to be retransmitted. + QuicIntervalSet pending_retransmissions_; + + // Index of slice which contains data waiting to be written for the first + // time. -1 if send buffer is empty or all data has been written. + int32_t write_index_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_STREAM_SEND_BUFFER_H_ diff --git a/quiche/quic/core/quic_stream_send_buffer_test.cc b/quiche/quic/core/quic_stream_send_buffer_test.cc new file mode 100644 index 000000000000..f4d6b50b94a8 --- /dev/null +++ b/quiche/quic/core/quic_stream_send_buffer_test.cc @@ -0,0 +1,345 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_stream_send_buffer.h" + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_stream_send_buffer_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/simple_buffer_allocator.h" + +namespace quic { +namespace test { +namespace { + +class QuicStreamSendBufferTest : public QuicTest { + public: + QuicStreamSendBufferTest() : send_buffer_(&allocator_) { + EXPECT_EQ(0u, send_buffer_.size()); + EXPECT_EQ(0u, send_buffer_.stream_bytes_written()); + EXPECT_EQ(0u, send_buffer_.stream_bytes_outstanding()); + // The stream offset should be 0 since nothing is written. + EXPECT_EQ(0u, QuicStreamSendBufferPeer::EndOffset(&send_buffer_)); + + std::string data1 = absl::StrCat( + std::string(1536, 'a'), std::string(256, 'b'), std::string(256, 'c')); + + quiche::QuicheBuffer buffer1(&allocator_, 1024); + memset(buffer1.data(), 'c', buffer1.size()); + quiche::QuicheMemSlice slice1(std::move(buffer1)); + + quiche::QuicheBuffer buffer2(&allocator_, 768); + memset(buffer2.data(), 'd', buffer2.size()); + quiche::QuicheMemSlice slice2(std::move(buffer2)); + + // `data` will be split into two BufferedSlices. + SetQuicFlag(quic_send_buffer_max_data_slice_size, 1024); + send_buffer_.SaveStreamData(data1); + + send_buffer_.SaveMemSlice(std::move(slice1)); + EXPECT_TRUE(slice1.empty()); + send_buffer_.SaveMemSlice(std::move(slice2)); + EXPECT_TRUE(slice2.empty()); + + EXPECT_EQ(4u, send_buffer_.size()); + // At this point, `send_buffer_.interval_deque_` looks like this: + // BufferedSlice1: 'a' * 1024 + // BufferedSlice2: 'a' * 512 + 'b' * 256 + 'c' * 256 + // BufferedSlice3: 'c' * 1024 + // BufferedSlice4: 'd' * 768 + } + + void WriteAllData() { + // Write all data. + char buf[4000]; + QuicDataWriter writer(4000, buf, quiche::HOST_BYTE_ORDER); + send_buffer_.WriteStreamData(0, 3840u, &writer); + + send_buffer_.OnStreamDataConsumed(3840u); + EXPECT_EQ(3840u, send_buffer_.stream_bytes_written()); + EXPECT_EQ(3840u, send_buffer_.stream_bytes_outstanding()); + } + + quiche::SimpleBufferAllocator allocator_; + QuicStreamSendBuffer send_buffer_; +}; + +TEST_F(QuicStreamSendBufferTest, CopyDataToBuffer) { + char buf[4000]; + QuicDataWriter writer(4000, buf, quiche::HOST_BYTE_ORDER); + std::string copy1(1024, 'a'); + std::string copy2 = + std::string(512, 'a') + std::string(256, 'b') + std::string(256, 'c'); + std::string copy3(1024, 'c'); + std::string copy4(768, 'd'); + + ASSERT_TRUE(send_buffer_.WriteStreamData(0, 1024, &writer)); + EXPECT_EQ(copy1, absl::string_view(buf, 1024)); + ASSERT_TRUE(send_buffer_.WriteStreamData(1024, 1024, &writer)); + EXPECT_EQ(copy2, absl::string_view(buf + 1024, 1024)); + ASSERT_TRUE(send_buffer_.WriteStreamData(2048, 1024, &writer)); + EXPECT_EQ(copy3, absl::string_view(buf + 2048, 1024)); + ASSERT_TRUE(send_buffer_.WriteStreamData(3072, 768, &writer)); + EXPECT_EQ(copy4, absl::string_view(buf + 3072, 768)); + + // Test data piece across boundries. + QuicDataWriter writer2(4000, buf, quiche::HOST_BYTE_ORDER); + std::string copy5 = + std::string(536, 'a') + std::string(256, 'b') + std::string(232, 'c'); + ASSERT_TRUE(send_buffer_.WriteStreamData(1000, 1024, &writer2)); + EXPECT_EQ(copy5, absl::string_view(buf, 1024)); + ASSERT_TRUE(send_buffer_.WriteStreamData(2500, 1024, &writer2)); + std::string copy6 = std::string(572, 'c') + std::string(452, 'd'); + EXPECT_EQ(copy6, absl::string_view(buf + 1024, 1024)); + + // Invalid data copy. + QuicDataWriter writer3(4000, buf, quiche::HOST_BYTE_ORDER); + EXPECT_FALSE(send_buffer_.WriteStreamData(3000, 1024, &writer3)); + EXPECT_QUIC_BUG(send_buffer_.WriteStreamData(0, 4000, &writer3), + "Writer fails to write."); + + send_buffer_.OnStreamDataConsumed(3840); + EXPECT_EQ(3840u, send_buffer_.stream_bytes_written()); + EXPECT_EQ(3840u, send_buffer_.stream_bytes_outstanding()); +} + +// Regression test for b/143491027. +TEST_F(QuicStreamSendBufferTest, + WriteStreamDataContainsBothRetransmissionAndNewData) { + std::string copy1(1024, 'a'); + std::string copy2 = + std::string(512, 'a') + std::string(256, 'b') + std::string(256, 'c'); + std::string copy3 = std::string(1024, 'c') + std::string(100, 'd'); + char buf[6000]; + QuicDataWriter writer(6000, buf, quiche::HOST_BYTE_ORDER); + // Write more than one slice. + EXPECT_EQ(0, QuicStreamSendBufferPeer::write_index(&send_buffer_)); + ASSERT_TRUE(send_buffer_.WriteStreamData(0, 1024, &writer)); + EXPECT_EQ(copy1, absl::string_view(buf, 1024)); + EXPECT_EQ(1, QuicStreamSendBufferPeer::write_index(&send_buffer_)); + + // Retransmit the first frame and also send new data. + ASSERT_TRUE(send_buffer_.WriteStreamData(0, 2048, &writer)); + EXPECT_EQ(copy1 + copy2, absl::string_view(buf + 1024, 2048)); + + // Write new data. + EXPECT_EQ(2048u, QuicStreamSendBufferPeer::EndOffset(&send_buffer_)); + ASSERT_TRUE(send_buffer_.WriteStreamData(2048, 50, &writer)); + EXPECT_EQ(std::string(50, 'c'), absl::string_view(buf + 1024 + 2048, 50)); + EXPECT_EQ(3072u, QuicStreamSendBufferPeer::EndOffset(&send_buffer_)); + ASSERT_TRUE(send_buffer_.WriteStreamData(2048, 1124, &writer)); + EXPECT_EQ(copy3, absl::string_view(buf + 1024 + 2048 + 50, 1124)); + EXPECT_EQ(3840u, QuicStreamSendBufferPeer::EndOffset(&send_buffer_)); +} + +TEST_F(QuicStreamSendBufferTest, RemoveStreamFrame) { + WriteAllData(); + + QuicByteCount newly_acked_length; + EXPECT_TRUE(send_buffer_.OnStreamDataAcked(1024, 1024, &newly_acked_length)); + EXPECT_EQ(1024u, newly_acked_length); + EXPECT_EQ(4u, send_buffer_.size()); + + EXPECT_TRUE(send_buffer_.OnStreamDataAcked(2048, 1024, &newly_acked_length)); + EXPECT_EQ(1024u, newly_acked_length); + EXPECT_EQ(4u, send_buffer_.size()); + + EXPECT_TRUE(send_buffer_.OnStreamDataAcked(0, 1024, &newly_acked_length)); + EXPECT_EQ(1024u, newly_acked_length); + + // Send buffer is cleaned up in order. + EXPECT_EQ(1u, send_buffer_.size()); + EXPECT_TRUE(send_buffer_.OnStreamDataAcked(3072, 768, &newly_acked_length)); + EXPECT_EQ(768u, newly_acked_length); + EXPECT_EQ(0u, send_buffer_.size()); +} + +TEST_F(QuicStreamSendBufferTest, RemoveStreamFrameAcrossBoundries) { + WriteAllData(); + + QuicByteCount newly_acked_length; + EXPECT_TRUE(send_buffer_.OnStreamDataAcked(2024, 576, &newly_acked_length)); + EXPECT_EQ(576u, newly_acked_length); + EXPECT_EQ(4u, send_buffer_.size()); + + EXPECT_TRUE(send_buffer_.OnStreamDataAcked(0, 1000, &newly_acked_length)); + EXPECT_EQ(1000u, newly_acked_length); + EXPECT_EQ(4u, send_buffer_.size()); + + EXPECT_TRUE(send_buffer_.OnStreamDataAcked(1000, 1024, &newly_acked_length)); + EXPECT_EQ(1024u, newly_acked_length); + // Send buffer is cleaned up in order. + EXPECT_EQ(2u, send_buffer_.size()); + + EXPECT_TRUE(send_buffer_.OnStreamDataAcked(2600, 1024, &newly_acked_length)); + EXPECT_EQ(1024u, newly_acked_length); + EXPECT_EQ(1u, send_buffer_.size()); + + EXPECT_TRUE(send_buffer_.OnStreamDataAcked(3624, 216, &newly_acked_length)); + EXPECT_EQ(216u, newly_acked_length); + EXPECT_EQ(0u, send_buffer_.size()); +} + +TEST_F(QuicStreamSendBufferTest, AckStreamDataMultipleTimes) { + WriteAllData(); + QuicByteCount newly_acked_length; + EXPECT_TRUE(send_buffer_.OnStreamDataAcked(100, 1500, &newly_acked_length)); + EXPECT_EQ(1500u, newly_acked_length); + EXPECT_EQ(4u, send_buffer_.size()); + + EXPECT_TRUE(send_buffer_.OnStreamDataAcked(2000, 500, &newly_acked_length)); + EXPECT_EQ(500u, newly_acked_length); + EXPECT_EQ(4u, send_buffer_.size()); + + EXPECT_TRUE(send_buffer_.OnStreamDataAcked(0, 2600, &newly_acked_length)); + EXPECT_EQ(600u, newly_acked_length); + // Send buffer is cleaned up in order. + EXPECT_EQ(2u, send_buffer_.size()); + + EXPECT_TRUE(send_buffer_.OnStreamDataAcked(2200, 1640, &newly_acked_length)); + EXPECT_EQ(1240u, newly_acked_length); + EXPECT_EQ(0u, send_buffer_.size()); + + EXPECT_FALSE(send_buffer_.OnStreamDataAcked(4000, 100, &newly_acked_length)); +} + +TEST_F(QuicStreamSendBufferTest, AckStreamDataOutOfOrder) { + WriteAllData(); + QuicByteCount newly_acked_length; + EXPECT_TRUE(send_buffer_.OnStreamDataAcked(500, 1000, &newly_acked_length)); + EXPECT_EQ(1000u, newly_acked_length); + EXPECT_EQ(4u, send_buffer_.size()); + EXPECT_EQ(3840u, QuicStreamSendBufferPeer::TotalLength(&send_buffer_)); + + EXPECT_TRUE(send_buffer_.OnStreamDataAcked(1200, 1000, &newly_acked_length)); + EXPECT_EQ(700u, newly_acked_length); + EXPECT_EQ(4u, send_buffer_.size()); + // Slice 2 gets fully acked. + EXPECT_EQ(2816u, QuicStreamSendBufferPeer::TotalLength(&send_buffer_)); + + EXPECT_TRUE(send_buffer_.OnStreamDataAcked(2000, 1840, &newly_acked_length)); + EXPECT_EQ(1640u, newly_acked_length); + EXPECT_EQ(4u, send_buffer_.size()); + // Slices 3 and 4 get fully acked. + EXPECT_EQ(1024u, QuicStreamSendBufferPeer::TotalLength(&send_buffer_)); + + EXPECT_TRUE(send_buffer_.OnStreamDataAcked(0, 1000, &newly_acked_length)); + EXPECT_EQ(500u, newly_acked_length); + EXPECT_EQ(0u, send_buffer_.size()); + EXPECT_EQ(0u, QuicStreamSendBufferPeer::TotalLength(&send_buffer_)); +} + +TEST_F(QuicStreamSendBufferTest, PendingRetransmission) { + WriteAllData(); + EXPECT_TRUE(send_buffer_.IsStreamDataOutstanding(0, 3840)); + EXPECT_FALSE(send_buffer_.HasPendingRetransmission()); + // Lost data [0, 1200). + send_buffer_.OnStreamDataLost(0, 1200); + // Lost data [1500, 2000). + send_buffer_.OnStreamDataLost(1500, 500); + EXPECT_TRUE(send_buffer_.HasPendingRetransmission()); + + EXPECT_EQ(StreamPendingRetransmission(0, 1200), + send_buffer_.NextPendingRetransmission()); + // Retransmit data [0, 500). + send_buffer_.OnStreamDataRetransmitted(0, 500); + EXPECT_TRUE(send_buffer_.IsStreamDataOutstanding(0, 500)); + EXPECT_EQ(StreamPendingRetransmission(500, 700), + send_buffer_.NextPendingRetransmission()); + // Ack data [500, 1200). + QuicByteCount newly_acked_length = 0; + EXPECT_TRUE(send_buffer_.OnStreamDataAcked(500, 700, &newly_acked_length)); + EXPECT_FALSE(send_buffer_.IsStreamDataOutstanding(500, 700)); + EXPECT_TRUE(send_buffer_.HasPendingRetransmission()); + EXPECT_EQ(StreamPendingRetransmission(1500, 500), + send_buffer_.NextPendingRetransmission()); + // Retransmit data [1500, 2000). + send_buffer_.OnStreamDataRetransmitted(1500, 500); + EXPECT_FALSE(send_buffer_.HasPendingRetransmission()); + + // Lost [200, 800). + send_buffer_.OnStreamDataLost(200, 600); + EXPECT_TRUE(send_buffer_.HasPendingRetransmission()); + // Verify [200, 500) is considered as lost, as [500, 800) has been acked. + EXPECT_EQ(StreamPendingRetransmission(200, 300), + send_buffer_.NextPendingRetransmission()); + + // Verify 0 length data is not outstanding. + EXPECT_FALSE(send_buffer_.IsStreamDataOutstanding(100, 0)); + // Verify partially acked data is outstanding. + EXPECT_TRUE(send_buffer_.IsStreamDataOutstanding(400, 800)); +} + +TEST_F(QuicStreamSendBufferTest, EndOffset) { + char buf[4000]; + QuicDataWriter writer(4000, buf, quiche::HOST_BYTE_ORDER); + + EXPECT_EQ(1024u, QuicStreamSendBufferPeer::EndOffset(&send_buffer_)); + ASSERT_TRUE(send_buffer_.WriteStreamData(0, 1024, &writer)); + // Last offset we've seen is 1024 + EXPECT_EQ(1024u, QuicStreamSendBufferPeer::EndOffset(&send_buffer_)); + + ASSERT_TRUE(send_buffer_.WriteStreamData(1024, 512, &writer)); + // Last offset is now 2048 as that's the end of the next slice. + EXPECT_EQ(2048u, QuicStreamSendBufferPeer::EndOffset(&send_buffer_)); + send_buffer_.OnStreamDataConsumed(1024); + + // If data in 1st slice gets ACK'ed, it shouldn't change the indexed slice + QuicByteCount newly_acked_length; + EXPECT_TRUE(send_buffer_.OnStreamDataAcked(0, 1024, &newly_acked_length)); + // Last offset is still 2048. + EXPECT_EQ(2048u, QuicStreamSendBufferPeer::EndOffset(&send_buffer_)); + + ASSERT_TRUE( + send_buffer_.WriteStreamData(1024 + 512, 3840 - 1024 - 512, &writer)); + + // Last offset is end offset of last slice. + EXPECT_EQ(3840u, QuicStreamSendBufferPeer::EndOffset(&send_buffer_)); + quiche::QuicheBuffer buffer(&allocator_, 60); + memset(buffer.data(), 'e', buffer.size()); + quiche::QuicheMemSlice slice(std::move(buffer)); + send_buffer_.SaveMemSlice(std::move(slice)); + + EXPECT_EQ(3840u, QuicStreamSendBufferPeer::EndOffset(&send_buffer_)); +} + +TEST_F(QuicStreamSendBufferTest, SaveMemSliceSpan) { + quiche::SimpleBufferAllocator allocator; + QuicStreamSendBuffer send_buffer(&allocator); + + std::string data(1024, 'a'); + std::vector buffers; + for (size_t i = 0; i < 10; ++i) { + buffers.push_back(MemSliceFromString(data)); + } + + EXPECT_EQ(10 * 1024u, send_buffer.SaveMemSliceSpan(absl::MakeSpan(buffers))); + EXPECT_EQ(10u, send_buffer.size()); +} + +TEST_F(QuicStreamSendBufferTest, SaveEmptyMemSliceSpan) { + quiche::SimpleBufferAllocator allocator; + QuicStreamSendBuffer send_buffer(&allocator); + + std::string data(1024, 'a'); + std::vector buffers; + for (size_t i = 0; i < 10; ++i) { + buffers.push_back(MemSliceFromString(data)); + } + + EXPECT_EQ(10 * 1024u, send_buffer.SaveMemSliceSpan(absl::MakeSpan(buffers))); + // Verify the empty slice does not get saved. + EXPECT_EQ(10u, send_buffer.size()); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_stream_sequencer.cc b/quiche/quic/core/quic_stream_sequencer.cc new file mode 100644 index 000000000000..540c8c7ac4d6 --- /dev/null +++ b/quiche/quic/core/quic_stream_sequencer.cc @@ -0,0 +1,315 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_stream_sequencer.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_clock.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_stream.h" +#include "quiche/quic/core/quic_stream_sequencer_buffer.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_stack_trace.h" + +namespace quic { + +QuicStreamSequencer::QuicStreamSequencer(StreamInterface* quic_stream) + : stream_(quic_stream), + buffered_frames_(kStreamReceiveWindowLimit), + highest_offset_(0), + close_offset_(std::numeric_limits::max()), + blocked_(false), + num_frames_received_(0), + num_duplicate_frames_received_(0), + ignore_read_data_(false), + level_triggered_(false) {} + +QuicStreamSequencer::~QuicStreamSequencer() { + if (stream_ == nullptr) { + QUIC_BUG(quic_bug_10858_1) << "Double free'ing QuicStreamSequencer at " + << this << ". " << QuicStackTrace(); + } + stream_ = nullptr; +} + +void QuicStreamSequencer::OnStreamFrame(const QuicStreamFrame& frame) { + QUICHE_DCHECK_LE(frame.offset + frame.data_length, close_offset_); + ++num_frames_received_; + const QuicStreamOffset byte_offset = frame.offset; + const size_t data_len = frame.data_length; + + if (frame.fin && + (!CloseStreamAtOffset(frame.offset + data_len) || data_len == 0)) { + return; + } + if (stream_->version().HasIetfQuicFrames() && data_len == 0) { + QUICHE_DCHECK(!frame.fin); + // Ignore empty frame with no fin. + return; + } + OnFrameData(byte_offset, data_len, frame.data_buffer); +} + +void QuicStreamSequencer::OnCryptoFrame(const QuicCryptoFrame& frame) { + ++num_frames_received_; + if (frame.data_length == 0) { + // Ignore empty crypto frame. + return; + } + OnFrameData(frame.offset, frame.data_length, frame.data_buffer); +} + +void QuicStreamSequencer::OnFrameData(QuicStreamOffset byte_offset, + size_t data_len, + const char* data_buffer) { + highest_offset_ = std::max(highest_offset_, byte_offset + data_len); + const size_t previous_readable_bytes = buffered_frames_.ReadableBytes(); + size_t bytes_written; + std::string error_details; + QuicErrorCode result = buffered_frames_.OnStreamData( + byte_offset, absl::string_view(data_buffer, data_len), &bytes_written, + &error_details); + if (result != QUIC_NO_ERROR) { + std::string details = + absl::StrCat("Stream ", stream_->id(), ": ", + QuicErrorCodeToString(result), ": ", error_details); + QUIC_LOG_FIRST_N(WARNING, 50) << QuicErrorCodeToString(result); + QUIC_LOG_FIRST_N(WARNING, 50) << details; + stream_->OnUnrecoverableError(result, details); + return; + } + + if (bytes_written == 0) { + ++num_duplicate_frames_received_; + // Silently ignore duplicates. + return; + } + + if (blocked_) { + return; + } + + if (level_triggered_) { + if (buffered_frames_.ReadableBytes() > previous_readable_bytes) { + // Readable bytes has changed, let stream decide if to inform application + // or not. + if (ignore_read_data_) { + FlushBufferedFrames(); + } else { + stream_->OnDataAvailable(); + } + } + return; + } + const bool stream_unblocked = + previous_readable_bytes == 0 && buffered_frames_.ReadableBytes() > 0; + if (stream_unblocked) { + if (ignore_read_data_) { + FlushBufferedFrames(); + } else { + stream_->OnDataAvailable(); + } + } +} + +bool QuicStreamSequencer::CloseStreamAtOffset(QuicStreamOffset offset) { + const QuicStreamOffset kMaxOffset = + std::numeric_limits::max(); + + // If there is a scheduled close, the new offset should match it. + if (close_offset_ != kMaxOffset && offset != close_offset_) { + stream_->OnUnrecoverableError( + QUIC_STREAM_SEQUENCER_INVALID_STATE, + absl::StrCat( + "Stream ", stream_->id(), " received new final offset: ", offset, + ", which is different from close offset: ", close_offset_)); + return false; + } + + // The final offset should be no less than the highest offset that is + // received. + if (offset < highest_offset_) { + stream_->OnUnrecoverableError( + QUIC_STREAM_SEQUENCER_INVALID_STATE, + absl::StrCat( + "Stream ", stream_->id(), " received fin with offset: ", offset, + ", which reduces current highest offset: ", highest_offset_)); + return false; + } + + close_offset_ = offset; + + MaybeCloseStream(); + return true; +} + +void QuicStreamSequencer::MaybeCloseStream() { + if (blocked_ || !IsClosed()) { + return; + } + + QUIC_DVLOG(1) << "Passing up termination, as we've processed " + << buffered_frames_.BytesConsumed() << " of " << close_offset_ + << " bytes."; + // This will cause the stream to consume the FIN. + // Technically it's an error if |num_bytes_consumed| isn't exactly + // equal to |close_offset|, but error handling seems silly at this point. + if (ignore_read_data_) { + // The sequencer is discarding stream data and must notify the stream on + // receipt of a FIN because the consumer won't. + stream_->OnFinRead(); + } else { + stream_->OnDataAvailable(); + } + buffered_frames_.Clear(); +} + +int QuicStreamSequencer::GetReadableRegions(iovec* iov, size_t iov_len) const { + QUICHE_DCHECK(!blocked_); + return buffered_frames_.GetReadableRegions(iov, iov_len); +} + +bool QuicStreamSequencer::GetReadableRegion(iovec* iov) const { + QUICHE_DCHECK(!blocked_); + return buffered_frames_.GetReadableRegion(iov); +} + +bool QuicStreamSequencer::PeekRegion(QuicStreamOffset offset, + iovec* iov) const { + QUICHE_DCHECK(!blocked_); + return buffered_frames_.PeekRegion(offset, iov); +} + +void QuicStreamSequencer::Read(std::string* buffer) { + QUICHE_DCHECK(!blocked_); + buffer->resize(buffer->size() + ReadableBytes()); + iovec iov; + iov.iov_len = ReadableBytes(); + iov.iov_base = &(*buffer)[buffer->size() - iov.iov_len]; + Readv(&iov, 1); +} + +size_t QuicStreamSequencer::Readv(const struct iovec* iov, size_t iov_len) { + QUICHE_DCHECK(!blocked_); + std::string error_details; + size_t bytes_read; + QuicErrorCode read_error = + buffered_frames_.Readv(iov, iov_len, &bytes_read, &error_details); + if (read_error != QUIC_NO_ERROR) { + std::string details = + absl::StrCat("Stream ", stream_->id(), ": ", error_details); + stream_->OnUnrecoverableError(read_error, details); + return bytes_read; + } + + stream_->AddBytesConsumed(bytes_read); + return bytes_read; +} + +bool QuicStreamSequencer::HasBytesToRead() const { + return buffered_frames_.HasBytesToRead(); +} + +size_t QuicStreamSequencer::ReadableBytes() const { + return buffered_frames_.ReadableBytes(); +} + +bool QuicStreamSequencer::IsClosed() const { + return buffered_frames_.BytesConsumed() >= close_offset_; +} + +void QuicStreamSequencer::MarkConsumed(size_t num_bytes_consumed) { + QUICHE_DCHECK(!blocked_); + bool result = buffered_frames_.MarkConsumed(num_bytes_consumed); + if (!result) { + QUIC_BUG(quic_bug_10858_2) + << "Invalid argument to MarkConsumed." + << " expect to consume: " << num_bytes_consumed + << ", but not enough bytes available. " << DebugString(); + stream_->ResetWithError( + QuicResetStreamError::FromInternal(QUIC_ERROR_PROCESSING_STREAM)); + return; + } + stream_->AddBytesConsumed(num_bytes_consumed); +} + +void QuicStreamSequencer::SetBlockedUntilFlush() { blocked_ = true; } + +void QuicStreamSequencer::SetUnblocked() { + blocked_ = false; + if (IsClosed() || HasBytesToRead()) { + stream_->OnDataAvailable(); + } +} + +void QuicStreamSequencer::StopReading() { + if (ignore_read_data_) { + return; + } + ignore_read_data_ = true; + FlushBufferedFrames(); +} + +void QuicStreamSequencer::ReleaseBuffer() { + buffered_frames_.ReleaseWholeBuffer(); +} + +void QuicStreamSequencer::ReleaseBufferIfEmpty() { + if (buffered_frames_.Empty()) { + buffered_frames_.ReleaseWholeBuffer(); + } +} + +void QuicStreamSequencer::FlushBufferedFrames() { + QUICHE_DCHECK(ignore_read_data_); + size_t bytes_flushed = buffered_frames_.FlushBufferedFrames(); + QUIC_DVLOG(1) << "Flushing buffered data at offset " + << buffered_frames_.BytesConsumed() << " length " + << bytes_flushed << " for stream " << stream_->id(); + stream_->AddBytesConsumed(bytes_flushed); + MaybeCloseStream(); +} + +size_t QuicStreamSequencer::NumBytesBuffered() const { + return buffered_frames_.BytesBuffered(); +} + +QuicStreamOffset QuicStreamSequencer::NumBytesConsumed() const { + return buffered_frames_.BytesConsumed(); +} + +bool QuicStreamSequencer::IsAllDataAvailable() const { + QUICHE_DCHECK_LE(NumBytesConsumed() + NumBytesBuffered(), close_offset_); + return NumBytesConsumed() + NumBytesBuffered() >= close_offset_; +} + +std::string QuicStreamSequencer::DebugString() const { + // clang-format off + return absl::StrCat( + "QuicStreamSequencer: bytes buffered: ", NumBytesBuffered(), + "\n bytes consumed: ", NumBytesConsumed(), + "\n first missing byte: ", buffered_frames_.FirstMissingByte(), + "\n next expected byte: ", buffered_frames_.NextExpectedByte(), + "\n received frames: ", buffered_frames_.ReceivedFramesDebugString(), + "\n has bytes to read: ", HasBytesToRead() ? "true" : "false", + "\n frames received: ", num_frames_received(), + "\n close offset bytes: ", close_offset_, + "\n is closed: ", IsClosed() ? "true" : "false"); + // clang-format on +} + +} // namespace quic diff --git a/quiche/quic/core/quic_stream_sequencer.h b/quiche/quic/core/quic_stream_sequencer.h new file mode 100644 index 000000000000..f0e3ab394ab4 --- /dev/null +++ b/quiche/quic/core/quic_stream_sequencer.h @@ -0,0 +1,220 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_STREAM_SEQUENCER_H_ +#define QUICHE_QUIC_CORE_QUIC_STREAM_SEQUENCER_H_ + +#include +#include +#include + +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_stream_sequencer_buffer.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +namespace test { +class QuicStreamSequencerPeer; +} // namespace test + +// Buffers frames until we have something which can be passed +// up to the next layer. +class QUIC_EXPORT_PRIVATE QuicStreamSequencer final { + public: + // Interface that thie Sequencer uses to communicate with the Stream. + class QUIC_EXPORT_PRIVATE StreamInterface { + public: + virtual ~StreamInterface() = default; + + // Called when new data is available to be read from the sequencer. + virtual void OnDataAvailable() = 0; + // Called when the end of the stream has been read. + virtual void OnFinRead() = 0; + // Called when bytes have been consumed from the sequencer. + virtual void AddBytesConsumed(QuicByteCount bytes) = 0; + // Called when an error has occurred which should result in the stream + // being reset. + virtual void ResetWithError(QuicResetStreamError error) = 0; + // Called when an error has occurred which should result in the connection + // being closed. + virtual void OnUnrecoverableError(QuicErrorCode error, + const std::string& details) = 0; + // Called when an error has occurred which should result in the connection + // being closed, specifying the wire error code |ietf_error| explicitly. + virtual void OnUnrecoverableError(QuicErrorCode error, + QuicIetfTransportErrorCodes ietf_error, + const std::string& details) = 0; + // Returns the stream id of this stream. + virtual QuicStreamId id() const = 0; + + // Returns the QUIC version being used by this stream. + virtual ParsedQuicVersion version() const = 0; + }; + + explicit QuicStreamSequencer(StreamInterface* quic_stream); + QuicStreamSequencer(const QuicStreamSequencer&) = delete; + QuicStreamSequencer(QuicStreamSequencer&&) = default; + QuicStreamSequencer& operator=(const QuicStreamSequencer&) = delete; + QuicStreamSequencer& operator=(QuicStreamSequencer&&) = default; + ~QuicStreamSequencer(); + + // If the frame is the next one we need in order to process in-order data, + // ProcessData will be immediately called on the stream until all buffered + // data is processed or the stream fails to consume data. Any unconsumed + // data will be buffered. If the frame is not the next in line, it will be + // buffered. + void OnStreamFrame(const QuicStreamFrame& frame); + + // If the frame is the next one we need in order to process in-order data, + // ProcessData will be immediately called on the crypto stream until all + // buffered data is processed or the crypto stream fails to consume data. Any + // unconsumed data will be buffered. If the frame is not the next in line, it + // will be buffered. + void OnCryptoFrame(const QuicCryptoFrame& frame); + + // Once data is buffered, it's up to the stream to read it when the stream + // can handle more data. The following three functions make that possible. + + // Fills in up to iov_len iovecs with the next readable regions. Returns the + // number of iovs used. Non-destructive of the underlying data. + int GetReadableRegions(iovec* iov, size_t iov_len) const; + + // Fills in one iovec with the next readable region. Returns false if there + // is no readable region available. + bool GetReadableRegion(iovec* iov) const; + + // Fills in one iovec with the region starting at |offset| and returns true. + // Returns false if no readable region is available, either because data has + // not been received yet or has already been consumed. + bool PeekRegion(QuicStreamOffset offset, iovec* iov) const; + + // Copies the data into the iov_len buffers provided. Returns the number of + // bytes read. Any buffered data no longer in use will be released. + // TODO(rch): remove this method and instead implement it as a helper method + // based on GetReadableRegions and MarkConsumed. + size_t Readv(const struct iovec* iov, size_t iov_len); + + // Consumes |num_bytes| data. Used in conjunction with |GetReadableRegions| + // to do zero-copy reads. + void MarkConsumed(size_t num_bytes); + + // Appends all of the readable data to |buffer| and marks all of the appended + // data as consumed. + void Read(std::string* buffer); + + // Returns true if the sequncer has bytes available for reading. + bool HasBytesToRead() const; + + // Number of bytes available to read. + size_t ReadableBytes() const; + + // Returns true if the sequencer has delivered the fin. + bool IsClosed() const; + + // Calls |OnDataAvailable| on |stream_| if there is buffered data that can + // be processed, and causes |OnDataAvailable| to be called as new data + // arrives. + void SetUnblocked(); + + // Blocks processing of frames until |SetUnblocked| is called. + void SetBlockedUntilFlush(); + + // Sets the sequencer to discard all incoming data itself and not call + // |stream_->OnDataAvailable()|. |stream_->OnFinRead()| will be called + // automatically when the FIN is consumed (which may be immediately). + void StopReading(); + + // Free the memory of underlying buffer. + void ReleaseBuffer(); + + // Free the memory of underlying buffer when no bytes remain in it. + void ReleaseBufferIfEmpty(); + + // Number of bytes in the buffer right now. + size_t NumBytesBuffered() const; + + // Number of bytes has been consumed. + QuicStreamOffset NumBytesConsumed() const; + + // Returns true if all of the data within the stream up until the FIN is + // available. + bool IsAllDataAvailable() const; + + QuicStreamOffset close_offset() const { return close_offset_; } + + int num_frames_received() const { return num_frames_received_; } + + int num_duplicate_frames_received() const { + return num_duplicate_frames_received_; + } + + bool ignore_read_data() const { return ignore_read_data_; } + + void set_level_triggered(bool level_triggered) { + level_triggered_ = level_triggered; + } + + bool level_triggered() const { return level_triggered_; } + + void set_stream(StreamInterface* stream) { stream_ = stream; } + + // Returns string describing internal state. + std::string DebugString() const; + + private: + friend class test::QuicStreamSequencerPeer; + + // Deletes and records as consumed any buffered data that is now in-sequence. + // (To be called only after StopReading has been called.) + void FlushBufferedFrames(); + + // Wait until we've seen 'offset' bytes, and then terminate the stream. + // Returns true if |stream_| is still available to receive data, and false if + // |stream_| is reset. + bool CloseStreamAtOffset(QuicStreamOffset offset); + + // If we've received a FIN and have processed all remaining data, then inform + // the stream of FIN, and clear buffers. + void MaybeCloseStream(); + + // Shared implementation between OnStreamFrame and OnCryptoFrame. + void OnFrameData(QuicStreamOffset byte_offset, size_t data_len, + const char* data_buffer); + + // The stream which owns this sequencer. + StreamInterface* stream_; + + // Stores received data in offset order. + QuicStreamSequencerBuffer buffered_frames_; + + // The highest offset that is received so far. + QuicStreamOffset highest_offset_; + + // The offset, if any, we got a stream termination for. When this many bytes + // have been processed, the sequencer will be closed. + QuicStreamOffset close_offset_; + + // If true, the sequencer is blocked from passing data to the stream and will + // buffer all new incoming data until FlushBufferedFrames is called. + bool blocked_; + + // Count of the number of frames received. + int num_frames_received_; + + // Count of the number of duplicate frames received. + int num_duplicate_frames_received_; + + // If true, all incoming data will be discarded. + bool ignore_read_data_; + + // If false, only call OnDataAvailable() when it becomes newly unblocked. + // Otherwise, call OnDataAvailable() when number of readable bytes changes. + bool level_triggered_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_STREAM_SEQUENCER_H_ diff --git a/quiche/quic/core/quic_stream_sequencer_buffer.cc b/quiche/quic/core/quic_stream_sequencer_buffer.cc new file mode 100644 index 000000000000..d364d61bcf16 --- /dev/null +++ b/quiche/quic/core/quic_stream_sequencer_buffer.cc @@ -0,0 +1,542 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_stream_sequencer_buffer.h" + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_interval.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { +namespace { + +size_t CalculateBlockCount(size_t max_capacity_bytes) { + return (max_capacity_bytes + QuicStreamSequencerBuffer::kBlockSizeBytes - 1) / + QuicStreamSequencerBuffer::kBlockSizeBytes; +} + +// Upper limit of how many gaps allowed in buffer, which ensures a reasonable +// number of iterations needed to find the right gap to fill when a frame +// arrives. +const size_t kMaxNumDataIntervalsAllowed = 2 * kMaxPacketGap; + +// Number of blocks allocated initially. +constexpr size_t kInitialBlockCount = 8u; + +// How fast block pointers container grow in size. +// Choose 4 to reduce the amount of reallocation. +constexpr int kBlocksGrowthFactor = 4; + +} // namespace + +QuicStreamSequencerBuffer::QuicStreamSequencerBuffer(size_t max_capacity_bytes) + : max_buffer_capacity_bytes_(max_capacity_bytes), + max_blocks_count_(CalculateBlockCount(max_capacity_bytes)), + current_blocks_count_(0u), + total_bytes_read_(0), + blocks_(nullptr) { + QUICHE_DCHECK_GE(max_blocks_count_, kInitialBlockCount); + Clear(); +} + +QuicStreamSequencerBuffer::~QuicStreamSequencerBuffer() { Clear(); } + +void QuicStreamSequencerBuffer::Clear() { + if (blocks_ != nullptr) { + for (size_t i = 0; i < current_blocks_count_; ++i) { + if (blocks_[i] != nullptr) { + RetireBlock(i); + } + } + } + num_bytes_buffered_ = 0; + bytes_received_.Clear(); + bytes_received_.Add(0, total_bytes_read_); +} + +bool QuicStreamSequencerBuffer::RetireBlock(size_t index) { + if (blocks_[index] == nullptr) { + QUIC_BUG(quic_bug_10610_1) << "Try to retire block twice"; + return false; + } + delete blocks_[index]; + blocks_[index] = nullptr; + QUIC_DVLOG(1) << "Retired block with index: " << index; + return true; +} + +void QuicStreamSequencerBuffer::MaybeAddMoreBlocks( + QuicStreamOffset next_expected_byte) { + if (current_blocks_count_ == max_blocks_count_) { + return; + } + QuicStreamOffset last_byte = next_expected_byte - 1; + size_t num_of_blocks_needed; + // As long as last_byte does not wrap around, its index plus one blocks are + // needed. Otherwise, block_count_ blocks are needed. + if (last_byte < max_buffer_capacity_bytes_) { + num_of_blocks_needed = + std::max(GetBlockIndex(last_byte) + 1, kInitialBlockCount); + } else { + num_of_blocks_needed = max_blocks_count_; + } + if (current_blocks_count_ >= num_of_blocks_needed) { + return; + } + size_t new_block_count = kBlocksGrowthFactor * current_blocks_count_; + new_block_count = std::min(std::max(new_block_count, num_of_blocks_needed), + max_blocks_count_); + auto new_blocks = std::make_unique(new_block_count); + if (blocks_ != nullptr) { + memcpy(new_blocks.get(), blocks_.get(), + current_blocks_count_ * sizeof(BufferBlock*)); + } + blocks_ = std::move(new_blocks); + current_blocks_count_ = new_block_count; +} + +QuicErrorCode QuicStreamSequencerBuffer::OnStreamData( + QuicStreamOffset starting_offset, absl::string_view data, + size_t* const bytes_buffered, std::string* error_details) { + *bytes_buffered = 0; + size_t size = data.size(); + if (size == 0) { + *error_details = "Received empty stream frame without FIN."; + return QUIC_EMPTY_STREAM_FRAME_NO_FIN; + } + // Write beyond the current range this buffer is covering. + if (starting_offset + size > total_bytes_read_ + max_buffer_capacity_bytes_ || + starting_offset + size < starting_offset) { + *error_details = "Received data beyond available range."; + return QUIC_INTERNAL_ERROR; + } + + if (bytes_received_.Empty() || + starting_offset >= bytes_received_.rbegin()->max() || + bytes_received_.IsDisjoint(QuicInterval( + starting_offset, starting_offset + size))) { + // Optimization for the typical case, when all data is newly received. + bytes_received_.AddOptimizedForAppend(starting_offset, + starting_offset + size); + if (bytes_received_.Size() >= kMaxNumDataIntervalsAllowed) { + // This frame is going to create more intervals than allowed. Stop + // processing. + *error_details = "Too many data intervals received for this stream."; + return QUIC_TOO_MANY_STREAM_DATA_INTERVALS; + } + MaybeAddMoreBlocks(starting_offset + size); + + size_t bytes_copy = 0; + if (!CopyStreamData(starting_offset, data, &bytes_copy, error_details)) { + return QUIC_STREAM_SEQUENCER_INVALID_STATE; + } + *bytes_buffered += bytes_copy; + num_bytes_buffered_ += *bytes_buffered; + return QUIC_NO_ERROR; + } + // Slow path, received data overlaps with received data. + QuicIntervalSet newly_received(starting_offset, + starting_offset + size); + newly_received.Difference(bytes_received_); + if (newly_received.Empty()) { + return QUIC_NO_ERROR; + } + bytes_received_.Add(starting_offset, starting_offset + size); + if (bytes_received_.Size() >= kMaxNumDataIntervalsAllowed) { + // This frame is going to create more intervals than allowed. Stop + // processing. + *error_details = "Too many data intervals received for this stream."; + return QUIC_TOO_MANY_STREAM_DATA_INTERVALS; + } + MaybeAddMoreBlocks(starting_offset + size); + for (const auto& interval : newly_received) { + const QuicStreamOffset copy_offset = interval.min(); + const QuicByteCount copy_length = interval.max() - interval.min(); + size_t bytes_copy = 0; + if (!CopyStreamData(copy_offset, + data.substr(copy_offset - starting_offset, copy_length), + &bytes_copy, error_details)) { + return QUIC_STREAM_SEQUENCER_INVALID_STATE; + } + *bytes_buffered += bytes_copy; + } + num_bytes_buffered_ += *bytes_buffered; + return QUIC_NO_ERROR; +} + +bool QuicStreamSequencerBuffer::CopyStreamData(QuicStreamOffset offset, + absl::string_view data, + size_t* bytes_copy, + std::string* error_details) { + *bytes_copy = 0; + size_t source_remaining = data.size(); + if (source_remaining == 0) { + return true; + } + const char* source = data.data(); + // Write data block by block. If corresponding block has not created yet, + // create it first. + // Stop when all data are written or reaches the logical end of the buffer. + while (source_remaining > 0) { + const size_t write_block_num = GetBlockIndex(offset); + const size_t write_block_offset = GetInBlockOffset(offset); + size_t current_blocks_count = current_blocks_count_; + QUICHE_DCHECK_GT(current_blocks_count, write_block_num); + + size_t block_capacity = GetBlockCapacity(write_block_num); + size_t bytes_avail = block_capacity - write_block_offset; + + // If this write meets the upper boundary of the buffer, + // reduce the available free bytes. + if (offset + bytes_avail > total_bytes_read_ + max_buffer_capacity_bytes_) { + bytes_avail = total_bytes_read_ + max_buffer_capacity_bytes_ - offset; + } + + if (write_block_num >= current_blocks_count) { + *error_details = absl::StrCat( + "QuicStreamSequencerBuffer error: OnStreamData() exceed array bounds." + "write offset = ", + offset, " write_block_num = ", write_block_num, + " current_blocks_count_ = ", current_blocks_count); + return false; + } + if (blocks_ == nullptr) { + *error_details = + "QuicStreamSequencerBuffer error: OnStreamData() blocks_ is null"; + return false; + } + if (blocks_[write_block_num] == nullptr) { + // TODO(danzh): Investigate if using a freelist would improve performance. + // Same as RetireBlock(). + blocks_[write_block_num] = new BufferBlock(); + } + + const size_t bytes_to_copy = + std::min(bytes_avail, source_remaining); + char* dest = blocks_[write_block_num]->buffer + write_block_offset; + QUIC_DVLOG(1) << "Write at offset: " << offset + << " length: " << bytes_to_copy; + + if (dest == nullptr || source == nullptr) { + *error_details = absl::StrCat( + "QuicStreamSequencerBuffer error: OnStreamData()" + " dest == nullptr: ", + (dest == nullptr), " source == nullptr: ", (source == nullptr), + " Writing at offset ", offset, + " Received frames: ", ReceivedFramesDebugString(), + " total_bytes_read_ = ", total_bytes_read_); + return false; + } + memcpy(dest, source, bytes_to_copy); + source += bytes_to_copy; + source_remaining -= bytes_to_copy; + offset += bytes_to_copy; + *bytes_copy += bytes_to_copy; + } + return true; +} + +QuicErrorCode QuicStreamSequencerBuffer::Readv(const iovec* dest_iov, + size_t dest_count, + size_t* bytes_read, + std::string* error_details) { + *bytes_read = 0; + for (size_t i = 0; i < dest_count && ReadableBytes() > 0; ++i) { + char* dest = reinterpret_cast(dest_iov[i].iov_base); + QUICHE_DCHECK(dest != nullptr); + size_t dest_remaining = dest_iov[i].iov_len; + while (dest_remaining > 0 && ReadableBytes() > 0) { + size_t block_idx = NextBlockToRead(); + size_t start_offset_in_block = ReadOffset(); + size_t block_capacity = GetBlockCapacity(block_idx); + size_t bytes_available_in_block = std::min( + ReadableBytes(), block_capacity - start_offset_in_block); + size_t bytes_to_copy = + std::min(bytes_available_in_block, dest_remaining); + QUICHE_DCHECK_GT(bytes_to_copy, 0u); + if (blocks_[block_idx] == nullptr || dest == nullptr) { + *error_details = absl::StrCat( + "QuicStreamSequencerBuffer error:" + " Readv() dest == nullptr: ", + (dest == nullptr), " blocks_[", block_idx, + "] == nullptr: ", (blocks_[block_idx] == nullptr), + " Received frames: ", ReceivedFramesDebugString(), + " total_bytes_read_ = ", total_bytes_read_); + return QUIC_STREAM_SEQUENCER_INVALID_STATE; + } + memcpy(dest, blocks_[block_idx]->buffer + start_offset_in_block, + bytes_to_copy); + dest += bytes_to_copy; + dest_remaining -= bytes_to_copy; + num_bytes_buffered_ -= bytes_to_copy; + total_bytes_read_ += bytes_to_copy; + *bytes_read += bytes_to_copy; + + // Retire the block if all the data is read out and no other data is + // stored in this block. + // In case of failing to retire a block which is ready to retire, return + // immediately. + if (bytes_to_copy == bytes_available_in_block) { + bool retire_successfully = RetireBlockIfEmpty(block_idx); + if (!retire_successfully) { + *error_details = absl::StrCat( + "QuicStreamSequencerBuffer error: fail to retire block ", + block_idx, + " as the block is already released, total_bytes_read_ = ", + total_bytes_read_, + " Received frames: ", ReceivedFramesDebugString()); + return QUIC_STREAM_SEQUENCER_INVALID_STATE; + } + } + } + } + + return QUIC_NO_ERROR; +} + +int QuicStreamSequencerBuffer::GetReadableRegions(struct iovec* iov, + int iov_len) const { + QUICHE_DCHECK(iov != nullptr); + QUICHE_DCHECK_GT(iov_len, 0); + + if (ReadableBytes() == 0) { + iov[0].iov_base = nullptr; + iov[0].iov_len = 0; + return 0; + } + + size_t start_block_idx = NextBlockToRead(); + QuicStreamOffset readable_offset_end = FirstMissingByte() - 1; + QUICHE_DCHECK_GE(readable_offset_end + 1, total_bytes_read_); + size_t end_block_offset = GetInBlockOffset(readable_offset_end); + size_t end_block_idx = GetBlockIndex(readable_offset_end); + + // If readable region is within one block, deal with it seperately. + if (start_block_idx == end_block_idx && ReadOffset() <= end_block_offset) { + iov[0].iov_base = blocks_[start_block_idx]->buffer + ReadOffset(); + iov[0].iov_len = ReadableBytes(); + QUIC_DVLOG(1) << "Got only a single block with index: " << start_block_idx; + return 1; + } + + // Get first block + iov[0].iov_base = blocks_[start_block_idx]->buffer + ReadOffset(); + iov[0].iov_len = GetBlockCapacity(start_block_idx) - ReadOffset(); + QUIC_DVLOG(1) << "Got first block " << start_block_idx << " with len " + << iov[0].iov_len; + QUICHE_DCHECK_GT(readable_offset_end + 1, total_bytes_read_ + iov[0].iov_len) + << "there should be more available data"; + + // Get readable regions of the rest blocks till either 2nd to last block + // before gap is met or |iov| is filled. For these blocks, one whole block is + // a region. + int iov_used = 1; + size_t block_idx = (start_block_idx + iov_used) % max_blocks_count_; + while (block_idx != end_block_idx && iov_used < iov_len) { + QUICHE_DCHECK(nullptr != blocks_[block_idx]); + iov[iov_used].iov_base = blocks_[block_idx]->buffer; + iov[iov_used].iov_len = GetBlockCapacity(block_idx); + QUIC_DVLOG(1) << "Got block with index: " << block_idx; + ++iov_used; + block_idx = (start_block_idx + iov_used) % max_blocks_count_; + } + + // Deal with last block if |iov| can hold more. + if (iov_used < iov_len) { + QUICHE_DCHECK(nullptr != blocks_[block_idx]); + iov[iov_used].iov_base = blocks_[end_block_idx]->buffer; + iov[iov_used].iov_len = end_block_offset + 1; + QUIC_DVLOG(1) << "Got last block with index: " << end_block_idx; + ++iov_used; + } + return iov_used; +} + +bool QuicStreamSequencerBuffer::GetReadableRegion(iovec* iov) const { + return GetReadableRegions(iov, 1) == 1; +} + +bool QuicStreamSequencerBuffer::PeekRegion(QuicStreamOffset offset, + iovec* iov) const { + QUICHE_DCHECK(iov); + + if (offset < total_bytes_read_) { + // Data at |offset| has already been consumed. + return false; + } + + if (offset >= FirstMissingByte()) { + // Data at |offset| has not been received yet. + return false; + } + + // Beginning of region. + size_t block_idx = GetBlockIndex(offset); + size_t block_offset = GetInBlockOffset(offset); + iov->iov_base = blocks_[block_idx]->buffer + block_offset; + + // Determine if entire block has been received. + size_t end_block_idx = GetBlockIndex(FirstMissingByte()); + if (block_idx == end_block_idx) { + // Only read part of block before FirstMissingByte(). + iov->iov_len = GetInBlockOffset(FirstMissingByte()) - block_offset; + } else { + // Read entire block. + iov->iov_len = GetBlockCapacity(block_idx) - block_offset; + } + + return true; +} + +bool QuicStreamSequencerBuffer::MarkConsumed(size_t bytes_consumed) { + if (bytes_consumed > ReadableBytes()) { + return false; + } + size_t bytes_to_consume = bytes_consumed; + while (bytes_to_consume > 0) { + size_t block_idx = NextBlockToRead(); + size_t offset_in_block = ReadOffset(); + size_t bytes_available = std::min( + ReadableBytes(), GetBlockCapacity(block_idx) - offset_in_block); + size_t bytes_read = std::min(bytes_to_consume, bytes_available); + total_bytes_read_ += bytes_read; + num_bytes_buffered_ -= bytes_read; + bytes_to_consume -= bytes_read; + // If advanced to the end of current block and end of buffer hasn't wrapped + // to this block yet. + if (bytes_available == bytes_read) { + RetireBlockIfEmpty(block_idx); + } + } + + return true; +} + +size_t QuicStreamSequencerBuffer::FlushBufferedFrames() { + size_t prev_total_bytes_read = total_bytes_read_; + total_bytes_read_ = NextExpectedByte(); + Clear(); + return total_bytes_read_ - prev_total_bytes_read; +} + +void QuicStreamSequencerBuffer::ReleaseWholeBuffer() { + Clear(); + current_blocks_count_ = 0; + blocks_.reset(nullptr); +} + +size_t QuicStreamSequencerBuffer::ReadableBytes() const { + return FirstMissingByte() - total_bytes_read_; +} + +bool QuicStreamSequencerBuffer::HasBytesToRead() const { + return ReadableBytes() > 0; +} + +QuicStreamOffset QuicStreamSequencerBuffer::BytesConsumed() const { + return total_bytes_read_; +} + +size_t QuicStreamSequencerBuffer::BytesBuffered() const { + return num_bytes_buffered_; +} + +size_t QuicStreamSequencerBuffer::GetBlockIndex(QuicStreamOffset offset) const { + return (offset % max_buffer_capacity_bytes_) / kBlockSizeBytes; +} + +size_t QuicStreamSequencerBuffer::GetInBlockOffset( + QuicStreamOffset offset) const { + return (offset % max_buffer_capacity_bytes_) % kBlockSizeBytes; +} + +size_t QuicStreamSequencerBuffer::ReadOffset() const { + return GetInBlockOffset(total_bytes_read_); +} + +size_t QuicStreamSequencerBuffer::NextBlockToRead() const { + return GetBlockIndex(total_bytes_read_); +} + +bool QuicStreamSequencerBuffer::RetireBlockIfEmpty(size_t block_index) { + QUICHE_DCHECK(ReadableBytes() == 0 || + GetInBlockOffset(total_bytes_read_) == 0) + << "RetireBlockIfEmpty() should only be called when advancing to next " + << "block or a gap has been reached."; + // If the whole buffer becomes empty, the last piece of data has been read. + if (Empty()) { + return RetireBlock(block_index); + } + + // Check where the logical end of this buffer is. + // Not empty if the end of circular buffer has been wrapped to this block. + if (GetBlockIndex(NextExpectedByte() - 1) == block_index) { + return true; + } + + // Read index remains in this block, which means a gap has been reached. + if (NextBlockToRead() == block_index) { + if (bytes_received_.Size() > 1) { + auto it = bytes_received_.begin(); + ++it; + if (GetBlockIndex(it->min()) == block_index) { + // Do not retire the block if next data interval is in this block. + return true; + } + } else { + QUIC_BUG(quic_bug_10610_2) << "Read stopped at where it shouldn't."; + return false; + } + } + return RetireBlock(block_index); +} + +bool QuicStreamSequencerBuffer::Empty() const { + return bytes_received_.Empty() || + (bytes_received_.Size() == 1 && total_bytes_read_ > 0 && + bytes_received_.begin()->max() == total_bytes_read_); +} + +size_t QuicStreamSequencerBuffer::GetBlockCapacity(size_t block_index) const { + if ((block_index + 1) == max_blocks_count_) { + size_t result = max_buffer_capacity_bytes_ % kBlockSizeBytes; + if (result == 0) { // whole block + result = kBlockSizeBytes; + } + return result; + } else { + return kBlockSizeBytes; + } +} + +std::string QuicStreamSequencerBuffer::ReceivedFramesDebugString() const { + return bytes_received_.ToString(); +} + +QuicStreamOffset QuicStreamSequencerBuffer::FirstMissingByte() const { + if (bytes_received_.Empty() || bytes_received_.begin()->min() > 0) { + // Offset 0 is not received yet. + return 0; + } + return bytes_received_.begin()->max(); +} + +QuicStreamOffset QuicStreamSequencerBuffer::NextExpectedByte() const { + if (bytes_received_.Empty()) { + return 0; + } + return bytes_received_.rbegin()->max(); +} + +} // namespace quic diff --git a/quiche/quic/core/quic_stream_sequencer_buffer.h b/quiche/quic/core/quic_stream_sequencer_buffer.h new file mode 100644 index 000000000000..9b36ee14a055 --- /dev/null +++ b/quiche/quic/core/quic_stream_sequencer_buffer.h @@ -0,0 +1,241 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_STREAM_SEQUENCER_BUFFER_H_ +#define QUICHE_QUIC_CORE_QUIC_STREAM_SEQUENCER_BUFFER_H_ + +// QuicStreamSequencerBuffer is a circular stream buffer with random write and +// in-sequence read. It consists of a vector of pointers pointing +// to memory blocks created as needed and an interval set recording received +// data. +// - Data are written in with offset indicating where it should be in the +// stream, and the buffer grown as needed (up to the maximum buffer capacity), +// without expensive copying (extra blocks are allocated). +// - Data can be read from the buffer if there is no gap before it, +// and the buffer shrinks as the data are consumed. +// - An upper limit on the number of blocks in the buffer provides an upper +// bound on memory use. +// +// This class is thread-unsafe. +// +// QuicStreamSequencerBuffer maintains a concept of the readable region, which +// contains all written data that has not been read. +// It promises stability of the underlying memory addresses in the readable +// region, so pointers into it can be maintained, and the offset of a pointer +// from the start of the read region can be calculated. +// +// Expected Use: +// QuicStreamSequencerBuffer buffer(2.5 * 8 * 1024); +// std::string source(1024, 'a'); +// absl::string_view string_piece(source.data(), source.size()); +// size_t written = 0; +// buffer.OnStreamData(800, string_piece, GetEpollClockNow(), &written); +// source = std::string{800, 'b'}; +// absl::string_view string_piece1(source.data(), 800); +// // Try to write to [1, 801), but should fail due to overlapping, +// // res should be QUIC_INVALID_STREAM_DATA +// auto res = buffer.OnStreamData(1, string_piece1, &written)); +// // write to [0, 800), res should be QUIC_NO_ERROR +// auto res = buffer.OnStreamData(0, string_piece1, GetEpollClockNow(), +// &written); +// +// // Read into a iovec array with total capacity of 120 bytes. +// char dest[120]; +// iovec iovecs[3]{iovec{dest, 40}, iovec{dest + 40, 40}, +// iovec{dest + 80, 40}}; +// size_t read = buffer.Readv(iovecs, 3); +// +// // Get single readable region. +// iovec iov; +// buffer.GetReadableRegion(iov); +// +// // Get readable regions from [256, 1024) and consume some of it. +// iovec iovs[2]; +// int iov_count = buffer.GetReadableRegions(iovs, 2); +// // Consume some bytes in iovs, returning number of bytes having been +// consumed. +// size_t consumed = consume_iovs(iovs, iov_count); +// buffer.MarkConsumed(consumed); + +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_interval_set.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/common/platform/api/quiche_iovec.h" + +namespace quic { + +namespace test { +class QuicStreamSequencerBufferPeer; +} // namespace test + +class QUIC_EXPORT_PRIVATE QuicStreamSequencerBuffer { + public: + // Size of blocks used by this buffer. + // Choose 8K to make block large enough to hold multiple frames, each of + // which could be up to 1.5 KB. + static const size_t kBlockSizeBytes = 8 * 1024; // 8KB + + // The basic storage block used by this buffer. + struct QUIC_EXPORT_PRIVATE BufferBlock { + char buffer[kBlockSizeBytes]; + }; + + explicit QuicStreamSequencerBuffer(size_t max_capacity_bytes); + QuicStreamSequencerBuffer(const QuicStreamSequencerBuffer&) = delete; + QuicStreamSequencerBuffer(QuicStreamSequencerBuffer&&) = default; + QuicStreamSequencerBuffer& operator=(const QuicStreamSequencerBuffer&) = + delete; + QuicStreamSequencerBuffer& operator=(QuicStreamSequencerBuffer&&) = default; + ~QuicStreamSequencerBuffer(); + + // Free the space used to buffer data. + void Clear(); + + // Returns true if there is nothing to read in this buffer. + bool Empty() const; + + // Called to buffer new data received for this stream. If the data was + // successfully buffered, returns QUIC_NO_ERROR and stores the number of + // bytes buffered in |bytes_buffered|. Returns an error otherwise. + QuicErrorCode OnStreamData(QuicStreamOffset offset, absl::string_view data, + size_t* bytes_buffered, + std::string* error_details); + + // Reads from this buffer into given iovec array, up to number of iov_len + // iovec objects and returns the number of bytes read. + QuicErrorCode Readv(const struct iovec* dest_iov, size_t dest_count, + size_t* bytes_read, std::string* error_details); + + // Returns the readable region of valid data in iovec format. The readable + // region is the buffer region where there is valid data not yet read by + // client. + // Returns the number of iovec entries in |iov| which were populated. + // If the region is empty, one iovec entry with 0 length + // is returned, and the function returns 0. If there are more readable + // regions than |iov_size|, the function only processes the first + // |iov_size| of them. + int GetReadableRegions(struct iovec* iov, int iov_len) const; + + // Fills in one iovec with data from the next readable region. + // Returns false if there is no readable region available. + bool GetReadableRegion(iovec* iov) const; + + // Returns true and sets |*iov| to point to a region starting at |offset|. + // Returns false if no data can be read at |offset|, which can be because data + // has not been received yet or it is already consumed. + // Does not consume data. + bool PeekRegion(QuicStreamOffset offset, iovec* iov) const; + + // Called after GetReadableRegions() to free up |bytes_consumed| space if + // these bytes are processed. + // Pre-requisite: bytes_consumed <= available bytes to read. + bool MarkConsumed(size_t bytes_consumed); + + // Deletes and records as consumed any buffered data and clear the buffer. + // (To be called only after sequencer's StopReading has been called.) + size_t FlushBufferedFrames(); + + // Free the memory of buffered data. + void ReleaseWholeBuffer(); + + // Whether there are bytes can be read out. + bool HasBytesToRead() const; + + // Count how many bytes have been consumed (read out of buffer). + QuicStreamOffset BytesConsumed() const; + + // Count how many bytes are in buffer at this moment. + size_t BytesBuffered() const; + + // Returns number of bytes available to be read out. + size_t ReadableBytes() const; + + // Returns offset of first missing byte. + QuicStreamOffset FirstMissingByte() const; + + // Returns offset of highest received byte + 1. + QuicStreamOffset NextExpectedByte() const; + + // Return all received frames as a string. + std::string ReceivedFramesDebugString() const; + + private: + friend class test::QuicStreamSequencerBufferPeer; + + // Copies |data| to blocks_, sets |bytes_copy|. Returns true if the copy is + // successful. Otherwise, sets |error_details| and returns false. + bool CopyStreamData(QuicStreamOffset offset, absl::string_view data, + size_t* bytes_copy, std::string* error_details); + + // Dispose the given buffer block. + // After calling this method, blocks_[index] is set to nullptr + // in order to indicate that no memory set is allocated for that block. + // Returns true on success, false otherwise. + bool RetireBlock(size_t index); + + // Should only be called after the indexed block is read till the end of the + // block or missing data has been reached. + // If the block at |block_index| contains no buffered data, the block + // should be retired. + // Returns true on success, or false otherwise. + bool RetireBlockIfEmpty(size_t block_index); + + // Calculate the capacity of block at specified index. + // Return value should be either kBlockSizeBytes for non-trailing blocks and + // max_buffer_capacity % kBlockSizeBytes for trailing block. + size_t GetBlockCapacity(size_t index) const; + + // Does not check if offset is within reasonable range. + size_t GetBlockIndex(QuicStreamOffset offset) const; + + // Given an offset in the stream, return the offset from the beginning of the + // block which contains this data. + size_t GetInBlockOffset(QuicStreamOffset offset) const; + + // Get offset relative to index 0 in logical 1st block to start next read. + size_t ReadOffset() const; + + // Get the index of the logical 1st block to start next read. + size_t NextBlockToRead() const; + + // Resize blocks_ if more blocks are needed to accomodate bytes before + // next_expected_byte. + void MaybeAddMoreBlocks(QuicStreamOffset next_expected_byte); + + // The maximum total capacity of this buffer in byte, as constructed. + size_t max_buffer_capacity_bytes_; + + // Number of blocks this buffer would have when it reaches full capacity, + // i.e., maximal number of blocks in blocks_. + size_t max_blocks_count_; + + // Number of blocks this buffer currently has. + size_t current_blocks_count_; + + // Number of bytes read out of buffer. + QuicStreamOffset total_bytes_read_; + + // An ordered, variable-length list of blocks, with the length limited + // such that the number of blocks never exceeds max_blocks_count_. + // Each list entry can hold up to kBlockSizeBytes bytes. + std::unique_ptr blocks_; + + // Number of bytes in buffer. + size_t num_bytes_buffered_; + + // Currently received data. + QuicIntervalSet bytes_received_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_STREAM_SEQUENCER_BUFFER_H_ diff --git a/quiche/quic/core/quic_stream_sequencer_buffer_test.cc b/quiche/quic/core/quic_stream_sequencer_buffer_test.cc new file mode 100644 index 000000000000..d1cdf341c1c1 --- /dev/null +++ b/quiche/quic/core/quic_stream_sequencer_buffer_test.cc @@ -0,0 +1,1139 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_stream_sequencer_buffer.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_stream_sequencer_buffer_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { + +namespace test { + +absl::string_view IovecToStringPiece(iovec iov) { + return absl::string_view(reinterpret_cast(iov.iov_base), + iov.iov_len); +} + +char GetCharFromIOVecs(size_t offset, iovec iov[], size_t count) { + size_t start_offset = 0; + for (size_t i = 0; i < count; i++) { + if (iov[i].iov_len == 0) { + continue; + } + size_t end_offset = start_offset + iov[i].iov_len - 1; + if (offset >= start_offset && offset <= end_offset) { + const char* buf = reinterpret_cast(iov[i].iov_base); + return buf[offset - start_offset]; + } + start_offset += iov[i].iov_len; + } + QUIC_LOG(ERROR) << "Could not locate char at offset " << offset << " in " + << count << " iovecs"; + for (size_t i = 0; i < count; ++i) { + QUIC_LOG(ERROR) << " iov[" << i << "].iov_len = " << iov[i].iov_len; + } + return '\0'; +} + +const size_t kMaxNumGapsAllowed = 2 * kMaxPacketGap; + +static const size_t kBlockSizeBytes = + QuicStreamSequencerBuffer::kBlockSizeBytes; +using BufferBlock = QuicStreamSequencerBuffer::BufferBlock; + +namespace { + +class QuicStreamSequencerBufferTest : public QuicTest { + public: + void SetUp() override { Initialize(); } + + void ResetMaxCapacityBytes(size_t max_capacity_bytes) { + max_capacity_bytes_ = max_capacity_bytes; + Initialize(); + } + + protected: + void Initialize() { + buffer_ = + std::make_unique((max_capacity_bytes_)); + helper_ = std::make_unique((buffer_.get())); + } + + // Use 8.5 here to make sure that the buffer has more than + // QuicStreamSequencerBuffer::kInitialBlockCount block and its end doesn't + // align with the end of a block in order to test all the offset calculation. + size_t max_capacity_bytes_ = 8.5 * kBlockSizeBytes; + + std::unique_ptr buffer_; + std::unique_ptr helper_; + size_t written_ = 0; + std::string error_details_; +}; + +TEST_F(QuicStreamSequencerBufferTest, InitializeWithMaxRecvWindowSize) { + ResetMaxCapacityBytes(16 * 1024 * 1024); // 16MB + EXPECT_EQ(2 * 1024u, // 16MB / 8KB = 2K + helper_->max_blocks_count()); + EXPECT_EQ(max_capacity_bytes_, helper_->max_buffer_capacity()); + EXPECT_TRUE(helper_->CheckInitialState()); +} + +TEST_F(QuicStreamSequencerBufferTest, InitializationWithDifferentSizes) { + const size_t kCapacity = 16 * QuicStreamSequencerBuffer::kBlockSizeBytes; + ResetMaxCapacityBytes(kCapacity); + EXPECT_EQ(max_capacity_bytes_, helper_->max_buffer_capacity()); + EXPECT_TRUE(helper_->CheckInitialState()); + + const size_t kCapacity1 = 32 * QuicStreamSequencerBuffer::kBlockSizeBytes; + ResetMaxCapacityBytes(kCapacity1); + EXPECT_EQ(kCapacity1, helper_->max_buffer_capacity()); + EXPECT_TRUE(helper_->CheckInitialState()); +} + +TEST_F(QuicStreamSequencerBufferTest, ClearOnEmpty) { + buffer_->Clear(); + EXPECT_TRUE(helper_->CheckBufferInvariants()); +} + +TEST_F(QuicStreamSequencerBufferTest, OnStreamData0length) { + QuicErrorCode error = + buffer_->OnStreamData(800, "", &written_, &error_details_); + EXPECT_THAT(error, IsError(QUIC_EMPTY_STREAM_FRAME_NO_FIN)); + EXPECT_TRUE(helper_->CheckBufferInvariants()); +} + +TEST_F(QuicStreamSequencerBufferTest, OnStreamDataWithinBlock) { + EXPECT_FALSE(helper_->IsBufferAllocated()); + std::string source(1024, 'a'); + EXPECT_THAT(buffer_->OnStreamData(800, source, &written_, &error_details_), + IsQuicNoError()); + BufferBlock* block_ptr = helper_->GetBlock(0); + for (size_t i = 0; i < source.size(); ++i) { + ASSERT_EQ('a', block_ptr->buffer[helper_->GetInBlockOffset(800) + i]); + } + EXPECT_EQ(2, helper_->IntervalSize()); + EXPECT_EQ(0u, helper_->ReadableBytes()); + EXPECT_EQ(1u, helper_->bytes_received().Size()); + EXPECT_EQ(800u, helper_->bytes_received().begin()->min()); + EXPECT_EQ(1824u, helper_->bytes_received().begin()->max()); + EXPECT_TRUE(helper_->CheckBufferInvariants()); + EXPECT_TRUE(helper_->IsBufferAllocated()); +} + +TEST_F(QuicStreamSequencerBufferTest, Move) { + EXPECT_FALSE(helper_->IsBufferAllocated()); + std::string source(1024, 'a'); + EXPECT_THAT(buffer_->OnStreamData(800, source, &written_, &error_details_), + IsQuicNoError()); + BufferBlock* block_ptr = helper_->GetBlock(0); + for (size_t i = 0; i < source.size(); ++i) { + ASSERT_EQ('a', block_ptr->buffer[helper_->GetInBlockOffset(800) + i]); + } + + QuicStreamSequencerBuffer buffer2(std::move(*buffer_)); + QuicStreamSequencerBufferPeer helper2(&buffer2); + + EXPECT_FALSE(helper_->IsBufferAllocated()); + + EXPECT_EQ(2, helper2.IntervalSize()); + EXPECT_EQ(0u, helper2.ReadableBytes()); + EXPECT_EQ(1u, helper2.bytes_received().Size()); + EXPECT_EQ(800u, helper2.bytes_received().begin()->min()); + EXPECT_EQ(1824u, helper2.bytes_received().begin()->max()); + EXPECT_TRUE(helper2.CheckBufferInvariants()); + EXPECT_TRUE(helper2.IsBufferAllocated()); +} + +TEST_F(QuicStreamSequencerBufferTest, DISABLED_OnStreamDataInvalidSource) { + // Pass in an invalid source, expects to return error. + absl::string_view source; + source = absl::string_view(nullptr, 1024); + EXPECT_THAT(buffer_->OnStreamData(800, source, &written_, &error_details_), + IsError(QUIC_STREAM_SEQUENCER_INVALID_STATE)); + EXPECT_EQ(0u, error_details_.find(absl::StrCat( + "QuicStreamSequencerBuffer error: OnStreamData() " + "dest == nullptr: ", + false, " source == nullptr: ", true))); +} + +TEST_F(QuicStreamSequencerBufferTest, OnStreamDataWithOverlap) { + std::string source(1024, 'a'); + // Write something into [800, 1824) + EXPECT_THAT(buffer_->OnStreamData(800, source, &written_, &error_details_), + IsQuicNoError()); + // Try to write to [0, 1024) and [1024, 2048). + EXPECT_THAT(buffer_->OnStreamData(0, source, &written_, &error_details_), + IsQuicNoError()); + EXPECT_THAT(buffer_->OnStreamData(1024, source, &written_, &error_details_), + IsQuicNoError()); +} + +TEST_F(QuicStreamSequencerBufferTest, + OnStreamDataOverlapAndDuplicateCornerCases) { + std::string source(1024, 'a'); + // Write something into [800, 1824) + buffer_->OnStreamData(800, source, &written_, &error_details_); + source = std::string(800, 'b'); + std::string one_byte = "c"; + // Write [1, 801). + EXPECT_THAT(buffer_->OnStreamData(1, source, &written_, &error_details_), + IsQuicNoError()); + // Write [0, 800). + EXPECT_THAT(buffer_->OnStreamData(0, source, &written_, &error_details_), + IsQuicNoError()); + // Write [1823, 1824). + EXPECT_THAT(buffer_->OnStreamData(1823, one_byte, &written_, &error_details_), + IsQuicNoError()); + EXPECT_EQ(0u, written_); + // write one byte to [1824, 1825) + EXPECT_THAT(buffer_->OnStreamData(1824, one_byte, &written_, &error_details_), + IsQuicNoError()); + EXPECT_TRUE(helper_->CheckBufferInvariants()); +} + +TEST_F(QuicStreamSequencerBufferTest, OnStreamDataWithoutOverlap) { + std::string source(1024, 'a'); + // Write something into [800, 1824). + EXPECT_THAT(buffer_->OnStreamData(800, source, &written_, &error_details_), + IsQuicNoError()); + source = std::string(100, 'b'); + // Write something into [kBlockSizeBytes * 2 - 20, kBlockSizeBytes * 2 + 80). + EXPECT_THAT(buffer_->OnStreamData(kBlockSizeBytes * 2 - 20, source, &written_, + &error_details_), + IsQuicNoError()); + EXPECT_EQ(3, helper_->IntervalSize()); + EXPECT_EQ(1024u + 100u, buffer_->BytesBuffered()); + EXPECT_TRUE(helper_->CheckBufferInvariants()); +} + +TEST_F(QuicStreamSequencerBufferTest, OnStreamDataInLongStreamWithOverlap) { + // Assume a stream has already buffered almost 4GB. + uint64_t total_bytes_read = pow(2, 32) - 1; + helper_->set_total_bytes_read(total_bytes_read); + helper_->AddBytesReceived(0, total_bytes_read); + + // Three new out of order frames arrive. + const size_t kBytesToWrite = 100; + std::string source(kBytesToWrite, 'a'); + // Frame [2^32 + 500, 2^32 + 600). + QuicStreamOffset offset = pow(2, 32) + 500; + EXPECT_THAT(buffer_->OnStreamData(offset, source, &written_, &error_details_), + IsQuicNoError()); + EXPECT_EQ(2, helper_->IntervalSize()); + + // Frame [2^32 + 700, 2^32 + 800). + offset = pow(2, 32) + 700; + EXPECT_THAT(buffer_->OnStreamData(offset, source, &written_, &error_details_), + IsQuicNoError()); + EXPECT_EQ(3, helper_->IntervalSize()); + + // Another frame [2^32 + 300, 2^32 + 400). + offset = pow(2, 32) + 300; + EXPECT_THAT(buffer_->OnStreamData(offset, source, &written_, &error_details_), + IsQuicNoError()); + EXPECT_EQ(4, helper_->IntervalSize()); +} + +TEST_F(QuicStreamSequencerBufferTest, OnStreamDataTillEnd) { + // Write 50 bytes to the end. + const size_t kBytesToWrite = 50; + std::string source(kBytesToWrite, 'a'); + EXPECT_THAT(buffer_->OnStreamData(max_capacity_bytes_ - kBytesToWrite, source, + &written_, &error_details_), + IsQuicNoError()); + EXPECT_EQ(50u, buffer_->BytesBuffered()); + EXPECT_TRUE(helper_->CheckBufferInvariants()); +} + +TEST_F(QuicStreamSequencerBufferTest, OnStreamDataTillEndCorner) { + // Write 1 byte to the end. + const size_t kBytesToWrite = 1; + std::string source(kBytesToWrite, 'a'); + EXPECT_THAT(buffer_->OnStreamData(max_capacity_bytes_ - kBytesToWrite, source, + &written_, &error_details_), + IsQuicNoError()); + EXPECT_EQ(1u, buffer_->BytesBuffered()); + EXPECT_TRUE(helper_->CheckBufferInvariants()); +} + +TEST_F(QuicStreamSequencerBufferTest, OnStreamDataBeyondCapacity) { + std::string source(60, 'a'); + EXPECT_THAT(buffer_->OnStreamData(max_capacity_bytes_ - 50, source, &written_, + &error_details_), + IsError(QUIC_INTERNAL_ERROR)); + EXPECT_TRUE(helper_->CheckBufferInvariants()); + + source = "b"; + EXPECT_THAT(buffer_->OnStreamData(max_capacity_bytes_, source, &written_, + &error_details_), + IsError(QUIC_INTERNAL_ERROR)); + EXPECT_TRUE(helper_->CheckBufferInvariants()); + + EXPECT_THAT(buffer_->OnStreamData(max_capacity_bytes_ * 1000, source, + &written_, &error_details_), + IsError(QUIC_INTERNAL_ERROR)); + EXPECT_TRUE(helper_->CheckBufferInvariants()); + + // Disallow current_gap != gaps_.end() + EXPECT_THAT(buffer_->OnStreamData(static_cast(-1), source, + &written_, &error_details_), + IsError(QUIC_INTERNAL_ERROR)); + EXPECT_TRUE(helper_->CheckBufferInvariants()); + + // Disallow offset + size overflow + source = "bbb"; + EXPECT_THAT(buffer_->OnStreamData(static_cast(-2), source, + &written_, &error_details_), + IsError(QUIC_INTERNAL_ERROR)); + EXPECT_TRUE(helper_->CheckBufferInvariants()); + EXPECT_EQ(0u, buffer_->BytesBuffered()); +} + +TEST_F(QuicStreamSequencerBufferTest, Readv100Bytes) { + std::string source(1024, 'a'); + // Write something into [kBlockSizeBytes, kBlockSizeBytes + 1024). + buffer_->OnStreamData(kBlockSizeBytes, source, &written_, &error_details_); + EXPECT_FALSE(buffer_->HasBytesToRead()); + source = std::string(100, 'b'); + // Write something into [0, 100). + buffer_->OnStreamData(0, source, &written_, &error_details_); + EXPECT_TRUE(buffer_->HasBytesToRead()); + // Read into a iovec array with total capacity of 120 bytes. + char dest[120]; + iovec iovecs[3]{iovec{dest, 40}, iovec{dest + 40, 40}, iovec{dest + 80, 40}}; + size_t read; + EXPECT_THAT(buffer_->Readv(iovecs, 3, &read, &error_details_), + IsQuicNoError()); + QUIC_LOG(ERROR) << error_details_; + EXPECT_EQ(100u, read); + EXPECT_EQ(100u, buffer_->BytesConsumed()); + EXPECT_EQ(source, absl::string_view(dest, read)); + // The first block should be released as its data has been read out. + EXPECT_EQ(nullptr, helper_->GetBlock(0)); + EXPECT_TRUE(helper_->CheckBufferInvariants()); +} + +TEST_F(QuicStreamSequencerBufferTest, ReadvAcrossBlocks) { + std::string source(kBlockSizeBytes + 50, 'a'); + // Write 1st block to full and extand 50 bytes to next block. + buffer_->OnStreamData(0, source, &written_, &error_details_); + EXPECT_EQ(source.size(), helper_->ReadableBytes()); + // Iteratively read 512 bytes from buffer_-> Overwrite dest[] each time. + char dest[512]; + while (helper_->ReadableBytes()) { + std::fill(dest, dest + 512, 0); + iovec iovecs[2]{iovec{dest, 256}, iovec{dest + 256, 256}}; + size_t read; + EXPECT_THAT(buffer_->Readv(iovecs, 2, &read, &error_details_), + IsQuicNoError()); + } + // The last read only reads the rest 50 bytes in 2nd block. + EXPECT_EQ(std::string(50, 'a'), std::string(dest, 50)); + EXPECT_EQ(0, dest[50]) << "Dest[50] shouln't be filled."; + EXPECT_EQ(source.size(), buffer_->BytesConsumed()); + EXPECT_TRUE(buffer_->Empty()); + EXPECT_TRUE(helper_->CheckBufferInvariants()); +} + +TEST_F(QuicStreamSequencerBufferTest, ClearAfterRead) { + std::string source(kBlockSizeBytes + 50, 'a'); + // Write 1st block to full with 'a'. + buffer_->OnStreamData(0, source, &written_, &error_details_); + // Read first 512 bytes from buffer to make space at the beginning. + char dest[512]{0}; + const iovec iov{dest, 512}; + size_t read; + EXPECT_THAT(buffer_->Readv(&iov, 1, &read, &error_details_), IsQuicNoError()); + // Clear() should make buffer empty while preserving BytesConsumed() + buffer_->Clear(); + EXPECT_TRUE(buffer_->Empty()); + EXPECT_TRUE(helper_->CheckBufferInvariants()); +} + +TEST_F(QuicStreamSequencerBufferTest, + OnStreamDataAcrossLastBlockAndFillCapacity) { + std::string source(kBlockSizeBytes + 50, 'a'); + // Write 1st block to full with 'a'. + buffer_->OnStreamData(0, source, &written_, &error_details_); + // Read first 512 bytes from buffer to make space at the beginning. + char dest[512]{0}; + const iovec iov{dest, 512}; + size_t read; + EXPECT_THAT(buffer_->Readv(&iov, 1, &read, &error_details_), IsQuicNoError()); + EXPECT_EQ(source.size(), written_); + + // Write more than half block size of bytes in the last block with 'b', which + // will wrap to the beginning and reaches the full capacity. + source = std::string(0.5 * kBlockSizeBytes + 512, 'b'); + EXPECT_THAT(buffer_->OnStreamData(2 * kBlockSizeBytes, source, &written_, + &error_details_), + IsQuicNoError()); + EXPECT_EQ(source.size(), written_); + EXPECT_TRUE(helper_->CheckBufferInvariants()); +} + +TEST_F(QuicStreamSequencerBufferTest, + OnStreamDataAcrossLastBlockAndExceedCapacity) { + std::string source(kBlockSizeBytes + 50, 'a'); + // Write 1st block to full. + buffer_->OnStreamData(0, source, &written_, &error_details_); + // Read first 512 bytes from buffer to make space at the beginning. + char dest[512]{0}; + const iovec iov{dest, 512}; + size_t read; + EXPECT_THAT(buffer_->Readv(&iov, 1, &read, &error_details_), IsQuicNoError()); + + // Try to write from [max_capacity_bytes_ - 0.5 * kBlockSizeBytes, + // max_capacity_bytes_ + 512 + 1). But last bytes exceeds current capacity. + source = std::string(0.5 * kBlockSizeBytes + 512 + 1, 'b'); + EXPECT_THAT(buffer_->OnStreamData(8 * kBlockSizeBytes, source, &written_, + &error_details_), + IsError(QUIC_INTERNAL_ERROR)); + EXPECT_TRUE(helper_->CheckBufferInvariants()); +} + +TEST_F(QuicStreamSequencerBufferTest, ReadvAcrossLastBlock) { + // Write to full capacity and read out 512 bytes at beginning and continue + // appending 256 bytes. + std::string source(max_capacity_bytes_, 'a'); + buffer_->OnStreamData(0, source, &written_, &error_details_); + char dest[512]{0}; + const iovec iov{dest, 512}; + size_t read; + EXPECT_THAT(buffer_->Readv(&iov, 1, &read, &error_details_), IsQuicNoError()); + source = std::string(256, 'b'); + buffer_->OnStreamData(max_capacity_bytes_, source, &written_, + &error_details_); + EXPECT_TRUE(helper_->CheckBufferInvariants()); + + // Read all data out. + std::unique_ptr dest1{new char[max_capacity_bytes_]}; + dest1[0] = 0; + const iovec iov1{dest1.get(), max_capacity_bytes_}; + EXPECT_THAT(buffer_->Readv(&iov1, 1, &read, &error_details_), + IsQuicNoError()); + EXPECT_EQ(max_capacity_bytes_ - 512 + 256, read); + EXPECT_EQ(max_capacity_bytes_ + 256, buffer_->BytesConsumed()); + EXPECT_TRUE(buffer_->Empty()); + EXPECT_TRUE(helper_->CheckBufferInvariants()); +} + +TEST_F(QuicStreamSequencerBufferTest, ReadvEmpty) { + char dest[512]{0}; + iovec iov{dest, 512}; + size_t read; + EXPECT_THAT(buffer_->Readv(&iov, 1, &read, &error_details_), IsQuicNoError()); + EXPECT_EQ(0u, read); + EXPECT_TRUE(helper_->CheckBufferInvariants()); +} + +TEST_F(QuicStreamSequencerBufferTest, GetReadableRegionsEmpty) { + iovec iovs[2]; + int iov_count = buffer_->GetReadableRegions(iovs, 2); + EXPECT_EQ(0, iov_count); + EXPECT_EQ(nullptr, iovs[iov_count].iov_base); + EXPECT_EQ(0u, iovs[iov_count].iov_len); +} + +TEST_F(QuicStreamSequencerBufferTest, ReleaseWholeBuffer) { + // Tests that buffer is not deallocated unless ReleaseWholeBuffer() is called. + std::string source(100, 'b'); + // Write something into [0, 100). + buffer_->OnStreamData(0, source, &written_, &error_details_); + EXPECT_TRUE(buffer_->HasBytesToRead()); + char dest[120]; + iovec iovecs[3]{iovec{dest, 40}, iovec{dest + 40, 40}, iovec{dest + 80, 40}}; + size_t read; + EXPECT_THAT(buffer_->Readv(iovecs, 3, &read, &error_details_), + IsQuicNoError()); + EXPECT_EQ(100u, read); + EXPECT_EQ(100u, buffer_->BytesConsumed()); + EXPECT_TRUE(helper_->CheckBufferInvariants()); + EXPECT_TRUE(helper_->IsBufferAllocated()); + buffer_->ReleaseWholeBuffer(); + EXPECT_FALSE(helper_->IsBufferAllocated()); +} + +TEST_F(QuicStreamSequencerBufferTest, GetReadableRegionsBlockedByGap) { + // Write into [1, 1024). + std::string source(1023, 'a'); + buffer_->OnStreamData(1, source, &written_, &error_details_); + // Try to get readable regions, but none is there. + iovec iovs[2]; + int iov_count = buffer_->GetReadableRegions(iovs, 2); + EXPECT_EQ(0, iov_count); +} + +TEST_F(QuicStreamSequencerBufferTest, GetReadableRegionsTillEndOfBlock) { + // Write first block to full with [0, 256) 'a' and the rest 'b' then read out + // [0, 256) + std::string source(kBlockSizeBytes, 'a'); + buffer_->OnStreamData(0, source, &written_, &error_details_); + char dest[256]; + helper_->Read(dest, 256); + // Get readable region from [256, 1024) + iovec iovs[2]; + int iov_count = buffer_->GetReadableRegions(iovs, 2); + EXPECT_EQ(1, iov_count); + EXPECT_EQ(std::string(kBlockSizeBytes - 256, 'a'), + IovecToStringPiece(iovs[0])); +} + +TEST_F(QuicStreamSequencerBufferTest, GetReadableRegionsWithinOneBlock) { + // Write into [0, 1024) and then read out [0, 256) + std::string source(1024, 'a'); + buffer_->OnStreamData(0, source, &written_, &error_details_); + char dest[256]; + helper_->Read(dest, 256); + // Get readable region from [256, 1024) + iovec iovs[2]; + int iov_count = buffer_->GetReadableRegions(iovs, 2); + EXPECT_EQ(1, iov_count); + EXPECT_EQ(std::string(1024 - 256, 'a'), IovecToStringPiece(iovs[0])); +} + +TEST_F(QuicStreamSequencerBufferTest, + GetReadableRegionsAcrossBlockWithLongIOV) { + // Write into [0, 2 * kBlockSizeBytes + 1024) and then read out [0, 1024) + std::string source(2 * kBlockSizeBytes + 1024, 'a'); + buffer_->OnStreamData(0, source, &written_, &error_details_); + char dest[1024]; + helper_->Read(dest, 1024); + + iovec iovs[4]; + int iov_count = buffer_->GetReadableRegions(iovs, 4); + EXPECT_EQ(3, iov_count); + EXPECT_EQ(kBlockSizeBytes - 1024, iovs[0].iov_len); + EXPECT_EQ(kBlockSizeBytes, iovs[1].iov_len); + EXPECT_EQ(1024u, iovs[2].iov_len); +} + +TEST_F(QuicStreamSequencerBufferTest, + GetReadableRegionsWithMultipleIOVsAcrossEnd) { + // Write into [0, 8.5 * kBlockSizeBytes - 1024) and then read out [0, 1024) + // and then append 1024 + 512 bytes. + std::string source(8.5 * kBlockSizeBytes - 1024, 'a'); + buffer_->OnStreamData(0, source, &written_, &error_details_); + char dest[1024]; + helper_->Read(dest, 1024); + // Write across the end. + source = std::string(1024 + 512, 'b'); + buffer_->OnStreamData(8.5 * kBlockSizeBytes - 1024, source, &written_, + &error_details_); + // Use short iovec's. + iovec iovs[2]; + int iov_count = buffer_->GetReadableRegions(iovs, 2); + EXPECT_EQ(2, iov_count); + EXPECT_EQ(kBlockSizeBytes - 1024, iovs[0].iov_len); + EXPECT_EQ(kBlockSizeBytes, iovs[1].iov_len); + // Use long iovec's and wrap the end of buffer. + iovec iovs1[11]; + EXPECT_EQ(10, buffer_->GetReadableRegions(iovs1, 11)); + EXPECT_EQ(0.5 * kBlockSizeBytes, iovs1[8].iov_len); + EXPECT_EQ(512u, iovs1[9].iov_len); + EXPECT_EQ(std::string(512, 'b'), IovecToStringPiece(iovs1[9])); +} + +TEST_F(QuicStreamSequencerBufferTest, GetReadableRegionEmpty) { + iovec iov; + EXPECT_FALSE(buffer_->GetReadableRegion(&iov)); + EXPECT_EQ(nullptr, iov.iov_base); + EXPECT_EQ(0u, iov.iov_len); +} + +TEST_F(QuicStreamSequencerBufferTest, GetReadableRegionBeforeGap) { + // Write into [1, 1024). + std::string source(1023, 'a'); + buffer_->OnStreamData(1, source, &written_, &error_details_); + // GetReadableRegion should return false because range [0,1) hasn't been + // filled yet. + iovec iov; + EXPECT_FALSE(buffer_->GetReadableRegion(&iov)); +} + +TEST_F(QuicStreamSequencerBufferTest, GetReadableRegionTillEndOfBlock) { + // Write into [0, kBlockSizeBytes + 1) and then read out [0, 256) + std::string source(kBlockSizeBytes + 1, 'a'); + buffer_->OnStreamData(0, source, &written_, &error_details_); + char dest[256]; + helper_->Read(dest, 256); + // Get readable region from [256, 1024) + iovec iov; + EXPECT_TRUE(buffer_->GetReadableRegion(&iov)); + EXPECT_EQ(std::string(kBlockSizeBytes - 256, 'a'), IovecToStringPiece(iov)); +} + +TEST_F(QuicStreamSequencerBufferTest, GetReadableRegionTillGap) { + // Write into [0, kBlockSizeBytes - 1) and then read out [0, 256) + std::string source(kBlockSizeBytes - 1, 'a'); + buffer_->OnStreamData(0, source, &written_, &error_details_); + char dest[256]; + helper_->Read(dest, 256); + // Get readable region from [256, 1023) + iovec iov; + EXPECT_TRUE(buffer_->GetReadableRegion(&iov)); + EXPECT_EQ(std::string(kBlockSizeBytes - 1 - 256, 'a'), + IovecToStringPiece(iov)); +} + +TEST_F(QuicStreamSequencerBufferTest, PeekEmptyBuffer) { + iovec iov; + EXPECT_FALSE(buffer_->PeekRegion(0, &iov)); + EXPECT_FALSE(buffer_->PeekRegion(1, &iov)); + EXPECT_FALSE(buffer_->PeekRegion(100, &iov)); +} + +TEST_F(QuicStreamSequencerBufferTest, PeekSingleBlock) { + std::string source(kBlockSizeBytes, 'a'); + buffer_->OnStreamData(0, source, &written_, &error_details_); + + iovec iov; + EXPECT_TRUE(buffer_->PeekRegion(0, &iov)); + EXPECT_EQ(source, IovecToStringPiece(iov)); + + // Peeking again gives the same result. + EXPECT_TRUE(buffer_->PeekRegion(0, &iov)); + EXPECT_EQ(source, IovecToStringPiece(iov)); + + // Peek at a different offset. + EXPECT_TRUE(buffer_->PeekRegion(100, &iov)); + EXPECT_EQ(absl::string_view(source).substr(100), IovecToStringPiece(iov)); + + // Peeking at or after FirstMissingByte() returns false. + EXPECT_FALSE(buffer_->PeekRegion(kBlockSizeBytes, &iov)); + EXPECT_FALSE(buffer_->PeekRegion(kBlockSizeBytes + 1, &iov)); +} + +TEST_F(QuicStreamSequencerBufferTest, PeekTwoWritesInSingleBlock) { + const size_t length1 = 1024; + std::string source1(length1, 'a'); + buffer_->OnStreamData(0, source1, &written_, &error_details_); + + iovec iov; + EXPECT_TRUE(buffer_->PeekRegion(0, &iov)); + EXPECT_EQ(source1, IovecToStringPiece(iov)); + + // The second frame goes into the same block. + const size_t length2 = 800; + std::string source2(length2, 'b'); + buffer_->OnStreamData(length1, source2, &written_, &error_details_); + + EXPECT_TRUE(buffer_->PeekRegion(length1, &iov)); + EXPECT_EQ(source2, IovecToStringPiece(iov)); + + // Peek with an offset inside the first write. + const QuicStreamOffset offset1 = 500; + EXPECT_TRUE(buffer_->PeekRegion(offset1, &iov)); + EXPECT_EQ(absl::string_view(source1).substr(offset1), + IovecToStringPiece(iov).substr(0, length1 - offset1)); + EXPECT_EQ(absl::string_view(source2), + IovecToStringPiece(iov).substr(length1 - offset1)); + + // Peek with an offset inside the second write. + const QuicStreamOffset offset2 = 1500; + EXPECT_TRUE(buffer_->PeekRegion(offset2, &iov)); + EXPECT_EQ(absl::string_view(source2).substr(offset2 - length1), + IovecToStringPiece(iov)); + + // Peeking at or after FirstMissingByte() returns false. + EXPECT_FALSE(buffer_->PeekRegion(length1 + length2, &iov)); + EXPECT_FALSE(buffer_->PeekRegion(length1 + length2 + 1, &iov)); +} + +TEST_F(QuicStreamSequencerBufferTest, PeekBufferWithMultipleBlocks) { + const size_t length1 = 1024; + std::string source1(length1, 'a'); + buffer_->OnStreamData(0, source1, &written_, &error_details_); + + iovec iov; + EXPECT_TRUE(buffer_->PeekRegion(0, &iov)); + EXPECT_EQ(source1, IovecToStringPiece(iov)); + + const size_t length2 = kBlockSizeBytes + 2; + std::string source2(length2, 'b'); + buffer_->OnStreamData(length1, source2, &written_, &error_details_); + + // Peek with offset 0 returns the entire block. + EXPECT_TRUE(buffer_->PeekRegion(0, &iov)); + EXPECT_EQ(kBlockSizeBytes, iov.iov_len); + EXPECT_EQ(source1, IovecToStringPiece(iov).substr(0, length1)); + EXPECT_EQ(absl::string_view(source2).substr(0, kBlockSizeBytes - length1), + IovecToStringPiece(iov).substr(length1)); + + EXPECT_TRUE(buffer_->PeekRegion(length1, &iov)); + EXPECT_EQ(absl::string_view(source2).substr(0, kBlockSizeBytes - length1), + IovecToStringPiece(iov)); + + EXPECT_TRUE(buffer_->PeekRegion(kBlockSizeBytes, &iov)); + EXPECT_EQ(absl::string_view(source2).substr(kBlockSizeBytes - length1), + IovecToStringPiece(iov)); + + // Peeking at or after FirstMissingByte() returns false. + EXPECT_FALSE(buffer_->PeekRegion(length1 + length2, &iov)); + EXPECT_FALSE(buffer_->PeekRegion(length1 + length2 + 1, &iov)); +} + +TEST_F(QuicStreamSequencerBufferTest, PeekAfterConsumed) { + std::string source1(kBlockSizeBytes, 'a'); + buffer_->OnStreamData(0, source1, &written_, &error_details_); + + iovec iov; + EXPECT_TRUE(buffer_->PeekRegion(0, &iov)); + EXPECT_EQ(source1, IovecToStringPiece(iov)); + + // Consume some data. + EXPECT_TRUE(buffer_->MarkConsumed(1024)); + + // Peeking into consumed data fails. + EXPECT_FALSE(buffer_->PeekRegion(0, &iov)); + EXPECT_FALSE(buffer_->PeekRegion(512, &iov)); + + EXPECT_TRUE(buffer_->PeekRegion(1024, &iov)); + EXPECT_EQ(absl::string_view(source1).substr(1024), IovecToStringPiece(iov)); + + EXPECT_TRUE(buffer_->PeekRegion(1500, &iov)); + EXPECT_EQ(absl::string_view(source1).substr(1500), IovecToStringPiece(iov)); + + // Consume rest of block. + EXPECT_TRUE(buffer_->MarkConsumed(kBlockSizeBytes - 1024)); + + // Read new data. + std::string source2(300, 'b'); + buffer_->OnStreamData(kBlockSizeBytes, source2, &written_, &error_details_); + + // Peek into new data. + EXPECT_TRUE(buffer_->PeekRegion(kBlockSizeBytes, &iov)); + EXPECT_EQ(source2, IovecToStringPiece(iov)); + + EXPECT_TRUE(buffer_->PeekRegion(kBlockSizeBytes + 128, &iov)); + EXPECT_EQ(absl::string_view(source2).substr(128), IovecToStringPiece(iov)); + + // Peeking into consumed data still fails. + EXPECT_FALSE(buffer_->PeekRegion(0, &iov)); + EXPECT_FALSE(buffer_->PeekRegion(512, &iov)); + EXPECT_FALSE(buffer_->PeekRegion(1024, &iov)); + EXPECT_FALSE(buffer_->PeekRegion(1500, &iov)); +} + +TEST_F(QuicStreamSequencerBufferTest, PeekContinously) { + std::string source1(kBlockSizeBytes, 'a'); + buffer_->OnStreamData(0, source1, &written_, &error_details_); + + iovec iov; + EXPECT_TRUE(buffer_->PeekRegion(0, &iov)); + EXPECT_EQ(source1, IovecToStringPiece(iov)); + + std::string source2(kBlockSizeBytes, 'b'); + buffer_->OnStreamData(kBlockSizeBytes, source2, &written_, &error_details_); + + EXPECT_TRUE(buffer_->PeekRegion(kBlockSizeBytes, &iov)); + EXPECT_EQ(source2, IovecToStringPiece(iov)); + + // First block is still there. + EXPECT_TRUE(buffer_->PeekRegion(0, &iov)); + EXPECT_EQ(source1, IovecToStringPiece(iov)); +} + +TEST_F(QuicStreamSequencerBufferTest, MarkConsumedInOneBlock) { + // Write into [0, 1024) and then read out [0, 256) + std::string source(1024, 'a'); + buffer_->OnStreamData(0, source, &written_, &error_details_); + char dest[256]; + helper_->Read(dest, 256); + + EXPECT_TRUE(buffer_->MarkConsumed(512)); + EXPECT_EQ(256u + 512u, buffer_->BytesConsumed()); + EXPECT_EQ(256u, helper_->ReadableBytes()); + buffer_->MarkConsumed(256); + EXPECT_TRUE(buffer_->Empty()); + EXPECT_TRUE(helper_->CheckBufferInvariants()); +} + +TEST_F(QuicStreamSequencerBufferTest, MarkConsumedNotEnoughBytes) { + // Write into [0, 1024) and then read out [0, 256) + std::string source(1024, 'a'); + buffer_->OnStreamData(0, source, &written_, &error_details_); + char dest[256]; + helper_->Read(dest, 256); + + // Consume 1st 512 bytes + EXPECT_TRUE(buffer_->MarkConsumed(512)); + EXPECT_EQ(256u + 512u, buffer_->BytesConsumed()); + EXPECT_EQ(256u, helper_->ReadableBytes()); + // Try to consume one bytes more than available. Should return false. + EXPECT_FALSE(buffer_->MarkConsumed(257)); + EXPECT_EQ(256u + 512u, buffer_->BytesConsumed()); + iovec iov; + EXPECT_TRUE(buffer_->GetReadableRegion(&iov)); + EXPECT_TRUE(helper_->CheckBufferInvariants()); +} + +TEST_F(QuicStreamSequencerBufferTest, MarkConsumedAcrossBlock) { + // Write into [0, 2 * kBlockSizeBytes + 1024) and then read out [0, 1024) + std::string source(2 * kBlockSizeBytes + 1024, 'a'); + buffer_->OnStreamData(0, source, &written_, &error_details_); + char dest[1024]; + helper_->Read(dest, 1024); + + buffer_->MarkConsumed(2 * kBlockSizeBytes); + EXPECT_EQ(source.size(), buffer_->BytesConsumed()); + EXPECT_TRUE(buffer_->Empty()); + EXPECT_TRUE(helper_->CheckBufferInvariants()); +} + +TEST_F(QuicStreamSequencerBufferTest, MarkConsumedAcrossEnd) { + // Write into [0, 8.5 * kBlockSizeBytes - 1024) and then read out [0, 1024) + // and then append 1024 + 512 bytes. + std::string source(8.5 * kBlockSizeBytes - 1024, 'a'); + buffer_->OnStreamData(0, source, &written_, &error_details_); + char dest[1024]; + helper_->Read(dest, 1024); + source = std::string(1024 + 512, 'b'); + buffer_->OnStreamData(8.5 * kBlockSizeBytes - 1024, source, &written_, + &error_details_); + EXPECT_EQ(1024u, buffer_->BytesConsumed()); + + // Consume to the end of 8th block. + buffer_->MarkConsumed(8 * kBlockSizeBytes - 1024); + EXPECT_EQ(8 * kBlockSizeBytes, buffer_->BytesConsumed()); + // Consume across the physical end of buffer + buffer_->MarkConsumed(0.5 * kBlockSizeBytes + 500); + EXPECT_EQ(max_capacity_bytes_ + 500, buffer_->BytesConsumed()); + EXPECT_EQ(12u, helper_->ReadableBytes()); + // Consume to the logical end of buffer + buffer_->MarkConsumed(12); + EXPECT_EQ(max_capacity_bytes_ + 512, buffer_->BytesConsumed()); + EXPECT_TRUE(buffer_->Empty()); + EXPECT_TRUE(helper_->CheckBufferInvariants()); +} + +TEST_F(QuicStreamSequencerBufferTest, FlushBufferedFrames) { + // Write into [0, 8.5 * kBlockSizeBytes - 1024) and then read out [0, 1024). + std::string source(max_capacity_bytes_ - 1024, 'a'); + buffer_->OnStreamData(0, source, &written_, &error_details_); + char dest[1024]; + helper_->Read(dest, 1024); + EXPECT_EQ(1024u, buffer_->BytesConsumed()); + // Write [1024, 512) to the physical beginning. + source = std::string(512, 'b'); + buffer_->OnStreamData(max_capacity_bytes_, source, &written_, + &error_details_); + EXPECT_EQ(512u, written_); + EXPECT_EQ(max_capacity_bytes_ - 1024 + 512, buffer_->FlushBufferedFrames()); + EXPECT_EQ(max_capacity_bytes_ + 512, buffer_->BytesConsumed()); + EXPECT_TRUE(buffer_->Empty()); + EXPECT_TRUE(helper_->CheckBufferInvariants()); + // Clear buffer at this point should still preserve BytesConsumed(). + buffer_->Clear(); + EXPECT_EQ(max_capacity_bytes_ + 512, buffer_->BytesConsumed()); + EXPECT_TRUE(helper_->CheckBufferInvariants()); +} + +TEST_F(QuicStreamSequencerBufferTest, TooManyGaps) { + // Make sure max capacity is large enough that it is possible to have more + // than |kMaxNumGapsAllowed| number of gaps. + max_capacity_bytes_ = 3 * kBlockSizeBytes; + // Feed buffer with 1-byte discontiguous frames. e.g. [1,2), [3,4), [5,6)... + for (QuicStreamOffset begin = 1; begin <= max_capacity_bytes_; begin += 2) { + QuicErrorCode rs = + buffer_->OnStreamData(begin, "a", &written_, &error_details_); + + QuicStreamOffset last_straw = 2 * kMaxNumGapsAllowed - 1; + if (begin == last_straw) { + EXPECT_THAT(rs, IsError(QUIC_TOO_MANY_STREAM_DATA_INTERVALS)); + EXPECT_EQ("Too many data intervals received for this stream.", + error_details_); + break; + } + } +} + +class QuicStreamSequencerBufferRandomIOTest + : public QuicStreamSequencerBufferTest { + public: + using OffsetSizePair = std::pair; + + void SetUp() override { + // Test against a larger capacity then above tests. Also make sure the last + // block is partially available to use. + max_capacity_bytes_ = 8.25 * kBlockSizeBytes; + // Stream to be buffered should be larger than the capacity to test wrap + // around. + bytes_to_buffer_ = 2 * max_capacity_bytes_; + Initialize(); + + uint64_t seed = QuicRandom::GetInstance()->RandUint64(); + QUIC_LOG(INFO) << "**** The current seed is " << seed << " ****"; + rng_.set_seed(seed); + } + + // Create an out-of-order source stream with given size to populate + // shuffled_buf_. + void CreateSourceAndShuffle(size_t max_chunk_size_bytes) { + max_chunk_size_bytes_ = max_chunk_size_bytes; + std::unique_ptr chopped_stream( + new OffsetSizePair[bytes_to_buffer_]); + + // Split stream into small chunks with random length. chopped_stream will be + // populated with segmented stream chunks. + size_t start_chopping_offset = 0; + size_t iterations = 0; + while (start_chopping_offset < bytes_to_buffer_) { + size_t max_chunk = std::min( + max_chunk_size_bytes_, bytes_to_buffer_ - start_chopping_offset); + size_t chunk_size = rng_.RandUint64() % max_chunk + 1; + chopped_stream[iterations] = + OffsetSizePair(start_chopping_offset, chunk_size); + start_chopping_offset += chunk_size; + ++iterations; + } + QUICHE_DCHECK(start_chopping_offset == bytes_to_buffer_); + size_t chunk_num = iterations; + + // Randomly change the sequence of in-ordered OffsetSizePairs to make a + // out-of-order array of OffsetSizePairs. + for (int i = chunk_num - 1; i >= 0; --i) { + size_t random_idx = rng_.RandUint64() % (i + 1); + QUIC_DVLOG(1) << "chunk offset " << chopped_stream[random_idx].first + << " size " << chopped_stream[random_idx].second; + shuffled_buf_.push_front(chopped_stream[random_idx]); + chopped_stream[random_idx] = chopped_stream[i]; + } + } + + // Write the currently first chunk of data in the out-of-order stream into + // QuicStreamSequencerBuffer. If current chuck cannot be written into buffer + // because it goes beyond current capacity, move it to the end of + // shuffled_buf_ and write it later. + void WriteNextChunkToBuffer() { + OffsetSizePair& chunk = shuffled_buf_.front(); + QuicStreamOffset offset = chunk.first; + const size_t num_to_write = chunk.second; + std::unique_ptr write_buf{new char[max_chunk_size_bytes_]}; + for (size_t i = 0; i < num_to_write; ++i) { + write_buf[i] = (offset + i) % 256; + } + absl::string_view string_piece_w(write_buf.get(), num_to_write); + auto result = buffer_->OnStreamData(offset, string_piece_w, &written_, + &error_details_); + if (result == QUIC_NO_ERROR) { + shuffled_buf_.pop_front(); + total_bytes_written_ += num_to_write; + } else { + // This chunk offset exceeds window size. + shuffled_buf_.push_back(chunk); + shuffled_buf_.pop_front(); + } + QUIC_DVLOG(1) << " write at offset: " << offset + << " len to write: " << num_to_write + << " write result: " << result + << " left over: " << shuffled_buf_.size(); + } + + protected: + std::list shuffled_buf_; + size_t max_chunk_size_bytes_; + QuicStreamOffset bytes_to_buffer_; + size_t total_bytes_written_ = 0; + size_t total_bytes_read_ = 0; + SimpleRandom rng_; +}; + +TEST_F(QuicStreamSequencerBufferRandomIOTest, RandomWriteAndReadv) { + // Set kMaxReadSize larger than kBlockSizeBytes to test both small and large + // read. + const size_t kMaxReadSize = kBlockSizeBytes * 2; + // kNumReads is larger than 1 to test how multiple read destinations work. + const size_t kNumReads = 2; + // Since write and read operation have equal possibility to be called. Bytes + // to be written into and read out of should roughly the same. + const size_t kMaxWriteSize = kNumReads * kMaxReadSize; + size_t iterations = 0; + + CreateSourceAndShuffle(kMaxWriteSize); + + while ((!shuffled_buf_.empty() || total_bytes_read_ < bytes_to_buffer_) && + iterations <= 2 * bytes_to_buffer_) { + uint8_t next_action = + shuffled_buf_.empty() ? uint8_t{1} : rng_.RandUint64() % 2; + QUIC_DVLOG(1) << "iteration: " << iterations; + switch (next_action) { + case 0: { // write + WriteNextChunkToBuffer(); + ASSERT_TRUE(helper_->CheckBufferInvariants()); + break; + } + case 1: { // readv + std::unique_ptr read_buf{ + new char[kNumReads][kMaxReadSize]}; + iovec dest_iov[kNumReads]; + size_t num_to_read = 0; + for (size_t i = 0; i < kNumReads; ++i) { + dest_iov[i].iov_base = + reinterpret_cast(const_cast(read_buf[i])); + dest_iov[i].iov_len = rng_.RandUint64() % kMaxReadSize; + num_to_read += dest_iov[i].iov_len; + } + size_t actually_read; + EXPECT_THAT(buffer_->Readv(dest_iov, kNumReads, &actually_read, + &error_details_), + IsQuicNoError()); + ASSERT_LE(actually_read, num_to_read); + QUIC_DVLOG(1) << " read from offset: " << total_bytes_read_ + << " size: " << num_to_read + << " actual read: " << actually_read; + for (size_t i = 0; i < actually_read; ++i) { + char ch = (i + total_bytes_read_) % 256; + ASSERT_EQ(ch, GetCharFromIOVecs(i, dest_iov, kNumReads)) + << " at iteration " << iterations; + } + total_bytes_read_ += actually_read; + ASSERT_EQ(total_bytes_read_, buffer_->BytesConsumed()); + ASSERT_TRUE(helper_->CheckBufferInvariants()); + break; + } + } + ++iterations; + ASSERT_LE(total_bytes_read_, total_bytes_written_); + } + EXPECT_LT(iterations, bytes_to_buffer_) << "runaway test"; + EXPECT_LE(bytes_to_buffer_, total_bytes_read_) + << "iterations: " << iterations; + EXPECT_LE(bytes_to_buffer_, total_bytes_written_); +} + +TEST_F(QuicStreamSequencerBufferRandomIOTest, RandomWriteAndConsumeInPlace) { + // The value 4 is chosen such that the max write size is no larger than the + // maximum buffer capacity. + const size_t kMaxNumReads = 4; + // Adjust write amount be roughly equal to that GetReadableRegions() can get. + const size_t kMaxWriteSize = kMaxNumReads * kBlockSizeBytes; + ASSERT_LE(kMaxWriteSize, max_capacity_bytes_); + size_t iterations = 0; + + CreateSourceAndShuffle(kMaxWriteSize); + + while ((!shuffled_buf_.empty() || total_bytes_read_ < bytes_to_buffer_) && + iterations <= 2 * bytes_to_buffer_) { + uint8_t next_action = + shuffled_buf_.empty() ? uint8_t{1} : rng_.RandUint64() % 2; + QUIC_DVLOG(1) << "iteration: " << iterations; + switch (next_action) { + case 0: { // write + WriteNextChunkToBuffer(); + ASSERT_TRUE(helper_->CheckBufferInvariants()); + break; + } + case 1: { // GetReadableRegions and then MarkConsumed + size_t num_read = rng_.RandUint64() % kMaxNumReads + 1; + iovec dest_iov[kMaxNumReads]; + ASSERT_TRUE(helper_->CheckBufferInvariants()); + size_t actually_num_read = + buffer_->GetReadableRegions(dest_iov, num_read); + ASSERT_LE(actually_num_read, num_read); + size_t avail_bytes = 0; + for (size_t i = 0; i < actually_num_read; ++i) { + avail_bytes += dest_iov[i].iov_len; + } + // process random number of bytes (check the value of each byte). + size_t bytes_to_process = rng_.RandUint64() % (avail_bytes + 1); + size_t bytes_processed = 0; + for (size_t i = 0; i < actually_num_read; ++i) { + size_t bytes_in_block = std::min( + bytes_to_process - bytes_processed, dest_iov[i].iov_len); + if (bytes_in_block == 0) { + break; + } + for (size_t j = 0; j < bytes_in_block; ++j) { + ASSERT_LE(bytes_processed, bytes_to_process); + char char_expected = + (buffer_->BytesConsumed() + bytes_processed) % 256; + ASSERT_EQ(char_expected, + reinterpret_cast(dest_iov[i].iov_base)[j]) + << " at iteration " << iterations; + ++bytes_processed; + } + } + + buffer_->MarkConsumed(bytes_processed); + + QUIC_DVLOG(1) << "iteration " << iterations << ": try to get " + << num_read << " readable regions, actually get " + << actually_num_read + << " from offset: " << total_bytes_read_ + << "\nprocesse bytes: " << bytes_processed; + total_bytes_read_ += bytes_processed; + ASSERT_EQ(total_bytes_read_, buffer_->BytesConsumed()); + ASSERT_TRUE(helper_->CheckBufferInvariants()); + break; + } + } + ++iterations; + ASSERT_LE(total_bytes_read_, total_bytes_written_); + } + EXPECT_LT(iterations, bytes_to_buffer_) << "runaway test"; + EXPECT_LE(bytes_to_buffer_, total_bytes_read_) + << "iterations: " << iterations; + EXPECT_LE(bytes_to_buffer_, total_bytes_written_); +} + +TEST_F(QuicStreamSequencerBufferTest, GrowBlockSizeOnDemand) { + max_capacity_bytes_ = 1024 * kBlockSizeBytes; + std::string source_of_one_block(kBlockSizeBytes, 'a'); + Initialize(); + + ASSERT_EQ(helper_->current_blocks_count(), 0u); + + // A minimum of 8 blocks are allocated + buffer_->OnStreamData(0, source_of_one_block, &written_, &error_details_); + ASSERT_EQ(helper_->current_blocks_count(), 8u); + + // Number of blocks doesn't grow if the data is within the capacity. + buffer_->OnStreamData(kBlockSizeBytes * 7, source_of_one_block, &written_, + &error_details_); + ASSERT_EQ(helper_->current_blocks_count(), 8u); + + // Number of blocks grows by a factor of 4 normally. + buffer_->OnStreamData(kBlockSizeBytes * 8, "a", &written_, &error_details_); + ASSERT_EQ(helper_->current_blocks_count(), 32u); + + // Number of blocks grow to the demanded size of 140 instead of 128 since + // that's not enough. + buffer_->OnStreamData(kBlockSizeBytes * 139, source_of_one_block, &written_, + &error_details_); + ASSERT_EQ(helper_->current_blocks_count(), 140u); + + // Number of blocks grows by a factor of 4 normally. + buffer_->OnStreamData(kBlockSizeBytes * 140, source_of_one_block, &written_, + &error_details_); + ASSERT_EQ(helper_->current_blocks_count(), 560u); + + // max_capacity_bytes is reached and number of blocks is capped. + buffer_->OnStreamData(kBlockSizeBytes * 560, source_of_one_block, &written_, + &error_details_); + ASSERT_EQ(helper_->current_blocks_count(), 1024u); + + // max_capacity_bytes is reached and number of blocks is capped. + buffer_->OnStreamData(kBlockSizeBytes * 1025, source_of_one_block, &written_, + &error_details_); + ASSERT_EQ(helper_->current_blocks_count(), 1024u); +} + +} // anonymous namespace + +} // namespace test + +} // namespace quic diff --git a/quiche/quic/core/quic_stream_sequencer_test.cc b/quiche/quic/core/quic_stream_sequencer_test.cc new file mode 100644 index 000000000000..3ac6d88ba8d3 --- /dev/null +++ b/quiche/quic/core/quic_stream_sequencer_test.cc @@ -0,0 +1,782 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_stream_sequencer.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_stream.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_stream_sequencer_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +using testing::_; +using testing::AnyNumber; +using testing::InSequence; + +namespace quic { +namespace test { + +class MockStream : public QuicStreamSequencer::StreamInterface { + public: + MOCK_METHOD(void, OnFinRead, (), (override)); + MOCK_METHOD(void, OnDataAvailable, (), (override)); + MOCK_METHOD(void, OnUnrecoverableError, + (QuicErrorCode error, const std::string& details), (override)); + MOCK_METHOD(void, OnUnrecoverableError, + (QuicErrorCode error, QuicIetfTransportErrorCodes ietf_error, + const std::string& details), + (override)); + MOCK_METHOD(void, ResetWithError, (QuicResetStreamError error), (override)); + MOCK_METHOD(void, AddBytesConsumed, (QuicByteCount bytes), (override)); + + QuicStreamId id() const override { return 1; } + ParsedQuicVersion version() const override { + return CurrentSupportedVersions()[0]; + } +}; + +namespace { + +static const char kPayload[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + +class QuicStreamSequencerTest : public QuicTest { + public: + void ConsumeData(size_t num_bytes) { + char buffer[1024]; + ASSERT_GT(ABSL_ARRAYSIZE(buffer), num_bytes); + struct iovec iov; + iov.iov_base = buffer; + iov.iov_len = num_bytes; + ASSERT_EQ(num_bytes, sequencer_->Readv(&iov, 1)); + } + + protected: + QuicStreamSequencerTest() + : stream_(), sequencer_(new QuicStreamSequencer(&stream_)) {} + + // Verify that the data in first region match with the expected[0]. + bool VerifyReadableRegion(const std::vector& expected) { + return VerifyReadableRegion(*sequencer_, expected); + } + + // Verify that the data in each of currently readable regions match with each + // item given in |expected|. + bool VerifyReadableRegions(const std::vector& expected) { + return VerifyReadableRegions(*sequencer_, expected); + } + + bool VerifyIovecs(iovec* iovecs, size_t num_iovecs, + const std::vector& expected) { + return VerifyIovecs(*sequencer_, iovecs, num_iovecs, expected); + } + + bool VerifyReadableRegion(const QuicStreamSequencer& sequencer, + const std::vector& expected) { + iovec iovecs[1]; + if (sequencer.GetReadableRegions(iovecs, 1)) { + return (VerifyIovecs(sequencer, iovecs, 1, + std::vector{expected[0]})); + } + return false; + } + + // Verify that the data in each of currently readable regions match with each + // item given in |expected|. + bool VerifyReadableRegions(const QuicStreamSequencer& sequencer, + const std::vector& expected) { + iovec iovecs[5]; + size_t num_iovecs = + sequencer.GetReadableRegions(iovecs, ABSL_ARRAYSIZE(iovecs)); + return VerifyReadableRegion(sequencer, expected) && + VerifyIovecs(sequencer, iovecs, num_iovecs, expected); + } + + bool VerifyIovecs(const QuicStreamSequencer& /*sequencer*/, iovec* iovecs, + size_t num_iovecs, + const std::vector& expected) { + int start_position = 0; + for (size_t i = 0; i < num_iovecs; ++i) { + if (!VerifyIovec(iovecs[i], + expected[0].substr(start_position, iovecs[i].iov_len))) { + return false; + } + start_position += iovecs[i].iov_len; + } + return true; + } + + bool VerifyIovec(const iovec& iovec, absl::string_view expected) { + if (iovec.iov_len != expected.length()) { + QUIC_LOG(ERROR) << "Invalid length: " << iovec.iov_len << " vs " + << expected.length(); + return false; + } + if (memcmp(iovec.iov_base, expected.data(), expected.length()) != 0) { + QUIC_LOG(ERROR) << "Invalid data: " << static_cast(iovec.iov_base) + << " vs " << expected; + return false; + } + return true; + } + + void OnFinFrame(QuicStreamOffset byte_offset, const char* data) { + QuicStreamFrame frame; + frame.stream_id = 1; + frame.offset = byte_offset; + frame.data_buffer = data; + frame.data_length = strlen(data); + frame.fin = true; + sequencer_->OnStreamFrame(frame); + } + + void OnFrame(QuicStreamOffset byte_offset, const char* data) { + QuicStreamFrame frame; + frame.stream_id = 1; + frame.offset = byte_offset; + frame.data_buffer = data; + frame.data_length = strlen(data); + frame.fin = false; + sequencer_->OnStreamFrame(frame); + } + + size_t NumBufferedBytes() { + return QuicStreamSequencerPeer::GetNumBufferedBytes(sequencer_.get()); + } + + testing::StrictMock stream_; + std::unique_ptr sequencer_; +}; + +// TODO(rch): reorder these tests so they build on each other. + +TEST_F(QuicStreamSequencerTest, RejectOldFrame) { + EXPECT_CALL(stream_, AddBytesConsumed(3)); + EXPECT_CALL(stream_, OnDataAvailable()).WillOnce(testing::Invoke([this]() { + ConsumeData(3); + })); + + OnFrame(0, "abc"); + + EXPECT_EQ(0u, NumBufferedBytes()); + EXPECT_EQ(3u, sequencer_->NumBytesConsumed()); + // Ignore this - it matches a past packet number and we should not see it + // again. + OnFrame(0, "def"); + EXPECT_EQ(0u, NumBufferedBytes()); +} + +TEST_F(QuicStreamSequencerTest, RejectBufferedFrame) { + EXPECT_CALL(stream_, OnDataAvailable()); + + OnFrame(0, "abc"); + EXPECT_EQ(3u, NumBufferedBytes()); + EXPECT_EQ(0u, sequencer_->NumBytesConsumed()); + + // Ignore this - it matches a buffered frame. + // Right now there's no checking that the payload is consistent. + OnFrame(0, "def"); + EXPECT_EQ(3u, NumBufferedBytes()); +} + +TEST_F(QuicStreamSequencerTest, FullFrameConsumed) { + EXPECT_CALL(stream_, AddBytesConsumed(3)); + EXPECT_CALL(stream_, OnDataAvailable()).WillOnce(testing::Invoke([this]() { + ConsumeData(3); + })); + + OnFrame(0, "abc"); + EXPECT_EQ(0u, NumBufferedBytes()); + EXPECT_EQ(3u, sequencer_->NumBytesConsumed()); +} + +TEST_F(QuicStreamSequencerTest, BlockedThenFullFrameConsumed) { + sequencer_->SetBlockedUntilFlush(); + + OnFrame(0, "abc"); + EXPECT_EQ(3u, NumBufferedBytes()); + EXPECT_EQ(0u, sequencer_->NumBytesConsumed()); + + EXPECT_CALL(stream_, AddBytesConsumed(3)); + EXPECT_CALL(stream_, OnDataAvailable()).WillOnce(testing::Invoke([this]() { + ConsumeData(3); + })); + sequencer_->SetUnblocked(); + EXPECT_EQ(0u, NumBufferedBytes()); + EXPECT_EQ(3u, sequencer_->NumBytesConsumed()); + + EXPECT_CALL(stream_, AddBytesConsumed(3)); + EXPECT_CALL(stream_, OnDataAvailable()).WillOnce(testing::Invoke([this]() { + ConsumeData(3); + })); + EXPECT_FALSE(sequencer_->IsClosed()); + EXPECT_FALSE(sequencer_->IsAllDataAvailable()); + OnFinFrame(3, "def"); + EXPECT_TRUE(sequencer_->IsClosed()); + EXPECT_TRUE(sequencer_->IsAllDataAvailable()); +} + +TEST_F(QuicStreamSequencerTest, BlockedThenFullFrameAndFinConsumed) { + sequencer_->SetBlockedUntilFlush(); + + OnFinFrame(0, "abc"); + EXPECT_EQ(3u, NumBufferedBytes()); + EXPECT_EQ(0u, sequencer_->NumBytesConsumed()); + + EXPECT_CALL(stream_, AddBytesConsumed(3)); + EXPECT_CALL(stream_, OnDataAvailable()).WillOnce(testing::Invoke([this]() { + ConsumeData(3); + })); + EXPECT_FALSE(sequencer_->IsClosed()); + EXPECT_TRUE(sequencer_->IsAllDataAvailable()); + sequencer_->SetUnblocked(); + EXPECT_TRUE(sequencer_->IsClosed()); + EXPECT_EQ(0u, NumBufferedBytes()); + EXPECT_EQ(3u, sequencer_->NumBytesConsumed()); +} + +TEST_F(QuicStreamSequencerTest, EmptyFrame) { + if (!stream_.version().HasIetfQuicFrames()) { + EXPECT_CALL(stream_, + OnUnrecoverableError(QUIC_EMPTY_STREAM_FRAME_NO_FIN, _)); + } + OnFrame(0, ""); + EXPECT_EQ(0u, NumBufferedBytes()); + EXPECT_EQ(0u, sequencer_->NumBytesConsumed()); +} + +TEST_F(QuicStreamSequencerTest, EmptyFinFrame) { + EXPECT_CALL(stream_, OnDataAvailable()); + OnFinFrame(0, ""); + EXPECT_EQ(0u, NumBufferedBytes()); + EXPECT_EQ(0u, sequencer_->NumBytesConsumed()); + EXPECT_TRUE(sequencer_->IsAllDataAvailable()); +} + +TEST_F(QuicStreamSequencerTest, PartialFrameConsumed) { + EXPECT_CALL(stream_, AddBytesConsumed(2)); + EXPECT_CALL(stream_, OnDataAvailable()).WillOnce(testing::Invoke([this]() { + ConsumeData(2); + })); + + OnFrame(0, "abc"); + EXPECT_EQ(1u, NumBufferedBytes()); + EXPECT_EQ(2u, sequencer_->NumBytesConsumed()); +} + +TEST_F(QuicStreamSequencerTest, NextxFrameNotConsumed) { + EXPECT_CALL(stream_, OnDataAvailable()); + + OnFrame(0, "abc"); + EXPECT_EQ(3u, NumBufferedBytes()); + EXPECT_EQ(0u, sequencer_->NumBytesConsumed()); +} + +TEST_F(QuicStreamSequencerTest, FutureFrameNotProcessed) { + OnFrame(3, "abc"); + EXPECT_EQ(3u, NumBufferedBytes()); + EXPECT_EQ(0u, sequencer_->NumBytesConsumed()); +} + +TEST_F(QuicStreamSequencerTest, OutOfOrderFrameProcessed) { + // Buffer the first + OnFrame(6, "ghi"); + EXPECT_EQ(3u, NumBufferedBytes()); + EXPECT_EQ(0u, sequencer_->NumBytesConsumed()); + EXPECT_EQ(3u, sequencer_->NumBytesBuffered()); + // Buffer the second + OnFrame(3, "def"); + EXPECT_EQ(6u, NumBufferedBytes()); + EXPECT_EQ(0u, sequencer_->NumBytesConsumed()); + EXPECT_EQ(6u, sequencer_->NumBytesBuffered()); + + EXPECT_CALL(stream_, AddBytesConsumed(9)); + EXPECT_CALL(stream_, OnDataAvailable()).WillOnce(testing::Invoke([this]() { + ConsumeData(9); + })); + + // Now process all of them at once. + OnFrame(0, "abc"); + EXPECT_EQ(9u, sequencer_->NumBytesConsumed()); + EXPECT_EQ(0u, sequencer_->NumBytesBuffered()); + + EXPECT_EQ(0u, NumBufferedBytes()); +} + +TEST_F(QuicStreamSequencerTest, BasicHalfCloseOrdered) { + InSequence s; + + EXPECT_CALL(stream_, OnDataAvailable()).WillOnce(testing::Invoke([this]() { + ConsumeData(3); + })); + EXPECT_CALL(stream_, AddBytesConsumed(3)); + OnFinFrame(0, "abc"); + + EXPECT_EQ(3u, QuicStreamSequencerPeer::GetCloseOffset(sequencer_.get())); +} + +TEST_F(QuicStreamSequencerTest, BasicHalfCloseUnorderedWithFlush) { + OnFinFrame(6, ""); + EXPECT_EQ(6u, QuicStreamSequencerPeer::GetCloseOffset(sequencer_.get())); + + OnFrame(3, "def"); + EXPECT_CALL(stream_, AddBytesConsumed(6)); + EXPECT_CALL(stream_, OnDataAvailable()).WillOnce(testing::Invoke([this]() { + ConsumeData(6); + })); + EXPECT_FALSE(sequencer_->IsClosed()); + OnFrame(0, "abc"); + EXPECT_TRUE(sequencer_->IsClosed()); +} + +TEST_F(QuicStreamSequencerTest, BasicHalfUnordered) { + OnFinFrame(3, ""); + EXPECT_EQ(3u, QuicStreamSequencerPeer::GetCloseOffset(sequencer_.get())); + + EXPECT_CALL(stream_, AddBytesConsumed(3)); + EXPECT_CALL(stream_, OnDataAvailable()).WillOnce(testing::Invoke([this]() { + ConsumeData(3); + })); + EXPECT_FALSE(sequencer_->IsClosed()); + OnFrame(0, "abc"); + EXPECT_TRUE(sequencer_->IsClosed()); +} + +TEST_F(QuicStreamSequencerTest, TerminateWithReadv) { + char buffer[3]; + + OnFinFrame(3, ""); + EXPECT_EQ(3u, QuicStreamSequencerPeer::GetCloseOffset(sequencer_.get())); + + EXPECT_FALSE(sequencer_->IsClosed()); + + EXPECT_CALL(stream_, OnDataAvailable()); + OnFrame(0, "abc"); + + EXPECT_CALL(stream_, AddBytesConsumed(3)); + iovec iov = {&buffer[0], 3}; + int bytes_read = sequencer_->Readv(&iov, 1); + EXPECT_EQ(3, bytes_read); + EXPECT_TRUE(sequencer_->IsClosed()); +} + +TEST_F(QuicStreamSequencerTest, MultipleOffsets) { + OnFinFrame(3, ""); + EXPECT_EQ(3u, QuicStreamSequencerPeer::GetCloseOffset(sequencer_.get())); + + EXPECT_CALL(stream_, OnUnrecoverableError( + QUIC_STREAM_SEQUENCER_INVALID_STATE, + "Stream 1 received new final offset: 1, which is " + "different from close offset: 3")); + OnFinFrame(1, ""); +} + +class QuicSequencerRandomTest : public QuicStreamSequencerTest { + public: + using Frame = std::pair; + using FrameList = std::vector; + + void CreateFrames() { + int payload_size = ABSL_ARRAYSIZE(kPayload) - 1; + int remaining_payload = payload_size; + while (remaining_payload != 0) { + int size = std::min(OneToN(6), remaining_payload); + int index = payload_size - remaining_payload; + list_.push_back( + std::make_pair(index, std::string(kPayload + index, size))); + remaining_payload -= size; + } + } + + QuicSequencerRandomTest() { + uint64_t seed = QuicRandom::GetInstance()->RandUint64(); + QUIC_LOG(INFO) << "**** The current seed is " << seed << " ****"; + random_.set_seed(seed); + + CreateFrames(); + } + + int OneToN(int n) { return random_.RandUint64() % n + 1; } + + void ReadAvailableData() { + // Read all available data + char output[ABSL_ARRAYSIZE(kPayload) + 1]; + iovec iov; + iov.iov_base = output; + iov.iov_len = ABSL_ARRAYSIZE(output); + int bytes_read = sequencer_->Readv(&iov, 1); + EXPECT_NE(0, bytes_read); + output_.append(output, bytes_read); + } + + std::string output_; + // Data which peek at using GetReadableRegion if we back up. + std::string peeked_; + SimpleRandom random_; + FrameList list_; +}; + +// All frames are processed as soon as we have sequential data. +// Infinite buffering, so all frames are acked right away. +TEST_F(QuicSequencerRandomTest, RandomFramesNoDroppingNoBackup) { + EXPECT_CALL(stream_, OnDataAvailable()) + .Times(AnyNumber()) + .WillRepeatedly( + Invoke(this, &QuicSequencerRandomTest::ReadAvailableData)); + QuicByteCount total_bytes_consumed = 0; + EXPECT_CALL(stream_, AddBytesConsumed(_)) + .Times(AnyNumber()) + .WillRepeatedly( + testing::Invoke([&total_bytes_consumed](QuicByteCount bytes) { + total_bytes_consumed += bytes; + })); + + while (!list_.empty()) { + int index = OneToN(list_.size()) - 1; + QUIC_LOG(ERROR) << "Sending index " << index << " " << list_[index].second; + OnFrame(list_[index].first, list_[index].second.data()); + + list_.erase(list_.begin() + index); + } + + ASSERT_EQ(ABSL_ARRAYSIZE(kPayload) - 1, output_.size()); + EXPECT_EQ(kPayload, output_); + EXPECT_EQ(ABSL_ARRAYSIZE(kPayload) - 1, total_bytes_consumed); +} + +TEST_F(QuicSequencerRandomTest, RandomFramesNoDroppingBackup) { + char buffer[10]; + iovec iov[2]; + iov[0].iov_base = &buffer[0]; + iov[0].iov_len = 5; + iov[1].iov_base = &buffer[5]; + iov[1].iov_len = 5; + + EXPECT_CALL(stream_, OnDataAvailable()).Times(AnyNumber()); + QuicByteCount total_bytes_consumed = 0; + EXPECT_CALL(stream_, AddBytesConsumed(_)) + .Times(AnyNumber()) + .WillRepeatedly( + testing::Invoke([&total_bytes_consumed](QuicByteCount bytes) { + total_bytes_consumed += bytes; + })); + + while (output_.size() != ABSL_ARRAYSIZE(kPayload) - 1) { + if (!list_.empty() && OneToN(2) == 1) { // Send data + int index = OneToN(list_.size()) - 1; + OnFrame(list_[index].first, list_[index].second.data()); + list_.erase(list_.begin() + index); + } else { // Read data + bool has_bytes = sequencer_->HasBytesToRead(); + iovec peek_iov[20]; + int iovs_peeked = sequencer_->GetReadableRegions(peek_iov, 20); + if (has_bytes) { + ASSERT_LT(0, iovs_peeked); + ASSERT_TRUE(sequencer_->GetReadableRegion(peek_iov)); + } else { + ASSERT_EQ(0, iovs_peeked); + ASSERT_FALSE(sequencer_->GetReadableRegion(peek_iov)); + } + int total_bytes_to_peek = ABSL_ARRAYSIZE(buffer); + for (int i = 0; i < iovs_peeked; ++i) { + int bytes_to_peek = + std::min(peek_iov[i].iov_len, total_bytes_to_peek); + peeked_.append(static_cast(peek_iov[i].iov_base), bytes_to_peek); + total_bytes_to_peek -= bytes_to_peek; + if (total_bytes_to_peek == 0) { + break; + } + } + int bytes_read = sequencer_->Readv(iov, 2); + output_.append(buffer, bytes_read); + ASSERT_EQ(output_.size(), peeked_.size()); + } + } + EXPECT_EQ(std::string(kPayload), output_); + EXPECT_EQ(std::string(kPayload), peeked_); + EXPECT_EQ(ABSL_ARRAYSIZE(kPayload) - 1, total_bytes_consumed); +} + +// Same as above, just using a different method for reading. +TEST_F(QuicStreamSequencerTest, MarkConsumed) { + InSequence s; + EXPECT_CALL(stream_, OnDataAvailable()); + + OnFrame(0, "abc"); + OnFrame(3, "def"); + OnFrame(6, "ghi"); + + // abcdefghi buffered. + EXPECT_EQ(9u, sequencer_->NumBytesBuffered()); + + // Peek into the data. + std::vector expected = {"abcdefghi"}; + ASSERT_TRUE(VerifyReadableRegions(expected)); + + // Consume 1 byte. + EXPECT_CALL(stream_, AddBytesConsumed(1)); + sequencer_->MarkConsumed(1); + // Verify data. + std::vector expected2 = {"bcdefghi"}; + ASSERT_TRUE(VerifyReadableRegions(expected2)); + EXPECT_EQ(8u, sequencer_->NumBytesBuffered()); + + // Consume 2 bytes. + EXPECT_CALL(stream_, AddBytesConsumed(2)); + sequencer_->MarkConsumed(2); + // Verify data. + std::vector expected3 = {"defghi"}; + ASSERT_TRUE(VerifyReadableRegions(expected3)); + EXPECT_EQ(6u, sequencer_->NumBytesBuffered()); + + // Consume 5 bytes. + EXPECT_CALL(stream_, AddBytesConsumed(5)); + sequencer_->MarkConsumed(5); + // Verify data. + std::vector expected4{"i"}; + ASSERT_TRUE(VerifyReadableRegions(expected4)); + EXPECT_EQ(1u, sequencer_->NumBytesBuffered()); +} + +TEST_F(QuicStreamSequencerTest, MarkConsumedError) { + EXPECT_CALL(stream_, OnDataAvailable()); + + OnFrame(0, "abc"); + OnFrame(9, "jklmnopqrstuvwxyz"); + + // Peek into the data. Only the first chunk should be readable because of the + // missing data. + std::vector expected{"abc"}; + ASSERT_TRUE(VerifyReadableRegions(expected)); + + // Now, attempt to mark consumed more data than was readable and expect the + // stream to be closed. + EXPECT_QUIC_BUG( + { + EXPECT_CALL(stream_, ResetWithError(QuicResetStreamError::FromInternal( + QUIC_ERROR_PROCESSING_STREAM))); + sequencer_->MarkConsumed(4); + }, + "Invalid argument to MarkConsumed." + " expect to consume: 4, but not enough bytes available."); +} + +TEST_F(QuicStreamSequencerTest, MarkConsumedWithMissingPacket) { + InSequence s; + EXPECT_CALL(stream_, OnDataAvailable()); + + OnFrame(0, "abc"); + OnFrame(3, "def"); + // Missing packet: 6, ghi. + OnFrame(9, "jkl"); + + std::vector expected = {"abcdef"}; + ASSERT_TRUE(VerifyReadableRegions(expected)); + + EXPECT_CALL(stream_, AddBytesConsumed(6)); + sequencer_->MarkConsumed(6); +} + +TEST_F(QuicStreamSequencerTest, Move) { + InSequence s; + EXPECT_CALL(stream_, OnDataAvailable()); + + OnFrame(0, "abc"); + OnFrame(3, "def"); + OnFrame(6, "ghi"); + + // abcdefghi buffered. + EXPECT_EQ(9u, sequencer_->NumBytesBuffered()); + + // Peek into the data. + std::vector expected = {"abcdefghi"}; + ASSERT_TRUE(VerifyReadableRegions(expected)); + + QuicStreamSequencer sequencer2(std::move(*sequencer_)); + ASSERT_TRUE(VerifyReadableRegions(sequencer2, expected)); +} + +TEST_F(QuicStreamSequencerTest, OverlappingFramesReceived) { + // The peer should never send us non-identical stream frames which contain + // overlapping byte ranges - if they do, we close the connection. + QuicStreamId id = 1; + + QuicStreamFrame frame1(id, false, 1, absl::string_view("hello")); + sequencer_->OnStreamFrame(frame1); + + QuicStreamFrame frame2(id, false, 2, absl::string_view("hello")); + EXPECT_CALL(stream_, OnUnrecoverableError(QUIC_OVERLAPPING_STREAM_DATA, _)) + .Times(0); + sequencer_->OnStreamFrame(frame2); +} + +TEST_F(QuicStreamSequencerTest, DataAvailableOnOverlappingFrames) { + QuicStreamId id = 1; + const std::string data(1000, '.'); + + // Received [0, 1000). + QuicStreamFrame frame1(id, false, 0, data); + EXPECT_CALL(stream_, OnDataAvailable()); + sequencer_->OnStreamFrame(frame1); + // Consume [0, 500). + EXPECT_CALL(stream_, AddBytesConsumed(500)); + QuicStreamSequencerTest::ConsumeData(500); + EXPECT_EQ(500u, sequencer_->NumBytesConsumed()); + EXPECT_EQ(500u, sequencer_->NumBytesBuffered()); + + // Received [500, 1500). + QuicStreamFrame frame2(id, false, 500, data); + // Do not call OnDataAvailable as there are readable bytes left in the buffer. + EXPECT_CALL(stream_, OnDataAvailable()).Times(0); + sequencer_->OnStreamFrame(frame2); + // Consume [1000, 1500). + EXPECT_CALL(stream_, AddBytesConsumed(1000)); + QuicStreamSequencerTest::ConsumeData(1000); + EXPECT_EQ(1500u, sequencer_->NumBytesConsumed()); + EXPECT_EQ(0u, sequencer_->NumBytesBuffered()); + + // Received [1498, 1503). + QuicStreamFrame frame3(id, false, 1498, absl::string_view("hello")); + EXPECT_CALL(stream_, OnDataAvailable()); + sequencer_->OnStreamFrame(frame3); + EXPECT_CALL(stream_, AddBytesConsumed(3)); + QuicStreamSequencerTest::ConsumeData(3); + EXPECT_EQ(1503u, sequencer_->NumBytesConsumed()); + EXPECT_EQ(0u, sequencer_->NumBytesBuffered()); + + // Received [1000, 1005). + QuicStreamFrame frame4(id, false, 1000, absl::string_view("hello")); + EXPECT_CALL(stream_, OnDataAvailable()).Times(0); + sequencer_->OnStreamFrame(frame4); + EXPECT_EQ(1503u, sequencer_->NumBytesConsumed()); + EXPECT_EQ(0u, sequencer_->NumBytesBuffered()); +} + +TEST_F(QuicStreamSequencerTest, OnDataAvailableWhenReadableBytesIncrease) { + sequencer_->set_level_triggered(true); + QuicStreamId id = 1; + + // Received [0, 5). + QuicStreamFrame frame1(id, false, 0, "hello"); + EXPECT_CALL(stream_, OnDataAvailable()); + sequencer_->OnStreamFrame(frame1); + EXPECT_EQ(5u, sequencer_->NumBytesBuffered()); + + // Without consuming the buffer bytes, continue receiving [5, 11). + QuicStreamFrame frame2(id, false, 5, " world"); + // OnDataAvailable should still be called because there are more data to read. + EXPECT_CALL(stream_, OnDataAvailable()); + sequencer_->OnStreamFrame(frame2); + EXPECT_EQ(11u, sequencer_->NumBytesBuffered()); + + // Without consuming the buffer bytes, continue receiving [12, 13). + QuicStreamFrame frame3(id, false, 5, "a"); + // OnDataAvailable shouldn't be called becasue there are still only 11 bytes + // available. + EXPECT_CALL(stream_, OnDataAvailable()).Times(0); + sequencer_->OnStreamFrame(frame3); + EXPECT_EQ(11u, sequencer_->NumBytesBuffered()); +} + +TEST_F(QuicStreamSequencerTest, ReadSingleFrame) { + EXPECT_CALL(stream_, OnDataAvailable()); + OnFrame(0u, "abc"); + std::string actual; + EXPECT_CALL(stream_, AddBytesConsumed(3)); + sequencer_->Read(&actual); + EXPECT_EQ("abc", actual); + EXPECT_EQ(0u, sequencer_->NumBytesBuffered()); +} + +TEST_F(QuicStreamSequencerTest, ReadMultipleFramesWithMissingFrame) { + EXPECT_CALL(stream_, OnDataAvailable()); + OnFrame(0u, "abc"); + OnFrame(3u, "def"); + OnFrame(6u, "ghi"); + OnFrame(10u, "xyz"); // Byte 9 is missing. + std::string actual; + EXPECT_CALL(stream_, AddBytesConsumed(9)); + sequencer_->Read(&actual); + EXPECT_EQ("abcdefghi", actual); + EXPECT_EQ(3u, sequencer_->NumBytesBuffered()); +} + +TEST_F(QuicStreamSequencerTest, ReadAndAppendToString) { + EXPECT_CALL(stream_, OnDataAvailable()); + OnFrame(0u, "def"); + OnFrame(3u, "ghi"); + std::string actual = "abc"; + EXPECT_CALL(stream_, AddBytesConsumed(6)); + sequencer_->Read(&actual); + EXPECT_EQ("abcdefghi", actual); + EXPECT_EQ(0u, sequencer_->NumBytesBuffered()); +} + +TEST_F(QuicStreamSequencerTest, StopReading) { + EXPECT_CALL(stream_, OnDataAvailable()).Times(0); + EXPECT_CALL(stream_, OnFinRead()); + + EXPECT_CALL(stream_, AddBytesConsumed(0)); + sequencer_->StopReading(); + + EXPECT_CALL(stream_, AddBytesConsumed(3)); + OnFrame(0u, "abc"); + EXPECT_CALL(stream_, AddBytesConsumed(3)); + OnFrame(3u, "def"); + EXPECT_CALL(stream_, AddBytesConsumed(3)); + OnFinFrame(6u, "ghi"); +} + +TEST_F(QuicStreamSequencerTest, StopReadingWithLevelTriggered) { + EXPECT_CALL(stream_, AddBytesConsumed(0)); + EXPECT_CALL(stream_, AddBytesConsumed(3)).Times(3); + EXPECT_CALL(stream_, OnDataAvailable()).Times(0); + EXPECT_CALL(stream_, OnFinRead()); + + sequencer_->set_level_triggered(true); + sequencer_->StopReading(); + + OnFrame(0u, "abc"); + OnFrame(3u, "def"); + OnFinFrame(6u, "ghi"); +} + +// Regression test for https://crbug.com/992486. +TEST_F(QuicStreamSequencerTest, CorruptFinFrames) { + EXPECT_CALL(stream_, OnUnrecoverableError( + QUIC_STREAM_SEQUENCER_INVALID_STATE, + "Stream 1 received new final offset: 1, which is " + "different from close offset: 2")); + + OnFinFrame(2u, ""); + OnFinFrame(0u, "a"); + EXPECT_FALSE(sequencer_->HasBytesToRead()); +} + +// Regression test for crbug.com/1015693 +TEST_F(QuicStreamSequencerTest, ReceiveFinLessThanHighestOffset) { + EXPECT_CALL(stream_, OnDataAvailable()).Times(1); + EXPECT_CALL(stream_, OnUnrecoverableError( + QUIC_STREAM_SEQUENCER_INVALID_STATE, + "Stream 1 received fin with offset: 0, which " + "reduces current highest offset: 3")); + OnFrame(0u, "abc"); + OnFinFrame(0u, ""); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_stream_test.cc b/quiche/quic/core/quic_stream_test.cc new file mode 100644 index 000000000000..7afbb1954833 --- /dev/null +++ b/quiche/quic/core/quic_stream_test.cc @@ -0,0 +1,1752 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_stream.h" + +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/core/frames/quic_rst_stream_frame.h" +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/core/quic_write_blocked_list.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_config_peer.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_flow_controller_peer.h" +#include "quiche/quic/test_tools/quic_session_peer.h" +#include "quiche/quic/test_tools/quic_stream_peer.h" +#include "quiche/quic/test_tools/quic_stream_sequencer_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/quiche_mem_slice_storage.h" + +using testing::_; +using testing::AnyNumber; +using testing::AtLeast; +using testing::InSequence; +using testing::Invoke; +using testing::InvokeWithoutArgs; +using testing::Return; +using testing::StrictMock; + +namespace quic { +namespace test { +namespace { + +const char kData1[] = "FooAndBar"; +const char kData2[] = "EepAndBaz"; +const QuicByteCount kDataLen = 9; +const uint8_t kPacket0ByteConnectionId = 0; +const uint8_t kPacket8ByteConnectionId = 8; + +class TestStream : public QuicStream { + public: + TestStream(QuicStreamId id, QuicSession* session, StreamType type) + : QuicStream(id, session, /*is_static=*/false, type) { + sequencer()->set_level_triggered(true); + } + + TestStream(PendingStream* pending, QuicSession* session, bool is_static) + : QuicStream(pending, session, is_static) {} + + MOCK_METHOD(void, OnDataAvailable, (), (override)); + + MOCK_METHOD(void, OnCanWriteNewData, (), (override)); + + MOCK_METHOD(void, OnWriteSideInDataRecvdState, (), (override)); + + using QuicStream::CanWriteNewData; + using QuicStream::CanWriteNewDataAfterData; + using QuicStream::CloseWriteSide; + using QuicStream::fin_buffered; + using QuicStream::MaybeSendStopSending; + using QuicStream::OnClose; + using QuicStream::WriteMemSlices; + using QuicStream::WriteOrBufferData; + + private: + std::string data_; +}; + +class QuicStreamTest : public QuicTestWithParam { + public: + QuicStreamTest() + : zero_(QuicTime::Delta::Zero()), + supported_versions_(AllSupportedVersions()) {} + + void Initialize(Perspective perspective = Perspective::IS_SERVER) { + ParsedQuicVersionVector version_vector; + version_vector.push_back(GetParam()); + connection_ = new StrictMock( + &helper_, &alarm_factory_, perspective, version_vector); + connection_->AdvanceTime(QuicTime::Delta::FromSeconds(1)); + session_ = std::make_unique>(connection_); + session_->Initialize(); + connection_->SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(connection_->perspective())); + QuicConfigPeer::SetReceivedInitialSessionFlowControlWindow( + session_->config(), kMinimumFlowControlSendWindow); + QuicConfigPeer::SetReceivedInitialMaxStreamDataBytesUnidirectional( + session_->config(), kMinimumFlowControlSendWindow); + QuicConfigPeer::SetReceivedInitialMaxStreamDataBytesIncomingBidirectional( + session_->config(), kMinimumFlowControlSendWindow); + QuicConfigPeer::SetReceivedInitialMaxStreamDataBytesOutgoingBidirectional( + session_->config(), kMinimumFlowControlSendWindow); + QuicConfigPeer::SetReceivedMaxUnidirectionalStreams(session_->config(), 10); + session_->OnConfigNegotiated(); + + stream_ = new StrictMock(kTestStreamId, session_.get(), + BIDIRECTIONAL); + EXPECT_NE(nullptr, stream_); + EXPECT_CALL(*session_, ShouldKeepConnectionAlive()) + .WillRepeatedly(Return(true)); + // session_ now owns stream_. + session_->ActivateStream(absl::WrapUnique(stream_)); + // Ignore resetting when session_ is terminated. + EXPECT_CALL(*session_, MaybeSendStopSendingFrame(kTestStreamId, _)) + .Times(AnyNumber()); + EXPECT_CALL(*session_, MaybeSendRstStreamFrame(kTestStreamId, _, _)) + .Times(AnyNumber()); + write_blocked_list_ = + QuicSessionPeer::GetWriteBlockedStreams(session_.get()); + } + + bool fin_sent() { return stream_->fin_sent(); } + bool rst_sent() { return stream_->rst_sent(); } + + bool HasWriteBlockedStreams() { + return write_blocked_list_->HasWriteBlockedSpecialStream() || + write_blocked_list_->HasWriteBlockedDataStreams(); + } + + QuicConsumedData CloseStreamOnWriteError( + QuicStreamId id, QuicByteCount /*write_length*/, + QuicStreamOffset /*offset*/, StreamSendingState /*state*/, + TransmissionType /*type*/, absl::optional /*level*/) { + session_->ResetStream(id, QUIC_STREAM_CANCELLED); + return QuicConsumedData(1, false); + } + + bool ClearResetStreamFrame(const QuicFrame& frame) { + EXPECT_EQ(RST_STREAM_FRAME, frame.type); + DeleteFrame(&const_cast(frame)); + return true; + } + + bool ClearStopSendingFrame(const QuicFrame& frame) { + EXPECT_EQ(STOP_SENDING_FRAME, frame.type); + DeleteFrame(&const_cast(frame)); + return true; + } + + protected: + MockQuicConnectionHelper helper_; + MockAlarmFactory alarm_factory_; + MockQuicConnection* connection_; + std::unique_ptr session_; + StrictMock* stream_; + QuicWriteBlockedListInterface* write_blocked_list_; + QuicTime::Delta zero_; + ParsedQuicVersionVector supported_versions_; + QuicStreamId kTestStreamId = GetNthClientInitiatedBidirectionalStreamId( + GetParam().transport_version, 1); + const QuicStreamId kTestPendingStreamId = + GetNthClientInitiatedUnidirectionalStreamId(GetParam().transport_version, + 1); +}; + +INSTANTIATE_TEST_SUITE_P(QuicStreamTests, QuicStreamTest, + ::testing::ValuesIn(AllSupportedVersions()), + ::testing::PrintToStringParamName()); + +using PendingStreamTest = QuicStreamTest; + +INSTANTIATE_TEST_SUITE_P(PendingStreamTests, PendingStreamTest, + ::testing::ValuesIn(CurrentSupportedHttp3Versions()), + ::testing::PrintToStringParamName()); + +TEST_P(PendingStreamTest, PendingStreamStaticness) { + Initialize(); + + PendingStream pending(kTestPendingStreamId, session_.get()); + TestStream stream(&pending, session_.get(), false); + EXPECT_FALSE(stream.is_static()); + + PendingStream pending2(kTestPendingStreamId + 4, session_.get()); + TestStream stream2(&pending2, session_.get(), true); + EXPECT_TRUE(stream2.is_static()); +} + +TEST_P(PendingStreamTest, PendingStreamType) { + Initialize(); + + PendingStream pending(kTestPendingStreamId, session_.get()); + TestStream stream(&pending, session_.get(), false); + EXPECT_EQ(stream.type(), READ_UNIDIRECTIONAL); +} + +TEST_P(PendingStreamTest, PendingStreamTypeOnClient) { + Initialize(Perspective::IS_CLIENT); + + QuicStreamId server_initiated_pending_stream_id = + GetNthServerInitiatedUnidirectionalStreamId(session_->transport_version(), + 1); + PendingStream pending(server_initiated_pending_stream_id, session_.get()); + TestStream stream(&pending, session_.get(), false); + EXPECT_EQ(stream.type(), READ_UNIDIRECTIONAL); +} + +TEST_P(PendingStreamTest, PendingStreamTooMuchData) { + Initialize(); + + PendingStream pending(kTestPendingStreamId, session_.get()); + // Receive a stream frame that violates flow control: the byte offset is + // higher than the receive window offset. + QuicStreamFrame frame(kTestPendingStreamId, false, + kInitialSessionFlowControlWindowForTest + 1, "."); + + // Stream should not accept the frame, and the connection should be closed. + EXPECT_CALL(*connection_, + CloseConnection(QUIC_FLOW_CONTROL_RECEIVED_TOO_MUCH_DATA, _, _)); + pending.OnStreamFrame(frame); +} + +TEST_P(PendingStreamTest, PendingStreamTooMuchDataInRstStream) { + Initialize(); + + PendingStream pending1(kTestPendingStreamId, session_.get()); + // Receive a rst stream frame that violates flow control: the byte offset is + // higher than the receive window offset. + QuicRstStreamFrame frame1(kInvalidControlFrameId, kTestPendingStreamId, + QUIC_STREAM_CANCELLED, + kInitialSessionFlowControlWindowForTest + 1); + + // Pending stream should not accept the frame, and the connection should be + // closed. + EXPECT_CALL(*connection_, + CloseConnection(QUIC_FLOW_CONTROL_RECEIVED_TOO_MUCH_DATA, _, _)); + pending1.OnRstStreamFrame(frame1); + + QuicStreamId bidirection_stream_id = QuicUtils::GetFirstBidirectionalStreamId( + session_->transport_version(), Perspective::IS_CLIENT); + PendingStream pending2(bidirection_stream_id, session_.get()); + // Receive a rst stream frame that violates flow control: the byte offset is + // higher than the receive window offset. + QuicRstStreamFrame frame2(kInvalidControlFrameId, bidirection_stream_id, + QUIC_STREAM_CANCELLED, + kInitialSessionFlowControlWindowForTest + 1); + // Bidirectional Pending stream should not accept the frame, and the + // connection should be closed. + EXPECT_CALL(*connection_, + CloseConnection(QUIC_FLOW_CONTROL_RECEIVED_TOO_MUCH_DATA, _, _)); + pending2.OnRstStreamFrame(frame2); +} + +TEST_P(PendingStreamTest, PendingStreamRstStream) { + Initialize(); + + PendingStream pending(kTestPendingStreamId, session_.get()); + QuicStreamOffset final_byte_offset = 7; + QuicRstStreamFrame frame(kInvalidControlFrameId, kTestPendingStreamId, + QUIC_STREAM_CANCELLED, final_byte_offset); + + // Pending stream should accept the frame and not close the connection. + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + pending.OnRstStreamFrame(frame); +} + +TEST_P(PendingStreamTest, PendingStreamWindowUpdate) { + Initialize(); + + QuicStreamId bidirection_stream_id = QuicUtils::GetFirstBidirectionalStreamId( + session_->transport_version(), Perspective::IS_CLIENT); + PendingStream pending(bidirection_stream_id, session_.get()); + QuicWindowUpdateFrame frame(kInvalidControlFrameId, bidirection_stream_id, + kDefaultFlowControlSendWindow * 2); + pending.OnWindowUpdateFrame(frame); + TestStream stream(&pending, session_.get(), false); + + EXPECT_EQ(QuicStreamPeer::SendWindowSize(&stream), + kDefaultFlowControlSendWindow * 2); +} + +TEST_P(PendingStreamTest, PendingStreamStopSending) { + Initialize(); + + QuicStreamId bidirection_stream_id = QuicUtils::GetFirstBidirectionalStreamId( + session_->transport_version(), Perspective::IS_CLIENT); + PendingStream pending(bidirection_stream_id, session_.get()); + QuicResetStreamError error = + QuicResetStreamError::FromInternal(QUIC_STREAM_INTERNAL_ERROR); + pending.OnStopSending(error); + EXPECT_TRUE(pending.GetStopSendingErrorCode()); + auto actual_error = *pending.GetStopSendingErrorCode(); + EXPECT_EQ(actual_error, error); +} + +TEST_P(PendingStreamTest, FromPendingStream) { + Initialize(); + + PendingStream pending(kTestPendingStreamId, session_.get()); + + QuicStreamFrame frame(kTestPendingStreamId, false, 2, "."); + pending.OnStreamFrame(frame); + pending.OnStreamFrame(frame); + QuicStreamFrame frame2(kTestPendingStreamId, true, 3, "."); + pending.OnStreamFrame(frame2); + + TestStream stream(&pending, session_.get(), false); + EXPECT_EQ(3, stream.num_frames_received()); + EXPECT_EQ(3u, stream.stream_bytes_read()); + EXPECT_EQ(1, stream.num_duplicate_frames_received()); + EXPECT_EQ(true, stream.fin_received()); + EXPECT_EQ(frame2.offset + 1, stream.highest_received_byte_offset()); + EXPECT_EQ(frame2.offset + 1, + session_->flow_controller()->highest_received_byte_offset()); +} + +TEST_P(PendingStreamTest, FromPendingStreamThenData) { + Initialize(); + + PendingStream pending(kTestPendingStreamId, session_.get()); + + QuicStreamFrame frame(kTestPendingStreamId, false, 2, "."); + pending.OnStreamFrame(frame); + + auto stream = new TestStream(&pending, session_.get(), false); + session_->ActivateStream(absl::WrapUnique(stream)); + + QuicStreamFrame frame2(kTestPendingStreamId, true, 3, "."); + stream->OnStreamFrame(frame2); + + EXPECT_EQ(2, stream->num_frames_received()); + EXPECT_EQ(2u, stream->stream_bytes_read()); + EXPECT_EQ(true, stream->fin_received()); + EXPECT_EQ(frame2.offset + 1, stream->highest_received_byte_offset()); + EXPECT_EQ(frame2.offset + 1, + session_->flow_controller()->highest_received_byte_offset()); +} + +TEST_P(QuicStreamTest, WriteAllData) { + Initialize(); + + QuicByteCount length = + 1 + QuicPacketCreator::StreamFramePacketOverhead( + connection_->transport_version(), kPacket8ByteConnectionId, + kPacket0ByteConnectionId, !kIncludeVersion, + !kIncludeDiversificationNonce, PACKET_4BYTE_PACKET_NUMBER, + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0, + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0, 0u); + connection_->SetMaxPacketLength(length); + + EXPECT_CALL(*session_, WritevData(kTestStreamId, _, _, _, _, _)) + .WillOnce(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + stream_->WriteOrBufferData(kData1, false, nullptr); + EXPECT_FALSE(HasWriteBlockedStreams()); +} + +TEST_P(QuicStreamTest, NoBlockingIfNoDataOrFin) { + Initialize(); + + // Write no data and no fin. If we consume nothing we should not be write + // blocked. + EXPECT_QUIC_BUG( + stream_->WriteOrBufferData(absl::string_view(), false, nullptr), ""); + EXPECT_FALSE(HasWriteBlockedStreams()); +} + +TEST_P(QuicStreamTest, BlockIfOnlySomeDataConsumed) { + Initialize(); + + // Write some data and no fin. If we consume some but not all of the data, + // we should be write blocked a not all the data was consumed. + EXPECT_CALL(*session_, WritevData(kTestStreamId, _, _, _, _, _)) + .WillOnce(InvokeWithoutArgs([this]() { + return session_->ConsumeData(stream_->id(), 1u, 0u, NO_FIN, + NOT_RETRANSMISSION, absl::nullopt); + })); + stream_->WriteOrBufferData(absl::string_view(kData1, 2), false, nullptr); + EXPECT_TRUE(session_->HasUnackedStreamData()); + ASSERT_EQ(1u, write_blocked_list_->NumBlockedStreams()); + EXPECT_EQ(1u, stream_->BufferedDataBytes()); +} + +TEST_P(QuicStreamTest, BlockIfFinNotConsumedWithData) { + Initialize(); + + // Write some data and no fin. If we consume all the data but not the fin, + // we should be write blocked because the fin was not consumed. + // (This should never actually happen as the fin should be sent out with the + // last data) + EXPECT_CALL(*session_, WritevData(kTestStreamId, _, _, _, _, _)) + .WillOnce(InvokeWithoutArgs([this]() { + return session_->ConsumeData(stream_->id(), 2u, 0u, NO_FIN, + NOT_RETRANSMISSION, absl::nullopt); + })); + stream_->WriteOrBufferData(absl::string_view(kData1, 2), true, nullptr); + EXPECT_TRUE(session_->HasUnackedStreamData()); + ASSERT_EQ(1u, write_blocked_list_->NumBlockedStreams()); +} + +TEST_P(QuicStreamTest, BlockIfSoloFinNotConsumed) { + Initialize(); + + // Write no data and a fin. If we consume nothing we should be write blocked, + // as the fin was not consumed. + EXPECT_CALL(*session_, WritevData(kTestStreamId, _, _, _, _, _)) + .WillOnce(Return(QuicConsumedData(0, false))); + stream_->WriteOrBufferData(absl::string_view(), true, nullptr); + ASSERT_EQ(1u, write_blocked_list_->NumBlockedStreams()); +} + +TEST_P(QuicStreamTest, CloseOnPartialWrite) { + Initialize(); + + // Write some data and no fin. However, while writing the data + // close the stream and verify that MarkConnectionLevelWriteBlocked does not + // crash with an unknown stream. + EXPECT_CALL(*session_, WritevData(kTestStreamId, _, _, _, _, _)) + .WillOnce(Invoke(this, &QuicStreamTest::CloseStreamOnWriteError)); + stream_->WriteOrBufferData(absl::string_view(kData1, 2), false, nullptr); + ASSERT_EQ(0u, write_blocked_list_->NumBlockedStreams()); +} + +TEST_P(QuicStreamTest, WriteOrBufferData) { + Initialize(); + + EXPECT_FALSE(HasWriteBlockedStreams()); + QuicByteCount length = + 1 + QuicPacketCreator::StreamFramePacketOverhead( + connection_->transport_version(), kPacket8ByteConnectionId, + kPacket0ByteConnectionId, !kIncludeVersion, + !kIncludeDiversificationNonce, PACKET_4BYTE_PACKET_NUMBER, + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0, + quiche::VARIABLE_LENGTH_INTEGER_LENGTH_0, 0u); + connection_->SetMaxPacketLength(length); + + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillOnce(InvokeWithoutArgs([this]() { + return session_->ConsumeData(stream_->id(), kDataLen - 1, 0u, NO_FIN, + NOT_RETRANSMISSION, absl::nullopt); + })); + stream_->WriteOrBufferData(kData1, false, nullptr); + + EXPECT_TRUE(session_->HasUnackedStreamData()); + EXPECT_EQ(1u, stream_->BufferedDataBytes()); + EXPECT_TRUE(HasWriteBlockedStreams()); + + // Queue a bytes_consumed write. + stream_->WriteOrBufferData(kData2, false, nullptr); + EXPECT_EQ(10u, stream_->BufferedDataBytes()); + // Make sure we get the tail of the first write followed by the bytes_consumed + InSequence s; + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillOnce(InvokeWithoutArgs([this]() { + return session_->ConsumeData(stream_->id(), kDataLen - 1, kDataLen - 1, + NO_FIN, NOT_RETRANSMISSION, absl::nullopt); + })); + EXPECT_CALL(*stream_, OnCanWriteNewData()); + stream_->OnCanWrite(); + EXPECT_TRUE(session_->HasUnackedStreamData()); + + // And finally the end of the bytes_consumed. + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillOnce(InvokeWithoutArgs([this]() { + return session_->ConsumeData(stream_->id(), 2u, 2 * kDataLen - 2, + NO_FIN, NOT_RETRANSMISSION, absl::nullopt); + })); + EXPECT_CALL(*stream_, OnCanWriteNewData()); + stream_->OnCanWrite(); + EXPECT_TRUE(session_->HasUnackedStreamData()); +} + +TEST_P(QuicStreamTest, WriteOrBufferDataReachStreamLimit) { + Initialize(); + std::string data("aaaaa"); + QuicStreamPeer::SetStreamBytesWritten(kMaxStreamLength - data.length(), + stream_); + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillOnce(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + stream_->WriteOrBufferData(data, false, nullptr); + EXPECT_TRUE(session_->HasUnackedStreamData()); + EXPECT_QUIC_BUG( + { + EXPECT_CALL(*connection_, + CloseConnection(QUIC_STREAM_LENGTH_OVERFLOW, _, _)); + stream_->WriteOrBufferData("a", false, nullptr); + }, + "Write too many data via stream"); +} + +TEST_P(QuicStreamTest, ConnectionCloseAfterStreamClose) { + Initialize(); + + QuicRstStreamFrame rst_frame(kInvalidControlFrameId, stream_->id(), + QUIC_STREAM_CANCELLED, 1234); + stream_->OnStreamReset(rst_frame); + if (VersionHasIetfQuicFrames(session_->transport_version())) { + // Create and inject a STOP SENDING frame to complete the close + // of the stream. This is only needed for version 99/IETF QUIC. + QuicStopSendingFrame stop_sending(kInvalidControlFrameId, stream_->id(), + QUIC_STREAM_CANCELLED); + session_->OnStopSendingFrame(stop_sending); + } + EXPECT_THAT(stream_->stream_error(), IsStreamError(QUIC_STREAM_CANCELLED)); + EXPECT_THAT(stream_->connection_error(), IsQuicNoError()); + stream_->OnConnectionClosed(QUIC_INTERNAL_ERROR, + ConnectionCloseSource::FROM_SELF); + EXPECT_THAT(stream_->stream_error(), IsStreamError(QUIC_STREAM_CANCELLED)); + EXPECT_THAT(stream_->connection_error(), IsQuicNoError()); +} + +TEST_P(QuicStreamTest, RstAlwaysSentIfNoFinSent) { + // For flow control accounting, a stream must send either a FIN or a RST frame + // before termination. + // Test that if no FIN has been sent, we send a RST. + + Initialize(); + EXPECT_FALSE(fin_sent()); + EXPECT_FALSE(rst_sent()); + + // Write some data, with no FIN. + EXPECT_CALL(*session_, WritevData(kTestStreamId, _, _, _, _, _)) + .WillOnce(InvokeWithoutArgs([this]() { + return session_->ConsumeData(stream_->id(), 1u, 0u, NO_FIN, + NOT_RETRANSMISSION, absl::nullopt); + })); + stream_->WriteOrBufferData(absl::string_view(kData1, 1), false, nullptr); + EXPECT_TRUE(session_->HasUnackedStreamData()); + EXPECT_FALSE(fin_sent()); + EXPECT_FALSE(rst_sent()); + + // Now close the stream, and expect that we send a RST. + EXPECT_CALL(*session_, MaybeSendRstStreamFrame(kTestStreamId, _, _)); + QuicRstStreamFrame rst_frame(kInvalidControlFrameId, stream_->id(), + QUIC_STREAM_CANCELLED, 1234); + stream_->OnStreamReset(rst_frame); + if (VersionHasIetfQuicFrames(session_->transport_version())) { + // Create and inject a STOP SENDING frame to complete the close + // of the stream. This is only needed for version 99/IETF QUIC. + QuicStopSendingFrame stop_sending(kInvalidControlFrameId, stream_->id(), + QUIC_STREAM_CANCELLED); + session_->OnStopSendingFrame(stop_sending); + } + EXPECT_FALSE(session_->HasUnackedStreamData()); + EXPECT_FALSE(fin_sent()); + EXPECT_TRUE(rst_sent()); +} + +TEST_P(QuicStreamTest, RstNotSentIfFinSent) { + // For flow control accounting, a stream must send either a FIN or a RST frame + // before termination. + // Test that if a FIN has been sent, we don't also send a RST. + + Initialize(); + EXPECT_FALSE(fin_sent()); + EXPECT_FALSE(rst_sent()); + + // Write some data, with FIN. + EXPECT_CALL(*session_, WritevData(kTestStreamId, _, _, _, _, _)) + .WillOnce(InvokeWithoutArgs([this]() { + return session_->ConsumeData(stream_->id(), 1u, 0u, FIN, + NOT_RETRANSMISSION, absl::nullopt); + })); + stream_->WriteOrBufferData(absl::string_view(kData1, 1), true, nullptr); + EXPECT_TRUE(fin_sent()); + EXPECT_FALSE(rst_sent()); + + // Now close the stream, and expect that we do not send a RST. + QuicStreamPeer::CloseReadSide(stream_); + stream_->CloseWriteSide(); + EXPECT_TRUE(fin_sent()); + EXPECT_FALSE(rst_sent()); +} + +TEST_P(QuicStreamTest, OnlySendOneRst) { + // For flow control accounting, a stream must send either a FIN or a RST frame + // before termination. + // Test that if a stream sends a RST, it doesn't send an additional RST during + // OnClose() (this shouldn't be harmful, but we shouldn't do it anyway...) + + Initialize(); + EXPECT_FALSE(fin_sent()); + EXPECT_FALSE(rst_sent()); + + // Reset the stream. + EXPECT_CALL(*session_, MaybeSendRstStreamFrame(kTestStreamId, _, _)).Times(1); + stream_->Reset(QUIC_STREAM_CANCELLED); + EXPECT_FALSE(fin_sent()); + EXPECT_TRUE(rst_sent()); + + // Now close the stream (any further resets being sent would break the + // expectation above). + QuicStreamPeer::CloseReadSide(stream_); + stream_->CloseWriteSide(); + EXPECT_FALSE(fin_sent()); + EXPECT_TRUE(rst_sent()); +} + +TEST_P(QuicStreamTest, StreamFlowControlMultipleWindowUpdates) { + Initialize(); + + // If we receive multiple WINDOW_UPDATES (potentially out of order), then we + // want to make sure we latch the largest offset we see. + + // Initially should be default. + EXPECT_EQ(kMinimumFlowControlSendWindow, + QuicStreamPeer::SendWindowOffset(stream_)); + + // Check a single WINDOW_UPDATE results in correct offset. + QuicWindowUpdateFrame window_update_1(kInvalidControlFrameId, stream_->id(), + kMinimumFlowControlSendWindow + 5); + stream_->OnWindowUpdateFrame(window_update_1); + EXPECT_EQ(window_update_1.max_data, + QuicStreamPeer::SendWindowOffset(stream_)); + + // Now send a few more WINDOW_UPDATES and make sure that only the largest is + // remembered. + QuicWindowUpdateFrame window_update_2(kInvalidControlFrameId, stream_->id(), + 1); + QuicWindowUpdateFrame window_update_3(kInvalidControlFrameId, stream_->id(), + kMinimumFlowControlSendWindow + 10); + QuicWindowUpdateFrame window_update_4(kInvalidControlFrameId, stream_->id(), + 5678); + stream_->OnWindowUpdateFrame(window_update_2); + stream_->OnWindowUpdateFrame(window_update_3); + stream_->OnWindowUpdateFrame(window_update_4); + EXPECT_EQ(window_update_3.max_data, + QuicStreamPeer::SendWindowOffset(stream_)); +} + +TEST_P(QuicStreamTest, FrameStats) { + Initialize(); + + EXPECT_EQ(0, stream_->num_frames_received()); + EXPECT_EQ(0, stream_->num_duplicate_frames_received()); + QuicStreamFrame frame(stream_->id(), false, 0, "."); + EXPECT_CALL(*stream_, OnDataAvailable()).Times(2); + stream_->OnStreamFrame(frame); + EXPECT_EQ(1, stream_->num_frames_received()); + EXPECT_EQ(0, stream_->num_duplicate_frames_received()); + stream_->OnStreamFrame(frame); + EXPECT_EQ(2, stream_->num_frames_received()); + EXPECT_EQ(1, stream_->num_duplicate_frames_received()); + QuicStreamFrame frame2(stream_->id(), false, 1, "abc"); + stream_->OnStreamFrame(frame2); +} + +// Verify that when we receive a packet which violates flow control (i.e. sends +// too much data on the stream) that the stream sequencer never sees this frame, +// as we check for violation and close the connection early. +TEST_P(QuicStreamTest, StreamSequencerNeverSeesPacketsViolatingFlowControl) { + Initialize(); + + // Receive a stream frame that violates flow control: the byte offset is + // higher than the receive window offset. + QuicStreamFrame frame(stream_->id(), false, + kInitialSessionFlowControlWindowForTest + 1, "."); + EXPECT_GT(frame.offset, QuicStreamPeer::ReceiveWindowOffset(stream_)); + + // Stream should not accept the frame, and the connection should be closed. + EXPECT_CALL(*connection_, + CloseConnection(QUIC_FLOW_CONTROL_RECEIVED_TOO_MUCH_DATA, _, _)); + stream_->OnStreamFrame(frame); +} + +// Verify that after the consumer calls StopReading(), the stream still sends +// flow control updates. +TEST_P(QuicStreamTest, StopReadingSendsFlowControl) { + Initialize(); + + stream_->StopReading(); + + // Connection should not get terminated due to flow control errors. + EXPECT_CALL(*connection_, + CloseConnection(QUIC_FLOW_CONTROL_RECEIVED_TOO_MUCH_DATA, _, _)) + .Times(0); + EXPECT_CALL(*session_, WriteControlFrame(_, _)) + .Times(AtLeast(1)) + .WillRepeatedly(Invoke(&ClearControlFrameWithTransmissionType)); + + std::string data(1000, 'x'); + for (QuicStreamOffset offset = 0; + offset < 2 * kInitialStreamFlowControlWindowForTest; + offset += data.length()) { + QuicStreamFrame frame(stream_->id(), false, offset, data); + stream_->OnStreamFrame(frame); + } + EXPECT_LT(kInitialStreamFlowControlWindowForTest, + QuicStreamPeer::ReceiveWindowOffset(stream_)); +} + +TEST_P(QuicStreamTest, FinalByteOffsetFromFin) { + Initialize(); + + EXPECT_FALSE(stream_->HasReceivedFinalOffset()); + + QuicStreamFrame stream_frame_no_fin(stream_->id(), false, 1234, "."); + stream_->OnStreamFrame(stream_frame_no_fin); + EXPECT_FALSE(stream_->HasReceivedFinalOffset()); + + QuicStreamFrame stream_frame_with_fin(stream_->id(), true, 1234, "."); + stream_->OnStreamFrame(stream_frame_with_fin); + EXPECT_TRUE(stream_->HasReceivedFinalOffset()); +} + +TEST_P(QuicStreamTest, FinalByteOffsetFromRst) { + Initialize(); + + EXPECT_FALSE(stream_->HasReceivedFinalOffset()); + QuicRstStreamFrame rst_frame(kInvalidControlFrameId, stream_->id(), + QUIC_STREAM_CANCELLED, 1234); + stream_->OnStreamReset(rst_frame); + EXPECT_TRUE(stream_->HasReceivedFinalOffset()); +} + +TEST_P(QuicStreamTest, InvalidFinalByteOffsetFromRst) { + Initialize(); + + EXPECT_FALSE(stream_->HasReceivedFinalOffset()); + QuicRstStreamFrame rst_frame(kInvalidControlFrameId, stream_->id(), + QUIC_STREAM_CANCELLED, 0xFFFFFFFFFFFF); + // Stream should not accept the frame, and the connection should be closed. + EXPECT_CALL(*connection_, + CloseConnection(QUIC_FLOW_CONTROL_RECEIVED_TOO_MUCH_DATA, _, _)); + stream_->OnStreamReset(rst_frame); + EXPECT_TRUE(stream_->HasReceivedFinalOffset()); +} + +TEST_P(QuicStreamTest, FinalByteOffsetFromZeroLengthStreamFrame) { + // When receiving Trailers, an empty stream frame is created with the FIN set, + // and is passed to OnStreamFrame. The Trailers may be sent in advance of + // queued body bytes being sent, and thus the final byte offset may exceed + // current flow control limits. Flow control should only be concerned with + // data that has actually been sent/received, so verify that flow control + // ignores such a stream frame. + Initialize(); + + EXPECT_FALSE(stream_->HasReceivedFinalOffset()); + const QuicStreamOffset kByteOffsetExceedingFlowControlWindow = + kInitialSessionFlowControlWindowForTest + 1; + const QuicStreamOffset current_stream_flow_control_offset = + QuicStreamPeer::ReceiveWindowOffset(stream_); + const QuicStreamOffset current_connection_flow_control_offset = + QuicFlowControllerPeer::ReceiveWindowOffset(session_->flow_controller()); + ASSERT_GT(kByteOffsetExceedingFlowControlWindow, + current_stream_flow_control_offset); + ASSERT_GT(kByteOffsetExceedingFlowControlWindow, + current_connection_flow_control_offset); + QuicStreamFrame zero_length_stream_frame_with_fin( + stream_->id(), /*fin=*/true, kByteOffsetExceedingFlowControlWindow, + absl::string_view()); + EXPECT_EQ(0, zero_length_stream_frame_with_fin.data_length); + + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + stream_->OnStreamFrame(zero_length_stream_frame_with_fin); + EXPECT_TRUE(stream_->HasReceivedFinalOffset()); + + // The flow control receive offset values should not have changed. + EXPECT_EQ(current_stream_flow_control_offset, + QuicStreamPeer::ReceiveWindowOffset(stream_)); + EXPECT_EQ( + current_connection_flow_control_offset, + QuicFlowControllerPeer::ReceiveWindowOffset(session_->flow_controller())); +} + +TEST_P(QuicStreamTest, OnStreamResetOffsetOverflow) { + Initialize(); + QuicRstStreamFrame rst_frame(kInvalidControlFrameId, stream_->id(), + QUIC_STREAM_CANCELLED, kMaxStreamLength + 1); + EXPECT_CALL(*connection_, CloseConnection(QUIC_STREAM_LENGTH_OVERFLOW, _, _)); + stream_->OnStreamReset(rst_frame); +} + +TEST_P(QuicStreamTest, OnStreamFrameUpperLimit) { + Initialize(); + + // Modify receive window offset and sequencer buffer total_bytes_read_ to + // avoid flow control violation. + QuicStreamPeer::SetReceiveWindowOffset(stream_, kMaxStreamLength + 5u); + QuicFlowControllerPeer::SetReceiveWindowOffset(session_->flow_controller(), + kMaxStreamLength + 5u); + QuicStreamSequencerPeer::SetFrameBufferTotalBytesRead( + QuicStreamPeer::sequencer(stream_), kMaxStreamLength - 10u); + + EXPECT_CALL(*connection_, CloseConnection(QUIC_STREAM_LENGTH_OVERFLOW, _, _)) + .Times(0); + QuicStreamFrame stream_frame(stream_->id(), false, kMaxStreamLength - 1, "."); + stream_->OnStreamFrame(stream_frame); + QuicStreamFrame stream_frame2(stream_->id(), true, kMaxStreamLength, ""); + stream_->OnStreamFrame(stream_frame2); +} + +TEST_P(QuicStreamTest, StreamTooLong) { + Initialize(); + QuicStreamFrame stream_frame(stream_->id(), false, kMaxStreamLength, "."); + EXPECT_QUIC_PEER_BUG( + { + EXPECT_CALL(*connection_, + CloseConnection(QUIC_STREAM_LENGTH_OVERFLOW, _, _)) + .Times(1); + stream_->OnStreamFrame(stream_frame); + }, + absl::StrCat("Receive stream frame on stream ", stream_->id(), + " reaches max stream length")); +} + +TEST_P(QuicStreamTest, SetDrainingIncomingOutgoing) { + // Don't have incoming data consumed. + Initialize(); + + // Incoming data with FIN. + QuicStreamFrame stream_frame_with_fin(stream_->id(), true, 1234, "."); + stream_->OnStreamFrame(stream_frame_with_fin); + // The FIN has been received but not consumed. + EXPECT_TRUE(stream_->HasReceivedFinalOffset()); + EXPECT_FALSE(QuicStreamPeer::read_side_closed(stream_)); + EXPECT_FALSE(stream_->reading_stopped()); + + EXPECT_EQ(1u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); + + // Outgoing data with FIN. + EXPECT_CALL(*session_, WritevData(kTestStreamId, _, _, _, _, _)) + .WillOnce(InvokeWithoutArgs([this]() { + return session_->ConsumeData(stream_->id(), 2u, 0u, FIN, + NOT_RETRANSMISSION, absl::nullopt); + })); + stream_->WriteOrBufferData(absl::string_view(kData1, 2), true, nullptr); + EXPECT_TRUE(stream_->write_side_closed()); + + EXPECT_EQ(1u, QuicSessionPeer::GetNumDrainingStreams(session_.get())); + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); +} + +TEST_P(QuicStreamTest, SetDrainingOutgoingIncoming) { + // Don't have incoming data consumed. + Initialize(); + + // Outgoing data with FIN. + EXPECT_CALL(*session_, WritevData(kTestStreamId, _, _, _, _, _)) + .WillOnce(InvokeWithoutArgs([this]() { + return session_->ConsumeData(stream_->id(), 2u, 0u, FIN, + NOT_RETRANSMISSION, absl::nullopt); + })); + stream_->WriteOrBufferData(absl::string_view(kData1, 2), true, nullptr); + EXPECT_TRUE(stream_->write_side_closed()); + + EXPECT_EQ(1u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); + + // Incoming data with FIN. + QuicStreamFrame stream_frame_with_fin(stream_->id(), true, 1234, "."); + stream_->OnStreamFrame(stream_frame_with_fin); + // The FIN has been received but not consumed. + EXPECT_TRUE(stream_->HasReceivedFinalOffset()); + EXPECT_FALSE(QuicStreamPeer::read_side_closed(stream_)); + EXPECT_FALSE(stream_->reading_stopped()); + + EXPECT_EQ(1u, QuicSessionPeer::GetNumDrainingStreams(session_.get())); + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); +} + +TEST_P(QuicStreamTest, EarlyResponseFinHandling) { + // Verify that if the server completes the response before reading the end of + // the request, the received FIN is recorded. + + Initialize(); + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillRepeatedly(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + + // Receive data for the request. + EXPECT_CALL(*stream_, OnDataAvailable()).Times(1); + QuicStreamFrame frame1(stream_->id(), false, 0, "Start"); + stream_->OnStreamFrame(frame1); + // When QuicSimpleServerStream sends the response, it calls + // QuicStream::CloseReadSide() first. + QuicStreamPeer::CloseReadSide(stream_); + // Send data and FIN for the response. + stream_->WriteOrBufferData(kData1, false, nullptr); + EXPECT_TRUE(QuicStreamPeer::read_side_closed(stream_)); + // Receive remaining data and FIN for the request. + QuicStreamFrame frame2(stream_->id(), true, 0, "End"); + stream_->OnStreamFrame(frame2); + EXPECT_TRUE(stream_->fin_received()); + EXPECT_TRUE(stream_->HasReceivedFinalOffset()); +} + +TEST_P(QuicStreamTest, StreamWaitsForAcks) { + Initialize(); + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillRepeatedly(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + // Stream is not waiting for acks initially. + EXPECT_FALSE(stream_->IsWaitingForAcks()); + EXPECT_EQ(0u, QuicStreamPeer::SendBuffer(stream_).size()); + EXPECT_FALSE(session_->HasUnackedStreamData()); + + // Send kData1. + stream_->WriteOrBufferData(kData1, false, nullptr); + EXPECT_TRUE(session_->HasUnackedStreamData()); + EXPECT_EQ(1u, QuicStreamPeer::SendBuffer(stream_).size()); + EXPECT_TRUE(stream_->IsWaitingForAcks()); + QuicByteCount newly_acked_length = 0; + EXPECT_TRUE(stream_->OnStreamFrameAcked(0, 9, false, QuicTime::Delta::Zero(), + QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(9u, newly_acked_length); + // Stream is not waiting for acks as all sent data is acked. + EXPECT_FALSE(stream_->IsWaitingForAcks()); + EXPECT_FALSE(session_->HasUnackedStreamData()); + EXPECT_EQ(0u, QuicStreamPeer::SendBuffer(stream_).size()); + + // Send kData2. + stream_->WriteOrBufferData(kData2, false, nullptr); + EXPECT_TRUE(stream_->IsWaitingForAcks()); + EXPECT_TRUE(session_->HasUnackedStreamData()); + EXPECT_EQ(1u, QuicStreamPeer::SendBuffer(stream_).size()); + // Send FIN. + stream_->WriteOrBufferData("", true, nullptr); + // Fin only frame is not stored in send buffer. + EXPECT_EQ(1u, QuicStreamPeer::SendBuffer(stream_).size()); + + // kData2 is retransmitted. + stream_->OnStreamFrameRetransmitted(9, 9, false); + + // kData2 is acked. + EXPECT_TRUE(stream_->OnStreamFrameAcked(9, 9, false, QuicTime::Delta::Zero(), + QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(9u, newly_acked_length); + // Stream is waiting for acks as FIN is not acked. + EXPECT_TRUE(stream_->IsWaitingForAcks()); + EXPECT_TRUE(session_->HasUnackedStreamData()); + EXPECT_EQ(0u, QuicStreamPeer::SendBuffer(stream_).size()); + + // FIN is acked. + EXPECT_CALL(*stream_, OnWriteSideInDataRecvdState()); + EXPECT_TRUE(stream_->OnStreamFrameAcked(18, 0, true, QuicTime::Delta::Zero(), + QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(0u, newly_acked_length); + EXPECT_FALSE(stream_->IsWaitingForAcks()); + EXPECT_FALSE(session_->HasUnackedStreamData()); + EXPECT_EQ(0u, QuicStreamPeer::SendBuffer(stream_).size()); +} + +TEST_P(QuicStreamTest, StreamDataGetAckedOutOfOrder) { + Initialize(); + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillRepeatedly(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + // Send data. + stream_->WriteOrBufferData(kData1, false, nullptr); + stream_->WriteOrBufferData(kData1, false, nullptr); + stream_->WriteOrBufferData(kData1, false, nullptr); + stream_->WriteOrBufferData("", true, nullptr); + EXPECT_EQ(3u, QuicStreamPeer::SendBuffer(stream_).size()); + EXPECT_TRUE(stream_->IsWaitingForAcks()); + EXPECT_TRUE(session_->HasUnackedStreamData()); + QuicByteCount newly_acked_length = 0; + EXPECT_TRUE(stream_->OnStreamFrameAcked(9, 9, false, QuicTime::Delta::Zero(), + QuicTime::Zero(), + &newly_acked_length)); + EXPECT_TRUE(session_->HasUnackedStreamData()); + EXPECT_EQ(9u, newly_acked_length); + EXPECT_EQ(3u, QuicStreamPeer::SendBuffer(stream_).size()); + EXPECT_TRUE(stream_->OnStreamFrameAcked(18, 9, false, QuicTime::Delta::Zero(), + QuicTime::Zero(), + &newly_acked_length)); + EXPECT_TRUE(session_->HasUnackedStreamData()); + EXPECT_EQ(9u, newly_acked_length); + EXPECT_EQ(3u, QuicStreamPeer::SendBuffer(stream_).size()); + EXPECT_TRUE(stream_->OnStreamFrameAcked(0, 9, false, QuicTime::Delta::Zero(), + QuicTime::Zero(), + &newly_acked_length)); + EXPECT_TRUE(session_->HasUnackedStreamData()); + EXPECT_EQ(9u, newly_acked_length); + EXPECT_EQ(0u, QuicStreamPeer::SendBuffer(stream_).size()); + // FIN is not acked yet. + EXPECT_TRUE(stream_->IsWaitingForAcks()); + EXPECT_TRUE(session_->HasUnackedStreamData()); + EXPECT_CALL(*stream_, OnWriteSideInDataRecvdState()); + EXPECT_TRUE(stream_->OnStreamFrameAcked(27, 0, true, QuicTime::Delta::Zero(), + QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(0u, newly_acked_length); + EXPECT_FALSE(stream_->IsWaitingForAcks()); + EXPECT_FALSE(session_->HasUnackedStreamData()); +} + +TEST_P(QuicStreamTest, CancelStream) { + Initialize(); + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillRepeatedly(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + EXPECT_FALSE(stream_->IsWaitingForAcks()); + EXPECT_FALSE(session_->HasUnackedStreamData()); + EXPECT_EQ(0u, QuicStreamPeer::SendBuffer(stream_).size()); + + stream_->WriteOrBufferData(kData1, false, nullptr); + EXPECT_TRUE(stream_->IsWaitingForAcks()); + EXPECT_TRUE(session_->HasUnackedStreamData()); + EXPECT_EQ(1u, QuicStreamPeer::SendBuffer(stream_).size()); + // Cancel stream. + stream_->MaybeSendStopSending(QUIC_STREAM_NO_ERROR); + // stream still waits for acks as the error code is QUIC_STREAM_NO_ERROR, and + // data is going to be retransmitted. + EXPECT_TRUE(stream_->IsWaitingForAcks()); + EXPECT_TRUE(session_->HasUnackedStreamData()); + EXPECT_CALL(*connection_, + OnStreamReset(stream_->id(), QUIC_STREAM_CANCELLED)); + EXPECT_CALL(*session_, WriteControlFrame(_, _)) + .Times(AtLeast(1)) + .WillRepeatedly(Invoke(&ClearControlFrameWithTransmissionType)); + + EXPECT_CALL(*session_, MaybeSendRstStreamFrame(_, _, _)) + .WillOnce(InvokeWithoutArgs([this]() { + session_->ReallyMaybeSendRstStreamFrame( + stream_->id(), QUIC_STREAM_CANCELLED, + stream_->stream_bytes_written()); + })); + + stream_->Reset(QUIC_STREAM_CANCELLED); + EXPECT_EQ(1u, QuicStreamPeer::SendBuffer(stream_).size()); + // Stream stops waiting for acks as data is not going to be retransmitted. + EXPECT_FALSE(stream_->IsWaitingForAcks()); + EXPECT_FALSE(session_->HasUnackedStreamData()); +} + +TEST_P(QuicStreamTest, RstFrameReceivedStreamNotFinishSending) { + if (VersionHasIetfQuicFrames(GetParam().transport_version)) { + // In IETF QUIC, receiving a RESET_STREAM will only close the read side. The + // stream itself is not closed and will not send reset. + return; + } + + Initialize(); + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillRepeatedly(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + EXPECT_FALSE(stream_->IsWaitingForAcks()); + EXPECT_FALSE(session_->HasUnackedStreamData()); + EXPECT_EQ(0u, QuicStreamPeer::SendBuffer(stream_).size()); + + stream_->WriteOrBufferData(kData1, false, nullptr); + EXPECT_TRUE(stream_->IsWaitingForAcks()); + EXPECT_TRUE(session_->HasUnackedStreamData()); + EXPECT_EQ(1u, QuicStreamPeer::SendBuffer(stream_).size()); + + // RST_STREAM received. + QuicRstStreamFrame rst_frame(kInvalidControlFrameId, stream_->id(), + QUIC_STREAM_CANCELLED, 9); + + EXPECT_CALL( + *session_, + MaybeSendRstStreamFrame( + stream_->id(), + QuicResetStreamError::FromInternal(QUIC_RST_ACKNOWLEDGEMENT), 9)); + stream_->OnStreamReset(rst_frame); + EXPECT_EQ(1u, QuicStreamPeer::SendBuffer(stream_).size()); + // Stream stops waiting for acks as it does not finish sending and rst is + // sent. + EXPECT_FALSE(stream_->IsWaitingForAcks()); + EXPECT_FALSE(session_->HasUnackedStreamData()); +} + +TEST_P(QuicStreamTest, RstFrameReceivedStreamFinishSending) { + Initialize(); + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillRepeatedly(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + EXPECT_FALSE(stream_->IsWaitingForAcks()); + EXPECT_FALSE(session_->HasUnackedStreamData()); + EXPECT_EQ(0u, QuicStreamPeer::SendBuffer(stream_).size()); + + stream_->WriteOrBufferData(kData1, true, nullptr); + EXPECT_TRUE(stream_->IsWaitingForAcks()); + EXPECT_TRUE(session_->HasUnackedStreamData()); + + // RST_STREAM received. + QuicRstStreamFrame rst_frame(kInvalidControlFrameId, stream_->id(), + QUIC_STREAM_CANCELLED, 1234); + stream_->OnStreamReset(rst_frame); + // Stream still waits for acks as it finishes sending and has unacked data. + EXPECT_TRUE(stream_->IsWaitingForAcks()); + EXPECT_TRUE(session_->HasUnackedStreamData()); + EXPECT_EQ(1u, QuicStreamPeer::SendBuffer(stream_).size()); +} + +TEST_P(QuicStreamTest, ConnectionClosed) { + Initialize(); + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillRepeatedly(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + EXPECT_FALSE(stream_->IsWaitingForAcks()); + EXPECT_FALSE(session_->HasUnackedStreamData()); + EXPECT_EQ(0u, QuicStreamPeer::SendBuffer(stream_).size()); + + stream_->WriteOrBufferData(kData1, false, nullptr); + EXPECT_TRUE(stream_->IsWaitingForAcks()); + EXPECT_TRUE(session_->HasUnackedStreamData()); + EXPECT_CALL( + *session_, + MaybeSendRstStreamFrame( + stream_->id(), + QuicResetStreamError::FromInternal(QUIC_RST_ACKNOWLEDGEMENT), 9)); + QuicConnectionPeer::SetConnectionClose(connection_); + stream_->OnConnectionClosed(QUIC_INTERNAL_ERROR, + ConnectionCloseSource::FROM_SELF); + EXPECT_EQ(1u, QuicStreamPeer::SendBuffer(stream_).size()); + // Stream stops waiting for acks as connection is going to close. + EXPECT_FALSE(stream_->IsWaitingForAcks()); + EXPECT_FALSE(session_->HasUnackedStreamData()); +} + +TEST_P(QuicStreamTest, CanWriteNewDataAfterData) { + SetQuicFlag(quic_buffered_data_threshold, 100); + Initialize(); + EXPECT_TRUE(stream_->CanWriteNewDataAfterData(99)); + EXPECT_FALSE(stream_->CanWriteNewDataAfterData(100)); +} + +TEST_P(QuicStreamTest, WriteBufferedData) { + // Set buffered data low water mark to be 100. + SetQuicFlag(quic_buffered_data_threshold, 100); + + Initialize(); + std::string data(1024, 'a'); + EXPECT_TRUE(stream_->CanWriteNewData()); + + // Testing WriteOrBufferData. + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillOnce(InvokeWithoutArgs([this]() { + return session_->ConsumeData(stream_->id(), 100u, 0u, NO_FIN, + NOT_RETRANSMISSION, absl::nullopt); + })); + stream_->WriteOrBufferData(data, false, nullptr); + stream_->WriteOrBufferData(data, false, nullptr); + stream_->WriteOrBufferData(data, false, nullptr); + EXPECT_TRUE(stream_->IsWaitingForAcks()); + + // Verify all data is saved. + EXPECT_EQ(3 * data.length() - 100, stream_->BufferedDataBytes()); + + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillOnce(InvokeWithoutArgs([this]() { + return session_->ConsumeData(stream_->id(), 100, 100u, NO_FIN, + NOT_RETRANSMISSION, absl::nullopt); + })); + // Buffered data size > threshold, do not ask upper layer for more data. + EXPECT_CALL(*stream_, OnCanWriteNewData()).Times(0); + stream_->OnCanWrite(); + EXPECT_EQ(3 * data.length() - 200, stream_->BufferedDataBytes()); + EXPECT_FALSE(stream_->CanWriteNewData()); + + // Send buffered data to make buffered data size < threshold. + QuicByteCount data_to_write = + 3 * data.length() - 200 - GetQuicFlag(quic_buffered_data_threshold) + 1; + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillOnce(InvokeWithoutArgs([this, data_to_write]() { + return session_->ConsumeData(stream_->id(), data_to_write, 200u, NO_FIN, + NOT_RETRANSMISSION, absl::nullopt); + })); + // Buffered data size < threshold, ask upper layer for more data. + EXPECT_CALL(*stream_, OnCanWriteNewData()).Times(1); + stream_->OnCanWrite(); + EXPECT_EQ( + static_cast(GetQuicFlag(quic_buffered_data_threshold) - 1), + stream_->BufferedDataBytes()); + EXPECT_TRUE(stream_->CanWriteNewData()); + + // Flush all buffered data. + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillOnce(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + EXPECT_CALL(*stream_, OnCanWriteNewData()).Times(1); + stream_->OnCanWrite(); + EXPECT_EQ(0u, stream_->BufferedDataBytes()); + EXPECT_FALSE(stream_->HasBufferedData()); + EXPECT_TRUE(stream_->CanWriteNewData()); + + // Testing Writev. + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillOnce(Return(QuicConsumedData(0, false))); + struct iovec iov = {const_cast(data.data()), data.length()}; + quiche::QuicheMemSliceStorage storage( + &iov, 1, session_->connection()->helper()->GetStreamSendBufferAllocator(), + 1024); + QuicConsumedData consumed = stream_->WriteMemSlices(storage.ToSpan(), false); + + // There is no buffered data before, all data should be consumed without + // respecting buffered data upper limit. + EXPECT_EQ(data.length(), consumed.bytes_consumed); + EXPECT_FALSE(consumed.fin_consumed); + EXPECT_EQ(data.length(), stream_->BufferedDataBytes()); + EXPECT_FALSE(stream_->CanWriteNewData()); + + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)).Times(0); + quiche::QuicheMemSliceStorage storage2( + &iov, 1, session_->connection()->helper()->GetStreamSendBufferAllocator(), + 1024); + consumed = stream_->WriteMemSlices(storage2.ToSpan(), false); + // No Data can be consumed as buffered data is beyond upper limit. + EXPECT_EQ(0u, consumed.bytes_consumed); + EXPECT_FALSE(consumed.fin_consumed); + EXPECT_EQ(data.length(), stream_->BufferedDataBytes()); + + data_to_write = data.length() - GetQuicFlag(quic_buffered_data_threshold) + 1; + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillOnce(InvokeWithoutArgs([this, data_to_write]() { + return session_->ConsumeData(stream_->id(), data_to_write, 0u, NO_FIN, + NOT_RETRANSMISSION, absl::nullopt); + })); + + EXPECT_CALL(*stream_, OnCanWriteNewData()).Times(1); + stream_->OnCanWrite(); + EXPECT_EQ( + static_cast(GetQuicFlag(quic_buffered_data_threshold) - 1), + stream_->BufferedDataBytes()); + EXPECT_TRUE(stream_->CanWriteNewData()); + + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)).Times(0); + // All data can be consumed as buffered data is below upper limit. + quiche::QuicheMemSliceStorage storage3( + &iov, 1, session_->connection()->helper()->GetStreamSendBufferAllocator(), + 1024); + consumed = stream_->WriteMemSlices(storage3.ToSpan(), false); + EXPECT_EQ(data.length(), consumed.bytes_consumed); + EXPECT_FALSE(consumed.fin_consumed); + EXPECT_EQ(data.length() + GetQuicFlag(quic_buffered_data_threshold) - 1, + stream_->BufferedDataBytes()); + EXPECT_FALSE(stream_->CanWriteNewData()); +} + +TEST_P(QuicStreamTest, WritevDataReachStreamLimit) { + Initialize(); + std::string data("aaaaa"); + QuicStreamPeer::SetStreamBytesWritten(kMaxStreamLength - data.length(), + stream_); + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillOnce(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + struct iovec iov = {const_cast(data.data()), 5u}; + quiche::QuicheMemSliceStorage storage( + &iov, 1, session_->connection()->helper()->GetStreamSendBufferAllocator(), + 1024); + QuicConsumedData consumed = stream_->WriteMemSlices(storage.ToSpan(), false); + EXPECT_EQ(data.length(), consumed.bytes_consumed); + struct iovec iov2 = {const_cast(data.data()), 1u}; + quiche::QuicheMemSliceStorage storage2( + &iov2, 1, + session_->connection()->helper()->GetStreamSendBufferAllocator(), 1024); + EXPECT_QUIC_BUG( + { + EXPECT_CALL(*connection_, + CloseConnection(QUIC_STREAM_LENGTH_OVERFLOW, _, _)); + stream_->WriteMemSlices(storage2.ToSpan(), false); + }, + "Write too many data via stream"); +} + +TEST_P(QuicStreamTest, WriteMemSlices) { + // Set buffered data low water mark to be 100. + SetQuicFlag(quic_buffered_data_threshold, 100); + + Initialize(); + constexpr QuicByteCount kDataSize = 1024; + quiche::QuicheBufferAllocator* allocator = + connection_->helper()->GetStreamSendBufferAllocator(); + std::vector vector1; + vector1.push_back( + quiche::QuicheMemSlice(quiche::QuicheBuffer(allocator, kDataSize))); + vector1.push_back( + quiche::QuicheMemSlice(quiche::QuicheBuffer(allocator, kDataSize))); + std::vector vector2; + vector2.push_back( + quiche::QuicheMemSlice(quiche::QuicheBuffer(allocator, kDataSize))); + vector2.push_back( + quiche::QuicheMemSlice(quiche::QuicheBuffer(allocator, kDataSize))); + absl::Span span1(vector1); + absl::Span span2(vector2); + + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillOnce(InvokeWithoutArgs([this]() { + return session_->ConsumeData(stream_->id(), 100u, 0u, NO_FIN, + NOT_RETRANSMISSION, absl::nullopt); + })); + // There is no buffered data before, all data should be consumed. + QuicConsumedData consumed = stream_->WriteMemSlices(span1, false); + EXPECT_EQ(2048u, consumed.bytes_consumed); + EXPECT_FALSE(consumed.fin_consumed); + EXPECT_EQ(2 * kDataSize - 100, stream_->BufferedDataBytes()); + EXPECT_FALSE(stream_->fin_buffered()); + + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)).Times(0); + // No Data can be consumed as buffered data is beyond upper limit. + consumed = stream_->WriteMemSlices(span2, true); + EXPECT_EQ(0u, consumed.bytes_consumed); + EXPECT_FALSE(consumed.fin_consumed); + EXPECT_EQ(2 * kDataSize - 100, stream_->BufferedDataBytes()); + EXPECT_FALSE(stream_->fin_buffered()); + + QuicByteCount data_to_write = + 2 * kDataSize - 100 - GetQuicFlag(quic_buffered_data_threshold) + 1; + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillOnce(InvokeWithoutArgs([this, data_to_write]() { + return session_->ConsumeData(stream_->id(), data_to_write, 100u, NO_FIN, + NOT_RETRANSMISSION, absl::nullopt); + })); + EXPECT_CALL(*stream_, OnCanWriteNewData()).Times(1); + stream_->OnCanWrite(); + EXPECT_EQ( + static_cast(GetQuicFlag(quic_buffered_data_threshold) - 1), + stream_->BufferedDataBytes()); + // Try to write slices2 again. + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)).Times(0); + consumed = stream_->WriteMemSlices(span2, true); + EXPECT_EQ(2048u, consumed.bytes_consumed); + EXPECT_TRUE(consumed.fin_consumed); + EXPECT_EQ(2 * kDataSize + GetQuicFlag(quic_buffered_data_threshold) - 1, + stream_->BufferedDataBytes()); + EXPECT_TRUE(stream_->fin_buffered()); + + // Flush all buffered data. + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillOnce(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + stream_->OnCanWrite(); + EXPECT_CALL(*stream_, OnCanWriteNewData()).Times(0); + EXPECT_FALSE(stream_->HasBufferedData()); + EXPECT_TRUE(stream_->write_side_closed()); +} + +TEST_P(QuicStreamTest, WriteMemSlicesReachStreamLimit) { + Initialize(); + QuicStreamPeer::SetStreamBytesWritten(kMaxStreamLength - 5u, stream_); + std::vector> buffers; + quiche::QuicheMemSlice slice1 = MemSliceFromString("12345"); + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillOnce(InvokeWithoutArgs([this]() { + return session_->ConsumeData(stream_->id(), 5u, 0u, NO_FIN, + NOT_RETRANSMISSION, absl::nullopt); + })); + // There is no buffered data before, all data should be consumed. + QuicConsumedData consumed = stream_->WriteMemSlice(std::move(slice1), false); + EXPECT_EQ(5u, consumed.bytes_consumed); + + quiche::QuicheMemSlice slice2 = MemSliceFromString("6"); + EXPECT_QUIC_BUG( + { + EXPECT_CALL(*connection_, + CloseConnection(QUIC_STREAM_LENGTH_OVERFLOW, _, _)); + stream_->WriteMemSlice(std::move(slice2), false); + }, + "Write too many data via stream"); +} + +TEST_P(QuicStreamTest, StreamDataGetAckedMultipleTimes) { + Initialize(); + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillRepeatedly(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + EXPECT_FALSE(stream_->IsWaitingForAcks()); + EXPECT_FALSE(session_->HasUnackedStreamData()); + + // Send [0, 27) and fin. + stream_->WriteOrBufferData(kData1, false, nullptr); + stream_->WriteOrBufferData(kData1, false, nullptr); + stream_->WriteOrBufferData(kData1, true, nullptr); + EXPECT_EQ(3u, QuicStreamPeer::SendBuffer(stream_).size()); + EXPECT_TRUE(stream_->IsWaitingForAcks()); + EXPECT_TRUE(session_->HasUnackedStreamData()); + // Ack [0, 9), [5, 22) and [18, 26) + // Verify [0, 9) 9 bytes are acked. + QuicByteCount newly_acked_length = 0; + EXPECT_TRUE(stream_->OnStreamFrameAcked(0, 9, false, QuicTime::Delta::Zero(), + QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(9u, newly_acked_length); + EXPECT_EQ(2u, QuicStreamPeer::SendBuffer(stream_).size()); + // Verify [9, 22) 13 bytes are acked. + EXPECT_TRUE(stream_->OnStreamFrameAcked(5, 17, false, QuicTime::Delta::Zero(), + QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(13u, newly_acked_length); + EXPECT_EQ(1u, QuicStreamPeer::SendBuffer(stream_).size()); + // Verify [22, 26) 4 bytes are acked. + EXPECT_TRUE(stream_->OnStreamFrameAcked(18, 8, false, QuicTime::Delta::Zero(), + QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(4u, newly_acked_length); + EXPECT_EQ(1u, QuicStreamPeer::SendBuffer(stream_).size()); + EXPECT_TRUE(stream_->IsWaitingForAcks()); + EXPECT_TRUE(session_->HasUnackedStreamData()); + + // Ack [0, 27). Verify [26, 27) 1 byte is acked. + EXPECT_TRUE(stream_->OnStreamFrameAcked(26, 1, false, QuicTime::Delta::Zero(), + QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(1u, newly_acked_length); + EXPECT_EQ(0u, QuicStreamPeer::SendBuffer(stream_).size()); + EXPECT_TRUE(stream_->IsWaitingForAcks()); + EXPECT_TRUE(session_->HasUnackedStreamData()); + + // Ack Fin. + EXPECT_CALL(*stream_, OnWriteSideInDataRecvdState()).Times(1); + EXPECT_TRUE(stream_->OnStreamFrameAcked(27, 0, true, QuicTime::Delta::Zero(), + QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(0u, newly_acked_length); + EXPECT_EQ(0u, QuicStreamPeer::SendBuffer(stream_).size()); + EXPECT_FALSE(stream_->IsWaitingForAcks()); + EXPECT_FALSE(session_->HasUnackedStreamData()); + + // Ack [10, 27) and fin. No new data is acked. + EXPECT_FALSE( + stream_->OnStreamFrameAcked(10, 17, true, QuicTime::Delta::Zero(), + QuicTime::Zero(), &newly_acked_length)); + EXPECT_EQ(0u, newly_acked_length); + EXPECT_EQ(0u, QuicStreamPeer::SendBuffer(stream_).size()); + EXPECT_FALSE(stream_->IsWaitingForAcks()); + EXPECT_FALSE(session_->HasUnackedStreamData()); +} + +TEST_P(QuicStreamTest, OnStreamFrameLost) { + Initialize(); + + // Send [0, 9). + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillOnce(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + stream_->WriteOrBufferData(kData1, false, nullptr); + EXPECT_FALSE(stream_->HasBufferedData()); + EXPECT_TRUE(stream_->IsStreamFrameOutstanding(0, 9, false)); + + // Try to send [9, 27), but connection is blocked. + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillOnce(Return(QuicConsumedData(0, false))); + stream_->WriteOrBufferData(kData2, false, nullptr); + stream_->WriteOrBufferData(kData2, false, nullptr); + EXPECT_TRUE(stream_->HasBufferedData()); + EXPECT_FALSE(stream_->HasPendingRetransmission()); + + // Lost [0, 9). When stream gets a chance to write, only lost data is + // transmitted. + stream_->OnStreamFrameLost(0, 9, false); + EXPECT_TRUE(stream_->HasPendingRetransmission()); + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillOnce(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + EXPECT_CALL(*stream_, OnCanWriteNewData()).Times(1); + stream_->OnCanWrite(); + EXPECT_FALSE(stream_->HasPendingRetransmission()); + EXPECT_TRUE(stream_->HasBufferedData()); + + // This OnCanWrite causes [9, 27) to be sent. + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillOnce(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + stream_->OnCanWrite(); + EXPECT_FALSE(stream_->HasBufferedData()); + + // Send a fin only frame. + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillOnce(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + stream_->WriteOrBufferData("", true, nullptr); + + // Lost [9, 27) and fin. + stream_->OnStreamFrameLost(9, 18, false); + stream_->OnStreamFrameLost(27, 0, true); + EXPECT_TRUE(stream_->HasPendingRetransmission()); + + // Ack [9, 18). + QuicByteCount newly_acked_length = 0; + EXPECT_TRUE(stream_->OnStreamFrameAcked(9, 9, false, QuicTime::Delta::Zero(), + QuicTime::Zero(), + &newly_acked_length)); + EXPECT_EQ(9u, newly_acked_length); + EXPECT_FALSE(stream_->IsStreamFrameOutstanding(9, 3, false)); + EXPECT_TRUE(stream_->HasPendingRetransmission()); + // This OnCanWrite causes [18, 27) and fin to be retransmitted. Verify fin can + // be bundled with data. + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillOnce(InvokeWithoutArgs([this]() { + return session_->ConsumeData(stream_->id(), 9u, 18u, FIN, + NOT_RETRANSMISSION, absl::nullopt); + })); + stream_->OnCanWrite(); + EXPECT_FALSE(stream_->HasPendingRetransmission()); + // Lost [9, 18) again, but it is not considered as lost because kData2 + // has been acked. + stream_->OnStreamFrameLost(9, 9, false); + EXPECT_FALSE(stream_->HasPendingRetransmission()); + EXPECT_TRUE(stream_->IsStreamFrameOutstanding(27, 0, true)); +} + +TEST_P(QuicStreamTest, CannotBundleLostFin) { + Initialize(); + + // Send [0, 18) and fin. + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillRepeatedly(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + stream_->WriteOrBufferData(kData1, false, nullptr); + stream_->WriteOrBufferData(kData2, true, nullptr); + + // Lost [0, 9) and fin. + stream_->OnStreamFrameLost(0, 9, false); + stream_->OnStreamFrameLost(18, 0, true); + + // Retransmit lost data. Verify [0, 9) and fin are retransmitted in two + // frames. + InSequence s; + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillOnce(InvokeWithoutArgs([this]() { + return session_->ConsumeData(stream_->id(), 9u, 0u, NO_FIN, + NOT_RETRANSMISSION, absl::nullopt); + })); + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillOnce(Return(QuicConsumedData(0, true))); + stream_->OnCanWrite(); +} + +TEST_P(QuicStreamTest, MarkConnectionLevelWriteBlockedOnWindowUpdateFrame) { + Initialize(); + + // Set the config to a small value so that a newly created stream has small + // send flow control window. + QuicConfigPeer::SetReceivedInitialStreamFlowControlWindow(session_->config(), + 100); + QuicConfigPeer::SetReceivedInitialMaxStreamDataBytesIncomingBidirectional( + session_->config(), 100); + auto stream = new TestStream(GetNthClientInitiatedBidirectionalStreamId( + GetParam().transport_version, 2), + session_.get(), BIDIRECTIONAL); + session_->ActivateStream(absl::WrapUnique(stream)); + + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillRepeatedly(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + EXPECT_CALL(*session_, SendBlocked(_, _)).Times(1); + std::string data(1024, '.'); + stream->WriteOrBufferData(data, false, nullptr); + EXPECT_FALSE(HasWriteBlockedStreams()); + + QuicWindowUpdateFrame window_update(kInvalidControlFrameId, stream_->id(), + 1234); + + stream->OnWindowUpdateFrame(window_update); + // Verify stream is marked connection level write blocked. + EXPECT_TRUE(HasWriteBlockedStreams()); + EXPECT_TRUE(stream->HasBufferedData()); +} + +// Regression test for b/73282665. +TEST_P(QuicStreamTest, + MarkConnectionLevelWriteBlockedOnWindowUpdateFrameWithNoBufferedData) { + Initialize(); + + // Set the config to a small value so that a newly created stream has small + // send flow control window. + QuicConfigPeer::SetReceivedInitialStreamFlowControlWindow(session_->config(), + 100); + QuicConfigPeer::SetReceivedInitialMaxStreamDataBytesIncomingBidirectional( + session_->config(), 100); + auto stream = new TestStream(GetNthClientInitiatedBidirectionalStreamId( + GetParam().transport_version, 2), + session_.get(), BIDIRECTIONAL); + session_->ActivateStream(absl::WrapUnique(stream)); + + std::string data(100, '.'); + EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) + .WillRepeatedly(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + EXPECT_CALL(*session_, SendBlocked(_, _)).Times(1); + stream->WriteOrBufferData(data, false, nullptr); + EXPECT_FALSE(HasWriteBlockedStreams()); + + QuicWindowUpdateFrame window_update(kInvalidControlFrameId, stream_->id(), + 120); + stream->OnWindowUpdateFrame(window_update); + EXPECT_FALSE(stream->HasBufferedData()); + // Verify stream is marked as blocked although there is no buffered data. + EXPECT_TRUE(HasWriteBlockedStreams()); +} + +TEST_P(QuicStreamTest, RetransmitStreamData) { + Initialize(); + InSequence s; + + // Send [0, 18) with fin. + EXPECT_CALL(*session_, WritevData(stream_->id(), _, _, _, _, _)) + .Times(2) + .WillRepeatedly(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + stream_->WriteOrBufferData(kData1, false, nullptr); + stream_->WriteOrBufferData(kData1, true, nullptr); + // Ack [10, 13). + QuicByteCount newly_acked_length = 0; + stream_->OnStreamFrameAcked(10, 3, false, QuicTime::Delta::Zero(), + QuicTime::Zero(), &newly_acked_length); + EXPECT_EQ(3u, newly_acked_length); + // Retransmit [0, 18) with fin, and only [0, 8) is consumed. + EXPECT_CALL(*session_, WritevData(stream_->id(), 10, 0, NO_FIN, _, _)) + .WillOnce(InvokeWithoutArgs([this]() { + return session_->ConsumeData(stream_->id(), 8, 0u, NO_FIN, + NOT_RETRANSMISSION, absl::nullopt); + })); + EXPECT_FALSE(stream_->RetransmitStreamData(0, 18, true, PTO_RETRANSMISSION)); + + // Retransmit [0, 18) with fin, and all is consumed. + EXPECT_CALL(*session_, WritevData(stream_->id(), 10, 0, NO_FIN, _, _)) + .WillOnce(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + EXPECT_CALL(*session_, WritevData(stream_->id(), 5, 13, FIN, _, _)) + .WillOnce(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + EXPECT_TRUE(stream_->RetransmitStreamData(0, 18, true, PTO_RETRANSMISSION)); + + // Retransmit [0, 8) with fin, and all is consumed. + EXPECT_CALL(*session_, WritevData(stream_->id(), 8, 0, NO_FIN, _, _)) + .WillOnce(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + EXPECT_CALL(*session_, WritevData(stream_->id(), 0, 18, FIN, _, _)) + .WillOnce(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + EXPECT_TRUE(stream_->RetransmitStreamData(0, 8, true, PTO_RETRANSMISSION)); +} + +TEST_P(QuicStreamTest, ResetStreamOnTtlExpiresRetransmitLostData) { + Initialize(); + + EXPECT_CALL(*session_, WritevData(stream_->id(), 200, 0, FIN, _, _)) + .WillOnce(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + std::string body(200, 'a'); + stream_->WriteOrBufferData(body, true, nullptr); + + // Set TTL to be 1 s. + QuicTime::Delta ttl = QuicTime::Delta::FromSeconds(1); + ASSERT_TRUE(stream_->MaybeSetTtl(ttl)); + // Verify data gets retransmitted because TTL does not expire. + EXPECT_CALL(*session_, WritevData(stream_->id(), 100, 0, NO_FIN, _, _)) + .WillOnce(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + EXPECT_TRUE(stream_->RetransmitStreamData(0, 100, false, PTO_RETRANSMISSION)); + stream_->OnStreamFrameLost(100, 100, true); + EXPECT_TRUE(stream_->HasPendingRetransmission()); + + connection_->AdvanceTime(QuicTime::Delta::FromSeconds(1)); + // Verify stream gets reset because TTL expires. + if (session_->version().UsesHttp3()) { + EXPECT_CALL(*session_, + MaybeSendStopSendingFrame(_, QuicResetStreamError::FromInternal( + QUIC_STREAM_TTL_EXPIRED))) + .Times(1); + } + EXPECT_CALL( + *session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_STREAM_TTL_EXPIRED), _)) + .Times(1); + stream_->OnCanWrite(); +} + +TEST_P(QuicStreamTest, ResetStreamOnTtlExpiresEarlyRetransmitData) { + Initialize(); + + EXPECT_CALL(*session_, WritevData(stream_->id(), 200, 0, FIN, _, _)) + .WillOnce(Invoke(session_.get(), &MockQuicSession::ConsumeData)); + std::string body(200, 'a'); + stream_->WriteOrBufferData(body, true, nullptr); + + // Set TTL to be 1 s. + QuicTime::Delta ttl = QuicTime::Delta::FromSeconds(1); + ASSERT_TRUE(stream_->MaybeSetTtl(ttl)); + + connection_->AdvanceTime(QuicTime::Delta::FromSeconds(1)); + // Verify stream gets reset because TTL expires. + if (session_->version().UsesHttp3()) { + EXPECT_CALL(*session_, + MaybeSendStopSendingFrame(_, QuicResetStreamError::FromInternal( + QUIC_STREAM_TTL_EXPIRED))) + .Times(1); + } + EXPECT_CALL( + *session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_STREAM_TTL_EXPIRED), _)) + .Times(1); + stream_->RetransmitStreamData(0, 100, false, PTO_RETRANSMISSION); +} + +// Test that OnStreamReset does one-way (read) closes if version 99, two way +// (read and write) if not version 99. +TEST_P(QuicStreamTest, OnStreamResetReadOrReadWrite) { + Initialize(); + EXPECT_FALSE(stream_->write_side_closed()); + EXPECT_FALSE(QuicStreamPeer::read_side_closed(stream_)); + + QuicRstStreamFrame rst_frame(kInvalidControlFrameId, stream_->id(), + QUIC_STREAM_CANCELLED, 1234); + stream_->OnStreamReset(rst_frame); + if (VersionHasIetfQuicFrames(connection_->transport_version())) { + // Version 99/IETF QUIC should close just the read side. + EXPECT_TRUE(QuicStreamPeer::read_side_closed(stream_)); + EXPECT_FALSE(stream_->write_side_closed()); + } else { + // Google QUIC should close both sides of the stream. + EXPECT_TRUE(stream_->write_side_closed()); + EXPECT_TRUE(QuicStreamPeer::read_side_closed(stream_)); + } +} + +TEST_P(QuicStreamTest, WindowUpdateForReadOnlyStream) { + Initialize(); + + QuicStreamId stream_id = QuicUtils::GetFirstUnidirectionalStreamId( + connection_->transport_version(), Perspective::IS_CLIENT); + TestStream stream(stream_id, session_.get(), READ_UNIDIRECTIONAL); + QuicWindowUpdateFrame window_update_frame(kInvalidControlFrameId, stream_id, + 0); + EXPECT_CALL( + *connection_, + CloseConnection( + QUIC_WINDOW_UPDATE_RECEIVED_ON_READ_UNIDIRECTIONAL_STREAM, + "WindowUpdateFrame received on READ_UNIDIRECTIONAL stream.", _)); + stream.OnWindowUpdateFrame(window_update_frame); +} + +TEST_P(QuicStreamTest, RstStreamFrameChangesCloseOffset) { + Initialize(); + + QuicStreamFrame stream_frame(stream_->id(), true, 0, "abc"); + EXPECT_CALL(*stream_, OnDataAvailable()); + stream_->OnStreamFrame(stream_frame); + QuicRstStreamFrame rst(kInvalidControlFrameId, stream_->id(), + QUIC_STREAM_CANCELLED, 0u); + + EXPECT_CALL(*connection_, CloseConnection(QUIC_STREAM_MULTIPLE_OFFSET, _, _)); + stream_->OnStreamReset(rst); +} + +// Regression test for b/176073284. +TEST_P(QuicStreamTest, EmptyStreamFrameWithNoFin) { + Initialize(); + QuicStreamFrame empty_stream_frame(stream_->id(), false, 0, ""); + if (stream_->version().HasIetfQuicFrames()) { + EXPECT_CALL(*connection_, + CloseConnection(QUIC_EMPTY_STREAM_FRAME_NO_FIN, _, _)) + .Times(0); + } else { + EXPECT_CALL(*connection_, + CloseConnection(QUIC_EMPTY_STREAM_FRAME_NO_FIN, _, _)); + } + EXPECT_CALL(*stream_, OnDataAvailable()).Times(0); + stream_->OnStreamFrame(empty_stream_frame); +} + +TEST_P(QuicStreamTest, SendRstWithCustomIetfCode) { + Initialize(); + QuicResetStreamError error(QUIC_STREAM_CANCELLED, 0x1234abcd); + EXPECT_CALL(*session_, MaybeSendRstStreamFrame(kTestStreamId, error, _)) + .Times(1); + stream_->ResetWithError(error); + EXPECT_TRUE(rst_sent()); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_sustained_bandwidth_recorder.cc b/quiche/quic/core/quic_sustained_bandwidth_recorder.cc new file mode 100644 index 000000000000..810820eece96 --- /dev/null +++ b/quiche/quic/core/quic_sustained_bandwidth_recorder.cc @@ -0,0 +1,59 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_sustained_bandwidth_recorder.h" + +#include "quiche/quic/core/quic_bandwidth.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +QuicSustainedBandwidthRecorder::QuicSustainedBandwidthRecorder() + : has_estimate_(false), + is_recording_(false), + bandwidth_estimate_recorded_during_slow_start_(false), + bandwidth_estimate_(QuicBandwidth::Zero()), + max_bandwidth_estimate_(QuicBandwidth::Zero()), + max_bandwidth_timestamp_(0), + start_time_(QuicTime::Zero()) {} + +void QuicSustainedBandwidthRecorder::RecordEstimate( + bool in_recovery, bool in_slow_start, QuicBandwidth bandwidth, + QuicTime estimate_time, QuicWallTime wall_time, QuicTime::Delta srtt) { + if (in_recovery) { + is_recording_ = false; + QUIC_DVLOG(1) << "Stopped recording at: " + << estimate_time.ToDebuggingValue(); + return; + } + + if (!is_recording_) { + // This is the first estimate of a new recording period. + start_time_ = estimate_time; + is_recording_ = true; + QUIC_DVLOG(1) << "Started recording at: " << start_time_.ToDebuggingValue(); + return; + } + + // If we have been recording for at least 3 * srtt, then record the latest + // bandwidth estimate as a valid sustained bandwidth estimate. + if (estimate_time - start_time_ >= 3 * srtt) { + has_estimate_ = true; + bandwidth_estimate_recorded_during_slow_start_ = in_slow_start; + bandwidth_estimate_ = bandwidth; + QUIC_DVLOG(1) << "New sustained bandwidth estimate (KBytes/s): " + << bandwidth_estimate_.ToKBytesPerSecond(); + } + + // Check for an increase in max bandwidth. + if (bandwidth > max_bandwidth_estimate_) { + max_bandwidth_estimate_ = bandwidth; + max_bandwidth_timestamp_ = wall_time.ToUNIXSeconds(); + QUIC_DVLOG(1) << "New max bandwidth estimate (KBytes/s): " + << max_bandwidth_estimate_.ToKBytesPerSecond(); + } +} + +} // namespace quic diff --git a/quiche/quic/core/quic_sustained_bandwidth_recorder.h b/quiche/quic/core/quic_sustained_bandwidth_recorder.h new file mode 100644 index 000000000000..63ed6cb03c14 --- /dev/null +++ b/quiche/quic/core/quic_sustained_bandwidth_recorder.h @@ -0,0 +1,92 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_SUSTAINED_BANDWIDTH_RECORDER_H_ +#define QUICHE_QUIC_CORE_QUIC_SUSTAINED_BANDWIDTH_RECORDER_H_ + +#include + +#include "quiche/quic/core/quic_bandwidth.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +namespace test { +class QuicSustainedBandwidthRecorderPeer; +} // namespace test + +// This class keeps track of a sustained bandwidth estimate to ultimately send +// to the client in a server config update message. A sustained bandwidth +// estimate is only marked as valid if the QuicSustainedBandwidthRecorder has +// been given uninterrupted reliable estimates over a certain period of time. +class QUIC_EXPORT_PRIVATE QuicSustainedBandwidthRecorder { + public: + QuicSustainedBandwidthRecorder(); + QuicSustainedBandwidthRecorder(const QuicSustainedBandwidthRecorder&) = + delete; + QuicSustainedBandwidthRecorder& operator=( + const QuicSustainedBandwidthRecorder&) = delete; + + // As long as |in_recovery| is consistently false, multiple calls to this + // method over a 3 * srtt period results in storage of a valid sustained + // bandwidth estimate. + // |time_now| is used as a max bandwidth timestamp if needed. + void RecordEstimate(bool in_recovery, bool in_slow_start, + QuicBandwidth bandwidth, QuicTime estimate_time, + QuicWallTime wall_time, QuicTime::Delta srtt); + + bool HasEstimate() const { return has_estimate_; } + + QuicBandwidth BandwidthEstimate() const { + QUICHE_DCHECK(has_estimate_); + return bandwidth_estimate_; + } + + QuicBandwidth MaxBandwidthEstimate() const { + QUICHE_DCHECK(has_estimate_); + return max_bandwidth_estimate_; + } + + int64_t MaxBandwidthTimestamp() const { + QUICHE_DCHECK(has_estimate_); + return max_bandwidth_timestamp_; + } + + bool EstimateRecordedDuringSlowStart() const { + QUICHE_DCHECK(has_estimate_); + return bandwidth_estimate_recorded_during_slow_start_; + } + + private: + friend class test::QuicSustainedBandwidthRecorderPeer; + + // True if we have been able to calculate sustained bandwidth, over at least + // one recording period (3 * rtt). + bool has_estimate_; + + // True if the last call to RecordEstimate had a reliable estimate. + bool is_recording_; + + // True if the current sustained bandwidth estimate was generated while in + // slow start. + bool bandwidth_estimate_recorded_during_slow_start_; + + // The latest sustained bandwidth estimate. + QuicBandwidth bandwidth_estimate_; + + // The maximum sustained bandwidth seen over the lifetime of the connection. + QuicBandwidth max_bandwidth_estimate_; + + // Timestamp indicating when the max_bandwidth_estimate_ was seen. + int64_t max_bandwidth_timestamp_; + + // Timestamp marking the beginning of the latest recording period. + QuicTime start_time_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_SUSTAINED_BANDWIDTH_RECORDER_H_ diff --git a/quiche/quic/core/quic_sustained_bandwidth_recorder_test.cc b/quiche/quic/core/quic_sustained_bandwidth_recorder_test.cc new file mode 100644 index 000000000000..6f53350683da --- /dev/null +++ b/quiche/quic/core/quic_sustained_bandwidth_recorder_test.cc @@ -0,0 +1,133 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_sustained_bandwidth_recorder.h" + +#include "quiche/quic/core/quic_bandwidth.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { +namespace { + +class QuicSustainedBandwidthRecorderTest : public QuicTest {}; + +TEST_F(QuicSustainedBandwidthRecorderTest, BandwidthEstimates) { + QuicSustainedBandwidthRecorder recorder; + EXPECT_FALSE(recorder.HasEstimate()); + + QuicTime estimate_time = QuicTime::Zero(); + QuicWallTime wall_time = QuicWallTime::Zero(); + QuicTime::Delta srtt = QuicTime::Delta::FromMilliseconds(150); + const int kBandwidthBitsPerSecond = 12345678; + QuicBandwidth bandwidth = + QuicBandwidth::FromBitsPerSecond(kBandwidthBitsPerSecond); + + bool in_recovery = false; + bool in_slow_start = false; + + // This triggers recording, but should not yield a valid estimate yet. + recorder.RecordEstimate(in_recovery, in_slow_start, bandwidth, estimate_time, + wall_time, srtt); + EXPECT_FALSE(recorder.HasEstimate()); + + // Send a second reading, again this should not result in a valid estimate, + // as not enough time has passed. + estimate_time = estimate_time + srtt; + recorder.RecordEstimate(in_recovery, in_slow_start, bandwidth, estimate_time, + wall_time, srtt); + EXPECT_FALSE(recorder.HasEstimate()); + + // Now 3 * kSRTT has elapsed since first recording, expect a valid estimate. + estimate_time = estimate_time + srtt; + estimate_time = estimate_time + srtt; + recorder.RecordEstimate(in_recovery, in_slow_start, bandwidth, estimate_time, + wall_time, srtt); + EXPECT_TRUE(recorder.HasEstimate()); + EXPECT_EQ(recorder.BandwidthEstimate(), bandwidth); + EXPECT_EQ(recorder.BandwidthEstimate(), recorder.MaxBandwidthEstimate()); + + // Resetting, and sending a different estimate will only change output after + // a further 3 * kSRTT has passed. + QuicBandwidth second_bandwidth = + QuicBandwidth::FromBitsPerSecond(2 * kBandwidthBitsPerSecond); + // Reset the recorder by passing in a measurement while in recovery. + in_recovery = true; + recorder.RecordEstimate(in_recovery, in_slow_start, bandwidth, estimate_time, + wall_time, srtt); + in_recovery = false; + recorder.RecordEstimate(in_recovery, in_slow_start, bandwidth, estimate_time, + wall_time, srtt); + EXPECT_EQ(recorder.BandwidthEstimate(), bandwidth); + + estimate_time = estimate_time + 3 * srtt; + const int64_t kSeconds = 556677; + QuicWallTime second_bandwidth_wall_time = + QuicWallTime::FromUNIXSeconds(kSeconds); + recorder.RecordEstimate(in_recovery, in_slow_start, second_bandwidth, + estimate_time, second_bandwidth_wall_time, srtt); + EXPECT_EQ(recorder.BandwidthEstimate(), second_bandwidth); + EXPECT_EQ(recorder.BandwidthEstimate(), recorder.MaxBandwidthEstimate()); + EXPECT_EQ(recorder.MaxBandwidthTimestamp(), kSeconds); + + // Reset again, this time recording a lower bandwidth than before. + QuicBandwidth third_bandwidth = + QuicBandwidth::FromBitsPerSecond(0.5 * kBandwidthBitsPerSecond); + // Reset the recorder by passing in an unreliable measurement. + recorder.RecordEstimate(in_recovery, in_slow_start, third_bandwidth, + estimate_time, wall_time, srtt); + recorder.RecordEstimate(in_recovery, in_slow_start, third_bandwidth, + estimate_time, wall_time, srtt); + EXPECT_EQ(recorder.BandwidthEstimate(), third_bandwidth); + + estimate_time = estimate_time + 3 * srtt; + recorder.RecordEstimate(in_recovery, in_slow_start, third_bandwidth, + estimate_time, wall_time, srtt); + EXPECT_EQ(recorder.BandwidthEstimate(), third_bandwidth); + + // Max bandwidth should not have changed. + EXPECT_LT(third_bandwidth, second_bandwidth); + EXPECT_EQ(recorder.MaxBandwidthEstimate(), second_bandwidth); + EXPECT_EQ(recorder.MaxBandwidthTimestamp(), kSeconds); +} + +TEST_F(QuicSustainedBandwidthRecorderTest, SlowStart) { + // Verify that slow start status is correctly recorded. + QuicSustainedBandwidthRecorder recorder; + EXPECT_FALSE(recorder.HasEstimate()); + + QuicTime estimate_time = QuicTime::Zero(); + QuicWallTime wall_time = QuicWallTime::Zero(); + QuicTime::Delta srtt = QuicTime::Delta::FromMilliseconds(150); + const int kBandwidthBitsPerSecond = 12345678; + QuicBandwidth bandwidth = + QuicBandwidth::FromBitsPerSecond(kBandwidthBitsPerSecond); + + bool in_recovery = false; + bool in_slow_start = true; + + // This triggers recording, but should not yield a valid estimate yet. + recorder.RecordEstimate(in_recovery, in_slow_start, bandwidth, estimate_time, + wall_time, srtt); + + // Now 3 * kSRTT has elapsed since first recording, expect a valid estimate. + estimate_time = estimate_time + 3 * srtt; + recorder.RecordEstimate(in_recovery, in_slow_start, bandwidth, estimate_time, + wall_time, srtt); + EXPECT_TRUE(recorder.HasEstimate()); + EXPECT_TRUE(recorder.EstimateRecordedDuringSlowStart()); + + // Now send another estimate, this time not in slow start. + estimate_time = estimate_time + 3 * srtt; + in_slow_start = false; + recorder.RecordEstimate(in_recovery, in_slow_start, bandwidth, estimate_time, + wall_time, srtt); + EXPECT_TRUE(recorder.HasEstimate()); + EXPECT_FALSE(recorder.EstimateRecordedDuringSlowStart()); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_syscall_wrapper.cc b/quiche/quic/core/quic_syscall_wrapper.cc new file mode 100644 index 000000000000..613483b6b4e2 --- /dev/null +++ b/quiche/quic/core/quic_syscall_wrapper.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_syscall_wrapper.h" + +#include +#include + +namespace quic { +namespace { +std::atomic global_syscall_wrapper(new QuicSyscallWrapper); +} // namespace + +ssize_t QuicSyscallWrapper::Sendmsg(int sockfd, const msghdr* msg, int flags) { + return ::sendmsg(sockfd, msg, flags); +} + +int QuicSyscallWrapper::Sendmmsg(int sockfd, mmsghdr* msgvec, unsigned int vlen, + int flags) { +#if defined(__linux__) && !defined(__ANDROID__) + return ::sendmmsg(sockfd, msgvec, vlen, flags); +#else + errno = ENOSYS; + return -1; +#endif +} + +QuicSyscallWrapper* GetGlobalSyscallWrapper() { + return global_syscall_wrapper.load(); +} + +void SetGlobalSyscallWrapper(QuicSyscallWrapper* wrapper) { + global_syscall_wrapper.store(wrapper); +} + +ScopedGlobalSyscallWrapperOverride::ScopedGlobalSyscallWrapperOverride( + QuicSyscallWrapper* wrapper_in_scope) + : original_wrapper_(GetGlobalSyscallWrapper()) { + SetGlobalSyscallWrapper(wrapper_in_scope); +} + +ScopedGlobalSyscallWrapperOverride::~ScopedGlobalSyscallWrapperOverride() { + SetGlobalSyscallWrapper(original_wrapper_); +} + +} // namespace quic diff --git a/quiche/quic/core/quic_syscall_wrapper.h b/quiche/quic/core/quic_syscall_wrapper.h new file mode 100644 index 000000000000..4f4ffb0d6aa8 --- /dev/null +++ b/quiche/quic/core/quic_syscall_wrapper.h @@ -0,0 +1,47 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_SYSCALL_WRAPPER_H_ +#define QUICHE_QUIC_CORE_QUIC_SYSCALL_WRAPPER_H_ + +#include +#include + +#include "quiche/quic/platform/api/quic_export.h" + +struct mmsghdr; +namespace quic { + +// QuicSyscallWrapper is a pass-through proxy to the real syscalls. +class QUIC_EXPORT_PRIVATE QuicSyscallWrapper { + public: + virtual ~QuicSyscallWrapper() = default; + + virtual ssize_t Sendmsg(int sockfd, const msghdr* msg, int flags); + + virtual int Sendmmsg(int sockfd, mmsghdr* msgvec, unsigned int vlen, + int flags); +}; + +// A global instance of QuicSyscallWrapper, used by some socket util functions. +QuicSyscallWrapper* GetGlobalSyscallWrapper(); + +// Change the global QuicSyscallWrapper to |wrapper|, for testing. +void SetGlobalSyscallWrapper(QuicSyscallWrapper* wrapper); + +// ScopedGlobalSyscallWrapperOverride changes the global QuicSyscallWrapper +// during its lifetime, for testing. +class QUIC_EXPORT_PRIVATE ScopedGlobalSyscallWrapperOverride { + public: + explicit ScopedGlobalSyscallWrapperOverride( + QuicSyscallWrapper* wrapper_in_scope); + ~ScopedGlobalSyscallWrapperOverride(); + + private: + QuicSyscallWrapper* original_wrapper_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_SYSCALL_WRAPPER_H_ diff --git a/quiche/quic/core/quic_tag.cc b/quiche/quic/core/quic_tag.cc new file mode 100644 index 000000000000..b762e33acc58 --- /dev/null +++ b/quiche/quic/core/quic_tag.cc @@ -0,0 +1,109 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_tag.h" + +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_split.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/common/quiche_text_utils.h" + +namespace quic { + +bool FindMutualQuicTag(const QuicTagVector& our_tags, + const QuicTagVector& their_tags, QuicTag* out_result, + size_t* out_index) { + const size_t num_our_tags = our_tags.size(); + const size_t num_their_tags = their_tags.size(); + for (size_t i = 0; i < num_our_tags; i++) { + for (size_t j = 0; j < num_their_tags; j++) { + if (our_tags[i] == their_tags[j]) { + *out_result = our_tags[i]; + if (out_index != nullptr) { + *out_index = j; + } + return true; + } + } + } + + return false; +} + +std::string QuicTagToString(QuicTag tag) { + if (tag == 0) { + return "0"; + } + char chars[sizeof tag]; + bool ascii = true; + const QuicTag orig_tag = tag; + + for (size_t i = 0; i < ABSL_ARRAYSIZE(chars); i++) { + chars[i] = static_cast(tag); + if ((chars[i] == 0 || chars[i] == '\xff') && + i == ABSL_ARRAYSIZE(chars) - 1) { + chars[i] = ' '; + } + if (!isprint(static_cast(chars[i]))) { + ascii = false; + break; + } + tag >>= 8; + } + + if (ascii) { + return std::string(chars, sizeof(chars)); + } + + return absl::BytesToHexString(absl::string_view( + reinterpret_cast(&orig_tag), sizeof(orig_tag))); +} + +uint32_t MakeQuicTag(uint8_t a, uint8_t b, uint8_t c, uint8_t d) { + return static_cast(a) | static_cast(b) << 8 | + static_cast(c) << 16 | static_cast(d) << 24; +} + +bool ContainsQuicTag(const QuicTagVector& tag_vector, QuicTag tag) { + return std::find(tag_vector.begin(), tag_vector.end(), tag) != + tag_vector.end(); +} + +QuicTag ParseQuicTag(absl::string_view tag_string) { + quiche::QuicheTextUtils::RemoveLeadingAndTrailingWhitespace(&tag_string); + std::string tag_bytes; + if (tag_string.length() == 8) { + tag_bytes = absl::HexStringToBytes(tag_string); + tag_string = tag_bytes; + } + QuicTag tag = 0; + // Iterate over every character from right to left. + for (auto it = tag_string.rbegin(); it != tag_string.rend(); ++it) { + // The cast here is required on platforms where char is signed. + unsigned char token_char = static_cast(*it); + tag <<= 8; + tag |= token_char; + } + return tag; +} + +QuicTagVector ParseQuicTagVector(absl::string_view tags_string) { + QuicTagVector tag_vector; + quiche::QuicheTextUtils::RemoveLeadingAndTrailingWhitespace(&tags_string); + if (!tags_string.empty()) { + std::vector tag_strings = + absl::StrSplit(tags_string, ','); + for (absl::string_view tag_string : tag_strings) { + tag_vector.push_back(ParseQuicTag(tag_string)); + } + } + return tag_vector; +} + +} // namespace quic diff --git a/quiche/quic/core/quic_tag.h b/quiche/quic/core/quic_tag.h new file mode 100644 index 000000000000..17c596810769 --- /dev/null +++ b/quiche/quic/core/quic_tag.h @@ -0,0 +1,67 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_TAG_H_ +#define QUICHE_QUIC_CORE_QUIC_TAG_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// A QuicTag is a 32-bit used as identifiers in the QUIC handshake. The use of +// a uint32_t seeks to provide a balance between the tyranny of magic number +// registries and the verbosity of strings. As far as the wire protocol is +// concerned, these are opaque, 32-bit values. +// +// Tags will often be referred to by their ASCII equivalent, e.g. EXMP. This is +// just a mnemonic for the value 0x504d5845 (little-endian version of the ASCII +// string E X M P). +using QuicTag = uint32_t; +using QuicTagValueMap = std::map; +using QuicTagVector = std::vector; + +// MakeQuicTag returns a value given the four bytes. For example: +// MakeQuicTag('C', 'H', 'L', 'O'); +QUIC_EXPORT_PRIVATE QuicTag MakeQuicTag(uint8_t a, uint8_t b, uint8_t c, + uint8_t d); + +// Returns true if |tag_vector| contains |tag|. +QUIC_EXPORT_PRIVATE bool ContainsQuicTag(const QuicTagVector& tag_vector, + QuicTag tag); + +// Sets |out_result| to the first tag in |our_tags| that is also in |their_tags| +// and returns true. If there is no intersection it returns false. +// +// If |out_index| is non-nullptr and a match is found then the index of that +// match in |their_tags| is written to |out_index|. +QUIC_EXPORT_PRIVATE bool FindMutualQuicTag(const QuicTagVector& our_tags, + const QuicTagVector& their_tags, + QuicTag* out_result, + size_t* out_index); + +// A utility function that converts a tag to a string. It will try to maintain +// the human friendly name if possible (i.e. kABCD -> "ABCD"), or will just +// treat it as a number if not. +QUIC_EXPORT_PRIVATE std::string QuicTagToString(QuicTag tag); + +// Utility function that converts a string of the form "ABCD" to its +// corresponding QuicTag. Note that tags that are less than four characters +// long are right-padded with zeroes. Tags that contain non-ASCII characters +// are represented as 8-character-long hexadecimal strings. +QUIC_EXPORT_PRIVATE QuicTag ParseQuicTag(absl::string_view tag_string); + +// Utility function that converts a string of the form "ABCD,EFGH" to a vector +// of the form {kABCD,kEFGH}. Note the caveats on ParseQuicTag. +QUIC_EXPORT_PRIVATE QuicTagVector +ParseQuicTagVector(absl::string_view tags_string); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_TAG_H_ diff --git a/quiche/quic/core/quic_tag_test.cc b/quiche/quic/core/quic_tag_test.cc new file mode 100644 index 000000000000..b3e6510559ba --- /dev/null +++ b/quiche/quic/core/quic_tag_test.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_tag.h" + +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { +namespace { + +class QuicTagTest : public QuicTest {}; + +TEST_F(QuicTagTest, TagToString) { + EXPECT_EQ("SCFG", QuicTagToString(kSCFG)); + EXPECT_EQ("SNO ", QuicTagToString(kServerNonceTag)); + EXPECT_EQ("CRT ", QuicTagToString(kCertificateTag)); + EXPECT_EQ("CHLO", QuicTagToString(MakeQuicTag('C', 'H', 'L', 'O'))); + // A tag that contains a non-printing character will be printed as hex. + EXPECT_EQ("43484c1f", QuicTagToString(MakeQuicTag('C', 'H', 'L', '\x1f'))); +} + +TEST_F(QuicTagTest, MakeQuicTag) { + QuicTag tag = MakeQuicTag('A', 'B', 'C', 'D'); + char bytes[4]; + memcpy(bytes, &tag, 4); + EXPECT_EQ('A', bytes[0]); + EXPECT_EQ('B', bytes[1]); + EXPECT_EQ('C', bytes[2]); + EXPECT_EQ('D', bytes[3]); +} + +TEST_F(QuicTagTest, ParseQuicTag) { + QuicTag tag_abcd = MakeQuicTag('A', 'B', 'C', 'D'); + EXPECT_EQ(ParseQuicTag("ABCD"), tag_abcd); + EXPECT_EQ(ParseQuicTag("ABCDE"), tag_abcd); + QuicTag tag_efgh = MakeQuicTag('E', 'F', 'G', 'H'); + EXPECT_EQ(ParseQuicTag("EFGH"), tag_efgh); + QuicTag tag_ijk = MakeQuicTag('I', 'J', 'K', 0); + EXPECT_EQ(ParseQuicTag("IJK"), tag_ijk); + QuicTag tag_l = MakeQuicTag('L', 0, 0, 0); + EXPECT_EQ(ParseQuicTag("L"), tag_l); + QuicTag tag_hex = MakeQuicTag('M', 'N', 'O', static_cast(255)); + EXPECT_EQ(ParseQuicTag("4d4e4fff"), tag_hex); + EXPECT_EQ(ParseQuicTag("4D4E4FFF"), tag_hex); + QuicTag tag_with_numbers = MakeQuicTag('P', 'Q', '1', '2'); + EXPECT_EQ(ParseQuicTag("PQ12"), tag_with_numbers); + QuicTag tag_with_custom_chars = MakeQuicTag('r', '$', '_', '7'); + EXPECT_EQ(ParseQuicTag("r$_7"), tag_with_custom_chars); + QuicTag tag_zero = 0; + EXPECT_EQ(ParseQuicTag(""), tag_zero); + QuicTagVector tag_vector; + EXPECT_EQ(ParseQuicTagVector(""), tag_vector); + EXPECT_EQ(ParseQuicTagVector(" "), tag_vector); + tag_vector.push_back(tag_abcd); + EXPECT_EQ(ParseQuicTagVector("ABCD"), tag_vector); + tag_vector.push_back(tag_efgh); + EXPECT_EQ(ParseQuicTagVector("ABCD,EFGH"), tag_vector); + tag_vector.push_back(tag_ijk); + EXPECT_EQ(ParseQuicTagVector("ABCD,EFGH,IJK"), tag_vector); + tag_vector.push_back(tag_l); + EXPECT_EQ(ParseQuicTagVector("ABCD,EFGH,IJK,L"), tag_vector); + tag_vector.push_back(tag_hex); + EXPECT_EQ(ParseQuicTagVector("ABCD,EFGH,IJK,L,4d4e4fff"), tag_vector); + tag_vector.push_back(tag_with_numbers); + EXPECT_EQ(ParseQuicTagVector("ABCD,EFGH,IJK,L,4d4e4fff,PQ12"), tag_vector); + tag_vector.push_back(tag_with_custom_chars); + EXPECT_EQ(ParseQuicTagVector("ABCD,EFGH,IJK,L,4d4e4fff,PQ12,r$_7"), + tag_vector); + tag_vector.push_back(tag_zero); + EXPECT_EQ(ParseQuicTagVector("ABCD,EFGH,IJK,L,4d4e4fff,PQ12,r$_7,"), + tag_vector); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_time.cc b/quiche/quic/core/quic_time.cc new file mode 100644 index 000000000000..afbca5bf2d40 --- /dev/null +++ b/quiche/quic/core/quic_time.cc @@ -0,0 +1,81 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_time.h" + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" + +namespace quic { + +std::string QuicTime::Delta::ToDebuggingValue() const { + constexpr int64_t kMillisecondInMicroseconds = 1000; + constexpr int64_t kSecondInMicroseconds = 1000 * kMillisecondInMicroseconds; + + int64_t absolute_value = std::abs(time_offset_); + + // For debugging purposes, always display the value with the highest precision + // available. + if (absolute_value >= kSecondInMicroseconds && + absolute_value % kSecondInMicroseconds == 0) { + return absl::StrCat(time_offset_ / kSecondInMicroseconds, "s"); + } + if (absolute_value >= kMillisecondInMicroseconds && + absolute_value % kMillisecondInMicroseconds == 0) { + return absl::StrCat(time_offset_ / kMillisecondInMicroseconds, "ms"); + } + return absl::StrCat(time_offset_, "us"); +} + +uint64_t QuicWallTime::ToUNIXSeconds() const { return microseconds_ / 1000000; } + +uint64_t QuicWallTime::ToUNIXMicroseconds() const { return microseconds_; } + +bool QuicWallTime::IsAfter(QuicWallTime other) const { + return microseconds_ > other.microseconds_; +} + +bool QuicWallTime::IsBefore(QuicWallTime other) const { + return microseconds_ < other.microseconds_; +} + +bool QuicWallTime::IsZero() const { return microseconds_ == 0; } + +QuicTime::Delta QuicWallTime::AbsoluteDifference(QuicWallTime other) const { + uint64_t d; + + if (microseconds_ > other.microseconds_) { + d = microseconds_ - other.microseconds_; + } else { + d = other.microseconds_ - microseconds_; + } + + if (d > static_cast(std::numeric_limits::max())) { + d = std::numeric_limits::max(); + } + return QuicTime::Delta::FromMicroseconds(d); +} + +QuicWallTime QuicWallTime::Add(QuicTime::Delta delta) const { + uint64_t microseconds = microseconds_ + delta.ToMicroseconds(); + if (microseconds < microseconds_) { + microseconds = std::numeric_limits::max(); + } + return QuicWallTime(microseconds); +} + +// TODO(ianswett) Test this. +QuicWallTime QuicWallTime::Subtract(QuicTime::Delta delta) const { + uint64_t microseconds = microseconds_ - delta.ToMicroseconds(); + if (microseconds > microseconds_) { + microseconds = 0; + } + return QuicWallTime(microseconds); +} + +} // namespace quic diff --git a/quiche/quic/core/quic_time.h b/quiche/quic/core/quic_time.h new file mode 100644 index 000000000000..ddf0307c17d7 --- /dev/null +++ b/quiche/quic/core/quic_time.h @@ -0,0 +1,295 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_TIME_H_ +#define QUICHE_QUIC_CORE_QUIC_TIME_H_ + +#include +#include +#include +#include +#include + +#include "absl/time/time.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class QuicClock; +class QuicTime; + +// A 64-bit signed integer type that stores a time duration as +// a number of microseconds. QUIC does not use absl::Duration, since the Abseil +// type is 128-bit, which would adversely affect certain performance-sensitive +// QUIC data structures. +class QUIC_EXPORT_PRIVATE QuicTimeDelta { + public: + // Creates a QuicTimeDelta from an absl::Duration. Note that this inherently + // loses precision, since absl::Duration is nanoseconds, and QuicTimeDelta is + // microseconds. + explicit QuicTimeDelta(absl::Duration duration) + : time_offset_((duration == absl::InfiniteDuration()) + ? kInfiniteTimeUs + : absl::ToInt64Microseconds(duration)) {} + + // Create a object with an offset of 0. + static constexpr QuicTimeDelta Zero() { return QuicTimeDelta(0); } + + // Create a object with infinite offset time. + static constexpr QuicTimeDelta Infinite() { + return QuicTimeDelta(kInfiniteTimeUs); + } + + // Converts a number of seconds to a time offset. + static constexpr QuicTimeDelta FromSeconds(int64_t secs) { + return QuicTimeDelta(secs * 1000 * 1000); + } + + // Converts a number of milliseconds to a time offset. + static constexpr QuicTimeDelta FromMilliseconds(int64_t ms) { + return QuicTimeDelta(ms * 1000); + } + + // Converts a number of microseconds to a time offset. + static constexpr QuicTimeDelta FromMicroseconds(int64_t us) { + return QuicTimeDelta(us); + } + + // Converts the time offset to a rounded number of seconds. + constexpr int64_t ToSeconds() const { return time_offset_ / 1000 / 1000; } + + // Converts the time offset to a rounded number of milliseconds. + constexpr int64_t ToMilliseconds() const { return time_offset_ / 1000; } + + // Converts the time offset to a rounded number of microseconds. + constexpr int64_t ToMicroseconds() const { return time_offset_; } + + // Converts the time offset to an Abseil duration. + constexpr absl::Duration ToAbsl() { + if (ABSL_PREDICT_FALSE(IsInfinite())) { + return absl::InfiniteDuration(); + } + return absl::Microseconds(time_offset_); + } + + constexpr bool IsZero() const { return time_offset_ == 0; } + + constexpr bool IsInfinite() const { return time_offset_ == kInfiniteTimeUs; } + + std::string ToDebuggingValue() const; + + private: + friend inline bool operator==(QuicTimeDelta lhs, QuicTimeDelta rhs); + friend inline bool operator<(QuicTimeDelta lhs, QuicTimeDelta rhs); + friend inline QuicTimeDelta operator<<(QuicTimeDelta lhs, size_t rhs); + friend inline QuicTimeDelta operator>>(QuicTimeDelta lhs, size_t rhs); + + friend inline constexpr QuicTimeDelta operator+(QuicTimeDelta lhs, + QuicTimeDelta rhs); + friend inline constexpr QuicTimeDelta operator-(QuicTimeDelta lhs, + QuicTimeDelta rhs); + friend inline constexpr QuicTimeDelta operator*(QuicTimeDelta lhs, int rhs); + // Not constexpr since std::llround() is not constexpr. + friend inline QuicTimeDelta operator*(QuicTimeDelta lhs, double rhs); + + friend inline QuicTime operator+(QuicTime lhs, QuicTimeDelta rhs); + friend inline QuicTime operator-(QuicTime lhs, QuicTimeDelta rhs); + friend inline QuicTimeDelta operator-(QuicTime lhs, QuicTime rhs); + + static constexpr int64_t kInfiniteTimeUs = + std::numeric_limits::max(); + + explicit constexpr QuicTimeDelta(int64_t time_offset) + : time_offset_(time_offset) {} + + int64_t time_offset_; + friend class QuicTime; +}; + +// A microsecond precision timestamp returned by a QuicClock. It is +// usually either a Unix timestamp or a timestamp returned by the +// platform-specific monotonic clock. QuicClock has a method to convert QuicTime +// to the wall time. +class QUIC_EXPORT_PRIVATE QuicTime { + public: + using Delta = QuicTimeDelta; + + // Creates a new QuicTime with an internal value of 0. IsInitialized() + // will return false for these times. + static constexpr QuicTime Zero() { return QuicTime(0); } + + // Creates a new QuicTime with an infinite time. + static constexpr QuicTime Infinite() { + return QuicTime(Delta::kInfiniteTimeUs); + } + + QuicTime(const QuicTime& other) = default; + + QuicTime& operator=(const QuicTime& other) { + time_ = other.time_; + return *this; + } + + // Produce the internal value to be used when logging. This value + // represents the number of microseconds since some epoch. It may + // be the UNIX epoch on some platforms. On others, it may + // be a CPU ticks based value. + int64_t ToDebuggingValue() const { return time_; } + + bool IsInitialized() const { return 0 != time_; } + + private: + friend class QuicClock; + + friend inline bool operator==(QuicTime lhs, QuicTime rhs); + friend inline bool operator<(QuicTime lhs, QuicTime rhs); + friend inline QuicTime operator+(QuicTime lhs, QuicTimeDelta rhs); + friend inline QuicTime operator-(QuicTime lhs, QuicTimeDelta rhs); + friend inline QuicTimeDelta operator-(QuicTime lhs, QuicTime rhs); + + explicit constexpr QuicTime(int64_t time) : time_(time) {} + + int64_t time_; +}; + +// A UNIX timestamp. +// +// TODO(vasilvv): evaluate whether this can be replaced with absl::Time. +class QUIC_EXPORT_PRIVATE QuicWallTime { + public: + // FromUNIXSeconds constructs a QuicWallTime from a count of the seconds + // since the UNIX epoch. + static constexpr QuicWallTime FromUNIXSeconds(uint64_t seconds) { + return QuicWallTime(seconds * 1000000); + } + + static constexpr QuicWallTime FromUNIXMicroseconds(uint64_t microseconds) { + return QuicWallTime(microseconds); + } + + // Zero returns a QuicWallTime set to zero. IsZero will return true for this + // value. + static constexpr QuicWallTime Zero() { return QuicWallTime(0); } + + // Returns the number of seconds since the UNIX epoch. + uint64_t ToUNIXSeconds() const; + // Returns the number of microseconds since the UNIX epoch. + uint64_t ToUNIXMicroseconds() const; + + bool IsAfter(QuicWallTime other) const; + bool IsBefore(QuicWallTime other) const; + + // IsZero returns true if this object is the result of calling |Zero|. + bool IsZero() const; + + // AbsoluteDifference returns the absolute value of the time difference + // between |this| and |other|. + QuicTimeDelta AbsoluteDifference(QuicWallTime other) const; + + // Add returns a new QuicWallTime that represents the time of |this| plus + // |delta|. + [[nodiscard]] QuicWallTime Add(QuicTimeDelta delta) const; + + // Subtract returns a new QuicWallTime that represents the time of |this| + // minus |delta|. + [[nodiscard]] QuicWallTime Subtract(QuicTimeDelta delta) const; + + bool operator==(const QuicWallTime& other) const { + return microseconds_ == other.microseconds_; + } + + QuicTimeDelta operator-(const QuicWallTime& rhs) const { + return QuicTimeDelta::FromMicroseconds(microseconds_ - rhs.microseconds_); + } + + private: + explicit constexpr QuicWallTime(uint64_t microseconds) + : microseconds_(microseconds) {} + + uint64_t microseconds_; +}; + +// Non-member relational operators for QuicTimeDelta. +inline bool operator==(QuicTimeDelta lhs, QuicTimeDelta rhs) { + return lhs.time_offset_ == rhs.time_offset_; +} +inline bool operator!=(QuicTimeDelta lhs, QuicTimeDelta rhs) { + return !(lhs == rhs); +} +inline bool operator<(QuicTimeDelta lhs, QuicTimeDelta rhs) { + return lhs.time_offset_ < rhs.time_offset_; +} +inline bool operator>(QuicTimeDelta lhs, QuicTimeDelta rhs) { + return rhs < lhs; +} +inline bool operator<=(QuicTimeDelta lhs, QuicTimeDelta rhs) { + return !(rhs < lhs); +} +inline bool operator>=(QuicTimeDelta lhs, QuicTimeDelta rhs) { + return !(lhs < rhs); +} +inline QuicTimeDelta operator<<(QuicTimeDelta lhs, size_t rhs) { + return QuicTimeDelta(lhs.time_offset_ << rhs); +} +inline QuicTimeDelta operator>>(QuicTimeDelta lhs, size_t rhs) { + return QuicTimeDelta(lhs.time_offset_ >> rhs); +} + +// Non-member relational operators for QuicTime. +inline bool operator==(QuicTime lhs, QuicTime rhs) { + return lhs.time_ == rhs.time_; +} +inline bool operator!=(QuicTime lhs, QuicTime rhs) { return !(lhs == rhs); } +inline bool operator<(QuicTime lhs, QuicTime rhs) { + return lhs.time_ < rhs.time_; +} +inline bool operator>(QuicTime lhs, QuicTime rhs) { return rhs < lhs; } +inline bool operator<=(QuicTime lhs, QuicTime rhs) { return !(rhs < lhs); } +inline bool operator>=(QuicTime lhs, QuicTime rhs) { return !(lhs < rhs); } + +// Override stream output operator for gtest or QUICHE_CHECK macros. +inline std::ostream& operator<<(std::ostream& output, const QuicTime t) { + output << t.ToDebuggingValue(); + return output; +} + +// Non-member arithmetic operators for QuicTimeDelta. +inline constexpr QuicTimeDelta operator+(QuicTimeDelta lhs, QuicTimeDelta rhs) { + return QuicTimeDelta(lhs.time_offset_ + rhs.time_offset_); +} +inline constexpr QuicTimeDelta operator-(QuicTimeDelta lhs, QuicTimeDelta rhs) { + return QuicTimeDelta(lhs.time_offset_ - rhs.time_offset_); +} +inline constexpr QuicTimeDelta operator*(QuicTimeDelta lhs, int rhs) { + return QuicTimeDelta(lhs.time_offset_ * rhs); +} +inline QuicTimeDelta operator*(QuicTimeDelta lhs, double rhs) { + return QuicTimeDelta(static_cast( + std::llround(static_cast(lhs.time_offset_) * rhs))); +} +inline QuicTimeDelta operator*(int lhs, QuicTimeDelta rhs) { return rhs * lhs; } +inline QuicTimeDelta operator*(double lhs, QuicTimeDelta rhs) { + return rhs * lhs; +} + +// Non-member arithmetic operators for QuicTime and QuicTimeDelta. +inline QuicTime operator+(QuicTime lhs, QuicTimeDelta rhs) { + return QuicTime(lhs.time_ + rhs.time_offset_); +} +inline QuicTime operator-(QuicTime lhs, QuicTimeDelta rhs) { + return QuicTime(lhs.time_ - rhs.time_offset_); +} +inline QuicTimeDelta operator-(QuicTime lhs, QuicTime rhs) { + return QuicTimeDelta(lhs.time_ - rhs.time_); +} + +// Override stream output operator for gtest. +inline std::ostream& operator<<(std::ostream& output, + const QuicTimeDelta delta) { + output << delta.ToDebuggingValue(); + return output; +} +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_TIME_H_ diff --git a/quiche/quic/core/quic_time_accumulator.h b/quiche/quic/core/quic_time_accumulator.h new file mode 100644 index 000000000000..480e79747e3f --- /dev/null +++ b/quiche/quic/core/quic_time_accumulator.h @@ -0,0 +1,69 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_TIME_ACCUMULATOR_H_ +#define QUICHE_QUIC_CORE_QUIC_TIME_ACCUMULATOR_H_ + +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +// QuicTimeAccumulator accumulates elapsed times between Start(s) and Stop(s). +class QUIC_EXPORT_PRIVATE QuicTimeAccumulator { + // TODO(wub): Switch to a data member called kNotRunningSentinel after c++17. + static constexpr QuicTime NotRunningSentinel() { + return QuicTime::Infinite(); + } + + public: + // True if Started and not Stopped. + bool IsRunning() const { return last_start_time_ != NotRunningSentinel(); } + + void Start(QuicTime now) { + QUICHE_DCHECK(!IsRunning()); + last_start_time_ = now; + QUICHE_DCHECK(IsRunning()); + } + + void Stop(QuicTime now) { + QUICHE_DCHECK(IsRunning()); + if (now > last_start_time_) { + total_elapsed_ = total_elapsed_ + (now - last_start_time_); + } + last_start_time_ = NotRunningSentinel(); + QUICHE_DCHECK(!IsRunning()); + } + + // Get total elapsed time between COMPLETED Start/Stop pairs. + QuicTime::Delta GetTotalElapsedTime() const { return total_elapsed_; } + + // Get total elapsed time between COMPLETED Start/Stop pairs, plus, if it is + // running, the elapsed time between |last_start_time_| and |now|. + QuicTime::Delta GetTotalElapsedTime(QuicTime now) const { + if (!IsRunning()) { + return total_elapsed_; + } + if (now <= last_start_time_) { + return total_elapsed_; + } + return total_elapsed_ + (now - last_start_time_); + } + + private: + // + // |last_start_time_| + // | + // V + // Start => Stop => Start => Stop => Start + // | | | | + // |___________| + |___________| = |total_elapsed_| + QuicTime::Delta total_elapsed_ = QuicTime::Delta::Zero(); + QuicTime last_start_time_ = NotRunningSentinel(); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_TIME_ACCUMULATOR_H_ diff --git a/quiche/quic/core/quic_time_accumulator_test.cc b/quiche/quic/core/quic_time_accumulator_test.cc new file mode 100644 index 000000000000..fd56df4ab65b --- /dev/null +++ b/quiche/quic/core/quic_time_accumulator_test.cc @@ -0,0 +1,83 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_time_accumulator.h" + +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/mock_clock.h" + +namespace quic { +namespace test { + +TEST(QuicTimeAccumulator, DefaultConstruct) { + MockClock clock; + clock.AdvanceTime(QuicTime::Delta::FromMilliseconds(1)); + + QuicTimeAccumulator acc; + EXPECT_FALSE(acc.IsRunning()); + + clock.AdvanceTime(QuicTime::Delta::FromMilliseconds(1)); + EXPECT_EQ(QuicTime::Delta::Zero(), acc.GetTotalElapsedTime()); + EXPECT_EQ(QuicTime::Delta::Zero(), acc.GetTotalElapsedTime(clock.Now())); +} + +TEST(QuicTimeAccumulator, StartStop) { + MockClock clock; + clock.AdvanceTime(QuicTime::Delta::FromMilliseconds(1)); + + QuicTimeAccumulator acc; + acc.Start(clock.Now()); + EXPECT_TRUE(acc.IsRunning()); + + clock.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + acc.Stop(clock.Now()); + EXPECT_FALSE(acc.IsRunning()); + + clock.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(10), acc.GetTotalElapsedTime()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(10), + acc.GetTotalElapsedTime(clock.Now())); + + acc.Start(clock.Now()); + clock.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(10), acc.GetTotalElapsedTime()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(15), + acc.GetTotalElapsedTime(clock.Now())); + + clock.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(10), acc.GetTotalElapsedTime()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(20), + acc.GetTotalElapsedTime(clock.Now())); + + acc.Stop(clock.Now()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(20), acc.GetTotalElapsedTime()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(20), + acc.GetTotalElapsedTime(clock.Now())); +} + +TEST(QuicTimeAccumulator, ClockStepBackwards) { + MockClock clock; + clock.AdvanceTime(QuicTime::Delta::FromMilliseconds(100)); + + QuicTimeAccumulator acc; + acc.Start(clock.Now()); + + clock.AdvanceTime(QuicTime::Delta::FromMilliseconds(-10)); + acc.Stop(clock.Now()); + EXPECT_EQ(QuicTime::Delta::Zero(), acc.GetTotalElapsedTime()); + EXPECT_EQ(QuicTime::Delta::Zero(), acc.GetTotalElapsedTime(clock.Now())); + + acc.Start(clock.Now()); + clock.AdvanceTime(QuicTime::Delta::FromMilliseconds(50)); + acc.Stop(clock.Now()); + + acc.Start(clock.Now()); + clock.AdvanceTime(QuicTime::Delta::FromMilliseconds(-80)); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(50), acc.GetTotalElapsedTime()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(50), + acc.GetTotalElapsedTime(clock.Now())); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_time_test.cc b/quiche/quic/core/quic_time_test.cc new file mode 100644 index 000000000000..19bf03ac10c2 --- /dev/null +++ b/quiche/quic/core/quic_time_test.cc @@ -0,0 +1,186 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_time.h" + +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/mock_clock.h" + +namespace quic { +namespace test { + +class QuicTimeDeltaTest : public QuicTest {}; + +TEST_F(QuicTimeDeltaTest, Zero) { + EXPECT_TRUE(QuicTime::Delta::Zero().IsZero()); + EXPECT_FALSE(QuicTime::Delta::Zero().IsInfinite()); + EXPECT_FALSE(QuicTime::Delta::FromMilliseconds(1).IsZero()); +} + +TEST_F(QuicTimeDeltaTest, Infinite) { + EXPECT_TRUE(QuicTime::Delta::Infinite().IsInfinite()); + EXPECT_FALSE(QuicTime::Delta::Zero().IsInfinite()); + EXPECT_FALSE(QuicTime::Delta::FromMilliseconds(1).IsInfinite()); +} + +TEST_F(QuicTimeDeltaTest, FromTo) { + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(1), + QuicTime::Delta::FromMicroseconds(1000)); + EXPECT_EQ(QuicTime::Delta::FromSeconds(1), + QuicTime::Delta::FromMilliseconds(1000)); + EXPECT_EQ(QuicTime::Delta::FromSeconds(1), + QuicTime::Delta::FromMicroseconds(1000000)); + + EXPECT_EQ(1, QuicTime::Delta::FromMicroseconds(1000).ToMilliseconds()); + EXPECT_EQ(2, QuicTime::Delta::FromMilliseconds(2000).ToSeconds()); + EXPECT_EQ(1000, QuicTime::Delta::FromMilliseconds(1).ToMicroseconds()); + EXPECT_EQ(1, QuicTime::Delta::FromMicroseconds(1000).ToMilliseconds()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(2000).ToMicroseconds(), + QuicTime::Delta::FromSeconds(2).ToMicroseconds()); +} + +TEST_F(QuicTimeDeltaTest, Add) { + EXPECT_EQ(QuicTime::Delta::FromMicroseconds(2000), + QuicTime::Delta::Zero() + QuicTime::Delta::FromMilliseconds(2)); +} + +TEST_F(QuicTimeDeltaTest, Subtract) { + EXPECT_EQ(QuicTime::Delta::FromMicroseconds(1000), + QuicTime::Delta::FromMilliseconds(2) - + QuicTime::Delta::FromMilliseconds(1)); +} + +TEST_F(QuicTimeDeltaTest, Multiply) { + int i = 2; + EXPECT_EQ(QuicTime::Delta::FromMicroseconds(4000), + QuicTime::Delta::FromMilliseconds(2) * i); + EXPECT_EQ(QuicTime::Delta::FromMicroseconds(4000), + i * QuicTime::Delta::FromMilliseconds(2)); + double d = 2; + EXPECT_EQ(QuicTime::Delta::FromMicroseconds(4000), + QuicTime::Delta::FromMilliseconds(2) * d); + EXPECT_EQ(QuicTime::Delta::FromMicroseconds(4000), + d * QuicTime::Delta::FromMilliseconds(2)); + + // Ensure we are rounding correctly within a single-bit level of precision. + EXPECT_EQ(QuicTime::Delta::FromMicroseconds(5), + QuicTime::Delta::FromMicroseconds(9) * 0.5); + EXPECT_EQ(QuicTime::Delta::FromMicroseconds(2), + QuicTime::Delta::FromMicroseconds(12) * 0.2); +} + +TEST_F(QuicTimeDeltaTest, Max) { + EXPECT_EQ(QuicTime::Delta::FromMicroseconds(2000), + std::max(QuicTime::Delta::FromMicroseconds(1000), + QuicTime::Delta::FromMicroseconds(2000))); +} + +TEST_F(QuicTimeDeltaTest, NotEqual) { + EXPECT_TRUE(QuicTime::Delta::FromSeconds(0) != + QuicTime::Delta::FromSeconds(1)); + EXPECT_FALSE(QuicTime::Delta::FromSeconds(0) != + QuicTime::Delta::FromSeconds(0)); +} + +TEST_F(QuicTimeDeltaTest, DebuggingValue) { + const QuicTime::Delta one_us = QuicTime::Delta::FromMicroseconds(1); + const QuicTime::Delta one_ms = QuicTime::Delta::FromMilliseconds(1); + const QuicTime::Delta one_s = QuicTime::Delta::FromSeconds(1); + + EXPECT_EQ("1s", one_s.ToDebuggingValue()); + EXPECT_EQ("3s", (3 * one_s).ToDebuggingValue()); + EXPECT_EQ("1ms", one_ms.ToDebuggingValue()); + EXPECT_EQ("3ms", (3 * one_ms).ToDebuggingValue()); + EXPECT_EQ("1us", one_us.ToDebuggingValue()); + EXPECT_EQ("3us", (3 * one_us).ToDebuggingValue()); + + EXPECT_EQ("3001us", (3 * one_ms + one_us).ToDebuggingValue()); + EXPECT_EQ("3001ms", (3 * one_s + one_ms).ToDebuggingValue()); + EXPECT_EQ("3000001us", (3 * one_s + one_us).ToDebuggingValue()); +} + +class QuicTimeTest : public QuicTest { + protected: + MockClock clock_; +}; + +TEST_F(QuicTimeTest, Initialized) { + EXPECT_FALSE(QuicTime::Zero().IsInitialized()); + EXPECT_TRUE((QuicTime::Zero() + QuicTime::Delta::FromMicroseconds(1)) + .IsInitialized()); +} + +TEST_F(QuicTimeTest, CopyConstruct) { + QuicTime time_1 = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(1234); + EXPECT_NE(time_1, QuicTime(QuicTime::Zero())); + EXPECT_EQ(time_1, QuicTime(time_1)); +} + +TEST_F(QuicTimeTest, CopyAssignment) { + QuicTime time_1 = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(1234); + QuicTime time_2 = QuicTime::Zero(); + EXPECT_NE(time_1, time_2); + time_2 = time_1; + EXPECT_EQ(time_1, time_2); +} + +TEST_F(QuicTimeTest, Add) { + QuicTime time_1 = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(1); + QuicTime time_2 = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(2); + + QuicTime::Delta diff = time_2 - time_1; + + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(1), diff); + EXPECT_EQ(1000, diff.ToMicroseconds()); + EXPECT_EQ(1, diff.ToMilliseconds()); +} + +TEST_F(QuicTimeTest, Subtract) { + QuicTime time_1 = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(1); + QuicTime time_2 = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(2); + + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(1), time_2 - time_1); +} + +TEST_F(QuicTimeTest, SubtractDelta) { + QuicTime time = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(2); + EXPECT_EQ(QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(1), + time - QuicTime::Delta::FromMilliseconds(1)); +} + +TEST_F(QuicTimeTest, Max) { + QuicTime time_1 = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(1); + QuicTime time_2 = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(2); + + EXPECT_EQ(time_2, std::max(time_1, time_2)); +} + +TEST_F(QuicTimeTest, MockClock) { + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(1)); + + QuicTime now = clock_.ApproximateNow(); + QuicTime time = QuicTime::Zero() + QuicTime::Delta::FromMicroseconds(1000); + + EXPECT_EQ(now, time); + + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(1)); + now = clock_.ApproximateNow(); + + EXPECT_NE(now, time); + + time = time + QuicTime::Delta::FromMilliseconds(1); + EXPECT_EQ(now, time); +} + +TEST_F(QuicTimeTest, LE) { + const QuicTime zero = QuicTime::Zero(); + const QuicTime one = zero + QuicTime::Delta::FromSeconds(1); + EXPECT_TRUE(zero <= zero); + EXPECT_TRUE(zero <= one); + EXPECT_TRUE(one <= one); + EXPECT_FALSE(one <= zero); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_time_wait_list_manager.cc b/quiche/quic/core/quic_time_wait_list_manager.cc new file mode 100644 index 000000000000..d4799738cccb --- /dev/null +++ b/quiche/quic/core/quic_time_wait_list_manager.cc @@ -0,0 +1,486 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_time_wait_list_manager.h" + +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/crypto/quic_decrypter.h" +#include "quiche/quic/core/crypto/quic_encrypter.h" +#include "quiche/quic/core/quic_clock.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_framer.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/common/quiche_text_utils.h" + +namespace quic { + +// A very simple alarm that just informs the QuicTimeWaitListManager to clean +// up old connection_ids. This alarm should be cancelled and deleted before +// the QuicTimeWaitListManager is deleted. +class ConnectionIdCleanUpAlarm : public QuicAlarm::DelegateWithoutContext { + public: + explicit ConnectionIdCleanUpAlarm( + QuicTimeWaitListManager* time_wait_list_manager) + : time_wait_list_manager_(time_wait_list_manager) {} + ConnectionIdCleanUpAlarm(const ConnectionIdCleanUpAlarm&) = delete; + ConnectionIdCleanUpAlarm& operator=(const ConnectionIdCleanUpAlarm&) = delete; + + void OnAlarm() override { + time_wait_list_manager_->CleanUpOldConnectionIds(); + } + + private: + // Not owned. + QuicTimeWaitListManager* time_wait_list_manager_; +}; + +TimeWaitConnectionInfo::TimeWaitConnectionInfo( + bool ietf_quic, + std::vector>* termination_packets, + std::vector active_connection_ids) + : TimeWaitConnectionInfo(ietf_quic, termination_packets, + std::move(active_connection_ids), + QuicTime::Delta::Zero()) {} + +TimeWaitConnectionInfo::TimeWaitConnectionInfo( + bool ietf_quic, + std::vector>* termination_packets, + std::vector active_connection_ids, QuicTime::Delta srtt) + : ietf_quic(ietf_quic), + active_connection_ids(std::move(active_connection_ids)), + srtt(srtt) { + if (termination_packets != nullptr) { + this->termination_packets.swap(*termination_packets); + } +} + +QuicTimeWaitListManager::QuicTimeWaitListManager( + QuicPacketWriter* writer, Visitor* visitor, const QuicClock* clock, + QuicAlarmFactory* alarm_factory) + : time_wait_period_(QuicTime::Delta::FromSeconds( + GetQuicFlag(quic_time_wait_list_seconds))), + connection_id_clean_up_alarm_( + alarm_factory->CreateAlarm(new ConnectionIdCleanUpAlarm(this))), + clock_(clock), + writer_(writer), + visitor_(visitor) { + SetConnectionIdCleanUpAlarm(); +} + +QuicTimeWaitListManager::~QuicTimeWaitListManager() { + connection_id_clean_up_alarm_->Cancel(); +} + +QuicTimeWaitListManager::ConnectionIdMap::iterator +QuicTimeWaitListManager::FindConnectionIdDataInMap( + const QuicConnectionId& connection_id) { + auto it = indirect_connection_id_map_.find(connection_id); + if (it == indirect_connection_id_map_.end()) { + return connection_id_map_.end(); + } + return connection_id_map_.find(it->second); +} + +void QuicTimeWaitListManager::AddConnectionIdDataToMap( + const QuicConnectionId& canonical_connection_id, int num_packets, + TimeWaitAction action, TimeWaitConnectionInfo info) { + for (const auto& cid : info.active_connection_ids) { + indirect_connection_id_map_[cid] = canonical_connection_id; + } + ConnectionIdData data(num_packets, clock_->ApproximateNow(), action, + std::move(info)); + connection_id_map_.emplace( + std::make_pair(canonical_connection_id, std::move(data))); +} + +void QuicTimeWaitListManager::RemoveConnectionDataFromMap( + ConnectionIdMap::iterator it) { + for (const auto& cid : it->second.info.active_connection_ids) { + indirect_connection_id_map_.erase(cid); + } + connection_id_map_.erase(it); +} + +void QuicTimeWaitListManager::AddConnectionIdToTimeWait( + TimeWaitAction action, TimeWaitConnectionInfo info) { + QUICHE_DCHECK(!info.active_connection_ids.empty()); + const QuicConnectionId& canonical_connection_id = + info.active_connection_ids.front(); + QUICHE_DCHECK(action != SEND_TERMINATION_PACKETS || + !info.termination_packets.empty()); + QUICHE_DCHECK(action != DO_NOTHING || info.ietf_quic); + int num_packets = 0; + auto it = FindConnectionIdDataInMap(canonical_connection_id); + const bool new_connection_id = it == connection_id_map_.end(); + if (!new_connection_id) { // Replace record if it is reinserted. + num_packets = it->second.num_packets; + RemoveConnectionDataFromMap(it); + } + TrimTimeWaitListIfNeeded(); + int64_t max_connections = GetQuicFlag(quic_time_wait_list_max_connections); + QUICHE_DCHECK(connection_id_map_.empty() || + num_connections() < static_cast(max_connections)); + if (new_connection_id) { + for (const auto& cid : info.active_connection_ids) { + visitor_->OnConnectionAddedToTimeWaitList(cid); + } + } + AddConnectionIdDataToMap(canonical_connection_id, num_packets, action, + std::move(info)); +} + +bool QuicTimeWaitListManager::IsConnectionIdInTimeWait( + QuicConnectionId connection_id) const { + return indirect_connection_id_map_.contains(connection_id); +} + +void QuicTimeWaitListManager::OnBlockedWriterCanWrite() { + writer_->SetWritable(); + while (!pending_packets_queue_.empty()) { + QueuedPacket* queued_packet = pending_packets_queue_.front().get(); + if (!WriteToWire(queued_packet)) { + return; + } + pending_packets_queue_.pop_front(); + } +} + +void QuicTimeWaitListManager::ProcessPacket( + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, QuicConnectionId connection_id, + PacketHeaderFormat header_format, size_t received_packet_length, + std::unique_ptr packet_context) { + QUICHE_DCHECK(IsConnectionIdInTimeWait(connection_id)); + // TODO(satyamshekhar): Think about handling packets from different peer + // addresses. + auto it = FindConnectionIdDataInMap(connection_id); + QUICHE_DCHECK(it != connection_id_map_.end()); + // Increment the received packet count. + ConnectionIdData* connection_data = &it->second; + ++(connection_data->num_packets); + const QuicTime now = clock_->ApproximateNow(); + QuicTime::Delta delta = QuicTime::Delta::Zero(); + if (now > connection_data->time_added) { + delta = now - connection_data->time_added; + } + OnPacketReceivedForKnownConnection(connection_data->num_packets, delta, + connection_data->info.srtt); + + if (!ShouldSendResponse(connection_data->num_packets)) { + QUIC_DLOG(INFO) << "Processing " << connection_id << " in time wait state: " + << "throttled"; + return; + } + + QUIC_DLOG(INFO) << "Processing " << connection_id << " in time wait state: " + << "header format=" << header_format + << " ietf=" << connection_data->info.ietf_quic + << ", action=" << connection_data->action + << ", number termination packets=" + << connection_data->info.termination_packets.size(); + switch (connection_data->action) { + case SEND_TERMINATION_PACKETS: + if (connection_data->info.termination_packets.empty()) { + QUIC_BUG(quic_bug_10608_1) << "There are no termination packets."; + return; + } + switch (header_format) { + case IETF_QUIC_LONG_HEADER_PACKET: + if (!connection_data->info.ietf_quic) { + QUIC_CODE_COUNT(quic_received_long_header_packet_for_gquic); + } + break; + case IETF_QUIC_SHORT_HEADER_PACKET: + if (!connection_data->info.ietf_quic) { + QUIC_CODE_COUNT(quic_received_short_header_packet_for_gquic); + } + // Send stateless reset in response to short header packets. + SendPublicReset(self_address, peer_address, connection_id, + connection_data->info.ietf_quic, + received_packet_length, std::move(packet_context)); + return; + case GOOGLE_QUIC_PACKET: + if (connection_data->info.ietf_quic) { + QUIC_CODE_COUNT(quic_received_gquic_packet_for_ietf_quic); + } + break; + } + + for (const auto& packet : connection_data->info.termination_packets) { + SendOrQueuePacket(std::make_unique( + self_address, peer_address, packet->Clone()), + packet_context.get()); + } + return; + + case SEND_CONNECTION_CLOSE_PACKETS: + if (connection_data->info.termination_packets.empty()) { + QUIC_BUG(quic_bug_10608_2) << "There are no termination packets."; + return; + } + for (const auto& packet : connection_data->info.termination_packets) { + SendOrQueuePacket(std::make_unique( + self_address, peer_address, packet->Clone()), + packet_context.get()); + } + return; + + case SEND_STATELESS_RESET: + if (header_format == IETF_QUIC_LONG_HEADER_PACKET) { + QUIC_CODE_COUNT(quic_stateless_reset_long_header_packet); + } + SendPublicReset(self_address, peer_address, connection_id, + connection_data->info.ietf_quic, received_packet_length, + std::move(packet_context)); + return; + case DO_NOTHING: + QUIC_CODE_COUNT(quic_time_wait_list_do_nothing); + QUICHE_DCHECK(connection_data->info.ietf_quic); + } +} + +void QuicTimeWaitListManager::SendVersionNegotiationPacket( + QuicConnectionId server_connection_id, + QuicConnectionId client_connection_id, bool ietf_quic, + bool use_length_prefix, const ParsedQuicVersionVector& supported_versions, + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + std::unique_ptr packet_context) { + std::unique_ptr version_packet = + QuicFramer::BuildVersionNegotiationPacket( + server_connection_id, client_connection_id, ietf_quic, + use_length_prefix, supported_versions); + QUIC_DVLOG(2) << "Dispatcher sending version negotiation packet {" + << ParsedQuicVersionVectorToString(supported_versions) << "}, " + << (ietf_quic ? "" : "!") << "ietf_quic, " + << (use_length_prefix ? "" : "!") + << "use_length_prefix:" << std::endl + << quiche::QuicheTextUtils::HexDump(absl::string_view( + version_packet->data(), version_packet->length())); + SendOrQueuePacket(std::make_unique(self_address, peer_address, + std::move(version_packet)), + packet_context.get()); +} + +// Returns true if the number of packets received for this connection_id is a +// power of 2 to throttle the number of public reset packets we send to a peer. +bool QuicTimeWaitListManager::ShouldSendResponse(int received_packet_count) { + return (received_packet_count & (received_packet_count - 1)) == 0; +} + +void QuicTimeWaitListManager::SendPublicReset( + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, QuicConnectionId connection_id, + bool ietf_quic, size_t received_packet_length, + std::unique_ptr packet_context) { + if (ietf_quic) { + std::unique_ptr ietf_reset_packet = + BuildIetfStatelessResetPacket(connection_id, received_packet_length); + if (ietf_reset_packet == nullptr) { + // This could happen when trying to reject a short header packet of + // a connection which is in the time wait list (and with no termination + // packet). + return; + } + QUIC_DVLOG(2) << "Dispatcher sending IETF reset packet for " + << connection_id << std::endl + << quiche::QuicheTextUtils::HexDump( + absl::string_view(ietf_reset_packet->data(), + ietf_reset_packet->length())); + SendOrQueuePacket( + std::make_unique(self_address, peer_address, + std::move(ietf_reset_packet)), + packet_context.get()); + return; + } + // Google QUIC public resets donot elicit resets in response. + QuicPublicResetPacket packet; + packet.connection_id = connection_id; + // TODO(satyamshekhar): generate a valid nonce for this connection_id. + packet.nonce_proof = 1010101; + // TODO(wub): This is wrong for proxied sessions. Fix it. + packet.client_address = peer_address; + GetEndpointId(&packet.endpoint_id); + // Takes ownership of the packet. + std::unique_ptr reset_packet = BuildPublicReset(packet); + QUIC_DVLOG(2) << "Dispatcher sending reset packet for " << connection_id + << std::endl + << quiche::QuicheTextUtils::HexDump(absl::string_view( + reset_packet->data(), reset_packet->length())); + SendOrQueuePacket(std::make_unique(self_address, peer_address, + std::move(reset_packet)), + packet_context.get()); +} + +void QuicTimeWaitListManager::SendPacket(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + const QuicEncryptedPacket& packet) { + SendOrQueuePacket(std::make_unique(self_address, peer_address, + packet.Clone()), + nullptr); +} + +std::unique_ptr QuicTimeWaitListManager::BuildPublicReset( + const QuicPublicResetPacket& packet) { + return QuicFramer::BuildPublicResetPacket(packet); +} + +std::unique_ptr +QuicTimeWaitListManager::BuildIetfStatelessResetPacket( + QuicConnectionId connection_id, size_t received_packet_length) { + return QuicFramer::BuildIetfStatelessResetPacket( + connection_id, received_packet_length, + GetStatelessResetToken(connection_id)); +} + +// Either sends the packet and deletes it or makes pending queue the +// owner of the packet. +bool QuicTimeWaitListManager::SendOrQueuePacket( + std::unique_ptr packet, + const QuicPerPacketContext* /*packet_context*/) { + if (packet == nullptr) { + QUIC_LOG(ERROR) << "Tried to send or queue a null packet"; + return true; + } + if (pending_packets_queue_.size() >= + GetQuicFlag(quic_time_wait_list_max_pending_packets)) { + // There are too many pending packets. + QUIC_CODE_COUNT(quic_too_many_pending_packets_in_time_wait); + return true; + } + if (WriteToWire(packet.get())) { + // Allow the packet to be deleted upon leaving this function. + return true; + } + pending_packets_queue_.push_back(std::move(packet)); + return false; +} + +bool QuicTimeWaitListManager::WriteToWire(QueuedPacket* queued_packet) { + if (writer_->IsWriteBlocked()) { + visitor_->OnWriteBlocked(this); + return false; + } + WriteResult result = writer_->WritePacket( + queued_packet->packet()->data(), queued_packet->packet()->length(), + queued_packet->self_address().host(), queued_packet->peer_address(), + nullptr); + + // If using a batch writer and the packet is buffered, flush it. + if (writer_->IsBatchMode() && result.status == WRITE_STATUS_OK && + result.bytes_written == 0) { + result = writer_->Flush(); + } + + if (IsWriteBlockedStatus(result.status)) { + // If blocked and unbuffered, return false to retry sending. + QUICHE_DCHECK(writer_->IsWriteBlocked()); + visitor_->OnWriteBlocked(this); + return result.status == WRITE_STATUS_BLOCKED_DATA_BUFFERED; + } else if (IsWriteError(result.status)) { + QUIC_LOG_FIRST_N(WARNING, 1) + << "Received unknown error while sending termination packet to " + << queued_packet->peer_address().ToString() << ": " + << strerror(result.error_code); + } + return true; +} + +void QuicTimeWaitListManager::SetConnectionIdCleanUpAlarm() { + QuicTime::Delta next_alarm_interval = QuicTime::Delta::Zero(); + if (!connection_id_map_.empty()) { + QuicTime oldest_connection_id = + connection_id_map_.begin()->second.time_added; + QuicTime now = clock_->ApproximateNow(); + if (now - oldest_connection_id < time_wait_period_) { + next_alarm_interval = oldest_connection_id + time_wait_period_ - now; + } else { + QUIC_LOG(ERROR) + << "ConnectionId lingered for longer than time_wait_period_"; + } + } else { + // No connection_ids added so none will expire before time_wait_period_. + next_alarm_interval = time_wait_period_; + } + + connection_id_clean_up_alarm_->Update( + clock_->ApproximateNow() + next_alarm_interval, QuicTime::Delta::Zero()); +} + +bool QuicTimeWaitListManager::MaybeExpireOldestConnection( + QuicTime expiration_time) { + if (connection_id_map_.empty()) { + return false; + } + auto it = connection_id_map_.begin(); + QuicTime oldest_connection_id_time = it->second.time_added; + if (oldest_connection_id_time > expiration_time) { + // Too recent, don't retire. + return false; + } + // This connection_id has lived its age, retire it now. + QUIC_DLOG(INFO) << "Connection " << it->first + << " expired from time wait list"; + RemoveConnectionDataFromMap(it); + if (expiration_time == QuicTime::Infinite()) { + QUIC_CODE_COUNT(quic_time_wait_list_trim_full); + } else { + QUIC_CODE_COUNT(quic_time_wait_list_expire_connections); + } + return true; +} + +void QuicTimeWaitListManager::CleanUpOldConnectionIds() { + QuicTime now = clock_->ApproximateNow(); + QuicTime expiration = now - time_wait_period_; + + while (MaybeExpireOldestConnection(expiration)) { + } + + SetConnectionIdCleanUpAlarm(); +} + +void QuicTimeWaitListManager::TrimTimeWaitListIfNeeded() { + const int64_t kMaxConnections = + GetQuicFlag(quic_time_wait_list_max_connections); + if (kMaxConnections < 0) { + return; + } + while (!connection_id_map_.empty() && + num_connections() >= static_cast(kMaxConnections)) { + MaybeExpireOldestConnection(QuicTime::Infinite()); + } +} + +QuicTimeWaitListManager::ConnectionIdData::ConnectionIdData( + int num_packets, QuicTime time_added, TimeWaitAction action, + TimeWaitConnectionInfo info) + : num_packets(num_packets), + time_added(time_added), + action(action), + info(std::move(info)) {} + +QuicTimeWaitListManager::ConnectionIdData::ConnectionIdData( + ConnectionIdData&& other) = default; + +QuicTimeWaitListManager::ConnectionIdData::~ConnectionIdData() = default; + +StatelessResetToken QuicTimeWaitListManager::GetStatelessResetToken( + QuicConnectionId connection_id) const { + return QuicUtils::GenerateStatelessResetToken(connection_id); +} + +} // namespace quic diff --git a/quiche/quic/core/quic_time_wait_list_manager.h b/quiche/quic/core/quic_time_wait_list_manager.h new file mode 100644 index 000000000000..676e15e99d62 --- /dev/null +++ b/quiche/quic/core/quic_time_wait_list_manager.h @@ -0,0 +1,331 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Handles packets for connection_ids in time wait state by discarding the +// packet and sending the peers termination packets with exponential backoff. + +#ifndef QUICHE_QUIC_CORE_QUIC_TIME_WAIT_LIST_MANAGER_H_ +#define QUICHE_QUIC_CORE_QUIC_TIME_WAIT_LIST_MANAGER_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "quiche/quic/core/quic_blocked_writer_interface.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_framer.h" +#include "quiche/quic/core/quic_packet_writer.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/common/quiche_linked_hash_map.h" + +namespace quic { + +namespace test { +class QuicDispatcherPeer; +class QuicTimeWaitListManagerPeer; +} // namespace test + +// TimeWaitConnectionInfo comprises information of a connection which is in the +// time wait list. +struct QUIC_NO_EXPORT TimeWaitConnectionInfo { + TimeWaitConnectionInfo( + bool ietf_quic, + std::vector>* termination_packets, + std::vector active_connection_ids); + TimeWaitConnectionInfo( + bool ietf_quic, + std::vector>* termination_packets, + std::vector active_connection_ids, + QuicTime::Delta srtt); + + TimeWaitConnectionInfo(const TimeWaitConnectionInfo& other) = delete; + TimeWaitConnectionInfo(TimeWaitConnectionInfo&& other) = default; + + ~TimeWaitConnectionInfo() = default; + + bool ietf_quic; + std::vector> termination_packets; + std::vector active_connection_ids; + QuicTime::Delta srtt; +}; + +// Maintains a list of all connection_ids that have been recently closed. A +// connection_id lives in this state for time_wait_period_. All packets received +// for connection_ids in this state are handed over to the +// QuicTimeWaitListManager by the QuicDispatcher. Decides whether to send a +// public reset packet, a copy of the previously sent connection close packet, +// or nothing to the peer which sent a packet with the connection_id in time +// wait state. After the connection_id expires its time wait period, a new +// connection/session will be created if a packet is received for this +// connection_id. +class QUIC_NO_EXPORT QuicTimeWaitListManager + : public QuicBlockedWriterInterface { + public: + // Specifies what the time wait list manager should do when processing packets + // of a time wait connection. + enum TimeWaitAction : uint8_t { + // Send specified termination packets, error if termination packet is + // unavailable. + SEND_TERMINATION_PACKETS, + // The same as SEND_TERMINATION_PACKETS except that the corresponding + // termination packets are provided by the connection. + SEND_CONNECTION_CLOSE_PACKETS, + // Send stateless reset (public reset for GQUIC). + SEND_STATELESS_RESET, + + DO_NOTHING, + }; + + class QUIC_NO_EXPORT Visitor : public QuicSession::Visitor { + public: + // Called after the given connection is added to the time-wait list. + virtual void OnConnectionAddedToTimeWaitList( + QuicConnectionId connection_id) = 0; + }; + + // writer - the entity that writes to the socket. (Owned by the caller) + // visitor - the entity that manages blocked writers. (Owned by the caller) + // clock - provide a clock (Owned by the caller) + // alarm_factory - used to run clean up alarms. (Owned by the caller) + QuicTimeWaitListManager(QuicPacketWriter* writer, Visitor* visitor, + const QuicClock* clock, + QuicAlarmFactory* alarm_factory); + QuicTimeWaitListManager(const QuicTimeWaitListManager&) = delete; + QuicTimeWaitListManager& operator=(const QuicTimeWaitListManager&) = delete; + ~QuicTimeWaitListManager() override; + + // Adds the connection IDs in info to time wait state for time_wait_period_. + // If |info|.termination_packets are provided, copies of these packets will be + // sent when a packet with one of these connection IDs is processed. Any + // termination packets will be move from |info|.termination_packets and will + // become owned by the manager. |action| specifies what the time wait list + // manager should do when processing packets of the connection. + virtual void AddConnectionIdToTimeWait(TimeWaitAction action, + TimeWaitConnectionInfo info); + + // Returns true if the connection_id is in time wait state, false otherwise. + // Packets received for this connection_id should not lead to creation of new + // QuicSessions. + bool IsConnectionIdInTimeWait(QuicConnectionId connection_id) const; + + // Called when a packet is received for a connection_id that is in time wait + // state. Sends a public reset packet to the peer which sent this + // connection_id. Sending of the public reset packet is throttled by using + // exponential back off. QUICHE_DCHECKs for the connection_id to be in time + // wait state. virtual to override in tests. + // TODO(fayang): change ProcessPacket and SendPublicReset to take + // ReceivedPacketInfo. + virtual void ProcessPacket( + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, QuicConnectionId connection_id, + PacketHeaderFormat header_format, size_t received_packet_length, + std::unique_ptr packet_context); + + // Called by the dispatcher when the underlying socket becomes writable again, + // since we might need to send pending public reset packets which we didn't + // send because the underlying socket was write blocked. + void OnBlockedWriterCanWrite() override; + + bool IsWriterBlocked() const override { + return writer_ != nullptr && writer_->IsWriteBlocked(); + } + + // Used to delete connection_id entries that have outlived their time wait + // period. + void CleanUpOldConnectionIds(); + + // If necessary, trims the oldest connections from the time-wait list until + // the size is under the configured maximum. + void TrimTimeWaitListIfNeeded(); + + // The number of connections on the time-wait list. + size_t num_connections() const { return connection_id_map_.size(); } + + // Sends a version negotiation packet for |server_connection_id| and + // |client_connection_id| announcing support for |supported_versions| to + // |peer_address| from |self_address|. + virtual void SendVersionNegotiationPacket( + QuicConnectionId server_connection_id, + QuicConnectionId client_connection_id, bool ietf_quic, + bool use_length_prefix, const ParsedQuicVersionVector& supported_versions, + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + std::unique_ptr packet_context); + + // Creates a public reset packet and sends it or queues it to be sent later. + virtual void SendPublicReset( + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, QuicConnectionId connection_id, + bool ietf_quic, size_t received_packet_length, + std::unique_ptr packet_context); + + // Called to send |packet|. + virtual void SendPacket(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + const QuicEncryptedPacket& packet); + + // Return a non-owning pointer to the packet writer. + QuicPacketWriter* writer() { return writer_; } + + protected: + virtual std::unique_ptr BuildPublicReset( + const QuicPublicResetPacket& packet); + + virtual void GetEndpointId(std::string* /*endpoint_id*/) {} + + // Returns a stateless reset token which will be included in the public reset + // packet. + virtual StatelessResetToken GetStatelessResetToken( + QuicConnectionId connection_id) const; + + // Internal structure to store pending termination packets. + class QUIC_NO_EXPORT QueuedPacket { + public: + QueuedPacket(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + std::unique_ptr packet) + : self_address_(self_address), + peer_address_(peer_address), + packet_(std::move(packet)) {} + QueuedPacket(const QueuedPacket&) = delete; + QueuedPacket& operator=(const QueuedPacket&) = delete; + + const QuicSocketAddress& self_address() const { return self_address_; } + const QuicSocketAddress& peer_address() const { return peer_address_; } + QuicEncryptedPacket* packet() { return packet_.get(); } + + private: + // Server address on which a packet was received for a connection_id in + // time wait state. + const QuicSocketAddress self_address_; + // Address of the peer to send this packet to. + const QuicSocketAddress peer_address_; + // The pending termination packet that is to be sent to the peer. + std::unique_ptr packet_; + }; + + // Called right after |packet| is serialized. Either sends the packet and + // deletes it or makes pending_packets_queue_ the owner of the packet. + // Subclasses overriding this method should call this class's base + // implementation at the end of the override. + // Return true if |packet| is sent, false if it is queued. + virtual bool SendOrQueuePacket(std::unique_ptr packet, + const QuicPerPacketContext* packet_context); + + const quiche::QuicheCircularDeque>& + pending_packets_queue() const { + return pending_packets_queue_; + } + + private: + friend class test::QuicDispatcherPeer; + friend class test::QuicTimeWaitListManagerPeer; + + // Decides if a packet should be sent for this connection_id based on the + // number of received packets. + bool ShouldSendResponse(int received_packet_count); + + // Sends the packet out. Returns true if the packet was successfully consumed. + // If the writer got blocked and did not buffer the packet, we'll need to keep + // the packet and retry sending. In case of all other errors we drop the + // packet. + bool WriteToWire(QueuedPacket* packet); + + // Register the alarm server to wake up at appropriate time. + void SetConnectionIdCleanUpAlarm(); + + // Removes the oldest connection from the time-wait list if it was added prior + // to "expiration_time". To unconditionally remove the oldest connection, use + // a QuicTime::Delta:Infinity(). This function modifies the + // connection_id_map_. If you plan to call this function in a loop, any + // iterators that you hold before the call to this function may be invalid + // afterward. Returns true if the oldest connection was expired. Returns + // false if the map is empty or the oldest connection has not expired. + bool MaybeExpireOldestConnection(QuicTime expiration_time); + + // Called when a packet is received for a connection in this time wait list. + virtual void OnPacketReceivedForKnownConnection( + int /*num_packets*/, QuicTime::Delta /*delta*/, + QuicTime::Delta /*srtt*/) const {} + + std::unique_ptr BuildIetfStatelessResetPacket( + QuicConnectionId connection_id, size_t received_packet_length); + + // A map from a recently closed connection_id to the number of packets + // received after the termination of the connection bound to the + // connection_id. + struct QUIC_NO_EXPORT ConnectionIdData { + ConnectionIdData(int num_packets, QuicTime time_added, + TimeWaitAction action, TimeWaitConnectionInfo info); + + ConnectionIdData(const ConnectionIdData& other) = delete; + ConnectionIdData(ConnectionIdData&& other); + + ~ConnectionIdData(); + + int num_packets; + QuicTime time_added; + TimeWaitAction action; + TimeWaitConnectionInfo info; + }; + + // QuicheLinkedHashMap allows lookup by ConnectionId + // and traversal in add order. + using ConnectionIdMap = + quiche::QuicheLinkedHashMap; + // Do not use find/emplace/erase on this map directly. Use + // FindConnectionIdDataInMap, AddConnectionIdDateToMap, + // RemoveConnectionDataFromMap instead. + ConnectionIdMap connection_id_map_; + + // TODO(haoyuewang) Consider making connection_id_map_ a map of shared pointer + // and remove the indirect map. + // A connection can have multiple unretired ConnectionIds when it is closed. + // These Ids have the same ConnectionIdData entry in connection_id_map_. To + // find the entry, look up the cannoical ConnectionId in + // indirect_connection_id_map_ first, and look up connection_id_map_ with the + // cannoical ConnectionId. + absl::flat_hash_map + indirect_connection_id_map_; + + // Find an iterator for the given connection_id. Returns + // connection_id_map_.end() if none found. + ConnectionIdMap::iterator FindConnectionIdDataInMap( + const QuicConnectionId& connection_id); + // Inserts a ConnectionIdData entry to connection_id_map_. + void AddConnectionIdDataToMap(const QuicConnectionId& canonical_connection_id, + int num_packets, TimeWaitAction action, + TimeWaitConnectionInfo info); + // Removes a ConnectionIdData entry in connection_id_map_. + void RemoveConnectionDataFromMap(ConnectionIdMap::iterator it); + + // Pending termination packets that need to be sent out to the peer when we + // are given a chance to write by the dispatcher. + quiche::QuicheCircularDeque> + pending_packets_queue_; + + // Time period for which connection_ids should remain in time wait state. + const QuicTime::Delta time_wait_period_; + + // Alarm to clean up connection_ids that have out lived their duration in + // time wait state. + std::unique_ptr connection_id_clean_up_alarm_; + + // Clock to efficiently measure approximate time. + const QuicClock* clock_; + + // Interface that writes given buffer to the socket. + QuicPacketWriter* writer_; + + // Interface that manages blocked writers. + Visitor* visitor_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_TIME_WAIT_LIST_MANAGER_H_ diff --git a/quiche/quic/core/quic_time_wait_list_manager_test.cc b/quiche/quic/core/quic_time_wait_list_manager_test.cc new file mode 100644 index 000000000000..9567519adf09 --- /dev/null +++ b/quiche/quic/core/quic_time_wait_list_manager_test.cc @@ -0,0 +1,781 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_time_wait_list_manager.h" + +#include +#include +#include +#include + +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/core/crypto/quic_decrypter.h" +#include "quiche/quic/core/crypto/quic_encrypter.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_data_reader.h" +#include "quiche/quic/core/quic_framer.h" +#include "quiche/quic/core/quic_packet_writer.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/mock_quic_session_visitor.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/test_tools/quic_time_wait_list_manager_peer.h" + +using testing::_; +using testing::Args; +using testing::Assign; +using testing::DoAll; +using testing::Matcher; +using testing::NiceMock; +using testing::Return; +using testing::ReturnPointee; +using testing::StrictMock; +using testing::Truly; + +namespace quic { +namespace test { +namespace { + +const size_t kTestPacketSize = 100; + +class FramerVisitorCapturingPublicReset : public NoOpFramerVisitor { + public: + FramerVisitorCapturingPublicReset(QuicConnectionId connection_id) + : connection_id_(connection_id) {} + ~FramerVisitorCapturingPublicReset() override = default; + + void OnPublicResetPacket(const QuicPublicResetPacket& public_reset) override { + public_reset_packet_ = public_reset; + } + + const QuicPublicResetPacket public_reset_packet() { + return public_reset_packet_; + } + + bool IsValidStatelessResetToken( + const StatelessResetToken& token) const override { + return token == QuicUtils::GenerateStatelessResetToken(connection_id_); + } + + void OnAuthenticatedIetfStatelessResetPacket( + const QuicIetfStatelessResetPacket& packet) override { + stateless_reset_packet_ = packet; + } + + const QuicIetfStatelessResetPacket stateless_reset_packet() { + return stateless_reset_packet_; + } + + private: + QuicPublicResetPacket public_reset_packet_; + QuicIetfStatelessResetPacket stateless_reset_packet_; + QuicConnectionId connection_id_; +}; + +class MockAlarmFactory; +class MockAlarm : public QuicAlarm { + public: + explicit MockAlarm(QuicArenaScopedPtr delegate, int alarm_index, + MockAlarmFactory* factory) + : QuicAlarm(std::move(delegate)), + alarm_index_(alarm_index), + factory_(factory) {} + virtual ~MockAlarm() {} + + void SetImpl() override; + void CancelImpl() override; + + private: + int alarm_index_; + MockAlarmFactory* factory_; +}; + +class MockAlarmFactory : public QuicAlarmFactory { + public: + ~MockAlarmFactory() override {} + + // Creates a new platform-specific alarm which will be configured to notify + // |delegate| when the alarm fires. Returns an alarm allocated on the heap. + // Caller takes ownership of the new alarm, which will not yet be "set" to + // fire. + QuicAlarm* CreateAlarm(QuicAlarm::Delegate* delegate) override { + return new MockAlarm(QuicArenaScopedPtr(delegate), + alarm_index_++, this); + } + QuicArenaScopedPtr CreateAlarm( + QuicArenaScopedPtr delegate, + QuicConnectionArena* arena) override { + if (arena != nullptr) { + return arena->New(std::move(delegate), alarm_index_++, this); + } + return QuicArenaScopedPtr( + new MockAlarm(std::move(delegate), alarm_index_++, this)); + } + MOCK_METHOD(void, OnAlarmSet, (int, QuicTime), ()); + MOCK_METHOD(void, OnAlarmCancelled, (int), ()); + + private: + int alarm_index_ = 0; +}; + +void MockAlarm::SetImpl() { factory_->OnAlarmSet(alarm_index_, deadline()); } + +void MockAlarm::CancelImpl() { factory_->OnAlarmCancelled(alarm_index_); } + +class QuicTimeWaitListManagerTest : public QuicTest { + protected: + QuicTimeWaitListManagerTest() + : time_wait_list_manager_(&writer_, &visitor_, &clock_, &alarm_factory_), + connection_id_(TestConnectionId(45)), + peer_address_(TestPeerIPAddress(), kTestPort), + writer_is_blocked_(false) {} + + ~QuicTimeWaitListManagerTest() override = default; + + void SetUp() override { + EXPECT_CALL(writer_, IsWriteBlocked()) + .WillRepeatedly(ReturnPointee(&writer_is_blocked_)); + } + + void AddConnectionId(QuicConnectionId connection_id, + QuicTimeWaitListManager::TimeWaitAction action) { + AddConnectionId(connection_id, QuicVersionMax(), action, nullptr); + } + + void AddStatelessConnectionId(QuicConnectionId connection_id) { + std::vector> termination_packets; + termination_packets.push_back(std::unique_ptr( + new QuicEncryptedPacket(nullptr, 0, false))); + time_wait_list_manager_.AddConnectionIdToTimeWait( + QuicTimeWaitListManager::SEND_TERMINATION_PACKETS, + TimeWaitConnectionInfo(false, &termination_packets, {connection_id})); + } + + void AddConnectionId( + QuicConnectionId connection_id, ParsedQuicVersion version, + QuicTimeWaitListManager::TimeWaitAction action, + std::vector>* packets) { + time_wait_list_manager_.AddConnectionIdToTimeWait( + action, TimeWaitConnectionInfo(version.HasIetfInvariantHeader(), + packets, {connection_id})); + } + + bool IsConnectionIdInTimeWait(QuicConnectionId connection_id) { + return time_wait_list_manager_.IsConnectionIdInTimeWait(connection_id); + } + + void ProcessPacket(QuicConnectionId connection_id) { + time_wait_list_manager_.ProcessPacket( + self_address_, peer_address_, connection_id, GOOGLE_QUIC_PACKET, + kTestPacketSize, std::make_unique()); + } + + QuicEncryptedPacket* ConstructEncryptedPacket( + QuicConnectionId destination_connection_id, + QuicConnectionId source_connection_id, uint64_t packet_number) { + return quic::test::ConstructEncryptedPacket(destination_connection_id, + source_connection_id, false, + false, packet_number, "data"); + } + + MockClock clock_; + MockAlarmFactory alarm_factory_; + NiceMock writer_; + StrictMock visitor_; + QuicTimeWaitListManager time_wait_list_manager_; + QuicConnectionId connection_id_; + QuicSocketAddress self_address_; + QuicSocketAddress peer_address_; + bool writer_is_blocked_; +}; + +bool ValidPublicResetPacketPredicate( + QuicConnectionId expected_connection_id, + const std::tuple& packet_buffer) { + FramerVisitorCapturingPublicReset visitor(expected_connection_id); + QuicFramer framer(AllSupportedVersions(), QuicTime::Zero(), + Perspective::IS_CLIENT, kQuicDefaultConnectionIdLength); + framer.set_visitor(&visitor); + QuicEncryptedPacket encrypted(std::get<0>(packet_buffer), + std::get<1>(packet_buffer)); + framer.ProcessPacket(encrypted); + QuicPublicResetPacket packet = visitor.public_reset_packet(); + bool public_reset_is_valid = + expected_connection_id == packet.connection_id && + TestPeerIPAddress() == packet.client_address.host() && + kTestPort == packet.client_address.port(); + + QuicIetfStatelessResetPacket stateless_reset = + visitor.stateless_reset_packet(); + + StatelessResetToken expected_stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(expected_connection_id); + + bool stateless_reset_is_valid = + stateless_reset.stateless_reset_token == expected_stateless_reset_token; + + return public_reset_is_valid || stateless_reset_is_valid; +} + +Matcher> PublicResetPacketEq( + QuicConnectionId connection_id) { + return Truly( + [connection_id](const std::tuple packet_buffer) { + return ValidPublicResetPacketPredicate(connection_id, packet_buffer); + }); +} + +TEST_F(QuicTimeWaitListManagerTest, CheckConnectionIdInTimeWait) { + EXPECT_FALSE(IsConnectionIdInTimeWait(connection_id_)); + EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id_)); + AddConnectionId(connection_id_, QuicTimeWaitListManager::DO_NOTHING); + EXPECT_EQ(1u, time_wait_list_manager_.num_connections()); + EXPECT_TRUE(IsConnectionIdInTimeWait(connection_id_)); +} + +TEST_F(QuicTimeWaitListManagerTest, CheckStatelessConnectionIdInTimeWait) { + EXPECT_FALSE(IsConnectionIdInTimeWait(connection_id_)); + EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id_)); + AddStatelessConnectionId(connection_id_); + EXPECT_EQ(1u, time_wait_list_manager_.num_connections()); + EXPECT_TRUE(IsConnectionIdInTimeWait(connection_id_)); +} + +TEST_F(QuicTimeWaitListManagerTest, SendVersionNegotiationPacket) { + std::unique_ptr packet( + QuicFramer::BuildVersionNegotiationPacket( + connection_id_, EmptyQuicConnectionId(), /*ietf_quic=*/false, + /*use_length_prefix=*/false, AllSupportedVersions())); + EXPECT_CALL(writer_, WritePacket(_, packet->length(), self_address_.host(), + peer_address_, _)) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 1))); + + time_wait_list_manager_.SendVersionNegotiationPacket( + connection_id_, EmptyQuicConnectionId(), /*ietf_quic=*/false, + /*use_length_prefix=*/false, AllSupportedVersions(), self_address_, + peer_address_, std::make_unique()); + EXPECT_EQ(0u, time_wait_list_manager_.num_connections()); +} + +TEST_F(QuicTimeWaitListManagerTest, + SendIetfVersionNegotiationPacketWithoutLengthPrefix) { + std::unique_ptr packet( + QuicFramer::BuildVersionNegotiationPacket( + connection_id_, EmptyQuicConnectionId(), /*ietf_quic=*/true, + /*use_length_prefix=*/false, AllSupportedVersions())); + EXPECT_CALL(writer_, WritePacket(_, packet->length(), self_address_.host(), + peer_address_, _)) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 1))); + + time_wait_list_manager_.SendVersionNegotiationPacket( + connection_id_, EmptyQuicConnectionId(), /*ietf_quic=*/true, + /*use_length_prefix=*/false, AllSupportedVersions(), self_address_, + peer_address_, std::make_unique()); + EXPECT_EQ(0u, time_wait_list_manager_.num_connections()); +} + +TEST_F(QuicTimeWaitListManagerTest, SendIetfVersionNegotiationPacket) { + std::unique_ptr packet( + QuicFramer::BuildVersionNegotiationPacket( + connection_id_, EmptyQuicConnectionId(), /*ietf_quic=*/true, + /*use_length_prefix=*/true, AllSupportedVersions())); + EXPECT_CALL(writer_, WritePacket(_, packet->length(), self_address_.host(), + peer_address_, _)) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 1))); + + time_wait_list_manager_.SendVersionNegotiationPacket( + connection_id_, EmptyQuicConnectionId(), /*ietf_quic=*/true, + /*use_length_prefix=*/true, AllSupportedVersions(), self_address_, + peer_address_, std::make_unique()); + EXPECT_EQ(0u, time_wait_list_manager_.num_connections()); +} + +TEST_F(QuicTimeWaitListManagerTest, + SendIetfVersionNegotiationPacketWithClientConnectionId) { + std::unique_ptr packet( + QuicFramer::BuildVersionNegotiationPacket( + connection_id_, TestConnectionId(0x33), /*ietf_quic=*/true, + /*use_length_prefix=*/true, AllSupportedVersions())); + EXPECT_CALL(writer_, WritePacket(_, packet->length(), self_address_.host(), + peer_address_, _)) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 1))); + + time_wait_list_manager_.SendVersionNegotiationPacket( + connection_id_, TestConnectionId(0x33), /*ietf_quic=*/true, + /*use_length_prefix=*/true, AllSupportedVersions(), self_address_, + peer_address_, std::make_unique()); + EXPECT_EQ(0u, time_wait_list_manager_.num_connections()); +} + +TEST_F(QuicTimeWaitListManagerTest, SendConnectionClose) { + const size_t kConnectionCloseLength = 100; + EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id_)); + std::vector> termination_packets; + termination_packets.push_back( + std::unique_ptr(new QuicEncryptedPacket( + new char[kConnectionCloseLength], kConnectionCloseLength, true))); + AddConnectionId(connection_id_, QuicVersionMax(), + QuicTimeWaitListManager::SEND_CONNECTION_CLOSE_PACKETS, + &termination_packets); + EXPECT_CALL(writer_, WritePacket(_, kConnectionCloseLength, + self_address_.host(), peer_address_, _)) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 1))); + + ProcessPacket(connection_id_); +} + +TEST_F(QuicTimeWaitListManagerTest, SendTwoConnectionCloses) { + const size_t kConnectionCloseLength = 100; + EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id_)); + std::vector> termination_packets; + termination_packets.push_back( + std::unique_ptr(new QuicEncryptedPacket( + new char[kConnectionCloseLength], kConnectionCloseLength, true))); + termination_packets.push_back( + std::unique_ptr(new QuicEncryptedPacket( + new char[kConnectionCloseLength], kConnectionCloseLength, true))); + AddConnectionId(connection_id_, QuicVersionMax(), + QuicTimeWaitListManager::SEND_CONNECTION_CLOSE_PACKETS, + &termination_packets); + EXPECT_CALL(writer_, WritePacket(_, kConnectionCloseLength, + self_address_.host(), peer_address_, _)) + .Times(2) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 1))); + + ProcessPacket(connection_id_); +} + +TEST_F(QuicTimeWaitListManagerTest, SendPublicReset) { + EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id_)); + AddConnectionId(connection_id_, + QuicTimeWaitListManager::SEND_STATELESS_RESET); + EXPECT_CALL(writer_, + WritePacket(_, _, self_address_.host(), peer_address_, _)) + .With(Args<0, 1>(PublicResetPacketEq(connection_id_))) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 0))); + + ProcessPacket(connection_id_); +} + +TEST_F(QuicTimeWaitListManagerTest, SendPublicResetWithExponentialBackOff) { + EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id_)); + AddConnectionId(connection_id_, + QuicTimeWaitListManager::SEND_STATELESS_RESET); + EXPECT_EQ(1u, time_wait_list_manager_.num_connections()); + for (int packet_number = 1; packet_number < 101; ++packet_number) { + if ((packet_number & (packet_number - 1)) == 0) { + EXPECT_CALL(writer_, WritePacket(_, _, _, _, _)) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 1))); + } + ProcessPacket(connection_id_); + // Send public reset with exponential back off. + if ((packet_number & (packet_number - 1)) == 0) { + EXPECT_TRUE(QuicTimeWaitListManagerPeer::ShouldSendResponse( + &time_wait_list_manager_, packet_number)); + } else { + EXPECT_FALSE(QuicTimeWaitListManagerPeer::ShouldSendResponse( + &time_wait_list_manager_, packet_number)); + } + } +} + +TEST_F(QuicTimeWaitListManagerTest, NoPublicResetForStatelessConnections) { + EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id_)); + AddStatelessConnectionId(connection_id_); + + EXPECT_CALL(writer_, + WritePacket(_, _, self_address_.host(), peer_address_, _)) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 1))); + + ProcessPacket(connection_id_); +} + +TEST_F(QuicTimeWaitListManagerTest, CleanUpOldConnectionIds) { + const size_t kConnectionIdCount = 100; + const size_t kOldConnectionIdCount = 31; + + // Add connection_ids such that their expiry time is time_wait_period_. + for (uint64_t conn_id = 1; conn_id <= kOldConnectionIdCount; ++conn_id) { + QuicConnectionId connection_id = TestConnectionId(conn_id); + EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id)); + AddConnectionId(connection_id, QuicTimeWaitListManager::DO_NOTHING); + } + EXPECT_EQ(kOldConnectionIdCount, time_wait_list_manager_.num_connections()); + + // Add remaining connection_ids such that their add time is + // 2 * time_wait_period_. + const QuicTime::Delta time_wait_period = + QuicTimeWaitListManagerPeer::time_wait_period(&time_wait_list_manager_); + clock_.AdvanceTime(time_wait_period); + for (uint64_t conn_id = kOldConnectionIdCount + 1; + conn_id <= kConnectionIdCount; ++conn_id) { + QuicConnectionId connection_id = TestConnectionId(conn_id); + EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id)); + AddConnectionId(connection_id, QuicTimeWaitListManager::DO_NOTHING); + } + EXPECT_EQ(kConnectionIdCount, time_wait_list_manager_.num_connections()); + + QuicTime::Delta offset = QuicTime::Delta::FromMicroseconds(39); + // Now set the current time as time_wait_period + offset usecs. + clock_.AdvanceTime(offset); + // After all the old connection_ids are cleaned up, check the next alarm + // interval. + QuicTime next_alarm_time = clock_.Now() + time_wait_period - offset; + EXPECT_CALL(alarm_factory_, OnAlarmSet(_, next_alarm_time)); + + time_wait_list_manager_.CleanUpOldConnectionIds(); + for (uint64_t conn_id = 1; conn_id <= kConnectionIdCount; ++conn_id) { + QuicConnectionId connection_id = TestConnectionId(conn_id); + EXPECT_EQ(conn_id > kOldConnectionIdCount, + IsConnectionIdInTimeWait(connection_id)) + << "kOldConnectionIdCount: " << kOldConnectionIdCount + << " connection_id: " << connection_id; + } + EXPECT_EQ(kConnectionIdCount - kOldConnectionIdCount, + time_wait_list_manager_.num_connections()); +} + +TEST_F(QuicTimeWaitListManagerTest, + CleanUpOldConnectionIdsForMultipleConnectionIdsPerConnection) { + connection_id_ = TestConnectionId(7); + const size_t kConnectionCloseLength = 100; + EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id_)); + EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(TestConnectionId(8))); + std::vector> termination_packets; + termination_packets.push_back( + std::unique_ptr(new QuicEncryptedPacket( + new char[kConnectionCloseLength], kConnectionCloseLength, true))); + + // Add a CONNECTION_CLOSE termination packet. + std::vector active_connection_ids{connection_id_, + TestConnectionId(8)}; + time_wait_list_manager_.AddConnectionIdToTimeWait( + QuicTimeWaitListManager::SEND_CONNECTION_CLOSE_PACKETS, + TimeWaitConnectionInfo(/*ietf_quic=*/true, &termination_packets, + active_connection_ids, QuicTime::Delta::Zero())); + + EXPECT_TRUE( + time_wait_list_manager_.IsConnectionIdInTimeWait(TestConnectionId(7))); + EXPECT_TRUE( + time_wait_list_manager_.IsConnectionIdInTimeWait(TestConnectionId(8))); + + // Remove these IDs. + const QuicTime::Delta time_wait_period = + QuicTimeWaitListManagerPeer::time_wait_period(&time_wait_list_manager_); + clock_.AdvanceTime(time_wait_period); + time_wait_list_manager_.CleanUpOldConnectionIds(); + + EXPECT_FALSE( + time_wait_list_manager_.IsConnectionIdInTimeWait(TestConnectionId(7))); + EXPECT_FALSE( + time_wait_list_manager_.IsConnectionIdInTimeWait(TestConnectionId(8))); +} + +TEST_F(QuicTimeWaitListManagerTest, SendQueuedPackets) { + QuicConnectionId connection_id = TestConnectionId(1); + EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id)); + AddConnectionId(connection_id, QuicTimeWaitListManager::SEND_STATELESS_RESET); + std::unique_ptr packet(ConstructEncryptedPacket( + connection_id, EmptyQuicConnectionId(), /*packet_number=*/234)); + // Let first write through. + EXPECT_CALL(writer_, + WritePacket(_, _, self_address_.host(), peer_address_, _)) + .With(Args<0, 1>(PublicResetPacketEq(connection_id))) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, packet->length()))); + ProcessPacket(connection_id); + + // write block for the next packet. + EXPECT_CALL(writer_, + WritePacket(_, _, self_address_.host(), peer_address_, _)) + .With(Args<0, 1>(PublicResetPacketEq(connection_id))) + .WillOnce(DoAll(Assign(&writer_is_blocked_, true), + Return(WriteResult(WRITE_STATUS_BLOCKED, EAGAIN)))); + EXPECT_CALL(visitor_, OnWriteBlocked(&time_wait_list_manager_)); + ProcessPacket(connection_id); + // 3rd packet. No public reset should be sent; + ProcessPacket(connection_id); + + // write packet should not be called since we are write blocked but the + // should be queued. + QuicConnectionId other_connection_id = TestConnectionId(2); + EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(other_connection_id)); + AddConnectionId(other_connection_id, + QuicTimeWaitListManager::SEND_STATELESS_RESET); + std::unique_ptr other_packet(ConstructEncryptedPacket( + other_connection_id, EmptyQuicConnectionId(), /*packet_number=*/23423)); + EXPECT_CALL(writer_, WritePacket(_, _, _, _, _)).Times(0); + EXPECT_CALL(visitor_, OnWriteBlocked(&time_wait_list_manager_)); + ProcessPacket(other_connection_id); + EXPECT_EQ(2u, time_wait_list_manager_.num_connections()); + + // Now expect all the write blocked public reset packets to be sent again. + writer_is_blocked_ = false; + EXPECT_CALL(writer_, + WritePacket(_, _, self_address_.host(), peer_address_, _)) + .With(Args<0, 1>(PublicResetPacketEq(connection_id))) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, packet->length()))); + EXPECT_CALL(writer_, + WritePacket(_, _, self_address_.host(), peer_address_, _)) + .With(Args<0, 1>(PublicResetPacketEq(other_connection_id))) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, packet->length()))); + time_wait_list_manager_.OnBlockedWriterCanWrite(); +} + +TEST_F(QuicTimeWaitListManagerTest, AddConnectionIdTwice) { + // Add connection_ids such that their expiry time is time_wait_period_. + EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id_)); + AddConnectionId(connection_id_, QuicTimeWaitListManager::DO_NOTHING); + EXPECT_TRUE(IsConnectionIdInTimeWait(connection_id_)); + const size_t kConnectionCloseLength = 100; + std::vector> termination_packets; + termination_packets.push_back( + std::unique_ptr(new QuicEncryptedPacket( + new char[kConnectionCloseLength], kConnectionCloseLength, true))); + AddConnectionId(connection_id_, QuicVersionMax(), + QuicTimeWaitListManager::SEND_TERMINATION_PACKETS, + &termination_packets); + EXPECT_TRUE(IsConnectionIdInTimeWait(connection_id_)); + EXPECT_EQ(1u, time_wait_list_manager_.num_connections()); + + EXPECT_CALL(writer_, WritePacket(_, kConnectionCloseLength, + self_address_.host(), peer_address_, _)) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 1))); + + ProcessPacket(connection_id_); + + const QuicTime::Delta time_wait_period = + QuicTimeWaitListManagerPeer::time_wait_period(&time_wait_list_manager_); + + QuicTime::Delta offset = QuicTime::Delta::FromMicroseconds(39); + clock_.AdvanceTime(offset + time_wait_period); + // Now set the current time as time_wait_period + offset usecs. + QuicTime next_alarm_time = clock_.Now() + time_wait_period; + EXPECT_CALL(alarm_factory_, OnAlarmSet(_, next_alarm_time)); + + time_wait_list_manager_.CleanUpOldConnectionIds(); + EXPECT_FALSE(IsConnectionIdInTimeWait(connection_id_)); + EXPECT_EQ(0u, time_wait_list_manager_.num_connections()); +} + +TEST_F(QuicTimeWaitListManagerTest, ConnectionIdsOrderedByTime) { + // Simple randomization: the values of connection_ids are randomly swapped. + // If the container is broken, the test will be 50% flaky. + const uint64_t conn_id1 = QuicRandom::GetInstance()->RandUint64() % 2; + const QuicConnectionId connection_id1 = TestConnectionId(conn_id1); + const QuicConnectionId connection_id2 = TestConnectionId(1 - conn_id1); + + // 1 will hash lower than 2, but we add it later. They should come out in the + // add order, not hash order. + EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id1)); + AddConnectionId(connection_id1, QuicTimeWaitListManager::DO_NOTHING); + clock_.AdvanceTime(QuicTime::Delta::FromMicroseconds(10)); + EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id2)); + AddConnectionId(connection_id2, QuicTimeWaitListManager::DO_NOTHING); + EXPECT_EQ(2u, time_wait_list_manager_.num_connections()); + + const QuicTime::Delta time_wait_period = + QuicTimeWaitListManagerPeer::time_wait_period(&time_wait_list_manager_); + clock_.AdvanceTime(time_wait_period - QuicTime::Delta::FromMicroseconds(9)); + + EXPECT_CALL(alarm_factory_, OnAlarmSet(_, _)); + + time_wait_list_manager_.CleanUpOldConnectionIds(); + EXPECT_FALSE(IsConnectionIdInTimeWait(connection_id1)); + EXPECT_TRUE(IsConnectionIdInTimeWait(connection_id2)); + EXPECT_EQ(1u, time_wait_list_manager_.num_connections()); +} + +TEST_F(QuicTimeWaitListManagerTest, MaxConnectionsTest) { + // Basically, shut off time-based eviction. + SetQuicFlag(quic_time_wait_list_seconds, 10000000000); + SetQuicFlag(quic_time_wait_list_max_connections, 5); + + uint64_t current_conn_id = 0; + const int64_t kMaxConnections = + GetQuicFlag(quic_time_wait_list_max_connections); + // Add exactly the maximum number of connections + for (int64_t i = 0; i < kMaxConnections; ++i) { + ++current_conn_id; + QuicConnectionId current_connection_id = TestConnectionId(current_conn_id); + EXPECT_FALSE(IsConnectionIdInTimeWait(current_connection_id)); + EXPECT_CALL(visitor_, + OnConnectionAddedToTimeWaitList(current_connection_id)); + AddConnectionId(current_connection_id, QuicTimeWaitListManager::DO_NOTHING); + EXPECT_EQ(current_conn_id, time_wait_list_manager_.num_connections()); + EXPECT_TRUE(IsConnectionIdInTimeWait(current_connection_id)); + } + + // Now keep adding. Since we're already at the max, every new connection-id + // will evict the oldest one. + for (int64_t i = 0; i < kMaxConnections; ++i) { + ++current_conn_id; + QuicConnectionId current_connection_id = TestConnectionId(current_conn_id); + const QuicConnectionId id_to_evict = + TestConnectionId(current_conn_id - kMaxConnections); + EXPECT_TRUE(IsConnectionIdInTimeWait(id_to_evict)); + EXPECT_FALSE(IsConnectionIdInTimeWait(current_connection_id)); + EXPECT_CALL(visitor_, + OnConnectionAddedToTimeWaitList(current_connection_id)); + AddConnectionId(current_connection_id, QuicTimeWaitListManager::DO_NOTHING); + EXPECT_EQ(static_cast(kMaxConnections), + time_wait_list_manager_.num_connections()); + EXPECT_FALSE(IsConnectionIdInTimeWait(id_to_evict)); + EXPECT_TRUE(IsConnectionIdInTimeWait(current_connection_id)); + } +} + +TEST_F(QuicTimeWaitListManagerTest, ZeroMaxConnections) { + // Basically, shut off time-based eviction. + SetQuicFlag(quic_time_wait_list_seconds, 10000000000); + // Keep time wait list empty. + SetQuicFlag(quic_time_wait_list_max_connections, 0); + + uint64_t current_conn_id = 0; + // Add exactly the maximum number of connections + for (int64_t i = 0; i < 10; ++i) { + ++current_conn_id; + QuicConnectionId current_connection_id = TestConnectionId(current_conn_id); + EXPECT_FALSE(IsConnectionIdInTimeWait(current_connection_id)); + EXPECT_CALL(visitor_, + OnConnectionAddedToTimeWaitList(current_connection_id)); + AddConnectionId(current_connection_id, QuicTimeWaitListManager::DO_NOTHING); + // Verify time wait list always has 1 connection. + EXPECT_EQ(1u, time_wait_list_manager_.num_connections()); + EXPECT_TRUE(IsConnectionIdInTimeWait(current_connection_id)); + } +} + +// Regression test for b/116200989. +TEST_F(QuicTimeWaitListManagerTest, + SendStatelessResetInResponseToShortHeaders) { + // This test mimics a scenario where an ENCRYPTION_INITIAL connection close is + // added as termination packet for an IETF connection ID. However, a short + // header packet is received later. + const size_t kConnectionCloseLength = 100; + EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id_)); + std::vector> termination_packets; + termination_packets.push_back( + std::unique_ptr(new QuicEncryptedPacket( + new char[kConnectionCloseLength], kConnectionCloseLength, true))); + time_wait_list_manager_.AddConnectionIdToTimeWait( + QuicTimeWaitListManager::SEND_TERMINATION_PACKETS, + TimeWaitConnectionInfo(/*ietf_quic=*/true, &termination_packets, + {connection_id_})); + + // Termination packet is not encrypted, instead, send stateless reset. + EXPECT_CALL(writer_, + WritePacket(_, _, self_address_.host(), peer_address_, _)) + .With(Args<0, 1>(PublicResetPacketEq(connection_id_))) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 0))); + // Processes IETF short header packet. + time_wait_list_manager_.ProcessPacket( + self_address_, peer_address_, connection_id_, + IETF_QUIC_SHORT_HEADER_PACKET, kTestPacketSize, + std::make_unique()); +} + +TEST_F(QuicTimeWaitListManagerTest, + SendConnectionClosePacketsInResponseToShortHeaders) { + const size_t kConnectionCloseLength = 100; + EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id_)); + std::vector> termination_packets; + termination_packets.push_back( + std::unique_ptr(new QuicEncryptedPacket( + new char[kConnectionCloseLength], kConnectionCloseLength, true))); + // Add a CONNECTION_CLOSE termination packet. + time_wait_list_manager_.AddConnectionIdToTimeWait( + QuicTimeWaitListManager::SEND_CONNECTION_CLOSE_PACKETS, + TimeWaitConnectionInfo(/*ietf_quic=*/true, &termination_packets, + {connection_id_})); + EXPECT_CALL(writer_, WritePacket(_, kConnectionCloseLength, + self_address_.host(), peer_address_, _)) + .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 1))); + + // Processes IETF short header packet. + time_wait_list_manager_.ProcessPacket( + self_address_, peer_address_, connection_id_, + IETF_QUIC_SHORT_HEADER_PACKET, kTestPacketSize, + std::make_unique()); +} + +TEST_F(QuicTimeWaitListManagerTest, + SendConnectionClosePacketsForMultipleConnectionIds) { + connection_id_ = TestConnectionId(7); + const size_t kConnectionCloseLength = 100; + EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id_)); + EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(TestConnectionId(8))); + std::vector> termination_packets; + termination_packets.push_back( + std::unique_ptr(new QuicEncryptedPacket( + new char[kConnectionCloseLength], kConnectionCloseLength, true))); + + // Add a CONNECTION_CLOSE termination packet. + std::vector active_connection_ids{connection_id_, + TestConnectionId(8)}; + time_wait_list_manager_.AddConnectionIdToTimeWait( + QuicTimeWaitListManager::SEND_CONNECTION_CLOSE_PACKETS, + TimeWaitConnectionInfo(/*ietf_quic=*/true, &termination_packets, + active_connection_ids, QuicTime::Delta::Zero())); + + EXPECT_CALL(writer_, WritePacket(_, kConnectionCloseLength, + self_address_.host(), peer_address_, _)) + .Times(2) + .WillRepeatedly(Return(WriteResult(WRITE_STATUS_OK, 1))); + // Processes IETF short header packet. + for (auto const& cid : active_connection_ids) { + time_wait_list_manager_.ProcessPacket( + self_address_, peer_address_, cid, IETF_QUIC_SHORT_HEADER_PACKET, + kTestPacketSize, std::make_unique()); + } +} + +// Regression test for b/184053898. +TEST_F(QuicTimeWaitListManagerTest, DonotCrashOnNullStatelessReset) { + // Received a packet with length < + // QuicFramer::GetMinStatelessResetPacketLength(), and this will result in a + // null stateless reset. + time_wait_list_manager_.SendPublicReset( + self_address_, peer_address_, TestConnectionId(1), + /*ietf_quic=*/true, + /*received_packet_length=*/ + QuicFramer::GetMinStatelessResetPacketLength() - 1, + /*packet_context=*/nullptr); +} + +TEST_F(QuicTimeWaitListManagerTest, SendOrQueueNullPacket) { + QuicTimeWaitListManagerPeer::SendOrQueuePacket(&time_wait_list_manager_, + nullptr, nullptr); +} + +TEST_F(QuicTimeWaitListManagerTest, TooManyPendingPackets) { + SetQuicFlag(quic_time_wait_list_max_pending_packets, 5); + const size_t kNumOfUnProcessablePackets = 2048; + EXPECT_CALL(visitor_, OnWriteBlocked(&time_wait_list_manager_)) + .Times(testing::AnyNumber()); + // Write block for the next packets. + EXPECT_CALL(writer_, + WritePacket(_, _, self_address_.host(), peer_address_, _)) + .With(Args<0, 1>(PublicResetPacketEq(TestConnectionId(1)))) + .WillOnce(DoAll(Assign(&writer_is_blocked_, true), + Return(WriteResult(WRITE_STATUS_BLOCKED, EAGAIN)))); + for (size_t i = 0; i < kNumOfUnProcessablePackets; ++i) { + time_wait_list_manager_.SendPublicReset( + self_address_, peer_address_, TestConnectionId(1), + /*ietf_quic=*/true, + /*received_packet_length=*/ + QuicFramer::GetMinStatelessResetPacketLength() + 1, + /*packet_context=*/nullptr); + } + // Verify pending packet queue size is limited. + EXPECT_EQ(5u, QuicTimeWaitListManagerPeer::PendingPacketsQueueSize( + &time_wait_list_manager_)); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_trace_visitor.cc b/quiche/quic/core/quic_trace_visitor.cc new file mode 100644 index 000000000000..04bbda64451e --- /dev/null +++ b/quiche/quic/core/quic_trace_visitor.cc @@ -0,0 +1,341 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_trace_visitor.h" + +#include + +#include "quiche/quic/core/quic_types.h" +#include "quiche/common/quiche_endian.h" + +namespace quic { + +quic_trace::EncryptionLevel EncryptionLevelToProto(EncryptionLevel level) { + switch (level) { + case ENCRYPTION_INITIAL: + return quic_trace::ENCRYPTION_INITIAL; + case ENCRYPTION_HANDSHAKE: + return quic_trace::ENCRYPTION_HANDSHAKE; + case ENCRYPTION_ZERO_RTT: + return quic_trace::ENCRYPTION_0RTT; + case ENCRYPTION_FORWARD_SECURE: + return quic_trace::ENCRYPTION_1RTT; + case NUM_ENCRYPTION_LEVELS: + QUIC_BUG(quic_bug_10284_1) << "Invalid encryption level specified"; + return quic_trace::ENCRYPTION_UNKNOWN; + } +} + +QuicTraceVisitor::QuicTraceVisitor(const QuicConnection* connection) + : connection_(connection), + start_time_(connection_->clock()->ApproximateNow()) { + std::string binary_connection_id(connection->connection_id().data(), + connection->connection_id().length()); + // We assume that the connection ID in gQUIC is equivalent to the + // server-chosen client-selected ID. + switch (connection->perspective()) { + case Perspective::IS_CLIENT: + trace_.set_destination_connection_id(binary_connection_id); + break; + case Perspective::IS_SERVER: + trace_.set_source_connection_id(binary_connection_id); + break; + } +} + +void QuicTraceVisitor::OnPacketSent( + QuicPacketNumber packet_number, QuicPacketLength packet_length, + bool /*has_crypto_handshake*/, TransmissionType /*transmission_type*/, + EncryptionLevel encryption_level, const QuicFrames& retransmittable_frames, + const QuicFrames& /*nonretransmittable_frames*/, QuicTime sent_time) { + quic_trace::Event* event = trace_.add_events(); + event->set_event_type(quic_trace::PACKET_SENT); + event->set_time_us(ConvertTimestampToRecordedFormat(sent_time)); + event->set_packet_number(packet_number.ToUint64()); + event->set_packet_size(packet_length); + event->set_encryption_level(EncryptionLevelToProto(encryption_level)); + + for (const QuicFrame& frame : retransmittable_frames) { + switch (frame.type) { + case STREAM_FRAME: + case RST_STREAM_FRAME: + case CONNECTION_CLOSE_FRAME: + case WINDOW_UPDATE_FRAME: + case BLOCKED_FRAME: + case PING_FRAME: + case HANDSHAKE_DONE_FRAME: + case ACK_FREQUENCY_FRAME: + PopulateFrameInfo(frame, event->add_frames()); + break; + + case PADDING_FRAME: + case MTU_DISCOVERY_FRAME: + case STOP_WAITING_FRAME: + case ACK_FRAME: + QUIC_BUG(quic_bug_12732_1) + << "Frames of type are not retransmittable and are not supposed " + "to be in retransmittable_frames"; + break; + + // New IETF frames, not used in current gQUIC version. + case NEW_CONNECTION_ID_FRAME: + case RETIRE_CONNECTION_ID_FRAME: + case MAX_STREAMS_FRAME: + case STREAMS_BLOCKED_FRAME: + case PATH_RESPONSE_FRAME: + case PATH_CHALLENGE_FRAME: + case STOP_SENDING_FRAME: + case MESSAGE_FRAME: + case CRYPTO_FRAME: + case NEW_TOKEN_FRAME: + break; + + // Ignore gQUIC-specific frames. + case GOAWAY_FRAME: + break; + + case NUM_FRAME_TYPES: + QUIC_BUG(quic_bug_10284_2) << "Unknown frame type encountered"; + break; + } + } + + // Output PCC DebugState on packet sent for analysis. + if (connection_->sent_packet_manager() + .GetSendAlgorithm() + ->GetCongestionControlType() == kPCC) { + PopulateTransportState(event->mutable_transport_state()); + } +} + +void QuicTraceVisitor::PopulateFrameInfo(const QuicFrame& frame, + quic_trace::Frame* frame_record) { + switch (frame.type) { + case STREAM_FRAME: { + frame_record->set_frame_type(quic_trace::STREAM); + + quic_trace::StreamFrameInfo* info = + frame_record->mutable_stream_frame_info(); + info->set_stream_id(frame.stream_frame.stream_id); + info->set_fin(frame.stream_frame.fin); + info->set_offset(frame.stream_frame.offset); + info->set_length(frame.stream_frame.data_length); + break; + } + + case ACK_FRAME: { + frame_record->set_frame_type(quic_trace::ACK); + + quic_trace::AckInfo* info = frame_record->mutable_ack_info(); + info->set_ack_delay_us(frame.ack_frame->ack_delay_time.ToMicroseconds()); + for (const auto& interval : frame.ack_frame->packets) { + quic_trace::AckBlock* block = info->add_acked_packets(); + // We record intervals as [a, b], whereas the in-memory representation + // we currently use is [a, b). + block->set_first_packet(interval.min().ToUint64()); + block->set_last_packet(interval.max().ToUint64() - 1); + } + break; + } + + case RST_STREAM_FRAME: { + frame_record->set_frame_type(quic_trace::RESET_STREAM); + + quic_trace::ResetStreamInfo* info = + frame_record->mutable_reset_stream_info(); + info->set_stream_id(frame.rst_stream_frame->stream_id); + info->set_final_offset(frame.rst_stream_frame->byte_offset); + info->set_application_error_code(frame.rst_stream_frame->error_code); + break; + } + + case CONNECTION_CLOSE_FRAME: { + frame_record->set_frame_type(quic_trace::CONNECTION_CLOSE); + + quic_trace::CloseInfo* info = frame_record->mutable_close_info(); + info->set_error_code(frame.connection_close_frame->quic_error_code); + info->set_reason_phrase(frame.connection_close_frame->error_details); + info->set_close_type(static_cast( + frame.connection_close_frame->close_type)); + info->set_transport_close_frame_type( + frame.connection_close_frame->transport_close_frame_type); + break; + } + + case GOAWAY_FRAME: + // Do not bother logging this since the frame in question is + // gQUIC-specific. + break; + + case WINDOW_UPDATE_FRAME: { + bool is_connection = frame.window_update_frame.stream_id == 0; + frame_record->set_frame_type(is_connection ? quic_trace::MAX_DATA + : quic_trace::MAX_STREAM_DATA); + + quic_trace::FlowControlInfo* info = + frame_record->mutable_flow_control_info(); + info->set_max_data(frame.window_update_frame.max_data); + if (!is_connection) { + info->set_stream_id(frame.window_update_frame.stream_id); + } + break; + } + + case BLOCKED_FRAME: { + bool is_connection = frame.blocked_frame.stream_id == 0; + frame_record->set_frame_type(is_connection ? quic_trace::BLOCKED + : quic_trace::STREAM_BLOCKED); + + quic_trace::FlowControlInfo* info = + frame_record->mutable_flow_control_info(); + if (!is_connection) { + info->set_stream_id(frame.window_update_frame.stream_id); + } + break; + } + + case PING_FRAME: + case MTU_DISCOVERY_FRAME: + case HANDSHAKE_DONE_FRAME: + frame_record->set_frame_type(quic_trace::PING); + break; + + case PADDING_FRAME: + frame_record->set_frame_type(quic_trace::PADDING); + break; + + case STOP_WAITING_FRAME: + // We're going to pretend those do not exist. + break; + + // New IETF frames, not used in current gQUIC version. + case NEW_CONNECTION_ID_FRAME: + case RETIRE_CONNECTION_ID_FRAME: + case MAX_STREAMS_FRAME: + case STREAMS_BLOCKED_FRAME: + case PATH_RESPONSE_FRAME: + case PATH_CHALLENGE_FRAME: + case STOP_SENDING_FRAME: + case MESSAGE_FRAME: + case CRYPTO_FRAME: + case NEW_TOKEN_FRAME: + case ACK_FREQUENCY_FRAME: + break; + + case NUM_FRAME_TYPES: + QUIC_BUG(quic_bug_10284_3) << "Unknown frame type encountered"; + break; + } +} + +void QuicTraceVisitor::OnIncomingAck( + QuicPacketNumber /*ack_packet_number*/, EncryptionLevel ack_decrypted_level, + const QuicAckFrame& ack_frame, QuicTime ack_receive_time, + QuicPacketNumber /*largest_observed*/, bool /*rtt_updated*/, + QuicPacketNumber /*least_unacked_sent_packet*/) { + quic_trace::Event* event = trace_.add_events(); + event->set_time_us(ConvertTimestampToRecordedFormat(ack_receive_time)); + event->set_packet_number(connection_->GetLargestReceivedPacket().ToUint64()); + event->set_event_type(quic_trace::PACKET_RECEIVED); + event->set_encryption_level(EncryptionLevelToProto(ack_decrypted_level)); + + // TODO(vasilvv): consider removing this copy. + QuicAckFrame copy_of_ack = ack_frame; + PopulateFrameInfo(QuicFrame(©_of_ack), event->add_frames()); + PopulateTransportState(event->mutable_transport_state()); +} + +void QuicTraceVisitor::OnPacketLoss(QuicPacketNumber lost_packet_number, + EncryptionLevel encryption_level, + TransmissionType /*transmission_type*/, + QuicTime detection_time) { + quic_trace::Event* event = trace_.add_events(); + event->set_time_us(ConvertTimestampToRecordedFormat(detection_time)); + event->set_event_type(quic_trace::PACKET_LOST); + event->set_packet_number(lost_packet_number.ToUint64()); + PopulateTransportState(event->mutable_transport_state()); + event->set_encryption_level(EncryptionLevelToProto(encryption_level)); +} + +void QuicTraceVisitor::OnWindowUpdateFrame(const QuicWindowUpdateFrame& frame, + const QuicTime& receive_time) { + quic_trace::Event* event = trace_.add_events(); + event->set_time_us(ConvertTimestampToRecordedFormat(receive_time)); + event->set_event_type(quic_trace::PACKET_RECEIVED); + event->set_packet_number(connection_->GetLargestReceivedPacket().ToUint64()); + + PopulateFrameInfo(QuicFrame(frame), event->add_frames()); +} + +void QuicTraceVisitor::OnSuccessfulVersionNegotiation( + const ParsedQuicVersion& version) { + uint32_t tag = + quiche::QuicheEndian::HostToNet32(CreateQuicVersionLabel(version)); + std::string binary_tag(reinterpret_cast(&tag), sizeof(tag)); + trace_.set_protocol_version(binary_tag); +} + +void QuicTraceVisitor::OnApplicationLimited() { + quic_trace::Event* event = trace_.add_events(); + event->set_time_us( + ConvertTimestampToRecordedFormat(connection_->clock()->ApproximateNow())); + event->set_event_type(quic_trace::APPLICATION_LIMITED); +} + +void QuicTraceVisitor::OnAdjustNetworkParameters(QuicBandwidth bandwidth, + QuicTime::Delta rtt, + QuicByteCount /*old_cwnd*/, + QuicByteCount /*new_cwnd*/) { + quic_trace::Event* event = trace_.add_events(); + event->set_time_us( + ConvertTimestampToRecordedFormat(connection_->clock()->ApproximateNow())); + event->set_event_type(quic_trace::EXTERNAL_PARAMETERS); + + quic_trace::ExternalNetworkParameters* parameters = + event->mutable_external_network_parameters(); + if (!bandwidth.IsZero()) { + parameters->set_bandwidth_bps(bandwidth.ToBitsPerSecond()); + } + if (!rtt.IsZero()) { + parameters->set_rtt_us(rtt.ToMicroseconds()); + } +} + +uint64_t QuicTraceVisitor::ConvertTimestampToRecordedFormat( + QuicTime timestamp) { + if (timestamp < start_time_) { + QUIC_BUG(quic_bug_10284_4) + << "Timestamp went back in time while recording a trace"; + return 0; + } + + return (timestamp - start_time_).ToMicroseconds(); +} + +void QuicTraceVisitor::PopulateTransportState( + quic_trace::TransportState* state) { + const RttStats* rtt_stats = connection_->sent_packet_manager().GetRttStats(); + state->set_min_rtt_us(rtt_stats->min_rtt().ToMicroseconds()); + state->set_smoothed_rtt_us(rtt_stats->smoothed_rtt().ToMicroseconds()); + state->set_last_rtt_us(rtt_stats->latest_rtt().ToMicroseconds()); + + state->set_cwnd_bytes( + connection_->sent_packet_manager().GetCongestionWindowInBytes()); + QuicByteCount in_flight = + connection_->sent_packet_manager().GetBytesInFlight(); + state->set_in_flight_bytes(in_flight); + state->set_pacing_rate_bps(connection_->sent_packet_manager() + .GetSendAlgorithm() + ->PacingRate(in_flight) + .ToBitsPerSecond()); + + if (connection_->sent_packet_manager() + .GetSendAlgorithm() + ->GetCongestionControlType() == kPCC) { + state->set_congestion_control_state( + connection_->sent_packet_manager().GetSendAlgorithm()->GetDebugState()); + } +} + +} // namespace quic diff --git a/quiche/quic/core/quic_trace_visitor.h b/quiche/quic/core/quic_trace_visitor.h new file mode 100644 index 000000000000..6a8c442e8c14 --- /dev/null +++ b/quiche/quic/core/quic_trace_visitor.h @@ -0,0 +1,75 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_TRACE_VISITOR_H_ +#define QUICHE_QUIC_CORE_QUIC_TRACE_VISITOR_H_ + +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/core/quic_types.h" +#include "quic_trace/quic_trace.pb.h" + +namespace quic { + +// Records a QUIC trace protocol buffer for a QuicConnection. It's the +// responsibility of the user of this visitor to process or store the resulting +// trace, which can be accessed via trace(). +class QUIC_NO_EXPORT QuicTraceVisitor : public QuicConnectionDebugVisitor { + public: + explicit QuicTraceVisitor(const QuicConnection* connection); + + void OnPacketSent(QuicPacketNumber packet_number, + QuicPacketLength packet_length, bool has_crypto_handshake, + TransmissionType transmission_type, + EncryptionLevel encryption_level, + const QuicFrames& retransmittable_frames, + const QuicFrames& nonretransmittable_frames, + QuicTime sent_time) override; + + void OnIncomingAck(QuicPacketNumber ack_packet_number, + EncryptionLevel ack_decrypted_level, + const QuicAckFrame& ack_frame, QuicTime ack_receive_time, + QuicPacketNumber largest_observed, bool rtt_updated, + QuicPacketNumber least_unacked_sent_packet) override; + + void OnPacketLoss(QuicPacketNumber lost_packet_number, + EncryptionLevel encryption_level, + TransmissionType transmission_type, + QuicTime detection_time) override; + + void OnWindowUpdateFrame(const QuicWindowUpdateFrame& frame, + const QuicTime& receive_time) override; + + void OnSuccessfulVersionNegotiation( + const ParsedQuicVersion& version) override; + + void OnApplicationLimited() override; + + void OnAdjustNetworkParameters(QuicBandwidth bandwidth, QuicTime::Delta rtt, + QuicByteCount old_cwnd, + QuicByteCount new_cwnd) override; + + // Returns a mutable pointer to the trace. The trace is owned by the + // visitor, but can be moved using Swap() method after the connection is + // finished. + quic_trace::Trace* trace() { return &trace_; } + + private: + // Converts QuicTime into a microsecond delta w.r.t. the beginning of the + // connection. + uint64_t ConvertTimestampToRecordedFormat(QuicTime timestamp); + // Populates a quic_trace::Frame message from |frame|. + void PopulateFrameInfo(const QuicFrame& frame, + quic_trace::Frame* frame_record); + // Populates a quic_trace::TransportState message from the associated + // connection. + void PopulateTransportState(quic_trace::TransportState* state); + + quic_trace::Trace trace_; + const QuicConnection* connection_; + const QuicTime start_time_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_TRACE_VISITOR_H_ diff --git a/quiche/quic/core/quic_trace_visitor_test.cc b/quiche/quic/core/quic_trace_visitor_test.cc new file mode 100644 index 000000000000..5584ebe65e49 --- /dev/null +++ b/quiche/quic/core/quic_trace_visitor_test.cc @@ -0,0 +1,184 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_trace_visitor.h" + +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/test_tools/simulator/quic_endpoint.h" +#include "quiche/quic/test_tools/simulator/simulator.h" +#include "quiche/quic/test_tools/simulator/switch.h" + +namespace quic::test { +namespace { + +const QuicByteCount kTransferSize = 1000 * kMaxOutgoingPacketSize; +const QuicByteCount kTestStreamNumber = 3; +const QuicTime::Delta kDelay = QuicTime::Delta::FromMilliseconds(20); + +// The trace for this test is generated using a simulator transfer. +class QuicTraceVisitorTest : public QuicTest { + public: + QuicTraceVisitorTest() { + QuicConnectionId connection_id = test::TestConnectionId(); + simulator::Simulator simulator; + simulator::QuicEndpoint client(&simulator, "Client", "Server", + Perspective::IS_CLIENT, connection_id); + simulator::QuicEndpoint server(&simulator, "Server", "Client", + Perspective::IS_SERVER, connection_id); + + const QuicBandwidth kBandwidth = QuicBandwidth::FromKBitsPerSecond(1000); + const QuicByteCount kBdp = kBandwidth * (2 * kDelay); + + // Create parameters such that some loss is observed. + simulator::Switch network_switch(&simulator, "Switch", 8, 0.5 * kBdp); + simulator::SymmetricLink client_link(&client, network_switch.port(1), + 2 * kBandwidth, kDelay); + simulator::SymmetricLink server_link(&server, network_switch.port(2), + kBandwidth, kDelay); + + QuicTraceVisitor visitor(client.connection()); + client.connection()->set_debug_visitor(&visitor); + + // Transfer about a megabyte worth of data from client to server. + const QuicTime::Delta kDeadline = + 3 * kBandwidth.TransferTime(kTransferSize); + client.AddBytesToTransfer(kTransferSize); + bool simulator_result = simulator.RunUntilOrTimeout( + [&]() { return server.bytes_received() >= kTransferSize; }, kDeadline); + QUICHE_CHECK(simulator_result); + + // Save the trace and ensure some loss was observed. + trace_.Swap(visitor.trace()); + QUICHE_CHECK_NE(0u, client.connection()->GetStats().packets_retransmitted); + packets_sent_ = client.connection()->GetStats().packets_sent; + } + + std::vector AllEventsWithType( + quic_trace::EventType event_type) { + std::vector result; + for (const auto& event : trace_.events()) { + if (event.event_type() == event_type) { + result.push_back(event); + } + } + return result; + } + + protected: + quic_trace::Trace trace_; + QuicPacketCount packets_sent_; +}; + +TEST_F(QuicTraceVisitorTest, ConnectionId) { + char expected_cid[] = {0, 0, 0, 0, 0, 0, 0, 42}; + EXPECT_EQ(std::string(expected_cid, sizeof(expected_cid)), + trace_.destination_connection_id()); +} + +TEST_F(QuicTraceVisitorTest, Version) { + std::string version = trace_.protocol_version(); + ASSERT_EQ(4u, version.size()); + // Ensure version isn't all-zeroes. + EXPECT_TRUE(version[0] != 0 || version[1] != 0 || version[2] != 0 || + version[3] != 0); +} + +// Check that basic metadata about sent packets is recorded. +TEST_F(QuicTraceVisitorTest, SentPacket) { + auto sent_packets = AllEventsWithType(quic_trace::PACKET_SENT); + EXPECT_EQ(packets_sent_, sent_packets.size()); + ASSERT_GT(sent_packets.size(), 0u); + + EXPECT_EQ(sent_packets[0].packet_size(), kDefaultMaxPacketSize); + EXPECT_EQ(sent_packets[0].packet_number(), 1u); +} + +// Ensure that every stream frame that was sent is recorded. +TEST_F(QuicTraceVisitorTest, SentStream) { + auto sent_packets = AllEventsWithType(quic_trace::PACKET_SENT); + + QuicIntervalSet offsets; + for (const quic_trace::Event& packet : sent_packets) { + for (const quic_trace::Frame& frame : packet.frames()) { + if (frame.frame_type() != quic_trace::STREAM) { + continue; + } + + const quic_trace::StreamFrameInfo& info = frame.stream_frame_info(); + if (info.stream_id() != kTestStreamNumber) { + continue; + } + + ASSERT_GT(info.length(), 0u); + offsets.Add(info.offset(), info.offset() + info.length()); + } + } + + ASSERT_EQ(1u, offsets.Size()); + EXPECT_EQ(0u, offsets.begin()->min()); + EXPECT_EQ(kTransferSize, offsets.rbegin()->max()); +} + +// Ensure that all packets are either acknowledged or lost. +TEST_F(QuicTraceVisitorTest, AckPackets) { + QuicIntervalSet packets; + for (const quic_trace::Event& packet : trace_.events()) { + if (packet.event_type() == quic_trace::PACKET_RECEIVED) { + for (const quic_trace::Frame& frame : packet.frames()) { + if (frame.frame_type() != quic_trace::ACK) { + continue; + } + + const quic_trace::AckInfo& info = frame.ack_info(); + for (const auto& block : info.acked_packets()) { + packets.Add(QuicPacketNumber(block.first_packet()), + QuicPacketNumber(block.last_packet()) + 1); + } + } + } + if (packet.event_type() == quic_trace::PACKET_LOST) { + packets.Add(QuicPacketNumber(packet.packet_number()), + QuicPacketNumber(packet.packet_number()) + 1); + } + } + + ASSERT_EQ(1u, packets.Size()); + EXPECT_EQ(QuicPacketNumber(1u), packets.begin()->min()); + // We leave some room (20 packets) for the packets which did not receive + // conclusive status at the end of simulation. + EXPECT_GT(packets.rbegin()->max(), QuicPacketNumber(packets_sent_ - 20)); +} + +TEST_F(QuicTraceVisitorTest, TransportState) { + auto acks = AllEventsWithType(quic_trace::PACKET_RECEIVED); + ASSERT_EQ(1, acks[0].frames_size()); + ASSERT_EQ(quic_trace::ACK, acks[0].frames(0).frame_type()); + + // Check that min-RTT at the end is a reasonable approximation. + EXPECT_LE((4 * kDelay).ToMicroseconds() * 1., + acks.rbegin()->transport_state().min_rtt_us()); + EXPECT_GE((4 * kDelay).ToMicroseconds() * 1.25, + acks.rbegin()->transport_state().min_rtt_us()); +} + +TEST_F(QuicTraceVisitorTest, EncryptionLevels) { + for (const auto& event : trace_.events()) { + switch (event.event_type()) { + case quic_trace::PACKET_SENT: + case quic_trace::PACKET_RECEIVED: + case quic_trace::PACKET_LOST: + ASSERT_TRUE(event.has_encryption_level()); + ASSERT_NE(event.encryption_level(), quic_trace::ENCRYPTION_UNKNOWN); + break; + + default: + break; + } + } +} + +} // namespace +} // namespace quic::test diff --git a/quiche/quic/core/quic_transmission_info.cc b/quiche/quic/core/quic_transmission_info.cc new file mode 100644 index 000000000000..6263216e23a3 --- /dev/null +++ b/quiche/quic/core/quic_transmission_info.cc @@ -0,0 +1,56 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_transmission_info.h" + +#include "absl/strings/str_cat.h" + +namespace quic { + +QuicTransmissionInfo::QuicTransmissionInfo() + : sent_time(QuicTime::Zero()), + bytes_sent(0), + encryption_level(ENCRYPTION_INITIAL), + transmission_type(NOT_RETRANSMISSION), + in_flight(false), + state(OUTSTANDING), + has_crypto_handshake(false), + has_ack_frequency(false), + ecn_codepoint(ECN_NOT_ECT) {} + +QuicTransmissionInfo::QuicTransmissionInfo( + EncryptionLevel level, TransmissionType transmission_type, + QuicTime sent_time, QuicPacketLength bytes_sent, bool has_crypto_handshake, + bool has_ack_frequency, QuicEcnCodepoint ecn_codepoint) + : sent_time(sent_time), + bytes_sent(bytes_sent), + encryption_level(level), + transmission_type(transmission_type), + in_flight(false), + state(OUTSTANDING), + has_crypto_handshake(has_crypto_handshake), + has_ack_frequency(has_ack_frequency), + ecn_codepoint(ecn_codepoint) {} + +QuicTransmissionInfo::QuicTransmissionInfo(const QuicTransmissionInfo& other) = + default; + +QuicTransmissionInfo::~QuicTransmissionInfo() {} + +std::string QuicTransmissionInfo::DebugString() const { + return absl::StrCat( + "{sent_time: ", sent_time.ToDebuggingValue(), + ", bytes_sent: ", bytes_sent, + ", encryption_level: ", EncryptionLevelToString(encryption_level), + ", transmission_type: ", TransmissionTypeToString(transmission_type), + ", in_flight: ", in_flight, ", state: ", state, + ", has_crypto_handshake: ", has_crypto_handshake, + ", has_ack_frequency: ", has_ack_frequency, + ", first_sent_after_loss: ", first_sent_after_loss.ToString(), + ", largest_acked: ", largest_acked.ToString(), + ", retransmittable_frames: ", QuicFramesToString(retransmittable_frames), + "}"); +} + +} // namespace quic diff --git a/quiche/quic/core/quic_transmission_info.h b/quiche/quic/core/quic_transmission_info.h new file mode 100644 index 000000000000..2b81078dc38d --- /dev/null +++ b/quiche/quic/core/quic_transmission_info.h @@ -0,0 +1,66 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_TRANSMISSION_INFO_H_ +#define QUICHE_QUIC_CORE_QUIC_TRANSMISSION_INFO_H_ + +#include + +#include "quiche/quic/core/frames/quic_frame.h" +#include "quiche/quic/core/quic_ack_listener_interface.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// Stores details of a single sent packet. +struct QUIC_EXPORT_PRIVATE QuicTransmissionInfo { + // Used by STL when assigning into a map. + QuicTransmissionInfo(); + + // Constructs a Transmission with a new all_transmissions set + // containing |packet_number|. + QuicTransmissionInfo(EncryptionLevel level, + TransmissionType transmission_type, QuicTime sent_time, + QuicPacketLength bytes_sent, bool has_crypto_handshake, + bool has_ack_frequency, QuicEcnCodepoint ecn_codepoint); + + QuicTransmissionInfo(const QuicTransmissionInfo& other); + + ~QuicTransmissionInfo(); + + std::string DebugString() const; + + QuicFrames retransmittable_frames; + QuicTime sent_time; + QuicPacketLength bytes_sent; + EncryptionLevel encryption_level; + // Reason why this packet was transmitted. + TransmissionType transmission_type; + // In flight packets have not been abandoned or lost. + bool in_flight; + // State of this packet. + SentPacketState state; + // True if the packet contains stream data from the crypto stream. + bool has_crypto_handshake; + // True if the packet contains ack frequency frame. + bool has_ack_frequency; + // Records the first sent packet after this packet was detected lost. Zero if + // this packet has not been detected lost. This is used to keep lost packet + // for another RTT (for potential spurious loss detection) + QuicPacketNumber first_sent_after_loss; + // The largest_acked in the ack frame, if the packet contains an ack. + QuicPacketNumber largest_acked; + // The ECN codepoint with which this packet was sent. + QuicEcnCodepoint ecn_codepoint; +}; +// TODO(ianswett): Add static_assert when size of this struct is reduced below +// 64 bytes. +// NOTE(vlovich): Existing static_assert removed because padding differences on +// 64-bit iOS resulted in an 88-byte struct that is greater than the 84-byte +// limit on other platforms. Removing per ianswett's request. + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_TRANSMISSION_INFO_H_ diff --git a/quiche/quic/core/quic_types.cc b/quiche/quic/core/quic_types.cc new file mode 100644 index 000000000000..3981256f6744 --- /dev/null +++ b/quiche/quic/core/quic_types.cc @@ -0,0 +1,465 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_types.h" + +#include + +#include "absl/strings/str_cat.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/common/print_elements.h" + +namespace quic { + +static_assert(sizeof(StatelessResetToken) == kStatelessResetTokenLength, + "bad size"); + +std::ostream& operator<<(std::ostream& os, const QuicConsumedData& s) { + os << "bytes_consumed: " << s.bytes_consumed + << " fin_consumed: " << s.fin_consumed; + return os; +} + +std::string PerspectiveToString(Perspective perspective) { + if (perspective == Perspective::IS_SERVER) { + return "IS_SERVER"; + } + if (perspective == Perspective::IS_CLIENT) { + return "IS_CLIENT"; + } + return absl::StrCat("Unknown(", static_cast(perspective), ")"); +} + +std::ostream& operator<<(std::ostream& os, const Perspective& perspective) { + os << PerspectiveToString(perspective); + return os; +} + +std::string ConnectionCloseSourceToString( + ConnectionCloseSource connection_close_source) { + if (connection_close_source == ConnectionCloseSource::FROM_PEER) { + return "FROM_PEER"; + } + if (connection_close_source == ConnectionCloseSource::FROM_SELF) { + return "FROM_SELF"; + } + return absl::StrCat("Unknown(", static_cast(connection_close_source), + ")"); +} + +std::ostream& operator<<(std::ostream& os, + const ConnectionCloseSource& connection_close_source) { + os << ConnectionCloseSourceToString(connection_close_source); + return os; +} + +std::string ConnectionCloseBehaviorToString( + ConnectionCloseBehavior connection_close_behavior) { + if (connection_close_behavior == ConnectionCloseBehavior::SILENT_CLOSE) { + return "SILENT_CLOSE"; + } + if (connection_close_behavior == + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET) { + return "SEND_CONNECTION_CLOSE_PACKET"; + } + return absl::StrCat("Unknown(", static_cast(connection_close_behavior), + ")"); +} + +std::ostream& operator<<( + std::ostream& os, + const ConnectionCloseBehavior& connection_close_behavior) { + os << ConnectionCloseBehaviorToString(connection_close_behavior); + return os; +} + +std::ostream& operator<<(std::ostream& os, const AckedPacket& acked_packet) { + os << "{ packet_number: " << acked_packet.packet_number + << ", bytes_acked: " << acked_packet.bytes_acked << ", receive_timestamp: " + << acked_packet.receive_timestamp.ToDebuggingValue() << "} "; + return os; +} + +std::ostream& operator<<(std::ostream& os, const LostPacket& lost_packet) { + os << "{ packet_number: " << lost_packet.packet_number + << ", bytes_lost: " << lost_packet.bytes_lost << "} "; + return os; +} + +std::string HistogramEnumString(WriteStatus enum_value) { + switch (enum_value) { + case WRITE_STATUS_OK: + return "OK"; + case WRITE_STATUS_BLOCKED: + return "BLOCKED"; + case WRITE_STATUS_BLOCKED_DATA_BUFFERED: + return "BLOCKED_DATA_BUFFERED"; + case WRITE_STATUS_ERROR: + return "ERROR"; + case WRITE_STATUS_MSG_TOO_BIG: + return "MSG_TOO_BIG"; + case WRITE_STATUS_FAILED_TO_COALESCE_PACKET: + return "WRITE_STATUS_FAILED_TO_COALESCE_PACKET"; + case WRITE_STATUS_NUM_VALUES: + return "NUM_VALUES"; + } + QUIC_DLOG(ERROR) << "Invalid WriteStatus value: " + << static_cast(enum_value); + return ""; +} + +std::ostream& operator<<(std::ostream& os, const WriteStatus& status) { + os << HistogramEnumString(status); + return os; +} + +std::ostream& operator<<(std::ostream& os, const WriteResult& s) { + os << "{ status: " << s.status; + if (s.status == WRITE_STATUS_OK) { + os << ", bytes_written: " << s.bytes_written; + } else { + os << ", error_code: " << s.error_code; + } + os << " }"; + return os; +} + +MessageResult::MessageResult(MessageStatus status, QuicMessageId message_id) + : status(status), message_id(message_id) {} + +#define RETURN_STRING_LITERAL(x) \ + case x: \ + return #x; + +std::string QuicFrameTypeToString(QuicFrameType t) { + switch (t) { + RETURN_STRING_LITERAL(PADDING_FRAME) + RETURN_STRING_LITERAL(RST_STREAM_FRAME) + RETURN_STRING_LITERAL(CONNECTION_CLOSE_FRAME) + RETURN_STRING_LITERAL(GOAWAY_FRAME) + RETURN_STRING_LITERAL(WINDOW_UPDATE_FRAME) + RETURN_STRING_LITERAL(BLOCKED_FRAME) + RETURN_STRING_LITERAL(STOP_WAITING_FRAME) + RETURN_STRING_LITERAL(PING_FRAME) + RETURN_STRING_LITERAL(CRYPTO_FRAME) + RETURN_STRING_LITERAL(HANDSHAKE_DONE_FRAME) + RETURN_STRING_LITERAL(STREAM_FRAME) + RETURN_STRING_LITERAL(ACK_FRAME) + RETURN_STRING_LITERAL(MTU_DISCOVERY_FRAME) + RETURN_STRING_LITERAL(NEW_CONNECTION_ID_FRAME) + RETURN_STRING_LITERAL(MAX_STREAMS_FRAME) + RETURN_STRING_LITERAL(STREAMS_BLOCKED_FRAME) + RETURN_STRING_LITERAL(PATH_RESPONSE_FRAME) + RETURN_STRING_LITERAL(PATH_CHALLENGE_FRAME) + RETURN_STRING_LITERAL(STOP_SENDING_FRAME) + RETURN_STRING_LITERAL(MESSAGE_FRAME) + RETURN_STRING_LITERAL(NEW_TOKEN_FRAME) + RETURN_STRING_LITERAL(RETIRE_CONNECTION_ID_FRAME) + RETURN_STRING_LITERAL(ACK_FREQUENCY_FRAME) + RETURN_STRING_LITERAL(NUM_FRAME_TYPES) + } + return absl::StrCat("Unknown(", static_cast(t), ")"); +} + +std::ostream& operator<<(std::ostream& os, const QuicFrameType& t) { + os << QuicFrameTypeToString(t); + return os; +} + +std::string QuicIetfFrameTypeString(QuicIetfFrameType t) { + if (IS_IETF_STREAM_FRAME(t)) { + return "IETF_STREAM"; + } + + switch (t) { + RETURN_STRING_LITERAL(IETF_PADDING); + RETURN_STRING_LITERAL(IETF_PING); + RETURN_STRING_LITERAL(IETF_ACK); + RETURN_STRING_LITERAL(IETF_ACK_ECN); + RETURN_STRING_LITERAL(IETF_RST_STREAM); + RETURN_STRING_LITERAL(IETF_STOP_SENDING); + RETURN_STRING_LITERAL(IETF_CRYPTO); + RETURN_STRING_LITERAL(IETF_NEW_TOKEN); + RETURN_STRING_LITERAL(IETF_MAX_DATA); + RETURN_STRING_LITERAL(IETF_MAX_STREAM_DATA); + RETURN_STRING_LITERAL(IETF_MAX_STREAMS_BIDIRECTIONAL); + RETURN_STRING_LITERAL(IETF_MAX_STREAMS_UNIDIRECTIONAL); + RETURN_STRING_LITERAL(IETF_DATA_BLOCKED); + RETURN_STRING_LITERAL(IETF_STREAM_DATA_BLOCKED); + RETURN_STRING_LITERAL(IETF_STREAMS_BLOCKED_BIDIRECTIONAL); + RETURN_STRING_LITERAL(IETF_STREAMS_BLOCKED_UNIDIRECTIONAL); + RETURN_STRING_LITERAL(IETF_NEW_CONNECTION_ID); + RETURN_STRING_LITERAL(IETF_RETIRE_CONNECTION_ID); + RETURN_STRING_LITERAL(IETF_PATH_CHALLENGE); + RETURN_STRING_LITERAL(IETF_PATH_RESPONSE); + RETURN_STRING_LITERAL(IETF_CONNECTION_CLOSE); + RETURN_STRING_LITERAL(IETF_APPLICATION_CLOSE); + RETURN_STRING_LITERAL(IETF_EXTENSION_MESSAGE_NO_LENGTH); + RETURN_STRING_LITERAL(IETF_EXTENSION_MESSAGE); + RETURN_STRING_LITERAL(IETF_EXTENSION_MESSAGE_NO_LENGTH_V99); + RETURN_STRING_LITERAL(IETF_EXTENSION_MESSAGE_V99); + default: + return absl::StrCat("Private value (", t, ")"); + } +} +std::ostream& operator<<(std::ostream& os, const QuicIetfFrameType& c) { + os << QuicIetfFrameTypeString(c); + return os; +} + +std::string TransmissionTypeToString(TransmissionType transmission_type) { + switch (transmission_type) { + RETURN_STRING_LITERAL(NOT_RETRANSMISSION); + RETURN_STRING_LITERAL(HANDSHAKE_RETRANSMISSION); + RETURN_STRING_LITERAL(ALL_ZERO_RTT_RETRANSMISSION); + RETURN_STRING_LITERAL(LOSS_RETRANSMISSION); + RETURN_STRING_LITERAL(PTO_RETRANSMISSION); + RETURN_STRING_LITERAL(PATH_RETRANSMISSION); + RETURN_STRING_LITERAL(ALL_INITIAL_RETRANSMISSION); + default: + // Some varz rely on this behavior for statistic collection. + if (transmission_type == LAST_TRANSMISSION_TYPE + 1) { + return "INVALID_TRANSMISSION_TYPE"; + } + return absl::StrCat("Unknown(", static_cast(transmission_type), ")"); + } +} + +std::ostream& operator<<(std::ostream& os, TransmissionType transmission_type) { + os << TransmissionTypeToString(transmission_type); + return os; +} + +std::string PacketHeaderFormatToString(PacketHeaderFormat format) { + switch (format) { + RETURN_STRING_LITERAL(IETF_QUIC_LONG_HEADER_PACKET); + RETURN_STRING_LITERAL(IETF_QUIC_SHORT_HEADER_PACKET); + RETURN_STRING_LITERAL(GOOGLE_QUIC_PACKET); + default: + return absl::StrCat("Unknown (", static_cast(format), ")"); + } +} + +std::string QuicLongHeaderTypeToString(QuicLongHeaderType type) { + switch (type) { + RETURN_STRING_LITERAL(VERSION_NEGOTIATION); + RETURN_STRING_LITERAL(INITIAL); + RETURN_STRING_LITERAL(ZERO_RTT_PROTECTED); + RETURN_STRING_LITERAL(HANDSHAKE); + RETURN_STRING_LITERAL(RETRY); + RETURN_STRING_LITERAL(INVALID_PACKET_TYPE); + default: + return absl::StrCat("Unknown (", static_cast(type), ")"); + } +} + +std::string MessageStatusToString(MessageStatus message_status) { + switch (message_status) { + RETURN_STRING_LITERAL(MESSAGE_STATUS_SUCCESS); + RETURN_STRING_LITERAL(MESSAGE_STATUS_ENCRYPTION_NOT_ESTABLISHED); + RETURN_STRING_LITERAL(MESSAGE_STATUS_UNSUPPORTED); + RETURN_STRING_LITERAL(MESSAGE_STATUS_BLOCKED); + RETURN_STRING_LITERAL(MESSAGE_STATUS_TOO_LARGE); + RETURN_STRING_LITERAL(MESSAGE_STATUS_INTERNAL_ERROR); + default: + return absl::StrCat("Unknown(", static_cast(message_status), ")"); + } +} + +std::string MessageResultToString(MessageResult message_result) { + if (message_result.status != MESSAGE_STATUS_SUCCESS) { + return absl::StrCat("{", MessageStatusToString(message_result.status), "}"); + } + return absl::StrCat("{MESSAGE_STATUS_SUCCESS,id=", message_result.message_id, + "}"); +} + +std::ostream& operator<<(std::ostream& os, const MessageResult& mr) { + os << MessageResultToString(mr); + return os; +} + +std::string PacketNumberSpaceToString(PacketNumberSpace packet_number_space) { + switch (packet_number_space) { + RETURN_STRING_LITERAL(INITIAL_DATA); + RETURN_STRING_LITERAL(HANDSHAKE_DATA); + RETURN_STRING_LITERAL(APPLICATION_DATA); + default: + return absl::StrCat("Unknown(", static_cast(packet_number_space), + ")"); + } +} + +std::string SerializedPacketFateToString(SerializedPacketFate fate) { + switch (fate) { + RETURN_STRING_LITERAL(DISCARD); + RETURN_STRING_LITERAL(COALESCE); + RETURN_STRING_LITERAL(BUFFER); + RETURN_STRING_LITERAL(SEND_TO_WRITER); + } + return absl::StrCat("Unknown(", static_cast(fate), ")"); +} + +std::ostream& operator<<(std::ostream& os, SerializedPacketFate fate) { + os << SerializedPacketFateToString(fate); + return os; +} + +std::string CongestionControlTypeToString(CongestionControlType cc_type) { + switch (cc_type) { + case kCubicBytes: + return "CUBIC_BYTES"; + case kRenoBytes: + return "RENO_BYTES"; + case kBBR: + return "BBR"; + case kBBRv2: + return "BBRv2"; + case kPCC: + return "PCC"; + case kGoogCC: + return "GoogCC"; + } + return absl::StrCat("Unknown(", static_cast(cc_type), ")"); +} + +std::string EncryptionLevelToString(EncryptionLevel level) { + switch (level) { + RETURN_STRING_LITERAL(ENCRYPTION_INITIAL); + RETURN_STRING_LITERAL(ENCRYPTION_HANDSHAKE); + RETURN_STRING_LITERAL(ENCRYPTION_ZERO_RTT); + RETURN_STRING_LITERAL(ENCRYPTION_FORWARD_SECURE); + default: + return absl::StrCat("Unknown(", static_cast(level), ")"); + } +} + +std::ostream& operator<<(std::ostream& os, EncryptionLevel level) { + os << EncryptionLevelToString(level); + return os; +} + +absl::string_view ClientCertModeToString(ClientCertMode mode) { +#define RETURN_REASON_LITERAL(x) \ + case ClientCertMode::x: \ + return #x + switch (mode) { + RETURN_REASON_LITERAL(kNone); + RETURN_REASON_LITERAL(kRequest); + RETURN_REASON_LITERAL(kRequire); + default: + return ""; + } +#undef RETURN_REASON_LITERAL +} + +std::ostream& operator<<(std::ostream& os, ClientCertMode mode) { + os << ClientCertModeToString(mode); + return os; +} + +std::string QuicConnectionCloseTypeString(QuicConnectionCloseType type) { + switch (type) { + RETURN_STRING_LITERAL(GOOGLE_QUIC_CONNECTION_CLOSE); + RETURN_STRING_LITERAL(IETF_QUIC_TRANSPORT_CONNECTION_CLOSE); + RETURN_STRING_LITERAL(IETF_QUIC_APPLICATION_CONNECTION_CLOSE); + default: + return absl::StrCat("Unknown(", static_cast(type), ")"); + } +} + +std::ostream& operator<<(std::ostream& os, const QuicConnectionCloseType type) { + os << QuicConnectionCloseTypeString(type); + return os; +} + +std::string AddressChangeTypeToString(AddressChangeType type) { + using IntType = typename std::underlying_type::type; + switch (type) { + RETURN_STRING_LITERAL(NO_CHANGE); + RETURN_STRING_LITERAL(PORT_CHANGE); + RETURN_STRING_LITERAL(IPV4_SUBNET_CHANGE); + RETURN_STRING_LITERAL(IPV4_TO_IPV4_CHANGE); + RETURN_STRING_LITERAL(IPV4_TO_IPV6_CHANGE); + RETURN_STRING_LITERAL(IPV6_TO_IPV4_CHANGE); + RETURN_STRING_LITERAL(IPV6_TO_IPV6_CHANGE); + default: + return absl::StrCat("Unknown(", static_cast(type), ")"); + } +} + +std::ostream& operator<<(std::ostream& os, AddressChangeType type) { + os << AddressChangeTypeToString(type); + return os; +} + +std::string KeyUpdateReasonString(KeyUpdateReason reason) { +#define RETURN_REASON_LITERAL(x) \ + case KeyUpdateReason::x: \ + return #x + switch (reason) { + RETURN_REASON_LITERAL(kInvalid); + RETURN_REASON_LITERAL(kRemote); + RETURN_REASON_LITERAL(kLocalForTests); + RETURN_REASON_LITERAL(kLocalForInteropRunner); + RETURN_REASON_LITERAL(kLocalAeadConfidentialityLimit); + RETURN_REASON_LITERAL(kLocalKeyUpdateLimitOverride); + default: + return absl::StrCat("Unknown(", static_cast(reason), ")"); + } +#undef RETURN_REASON_LITERAL +} + +std::ostream& operator<<(std::ostream& os, const KeyUpdateReason reason) { + os << KeyUpdateReasonString(reason); + return os; +} + +bool operator==(const ParsedClientHello& a, const ParsedClientHello& b) { + return a.sni == b.sni && a.uaid == b.uaid && a.alpns == b.alpns && + a.retry_token == b.retry_token && + a.resumption_attempted == b.resumption_attempted && + a.early_data_attempted == b.early_data_attempted; +} + +std::ostream& operator<<(std::ostream& os, + const ParsedClientHello& parsed_chlo) { + os << "{ sni:" << parsed_chlo.sni << ", uaid:" << parsed_chlo.uaid + << ", alpns:" << quiche::PrintElements(parsed_chlo.alpns) + << ", len(retry_token):" << parsed_chlo.retry_token.size() << " }"; + return os; +} + +QUICHE_EXPORT std::string QuicPriorityTypeToString(QuicPriorityType type) { + switch (type) { + case quic::QuicPriorityType::kHttp: + return "HTTP (RFC 9218)"; + case quic::QuicPriorityType::kWebTransport: + return "WebTransport (W3C API)"; + } + return "(unknown)"; +} +QUIC_EXPORT_PRIVATE std::ostream& operator<<(std::ostream& os, + QuicPriorityType type) { + os << QuicPriorityTypeToString(type); + return os; +} + +std::string EcnCodepointToString(QuicEcnCodepoint ecn) { + switch (ecn) { + case ECN_NOT_ECT: + return "Not-ECT"; + case ECN_ECT0: + return "ECT(0)"; + case ECN_ECT1: + return "ECT(1)"; + case ECN_CE: + return "CE"; + } + return ""; // Handle compilation on windows for invalid enums +} + +#undef RETURN_STRING_LITERAL // undef for jumbo builds + +} // namespace quic diff --git a/quiche/quic/core/quic_types.h b/quiche/quic/core/quic_types.h new file mode 100644 index 000000000000..b2a100b3b927 --- /dev/null +++ b/quiche/quic/core/quic_types.h @@ -0,0 +1,928 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_TYPES_H_ +#define QUICHE_QUIC_CORE_QUIC_TYPES_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_packet_number.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/common/quiche_endian.h" + +namespace quic { + +using QuicPacketLength = uint16_t; +using QuicControlFrameId = uint32_t; +using QuicMessageId = uint32_t; + +// IMPORTANT: IETF QUIC defines stream IDs and stream counts as being unsigned +// 62-bit numbers. However, we have decided to only support up to 2^32-1 streams +// in order to reduce the size of data structures such as QuicStreamFrame +// and QuicTransmissionInfo, as that allows them to fit in cache lines and has +// visible perfomance impact. +using QuicStreamId = uint32_t; + +// Count of stream IDs. Used in MAX_STREAMS and STREAMS_BLOCKED frames. +using QuicStreamCount = QuicStreamId; + +using QuicByteCount = uint64_t; +using QuicPacketCount = uint64_t; +using QuicPublicResetNonceProof = uint64_t; +using QuicStreamOffset = uint64_t; +using DiversificationNonce = std::array; +using PacketTimeVector = std::vector>; + +enum : size_t { kStatelessResetTokenLength = 16 }; +using StatelessResetToken = std::array; + +// WebTransport session IDs are stream IDs. +using WebTransportSessionId = uint64_t; +// WebTransport stream reset codes are 8-bit. +using WebTransportStreamError = uint8_t; +// WebTransport session error codes are 32-bit. +using WebTransportSessionError = uint32_t; + +enum : size_t { kQuicPathFrameBufferSize = 8 }; +using QuicPathFrameBuffer = std::array; + +// The connection id sequence number specifies the order that connection +// ids must be used in. This is also the sequence number carried in +// the IETF QUIC NEW_CONNECTION_ID and RETIRE_CONNECTION_ID frames. +using QuicConnectionIdSequenceNumber = uint64_t; + +// A custom data that represents application-specific settings. +// In HTTP/3 for example, it includes the encoded SETTINGS. +using ApplicationState = std::vector; + +// A struct for functions which consume data payloads and fins. +struct QUIC_EXPORT_PRIVATE QuicConsumedData { + constexpr QuicConsumedData(size_t bytes_consumed, bool fin_consumed) + : bytes_consumed(bytes_consumed), fin_consumed(fin_consumed) {} + + // By default, gtest prints the raw bytes of an object. The bool data + // member causes this object to have padding bytes, which causes the + // default gtest object printer to read uninitialize memory. So we need + // to teach gtest how to print this object. + QUIC_EXPORT_PRIVATE friend std::ostream& operator<<( + std::ostream& os, const QuicConsumedData& s); + + // How many bytes were consumed. + size_t bytes_consumed; + + // True if an incoming fin was consumed. + bool fin_consumed; +}; + +// QuicAsyncStatus enumerates the possible results of an asynchronous +// operation. +enum QuicAsyncStatus { + QUIC_SUCCESS = 0, + QUIC_FAILURE = 1, + // QUIC_PENDING results from an operation that will occur asynchronously. When + // the operation is complete, a callback's |Run| method will be called. + QUIC_PENDING = 2, +}; + +// TODO(wtc): see if WriteStatus can be replaced by QuicAsyncStatus. +enum WriteStatus : int16_t { + WRITE_STATUS_OK, + // Write is blocked, caller needs to retry. + WRITE_STATUS_BLOCKED, + // Write is blocked but the packet data is buffered, caller should not retry. + WRITE_STATUS_BLOCKED_DATA_BUFFERED, + // To make the IsWriteError(WriteStatus) function work properly: + // - Non-errors MUST be added before WRITE_STATUS_ERROR. + // - Errors MUST be added after WRITE_STATUS_ERROR. + WRITE_STATUS_ERROR, + WRITE_STATUS_MSG_TOO_BIG, + WRITE_STATUS_FAILED_TO_COALESCE_PACKET, + WRITE_STATUS_NUM_VALUES, +}; + +std::string HistogramEnumString(WriteStatus enum_value); +QUIC_EXPORT_PRIVATE std::ostream& operator<<(std::ostream& os, + const WriteStatus& status); + +inline std::string HistogramEnumDescription(WriteStatus /*dummy*/) { + return "status"; +} + +inline bool IsWriteBlockedStatus(WriteStatus status) { + return status == WRITE_STATUS_BLOCKED || + status == WRITE_STATUS_BLOCKED_DATA_BUFFERED; +} + +inline bool IsWriteError(WriteStatus status) { + return status >= WRITE_STATUS_ERROR; +} + +// A struct used to return the result of write calls including either the number +// of bytes written or the error code, depending upon the status. +struct QUIC_EXPORT_PRIVATE WriteResult { + constexpr WriteResult(WriteStatus status, int bytes_written_or_error_code) + : status(status), bytes_written(bytes_written_or_error_code) {} + + constexpr WriteResult() : WriteResult(WRITE_STATUS_ERROR, 0) {} + + bool operator==(const WriteResult& other) const { + if (status != other.status) { + return false; + } + switch (status) { + case WRITE_STATUS_OK: + return bytes_written == other.bytes_written; + case WRITE_STATUS_BLOCKED: + case WRITE_STATUS_BLOCKED_DATA_BUFFERED: + return true; + default: + return error_code == other.error_code; + } + } + + QUIC_EXPORT_PRIVATE friend std::ostream& operator<<(std::ostream& os, + const WriteResult& s); + + WriteStatus status; + // Number of packets dropped as a result of this write. + // Only used by batch writers. Otherwise always 0. + uint16_t dropped_packets = 0; + // The delta between a packet's ideal and actual send time: + // actual_send_time = ideal_send_time + send_time_offset + // = (now + release_time_delay) + send_time_offset + // Only valid if |status| is WRITE_STATUS_OK. + QuicTime::Delta send_time_offset = QuicTime::Delta::Zero(); + // TODO(wub): In some cases, WRITE_STATUS_ERROR may set an error_code and + // WRITE_STATUS_BLOCKED_DATA_BUFFERED may set bytes_written. This may need + // some cleaning up so that perhaps both values can be set and valid. + union { + int bytes_written; // only valid when status is WRITE_STATUS_OK + int error_code; // only valid when status is WRITE_STATUS_ERROR + }; +}; + +enum TransmissionType : int8_t { + NOT_RETRANSMISSION, + FIRST_TRANSMISSION_TYPE = NOT_RETRANSMISSION, + HANDSHAKE_RETRANSMISSION, // Retransmits due to handshake timeouts. + ALL_ZERO_RTT_RETRANSMISSION, // Retransmits all packets encrypted with 0-RTT + // key. + LOSS_RETRANSMISSION, // Retransmits due to loss detection. + PTO_RETRANSMISSION, // Retransmission due to probe timeout. + PATH_RETRANSMISSION, // Retransmission proactively due to underlying + // network change. + ALL_INITIAL_RETRANSMISSION, // Retransmit all packets encrypted with INITIAL + // key. + LAST_TRANSMISSION_TYPE = ALL_INITIAL_RETRANSMISSION, +}; + +QUIC_EXPORT_PRIVATE std::string TransmissionTypeToString( + TransmissionType transmission_type); + +QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, TransmissionType transmission_type); + +enum HasRetransmittableData : uint8_t { + NO_RETRANSMITTABLE_DATA, + HAS_RETRANSMITTABLE_DATA, +}; + +enum IsHandshake : uint8_t { NOT_HANDSHAKE, IS_HANDSHAKE }; + +enum class Perspective : uint8_t { IS_SERVER, IS_CLIENT }; + +QUIC_EXPORT_PRIVATE std::string PerspectiveToString(Perspective perspective); +QUIC_EXPORT_PRIVATE std::ostream& operator<<(std::ostream& os, + const Perspective& perspective); + +// Describes whether a ConnectionClose was originated by the peer. +enum class ConnectionCloseSource { FROM_PEER, FROM_SELF }; + +QUIC_EXPORT_PRIVATE std::string ConnectionCloseSourceToString( + ConnectionCloseSource connection_close_source); +QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const ConnectionCloseSource& connection_close_source); + +// Should a connection be closed silently or not. +enum class ConnectionCloseBehavior { + SILENT_CLOSE, + SILENT_CLOSE_WITH_CONNECTION_CLOSE_PACKET_SERIALIZED, + SEND_CONNECTION_CLOSE_PACKET +}; + +QUIC_EXPORT_PRIVATE std::string ConnectionCloseBehaviorToString( + ConnectionCloseBehavior connection_close_behavior); +QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const ConnectionCloseBehavior& connection_close_behavior); + +enum QuicFrameType : uint8_t { + // Regular frame types. The values set here cannot change without the + // introduction of a new QUIC version. + PADDING_FRAME = 0, + RST_STREAM_FRAME = 1, + CONNECTION_CLOSE_FRAME = 2, + GOAWAY_FRAME = 3, + WINDOW_UPDATE_FRAME = 4, + BLOCKED_FRAME = 5, + STOP_WAITING_FRAME = 6, + PING_FRAME = 7, + CRYPTO_FRAME = 8, + // TODO(b/157935330): stop hard coding this when deprecate T050. + HANDSHAKE_DONE_FRAME = 9, + + // STREAM and ACK frames are special frames. They are encoded differently on + // the wire and their values do not need to be stable. + STREAM_FRAME, + ACK_FRAME, + // The path MTU discovery frame is encoded as a PING frame on the wire. + MTU_DISCOVERY_FRAME, + + // These are for IETF-specific frames for which there is no mapping + // from Google QUIC frames. These are valid/allowed if and only if IETF- + // QUIC has been negotiated. Values are not important, they are not + // the values that are in the packets (see QuicIetfFrameType, below). + NEW_CONNECTION_ID_FRAME, + MAX_STREAMS_FRAME, + STREAMS_BLOCKED_FRAME, + PATH_RESPONSE_FRAME, + PATH_CHALLENGE_FRAME, + STOP_SENDING_FRAME, + MESSAGE_FRAME, + NEW_TOKEN_FRAME, + RETIRE_CONNECTION_ID_FRAME, + ACK_FREQUENCY_FRAME, + + NUM_FRAME_TYPES +}; + +// Human-readable string suitable for logging. +QUIC_EXPORT_PRIVATE std::string QuicFrameTypeToString(QuicFrameType t); +QUIC_EXPORT_PRIVATE std::ostream& operator<<(std::ostream& os, + const QuicFrameType& t); + +// Ietf frame types. These are defined in the IETF QUIC Specification. +// Explicit values are given in the enum so that we can be sure that +// the symbol will map to the correct stream type. +// All types are defined here, even if we have not yet implmented the +// quic/core/stream/.... stuff needed. +// Note: The protocol specifies that frame types are varint-62 encoded, +// further stating that the shortest encoding must be used. The current set of +// frame types all have values less than 0x40 (64) so can be encoded in a single +// byte, with the two most significant bits being 0. Thus, the following +// enumerations are valid as both the numeric values of frame types AND their +// encodings. +enum QuicIetfFrameType : uint64_t { + IETF_PADDING = 0x00, + IETF_PING = 0x01, + IETF_ACK = 0x02, + IETF_ACK_ECN = 0x03, + IETF_RST_STREAM = 0x04, + IETF_STOP_SENDING = 0x05, + IETF_CRYPTO = 0x06, + IETF_NEW_TOKEN = 0x07, + // the low-3 bits of the stream frame type value are actually flags + // declaring what parts of the frame are/are-not present, as well as + // some other control information. The code would then do something + // along the lines of "if ((frame_type & 0xf8) == 0x08)" to determine + // whether the frame is a stream frame or not, and then examine each + // bit specifically when/as needed. + IETF_STREAM = 0x08, + // 0x09 through 0x0f are various flag settings of the IETF_STREAM frame. + IETF_MAX_DATA = 0x10, + IETF_MAX_STREAM_DATA = 0x11, + IETF_MAX_STREAMS_BIDIRECTIONAL = 0x12, + IETF_MAX_STREAMS_UNIDIRECTIONAL = 0x13, + IETF_DATA_BLOCKED = 0x14, + IETF_STREAM_DATA_BLOCKED = 0x15, + IETF_STREAMS_BLOCKED_BIDIRECTIONAL = 0x16, + IETF_STREAMS_BLOCKED_UNIDIRECTIONAL = 0x17, + IETF_NEW_CONNECTION_ID = 0x18, + IETF_RETIRE_CONNECTION_ID = 0x19, + IETF_PATH_CHALLENGE = 0x1a, + IETF_PATH_RESPONSE = 0x1b, + // Both of the following are "Connection Close" frames, + // the first signals transport-layer errors, the second application-layer + // errors. + IETF_CONNECTION_CLOSE = 0x1c, + IETF_APPLICATION_CLOSE = 0x1d, + + IETF_HANDSHAKE_DONE = 0x1e, + + // The MESSAGE frame type has not yet been fully standardized. + // QUIC versions starting with 46 and before 99 use 0x20-0x21. + // IETF QUIC (v99) uses 0x30-0x31, see draft-pauly-quic-datagram. + IETF_EXTENSION_MESSAGE_NO_LENGTH = 0x20, + IETF_EXTENSION_MESSAGE = 0x21, + IETF_EXTENSION_MESSAGE_NO_LENGTH_V99 = 0x30, + IETF_EXTENSION_MESSAGE_V99 = 0x31, + + // An QUIC extension frame for sender control of acknowledgement delays + IETF_ACK_FREQUENCY = 0xaf, + + // A QUIC extension frame which augments the IETF_ACK frame definition with + // packet receive timestamps. + // TODO(ianswett): Determine a proper value to replace this temporary value. + IETF_ACK_RECEIVE_TIMESTAMPS = 0x22, +}; +QUIC_EXPORT_PRIVATE std::ostream& operator<<(std::ostream& os, + const QuicIetfFrameType& c); +QUIC_EXPORT_PRIVATE std::string QuicIetfFrameTypeString(QuicIetfFrameType t); + +// Masks for the bits that indicate the frame is a Stream frame vs the +// bits used as flags. +#define IETF_STREAM_FRAME_TYPE_MASK 0xfffffffffffffff8 +#define IETF_STREAM_FRAME_FLAG_MASK 0x07 +#define IS_IETF_STREAM_FRAME(_stype_) \ + (((_stype_)&IETF_STREAM_FRAME_TYPE_MASK) == IETF_STREAM) + +// These are the values encoded in the low-order 3 bits of the +// IETF_STREAMx frame type. +#define IETF_STREAM_FRAME_FIN_BIT 0x01 +#define IETF_STREAM_FRAME_LEN_BIT 0x02 +#define IETF_STREAM_FRAME_OFF_BIT 0x04 + +enum QuicPacketNumberLength : uint8_t { + PACKET_1BYTE_PACKET_NUMBER = 1, + PACKET_2BYTE_PACKET_NUMBER = 2, + PACKET_3BYTE_PACKET_NUMBER = 3, // Used in versions 45+. + PACKET_4BYTE_PACKET_NUMBER = 4, + IETF_MAX_PACKET_NUMBER_LENGTH = 4, + // TODO(b/145819870): Remove 6 and 8 when we remove Q043 since these values + // are not representable with later versions. + PACKET_6BYTE_PACKET_NUMBER = 6, + PACKET_8BYTE_PACKET_NUMBER = 8 +}; + +// Used to indicate a QuicSequenceNumberLength using two flag bits. +enum QuicPacketNumberLengthFlags { + PACKET_FLAGS_1BYTE_PACKET = 0, // 00 + PACKET_FLAGS_2BYTE_PACKET = 1, // 01 + PACKET_FLAGS_4BYTE_PACKET = 1 << 1, // 10 + PACKET_FLAGS_8BYTE_PACKET = 1 << 1 | 1, // 11 +}; + +// The public flags are specified in one byte. +enum QuicPacketPublicFlags { + PACKET_PUBLIC_FLAGS_NONE = 0, + + // Bit 0: Does the packet header contains version info? + PACKET_PUBLIC_FLAGS_VERSION = 1 << 0, + + // Bit 1: Is this packet a public reset packet? + PACKET_PUBLIC_FLAGS_RST = 1 << 1, + + // Bit 2: indicates the header includes a nonce. + PACKET_PUBLIC_FLAGS_NONCE = 1 << 2, + + // Bit 3: indicates whether a ConnectionID is included. + PACKET_PUBLIC_FLAGS_0BYTE_CONNECTION_ID = 0, + PACKET_PUBLIC_FLAGS_8BYTE_CONNECTION_ID = 1 << 3, + + // Deprecated version 32 and earlier used two bits to indicate an 8-byte + // connection ID. We send this from the client because of some broken + // middleboxes that are still checking this bit. + PACKET_PUBLIC_FLAGS_8BYTE_CONNECTION_ID_OLD = 1 << 3 | 1 << 2, + + // Bits 4 and 5 describe the packet number length as follows: + // --00----: 1 byte + // --01----: 2 bytes + // --10----: 4 bytes + // --11----: 6 bytes + PACKET_PUBLIC_FLAGS_1BYTE_PACKET = PACKET_FLAGS_1BYTE_PACKET << 4, + PACKET_PUBLIC_FLAGS_2BYTE_PACKET = PACKET_FLAGS_2BYTE_PACKET << 4, + PACKET_PUBLIC_FLAGS_4BYTE_PACKET = PACKET_FLAGS_4BYTE_PACKET << 4, + PACKET_PUBLIC_FLAGS_6BYTE_PACKET = PACKET_FLAGS_8BYTE_PACKET << 4, + + // Reserved, unimplemented flags: + + // Bit 7: indicates the presence of a second flags byte. + PACKET_PUBLIC_FLAGS_TWO_OR_MORE_BYTES = 1 << 7, + + // All bits set (bits 6 and 7 are not currently used): 00111111 + PACKET_PUBLIC_FLAGS_MAX = (1 << 6) - 1, +}; + +// The private flags are specified in one byte. +enum QuicPacketPrivateFlags { + PACKET_PRIVATE_FLAGS_NONE = 0, + + // Bit 0: Does this packet contain an entropy bit? + PACKET_PRIVATE_FLAGS_ENTROPY = 1 << 0, + + // (bits 1-7 are not used): 00000001 + PACKET_PRIVATE_FLAGS_MAX = (1 << 1) - 1 +}; + +// Defines for all types of congestion control algorithms that can be used in +// QUIC. Note that this is separate from the congestion feedback type - +// some congestion control algorithms may use the same feedback type +// (Reno and Cubic are the classic example for that). +enum CongestionControlType { + kCubicBytes, + kRenoBytes, + kBBR, + kPCC, + kGoogCC, + kBBRv2, +}; + +QUIC_EXPORT_PRIVATE std::string CongestionControlTypeToString( + CongestionControlType cc_type); + +// EncryptionLevel enumerates the stages of encryption that a QUIC connection +// progresses through. When retransmitting a packet, the encryption level needs +// to be specified so that it is retransmitted at a level which the peer can +// understand. +enum EncryptionLevel : int8_t { + ENCRYPTION_INITIAL = 0, + ENCRYPTION_HANDSHAKE = 1, + ENCRYPTION_ZERO_RTT = 2, + ENCRYPTION_FORWARD_SECURE = 3, + + NUM_ENCRYPTION_LEVELS, +}; + +inline bool EncryptionLevelIsValid(EncryptionLevel level) { + return ENCRYPTION_INITIAL <= level && level < NUM_ENCRYPTION_LEVELS; +} + +QUIC_EXPORT_PRIVATE std::string EncryptionLevelToString(EncryptionLevel level); + +QUIC_EXPORT_PRIVATE std::ostream& operator<<(std::ostream& os, + EncryptionLevel level); + +// Enumeration of whether a server endpoint will request a client certificate, +// and whether that endpoint requires a valid client certificate to establish a +// connection. +enum class ClientCertMode : uint8_t { + kNone, // Do not request a client certificate. Default server behavior. + kRequest, // Request a certificate, but allow unauthenticated connections. + kRequire, // Require clients to provide a valid certificate. +}; + +QUIC_EXPORT_PRIVATE absl::string_view ClientCertModeToString( + ClientCertMode mode); + +QUIC_EXPORT_PRIVATE std::ostream& operator<<(std::ostream& os, + ClientCertMode mode); + +enum AddressChangeType : uint8_t { + // IP address and port remain unchanged. + NO_CHANGE, + // Port changed, but IP address remains unchanged. + PORT_CHANGE, + // IPv4 address changed, but within the /24 subnet (port may have changed.) + IPV4_SUBNET_CHANGE, + // IPv4 address changed, excluding /24 subnet change (port may have changed.) + IPV4_TO_IPV4_CHANGE, + // IP address change from an IPv4 to an IPv6 address (port may have changed.) + IPV4_TO_IPV6_CHANGE, + // IP address change from an IPv6 to an IPv4 address (port may have changed.) + IPV6_TO_IPV4_CHANGE, + // IP address change from an IPv6 to an IPv6 address (port may have changed.) + IPV6_TO_IPV6_CHANGE, +}; + +QUIC_EXPORT_PRIVATE std::string AddressChangeTypeToString( + AddressChangeType type); + +QUIC_EXPORT_PRIVATE std::ostream& operator<<(std::ostream& os, + AddressChangeType type); + +enum StreamSendingState { + // Sender has more data to send on this stream. + NO_FIN, + // Sender is done sending on this stream. + FIN, + // Sender is done sending on this stream and random padding needs to be + // appended after all stream frames. + FIN_AND_PADDING, +}; + +enum SentPacketState : uint8_t { + // The packet is in flight and waiting to be acked. + OUTSTANDING, + FIRST_PACKET_STATE = OUTSTANDING, + // The packet was never sent. + NEVER_SENT, + // The packet has been acked. + ACKED, + // This packet is not expected to be acked. + UNACKABLE, + // This packet has been delivered or unneeded. + NEUTERED, + + // States below are corresponding to retransmission types in TransmissionType. + + // This packet has been retransmitted when retransmission timer fires in + // HANDSHAKE mode. + HANDSHAKE_RETRANSMITTED, + // This packet is considered as lost, this is used for LOST_RETRANSMISSION. + LOST, + // This packet has been retransmitted when PTO fires. + PTO_RETRANSMITTED, + // This packet is sent on a different path or is a PING only packet. + // Do not update RTT stats and congestion control if the packet is the + // largest_acked of an incoming ACK. + NOT_CONTRIBUTING_RTT, + LAST_PACKET_STATE = NOT_CONTRIBUTING_RTT, +}; + +enum PacketHeaderFormat : uint8_t { + IETF_QUIC_LONG_HEADER_PACKET, + IETF_QUIC_SHORT_HEADER_PACKET, + GOOGLE_QUIC_PACKET, +}; + +QUIC_EXPORT_PRIVATE std::string PacketHeaderFormatToString( + PacketHeaderFormat format); + +// Information about a newly acknowledged packet. +struct QUIC_EXPORT_PRIVATE AckedPacket { + constexpr AckedPacket(QuicPacketNumber packet_number, + QuicPacketLength bytes_acked, + QuicTime receive_timestamp) + : packet_number(packet_number), + bytes_acked(bytes_acked), + receive_timestamp(receive_timestamp) {} + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const AckedPacket& acked_packet); + + QuicPacketNumber packet_number; + // Number of bytes sent in the packet that was acknowledged. + QuicPacketLength bytes_acked; + // The time |packet_number| was received by the peer, according to the + // optional timestamp the peer included in the ACK frame which acknowledged + // |packet_number|. Zero if no timestamp was available for this packet. + QuicTime receive_timestamp; +}; + +// A vector of acked packets. +using AckedPacketVector = absl::InlinedVector; + +// Information about a newly lost packet. +struct QUIC_EXPORT_PRIVATE LostPacket { + LostPacket(QuicPacketNumber packet_number, QuicPacketLength bytes_lost) + : packet_number(packet_number), bytes_lost(bytes_lost) {} + + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const LostPacket& lost_packet); + + QuicPacketNumber packet_number; + // Number of bytes sent in the packet that was lost. + QuicPacketLength bytes_lost; +}; + +// A vector of lost packets. +using LostPacketVector = absl::InlinedVector; + +// Please note, this value cannot used directly for packet serialization. +enum QuicLongHeaderType : uint8_t { + VERSION_NEGOTIATION, + INITIAL, + ZERO_RTT_PROTECTED, + HANDSHAKE, + RETRY, + + INVALID_PACKET_TYPE, +}; + +QUIC_EXPORT_PRIVATE std::string QuicLongHeaderTypeToString( + QuicLongHeaderType type); + +enum QuicPacketHeaderTypeFlags : uint8_t { + // Bit 2: Key phase bit for IETF QUIC short header packets. + FLAGS_KEY_PHASE_BIT = 1 << 2, + // Bit 3: Google QUIC Demultiplexing bit, the short header always sets this + // bit to 0, allowing to distinguish Google QUIC packets from short header + // packets. + FLAGS_DEMULTIPLEXING_BIT = 1 << 3, + // Bits 4 and 5: Reserved bits for short header. + FLAGS_SHORT_HEADER_RESERVED_1 = 1 << 4, + FLAGS_SHORT_HEADER_RESERVED_2 = 1 << 5, + // Bit 6: the 'QUIC' bit. + FLAGS_FIXED_BIT = 1 << 6, + // Bit 7: Indicates the header is long or short header. + FLAGS_LONG_HEADER = 1 << 7, +}; + +enum MessageStatus { + MESSAGE_STATUS_SUCCESS, + MESSAGE_STATUS_ENCRYPTION_NOT_ESTABLISHED, // Failed to send message because + // encryption is not established + // yet. + MESSAGE_STATUS_UNSUPPORTED, // Failed to send message because MESSAGE frame + // is not supported by the connection. + MESSAGE_STATUS_BLOCKED, // Failed to send message because connection is + // congestion control blocked or underlying socket is + // write blocked. + MESSAGE_STATUS_TOO_LARGE, // Failed to send message because the message is + // too large to fit into a single packet. + MESSAGE_STATUS_INTERNAL_ERROR, // Failed to send message because connection + // reaches an invalid state. +}; + +QUIC_EXPORT_PRIVATE std::string MessageStatusToString( + MessageStatus message_status); + +// Used to return the result of SendMessage calls +struct QUIC_EXPORT_PRIVATE MessageResult { + MessageResult(MessageStatus status, QuicMessageId message_id); + + bool operator==(const MessageResult& other) const { + return status == other.status && message_id == other.message_id; + } + + QUIC_EXPORT_PRIVATE friend std::ostream& operator<<(std::ostream& os, + const MessageResult& mr); + + MessageStatus status; + // Only valid when status is MESSAGE_STATUS_SUCCESS. + QuicMessageId message_id; +}; + +QUIC_EXPORT_PRIVATE std::string MessageResultToString( + MessageResult message_result); + +enum WriteStreamDataResult { + WRITE_SUCCESS, + STREAM_MISSING, // Trying to write data of a nonexistent stream (e.g. + // closed). + WRITE_FAILED, // Trying to write nonexistent data of a stream +}; + +enum StreamType : uint8_t { + // Bidirectional streams allow for data to be sent in both directions. + BIDIRECTIONAL, + + // Unidirectional streams carry data in one direction only. + WRITE_UNIDIRECTIONAL, + READ_UNIDIRECTIONAL, + // Not actually a stream type. Used only by QuicCryptoStream when it uses + // CRYPTO frames and isn't actually a QuicStream. + CRYPTO, +}; + +// A packet number space is the context in which a packet can be processed and +// acknowledged. +enum PacketNumberSpace : uint8_t { + INITIAL_DATA = 0, // Only used in IETF QUIC. + HANDSHAKE_DATA = 1, + APPLICATION_DATA = 2, + + NUM_PACKET_NUMBER_SPACES, +}; + +QUIC_EXPORT_PRIVATE std::string PacketNumberSpaceToString( + PacketNumberSpace packet_number_space); + +// Used to return the result of processing a received ACK frame. +enum AckResult { + PACKETS_NEWLY_ACKED, + NO_PACKETS_NEWLY_ACKED, + UNSENT_PACKETS_ACKED, // Peer acks unsent packets. + UNACKABLE_PACKETS_ACKED, // Peer acks packets that are not expected to be + // acked. For example, encryption is reestablished, + // and all sent encrypted packets cannot be + // decrypted by the peer. Version gets negotiated, + // and all sent packets in the different version + // cannot be processed by the peer. + PACKETS_ACKED_IN_WRONG_PACKET_NUMBER_SPACE, +}; + +// Indicates the fate of a serialized packet in WritePacket(). +enum SerializedPacketFate : uint8_t { + DISCARD, // Discard the packet. + COALESCE, // Try to coalesce packet. + BUFFER, // Buffer packet in buffered_packets_. + SEND_TO_WRITER, // Send packet to writer. +}; + +QUIC_EXPORT_PRIVATE std::string SerializedPacketFateToString( + SerializedPacketFate fate); + +QUIC_EXPORT_PRIVATE std::ostream& operator<<(std::ostream& os, + const SerializedPacketFate fate); + +// There are three different forms of CONNECTION_CLOSE. +enum QuicConnectionCloseType { + GOOGLE_QUIC_CONNECTION_CLOSE = 0, + IETF_QUIC_TRANSPORT_CONNECTION_CLOSE = 1, + IETF_QUIC_APPLICATION_CONNECTION_CLOSE = 2 +}; + +QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const QuicConnectionCloseType type); + +QUIC_EXPORT_PRIVATE std::string QuicConnectionCloseTypeString( + QuicConnectionCloseType type); + +// Indicate handshake state of a connection. +enum HandshakeState { + // Initial state. + HANDSHAKE_START, + // Only used in IETF QUIC with TLS handshake. State proceeds to + // HANDSHAKE_PROCESSED after a packet of HANDSHAKE packet number space + // gets successfully processed, and the initial key can be dropped. + HANDSHAKE_PROCESSED, + // In QUIC crypto, state proceeds to HANDSHAKE_COMPLETE if client receives + // SHLO or server successfully processes an ENCRYPTION_FORWARD_SECURE + // packet, such that the handshake packets can be neutered. In IETF QUIC + // with TLS handshake, state proceeds to HANDSHAKE_COMPLETE once the client + // has both 1-RTT send and receive keys. + HANDSHAKE_COMPLETE, + // Only used in IETF QUIC with TLS handshake. State proceeds to + // HANDSHAKE_CONFIRMED if 1) a client receives HANDSHAKE_DONE frame or + // acknowledgment for 1-RTT packet or 2) server has + // 1-RTT send and receive keys. + HANDSHAKE_CONFIRMED, +}; + +struct QUIC_NO_EXPORT NextReleaseTimeResult { + // The ideal release time of the packet being sent. + QuicTime release_time; + // Whether it is allowed to send the packet before release_time. + bool allow_burst; +}; + +// QuicPacketBuffer bundles a buffer and a function that releases it. Note +// it does not assume ownership of buffer, i.e. it doesn't release the buffer on +// destruction. +struct QUIC_NO_EXPORT QuicPacketBuffer { + QuicPacketBuffer() = default; + + QuicPacketBuffer(char* buffer, + std::function release_buffer) + : buffer(buffer), release_buffer(std::move(release_buffer)) {} + + char* buffer = nullptr; + std::function release_buffer; +}; + +// QuicOwnedPacketBuffer is a QuicPacketBuffer that assumes buffer ownership. +struct QUIC_NO_EXPORT QuicOwnedPacketBuffer : public QuicPacketBuffer { + QuicOwnedPacketBuffer(const QuicOwnedPacketBuffer&) = delete; + QuicOwnedPacketBuffer& operator=(const QuicOwnedPacketBuffer&) = delete; + + QuicOwnedPacketBuffer(char* buffer, + std::function release_buffer) + : QuicPacketBuffer(buffer, std::move(release_buffer)) {} + + QuicOwnedPacketBuffer(QuicOwnedPacketBuffer&& owned_buffer) + : QuicPacketBuffer(std::move(owned_buffer)) { + // |owned_buffer| does not own a buffer any more. + owned_buffer.buffer = nullptr; + } + + explicit QuicOwnedPacketBuffer(QuicPacketBuffer&& packet_buffer) + : QuicPacketBuffer(std::move(packet_buffer)) {} + + ~QuicOwnedPacketBuffer() { + if (release_buffer != nullptr && buffer != nullptr) { + release_buffer(buffer); + } + } +}; + +// These values must remain stable as they are uploaded to UMA histograms. +enum class KeyUpdateReason { + kInvalid = 0, + kRemote = 1, + kLocalForTests = 2, + kLocalForInteropRunner = 3, + kLocalAeadConfidentialityLimit = 4, + kLocalKeyUpdateLimitOverride = 5, + kMaxValue = kLocalKeyUpdateLimitOverride, +}; + +QUIC_EXPORT_PRIVATE std::ostream& operator<<(std::ostream& os, + const KeyUpdateReason reason); + +QUIC_EXPORT_PRIVATE std::string KeyUpdateReasonString(KeyUpdateReason reason); + +using QuicSignatureAlgorithmVector = absl::InlinedVector; + +// QuicSSLConfig contains configurations to be applied on a SSL object, which +// overrides the configurations in SSL_CTX. +struct QUIC_NO_EXPORT QuicSSLConfig { + // Whether TLS early data should be enabled. If not set, default to enabled. + absl::optional early_data_enabled; + // Whether TLS session tickets are supported. If not set, default to + // supported. + absl::optional disable_ticket_support; + // If set, used to configure the SSL object with + // SSL_set_signing_algorithm_prefs. + absl::optional signing_algorithm_prefs; + // Client certificate mode for mTLS support. Only used at server side. + ClientCertMode client_cert_mode = ClientCertMode::kNone; + // As a client, the ECHConfigList to use with ECH. If empty, ECH is not + // offered. + std::string ech_config_list; + // As a client, whether ECH GREASE is enabled. If `ech_config_list` is + // not empty, this value does nothing. + bool ech_grease_enabled = false; +}; + +// QuicDelayedSSLConfig contains a subset of SSL config that can be applied +// after BoringSSL's early select certificate callback. This overwrites all SSL +// configs applied before cert selection. +struct QUIC_NO_EXPORT QuicDelayedSSLConfig { + // Client certificate mode for mTLS support. Only used at server side. + // absl::nullopt means do not change client certificate mode. + absl::optional client_cert_mode; + // QUIC transport parameters as serialized by ProofSourceHandle. + absl::optional> quic_transport_parameters; +}; + +// ParsedClientHello contains client hello information extracted from a fully +// received client hello. +struct QUIC_NO_EXPORT ParsedClientHello { + std::string sni; // QUIC crypto and TLS. + std::string uaid; // QUIC crypto only. + std::vector alpns; // QUIC crypto and TLS. + // The unvalidated retry token from the last received packet of a potentially + // multi-packet client hello. TLS only. + std::string retry_token; + bool resumption_attempted = false; // TLS only. + bool early_data_attempted = false; // TLS only. +}; + +QUIC_EXPORT_PRIVATE bool operator==(const ParsedClientHello& a, + const ParsedClientHello& b); + +QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const ParsedClientHello& parsed_chlo); + +// The two bits in the IP header for Explicit Congestion Notification can take +// one of four values. +enum QuicEcnCodepoint { + // The NOT-ECT codepoint, indicating the packet sender is not using (or the + // network has disabled) ECN. + ECN_NOT_ECT = 0, + // The ECT(0) codepoint, indicating the packet sender is using classic ECN + // (RFC3168). + ECN_ECT0 = 1, + // The ECT(1) codepoint, indicating the packet sender is using Low Latency, + // Low Loss, Scalable Throughput (L4S) ECN (RFC9330). + ECN_ECT1 = 2, + // The CE ("Congestion Experienced") codepoint, indicating the packet sender + // is using ECN, and a router is experiencing congestion. + ECN_CE = 3, +}; + +QUICHE_EXPORT std::string EcnCodepointToString(QuicEcnCodepoint ecn); + +// This struct reports the Explicit Congestion Notification (ECN) contents of +// the ACK_ECN frame. They are the cumulative number of QUIC packets received +// for that codepoint in a given Packet Number Space. +struct QUIC_EXPORT_PRIVATE QuicEcnCounts { + QuicEcnCounts() = default; + QuicEcnCounts(QuicPacketCount ect0, QuicPacketCount ect1, QuicPacketCount ce) + : ect0(ect0), ect1(ect1), ce(ce) {} + + std::string ToString() const { + return absl::StrFormat("ECT(0): %s, ECT(1): %s, CE: %s", + std::to_string(ect0), std::to_string(ect1), + std::to_string(ce)); + } + + bool operator==(const QuicEcnCounts& other) const { + return (this->ect0 == other.ect0 && this->ect1 == other.ect1 && + this->ce == other.ce); + } + + QuicPacketCount ect0 = 0; + QuicPacketCount ect1 = 0; + QuicPacketCount ce = 0; +}; + +// Type of the priorities used by a QUIC session. +enum class QuicPriorityType : uint8_t { + // HTTP priorities as defined by RFC 9218 + kHttp, + // WebTransport priorities as defined by + kWebTransport, +}; + +QUICHE_EXPORT std::string QuicPriorityTypeToString(QuicPriorityType type); +QUIC_EXPORT_PRIVATE std::ostream& operator<<(std::ostream& os, + QuicPriorityType type); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_TYPES_H_ diff --git a/quiche/quic/core/quic_udp_socket.h b/quiche/quic/core/quic_udp_socket.h new file mode 100644 index 000000000000..08a259578d07 --- /dev/null +++ b/quiche/quic/core/quic_udp_socket.h @@ -0,0 +1,270 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_UDP_SOCKET_H_ +#define QUICHE_QUIC_CORE_QUIC_UDP_SOCKET_H_ + +#include +#include +#include + +#include "quiche/quic/core/io/socket.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_socket_address.h" + +#ifndef UDP_GRO +#define UDP_GRO 104 +#endif + +namespace quic { + +using QuicUdpSocketFd = SocketFd; +inline constexpr QuicUdpSocketFd kQuicInvalidSocketFd = kInvalidSocketFd; + +inline constexpr size_t kDefaultUdpPacketControlBufferSize = 512; + +enum class QuicUdpPacketInfoBit : uint8_t { + DROPPED_PACKETS = 0, // Read + V4_SELF_IP, // Read + V6_SELF_IP, // Read + PEER_ADDRESS, // Read & Write + RECV_TIMESTAMP, // Read + TTL, // Read & Write + ECN, // Read + GOOGLE_PACKET_HEADER, // Read + NUM_BITS, + IS_GRO, // Read +}; +static_assert(static_cast(QuicUdpPacketInfoBit::NUM_BITS) <= + BitMask64::NumBits(), + "BitMask64 not wide enough to hold all bits."); + +// BufferSpan points to an unowned buffer, copying this structure only copies +// the pointer and length, not the buffer itself. +struct QUIC_EXPORT_PRIVATE BufferSpan { + BufferSpan(char* buffer, size_t buffer_len) + : buffer(buffer), buffer_len(buffer_len) {} + + BufferSpan() = default; + BufferSpan(const BufferSpan& other) = default; + BufferSpan& operator=(const BufferSpan& other) = default; + + char* buffer = nullptr; + size_t buffer_len = 0; +}; + +// QuicUdpPacketInfo contains per-packet information used for sending and +// receiving. +class QUIC_EXPORT_PRIVATE QuicUdpPacketInfo { + public: + BitMask64 bitmask() const { return bitmask_; } + + void Reset() { bitmask_.ClearAll(); } + + bool HasValue(QuicUdpPacketInfoBit bit) const { return bitmask_.IsSet(bit); } + + QuicPacketCount dropped_packets() const { + QUICHE_DCHECK(HasValue(QuicUdpPacketInfoBit::DROPPED_PACKETS)); + return dropped_packets_; + } + + void SetDroppedPackets(QuicPacketCount dropped_packets) { + dropped_packets_ = dropped_packets; + bitmask_.Set(QuicUdpPacketInfoBit::DROPPED_PACKETS); + } + + void set_gso_size(size_t gso_size) { + gso_size_ = gso_size; + bitmask_.Set(QuicUdpPacketInfoBit::IS_GRO); + } + + size_t gso_size() { return gso_size_; } + + const QuicIpAddress& self_v4_ip() const { + QUICHE_DCHECK(HasValue(QuicUdpPacketInfoBit::V4_SELF_IP)); + return self_v4_ip_; + } + + void SetSelfV4Ip(QuicIpAddress self_v4_ip) { + self_v4_ip_ = self_v4_ip; + bitmask_.Set(QuicUdpPacketInfoBit::V4_SELF_IP); + } + + const QuicIpAddress& self_v6_ip() const { + QUICHE_DCHECK(HasValue(QuicUdpPacketInfoBit::V6_SELF_IP)); + return self_v6_ip_; + } + + void SetSelfV6Ip(QuicIpAddress self_v6_ip) { + self_v6_ip_ = self_v6_ip; + bitmask_.Set(QuicUdpPacketInfoBit::V6_SELF_IP); + } + + void SetSelfIp(QuicIpAddress self_ip) { + if (self_ip.IsIPv4()) { + SetSelfV4Ip(self_ip); + } else { + SetSelfV6Ip(self_ip); + } + } + + const QuicSocketAddress& peer_address() const { + QUICHE_DCHECK(HasValue(QuicUdpPacketInfoBit::PEER_ADDRESS)); + return peer_address_; + } + + void SetPeerAddress(QuicSocketAddress peer_address) { + peer_address_ = peer_address; + bitmask_.Set(QuicUdpPacketInfoBit::PEER_ADDRESS); + } + + QuicWallTime receive_timestamp() const { + QUICHE_DCHECK(HasValue(QuicUdpPacketInfoBit::RECV_TIMESTAMP)); + return receive_timestamp_; + } + + void SetReceiveTimestamp(QuicWallTime receive_timestamp) { + receive_timestamp_ = receive_timestamp; + bitmask_.Set(QuicUdpPacketInfoBit::RECV_TIMESTAMP); + } + + int ttl() const { + QUICHE_DCHECK(HasValue(QuicUdpPacketInfoBit::TTL)); + return ttl_; + } + + void SetTtl(int ttl) { + ttl_ = ttl; + bitmask_.Set(QuicUdpPacketInfoBit::TTL); + } + + BufferSpan google_packet_headers() const { + QUICHE_DCHECK(HasValue(QuicUdpPacketInfoBit::GOOGLE_PACKET_HEADER)); + return google_packet_headers_; + } + + void SetGooglePacketHeaders(BufferSpan google_packet_headers) { + google_packet_headers_ = google_packet_headers; + bitmask_.Set(QuicUdpPacketInfoBit::GOOGLE_PACKET_HEADER); + } + + QuicEcnCodepoint ecn_codepoint() const { return ecn_codepoint_; } + + void SetEcnCodepoint(const QuicEcnCodepoint ecn_codepoint) { + ecn_codepoint_ = ecn_codepoint; + bitmask_.Set(QuicUdpPacketInfoBit::ECN); + } + + private: + BitMask64 bitmask_; + QuicPacketCount dropped_packets_; + QuicIpAddress self_v4_ip_; + QuicIpAddress self_v6_ip_; + QuicSocketAddress peer_address_; + QuicWallTime receive_timestamp_ = QuicWallTime::Zero(); + int ttl_; + BufferSpan google_packet_headers_; + size_t gso_size_ = 0; + QuicEcnCodepoint ecn_codepoint_ = ECN_NOT_ECT; +}; + +// QuicUdpSocketApi provides a minimal set of apis for sending and receiving +// udp packets. The low level udp socket apis differ between kernels and kernel +// versions, the goal of QuicUdpSocketApi is to hide such differences. +// We use non-static functions because it is easier to be mocked in tests when +// needed. +class QUIC_EXPORT_PRIVATE QuicUdpSocketApi { + public: + // Creates a non-blocking udp socket, sets the receive/send buffer and enable + // receiving of self ip addresses on read. + // If address_family == AF_INET6 and ipv6_only is true, receiving of IPv4 self + // addresses is disabled. This is only necessary for IPv6 sockets on iOS - all + // other platforms can ignore this parameter. Return kQuicInvalidSocketFd if + // failed. + QuicUdpSocketFd Create(int address_family, int receive_buffer_size, + int send_buffer_size, bool ipv6_only = false); + + // Closes |fd|. No-op if |fd| equals to kQuicInvalidSocketFd. + void Destroy(QuicUdpSocketFd fd); + + // Bind |fd| to |address|. If |address|'s port number is 0, kernel will choose + // a random port to bind to. Caller can use QuicSocketAddress::FromSocket(fd) + // to get the bound random port. + bool Bind(QuicUdpSocketFd fd, QuicSocketAddress address); + + // Bind |fd| to |interface_name|. Returns true if the setsockopt call + // succeeded. Returns false if |interface_name| is empty, its length exceeds + // IFNAMSIZ, or setsockopt experienced an error. Only implemented for + // non-Android Linux. + bool BindInterface(QuicUdpSocketFd fd, const std::string& interface_name); + + // Enable receiving of various per-packet information. Return true if the + // corresponding information can be received on read. + bool EnableDroppedPacketCount(QuicUdpSocketFd fd); + bool EnableReceiveTimestamp(QuicUdpSocketFd fd); + bool EnableReceiveTtlForV4(QuicUdpSocketFd fd); + bool EnableReceiveTtlForV6(QuicUdpSocketFd fd); + + // Wait for |fd| to become readable, up to |timeout|. + // Return true if |fd| is readable upon return. + bool WaitUntilReadable(QuicUdpSocketFd fd, QuicTime::Delta timeout); + + struct QUIC_EXPORT_PRIVATE ReadPacketResult { + bool ok = false; + QuicUdpPacketInfo packet_info; + BufferSpan packet_buffer; + BufferSpan control_buffer; + + void Reset(size_t packet_buffer_length) { + ok = false; + packet_info.Reset(); + packet_buffer.buffer_len = packet_buffer_length; + } + }; + // Read a packet from |fd|: + // packet_info_interested: Bitmask indicating what information caller wants to + // receive into |result->packet_info|. + // result->packet_info: Received per packet information. + // result->packet_buffer: The packet buffer, to be filled with packet data. + // |result->packet_buffer.buffer_len| is set to the + // packet length on a successful return. + // result->control_buffer: The control buffer, used by ReadPacket internally. + // It is recommended to be + // |kDefaultUdpPacketControlBufferSize| bytes. + // result->ok: True iff a packet is successfully received. + // + // If |*result| is reused for subsequent ReadPacket() calls, caller needs to + // call result->Reset() before each ReadPacket(). + void ReadPacket(QuicUdpSocketFd fd, BitMask64 packet_info_interested, + ReadPacketResult* result); + + using ReadPacketResults = std::vector; + // Read up to |results->size()| packets from |fd|. The meaning of each element + // in |*results| has been documented on top of |ReadPacket|. + // Return the number of elements populated into |*results|, note it is + // possible for some of the populated elements to have ok=false. + size_t ReadMultiplePackets(QuicUdpSocketFd fd, + BitMask64 packet_info_interested, + ReadPacketResults* results); + + // Write a packet to |fd|. + // packet_buffer, packet_buffer_len: The packet buffer to write. + // packet_info: The per packet information to set. + WriteResult WritePacket(QuicUdpSocketFd fd, const char* packet_buffer, + size_t packet_buffer_len, + const QuicUdpPacketInfo& packet_info); + + protected: + bool SetupSocket(QuicUdpSocketFd fd, int address_family, + int receive_buffer_size, int send_buffer_size, + bool ipv6_only); + bool EnableReceiveSelfIpAddressForV4(QuicUdpSocketFd fd); + bool EnableReceiveSelfIpAddressForV6(QuicUdpSocketFd fd); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_UDP_SOCKET_H_ diff --git a/quiche/quic/core/quic_udp_socket_posix.cc b/quiche/quic/core/quic_udp_socket_posix.cc new file mode 100644 index 000000000000..26d4ef4370d8 --- /dev/null +++ b/quiche/quic/core/quic_udp_socket_posix.cc @@ -0,0 +1,711 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_types.h" +#if defined(__APPLE__) && !defined(__APPLE_USE_RFC_3542) +// This must be defined before including any system headers. +#define __APPLE_USE_RFC_3542 +#endif // defined(__APPLE__) && !defined(__APPLE_USE_RFC_3542) + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/optimization.h" +#include "quiche/quic/core/io/socket.h" +#include "quiche/quic/core/quic_udp_socket.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_udp_socket_platform_api.h" + +#if defined(__APPLE__) && !defined(__APPLE_USE_RFC_3542) +#error "__APPLE_USE_RFC_3542 needs to be defined." +#endif + +#if defined(__linux__) +#include +// For SO_TIMESTAMPING. +#include +#endif + +#if defined(__linux__) && !defined(__ANDROID__) +#define QUIC_UDP_SOCKET_SUPPORT_TTL 1 +#endif + +namespace quic { +namespace { + +// Explicit Congestion Notification is the last two bits of the TOS byte. +constexpr uint8_t kEcnMask = 0x03; + +#if defined(__linux__) && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21) +#define QUIC_UDP_SOCKET_SUPPORT_LINUX_TIMESTAMPING 1 +// This is the structure that SO_TIMESTAMPING fills into the cmsg header. +// It is well-defined, but does not have a definition in a public header. +// See https://www.kernel.org/doc/Documentation/networking/timestamping.txt +// for more information. +struct LinuxSoTimestamping { + // The converted system time of the timestamp. + struct timespec systime; + // Deprecated; serves only as padding. + struct timespec hwtimetrans; + // The raw hardware timestamp. + struct timespec hwtimeraw; +}; +const size_t kCmsgSpaceForRecvTimestamp = + CMSG_SPACE(sizeof(LinuxSoTimestamping)); +#else +const size_t kCmsgSpaceForRecvTimestamp = 0; +#endif + +const size_t kMinCmsgSpaceForRead = + CMSG_SPACE(sizeof(uint32_t)) // Dropped packet count + + CMSG_SPACE(sizeof(in_pktinfo)) // V4 Self IP + + CMSG_SPACE(sizeof(in6_pktinfo)) // V6 Self IP + + kCmsgSpaceForRecvTimestamp + CMSG_SPACE(sizeof(int)) // TTL + + kCmsgSpaceForGooglePacketHeader; + +void SetV4SelfIpInControlMessage(const QuicIpAddress& self_address, + cmsghdr* cmsg) { + QUICHE_DCHECK(self_address.IsIPv4()); + in_pktinfo* pktinfo = reinterpret_cast(CMSG_DATA(cmsg)); + memset(pktinfo, 0, sizeof(in_pktinfo)); + pktinfo->ipi_ifindex = 0; + std::string address_string = self_address.ToPackedString(); + memcpy(&pktinfo->ipi_spec_dst, address_string.c_str(), + address_string.length()); +} + +void SetV6SelfIpInControlMessage(const QuicIpAddress& self_address, + cmsghdr* cmsg) { + QUICHE_DCHECK(self_address.IsIPv6()); + in6_pktinfo* pktinfo = reinterpret_cast(CMSG_DATA(cmsg)); + memset(pktinfo, 0, sizeof(in6_pktinfo)); + std::string address_string = self_address.ToPackedString(); + memcpy(&pktinfo->ipi6_addr, address_string.c_str(), address_string.length()); +} + +void PopulatePacketInfoFromControlMessage(struct cmsghdr* cmsg, + QuicUdpPacketInfo* packet_info, + BitMask64 packet_info_interested) { +#ifdef SOL_UDP + if (packet_info_interested.IsSet(QuicUdpPacketInfoBit::IS_GRO) && + cmsg->cmsg_level == SOL_UDP && cmsg->cmsg_type == UDP_GRO) { + packet_info->set_gso_size(*reinterpret_cast(CMSG_DATA(cmsg))); + } +#endif + +#if defined(__linux__) && defined(SO_RXQ_OVFL) + if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SO_RXQ_OVFL) { + if (packet_info_interested.IsSet(QuicUdpPacketInfoBit::DROPPED_PACKETS)) { + packet_info->SetDroppedPackets( + *(reinterpret_cast CMSG_DATA(cmsg))); + } + return; + } +#endif + +#if defined(QUIC_UDP_SOCKET_SUPPORT_LINUX_TIMESTAMPING) + if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SO_TIMESTAMPING) { + if (packet_info_interested.IsSet(QuicUdpPacketInfoBit::RECV_TIMESTAMP)) { + LinuxSoTimestamping* linux_ts = + reinterpret_cast(CMSG_DATA(cmsg)); + timespec* ts = &linux_ts->systime; + int64_t usec = (static_cast(ts->tv_sec) * 1000 * 1000) + + (static_cast(ts->tv_nsec) / 1000); + packet_info->SetReceiveTimestamp( + QuicWallTime::FromUNIXMicroseconds(usec)); + } + return; + } +#endif + + if (cmsg->cmsg_level == IPPROTO_IPV6 && cmsg->cmsg_type == IPV6_PKTINFO) { + if (packet_info_interested.IsSet(QuicUdpPacketInfoBit::V6_SELF_IP)) { + const in6_pktinfo* info = reinterpret_cast(CMSG_DATA(cmsg)); + const char* addr_data = reinterpret_cast(&info->ipi6_addr); + int addr_len = sizeof(in6_addr); + QuicIpAddress self_v6_ip; + if (self_v6_ip.FromPackedString(addr_data, addr_len)) { + packet_info->SetSelfV6Ip(self_v6_ip); + } else { + QUIC_BUG(quic_bug_10751_1) << "QuicIpAddress::FromPackedString failed"; + } + } + return; + } + + if (cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_PKTINFO) { + if (packet_info_interested.IsSet(QuicUdpPacketInfoBit::V4_SELF_IP)) { + const in_pktinfo* info = reinterpret_cast(CMSG_DATA(cmsg)); + const char* addr_data = reinterpret_cast(&info->ipi_addr); + int addr_len = sizeof(in_addr); + QuicIpAddress self_v4_ip; + if (self_v4_ip.FromPackedString(addr_data, addr_len)) { + packet_info->SetSelfV4Ip(self_v4_ip); + } else { + QUIC_BUG(quic_bug_10751_2) << "QuicIpAddress::FromPackedString failed"; + } + } + return; + } + + if ((cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_TTL) || + (cmsg->cmsg_level == IPPROTO_IPV6 && cmsg->cmsg_type == IPV6_HOPLIMIT)) { + if (packet_info_interested.IsSet(QuicUdpPacketInfoBit::TTL)) { + packet_info->SetTtl(*(reinterpret_cast(CMSG_DATA(cmsg)))); + } + return; + } + + if ((cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_TOS) || + (cmsg->cmsg_level == IPPROTO_IPV6 && cmsg->cmsg_type == IPV6_TCLASS)) { + if (packet_info_interested.IsSet(QuicUdpPacketInfoBit::ECN)) { + packet_info->SetEcnCodepoint(QuicEcnCodepoint( + *(reinterpret_cast(CMSG_DATA(cmsg))) & kEcnMask)); + } + } + + if (packet_info_interested.IsSet( + QuicUdpPacketInfoBit::GOOGLE_PACKET_HEADER)) { + BufferSpan google_packet_headers; + if (GetGooglePacketHeadersFromControlMessage( + cmsg, &google_packet_headers.buffer, + &google_packet_headers.buffer_len)) { + packet_info->SetGooglePacketHeaders(google_packet_headers); + } + } +} + +bool NextCmsg(msghdr* hdr, char* control_buffer, size_t control_buffer_len, + int cmsg_level, int cmsg_type, size_t data_size, + cmsghdr** cmsg /*in, out*/) { + // msg_controllen needs to be increased first, otherwise CMSG_NXTHDR will + // return nullptr. + hdr->msg_controllen += CMSG_SPACE(data_size); + if (hdr->msg_controllen > control_buffer_len) { + return false; + } + + if ((*cmsg) == nullptr) { + QUICHE_DCHECK_EQ(nullptr, hdr->msg_control); + memset(control_buffer, 0, control_buffer_len); + hdr->msg_control = control_buffer; + (*cmsg) = CMSG_FIRSTHDR(hdr); + } else { + QUICHE_DCHECK_NE(nullptr, hdr->msg_control); + (*cmsg) = CMSG_NXTHDR(hdr, (*cmsg)); + } + + if (nullptr == (*cmsg)) { + return false; + } + + (*cmsg)->cmsg_len = CMSG_LEN(data_size); + (*cmsg)->cmsg_level = cmsg_level; + (*cmsg)->cmsg_type = cmsg_type; + + return true; +} +} // namespace + +QuicUdpSocketFd QuicUdpSocketApi::Create(int address_family, + int receive_buffer_size, + int send_buffer_size, bool ipv6_only) { + // QUICHE_DCHECK here so the program exits early(before reading packets) in + // debug mode. This should have been a static_assert, however it can't be done + // on ios/osx because CMSG_SPACE isn't a constant expression there. + QUICHE_DCHECK_GE(kDefaultUdpPacketControlBufferSize, kMinCmsgSpaceForRead); + + absl::StatusOr socket = socket_api::CreateSocket( + quiche::FromPlatformAddressFamily(address_family), + socket_api::SocketProtocol::kUdp, + /*blocking=*/false); + + if (!socket.ok()) { + QUIC_LOG_FIRST_N(ERROR, 100) + << "UDP non-blocking socket creation for address_family=" + << address_family << " failed: " << socket.status(); + return kQuicInvalidSocketFd; + } + + SetGoogleSocketOptions(socket.value()); + + if (!SetupSocket(socket.value(), address_family, receive_buffer_size, + send_buffer_size, ipv6_only)) { + Destroy(socket.value()); + return kQuicInvalidSocketFd; + } + + return socket.value(); +} + +bool QuicUdpSocketApi::SetupSocket(QuicUdpSocketFd fd, int address_family, + int receive_buffer_size, + int send_buffer_size, bool ipv6_only) { + // Receive buffer size. + if (setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &receive_buffer_size, + sizeof(receive_buffer_size)) != 0) { + QUIC_LOG_FIRST_N(ERROR, 100) << "Failed to set socket recv size"; + return false; + } + + // Send buffer size. + if (setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &send_buffer_size, + sizeof(send_buffer_size)) != 0) { + QUIC_LOG_FIRST_N(ERROR, 100) << "Failed to set socket send size"; + return false; + } + + if (GetQuicRestartFlag(quic_quiche_ecn_sockets)) { + QUIC_RESTART_FLAG_COUNT(quic_quiche_ecn_sockets); + unsigned int set = 1; + if (address_family == AF_INET && + setsockopt(fd, IPPROTO_IP, IP_RECVTOS, &set, sizeof(set)) != 0) { + QUIC_LOG_FIRST_N(ERROR, 100) << "Failed to request to receive ECN on " + << "socket"; + return false; + } + if (address_family == AF_INET6 && + setsockopt(fd, IPPROTO_IPV6, IPV6_RECVTCLASS, &set, sizeof(set)) != 0) { + QUIC_LOG_FIRST_N(ERROR, 100) << "Failed to request to receive ECN on " + << "socket"; + return false; + } + } + + if (!(address_family == AF_INET6 && ipv6_only)) { + if (!EnableReceiveSelfIpAddressForV4(fd)) { + QUIC_LOG_FIRST_N(ERROR, 100) + << "Failed to enable receiving of self v4 ip"; + return false; + } + } + + if (address_family == AF_INET6) { + if (!EnableReceiveSelfIpAddressForV6(fd)) { + QUIC_LOG_FIRST_N(ERROR, 100) + << "Failed to enable receiving of self v6 ip"; + return false; + } + } + + return true; +} + +void QuicUdpSocketApi::Destroy(QuicUdpSocketFd fd) { + if (fd != kQuicInvalidSocketFd) { + absl::Status result = socket_api::Close(fd); + if (!result.ok()) { + QUIC_LOG_FIRST_N(WARNING, 100) + << "Failed to close UDP socket with error " << result; + } + } +} + +bool QuicUdpSocketApi::Bind(QuicUdpSocketFd fd, QuicSocketAddress address) { + sockaddr_storage addr = address.generic_address(); + int addr_len = + address.host().IsIPv4() ? sizeof(sockaddr_in) : sizeof(sockaddr_in6); + return 0 == bind(fd, reinterpret_cast(&addr), addr_len); +} + +bool QuicUdpSocketApi::BindInterface(QuicUdpSocketFd fd, + const std::string& interface_name) { +#if defined(__linux__) && !defined(__ANDROID_API__) + if (interface_name.empty() || interface_name.size() >= IFNAMSIZ) { + QUIC_BUG(udp_bad_interface_name) + << "interface_name must be nonempty and shorter than " << IFNAMSIZ; + return false; + } + + return 0 == setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, + interface_name.c_str(), interface_name.length()); +#else + (void)fd; + (void)interface_name; + QUIC_BUG(interface_bind_not_implemented) + << "Interface binding is not implemented on this platform"; + return false; +#endif +} + +bool QuicUdpSocketApi::EnableDroppedPacketCount(QuicUdpSocketFd fd) { +#if defined(__linux__) && defined(SO_RXQ_OVFL) + int get_overflow = 1; + return 0 == setsockopt(fd, SOL_SOCKET, SO_RXQ_OVFL, &get_overflow, + sizeof(get_overflow)); +#else + (void)fd; + return false; +#endif +} + +bool QuicUdpSocketApi::EnableReceiveSelfIpAddressForV4(QuicUdpSocketFd fd) { + int get_self_ip = 1; + return 0 == setsockopt(fd, IPPROTO_IP, IP_PKTINFO, &get_self_ip, + sizeof(get_self_ip)); +} + +bool QuicUdpSocketApi::EnableReceiveSelfIpAddressForV6(QuicUdpSocketFd fd) { + int get_self_ip = 1; + return 0 == setsockopt(fd, IPPROTO_IPV6, IPV6_RECVPKTINFO, &get_self_ip, + sizeof(get_self_ip)); +} + +bool QuicUdpSocketApi::EnableReceiveTimestamp(QuicUdpSocketFd fd) { +#if defined(__linux__) && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21) + int timestamping = SOF_TIMESTAMPING_RX_SOFTWARE | SOF_TIMESTAMPING_SOFTWARE; + return 0 == setsockopt(fd, SOL_SOCKET, SO_TIMESTAMPING, ×tamping, + sizeof(timestamping)); +#else + (void)fd; + return false; +#endif +} + +bool QuicUdpSocketApi::EnableReceiveTtlForV4(QuicUdpSocketFd fd) { +#if defined(QUIC_UDP_SOCKET_SUPPORT_TTL) + int get_ttl = 1; + return 0 == setsockopt(fd, IPPROTO_IP, IP_RECVTTL, &get_ttl, sizeof(get_ttl)); +#else + (void)fd; + return false; +#endif +} + +bool QuicUdpSocketApi::EnableReceiveTtlForV6(QuicUdpSocketFd fd) { +#if defined(QUIC_UDP_SOCKET_SUPPORT_TTL) + int get_ttl = 1; + return 0 == setsockopt(fd, IPPROTO_IPV6, IPV6_RECVHOPLIMIT, &get_ttl, + sizeof(get_ttl)); +#else + (void)fd; + return false; +#endif +} + +bool QuicUdpSocketApi::WaitUntilReadable(QuicUdpSocketFd fd, + QuicTime::Delta timeout) { + fd_set read_fds; + FD_ZERO(&read_fds); + FD_SET(fd, &read_fds); + + timeval select_timeout; + select_timeout.tv_sec = timeout.ToSeconds(); + select_timeout.tv_usec = timeout.ToMicroseconds() % 1000000; + + return 1 == select(1 + fd, &read_fds, nullptr, nullptr, &select_timeout); +} + +void QuicUdpSocketApi::ReadPacket(QuicUdpSocketFd fd, + BitMask64 packet_info_interested, + ReadPacketResult* result) { + result->ok = false; + BufferSpan& packet_buffer = result->packet_buffer; + BufferSpan& control_buffer = result->control_buffer; + QuicUdpPacketInfo* packet_info = &result->packet_info; + + QUICHE_DCHECK_GE(control_buffer.buffer_len, kMinCmsgSpaceForRead); + + struct iovec iov = {packet_buffer.buffer, packet_buffer.buffer_len}; + struct sockaddr_storage raw_peer_address; + + if (control_buffer.buffer_len > 0) { + reinterpret_cast(control_buffer.buffer)->cmsg_len = + control_buffer.buffer_len; + } + + msghdr hdr; + hdr.msg_name = &raw_peer_address; + hdr.msg_namelen = sizeof(raw_peer_address); + hdr.msg_iov = &iov; + hdr.msg_iovlen = 1; + hdr.msg_flags = 0; + hdr.msg_control = control_buffer.buffer; + hdr.msg_controllen = control_buffer.buffer_len; + +#if defined(__linux__) + // If MSG_TRUNC is set on Linux, recvmsg will return the real packet size even + // if |packet_buffer| is too small to receive it. + int flags = MSG_TRUNC; +#else + int flags = 0; +#endif + + int bytes_read = recvmsg(fd, &hdr, flags); + if (bytes_read < 0) { + const int error_num = errno; + if (error_num != EAGAIN) { + QUIC_LOG_FIRST_N(ERROR, 100) + << "Error reading packet: " << strerror(error_num); + } + return; + } + + if (ABSL_PREDICT_FALSE(hdr.msg_flags & MSG_CTRUNC)) { + QUIC_BUG(quic_bug_10751_3) + << "Control buffer too small. size:" << control_buffer.buffer_len; + return; + } + + if (ABSL_PREDICT_FALSE(hdr.msg_flags & MSG_TRUNC) || + // Normally "bytes_read > packet_buffer.buffer_len" implies the MSG_TRUNC + // bit is set, but it is not the case if tested with config=android_arm64. + static_cast(bytes_read) > packet_buffer.buffer_len) { + QUIC_LOG_FIRST_N(WARNING, 100) + << "Received truncated QUIC packet: buffer size:" + << packet_buffer.buffer_len << " packet size:" << bytes_read; + return; + } + + packet_buffer.buffer_len = bytes_read; + if (packet_info_interested.IsSet(QuicUdpPacketInfoBit::PEER_ADDRESS)) { + packet_info->SetPeerAddress(QuicSocketAddress(raw_peer_address)); + } + + if (hdr.msg_controllen > 0) { + for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&hdr); cmsg != nullptr; + cmsg = CMSG_NXTHDR(&hdr, cmsg)) { + BitMask64 prior_bitmask = packet_info->bitmask(); + PopulatePacketInfoFromControlMessage(cmsg, packet_info, + packet_info_interested); + if (packet_info->bitmask() == prior_bitmask) { + QUIC_DLOG(INFO) << "Ignored cmsg_level:" << cmsg->cmsg_level + << ", cmsg_type:" << cmsg->cmsg_type; + } + } + } + + result->ok = true; +} + +size_t QuicUdpSocketApi::ReadMultiplePackets(QuicUdpSocketFd fd, + BitMask64 packet_info_interested, + ReadPacketResults* results) { +#if defined(__linux__) && !defined(__ANDROID__) + if (packet_info_interested.IsSet(QuicUdpPacketInfoBit::IS_GRO)) { + size_t num_packets = 0; + for (ReadPacketResult& result : *results) { + result.ok = false; + } + for (ReadPacketResult& result : *results) { + ReadPacket(fd, packet_info_interested, &result); + if (!result.ok) { + break; + } + ++num_packets; + } + return num_packets; + } else { + // Use recvmmsg. + size_t hdrs_size = sizeof(mmsghdr) * results->size(); + mmsghdr* hdrs = static_cast(alloca(hdrs_size)); + memset(hdrs, 0, hdrs_size); + + struct TempPerPacketData { + iovec iov; + sockaddr_storage raw_peer_address; + }; + TempPerPacketData* packet_data_array = static_cast( + alloca(sizeof(TempPerPacketData) * results->size())); + + for (size_t i = 0; i < results->size(); ++i) { + (*results)[i].ok = false; + + msghdr* hdr = &hdrs[i].msg_hdr; + TempPerPacketData* packet_data = &packet_data_array[i]; + packet_data->iov.iov_base = (*results)[i].packet_buffer.buffer; + packet_data->iov.iov_len = (*results)[i].packet_buffer.buffer_len; + + hdr->msg_name = &packet_data->raw_peer_address; + hdr->msg_namelen = sizeof(sockaddr_storage); + hdr->msg_iov = &packet_data->iov; + hdr->msg_iovlen = 1; + hdr->msg_flags = 0; + hdr->msg_control = (*results)[i].control_buffer.buffer; + hdr->msg_controllen = (*results)[i].control_buffer.buffer_len; + + QUICHE_DCHECK_GE(hdr->msg_controllen, kMinCmsgSpaceForRead); + } + // If MSG_TRUNC is set on Linux, recvmmsg will return the real packet size + // in |hdrs[i].msg_len| even if packet buffer is too small to receive it. + int packets_read = recvmmsg(fd, hdrs, results->size(), MSG_TRUNC, nullptr); + if (packets_read <= 0) { + const int error_num = errno; + if (error_num != EAGAIN) { + QUIC_LOG_FIRST_N(ERROR, 100) + << "Error reading packets: " << strerror(error_num); + } + return 0; + } + + for (int i = 0; i < packets_read; ++i) { + if (hdrs[i].msg_len == 0) { + continue; + } + + msghdr& hdr = hdrs[i].msg_hdr; + if (ABSL_PREDICT_FALSE(hdr.msg_flags & MSG_CTRUNC)) { + QUIC_BUG(quic_bug_10751_4) << "Control buffer too small. size:" + << (*results)[i].control_buffer.buffer_len + << ", need:" << hdr.msg_controllen; + continue; + } + + if (ABSL_PREDICT_FALSE(hdr.msg_flags & MSG_TRUNC)) { + QUIC_LOG_FIRST_N(WARNING, 100) + << "Received truncated QUIC packet: buffer size:" + << (*results)[i].packet_buffer.buffer_len + << " packet size:" << hdrs[i].msg_len; + continue; + } + + (*results)[i].ok = true; + (*results)[i].packet_buffer.buffer_len = hdrs[i].msg_len; + + QuicUdpPacketInfo* packet_info = &(*results)[i].packet_info; + if (packet_info_interested.IsSet(QuicUdpPacketInfoBit::PEER_ADDRESS)) { + packet_info->SetPeerAddress( + QuicSocketAddress(packet_data_array[i].raw_peer_address)); + } + + if (hdr.msg_controllen > 0) { + for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&hdr); cmsg != nullptr; + cmsg = CMSG_NXTHDR(&hdr, cmsg)) { + PopulatePacketInfoFromControlMessage(cmsg, packet_info, + packet_info_interested); + } + } + } + return packets_read; + } +#else + size_t num_packets = 0; + for (ReadPacketResult& result : *results) { + result.ok = false; + } + for (ReadPacketResult& result : *results) { + errno = 0; + ReadPacket(fd, packet_info_interested, &result); + if (!result.ok && errno == EAGAIN) { + break; + } + ++num_packets; + } + return num_packets; +#endif +} + +WriteResult QuicUdpSocketApi::WritePacket( + QuicUdpSocketFd fd, const char* packet_buffer, size_t packet_buffer_len, + const QuicUdpPacketInfo& packet_info) { + if (!packet_info.HasValue(QuicUdpPacketInfoBit::PEER_ADDRESS)) { + return WriteResult(WRITE_STATUS_ERROR, EINVAL); + } + + char control_buffer[512]; + sockaddr_storage raw_peer_address = + packet_info.peer_address().generic_address(); + iovec iov = {const_cast(packet_buffer), packet_buffer_len}; + + msghdr hdr; + hdr.msg_name = &raw_peer_address; + hdr.msg_namelen = packet_info.peer_address().host().IsIPv4() + ? sizeof(sockaddr_in) + : sizeof(sockaddr_in6); + hdr.msg_iov = &iov; + hdr.msg_iovlen = 1; + hdr.msg_flags = 0; + hdr.msg_control = nullptr; + hdr.msg_controllen = 0; + + cmsghdr* cmsg = nullptr; + + // Set self IP. + if (packet_info.HasValue(QuicUdpPacketInfoBit::V4_SELF_IP) && + packet_info.self_v4_ip().IsInitialized()) { + if (!NextCmsg(&hdr, control_buffer, sizeof(control_buffer), IPPROTO_IP, + IP_PKTINFO, sizeof(in_pktinfo), &cmsg)) { + QUIC_LOG_FIRST_N(ERROR, 100) + << "Not enough buffer to set self v4 ip address."; + return WriteResult(WRITE_STATUS_ERROR, EINVAL); + } + SetV4SelfIpInControlMessage(packet_info.self_v4_ip(), cmsg); + } else if (packet_info.HasValue(QuicUdpPacketInfoBit::V6_SELF_IP) && + packet_info.self_v6_ip().IsInitialized()) { + if (!NextCmsg(&hdr, control_buffer, sizeof(control_buffer), IPPROTO_IPV6, + IPV6_PKTINFO, sizeof(in6_pktinfo), &cmsg)) { + QUIC_LOG_FIRST_N(ERROR, 100) + << "Not enough buffer to set self v6 ip address."; + return WriteResult(WRITE_STATUS_ERROR, EINVAL); + } + SetV6SelfIpInControlMessage(packet_info.self_v6_ip(), cmsg); + } + +#if defined(QUIC_UDP_SOCKET_SUPPORT_TTL) + // Set ttl. + if (packet_info.HasValue(QuicUdpPacketInfoBit::TTL)) { + int cmsg_level = + packet_info.peer_address().host().IsIPv4() ? IPPROTO_IP : IPPROTO_IPV6; + int cmsg_type = + packet_info.peer_address().host().IsIPv4() ? IP_TTL : IPV6_HOPLIMIT; + if (!NextCmsg(&hdr, control_buffer, sizeof(control_buffer), cmsg_level, + cmsg_type, sizeof(int), &cmsg)) { + QUIC_LOG_FIRST_N(ERROR, 100) << "Not enough buffer to set ttl."; + return WriteResult(WRITE_STATUS_ERROR, EINVAL); + } + *reinterpret_cast(CMSG_DATA(cmsg)) = packet_info.ttl(); + } +#endif + + // TODO(b/270584616): This code block might go away when full support for + // marking ECN is implemented. + if (packet_info.HasValue(QuicUdpPacketInfoBit::ECN)) { + int cmsg_level = + packet_info.peer_address().host().IsIPv4() ? IPPROTO_IP : IPPROTO_IPV6; + int cmsg_type; + unsigned char value_buf[20]; + socklen_t value_len = sizeof(value_buf); + if (GetQuicRestartFlag(quic_platform_tos_sockopt)) { + QUIC_RESTART_FLAG_COUNT(quic_platform_tos_sockopt); + if (GetEcnCmsgArgsPreserveDscp( + fd, packet_info.peer_address().host().address_family(), + packet_info.ecn_codepoint(), cmsg_type, value_buf, + value_len) != 0) { + QUIC_LOG_FIRST_N(ERROR, 100) + << "Could not get ECN msg type for this platform."; + return WriteResult(WRITE_STATUS_ERROR, EINVAL); + } + } else { + cmsg_type = (cmsg_level == IPPROTO_IP) ? IP_TOS : IPV6_TCLASS; + *(int*)value_buf = static_cast(packet_info.ecn_codepoint()); + value_len = sizeof(int); + } + if (!NextCmsg(&hdr, control_buffer, sizeof(control_buffer), cmsg_level, + cmsg_type, value_len, &cmsg)) { + QUIC_LOG_FIRST_N(ERROR, 100) << "Not enough buffer to set ECN."; + return WriteResult(WRITE_STATUS_ERROR, EINVAL); + } + memcpy(CMSG_DATA(cmsg), value_buf, value_len); + } + + int rc; + do { + rc = sendmsg(fd, &hdr, 0); + } while (rc < 0 && errno == EINTR); + if (rc >= 0) { + return WriteResult(WRITE_STATUS_OK, rc); + } + return WriteResult((errno == EAGAIN || errno == EWOULDBLOCK) + ? WRITE_STATUS_BLOCKED + : WRITE_STATUS_ERROR, + errno); +} + +} // namespace quic diff --git a/quiche/quic/core/quic_unacked_packet_map.cc b/quiche/quic/core/quic_unacked_packet_map.cc new file mode 100644 index 000000000000..a89cf80421c4 --- /dev/null +++ b/quiche/quic/core/quic_unacked_packet_map.cc @@ -0,0 +1,652 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_unacked_packet_map.h" + +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "quiche/quic/core/quic_connection_stats.h" +#include "quiche/quic/core/quic_packet_number.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" + +namespace quic { + +namespace { +bool WillStreamFrameLengthSumWrapAround(QuicPacketLength lhs, + QuicPacketLength rhs) { + static_assert( + std::is_unsigned::value, + "This function assumes QuicPacketLength is an unsigned integer type."); + return std::numeric_limits::max() - lhs < rhs; +} + +enum QuicFrameTypeBitfield : uint32_t { + kInvalidFrameBitfield = 0, + kPaddingFrameBitfield = 1, + kRstStreamFrameBitfield = 1 << 1, + kConnectionCloseFrameBitfield = 1 << 2, + kGoawayFrameBitfield = 1 << 3, + kWindowUpdateFrameBitfield = 1 << 4, + kBlockedFrameBitfield = 1 << 5, + kStopWaitingFrameBitfield = 1 << 6, + kPingFrameBitfield = 1 << 7, + kCryptoFrameBitfield = 1 << 8, + kHandshakeDoneFrameBitfield = 1 << 9, + kStreamFrameBitfield = 1 << 10, + kAckFrameBitfield = 1 << 11, + kMtuDiscoveryFrameBitfield = 1 << 12, + kNewConnectionIdFrameBitfield = 1 << 13, + kMaxStreamsFrameBitfield = 1 << 14, + kStreamsBlockedFrameBitfield = 1 << 15, + kPathResponseFrameBitfield = 1 << 16, + kPathChallengeFrameBitfield = 1 << 17, + kStopSendingFrameBitfield = 1 << 18, + kMessageFrameBitfield = 1 << 19, + kNewTokenFrameBitfield = 1 << 20, + kRetireConnectionIdFrameBitfield = 1 << 21, + kAckFrequencyFrameBitfield = 1 << 22, +}; + +QuicFrameTypeBitfield GetFrameTypeBitfield(QuicFrameType type) { + switch (type) { + case PADDING_FRAME: + return kPaddingFrameBitfield; + case RST_STREAM_FRAME: + return kRstStreamFrameBitfield; + case CONNECTION_CLOSE_FRAME: + return kConnectionCloseFrameBitfield; + case GOAWAY_FRAME: + return kGoawayFrameBitfield; + case WINDOW_UPDATE_FRAME: + return kWindowUpdateFrameBitfield; + case BLOCKED_FRAME: + return kBlockedFrameBitfield; + case STOP_WAITING_FRAME: + return kStopWaitingFrameBitfield; + case PING_FRAME: + return kPingFrameBitfield; + case CRYPTO_FRAME: + return kCryptoFrameBitfield; + case HANDSHAKE_DONE_FRAME: + return kHandshakeDoneFrameBitfield; + case STREAM_FRAME: + return kStreamFrameBitfield; + case ACK_FRAME: + return kAckFrameBitfield; + case MTU_DISCOVERY_FRAME: + return kMtuDiscoveryFrameBitfield; + case NEW_CONNECTION_ID_FRAME: + return kNewConnectionIdFrameBitfield; + case MAX_STREAMS_FRAME: + return kMaxStreamsFrameBitfield; + case STREAMS_BLOCKED_FRAME: + return kStreamsBlockedFrameBitfield; + case PATH_RESPONSE_FRAME: + return kPathResponseFrameBitfield; + case PATH_CHALLENGE_FRAME: + return kPathChallengeFrameBitfield; + case STOP_SENDING_FRAME: + return kStopSendingFrameBitfield; + case MESSAGE_FRAME: + return kMessageFrameBitfield; + case NEW_TOKEN_FRAME: + return kNewTokenFrameBitfield; + case RETIRE_CONNECTION_ID_FRAME: + return kRetireConnectionIdFrameBitfield; + case ACK_FREQUENCY_FRAME: + return kAckFrequencyFrameBitfield; + case NUM_FRAME_TYPES: + QUIC_BUG(quic_bug_10518_1) << "Unexpected frame type"; + return kInvalidFrameBitfield; + } + QUIC_BUG(quic_bug_10518_2) << "Unexpected frame type"; + return kInvalidFrameBitfield; +} + +} // namespace + +QuicUnackedPacketMap::QuicUnackedPacketMap(Perspective perspective) + : perspective_(perspective), + least_unacked_(FirstSendingPacketNumber()), + bytes_in_flight_(0), + bytes_in_flight_per_packet_number_space_{0, 0, 0}, + packets_in_flight_(0), + last_inflight_packet_sent_time_(QuicTime::Zero()), + last_inflight_packets_sent_time_{ + {QuicTime::Zero()}, {QuicTime::Zero()}, {QuicTime::Zero()}}, + last_crypto_packet_sent_time_(QuicTime::Zero()), + session_notifier_(nullptr), + supports_multiple_packet_number_spaces_(false) {} + +QuicUnackedPacketMap::~QuicUnackedPacketMap() { + for (QuicTransmissionInfo& transmission_info : unacked_packets_) { + DeleteFrames(&(transmission_info.retransmittable_frames)); + } +} + +void QuicUnackedPacketMap::AddSentPacket(SerializedPacket* mutable_packet, + TransmissionType transmission_type, + QuicTime sent_time, bool set_in_flight, + bool measure_rtt, + QuicEcnCodepoint ecn_codepoint) { + const SerializedPacket& packet = *mutable_packet; + QuicPacketNumber packet_number = packet.packet_number; + QuicPacketLength bytes_sent = packet.encrypted_length; + QUIC_BUG_IF(quic_bug_12645_1, largest_sent_packet_.IsInitialized() && + largest_sent_packet_ >= packet_number) + << "largest_sent_packet_: " << largest_sent_packet_ + << ", packet_number: " << packet_number; + QUICHE_DCHECK_GE(packet_number, least_unacked_ + unacked_packets_.size()); + while (least_unacked_ + unacked_packets_.size() < packet_number) { + unacked_packets_.push_back(QuicTransmissionInfo()); + unacked_packets_.back().state = NEVER_SENT; + } + + const bool has_crypto_handshake = packet.has_crypto_handshake == IS_HANDSHAKE; + QuicTransmissionInfo info(packet.encryption_level, transmission_type, + sent_time, bytes_sent, has_crypto_handshake, + packet.has_ack_frequency, ecn_codepoint); + info.largest_acked = packet.largest_acked; + largest_sent_largest_acked_.UpdateMax(packet.largest_acked); + + if (!measure_rtt) { + QUIC_BUG_IF(quic_bug_12645_2, set_in_flight) + << "Packet " << mutable_packet->packet_number << ", transmission type " + << TransmissionTypeToString(mutable_packet->transmission_type) + << ", retransmittable frames: " + << QuicFramesToString(mutable_packet->retransmittable_frames) + << ", nonretransmittable_frames: " + << QuicFramesToString(mutable_packet->nonretransmittable_frames); + info.state = NOT_CONTRIBUTING_RTT; + } + + largest_sent_packet_ = packet_number; + if (set_in_flight) { + const PacketNumberSpace packet_number_space = + GetPacketNumberSpace(info.encryption_level); + bytes_in_flight_ += bytes_sent; + bytes_in_flight_per_packet_number_space_[packet_number_space] += bytes_sent; + ++packets_in_flight_; + info.in_flight = true; + largest_sent_retransmittable_packets_[packet_number_space] = packet_number; + last_inflight_packet_sent_time_ = sent_time; + last_inflight_packets_sent_time_[packet_number_space] = sent_time; + } + unacked_packets_.push_back(std::move(info)); + // Swap the retransmittable frames to avoid allocations. + // TODO(ianswett): Could use emplace_back when Chromium can. + if (has_crypto_handshake) { + last_crypto_packet_sent_time_ = sent_time; + } + + mutable_packet->retransmittable_frames.swap( + unacked_packets_.back().retransmittable_frames); +} + +void QuicUnackedPacketMap::RemoveObsoletePackets() { + while (!unacked_packets_.empty()) { + if (!IsPacketUseless(least_unacked_, unacked_packets_.front())) { + break; + } + DeleteFrames(&unacked_packets_.front().retransmittable_frames); + unacked_packets_.pop_front(); + ++least_unacked_; + } +} + +bool QuicUnackedPacketMap::HasRetransmittableFrames( + QuicPacketNumber packet_number) const { + QUICHE_DCHECK_GE(packet_number, least_unacked_); + QUICHE_DCHECK_LT(packet_number, least_unacked_ + unacked_packets_.size()); + return HasRetransmittableFrames( + unacked_packets_[packet_number - least_unacked_]); +} + +bool QuicUnackedPacketMap::HasRetransmittableFrames( + const QuicTransmissionInfo& info) const { + if (!QuicUtils::IsAckable(info.state)) { + return false; + } + + for (const auto& frame : info.retransmittable_frames) { + if (session_notifier_->IsFrameOutstanding(frame)) { + return true; + } + } + return false; +} + +void QuicUnackedPacketMap::RemoveRetransmittability( + QuicTransmissionInfo* info) { + DeleteFrames(&info->retransmittable_frames); + info->first_sent_after_loss.Clear(); +} + +void QuicUnackedPacketMap::RemoveRetransmittability( + QuicPacketNumber packet_number) { + QUICHE_DCHECK_GE(packet_number, least_unacked_); + QUICHE_DCHECK_LT(packet_number, least_unacked_ + unacked_packets_.size()); + QuicTransmissionInfo* info = + &unacked_packets_[packet_number - least_unacked_]; + RemoveRetransmittability(info); +} + +void QuicUnackedPacketMap::IncreaseLargestAcked( + QuicPacketNumber largest_acked) { + QUICHE_DCHECK(!largest_acked_.IsInitialized() || + largest_acked_ <= largest_acked); + largest_acked_ = largest_acked; +} + +void QuicUnackedPacketMap::MaybeUpdateLargestAckedOfPacketNumberSpace( + PacketNumberSpace packet_number_space, QuicPacketNumber packet_number) { + largest_acked_packets_[packet_number_space].UpdateMax(packet_number); +} + +bool QuicUnackedPacketMap::IsPacketUsefulForMeasuringRtt( + QuicPacketNumber packet_number, const QuicTransmissionInfo& info) const { + // Packet can be used for RTT measurement if it may yet be acked as the + // largest observed packet by the receiver. + return QuicUtils::IsAckable(info.state) && + (!largest_acked_.IsInitialized() || packet_number > largest_acked_) && + info.state != NOT_CONTRIBUTING_RTT; +} + +bool QuicUnackedPacketMap::IsPacketUsefulForCongestionControl( + const QuicTransmissionInfo& info) const { + // Packet contributes to congestion control if it is considered inflight. + return info.in_flight; +} + +bool QuicUnackedPacketMap::IsPacketUsefulForRetransmittableData( + const QuicTransmissionInfo& info) const { + // Wait for 1 RTT before giving up on the lost packet. + return info.first_sent_after_loss.IsInitialized() && + (!largest_acked_.IsInitialized() || + info.first_sent_after_loss > largest_acked_); +} + +bool QuicUnackedPacketMap::IsPacketUseless( + QuicPacketNumber packet_number, const QuicTransmissionInfo& info) const { + return !IsPacketUsefulForMeasuringRtt(packet_number, info) && + !IsPacketUsefulForCongestionControl(info) && + !IsPacketUsefulForRetransmittableData(info); +} + +bool QuicUnackedPacketMap::IsUnacked(QuicPacketNumber packet_number) const { + if (packet_number < least_unacked_ || + packet_number >= least_unacked_ + unacked_packets_.size()) { + return false; + } + return !IsPacketUseless(packet_number, + unacked_packets_[packet_number - least_unacked_]); +} + +void QuicUnackedPacketMap::RemoveFromInFlight(QuicTransmissionInfo* info) { + if (info->in_flight) { + QUIC_BUG_IF(quic_bug_12645_3, bytes_in_flight_ < info->bytes_sent); + QUIC_BUG_IF(quic_bug_12645_4, packets_in_flight_ == 0); + bytes_in_flight_ -= info->bytes_sent; + --packets_in_flight_; + + const PacketNumberSpace packet_number_space = + GetPacketNumberSpace(info->encryption_level); + if (bytes_in_flight_per_packet_number_space_[packet_number_space] < + info->bytes_sent) { + QUIC_BUG(quic_bug_10518_3) + << "bytes_in_flight: " + << bytes_in_flight_per_packet_number_space_[packet_number_space] + << " is smaller than bytes_sent: " << info->bytes_sent + << " for packet number space: " + << PacketNumberSpaceToString(packet_number_space); + bytes_in_flight_per_packet_number_space_[packet_number_space] = 0; + } else { + bytes_in_flight_per_packet_number_space_[packet_number_space] -= + info->bytes_sent; + } + if (bytes_in_flight_per_packet_number_space_[packet_number_space] == 0) { + last_inflight_packets_sent_time_[packet_number_space] = QuicTime::Zero(); + } + + info->in_flight = false; + } +} + +void QuicUnackedPacketMap::RemoveFromInFlight(QuicPacketNumber packet_number) { + QUICHE_DCHECK_GE(packet_number, least_unacked_); + QUICHE_DCHECK_LT(packet_number, least_unacked_ + unacked_packets_.size()); + QuicTransmissionInfo* info = + &unacked_packets_[packet_number - least_unacked_]; + RemoveFromInFlight(info); +} + +absl::InlinedVector +QuicUnackedPacketMap::NeuterUnencryptedPackets() { + absl::InlinedVector neutered_packets; + QuicPacketNumber packet_number = GetLeastUnacked(); + for (QuicUnackedPacketMap::iterator it = begin(); it != end(); + ++it, ++packet_number) { + if (!it->retransmittable_frames.empty() && + it->encryption_level == ENCRYPTION_INITIAL) { + QUIC_DVLOG(2) << "Neutering unencrypted packet " << packet_number; + // Once the connection swithes to forward secure, no unencrypted packets + // will be sent. The data has been abandoned in the cryto stream. Remove + // it from in flight. + RemoveFromInFlight(packet_number); + it->state = NEUTERED; + neutered_packets.push_back(packet_number); + // Notify session that the data has been delivered (but do not notify + // send algorithm). + // TODO(b/148868195): use NotifyFramesNeutered. + NotifyFramesAcked(*it, QuicTime::Delta::Zero(), QuicTime::Zero()); + QUICHE_DCHECK(!HasRetransmittableFrames(*it)); + } + } + QUICHE_DCHECK(!supports_multiple_packet_number_spaces_ || + last_inflight_packets_sent_time_[INITIAL_DATA] == + QuicTime::Zero()); + return neutered_packets; +} + +absl::InlinedVector +QuicUnackedPacketMap::NeuterHandshakePackets() { + absl::InlinedVector neutered_packets; + QuicPacketNumber packet_number = GetLeastUnacked(); + for (QuicUnackedPacketMap::iterator it = begin(); it != end(); + ++it, ++packet_number) { + if (!it->retransmittable_frames.empty() && + GetPacketNumberSpace(it->encryption_level) == HANDSHAKE_DATA) { + QUIC_DVLOG(2) << "Neutering handshake packet " << packet_number; + RemoveFromInFlight(packet_number); + // Notify session that the data has been delivered (but do not notify + // send algorithm). + it->state = NEUTERED; + neutered_packets.push_back(packet_number); + // TODO(b/148868195): use NotifyFramesNeutered. + NotifyFramesAcked(*it, QuicTime::Delta::Zero(), QuicTime::Zero()); + } + } + QUICHE_DCHECK(!supports_multiple_packet_number_spaces() || + last_inflight_packets_sent_time_[HANDSHAKE_DATA] == + QuicTime::Zero()); + return neutered_packets; +} + +bool QuicUnackedPacketMap::HasInFlightPackets() const { + return bytes_in_flight_ > 0; +} + +const QuicTransmissionInfo& QuicUnackedPacketMap::GetTransmissionInfo( + QuicPacketNumber packet_number) const { + return unacked_packets_[packet_number - least_unacked_]; +} + +QuicTransmissionInfo* QuicUnackedPacketMap::GetMutableTransmissionInfo( + QuicPacketNumber packet_number) { + return &unacked_packets_[packet_number - least_unacked_]; +} + +QuicTime QuicUnackedPacketMap::GetLastInFlightPacketSentTime() const { + return last_inflight_packet_sent_time_; +} + +QuicTime QuicUnackedPacketMap::GetLastCryptoPacketSentTime() const { + return last_crypto_packet_sent_time_; +} + +size_t QuicUnackedPacketMap::GetNumUnackedPacketsDebugOnly() const { + size_t unacked_packet_count = 0; + QuicPacketNumber packet_number = least_unacked_; + for (auto it = begin(); it != end(); ++it, ++packet_number) { + if (!IsPacketUseless(packet_number, *it)) { + ++unacked_packet_count; + } + } + return unacked_packet_count; +} + +bool QuicUnackedPacketMap::HasMultipleInFlightPackets() const { + if (bytes_in_flight_ > kDefaultTCPMSS) { + return true; + } + size_t num_in_flight = 0; + for (auto it = rbegin(); it != rend(); ++it) { + if (it->in_flight) { + ++num_in_flight; + } + if (num_in_flight > 1) { + return true; + } + } + return false; +} + +bool QuicUnackedPacketMap::HasPendingCryptoPackets() const { + return session_notifier_->HasUnackedCryptoData(); +} + +bool QuicUnackedPacketMap::HasUnackedRetransmittableFrames() const { + for (auto it = rbegin(); it != rend(); ++it) { + if (it->in_flight && HasRetransmittableFrames(*it)) { + return true; + } + } + return false; +} + +QuicPacketNumber QuicUnackedPacketMap::GetLeastUnacked() const { + return least_unacked_; +} + +void QuicUnackedPacketMap::SetSessionNotifier( + SessionNotifierInterface* session_notifier) { + session_notifier_ = session_notifier; +} + +bool QuicUnackedPacketMap::NotifyFramesAcked(const QuicTransmissionInfo& info, + QuicTime::Delta ack_delay, + QuicTime receive_timestamp) { + if (session_notifier_ == nullptr) { + return false; + } + bool new_data_acked = false; + for (const QuicFrame& frame : info.retransmittable_frames) { + if (session_notifier_->OnFrameAcked(frame, ack_delay, receive_timestamp)) { + new_data_acked = true; + } + } + return new_data_acked; +} + +void QuicUnackedPacketMap::NotifyFramesLost(const QuicTransmissionInfo& info, + TransmissionType /*type*/) { + for (const QuicFrame& frame : info.retransmittable_frames) { + session_notifier_->OnFrameLost(frame); + } +} + +bool QuicUnackedPacketMap::RetransmitFrames(const QuicFrames& frames, + TransmissionType type) { + return session_notifier_->RetransmitFrames(frames, type); +} + +void QuicUnackedPacketMap::MaybeAggregateAckedStreamFrame( + const QuicTransmissionInfo& info, QuicTime::Delta ack_delay, + QuicTime receive_timestamp) { + if (session_notifier_ == nullptr) { + return; + } + for (const auto& frame : info.retransmittable_frames) { + // Determine whether acked stream frame can be aggregated. + const bool can_aggregate = + frame.type == STREAM_FRAME && + frame.stream_frame.stream_id == aggregated_stream_frame_.stream_id && + frame.stream_frame.offset == aggregated_stream_frame_.offset + + aggregated_stream_frame_.data_length && + // We would like to increment aggregated_stream_frame_.data_length by + // frame.stream_frame.data_length, so we need to make sure their sum is + // representable by QuicPacketLength, which is the type of the former. + !WillStreamFrameLengthSumWrapAround( + aggregated_stream_frame_.data_length, + frame.stream_frame.data_length); + + if (can_aggregate) { + // Aggregate stream frame. + aggregated_stream_frame_.data_length += frame.stream_frame.data_length; + aggregated_stream_frame_.fin = frame.stream_frame.fin; + if (aggregated_stream_frame_.fin) { + // Notify session notifier aggregated stream frame gets acked if fin is + // acked. + NotifyAggregatedStreamFrameAcked(ack_delay); + } + continue; + } + + NotifyAggregatedStreamFrameAcked(ack_delay); + if (frame.type != STREAM_FRAME || frame.stream_frame.fin) { + session_notifier_->OnFrameAcked(frame, ack_delay, receive_timestamp); + continue; + } + + // Delay notifying session notifier stream frame gets acked in case it can + // be aggregated with following acked ones. + aggregated_stream_frame_.stream_id = frame.stream_frame.stream_id; + aggregated_stream_frame_.offset = frame.stream_frame.offset; + aggregated_stream_frame_.data_length = frame.stream_frame.data_length; + aggregated_stream_frame_.fin = frame.stream_frame.fin; + } +} + +void QuicUnackedPacketMap::NotifyAggregatedStreamFrameAcked( + QuicTime::Delta ack_delay) { + if (aggregated_stream_frame_.stream_id == static_cast(-1) || + session_notifier_ == nullptr) { + // Aggregated stream frame is empty. + return; + } + // Note: there is no receive_timestamp for an aggregated stream frame. The + // frames that are aggregated may not have been received at the same time. + session_notifier_->OnFrameAcked(QuicFrame(aggregated_stream_frame_), + ack_delay, + /*receive_timestamp=*/QuicTime::Zero()); + // Clear aggregated stream frame. + aggregated_stream_frame_.stream_id = -1; +} + +PacketNumberSpace QuicUnackedPacketMap::GetPacketNumberSpace( + QuicPacketNumber packet_number) const { + return GetPacketNumberSpace( + GetTransmissionInfo(packet_number).encryption_level); +} + +PacketNumberSpace QuicUnackedPacketMap::GetPacketNumberSpace( + EncryptionLevel encryption_level) const { + if (supports_multiple_packet_number_spaces_) { + return QuicUtils::GetPacketNumberSpace(encryption_level); + } + if (perspective_ == Perspective::IS_CLIENT) { + return encryption_level == ENCRYPTION_INITIAL ? HANDSHAKE_DATA + : APPLICATION_DATA; + } + return encryption_level == ENCRYPTION_FORWARD_SECURE ? APPLICATION_DATA + : HANDSHAKE_DATA; +} + +QuicPacketNumber QuicUnackedPacketMap::GetLargestAckedOfPacketNumberSpace( + PacketNumberSpace packet_number_space) const { + if (packet_number_space >= NUM_PACKET_NUMBER_SPACES) { + QUIC_BUG(quic_bug_10518_4) + << "Invalid packet number space: " << packet_number_space; + return QuicPacketNumber(); + } + return largest_acked_packets_[packet_number_space]; +} + +QuicTime QuicUnackedPacketMap::GetLastInFlightPacketSentTime( + PacketNumberSpace packet_number_space) const { + if (packet_number_space >= NUM_PACKET_NUMBER_SPACES) { + QUIC_BUG(quic_bug_10518_5) + << "Invalid packet number space: " << packet_number_space; + return QuicTime::Zero(); + } + return last_inflight_packets_sent_time_[packet_number_space]; +} + +QuicPacketNumber +QuicUnackedPacketMap::GetLargestSentRetransmittableOfPacketNumberSpace( + PacketNumberSpace packet_number_space) const { + if (packet_number_space >= NUM_PACKET_NUMBER_SPACES) { + QUIC_BUG(quic_bug_10518_6) + << "Invalid packet number space: " << packet_number_space; + return QuicPacketNumber(); + } + return largest_sent_retransmittable_packets_[packet_number_space]; +} + +const QuicTransmissionInfo* +QuicUnackedPacketMap::GetFirstInFlightTransmissionInfo() const { + QUICHE_DCHECK(HasInFlightPackets()); + for (auto it = begin(); it != end(); ++it) { + if (it->in_flight) { + return &(*it); + } + } + QUICHE_DCHECK(false); + return nullptr; +} + +const QuicTransmissionInfo* +QuicUnackedPacketMap::GetFirstInFlightTransmissionInfoOfSpace( + PacketNumberSpace packet_number_space) const { + // TODO(fayang): Optimize this part if arm 1st PTO with first in flight sent + // time works. + for (auto it = begin(); it != end(); ++it) { + if (it->in_flight && + GetPacketNumberSpace(it->encryption_level) == packet_number_space) { + return &(*it); + } + } + return nullptr; +} + +void QuicUnackedPacketMap::EnableMultiplePacketNumberSpacesSupport() { + if (supports_multiple_packet_number_spaces_) { + QUIC_BUG(quic_bug_10518_7) + << "Multiple packet number spaces has already been enabled"; + return; + } + if (largest_sent_packet_.IsInitialized()) { + QUIC_BUG(quic_bug_10518_8) + << "Try to enable multiple packet number spaces support after any " + "packet has been sent."; + return; + } + + supports_multiple_packet_number_spaces_ = true; +} + +int32_t QuicUnackedPacketMap::GetLastPacketContent() const { + if (empty()) { + // Use -1 to distinguish with packets with no retransmittable frames nor + // acks. + return -1; + } + int32_t content = 0; + const QuicTransmissionInfo& last_packet = unacked_packets_.back(); + for (const auto& frame : last_packet.retransmittable_frames) { + content |= GetFrameTypeBitfield(frame.type); + } + if (last_packet.largest_acked.IsInitialized()) { + content |= GetFrameTypeBitfield(ACK_FRAME); + } + return content; +} + +} // namespace quic diff --git a/quiche/quic/core/quic_unacked_packet_map.h b/quiche/quic/core/quic_unacked_packet_map.h new file mode 100644 index 000000000000..143fa5f5ee62 --- /dev/null +++ b/quiche/quic/core/quic_unacked_packet_map.h @@ -0,0 +1,336 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_UNACKED_PACKET_MAP_H_ +#define QUICHE_QUIC_CORE_QUIC_UNACKED_PACKET_MAP_H_ + +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_transmission_info.h" +#include "quiche/quic/core/session_notifier_interface.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/common/quiche_circular_deque.h" + +namespace quic { + +namespace test { +class QuicUnackedPacketMapPeer; +} // namespace test + +// Class which tracks unacked packets for three purposes: +// 1) Track retransmittable data, including multiple transmissions of frames. +// 2) Track packets and bytes in flight for congestion control. +// 3) Track sent time of packets to provide RTT measurements from acks. +class QUIC_EXPORT_PRIVATE QuicUnackedPacketMap { + public: + QuicUnackedPacketMap(Perspective perspective); + QuicUnackedPacketMap(const QuicUnackedPacketMap&) = delete; + QuicUnackedPacketMap& operator=(const QuicUnackedPacketMap&) = delete; + ~QuicUnackedPacketMap(); + + // Adds |mutable_packet| to the map and marks it as sent at |sent_time|. + // Marks the packet as in flight if |set_in_flight| is true. + // Packets marked as in flight are expected to be marked as missing when they + // don't arrive, indicating the need for retransmission. + // Any retransmittible_frames in |mutable_packet| are swapped from + // |mutable_packet| into the QuicTransmissionInfo. + void AddSentPacket(SerializedPacket* mutable_packet, + TransmissionType transmission_type, QuicTime sent_time, + bool set_in_flight, bool measure_rtt, + QuicEcnCodepoint ecn_codepoint); + + // Returns true if the packet |packet_number| is unacked. + bool IsUnacked(QuicPacketNumber packet_number) const; + + // Notifies session_notifier that frames have been acked. Returns true if any + // new data gets acked, returns false otherwise. + bool NotifyFramesAcked(const QuicTransmissionInfo& info, + QuicTime::Delta ack_delay, QuicTime receive_timestamp); + + // Notifies session_notifier that frames in |info| are considered as lost. + void NotifyFramesLost(const QuicTransmissionInfo& info, + TransmissionType type); + + // Notifies session_notifier to retransmit frames with |transmission_type|. + // Returns true if all data gets retransmitted. + bool RetransmitFrames(const QuicFrames& frames, TransmissionType type); + + // Marks |info| as no longer in flight. + void RemoveFromInFlight(QuicTransmissionInfo* info); + + // Marks |packet_number| as no longer in flight. + void RemoveFromInFlight(QuicPacketNumber packet_number); + + // Called to neuter all unencrypted packets to ensure they do not get + // retransmitted. Returns a vector of neutered packet numbers. + absl::InlinedVector NeuterUnencryptedPackets(); + + // Called to neuter packets in handshake packet number space to ensure they do + // not get retransmitted. Returns a vector of neutered packet numbers. + // TODO(fayang): Consider to combine this with NeuterUnencryptedPackets. + absl::InlinedVector NeuterHandshakePackets(); + + // Returns true if |packet_number| has retransmittable frames. This will + // return false if all frames of this packet are either non-retransmittable or + // have been acked. + bool HasRetransmittableFrames(QuicPacketNumber packet_number) const; + + // Returns true if |info| has retransmittable frames. This will return false + // if all frames of this packet are either non-retransmittable or have been + // acked. + bool HasRetransmittableFrames(const QuicTransmissionInfo& info) const; + + // Returns true if there are any unacked packets which have retransmittable + // frames. + bool HasUnackedRetransmittableFrames() const; + + // Returns true if there are no packets present in the unacked packet map. + bool empty() const { return unacked_packets_.empty(); } + + // Returns the largest packet number that has been sent. + QuicPacketNumber largest_sent_packet() const { return largest_sent_packet_; } + + QuicPacketNumber largest_sent_largest_acked() const { + return largest_sent_largest_acked_; + } + + // Returns the largest packet number that has been acked. + QuicPacketNumber largest_acked() const { return largest_acked_; } + + // Returns the sum of bytes from all packets in flight. + QuicByteCount bytes_in_flight() const { return bytes_in_flight_; } + QuicPacketCount packets_in_flight() const { return packets_in_flight_; } + + // Returns the smallest packet number of a serialized packet which has not + // been acked by the peer. If there are no unacked packets, returns 0. + QuicPacketNumber GetLeastUnacked() const; + + using const_iterator = + quiche::QuicheCircularDeque::const_iterator; + using const_reverse_iterator = + quiche::QuicheCircularDeque::const_reverse_iterator; + using iterator = quiche::QuicheCircularDeque::iterator; + + const_iterator begin() const { return unacked_packets_.begin(); } + const_iterator end() const { return unacked_packets_.end(); } + const_reverse_iterator rbegin() const { return unacked_packets_.rbegin(); } + const_reverse_iterator rend() const { return unacked_packets_.rend(); } + iterator begin() { return unacked_packets_.begin(); } + iterator end() { return unacked_packets_.end(); } + + // Returns true if there are unacked packets that are in flight. + bool HasInFlightPackets() const; + + // Returns the QuicTransmissionInfo associated with |packet_number|, which + // must be unacked. + const QuicTransmissionInfo& GetTransmissionInfo( + QuicPacketNumber packet_number) const; + + // Returns mutable QuicTransmissionInfo associated with |packet_number|, which + // must be unacked. + QuicTransmissionInfo* GetMutableTransmissionInfo( + QuicPacketNumber packet_number); + + // Returns the time that the last unacked packet was sent. + QuicTime GetLastInFlightPacketSentTime() const; + + // Returns the time that the last unacked crypto packet was sent. + QuicTime GetLastCryptoPacketSentTime() const; + + // Returns the number of unacked packets. + size_t GetNumUnackedPacketsDebugOnly() const; + + // Returns true if there are multiple packets in flight. + // TODO(fayang): Remove this method and use packets_in_flight_ instead. + bool HasMultipleInFlightPackets() const; + + // Returns true if there are any pending crypto packets. + bool HasPendingCryptoPackets() const; + + // Returns true if there is any unacked non-crypto stream data. + bool HasUnackedStreamData() const { + return session_notifier_->HasUnackedStreamData(); + } + + // Removes any retransmittable frames from this transmission or an associated + // transmission. It removes now useless transmissions, and disconnects any + // other packets from other transmissions. + void RemoveRetransmittability(QuicTransmissionInfo* info); + + // Looks up the QuicTransmissionInfo by |packet_number| and calls + // RemoveRetransmittability. + void RemoveRetransmittability(QuicPacketNumber packet_number); + + // Increases the largest acked. Any packets less or equal to + // |largest_acked| are discarded if they are only for the RTT purposes. + void IncreaseLargestAcked(QuicPacketNumber largest_acked); + + // Called when |packet_number| gets acked. Maybe increase the largest acked of + // |packet_number_space|. + void MaybeUpdateLargestAckedOfPacketNumberSpace( + PacketNumberSpace packet_number_space, QuicPacketNumber packet_number); + + // Remove any packets no longer needed for retransmission, congestion, or + // RTT measurement purposes. + void RemoveObsoletePackets(); + + // Try to aggregate acked contiguous stream frames. For noncontiguous stream + // frames or control frames, notify the session notifier they get acked + // immediately. + void MaybeAggregateAckedStreamFrame(const QuicTransmissionInfo& info, + QuicTime::Delta ack_delay, + QuicTime receive_timestamp); + + // Notify the session notifier of any stream data aggregated in + // aggregated_stream_frame_. No effect if the stream frame has an invalid + // stream id. + void NotifyAggregatedStreamFrameAcked(QuicTime::Delta ack_delay); + + // Returns packet number space that |packet_number| belongs to. Please use + // GetPacketNumberSpace(EncryptionLevel) whenever encryption level is + // available. + PacketNumberSpace GetPacketNumberSpace(QuicPacketNumber packet_number) const; + + // Returns packet number space of |encryption_level|. + PacketNumberSpace GetPacketNumberSpace( + EncryptionLevel encryption_level) const; + + // Returns largest acked packet number of |packet_number_space|. + QuicPacketNumber GetLargestAckedOfPacketNumberSpace( + PacketNumberSpace packet_number_space) const; + + // Returns largest sent retransmittable packet number of + // |packet_number_space|. + QuicPacketNumber GetLargestSentRetransmittableOfPacketNumberSpace( + PacketNumberSpace packet_number_space) const; + + // Returns largest sent packet number of |encryption_level|. + QuicPacketNumber GetLargestSentPacketOfPacketNumberSpace( + EncryptionLevel encryption_level) const; + + // Returns last in flight packet sent time of |packet_number_space|. + QuicTime GetLastInFlightPacketSentTime( + PacketNumberSpace packet_number_space) const; + + // Returns TransmissionInfo of the first in flight packet. + const QuicTransmissionInfo* GetFirstInFlightTransmissionInfo() const; + + // Returns TransmissionInfo of first in flight packet in + // |packet_number_space|. + const QuicTransmissionInfo* GetFirstInFlightTransmissionInfoOfSpace( + PacketNumberSpace packet_number_space) const; + + void SetSessionNotifier(SessionNotifierInterface* session_notifier); + + void EnableMultiplePacketNumberSpacesSupport(); + + // Returns a bitfield of retransmittable frames of last packet in + // unacked_packets_. For example, if the packet contains STREAM_FRAME, content + // & (1 << STREAM_FRAME) would be set. Returns max uint32_t if + // unacked_packets_ is empty. + int32_t GetLastPacketContent() const; + + Perspective perspective() const { return perspective_; } + + bool supports_multiple_packet_number_spaces() const { + return supports_multiple_packet_number_spaces_; + } + + void ReserveInitialCapacity(size_t initial_capacity) { + unacked_packets_.reserve(initial_capacity); + } + + std::string DebugString() const { + return absl::StrCat( + "{size: ", unacked_packets_.size(), + ", least_unacked: ", least_unacked_.ToString(), + ", largest_sent_packet: ", largest_sent_packet_.ToString(), + ", largest_acked: ", largest_acked_.ToString(), + ", bytes_in_flight: ", bytes_in_flight_, + ", packets_in_flight: ", packets_in_flight_, "}"); + } + + private: + friend class test::QuicUnackedPacketMapPeer; + + // Returns true if packet may be useful for an RTT measurement. + bool IsPacketUsefulForMeasuringRtt(QuicPacketNumber packet_number, + const QuicTransmissionInfo& info) const; + + // Returns true if packet may be useful for congestion control purposes. + bool IsPacketUsefulForCongestionControl( + const QuicTransmissionInfo& info) const; + + // Returns true if packet may be associated with retransmittable data + // directly or through retransmissions. + bool IsPacketUsefulForRetransmittableData( + const QuicTransmissionInfo& info) const; + + // Returns true if the packet no longer has a purpose in the map. + bool IsPacketUseless(QuicPacketNumber packet_number, + const QuicTransmissionInfo& info) const; + + const Perspective perspective_; + + QuicPacketNumber largest_sent_packet_; + // The largest sent packet we expect to receive an ack for per packet number + // space. + QuicPacketNumber + largest_sent_retransmittable_packets_[NUM_PACKET_NUMBER_SPACES]; + // The largest sent largest_acked in an ACK frame. + QuicPacketNumber largest_sent_largest_acked_; + // The largest received largest_acked from an ACK frame. + QuicPacketNumber largest_acked_; + // The largest received largest_acked from ACK frame per packet number space. + QuicPacketNumber largest_acked_packets_[NUM_PACKET_NUMBER_SPACES]; + + // Newly serialized retransmittable packets are added to this map, which + // contains owning pointers to any contained frames. If a packet is + // retransmitted, this map will contain entries for both the old and the new + // packet. The old packet's retransmittable frames entry will be nullptr, + // while the new packet's entry will contain the frames to retransmit. + // If the old packet is acked before the new packet, then the old entry will + // be removed from the map and the new entry's retransmittable frames will be + // set to nullptr. + quiche::QuicheCircularDeque unacked_packets_; + + // The packet at the 0th index of unacked_packets_. + QuicPacketNumber least_unacked_; + + QuicByteCount bytes_in_flight_; + // Bytes in flight per packet number space. + QuicByteCount + bytes_in_flight_per_packet_number_space_[NUM_PACKET_NUMBER_SPACES]; + QuicPacketCount packets_in_flight_; + + // Time that the last inflight packet was sent. + QuicTime last_inflight_packet_sent_time_; + // Time that the last in flight packet was sent per packet number space. + QuicTime last_inflight_packets_sent_time_[NUM_PACKET_NUMBER_SPACES]; + + // Time that the last unacked crypto packet was sent. + QuicTime last_crypto_packet_sent_time_; + + // Aggregates acked stream data across multiple acked sent packets to save CPU + // by reducing the number of calls to the session notifier. + QuicStreamFrame aggregated_stream_frame_; + + // Receives notifications of frames being retransmitted or acknowledged. + SessionNotifierInterface* session_notifier_; + + // If true, supports multiple packet number spaces. + bool supports_multiple_packet_number_spaces_; + + // Latched value of the quic_simple_inflight_time flag. + bool simple_inflight_time_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_UNACKED_PACKET_MAP_H_ diff --git a/quiche/quic/core/quic_unacked_packet_map_test.cc b/quiche/quic/core/quic_unacked_packet_map_test.cc new file mode 100644 index 000000000000..a8510db81b65 --- /dev/null +++ b/quiche/quic/core/quic_unacked_packet_map_test.cc @@ -0,0 +1,722 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_unacked_packet_map.h" + +#include +#include + +#include "absl/base/macros.h" +#include "quiche/quic/core/frames/quic_stream_frame.h" +#include "quiche/quic/core/quic_packet_number.h" +#include "quiche/quic/core/quic_transmission_info.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/test_tools/quic_unacked_packet_map_peer.h" + +using testing::_; +using testing::Return; +using testing::StrictMock; + +namespace quic { +namespace test { +namespace { + +// Default packet length. +const uint32_t kDefaultLength = 1000; + +class QuicUnackedPacketMapTest : public QuicTestWithParam { + protected: + QuicUnackedPacketMapTest() + : unacked_packets_(GetParam()), + now_(QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(1000)) { + unacked_packets_.SetSessionNotifier(¬ifier_); + EXPECT_CALL(notifier_, IsFrameOutstanding(_)).WillRepeatedly(Return(true)); + EXPECT_CALL(notifier_, OnStreamFrameRetransmitted(_)) + .Times(testing::AnyNumber()); + } + + ~QuicUnackedPacketMapTest() override {} + + SerializedPacket CreateRetransmittablePacket(uint64_t packet_number) { + return CreateRetransmittablePacketForStream( + packet_number, QuicUtils::GetFirstBidirectionalStreamId( + CurrentSupportedVersions()[0].transport_version, + Perspective::IS_CLIENT)); + } + + SerializedPacket CreateRetransmittablePacketForStream( + uint64_t packet_number, QuicStreamId stream_id) { + SerializedPacket packet(QuicPacketNumber(packet_number), + PACKET_1BYTE_PACKET_NUMBER, nullptr, kDefaultLength, + false, false); + QuicStreamFrame frame; + frame.stream_id = stream_id; + packet.retransmittable_frames.push_back(QuicFrame(frame)); + return packet; + } + + SerializedPacket CreateNonRetransmittablePacket(uint64_t packet_number) { + return SerializedPacket(QuicPacketNumber(packet_number), + PACKET_1BYTE_PACKET_NUMBER, nullptr, kDefaultLength, + false, false); + } + + void VerifyInFlightPackets(uint64_t* packets, size_t num_packets) { + unacked_packets_.RemoveObsoletePackets(); + if (num_packets == 0) { + EXPECT_FALSE(unacked_packets_.HasInFlightPackets()); + EXPECT_FALSE(unacked_packets_.HasMultipleInFlightPackets()); + return; + } + if (num_packets == 1) { + EXPECT_TRUE(unacked_packets_.HasInFlightPackets()); + EXPECT_FALSE(unacked_packets_.HasMultipleInFlightPackets()); + ASSERT_TRUE(unacked_packets_.IsUnacked(QuicPacketNumber(packets[0]))); + EXPECT_TRUE( + unacked_packets_.GetTransmissionInfo(QuicPacketNumber(packets[0])) + .in_flight); + } + for (size_t i = 0; i < num_packets; ++i) { + ASSERT_TRUE(unacked_packets_.IsUnacked(QuicPacketNumber(packets[i]))); + EXPECT_TRUE( + unacked_packets_.GetTransmissionInfo(QuicPacketNumber(packets[i])) + .in_flight); + } + size_t in_flight_count = 0; + for (auto it = unacked_packets_.begin(); it != unacked_packets_.end(); + ++it) { + if (it->in_flight) { + ++in_flight_count; + } + } + EXPECT_EQ(num_packets, in_flight_count); + } + + void VerifyUnackedPackets(uint64_t* packets, size_t num_packets) { + unacked_packets_.RemoveObsoletePackets(); + if (num_packets == 0) { + EXPECT_TRUE(unacked_packets_.empty()); + EXPECT_FALSE(unacked_packets_.HasUnackedRetransmittableFrames()); + return; + } + EXPECT_FALSE(unacked_packets_.empty()); + for (size_t i = 0; i < num_packets; ++i) { + EXPECT_TRUE(unacked_packets_.IsUnacked(QuicPacketNumber(packets[i]))) + << packets[i]; + } + EXPECT_EQ(num_packets, unacked_packets_.GetNumUnackedPacketsDebugOnly()); + } + + void VerifyRetransmittablePackets(uint64_t* packets, size_t num_packets) { + unacked_packets_.RemoveObsoletePackets(); + size_t num_retransmittable_packets = 0; + for (auto it = unacked_packets_.begin(); it != unacked_packets_.end(); + ++it) { + if (unacked_packets_.HasRetransmittableFrames(*it)) { + ++num_retransmittable_packets; + } + } + EXPECT_EQ(num_packets, num_retransmittable_packets); + for (size_t i = 0; i < num_packets; ++i) { + EXPECT_TRUE(unacked_packets_.HasRetransmittableFrames( + QuicPacketNumber(packets[i]))) + << " packets[" << i << "]:" << packets[i]; + } + } + + void UpdatePacketState(uint64_t packet_number, SentPacketState state) { + unacked_packets_ + .GetMutableTransmissionInfo(QuicPacketNumber(packet_number)) + ->state = state; + } + + void RetransmitAndSendPacket(uint64_t old_packet_number, + uint64_t new_packet_number, + TransmissionType transmission_type) { + QUICHE_DCHECK(unacked_packets_.HasRetransmittableFrames( + QuicPacketNumber(old_packet_number))); + QuicTransmissionInfo* info = unacked_packets_.GetMutableTransmissionInfo( + QuicPacketNumber(old_packet_number)); + QuicStreamId stream_id = QuicUtils::GetFirstBidirectionalStreamId( + CurrentSupportedVersions()[0].transport_version, + Perspective::IS_CLIENT); + for (const auto& frame : info->retransmittable_frames) { + if (frame.type == STREAM_FRAME) { + stream_id = frame.stream_frame.stream_id; + break; + } + } + UpdatePacketState( + old_packet_number, + QuicUtils::RetransmissionTypeToPacketState(transmission_type)); + info->first_sent_after_loss = QuicPacketNumber(new_packet_number); + SerializedPacket packet( + CreateRetransmittablePacketForStream(new_packet_number, stream_id)); + unacked_packets_.AddSentPacket(&packet, transmission_type, now_, true, true, + ECN_NOT_ECT); + } + QuicUnackedPacketMap unacked_packets_; + QuicTime now_; + StrictMock notifier_; +}; + +INSTANTIATE_TEST_SUITE_P(Tests, QuicUnackedPacketMapTest, + ::testing::ValuesIn({Perspective::IS_CLIENT, + Perspective::IS_SERVER}), + ::testing::PrintToStringParamName()); + +TEST_P(QuicUnackedPacketMapTest, RttOnly) { + // Acks are only tracked for RTT measurement purposes. + SerializedPacket packet(CreateNonRetransmittablePacket(1)); + unacked_packets_.AddSentPacket(&packet, NOT_RETRANSMISSION, now_, false, true, + ECN_NOT_ECT); + + uint64_t unacked[] = {1}; + VerifyUnackedPackets(unacked, ABSL_ARRAYSIZE(unacked)); + VerifyInFlightPackets(nullptr, 0); + VerifyRetransmittablePackets(nullptr, 0); + + unacked_packets_.IncreaseLargestAcked(QuicPacketNumber(1)); + VerifyUnackedPackets(nullptr, 0); + VerifyInFlightPackets(nullptr, 0); + VerifyRetransmittablePackets(nullptr, 0); +} + +TEST_P(QuicUnackedPacketMapTest, RetransmittableInflightAndRtt) { + // Simulate a retransmittable packet being sent and acked. + SerializedPacket packet(CreateRetransmittablePacket(1)); + unacked_packets_.AddSentPacket(&packet, NOT_RETRANSMISSION, now_, true, true, + ECN_NOT_ECT); + + uint64_t unacked[] = {1}; + VerifyUnackedPackets(unacked, ABSL_ARRAYSIZE(unacked)); + VerifyInFlightPackets(unacked, ABSL_ARRAYSIZE(unacked)); + VerifyRetransmittablePackets(unacked, ABSL_ARRAYSIZE(unacked)); + + unacked_packets_.RemoveRetransmittability(QuicPacketNumber(1)); + VerifyUnackedPackets(unacked, ABSL_ARRAYSIZE(unacked)); + VerifyInFlightPackets(unacked, ABSL_ARRAYSIZE(unacked)); + VerifyRetransmittablePackets(nullptr, 0); + + unacked_packets_.IncreaseLargestAcked(QuicPacketNumber(1)); + VerifyUnackedPackets(unacked, ABSL_ARRAYSIZE(unacked)); + VerifyInFlightPackets(unacked, ABSL_ARRAYSIZE(unacked)); + VerifyRetransmittablePackets(nullptr, 0); + + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(1)); + VerifyUnackedPackets(nullptr, 0); + VerifyInFlightPackets(nullptr, 0); + VerifyRetransmittablePackets(nullptr, 0); +} + +TEST_P(QuicUnackedPacketMapTest, StopRetransmission) { + const QuicStreamId stream_id = 2; + SerializedPacket packet(CreateRetransmittablePacketForStream(1, stream_id)); + unacked_packets_.AddSentPacket(&packet, NOT_RETRANSMISSION, now_, true, true, + ECN_NOT_ECT); + + uint64_t unacked[] = {1}; + VerifyUnackedPackets(unacked, ABSL_ARRAYSIZE(unacked)); + VerifyInFlightPackets(unacked, ABSL_ARRAYSIZE(unacked)); + uint64_t retransmittable[] = {1}; + VerifyRetransmittablePackets(retransmittable, + ABSL_ARRAYSIZE(retransmittable)); + + EXPECT_CALL(notifier_, IsFrameOutstanding(_)).WillRepeatedly(Return(false)); + VerifyUnackedPackets(unacked, ABSL_ARRAYSIZE(unacked)); + VerifyInFlightPackets(unacked, ABSL_ARRAYSIZE(unacked)); + VerifyRetransmittablePackets(nullptr, 0); +} + +TEST_P(QuicUnackedPacketMapTest, StopRetransmissionOnOtherStream) { + const QuicStreamId stream_id = 2; + SerializedPacket packet(CreateRetransmittablePacketForStream(1, stream_id)); + unacked_packets_.AddSentPacket(&packet, NOT_RETRANSMISSION, now_, true, true, + ECN_NOT_ECT); + + uint64_t unacked[] = {1}; + VerifyUnackedPackets(unacked, ABSL_ARRAYSIZE(unacked)); + VerifyInFlightPackets(unacked, ABSL_ARRAYSIZE(unacked)); + uint64_t retransmittable[] = {1}; + VerifyRetransmittablePackets(retransmittable, + ABSL_ARRAYSIZE(retransmittable)); + + VerifyUnackedPackets(unacked, ABSL_ARRAYSIZE(unacked)); + VerifyInFlightPackets(unacked, ABSL_ARRAYSIZE(unacked)); + VerifyRetransmittablePackets(retransmittable, + ABSL_ARRAYSIZE(retransmittable)); +} + +TEST_P(QuicUnackedPacketMapTest, StopRetransmissionAfterRetransmission) { + const QuicStreamId stream_id = 2; + SerializedPacket packet1(CreateRetransmittablePacketForStream(1, stream_id)); + unacked_packets_.AddSentPacket(&packet1, NOT_RETRANSMISSION, now_, true, true, + ECN_NOT_ECT); + RetransmitAndSendPacket(1, 2, LOSS_RETRANSMISSION); + + uint64_t unacked[] = {1, 2}; + VerifyUnackedPackets(unacked, ABSL_ARRAYSIZE(unacked)); + VerifyInFlightPackets(unacked, ABSL_ARRAYSIZE(unacked)); + std::vector retransmittable = {1, 2}; + VerifyRetransmittablePackets(&retransmittable[0], retransmittable.size()); + + EXPECT_CALL(notifier_, IsFrameOutstanding(_)).WillRepeatedly(Return(false)); + VerifyUnackedPackets(unacked, ABSL_ARRAYSIZE(unacked)); + VerifyInFlightPackets(unacked, ABSL_ARRAYSIZE(unacked)); + VerifyRetransmittablePackets(nullptr, 0); +} + +TEST_P(QuicUnackedPacketMapTest, RetransmittedPacket) { + // Simulate a retransmittable packet being sent, retransmitted, and the first + // transmission being acked. + SerializedPacket packet1(CreateRetransmittablePacket(1)); + unacked_packets_.AddSentPacket(&packet1, NOT_RETRANSMISSION, now_, true, true, + ECN_NOT_ECT); + RetransmitAndSendPacket(1, 2, LOSS_RETRANSMISSION); + + uint64_t unacked[] = {1, 2}; + VerifyUnackedPackets(unacked, ABSL_ARRAYSIZE(unacked)); + VerifyInFlightPackets(unacked, ABSL_ARRAYSIZE(unacked)); + std::vector retransmittable = {1, 2}; + VerifyRetransmittablePackets(&retransmittable[0], retransmittable.size()); + + EXPECT_CALL(notifier_, IsFrameOutstanding(_)).WillRepeatedly(Return(false)); + unacked_packets_.RemoveRetransmittability(QuicPacketNumber(1)); + VerifyUnackedPackets(unacked, ABSL_ARRAYSIZE(unacked)); + VerifyInFlightPackets(unacked, ABSL_ARRAYSIZE(unacked)); + VerifyRetransmittablePackets(nullptr, 0); + + unacked_packets_.IncreaseLargestAcked(QuicPacketNumber(2)); + VerifyUnackedPackets(unacked, ABSL_ARRAYSIZE(unacked)); + VerifyInFlightPackets(unacked, ABSL_ARRAYSIZE(unacked)); + VerifyRetransmittablePackets(nullptr, 0); + + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(2)); + uint64_t unacked2[] = {1}; + VerifyUnackedPackets(unacked2, ABSL_ARRAYSIZE(unacked2)); + VerifyInFlightPackets(unacked2, ABSL_ARRAYSIZE(unacked2)); + VerifyRetransmittablePackets(nullptr, 0); + + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(1)); + VerifyUnackedPackets(nullptr, 0); + VerifyInFlightPackets(nullptr, 0); + VerifyRetransmittablePackets(nullptr, 0); +} + +TEST_P(QuicUnackedPacketMapTest, RetransmitThreeTimes) { + // Simulate a retransmittable packet being sent and retransmitted twice. + SerializedPacket packet1(CreateRetransmittablePacket(1)); + unacked_packets_.AddSentPacket(&packet1, NOT_RETRANSMISSION, now_, true, true, + ECN_NOT_ECT); + SerializedPacket packet2(CreateRetransmittablePacket(2)); + unacked_packets_.AddSentPacket(&packet2, NOT_RETRANSMISSION, now_, true, true, + ECN_NOT_ECT); + + uint64_t unacked[] = {1, 2}; + VerifyUnackedPackets(unacked, ABSL_ARRAYSIZE(unacked)); + VerifyInFlightPackets(unacked, ABSL_ARRAYSIZE(unacked)); + uint64_t retransmittable[] = {1, 2}; + VerifyRetransmittablePackets(retransmittable, + ABSL_ARRAYSIZE(retransmittable)); + + // Early retransmit 1 as 3 and send new data as 4. + unacked_packets_.IncreaseLargestAcked(QuicPacketNumber(2)); + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(2)); + unacked_packets_.RemoveRetransmittability(QuicPacketNumber(2)); + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(1)); + RetransmitAndSendPacket(1, 3, LOSS_RETRANSMISSION); + SerializedPacket packet4(CreateRetransmittablePacket(4)); + unacked_packets_.AddSentPacket(&packet4, NOT_RETRANSMISSION, now_, true, true, + ECN_NOT_ECT); + + uint64_t unacked2[] = {1, 3, 4}; + VerifyUnackedPackets(unacked2, ABSL_ARRAYSIZE(unacked2)); + uint64_t pending2[] = {3, 4}; + VerifyInFlightPackets(pending2, ABSL_ARRAYSIZE(pending2)); + std::vector retransmittable2 = {1, 3, 4}; + VerifyRetransmittablePackets(&retransmittable2[0], retransmittable2.size()); + + // Early retransmit 3 (formerly 1) as 5, and remove 1 from unacked. + unacked_packets_.IncreaseLargestAcked(QuicPacketNumber(4)); + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(4)); + unacked_packets_.RemoveRetransmittability(QuicPacketNumber(4)); + RetransmitAndSendPacket(3, 5, LOSS_RETRANSMISSION); + SerializedPacket packet6(CreateRetransmittablePacket(6)); + unacked_packets_.AddSentPacket(&packet6, NOT_RETRANSMISSION, now_, true, true, + ECN_NOT_ECT); + + std::vector unacked3 = {3, 5, 6}; + std::vector retransmittable3 = {3, 5, 6}; + VerifyUnackedPackets(&unacked3[0], unacked3.size()); + VerifyRetransmittablePackets(&retransmittable3[0], retransmittable3.size()); + uint64_t pending3[] = {3, 5, 6}; + VerifyInFlightPackets(pending3, ABSL_ARRAYSIZE(pending3)); + + // Early retransmit 5 as 7 and ensure in flight packet 3 is not removed. + unacked_packets_.IncreaseLargestAcked(QuicPacketNumber(6)); + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(6)); + unacked_packets_.RemoveRetransmittability(QuicPacketNumber(6)); + RetransmitAndSendPacket(5, 7, LOSS_RETRANSMISSION); + + std::vector unacked4 = {3, 5, 7}; + std::vector retransmittable4 = {3, 5, 7}; + VerifyUnackedPackets(&unacked4[0], unacked4.size()); + VerifyRetransmittablePackets(&retransmittable4[0], retransmittable4.size()); + uint64_t pending4[] = {3, 5, 7}; + VerifyInFlightPackets(pending4, ABSL_ARRAYSIZE(pending4)); + + // Remove the older two transmissions from in flight. + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(3)); + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(5)); + uint64_t pending5[] = {7}; + VerifyInFlightPackets(pending5, ABSL_ARRAYSIZE(pending5)); +} + +TEST_P(QuicUnackedPacketMapTest, RetransmitFourTimes) { + // Simulate a retransmittable packet being sent and retransmitted twice. + SerializedPacket packet1(CreateRetransmittablePacket(1)); + unacked_packets_.AddSentPacket(&packet1, NOT_RETRANSMISSION, now_, true, true, + ECN_NOT_ECT); + SerializedPacket packet2(CreateRetransmittablePacket(2)); + unacked_packets_.AddSentPacket(&packet2, NOT_RETRANSMISSION, now_, true, true, + ECN_NOT_ECT); + + uint64_t unacked[] = {1, 2}; + VerifyUnackedPackets(unacked, ABSL_ARRAYSIZE(unacked)); + VerifyInFlightPackets(unacked, ABSL_ARRAYSIZE(unacked)); + uint64_t retransmittable[] = {1, 2}; + VerifyRetransmittablePackets(retransmittable, + ABSL_ARRAYSIZE(retransmittable)); + + // Early retransmit 1 as 3. + unacked_packets_.IncreaseLargestAcked(QuicPacketNumber(2)); + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(2)); + unacked_packets_.RemoveRetransmittability(QuicPacketNumber(2)); + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(1)); + RetransmitAndSendPacket(1, 3, LOSS_RETRANSMISSION); + + uint64_t unacked2[] = {1, 3}; + VerifyUnackedPackets(unacked2, ABSL_ARRAYSIZE(unacked2)); + uint64_t pending2[] = {3}; + VerifyInFlightPackets(pending2, ABSL_ARRAYSIZE(pending2)); + std::vector retransmittable2 = {1, 3}; + VerifyRetransmittablePackets(&retransmittable2[0], retransmittable2.size()); + + // PTO 3 (formerly 1) as 4, and don't remove 1 from unacked. + RetransmitAndSendPacket(3, 4, PTO_RETRANSMISSION); + SerializedPacket packet5(CreateRetransmittablePacket(5)); + unacked_packets_.AddSentPacket(&packet5, NOT_RETRANSMISSION, now_, true, true, + ECN_NOT_ECT); + + uint64_t unacked3[] = {1, 3, 4, 5}; + VerifyUnackedPackets(unacked3, ABSL_ARRAYSIZE(unacked3)); + uint64_t pending3[] = {3, 4, 5}; + VerifyInFlightPackets(pending3, ABSL_ARRAYSIZE(pending3)); + std::vector retransmittable3 = {1, 3, 4, 5}; + VerifyRetransmittablePackets(&retransmittable3[0], retransmittable3.size()); + + // Early retransmit 4 as 6 and ensure in flight packet 3 is removed. + unacked_packets_.IncreaseLargestAcked(QuicPacketNumber(5)); + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(5)); + unacked_packets_.RemoveRetransmittability(QuicPacketNumber(5)); + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(3)); + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(4)); + RetransmitAndSendPacket(4, 6, LOSS_RETRANSMISSION); + + std::vector unacked4 = {4, 6}; + VerifyUnackedPackets(&unacked4[0], unacked4.size()); + uint64_t pending4[] = {6}; + VerifyInFlightPackets(pending4, ABSL_ARRAYSIZE(pending4)); + std::vector retransmittable4 = {4, 6}; + VerifyRetransmittablePackets(&retransmittable4[0], retransmittable4.size()); +} + +TEST_P(QuicUnackedPacketMapTest, SendWithGap) { + // Simulate a retransmittable packet being sent, retransmitted, and the first + // transmission being acked. + SerializedPacket packet1(CreateRetransmittablePacket(1)); + unacked_packets_.AddSentPacket(&packet1, NOT_RETRANSMISSION, now_, true, true, + ECN_NOT_ECT); + SerializedPacket packet3(CreateRetransmittablePacket(3)); + unacked_packets_.AddSentPacket(&packet3, NOT_RETRANSMISSION, now_, true, true, + ECN_NOT_ECT); + RetransmitAndSendPacket(3, 5, LOSS_RETRANSMISSION); + + EXPECT_EQ(QuicPacketNumber(1u), unacked_packets_.GetLeastUnacked()); + EXPECT_TRUE(unacked_packets_.IsUnacked(QuicPacketNumber(1))); + EXPECT_FALSE(unacked_packets_.IsUnacked(QuicPacketNumber(2))); + EXPECT_TRUE(unacked_packets_.IsUnacked(QuicPacketNumber(3))); + EXPECT_FALSE(unacked_packets_.IsUnacked(QuicPacketNumber(4))); + EXPECT_TRUE(unacked_packets_.IsUnacked(QuicPacketNumber(5))); + EXPECT_EQ(QuicPacketNumber(5u), unacked_packets_.largest_sent_packet()); +} + +TEST_P(QuicUnackedPacketMapTest, AggregateContiguousAckedStreamFrames) { + testing::InSequence s; + EXPECT_CALL(notifier_, OnFrameAcked(_, _, _)).Times(0); + unacked_packets_.NotifyAggregatedStreamFrameAcked(QuicTime::Delta::Zero()); + + QuicTransmissionInfo info1; + QuicStreamFrame stream_frame1(3, false, 0, 100); + info1.retransmittable_frames.push_back(QuicFrame(stream_frame1)); + + QuicTransmissionInfo info2; + QuicStreamFrame stream_frame2(3, false, 100, 100); + info2.retransmittable_frames.push_back(QuicFrame(stream_frame2)); + + QuicTransmissionInfo info3; + QuicStreamFrame stream_frame3(3, false, 200, 100); + info3.retransmittable_frames.push_back(QuicFrame(stream_frame3)); + + QuicTransmissionInfo info4; + QuicStreamFrame stream_frame4(3, true, 300, 0); + info4.retransmittable_frames.push_back(QuicFrame(stream_frame4)); + + // Verify stream frames are aggregated. + EXPECT_CALL(notifier_, OnFrameAcked(_, _, _)).Times(0); + unacked_packets_.MaybeAggregateAckedStreamFrame( + info1, QuicTime::Delta::Zero(), QuicTime::Zero()); + EXPECT_CALL(notifier_, OnFrameAcked(_, _, _)).Times(0); + unacked_packets_.MaybeAggregateAckedStreamFrame( + info2, QuicTime::Delta::Zero(), QuicTime::Zero()); + EXPECT_CALL(notifier_, OnFrameAcked(_, _, _)).Times(0); + unacked_packets_.MaybeAggregateAckedStreamFrame( + info3, QuicTime::Delta::Zero(), QuicTime::Zero()); + + // Verify aggregated stream frame gets acked since fin is acked. + EXPECT_CALL(notifier_, OnFrameAcked(_, _, _)).Times(1); + unacked_packets_.MaybeAggregateAckedStreamFrame( + info4, QuicTime::Delta::Zero(), QuicTime::Zero()); +} + +// Regression test for b/112930090. +TEST_P(QuicUnackedPacketMapTest, CannotAggregateIfDataLengthOverflow) { + QuicByteCount kMaxAggregatedDataLength = + std::numeric_limits::max(); + QuicStreamId stream_id = 2; + + // acked_stream_length=512 covers the case where a frame will cause the + // aggregated frame length to be exactly 64K. + // acked_stream_length=1300 covers the case where a frame will cause the + // aggregated frame length to exceed 64K. + for (const QuicPacketLength acked_stream_length : {512, 1300}) { + ++stream_id; + QuicStreamOffset offset = 0; + // Expected length of the aggregated stream frame. + QuicByteCount aggregated_data_length = 0; + + while (offset < 1e6) { + QuicTransmissionInfo info; + QuicStreamFrame stream_frame(stream_id, false, offset, + acked_stream_length); + info.retransmittable_frames.push_back(QuicFrame(stream_frame)); + + const QuicStreamFrame& aggregated_stream_frame = + QuicUnackedPacketMapPeer::GetAggregatedStreamFrame(unacked_packets_); + if (aggregated_stream_frame.data_length + acked_stream_length <= + kMaxAggregatedDataLength) { + // Verify the acked stream frame can be aggregated. + EXPECT_CALL(notifier_, OnFrameAcked(_, _, _)).Times(0); + unacked_packets_.MaybeAggregateAckedStreamFrame( + info, QuicTime::Delta::Zero(), QuicTime::Zero()); + aggregated_data_length += acked_stream_length; + testing::Mock::VerifyAndClearExpectations(¬ifier_); + } else { + // Verify the acked stream frame cannot be aggregated because + // data_length is overflow. + EXPECT_CALL(notifier_, OnFrameAcked(_, _, _)).Times(1); + unacked_packets_.MaybeAggregateAckedStreamFrame( + info, QuicTime::Delta::Zero(), QuicTime::Zero()); + aggregated_data_length = acked_stream_length; + testing::Mock::VerifyAndClearExpectations(¬ifier_); + } + + EXPECT_EQ(aggregated_data_length, aggregated_stream_frame.data_length); + offset += acked_stream_length; + } + + // Ack the last frame of the stream. + QuicTransmissionInfo info; + QuicStreamFrame stream_frame(stream_id, true, offset, acked_stream_length); + info.retransmittable_frames.push_back(QuicFrame(stream_frame)); + EXPECT_CALL(notifier_, OnFrameAcked(_, _, _)).Times(1); + unacked_packets_.MaybeAggregateAckedStreamFrame( + info, QuicTime::Delta::Zero(), QuicTime::Zero()); + testing::Mock::VerifyAndClearExpectations(¬ifier_); + } +} + +TEST_P(QuicUnackedPacketMapTest, CannotAggregateAckedControlFrames) { + testing::InSequence s; + QuicWindowUpdateFrame window_update(1, 5, 100); + QuicStreamFrame stream_frame1(3, false, 0, 100); + QuicStreamFrame stream_frame2(3, false, 100, 100); + QuicBlockedFrame blocked(2, 5, 0); + QuicGoAwayFrame go_away(3, QUIC_PEER_GOING_AWAY, 5, "Going away."); + + QuicTransmissionInfo info1; + info1.retransmittable_frames.push_back(QuicFrame(window_update)); + info1.retransmittable_frames.push_back(QuicFrame(stream_frame1)); + info1.retransmittable_frames.push_back(QuicFrame(stream_frame2)); + + QuicTransmissionInfo info2; + info2.retransmittable_frames.push_back(QuicFrame(blocked)); + info2.retransmittable_frames.push_back(QuicFrame(&go_away)); + + // Verify 2 contiguous stream frames are aggregated. + EXPECT_CALL(notifier_, OnFrameAcked(_, _, _)).Times(1); + unacked_packets_.MaybeAggregateAckedStreamFrame( + info1, QuicTime::Delta::Zero(), QuicTime::Zero()); + // Verify aggregated stream frame gets acked. + EXPECT_CALL(notifier_, OnFrameAcked(_, _, _)).Times(3); + unacked_packets_.MaybeAggregateAckedStreamFrame( + info2, QuicTime::Delta::Zero(), QuicTime::Zero()); + + EXPECT_CALL(notifier_, OnFrameAcked(_, _, _)).Times(0); + unacked_packets_.NotifyAggregatedStreamFrameAcked(QuicTime::Delta::Zero()); +} + +TEST_P(QuicUnackedPacketMapTest, LargestSentPacketMultiplePacketNumberSpaces) { + unacked_packets_.EnableMultiplePacketNumberSpacesSupport(); + EXPECT_FALSE( + unacked_packets_ + .GetLargestSentRetransmittableOfPacketNumberSpace(INITIAL_DATA) + .IsInitialized()); + // Send packet 1. + SerializedPacket packet1(CreateRetransmittablePacket(1)); + packet1.encryption_level = ENCRYPTION_INITIAL; + unacked_packets_.AddSentPacket(&packet1, NOT_RETRANSMISSION, now_, true, true, + ECN_NOT_ECT); + EXPECT_EQ(QuicPacketNumber(1u), unacked_packets_.largest_sent_packet()); + EXPECT_EQ(QuicPacketNumber(1), + unacked_packets_.GetLargestSentRetransmittableOfPacketNumberSpace( + INITIAL_DATA)); + EXPECT_FALSE( + unacked_packets_ + .GetLargestSentRetransmittableOfPacketNumberSpace(HANDSHAKE_DATA) + .IsInitialized()); + // Send packet 2. + SerializedPacket packet2(CreateRetransmittablePacket(2)); + packet2.encryption_level = ENCRYPTION_HANDSHAKE; + unacked_packets_.AddSentPacket(&packet2, NOT_RETRANSMISSION, now_, true, true, + ECN_NOT_ECT); + EXPECT_EQ(QuicPacketNumber(2u), unacked_packets_.largest_sent_packet()); + EXPECT_EQ(QuicPacketNumber(1), + unacked_packets_.GetLargestSentRetransmittableOfPacketNumberSpace( + INITIAL_DATA)); + EXPECT_EQ(QuicPacketNumber(2), + unacked_packets_.GetLargestSentRetransmittableOfPacketNumberSpace( + HANDSHAKE_DATA)); + EXPECT_FALSE( + unacked_packets_ + .GetLargestSentRetransmittableOfPacketNumberSpace(APPLICATION_DATA) + .IsInitialized()); + // Send packet 3. + SerializedPacket packet3(CreateRetransmittablePacket(3)); + packet3.encryption_level = ENCRYPTION_ZERO_RTT; + unacked_packets_.AddSentPacket(&packet3, NOT_RETRANSMISSION, now_, true, true, + ECN_NOT_ECT); + EXPECT_EQ(QuicPacketNumber(3u), unacked_packets_.largest_sent_packet()); + EXPECT_EQ(QuicPacketNumber(1), + unacked_packets_.GetLargestSentRetransmittableOfPacketNumberSpace( + INITIAL_DATA)); + EXPECT_EQ(QuicPacketNumber(2), + unacked_packets_.GetLargestSentRetransmittableOfPacketNumberSpace( + HANDSHAKE_DATA)); + EXPECT_EQ(QuicPacketNumber(3), + unacked_packets_.GetLargestSentRetransmittableOfPacketNumberSpace( + APPLICATION_DATA)); + // Verify forward secure belongs to the same packet number space as encryption + // zero rtt. + EXPECT_EQ(QuicPacketNumber(3), + unacked_packets_.GetLargestSentRetransmittableOfPacketNumberSpace( + APPLICATION_DATA)); + + // Send packet 4. + SerializedPacket packet4(CreateRetransmittablePacket(4)); + packet4.encryption_level = ENCRYPTION_FORWARD_SECURE; + unacked_packets_.AddSentPacket(&packet4, NOT_RETRANSMISSION, now_, true, true, + ECN_NOT_ECT); + EXPECT_EQ(QuicPacketNumber(4u), unacked_packets_.largest_sent_packet()); + EXPECT_EQ(QuicPacketNumber(1), + unacked_packets_.GetLargestSentRetransmittableOfPacketNumberSpace( + INITIAL_DATA)); + EXPECT_EQ(QuicPacketNumber(2), + unacked_packets_.GetLargestSentRetransmittableOfPacketNumberSpace( + HANDSHAKE_DATA)); + EXPECT_EQ(QuicPacketNumber(4), + unacked_packets_.GetLargestSentRetransmittableOfPacketNumberSpace( + APPLICATION_DATA)); + // Verify forward secure belongs to the same packet number space as encryption + // zero rtt. + EXPECT_EQ(QuicPacketNumber(4), + unacked_packets_.GetLargestSentRetransmittableOfPacketNumberSpace( + APPLICATION_DATA)); + EXPECT_TRUE(unacked_packets_.GetLastPacketContent() & (1 << STREAM_FRAME)); + EXPECT_FALSE(unacked_packets_.GetLastPacketContent() & (1 << ACK_FRAME)); +} + +TEST_P(QuicUnackedPacketMapTest, ReserveInitialCapacityTest) { + QuicUnackedPacketMap unacked_packets(GetParam()); + ASSERT_EQ(QuicUnackedPacketMapPeer::GetCapacity(unacked_packets), 0u); + unacked_packets.ReserveInitialCapacity(16); + QuicStreamId stream_id(1); + SerializedPacket packet(CreateRetransmittablePacketForStream(1, stream_id)); + unacked_packets.AddSentPacket(&packet, TransmissionType::NOT_RETRANSMISSION, + now_, true, true, ECN_NOT_ECT); + ASSERT_EQ(QuicUnackedPacketMapPeer::GetCapacity(unacked_packets), 16u); +} + +TEST_P(QuicUnackedPacketMapTest, DebugString) { + EXPECT_EQ(unacked_packets_.DebugString(), + "{size: 0, least_unacked: 1, largest_sent_packet: uninitialized, " + "largest_acked: uninitialized, bytes_in_flight: 0, " + "packets_in_flight: 0}"); + + SerializedPacket packet1(CreateRetransmittablePacket(1)); + unacked_packets_.AddSentPacket(&packet1, NOT_RETRANSMISSION, now_, true, true, + ECN_NOT_ECT); + EXPECT_EQ( + unacked_packets_.DebugString(), + "{size: 1, least_unacked: 1, largest_sent_packet: 1, largest_acked: " + "uninitialized, bytes_in_flight: 1000, packets_in_flight: 1}"); + + SerializedPacket packet2(CreateRetransmittablePacket(2)); + unacked_packets_.AddSentPacket(&packet2, NOT_RETRANSMISSION, now_, true, true, + ECN_NOT_ECT); + unacked_packets_.RemoveFromInFlight(QuicPacketNumber(1)); + unacked_packets_.IncreaseLargestAcked(QuicPacketNumber(1)); + unacked_packets_.RemoveObsoletePackets(); + EXPECT_EQ( + unacked_packets_.DebugString(), + "{size: 1, least_unacked: 2, largest_sent_packet: 2, largest_acked: 1, " + "bytes_in_flight: 1000, packets_in_flight: 1}"); +} + +TEST_P(QuicUnackedPacketMapTest, EcnInfoStored) { + SerializedPacket packet1(CreateRetransmittablePacket(1)); + unacked_packets_.AddSentPacket(&packet1, NOT_RETRANSMISSION, now_, true, true, + ECN_NOT_ECT); + SerializedPacket packet2(CreateRetransmittablePacket(2)); + unacked_packets_.AddSentPacket(&packet2, NOT_RETRANSMISSION, now_, true, true, + ECN_ECT0); + SerializedPacket packet3(CreateRetransmittablePacket(3)); + unacked_packets_.AddSentPacket(&packet3, NOT_RETRANSMISSION, now_, true, true, + ECN_ECT1); + EXPECT_EQ( + unacked_packets_.GetTransmissionInfo(QuicPacketNumber(1)).ecn_codepoint, + ECN_NOT_ECT); + EXPECT_EQ( + unacked_packets_.GetTransmissionInfo(QuicPacketNumber(2)).ecn_codepoint, + ECN_ECT0); + EXPECT_EQ( + unacked_packets_.GetTransmissionInfo(QuicPacketNumber(3)).ecn_codepoint, + ECN_ECT1); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_utils.cc b/quiche/quic/core/quic_utils.cc new file mode 100644 index 000000000000..5a86673880f2 --- /dev/null +++ b/quiche/quic/core/quic_utils.cc @@ -0,0 +1,630 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_utils.h" + +#include +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/base/optimization.h" +#include "absl/numeric/int128.h" +#include "absl/strings/string_view.h" +#include "openssl/sha.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/common/quiche_endian.h" + +namespace quic { +namespace { + +// We know that >= GCC 4.8 and Clang have a __uint128_t intrinsic. Other +// compilers don't necessarily, notably MSVC. +#if defined(__x86_64__) && \ + ((defined(__GNUC__) && \ + (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8))) || \ + defined(__clang__)) +#define QUIC_UTIL_HAS_UINT128 1 +#endif + +#ifdef QUIC_UTIL_HAS_UINT128 +absl::uint128 IncrementalHashFast(absl::uint128 uhash, absl::string_view data) { + // This code ends up faster than the naive implementation for 2 reasons: + // 1. absl::uint128 is sufficiently complicated that the compiler + // cannot transform the multiplication by kPrime into a shift-multiply-add; + // it has go through all of the instructions for a 128-bit multiply. + // 2. Because there are so fewer instructions (around 13), the hot loop fits + // nicely in the instruction queue of many Intel CPUs. + // kPrime = 309485009821345068724781371 + static const absl::uint128 kPrime = + (static_cast(16777216) << 64) + 315; + auto hi = absl::Uint128High64(uhash); + auto lo = absl::Uint128Low64(uhash); + absl::uint128 xhash = (static_cast(hi) << 64) + lo; + const uint8_t* octets = reinterpret_cast(data.data()); + for (size_t i = 0; i < data.length(); ++i) { + xhash = (xhash ^ static_cast(octets[i])) * kPrime; + } + return absl::MakeUint128(absl::Uint128High64(xhash), + absl::Uint128Low64(xhash)); +} +#endif + +#ifndef QUIC_UTIL_HAS_UINT128 +// Slow implementation of IncrementalHash. In practice, only used by Chromium. +absl::uint128 IncrementalHashSlow(absl::uint128 hash, absl::string_view data) { + // kPrime = 309485009821345068724781371 + static const absl::uint128 kPrime = absl::MakeUint128(16777216, 315); + const uint8_t* octets = reinterpret_cast(data.data()); + for (size_t i = 0; i < data.length(); ++i) { + hash = hash ^ absl::MakeUint128(0, octets[i]); + hash = hash * kPrime; + } + return hash; +} +#endif + +absl::uint128 IncrementalHash(absl::uint128 hash, absl::string_view data) { +#ifdef QUIC_UTIL_HAS_UINT128 + return IncrementalHashFast(hash, data); +#else + return IncrementalHashSlow(hash, data); +#endif +} + +} // namespace + +// static +uint64_t QuicUtils::FNV1a_64_Hash(absl::string_view data) { + static const uint64_t kOffset = UINT64_C(14695981039346656037); + static const uint64_t kPrime = UINT64_C(1099511628211); + + const uint8_t* octets = reinterpret_cast(data.data()); + + uint64_t hash = kOffset; + + for (size_t i = 0; i < data.length(); ++i) { + hash = hash ^ octets[i]; + hash = hash * kPrime; + } + + return hash; +} + +// static +absl::uint128 QuicUtils::FNV1a_128_Hash(absl::string_view data) { + return FNV1a_128_Hash_Three(data, absl::string_view(), absl::string_view()); +} + +// static +absl::uint128 QuicUtils::FNV1a_128_Hash_Two(absl::string_view data1, + absl::string_view data2) { + return FNV1a_128_Hash_Three(data1, data2, absl::string_view()); +} + +// static +absl::uint128 QuicUtils::FNV1a_128_Hash_Three(absl::string_view data1, + absl::string_view data2, + absl::string_view data3) { + // The two constants are defined as part of the hash algorithm. + // see http://www.isthe.com/chongo/tech/comp/fnv/ + // kOffset = 144066263297769815596495629667062367629 + const absl::uint128 kOffset = absl::MakeUint128( + UINT64_C(7809847782465536322), UINT64_C(7113472399480571277)); + + absl::uint128 hash = IncrementalHash(kOffset, data1); + if (data2.empty()) { + return hash; + } + + hash = IncrementalHash(hash, data2); + if (data3.empty()) { + return hash; + } + return IncrementalHash(hash, data3); +} + +// static +void QuicUtils::SerializeUint128Short(absl::uint128 v, uint8_t* out) { + const uint64_t lo = absl::Uint128Low64(v); + const uint64_t hi = absl::Uint128High64(v); + // This assumes that the system is little-endian. + memcpy(out, &lo, sizeof(lo)); + memcpy(out + sizeof(lo), &hi, sizeof(hi) / 2); +} + +#define RETURN_STRING_LITERAL(x) \ + case x: \ + return #x; + +std::string QuicUtils::AddressChangeTypeToString(AddressChangeType type) { + switch (type) { + RETURN_STRING_LITERAL(NO_CHANGE); + RETURN_STRING_LITERAL(PORT_CHANGE); + RETURN_STRING_LITERAL(IPV4_SUBNET_CHANGE); + RETURN_STRING_LITERAL(IPV4_TO_IPV6_CHANGE); + RETURN_STRING_LITERAL(IPV6_TO_IPV4_CHANGE); + RETURN_STRING_LITERAL(IPV6_TO_IPV6_CHANGE); + RETURN_STRING_LITERAL(IPV4_TO_IPV4_CHANGE); + } + return "INVALID_ADDRESS_CHANGE_TYPE"; +} + +const char* QuicUtils::SentPacketStateToString(SentPacketState state) { + switch (state) { + RETURN_STRING_LITERAL(OUTSTANDING); + RETURN_STRING_LITERAL(NEVER_SENT); + RETURN_STRING_LITERAL(ACKED); + RETURN_STRING_LITERAL(UNACKABLE); + RETURN_STRING_LITERAL(NEUTERED); + RETURN_STRING_LITERAL(HANDSHAKE_RETRANSMITTED); + RETURN_STRING_LITERAL(LOST); + RETURN_STRING_LITERAL(PTO_RETRANSMITTED); + RETURN_STRING_LITERAL(NOT_CONTRIBUTING_RTT); + } + return "INVALID_SENT_PACKET_STATE"; +} + +// static +const char* QuicUtils::QuicLongHeaderTypetoString(QuicLongHeaderType type) { + switch (type) { + RETURN_STRING_LITERAL(VERSION_NEGOTIATION); + RETURN_STRING_LITERAL(INITIAL); + RETURN_STRING_LITERAL(RETRY); + RETURN_STRING_LITERAL(HANDSHAKE); + RETURN_STRING_LITERAL(ZERO_RTT_PROTECTED); + default: + return "INVALID_PACKET_TYPE"; + } +} + +// static +const char* QuicUtils::AckResultToString(AckResult result) { + switch (result) { + RETURN_STRING_LITERAL(PACKETS_NEWLY_ACKED); + RETURN_STRING_LITERAL(NO_PACKETS_NEWLY_ACKED); + RETURN_STRING_LITERAL(UNSENT_PACKETS_ACKED); + RETURN_STRING_LITERAL(UNACKABLE_PACKETS_ACKED); + RETURN_STRING_LITERAL(PACKETS_ACKED_IN_WRONG_PACKET_NUMBER_SPACE); + } + return "INVALID_ACK_RESULT"; +} + +// static +AddressChangeType QuicUtils::DetermineAddressChangeType( + const QuicSocketAddress& old_address, + const QuicSocketAddress& new_address) { + if (!old_address.IsInitialized() || !new_address.IsInitialized() || + old_address == new_address) { + return NO_CHANGE; + } + + if (old_address.host() == new_address.host()) { + return PORT_CHANGE; + } + + bool old_ip_is_ipv4 = old_address.host().IsIPv4() ? true : false; + bool migrating_ip_is_ipv4 = new_address.host().IsIPv4() ? true : false; + if (old_ip_is_ipv4 && !migrating_ip_is_ipv4) { + return IPV4_TO_IPV6_CHANGE; + } + + if (!old_ip_is_ipv4) { + return migrating_ip_is_ipv4 ? IPV6_TO_IPV4_CHANGE : IPV6_TO_IPV6_CHANGE; + } + + const int kSubnetMaskLength = 24; + if (old_address.host().InSameSubnet(new_address.host(), kSubnetMaskLength)) { + // Subnet part does not change (here, we use /24), which is considered to be + // caused by NATs. + return IPV4_SUBNET_CHANGE; + } + + return IPV4_TO_IPV4_CHANGE; +} + +// static +bool QuicUtils::IsAckable(SentPacketState state) { + return state != NEVER_SENT && state != ACKED && state != UNACKABLE; +} + +// static +bool QuicUtils::IsRetransmittableFrame(QuicFrameType type) { + switch (type) { + case ACK_FRAME: + case PADDING_FRAME: + case STOP_WAITING_FRAME: + case MTU_DISCOVERY_FRAME: + case PATH_CHALLENGE_FRAME: + case PATH_RESPONSE_FRAME: + return false; + default: + return true; + } +} + +// static +bool QuicUtils::IsHandshakeFrame(const QuicFrame& frame, + QuicTransportVersion transport_version) { + if (!QuicVersionUsesCryptoFrames(transport_version)) { + return frame.type == STREAM_FRAME && + frame.stream_frame.stream_id == GetCryptoStreamId(transport_version); + } else { + return frame.type == CRYPTO_FRAME; + } +} + +// static +bool QuicUtils::ContainsFrameType(const QuicFrames& frames, + QuicFrameType type) { + for (const QuicFrame& frame : frames) { + if (frame.type == type) { + return true; + } + } + return false; +} + +// static +SentPacketState QuicUtils::RetransmissionTypeToPacketState( + TransmissionType retransmission_type) { + switch (retransmission_type) { + case ALL_ZERO_RTT_RETRANSMISSION: + return UNACKABLE; + case HANDSHAKE_RETRANSMISSION: + return HANDSHAKE_RETRANSMITTED; + case LOSS_RETRANSMISSION: + return LOST; + case PTO_RETRANSMISSION: + return PTO_RETRANSMITTED; + case PATH_RETRANSMISSION: + return NOT_CONTRIBUTING_RTT; + case ALL_INITIAL_RETRANSMISSION: + return UNACKABLE; + default: + QUIC_BUG(quic_bug_10839_2) + << retransmission_type << " is not a retransmission_type"; + return UNACKABLE; + } +} + +// static +bool QuicUtils::IsIetfPacketHeader(uint8_t first_byte) { + return (first_byte & FLAGS_LONG_HEADER) || (first_byte & FLAGS_FIXED_BIT) || + !(first_byte & FLAGS_DEMULTIPLEXING_BIT); +} + +// static +bool QuicUtils::IsIetfPacketShortHeader(uint8_t first_byte) { + return IsIetfPacketHeader(first_byte) && !(first_byte & FLAGS_LONG_HEADER); +} + +// static +QuicStreamId QuicUtils::GetInvalidStreamId(QuicTransportVersion version) { + return VersionHasIetfQuicFrames(version) + ? std::numeric_limits::max() + : 0; +} + +// static +QuicStreamId QuicUtils::GetCryptoStreamId(QuicTransportVersion version) { + QUIC_BUG_IF(quic_bug_12982_1, QuicVersionUsesCryptoFrames(version)) + << "CRYPTO data aren't in stream frames; they have no stream ID."; + return QuicVersionUsesCryptoFrames(version) ? GetInvalidStreamId(version) : 1; +} + +// static +bool QuicUtils::IsCryptoStreamId(QuicTransportVersion version, + QuicStreamId stream_id) { + if (QuicVersionUsesCryptoFrames(version)) { + return false; + } + return stream_id == GetCryptoStreamId(version); +} + +// static +QuicStreamId QuicUtils::GetHeadersStreamId(QuicTransportVersion version) { + QUICHE_DCHECK(!VersionUsesHttp3(version)); + return GetFirstBidirectionalStreamId(version, Perspective::IS_CLIENT); +} + +// static +bool QuicUtils::IsClientInitiatedStreamId(QuicTransportVersion version, + QuicStreamId id) { + if (id == GetInvalidStreamId(version)) { + return false; + } + return VersionHasIetfQuicFrames(version) ? id % 2 == 0 : id % 2 != 0; +} + +// static +bool QuicUtils::IsServerInitiatedStreamId(QuicTransportVersion version, + QuicStreamId id) { + if (id == GetInvalidStreamId(version)) { + return false; + } + return VersionHasIetfQuicFrames(version) ? id % 2 != 0 : id % 2 == 0; +} + +// static +bool QuicUtils::IsOutgoingStreamId(ParsedQuicVersion version, QuicStreamId id, + Perspective perspective) { + // Streams are outgoing streams, iff: + // - we are the server and the stream is server-initiated + // - we are the client and the stream is client-initiated. + const bool perspective_is_server = perspective == Perspective::IS_SERVER; + const bool stream_is_server = + QuicUtils::IsServerInitiatedStreamId(version.transport_version, id); + return perspective_is_server == stream_is_server; +} + +// static +bool QuicUtils::IsBidirectionalStreamId(QuicStreamId id, + ParsedQuicVersion version) { + QUICHE_DCHECK(version.HasIetfQuicFrames()); + return id % 4 < 2; +} + +// static +StreamType QuicUtils::GetStreamType(QuicStreamId id, Perspective perspective, + bool peer_initiated, + ParsedQuicVersion version) { + QUICHE_DCHECK(version.HasIetfQuicFrames()); + if (IsBidirectionalStreamId(id, version)) { + return BIDIRECTIONAL; + } + + if (peer_initiated) { + if (perspective == Perspective::IS_SERVER) { + QUICHE_DCHECK_EQ(2u, id % 4); + } else { + QUICHE_DCHECK_EQ(Perspective::IS_CLIENT, perspective); + QUICHE_DCHECK_EQ(3u, id % 4); + } + return READ_UNIDIRECTIONAL; + } + + if (perspective == Perspective::IS_SERVER) { + QUICHE_DCHECK_EQ(3u, id % 4); + } else { + QUICHE_DCHECK_EQ(Perspective::IS_CLIENT, perspective); + QUICHE_DCHECK_EQ(2u, id % 4); + } + return WRITE_UNIDIRECTIONAL; +} + +// static +QuicStreamId QuicUtils::StreamIdDelta(QuicTransportVersion version) { + return VersionHasIetfQuicFrames(version) ? 4 : 2; +} + +// static +QuicStreamId QuicUtils::GetFirstBidirectionalStreamId( + QuicTransportVersion version, Perspective perspective) { + if (VersionHasIetfQuicFrames(version)) { + return perspective == Perspective::IS_CLIENT ? 0 : 1; + } else if (QuicVersionUsesCryptoFrames(version)) { + return perspective == Perspective::IS_CLIENT ? 1 : 2; + } + return perspective == Perspective::IS_CLIENT ? 3 : 2; +} + +// static +QuicStreamId QuicUtils::GetFirstUnidirectionalStreamId( + QuicTransportVersion version, Perspective perspective) { + if (VersionHasIetfQuicFrames(version)) { + return perspective == Perspective::IS_CLIENT ? 2 : 3; + } else if (QuicVersionUsesCryptoFrames(version)) { + return perspective == Perspective::IS_CLIENT ? 1 : 2; + } + return perspective == Perspective::IS_CLIENT ? 3 : 2; +} + +// static +QuicStreamId QuicUtils::GetMaxClientInitiatedBidirectionalStreamId( + QuicTransportVersion version) { + if (VersionHasIetfQuicFrames(version)) { + // Client initiated bidirectional streams have stream IDs divisible by 4. + return std::numeric_limits::max() - 3; + } + + // Client initiated bidirectional streams have odd stream IDs. + return std::numeric_limits::max(); +} + +// static +QuicConnectionId QuicUtils::CreateRandomConnectionId() { + return CreateRandomConnectionId(kQuicDefaultConnectionIdLength, + QuicRandom::GetInstance()); +} + +// static +QuicConnectionId QuicUtils::CreateRandomConnectionId(QuicRandom* random) { + return CreateRandomConnectionId(kQuicDefaultConnectionIdLength, random); +} +// static +QuicConnectionId QuicUtils::CreateRandomConnectionId( + uint8_t connection_id_length) { + return CreateRandomConnectionId(connection_id_length, + QuicRandom::GetInstance()); +} + +// static +QuicConnectionId QuicUtils::CreateRandomConnectionId( + uint8_t connection_id_length, QuicRandom* random) { + QuicConnectionId connection_id; + connection_id.set_length(connection_id_length); + if (connection_id.length() > 0) { + random->RandBytes(connection_id.mutable_data(), connection_id.length()); + } + return connection_id; +} + +// static +QuicConnectionId QuicUtils::CreateZeroConnectionId( + QuicTransportVersion version) { + if (!VersionAllowsVariableLengthConnectionIds(version)) { + char connection_id_bytes[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + return QuicConnectionId(static_cast(connection_id_bytes), + ABSL_ARRAYSIZE(connection_id_bytes)); + } + return EmptyQuicConnectionId(); +} + +// static +bool QuicUtils::IsConnectionIdLengthValidForVersion( + size_t connection_id_length, QuicTransportVersion transport_version) { + // No version of QUIC can support lengths that do not fit in an uint8_t. + if (connection_id_length > + static_cast(std::numeric_limits::max())) { + return false; + } + + if (transport_version == QUIC_VERSION_UNSUPPORTED || + transport_version == QUIC_VERSION_RESERVED_FOR_NEGOTIATION) { + // Unknown versions could allow connection ID lengths up to 255. + return true; + } + + const uint8_t connection_id_length8 = + static_cast(connection_id_length); + // Versions that do not support variable lengths only support length 8. + if (!VersionAllowsVariableLengthConnectionIds(transport_version)) { + return connection_id_length8 == kQuicDefaultConnectionIdLength; + } + return connection_id_length8 <= kQuicMaxConnectionIdWithLengthPrefixLength; +} + +// static +bool QuicUtils::IsConnectionIdValidForVersion( + QuicConnectionId connection_id, QuicTransportVersion transport_version) { + return IsConnectionIdLengthValidForVersion(connection_id.length(), + transport_version); +} + +StatelessResetToken QuicUtils::GenerateStatelessResetToken( + QuicConnectionId connection_id) { + static_assert(sizeof(absl::uint128) == sizeof(StatelessResetToken), + "bad size"); + static_assert(alignof(absl::uint128) >= alignof(StatelessResetToken), + "bad alignment"); + absl::uint128 hash = FNV1a_128_Hash( + absl::string_view(connection_id.data(), connection_id.length())); + return *reinterpret_cast(&hash); +} + +// static +QuicStreamCount QuicUtils::GetMaxStreamCount() { + return (kMaxQuicStreamCount >> 2) + 1; +} + +// static +PacketNumberSpace QuicUtils::GetPacketNumberSpace( + EncryptionLevel encryption_level) { + switch (encryption_level) { + case ENCRYPTION_INITIAL: + return INITIAL_DATA; + case ENCRYPTION_HANDSHAKE: + return HANDSHAKE_DATA; + case ENCRYPTION_ZERO_RTT: + case ENCRYPTION_FORWARD_SECURE: + return APPLICATION_DATA; + default: + QUIC_BUG(quic_bug_10839_3) + << "Try to get packet number space of encryption level: " + << encryption_level; + return NUM_PACKET_NUMBER_SPACES; + } +} + +// static +EncryptionLevel QuicUtils::GetEncryptionLevelToSendAckofSpace( + PacketNumberSpace packet_number_space) { + switch (packet_number_space) { + case INITIAL_DATA: + return ENCRYPTION_INITIAL; + case HANDSHAKE_DATA: + return ENCRYPTION_HANDSHAKE; + case APPLICATION_DATA: + return ENCRYPTION_FORWARD_SECURE; + default: + QUICHE_DCHECK(false); + return NUM_ENCRYPTION_LEVELS; + } +} + +// static +bool QuicUtils::IsProbingFrame(QuicFrameType type) { + switch (type) { + case PATH_CHALLENGE_FRAME: + case PATH_RESPONSE_FRAME: + case NEW_CONNECTION_ID_FRAME: + case PADDING_FRAME: + return true; + default: + return false; + } +} + +// static +bool QuicUtils::IsAckElicitingFrame(QuicFrameType type) { + switch (type) { + case PADDING_FRAME: + case STOP_WAITING_FRAME: + case ACK_FRAME: + case CONNECTION_CLOSE_FRAME: + return false; + default: + return true; + } +} + +// static +bool QuicUtils::AreStatelessResetTokensEqual( + const StatelessResetToken& token1, const StatelessResetToken& token2) { + char byte = 0; + for (size_t i = 0; i < kStatelessResetTokenLength; i++) { + // This avoids compiler optimizations that could make us stop comparing + // after we find a byte that doesn't match. + byte |= (token1[i] ^ token2[i]); + } + return byte == 0; +} + +bool IsValidWebTransportSessionId(WebTransportSessionId id, + ParsedQuicVersion version) { + QUICHE_DCHECK(version.UsesHttp3()); + return (id <= std::numeric_limits::max()) && + QuicUtils::IsBidirectionalStreamId(id, version) && + QuicUtils::IsClientInitiatedStreamId(version.transport_version, id); +} + +QuicByteCount MemSliceSpanTotalSize(absl::Span span) { + QuicByteCount total = 0; + for (const quiche::QuicheMemSlice& slice : span) { + total += slice.length(); + } + return total; +} + +std::string RawSha256(absl::string_view input) { + std::string raw_hash; + raw_hash.resize(SHA256_DIGEST_LENGTH); + SHA256(reinterpret_cast(input.data()), input.size(), + reinterpret_cast(&raw_hash[0])); + return raw_hash; +} + +#undef RETURN_STRING_LITERAL // undef for jumbo builds +} // namespace quic diff --git a/quiche/quic/core/quic_utils.h b/quiche/quic/core/quic_utils.h new file mode 100644 index 000000000000..79212fd389bd --- /dev/null +++ b/quiche/quic/core/quic_utils.h @@ -0,0 +1,290 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_UTILS_H_ +#define QUICHE_QUIC_CORE_QUIC_UTILS_H_ + +#include +#include +#include +#include +#include + +#include "absl/numeric/int128.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/frames/quic_frame.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" + +namespace quic { + +class QUIC_EXPORT_PRIVATE QuicUtils { + public: + QuicUtils() = delete; + + // Returns the 64 bit FNV1a hash of the data. See + // http://www.isthe.com/chongo/tech/comp/fnv/index.html#FNV-param + static uint64_t FNV1a_64_Hash(absl::string_view data); + + // Returns the 128 bit FNV1a hash of the data. See + // http://www.isthe.com/chongo/tech/comp/fnv/index.html#FNV-param + static absl::uint128 FNV1a_128_Hash(absl::string_view data); + + // Returns the 128 bit FNV1a hash of the two sequences of data. See + // http://www.isthe.com/chongo/tech/comp/fnv/index.html#FNV-param + static absl::uint128 FNV1a_128_Hash_Two(absl::string_view data1, + absl::string_view data2); + + // Returns the 128 bit FNV1a hash of the three sequences of data. See + // http://www.isthe.com/chongo/tech/comp/fnv/index.html#FNV-param + static absl::uint128 FNV1a_128_Hash_Three(absl::string_view data1, + absl::string_view data2, + absl::string_view data3); + + // SerializeUint128 writes the first 96 bits of |v| in little-endian form + // to |out|. + static void SerializeUint128Short(absl::uint128 v, uint8_t* out); + + // Returns AddressChangeType as a string. + static std::string AddressChangeTypeToString(AddressChangeType type); + + // Returns SentPacketState as a char*. + static const char* SentPacketStateToString(SentPacketState state); + + // Returns QuicLongHeaderType as a char*. + static const char* QuicLongHeaderTypetoString(QuicLongHeaderType type); + + // Returns AckResult as a char*. + static const char* AckResultToString(AckResult result); + + // Determines and returns change type of address change from |old_address| to + // |new_address|. + static AddressChangeType DetermineAddressChangeType( + const QuicSocketAddress& old_address, + const QuicSocketAddress& new_address); + + // Returns the opposite Perspective of the |perspective| passed in. + static constexpr Perspective InvertPerspective(Perspective perspective) { + return perspective == Perspective::IS_CLIENT ? Perspective::IS_SERVER + : Perspective::IS_CLIENT; + } + + // Returns true if a packet is ackable. A packet is unackable if it can never + // be acked. Occurs when a packet is never sent, after it is acknowledged + // once, or if it's a crypto packet we never expect to receive an ack for. + static bool IsAckable(SentPacketState state); + + // Returns true if frame with |type| is retransmittable. A retransmittable + // frame should be retransmitted if it is detected as lost. + static bool IsRetransmittableFrame(QuicFrameType type); + + // Returns true if |frame| is a handshake frame in version |version|. + static bool IsHandshakeFrame(const QuicFrame& frame, + QuicTransportVersion transport_version); + + // Return true if any frame in |frames| is of |type|. + static bool ContainsFrameType(const QuicFrames& frames, QuicFrameType type); + + // Returns packet state corresponding to |retransmission_type|. + static SentPacketState RetransmissionTypeToPacketState( + TransmissionType retransmission_type); + + // Returns true if header with |first_byte| is considered as an IETF QUIC + // packet header. This only works on the server. + static bool IsIetfPacketHeader(uint8_t first_byte); + + // Returns true if header with |first_byte| is considered as an IETF QUIC + // short packet header. + static bool IsIetfPacketShortHeader(uint8_t first_byte); + + // Returns ID to denote an invalid stream of |version|. + static QuicStreamId GetInvalidStreamId(QuicTransportVersion version); + + // Returns crypto stream ID of |version|. + static QuicStreamId GetCryptoStreamId(QuicTransportVersion version); + + // Returns whether |id| is the stream ID for the crypto stream. If |version| + // is a version where crypto data doesn't go over stream frames, this function + // will always return false. + static bool IsCryptoStreamId(QuicTransportVersion version, QuicStreamId id); + + // Returns headers stream ID of |version|. + static QuicStreamId GetHeadersStreamId(QuicTransportVersion version); + + // Returns true if |id| is considered as client initiated stream ID. + static bool IsClientInitiatedStreamId(QuicTransportVersion version, + QuicStreamId id); + + // Returns true if |id| is considered as server initiated stream ID. + static bool IsServerInitiatedStreamId(QuicTransportVersion version, + QuicStreamId id); + + // Returns true if the stream ID represents a stream initiated by the + // provided perspective. + static bool IsOutgoingStreamId(ParsedQuicVersion version, QuicStreamId id, + Perspective perspective); + + // Returns true if |id| is considered as bidirectional stream ID. Only used in + // v99. + static bool IsBidirectionalStreamId(QuicStreamId id, + ParsedQuicVersion version); + + // Returns stream type. Either |perspective| or |peer_initiated| would be + // enough together with |id|. This method enforces that the three parameters + // are consistent. Only used in v99. + static StreamType GetStreamType(QuicStreamId id, Perspective perspective, + bool peer_initiated, + ParsedQuicVersion version); + + // Returns the delta between consecutive stream IDs of the same type. + static QuicStreamId StreamIdDelta(QuicTransportVersion version); + + // Returns the first initiated bidirectional stream ID of |perspective|. + static QuicStreamId GetFirstBidirectionalStreamId( + QuicTransportVersion version, Perspective perspective); + + // Returns the first initiated unidirectional stream ID of |perspective|. + static QuicStreamId GetFirstUnidirectionalStreamId( + QuicTransportVersion version, Perspective perspective); + + // Returns the largest possible client initiated bidirectional stream ID. + static QuicStreamId GetMaxClientInitiatedBidirectionalStreamId( + QuicTransportVersion version); + + // Generates a random 64bit connection ID. + static QuicConnectionId CreateRandomConnectionId(); + + // Generates a random 64bit connection ID using the provided QuicRandom. + static QuicConnectionId CreateRandomConnectionId(QuicRandom* random); + + // Generates a random connection ID of the given length. + static QuicConnectionId CreateRandomConnectionId( + uint8_t connection_id_length); + + // Generates a random connection ID of the given length using the provided + // QuicRandom. + static QuicConnectionId CreateRandomConnectionId(uint8_t connection_id_length, + QuicRandom* random); + + // Returns true if the connection ID length is valid for this QUIC version. + static bool IsConnectionIdLengthValidForVersion( + size_t connection_id_length, QuicTransportVersion transport_version); + + // Returns true if the connection ID is valid for this QUIC version. + static bool IsConnectionIdValidForVersion( + QuicConnectionId connection_id, QuicTransportVersion transport_version); + + // Returns a connection ID suitable for QUIC use-cases that do not need the + // connection ID for multiplexing. If the version allows variable lengths, + // a connection of length zero is returned, otherwise 64bits set to zero. + static QuicConnectionId CreateZeroConnectionId(QuicTransportVersion version); + + // Generates a 128bit stateless reset token based on a connection ID. + static StatelessResetToken GenerateStatelessResetToken( + QuicConnectionId connection_id); + + // Determines packet number space from |encryption_level|. + static PacketNumberSpace GetPacketNumberSpace( + EncryptionLevel encryption_level); + + // Determines encryption level to send ACK in |packet_number_space|. + static EncryptionLevel GetEncryptionLevelToSendAckofSpace( + PacketNumberSpace packet_number_space); + + // Get the maximum value for a V99/IETF QUIC stream count. If a count + // exceeds this value, it will result in a stream ID that exceeds the + // implementation limit on stream ID size. + static QuicStreamCount GetMaxStreamCount(); + + // Return true if this frame is an IETF probing frame. + static bool IsProbingFrame(QuicFrameType type); + + // Return true if the two stateless reset tokens are equal. Performs the + // comparison in constant time. + static bool AreStatelessResetTokensEqual(const StatelessResetToken& token1, + const StatelessResetToken& token2); + + // Return ture if this frame is an ack-eliciting frame. + static bool IsAckElicitingFrame(QuicFrameType type); +}; + +// Returns true if the specific ID is a valid WebTransport session ID that our +// implementation can process. +bool IsValidWebTransportSessionId(WebTransportSessionId id, + ParsedQuicVersion transport_version); + +QuicByteCount MemSliceSpanTotalSize(absl::Span span); + +// Computes a SHA-256 hash and returns the raw bytes of the hash. +QUIC_EXPORT_PRIVATE std::string RawSha256(absl::string_view input); + +template +class QUIC_EXPORT_PRIVATE BitMask { + public: + // explicit to prevent (incorrect) usage like "BitMask bitmask = 0;". + template + explicit BitMask(Bits... bits) { + mask_ = MakeMask(bits...); + } + + BitMask() = default; + BitMask(const BitMask& other) = default; + BitMask& operator=(const BitMask& other) = default; + + template + void Set(Bits... bits) { + mask_ |= MakeMask(bits...); + } + + template + bool IsSet(Bit bit) const { + return (MakeMask(bit) & mask_) != 0; + } + + void ClearAll() { mask_ = 0; } + + static constexpr size_t NumBits() { return 8 * sizeof(Mask); } + + friend bool operator==(const BitMask& lhs, const BitMask& rhs) { + return lhs.mask_ == rhs.mask_; + } + + std::string DebugString() const { + std::ostringstream oss; + oss << "0x" << std::hex << mask_; + return oss.str(); + } + + private: + template + static std::enable_if_t::value, Mask> MakeMask(Bit bit) { + using IntType = typename std::underlying_type::type; + return Mask(1) << static_cast(bit); + } + + template + static std::enable_if_t::value, Mask> MakeMask(Bit bit) { + return Mask(1) << bit; + } + + template + static Mask MakeMask(Bit first_bit, Bits... other_bits) { + return MakeMask(first_bit) | MakeMask(other_bits...); + } + + Mask mask_ = 0; +}; + +using BitMask64 = BitMask; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_UTILS_H_ diff --git a/quiche/quic/core/quic_utils_test.cc b/quiche/quic/core/quic_utils_test.cc new file mode 100644 index 000000000000..543c8ead0b74 --- /dev/null +++ b/quiche/quic/core/quic_utils_test.cc @@ -0,0 +1,320 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_utils.h" + +#include + +#include "absl/base/macros.h" +#include "absl/numeric/int128.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { +namespace test { +namespace { + +class QuicUtilsTest : public QuicTest {}; + +TEST_F(QuicUtilsTest, DetermineAddressChangeType) { + const std::string kIPv4String1 = "1.2.3.4"; + const std::string kIPv4String2 = "1.2.3.5"; + const std::string kIPv4String3 = "1.1.3.5"; + const std::string kIPv6String1 = "2001:700:300:1800::f"; + const std::string kIPv6String2 = "2001:700:300:1800:1:1:1:f"; + QuicSocketAddress old_address; + QuicSocketAddress new_address; + QuicIpAddress address; + + EXPECT_EQ(NO_CHANGE, + QuicUtils::DetermineAddressChangeType(old_address, new_address)); + ASSERT_TRUE(address.FromString(kIPv4String1)); + old_address = QuicSocketAddress(address, 1234); + EXPECT_EQ(NO_CHANGE, + QuicUtils::DetermineAddressChangeType(old_address, new_address)); + new_address = QuicSocketAddress(address, 1234); + EXPECT_EQ(NO_CHANGE, + QuicUtils::DetermineAddressChangeType(old_address, new_address)); + + new_address = QuicSocketAddress(address, 5678); + EXPECT_EQ(PORT_CHANGE, + QuicUtils::DetermineAddressChangeType(old_address, new_address)); + ASSERT_TRUE(address.FromString(kIPv6String1)); + old_address = QuicSocketAddress(address, 1234); + new_address = QuicSocketAddress(address, 5678); + EXPECT_EQ(PORT_CHANGE, + QuicUtils::DetermineAddressChangeType(old_address, new_address)); + + ASSERT_TRUE(address.FromString(kIPv4String1)); + old_address = QuicSocketAddress(address, 1234); + ASSERT_TRUE(address.FromString(kIPv6String1)); + new_address = QuicSocketAddress(address, 1234); + EXPECT_EQ(IPV4_TO_IPV6_CHANGE, + QuicUtils::DetermineAddressChangeType(old_address, new_address)); + + old_address = QuicSocketAddress(address, 1234); + ASSERT_TRUE(address.FromString(kIPv4String1)); + new_address = QuicSocketAddress(address, 1234); + EXPECT_EQ(IPV6_TO_IPV4_CHANGE, + QuicUtils::DetermineAddressChangeType(old_address, new_address)); + + ASSERT_TRUE(address.FromString(kIPv6String2)); + new_address = QuicSocketAddress(address, 1234); + EXPECT_EQ(IPV6_TO_IPV6_CHANGE, + QuicUtils::DetermineAddressChangeType(old_address, new_address)); + + ASSERT_TRUE(address.FromString(kIPv4String1)); + old_address = QuicSocketAddress(address, 1234); + ASSERT_TRUE(address.FromString(kIPv4String2)); + new_address = QuicSocketAddress(address, 1234); + EXPECT_EQ(IPV4_SUBNET_CHANGE, + QuicUtils::DetermineAddressChangeType(old_address, new_address)); + ASSERT_TRUE(address.FromString(kIPv4String3)); + new_address = QuicSocketAddress(address, 1234); + EXPECT_EQ(IPV4_TO_IPV4_CHANGE, + QuicUtils::DetermineAddressChangeType(old_address, new_address)); +} + +absl::uint128 IncrementalHashReference(const void* data, size_t len) { + // The two constants are defined as part of the hash algorithm. + // see http://www.isthe.com/chongo/tech/comp/fnv/ + // hash = 144066263297769815596495629667062367629 + absl::uint128 hash = absl::MakeUint128(UINT64_C(7809847782465536322), + UINT64_C(7113472399480571277)); + // kPrime = 309485009821345068724781371 + const absl::uint128 kPrime = absl::MakeUint128(16777216, 315); + const uint8_t* octets = reinterpret_cast(data); + for (size_t i = 0; i < len; ++i) { + hash = hash ^ absl::MakeUint128(0, octets[i]); + hash = hash * kPrime; + } + return hash; +} + +TEST_F(QuicUtilsTest, ReferenceTest) { + std::vector data(32); + for (size_t i = 0; i < data.size(); ++i) { + data[i] = i % 255; + } + EXPECT_EQ(IncrementalHashReference(data.data(), data.size()), + QuicUtils::FNV1a_128_Hash(absl::string_view( + reinterpret_cast(data.data()), data.size()))); +} + +TEST_F(QuicUtilsTest, IsUnackable) { + for (size_t i = FIRST_PACKET_STATE; i <= LAST_PACKET_STATE; ++i) { + if (i == NEVER_SENT || i == ACKED || i == UNACKABLE) { + EXPECT_FALSE(QuicUtils::IsAckable(static_cast(i))); + } else { + EXPECT_TRUE(QuicUtils::IsAckable(static_cast(i))); + } + } +} + +TEST_F(QuicUtilsTest, RetransmissionTypeToPacketState) { + for (size_t i = FIRST_TRANSMISSION_TYPE; i <= LAST_TRANSMISSION_TYPE; ++i) { + if (i == NOT_RETRANSMISSION) { + continue; + } + SentPacketState state = QuicUtils::RetransmissionTypeToPacketState( + static_cast(i)); + if (i == HANDSHAKE_RETRANSMISSION) { + EXPECT_EQ(HANDSHAKE_RETRANSMITTED, state); + } else if (i == LOSS_RETRANSMISSION) { + EXPECT_EQ(LOST, state); + } else if (i == ALL_ZERO_RTT_RETRANSMISSION) { + EXPECT_EQ(UNACKABLE, state); + } else if (i == PTO_RETRANSMISSION) { + EXPECT_EQ(PTO_RETRANSMITTED, state); + } else if (i == PATH_RETRANSMISSION) { + EXPECT_EQ(NOT_CONTRIBUTING_RTT, state); + } else if (i == ALL_INITIAL_RETRANSMISSION) { + EXPECT_EQ(UNACKABLE, state); + } else { + QUICHE_DCHECK(false) + << "No corresponding packet state according to transmission type: " + << i; + } + } +} + +TEST_F(QuicUtilsTest, IsIetfPacketHeader) { + // IETF QUIC short header + uint8_t first_byte = 0; + EXPECT_TRUE(QuicUtils::IsIetfPacketHeader(first_byte)); + EXPECT_TRUE(QuicUtils::IsIetfPacketShortHeader(first_byte)); + + // IETF QUIC long header + first_byte |= (FLAGS_LONG_HEADER | FLAGS_DEMULTIPLEXING_BIT); + EXPECT_TRUE(QuicUtils::IsIetfPacketHeader(first_byte)); + EXPECT_FALSE(QuicUtils::IsIetfPacketShortHeader(first_byte)); + + // IETF QUIC long header, version negotiation. + first_byte = 0; + first_byte |= FLAGS_LONG_HEADER; + EXPECT_TRUE(QuicUtils::IsIetfPacketHeader(first_byte)); + EXPECT_FALSE(QuicUtils::IsIetfPacketShortHeader(first_byte)); + + // GQUIC + first_byte = 0; + first_byte |= PACKET_PUBLIC_FLAGS_8BYTE_CONNECTION_ID; + EXPECT_FALSE(QuicUtils::IsIetfPacketHeader(first_byte)); + EXPECT_FALSE(QuicUtils::IsIetfPacketShortHeader(first_byte)); +} + +TEST_F(QuicUtilsTest, RandomConnectionId) { + MockRandom random(33); + QuicConnectionId connection_id = QuicUtils::CreateRandomConnectionId(&random); + EXPECT_EQ(connection_id.length(), sizeof(uint64_t)); + char connection_id_bytes[sizeof(uint64_t)]; + random.RandBytes(connection_id_bytes, ABSL_ARRAYSIZE(connection_id_bytes)); + EXPECT_EQ(connection_id, + QuicConnectionId(static_cast(connection_id_bytes), + ABSL_ARRAYSIZE(connection_id_bytes))); + EXPECT_NE(connection_id, EmptyQuicConnectionId()); + EXPECT_NE(connection_id, TestConnectionId()); + EXPECT_NE(connection_id, TestConnectionId(1)); + EXPECT_NE(connection_id, TestConnectionIdNineBytesLong(1)); + EXPECT_EQ(QuicUtils::CreateRandomConnectionId().length(), + kQuicDefaultConnectionIdLength); +} + +TEST_F(QuicUtilsTest, RandomConnectionIdVariableLength) { + MockRandom random(1337); + const uint8_t connection_id_length = 9; + QuicConnectionId connection_id = + QuicUtils::CreateRandomConnectionId(connection_id_length, &random); + EXPECT_EQ(connection_id.length(), connection_id_length); + char connection_id_bytes[connection_id_length]; + random.RandBytes(connection_id_bytes, ABSL_ARRAYSIZE(connection_id_bytes)); + EXPECT_EQ(connection_id, + QuicConnectionId(static_cast(connection_id_bytes), + ABSL_ARRAYSIZE(connection_id_bytes))); + EXPECT_NE(connection_id, EmptyQuicConnectionId()); + EXPECT_NE(connection_id, TestConnectionId()); + EXPECT_NE(connection_id, TestConnectionId(1)); + EXPECT_NE(connection_id, TestConnectionIdNineBytesLong(1)); + EXPECT_EQ(QuicUtils::CreateRandomConnectionId(connection_id_length).length(), + connection_id_length); +} + +TEST_F(QuicUtilsTest, VariableLengthConnectionId) { + EXPECT_FALSE(VersionAllowsVariableLengthConnectionIds(QUIC_VERSION_43)); + EXPECT_TRUE(QuicUtils::IsConnectionIdValidForVersion( + QuicUtils::CreateZeroConnectionId(QUIC_VERSION_43), QUIC_VERSION_43)); + EXPECT_TRUE(QuicUtils::IsConnectionIdValidForVersion( + QuicUtils::CreateZeroConnectionId(QUIC_VERSION_50), QUIC_VERSION_50)); + EXPECT_NE(QuicUtils::CreateZeroConnectionId(QUIC_VERSION_43), + EmptyQuicConnectionId()); + EXPECT_EQ(QuicUtils::CreateZeroConnectionId(QUIC_VERSION_50), + EmptyQuicConnectionId()); + EXPECT_FALSE(QuicUtils::IsConnectionIdValidForVersion(EmptyQuicConnectionId(), + QUIC_VERSION_43)); +} + +TEST_F(QuicUtilsTest, StatelessResetToken) { + QuicConnectionId connection_id1a = test::TestConnectionId(1); + QuicConnectionId connection_id1b = test::TestConnectionId(1); + QuicConnectionId connection_id2 = test::TestConnectionId(2); + StatelessResetToken token1a = + QuicUtils::GenerateStatelessResetToken(connection_id1a); + StatelessResetToken token1b = + QuicUtils::GenerateStatelessResetToken(connection_id1b); + StatelessResetToken token2 = + QuicUtils::GenerateStatelessResetToken(connection_id2); + EXPECT_EQ(token1a, token1b); + EXPECT_NE(token1a, token2); + EXPECT_TRUE(QuicUtils::AreStatelessResetTokensEqual(token1a, token1b)); + EXPECT_FALSE(QuicUtils::AreStatelessResetTokensEqual(token1a, token2)); +} + +TEST_F(QuicUtilsTest, EcnCodepointToString) { + EXPECT_EQ(EcnCodepointToString(ECN_NOT_ECT), "Not-ECT"); + EXPECT_EQ(EcnCodepointToString(ECN_ECT0), "ECT(0)"); + EXPECT_EQ(EcnCodepointToString(ECN_ECT1), "ECT(1)"); + EXPECT_EQ(EcnCodepointToString(ECN_CE), "CE"); +} + +enum class TestEnumClassBit : uint8_t { + BIT_ZERO = 0, + BIT_ONE, + BIT_TWO, +}; + +enum TestEnumBit { + TEST_BIT_0 = 0, + TEST_BIT_1, + TEST_BIT_2, +}; + +TEST(QuicBitMaskTest, EnumClass) { + BitMask64 mask(TestEnumClassBit::BIT_ZERO, TestEnumClassBit::BIT_TWO); + EXPECT_TRUE(mask.IsSet(TestEnumClassBit::BIT_ZERO)); + EXPECT_FALSE(mask.IsSet(TestEnumClassBit::BIT_ONE)); + EXPECT_TRUE(mask.IsSet(TestEnumClassBit::BIT_TWO)); + + mask.ClearAll(); + EXPECT_FALSE(mask.IsSet(TestEnumClassBit::BIT_ZERO)); + EXPECT_FALSE(mask.IsSet(TestEnumClassBit::BIT_ONE)); + EXPECT_FALSE(mask.IsSet(TestEnumClassBit::BIT_TWO)); +} + +TEST(QuicBitMaskTest, Enum) { + BitMask64 mask(TEST_BIT_1, TEST_BIT_2); + EXPECT_FALSE(mask.IsSet(TEST_BIT_0)); + EXPECT_TRUE(mask.IsSet(TEST_BIT_1)); + EXPECT_TRUE(mask.IsSet(TEST_BIT_2)); + + mask.ClearAll(); + EXPECT_FALSE(mask.IsSet(TEST_BIT_0)); + EXPECT_FALSE(mask.IsSet(TEST_BIT_1)); + EXPECT_FALSE(mask.IsSet(TEST_BIT_2)); +} + +TEST(QuicBitMaskTest, Integer) { + BitMask64 mask(1, 3); + mask.Set(3); + mask.Set(5, 7, 9); + EXPECT_FALSE(mask.IsSet(0)); + EXPECT_TRUE(mask.IsSet(1)); + EXPECT_FALSE(mask.IsSet(2)); + EXPECT_TRUE(mask.IsSet(3)); + EXPECT_FALSE(mask.IsSet(4)); + EXPECT_TRUE(mask.IsSet(5)); + EXPECT_FALSE(mask.IsSet(6)); + EXPECT_TRUE(mask.IsSet(7)); + EXPECT_FALSE(mask.IsSet(8)); + EXPECT_TRUE(mask.IsSet(9)); +} + +TEST(QuicBitMaskTest, NumBits) { + EXPECT_EQ(64u, BitMask64::NumBits()); + EXPECT_EQ(32u, BitMask::NumBits()); +} + +TEST(QuicBitMaskTest, Constructor) { + BitMask64 empty_mask; + for (size_t bit = 0; bit < empty_mask.NumBits(); ++bit) { + EXPECT_FALSE(empty_mask.IsSet(bit)); + } + + BitMask64 mask(1, 3); + BitMask64 mask2 = mask; + BitMask64 mask3(mask2); + + for (size_t bit = 0; bit < mask.NumBits(); ++bit) { + EXPECT_EQ(mask.IsSet(bit), mask2.IsSet(bit)); + EXPECT_EQ(mask.IsSet(bit), mask3.IsSet(bit)); + } + + EXPECT_TRUE(std::is_trivially_copyable::value); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_version_manager.cc b/quiche/quic/core/quic_version_manager.cc new file mode 100644 index 000000000000..ec2b9ef0089b --- /dev/null +++ b/quiche/quic/core/quic_version_manager.cc @@ -0,0 +1,94 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_version_manager.h" + +#include + +#include "absl/base/macros.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" + +namespace quic { + +QuicVersionManager::QuicVersionManager( + ParsedQuicVersionVector supported_versions) + : allowed_supported_versions_(std::move(supported_versions)) {} + +QuicVersionManager::~QuicVersionManager() {} + +const ParsedQuicVersionVector& QuicVersionManager::GetSupportedVersions() { + MaybeRefilterSupportedVersions(); + return filtered_supported_versions_; +} + +const ParsedQuicVersionVector& +QuicVersionManager::GetSupportedVersionsWithOnlyHttp3() { + MaybeRefilterSupportedVersions(); + return filtered_supported_versions_with_http3_; +} + +const std::vector& QuicVersionManager::GetSupportedAlpns() { + MaybeRefilterSupportedVersions(); + return filtered_supported_alpns_; +} + +void QuicVersionManager::MaybeRefilterSupportedVersions() { + static_assert(SupportedVersions().size() == 6u, + "Supported versions out of sync"); + if (enable_version_2_draft_08_ != + GetQuicReloadableFlag(quic_enable_version_2_draft_08) || + disable_version_rfcv1_ != + GetQuicReloadableFlag(quic_disable_version_rfcv1) || + disable_version_draft_29_ != + GetQuicReloadableFlag(quic_disable_version_draft_29) || + disable_version_q050_ != + GetQuicReloadableFlag(quic_disable_version_q050) || + disable_version_q046_ != + GetQuicReloadableFlag(quic_disable_version_q046) || + disable_version_q043_ != + GetQuicReloadableFlag(quic_disable_version_q043)) { + enable_version_2_draft_08_ = + GetQuicReloadableFlag(quic_enable_version_2_draft_08); + disable_version_rfcv1_ = GetQuicReloadableFlag(quic_disable_version_rfcv1); + disable_version_draft_29_ = + GetQuicReloadableFlag(quic_disable_version_draft_29); + disable_version_q050_ = GetQuicReloadableFlag(quic_disable_version_q050); + disable_version_q046_ = GetQuicReloadableFlag(quic_disable_version_q046); + disable_version_q043_ = GetQuicReloadableFlag(quic_disable_version_q043); + + RefilterSupportedVersions(); + } +} + +void QuicVersionManager::RefilterSupportedVersions() { + filtered_supported_versions_ = + FilterSupportedVersions(allowed_supported_versions_); + filtered_supported_versions_with_http3_.clear(); + filtered_transport_versions_.clear(); + filtered_supported_alpns_.clear(); + for (const ParsedQuicVersion& version : filtered_supported_versions_) { + auto transport_version = version.transport_version; + if (std::find(filtered_transport_versions_.begin(), + filtered_transport_versions_.end(), + transport_version) == filtered_transport_versions_.end()) { + filtered_transport_versions_.push_back(transport_version); + } + if (version.UsesHttp3()) { + filtered_supported_versions_with_http3_.push_back(version); + } + if (std::find(filtered_supported_alpns_.begin(), + filtered_supported_alpns_.end(), + AlpnForVersion(version)) == filtered_supported_alpns_.end()) { + filtered_supported_alpns_.emplace_back(AlpnForVersion(version)); + } + } +} + +void QuicVersionManager::AddCustomAlpn(const std::string& alpn) { + filtered_supported_alpns_.push_back(alpn); +} + +} // namespace quic diff --git a/quiche/quic/core/quic_version_manager.h b/quiche/quic/core/quic_version_manager.h new file mode 100644 index 000000000000..c664fc4c8a6a --- /dev/null +++ b/quiche/quic/core/quic_version_manager.h @@ -0,0 +1,95 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_VERSION_MANAGER_H_ +#define QUICHE_QUIC_CORE_QUIC_VERSION_MANAGER_H_ + +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// Used to generate filtered supported versions based on flags. +class QUIC_EXPORT_PRIVATE QuicVersionManager { + public: + // |supported_versions| should be sorted in the order of preference (typically + // highest supported version to the lowest supported version). + explicit QuicVersionManager(ParsedQuicVersionVector supported_versions); + virtual ~QuicVersionManager(); + + // Returns currently supported QUIC versions. This vector has the same order + // as the versions passed to the constructor. + const ParsedQuicVersionVector& GetSupportedVersions(); + + // Returns currently supported versions using HTTP/3. + const ParsedQuicVersionVector& GetSupportedVersionsWithOnlyHttp3(); + + // Returns the list of supported ALPNs, based on the current supported + // versions and any custom additions by subclasses. + const std::vector& GetSupportedAlpns(); + + protected: + // If the value of any reloadable flag is different from the cached value, + // re-filter |filtered_supported_versions_| and update the cached flag values. + // Otherwise, does nothing. + // TODO(dschinazi): Make private when deprecating + // FLAGS_gfe2_restart_flag_quic_disable_old_alt_svc_format. + void MaybeRefilterSupportedVersions(); + + // Refilters filtered_supported_versions_. + virtual void RefilterSupportedVersions(); + + // RefilterSupportedVersions() must be called before calling this method. + // TODO(dschinazi): Remove when deprecating + // FLAGS_gfe2_restart_flag_quic_disable_old_alt_svc_format. + const QuicTransportVersionVector& filtered_transport_versions() const { + return filtered_transport_versions_; + } + + // Subclasses may add custom ALPNs to the supported list by overriding + // RefilterSupportedVersions() to first call + // QuicVersionManager::RefilterSupportedVersions() then AddCustomAlpn(). + // Must not be called elsewhere. + void AddCustomAlpn(const std::string& alpn); + + private: + // Cached value of reloadable flags. + // quic_enable_version_2_draft_08 flag + bool enable_version_2_draft_08_ = false; + // quic_disable_version_rfcv1 flag + bool disable_version_rfcv1_ = true; + // quic_disable_version_draft_29 flag + bool disable_version_draft_29_ = true; + // quic_disable_version_q050 flag + bool disable_version_q050_ = true; + // quic_disable_version_q046 flag + bool disable_version_q046_ = true; + // quic_disable_version_q043 flag + bool disable_version_q043_ = true; + + // The list of versions that may be supported. + const ParsedQuicVersionVector allowed_supported_versions_; + + // The following vectors are calculated from reloadable flags by + // RefilterSupportedVersions(). It is performed lazily when first needed, and + // after that, since the calculation is relatively expensive, only if the flag + // values change. + + // This vector contains QUIC versions which are currently supported based on + // flags. + ParsedQuicVersionVector filtered_supported_versions_; + // Currently supported versions using HTTP/3. + ParsedQuicVersionVector filtered_supported_versions_with_http3_; + // This vector contains the transport versions from + // |filtered_supported_versions_|. No guarantees are made that the same + // transport version isn't repeated. + QuicTransportVersionVector filtered_transport_versions_; + // Contains the list of ALPNs corresponding to filtered_supported_versions_ + // with custom ALPNs added. + std::vector filtered_supported_alpns_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_VERSION_MANAGER_H_ diff --git a/quiche/quic/core/quic_version_manager_test.cc b/quiche/quic/core/quic_version_manager_test.cc new file mode 100644 index 000000000000..3c3a00ea2f68 --- /dev/null +++ b/quiche/quic/core/quic_version_manager_test.cc @@ -0,0 +1,83 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_version_manager.h" + +#include "absl/base/macros.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_test.h" + +using ::testing::ElementsAre; + +namespace quic { +namespace test { +namespace { + +class QuicVersionManagerTest : public QuicTest {}; + +TEST_F(QuicVersionManagerTest, QuicVersionManager) { + static_assert(SupportedVersions().size() == 6u, + "Supported versions out of sync"); + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + QuicEnableVersion(version); + } + QuicDisableVersion(ParsedQuicVersion::V2Draft08()); + QuicDisableVersion(ParsedQuicVersion::RFCv1()); + QuicDisableVersion(ParsedQuicVersion::Draft29()); + QuicVersionManager manager(AllSupportedVersions()); + + ParsedQuicVersionVector expected_parsed_versions; + expected_parsed_versions.push_back(ParsedQuicVersion::Q050()); + expected_parsed_versions.push_back(ParsedQuicVersion::Q046()); + expected_parsed_versions.push_back(ParsedQuicVersion::Q043()); + + EXPECT_EQ(expected_parsed_versions, manager.GetSupportedVersions()); + + EXPECT_EQ(FilterSupportedVersions(AllSupportedVersions()), + manager.GetSupportedVersions()); + EXPECT_TRUE(manager.GetSupportedVersionsWithOnlyHttp3().empty()); + EXPECT_THAT(manager.GetSupportedAlpns(), + ElementsAre("h3-Q050", "h3-Q046", "h3-Q043")); + + QuicEnableVersion(ParsedQuicVersion::Draft29()); + expected_parsed_versions.insert(expected_parsed_versions.begin(), + ParsedQuicVersion::Draft29()); + EXPECT_EQ(expected_parsed_versions, manager.GetSupportedVersions()); + EXPECT_EQ(FilterSupportedVersions(AllSupportedVersions()), + manager.GetSupportedVersions()); + EXPECT_EQ(1u, manager.GetSupportedVersionsWithOnlyHttp3().size()); + EXPECT_EQ(CurrentSupportedHttp3Versions(), + manager.GetSupportedVersionsWithOnlyHttp3()); + EXPECT_THAT(manager.GetSupportedAlpns(), + ElementsAre("h3-29", "h3-Q050", "h3-Q046", "h3-Q043")); + + QuicEnableVersion(ParsedQuicVersion::RFCv1()); + expected_parsed_versions.insert(expected_parsed_versions.begin(), + ParsedQuicVersion::RFCv1()); + EXPECT_EQ(expected_parsed_versions, manager.GetSupportedVersions()); + EXPECT_EQ(FilterSupportedVersions(AllSupportedVersions()), + manager.GetSupportedVersions()); + EXPECT_EQ(2u, manager.GetSupportedVersionsWithOnlyHttp3().size()); + EXPECT_EQ(CurrentSupportedHttp3Versions(), + manager.GetSupportedVersionsWithOnlyHttp3()); + EXPECT_THAT(manager.GetSupportedAlpns(), + ElementsAre("h3", "h3-29", "h3-Q050", "h3-Q046", "h3-Q043")); + + QuicEnableVersion(ParsedQuicVersion::V2Draft08()); + expected_parsed_versions.insert(expected_parsed_versions.begin(), + ParsedQuicVersion::V2Draft08()); + EXPECT_EQ(expected_parsed_versions, manager.GetSupportedVersions()); + EXPECT_EQ(FilterSupportedVersions(AllSupportedVersions()), + manager.GetSupportedVersions()); + EXPECT_EQ(3u, manager.GetSupportedVersionsWithOnlyHttp3().size()); + EXPECT_EQ(CurrentSupportedHttp3Versions(), + manager.GetSupportedVersionsWithOnlyHttp3()); + EXPECT_THAT(manager.GetSupportedAlpns(), + ElementsAre("h3", "h3-29", "h3-Q050", "h3-Q046", "h3-Q043")); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_versions.cc b/quiche/quic/core/quic_versions.cc new file mode 100644 index 000000000000..1f03b704e15c --- /dev/null +++ b/quiche/quic/core/quic_versions.cc @@ -0,0 +1,664 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_versions.h" + +#include + +#include "absl/base/macros.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/quic_tag.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/quiche_endian.h" +#include "quiche/common/quiche_text_utils.h" + +namespace quic { +namespace { + +QuicVersionLabel CreateRandomVersionLabelForNegotiation() { + QuicVersionLabel result; + if (!GetQuicFlag(quic_disable_version_negotiation_grease_randomness)) { + QuicRandom::GetInstance()->RandBytes(&result, sizeof(result)); + } else { + result = MakeVersionLabel(0xd1, 0x57, 0x38, 0x3f); + } + result &= 0xf0f0f0f0; + result |= 0x0a0a0a0a; + return result; +} + +void SetVersionFlag(const ParsedQuicVersion& version, bool should_enable) { + static_assert(SupportedVersions().size() == 6u, + "Supported versions out of sync"); + const bool enable = should_enable; + const bool disable = !should_enable; + if (version == ParsedQuicVersion::V2Draft08()) { + SetQuicReloadableFlag(quic_enable_version_2_draft_08, enable); + } else if (version == ParsedQuicVersion::RFCv1()) { + SetQuicReloadableFlag(quic_disable_version_rfcv1, disable); + } else if (version == ParsedQuicVersion::Draft29()) { + SetQuicReloadableFlag(quic_disable_version_draft_29, disable); + } else if (version == ParsedQuicVersion::Q050()) { + SetQuicReloadableFlag(quic_disable_version_q050, disable); + } else if (version == ParsedQuicVersion::Q046()) { + SetQuicReloadableFlag(quic_disable_version_q046, disable); + } else if (version == ParsedQuicVersion::Q043()) { + SetQuicReloadableFlag(quic_disable_version_q043, disable); + } else { + QUIC_BUG(quic_bug_10589_1) + << "Cannot " << (enable ? "en" : "dis") << "able version " << version; + } +} + +} // namespace + +bool ParsedQuicVersion::IsKnown() const { + QUICHE_DCHECK(ParsedQuicVersionIsValid(handshake_protocol, transport_version)) + << QuicVersionToString(transport_version) << " " + << HandshakeProtocolToString(handshake_protocol); + return transport_version != QUIC_VERSION_UNSUPPORTED; +} + +bool ParsedQuicVersion::KnowsWhichDecrypterToUse() const { + QUICHE_DCHECK(IsKnown()); + return transport_version > QUIC_VERSION_46; +} + +bool ParsedQuicVersion::UsesInitialObfuscators() const { + QUICHE_DCHECK(IsKnown()); + // Initial obfuscators were added in version 50. + return transport_version > QUIC_VERSION_46; +} + +bool ParsedQuicVersion::AllowsLowFlowControlLimits() const { + QUICHE_DCHECK(IsKnown()); + // Low flow-control limits are used for all IETF versions. + return UsesHttp3(); +} + +bool ParsedQuicVersion::HasHeaderProtection() const { + QUICHE_DCHECK(IsKnown()); + // Header protection was added in version 50. + return transport_version > QUIC_VERSION_46; +} + +bool ParsedQuicVersion::SupportsRetry() const { + QUICHE_DCHECK(IsKnown()); + // Retry was added in version 47. + return transport_version > QUIC_VERSION_46; +} + +bool ParsedQuicVersion::SendsVariableLengthPacketNumberInLongHeader() const { + QUICHE_DCHECK(IsKnown()); + return transport_version > QUIC_VERSION_46; +} + +bool ParsedQuicVersion::AllowsVariableLengthConnectionIds() const { + QUICHE_DCHECK(IsKnown()); + return VersionAllowsVariableLengthConnectionIds(transport_version); +} + +bool ParsedQuicVersion::SupportsClientConnectionIds() const { + QUICHE_DCHECK(IsKnown()); + // Client connection IDs were added in version 49. + return transport_version > QUIC_VERSION_46; +} + +bool ParsedQuicVersion::HasLengthPrefixedConnectionIds() const { + QUICHE_DCHECK(IsKnown()); + return VersionHasLengthPrefixedConnectionIds(transport_version); +} + +bool ParsedQuicVersion::SupportsAntiAmplificationLimit() const { + QUICHE_DCHECK(IsKnown()); + // The anti-amplification limit is used for all IETF versions. + return UsesHttp3(); +} + +bool ParsedQuicVersion::CanSendCoalescedPackets() const { + QUICHE_DCHECK(IsKnown()); + return HasLongHeaderLengths() && UsesTls(); +} + +bool ParsedQuicVersion::SupportsGoogleAltSvcFormat() const { + QUICHE_DCHECK(IsKnown()); + return VersionSupportsGoogleAltSvcFormat(transport_version); +} + +bool ParsedQuicVersion::HasIetfInvariantHeader() const { + QUICHE_DCHECK(IsKnown()); + return VersionHasIetfInvariantHeader(transport_version); +} + +bool ParsedQuicVersion::SupportsMessageFrames() const { + QUICHE_DCHECK(IsKnown()); + return VersionSupportsMessageFrames(transport_version); +} + +bool ParsedQuicVersion::UsesHttp3() const { + QUICHE_DCHECK(IsKnown()); + return VersionUsesHttp3(transport_version); +} + +bool ParsedQuicVersion::HasLongHeaderLengths() const { + QUICHE_DCHECK(IsKnown()); + return QuicVersionHasLongHeaderLengths(transport_version); +} + +bool ParsedQuicVersion::UsesCryptoFrames() const { + QUICHE_DCHECK(IsKnown()); + return QuicVersionUsesCryptoFrames(transport_version); +} + +bool ParsedQuicVersion::HasIetfQuicFrames() const { + QUICHE_DCHECK(IsKnown()); + return VersionHasIetfQuicFrames(transport_version); +} + +bool ParsedQuicVersion::UsesLegacyTlsExtension() const { + QUICHE_DCHECK(IsKnown()); + return UsesTls() && transport_version <= QUIC_VERSION_IETF_DRAFT_29; +} + +bool ParsedQuicVersion::UsesTls() const { + QUICHE_DCHECK(IsKnown()); + return handshake_protocol == PROTOCOL_TLS1_3; +} + +bool ParsedQuicVersion::UsesQuicCrypto() const { + QUICHE_DCHECK(IsKnown()); + return handshake_protocol == PROTOCOL_QUIC_CRYPTO; +} + +bool ParsedQuicVersion::UsesV2PacketTypes() const { + QUICHE_DCHECK(IsKnown()); + return transport_version == QUIC_VERSION_IETF_2_DRAFT_08; +} + +bool ParsedQuicVersion::AlpnDeferToRFCv1() const { + QUICHE_DCHECK(IsKnown()); + return transport_version == QUIC_VERSION_IETF_2_DRAFT_08; +} + +bool VersionHasLengthPrefixedConnectionIds( + QuicTransportVersion transport_version) { + QUICHE_DCHECK(transport_version != QUIC_VERSION_UNSUPPORTED); + // Length-prefixed connection IDs were added in version 49. + return transport_version > QUIC_VERSION_46; +} + +std::ostream& operator<<(std::ostream& os, const ParsedQuicVersion& version) { + os << ParsedQuicVersionToString(version); + return os; +} + +std::ostream& operator<<(std::ostream& os, + const ParsedQuicVersionVector& versions) { + os << ParsedQuicVersionVectorToString(versions); + return os; +} + +QuicVersionLabel MakeVersionLabel(uint8_t a, uint8_t b, uint8_t c, uint8_t d) { + return MakeQuicTag(d, c, b, a); +} + +std::ostream& operator<<(std::ostream& os, + const QuicVersionLabelVector& version_labels) { + os << QuicVersionLabelVectorToString(version_labels); + return os; +} + +std::ostream& operator<<(std::ostream& os, + const QuicTransportVersionVector& transport_versions) { + os << QuicTransportVersionVectorToString(transport_versions); + return os; +} + +QuicVersionLabel CreateQuicVersionLabel(ParsedQuicVersion parsed_version) { + static_assert(SupportedVersions().size() == 6u, + "Supported versions out of sync"); + if (parsed_version == ParsedQuicVersion::V2Draft08()) { + return MakeVersionLabel(0x6b, 0x33, 0x43, 0xcf); + } else if (parsed_version == ParsedQuicVersion::RFCv1()) { + return MakeVersionLabel(0x00, 0x00, 0x00, 0x01); + } else if (parsed_version == ParsedQuicVersion::Draft29()) { + return MakeVersionLabel(0xff, 0x00, 0x00, 29); + } else if (parsed_version == ParsedQuicVersion::Q050()) { + return MakeVersionLabel('Q', '0', '5', '0'); + } else if (parsed_version == ParsedQuicVersion::Q046()) { + return MakeVersionLabel('Q', '0', '4', '6'); + } else if (parsed_version == ParsedQuicVersion::Q043()) { + return MakeVersionLabel('Q', '0', '4', '3'); + } else if (parsed_version == ParsedQuicVersion::ReservedForNegotiation()) { + return CreateRandomVersionLabelForNegotiation(); + } + QUIC_BUG(quic_bug_10589_2) + << "Unsupported version " + << QuicVersionToString(parsed_version.transport_version) << " " + << HandshakeProtocolToString(parsed_version.handshake_protocol); + return 0; +} + +QuicVersionLabelVector CreateQuicVersionLabelVector( + const ParsedQuicVersionVector& versions) { + QuicVersionLabelVector out; + out.reserve(versions.size()); + for (const auto& version : versions) { + out.push_back(CreateQuicVersionLabel(version)); + } + return out; +} + +ParsedQuicVersionVector AllSupportedVersionsWithQuicCrypto() { + ParsedQuicVersionVector versions; + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + if (version.handshake_protocol == PROTOCOL_QUIC_CRYPTO) { + versions.push_back(version); + } + } + QUIC_BUG_IF(quic_bug_10589_3, versions.empty()) + << "No version with QUIC crypto found."; + return versions; +} + +ParsedQuicVersionVector CurrentSupportedVersionsWithQuicCrypto() { + ParsedQuicVersionVector versions; + for (const ParsedQuicVersion& version : CurrentSupportedVersions()) { + if (version.handshake_protocol == PROTOCOL_QUIC_CRYPTO) { + versions.push_back(version); + } + } + QUIC_BUG_IF(quic_bug_10589_4, versions.empty()) + << "No version with QUIC crypto found."; + return versions; +} + +ParsedQuicVersionVector AllSupportedVersionsWithTls() { + ParsedQuicVersionVector versions; + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + if (version.UsesTls()) { + versions.push_back(version); + } + } + QUIC_BUG_IF(quic_bug_10589_5, versions.empty()) + << "No version with TLS handshake found."; + return versions; +} + +ParsedQuicVersionVector CurrentSupportedVersionsWithTls() { + ParsedQuicVersionVector versions; + for (const ParsedQuicVersion& version : CurrentSupportedVersions()) { + if (version.UsesTls()) { + versions.push_back(version); + } + } + QUIC_BUG_IF(quic_bug_10589_6, versions.empty()) + << "No version with TLS handshake found."; + return versions; +} + +ParsedQuicVersionVector CurrentSupportedHttp3Versions() { + ParsedQuicVersionVector versions; + for (const ParsedQuicVersion& version : CurrentSupportedVersions()) { + if (version.UsesHttp3()) { + versions.push_back(version); + } + } + QUIC_BUG_IF(no_version_uses_http3, versions.empty()) + << "No version speaking Http3 found."; + return versions; +} + +ParsedQuicVersion ParseQuicVersionLabel(QuicVersionLabel version_label) { + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + if (version_label == CreateQuicVersionLabel(version)) { + return version; + } + } + // Reading from the client so this should not be considered an ERROR. + QUIC_DLOG(INFO) << "Unsupported QuicVersionLabel version: " + << QuicVersionLabelToString(version_label); + return UnsupportedQuicVersion(); +} + +ParsedQuicVersionVector ParseQuicVersionLabelVector( + const QuicVersionLabelVector& version_labels) { + ParsedQuicVersionVector parsed_versions; + for (const QuicVersionLabel& version_label : version_labels) { + ParsedQuicVersion parsed_version = ParseQuicVersionLabel(version_label); + if (parsed_version.IsKnown()) { + parsed_versions.push_back(parsed_version); + } + } + return parsed_versions; +} + +ParsedQuicVersion ParseQuicVersionString(absl::string_view version_string) { + if (version_string.empty()) { + return UnsupportedQuicVersion(); + } + const ParsedQuicVersionVector supported_versions = AllSupportedVersions(); + for (const ParsedQuicVersion& version : supported_versions) { + if (version_string == ParsedQuicVersionToString(version) || + (version_string == AlpnForVersion(version) && + !version.AlpnDeferToRFCv1()) || + (version.handshake_protocol == PROTOCOL_QUIC_CRYPTO && + version_string == QuicVersionToString(version.transport_version))) { + return version; + } + } + for (const ParsedQuicVersion& version : supported_versions) { + if (version.UsesHttp3() && + version_string == + QuicVersionLabelToString(CreateQuicVersionLabel(version))) { + return version; + } + } + int quic_version_number = 0; + if (absl::SimpleAtoi(version_string, &quic_version_number) && + quic_version_number > 0) { + QuicTransportVersion transport_version = + static_cast(quic_version_number); + if (!ParsedQuicVersionIsValid(PROTOCOL_QUIC_CRYPTO, transport_version)) { + return UnsupportedQuicVersion(); + } + ParsedQuicVersion version(PROTOCOL_QUIC_CRYPTO, transport_version); + if (std::find(supported_versions.begin(), supported_versions.end(), + version) != supported_versions.end()) { + return version; + } + return UnsupportedQuicVersion(); + } + // Reading from the client so this should not be considered an ERROR. + QUIC_DLOG(INFO) << "Unsupported QUIC version string: \"" << version_string + << "\"."; + return UnsupportedQuicVersion(); +} + +ParsedQuicVersionVector ParseQuicVersionVectorString( + absl::string_view versions_string) { + ParsedQuicVersionVector versions; + std::vector version_strings = + absl::StrSplit(versions_string, ','); + for (absl::string_view version_string : version_strings) { + quiche::QuicheTextUtils::RemoveLeadingAndTrailingWhitespace( + &version_string); + ParsedQuicVersion version = ParseQuicVersionString(version_string); + if (!version.IsKnown() || std::find(versions.begin(), versions.end(), + version) != versions.end()) { + continue; + } + versions.push_back(version); + } + return versions; +} + +QuicTransportVersionVector AllSupportedTransportVersions() { + QuicTransportVersionVector transport_versions; + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + if (std::find(transport_versions.begin(), transport_versions.end(), + version.transport_version) == transport_versions.end()) { + transport_versions.push_back(version.transport_version); + } + } + return transport_versions; +} + +ParsedQuicVersionVector AllSupportedVersions() { + constexpr auto supported_versions = SupportedVersions(); + return ParsedQuicVersionVector(supported_versions.begin(), + supported_versions.end()); +} + +ParsedQuicVersionVector CurrentSupportedVersions() { + return FilterSupportedVersions(AllSupportedVersions()); +} + +ParsedQuicVersionVector FilterSupportedVersions( + ParsedQuicVersionVector versions) { + static_assert(SupportedVersions().size() == 6u, + "Supported versions out of sync"); + ParsedQuicVersionVector filtered_versions; + filtered_versions.reserve(versions.size()); + for (const ParsedQuicVersion& version : versions) { + if (version == ParsedQuicVersion::V2Draft08()) { + if (GetQuicReloadableFlag(quic_enable_version_2_draft_08)) { + filtered_versions.push_back(version); + } + } else if (version == ParsedQuicVersion::RFCv1()) { + if (!GetQuicReloadableFlag(quic_disable_version_rfcv1)) { + filtered_versions.push_back(version); + } + } else if (version == ParsedQuicVersion::Draft29()) { + if (!GetQuicReloadableFlag(quic_disable_version_draft_29)) { + filtered_versions.push_back(version); + } + } else if (version == ParsedQuicVersion::Q050()) { + if (!GetQuicReloadableFlag(quic_disable_version_q050)) { + filtered_versions.push_back(version); + } + } else if (version == ParsedQuicVersion::Q046()) { + if (!GetQuicReloadableFlag(quic_disable_version_q046)) { + filtered_versions.push_back(version); + } + } else if (version == ParsedQuicVersion::Q043()) { + if (!GetQuicReloadableFlag(quic_disable_version_q043)) { + filtered_versions.push_back(version); + } + } else { + QUIC_BUG(quic_bug_10589_7) + << "QUIC version " << version << " has no flag protection"; + filtered_versions.push_back(version); + } + } + return filtered_versions; +} + +ParsedQuicVersionVector ParsedVersionOfIndex( + const ParsedQuicVersionVector& versions, int index) { + ParsedQuicVersionVector version; + int version_count = versions.size(); + if (index >= 0 && index < version_count) { + version.push_back(versions[index]); + } else { + version.push_back(UnsupportedQuicVersion()); + } + return version; +} + +std::string QuicVersionLabelToString(QuicVersionLabel version_label) { + return QuicTagToString(quiche::QuicheEndian::HostToNet32(version_label)); +} + +ParsedQuicVersion ParseQuicVersionLabelString( + absl::string_view version_label_string) { + const ParsedQuicVersionVector supported_versions = AllSupportedVersions(); + for (const ParsedQuicVersion& version : supported_versions) { + if (version_label_string == + QuicVersionLabelToString(CreateQuicVersionLabel(version))) { + return version; + } + } + return UnsupportedQuicVersion(); +} + +std::string QuicVersionLabelVectorToString( + const QuicVersionLabelVector& version_labels, const std::string& separator, + size_t skip_after_nth_version) { + std::string result; + for (size_t i = 0; i < version_labels.size(); ++i) { + if (i != 0) { + result.append(separator); + } + + if (i > skip_after_nth_version) { + result.append("..."); + break; + } + result.append(QuicVersionLabelToString(version_labels[i])); + } + return result; +} + +#define RETURN_STRING_LITERAL(x) \ + case x: \ + return #x + +std::string QuicVersionToString(QuicTransportVersion transport_version) { + switch (transport_version) { + RETURN_STRING_LITERAL(QUIC_VERSION_43); + RETURN_STRING_LITERAL(QUIC_VERSION_46); + RETURN_STRING_LITERAL(QUIC_VERSION_50); + RETURN_STRING_LITERAL(QUIC_VERSION_IETF_DRAFT_29); + RETURN_STRING_LITERAL(QUIC_VERSION_IETF_RFC_V1); + RETURN_STRING_LITERAL(QUIC_VERSION_IETF_2_DRAFT_08); + RETURN_STRING_LITERAL(QUIC_VERSION_UNSUPPORTED); + RETURN_STRING_LITERAL(QUIC_VERSION_RESERVED_FOR_NEGOTIATION); + } + return absl::StrCat("QUIC_VERSION_UNKNOWN(", + static_cast(transport_version), ")"); +} + +std::string HandshakeProtocolToString(HandshakeProtocol handshake_protocol) { + switch (handshake_protocol) { + RETURN_STRING_LITERAL(PROTOCOL_UNSUPPORTED); + RETURN_STRING_LITERAL(PROTOCOL_QUIC_CRYPTO); + RETURN_STRING_LITERAL(PROTOCOL_TLS1_3); + } + return absl::StrCat("PROTOCOL_UNKNOWN(", static_cast(handshake_protocol), + ")"); +} + +std::string ParsedQuicVersionToString(ParsedQuicVersion version) { + static_assert(SupportedVersions().size() == 6u, + "Supported versions out of sync"); + if (version == UnsupportedQuicVersion()) { + return "0"; + } else if (version == ParsedQuicVersion::V2Draft08()) { + QUICHE_DCHECK(version.UsesHttp3()); + return "V2Draft08"; + } else if (version == ParsedQuicVersion::RFCv1()) { + QUICHE_DCHECK(version.UsesHttp3()); + return "RFCv1"; + } else if (version == ParsedQuicVersion::Draft29()) { + QUICHE_DCHECK(version.UsesHttp3()); + return "draft29"; + } + + return QuicVersionLabelToString(CreateQuicVersionLabel(version)); +} + +std::string QuicTransportVersionVectorToString( + const QuicTransportVersionVector& versions) { + std::string result = ""; + for (size_t i = 0; i < versions.size(); ++i) { + if (i != 0) { + result.append(","); + } + result.append(QuicVersionToString(versions[i])); + } + return result; +} + +std::string ParsedQuicVersionVectorToString( + const ParsedQuicVersionVector& versions, const std::string& separator, + size_t skip_after_nth_version) { + std::string result; + for (size_t i = 0; i < versions.size(); ++i) { + if (i != 0) { + result.append(separator); + } + if (i > skip_after_nth_version) { + result.append("..."); + break; + } + result.append(ParsedQuicVersionToString(versions[i])); + } + return result; +} + +bool VersionSupportsGoogleAltSvcFormat(QuicTransportVersion transport_version) { + return transport_version <= QUIC_VERSION_46; +} + +bool VersionAllowsVariableLengthConnectionIds( + QuicTransportVersion transport_version) { + QUICHE_DCHECK_NE(transport_version, QUIC_VERSION_UNSUPPORTED); + return transport_version > QUIC_VERSION_46; +} + +bool QuicVersionLabelUses4BitConnectionIdLength( + QuicVersionLabel version_label) { + // As we deprecate old versions, we still need the ability to send valid + // version negotiation packets for those versions. This function keeps track + // of the versions that ever supported the 4bit connection ID length encoding + // that we know about. Google QUIC 43 and earlier used a different encoding, + // and Google QUIC 49 and later use the new length prefixed encoding. + // Similarly, only IETF drafts 11 to 21 used this encoding. + + // Check Q044, Q045, Q046, Q047 and Q048. + for (uint8_t c = '4'; c <= '8'; ++c) { + if (version_label == MakeVersionLabel('Q', '0', '4', c)) { + return true; + } + } + // Check T048. + if (version_label == MakeVersionLabel('T', '0', '4', '8')) { + return true; + } + // Check IETF draft versions in [11,21]. + for (uint8_t draft_number = 11; draft_number <= 21; ++draft_number) { + if (version_label == MakeVersionLabel(0xff, 0x00, 0x00, draft_number)) { + return true; + } + } + return false; +} + +ParsedQuicVersion UnsupportedQuicVersion() { + return ParsedQuicVersion::Unsupported(); +} + +ParsedQuicVersion QuicVersionReservedForNegotiation() { + return ParsedQuicVersion::ReservedForNegotiation(); +} + +std::string AlpnForVersion(ParsedQuicVersion parsed_version) { + if (parsed_version == ParsedQuicVersion::V2Draft08()) { + return "h3"; + } else if (parsed_version == ParsedQuicVersion::RFCv1()) { + return "h3"; + } else if (parsed_version == ParsedQuicVersion::Draft29()) { + return "h3-29"; + } + return "h3-" + ParsedQuicVersionToString(parsed_version); +} + +void QuicVersionInitializeSupportForIetfDraft() { + // Enable necessary flags. + SetQuicRestartFlag(quic_receive_ecn, true); +} + +void QuicEnableVersion(const ParsedQuicVersion& version) { + SetVersionFlag(version, /*should_enable=*/true); +} + +void QuicDisableVersion(const ParsedQuicVersion& version) { + SetVersionFlag(version, /*should_enable=*/false); +} + +bool QuicVersionIsEnabled(const ParsedQuicVersion& version) { + ParsedQuicVersionVector current = CurrentSupportedVersions(); + return std::find(current.begin(), current.end(), version) != current.end(); +} + +#undef RETURN_STRING_LITERAL // undef for jumbo builds +} // namespace quic diff --git a/quiche/quic/core/quic_versions.h b/quiche/quic/core/quic_versions.h new file mode 100644 index 000000000000..a8cfab371a0a --- /dev/null +++ b/quiche/quic/core/quic_versions.h @@ -0,0 +1,649 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Definitions and utility functions related to handling of QUIC versions. +// +// QUIC versions are encoded over the wire as an opaque 32bit field. The wire +// encoding is represented in memory as a QuicVersionLabel type (which is an +// alias to uint32_t). Conceptual versions are represented in memory as +// ParsedQuicVersion. +// +// We currently support two kinds of QUIC versions, GoogleQUIC and IETF QUIC. +// +// All GoogleQUIC versions use a wire encoding that matches the following regex +// when converted to ASCII: "[QT]0\d\d" (e.g. Q050). Q or T distinguishes the +// type of handshake used (Q for the QUIC_CRYPTO handshake, T for the QUIC+TLS +// handshake), and the two digits at the end contain the numeric value of +// the transport version used. +// +// All IETF QUIC versions use the wire encoding described in: +// https://tools.ietf.org/html/draft-ietf-quic-transport + +#ifndef QUICHE_QUIC_CORE_QUIC_VERSIONS_H_ +#define QUICHE_QUIC_CORE_QUIC_VERSIONS_H_ + +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_tag.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// The list of existing QUIC transport versions. Note that QUIC versions are +// sent over the wire as an encoding of ParsedQuicVersion, which requires a +// QUIC transport version and handshake protocol. For transport versions of the +// form QUIC_VERSION_XX where XX is decimal, the enum numeric value is +// guaranteed to match the name. Older deprecated transport versions are +// documented in comments below. +enum QuicTransportVersion { + // Special case to indicate unknown/unsupported QUIC version. + QUIC_VERSION_UNSUPPORTED = 0, + + // Version 1 was the first version of QUIC that supported versioning. + // Version 2 decoupled versioning of non-cryptographic parameters from the + // SCFG. + // Version 3 moved public flags into the beginning of the packet. + // Version 4 added support for variable-length connection IDs. + // Version 5 made specifying FEC groups optional. + // Version 6 introduced variable-length packet numbers. + // Version 7 introduced a lower-overhead encoding for stream frames. + // Version 8 made salt length equal to digest length for the RSA-PSS + // signatures. + // Version 9 added stream priority. + // Version 10 redid the frame type numbering. + // Version 11 reduced the length of null encryption authentication tag + // from 16 to 12 bytes. + // Version 12 made the sequence numbers in the ACK frames variable-sized. + // Version 13 added the dedicated header stream. + // Version 14 added byte_offset to RST_STREAM frame. + // Version 15 added a list of packets recovered using FEC to the ACK frame. + // Version 16 added STOP_WAITING frame. + // Version 17 added per-stream flow control. + // Version 18 added PING frame. + // Version 19 added connection-level flow control + // Version 20 allowed to set stream- and connection-level flow control windows + // to different values. + // Version 21 made header and crypto streams flow-controlled. + // Version 22 added support for SCUP (server config update) messages. + // Version 23 added timestamps into the ACK frame. + // Version 24 added SPDY/4 header compression. + // Version 25 added support for SPDY/4 header keys and removed error_details + // from RST_STREAM frame. + // Version 26 added XLCT (expected leaf certificate) tag into CHLO. + // Version 27 added a nonce into SHLO. + // Version 28 allowed receiver to refuse creating a requested stream. + // Version 29 added support for QUIC_STREAM_NO_ERROR. + // Version 30 added server-side support for certificate transparency. + // Version 31 incorporated the hash of CHLO into the crypto proof supplied by + // the server. + // Version 32 removed FEC-related fields from wire format. + // Version 33 added diversification nonces. + // Version 34 removed entropy bits from packets and ACK frames, removed + // private flag from packet header and changed the ACK format to + // specify ranges of packets acknowledged rather than missing + // ranges. + // Version 35 allows endpoints to independently set stream limit. + // Version 36 added support for forced head-of-line blocking experiments. + // Version 37 added perspective into null encryption. + // Version 38 switched to IETF padding frame format and support for NSTP (no + // stop waiting frame) connection option. + + // Version 39 writes integers and floating numbers in big endian, stops acking + // acks, sends a connection level WINDOW_UPDATE every 20 sent packets which do + // not contain retransmittable frames. + + // Version 40 was an attempt to convert QUIC to IETF frame format; it was + // never shipped due to a bug. + // Version 41 was a bugfix for version 40. The working group changed the wire + // format before it shipped, which caused it to be never shipped + // and all the changes from it to be reverted. No changes from v40 + // or v41 are present in subsequent versions. + // Version 42 allowed receiving overlapping stream data. + + QUIC_VERSION_43 = 43, // PRIORITY frames are sent by client and accepted by + // server. + // Version 44 used IETF header format from draft-ietf-quic-invariants-05. + + // Version 45 added MESSAGE frame. + + QUIC_VERSION_46 = 46, // Use IETF draft-17 header format with demultiplexing + // bit. + // Version 47 added variable-length QUIC server connection IDs. + // Version 48 added CRYPTO frames for the handshake. + // Version 49 added client connection IDs, long header lengths, and the IETF + // header format from draft-ietf-quic-invariants-06 + QUIC_VERSION_50 = 50, // Header protection and initial obfuscators. + // Number 51 was T051 which used draft-29 features but with GoogleQUIC frames. + // Number 70 used to represent draft-ietf-quic-transport-25. + // Number 71 used to represent draft-ietf-quic-transport-27. + // Number 72 used to represent draft-ietf-quic-transport-28. + QUIC_VERSION_IETF_DRAFT_29 = 73, // draft-ietf-quic-transport-29. + QUIC_VERSION_IETF_RFC_V1 = 80, // RFC 9000. + // Number 81 used to represent draft-ietf-quic-v2-01. + QUIC_VERSION_IETF_2_DRAFT_08 = 82, // draft-ietf-quic-v2-08. + // Version 99 was a dumping ground for IETF QUIC changes which were not yet + // ready for production between 2018-02 and 2020-02. + + // QUIC_VERSION_RESERVED_FOR_NEGOTIATION is sent over the wire as ?a?a?a?a + // which is part of a range reserved by the IETF for version negotiation + // testing (see the "Versions" section of draft-ietf-quic-transport). + // This version is intentionally meant to never be supported to trigger + // version negotiation when proposed by clients and to prevent client + // ossification when sent by servers. + QUIC_VERSION_RESERVED_FOR_NEGOTIATION = 999, +}; + +// Helper function which translates from a QuicTransportVersion to a string. +// Returns strings corresponding to enum names (e.g. QUIC_VERSION_6). +QUIC_EXPORT_PRIVATE std::string QuicVersionToString( + QuicTransportVersion transport_version); + +// The crypto handshake protocols that can be used with QUIC. +// We are planning on eventually deprecating PROTOCOL_QUIC_CRYPTO in favor of +// PROTOCOL_TLS1_3. +enum HandshakeProtocol { + PROTOCOL_UNSUPPORTED, + PROTOCOL_QUIC_CRYPTO, + PROTOCOL_TLS1_3, +}; + +// Helper function which translates from a HandshakeProtocol to a string. +QUIC_EXPORT_PRIVATE std::string HandshakeProtocolToString( + HandshakeProtocol handshake_protocol); + +// Returns whether |transport_version| uses CRYPTO frames for the handshake +// instead of stream 1. +QUIC_EXPORT_PRIVATE constexpr bool QuicVersionUsesCryptoFrames( + QuicTransportVersion transport_version) { + // CRYPTO frames were added in version 48. + return transport_version > QUIC_VERSION_46; +} + +// Returns whether this combination of handshake protocol and transport +// version is allowed. For example, {PROTOCOL_TLS1_3, QUIC_VERSION_43} is NOT +// allowed as TLS requires crypto frames which v43 does not support. Note that +// UnsupportedQuicVersion is a valid version. +QUIC_EXPORT_PRIVATE constexpr bool ParsedQuicVersionIsValid( + HandshakeProtocol handshake_protocol, + QuicTransportVersion transport_version) { + bool transport_version_is_valid = false; + constexpr QuicTransportVersion valid_transport_versions[] = { + QUIC_VERSION_IETF_2_DRAFT_08, + QUIC_VERSION_IETF_RFC_V1, + QUIC_VERSION_IETF_DRAFT_29, + QUIC_VERSION_50, + QUIC_VERSION_46, + QUIC_VERSION_43, + QUIC_VERSION_RESERVED_FOR_NEGOTIATION, + QUIC_VERSION_UNSUPPORTED, + }; + for (size_t i = 0; i < ABSL_ARRAYSIZE(valid_transport_versions); ++i) { + if (transport_version == valid_transport_versions[i]) { + transport_version_is_valid = true; + break; + } + } + if (!transport_version_is_valid) { + return false; + } + switch (handshake_protocol) { + case PROTOCOL_UNSUPPORTED: + return transport_version == QUIC_VERSION_UNSUPPORTED; + case PROTOCOL_QUIC_CRYPTO: + return transport_version != QUIC_VERSION_UNSUPPORTED && + transport_version != QUIC_VERSION_RESERVED_FOR_NEGOTIATION && + transport_version != QUIC_VERSION_IETF_DRAFT_29 && + transport_version != QUIC_VERSION_IETF_RFC_V1 && + transport_version != QUIC_VERSION_IETF_2_DRAFT_08; + case PROTOCOL_TLS1_3: + return transport_version != QUIC_VERSION_UNSUPPORTED && + transport_version != QUIC_VERSION_50 && + QuicVersionUsesCryptoFrames(transport_version); + } + return false; +} + +// A parsed QUIC version label which determines that handshake protocol +// and the transport version. +struct QUIC_EXPORT_PRIVATE ParsedQuicVersion { + HandshakeProtocol handshake_protocol; + QuicTransportVersion transport_version; + + constexpr ParsedQuicVersion(HandshakeProtocol handshake_protocol, + QuicTransportVersion transport_version) + : handshake_protocol(handshake_protocol), + transport_version(transport_version) { + QUICHE_DCHECK( + ParsedQuicVersionIsValid(handshake_protocol, transport_version)) + << QuicVersionToString(transport_version) << " " + << HandshakeProtocolToString(handshake_protocol); + } + + constexpr ParsedQuicVersion(const ParsedQuicVersion& other) + : ParsedQuicVersion(other.handshake_protocol, other.transport_version) {} + + ParsedQuicVersion& operator=(const ParsedQuicVersion& other) { + QUICHE_DCHECK(ParsedQuicVersionIsValid(other.handshake_protocol, + other.transport_version)) + << QuicVersionToString(other.transport_version) << " " + << HandshakeProtocolToString(other.handshake_protocol); + if (this != &other) { + handshake_protocol = other.handshake_protocol; + transport_version = other.transport_version; + } + return *this; + } + + bool operator==(const ParsedQuicVersion& other) const { + return handshake_protocol == other.handshake_protocol && + transport_version == other.transport_version; + } + + bool operator!=(const ParsedQuicVersion& other) const { + return handshake_protocol != other.handshake_protocol || + transport_version != other.transport_version; + } + + static constexpr ParsedQuicVersion V2Draft08() { + return ParsedQuicVersion(PROTOCOL_TLS1_3, QUIC_VERSION_IETF_2_DRAFT_08); + } + + static constexpr ParsedQuicVersion RFCv1() { + return ParsedQuicVersion(PROTOCOL_TLS1_3, QUIC_VERSION_IETF_RFC_V1); + } + + static constexpr ParsedQuicVersion Draft29() { + return ParsedQuicVersion(PROTOCOL_TLS1_3, QUIC_VERSION_IETF_DRAFT_29); + } + + static constexpr ParsedQuicVersion Q050() { + return ParsedQuicVersion(PROTOCOL_QUIC_CRYPTO, QUIC_VERSION_50); + } + + static constexpr ParsedQuicVersion Q046() { + return ParsedQuicVersion(PROTOCOL_QUIC_CRYPTO, QUIC_VERSION_46); + } + + static constexpr ParsedQuicVersion Q043() { + return ParsedQuicVersion(PROTOCOL_QUIC_CRYPTO, QUIC_VERSION_43); + } + + static constexpr ParsedQuicVersion Unsupported() { + return ParsedQuicVersion(PROTOCOL_UNSUPPORTED, QUIC_VERSION_UNSUPPORTED); + } + + static constexpr ParsedQuicVersion ReservedForNegotiation() { + return ParsedQuicVersion(PROTOCOL_TLS1_3, + QUIC_VERSION_RESERVED_FOR_NEGOTIATION); + } + + // Returns whether our codebase understands this version. This should only be + // called on valid versions, see ParsedQuicVersionIsValid. Assuming the + // version is valid, IsKnown returns whether the version is not + // UnsupportedQuicVersion. + bool IsKnown() const; + + bool KnowsWhichDecrypterToUse() const; + + // Returns whether this version uses keys derived from the Connection ID for + // ENCRYPTION_INITIAL keys (instead of NullEncrypter/NullDecrypter). + bool UsesInitialObfuscators() const; + + // Indicates that this QUIC version does not have an enforced minimum value + // for flow control values negotiated during the handshake. + bool AllowsLowFlowControlLimits() const; + + // Returns whether header protection is used in this version of QUIC. + bool HasHeaderProtection() const; + + // Returns whether this version supports IETF RETRY packets. + bool SupportsRetry() const; + + // Returns true if this version sends variable length packet number in long + // header. + bool SendsVariableLengthPacketNumberInLongHeader() const; + + // Returns whether this version allows server connection ID lengths + // that are not 64 bits. + bool AllowsVariableLengthConnectionIds() const; + + // Returns whether this version supports client connection ID. + bool SupportsClientConnectionIds() const; + + // Returns whether this version supports long header 8-bit encoded + // connection ID lengths as described in draft-ietf-quic-invariants-06 and + // draft-ietf-quic-transport-22. + bool HasLengthPrefixedConnectionIds() const; + + // Returns whether this version supports IETF style anti-amplification limit, + // i.e., server will send no more than FLAGS_quic_anti_amplification_factor + // times received bytes until address can be validated. + bool SupportsAntiAmplificationLimit() const; + + // Returns true if this version can send coalesced packets. + bool CanSendCoalescedPackets() const; + + // Returns true if this version supports the old Google-style Alt-Svc + // advertisement format. + bool SupportsGoogleAltSvcFormat() const; + + // Returns true if |transport_version| uses IETF invariant headers. + bool HasIetfInvariantHeader() const; + + // Returns true if |transport_version| supports MESSAGE frames. + bool SupportsMessageFrames() const; + + // If true, HTTP/3 instead of gQUIC will be used at the HTTP layer. + // Notable changes are: + // * Headers stream no longer exists. + // * PRIORITY, HEADERS are moved from headers stream to HTTP/3 control stream. + // * PUSH_PROMISE is moved to request stream. + // * Unidirectional streams will have their first byte as a stream type. + // * HEADERS frames are compressed using QPACK. + // * DATA frame has frame headers. + // * GOAWAY is moved to HTTP layer. + bool UsesHttp3() const; + + // Returns whether the transport_version supports the variable length integer + // length field as defined by IETF QUIC draft-13 and later. + bool HasLongHeaderLengths() const; + + // Returns whether |transport_version| uses CRYPTO frames for the handshake + // instead of stream 1. + bool UsesCryptoFrames() const; + + // Returns whether |transport_version| makes use of IETF QUIC + // frames or not. + bool HasIetfQuicFrames() const; + + // Returns whether this version uses the legacy TLS extension codepoint. + bool UsesLegacyTlsExtension() const; + + // Returns whether this version uses PROTOCOL_TLS1_3. + bool UsesTls() const; + + // Returns whether this version uses PROTOCOL_QUIC_CRYPTO. + bool UsesQuicCrypto() const; + + // Returns whether this version uses the QUICv2 Long Header Packet Types. + bool UsesV2PacketTypes() const; + + // Returns true if this shares ALPN codes with RFCv1, and endpoints should + // choose RFCv1 when presented with a v1 ALPN. Note that this is false for + // RFCv1. + bool AlpnDeferToRFCv1() const; +}; + +QUIC_EXPORT_PRIVATE ParsedQuicVersion UnsupportedQuicVersion(); + +QUIC_EXPORT_PRIVATE ParsedQuicVersion QuicVersionReservedForNegotiation(); + +QUIC_EXPORT_PRIVATE std::ostream& operator<<(std::ostream& os, + const ParsedQuicVersion& version); + +using ParsedQuicVersionVector = std::vector; + +QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const ParsedQuicVersionVector& versions); + +// Representation of the on-the-wire QUIC version number. Will be written/read +// to the wire in network-byte-order. +using QuicVersionLabel = uint32_t; +using QuicVersionLabelVector = std::vector; + +// Constructs a version label from the 4 bytes such that the on-the-wire +// order will be: d, c, b, a. +QUIC_EXPORT_PRIVATE QuicVersionLabel MakeVersionLabel(uint8_t a, uint8_t b, + uint8_t c, uint8_t d); + +QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const QuicVersionLabelVector& version_labels); + +// This vector contains all crypto handshake protocols that are supported. +constexpr std::array SupportedHandshakeProtocols() { + return {PROTOCOL_TLS1_3, PROTOCOL_QUIC_CRYPTO}; +} + +constexpr std::array SupportedVersions() { + return { + ParsedQuicVersion::V2Draft08(), ParsedQuicVersion::RFCv1(), + ParsedQuicVersion::Draft29(), ParsedQuicVersion::Q050(), + ParsedQuicVersion::Q046(), ParsedQuicVersion::Q043(), + }; +} + +using QuicTransportVersionVector = std::vector; + +QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const QuicTransportVersionVector& transport_versions); + +// Returns a vector of supported QUIC versions. +QUIC_EXPORT_PRIVATE ParsedQuicVersionVector AllSupportedVersions(); + +// Returns a vector of supported QUIC versions, with any versions disabled by +// flags excluded. +QUIC_EXPORT_PRIVATE ParsedQuicVersionVector CurrentSupportedVersions(); + +// Returns a vector of QUIC versions from |versions| which exclude any versions +// which are disabled by flags. +QUIC_EXPORT_PRIVATE ParsedQuicVersionVector +FilterSupportedVersions(ParsedQuicVersionVector versions); + +// Returns a subset of AllSupportedVersions() with +// handshake_protocol == PROTOCOL_QUIC_CRYPTO, in the same order. +// Deprecated; only to be used in components that do not yet support +// PROTOCOL_TLS1_3. +QUIC_EXPORT_PRIVATE ParsedQuicVersionVector +AllSupportedVersionsWithQuicCrypto(); + +// Returns a subset of CurrentSupportedVersions() with +// handshake_protocol == PROTOCOL_QUIC_CRYPTO, in the same order. +QUIC_EXPORT_PRIVATE ParsedQuicVersionVector +CurrentSupportedVersionsWithQuicCrypto(); + +// Returns a subset of AllSupportedVersions() with +// handshake_protocol == PROTOCOL_TLS1_3, in the same order. +QUIC_EXPORT_PRIVATE ParsedQuicVersionVector AllSupportedVersionsWithTls(); + +// Returns a subset of CurrentSupportedVersions() with handshake_protocol == +// PROTOCOL_TLS1_3. +QUIC_EXPORT_PRIVATE ParsedQuicVersionVector CurrentSupportedVersionsWithTls(); + +// Returns a subset of CurrentSupportedVersions() using HTTP/3 at the HTTP +// layer. +QUIC_EXPORT_PRIVATE ParsedQuicVersionVector CurrentSupportedHttp3Versions(); + +// Returns QUIC version of |index| in result of |versions|. Returns +// UnsupportedQuicVersion() if |index| is out of bounds. +QUIC_EXPORT_PRIVATE ParsedQuicVersionVector +ParsedVersionOfIndex(const ParsedQuicVersionVector& versions, int index); + +// QuicVersionLabel is written to and read from the wire, but we prefer to use +// the more readable ParsedQuicVersion at other levels. +// Helper function which translates from a QuicVersionLabel to a +// ParsedQuicVersion. +QUIC_EXPORT_PRIVATE ParsedQuicVersion +ParseQuicVersionLabel(QuicVersionLabel version_label); + +// Helper function that translates from a QuicVersionLabelVector to a +// ParsedQuicVersionVector. +QUIC_EXPORT_PRIVATE ParsedQuicVersionVector +ParseQuicVersionLabelVector(const QuicVersionLabelVector& version_labels); + +// Parses a QUIC version string such as "Q043" or "T051". Also supports parsing +// ALPN such as "h3-29" or "h3-Q050". For PROTOCOL_QUIC_CRYPTO versions, also +// supports parsing numbers such as "46". +QUIC_EXPORT_PRIVATE ParsedQuicVersion +ParseQuicVersionString(absl::string_view version_string); + +// Parses a comma-separated list of QUIC version strings. Supports parsing by +// label, ALPN and numbers for PROTOCOL_QUIC_CRYPTO. Skips unknown versions. +// For example: "h3-29,Q050,46". +QUIC_EXPORT_PRIVATE ParsedQuicVersionVector +ParseQuicVersionVectorString(absl::string_view versions_string); + +// Constructs a QuicVersionLabel from the provided ParsedQuicVersion. +// QuicVersionLabel is written to and read from the wire, but we prefer to use +// the more readable ParsedQuicVersion at other levels. +// Helper function which translates from a ParsedQuicVersion to a +// QuicVersionLabel. Returns 0 if |parsed_version| is unsupported. +QUIC_EXPORT_PRIVATE QuicVersionLabel +CreateQuicVersionLabel(ParsedQuicVersion parsed_version); + +// Constructs a QuicVersionLabelVector from the provided +// ParsedQuicVersionVector. +QUIC_EXPORT_PRIVATE QuicVersionLabelVector +CreateQuicVersionLabelVector(const ParsedQuicVersionVector& versions); + +// Helper function which translates from a QuicVersionLabel to a string. +QUIC_EXPORT_PRIVATE std::string QuicVersionLabelToString( + QuicVersionLabel version_label); + +// Helper function which translates from a QuicVersionLabel string to a +// ParsedQuicVersion. The version label string must be of the form returned +// by QuicVersionLabelToString, for example, "00000001" or "Q046", but not +// "51303433" (the hex encoding of the Q064 version label). Returns +// the ParsedQuicVersion which matches the label or UnsupportedQuicVersion() +// otherwise. +QUIC_EXPORT_PRIVATE ParsedQuicVersion +ParseQuicVersionLabelString(absl::string_view version_label_string); + +// Returns |separator|-separated list of string representations of +// QuicVersionLabel values in the supplied |version_labels| vector. The values +// after the (0-based) |skip_after_nth_version|'th are skipped. +QUIC_EXPORT_PRIVATE std::string QuicVersionLabelVectorToString( + const QuicVersionLabelVector& version_labels, const std::string& separator, + size_t skip_after_nth_version); + +// Returns comma separated list of string representations of QuicVersionLabel +// values in the supplied |version_labels| vector. +QUIC_EXPORT_PRIVATE inline std::string QuicVersionLabelVectorToString( + const QuicVersionLabelVector& version_labels) { + return QuicVersionLabelVectorToString(version_labels, ",", + std::numeric_limits::max()); +} + +// Helper function which translates from a ParsedQuicVersion to a string. +// Returns strings corresponding to the on-the-wire tag. +QUIC_EXPORT_PRIVATE std::string ParsedQuicVersionToString( + ParsedQuicVersion version); + +// Returns a vector of supported QUIC transport versions. DEPRECATED, use +// AllSupportedVersions instead. +QUIC_EXPORT_PRIVATE QuicTransportVersionVector AllSupportedTransportVersions(); + +// Returns comma separated list of string representations of +// QuicTransportVersion enum values in the supplied |versions| vector. +QUIC_EXPORT_PRIVATE std::string QuicTransportVersionVectorToString( + const QuicTransportVersionVector& versions); + +// Returns comma separated list of string representations of ParsedQuicVersion +// values in the supplied |versions| vector. +QUIC_EXPORT_PRIVATE std::string ParsedQuicVersionVectorToString( + const ParsedQuicVersionVector& versions); + +// Returns |separator|-separated list of string representations of +// ParsedQuicVersion values in the supplied |versions| vector. The values after +// the (0-based) |skip_after_nth_version|'th are skipped. +QUIC_EXPORT_PRIVATE std::string ParsedQuicVersionVectorToString( + const ParsedQuicVersionVector& versions, const std::string& separator, + size_t skip_after_nth_version); + +// Returns comma separated list of string representations of ParsedQuicVersion +// values in the supplied |versions| vector. +QUIC_EXPORT_PRIVATE inline std::string ParsedQuicVersionVectorToString( + const ParsedQuicVersionVector& versions) { + return ParsedQuicVersionVectorToString(versions, ",", + std::numeric_limits::max()); +} + +// Returns true if |transport_version| uses IETF invariant headers. +QUIC_EXPORT_PRIVATE constexpr bool VersionHasIetfInvariantHeader( + QuicTransportVersion transport_version) { + return transport_version > QUIC_VERSION_43; +} + +// Returns true if |transport_version| supports MESSAGE frames. +QUIC_EXPORT_PRIVATE constexpr bool VersionSupportsMessageFrames( + QuicTransportVersion transport_version) { + // MESSAGE frames were added in version 45. + return transport_version > QUIC_VERSION_43; +} + +// If true, HTTP/3 instead of gQUIC will be used at the HTTP layer. +// Notable changes are: +// * Headers stream no longer exists. +// * PRIORITY, HEADERS are moved from headers stream to HTTP/3 control stream. +// * PUSH_PROMISE is moved to request stream. +// * Unidirectional streams will have their first byte as a stream type. +// * HEADERS frames are compressed using QPACK. +// * DATA frame has frame headers. +// * GOAWAY is moved to HTTP layer. +QUIC_EXPORT_PRIVATE constexpr bool VersionUsesHttp3( + QuicTransportVersion transport_version) { + return transport_version >= QUIC_VERSION_IETF_DRAFT_29; +} + +// Returns whether the transport_version supports the variable length integer +// length field as defined by IETF QUIC draft-13 and later. +QUIC_EXPORT_PRIVATE constexpr bool QuicVersionHasLongHeaderLengths( + QuicTransportVersion transport_version) { + // Long header lengths were added in version 49. + return transport_version > QUIC_VERSION_46; +} + +// Returns whether |transport_version| makes use of IETF QUIC +// frames or not. +QUIC_EXPORT_PRIVATE constexpr bool VersionHasIetfQuicFrames( + QuicTransportVersion transport_version) { + return VersionUsesHttp3(transport_version); +} + +// Returns whether this version supports long header 8-bit encoded +// connection ID lengths as described in draft-ietf-quic-invariants-06 and +// draft-ietf-quic-transport-22. +QUIC_EXPORT_PRIVATE bool VersionHasLengthPrefixedConnectionIds( + QuicTransportVersion transport_version); + +// Returns true if this version supports the old Google-style Alt-Svc +// advertisement format. +QUIC_EXPORT_PRIVATE bool VersionSupportsGoogleAltSvcFormat( + QuicTransportVersion transport_version); + +// Returns whether this version allows server connection ID lengths that are +// not 64 bits. +QUIC_EXPORT_PRIVATE bool VersionAllowsVariableLengthConnectionIds( + QuicTransportVersion transport_version); + +// Returns whether this version label supports long header 4-bit encoded +// connection ID lengths as described in draft-ietf-quic-invariants-05 and +// draft-ietf-quic-transport-21. +QUIC_EXPORT_PRIVATE bool QuicVersionLabelUses4BitConnectionIdLength( + QuicVersionLabel version_label); + +// Returns the ALPN string to use in TLS for this version of QUIC. +QUIC_EXPORT_PRIVATE std::string AlpnForVersion( + ParsedQuicVersion parsed_version); + +// Initializes support for the provided IETF draft version by setting the +// correct flags. +QUIC_EXPORT_PRIVATE void QuicVersionInitializeSupportForIetfDraft(); + +// Configures the flags required to enable support for this version of QUIC. +QUIC_EXPORT_PRIVATE void QuicEnableVersion(const ParsedQuicVersion& version); + +// Configures the flags required to disable support for this version of QUIC. +QUIC_EXPORT_PRIVATE void QuicDisableVersion(const ParsedQuicVersion& version); + +// Returns whether support for this version of QUIC is currently enabled. +QUIC_EXPORT_PRIVATE bool QuicVersionIsEnabled(const ParsedQuicVersion& version); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_VERSIONS_H_ diff --git a/quiche/quic/core/quic_versions_test.cc b/quiche/quic/core/quic_versions_test.cc new file mode 100644 index 000000000000..745b58a7401f --- /dev/null +++ b/quiche/quic/core/quic_versions_test.cc @@ -0,0 +1,523 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_versions.h" + +#include "absl/base/macros.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { +namespace { + +using ::testing::ElementsAre; +using ::testing::IsEmpty; + +TEST(QuicVersionsTest, CreateQuicVersionLabelUnsupported) { + EXPECT_QUIC_BUG( + CreateQuicVersionLabel(UnsupportedQuicVersion()), + "Unsupported version QUIC_VERSION_UNSUPPORTED PROTOCOL_UNSUPPORTED"); +} + +TEST(QuicVersionsTest, KnownAndValid) { + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + EXPECT_TRUE(version.IsKnown()); + EXPECT_TRUE(ParsedQuicVersionIsValid(version.handshake_protocol, + version.transport_version)); + } + ParsedQuicVersion unsupported = UnsupportedQuicVersion(); + EXPECT_FALSE(unsupported.IsKnown()); + EXPECT_TRUE(ParsedQuicVersionIsValid(unsupported.handshake_protocol, + unsupported.transport_version)); + ParsedQuicVersion reserved = QuicVersionReservedForNegotiation(); + EXPECT_TRUE(reserved.IsKnown()); + EXPECT_TRUE(ParsedQuicVersionIsValid(reserved.handshake_protocol, + reserved.transport_version)); + // Check that invalid combinations are not valid. + EXPECT_FALSE(ParsedQuicVersionIsValid(PROTOCOL_TLS1_3, QUIC_VERSION_43)); + EXPECT_FALSE(ParsedQuicVersionIsValid(PROTOCOL_QUIC_CRYPTO, + QUIC_VERSION_IETF_DRAFT_29)); + // Check that deprecated versions are not valid. + EXPECT_FALSE(ParsedQuicVersionIsValid(PROTOCOL_QUIC_CRYPTO, + static_cast(33))); + EXPECT_FALSE(ParsedQuicVersionIsValid(PROTOCOL_QUIC_CRYPTO, + static_cast(99))); + EXPECT_FALSE(ParsedQuicVersionIsValid(PROTOCOL_TLS1_3, + static_cast(99))); +} + +TEST(QuicVersionsTest, Features) { + ParsedQuicVersion parsed_version_q043 = ParsedQuicVersion::Q043(); + ParsedQuicVersion parsed_version_draft_29 = ParsedQuicVersion::Draft29(); + + EXPECT_TRUE(parsed_version_q043.IsKnown()); + EXPECT_FALSE(parsed_version_q043.KnowsWhichDecrypterToUse()); + EXPECT_FALSE(parsed_version_q043.UsesInitialObfuscators()); + EXPECT_FALSE(parsed_version_q043.AllowsLowFlowControlLimits()); + EXPECT_FALSE(parsed_version_q043.HasHeaderProtection()); + EXPECT_FALSE(parsed_version_q043.SupportsRetry()); + EXPECT_FALSE( + parsed_version_q043.SendsVariableLengthPacketNumberInLongHeader()); + EXPECT_FALSE(parsed_version_q043.AllowsVariableLengthConnectionIds()); + EXPECT_FALSE(parsed_version_q043.SupportsClientConnectionIds()); + EXPECT_FALSE(parsed_version_q043.HasLengthPrefixedConnectionIds()); + EXPECT_FALSE(parsed_version_q043.SupportsAntiAmplificationLimit()); + EXPECT_FALSE(parsed_version_q043.CanSendCoalescedPackets()); + EXPECT_TRUE(parsed_version_q043.SupportsGoogleAltSvcFormat()); + EXPECT_FALSE(parsed_version_q043.HasIetfInvariantHeader()); + EXPECT_FALSE(parsed_version_q043.SupportsMessageFrames()); + EXPECT_FALSE(parsed_version_q043.UsesHttp3()); + EXPECT_FALSE(parsed_version_q043.HasLongHeaderLengths()); + EXPECT_FALSE(parsed_version_q043.UsesCryptoFrames()); + EXPECT_FALSE(parsed_version_q043.HasIetfQuicFrames()); + EXPECT_FALSE(parsed_version_q043.UsesTls()); + EXPECT_TRUE(parsed_version_q043.UsesQuicCrypto()); + + EXPECT_TRUE(parsed_version_draft_29.IsKnown()); + EXPECT_TRUE(parsed_version_draft_29.KnowsWhichDecrypterToUse()); + EXPECT_TRUE(parsed_version_draft_29.UsesInitialObfuscators()); + EXPECT_TRUE(parsed_version_draft_29.AllowsLowFlowControlLimits()); + EXPECT_TRUE(parsed_version_draft_29.HasHeaderProtection()); + EXPECT_TRUE(parsed_version_draft_29.SupportsRetry()); + EXPECT_TRUE( + parsed_version_draft_29.SendsVariableLengthPacketNumberInLongHeader()); + EXPECT_TRUE(parsed_version_draft_29.AllowsVariableLengthConnectionIds()); + EXPECT_TRUE(parsed_version_draft_29.SupportsClientConnectionIds()); + EXPECT_TRUE(parsed_version_draft_29.HasLengthPrefixedConnectionIds()); + EXPECT_TRUE(parsed_version_draft_29.SupportsAntiAmplificationLimit()); + EXPECT_TRUE(parsed_version_draft_29.CanSendCoalescedPackets()); + EXPECT_FALSE(parsed_version_draft_29.SupportsGoogleAltSvcFormat()); + EXPECT_TRUE(parsed_version_draft_29.HasIetfInvariantHeader()); + EXPECT_TRUE(parsed_version_draft_29.SupportsMessageFrames()); + EXPECT_TRUE(parsed_version_draft_29.UsesHttp3()); + EXPECT_TRUE(parsed_version_draft_29.HasLongHeaderLengths()); + EXPECT_TRUE(parsed_version_draft_29.UsesCryptoFrames()); + EXPECT_TRUE(parsed_version_draft_29.HasIetfQuicFrames()); + EXPECT_TRUE(parsed_version_draft_29.UsesTls()); + EXPECT_FALSE(parsed_version_draft_29.UsesQuicCrypto()); +} + +TEST(QuicVersionsTest, ParseQuicVersionLabel) { + static_assert(SupportedVersions().size() == 6u, + "Supported versions out of sync"); + EXPECT_EQ(ParsedQuicVersion::Q043(), + ParseQuicVersionLabel(MakeVersionLabel('Q', '0', '4', '3'))); + EXPECT_EQ(ParsedQuicVersion::Q046(), + ParseQuicVersionLabel(MakeVersionLabel('Q', '0', '4', '6'))); + EXPECT_EQ(ParsedQuicVersion::Q050(), + ParseQuicVersionLabel(MakeVersionLabel('Q', '0', '5', '0'))); + EXPECT_EQ(ParsedQuicVersion::Draft29(), + ParseQuicVersionLabel(MakeVersionLabel(0xff, 0x00, 0x00, 0x1d))); + EXPECT_EQ(ParsedQuicVersion::RFCv1(), + ParseQuicVersionLabel(MakeVersionLabel(0x00, 0x00, 0x00, 0x01))); + EXPECT_EQ(ParsedQuicVersion::V2Draft08(), + ParseQuicVersionLabel(MakeVersionLabel(0x6b, 0x33, 0x43, 0xcf))); + EXPECT_EQ((ParsedQuicVersionVector{ParsedQuicVersion::V2Draft08(), + ParsedQuicVersion::RFCv1(), + ParsedQuicVersion::Draft29()}), + ParseQuicVersionLabelVector(QuicVersionLabelVector{ + MakeVersionLabel(0x6b, 0x33, 0x43, 0xcf), + MakeVersionLabel(0x00, 0x00, 0x00, 0x01), + MakeVersionLabel(0xaa, 0xaa, 0xaa, 0xaa), + MakeVersionLabel(0xff, 0x00, 0x00, 0x1d)})); + + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + EXPECT_EQ(version, ParseQuicVersionLabel(CreateQuicVersionLabel(version))); + } +} + +TEST(QuicVersionsTest, ParseQuicVersionString) { + static_assert(SupportedVersions().size() == 6u, + "Supported versions out of sync"); + EXPECT_EQ(ParsedQuicVersion::Q043(), ParseQuicVersionString("Q043")); + EXPECT_EQ(ParsedQuicVersion::Q046(), + ParseQuicVersionString("QUIC_VERSION_46")); + EXPECT_EQ(ParsedQuicVersion::Q046(), ParseQuicVersionString("46")); + EXPECT_EQ(ParsedQuicVersion::Q046(), ParseQuicVersionString("Q046")); + EXPECT_EQ(ParsedQuicVersion::Q050(), ParseQuicVersionString("Q050")); + EXPECT_EQ(ParsedQuicVersion::Q050(), ParseQuicVersionString("50")); + EXPECT_EQ(ParsedQuicVersion::Q050(), ParseQuicVersionString("h3-Q050")); + + EXPECT_EQ(UnsupportedQuicVersion(), ParseQuicVersionString("")); + EXPECT_EQ(UnsupportedQuicVersion(), ParseQuicVersionString("Q 46")); + EXPECT_EQ(UnsupportedQuicVersion(), ParseQuicVersionString("Q046 ")); + EXPECT_EQ(UnsupportedQuicVersion(), ParseQuicVersionString("99")); + EXPECT_EQ(UnsupportedQuicVersion(), ParseQuicVersionString("70")); + + EXPECT_EQ(ParsedQuicVersion::Draft29(), ParseQuicVersionString("ff00001d")); + EXPECT_EQ(ParsedQuicVersion::Draft29(), ParseQuicVersionString("draft29")); + EXPECT_EQ(ParsedQuicVersion::Draft29(), ParseQuicVersionString("h3-29")); + + EXPECT_EQ(ParsedQuicVersion::RFCv1(), ParseQuicVersionString("00000001")); + EXPECT_EQ(ParsedQuicVersion::RFCv1(), ParseQuicVersionString("h3")); + + // QUICv2 will never be the result for "h3". + + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + EXPECT_EQ(version, + ParseQuicVersionString(ParsedQuicVersionToString(version))); + EXPECT_EQ(version, ParseQuicVersionString(QuicVersionLabelToString( + CreateQuicVersionLabel(version)))); + if (!version.AlpnDeferToRFCv1()) { + EXPECT_EQ(version, ParseQuicVersionString(AlpnForVersion(version))); + } + } +} + +TEST(QuicVersionsTest, ParseQuicVersionVectorString) { + ParsedQuicVersion version_q046 = ParsedQuicVersion::Q046(); + ParsedQuicVersion version_q050 = ParsedQuicVersion::Q050(); + ParsedQuicVersion version_draft_29 = ParsedQuicVersion::Draft29(); + + EXPECT_THAT(ParseQuicVersionVectorString(""), IsEmpty()); + + EXPECT_THAT(ParseQuicVersionVectorString("QUIC_VERSION_50"), + ElementsAre(version_q050)); + EXPECT_THAT(ParseQuicVersionVectorString("h3-Q050"), + ElementsAre(version_q050)); + EXPECT_THAT(ParseQuicVersionVectorString("h3-Q050, h3-29"), + ElementsAre(version_q050, version_draft_29)); + EXPECT_THAT(ParseQuicVersionVectorString("h3-29,h3-Q050,h3-29"), + ElementsAre(version_draft_29, version_q050)); + EXPECT_THAT(ParseQuicVersionVectorString("h3-29,h3-Q050, h3-29"), + ElementsAre(version_draft_29, version_q050)); + EXPECT_THAT(ParseQuicVersionVectorString("h3-29, h3-Q050"), + ElementsAre(version_draft_29, version_q050)); + EXPECT_THAT(ParseQuicVersionVectorString("QUIC_VERSION_50,h3-29"), + ElementsAre(version_q050, version_draft_29)); + EXPECT_THAT(ParseQuicVersionVectorString("h3-29,QUIC_VERSION_50"), + ElementsAre(version_draft_29, version_q050)); + EXPECT_THAT(ParseQuicVersionVectorString("QUIC_VERSION_50, h3-29"), + ElementsAre(version_q050, version_draft_29)); + EXPECT_THAT(ParseQuicVersionVectorString("h3-29, QUIC_VERSION_50"), + ElementsAre(version_draft_29, version_q050)); + EXPECT_THAT(ParseQuicVersionVectorString("QUIC_VERSION_50,QUIC_VERSION_46"), + ElementsAre(version_q050, version_q046)); + EXPECT_THAT(ParseQuicVersionVectorString("QUIC_VERSION_46,QUIC_VERSION_50"), + ElementsAre(version_q046, version_q050)); + + // Regression test for https://crbug.com/1044952. + EXPECT_THAT(ParseQuicVersionVectorString("QUIC_VERSION_50, QUIC_VERSION_50"), + ElementsAre(version_q050)); + EXPECT_THAT(ParseQuicVersionVectorString("h3-Q050, h3-Q050"), + ElementsAre(version_q050)); + EXPECT_THAT(ParseQuicVersionVectorString("h3-Q050, QUIC_VERSION_50"), + ElementsAre(version_q050)); + EXPECT_THAT(ParseQuicVersionVectorString( + "QUIC_VERSION_50, h3-Q050, QUIC_VERSION_50, h3-Q050"), + ElementsAre(version_q050)); + EXPECT_THAT(ParseQuicVersionVectorString("QUIC_VERSION_50, h3-29, h3-Q050"), + ElementsAre(version_q050, version_draft_29)); + + EXPECT_THAT(ParseQuicVersionVectorString("99"), IsEmpty()); + EXPECT_THAT(ParseQuicVersionVectorString("70"), IsEmpty()); + EXPECT_THAT(ParseQuicVersionVectorString("h3-01"), IsEmpty()); + EXPECT_THAT(ParseQuicVersionVectorString("h3-01,h3-29"), + ElementsAre(version_draft_29)); +} + +// Do not use MakeVersionLabel() to generate expectations, because +// CreateQuicVersionLabel() uses MakeVersionLabel() internally, +// in case it has a bug. +TEST(QuicVersionsTest, CreateQuicVersionLabel) { + static_assert(SupportedVersions().size() == 6u, + "Supported versions out of sync"); + EXPECT_EQ(0x51303433u, CreateQuicVersionLabel(ParsedQuicVersion::Q043())); + EXPECT_EQ(0x51303436u, CreateQuicVersionLabel(ParsedQuicVersion::Q046())); + EXPECT_EQ(0x51303530u, CreateQuicVersionLabel(ParsedQuicVersion::Q050())); + EXPECT_EQ(0xff00001du, CreateQuicVersionLabel(ParsedQuicVersion::Draft29())); + EXPECT_EQ(0x00000001u, CreateQuicVersionLabel(ParsedQuicVersion::RFCv1())); + EXPECT_EQ(0x6b3343cfu, + CreateQuicVersionLabel(ParsedQuicVersion::V2Draft08())); + + // Make sure the negotiation reserved version is in the IETF reserved space. + EXPECT_EQ( + 0xda5a3a3au & 0x0f0f0f0f, + CreateQuicVersionLabel(ParsedQuicVersion::ReservedForNegotiation()) & + 0x0f0f0f0f); + + // Make sure that disabling randomness works. + SetQuicFlag(quic_disable_version_negotiation_grease_randomness, true); + EXPECT_EQ(0xda5a3a3au, CreateQuicVersionLabel( + ParsedQuicVersion::ReservedForNegotiation())); +} + +TEST(QuicVersionsTest, QuicVersionLabelToString) { + static_assert(SupportedVersions().size() == 6u, + "Supported versions out of sync"); + EXPECT_EQ("Q043", QuicVersionLabelToString( + CreateQuicVersionLabel(ParsedQuicVersion::Q043()))); + EXPECT_EQ("Q046", QuicVersionLabelToString( + CreateQuicVersionLabel(ParsedQuicVersion::Q046()))); + EXPECT_EQ("Q050", QuicVersionLabelToString( + CreateQuicVersionLabel(ParsedQuicVersion::Q050()))); + EXPECT_EQ("ff00001d", QuicVersionLabelToString(CreateQuicVersionLabel( + ParsedQuicVersion::Draft29()))); + EXPECT_EQ("00000001", QuicVersionLabelToString(CreateQuicVersionLabel( + ParsedQuicVersion::RFCv1()))); + EXPECT_EQ("6b3343cf", QuicVersionLabelToString(CreateQuicVersionLabel( + ParsedQuicVersion::V2Draft08()))); + + QuicVersionLabelVector version_labels = { + MakeVersionLabel('Q', '0', '3', '5'), + MakeVersionLabel('T', '0', '3', '8'), + MakeVersionLabel(0xff, 0, 0, 7), + }; + + EXPECT_EQ("Q035", QuicVersionLabelToString(version_labels[0])); + EXPECT_EQ("T038", QuicVersionLabelToString(version_labels[1])); + EXPECT_EQ("ff000007", QuicVersionLabelToString(version_labels[2])); + + EXPECT_EQ("Q035,T038,ff000007", + QuicVersionLabelVectorToString(version_labels)); + EXPECT_EQ("Q035:T038:ff000007", + QuicVersionLabelVectorToString(version_labels, ":", 2)); + EXPECT_EQ("Q035|T038|...", + QuicVersionLabelVectorToString(version_labels, "|", 1)); + + std::ostringstream os; + os << version_labels; + EXPECT_EQ("Q035,T038,ff000007", os.str()); +} + +TEST(QuicVersionsTest, ParseQuicVersionLabelString) { + static_assert(SupportedVersions().size() == 6u, + "Supported versions out of sync"); + // Explicitly test known QUIC version label strings. + EXPECT_EQ(ParsedQuicVersion::Q043(), ParseQuicVersionLabelString("Q043")); + EXPECT_EQ(ParsedQuicVersion::Q046(), ParseQuicVersionLabelString("Q046")); + EXPECT_EQ(ParsedQuicVersion::Q050(), ParseQuicVersionLabelString("Q050")); + EXPECT_EQ(ParsedQuicVersion::Draft29(), + ParseQuicVersionLabelString("ff00001d")); + EXPECT_EQ(ParsedQuicVersion::RFCv1(), + ParseQuicVersionLabelString("00000001")); + EXPECT_EQ(ParsedQuicVersion::V2Draft08(), + ParseQuicVersionLabelString("6b3343cf")); + + // Sanity check that a variety of other serialization formats are ignored. + EXPECT_EQ(UnsupportedQuicVersion(), ParseQuicVersionLabelString("1")); + EXPECT_EQ(UnsupportedQuicVersion(), ParseQuicVersionLabelString("46")); + EXPECT_EQ(UnsupportedQuicVersion(), + ParseQuicVersionLabelString("QUIC_VERSION_46")); + EXPECT_EQ(UnsupportedQuicVersion(), ParseQuicVersionLabelString("h3")); + EXPECT_EQ(UnsupportedQuicVersion(), ParseQuicVersionLabelString("h3-29")); + + // Test round-trips between QuicVersionLabelToString and + // ParseQuicVersionLabelString. + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + EXPECT_EQ(version, ParseQuicVersionLabelString(QuicVersionLabelToString( + CreateQuicVersionLabel(version)))); + } +} + +TEST(QuicVersionsTest, QuicVersionToString) { + EXPECT_EQ("QUIC_VERSION_UNSUPPORTED", + QuicVersionToString(QUIC_VERSION_UNSUPPORTED)); + + QuicTransportVersion single_version[] = {QUIC_VERSION_43}; + QuicTransportVersionVector versions_vector; + for (size_t i = 0; i < ABSL_ARRAYSIZE(single_version); ++i) { + versions_vector.push_back(single_version[i]); + } + EXPECT_EQ("QUIC_VERSION_43", + QuicTransportVersionVectorToString(versions_vector)); + + QuicTransportVersion multiple_versions[] = {QUIC_VERSION_UNSUPPORTED, + QUIC_VERSION_43}; + versions_vector.clear(); + for (size_t i = 0; i < ABSL_ARRAYSIZE(multiple_versions); ++i) { + versions_vector.push_back(multiple_versions[i]); + } + EXPECT_EQ("QUIC_VERSION_UNSUPPORTED,QUIC_VERSION_43", + QuicTransportVersionVectorToString(versions_vector)); + + // Make sure that all supported versions are present in QuicVersionToString. + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + EXPECT_NE("QUIC_VERSION_UNSUPPORTED", + QuicVersionToString(version.transport_version)); + } + + std::ostringstream os; + os << versions_vector; + EXPECT_EQ("QUIC_VERSION_UNSUPPORTED,QUIC_VERSION_43", os.str()); +} + +TEST(QuicVersionsTest, ParsedQuicVersionToString) { + EXPECT_EQ("0", ParsedQuicVersionToString(ParsedQuicVersion::Unsupported())); + EXPECT_EQ("Q043", ParsedQuicVersionToString(ParsedQuicVersion::Q043())); + EXPECT_EQ("Q046", ParsedQuicVersionToString(ParsedQuicVersion::Q046())); + EXPECT_EQ("Q050", ParsedQuicVersionToString(ParsedQuicVersion::Q050())); + EXPECT_EQ("draft29", ParsedQuicVersionToString(ParsedQuicVersion::Draft29())); + EXPECT_EQ("RFCv1", ParsedQuicVersionToString(ParsedQuicVersion::RFCv1())); + EXPECT_EQ("V2Draft08", + ParsedQuicVersionToString(ParsedQuicVersion::V2Draft08())); + + ParsedQuicVersionVector versions_vector = {ParsedQuicVersion::Q043()}; + EXPECT_EQ("Q043", ParsedQuicVersionVectorToString(versions_vector)); + + versions_vector = {ParsedQuicVersion::Unsupported(), + ParsedQuicVersion::Q043()}; + EXPECT_EQ("0,Q043", ParsedQuicVersionVectorToString(versions_vector)); + EXPECT_EQ("0:Q043", ParsedQuicVersionVectorToString(versions_vector, ":", + versions_vector.size())); + EXPECT_EQ("0|...", ParsedQuicVersionVectorToString(versions_vector, "|", 0)); + + // Make sure that all supported versions are present in + // ParsedQuicVersionToString. + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + EXPECT_NE("0", ParsedQuicVersionToString(version)); + } + + std::ostringstream os; + os << versions_vector; + EXPECT_EQ("0,Q043", os.str()); +} + +TEST(QuicVersionsTest, FilterSupportedVersionsAllVersions) { + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + QuicEnableVersion(version); + } + ParsedQuicVersionVector expected_parsed_versions; + for (const ParsedQuicVersion& version : SupportedVersions()) { + expected_parsed_versions.push_back(version); + } + EXPECT_EQ(expected_parsed_versions, + FilterSupportedVersions(AllSupportedVersions())); + EXPECT_EQ(expected_parsed_versions, AllSupportedVersions()); +} + +TEST(QuicVersionsTest, FilterSupportedVersionsWithoutFirstVersion) { + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + QuicEnableVersion(version); + } + QuicDisableVersion(AllSupportedVersions().front()); + ParsedQuicVersionVector expected_parsed_versions; + for (const ParsedQuicVersion& version : SupportedVersions()) { + expected_parsed_versions.push_back(version); + } + expected_parsed_versions.erase(expected_parsed_versions.begin()); + EXPECT_EQ(expected_parsed_versions, + FilterSupportedVersions(AllSupportedVersions())); +} + +TEST(QuicVersionsTest, LookUpParsedVersionByIndex) { + ParsedQuicVersionVector all_versions = AllSupportedVersions(); + int version_count = all_versions.size(); + for (int i = -5; i <= version_count + 1; ++i) { + ParsedQuicVersionVector index = ParsedVersionOfIndex(all_versions, i); + if (i >= 0 && i < version_count) { + EXPECT_EQ(all_versions[i], index[0]); + } else { + EXPECT_EQ(UnsupportedQuicVersion(), index[0]); + } + } +} + +// This test may appear to be so simplistic as to be unnecessary, +// yet a typo was made in doing the #defines and it was caught +// only in some test far removed from here... Better safe than sorry. +TEST(QuicVersionsTest, CheckTransportVersionNumbersForTypos) { + static_assert(SupportedVersions().size() == 6u, + "Supported versions out of sync"); + EXPECT_EQ(QUIC_VERSION_43, 43); + EXPECT_EQ(QUIC_VERSION_46, 46); + EXPECT_EQ(QUIC_VERSION_50, 50); + EXPECT_EQ(QUIC_VERSION_IETF_DRAFT_29, 73); + EXPECT_EQ(QUIC_VERSION_IETF_RFC_V1, 80); + EXPECT_EQ(QUIC_VERSION_IETF_2_DRAFT_08, 82); +} + +TEST(QuicVersionsTest, AlpnForVersion) { + static_assert(SupportedVersions().size() == 6u, + "Supported versions out of sync"); + EXPECT_EQ("h3-Q043", AlpnForVersion(ParsedQuicVersion::Q043())); + EXPECT_EQ("h3-Q046", AlpnForVersion(ParsedQuicVersion::Q046())); + EXPECT_EQ("h3-Q050", AlpnForVersion(ParsedQuicVersion::Q050())); + EXPECT_EQ("h3-29", AlpnForVersion(ParsedQuicVersion::Draft29())); + EXPECT_EQ("h3", AlpnForVersion(ParsedQuicVersion::RFCv1())); + EXPECT_EQ("h3", AlpnForVersion(ParsedQuicVersion::V2Draft08())); +} + +TEST(QuicVersionsTest, QuicVersionEnabling) { + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + QuicFlagSaver flag_saver; + QuicDisableVersion(version); + EXPECT_FALSE(QuicVersionIsEnabled(version)); + QuicEnableVersion(version); + EXPECT_TRUE(QuicVersionIsEnabled(version)); + } +} + +TEST(QuicVersionsTest, ReservedForNegotiation) { + EXPECT_EQ(QUIC_VERSION_RESERVED_FOR_NEGOTIATION, + QuicVersionReservedForNegotiation().transport_version); + // QUIC_VERSION_RESERVED_FOR_NEGOTIATION MUST NOT be supported. + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + EXPECT_NE(QUIC_VERSION_RESERVED_FOR_NEGOTIATION, version.transport_version); + } +} + +TEST(QuicVersionsTest, SupportedVersionsHasCorrectList) { + size_t index = 0; + for (HandshakeProtocol handshake_protocol : SupportedHandshakeProtocols()) { + for (int trans_vers = 255; trans_vers > 0; trans_vers--) { + QuicTransportVersion transport_version = + static_cast(trans_vers); + SCOPED_TRACE(index); + if (ParsedQuicVersionIsValid(handshake_protocol, transport_version)) { + ParsedQuicVersion version = SupportedVersions()[index]; + EXPECT_EQ(version, + ParsedQuicVersion(handshake_protocol, transport_version)); + index++; + } + } + } + EXPECT_EQ(SupportedVersions().size(), index); +} + +TEST(QuicVersionsTest, SupportedVersionsAllDistinct) { + for (size_t index1 = 0; index1 < SupportedVersions().size(); ++index1) { + ParsedQuicVersion version1 = SupportedVersions()[index1]; + for (size_t index2 = index1 + 1; index2 < SupportedVersions().size(); + ++index2) { + ParsedQuicVersion version2 = SupportedVersions()[index2]; + EXPECT_NE(version1, version2) << version1 << " " << version2; + EXPECT_NE(CreateQuicVersionLabel(version1), + CreateQuicVersionLabel(version2)) + << version1 << " " << version2; + // The one pair where ALPNs are the same. + if ((version1 != ParsedQuicVersion::V2Draft08()) && + (version2 != ParsedQuicVersion::RFCv1())) { + EXPECT_NE(AlpnForVersion(version1), AlpnForVersion(version2)) + << version1 << " " << version2; + } + } + } +} + +TEST(QuicVersionsTest, CurrentSupportedHttp3Versions) { + ParsedQuicVersionVector h3_versions = CurrentSupportedHttp3Versions(); + ParsedQuicVersionVector all_current_supported_versions = + CurrentSupportedVersions(); + for (auto& version : all_current_supported_versions) { + bool version_is_h3 = false; + for (auto& h3_version : h3_versions) { + if (version == h3_version) { + EXPECT_TRUE(version.UsesHttp3()); + version_is_h3 = true; + break; + } + } + if (!version_is_h3) { + EXPECT_FALSE(version.UsesHttp3()); + } + } +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/quic_write_blocked_list.cc b/quiche/quic/core/quic_write_blocked_list.cc new file mode 100644 index 000000000000..475bdc12d0d7 --- /dev/null +++ b/quiche/quic/core/quic_write_blocked_list.cc @@ -0,0 +1,212 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_write_blocked_list.h" + +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" + +namespace quic { + +QuicWriteBlockedList::QuicWriteBlockedList() + : last_priority_popped_(0), + respect_incremental_( + GetQuicReloadableFlag(quic_priority_respect_incremental)), + disable_batch_write_(GetQuicReloadableFlag(quic_disable_batch_write)) { + memset(batch_write_stream_id_, 0, sizeof(batch_write_stream_id_)); + memset(bytes_left_for_batch_write_, 0, sizeof(bytes_left_for_batch_write_)); +} + +bool QuicWriteBlockedList::ShouldYield(QuicStreamId id) const { + for (const auto& stream : static_stream_collection_) { + if (stream.id == id) { + // Static streams should never yield to data streams, or to lower + // priority static stream. + return false; + } + if (stream.is_blocked) { + return true; // All data streams yield to static streams. + } + } + + return priority_write_scheduler_.ShouldYield(id); +} + +QuicStreamId QuicWriteBlockedList::PopFront() { + QuicStreamId static_stream_id; + if (static_stream_collection_.UnblockFirstBlocked(&static_stream_id)) { + return static_stream_id; + } + + const auto [id, priority] = + priority_write_scheduler_.PopNextReadyStreamAndPriority(); + const spdy::SpdyPriority urgency = priority.urgency; + const bool incremental = priority.incremental; + + last_priority_popped_ = urgency; + + if (disable_batch_write_) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_disable_batch_write, 1, 3); + + // Writes on incremental streams are not batched. Not setting + // `batch_write_stream_id_` if the current write is incremental allows the + // write on the last non-incremental stream to continue if only incremental + // writes happened within this urgency bucket while that stream had no data + // to write. + if (!respect_incremental_ || !incremental) { + batch_write_stream_id_[urgency] = id; + } + + return id; + } + + if (!priority_write_scheduler_.HasReadyStreams()) { + // If no streams are blocked, don't bother latching. This stream will be + // the first popped for its urgency anyway. + batch_write_stream_id_[urgency] = 0; + } else if (batch_write_stream_id_[urgency] != id) { + // If newly latching this batch write stream, let it write 16k. + batch_write_stream_id_[urgency] = id; + bytes_left_for_batch_write_[urgency] = 16000; + } + + return id; +} + +void QuicWriteBlockedList::RegisterStream(QuicStreamId stream_id, + bool is_static_stream, + const QuicStreamPriority& priority) { + QUICHE_DCHECK(!priority_write_scheduler_.StreamRegistered(stream_id)) + << "stream " << stream_id << " already registered"; + if (is_static_stream) { + static_stream_collection_.Register(stream_id); + return; + } + + priority_write_scheduler_.RegisterStream(stream_id, priority.http()); +} + +void QuicWriteBlockedList::UnregisterStream(QuicStreamId stream_id) { + if (static_stream_collection_.Unregister(stream_id)) { + return; + } + priority_write_scheduler_.UnregisterStream(stream_id); +} + +void QuicWriteBlockedList::UpdateStreamPriority( + QuicStreamId stream_id, const QuicStreamPriority& new_priority) { + QUICHE_DCHECK(!static_stream_collection_.IsRegistered(stream_id)); + priority_write_scheduler_.UpdateStreamPriority(stream_id, + new_priority.http()); +} + +void QuicWriteBlockedList::UpdateBytesForStream(QuicStreamId stream_id, + size_t bytes) { + if (disable_batch_write_) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_disable_batch_write, 2, 3); + return; + } + + if (batch_write_stream_id_[last_priority_popped_] == stream_id) { + // If this was the last data stream popped by PopFront, update the + // bytes remaining in its batch write. + bytes_left_for_batch_write_[last_priority_popped_] -= + std::min(bytes_left_for_batch_write_[last_priority_popped_], bytes); + } +} + +void QuicWriteBlockedList::AddStream(QuicStreamId stream_id) { + if (static_stream_collection_.SetBlocked(stream_id)) { + return; + } + + if (respect_incremental_) { + QUIC_RELOADABLE_FLAG_COUNT(quic_priority_respect_incremental); + if (!priority_write_scheduler_.GetStreamPriority(stream_id).incremental) { + const bool push_front = + stream_id == batch_write_stream_id_[last_priority_popped_]; + priority_write_scheduler_.MarkStreamReady(stream_id, push_front); + return; + } + } + + if (disable_batch_write_) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_disable_batch_write, 3, 3); + priority_write_scheduler_.MarkStreamReady(stream_id, + /* push_front = */ false); + return; + } + + const bool push_front = + stream_id == batch_write_stream_id_[last_priority_popped_] && + bytes_left_for_batch_write_[last_priority_popped_] > 0; + + priority_write_scheduler_.MarkStreamReady(stream_id, push_front); +} + +bool QuicWriteBlockedList::IsStreamBlocked(QuicStreamId stream_id) const { + for (const auto& stream : static_stream_collection_) { + if (stream.id == stream_id) { + return stream.is_blocked; + } + } + + return priority_write_scheduler_.IsStreamReady(stream_id); +} + +void QuicWriteBlockedList::StaticStreamCollection::Register(QuicStreamId id) { + QUICHE_DCHECK(!IsRegistered(id)); + streams_.push_back({id, false}); +} + +bool QuicWriteBlockedList::StaticStreamCollection::IsRegistered( + QuicStreamId id) const { + for (const auto& stream : streams_) { + if (stream.id == id) { + return true; + } + } + return false; +} + +bool QuicWriteBlockedList::StaticStreamCollection::Unregister(QuicStreamId id) { + for (auto it = streams_.begin(); it != streams_.end(); ++it) { + if (it->id == id) { + if (it->is_blocked) { + --num_blocked_; + } + streams_.erase(it); + return true; + } + } + return false; +} + +bool QuicWriteBlockedList::StaticStreamCollection::SetBlocked(QuicStreamId id) { + for (auto& stream : streams_) { + if (stream.id == id) { + if (!stream.is_blocked) { + stream.is_blocked = true; + ++num_blocked_; + } + return true; + } + } + return false; +} + +bool QuicWriteBlockedList::StaticStreamCollection::UnblockFirstBlocked( + QuicStreamId* id) { + for (auto& stream : streams_) { + if (stream.is_blocked) { + --num_blocked_; + stream.is_blocked = false; + *id = stream.id; + return true; + } + } + return false; +} + +} // namespace quic diff --git a/quiche/quic/core/quic_write_blocked_list.h b/quiche/quic/core/quic_write_blocked_list.h new file mode 100644 index 000000000000..1c7f25a4c60f --- /dev/null +++ b/quiche/quic/core/quic_write_blocked_list.h @@ -0,0 +1,220 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_WRITE_BLOCKED_LIST_H_ +#define QUICHE_QUIC_CORE_QUIC_WRITE_BLOCKED_LIST_H_ + +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "quiche/http2/core/priority_write_scheduler.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_stream_priority.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/spdy/core/spdy_protocol.h" + +namespace quic { + +// Keeps tracks of the order of QUIC streams that have data to write. +// Static streams come first, in the order they were registered with +// QuicWriteBlockedList. They are followed by non-static streams, ordered by +// priority. +class QUICHE_EXPORT QuicWriteBlockedListInterface { + public: + virtual ~QuicWriteBlockedListInterface() = default; + + virtual bool HasWriteBlockedDataStreams() const = 0; + virtual size_t NumBlockedSpecialStreams() const = 0; + virtual size_t NumBlockedStreams() const = 0; + bool HasWriteBlockedSpecialStream() const { + return NumBlockedSpecialStreams() > 0; + } + + // Returns true if there is another stream with higher priority in the queue. + virtual bool ShouldYield(QuicStreamId id) const = 0; + + // Returns the priority of the specified stream. + virtual QuicStreamPriority GetPriorityOfStream(QuicStreamId id) const = 0; + + // Pops the highest priority stream, special casing static streams. Latches + // the most recently popped data stream for batch writing purposes. + virtual QuicStreamId PopFront() = 0; + + // Register a stream with given priority. + // `priority` is ignored for static streams. + virtual void RegisterStream(QuicStreamId stream_id, bool is_static_stream, + const QuicStreamPriority& priority) = 0; + + // Unregister a stream. `stream_id` must be registered, either as a static + // stream or as a non-static stream. + virtual void UnregisterStream(QuicStreamId stream_id) = 0; + + // Updates the stored priority of a stream. Must not be called for static + // streams. + virtual void UpdateStreamPriority(QuicStreamId stream_id, + const QuicStreamPriority& new_priority) = 0; + + // TODO(b/147306124): Remove when deprecating + // reloadable_flag_quic_disable_batch_write. + virtual void UpdateBytesForStream(QuicStreamId stream_id, size_t bytes) = 0; + + // Pushes a stream to the back of the list for its priority level *unless* it + // is latched for doing batched writes in which case it goes to the front of + // the list for its priority level. + // Static streams are special cased to always resume first. + // Stream must already be registered. + virtual void AddStream(QuicStreamId stream_id) = 0; + + // Returns true if stream with |stream_id| is write blocked. + virtual bool IsStreamBlocked(QuicStreamId stream_id) const = 0; +}; + +// Default implementation of QuicWriteBlockedListInterface. +class QUIC_EXPORT_PRIVATE QuicWriteBlockedList + : public QuicWriteBlockedListInterface { + public: + explicit QuicWriteBlockedList(); + QuicWriteBlockedList(const QuicWriteBlockedList&) = delete; + QuicWriteBlockedList& operator=(const QuicWriteBlockedList&) = delete; + + bool HasWriteBlockedDataStreams() const override { + return priority_write_scheduler_.HasReadyStreams(); + } + + size_t NumBlockedSpecialStreams() const override { + return static_stream_collection_.num_blocked(); + } + + size_t NumBlockedStreams() const override { + return NumBlockedSpecialStreams() + + priority_write_scheduler_.NumReadyStreams(); + } + + bool ShouldYield(QuicStreamId id) const override; + + QuicStreamPriority GetPriorityOfStream(QuicStreamId id) const override { + return QuicStreamPriority(priority_write_scheduler_.GetStreamPriority(id)); + } + + // Pops the highest priority stream, special casing static streams. Latches + // the most recently popped data stream for batch writing purposes. + QuicStreamId PopFront() override; + + // Register a stream with given priority. + // `priority` is ignored for static streams. + void RegisterStream(QuicStreamId stream_id, bool is_static_stream, + const QuicStreamPriority& priority) override; + + // Unregister a stream. `stream_id` must be registered, either as a static + // stream or as a non-static stream. + void UnregisterStream(QuicStreamId stream_id) override; + + // Updates the stored priority of a stream. Must not be called for static + // streams. + void UpdateStreamPriority(QuicStreamId stream_id, + const QuicStreamPriority& new_priority) override; + + // TODO(b/147306124): Remove when deprecating + // reloadable_flag_quic_disable_batch_write. + void UpdateBytesForStream(QuicStreamId stream_id, size_t bytes) override; + + // Pushes a stream to the back of the list for its priority level *unless* it + // is latched for doing batched writes in which case it goes to the front of + // the list for its priority level. + // Static streams are special cased to always resume first. + // Stream must already be registered. + void AddStream(QuicStreamId stream_id) override; + + // Returns true if stream with |stream_id| is write blocked. + bool IsStreamBlocked(QuicStreamId stream_id) const override; + + private: + struct QUICHE_EXPORT HttpStreamPriorityToInt { + int operator()(const HttpStreamPriority& priority) { + return priority.urgency; + } + }; + + struct QUICHE_EXPORT IntToHttpStreamPriority { + HttpStreamPriority operator()(int urgency) { + return HttpStreamPriority{urgency}; + } + }; + http2::PriorityWriteScheduler + priority_write_scheduler_; + + // If performing batch writes, this will be the stream ID of the stream doing + // batch writes for this priority level. We will allow this stream to write + // until it has written kBatchWriteSize bytes, it has no more data to write, + // or a higher priority stream preempts. + QuicStreamId batch_write_stream_id_[spdy::kV3LowestPriority + 1]; + // Set to kBatchWriteSize when we set a new batch_write_stream_id_ for a given + // priority. This is decremented with each write the stream does until it is + // done with its batch write. + // TODO(b/147306124): Remove when deprecating + // reloadable_flag_quic_disable_batch_write. + size_t bytes_left_for_batch_write_[spdy::kV3LowestPriority + 1]; + // Tracks the last priority popped for UpdateBytesForStream() and AddStream(). + spdy::SpdyPriority last_priority_popped_; + + // A StaticStreamCollection is a vector of pairs plus a + // eagerly-computed number of blocked static streams. + class QUIC_EXPORT_PRIVATE StaticStreamCollection { + public: + struct QUIC_EXPORT_PRIVATE StreamIdBlockedPair { + QuicStreamId id; + bool is_blocked; + }; + + // Optimized for the typical case of 2 static streams per session. + using StreamsVector = absl::InlinedVector; + + StreamsVector::const_iterator begin() const { return streams_.cbegin(); } + + StreamsVector::const_iterator end() const { return streams_.cend(); } + + size_t num_blocked() const { return num_blocked_; } + + // Add |id| to the collection in unblocked state. + void Register(QuicStreamId id); + + // True if |id| is in the collection, regardless of its state. + bool IsRegistered(QuicStreamId id) const; + + // Remove |id| from the collection. If it is in the blocked state, reduce + // |num_blocked_| by 1. Returns true if |id| was in the collection. + bool Unregister(QuicStreamId id); + + // Set |id| to be blocked. If |id| is not already blocked, increase + // |num_blocked_| by 1. + // Return true if |id| is in the collection. + bool SetBlocked(QuicStreamId id); + + // Unblock the first blocked stream in the collection. + // If no stream is blocked, return false. Otherwise return true, set *id to + // the unblocked stream id and reduce |num_blocked_| by 1. + bool UnblockFirstBlocked(QuicStreamId* id); + + private: + size_t num_blocked_ = 0; + StreamsVector streams_; + }; + + StaticStreamCollection static_stream_collection_; + + // Latched value of reloadable_flag_quic_priority_respect_incremental. + const bool respect_incremental_; + // Latched value of reloadable_flag_quic_disable_batch_write. + const bool disable_batch_write_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_WRITE_BLOCKED_LIST_H_ diff --git a/quiche/quic/core/quic_write_blocked_list_test.cc b/quiche/quic/core/quic_write_blocked_list_test.cc new file mode 100644 index 000000000000..ba6569e11c76 --- /dev/null +++ b/quiche/quic/core/quic_write_blocked_list_test.cc @@ -0,0 +1,678 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/quic_write_blocked_list.h" + +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/platform/api/quiche_expect_bug.h" + +using spdy::kV3HighestPriority; +using spdy::kV3LowestPriority; + +namespace quic { +namespace test { +namespace { + +constexpr bool kStatic = true; +constexpr bool kNotStatic = false; + +constexpr bool kIncremental = true; +constexpr bool kNotIncremental = false; + +class QuicWriteBlockedListTest : public QuicTest { + protected: + void SetUp() override { + // Delay construction of QuicWriteBlockedList object to allow constructor of + // derived test classes to manipulate reloadable flags that are latched in + // QuicWriteBlockedList constructor. + write_blocked_list_.emplace(); + } + + bool HasWriteBlockedDataStreams() const { + return write_blocked_list_->HasWriteBlockedDataStreams(); + } + + bool HasWriteBlockedSpecialStream() const { + return write_blocked_list_->HasWriteBlockedSpecialStream(); + } + + size_t NumBlockedSpecialStreams() const { + return write_blocked_list_->NumBlockedSpecialStreams(); + } + + size_t NumBlockedStreams() const { + return write_blocked_list_->NumBlockedStreams(); + } + + bool ShouldYield(QuicStreamId id) const { + return write_blocked_list_->ShouldYield(id); + } + + QuicStreamPriority GetPriorityOfStream(QuicStreamId id) const { + return write_blocked_list_->GetPriorityOfStream(id); + } + + QuicStreamId PopFront() { return write_blocked_list_->PopFront(); } + + void RegisterStream(QuicStreamId stream_id, bool is_static_stream, + const HttpStreamPriority& priority) { + write_blocked_list_->RegisterStream(stream_id, is_static_stream, + QuicStreamPriority(priority)); + } + + void UnregisterStream(QuicStreamId stream_id) { + write_blocked_list_->UnregisterStream(stream_id); + } + + void UpdateStreamPriority(QuicStreamId stream_id, + const HttpStreamPriority& new_priority) { + write_blocked_list_->UpdateStreamPriority(stream_id, + QuicStreamPriority(new_priority)); + } + + void UpdateBytesForStream(QuicStreamId stream_id, size_t bytes) { + write_blocked_list_->UpdateBytesForStream(stream_id, bytes); + } + + void AddStream(QuicStreamId stream_id) { + write_blocked_list_->AddStream(stream_id); + } + + bool IsStreamBlocked(QuicStreamId stream_id) const { + return write_blocked_list_->IsStreamBlocked(stream_id); + } + + private: + absl::optional write_blocked_list_; +}; + +TEST_F(QuicWriteBlockedListTest, PriorityOrder) { + // Mark streams blocked in roughly reverse priority order, and + // verify that streams are sorted. + RegisterStream(40, kNotStatic, {kV3LowestPriority, kNotIncremental}); + RegisterStream(23, kNotStatic, {kV3HighestPriority, kIncremental}); + RegisterStream(17, kNotStatic, {kV3HighestPriority, kNotIncremental}); + RegisterStream(1, kStatic, {kV3HighestPriority, kNotIncremental}); + RegisterStream(3, kStatic, {kV3HighestPriority, kNotIncremental}); + + EXPECT_EQ(kV3LowestPriority, GetPriorityOfStream(40).http().urgency); + EXPECT_EQ(kNotIncremental, GetPriorityOfStream(40).http().incremental); + + EXPECT_EQ(kV3HighestPriority, GetPriorityOfStream(23).http().urgency); + EXPECT_EQ(kIncremental, GetPriorityOfStream(23).http().incremental); + + EXPECT_EQ(kV3HighestPriority, GetPriorityOfStream(17).http().urgency); + EXPECT_EQ(kNotIncremental, GetPriorityOfStream(17).http().incremental); + + AddStream(40); + EXPECT_TRUE(IsStreamBlocked(40)); + AddStream(23); + EXPECT_TRUE(IsStreamBlocked(23)); + AddStream(17); + EXPECT_TRUE(IsStreamBlocked(17)); + AddStream(3); + EXPECT_TRUE(IsStreamBlocked(3)); + AddStream(1); + EXPECT_TRUE(IsStreamBlocked(1)); + + EXPECT_EQ(5u, NumBlockedStreams()); + EXPECT_TRUE(HasWriteBlockedSpecialStream()); + EXPECT_EQ(2u, NumBlockedSpecialStreams()); + EXPECT_TRUE(HasWriteBlockedDataStreams()); + + // Static streams are highest priority, regardless of priority value. + EXPECT_EQ(1u, PopFront()); + EXPECT_EQ(1u, NumBlockedSpecialStreams()); + EXPECT_FALSE(IsStreamBlocked(1)); + + EXPECT_EQ(3u, PopFront()); + EXPECT_EQ(0u, NumBlockedSpecialStreams()); + EXPECT_FALSE(IsStreamBlocked(3)); + + // Streams with same priority are popped in the order they were inserted. + EXPECT_EQ(23u, PopFront()); + EXPECT_FALSE(IsStreamBlocked(23)); + EXPECT_EQ(17u, PopFront()); + EXPECT_FALSE(IsStreamBlocked(17)); + + // Low priority stream appears last. + EXPECT_EQ(40u, PopFront()); + EXPECT_FALSE(IsStreamBlocked(40)); + + EXPECT_EQ(0u, NumBlockedStreams()); + EXPECT_FALSE(HasWriteBlockedSpecialStream()); + EXPECT_FALSE(HasWriteBlockedDataStreams()); +} + +TEST_F(QuicWriteBlockedListTest, SingleStaticStream) { + RegisterStream(5, kStatic, {kV3HighestPriority, kNotIncremental}); + AddStream(5); + + EXPECT_EQ(1u, NumBlockedStreams()); + EXPECT_TRUE(HasWriteBlockedSpecialStream()); + EXPECT_EQ(5u, PopFront()); + EXPECT_EQ(0u, NumBlockedStreams()); + EXPECT_FALSE(HasWriteBlockedSpecialStream()); +} + +TEST_F(QuicWriteBlockedListTest, StaticStreamsComeFirst) { + RegisterStream(5, kNotStatic, {kV3HighestPriority, kNotIncremental}); + RegisterStream(3, kStatic, {kV3LowestPriority, kNotIncremental}); + AddStream(5); + AddStream(3); + + EXPECT_EQ(2u, NumBlockedStreams()); + EXPECT_TRUE(HasWriteBlockedSpecialStream()); + EXPECT_TRUE(HasWriteBlockedDataStreams()); + + EXPECT_EQ(3u, PopFront()); + EXPECT_EQ(5u, PopFront()); + + EXPECT_EQ(0u, NumBlockedStreams()); + EXPECT_FALSE(HasWriteBlockedSpecialStream()); + EXPECT_FALSE(HasWriteBlockedDataStreams()); +} + +TEST_F(QuicWriteBlockedListTest, NoDuplicateEntries) { + // Test that QuicWriteBlockedList doesn't allow duplicate entries. + // Try to add a stream to the write blocked list multiple times at the same + // priority. + const QuicStreamId kBlockedId = 5; + RegisterStream(kBlockedId, kNotStatic, {kV3HighestPriority, kNotIncremental}); + AddStream(kBlockedId); + AddStream(kBlockedId); + AddStream(kBlockedId); + + // This should only result in one blocked stream being added. + EXPECT_EQ(1u, NumBlockedStreams()); + EXPECT_TRUE(HasWriteBlockedDataStreams()); + + // There should only be one stream to pop off the front. + EXPECT_EQ(kBlockedId, PopFront()); + EXPECT_EQ(0u, NumBlockedStreams()); + EXPECT_FALSE(HasWriteBlockedDataStreams()); +} + +TEST_F(QuicWriteBlockedListTest, IncrementalStreamsRoundRobin) { + const QuicStreamId id1 = 5; + const QuicStreamId id2 = 7; + const QuicStreamId id3 = 9; + RegisterStream(id1, kNotStatic, {kV3LowestPriority, kIncremental}); + RegisterStream(id2, kNotStatic, {kV3LowestPriority, kIncremental}); + RegisterStream(id3, kNotStatic, {kV3LowestPriority, kIncremental}); + + AddStream(id1); + AddStream(id2); + AddStream(id3); + + EXPECT_EQ(id1, PopFront()); + const size_t kLargeWriteSize = 1000 * 1000 * 1000; + UpdateBytesForStream(id1, kLargeWriteSize); + AddStream(id1); + + EXPECT_EQ(id2, PopFront()); + UpdateBytesForStream(id2, kLargeWriteSize); + EXPECT_EQ(id3, PopFront()); + UpdateBytesForStream(id3, kLargeWriteSize); + + AddStream(id3); + AddStream(id2); + + EXPECT_EQ(id1, PopFront()); + UpdateBytesForStream(id1, kLargeWriteSize); + EXPECT_EQ(id3, PopFront()); + UpdateBytesForStream(id3, kLargeWriteSize); + AddStream(id3); + + EXPECT_EQ(id2, PopFront()); + EXPECT_EQ(id3, PopFront()); +} + +class QuicWriteBlockedListParameterizedTest + : public QuicWriteBlockedListTest, + public ::testing::WithParamInterface> { + protected: + QuicWriteBlockedListParameterizedTest() + : priority_respect_incremental_(std::get<0>(GetParam())), + disable_batch_write_(std::get<1>(GetParam())) { + SetQuicReloadableFlag(quic_priority_respect_incremental, + priority_respect_incremental_); + SetQuicReloadableFlag(quic_disable_batch_write, disable_batch_write_); + } + + const bool priority_respect_incremental_; + const bool disable_batch_write_; +}; + +INSTANTIATE_TEST_SUITE_P( + BatchWrite, QuicWriteBlockedListParameterizedTest, + ::testing::Combine(::testing::Bool(), ::testing::Bool()), + [](const testing::TestParamInfo< + QuicWriteBlockedListParameterizedTest::ParamType>& info) { + return absl::StrCat(std::get<0>(info.param) ? "RespectIncrementalTrue" + : "RespectIncrementalFalse", + std::get<1>(info.param) ? "DisableBatchWriteTrue" + : "DisableBatchWriteFalse"); + }); + +// If reloadable_flag_quic_disable_batch_write is false, writes are batched. +TEST_P(QuicWriteBlockedListParameterizedTest, BatchingWrites) { + if (disable_batch_write_) { + return; + } + + const QuicStreamId id1 = 5; + const QuicStreamId id2 = 7; + const QuicStreamId id3 = 9; + RegisterStream(id1, kNotStatic, {kV3LowestPriority, kIncremental}); + RegisterStream(id2, kNotStatic, {kV3LowestPriority, kIncremental}); + RegisterStream(id3, kNotStatic, {kV3HighestPriority, kIncremental}); + + AddStream(id1); + AddStream(id2); + EXPECT_EQ(2u, NumBlockedStreams()); + + // The first stream we push back should stay at the front until 16k is + // written. + EXPECT_EQ(id1, PopFront()); + UpdateBytesForStream(id1, 15999); + AddStream(id1); + EXPECT_EQ(2u, NumBlockedStreams()); + EXPECT_EQ(id1, PopFront()); + + // Once 16k is written the first stream will yield to the next. + UpdateBytesForStream(id1, 1); + AddStream(id1); + EXPECT_EQ(2u, NumBlockedStreams()); + EXPECT_EQ(id2, PopFront()); + + // Set the new stream to have written all but one byte. + UpdateBytesForStream(id2, 15999); + AddStream(id2); + EXPECT_EQ(2u, NumBlockedStreams()); + + // Ensure higher priority streams are popped first. + AddStream(id3); + EXPECT_EQ(id3, PopFront()); + + // Higher priority streams will always be popped first, even if using their + // byte quota + UpdateBytesForStream(id3, 20000); + AddStream(id3); + EXPECT_EQ(id3, PopFront()); + + // Once the higher priority stream is out of the way, id2 will resume its 16k + // write, with only 1 byte remaining of its guaranteed write allocation. + EXPECT_EQ(id2, PopFront()); + UpdateBytesForStream(id2, 1); + AddStream(id2); + EXPECT_EQ(2u, NumBlockedStreams()); + EXPECT_EQ(id1, PopFront()); +} + +// If reloadable_flag_quic_disable_batch_write is true, writes are performed +// round-robin regardless of how little data is written on each stream. +TEST_P(QuicWriteBlockedListParameterizedTest, RoundRobin) { + if (!disable_batch_write_) { + return; + } + + const QuicStreamId id1 = 5; + const QuicStreamId id2 = 7; + const QuicStreamId id3 = 9; + RegisterStream(id1, kNotStatic, {kV3LowestPriority, kIncremental}); + RegisterStream(id2, kNotStatic, {kV3LowestPriority, kIncremental}); + RegisterStream(id3, kNotStatic, {kV3LowestPriority, kIncremental}); + + AddStream(id1); + AddStream(id2); + AddStream(id3); + + EXPECT_EQ(id1, PopFront()); + AddStream(id1); + + EXPECT_EQ(id2, PopFront()); + EXPECT_EQ(id3, PopFront()); + + AddStream(id3); + AddStream(id2); + + EXPECT_EQ(id1, PopFront()); + EXPECT_EQ(id3, PopFront()); + AddStream(id3); + + EXPECT_EQ(id2, PopFront()); + EXPECT_EQ(id3, PopFront()); +} + +TEST_P(QuicWriteBlockedListParameterizedTest, + NonIncrementalStreamsKeepWriting) { + if (!priority_respect_incremental_) { + return; + } + + const QuicStreamId id1 = 1; + const QuicStreamId id2 = 2; + const QuicStreamId id3 = 3; + const QuicStreamId id4 = 4; + RegisterStream(id1, kNotStatic, {kV3LowestPriority, kNotIncremental}); + RegisterStream(id2, kNotStatic, {kV3LowestPriority, kNotIncremental}); + RegisterStream(id3, kNotStatic, {kV3LowestPriority, kNotIncremental}); + RegisterStream(id4, kNotStatic, {kV3HighestPriority, kNotIncremental}); + + AddStream(id1); + AddStream(id2); + AddStream(id3); + + // A non-incremental stream can continue writing as long as it has data. + EXPECT_EQ(id1, PopFront()); + const size_t kLargeWriteSize = 1000 * 1000 * 1000; + UpdateBytesForStream(id1, kLargeWriteSize); + AddStream(id1); + + EXPECT_EQ(id1, PopFront()); + UpdateBytesForStream(id1, kLargeWriteSize); + AddStream(id1); + + EXPECT_EQ(id1, PopFront()); + UpdateBytesForStream(id1, kLargeWriteSize); + AddStream(id1); + + EXPECT_EQ(id1, PopFront()); + UpdateBytesForStream(id1, kLargeWriteSize); + AddStream(id1); + + // A higher priority stream takes precedence. + AddStream(id4); + EXPECT_EQ(id4, PopFront()); + + // When it is the turn of the lower urgency bucket again, writing of the first + // stream will continue. + EXPECT_EQ(id1, PopFront()); + UpdateBytesForStream(id1, kLargeWriteSize); + + // When there is no more data on the first stream, write can start on the + // second stream. + EXPECT_EQ(id2, PopFront()); + UpdateBytesForStream(id2, kLargeWriteSize); + AddStream(id2); + + // Write continues without limit. + EXPECT_EQ(id2, PopFront()); + UpdateBytesForStream(id2, kLargeWriteSize); + AddStream(id2); + + // Stream 1 is not the most recently written one, therefore it gets to the end + // of the dequeue. + AddStream(id1); + + EXPECT_EQ(id2, PopFront()); + UpdateBytesForStream(id2, kLargeWriteSize); + + EXPECT_EQ(id3, PopFront()); + UpdateBytesForStream(id2, kLargeWriteSize); + AddStream(id3); + + EXPECT_EQ(id3, PopFront()); + UpdateBytesForStream(id2, kLargeWriteSize); + + // When there is no data to write either on stream 2 or stream 3, stream 1 can + // resume. + EXPECT_EQ(id1, PopFront()); + UpdateBytesForStream(id1, kLargeWriteSize); +} + +TEST_P(QuicWriteBlockedListParameterizedTest, + IncrementalAndNonIncrementalStreams) { + if (!priority_respect_incremental_) { + return; + } + + const QuicStreamId id1 = 1; + const QuicStreamId id2 = 2; + RegisterStream(id1, kNotStatic, {kV3LowestPriority, kNotIncremental}); + RegisterStream(id2, kNotStatic, {kV3LowestPriority, kIncremental}); + + AddStream(id1); + AddStream(id2); + + // A non-incremental stream can continue writing as long as it has data. + EXPECT_EQ(id1, PopFront()); + const size_t kSmallWriteSize = 1000; + UpdateBytesForStream(id1, kSmallWriteSize); + AddStream(id1); + + EXPECT_EQ(id1, PopFront()); + UpdateBytesForStream(id1, kSmallWriteSize); + AddStream(id1); + + EXPECT_EQ(id1, PopFront()); + UpdateBytesForStream(id1, kSmallWriteSize); + + EXPECT_EQ(id2, PopFront()); + UpdateBytesForStream(id2, kSmallWriteSize); + AddStream(id2); + AddStream(id1); + + if (!disable_batch_write_) { + // Small writes do not exceed the batch limit. + // Writes continue even on an incremental stream. + EXPECT_EQ(id2, PopFront()); + UpdateBytesForStream(id2, kSmallWriteSize); + AddStream(id2); + + EXPECT_EQ(id2, PopFront()); + UpdateBytesForStream(id2, kSmallWriteSize); + } + + EXPECT_EQ(id1, PopFront()); + const size_t kLargeWriteSize = 1000 * 1000 * 1000; + UpdateBytesForStream(id1, kLargeWriteSize); + AddStream(id1); + + EXPECT_EQ(id1, PopFront()); + UpdateBytesForStream(id1, kLargeWriteSize); + AddStream(id1); + + EXPECT_EQ(id1, PopFront()); + UpdateBytesForStream(id1, kLargeWriteSize); + AddStream(id2); + AddStream(id1); + + // When batch writing is disabled, stream 2 immediately yields to stream 1, + // which is the non-incremental stream with most recent writes. + // When batch writing is enabled, stream 2 only yields to stream 1 after + // exceeding the batching limit. + if (!disable_batch_write_) { + EXPECT_EQ(id2, PopFront()); + UpdateBytesForStream(id2, kLargeWriteSize); + AddStream(id2); + } + + EXPECT_EQ(id1, PopFront()); + UpdateBytesForStream(id1, kLargeWriteSize); +} + +TEST_F(QuicWriteBlockedListTest, Ceding) { + RegisterStream(15, kNotStatic, {kV3HighestPriority, kNotIncremental}); + RegisterStream(16, kNotStatic, {kV3HighestPriority, kNotIncremental}); + RegisterStream(5, kNotStatic, {5, kNotIncremental}); + RegisterStream(4, kNotStatic, {5, kNotIncremental}); + RegisterStream(7, kNotStatic, {7, kNotIncremental}); + RegisterStream(1, kStatic, {kV3HighestPriority, kNotIncremental}); + RegisterStream(3, kStatic, {kV3HighestPriority, kNotIncremental}); + + // When nothing is on the list, nothing yields. + EXPECT_FALSE(ShouldYield(5)); + + AddStream(5); + // 5 should not yield to itself. + EXPECT_FALSE(ShouldYield(5)); + // 4 and 7 are equal or lower priority and should yield to 5. + EXPECT_TRUE(ShouldYield(4)); + EXPECT_TRUE(ShouldYield(7)); + // Stream 15 and static streams should preempt 5. + EXPECT_FALSE(ShouldYield(15)); + EXPECT_FALSE(ShouldYield(3)); + EXPECT_FALSE(ShouldYield(1)); + + // Block a high priority stream. + AddStream(15); + // 16 should yield (same priority) but static streams will still not. + EXPECT_TRUE(ShouldYield(16)); + EXPECT_FALSE(ShouldYield(3)); + EXPECT_FALSE(ShouldYield(1)); + + // Block a static stream. All non-static streams should yield. + AddStream(3); + EXPECT_TRUE(ShouldYield(16)); + EXPECT_TRUE(ShouldYield(15)); + EXPECT_FALSE(ShouldYield(3)); + EXPECT_FALSE(ShouldYield(1)); + + // Block the other static stream. All other streams should yield. + AddStream(1); + EXPECT_TRUE(ShouldYield(16)); + EXPECT_TRUE(ShouldYield(15)); + EXPECT_TRUE(ShouldYield(3)); + EXPECT_FALSE(ShouldYield(1)); +} + +TEST_F(QuicWriteBlockedListTest, UnregisterStream) { + RegisterStream(40, kNotStatic, {kV3LowestPriority, kNotIncremental}); + RegisterStream(23, kNotStatic, {6, kNotIncremental}); + RegisterStream(12, kNotStatic, {3, kNotIncremental}); + RegisterStream(17, kNotStatic, {kV3HighestPriority, kNotIncremental}); + RegisterStream(1, kStatic, {kV3HighestPriority, kNotIncremental}); + RegisterStream(3, kStatic, {kV3HighestPriority, kNotIncremental}); + + AddStream(40); + AddStream(23); + AddStream(12); + AddStream(17); + AddStream(1); + AddStream(3); + + UnregisterStream(23); + UnregisterStream(1); + + EXPECT_EQ(3u, PopFront()); + EXPECT_EQ(17u, PopFront()); + EXPECT_EQ(12u, PopFront()); + EXPECT_EQ(40, PopFront()); +} + +TEST_F(QuicWriteBlockedListTest, UnregisterNotRegisteredStream) { + EXPECT_QUICHE_BUG(UnregisterStream(1), "Stream 1 not registered"); + + RegisterStream(2, kNotStatic, {kV3HighestPriority, kIncremental}); + UnregisterStream(2); + EXPECT_QUICHE_BUG(UnregisterStream(2), "Stream 2 not registered"); +} + +TEST_F(QuicWriteBlockedListTest, UpdateStreamPriority) { + RegisterStream(40, kNotStatic, {kV3LowestPriority, kNotIncremental}); + RegisterStream(23, kNotStatic, {6, kIncremental}); + RegisterStream(17, kNotStatic, {kV3HighestPriority, kNotIncremental}); + RegisterStream(1, kStatic, {2, kNotIncremental}); + RegisterStream(3, kStatic, {kV3HighestPriority, kNotIncremental}); + + EXPECT_EQ(kV3LowestPriority, GetPriorityOfStream(40).http().urgency); + EXPECT_EQ(kNotIncremental, GetPriorityOfStream(40).http().incremental); + + EXPECT_EQ(6, GetPriorityOfStream(23).http().urgency); + EXPECT_EQ(kIncremental, GetPriorityOfStream(23).http().incremental); + + EXPECT_EQ(kV3HighestPriority, GetPriorityOfStream(17).http().urgency); + EXPECT_EQ(kNotIncremental, GetPriorityOfStream(17).http().incremental); + + UpdateStreamPriority(40, {3, kIncremental}); + UpdateStreamPriority(23, {kV3HighestPriority, kNotIncremental}); + UpdateStreamPriority(17, {5, kNotIncremental}); + + EXPECT_EQ(3, GetPriorityOfStream(40).http().urgency); + EXPECT_EQ(kIncremental, GetPriorityOfStream(40).http().incremental); + + EXPECT_EQ(kV3HighestPriority, GetPriorityOfStream(23).http().urgency); + EXPECT_EQ(kNotIncremental, GetPriorityOfStream(23).http().incremental); + + EXPECT_EQ(5, GetPriorityOfStream(17).http().urgency); + EXPECT_EQ(kNotIncremental, GetPriorityOfStream(17).http().incremental); + + AddStream(40); + AddStream(23); + AddStream(17); + AddStream(1); + AddStream(3); + + EXPECT_EQ(1u, PopFront()); + EXPECT_EQ(3u, PopFront()); + EXPECT_EQ(23u, PopFront()); + EXPECT_EQ(40u, PopFront()); + EXPECT_EQ(17u, PopFront()); +} + +// UpdateStreamPriority() must not be called for static streams. +TEST_F(QuicWriteBlockedListTest, UpdateStaticStreamPriority) { + RegisterStream(2, kStatic, {kV3LowestPriority, kNotIncremental}); + EXPECT_QUICHE_DEBUG_DEATH( + UpdateStreamPriority(2, {kV3HighestPriority, kNotIncremental}), + "IsRegistered"); +} + +TEST_F(QuicWriteBlockedListTest, UpdateStreamPrioritySameUrgency) { + // Streams with same urgency are returned by PopFront() in the order they were + // added by AddStream(). + RegisterStream(1, kNotStatic, {6, kNotIncremental}); + RegisterStream(2, kNotStatic, {6, kNotIncremental}); + + AddStream(1); + AddStream(2); + + EXPECT_EQ(1u, PopFront()); + EXPECT_EQ(2u, PopFront()); + + // Calling UpdateStreamPriority() on the first stream does not change the + // order. + RegisterStream(3, kNotStatic, {6, kNotIncremental}); + RegisterStream(4, kNotStatic, {6, kNotIncremental}); + + EXPECT_EQ(6, GetPriorityOfStream(3).http().urgency); + EXPECT_EQ(kNotIncremental, GetPriorityOfStream(3).http().incremental); + + UpdateStreamPriority(3, {6, kIncremental}); + + EXPECT_EQ(6, GetPriorityOfStream(3).http().urgency); + EXPECT_EQ(kIncremental, GetPriorityOfStream(3).http().incremental); + + AddStream(3); + AddStream(4); + + EXPECT_EQ(3u, PopFront()); + EXPECT_EQ(4u, PopFront()); + + // Calling UpdateStreamPriority() on the second stream does not change the + // order. + RegisterStream(5, kNotStatic, {6, kIncremental}); + RegisterStream(6, kNotStatic, {6, kIncremental}); + + EXPECT_EQ(6, GetPriorityOfStream(6).http().urgency); + EXPECT_EQ(kIncremental, GetPriorityOfStream(6).http().incremental); + + UpdateStreamPriority(6, {6, kNotIncremental}); + + EXPECT_EQ(6, GetPriorityOfStream(6).http().urgency); + EXPECT_EQ(kNotIncremental, GetPriorityOfStream(6).http().incremental); + + AddStream(5); + AddStream(6); + + EXPECT_EQ(5u, PopFront()); + EXPECT_EQ(6u, PopFront()); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/session_notifier_interface.h b/quiche/quic/core/session_notifier_interface.h new file mode 100644 index 000000000000..ee4453a867ac --- /dev/null +++ b/quiche/quic/core/session_notifier_interface.h @@ -0,0 +1,48 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_SESSION_NOTIFIER_INTERFACE_H_ +#define QUICHE_QUIC_CORE_SESSION_NOTIFIER_INTERFACE_H_ + +#include "quiche/quic/core/frames/quic_frame.h" +#include "quiche/quic/core/quic_time.h" + +namespace quic { + +// Pure virtual class to be notified when a packet containing a frame is acked +// or lost. +class QUIC_EXPORT_PRIVATE SessionNotifierInterface { + public: + virtual ~SessionNotifierInterface() {} + + // Called when |frame| is acked. Returns true if any new data gets acked, + // returns false otherwise. + virtual bool OnFrameAcked(const QuicFrame& frame, + QuicTime::Delta ack_delay_time, + QuicTime receive_timestamp) = 0; + + // Called when |frame| is retransmitted. + virtual void OnStreamFrameRetransmitted(const QuicStreamFrame& frame) = 0; + + // Called when |frame| is considered as lost. + virtual void OnFrameLost(const QuicFrame& frame) = 0; + + // Called to retransmit |frames| with transmission |type|. Returns true if all + // data gets retransmitted. + virtual bool RetransmitFrames(const QuicFrames& frames, + TransmissionType type) = 0; + + // Returns true if |frame| is outstanding and waiting to be acked. + virtual bool IsFrameOutstanding(const QuicFrame& frame) const = 0; + + // Returns true if crypto stream is waiting for acks. + virtual bool HasUnackedCryptoData() const = 0; + + // Returns true if any stream is waiting for acks. + virtual bool HasUnackedStreamData() const = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_SESSION_NOTIFIER_INTERFACE_H_ diff --git a/quiche/quic/core/socket_factory.h b/quiche/quic/core/socket_factory.h new file mode 100644 index 000000000000..708ff7d283f2 --- /dev/null +++ b/quiche/quic/core/socket_factory.h @@ -0,0 +1,47 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_SOCKET_FACTORY_H_ +#define QUICHE_QUIC_CORE_SOCKET_FACTORY_H_ + +#include + +#include "quiche/quic/core/connecting_client_socket.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace quic { + +// A factory to create objects of type Socket and derived interfaces. +class QUICHE_EXPORT SocketFactory { + public: + virtual ~SocketFactory() = default; + + // Will use platform default buffer size if `receive_buffer_size` or + // `send_buffer_size` is zero. If `async_visitor` is null, async operations + // must not be called on the created socket. If `async_visitor` is non-null, + // it must outlive the created socket. + virtual std::unique_ptr CreateTcpClientSocket( + const quic::QuicSocketAddress& peer_address, + QuicByteCount receive_buffer_size, QuicByteCount send_buffer_size, + ConnectingClientSocket::AsyncVisitor* async_visitor) = 0; + + // Will use platform default buffer size if `receive_buffer_size` or + // `send_buffer_size` is zero. If `async_visitor` is null, async operations + // must not be called on the created socket. If `async_visitor` is non-null, + // it must outlive the created socket. + // + // TODO(ericorth): Consider creating a sub-interface for connecting UDP + // sockets with additional functionality, e.g. sendto, if needed. + virtual std::unique_ptr + CreateConnectingUdpClientSocket( + const quic::QuicSocketAddress& peer_address, + QuicByteCount receive_buffer_size, QuicByteCount send_buffer_size, + ConnectingClientSocket::AsyncVisitor* async_visitor) = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_SOCKET_FACTORY_H_ diff --git a/quiche/quic/core/stream_delegate_interface.h b/quiche/quic/core/stream_delegate_interface.h new file mode 100644 index 000000000000..a5f03f0f635e --- /dev/null +++ b/quiche/quic/core/stream_delegate_interface.h @@ -0,0 +1,55 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_STREAM_DELEGATE_INTERFACE_H_ +#define QUICHE_QUIC_CORE_STREAM_DELEGATE_INTERFACE_H_ + +#include + +#include "absl/types/optional.h" +#include "quiche/quic/core/quic_stream_priority.h" +#include "quiche/quic/core/quic_types.h" + +namespace quic { + +class QuicStream; + +// Pure virtual class to get notified when particular QuicStream events +// occurred. +class QUIC_EXPORT_PRIVATE StreamDelegateInterface { + public: + virtual ~StreamDelegateInterface() {} + + // Called when the stream has encountered errors that it can't handle. + virtual void OnStreamError(QuicErrorCode error_code, + std::string error_details) = 0; + // Called when the stream has encountered errors that it can't handle, + // specifying the wire error code |ietf_error| explicitly. + virtual void OnStreamError(QuicErrorCode error_code, + QuicIetfTransportErrorCodes ietf_error, + std::string error_details) = 0; + // Called when the stream needs to write data at specified |level| and + // transmission |type|. + virtual QuicConsumedData WritevData(QuicStreamId id, size_t write_length, + QuicStreamOffset offset, + StreamSendingState state, + TransmissionType type, + EncryptionLevel level) = 0; + // Called to write crypto data. + virtual size_t SendCryptoData(EncryptionLevel level, size_t write_length, + QuicStreamOffset offset, + TransmissionType type) = 0; + // Called on stream creation. + virtual void RegisterStreamPriority(QuicStreamId id, bool is_static, + const QuicStreamPriority& priority) = 0; + // Called on stream destruction to clear priority. + virtual void UnregisterStreamPriority(QuicStreamId id) = 0; + // Called by the stream on SetPriority to update priority. + virtual void UpdateStreamPriority(QuicStreamId id, + const QuicStreamPriority& new_priority) = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_STREAM_DELEGATE_INTERFACE_H_ diff --git a/quiche/quic/core/tls_chlo_extractor.cc b/quiche/quic/core/tls_chlo_extractor.cc new file mode 100644 index 000000000000..9a741d019294 --- /dev/null +++ b/quiche/quic/core/tls_chlo_extractor.cc @@ -0,0 +1,429 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/tls_chlo_extractor.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "openssl/ssl.h" +#include "quiche/quic/core/frames/quic_crypto_frame.h" +#include "quiche/quic/core/quic_data_reader.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_framer.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace quic { + +namespace { +bool HasExtension(const SSL_CLIENT_HELLO* client_hello, uint16_t extension) { + const uint8_t* unused_extension_bytes; + size_t unused_extension_len; + return 1 == SSL_early_callback_ctx_extension_get(client_hello, extension, + &unused_extension_bytes, + &unused_extension_len); +} +} // namespace + +TlsChloExtractor::TlsChloExtractor() + : crypto_stream_sequencer_(this), + state_(State::kInitial), + parsed_crypto_frame_in_this_packet_(false) {} + +TlsChloExtractor::TlsChloExtractor(TlsChloExtractor&& other) + : TlsChloExtractor() { + *this = std::move(other); +} + +TlsChloExtractor& TlsChloExtractor::operator=(TlsChloExtractor&& other) { + framer_ = std::move(other.framer_); + if (framer_) { + framer_->set_visitor(this); + } + crypto_stream_sequencer_ = std::move(other.crypto_stream_sequencer_); + crypto_stream_sequencer_.set_stream(this); + ssl_ = std::move(other.ssl_); + if (ssl_) { + std::pair shared_handles = GetSharedSslHandles(); + int ex_data_index = shared_handles.second; + const int rv = SSL_set_ex_data(ssl_.get(), ex_data_index, this); + QUICHE_CHECK_EQ(rv, 1) << "Internal allocation failure in SSL_set_ex_data"; + } + state_ = other.state_; + error_details_ = std::move(other.error_details_); + parsed_crypto_frame_in_this_packet_ = + other.parsed_crypto_frame_in_this_packet_; + alpns_ = std::move(other.alpns_); + server_name_ = std::move(other.server_name_); + client_hello_bytes_ = std::move(other.client_hello_bytes_); + return *this; +} + +void TlsChloExtractor::IngestPacket(const ParsedQuicVersion& version, + const QuicReceivedPacket& packet) { + if (state_ == State::kUnrecoverableFailure) { + QUIC_DLOG(ERROR) << "Not ingesting packet after unrecoverable error"; + return; + } + if (version == UnsupportedQuicVersion()) { + QUIC_DLOG(ERROR) << "Not ingesting packet with unsupported version"; + return; + } + if (version.handshake_protocol != PROTOCOL_TLS1_3) { + QUIC_DLOG(ERROR) << "Not ingesting packet with non-TLS version " << version; + return; + } + if (framer_) { + // This is not the first packet we have ingested, check if version matches. + if (!framer_->IsSupportedVersion(version)) { + QUIC_DLOG(ERROR) + << "Not ingesting packet with version mismatch, expected " + << framer_->version() << ", got " << version; + return; + } + } else { + // This is the first packet we have ingested, setup parser. + framer_ = std::make_unique( + ParsedQuicVersionVector{version}, QuicTime::Zero(), + Perspective::IS_SERVER, /*expected_server_connection_id_length=*/0); + // Note that expected_server_connection_id_length only matters for short + // headers and we explicitly drop those so we can pass any value here. + framer_->set_visitor(this); + } + + // When the framer parses |packet|, if it sees a CRYPTO frame it will call + // OnCryptoFrame below and that will set parsed_crypto_frame_in_this_packet_ + // to true. + parsed_crypto_frame_in_this_packet_ = false; + const bool parse_success = framer_->ProcessPacket(packet); + if (state_ == State::kInitial && parsed_crypto_frame_in_this_packet_) { + // If we parsed a CRYPTO frame but didn't advance the state from initial, + // then it means that we will need more packets to reassemble the full CHLO, + // so we advance the state here. This can happen when the first packet + // received is not the first one in the crypto stream. This allows us to + // differentiate our state between single-packet CHLO and multi-packet CHLO. + state_ = State::kParsedPartialChloFragment; + } + + if (!parse_success) { + // This could be due to the packet being non-initial for example. + QUIC_DLOG(ERROR) << "Failed to process packet"; + return; + } +} + +// This is called when the framer parsed the unencrypted parts of the header. +bool TlsChloExtractor::OnUnauthenticatedPublicHeader( + const QuicPacketHeader& header) { + if (header.form != IETF_QUIC_LONG_HEADER_PACKET) { + QUIC_DLOG(ERROR) << "Not parsing non-long-header packet " << header; + return false; + } + if (header.long_packet_type != INITIAL) { + QUIC_DLOG(ERROR) << "Not parsing non-initial packet " << header; + return false; + } + // QuicFramer is constructed without knowledge of the server's connection ID + // so it needs to be set up here in order to decrypt the packet. + framer_->SetInitialObfuscators(header.destination_connection_id); + return true; +} + +// This is called by the framer if it detects a change in version during +// parsing. +bool TlsChloExtractor::OnProtocolVersionMismatch(ParsedQuicVersion version) { + // This should never be called because we already check versions in + // IngestPacket. + QUIC_BUG(quic_bug_10855_1) << "Unexpected version mismatch, expected " + << framer_->version() << ", got " << version; + return false; +} + +// This is called by the QuicStreamSequencer if it encounters an unrecoverable +// error that will prevent it from reassembling the crypto stream data. +void TlsChloExtractor::OnUnrecoverableError(QuicErrorCode error, + const std::string& details) { + HandleUnrecoverableError(absl::StrCat( + "Crypto stream error ", QuicErrorCodeToString(error), ": ", details)); +} + +void TlsChloExtractor::OnUnrecoverableError( + QuicErrorCode error, QuicIetfTransportErrorCodes ietf_error, + const std::string& details) { + HandleUnrecoverableError(absl::StrCat( + "Crypto stream error ", QuicErrorCodeToString(error), "(", + QuicIetfTransportErrorCodeString(ietf_error), "): ", details)); +} + +// This is called by the framer if it sees a CRYPTO frame during parsing. +bool TlsChloExtractor::OnCryptoFrame(const QuicCryptoFrame& frame) { + if (frame.level != ENCRYPTION_INITIAL) { + // Since we drop non-INITIAL packets in OnUnauthenticatedPublicHeader, + // we should never receive any CRYPTO frames at other encryption levels. + QUIC_BUG(quic_bug_10855_2) << "Parsed bad-level CRYPTO frame " << frame; + return false; + } + // parsed_crypto_frame_in_this_packet_ is checked in IngestPacket to allow + // advancing our state to track the difference between single-packet CHLO + // and multi-packet CHLO. + parsed_crypto_frame_in_this_packet_ = true; + crypto_stream_sequencer_.OnCryptoFrame(frame); + return true; +} + +// Called by the QuicStreamSequencer when it receives a CRYPTO frame that +// advances the amount of contiguous data we now have starting from offset 0. +void TlsChloExtractor::OnDataAvailable() { + // Lazily set up BoringSSL handle. + SetupSslHandle(); + + // Get data from the stream sequencer and pass it to BoringSSL. + struct iovec iov; + while (crypto_stream_sequencer_.GetReadableRegion(&iov)) { + const int rv = SSL_provide_quic_data( + ssl_.get(), ssl_encryption_initial, + reinterpret_cast(iov.iov_base), iov.iov_len); + if (rv != 1) { + HandleUnrecoverableError("SSL_provide_quic_data failed"); + return; + } + crypto_stream_sequencer_.MarkConsumed(iov.iov_len); + } + + // Instruct BoringSSL to attempt parsing a full CHLO from the provided data. + // We ignore the return value since we know the handshake is going to fail + // because we explicitly cancel processing once we've parsed the CHLO. + (void)SSL_do_handshake(ssl_.get()); +} + +// static +TlsChloExtractor* TlsChloExtractor::GetInstanceFromSSL(SSL* ssl) { + std::pair shared_handles = GetSharedSslHandles(); + int ex_data_index = shared_handles.second; + return reinterpret_cast( + SSL_get_ex_data(ssl, ex_data_index)); +} + +// static +int TlsChloExtractor::SetReadSecretCallback( + SSL* ssl, enum ssl_encryption_level_t /*level*/, + const SSL_CIPHER* /*cipher*/, const uint8_t* /*secret*/, + size_t /*secret_length*/) { + GetInstanceFromSSL(ssl)->HandleUnexpectedCallback("SetReadSecretCallback"); + return 0; +} + +// static +int TlsChloExtractor::SetWriteSecretCallback( + SSL* ssl, enum ssl_encryption_level_t /*level*/, + const SSL_CIPHER* /*cipher*/, const uint8_t* /*secret*/, + size_t /*secret_length*/) { + GetInstanceFromSSL(ssl)->HandleUnexpectedCallback("SetWriteSecretCallback"); + return 0; +} + +// static +int TlsChloExtractor::WriteMessageCallback( + SSL* ssl, enum ssl_encryption_level_t /*level*/, const uint8_t* /*data*/, + size_t /*len*/) { + GetInstanceFromSSL(ssl)->HandleUnexpectedCallback("WriteMessageCallback"); + return 0; +} + +// static +int TlsChloExtractor::FlushFlightCallback(SSL* ssl) { + GetInstanceFromSSL(ssl)->HandleUnexpectedCallback("FlushFlightCallback"); + return 0; +} + +void TlsChloExtractor::HandleUnexpectedCallback( + const std::string& callback_name) { + std::string error_details = + absl::StrCat("Unexpected callback ", callback_name); + QUIC_BUG(quic_bug_10855_3) << error_details; + HandleUnrecoverableError(error_details); +} + +// static +int TlsChloExtractor::SendAlertCallback(SSL* ssl, + enum ssl_encryption_level_t /*level*/, + uint8_t desc) { + GetInstanceFromSSL(ssl)->SendAlert(desc); + return 0; +} + +void TlsChloExtractor::SendAlert(uint8_t tls_alert_value) { + if (tls_alert_value == SSL3_AD_HANDSHAKE_FAILURE && HasParsedFullChlo()) { + // This is the most common scenario. Since we return an error from + // SelectCertCallback in order to cancel further processing, BoringSSL will + // try to send this alert to tell the client that the handshake failed. + return; + } + HandleUnrecoverableError(absl::StrCat( + "BoringSSL attempted to send alert ", static_cast(tls_alert_value), + " ", SSL_alert_desc_string_long(tls_alert_value))); + if (state_ == State::kUnrecoverableFailure) { + tls_alert_ = tls_alert_value; + } +} + +// static +enum ssl_select_cert_result_t TlsChloExtractor::SelectCertCallback( + const SSL_CLIENT_HELLO* client_hello) { + GetInstanceFromSSL(client_hello->ssl)->HandleParsedChlo(client_hello); + // Always return an error to cancel any further processing in BoringSSL. + return ssl_select_cert_error; +} + +// Extracts the server name and ALPN from the parsed ClientHello. +void TlsChloExtractor::HandleParsedChlo(const SSL_CLIENT_HELLO* client_hello) { + const char* server_name = + SSL_get_servername(client_hello->ssl, TLSEXT_NAMETYPE_host_name); + if (server_name) { + server_name_ = std::string(server_name); + } + + resumption_attempted_ = + HasExtension(client_hello, TLSEXT_TYPE_pre_shared_key); + early_data_attempted_ = HasExtension(client_hello, TLSEXT_TYPE_early_data); + + QUICHE_DCHECK(client_hello_bytes_.empty()); + client_hello_bytes_.assign( + client_hello->client_hello, + client_hello->client_hello + client_hello->client_hello_len); + + const uint8_t* alpn_data; + size_t alpn_len; + int rv = SSL_early_callback_ctx_extension_get( + client_hello, TLSEXT_TYPE_application_layer_protocol_negotiation, + &alpn_data, &alpn_len); + if (rv == 1) { + QuicDataReader alpns_reader(reinterpret_cast(alpn_data), + alpn_len); + absl::string_view alpns_payload; + if (!alpns_reader.ReadStringPiece16(&alpns_payload)) { + HandleUnrecoverableError("Failed to read alpns_payload"); + return; + } + QuicDataReader alpns_payload_reader(alpns_payload); + while (!alpns_payload_reader.IsDoneReading()) { + absl::string_view alpn_payload; + if (!alpns_payload_reader.ReadStringPiece8(&alpn_payload)) { + HandleUnrecoverableError("Failed to read alpn_payload"); + return; + } + alpns_.emplace_back(std::string(alpn_payload)); + } + } + + // Update our state now that we've parsed a full CHLO. + if (state_ == State::kInitial) { + state_ = State::kParsedFullSinglePacketChlo; + } else if (state_ == State::kParsedPartialChloFragment) { + state_ = State::kParsedFullMultiPacketChlo; + } else { + QUIC_BUG(quic_bug_10855_4) + << "Unexpected state on successful parse " << StateToString(state_); + } +} + +// static +std::pair TlsChloExtractor::GetSharedSslHandles() { + // Use a lambda to benefit from C++11 guarantee that static variables are + // initialized lazily in a thread-safe manner. |shared_handles| is therefore + // guaranteed to be initialized exactly once and never destructed. + static std::pair* shared_handles = []() { + CRYPTO_library_init(); + SSL_CTX* ssl_ctx = SSL_CTX_new(TLS_with_buffers_method()); + SSL_CTX_set_min_proto_version(ssl_ctx, TLS1_3_VERSION); + SSL_CTX_set_max_proto_version(ssl_ctx, TLS1_3_VERSION); + static const SSL_QUIC_METHOD kQuicCallbacks{ + TlsChloExtractor::SetReadSecretCallback, + TlsChloExtractor::SetWriteSecretCallback, + TlsChloExtractor::WriteMessageCallback, + TlsChloExtractor::FlushFlightCallback, + TlsChloExtractor::SendAlertCallback}; + SSL_CTX_set_quic_method(ssl_ctx, &kQuicCallbacks); + SSL_CTX_set_select_certificate_cb(ssl_ctx, + TlsChloExtractor::SelectCertCallback); + int ex_data_index = + SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); + return new std::pair(ssl_ctx, ex_data_index); + }(); + return *shared_handles; +} + +// Sets up the per-instance SSL handle needed by BoringSSL. +void TlsChloExtractor::SetupSslHandle() { + if (ssl_) { + // Handles have already been set up. + return; + } + + std::pair shared_handles = GetSharedSslHandles(); + SSL_CTX* ssl_ctx = shared_handles.first; + int ex_data_index = shared_handles.second; + + ssl_ = bssl::UniquePtr(SSL_new(ssl_ctx)); + const int rv = SSL_set_ex_data(ssl_.get(), ex_data_index, this); + QUICHE_CHECK_EQ(rv, 1) << "Internal allocation failure in SSL_set_ex_data"; + SSL_set_accept_state(ssl_.get()); + + // Make sure we use the right TLS extension codepoint. + int use_legacy_extension = 0; + if (framer_->version().UsesLegacyTlsExtension()) { + use_legacy_extension = 1; + } + SSL_set_quic_use_legacy_codepoint(ssl_.get(), use_legacy_extension); +} + +// Called by other methods to record any unrecoverable failures they experience. +void TlsChloExtractor::HandleUnrecoverableError( + const std::string& error_details) { + if (HasParsedFullChlo()) { + // Ignore errors if we've parsed everything successfully. + QUIC_DLOG(ERROR) << "Ignoring error: " << error_details; + return; + } + QUIC_DLOG(ERROR) << "Handling error: " << error_details; + + state_ = State::kUnrecoverableFailure; + + if (error_details_.empty()) { + error_details_ = error_details; + } else { + error_details_ = absl::StrCat(error_details_, "; ", error_details); + } +} + +// static +std::string TlsChloExtractor::StateToString(State state) { + switch (state) { + case State::kInitial: + return "Initial"; + case State::kParsedFullSinglePacketChlo: + return "ParsedFullSinglePacketChlo"; + case State::kParsedFullMultiPacketChlo: + return "ParsedFullMultiPacketChlo"; + case State::kParsedPartialChloFragment: + return "ParsedPartialChloFragment"; + case State::kUnrecoverableFailure: + return "UnrecoverableFailure"; + } + return absl::StrCat("Unknown(", static_cast(state), ")"); +} + +std::ostream& operator<<(std::ostream& os, + const TlsChloExtractor::State& state) { + os << TlsChloExtractor::StateToString(state); + return os; +} + +} // namespace quic diff --git a/quiche/quic/core/tls_chlo_extractor.h b/quiche/quic/core/tls_chlo_extractor.h new file mode 100644 index 000000000000..d8c6b5594d29 --- /dev/null +++ b/quiche/quic/core/tls_chlo_extractor.h @@ -0,0 +1,280 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_TLS_CHLO_EXTRACTOR_H_ +#define QUICHE_QUIC_CORE_TLS_CHLO_EXTRACTOR_H_ + +#include +#include +#include + +#include "absl/types/span.h" +#include "openssl/ssl.h" +#include "quiche/quic/core/frames/quic_ack_frequency_frame.h" +#include "quiche/quic/core/quic_framer.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_stream_sequencer.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// Utility class that allows extracting information from a QUIC-TLS Client +// Hello. This class creates a QuicFramer to parse the packet, and implements +// QuicFramerVisitorInterface to access the frames parsed by the QuicFramer. It +// then uses a QuicStreamSequencer to reassemble the contents of the crypto +// stream, and implements QuicStreamSequencer::StreamInterface to access the +// reassembled data. +class QUIC_NO_EXPORT TlsChloExtractor + : public QuicFramerVisitorInterface, + public QuicStreamSequencer::StreamInterface { + public: + TlsChloExtractor(); + TlsChloExtractor(const TlsChloExtractor&) = delete; + TlsChloExtractor(TlsChloExtractor&&); + TlsChloExtractor& operator=(const TlsChloExtractor&) = delete; + TlsChloExtractor& operator=(TlsChloExtractor&&); + + enum class State : uint8_t { + kInitial = 0, + kParsedFullSinglePacketChlo = 1, + kParsedFullMultiPacketChlo = 2, + kParsedPartialChloFragment = 3, + kUnrecoverableFailure = 4, + }; + + State state() const { return state_; } + std::vector alpns() const { return alpns_; } + std::string server_name() const { return server_name_; } + bool resumption_attempted() const { return resumption_attempted_; } + bool early_data_attempted() const { return early_data_attempted_; } + absl::Span client_hello_bytes() const { + return client_hello_bytes_; + } + + // Converts |state| to a human-readable string suitable for logging. + static std::string StateToString(State state); + + // Ingests |packet| and attempts to parse out the CHLO. + void IngestPacket(const ParsedQuicVersion& version, + const QuicReceivedPacket& packet); + + // Returns whether the ingested packets have allowed parsing a complete CHLO. + bool HasParsedFullChlo() const { + return state_ == State::kParsedFullSinglePacketChlo || + state_ == State::kParsedFullMultiPacketChlo; + } + + // Returns the TLS alert that caused the unrecoverable error, if any. + absl::optional tls_alert() const { + QUICHE_DCHECK(!tls_alert_.has_value() || + state_ == State::kUnrecoverableFailure); + return tls_alert_; + } + + // Methods from QuicFramerVisitorInterface. + void OnError(QuicFramer* /*framer*/) override {} + bool OnProtocolVersionMismatch(ParsedQuicVersion version) override; + void OnPacket() override {} + void OnPublicResetPacket(const QuicPublicResetPacket& /*packet*/) override {} + void OnVersionNegotiationPacket( + const QuicVersionNegotiationPacket& /*packet*/) override {} + void OnRetryPacket(QuicConnectionId /*original_connection_id*/, + QuicConnectionId /*new_connection_id*/, + absl::string_view /*retry_token*/, + absl::string_view /*retry_integrity_tag*/, + absl::string_view /*retry_without_tag*/) override {} + bool OnUnauthenticatedPublicHeader(const QuicPacketHeader& header) override; + bool OnUnauthenticatedHeader(const QuicPacketHeader& /*header*/) override { + return true; + } + void OnDecryptedPacket(size_t /*packet_length*/, + EncryptionLevel /*level*/) override {} + bool OnPacketHeader(const QuicPacketHeader& /*header*/) override { + return true; + } + void OnCoalescedPacket(const QuicEncryptedPacket& /*packet*/) override {} + void OnUndecryptablePacket(const QuicEncryptedPacket& /*packet*/, + EncryptionLevel /*decryption_level*/, + bool /*has_decryption_key*/) override {} + bool OnStreamFrame(const QuicStreamFrame& /*frame*/) override { return true; } + bool OnCryptoFrame(const QuicCryptoFrame& frame) override; + bool OnAckFrameStart(QuicPacketNumber /*largest_acked*/, + QuicTime::Delta /*ack_delay_time*/) override { + return true; + } + bool OnAckRange(QuicPacketNumber /*start*/, + QuicPacketNumber /*end*/) override { + return true; + } + bool OnAckTimestamp(QuicPacketNumber /*packet_number*/, + QuicTime /*timestamp*/) override { + return true; + } + bool OnAckFrameEnd( + QuicPacketNumber /*start*/, + const absl::optional& /*ecn_counts*/) override { + return true; + } + bool OnStopWaitingFrame(const QuicStopWaitingFrame& /*frame*/) override { + return true; + } + bool OnPingFrame(const QuicPingFrame& /*frame*/) override { return true; } + bool OnRstStreamFrame(const QuicRstStreamFrame& /*frame*/) override { + return true; + } + bool OnConnectionCloseFrame( + const QuicConnectionCloseFrame& /*frame*/) override { + return true; + } + bool OnNewConnectionIdFrame( + const QuicNewConnectionIdFrame& /*frame*/) override { + return true; + } + bool OnRetireConnectionIdFrame( + const QuicRetireConnectionIdFrame& /*frame*/) override { + return true; + } + bool OnNewTokenFrame(const QuicNewTokenFrame& /*frame*/) override { + return true; + } + bool OnStopSendingFrame(const QuicStopSendingFrame& /*frame*/) override { + return true; + } + bool OnPathChallengeFrame(const QuicPathChallengeFrame& /*frame*/) override { + return true; + } + bool OnPathResponseFrame(const QuicPathResponseFrame& /*frame*/) override { + return true; + } + bool OnGoAwayFrame(const QuicGoAwayFrame& /*frame*/) override { return true; } + bool OnMaxStreamsFrame(const QuicMaxStreamsFrame& /*frame*/) override { + return true; + } + bool OnStreamsBlockedFrame( + const QuicStreamsBlockedFrame& /*frame*/) override { + return true; + } + bool OnWindowUpdateFrame(const QuicWindowUpdateFrame& /*frame*/) override { + return true; + } + bool OnBlockedFrame(const QuicBlockedFrame& /*frame*/) override { + return true; + } + bool OnPaddingFrame(const QuicPaddingFrame& /*frame*/) override { + return true; + } + bool OnMessageFrame(const QuicMessageFrame& /*frame*/) override { + return true; + } + bool OnHandshakeDoneFrame(const QuicHandshakeDoneFrame& /*frame*/) override { + return true; + } + bool OnAckFrequencyFrame(const QuicAckFrequencyFrame& /*frame*/) override { + return true; + } + void OnPacketComplete() override {} + bool IsValidStatelessResetToken( + const StatelessResetToken& /*token*/) const override { + return true; + } + void OnAuthenticatedIetfStatelessResetPacket( + const QuicIetfStatelessResetPacket& /*packet*/) override {} + void OnKeyUpdate(KeyUpdateReason /*reason*/) override {} + void OnDecryptedFirstPacketInKeyPhase() override {} + std::unique_ptr AdvanceKeysAndCreateCurrentOneRttDecrypter() + override { + return nullptr; + } + std::unique_ptr CreateCurrentOneRttEncrypter() override { + return nullptr; + } + + // Methods from QuicStreamSequencer::StreamInterface. + void OnDataAvailable() override; + void OnFinRead() override {} + void AddBytesConsumed(QuicByteCount /*bytes*/) override {} + void ResetWithError(QuicResetStreamError /*error*/) override {} + void OnUnrecoverableError(QuicErrorCode error, + const std::string& details) override; + void OnUnrecoverableError(QuicErrorCode error, + QuicIetfTransportErrorCodes ietf_error, + const std::string& details) override; + QuicStreamId id() const override { return 0; } + ParsedQuicVersion version() const override { return framer_->version(); } + + private: + // Parses the length of the CHLO message by looking at the first four bytes. + // Returns whether we have received enough data to parse the full CHLO now. + bool MaybeAttemptToParseChloLength(); + // Parses the full CHLO message if enough data has been received. + void AttemptToParseFullChlo(); + // Moves to the failed state and records the error details. + void HandleUnrecoverableError(const std::string& error_details); + // Lazily sets up shared SSL handles if needed. + static std::pair GetSharedSslHandles(); + // Lazily sets up the per-instance SSL handle if needed. + void SetupSslHandle(); + // Extract the TlsChloExtractor instance from |ssl|. + static TlsChloExtractor* GetInstanceFromSSL(SSL* ssl); + + // BoringSSL static TLS callbacks. + static enum ssl_select_cert_result_t SelectCertCallback( + const SSL_CLIENT_HELLO* client_hello); + static int SetReadSecretCallback(SSL* ssl, enum ssl_encryption_level_t level, + const SSL_CIPHER* cipher, + const uint8_t* secret, size_t secret_length); + static int SetWriteSecretCallback(SSL* ssl, enum ssl_encryption_level_t level, + const SSL_CIPHER* cipher, + const uint8_t* secret, + size_t secret_length); + static int WriteMessageCallback(SSL* ssl, enum ssl_encryption_level_t level, + const uint8_t* data, size_t len); + static int FlushFlightCallback(SSL* ssl); + static int SendAlertCallback(SSL* ssl, enum ssl_encryption_level_t level, + uint8_t desc); + + // Called by SelectCertCallback. + void HandleParsedChlo(const SSL_CLIENT_HELLO* client_hello); + // Called by callbacks that should never be called. + void HandleUnexpectedCallback(const std::string& callback_name); + // Called by SendAlertCallback. + void SendAlert(uint8_t tls_alert_value); + + // Used to parse received packets to extract single frames. + std::unique_ptr framer_; + // Used to reassemble the crypto stream from received CRYPTO frames. + QuicStreamSequencer crypto_stream_sequencer_; + // BoringSSL handle required to parse the CHLO. + bssl::UniquePtr ssl_; + // State of this TlsChloExtractor. + State state_; + // Detail string that can be logged in the presence of unrecoverable errors. + std::string error_details_; + // Whether a CRYPTO frame was parsed in this packet. + bool parsed_crypto_frame_in_this_packet_; + // Array of ALPNs parsed from the CHLO. + std::vector alpns_; + // SNI parsed from the CHLO. + std::string server_name_; + // Whether resumption is attempted from the CHLO, indicated by the + // 'pre_shared_key' TLS extension. + bool resumption_attempted_ = false; + // Whether early data is attempted from the CHLO, indicated by the + // 'early_data' TLS extension. + bool early_data_attempted_ = false; + // If set, contains the TLS alert that caused an unrecoverable error, which is + // an AlertDescription value defined in go/rfc/8446#appendix-B.2. + absl::optional tls_alert_; + // Exact TLS message bytes. + std::vector client_hello_bytes_; +}; + +// Convenience method to facilitate logging TlsChloExtractor::State. +QUIC_NO_EXPORT std::ostream& operator<<(std::ostream& os, + const TlsChloExtractor::State& state); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_TLS_CHLO_EXTRACTOR_H_ diff --git a/quiche/quic/core/tls_chlo_extractor_test.cc b/quiche/quic/core/tls_chlo_extractor_test.cc new file mode 100644 index 000000000000..44a1e634856b --- /dev/null +++ b/quiche/quic/core/tls_chlo_extractor_test.cc @@ -0,0 +1,291 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/tls_chlo_extractor.h" + +#include + +#include "openssl/ssl.h" +#include "quiche/quic/core/http/quic_spdy_client_session.h" +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/first_flight.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/test_tools/simple_session_cache.h" + +namespace quic { +namespace test { +namespace { + +using testing::_; +using testing::AnyNumber; + +class TlsChloExtractorTest : public QuicTestWithParam { + protected: + TlsChloExtractorTest() : version_(GetParam()), server_id_(TestServerId()) {} + + void Initialize() { + AnnotatedPackets packets = + GetAnnotatedFirstFlightOfPackets(version_, config_); + packets_ = std::move(packets.packets); + crypto_stream_size_ = packets.crypto_stream_size; + } + + void Initialize(std::unique_ptr crypto_config) { + AnnotatedPackets packets = GetAnnotatedFirstFlightOfPackets( + version_, config_, TestConnectionId(), EmptyQuicConnectionId(), + std::move(crypto_config)); + packets_ = std::move(packets.packets); + crypto_stream_size_ = packets.crypto_stream_size; + } + + // Perform a full handshake in order to insert a SSL_SESSION into + // crypto_config->session_cache(), which can be used by a TLS resumption. + void PerformFullHandshake(QuicCryptoClientConfig* crypto_config) const { + ASSERT_NE(crypto_config->session_cache(), nullptr); + MockQuicConnectionHelper client_helper, server_helper; + MockAlarmFactory alarm_factory; + ParsedQuicVersionVector supported_versions = {version_}; + PacketSavingConnection* client_connection = + new PacketSavingConnection(&client_helper, &alarm_factory, + Perspective::IS_CLIENT, supported_versions); + // Advance the time, because timers do not like uninitialized times. + client_connection->AdvanceTime(QuicTime::Delta::FromSeconds(1)); + QuicClientPushPromiseIndex push_promise_index; + QuicSpdyClientSession client_session(config_, supported_versions, + client_connection, server_id_, + crypto_config, &push_promise_index); + client_session.Initialize(); + + std::unique_ptr server_crypto_config = + crypto_test_utils::CryptoServerConfigForTesting(); + QuicConfig server_config; + + EXPECT_CALL(*client_connection, SendCryptoData(_, _, _)).Times(AnyNumber()); + client_session.GetMutableCryptoStream()->CryptoConnect(); + + crypto_test_utils::HandshakeWithFakeServer( + &server_config, server_crypto_config.get(), &server_helper, + &alarm_factory, client_connection, + client_session.GetMutableCryptoStream(), + AlpnForVersion(client_connection->version())); + + // For some reason, the test client can not receive the server settings and + // the SSL_SESSION will not be inserted to client's session_cache. We create + // a dummy settings and call SetServerApplicationStateForResumption manually + // to ensure the SSL_SESSION is cached. + // TODO(wub): Fix crypto_test_utils::HandshakeWithFakeServer to make sure a + // SSL_SESSION is cached at the client, and remove the rest of the function. + SettingsFrame server_settings; + server_settings.values[SETTINGS_QPACK_MAX_TABLE_CAPACITY] = + kDefaultQpackMaxDynamicTableCapacity; + std::string settings_frame = + HttpEncoder::SerializeSettingsFrame(server_settings); + client_session.GetMutableCryptoStream() + ->SetServerApplicationStateForResumption( + std::make_unique( + settings_frame.data(), + settings_frame.data() + settings_frame.length())); + } + + void IngestPackets() { + for (const std::unique_ptr& packet : packets_) { + ReceivedPacketInfo packet_info( + QuicSocketAddress(TestPeerIPAddress(), kTestPort), + QuicSocketAddress(TestPeerIPAddress(), kTestPort), *packet); + std::string detailed_error; + absl::optional retry_token; + const QuicErrorCode error = QuicFramer::ParsePublicHeaderDispatcher( + *packet, /*expected_destination_connection_id_length=*/0, + &packet_info.form, &packet_info.long_packet_type, + &packet_info.version_flag, &packet_info.use_length_prefix, + &packet_info.version_label, &packet_info.version, + &packet_info.destination_connection_id, + &packet_info.source_connection_id, &retry_token, &detailed_error); + ASSERT_THAT(error, IsQuicNoError()) << detailed_error; + tls_chlo_extractor_.IngestPacket(packet_info.version, packet_info.packet); + } + packets_.clear(); + } + + void ValidateChloDetails(const TlsChloExtractor* extractor = nullptr) const { + if (extractor == nullptr) { + extractor = &tls_chlo_extractor_; + } + + EXPECT_TRUE(extractor->HasParsedFullChlo()); + std::vector alpns = extractor->alpns(); + ASSERT_EQ(alpns.size(), 1u); + EXPECT_EQ(alpns[0], AlpnForVersion(version_)); + EXPECT_EQ(extractor->server_name(), TestHostname()); + // Crypto stream has one frame in the following format: + // CRYPTO Frame { + // Type (i) = 0x06, + // Offset (i), + // Length (i), + // Crypto Data (..), + // } + // + // Type is 1 byte long, Offset is zero and also 1 byte long, and + // all generated ClientHello messages have 2 byte length. So + // the header is 4 bytes total. + EXPECT_EQ(extractor->client_hello_bytes().size(), crypto_stream_size_ - 4); + } + + void IncreaseSizeOfChlo() { + // Add a 2000-byte custom parameter to increase the length of the CHLO. + constexpr auto kCustomParameterId = + static_cast(0xff33); + std::string kCustomParameterValue(2000, '-'); + config_.custom_transport_parameters_to_send()[kCustomParameterId] = + kCustomParameterValue; + } + + ParsedQuicVersion version_; + QuicServerId server_id_; + TlsChloExtractor tls_chlo_extractor_; + QuicConfig config_; + std::vector> packets_; + uint64_t crypto_stream_size_; +}; + +INSTANTIATE_TEST_SUITE_P(TlsChloExtractorTests, TlsChloExtractorTest, + ::testing::ValuesIn(AllSupportedVersionsWithTls()), + ::testing::PrintToStringParamName()); + +TEST_P(TlsChloExtractorTest, Simple) { + Initialize(); + EXPECT_EQ(packets_.size(), 1u); + IngestPackets(); + ValidateChloDetails(); + EXPECT_EQ(tls_chlo_extractor_.state(), + TlsChloExtractor::State::kParsedFullSinglePacketChlo); + EXPECT_FALSE(tls_chlo_extractor_.resumption_attempted()); + EXPECT_FALSE(tls_chlo_extractor_.early_data_attempted()); +} + +TEST_P(TlsChloExtractorTest, TlsExtentionInfo_ResumptionOnly) { + auto crypto_client_config = std::make_unique( + crypto_test_utils::ProofVerifierForTesting(), + std::make_unique()); + PerformFullHandshake(crypto_client_config.get()); + + SSL_CTX_set_early_data_enabled(crypto_client_config->ssl_ctx(), 0); + Initialize(std::move(crypto_client_config)); + EXPECT_GE(packets_.size(), 1u); + IngestPackets(); + ValidateChloDetails(); + EXPECT_EQ(tls_chlo_extractor_.state(), + TlsChloExtractor::State::kParsedFullSinglePacketChlo); + EXPECT_TRUE(tls_chlo_extractor_.resumption_attempted()); + EXPECT_FALSE(tls_chlo_extractor_.early_data_attempted()); +} + +TEST_P(TlsChloExtractorTest, TlsExtentionInfo_ZeroRtt) { + auto crypto_client_config = std::make_unique( + crypto_test_utils::ProofVerifierForTesting(), + std::make_unique()); + PerformFullHandshake(crypto_client_config.get()); + + IncreaseSizeOfChlo(); + Initialize(std::move(crypto_client_config)); + EXPECT_GE(packets_.size(), 1u); + IngestPackets(); + ValidateChloDetails(); + EXPECT_EQ(tls_chlo_extractor_.state(), + TlsChloExtractor::State::kParsedFullMultiPacketChlo); + EXPECT_TRUE(tls_chlo_extractor_.resumption_attempted()); + EXPECT_TRUE(tls_chlo_extractor_.early_data_attempted()); +} + +TEST_P(TlsChloExtractorTest, MultiPacket) { + IncreaseSizeOfChlo(); + Initialize(); + EXPECT_EQ(packets_.size(), 2u); + IngestPackets(); + ValidateChloDetails(); + EXPECT_EQ(tls_chlo_extractor_.state(), + TlsChloExtractor::State::kParsedFullMultiPacketChlo); +} + +TEST_P(TlsChloExtractorTest, MultiPacketReordered) { + IncreaseSizeOfChlo(); + Initialize(); + ASSERT_EQ(packets_.size(), 2u); + // Artifically reorder both packets. + std::swap(packets_[0], packets_[1]); + IngestPackets(); + ValidateChloDetails(); + EXPECT_EQ(tls_chlo_extractor_.state(), + TlsChloExtractor::State::kParsedFullMultiPacketChlo); +} + +TEST_P(TlsChloExtractorTest, MoveAssignment) { + Initialize(); + EXPECT_EQ(packets_.size(), 1u); + TlsChloExtractor other_extractor; + tls_chlo_extractor_ = std::move(other_extractor); + IngestPackets(); + ValidateChloDetails(); + EXPECT_EQ(tls_chlo_extractor_.state(), + TlsChloExtractor::State::kParsedFullSinglePacketChlo); +} + +TEST_P(TlsChloExtractorTest, MoveAssignmentAfterExtraction) { + Initialize(); + EXPECT_EQ(packets_.size(), 1u); + IngestPackets(); + ValidateChloDetails(); + EXPECT_EQ(tls_chlo_extractor_.state(), + TlsChloExtractor::State::kParsedFullSinglePacketChlo); + + TlsChloExtractor other_extractor = std::move(tls_chlo_extractor_); + + EXPECT_EQ(other_extractor.state(), + TlsChloExtractor::State::kParsedFullSinglePacketChlo); + ValidateChloDetails(&other_extractor); +} + +TEST_P(TlsChloExtractorTest, MoveAssignmentBetweenPackets) { + IncreaseSizeOfChlo(); + Initialize(); + ASSERT_EQ(packets_.size(), 2u); + TlsChloExtractor other_extractor; + + // Have |other_extractor| parse the first packet. + ReceivedPacketInfo packet_info( + QuicSocketAddress(TestPeerIPAddress(), kTestPort), + QuicSocketAddress(TestPeerIPAddress(), kTestPort), *packets_[0]); + std::string detailed_error; + absl::optional retry_token; + const QuicErrorCode error = QuicFramer::ParsePublicHeaderDispatcher( + *packets_[0], /*expected_destination_connection_id_length=*/0, + &packet_info.form, &packet_info.long_packet_type, + &packet_info.version_flag, &packet_info.use_length_prefix, + &packet_info.version_label, &packet_info.version, + &packet_info.destination_connection_id, &packet_info.source_connection_id, + &retry_token, &detailed_error); + ASSERT_THAT(error, IsQuicNoError()) << detailed_error; + other_extractor.IngestPacket(packet_info.version, packet_info.packet); + // Remove the first packet from the list. + packets_.erase(packets_.begin()); + EXPECT_EQ(packets_.size(), 1u); + + // Move the extractor. + tls_chlo_extractor_ = std::move(other_extractor); + + // Have |tls_chlo_extractor_| parse the second packet. + IngestPackets(); + + ValidateChloDetails(); + EXPECT_EQ(tls_chlo_extractor_.state(), + TlsChloExtractor::State::kParsedFullMultiPacketChlo); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/tls_client_handshaker.cc b/quiche/quic/core/tls_client_handshaker.cc new file mode 100644 index 000000000000..74537f754dff --- /dev/null +++ b/quiche/quic/core/tls_client_handshaker.cc @@ -0,0 +1,665 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/tls_client_handshaker.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "openssl/ssl.h" +#include "quiche/quic/core/crypto/quic_crypto_client_config.h" +#include "quiche/quic/core/crypto/quic_encrypter.h" +#include "quiche/quic/core/crypto/transport_parameters.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_hostname_utils.h" +#include "quiche/common/quiche_text_utils.h" + +namespace quic { + +TlsClientHandshaker::TlsClientHandshaker( + const QuicServerId& server_id, QuicCryptoStream* stream, + QuicSession* session, std::unique_ptr verify_context, + QuicCryptoClientConfig* crypto_config, + QuicCryptoClientStream::ProofHandler* proof_handler, + bool has_application_state) + : TlsHandshaker(stream, session), + session_(session), + server_id_(server_id), + proof_verifier_(crypto_config->proof_verifier()), + verify_context_(std::move(verify_context)), + proof_handler_(proof_handler), + session_cache_(crypto_config->session_cache()), + user_agent_id_(crypto_config->user_agent_id()), + pre_shared_key_(crypto_config->pre_shared_key()), + crypto_negotiated_params_(new QuicCryptoNegotiatedParameters), + has_application_state_(has_application_state), + tls_connection_(crypto_config->ssl_ctx(), this, session->GetSSLConfig()) { + if (crypto_config->tls_signature_algorithms().has_value()) { + SSL_set1_sigalgs_list(ssl(), + crypto_config->tls_signature_algorithms()->c_str()); + } + if (crypto_config->proof_source() != nullptr) { + const ClientProofSource::CertAndKey* cert_and_key = + crypto_config->proof_source()->GetCertAndKey(server_id.host()); + if (cert_and_key != nullptr) { + QUIC_DVLOG(1) << "Setting client cert and key for " << server_id.host(); + tls_connection_.SetCertChain(cert_and_key->chain->ToCryptoBuffers().value, + cert_and_key->private_key.private_key()); + } + } +} + +TlsClientHandshaker::~TlsClientHandshaker() {} + +bool TlsClientHandshaker::CryptoConnect() { + if (!pre_shared_key_.empty()) { + // TODO(b/154162689) add PSK support to QUIC+TLS. + std::string error_details = + "QUIC client pre-shared keys not yet supported with TLS"; + QUIC_BUG(quic_bug_10576_1) << error_details; + CloseConnection(QUIC_HANDSHAKE_FAILED, error_details); + return false; + } + + // Make sure we use the right TLS extension codepoint. + int use_legacy_extension = 0; + if (session()->version().UsesLegacyTlsExtension()) { + use_legacy_extension = 1; + } + SSL_set_quic_use_legacy_codepoint(ssl(), use_legacy_extension); + + // TODO(b/193650832) Add SetFromConfig to QUIC handshakers and remove reliance + // on session pointer. +#if BORINGSSL_API_VERSION >= 16 + // Ask BoringSSL to randomize the order of TLS extensions. + SSL_set_permute_extensions(ssl(), true); +#endif // BORINGSSL_API_VERSION + + // Set the SNI to send, if any. + SSL_set_connect_state(ssl()); + if (QUIC_DLOG_INFO_IS_ON() && + !QuicHostnameUtils::IsValidSNI(server_id_.host())) { + QUIC_DLOG(INFO) << "Client configured with invalid hostname \"" + << server_id_.host() << "\", not sending as SNI"; + } + if (!server_id_.host().empty() && + (QuicHostnameUtils::IsValidSNI(server_id_.host()) || + allow_invalid_sni_for_tests_) && + SSL_set_tlsext_host_name(ssl(), server_id_.host().c_str()) != 1) { + return false; + } + + if (!SetAlpn()) { + CloseConnection(QUIC_HANDSHAKE_FAILED, "Client failed to set ALPN"); + return false; + } + + // Set the Transport Parameters to send in the ClientHello + if (!SetTransportParameters()) { + CloseConnection(QUIC_HANDSHAKE_FAILED, + "Client failed to set Transport Parameters"); + return false; + } + + // Set a session to resume, if there is one. + if (session_cache_) { + cached_state_ = session_cache_->Lookup( + server_id_, session()->GetClock()->WallNow(), SSL_get_SSL_CTX(ssl())); + } + if (cached_state_) { + SSL_set_session(ssl(), cached_state_->tls_session.get()); + if (!cached_state_->token.empty()) { + session()->SetSourceAddressTokenToSend(cached_state_->token); + } + } + + SSL_set_enable_ech_grease(ssl(), + tls_connection_.ssl_config().ech_grease_enabled); + if (!tls_connection_.ssl_config().ech_config_list.empty() && + !SSL_set1_ech_config_list( + ssl(), + reinterpret_cast( + tls_connection_.ssl_config().ech_config_list.data()), + tls_connection_.ssl_config().ech_config_list.size())) { + CloseConnection(QUIC_HANDSHAKE_FAILED, + "Client failed to set ECHConfigList"); + return false; + } + + // Start the handshake. + AdvanceHandshake(); + return session()->connection()->connected(); +} + +bool TlsClientHandshaker::PrepareZeroRttConfig( + QuicResumptionState* cached_state) { + std::string error_details; + if (!cached_state->transport_params || + handshaker_delegate()->ProcessTransportParameters( + *(cached_state->transport_params), + /*is_resumption = */ true, &error_details) != QUIC_NO_ERROR) { + QUIC_BUG(quic_bug_10576_2) + << "Unable to parse cached transport parameters."; + CloseConnection(QUIC_HANDSHAKE_FAILED, + "Client failed to parse cached Transport Parameters."); + return false; + } + + session()->connection()->OnTransportParametersResumed( + *(cached_state->transport_params)); + session()->OnConfigNegotiated(); + + if (has_application_state_) { + if (!cached_state->application_state || + !session()->ResumeApplicationState( + cached_state->application_state.get())) { + QUIC_BUG(quic_bug_10576_3) << "Unable to parse cached application state."; + CloseConnection(QUIC_HANDSHAKE_FAILED, + "Client failed to parse cached application state."); + return false; + } + } + return true; +} + +static bool IsValidAlpn(const std::string& alpn_string) { + return alpn_string.length() <= std::numeric_limits::max(); +} + +bool TlsClientHandshaker::SetAlpn() { + std::vector alpns = session()->GetAlpnsToOffer(); + if (alpns.empty()) { + if (allow_empty_alpn_for_tests_) { + return true; + } + + QUIC_BUG(quic_bug_10576_4) << "ALPN missing"; + return false; + } + if (!std::all_of(alpns.begin(), alpns.end(), IsValidAlpn)) { + QUIC_BUG(quic_bug_10576_5) << "ALPN too long"; + return false; + } + + // SSL_set_alpn_protos expects a sequence of one-byte-length-prefixed + // strings. + uint8_t alpn[1024]; + QuicDataWriter alpn_writer(sizeof(alpn), reinterpret_cast(alpn)); + bool success = true; + for (const std::string& alpn_string : alpns) { + success = success && alpn_writer.WriteUInt8(alpn_string.size()) && + alpn_writer.WriteStringPiece(alpn_string); + } + success = + success && (SSL_set_alpn_protos(ssl(), alpn, alpn_writer.length()) == 0); + if (!success) { + QUIC_BUG(quic_bug_10576_6) + << "Failed to set ALPN: " + << quiche::QuicheTextUtils::HexDump( + absl::string_view(alpn_writer.data(), alpn_writer.length())); + return false; + } + + // Enable ALPS only for versions that use HTTP/3 frames. + for (const std::string& alpn_string : alpns) { + for (const ParsedQuicVersion& version : session()->supported_versions()) { + if (!version.UsesHttp3() || AlpnForVersion(version) != alpn_string) { + continue; + } + if (SSL_add_application_settings( + ssl(), reinterpret_cast(alpn_string.data()), + alpn_string.size(), nullptr, /* settings_len = */ 0) != 1) { + QUIC_BUG(quic_bug_10576_7) << "Failed to enable ALPS."; + return false; + } + break; + } + } + + QUIC_DLOG(INFO) << "Client using ALPN: '" << alpns[0] << "'"; + return true; +} + +bool TlsClientHandshaker::SetTransportParameters() { + TransportParameters params; + params.perspective = Perspective::IS_CLIENT; + params.legacy_version_information = + TransportParameters::LegacyVersionInformation(); + params.legacy_version_information.value().version = + CreateQuicVersionLabel(session()->supported_versions().front()); + params.version_information = TransportParameters::VersionInformation(); + const QuicVersionLabel version = CreateQuicVersionLabel(session()->version()); + params.version_information.value().chosen_version = version; + params.version_information.value().other_versions.push_back(version); + + if (!handshaker_delegate()->FillTransportParameters(¶ms)) { + return false; + } + + // Notify QuicConnectionDebugVisitor. + session()->connection()->OnTransportParametersSent(params); + + std::vector param_bytes; + return SerializeTransportParameters(params, ¶m_bytes) && + SSL_set_quic_transport_params(ssl(), param_bytes.data(), + param_bytes.size()) == 1; +} + +bool TlsClientHandshaker::ProcessTransportParameters( + std::string* error_details) { + received_transport_params_ = std::make_unique(); + const uint8_t* param_bytes; + size_t param_bytes_len; + SSL_get_peer_quic_transport_params(ssl(), ¶m_bytes, ¶m_bytes_len); + if (param_bytes_len == 0) { + *error_details = "Server's transport parameters are missing"; + return false; + } + std::string parse_error_details; + if (!ParseTransportParameters( + session()->connection()->version(), Perspective::IS_SERVER, + param_bytes, param_bytes_len, received_transport_params_.get(), + &parse_error_details)) { + QUICHE_DCHECK(!parse_error_details.empty()); + *error_details = + "Unable to parse server's transport parameters: " + parse_error_details; + return false; + } + + // Notify QuicConnectionDebugVisitor. + session()->connection()->OnTransportParametersReceived( + *received_transport_params_); + + if (received_transport_params_->legacy_version_information.has_value()) { + if (received_transport_params_->legacy_version_information.value() + .version != + CreateQuicVersionLabel(session()->connection()->version())) { + *error_details = "Version mismatch detected"; + return false; + } + if (CryptoUtils::ValidateServerHelloVersions( + received_transport_params_->legacy_version_information.value() + .supported_versions, + session()->connection()->server_supported_versions(), + error_details) != QUIC_NO_ERROR) { + QUICHE_DCHECK(!error_details->empty()); + return false; + } + } + if (received_transport_params_->version_information.has_value()) { + if (!CryptoUtils::ValidateChosenVersion( + received_transport_params_->version_information.value() + .chosen_version, + session()->version(), error_details)) { + QUICHE_DCHECK(!error_details->empty()); + return false; + } + if (!CryptoUtils::CryptoUtils::ValidateServerVersions( + received_transport_params_->version_information.value() + .other_versions, + session()->version(), + session()->client_original_supported_versions(), error_details)) { + QUICHE_DCHECK(!error_details->empty()); + return false; + } + } + + if (handshaker_delegate()->ProcessTransportParameters( + *received_transport_params_, /* is_resumption = */ false, + error_details) != QUIC_NO_ERROR) { + QUICHE_DCHECK(!error_details->empty()); + return false; + } + + session()->OnConfigNegotiated(); + if (is_connection_closed()) { + *error_details = + "Session closed the connection when parsing negotiated config."; + return false; + } + return true; +} + +int TlsClientHandshaker::num_sent_client_hellos() const { return 0; } + +bool TlsClientHandshaker::IsResumption() const { + QUIC_BUG_IF(quic_bug_12736_1, !one_rtt_keys_available()); + return SSL_session_reused(ssl()) == 1; +} + +bool TlsClientHandshaker::EarlyDataAccepted() const { + QUIC_BUG_IF(quic_bug_12736_2, !one_rtt_keys_available()); + return SSL_early_data_accepted(ssl()) == 1; +} + +ssl_early_data_reason_t TlsClientHandshaker::EarlyDataReason() const { + return TlsHandshaker::EarlyDataReason(); +} + +bool TlsClientHandshaker::ReceivedInchoateReject() const { + QUIC_BUG_IF(quic_bug_12736_3, !one_rtt_keys_available()); + // REJ messages are a QUIC crypto feature, so TLS always returns false. + return false; +} + +int TlsClientHandshaker::num_scup_messages_received() const { + // SCUP messages aren't sent or received when using the TLS handshake. + return 0; +} + +std::string TlsClientHandshaker::chlo_hash() const { return ""; } + +bool TlsClientHandshaker::ExportKeyingMaterial(absl::string_view label, + absl::string_view context, + size_t result_len, + std::string* result) { + return ExportKeyingMaterialForLabel(label, context, result_len, result); +} + +bool TlsClientHandshaker::encryption_established() const { + return encryption_established_; +} + +bool TlsClientHandshaker::IsCryptoFrameExpectedForEncryptionLevel( + EncryptionLevel level) const { + return level != ENCRYPTION_ZERO_RTT; +} + +EncryptionLevel TlsClientHandshaker::GetEncryptionLevelToSendCryptoDataOfSpace( + PacketNumberSpace space) const { + switch (space) { + case INITIAL_DATA: + return ENCRYPTION_INITIAL; + case HANDSHAKE_DATA: + return ENCRYPTION_HANDSHAKE; + default: + QUICHE_DCHECK(false); + return NUM_ENCRYPTION_LEVELS; + } +} + +bool TlsClientHandshaker::one_rtt_keys_available() const { + return state_ >= HANDSHAKE_COMPLETE; +} + +const QuicCryptoNegotiatedParameters& +TlsClientHandshaker::crypto_negotiated_params() const { + return *crypto_negotiated_params_; +} + +CryptoMessageParser* TlsClientHandshaker::crypto_message_parser() { + return TlsHandshaker::crypto_message_parser(); +} + +HandshakeState TlsClientHandshaker::GetHandshakeState() const { return state_; } + +size_t TlsClientHandshaker::BufferSizeLimitForLevel( + EncryptionLevel level) const { + return TlsHandshaker::BufferSizeLimitForLevel(level); +} + +std::unique_ptr +TlsClientHandshaker::AdvanceKeysAndCreateCurrentOneRttDecrypter() { + return TlsHandshaker::AdvanceKeysAndCreateCurrentOneRttDecrypter(); +} + +std::unique_ptr +TlsClientHandshaker::CreateCurrentOneRttEncrypter() { + return TlsHandshaker::CreateCurrentOneRttEncrypter(); +} + +void TlsClientHandshaker::OnOneRttPacketAcknowledged() { + OnHandshakeConfirmed(); +} + +void TlsClientHandshaker::OnHandshakePacketSent() { + if (initial_keys_dropped_) { + return; + } + initial_keys_dropped_ = true; + handshaker_delegate()->DiscardOldEncryptionKey(ENCRYPTION_INITIAL); + handshaker_delegate()->DiscardOldDecryptionKey(ENCRYPTION_INITIAL); +} + +void TlsClientHandshaker::OnConnectionClosed(QuicErrorCode error, + ConnectionCloseSource source) { + TlsHandshaker::OnConnectionClosed(error, source); +} + +void TlsClientHandshaker::OnHandshakeDoneReceived() { + if (!one_rtt_keys_available()) { + CloseConnection(QUIC_HANDSHAKE_FAILED, + "Unexpected handshake done received"); + return; + } + OnHandshakeConfirmed(); +} + +void TlsClientHandshaker::OnNewTokenReceived(absl::string_view token) { + if (token.empty()) { + return; + } + if (session_cache_ != nullptr) { + session_cache_->OnNewTokenReceived(server_id_, token); + } +} + +void TlsClientHandshaker::SetWriteSecret( + EncryptionLevel level, const SSL_CIPHER* cipher, + absl::Span write_secret) { + if (is_connection_closed()) { + return; + } + if (level == ENCRYPTION_FORWARD_SECURE || level == ENCRYPTION_ZERO_RTT) { + encryption_established_ = true; + } + TlsHandshaker::SetWriteSecret(level, cipher, write_secret); + if (level == ENCRYPTION_FORWARD_SECURE) { + handshaker_delegate()->DiscardOldEncryptionKey(ENCRYPTION_ZERO_RTT); + } +} + +void TlsClientHandshaker::OnHandshakeConfirmed() { + QUICHE_DCHECK(one_rtt_keys_available()); + if (state_ >= HANDSHAKE_CONFIRMED) { + return; + } + state_ = HANDSHAKE_CONFIRMED; + handshaker_delegate()->DiscardOldEncryptionKey(ENCRYPTION_HANDSHAKE); + handshaker_delegate()->DiscardOldDecryptionKey(ENCRYPTION_HANDSHAKE); +} + +QuicAsyncStatus TlsClientHandshaker::VerifyCertChain( + const std::vector& certs, std::string* error_details, + std::unique_ptr* details, uint8_t* out_alert, + std::unique_ptr callback) { + const uint8_t* ocsp_response_raw; + size_t ocsp_response_len; + SSL_get0_ocsp_response(ssl(), &ocsp_response_raw, &ocsp_response_len); + std::string ocsp_response(reinterpret_cast(ocsp_response_raw), + ocsp_response_len); + const uint8_t* sct_list_raw; + size_t sct_list_len; + SSL_get0_signed_cert_timestamp_list(ssl(), &sct_list_raw, &sct_list_len); + std::string sct_list(reinterpret_cast(sct_list_raw), + sct_list_len); + + return proof_verifier_->VerifyCertChain( + server_id_.host(), server_id_.port(), certs, ocsp_response, sct_list, + verify_context_.get(), error_details, details, out_alert, + std::move(callback)); +} + +void TlsClientHandshaker::OnProofVerifyDetailsAvailable( + const ProofVerifyDetails& verify_details) { + proof_handler_->OnProofVerifyDetailsAvailable(verify_details); +} + +void TlsClientHandshaker::FinishHandshake() { + FillNegotiatedParams(); + + QUICHE_CHECK(!SSL_in_early_data(ssl())); + + QUIC_LOG(INFO) << "Client: handshake finished"; + + std::string error_details; + if (!ProcessTransportParameters(&error_details)) { + QUICHE_DCHECK(!error_details.empty()); + CloseConnection(QUIC_HANDSHAKE_FAILED, error_details); + return; + } + + const uint8_t* alpn_data = nullptr; + unsigned alpn_length = 0; + SSL_get0_alpn_selected(ssl(), &alpn_data, &alpn_length); + + if (alpn_length == 0) { + QUIC_DLOG(ERROR) << "Client: server did not select ALPN"; + // TODO(b/130164908) this should send no_application_protocol + // instead of QUIC_HANDSHAKE_FAILED. + CloseConnection(QUIC_HANDSHAKE_FAILED, "Server did not select ALPN"); + return; + } + + std::string received_alpn_string(reinterpret_cast(alpn_data), + alpn_length); + std::vector offered_alpns = session()->GetAlpnsToOffer(); + if (std::find(offered_alpns.begin(), offered_alpns.end(), + received_alpn_string) == offered_alpns.end()) { + QUIC_LOG(ERROR) << "Client: received mismatched ALPN '" + << received_alpn_string; + // TODO(b/130164908) this should send no_application_protocol + // instead of QUIC_HANDSHAKE_FAILED. + CloseConnection(QUIC_HANDSHAKE_FAILED, "Client received mismatched ALPN"); + return; + } + session()->OnAlpnSelected(received_alpn_string); + QUIC_DLOG(INFO) << "Client: server selected ALPN: '" << received_alpn_string + << "'"; + + // Parse ALPS extension. + const uint8_t* alps_data; + size_t alps_length; + SSL_get0_peer_application_settings(ssl(), &alps_data, &alps_length); + if (alps_length > 0) { + auto error = session()->OnAlpsData(alps_data, alps_length); + if (error) { + // Calling CloseConnection() is safe even in case OnAlpsData() has + // already closed the connection. + CloseConnection( + QUIC_HANDSHAKE_FAILED, + absl::StrCat("Error processing ALPS data: ", error.value())); + return; + } + } + + state_ = HANDSHAKE_COMPLETE; + handshaker_delegate()->OnTlsHandshakeComplete(); +} + +void TlsClientHandshaker::OnEnterEarlyData() { + QUICHE_DCHECK(SSL_in_early_data(ssl())); + + // TODO(wub): It might be unnecessary to FillNegotiatedParams() at this time, + // because we fill it again when handshake completes. + FillNegotiatedParams(); + + // If we're attempting a 0-RTT handshake, then we need to let the transport + // and application know what state to apply to early data. + PrepareZeroRttConfig(cached_state_.get()); +} + +void TlsClientHandshaker::FillNegotiatedParams() { + const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl()); + if (cipher) { + crypto_negotiated_params_->cipher_suite = + SSL_CIPHER_get_protocol_id(cipher); + } + crypto_negotiated_params_->key_exchange_group = SSL_get_curve_id(ssl()); + crypto_negotiated_params_->peer_signature_algorithm = + SSL_get_peer_signature_algorithm(ssl()); + crypto_negotiated_params_->encrypted_client_hello = SSL_ech_accepted(ssl()); +} + +void TlsClientHandshaker::ProcessPostHandshakeMessage() { + int rv = SSL_process_quic_post_handshake(ssl()); + if (rv != 1) { + CloseConnection(QUIC_HANDSHAKE_FAILED, "Unexpected post-handshake data"); + } +} + +bool TlsClientHandshaker::ShouldCloseConnectionOnUnexpectedError( + int ssl_error) { + if (ssl_error != SSL_ERROR_EARLY_DATA_REJECTED) { + return true; + } + HandleZeroRttReject(); + return false; +} + +void TlsClientHandshaker::HandleZeroRttReject() { + QUIC_LOG(INFO) << "0-RTT handshake attempted but was rejected by the server"; + QUICHE_DCHECK(session_cache_); + // Disable encrytion to block outgoing data until 1-RTT keys are available. + encryption_established_ = false; + handshaker_delegate()->OnZeroRttRejected(EarlyDataReason()); + SSL_reset_early_data_reject(ssl()); + session_cache_->ClearEarlyData(server_id_); + AdvanceHandshake(); +} + +void TlsClientHandshaker::InsertSession(bssl::UniquePtr session) { + if (!received_transport_params_) { + QUIC_BUG(quic_bug_10576_8) << "Transport parameters isn't received"; + return; + } + if (session_cache_ == nullptr) { + QUIC_DVLOG(1) << "No session cache, not inserting a session"; + return; + } + if (has_application_state_ && !received_application_state_) { + // Application state is not received yet. cache the sessions. + if (cached_tls_sessions_[0] != nullptr) { + cached_tls_sessions_[1] = std::move(cached_tls_sessions_[0]); + } + cached_tls_sessions_[0] = std::move(session); + return; + } + session_cache_->Insert(server_id_, std::move(session), + *received_transport_params_, + received_application_state_.get()); +} + +void TlsClientHandshaker::WriteMessage(EncryptionLevel level, + absl::string_view data) { + if (level == ENCRYPTION_HANDSHAKE && state_ < HANDSHAKE_PROCESSED) { + state_ = HANDSHAKE_PROCESSED; + } + TlsHandshaker::WriteMessage(level, data); +} + +void TlsClientHandshaker::SetServerApplicationStateForResumption( + std::unique_ptr application_state) { + QUICHE_DCHECK(one_rtt_keys_available()); + received_application_state_ = std::move(application_state); + // At least one tls session is cached before application state is received. So + // insert now. + if (session_cache_ != nullptr && cached_tls_sessions_[0] != nullptr) { + if (cached_tls_sessions_[1] != nullptr) { + // Insert the older session first. + session_cache_->Insert(server_id_, std::move(cached_tls_sessions_[1]), + *received_transport_params_, + received_application_state_.get()); + } + session_cache_->Insert(server_id_, std::move(cached_tls_sessions_[0]), + *received_transport_params_, + received_application_state_.get()); + } +} + +} // namespace quic diff --git a/quiche/quic/core/tls_client_handshaker.h b/quiche/quic/core/tls_client_handshaker.h new file mode 100644 index 000000000000..06581b54db4e --- /dev/null +++ b/quiche/quic/core/tls_client_handshaker.h @@ -0,0 +1,175 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_TLS_CLIENT_HANDSHAKER_H_ +#define QUICHE_QUIC_CORE_TLS_CLIENT_HANDSHAKER_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "openssl/ssl.h" +#include "quiche/quic/core/crypto/quic_crypto_client_config.h" +#include "quiche/quic/core/crypto/tls_client_connection.h" +#include "quiche/quic/core/crypto/transport_parameters.h" +#include "quiche/quic/core/quic_crypto_client_stream.h" +#include "quiche/quic/core/quic_crypto_stream.h" +#include "quiche/quic/core/tls_handshaker.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// An implementation of QuicCryptoClientStream::HandshakerInterface which uses +// TLS 1.3 for the crypto handshake protocol. +class QUIC_EXPORT_PRIVATE TlsClientHandshaker + : public TlsHandshaker, + public QuicCryptoClientStream::HandshakerInterface, + public TlsClientConnection::Delegate { + public: + // |crypto_config| must outlive TlsClientHandshaker. + TlsClientHandshaker(const QuicServerId& server_id, QuicCryptoStream* stream, + QuicSession* session, + std::unique_ptr verify_context, + QuicCryptoClientConfig* crypto_config, + QuicCryptoClientStream::ProofHandler* proof_handler, + bool has_application_state); + TlsClientHandshaker(const TlsClientHandshaker&) = delete; + TlsClientHandshaker& operator=(const TlsClientHandshaker&) = delete; + + ~TlsClientHandshaker() override; + + // From QuicCryptoClientStream::HandshakerInterface + bool CryptoConnect() override; + int num_sent_client_hellos() const override; + bool IsResumption() const override; + bool EarlyDataAccepted() const override; + ssl_early_data_reason_t EarlyDataReason() const override; + bool ReceivedInchoateReject() const override; + int num_scup_messages_received() const override; + std::string chlo_hash() const override; + bool ExportKeyingMaterial(absl::string_view label, absl::string_view context, + size_t result_len, std::string* result) override; + + // From QuicCryptoClientStream::HandshakerInterface and TlsHandshaker + bool encryption_established() const override; + bool IsCryptoFrameExpectedForEncryptionLevel( + EncryptionLevel level) const override; + EncryptionLevel GetEncryptionLevelToSendCryptoDataOfSpace( + PacketNumberSpace space) const override; + bool one_rtt_keys_available() const override; + const QuicCryptoNegotiatedParameters& crypto_negotiated_params() + const override; + CryptoMessageParser* crypto_message_parser() override; + HandshakeState GetHandshakeState() const override; + size_t BufferSizeLimitForLevel(EncryptionLevel level) const override; + std::unique_ptr AdvanceKeysAndCreateCurrentOneRttDecrypter() + override; + std::unique_ptr CreateCurrentOneRttEncrypter() override; + void OnOneRttPacketAcknowledged() override; + void OnHandshakePacketSent() override; + void OnConnectionClosed(QuicErrorCode error, + ConnectionCloseSource source) override; + void OnHandshakeDoneReceived() override; + void OnNewTokenReceived(absl::string_view token) override; + void SetWriteSecret(EncryptionLevel level, const SSL_CIPHER* cipher, + absl::Span write_secret) override; + + // Override to drop initial keys if trying to write ENCRYPTION_HANDSHAKE data. + void WriteMessage(EncryptionLevel level, absl::string_view data) override; + + void SetServerApplicationStateForResumption( + std::unique_ptr application_state) override; + + void AllowEmptyAlpnForTests() { allow_empty_alpn_for_tests_ = true; } + void AllowInvalidSNIForTests() { allow_invalid_sni_for_tests_ = true; } + + // Make the SSL object from BoringSSL publicly accessible. + using TlsHandshaker::ssl; + + protected: + const TlsConnection* tls_connection() const override { + return &tls_connection_; + } + + void FinishHandshake() override; + void OnEnterEarlyData() override; + void FillNegotiatedParams(); + void ProcessPostHandshakeMessage() override; + bool ShouldCloseConnectionOnUnexpectedError(int ssl_error) override; + QuicAsyncStatus VerifyCertChain( + const std::vector& certs, std::string* error_details, + std::unique_ptr* details, uint8_t* out_alert, + std::unique_ptr callback) override; + void OnProofVerifyDetailsAvailable( + const ProofVerifyDetails& verify_details) override; + + // TlsClientConnection::Delegate implementation: + TlsConnection::Delegate* ConnectionDelegate() override { return this; } + + private: + bool SetAlpn(); + bool SetTransportParameters(); + bool ProcessTransportParameters(std::string* error_details); + void HandleZeroRttReject(); + + // Called when server completes handshake (i.e., either handshake done is + // received or 1-RTT packet gets acknowledged). + void OnHandshakeConfirmed(); + + void InsertSession(bssl::UniquePtr session) override; + + bool PrepareZeroRttConfig(QuicResumptionState* cached_state); + + QuicSession* session() { return session_; } + QuicSession* session_; + + QuicServerId server_id_; + + // Objects used for verifying the server's certificate chain. + // |proof_verifier_| is owned by the caller of TlsHandshaker's constructor. + ProofVerifier* proof_verifier_; + std::unique_ptr verify_context_; + + // Unowned pointer to the proof handler which has the + // OnProofVerifyDetailsAvailable callback to use for notifying the result of + // certificate verification. + QuicCryptoClientStream::ProofHandler* proof_handler_; + + // Used for session resumption. |session_cache_| is owned by the + // QuicCryptoClientConfig passed into TlsClientHandshaker's constructor. + SessionCache* session_cache_; + + std::string user_agent_id_; + + // Pre-shared key used during the handshake. + std::string pre_shared_key_; + + HandshakeState state_ = HANDSHAKE_START; + bool encryption_established_ = false; + bool initial_keys_dropped_ = false; + quiche::QuicheReferenceCountedPointer + crypto_negotiated_params_; + + bool allow_empty_alpn_for_tests_ = false; + bool allow_invalid_sni_for_tests_ = false; + + const bool has_application_state_; + // Contains the state for performing a resumption, if one is attempted. This + // will always be non-null if a 0-RTT resumption is attempted. + std::unique_ptr cached_state_; + + TlsClientConnection tls_connection_; + + // If |has_application_state_|, stores the tls session tickets before + // application state is received. The latest one is put in the front. + bssl::UniquePtr cached_tls_sessions_[2] = {}; + + std::unique_ptr received_transport_params_ = nullptr; + std::unique_ptr received_application_state_ = nullptr; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_TLS_CLIENT_HANDSHAKER_H_ diff --git a/quiche/quic/core/tls_client_handshaker_test.cc b/quiche/quic/core/tls_client_handshaker_test.cc new file mode 100644 index 000000000000..0459de3acf0a --- /dev/null +++ b/quiche/quic/core/tls_client_handshaker_test.cc @@ -0,0 +1,863 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include +#include +#include + +#include "absl/base/macros.h" +#include "openssl/hpke.h" +#include "openssl/ssl.h" +#include "quiche/quic/core/crypto/quic_decrypter.h" +#include "quiche/quic/core/crypto/quic_encrypter.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_framer_peer.h" +#include "quiche/quic/test_tools/quic_session_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/test_tools/simple_session_cache.h" +#include "quiche/quic/tools/fake_proof_verifier.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +using testing::_; + +namespace quic { +namespace test { +namespace { + +constexpr char kServerHostname[] = "test.example.com"; +constexpr uint16_t kServerPort = 443; + +// TestProofVerifier wraps ProofVerifierForTesting, except for VerifyCertChain +// which, if TestProofVerifier is active, always returns QUIC_PENDING. (If this +// test proof verifier is not active, it delegates VerifyCertChain to the +// ProofVerifierForTesting.) The pending VerifyCertChain operation can be +// completed by calling InvokePendingCallback. This allows for testing +// asynchronous VerifyCertChain operations. +class TestProofVerifier : public ProofVerifier { + public: + TestProofVerifier() + : verifier_(crypto_test_utils::ProofVerifierForTesting()) {} + + QuicAsyncStatus VerifyProof( + const std::string& hostname, const uint16_t port, + const std::string& server_config, QuicTransportVersion quic_version, + absl::string_view chlo_hash, const std::vector& certs, + const std::string& cert_sct, const std::string& signature, + const ProofVerifyContext* context, std::string* error_details, + std::unique_ptr* details, + std::unique_ptr callback) override { + return verifier_->VerifyProof( + hostname, port, server_config, quic_version, chlo_hash, certs, cert_sct, + signature, context, error_details, details, std::move(callback)); + } + + QuicAsyncStatus VerifyCertChain( + const std::string& hostname, const uint16_t port, + const std::vector& certs, const std::string& ocsp_response, + const std::string& cert_sct, const ProofVerifyContext* context, + std::string* error_details, std::unique_ptr* details, + uint8_t* out_alert, + std::unique_ptr callback) override { + if (!active_) { + return verifier_->VerifyCertChain( + hostname, port, certs, ocsp_response, cert_sct, context, + error_details, details, out_alert, std::move(callback)); + } + pending_ops_.push_back(std::make_unique( + hostname, port, certs, ocsp_response, cert_sct, context, error_details, + details, out_alert, std::move(callback), verifier_.get())); + return QUIC_PENDING; + } + + std::unique_ptr CreateDefaultContext() override { + return nullptr; + } + + void Activate() { active_ = true; } + + size_t NumPendingCallbacks() const { return pending_ops_.size(); } + + void InvokePendingCallback(size_t n) { + ASSERT_GT(NumPendingCallbacks(), n); + pending_ops_[n]->Run(); + auto it = pending_ops_.begin() + n; + pending_ops_.erase(it); + } + + private: + // Implementation of ProofVerifierCallback that fails if the callback is ever + // run. + class FailingProofVerifierCallback : public ProofVerifierCallback { + public: + void Run(bool /*ok*/, const std::string& /*error_details*/, + std::unique_ptr* /*details*/) override { + FAIL(); + } + }; + + class VerifyChainPendingOp { + public: + VerifyChainPendingOp(const std::string& hostname, const uint16_t port, + const std::vector& certs, + const std::string& ocsp_response, + const std::string& cert_sct, + const ProofVerifyContext* context, + std::string* error_details, + std::unique_ptr* details, + uint8_t* out_alert, + std::unique_ptr callback, + ProofVerifier* delegate) + : hostname_(hostname), + port_(port), + certs_(certs), + ocsp_response_(ocsp_response), + cert_sct_(cert_sct), + context_(context), + error_details_(error_details), + details_(details), + out_alert_(out_alert), + callback_(std::move(callback)), + delegate_(delegate) {} + + void Run() { + // TestProofVerifier depends on crypto_test_utils::ProofVerifierForTesting + // running synchronously. It passes a FailingProofVerifierCallback and + // runs the original callback after asserting that the verification ran + // synchronously. + QuicAsyncStatus status = delegate_->VerifyCertChain( + hostname_, port_, certs_, ocsp_response_, cert_sct_, context_, + error_details_, details_, out_alert_, + std::make_unique()); + ASSERT_NE(status, QUIC_PENDING); + callback_->Run(status == QUIC_SUCCESS, *error_details_, details_); + } + + private: + std::string hostname_; + const uint16_t port_; + std::vector certs_; + std::string ocsp_response_; + std::string cert_sct_; + const ProofVerifyContext* context_; + std::string* error_details_; + std::unique_ptr* details_; + uint8_t* out_alert_; + std::unique_ptr callback_; + ProofVerifier* delegate_; + }; + + std::unique_ptr verifier_; + bool active_ = false; + std::vector> pending_ops_; +}; + +class TlsClientHandshakerTest : public QuicTestWithParam { + public: + TlsClientHandshakerTest() + : supported_versions_({GetParam()}), + server_id_(kServerHostname, kServerPort, false), + server_compressed_certs_cache_( + QuicCompressedCertsCache::kQuicCompressedCertsCacheSize) { + crypto_config_ = std::make_unique( + std::make_unique(), + std::make_unique()); + server_crypto_config_ = crypto_test_utils::CryptoServerConfigForTesting(); + CreateConnection(); + } + + void CreateSession() { + session_ = std::make_unique( + connection_, DefaultQuicConfig(), supported_versions_, server_id_, + crypto_config_.get(), ssl_config_); + EXPECT_CALL(*session_, GetAlpnsToOffer()) + .WillRepeatedly(testing::Return(std::vector( + {AlpnForVersion(connection_->version())}))); + } + + void CreateConnection() { + connection_ = + new PacketSavingConnection(&client_helper_, &alarm_factory_, + Perspective::IS_CLIENT, supported_versions_); + // Advance the time, because timers do not like uninitialized times. + connection_->AdvanceTime(QuicTime::Delta::FromSeconds(1)); + CreateSession(); + } + + void CompleteCryptoHandshake() { + CompleteCryptoHandshakeWithServerALPN( + AlpnForVersion(connection_->version())); + } + + void CompleteCryptoHandshakeWithServerALPN(const std::string& alpn) { + EXPECT_CALL(*connection_, SendCryptoData(_, _, _)) + .Times(testing::AnyNumber()); + stream()->CryptoConnect(); + QuicConfig config; + crypto_test_utils::HandshakeWithFakeServer( + &config, server_crypto_config_.get(), &server_helper_, &alarm_factory_, + connection_, stream(), alpn); + } + + QuicCryptoClientStream* stream() { + return session_->GetMutableCryptoStream(); + } + + QuicCryptoServerStreamBase* server_stream() { + return server_session_->GetMutableCryptoStream(); + } + + // Initializes a fake server, and all its associated state, for testing. + void InitializeFakeServer() { + TestQuicSpdyServerSession* server_session = nullptr; + CreateServerSessionForTest( + server_id_, QuicTime::Delta::FromSeconds(100000), supported_versions_, + &server_helper_, &alarm_factory_, server_crypto_config_.get(), + &server_compressed_certs_cache_, &server_connection_, &server_session); + server_session_.reset(server_session); + std::string alpn = AlpnForVersion(connection_->version()); + EXPECT_CALL(*server_session_, SelectAlpn(_)) + .WillRepeatedly([alpn](const std::vector& alpns) { + return std::find(alpns.cbegin(), alpns.cend(), alpn); + }); + } + + static bssl::UniquePtr MakeTestEchKeys( + const char* public_name, size_t max_name_len, + std::string* ech_config_list) { + bssl::ScopedEVP_HPKE_KEY key; + if (!EVP_HPKE_KEY_generate(key.get(), EVP_hpke_x25519_hkdf_sha256())) { + return nullptr; + } + + uint8_t* ech_config; + size_t ech_config_len; + if (!SSL_marshal_ech_config(&ech_config, &ech_config_len, + /*config_id=*/1, key.get(), public_name, + max_name_len)) { + return nullptr; + } + bssl::UniquePtr scoped_ech_config(ech_config); + + uint8_t* ech_config_list_raw; + size_t ech_config_list_len; + bssl::UniquePtr keys(SSL_ECH_KEYS_new()); + if (!keys || + !SSL_ECH_KEYS_add(keys.get(), /*is_retry_config=*/1, ech_config, + ech_config_len, key.get()) || + !SSL_ECH_KEYS_marshal_retry_configs(keys.get(), &ech_config_list_raw, + &ech_config_list_len)) { + return nullptr; + } + bssl::UniquePtr scoped_ech_config_list(ech_config_list_raw); + + ech_config_list->assign(ech_config_list_raw, + ech_config_list_raw + ech_config_list_len); + return keys; + } + + MockQuicConnectionHelper server_helper_; + MockQuicConnectionHelper client_helper_; + MockAlarmFactory alarm_factory_; + PacketSavingConnection* connection_; + ParsedQuicVersionVector supported_versions_; + std::unique_ptr session_; + QuicServerId server_id_; + CryptoHandshakeMessage message_; + std::unique_ptr crypto_config_; + absl::optional ssl_config_; + + // Server state. + std::unique_ptr server_crypto_config_; + PacketSavingConnection* server_connection_; + std::unique_ptr server_session_; + QuicCompressedCertsCache server_compressed_certs_cache_; +}; + +INSTANTIATE_TEST_SUITE_P(TlsHandshakerTests, TlsClientHandshakerTest, + ::testing::ValuesIn(AllSupportedVersionsWithTls()), + ::testing::PrintToStringParamName()); + +TEST_P(TlsClientHandshakerTest, NotInitiallyConnected) { + EXPECT_FALSE(stream()->encryption_established()); + EXPECT_FALSE(stream()->one_rtt_keys_available()); +} + +TEST_P(TlsClientHandshakerTest, ConnectedAfterHandshake) { + CompleteCryptoHandshake(); + EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol()); + EXPECT_TRUE(stream()->encryption_established()); + EXPECT_TRUE(stream()->one_rtt_keys_available()); + EXPECT_FALSE(stream()->IsResumption()); +} + +TEST_P(TlsClientHandshakerTest, ConnectionClosedOnTlsError) { + // Have client send ClientHello. + stream()->CryptoConnect(); + EXPECT_CALL(*connection_, CloseConnection(QUIC_HANDSHAKE_FAILED, _, _, _)); + + // Send a zero-length ServerHello from server to client. + char bogus_handshake_message[] = { + // Handshake struct (RFC 8446 appendix B.3) + 2, // HandshakeType server_hello + 0, 0, 0, // uint24 length + }; + stream()->crypto_message_parser()->ProcessInput( + absl::string_view(bogus_handshake_message, + ABSL_ARRAYSIZE(bogus_handshake_message)), + ENCRYPTION_INITIAL); + + EXPECT_FALSE(stream()->one_rtt_keys_available()); +} + +TEST_P(TlsClientHandshakerTest, ProofVerifyDetailsAvailableAfterHandshake) { + EXPECT_CALL(*session_, OnProofVerifyDetailsAvailable(testing::_)); + stream()->CryptoConnect(); + QuicConfig config; + crypto_test_utils::HandshakeWithFakeServer( + &config, server_crypto_config_.get(), &server_helper_, &alarm_factory_, + connection_, stream(), AlpnForVersion(connection_->version())); + EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol()); + EXPECT_TRUE(stream()->encryption_established()); + EXPECT_TRUE(stream()->one_rtt_keys_available()); +} + +TEST_P(TlsClientHandshakerTest, HandshakeWithAsyncProofVerifier) { + InitializeFakeServer(); + + // Enable TestProofVerifier to capture call to VerifyCertChain and run it + // asynchronously. + TestProofVerifier* proof_verifier = + static_cast(crypto_config_->proof_verifier()); + proof_verifier->Activate(); + + stream()->CryptoConnect(); + // Exchange handshake messages. + std::pair moved_message_counts = + crypto_test_utils::AdvanceHandshake( + connection_, stream(), 0, server_connection_, server_stream(), 0); + + ASSERT_EQ(proof_verifier->NumPendingCallbacks(), 1u); + proof_verifier->InvokePendingCallback(0); + + // Exchange more handshake messages. + crypto_test_utils::AdvanceHandshake( + connection_, stream(), moved_message_counts.first, server_connection_, + server_stream(), moved_message_counts.second); + + EXPECT_TRUE(stream()->encryption_established()); + EXPECT_TRUE(stream()->one_rtt_keys_available()); +} + +TEST_P(TlsClientHandshakerTest, Resumption) { + // Disable 0-RTT on the server so that we're only testing 1-RTT resumption: + SSL_CTX_set_early_data_enabled(server_crypto_config_->ssl_ctx(), false); + // Finish establishing the first connection: + CompleteCryptoHandshake(); + + EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol()); + EXPECT_TRUE(stream()->encryption_established()); + EXPECT_TRUE(stream()->one_rtt_keys_available()); + EXPECT_FALSE(stream()->IsResumption()); + + // Create a second connection + CreateConnection(); + CompleteCryptoHandshake(); + + EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol()); + EXPECT_TRUE(stream()->encryption_established()); + EXPECT_TRUE(stream()->one_rtt_keys_available()); + EXPECT_TRUE(stream()->IsResumption()); +} + +TEST_P(TlsClientHandshakerTest, ResumptionRejection) { + // Disable 0-RTT on the server before the first connection so the client + // doesn't attempt a 0-RTT resumption, only a 1-RTT resumption. + SSL_CTX_set_early_data_enabled(server_crypto_config_->ssl_ctx(), false); + // Finish establishing the first connection: + CompleteCryptoHandshake(); + + EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol()); + EXPECT_TRUE(stream()->encryption_established()); + EXPECT_TRUE(stream()->one_rtt_keys_available()); + EXPECT_FALSE(stream()->IsResumption()); + + // Create a second connection, but disable resumption on the server. + SSL_CTX_set_options(server_crypto_config_->ssl_ctx(), SSL_OP_NO_TICKET); + CreateConnection(); + CompleteCryptoHandshake(); + + EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol()); + EXPECT_TRUE(stream()->encryption_established()); + EXPECT_TRUE(stream()->one_rtt_keys_available()); + EXPECT_FALSE(stream()->IsResumption()); + EXPECT_FALSE(stream()->EarlyDataAccepted()); + EXPECT_EQ(stream()->EarlyDataReason(), + ssl_early_data_unsupported_for_session); +} + +TEST_P(TlsClientHandshakerTest, ZeroRttResumption) { + // Finish establishing the first connection: + CompleteCryptoHandshake(); + + EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol()); + EXPECT_TRUE(stream()->encryption_established()); + EXPECT_TRUE(stream()->one_rtt_keys_available()); + EXPECT_FALSE(stream()->IsResumption()); + + // Create a second connection + CreateConnection(); + // OnConfigNegotiated should be called twice - once when processing saved + // 0-RTT transport parameters, and then again when receiving transport + // parameters from the server. + EXPECT_CALL(*session_, OnConfigNegotiated()).Times(2); + EXPECT_CALL(*connection_, SendCryptoData(_, _, _)) + .Times(testing::AnyNumber()); + // Start the second handshake and confirm we have keys before receiving any + // messages from the server. + stream()->CryptoConnect(); + EXPECT_TRUE(stream()->encryption_established()); + EXPECT_NE(stream()->crypto_negotiated_params().cipher_suite, 0); + EXPECT_NE(stream()->crypto_negotiated_params().key_exchange_group, 0); + EXPECT_NE(stream()->crypto_negotiated_params().peer_signature_algorithm, 0); + // Finish the handshake with the server. + QuicConfig config; + crypto_test_utils::HandshakeWithFakeServer( + &config, server_crypto_config_.get(), &server_helper_, &alarm_factory_, + connection_, stream(), AlpnForVersion(connection_->version())); + + EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol()); + EXPECT_TRUE(stream()->encryption_established()); + EXPECT_TRUE(stream()->one_rtt_keys_available()); + EXPECT_TRUE(stream()->IsResumption()); + EXPECT_TRUE(stream()->EarlyDataAccepted()); + EXPECT_EQ(stream()->EarlyDataReason(), ssl_early_data_accepted); +} + +// Regression test for b/186438140. +TEST_P(TlsClientHandshakerTest, ZeroRttResumptionWithAyncProofVerifier) { + // Finish establishing the first connection, so the second connection can + // resume. + CompleteCryptoHandshake(); + + EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol()); + EXPECT_TRUE(stream()->encryption_established()); + EXPECT_TRUE(stream()->one_rtt_keys_available()); + EXPECT_FALSE(stream()->IsResumption()); + + // Create a second connection. + CreateConnection(); + InitializeFakeServer(); + EXPECT_CALL(*session_, OnConfigNegotiated()); + EXPECT_CALL(*connection_, SendCryptoData(_, _, _)) + .Times(testing::AnyNumber()); + // Enable TestProofVerifier to capture the call to VerifyCertChain and run it + // asynchronously. + TestProofVerifier* proof_verifier = + static_cast(crypto_config_->proof_verifier()); + proof_verifier->Activate(); + // Start the second handshake. + stream()->CryptoConnect(); + + ASSERT_EQ(proof_verifier->NumPendingCallbacks(), 1u); + + // Advance the handshake with the server. Since cert verification has not + // finished yet, client cannot derive HANDSHAKE and 1-RTT keys. + crypto_test_utils::AdvanceHandshake(connection_, stream(), 0, + server_connection_, server_stream(), 0); + + EXPECT_FALSE(stream()->one_rtt_keys_available()); + EXPECT_FALSE(server_stream()->one_rtt_keys_available()); + + // Finish cert verification after receiving packets from server. + proof_verifier->InvokePendingCallback(0); + + QuicFramer* framer = QuicConnectionPeer::GetFramer(connection_); + // Verify client has derived HANDSHAKE key. + EXPECT_NE(nullptr, + QuicFramerPeer::GetEncrypter(framer, ENCRYPTION_HANDSHAKE)); + + // Ideally, we should also verify that the process_undecryptable_packets_alarm + // is set and processing the undecryptable packets can advance the handshake + // to completion. Unfortunately, the test facilities used in this test does + // not support queuing and processing undecryptable packets. +} + +TEST_P(TlsClientHandshakerTest, ZeroRttRejection) { + // Finish establishing the first connection: + CompleteCryptoHandshake(); + + EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol()); + EXPECT_TRUE(stream()->encryption_established()); + EXPECT_TRUE(stream()->one_rtt_keys_available()); + EXPECT_FALSE(stream()->IsResumption()); + + // Create a second connection, but disable 0-RTT on the server. + SSL_CTX_set_early_data_enabled(server_crypto_config_->ssl_ctx(), false); + CreateConnection(); + + // OnConfigNegotiated should be called twice - once when processing saved + // 0-RTT transport parameters, and then again when receiving transport + // parameters from the server. + EXPECT_CALL(*session_, OnConfigNegotiated()).Times(2); + + // 4 packets will be sent in this connection: initial handshake packet, 0-RTT + // packet containing SETTINGS, handshake packet upon 0-RTT rejection, 0-RTT + // packet retransmission. + EXPECT_CALL(*connection_, + OnPacketSent(ENCRYPTION_INITIAL, NOT_RETRANSMISSION)); + if (VersionUsesHttp3(session_->transport_version())) { + EXPECT_CALL(*connection_, + OnPacketSent(ENCRYPTION_ZERO_RTT, NOT_RETRANSMISSION)); + } + EXPECT_CALL(*connection_, + OnPacketSent(ENCRYPTION_HANDSHAKE, NOT_RETRANSMISSION)); + if (VersionUsesHttp3(session_->transport_version())) { + // TODO(b/158027651): change transmission type to + // ALL_ZERO_RTT_RETRANSMISSION. + EXPECT_CALL(*connection_, + OnPacketSent(ENCRYPTION_FORWARD_SECURE, LOSS_RETRANSMISSION)); + } + + CompleteCryptoHandshake(); + + QuicFramer* framer = QuicConnectionPeer::GetFramer(connection_); + EXPECT_EQ(nullptr, QuicFramerPeer::GetEncrypter(framer, ENCRYPTION_ZERO_RTT)); + + EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol()); + EXPECT_TRUE(stream()->encryption_established()); + EXPECT_TRUE(stream()->one_rtt_keys_available()); + EXPECT_TRUE(stream()->IsResumption()); + EXPECT_FALSE(stream()->EarlyDataAccepted()); + EXPECT_EQ(stream()->EarlyDataReason(), ssl_early_data_peer_declined); +} + +TEST_P(TlsClientHandshakerTest, ZeroRttAndResumptionRejection) { + // Finish establishing the first connection: + CompleteCryptoHandshake(); + + EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol()); + EXPECT_TRUE(stream()->encryption_established()); + EXPECT_TRUE(stream()->one_rtt_keys_available()); + EXPECT_FALSE(stream()->IsResumption()); + + // Create a second connection, but disable resumption on the server. + SSL_CTX_set_options(server_crypto_config_->ssl_ctx(), SSL_OP_NO_TICKET); + CreateConnection(); + + // OnConfigNegotiated should be called twice - once when processing saved + // 0-RTT transport parameters, and then again when receiving transport + // parameters from the server. + EXPECT_CALL(*session_, OnConfigNegotiated()).Times(2); + + // 4 packets will be sent in this connection: initial handshake packet, 0-RTT + // packet containing SETTINGS, handshake packet upon 0-RTT rejection, 0-RTT + // packet retransmission. + EXPECT_CALL(*connection_, + OnPacketSent(ENCRYPTION_INITIAL, NOT_RETRANSMISSION)); + if (VersionUsesHttp3(session_->transport_version())) { + EXPECT_CALL(*connection_, + OnPacketSent(ENCRYPTION_ZERO_RTT, NOT_RETRANSMISSION)); + } + EXPECT_CALL(*connection_, + OnPacketSent(ENCRYPTION_HANDSHAKE, NOT_RETRANSMISSION)); + if (VersionUsesHttp3(session_->transport_version())) { + // TODO(b/158027651): change transmission type to + // ALL_ZERO_RTT_RETRANSMISSION. + EXPECT_CALL(*connection_, + OnPacketSent(ENCRYPTION_FORWARD_SECURE, LOSS_RETRANSMISSION)); + } + + CompleteCryptoHandshake(); + + QuicFramer* framer = QuicConnectionPeer::GetFramer(connection_); + EXPECT_EQ(nullptr, QuicFramerPeer::GetEncrypter(framer, ENCRYPTION_ZERO_RTT)); + + EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol()); + EXPECT_TRUE(stream()->encryption_established()); + EXPECT_TRUE(stream()->one_rtt_keys_available()); + EXPECT_FALSE(stream()->IsResumption()); + EXPECT_FALSE(stream()->EarlyDataAccepted()); + EXPECT_EQ(stream()->EarlyDataReason(), ssl_early_data_session_not_resumed); +} + +TEST_P(TlsClientHandshakerTest, ClientSendsNoSNI) { + // Reconfigure client to sent an empty server hostname. The crypto config also + // needs to be recreated to use a FakeProofVerifier since the server's cert + // won't match the empty hostname. + server_id_ = QuicServerId("", 443); + crypto_config_.reset(new QuicCryptoClientConfig( + std::make_unique(), nullptr)); + CreateConnection(); + InitializeFakeServer(); + + stream()->CryptoConnect(); + crypto_test_utils::CommunicateHandshakeMessages( + connection_, stream(), server_connection_, server_stream()); + + EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol()); + EXPECT_TRUE(stream()->encryption_established()); + EXPECT_TRUE(stream()->one_rtt_keys_available()); + + EXPECT_EQ(server_stream()->crypto_negotiated_params().sni, ""); +} + +TEST_P(TlsClientHandshakerTest, ClientSendingTooManyALPNs) { + std::string long_alpn(250, 'A'); + EXPECT_QUIC_BUG( + { + EXPECT_CALL(*session_, GetAlpnsToOffer()) + .WillOnce(testing::Return(std::vector({ + long_alpn + "1", + long_alpn + "2", + long_alpn + "3", + long_alpn + "4", + long_alpn + "5", + long_alpn + "6", + long_alpn + "7", + long_alpn + "8", + }))); + stream()->CryptoConnect(); + }, + "Failed to set ALPN"); +} + +TEST_P(TlsClientHandshakerTest, ServerRequiresCustomALPN) { + InitializeFakeServer(); + const std::string kTestAlpn = "An ALPN That Client Did Not Offer"; + EXPECT_CALL(*server_session_, SelectAlpn(_)) + .WillOnce([kTestAlpn](const std::vector& alpns) { + return std::find(alpns.cbegin(), alpns.cend(), kTestAlpn); + }); + + EXPECT_CALL(*server_connection_, + CloseConnection(QUIC_HANDSHAKE_FAILED, + static_cast( + CRYPTO_ERROR_FIRST + 120), + "TLS handshake failure (ENCRYPTION_INITIAL) 120: " + "no application protocol", + _)); + + stream()->CryptoConnect(); + crypto_test_utils::AdvanceHandshake(connection_, stream(), 0, + server_connection_, server_stream(), 0); + + EXPECT_FALSE(stream()->one_rtt_keys_available()); + EXPECT_FALSE(server_stream()->one_rtt_keys_available()); + EXPECT_FALSE(stream()->encryption_established()); + EXPECT_FALSE(server_stream()->encryption_established()); +} + +TEST_P(TlsClientHandshakerTest, ZeroRTTNotAttemptedOnALPNChange) { + // Finish establishing the first connection: + CompleteCryptoHandshake(); + + EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol()); + EXPECT_TRUE(stream()->encryption_established()); + EXPECT_TRUE(stream()->one_rtt_keys_available()); + EXPECT_FALSE(stream()->IsResumption()); + + // Create a second connection + CreateConnection(); + // Override the ALPN to send on the second connection. + const std::string kTestAlpn = "Test ALPN"; + EXPECT_CALL(*session_, GetAlpnsToOffer()) + .WillRepeatedly(testing::Return(std::vector({kTestAlpn}))); + // OnConfigNegotiated should only be called once: when transport parameters + // are received from the server. + EXPECT_CALL(*session_, OnConfigNegotiated()).Times(1); + + CompleteCryptoHandshakeWithServerALPN(kTestAlpn); + EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol()); + EXPECT_TRUE(stream()->encryption_established()); + EXPECT_TRUE(stream()->one_rtt_keys_available()); + EXPECT_FALSE(stream()->EarlyDataAccepted()); + EXPECT_EQ(stream()->EarlyDataReason(), ssl_early_data_alpn_mismatch); +} + +TEST_P(TlsClientHandshakerTest, InvalidSNI) { + // Test that a client will skip sending SNI if configured to send an invalid + // hostname. In this case, the inclusion of '!' is invalid. + server_id_ = QuicServerId("invalid!.example.com", 443); + crypto_config_.reset(new QuicCryptoClientConfig( + std::make_unique(), nullptr)); + CreateConnection(); + InitializeFakeServer(); + + stream()->CryptoConnect(); + crypto_test_utils::CommunicateHandshakeMessages( + connection_, stream(), server_connection_, server_stream()); + + EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol()); + EXPECT_TRUE(stream()->encryption_established()); + EXPECT_TRUE(stream()->one_rtt_keys_available()); + + EXPECT_EQ(server_stream()->crypto_negotiated_params().sni, ""); +} + +TEST_P(TlsClientHandshakerTest, BadTransportParams) { + if (!connection_->version().UsesHttp3()) { + return; + } + // Finish establishing the first connection: + CompleteCryptoHandshake(); + + // Create a second connection + CreateConnection(); + + stream()->CryptoConnect(); + auto* id_manager = QuicSessionPeer::ietf_streamid_manager(session_.get()); + EXPECT_EQ(kDefaultMaxStreamsPerConnection, + id_manager->max_outgoing_bidirectional_streams()); + QuicConfig config; + config.SetMaxBidirectionalStreamsToSend( + config.GetMaxBidirectionalStreamsToSend() - 1); + + EXPECT_CALL(*connection_, + CloseConnection(QUIC_ZERO_RTT_REJECTION_LIMIT_REDUCED, _, _)) + .WillOnce(testing::Invoke(connection_, + &MockQuicConnection::ReallyCloseConnection)); + // Close connection will be called again in the handshaker, but this will be + // no-op as the connection is already closed. + EXPECT_CALL(*connection_, CloseConnection(QUIC_HANDSHAKE_FAILED, _, _)); + + crypto_test_utils::HandshakeWithFakeServer( + &config, server_crypto_config_.get(), &server_helper_, &alarm_factory_, + connection_, stream(), AlpnForVersion(connection_->version())); +} + +TEST_P(TlsClientHandshakerTest, ECH) { + ssl_config_.emplace(); + bssl::UniquePtr ech_keys = + MakeTestEchKeys("public-name.example", /*max_name_len=*/64, + &ssl_config_->ech_config_list); + ASSERT_TRUE(ech_keys); + + // Configure the server to use the test ECH keys. + ASSERT_TRUE( + SSL_CTX_set1_ech_keys(server_crypto_config_->ssl_ctx(), ech_keys.get())); + + // Recreate the client to pick up the new `ssl_config_`. + CreateConnection(); + + // The handshake should complete and negotiate ECH. + CompleteCryptoHandshake(); + EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol()); + EXPECT_TRUE(stream()->encryption_established()); + EXPECT_TRUE(stream()->one_rtt_keys_available()); + EXPECT_TRUE(stream()->crypto_negotiated_params().encrypted_client_hello); +} + +TEST_P(TlsClientHandshakerTest, ECHWithConfigAndGREASE) { + ssl_config_.emplace(); + bssl::UniquePtr ech_keys = + MakeTestEchKeys("public-name.example", /*max_name_len=*/64, + &ssl_config_->ech_config_list); + ASSERT_TRUE(ech_keys); + ssl_config_->ech_grease_enabled = true; + + // Configure the server to use the test ECH keys. + ASSERT_TRUE( + SSL_CTX_set1_ech_keys(server_crypto_config_->ssl_ctx(), ech_keys.get())); + + // Recreate the client to pick up the new `ssl_config_`. + CreateConnection(); + + // When both ECH and ECH GREASE are enabled, ECH should take precedence. + // The handshake should complete and negotiate ECH. + CompleteCryptoHandshake(); + EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol()); + EXPECT_TRUE(stream()->encryption_established()); + EXPECT_TRUE(stream()->one_rtt_keys_available()); + EXPECT_TRUE(stream()->crypto_negotiated_params().encrypted_client_hello); +} + +TEST_P(TlsClientHandshakerTest, ECHInvalidConfig) { + // An invalid ECHConfigList should fail before sending a ClientHello. + ssl_config_.emplace(); + ssl_config_->ech_config_list = "invalid config"; + CreateConnection(); + EXPECT_CALL(*connection_, CloseConnection(QUIC_HANDSHAKE_FAILED, _, _)); + stream()->CryptoConnect(); +} + +TEST_P(TlsClientHandshakerTest, ECHWrongKeys) { + ssl_config_.emplace(); + bssl::UniquePtr ech_keys1 = + MakeTestEchKeys("public-name.example", /*max_name_len=*/64, + &ssl_config_->ech_config_list); + ASSERT_TRUE(ech_keys1); + + std::string ech_config_list2; + bssl::UniquePtr ech_keys2 = MakeTestEchKeys( + "public-name.example", /*max_name_len=*/64, &ech_config_list2); + ASSERT_TRUE(ech_keys2); + + // Configure the server to use different keys from what the client has. + ASSERT_TRUE( + SSL_CTX_set1_ech_keys(server_crypto_config_->ssl_ctx(), ech_keys2.get())); + + // Recreate the client to pick up the new `ssl_config_`. + CreateConnection(); + + // TODO(crbug.com/1287248): This should instead output sufficient information + // to run the recovery flow. + EXPECT_CALL(*connection_, + CloseConnection(QUIC_HANDSHAKE_FAILED, + static_cast( + CRYPTO_ERROR_FIRST + SSL_AD_ECH_REQUIRED), + _, _)) + .WillOnce(testing::Invoke(connection_, + &MockQuicConnection::ReallyCloseConnection4)); + + // The handshake should complete and negotiate ECH. + CompleteCryptoHandshake(); +} + +// Test that ECH GREASE can be configured. +TEST_P(TlsClientHandshakerTest, ECHGrease) { + ssl_config_.emplace(); + ssl_config_->ech_grease_enabled = true; + CreateConnection(); + + // Add a DoS callback on the server, to test that the client sent a GREASE + // message. This is a bit of a hack. TlsServerHandshaker already configures + // the certificate selection callback, but does not usefully expose any way + // for tests to inspect the ClientHello. So, instead, we register a different + // callback that also gets the ClientHello. + static bool callback_ran; + callback_ran = false; + SSL_CTX_set_dos_protection_cb( + server_crypto_config_->ssl_ctx(), + [](const SSL_CLIENT_HELLO* client_hello) -> int { + const uint8_t* data; + size_t len; + EXPECT_TRUE(SSL_early_callback_ctx_extension_get( + client_hello, TLSEXT_TYPE_encrypted_client_hello, &data, &len)); + callback_ran = true; + return 1; + }); + + CompleteCryptoHandshake(); + EXPECT_TRUE(callback_ran); + + EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol()); + EXPECT_TRUE(stream()->encryption_established()); + EXPECT_TRUE(stream()->one_rtt_keys_available()); + // Sending an ignored ECH GREASE extension does not count as negotiating ECH. + EXPECT_FALSE(stream()->crypto_negotiated_params().encrypted_client_hello); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/tls_handshaker.cc b/quiche/quic/core/tls_handshaker.cc new file mode 100644 index 000000000000..486457dea515 --- /dev/null +++ b/quiche/quic/core/tls_handshaker.cc @@ -0,0 +1,406 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/tls_handshaker.h" + +#include "absl/base/macros.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "openssl/crypto.h" +#include "openssl/ssl.h" +#include "quiche/quic/core/quic_crypto_stream.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_stack_trace.h" + +namespace quic { + +#define ENDPOINT (SSL_is_server(ssl()) ? "TlsServer: " : "TlsClient: ") + +TlsHandshaker::ProofVerifierCallbackImpl::ProofVerifierCallbackImpl( + TlsHandshaker* parent) + : parent_(parent) {} + +TlsHandshaker::ProofVerifierCallbackImpl::~ProofVerifierCallbackImpl() {} + +void TlsHandshaker::ProofVerifierCallbackImpl::Run( + bool ok, const std::string& /*error_details*/, + std::unique_ptr* details) { + if (parent_ == nullptr) { + return; + } + + parent_->verify_details_ = std::move(*details); + parent_->verify_result_ = ok ? ssl_verify_ok : ssl_verify_invalid; + parent_->set_expected_ssl_error(SSL_ERROR_WANT_READ); + parent_->proof_verify_callback_ = nullptr; + if (parent_->verify_details_) { + parent_->OnProofVerifyDetailsAvailable(*parent_->verify_details_); + } + parent_->AdvanceHandshake(); +} + +void TlsHandshaker::ProofVerifierCallbackImpl::Cancel() { parent_ = nullptr; } + +TlsHandshaker::TlsHandshaker(QuicCryptoStream* stream, QuicSession* session) + : stream_(stream), handshaker_delegate_(session) {} + +TlsHandshaker::~TlsHandshaker() { + if (proof_verify_callback_) { + proof_verify_callback_->Cancel(); + } +} + +bool TlsHandshaker::ProcessInput(absl::string_view input, + EncryptionLevel level) { + if (parser_error_ != QUIC_NO_ERROR) { + return false; + } + // TODO(nharper): Call SSL_quic_read_level(ssl()) and check whether the + // encryption level BoringSSL expects matches the encryption level that we + // just received input at. If they mismatch, should ProcessInput return true + // or false? If data is for a future encryption level, it should be queued for + // later? + if (SSL_provide_quic_data(ssl(), TlsConnection::BoringEncryptionLevel(level), + reinterpret_cast(input.data()), + input.size()) != 1) { + // SSL_provide_quic_data can fail for 3 reasons: + // - API misuse (calling it before SSL_set_custom_quic_method, which we + // call in the TlsHandshaker c'tor) + // - Memory exhaustion when appending data to its buffer + // - Data provided at the wrong encryption level + // + // Of these, the only sensible error to handle is data provided at the wrong + // encryption level. + // + // Note: the error provided below has a good-sounding enum value, although + // it doesn't match the description as it's a QUIC Crypto specific error. + parser_error_ = QUIC_INVALID_CRYPTO_MESSAGE_TYPE; + parser_error_detail_ = "TLS stack failed to receive data"; + return false; + } + AdvanceHandshake(); + return true; +} + +void TlsHandshaker::AdvanceHandshake() { + if (is_connection_closed()) { + return; + } + if (GetHandshakeState() >= HANDSHAKE_COMPLETE) { + ProcessPostHandshakeMessage(); + return; + } + + QUICHE_BUG_IF( + quic_tls_server_async_done_no_flusher, + SSL_is_server(ssl()) && !handshaker_delegate_->PacketFlusherAttached()) + << "is_server:" << SSL_is_server(ssl()); + + QUIC_VLOG(1) << ENDPOINT << "Continuing handshake"; + last_tls_alert_.reset(); + int rv = SSL_do_handshake(ssl()); + + if (is_connection_closed()) { + return; + } + + // If SSL_do_handshake return success(1) and we are in early data, it is + // possible that we have provided ServerHello to BoringSSL but it hasn't been + // processed. Retry SSL_do_handshake once will advance the handshake more in + // that case. If there are no unprocessed ServerHello, the retry will return a + // non-positive number. + if (rv == 1 && SSL_in_early_data(ssl())) { + OnEnterEarlyData(); + rv = SSL_do_handshake(ssl()); + + if (is_connection_closed()) { + return; + } + + QUIC_VLOG(1) << ENDPOINT + << "SSL_do_handshake returned when entering early data. After " + << "retry, rv=" << rv + << ", SSL_in_early_data=" << SSL_in_early_data(ssl()); + // The retry should either + // - Return <= 0 if the handshake is still pending, likely still in early + // data. + // - Return 1 if the handshake has _actually_ finished. i.e. + // SSL_in_early_data should be false. + // + // In either case, it should not both return 1 and stay in early data. + if (rv == 1 && SSL_in_early_data(ssl()) && !is_connection_closed()) { + QUIC_BUG(quic_handshaker_stay_in_early_data) + << "The original and the retry of SSL_do_handshake both returned " + "success and in early data"; + CloseConnection(QUIC_HANDSHAKE_FAILED, + "TLS handshake failed: Still in early data after retry"); + return; + } + } + + if (rv == 1) { + FinishHandshake(); + return; + } + int ssl_error = SSL_get_error(ssl(), rv); + if (ssl_error == expected_ssl_error_) { + return; + } + if (ShouldCloseConnectionOnUnexpectedError(ssl_error) && + !is_connection_closed()) { + QUIC_VLOG(1) << "SSL_do_handshake failed; SSL_get_error returns " + << ssl_error; + ERR_print_errors_fp(stderr); + if (dont_close_connection_in_tls_alert_callback_ && + last_tls_alert_.has_value()) { + QUIC_RELOADABLE_FLAG_COUNT_N( + quic_dont_close_connection_in_tls_alert_callback, 2, 2); + std::string error_details = + absl::StrCat("TLS handshake failure (", + EncryptionLevelToString(last_tls_alert_->level), ") ", + static_cast(last_tls_alert_->desc), ": ", + SSL_alert_desc_string_long(last_tls_alert_->desc)); + QUIC_DLOG(ERROR) << error_details; + CloseConnection(TlsAlertToQuicErrorCode(last_tls_alert_->desc), + static_cast( + CRYPTO_ERROR_FIRST + last_tls_alert_->desc), + error_details); + } else { + CloseConnection(QUIC_HANDSHAKE_FAILED, "TLS handshake failed"); + } + } +} + +void TlsHandshaker::CloseConnection(QuicErrorCode error, + const std::string& reason_phrase) { + QUICHE_DCHECK(!reason_phrase.empty()); + stream()->OnUnrecoverableError(error, reason_phrase); + is_connection_closed_ = true; +} + +void TlsHandshaker::CloseConnection(QuicErrorCode error, + QuicIetfTransportErrorCodes ietf_error, + const std::string& reason_phrase) { + QUICHE_DCHECK(!reason_phrase.empty()); + stream()->OnUnrecoverableError(error, ietf_error, reason_phrase); + is_connection_closed_ = true; +} + +void TlsHandshaker::OnConnectionClosed(QuicErrorCode /*error*/, + ConnectionCloseSource /*source*/) { + is_connection_closed_ = true; +} + +bool TlsHandshaker::ShouldCloseConnectionOnUnexpectedError(int /*ssl_error*/) { + return true; +} + +size_t TlsHandshaker::BufferSizeLimitForLevel(EncryptionLevel level) const { + return SSL_quic_max_handshake_flight_len( + ssl(), TlsConnection::BoringEncryptionLevel(level)); +} + +ssl_early_data_reason_t TlsHandshaker::EarlyDataReason() const { + return SSL_get_early_data_reason(ssl()); +} + +const EVP_MD* TlsHandshaker::Prf(const SSL_CIPHER* cipher) { + return EVP_get_digestbynid(SSL_CIPHER_get_prf_nid(cipher)); +} + +enum ssl_verify_result_t TlsHandshaker::VerifyCert(uint8_t* out_alert) { + if (verify_result_ != ssl_verify_retry || + expected_ssl_error() == SSL_ERROR_WANT_CERTIFICATE_VERIFY) { + enum ssl_verify_result_t result = verify_result_; + verify_result_ = ssl_verify_retry; + *out_alert = cert_verify_tls_alert_; + return result; + } + const STACK_OF(CRYPTO_BUFFER)* cert_chain = SSL_get0_peer_certificates(ssl()); + if (cert_chain == nullptr) { + *out_alert = SSL_AD_INTERNAL_ERROR; + return ssl_verify_invalid; + } + // TODO(nharper): Pass the CRYPTO_BUFFERs into the QUIC stack to avoid copies. + std::vector certs; + for (CRYPTO_BUFFER* cert : cert_chain) { + certs.push_back( + std::string(reinterpret_cast(CRYPTO_BUFFER_data(cert)), + CRYPTO_BUFFER_len(cert))); + } + QUIC_DVLOG(1) << "VerifyCert: peer cert_chain length: " << certs.size(); + + ProofVerifierCallbackImpl* proof_verify_callback = + new ProofVerifierCallbackImpl(this); + + cert_verify_tls_alert_ = *out_alert; + QuicAsyncStatus verify_result = VerifyCertChain( + certs, &cert_verify_error_details_, &verify_details_, + &cert_verify_tls_alert_, + std::unique_ptr(proof_verify_callback)); + switch (verify_result) { + case QUIC_SUCCESS: + if (verify_details_) { + OnProofVerifyDetailsAvailable(*verify_details_); + } + return ssl_verify_ok; + case QUIC_PENDING: + proof_verify_callback_ = proof_verify_callback; + set_expected_ssl_error(SSL_ERROR_WANT_CERTIFICATE_VERIFY); + return ssl_verify_retry; + case QUIC_FAILURE: + default: + *out_alert = cert_verify_tls_alert_; + QUIC_LOG(INFO) << "Cert chain verification failed: " + << cert_verify_error_details_; + return ssl_verify_invalid; + } +} + +void TlsHandshaker::SetWriteSecret(EncryptionLevel level, + const SSL_CIPHER* cipher, + absl::Span write_secret) { + QUIC_DVLOG(1) << ENDPOINT << "SetWriteSecret level=" << level; + std::unique_ptr encrypter = + QuicEncrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher)); + const EVP_MD* prf = Prf(cipher); + CryptoUtils::SetKeyAndIV(prf, write_secret, + handshaker_delegate_->parsed_version(), + encrypter.get()); + std::vector header_protection_key = + CryptoUtils::GenerateHeaderProtectionKey( + prf, write_secret, handshaker_delegate_->parsed_version(), + encrypter->GetKeySize()); + encrypter->SetHeaderProtectionKey( + absl::string_view(reinterpret_cast(header_protection_key.data()), + header_protection_key.size())); + if (level == ENCRYPTION_FORWARD_SECURE) { + QUICHE_DCHECK(latest_write_secret_.empty()); + latest_write_secret_.assign(write_secret.begin(), write_secret.end()); + one_rtt_write_header_protection_key_ = header_protection_key; + } + handshaker_delegate_->OnNewEncryptionKeyAvailable(level, + std::move(encrypter)); +} + +bool TlsHandshaker::SetReadSecret(EncryptionLevel level, + const SSL_CIPHER* cipher, + absl::Span read_secret) { + QUIC_DVLOG(1) << ENDPOINT << "SetReadSecret level=" << level; + std::unique_ptr decrypter = + QuicDecrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher)); + const EVP_MD* prf = Prf(cipher); + CryptoUtils::SetKeyAndIV(prf, read_secret, + handshaker_delegate_->parsed_version(), + decrypter.get()); + std::vector header_protection_key = + CryptoUtils::GenerateHeaderProtectionKey( + prf, read_secret, handshaker_delegate_->parsed_version(), + decrypter->GetKeySize()); + decrypter->SetHeaderProtectionKey( + absl::string_view(reinterpret_cast(header_protection_key.data()), + header_protection_key.size())); + if (level == ENCRYPTION_FORWARD_SECURE) { + QUICHE_DCHECK(latest_read_secret_.empty()); + latest_read_secret_.assign(read_secret.begin(), read_secret.end()); + one_rtt_read_header_protection_key_ = header_protection_key; + } + return handshaker_delegate_->OnNewDecryptionKeyAvailable( + level, std::move(decrypter), + /*set_alternative_decrypter=*/false, + /*latch_once_used=*/false); +} + +std::unique_ptr +TlsHandshaker::AdvanceKeysAndCreateCurrentOneRttDecrypter() { + if (latest_read_secret_.empty() || latest_write_secret_.empty() || + one_rtt_read_header_protection_key_.empty() || + one_rtt_write_header_protection_key_.empty()) { + std::string error_details = "1-RTT secret(s) not set yet."; + QUIC_BUG(quic_bug_10312_1) << error_details; + CloseConnection(QUIC_INTERNAL_ERROR, error_details); + return nullptr; + } + const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl()); + const EVP_MD* prf = Prf(cipher); + latest_read_secret_ = CryptoUtils::GenerateNextKeyPhaseSecret( + prf, handshaker_delegate_->parsed_version(), latest_read_secret_); + latest_write_secret_ = CryptoUtils::GenerateNextKeyPhaseSecret( + prf, handshaker_delegate_->parsed_version(), latest_write_secret_); + + std::unique_ptr decrypter = + QuicDecrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher)); + CryptoUtils::SetKeyAndIV(prf, latest_read_secret_, + handshaker_delegate_->parsed_version(), + decrypter.get()); + decrypter->SetHeaderProtectionKey(absl::string_view( + reinterpret_cast(one_rtt_read_header_protection_key_.data()), + one_rtt_read_header_protection_key_.size())); + + return decrypter; +} + +std::unique_ptr TlsHandshaker::CreateCurrentOneRttEncrypter() { + if (latest_write_secret_.empty() || + one_rtt_write_header_protection_key_.empty()) { + std::string error_details = "1-RTT write secret not set yet."; + QUIC_BUG(quic_bug_10312_2) << error_details; + CloseConnection(QUIC_INTERNAL_ERROR, error_details); + return nullptr; + } + const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl()); + std::unique_ptr encrypter = + QuicEncrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher)); + CryptoUtils::SetKeyAndIV(Prf(cipher), latest_write_secret_, + handshaker_delegate_->parsed_version(), + encrypter.get()); + encrypter->SetHeaderProtectionKey(absl::string_view( + reinterpret_cast(one_rtt_write_header_protection_key_.data()), + one_rtt_write_header_protection_key_.size())); + return encrypter; +} + +bool TlsHandshaker::ExportKeyingMaterialForLabel(absl::string_view label, + absl::string_view context, + size_t result_len, + std::string* result) { + if (result == nullptr) { + return false; + } + result->resize(result_len); + return SSL_export_keying_material( + ssl(), reinterpret_cast(&*result->begin()), result_len, + label.data(), label.size(), + reinterpret_cast(context.data()), context.size(), + !context.empty()) == 1; +} + +void TlsHandshaker::WriteMessage(EncryptionLevel level, + absl::string_view data) { + stream_->WriteCryptoData(level, data); +} + +void TlsHandshaker::FlushFlight() {} + +void TlsHandshaker::SendAlert(EncryptionLevel level, uint8_t desc) { + if (dont_close_connection_in_tls_alert_callback_) { + QUIC_RELOADABLE_FLAG_COUNT_N( + quic_dont_close_connection_in_tls_alert_callback, 1, 2); + TlsAlert tls_alert; + tls_alert.level = level; + tls_alert.desc = desc; + last_tls_alert_ = tls_alert; + } else { + std::string error_details = absl::StrCat( + "TLS handshake failure (", EncryptionLevelToString(level), ") ", + static_cast(desc), ": ", SSL_alert_desc_string_long(desc)); + QUIC_DLOG(ERROR) << error_details; + CloseConnection( + TlsAlertToQuicErrorCode(desc), + static_cast(CRYPTO_ERROR_FIRST + desc), + error_details); + } +} + +} // namespace quic diff --git a/quiche/quic/core/tls_handshaker.h b/quiche/quic/core/tls_handshaker.h new file mode 100644 index 000000000000..03b6b9e48a0c --- /dev/null +++ b/quiche/quic/core/tls_handshaker.h @@ -0,0 +1,230 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_TLS_HANDSHAKER_H_ +#define QUICHE_QUIC_CORE_TLS_HANDSHAKER_H_ + +#include "absl/strings/string_view.h" +#include "openssl/base.h" +#include "openssl/ssl.h" +#include "quiche/quic/core/crypto/crypto_handshake.h" +#include "quiche/quic/core/crypto/crypto_message_parser.h" +#include "quiche/quic/core/crypto/proof_verifier.h" +#include "quiche/quic/core/crypto/quic_decrypter.h" +#include "quiche/quic/core/crypto/quic_encrypter.h" +#include "quiche/quic/core/crypto/tls_connection.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_flags.h" + +namespace quic { + +class QuicCryptoStream; + +// Base class for TlsClientHandshaker and TlsServerHandshaker. TlsHandshaker +// provides functionality common to both the client and server, such as moving +// messages between the TLS stack and the QUIC crypto stream, and handling +// derivation of secrets. +class QUIC_EXPORT_PRIVATE TlsHandshaker : public TlsConnection::Delegate, + public CryptoMessageParser { + public: + // TlsHandshaker does not take ownership of any of its arguments; they must + // outlive the TlsHandshaker. + TlsHandshaker(QuicCryptoStream* stream, QuicSession* session); + TlsHandshaker(const TlsHandshaker&) = delete; + TlsHandshaker& operator=(const TlsHandshaker&) = delete; + + ~TlsHandshaker() override; + + // From CryptoMessageParser + bool ProcessInput(absl::string_view input, EncryptionLevel level) override; + size_t InputBytesRemaining() const override { return 0; } + QuicErrorCode error() const override { return parser_error_; } + const std::string& error_detail() const override { + return parser_error_detail_; + } + + // The following methods provide implementations to subclasses of + // TlsHandshaker which use them to implement methods of QuicCryptoStream. + CryptoMessageParser* crypto_message_parser() { return this; } + size_t BufferSizeLimitForLevel(EncryptionLevel level) const; + ssl_early_data_reason_t EarlyDataReason() const; + std::unique_ptr AdvanceKeysAndCreateCurrentOneRttDecrypter(); + std::unique_ptr CreateCurrentOneRttEncrypter(); + virtual HandshakeState GetHandshakeState() const = 0; + bool ExportKeyingMaterialForLabel(absl::string_view label, + absl::string_view context, + size_t result_len, std::string* result); + + protected: + // Called when a new message is received on the crypto stream and is available + // for the TLS stack to read. + virtual void AdvanceHandshake(); + + void CloseConnection(QuicErrorCode error, const std::string& reason_phrase); + // Closes the connection, specifying the wire error code |ietf_error| + // explicitly. + void CloseConnection(QuicErrorCode error, + QuicIetfTransportErrorCodes ietf_error, + const std::string& reason_phrase); + + void OnConnectionClosed(QuicErrorCode error, ConnectionCloseSource source); + + bool is_connection_closed() const { return is_connection_closed_; } + + // Called when |SSL_do_handshake| returns 1, indicating that the handshake has + // finished. Note that a handshake only finishes once, entering early data + // does not count. + virtual void FinishHandshake() = 0; + + // Called when |SSL_do_handshake| returns 1 and the connection is in early + // data. In that case, |AdvanceHandshake| will call |OnEnterEarlyData| and + // retry |SSL_do_handshake| once. + virtual void OnEnterEarlyData() { + // By default, do nothing but check the preconditions. + QUICHE_DCHECK(SSL_in_early_data(ssl())); + } + + // Called when a handshake message is received after the handshake is + // complete. + virtual void ProcessPostHandshakeMessage() = 0; + + // Called when an unexpected error code is received from |SSL_get_error|. If a + // subclass can expect more than just a single error (as provided by + // |set_expected_ssl_error|), it can override this method to handle that case. + virtual bool ShouldCloseConnectionOnUnexpectedError(int ssl_error); + + void set_expected_ssl_error(int ssl_error) { + expected_ssl_error_ = ssl_error; + } + int expected_ssl_error() const { return expected_ssl_error_; } + + // Called to verify a cert chain. This can be implemented as a simple wrapper + // around ProofVerifier, which optionally gathers additional arguments to pass + // into their VerifyCertChain method. This class retains a non-owning pointer + // to |callback|; the callback must live until this function returns + // QUIC_SUCCESS or QUIC_FAILURE, or until the callback is run. + // + // If certificate verification fails, |*out_alert| may be set to a TLS alert + // that will be sent when closing the connection; it defaults to + // certificate_unknown. Implementations of VerifyCertChain may retain the + // |out_alert| pointer while performing an async operation. + virtual QuicAsyncStatus VerifyCertChain( + const std::vector& certs, std::string* error_details, + std::unique_ptr* details, uint8_t* out_alert, + std::unique_ptr callback) = 0; + // Called when certificate verification is completed. + virtual void OnProofVerifyDetailsAvailable( + const ProofVerifyDetails& verify_details) = 0; + + // Returns the PRF used by the cipher suite negotiated in the TLS handshake. + const EVP_MD* Prf(const SSL_CIPHER* cipher); + + virtual const TlsConnection* tls_connection() const = 0; + + SSL* ssl() const { return tls_connection()->ssl(); } + + QuicCryptoStream* stream() { return stream_; } + HandshakerDelegateInterface* handshaker_delegate() { + return handshaker_delegate_; + } + + enum ssl_verify_result_t VerifyCert(uint8_t* out_alert) override; + + // SetWriteSecret provides the encryption secret used to encrypt messages at + // encryption level |level|. The secret provided here is the one from the TLS + // 1.3 key schedule (RFC 8446 section 7.1), in particular the handshake + // traffic secrets and application traffic secrets. The provided write secret + // must be used with the provided cipher suite |cipher|. + void SetWriteSecret(EncryptionLevel level, const SSL_CIPHER* cipher, + absl::Span write_secret) override; + + // SetReadSecret is similar to SetWriteSecret, except that it is used for + // decrypting messages. SetReadSecret at a particular level is always called + // after SetWriteSecret for that level, except for ENCRYPTION_ZERO_RTT, where + // the EncryptionLevel for SetWriteSecret is ENCRYPTION_FORWARD_SECURE. + bool SetReadSecret(EncryptionLevel level, const SSL_CIPHER* cipher, + absl::Span read_secret) override; + + // WriteMessage is called when there is |data| from the TLS stack ready for + // the QUIC stack to write in a crypto frame. The data must be transmitted at + // encryption level |level|. + void WriteMessage(EncryptionLevel level, absl::string_view data) override; + + // FlushFlight is called to signal that the current flight of + // messages have all been written (via calls to WriteMessage) and can be + // flushed to the underlying transport. + void FlushFlight() override; + + // SendAlert causes this TlsHandshaker to close the QUIC connection with an + // error code corresponding to the TLS alert description |desc|. + void SendAlert(EncryptionLevel level, uint8_t desc) override; + + // Informational callback from BoringSSL. Subclasses can override it to do + // logging, tracing, etc. + // See |SSL_CTX_set_info_callback| for the meaning of |type| and |value|. + void InfoCallback(int /*type*/, int /*value*/) override {} + + private: + // ProofVerifierCallbackImpl handles the result of an asynchronous certificate + // verification operation. + class QUIC_EXPORT_PRIVATE ProofVerifierCallbackImpl + : public ProofVerifierCallback { + public: + explicit ProofVerifierCallbackImpl(TlsHandshaker* parent); + ~ProofVerifierCallbackImpl() override; + + // ProofVerifierCallback interface. + void Run(bool ok, const std::string& error_details, + std::unique_ptr* details) override; + + // If called, Cancel causes the pending callback to be a no-op. + void Cancel(); + + private: + // Non-owning pointer to the TlsHandshaker responsible for this callback. + // |parent_| must be valid for the life of this callback or until |Cancel| + // is called. + TlsHandshaker* parent_; + }; + + // ProofVerifierCallback used for async certificate verification. Ownership of + // this object is transferred to |VerifyCertChain|; + ProofVerifierCallbackImpl* proof_verify_callback_ = nullptr; + std::unique_ptr verify_details_; + enum ssl_verify_result_t verify_result_ = ssl_verify_retry; + uint8_t cert_verify_tls_alert_ = SSL_AD_CERTIFICATE_UNKNOWN; + std::string cert_verify_error_details_; + + int expected_ssl_error_ = SSL_ERROR_WANT_READ; + bool is_connection_closed_ = false; + + QuicCryptoStream* stream_; + HandshakerDelegateInterface* handshaker_delegate_; + + QuicErrorCode parser_error_ = QUIC_NO_ERROR; + std::string parser_error_detail_; + + // The most recently derived 1-RTT read and write secrets, which are updated + // on each key update. + std::vector latest_read_secret_; + std::vector latest_write_secret_; + // 1-RTT header protection keys, which are not changed during key update. + std::vector one_rtt_read_header_protection_key_; + std::vector one_rtt_write_header_protection_key_; + + struct TlsAlert { + EncryptionLevel level; + // The TLS alert code as listed in + // https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-6 + uint8_t desc; + }; + absl::optional last_tls_alert_; + const bool dont_close_connection_in_tls_alert_callback_ = + GetQuicReloadableFlag(quic_dont_close_connection_in_tls_alert_callback); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_TLS_HANDSHAKER_H_ diff --git a/quiche/quic/core/tls_server_handshaker.cc b/quiche/quic/core/tls_server_handshaker.cc new file mode 100644 index 000000000000..f4bbd3df8ff9 --- /dev/null +++ b/quiche/quic/core/tls_server_handshaker.cc @@ -0,0 +1,1185 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/tls_server_handshaker.h" + +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "openssl/pool.h" +#include "openssl/ssl.h" +#include "quiche/quic/core/crypto/quic_crypto_server_config.h" +#include "quiche/quic/core/crypto/transport_parameters.h" +#include "quiche/quic/core/http/http_encoder.h" +#include "quiche/quic/core/http/http_frames.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_hostname_utils.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_server_stats.h" + +#define RECORD_LATENCY_IN_US(stat_name, latency, comment) \ + do { \ + const int64_t latency_in_us = (latency).ToMicroseconds(); \ + QUIC_DVLOG(1) << "Recording " stat_name ": " << latency_in_us; \ + QUIC_SERVER_HISTOGRAM_COUNTS(stat_name, latency_in_us, 1, 10000000, 50, \ + comment); \ + } while (0) + +namespace quic { + +namespace { + +// Default port for HTTP/3. +uint16_t kDefaultPort = 443; + +} // namespace + +TlsServerHandshaker::DefaultProofSourceHandle::DefaultProofSourceHandle( + TlsServerHandshaker* handshaker, ProofSource* proof_source) + : handshaker_(handshaker), proof_source_(proof_source) {} + +TlsServerHandshaker::DefaultProofSourceHandle::~DefaultProofSourceHandle() { + CloseHandle(); +} + +void TlsServerHandshaker::DefaultProofSourceHandle::CloseHandle() { + QUIC_DVLOG(1) << "CloseHandle. is_signature_pending=" + << (signature_callback_ != nullptr); + if (signature_callback_) { + signature_callback_->Cancel(); + signature_callback_ = nullptr; + } +} + +QuicAsyncStatus +TlsServerHandshaker::DefaultProofSourceHandle::SelectCertificate( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const QuicConnectionId& /*original_connection_id*/, + absl::string_view /*ssl_capabilities*/, const std::string& hostname, + absl::string_view /*client_hello*/, const std::string& /*alpn*/, + absl::optional /*alps*/, + const std::vector& /*quic_transport_params*/, + const absl::optional>& /*early_data_context*/, + const QuicSSLConfig& /*ssl_config*/) { + if (!handshaker_ || !proof_source_) { + QUIC_BUG(quic_bug_10341_1) + << "SelectCertificate called on a detached handle"; + return QUIC_FAILURE; + } + + bool cert_matched_sni; + quiche::QuicheReferenceCountedPointer chain = + proof_source_->GetCertChain(server_address, client_address, hostname, + &cert_matched_sni); + + handshaker_->OnSelectCertificateDone( + /*ok=*/true, /*is_sync=*/true, chain.get(), + /*handshake_hints=*/absl::string_view(), + /*ticket_encryption_key=*/absl::string_view(), cert_matched_sni, + QuicDelayedSSLConfig()); + if (!handshaker_->select_cert_status().has_value()) { + QUIC_BUG(quic_bug_12423_1) + << "select_cert_status() has no value after a synchronous select cert"; + // Return success to continue the handshake. + return QUIC_SUCCESS; + } + return handshaker_->select_cert_status().value(); +} + +QuicAsyncStatus TlsServerHandshaker::DefaultProofSourceHandle::ComputeSignature( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, const std::string& hostname, + uint16_t signature_algorithm, absl::string_view in, + size_t max_signature_size) { + if (!handshaker_ || !proof_source_) { + QUIC_BUG(quic_bug_10341_2) + << "ComputeSignature called on a detached handle"; + return QUIC_FAILURE; + } + + if (signature_callback_) { + QUIC_BUG(quic_bug_10341_3) << "ComputeSignature called while pending"; + return QUIC_FAILURE; + } + + signature_callback_ = new DefaultSignatureCallback(this); + proof_source_->ComputeTlsSignature( + server_address, client_address, hostname, signature_algorithm, in, + std::unique_ptr(signature_callback_)); + + if (signature_callback_) { + QUIC_DVLOG(1) << "ComputeTlsSignature is pending"; + signature_callback_->set_is_sync(false); + return QUIC_PENDING; + } + + bool success = handshaker_->HasValidSignature(max_signature_size); + QUIC_DVLOG(1) << "ComputeTlsSignature completed synchronously. success:" + << success; + // OnComputeSignatureDone should have been called by signature_callback_->Run. + return success ? QUIC_SUCCESS : QUIC_FAILURE; +} + +TlsServerHandshaker::DecryptCallback::DecryptCallback( + TlsServerHandshaker* handshaker) + : handshaker_(handshaker) {} + +void TlsServerHandshaker::DecryptCallback::Run(std::vector plaintext) { + if (handshaker_ == nullptr) { + // The callback was cancelled before we could run. + return; + } + + TlsServerHandshaker* handshaker = handshaker_; + handshaker_ = nullptr; + + handshaker->decrypted_session_ticket_ = std::move(plaintext); + const bool is_async = + (handshaker->expected_ssl_error() == SSL_ERROR_PENDING_TICKET); + + absl::optional context_switcher; + + if (is_async) { + context_switcher.emplace(handshaker->connection_context()); + } + QUIC_TRACESTRING( + absl::StrCat("TLS ticket decryption done. len(decrypted_ticket):", + handshaker->decrypted_session_ticket_.size())); + + // DecryptCallback::Run could be called synchronously. When that happens, we + // are currently in the middle of a call to AdvanceHandshake. + // (AdvanceHandshake called SSL_do_handshake, which through some layers + // called SessionTicketOpen, which called TicketCrypter::Decrypt, which + // synchronously called this function.) In that case, the handshake will + // continue to be processed when this function returns. + // + // When this callback is called asynchronously (i.e. the ticket decryption + // is pending), TlsServerHandshaker is not actively processing handshake + // messages. We need to have it resume processing handshake messages by + // calling AdvanceHandshake. + if (is_async) { + handshaker->AdvanceHandshakeFromCallback(); + } + + handshaker->ticket_decryption_callback_ = nullptr; +} + +void TlsServerHandshaker::DecryptCallback::Cancel() { + QUICHE_DCHECK(handshaker_); + handshaker_ = nullptr; +} + +TlsServerHandshaker::TlsServerHandshaker( + QuicSession* session, const QuicCryptoServerConfig* crypto_config) + : TlsHandshaker(this, session), + QuicCryptoServerStreamBase(session), + proof_source_(crypto_config->proof_source()), + pre_shared_key_(crypto_config->pre_shared_key()), + crypto_negotiated_params_(new QuicCryptoNegotiatedParameters), + tls_connection_(crypto_config->ssl_ctx(), this, session->GetSSLConfig()), + crypto_config_(crypto_config) { + QUIC_DVLOG(1) << "TlsServerHandshaker: client_cert_mode initial value: " + << client_cert_mode(); + + QUICHE_DCHECK_EQ(PROTOCOL_TLS1_3, + session->connection()->version().handshake_protocol); + + // Configure the SSL to be a server. + SSL_set_accept_state(ssl()); + + // Make sure we use the right TLS extension codepoint. + int use_legacy_extension = 0; + if (session->version().UsesLegacyTlsExtension()) { + use_legacy_extension = 1; + } + SSL_set_quic_use_legacy_codepoint(ssl(), use_legacy_extension); + + if (session->connection()->context()->tracer) { + tls_connection_.EnableInfoCallback(); + } +} + +TlsServerHandshaker::~TlsServerHandshaker() { CancelOutstandingCallbacks(); } + +void TlsServerHandshaker::CancelOutstandingCallbacks() { + if (proof_source_handle_) { + proof_source_handle_->CloseHandle(); + } + if (ticket_decryption_callback_) { + ticket_decryption_callback_->Cancel(); + ticket_decryption_callback_ = nullptr; + } +} + +void TlsServerHandshaker::InfoCallback(int type, int value) { + QuicConnectionTracer* tracer = + session()->connection()->context()->tracer.get(); + + if (tracer == nullptr) { + return; + } + + if (type & SSL_CB_LOOP) { + tracer->PrintString( + absl::StrCat("SSL:ACCEPT_LOOP:", SSL_state_string_long(ssl()))); + } else if (type & SSL_CB_ALERT) { + const char* prefix = + (type & SSL_CB_READ) ? "SSL:READ_ALERT:" : "SSL:WRITE_ALERT:"; + tracer->PrintString(absl::StrCat(prefix, SSL_alert_type_string_long(value), + ":", SSL_alert_desc_string_long(value))); + } else if (type & SSL_CB_EXIT) { + const char* prefix = + (value == 1) ? "SSL:ACCEPT_EXIT_OK:" : "SSL:ACCEPT_EXIT_FAIL:"; + tracer->PrintString(absl::StrCat(prefix, SSL_state_string_long(ssl()))); + } else if (type & SSL_CB_HANDSHAKE_START) { + tracer->PrintString( + absl::StrCat("SSL:HANDSHAKE_START:", SSL_state_string_long(ssl()))); + } else if (type & SSL_CB_HANDSHAKE_DONE) { + tracer->PrintString( + absl::StrCat("SSL:HANDSHAKE_DONE:", SSL_state_string_long(ssl()))); + } else { + QUIC_DLOG(INFO) << "Unknown event type " << type << ": " + << SSL_state_string_long(ssl()); + tracer->PrintString( + absl::StrCat("SSL:unknown:", value, ":", SSL_state_string_long(ssl()))); + } +} + +std::unique_ptr +TlsServerHandshaker::MaybeCreateProofSourceHandle() { + return std::make_unique(this, proof_source_); +} + +bool TlsServerHandshaker::GetBase64SHA256ClientChannelID( + std::string* /*output*/) const { + // Channel ID is not supported when TLS is used in QUIC. + return false; +} + +void TlsServerHandshaker::SendServerConfigUpdate( + const CachedNetworkParameters* /*cached_network_params*/) { + // SCUP messages aren't supported when using the TLS handshake. +} + +bool TlsServerHandshaker::DisableResumption() { + if (!can_disable_resumption_ || !session()->connection()->connected()) { + return false; + } + tls_connection_.DisableTicketSupport(); + return true; +} + +bool TlsServerHandshaker::IsZeroRtt() const { + return SSL_early_data_accepted(ssl()); +} + +bool TlsServerHandshaker::IsResumption() const { + return SSL_session_reused(ssl()); +} + +bool TlsServerHandshaker::ResumptionAttempted() const { + return ticket_received_; +} + +bool TlsServerHandshaker::EarlyDataAttempted() const { + QUIC_BUG_IF(quic_tls_early_data_attempted_too_early, + !select_cert_status_.has_value()) + << "EarlyDataAttempted must be called after EarlySelectCertCallback is " + "started"; + return early_data_attempted_; +} + +int TlsServerHandshaker::NumServerConfigUpdateMessagesSent() const { + // SCUP messages aren't supported when using the TLS handshake. + return 0; +} + +const CachedNetworkParameters* +TlsServerHandshaker::PreviousCachedNetworkParams() const { + return last_received_cached_network_params_.get(); +} + +void TlsServerHandshaker::SetPreviousCachedNetworkParams( + CachedNetworkParameters cached_network_params) { + last_received_cached_network_params_ = + std::make_unique(cached_network_params); +} + +void TlsServerHandshaker::OnPacketDecrypted(EncryptionLevel level) { + if (level == ENCRYPTION_HANDSHAKE && state_ < HANDSHAKE_PROCESSED) { + state_ = HANDSHAKE_PROCESSED; + handshaker_delegate()->DiscardOldEncryptionKey(ENCRYPTION_INITIAL); + handshaker_delegate()->DiscardOldDecryptionKey(ENCRYPTION_INITIAL); + } +} + +void TlsServerHandshaker::OnHandshakeDoneReceived() { QUICHE_DCHECK(false); } + +void TlsServerHandshaker::OnNewTokenReceived(absl::string_view /*token*/) { + QUICHE_DCHECK(false); +} + +std::string TlsServerHandshaker::GetAddressToken( + const CachedNetworkParameters* cached_network_params) const { + SourceAddressTokens empty_previous_tokens; + const QuicConnection* connection = session()->connection(); + return crypto_config_->NewSourceAddressToken( + crypto_config_->source_address_token_boxer(), empty_previous_tokens, + connection->effective_peer_address().host(), + connection->random_generator(), connection->clock()->WallNow(), + cached_network_params); +} + +bool TlsServerHandshaker::ValidateAddressToken(absl::string_view token) const { + SourceAddressTokens tokens; + HandshakeFailureReason reason = crypto_config_->ParseSourceAddressToken( + crypto_config_->source_address_token_boxer(), token, tokens); + if (reason != HANDSHAKE_OK) { + QUIC_DLOG(WARNING) << "Failed to parse source address token: " + << CryptoUtils::HandshakeFailureReasonToString(reason); + return false; + } + auto cached_network_params = std::make_unique(); + reason = crypto_config_->ValidateSourceAddressTokens( + tokens, session()->connection()->effective_peer_address().host(), + session()->connection()->clock()->WallNow(), cached_network_params.get()); + if (reason != HANDSHAKE_OK) { + QUIC_DLOG(WARNING) << "Failed to validate source address token: " + << CryptoUtils::HandshakeFailureReasonToString(reason); + return false; + } + + last_received_cached_network_params_ = std::move(cached_network_params); + return true; +} + +bool TlsServerHandshaker::ShouldSendExpectCTHeader() const { return false; } + +bool TlsServerHandshaker::DidCertMatchSni() const { return cert_matched_sni_; } + +const ProofSource::Details* TlsServerHandshaker::ProofSourceDetails() const { + return proof_source_details_.get(); +} + +bool TlsServerHandshaker::ExportKeyingMaterial(absl::string_view label, + absl::string_view context, + size_t result_len, + std::string* result) { + return ExportKeyingMaterialForLabel(label, context, result_len, result); +} + +void TlsServerHandshaker::OnConnectionClosed(QuicErrorCode error, + ConnectionCloseSource source) { + TlsHandshaker::OnConnectionClosed(error, source); +} + +ssl_early_data_reason_t TlsServerHandshaker::EarlyDataReason() const { + return TlsHandshaker::EarlyDataReason(); +} + +bool TlsServerHandshaker::encryption_established() const { + return encryption_established_; +} + +bool TlsServerHandshaker::one_rtt_keys_available() const { + return state_ == HANDSHAKE_CONFIRMED; +} + +const QuicCryptoNegotiatedParameters& +TlsServerHandshaker::crypto_negotiated_params() const { + return *crypto_negotiated_params_; +} + +CryptoMessageParser* TlsServerHandshaker::crypto_message_parser() { + return TlsHandshaker::crypto_message_parser(); +} + +HandshakeState TlsServerHandshaker::GetHandshakeState() const { return state_; } + +void TlsServerHandshaker::SetServerApplicationStateForResumption( + std::unique_ptr state) { + application_state_ = std::move(state); +} + +size_t TlsServerHandshaker::BufferSizeLimitForLevel( + EncryptionLevel level) const { + return TlsHandshaker::BufferSizeLimitForLevel(level); +} + +std::unique_ptr +TlsServerHandshaker::AdvanceKeysAndCreateCurrentOneRttDecrypter() { + return TlsHandshaker::AdvanceKeysAndCreateCurrentOneRttDecrypter(); +} + +std::unique_ptr +TlsServerHandshaker::CreateCurrentOneRttEncrypter() { + return TlsHandshaker::CreateCurrentOneRttEncrypter(); +} + +void TlsServerHandshaker::OverrideQuicConfigDefaults(QuicConfig* /*config*/) {} + +void TlsServerHandshaker::AdvanceHandshakeFromCallback() { + QuicConnection::ScopedPacketFlusher flusher(session()->connection()); + + AdvanceHandshake(); + if (!is_connection_closed()) { + handshaker_delegate()->OnHandshakeCallbackDone(); + } +} + +bool TlsServerHandshaker::ProcessTransportParameters( + const SSL_CLIENT_HELLO* client_hello, std::string* error_details) { + TransportParameters client_params; + const uint8_t* client_params_bytes; + size_t params_bytes_len; + + // Make sure we use the right TLS extension codepoint. + uint16_t extension_type = TLSEXT_TYPE_quic_transport_parameters_standard; + if (session()->version().UsesLegacyTlsExtension()) { + extension_type = TLSEXT_TYPE_quic_transport_parameters_legacy; + } + // When using early select cert callback, SSL_get_peer_quic_transport_params + // can not be used to retrieve the client's transport parameters, but we can + // use SSL_early_callback_ctx_extension_get to do that. + if (!SSL_early_callback_ctx_extension_get(client_hello, extension_type, + &client_params_bytes, + ¶ms_bytes_len)) { + params_bytes_len = 0; + } + + if (params_bytes_len == 0) { + *error_details = "Client's transport parameters are missing"; + return false; + } + std::string parse_error_details; + if (!ParseTransportParameters(session()->connection()->version(), + Perspective::IS_CLIENT, client_params_bytes, + params_bytes_len, &client_params, + &parse_error_details)) { + QUICHE_DCHECK(!parse_error_details.empty()); + *error_details = + "Unable to parse client's transport parameters: " + parse_error_details; + return false; + } + + // Notify QuicConnectionDebugVisitor. + session()->connection()->OnTransportParametersReceived(client_params); + + if (client_params.legacy_version_information.has_value() && + CryptoUtils::ValidateClientHelloVersion( + client_params.legacy_version_information.value().version, + session()->connection()->version(), session()->supported_versions(), + error_details) != QUIC_NO_ERROR) { + return false; + } + + if (client_params.version_information.has_value() && + !CryptoUtils::ValidateChosenVersion( + client_params.version_information.value().chosen_version, + session()->version(), error_details)) { + QUICHE_DCHECK(!error_details->empty()); + return false; + } + + if (handshaker_delegate()->ProcessTransportParameters( + client_params, /* is_resumption = */ false, error_details) != + QUIC_NO_ERROR) { + return false; + } + + ProcessAdditionalTransportParameters(client_params); + + return true; +} + +TlsServerHandshaker::SetTransportParametersResult +TlsServerHandshaker::SetTransportParameters() { + SetTransportParametersResult result; + QUICHE_DCHECK(!result.success); + + server_params_.perspective = Perspective::IS_SERVER; + server_params_.legacy_version_information = + TransportParameters::LegacyVersionInformation(); + server_params_.legacy_version_information.value().supported_versions = + CreateQuicVersionLabelVector(session()->supported_versions()); + server_params_.legacy_version_information.value().version = + CreateQuicVersionLabel(session()->connection()->version()); + server_params_.version_information = + TransportParameters::VersionInformation(); + server_params_.version_information.value().chosen_version = + CreateQuicVersionLabel(session()->version()); + server_params_.version_information.value().other_versions = + CreateQuicVersionLabelVector(session()->supported_versions()); + + if (!handshaker_delegate()->FillTransportParameters(&server_params_)) { + return result; + } + + // Notify QuicConnectionDebugVisitor. + session()->connection()->OnTransportParametersSent(server_params_); + + { // Ensure |server_params_bytes| is not accessed out of the scope. + std::vector server_params_bytes; + if (!SerializeTransportParameters(server_params_, &server_params_bytes) || + SSL_set_quic_transport_params(ssl(), server_params_bytes.data(), + server_params_bytes.size()) != 1) { + return result; + } + result.quic_transport_params = std::move(server_params_bytes); + } + + if (application_state_) { + std::vector early_data_context; + if (!SerializeTransportParametersForTicket( + server_params_, *application_state_, &early_data_context)) { + QUIC_BUG(quic_bug_10341_4) + << "Failed to serialize Transport Parameters for ticket."; + result.early_data_context = std::vector(); + return result; + } + SSL_set_quic_early_data_context(ssl(), early_data_context.data(), + early_data_context.size()); + result.early_data_context = std::move(early_data_context); + application_state_.reset(nullptr); + } + result.success = true; + return result; +} + +bool TlsServerHandshaker::TransportParametersMatch( + absl::Span serialized_params) const { + TransportParameters params; + std::string error_details; + + bool parse_ok = ParseTransportParameters( + session()->version(), Perspective::IS_SERVER, serialized_params.data(), + serialized_params.size(), ¶ms, &error_details); + + if (!parse_ok) { + return false; + } + + DegreaseTransportParameters(params); + + return params == server_params_; +} + +void TlsServerHandshaker::SetWriteSecret( + EncryptionLevel level, const SSL_CIPHER* cipher, + absl::Span write_secret) { + if (is_connection_closed()) { + return; + } + if (level == ENCRYPTION_FORWARD_SECURE) { + encryption_established_ = true; + // Fill crypto_negotiated_params_: + const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl()); + if (cipher) { + crypto_negotiated_params_->cipher_suite = + SSL_CIPHER_get_protocol_id(cipher); + } + crypto_negotiated_params_->key_exchange_group = SSL_get_curve_id(ssl()); + crypto_negotiated_params_->encrypted_client_hello = SSL_ech_accepted(ssl()); + } + TlsHandshaker::SetWriteSecret(level, cipher, write_secret); +} + +std::string TlsServerHandshaker::GetAcceptChValueForHostname( + const std::string& /*hostname*/) const { + return {}; +} + +void TlsServerHandshaker::FinishHandshake() { + QUICHE_DCHECK(!SSL_in_early_data(ssl())); + + if (!valid_alpn_received_) { + QUIC_DLOG(ERROR) + << "Server: handshake finished without receiving a known ALPN"; + // TODO(b/130164908) this should send no_application_protocol + // instead of QUIC_HANDSHAKE_FAILED. + CloseConnection(QUIC_HANDSHAKE_FAILED, + "Server did not receive a known ALPN"); + return; + } + + ssl_early_data_reason_t reason_code = EarlyDataReason(); + QUIC_DLOG(INFO) << "Server: handshake finished. Early data reason " + << reason_code << " (" + << CryptoUtils::EarlyDataReasonToString(reason_code) << ")"; + state_ = HANDSHAKE_CONFIRMED; + + handshaker_delegate()->OnTlsHandshakeComplete(); + handshaker_delegate()->DiscardOldEncryptionKey(ENCRYPTION_HANDSHAKE); + handshaker_delegate()->DiscardOldDecryptionKey(ENCRYPTION_HANDSHAKE); + // ENCRYPTION_ZERO_RTT decryption key is not discarded here as "Servers MAY + // temporarily retain 0-RTT keys to allow decrypting reordered packets + // without requiring their contents to be retransmitted with 1-RTT keys." + // It is expected that QuicConnection will discard the key at an + // appropriate time. +} + +QuicAsyncStatus TlsServerHandshaker::VerifyCertChain( + const std::vector& /*certs*/, std::string* /*error_details*/, + std::unique_ptr* /*details*/, uint8_t* /*out_alert*/, + std::unique_ptr /*callback*/) { + QUIC_DVLOG(1) << "VerifyCertChain returning success"; + + // No real verification here. A subclass can override this function to verify + // the client cert if needed. + return QUIC_SUCCESS; +} + +void TlsServerHandshaker::OnProofVerifyDetailsAvailable( + const ProofVerifyDetails& /*verify_details*/) {} + +ssl_private_key_result_t TlsServerHandshaker::PrivateKeySign( + uint8_t* out, size_t* out_len, size_t max_out, uint16_t sig_alg, + absl::string_view in) { + QUICHE_DCHECK_EQ(expected_ssl_error(), SSL_ERROR_WANT_READ); + + QuicAsyncStatus status = proof_source_handle_->ComputeSignature( + session()->connection()->self_address(), + session()->connection()->peer_address(), crypto_negotiated_params_->sni, + sig_alg, in, max_out); + if (status == QUIC_PENDING) { + set_expected_ssl_error(SSL_ERROR_WANT_PRIVATE_KEY_OPERATION); + if (async_op_timer_.has_value()) { + QUIC_CODE_COUNT( + quic_tls_server_computing_signature_while_another_op_pending); + } + async_op_timer_ = QuicTimeAccumulator(); + async_op_timer_->Start(now()); + } + return PrivateKeyComplete(out, out_len, max_out); +} + +ssl_private_key_result_t TlsServerHandshaker::PrivateKeyComplete( + uint8_t* out, size_t* out_len, size_t max_out) { + if (expected_ssl_error() == SSL_ERROR_WANT_PRIVATE_KEY_OPERATION) { + return ssl_private_key_retry; + } + + const bool success = HasValidSignature(max_out); + QuicConnectionStats::TlsServerOperationStats compute_signature_stats; + compute_signature_stats.success = success; + if (async_op_timer_.has_value()) { + async_op_timer_->Stop(now()); + compute_signature_stats.async_latency = + async_op_timer_->GetTotalElapsedTime(); + async_op_timer_.reset(); + RECORD_LATENCY_IN_US("tls_server_async_compute_signature_latency_us", + compute_signature_stats.async_latency, + "Async compute signature latency in microseconds"); + } + connection_stats().tls_server_compute_signature_stats = + std::move(compute_signature_stats); + + if (!success) { + return ssl_private_key_failure; + } + *out_len = cert_verify_sig_.size(); + memcpy(out, cert_verify_sig_.data(), *out_len); + cert_verify_sig_.clear(); + cert_verify_sig_.shrink_to_fit(); + return ssl_private_key_success; +} + +void TlsServerHandshaker::OnComputeSignatureDone( + bool ok, bool is_sync, std::string signature, + std::unique_ptr details) { + QUIC_DVLOG(1) << "OnComputeSignatureDone. ok:" << ok + << ", is_sync:" << is_sync + << ", len(signature):" << signature.size(); + absl::optional context_switcher; + + if (!is_sync) { + context_switcher.emplace(connection_context()); + } + + QUIC_TRACESTRING(absl::StrCat("TLS compute signature done. ok:", ok, + ", len(signature):", signature.size())); + + if (ok) { + cert_verify_sig_ = std::move(signature); + proof_source_details_ = std::move(details); + } + const int last_expected_ssl_error = expected_ssl_error(); + set_expected_ssl_error(SSL_ERROR_WANT_READ); + if (!is_sync) { + QUICHE_DCHECK_EQ(last_expected_ssl_error, + SSL_ERROR_WANT_PRIVATE_KEY_OPERATION); + AdvanceHandshakeFromCallback(); + } +} + +bool TlsServerHandshaker::HasValidSignature(size_t max_signature_size) const { + return !cert_verify_sig_.empty() && + cert_verify_sig_.size() <= max_signature_size; +} + +size_t TlsServerHandshaker::SessionTicketMaxOverhead() { + QUICHE_DCHECK(proof_source_->GetTicketCrypter()); + return proof_source_->GetTicketCrypter()->MaxOverhead(); +} + +int TlsServerHandshaker::SessionTicketSeal(uint8_t* out, size_t* out_len, + size_t max_out_len, + absl::string_view in) { + QUICHE_DCHECK(proof_source_->GetTicketCrypter()); + std::vector ticket = + proof_source_->GetTicketCrypter()->Encrypt(in, ticket_encryption_key_); + if (GetQuicReloadableFlag( + quic_send_placeholder_ticket_when_encrypt_ticket_fails) && + ticket.empty()) { + QUIC_CODE_COUNT(quic_tls_server_handshaker_send_placeholder_ticket); + const absl::string_view kTicketFailurePlaceholder = "TICKET FAILURE"; + const absl::string_view kTicketWithSizeLimit = + kTicketFailurePlaceholder.substr(0, max_out_len); + ticket.assign(kTicketWithSizeLimit.begin(), kTicketWithSizeLimit.end()); + } + if (max_out_len < ticket.size()) { + QUIC_BUG(quic_bug_12423_2) + << "TicketCrypter returned " << ticket.size() + << " bytes of ciphertext, which is larger than its max overhead of " + << max_out_len; + return 0; // failure + } + *out_len = ticket.size(); + memcpy(out, ticket.data(), ticket.size()); + QUIC_CODE_COUNT(quic_tls_server_handshaker_tickets_sealed); + return 1; // success +} + +ssl_ticket_aead_result_t TlsServerHandshaker::SessionTicketOpen( + uint8_t* out, size_t* out_len, size_t max_out_len, absl::string_view in) { + QUICHE_DCHECK(proof_source_->GetTicketCrypter()); + + if (ignore_ticket_open_) { + // SetIgnoreTicketOpen has been called. Typically this means the caller is + // using handshake hints and expect the hints to contain ticket decryption + // results. + QUIC_CODE_COUNT(quic_tls_server_handshaker_tickets_ignored_1); + return ssl_ticket_aead_ignore_ticket; + } + + if (!ticket_decryption_callback_) { + ticket_decryption_callback_ = std::make_shared(this); + proof_source_->GetTicketCrypter()->Decrypt(in, ticket_decryption_callback_); + + // Decrypt can run the callback synchronously. In that case, the callback + // will clear the ticket_decryption_callback_ pointer, and instead of + // returning ssl_ticket_aead_retry, we should continue processing to + // return the decrypted ticket. + // + // If the callback is not run synchronously, return ssl_ticket_aead_retry + // and when the callback is complete this function will be run again to + // return the result. + if (ticket_decryption_callback_) { + QUICHE_DCHECK(!ticket_decryption_callback_->IsDone()); + set_expected_ssl_error(SSL_ERROR_PENDING_TICKET); + if (async_op_timer_.has_value()) { + QUIC_CODE_COUNT( + quic_tls_server_decrypting_ticket_while_another_op_pending); + } + async_op_timer_ = QuicTimeAccumulator(); + async_op_timer_->Start(now()); + } + } + + // If the async ticket decryption is pending, either started by this + // SessionTicketOpen call or one that happened earlier, return + // ssl_ticket_aead_retry. + if (ticket_decryption_callback_ && !ticket_decryption_callback_->IsDone()) { + return ssl_ticket_aead_retry; + } + + ssl_ticket_aead_result_t result = + FinalizeSessionTicketOpen(out, out_len, max_out_len); + + QuicConnectionStats::TlsServerOperationStats decrypt_ticket_stats; + decrypt_ticket_stats.success = (result == ssl_ticket_aead_success); + if (async_op_timer_.has_value()) { + async_op_timer_->Stop(now()); + decrypt_ticket_stats.async_latency = async_op_timer_->GetTotalElapsedTime(); + async_op_timer_.reset(); + RECORD_LATENCY_IN_US("tls_server_async_decrypt_ticket_latency_us", + decrypt_ticket_stats.async_latency, + "Async decrypt ticket latency in microseconds"); + } + connection_stats().tls_server_decrypt_ticket_stats = + std::move(decrypt_ticket_stats); + + return result; +} + +ssl_ticket_aead_result_t TlsServerHandshaker::FinalizeSessionTicketOpen( + uint8_t* out, size_t* out_len, size_t max_out_len) { + ticket_decryption_callback_ = nullptr; + set_expected_ssl_error(SSL_ERROR_WANT_READ); + if (decrypted_session_ticket_.empty()) { + QUIC_DLOG(ERROR) << "Session ticket decryption failed; ignoring ticket"; + // Ticket decryption failed. Ignore the ticket. + QUIC_CODE_COUNT(quic_tls_server_handshaker_tickets_ignored_2); + return ssl_ticket_aead_ignore_ticket; + } + if (max_out_len < decrypted_session_ticket_.size()) { + return ssl_ticket_aead_error; + } + memcpy(out, decrypted_session_ticket_.data(), + decrypted_session_ticket_.size()); + *out_len = decrypted_session_ticket_.size(); + + QUIC_CODE_COUNT(quic_tls_server_handshaker_tickets_opened); + return ssl_ticket_aead_success; +} + +ssl_select_cert_result_t TlsServerHandshaker::EarlySelectCertCallback( + const SSL_CLIENT_HELLO* client_hello) { + // EarlySelectCertCallback can be called twice from BoringSSL: If the first + // call returns ssl_select_cert_retry, when cert selection completes, + // SSL_do_handshake will call it again. + + if (select_cert_status_.has_value()) { + // This is the second call, return the result directly. + QUIC_DVLOG(1) << "EarlySelectCertCallback called to continue handshake, " + "returning directly. success:" + << (select_cert_status_.value() == QUIC_SUCCESS); + return (select_cert_status_.value() == QUIC_SUCCESS) + ? ssl_select_cert_success + : ssl_select_cert_error; + } + + // This is the first call. + select_cert_status_ = QUIC_PENDING; + proof_source_handle_ = MaybeCreateProofSourceHandle(); + + if (!pre_shared_key_.empty()) { + // TODO(b/154162689) add PSK support to QUIC+TLS. + QUIC_BUG(quic_bug_10341_6) + << "QUIC server pre-shared keys not yet supported with TLS"; + return ssl_select_cert_error; + } + + { + const uint8_t* unused_extension_bytes; + size_t unused_extension_len; + ticket_received_ = SSL_early_callback_ctx_extension_get( + client_hello, TLSEXT_TYPE_pre_shared_key, &unused_extension_bytes, + &unused_extension_len); + + early_data_attempted_ = SSL_early_callback_ctx_extension_get( + client_hello, TLSEXT_TYPE_early_data, &unused_extension_bytes, + &unused_extension_len); + } + + // This callback is called very early by Boring SSL, most of the SSL_get_foo + // function do not work at this point, but SSL_get_servername does. + const char* hostname = SSL_get_servername(ssl(), TLSEXT_NAMETYPE_host_name); + if (hostname) { + crypto_negotiated_params_->sni = + QuicHostnameUtils::NormalizeHostname(hostname); + if (!ValidateHostname(hostname)) { + return ssl_select_cert_error; + } + if (hostname != crypto_negotiated_params_->sni) { + QUIC_CODE_COUNT(quic_tls_server_hostname_diff); + QUIC_LOG_EVERY_N_SEC(WARNING, 300) + << "Raw and normalized hostnames differ, but both are valid SNIs. " + "raw hostname:" + << hostname << ", normalized:" << crypto_negotiated_params_->sni; + } else { + QUIC_CODE_COUNT(quic_tls_server_hostname_same); + } + } else { + QUIC_LOG(INFO) << "No hostname indicated in SNI"; + } + + std::string error_details; + if (!ProcessTransportParameters(client_hello, &error_details)) { + CloseConnection(QUIC_HANDSHAKE_FAILED, error_details); + return ssl_select_cert_error; + } + OverrideQuicConfigDefaults(session()->config()); + session()->OnConfigNegotiated(); + + auto set_transport_params_result = SetTransportParameters(); + if (!set_transport_params_result.success) { + QUIC_LOG(ERROR) << "Failed to set transport parameters"; + return ssl_select_cert_error; + } + + bssl::UniquePtr ssl_capabilities; + size_t ssl_capabilities_len = 0; + absl::string_view ssl_capabilities_view; + + if (CryptoUtils::GetSSLCapabilities(ssl(), &ssl_capabilities, + &ssl_capabilities_len)) { + ssl_capabilities_view = + absl::string_view(reinterpret_cast(ssl_capabilities.get()), + ssl_capabilities_len); + } + + // Enable ALPS for the session's ALPN. + SetApplicationSettingsResult alps_result = + SetApplicationSettings(AlpnForVersion(session()->version())); + if (!alps_result.success) { + return ssl_select_cert_error; + } + + if (!session()->connection()->connected()) { + select_cert_status_ = QUIC_FAILURE; + return ssl_select_cert_error; + } + + can_disable_resumption_ = false; + const QuicAsyncStatus status = proof_source_handle_->SelectCertificate( + session()->connection()->self_address().Normalized(), + session()->connection()->peer_address().Normalized(), + session()->connection()->GetOriginalDestinationConnectionId(), + ssl_capabilities_view, crypto_negotiated_params_->sni, + absl::string_view( + reinterpret_cast(client_hello->client_hello), + client_hello->client_hello_len), + AlpnForVersion(session()->version()), std::move(alps_result.alps_buffer), + set_transport_params_result.quic_transport_params, + set_transport_params_result.early_data_context, + tls_connection_.ssl_config()); + + QUICHE_DCHECK_EQ(status, select_cert_status().value()); + + if (status == QUIC_PENDING) { + set_expected_ssl_error(SSL_ERROR_PENDING_CERTIFICATE); + if (async_op_timer_.has_value()) { + QUIC_CODE_COUNT(quic_tls_server_selecting_cert_while_another_op_pending); + } + async_op_timer_ = QuicTimeAccumulator(); + async_op_timer_->Start(now()); + return ssl_select_cert_retry; + } + + if (status == QUIC_FAILURE) { + return ssl_select_cert_error; + } + + return ssl_select_cert_success; +} + +void TlsServerHandshaker::OnSelectCertificateDone( + bool ok, bool is_sync, const ProofSource::Chain* chain, + absl::string_view handshake_hints, absl::string_view ticket_encryption_key, + bool cert_matched_sni, QuicDelayedSSLConfig delayed_ssl_config) { + QUIC_DVLOG(1) << "OnSelectCertificateDone. ok:" << ok + << ", is_sync:" << is_sync + << ", len(handshake_hints):" << handshake_hints.size() + << ", len(ticket_encryption_key):" + << ticket_encryption_key.size(); + absl::optional context_switcher; + if (!is_sync) { + context_switcher.emplace(connection_context()); + } + + QUIC_TRACESTRING(absl::StrCat( + "TLS select certificate done: ok:", ok, + ", certs_found:", (chain != nullptr && !chain->certs.empty()), + ", len(handshake_hints):", handshake_hints.size(), + ", len(ticket_encryption_key):", ticket_encryption_key.size())); + + ticket_encryption_key_ = std::string(ticket_encryption_key); + select_cert_status_ = QUIC_FAILURE; + cert_matched_sni_ = cert_matched_sni; + + if (delayed_ssl_config.quic_transport_parameters.has_value()) { + // In case of any error the SSL object is still valid. Handshaker may need + // to call ComputeSignature but otherwise can proceed. + if (TransportParametersMatch( + absl::MakeSpan(*delayed_ssl_config.quic_transport_parameters))) { + if (SSL_set_quic_transport_params( + ssl(), delayed_ssl_config.quic_transport_parameters->data(), + delayed_ssl_config.quic_transport_parameters->size()) != 1) { + QUIC_DVLOG(1) << "SSL_set_quic_transport_params override failed"; + } + } else { + QUIC_DVLOG(1) + << "QUIC transport parameters mismatch with ProofSourceHandle"; + } + } + + if (delayed_ssl_config.client_cert_mode.has_value()) { + tls_connection_.SetClientCertMode(*delayed_ssl_config.client_cert_mode); + QUIC_DVLOG(1) << "client_cert_mode after cert selection: " + << client_cert_mode(); + } + + if (ok) { + if (chain && !chain->certs.empty()) { + tls_connection_.SetCertChain(chain->ToCryptoBuffers().value); + if (!handshake_hints.empty() && + !SSL_set_handshake_hints( + ssl(), reinterpret_cast(handshake_hints.data()), + handshake_hints.size())) { + // If |SSL_set_handshake_hints| fails, the ssl() object will remain + // intact, it is as if we didn't call it. The handshaker will + // continue to compute signature/decrypt ticket as normal. + QUIC_CODE_COUNT(quic_tls_server_set_handshake_hints_failed); + QUIC_DVLOG(1) << "SSL_set_handshake_hints failed"; + } + select_cert_status_ = QUIC_SUCCESS; + } else { + QUIC_DLOG(ERROR) << "No certs provided for host '" + << crypto_negotiated_params_->sni << "', server_address:" + << session()->connection()->self_address() + << ", client_address:" + << session()->connection()->peer_address(); + } + } + + QuicConnectionStats::TlsServerOperationStats select_cert_stats; + select_cert_stats.success = (select_cert_status_ == QUIC_SUCCESS); + QUICHE_DCHECK_NE(is_sync, async_op_timer_.has_value()); + if (async_op_timer_.has_value()) { + async_op_timer_->Stop(now()); + select_cert_stats.async_latency = async_op_timer_->GetTotalElapsedTime(); + async_op_timer_.reset(); + RECORD_LATENCY_IN_US("tls_server_async_select_cert_latency_us", + select_cert_stats.async_latency, + "Async select cert latency in microseconds"); + } + connection_stats().tls_server_select_cert_stats = + std::move(select_cert_stats); + + const int last_expected_ssl_error = expected_ssl_error(); + set_expected_ssl_error(SSL_ERROR_WANT_READ); + if (!is_sync) { + QUICHE_DCHECK_EQ(last_expected_ssl_error, SSL_ERROR_PENDING_CERTIFICATE); + AdvanceHandshakeFromCallback(); + } +} + +bool TlsServerHandshaker::WillNotCallComputeSignature() const { + return SSL_can_release_private_key(ssl()); +} + +bool TlsServerHandshaker::ValidateHostname(const std::string& hostname) const { + if (!QuicHostnameUtils::IsValidSNI(hostname)) { + // TODO(b/151676147): Include this error string in the CONNECTION_CLOSE + // frame. + QUIC_DLOG(ERROR) << "Invalid SNI provided: \"" << hostname << "\""; + return false; + } + return true; +} + +int TlsServerHandshaker::TlsExtServernameCallback(int* /*out_alert*/) { + // SSL_TLSEXT_ERR_OK causes the server_name extension to be acked in + // ServerHello. + return SSL_TLSEXT_ERR_OK; +} + +int TlsServerHandshaker::SelectAlpn(const uint8_t** out, uint8_t* out_len, + const uint8_t* in, unsigned in_len) { + // |in| contains a sequence of 1-byte-length-prefixed values. + *out_len = 0; + *out = nullptr; + if (in_len == 0) { + QUIC_DLOG(ERROR) << "No ALPN provided by client"; + return SSL_TLSEXT_ERR_NOACK; + } + + CBS all_alpns; + CBS_init(&all_alpns, in, in_len); + + std::vector alpns; + while (CBS_len(&all_alpns) > 0) { + CBS alpn; + if (!CBS_get_u8_length_prefixed(&all_alpns, &alpn)) { + QUIC_DLOG(ERROR) << "Failed to parse ALPN length"; + return SSL_TLSEXT_ERR_NOACK; + } + + const size_t alpn_length = CBS_len(&alpn); + if (alpn_length == 0) { + QUIC_DLOG(ERROR) << "Received invalid zero-length ALPN"; + return SSL_TLSEXT_ERR_NOACK; + } + + alpns.emplace_back(reinterpret_cast(CBS_data(&alpn)), + alpn_length); + } + + // TODO(wub): Remove QuicSession::SelectAlpn. QuicSessions should know the + // ALPN on construction. + auto selected_alpn = session()->SelectAlpn(alpns); + if (selected_alpn == alpns.end()) { + QUIC_DLOG(ERROR) << "No known ALPN provided by client"; + return SSL_TLSEXT_ERR_NOACK; + } + + session()->OnAlpnSelected(*selected_alpn); + valid_alpn_received_ = true; + *out_len = selected_alpn->size(); + *out = reinterpret_cast(selected_alpn->data()); + return SSL_TLSEXT_ERR_OK; +} + +TlsServerHandshaker::SetApplicationSettingsResult +TlsServerHandshaker::SetApplicationSettings(absl::string_view alpn) { + TlsServerHandshaker::SetApplicationSettingsResult result; + + const std::string& hostname = crypto_negotiated_params_->sni; + std::string accept_ch_value = GetAcceptChValueForHostname(hostname); + std::string origin = absl::StrCat("https://", hostname); + uint16_t port = session()->self_address().port(); + if (port != kDefaultPort) { + // This should be rare in production, but useful for test servers. + QUIC_CODE_COUNT(quic_server_alps_non_default_port); + absl::StrAppend(&origin, ":", port); + } + + if (!accept_ch_value.empty()) { + AcceptChFrame frame{{{std::move(origin), std::move(accept_ch_value)}}}; + result.alps_buffer = HttpEncoder::SerializeAcceptChFrame(frame); + } + + const std::string& alps = result.alps_buffer; + if (SSL_add_application_settings( + ssl(), reinterpret_cast(alpn.data()), alpn.size(), + reinterpret_cast(alps.data()), alps.size()) != 1) { + QUIC_DLOG(ERROR) << "Failed to enable ALPS"; + result.success = false; + } else { + result.success = true; + } + return result; +} + +SSL* TlsServerHandshaker::GetSsl() const { return ssl(); } + +bool TlsServerHandshaker::IsCryptoFrameExpectedForEncryptionLevel( + EncryptionLevel level) const { + return level != ENCRYPTION_ZERO_RTT; +} + +EncryptionLevel TlsServerHandshaker::GetEncryptionLevelToSendCryptoDataOfSpace( + PacketNumberSpace space) const { + switch (space) { + case INITIAL_DATA: + return ENCRYPTION_INITIAL; + case HANDSHAKE_DATA: + return ENCRYPTION_HANDSHAKE; + case APPLICATION_DATA: + return ENCRYPTION_FORWARD_SECURE; + default: + QUICHE_DCHECK(false); + return NUM_ENCRYPTION_LEVELS; + } +} + +} // namespace quic diff --git a/quiche/quic/core/tls_server_handshaker.h b/quiche/quic/core/tls_server_handshaker.h new file mode 100644 index 000000000000..9863ea7446d7 --- /dev/null +++ b/quiche/quic/core/tls_server_handshaker.h @@ -0,0 +1,386 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_TLS_SERVER_HANDSHAKER_H_ +#define QUICHE_QUIC_CORE_TLS_SERVER_HANDSHAKER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "openssl/pool.h" +#include "openssl/ssl.h" +#include "quiche/quic/core/crypto/quic_crypto_server_config.h" +#include "quiche/quic/core/crypto/tls_server_connection.h" +#include "quiche/quic/core/proto/cached_network_parameters_proto.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_crypto_server_stream_base.h" +#include "quiche/quic/core/quic_crypto_stream.h" +#include "quiche/quic/core/quic_time_accumulator.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/tls_handshaker.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_flag_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" + +namespace quic { + +// An implementation of QuicCryptoServerStreamBase which uses +// TLS 1.3 for the crypto handshake protocol. +class QUIC_EXPORT_PRIVATE TlsServerHandshaker + : public TlsHandshaker, + public TlsServerConnection::Delegate, + public ProofSourceHandleCallback, + public QuicCryptoServerStreamBase { + public: + // |crypto_config| must outlive TlsServerHandshaker. + TlsServerHandshaker(QuicSession* session, + const QuicCryptoServerConfig* crypto_config); + TlsServerHandshaker(const TlsServerHandshaker&) = delete; + TlsServerHandshaker& operator=(const TlsServerHandshaker&) = delete; + + ~TlsServerHandshaker() override; + + // From QuicCryptoServerStreamBase + void CancelOutstandingCallbacks() override; + bool GetBase64SHA256ClientChannelID(std::string* output) const override; + void SendServerConfigUpdate( + const CachedNetworkParameters* cached_network_params) override; + bool DisableResumption() override; + bool IsZeroRtt() const override; + bool IsResumption() const override; + bool ResumptionAttempted() const override; + // Must be called after EarlySelectCertCallback is started. + bool EarlyDataAttempted() const override; + int NumServerConfigUpdateMessagesSent() const override; + const CachedNetworkParameters* PreviousCachedNetworkParams() const override; + void SetPreviousCachedNetworkParams( + CachedNetworkParameters cached_network_params) override; + void OnPacketDecrypted(EncryptionLevel level) override; + void OnOneRttPacketAcknowledged() override {} + void OnHandshakePacketSent() override {} + void OnConnectionClosed(QuicErrorCode error, + ConnectionCloseSource source) override; + void OnHandshakeDoneReceived() override; + std::string GetAddressToken( + const CachedNetworkParameters* cached_network_params) const override; + bool ValidateAddressToken(absl::string_view token) const override; + void OnNewTokenReceived(absl::string_view token) override; + bool ShouldSendExpectCTHeader() const override; + bool DidCertMatchSni() const override; + const ProofSource::Details* ProofSourceDetails() const override; + bool ExportKeyingMaterial(absl::string_view label, absl::string_view context, + size_t result_len, std::string* result) override; + SSL* GetSsl() const override; + bool IsCryptoFrameExpectedForEncryptionLevel( + EncryptionLevel level) const override; + EncryptionLevel GetEncryptionLevelToSendCryptoDataOfSpace( + PacketNumberSpace space) const override; + + // From QuicCryptoServerStreamBase and TlsHandshaker + ssl_early_data_reason_t EarlyDataReason() const override; + bool encryption_established() const override; + bool one_rtt_keys_available() const override; + const QuicCryptoNegotiatedParameters& crypto_negotiated_params() + const override; + CryptoMessageParser* crypto_message_parser() override; + HandshakeState GetHandshakeState() const override; + void SetServerApplicationStateForResumption( + std::unique_ptr state) override; + size_t BufferSizeLimitForLevel(EncryptionLevel level) const override; + std::unique_ptr AdvanceKeysAndCreateCurrentOneRttDecrypter() + override; + std::unique_ptr CreateCurrentOneRttEncrypter() override; + void SetWriteSecret(EncryptionLevel level, const SSL_CIPHER* cipher, + absl::Span write_secret) override; + + // Called with normalized SNI hostname as |hostname|. Return value will be + // sent in an ACCEPT_CH frame in the TLS ALPS extension, unless empty. + virtual std::string GetAcceptChValueForHostname( + const std::string& hostname) const; + + // Get the ClientCertMode that is currently in effect on this handshaker. + ClientCertMode client_cert_mode() const { + return tls_connection_.ssl_config().client_cert_mode; + } + + protected: + // Override for tracing. + void InfoCallback(int type, int value) override; + + // Creates a proof source handle for selecting cert and computing signature. + virtual std::unique_ptr MaybeCreateProofSourceHandle(); + + // Hook to allow the server to override parts of the QuicConfig based on SNI + // before we generate transport parameters. + virtual void OverrideQuicConfigDefaults(QuicConfig* config); + + virtual bool ValidateHostname(const std::string& hostname) const; + + const TlsConnection* tls_connection() const override { + return &tls_connection_; + } + + virtual void ProcessAdditionalTransportParameters( + const TransportParameters& /*params*/) {} + + // Called when a potentially async operation is done and the done callback + // needs to advance the handshake. + void AdvanceHandshakeFromCallback(); + + // TlsHandshaker implementation: + void FinishHandshake() override; + void ProcessPostHandshakeMessage() override {} + QuicAsyncStatus VerifyCertChain( + const std::vector& certs, std::string* error_details, + std::unique_ptr* details, uint8_t* out_alert, + std::unique_ptr callback) override; + void OnProofVerifyDetailsAvailable( + const ProofVerifyDetails& verify_details) override; + + // TlsServerConnection::Delegate implementation: + // Used to select certificates and process transport parameters. + ssl_select_cert_result_t EarlySelectCertCallback( + const SSL_CLIENT_HELLO* client_hello) override; + int TlsExtServernameCallback(int* out_alert) override; + int SelectAlpn(const uint8_t** out, uint8_t* out_len, const uint8_t* in, + unsigned in_len) override; + ssl_private_key_result_t PrivateKeySign(uint8_t* out, size_t* out_len, + size_t max_out, uint16_t sig_alg, + absl::string_view in) override; + ssl_private_key_result_t PrivateKeyComplete(uint8_t* out, size_t* out_len, + size_t max_out) override; + size_t SessionTicketMaxOverhead() override; + int SessionTicketSeal(uint8_t* out, size_t* out_len, size_t max_out_len, + absl::string_view in) override; + ssl_ticket_aead_result_t SessionTicketOpen(uint8_t* out, size_t* out_len, + size_t max_out_len, + absl::string_view in) override; + // Called when ticket_decryption_callback_ is done to determine a final + // decryption result. + ssl_ticket_aead_result_t FinalizeSessionTicketOpen(uint8_t* out, + size_t* out_len, + size_t max_out_len); + TlsConnection::Delegate* ConnectionDelegate() override { return this; } + + // The status of cert selection. nullopt means it hasn't started. + const absl::optional& select_cert_status() const { + return select_cert_status_; + } + // Whether |cert_verify_sig_| contains a valid signature. + // NOTE: BoringSSL queries the result of a async signature operation using + // PrivateKeyComplete(), a successful PrivateKeyComplete() will clear the + // content of |cert_verify_sig_|, this function should not be called after + // that. + bool HasValidSignature(size_t max_signature_size) const; + + // ProofSourceHandleCallback implementation: + void OnSelectCertificateDone( + bool ok, bool is_sync, const ProofSource::Chain* chain, + absl::string_view handshake_hints, + absl::string_view ticket_encryption_key, bool cert_matched_sni, + QuicDelayedSSLConfig delayed_ssl_config) override; + + void OnComputeSignatureDone( + bool ok, bool is_sync, std::string signature, + std::unique_ptr details) override; + + void set_encryption_established(bool encryption_established) { + encryption_established_ = encryption_established; + } + + bool WillNotCallComputeSignature() const override; + + void SetIgnoreTicketOpen(bool value) { ignore_ticket_open_ = value; } + + private: + class QUIC_EXPORT_PRIVATE DecryptCallback + : public ProofSource::DecryptCallback { + public: + explicit DecryptCallback(TlsServerHandshaker* handshaker); + void Run(std::vector plaintext) override; + + // If called, Cancel causes the pending callback to be a no-op. + void Cancel(); + + // Return true if either + // - Cancel() has been called. + // - Run() has been called, or is in the middle of it. + bool IsDone() const { return handshaker_ == nullptr; } + + private: + TlsServerHandshaker* handshaker_; + }; + + // DefaultProofSourceHandle delegates all operations to the shared proof + // source. + class QUIC_EXPORT_PRIVATE DefaultProofSourceHandle + : public ProofSourceHandle { + public: + DefaultProofSourceHandle(TlsServerHandshaker* handshaker, + ProofSource* proof_source); + + ~DefaultProofSourceHandle() override; + + // Close the handle. Cancel the pending signature operation, if any. + void CloseHandle() override; + + // Delegates to proof_source_->GetCertChain. + // Returns QUIC_SUCCESS or QUIC_FAILURE. Never returns QUIC_PENDING. + QuicAsyncStatus SelectCertificate( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const QuicConnectionId& original_connection_id, + absl::string_view ssl_capabilities, const std::string& hostname, + absl::string_view client_hello, const std::string& alpn, + absl::optional alps, + const std::vector& quic_transport_params, + const absl::optional>& early_data_context, + const QuicSSLConfig& ssl_config) override; + + // Delegates to proof_source_->ComputeTlsSignature. + // Returns QUIC_SUCCESS, QUIC_FAILURE or QUIC_PENDING. + QuicAsyncStatus ComputeSignature(const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const std::string& hostname, + uint16_t signature_algorithm, + absl::string_view in, + size_t max_signature_size) override; + + protected: + ProofSourceHandleCallback* callback() override { return handshaker_; } + + private: + class QUIC_EXPORT_PRIVATE DefaultSignatureCallback + : public ProofSource::SignatureCallback { + public: + explicit DefaultSignatureCallback(DefaultProofSourceHandle* handle) + : handle_(handle) {} + + void Run(bool ok, std::string signature, + std::unique_ptr details) override { + if (handle_ == nullptr) { + // Operation has been canceled, or Run has been called. + return; + } + + DefaultProofSourceHandle* handle = handle_; + handle_ = nullptr; + + handle->signature_callback_ = nullptr; + if (handle->handshaker_ != nullptr) { + handle->handshaker_->OnComputeSignatureDone( + ok, is_sync_, std::move(signature), std::move(details)); + } + } + + // If called, Cancel causes the pending callback to be a no-op. + void Cancel() { handle_ = nullptr; } + + void set_is_sync(bool is_sync) { is_sync_ = is_sync; } + + private: + DefaultProofSourceHandle* handle_; + // Set to false if handle_->ComputeSignature returns QUIC_PENDING. + bool is_sync_ = true; + }; + + // Not nullptr on construction. Set to nullptr when cancelled. + TlsServerHandshaker* handshaker_; // Not owned. + ProofSource* proof_source_; // Not owned. + DefaultSignatureCallback* signature_callback_ = nullptr; + }; + + struct QUIC_NO_EXPORT SetTransportParametersResult { + bool success = false; + // Empty vector if QUIC transport params are not set successfully. + std::vector quic_transport_params; + // absl::nullopt if there is no application state to begin with. + // Empty vector if application state is not set successfully. + absl::optional> early_data_context; + }; + + SetTransportParametersResult SetTransportParameters(); + bool ProcessTransportParameters(const SSL_CLIENT_HELLO* client_hello, + std::string* error_details); + // Compares |serialized_params| with |server_params_|. + // Returns true if handshaker serialization is equivalent. + bool TransportParametersMatch( + absl::Span serialized_params) const; + + struct QUIC_NO_EXPORT SetApplicationSettingsResult { + bool success = false; + // TODO(b/239676439): Change type to absl::optional and make + // sure SetApplicationSettings() returns nullopt if no ALPS data. + std::string alps_buffer; + }; + SetApplicationSettingsResult SetApplicationSettings(absl::string_view alpn); + + QuicConnectionStats& connection_stats() { + return session()->connection()->mutable_stats(); + } + QuicTime now() const { return session()->GetClock()->Now(); } + + QuicConnectionContext* connection_context() { + return session()->connection()->context(); + } + + std::unique_ptr proof_source_handle_; + ProofSource* proof_source_; + + // State to handle potentially asynchronous session ticket decryption. + // |ticket_decryption_callback_| points to the non-owned callback that was + // passed to ProofSource::TicketCrypter::Decrypt but hasn't finished running + // yet. + std::shared_ptr ticket_decryption_callback_; + // |decrypted_session_ticket_| contains the decrypted session ticket after the + // callback has run but before it is passed to BoringSSL. + std::vector decrypted_session_ticket_; + // |ticket_received_| tracks whether we received a resumption ticket from the + // client. It does not matter whether we were able to decrypt said ticket or + // if we actually resumed a session with it - the presence of this ticket + // indicates that the client attempted a resumption. + bool ticket_received_ = false; + + // True if the "early_data" extension is in the client hello. + bool early_data_attempted_ = false; + + // Force SessionTicketOpen to return ssl_ticket_aead_ignore_ticket if called. + bool ignore_ticket_open_ = false; + + // nullopt means select cert hasn't started. + absl::optional select_cert_status_; + + std::string cert_verify_sig_; + std::unique_ptr proof_source_details_; + + // Count the duration of the current async operation, if any. + absl::optional async_op_timer_; + + std::unique_ptr application_state_; + + // Pre-shared key used during the handshake. + std::string pre_shared_key_; + + // (optional) Key to use for encrypting TLS resumption tickets. + std::string ticket_encryption_key_; + + HandshakeState state_ = HANDSHAKE_START; + bool encryption_established_ = false; + bool valid_alpn_received_ = false; + bool can_disable_resumption_ = true; + quiche::QuicheReferenceCountedPointer + crypto_negotiated_params_; + TlsServerConnection tls_connection_; + const QuicCryptoServerConfig* crypto_config_; // Unowned. + // The last received CachedNetworkParameters from a validated address token. + mutable std::unique_ptr + last_received_cached_network_params_; + + bool cert_matched_sni_ = false; + TransportParameters server_params_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_TLS_SERVER_HANDSHAKER_H_ diff --git a/quiche/quic/core/tls_server_handshaker_test.cc b/quiche/quic/core/tls_server_handshaker_test.cc new file mode 100644 index 000000000000..e68ea65edab5 --- /dev/null +++ b/quiche/quic/core/tls_server_handshaker_test.cc @@ -0,0 +1,1168 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/tls_server_handshaker.h" + +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/certificate_util.h" +#include "quiche/quic/core/crypto/client_proof_source.h" +#include "quiche/quic/core/crypto/proof_source.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_crypto_client_stream.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/core/tls_client_handshaker.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/failing_proof_source.h" +#include "quiche/quic/test_tools/fake_proof_source.h" +#include "quiche/quic/test_tools/fake_proof_source_handle.h" +#include "quiche/quic/test_tools/quic_config_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/test_tools/simple_session_cache.h" +#include "quiche/quic/test_tools/test_certificates.h" +#include "quiche/quic/test_tools/test_ticket_crypter.h" + +namespace quic { +class QuicConnection; +class QuicStream; +} // namespace quic + +using testing::_; +using testing::NiceMock; +using testing::Return; + +namespace quic { +namespace test { + +namespace { + +const char kServerHostname[] = "test.example.com"; +const uint16_t kServerPort = 443; + +struct TestParams { + ParsedQuicVersion version; + bool disable_resumption; +}; + +// Used by ::testing::PrintToStringParamName(). +std::string PrintToString(const TestParams& p) { + return absl::StrCat( + ParsedQuicVersionToString(p.version), "_", + (p.disable_resumption ? "ResumptionDisabled" : "ResumptionEnabled")); +} + +// Constructs test permutations. +std::vector GetTestParams() { + std::vector params; + for (const auto& version : AllSupportedVersionsWithTls()) { + for (bool disable_resumption : {false, true}) { + params.push_back(TestParams{version, disable_resumption}); + } + } + return params; +} + +class TestTlsServerHandshaker : public TlsServerHandshaker { + public: + TestTlsServerHandshaker(QuicSession* session, + const QuicCryptoServerConfig* crypto_config) + : TlsServerHandshaker(session, crypto_config), + proof_source_(crypto_config->proof_source()) { + ON_CALL(*this, MaybeCreateProofSourceHandle()) + .WillByDefault(testing::Invoke( + this, &TestTlsServerHandshaker::RealMaybeCreateProofSourceHandle)); + + ON_CALL(*this, OverrideQuicConfigDefaults(_)) + .WillByDefault(testing::Invoke( + this, &TestTlsServerHandshaker::RealOverrideQuicConfigDefaults)); + } + + MOCK_METHOD(std::unique_ptr, MaybeCreateProofSourceHandle, + (), (override)); + + MOCK_METHOD(void, OverrideQuicConfigDefaults, (QuicConfig * config), + (override)); + + void SetupProofSourceHandle( + FakeProofSourceHandle::Action select_cert_action, + FakeProofSourceHandle::Action compute_signature_action, + QuicDelayedSSLConfig dealyed_ssl_config = QuicDelayedSSLConfig()) { + EXPECT_CALL(*this, MaybeCreateProofSourceHandle()) + .WillOnce( + testing::Invoke([this, select_cert_action, compute_signature_action, + dealyed_ssl_config]() { + auto handle = std::make_unique( + proof_source_, this, select_cert_action, + compute_signature_action, dealyed_ssl_config); + fake_proof_source_handle_ = handle.get(); + return handle; + })); + } + + FakeProofSourceHandle* fake_proof_source_handle() { + return fake_proof_source_handle_; + } + + bool received_client_cert() const { return received_client_cert_; } + + using TlsServerHandshaker::AdvanceHandshake; + using TlsServerHandshaker::expected_ssl_error; + + protected: + QuicAsyncStatus VerifyCertChain( + const std::vector& certs, std::string* error_details, + std::unique_ptr* details, uint8_t* out_alert, + std::unique_ptr callback) override { + received_client_cert_ = true; + return TlsServerHandshaker::VerifyCertChain(certs, error_details, details, + out_alert, std::move(callback)); + } + + private: + std::unique_ptr RealMaybeCreateProofSourceHandle() { + return TlsServerHandshaker::MaybeCreateProofSourceHandle(); + } + + void RealOverrideQuicConfigDefaults(QuicConfig* config) { + return TlsServerHandshaker::OverrideQuicConfigDefaults(config); + } + + // Owned by TlsServerHandshaker. + FakeProofSourceHandle* fake_proof_source_handle_ = nullptr; + ProofSource* proof_source_ = nullptr; + bool received_client_cert_ = false; +}; + +class TlsServerHandshakerTestSession : public TestQuicSpdyServerSession { + public: + using TestQuicSpdyServerSession::TestQuicSpdyServerSession; + + std::unique_ptr CreateQuicCryptoServerStream( + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* /*compressed_certs_cache*/) override { + if (connection()->version().handshake_protocol == PROTOCOL_TLS1_3) { + return std::make_unique>(this, + crypto_config); + } + + QUICHE_CHECK(false) << "Unsupported handshake protocol: " + << connection()->version().handshake_protocol; + return nullptr; + } +}; + +class TlsServerHandshakerTest : public QuicTestWithParam { + public: + TlsServerHandshakerTest() + : server_compressed_certs_cache_( + QuicCompressedCertsCache::kQuicCompressedCertsCacheSize), + server_id_(kServerHostname, kServerPort, false), + supported_versions_({GetParam().version}) { + SetQuicFlag(quic_disable_server_tls_resumption, + GetParam().disable_resumption); + client_crypto_config_ = std::make_unique( + crypto_test_utils::ProofVerifierForTesting(), + std::make_unique()); + InitializeServerConfig(); + InitializeServer(); + InitializeFakeClient(); + } + + ~TlsServerHandshakerTest() override { + // Ensure that anything that might reference |helpers_| is destroyed before + // |helpers_| is destroyed. + server_session_.reset(); + client_session_.reset(); + helpers_.clear(); + alarm_factories_.clear(); + } + + void InitializeServerConfig() { + auto ticket_crypter = std::make_unique(); + ticket_crypter_ = ticket_crypter.get(); + auto proof_source = std::make_unique(); + proof_source_ = proof_source.get(); + proof_source_->SetTicketCrypter(std::move(ticket_crypter)); + server_crypto_config_ = std::make_unique( + QuicCryptoServerConfig::TESTING, QuicRandom::GetInstance(), + std::move(proof_source), KeyExchangeSource::Default()); + } + + void InitializeServerConfigWithFailingProofSource() { + server_crypto_config_ = std::make_unique( + QuicCryptoServerConfig::TESTING, QuicRandom::GetInstance(), + std::make_unique(), KeyExchangeSource::Default()); + } + + void CreateTlsServerHandshakerTestSession(MockQuicConnectionHelper* helper, + MockAlarmFactory* alarm_factory) { + server_connection_ = new PacketSavingConnection( + helper, alarm_factory, Perspective::IS_SERVER, + ParsedVersionOfIndex(supported_versions_, 0)); + + TlsServerHandshakerTestSession* server_session = + new TlsServerHandshakerTestSession( + server_connection_, DefaultQuicConfig(), supported_versions_, + server_crypto_config_.get(), &server_compressed_certs_cache_); + server_session->set_client_cert_mode(initial_client_cert_mode_); + server_session->Initialize(); + + // We advance the clock initially because the default time is zero and the + // strike register worries that we've just overflowed a uint32_t time. + server_connection_->AdvanceTime(QuicTime::Delta::FromSeconds(100000)); + + QUICHE_CHECK(server_session); + server_session_.reset(server_session); + } + + void InitializeServerWithFakeProofSourceHandle() { + helpers_.push_back(std::make_unique>()); + alarm_factories_.push_back(std::make_unique()); + CreateTlsServerHandshakerTestSession(helpers_.back().get(), + alarm_factories_.back().get()); + server_handshaker_ = static_cast*>( + server_session_->GetMutableCryptoStream()); + EXPECT_CALL(*server_session_->helper(), CanAcceptClientHello(_, _, _, _, _)) + .Times(testing::AnyNumber()); + EXPECT_CALL(*server_session_, SelectAlpn(_)) + .WillRepeatedly([this](const std::vector& alpns) { + return std::find( + alpns.cbegin(), alpns.cend(), + AlpnForVersion(server_session_->connection()->version())); + }); + crypto_test_utils::SetupCryptoServerConfigForTest( + server_connection_->clock(), server_connection_->random_generator(), + server_crypto_config_.get()); + } + + // Initializes the crypto server stream state for testing. May be + // called multiple times. + void InitializeServer() { + TestQuicSpdyServerSession* server_session = nullptr; + helpers_.push_back(std::make_unique>()); + alarm_factories_.push_back(std::make_unique()); + CreateServerSessionForTest( + server_id_, QuicTime::Delta::FromSeconds(100000), supported_versions_, + helpers_.back().get(), alarm_factories_.back().get(), + server_crypto_config_.get(), &server_compressed_certs_cache_, + &server_connection_, &server_session); + QUICHE_CHECK(server_session); + server_session_.reset(server_session); + server_handshaker_ = nullptr; + EXPECT_CALL(*server_session_->helper(), CanAcceptClientHello(_, _, _, _, _)) + .Times(testing::AnyNumber()); + EXPECT_CALL(*server_session_, SelectAlpn(_)) + .WillRepeatedly([this](const std::vector& alpns) { + return std::find( + alpns.cbegin(), alpns.cend(), + AlpnForVersion(server_session_->connection()->version())); + }); + crypto_test_utils::SetupCryptoServerConfigForTest( + server_connection_->clock(), server_connection_->random_generator(), + server_crypto_config_.get()); + } + + QuicCryptoServerStreamBase* server_stream() { + return server_session_->GetMutableCryptoStream(); + } + + QuicCryptoClientStream* client_stream() { + return client_session_->GetMutableCryptoStream(); + } + + // Initializes a fake client, and all its associated state, for + // testing. May be called multiple times. + void InitializeFakeClient() { + TestQuicSpdyClientSession* client_session = nullptr; + helpers_.push_back(std::make_unique>()); + alarm_factories_.push_back(std::make_unique()); + CreateClientSessionForTest( + server_id_, QuicTime::Delta::FromSeconds(100000), supported_versions_, + helpers_.back().get(), alarm_factories_.back().get(), + client_crypto_config_.get(), &client_connection_, &client_session); + const std::string default_alpn = + AlpnForVersion(client_connection_->version()); + ON_CALL(*client_session, GetAlpnsToOffer()) + .WillByDefault(Return(std::vector({default_alpn}))); + QUICHE_CHECK(client_session); + client_session_.reset(client_session); + moved_messages_counts_ = {0, 0}; + } + + void CompleteCryptoHandshake() { + while (!client_stream()->one_rtt_keys_available() || + !server_stream()->one_rtt_keys_available()) { + auto previous_moved_messages_counts = moved_messages_counts_; + AdvanceHandshakeWithFakeClient(); + // Check that the handshake has made forward progress + ASSERT_NE(previous_moved_messages_counts, moved_messages_counts_); + } + } + + // Performs a single round of handshake message-exchange between the + // client and server. + void AdvanceHandshakeWithFakeClient() { + QUICHE_CHECK(server_connection_); + QUICHE_CHECK(client_session_ != nullptr); + + EXPECT_CALL(*client_session_, OnProofValid(_)).Times(testing::AnyNumber()); + EXPECT_CALL(*client_session_, OnProofVerifyDetailsAvailable(_)) + .Times(testing::AnyNumber()); + EXPECT_CALL(*client_connection_, OnCanWrite()).Times(testing::AnyNumber()); + EXPECT_CALL(*server_connection_, OnCanWrite()).Times(testing::AnyNumber()); + // Call CryptoConnect if we haven't moved any client messages yet. + if (moved_messages_counts_.first == 0) { + client_stream()->CryptoConnect(); + } + moved_messages_counts_ = crypto_test_utils::AdvanceHandshake( + client_connection_, client_stream(), moved_messages_counts_.first, + server_connection_, server_stream(), moved_messages_counts_.second); + } + + void ExpectHandshakeSuccessful() { + EXPECT_TRUE(client_stream()->one_rtt_keys_available()); + EXPECT_TRUE(client_stream()->encryption_established()); + EXPECT_TRUE(server_stream()->one_rtt_keys_available()); + EXPECT_TRUE(server_stream()->encryption_established()); + EXPECT_EQ(HANDSHAKE_COMPLETE, client_stream()->GetHandshakeState()); + EXPECT_EQ(HANDSHAKE_CONFIRMED, server_stream()->GetHandshakeState()); + + const auto& client_crypto_params = + client_stream()->crypto_negotiated_params(); + const auto& server_crypto_params = + server_stream()->crypto_negotiated_params(); + // The TLS params should be filled in on the client. + EXPECT_NE(0, client_crypto_params.cipher_suite); + EXPECT_NE(0, client_crypto_params.key_exchange_group); + EXPECT_NE(0, client_crypto_params.peer_signature_algorithm); + + // The cipher suite and key exchange group should match on the client and + // server. + EXPECT_EQ(client_crypto_params.cipher_suite, + server_crypto_params.cipher_suite); + EXPECT_EQ(client_crypto_params.key_exchange_group, + server_crypto_params.key_exchange_group); + // We don't support client certs on the server (yet), so the server + // shouldn't have a peer signature algorithm to report. + EXPECT_EQ(0, server_crypto_params.peer_signature_algorithm); + } + + // Should only be called when using FakeProofSourceHandle. + FakeProofSourceHandle::SelectCertArgs last_select_cert_args() const { + QUICHE_CHECK(server_handshaker_ && + server_handshaker_->fake_proof_source_handle()); + QUICHE_CHECK(!server_handshaker_->fake_proof_source_handle() + ->all_select_cert_args() + .empty()); + return server_handshaker_->fake_proof_source_handle() + ->all_select_cert_args() + .back(); + } + + // Should only be called when using FakeProofSourceHandle. + FakeProofSourceHandle::ComputeSignatureArgs last_compute_signature_args() + const { + QUICHE_CHECK(server_handshaker_ && + server_handshaker_->fake_proof_source_handle()); + QUICHE_CHECK(!server_handshaker_->fake_proof_source_handle() + ->all_compute_signature_args() + .empty()); + return server_handshaker_->fake_proof_source_handle() + ->all_compute_signature_args() + .back(); + } + + protected: + // Setup the client to send a (self-signed) client cert to the server, if + // requested. InitializeFakeClient() must be called after this to take effect. + bool SetupClientCert() { + auto client_proof_source = std::make_unique(); + + CertificatePrivateKey client_cert_key( + MakeKeyPairForSelfSignedCertificate()); + + CertificateOptions options; + options.subject = "CN=subject"; + options.serial_number = 0x12345678; + options.validity_start = {2020, 1, 1, 0, 0, 0}; + options.validity_end = {2049, 12, 31, 0, 0, 0}; + std::string der_cert = + CreateSelfSignedCertificate(*client_cert_key.private_key(), options); + + quiche::QuicheReferenceCountedPointer + client_cert_chain(new ClientProofSource::Chain({der_cert})); + + if (!client_proof_source->AddCertAndKey({"*"}, client_cert_chain, + std::move(client_cert_key))) { + return false; + } + + client_crypto_config_->set_proof_source(std::move(client_proof_source)); + return true; + } + + // Every connection gets its own MockQuicConnectionHelper and + // MockAlarmFactory, tracked separately from the server and client state so + // their lifetimes persist through the whole test. + std::vector> helpers_; + std::vector> alarm_factories_; + + // Server state. + PacketSavingConnection* server_connection_; + std::unique_ptr server_session_; + // Only set when initialized with InitializeServerWithFakeProofSourceHandle. + NiceMock* server_handshaker_ = nullptr; + TestTicketCrypter* ticket_crypter_; // owned by proof_source_ + FakeProofSource* proof_source_; // owned by server_crypto_config_ + std::unique_ptr server_crypto_config_; + QuicCompressedCertsCache server_compressed_certs_cache_; + QuicServerId server_id_; + ClientCertMode initial_client_cert_mode_ = ClientCertMode::kNone; + + // Client state. + PacketSavingConnection* client_connection_; + std::unique_ptr client_crypto_config_; + std::unique_ptr client_session_; + + crypto_test_utils::FakeClientOptions client_options_; + // How many handshake messages have been moved from client to server and + // server to client. + std::pair moved_messages_counts_ = {0, 0}; + + // Which QUIC versions the client and server support. + ParsedQuicVersionVector supported_versions_; +}; + +INSTANTIATE_TEST_SUITE_P(TlsServerHandshakerTests, TlsServerHandshakerTest, + ::testing::ValuesIn(GetTestParams()), + ::testing::PrintToStringParamName()); + +TEST_P(TlsServerHandshakerTest, NotInitiallyConected) { + EXPECT_FALSE(server_stream()->encryption_established()); + EXPECT_FALSE(server_stream()->one_rtt_keys_available()); +} + +TEST_P(TlsServerHandshakerTest, ConnectedAfterTlsHandshake) { + CompleteCryptoHandshake(); + EXPECT_EQ(PROTOCOL_TLS1_3, server_stream()->handshake_protocol()); + ExpectHandshakeSuccessful(); +} + +TEST_P(TlsServerHandshakerTest, HandshakeWithAsyncSelectCertSuccess) { + InitializeServerWithFakeProofSourceHandle(); + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_ASYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action:: + DELEGATE_SYNC); + + EXPECT_CALL(*client_connection_, CloseConnection(_, _, _)).Times(0); + EXPECT_CALL(*server_connection_, CloseConnection(_, _, _)).Times(0); + + // Start handshake. + AdvanceHandshakeWithFakeClient(); + + ASSERT_TRUE( + server_handshaker_->fake_proof_source_handle()->HasPendingOperation()); + server_handshaker_->fake_proof_source_handle()->CompletePendingOperation(); + + CompleteCryptoHandshake(); + + ExpectHandshakeSuccessful(); +} + +TEST_P(TlsServerHandshakerTest, HandshakeWithAsyncSelectCertFailure) { + InitializeServerWithFakeProofSourceHandle(); + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::FAIL_ASYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action:: + DELEGATE_SYNC); + + // Start handshake. + AdvanceHandshakeWithFakeClient(); + + ASSERT_TRUE( + server_handshaker_->fake_proof_source_handle()->HasPendingOperation()); + server_handshaker_->fake_proof_source_handle()->CompletePendingOperation(); + + // Check that the server didn't send any handshake messages, because it failed + // to handshake. + EXPECT_EQ(moved_messages_counts_.second, 0u); +} + +TEST_P(TlsServerHandshakerTest, HandshakeWithAsyncSelectCertAndSignature) { + InitializeServerWithFakeProofSourceHandle(); + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_ASYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action:: + DELEGATE_ASYNC); + + EXPECT_CALL(*client_connection_, CloseConnection(_, _, _)).Times(0); + EXPECT_CALL(*server_connection_, CloseConnection(_, _, _)).Times(0); + + // Start handshake. + AdvanceHandshakeWithFakeClient(); + + // A select cert operation is now pending. + ASSERT_TRUE( + server_handshaker_->fake_proof_source_handle()->HasPendingOperation()); + EXPECT_EQ(server_handshaker_->expected_ssl_error(), + SSL_ERROR_PENDING_CERTIFICATE); + + // Complete the pending select cert. It should advance the handshake to + // compute a signature, which will also be saved as a pending operation. + server_handshaker_->fake_proof_source_handle()->CompletePendingOperation(); + + // A compute signature operation is now pending. + ASSERT_TRUE( + server_handshaker_->fake_proof_source_handle()->HasPendingOperation()); + EXPECT_EQ(server_handshaker_->expected_ssl_error(), + SSL_ERROR_WANT_PRIVATE_KEY_OPERATION); + + server_handshaker_->fake_proof_source_handle()->CompletePendingOperation(); + + CompleteCryptoHandshake(); + + ExpectHandshakeSuccessful(); +} + +TEST_P(TlsServerHandshakerTest, HandshakeWithAsyncSignature) { + EXPECT_CALL(*client_connection_, CloseConnection(_, _, _)).Times(0); + EXPECT_CALL(*server_connection_, CloseConnection(_, _, _)).Times(0); + // Enable FakeProofSource to capture call to ComputeTlsSignature and run it + // asynchronously. + proof_source_->Activate(); + + // Start handshake. + AdvanceHandshakeWithFakeClient(); + + ASSERT_EQ(proof_source_->NumPendingCallbacks(), 1); + proof_source_->InvokePendingCallback(0); + + CompleteCryptoHandshake(); + + ExpectHandshakeSuccessful(); +} + +TEST_P(TlsServerHandshakerTest, CancelPendingSelectCert) { + InitializeServerWithFakeProofSourceHandle(); + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_ASYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action:: + DELEGATE_SYNC); + + EXPECT_CALL(*client_connection_, CloseConnection(_, _, _)).Times(0); + EXPECT_CALL(*server_connection_, CloseConnection(_, _, _)).Times(0); + + // Start handshake. + AdvanceHandshakeWithFakeClient(); + + ASSERT_TRUE( + server_handshaker_->fake_proof_source_handle()->HasPendingOperation()); + server_handshaker_->CancelOutstandingCallbacks(); + ASSERT_FALSE( + server_handshaker_->fake_proof_source_handle()->HasPendingOperation()); + // CompletePendingOperation should be noop. + server_handshaker_->fake_proof_source_handle()->CompletePendingOperation(); +} + +TEST_P(TlsServerHandshakerTest, CancelPendingSignature) { + EXPECT_CALL(*client_connection_, CloseConnection(_, _, _)).Times(0); + EXPECT_CALL(*server_connection_, CloseConnection(_, _, _)).Times(0); + // Enable FakeProofSource to capture call to ComputeTlsSignature and run it + // asynchronously. + proof_source_->Activate(); + + // Start handshake. + AdvanceHandshakeWithFakeClient(); + + ASSERT_EQ(proof_source_->NumPendingCallbacks(), 1); + server_session_ = nullptr; + + proof_source_->InvokePendingCallback(0); +} + +TEST_P(TlsServerHandshakerTest, ExtractSNI) { + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + + EXPECT_EQ(server_stream()->crypto_negotiated_params().sni, + "test.example.com"); +} + +TEST_P(TlsServerHandshakerTest, ServerConnectionIdPassedToSelectCert) { + InitializeServerWithFakeProofSourceHandle(); + + // Disable early data. + server_session_->set_early_data_enabled(false); + + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_SYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action:: + DELEGATE_SYNC); + InitializeFakeClient(); + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + + EXPECT_EQ(last_select_cert_args().original_connection_id, TestConnectionId()); +} + +TEST_P(TlsServerHandshakerTest, HostnameForCertSelectionAndComputeSignature) { + // Client uses upper case letters in hostname. It is considered valid by + // QuicHostnameUtils::IsValidSNI, but it should be normalized for cert + // selection. + server_id_ = QuicServerId("tEsT.EXAMPLE.CoM", kServerPort, false); + InitializeServerWithFakeProofSourceHandle(); + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_SYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action:: + DELEGATE_SYNC); + InitializeFakeClient(); + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + + EXPECT_EQ(server_stream()->crypto_negotiated_params().sni, + "test.example.com"); + + EXPECT_EQ(last_select_cert_args().hostname, "test.example.com"); + EXPECT_EQ(last_compute_signature_args().hostname, "test.example.com"); +} + +TEST_P(TlsServerHandshakerTest, SSLConfigForCertSelection) { + InitializeServerWithFakeProofSourceHandle(); + + // Disable early data. + server_session_->set_early_data_enabled(false); + + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_SYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action:: + DELEGATE_SYNC); + InitializeFakeClient(); + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + + EXPECT_FALSE(last_select_cert_args().ssl_config.early_data_enabled); +} + +TEST_P(TlsServerHandshakerTest, ConnectionClosedOnTlsError) { + EXPECT_CALL(*server_connection_, + CloseConnection(QUIC_HANDSHAKE_FAILED, _, _, _)); + + // Send a zero-length ClientHello from client to server. + char bogus_handshake_message[] = { + // Handshake struct (RFC 8446 appendix B.3) + 1, // HandshakeType client_hello + 0, 0, 0, // uint24 length + }; + + // Install a packet flusher such that the packets generated by + // |server_connection_| in response to this handshake message are more likely + // to be coalesced and/or batched in the writer. + // + // This is required by TlsServerHandshaker because without the flusher, it + // tends to generate many small, uncoalesced packets, one per + // TlsHandshaker::WriteMessage. + QuicConnection::ScopedPacketFlusher flusher(server_connection_); + server_stream()->crypto_message_parser()->ProcessInput( + absl::string_view(bogus_handshake_message, + ABSL_ARRAYSIZE(bogus_handshake_message)), + ENCRYPTION_INITIAL); + + EXPECT_FALSE(server_stream()->one_rtt_keys_available()); +} + +TEST_P(TlsServerHandshakerTest, ClientSendingBadALPN) { + const std::string kTestBadClientAlpn = "bad-client-alpn"; + EXPECT_CALL(*client_session_, GetAlpnsToOffer()) + .WillOnce(Return(std::vector({kTestBadClientAlpn}))); + + EXPECT_CALL(*server_connection_, + CloseConnection(QUIC_HANDSHAKE_FAILED, + static_cast( + CRYPTO_ERROR_FIRST + 120), + "TLS handshake failure (ENCRYPTION_INITIAL) 120: " + "no application protocol", + _)); + + AdvanceHandshakeWithFakeClient(); + + EXPECT_FALSE(client_stream()->one_rtt_keys_available()); + EXPECT_FALSE(client_stream()->encryption_established()); + EXPECT_FALSE(server_stream()->one_rtt_keys_available()); + EXPECT_FALSE(server_stream()->encryption_established()); +} + +TEST_P(TlsServerHandshakerTest, CustomALPNNegotiation) { + EXPECT_CALL(*client_connection_, CloseConnection(_, _, _)).Times(0); + EXPECT_CALL(*server_connection_, CloseConnection(_, _, _)).Times(0); + + const std::string kTestAlpn = "A Custom ALPN Value"; + const std::vector kTestAlpns( + {"foo", "bar", kTestAlpn, "something else"}); + EXPECT_CALL(*client_session_, GetAlpnsToOffer()) + .WillRepeatedly(Return(kTestAlpns)); + EXPECT_CALL(*server_session_, SelectAlpn(_)) + .WillOnce( + [kTestAlpn, kTestAlpns](const std::vector& alpns) { + EXPECT_THAT(alpns, testing::ElementsAreArray(kTestAlpns)); + return std::find(alpns.cbegin(), alpns.cend(), kTestAlpn); + }); + EXPECT_CALL(*client_session_, OnAlpnSelected(absl::string_view(kTestAlpn))); + EXPECT_CALL(*server_session_, OnAlpnSelected(absl::string_view(kTestAlpn))); + + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); +} + +TEST_P(TlsServerHandshakerTest, RejectInvalidSNI) { + server_id_ = QuicServerId("invalid!.example.com", kServerPort, false); + InitializeFakeClient(); + static_cast( + QuicCryptoClientStreamPeer::GetHandshaker(client_stream())) + ->AllowInvalidSNIForTests(); + + // Run the handshake and expect it to fail. + AdvanceHandshakeWithFakeClient(); + EXPECT_FALSE(server_stream()->encryption_established()); + EXPECT_FALSE(server_stream()->one_rtt_keys_available()); +} + +TEST_P(TlsServerHandshakerTest, Resumption) { + // Do the first handshake + InitializeFakeClient(); + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + EXPECT_FALSE(client_stream()->IsResumption()); + EXPECT_FALSE(server_stream()->IsResumption()); + EXPECT_FALSE(server_stream()->ResumptionAttempted()); + + // Now do another handshake + InitializeServer(); + InitializeFakeClient(); + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + EXPECT_NE(client_stream()->IsResumption(), GetParam().disable_resumption); + EXPECT_NE(server_stream()->IsResumption(), GetParam().disable_resumption); + EXPECT_NE(server_stream()->ResumptionAttempted(), + GetParam().disable_resumption); +} + +TEST_P(TlsServerHandshakerTest, ResumptionWithAsyncDecryptCallback) { + // Do the first handshake + InitializeFakeClient(); + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + + ticket_crypter_->SetRunCallbacksAsync(true); + // Now do another handshake + InitializeServer(); + InitializeFakeClient(); + + AdvanceHandshakeWithFakeClient(); + if (GetParam().disable_resumption) { + ASSERT_EQ(ticket_crypter_->NumPendingCallbacks(), 0u); + return; + } + // Test that the DecryptCallback will be run asynchronously, and then run it. + ASSERT_EQ(ticket_crypter_->NumPendingCallbacks(), 1u); + ticket_crypter_->RunPendingCallback(0); + + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + EXPECT_TRUE(client_stream()->IsResumption()); + EXPECT_TRUE(server_stream()->IsResumption()); + EXPECT_TRUE(server_stream()->ResumptionAttempted()); +} + +TEST_P(TlsServerHandshakerTest, ResumptionWithPlaceholderTicket) { + // Do the first handshake + InitializeFakeClient(); + + ticket_crypter_->set_fail_encrypt(true); + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + EXPECT_FALSE(client_stream()->IsResumption()); + EXPECT_FALSE(server_stream()->IsResumption()); + EXPECT_FALSE(server_stream()->ResumptionAttempted()); + + // Now do another handshake. It should end up with a full handshake because + // the placeholder ticket is undecryptable. + InitializeServer(); + InitializeFakeClient(); + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + EXPECT_FALSE(client_stream()->IsResumption()); + EXPECT_FALSE(server_stream()->IsResumption()); + EXPECT_NE(server_stream()->ResumptionAttempted(), + GetParam().disable_resumption); +} + +TEST_P(TlsServerHandshakerTest, AdvanceHandshakeDuringAsyncDecryptCallback) { + if (GetParam().disable_resumption) { + return; + } + + // Do the first handshake + InitializeFakeClient(); + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + + ticket_crypter_->SetRunCallbacksAsync(true); + // Now do another handshake + InitializeServerWithFakeProofSourceHandle(); + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_SYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action:: + DELEGATE_SYNC); + InitializeFakeClient(); + + AdvanceHandshakeWithFakeClient(); + + // Ensure an async DecryptCallback is now pending. + ASSERT_EQ(ticket_crypter_->NumPendingCallbacks(), 1u); + + { + QuicConnection::ScopedPacketFlusher flusher(server_connection_); + server_handshaker_->AdvanceHandshake(); + } + + // This will delete |server_handshaker_|. + server_session_ = nullptr; + + ticket_crypter_->RunPendingCallback(0); // Should not crash. +} + +TEST_P(TlsServerHandshakerTest, ResumptionWithFailingDecryptCallback) { + if (GetParam().disable_resumption) { + return; + } + + // Do the first handshake + InitializeFakeClient(); + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + + ticket_crypter_->set_fail_decrypt(true); + // Now do another handshake + InitializeServer(); + InitializeFakeClient(); + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + EXPECT_FALSE(client_stream()->IsResumption()); + EXPECT_FALSE(server_stream()->IsResumption()); + EXPECT_TRUE(server_stream()->ResumptionAttempted()); +} + +TEST_P(TlsServerHandshakerTest, ResumptionWithFailingAsyncDecryptCallback) { + if (GetParam().disable_resumption) { + return; + } + + // Do the first handshake + InitializeFakeClient(); + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + + ticket_crypter_->set_fail_decrypt(true); + ticket_crypter_->SetRunCallbacksAsync(true); + // Now do another handshake + InitializeServer(); + InitializeFakeClient(); + + AdvanceHandshakeWithFakeClient(); + // Test that the DecryptCallback will be run asynchronously, and then run it. + ASSERT_EQ(ticket_crypter_->NumPendingCallbacks(), 1u); + ticket_crypter_->RunPendingCallback(0); + + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + EXPECT_FALSE(client_stream()->IsResumption()); + EXPECT_FALSE(server_stream()->IsResumption()); + EXPECT_TRUE(server_stream()->ResumptionAttempted()); +} + +TEST_P(TlsServerHandshakerTest, HandshakeFailsWithFailingProofSource) { + InitializeServerConfigWithFailingProofSource(); + InitializeServer(); + InitializeFakeClient(); + + // Attempt handshake. + AdvanceHandshakeWithFakeClient(); + // Check that the server didn't send any handshake messages, because it failed + // to handshake. + EXPECT_EQ(moved_messages_counts_.second, 0u); +} + +TEST_P(TlsServerHandshakerTest, ZeroRttResumption) { + std::vector application_state = {0, 1, 2, 3}; + + // Do the first handshake + server_stream()->SetServerApplicationStateForResumption( + std::make_unique(application_state)); + InitializeFakeClient(); + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + EXPECT_FALSE(client_stream()->IsResumption()); + EXPECT_FALSE(server_stream()->IsZeroRtt()); + + // Now do another handshake + InitializeServer(); + server_stream()->SetServerApplicationStateForResumption( + std::make_unique(application_state)); + InitializeFakeClient(); + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + EXPECT_NE(client_stream()->IsResumption(), GetParam().disable_resumption); + EXPECT_NE(server_stream()->IsZeroRtt(), GetParam().disable_resumption); +} + +TEST_P(TlsServerHandshakerTest, ZeroRttRejectOnApplicationStateChange) { + std::vector original_application_state = {1, 2}; + std::vector new_application_state = {3, 4}; + + // Do the first handshake + server_stream()->SetServerApplicationStateForResumption( + std::make_unique(original_application_state)); + InitializeFakeClient(); + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + EXPECT_FALSE(client_stream()->IsResumption()); + EXPECT_FALSE(server_stream()->IsZeroRtt()); + + // Do another handshake, but change the application state + InitializeServer(); + server_stream()->SetServerApplicationStateForResumption( + std::make_unique(new_application_state)); + InitializeFakeClient(); + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + EXPECT_NE(client_stream()->IsResumption(), GetParam().disable_resumption); + EXPECT_FALSE(server_stream()->IsZeroRtt()); +} + +TEST_P(TlsServerHandshakerTest, RequestClientCert) { + ASSERT_TRUE(SetupClientCert()); + InitializeFakeClient(); + + initial_client_cert_mode_ = ClientCertMode::kRequest; + InitializeServerWithFakeProofSourceHandle(); + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_SYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action:: + DELEGATE_SYNC); + + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + EXPECT_TRUE(server_handshaker_->received_client_cert()); +} + +TEST_P(TlsServerHandshakerTest, + SetInvalidServerTransportParamsByDelayedSslConfig) { + ASSERT_TRUE(SetupClientCert()); + InitializeFakeClient(); + + QuicDelayedSSLConfig delayed_ssl_config; + delayed_ssl_config.quic_transport_parameters = {1, 2, 3}; + InitializeServerWithFakeProofSourceHandle(); + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_ASYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action::DELEGATE_SYNC, + delayed_ssl_config); + + AdvanceHandshakeWithFakeClient(); + ASSERT_TRUE( + server_handshaker_->fake_proof_source_handle()->HasPendingOperation()); + server_handshaker_->fake_proof_source_handle()->CompletePendingOperation(); + + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + EXPECT_FALSE(server_handshaker_->fake_proof_source_handle() + ->all_compute_signature_args() + .empty()); +} + +TEST_P(TlsServerHandshakerTest, + SetValidServerTransportParamsByDelayedSslConfig) { + ParsedQuicVersion version = GetParam().version; + + TransportParameters server_params; + std::string error_details; + server_params.perspective = quic::Perspective::IS_SERVER; + server_params.legacy_version_information = + TransportParameters::LegacyVersionInformation(); + server_params.legacy_version_information.value().supported_versions = + quic::CreateQuicVersionLabelVector( + quic::ParsedQuicVersionVector{version}); + server_params.legacy_version_information.value().version = + quic::CreateQuicVersionLabel(version); + server_params.version_information = TransportParameters::VersionInformation(); + server_params.version_information.value().chosen_version = + quic::CreateQuicVersionLabel(version); + server_params.version_information.value().other_versions = + quic::CreateQuicVersionLabelVector( + quic::ParsedQuicVersionVector{version}); + + ASSERT_TRUE(server_params.AreValid(&error_details)) << error_details; + + std::vector server_params_bytes; + ASSERT_TRUE( + SerializeTransportParameters(server_params, &server_params_bytes)); + + ASSERT_TRUE(SetupClientCert()); + InitializeFakeClient(); + + QuicDelayedSSLConfig delayed_ssl_config; + delayed_ssl_config.quic_transport_parameters = server_params_bytes; + InitializeServerWithFakeProofSourceHandle(); + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_ASYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action::DELEGATE_SYNC, + delayed_ssl_config); + + AdvanceHandshakeWithFakeClient(); + ASSERT_TRUE( + server_handshaker_->fake_proof_source_handle()->HasPendingOperation()); + server_handshaker_->fake_proof_source_handle()->CompletePendingOperation(); + + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + EXPECT_FALSE(server_handshaker_->fake_proof_source_handle() + ->all_compute_signature_args() + .empty()); +} + +TEST_P(TlsServerHandshakerTest, RequestClientCertByDelayedSslConfig) { + ASSERT_TRUE(SetupClientCert()); + InitializeFakeClient(); + + QuicDelayedSSLConfig delayed_ssl_config; + delayed_ssl_config.client_cert_mode = ClientCertMode::kRequest; + InitializeServerWithFakeProofSourceHandle(); + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_ASYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action::DELEGATE_SYNC, + delayed_ssl_config); + + AdvanceHandshakeWithFakeClient(); + ASSERT_TRUE( + server_handshaker_->fake_proof_source_handle()->HasPendingOperation()); + server_handshaker_->fake_proof_source_handle()->CompletePendingOperation(); + + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + EXPECT_TRUE(server_handshaker_->received_client_cert()); +} + +TEST_P(TlsServerHandshakerTest, RequestClientCert_NoCert) { + initial_client_cert_mode_ = ClientCertMode::kRequest; + InitializeServerWithFakeProofSourceHandle(); + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_SYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action:: + DELEGATE_SYNC); + + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + EXPECT_FALSE(server_handshaker_->received_client_cert()); +} + +TEST_P(TlsServerHandshakerTest, RequestAndRequireClientCert) { + ASSERT_TRUE(SetupClientCert()); + InitializeFakeClient(); + + initial_client_cert_mode_ = ClientCertMode::kRequire; + InitializeServerWithFakeProofSourceHandle(); + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_SYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action:: + DELEGATE_SYNC); + + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + EXPECT_TRUE(server_handshaker_->received_client_cert()); +} + +TEST_P(TlsServerHandshakerTest, RequestAndRequireClientCertByDelayedSslConfig) { + ASSERT_TRUE(SetupClientCert()); + InitializeFakeClient(); + + QuicDelayedSSLConfig delayed_ssl_config; + delayed_ssl_config.client_cert_mode = ClientCertMode::kRequire; + InitializeServerWithFakeProofSourceHandle(); + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_ASYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action::DELEGATE_SYNC, + delayed_ssl_config); + + AdvanceHandshakeWithFakeClient(); + ASSERT_TRUE( + server_handshaker_->fake_proof_source_handle()->HasPendingOperation()); + server_handshaker_->fake_proof_source_handle()->CompletePendingOperation(); + + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + EXPECT_TRUE(server_handshaker_->received_client_cert()); +} + +TEST_P(TlsServerHandshakerTest, RequestAndRequireClientCert_NoCert) { + initial_client_cert_mode_ = ClientCertMode::kRequire; + InitializeServerWithFakeProofSourceHandle(); + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_SYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action:: + DELEGATE_SYNC); + + EXPECT_CALL(*server_connection_, + CloseConnection(QUIC_TLS_CERTIFICATE_REQUIRED, _, _, _)); + + AdvanceHandshakeWithFakeClient(); + AdvanceHandshakeWithFakeClient(); + EXPECT_FALSE(server_handshaker_->received_client_cert()); +} + +TEST_P(TlsServerHandshakerTest, CloseConnectionBeforeSelectCert) { + InitializeServerWithFakeProofSourceHandle(); + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action:: + FAIL_SYNC_DO_NOT_CHECK_CLOSED, + /*compute_signature_action=*/FakeProofSourceHandle::Action:: + FAIL_SYNC_DO_NOT_CHECK_CLOSED); + + EXPECT_CALL(*server_handshaker_, OverrideQuicConfigDefaults(_)) + .WillOnce(testing::Invoke([](QuicConfig* config) { + QuicConfigPeer::SetReceivedMaxUnidirectionalStreams(config, + /*max_streams=*/0); + })); + + EXPECT_CALL(*server_connection_, + CloseConnection(QUIC_ZERO_RTT_RESUMPTION_LIMIT_REDUCED, _, _)) + .WillOnce(testing::Invoke( + [this](QuicErrorCode error, const std::string& details, + ConnectionCloseBehavior connection_close_behavior) { + server_connection_->ReallyCloseConnection( + error, details, connection_close_behavior); + ASSERT_FALSE(server_connection_->connected()); + })); + + AdvanceHandshakeWithFakeClient(); + + EXPECT_TRUE(server_handshaker_->fake_proof_source_handle() + ->all_select_cert_args() + .empty()); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/uber_quic_stream_id_manager.cc b/quiche/quic/core/uber_quic_stream_id_manager.cc new file mode 100644 index 000000000000..2c8b1a5a650f --- /dev/null +++ b/quiche/quic/core/uber_quic_stream_id_manager.cc @@ -0,0 +1,170 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/uber_quic_stream_id_manager.h" + +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_utils.h" + +namespace quic { + +UberQuicStreamIdManager::UberQuicStreamIdManager( + Perspective perspective, ParsedQuicVersion version, + QuicStreamIdManager::DelegateInterface* delegate, + QuicStreamCount max_open_outgoing_bidirectional_streams, + QuicStreamCount max_open_outgoing_unidirectional_streams, + QuicStreamCount max_open_incoming_bidirectional_streams, + QuicStreamCount max_open_incoming_unidirectional_streams) + : version_(version), + bidirectional_stream_id_manager_(delegate, + /*unidirectional=*/false, perspective, + version, + max_open_outgoing_bidirectional_streams, + max_open_incoming_bidirectional_streams), + unidirectional_stream_id_manager_( + delegate, + /*unidirectional=*/true, perspective, version, + max_open_outgoing_unidirectional_streams, + max_open_incoming_unidirectional_streams) {} + +bool UberQuicStreamIdManager::MaybeAllowNewOutgoingBidirectionalStreams( + QuicStreamCount max_open_streams) { + return bidirectional_stream_id_manager_.MaybeAllowNewOutgoingStreams( + max_open_streams); +} +bool UberQuicStreamIdManager::MaybeAllowNewOutgoingUnidirectionalStreams( + QuicStreamCount max_open_streams) { + return unidirectional_stream_id_manager_.MaybeAllowNewOutgoingStreams( + max_open_streams); +} +void UberQuicStreamIdManager::SetMaxOpenIncomingBidirectionalStreams( + QuicStreamCount max_open_streams) { + bidirectional_stream_id_manager_.SetMaxOpenIncomingStreams(max_open_streams); +} +void UberQuicStreamIdManager::SetMaxOpenIncomingUnidirectionalStreams( + QuicStreamCount max_open_streams) { + unidirectional_stream_id_manager_.SetMaxOpenIncomingStreams(max_open_streams); +} + +bool UberQuicStreamIdManager::CanOpenNextOutgoingBidirectionalStream() const { + return bidirectional_stream_id_manager_.CanOpenNextOutgoingStream(); +} + +bool UberQuicStreamIdManager::CanOpenNextOutgoingUnidirectionalStream() const { + return unidirectional_stream_id_manager_.CanOpenNextOutgoingStream(); +} + +QuicStreamId UberQuicStreamIdManager::GetNextOutgoingBidirectionalStreamId() { + return bidirectional_stream_id_manager_.GetNextOutgoingStreamId(); +} + +QuicStreamId UberQuicStreamIdManager::GetNextOutgoingUnidirectionalStreamId() { + return unidirectional_stream_id_manager_.GetNextOutgoingStreamId(); +} + +bool UberQuicStreamIdManager::MaybeIncreaseLargestPeerStreamId( + QuicStreamId id, std::string* error_details) { + if (QuicUtils::IsBidirectionalStreamId(id, version_)) { + return bidirectional_stream_id_manager_.MaybeIncreaseLargestPeerStreamId( + id, error_details); + } + return unidirectional_stream_id_manager_.MaybeIncreaseLargestPeerStreamId( + id, error_details); +} + +void UberQuicStreamIdManager::OnStreamClosed(QuicStreamId id) { + if (QuicUtils::IsBidirectionalStreamId(id, version_)) { + bidirectional_stream_id_manager_.OnStreamClosed(id); + return; + } + unidirectional_stream_id_manager_.OnStreamClosed(id); +} + +bool UberQuicStreamIdManager::OnStreamsBlockedFrame( + const QuicStreamsBlockedFrame& frame, std::string* error_details) { + if (frame.unidirectional) { + return unidirectional_stream_id_manager_.OnStreamsBlockedFrame( + frame, error_details); + } + return bidirectional_stream_id_manager_.OnStreamsBlockedFrame(frame, + error_details); +} + +bool UberQuicStreamIdManager::IsAvailableStream(QuicStreamId id) const { + if (QuicUtils::IsBidirectionalStreamId(id, version_)) { + return bidirectional_stream_id_manager_.IsAvailableStream(id); + } + return unidirectional_stream_id_manager_.IsAvailableStream(id); +} + +QuicStreamCount +UberQuicStreamIdManager::GetMaxAllowdIncomingBidirectionalStreams() const { + return bidirectional_stream_id_manager_.incoming_initial_max_open_streams(); +} + +QuicStreamCount +UberQuicStreamIdManager::GetMaxAllowdIncomingUnidirectionalStreams() const { + return unidirectional_stream_id_manager_.incoming_initial_max_open_streams(); +} + +QuicStreamId UberQuicStreamIdManager::GetLargestPeerCreatedStreamId( + bool unidirectional) const { + if (unidirectional) { + return unidirectional_stream_id_manager_.largest_peer_created_stream_id(); + } + return bidirectional_stream_id_manager_.largest_peer_created_stream_id(); +} + +QuicStreamId UberQuicStreamIdManager::next_outgoing_bidirectional_stream_id() + const { + return bidirectional_stream_id_manager_.next_outgoing_stream_id(); +} + +QuicStreamId UberQuicStreamIdManager::next_outgoing_unidirectional_stream_id() + const { + return unidirectional_stream_id_manager_.next_outgoing_stream_id(); +} + +QuicStreamCount UberQuicStreamIdManager::max_outgoing_bidirectional_streams() + const { + return bidirectional_stream_id_manager_.outgoing_max_streams(); +} + +QuicStreamCount UberQuicStreamIdManager::max_outgoing_unidirectional_streams() + const { + return unidirectional_stream_id_manager_.outgoing_max_streams(); +} + +QuicStreamCount UberQuicStreamIdManager::max_incoming_bidirectional_streams() + const { + return bidirectional_stream_id_manager_.incoming_actual_max_streams(); +} + +QuicStreamCount UberQuicStreamIdManager::max_incoming_unidirectional_streams() + const { + return unidirectional_stream_id_manager_.incoming_actual_max_streams(); +} + +QuicStreamCount +UberQuicStreamIdManager::advertised_max_incoming_bidirectional_streams() const { + return bidirectional_stream_id_manager_.incoming_advertised_max_streams(); +} + +QuicStreamCount +UberQuicStreamIdManager::advertised_max_incoming_unidirectional_streams() + const { + return unidirectional_stream_id_manager_.incoming_advertised_max_streams(); +} + +QuicStreamCount UberQuicStreamIdManager::outgoing_bidirectional_stream_count() + const { + return bidirectional_stream_id_manager_.outgoing_stream_count(); +} + +QuicStreamCount UberQuicStreamIdManager::outgoing_unidirectional_stream_count() + const { + return unidirectional_stream_id_manager_.outgoing_stream_count(); +} + +} // namespace quic diff --git a/quiche/quic/core/uber_quic_stream_id_manager.h b/quiche/quic/core/uber_quic_stream_id_manager.h new file mode 100644 index 000000000000..e6b73d00cccb --- /dev/null +++ b/quiche/quic/core/uber_quic_stream_id_manager.h @@ -0,0 +1,106 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_UBER_QUIC_STREAM_ID_MANAGER_H_ +#define QUICHE_QUIC_CORE_UBER_QUIC_STREAM_ID_MANAGER_H_ + +#include "quiche/quic/core/quic_stream_id_manager.h" +#include "quiche/quic/core/quic_types.h" + +namespace quic { + +namespace test { +class QuicSessionPeer; +class UberQuicStreamIdManagerPeer; +} // namespace test + +class QuicSession; + +// This class comprises two QuicStreamIdManagers, which manage bidirectional and +// unidirectional stream IDs, respectively. +class QUIC_EXPORT_PRIVATE UberQuicStreamIdManager { + public: + UberQuicStreamIdManager( + Perspective perspective, ParsedQuicVersion version, + QuicStreamIdManager::DelegateInterface* delegate, + QuicStreamCount max_open_outgoing_bidirectional_streams, + QuicStreamCount max_open_outgoing_unidirectional_streams, + QuicStreamCount max_open_incoming_bidirectional_streams, + QuicStreamCount max_open_incoming_unidirectional_streams); + + // Called on |max_open_streams| outgoing streams can be created because of 1) + // config negotiated or 2) MAX_STREAMS received. Returns true if new + // streams can be created. + bool MaybeAllowNewOutgoingBidirectionalStreams( + QuicStreamCount max_open_streams); + bool MaybeAllowNewOutgoingUnidirectionalStreams( + QuicStreamCount max_open_streams); + + // Sets the limits to max_open_streams. + void SetMaxOpenIncomingBidirectionalStreams(QuicStreamCount max_open_streams); + void SetMaxOpenIncomingUnidirectionalStreams( + QuicStreamCount max_open_streams); + + // Returns true if next outgoing bidirectional stream ID can be allocated. + bool CanOpenNextOutgoingBidirectionalStream() const; + + // Returns true if next outgoing unidirectional stream ID can be allocated. + bool CanOpenNextOutgoingUnidirectionalStream() const; + + // Returns the next outgoing bidirectional stream id. + QuicStreamId GetNextOutgoingBidirectionalStreamId(); + + // Returns the next outgoing unidirectional stream id. + QuicStreamId GetNextOutgoingUnidirectionalStreamId(); + + // Returns true if the incoming |id| is within the limit. + bool MaybeIncreaseLargestPeerStreamId(QuicStreamId id, + std::string* error_details); + + // Called when |id| is released. + void OnStreamClosed(QuicStreamId id); + + // Called when a STREAMS_BLOCKED frame is received. + bool OnStreamsBlockedFrame(const QuicStreamsBlockedFrame& frame, + std::string* error_details); + + // Returns true if |id| is still available. + bool IsAvailableStream(QuicStreamId id) const; + + QuicStreamCount GetMaxAllowdIncomingBidirectionalStreams() const; + + QuicStreamCount GetMaxAllowdIncomingUnidirectionalStreams() const; + + QuicStreamId GetLargestPeerCreatedStreamId(bool unidirectional) const; + + QuicStreamId next_outgoing_bidirectional_stream_id() const; + QuicStreamId next_outgoing_unidirectional_stream_id() const; + + QuicStreamCount max_outgoing_bidirectional_streams() const; + QuicStreamCount max_outgoing_unidirectional_streams() const; + + QuicStreamCount max_incoming_bidirectional_streams() const; + QuicStreamCount max_incoming_unidirectional_streams() const; + + QuicStreamCount advertised_max_incoming_bidirectional_streams() const; + QuicStreamCount advertised_max_incoming_unidirectional_streams() const; + + QuicStreamCount outgoing_bidirectional_stream_count() const; + QuicStreamCount outgoing_unidirectional_stream_count() const; + + private: + friend class test::QuicSessionPeer; + friend class test::UberQuicStreamIdManagerPeer; + + ParsedQuicVersion version_; + // Manages stream IDs of bidirectional streams. + QuicStreamIdManager bidirectional_stream_id_manager_; + + // Manages stream IDs of unidirectional streams. + QuicStreamIdManager unidirectional_stream_id_manager_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_UBER_QUIC_STREAM_ID_MANAGER_H_ diff --git a/quiche/quic/core/uber_quic_stream_id_manager_test.cc b/quiche/quic/core/uber_quic_stream_id_manager_test.cc new file mode 100644 index 000000000000..dee343c17e48 --- /dev/null +++ b/quiche/quic/core/uber_quic_stream_id_manager_test.cc @@ -0,0 +1,332 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/uber_quic_stream_id_manager.h" + +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_stream_id_manager_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +using testing::_; +using testing::StrictMock; + +namespace quic { +namespace test { +namespace { + +struct TestParams { + explicit TestParams(ParsedQuicVersion version, Perspective perspective) + : version(version), perspective(perspective) {} + + ParsedQuicVersion version; + Perspective perspective; +}; + +// Used by ::testing::PrintToStringParamName(). +std::string PrintToString(const TestParams& p) { + return absl::StrCat( + ParsedQuicVersionToString(p.version), "_", + (p.perspective == Perspective::IS_CLIENT ? "client" : "server")); +} + +std::vector GetTestParams() { + std::vector params; + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + if (!version.HasIetfQuicFrames()) { + continue; + } + params.push_back(TestParams(version, Perspective::IS_CLIENT)); + params.push_back(TestParams(version, Perspective::IS_SERVER)); + } + return params; +} + +class MockDelegate : public QuicStreamIdManager::DelegateInterface { + public: + MOCK_METHOD(void, SendMaxStreams, + (QuicStreamCount stream_count, bool unidirectional), (override)); +}; + +class UberQuicStreamIdManagerTest : public QuicTestWithParam { + protected: + UberQuicStreamIdManagerTest() + : manager_(perspective(), version(), &delegate_, 0, 0, + kDefaultMaxStreamsPerConnection, + kDefaultMaxStreamsPerConnection) {} + + QuicStreamId GetNthClientInitiatedBidirectionalId(int n) { + return QuicUtils::GetFirstBidirectionalStreamId(transport_version(), + Perspective::IS_CLIENT) + + QuicUtils::StreamIdDelta(transport_version()) * n; + } + + QuicStreamId GetNthClientInitiatedUnidirectionalId(int n) { + return QuicUtils::GetFirstUnidirectionalStreamId(transport_version(), + Perspective::IS_CLIENT) + + QuicUtils::StreamIdDelta(transport_version()) * n; + } + + QuicStreamId GetNthServerInitiatedBidirectionalId(int n) { + return QuicUtils::GetFirstBidirectionalStreamId(transport_version(), + Perspective::IS_SERVER) + + QuicUtils::StreamIdDelta(transport_version()) * n; + } + + QuicStreamId GetNthServerInitiatedUnidirectionalId(int n) { + return QuicUtils::GetFirstUnidirectionalStreamId(transport_version(), + Perspective::IS_SERVER) + + QuicUtils::StreamIdDelta(transport_version()) * n; + } + + QuicStreamId GetNthPeerInitiatedBidirectionalStreamId(int n) { + return ((perspective() == Perspective::IS_SERVER) + ? GetNthClientInitiatedBidirectionalId(n) + : GetNthServerInitiatedBidirectionalId(n)); + } + QuicStreamId GetNthPeerInitiatedUnidirectionalStreamId(int n) { + return ((perspective() == Perspective::IS_SERVER) + ? GetNthClientInitiatedUnidirectionalId(n) + : GetNthServerInitiatedUnidirectionalId(n)); + } + QuicStreamId GetNthSelfInitiatedBidirectionalStreamId(int n) { + return ((perspective() == Perspective::IS_CLIENT) + ? GetNthClientInitiatedBidirectionalId(n) + : GetNthServerInitiatedBidirectionalId(n)); + } + QuicStreamId GetNthSelfInitiatedUnidirectionalStreamId(int n) { + return ((perspective() == Perspective::IS_CLIENT) + ? GetNthClientInitiatedUnidirectionalId(n) + : GetNthServerInitiatedUnidirectionalId(n)); + } + + QuicStreamId StreamCountToId(QuicStreamCount stream_count, + Perspective perspective, bool bidirectional) { + return ((bidirectional) ? QuicUtils::GetFirstBidirectionalStreamId( + transport_version(), perspective) + : QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), perspective)) + + ((stream_count - 1) * QuicUtils::StreamIdDelta(transport_version())); + } + + ParsedQuicVersion version() { return GetParam().version; } + QuicTransportVersion transport_version() { + return version().transport_version; + } + + Perspective perspective() { return GetParam().perspective; } + + testing::StrictMock delegate_; + UberQuicStreamIdManager manager_; +}; + +INSTANTIATE_TEST_SUITE_P(Tests, UberQuicStreamIdManagerTest, + ::testing::ValuesIn(GetTestParams()), + ::testing::PrintToStringParamName()); + +TEST_P(UberQuicStreamIdManagerTest, Initialization) { + EXPECT_EQ(GetNthSelfInitiatedBidirectionalStreamId(0), + manager_.next_outgoing_bidirectional_stream_id()); + EXPECT_EQ(GetNthSelfInitiatedUnidirectionalStreamId(0), + manager_.next_outgoing_unidirectional_stream_id()); +} + +TEST_P(UberQuicStreamIdManagerTest, SetMaxOpenOutgoingStreams) { + const size_t kNumMaxOutgoingStream = 123; + // Set the uni- and bi- directional limits to different values to ensure + // that they are managed separately. + EXPECT_TRUE(manager_.MaybeAllowNewOutgoingBidirectionalStreams( + kNumMaxOutgoingStream)); + EXPECT_TRUE(manager_.MaybeAllowNewOutgoingUnidirectionalStreams( + kNumMaxOutgoingStream + 1)); + EXPECT_EQ(kNumMaxOutgoingStream, + manager_.max_outgoing_bidirectional_streams()); + EXPECT_EQ(kNumMaxOutgoingStream + 1, + manager_.max_outgoing_unidirectional_streams()); + // Check that, for each directionality, we can open the correct number of + // streams. + int i = kNumMaxOutgoingStream; + while (i) { + EXPECT_TRUE(manager_.CanOpenNextOutgoingBidirectionalStream()); + manager_.GetNextOutgoingBidirectionalStreamId(); + EXPECT_TRUE(manager_.CanOpenNextOutgoingUnidirectionalStream()); + manager_.GetNextOutgoingUnidirectionalStreamId(); + i--; + } + // One more unidirectional + EXPECT_TRUE(manager_.CanOpenNextOutgoingUnidirectionalStream()); + manager_.GetNextOutgoingUnidirectionalStreamId(); + + // Both should be exhausted... + EXPECT_FALSE(manager_.CanOpenNextOutgoingUnidirectionalStream()); + EXPECT_FALSE(manager_.CanOpenNextOutgoingBidirectionalStream()); +} + +TEST_P(UberQuicStreamIdManagerTest, SetMaxOpenIncomingStreams) { + const size_t kNumMaxIncomingStreams = 456; + manager_.SetMaxOpenIncomingUnidirectionalStreams(kNumMaxIncomingStreams); + // Do +1 for bidirectional to ensure that uni- and bi- get properly set. + manager_.SetMaxOpenIncomingBidirectionalStreams(kNumMaxIncomingStreams + 1); + EXPECT_EQ(kNumMaxIncomingStreams + 1, + manager_.GetMaxAllowdIncomingBidirectionalStreams()); + EXPECT_EQ(kNumMaxIncomingStreams, + manager_.GetMaxAllowdIncomingUnidirectionalStreams()); + EXPECT_EQ(manager_.max_incoming_bidirectional_streams(), + manager_.advertised_max_incoming_bidirectional_streams()); + EXPECT_EQ(manager_.max_incoming_unidirectional_streams(), + manager_.advertised_max_incoming_unidirectional_streams()); + // Make sure that we can create kNumMaxIncomingStreams incoming unidirectional + // streams and kNumMaxIncomingStreams+1 incoming bidirectional streams. + size_t i; + for (i = 0; i < kNumMaxIncomingStreams; i++) { + EXPECT_TRUE(manager_.MaybeIncreaseLargestPeerStreamId( + GetNthPeerInitiatedUnidirectionalStreamId(i), nullptr)); + EXPECT_TRUE(manager_.MaybeIncreaseLargestPeerStreamId( + GetNthPeerInitiatedBidirectionalStreamId(i), nullptr)); + } + // Should be able to open the next bidirectional stream + EXPECT_TRUE(manager_.MaybeIncreaseLargestPeerStreamId( + GetNthPeerInitiatedBidirectionalStreamId(i), nullptr)); + + // We should have exhausted the counts, the next streams should fail + std::string error_details; + EXPECT_FALSE(manager_.MaybeIncreaseLargestPeerStreamId( + GetNthPeerInitiatedUnidirectionalStreamId(i), &error_details)); + EXPECT_EQ(error_details, + absl::StrCat( + "Stream id ", GetNthPeerInitiatedUnidirectionalStreamId(i), + " would exceed stream count limit ", kNumMaxIncomingStreams)); + EXPECT_FALSE(manager_.MaybeIncreaseLargestPeerStreamId( + GetNthPeerInitiatedBidirectionalStreamId(i + 1), &error_details)); + EXPECT_EQ(error_details, + absl::StrCat("Stream id ", + GetNthPeerInitiatedBidirectionalStreamId(i + 1), + " would exceed stream count limit ", + kNumMaxIncomingStreams + 1)); +} + +TEST_P(UberQuicStreamIdManagerTest, GetNextOutgoingStreamId) { + EXPECT_TRUE(manager_.MaybeAllowNewOutgoingBidirectionalStreams(10)); + EXPECT_TRUE(manager_.MaybeAllowNewOutgoingUnidirectionalStreams(10)); + EXPECT_EQ(GetNthSelfInitiatedBidirectionalStreamId(0), + manager_.GetNextOutgoingBidirectionalStreamId()); + EXPECT_EQ(GetNthSelfInitiatedBidirectionalStreamId(1), + manager_.GetNextOutgoingBidirectionalStreamId()); + EXPECT_EQ(GetNthSelfInitiatedUnidirectionalStreamId(0), + manager_.GetNextOutgoingUnidirectionalStreamId()); + EXPECT_EQ(GetNthSelfInitiatedUnidirectionalStreamId(1), + manager_.GetNextOutgoingUnidirectionalStreamId()); +} + +TEST_P(UberQuicStreamIdManagerTest, AvailableStreams) { + EXPECT_TRUE(manager_.MaybeIncreaseLargestPeerStreamId( + GetNthPeerInitiatedBidirectionalStreamId(3), nullptr)); + EXPECT_TRUE( + manager_.IsAvailableStream(GetNthPeerInitiatedBidirectionalStreamId(1))); + EXPECT_TRUE( + manager_.IsAvailableStream(GetNthPeerInitiatedBidirectionalStreamId(2))); + + EXPECT_TRUE(manager_.MaybeIncreaseLargestPeerStreamId( + GetNthPeerInitiatedUnidirectionalStreamId(3), nullptr)); + EXPECT_TRUE( + manager_.IsAvailableStream(GetNthPeerInitiatedUnidirectionalStreamId(1))); + EXPECT_TRUE( + manager_.IsAvailableStream(GetNthPeerInitiatedUnidirectionalStreamId(2))); +} + +TEST_P(UberQuicStreamIdManagerTest, MaybeIncreaseLargestPeerStreamId) { + EXPECT_TRUE(manager_.MaybeIncreaseLargestPeerStreamId( + StreamCountToId(manager_.max_incoming_bidirectional_streams(), + QuicUtils::InvertPerspective(perspective()), + /* bidirectional=*/true), + nullptr)); + EXPECT_TRUE(manager_.MaybeIncreaseLargestPeerStreamId( + StreamCountToId(manager_.max_incoming_bidirectional_streams(), + QuicUtils::InvertPerspective(perspective()), + /* bidirectional=*/false), + nullptr)); + + std::string expected_error_details = + perspective() == Perspective::IS_SERVER + ? "Stream id 400 would exceed stream count limit 100" + : "Stream id 401 would exceed stream count limit 100"; + std::string error_details; + + EXPECT_FALSE(manager_.MaybeIncreaseLargestPeerStreamId( + StreamCountToId(manager_.max_incoming_bidirectional_streams() + 1, + QuicUtils::InvertPerspective(perspective()), + /* bidirectional=*/true), + &error_details)); + EXPECT_EQ(expected_error_details, error_details); + expected_error_details = + perspective() == Perspective::IS_SERVER + ? "Stream id 402 would exceed stream count limit 100" + : "Stream id 403 would exceed stream count limit 100"; + + EXPECT_FALSE(manager_.MaybeIncreaseLargestPeerStreamId( + StreamCountToId(manager_.max_incoming_bidirectional_streams() + 1, + QuicUtils::InvertPerspective(perspective()), + /* bidirectional=*/false), + &error_details)); + EXPECT_EQ(expected_error_details, error_details); +} + +TEST_P(UberQuicStreamIdManagerTest, OnStreamsBlockedFrame) { + QuicStreamCount stream_count = + manager_.advertised_max_incoming_bidirectional_streams() - 1; + + QuicStreamsBlockedFrame frame(kInvalidControlFrameId, stream_count, + /*unidirectional=*/false); + EXPECT_CALL(delegate_, + SendMaxStreams(manager_.max_incoming_bidirectional_streams(), + frame.unidirectional)) + .Times(0); + EXPECT_TRUE(manager_.OnStreamsBlockedFrame(frame, nullptr)); + + stream_count = manager_.advertised_max_incoming_unidirectional_streams() - 1; + frame.stream_count = stream_count; + frame.unidirectional = true; + + EXPECT_CALL(delegate_, + SendMaxStreams(manager_.max_incoming_unidirectional_streams(), + frame.unidirectional)) + .Times(0); + EXPECT_TRUE(manager_.OnStreamsBlockedFrame(frame, nullptr)); +} + +TEST_P(UberQuicStreamIdManagerTest, SetMaxOpenOutgoingStreamsPlusFrame) { + const size_t kNumMaxOutgoingStream = 123; + // Set the uni- and bi- directional limits to different values to ensure + // that they are managed separately. + EXPECT_TRUE(manager_.MaybeAllowNewOutgoingBidirectionalStreams( + kNumMaxOutgoingStream)); + EXPECT_TRUE(manager_.MaybeAllowNewOutgoingUnidirectionalStreams( + kNumMaxOutgoingStream + 1)); + EXPECT_EQ(kNumMaxOutgoingStream, + manager_.max_outgoing_bidirectional_streams()); + EXPECT_EQ(kNumMaxOutgoingStream + 1, + manager_.max_outgoing_unidirectional_streams()); + // Check that, for each directionality, we can open the correct number of + // streams. + int i = kNumMaxOutgoingStream; + while (i) { + EXPECT_TRUE(manager_.CanOpenNextOutgoingBidirectionalStream()); + manager_.GetNextOutgoingBidirectionalStreamId(); + EXPECT_TRUE(manager_.CanOpenNextOutgoingUnidirectionalStream()); + manager_.GetNextOutgoingUnidirectionalStreamId(); + i--; + } + // One more unidirectional + EXPECT_TRUE(manager_.CanOpenNextOutgoingUnidirectionalStream()); + manager_.GetNextOutgoingUnidirectionalStreamId(); + + // Both should be exhausted... + EXPECT_FALSE(manager_.CanOpenNextOutgoingUnidirectionalStream()); + EXPECT_FALSE(manager_.CanOpenNextOutgoingBidirectionalStream()); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/uber_received_packet_manager.cc b/quiche/quic/core/uber_received_packet_manager.cc new file mode 100644 index 000000000000..4efe7885b719 --- /dev/null +++ b/quiche/quic/core/uber_received_packet_manager.cc @@ -0,0 +1,246 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/uber_received_packet_manager.h" + +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" + +namespace quic { + +UberReceivedPacketManager::UberReceivedPacketManager(QuicConnectionStats* stats) + : supports_multiple_packet_number_spaces_(false) { + for (auto& received_packet_manager : received_packet_managers_) { + received_packet_manager.set_connection_stats(stats); + } +} + +UberReceivedPacketManager::~UberReceivedPacketManager() {} + +void UberReceivedPacketManager::SetFromConfig(const QuicConfig& config, + Perspective perspective) { + for (auto& received_packet_manager : received_packet_managers_) { + received_packet_manager.SetFromConfig(config, perspective); + } +} + +bool UberReceivedPacketManager::IsAwaitingPacket( + EncryptionLevel decrypted_packet_level, + QuicPacketNumber packet_number) const { + if (!supports_multiple_packet_number_spaces_) { + return received_packet_managers_[0].IsAwaitingPacket(packet_number); + } + return received_packet_managers_[QuicUtils::GetPacketNumberSpace( + decrypted_packet_level)] + .IsAwaitingPacket(packet_number); +} + +const QuicFrame UberReceivedPacketManager::GetUpdatedAckFrame( + PacketNumberSpace packet_number_space, QuicTime approximate_now) { + if (!supports_multiple_packet_number_spaces_) { + return received_packet_managers_[0].GetUpdatedAckFrame(approximate_now); + } + return received_packet_managers_[packet_number_space].GetUpdatedAckFrame( + approximate_now); +} + +void UberReceivedPacketManager::RecordPacketReceived( + EncryptionLevel decrypted_packet_level, const QuicPacketHeader& header, + QuicTime receipt_time, QuicEcnCodepoint ecn_codepoint) { + if (!supports_multiple_packet_number_spaces_) { + received_packet_managers_[0].RecordPacketReceived(header, receipt_time, + ecn_codepoint); + return; + } + received_packet_managers_[QuicUtils::GetPacketNumberSpace( + decrypted_packet_level)] + .RecordPacketReceived(header, receipt_time, ecn_codepoint); +} + +void UberReceivedPacketManager::DontWaitForPacketsBefore( + EncryptionLevel decrypted_packet_level, QuicPacketNumber least_unacked) { + if (!supports_multiple_packet_number_spaces_) { + received_packet_managers_[0].DontWaitForPacketsBefore(least_unacked); + return; + } + received_packet_managers_[QuicUtils::GetPacketNumberSpace( + decrypted_packet_level)] + .DontWaitForPacketsBefore(least_unacked); +} + +void UberReceivedPacketManager::MaybeUpdateAckTimeout( + bool should_last_packet_instigate_acks, + EncryptionLevel decrypted_packet_level, + QuicPacketNumber last_received_packet_number, + QuicTime last_packet_receipt_time, QuicTime now, + const RttStats* rtt_stats) { + if (!supports_multiple_packet_number_spaces_) { + received_packet_managers_[0].MaybeUpdateAckTimeout( + should_last_packet_instigate_acks, last_received_packet_number, + last_packet_receipt_time, now, rtt_stats); + return; + } + received_packet_managers_[QuicUtils::GetPacketNumberSpace( + decrypted_packet_level)] + .MaybeUpdateAckTimeout(should_last_packet_instigate_acks, + last_received_packet_number, + last_packet_receipt_time, now, rtt_stats); +} + +void UberReceivedPacketManager::ResetAckStates( + EncryptionLevel encryption_level) { + if (!supports_multiple_packet_number_spaces_) { + received_packet_managers_[0].ResetAckStates(); + return; + } + received_packet_managers_[QuicUtils::GetPacketNumberSpace(encryption_level)] + .ResetAckStates(); + if (encryption_level == ENCRYPTION_INITIAL) { + // After one Initial ACK is sent, the others should be sent 'immediately'. + received_packet_managers_[INITIAL_DATA].set_local_max_ack_delay( + kAlarmGranularity); + } +} + +void UberReceivedPacketManager::EnableMultiplePacketNumberSpacesSupport( + Perspective perspective) { + if (supports_multiple_packet_number_spaces_) { + QUIC_BUG(quic_bug_10495_1) + << "Multiple packet number spaces has already been enabled"; + return; + } + if (received_packet_managers_[0].GetLargestObserved().IsInitialized()) { + QUIC_BUG(quic_bug_10495_2) + << "Try to enable multiple packet number spaces support after any " + "packet has been received."; + return; + } + // In IETF QUIC, the peer is expected to acknowledge packets in Initial and + // Handshake packets with minimal delay. + if (perspective == Perspective::IS_CLIENT) { + // Delay the first server ACK, because server ACKs are padded to + // full size and count towards the amplification limit. + received_packet_managers_[INITIAL_DATA].set_local_max_ack_delay( + kAlarmGranularity); + } + received_packet_managers_[HANDSHAKE_DATA].set_local_max_ack_delay( + kAlarmGranularity); + + supports_multiple_packet_number_spaces_ = true; +} + +bool UberReceivedPacketManager::IsAckFrameUpdated() const { + if (!supports_multiple_packet_number_spaces_) { + return received_packet_managers_[0].ack_frame_updated(); + } + for (const auto& received_packet_manager : received_packet_managers_) { + if (received_packet_manager.ack_frame_updated()) { + return true; + } + } + return false; +} + +QuicPacketNumber UberReceivedPacketManager::GetLargestObserved( + EncryptionLevel decrypted_packet_level) const { + if (!supports_multiple_packet_number_spaces_) { + return received_packet_managers_[0].GetLargestObserved(); + } + return received_packet_managers_[QuicUtils::GetPacketNumberSpace( + decrypted_packet_level)] + .GetLargestObserved(); +} + +QuicTime UberReceivedPacketManager::GetAckTimeout( + PacketNumberSpace packet_number_space) const { + if (!supports_multiple_packet_number_spaces_) { + return received_packet_managers_[0].ack_timeout(); + } + return received_packet_managers_[packet_number_space].ack_timeout(); +} + +QuicTime UberReceivedPacketManager::GetEarliestAckTimeout() const { + QuicTime ack_timeout = QuicTime::Zero(); + // Returns the earliest non-zero ack timeout. + for (const auto& received_packet_manager : received_packet_managers_) { + const QuicTime timeout = received_packet_manager.ack_timeout(); + if (!ack_timeout.IsInitialized()) { + ack_timeout = timeout; + continue; + } + if (timeout.IsInitialized()) { + ack_timeout = std::min(ack_timeout, timeout); + } + } + return ack_timeout; +} + +bool UberReceivedPacketManager::IsAckFrameEmpty( + PacketNumberSpace packet_number_space) const { + if (!supports_multiple_packet_number_spaces_) { + return received_packet_managers_[0].IsAckFrameEmpty(); + } + return received_packet_managers_[packet_number_space].IsAckFrameEmpty(); +} + +QuicPacketNumber UberReceivedPacketManager::peer_least_packet_awaiting_ack() + const { + QUICHE_DCHECK(!supports_multiple_packet_number_spaces_); + return received_packet_managers_[0].peer_least_packet_awaiting_ack(); +} + +size_t UberReceivedPacketManager::min_received_before_ack_decimation() const { + return received_packet_managers_[0].min_received_before_ack_decimation(); +} + +void UberReceivedPacketManager::set_min_received_before_ack_decimation( + size_t new_value) { + for (auto& received_packet_manager : received_packet_managers_) { + received_packet_manager.set_min_received_before_ack_decimation(new_value); + } +} + +void UberReceivedPacketManager::set_ack_frequency(size_t new_value) { + for (auto& received_packet_manager : received_packet_managers_) { + received_packet_manager.set_ack_frequency(new_value); + } +} + +const QuicAckFrame& UberReceivedPacketManager::ack_frame() const { + QUICHE_DCHECK(!supports_multiple_packet_number_spaces_); + return received_packet_managers_[0].ack_frame(); +} + +const QuicAckFrame& UberReceivedPacketManager::GetAckFrame( + PacketNumberSpace packet_number_space) const { + QUICHE_DCHECK(supports_multiple_packet_number_spaces_); + return received_packet_managers_[packet_number_space].ack_frame(); +} + +void UberReceivedPacketManager::set_max_ack_ranges(size_t max_ack_ranges) { + for (auto& received_packet_manager : received_packet_managers_) { + received_packet_manager.set_max_ack_ranges(max_ack_ranges); + } +} + +void UberReceivedPacketManager::set_save_timestamps(bool save_timestamps) { + for (auto& received_packet_manager : received_packet_managers_) { + received_packet_manager.set_save_timestamps( + save_timestamps, supports_multiple_packet_number_spaces_); + } +} + +void UberReceivedPacketManager::OnAckFrequencyFrame( + const QuicAckFrequencyFrame& frame) { + if (!supports_multiple_packet_number_spaces_) { + QUIC_BUG(quic_bug_10495_3) + << "Received AckFrequencyFrame when multiple packet number spaces " + "is not supported"; + return; + } + received_packet_managers_[APPLICATION_DATA].OnAckFrequencyFrame(frame); +} + +} // namespace quic diff --git a/quiche/quic/core/uber_received_packet_manager.h b/quiche/quic/core/uber_received_packet_manager.h new file mode 100644 index 000000000000..0e436c033c37 --- /dev/null +++ b/quiche/quic/core/uber_received_packet_manager.h @@ -0,0 +1,112 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_UBER_RECEIVED_PACKET_MANAGER_H_ +#define QUICHE_QUIC_CORE_UBER_RECEIVED_PACKET_MANAGER_H_ + +#include "quiche/quic/core/frames/quic_ack_frequency_frame.h" +#include "quiche/quic/core/quic_received_packet_manager.h" + +namespace quic { + +// This class comprises multiple received packet managers, one per packet number +// space. Please note, if multiple packet number spaces is not supported, only +// one received packet manager will be used. +class QUIC_EXPORT_PRIVATE UberReceivedPacketManager { + public: + explicit UberReceivedPacketManager(QuicConnectionStats* stats); + UberReceivedPacketManager(const UberReceivedPacketManager&) = delete; + UberReceivedPacketManager& operator=(const UberReceivedPacketManager&) = + delete; + virtual ~UberReceivedPacketManager(); + + void SetFromConfig(const QuicConfig& config, Perspective perspective); + + // Checks if we are still waiting for the packet with |packet_number|. + bool IsAwaitingPacket(EncryptionLevel decrypted_packet_level, + QuicPacketNumber packet_number) const; + + // Called after a packet has been successfully decrypted and its header has + // been parsed. + void RecordPacketReceived(EncryptionLevel decrypted_packet_level, + const QuicPacketHeader& header, + QuicTime receipt_time, + QuicEcnCodepoint ecn_codepoint); + + // Retrieves a frame containing a QuicAckFrame. The ack frame must be + // serialized before another packet is received, or it will change. + const QuicFrame GetUpdatedAckFrame(PacketNumberSpace packet_number_space, + QuicTime approximate_now); + + // Stop ACKing packets before |least_unacked|. + void DontWaitForPacketsBefore(EncryptionLevel decrypted_packet_level, + QuicPacketNumber least_unacked); + + // Called after header of last received packet has been successfully processed + // to update ACK timeout. + void MaybeUpdateAckTimeout(bool should_last_packet_instigate_acks, + EncryptionLevel decrypted_packet_level, + QuicPacketNumber last_received_packet_number, + QuicTime last_packet_receipt_time, QuicTime now, + const RttStats* rtt_stats); + + // Resets ACK related states, called after an ACK is successfully sent. + void ResetAckStates(EncryptionLevel encryption_level); + + // Called to enable multiple packet number support. + void EnableMultiplePacketNumberSpacesSupport(Perspective perspective); + + // Returns true if ACK frame has been updated since GetUpdatedAckFrame was + // last called. + bool IsAckFrameUpdated() const; + + // Returns the largest received packet number. + QuicPacketNumber GetLargestObserved( + EncryptionLevel decrypted_packet_level) const; + + // Returns ACK timeout of |packet_number_space|. + QuicTime GetAckTimeout(PacketNumberSpace packet_number_space) const; + + // Get the earliest ack_timeout of all packet number spaces. + QuicTime GetEarliestAckTimeout() const; + + // Return true if ack frame of |packet_number_space| is empty. + bool IsAckFrameEmpty(PacketNumberSpace packet_number_space) const; + + QuicPacketNumber peer_least_packet_awaiting_ack() const; + + size_t min_received_before_ack_decimation() const; + void set_min_received_before_ack_decimation(size_t new_value); + + void set_ack_frequency(size_t new_value); + + bool supports_multiple_packet_number_spaces() const { + return supports_multiple_packet_number_spaces_; + } + + // For logging purposes. + const QuicAckFrame& ack_frame() const; + const QuicAckFrame& GetAckFrame(PacketNumberSpace packet_number_space) const; + + void set_max_ack_ranges(size_t max_ack_ranges); + + void OnAckFrequencyFrame(const QuicAckFrequencyFrame& frame); + + void set_save_timestamps(bool save_timestamps); + + private: + friend class test::QuicConnectionPeer; + friend class test::UberReceivedPacketManagerPeer; + + // One received packet manager per packet number space. If + // supports_multiple_packet_number_spaces_ is false, only the first (0 index) + // received_packet_manager is used. + QuicReceivedPacketManager received_packet_managers_[NUM_PACKET_NUMBER_SPACES]; + + bool supports_multiple_packet_number_spaces_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_UBER_RECEIVED_PACKET_MANAGER_H_ diff --git a/quiche/quic/core/uber_received_packet_manager_test.cc b/quiche/quic/core/uber_received_packet_manager_test.cc new file mode 100644 index 000000000000..97b2a80d7a76 --- /dev/null +++ b/quiche/quic/core/uber_received_packet_manager_test.cc @@ -0,0 +1,568 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/core/uber_received_packet_manager.h" + +#include + +#include "quiche/quic/core/congestion_control/rtt_stats.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/quic_connection_stats.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/mock_clock.h" + +namespace quic { +namespace test { + +class UberReceivedPacketManagerPeer { + public: + static void SetAckDecimationDelay(UberReceivedPacketManager* manager, + float ack_decimation_delay) { + for (auto& received_packet_manager : manager->received_packet_managers_) { + received_packet_manager.ack_decimation_delay_ = ack_decimation_delay; + } + } +}; + +namespace { + +const bool kInstigateAck = true; +const QuicTime::Delta kMinRttMs = QuicTime::Delta::FromMilliseconds(40); +const QuicTime::Delta kDelayedAckTime = + QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs); + +EncryptionLevel GetEncryptionLevel(PacketNumberSpace packet_number_space) { + switch (packet_number_space) { + case INITIAL_DATA: + return ENCRYPTION_INITIAL; + case HANDSHAKE_DATA: + return ENCRYPTION_HANDSHAKE; + case APPLICATION_DATA: + return ENCRYPTION_FORWARD_SECURE; + default: + QUICHE_DCHECK(false); + return NUM_ENCRYPTION_LEVELS; + } +} + +class UberReceivedPacketManagerTest : public QuicTest { + protected: + UberReceivedPacketManagerTest() { + manager_ = std::make_unique(&stats_); + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(1)); + rtt_stats_.UpdateRtt(kMinRttMs, QuicTime::Delta::Zero(), QuicTime::Zero()); + manager_->set_save_timestamps(true); + } + + void RecordPacketReceipt(uint64_t packet_number) { + RecordPacketReceipt(ENCRYPTION_FORWARD_SECURE, packet_number); + } + + void RecordPacketReceipt(uint64_t packet_number, QuicTime receipt_time) { + RecordPacketReceipt(ENCRYPTION_FORWARD_SECURE, packet_number, receipt_time); + } + + void RecordPacketReceipt(EncryptionLevel decrypted_packet_level, + uint64_t packet_number) { + RecordPacketReceipt(decrypted_packet_level, packet_number, + QuicTime::Zero()); + } + + void RecordPacketReceipt(EncryptionLevel decrypted_packet_level, + uint64_t packet_number, QuicTime receipt_time) { + QuicPacketHeader header; + header.packet_number = QuicPacketNumber(packet_number); + manager_->RecordPacketReceived(decrypted_packet_level, header, receipt_time, + ECN_NOT_ECT); + } + + bool HasPendingAck() { + if (!manager_->supports_multiple_packet_number_spaces()) { + return manager_->GetAckTimeout(APPLICATION_DATA).IsInitialized(); + } + return manager_->GetEarliestAckTimeout().IsInitialized(); + } + + void MaybeUpdateAckTimeout(bool should_last_packet_instigate_acks, + uint64_t last_received_packet_number) { + MaybeUpdateAckTimeout(should_last_packet_instigate_acks, + ENCRYPTION_FORWARD_SECURE, + last_received_packet_number); + } + + void MaybeUpdateAckTimeout(bool should_last_packet_instigate_acks, + EncryptionLevel decrypted_packet_level, + uint64_t last_received_packet_number) { + manager_->MaybeUpdateAckTimeout( + should_last_packet_instigate_acks, decrypted_packet_level, + QuicPacketNumber(last_received_packet_number), clock_.ApproximateNow(), + clock_.ApproximateNow(), &rtt_stats_); + } + + void CheckAckTimeout(QuicTime time) { + QUICHE_DCHECK(HasPendingAck()); + if (!manager_->supports_multiple_packet_number_spaces()) { + QUICHE_DCHECK(manager_->GetAckTimeout(APPLICATION_DATA) == time); + if (time <= clock_.ApproximateNow()) { + // ACK timeout expires, send an ACK. + manager_->ResetAckStates(ENCRYPTION_FORWARD_SECURE); + QUICHE_DCHECK(!HasPendingAck()); + } + return; + } + QUICHE_DCHECK(manager_->GetEarliestAckTimeout() == time); + // Send all expired ACKs. + for (int8_t i = INITIAL_DATA; i < NUM_PACKET_NUMBER_SPACES; ++i) { + const QuicTime ack_timeout = + manager_->GetAckTimeout(static_cast(i)); + if (!ack_timeout.IsInitialized() || + ack_timeout > clock_.ApproximateNow()) { + continue; + } + manager_->ResetAckStates( + GetEncryptionLevel(static_cast(i))); + } + } + + MockClock clock_; + RttStats rtt_stats_; + QuicConnectionStats stats_; + std::unique_ptr manager_; +}; + +TEST_F(UberReceivedPacketManagerTest, DontWaitForPacketsBefore) { + EXPECT_TRUE(manager_->IsAckFrameEmpty(APPLICATION_DATA)); + RecordPacketReceipt(2); + EXPECT_FALSE(manager_->IsAckFrameEmpty(APPLICATION_DATA)); + RecordPacketReceipt(7); + EXPECT_TRUE(manager_->IsAwaitingPacket(ENCRYPTION_FORWARD_SECURE, + QuicPacketNumber(3u))); + EXPECT_TRUE(manager_->IsAwaitingPacket(ENCRYPTION_FORWARD_SECURE, + QuicPacketNumber(6u))); + manager_->DontWaitForPacketsBefore(ENCRYPTION_FORWARD_SECURE, + QuicPacketNumber(4)); + EXPECT_FALSE(manager_->IsAwaitingPacket(ENCRYPTION_FORWARD_SECURE, + QuicPacketNumber(3u))); + EXPECT_TRUE(manager_->IsAwaitingPacket(ENCRYPTION_FORWARD_SECURE, + QuicPacketNumber(6u))); +} + +TEST_F(UberReceivedPacketManagerTest, GetUpdatedAckFrame) { + QuicTime two_ms = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(2); + EXPECT_FALSE(manager_->IsAckFrameUpdated()); + RecordPacketReceipt(2, two_ms); + EXPECT_TRUE(manager_->IsAckFrameUpdated()); + + QuicFrame ack = + manager_->GetUpdatedAckFrame(APPLICATION_DATA, QuicTime::Zero()); + manager_->ResetAckStates(ENCRYPTION_FORWARD_SECURE); + EXPECT_FALSE(manager_->IsAckFrameUpdated()); + // When UpdateReceivedPacketInfo with a time earlier than the time of the + // largest observed packet, make sure that the delta is 0, not negative. + EXPECT_EQ(QuicTime::Delta::Zero(), ack.ack_frame->ack_delay_time); + EXPECT_EQ(1u, ack.ack_frame->received_packet_times.size()); + + QuicTime four_ms = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(4); + ack = manager_->GetUpdatedAckFrame(APPLICATION_DATA, four_ms); + manager_->ResetAckStates(ENCRYPTION_FORWARD_SECURE); + EXPECT_FALSE(manager_->IsAckFrameUpdated()); + // When UpdateReceivedPacketInfo after not having received a new packet, + // the delta should still be accurate. + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(2), + ack.ack_frame->ack_delay_time); + // And received packet times won't have change. + EXPECT_EQ(1u, ack.ack_frame->received_packet_times.size()); + + RecordPacketReceipt(999, two_ms); + RecordPacketReceipt(4, two_ms); + RecordPacketReceipt(1000, two_ms); + EXPECT_TRUE(manager_->IsAckFrameUpdated()); + ack = manager_->GetUpdatedAckFrame(APPLICATION_DATA, two_ms); + manager_->ResetAckStates(ENCRYPTION_FORWARD_SECURE); + EXPECT_FALSE(manager_->IsAckFrameUpdated()); + // UpdateReceivedPacketInfo should discard any times which can't be + // expressed on the wire. + EXPECT_EQ(2u, ack.ack_frame->received_packet_times.size()); +} + +TEST_F(UberReceivedPacketManagerTest, UpdateReceivedConnectionStats) { + EXPECT_FALSE(manager_->IsAckFrameUpdated()); + RecordPacketReceipt(1); + EXPECT_TRUE(manager_->IsAckFrameUpdated()); + RecordPacketReceipt(6); + RecordPacketReceipt(2, + QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(1)); + + EXPECT_EQ(4u, stats_.max_sequence_reordering); + EXPECT_EQ(1000, stats_.max_time_reordering_us); + EXPECT_EQ(1u, stats_.packets_reordered); +} + +TEST_F(UberReceivedPacketManagerTest, LimitAckRanges) { + manager_->set_max_ack_ranges(10); + EXPECT_FALSE(manager_->IsAckFrameUpdated()); + for (int i = 0; i < 100; ++i) { + RecordPacketReceipt(1 + 2 * i); + EXPECT_TRUE(manager_->IsAckFrameUpdated()); + manager_->GetUpdatedAckFrame(APPLICATION_DATA, QuicTime::Zero()); + EXPECT_GE(10u, manager_->ack_frame().packets.NumIntervals()); + EXPECT_EQ(QuicPacketNumber(1u + 2 * i), + manager_->ack_frame().packets.Max()); + for (int j = 0; j < std::min(10, i + 1); ++j) { + ASSERT_GE(i, j); + EXPECT_TRUE(manager_->ack_frame().packets.Contains( + QuicPacketNumber(1 + (i - j) * 2))); + if (i > j) { + EXPECT_FALSE(manager_->ack_frame().packets.Contains( + QuicPacketNumber((i - j) * 2))); + } + } + } +} + +TEST_F(UberReceivedPacketManagerTest, IgnoreOutOfOrderTimestamps) { + EXPECT_FALSE(manager_->IsAckFrameUpdated()); + RecordPacketReceipt(1, QuicTime::Zero()); + EXPECT_TRUE(manager_->IsAckFrameUpdated()); + EXPECT_EQ(1u, manager_->ack_frame().received_packet_times.size()); + RecordPacketReceipt(2, + QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(1)); + EXPECT_EQ(2u, manager_->ack_frame().received_packet_times.size()); + RecordPacketReceipt(3, QuicTime::Zero()); + EXPECT_EQ(2u, manager_->ack_frame().received_packet_times.size()); +} + +TEST_F(UberReceivedPacketManagerTest, OutOfOrderReceiptCausesAckSent) { + EXPECT_FALSE(HasPendingAck()); + + RecordPacketReceipt(3, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 3); + // Delayed ack is scheduled. + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); + + RecordPacketReceipt(2, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 2); + CheckAckTimeout(clock_.ApproximateNow()); + + RecordPacketReceipt(1, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 1); + // Should ack immediately, since this fills the last hole. + CheckAckTimeout(clock_.ApproximateNow()); + + RecordPacketReceipt(4, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 4); + // Delayed ack is scheduled. + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); +} + +TEST_F(UberReceivedPacketManagerTest, OutOfOrderAckReceiptCausesNoAck) { + EXPECT_FALSE(HasPendingAck()); + + RecordPacketReceipt(2, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(!kInstigateAck, 2); + EXPECT_FALSE(HasPendingAck()); + + RecordPacketReceipt(1, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(!kInstigateAck, 1); + EXPECT_FALSE(HasPendingAck()); +} + +TEST_F(UberReceivedPacketManagerTest, AckReceiptCausesAckSend) { + EXPECT_FALSE(HasPendingAck()); + + RecordPacketReceipt(1, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(!kInstigateAck, 1); + EXPECT_FALSE(HasPendingAck()); + + RecordPacketReceipt(2, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(!kInstigateAck, 2); + EXPECT_FALSE(HasPendingAck()); + + RecordPacketReceipt(3, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 3); + // Delayed ack is scheduled. + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); + clock_.AdvanceTime(kDelayedAckTime); + CheckAckTimeout(clock_.ApproximateNow()); + + RecordPacketReceipt(4, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(!kInstigateAck, 4); + EXPECT_FALSE(HasPendingAck()); + + RecordPacketReceipt(5, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(!kInstigateAck, 5); + EXPECT_FALSE(HasPendingAck()); +} + +TEST_F(UberReceivedPacketManagerTest, AckSentEveryNthPacket) { + EXPECT_FALSE(HasPendingAck()); + manager_->set_ack_frequency(3); + + // Receives packets 1 - 39. + for (size_t i = 1; i <= 39; ++i) { + RecordPacketReceipt(i, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, i); + if (i % 3 == 0) { + CheckAckTimeout(clock_.ApproximateNow()); + } else { + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); + } + } +} + +TEST_F(UberReceivedPacketManagerTest, AckDecimationReducesAcks) { + EXPECT_FALSE(HasPendingAck()); + + // Start ack decimation from 10th packet. + manager_->set_min_received_before_ack_decimation(10); + + // Receives packets 1 - 29. + for (size_t i = 1; i <= 29; ++i) { + RecordPacketReceipt(i, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, i); + if (i <= 10) { + // For packets 1-10, ack every 2 packets. + if (i % 2 == 0) { + CheckAckTimeout(clock_.ApproximateNow()); + } else { + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); + } + continue; + } + // ack at 20. + if (i == 20) { + CheckAckTimeout(clock_.ApproximateNow()); + } else { + CheckAckTimeout(clock_.ApproximateNow() + kMinRttMs * 0.25); + } + } + + // We now receive the 30th packet, and so we send an ack. + RecordPacketReceipt(30, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, 30); + CheckAckTimeout(clock_.ApproximateNow()); +} + +TEST_F(UberReceivedPacketManagerTest, SendDelayedAckDecimation) { + EXPECT_FALSE(HasPendingAck()); + // The ack time should be based on min_rtt * 1/4, since it's less than the + // default delayed ack time. + QuicTime ack_time = clock_.ApproximateNow() + kMinRttMs * 0.25; + + // Process all the packets in order so there aren't missing packets. + uint64_t kFirstDecimatedPacket = 101; + for (uint64_t i = 1; i < kFirstDecimatedPacket; ++i) { + RecordPacketReceipt(i, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, i); + if (i % 2 == 0) { + // Ack every 2 packets by default. + CheckAckTimeout(clock_.ApproximateNow()); + } else { + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); + } + } + + RecordPacketReceipt(kFirstDecimatedPacket, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, kFirstDecimatedPacket); + CheckAckTimeout(ack_time); + + // The 10th received packet causes an ack to be sent. + for (uint64_t i = 1; i < 10; ++i) { + RecordPacketReceipt(kFirstDecimatedPacket + i, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, kFirstDecimatedPacket + i); + } + CheckAckTimeout(clock_.ApproximateNow()); +} + +TEST_F(UberReceivedPacketManagerTest, + SendDelayedAckDecimationUnlimitedAggregation) { + EXPECT_FALSE(HasPendingAck()); + QuicConfig config; + QuicTagVector connection_options; + // No limit on the number of packets received before sending an ack. + connection_options.push_back(kAKDU); + config.SetConnectionOptionsToSend(connection_options); + manager_->SetFromConfig(config, Perspective::IS_CLIENT); + + // The ack time should be based on min_rtt/4, since it's less than the + // default delayed ack time. + QuicTime ack_time = clock_.ApproximateNow() + kMinRttMs * 0.25; + + // Process all the initial packets in order so there aren't missing packets. + uint64_t kFirstDecimatedPacket = 101; + for (uint64_t i = 1; i < kFirstDecimatedPacket; ++i) { + RecordPacketReceipt(i, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, i); + if (i % 2 == 0) { + // Ack every 2 packets by default. + CheckAckTimeout(clock_.ApproximateNow()); + } else { + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); + } + } + + RecordPacketReceipt(kFirstDecimatedPacket, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, kFirstDecimatedPacket); + CheckAckTimeout(ack_time); + + // 18 packets will not cause an ack to be sent. 19 will because when + // stop waiting frames are in use, we ack every 20 packets no matter what. + for (int i = 1; i <= 18; ++i) { + RecordPacketReceipt(kFirstDecimatedPacket + i, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, kFirstDecimatedPacket + i); + } + CheckAckTimeout(ack_time); +} + +TEST_F(UberReceivedPacketManagerTest, SendDelayedAckDecimationEighthRtt) { + EXPECT_FALSE(HasPendingAck()); + UberReceivedPacketManagerPeer::SetAckDecimationDelay(manager_.get(), 0.125); + + // The ack time should be based on min_rtt/8, since it's less than the + // default delayed ack time. + QuicTime ack_time = clock_.ApproximateNow() + kMinRttMs * 0.125; + + // Process all the packets in order so there aren't missing packets. + uint64_t kFirstDecimatedPacket = 101; + for (uint64_t i = 1; i < kFirstDecimatedPacket; ++i) { + RecordPacketReceipt(i, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, i); + if (i % 2 == 0) { + // Ack every 2 packets by default. + CheckAckTimeout(clock_.ApproximateNow()); + } else { + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); + } + } + + RecordPacketReceipt(kFirstDecimatedPacket, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, kFirstDecimatedPacket); + CheckAckTimeout(ack_time); + + // The 10th received packet causes an ack to be sent. + for (uint64_t i = 1; i < 10; ++i) { + RecordPacketReceipt(kFirstDecimatedPacket + i, clock_.ApproximateNow()); + MaybeUpdateAckTimeout(kInstigateAck, kFirstDecimatedPacket + i); + } + CheckAckTimeout(clock_.ApproximateNow()); +} + +TEST_F(UberReceivedPacketManagerTest, + DontWaitForPacketsBeforeMultiplePacketNumberSpaces) { + manager_->EnableMultiplePacketNumberSpacesSupport(Perspective::IS_CLIENT); + EXPECT_FALSE( + manager_->GetLargestObserved(ENCRYPTION_HANDSHAKE).IsInitialized()); + EXPECT_FALSE( + manager_->GetLargestObserved(ENCRYPTION_FORWARD_SECURE).IsInitialized()); + RecordPacketReceipt(ENCRYPTION_HANDSHAKE, 2); + RecordPacketReceipt(ENCRYPTION_HANDSHAKE, 4); + RecordPacketReceipt(ENCRYPTION_FORWARD_SECURE, 3); + RecordPacketReceipt(ENCRYPTION_FORWARD_SECURE, 7); + EXPECT_EQ(QuicPacketNumber(4), + manager_->GetLargestObserved(ENCRYPTION_HANDSHAKE)); + EXPECT_EQ(QuicPacketNumber(7), + manager_->GetLargestObserved(ENCRYPTION_FORWARD_SECURE)); + + EXPECT_TRUE( + manager_->IsAwaitingPacket(ENCRYPTION_HANDSHAKE, QuicPacketNumber(3))); + EXPECT_FALSE(manager_->IsAwaitingPacket(ENCRYPTION_FORWARD_SECURE, + QuicPacketNumber(3))); + EXPECT_TRUE(manager_->IsAwaitingPacket(ENCRYPTION_FORWARD_SECURE, + QuicPacketNumber(4))); + + manager_->DontWaitForPacketsBefore(ENCRYPTION_FORWARD_SECURE, + QuicPacketNumber(5)); + EXPECT_TRUE( + manager_->IsAwaitingPacket(ENCRYPTION_HANDSHAKE, QuicPacketNumber(3))); + EXPECT_FALSE(manager_->IsAwaitingPacket(ENCRYPTION_FORWARD_SECURE, + QuicPacketNumber(4))); +} + +TEST_F(UberReceivedPacketManagerTest, AckSendingDifferentPacketNumberSpaces) { + manager_->EnableMultiplePacketNumberSpacesSupport(Perspective::IS_SERVER); + EXPECT_FALSE(HasPendingAck()); + EXPECT_FALSE(manager_->IsAckFrameUpdated()); + + RecordPacketReceipt(ENCRYPTION_INITIAL, 3); + EXPECT_TRUE(manager_->IsAckFrameUpdated()); + MaybeUpdateAckTimeout(kInstigateAck, ENCRYPTION_INITIAL, 3); + EXPECT_TRUE(HasPendingAck()); + // Delayed ack is scheduled. + CheckAckTimeout(clock_.ApproximateNow() + + QuicTime::Delta::FromMilliseconds(25)); + // Send delayed handshake data ACK. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(25)); + CheckAckTimeout(clock_.ApproximateNow()); + EXPECT_FALSE(HasPendingAck()); + + // Second delayed ack should have a shorter delay. + RecordPacketReceipt(ENCRYPTION_INITIAL, 4); + EXPECT_TRUE(manager_->IsAckFrameUpdated()); + MaybeUpdateAckTimeout(kInstigateAck, ENCRYPTION_INITIAL, 4); + EXPECT_TRUE(HasPendingAck()); + // Delayed ack is scheduled. + CheckAckTimeout(clock_.ApproximateNow() + + QuicTime::Delta::FromMilliseconds(1)); + // Send delayed handshake data ACK. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(1)); + CheckAckTimeout(clock_.ApproximateNow()); + EXPECT_FALSE(HasPendingAck()); + + RecordPacketReceipt(ENCRYPTION_HANDSHAKE, 3); + EXPECT_TRUE(manager_->IsAckFrameUpdated()); + MaybeUpdateAckTimeout(kInstigateAck, ENCRYPTION_HANDSHAKE, 3); + EXPECT_TRUE(HasPendingAck()); + // Delayed ack is scheduled. + CheckAckTimeout(clock_.ApproximateNow() + + QuicTime::Delta::FromMilliseconds(1)); + // Send delayed handshake data ACK. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(1)); + CheckAckTimeout(clock_.ApproximateNow()); + EXPECT_FALSE(HasPendingAck()); + + RecordPacketReceipt(ENCRYPTION_FORWARD_SECURE, 3); + MaybeUpdateAckTimeout(kInstigateAck, ENCRYPTION_FORWARD_SECURE, 3); + EXPECT_TRUE(HasPendingAck()); + // Delayed ack is scheduled. + CheckAckTimeout(clock_.ApproximateNow() + kDelayedAckTime); + + RecordPacketReceipt(ENCRYPTION_FORWARD_SECURE, 2); + MaybeUpdateAckTimeout(kInstigateAck, ENCRYPTION_FORWARD_SECURE, 2); + // Application data ACK should be sent immediately. + CheckAckTimeout(clock_.ApproximateNow()); + EXPECT_FALSE(HasPendingAck()); +} + +TEST_F(UberReceivedPacketManagerTest, + AckTimeoutForPreviouslyUndecryptablePackets) { + manager_->EnableMultiplePacketNumberSpacesSupport(Perspective::IS_SERVER); + EXPECT_FALSE(HasPendingAck()); + EXPECT_FALSE(manager_->IsAckFrameUpdated()); + + // Received undecryptable 1-RTT packet 4. + const QuicTime packet_receipt_time4 = clock_.ApproximateNow(); + // 1-RTT keys become available after 10ms because HANDSHAKE 5 gets received. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + RecordPacketReceipt(ENCRYPTION_HANDSHAKE, 5); + MaybeUpdateAckTimeout(kInstigateAck, ENCRYPTION_HANDSHAKE, 5); + EXPECT_TRUE(HasPendingAck()); + RecordPacketReceipt(ENCRYPTION_FORWARD_SECURE, 4); + manager_->MaybeUpdateAckTimeout(kInstigateAck, ENCRYPTION_FORWARD_SECURE, + QuicPacketNumber(4), packet_receipt_time4, + clock_.ApproximateNow(), &rtt_stats_); + + // Send delayed handshake ACK. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(1)); + CheckAckTimeout(clock_.ApproximateNow()); + + EXPECT_TRUE(HasPendingAck()); + // Verify ACK delay is based on packet receipt time. + CheckAckTimeout(clock_.ApproximateNow() - + QuicTime::Delta::FromMilliseconds(11) + kDelayedAckTime); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/core/web_transport_interface.h b/quiche/quic/core/web_transport_interface.h new file mode 100644 index 000000000000..dc145f7f55c3 --- /dev/null +++ b/quiche/quic/core/web_transport_interface.h @@ -0,0 +1,53 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This header contains interfaces that abstract away different backing +// protocols for WebTransport. + +#ifndef QUICHE_QUIC_CORE_WEB_TRANSPORT_INTERFACE_H_ +#define QUICHE_QUIC_CORE_WEB_TRANSPORT_INTERFACE_H_ + +#include "quiche/quic/core/quic_types.h" +#include "quiche/web_transport/web_transport.h" + +namespace quic { + +using WebTransportSessionError = webtransport::SessionErrorCode; +using WebTransportStreamError = webtransport::StreamErrorCode; + +using WebTransportStreamVisitor = webtransport::StreamVisitor; +using WebTransportStream = webtransport::Stream; +using WebTransportVisitor = webtransport::SessionVisitor; +using WebTransportSession = webtransport::Session; + +inline webtransport::DatagramStatus MessageStatusToWebTransportStatus( + MessageStatus status) { + switch (status) { + case MESSAGE_STATUS_SUCCESS: + return webtransport::DatagramStatus( + webtransport::DatagramStatusCode::kSuccess, ""); + case MESSAGE_STATUS_BLOCKED: + return webtransport::DatagramStatus( + webtransport::DatagramStatusCode::kBlocked, + "QUIC connection write-blocked"); + case MESSAGE_STATUS_TOO_LARGE: + return webtransport::DatagramStatus( + webtransport::DatagramStatusCode::kTooBig, + "Datagram payload exceeded maximum allowed size"); + case MESSAGE_STATUS_ENCRYPTION_NOT_ESTABLISHED: + case MESSAGE_STATUS_INTERNAL_ERROR: + case MESSAGE_STATUS_UNSUPPORTED: + return webtransport::DatagramStatus( + webtransport::DatagramStatusCode::kInternalError, + absl::StrCat("Internal error: ", MessageStatusToString(status))); + default: + return webtransport::DatagramStatus( + webtransport::DatagramStatusCode::kInternalError, + absl::StrCat("Unknown status: ", MessageStatusToString(status))); + } +} + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_WEB_TRANSPORT_INTERFACE_H_ diff --git a/quiche/quic/load_balancer/load_balancer_config.cc b/quiche/quic/load_balancer/load_balancer_config.cc new file mode 100644 index 000000000000..62d8898b9f36 --- /dev/null +++ b/quiche/quic/load_balancer/load_balancer_config.cc @@ -0,0 +1,202 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/load_balancer/load_balancer_config.h" + +#include +#include + +#include "openssl/aes.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" + +namespace quic { + +namespace { + +// Validates all non-key parts of the input. +bool CommonValidation(const uint8_t config_id, const uint8_t server_id_len, + const uint8_t nonce_len) { + if (config_id >= kNumLoadBalancerConfigs || server_id_len == 0 || + nonce_len < kLoadBalancerMinNonceLen || + nonce_len > kLoadBalancerMaxNonceLen || + server_id_len > + (kQuicMaxConnectionIdWithLengthPrefixLength - nonce_len - 1)) { + QUIC_BUG(quic_bug_433862549_01) + << "Invalid LoadBalancerConfig " + << "Config ID " << static_cast(config_id) << " Server ID Length " + << static_cast(server_id_len) << " Nonce Length " + << static_cast(nonce_len); + return false; + } + return true; +} + +// Initialize the key in the constructor +absl::optional BuildKey(absl::string_view key, bool encrypt) { + if (key.empty()) { + return absl::optional(); + } + AES_KEY raw_key; + if (encrypt) { + if (AES_set_encrypt_key(reinterpret_cast(key.data()), + key.size() * 8, &raw_key) < 0) { + return absl::optional(); + } + } else if (AES_set_decrypt_key(reinterpret_cast(key.data()), + key.size() * 8, &raw_key) < 0) { + return absl::optional(); + } + return raw_key; +} + +// Functions to handle 4-pass encryption/decryption. +// TakePlaintextFrom{Left,Right}() reads the left or right half of 'from' and +// expands it into a full encryption block ('to') in accordance with the +// internet-draft. +void TakePlaintextFromLeft(const uint8_t *from, const uint8_t plaintext_len, + const uint8_t index, uint8_t *to) { + uint8_t half = plaintext_len / 2; + + to[0] = plaintext_len; + to[1] = index; + memcpy(to + 2, from, half); + if (plaintext_len % 2) { + to[2 + half] = from[half] & 0xf0; + half++; + } + memset(to + 2 + half, 0, kLoadBalancerBlockSize - 2 - half); +} + +void TakePlaintextFromRight(const uint8_t *from, const uint8_t plaintext_len, + const uint8_t index, uint8_t *to) { + uint8_t half = plaintext_len / 2; + + to[0] = plaintext_len; + to[1] = index; + memcpy(to + 2, from + half, half + (plaintext_len % 2)); + if (plaintext_len % 2) { + to[2] &= 0x0f; + half++; + } + memset(to + 2 + half, 0, kLoadBalancerBlockSize - 2 - half); +} + +// CiphertextXorWith{Left,Right}() takes the relevant end of the ciphertext in +// 'from' and XORs it with half of the ConnectionId stored at 'to', in +// accordance with the internet-draft. +void CiphertextXorWithLeft(const uint8_t *from, const uint8_t plaintext_len, + uint8_t *to) { + uint8_t half = plaintext_len / 2; + for (int i = 0; i < half; i++) { + to[i] ^= from[i]; + } + if (plaintext_len % 2) { + to[half] ^= (from[half] & 0xf0); + } +} + +void CiphertextXorWithRight(const uint8_t *from, const uint8_t plaintext_len, + uint8_t *to) { + uint8_t half = plaintext_len / 2; + int i = 0; + if (plaintext_len % 2) { + to[half] ^= (from[0] & 0x0f); + i++; + } + while ((half + i) < plaintext_len) { + to[half + i] ^= from[i]; + i++; + } +} + +} // namespace + +absl::optional LoadBalancerConfig::Create( + const uint8_t config_id, const uint8_t server_id_len, + const uint8_t nonce_len, const absl::string_view key) { + // Check for valid parameters. + if (key.size() != kLoadBalancerKeyLen) { + QUIC_BUG(quic_bug_433862549_02) + << "Invalid LoadBalancerConfig Key Length: " << key.size(); + return absl::optional(); + } + if (!CommonValidation(config_id, server_id_len, nonce_len)) { + return absl::optional(); + } + auto new_config = + LoadBalancerConfig(config_id, server_id_len, nonce_len, key); + if (!new_config.IsEncrypted()) { + // Something went wrong in assigning the key! + QUIC_BUG(quic_bug_433862549_03) << "Something went wrong in initializing " + "the load balancing key."; + return absl::optional(); + } + return new_config; +} + +// Creates an unencrypted config. +absl::optional LoadBalancerConfig::CreateUnencrypted( + const uint8_t config_id, const uint8_t server_id_len, + const uint8_t nonce_len) { + return CommonValidation(config_id, server_id_len, nonce_len) + ? LoadBalancerConfig(config_id, server_id_len, nonce_len, "") + : absl::optional(); +} + +bool LoadBalancerConfig::EncryptionPass(absl::Span target, + const uint8_t index) const { + uint8_t buf[kLoadBalancerBlockSize]; + if (!key_.has_value() || target.size() < plaintext_len()) { + return false; + } + if (index % 2) { // Odd indices go from left to right + TakePlaintextFromLeft(target.data(), plaintext_len(), index, buf); + } else { + TakePlaintextFromRight(target.data(), plaintext_len(), index, buf); + } + if (!BlockEncrypt(buf, buf)) { + return false; + } + // XOR bits over the correct half. + if (index % 2) { + CiphertextXorWithRight(buf, plaintext_len(), target.data()); + } else { + CiphertextXorWithLeft(buf, plaintext_len(), target.data()); + } + return true; +} + +bool LoadBalancerConfig::BlockEncrypt( + const uint8_t plaintext[kLoadBalancerBlockSize], + uint8_t ciphertext[kLoadBalancerBlockSize]) const { + if (!key_.has_value()) { + return false; + } + AES_encrypt(plaintext, ciphertext, &key_.value()); + return true; +} + +bool LoadBalancerConfig::BlockDecrypt( + const uint8_t ciphertext[kLoadBalancerBlockSize], + uint8_t plaintext[kLoadBalancerBlockSize]) const { + if (!block_decrypt_key_.has_value()) { + return false; + } + AES_decrypt(ciphertext, plaintext, &block_decrypt_key_.value()); + return true; +} + +LoadBalancerConfig::LoadBalancerConfig(const uint8_t config_id, + const uint8_t server_id_len, + const uint8_t nonce_len, + const absl::string_view key) + : config_id_(config_id), + server_id_len_(server_id_len), + nonce_len_(nonce_len), + key_(BuildKey(key, /* encrypt = */ true)), + block_decrypt_key_((server_id_len + nonce_len == kLoadBalancerBlockSize) + ? BuildKey(key, /* encrypt = */ false) + : absl::optional()) {} + +} // namespace quic diff --git a/quiche/quic/load_balancer/load_balancer_config.h b/quiche/quic/load_balancer/load_balancer_config.h new file mode 100644 index 000000000000..bc92d5fc1840 --- /dev/null +++ b/quiche/quic/load_balancer/load_balancer_config.h @@ -0,0 +1,94 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_CONFIG_H_ +#define QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_CONFIG_H_ + +#include "openssl/aes.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +inline constexpr uint8_t kNumLoadBalancerConfigs = 3; +inline constexpr uint8_t kLoadBalancerKeyLen = 16; +// Regardless of key length, the AES block size is always 16 Bytes. +inline constexpr uint8_t kLoadBalancerBlockSize = 16; +// The spec says nonces can be 18 bytes, but 16 lets it be a uint128. +inline constexpr uint8_t kLoadBalancerMaxNonceLen = 16; +inline constexpr uint8_t kLoadBalancerMinNonceLen = 4; +inline constexpr uint8_t kNumLoadBalancerCryptoPasses = 4; + +// This the base class for QUIC-LB configuration. It contains configuration +// elements usable by both encoders (servers) and decoders (load balancers). +// Confusingly, it is called "LoadBalancerConfig" because it pertains to objects +// that both servers and load balancers use to interact with each other. +class QUIC_EXPORT_PRIVATE LoadBalancerConfig { + public: + // This factory function initializes an encrypted LoadBalancerConfig and + // returns it in absl::optional, which is empty if the config is invalid. + // config_id: The first two bits of the Connection Id. Must be no larger than + // 2. + // server_id_len: Expected length of the server ids associated with this + // config. Must be greater than 0 and less than 16. + // nonce_len: Length of the nonce. Must be at least 4 and no larger than 16. + // Further the server_id_len + nonce_len must be no larger than 19. + // key: The encryption key must be 16B long. + static absl::optional Create(const uint8_t config_id, + const uint8_t server_id_len, + const uint8_t nonce_len, + const absl::string_view key); + + // Creates an unencrypted config. + static absl::optional CreateUnencrypted( + const uint8_t config_id, const uint8_t server_id_len, + const uint8_t nonce_len); + + // Handles one pass of 4-pass encryption. Encoder and decoder use of this + // function varies substantially, so they are not implemented here. + // Returns false if the config is not encrypted, or if |target| isn't long + // enough. + ABSL_MUST_USE_RESULT bool EncryptionPass(absl::Span target, + const uint8_t index) const; + // Use the key to do a block encryption, which is used both in all cases of + // encrypted configs. Returns false if there's no key. + ABSL_MUST_USE_RESULT bool BlockEncrypt( + const uint8_t plaintext[kLoadBalancerBlockSize], + uint8_t ciphertext[kLoadBalancerBlockSize]) const; + // Returns false if the config does not require block decryption. + ABSL_MUST_USE_RESULT bool BlockDecrypt( + const uint8_t ciphertext[kLoadBalancerBlockSize], + uint8_t plaintext[kLoadBalancerBlockSize]) const; + + uint8_t config_id() const { return config_id_; } + uint8_t server_id_len() const { return server_id_len_; } + uint8_t nonce_len() const { return nonce_len_; } + // Returns length of all but the first octet. + uint8_t plaintext_len() const { return server_id_len_ + nonce_len_; } + // Returns length of the entire connection ID. + uint8_t total_len() const { return server_id_len_ + nonce_len_ + 1; } + bool IsEncrypted() const { return key_.has_value(); } + + private: + // Constructor is private because it doesn't validate input. + LoadBalancerConfig(uint8_t config_id, uint8_t server_id_len, + uint8_t nonce_len, absl::string_view key); + + uint8_t config_id_; + uint8_t server_id_len_; + uint8_t nonce_len_; + // All Connection ID encryption and decryption uses the AES_encrypt function + // at root, so there is a single key for all of it. This is empty if the + // config is not encrypted. + absl::optional key_; + // The one exception is that when total_len == 16, connection ID decryption + // uses AES_decrypt. The bytes that comprise the key are the same, but + // AES_decrypt requires an AES_KEY that is initialized differently. In all + // other cases, block_decrypt_key_ is empty. + absl::optional block_decrypt_key_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_CONFIG_H_ diff --git a/quiche/quic/load_balancer/load_balancer_config_test.cc b/quiche/quic/load_balancer/load_balancer_config_test.cc new file mode 100644 index 000000000000..bb3d36748bf7 --- /dev/null +++ b/quiche/quic/load_balancer/load_balancer_config_test.cc @@ -0,0 +1,190 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/load_balancer/load_balancer_config.h" + +#include + +#include "absl/types/span.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { + +namespace test { + +namespace { + +constexpr char raw_key[] = { + 0xfd, 0xf7, 0x26, 0xa9, 0x89, 0x3e, 0xc0, 0x5c, + 0x06, 0x32, 0xd3, 0x95, 0x66, 0x80, 0xba, 0xf0, +}; + +class LoadBalancerConfigTest : public QuicTest {}; + +TEST_F(LoadBalancerConfigTest, InvalidParams) { + // Bogus config_id. + EXPECT_QUIC_BUG( + EXPECT_FALSE(LoadBalancerConfig::CreateUnencrypted(3, 4, 10).has_value()), + "Invalid LoadBalancerConfig Config ID 3 Server ID Length 4 " + "Nonce Length 10"); + // Bad Server ID lengths. + EXPECT_QUIC_BUG(EXPECT_FALSE(LoadBalancerConfig::Create( + 2, 0, 10, absl::string_view(raw_key, 16)) + .has_value()), + "Invalid LoadBalancerConfig Config ID 2 Server ID Length 0 " + "Nonce Length 10"); + EXPECT_QUIC_BUG( + EXPECT_FALSE(LoadBalancerConfig::CreateUnencrypted(2, 16, 4).has_value()), + "Invalid LoadBalancerConfig Config ID 2 Server ID Length 16 " + "Nonce Length 4"); + // Bad Nonce lengths. + EXPECT_QUIC_BUG( + EXPECT_FALSE(LoadBalancerConfig::CreateUnencrypted(2, 4, 2).has_value()), + "Invalid LoadBalancerConfig Config ID 2 Server ID Length 4 " + "Nonce Length 2"); + EXPECT_QUIC_BUG( + EXPECT_FALSE(LoadBalancerConfig::CreateUnencrypted(2, 1, 17).has_value()), + "Invalid LoadBalancerConfig Config ID 2 Server ID Length 1 " + "Nonce Length 17"); + // Bad key lengths. + EXPECT_QUIC_BUG( + EXPECT_FALSE(LoadBalancerConfig::Create(2, 3, 4, "").has_value()), + "Invalid LoadBalancerConfig Key Length: 0"); + EXPECT_QUIC_BUG(EXPECT_FALSE(LoadBalancerConfig::Create( + 2, 3, 4, absl::string_view(raw_key, 10)) + .has_value()), + "Invalid LoadBalancerConfig Key Length: 10"); + EXPECT_QUIC_BUG(EXPECT_FALSE(LoadBalancerConfig::Create( + 0, 3, 4, absl::string_view(raw_key, 17)) + .has_value()), + "Invalid LoadBalancerConfig Key Length: 17"); +} + +TEST_F(LoadBalancerConfigTest, ValidParams) { + // Test valid configurations and accessors + auto config = LoadBalancerConfig::CreateUnencrypted(0, 3, 4); + EXPECT_TRUE(config.has_value()); + EXPECT_EQ(config->config_id(), 0); + EXPECT_EQ(config->server_id_len(), 3); + EXPECT_EQ(config->nonce_len(), 4); + EXPECT_EQ(config->plaintext_len(), 7); + EXPECT_EQ(config->total_len(), 8); + EXPECT_FALSE(config->IsEncrypted()); + auto config2 = + LoadBalancerConfig::Create(2, 6, 7, absl::string_view(raw_key, 16)); + EXPECT_TRUE(config.has_value()); + EXPECT_EQ(config2->config_id(), 2); + EXPECT_EQ(config2->server_id_len(), 6); + EXPECT_EQ(config2->nonce_len(), 7); + EXPECT_EQ(config2->plaintext_len(), 13); + EXPECT_EQ(config2->total_len(), 14); + EXPECT_TRUE(config2->IsEncrypted()); +} + +// Compare EncryptionPass() results to the example in +// draft-ietf-quic-load-balancers-15, Section 4.3.2. +TEST_F(LoadBalancerConfigTest, TestEncryptionPassExample) { + auto config = + LoadBalancerConfig::Create(0, 3, 4, absl::string_view(raw_key, 16)); + EXPECT_TRUE(config.has_value()); + EXPECT_TRUE(config->IsEncrypted()); + std::array bytes = {0x31, 0x44, 0x1a, 0x9c, 0x69, 0xc2, 0x75}; + std::array pass1 = {0x31, 0x44, 0x1a, 0x9f, 0x1a, 0x5b, 0x6b}; + std::array pass2 = {0x02, 0x8e, 0x1b, 0x5f, 0x1a, 0x5b, 0x6b}; + std::array pass3 = {0x02, 0x8e, 0x1b, 0x54, 0x94, 0x97, 0x62}; + std::array pass4 = {0x8e, 0x9a, 0x91, 0xf4, 0x94, 0x97, 0x62}; + + // Input is too short. + EXPECT_FALSE(config->EncryptionPass(absl::Span(bytes.data(), 6), 0)); + EXPECT_TRUE(config->EncryptionPass(absl::Span(bytes), 1)); + EXPECT_EQ(bytes, pass1); + EXPECT_TRUE(config->EncryptionPass(absl::Span(bytes), 2)); + EXPECT_EQ(bytes, pass2); + EXPECT_TRUE(config->EncryptionPass(absl::Span(bytes), 3)); + EXPECT_EQ(bytes, pass3); + EXPECT_TRUE(config->EncryptionPass(absl::Span(bytes), 4)); + EXPECT_EQ(bytes, pass4); +} + +TEST_F(LoadBalancerConfigTest, EncryptionPassPlaintext) { + auto config = LoadBalancerConfig::CreateUnencrypted(0, 3, 4); + std::array bytes = {0x31, 0x44, 0x1a, 0x9c, 0x69, 0xc2, 0x75}; + EXPECT_FALSE(config->EncryptionPass(absl::Span(bytes), 1)); +} + +// Check that the encryption pass code can decode its own ciphertext. Various +// pointer errors could cause the code to overwrite bits that contain +// important information. +TEST_F(LoadBalancerConfigTest, EncryptionPassesAreReversible) { + auto config = + LoadBalancerConfig::Create(0, 3, 4, absl::string_view(raw_key, 16)); + std::array bytes = { + 0x31, 0x44, 0x1a, 0x9c, 0x69, 0xc2, 0x75, + }; + std::array orig_bytes; + memcpy(orig_bytes.data(), bytes.data(), bytes.size()); + // Work left->right and right->left passes. + EXPECT_TRUE(config->EncryptionPass(absl::Span(bytes), 1)); + EXPECT_TRUE(config->EncryptionPass(absl::Span(bytes), 2)); + EXPECT_TRUE(config->EncryptionPass(absl::Span(bytes), 2)); + EXPECT_TRUE(config->EncryptionPass(absl::Span(bytes), 1)); + EXPECT_EQ(bytes, orig_bytes); +} + +TEST_F(LoadBalancerConfigTest, InvalidBlockEncryption) { + uint8_t pt[kLoadBalancerBlockSize], ct[kLoadBalancerBlockSize]; + auto pt_config = LoadBalancerConfig::CreateUnencrypted(0, 8, 8); + EXPECT_FALSE(pt_config->BlockEncrypt(pt, ct)); + EXPECT_FALSE(pt_config->BlockDecrypt(ct, pt)); + EXPECT_FALSE(pt_config->EncryptionPass(absl::Span(pt), 0)); + auto small_cid_config = + LoadBalancerConfig::Create(0, 3, 4, absl::string_view(raw_key, 16)); + EXPECT_TRUE(small_cid_config->BlockEncrypt(pt, ct)); + EXPECT_FALSE(small_cid_config->BlockDecrypt(ct, pt)); + auto block_config = + LoadBalancerConfig::Create(0, 8, 8, absl::string_view(raw_key, 16)); + EXPECT_TRUE(block_config->BlockEncrypt(pt, ct)); + EXPECT_TRUE(block_config->BlockDecrypt(ct, pt)); +} + +// Block decrypt test from the Test Vector in +// draft-ietf-quic-load-balancers-15, Appendix B. +TEST_F(LoadBalancerConfigTest, BlockEncryptionExample) { + const uint8_t ptext[] = {0xed, 0x79, 0x3a, 0x51, 0xd4, 0x9b, 0x8f, 0x5f, + 0xee, 0x08, 0x0d, 0xbf, 0x48, 0xc0, 0xd1, 0xe5}; + const uint8_t ctext[] = {0x4d, 0xd2, 0xd0, 0x5a, 0x7b, 0x0d, 0xe9, 0xb2, + 0xb9, 0x90, 0x7a, 0xfb, 0x5e, 0xcf, 0x8c, 0xc3}; + const char key[] = {0x8f, 0x95, 0xf0, 0x92, 0x45, 0x76, 0x5f, 0x80, + 0x25, 0x69, 0x34, 0xe5, 0x0c, 0x66, 0x20, 0x7f}; + uint8_t result[sizeof(ptext)]; + auto config = LoadBalancerConfig::Create(0, 8, 8, absl::string_view(key, 16)); + EXPECT_TRUE(config->BlockEncrypt(ptext, result)); + EXPECT_EQ(memcmp(result, ctext, sizeof(ctext)), 0); + EXPECT_TRUE(config->BlockDecrypt(ctext, result)); + EXPECT_EQ(memcmp(result, ptext, sizeof(ptext)), 0); +} + +TEST_F(LoadBalancerConfigTest, ConfigIsCopyable) { + const uint8_t ptext[] = {0xed, 0x79, 0x3a, 0x51, 0xd4, 0x9b, 0x8f, 0x5f, + 0xee, 0x08, 0x0d, 0xbf, 0x48, 0xc0, 0xd1, 0xe5}; + const uint8_t ctext[] = {0x4d, 0xd2, 0xd0, 0x5a, 0x7b, 0x0d, 0xe9, 0xb2, + 0xb9, 0x90, 0x7a, 0xfb, 0x5e, 0xcf, 0x8c, 0xc3}; + const char key[] = {0x8f, 0x95, 0xf0, 0x92, 0x45, 0x76, 0x5f, 0x80, + 0x25, 0x69, 0x34, 0xe5, 0x0c, 0x66, 0x20, 0x7f}; + uint8_t result[sizeof(ptext)]; + auto config = LoadBalancerConfig::Create(0, 8, 8, absl::string_view(key, 16)); + auto config2 = config; + EXPECT_TRUE(config->BlockEncrypt(ptext, result)); + EXPECT_EQ(memcmp(result, ctext, sizeof(ctext)), 0); + EXPECT_TRUE(config2->BlockEncrypt(ptext, result)); + EXPECT_EQ(memcmp(result, ctext, sizeof(ctext)), 0); +} + +} // namespace + +} // namespace test + +} // namespace quic diff --git a/quiche/quic/load_balancer/load_balancer_decoder.cc b/quiche/quic/load_balancer/load_balancer_decoder.cc new file mode 100644 index 000000000000..cdff31445682 --- /dev/null +++ b/quiche/quic/load_balancer/load_balancer_decoder.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/load_balancer/load_balancer_decoder.h" + +#include "quiche/quic/platform/api/quic_bug_tracker.h" + +namespace quic { + +bool LoadBalancerDecoder::AddConfig(const LoadBalancerConfig& config) { + if (config_[config.config_id()].has_value()) { + return false; + } + config_[config.config_id()] = config; + return true; +} + +void LoadBalancerDecoder::DeleteConfig(uint8_t config_id) { + if (config_id >= kNumLoadBalancerConfigs) { + QUIC_BUG(quic_bug_438896865_01) + << "Decoder deleting config with invalid config_id " + << static_cast(config_id); + return; + } + config_[config_id].reset(); +} + +// This is the core logic to extract a server ID given a valid config and +// connection ID of sufficient length. +absl::optional LoadBalancerDecoder::GetServerId( + const QuicConnectionId& connection_id) const { + absl::optional config_id = GetConfigId(connection_id); + if (!config_id.has_value()) { + return absl::optional(); + } + absl::optional config = config_[*config_id]; + if (!config.has_value()) { + return absl::optional(); + } + if (connection_id.length() < config->total_len()) { + // Connection ID wasn't long enough + return absl::optional(); + } + // The first byte is complete. Finish the rest. + const uint8_t* data = + reinterpret_cast(connection_id.data()) + 1; + if (!config->IsEncrypted()) { // It's a Plaintext CID. + return LoadBalancerServerId::Create( + absl::Span(data, config->server_id_len())); + } + uint8_t result[kQuicMaxConnectionIdWithLengthPrefixLength]; + if (config->plaintext_len() == kLoadBalancerKeyLen) { // single pass + if (!config->BlockDecrypt(data, result)) { + return absl::optional(); + } + } else { + // Do 3 or 4 passes. Only 3 are necessary if the server_id is short enough + // to fit in the first half of the connection ID (the decoder doesn't need + // to extract the nonce). + memcpy(result, data, config->plaintext_len()); + uint8_t end = (config->server_id_len() > config->nonce_len()) ? 1 : 2; + for (uint8_t i = kNumLoadBalancerCryptoPasses; i >= end; i--) { + if (!config->EncryptionPass(absl::Span(result), i)) { + return absl::optional(); + } + } + } + return LoadBalancerServerId::Create( + absl::Span(result, config->server_id_len())); +} + +absl::optional LoadBalancerDecoder::GetConfigId( + const QuicConnectionId& connection_id) { + if (connection_id.IsEmpty()) { + return absl::optional(); + } + return GetConfigId(*reinterpret_cast(connection_id.data())); +} + +absl::optional LoadBalancerDecoder::GetConfigId( + const uint8_t connection_id_first_byte) { + uint8_t codepoint = (connection_id_first_byte >> 6); + if (codepoint < kNumLoadBalancerConfigs) { + return codepoint; + } + return absl::optional(); +} + +} // namespace quic diff --git a/quiche/quic/load_balancer/load_balancer_decoder.h b/quiche/quic/load_balancer/load_balancer_decoder.h new file mode 100644 index 000000000000..f5ba6f3a3826 --- /dev/null +++ b/quiche/quic/load_balancer/load_balancer_decoder.h @@ -0,0 +1,59 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_DECODER_H_ +#define QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_DECODER_H_ + +#include "quiche/quic/load_balancer/load_balancer_config.h" +#include "quiche/quic/load_balancer/load_balancer_server_id.h" + +namespace quic { + +// Manages QUIC-LB configurations to extract a server ID from a properly +// encoded connection ID, usually on behalf of a load balancer. +class QUIC_EXPORT_PRIVATE LoadBalancerDecoder { + public: + // Returns false if the config_id codepoint is already occupied. + bool AddConfig(const LoadBalancerConfig& config); + + // Remove support for a config. Does nothing if there is no config for + // |config_id|. Does nothing and creates a bug if |config_id| is greater than + // 2. + void DeleteConfig(uint8_t config_id); + + // Return the config for |config_id|, or nullptr if not found. + const LoadBalancerConfig* GetConfig(const uint8_t config_id) const { + if (config_id >= kNumLoadBalancerConfigs || + !config_[config_id].has_value()) { + return nullptr; + } + + return &config_[config_id].value(); + } + + // Extract a server ID from |connection_id|. If there is no config for the + // codepoint, |connection_id| is too short, or there's a decrypt error, + // returns empty. Will accept |connection_id| that is longer than necessary + // without error. + absl::optional GetServerId( + const QuicConnectionId& connection_id) const; + + // Returns the config ID stored in the first two bits of |connection_id|, or + // empty if |connection_id| is empty, or the first two bits of the first byte + // of |connection_id| are 0b11. + static absl::optional GetConfigId( + const QuicConnectionId& connection_id); + + // Returns the config ID stored in the first two bits of + // |connection_id_first_byte|, or empty if the first two bits are 0b11. + static absl::optional GetConfigId(uint8_t connection_id_first_byte); + + private: + // Decoders can support up to 3 configs at once. + absl::optional config_[kNumLoadBalancerConfigs]; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_DECODER_H_ diff --git a/quiche/quic/load_balancer/load_balancer_decoder_test.cc b/quiche/quic/load_balancer/load_balancer_decoder_test.cc new file mode 100644 index 000000000000..b2bdfd3a05bc --- /dev/null +++ b/quiche/quic/load_balancer/load_balancer_decoder_test.cc @@ -0,0 +1,242 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/load_balancer/load_balancer_decoder.h" + +#include "quiche/quic/load_balancer/load_balancer_server_id.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { + +namespace test { + +namespace { + +class LoadBalancerDecoderTest : public QuicTest {}; + +// Convenience function to shorten the code. Does not check if |array| is long +// enough or |length| is valid for a server ID. +inline LoadBalancerServerId MakeServerId(const uint8_t array[], + const uint8_t length) { + return *LoadBalancerServerId::Create( + absl::Span(array, length)); +} + +constexpr char kRawKey[] = {0x8f, 0x95, 0xf0, 0x92, 0x45, 0x76, 0x5f, 0x80, + 0x25, 0x69, 0x34, 0xe5, 0x0c, 0x66, 0x20, 0x7f}; +constexpr absl::string_view kKey(kRawKey, kLoadBalancerKeyLen); +constexpr uint8_t kServerId[] = {0xed, 0x79, 0x3a, 0x51, 0xd4, 0x9b, 0x8f, 0x5f, + 0xab, 0x65, 0xba, 0x04, 0xc3, 0x33, 0x0a}; + +struct LoadBalancerDecoderTestCase { + LoadBalancerConfig config; + QuicConnectionId connection_id; + LoadBalancerServerId server_id; +}; + +TEST_F(LoadBalancerDecoderTest, UnencryptedConnectionIdTestVectors) { + const struct LoadBalancerDecoderTestCase test_vectors[2] = { + { + *LoadBalancerConfig::CreateUnencrypted(0, 3, 4), + QuicConnectionId({0x07, 0xed, 0x79, 0x3a, 0x80, 0x49, 0x71, 0x8a}), + MakeServerId(kServerId, 3), + }, + { + *LoadBalancerConfig::CreateUnencrypted(1, 8, 5), + QuicConnectionId({0x4d, 0xed, 0x79, 0x3a, 0x51, 0xd4, 0x9b, 0x8f, + 0x5f, 0xee, 0x15, 0xda, 0x27, 0xc4}), + MakeServerId(kServerId, 8), + }}; + for (const auto& test : test_vectors) { + LoadBalancerDecoder decoder; + EXPECT_TRUE(decoder.AddConfig(test.config)); + EXPECT_EQ(decoder.GetServerId(test.connection_id), test.server_id); + } +} + +// Compare test vectors from Appendix B of draft-ietf-quic-load-balancers-15. +TEST_F(LoadBalancerDecoderTest, DecoderTestVectors) { + // Try (1) the "standard" CID length of 8 + // (2) server_id_len > nonce_len, so there is a fourth decryption pass + // (3) the single-pass encryption case + // (4) An even total length. + const struct LoadBalancerDecoderTestCase test_vectors[4] = { + { + *LoadBalancerConfig::Create(0, 3, 4, kKey), + QuicConnectionId({0x07, 0x41, 0x26, 0xee, 0x38, 0xbf, 0x54, 0x54}), + MakeServerId(kServerId, 3), + }, + { + *LoadBalancerConfig::Create(1, 10, 5, kKey), + QuicConnectionId({0x4f, 0xcd, 0x3f, 0x57, 0x2d, 0x4e, 0xef, 0xb0, + 0x46, 0xfd, 0xb5, 0x1d, 0x16, 0x4e, 0xfc, 0xcc}), + MakeServerId(kServerId, 10), + }, + { + *LoadBalancerConfig::Create(2, 8, 8, kKey), + QuicConnectionId({0x90, 0x4d, 0xd2, 0xd0, 0x5a, 0x7b, 0x0d, 0xe9, + 0xb2, 0xb9, 0x90, 0x7a, 0xfb, 0x5e, 0xcf, 0x8c, + 0xc3}), + MakeServerId(kServerId, 8), + }, + { + *LoadBalancerConfig::Create(0, 9, 9, kKey), + QuicConnectionId({0x12, 0x12, 0x4d, 0x1e, 0xb8, 0xfb, 0xb2, 0x1e, + 0x4a, 0x49, 0x0c, 0xa5, 0x3c, 0xfe, 0x21, 0xd0, + 0x4a, 0xe6, 0x3a}), + MakeServerId(kServerId, 9), + }, + }; + for (const auto& test : test_vectors) { + LoadBalancerDecoder decoder; + EXPECT_TRUE(decoder.AddConfig(test.config)); + EXPECT_EQ(decoder.GetServerId(test.connection_id), test.server_id); + } +} + +TEST_F(LoadBalancerDecoderTest, NoServerIdEntry) { + auto server_id = LoadBalancerServerId::Create({0x01, 0x02, 0x03}); + EXPECT_TRUE(server_id.has_value()); + LoadBalancerDecoder decoder; + EXPECT_TRUE( + decoder.AddConfig(*LoadBalancerConfig::CreateUnencrypted(0, 3, 4))); + QuicConnectionId no_server_id_entry( + {0x00, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}); + EXPECT_TRUE(decoder.GetServerId(no_server_id_entry).has_value()); +} + +TEST_F(LoadBalancerDecoderTest, InvalidConfigId) { + auto server_id = LoadBalancerServerId::Create({0x01, 0x02, 0x03}); + EXPECT_TRUE(server_id.has_value()); + LoadBalancerDecoder decoder; + EXPECT_TRUE( + decoder.AddConfig(*LoadBalancerConfig::CreateUnencrypted(1, 3, 4))); + QuicConnectionId wrong_config_id( + {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}); + EXPECT_FALSE(decoder + .GetServerId(QuicConnectionId( + {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07})) + .has_value()); +} + +TEST_F(LoadBalancerDecoderTest, UnroutableCodepoint) { + auto server_id = LoadBalancerServerId::Create({0x01, 0x02, 0x03}); + EXPECT_TRUE(server_id.has_value()); + LoadBalancerDecoder decoder; + EXPECT_TRUE( + decoder.AddConfig(*LoadBalancerConfig::CreateUnencrypted(1, 3, 4))); + EXPECT_FALSE(decoder + .GetServerId(QuicConnectionId( + {0xc0, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07})) + .has_value()); +} + +TEST_F(LoadBalancerDecoderTest, UnroutableCodepointAnyLength) { + auto server_id = LoadBalancerServerId::Create({0x01, 0x02, 0x03}); + EXPECT_TRUE(server_id.has_value()); + LoadBalancerDecoder decoder; + EXPECT_TRUE( + decoder.AddConfig(*LoadBalancerConfig::CreateUnencrypted(1, 3, 4))); + EXPECT_FALSE(decoder.GetServerId(QuicConnectionId({0xff})).has_value()); +} + +TEST_F(LoadBalancerDecoderTest, ConnectionIdTooShort) { + auto server_id = LoadBalancerServerId::Create({0x01, 0x02, 0x03}); + EXPECT_TRUE(server_id.has_value()); + LoadBalancerDecoder decoder; + EXPECT_TRUE( + decoder.AddConfig(*LoadBalancerConfig::CreateUnencrypted(0, 3, 4))); + EXPECT_FALSE(decoder + .GetServerId(QuicConnectionId( + {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06})) + .has_value()); +} + +TEST_F(LoadBalancerDecoderTest, ConnectionIdTooLongIsOK) { + auto server_id = LoadBalancerServerId::Create({0x01, 0x02, 0x03}); + LoadBalancerDecoder decoder; + EXPECT_TRUE( + decoder.AddConfig(*LoadBalancerConfig::CreateUnencrypted(0, 3, 4))); + auto server_id_result = decoder.GetServerId( + QuicConnectionId({0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08})); + EXPECT_TRUE(server_id_result.has_value()); + EXPECT_EQ(server_id_result, server_id); +} + +TEST_F(LoadBalancerDecoderTest, DeleteConfigBadId) { + LoadBalancerDecoder decoder; + decoder.AddConfig(*LoadBalancerConfig::CreateUnencrypted(2, 3, 4)); + decoder.DeleteConfig(0); + EXPECT_QUIC_BUG(decoder.DeleteConfig(3), + "Decoder deleting config with invalid config_id 3"); + EXPECT_TRUE(decoder + .GetServerId(QuicConnectionId( + {0x80, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07})) + .has_value()); +} + +TEST_F(LoadBalancerDecoderTest, DeleteConfigGoodId) { + LoadBalancerDecoder decoder; + decoder.AddConfig(*LoadBalancerConfig::CreateUnencrypted(2, 3, 4)); + decoder.DeleteConfig(2); + EXPECT_FALSE(decoder + .GetServerId(QuicConnectionId( + {0x80, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07})) + .has_value()); +} + +// Create two server IDs and make sure the decoder decodes the correct one. +TEST_F(LoadBalancerDecoderTest, TwoServerIds) { + auto server_id1 = LoadBalancerServerId::Create({0x01, 0x02, 0x03}); + EXPECT_TRUE(server_id1.has_value()); + auto server_id2 = LoadBalancerServerId::Create({0x04, 0x05, 0x06}); + LoadBalancerDecoder decoder; + EXPECT_TRUE( + decoder.AddConfig(*LoadBalancerConfig::CreateUnencrypted(0, 3, 4))); + EXPECT_EQ(decoder.GetServerId(QuicConnectionId( + {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07})), + server_id1); + EXPECT_EQ(decoder.GetServerId(QuicConnectionId( + {0x00, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a})), + server_id2); +} + +TEST_F(LoadBalancerDecoderTest, GetConfigId) { + EXPECT_FALSE( + LoadBalancerDecoder::GetConfigId(QuicConnectionId()).has_value()); + for (uint8_t i = 0; i < 3; i++) { + const QuicConnectionId connection_id({static_cast(i << 6)}); + auto config_id = LoadBalancerDecoder::GetConfigId(connection_id); + EXPECT_EQ(config_id, + LoadBalancerDecoder::GetConfigId(connection_id.data()[0])); + EXPECT_TRUE(config_id.has_value()); + EXPECT_EQ(*config_id, i); + } + EXPECT_FALSE( + LoadBalancerDecoder::GetConfigId(QuicConnectionId({0xc0})).has_value()); +} + +TEST_F(LoadBalancerDecoderTest, GetConfig) { + LoadBalancerDecoder decoder; + decoder.AddConfig(*LoadBalancerConfig::CreateUnencrypted(2, 3, 4)); + + EXPECT_EQ(decoder.GetConfig(0), nullptr); + EXPECT_EQ(decoder.GetConfig(1), nullptr); + EXPECT_EQ(decoder.GetConfig(3), nullptr); + EXPECT_EQ(decoder.GetConfig(4), nullptr); + + const LoadBalancerConfig* config = decoder.GetConfig(2); + ASSERT_NE(config, nullptr); + EXPECT_EQ(config->server_id_len(), 3); + EXPECT_EQ(config->nonce_len(), 4); + EXPECT_FALSE(config->IsEncrypted()); +} + +} // namespace + +} // namespace test + +} // namespace quic diff --git a/quiche/quic/load_balancer/load_balancer_encoder.cc b/quiche/quic/load_balancer/load_balancer_encoder.cc new file mode 100644 index 000000000000..9e71d6c4abbb --- /dev/null +++ b/quiche/quic/load_balancer/load_balancer_encoder.cc @@ -0,0 +1,203 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/load_balancer/load_balancer_encoder.h" + +#include "absl/numeric/int128.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_data_reader.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/core/quic_packet_number.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/load_balancer/load_balancer_config.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" + +namespace quic { + +namespace { + +// Returns the number of nonces given a certain |nonce_len|. +absl::uint128 NumberOfNonces(uint8_t nonce_len) { + return (static_cast(1) << (nonce_len * 8)); +} + +// Writes the |size| least significant bytes from |in| to |out| in host byte +// order. Returns false if |out| does not have enough space. +bool WriteUint128(const absl::uint128 in, uint8_t size, QuicDataWriter &out) { + if (out.remaining() < size) { + QUIC_BUG(quic_bug_435375038_05) + << "Call to WriteUint128() does not have enough space in |out|"; + return false; + } + uint64_t num64 = absl::Uint128Low64(in); + if (size <= sizeof(num64)) { + out.WriteBytes(&num64, size); + } else { + out.WriteBytes(&num64, sizeof(num64)); + num64 = absl::Uint128High64(in); + out.WriteBytes(&num64, size - sizeof(num64)); + } + return true; +} + +} // namespace + +absl::optional LoadBalancerEncoder::Create( + QuicRandom &random, LoadBalancerEncoderVisitorInterface *const visitor, + const bool len_self_encoded, const uint8_t unroutable_connection_id_len) { + if (unroutable_connection_id_len == 0 || + unroutable_connection_id_len > + kQuicMaxConnectionIdWithLengthPrefixLength) { + QUIC_BUG(quic_bug_435375038_01) + << "Invalid unroutable_connection_id_len = " + << static_cast(unroutable_connection_id_len); + return absl::optional(); + } + return LoadBalancerEncoder(random, visitor, len_self_encoded, + unroutable_connection_id_len); +} + +bool LoadBalancerEncoder::UpdateConfig(const LoadBalancerConfig &config, + const LoadBalancerServerId server_id) { + if (config_.has_value() && config_->config_id() == config.config_id()) { + QUIC_BUG(quic_bug_435375038_02) + << "Attempting to change config with same ID"; + return false; + } + if (server_id.length() != config.server_id_len()) { + QUIC_BUG(quic_bug_435375038_03) + << "Server ID length " << static_cast(server_id.length()) + << " does not match configured value of " + << static_cast(config.server_id_len()); + return false; + } + if (visitor_ != nullptr) { + if (config_.has_value()) { + visitor_->OnConfigChanged(config_->config_id(), config.config_id()); + } else { + visitor_->OnConfigAdded(config.config_id()); + } + } + config_ = config; + server_id_ = server_id; + + seed_ = absl::MakeUint128(random_.RandUint64(), random_.RandUint64()) % + NumberOfNonces(config.nonce_len()); + num_nonces_left_ = NumberOfNonces(config.nonce_len()); + connection_id_lengths_[config.config_id()] = config.total_len(); + return true; +} + +void LoadBalancerEncoder::DeleteConfig() { + if (visitor_ != nullptr && config_.has_value()) { + visitor_->OnConfigDeleted(config_->config_id()); + } + config_.reset(); + server_id_.reset(); + num_nonces_left_ = 0; +} + +QuicConnectionId LoadBalancerEncoder::GenerateConnectionId() { + uint8_t config_id = config_.has_value() ? config_->config_id() + : kLoadBalancerUnroutableConfigId; + uint8_t shifted_config_id = config_id << 6; + uint8_t length = connection_id_lengths_[config_id]; + if (config_.has_value() != server_id_.has_value()) { + QUIC_BUG(quic_bug_435375038_04) + << "Existence of config and server_id are out of sync"; + return QuicConnectionId(); + } + uint8_t first_byte; + // first byte + if (len_self_encoded_) { + first_byte = shifted_config_id | (length - 1); + } else { + random_.RandBytes(static_cast(&first_byte), 1); + first_byte = shifted_config_id | (first_byte & kLoadBalancerLengthMask); + } + if (!config_.has_value()) { + return MakeUnroutableConnectionId(first_byte); + } + QuicConnectionId id; + id.set_length(length); + QuicDataWriter writer(length, id.mutable_data(), quiche::HOST_BYTE_ORDER); + writer.WriteUInt8(first_byte); + absl::uint128 next_nonce = + (seed_ + num_nonces_left_--) % NumberOfNonces(config_->nonce_len()); + writer.WriteBytes(server_id_->data().data(), server_id_->length()); + if (!WriteUint128(next_nonce, config_->nonce_len(), writer)) { + return QuicConnectionId(); + } + uint8_t *block_start = reinterpret_cast(writer.data() + 1); + if (!config_->IsEncrypted()) { + // Fill the nonce field with a hash of the Connection ID to avoid the nonce + // visibly increasing by one. This would allow observers to correlate + // connection IDs as being sequential and likely from the same connection, + // not just the same server. + absl::uint128 nonce_hash = + QuicUtils::FNV1a_128_Hash(absl::string_view(writer.data(), length)); + QuicDataWriter rewriter(config_->nonce_len(), + id.mutable_data() + config_->server_id_len() + 1, + quiche::HOST_BYTE_ORDER); + if (!WriteUint128(nonce_hash, config_->nonce_len(), rewriter)) { + return QuicConnectionId(); + } + } else if (config_->plaintext_len() == kLoadBalancerBlockSize) { + // Use one encryption pass. + if (!config_->BlockEncrypt(block_start, block_start)) { + QUIC_LOG(ERROR) << "Block encryption failed"; + return QuicConnectionId(); + } + } else { + for (uint8_t i = 1; i <= kNumLoadBalancerCryptoPasses; i++) { + if (!config_->EncryptionPass(absl::Span(block_start, length - 1), + i)) { + QUIC_LOG(ERROR) << "Block encryption failed"; + return QuicConnectionId(); + } + } + } + if (num_nonces_left_ == 0) { + DeleteConfig(); + } + return id; +} + +absl::optional LoadBalancerEncoder::GenerateNextConnectionId( + [[maybe_unused]] const QuicConnectionId &original) { + // Do not allow new connection IDs if linkable. + return (IsEncoding() && !IsEncrypted()) ? absl::optional() + : GenerateConnectionId(); +} + +absl::optional LoadBalancerEncoder::MaybeReplaceConnectionId( + const QuicConnectionId &original, const ParsedQuicVersion &version) { + // Pre-IETF versions of QUIC can respond poorly to new connection IDs issued + // during the handshake. + uint8_t needed_length = config_.has_value() + ? config_->total_len() + : connection_id_lengths_[kNumLoadBalancerConfigs]; + return (!version.HasIetfQuicFrames() && original.length() == needed_length) + ? absl::optional() + : GenerateConnectionId(); +} + +uint8_t LoadBalancerEncoder::ConnectionIdLength(uint8_t first_byte) const { + if (len_self_encoded()) { + return (first_byte &= kLoadBalancerLengthMask) + 1; + } + return connection_id_lengths_[first_byte >> 6]; +} + +QuicConnectionId LoadBalancerEncoder::MakeUnroutableConnectionId( + uint8_t first_byte) { + QuicConnectionId id; + id.set_length(connection_id_lengths_[kLoadBalancerUnroutableConfigId]); + id.mutable_data()[0] = first_byte; + random_.RandBytes(&id.mutable_data()[1], connection_id_lengths_[3] - 1); + return id; +} + +} // namespace quic diff --git a/quiche/quic/load_balancer/load_balancer_encoder.h b/quiche/quic/load_balancer/load_balancer_encoder.h new file mode 100644 index 000000000000..1099d7ac6936 --- /dev/null +++ b/quiche/quic/load_balancer/load_balancer_encoder.h @@ -0,0 +1,156 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_ENCODER_H_ +#define QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_ENCODER_H_ + +#include "quiche/quic/core/connection_id_generator.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/load_balancer/load_balancer_config.h" +#include "quiche/quic/load_balancer/load_balancer_server_id.h" + +namespace quic { + +namespace test { +class LoadBalancerEncoderPeer; +} + +// Default length of a 4-tuple connection ID. +inline constexpr uint8_t kLoadBalancerUnroutableLen = 8; +// When the encoder is self-encoding the connection ID length, these are the +// bits of the first byte that do so. +constexpr uint8_t kLoadBalancerLengthMask = 0x3f; +// The bits of the connection ID first byte that encode the config ID. +constexpr uint8_t kLoadBalancerConfigIdMask = 0xc0; +// The config ID that means the connection ID does not contain routing +// information. +constexpr uint8_t kLoadBalancerUnroutableConfigId = kNumLoadBalancerConfigs; +// The bits of the connection ID first byte that correspond to a connection ID +// that does not contain routing information. +constexpr uint8_t kLoadBalancerUnroutablePrefix = + kLoadBalancerUnroutableConfigId << 6; + +// Interface which receives notifications when the current config is updated. +class QUIC_EXPORT_PRIVATE LoadBalancerEncoderVisitorInterface { + public: + virtual ~LoadBalancerEncoderVisitorInterface() {} + + // Called when a config is added where none existed. + // + // Connections that support address migration should retire unroutable + // connection IDs and replace them with routable ones using the new config, + // while avoiding sending a sudden storm of packets containing + // RETIRE_CONNECTION_ID and NEW_CONNECTION_ID frames. + virtual void OnConfigAdded(const uint8_t config_id) = 0; + // Called when the config is changed. + // + // Existing routable connection IDs should be retired before the decoder stops + // supporting that config. The timing of this event is deployment-dependent + // and might be tied to the arrival of a new config at the encoder. + virtual void OnConfigChanged(const uint8_t old_config_id, + const uint8_t new_config_id) = 0; + // Called when a config is deleted. The encoder will generate unroutable + // connection IDs from now on. + // + // New connections will not be able to support address migration until a new + // config arrives. Existing connections can retain connection IDs that use the + // deleted config, which will only become unroutable once the decoder also + // deletes it. The time of that deletion is deployment-dependent and might be + // tied to the arrival of a new config at the encoder. + virtual void OnConfigDeleted(const uint8_t config_id) = 0; +}; + +// Manages QUIC-LB configurations to properly encode a given server ID in a +// QUIC Connection ID. +class QUIC_EXPORT_PRIVATE LoadBalancerEncoder + : public ConnectionIdGeneratorInterface { + public: + LoadBalancerEncoder(QuicRandom& random, + LoadBalancerEncoderVisitorInterface* const visitor, + const bool len_self_encoded) + : LoadBalancerEncoder(random, visitor, len_self_encoded, + kLoadBalancerUnroutableLen) {} + ~LoadBalancerEncoder() override {} + + // Returns a newly created encoder with no active config, if + // |unroutable_connection_id_length| is valid. |visitor| specifies an optional + // interface to receive callbacks when config status changes. + // If |len_self_encoded| is true, then the first byte of any generated + // connection ids will encode the length. Otherwise, those bits will be + // random. |unroutable_connection_id_length| specifies the length of + // connection IDs to be generated when there is no active config. It must not + // be 0 and must not be larger than the RFC9000 maximum of 20. + static absl::optional Create( + QuicRandom& random, LoadBalancerEncoderVisitorInterface* const visitor, + const bool len_self_encoded, + const uint8_t unroutable_connection_id_len = kLoadBalancerUnroutableLen); + + // Attempts to replace the current config and server_id with |config| and + // |server_id|. If the length |server_id| does not match the server_id_length + // of |config| or the ID of |config| matches the ID of the current config, + // returns false and leaves the current config unchanged. Otherwise, returns + // true. When the encoder runs out of nonces, it will delete the config and + // begin generating unroutable connection IDs. + bool UpdateConfig(const LoadBalancerConfig& config, + const LoadBalancerServerId server_id); + + // Delete the current config and generate unroutable connection IDs from now + // on. + void DeleteConfig(); + + // Returns the number of additional connection IDs that can be generated with + // the current config, or 0 if there is no current config. + absl::uint128 num_nonces_left() const { return num_nonces_left_; } + + // Functions below are declared virtual to enable mocking. + // Returns true if there is an active configuration. + virtual bool IsEncoding() const { return config_.has_value(); } + // Returns true if there is an active configuration that uses encryption. + virtual bool IsEncrypted() const { + return config_.has_value() && config_->IsEncrypted(); + } + virtual bool len_self_encoded() const { return len_self_encoded_; } + + // If there's an active config, generates a connection ID using it. If not, + // generates an unroutable connection_id. If there's an error, returns a zero- + // length Connection ID. + QuicConnectionId GenerateConnectionId(); + + // Functions from ConnectionIdGeneratorInterface + absl::optional GenerateNextConnectionId( + const QuicConnectionId& original) override; + absl::optional MaybeReplaceConnectionId( + const QuicConnectionId& original, + const ParsedQuicVersion& version) override; + uint8_t ConnectionIdLength(uint8_t first_byte) const override; + + protected: + LoadBalancerEncoder(QuicRandom& random, + LoadBalancerEncoderVisitorInterface* const visitor, + const bool len_self_encoded, + const uint8_t unroutable_connection_id_len) + : random_(random), + len_self_encoded_(len_self_encoded), + visitor_(visitor) { + std::fill_n(connection_id_lengths_, 4, unroutable_connection_id_len); + } + + private: + friend class test::LoadBalancerEncoderPeer; + + QuicConnectionId MakeUnroutableConnectionId(uint8_t first_byte); + + QuicRandom& random_; + const bool len_self_encoded_; + LoadBalancerEncoderVisitorInterface* const visitor_; + + absl::optional config_; + absl::uint128 seed_, num_nonces_left_ = 0; + absl::optional server_id_; + uint8_t connection_id_lengths_[kNumLoadBalancerConfigs + 1]; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_ENCODER_H_ diff --git a/quiche/quic/load_balancer/load_balancer_encoder_test.cc b/quiche/quic/load_balancer/load_balancer_encoder_test.cc new file mode 100644 index 000000000000..88d09f448625 --- /dev/null +++ b/quiche/quic/load_balancer/load_balancer_encoder_test.cc @@ -0,0 +1,451 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/load_balancer/load_balancer_encoder.h" + +#include + +#include "absl/numeric/int128.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { + +namespace test { + +class LoadBalancerEncoderPeer { + public: + static void SetNumNoncesLeft(LoadBalancerEncoder &encoder, + uint64_t nonces_remaining) { + encoder.num_nonces_left_ = absl::uint128(nonces_remaining); + } +}; + +namespace { + +class TestLoadBalancerEncoderVisitor + : public LoadBalancerEncoderVisitorInterface { + public: + ~TestLoadBalancerEncoderVisitor() override {} + + void OnConfigAdded(const uint8_t config_id) override { + num_adds_++; + current_config_id_ = config_id; + } + + void OnConfigChanged(const uint8_t old_config_id, + const uint8_t new_config_id) override { + num_adds_++; + num_deletes_++; + EXPECT_EQ(old_config_id, current_config_id_); + current_config_id_ = new_config_id; + } + + void OnConfigDeleted(const uint8_t config_id) override { + EXPECT_EQ(config_id, current_config_id_); + current_config_id_.reset(); + num_deletes_++; + } + + uint32_t num_adds() const { return num_adds_; } + uint32_t num_deletes() const { return num_deletes_; } + + private: + uint32_t num_adds_ = 0, num_deletes_ = 0; + absl::optional current_config_id_ = absl::optional(); +}; + +// Allows the caller to specify the exact results in 64-bit chunks. +class TestRandom : public QuicRandom { + public: + uint64_t RandUint64() override { + if (next_values_.empty()) { + return base_; + } + uint64_t value = next_values_.front(); + next_values_.pop(); + return value; + } + + void RandBytes(void *data, size_t len) override { + size_t written = 0; + uint8_t *ptr = static_cast(data); + while (written < len) { + uint64_t result = RandUint64(); + size_t to_write = (len - written > sizeof(uint64_t)) ? sizeof(uint64_t) + : (len - written); + memcpy(ptr + written, &result, to_write); + written += to_write; + } + } + + void InsecureRandBytes(void *data, size_t len) override { + RandBytes(data, len); + } + + uint64_t InsecureRandUint64() override { return RandUint64(); } + + void AddNextValues(uint64_t hi, uint64_t lo) { + next_values_.push(hi); + next_values_.push(lo); + } + + private: + std::queue next_values_; + uint64_t base_ = 0xDEADBEEFDEADBEEF; +}; + +class LoadBalancerEncoderTest : public QuicTest { + public: + TestRandom random_; +}; + +// Convenience function to shorten the code. Does not check if |array| is long +// enough or |length| is valid for a server ID. +LoadBalancerServerId MakeServerId(const uint8_t array[], const uint8_t length) { + return *LoadBalancerServerId::Create( + absl::Span(array, length)); +} + +constexpr char kRawKey[] = {0x8f, 0x95, 0xf0, 0x92, 0x45, 0x76, 0x5f, 0x80, + 0x25, 0x69, 0x34, 0xe5, 0x0c, 0x66, 0x20, 0x7f}; +constexpr absl::string_view kKey(kRawKey, kLoadBalancerKeyLen); +constexpr uint64_t kNonceLow = 0xe5d1c048bf0d08ee; +constexpr uint64_t kNonceHigh = 0x9321e7e34dde525d; +constexpr uint8_t kServerId[] = {0xed, 0x79, 0x3a, 0x51, 0xd4, 0x9b, 0x8f, 0x5f, + 0xab, 0x65, 0xba, 0x04, 0xc3, 0x33, 0x0a}; + +TEST_F(LoadBalancerEncoderTest, BadUnroutableLength) { + EXPECT_QUIC_BUG( + EXPECT_FALSE( + LoadBalancerEncoder::Create(random_, nullptr, false, 0).has_value()), + "Invalid unroutable_connection_id_len = 0"); + EXPECT_QUIC_BUG( + EXPECT_FALSE( + LoadBalancerEncoder::Create(random_, nullptr, false, 21).has_value()), + "Invalid unroutable_connection_id_len = 21"); +} + +TEST_F(LoadBalancerEncoderTest, BadServerIdLength) { + auto encoder = LoadBalancerEncoder::Create(random_, nullptr, true); + ASSERT_TRUE(encoder.has_value()); + // Expects a 3 byte server ID and got 4. + auto config = LoadBalancerConfig::CreateUnencrypted(1, 3, 4); + ASSERT_TRUE(config.has_value()); + EXPECT_QUIC_BUG( + EXPECT_FALSE(encoder->UpdateConfig(*config, MakeServerId(kServerId, 4))), + "Server ID length 4 does not match configured value of 3"); + EXPECT_FALSE(encoder->IsEncoding()); +} + +TEST_F(LoadBalancerEncoderTest, FailToUpdateConfigWithSameId) { + TestLoadBalancerEncoderVisitor visitor; + auto encoder = LoadBalancerEncoder::Create(random_, &visitor, true); + ASSERT_TRUE(encoder.has_value()); + auto config = LoadBalancerConfig::CreateUnencrypted(1, 3, 4); + ASSERT_TRUE(config.has_value()); + EXPECT_TRUE(encoder->UpdateConfig(*config, MakeServerId(kServerId, 3))); + EXPECT_EQ(visitor.num_adds(), 1u); + EXPECT_QUIC_BUG( + EXPECT_FALSE(encoder->UpdateConfig(*config, MakeServerId(kServerId, 3))), + "Attempting to change config with same ID"); + EXPECT_EQ(visitor.num_adds(), 1u); +} + +struct LoadBalancerEncoderTestCase { + LoadBalancerConfig config; + QuicConnectionId connection_id; + LoadBalancerServerId server_id; +}; + +TEST_F(LoadBalancerEncoderTest, UnencryptedConnectionIdTestVectors) { + const struct LoadBalancerEncoderTestCase test_vectors[2] = { + { + *LoadBalancerConfig::CreateUnencrypted(0, 3, 4), + QuicConnectionId({0x07, 0xed, 0x79, 0x3a, 0x80, 0x49, 0x71, 0x8a}), + MakeServerId(kServerId, 3), + }, + { + *LoadBalancerConfig::CreateUnencrypted(1, 8, 5), + QuicConnectionId({0x4d, 0xed, 0x79, 0x3a, 0x51, 0xd4, 0x9b, 0x8f, + 0x5f, 0xee, 0x15, 0xda, 0x27, 0xc4}), + MakeServerId(kServerId, 8), + }, + }; + for (const auto &test : test_vectors) { + random_.AddNextValues(kNonceHigh, kNonceLow); + auto encoder = LoadBalancerEncoder::Create(random_, nullptr, true, 8); + EXPECT_TRUE(encoder->UpdateConfig(test.config, test.server_id)); + absl::uint128 nonces_left = encoder->num_nonces_left(); + EXPECT_EQ(encoder->GenerateConnectionId(), test.connection_id); + EXPECT_EQ(encoder->num_nonces_left(), nonces_left - 1); + } +} + +// Follow example in draft-ietf-quic-load-balancers-15. +TEST_F(LoadBalancerEncoderTest, FollowSpecExample) { + const uint8_t config_id = 0, server_id_len = 3, nonce_len = 4; + const uint8_t raw_server_id[] = { + 0x31, + 0x44, + 0x1a, + }; + const char raw_key[] = { + 0xfd, 0xf7, 0x26, 0xa9, 0x89, 0x3e, 0xc0, 0x5c, + 0x06, 0x32, 0xd3, 0x95, 0x66, 0x80, 0xba, 0xf0, + }; + random_.AddNextValues(0, 0x75c2699c); + auto encoder = LoadBalancerEncoder::Create(random_, nullptr, true, 8); + ASSERT_TRUE(encoder.has_value()); + auto config = LoadBalancerConfig::Create(config_id, server_id_len, nonce_len, + absl::string_view(raw_key)); + ASSERT_TRUE(config.has_value()); + EXPECT_TRUE(encoder->UpdateConfig( + *config, *LoadBalancerServerId::Create(raw_server_id))); + EXPECT_TRUE(encoder->IsEncoding()); + const char raw_connection_id[] = {0x07, 0x8e, 0x9a, 0x91, + 0xf4, 0x94, 0x97, 0x62}; + auto expected = + QuicConnectionId(raw_connection_id, 1 + server_id_len + nonce_len); + EXPECT_EQ(encoder->GenerateConnectionId(), expected); +} + +// Compare test vectors from Appendix B of draft-ietf-quic-load-balancers-15. +TEST_F(LoadBalancerEncoderTest, EncoderTestVectors) { + // Try (1) the "standard" ConnectionId length of 8 + // (2) server_id_len > nonce_len, so there is a fourth decryption pass + // (3) the single-pass encryption case + // (4) An even total length. + const LoadBalancerEncoderTestCase test_vectors[4] = { + { + *LoadBalancerConfig::Create(0, 3, 4, kKey), + QuicConnectionId({0x07, 0x41, 0x26, 0xee, 0x38, 0xbf, 0x54, 0x54}), + MakeServerId(kServerId, 3), + }, + { + *LoadBalancerConfig::Create(1, 10, 5, kKey), + QuicConnectionId({0x4f, 0xcd, 0x3f, 0x57, 0x2d, 0x4e, 0xef, 0xb0, + 0x46, 0xfd, 0xb5, 0x1d, 0x16, 0x4e, 0xfc, 0xcc}), + MakeServerId(kServerId, 10), + }, + { + *LoadBalancerConfig::Create(2, 8, 8, kKey), + QuicConnectionId({0x90, 0x4d, 0xd2, 0xd0, 0x5a, 0x7b, 0x0d, 0xe9, + 0xb2, 0xb9, 0x90, 0x7a, 0xfb, 0x5e, 0xcf, 0x8c, + 0xc3}), + MakeServerId(kServerId, 8), + }, + { + *LoadBalancerConfig::Create(0, 9, 9, kKey), + QuicConnectionId({0x12, 0x12, 0x4d, 0x1e, 0xb8, 0xfb, 0xb2, 0x1e, + 0x4a, 0x49, 0x0c, 0xa5, 0x3c, 0xfe, 0x21, 0xd0, + 0x4a, 0xe6, 0x3a}), + MakeServerId(kServerId, 9), + }, + }; + for (const auto &test : test_vectors) { + auto encoder = LoadBalancerEncoder::Create(random_, nullptr, true, 8); + ASSERT_TRUE(encoder.has_value()); + random_.AddNextValues(kNonceHigh, kNonceLow); + EXPECT_TRUE(encoder->UpdateConfig(test.config, test.server_id)); + EXPECT_EQ(encoder->GenerateConnectionId(), test.connection_id); + } +} + +TEST_F(LoadBalancerEncoderTest, RunOutOfNonces) { + const uint8_t server_id_len = 3; + TestLoadBalancerEncoderVisitor visitor; + auto encoder = LoadBalancerEncoder::Create(random_, &visitor, true, 8); + ASSERT_TRUE(encoder.has_value()); + auto config = LoadBalancerConfig::Create(0, server_id_len, 4, kKey); + ASSERT_TRUE(config.has_value()); + EXPECT_TRUE( + encoder->UpdateConfig(*config, MakeServerId(kServerId, server_id_len))); + EXPECT_EQ(visitor.num_adds(), 1u); + LoadBalancerEncoderPeer::SetNumNoncesLeft(*encoder, 2); + EXPECT_EQ(encoder->num_nonces_left(), 2); + EXPECT_EQ(encoder->GenerateConnectionId(), + QuicConnectionId({0x07, 0x1d, 0x4a, 0xb8, 0xc6, 0x1d, 0xd6, 0x5d})); + EXPECT_EQ(encoder->num_nonces_left(), 1); + encoder->GenerateConnectionId(); + EXPECT_EQ(encoder->IsEncoding(), false); + // No retire_calls except for the initial UpdateConfig. + EXPECT_EQ(visitor.num_deletes(), 1u); +} + +TEST_F(LoadBalancerEncoderTest, UnroutableConnectionId) { + random_.AddNextValues(0x83, kNonceHigh); + auto encoder = LoadBalancerEncoder::Create(random_, nullptr, false); + ASSERT_TRUE(encoder.has_value()); + EXPECT_EQ(encoder->num_nonces_left(), 0); + auto connection_id = encoder->GenerateConnectionId(); + // The first byte is the config_id (0xc0) xored with (0x83 & 0x3f). + // The remaining bytes are random, and therefore match kNonceHigh. + QuicConnectionId expected({0xc3, 0x5d, 0x52, 0xde, 0x4d, 0xe3, 0xe7, 0x21}); + EXPECT_EQ(expected, connection_id); +} + +TEST_F(LoadBalancerEncoderTest, NonDefaultUnroutableConnectionIdLength) { + auto encoder = LoadBalancerEncoder::Create(random_, nullptr, true, 9); + ASSERT_TRUE(encoder.has_value()); + QuicConnectionId connection_id = encoder->GenerateConnectionId(); + EXPECT_EQ(connection_id.length(), 9); +} + +TEST_F(LoadBalancerEncoderTest, DeleteConfigWhenNoConfigExists) { + TestLoadBalancerEncoderVisitor visitor; + auto encoder = LoadBalancerEncoder::Create(random_, &visitor, true); + ASSERT_TRUE(encoder.has_value()); + encoder->DeleteConfig(); + EXPECT_EQ(visitor.num_deletes(), 0u); +} + +TEST_F(LoadBalancerEncoderTest, AddConfig) { + auto config = LoadBalancerConfig::CreateUnencrypted(0, 3, 4); + ASSERT_TRUE(config.has_value()); + TestLoadBalancerEncoderVisitor visitor; + auto encoder = LoadBalancerEncoder::Create(random_, &visitor, true); + EXPECT_TRUE(encoder->UpdateConfig(*config, MakeServerId(kServerId, 3))); + EXPECT_EQ(visitor.num_adds(), 1u); + absl::uint128 left = encoder->num_nonces_left(); + EXPECT_EQ(left, (0x1ull << 32)); + EXPECT_TRUE(encoder->IsEncoding()); + EXPECT_FALSE(encoder->IsEncrypted()); + encoder->GenerateConnectionId(); + EXPECT_EQ(encoder->num_nonces_left(), left - 1); + EXPECT_EQ(visitor.num_deletes(), 0u); +} + +TEST_F(LoadBalancerEncoderTest, UpdateConfig) { + auto config = LoadBalancerConfig::CreateUnencrypted(0, 3, 4); + ASSERT_TRUE(config.has_value()); + TestLoadBalancerEncoderVisitor visitor; + auto encoder = LoadBalancerEncoder::Create(random_, &visitor, true); + EXPECT_TRUE(encoder->UpdateConfig(*config, MakeServerId(kServerId, 3))); + config = LoadBalancerConfig::Create(1, 4, 4, kKey); + ASSERT_TRUE(config.has_value()); + EXPECT_TRUE(encoder->UpdateConfig(*config, MakeServerId(kServerId, 4))); + EXPECT_EQ(visitor.num_adds(), 2u); + EXPECT_EQ(visitor.num_deletes(), 1u); + EXPECT_TRUE(encoder->IsEncoding()); + EXPECT_TRUE(encoder->IsEncrypted()); +} + +TEST_F(LoadBalancerEncoderTest, DeleteConfig) { + auto config = LoadBalancerConfig::CreateUnencrypted(0, 3, 4); + ASSERT_TRUE(config.has_value()); + TestLoadBalancerEncoderVisitor visitor; + auto encoder = LoadBalancerEncoder::Create(random_, &visitor, true); + EXPECT_TRUE(encoder->UpdateConfig(*config, MakeServerId(kServerId, 3))); + encoder->DeleteConfig(); + EXPECT_EQ(visitor.num_adds(), 1u); + EXPECT_EQ(visitor.num_deletes(), 1u); + EXPECT_FALSE(encoder->IsEncoding()); + EXPECT_FALSE(encoder->IsEncrypted()); + EXPECT_EQ(encoder->num_nonces_left(), 0); +} + +TEST_F(LoadBalancerEncoderTest, DeleteConfigNoVisitor) { + auto config = LoadBalancerConfig::CreateUnencrypted(0, 3, 4); + ASSERT_TRUE(config.has_value()); + auto encoder = LoadBalancerEncoder::Create(random_, nullptr, true); + EXPECT_TRUE(encoder->UpdateConfig(*config, MakeServerId(kServerId, 3))); + encoder->DeleteConfig(); + EXPECT_FALSE(encoder->IsEncoding()); + EXPECT_FALSE(encoder->IsEncrypted()); + EXPECT_EQ(encoder->num_nonces_left(), 0); +} + +TEST_F(LoadBalancerEncoderTest, MaybeReplaceConnectionIdReturnsNoChange) { + auto encoder = LoadBalancerEncoder::Create(random_, nullptr, false); + ASSERT_TRUE(encoder.has_value()); + EXPECT_EQ(encoder->MaybeReplaceConnectionId(TestConnectionId(1), + ParsedQuicVersion::Q050()), + absl::nullopt); +} + +TEST_F(LoadBalancerEncoderTest, MaybeReplaceConnectionIdReturnsChange) { + random_.AddNextValues(0x83, kNonceHigh); + auto encoder = LoadBalancerEncoder::Create(random_, nullptr, false); + ASSERT_TRUE(encoder.has_value()); + // The first byte is the config_id (0xc0) xored with (0x83 & 0x3f). + // The remaining bytes are random, and therefore match kNonceHigh. + QuicConnectionId expected({0xc3, 0x5d, 0x52, 0xde, 0x4d, 0xe3, 0xe7, 0x21}); + EXPECT_EQ(*encoder->MaybeReplaceConnectionId(TestConnectionId(1), + ParsedQuicVersion::RFCv1()), + expected); +} + +TEST_F(LoadBalancerEncoderTest, GenerateNextConnectionIdReturnsNoChange) { + auto config = LoadBalancerConfig::CreateUnencrypted(0, 3, 4); + ASSERT_TRUE(config.has_value()); + auto encoder = LoadBalancerEncoder::Create(random_, nullptr, true); + EXPECT_TRUE(encoder->UpdateConfig(*config, MakeServerId(kServerId, 3))); + EXPECT_EQ(encoder->GenerateNextConnectionId(TestConnectionId(1)), + absl::nullopt); +} + +TEST_F(LoadBalancerEncoderTest, GenerateNextConnectionIdReturnsChange) { + random_.AddNextValues(0x83, kNonceHigh); + auto encoder = LoadBalancerEncoder::Create(random_, nullptr, false); + ASSERT_TRUE(encoder.has_value()); + // The first byte is the config_id (0xc0) xored with (0x83 & 0x3f). + // The remaining bytes are random, and therefore match kNonceHigh. + QuicConnectionId expected({0xc3, 0x5d, 0x52, 0xde, 0x4d, 0xe3, 0xe7, 0x21}); + EXPECT_EQ(*encoder->GenerateNextConnectionId(TestConnectionId(1)), expected); +} + +TEST_F(LoadBalancerEncoderTest, ConnectionIdLengthsEncoded) { + // The first byte literally encodes the length. + auto len_encoder = LoadBalancerEncoder::Create(random_, nullptr, true); + ASSERT_TRUE(len_encoder.has_value()); + EXPECT_EQ(len_encoder->ConnectionIdLength(0xc8), 9); + EXPECT_EQ(len_encoder->ConnectionIdLength(0x4a), 11); + EXPECT_EQ(len_encoder->ConnectionIdLength(0x09), 10); + // The length is not self-encoded anymore. + auto encoder = LoadBalancerEncoder::Create(random_, nullptr, false); + ASSERT_TRUE(encoder.has_value()); + EXPECT_EQ(encoder->ConnectionIdLength(0xc8), kQuicDefaultConnectionIdLength); + EXPECT_EQ(encoder->ConnectionIdLength(0x4a), kQuicDefaultConnectionIdLength); + EXPECT_EQ(encoder->ConnectionIdLength(0x09), kQuicDefaultConnectionIdLength); + // Add config ID 0, so that ID now returns a different length. + uint8_t config_id = 0; + uint8_t server_id_len = 3; + uint8_t nonce_len = 6; + uint8_t config_0_len = server_id_len + nonce_len + 1; + auto config0 = LoadBalancerConfig::CreateUnencrypted(config_id, server_id_len, + nonce_len); + ASSERT_TRUE(config0.has_value()); + EXPECT_TRUE( + encoder->UpdateConfig(*config0, MakeServerId(kServerId, server_id_len))); + EXPECT_EQ(encoder->ConnectionIdLength(0xc8), kQuicDefaultConnectionIdLength); + EXPECT_EQ(encoder->ConnectionIdLength(0x4a), kQuicDefaultConnectionIdLength); + EXPECT_EQ(encoder->ConnectionIdLength(0x09), config_0_len); + // Replace config ID 0 with 1. There are probably still packets with config + // ID 0 arriving, so keep that length in memory. + config_id = 1; + nonce_len++; + uint8_t config_1_len = server_id_len + nonce_len + 1; + auto config1 = LoadBalancerConfig::CreateUnencrypted(config_id, server_id_len, + nonce_len); + ASSERT_TRUE(config1.has_value()); + // Old config length still there after replacement + EXPECT_TRUE( + encoder->UpdateConfig(*config1, MakeServerId(kServerId, server_id_len))); + EXPECT_EQ(encoder->ConnectionIdLength(0xc8), kQuicDefaultConnectionIdLength); + EXPECT_EQ(encoder->ConnectionIdLength(0x4a), config_1_len); + EXPECT_EQ(encoder->ConnectionIdLength(0x09), config_0_len); + // Old config length still there after delete + encoder->DeleteConfig(); + EXPECT_EQ(encoder->ConnectionIdLength(0xc8), kQuicDefaultConnectionIdLength); + EXPECT_EQ(encoder->ConnectionIdLength(0x4a), config_1_len); + EXPECT_EQ(encoder->ConnectionIdLength(0x09), config_0_len); +} + +} // namespace + +} // namespace test + +} // namespace quic diff --git a/quiche/quic/load_balancer/load_balancer_server_id.cc b/quiche/quic/load_balancer/load_balancer_server_id.cc new file mode 100644 index 000000000000..5805b7fe5c0d --- /dev/null +++ b/quiche/quic/load_balancer/load_balancer_server_id.cc @@ -0,0 +1,45 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/load_balancer/load_balancer_server_id.h" + +#include "absl/strings/escaping.h" +#include "absl/types/span.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" + +namespace quic { + +namespace { + +// Helper to allow setting the const array during initialization. +std::array MakeArray( + const absl::Span data, const uint8_t length) { + std::array array; + memcpy(array.data(), data.data(), length); + return array; +} + +} // namespace + +absl::optional LoadBalancerServerId::Create( + const absl::Span data) { + if (data.length() == 0 || data.length() > kLoadBalancerMaxServerIdLen) { + QUIC_BUG(quic_bug_433312504_01) + << "Attempted to create LoadBalancerServerId with length " + << data.length(); + return absl::optional(); + } + return LoadBalancerServerId(data); +} + +std::string LoadBalancerServerId::ToString() const { + return absl::BytesToHexString( + absl::string_view(reinterpret_cast(data_.data()), length_)); +} + +LoadBalancerServerId::LoadBalancerServerId(const absl::Span data) + : data_(MakeArray(data, data.length())), length_(data.length()) {} + +} // namespace quic diff --git a/quiche/quic/load_balancer/load_balancer_server_id.h b/quiche/quic/load_balancer/load_balancer_server_id.h new file mode 100644 index 000000000000..cdd3c9568905 --- /dev/null +++ b/quiche/quic/load_balancer/load_balancer_server_id.h @@ -0,0 +1,71 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_SERVER_ID_H_ +#define QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_SERVER_ID_H_ + +#include + +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// The maximum number of bytes in a LoadBalancerServerId. +inline constexpr uint8_t kLoadBalancerMaxServerIdLen = 15; + +// LoadBalancerServerId is the globally understood identifier for a given pool +// member. It is unique to any given QUIC-LB configuration. See +// draft-ietf-quic-load-balancers. +// Note: this has nothing to do with QuicServerID. It's an unfortunate collision +// between an internal term for the destination identifiers for a particular +// deployment (QuicServerID) and the object of a load balancing decision +// (LoadBalancerServerId). +class QUIC_EXPORT_PRIVATE LoadBalancerServerId { + public: + // Copies all the bytes from |data| into a new LoadBalancerServerId. + static absl::optional Create( + absl::Span data); + + // For callers with a string_view at hand. + static absl::optional Create(absl::string_view data) { + return Create(absl::MakeSpan(reinterpret_cast(data.data()), + data.length())); + } + + // Server IDs are opaque bytes, but defining these operators allows us to sort + // them into a tree and define ranges. + bool operator<(const LoadBalancerServerId& other) const { + return data() < other.data(); + } + bool operator==(const LoadBalancerServerId& other) const { + return data() == other.data(); + } + + // Hash function to allow use as a key in unordered maps. + template + friend H AbslHashValue(H h, const LoadBalancerServerId& server_id) { + return H::combine_contiguous(std::move(h), server_id.data().data(), + server_id.length()); + } + + absl::Span data() const { + return absl::MakeConstSpan(data_.data(), length_); + } + uint8_t length() const { return length_; } + + // Returns the server ID in hex format. + std::string ToString() const; + + private: + // The constructor is private because it can't validate the input. + LoadBalancerServerId(const absl::Span data); + + std::array data_; + uint8_t length_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_SERVER_ID_H_ diff --git a/quiche/quic/load_balancer/load_balancer_server_id_map.h b/quiche/quic/load_balancer/load_balancer_server_id_map.h new file mode 100644 index 000000000000..2e79e6642917 --- /dev/null +++ b/quiche/quic/load_balancer/load_balancer_server_id_map.h @@ -0,0 +1,104 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_SERVER_ID_MAP_H_ +#define QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_SERVER_ID_MAP_H_ + +#include "absl/container/flat_hash_map.h" +#include "quiche/quic/load_balancer/load_balancer_server_id.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" + +namespace quic { + +// This class wraps an absl::flat_hash_map which associates server IDs to an +// arbitrary type T. It validates that all server ids are of the same fixed +// length. This might be used by a load balancer to connect a server ID with a +// pool member data structure. +template +class QUIC_EXPORT_PRIVATE LoadBalancerServerIdMap { + public: + // Returns a newly created pool for server IDs of length |server_id_len|, or + // nullptr if |server_id_len| is invalid. + static std::shared_ptr Create( + const uint8_t server_id_len); + + // Returns the entry associated with |server_id|, if present. For small |T|, + // use Lookup. For large |T|, use LookupNoCopy. + absl::optional Lookup(const LoadBalancerServerId server_id) const; + const T* LookupNoCopy(const LoadBalancerServerId server_id) const; + + // Updates the table so that |value| is associated with |server_id|. Sets + // QUIC_BUG if the length is incorrect for this map. + void AddOrReplace(const LoadBalancerServerId server_id, T value); + + // Removes the entry associated with |server_id|. + void Erase(const LoadBalancerServerId server_id) { + server_id_table_.erase(server_id); + } + + uint8_t server_id_len() const { return server_id_len_; } + + private: + LoadBalancerServerIdMap(uint8_t server_id_len) + : server_id_len_(server_id_len) {} + + const uint8_t server_id_len_; // All server IDs must be of this length. + absl::flat_hash_map server_id_table_; +}; + +template +std::shared_ptr> LoadBalancerServerIdMap::Create( + const uint8_t server_id_len) { + if (server_id_len == 0 || server_id_len > kLoadBalancerMaxServerIdLen) { + QUIC_BUG(quic_bug_434893339_01) + << "Tried to configure map with server ID length " + << static_cast(server_id_len); + return nullptr; + } + return std::make_shared>( + LoadBalancerServerIdMap(server_id_len)); +} + +template +absl::optional LoadBalancerServerIdMap::Lookup( + const LoadBalancerServerId server_id) const { + if (server_id.length() != server_id_len_) { + QUIC_BUG(quic_bug_434893339_02) + << "Lookup with a " << static_cast(server_id.length()) + << " byte server ID, map requires " << static_cast(server_id_len_); + return absl::optional(); + } + auto it = server_id_table_.find(server_id); + return (it != server_id_table_.end()) ? it->second + : absl::optional(); +} + +template +const T* LoadBalancerServerIdMap::LookupNoCopy( + const LoadBalancerServerId server_id) const { + if (server_id.length() != server_id_len_) { + QUIC_BUG(quic_bug_434893339_02) + << "Lookup with a " << static_cast(server_id.length()) + << " byte server ID, map requires " << static_cast(server_id_len_); + return nullptr; + } + auto it = server_id_table_.find(server_id); + return (it != server_id_table_.end()) ? &it->second : nullptr; +} + +template +void LoadBalancerServerIdMap::AddOrReplace( + const LoadBalancerServerId server_id, T value) { + if (server_id.length() == server_id_len_) { + server_id_table_[server_id] = value; + } else { + QUIC_BUG(quic_bug_434893339_03) + << "Server ID of " << static_cast(server_id.length()) + << " bytes; this map requires " << static_cast(server_id_len_); + } +} + +} // namespace quic + +#endif // QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_SERVER_ID_MAP_H_ diff --git a/quiche/quic/load_balancer/load_balancer_server_id_map_test.cc b/quiche/quic/load_balancer/load_balancer_server_id_map_test.cc new file mode 100644 index 000000000000..d8fcdac79648 --- /dev/null +++ b/quiche/quic/load_balancer/load_balancer_server_id_map_test.cc @@ -0,0 +1,94 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/load_balancer/load_balancer_server_id_map.h" + +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { + +namespace test { + +namespace { + +constexpr uint8_t kServerId[] = {0xed, 0x79, 0x3a, 0x51}; + +class LoadBalancerServerIdMapTest : public QuicTest { + public: + const LoadBalancerServerId valid_server_id_ = + *LoadBalancerServerId::Create(kServerId); + const LoadBalancerServerId invalid_server_id_ = + *LoadBalancerServerId::Create(absl::Span(kServerId, 3)); +}; + +TEST_F(LoadBalancerServerIdMapTest, CreateWithBadServerIdLength) { + EXPECT_QUIC_BUG(EXPECT_EQ(LoadBalancerServerIdMap::Create(0), nullptr), + "Tried to configure map with server ID length 0"); + EXPECT_QUIC_BUG(EXPECT_EQ(LoadBalancerServerIdMap::Create(16), nullptr), + "Tried to configure map with server ID length 16"); +} + +TEST_F(LoadBalancerServerIdMapTest, AddOrReplaceWithBadServerIdLength) { + int record = 1; + auto pool = LoadBalancerServerIdMap::Create(4); + EXPECT_NE(pool, nullptr); + EXPECT_QUIC_BUG(pool->AddOrReplace(invalid_server_id_, record), + "Server ID of 3 bytes; this map requires 4"); +} + +TEST_F(LoadBalancerServerIdMapTest, LookupWithBadServerIdLength) { + int record = 1; + auto pool = LoadBalancerServerIdMap::Create(4); + EXPECT_NE(pool, nullptr); + pool->AddOrReplace(valid_server_id_, record); + EXPECT_QUIC_BUG(EXPECT_FALSE(pool->Lookup(invalid_server_id_).has_value()), + "Lookup with a 3 byte server ID, map requires 4"); + EXPECT_QUIC_BUG(EXPECT_EQ(pool->LookupNoCopy(invalid_server_id_), nullptr), + "Lookup with a 3 byte server ID, map requires 4"); +} + +TEST_F(LoadBalancerServerIdMapTest, LookupWhenEmpty) { + auto pool = LoadBalancerServerIdMap::Create(4); + EXPECT_NE(pool, nullptr); + EXPECT_EQ(pool->LookupNoCopy(valid_server_id_), nullptr); + absl::optional result = pool->Lookup(valid_server_id_); + EXPECT_FALSE(result.has_value()); +} + +TEST_F(LoadBalancerServerIdMapTest, AddLookup) { + int record1 = 1, record2 = 2; + auto pool = LoadBalancerServerIdMap::Create(4); + EXPECT_NE(pool, nullptr); + auto other_server_id = LoadBalancerServerId::Create({0x01, 0x02, 0x03, 0x04}); + EXPECT_TRUE(other_server_id.has_value()); + pool->AddOrReplace(valid_server_id_, record1); + pool->AddOrReplace(*other_server_id, record2); + absl::optional result = pool->Lookup(valid_server_id_); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, record1); + auto result_ptr = pool->LookupNoCopy(valid_server_id_); + EXPECT_NE(result_ptr, nullptr); + EXPECT_EQ(*result_ptr, record1); + result = pool->Lookup(*other_server_id); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, record2); +} + +TEST_F(LoadBalancerServerIdMapTest, AddErase) { + int record = 1; + auto pool = LoadBalancerServerIdMap::Create(4); + EXPECT_NE(pool, nullptr); + pool->AddOrReplace(valid_server_id_, record); + EXPECT_EQ(*pool->LookupNoCopy(valid_server_id_), record); + pool->Erase(valid_server_id_); + EXPECT_EQ(pool->LookupNoCopy(valid_server_id_), nullptr); +} + +} // namespace + +} // namespace test + +} // namespace quic diff --git a/quiche/quic/load_balancer/load_balancer_server_id_test.cc b/quiche/quic/load_balancer/load_balancer_server_id_test.cc new file mode 100644 index 000000000000..91fcc57a07dc --- /dev/null +++ b/quiche/quic/load_balancer/load_balancer_server_id_test.cc @@ -0,0 +1,106 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/load_balancer/load_balancer_server_id.h" + +#include "absl/hash/hash_testing.h" + +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { + +namespace test { + +namespace { + +class LoadBalancerServerIdTest : public QuicTest {}; + +constexpr uint8_t kRawServerId[] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, + 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, + 0x0c, 0x0d, 0x0e, 0x0f}; + +TEST_F(LoadBalancerServerIdTest, CreateReturnsNullIfTooLong) { + EXPECT_QUIC_BUG(EXPECT_FALSE(LoadBalancerServerId::Create( + absl::Span(kRawServerId, 16)) + .has_value()), + "Attempted to create LoadBalancerServerId with length 16"); + EXPECT_QUIC_BUG( + EXPECT_FALSE(LoadBalancerServerId::Create(absl::Span()) + .has_value()), + "Attempted to create LoadBalancerServerId with length 0"); +} + +TEST_F(LoadBalancerServerIdTest, CompareIdenticalExceptLength) { + auto server_id = + LoadBalancerServerId::Create(absl::Span(kRawServerId, 15)); + ASSERT_TRUE(server_id.has_value()); + EXPECT_EQ(server_id->length(), 15); + auto shorter_server_id = + LoadBalancerServerId::Create(absl::Span(kRawServerId, 5)); + ASSERT_TRUE(shorter_server_id.has_value()); + EXPECT_EQ(shorter_server_id->length(), 5); + // Shorter comes before longer if all bits match + EXPECT_TRUE(shorter_server_id < server_id); + EXPECT_FALSE(server_id < shorter_server_id); + // Different lengths are never equal. + EXPECT_FALSE(shorter_server_id == server_id); +} + +TEST_F(LoadBalancerServerIdTest, AccessorFunctions) { + auto server_id = + LoadBalancerServerId::Create(absl::Span(kRawServerId, 5)); + EXPECT_TRUE(server_id.has_value()); + EXPECT_EQ(server_id->length(), 5); + EXPECT_EQ(memcmp(server_id->data().data(), kRawServerId, 5), 0); + EXPECT_EQ(server_id->ToString(), "0001020304"); +} + +TEST_F(LoadBalancerServerIdTest, CompareDifferentServerIds) { + auto server_id = + LoadBalancerServerId::Create(absl::Span(kRawServerId, 5)); + ASSERT_TRUE(server_id.has_value()); + auto reverse = LoadBalancerServerId::Create({0x0f, 0x0e, 0x0d, 0x0c, 0x0b}); + ASSERT_TRUE(reverse.has_value()); + EXPECT_TRUE(server_id < reverse); + auto long_server_id = + LoadBalancerServerId::Create(absl::Span(kRawServerId, 15)); + EXPECT_TRUE(long_server_id < reverse); +} + +TEST_F(LoadBalancerServerIdTest, EqualityOperators) { + auto server_id = + LoadBalancerServerId::Create(absl::Span(kRawServerId, 15)); + ASSERT_TRUE(server_id.has_value()); + auto shorter_server_id = + LoadBalancerServerId::Create(absl::Span(kRawServerId, 5)); + ASSERT_TRUE(shorter_server_id.has_value()); + EXPECT_FALSE(server_id == shorter_server_id); + auto server_id2 = server_id; + EXPECT_TRUE(server_id == server_id2); +} + +TEST_F(LoadBalancerServerIdTest, SupportsHash) { + auto server_id = + LoadBalancerServerId::Create(absl::Span(kRawServerId, 15)); + ASSERT_TRUE(server_id.has_value()); + auto shorter_server_id = + LoadBalancerServerId::Create(absl::Span(kRawServerId, 5)); + ASSERT_TRUE(shorter_server_id.has_value()); + auto different_server_id = + LoadBalancerServerId::Create({0x0f, 0x0e, 0x0d, 0x0c, 0x0b}); + ASSERT_TRUE(different_server_id.has_value()); + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly({ + *server_id, + *shorter_server_id, + *different_server_id, + })); +} + +} // namespace + +} // namespace test + +} // namespace quic diff --git a/quiche/quic/masque/README.md b/quiche/quic/masque/README.md new file mode 100644 index 000000000000..6bfc08ee0f8c --- /dev/null +++ b/quiche/quic/masque/README.md @@ -0,0 +1,4 @@ +# MASQUE + +The files in this directory implement MASQUE as described in +. diff --git a/quiche/quic/masque/masque_client.cc b/quiche/quic/masque/masque_client.cc new file mode 100644 index 000000000000..b30597ba0420 --- /dev/null +++ b/quiche/quic/masque/masque_client.cc @@ -0,0 +1,106 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/masque/masque_client.h" + +#include + +#include "absl/memory/memory.h" +#include "quiche/quic/masque/masque_client_session.h" +#include "quiche/quic/masque/masque_utils.h" +#include "quiche/quic/tools/quic_name_lookup.h" +#include "quiche/quic/tools/quic_url.h" + +namespace quic { + +MasqueClient::MasqueClient(QuicSocketAddress server_address, + const QuicServerId& server_id, + MasqueMode masque_mode, QuicEventLoop* event_loop, + std::unique_ptr proof_verifier, + const std::string& uri_template) + : QuicDefaultClient(server_address, server_id, MasqueSupportedVersions(), + event_loop, std::move(proof_verifier)), + masque_mode_(masque_mode), + uri_template_(uri_template) {} + +std::unique_ptr MasqueClient::CreateQuicClientSession( + const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection) { + QUIC_DLOG(INFO) << "Creating MASQUE session for " + << connection->connection_id(); + return std::make_unique( + masque_mode_, uri_template_, *config(), supported_versions, connection, + server_id(), crypto_config(), push_promise_index(), this); +} + +MasqueClientSession* MasqueClient::masque_client_session() { + return static_cast(QuicDefaultClient::session()); +} + +QuicConnectionId MasqueClient::connection_id() { + return masque_client_session()->connection_id(); +} + +std::string MasqueClient::authority() const { + QuicUrl url(uri_template_); + return absl::StrCat(url.host(), ":", url.port()); +} + +// static +std::unique_ptr MasqueClient::Create( + const std::string& uri_template, MasqueMode masque_mode, + QuicEventLoop* event_loop, std::unique_ptr proof_verifier) { + QuicUrl url(uri_template); + std::string host = url.host(); + uint16_t port = url.port(); + // Build the masque_client, and try to connect. + QuicSocketAddress addr = tools::LookupAddress(host, absl::StrCat(port)); + if (!addr.IsInitialized()) { + QUIC_LOG(ERROR) << "Unable to resolve address: " << host; + return nullptr; + } + QuicServerId server_id(host, port); + // Use absl::WrapUnique(new MasqueClient(...)) instead of + // std::make_unique(...) because the constructor for + // MasqueClient is private and therefore not accessible from make_unique. + auto masque_client = absl::WrapUnique( + new MasqueClient(addr, server_id, masque_mode, event_loop, + std::move(proof_verifier), uri_template)); + + if (masque_client == nullptr) { + QUIC_LOG(ERROR) << "Failed to create masque_client"; + return nullptr; + } + + masque_client->set_initial_max_packet_length(kMasqueMaxOuterPacketSize); + masque_client->set_drop_response_body(false); + if (!masque_client->Initialize()) { + QUIC_LOG(ERROR) << "Failed to initialize masque_client"; + return nullptr; + } + if (!masque_client->Connect()) { + QuicErrorCode error = masque_client->session()->error(); + QUIC_LOG(ERROR) << "Failed to connect to " << host << ":" << port + << ". Error: " << QuicErrorCodeToString(error); + return nullptr; + } + + if (!masque_client->WaitUntilSettingsReceived()) { + QUIC_LOG(ERROR) << "Failed to receive settings"; + return nullptr; + } + + return masque_client; +} + +void MasqueClient::OnSettingsReceived() { settings_received_ = true; } + +bool MasqueClient::WaitUntilSettingsReceived() { + while (connected() && !settings_received_) { + network_helper()->RunEventLoop(); + } + return connected() && settings_received_; +} + +} // namespace quic diff --git a/quiche/quic/masque/masque_client.h b/quiche/quic/masque/masque_client.h new file mode 100644 index 000000000000..d3332acc2b22 --- /dev/null +++ b/quiche/quic/masque/masque_client.h @@ -0,0 +1,67 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_MASQUE_MASQUE_CLIENT_H_ +#define QUICHE_QUIC_MASQUE_MASQUE_CLIENT_H_ + +#include + +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/masque/masque_client_session.h" +#include "quiche/quic/masque/masque_utils.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/tools/quic_default_client.h" + +namespace quic { + +// QUIC client that implements MASQUE. +class QUIC_NO_EXPORT MasqueClient : public QuicDefaultClient, + public MasqueClientSession::Owner { + public: + // Constructs a MasqueClient, performs a synchronous DNS lookup. + static std::unique_ptr Create( + const std::string& uri_template, MasqueMode masque_mode, + QuicEventLoop* event_loop, std::unique_ptr proof_verifier); + + // From QuicClient. + std::unique_ptr CreateQuicClientSession( + const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection) override; + + // Client session for this client. + MasqueClientSession* masque_client_session(); + + // Convenience accessor for the underlying connection ID. + QuicConnectionId connection_id(); + + // From MasqueClientSession::Owner. + void OnSettingsReceived() override; + + MasqueMode masque_mode() const { return masque_mode_; } + + private: + // Constructor is private, use Create() instead. + MasqueClient(QuicSocketAddress server_address, const QuicServerId& server_id, + MasqueMode masque_mode, QuicEventLoop* event_loop, + std::unique_ptr proof_verifier, + const std::string& uri_template); + + // Wait synchronously until we receive the peer's settings. Returns whether + // they were received. + bool WaitUntilSettingsReceived(); + + std::string authority() const; + + // Disallow copy and assign. + MasqueClient(const MasqueClient&) = delete; + MasqueClient& operator=(const MasqueClient&) = delete; + + MasqueMode masque_mode_; + std::string uri_template_; + bool settings_received_ = false; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_MASQUE_MASQUE_CLIENT_H_ diff --git a/quiche/quic/masque/masque_client_bin.cc b/quiche/quic/masque/masque_client_bin.cc new file mode 100644 index 000000000000..445e56b28f03 --- /dev/null +++ b/quiche/quic/masque/masque_client_bin.cc @@ -0,0 +1,258 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This file is reponsible for the masque_client binary. It allows testing +// our MASQUE client code by connecting to a MASQUE proxy and then sending +// HTTP/3 requests to web servers tunnelled over that MASQUE connection. +// e.g.: masque_client $PROXY_HOST:$PROXY_PORT $URL1 $URL2 + +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "url/third_party/mozilla/url_parse.h" +#include "quiche/quic/core/io/quic_default_event_loop.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/quic_default_clock.h" +#include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/masque/masque_client.h" +#include "quiche/quic/masque/masque_client_tools.h" +#include "quiche/quic/masque/masque_encapsulated_client.h" +#include "quiche/quic/masque/masque_utils.h" +#include "quiche/quic/platform/api/quic_default_proof_providers.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/tools/fake_proof_verifier.h" +#include "quiche/common/platform/api/quiche_command_line_flags.h" +#include "quiche/common/platform/api/quiche_system_event_loop.h" + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + bool, disable_certificate_verification, false, + "If true, don't verify the server certificate."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG(int, address_family, 0, + "IP address family to use. Must be 0, 4 or 6. " + "Defaults to 0 which means any."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::string, masque_mode, "", + "Allows setting MASQUE mode, currently only valid value is \"open\"."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + bool, bring_up_tun, false, + "If set to true, no URLs need to be specified and instead a TUN device " + "is brought up with the assigned IP from the MASQUE CONNECT-IP server"); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + bool, dns_on_client, false, + "If set to true, masque_client will perform DNS for encapsulated URLs and " + "send the IP litteral in the CONNECT request. If set to false, " + "masque_client send the hostname in the CONNECT request."); + +namespace quic { + +namespace { + +using ::quiche::AddressAssignCapsule; +using ::quiche::AddressRequestCapsule; +using ::quiche::RouteAdvertisementCapsule; + +class MasqueTunSession : public MasqueClientSession::EncapsulatedIpSession, + public QuicSocketEventListener { + public: + MasqueTunSession(QuicEventLoop* event_loop, MasqueClientSession* session) + : event_loop_(event_loop), session_(session) {} + ~MasqueTunSession() override = default; + // MasqueClientSession::EncapsulatedIpSession + void ProcessIpPacket(absl::string_view packet) override { + QUIC_LOG(INFO) << " Received IP packets of length " << packet.length(); + if (fd_ == -1) { + // TUN not open, early return + return; + } + if (write(fd_, packet.data(), packet.size()) == -1) { + QUIC_LOG(FATAL) << "Failed to write"; + } + } + void CloseIpSession(const std::string& details) override { + QUIC_LOG(ERROR) << "Was asked to close IP session: " << details; + } + bool OnAddressAssignCapsule(const AddressAssignCapsule& capsule) override { + for (auto assigned_address : capsule.assigned_addresses) { + if (assigned_address.ip_prefix.address().IsIPv4()) { + QUIC_LOG(INFO) << "MasqueTunSession saving local IPv4 address " + << assigned_address.ip_prefix.address(); + local_address_ = assigned_address.ip_prefix.address(); + break; + } + } + // Bring up the TUN + QUIC_LOG(ERROR) << "Bringing up tun with address " << local_address_; + fd_ = CreateTunInterface(local_address_, false); + if (fd_ < 0) { + QUIC_LOG(FATAL) << "Failed to create TUN interface"; + } + if (!event_loop_->RegisterSocket(fd_, kSocketEventReadable, this)) { + QUIC_LOG(FATAL) << "Failed to register TUN fd with the event loop"; + } + return true; + } + bool OnAddressRequestCapsule( + const AddressRequestCapsule& /*capsule*/) override { + // Always ignore the address request capsule from the server. + return true; + } + bool OnRouteAdvertisementCapsule( + const RouteAdvertisementCapsule& /*capsule*/) override { + // Consider installing routes. + return true; + } + + // QuicSocketEventListener + void OnSocketEvent(QuicEventLoop* /*event_loop*/, QuicUdpSocketFd fd, + QuicSocketEventMask events) override { + if ((events & kSocketEventReadable) == 0) { + QUIC_DVLOG(1) << "Ignoring OnEvent fd " << fd << " event mask " << events; + return; + } + char datagram[1501]; + while (true) { + ssize_t read_size = read(fd, datagram, sizeof(datagram)); + if (read_size < 0) { + break; + } + // Packet received from the TUN. Write it to the MASQUE CONNECT-IP + // session. + session_->SendIpPacket(absl::string_view(datagram, read_size), this); + } + if (!event_loop_->SupportsEdgeTriggered()) { + if (!event_loop_->RearmSocket(fd, kSocketEventReadable)) { + QUIC_BUG(MasqueServerSession_ConnectIp_OnSocketEvent_Rearm) + << "Failed to re-arm socket " << fd << " for reading"; + } + } + } + + private: + QuicEventLoop* event_loop_; + MasqueClientSession* session_; + QuicIpAddress local_address_; + int fd_ = -1; +}; + +int RunMasqueClient(int argc, char* argv[]) { + quiche::QuicheSystemEventLoop system_event_loop("masque_client"); + const char* usage = "Usage: masque_client [options] "; + + // The first non-flag argument is the URI template of the MASQUE server. + // All subsequent ones are interpreted as URLs to fetch via the MASQUE server. + // Note that the URI template expansion currently only supports string + // replacement of {target_host} and {target_port}, not + // {?target_host,target_port}. + std::vector urls = + quiche::QuicheParseCommandLineFlags(usage, argc, argv); + bool bring_up_tun = quiche::GetQuicheCommandLineFlag(FLAGS_bring_up_tun); + if (urls.empty() && !bring_up_tun) { + quiche::QuichePrintCommandLineFlagHelp(usage); + return 1; + } + + const bool disable_certificate_verification = + quiche::GetQuicheCommandLineFlag(FLAGS_disable_certificate_verification); + std::unique_ptr event_loop = + GetDefaultEventLoop()->Create(QuicDefaultClock::Get()); + + std::string uri_template = urls[0]; + if (!absl::StrContains(uri_template, '/')) { + // If an authority is passed in instead of a URI template, use the default + // URI template. + uri_template = + absl::StrCat("https://", uri_template, + "/.well-known/masque/udp/{target_host}/{target_port}/"); + } + url::Parsed parsed_uri_template; + url::ParseStandardURL(uri_template.c_str(), uri_template.length(), + &parsed_uri_template); + if (!parsed_uri_template.scheme.is_nonempty() || + !parsed_uri_template.host.is_nonempty() || + !parsed_uri_template.path.is_nonempty()) { + std::cerr << "Failed to parse MASQUE URI template \"" << urls[0] << "\"" + << std::endl; + return 1; + } + std::string host = uri_template.substr(parsed_uri_template.host.begin, + parsed_uri_template.host.len); + std::unique_ptr proof_verifier; + if (disable_certificate_verification) { + proof_verifier = std::make_unique(); + } else { + proof_verifier = CreateDefaultProofVerifier(host); + } + MasqueMode masque_mode = MasqueMode::kOpen; + std::string mode_string = quiche::GetQuicheCommandLineFlag(FLAGS_masque_mode); + if (!mode_string.empty()) { + if (mode_string == "open") { + masque_mode = MasqueMode::kOpen; + } else if (mode_string == "connectip" || mode_string == "connect-ip") { + masque_mode = MasqueMode::kConnectIp; + } else { + std::cerr << "Invalid masque_mode \"" << mode_string << "\"" << std::endl; + return 1; + } + } + const int address_family = + quiche::GetQuicheCommandLineFlag(FLAGS_address_family); + int address_family_for_lookup; + if (address_family == 0) { + address_family_for_lookup = AF_UNSPEC; + } else if (address_family == 4) { + address_family_for_lookup = AF_INET; + } else if (address_family == 6) { + address_family_for_lookup = AF_INET6; + } else { + std::cerr << "Invalid address_family " << address_family << std::endl; + return 1; + } + std::unique_ptr masque_client = MasqueClient::Create( + uri_template, masque_mode, event_loop.get(), std::move(proof_verifier)); + if (masque_client == nullptr) { + return 1; + } + + std::cerr << "MASQUE is connected " << masque_client->connection_id() + << " in " << masque_mode << " mode" << std::endl; + + if (bring_up_tun) { + std::cerr << "Bringing up tun" << std::endl; + MasqueTunSession tun_session(event_loop.get(), + masque_client->masque_client_session()); + masque_client->masque_client_session()->SendIpPacket( + absl::string_view("asdf"), &tun_session); + while (true) { + event_loop->RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(50)); + } + QUICHE_NOTREACHED(); + } + + const bool dns_on_client = + quiche::GetQuicheCommandLineFlag(FLAGS_dns_on_client); + + for (size_t i = 1; i < urls.size(); ++i) { + if (!tools::SendEncapsulatedMasqueRequest( + masque_client.get(), event_loop.get(), urls[i], + disable_certificate_verification, address_family_for_lookup, + dns_on_client)) { + return 1; + } + } + + return 0; +} + +} // namespace + +} // namespace quic + +int main(int argc, char* argv[]) { return quic::RunMasqueClient(argc, argv); } diff --git a/quiche/quic/masque/masque_client_session.cc b/quiche/quic/masque/masque_client_session.cc new file mode 100644 index 000000000000..83f164843029 --- /dev/null +++ b/quiche/quic/masque/masque_client_session.cc @@ -0,0 +1,524 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/masque/masque_client_session.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "url/url_canon.h" +#include "quiche/quic/core/http/spdy_utils.h" +#include "quiche/quic/core/quic_data_reader.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/masque/masque_utils.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/tools/quic_url.h" +#include "quiche/common/platform/api/quiche_url_utils.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +namespace { + +using ::quiche::AddressAssignCapsule; +using ::quiche::AddressRequestCapsule; +using ::quiche::RouteAdvertisementCapsule; + +constexpr uint64_t kConnectIpPayloadContextId = 0; +} // namespace + +MasqueClientSession::MasqueClientSession( + MasqueMode masque_mode, const std::string& uri_template, + const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, const QuicServerId& server_id, + QuicCryptoClientConfig* crypto_config, + QuicClientPushPromiseIndex* push_promise_index, Owner* owner) + : QuicSpdyClientSession(config, supported_versions, connection, server_id, + crypto_config, push_promise_index), + masque_mode_(masque_mode), + uri_template_(uri_template), + owner_(owner) {} + +void MasqueClientSession::OnMessageAcked(QuicMessageId message_id, + QuicTime /*receive_timestamp*/) { + QUIC_DVLOG(1) << "Received ack for DATAGRAM frame " << message_id; +} + +void MasqueClientSession::OnMessageLost(QuicMessageId message_id) { + QUIC_DVLOG(1) << "We believe DATAGRAM frame " << message_id << " was lost"; +} + +const MasqueClientSession::ConnectUdpClientState* +MasqueClientSession::GetOrCreateConnectUdpClientState( + const QuicSocketAddress& target_server_address, + EncapsulatedClientSession* encapsulated_client_session) { + for (const ConnectUdpClientState& client_state : connect_udp_client_states_) { + if (client_state.target_server_address() == target_server_address && + client_state.encapsulated_client_session() == + encapsulated_client_session) { + // Found existing CONNECT-UDP request. + return &client_state; + } + } + // No CONNECT-UDP request found, create a new one. + std::string target_host; + auto it = fake_addresses_.find(target_server_address.host().ToPackedString()); + if (it != fake_addresses_.end()) { + target_host = it->second; + } else { + target_host = target_server_address.host().ToString(); + } + + url::Parsed parsed_uri_template; + url::ParseStandardURL(uri_template_.c_str(), uri_template_.length(), + &parsed_uri_template); + if (!parsed_uri_template.path.is_nonempty()) { + QUIC_BUG(bad URI template path) + << "Cannot parse path from URI template " << uri_template_; + return nullptr; + } + std::string path = uri_template_.substr(parsed_uri_template.path.begin, + parsed_uri_template.path.len); + if (parsed_uri_template.query.is_valid()) { + absl::StrAppend(&path, "?", + uri_template_.substr(parsed_uri_template.query.begin, + parsed_uri_template.query.len)); + } + absl::flat_hash_map parameters; + parameters["target_host"] = target_host; + parameters["target_port"] = absl::StrCat(target_server_address.port()); + std::string expanded_path; + absl::flat_hash_set vars_found; + bool expanded = + quiche::ExpandURITemplate(path, parameters, &expanded_path, &vars_found); + if (!expanded || vars_found.find("target_host") == vars_found.end() || + vars_found.find("target_port") == vars_found.end()) { + QUIC_DLOG(ERROR) << "Failed to expand URI template \"" << uri_template_ + << "\" for " << target_host << " port " + << target_server_address.port(); + return nullptr; + } + + url::Component expanded_path_component(0, expanded_path.length()); + url::RawCanonOutput<1024> canonicalized_path_output; + url::Component canonicalized_path_component; + bool canonicalized = url::CanonicalizePath( + expanded_path.c_str(), expanded_path_component, + &canonicalized_path_output, &canonicalized_path_component); + if (!canonicalized || !canonicalized_path_component.is_nonempty()) { + QUIC_DLOG(ERROR) << "Failed to canonicalize URI template \"" + << uri_template_ << "\" for " << target_host << " port " + << target_server_address.port(); + return nullptr; + } + std::string canonicalized_path( + canonicalized_path_output.data() + canonicalized_path_component.begin, + canonicalized_path_component.len); + + QuicSpdyClientStream* stream = CreateOutgoingBidirectionalStream(); + if (stream == nullptr) { + // Stream flow control limits prevented us from opening a new stream. + QUIC_DLOG(ERROR) << "Failed to open CONNECT-UDP stream"; + return nullptr; + } + + QuicUrl url(uri_template_); + std::string scheme = url.scheme(); + std::string authority = url.HostPort(); + + QUIC_DLOG(INFO) << "Sending CONNECT-UDP request for " << target_host + << " port " << target_server_address.port() << " on stream " + << stream->id() << " scheme=\"" << scheme << "\" authority=\"" + << authority << "\" path=\"" << canonicalized_path << "\""; + + // Send the request. + spdy::Http2HeaderBlock headers; + headers[":method"] = "CONNECT"; + headers[":protocol"] = "connect-udp"; + headers[":scheme"] = scheme; + headers[":authority"] = authority; + headers[":path"] = canonicalized_path; + headers["connect-udp-version"] = "12"; + size_t bytes_sent = + stream->SendRequest(std::move(headers), /*body=*/"", /*fin=*/false); + if (bytes_sent == 0) { + QUIC_DLOG(ERROR) << "Failed to send CONNECT-UDP request"; + return nullptr; + } + + connect_udp_client_states_.push_back(ConnectUdpClientState( + stream, encapsulated_client_session, this, target_server_address)); + return &connect_udp_client_states_.back(); +} + +const MasqueClientSession::ConnectIpClientState* +MasqueClientSession::GetOrCreateConnectIpClientState( + MasqueClientSession::EncapsulatedIpSession* encapsulated_ip_session) { + for (const ConnectIpClientState& client_state : connect_ip_client_states_) { + if (client_state.encapsulated_ip_session() == encapsulated_ip_session) { + // Found existing CONNECT-IP request. + return &client_state; + } + } + // No CONNECT-IP request found, create a new one. + QuicSpdyClientStream* stream = CreateOutgoingBidirectionalStream(); + if (stream == nullptr) { + // Stream flow control limits prevented us from opening a new stream. + QUIC_DLOG(ERROR) << "Failed to open CONNECT-IP stream"; + return nullptr; + } + + QuicUrl url(uri_template_); + std::string scheme = url.scheme(); + std::string authority = url.HostPort(); + std::string path = "/.well-known/masque/ip/*/*/"; + + QUIC_DLOG(INFO) << "Sending CONNECT-IP request on stream " << stream->id() + << " scheme=\"" << scheme << "\" authority=\"" << authority + << "\" path=\"" << path << "\""; + + // Send the request. + spdy::Http2HeaderBlock headers; + headers[":method"] = "CONNECT"; + headers[":protocol"] = "connect-ip"; + headers[":scheme"] = scheme; + headers[":authority"] = authority; + headers[":path"] = path; + headers["connect-ip-version"] = "3"; + size_t bytes_sent = + stream->SendRequest(std::move(headers), /*body=*/"", /*fin=*/false); + if (bytes_sent == 0) { + QUIC_DLOG(ERROR) << "Failed to send CONNECT-IP request"; + return nullptr; + } + + connect_ip_client_states_.push_back( + ConnectIpClientState(stream, encapsulated_ip_session, this)); + return &connect_ip_client_states_.back(); +} + +void MasqueClientSession::SendIpPacket( + absl::string_view packet, + MasqueClientSession::EncapsulatedIpSession* encapsulated_ip_session) { + const ConnectIpClientState* connect_ip = + GetOrCreateConnectIpClientState(encapsulated_ip_session); + if (connect_ip == nullptr) { + QUIC_DLOG(ERROR) << "Failed to create CONNECT-IP request"; + return; + } + + std::string http_payload; + http_payload.resize( + QuicDataWriter::GetVarInt62Len(kConnectIpPayloadContextId) + + packet.size()); + QuicDataWriter writer(http_payload.size(), http_payload.data()); + if (!writer.WriteVarInt62(kConnectIpPayloadContextId)) { + QUIC_BUG(IP context write fail) << "Failed to write CONNECT-IP context ID"; + return; + } + if (!writer.WriteStringPiece(packet)) { + QUIC_BUG(IP packet write fail) << "Failed to write CONNECT-IP packet"; + return; + } + MessageStatus message_status = + SendHttp3Datagram(connect_ip->stream()->id(), http_payload); + + QUIC_DVLOG(1) << "Sent encapsulated IP packet of length " << packet.size() + << " with stream ID " << connect_ip->stream()->id() + << " and got message status " + << MessageStatusToString(message_status); +} + +void MasqueClientSession::SendPacket( + absl::string_view packet, const QuicSocketAddress& target_server_address, + EncapsulatedClientSession* encapsulated_client_session) { + const ConnectUdpClientState* connect_udp = GetOrCreateConnectUdpClientState( + target_server_address, encapsulated_client_session); + if (connect_udp == nullptr) { + QUIC_DLOG(ERROR) << "Failed to create CONNECT-UDP request"; + return; + } + + std::string http_payload; + http_payload.resize(1 + packet.size()); + http_payload[0] = 0; + memcpy(&http_payload[1], packet.data(), packet.size()); + MessageStatus message_status = + SendHttp3Datagram(connect_udp->stream()->id(), http_payload); + + QUIC_DVLOG(1) << "Sent packet to " << target_server_address + << " compressed with stream ID " << connect_udp->stream()->id() + << " and got message status " + << MessageStatusToString(message_status); +} + +void MasqueClientSession::CloseConnectUdpStream( + EncapsulatedClientSession* encapsulated_client_session) { + for (auto it = connect_udp_client_states_.begin(); + it != connect_udp_client_states_.end();) { + if (it->encapsulated_client_session() == encapsulated_client_session) { + QUIC_DLOG(INFO) << "Removing CONNECT-UDP state for stream ID " + << it->stream()->id(); + auto* stream = it->stream(); + it = connect_udp_client_states_.erase(it); + if (!stream->write_side_closed()) { + stream->Reset(QUIC_STREAM_CANCELLED); + } + } else { + ++it; + } + } +} + +void MasqueClientSession::CloseConnectIpStream( + EncapsulatedIpSession* encapsulated_ip_session) { + for (auto it = connect_ip_client_states_.begin(); + it != connect_ip_client_states_.end();) { + if (it->encapsulated_ip_session() == encapsulated_ip_session) { + QUIC_DLOG(INFO) << "Removing CONNECT-IP state for stream ID " + << it->stream()->id(); + auto* stream = it->stream(); + it = connect_ip_client_states_.erase(it); + if (!stream->write_side_closed()) { + stream->Reset(QUIC_STREAM_CANCELLED); + } + } else { + ++it; + } + } +} + +void MasqueClientSession::OnConnectionClosed( + const QuicConnectionCloseFrame& frame, ConnectionCloseSource source) { + QuicSpdyClientSession::OnConnectionClosed(frame, source); + // Close all encapsulated sessions. + for (const auto& client_state : connect_udp_client_states_) { + client_state.encapsulated_client_session()->CloseConnection( + QUIC_CONNECTION_CANCELLED, "Underlying MASQUE connection was closed", + ConnectionCloseBehavior::SILENT_CLOSE); + } + for (const auto& client_state : connect_ip_client_states_) { + client_state.encapsulated_ip_session()->CloseIpSession( + "Underlying MASQUE connection was closed"); + } +} + +void MasqueClientSession::OnStreamClosed(QuicStreamId stream_id) { + if (QuicUtils::IsBidirectionalStreamId(stream_id, version()) && + QuicUtils::IsClientInitiatedStreamId(transport_version(), stream_id)) { + QuicSpdyClientStream* stream = + reinterpret_cast(GetActiveStream(stream_id)); + if (stream != nullptr) { + QUIC_DLOG(INFO) << "Stream " << stream_id + << " closed, got response headers:" + << stream->response_headers().DebugString(); + } + } + for (auto it = connect_udp_client_states_.begin(); + it != connect_udp_client_states_.end();) { + if (it->stream()->id() == stream_id) { + QUIC_DLOG(INFO) << "Stream " << stream_id + << " was closed, removing CONNECT-UDP state"; + auto* encapsulated_client_session = it->encapsulated_client_session(); + it = connect_udp_client_states_.erase(it); + encapsulated_client_session->CloseConnection( + QUIC_CONNECTION_CANCELLED, + "Underlying MASQUE CONNECT-UDP stream was closed", + ConnectionCloseBehavior::SILENT_CLOSE); + } else { + ++it; + } + } + for (auto it = connect_ip_client_states_.begin(); + it != connect_ip_client_states_.end();) { + if (it->stream()->id() == stream_id) { + QUIC_DLOG(INFO) << "Stream " << stream_id + << " was closed, removing CONNECT-IP state"; + auto* encapsulated_ip_session = it->encapsulated_ip_session(); + it = connect_ip_client_states_.erase(it); + encapsulated_ip_session->CloseIpSession( + "Underlying MASQUE CONNECT-IP stream was closed"); + } else { + ++it; + } + } + + QuicSpdyClientSession::OnStreamClosed(stream_id); +} + +bool MasqueClientSession::OnSettingsFrame(const SettingsFrame& frame) { + QUIC_DLOG(INFO) << "Received SETTINGS: " << frame; + if (!QuicSpdyClientSession::OnSettingsFrame(frame)) { + QUIC_DLOG(ERROR) << "Failed to parse received settings"; + return false; + } + if (!SupportsH3Datagram()) { + QUIC_DLOG(ERROR) << "Refusing to use MASQUE without HTTP/3 Datagrams"; + return false; + } + QUIC_DLOG(INFO) << "Using HTTP Datagram: " << http_datagram_support(); + owner_->OnSettingsReceived(); + return true; +} + +MasqueClientSession::ConnectUdpClientState::ConnectUdpClientState( + QuicSpdyClientStream* stream, + EncapsulatedClientSession* encapsulated_client_session, + MasqueClientSession* masque_session, + const QuicSocketAddress& target_server_address) + : stream_(stream), + encapsulated_client_session_(encapsulated_client_session), + masque_session_(masque_session), + target_server_address_(target_server_address) { + QUICHE_DCHECK_NE(masque_session_, nullptr); + this->stream()->RegisterHttp3DatagramVisitor(this); +} + +MasqueClientSession::ConnectUdpClientState::~ConnectUdpClientState() { + if (stream() != nullptr) { + stream()->UnregisterHttp3DatagramVisitor(); + } +} + +MasqueClientSession::ConnectUdpClientState::ConnectUdpClientState( + MasqueClientSession::ConnectUdpClientState&& other) { + *this = std::move(other); +} + +MasqueClientSession::ConnectUdpClientState& +MasqueClientSession::ConnectUdpClientState::operator=( + MasqueClientSession::ConnectUdpClientState&& other) { + stream_ = other.stream_; + encapsulated_client_session_ = other.encapsulated_client_session_; + masque_session_ = other.masque_session_; + target_server_address_ = other.target_server_address_; + other.stream_ = nullptr; + if (stream() != nullptr) { + stream()->ReplaceHttp3DatagramVisitor(this); + } + return *this; +} + +void MasqueClientSession::ConnectUdpClientState::OnHttp3Datagram( + QuicStreamId stream_id, absl::string_view payload) { + QUICHE_DCHECK_EQ(stream_id, stream()->id()); + QuicDataReader reader(payload); + uint64_t context_id; + if (!reader.ReadVarInt62(&context_id)) { + QUIC_DLOG(ERROR) << "Failed to read context ID"; + return; + } + if (context_id != 0) { + QUIC_DLOG(ERROR) << "Ignoring HTTP Datagram with unexpected context ID " + << context_id; + return; + } + absl::string_view http_payload = reader.ReadRemainingPayload(); + encapsulated_client_session_->ProcessPacket(http_payload, + target_server_address_); + QUIC_DVLOG(1) << "Sent " << http_payload.size() + << " bytes to connection for stream ID " << stream_id; +} + +MasqueClientSession::ConnectIpClientState::ConnectIpClientState( + QuicSpdyClientStream* stream, + EncapsulatedIpSession* encapsulated_ip_session, + MasqueClientSession* masque_session) + : stream_(stream), + encapsulated_ip_session_(encapsulated_ip_session), + masque_session_(masque_session) { + QUICHE_DCHECK_NE(masque_session_, nullptr); + this->stream()->RegisterHttp3DatagramVisitor(this); + this->stream()->RegisterConnectIpVisitor(this); +} + +MasqueClientSession::ConnectIpClientState::~ConnectIpClientState() { + if (stream() != nullptr) { + stream()->UnregisterHttp3DatagramVisitor(); + stream()->UnregisterConnectIpVisitor(); + } +} + +MasqueClientSession::ConnectIpClientState::ConnectIpClientState( + MasqueClientSession::ConnectIpClientState&& other) { + *this = std::move(other); +} + +MasqueClientSession::ConnectIpClientState& +MasqueClientSession::ConnectIpClientState::operator=( + MasqueClientSession::ConnectIpClientState&& other) { + stream_ = other.stream_; + encapsulated_ip_session_ = other.encapsulated_ip_session_; + masque_session_ = other.masque_session_; + other.stream_ = nullptr; + if (stream() != nullptr) { + stream()->ReplaceHttp3DatagramVisitor(this); + stream()->ReplaceConnectIpVisitor(this); + } + return *this; +} + +void MasqueClientSession::ConnectIpClientState::OnHttp3Datagram( + QuicStreamId stream_id, absl::string_view payload) { + QUICHE_DCHECK_EQ(stream_id, stream()->id()); + QuicDataReader reader(payload); + uint64_t context_id; + if (!reader.ReadVarInt62(&context_id)) { + QUIC_DLOG(ERROR) << "Failed to read context ID"; + return; + } + if (context_id != kConnectIpPayloadContextId) { + QUIC_DLOG(ERROR) << "Ignoring HTTP Datagram with unexpected context ID " + << context_id; + return; + } + absl::string_view http_payload = reader.ReadRemainingPayload(); + encapsulated_ip_session_->ProcessIpPacket(http_payload); + QUIC_DVLOG(1) << "Sent " << http_payload.size() + << " IP bytes to connection for stream ID " << stream_id; +} + +bool MasqueClientSession::ConnectIpClientState::OnAddressAssignCapsule( + const AddressAssignCapsule& capsule) { + return encapsulated_ip_session_->OnAddressAssignCapsule(capsule); +} + +bool MasqueClientSession::ConnectIpClientState::OnAddressRequestCapsule( + const AddressRequestCapsule& capsule) { + return encapsulated_ip_session_->OnAddressRequestCapsule(capsule); +} + +bool MasqueClientSession::ConnectIpClientState::OnRouteAdvertisementCapsule( + const RouteAdvertisementCapsule& capsule) { + return encapsulated_ip_session_->OnRouteAdvertisementCapsule(capsule); +} + +void MasqueClientSession::ConnectIpClientState::OnHeadersWritten() {} + +quiche::QuicheIpAddress MasqueClientSession::GetFakeAddress( + absl::string_view hostname) { + quiche::QuicheIpAddress address; + uint8_t address_bytes[16] = {0xFD}; + quiche::QuicheRandom::GetInstance()->RandBytes(&address_bytes[1], + sizeof(address_bytes) - 1); + address.FromPackedString(reinterpret_cast(address_bytes), + sizeof(address_bytes)); + std::string address_bytes_string(reinterpret_cast(address_bytes), + sizeof(address_bytes)); + fake_addresses_[address_bytes_string] = std::string(hostname); + return address; +} + +void MasqueClientSession::RemoveFakeAddress( + const quiche::QuicheIpAddress& fake_address) { + fake_addresses_.erase(fake_address.ToPackedString()); +} + +} // namespace quic diff --git a/quiche/quic/masque/masque_client_session.h b/quiche/quic/masque/masque_client_session.h new file mode 100644 index 000000000000..4e317eba461f --- /dev/null +++ b/quiche/quic/masque/masque_client_session.h @@ -0,0 +1,238 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_MASQUE_MASQUE_CLIENT_SESSION_H_ +#define QUICHE_QUIC_MASQUE_MASQUE_CLIENT_SESSION_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/http/quic_spdy_client_session.h" +#include "quiche/quic/masque/masque_utils.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_socket_address.h" + +namespace quic { + +// QUIC client session for connection to MASQUE proxy. This session establishes +// a connection to a MASQUE proxy and handles sending and receiving DATAGRAM +// frames for operation of the MASQUE protocol. Multiple end-to-end encapsulated +// sessions can then coexist inside this session. Once these are created, they +// need to be registered with this session. +class QUIC_NO_EXPORT MasqueClientSession : public QuicSpdyClientSession { + public: + // Interface meant to be implemented by the owner of the + // MasqueClientSession instance. + class QUIC_NO_EXPORT Owner { + public: + virtual ~Owner() {} + + // Notifies the owner that a settings frame has been received. + virtual void OnSettingsReceived() = 0; + }; + + // Interface meant to be implemented by client sessions encapsulated inside + // CONNECT-UDP, i.e. the end-to-end QUIC client sessions that run inside + // CONNECT-UDP encapsulation. + class QUIC_NO_EXPORT EncapsulatedClientSession { + public: + virtual ~EncapsulatedClientSession() {} + + // Process UDP packet that was just decapsulated. |packet| contains the UDP + // payload. + virtual void ProcessPacket(absl::string_view packet, + QuicSocketAddress target_server_address) = 0; + + // Close the encapsulated connection. + virtual void CloseConnection( + QuicErrorCode error, const std::string& details, + ConnectionCloseBehavior connection_close_behavior) = 0; + }; + + // Interface meant to be implemented by client sessions encapsulated inside + // CONNECT-IP, i.e. the end-to-end QUIC client sessions that run inside + // CONNECT-IP encapsulation. + class QUIC_NO_EXPORT EncapsulatedIpSession { + public: + virtual ~EncapsulatedIpSession() {} + + // Process packet that was just decapsulated. |packet| contains the IP + // header and payload. + virtual void ProcessIpPacket(absl::string_view packet) = 0; + + // Close the encapsulated connection. + virtual void CloseIpSession(const std::string& details) = 0; + + virtual bool OnAddressAssignCapsule( + const quiche::AddressAssignCapsule& capsule) = 0; + virtual bool OnAddressRequestCapsule( + const quiche::AddressRequestCapsule& capsule) = 0; + virtual bool OnRouteAdvertisementCapsule( + const quiche::RouteAdvertisementCapsule& capsule) = 0; + }; + + // Takes ownership of |connection|, but not of |crypto_config| or + // |push_promise_index| or |owner|. All pointers must be non-null. Caller + // must ensure that |push_promise_index| and |owner| stay valid for the + // lifetime of the newly created MasqueClientSession. + MasqueClientSession(MasqueMode masque_mode, const std::string& uri_template, + const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, const QuicServerId& server_id, + QuicCryptoClientConfig* crypto_config, + QuicClientPushPromiseIndex* push_promise_index, + Owner* owner); + + // Disallow copy and assign. + MasqueClientSession(const MasqueClientSession&) = delete; + MasqueClientSession& operator=(const MasqueClientSession&) = delete; + + // From QuicSession. + void OnMessageAcked(QuicMessageId message_id, + QuicTime receive_timestamp) override; + void OnMessageLost(QuicMessageId message_id) override; + void OnConnectionClosed(const QuicConnectionCloseFrame& frame, + ConnectionCloseSource source) override; + void OnStreamClosed(QuicStreamId stream_id) override; + + // From QuicSpdySession. + bool OnSettingsFrame(const SettingsFrame& frame) override; + + // Send encapsulated UDP packet. |packet| contains the UDP payload. + void SendPacket(absl::string_view packet, + const QuicSocketAddress& target_server_address, + EncapsulatedClientSession* encapsulated_client_session); + + // Send encapsulated IP packet. |packet| contains the IP header and payload. + void SendIpPacket(absl::string_view packet, + EncapsulatedIpSession* encapsulated_ip_session); + + // Close CONNECT-UDP stream tied to this encapsulated client session. + void CloseConnectUdpStream( + EncapsulatedClientSession* encapsulated_client_session); + + // Close CONNECT-IP stream tied to this encapsulated client session. + void CloseConnectIpStream(EncapsulatedIpSession* encapsulated_ip_session); + + // Generate a random Unique Local Address and register a mapping from + // that address to the corresponding hostname. The returned address should be + // removed by calling RemoveFakeAddress() once it is no longer needed. + quiche::QuicheIpAddress GetFakeAddress(absl::string_view hostname); + + // Removes a fake address that was previously created by GetFakeAddress(). + void RemoveFakeAddress(const quiche::QuicheIpAddress& fake_address); + + private: + // State that the MasqueClientSession keeps for each CONNECT-UDP request. + class QUIC_NO_EXPORT ConnectUdpClientState + : public QuicSpdyStream::Http3DatagramVisitor { + public: + // |stream| and |encapsulated_client_session| must be valid for the lifetime + // of the ConnectUdpClientState. + explicit ConnectUdpClientState( + QuicSpdyClientStream* stream, + EncapsulatedClientSession* encapsulated_client_session, + MasqueClientSession* masque_session, + const QuicSocketAddress& target_server_address); + + ~ConnectUdpClientState(); + + // Disallow copy but allow move. + ConnectUdpClientState(const ConnectUdpClientState&) = delete; + ConnectUdpClientState(ConnectUdpClientState&&); + ConnectUdpClientState& operator=(const ConnectUdpClientState&) = delete; + ConnectUdpClientState& operator=(ConnectUdpClientState&&); + + QuicSpdyClientStream* stream() const { return stream_; } + EncapsulatedClientSession* encapsulated_client_session() const { + return encapsulated_client_session_; + } + const QuicSocketAddress& target_server_address() const { + return target_server_address_; + } + + // From QuicSpdyStream::Http3DatagramVisitor. + void OnHttp3Datagram(QuicStreamId stream_id, + absl::string_view payload) override; + void OnUnknownCapsule(QuicStreamId /*stream_id*/, + const quiche::UnknownCapsule& /*capsule*/) override {} + + private: + QuicSpdyClientStream* stream_; // Unowned. + EncapsulatedClientSession* encapsulated_client_session_; // Unowned. + MasqueClientSession* masque_session_; // Unowned. + QuicSocketAddress target_server_address_; + }; + + // State that the MasqueClientSession keeps for each CONNECT-IP request. + class QUIC_NO_EXPORT ConnectIpClientState + : public QuicSpdyStream::Http3DatagramVisitor, + public QuicSpdyStream::ConnectIpVisitor { + public: + // |stream| and |encapsulated_client_session| must be valid for the lifetime + // of the ConnectUdpClientState. + explicit ConnectIpClientState( + QuicSpdyClientStream* stream, + EncapsulatedIpSession* encapsulated_ip_session, + MasqueClientSession* masque_session); + + ~ConnectIpClientState(); + + // Disallow copy but allow move. + ConnectIpClientState(const ConnectIpClientState&) = delete; + ConnectIpClientState(ConnectIpClientState&&); + ConnectIpClientState& operator=(const ConnectIpClientState&) = delete; + ConnectIpClientState& operator=(ConnectIpClientState&&); + + QuicSpdyClientStream* stream() const { return stream_; } + EncapsulatedIpSession* encapsulated_ip_session() const { + return encapsulated_ip_session_; + } + + // From QuicSpdyStream::Http3DatagramVisitor. + void OnHttp3Datagram(QuicStreamId stream_id, + absl::string_view payload) override; + void OnUnknownCapsule(QuicStreamId /*stream_id*/, + const quiche::UnknownCapsule& /*capsule*/) override {} + + // From QuicSpdyStream::ConnectIpVisitor. + bool OnAddressAssignCapsule( + const quiche::AddressAssignCapsule& capsule) override; + bool OnAddressRequestCapsule( + const quiche::AddressRequestCapsule& capsule) override; + bool OnRouteAdvertisementCapsule( + const quiche::RouteAdvertisementCapsule& capsule) override; + void OnHeadersWritten() override; + + private: + QuicSpdyClientStream* stream_; // Unowned. + EncapsulatedIpSession* encapsulated_ip_session_; // Unowned. + MasqueClientSession* masque_session_; // Unowned. + }; + + HttpDatagramSupport LocalHttpDatagramSupport() override { + return HttpDatagramSupport::kRfc; + } + + const ConnectUdpClientState* GetOrCreateConnectUdpClientState( + const QuicSocketAddress& target_server_address, + EncapsulatedClientSession* encapsulated_client_session); + + const ConnectIpClientState* GetOrCreateConnectIpClientState( + EncapsulatedIpSession* encapsulated_ip_session); + + MasqueMode masque_mode_; + std::string uri_template_; + std::list connect_udp_client_states_; + std::list connect_ip_client_states_; + // Maps fake addresses generated by GetFakeAddress() to their corresponding + // hostnames. + absl::flat_hash_map fake_addresses_; + Owner* owner_; // Unowned; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_MASQUE_MASQUE_CLIENT_SESSION_H_ diff --git a/quiche/quic/masque/masque_client_tools.cc b/quiche/quic/masque/masque_client_tools.cc new file mode 100644 index 000000000000..8410529117e3 --- /dev/null +++ b/quiche/quic/masque/masque_client_tools.cc @@ -0,0 +1,151 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/masque/masque_client_tools.h" + +#include "absl/types/optional.h" +#include "quiche/quic/masque/masque_encapsulated_client.h" +#include "quiche/quic/masque/masque_utils.h" +#include "quiche/quic/platform/api/quic_default_proof_providers.h" +#include "quiche/quic/tools/fake_proof_verifier.h" +#include "quiche/quic/tools/quic_name_lookup.h" +#include "quiche/quic/tools/quic_url.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { +namespace tools { + +namespace { + +// Helper class to ensure a fake address gets properly removed when this goes +// out of scope. +class FakeAddressRemover { + public: + FakeAddressRemover() = default; + void IngestFakeAddress(const quiche::QuicheIpAddress& fake_address, + MasqueClientSession* masque_client_session) { + QUICHE_CHECK(masque_client_session != nullptr); + QUICHE_CHECK(!fake_address_.has_value()); + fake_address_ = fake_address; + masque_client_session_ = masque_client_session; + } + ~FakeAddressRemover() { + if (fake_address_.has_value()) { + masque_client_session_->RemoveFakeAddress(*fake_address_); + } + } + + private: + absl::optional fake_address_; + MasqueClientSession* masque_client_session_ = nullptr; +}; + +} // namespace + +bool SendEncapsulatedMasqueRequest(MasqueClient* masque_client, + QuicEventLoop* event_loop, + std::string url_string, + bool disable_certificate_verification, + int address_family_for_lookup, + bool dns_on_client) { + const QuicUrl url(url_string, "https"); + std::unique_ptr proof_verifier; + if (disable_certificate_verification) { + proof_verifier = std::make_unique(); + } else { + proof_verifier = CreateDefaultProofVerifier(url.host()); + } + + // Build the client, and try to connect. + QuicSocketAddress addr; + FakeAddressRemover fake_address_remover; + if (dns_on_client) { + addr = LookupAddress(address_family_for_lookup, url.host(), + absl::StrCat(url.port())); + if (!addr.IsInitialized()) { + QUIC_LOG(ERROR) << "Unable to resolve address: " << url.host(); + return false; + } + } else { + quiche::QuicheIpAddress fake_address = + masque_client->masque_client_session()->GetFakeAddress(url.host()); + fake_address_remover.IngestFakeAddress( + fake_address, masque_client->masque_client_session()); + addr = QuicSocketAddress(fake_address, url.port()); + QUICHE_CHECK(addr.IsInitialized()); + } + const QuicServerId server_id(url.host(), url.port()); + auto client = std::make_unique( + addr, server_id, event_loop, std::move(proof_verifier), masque_client); + + if (client == nullptr) { + QUIC_LOG(ERROR) << "Failed to create MasqueEncapsulatedClient for " + << url_string; + return false; + } + + client->set_initial_max_packet_length(kMasqueMaxEncapsulatedPacketSize); + client->set_drop_response_body(false); + if (!client->Initialize()) { + QUIC_LOG(ERROR) << "Failed to initialize MasqueEncapsulatedClient for " + << url_string; + return false; + } + + if (!client->Connect()) { + QuicErrorCode error = client->session()->error(); + QUIC_LOG(ERROR) << "Failed to connect with client " + << client->session()->connection()->client_connection_id() + << " server " << client->session()->connection_id() + << " to " << url.HostPort() + << ". Error: " << QuicErrorCodeToString(error); + return false; + } + + QUIC_LOG(INFO) << "Connected client " + << client->session()->connection()->client_connection_id() + << " server " << client->session()->connection_id() << " for " + << url_string; + + // Construct the string body from flags, if provided. + // TODO(dschinazi) Add support for HTTP POST and non-empty bodies. + const std::string body = ""; + + // Construct a GET or POST request for supplied URL. + spdy::Http2HeaderBlock header_block; + header_block[":method"] = "GET"; + header_block[":scheme"] = url.scheme(); + header_block[":authority"] = url.HostPort(); + header_block[":path"] = url.PathParamsQuery(); + + // Make sure to store the response, for later output. + client->set_store_response(true); + + // Send the MASQUE init request. + client->SendRequestAndWaitForResponse(header_block, body, + /*fin=*/true); + + if (!client->connected()) { + QUIC_LOG(ERROR) << "Request for " << url_string + << " caused connection failure. Error: " + << QuicErrorCodeToString(client->session()->error()); + return false; + } + + const int response_code = client->latest_response_code(); + if (response_code < 200 || response_code >= 300) { + QUIC_LOG(ERROR) << "Request for " << url_string + << " failed with HTTP response code " << response_code; + return false; + } + + const std::string response_body = client->latest_response_body(); + QUIC_LOG(INFO) << "Request succeeded for " << url_string << std::endl + << response_body; + + return true; +} + +} // namespace tools +} // namespace quic diff --git a/quiche/quic/masque/masque_client_tools.h b/quiche/quic/masque/masque_client_tools.h new file mode 100644 index 000000000000..932fdfac78d1 --- /dev/null +++ b/quiche/quic/masque/masque_client_tools.h @@ -0,0 +1,27 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_MASQUE_MASQUE_CLIENT_TOOLS_H_ +#define QUICHE_QUIC_MASQUE_MASQUE_CLIENT_TOOLS_H_ + +#include "quiche/quic/masque/masque_client.h" + +namespace quic { +namespace tools { + +// Sends an HTTP GET request for |url_string|, proxied over the MASQUE +// connection represented by |masque_client|. A valid and owned |event_loop| +// is required. |disable_certificate_verification| allows disabling verification +// of the HTTP server's TLS certificate. +bool SendEncapsulatedMasqueRequest(MasqueClient* masque_client, + QuicEventLoop* event_loop, + std::string url_string, + bool disable_certificate_verification, + int address_family_for_lookup, + bool dns_on_client); + +} // namespace tools +} // namespace quic + +#endif // QUICHE_QUIC_MASQUE_MASQUE_CLIENT_TOOLS_H_ diff --git a/quiche/quic/masque/masque_dispatcher.cc b/quiche/quic/masque/masque_dispatcher.cc new file mode 100644 index 000000000000..40ba112d2b16 --- /dev/null +++ b/quiche/quic/masque/masque_dispatcher.cc @@ -0,0 +1,49 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/masque/masque_dispatcher.h" + +#include "quiche/quic/masque/masque_server_session.h" + +namespace quic { + +MasqueDispatcher::MasqueDispatcher( + MasqueMode masque_mode, const QuicConfig* config, + const QuicCryptoServerConfig* crypto_config, + QuicVersionManager* version_manager, QuicEventLoop* event_loop, + std::unique_ptr helper, + std::unique_ptr session_helper, + std::unique_ptr alarm_factory, + MasqueServerBackend* masque_server_backend, + uint8_t expected_server_connection_id_length, + ConnectionIdGeneratorInterface& generator) + : QuicSimpleDispatcher(config, crypto_config, version_manager, + std::move(helper), std::move(session_helper), + std::move(alarm_factory), masque_server_backend, + expected_server_connection_id_length, generator), + masque_mode_(masque_mode), + event_loop_(event_loop), + masque_server_backend_(masque_server_backend) {} + +std::unique_ptr MasqueDispatcher::CreateQuicSession( + QuicConnectionId connection_id, const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, absl::string_view /*alpn*/, + const ParsedQuicVersion& version, + const ParsedClientHello& /*parsed_chlo*/) { + // The MasqueServerSession takes ownership of |connection| below. + QuicConnection* connection = new QuicConnection( + connection_id, self_address, peer_address, helper(), alarm_factory(), + writer(), + /*owns_writer=*/false, Perspective::IS_SERVER, + ParsedQuicVersionVector{version}, connection_id_generator()); + + auto session = std::make_unique( + masque_mode_, config(), GetSupportedVersions(), connection, this, + event_loop_, session_helper(), crypto_config(), compressed_certs_cache(), + masque_server_backend_); + session->Initialize(); + return session; +} + +} // namespace quic diff --git a/quiche/quic/masque/masque_dispatcher.h b/quiche/quic/masque/masque_dispatcher.h new file mode 100644 index 000000000000..8d5a07b2e0bb --- /dev/null +++ b/quiche/quic/masque/masque_dispatcher.h @@ -0,0 +1,52 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_MASQUE_MASQUE_DISPATCHER_H_ +#define QUICHE_QUIC_MASQUE_MASQUE_DISPATCHER_H_ + +#include "absl/container/flat_hash_map.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/masque/masque_server_backend.h" +#include "quiche/quic/masque/masque_server_session.h" +#include "quiche/quic/masque/masque_utils.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/tools/quic_simple_dispatcher.h" + +namespace quic { + +// QUIC dispatcher that handles new MASQUE connections and can proxy traffic +// between MASQUE clients and QUIC servers. +class QUIC_NO_EXPORT MasqueDispatcher : public QuicSimpleDispatcher { + public: + explicit MasqueDispatcher( + MasqueMode masque_mode, const QuicConfig* config, + const QuicCryptoServerConfig* crypto_config, + QuicVersionManager* version_manager, QuicEventLoop* event_loop, + std::unique_ptr helper, + std::unique_ptr session_helper, + std::unique_ptr alarm_factory, + MasqueServerBackend* masque_server_backend, + uint8_t expected_server_connection_id_length, + ConnectionIdGeneratorInterface& generator); + + // Disallow copy and assign. + MasqueDispatcher(const MasqueDispatcher&) = delete; + MasqueDispatcher& operator=(const MasqueDispatcher&) = delete; + + // From QuicSimpleDispatcher. + std::unique_ptr CreateQuicSession( + QuicConnectionId connection_id, const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, absl::string_view alpn, + const ParsedQuicVersion& version, + const quic::ParsedClientHello& parsed_chlo) override; + + private: + MasqueMode masque_mode_; + QuicEventLoop* event_loop_; // Unowned. + MasqueServerBackend* masque_server_backend_; // Unowned. +}; + +} // namespace quic + +#endif // QUICHE_QUIC_MASQUE_MASQUE_DISPATCHER_H_ diff --git a/quiche/quic/masque/masque_encapsulated_client.cc b/quiche/quic/masque/masque_encapsulated_client.cc new file mode 100644 index 000000000000..d3843638f7c8 --- /dev/null +++ b/quiche/quic/masque/masque_encapsulated_client.cc @@ -0,0 +1,262 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/masque/masque_encapsulated_client.h" + +#include + +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/masque/masque_client.h" +#include "quiche/quic/masque/masque_client_session.h" +#include "quiche/quic/masque/masque_encapsulated_client_session.h" +#include "quiche/quic/masque/masque_utils.h" +#include "quiche/quic/tools/quic_client_default_network_helper.h" +#include "quiche/common/quiche_data_reader.h" +#include "quiche/common/quiche_data_writer.h" + +namespace quic { + +namespace { + +class ChecksumWriter { + public: + explicit ChecksumWriter(quiche::QuicheDataWriter& writer) : writer_(writer) {} + void IngestUInt16(uint16_t val) { accumulator_ += val; } + void IngestUInt8(uint8_t val) { + uint16_t val16 = odd_ ? val : (val << 8); + accumulator_ += val16; + odd_ = !odd_; + } + bool IngestData(size_t offset, size_t length) { + quiche::QuicheDataReader reader( + writer_.data(), std::min(offset + length, writer_.capacity())); + if (!reader.Seek(offset) || reader.BytesRemaining() < length) { + return false; + } + // Handle any potentially off first byte. + uint8_t first_byte; + if (odd_ && reader.ReadUInt8(&first_byte)) { + IngestUInt8(first_byte); + } + // Handle each 16-bit word at a time. + while (reader.BytesRemaining() >= sizeof(uint16_t)) { + uint16_t word; + if (!reader.ReadUInt16(&word)) { + return false; + } + IngestUInt16(word); + } + // Handle any leftover odd byte. + uint8_t last_byte; + if (reader.ReadUInt8(&last_byte)) { + IngestUInt8(last_byte); + } + return true; + } + bool WriteChecksumAtOffset(size_t offset) { + while (accumulator_ >> 16 > 0) { + accumulator_ = (accumulator_ & 0xffff) + (accumulator_ >> 16); + } + accumulator_ = 0xffff & ~accumulator_; + quiche::QuicheDataWriter writer2(writer_.capacity(), writer_.data()); + return writer2.Seek(offset) && writer2.WriteUInt16(accumulator_); + } + + private: + quiche::QuicheDataWriter& writer_; + uint32_t accumulator_ = 0xffff; + bool odd_ = false; +}; + +// Custom packet writer that allows getting all of a connection's outgoing +// packets. +class MasquePacketWriter : public QuicPacketWriter { + public: + explicit MasquePacketWriter(MasqueEncapsulatedClient* client) + : client_(client) {} + WriteResult WritePacket(const char* buffer, size_t buf_len, + const QuicIpAddress& /*self_address*/, + const QuicSocketAddress& peer_address, + PerPacketOptions* /*options*/) override { + QUICHE_DCHECK(peer_address.IsInitialized()); + QUIC_DVLOG(1) << "MasquePacketWriter trying to write " << buf_len + << " bytes to " << peer_address; + if (client_->masque_client()->masque_mode() == MasqueMode::kConnectIp) { + constexpr size_t kIPv4HeaderSize = 20; + constexpr size_t kIPv4ChecksumOffset = 10; + constexpr size_t kIPv6HeaderSize = 40; + constexpr size_t kUdpHeaderSize = 8; + const size_t udp_length = kUdpHeaderSize + buf_len; + std::string packet; + packet.resize( + (peer_address.host().IsIPv6() ? kIPv6HeaderSize : kIPv4HeaderSize) + + udp_length); + quiche::QuicheDataWriter writer(packet.size(), packet.data()); + if (peer_address.host().IsIPv6()) { + // Write IPv6 header. + QUICHE_CHECK(writer.WriteUInt8(0x60)); // Version = 6 and DSCP. + QUICHE_CHECK(writer.WriteUInt8(0)); // DSCP/ECN and flow label. + QUICHE_CHECK(writer.WriteUInt16(0)); // Flow label. + QUICHE_CHECK(writer.WriteUInt16(udp_length)); // Payload Length. + QUICHE_CHECK(writer.WriteUInt8(17)); // Next header = UDP. + QUICHE_CHECK(writer.WriteUInt8(64)); // Hop limit = 64. + in6_addr source_address = {}; + if (client_->masque_encapsulated_client_session() + ->local_v6_address() + .IsIPv6()) { + source_address = client_->masque_encapsulated_client_session() + ->local_v6_address() + .GetIPv6(); + } + QUICHE_CHECK( + writer.WriteBytes(&source_address, sizeof(source_address))); + in6_addr destination_address = peer_address.host().GetIPv6(); + QUICHE_CHECK(writer.WriteBytes(&destination_address, + sizeof(destination_address))); + } else { + // Write IPv4 header. + QUICHE_CHECK(writer.WriteUInt8(0x45)); // Version = 4, IHL = 5. + QUICHE_CHECK(writer.WriteUInt8(0)); // DSCP/ECN. + QUICHE_CHECK(writer.WriteUInt16(packet.size())); // Total Length. + QUICHE_CHECK(writer.WriteUInt32(0)); // No fragmentation. + QUICHE_CHECK(writer.WriteUInt8(64)); // TTL = 64. + QUICHE_CHECK(writer.WriteUInt8(17)); // IP Protocol = UDP. + QUICHE_CHECK(writer.WriteUInt16(0)); // Checksum = 0 initially. + in_addr source_address = {}; + if (client_->masque_encapsulated_client_session() + ->local_v4_address() + .IsIPv4()) { + source_address = client_->masque_encapsulated_client_session() + ->local_v4_address() + .GetIPv4(); + } + QUICHE_CHECK( + writer.WriteBytes(&source_address, sizeof(source_address))); + in_addr destination_address = peer_address.host().GetIPv4(); + QUICHE_CHECK(writer.WriteBytes(&destination_address, + sizeof(destination_address))); + ChecksumWriter ip_checksum_writer(writer); + QUICHE_CHECK(ip_checksum_writer.IngestData(0, kIPv4HeaderSize)); + QUICHE_CHECK( + ip_checksum_writer.WriteChecksumAtOffset(kIPv4ChecksumOffset)); + } + // Write UDP header. + QUICHE_CHECK(writer.WriteUInt16(0x1234)); // Source port. + QUICHE_CHECK( + writer.WriteUInt16(peer_address.port())); // Destination port. + QUICHE_CHECK(writer.WriteUInt16(udp_length)); // UDP length. + QUICHE_CHECK(writer.WriteUInt16(0)); // Checksum = 0 initially. + // Write UDP payload. + QUICHE_CHECK(writer.WriteBytes(buffer, buf_len)); + ChecksumWriter udp_checksum_writer(writer); + if (peer_address.host().IsIPv6()) { + QUICHE_CHECK(udp_checksum_writer.IngestData(8, 32)); // IP addresses. + udp_checksum_writer.IngestUInt16(0); // High bits of UDP length. + udp_checksum_writer.IngestUInt16( + udp_length); // Low bits of UDP length. + udp_checksum_writer.IngestUInt16(0); // Zeroes. + udp_checksum_writer.IngestUInt8(0); // Zeroes. + udp_checksum_writer.IngestUInt8(17); // Next header = UDP. + QUICHE_CHECK(udp_checksum_writer.IngestData( + kIPv6HeaderSize, udp_length)); // UDP header and data. + QUICHE_CHECK( + udp_checksum_writer.WriteChecksumAtOffset(kIPv6HeaderSize + 6)); + } else { + QUICHE_CHECK(udp_checksum_writer.IngestData(12, 8)); // IP addresses. + udp_checksum_writer.IngestUInt8(0); // Zeroes. + udp_checksum_writer.IngestUInt8(17); // IP Protocol = UDP. + udp_checksum_writer.IngestUInt16(udp_length); // UDP length. + QUICHE_CHECK(udp_checksum_writer.IngestData( + kIPv4HeaderSize, udp_length)); // UDP header and data. + QUICHE_CHECK( + udp_checksum_writer.WriteChecksumAtOffset(kIPv4HeaderSize + 6)); + } + client_->masque_client()->masque_client_session()->SendIpPacket( + packet, client_->masque_encapsulated_client_session()); + } else { + absl::string_view packet(buffer, buf_len); + client_->masque_client()->masque_client_session()->SendPacket( + packet, peer_address, client_->masque_encapsulated_client_session()); + } + return WriteResult(WRITE_STATUS_OK, buf_len); + } + + bool IsWriteBlocked() const override { return false; } + + void SetWritable() override {} + + absl::optional MessageTooBigErrorCode() const override { + return absl::nullopt; + } + + QuicByteCount GetMaxPacketSize( + const QuicSocketAddress& /*peer_address*/) const override { + return kMasqueMaxEncapsulatedPacketSize; + } + + bool SupportsReleaseTime() const override { return false; } + + bool IsBatchMode() const override { return false; } + QuicPacketBuffer GetNextWriteLocation( + const QuicIpAddress& /*self_address*/, + const QuicSocketAddress& /*peer_address*/) override { + return {nullptr, nullptr}; + } + + WriteResult Flush() override { return WriteResult(WRITE_STATUS_OK, 0); } + + private: + MasqueEncapsulatedClient* client_; // Unowned. +}; + +// Custom network helper that allows injecting a custom packet writer in order +// to get all of a connection's outgoing packets. +class MasqueClientDefaultNetworkHelper : public QuicClientDefaultNetworkHelper { + public: + MasqueClientDefaultNetworkHelper(QuicEventLoop* event_loop, + MasqueEncapsulatedClient* client) + : QuicClientDefaultNetworkHelper(event_loop, client), client_(client) {} + QuicPacketWriter* CreateQuicPacketWriter() override { + return new MasquePacketWriter(client_); + } + + private: + MasqueEncapsulatedClient* client_; // Unowned. +}; + +} // namespace + +MasqueEncapsulatedClient::MasqueEncapsulatedClient( + QuicSocketAddress server_address, const QuicServerId& server_id, + QuicEventLoop* event_loop, std::unique_ptr proof_verifier, + MasqueClient* masque_client) + : QuicDefaultClient( + server_address, server_id, MasqueSupportedVersions(), + MasqueEncapsulatedConfig(), event_loop, + std::make_unique(event_loop, this), + std::move(proof_verifier)), + masque_client_(masque_client) {} + +MasqueEncapsulatedClient::~MasqueEncapsulatedClient() { + masque_client_->masque_client_session()->CloseConnectUdpStream( + masque_encapsulated_client_session()); +} + +std::unique_ptr MasqueEncapsulatedClient::CreateQuicClientSession( + const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection) { + QUIC_DLOG(INFO) << "Creating MASQUE encapsulated session for " + << connection->connection_id(); + return std::make_unique( + *config(), supported_versions, connection, server_id(), crypto_config(), + push_promise_index(), masque_client_->masque_client_session()); +} + +MasqueEncapsulatedClientSession* +MasqueEncapsulatedClient::masque_encapsulated_client_session() { + return static_cast( + QuicDefaultClient::session()); +} + +} // namespace quic diff --git a/quiche/quic/masque/masque_encapsulated_client.h b/quiche/quic/masque/masque_encapsulated_client.h new file mode 100644 index 000000000000..42fda869e68d --- /dev/null +++ b/quiche/quic/masque/masque_encapsulated_client.h @@ -0,0 +1,47 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_MASQUE_MASQUE_ENCAPSULATED_CLIENT_H_ +#define QUICHE_QUIC_MASQUE_MASQUE_ENCAPSULATED_CLIENT_H_ + +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/masque/masque_client.h" +#include "quiche/quic/masque/masque_encapsulated_client_session.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/tools/quic_default_client.h" + +namespace quic { + +// QUIC client for QUIC encapsulated in MASQUE. +class QUIC_NO_EXPORT MasqueEncapsulatedClient : public QuicDefaultClient { + public: + MasqueEncapsulatedClient(QuicSocketAddress server_address, + const QuicServerId& server_id, + QuicEventLoop* event_loop, + std::unique_ptr proof_verifier, + MasqueClient* masque_client); + ~MasqueEncapsulatedClient() override; + + // Disallow copy and assign. + MasqueEncapsulatedClient(const MasqueEncapsulatedClient&) = delete; + MasqueEncapsulatedClient& operator=(const MasqueEncapsulatedClient&) = delete; + + // From QuicClient. + std::unique_ptr CreateQuicClientSession( + const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection) override; + + // MASQUE client that this client is encapsulated in. + MasqueClient* masque_client() { return masque_client_; } + + // Client session for this client. + MasqueEncapsulatedClientSession* masque_encapsulated_client_session(); + + private: + MasqueClient* masque_client_; // Unowned. +}; + +} // namespace quic + +#endif // QUICHE_QUIC_MASQUE_MASQUE_ENCAPSULATED_CLIENT_H_ diff --git a/quiche/quic/masque/masque_encapsulated_client_session.cc b/quiche/quic/masque/masque_encapsulated_client_session.cc new file mode 100644 index 000000000000..d9e77babab6e --- /dev/null +++ b/quiche/quic/masque/masque_encapsulated_client_session.cc @@ -0,0 +1,255 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/masque/masque_encapsulated_client_session.h" + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_ip_address.h" + +namespace quic { + +using ::quiche::AddressAssignCapsule; +using ::quiche::AddressRequestCapsule; +using ::quiche::RouteAdvertisementCapsule; + +MasqueEncapsulatedClientSession::MasqueEncapsulatedClientSession( + const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, const QuicServerId& server_id, + QuicCryptoClientConfig* crypto_config, + QuicClientPushPromiseIndex* push_promise_index, + MasqueClientSession* masque_client_session) + : QuicSpdyClientSession(config, supported_versions, connection, server_id, + crypto_config, push_promise_index), + masque_client_session_(masque_client_session) {} + +void MasqueEncapsulatedClientSession::ProcessPacket( + absl::string_view packet, QuicSocketAddress server_address) { + QuicTime now = connection()->clock()->ApproximateNow(); + QuicReceivedPacket received_packet(packet.data(), packet.length(), now); + connection()->ProcessUdpPacket(connection()->self_address(), server_address, + received_packet); +} + +void MasqueEncapsulatedClientSession::CloseConnection( + QuicErrorCode error, const std::string& details, + ConnectionCloseBehavior connection_close_behavior) { + connection()->CloseConnection(error, details, connection_close_behavior); +} + +void MasqueEncapsulatedClientSession::OnConnectionClosed( + const QuicConnectionCloseFrame& frame, ConnectionCloseSource source) { + QuicSpdyClientSession::OnConnectionClosed(frame, source); + masque_client_session_->CloseConnectUdpStream(this); +} + +void MasqueEncapsulatedClientSession::ProcessIpPacket( + absl::string_view packet) { + quiche::QuicheDataReader reader(packet); + uint8_t first_byte; + if (!reader.ReadUInt8(&first_byte)) { + QUIC_DLOG(ERROR) << "Dropping empty CONNECT-IP packet"; + return; + } + const uint8_t ip_version = first_byte >> 4; + quiche::QuicheIpAddress server_ip; + if (ip_version == 6) { + if (!reader.Seek(5)) { + QUICHE_DLOG(ERROR) << "Failed to seek CONNECT-IP IPv6 start" + << "\n" + << quiche::QuicheTextUtils::HexDump(packet); + return; + } + uint8_t next_header = 0; + if (!reader.ReadUInt8(&next_header)) { + QUICHE_DLOG(ERROR) << "Failed to read CONNECT-IP next header" + << "\n" + << quiche::QuicheTextUtils::HexDump(packet); + return; + } + if (next_header != 17) { + // Note that this drops packets with IPv6 extension headers, since we + // do not expect to see them in practice. + QUIC_DLOG(ERROR) + << "Dropping CONNECT-IP packet with unexpected next header " + << static_cast(next_header) << "\n" + << quiche::QuicheTextUtils::HexDump(packet); + return; + } + if (!reader.Seek(1)) { + QUICHE_DLOG(ERROR) << "Failed to seek CONNECT-IP hop limit" + << "\n" + << quiche::QuicheTextUtils::HexDump(packet); + return; + } + absl::string_view source_ip; + if (!reader.ReadStringPiece(&source_ip, 16)) { + QUICHE_DLOG(ERROR) << "Failed to read CONNECT-IP source IPv6" + << "\n" + << quiche::QuicheTextUtils::HexDump(packet); + return; + } + server_ip.FromPackedString(source_ip.data(), source_ip.length()); + if (!reader.Seek(16)) { + QUICHE_DLOG(ERROR) << "Failed to seek CONNECT-IP destination IPv6" + << "\n" + << quiche::QuicheTextUtils::HexDump(packet); + return; + } + } else if (ip_version == 4) { + uint8_t ihl = first_byte & 0xF; + if (ihl < 5) { + QUICHE_DLOG(ERROR) << "Dropping CONNECT-IP packet with invalid IHL " + << static_cast(ihl) << "\n" + << quiche::QuicheTextUtils::HexDump(packet); + return; + } + if (!reader.Seek(8)) { + QUICHE_DLOG(ERROR) << "Failed to seek CONNECT-IP IPv4 start" + << "\n" + << quiche::QuicheTextUtils::HexDump(packet); + return; + } + uint8_t ip_proto = 0; + if (!reader.ReadUInt8(&ip_proto)) { + QUICHE_DLOG(ERROR) << "Failed to read CONNECT-IP ip_proto" + << "\n" + << quiche::QuicheTextUtils::HexDump(packet); + return; + } + if (ip_proto != 17) { + QUIC_DLOG(ERROR) << "Dropping CONNECT-IP packet with unexpected IP proto " + << static_cast(ip_proto) << "\n" + << quiche::QuicheTextUtils::HexDump(packet); + return; + } + if (!reader.Seek(2)) { + QUICHE_DLOG(ERROR) << "Failed to seek CONNECT-IP IP checksum" + << "\n" + << quiche::QuicheTextUtils::HexDump(packet); + return; + } + absl::string_view source_ip; + if (!reader.ReadStringPiece(&source_ip, 4)) { + QUICHE_DLOG(ERROR) << "Failed to read CONNECT-IP source IPv4" + << "\n" + << quiche::QuicheTextUtils::HexDump(packet); + return; + } + server_ip.FromPackedString(source_ip.data(), source_ip.length()); + if (!reader.Seek(4)) { + QUICHE_DLOG(ERROR) << "Failed to seek CONNECT-IP destination IPv4" + << "\n" + << quiche::QuicheTextUtils::HexDump(packet); + return; + } + uint8_t ip_options_length = (ihl - 5) * 4; + if (!reader.Seek(ip_options_length)) { + QUICHE_DLOG(ERROR) << "Failed to seek CONNECT-IP IP options of length " + << static_cast(ip_options_length) << "\n" + << quiche::QuicheTextUtils::HexDump(packet); + return; + } + } else { + QUIC_DLOG(ERROR) << "Dropping CONNECT-IP packet with unexpected IP version " + << static_cast(ip_version) << "\n" + << quiche::QuicheTextUtils::HexDump(packet); + return; + } + // Parse UDP header. + uint16_t server_port; + if (!reader.ReadUInt16(&server_port)) { + QUICHE_DLOG(ERROR) << "Failed to read CONNECT-IP source port" + << "\n" + << quiche::QuicheTextUtils::HexDump(packet); + return; + } + if (!reader.Seek(2)) { + QUICHE_DLOG(ERROR) << "Failed to seek CONNECT-IP destination port" + << "\n" + << quiche::QuicheTextUtils::HexDump(packet); + return; + } + uint16_t udp_length; + if (!reader.ReadUInt16(&udp_length)) { + QUICHE_DLOG(ERROR) << "Failed to read CONNECT-IP UDP length" + << "\n" + << quiche::QuicheTextUtils::HexDump(packet); + return; + } + if (udp_length < 8) { + QUICHE_DLOG(ERROR) << "Dropping CONNECT-IP packet with invalid UDP length " + << udp_length << "\n" + << quiche::QuicheTextUtils::HexDump(packet); + return; + } + if (!reader.Seek(2)) { + QUICHE_DLOG(ERROR) << "Failed to seek CONNECT-IP UDP checksum" + << "\n" + << quiche::QuicheTextUtils::HexDump(packet); + return; + } + absl::string_view quic_packet; + if (!reader.ReadStringPiece(&quic_packet, udp_length - 8)) { + QUICHE_DLOG(ERROR) << "Failed to read CONNECT-IP UDP payload" + << "\n" + << quiche::QuicheTextUtils::HexDump(packet); + return; + } + if (!reader.IsDoneReading()) { + QUICHE_DLOG(INFO) << "Received CONNECT-IP UDP packet with " + << reader.BytesRemaining() + << " extra bytes after payload\n" + << quiche::QuicheTextUtils::HexDump(packet); + } + QUIC_DLOG(INFO) << "Received CONNECT-IP encapsulated packet of length " + << quic_packet.size(); + QuicTime now = connection()->clock()->ApproximateNow(); + QuicReceivedPacket received_packet(quic_packet.data(), quic_packet.size(), + now); + QuicSocketAddress server_address = QuicSocketAddress(server_ip, server_port); + connection()->ProcessUdpPacket(connection()->self_address(), server_address, + received_packet); +} + +void MasqueEncapsulatedClientSession::CloseIpSession( + const std::string& details) { + connection()->CloseConnection(QUIC_CONNECTION_CANCELLED, details, + ConnectionCloseBehavior::SILENT_CLOSE); +} + +bool MasqueEncapsulatedClientSession::OnAddressAssignCapsule( + const AddressAssignCapsule& capsule) { + QUIC_DLOG(INFO) << "Received capsule " << capsule.ToString(); + for (auto assigned_address : capsule.assigned_addresses) { + if (assigned_address.ip_prefix.address().IsIPv4() && + !local_v4_address_.IsInitialized()) { + QUIC_LOG(INFO) + << "MasqueEncapsulatedClientSession saving local IPv4 address " + << assigned_address.ip_prefix.address(); + local_v4_address_ = assigned_address.ip_prefix.address(); + } else if (assigned_address.ip_prefix.address().IsIPv6() && + !local_v6_address_.IsInitialized()) { + QUIC_LOG(INFO) + << "MasqueEncapsulatedClientSession saving local IPv6 address " + << assigned_address.ip_prefix.address(); + local_v6_address_ = assigned_address.ip_prefix.address(); + } + } + return true; +} + +bool MasqueEncapsulatedClientSession::OnAddressRequestCapsule( + const AddressRequestCapsule& capsule) { + QUIC_DLOG(INFO) << "Ignoring received capsule " << capsule.ToString(); + return true; +} + +bool MasqueEncapsulatedClientSession::OnRouteAdvertisementCapsule( + const RouteAdvertisementCapsule& capsule) { + QUIC_DLOG(INFO) << "Ignoring received capsule " << capsule.ToString(); + return true; +} + +} // namespace quic diff --git a/quiche/quic/masque/masque_encapsulated_client_session.h b/quiche/quic/masque/masque_encapsulated_client_session.h new file mode 100644 index 000000000000..323c16dbd57c --- /dev/null +++ b/quiche/quic/masque/masque_encapsulated_client_session.h @@ -0,0 +1,78 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_MASQUE_MASQUE_ENCAPSULATED_CLIENT_SESSION_H_ +#define QUICHE_QUIC_MASQUE_MASQUE_ENCAPSULATED_CLIENT_SESSION_H_ + +#include "quiche/quic/core/http/quic_spdy_client_session.h" +#include "quiche/quic/masque/masque_client_session.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// QUIC client session for QUIC encapsulated in MASQUE. This client session is +// maintained end-to-end between the client and the web-server (the MASQUE +// session does not have access to the cryptographic keys for the end-to-end +// session), but its packets are sent encapsulated inside DATAGRAM frames in a +// MASQUE session, as opposed to regular QUIC packets. Multiple encapsulated +// sessions can coexist inside a MASQUE session. +class QUIC_NO_EXPORT MasqueEncapsulatedClientSession + : public QuicSpdyClientSession, + public MasqueClientSession::EncapsulatedClientSession, + public MasqueClientSession::EncapsulatedIpSession { + public: + // Takes ownership of |connection|, but not of |crypto_config| or + // |push_promise_index| or |masque_client_session|. All pointers must be + // non-null. Caller must ensure that |push_promise_index| and + // |masque_client_session| stay valid for the lifetime of the newly created + // MasqueEncapsulatedClientSession. + MasqueEncapsulatedClientSession( + const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, const QuicServerId& server_id, + QuicCryptoClientConfig* crypto_config, + QuicClientPushPromiseIndex* push_promise_index, + MasqueClientSession* masque_client_session); + + // Disallow copy and assign. + MasqueEncapsulatedClientSession(const MasqueEncapsulatedClientSession&) = + delete; + MasqueEncapsulatedClientSession& operator=( + const MasqueEncapsulatedClientSession&) = delete; + + // From MasqueClientSession::EncapsulatedClientSession. + void ProcessPacket(absl::string_view packet, + QuicSocketAddress server_address) override; + void CloseConnection( + QuicErrorCode error, const std::string& details, + ConnectionCloseBehavior connection_close_behavior) override; + + // From MasqueClientSession::EncapsulatedIpSession. + void ProcessIpPacket(absl::string_view packet) override; + void CloseIpSession(const std::string& details) override; + bool OnAddressAssignCapsule( + const quiche::AddressAssignCapsule& capsule) override; + bool OnAddressRequestCapsule( + const quiche::AddressRequestCapsule& capsule) override; + bool OnRouteAdvertisementCapsule( + const quiche::RouteAdvertisementCapsule& capsule) override; + + // From QuicSession. + void OnConnectionClosed(const QuicConnectionCloseFrame& frame, + ConnectionCloseSource source) override; + + // For CONNECT-IP. + QuicIpAddress local_v4_address() const { return local_v4_address_; } + QuicIpAddress local_v6_address() const { return local_v6_address_; } + + private: + MasqueClientSession* masque_client_session_; // Unowned. + // For CONNECT-IP. + QuicIpAddress local_v4_address_; + QuicIpAddress local_v6_address_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_MASQUE_MASQUE_ENCAPSULATED_CLIENT_SESSION_H_ diff --git a/quiche/quic/masque/masque_server.cc b/quiche/quic/masque/masque_server.cc new file mode 100644 index 000000000000..b05e306672fc --- /dev/null +++ b/quiche/quic/masque/masque_server.cc @@ -0,0 +1,31 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/masque/masque_server.h" + +#include "quiche/quic/core/quic_default_connection_helper.h" +#include "quiche/quic/masque/masque_dispatcher.h" +#include "quiche/quic/masque/masque_utils.h" +#include "quiche/quic/platform/api/quic_default_proof_providers.h" +#include "quiche/quic/tools/quic_simple_crypto_server_stream_helper.h" + +namespace quic { + +MasqueServer::MasqueServer(MasqueMode masque_mode, + MasqueServerBackend* masque_server_backend) + : QuicServer(CreateDefaultProofSource(), masque_server_backend, + MasqueSupportedVersions()), + masque_mode_(masque_mode), + masque_server_backend_(masque_server_backend) {} + +QuicDispatcher* MasqueServer::CreateQuicDispatcher() { + return new MasqueDispatcher( + masque_mode_, &config(), &crypto_config(), version_manager(), + event_loop(), std::make_unique(), + std::make_unique(), + event_loop()->CreateAlarmFactory(), masque_server_backend_, + expected_server_connection_id_length(), connection_id_generator()); +} + +} // namespace quic diff --git a/quiche/quic/masque/masque_server.h b/quiche/quic/masque/masque_server.h new file mode 100644 index 000000000000..0d09d7fe6b3f --- /dev/null +++ b/quiche/quic/masque/masque_server.h @@ -0,0 +1,35 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_MASQUE_MASQUE_SERVER_H_ +#define QUICHE_QUIC_MASQUE_MASQUE_SERVER_H_ + +#include "quiche/quic/masque/masque_server_backend.h" +#include "quiche/quic/masque/masque_utils.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/tools/quic_server.h" + +namespace quic { + +// QUIC server that implements MASQUE. +class QUIC_NO_EXPORT MasqueServer : public QuicServer { + public: + explicit MasqueServer(MasqueMode masque_mode, + MasqueServerBackend* masque_server_backend); + + // Disallow copy and assign. + MasqueServer(const MasqueServer&) = delete; + MasqueServer& operator=(const MasqueServer&) = delete; + + // From QuicServer. + QuicDispatcher* CreateQuicDispatcher() override; + + private: + MasqueMode masque_mode_; + MasqueServerBackend* masque_server_backend_; // Unowned. +}; + +} // namespace quic + +#endif // QUICHE_QUIC_MASQUE_MASQUE_SERVER_H_ diff --git a/quiche/quic/masque/masque_server_backend.cc b/quiche/quic/masque/masque_server_backend.cc new file mode 100644 index 000000000000..d7ccb9aceab8 --- /dev/null +++ b/quiche/quic/masque/masque_server_backend.cc @@ -0,0 +1,153 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/masque/masque_server_backend.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" + +namespace quic { + +MasqueServerBackend::MasqueServerBackend(MasqueMode masque_mode, + const std::string& server_authority, + const std::string& cache_directory) + : masque_mode_(masque_mode), server_authority_(server_authority) { + // Start with client IP 10.1.1.2. + connect_ip_next_client_ip_[0] = 10; + connect_ip_next_client_ip_[1] = 1; + connect_ip_next_client_ip_[2] = 1; + connect_ip_next_client_ip_[3] = 2; + + if (!cache_directory.empty()) { + QuicMemoryCacheBackend::InitializeBackend(cache_directory); + } +} + +bool MasqueServerBackend::MaybeHandleMasqueRequest( + const spdy::Http2HeaderBlock& request_headers, + QuicSimpleServerBackend::RequestHandler* request_handler) { + auto method_pair = request_headers.find(":method"); + if (method_pair == request_headers.end()) { + // Request is missing a method. + return false; + } + absl::string_view method = method_pair->second; + std::string masque_path = ""; + auto protocol_pair = request_headers.find(":protocol"); + if (method != "CONNECT" || protocol_pair == request_headers.end() || + (protocol_pair->second != "connect-udp" && + protocol_pair->second != "connect-ip")) { + // This is not a MASQUE request. + return false; + } + + if (!server_authority_.empty()) { + auto authority_pair = request_headers.find(":authority"); + if (authority_pair == request_headers.end()) { + // Cannot enforce missing authority. + return false; + } + absl::string_view authority = authority_pair->second; + if (server_authority_ != authority) { + // This request does not match server_authority. + return false; + } + } + + auto it = backend_client_states_.find(request_handler->connection_id()); + if (it == backend_client_states_.end()) { + QUIC_LOG(ERROR) << "Could not find backend client for " << masque_path + << request_headers.DebugString(); + return false; + } + + BackendClient* backend_client = it->second.backend_client; + + std::unique_ptr response = + backend_client->HandleMasqueRequest(request_headers, request_handler); + if (response == nullptr) { + QUIC_LOG(ERROR) << "Backend client did not process request for " + << masque_path << request_headers.DebugString(); + return false; + } + + QUIC_DLOG(INFO) << "Sending MASQUE response for " + << request_headers.DebugString(); + + request_handler->OnResponseBackendComplete(response.get()); + it->second.responses.emplace_back(std::move(response)); + + return true; +} + +void MasqueServerBackend::FetchResponseFromBackend( + const spdy::Http2HeaderBlock& request_headers, + const std::string& request_body, + QuicSimpleServerBackend::RequestHandler* request_handler) { + if (MaybeHandleMasqueRequest(request_headers, request_handler)) { + // Request was handled as a MASQUE request. + return; + } + QUIC_DLOG(INFO) << "Fetching non-MASQUE response for " + << request_headers.DebugString(); + QuicMemoryCacheBackend::FetchResponseFromBackend( + request_headers, request_body, request_handler); +} + +void MasqueServerBackend::HandleConnectHeaders( + const spdy::Http2HeaderBlock& request_headers, + RequestHandler* request_handler) { + if (MaybeHandleMasqueRequest(request_headers, request_handler)) { + // Request was handled as a MASQUE request. + return; + } + QUIC_DLOG(INFO) << "Fetching non-MASQUE CONNECT response for " + << request_headers.DebugString(); + QuicMemoryCacheBackend::HandleConnectHeaders(request_headers, + request_handler); +} + +void MasqueServerBackend::CloseBackendResponseStream( + QuicSimpleServerBackend::RequestHandler* request_handler) { + QUIC_DLOG(INFO) << "Closing response stream"; + QuicMemoryCacheBackend::CloseBackendResponseStream(request_handler); +} + +void MasqueServerBackend::RegisterBackendClient(QuicConnectionId connection_id, + BackendClient* backend_client) { + QUIC_DLOG(INFO) << "Registering backend client for " << connection_id; + QUIC_BUG_IF(quic_bug_12005_1, backend_client_states_.find(connection_id) != + backend_client_states_.end()) + << connection_id << " already in backend clients map"; + backend_client_states_[connection_id] = + BackendClientState{backend_client, {}}; +} + +void MasqueServerBackend::RemoveBackendClient(QuicConnectionId connection_id) { + QUIC_DLOG(INFO) << "Removing backend client for " << connection_id; + backend_client_states_.erase(connection_id); +} + +QuicIpAddress MasqueServerBackend::GetNextClientIpAddress() { + // Makes sure all addresses are in 10.(1-254).(1-254).(2-254) + QuicIpAddress address; + address.FromPackedString( + reinterpret_cast(&connect_ip_next_client_ip_[0]), + sizeof(connect_ip_next_client_ip_)); + connect_ip_next_client_ip_[3]++; + if (connect_ip_next_client_ip_[3] >= 255) { + connect_ip_next_client_ip_[3] = 2; + connect_ip_next_client_ip_[2]++; + if (connect_ip_next_client_ip_[2] >= 255) { + connect_ip_next_client_ip_[2] = 1; + connect_ip_next_client_ip_[1]++; + if (connect_ip_next_client_ip_[1] >= 255) { + QUIC_LOG(FATAL) << "Ran out of IP addresses, restarting process."; + } + } + } + return address; +} + +} // namespace quic diff --git a/quiche/quic/masque/masque_server_backend.h b/quiche/quic/masque/masque_server_backend.h new file mode 100644 index 000000000000..50c5a021538a --- /dev/null +++ b/quiche/quic/masque/masque_server_backend.h @@ -0,0 +1,80 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_MASQUE_MASQUE_SERVER_BACKEND_H_ +#define QUICHE_QUIC_MASQUE_MASQUE_SERVER_BACKEND_H_ + +#include "absl/container/flat_hash_map.h" +#include "quiche/quic/masque/masque_utils.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/tools/quic_memory_cache_backend.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +// QUIC server backend that understands MASQUE requests, but otherwise answers +// HTTP queries using an in-memory cache. +class QUIC_NO_EXPORT MasqueServerBackend : public QuicMemoryCacheBackend { + public: + // Interface meant to be implemented by the owner of the MasqueServerBackend + // instance. + class QUIC_NO_EXPORT BackendClient { + public: + virtual std::unique_ptr HandleMasqueRequest( + const spdy::Http2HeaderBlock& request_headers, + QuicSimpleServerBackend::RequestHandler* request_handler) = 0; + virtual ~BackendClient() = default; + }; + + explicit MasqueServerBackend(MasqueMode masque_mode, + const std::string& server_authority, + const std::string& cache_directory); + + // Disallow copy and assign. + MasqueServerBackend(const MasqueServerBackend&) = delete; + MasqueServerBackend& operator=(const MasqueServerBackend&) = delete; + + // From QuicMemoryCacheBackend. + void FetchResponseFromBackend( + const spdy::Http2HeaderBlock& request_headers, + const std::string& request_body, + QuicSimpleServerBackend::RequestHandler* request_handler) override; + void HandleConnectHeaders(const spdy::Http2HeaderBlock& request_headers, + RequestHandler* request_handler) override; + + void CloseBackendResponseStream( + QuicSimpleServerBackend::RequestHandler* request_handler) override; + + // Register backend client that can handle MASQUE requests. + void RegisterBackendClient(QuicConnectionId connection_id, + BackendClient* backend_client); + + // Unregister backend client. + void RemoveBackendClient(QuicConnectionId connection_id); + + // Provides a unique client IP address for each CONNECT-IP client. + QuicIpAddress GetNextClientIpAddress(); + + private: + // Handle MASQUE request. + bool MaybeHandleMasqueRequest( + const spdy::Http2HeaderBlock& request_headers, + QuicSimpleServerBackend::RequestHandler* request_handler); + + MasqueMode masque_mode_; + std::string server_authority_; + + struct QUIC_NO_EXPORT BackendClientState { + BackendClient* backend_client; + std::vector> responses; + }; + absl::flat_hash_map + backend_client_states_; + uint8_t connect_ip_next_client_ip_[4]; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_MASQUE_MASQUE_SERVER_BACKEND_H_ diff --git a/quiche/quic/masque/masque_server_bin.cc b/quiche/quic/masque/masque_server_bin.cc new file mode 100644 index 000000000000..148d640b8fd6 --- /dev/null +++ b/quiche/quic/masque/masque_server_bin.cc @@ -0,0 +1,71 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This file is reponsible for the masque_server binary. It allows testing +// our MASQUE server code by creating a MASQUE proxy that relays HTTP/3 +// requests to web servers tunnelled over MASQUE connections. +// e.g.: masque_server + +#include + +#include "quiche/quic/masque/masque_server.h" +#include "quiche/quic/masque/masque_server_backend.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/common/platform/api/quiche_command_line_flags.h" +#include "quiche/common/platform/api/quiche_system_event_loop.h" + +DEFINE_QUICHE_COMMAND_LINE_FLAG(int32_t, port, 9661, + "The port the MASQUE server will listen on."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::string, cache_dir, "", + "Specifies the directory used during QuicHttpResponseCache " + "construction to seed the cache. Cache directory can be " + "generated using `wget -p --save-headers `"); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::string, server_authority, "", + "Specifies the authority over which the server will accept MASQUE " + "requests. Defaults to empty which allows all authorities."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::string, masque_mode, "", + "Allows setting MASQUE mode, currently only valid value is \"open\"."); + +int main(int argc, char* argv[]) { + quiche::QuicheSystemEventLoop event_loop("masque_server"); + const char* usage = "Usage: masque_server [options]"; + std::vector non_option_args = + quiche::QuicheParseCommandLineFlags(usage, argc, argv); + if (!non_option_args.empty()) { + quiche::QuichePrintCommandLineFlagHelp(usage); + return 0; + } + + quic::MasqueMode masque_mode = quic::MasqueMode::kOpen; + std::string mode_string = quiche::GetQuicheCommandLineFlag(FLAGS_masque_mode); + if (!mode_string.empty() && mode_string != "open") { + std::cerr << "Invalid masque_mode \"" << mode_string << "\"" << std::endl; + return 1; + } + + auto backend = std::make_unique( + masque_mode, quiche::GetQuicheCommandLineFlag(FLAGS_server_authority), + quiche::GetQuicheCommandLineFlag(FLAGS_cache_dir)); + + auto server = + std::make_unique(masque_mode, backend.get()); + + if (!server->CreateUDPSocketAndListen(quic::QuicSocketAddress( + quic::QuicIpAddress::Any6(), + quiche::GetQuicheCommandLineFlag(FLAGS_port)))) { + return 1; + } + + std::cerr << "Started " << masque_mode << " MASQUE server" << std::endl; + server->HandleEventsForever(); + return 0; +} diff --git a/quiche/quic/masque/masque_server_session.cc b/quiche/quic/masque/masque_server_session.cc new file mode 100644 index 000000000000..493b9f6cfdbd --- /dev/null +++ b/quiche/quic/masque/masque_server_session.cc @@ -0,0 +1,642 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/masque/masque_server_session.h" + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "absl/cleanup/cleanup.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/quic/core/http/spdy_utils.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/quic_data_reader.h" +#include "quiche/quic/core/quic_udp_socket.h" +#include "quiche/quic/masque/masque_utils.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/tools/quic_url.h" +#include "quiche/common/capsule.h" +#include "quiche/common/platform/api/quiche_url_utils.h" +#include "quiche/common/quiche_ip_address.h" + +namespace quic { + +namespace { + +using ::quiche::AddressAssignCapsule; +using ::quiche::AddressRequestCapsule; +using ::quiche::Capsule; +using ::quiche::IpAddressRange; +using ::quiche::PrefixWithId; +using ::quiche::RouteAdvertisementCapsule; + +// RAII wrapper for QuicUdpSocketFd. +class FdWrapper { + public: + // Takes ownership of |fd| and closes the file descriptor on destruction. + explicit FdWrapper(int address_family) { + QuicUdpSocketApi socket_api; + fd_ = + socket_api.Create(address_family, + /*receive_buffer_size =*/kDefaultSocketReceiveBuffer, + /*send_buffer_size =*/kDefaultSocketReceiveBuffer); + } + + ~FdWrapper() { + if (fd_ == kQuicInvalidSocketFd) { + return; + } + QuicUdpSocketApi socket_api; + socket_api.Destroy(fd_); + } + + // Hands ownership of the file descriptor to the caller. + QuicUdpSocketFd extract_fd() { + QuicUdpSocketFd fd = fd_; + fd_ = kQuicInvalidSocketFd; + return fd; + } + + // Keeps ownership of the file descriptor. + QuicUdpSocketFd fd() { return fd_; } + + // Disallow copy and move. + FdWrapper(const FdWrapper&) = delete; + FdWrapper(FdWrapper&&) = delete; + FdWrapper& operator=(const FdWrapper&) = delete; + FdWrapper& operator=(FdWrapper&&) = delete; + + private: + QuicUdpSocketFd fd_; +}; + +std::unique_ptr CreateBackendErrorResponse( + absl::string_view status, absl::string_view error_details) { + spdy::Http2HeaderBlock response_headers; + response_headers[":status"] = status; + response_headers["masque-debug-info"] = error_details; + auto response = std::make_unique(); + response->set_response_type(QuicBackendResponse::REGULAR_RESPONSE); + response->set_headers(std::move(response_headers)); + return response; +} + +} // namespace + +MasqueServerSession::MasqueServerSession( + MasqueMode masque_mode, const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, QuicSession::Visitor* visitor, + QuicEventLoop* event_loop, QuicCryptoServerStreamBase::Helper* helper, + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache, + MasqueServerBackend* masque_server_backend) + : QuicSimpleServerSession(config, supported_versions, connection, visitor, + helper, crypto_config, compressed_certs_cache, + masque_server_backend), + masque_server_backend_(masque_server_backend), + event_loop_(event_loop), + masque_mode_(masque_mode) { + // Artificially increase the max packet length to 1350 to ensure we can fit + // QUIC packets inside DATAGRAM frames. + // TODO(b/181606597) Remove this workaround once we use PMTUD. + connection->SetMaxPacketLength(kMasqueMaxOuterPacketSize); + + masque_server_backend_->RegisterBackendClient(connection_id(), this); + QUICHE_DCHECK_NE(event_loop_, nullptr); +} + +void MasqueServerSession::OnMessageAcked(QuicMessageId message_id, + QuicTime /*receive_timestamp*/) { + QUIC_DVLOG(1) << "Received ack for DATAGRAM frame " << message_id; +} + +void MasqueServerSession::OnMessageLost(QuicMessageId message_id) { + QUIC_DVLOG(1) << "We believe DATAGRAM frame " << message_id << " was lost"; +} + +void MasqueServerSession::OnConnectionClosed( + const QuicConnectionCloseFrame& frame, ConnectionCloseSource source) { + QuicSimpleServerSession::OnConnectionClosed(frame, source); + QUIC_DLOG(INFO) << "Closing connection for " << connection_id(); + masque_server_backend_->RemoveBackendClient(connection_id()); + // Clearing this state will close all sockets. + connect_udp_server_states_.clear(); +} + +void MasqueServerSession::OnStreamClosed(QuicStreamId stream_id) { + connect_udp_server_states_.remove_if( + [stream_id](const ConnectUdpServerState& connect_udp) { + return connect_udp.stream()->id() == stream_id; + }); + connect_ip_server_states_.remove_if( + [stream_id](const ConnectIpServerState& connect_ip) { + return connect_ip.stream()->id() == stream_id; + }); + + QuicSimpleServerSession::OnStreamClosed(stream_id); +} + +std::unique_ptr MasqueServerSession::HandleMasqueRequest( + const spdy::Http2HeaderBlock& request_headers, + QuicSimpleServerBackend::RequestHandler* request_handler) { + auto path_pair = request_headers.find(":path"); + auto scheme_pair = request_headers.find(":scheme"); + auto method_pair = request_headers.find(":method"); + auto protocol_pair = request_headers.find(":protocol"); + auto authority_pair = request_headers.find(":authority"); + if (path_pair == request_headers.end()) { + QUIC_DLOG(ERROR) << "MASQUE request is missing :path"; + return CreateBackendErrorResponse("400", "Missing :path"); + } + if (scheme_pair == request_headers.end()) { + QUIC_DLOG(ERROR) << "MASQUE request is missing :scheme"; + return CreateBackendErrorResponse("400", "Missing :scheme"); + } + if (method_pair == request_headers.end()) { + QUIC_DLOG(ERROR) << "MASQUE request is missing :method"; + return CreateBackendErrorResponse("400", "Missing :method"); + } + if (protocol_pair == request_headers.end()) { + QUIC_DLOG(ERROR) << "MASQUE request is missing :protocol"; + return CreateBackendErrorResponse("400", "Missing :protocol"); + } + if (authority_pair == request_headers.end()) { + QUIC_DLOG(ERROR) << "MASQUE request is missing :authority"; + return CreateBackendErrorResponse("400", "Missing :authority"); + } + absl::string_view path = path_pair->second; + absl::string_view scheme = scheme_pair->second; + absl::string_view method = method_pair->second; + absl::string_view protocol = protocol_pair->second; + absl::string_view authority = authority_pair->second; + if (path.empty()) { + QUIC_DLOG(ERROR) << "MASQUE request with empty path"; + return CreateBackendErrorResponse("400", "Empty path"); + } + if (scheme.empty()) { + return CreateBackendErrorResponse("400", "Empty scheme"); + } + if (method != "CONNECT") { + QUIC_DLOG(ERROR) << "MASQUE request with bad method \"" << method << "\""; + return CreateBackendErrorResponse("400", "Bad method"); + } + if (protocol != "connect-udp" && protocol != "connect-ip") { + QUIC_DLOG(ERROR) << "MASQUE request with bad protocol \"" << protocol + << "\""; + return CreateBackendErrorResponse("400", "Bad protocol"); + } + if (protocol == "connect-ip") { + QuicSpdyStream* stream = static_cast( + GetActiveStream(request_handler->stream_id())); + if (stream == nullptr) { + QUIC_BUG(bad masque server stream type) + << "Unexpected stream type for stream ID " + << request_handler->stream_id(); + return CreateBackendErrorResponse("500", "Bad stream type"); + } + QuicIpAddress client_ip = masque_server_backend_->GetNextClientIpAddress(); + QUIC_DLOG(INFO) << "Using client IP " << client_ip.ToString() + << " for CONNECT-IP stream ID " + << request_handler->stream_id(); + int fd = CreateTunInterface(client_ip); + if (fd < 0) { + QUIC_LOG(ERROR) << "Failed to create TUN interface for stream ID " + << request_handler->stream_id(); + return CreateBackendErrorResponse("500", + "Failed to create TUN interface"); + } + if (!event_loop_->RegisterSocket(fd, kSocketEventReadable, this)) { + QUIC_DLOG(ERROR) << "Failed to register TUN fd with the event loop"; + close(fd); + return CreateBackendErrorResponse("500", "Registering TUN socket failed"); + } + connect_ip_server_states_.push_back( + ConnectIpServerState(client_ip, stream, fd, this)); + + spdy::Http2HeaderBlock response_headers; + response_headers[":status"] = "200"; + auto response = std::make_unique(); + response->set_response_type(QuicBackendResponse::INCOMPLETE_RESPONSE); + response->set_headers(std::move(response_headers)); + response->set_body(""); + + return response; + } + // Extract target host and port from path using default template. + std::vector path_split = absl::StrSplit(path, '/'); + if (path_split.size() != 7 || !path_split[0].empty() || + path_split[1] != ".well-known" || path_split[2] != "masque" || + path_split[3] != "udp" || path_split[4].empty() || + path_split[5].empty() || !path_split[6].empty()) { + QUIC_DLOG(ERROR) << "MASQUE request with bad path \"" << path << "\""; + return CreateBackendErrorResponse("400", "Bad path"); + } + absl::optional host = quiche::AsciiUrlDecode(path_split[4]); + if (!host.has_value()) { + QUIC_DLOG(ERROR) << "Failed to decode host \"" << path_split[4] << "\""; + return CreateBackendErrorResponse("500", "Failed to decode host"); + } + absl::optional port = quiche::AsciiUrlDecode(path_split[5]); + if (!port.has_value()) { + QUIC_DLOG(ERROR) << "Failed to decode port \"" << path_split[5] << "\""; + return CreateBackendErrorResponse("500", "Failed to decode port"); + } + + // Perform DNS resolution. + addrinfo hint = {}; + hint.ai_protocol = IPPROTO_UDP; + + addrinfo* info_list = nullptr; + int result = getaddrinfo(host.value().c_str(), port.value().c_str(), &hint, + &info_list); + if (result != 0 || info_list == nullptr) { + QUIC_DLOG(ERROR) << "Failed to resolve " << authority << ": " + << gai_strerror(result); + return CreateBackendErrorResponse("500", "DNS resolution failed"); + } + + std::unique_ptr info_list_owned(info_list, + freeaddrinfo); + QuicSocketAddress target_server_address(info_list->ai_addr, + info_list->ai_addrlen); + QUIC_DLOG(INFO) << "Got CONNECT_UDP request on stream ID " + << request_handler->stream_id() << " target_server_address=\"" + << target_server_address << "\""; + + FdWrapper fd_wrapper(target_server_address.host().AddressFamilyToInt()); + if (fd_wrapper.fd() == kQuicInvalidSocketFd) { + QUIC_DLOG(ERROR) << "Socket creation failed"; + return CreateBackendErrorResponse("500", "Socket creation failed"); + } + QuicSocketAddress empty_address(QuicIpAddress::Any6(), 0); + if (target_server_address.host().IsIPv4()) { + empty_address = QuicSocketAddress(QuicIpAddress::Any4(), 0); + } + QuicUdpSocketApi socket_api; + if (!socket_api.Bind(fd_wrapper.fd(), empty_address)) { + QUIC_DLOG(ERROR) << "Socket bind failed"; + return CreateBackendErrorResponse("500", "Socket bind failed"); + } + if (!event_loop_->RegisterSocket(fd_wrapper.fd(), kSocketEventReadable, + this)) { + QUIC_DLOG(ERROR) << "Failed to register socket with the event loop"; + return CreateBackendErrorResponse("500", "Registering socket failed"); + } + + QuicSpdyStream* stream = + static_cast(GetActiveStream(request_handler->stream_id())); + if (stream == nullptr) { + QUIC_BUG(bad masque server stream type) + << "Unexpected stream type for stream ID " + << request_handler->stream_id(); + return CreateBackendErrorResponse("500", "Bad stream type"); + } + connect_udp_server_states_.push_back(ConnectUdpServerState( + stream, target_server_address, fd_wrapper.extract_fd(), this)); + + spdy::Http2HeaderBlock response_headers; + response_headers[":status"] = "200"; + auto response = std::make_unique(); + response->set_response_type(QuicBackendResponse::INCOMPLETE_RESPONSE); + response->set_headers(std::move(response_headers)); + response->set_body(""); + + return response; +} + +void MasqueServerSession::OnSocketEvent(QuicEventLoop* /*event_loop*/, + QuicUdpSocketFd fd, + QuicSocketEventMask events) { + if ((events & kSocketEventReadable) == 0) { + QUIC_DVLOG(1) << "Ignoring OnEvent fd " << fd << " event mask " << events; + return; + } + auto it = absl::c_find_if(connect_udp_server_states_, + [fd](const ConnectUdpServerState& connect_udp) { + return connect_udp.fd() == fd; + }); + if (it == connect_udp_server_states_.end()) { + auto it2 = absl::c_find_if(connect_ip_server_states_, + [fd](const ConnectIpServerState& connect_ip) { + return connect_ip.fd() == fd; + }); + if (it2 == connect_ip_server_states_.end()) { + QUIC_BUG(quic_bug_10974_1) + << "Got unexpected event mask " << events << " on unknown fd " << fd; + return; + } + + char datagram[1501]; + datagram[0] = 0; // Context ID. + while (true) { + ssize_t read_size = read(fd, datagram + 1, sizeof(datagram) - 1); + if (read_size < 0) { + break; + } + MessageStatus message_status = it2->stream()->SendHttp3Datagram( + absl::string_view(datagram, 1 + read_size)); + QUIC_DVLOG(1) << "Encapsulated IP packet of length " << read_size + << " with stream ID " << it2->stream()->id() + << " and got message status " + << MessageStatusToString(message_status); + } + if (!event_loop_->SupportsEdgeTriggered()) { + if (!event_loop_->RearmSocket(fd, kSocketEventReadable)) { + QUIC_BUG(MasqueServerSession_ConnectIp_OnSocketEvent_Rearm) + << "Failed to re-arm socket " << fd << " for reading"; + } + } + + return; + } + + auto rearm = absl::MakeCleanup([&]() { + if (!event_loop_->SupportsEdgeTriggered()) { + if (!event_loop_->RearmSocket(fd, kSocketEventReadable)) { + QUIC_BUG(MasqueServerSession_OnSocketEvent_Rearm) + << "Failed to re-arm socket " << fd << " for reading"; + } + } + }); + + QuicSocketAddress expected_target_server_address = + it->target_server_address(); + QUICHE_DCHECK(expected_target_server_address.IsInitialized()); + QUIC_DVLOG(1) << "Received readable event on fd " << fd << " (mask " << events + << ") stream ID " << it->stream()->id() << " server " + << expected_target_server_address; + QuicUdpSocketApi socket_api; + BitMask64 packet_info_interested(QuicUdpPacketInfoBit::PEER_ADDRESS); + char packet_buffer[1 + kMaxIncomingPacketSize]; + packet_buffer[0] = 0; // context ID. + char control_buffer[kDefaultUdpPacketControlBufferSize]; + while (true) { + QuicUdpSocketApi::ReadPacketResult read_result; + read_result.packet_buffer = {packet_buffer + 1, sizeof(packet_buffer) - 1}; + read_result.control_buffer = {control_buffer, sizeof(control_buffer)}; + socket_api.ReadPacket(fd, packet_info_interested, &read_result); + if (!read_result.ok) { + // Most likely there is nothing left to read, break out of read loop. + break; + } + if (!read_result.packet_info.HasValue(QuicUdpPacketInfoBit::PEER_ADDRESS)) { + QUIC_BUG(quic_bug_10974_2) + << "Missing peer address when reading from fd " << fd; + continue; + } + if (read_result.packet_info.peer_address() != + expected_target_server_address) { + QUIC_DLOG(ERROR) << "Ignoring UDP packet on fd " << fd + << " from unexpected server address " + << read_result.packet_info.peer_address() + << " (expected " << expected_target_server_address + << ")"; + continue; + } + if (!connection()->connected()) { + QUIC_BUG(quic_bug_10974_3) + << "Unexpected incoming UDP packet on fd " << fd << " from " + << expected_target_server_address + << " because MASQUE connection is closed"; + return; + } + // The packet is valid, send it to the client in a DATAGRAM frame. + MessageStatus message_status = + it->stream()->SendHttp3Datagram(absl::string_view( + packet_buffer, read_result.packet_buffer.buffer_len + 1)); + QUIC_DVLOG(1) << "Sent UDP packet from " << expected_target_server_address + << " of length " << read_result.packet_buffer.buffer_len + << " with stream ID " << it->stream()->id() + << " and got message status " + << MessageStatusToString(message_status); + } +} + +bool MasqueServerSession::OnSettingsFrame(const SettingsFrame& frame) { + QUIC_DLOG(INFO) << "Received SETTINGS: " << frame; + if (!QuicSimpleServerSession::OnSettingsFrame(frame)) { + return false; + } + if (!SupportsH3Datagram()) { + QUIC_DLOG(ERROR) << "Refusing to use MASQUE without HTTP Datagrams"; + return false; + } + QUIC_DLOG(INFO) << "Using HTTP Datagram: " << http_datagram_support(); + return true; +} + +MasqueServerSession::ConnectUdpServerState::ConnectUdpServerState( + QuicSpdyStream* stream, const QuicSocketAddress& target_server_address, + QuicUdpSocketFd fd, MasqueServerSession* masque_session) + : stream_(stream), + target_server_address_(target_server_address), + fd_(fd), + masque_session_(masque_session) { + QUICHE_DCHECK_NE(fd_, kQuicInvalidSocketFd); + QUICHE_DCHECK_NE(masque_session_, nullptr); + this->stream()->RegisterHttp3DatagramVisitor(this); +} + +MasqueServerSession::ConnectUdpServerState::~ConnectUdpServerState() { + if (stream() != nullptr) { + stream()->UnregisterHttp3DatagramVisitor(); + } + if (fd_ == kQuicInvalidSocketFd) { + return; + } + QuicUdpSocketApi socket_api; + QUIC_DLOG(INFO) << "Closing fd " << fd_; + if (!masque_session_->event_loop()->UnregisterSocket(fd_)) { + QUIC_DLOG(ERROR) << "Failed to unregister FD " << fd_; + } + socket_api.Destroy(fd_); +} + +MasqueServerSession::ConnectUdpServerState::ConnectUdpServerState( + MasqueServerSession::ConnectUdpServerState&& other) { + fd_ = kQuicInvalidSocketFd; + *this = std::move(other); +} + +MasqueServerSession::ConnectUdpServerState& +MasqueServerSession::ConnectUdpServerState::operator=( + MasqueServerSession::ConnectUdpServerState&& other) { + if (fd_ != kQuicInvalidSocketFd) { + QuicUdpSocketApi socket_api; + QUIC_DLOG(INFO) << "Closing fd " << fd_; + if (!masque_session_->event_loop()->UnregisterSocket(fd_)) { + QUIC_DLOG(ERROR) << "Failed to unregister FD " << fd_; + } + socket_api.Destroy(fd_); + } + stream_ = other.stream_; + other.stream_ = nullptr; + target_server_address_ = other.target_server_address_; + fd_ = other.fd_; + masque_session_ = other.masque_session_; + other.fd_ = kQuicInvalidSocketFd; + if (stream() != nullptr) { + stream()->ReplaceHttp3DatagramVisitor(this); + } + return *this; +} + +void MasqueServerSession::ConnectUdpServerState::OnHttp3Datagram( + QuicStreamId stream_id, absl::string_view payload) { + QUICHE_DCHECK_EQ(stream_id, stream()->id()); + QuicDataReader reader(payload); + uint64_t context_id; + if (!reader.ReadVarInt62(&context_id)) { + QUIC_DLOG(ERROR) << "Failed to read context ID"; + return; + } + if (context_id != 0) { + QUIC_DLOG(ERROR) << "Ignoring HTTP Datagram with unexpected context ID " + << context_id; + return; + } + absl::string_view http_payload = reader.ReadRemainingPayload(); + QuicUdpSocketApi socket_api; + QuicUdpPacketInfo packet_info; + packet_info.SetPeerAddress(target_server_address_); + WriteResult write_result = socket_api.WritePacket( + fd_, http_payload.data(), http_payload.length(), packet_info); + QUIC_DVLOG(1) << "Wrote packet of length " << http_payload.length() << " to " + << target_server_address_ << " with result " << write_result; +} + +MasqueServerSession::ConnectIpServerState::ConnectIpServerState( + QuicIpAddress client_ip, QuicSpdyStream* stream, QuicUdpSocketFd fd, + MasqueServerSession* masque_session) + : client_ip_(client_ip), + stream_(stream), + fd_(fd), + masque_session_(masque_session) { + QUICHE_DCHECK(client_ip_.IsIPv4()); + QUICHE_DCHECK_NE(fd_, kQuicInvalidSocketFd); + QUICHE_DCHECK_NE(masque_session_, nullptr); + this->stream()->RegisterHttp3DatagramVisitor(this); + this->stream()->RegisterConnectIpVisitor(this); +} + +MasqueServerSession::ConnectIpServerState::~ConnectIpServerState() { + if (stream() != nullptr) { + stream()->UnregisterHttp3DatagramVisitor(); + stream()->UnregisterConnectIpVisitor(); + } + if (fd_ == kQuicInvalidSocketFd) { + return; + } + QuicUdpSocketApi socket_api; + QUIC_DLOG(INFO) << "Closing fd " << fd_; + if (!masque_session_->event_loop()->UnregisterSocket(fd_)) { + QUIC_DLOG(ERROR) << "Failed to unregister FD " << fd_; + } + socket_api.Destroy(fd_); +} + +MasqueServerSession::ConnectIpServerState::ConnectIpServerState( + MasqueServerSession::ConnectIpServerState&& other) { + fd_ = kQuicInvalidSocketFd; + *this = std::move(other); +} + +MasqueServerSession::ConnectIpServerState& +MasqueServerSession::ConnectIpServerState::operator=( + MasqueServerSession::ConnectIpServerState&& other) { + if (fd_ != kQuicInvalidSocketFd) { + QuicUdpSocketApi socket_api; + QUIC_DLOG(INFO) << "Closing fd " << fd_; + if (!masque_session_->event_loop()->UnregisterSocket(fd_)) { + QUIC_DLOG(ERROR) << "Failed to unregister FD " << fd_; + } + socket_api.Destroy(fd_); + } + client_ip_ = other.client_ip_; + stream_ = other.stream_; + other.stream_ = nullptr; + fd_ = other.fd_; + masque_session_ = other.masque_session_; + other.fd_ = kQuicInvalidSocketFd; + if (stream() != nullptr) { + stream()->ReplaceHttp3DatagramVisitor(this); + stream()->ReplaceConnectIpVisitor(this); + } + return *this; +} + +void MasqueServerSession::ConnectIpServerState::OnHttp3Datagram( + QuicStreamId stream_id, absl::string_view payload) { + QUICHE_DCHECK_EQ(stream_id, stream()->id()); + QuicDataReader reader(payload); + uint64_t context_id; + if (!reader.ReadVarInt62(&context_id)) { + QUIC_DLOG(ERROR) << "Failed to read context ID"; + return; + } + if (context_id != 0) { + QUIC_DLOG(ERROR) << "Ignoring HTTP Datagram with unexpected context ID " + << context_id; + return; + } + absl::string_view ip_packet = reader.ReadRemainingPayload(); + ssize_t written = write(fd(), ip_packet.data(), ip_packet.size()); + if (written != static_cast(ip_packet.size())) { + QUIC_DLOG(ERROR) << "Failed to write CONNECT-IP packet of length " + << ip_packet.size(); + } else { + QUIC_DLOG(INFO) << "Decapsulated CONNECT-IP packet of length " + << ip_packet.size(); + } +} + +bool MasqueServerSession::ConnectIpServerState::OnAddressAssignCapsule( + const AddressAssignCapsule& capsule) { + QUIC_DLOG(INFO) << "Ignoring received capsule " << capsule.ToString(); + return true; +} + +bool MasqueServerSession::ConnectIpServerState::OnAddressRequestCapsule( + const AddressRequestCapsule& capsule) { + QUIC_DLOG(INFO) << "Ignoring received capsule " << capsule.ToString(); + return true; +} + +bool MasqueServerSession::ConnectIpServerState::OnRouteAdvertisementCapsule( + const RouteAdvertisementCapsule& capsule) { + QUIC_DLOG(INFO) << "Ignoring received capsule " << capsule.ToString(); + return true; +} + +void MasqueServerSession::ConnectIpServerState::OnHeadersWritten() { + QUICHE_DCHECK(client_ip_.IsIPv4()) << client_ip_.ToString(); + Capsule address_assign_capsule = Capsule::AddressAssign(); + PrefixWithId assigned_address; + assigned_address.ip_prefix = quiche::QuicheIpPrefix(client_ip_, 32); + assigned_address.request_id = 0; + address_assign_capsule.address_assign_capsule().assigned_addresses.push_back( + assigned_address); + stream()->WriteCapsule(address_assign_capsule); + IpAddressRange default_route; + default_route.start_ip_address.FromString("0.0.0.0"); + default_route.end_ip_address.FromString("255.255.255.255"); + default_route.ip_protocol = 0; + Capsule route_advertisement = Capsule::RouteAdvertisement(); + route_advertisement.route_advertisement_capsule().ip_address_ranges.push_back( + default_route); + stream()->WriteCapsule(route_advertisement); +} + +} // namespace quic diff --git a/quiche/quic/masque/masque_server_session.h b/quiche/quic/masque/masque_server_session.h new file mode 100644 index 000000000000..2dbed18fdc00 --- /dev/null +++ b/quiche/quic/masque/masque_server_session.h @@ -0,0 +1,155 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_MASQUE_MASQUE_SERVER_SESSION_H_ +#define QUICHE_QUIC_MASQUE_MASQUE_SERVER_SESSION_H_ + +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_udp_socket.h" +#include "quiche/quic/masque/masque_server_backend.h" +#include "quiche/quic/masque/masque_utils.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/tools/quic_simple_server_session.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +// QUIC server session for connection to MASQUE proxy. +class QUIC_NO_EXPORT MasqueServerSession + : public QuicSimpleServerSession, + public MasqueServerBackend::BackendClient, + public QuicSocketEventListener { + public: + explicit MasqueServerSession( + MasqueMode masque_mode, const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, QuicSession::Visitor* visitor, + QuicEventLoop* event_loop, QuicCryptoServerStreamBase::Helper* helper, + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache, + MasqueServerBackend* masque_server_backend); + + // Disallow copy and assign. + MasqueServerSession(const MasqueServerSession&) = delete; + MasqueServerSession& operator=(const MasqueServerSession&) = delete; + + // From QuicSession. + void OnMessageAcked(QuicMessageId message_id, + QuicTime receive_timestamp) override; + void OnMessageLost(QuicMessageId message_id) override; + void OnConnectionClosed(const QuicConnectionCloseFrame& frame, + ConnectionCloseSource source) override; + void OnStreamClosed(QuicStreamId stream_id) override; + + // From MasqueServerBackend::BackendClient. + std::unique_ptr HandleMasqueRequest( + const spdy::Http2HeaderBlock& request_headers, + QuicSimpleServerBackend::RequestHandler* request_handler) override; + + // From QuicSocketEventListener. + void OnSocketEvent(QuicEventLoop* event_loop, QuicUdpSocketFd fd, + QuicSocketEventMask events) override; + + QuicEventLoop* event_loop() const { return event_loop_; } + + private: + // State that the MasqueServerSession keeps for each CONNECT-UDP request. + class QUIC_NO_EXPORT ConnectUdpServerState + : public QuicSpdyStream::Http3DatagramVisitor { + public: + // ConnectUdpServerState takes ownership of |fd|. It will unregister it + // from |event_loop| and close the file descriptor when destructed. + explicit ConnectUdpServerState( + QuicSpdyStream* stream, const QuicSocketAddress& target_server_address, + QuicUdpSocketFd fd, MasqueServerSession* masque_session); + + ~ConnectUdpServerState(); + + // Disallow copy but allow move. + ConnectUdpServerState(const ConnectUdpServerState&) = delete; + ConnectUdpServerState(ConnectUdpServerState&&); + ConnectUdpServerState& operator=(const ConnectUdpServerState&) = delete; + ConnectUdpServerState& operator=(ConnectUdpServerState&&); + + QuicSpdyStream* stream() const { return stream_; } + const QuicSocketAddress& target_server_address() const { + return target_server_address_; + } + QuicUdpSocketFd fd() const { return fd_; } + + // From QuicSpdyStream::Http3DatagramVisitor. + void OnHttp3Datagram(QuicStreamId stream_id, + absl::string_view payload) override; + void OnUnknownCapsule(QuicStreamId /*stream_id*/, + const quiche::UnknownCapsule& /*capsule*/) override {} + + private: + QuicSpdyStream* stream_; + QuicSocketAddress target_server_address_; + QuicUdpSocketFd fd_; // Owned. + MasqueServerSession* masque_session_; // Unowned. + }; + + // State that the MasqueServerSession keeps for each CONNECT-IP request. + class QUIC_NO_EXPORT ConnectIpServerState + : public QuicSpdyStream::Http3DatagramVisitor, + public QuicSpdyStream::ConnectIpVisitor { + public: + // ConnectIpServerState takes ownership of |fd|. It will unregister it + // from |event_loop| and close the file descriptor when destructed. + explicit ConnectIpServerState(QuicIpAddress client_ip, + QuicSpdyStream* stream, QuicUdpSocketFd fd, + MasqueServerSession* masque_session); + + ~ConnectIpServerState(); + + // Disallow copy but allow move. + ConnectIpServerState(const ConnectIpServerState&) = delete; + ConnectIpServerState(ConnectIpServerState&&); + ConnectIpServerState& operator=(const ConnectIpServerState&) = delete; + ConnectIpServerState& operator=(ConnectIpServerState&&); + + QuicSpdyStream* stream() const { return stream_; } + QuicUdpSocketFd fd() const { return fd_; } + + // From QuicSpdyStream::Http3DatagramVisitor. + void OnHttp3Datagram(QuicStreamId stream_id, + absl::string_view payload) override; + void OnUnknownCapsule(QuicStreamId /*stream_id*/, + const quiche::UnknownCapsule& /*capsule*/) override {} + + // From QuicSpdyStream::ConnectIpVisitor. + bool OnAddressAssignCapsule( + const quiche::AddressAssignCapsule& capsule) override; + bool OnAddressRequestCapsule( + const quiche::AddressRequestCapsule& capsule) override; + bool OnRouteAdvertisementCapsule( + const quiche::RouteAdvertisementCapsule& capsule) override; + void OnHeadersWritten() override; + + private: + QuicIpAddress client_ip_; + QuicSpdyStream* stream_; + QuicUdpSocketFd fd_; // Owned. + MasqueServerSession* masque_session_; // Unowned. + }; + + // From QuicSpdySession. + bool OnSettingsFrame(const SettingsFrame& frame) override; + HttpDatagramSupport LocalHttpDatagramSupport() override { + return HttpDatagramSupport::kRfc; + } + + MasqueServerBackend* masque_server_backend_; // Unowned. + QuicEventLoop* event_loop_; // Unowned. + MasqueMode masque_mode_; + std::list connect_udp_server_states_; + std::list connect_ip_server_states_; + bool masque_initialized_ = false; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_MASQUE_MASQUE_SERVER_SESSION_H_ diff --git a/quiche/quic/masque/masque_utils.cc b/quiche/quic/masque/masque_utils.cc new file mode 100644 index 000000000000..b31eccb77dab --- /dev/null +++ b/quiche/quic/masque/masque_utils.cc @@ -0,0 +1,150 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/masque/masque_utils.h" + +#if defined(__linux__) +#include +#include +#include +#include +#endif // defined(__linux__) + +namespace quic { + +ParsedQuicVersionVector MasqueSupportedVersions() { + QuicVersionInitializeSupportForIetfDraft(); + ParsedQuicVersionVector versions; + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + // Use all versions that support IETF QUIC except QUICv2. + if (version.UsesHttp3() && !version.AlpnDeferToRFCv1()) { + QuicEnableVersion(version); + versions.push_back(version); + } + } + QUICHE_CHECK(!versions.empty()); + return versions; +} + +QuicConfig MasqueEncapsulatedConfig() { + QuicConfig config; + config.SetMaxPacketSizeToSend(kMasqueMaxEncapsulatedPacketSize); + return config; +} + +std::string MasqueModeToString(MasqueMode masque_mode) { + switch (masque_mode) { + case MasqueMode::kInvalid: + return "Invalid"; + case MasqueMode::kOpen: + return "Open"; + case MasqueMode::kConnectIp: + return "CONNECT-IP"; + } + return absl::StrCat("Unknown(", static_cast(masque_mode), ")"); +} + +std::ostream& operator<<(std::ostream& os, const MasqueMode& masque_mode) { + os << MasqueModeToString(masque_mode); + return os; +} + +#if defined(__linux__) +int CreateTunInterface(const QuicIpAddress& client_address, bool server) { + if (!client_address.IsIPv4()) { + QUIC_LOG(ERROR) << "CreateTunInterface currently only supports IPv4"; + return -1; + } + int tun_fd = open("/dev/net/tun", O_RDWR); + int ip_fd = -1; + do { + if (tun_fd < 0) { + QUIC_PLOG(ERROR) << "Failed to open clone device"; + break; + } + struct ifreq ifr = {}; + ifr.ifr_flags = IFF_TUN | IFF_NO_PI; + // If we want to pick a specific device name, we can set it via + // ifr.ifr_name. Otherwise, the kernel will pick the next available tunX + // name. + int err = ioctl(tun_fd, TUNSETIFF, &ifr); + if (err < 0) { + QUIC_PLOG(ERROR) << "TUNSETIFF failed"; + break; + } + ip_fd = socket(AF_INET, SOCK_DGRAM, 0); + if (ip_fd < 0) { + QUIC_PLOG(ERROR) << "Failed to open IP configuration socket"; + break; + } + struct sockaddr_in addr = {}; + addr.sin_family = AF_INET; + // Local address, unused but needs to be set. We use the same address as the + // client address, but with last byte set to 1. + addr.sin_addr = client_address.GetIPv4(); + if (server) { + addr.sin_addr.s_addr &= htonl(0xffffff00); + addr.sin_addr.s_addr |= htonl(0x00000001); + } + memcpy(&ifr.ifr_addr, &addr, sizeof(addr)); + err = ioctl(ip_fd, SIOCSIFADDR, &ifr); + if (err < 0) { + QUIC_PLOG(ERROR) << "SIOCSIFADDR failed"; + break; + } + // Peer address, needs to match source IP address of sent packets. + addr.sin_addr = client_address.GetIPv4(); + if (!server) { + addr.sin_addr.s_addr &= htonl(0xffffff00); + addr.sin_addr.s_addr |= htonl(0x00000001); + } + memcpy(&ifr.ifr_addr, &addr, sizeof(addr)); + err = ioctl(ip_fd, SIOCSIFDSTADDR, &ifr); + if (err < 0) { + QUIC_PLOG(ERROR) << "SIOCSIFDSTADDR failed"; + break; + } + if (!server) { + // Set MTU, to 1280 for now which should always fit (fingers crossed) + ifr.ifr_mtu = 1280; + err = ioctl(ip_fd, SIOCSIFMTU, &ifr); + if (err < 0) { + QUIC_PLOG(ERROR) << "SIOCSIFMTU failed"; + break; + } + } + + err = ioctl(ip_fd, SIOCGIFFLAGS, &ifr); + if (err < 0) { + QUIC_PLOG(ERROR) << "SIOCGIFFLAGS failed"; + break; + } + ifr.ifr_flags |= (IFF_UP | IFF_RUNNING); + err = ioctl(ip_fd, SIOCSIFFLAGS, &ifr); + if (err < 0) { + QUIC_PLOG(ERROR) << "SIOCSIFFLAGS failed"; + break; + } + close(ip_fd); + QUIC_DLOG(INFO) << "Successfully created TUN interface " << ifr.ifr_name + << " with fd " << tun_fd; + return tun_fd; + } while (false); + if (tun_fd >= 0) { + close(tun_fd); + } + if (ip_fd >= 0) { + close(ip_fd); + } + return -1; +} +#else +int CreateTunInterface(const QuicIpAddress& /*client_address*/, + bool /*server*/) { + // Unsupported. + return -1; +} +#endif // defined(__linux__) + +} // namespace quic diff --git a/quiche/quic/masque/masque_utils.h b/quiche/quic/masque/masque_utils.h new file mode 100644 index 000000000000..1743092a1f42 --- /dev/null +++ b/quiche/quic/masque/masque_utils.h @@ -0,0 +1,48 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_MASQUE_MASQUE_UTILS_H_ +#define QUICHE_QUIC_MASQUE_MASQUE_UTILS_H_ + +#include "quiche/quic/core/quic_config.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_versions.h" + +namespace quic { + +// List of QUIC versions that support MASQUE. Currently restricted to IETF QUIC. +QUIC_NO_EXPORT ParsedQuicVersionVector MasqueSupportedVersions(); + +// Default QuicConfig for use with MASQUE. Sets a custom max_packet_size. +QUIC_NO_EXPORT QuicConfig MasqueEncapsulatedConfig(); + +// Maximum packet size for encapsulated connections. +enum : QuicByteCount { + kMasqueMaxEncapsulatedPacketSize = 1250, + kMasqueMaxOuterPacketSize = 1350, +}; + +// Mode that MASQUE is operating in. +enum class MasqueMode : uint8_t { + kInvalid = 0, // Should never be used. + kOpen = 2, // Open mode uses the MASQUE HTTP CONNECT-UDP method as documented + // in . This mode allows + // unauthenticated clients (a more restricted mode will be added to this enum + // at a later date). + kConnectIp = + 1, // ConnectIp mode uses MASQUE HTTP CONNECT-IP as documented in + // . This + // mode also allows unauthenticated clients. +}; + +QUIC_NO_EXPORT std::string MasqueModeToString(MasqueMode masque_mode); +QUIC_NO_EXPORT std::ostream& operator<<(std::ostream& os, + const MasqueMode& masque_mode); + +// Create a TUN interface, with the specified `client_address`. Requires root. +int CreateTunInterface(const QuicIpAddress& client_address, bool server = true); + +} // namespace quic + +#endif // QUICHE_QUIC_MASQUE_MASQUE_UTILS_H_ diff --git a/quiche/quic/platform/README.md b/quiche/quic/platform/README.md new file mode 100644 index 000000000000..6538de108cf7 --- /dev/null +++ b/quiche/quic/platform/README.md @@ -0,0 +1,12 @@ +# QUIC platform + +This platform/ directory exists in order to allow QUIC code to be built on +numerous platforms. It contains two subdirectories: + +- api/ contains platform independent class definitions for fundamental data + structures (e.g., IPAddress, SocketAddress, etc.). +- impl/ contains platform specific implementations of these data structures. + The content of files in impl/ will vary depending on the platform. + +Code in the parent quic/ directory should not depend on any platform specific +code, other than that found in impl/. diff --git a/quiche/quic/platform/api/README.md b/quiche/quic/platform/api/README.md new file mode 100644 index 000000000000..117e424d00c9 --- /dev/null +++ b/quiche/quic/platform/api/README.md @@ -0,0 +1,72 @@ +# QUIC platform API + +This directory contains the infrastructure blocks needed to support QUIC in +certain platform. These APIs act as interaction layers between QUIC core and +either the upper layer application (i.e. Chrome, Envoy) or the platform's own +infrastructure (i.e. logging, test framework and system IO). QUIC core needs the +implementations of these APIs to build and function appropriately. There is +unidirectional dependency from QUIC core to most of the APIs here, such as +QUIC_LOG and QuicMutex, but a few APIs also depend back on QUIC core's basic +QUIC data types, such as QuicClock and QuicSleep. + +- APIs used by QUIC core: + + Most APIs are used by QUIC core to interact with platform infrastructure + (i.e. QUIC_LOG) or to wrap around platform dependent data types (i.e. + QuicThread), the dependency is: + +``` +application -> quic_core -> quic_platform_api + | | + v v +platform_infrastructure <- quic_platform_impl +``` + +- APIs used by applications: + + Some APIs are used by applications to interact with QUIC core (i.e. + QuicMemSlice). For such APIs, their dependency model is: + +``` +application -> quic_core -> quic_platform_api + | ^ + | | + -------------------> quic_platform_impl + | | + | v + -------------------> platform_infrastructure +``` + +        An example for such dependency +is QuicClock. + +        Or + +``` +application -> quic_core -> quic_platform_api + | ^ + | | + | v + -------------------> quic_platform_impl + | | + | v + -------------------> platform_infrastructure +``` + +        An example for such dependency +is QuicMemSlice. + +# Documentation of each API and its usage. + +QuicMemSlice +: QuicMemSlice is used to wrap application data and pass to QUIC stream's + write interface. It refers to a memory block of data which should be around + till QuicMemSlice::Reset() is called. It's upto each platform, to implement + it as reference counted or not. + +QuicClock +: QuicClock is used by QUIC core to get current time. Its instance is created + by applications and passed into QuicDispatcher and + QuicConnectionHelperInterface. + +TODO(b/131224336) add document for other APIs diff --git a/quiche/quic/platform/api/quic_bug_tracker.h b/quiche/quic/platform/api/quic_bug_tracker.h new file mode 100644 index 000000000000..f84d220630e3 --- /dev/null +++ b/quiche/quic/platform/api/quic_bug_tracker.h @@ -0,0 +1,15 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_PLATFORM_API_QUIC_BUG_TRACKER_H_ +#define QUICHE_QUIC_PLATFORM_API_QUIC_BUG_TRACKER_H_ + +#include "quiche/common/platform/api/quiche_bug_tracker.h" + +#define QUIC_BUG QUICHE_BUG +#define QUIC_BUG_IF QUICHE_BUG_IF +#define QUIC_PEER_BUG QUICHE_PEER_BUG +#define QUIC_PEER_BUG_IF QUICHE_PEER_BUG_IF + +#endif // QUICHE_QUIC_PLATFORM_API_QUIC_BUG_TRACKER_H_ diff --git a/quiche/quic/platform/api/quic_client_stats.h b/quiche/quic/platform/api/quic_client_stats.h new file mode 100644 index 000000000000..d18d6142acd5 --- /dev/null +++ b/quiche/quic/platform/api/quic_client_stats.h @@ -0,0 +1,87 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_PLATFORM_API_QUIC_CLIENT_STATS_H_ +#define QUICHE_QUIC_PLATFORM_API_QUIC_CLIENT_STATS_H_ + +#include + +#include "quiche/common/platform/api/quiche_client_stats.h" + +namespace quic { + +//------------------------------------------------------------------------------ +// Enumeration histograms. +// +// Sample usage: +// // In Chrome, these values are persisted to logs. Entries should not be +// // renumbered and numeric values should never be reused. +// enum class MyEnum { +// FIRST_VALUE = 0, +// SECOND_VALUE = 1, +// ... +// FINAL_VALUE = N, +// COUNT +// }; +// QUIC_CLIENT_HISTOGRAM_ENUM("My.Enumeration", MyEnum::SOME_VALUE, +// MyEnum::COUNT, "Number of time $foo equals to some enum value"); +// +// Note: The value in |sample| must be strictly less than |enum_size|. + +#define QUIC_CLIENT_HISTOGRAM_ENUM(name, sample, enum_size, docstring) \ + QUICHE_CLIENT_HISTOGRAM_ENUM(name, sample, enum_size, docstring) + +//------------------------------------------------------------------------------ +// Histogram for boolean values. + +// Sample usage: +// QUIC_CLIENT_HISTOGRAM_BOOL("My.Boolean", bool, +// "Number of times $foo is true or false"); +#define QUIC_CLIENT_HISTOGRAM_BOOL(name, sample, docstring) \ + QUICHE_CLIENT_HISTOGRAM_BOOL(name, sample, docstring) + +//------------------------------------------------------------------------------ +// Timing histograms. These are used for collecting timing data (generally +// latencies). + +// These macros create exponentially sized histograms (lengths of the bucket +// ranges exponentially increase as the sample range increases). The units for +// sample and max are unspecified, but they must be the same for one histogram. + +// Sample usage: +// QUIC_CLIENT_HISTOGRAM_TIMES("Very.Long.Timing.Histogram", time_delta, +// QuicTime::Delta::FromSeconds(1), QuicTime::Delta::FromSecond(3600 * +// 24), 100, "Time spent in doing operation."); +#define QUIC_CLIENT_HISTOGRAM_TIMES(name, sample, min, max, bucket_count, \ + docstring) \ + QUICHE_CLIENT_HISTOGRAM_TIMES(name, sample, min, max, bucket_count, docstring) + +//------------------------------------------------------------------------------ +// Count histograms. These are used for collecting numeric data. + +// These macros default to exponential histograms - i.e. the lengths of the +// bucket ranges exponentially increase as the sample range increases. + +// All of these macros must be called with |name| as a runtime constant. + +// Any data outside the range here will be put in underflow and overflow +// buckets. Min values should be >=1 as emitted 0s will still go into the +// underflow bucket. + +// Sample usage: +// UMA_CLIENT_HISTOGRAM_CUSTOM_COUNTS("My.Histogram", 1, 100000000, 100, +// "Counters of hitting certian code."); + +#define QUIC_CLIENT_HISTOGRAM_COUNTS(name, sample, min, max, bucket_count, \ + docstring) \ + QUICHE_CLIENT_HISTOGRAM_COUNTS(name, sample, min, max, bucket_count, \ + docstring) + +inline void QuicClientSparseHistogram(const std::string& name, int sample) { + quiche::QuicheClientSparseHistogram(name, sample); +} + +} // namespace quic + +#endif // QUICHE_QUIC_PLATFORM_API_QUIC_CLIENT_STATS_H_ diff --git a/quiche/quic/platform/api/quic_default_proof_providers.h b/quiche/quic/platform/api/quic_default_proof_providers.h new file mode 100644 index 000000000000..59d9052baa99 --- /dev/null +++ b/quiche/quic/platform/api/quic_default_proof_providers.h @@ -0,0 +1,33 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_PLATFORM_API_QUIC_DEFAULT_PROOF_PROVIDERS_H_ +#define QUICHE_QUIC_PLATFORM_API_QUIC_DEFAULT_PROOF_PROVIDERS_H_ + +#include +#include + +#include "quiche/quic/core/crypto/proof_source.h" +#include "quiche/quic/core/crypto/proof_verifier.h" +#include "quiche/common/platform/api/quiche_default_proof_providers.h" + +namespace quic { + +// Provides a default proof verifier that can verify a cert chain for |host|. +// The verifier has to do a good faith attempt at verifying the certificate +// against a reasonable root store, and not just always return success. +inline std::unique_ptr CreateDefaultProofVerifier( + const std::string& host) { + return quiche::CreateDefaultProofVerifier(host); +} + +// Provides a default proof source for CLI-based tools. The actual certificates +// used in the proof source should be confifgurable via command-line flags. +inline std::unique_ptr CreateDefaultProofSource() { + return quiche::CreateDefaultProofSource(); +} + +} // namespace quic + +#endif // QUICHE_QUIC_PLATFORM_API_QUIC_DEFAULT_PROOF_PROVIDERS_H_ diff --git a/quiche/quic/platform/api/quic_expect_bug.h b/quiche/quic/platform/api/quic_expect_bug.h new file mode 100644 index 000000000000..02ba38a20ea3 --- /dev/null +++ b/quiche/quic/platform/api/quic_expect_bug.h @@ -0,0 +1,14 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_PLATFORM_API_QUIC_EXPECT_BUG_H_ +#define QUICHE_QUIC_PLATFORM_API_QUIC_EXPECT_BUG_H_ + +#include "quiche/common/platform/api/quiche_expect_bug.h" + +#define EXPECT_QUIC_BUG EXPECT_QUICHE_BUG +#define EXPECT_QUIC_PEER_BUG(statement, regex) \ + EXPECT_QUICHE_PEER_BUG(statement, regex) + +#endif // QUICHE_QUIC_PLATFORM_API_QUIC_EXPECT_BUG_H_ diff --git a/quiche/quic/platform/api/quic_export.h b/quiche/quic/platform/api/quic_export.h new file mode 100644 index 000000000000..13e43ca1de36 --- /dev/null +++ b/quiche/quic/platform/api/quic_export.h @@ -0,0 +1,21 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_PLATFORM_API_QUIC_EXPORT_H_ +#define QUICHE_QUIC_PLATFORM_API_QUIC_EXPORT_H_ + +#include "quiche/common/platform/api/quiche_export.h" + +// QUIC_EXPORT is not meant to be used. +#define QUIC_EXPORT QUICHE_EXPORT + +// QUIC_EXPORT_PRIVATE is meant for QUIC functionality that is built in Chromium +// as part of //net, and not fully contained in headers. +#define QUIC_EXPORT_PRIVATE QUICHE_EXPORT + +// QUIC_NO_EXPORT is meant for QUIC functionality that is either fully defined +// in a header, or is built in Chromium as part of tests or tools. +#define QUIC_NO_EXPORT QUICHE_NO_EXPORT + +#endif // QUICHE_QUIC_PLATFORM_API_QUIC_EXPORT_H_ diff --git a/quiche/quic/platform/api/quic_exported_stats.h b/quiche/quic/platform/api/quic_exported_stats.h new file mode 100644 index 000000000000..3e7f4a757fca --- /dev/null +++ b/quiche/quic/platform/api/quic_exported_stats.h @@ -0,0 +1,96 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_PLATFORM_API_QUIC_EXPORTED_STATS_H_ +#define QUICHE_QUIC_PLATFORM_API_QUIC_EXPORTED_STATS_H_ + +#include "quiche/quic/platform/api/quic_client_stats.h" +#include "quiche/quic/platform/api/quic_server_stats.h" + +namespace quic { + +// TODO(wub): Add support for counters. Only histograms are supported for now. + +//------------------------------------------------------------------------------ +// Enumeration histograms. +// +// Sample usage: +// // In Chrome, these values are persisted to logs. Entries should not be +// // renumbered and numeric values should never be reused. +// enum class MyEnum { +// FIRST_VALUE = 0, +// SECOND_VALUE = 1, +// ... +// FINAL_VALUE = N, +// COUNT +// }; +// QUIC_HISTOGRAM_ENUM("My.Enumeration", MyEnum::SOME_VALUE, MyEnum::COUNT, +// "Number of time $foo equals to some enum value"); +// +// Note: The value in |sample| must be strictly less than |enum_size|. + +#define QUIC_HISTOGRAM_ENUM(name, sample, enum_size, docstring) \ + do { \ + QUIC_CLIENT_HISTOGRAM_ENUM(name, sample, enum_size, docstring); \ + QUIC_SERVER_HISTOGRAM_ENUM(name, sample, enum_size, docstring); \ + } while (0) + +//------------------------------------------------------------------------------ +// Histogram for boolean values. + +// Sample usage: +// QUIC_HISTOGRAM_BOOL("My.Boolean", bool, +// "Number of times $foo is true or false"); +#define QUIC_HISTOGRAM_BOOL(name, sample, docstring) \ + do { \ + QUIC_CLIENT_HISTOGRAM_BOOL(name, sample, docstring); \ + QUIC_SERVER_HISTOGRAM_BOOL(name, sample, docstring); \ + } while (0) + +//------------------------------------------------------------------------------ +// Timing histograms. These are used for collecting timing data (generally +// latencies). + +// These macros create exponentially sized histograms (lengths of the bucket +// ranges exponentially increase as the sample range increases). The units for +// sample and max are unspecified, but they must be the same for one histogram. + +// Sample usage: +// QUIC_HISTOGRAM_TIMES("My.Timing.Histogram.InMs", time_delta, +// QuicTime::Delta::FromSeconds(1), QuicTime::Delta::FromSecond(3600 * +// 24), 100, "Time spent in doing operation."); + +#define QUIC_HISTOGRAM_TIMES(name, sample, min, max, bucket_count, docstring) \ + do { \ + QUIC_CLIENT_HISTOGRAM_TIMES(name, sample, min, max, bucket_count, \ + docstring); \ + QUIC_SERVER_HISTOGRAM_TIMES(name, sample, min, max, bucket_count, \ + docstring); \ + } while (0) + +//------------------------------------------------------------------------------ +// Count histograms. These are used for collecting numeric data. + +// These macros default to exponential histograms - i.e. the lengths of the +// bucket ranges exponentially increase as the sample range increases. + +// All of these macros must be called with |name| as a runtime constant. + +// Sample usage: +// QUIC_HISTOGRAM_COUNTS("My.Histogram", +// sample, // Number of something in this event. +// 1000, // Record up to 1K of something. +// "Number of something."); + +#define QUIC_HISTOGRAM_COUNTS(name, sample, min, max, bucket_count, docstring) \ + do { \ + QUIC_CLIENT_HISTOGRAM_COUNTS(name, sample, min, max, bucket_count, \ + docstring); \ + QUIC_SERVER_HISTOGRAM_COUNTS(name, sample, min, max, bucket_count, \ + docstring); \ + } while (0) + +} // namespace quic + +#endif // QUICHE_QUIC_PLATFORM_API_QUIC_EXPORTED_STATS_H_ diff --git a/quiche/quic/platform/api/quic_flag_utils.h b/quiche/quic/platform/api/quic_flag_utils.h new file mode 100644 index 000000000000..c8e179dede65 --- /dev/null +++ b/quiche/quic/platform/api/quic_flag_utils.h @@ -0,0 +1,19 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_PLATFORM_API_QUIC_FLAG_UTILS_H_ +#define QUICHE_QUIC_PLATFORM_API_QUIC_FLAG_UTILS_H_ + +#include "quiche/common/platform/api/quiche_flag_utils.h" + +#define QUIC_RELOADABLE_FLAG_COUNT QUICHE_RELOADABLE_FLAG_COUNT +#define QUIC_RELOADABLE_FLAG_COUNT_N QUICHE_RELOADABLE_FLAG_COUNT_N + +#define QUIC_RESTART_FLAG_COUNT QUICHE_RESTART_FLAG_COUNT +#define QUIC_RESTART_FLAG_COUNT_N QUICHE_RESTART_FLAG_COUNT_N + +#define QUIC_CODE_COUNT QUICHE_CODE_COUNT +#define QUIC_CODE_COUNT_N QUICHE_CODE_COUNT_N + +#endif // QUICHE_QUIC_PLATFORM_API_QUIC_FLAG_UTILS_H_ diff --git a/quiche/quic/platform/api/quic_flags.h b/quiche/quic/platform/api/quic_flags.h new file mode 100644 index 000000000000..3b36c8fafb05 --- /dev/null +++ b/quiche/quic/platform/api/quic_flags.h @@ -0,0 +1,21 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_PLATFORM_API_QUIC_FLAGS_H_ +#define QUICHE_QUIC_PLATFORM_API_QUIC_FLAGS_H_ + +#include +#include + +#include "quiche/common/platform/api/quiche_flags.h" + +#define GetQuicReloadableFlag(flag) GetQuicheReloadableFlag(quic, flag) +#define SetQuicReloadableFlag(flag, value) \ + SetQuicheReloadableFlag(quic, flag, value) +#define GetQuicRestartFlag(flag) GetQuicheRestartFlag(quic, flag) +#define SetQuicRestartFlag(flag, value) SetQuicheRestartFlag(quic, flag, value) +#define GetQuicFlag(flag) GetQuicheFlag(flag) +#define SetQuicFlag(flag, value) SetQuicheFlag(flag, value) + +#endif // QUICHE_QUIC_PLATFORM_API_QUIC_FLAGS_H_ diff --git a/quiche/quic/platform/api/quic_hostname_utils.h b/quiche/quic/platform/api/quic_hostname_utils.h new file mode 100644 index 000000000000..69072d56cdef --- /dev/null +++ b/quiche/quic/platform/api/quic_hostname_utils.h @@ -0,0 +1,16 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_PLATFORM_API_QUIC_HOSTNAME_UTILS_H_ +#define QUICHE_QUIC_PLATFORM_API_QUIC_HOSTNAME_UTILS_H_ + +#include "quiche/common/platform/api/quiche_hostname_utils.h" + +namespace quic { + +using QuicHostnameUtils = quiche::QuicheHostnameUtils; + +} // namespace quic + +#endif // QUICHE_QUIC_PLATFORM_API_QUIC_HOSTNAME_UTILS_H_ diff --git a/quiche/quic/platform/api/quic_ip_address.h b/quiche/quic/platform/api/quic_ip_address.h new file mode 100644 index 000000000000..7f2ed20d5717 --- /dev/null +++ b/quiche/quic/platform/api/quic_ip_address.h @@ -0,0 +1,16 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_PLATFORM_API_QUIC_IP_ADDRESS_H_ +#define QUICHE_QUIC_PLATFORM_API_QUIC_IP_ADDRESS_H_ + +#include "quiche/common/quiche_ip_address.h" + +namespace quic { + +using QuicIpAddress = ::quiche::QuicheIpAddress; + +} // namespace quic + +#endif // QUICHE_QUIC_PLATFORM_API_QUIC_IP_ADDRESS_H_ diff --git a/quiche/quic/platform/api/quic_ip_address_family.h b/quiche/quic/platform/api/quic_ip_address_family.h new file mode 100644 index 000000000000..0aeac174d5aa --- /dev/null +++ b/quiche/quic/platform/api/quic_ip_address_family.h @@ -0,0 +1,16 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_PLATFORM_API_QUIC_IP_ADDRESS_FAMILY_H_ +#define QUICHE_QUIC_PLATFORM_API_QUIC_IP_ADDRESS_FAMILY_H_ + +#include "quiche/common/quiche_ip_address_family.h" + +namespace quic { + +using IpAddressFamily = ::quiche::IpAddressFamily; + +} // namespace quic + +#endif // QUICHE_QUIC_PLATFORM_API_QUIC_IP_ADDRESS_FAMILY_H_ diff --git a/quiche/quic/platform/api/quic_logging.h b/quiche/quic/platform/api/quic_logging.h new file mode 100644 index 000000000000..c44e53f54a47 --- /dev/null +++ b/quiche/quic/platform/api/quic_logging.h @@ -0,0 +1,32 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_PLATFORM_API_QUIC_LOGGING_H_ +#define QUICHE_QUIC_PLATFORM_API_QUIC_LOGGING_H_ + +#include "quiche/common/platform/api/quiche_logging.h" + +// Please note following QUIC_LOG are platform dependent: +// INFO severity can be degraded (to VLOG(1) or DVLOG(1)). +// Some platforms may not support QUIC_LOG_FIRST_N or QUIC_LOG_EVERY_N_SEC, and +// they would simply be translated to LOG. + +#define QUIC_DVLOG QUICHE_DVLOG +#define QUIC_DVLOG_IF QUICHE_DVLOG_IF +#define QUIC_DLOG QUICHE_DLOG +#define QUIC_DLOG_IF QUICHE_DLOG_IF +#define QUIC_VLOG QUICHE_VLOG +#define QUIC_LOG QUICHE_LOG +#define QUIC_LOG_FIRST_N QUICHE_LOG_FIRST_N +#define QUIC_LOG_EVERY_N_SEC QUICHE_LOG_EVERY_N_SEC +#define QUIC_LOG_IF QUICHE_LOG_IF + +#define QUIC_PLOG QUICHE_PLOG + +#define QUIC_DLOG_INFO_IS_ON QUICHE_DLOG_INFO_IS_ON +#define QUIC_LOG_INFO_IS_ON QUICHE_LOG_INFO_IS_ON +#define QUIC_LOG_WARNING_IS_ON QUICHE_LOG_WARNING_IS_ON +#define QUIC_LOG_ERROR_IS_ON QUICHE_LOG_ERROR_IS_ON + +#endif // QUICHE_QUIC_PLATFORM_API_QUIC_LOGGING_H_ diff --git a/quiche/quic/platform/api/quic_mutex.h b/quiche/quic/platform/api/quic_mutex.h new file mode 100644 index 000000000000..1e6f36af4b3a --- /dev/null +++ b/quiche/quic/platform/api/quic_mutex.h @@ -0,0 +1,32 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// TODO(b/178613777): Remove this file. + +#ifndef QUICHE_QUIC_PLATFORM_API_QUIC_MUTEX_H_ +#define QUICHE_QUIC_PLATFORM_API_QUIC_MUTEX_H_ + +#include "quiche/common/platform/api/quiche_mutex.h" + +#define QUIC_EXCLUSIVE_LOCKS_REQUIRED QUICHE_EXCLUSIVE_LOCKS_REQUIRED +#define QUIC_GUARDED_BY QUICHE_GUARDED_BY +#define QUIC_LOCKABLE QUICHE_LOCKABLE +#define QUIC_LOCKS_EXCLUDED QUICHE_LOCKS_EXCLUDED +#define QUIC_SHARED_LOCKS_REQUIRED QUICHE_SHARED_LOCKS_REQUIRED +#define QUIC_EXCLUSIVE_LOCK_FUNCTION QUICHE_EXCLUSIVE_LOCK_FUNCTION +#define QUIC_UNLOCK_FUNCTION QUICHE_UNLOCK_FUNCTION +#define QUIC_SHARED_LOCK_FUNCTION QUICHE_SHARED_LOCK_FUNCTION +#define QUIC_SCOPED_LOCKABLE QUICHE_SCOPED_LOCKABLE +#define QUIC_ASSERT_SHARED_LOCK QUICHE_ASSERT_SHARED_LOCK + +namespace quic { + +using QuicMutex = quiche::QuicheMutex; +using QuicReaderMutexLock = quiche::QuicheReaderMutexLock; +using QuicWriterMutexLock = quiche::QuicheWriterMutexLock; +using QuicNotification = quiche::QuicheNotification; + +} // namespace quic + +#endif // QUICHE_QUIC_PLATFORM_API_QUIC_MUTEX_H_ diff --git a/quiche/quic/platform/api/quic_server_stats.h b/quiche/quic/platform/api/quic_server_stats.h new file mode 100644 index 000000000000..8e373464b529 --- /dev/null +++ b/quiche/quic/platform/api/quic_server_stats.h @@ -0,0 +1,25 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_PLATFORM_API_QUIC_SERVER_STATS_H_ +#define QUICHE_QUIC_PLATFORM_API_QUIC_SERVER_STATS_H_ + +#include "quiche/common/platform/api/quiche_server_stats.h" + +#define QUIC_SERVER_HISTOGRAM_ENUM(name, sample, enum_size, docstring) \ + QUICHE_SERVER_HISTOGRAM_ENUM(name, sample, enum_size, docstring) + +#define QUIC_SERVER_HISTOGRAM_BOOL(name, sample, docstring) \ + QUICHE_SERVER_HISTOGRAM_BOOL(name, sample, docstring) + +#define QUIC_SERVER_HISTOGRAM_TIMES(name, sample, min, max, bucket_count, \ + docstring) \ + QUICHE_SERVER_HISTOGRAM_TIMES(name, sample, min, max, bucket_count, docstring) + +#define QUIC_SERVER_HISTOGRAM_COUNTS(name, sample, min, max, bucket_count, \ + docstring) \ + QUICHE_SERVER_HISTOGRAM_COUNTS(name, sample, min, max, bucket_count, \ + docstring) + +#endif // QUICHE_QUIC_PLATFORM_API_QUIC_SERVER_STATS_H_ diff --git a/quiche/quic/platform/api/quic_socket_address.cc b/quiche/quic/platform/api/quic_socket_address.cc new file mode 100644 index 000000000000..f14faea0f5ab --- /dev/null +++ b/quiche/quic/platform/api/quic_socket_address.cc @@ -0,0 +1,152 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/platform/api/quic_socket_address.h" + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_ip_address_family.h" + +namespace quic { + +namespace { + +uint32_t HashIP(const QuicIpAddress& ip) { + if (ip.IsIPv4()) { + return ip.GetIPv4().s_addr; + } + if (ip.IsIPv6()) { + auto v6addr = ip.GetIPv6(); + const uint32_t* v6_as_ints = + reinterpret_cast(&v6addr.s6_addr); + return v6_as_ints[0] ^ v6_as_ints[1] ^ v6_as_ints[2] ^ v6_as_ints[3]; + } + return 0; +} + +} // namespace + +QuicSocketAddress::QuicSocketAddress(QuicIpAddress address, uint16_t port) + : host_(address), port_(port) {} + +QuicSocketAddress::QuicSocketAddress(const struct sockaddr_storage& saddr) { + switch (saddr.ss_family) { + case AF_INET: { + const sockaddr_in* v4 = reinterpret_cast(&saddr); + host_ = QuicIpAddress(v4->sin_addr); + port_ = ntohs(v4->sin_port); + break; + } + case AF_INET6: { + const sockaddr_in6* v6 = reinterpret_cast(&saddr); + host_ = QuicIpAddress(v6->sin6_addr); + port_ = ntohs(v6->sin6_port); + break; + } + default: + QUIC_BUG(quic_bug_10075_1) + << "Unknown address family passed: " << saddr.ss_family; + break; + } +} + +QuicSocketAddress::QuicSocketAddress(const sockaddr* saddr, socklen_t len) { + sockaddr_storage storage; + static_assert(std::numeric_limits::max() >= sizeof(storage), + "Cannot cast sizeof(storage) to socklen_t as it does not fit"); + if (len < static_cast(sizeof(sockaddr)) || + (saddr->sa_family == AF_INET && + len < static_cast(sizeof(sockaddr_in))) || + (saddr->sa_family == AF_INET6 && + len < static_cast(sizeof(sockaddr_in6))) || + len > static_cast(sizeof(storage))) { + QUIC_BUG(quic_bug_10075_2) << "Socket address of invalid length provided"; + return; + } + memcpy(&storage, saddr, len); + *this = QuicSocketAddress(storage); +} + +bool operator==(const QuicSocketAddress& lhs, const QuicSocketAddress& rhs) { + return lhs.host_ == rhs.host_ && lhs.port_ == rhs.port_; +} + +bool operator!=(const QuicSocketAddress& lhs, const QuicSocketAddress& rhs) { + return !(lhs == rhs); +} + +bool QuicSocketAddress::IsInitialized() const { return host_.IsInitialized(); } + +std::string QuicSocketAddress::ToString() const { + switch (host_.address_family()) { + case IpAddressFamily::IP_V4: + return absl::StrCat(host_.ToString(), ":", port_); + case IpAddressFamily::IP_V6: + return absl::StrCat("[", host_.ToString(), "]:", port_); + default: + return ""; + } +} + +int QuicSocketAddress::FromSocket(int fd) { + sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + int result = getsockname(fd, reinterpret_cast(&addr), &addr_len); + + bool success = result == 0 && addr_len > 0 && + static_cast(addr_len) <= sizeof(addr); + if (success) { + *this = QuicSocketAddress(addr); + return 0; + } + return -1; +} + +QuicSocketAddress QuicSocketAddress::Normalized() const { + return QuicSocketAddress(host_.Normalized(), port_); +} + +QuicIpAddress QuicSocketAddress::host() const { return host_; } + +uint16_t QuicSocketAddress::port() const { return port_; } + +sockaddr_storage QuicSocketAddress::generic_address() const { + union { + sockaddr_storage storage; + sockaddr_in v4; + sockaddr_in6 v6; + } result; + memset(&result.storage, 0, sizeof(result.storage)); + + switch (host_.address_family()) { + case IpAddressFamily::IP_V4: + result.v4.sin_family = AF_INET; + result.v4.sin_addr = host_.GetIPv4(); + result.v4.sin_port = htons(port_); + break; + case IpAddressFamily::IP_V6: + result.v6.sin6_family = AF_INET6; + result.v6.sin6_addr = host_.GetIPv6(); + result.v6.sin6_port = htons(port_); + break; + default: + result.storage.ss_family = AF_UNSPEC; + break; + } + return result.storage; +} + +uint32_t QuicSocketAddress::Hash() const { + uint32_t value = 0; + value ^= HashIP(host_); + value ^= port_ | (port_ << 16); + return value; +} + +} // namespace quic diff --git a/quiche/quic/platform/api/quic_socket_address.h b/quiche/quic/platform/api/quic_socket_address.h new file mode 100644 index 000000000000..626a54de86f1 --- /dev/null +++ b/quiche/quic/platform/api/quic_socket_address.h @@ -0,0 +1,67 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_PLATFORM_API_QUIC_SOCKET_ADDRESS_H_ +#define QUICHE_QUIC_PLATFORM_API_QUIC_SOCKET_ADDRESS_H_ + +#include + +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/platform/api/quic_ip_address.h" + +namespace quic { + +// A class representing a socket endpoint address (i.e., IP address plus a +// port) in QUIC. +class QUIC_EXPORT_PRIVATE QuicSocketAddress { + public: + QuicSocketAddress() {} + QuicSocketAddress(QuicIpAddress address, uint16_t port); + explicit QuicSocketAddress(const struct sockaddr_storage& saddr); + explicit QuicSocketAddress(const sockaddr* saddr, socklen_t len); + QuicSocketAddress(const QuicSocketAddress& other) = default; + QuicSocketAddress& operator=(const QuicSocketAddress& other) = default; + QuicSocketAddress& operator=(QuicSocketAddress&& other) = default; + QUIC_EXPORT_PRIVATE friend bool operator==(const QuicSocketAddress& lhs, + const QuicSocketAddress& rhs); + QUIC_EXPORT_PRIVATE friend bool operator!=(const QuicSocketAddress& lhs, + const QuicSocketAddress& rhs); + + bool IsInitialized() const; + std::string ToString() const; + + // TODO(ericorth): Convert usage over to socket_api::GetSocketAddress() and + // remove. + int FromSocket(int fd); + + QuicSocketAddress Normalized() const; + + QuicIpAddress host() const; + uint16_t port() const; + sockaddr_storage generic_address() const; + + // Hashes this address to an uint32_t. + uint32_t Hash() const; + + private: + QuicIpAddress host_; + uint16_t port_ = 0; +}; + +inline std::ostream& operator<<(std::ostream& os, + const QuicSocketAddress address) { + os << address.ToString(); + return os; +} + +class QUIC_EXPORT_PRIVATE QuicSocketAddressHash { + public: + size_t operator()(QuicSocketAddress const& address) const noexcept { + return address.Hash(); + } +}; + +} // namespace quic + +#endif // QUICHE_QUIC_PLATFORM_API_QUIC_SOCKET_ADDRESS_H_ diff --git a/quiche/quic/platform/api/quic_socket_address_test.cc b/quiche/quic/platform/api/quic_socket_address_test.cc new file mode 100644 index 000000000000..9512c9f227a4 --- /dev/null +++ b/quiche/quic/platform/api/quic_socket_address_test.cc @@ -0,0 +1,134 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/platform/api/quic_socket_address.h" + +#include +#include + +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace { + +TEST(QuicSocketAddress, Uninitialized) { + QuicSocketAddress uninitialized; + EXPECT_FALSE(uninitialized.IsInitialized()); +} + +TEST(QuicSocketAddress, ExplicitConstruction) { + QuicSocketAddress ipv4_address(QuicIpAddress::Loopback4(), 443); + QuicSocketAddress ipv6_address(QuicIpAddress::Loopback6(), 443); + EXPECT_TRUE(ipv4_address.IsInitialized()); + EXPECT_EQ("127.0.0.1:443", ipv4_address.ToString()); + EXPECT_EQ("[::1]:443", ipv6_address.ToString()); + EXPECT_EQ(QuicIpAddress::Loopback4(), ipv4_address.host()); + EXPECT_EQ(QuicIpAddress::Loopback6(), ipv6_address.host()); + EXPECT_EQ(443, ipv4_address.port()); +} + +TEST(QuicSocketAddress, OutputToStream) { + QuicSocketAddress ipv4_address(QuicIpAddress::Loopback4(), 443); + std::stringstream stream; + stream << ipv4_address; + EXPECT_EQ("127.0.0.1:443", stream.str()); +} + +TEST(QuicSocketAddress, FromSockaddrIPv4) { + union { + sockaddr_storage storage; + sockaddr addr; + sockaddr_in v4; + } address; + + memset(&address, 0, sizeof(address)); + address.v4.sin_family = AF_INET; + address.v4.sin_addr = QuicIpAddress::Loopback4().GetIPv4(); + address.v4.sin_port = htons(443); + EXPECT_EQ("127.0.0.1:443", + QuicSocketAddress(&address.addr, sizeof(address.v4)).ToString()); + EXPECT_EQ("127.0.0.1:443", QuicSocketAddress(address.storage).ToString()); +} + +TEST(QuicSocketAddress, FromSockaddrIPv6) { + union { + sockaddr_storage storage; + sockaddr addr; + sockaddr_in6 v6; + } address; + + memset(&address, 0, sizeof(address)); + address.v6.sin6_family = AF_INET6; + address.v6.sin6_addr = QuicIpAddress::Loopback6().GetIPv6(); + address.v6.sin6_port = htons(443); + EXPECT_EQ("[::1]:443", + QuicSocketAddress(&address.addr, sizeof(address.v6)).ToString()); + EXPECT_EQ("[::1]:443", QuicSocketAddress(address.storage).ToString()); +} + +TEST(QuicSocketAddres, ToSockaddrIPv4) { + union { + sockaddr_storage storage; + sockaddr_in v4; + } address; + + address.storage = + QuicSocketAddress(QuicIpAddress::Loopback4(), 443).generic_address(); + ASSERT_EQ(AF_INET, address.v4.sin_family); + EXPECT_EQ(QuicIpAddress::Loopback4(), QuicIpAddress(address.v4.sin_addr)); + EXPECT_EQ(htons(443), address.v4.sin_port); +} + +TEST(QuicSocketAddress, Normalize) { + QuicIpAddress dual_stacked; + ASSERT_TRUE(dual_stacked.FromString("::ffff:127.0.0.1")); + ASSERT_TRUE(dual_stacked.IsIPv6()); + QuicSocketAddress not_normalized(dual_stacked, 443); + QuicSocketAddress normalized = not_normalized.Normalized(); + EXPECT_EQ("[::ffff:127.0.0.1]:443", not_normalized.ToString()); + EXPECT_EQ("127.0.0.1:443", normalized.ToString()); +} + +// TODO(vasilvv): either ensure this works on all platforms, or deprecate and +// remove this API. +#if defined(__linux__) && !defined(ANDROID) +#include +#include +#include + +TEST(QuicSocketAddress, FromSocket) { + int fd; + QuicSocketAddress address; + bool bound = false; + for (int port = 50000; port < 50400; port++) { + fd = socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP); + ASSERT_GT(fd, 0); + + address = QuicSocketAddress(QuicIpAddress::Loopback6(), port); + sockaddr_storage raw_address = address.generic_address(); + int bind_result = bind(fd, reinterpret_cast(&raw_address), + sizeof(sockaddr_in6)); + + if (bind_result < 0 && errno == EADDRINUSE) { + close(fd); + continue; + } + + ASSERT_EQ(0, bind_result); + bound = true; + break; + } + ASSERT_TRUE(bound); + + QuicSocketAddress real_address; + ASSERT_EQ(0, real_address.FromSocket(fd)); + ASSERT_TRUE(real_address.IsInitialized()); + EXPECT_EQ(real_address, address); + close(fd); +} +#endif + +} // namespace +} // namespace quic diff --git a/quiche/quic/platform/api/quic_stack_trace.h b/quiche/quic/platform/api/quic_stack_trace.h new file mode 100644 index 000000000000..eb8beab42ecd --- /dev/null +++ b/quiche/quic/platform/api/quic_stack_trace.h @@ -0,0 +1,18 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_PLATFORM_API_QUIC_STACK_TRACE_H_ +#define QUICHE_QUIC_PLATFORM_API_QUIC_STACK_TRACE_H_ + +#include + +#include "quiche/common/platform/api/quiche_stack_trace.h" + +namespace quic { + +inline std::string QuicStackTrace() { return quiche::QuicheStackTrace(); } + +} // namespace quic + +#endif // QUICHE_QUIC_PLATFORM_API_QUIC_STACK_TRACE_H_ diff --git a/quiche/quic/platform/api/quic_test.h b/quiche/quic/platform/api/quic_test.h new file mode 100644 index 000000000000..0d1d2b0b1663 --- /dev/null +++ b/quiche/quic/platform/api/quic_test.h @@ -0,0 +1,26 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_PLATFORM_API_QUIC_TEST_H_ +#define QUICHE_QUIC_PLATFORM_API_QUIC_TEST_H_ + +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace quic::test { + +using QuicFlagSaver = quiche::test::QuicheFlagSaver; + +// Defines the base classes to be used in QUIC tests. +using QuicTest = quiche::test::QuicheTest; +template +using QuicTestWithParam = quiche::test::QuicheTestWithParam; + +} // namespace quic::test + +#define QUIC_TEST_DISABLED_IN_CHROME(name) QUICHE_TEST_DISABLED_IN_CHROME(name) + +#define QUIC_SLOW_TEST(test) QUICHE_SLOW_TEST(test) + +#endif // QUICHE_QUIC_PLATFORM_API_QUIC_TEST_H_ diff --git a/quiche/quic/platform/api/quic_test_loopback.h b/quiche/quic/platform/api/quic_test_loopback.h new file mode 100644 index 000000000000..2e9e42f04b41 --- /dev/null +++ b/quiche/quic/platform/api/quic_test_loopback.h @@ -0,0 +1,38 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_PLATFORM_API_QUIC_TEST_LOOPBACK_H_ +#define QUICHE_QUIC_PLATFORM_API_QUIC_TEST_LOOPBACK_H_ + +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_ip_address_family.h" +#include "quiche/common/platform/api/quiche_test_loopback.h" + +namespace quic { + +// Returns the address family (IPv4 or IPv6) used to run test under. +inline IpAddressFamily AddressFamilyUnderTest() { + return quiche::AddressFamilyUnderTest(); +} + +// Returns an IPv4 loopback address. +inline QuicIpAddress TestLoopback4() { return quiche::TestLoopback4(); } + +// Returns the only IPv6 loopback address. +inline QuicIpAddress TestLoopback6() { return quiche::TestLoopback6(); } + +// Returns an appropriate IPv4/Ipv6 loopback address based upon whether the +// test's environment. +inline QuicIpAddress TestLoopback() { return quiche::TestLoopback(); } + +// If address family under test is IPv4, returns an indexed IPv4 loopback +// address. If address family under test is IPv6, the address returned is +// platform-dependent. +inline QuicIpAddress TestLoopback(int index) { + return quiche::TestLoopback(index); +} + +} // namespace quic + +#endif // QUICHE_QUIC_PLATFORM_API_QUIC_TEST_LOOPBACK_H_ diff --git a/quiche/quic/platform/api/quic_test_output.h b/quiche/quic/platform/api/quic_test_output.h new file mode 100644 index 000000000000..8d4b54a72c84 --- /dev/null +++ b/quiche/quic/platform/api/quic_test_output.h @@ -0,0 +1,28 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_PLATFORM_API_QUIC_TEST_OUTPUT_H_ +#define QUICHE_QUIC_PLATFORM_API_QUIC_TEST_OUTPUT_H_ + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_test_output.h" + +namespace quic { + +inline void QuicSaveTestOutput(absl::string_view filename, + absl::string_view data) { + quiche::QuicheSaveTestOutput(filename, data); +} + +inline bool QuicLoadTestOutput(absl::string_view filename, std::string* data) { + return quiche::QuicheLoadTestOutput(filename, data); +} + +inline void QuicRecordTrace(absl::string_view identifier, + absl::string_view data) { + quiche::QuicheRecordTrace(identifier, data); +} + +} // namespace quic +#endif // QUICHE_QUIC_PLATFORM_API_QUIC_TEST_OUTPUT_H_ diff --git a/quiche/quic/platform/api/quic_testvalue.h b/quiche/quic/platform/api/quic_testvalue.h new file mode 100644 index 000000000000..aa6074c11d0a --- /dev/null +++ b/quiche/quic/platform/api/quic_testvalue.h @@ -0,0 +1,22 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// TODO(b/178613777): Remove this file. + +#ifndef QUICHE_QUIC_PLATFORM_API_QUIC_TESTVALUE_H_ +#define QUICHE_QUIC_PLATFORM_API_QUIC_TESTVALUE_H_ + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_testvalue.h" + +namespace quic { + +template +void AdjustTestValue(absl::string_view label, T* var) { + quiche::AdjustTestValue(label, var); +} + +} // namespace quic + +#endif // QUICHE_QUIC_PLATFORM_API_QUIC_TESTVALUE_H_ diff --git a/quiche/quic/platform/api/quic_thread.h b/quiche/quic/platform/api/quic_thread.h new file mode 100644 index 000000000000..4633306604db --- /dev/null +++ b/quiche/quic/platform/api/quic_thread.h @@ -0,0 +1,16 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_PLATFORM_API_QUIC_THREAD_H_ +#define QUICHE_QUIC_PLATFORM_API_QUIC_THREAD_H_ + +#include "quiche/common/platform/api/quiche_thread.h" + +namespace quic { + +using QuicThread = quiche::QuicheThread; + +} // namespace quic + +#endif // QUICHE_QUIC_PLATFORM_API_QUIC_THREAD_H_ diff --git a/quiche/quic/platform/api/quic_udp_socket_platform_api.h b/quiche/quic/platform/api/quic_udp_socket_platform_api.h new file mode 100644 index 000000000000..16b1025e39dc --- /dev/null +++ b/quiche/quic/platform/api/quic_udp_socket_platform_api.h @@ -0,0 +1,27 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_PLATFORM_API_QUIC_UDP_SOCKET_PLATFORM_API_H_ +#define QUICHE_QUIC_PLATFORM_API_QUIC_UDP_SOCKET_PLATFORM_API_H_ + +#include "quiche/common/platform/api/quiche_udp_socket_platform_api.h" + +namespace quic { + +const size_t kCmsgSpaceForGooglePacketHeader = + quiche::kCmsgSpaceForGooglePacketHeader; + +inline bool GetGooglePacketHeadersFromControlMessage( + struct ::cmsghdr* cmsg, char** packet_headers, size_t* packet_headers_len) { + return quiche::GetGooglePacketHeadersFromControlMessage(cmsg, packet_headers, + packet_headers_len); +} + +inline void SetGoogleSocketOptions(int fd) { + quiche::SetGoogleSocketOptions(fd); +} + +} // namespace quic + +#endif // QUICHE_QUIC_PLATFORM_API_QUIC_UDP_SOCKET_PLATFORM_API_H_ diff --git a/quiche/quic/qbone/bonnet/icmp_reachable.cc b/quiche/quic/qbone/bonnet/icmp_reachable.cc new file mode 100644 index 000000000000..8f85190c0ee3 --- /dev/null +++ b/quiche/quic/qbone/bonnet/icmp_reachable.cc @@ -0,0 +1,209 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/bonnet/icmp_reachable.h" + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_mutex.h" +#include "quiche/quic/qbone/platform/icmp_packet.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_text_utils.h" + +namespace quic { +namespace { + +constexpr QuicSocketEventMask kEventMask = + kSocketEventReadable | kSocketEventWritable; +constexpr size_t kMtu = 1280; + +constexpr size_t kIPv6AddrSize = sizeof(in6_addr); + +} // namespace + +const char kUnknownSource[] = "UNKNOWN"; +const char kNoSource[] = "N/A"; + +IcmpReachable::IcmpReachable(QuicIpAddress source, QuicIpAddress destination, + QuicTime::Delta timeout, KernelInterface* kernel, + QuicEventLoop* event_loop, StatsInterface* stats) + : timeout_(timeout), + event_loop_(event_loop), + clock_(event_loop->GetClock()), + alarm_factory_(event_loop->CreateAlarmFactory()), + cb_(this), + alarm_(alarm_factory_->CreateAlarm(new AlarmCallback(this))), + kernel_(kernel), + stats_(stats), + send_fd_(0), + recv_fd_(0) { + src_.sin6_family = AF_INET6; + dst_.sin6_family = AF_INET6; + + memcpy(&src_.sin6_addr, source.ToPackedString().data(), kIPv6AddrSize); + memcpy(&dst_.sin6_addr, destination.ToPackedString().data(), kIPv6AddrSize); +} + +IcmpReachable::~IcmpReachable() { + if (send_fd_ > 0) { + kernel_->close(send_fd_); + } + if (recv_fd_ > 0) { + bool success = event_loop_->UnregisterSocket(recv_fd_); + QUICHE_DCHECK(success); + + kernel_->close(recv_fd_); + } +} + +bool IcmpReachable::Init() { + send_fd_ = kernel_->socket(PF_INET6, SOCK_RAW | SOCK_NONBLOCK, IPPROTO_RAW); + if (send_fd_ < 0) { + QUIC_LOG(ERROR) << "Unable to open socket: " << errno; + return false; + } + + if (kernel_->bind(send_fd_, reinterpret_cast(&src_), + sizeof(sockaddr_in6)) < 0) { + QUIC_LOG(ERROR) << "Unable to bind socket: " << errno; + return false; + } + + recv_fd_ = + kernel_->socket(PF_INET6, SOCK_RAW | SOCK_NONBLOCK, IPPROTO_ICMPV6); + if (recv_fd_ < 0) { + QUIC_LOG(ERROR) << "Unable to open socket: " << errno; + return false; + } + + if (kernel_->bind(recv_fd_, reinterpret_cast(&src_), + sizeof(sockaddr_in6)) < 0) { + QUIC_LOG(ERROR) << "Unable to bind socket: " << errno; + return false; + } + + icmp6_filter filter; + ICMP6_FILTER_SETBLOCKALL(&filter); + ICMP6_FILTER_SETPASS(ICMP6_ECHO_REPLY, &filter); + if (kernel_->setsockopt(recv_fd_, SOL_ICMPV6, ICMP6_FILTER, &filter, + sizeof(filter)) < 0) { + QUIC_LOG(ERROR) << "Unable to set ICMP6 filter."; + return false; + } + + if (!event_loop_->RegisterSocket(recv_fd_, kEventMask, &cb_)) { + QUIC_LOG(ERROR) << "Unable to register recv ICMP socket"; + return false; + } + alarm_->Set(clock_->Now()); + + QuicWriterMutexLock mu(&header_lock_); + icmp_header_.icmp6_type = ICMP6_ECHO_REQUEST; + icmp_header_.icmp6_code = 0; + + QuicRandom::GetInstance()->RandBytes(&icmp_header_.icmp6_id, + sizeof(uint16_t)); + + return true; +} + +bool IcmpReachable::OnEvent(int fd) { + char buffer[kMtu]; + + sockaddr_in6 source_addr{}; + socklen_t source_addr_len = sizeof(source_addr); + + ssize_t size = kernel_->recvfrom(fd, &buffer, kMtu, 0, + reinterpret_cast(&source_addr), + &source_addr_len); + + if (size < 0) { + if (errno != EAGAIN && errno != EWOULDBLOCK) { + stats_->OnReadError(errno); + } + return false; + } + + QUIC_VLOG(2) << quiche::QuicheTextUtils::HexDump( + absl::string_view(buffer, size)); + + auto* header = reinterpret_cast(&buffer); + QuicWriterMutexLock mu(&header_lock_); + if (header->icmp6_data32[0] != icmp_header_.icmp6_data32[0]) { + QUIC_VLOG(2) << "Unexpected response. id: " << header->icmp6_id + << " seq: " << header->icmp6_seq + << " Expected id: " << icmp_header_.icmp6_id + << " seq: " << icmp_header_.icmp6_seq; + return true; + } + end_ = clock_->Now(); + QUIC_VLOG(1) << "Received ping response in " << (end_ - start_); + + std::string source; + QuicIpAddress source_ip; + if (!source_ip.FromPackedString( + reinterpret_cast(&source_addr.sin6_addr), sizeof(in6_addr))) { + QUIC_LOG(WARNING) << "Unable to parse source address."; + source = kUnknownSource; + } else { + source = source_ip.ToString(); + } + stats_->OnEvent({Status::REACHABLE, end_ - start_, source}); + return true; +} + +void IcmpReachable::OnAlarm() { + QuicWriterMutexLock mu(&header_lock_); + + if (end_ < start_) { + QUIC_VLOG(1) << "Timed out on sequence: " << icmp_header_.icmp6_seq; + stats_->OnEvent({Status::UNREACHABLE, QuicTime::Delta::Zero(), kNoSource}); + } + + icmp_header_.icmp6_seq++; + CreateIcmpPacket(src_.sin6_addr, dst_.sin6_addr, icmp_header_, "", + [this](absl::string_view packet) { + QUIC_VLOG(2) << quiche::QuicheTextUtils::HexDump(packet); + + ssize_t size = kernel_->sendto( + send_fd_, packet.data(), packet.size(), 0, + reinterpret_cast(&dst_), + sizeof(sockaddr_in6)); + + if (size < packet.size()) { + stats_->OnWriteError(errno); + } + start_ = clock_->Now(); + }); + + alarm_->Set(clock_->ApproximateNow() + timeout_); +} + +absl::string_view IcmpReachable::StatusName(IcmpReachable::Status status) { + switch (status) { + case REACHABLE: + return "REACHABLE"; + case UNREACHABLE: + return "UNREACHABLE"; + default: + return "UNKNOWN"; + } +} + +void IcmpReachable::EpollCallback::OnSocketEvent(QuicEventLoop* event_loop, + QuicUdpSocketFd fd, + QuicSocketEventMask events) { + bool can_read_more = reachable_->OnEvent(fd); + if (can_read_more) { + bool success = + event_loop->ArtificiallyNotifyEvent(fd, kSocketEventReadable); + QUICHE_DCHECK(success); + } +} + +} // namespace quic diff --git a/quiche/quic/qbone/bonnet/icmp_reachable.h b/quiche/quic/qbone/bonnet/icmp_reachable.h new file mode 100644 index 000000000000..529ccc713278 --- /dev/null +++ b/quiche/quic/qbone/bonnet/icmp_reachable.h @@ -0,0 +1,146 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_BONNET_ICMP_REACHABLE_H_ +#define QUICHE_QUIC_QBONE_BONNET_ICMP_REACHABLE_H_ + +#include + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/quic_alarm.h" +#include "quiche/quic/core/quic_alarm_factory.h" +#include "quiche/quic/core/quic_clock.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_mutex.h" +#include "quiche/quic/qbone/bonnet/icmp_reachable_interface.h" +#include "quiche/quic/qbone/platform/kernel_interface.h" + +namespace quic { + +extern const char kUnknownSource[]; +extern const char kNoSource[]; + +// IcmpReachable schedules itself with an EpollServer, periodically sending +// ICMPv6 Echo Requests to the given |destination| on the interface that the +// given |source| is bound to. Echo Requests are sent once every |timeout|. +// On Echo Replies, timeouts, and I/O errors, the given |stats| object will +// be called back with details of the event. +class IcmpReachable : public IcmpReachableInterface { + public: + enum Status { REACHABLE, UNREACHABLE }; + + struct ReachableEvent { + Status status; + QuicTime::Delta response_time; + std::string source; + }; + + class StatsInterface { + public: + StatsInterface() = default; + + StatsInterface(const StatsInterface&) = delete; + StatsInterface& operator=(const StatsInterface&) = delete; + + StatsInterface(StatsInterface&&) = delete; + StatsInterface& operator=(StatsInterface&&) = delete; + + virtual ~StatsInterface() = default; + + virtual void OnEvent(ReachableEvent event) = 0; + + virtual void OnReadError(int error) = 0; + + virtual void OnWriteError(int error) = 0; + }; + + // |source| is the IPv6 address bound to the interface that IcmpReachable will + // send Echo Requests on. + // |destination| is the IPv6 address of the destination of the Echo Requests. + // |timeout| is the duration IcmpReachable will wait between Echo Requests. + // If no Echo Response is received by the next Echo Request, it will + // be considered a timeout. + // |kernel| is not owned, but should outlive this instance. + // |epoll_server| is not owned, but should outlive this instance. + // IcmpReachable's Init() must be called from within the Epoll + // Server's thread. + // |stats| is not owned, but should outlive this instance. It will be called + // back on Echo Replies, timeouts, and I/O errors. + IcmpReachable(QuicIpAddress source, QuicIpAddress destination, + QuicTime::Delta timeout, KernelInterface* kernel, + QuicEventLoop* event_loop, StatsInterface* stats); + + ~IcmpReachable() override; + + // Initializes this reachability probe. Must be called from within the + // |epoll_server|'s thread. + bool Init() QUIC_LOCKS_EXCLUDED(header_lock_) override; + + void OnAlarm() QUIC_LOCKS_EXCLUDED(header_lock_); + + static absl::string_view StatusName(Status status); + + private: + class EpollCallback : public QuicSocketEventListener { + public: + explicit EpollCallback(IcmpReachable* reachable) : reachable_(reachable) {} + + EpollCallback(const EpollCallback&) = delete; + EpollCallback& operator=(const EpollCallback&) = delete; + + EpollCallback(EpollCallback&&) = delete; + EpollCallback& operator=(EpollCallback&&) = delete; + + void OnSocketEvent(QuicEventLoop* event_loop, QuicUdpSocketFd fd, + QuicSocketEventMask events) override; + + private: + IcmpReachable* reachable_; + }; + + class AlarmCallback : public QuicAlarm::DelegateWithoutContext { + public: + explicit AlarmCallback(IcmpReachable* reachable) : reachable_(reachable) {} + + void OnAlarm() override { reachable_->OnAlarm(); } + + private: + IcmpReachable* reachable_; + }; + + bool OnEvent(int fd) QUIC_LOCKS_EXCLUDED(header_lock_); + + const QuicTime::Delta timeout_; + + QuicEventLoop* event_loop_; + const QuicClock* clock_; + std::unique_ptr alarm_factory_; + + EpollCallback cb_; + std::unique_ptr alarm_; + + sockaddr_in6 src_{}; + sockaddr_in6 dst_{}; + + KernelInterface* kernel_; + + StatsInterface* stats_; + + int send_fd_; + int recv_fd_; + + QuicMutex header_lock_; + icmp6_hdr icmp_header_ QUIC_GUARDED_BY(header_lock_){}; + + QuicTime start_ = QuicTime::Zero(); + QuicTime end_ = QuicTime::Zero(); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_BONNET_ICMP_REACHABLE_H_ diff --git a/quiche/quic/qbone/bonnet/icmp_reachable_interface.h b/quiche/quic/qbone/bonnet/icmp_reachable_interface.h new file mode 100644 index 000000000000..2426670e4453 --- /dev/null +++ b/quiche/quic/qbone/bonnet/icmp_reachable_interface.h @@ -0,0 +1,27 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_BONNET_ICMP_REACHABLE_INTERFACE_H_ +#define QUICHE_QUIC_QBONE_BONNET_ICMP_REACHABLE_INTERFACE_H_ + +namespace quic { + +class IcmpReachableInterface { + public: + IcmpReachableInterface() = default; + virtual ~IcmpReachableInterface() = default; + + IcmpReachableInterface(const IcmpReachableInterface&) = delete; + IcmpReachableInterface& operator=(const IcmpReachableInterface&) = delete; + + IcmpReachableInterface(IcmpReachableInterface&&) = delete; + IcmpReachableInterface& operator=(IcmpReachableInterface&&) = delete; + + // Initializes this reachability probe. + virtual bool Init() = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_BONNET_ICMP_REACHABLE_INTERFACE_H_ diff --git a/quiche/quic/qbone/bonnet/icmp_reachable_test.cc b/quiche/quic/qbone/bonnet/icmp_reachable_test.cc new file mode 100644 index 000000000000..ae48ddc9bb69 --- /dev/null +++ b/quiche/quic/qbone/bonnet/icmp_reachable_test.cc @@ -0,0 +1,261 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/bonnet/icmp_reachable.h" + +#include + +#include + +#include "absl/container/node_hash_map.h" +#include "quiche/quic/core/io/quic_default_event_loop.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/quic_default_clock.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/qbone/platform/mock_kernel.h" + +namespace quic::test { +namespace { + +using ::testing::_; +using ::testing::InSequence; +using ::testing::Invoke; +using ::testing::Return; +using ::testing::StrictMock; + +constexpr char kSourceAddress[] = "fe80:1:2:3:4::1"; +constexpr char kDestinationAddress[] = "fe80:4:3:2:1::1"; + +constexpr int kFakeWriteFd = 0; + +icmp6_hdr GetHeaderFromPacket(const void* buf, size_t len) { + QUICHE_CHECK_GE(len, sizeof(ip6_hdr) + sizeof(icmp6_hdr)); + + auto* buffer = reinterpret_cast(buf); + return *reinterpret_cast(&buffer[sizeof(ip6_hdr)]); +} + +class StatsInterface : public IcmpReachable::StatsInterface { + public: + void OnEvent(IcmpReachable::ReachableEvent event) override { + switch (event.status) { + case IcmpReachable::REACHABLE: { + reachable_count_++; + break; + } + case IcmpReachable::UNREACHABLE: { + unreachable_count_++; + break; + } + } + current_source_ = event.source; + } + + void OnReadError(int error) override { read_errors_[error]++; } + + void OnWriteError(int error) override { write_errors_[error]++; } + + bool HasWriteErrors() { return !write_errors_.empty(); } + + int WriteErrorCount(int error) { return write_errors_[error]; } + + bool HasReadErrors() { return !read_errors_.empty(); } + + int ReadErrorCount(int error) { return read_errors_[error]; } + + int reachable_count() { return reachable_count_; } + + int unreachable_count() { return unreachable_count_; } + + std::string current_source() { return current_source_; } + + private: + int reachable_count_ = 0; + int unreachable_count_ = 0; + + std::string current_source_{}; + + absl::node_hash_map read_errors_; + absl::node_hash_map write_errors_; +}; + +class IcmpReachableTest : public QuicTest { + public: + IcmpReachableTest() + : event_loop_(GetDefaultEventLoop()->Create(QuicDefaultClock::Get())) { + QUICHE_CHECK(source_.FromString(kSourceAddress)); + QUICHE_CHECK(destination_.FromString(kDestinationAddress)); + + int pipe_fds[2]; + QUICHE_CHECK(pipe(pipe_fds) >= 0) << "pipe() failed"; + + read_fd_ = pipe_fds[0]; + read_src_fd_ = pipe_fds[1]; + } + + void SetFdExpectations() { + InSequence seq; + EXPECT_CALL(kernel_, socket(_, _, _)).WillOnce(Return(kFakeWriteFd)); + EXPECT_CALL(kernel_, bind(kFakeWriteFd, _, _)).WillOnce(Return(0)); + + EXPECT_CALL(kernel_, socket(_, _, _)).WillOnce(Return(read_fd_)); + EXPECT_CALL(kernel_, bind(read_fd_, _, _)).WillOnce(Return(0)); + + EXPECT_CALL(kernel_, setsockopt(read_fd_, SOL_ICMPV6, ICMP6_FILTER, _, _)); + + EXPECT_CALL(kernel_, close(read_fd_)).WillOnce(Invoke([](int fd) { + return close(fd); + })); + } + + protected: + QuicIpAddress source_; + QuicIpAddress destination_; + + int read_fd_; + int read_src_fd_; + + StrictMock kernel_; + std::unique_ptr event_loop_; + StatsInterface stats_; +}; + +TEST_F(IcmpReachableTest, SendsPings) { + IcmpReachable reachable(source_, destination_, QuicTime::Delta::Zero(), + &kernel_, event_loop_.get(), &stats_); + + SetFdExpectations(); + ASSERT_TRUE(reachable.Init()); + + EXPECT_CALL(kernel_, sendto(kFakeWriteFd, _, _, _, _, _)) + .WillOnce(Invoke([](int sockfd, const void* buf, size_t len, int flags, + const struct sockaddr* dest_addr, socklen_t addrlen) { + auto icmp_header = GetHeaderFromPacket(buf, len); + EXPECT_EQ(icmp_header.icmp6_type, ICMP6_ECHO_REQUEST); + EXPECT_EQ(icmp_header.icmp6_seq, 1); + return len; + })); + + event_loop_->RunEventLoopOnce(QuicTime::Delta::FromSeconds(1)); + EXPECT_FALSE(stats_.HasWriteErrors()); +} + +TEST_F(IcmpReachableTest, HandlesUnreachableEvents) { + IcmpReachable reachable(source_, destination_, QuicTime::Delta::Zero(), + &kernel_, event_loop_.get(), &stats_); + + SetFdExpectations(); + ASSERT_TRUE(reachable.Init()); + + EXPECT_CALL(kernel_, sendto(kFakeWriteFd, _, _, _, _, _)) + .Times(2) + .WillRepeatedly(Invoke([](int sockfd, const void* buf, size_t len, + int flags, const struct sockaddr* dest_addr, + socklen_t addrlen) { return len; })); + + event_loop_->RunEventLoopOnce(QuicTime::Delta::FromSeconds(1)); + EXPECT_EQ(stats_.unreachable_count(), 0); + + event_loop_->RunEventLoopOnce(QuicTime::Delta::FromSeconds(1)); + EXPECT_FALSE(stats_.HasWriteErrors()); + EXPECT_EQ(stats_.unreachable_count(), 1); + EXPECT_EQ(stats_.current_source(), kNoSource); +} + +TEST_F(IcmpReachableTest, HandlesReachableEvents) { + IcmpReachable reachable(source_, destination_, QuicTime::Delta::Zero(), + &kernel_, event_loop_.get(), &stats_); + + SetFdExpectations(); + ASSERT_TRUE(reachable.Init()); + + icmp6_hdr last_request_hdr{}; + EXPECT_CALL(kernel_, sendto(kFakeWriteFd, _, _, _, _, _)) + .Times(2) + .WillRepeatedly( + Invoke([&last_request_hdr]( + int sockfd, const void* buf, size_t len, int flags, + const struct sockaddr* dest_addr, socklen_t addrlen) { + last_request_hdr = GetHeaderFromPacket(buf, len); + return len; + })); + + sockaddr_in6 source_addr{}; + std::string packed_source = source_.ToPackedString(); + memcpy(&source_addr.sin6_addr, packed_source.data(), packed_source.size()); + + EXPECT_CALL(kernel_, recvfrom(read_fd_, _, _, _, _, _)) + .WillOnce( + Invoke([&source_addr](int sockfd, void* buf, size_t len, int flags, + struct sockaddr* src_addr, socklen_t* addrlen) { + *reinterpret_cast(src_addr) = source_addr; + return read(sockfd, buf, len); + })); + + event_loop_->RunEventLoopOnce(QuicTime::Delta::FromSeconds(1)); + EXPECT_EQ(stats_.reachable_count(), 0); + + icmp6_hdr response = last_request_hdr; + response.icmp6_type = ICMP6_ECHO_REPLY; + + write(read_src_fd_, reinterpret_cast(&response), + sizeof(icmp6_hdr)); + + event_loop_->RunEventLoopOnce(QuicTime::Delta::FromSeconds(1)); + EXPECT_FALSE(stats_.HasReadErrors()); + EXPECT_FALSE(stats_.HasWriteErrors()); + EXPECT_EQ(stats_.reachable_count(), 1); + EXPECT_EQ(stats_.current_source(), source_.ToString()); +} + +TEST_F(IcmpReachableTest, HandlesWriteErrors) { + IcmpReachable reachable(source_, destination_, QuicTime::Delta::Zero(), + &kernel_, event_loop_.get(), &stats_); + + SetFdExpectations(); + ASSERT_TRUE(reachable.Init()); + + EXPECT_CALL(kernel_, sendto(kFakeWriteFd, _, _, _, _, _)) + .WillOnce(Invoke([](int sockfd, const void* buf, size_t len, int flags, + const struct sockaddr* dest_addr, socklen_t addrlen) { + errno = EAGAIN; + return 0; + })); + + event_loop_->RunEventLoopOnce(QuicTime::Delta::FromSeconds(1)); + EXPECT_EQ(stats_.WriteErrorCount(EAGAIN), 1); +} + +TEST_F(IcmpReachableTest, HandlesReadErrors) { + IcmpReachable reachable(source_, destination_, QuicTime::Delta::Zero(), + &kernel_, event_loop_.get(), &stats_); + + SetFdExpectations(); + ASSERT_TRUE(reachable.Init()); + + EXPECT_CALL(kernel_, sendto(kFakeWriteFd, _, _, _, _, _)) + .WillOnce(Invoke([](int sockfd, const void* buf, size_t len, int flags, + const struct sockaddr* dest_addr, + socklen_t addrlen) { return len; })); + + EXPECT_CALL(kernel_, recvfrom(read_fd_, _, _, _, _, _)) + .WillOnce(Invoke([](int sockfd, void* buf, size_t len, int flags, + struct sockaddr* src_addr, socklen_t* addrlen) { + errno = EIO; + return -1; + })); + + icmp6_hdr response{}; + + write(read_src_fd_, reinterpret_cast(&response), + sizeof(icmp6_hdr)); + + event_loop_->RunEventLoopOnce(QuicTime::Delta::FromSeconds(1)); + EXPECT_EQ(stats_.reachable_count(), 0); + EXPECT_EQ(stats_.ReadErrorCount(EIO), 1); +} + +} // namespace +} // namespace quic::test diff --git a/quiche/quic/qbone/bonnet/mock_icmp_reachable.h b/quiche/quic/qbone/bonnet/mock_icmp_reachable.h new file mode 100644 index 000000000000..13c294fddfe1 --- /dev/null +++ b/quiche/quic/qbone/bonnet/mock_icmp_reachable.h @@ -0,0 +1,20 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_BONNET_MOCK_ICMP_REACHABLE_H_ +#define QUICHE_QUIC_QBONE_BONNET_MOCK_ICMP_REACHABLE_H_ + +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/qbone/bonnet/icmp_reachable_interface.h" + +namespace quic { + +class MockIcmpReachable : public IcmpReachableInterface { + public: + MOCK_METHOD(bool, Init, (), (override)); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_BONNET_MOCK_ICMP_REACHABLE_H_ diff --git a/quiche/quic/qbone/bonnet/mock_packet_exchanger_stats_interface.h b/quiche/quic/qbone/bonnet/mock_packet_exchanger_stats_interface.h new file mode 100644 index 000000000000..004879024345 --- /dev/null +++ b/quiche/quic/qbone/bonnet/mock_packet_exchanger_stats_interface.h @@ -0,0 +1,27 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_BONNET_MOCK_PACKET_EXCHANGER_STATS_INTERFACE_H_ +#define QUICHE_QUIC_QBONE_BONNET_MOCK_PACKET_EXCHANGER_STATS_INTERFACE_H_ + +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/qbone/bonnet/tun_device_packet_exchanger.h" + +namespace quic { + +class MockPacketExchangerStatsInterface + : public TunDevicePacketExchanger::StatsInterface { + public: + MOCK_METHOD(void, OnPacketRead, (size_t), (override)); + MOCK_METHOD(void, OnPacketWritten, (size_t), (override)); + MOCK_METHOD(void, OnReadError, (std::string*), (override)); + MOCK_METHOD(void, OnWriteError, (std::string*), (override)); + + MOCK_METHOD(int64_t, PacketsRead, (), (const, override)); + MOCK_METHOD(int64_t, PacketsWritten, (), (const, override)); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_BONNET_MOCK_PACKET_EXCHANGER_STATS_INTERFACE_H_ diff --git a/quiche/quic/qbone/bonnet/mock_qbone_tunnel.h b/quiche/quic/qbone/bonnet/mock_qbone_tunnel.h new file mode 100644 index 000000000000..409bc4764201 --- /dev/null +++ b/quiche/quic/qbone/bonnet/mock_qbone_tunnel.h @@ -0,0 +1,45 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_BONNET_MOCK_QBONE_TUNNEL_H_ +#define QUICHE_QUIC_QBONE_BONNET_MOCK_QBONE_TUNNEL_H_ + +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/qbone/bonnet/qbone_tunnel_interface.h" + +namespace quic { + +class MockQboneTunnel : public QboneTunnelInterface { + public: + MockQboneTunnel() = default; + + MOCK_METHOD(bool, WaitForEvents, (), (override)); + + MOCK_METHOD(void, Wake, (), (override)); + + MOCK_METHOD(void, ResetTunnel, (), (override)); + + MOCK_METHOD(State, Disconnect, (), (override)); + + MOCK_METHOD(void, OnControlRequest, (const quic::QboneClientRequest&), + (override)); + + MOCK_METHOD(void, OnControlError, (), (override)); + + MOCK_METHOD(bool, AwaitConnection, ()); + + MOCK_METHOD(std::string, StateToString, (State), (override)); + + MOCK_METHOD(quic::QboneClient*, client, (), (override)); + + MOCK_METHOD(State, state, ()); + + MOCK_METHOD(std::string, HealthString, ()); + + MOCK_METHOD(std::string, ServerRegionString, ()); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_BONNET_MOCK_QBONE_TUNNEL_H_ diff --git a/quiche/quic/qbone/bonnet/mock_tun_device.h b/quiche/quic/qbone/bonnet/mock_tun_device.h new file mode 100644 index 000000000000..b712db5a8d1b --- /dev/null +++ b/quiche/quic/qbone/bonnet/mock_tun_device.h @@ -0,0 +1,28 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_BONNET_MOCK_TUN_DEVICE_H_ +#define QUICHE_QUIC_QBONE_BONNET_MOCK_TUN_DEVICE_H_ + +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/qbone/bonnet/tun_device_interface.h" + +namespace quic { + +class MockTunDevice : public TunDeviceInterface { + public: + MOCK_METHOD(bool, Init, (), (override)); + + MOCK_METHOD(bool, Up, (), (override)); + + MOCK_METHOD(bool, Down, (), (override)); + + MOCK_METHOD(void, CloseDevice, (), (override)); + + MOCK_METHOD(int, GetFileDescriptor, (), (const, override)); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_BONNET_MOCK_TUN_DEVICE_H_ diff --git a/quiche/quic/qbone/bonnet/mock_tun_device_controller.h b/quiche/quic/qbone/bonnet/mock_tun_device_controller.h new file mode 100644 index 000000000000..60ea0b66ca55 --- /dev/null +++ b/quiche/quic/qbone/bonnet/mock_tun_device_controller.h @@ -0,0 +1,27 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_BONNET_MOCK_TUN_DEVICE_CONTROLLER_H_ +#define QUICHE_QUIC_QBONE_BONNET_MOCK_TUN_DEVICE_CONTROLLER_H_ + +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/qbone/bonnet/tun_device_controller.h" + +namespace quic { + +class MockTunDeviceController : public TunDeviceController { + public: + MockTunDeviceController() : TunDeviceController("", true, nullptr) {} + + MOCK_METHOD(bool, UpdateAddress, (const IpRange&), (override)); + + MOCK_METHOD(bool, UpdateRoutes, (const IpRange&, const std::vector&), + (override)); + + MOCK_METHOD(QuicIpAddress, current_address, (), (override)); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_BONNET_MOCK_TUN_DEVICE_CONTROLLER_H_ diff --git a/quiche/quic/qbone/bonnet/qbone_tunnel_info.cc b/quiche/quic/qbone/bonnet/qbone_tunnel_info.cc new file mode 100644 index 000000000000..09589567293e --- /dev/null +++ b/quiche/quic/qbone/bonnet/qbone_tunnel_info.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/bonnet/qbone_tunnel_info.h" + +namespace quic { + +QuicIpAddress QboneTunnelInfo::GetAddress() { + QuicIpAddress no_address; + + NetlinkInterface::LinkInfo link_info{}; + if (!netlink_->GetLinkInfo(ifname_, &link_info)) { + return no_address; + } + + std::vector addresses; + if (!netlink_->GetAddresses(link_info.index, 0, &addresses, nullptr)) { + return no_address; + } + + quic::QuicIpAddress link_local_subnet; + if (!link_local_subnet.FromString("FE80::")) { + return no_address; + } + + for (const auto& address : addresses) { + if (address.interface_address.IsInitialized() && + !link_local_subnet.InSameSubnet(address.interface_address, 10)) { + return address.interface_address; + } + } + + return no_address; +} + +} // namespace quic diff --git a/quiche/quic/qbone/bonnet/qbone_tunnel_info.h b/quiche/quic/qbone/bonnet/qbone_tunnel_info.h new file mode 100644 index 000000000000..d928ba7caa64 --- /dev/null +++ b/quiche/quic/qbone/bonnet/qbone_tunnel_info.h @@ -0,0 +1,29 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_BONNET_QBONE_TUNNEL_INFO_H_ +#define QUICHE_QUIC_QBONE_BONNET_QBONE_TUNNEL_INFO_H_ + +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/qbone/platform/netlink_interface.h" + +namespace quic { + +class QboneTunnelInfo { + public: + QboneTunnelInfo(std::string ifname, NetlinkInterface* netlink) + : ifname_(std::move(ifname)), netlink_(netlink) {} + + // Returns the current QBONE tunnel address. Callers must use IsInitialized() + // to ensure the returned address is valid. + QuicIpAddress GetAddress(); + + private: + const std::string ifname_; + NetlinkInterface* netlink_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_BONNET_QBONE_TUNNEL_INFO_H_ diff --git a/quiche/quic/qbone/bonnet/qbone_tunnel_interface.h b/quiche/quic/qbone/bonnet/qbone_tunnel_interface.h new file mode 100644 index 000000000000..c4bd7b4bfa4e --- /dev/null +++ b/quiche/quic/qbone/bonnet/qbone_tunnel_interface.h @@ -0,0 +1,70 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_BONNET_QBONE_TUNNEL_INTERFACE_H_ +#define QUICHE_QUIC_QBONE_BONNET_QBONE_TUNNEL_INTERFACE_H_ + +#include "quiche/quic/qbone/qbone_client.h" + +namespace quic { + +// Interface for establishing bidirectional communication between a network +// device and a QboneClient. +class QboneTunnelInterface : public quic::QboneClientControlStream::Handler { + public: + QboneTunnelInterface() = default; + + QboneTunnelInterface(const QboneTunnelInterface&) = delete; + QboneTunnelInterface& operator=(const QboneTunnelInterface&) = delete; + + QboneTunnelInterface(QboneTunnelInterface&&) = delete; + QboneTunnelInterface& operator=(QboneTunnelInterface&&) = delete; + + enum State { + UNINITIALIZED, + IP_RANGE_REQUESTED, + START_REQUESTED, + STARTED, + LAME_DUCK_REQUESTED, + END_REQUESTED, + ENDED, + FAILED, + }; + + // Wait and handle any events which occur. + // Returns true if there are any outstanding requests. + virtual bool WaitForEvents() = 0; + + // Wakes the tunnel if it is currently in WaitForEvents. + virtual void Wake() = 0; + + // Disconnect the tunnel, resetting it to an uninitialized state. This will + // force ConnectIfNeeded to reconnect on the next epoll cycle. + virtual void ResetTunnel() = 0; + + // Disconnect from the QBONE server. + virtual State Disconnect() = 0; + + // Callback handling responses from the QBONE server. + void OnControlRequest(const QboneClientRequest& request) override = 0; + + // Callback handling bad responses from the QBONE server. Currently, this is + // only called when the response is unparsable. + void OnControlError() override = 0; + + // Returns a string value of the given state. + virtual std::string StateToString(State state) = 0; + + virtual QboneClient* client() = 0; + + virtual State state() = 0; + + virtual std::string HealthString() = 0; + + virtual std::string ServerRegionString() = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_BONNET_QBONE_TUNNEL_INTERFACE_H_ diff --git a/quiche/quic/qbone/bonnet/qbone_tunnel_silo.cc b/quiche/quic/qbone/bonnet/qbone_tunnel_silo.cc new file mode 100644 index 000000000000..1448d109ebdc --- /dev/null +++ b/quiche/quic/qbone/bonnet/qbone_tunnel_silo.cc @@ -0,0 +1,31 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/bonnet/qbone_tunnel_silo.h" + +namespace quic { + +void QboneTunnelSilo::Run() { + while (ShouldRun()) { + tunnel_->WaitForEvents(); + } + + QUIC_LOG(INFO) << "Tunnel has disconnected in state: " + << tunnel_->StateToString(tunnel_->Disconnect()); +} + +void QboneTunnelSilo::Quit() { + QUIC_LOG(INFO) << "Quit called on QboneTunnelSilo"; + quitting_.Notify(); + tunnel_->Wake(); +} + +bool QboneTunnelSilo::ShouldRun() { + bool post_init_shutdown_ready = + only_setup_tun_ && + tunnel_->state() == quic::QboneTunnelInterface::STARTED; + return !quitting_.HasBeenNotified() && !post_init_shutdown_ready; +} + +} // namespace quic diff --git a/quiche/quic/qbone/bonnet/qbone_tunnel_silo.h b/quiche/quic/qbone/bonnet/qbone_tunnel_silo.h new file mode 100644 index 000000000000..5e34783abf9f --- /dev/null +++ b/quiche/quic/qbone/bonnet/qbone_tunnel_silo.h @@ -0,0 +1,48 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_BONNET_QBONE_TUNNEL_SILO_H_ +#define QUICHE_QUIC_QBONE_BONNET_QBONE_TUNNEL_SILO_H_ + +#include "absl/synchronization/notification.h" +#include "quiche/quic/platform/api/quic_thread.h" +#include "quiche/quic/qbone/bonnet/qbone_tunnel_interface.h" + +namespace quic { + +// QboneTunnelSilo is a thread that initializes and evaluates a QboneTunnel's +// event loop. +class QboneTunnelSilo : public QuicThread { + public: + // Does not take ownership of |tunnel| + explicit QboneTunnelSilo(QboneTunnelInterface* tunnel, bool only_setup_tun) + : QuicThread("QboneTunnelSilo"), + tunnel_(tunnel), + only_setup_tun_(only_setup_tun) {} + + QboneTunnelSilo(const QboneTunnelSilo&) = delete; + QboneTunnelSilo& operator=(const QboneTunnelSilo&) = delete; + + QboneTunnelSilo(QboneTunnelSilo&&) = delete; + QboneTunnelSilo& operator=(QboneTunnelSilo&&) = delete; + + // Terminates the tunnel's event loop. This silo must still be joined. + void Quit(); + + protected: + void Run() override; + + private: + bool ShouldRun(); + + QboneTunnelInterface* tunnel_; + + absl::Notification quitting_; + + const bool only_setup_tun_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_BONNET_QBONE_TUNNEL_SILO_H_ diff --git a/quiche/quic/qbone/bonnet/qbone_tunnel_silo_test.cc b/quiche/quic/qbone/bonnet/qbone_tunnel_silo_test.cc new file mode 100644 index 000000000000..4fe967ad3c32 --- /dev/null +++ b/quiche/quic/qbone/bonnet/qbone_tunnel_silo_test.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/bonnet/qbone_tunnel_silo.h" + +#include "absl/synchronization/notification.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/qbone/bonnet/mock_qbone_tunnel.h" + +namespace quic { +namespace { + +using ::testing::Eq; +using ::testing::Invoke; +using ::testing::Return; + +TEST(QboneTunnelSiloTest, SiloRunsEventLoop) { + MockQboneTunnel mock_tunnel; + + absl::Notification event_loop_run; + EXPECT_CALL(mock_tunnel, WaitForEvents) + .WillRepeatedly(Invoke([&event_loop_run]() { + if (!event_loop_run.HasBeenNotified()) { + event_loop_run.Notify(); + } + return false; + })); + + QboneTunnelSilo silo(&mock_tunnel, false); + silo.Start(); + + event_loop_run.WaitForNotification(); + + absl::Notification client_disconnected; + EXPECT_CALL(mock_tunnel, Disconnect) + .WillOnce(Invoke([&client_disconnected]() { + client_disconnected.Notify(); + return QboneTunnelInterface::ENDED; + })); + + silo.Quit(); + client_disconnected.WaitForNotification(); + + silo.Join(); +} + +TEST(QboneTunnelSiloTest, SiloCanShutDownAfterInit) { + MockQboneTunnel mock_tunnel; + + int iteration_count = 0; + EXPECT_CALL(mock_tunnel, WaitForEvents) + .WillRepeatedly(Invoke([&iteration_count]() { + iteration_count++; + return false; + })); + + EXPECT_CALL(mock_tunnel, state) + .WillOnce(Return(QboneTunnelInterface::START_REQUESTED)) + .WillOnce(Return(QboneTunnelInterface::STARTED)); + + absl::Notification client_disconnected; + EXPECT_CALL(mock_tunnel, Disconnect) + .WillOnce(Invoke([&client_disconnected]() { + client_disconnected.Notify(); + return QboneTunnelInterface::ENDED; + })); + + QboneTunnelSilo silo(&mock_tunnel, true); + silo.Start(); + + client_disconnected.WaitForNotification(); + silo.Join(); + EXPECT_THAT(iteration_count, Eq(1)); +} + +} // namespace +} // namespace quic diff --git a/quiche/quic/qbone/bonnet/tun_device.cc b/quiche/quic/qbone/bonnet/tun_device.cc new file mode 100644 index 000000000000..9b11d99e8c6d --- /dev/null +++ b/quiche/quic/qbone/bonnet/tun_device.cc @@ -0,0 +1,217 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/bonnet/tun_device.h" + +#include +#include +#include +#include +#include + +#include "absl/cleanup/cleanup.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/qbone/platform/kernel_interface.h" + +ABSL_FLAG(std::string, qbone_client_tun_device_path, "/dev/net/tun", + "The path to the QBONE client's TUN device."); + +namespace quic { + +const int kInvalidFd = -1; + +TunTapDevice::TunTapDevice(const std::string& interface_name, int mtu, + bool persist, bool setup_tun, bool is_tap, + KernelInterface* kernel) + : interface_name_(interface_name), + mtu_(mtu), + persist_(persist), + setup_tun_(setup_tun), + is_tap_(is_tap), + file_descriptor_(kInvalidFd), + kernel_(*kernel) {} + +TunTapDevice::~TunTapDevice() { + if (!persist_) { + Down(); + } + CloseDevice(); +} + +bool TunTapDevice::Init() { + if (interface_name_.empty() || interface_name_.size() >= IFNAMSIZ) { + QUIC_BUG(quic_bug_10995_1) + << "interface_name must be nonempty and shorter than " << IFNAMSIZ; + return false; + } + + if (!OpenDevice()) { + return false; + } + + if (!ConfigureInterface()) { + return false; + } + + return true; +} + +// TODO(pengg): might be better to use netlink socket, once we have a library to +// use +bool TunTapDevice::Up() { + if (!setup_tun_) { + return true; + } + struct ifreq if_request; + memset(&if_request, 0, sizeof(if_request)); + // copy does not zero-terminate the result string, but we've memset the + // entire struct. + interface_name_.copy(if_request.ifr_name, IFNAMSIZ); + if_request.ifr_flags = IFF_UP; + + return NetdeviceIoctl(SIOCSIFFLAGS, reinterpret_cast(&if_request)); +} + +// TODO(pengg): might be better to use netlink socket, once we have a library to +// use +bool TunTapDevice::Down() { + if (!setup_tun_) { + return true; + } + struct ifreq if_request; + memset(&if_request, 0, sizeof(if_request)); + // copy does not zero-terminate the result string, but we've memset the + // entire struct. + interface_name_.copy(if_request.ifr_name, IFNAMSIZ); + if_request.ifr_flags = 0; + + return NetdeviceIoctl(SIOCSIFFLAGS, reinterpret_cast(&if_request)); +} + +int TunTapDevice::GetFileDescriptor() const { return file_descriptor_; } + +bool TunTapDevice::OpenDevice() { + if (file_descriptor_ != kInvalidFd) { + CloseDevice(); + } + + struct ifreq if_request; + memset(&if_request, 0, sizeof(if_request)); + // copy does not zero-terminate the result string, but we've memset the entire + // struct. + interface_name_.copy(if_request.ifr_name, IFNAMSIZ); + + // Always set IFF_MULTI_QUEUE since a persistent device does not allow this + // flag to be flipped when re-opening it. The only way to flip this flag is to + // destroy the device and create a new one, but that deletes any existing + // routing associated with the interface, which makes the meaning of the + // 'persist' bit ambiguous. + if_request.ifr_flags = IFF_MULTI_QUEUE | IFF_NO_PI; + if (is_tap_) { + if_request.ifr_flags |= IFF_TAP; + } else { + if_request.ifr_flags |= IFF_TUN; + } + + // When the device is running with IFF_MULTI_QUEUE set, each call to open will + // create a queue which can be used to read/write packets from/to the device. + bool successfully_opened = false; + auto cleanup = absl::MakeCleanup([this, &successfully_opened]() { + if (!successfully_opened) { + CloseDevice(); + } + }); + + const std::string tun_device_path = + absl::GetFlag(FLAGS_qbone_client_tun_device_path); + int fd = kernel_.open(tun_device_path.c_str(), O_RDWR); + if (fd < 0) { + QUIC_PLOG(WARNING) << "Failed to open " << tun_device_path; + return successfully_opened; + } + file_descriptor_ = fd; + if (!CheckFeatures(fd)) { + return successfully_opened; + } + + if (kernel_.ioctl(fd, TUNSETIFF, reinterpret_cast(&if_request)) != 0) { + QUIC_PLOG(WARNING) << "Failed to TUNSETIFF on fd(" << fd << ")"; + return successfully_opened; + } + + if (kernel_.ioctl( + fd, TUNSETPERSIST, + persist_ ? reinterpret_cast(&if_request) : nullptr) != 0) { + QUIC_PLOG(WARNING) << "Failed to TUNSETPERSIST on fd(" << fd << ")"; + return successfully_opened; + } + + successfully_opened = true; + return successfully_opened; +} + +// TODO(pengg): might be better to use netlink socket, once we have a library to +// use +bool TunTapDevice::ConfigureInterface() { + if (!setup_tun_) { + return true; + } + + struct ifreq if_request; + memset(&if_request, 0, sizeof(if_request)); + // copy does not zero-terminate the result string, but we've memset the entire + // struct. + interface_name_.copy(if_request.ifr_name, IFNAMSIZ); + if_request.ifr_mtu = mtu_; + + if (!NetdeviceIoctl(SIOCSIFMTU, reinterpret_cast(&if_request))) { + CloseDevice(); + return false; + } + + return true; +} + +bool TunTapDevice::CheckFeatures(int tun_device_fd) { + unsigned int actual_features; + if (kernel_.ioctl(tun_device_fd, TUNGETFEATURES, &actual_features) != 0) { + QUIC_PLOG(WARNING) << "Failed to TUNGETFEATURES"; + return false; + } + unsigned int required_features = IFF_TUN | IFF_NO_PI; + if ((required_features & actual_features) != required_features) { + QUIC_LOG(WARNING) + << "Required feature does not exist. required_features: 0x" << std::hex + << required_features << " vs actual_features: 0x" << std::hex + << actual_features; + return false; + } + return true; +} + +bool TunTapDevice::NetdeviceIoctl(int request, void* argp) { + int fd = kernel_.socket(AF_INET6, SOCK_DGRAM, 0); + if (fd < 0) { + QUIC_PLOG(WARNING) << "Failed to create AF_INET6 socket."; + return false; + } + + if (kernel_.ioctl(fd, request, argp) != 0) { + QUIC_PLOG(WARNING) << "Failed ioctl request: " << request; + kernel_.close(fd); + return false; + } + kernel_.close(fd); + return true; +} + +void TunTapDevice::CloseDevice() { + if (file_descriptor_ != kInvalidFd) { + kernel_.close(file_descriptor_); + file_descriptor_ = kInvalidFd; + } +} + +} // namespace quic diff --git a/quiche/quic/qbone/bonnet/tun_device.h b/quiche/quic/qbone/bonnet/tun_device.h new file mode 100644 index 000000000000..ccc22cbe9a11 --- /dev/null +++ b/quiche/quic/qbone/bonnet/tun_device.h @@ -0,0 +1,82 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_BONNET_TUN_DEVICE_H_ +#define QUICHE_QUIC_QBONE_BONNET_TUN_DEVICE_H_ + +#include +#include + +#include "quiche/quic/qbone/bonnet/tun_device_interface.h" +#include "quiche/quic/qbone/platform/kernel_interface.h" + +namespace quic { + +class TunTapDevice : public TunDeviceInterface { + public: + // This represents a tun device created in the OS kernel, which is a virtual + // network interface that any packets sent to it can be read by a user space + // program that owns it. The routing rule that routes packets to this + // interface should be defined somewhere else. + // + // Standard read/write system calls can be used to receive/send packets + // from/to this interface. The file descriptor is owned by this class. + // + // If persist is set to true, the device won't be deleted even after + // destructing. The device will be picked up when initializing this class with + // the same interface_name on the next time. + // + // Persisting the device is useful if one wants to keep the routing rules + // since once a tun device is destroyed by the kernel, all the associated + // routing rules go away. + // + // The caller should own kernel and make sure it outlives this. + TunTapDevice(const std::string& interface_name, int mtu, bool persist, + bool setup_tun, bool is_tap, KernelInterface* kernel); + + ~TunTapDevice() override; + + // Actually creates/reopens and configures the device. + bool Init() override; + + // Marks the interface up to start receiving packets. + bool Up() override; + + // Marks the interface down to stop receiving packets. + bool Down() override; + + // Closes the open file descriptor for the TUN device (if one exists). + // It is safe to reinitialize and reuse this TunTapDevice after calling + // CloseDevice. + void CloseDevice() override; + + // Gets the file descriptor that can be used to send/receive packets. + // This returns -1 when the TUN device is in an invalid state. + int GetFileDescriptor() const override; + + private: + // Creates or reopens the tun device. + bool OpenDevice(); + + // Configure the interface. + bool ConfigureInterface(); + + // Checks if the required kernel features exists. + bool CheckFeatures(int tun_device_fd); + + // Opens a socket and makes netdevice ioctl call + bool NetdeviceIoctl(int request, void* argp); + + const std::string interface_name_; + const int mtu_; + const bool persist_; + const bool setup_tun_; + const bool is_tap_; + int file_descriptor_; + KernelInterface& kernel_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_BONNET_TUN_DEVICE_H_ diff --git a/quiche/quic/qbone/bonnet/tun_device_controller.cc b/quiche/quic/qbone/bonnet/tun_device_controller.cc new file mode 100644 index 000000000000..b6e3d7684fc6 --- /dev/null +++ b/quiche/quic/qbone/bonnet/tun_device_controller.cc @@ -0,0 +1,176 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/bonnet/tun_device_controller.h" + +#include + +#include "absl/time/clock.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/qbone/qbone_constants.h" + +ABSL_FLAG(bool, qbone_tun_device_replace_default_routing_rules, true, + "If true, will define a rule that points packets sourced from the " + "qbone interface to the qbone table. This is unnecessary in " + "environments with no other ipv6 route."); + +ABSL_FLAG(int, qbone_route_init_cwnd, + quic::NetlinkInterface::kUnspecifiedInitCwnd, + "If non-zero, will add initcwnd to QBONE routing rules. Setting " + "a value below 10 is dangerous and not recommended."); + +namespace quic { + +bool TunDeviceController::UpdateAddress(const IpRange& desired_range) { + if (!setup_tun_) { + return true; + } + + NetlinkInterface::LinkInfo link_info{}; + if (!netlink_->GetLinkInfo(ifname_, &link_info)) { + return false; + } + + std::vector addresses; + if (!netlink_->GetAddresses(link_info.index, 0, &addresses, nullptr)) { + return false; + } + + QuicIpAddress desired_address = desired_range.FirstAddressInRange(); + + for (const auto& address : addresses) { + if (!netlink_->ChangeLocalAddress( + link_info.index, NetlinkInterface::Verb::kRemove, + address.interface_address, address.prefix_length, 0, 0, {})) { + return false; + } + } + + bool address_updated = netlink_->ChangeLocalAddress( + link_info.index, NetlinkInterface::Verb::kAdd, desired_address, + desired_range.prefix_length(), IFA_F_PERMANENT | IFA_F_NODAD, + RT_SCOPE_LINK, {}); + + if (address_updated) { + current_address_ = desired_address; + + for (const auto& cb : address_update_cbs_) { + cb(current_address_); + } + } + + return address_updated; +} + +bool TunDeviceController::UpdateRoutes( + const IpRange& desired_range, const std::vector& desired_routes) { + if (!setup_tun_) { + return true; + } + + NetlinkInterface::LinkInfo link_info{}; + if (!netlink_->GetLinkInfo(ifname_, &link_info)) { + QUIC_LOG(ERROR) << "Could not get link info for interface <" << ifname_ + << ">"; + return false; + } + + std::vector routing_rules; + if (!netlink_->GetRouteInfo(&routing_rules)) { + QUIC_LOG(ERROR) << "Unable to get route info"; + return false; + } + + for (const auto& rule : routing_rules) { + if (rule.out_interface == link_info.index && + rule.table == QboneConstants::kQboneRouteTableId) { + if (!netlink_->ChangeRoute(NetlinkInterface::Verb::kRemove, rule.table, + rule.destination_subnet, rule.scope, + rule.preferred_source, rule.out_interface, + rule.init_cwnd)) { + QUIC_LOG(ERROR) << "Unable to remove old route to <" + << rule.destination_subnet.ToString() << ">"; + return false; + } + } + } + + if (!UpdateRules(desired_range)) { + return false; + } + + QuicIpAddress desired_address = desired_range.FirstAddressInRange(); + + std::vector routes(desired_routes.begin(), desired_routes.end()); + routes.emplace_back(*QboneConstants::TerminatorLocalAddressRange()); + + for (const auto& route : routes) { + if (!netlink_->ChangeRoute(NetlinkInterface::Verb::kReplace, + QboneConstants::kQboneRouteTableId, route, + RT_SCOPE_LINK, desired_address, link_info.index, + absl::GetFlag(FLAGS_qbone_route_init_cwnd))) { + QUIC_LOG(ERROR) << "Unable to add route <" << route.ToString() << ">"; + return false; + } + } + + return true; +} + +bool TunDeviceController::UpdateRoutesWithRetries( + const IpRange& desired_range, const std::vector& desired_routes, + int retries) { + while (retries-- > 0) { + if (UpdateRoutes(desired_range, desired_routes)) { + return true; + } + absl::SleepFor(absl::Milliseconds(100)); + } + return false; +} + +bool TunDeviceController::UpdateRules(IpRange desired_range) { + if (!absl::GetFlag(FLAGS_qbone_tun_device_replace_default_routing_rules)) { + return true; + } + + std::vector ip_rules; + if (!netlink_->GetRuleInfo(&ip_rules)) { + QUIC_LOG(ERROR) << "Unable to get rule info"; + return false; + } + + for (const auto& rule : ip_rules) { + if (rule.table == QboneConstants::kQboneRouteTableId) { + if (!netlink_->ChangeRule(NetlinkInterface::Verb::kRemove, rule.table, + rule.source_range)) { + QUIC_LOG(ERROR) << "Unable to remove old rule for table <" << rule.table + << "> from source <" << rule.source_range.ToString() + << ">"; + return false; + } + } + } + + if (!netlink_->ChangeRule(NetlinkInterface::Verb::kAdd, + QboneConstants::kQboneRouteTableId, + desired_range)) { + QUIC_LOG(ERROR) << "Unable to add rule for <" << desired_range.ToString() + << ">"; + return false; + } + + return true; +} + +QuicIpAddress TunDeviceController::current_address() { + return current_address_; +} + +void TunDeviceController::RegisterAddressUpdateCallback( + const std::function& cb) { + address_update_cbs_.push_back(cb); +} + +} // namespace quic diff --git a/quiche/quic/qbone/bonnet/tun_device_controller.h b/quiche/quic/qbone/bonnet/tun_device_controller.h new file mode 100644 index 000000000000..30435832fc4c --- /dev/null +++ b/quiche/quic/qbone/bonnet/tun_device_controller.h @@ -0,0 +1,73 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_BONNET_TUN_DEVICE_CONTROLLER_H_ +#define QUICHE_QUIC_QBONE_BONNET_TUN_DEVICE_CONTROLLER_H_ + +#include "quiche/quic/qbone/bonnet/tun_device.h" +#include "quiche/quic/qbone/platform/netlink_interface.h" +#include "quiche/quic/qbone/qbone_control.pb.h" +#include "quiche/quic/qbone/qbone_control_stream.h" + +namespace quic { + +// TunDeviceController consumes control stream messages from a Qbone server +// and applies the given updates to the TUN device. +class TunDeviceController { + public: + // |ifname| is the interface name of the TUN device to be managed. This does + // not take ownership of |netlink|. + TunDeviceController(std::string ifname, bool setup_tun, + NetlinkInterface* netlink) + : ifname_(std::move(ifname)), setup_tun_(setup_tun), netlink_(netlink) {} + + TunDeviceController(const TunDeviceController&) = delete; + TunDeviceController& operator=(const TunDeviceController&) = delete; + + TunDeviceController(TunDeviceController&&) = delete; + TunDeviceController& operator=(TunDeviceController&&) = delete; + + virtual ~TunDeviceController() = default; + + // Updates the local address of the TUN device to be the first address in the + // given |response.ip_range()|. + virtual bool UpdateAddress(const IpRange& desired_range); + + // Updates the set of routes that the TUN device will provide. All current + // routes for the tunnel that do not exist in the |response| will be removed. + virtual bool UpdateRoutes(const IpRange& desired_range, + const std::vector& desired_routes); + + // Same as UpdateRoutes, but will wait and retry up to the number of times + // given by |retries| before giving up. This is an unpleasant workaround to + // deal with older kernels that aren't always able to set a route with a + // source address immediately after adding the address to the interface. + // + // TODO(b/179430548): Remove this once we've root-caused the underlying issue. + virtual bool UpdateRoutesWithRetries( + const IpRange& desired_range, const std::vector& desired_routes, + int retries); + + virtual void RegisterAddressUpdateCallback( + const std::function& cb); + + virtual QuicIpAddress current_address(); + + private: + // Update the IP Rules, this should only be used by UpdateRoutes. + bool UpdateRules(IpRange desired_range); + + const std::string ifname_; + const bool setup_tun_; + + NetlinkInterface* netlink_; + + QuicIpAddress current_address_; + + std::vector> address_update_cbs_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_BONNET_TUN_DEVICE_CONTROLLER_H_ diff --git a/quiche/quic/qbone/bonnet/tun_device_controller_test.cc b/quiche/quic/qbone/bonnet/tun_device_controller_test.cc new file mode 100644 index 000000000000..18488435da32 --- /dev/null +++ b/quiche/quic/qbone/bonnet/tun_device_controller_test.cc @@ -0,0 +1,263 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/bonnet/tun_device_controller.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/qbone/platform/mock_netlink.h" +#include "quiche/quic/qbone/qbone_constants.h" + +ABSL_DECLARE_FLAG(bool, qbone_tun_device_replace_default_routing_rules); +ABSL_DECLARE_FLAG(int, qbone_route_init_cwnd); + +namespace quic::test { +namespace { +using ::testing::Eq; + +constexpr int kIfindex = 42; +constexpr char kIfname[] = "qbone0"; + +const IpRange kIpRange = []() { + IpRange range; + QCHECK(range.FromString("2604:31c0:2::/64")); + return range; +}(); + +constexpr char kOldAddress[] = "1.2.3.4"; +constexpr int kOldPrefixLen = 24; + +using ::testing::_; +using ::testing::Invoke; +using ::testing::Return; +using ::testing::StrictMock; + +MATCHER_P(IpRangeEq, range, + absl::StrCat("expected IpRange to equal ", range.ToString())) { + return arg == range; +} + +class TunDeviceControllerTest : public QuicTest { + public: + TunDeviceControllerTest() + : controller_(kIfname, true, &netlink_), + link_local_range_(*QboneConstants::TerminatorLocalAddressRange()) { + controller_.RegisterAddressUpdateCallback( + [this](QuicIpAddress address) { notified_address_ = address; }); + } + + protected: + void ExpectLinkInfo(const std::string& interface_name, int ifindex) { + EXPECT_CALL(netlink_, GetLinkInfo(interface_name, _)) + .WillOnce(Invoke([ifindex](absl::string_view ifname, + NetlinkInterface::LinkInfo* link_info) { + link_info->index = ifindex; + return true; + })); + } + + MockNetlink netlink_; + TunDeviceController controller_; + QuicIpAddress notified_address_; + + IpRange link_local_range_; +}; + +TEST_F(TunDeviceControllerTest, AddressAppliedWhenNoneExisted) { + ExpectLinkInfo(kIfname, kIfindex); + + EXPECT_CALL(netlink_, GetAddresses(kIfindex, _, _, _)).WillOnce(Return(true)); + + EXPECT_CALL(netlink_, + ChangeLocalAddress( + kIfindex, NetlinkInterface::Verb::kAdd, + kIpRange.FirstAddressInRange(), kIpRange.prefix_length(), + IFA_F_PERMANENT | IFA_F_NODAD, RT_SCOPE_LINK, _)) + .WillOnce(Return(true)); + + EXPECT_TRUE(controller_.UpdateAddress(kIpRange)); + EXPECT_THAT(notified_address_, Eq(kIpRange.FirstAddressInRange())); +} + +TEST_F(TunDeviceControllerTest, OldAddressesAreRemoved) { + ExpectLinkInfo(kIfname, kIfindex); + + EXPECT_CALL(netlink_, GetAddresses(kIfindex, _, _, _)) + .WillOnce(Invoke([](int interface_index, uint8_t unwanted_flags, + std::vector* addresses, + int* num_ipv6_nodad_dadfailed_addresses) { + NetlinkInterface::AddressInfo info{}; + info.interface_address.FromString(kOldAddress); + info.prefix_length = kOldPrefixLen; + addresses->emplace_back(info); + return true; + })); + + QuicIpAddress old_address; + old_address.FromString(kOldAddress); + + EXPECT_CALL(netlink_, + ChangeLocalAddress(kIfindex, NetlinkInterface::Verb::kRemove, + old_address, kOldPrefixLen, _, _, _)) + .WillOnce(Return(true)); + + EXPECT_CALL(netlink_, + ChangeLocalAddress( + kIfindex, NetlinkInterface::Verb::kAdd, + kIpRange.FirstAddressInRange(), kIpRange.prefix_length(), + IFA_F_PERMANENT | IFA_F_NODAD, RT_SCOPE_LINK, _)) + .WillOnce(Return(true)); + + EXPECT_TRUE(controller_.UpdateAddress(kIpRange)); + EXPECT_THAT(notified_address_, Eq(kIpRange.FirstAddressInRange())); +} + +TEST_F(TunDeviceControllerTest, UpdateRoutesRemovedOldRoutes) { + ExpectLinkInfo(kIfname, kIfindex); + + const int num_matching_routes = 3; + EXPECT_CALL(netlink_, GetRouteInfo(_)) + .WillOnce( + Invoke([](std::vector* routing_rules) { + NetlinkInterface::RoutingRule non_matching_route{}; + non_matching_route.table = QboneConstants::kQboneRouteTableId; + non_matching_route.out_interface = kIfindex + 1; + routing_rules->push_back(non_matching_route); + + NetlinkInterface::RoutingRule matching_route{}; + matching_route.table = QboneConstants::kQboneRouteTableId; + matching_route.out_interface = kIfindex; + matching_route.init_cwnd = NetlinkInterface::kUnspecifiedInitCwnd; + for (int i = 0; i < num_matching_routes; i++) { + routing_rules->push_back(matching_route); + } + + NetlinkInterface::RoutingRule non_matching_table{}; + non_matching_table.table = QboneConstants::kQboneRouteTableId + 1; + non_matching_table.out_interface = kIfindex; + routing_rules->push_back(non_matching_table); + return true; + })); + + EXPECT_CALL(netlink_, + ChangeRoute(NetlinkInterface::Verb::kRemove, + QboneConstants::kQboneRouteTableId, _, _, _, kIfindex, + NetlinkInterface::kUnspecifiedInitCwnd)) + .Times(num_matching_routes) + .WillRepeatedly(Return(true)); + + EXPECT_CALL(netlink_, GetRuleInfo(_)).WillOnce(Return(true)); + + EXPECT_CALL(netlink_, ChangeRule(NetlinkInterface::Verb::kAdd, + QboneConstants::kQboneRouteTableId, + IpRangeEq(kIpRange))) + .WillOnce(Return(true)); + + EXPECT_CALL(netlink_, + ChangeRoute(NetlinkInterface::Verb::kReplace, + QboneConstants::kQboneRouteTableId, + IpRangeEq(link_local_range_), _, _, kIfindex, _)) + .WillOnce(Return(true)); + + EXPECT_TRUE(controller_.UpdateRoutes(kIpRange, {})); +} + +TEST_F(TunDeviceControllerTest, UpdateRoutesAddsNewRoutes) { + ExpectLinkInfo(kIfname, kIfindex); + + EXPECT_CALL(netlink_, GetRouteInfo(_)).WillOnce(Return(true)); + + EXPECT_CALL(netlink_, GetRuleInfo(_)).WillOnce(Return(true)); + + absl::SetFlag(&FLAGS_qbone_route_init_cwnd, 32); + EXPECT_CALL(netlink_, ChangeRoute(NetlinkInterface::Verb::kReplace, + QboneConstants::kQboneRouteTableId, + IpRangeEq(kIpRange), _, _, kIfindex, + absl::GetFlag(FLAGS_qbone_route_init_cwnd))) + .Times(2) + .WillRepeatedly(Return(true)) + .RetiresOnSaturation(); + + EXPECT_CALL(netlink_, ChangeRule(NetlinkInterface::Verb::kAdd, + QboneConstants::kQboneRouteTableId, + IpRangeEq(kIpRange))) + .WillOnce(Return(true)); + + EXPECT_CALL(netlink_, + ChangeRoute(NetlinkInterface::Verb::kReplace, + QboneConstants::kQboneRouteTableId, + IpRangeEq(link_local_range_), _, _, kIfindex, _)) + .WillOnce(Return(true)); + + EXPECT_TRUE(controller_.UpdateRoutes(kIpRange, {kIpRange, kIpRange})); +} + +TEST_F(TunDeviceControllerTest, EmptyUpdateRouteKeepsLinkLocalRoute) { + ExpectLinkInfo(kIfname, kIfindex); + + EXPECT_CALL(netlink_, GetRouteInfo(_)).WillOnce(Return(true)); + + EXPECT_CALL(netlink_, GetRuleInfo(_)).WillOnce(Return(true)); + + EXPECT_CALL(netlink_, ChangeRule(NetlinkInterface::Verb::kAdd, + QboneConstants::kQboneRouteTableId, + IpRangeEq(kIpRange))) + .WillOnce(Return(true)); + + EXPECT_CALL(netlink_, + ChangeRoute(NetlinkInterface::Verb::kReplace, + QboneConstants::kQboneRouteTableId, + IpRangeEq(link_local_range_), _, _, kIfindex, _)) + .WillOnce(Return(true)); + + EXPECT_TRUE(controller_.UpdateRoutes(kIpRange, {})); +} + +TEST_F(TunDeviceControllerTest, DisablingRoutingRulesSkipsRuleCreation) { + absl::SetFlag(&FLAGS_qbone_tun_device_replace_default_routing_rules, false); + ExpectLinkInfo(kIfname, kIfindex); + + EXPECT_CALL(netlink_, GetRouteInfo(_)).WillOnce(Return(true)); + + EXPECT_CALL(netlink_, ChangeRoute(NetlinkInterface::Verb::kReplace, + QboneConstants::kQboneRouteTableId, + IpRangeEq(kIpRange), _, _, kIfindex, _)) + .Times(2) + .WillRepeatedly(Return(true)) + .RetiresOnSaturation(); + + EXPECT_CALL(netlink_, + ChangeRoute(NetlinkInterface::Verb::kReplace, + QboneConstants::kQboneRouteTableId, + IpRangeEq(link_local_range_), _, _, kIfindex, _)) + .WillOnce(Return(true)); + + EXPECT_TRUE(controller_.UpdateRoutes(kIpRange, {kIpRange, kIpRange})); +} + +class DisabledTunDeviceControllerTest : public QuicTest { + public: + DisabledTunDeviceControllerTest() + : controller_(kIfname, false, &netlink_), + link_local_range_(*QboneConstants::TerminatorLocalAddressRange()) {} + + StrictMock netlink_; + TunDeviceController controller_; + + IpRange link_local_range_; +}; + +TEST_F(DisabledTunDeviceControllerTest, UpdateRoutesIsNop) { + EXPECT_THAT(controller_.UpdateRoutes(kIpRange, {}), Eq(true)); +} + +TEST_F(DisabledTunDeviceControllerTest, UpdateAddressIsNop) { + EXPECT_THAT(controller_.UpdateAddress(kIpRange), Eq(true)); +} + +} // namespace +} // namespace quic::test diff --git a/quiche/quic/qbone/bonnet/tun_device_interface.h b/quiche/quic/qbone/bonnet/tun_device_interface.h new file mode 100644 index 000000000000..e88efa97186a --- /dev/null +++ b/quiche/quic/qbone/bonnet/tun_device_interface.h @@ -0,0 +1,38 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_BONNET_TUN_DEVICE_INTERFACE_H_ +#define QUICHE_QUIC_QBONE_BONNET_TUN_DEVICE_INTERFACE_H_ + +#include + +namespace quic { + +// An interface with methods for interacting with a TUN device. +class TunDeviceInterface { + public: + virtual ~TunDeviceInterface() = default; + + // Actually creates/reopens and configures the device. + virtual bool Init() = 0; + + // Marks the interface up to start receiving packets. + virtual bool Up() = 0; + + // Marks the interface down to stop receiving packets. + virtual bool Down() = 0; + + // Closes the open file descriptor for the TUN device (if one exists). + // It is safe to reinitialize and reuse this TunTapDevice after calling + // CloseDevice. + virtual void CloseDevice() = 0; + + // Gets the file descriptor that can be used to send/receive packets. + // This returns -1 when the TUN device is in an invalid state. + virtual int GetFileDescriptor() const = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_BONNET_TUN_DEVICE_INTERFACE_H_ diff --git a/quiche/quic/qbone/bonnet/tun_device_packet_exchanger.cc b/quiche/quic/qbone/bonnet/tun_device_packet_exchanger.cc new file mode 100644 index 000000000000..6c9cb062bd88 --- /dev/null +++ b/quiche/quic/qbone/bonnet/tun_device_packet_exchanger.cc @@ -0,0 +1,230 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/bonnet/tun_device_packet_exchanger.h" + +#include +#include + +#include + +#include "absl/strings/str_cat.h" +#include "quiche/quic/qbone/platform/icmp_packet.h" +#include "quiche/quic/qbone/platform/netlink_interface.h" +#include "quiche/quic/qbone/qbone_constants.h" + +namespace quic { + +TunDevicePacketExchanger::TunDevicePacketExchanger( + size_t mtu, KernelInterface* kernel, NetlinkInterface* netlink, + QbonePacketExchanger::Visitor* visitor, size_t max_pending_packets, + bool is_tap, StatsInterface* stats, absl::string_view ifname) + : QbonePacketExchanger(visitor, max_pending_packets), + mtu_(mtu), + kernel_(kernel), + netlink_(netlink), + ifname_(ifname), + is_tap_(is_tap), + stats_(stats) { + if (is_tap_) { + mtu_ += ETH_HLEN; + } +} + +bool TunDevicePacketExchanger::WritePacket(const char* packet, size_t size, + bool* blocked, std::string* error) { + *blocked = false; + if (fd_ < 0) { + *error = absl::StrCat("Invalid file descriptor of the TUN device: ", fd_); + stats_->OnWriteError(error); + return false; + } + + auto buffer = std::make_unique(packet, size); + if (is_tap_) { + buffer = ApplyL2Headers(*buffer); + } + int result = kernel_->write(fd_, buffer->data(), buffer->length()); + if (result == -1) { + if (errno == EWOULDBLOCK || errno == EAGAIN) { + // The tunnel is blocked. Note that this does not mean the receive buffer + // of a TCP connection is filled. This simply means the TUN device itself + // is blocked on handing packets to the rest part of the kernel. + *error = absl::StrCat("Write to the TUN device was blocked: ", errno); + *blocked = true; + stats_->OnWriteError(error); + } + return false; + } + stats_->OnPacketWritten(result); + + return true; +} + +std::unique_ptr TunDevicePacketExchanger::ReadPacket( + bool* blocked, std::string* error) { + *blocked = false; + if (fd_ < 0) { + *error = absl::StrCat("Invalid file descriptor of the TUN device: ", fd_); + stats_->OnReadError(error); + return nullptr; + } + // Reading on a TUN device returns a packet at a time. If the packet is longer + // than the buffer, it's truncated. + auto read_buffer = std::make_unique(mtu_); + int result = kernel_->read(fd_, read_buffer.get(), mtu_); + // Note that 0 means end of file, but we're talking about a TUN device - there + // is no end of file. Therefore 0 also indicates error. + if (result <= 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + *error = absl::StrCat("Read from the TUN device was blocked: ", errno); + *blocked = true; + stats_->OnReadError(error); + } + return nullptr; + } + + auto buffer = std::make_unique(read_buffer.release(), result, true); + if (is_tap_) { + buffer = ConsumeL2Headers(*buffer); + } + if (buffer) { + stats_->OnPacketRead(buffer->length()); + } + return buffer; +} + +void TunDevicePacketExchanger::set_file_descriptor(int fd) { fd_ = fd; } + +const TunDevicePacketExchanger::StatsInterface* +TunDevicePacketExchanger::stats_interface() const { + return stats_; +} + +std::unique_ptr TunDevicePacketExchanger::ApplyL2Headers( + const QuicData& l3_packet) { + if (is_tap_ && !mac_initialized_) { + NetlinkInterface::LinkInfo link_info{}; + if (netlink_->GetLinkInfo(ifname_, &link_info)) { + memcpy(tap_mac_, link_info.hardware_address, ETH_ALEN); + mac_initialized_ = true; + } else { + QUIC_LOG_EVERY_N_SEC(ERROR, 30) + << "Unable to get link info for: " << ifname_; + } + } + + const auto l2_packet_size = l3_packet.length() + ETH_HLEN; + auto l2_buffer = std::make_unique(l2_packet_size); + + // Populate the Ethernet header + auto* hdr = reinterpret_cast(l2_buffer.get()); + // Set src & dst to my own address + memcpy(hdr->h_dest, tap_mac_, ETH_ALEN); + memcpy(hdr->h_source, tap_mac_, ETH_ALEN); + // Assume ipv6 for now + // TODO(b/195113643): Support additional protocols. + hdr->h_proto = absl::ghtons(ETH_P_IPV6); + + // Copy the l3 packet into buffer, just after the ethernet header. + memcpy(l2_buffer.get() + ETH_HLEN, l3_packet.data(), l3_packet.length()); + + return std::make_unique(l2_buffer.release(), l2_packet_size, true); +} + +std::unique_ptr TunDevicePacketExchanger::ConsumeL2Headers( + const QuicData& l2_packet) { + if (l2_packet.length() < ETH_HLEN) { + // Packet is too short for ethernet headers. Drop it. + return nullptr; + } + auto* hdr = reinterpret_cast(l2_packet.data()); + if (hdr->h_proto != absl::ghtons(ETH_P_IPV6)) { + return nullptr; + } + constexpr auto kIp6PrefixLen = ETH_HLEN + sizeof(ip6_hdr); + constexpr auto kIcmp6PrefixLen = kIp6PrefixLen + sizeof(icmp6_hdr); + if (l2_packet.length() < kIp6PrefixLen) { + // Packet is too short to be ipv6. Drop it. + return nullptr; + } + auto* ip_hdr = reinterpret_cast(l2_packet.data() + ETH_HLEN); + const bool is_icmp = ip_hdr->ip6_ctlun.ip6_un1.ip6_un1_nxt == IPPROTO_ICMPV6; + + bool is_neighbor_solicit = false; + if (is_icmp) { + if (l2_packet.length() < kIcmp6PrefixLen) { + // Packet is too short to be icmp6. Drop it. + return nullptr; + } + is_neighbor_solicit = + reinterpret_cast(l2_packet.data() + kIp6PrefixLen) + ->icmp6_type == ND_NEIGHBOR_SOLICIT; + } + + if (is_neighbor_solicit) { + // If we've received a neighbor solicitation, craft an advertisement to + // respond with and write it back to the local interface. + auto* icmp6_payload = l2_packet.data() + kIcmp6PrefixLen; + + QuicIpAddress target_address( + *reinterpret_cast(icmp6_payload)); + if (target_address != *QboneConstants::GatewayAddress()) { + // Only respond to solicitations for our gateway address + return nullptr; + } + + // Neighbor Advertisement crafted per: + // https://datatracker.ietf.org/doc/html/rfc4861#section-4.4 + // + // Using the Target link-layer address option defined at: + // https://datatracker.ietf.org/doc/html/rfc4861#section-4.6.1 + constexpr size_t kIcmpv6OptionSize = 8; + const int payload_size = sizeof(in6_addr) + kIcmpv6OptionSize; + auto payload = std::make_unique(payload_size); + // Place the solicited IPv6 address at the beginning of the response payload + memcpy(payload.get(), icmp6_payload, sizeof(in6_addr)); + // Setup the Target link-layer address option: + // 0 1 2 3 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Type | Length | Link-Layer Address ... + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + int pos = sizeof(in6_addr); + payload[pos++] = ND_OPT_TARGET_LINKADDR; // Type + payload[pos++] = 1; // Length in units of 8 octets + memcpy(&payload[pos], tap_mac_, ETH_ALEN); // This interfaces' MAC address + + // Populate the ICMPv6 header + icmp6_hdr response_hdr{}; + response_hdr.icmp6_type = ND_NEIGHBOR_ADVERT; + // Set the solicited bit to true + response_hdr.icmp6_dataun.icmp6_un_data8[0] = 64; + // Craft the full ICMPv6 packet and then ship it off to WritePacket + // to have it frame it with L2 headers and send it back to the requesting + // neighbor. + CreateIcmpPacket(ip_hdr->ip6_src, ip_hdr->ip6_src, response_hdr, + absl::string_view(payload.get(), payload_size), + [this](absl::string_view packet) { + bool blocked; + std::string error; + WritePacket(packet.data(), packet.size(), &blocked, + &error); + }); + // Do not forward the neighbor solicitation through the tunnel since it's + // link-local. + return nullptr; + } + + // If this isn't a Neighbor Solicitation, remove the L2 headers and forward + // it as though it were an L3 packet. + const auto l3_packet_size = l2_packet.length() - ETH_HLEN; + auto shift_buffer = std::make_unique(l3_packet_size); + memcpy(shift_buffer.get(), l2_packet.data() + ETH_HLEN, l3_packet_size); + + return std::make_unique(shift_buffer.release(), l3_packet_size, + true); +} + +} // namespace quic diff --git a/quiche/quic/qbone/bonnet/tun_device_packet_exchanger.h b/quiche/quic/qbone/bonnet/tun_device_packet_exchanger.h new file mode 100644 index 000000000000..2417d5051f84 --- /dev/null +++ b/quiche/quic/qbone/bonnet/tun_device_packet_exchanger.h @@ -0,0 +1,86 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_BONNET_TUN_DEVICE_PACKET_EXCHANGER_H_ +#define QUICHE_QUIC_QBONE_BONNET_TUN_DEVICE_PACKET_EXCHANGER_H_ + +#include + +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/qbone/platform/kernel_interface.h" +#include "quiche/quic/qbone/platform/netlink_interface.h" +#include "quiche/quic/qbone/qbone_client_interface.h" +#include "quiche/quic/qbone/qbone_packet_exchanger.h" + +namespace quic { + +class TunDevicePacketExchanger : public QbonePacketExchanger { + public: + class StatsInterface { + public: + StatsInterface() = default; + + StatsInterface(const StatsInterface&) = delete; + StatsInterface& operator=(const StatsInterface&) = delete; + + StatsInterface(StatsInterface&&) = delete; + StatsInterface& operator=(StatsInterface&&) = delete; + + virtual ~StatsInterface() = default; + + virtual void OnPacketRead(size_t count) = 0; + virtual void OnPacketWritten(size_t count) = 0; + virtual void OnReadError(std::string* error) = 0; + virtual void OnWriteError(std::string* error) = 0; + + ABSL_MUST_USE_RESULT virtual int64_t PacketsRead() const = 0; + ABSL_MUST_USE_RESULT virtual int64_t PacketsWritten() const = 0; + }; + + // |mtu| is the mtu of the TUN device. + // |kernel| is not owned but should out live objects of this class. + // |visitor| is not owned but should out live objects of this class. + // |max_pending_packets| controls the number of packets to be queued should + // the TUN device become blocked. + // |stats| is notified about packet read/write statistics. It is not owned, + // but should outlive objects of this class. + TunDevicePacketExchanger(size_t mtu, KernelInterface* kernel, + NetlinkInterface* netlink, + QbonePacketExchanger::Visitor* visitor, + size_t max_pending_packets, bool is_tap, + StatsInterface* stats, absl::string_view ifname); + + void set_file_descriptor(int fd); + + ABSL_MUST_USE_RESULT const StatsInterface* stats_interface() const; + + private: + // From QbonePacketExchanger. + std::unique_ptr ReadPacket(bool* blocked, + std::string* error) override; + + // From QbonePacketExchanger. + bool WritePacket(const char* packet, size_t size, bool* blocked, + std::string* error) override; + + std::unique_ptr ApplyL2Headers(const QuicData& l3_packet); + + std::unique_ptr ConsumeL2Headers(const QuicData& l2_packet); + + int fd_ = -1; + size_t mtu_; + KernelInterface* kernel_; + NetlinkInterface* netlink_; + const std::string ifname_; + + const bool is_tap_; + uint8_t tap_mac_[ETH_ALEN]{}; + bool mac_initialized_ = false; + + StatsInterface* stats_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_BONNET_TUN_DEVICE_PACKET_EXCHANGER_H_ diff --git a/quiche/quic/qbone/bonnet/tun_device_packet_exchanger_test.cc b/quiche/quic/qbone/bonnet/tun_device_packet_exchanger_test.cc new file mode 100644 index 000000000000..a6d3a39bfcf7 --- /dev/null +++ b/quiche/quic/qbone/bonnet/tun_device_packet_exchanger_test.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/bonnet/tun_device_packet_exchanger.h" + +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/qbone/bonnet/mock_packet_exchanger_stats_interface.h" +#include "quiche/quic/qbone/mock_qbone_client.h" +#include "quiche/quic/qbone/platform/mock_kernel.h" + +namespace quic::test { +namespace { + +const size_t kMtu = 1000; +const size_t kMaxPendingPackets = 5; +const int kFd = 15; + +using ::testing::_; +using ::testing::Invoke; +using ::testing::StrEq; +using ::testing::StrictMock; + +class MockVisitor : public QbonePacketExchanger::Visitor { + public: + MOCK_METHOD(void, OnReadError, (const std::string&), (override)); + MOCK_METHOD(void, OnWriteError, (const std::string&), (override)); +}; + +class TunDevicePacketExchangerTest : public QuicTest { + protected: + TunDevicePacketExchangerTest() + : exchanger_(kMtu, &mock_kernel_, nullptr, &mock_visitor_, + kMaxPendingPackets, false, &mock_stats_, + absl::string_view()) { + exchanger_.set_file_descriptor(kFd); + } + + ~TunDevicePacketExchangerTest() override = default; + + MockKernel mock_kernel_; + StrictMock mock_visitor_; + StrictMock mock_client_; + StrictMock mock_stats_; + TunDevicePacketExchanger exchanger_; +}; + +TEST_F(TunDevicePacketExchangerTest, WritePacketReturnsFalseOnError) { + std::string packet = "fake packet"; + EXPECT_CALL(mock_kernel_, write(kFd, _, packet.size())) + .WillOnce(Invoke([](int fd, const void* buf, size_t count) { + errno = ECOMM; + return -1; + })); + + EXPECT_CALL(mock_visitor_, OnWriteError(_)); + exchanger_.WritePacketToNetwork(packet.data(), packet.size()); +} + +TEST_F(TunDevicePacketExchangerTest, + WritePacketReturnFalseAndBlockedOnBlockedTunnel) { + std::string packet = "fake packet"; + EXPECT_CALL(mock_kernel_, write(kFd, _, packet.size())) + .WillOnce(Invoke([](int fd, const void* buf, size_t count) { + errno = EAGAIN; + return -1; + })); + + EXPECT_CALL(mock_stats_, OnWriteError(_)).Times(1); + exchanger_.WritePacketToNetwork(packet.data(), packet.size()); +} + +TEST_F(TunDevicePacketExchangerTest, WritePacketReturnsTrueOnSuccessfulWrite) { + std::string packet = "fake packet"; + EXPECT_CALL(mock_kernel_, write(kFd, _, packet.size())) + .WillOnce(Invoke([packet](int fd, const void* buf, size_t count) { + EXPECT_THAT(reinterpret_cast(buf), StrEq(packet)); + return count; + })); + + EXPECT_CALL(mock_stats_, OnPacketWritten(_)).Times(1); + exchanger_.WritePacketToNetwork(packet.data(), packet.size()); +} + +TEST_F(TunDevicePacketExchangerTest, ReadPacketReturnsNullOnError) { + EXPECT_CALL(mock_kernel_, read(kFd, _, kMtu)) + .WillOnce(Invoke([](int fd, void* buf, size_t count) { + errno = ECOMM; + return -1; + })); + EXPECT_CALL(mock_visitor_, OnReadError(_)); + exchanger_.ReadAndDeliverPacket(&mock_client_); +} + +TEST_F(TunDevicePacketExchangerTest, ReadPacketReturnsNullOnBlockedRead) { + EXPECT_CALL(mock_kernel_, read(kFd, _, kMtu)) + .WillOnce(Invoke([](int fd, void* buf, size_t count) { + errno = EAGAIN; + return -1; + })); + EXPECT_CALL(mock_stats_, OnReadError(_)).Times(1); + EXPECT_FALSE(exchanger_.ReadAndDeliverPacket(&mock_client_)); +} + +TEST_F(TunDevicePacketExchangerTest, + ReadPacketReturnsThePacketOnSuccessfulRead) { + std::string packet = "fake_packet"; + EXPECT_CALL(mock_kernel_, read(kFd, _, kMtu)) + .WillOnce(Invoke([packet](int fd, void* buf, size_t count) { + memcpy(buf, packet.data(), packet.size()); + return packet.size(); + })); + EXPECT_CALL(mock_client_, ProcessPacketFromNetwork(StrEq(packet))); + EXPECT_CALL(mock_stats_, OnPacketRead(_)).Times(1); + EXPECT_TRUE(exchanger_.ReadAndDeliverPacket(&mock_client_)); +} + +} // namespace +} // namespace quic::test diff --git a/quiche/quic/qbone/bonnet/tun_device_test.cc b/quiche/quic/qbone/bonnet/tun_device_test.cc new file mode 100644 index 000000000000..cc82a2f5ba55 --- /dev/null +++ b/quiche/quic/qbone/bonnet/tun_device_test.cc @@ -0,0 +1,211 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/bonnet/tun_device.h" + +#include +#include +#include + +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/qbone/platform/mock_kernel.h" + +namespace quic::test { +namespace { + +using ::testing::_; +using ::testing::AnyNumber; +using ::testing::Invoke; +using ::testing::Return; +using ::testing::StrEq; +using ::testing::Unused; + +const char kDeviceName[] = "tun0"; +const int kSupportedFeatures = + IFF_TUN | IFF_TAP | IFF_MULTI_QUEUE | IFF_ONE_QUEUE | IFF_NO_PI; + +// Quite a bit of EXPECT_CALL().Times(AnyNumber()).WillRepeatedly() are used to +// make sure we can correctly set common expectations and override the +// expectation with later call to EXPECT_CALL(). ON_CALL cannot be used here +// since when EPXECT_CALL overrides ON_CALL, it ignores the parameter matcher +// which results in unexpected call even if ON_CALL exists. +class TunDeviceTest : public QuicTest { + protected: + void SetUp() override { + EXPECT_CALL(mock_kernel_, socket(AF_INET6, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke([this](Unused, Unused, Unused) { + EXPECT_CALL(mock_kernel_, close(next_fd_)).WillOnce(Return(0)); + return next_fd_++; + })); + } + + // Set the expectations for calling Init(). + void SetInitExpectations(int mtu, bool persist) { + EXPECT_CALL(mock_kernel_, open(StrEq("/dev/net/tun"), _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke([this](Unused, Unused) { + EXPECT_CALL(mock_kernel_, close(next_fd_)).WillOnce(Return(0)); + return next_fd_++; + })); + EXPECT_CALL(mock_kernel_, ioctl(_, TUNGETFEATURES, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke([](Unused, Unused, void* argp) { + auto* actual_flags = reinterpret_cast(argp); + *actual_flags = kSupportedFeatures; + return 0; + })); + EXPECT_CALL(mock_kernel_, ioctl(_, TUNSETIFF, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke([](Unused, Unused, void* argp) { + auto* ifr = reinterpret_cast(argp); + EXPECT_EQ(IFF_TUN | IFF_MULTI_QUEUE | IFF_NO_PI, ifr->ifr_flags); + EXPECT_THAT(ifr->ifr_name, StrEq(kDeviceName)); + return 0; + })); + EXPECT_CALL(mock_kernel_, ioctl(_, TUNSETPERSIST, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke([persist](Unused, Unused, void* argp) { + auto* ifr = reinterpret_cast(argp); + if (persist) { + EXPECT_THAT(ifr->ifr_name, StrEq(kDeviceName)); + } else { + EXPECT_EQ(nullptr, ifr); + } + return 0; + })); + EXPECT_CALL(mock_kernel_, ioctl(_, SIOCSIFMTU, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke([mtu](Unused, Unused, void* argp) { + auto* ifr = reinterpret_cast(argp); + EXPECT_EQ(mtu, ifr->ifr_mtu); + EXPECT_THAT(ifr->ifr_name, StrEq(kDeviceName)); + return 0; + })); + } + + // Expect that Up() will be called. Force the call to fail when fail == true. + void ExpectUp(bool fail) { + EXPECT_CALL(mock_kernel_, ioctl(_, SIOCSIFFLAGS, _)) + .WillOnce(Invoke([fail](Unused, Unused, void* argp) { + auto* ifr = reinterpret_cast(argp); + EXPECT_TRUE(ifr->ifr_flags & IFF_UP); + EXPECT_THAT(ifr->ifr_name, StrEq(kDeviceName)); + if (fail) { + return -1; + } else { + return 0; + } + })); + } + + // Expect that Down() will be called *after* the interface is up. Force the + // call to fail when fail == true. + void ExpectDown(bool fail) { + EXPECT_CALL(mock_kernel_, ioctl(_, SIOCSIFFLAGS, _)) + .WillOnce(Invoke([fail](Unused, Unused, void* argp) { + auto* ifr = reinterpret_cast(argp); + EXPECT_FALSE(ifr->ifr_flags & IFF_UP); + EXPECT_THAT(ifr->ifr_name, StrEq(kDeviceName)); + if (fail) { + return -1; + } else { + return 0; + } + })); + } + + MockKernel mock_kernel_; + int next_fd_ = 100; +}; + +// A TunTapDevice can be initialized and up +TEST_F(TunDeviceTest, BasicWorkFlow) { + SetInitExpectations(/* mtu = */ 1500, /* persist = */ false); + TunTapDevice tun_device(kDeviceName, 1500, false, true, false, &mock_kernel_); + EXPECT_TRUE(tun_device.Init()); + EXPECT_GT(tun_device.GetFileDescriptor(), -1); + + ExpectUp(/* fail = */ false); + EXPECT_TRUE(tun_device.Up()); + ExpectDown(/* fail = */ false); +} + +TEST_F(TunDeviceTest, FailToOpenTunDevice) { + SetInitExpectations(/* mtu = */ 1500, /* persist = */ false); + EXPECT_CALL(mock_kernel_, open(StrEq("/dev/net/tun"), _)) + .WillOnce(Return(-1)); + TunTapDevice tun_device(kDeviceName, 1500, false, true, false, &mock_kernel_); + EXPECT_FALSE(tun_device.Init()); + EXPECT_EQ(tun_device.GetFileDescriptor(), -1); + ExpectDown(false); +} + +TEST_F(TunDeviceTest, FailToCheckFeature) { + SetInitExpectations(/* mtu = */ 1500, /* persist = */ false); + EXPECT_CALL(mock_kernel_, ioctl(_, TUNGETFEATURES, _)).WillOnce(Return(-1)); + TunTapDevice tun_device(kDeviceName, 1500, false, true, false, &mock_kernel_); + EXPECT_FALSE(tun_device.Init()); + EXPECT_EQ(tun_device.GetFileDescriptor(), -1); + ExpectDown(false); +} + +TEST_F(TunDeviceTest, TooFewFeature) { + SetInitExpectations(/* mtu = */ 1500, /* persist = */ false); + EXPECT_CALL(mock_kernel_, ioctl(_, TUNGETFEATURES, _)) + .WillOnce(Invoke([](Unused, Unused, void* argp) { + int* actual_features = reinterpret_cast(argp); + *actual_features = IFF_TUN | IFF_ONE_QUEUE; + return 0; + })); + TunTapDevice tun_device(kDeviceName, 1500, false, true, false, &mock_kernel_); + EXPECT_FALSE(tun_device.Init()); + EXPECT_EQ(tun_device.GetFileDescriptor(), -1); + ExpectDown(false); +} + +TEST_F(TunDeviceTest, FailToSetFlag) { + SetInitExpectations(/* mtu = */ 1500, /* persist = */ true); + EXPECT_CALL(mock_kernel_, ioctl(_, TUNSETIFF, _)).WillOnce(Return(-1)); + TunTapDevice tun_device(kDeviceName, 1500, true, true, false, &mock_kernel_); + EXPECT_FALSE(tun_device.Init()); + EXPECT_EQ(tun_device.GetFileDescriptor(), -1); +} + +TEST_F(TunDeviceTest, FailToPersistDevice) { + SetInitExpectations(/* mtu = */ 1500, /* persist = */ true); + EXPECT_CALL(mock_kernel_, ioctl(_, TUNSETPERSIST, _)).WillOnce(Return(-1)); + TunTapDevice tun_device(kDeviceName, 1500, true, true, false, &mock_kernel_); + EXPECT_FALSE(tun_device.Init()); + EXPECT_EQ(tun_device.GetFileDescriptor(), -1); +} + +TEST_F(TunDeviceTest, FailToOpenSocket) { + SetInitExpectations(/* mtu = */ 1500, /* persist = */ true); + EXPECT_CALL(mock_kernel_, socket(AF_INET6, _, _)).WillOnce(Return(-1)); + TunTapDevice tun_device(kDeviceName, 1500, true, true, false, &mock_kernel_); + EXPECT_FALSE(tun_device.Init()); + EXPECT_EQ(tun_device.GetFileDescriptor(), -1); +} + +TEST_F(TunDeviceTest, FailToSetMtu) { + SetInitExpectations(/* mtu = */ 1500, /* persist = */ true); + EXPECT_CALL(mock_kernel_, ioctl(_, SIOCSIFMTU, _)).WillOnce(Return(-1)); + TunTapDevice tun_device(kDeviceName, 1500, true, true, false, &mock_kernel_); + EXPECT_FALSE(tun_device.Init()); + EXPECT_EQ(tun_device.GetFileDescriptor(), -1); +} + +TEST_F(TunDeviceTest, FailToUp) { + SetInitExpectations(/* mtu = */ 1500, /* persist = */ true); + TunTapDevice tun_device(kDeviceName, 1500, true, true, false, &mock_kernel_); + EXPECT_TRUE(tun_device.Init()); + EXPECT_GT(tun_device.GetFileDescriptor(), -1); + + ExpectUp(/* fail = */ true); + EXPECT_FALSE(tun_device.Up()); +} + +} // namespace +} // namespace quic::test diff --git a/quiche/quic/qbone/mock_qbone_client.h b/quiche/quic/qbone/mock_qbone_client.h new file mode 100644 index 000000000000..8c278e915e24 --- /dev/null +++ b/quiche/quic/qbone/mock_qbone_client.h @@ -0,0 +1,22 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_MOCK_QBONE_CLIENT_H_ +#define QUICHE_QUIC_QBONE_MOCK_QBONE_CLIENT_H_ + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/qbone/qbone_client_interface.h" + +namespace quic { + +class MockQboneClient : public QboneClientInterface { + public: + MOCK_METHOD(void, ProcessPacketFromNetwork, (absl::string_view packet), + (override)); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_MOCK_QBONE_CLIENT_H_ diff --git a/quiche/quic/qbone/mock_qbone_server_session.h b/quiche/quic/qbone/mock_qbone_server_session.h new file mode 100644 index 000000000000..eb0e4ea85ce0 --- /dev/null +++ b/quiche/quic/qbone/mock_qbone_server_session.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_MOCK_QBONE_SERVER_SESSION_H_ +#define QUICHE_QUIC_QBONE_MOCK_QBONE_SERVER_SESSION_H_ + +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/qbone/qbone_server_session.h" + +namespace quic { + +class MockQboneServerSession : public QboneServerSession { + public: + explicit MockQboneServerSession(QuicConnection* connection) + : QboneServerSession(CurrentSupportedVersions(), connection, + /*owner=*/nullptr, + /*config=*/{}, + /*quic_crypto_server_config=*/nullptr, + /*compressed_certs_cache=*/nullptr, + /*writer=*/nullptr, + /*self_ip=*/QuicIpAddress::Loopback6(), + /*client_ip=*/QuicIpAddress::Loopback6(), + /*client_ip_subnet_length=*/0, + /*handler=*/nullptr) {} + + MOCK_METHOD(bool, SendClientRequest, (const QboneClientRequest&), (override)); + + MOCK_METHOD(void, ProcessPacketFromNetwork, (absl::string_view), (override)); + MOCK_METHOD(void, ProcessPacketFromPeer, (absl::string_view), (override)); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_MOCK_QBONE_SERVER_SESSION_H_ diff --git a/quiche/quic/qbone/platform/icmp_packet.cc b/quiche/quic/qbone/platform/icmp_packet.cc new file mode 100644 index 000000000000..3f15bb71f228 --- /dev/null +++ b/quiche/quic/qbone/platform/icmp_packet.cc @@ -0,0 +1,88 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/platform/icmp_packet.h" + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/qbone/platform/internet_checksum.h" +#include "quiche/common/quiche_endian.h" + +namespace quic { +namespace { + +constexpr size_t kIPv6AddressSize = sizeof(in6_addr); +constexpr size_t kIPv6HeaderSize = sizeof(ip6_hdr); +constexpr size_t kICMPv6HeaderSize = sizeof(icmp6_hdr); +constexpr size_t kIPv6MinPacketSize = 1280; + +// Hop limit set to 255 to satisfy: +// https://datatracker.ietf.org/doc/html/rfc4861#section-11.2 +constexpr size_t kIcmpTtl = 255; +constexpr size_t kICMPv6BodyMaxSize = + kIPv6MinPacketSize - kIPv6HeaderSize - kICMPv6HeaderSize; + +struct ICMPv6Packet { + ip6_hdr ip_header; + icmp6_hdr icmp_header; + uint8_t body[kICMPv6BodyMaxSize]; +}; + +// pseudo header as described in RFC 2460 Section 8.1 (excluding addresses) +struct IPv6PseudoHeader { + uint32_t payload_size{}; + uint8_t zeros[3] = {0, 0, 0}; + uint8_t next_header = IPPROTO_ICMPV6; +}; + +} // namespace + +void CreateIcmpPacket(in6_addr src, in6_addr dst, const icmp6_hdr& icmp_header, + absl::string_view body, + const std::function& cb) { + const size_t body_size = std::min(body.size(), kICMPv6BodyMaxSize); + const size_t payload_size = kICMPv6HeaderSize + body_size; + + ICMPv6Packet icmp_packet{}; + // Set version to 6. + icmp_packet.ip_header.ip6_vfc = 0x6 << 4; + // Set the payload size, protocol and TTL. + icmp_packet.ip_header.ip6_plen = + quiche::QuicheEndian::HostToNet16(payload_size); + icmp_packet.ip_header.ip6_nxt = IPPROTO_ICMPV6; + icmp_packet.ip_header.ip6_hops = kIcmpTtl; + // Set the source address to the specified self IP. + icmp_packet.ip_header.ip6_src = src; + icmp_packet.ip_header.ip6_dst = dst; + + icmp_packet.icmp_header = icmp_header; + // Per RFC 4443 Section 2.3, set checksum field to 0 prior to computing it + icmp_packet.icmp_header.icmp6_cksum = 0; + + IPv6PseudoHeader pseudo_header{}; + pseudo_header.payload_size = quiche::QuicheEndian::HostToNet32(payload_size); + + InternetChecksum checksum; + // Pseudoheader. + checksum.Update(icmp_packet.ip_header.ip6_src.s6_addr, kIPv6AddressSize); + checksum.Update(icmp_packet.ip_header.ip6_dst.s6_addr, kIPv6AddressSize); + checksum.Update(reinterpret_cast(&pseudo_header), + sizeof(pseudo_header)); + // ICMP header. + checksum.Update(reinterpret_cast(&icmp_packet.icmp_header), + sizeof(icmp_packet.icmp_header)); + // Body. + checksum.Update(body.data(), body_size); + icmp_packet.icmp_header.icmp6_cksum = checksum.Value(); + + memcpy(icmp_packet.body, body.data(), body_size); + + const char* packet = reinterpret_cast(&icmp_packet); + const size_t packet_size = offsetof(ICMPv6Packet, body) + body_size; + + cb(absl::string_view(packet, packet_size)); +} + +} // namespace quic diff --git a/quiche/quic/qbone/platform/icmp_packet.h b/quiche/quic/qbone/platform/icmp_packet.h new file mode 100644 index 000000000000..c83513c5bb18 --- /dev/null +++ b/quiche/quic/qbone/platform/icmp_packet.h @@ -0,0 +1,27 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_PLATFORM_ICMP_PACKET_H_ +#define QUICHE_QUIC_QBONE_PLATFORM_ICMP_PACKET_H_ + +#include +#include + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_ip_address.h" + +namespace quic { + +// Creates an ICMPv6 packet, returning a packed string representation of the +// packet to |cb|. The resulting packet is given to a callback because it's +// stack allocated inside CreateIcmpPacket. +void CreateIcmpPacket(in6_addr src, in6_addr dst, const icmp6_hdr& icmp_header, + absl::string_view body, + const std::function& cb); + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_PLATFORM_ICMP_PACKET_H_ diff --git a/quiche/quic/qbone/platform/icmp_packet_test.cc b/quiche/quic/qbone/platform/icmp_packet_test.cc new file mode 100644 index 000000000000..d94475103c3b --- /dev/null +++ b/quiche/quic/qbone/platform/icmp_packet_test.cc @@ -0,0 +1,128 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/platform/icmp_packet.h" + +#include + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/common/quiche_text_utils.h" + +namespace quic { +namespace { + +constexpr char kReferenceSourceAddress[] = "fe80:1:2:3:4::1"; +constexpr char kReferenceDestinationAddress[] = "fe80:4:3:2:1::1"; + +// clang-format off +constexpr uint8_t kReferenceICMPMessageBody[] { + 0xd2, 0x61, 0x29, 0x5b, 0x00, 0x00, 0x00, 0x00, + 0x0d, 0x59, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, + 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37 +}; + +constexpr uint8_t kReferenceICMPPacket[] = { + // START IPv6 Header + // IPv6 with zero TOS and flow label. + 0x60, 0x00, 0x00, 0x00, + // Payload is 64 bytes + 0x00, 0x40, + // Next header is 58 + 0x3a, + // Hop limit is 255 + 0xFF, + // Source address of fe80:1:2:3:4::1 + 0xfe, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, + 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + // Destination address of fe80:4:3:2:1::1 + 0xfe, 0x80, 0x00, 0x04, 0x00, 0x03, 0x00, 0x02, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + // END IPv6 Header + // START ICMPv6 Header + // Echo Request, zero code + 0x80, 0x00, + // Checksum + 0xec, 0x00, + // Identifier + 0xcb, 0x82, + // Sequence Number + 0x00, 0x01, + // END ICMPv6 Header + // Message body + 0xd2, 0x61, 0x29, 0x5b, 0x00, 0x00, 0x00, 0x00, + 0x0d, 0x59, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, + 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37 +}; +// clang-format on + +} // namespace + +TEST(IcmpPacketTest, CreatedPacketMatchesReference) { + QuicIpAddress src; + ASSERT_TRUE(src.FromString(kReferenceSourceAddress)); + in6_addr src_addr; + memcpy(src_addr.s6_addr, src.ToPackedString().data(), sizeof(in6_addr)); + + QuicIpAddress dst; + ASSERT_TRUE(dst.FromString(kReferenceDestinationAddress)); + in6_addr dst_addr; + memcpy(dst_addr.s6_addr, dst.ToPackedString().data(), sizeof(in6_addr)); + + icmp6_hdr icmp_header{}; + icmp_header.icmp6_type = ICMP6_ECHO_REQUEST; + icmp_header.icmp6_id = 0x82cb; + icmp_header.icmp6_seq = 0x0100; + + absl::string_view message_body = absl::string_view( + reinterpret_cast(kReferenceICMPMessageBody), 56); + absl::string_view expected_packet = absl::string_view( + reinterpret_cast(kReferenceICMPPacket), 104); + CreateIcmpPacket(src_addr, dst_addr, icmp_header, message_body, + [&expected_packet](absl::string_view packet) { + QUIC_LOG(INFO) << quiche::QuicheTextUtils::HexDump(packet); + ASSERT_EQ(packet, expected_packet); + }); +} + +TEST(IcmpPacketTest, NonZeroChecksumIsIgnored) { + QuicIpAddress src; + ASSERT_TRUE(src.FromString(kReferenceSourceAddress)); + in6_addr src_addr; + memcpy(src_addr.s6_addr, src.ToPackedString().data(), sizeof(in6_addr)); + + QuicIpAddress dst; + ASSERT_TRUE(dst.FromString(kReferenceDestinationAddress)); + in6_addr dst_addr; + memcpy(dst_addr.s6_addr, dst.ToPackedString().data(), sizeof(in6_addr)); + + icmp6_hdr icmp_header{}; + icmp_header.icmp6_type = ICMP6_ECHO_REQUEST; + icmp_header.icmp6_id = 0x82cb; + icmp_header.icmp6_seq = 0x0100; + // Set the checksum to a bogus value + icmp_header.icmp6_cksum = 0x1234; + + absl::string_view message_body = absl::string_view( + reinterpret_cast(kReferenceICMPMessageBody), 56); + absl::string_view expected_packet = absl::string_view( + reinterpret_cast(kReferenceICMPPacket), 104); + CreateIcmpPacket(src_addr, dst_addr, icmp_header, message_body, + [&expected_packet](absl::string_view packet) { + QUIC_LOG(INFO) << quiche::QuicheTextUtils::HexDump(packet); + ASSERT_EQ(packet, expected_packet); + }); +} + +} // namespace quic diff --git a/quiche/quic/qbone/platform/internet_checksum.cc b/quiche/quic/qbone/platform/internet_checksum.cc new file mode 100644 index 000000000000..f9901e5c8ecf --- /dev/null +++ b/quiche/quic/qbone/platform/internet_checksum.cc @@ -0,0 +1,36 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/platform/internet_checksum.h" + +#include +#include + +namespace quic { + +void InternetChecksum::Update(const char* data, size_t size) { + const char* current; + for (current = data; current + 1 < data + size; current += 2) { + uint16_t v; + memcpy(&v, current, sizeof(v)); + accumulator_ += v; + } + if (current < data + size) { + accumulator_ += *reinterpret_cast(current); + } +} + +void InternetChecksum::Update(const uint8_t* data, size_t size) { + Update(reinterpret_cast(data), size); +} + +uint16_t InternetChecksum::Value() const { + uint32_t total = accumulator_; + while (total & 0xffff0000u) { + total = (total >> 16u) + (total & 0xffffu); + } + return ~static_cast(total); +} + +} // namespace quic diff --git a/quiche/quic/qbone/platform/internet_checksum.h b/quiche/quic/qbone/platform/internet_checksum.h new file mode 100644 index 000000000000..85d24155f29d --- /dev/null +++ b/quiche/quic/qbone/platform/internet_checksum.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_PLATFORM_INTERNET_CHECKSUM_H_ +#define QUICHE_QUIC_QBONE_PLATFORM_INTERNET_CHECKSUM_H_ + +#include +#include + +namespace quic { + +// Incrementally compute an Internet header checksum as described in RFC 1071. +class InternetChecksum { + public: + // Update the checksum with the specified data. Note that while the checksum + // is commutative, the data has to be supplied in the units of two-byte words. + // If there is an extra byte at the end, the function has to be called on it + // last. + void Update(const char* data, size_t size); + + void Update(const uint8_t* data, size_t size); + + uint16_t Value() const; + + private: + uint32_t accumulator_ = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_PLATFORM_INTERNET_CHECKSUM_H_ diff --git a/quiche/quic/qbone/platform/internet_checksum_test.cc b/quiche/quic/qbone/platform/internet_checksum_test.cc new file mode 100644 index 000000000000..8033c2ea6b19 --- /dev/null +++ b/quiche/quic/qbone/platform/internet_checksum_test.cc @@ -0,0 +1,67 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/platform/internet_checksum.h" + +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace { + +// From the Numerical Example described in RFC 1071 +// https://tools.ietf.org/html/rfc1071#section-3 +TEST(InternetChecksumTest, MatchesRFC1071Example) { + uint8_t data[] = {0x00, 0x01, 0xf2, 0x03, 0xf4, 0xf5, 0xf6, 0xf7}; + + InternetChecksum checksum; + checksum.Update(data, 8); + uint16_t result = checksum.Value(); + auto* result_bytes = reinterpret_cast(&result); + ASSERT_EQ(0x22, result_bytes[0]); + ASSERT_EQ(0x0d, result_bytes[1]); +} + +// Same as above, except 7 bytes. Should behave as if there was an 8th byte +// that equals 0. +TEST(InternetChecksumTest, MatchesRFC1071ExampleWithOddByteCount) { + uint8_t data[] = {0x00, 0x01, 0xf2, 0x03, 0xf4, 0xf5, 0xf6}; + + InternetChecksum checksum; + checksum.Update(data, 7); + uint16_t result = checksum.Value(); + auto* result_bytes = reinterpret_cast(&result); + ASSERT_EQ(0x23, result_bytes[0]); + ASSERT_EQ(0x04, result_bytes[1]); +} + +// From the example described at: +// http://www.cs.berkeley.edu/~kfall/EE122/lec06/tsld023.htm +TEST(InternetChecksumTest, MatchesBerkleyExample) { + uint8_t data[] = {0xe3, 0x4f, 0x23, 0x96, 0x44, 0x27, 0x99, 0xf3}; + + InternetChecksum checksum; + checksum.Update(data, 8); + uint16_t result = checksum.Value(); + auto* result_bytes = reinterpret_cast(&result); + ASSERT_EQ(0x1a, result_bytes[0]); + ASSERT_EQ(0xff, result_bytes[1]); +} + +TEST(InternetChecksumTest, ChecksumRequiringMultipleCarriesInLittleEndian) { + uint8_t data[] = {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x02, 0x00}; + + // Data will accumulate to 0x0002FFFF + // Summing lower and upper halves gives 0x00010001 + // Second sum of lower and upper halves gives 0x0002 + // One's complement gives 0xfffd, or [0xfd, 0xff] in network byte order + InternetChecksum checksum; + checksum.Update(data, 8); + uint16_t result = checksum.Value(); + auto* result_bytes = reinterpret_cast(&result); + EXPECT_EQ(0xfd, result_bytes[0]); + EXPECT_EQ(0xff, result_bytes[1]); +} + +} // namespace +} // namespace quic diff --git a/quiche/quic/qbone/platform/ip_range.cc b/quiche/quic/qbone/platform/ip_range.cc new file mode 100644 index 000000000000..2ef1ef7e65e2 --- /dev/null +++ b/quiche/quic/qbone/platform/ip_range.cc @@ -0,0 +1,97 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/platform/ip_range.h" + +#include "quiche/common/quiche_endian.h" + +namespace quic { + +namespace { + +constexpr size_t kIPv4Size = 32; +constexpr size_t kIPv6Size = 128; + +QuicIpAddress TruncateToLength(const QuicIpAddress& input, + size_t* prefix_length) { + QuicIpAddress output; + if (input.IsIPv4()) { + if (*prefix_length > kIPv4Size) { + *prefix_length = kIPv4Size; + return input; + } + uint32_t raw_address = + *reinterpret_cast(input.ToPackedString().data()); + raw_address = quiche::QuicheEndian::NetToHost32(raw_address); + raw_address &= ~0U << (kIPv4Size - *prefix_length); + raw_address = quiche::QuicheEndian::HostToNet32(raw_address); + output.FromPackedString(reinterpret_cast(&raw_address), + sizeof(raw_address)); + return output; + } + if (input.IsIPv6()) { + if (*prefix_length > kIPv6Size) { + *prefix_length = kIPv6Size; + return input; + } + uint64_t raw_address[2]; + memcpy(raw_address, input.ToPackedString().data(), sizeof(raw_address)); + // raw_address[0] holds higher 8 bytes in big endian and raw_address[1] + // holds lower 8 bytes. Converting each to little endian for us to mask bits + // out. + // The endianess between raw_address[0] and raw_address[1] is handled + // explicitly by handling lower and higher bytes separately. + raw_address[0] = quiche::QuicheEndian::NetToHost64(raw_address[0]); + raw_address[1] = quiche::QuicheEndian::NetToHost64(raw_address[1]); + if (*prefix_length <= kIPv6Size / 2) { + raw_address[0] &= ~uint64_t{0} << (kIPv6Size / 2 - *prefix_length); + raw_address[1] = 0; + } else { + raw_address[1] &= ~uint64_t{0} << (kIPv6Size - *prefix_length); + } + raw_address[0] = quiche::QuicheEndian::HostToNet64(raw_address[0]); + raw_address[1] = quiche::QuicheEndian::HostToNet64(raw_address[1]); + output.FromPackedString(reinterpret_cast(raw_address), + sizeof(raw_address)); + return output; + } + return output; +} + +} // namespace + +IpRange::IpRange(const QuicIpAddress& prefix, size_t prefix_length) + : prefix_(prefix), prefix_length_(prefix_length) { + prefix_ = TruncateToLength(prefix_, &prefix_length_); +} + +bool IpRange::operator==(IpRange other) const { + return prefix_ == other.prefix_ && prefix_length_ == other.prefix_length_; +} + +bool IpRange::operator!=(IpRange other) const { return !(*this == other); } + +bool IpRange::FromString(const std::string& range) { + size_t slash_pos = range.find('/'); + if (slash_pos == std::string::npos) { + return false; + } + QuicIpAddress prefix; + bool success = prefix.FromString(range.substr(0, slash_pos)); + if (!success) { + return false; + } + uint64_t num_processed = 0; + size_t prefix_length = std::stoi(range.substr(slash_pos + 1), &num_processed); + if (num_processed + 1 + slash_pos != range.length()) { + return false; + } + prefix_ = TruncateToLength(prefix, &prefix_length); + prefix_length_ = prefix_length; + return true; +} + +QuicIpAddress IpRange::FirstAddressInRange() const { return prefix(); } + +} // namespace quic diff --git a/quiche/quic/qbone/platform/ip_range.h b/quiche/quic/qbone/platform/ip_range.h new file mode 100644 index 000000000000..fcffea6502b6 --- /dev/null +++ b/quiche/quic/qbone/platform/ip_range.h @@ -0,0 +1,61 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_PLATFORM_IP_RANGE_H_ +#define QUICHE_QUIC_QBONE_PLATFORM_IP_RANGE_H_ + +#include "absl/strings/str_cat.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_ip_address_family.h" + +namespace quic { + +class IpRange { + public: + // Default constructor to have an uninitialized IpRange. + IpRange() : prefix_length_(0) {} + + // prefix will be automatically truncated to prefix_length, so that any bit + // after prefix_length are zero. + IpRange(const QuicIpAddress& prefix, size_t prefix_length); + + bool operator==(IpRange other) const; + bool operator!=(IpRange other) const; + + // Parses range that looks like "10.0.0.1/8". Tailing bits will be set to zero + // after prefix_length. Return false if the parsing failed. + bool FromString(const std::string& range); + + // Returns the string representation of this object. + std::string ToString() const { + if (IsInitialized()) { + return absl::StrCat(prefix_.ToString(), "/", prefix_length_); + } + return "(uninitialized)"; + } + + // Whether this object is initialized. + bool IsInitialized() const { return prefix_.IsInitialized(); } + + // Returns the first available IP address in this IpRange. The resulting + // address will be uninitialized if there is no available address. + QuicIpAddress FirstAddressInRange() const; + + // The address family of this IpRange. + IpAddressFamily address_family() const { return prefix_.address_family(); } + + // The subnet's prefix address. + QuicIpAddress prefix() const { return prefix_; } + + // The subnet's prefix length. + size_t prefix_length() const { return prefix_length_; } + + private: + QuicIpAddress prefix_; + size_t prefix_length_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_PLATFORM_IP_RANGE_H_ diff --git a/quiche/quic/qbone/platform/ip_range_test.cc b/quiche/quic/qbone/platform/ip_range_test.cc new file mode 100644 index 000000000000..a0444eab9ee1 --- /dev/null +++ b/quiche/quic/qbone/platform/ip_range_test.cc @@ -0,0 +1,65 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/platform/ip_range.h" + +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace { + +TEST(IpRangeTest, TruncateWorksIPv4) { + QuicIpAddress before_truncate; + before_truncate.FromString("255.255.255.255"); + EXPECT_EQ("128.0.0.0/1", IpRange(before_truncate, 1).ToString()); + EXPECT_EQ("192.0.0.0/2", IpRange(before_truncate, 2).ToString()); + EXPECT_EQ("255.224.0.0/11", IpRange(before_truncate, 11).ToString()); + EXPECT_EQ("255.255.255.224/27", IpRange(before_truncate, 27).ToString()); + EXPECT_EQ("255.255.255.254/31", IpRange(before_truncate, 31).ToString()); + EXPECT_EQ("255.255.255.255/32", IpRange(before_truncate, 32).ToString()); + EXPECT_EQ("255.255.255.255/32", IpRange(before_truncate, 33).ToString()); +} + +TEST(IpRangeTest, TruncateWorksIPv6) { + QuicIpAddress before_truncate; + before_truncate.FromString("ffff:ffff:ffff:ffff:f903::5"); + EXPECT_EQ("fe00::/7", IpRange(before_truncate, 7).ToString()); + EXPECT_EQ("ffff:ffff:ffff::/48", IpRange(before_truncate, 48).ToString()); + EXPECT_EQ("ffff:ffff:ffff:ffff::/64", + IpRange(before_truncate, 64).ToString()); + EXPECT_EQ("ffff:ffff:ffff:ffff:8000::/65", + IpRange(before_truncate, 65).ToString()); + EXPECT_EQ("ffff:ffff:ffff:ffff:f903::4/127", + IpRange(before_truncate, 127).ToString()); +} + +TEST(IpRangeTest, FromStringWorksIPv4) { + IpRange range; + ASSERT_TRUE(range.FromString("127.0.3.249/26")); + EXPECT_EQ("127.0.3.192/26", range.ToString()); +} + +TEST(IpRangeTest, FromStringWorksIPv6) { + IpRange range; + ASSERT_TRUE(range.FromString("ff01:8f21:77f9::/33")); + EXPECT_EQ("ff01:8f21::/33", range.ToString()); +} + +TEST(IpRangeTest, FirstAddressWorksIPv6) { + IpRange range; + ASSERT_TRUE(range.FromString("ffff:ffff::/64")); + QuicIpAddress first_address = range.FirstAddressInRange(); + EXPECT_EQ("ffff:ffff::", first_address.ToString()); +} + +TEST(IpRangeTest, FirstAddressWorksIPv4) { + IpRange range; + ASSERT_TRUE(range.FromString("10.0.0.0/24")); + QuicIpAddress first_address = range.FirstAddressInRange(); + EXPECT_EQ("10.0.0.0", first_address.ToString()); +} + +} // namespace +} // namespace quic diff --git a/quiche/quic/qbone/platform/kernel_interface.h b/quiche/quic/qbone/platform/kernel_interface.h new file mode 100644 index 000000000000..b7eb5d06eb77 --- /dev/null +++ b/quiche/quic/qbone/platform/kernel_interface.h @@ -0,0 +1,149 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_PLATFORM_KERNEL_INTERFACE_H_ +#define QUICHE_QUIC_QBONE_PLATFORM_KERNEL_INTERFACE_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace quic { + +// A wrapper for making syscalls to the kernel, so that syscalls can be +// mocked during testing. +class KernelInterface { + public: + virtual ~KernelInterface() {} + virtual int bind(int fd, const struct sockaddr* addr, socklen_t addr_len) = 0; + virtual int close(int fd) = 0; + virtual int ioctl(int fd, int request, void* argp) = 0; + virtual int open(const char* pathname, int flags) = 0; + virtual ssize_t read(int fd, void* buf, size_t count) = 0; + virtual ssize_t recvfrom(int sockfd, void* buf, size_t len, int flags, + struct sockaddr* src_addr, socklen_t* addrlen) = 0; + virtual ssize_t sendmsg(int sockfd, const struct msghdr* msg, int flags) = 0; + virtual ssize_t sendto(int sockfd, const void* buf, size_t len, int flags, + const struct sockaddr* dest_addr, + socklen_t addrlen) = 0; + virtual int socket(int domain, int type, int protocol) = 0; + virtual int setsockopt(int fd, int level, int optname, const void* optval, + socklen_t optlen) = 0; + virtual ssize_t write(int fd, const void* buf, size_t count) = 0; +}; + +// It is unfortunate to have R here, but std::result_of cannot be used. +template +auto SyscallRetryOnError(R r, F f, Params&&... params) + -> decltype(f(std::forward(params)...)) { + static_assert( + std::is_same(params)...)), R>::value, + "Return type does not match"); + decltype(f(std::forward(params)...)) result; + do { + result = f(std::forward(params)...); + } while (result == r && errno == EINTR); + return result; +} + +template +auto SyscallRetry(F f, Params&&... params) + -> decltype(f(std::forward(params)...)) { + return SyscallRetryOnError(-1, f, std::forward(params)...); +} + +template +class ParametrizedKernel final : public KernelInterface { + public: + static_assert(std::is_trivially_destructible::value, + "Runner is used as static, must be trivially destructible"); + + ~ParametrizedKernel() override {} + + int bind(int fd, const struct sockaddr* addr, socklen_t addr_len) override { + static Runner syscall("bind"); + return syscall.Retry(&::bind, fd, addr, addr_len); + } + int close(int fd) override { + static Runner syscall("close"); + return syscall.Retry(&::close, fd); + } + int ioctl(int fd, int request, void* argp) override { + static Runner syscall("ioctl"); + return syscall.Retry(&::ioctl, fd, request, argp); + } + int open(const char* pathname, int flags) override { + static Runner syscall("open"); + return syscall.Retry(&::open, pathname, flags); + } + ssize_t read(int fd, void* buf, size_t count) override { + static Runner syscall("read"); + return syscall.Run(&::read, fd, buf, count); + } + ssize_t recvfrom(int sockfd, void* buf, size_t len, int flags, + struct sockaddr* src_addr, socklen_t* addrlen) override { + static Runner syscall("recvfrom"); + return syscall.RetryOnError(&::recvfrom, static_cast(-1), sockfd, + buf, len, flags, src_addr, addrlen); + } + ssize_t sendmsg(int sockfd, const struct msghdr* msg, int flags) override { + static Runner syscall("sendmsg"); + return syscall.RetryOnError(&::sendmsg, static_cast(-1), sockfd, + msg, flags); + } + ssize_t sendto(int sockfd, const void* buf, size_t len, int flags, + const struct sockaddr* dest_addr, socklen_t addrlen) override { + static Runner syscall("sendto"); + return syscall.RetryOnError(&::sendto, static_cast(-1), sockfd, + buf, len, flags, dest_addr, addrlen); + } + int socket(int domain, int type, int protocol) override { + static Runner syscall("socket"); + return syscall.Retry(&::socket, domain, type, protocol); + } + int setsockopt(int fd, int level, int optname, const void* optval, + socklen_t optlen) override { + static Runner syscall("setsockopt"); + return syscall.Retry(&::setsockopt, fd, level, optname, optval, optlen); + } + ssize_t write(int fd, const void* buf, size_t count) override { + static Runner syscall("write"); + return syscall.Run(&::write, fd, buf, count); + } +}; + +class DefaultKernelRunner { + public: + explicit DefaultKernelRunner(const char* name) {} + + template + static auto RetryOnError(F f, R r, Params&&... params) + -> decltype(f(std::forward(params)...)) { + return SyscallRetryOnError(r, f, std::forward(params)...); + } + + template + static auto Retry(F f, Params&&... params) + -> decltype(f(std::forward(params)...)) { + return SyscallRetry(f, std::forward(params)...); + } + + template + static auto Run(F f, Params&&... params) + -> decltype(f(std::forward(params)...)) { + return f(std::forward(params)...); + } +}; + +using Kernel = ParametrizedKernel; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_PLATFORM_KERNEL_INTERFACE_H_ diff --git a/quiche/quic/qbone/platform/mock_kernel.h b/quiche/quic/qbone/platform/mock_kernel.h new file mode 100644 index 000000000000..a69446e2a623 --- /dev/null +++ b/quiche/quic/qbone/platform/mock_kernel.h @@ -0,0 +1,41 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_PLATFORM_MOCK_KERNEL_H_ +#define QUICHE_QUIC_QBONE_PLATFORM_MOCK_KERNEL_H_ + +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/qbone/platform/kernel_interface.h" + +namespace quic { + +class MockKernel : public KernelInterface { + public: + MockKernel() {} + + MOCK_METHOD(int, bind, (int fd, const struct sockaddr*, socklen_t addr_len), + (override)); + MOCK_METHOD(int, close, (int fd), (override)); + MOCK_METHOD(int, ioctl, (int fd, int request, void*), (override)); + MOCK_METHOD(int, open, (const char*, int flags), (override)); + MOCK_METHOD(ssize_t, read, (int fd, void*, size_t count), (override)); + MOCK_METHOD(ssize_t, recvfrom, + (int sockfd, void*, size_t len, int flags, struct sockaddr*, + socklen_t*), + (override)); + MOCK_METHOD(ssize_t, sendmsg, (int sockfd, const struct msghdr*, int flags), + (override)); + MOCK_METHOD(ssize_t, sendto, + (int sockfd, const void*, size_t len, int flags, + const struct sockaddr*, socklen_t addrlen), + (override)); + MOCK_METHOD(int, socket, (int domain, int type, int protocol), (override)); + MOCK_METHOD(int, setsockopt, (int, int, int, const void*, socklen_t), + (override)); + MOCK_METHOD(ssize_t, write, (int fd, const void*, size_t count), (override)); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_PLATFORM_MOCK_KERNEL_H_ diff --git a/quiche/quic/qbone/platform/mock_netlink.h b/quiche/quic/qbone/platform/mock_netlink.h new file mode 100644 index 000000000000..72e3b666fa77 --- /dev/null +++ b/quiche/quic/qbone/platform/mock_netlink.h @@ -0,0 +1,43 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_PLATFORM_MOCK_NETLINK_H_ +#define QUICHE_QUIC_QBONE_PLATFORM_MOCK_NETLINK_H_ + +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/qbone/platform/netlink_interface.h" + +namespace quic { + +class MockNetlink : public NetlinkInterface { + public: + MOCK_METHOD(bool, GetLinkInfo, (const std::string&, LinkInfo*), (override)); + + MOCK_METHOD(bool, GetAddresses, + (int, uint8_t, std::vector*, int*), (override)); + + MOCK_METHOD(bool, ChangeLocalAddress, + (uint32_t, Verb, const QuicIpAddress&, uint8_t, uint8_t, uint8_t, + const std::vector&), + (override)); + + MOCK_METHOD(bool, GetRouteInfo, (std::vector*), (override)); + + MOCK_METHOD(bool, ChangeRoute, + (Verb, uint32_t, const IpRange&, uint8_t, QuicIpAddress, int32_t, + uint32_t), + (override)); + + MOCK_METHOD(bool, GetRuleInfo, (std::vector*), (override)); + + MOCK_METHOD(bool, ChangeRule, (Verb, uint32_t, IpRange), (override)); + + MOCK_METHOD(bool, Send, (struct iovec*, size_t), (override)); + + MOCK_METHOD(bool, Recv, (uint32_t, NetlinkParserInterface*), (override)); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_PLATFORM_MOCK_NETLINK_H_ diff --git a/quiche/quic/qbone/platform/netlink.cc b/quiche/quic/qbone/platform/netlink.cc new file mode 100644 index 000000000000..0f20576286ae --- /dev/null +++ b/quiche/quic/qbone/platform/netlink.cc @@ -0,0 +1,856 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/platform/netlink.h" + +#include + +#include + +#include "absl/base/attributes.h" +#include "absl/strings/str_cat.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/qbone/platform/rtnetlink_message.h" +#include "quiche/quic/qbone/qbone_constants.h" + +namespace quic { + +Netlink::Netlink(KernelInterface* kernel) : kernel_(kernel) { + seq_ = QuicRandom::GetInstance()->RandUint64(); +} + +Netlink::~Netlink() { CloseSocket(); } + +void Netlink::ResetRecvBuf(size_t size) { + if (size != 0) { + recvbuf_ = std::make_unique(size); + } else { + recvbuf_ = nullptr; + } + recvbuf_length_ = size; +} + +bool Netlink::OpenSocket() { + if (socket_fd_ >= 0) { + return true; + } + + socket_fd_ = kernel_->socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE); + + if (socket_fd_ < 0) { + QUIC_PLOG(ERROR) << "can't open netlink socket"; + return false; + } + + QUIC_LOG(INFO) << "Opened a new netlink socket fd = " << socket_fd_; + + // bind a local address to the socket + sockaddr_nl myaddr; + memset(&myaddr, 0, sizeof(myaddr)); + myaddr.nl_family = AF_NETLINK; + if (kernel_->bind(socket_fd_, reinterpret_cast(&myaddr), + sizeof(myaddr)) < 0) { + QUIC_LOG(INFO) << "can't bind address to socket"; + CloseSocket(); + return false; + } + + return true; +} + +void Netlink::CloseSocket() { + if (socket_fd_ >= 0) { + QUIC_LOG(INFO) << "Closing netlink socket fd = " << socket_fd_; + kernel_->close(socket_fd_); + } + ResetRecvBuf(0); + socket_fd_ = -1; +} + +namespace { + +class LinkInfoParser : public NetlinkParserInterface { + public: + LinkInfoParser(std::string interface_name, Netlink::LinkInfo* link_info) + : interface_name_(std::move(interface_name)), link_info_(link_info) {} + + void Run(struct nlmsghdr* netlink_message) override { + if (netlink_message->nlmsg_type != RTM_NEWLINK) { + QUIC_LOG(INFO) << absl::StrCat( + "Unexpected nlmsg_type: ", netlink_message->nlmsg_type, + " expected: ", RTM_NEWLINK); + return; + } + + struct ifinfomsg* interface_info = + reinterpret_cast(NLMSG_DATA(netlink_message)); + + // make sure interface_info is what we asked for. + if (interface_info->ifi_family != AF_UNSPEC) { + QUIC_LOG(INFO) << absl::StrCat( + "Unexpected ifi_family: ", interface_info->ifi_family, + " expected: ", AF_UNSPEC); + return; + } + + char hardware_address[kHwAddrSize]; + size_t hardware_address_length = 0; + char broadcast_address[kHwAddrSize]; + size_t broadcast_address_length = 0; + std::string name; + + // loop through the attributes + struct rtattr* rta; + int payload_length = IFLA_PAYLOAD(netlink_message); + for (rta = IFLA_RTA(interface_info); RTA_OK(rta, payload_length); + rta = RTA_NEXT(rta, payload_length)) { + int attribute_length; + switch (rta->rta_type) { + case IFLA_ADDRESS: { + attribute_length = RTA_PAYLOAD(rta); + if (attribute_length > kHwAddrSize) { + QUIC_VLOG(2) << "IFLA_ADDRESS too long: " << attribute_length; + break; + } + memmove(hardware_address, RTA_DATA(rta), attribute_length); + hardware_address_length = attribute_length; + break; + } + case IFLA_BROADCAST: { + attribute_length = RTA_PAYLOAD(rta); + if (attribute_length > kHwAddrSize) { + QUIC_VLOG(2) << "IFLA_BROADCAST too long: " << attribute_length; + break; + } + memmove(broadcast_address, RTA_DATA(rta), attribute_length); + broadcast_address_length = attribute_length; + break; + } + case IFLA_IFNAME: { + name = std::string(reinterpret_cast(RTA_DATA(rta)), + RTA_PAYLOAD(rta)); + // The name maybe a 0 terminated c string. + name = name.substr(0, name.find('\0')); + break; + } + } + } + + QUIC_VLOG(2) << "interface name: " << name + << ", index: " << interface_info->ifi_index; + + if (name == interface_name_) { + link_info_->index = interface_info->ifi_index; + link_info_->type = interface_info->ifi_type; + link_info_->hardware_address_length = hardware_address_length; + if (hardware_address_length > 0) { + memmove(&link_info_->hardware_address, hardware_address, + hardware_address_length); + } + link_info_->broadcast_address_length = broadcast_address_length; + if (broadcast_address_length > 0) { + memmove(&link_info_->broadcast_address, broadcast_address, + broadcast_address_length); + } + found_link_ = true; + } + } + + bool found_link() { return found_link_; } + + private: + const std::string interface_name_; + Netlink::LinkInfo* const link_info_; + bool found_link_ = false; +}; + +} // namespace + +bool Netlink::GetLinkInfo(const std::string& interface_name, + LinkInfo* link_info) { + auto message = LinkMessage::New(RtnetlinkMessage::Operation::GET, + NLM_F_ROOT | NLM_F_MATCH | NLM_F_REQUEST, + seq_, getpid(), nullptr); + + if (!Send(message.BuildIoVec().get(), message.IoVecSize())) { + QUIC_LOG(ERROR) << "send failed."; + return false; + } + + // Pass the parser to the receive routine. It may be called multiple times + // since there may be multiple reply packets each with multiple reply + // messages. + LinkInfoParser parser(interface_name, link_info); + if (!Recv(seq_++, &parser)) { + QUIC_LOG(ERROR) << "recv failed."; + return false; + } + + return parser.found_link(); +} + +namespace { + +class LocalAddressParser : public NetlinkParserInterface { + public: + LocalAddressParser(int interface_index, uint8_t unwanted_flags, + std::vector* local_addresses, + int* num_ipv6_nodad_dadfailed_addresses) + : interface_index_(interface_index), + unwanted_flags_(unwanted_flags), + local_addresses_(local_addresses), + num_ipv6_nodad_dadfailed_addresses_( + num_ipv6_nodad_dadfailed_addresses) {} + + void Run(struct nlmsghdr* netlink_message) override { + // each nlmsg contains a header and multiple address attributes. + if (netlink_message->nlmsg_type != RTM_NEWADDR) { + QUIC_LOG(INFO) << "Unexpected nlmsg_type: " << netlink_message->nlmsg_type + << " expected: " << RTM_NEWADDR; + return; + } + + struct ifaddrmsg* interface_address = + reinterpret_cast(NLMSG_DATA(netlink_message)); + + // Make sure this is for an address family we're interested in. + if (interface_address->ifa_family != AF_INET && + interface_address->ifa_family != AF_INET6) { + QUIC_VLOG(2) << absl::StrCat("uninteresting ifa family: ", + interface_address->ifa_family); + return; + } + + // Keep track of addresses with both 'nodad' and 'dadfailed', this really + // should't be possible and is likely a kernel bug. + if (num_ipv6_nodad_dadfailed_addresses_ != nullptr && + (interface_address->ifa_flags & IFA_F_NODAD) && + (interface_address->ifa_flags & IFA_F_DADFAILED)) { + ++(*num_ipv6_nodad_dadfailed_addresses_); + } + + uint8_t unwanted_flags = interface_address->ifa_flags & unwanted_flags_; + if (unwanted_flags != 0) { + QUIC_VLOG(2) << absl::StrCat("unwanted ifa flags: ", unwanted_flags); + return; + } + + // loop through the attributes + struct rtattr* rta; + int payload_length = IFA_PAYLOAD(netlink_message); + Netlink::AddressInfo address_info; + for (rta = IFA_RTA(interface_address); RTA_OK(rta, payload_length); + rta = RTA_NEXT(rta, payload_length)) { + // There's quite a lot of confusion in Linux over the use of IFA_LOCAL and + // IFA_ADDRESS (source and destination address). For broadcast links, such + // as Ethernet, they are identical (see ), but the kernel + // sometimes uses only one or the other. We'll return both so that the + // caller can decide which to use. + if (rta->rta_type != IFA_LOCAL && rta->rta_type != IFA_ADDRESS) { + QUIC_VLOG(2) << "Ignoring uninteresting rta_type: " << rta->rta_type; + continue; + } + + switch (interface_address->ifa_family) { + case AF_INET: + ABSL_FALLTHROUGH_INTENDED; + case AF_INET6: + // QuicIpAddress knows how to parse ip from raw bytes as long as they + // are in network byte order. + if (RTA_PAYLOAD(rta) == sizeof(struct in_addr) || + RTA_PAYLOAD(rta) == sizeof(struct in6_addr)) { + auto* raw_ip = reinterpret_cast(RTA_DATA(rta)); + if (rta->rta_type == IFA_LOCAL) { + address_info.local_address.FromPackedString(raw_ip, + RTA_PAYLOAD(rta)); + } else { + address_info.interface_address.FromPackedString(raw_ip, + RTA_PAYLOAD(rta)); + } + } + break; + default: + QUIC_LOG(ERROR) << absl::StrCat("Unknown address family: ", + interface_address->ifa_family); + } + } + + QUIC_VLOG(2) << "local_address: " << address_info.local_address.ToString() + << " interface_address: " + << address_info.interface_address.ToString() + << " index: " << interface_address->ifa_index; + if (interface_address->ifa_index != interface_index_) { + return; + } + + address_info.prefix_length = interface_address->ifa_prefixlen; + address_info.scope = interface_address->ifa_scope; + if (address_info.local_address.IsInitialized() || + address_info.interface_address.IsInitialized()) { + local_addresses_->push_back(address_info); + } + } + + private: + const int interface_index_; + const uint8_t unwanted_flags_; + std::vector* const local_addresses_; + int* const num_ipv6_nodad_dadfailed_addresses_; +}; + +} // namespace + +bool Netlink::GetAddresses(int interface_index, uint8_t unwanted_flags, + std::vector* addresses, + int* num_ipv6_nodad_dadfailed_addresses) { + // the message doesn't contain the index, we'll have to do the filtering while + // parsing the reply. This is because NLM_F_MATCH, which only returns entries + // that matches the request criteria, is not yet implemented (see man 3 + // netlink). + auto message = AddressMessage::New(RtnetlinkMessage::Operation::GET, + NLM_F_ROOT | NLM_F_MATCH | NLM_F_REQUEST, + seq_, getpid(), nullptr); + + // the send routine returns the socket to listen on. + if (!Send(message.BuildIoVec().get(), message.IoVecSize())) { + QUIC_LOG(ERROR) << "send failed."; + return false; + } + + addresses->clear(); + if (num_ipv6_nodad_dadfailed_addresses != nullptr) { + *num_ipv6_nodad_dadfailed_addresses = 0; + } + + LocalAddressParser parser(interface_index, unwanted_flags, addresses, + num_ipv6_nodad_dadfailed_addresses); + // Pass the parser to the receive routine. It may be called multiple times + // since there may be multiple reply packets each with multiple reply + // messages. + if (!Recv(seq_++, &parser)) { + QUIC_LOG(ERROR) << "recv failed"; + return false; + } + return true; +} + +namespace { + +class UnknownParser : public NetlinkParserInterface { + public: + void Run(struct nlmsghdr* netlink_message) override { + QUIC_LOG(INFO) << "nlmsg reply type: " << netlink_message->nlmsg_type; + } +}; + +} // namespace + +bool Netlink::ChangeLocalAddress( + uint32_t interface_index, Verb verb, const QuicIpAddress& address, + uint8_t prefix_length, uint8_t ifa_flags, uint8_t ifa_scope, + const std::vector& additional_attributes) { + if (verb == Verb::kReplace) { + return false; + } + auto operation = verb == Verb::kAdd ? RtnetlinkMessage::Operation::NEW + : RtnetlinkMessage::Operation::DEL; + uint8_t address_family; + if (address.address_family() == IpAddressFamily::IP_V4) { + address_family = AF_INET; + } else if (address.address_family() == IpAddressFamily::IP_V6) { + address_family = AF_INET6; + } else { + return false; + } + + struct ifaddrmsg address_header = {address_family, prefix_length, ifa_flags, + ifa_scope, interface_index}; + + auto message = AddressMessage::New(operation, NLM_F_REQUEST | NLM_F_ACK, seq_, + getpid(), &address_header); + + for (const auto& attribute : additional_attributes) { + if (attribute->rta_type == IFA_LOCAL) { + continue; + } + message.AppendAttribute(attribute->rta_type, RTA_DATA(attribute), + RTA_PAYLOAD(attribute)); + } + + message.AppendAttribute(IFA_LOCAL, address.ToPackedString().c_str(), + address.ToPackedString().size()); + + if (!Send(message.BuildIoVec().get(), message.IoVecSize())) { + QUIC_LOG(ERROR) << "send failed"; + return false; + } + + UnknownParser parser; + if (!Recv(seq_++, &parser)) { + QUIC_LOG(ERROR) << "receive failed."; + return false; + } + return true; +} + +namespace { + +class RoutingRuleParser : public NetlinkParserInterface { + public: + explicit RoutingRuleParser(std::vector* routing_rules) + : routing_rules_(routing_rules) {} + + void Run(struct nlmsghdr* netlink_message) override { + if (netlink_message->nlmsg_type != RTM_NEWROUTE) { + QUIC_LOG(WARNING) << absl::StrCat( + "Unexpected nlmsg_type: ", netlink_message->nlmsg_type, + " expected: ", RTM_NEWROUTE); + return; + } + + auto* route = reinterpret_cast(NLMSG_DATA(netlink_message)); + int payload_length = RTM_PAYLOAD(netlink_message); + + if (route->rtm_family != AF_INET && route->rtm_family != AF_INET6) { + QUIC_VLOG(2) << absl::StrCat("Uninteresting family: ", route->rtm_family); + return; + } + + Netlink::RoutingRule rule; + rule.scope = route->rtm_scope; + rule.table = route->rtm_table; + rule.init_cwnd = Netlink::kUnspecifiedInitCwnd; + + struct rtattr* rta; + for (rta = RTM_RTA(route); RTA_OK(rta, payload_length); + rta = RTA_NEXT(rta, payload_length)) { + switch (rta->rta_type) { + case RTA_TABLE: { + rule.table = *reinterpret_cast(RTA_DATA(rta)); + break; + } + case RTA_DST: { + QuicIpAddress destination; + destination.FromPackedString(reinterpret_cast RTA_DATA(rta), + RTA_PAYLOAD(rta)); + rule.destination_subnet = IpRange(destination, route->rtm_dst_len); + break; + } + case RTA_PREFSRC: { + QuicIpAddress preferred_source; + rule.preferred_source.FromPackedString( + reinterpret_cast RTA_DATA(rta), RTA_PAYLOAD(rta)); + break; + } + case RTA_OIF: { + rule.out_interface = *reinterpret_cast(RTA_DATA(rta)); + break; + } + case RTA_METRICS: { + struct rtattr* rtax; + int rta_payload_length = RTA_PAYLOAD(rta); + for (rtax = reinterpret_cast(RTA_DATA(rta)); + RTA_OK(rtax, rta_payload_length); + rtax = RTA_NEXT(rtax, rta_payload_length)) { + switch (rtax->rta_type) { + case RTAX_INITCWND: { + rule.init_cwnd = *reinterpret_cast(RTA_DATA(rtax)); + break; + } + default: { + QUIC_VLOG(2) << absl::StrCat( + "Uninteresting RTA_METRICS attribute: ", rtax->rta_type); + } + } + } + break; + } + default: { + QUIC_VLOG(2) << absl::StrCat("Uninteresting attribute: ", + rta->rta_type); + } + } + } + routing_rules_->push_back(rule); + } + + private: + std::vector* routing_rules_; +}; + +} // namespace + +bool Netlink::GetRouteInfo(std::vector* routing_rules) { + rtmsg route_message{}; + // Only manipulate main routing table. + route_message.rtm_table = RT_TABLE_MAIN; + + auto message = RouteMessage::New(RtnetlinkMessage::Operation::GET, + NLM_F_REQUEST | NLM_F_ROOT | NLM_F_MATCH, + seq_, getpid(), &route_message); + + if (!Send(message.BuildIoVec().get(), message.IoVecSize())) { + QUIC_LOG(ERROR) << "send failed"; + return false; + } + + RoutingRuleParser parser(routing_rules); + if (!Recv(seq_++, &parser)) { + QUIC_LOG(ERROR) << "recv failed"; + return false; + } + + return true; +} + +bool Netlink::ChangeRoute(Netlink::Verb verb, uint32_t table, + const IpRange& destination_subnet, uint8_t scope, + QuicIpAddress preferred_source, + int32_t interface_index, uint32_t init_cwnd) { + if (!destination_subnet.prefix().IsInitialized()) { + return false; + } + if (destination_subnet.address_family() != IpAddressFamily::IP_V4 && + destination_subnet.address_family() != IpAddressFamily::IP_V6) { + return false; + } + if (preferred_source.IsInitialized() && + preferred_source.address_family() != + destination_subnet.address_family()) { + return false; + } + + RtnetlinkMessage::Operation operation; + uint16_t flags = NLM_F_REQUEST | NLM_F_ACK; + switch (verb) { + case Verb::kAdd: + operation = RtnetlinkMessage::Operation::NEW; + // Setting NLM_F_EXCL so that an existing entry for this subnet will fail + // the request. NLM_F_CREATE is necessary to indicate this is trying to + // create a new entry - simply having RTM_NEWROUTE is not enough even the + // name suggests so. + flags |= NLM_F_EXCL | NLM_F_CREATE; + break; + case Verb::kRemove: + operation = RtnetlinkMessage::Operation::DEL; + break; + case Verb::kReplace: + operation = RtnetlinkMessage::Operation::NEW; + // Setting NLM_F_REPLACE to tell the kernel that existing entry for this + // subnet should be replaced. + flags |= NLM_F_REPLACE | NLM_F_CREATE; + break; + } + + struct rtmsg route_message; + memset(&route_message, 0, sizeof(route_message)); + route_message.rtm_family = + destination_subnet.address_family() == IpAddressFamily::IP_V4 ? AF_INET + : AF_INET6; + // rtm_dst_len and rtm_src_len are actually the subnet prefix lengths. Poor + // naming. + route_message.rtm_dst_len = destination_subnet.prefix_length(); + // 0 means no source subnet for this rule. + route_message.rtm_src_len = 0; + // Only program the main table. Other tables are intended for the kernel to + // manage. + route_message.rtm_table = RT_TABLE_MAIN; + // Use RTPROT_UNSPEC to match all the different protocol. Rules added by + // kernel have RTPROT_KERNEL. Rules added by the root user have RTPROT_STATIC + // instead. + route_message.rtm_protocol = + verb == Verb::kRemove ? RTPROT_UNSPEC : RTPROT_STATIC; + route_message.rtm_scope = scope; + // Only add unicast routing rule. + route_message.rtm_type = RTN_UNICAST; + auto message = + RouteMessage::New(operation, flags, seq_, getpid(), &route_message); + + message.AppendAttribute(RTA_TABLE, &table, sizeof(table)); + + if (init_cwnd != kUnspecifiedInitCwnd) { + char data[RTA_LENGTH(sizeof(uint32_t))]; + struct rtattr* rta = reinterpret_cast(data); + rta->rta_type = RTAX_INITCWND; + rta->rta_len = sizeof(data); + *reinterpret_cast(RTA_DATA(rta)) = init_cwnd; + message.AppendAttribute(RTA_METRICS, data, sizeof(data)); + } + + // RTA_OIF is the target interface for this rule. + message.AppendAttribute(RTA_OIF, &interface_index, sizeof(interface_index)); + // The actual destination subnet must be truncated of all the tailing zeros. + message.AppendAttribute( + RTA_DST, + reinterpret_cast( + destination_subnet.prefix().ToPackedString().c_str()), + destination_subnet.prefix().ToPackedString().size()); + // This is the source address to use in the IP packet should this routing rule + // is used. + if (preferred_source.IsInitialized()) { + auto src_str = preferred_source.ToPackedString(); + message.AppendAttribute(RTA_PREFSRC, + reinterpret_cast(src_str.c_str()), + src_str.size()); + } + + if (verb != Verb::kRemove) { + auto gateway_str = QboneConstants::GatewayAddress()->ToPackedString(); + message.AppendAttribute(RTA_GATEWAY, + reinterpret_cast(gateway_str.c_str()), + gateway_str.size()); + } + + if (!Send(message.BuildIoVec().get(), message.IoVecSize())) { + QUIC_LOG(ERROR) << "send failed"; + return false; + } + + UnknownParser parser; + if (!Recv(seq_++, &parser)) { + QUIC_LOG(ERROR) << "receive failed."; + return false; + } + return true; +} + +namespace { + +class IpRuleParser : public NetlinkParserInterface { + public: + explicit IpRuleParser(std::vector* ip_rules) + : ip_rules_(ip_rules) {} + + void Run(struct nlmsghdr* netlink_message) override { + if (netlink_message->nlmsg_type != RTM_NEWRULE) { + QUIC_LOG(WARNING) << absl::StrCat( + "Unexpected nlmsg_type: ", netlink_message->nlmsg_type, + " expected: ", RTM_NEWRULE); + return; + } + + auto* rule = reinterpret_cast(NLMSG_DATA(netlink_message)); + int payload_length = RTM_PAYLOAD(netlink_message); + + if (rule->rtm_family != AF_INET6) { + QUIC_LOG(ERROR) << absl::StrCat("Unexpected family: ", rule->rtm_family); + return; + } + + Netlink::IpRule ip_rule; + ip_rule.table = rule->rtm_table; + + struct rtattr* rta; + for (rta = RTM_RTA(rule); RTA_OK(rta, payload_length); + rta = RTA_NEXT(rta, payload_length)) { + switch (rta->rta_type) { + case RTA_TABLE: { + ip_rule.table = *reinterpret_cast(RTA_DATA(rta)); + break; + } + case RTA_SRC: { + QuicIpAddress src_addr; + src_addr.FromPackedString(reinterpret_cast(RTA_DATA(rta)), + RTA_PAYLOAD(rta)); + IpRange src_range(src_addr, rule->rtm_src_len); + ip_rule.source_range = src_range; + break; + } + default: { + QUIC_VLOG(2) << absl::StrCat("Uninteresting attribute: ", + rta->rta_type); + } + } + } + ip_rules_->emplace_back(ip_rule); + } + + private: + std::vector* ip_rules_; +}; + +} // namespace + +bool Netlink::GetRuleInfo(std::vector* ip_rules) { + rtmsg rule_message{}; + rule_message.rtm_family = AF_INET6; + + auto message = RuleMessage::New(RtnetlinkMessage::Operation::GET, + NLM_F_REQUEST | NLM_F_DUMP, seq_, getpid(), + &rule_message); + + if (!Send(message.BuildIoVec().get(), message.IoVecSize())) { + QUIC_LOG(ERROR) << "send failed"; + return false; + } + + IpRuleParser parser(ip_rules); + if (!Recv(seq_++, &parser)) { + QUIC_LOG(ERROR) << "receive failed."; + return false; + } + return true; +} + +bool Netlink::ChangeRule(Verb verb, uint32_t table, IpRange source_range) { + RtnetlinkMessage::Operation operation; + uint16_t flags = NLM_F_REQUEST | NLM_F_ACK; + + rtmsg rule_message{}; + rule_message.rtm_family = AF_INET6; + rule_message.rtm_protocol = RTPROT_STATIC; + rule_message.rtm_scope = RT_SCOPE_UNIVERSE; + rule_message.rtm_table = RT_TABLE_UNSPEC; + + rule_message.rtm_flags |= FIB_RULE_FIND_SADDR; + + switch (verb) { + case Verb::kAdd: + if (!source_range.IsInitialized()) { + QUIC_LOG(ERROR) << "Source range must be initialized."; + return false; + } + operation = RtnetlinkMessage::Operation::NEW; + flags |= NLM_F_EXCL | NLM_F_CREATE; + rule_message.rtm_type = FRA_DST; + rule_message.rtm_src_len = source_range.prefix_length(); + break; + case Verb::kRemove: + operation = RtnetlinkMessage::Operation::DEL; + break; + case Verb::kReplace: + QUIC_LOG(ERROR) << "Unsupported verb: kReplace"; + return false; + } + auto message = + RuleMessage::New(operation, flags, seq_, getpid(), &rule_message); + + message.AppendAttribute(RTA_TABLE, &table, sizeof(table)); + + if (source_range.IsInitialized()) { + std::string packed_src = source_range.prefix().ToPackedString(); + message.AppendAttribute(RTA_SRC, + reinterpret_cast(packed_src.c_str()), + packed_src.size()); + } + + if (!Send(message.BuildIoVec().get(), message.IoVecSize())) { + QUIC_LOG(ERROR) << "send failed"; + return false; + } + + UnknownParser parser; + if (!Recv(seq_++, &parser)) { + QUIC_LOG(ERROR) << "receive failed."; + return false; + } + return true; +} + +bool Netlink::Send(struct iovec* iov, size_t iovlen) { + if (!OpenSocket()) { + QUIC_LOG(ERROR) << "can't open socket"; + return false; + } + + // an address for communicating with the kernel netlink code + sockaddr_nl netlink_address; + memset(&netlink_address, 0, sizeof(netlink_address)); + netlink_address.nl_family = AF_NETLINK; + netlink_address.nl_pid = 0; // destination is kernel + netlink_address.nl_groups = 0; // no multicast + + struct msghdr msg = { + &netlink_address, sizeof(netlink_address), iov, iovlen, nullptr, 0, 0}; + + if (kernel_->sendmsg(socket_fd_, &msg, 0) < 0) { + QUIC_LOG(ERROR) << "sendmsg failed"; + CloseSocket(); + return false; + } + + return true; +} + +bool Netlink::Recv(uint32_t seq, NetlinkParserInterface* parser) { + sockaddr_nl netlink_address; + + // replies can span multiple packets + for (;;) { + socklen_t address_length = sizeof(netlink_address); + + // First, call recvfrom with buffer size of 0 and MSG_PEEK | MSG_TRUNC set + // so that we know the size of the incoming packet before actually receiving + // it. + int next_packet_size = kernel_->recvfrom( + socket_fd_, recvbuf_.get(), /* len = */ 0, MSG_PEEK | MSG_TRUNC, + reinterpret_cast(&netlink_address), &address_length); + if (next_packet_size < 0) { + QUIC_LOG(ERROR) + << "error recvfrom with MSG_PEEK | MSG_TRUNC to get packet length."; + CloseSocket(); + return false; + } + QUIC_VLOG(3) << "netlink packet size: " << next_packet_size; + if (next_packet_size > recvbuf_length_) { + QUIC_VLOG(2) << "resizing recvbuf to " << next_packet_size; + ResetRecvBuf(next_packet_size); + } + + // Get the packet for real. + memset(recvbuf_.get(), 0, recvbuf_length_); + int len = kernel_->recvfrom( + socket_fd_, recvbuf_.get(), recvbuf_length_, /* flags = */ 0, + reinterpret_cast(&netlink_address), &address_length); + QUIC_VLOG(3) << "recvfrom returned: " << len; + if (len < 0) { + QUIC_LOG(INFO) << "can't receive netlink packet"; + CloseSocket(); + return false; + } + + // there may be multiple nlmsg's in each reply packet + struct nlmsghdr* netlink_message; + for (netlink_message = reinterpret_cast(recvbuf_.get()); + NLMSG_OK(netlink_message, len); + netlink_message = NLMSG_NEXT(netlink_message, len)) { + QUIC_VLOG(3) << "netlink_message->nlmsg_type = " + << netlink_message->nlmsg_type; + // make sure this is to us + if (netlink_message->nlmsg_seq != seq) { + QUIC_LOG(INFO) << "netlink_message not meant for us." + << " seq: " << seq + << " nlmsg_seq: " << netlink_message->nlmsg_seq; + continue; + } + + // done with this whole reply (not just this particular packet) + if (netlink_message->nlmsg_type == NLMSG_DONE) { + return true; + } + if (netlink_message->nlmsg_type == NLMSG_ERROR) { + struct nlmsgerr* err = + reinterpret_cast(NLMSG_DATA(netlink_message)); + if (netlink_message->nlmsg_len < + NLMSG_LENGTH(sizeof(struct nlmsgerr))) { + QUIC_LOG(INFO) << "netlink_message ERROR truncated"; + } else { + // an ACK + if (err->error == 0) { + QUIC_VLOG(3) << "Netlink sent an ACK"; + return true; + } + QUIC_LOG(INFO) << "netlink_message ERROR: " << err->error; + } + return false; + } + + parser->Run(netlink_message); + } + } +} + +} // namespace quic diff --git a/quiche/quic/qbone/platform/netlink.h b/quiche/quic/qbone/platform/netlink.h new file mode 100644 index 000000000000..857b8bff4ac4 --- /dev/null +++ b/quiche/quic/qbone/platform/netlink.h @@ -0,0 +1,138 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_PLATFORM_NETLINK_H_ +#define QUICHE_QUIC_QBONE_PLATFORM_NETLINK_H_ + +#include +#include + +#include +#include +#include +#include +#include + +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/qbone/platform/ip_range.h" +#include "quiche/quic/qbone/platform/kernel_interface.h" +#include "quiche/quic/qbone/platform/netlink_interface.h" + +namespace quic { + +// A wrapper class to provide convenient methods of manipulating IP address and +// routing table using netlink (man 7 netlink) socket. More specifically, +// rtnetlink is used (man 7 rtnetlink). +// +// This class is not thread safe, but thread compatible, as long as callers can +// make sure Send and Recv pairs are executed in sequence for a particular +// query. +class Netlink : public NetlinkInterface { + public: + explicit Netlink(KernelInterface* kernel); + ~Netlink() override; + + // Gets the link information for the interface referred by the given + // interface_name. + // + // This is a synchronous communication. That should not be a problem since the + // kernel should answer immediately. + bool GetLinkInfo(const std::string& interface_name, + LinkInfo* link_info) override; + + // Gets the addresses for the given interface referred by the given + // interface_index. + // + // This is a synchronous communication. This should not be a problem since the + // kernel should answer immediately. + bool GetAddresses(int interface_index, uint8_t unwanted_flags, + std::vector* addresses, + int* num_ipv6_nodad_dadfailed_addresses) override; + + // Performs the given verb that modifies local addresses on the given + // interface_index. + // + // additional_attributes are RTAs (man 7 rtnelink) that will be sent together + // with the netlink message. Note that rta_len in each RTA is used to decide + // the length of the payload. The caller is responsible for making sure + // payload bytes are accessible after the RTA header. + bool ChangeLocalAddress( + uint32_t interface_index, Verb verb, const QuicIpAddress& address, + uint8_t prefix_length, uint8_t ifa_flags, uint8_t ifa_scope, + const std::vector& additional_attributes) override; + + // Gets the list of routing rules from the main routing table (RT_TABLE_MAIN), + // which is programmable. + // + // This is a synchronous communication. This should not be a problem since the + // kernel should answer immediately. + bool GetRouteInfo(std::vector* routing_rules) override; + + // Performs the given Verb on the matching rule in the main routing table + // (RT_TABLE_MAIN). + // + // preferred_source can be !IsInitialized(), in which case it will be omitted. + // + // init_cwnd will be left unspecified if set to 0. + // + // For Verb::kRemove, rule matching is done by (destination_subnet, scope, + // preferred_source, interface_index). Return true if a matching rule is + // found. interface_index can be 0 for wilecard. + // + // For Verb::kAdd, rule matching is done by destination_subnet. If a rule for + // the given destination_subnet already exists, nothing will happen and false + // is returned. + // + // For Verb::kReplace, rule matching is done by destination_subnet. If no + // matching rule is found, a new entry will be created. + bool ChangeRoute(Netlink::Verb verb, uint32_t table, + const IpRange& destination_subnet, uint8_t scope, + QuicIpAddress preferred_source, int32_t interface_index, + uint32_t init_cwnd) override; + + // Returns the set of all rules in the routing policy database. + bool GetRuleInfo(std::vector* ip_rules) override; + + // Performs the give verb on the matching rule in the routing policy database. + // When deleting a rule, the |source_range| may be unspecified, in which case + // the lowest priority rule from |table| will be removed. When adding a rule, + // the |source_address| must be specified. + bool ChangeRule(Verb verb, uint32_t table, IpRange source_range) override; + + // Sends a netlink message to the kernel. iov and iovlen represents an array + // of struct iovec to be fed into sendmsg. The caller needs to make sure the + // message conform to what's expected by NLMSG_* macros. + // + // This can be useful if more flexibility is needed than the provided + // convenient methods can provide. + bool Send(struct iovec* iov, size_t iovlen) override; + + // Receives a netlink message from the kernel. + // parser will be called on the caller's stack. + // + // This can be useful if more flexibility is needed than the provided + // convenient methods can provide. + // TODO(b/69412655): vectorize this. + bool Recv(uint32_t seq, NetlinkParserInterface* parser) override; + + private: + // Reset the size of recvbuf_ to size. If size is 0, recvbuf_ will be nullptr. + void ResetRecvBuf(size_t size); + + // Opens a netlink socket if not already opened. + bool OpenSocket(); + + // Closes the opened netlink socket. Noop if no netlink socket is opened. + void CloseSocket(); + + KernelInterface* kernel_; + int socket_fd_ = -1; + std::unique_ptr recvbuf_ = nullptr; + size_t recvbuf_length_ = 0; + uint32_t seq_; // next msg sequence number +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_PLATFORM_NETLINK_H_ diff --git a/quiche/quic/qbone/platform/netlink_interface.h b/quiche/quic/qbone/platform/netlink_interface.h new file mode 100644 index 000000000000..c4fc42c6f69b --- /dev/null +++ b/quiche/quic/qbone/platform/netlink_interface.h @@ -0,0 +1,144 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_PLATFORM_NETLINK_INTERFACE_H_ +#define QUICHE_QUIC_QBONE_PLATFORM_NETLINK_INTERFACE_H_ + +#include + +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/qbone/platform/ip_range.h" + +namespace quic { + +constexpr int kHwAddrSize = 6; + +class NetlinkParserInterface { + public: + virtual ~NetlinkParserInterface() {} + virtual void Run(struct nlmsghdr* netlink_message) = 0; +}; + +// An interface providing convenience methods for manipulating IP address and +// routing table using netlink (man 7 netlink) socket. +class NetlinkInterface { + public: + virtual ~NetlinkInterface() = default; + + // Link information returned from GetLinkInfo. + struct LinkInfo { + int index; + uint8_t type; + uint8_t hardware_address[kHwAddrSize]; + uint8_t broadcast_address[kHwAddrSize]; + size_t hardware_address_length; // 0 if no hardware address found + size_t broadcast_address_length; // 0 if no broadcast address found + }; + + // Gets the link information for the interface referred by the given + // interface_name. + virtual bool GetLinkInfo(const std::string& interface_name, + LinkInfo* link_info) = 0; + + // Address information reported back from GetAddresses. + struct AddressInfo { + QuicIpAddress local_address; + QuicIpAddress interface_address; + uint8_t prefix_length = 0; + uint8_t scope = 0; + }; + + // Gets the addresses for the given interface referred by the given + // interface_index. + virtual bool GetAddresses(int interface_index, uint8_t unwanted_flags, + std::vector* addresses, + int* num_ipv6_nodad_dadfailed_addresses) = 0; + + enum class Verb { + kAdd, + kRemove, + kReplace, + }; + + // Performs the given verb that modifies local addresses on the given + // interface_index. + // + // additional_attributes are RTAs (man 7 rtnelink) that will be sent together + // with the netlink message. Note that rta_len in each RTA is used to decide + // the length of the payload. The caller is responsible for making sure + // payload bytes are accessible after the RTA header. + virtual bool ChangeLocalAddress( + uint32_t interface_index, Verb verb, const QuicIpAddress& address, + uint8_t prefix_length, uint8_t ifa_flags, uint8_t ifa_scope, + const std::vector& additional_attributes) = 0; + + static constexpr uint32_t kUnspecifiedInitCwnd = 0; + + // Routing rule reported back from GetRouteInfo. + struct RoutingRule { + uint32_t table; + IpRange destination_subnet; + QuicIpAddress preferred_source; + uint8_t scope; + int out_interface; + uint32_t init_cwnd; // kUnspecifiedInitCwnd if unspecified + }; + + struct IpRule { + uint32_t table; + IpRange source_range; + }; + + // Gets the list of routing rules from the main routing table (RT_TABLE_MAIN), + // which is programmable. + virtual bool GetRouteInfo(std::vector* routing_rules) = 0; + + // Performs the given Verb on the matching rule in the main routing table + // (RT_TABLE_MAIN). + // + // preferred_source can be !IsInitialized(), in which case it will be omitted. + // + // For Verb::kRemove, rule matching is done by (destination_subnet, scope, + // preferred_source, interface_index). Return true if a matching rule is + // found. interface_index can be 0 for wilecard. + // + // For Verb::kAdd, rule matching is done by destination_subnet. If a rule for + // the given destination_subnet already exists, nothing will happen and false + // is returned. + // + // For Verb::kReplace, rule matching is done by destination_subnet. If no + // matching rule is found, a new entry will be created. + virtual bool ChangeRoute(Verb verb, uint32_t table, + const IpRange& destination_subnet, uint8_t scope, + QuicIpAddress preferred_source, + int32_t interface_index, uint32_t init_cwnd) = 0; + + // Returns the set of all rules in the routing policy database. + virtual bool GetRuleInfo(std::vector* ip_rules) = 0; + + // Performs the give verb on the matching rule in the routing policy database. + // When deleting a rule, the |source_range| may be unspecified, in which case + // the lowest priority rule from |table| will be removed. When adding a rule, + // the |source_address| must be specified. + virtual bool ChangeRule(Verb verb, uint32_t table, IpRange source_range) = 0; + + // Sends a netlink message to the kernel. iov and iovlen represents an array + // of struct iovec to be fed into sendmsg. The caller needs to make sure the + // message conform to what's expected by NLMSG_* macros. + // + // This can be useful if more flexibility is needed than the provided + // convenient methods can provide. + virtual bool Send(struct iovec* iov, size_t iovlen) = 0; + + // Receives a netlink message from the kernel. + // parser will be called on the caller's stack. + // + // This can be useful if more flexibility is needed than the provided + // convenient methods can provide. + virtual bool Recv(uint32_t seq, NetlinkParserInterface* parser) = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_PLATFORM_NETLINK_INTERFACE_H_ diff --git a/quiche/quic/qbone/platform/netlink_test.cc b/quiche/quic/qbone/platform/netlink_test.cc new file mode 100644 index 000000000000..b0ea1f0b8d11 --- /dev/null +++ b/quiche/quic/qbone/platform/netlink_test.cc @@ -0,0 +1,788 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/platform/netlink.h" + +#include + +#include "absl/container/node_hash_set.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/qbone/platform/mock_kernel.h" +#include "quiche/quic/qbone/qbone_constants.h" + +namespace quic::test { +namespace { + +using ::testing::_; +using ::testing::Contains; +using ::testing::InSequence; +using ::testing::Invoke; +using ::testing::Return; +using ::testing::Unused; + +const int kSocketFd = 101; + +class NetlinkTest : public QuicTest { + protected: + NetlinkTest() { + ON_CALL(mock_kernel_, socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE)) + .WillByDefault(Invoke([this](Unused, Unused, Unused) { + EXPECT_CALL(mock_kernel_, close(kSocketFd)).WillOnce(Return(0)); + return kSocketFd; + })); + } + + void ExpectNetlinkPacket( + uint16_t type, uint16_t flags, + const std::function& + recv_callback, + const std::function& send_callback = + nullptr) { + static int seq = -1; + InSequence s; + + EXPECT_CALL(mock_kernel_, sendmsg(kSocketFd, _, _)) + .WillOnce(Invoke([type, flags, send_callback]( + Unused, const struct msghdr* msg, int) { + EXPECT_EQ(sizeof(struct sockaddr_nl), msg->msg_namelen); + auto* nl_addr = + reinterpret_cast(msg->msg_name); + EXPECT_EQ(AF_NETLINK, nl_addr->nl_family); + EXPECT_EQ(0, nl_addr->nl_pid); + EXPECT_EQ(0, nl_addr->nl_groups); + + EXPECT_GE(msg->msg_iovlen, 1); + EXPECT_GE(msg->msg_iov[0].iov_len, sizeof(struct nlmsghdr)); + + std::string buf; + for (int i = 0; i < msg->msg_iovlen; i++) { + buf.append( + std::string(reinterpret_cast(msg->msg_iov[i].iov_base), + msg->msg_iov[i].iov_len)); + } + + auto* netlink_message = + reinterpret_cast(buf.c_str()); + EXPECT_EQ(type, netlink_message->nlmsg_type); + EXPECT_EQ(flags, netlink_message->nlmsg_flags); + EXPECT_GE(buf.size(), netlink_message->nlmsg_len); + + if (send_callback != nullptr) { + send_callback(buf.c_str(), buf.size()); + } + + QUICHE_CHECK_EQ(seq, -1); + seq = netlink_message->nlmsg_seq; + return buf.size(); + })); + + EXPECT_CALL(mock_kernel_, + recvfrom(kSocketFd, _, 0, MSG_PEEK | MSG_TRUNC, _, _)) + .WillOnce(Invoke([this, recv_callback](Unused, Unused, Unused, Unused, + struct sockaddr* src_addr, + socklen_t* addrlen) { + auto* nl_addr = reinterpret_cast(src_addr); + nl_addr->nl_family = AF_NETLINK; + nl_addr->nl_pid = 0; // from kernel + nl_addr->nl_groups = 0; // no multicast + + int ret = recv_callback(reply_packet_, sizeof(reply_packet_), seq); + QUICHE_CHECK_LE(ret, sizeof(reply_packet_)); + return ret; + })); + + EXPECT_CALL(mock_kernel_, recvfrom(kSocketFd, _, _, _, _, _)) + .WillOnce(Invoke([recv_callback](Unused, void* buf, size_t len, Unused, + struct sockaddr* src_addr, + socklen_t* addrlen) { + auto* nl_addr = reinterpret_cast(src_addr); + nl_addr->nl_family = AF_NETLINK; + nl_addr->nl_pid = 0; // from kernel + nl_addr->nl_groups = 0; // no multicast + + int ret = recv_callback(buf, len, seq); + EXPECT_GE(len, ret); + seq = -1; + return ret; + })); + } + + char reply_packet_[4096]; + MockKernel mock_kernel_; +}; + +void AddRTA(struct nlmsghdr* netlink_message, uint16_t type, const void* data, + size_t len) { + auto* next_header_ptr = reinterpret_cast(netlink_message) + + NLMSG_ALIGN(netlink_message->nlmsg_len); + + auto* rta = reinterpret_cast(next_header_ptr); + rta->rta_type = type; + rta->rta_len = RTA_LENGTH(len); + memcpy(RTA_DATA(rta), data, len); + + netlink_message->nlmsg_len = + NLMSG_ALIGN(netlink_message->nlmsg_len) + RTA_LENGTH(len); +} + +void CreateIfinfomsg(struct nlmsghdr* netlink_message, + const std::string& interface_name, uint16_t type, + int index, unsigned int flags, unsigned int change, + uint8_t address[], int address_len, uint8_t broadcast[], + int broadcast_len) { + auto* interface_info = + reinterpret_cast(NLMSG_DATA(netlink_message)); + interface_info->ifi_family = AF_UNSPEC; + interface_info->ifi_type = type; + interface_info->ifi_index = index; + interface_info->ifi_flags = flags; + interface_info->ifi_change = change; + netlink_message->nlmsg_len = NLMSG_LENGTH(sizeof(struct ifinfomsg)); + + // Add address + AddRTA(netlink_message, IFLA_ADDRESS, address, address_len); + + // Add broadcast address + AddRTA(netlink_message, IFLA_BROADCAST, broadcast, broadcast_len); + + // Add name + AddRTA(netlink_message, IFLA_IFNAME, interface_name.c_str(), + interface_name.size()); +} + +struct nlmsghdr* CreateNetlinkMessage(void* buf, // NOLINT + struct nlmsghdr* previous_netlink_message, + uint16_t type, int seq) { + auto* next_header_ptr = reinterpret_cast(buf); + if (previous_netlink_message != nullptr) { + next_header_ptr = reinterpret_cast(previous_netlink_message) + + NLMSG_ALIGN(previous_netlink_message->nlmsg_len); + } + auto* netlink_message = reinterpret_cast(next_header_ptr); + netlink_message->nlmsg_len = NLMSG_LENGTH(0); + netlink_message->nlmsg_type = type; + netlink_message->nlmsg_flags = NLM_F_MULTI; + netlink_message->nlmsg_pid = 0; // from the kernel + netlink_message->nlmsg_seq = seq; + + return netlink_message; +} + +void CreateIfaddrmsg(struct nlmsghdr* nlm, int interface_index, + unsigned char prefixlen, unsigned char flags, + unsigned char scope, QuicIpAddress ip) { + QUICHE_CHECK(ip.IsInitialized()); + unsigned char family; + switch (ip.address_family()) { + case IpAddressFamily::IP_V4: + family = AF_INET; + break; + case IpAddressFamily::IP_V6: + family = AF_INET6; + break; + default: + QUIC_BUG(quic_bug_11034_1) + << absl::StrCat("unexpected address family: ", ip.address_family()); + family = AF_UNSPEC; + } + auto* msg = reinterpret_cast(NLMSG_DATA(nlm)); + msg->ifa_family = family; + msg->ifa_prefixlen = prefixlen; + msg->ifa_flags = flags; + msg->ifa_scope = scope; + msg->ifa_index = interface_index; + nlm->nlmsg_len = NLMSG_LENGTH(sizeof(struct ifaddrmsg)); + + // Add local address + AddRTA(nlm, IFA_LOCAL, ip.ToPackedString().c_str(), + ip.ToPackedString().size()); +} + +void CreateRtmsg(struct nlmsghdr* nlm, unsigned char family, + unsigned char destination_length, unsigned char source_length, + unsigned char tos, unsigned char table, unsigned char protocol, + unsigned char scope, unsigned char type, unsigned int flags, + QuicIpAddress destination, int interface_index, + int init_cwnd) { + auto* msg = reinterpret_cast(NLMSG_DATA(nlm)); + msg->rtm_family = family; + msg->rtm_dst_len = destination_length; + msg->rtm_src_len = source_length; + msg->rtm_tos = tos; + msg->rtm_table = table; + msg->rtm_protocol = protocol; + msg->rtm_scope = scope; + msg->rtm_type = type; + msg->rtm_flags = flags; + nlm->nlmsg_len = NLMSG_LENGTH(sizeof(struct rtmsg)); + + // Add destination + AddRTA(nlm, RTA_DST, destination.ToPackedString().c_str(), + destination.ToPackedString().size()); + + // Add egress interface + AddRTA(nlm, RTA_OIF, &interface_index, sizeof(interface_index)); + + // Add initcwnd + if (init_cwnd > 0) { + char data[RTA_LENGTH(sizeof(uint32_t))]; + struct rtattr* rta = reinterpret_cast(data); + rta->rta_len = sizeof(data); + rta->rta_type = RTA_METRICS; + *reinterpret_cast(RTA_DATA(rta)) = init_cwnd; + AddRTA(nlm, RTA_METRICS, data, sizeof(data)); + } +} + +TEST_F(NetlinkTest, GetLinkInfoWorks) { + auto netlink = std::make_unique(&mock_kernel_); + + uint8_t hwaddr[] = {'a', 'b', 'c', 'd', 'e', 'f'}; + uint8_t bcaddr[] = {'c', 'b', 'a', 'f', 'e', 'd'}; + + ExpectNetlinkPacket( + RTM_GETLINK, NLM_F_ROOT | NLM_F_MATCH | NLM_F_REQUEST, + [&hwaddr, &bcaddr](void* buf, size_t len, int seq) { + int ret = 0; + + struct nlmsghdr* netlink_message = + CreateNetlinkMessage(buf, nullptr, RTM_NEWLINK, seq); + CreateIfinfomsg(netlink_message, "tun0", /* type = */ 1, + /* index = */ 7, + /* flags = */ 0, + /* change = */ 0xFFFFFFFF, hwaddr, 6, bcaddr, 6); + ret += NLMSG_ALIGN(netlink_message->nlmsg_len); + + netlink_message = + CreateNetlinkMessage(buf, netlink_message, NLMSG_DONE, seq); + ret += NLMSG_ALIGN(netlink_message->nlmsg_len); + + return ret; + }); + + Netlink::LinkInfo link_info; + EXPECT_TRUE(netlink->GetLinkInfo("tun0", &link_info)); + + EXPECT_EQ(7, link_info.index); + EXPECT_EQ(1, link_info.type); + + for (int i = 0; i < link_info.hardware_address_length; ++i) { + EXPECT_EQ(hwaddr[i], link_info.hardware_address[i]); + } + for (int i = 0; i < link_info.broadcast_address_length; ++i) { + EXPECT_EQ(bcaddr[i], link_info.broadcast_address[i]); + } +} + +TEST_F(NetlinkTest, GetAddressesWorks) { + auto netlink = std::make_unique(&mock_kernel_); + + absl::node_hash_set addresses = { + QuicIpAddress::Any4().ToString(), QuicIpAddress::Any6().ToString()}; + + ExpectNetlinkPacket( + RTM_GETADDR, NLM_F_ROOT | NLM_F_MATCH | NLM_F_REQUEST, + [&addresses](void* buf, size_t len, int seq) { + int ret = 0; + + struct nlmsghdr* nlm = nullptr; + + for (const auto& address : addresses) { + QuicIpAddress ip; + ip.FromString(address); + nlm = CreateNetlinkMessage(buf, nlm, RTM_NEWADDR, seq); + CreateIfaddrmsg(nlm, /* interface_index = */ 7, /* prefixlen = */ 24, + /* flags = */ 0, /* scope = */ RT_SCOPE_UNIVERSE, ip); + + ret += NLMSG_ALIGN(nlm->nlmsg_len); + } + + // Create IPs with unwanted flags. + { + QuicIpAddress ip; + ip.FromString("10.0.0.1"); + nlm = CreateNetlinkMessage(buf, nlm, RTM_NEWADDR, seq); + CreateIfaddrmsg(nlm, /* interface_index = */ 7, /* prefixlen = */ 16, + /* flags = */ IFA_F_OPTIMISTIC, /* scope = */ + RT_SCOPE_UNIVERSE, ip); + + ret += NLMSG_ALIGN(nlm->nlmsg_len); + + ip.FromString("10.0.0.2"); + nlm = CreateNetlinkMessage(buf, nlm, RTM_NEWADDR, seq); + CreateIfaddrmsg(nlm, /* interface_index = */ 7, /* prefixlen = */ 16, + /* flags = */ IFA_F_TENTATIVE, /* scope = */ + RT_SCOPE_UNIVERSE, ip); + + ret += NLMSG_ALIGN(nlm->nlmsg_len); + } + + nlm = CreateNetlinkMessage(buf, nlm, NLMSG_DONE, seq); + ret += NLMSG_ALIGN(nlm->nlmsg_len); + + return ret; + }); + + std::vector reported_addresses; + int num_ipv6_nodad_dadfailed_addresses = 0; + EXPECT_TRUE(netlink->GetAddresses(7, IFA_F_TENTATIVE | IFA_F_OPTIMISTIC, + &reported_addresses, + &num_ipv6_nodad_dadfailed_addresses)); + + for (const auto& reported_address : reported_addresses) { + EXPECT_TRUE(reported_address.local_address.IsInitialized()); + EXPECT_FALSE(reported_address.interface_address.IsInitialized()); + EXPECT_THAT(addresses, Contains(reported_address.local_address.ToString())); + addresses.erase(reported_address.local_address.ToString()); + + EXPECT_EQ(24, reported_address.prefix_length); + } + + EXPECT_TRUE(addresses.empty()); +} + +TEST_F(NetlinkTest, ChangeLocalAddressAdd) { + auto netlink = std::make_unique(&mock_kernel_); + + QuicIpAddress ip = QuicIpAddress::Any6(); + ExpectNetlinkPacket( + RTM_NEWADDR, NLM_F_ACK | NLM_F_REQUEST, + [](void* buf, size_t len, int seq) { + struct nlmsghdr* netlink_message = + CreateNetlinkMessage(buf, nullptr, NLMSG_ERROR, seq); + auto* err = + reinterpret_cast(NLMSG_DATA(netlink_message)); + // Ack the request + err->error = 0; + netlink_message->nlmsg_len = NLMSG_LENGTH(sizeof(struct nlmsgerr)); + return netlink_message->nlmsg_len; + }, + [ip](const void* buf, size_t len) { + auto* netlink_message = reinterpret_cast(buf); + auto* ifa = reinterpret_cast( + NLMSG_DATA(netlink_message)); + EXPECT_EQ(19, ifa->ifa_prefixlen); + EXPECT_EQ(RT_SCOPE_UNIVERSE, ifa->ifa_scope); + EXPECT_EQ(IFA_F_PERMANENT, ifa->ifa_flags); + EXPECT_EQ(7, ifa->ifa_index); + EXPECT_EQ(AF_INET6, ifa->ifa_family); + + const struct rtattr* rta; + int payload_length = IFA_PAYLOAD(netlink_message); + int num_rta = 0; + for (rta = IFA_RTA(ifa); RTA_OK(rta, payload_length); + rta = RTA_NEXT(rta, payload_length)) { + switch (rta->rta_type) { + case IFA_LOCAL: { + EXPECT_EQ(ip.ToPackedString().size(), RTA_PAYLOAD(rta)); + const auto* raw_address = + reinterpret_cast(RTA_DATA(rta)); + ASSERT_EQ(sizeof(in6_addr), RTA_PAYLOAD(rta)); + QuicIpAddress address; + address.FromPackedString(raw_address, RTA_PAYLOAD(rta)); + EXPECT_EQ(ip, address); + break; + } + case IFA_CACHEINFO: { + EXPECT_EQ(sizeof(struct ifa_cacheinfo), RTA_PAYLOAD(rta)); + const auto* cache_info = + reinterpret_cast(RTA_DATA(rta)); + EXPECT_EQ(8, cache_info->ifa_prefered); // common_typos_disable + EXPECT_EQ(6, cache_info->ifa_valid); + EXPECT_EQ(4, cache_info->cstamp); + EXPECT_EQ(2, cache_info->tstamp); + break; + } + default: + EXPECT_TRUE(false) << "Seeing rtattr that should not exist"; + } + ++num_rta; + } + EXPECT_EQ(2, num_rta); + }); + + struct { + struct rtattr rta; + struct ifa_cacheinfo cache_info; + } additional_rta; + + additional_rta.rta.rta_type = IFA_CACHEINFO; + additional_rta.rta.rta_len = RTA_LENGTH(sizeof(struct ifa_cacheinfo)); + additional_rta.cache_info.ifa_prefered = 8; + additional_rta.cache_info.ifa_valid = 6; + additional_rta.cache_info.cstamp = 4; + additional_rta.cache_info.tstamp = 2; + + EXPECT_TRUE(netlink->ChangeLocalAddress(7, Netlink::Verb::kAdd, ip, 19, + IFA_F_PERMANENT, RT_SCOPE_UNIVERSE, + {&additional_rta.rta})); +} + +TEST_F(NetlinkTest, ChangeLocalAddressRemove) { + auto netlink = std::make_unique(&mock_kernel_); + + QuicIpAddress ip = QuicIpAddress::Any4(); + ExpectNetlinkPacket( + RTM_DELADDR, NLM_F_ACK | NLM_F_REQUEST, + [](void* buf, size_t len, int seq) { + struct nlmsghdr* netlink_message = + CreateNetlinkMessage(buf, nullptr, NLMSG_ERROR, seq); + auto* err = + reinterpret_cast(NLMSG_DATA(netlink_message)); + // Ack the request + err->error = 0; + netlink_message->nlmsg_len = NLMSG_LENGTH(sizeof(struct nlmsgerr)); + return netlink_message->nlmsg_len; + }, + [ip](const void* buf, size_t len) { + auto* netlink_message = reinterpret_cast(buf); + auto* ifa = reinterpret_cast( + NLMSG_DATA(netlink_message)); + EXPECT_EQ(32, ifa->ifa_prefixlen); + EXPECT_EQ(RT_SCOPE_UNIVERSE, ifa->ifa_scope); + EXPECT_EQ(0, ifa->ifa_flags); + EXPECT_EQ(7, ifa->ifa_index); + EXPECT_EQ(AF_INET, ifa->ifa_family); + + const struct rtattr* rta; + int payload_length = IFA_PAYLOAD(netlink_message); + int num_rta = 0; + for (rta = IFA_RTA(ifa); RTA_OK(rta, payload_length); + rta = RTA_NEXT(rta, payload_length)) { + switch (rta->rta_type) { + case IFA_LOCAL: { + const auto* raw_address = + reinterpret_cast(RTA_DATA(rta)); + ASSERT_EQ(sizeof(in_addr), RTA_PAYLOAD(rta)); + QuicIpAddress address; + address.FromPackedString(raw_address, RTA_PAYLOAD(rta)); + EXPECT_EQ(ip, address); + break; + } + default: + EXPECT_TRUE(false) << "Seeing rtattr that should not exist"; + } + ++num_rta; + } + EXPECT_EQ(1, num_rta); + }); + + EXPECT_TRUE(netlink->ChangeLocalAddress(7, Netlink::Verb::kRemove, ip, 32, 0, + RT_SCOPE_UNIVERSE, {})); +} + +TEST_F(NetlinkTest, GetRouteInfoWorks) { + auto netlink = std::make_unique(&mock_kernel_); + + QuicIpAddress destination; + ASSERT_TRUE(destination.FromString("f800::2")); + ExpectNetlinkPacket(RTM_GETROUTE, NLM_F_ROOT | NLM_F_MATCH | NLM_F_REQUEST, + [destination](void* buf, size_t len, int seq) { + int ret = 0; + struct nlmsghdr* netlink_message = CreateNetlinkMessage( + buf, nullptr, RTM_NEWROUTE, seq); + CreateRtmsg(netlink_message, AF_INET6, 48, 0, 0, + RT_TABLE_MAIN, RTPROT_STATIC, RT_SCOPE_LINK, + RTN_UNICAST, 0, destination, 7, 0); + ret += NLMSG_ALIGN(netlink_message->nlmsg_len); + + netlink_message = CreateNetlinkMessage( + buf, netlink_message, NLMSG_DONE, seq); + ret += NLMSG_ALIGN(netlink_message->nlmsg_len); + + QUIC_LOG(INFO) << "ret: " << ret; + return ret; + }); + + std::vector routing_rules; + EXPECT_TRUE(netlink->GetRouteInfo(&routing_rules)); + + ASSERT_EQ(1, routing_rules.size()); + EXPECT_EQ(RT_SCOPE_LINK, routing_rules[0].scope); + EXPECT_EQ(IpRange(destination, 48).ToString(), + routing_rules[0].destination_subnet.ToString()); + EXPECT_FALSE(routing_rules[0].preferred_source.IsInitialized()); + EXPECT_EQ(7, routing_rules[0].out_interface); + EXPECT_EQ(0, routing_rules[0].init_cwnd); +} + +TEST_F(NetlinkTest, ChangeRouteAdd) { + auto netlink = std::make_unique(&mock_kernel_); + + QuicIpAddress preferred_ip; + preferred_ip.FromString("ff80:dead:beef::1"); + IpRange subnet; + subnet.FromString("ff80:dead:beef::/48"); + int egress_interface_index = 7; + uint32_t init_cwnd = 32; + ExpectNetlinkPacket( + RTM_NEWROUTE, NLM_F_ACK | NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL, + [](void* buf, size_t len, int seq) { + struct nlmsghdr* netlink_message = + CreateNetlinkMessage(buf, nullptr, NLMSG_ERROR, seq); + auto* err = + reinterpret_cast(NLMSG_DATA(netlink_message)); + // Ack the request + err->error = 0; + netlink_message->nlmsg_len = NLMSG_LENGTH(sizeof(struct nlmsgerr)); + return netlink_message->nlmsg_len; + }, + [preferred_ip, subnet, egress_interface_index, init_cwnd](const void* buf, + size_t len) { + auto* netlink_message = reinterpret_cast(buf); + auto* rtm = + reinterpret_cast(NLMSG_DATA(netlink_message)); + EXPECT_EQ(AF_INET6, rtm->rtm_family); + EXPECT_EQ(48, rtm->rtm_dst_len); + EXPECT_EQ(0, rtm->rtm_src_len); + EXPECT_EQ(RT_TABLE_MAIN, rtm->rtm_table); + EXPECT_EQ(RTPROT_STATIC, rtm->rtm_protocol); + EXPECT_EQ(RT_SCOPE_LINK, rtm->rtm_scope); + EXPECT_EQ(RTN_UNICAST, rtm->rtm_type); + + const struct rtattr* rta; + int payload_length = RTM_PAYLOAD(netlink_message); + int num_rta = 0; + for (rta = RTM_RTA(rtm); RTA_OK(rta, payload_length); + rta = RTA_NEXT(rta, payload_length)) { + switch (rta->rta_type) { + case RTA_PREFSRC: { + const auto* raw_address = + reinterpret_cast(RTA_DATA(rta)); + ASSERT_EQ(sizeof(struct in6_addr), RTA_PAYLOAD(rta)); + QuicIpAddress address; + address.FromPackedString(raw_address, RTA_PAYLOAD(rta)); + EXPECT_EQ(preferred_ip, address); + break; + } + case RTA_GATEWAY: { + const auto* raw_address = + reinterpret_cast(RTA_DATA(rta)); + ASSERT_EQ(sizeof(struct in6_addr), RTA_PAYLOAD(rta)); + QuicIpAddress address; + address.FromPackedString(raw_address, RTA_PAYLOAD(rta)); + EXPECT_EQ(*QboneConstants::GatewayAddress(), address); + break; + } + case RTA_OIF: { + ASSERT_EQ(sizeof(int), RTA_PAYLOAD(rta)); + const auto* interface_index = + reinterpret_cast(RTA_DATA(rta)); + EXPECT_EQ(egress_interface_index, *interface_index); + break; + } + case RTA_DST: { + const auto* raw_address = + reinterpret_cast(RTA_DATA(rta)); + ASSERT_EQ(sizeof(struct in6_addr), RTA_PAYLOAD(rta)); + QuicIpAddress address; + address.FromPackedString(raw_address, RTA_PAYLOAD(rta)); + EXPECT_EQ(subnet.ToString(), + IpRange(address, rtm->rtm_dst_len).ToString()); + break; + } + case RTA_TABLE: { + ASSERT_EQ(*reinterpret_cast(RTA_DATA(rta)), + QboneConstants::kQboneRouteTableId); + break; + } + case RTA_METRICS: { + struct rtattr* rtax = + reinterpret_cast(RTA_DATA(rta)); + ASSERT_EQ(rtax->rta_type, RTAX_INITCWND); + ASSERT_EQ(rtax->rta_len, RTA_LENGTH(sizeof(uint32_t))); + ASSERT_EQ(*reinterpret_cast(RTA_DATA(rtax)), + init_cwnd); + break; + } + default: + EXPECT_TRUE(false) << "Seeing rtattr that should not be sent"; + } + ++num_rta; + } + EXPECT_EQ(6, num_rta); + }); + EXPECT_TRUE(netlink->ChangeRoute( + Netlink::Verb::kAdd, QboneConstants::kQboneRouteTableId, subnet, + RT_SCOPE_LINK, preferred_ip, egress_interface_index, init_cwnd)); +} + +TEST_F(NetlinkTest, ChangeRouteRemove) { + auto netlink = std::make_unique(&mock_kernel_); + + QuicIpAddress preferred_ip; + preferred_ip.FromString("ff80:dead:beef::1"); + IpRange subnet; + subnet.FromString("ff80:dead:beef::/48"); + int egress_interface_index = 7; + ExpectNetlinkPacket( + RTM_DELROUTE, NLM_F_ACK | NLM_F_REQUEST, + [](void* buf, size_t len, int seq) { + struct nlmsghdr* netlink_message = + CreateNetlinkMessage(buf, nullptr, NLMSG_ERROR, seq); + auto* err = + reinterpret_cast(NLMSG_DATA(netlink_message)); + // Ack the request + err->error = 0; + netlink_message->nlmsg_len = NLMSG_LENGTH(sizeof(struct nlmsgerr)); + return netlink_message->nlmsg_len; + }, + [preferred_ip, subnet, egress_interface_index](const void* buf, + size_t len) { + auto* netlink_message = reinterpret_cast(buf); + auto* rtm = + reinterpret_cast(NLMSG_DATA(netlink_message)); + EXPECT_EQ(AF_INET6, rtm->rtm_family); + EXPECT_EQ(48, rtm->rtm_dst_len); + EXPECT_EQ(0, rtm->rtm_src_len); + EXPECT_EQ(RT_TABLE_MAIN, rtm->rtm_table); + EXPECT_EQ(RTPROT_UNSPEC, rtm->rtm_protocol); + EXPECT_EQ(RT_SCOPE_LINK, rtm->rtm_scope); + EXPECT_EQ(RTN_UNICAST, rtm->rtm_type); + + const struct rtattr* rta; + int payload_length = RTM_PAYLOAD(netlink_message); + int num_rta = 0; + for (rta = RTM_RTA(rtm); RTA_OK(rta, payload_length); + rta = RTA_NEXT(rta, payload_length)) { + switch (rta->rta_type) { + case RTA_PREFSRC: { + const auto* raw_address = + reinterpret_cast(RTA_DATA(rta)); + ASSERT_EQ(sizeof(struct in6_addr), RTA_PAYLOAD(rta)); + QuicIpAddress address; + address.FromPackedString(raw_address, RTA_PAYLOAD(rta)); + EXPECT_EQ(preferred_ip, address); + break; + } + case RTA_OIF: { + ASSERT_EQ(sizeof(int), RTA_PAYLOAD(rta)); + const auto* interface_index = + reinterpret_cast(RTA_DATA(rta)); + EXPECT_EQ(egress_interface_index, *interface_index); + break; + } + case RTA_DST: { + const auto* raw_address = + reinterpret_cast(RTA_DATA(rta)); + ASSERT_EQ(sizeof(struct in6_addr), RTA_PAYLOAD(rta)); + QuicIpAddress address; + address.FromPackedString(raw_address, RTA_PAYLOAD(rta)); + EXPECT_EQ(subnet.ToString(), + IpRange(address, rtm->rtm_dst_len).ToString()); + break; + } + case RTA_TABLE: { + ASSERT_EQ(*reinterpret_cast(RTA_DATA(rta)), + QboneConstants::kQboneRouteTableId); + break; + } + default: + EXPECT_TRUE(false) << "Seeing rtattr that should not be sent"; + } + ++num_rta; + } + EXPECT_EQ(4, num_rta); + }); + EXPECT_TRUE(netlink->ChangeRoute( + Netlink::Verb::kRemove, QboneConstants::kQboneRouteTableId, subnet, + RT_SCOPE_LINK, preferred_ip, egress_interface_index, + Netlink::kUnspecifiedInitCwnd)); +} + +TEST_F(NetlinkTest, ChangeRouteReplace) { + auto netlink = std::make_unique(&mock_kernel_); + + QuicIpAddress preferred_ip; + preferred_ip.FromString("ff80:dead:beef::1"); + IpRange subnet; + subnet.FromString("ff80:dead:beef::/48"); + int egress_interface_index = 7; + ExpectNetlinkPacket( + RTM_NEWROUTE, NLM_F_ACK | NLM_F_REQUEST | NLM_F_CREATE | NLM_F_REPLACE, + [](void* buf, size_t len, int seq) { + struct nlmsghdr* netlink_message = + CreateNetlinkMessage(buf, nullptr, NLMSG_ERROR, seq); + auto* err = + reinterpret_cast(NLMSG_DATA(netlink_message)); + // Ack the request + err->error = 0; + netlink_message->nlmsg_len = NLMSG_LENGTH(sizeof(struct nlmsgerr)); + return netlink_message->nlmsg_len; + }, + [preferred_ip, subnet, egress_interface_index](const void* buf, + size_t len) { + auto* netlink_message = reinterpret_cast(buf); + auto* rtm = + reinterpret_cast(NLMSG_DATA(netlink_message)); + EXPECT_EQ(AF_INET6, rtm->rtm_family); + EXPECT_EQ(48, rtm->rtm_dst_len); + EXPECT_EQ(0, rtm->rtm_src_len); + EXPECT_EQ(RT_TABLE_MAIN, rtm->rtm_table); + EXPECT_EQ(RTPROT_STATIC, rtm->rtm_protocol); + EXPECT_EQ(RT_SCOPE_LINK, rtm->rtm_scope); + EXPECT_EQ(RTN_UNICAST, rtm->rtm_type); + + const struct rtattr* rta; + int payload_length = RTM_PAYLOAD(netlink_message); + int num_rta = 0; + for (rta = RTM_RTA(rtm); RTA_OK(rta, payload_length); + rta = RTA_NEXT(rta, payload_length)) { + switch (rta->rta_type) { + case RTA_PREFSRC: { + const auto* raw_address = + reinterpret_cast(RTA_DATA(rta)); + ASSERT_EQ(sizeof(struct in6_addr), RTA_PAYLOAD(rta)); + QuicIpAddress address; + address.FromPackedString(raw_address, RTA_PAYLOAD(rta)); + EXPECT_EQ(preferred_ip, address); + break; + } + case RTA_GATEWAY: { + const auto* raw_address = + reinterpret_cast(RTA_DATA(rta)); + ASSERT_EQ(sizeof(struct in6_addr), RTA_PAYLOAD(rta)); + QuicIpAddress address; + address.FromPackedString(raw_address, RTA_PAYLOAD(rta)); + EXPECT_EQ(*QboneConstants::GatewayAddress(), address); + break; + } + case RTA_OIF: { + ASSERT_EQ(sizeof(int), RTA_PAYLOAD(rta)); + const auto* interface_index = + reinterpret_cast(RTA_DATA(rta)); + EXPECT_EQ(egress_interface_index, *interface_index); + break; + } + case RTA_DST: { + const auto* raw_address = + reinterpret_cast(RTA_DATA(rta)); + ASSERT_EQ(sizeof(struct in6_addr), RTA_PAYLOAD(rta)); + QuicIpAddress address; + address.FromPackedString(raw_address, RTA_PAYLOAD(rta)); + EXPECT_EQ(subnet.ToString(), + IpRange(address, rtm->rtm_dst_len).ToString()); + break; + } + case RTA_TABLE: { + ASSERT_EQ(*reinterpret_cast(RTA_DATA(rta)), + QboneConstants::kQboneRouteTableId); + break; + } + default: + EXPECT_TRUE(false) << "Seeing rtattr that should not be sent"; + } + ++num_rta; + } + EXPECT_EQ(5, num_rta); + }); + EXPECT_TRUE(netlink->ChangeRoute( + Netlink::Verb::kReplace, QboneConstants::kQboneRouteTableId, subnet, + RT_SCOPE_LINK, preferred_ip, egress_interface_index, + Netlink::kUnspecifiedInitCwnd)); +} + +} // namespace +} // namespace quic::test diff --git a/quiche/quic/qbone/platform/rtnetlink_message.cc b/quiche/quic/qbone/platform/rtnetlink_message.cc new file mode 100644 index 000000000000..c85bf6f135ce --- /dev/null +++ b/quiche/quic/qbone/platform/rtnetlink_message.cc @@ -0,0 +1,162 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/platform/rtnetlink_message.h" + +#include + +namespace quic { + +RtnetlinkMessage::RtnetlinkMessage(uint16_t type, uint16_t flags, uint32_t seq, + uint32_t pid, const void* payload_header, + size_t payload_header_length) { + auto* buf = new uint8_t[NLMSG_SPACE(payload_header_length)]; + memset(buf, 0, NLMSG_SPACE(payload_header_length)); + + auto* message_header = reinterpret_cast(buf); + message_header->nlmsg_len = NLMSG_LENGTH(payload_header_length); + message_header->nlmsg_type = type; + message_header->nlmsg_flags = flags; + message_header->nlmsg_seq = seq; + message_header->nlmsg_pid = pid; + + if (payload_header != nullptr) { + memcpy(NLMSG_DATA(message_header), payload_header, payload_header_length); + } + message_.push_back({buf, NLMSG_SPACE(payload_header_length)}); +} + +RtnetlinkMessage::~RtnetlinkMessage() { + for (const auto& iov : message_) { + delete[] reinterpret_cast(iov.iov_base); + } +} + +void RtnetlinkMessage::AppendAttribute(uint16_t type, const void* data, + uint16_t data_length) { + auto* buf = new uint8_t[RTA_SPACE(data_length)]; + memset(buf, 0, RTA_SPACE(data_length)); + + auto* rta = reinterpret_cast(buf); + static_assert(sizeof(uint16_t) == sizeof(rta->rta_len), + "struct rtattr uses unsigned short, it's no longer 16bits"); + static_assert(sizeof(uint16_t) == sizeof(rta->rta_type), + "struct rtattr uses unsigned short, it's no longer 16bits"); + + rta->rta_len = RTA_LENGTH(data_length); + rta->rta_type = type; + memcpy(RTA_DATA(rta), data, data_length); + + message_.push_back({buf, RTA_SPACE(data_length)}); + AdjustMessageLength(rta->rta_len); +} + +std::unique_ptr RtnetlinkMessage::BuildIoVec() const { + auto message = std::make_unique(message_.size()); + int idx = 0; + for (const auto& vec : message_) { + message[idx++] = vec; + } + return message; +} + +size_t RtnetlinkMessage::IoVecSize() const { return message_.size(); } + +void RtnetlinkMessage::AdjustMessageLength(size_t additional_data_length) { + MessageHeader()->nlmsg_len = + NLMSG_ALIGN(MessageHeader()->nlmsg_len) + additional_data_length; +} + +struct nlmsghdr* RtnetlinkMessage::MessageHeader() { + return reinterpret_cast(message_[0].iov_base); +} + +LinkMessage LinkMessage::New(RtnetlinkMessage::Operation request_operation, + uint16_t flags, uint32_t seq, uint32_t pid, + const struct ifinfomsg* interface_info_header) { + uint16_t request_type; + switch (request_operation) { + case RtnetlinkMessage::Operation::NEW: + request_type = RTM_NEWLINK; + break; + case RtnetlinkMessage::Operation::DEL: + request_type = RTM_DELLINK; + break; + case RtnetlinkMessage::Operation::GET: + request_type = RTM_GETLINK; + break; + } + bool is_get = request_type == RTM_GETLINK; + + if (is_get) { + struct rtgenmsg g = {AF_UNSPEC}; + return LinkMessage(request_type, flags, seq, pid, &g, sizeof(g)); + } + return LinkMessage(request_type, flags, seq, pid, interface_info_header, + sizeof(struct ifinfomsg)); +} + +AddressMessage AddressMessage::New( + RtnetlinkMessage::Operation request_operation, uint16_t flags, uint32_t seq, + uint32_t pid, const struct ifaddrmsg* interface_address_header) { + uint16_t request_type; + switch (request_operation) { + case RtnetlinkMessage::Operation::NEW: + request_type = RTM_NEWADDR; + break; + case RtnetlinkMessage::Operation::DEL: + request_type = RTM_DELADDR; + break; + case RtnetlinkMessage::Operation::GET: + request_type = RTM_GETADDR; + break; + } + bool is_get = request_type == RTM_GETADDR; + + if (is_get) { + struct rtgenmsg g = {AF_UNSPEC}; + return AddressMessage(request_type, flags, seq, pid, &g, sizeof(g)); + } + return AddressMessage(request_type, flags, seq, pid, interface_address_header, + sizeof(struct ifaddrmsg)); +} + +RouteMessage RouteMessage::New(RtnetlinkMessage::Operation request_operation, + uint16_t flags, uint32_t seq, uint32_t pid, + const struct rtmsg* route_message_header) { + uint16_t request_type; + switch (request_operation) { + case RtnetlinkMessage::Operation::NEW: + request_type = RTM_NEWROUTE; + break; + case RtnetlinkMessage::Operation::DEL: + request_type = RTM_DELROUTE; + break; + case RtnetlinkMessage::Operation::GET: + request_type = RTM_GETROUTE; + break; + } + return RouteMessage(request_type, flags, seq, pid, route_message_header, + sizeof(struct rtmsg)); +} + +RuleMessage RuleMessage::New(RtnetlinkMessage::Operation request_operation, + uint16_t flags, uint32_t seq, uint32_t pid, + const struct rtmsg* rule_message_header) { + uint16_t request_type; + switch (request_operation) { + case RtnetlinkMessage::Operation::NEW: + request_type = RTM_NEWRULE; + break; + case RtnetlinkMessage::Operation::DEL: + request_type = RTM_DELRULE; + break; + case RtnetlinkMessage::Operation::GET: + request_type = RTM_GETRULE; + break; + } + return RuleMessage(request_type, flags, seq, pid, rule_message_header, + sizeof(rtmsg)); +} +} // namespace quic diff --git a/quiche/quic/qbone/platform/rtnetlink_message.h b/quiche/quic/qbone/platform/rtnetlink_message.h new file mode 100644 index 000000000000..a5870e8c1206 --- /dev/null +++ b/quiche/quic/qbone/platform/rtnetlink_message.h @@ -0,0 +1,112 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_PLATFORM_RTNETLINK_MESSAGE_H_ +#define QUICHE_QUIC_QBONE_PLATFORM_RTNETLINK_MESSAGE_H_ + +#include +#include +#include +#include +#include + +#include +#include + +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +// This base class is used to construct an array struct iovec that represents a +// rtnetlink message as defined in man 7 rtnet. Padding for message header +// alignment to conform NLMSG_* and RTA_* macros is added at the end of each +// iovec::iov_base. +class RtnetlinkMessage { + public: + virtual ~RtnetlinkMessage(); + + enum class Operation { + NEW, + DEL, + GET, + }; + + // Appends a struct rtattr to the message. nlmsg_len and rta_len is handled + // properly. + // Override this to perform check on type. + virtual void AppendAttribute(uint16_t type, const void* data, + uint16_t data_length); + + // Builds the array of iovec that can be fed into sendmsg directly. + std::unique_ptr BuildIoVec() const; + + // The size of the array of iovec if BuildIovec is called. + size_t IoVecSize() const; + + protected: + // Subclass should add their own message header immediately after the + // nlmsghdr. Make this private to force the creation of such header. + RtnetlinkMessage(uint16_t type, uint16_t flags, uint32_t seq, uint32_t pid, + const void* payload_header, size_t payload_header_length); + + // Adjusts nlmsg_len in the header assuming additional_data_length is appended + // at the end. + void AdjustMessageLength(size_t additional_data_length); + + private: + // Convenient function for accessing the nlmsghdr. + struct nlmsghdr* MessageHeader(); + + std::vector message_; +}; + +// Message for manipulating link level configuration as defined in man 7 +// rtnetlink. RTM_NEWLINK, RTM_DELLINK and RTM_GETLINK are supported. +class LinkMessage : public RtnetlinkMessage { + public: + static LinkMessage New(RtnetlinkMessage::Operation request_operation, + uint16_t flags, uint32_t seq, uint32_t pid, + const struct ifinfomsg* interface_info_header); + + private: + using RtnetlinkMessage::RtnetlinkMessage; +}; + +// Message for manipulating address level configuration as defined in man 7 +// rtnetlink. RTM_NEWADDR, RTM_NEWADDR and RTM_GETADDR are supported. +class AddressMessage : public RtnetlinkMessage { + public: + static AddressMessage New(RtnetlinkMessage::Operation request_operation, + uint16_t flags, uint32_t seq, uint32_t pid, + const struct ifaddrmsg* interface_address_header); + + private: + using RtnetlinkMessage::RtnetlinkMessage; +}; + +// Message for manipulating routing table as defined in man 7 rtnetlink. +// RTM_NEWROUTE, RTM_DELROUTE and RTM_GETROUTE are supported. +class RouteMessage : public RtnetlinkMessage { + public: + static RouteMessage New(RtnetlinkMessage::Operation request_operation, + uint16_t flags, uint32_t seq, uint32_t pid, + const struct rtmsg* route_message_header); + + private: + using RtnetlinkMessage::RtnetlinkMessage; +}; + +class RuleMessage : public RtnetlinkMessage { + public: + static RuleMessage New(RtnetlinkMessage::Operation request_operation, + uint16_t flags, uint32_t seq, uint32_t pid, + const struct rtmsg* rule_message_header); + + private: + using RtnetlinkMessage::RtnetlinkMessage; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_PLATFORM_RTNETLINK_MESSAGE_H_ diff --git a/quiche/quic/qbone/platform/rtnetlink_message_test.cc b/quiche/quic/qbone/platform/rtnetlink_message_test.cc new file mode 100644 index 000000000000..5757e88d33d4 --- /dev/null +++ b/quiche/quic/qbone/platform/rtnetlink_message_test.cc @@ -0,0 +1,229 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/platform/rtnetlink_message.h" + +#include + +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace { + +using ::testing::StrEq; + +TEST(RtnetlinkMessageTest, LinkMessageCanBeCreatedForGetOperation) { + uint16_t flags = NLM_F_REQUEST | NLM_F_ROOT | NLM_F_MATCH; + uint32_t seq = 42; + uint32_t pid = 7; + auto message = LinkMessage::New(RtnetlinkMessage::Operation::GET, flags, seq, + pid, nullptr); + + // No rtattr appended. + EXPECT_EQ(1, message.IoVecSize()); + + // nlmsghdr is built properly. + auto iov = message.BuildIoVec(); + EXPECT_EQ(NLMSG_SPACE(sizeof(struct rtgenmsg)), iov[0].iov_len); + auto* netlink_message = reinterpret_cast(iov[0].iov_base); + EXPECT_EQ(NLMSG_LENGTH(sizeof(struct rtgenmsg)), netlink_message->nlmsg_len); + EXPECT_EQ(RTM_GETLINK, netlink_message->nlmsg_type); + EXPECT_EQ(flags, netlink_message->nlmsg_flags); + EXPECT_EQ(seq, netlink_message->nlmsg_seq); + EXPECT_EQ(pid, netlink_message->nlmsg_pid); + + // We actually included rtgenmsg instead of the passed in ifinfomsg since this + // is a GET operation. + EXPECT_EQ(NLMSG_LENGTH(sizeof(struct rtgenmsg)), netlink_message->nlmsg_len); +} + +TEST(RtnetlinkMessageTest, LinkMessageCanBeCreatedForNewOperation) { + struct ifinfomsg interface_info_header = {AF_INET, /* pad */ 0, ARPHRD_TUNNEL, + 3, 0, 0xffffffff}; + uint16_t flags = NLM_F_REQUEST | NLM_F_ROOT | NLM_F_MATCH; + uint32_t seq = 42; + uint32_t pid = 7; + auto message = LinkMessage::New(RtnetlinkMessage::Operation::NEW, flags, seq, + pid, &interface_info_header); + + std::string device_name = "device0"; + message.AppendAttribute(IFLA_IFNAME, device_name.c_str(), device_name.size()); + + // One rtattr appended. + EXPECT_EQ(2, message.IoVecSize()); + + // nlmsghdr is built properly. + auto iov = message.BuildIoVec(); + EXPECT_EQ(NLMSG_ALIGN(NLMSG_LENGTH(sizeof(struct ifinfomsg))), + iov[0].iov_len); + auto* netlink_message = reinterpret_cast(iov[0].iov_base); + EXPECT_EQ(NLMSG_ALIGN(NLMSG_LENGTH(sizeof(struct ifinfomsg))) + + RTA_LENGTH(device_name.size()), + netlink_message->nlmsg_len); + EXPECT_EQ(RTM_NEWLINK, netlink_message->nlmsg_type); + EXPECT_EQ(flags, netlink_message->nlmsg_flags); + EXPECT_EQ(seq, netlink_message->nlmsg_seq); + EXPECT_EQ(pid, netlink_message->nlmsg_pid); + + // ifinfomsg is included properly. + auto* parsed_header = + reinterpret_cast(NLMSG_DATA(netlink_message)); + EXPECT_EQ(interface_info_header.ifi_family, parsed_header->ifi_family); + EXPECT_EQ(interface_info_header.ifi_type, parsed_header->ifi_type); + EXPECT_EQ(interface_info_header.ifi_index, parsed_header->ifi_index); + EXPECT_EQ(interface_info_header.ifi_flags, parsed_header->ifi_flags); + EXPECT_EQ(interface_info_header.ifi_change, parsed_header->ifi_change); + + // rtattr is handled properly. + EXPECT_EQ(RTA_SPACE(device_name.size()), iov[1].iov_len); + auto* rta = reinterpret_cast(iov[1].iov_base); + EXPECT_EQ(IFLA_IFNAME, rta->rta_type); + EXPECT_EQ(RTA_LENGTH(device_name.size()), rta->rta_len); + EXPECT_THAT(device_name, + StrEq(std::string(reinterpret_cast(RTA_DATA(rta)), + RTA_PAYLOAD(rta)))); +} + +TEST(RtnetlinkMessageTest, AddressMessageCanBeCreatedForGetOperation) { + uint16_t flags = NLM_F_REQUEST | NLM_F_ROOT | NLM_F_MATCH; + uint32_t seq = 42; + uint32_t pid = 7; + auto message = AddressMessage::New(RtnetlinkMessage::Operation::GET, flags, + seq, pid, nullptr); + + // No rtattr appended. + EXPECT_EQ(1, message.IoVecSize()); + + // nlmsghdr is built properly. + auto iov = message.BuildIoVec(); + EXPECT_EQ(NLMSG_SPACE(sizeof(struct rtgenmsg)), iov[0].iov_len); + auto* netlink_message = reinterpret_cast(iov[0].iov_base); + EXPECT_EQ(NLMSG_LENGTH(sizeof(struct rtgenmsg)), netlink_message->nlmsg_len); + EXPECT_EQ(RTM_GETADDR, netlink_message->nlmsg_type); + EXPECT_EQ(flags, netlink_message->nlmsg_flags); + EXPECT_EQ(seq, netlink_message->nlmsg_seq); + EXPECT_EQ(pid, netlink_message->nlmsg_pid); + + // We actually included rtgenmsg instead of the passed in ifinfomsg since this + // is a GET operation. + EXPECT_EQ(NLMSG_LENGTH(sizeof(struct rtgenmsg)), netlink_message->nlmsg_len); +} + +TEST(RtnetlinkMessageTest, AddressMessageCanBeCreatedForNewOperation) { + struct ifaddrmsg interface_address_header = {AF_INET, + /* prefixlen */ 24, + /* flags */ 0, + /* scope */ RT_SCOPE_LINK, + /* index */ 4}; + uint16_t flags = NLM_F_REQUEST | NLM_F_ROOT | NLM_F_MATCH; + uint32_t seq = 42; + uint32_t pid = 7; + auto message = AddressMessage::New(RtnetlinkMessage::Operation::NEW, flags, + seq, pid, &interface_address_header); + + QuicIpAddress ip; + QUICHE_CHECK(ip.FromString("10.0.100.3")); + message.AppendAttribute(IFA_ADDRESS, ip.ToPackedString().c_str(), + ip.ToPackedString().size()); + + // One rtattr is appended. + EXPECT_EQ(2, message.IoVecSize()); + + // nlmsghdr is built properly. + auto iov = message.BuildIoVec(); + EXPECT_EQ(NLMSG_ALIGN(NLMSG_LENGTH(sizeof(struct ifaddrmsg))), + iov[0].iov_len); + auto* netlink_message = reinterpret_cast(iov[0].iov_base); + EXPECT_EQ(NLMSG_ALIGN(NLMSG_LENGTH(sizeof(struct ifaddrmsg))) + + RTA_LENGTH(ip.ToPackedString().size()), + netlink_message->nlmsg_len); + EXPECT_EQ(RTM_NEWADDR, netlink_message->nlmsg_type); + EXPECT_EQ(flags, netlink_message->nlmsg_flags); + EXPECT_EQ(seq, netlink_message->nlmsg_seq); + EXPECT_EQ(pid, netlink_message->nlmsg_pid); + + // ifaddrmsg is included properly. + auto* parsed_header = + reinterpret_cast(NLMSG_DATA(netlink_message)); + EXPECT_EQ(interface_address_header.ifa_family, parsed_header->ifa_family); + EXPECT_EQ(interface_address_header.ifa_prefixlen, + parsed_header->ifa_prefixlen); + EXPECT_EQ(interface_address_header.ifa_flags, parsed_header->ifa_flags); + EXPECT_EQ(interface_address_header.ifa_scope, parsed_header->ifa_scope); + EXPECT_EQ(interface_address_header.ifa_index, parsed_header->ifa_index); + + // rtattr is handled properly. + EXPECT_EQ(RTA_SPACE(ip.ToPackedString().size()), iov[1].iov_len); + auto* rta = reinterpret_cast(iov[1].iov_base); + EXPECT_EQ(IFA_ADDRESS, rta->rta_type); + EXPECT_EQ(RTA_LENGTH(ip.ToPackedString().size()), rta->rta_len); + EXPECT_THAT(ip.ToPackedString(), + StrEq(std::string(reinterpret_cast(RTA_DATA(rta)), + RTA_PAYLOAD(rta)))); +} + +TEST(RtnetlinkMessageTest, RouteMessageCanBeCreatedFromNewOperation) { + struct rtmsg route_message_header = {AF_INET6, + /* rtm_dst_len */ 48, + /* rtm_src_len */ 0, + /* rtm_tos */ 0, + /* rtm_table */ RT_TABLE_MAIN, + /* rtm_protocol */ RTPROT_STATIC, + /* rtm_scope */ RT_SCOPE_LINK, + /* rtm_type */ RTN_LOCAL, + /* rtm_flags */ 0}; + uint16_t flags = NLM_F_REQUEST | NLM_F_ROOT | NLM_F_MATCH; + uint32_t seq = 42; + uint32_t pid = 7; + auto message = RouteMessage::New(RtnetlinkMessage::Operation::NEW, flags, seq, + pid, &route_message_header); + + QuicIpAddress preferred_source; + QUICHE_CHECK(preferred_source.FromString("ff80::1")); + message.AppendAttribute(RTA_PREFSRC, + preferred_source.ToPackedString().c_str(), + preferred_source.ToPackedString().size()); + + // One rtattr is appended. + EXPECT_EQ(2, message.IoVecSize()); + + // nlmsghdr is built properly + auto iov = message.BuildIoVec(); + EXPECT_EQ(NLMSG_ALIGN(NLMSG_LENGTH(sizeof(struct rtmsg))), iov[0].iov_len); + auto* netlink_message = reinterpret_cast(iov[0].iov_base); + EXPECT_EQ(NLMSG_ALIGN(NLMSG_LENGTH(sizeof(struct rtmsg))) + + RTA_LENGTH(preferred_source.ToPackedString().size()), + netlink_message->nlmsg_len); + EXPECT_EQ(RTM_NEWROUTE, netlink_message->nlmsg_type); + EXPECT_EQ(flags, netlink_message->nlmsg_flags); + EXPECT_EQ(seq, netlink_message->nlmsg_seq); + EXPECT_EQ(pid, netlink_message->nlmsg_pid); + + // rtmsg is included properly. + auto* parsed_header = + reinterpret_cast(NLMSG_DATA(netlink_message)); + EXPECT_EQ(route_message_header.rtm_family, parsed_header->rtm_family); + EXPECT_EQ(route_message_header.rtm_dst_len, parsed_header->rtm_dst_len); + EXPECT_EQ(route_message_header.rtm_src_len, parsed_header->rtm_src_len); + EXPECT_EQ(route_message_header.rtm_tos, parsed_header->rtm_tos); + EXPECT_EQ(route_message_header.rtm_table, parsed_header->rtm_table); + EXPECT_EQ(route_message_header.rtm_protocol, parsed_header->rtm_protocol); + EXPECT_EQ(route_message_header.rtm_scope, parsed_header->rtm_scope); + EXPECT_EQ(route_message_header.rtm_type, parsed_header->rtm_type); + EXPECT_EQ(route_message_header.rtm_flags, parsed_header->rtm_flags); + + // rtattr is handled properly. + EXPECT_EQ(RTA_SPACE(preferred_source.ToPackedString().size()), + iov[1].iov_len); + auto* rta = reinterpret_cast(iov[1].iov_base); + EXPECT_EQ(RTA_PREFSRC, rta->rta_type); + EXPECT_EQ(RTA_LENGTH(preferred_source.ToPackedString().size()), rta->rta_len); + EXPECT_THAT(preferred_source.ToPackedString(), + StrEq(std::string(reinterpret_cast(RTA_DATA(rta)), + RTA_PAYLOAD(rta)))); +} + +} // namespace +} // namespace quic diff --git a/quiche/quic/qbone/platform/tcp_packet.cc b/quiche/quic/qbone/platform/tcp_packet.cc new file mode 100644 index 000000000000..e73e284a0105 --- /dev/null +++ b/quiche/quic/qbone/platform/tcp_packet.cc @@ -0,0 +1,127 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/platform/tcp_packet.h" + +#include + +#include "absl/base/optimization.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/qbone/platform/internet_checksum.h" +#include "quiche/common/quiche_endian.h" + +namespace quic { +namespace { + +constexpr size_t kIPv6AddressSize = sizeof(in6_addr); +constexpr size_t kTcpTtl = 64; + +struct TCPv6Packet { + ip6_hdr ip_header; + tcphdr tcp_header; +}; + +struct TCPv6PseudoHeader { + uint32_t payload_size{}; + uint8_t zeros[3] = {0, 0, 0}; + uint8_t next_header = IPPROTO_TCP; +}; + +} // namespace + +void CreateTcpResetPacket(absl::string_view original_packet, + const std::function& cb) { + // By the time this method is called, original_packet should be fairly + // strongly validated. However, it's better to be more paranoid than not, so + // here are a bunch of very obvious checks. + if (ABSL_PREDICT_FALSE(original_packet.size() < sizeof(ip6_hdr))) { + return; + } + auto* ip6_header = reinterpret_cast(original_packet.data()); + if (ABSL_PREDICT_FALSE(ip6_header->ip6_vfc >> 4 != 6)) { + return; + } + if (ABSL_PREDICT_FALSE(ip6_header->ip6_nxt != IPPROTO_TCP)) { + return; + } + if (ABSL_PREDICT_FALSE(quiche::QuicheEndian::NetToHost16( + ip6_header->ip6_plen) < sizeof(tcphdr))) { + return; + } + auto* tcp_header = reinterpret_cast(ip6_header + 1); + + // Now that the original packet has been confirmed to be well-formed, it's + // time to make the TCP RST packet. + TCPv6Packet tcp_packet{}; + + const size_t payload_size = sizeof(tcphdr); + + // Set version to 6. + tcp_packet.ip_header.ip6_vfc = 0x6 << 4; + // Set the payload size, protocol and TTL. + tcp_packet.ip_header.ip6_plen = + quiche::QuicheEndian::HostToNet16(payload_size); + tcp_packet.ip_header.ip6_nxt = IPPROTO_TCP; + tcp_packet.ip_header.ip6_hops = kTcpTtl; + // Since the TCP RST is impersonating the endpoint, flip the source and + // destination addresses from the original packet. + tcp_packet.ip_header.ip6_src = ip6_header->ip6_dst; + tcp_packet.ip_header.ip6_dst = ip6_header->ip6_src; + + // The same is true about the TCP ports + tcp_packet.tcp_header.dest = tcp_header->source; + tcp_packet.tcp_header.source = tcp_header->dest; + + // There are no extensions in this header, so size is trivial + tcp_packet.tcp_header.doff = sizeof(tcphdr) >> 2; + // Checksum is 0 before it is computed + tcp_packet.tcp_header.check = 0; + + // Per RFC 793, TCP RST comes in one of 3 flavors: + // + // * connection CLOSED + // * connection in non-synchronized state (LISTEN, SYN-SENT, SYN-RECEIVED) + // * connection in synchronized state (ESTABLISHED, FIN-WAIT-1, etc.) + // + // QBONE is acting like a firewall, so the RFC text of interest is the CLOSED + // state. Note, however, that it is possible for a connection to actually be + // in the FIN-WAIT-1 state on the remote end, but the processing logic does + // not change. + tcp_packet.tcp_header.rst = 1; + + // If the incoming segment has an ACK field, the reset takes its sequence + // number from the ACK field of the segment, + if (tcp_header->ack) { + tcp_packet.tcp_header.seq = tcp_header->ack_seq; + } else { + // Otherwise the reset has sequence number zero and the ACK field is set to + // the sum of the sequence number and segment length of the incoming segment + tcp_packet.tcp_header.ack = 1; + tcp_packet.tcp_header.seq = 0; + tcp_packet.tcp_header.ack_seq = quiche::QuicheEndian::HostToNet32( + quiche::QuicheEndian::NetToHost32(tcp_header->seq) + 1); + } + + TCPv6PseudoHeader pseudo_header{}; + pseudo_header.payload_size = quiche::QuicheEndian::HostToNet32(payload_size); + + InternetChecksum checksum; + // Pseudoheader. + checksum.Update(tcp_packet.ip_header.ip6_src.s6_addr, kIPv6AddressSize); + checksum.Update(tcp_packet.ip_header.ip6_dst.s6_addr, kIPv6AddressSize); + checksum.Update(reinterpret_cast(&pseudo_header), + sizeof(pseudo_header)); + // TCP header. + checksum.Update(reinterpret_cast(&tcp_packet.tcp_header), + sizeof(tcp_packet.tcp_header)); + // There is no body. + tcp_packet.tcp_header.check = checksum.Value(); + + const char* packet = reinterpret_cast(&tcp_packet); + + cb(absl::string_view(packet, sizeof(tcp_packet))); +} + +} // namespace quic diff --git a/quiche/quic/qbone/platform/tcp_packet.h b/quiche/quic/qbone/platform/tcp_packet.h new file mode 100644 index 000000000000..7e9550fb8706 --- /dev/null +++ b/quiche/quic/qbone/platform/tcp_packet.h @@ -0,0 +1,25 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_PLATFORM_TCP_PACKET_H_ +#define QUICHE_QUIC_QBONE_PLATFORM_TCP_PACKET_H_ + +#include +#include + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_ip_address.h" + +namespace quic { + +// Creates an TCPv6 RST packet, returning a packed string representation of the +// packet to |cb|. +void CreateTcpResetPacket(absl::string_view original_packet, + const std::function& cb); + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_PLATFORM_TCP_PACKET_H_ diff --git a/quiche/quic/qbone/platform/tcp_packet_test.cc b/quiche/quic/qbone/platform/tcp_packet_test.cc new file mode 100644 index 000000000000..c9ada0c77914 --- /dev/null +++ b/quiche/quic/qbone/platform/tcp_packet_test.cc @@ -0,0 +1,117 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/platform/tcp_packet.h" + +#include + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/common/quiche_text_utils.h" + +namespace quic { +namespace { + +// clang-format off +constexpr uint8_t kReferenceTCPSYNPacket[] = { + // START IPv6 Header + // IPv6 with zero ToS and flow label + 0x60, 0x00, 0x00, 0x00, + // Payload is 40 bytes + 0x00, 0x28, + // Next header is TCP (6) + 0x06, + // Hop limit is 64 + 0x40, + // Source address of ::1 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + // Destination address of ::1 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + // END IPv6 Header + // START TCPv6 Header + // Source port + 0xac, 0x1e, + // Destination port + 0x27, 0x0f, + // Sequence number + 0x4b, 0x01, 0xe8, 0x99, + // Acknowledgement Sequence number, + 0x00, 0x00, 0x00, 0x00, + // Offset + 0xa0, + // Flags + 0x02, + // Window + 0xaa, 0xaa, + // Checksum + 0x2e, 0x21, + // Urgent + 0x00, 0x00, + // END TCPv6 Header + // Options + 0x02, 0x04, 0xff, 0xc4, 0x04, 0x02, 0x08, 0x0a, + 0x1b, 0xb8, 0x52, 0xa1, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x03, 0x03, 0x07, +}; + +constexpr uint8_t kReferenceTCPRSTPacket[] = { + // START IPv6 Header + // IPv6 with zero ToS and flow label + 0x60, 0x00, 0x00, 0x00, + // Payload is 20 bytes + 0x00, 0x14, + // Next header is TCP (6) + 0x06, + // Hop limit is 64 + 0x40, + // Source address of ::1 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + // Destination address of ::1 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + // END IPv6 Header + // START TCPv6 Header + // Source port + 0x27, 0x0f, + // Destination port + 0xac, 0x1e, + // Sequence number + 0x00, 0x00, 0x00, 0x00, + // Acknowledgement Sequence number, + 0x4b, 0x01, 0xe8, 0x9a, + // Offset + 0x50, + // Flags + 0x14, + // Window + 0x00, 0x00, + // Checksum + 0xa9, 0x05, + // Urgent + 0x00, 0x00, + // END TCPv6 Header +}; +// clang-format on + +} // namespace + +TEST(TcpPacketTest, CreatedPacketMatchesReference) { + absl::string_view syn = + absl::string_view(reinterpret_cast(kReferenceTCPSYNPacket), + sizeof(kReferenceTCPSYNPacket)); + absl::string_view expected_packet = + absl::string_view(reinterpret_cast(kReferenceTCPRSTPacket), + sizeof(kReferenceTCPRSTPacket)); + CreateTcpResetPacket(syn, [&expected_packet](absl::string_view packet) { + QUIC_LOG(INFO) << quiche::QuicheTextUtils::HexDump(packet); + ASSERT_EQ(packet, expected_packet); + }); +} + +} // namespace quic diff --git a/quiche/quic/qbone/qbone_client.cc b/quiche/quic/qbone/qbone_client.cc new file mode 100644 index 000000000000..ffcfcf0faf37 --- /dev/null +++ b/quiche/quic/qbone/qbone_client.cc @@ -0,0 +1,100 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/qbone_client.h" + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/quic_default_connection_helper.h" +#include "quiche/quic/platform/api/quic_testvalue.h" +#include "quiche/quic/tools/quic_client_default_network_helper.h" + +namespace quic { +namespace { +std::unique_ptr CreateNetworkHelper( + QuicEventLoop* event_loop, QboneClient* client) { + std::unique_ptr helper = + std::make_unique(event_loop, client); + quic::AdjustTestValue("QboneClient/network_helper", &helper); + return helper; +} +} // namespace + +QboneClient::QboneClient(QuicSocketAddress server_address, + const QuicServerId& server_id, + const ParsedQuicVersionVector& supported_versions, + QuicSession::Visitor* session_owner, + const QuicConfig& config, QuicEventLoop* event_loop, + std::unique_ptr proof_verifier, + QbonePacketWriter* qbone_writer, + QboneClientControlStream::Handler* qbone_handler) + : QuicClientBase(server_id, supported_versions, config, + new QuicDefaultConnectionHelper(), + event_loop->CreateAlarmFactory().release(), + CreateNetworkHelper(event_loop, this), + std::move(proof_verifier), nullptr), + qbone_writer_(qbone_writer), + qbone_handler_(qbone_handler), + session_owner_(session_owner) { + set_server_address(server_address); + crypto_config()->set_alpn("qbone"); +} + +QboneClient::~QboneClient() { ResetSession(); } + +QboneClientSession* QboneClient::qbone_session() { + return static_cast(QuicClientBase::session()); +} + +void QboneClient::ProcessPacketFromNetwork(absl::string_view packet) { + qbone_session()->ProcessPacketFromNetwork(packet); +} + +bool QboneClient::EarlyDataAccepted() { + return qbone_session()->EarlyDataAccepted(); +} + +bool QboneClient::ReceivedInchoateReject() { + return qbone_session()->ReceivedInchoateReject(); +} + +int QboneClient::GetNumSentClientHellosFromSession() { + return qbone_session()->GetNumSentClientHellos(); +} + +int QboneClient::GetNumReceivedServerConfigUpdatesFromSession() { + return qbone_session()->GetNumReceivedServerConfigUpdates(); +} + +void QboneClient::ResendSavedData() { + // no op. +} + +void QboneClient::ClearDataToResend() { + // no op. +} + +bool QboneClient::HasActiveRequests() { + return qbone_session()->HasActiveRequests(); +} + +class QboneClientSessionWithConnection : public QboneClientSession { + public: + using QboneClientSession::QboneClientSession; + + ~QboneClientSessionWithConnection() override { DeleteConnection(); } +}; + +// Takes ownership of |connection|. +std::unique_ptr QboneClient::CreateQuicClientSession( + const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection) { + return std::make_unique( + connection, crypto_config(), session_owner(), *config(), + supported_versions, server_id(), qbone_writer_, qbone_handler_); +} + +} // namespace quic diff --git a/quiche/quic/qbone/qbone_client.h b/quiche/quic/qbone/qbone_client.h new file mode 100644 index 000000000000..cda71da33261 --- /dev/null +++ b/quiche/quic/qbone/qbone_client.h @@ -0,0 +1,74 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_QBONE_CLIENT_H_ +#define QUICHE_QUIC_QBONE_QBONE_CLIENT_H_ + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/qbone/qbone_client_interface.h" +#include "quiche/quic/qbone/qbone_client_session.h" +#include "quiche/quic/qbone/qbone_packet_writer.h" +#include "quiche/quic/tools/quic_client_base.h" + +namespace quic { +// A QboneClient encapsulates connecting to a server via an event loop +// and setting up a QBONE tunnel. See the QboneTestClient in qbone_client_test +// for usage. +class QboneClient : public QuicClientBase, public QboneClientInterface { + public: + // Note that the event loop, QBONE writer, and handler are owned + // by the caller. + QboneClient(QuicSocketAddress server_address, const QuicServerId& server_id, + const ParsedQuicVersionVector& supported_versions, + QuicSession::Visitor* session_owner, const QuicConfig& config, + QuicEventLoop* event_loop, + std::unique_ptr proof_verifier, + QbonePacketWriter* qbone_writer, + QboneClientControlStream::Handler* qbone_handler); + ~QboneClient() override; + QboneClientSession* qbone_session(); + + // From QboneClientInterface. Accepts a given packet from the network and + // sends the packet down to the QBONE connection. + void ProcessPacketFromNetwork(absl::string_view packet) override; + + bool EarlyDataAccepted() override; + bool ReceivedInchoateReject() override; + + protected: + int GetNumSentClientHellosFromSession() override; + int GetNumReceivedServerConfigUpdatesFromSession() override; + + // This client does not resend saved data. This will be a no-op. + void ResendSavedData() override; + + // This client does not resend saved data. This will be a no-op. + void ClearDataToResend() override; + + // Takes ownership of |connection|. + std::unique_ptr CreateQuicClientSession( + const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection) override; + + QbonePacketWriter* qbone_writer() { return qbone_writer_; } + + QboneClientControlStream::Handler* qbone_control_handler() { + return qbone_handler_; + } + + QuicSession::Visitor* session_owner() { return session_owner_; } + + bool HasActiveRequests() override; + + private: + QbonePacketWriter* qbone_writer_; + QboneClientControlStream::Handler* qbone_handler_; + + QuicSession::Visitor* session_owner_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_QBONE_CLIENT_H_ diff --git a/quiche/quic/qbone/qbone_client_interface.h b/quiche/quic/qbone/qbone_client_interface.h new file mode 100644 index 000000000000..8b31cceb65da --- /dev/null +++ b/quiche/quic/qbone/qbone_client_interface.h @@ -0,0 +1,25 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_QBONE_CLIENT_INTERFACE_H_ +#define QUICHE_QUIC_QBONE_QBONE_CLIENT_INTERFACE_H_ + +#include + +#include "absl/strings/string_view.h" + +namespace quic { + +// An interface that includes methods to interact with a QBONE client. +class QboneClientInterface { + public: + virtual ~QboneClientInterface() {} + // Accepts a given packet from the network and sends the packet down to the + // QBONE connection. + virtual void ProcessPacketFromNetwork(absl::string_view packet) = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_QBONE_CLIENT_INTERFACE_H_ diff --git a/quiche/quic/qbone/qbone_client_session.cc b/quiche/quic/qbone/qbone_client_session.cc new file mode 100644 index 000000000000..cfc6001abd41 --- /dev/null +++ b/quiche/quic/qbone/qbone_client_session.cc @@ -0,0 +1,124 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/qbone_client_session.h" + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/qbone/qbone_constants.h" +#include "quiche/common/platform/api/quiche_command_line_flags.h" + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + bool, qbone_client_defer_control_stream_creation, true, + "If true, control stream in QBONE client session is created after " + "encryption established."); + +namespace quic { + +QboneClientSession::QboneClientSession( + QuicConnection* connection, + QuicCryptoClientConfig* quic_crypto_client_config, + QuicSession::Visitor* owner, const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + const QuicServerId& server_id, QbonePacketWriter* writer, + QboneClientControlStream::Handler* handler) + : QboneSessionBase(connection, owner, config, supported_versions, writer), + server_id_(server_id), + quic_crypto_client_config_(quic_crypto_client_config), + handler_(handler) {} + +QboneClientSession::~QboneClientSession() {} + +std::unique_ptr QboneClientSession::CreateCryptoStream() { + return std::make_unique( + server_id_, this, nullptr, quic_crypto_client_config_, this, + /*has_application_state = */ true); +} + +void QboneClientSession::CreateControlStream() { + if (control_stream_ != nullptr) { + return; + } + // Register the reserved control stream. + QuicStreamId next_id = GetNextOutgoingBidirectionalStreamId(); + QUICHE_DCHECK_EQ(next_id, + QboneConstants::GetControlStreamId(transport_version())); + auto control_stream = + std::make_unique(this, handler_); + control_stream_ = control_stream.get(); + ActivateStream(std::move(control_stream)); +} + +void QboneClientSession::Initialize() { + // Initialize must be called first, as that's what generates the crypto + // stream. + QboneSessionBase::Initialize(); + static_cast(GetMutableCryptoStream()) + ->CryptoConnect(); + if (!quiche::GetQuicheCommandLineFlag( + FLAGS_qbone_client_defer_control_stream_creation)) { + CreateControlStream(); + } +} + +void QboneClientSession::SetDefaultEncryptionLevel( + quic::EncryptionLevel level) { + QboneSessionBase::SetDefaultEncryptionLevel(level); + if (quiche::GetQuicheCommandLineFlag( + FLAGS_qbone_client_defer_control_stream_creation) && + level == quic::ENCRYPTION_FORWARD_SECURE) { + CreateControlStream(); + } +} + +int QboneClientSession::GetNumSentClientHellos() const { + return static_cast(GetCryptoStream()) + ->num_sent_client_hellos(); +} + +bool QboneClientSession::EarlyDataAccepted() const { + return static_cast(GetCryptoStream()) + ->EarlyDataAccepted(); +} + +bool QboneClientSession::ReceivedInchoateReject() const { + return static_cast(GetCryptoStream()) + ->ReceivedInchoateReject(); +} + +int QboneClientSession::GetNumReceivedServerConfigUpdates() const { + return static_cast(GetCryptoStream()) + ->num_scup_messages_received(); +} + +bool QboneClientSession::SendServerRequest(const QboneServerRequest& request) { + if (!control_stream_) { + QUIC_BUG(quic_bug_11056_1) + << "Cannot send server request before control stream is created."; + return false; + } + return control_stream_->SendRequest(request); +} + +void QboneClientSession::ProcessPacketFromNetwork(absl::string_view packet) { + SendPacketToPeer(packet); +} + +void QboneClientSession::ProcessPacketFromPeer(absl::string_view packet) { + writer_->WritePacketToNetwork(packet.data(), packet.size()); +} + +void QboneClientSession::OnProofValid( + const QuicCryptoClientConfig::CachedState& cached) {} + +void QboneClientSession::OnProofVerifyDetailsAvailable( + const ProofVerifyDetails& verify_details) {} + +bool QboneClientSession::HasActiveRequests() const { + return GetNumActiveStreams() + num_draining_streams() > 0; +} + +} // namespace quic diff --git a/quiche/quic/qbone/qbone_client_session.h b/quiche/quic/qbone/qbone_client_session.h new file mode 100644 index 000000000000..b0bd5cf5048f --- /dev/null +++ b/quiche/quic/qbone/qbone_client_session.h @@ -0,0 +1,93 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_QBONE_CLIENT_SESSION_H_ +#define QUICHE_QUIC_QBONE_QBONE_CLIENT_SESSION_H_ + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_crypto_client_stream.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/qbone/qbone_control.pb.h" +#include "quiche/quic/qbone/qbone_control_stream.h" +#include "quiche/quic/qbone/qbone_packet_writer.h" +#include "quiche/quic/qbone/qbone_session_base.h" + +namespace quic { + +class QUIC_EXPORT_PRIVATE QboneClientSession + : public QboneSessionBase, + public QuicCryptoClientStream::ProofHandler { + public: + QboneClientSession(QuicConnection* connection, + QuicCryptoClientConfig* quic_crypto_client_config, + QuicSession::Visitor* owner, const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + const QuicServerId& server_id, QbonePacketWriter* writer, + QboneClientControlStream::Handler* handler); + QboneClientSession(const QboneClientSession&) = delete; + QboneClientSession& operator=(const QboneClientSession&) = delete; + ~QboneClientSession() override; + + // QuicSession overrides. This will initiate the crypto stream. + void Initialize() override; + // Override to create control stream at FORWARD_SECURE encryption level. + void SetDefaultEncryptionLevel(quic::EncryptionLevel level) override; + + // Returns the number of client hello messages that have been sent on the + // crypto stream. If the handshake has completed then this is one greater + // than the number of round-trips needed for the handshake. + int GetNumSentClientHellos() const; + + // Returns true if early data (0-RTT data) was sent and the server accepted + // it. + bool EarlyDataAccepted() const; + + // Returns true if the handshake was delayed one round trip by the server + // because the server wanted proof the client controls its source address + // before progressing further. In Google QUIC, this would be due to an + // inchoate REJ in the QUIC Crypto handshake; in IETF QUIC this would be due + // to a Retry packet. + // TODO(nharper): Consider a better name for this method. + bool ReceivedInchoateReject() const; + + int GetNumReceivedServerConfigUpdates() const; + + bool SendServerRequest(const QboneServerRequest& request); + + void ProcessPacketFromNetwork(absl::string_view packet) override; + void ProcessPacketFromPeer(absl::string_view packet) override; + + // Returns true if there are active requests on this session. + bool HasActiveRequests() const; + + protected: + // QboneSessionBase interface implementation. + std::unique_ptr CreateCryptoStream() override; + + // Instantiate QboneClientControlStream. + void CreateControlStream(); + + // ProofHandler interface implementation. + void OnProofValid(const QuicCryptoClientConfig::CachedState& cached) override; + void OnProofVerifyDetailsAvailable( + const ProofVerifyDetails& verify_details) override; + + QuicServerId server_id() { return server_id_; } + QuicCryptoClientConfig* crypto_client_config() { + return quic_crypto_client_config_; + } + + private: + QuicServerId server_id_; + // Config for QUIC crypto client stream, used by the client. + QuicCryptoClientConfig* quic_crypto_client_config_; + // Passed to the control stream. + QboneClientControlStream::Handler* handler_; + // The unowned control stream. + QboneClientControlStream* control_stream_ = nullptr; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_QBONE_CLIENT_SESSION_H_ diff --git a/quiche/quic/qbone/qbone_client_test.cc b/quiche/quic/qbone/qbone_client_test.cc new file mode 100644 index 000000000000..e893ddbf5b63 --- /dev/null +++ b/quiche/quic/qbone/qbone_client_test.cc @@ -0,0 +1,264 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Sets up a dispatcher and sends requests via the QboneClient. + +#include "quiche/quic/qbone/qbone_client.h" + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/io/quic_default_event_loop.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/quic_alarm_factory.h" +#include "quiche/quic/core/quic_default_clock.h" +#include "quiche/quic/core/quic_default_connection_helper.h" +#include "quiche/quic/core/quic_dispatcher.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_mutex.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/platform/api/quic_test_loopback.h" +#include "quiche/quic/qbone/qbone_packet_processor_test_tools.h" +#include "quiche/quic/qbone/qbone_server_session.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_dispatcher_peer.h" +#include "quiche/quic/test_tools/quic_server_peer.h" +#include "quiche/quic/test_tools/server_thread.h" +#include "quiche/quic/tools/quic_memory_cache_backend.h" +#include "quiche/quic/tools/quic_server.h" + +namespace quic { +namespace test { +namespace { + +using ::testing::ElementsAre; + +ParsedQuicVersionVector GetTestParams() { + ParsedQuicVersionVector test_versions; + SetQuicReloadableFlag(quic_disable_version_q046, false); + // TODO(b/113130636): Make QBONE work with TLS. + for (const auto& version : CurrentSupportedVersionsWithQuicCrypto()) { + // QBONE requires MESSAGE frames + if (!version.SupportsMessageFrames()) { + continue; + } + test_versions.push_back(version); + } + + return test_versions; +} + +std::string TestPacketIn(const std::string& body) { + return PrependIPv6HeaderForTest(body, 5); +} + +std::string TestPacketOut(const std::string& body) { + return PrependIPv6HeaderForTest(body, 4); +} + +class DataSavingQbonePacketWriter : public QbonePacketWriter { + public: + void WritePacketToNetwork(const char* packet, size_t size) override { + QuicWriterMutexLock lock(&mu_); + data_.push_back(std::string(packet, size)); + } + + std::vector data() { + QuicWriterMutexLock lock(&mu_); + return data_; + } + + private: + QuicMutex mu_; + std::vector data_; +}; + +// A subclass of a QBONE session that will own the connection passed in. +class ConnectionOwningQboneServerSession : public QboneServerSession { + public: + ConnectionOwningQboneServerSession( + const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, Visitor* owner, const QuicConfig& config, + const QuicCryptoServerConfig* quic_crypto_server_config, + QuicCompressedCertsCache* compressed_certs_cache, + QbonePacketWriter* writer) + : QboneServerSession(supported_versions, connection, owner, config, + quic_crypto_server_config, compressed_certs_cache, + writer, TestLoopback6(), TestLoopback6(), 64, + nullptr), + connection_(connection) {} + + private: + // Note that we don't expect the QboneServerSession or any of its parent + // classes to do anything with the connection_ in their destructors. + std::unique_ptr connection_; +}; + +class QuicQboneDispatcher : public QuicDispatcher { + public: + QuicQboneDispatcher( + const QuicConfig* config, const QuicCryptoServerConfig* crypto_config, + QuicVersionManager* version_manager, + std::unique_ptr helper, + std::unique_ptr session_helper, + std::unique_ptr alarm_factory, + QbonePacketWriter* writer, ConnectionIdGeneratorInterface& generator) + : QuicDispatcher(config, crypto_config, version_manager, + std::move(helper), std::move(session_helper), + std::move(alarm_factory), kQuicDefaultConnectionIdLength, + generator), + writer_(writer) {} + + std::unique_ptr CreateQuicSession( + QuicConnectionId id, const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, absl::string_view alpn, + const ParsedQuicVersion& version, + const ParsedClientHello& /*parsed_chlo*/) override { + QUICHE_CHECK_EQ(alpn, "qbone"); + QuicConnection* connection = new QuicConnection( + id, self_address, peer_address, helper(), alarm_factory(), writer(), + /* owns_writer= */ false, Perspective::IS_SERVER, + ParsedQuicVersionVector{version}, connection_id_generator()); + // The connection owning wrapper owns the connection created. + auto session = std::make_unique( + GetSupportedVersions(), connection, this, config(), crypto_config(), + compressed_certs_cache(), writer_); + session->Initialize(); + return session; + } + + private: + QbonePacketWriter* writer_; +}; + +class QboneTestServer : public QuicServer { + public: + explicit QboneTestServer(std::unique_ptr proof_source, + quic::QuicMemoryCacheBackend* response_cache) + : QuicServer(std::move(proof_source), response_cache) {} + QuicDispatcher* CreateQuicDispatcher() override { + return new QuicQboneDispatcher( + &config(), &crypto_config(), version_manager(), + std::make_unique(), + std::make_unique(), + event_loop()->CreateAlarmFactory(), &writer_, + connection_id_generator()); + } + + std::vector data() { return writer_.data(); } + + private: + DataSavingQbonePacketWriter writer_; +}; + +class QboneTestClient : public QboneClient { + public: + QboneTestClient(QuicSocketAddress server_address, + const QuicServerId& server_id, + const ParsedQuicVersionVector& supported_versions, + QuicEventLoop* event_loop, + std::unique_ptr proof_verifier) + : QboneClient(server_address, server_id, supported_versions, + /*session_owner=*/nullptr, QuicConfig(), event_loop, + std::move(proof_verifier), &qbone_writer_, nullptr) {} + + ~QboneTestClient() override {} + + void SendData(const std::string& data) { + qbone_session()->ProcessPacketFromNetwork(data); + } + + void WaitForWriteToFlush() { + while (connected() && session()->HasDataToWrite()) { + WaitForEvents(); + } + } + + // Returns true when the data size is reached or false on timeouts. + bool WaitForDataSize(int n, QuicTime::Delta timeout) { + const QuicClock* clock = + quic::test::QuicConnectionPeer::GetHelper(session()->connection()) + ->GetClock(); + const QuicTime deadline = clock->Now() + timeout; + while (data().size() < n) { + if (clock->Now() > deadline) { + return false; + } + WaitForEvents(); + } + return true; + } + + std::vector data() { return qbone_writer_.data(); } + + private: + DataSavingQbonePacketWriter qbone_writer_; +}; + +class QboneClientTest : public QuicTestWithParam {}; + +INSTANTIATE_TEST_SUITE_P(Tests, QboneClientTest, + ::testing::ValuesIn(GetTestParams()), + ::testing::PrintToStringParamName()); + +TEST_P(QboneClientTest, SendDataFromClient) { + quic::QuicMemoryCacheBackend server_backend; + auto server = std::make_unique( + crypto_test_utils::ProofSourceForTesting(), &server_backend); + QboneTestServer* server_ptr = server.get(); + QuicSocketAddress server_address(TestLoopback(), 0); + ServerThread server_thread(std::move(server), server_address); + server_thread.Initialize(); + server_address = + QuicSocketAddress(server_address.host(), server_thread.GetPort()); + server_thread.Start(); + + std::unique_ptr event_loop = + GetDefaultEventLoop()->Create(quic::QuicDefaultClock::Get()); + QboneTestClient client( + server_address, + QuicServerId("test.example.com", server_address.port(), false), + ParsedQuicVersionVector{GetParam()}, event_loop.get(), + crypto_test_utils::ProofVerifierForTesting()); + ASSERT_TRUE(client.Initialize()); + ASSERT_TRUE(client.Connect()); + ASSERT_TRUE(client.WaitForOneRttKeysAvailable()); + client.SendData(TestPacketIn("hello")); + client.SendData(TestPacketIn("world")); + client.WaitForWriteToFlush(); + + // Wait until the server has received at least two packets, timeout after 5s. + ASSERT_TRUE( + server_thread.WaitUntil([&] { return server_ptr->data().size() >= 2; }, + QuicTime::Delta::FromSeconds(5))); + + // Pretend the server gets data. + std::string long_data(1000, 'A'); + server_thread.Schedule([server_ptr, &long_data]() { + EXPECT_THAT(server_ptr->data(), + ElementsAre(TestPacketOut("hello"), TestPacketOut("world"))); + auto server_session = static_cast( + QuicDispatcherPeer::GetFirstSessionIfAny( + QuicServerPeer::GetDispatcher(server_ptr))); + server_session->ProcessPacketFromNetwork( + TestPacketIn("Somethingsomething")); + server_session->ProcessPacketFromNetwork(TestPacketIn(long_data)); + server_session->ProcessPacketFromNetwork(TestPacketIn(long_data)); + }); + + EXPECT_TRUE(client.WaitForDataSize(3, QuicTime::Delta::FromSeconds(5))); + EXPECT_THAT(client.data(), + ElementsAre(TestPacketOut("Somethingsomething"), + TestPacketOut(long_data), TestPacketOut(long_data))); + + client.Disconnect(); + server_thread.Quit(); + server_thread.Join(); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/qbone/qbone_constants.cc b/quiche/quic/qbone/qbone_constants.cc new file mode 100644 index 000000000000..79a6b11530dc --- /dev/null +++ b/quiche/quic/qbone/qbone_constants.cc @@ -0,0 +1,45 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/qbone_constants.h" + +#include "quiche/quic/core/quic_utils.h" + +namespace quic { + +constexpr char QboneConstants::kQboneAlpn[]; +const QuicByteCount QboneConstants::kMaxQbonePacketBytes; +const uint32_t QboneConstants::kQboneRouteTableId; + +QuicStreamId QboneConstants::GetControlStreamId(QuicTransportVersion version) { + return QuicUtils::GetFirstBidirectionalStreamId(version, + Perspective::IS_CLIENT); +} + +const QuicIpAddress* QboneConstants::TerminatorLocalAddress() { + static auto* terminator_address = []() { + auto* address = new QuicIpAddress; + // 0x71 0x62 0x6f 0x6e 0x65 is 'qbone' in ascii. + address->FromString("fe80::71:626f:6e65"); + return address; + }(); + return terminator_address; +} + +const IpRange* QboneConstants::TerminatorLocalAddressRange() { + static auto* range = + new quic::IpRange(*quic::QboneConstants::TerminatorLocalAddress(), 128); + return range; +} + +const QuicIpAddress* QboneConstants::GatewayAddress() { + static auto* gateway_address = []() { + auto* address = new QuicIpAddress; + address->FromString("fe80::1"); + return address; + }(); + return gateway_address; +} + +} // namespace quic diff --git a/quiche/quic/qbone/qbone_constants.h b/quiche/quic/qbone/qbone_constants.h new file mode 100644 index 000000000000..3ac75d4c3cb2 --- /dev/null +++ b/quiche/quic/qbone/qbone_constants.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_QBONE_CONSTANTS_H_ +#define QUICHE_QUIC_QBONE_QBONE_CONSTANTS_H_ + +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/qbone/platform/ip_range.h" + +namespace quic { + +struct QboneConstants { + // QBONE's ALPN + static constexpr char kQboneAlpn[] = "qbone"; + // The maximum number of bytes allowed in a QBONE packet. + static const QuicByteCount kMaxQbonePacketBytes = 2000; + // The table id for QBONE's routing table. 'bone' in ascii. + static const uint32_t kQboneRouteTableId = 0x626F6E65; + // The stream ID of the control channel. + static QuicStreamId GetControlStreamId(QuicTransportVersion version); + // The link-local address of the Terminator + static const QuicIpAddress* TerminatorLocalAddress(); + // The IPRange containing the TerminatorLocalAddress + static const IpRange* TerminatorLocalAddressRange(); + // The gateway address to provide when configuring routes to the QBONE + // interface + static const QuicIpAddress* GatewayAddress(); +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_QBONE_CONSTANTS_H_ diff --git a/quiche/quic/qbone/qbone_control.proto b/quiche/quic/qbone/qbone_control.proto new file mode 100644 index 000000000000..f0090d6cf8a8 --- /dev/null +++ b/quiche/quic/qbone/qbone_control.proto @@ -0,0 +1,13 @@ +syntax = "proto2"; + +option optimize_for = LITE_RUNTIME; + +package quic; + +message QboneServerRequest { + extensions 1000 to max; +}; + +message QboneClientRequest { + extensions 1000 to max; +}; diff --git a/quiche/quic/qbone/qbone_control_placeholder.proto b/quiche/quic/qbone/qbone_control_placeholder.proto new file mode 100644 index 000000000000..af993406d8e8 --- /dev/null +++ b/quiche/quic/qbone/qbone_control_placeholder.proto @@ -0,0 +1,20 @@ +syntax = "proto2"; + +option optimize_for = LITE_RUNTIME; + +package quic; + +import "quiche/quic/qbone/qbone_control.proto"; + +// These provide fields for QboneServerRequest and QboneClientRequest that are +// used to test the control channel. Once the control channel actually has real +// data to pass they can be removed. +// TODO(b/62139999): Remove this file in favor of testing actual configuration. + +extend QboneServerRequest { + optional string server_placeholder = 179838467; +} + +extend QboneClientRequest { + optional string client_placeholder = 179838467; +} diff --git a/quiche/quic/qbone/qbone_control_stream.cc b/quiche/quic/qbone/qbone_control_stream.cc new file mode 100644 index 000000000000..5ded99555eba --- /dev/null +++ b/quiche/quic/qbone/qbone_control_stream.cc @@ -0,0 +1,85 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/qbone_control_stream.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/qbone/qbone_constants.h" + +namespace quic { + +namespace { +static constexpr size_t kRequestSizeBytes = sizeof(uint16_t); +} // namespace + +QboneControlStreamBase::QboneControlStreamBase(QuicSession* session) + : QuicStream( + QboneConstants::GetControlStreamId(session->transport_version()), + session, + /*is_static=*/true, BIDIRECTIONAL), + pending_message_size_(0) {} + +QboneControlStreamBase::QboneControlStreamBase(quic::PendingStream* pending, + QuicSession* session) + : QuicStream(pending, session, /*is_static=*/true), + pending_message_size_(0) { + QUICHE_DCHECK_EQ(pending->id(), QboneConstants::GetControlStreamId( + session->transport_version())); +} + +void QboneControlStreamBase::OnDataAvailable() { + sequencer()->Read(&buffer_); + while (true) { + if (pending_message_size_ == 0) { + // Start of a message. + if (buffer_.size() < kRequestSizeBytes) { + return; + } + memcpy(&pending_message_size_, buffer_.data(), kRequestSizeBytes); + buffer_.erase(0, kRequestSizeBytes); + } + // Continuation of a message. + if (buffer_.size() < pending_message_size_) { + return; + } + std::string tmp = buffer_.substr(0, pending_message_size_); + buffer_.erase(0, pending_message_size_); + pending_message_size_ = 0; + OnMessage(tmp); + } +} + +bool QboneControlStreamBase::SendMessage(const proto2::Message& proto) { + std::string tmp; + if (!proto.SerializeToString(&tmp)) { + QUIC_BUG(quic_bug_11023_1) << "Failed to serialize QboneControlRequest"; + return false; + } + if (tmp.size() > std::numeric_limits::max()) { + QUIC_BUG(quic_bug_11023_2) + << "QboneControlRequest too large: " << tmp.size() << " > " + << std::numeric_limits::max(); + return false; + } + uint16_t size = tmp.size(); + char size_str[kRequestSizeBytes]; + memcpy(size_str, &size, kRequestSizeBytes); + WriteOrBufferData(absl::string_view(size_str, kRequestSizeBytes), false, + nullptr); + WriteOrBufferData(tmp, false, nullptr); + return true; +} + +void QboneControlStreamBase::OnStreamReset( + const QuicRstStreamFrame& /*frame*/) { + stream_delegate()->OnStreamError(QUIC_INVALID_STREAM_ID, + "Attempt to reset control stream"); +} + +} // namespace quic diff --git a/quiche/quic/qbone/qbone_control_stream.h b/quiche/quic/qbone/qbone_control_stream.h new file mode 100644 index 000000000000..379d735b49ac --- /dev/null +++ b/quiche/quic/qbone/qbone_control_stream.h @@ -0,0 +1,82 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_QBONE_CONTROL_STREAM_H_ +#define QUICHE_QUIC_QBONE_QBONE_CONTROL_STREAM_H_ + +#include "quiche/quic/core/quic_stream.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/qbone/qbone_control.pb.h" + +namespace quic { + +class QboneSessionBase; + +class QUIC_EXPORT_PRIVATE QboneControlStreamBase : public QuicStream { + public: + explicit QboneControlStreamBase(QuicSession* session); + QboneControlStreamBase(quic::PendingStream* pending, QuicSession* session); + + void OnDataAvailable() override; + + void OnStreamReset(const QuicRstStreamFrame& frame) override; + + protected: + virtual void OnMessage(const std::string& data) = 0; + bool SendMessage(const proto2::Message& proto); + + private: + uint16_t pending_message_size_; + std::string buffer_; +}; + +template +class QUIC_EXPORT_PRIVATE QboneControlHandler { + public: + virtual ~QboneControlHandler() {} + + virtual void OnControlRequest(const T& request) = 0; + virtual void OnControlError() = 0; +}; + +template +class QUIC_EXPORT_PRIVATE QboneControlStream : public QboneControlStreamBase { + public: + using Handler = QboneControlHandler; + + QboneControlStream(QuicSession* session, Handler* handler) + : QboneControlStreamBase(session), handler_(handler) {} + QboneControlStream(quic::PendingStream* pending, QuicSession* session, + Handler* handler) + : QboneControlStreamBase(pending, session), handler_(handler) {} + + bool SendRequest(const Outgoing& request) { return SendMessage(request); } + + protected: + void OnMessage(const std::string& data) override { + Incoming request; + if (!request.ParseFromString(data)) { + QUIC_LOG(ERROR) << "Failed to parse incoming request"; + if (handler_ != nullptr) { + handler_->OnControlError(); + } + return; + } + if (handler_ != nullptr) { + handler_->OnControlRequest(request); + } + } + + private: + Handler* handler_; +}; + +using QboneServerControlStream = + QboneControlStream; +using QboneClientControlStream = + QboneControlStream; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_QBONE_CONTROL_STREAM_H_ diff --git a/quiche/quic/qbone/qbone_packet_exchanger.cc b/quiche/quic/qbone/qbone_packet_exchanger.cc new file mode 100644 index 000000000000..f582d6ed0d9c --- /dev/null +++ b/quiche/quic/qbone/qbone_packet_exchanger.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/qbone_packet_exchanger.h" + +#include + +namespace quic { + +bool QbonePacketExchanger::ReadAndDeliverPacket( + QboneClientInterface* qbone_client) { + bool blocked = false; + std::string error; + std::unique_ptr packet = ReadPacket(&blocked, &error); + if (packet == nullptr) { + if (!blocked && visitor_) { + visitor_->OnReadError(error); + } + return false; + } + qbone_client->ProcessPacketFromNetwork(packet->AsStringPiece()); + return true; +} + +void QbonePacketExchanger::WritePacketToNetwork(const char* packet, + size_t size) { + bool blocked = false; + std::string error; + if (packet_queue_.empty() && !write_blocked_) { + if (WritePacket(packet, size, &blocked, &error)) { + return; + } + if (blocked) { + write_blocked_ = true; + } else { + QUIC_LOG_EVERY_N_SEC(ERROR, 60) << "Packet write failed: " << error; + if (visitor_) { + visitor_->OnWriteError(error); + } + } + } + + // Drop the packet on the floor if the queue if full. + if (packet_queue_.size() >= max_pending_packets_) { + return; + } + + auto data_copy = new char[size]; + memcpy(data_copy, packet, size); + packet_queue_.push_back( + std::make_unique(data_copy, size, /* owns_buffer = */ true)); +} + +void QbonePacketExchanger::SetWritable() { + write_blocked_ = false; + while (!packet_queue_.empty()) { + bool blocked = false; + std::string error; + if (WritePacket(packet_queue_.front()->data(), + packet_queue_.front()->length(), &blocked, &error)) { + packet_queue_.pop_front(); + } else { + if (!blocked && visitor_) { + visitor_->OnWriteError(error); + } + write_blocked_ = blocked; + return; + } + } +} + +} // namespace quic diff --git a/quiche/quic/qbone/qbone_packet_exchanger.h b/quiche/quic/qbone/qbone_packet_exchanger.h new file mode 100644 index 000000000000..4fd617b6eb2e --- /dev/null +++ b/quiche/quic/qbone/qbone_packet_exchanger.h @@ -0,0 +1,78 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_QBONE_PACKET_EXCHANGER_H_ +#define QUICHE_QUIC_QBONE_QBONE_PACKET_EXCHANGER_H_ + +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/qbone/qbone_client_interface.h" +#include "quiche/quic/qbone/qbone_packet_writer.h" + +namespace quic { + +// Handles reading and writing on the local network and exchange packets between +// the local network with a QBONE connection. +class QbonePacketExchanger : public QbonePacketWriter { + public: + // The owner might want to receive notifications when read or write fails. + class Visitor { + public: + virtual ~Visitor() {} + virtual void OnReadError(const std::string& error) {} + virtual void OnWriteError(const std::string& error) {} + }; + // Does not take ownership of visitor. + QbonePacketExchanger(Visitor* visitor, size_t max_pending_packets) + : visitor_(visitor), max_pending_packets_(max_pending_packets) {} + + QbonePacketExchanger(const QbonePacketExchanger&) = delete; + QbonePacketExchanger& operator=(const QbonePacketExchanger&) = delete; + + QbonePacketExchanger(QbonePacketExchanger&&) = delete; + QbonePacketExchanger& operator=(QbonePacketExchanger&&) = delete; + + ~QbonePacketExchanger() = default; + + // Returns true if there may be more packets to read. + // Implementations handles the actual raw read and delivers the packet to + // qbone_client. + bool ReadAndDeliverPacket(QboneClientInterface* qbone_client); + + // From QbonePacketWriter. + // Writes a packet to the local network. If the write would be blocked, the + // packet will be queued if the queue is smaller than max_pending_packets_. + void WritePacketToNetwork(const char* packet, size_t size) override; + + // The caller signifies that the local network is no longer blocked. + void SetWritable(); + + private: + // The actual implementation that reads a packet from the local network. + // Returns the packet if one is successfully read. This might nullptr when a) + // there is no packet to read, b) the read failed. In the former case, blocked + // is set to true. error contains the error message. + virtual std::unique_ptr ReadPacket(bool* blocked, + std::string* error) = 0; + + // The actual implementation that writes a packet to the local network. + // Returns true if the write succeeds. blocked will be set to true if the + // write failure is caused by the local network being blocked. error contains + // the error message. + virtual bool WritePacket(const char* packet, size_t size, bool* blocked, + std::string* error) = 0; + + std::list> packet_queue_; + + Visitor* visitor_; + + // The maximum number of packets that could be queued up when writing to local + // network is blocked. + size_t max_pending_packets_; + + bool write_blocked_ = false; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_QBONE_PACKET_EXCHANGER_H_ diff --git a/quiche/quic/qbone/qbone_packet_exchanger_test.cc b/quiche/quic/qbone/qbone_packet_exchanger_test.cc new file mode 100644 index 000000000000..be6084159a92 --- /dev/null +++ b/quiche/quic/qbone/qbone_packet_exchanger_test.cc @@ -0,0 +1,269 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/qbone_packet_exchanger.h" + +#include + +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/qbone/mock_qbone_client.h" + +namespace quic { +namespace { + +using ::testing::StrEq; +using ::testing::StrictMock; + +const size_t kMaxPendingPackets = 2; + +class MockVisitor : public QbonePacketExchanger::Visitor { + public: + MOCK_METHOD(void, OnReadError, (const std::string&), (override)); + MOCK_METHOD(void, OnWriteError, (const std::string&), (override)); +}; + +class FakeQbonePacketExchanger : public QbonePacketExchanger { + public: + using QbonePacketExchanger::QbonePacketExchanger; + + // Adds a packet to the end of list of packets to be returned by ReadPacket. + // When the list is empty, ReadPacket returns nullptr to signify error as + // defined by QbonePacketExchanger. If SetReadError is not called or called + // with empty error string, ReadPacket sets blocked to true. + void AddPacketToBeRead(std::unique_ptr packet) { + packets_to_be_read_.push_back(std::move(packet)); + } + + // Sets the error to be returned by ReadPacket when the list of packets is + // empty. If error is empty string, blocked is set by ReadPacket. + void SetReadError(const std::string& error) { read_error_ = error; } + + // Force WritePacket to fail with the given status. WritePacket returns true + // when blocked == true and error is empty. + void ForceWriteFailure(bool blocked, const std::string& error) { + write_blocked_ = blocked; + write_error_ = error; + } + + // Packets that have been successfully written by WritePacket. + const std::vector& packets_written() const { + return packets_written_; + } + + private: + // Implements QbonePacketExchanger::ReadPacket. + std::unique_ptr ReadPacket(bool* blocked, + std::string* error) override { + *blocked = false; + + if (packets_to_be_read_.empty()) { + *blocked = read_error_.empty(); + *error = read_error_; + return nullptr; + } + + std::unique_ptr packet = std::move(packets_to_be_read_.front()); + packets_to_be_read_.pop_front(); + return packet; + } + + // Implements QbonePacketExchanger::WritePacket. + bool WritePacket(const char* packet, size_t size, bool* blocked, + std::string* error) override { + *blocked = false; + + if (write_blocked_ || !write_error_.empty()) { + *blocked = write_blocked_; + *error = write_error_; + return false; + } + + packets_written_.push_back(std::string(packet, size)); + return true; + } + + std::string read_error_; + std::list> packets_to_be_read_; + + std::string write_error_; + bool write_blocked_ = false; + std::vector packets_written_; +}; + +TEST(QbonePacketExchangerTest, + ReadAndDeliverPacketDeliversPacketToQboneClient) { + StrictMock visitor; + FakeQbonePacketExchanger exchanger(&visitor, kMaxPendingPackets); + StrictMock client; + + std::string packet = "data"; + exchanger.AddPacketToBeRead( + std::make_unique(packet.data(), packet.length())); + EXPECT_CALL(client, ProcessPacketFromNetwork(StrEq("data"))); + + EXPECT_TRUE(exchanger.ReadAndDeliverPacket(&client)); +} + +TEST(QbonePacketExchangerTest, + ReadAndDeliverPacketNotifiesVisitorOnReadFailure) { + MockVisitor visitor; + FakeQbonePacketExchanger exchanger(&visitor, kMaxPendingPackets); + MockQboneClient client; + + // Force read error. + std::string io_error = "I/O error"; + exchanger.SetReadError(io_error); + EXPECT_CALL(visitor, OnReadError(StrEq(io_error))).Times(1); + + EXPECT_FALSE(exchanger.ReadAndDeliverPacket(&client)); +} + +TEST(QbonePacketExchangerTest, + ReadAndDeliverPacketDoesNotNotifyVisitorOnBlockedIO) { + MockVisitor visitor; + FakeQbonePacketExchanger exchanger(&visitor, kMaxPendingPackets); + MockQboneClient client; + + // No more packets to read. + EXPECT_FALSE(exchanger.ReadAndDeliverPacket(&client)); +} + +TEST(QbonePacketExchangerTest, + WritePacketToNetworkWritesDirectlyToNetworkWhenNotBlocked) { + MockVisitor visitor; + FakeQbonePacketExchanger exchanger(&visitor, kMaxPendingPackets); + MockQboneClient client; + + std::string packet = "data"; + exchanger.WritePacketToNetwork(packet.data(), packet.length()); + + ASSERT_EQ(exchanger.packets_written().size(), 1); + EXPECT_THAT(exchanger.packets_written()[0], StrEq(packet)); +} + +TEST(QbonePacketExchangerTest, + WritePacketToNetworkQueuesPacketsAndProcessThemLater) { + MockVisitor visitor; + FakeQbonePacketExchanger exchanger(&visitor, kMaxPendingPackets); + MockQboneClient client; + + // Force write to be blocked so that packets are queued. + exchanger.ForceWriteFailure(true, ""); + std::vector packets = {"packet0", "packet1"}; + for (int i = 0; i < packets.size(); i++) { + exchanger.WritePacketToNetwork(packets[i].data(), packets[i].length()); + } + + // Nothing should have been written because of blockage. + ASSERT_TRUE(exchanger.packets_written().empty()); + + // Remove blockage and start proccessing queued packets. + exchanger.ForceWriteFailure(false, ""); + exchanger.SetWritable(); + + // Queued packets are processed. + ASSERT_EQ(exchanger.packets_written().size(), 2); + for (int i = 0; i < packets.size(); i++) { + EXPECT_THAT(exchanger.packets_written()[i], StrEq(packets[i])); + } +} + +TEST(QbonePacketExchangerTest, + SetWritableContinuesProcessingPacketIfPreviousCallBlocked) { + MockVisitor visitor; + FakeQbonePacketExchanger exchanger(&visitor, kMaxPendingPackets); + MockQboneClient client; + + // Force write to be blocked so that packets are queued. + exchanger.ForceWriteFailure(true, ""); + std::vector packets = {"packet0", "packet1"}; + for (int i = 0; i < packets.size(); i++) { + exchanger.WritePacketToNetwork(packets[i].data(), packets[i].length()); + } + + // Nothing should have been written because of blockage. + ASSERT_TRUE(exchanger.packets_written().empty()); + + // Start processing packets, but since writes are still blocked, nothing + // should have been written. + exchanger.SetWritable(); + ASSERT_TRUE(exchanger.packets_written().empty()); + + // Remove blockage and start processing packets again. + exchanger.ForceWriteFailure(false, ""); + exchanger.SetWritable(); + + ASSERT_EQ(exchanger.packets_written().size(), 2); + for (int i = 0; i < packets.size(); i++) { + EXPECT_THAT(exchanger.packets_written()[i], StrEq(packets[i])); + } +} + +TEST(QbonePacketExchangerTest, WritePacketToNetworkDropsPacketIfQueueIfFull) { + std::vector packets = {"packet0", "packet1", "packet2"}; + size_t queue_size = packets.size() - 1; + MockVisitor visitor; + // exchanger has smaller queue than number of packets. + FakeQbonePacketExchanger exchanger(&visitor, queue_size); + MockQboneClient client; + + exchanger.ForceWriteFailure(true, ""); + for (int i = 0; i < packets.size(); i++) { + exchanger.WritePacketToNetwork(packets[i].data(), packets[i].length()); + } + + // Blocked writes cause packets to be queued or dropped. + ASSERT_TRUE(exchanger.packets_written().empty()); + + exchanger.ForceWriteFailure(false, ""); + exchanger.SetWritable(); + + ASSERT_EQ(exchanger.packets_written().size(), queue_size); + for (int i = 0; i < queue_size; i++) { + EXPECT_THAT(exchanger.packets_written()[i], StrEq(packets[i])); + } +} + +TEST(QbonePacketExchangerTest, WriteErrorsGetNotified) { + MockVisitor visitor; + FakeQbonePacketExchanger exchanger(&visitor, kMaxPendingPackets); + MockQboneClient client; + std::string packet = "data"; + + // Write error is delivered to visitor during WritePacketToNetwork. + std::string io_error = "I/O error"; + exchanger.ForceWriteFailure(false, io_error); + EXPECT_CALL(visitor, OnWriteError(StrEq(io_error))).Times(1); + exchanger.WritePacketToNetwork(packet.data(), packet.length()); + ASSERT_TRUE(exchanger.packets_written().empty()); + + // Write error is delivered to visitor during SetWritable. + exchanger.ForceWriteFailure(true, ""); + exchanger.WritePacketToNetwork(packet.data(), packet.length()); + + std::string sys_error = "sys error"; + exchanger.ForceWriteFailure(false, sys_error); + EXPECT_CALL(visitor, OnWriteError(StrEq(sys_error))).Times(1); + exchanger.SetWritable(); + ASSERT_TRUE(exchanger.packets_written().empty()); +} + +TEST(QbonePacketExchangerTest, NullVisitorDoesntCrash) { + FakeQbonePacketExchanger exchanger(nullptr, kMaxPendingPackets); + MockQboneClient client; + std::string packet = "data"; + + // Force read error. + std::string io_error = "I/O error"; + exchanger.SetReadError(io_error); + EXPECT_FALSE(exchanger.ReadAndDeliverPacket(&client)); + + // Force write error + exchanger.ForceWriteFailure(false, io_error); + exchanger.WritePacketToNetwork(packet.data(), packet.length()); + EXPECT_TRUE(exchanger.packets_written().empty()); +} + +} // namespace +} // namespace quic diff --git a/quiche/quic/qbone/qbone_packet_processor.cc b/quiche/quic/qbone/qbone_packet_processor.cc new file mode 100644 index 000000000000..40228dd13e25 --- /dev/null +++ b/quiche/quic/qbone/qbone_packet_processor.cc @@ -0,0 +1,291 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/qbone_packet_processor.h" + +#include +#include +#include + +#include + +#include "absl/base/optimization.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_ip_address_family.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/qbone/platform/icmp_packet.h" +#include "quiche/quic/qbone/platform/internet_checksum.h" +#include "quiche/quic/qbone/platform/tcp_packet.h" +#include "quiche/common/quiche_endian.h" + +namespace { + +constexpr size_t kIPv6AddressSize = 16; +constexpr size_t kIPv6MinPacketSize = 1280; +constexpr size_t kIcmpTtl = 64; +constexpr size_t kICMPv6DestinationUnreachableDueToSourcePolicy = 5; +constexpr size_t kIPv6DestinationOffset = 8; + +} // namespace + +namespace quic { + +const QuicIpAddress QbonePacketProcessor::kInvalidIpAddress = + QuicIpAddress::Any6(); + +QbonePacketProcessor::QbonePacketProcessor(QuicIpAddress self_ip, + QuicIpAddress client_ip, + size_t client_ip_subnet_length, + OutputInterface* output, + StatsInterface* stats) + : client_ip_(client_ip), + output_(output), + stats_(stats), + filter_(new Filter) { + memcpy(self_ip_.s6_addr, self_ip.ToPackedString().data(), kIPv6AddressSize); + QUICHE_DCHECK_LE(client_ip_subnet_length, kIPv6AddressSize * 8); + client_ip_subnet_length_ = client_ip_subnet_length; + + QUICHE_DCHECK(IpAddressFamily::IP_V6 == self_ip.address_family()); + QUICHE_DCHECK(IpAddressFamily::IP_V6 == client_ip.address_family()); + QUICHE_DCHECK(self_ip != kInvalidIpAddress); +} + +QbonePacketProcessor::OutputInterface::~OutputInterface() {} +QbonePacketProcessor::StatsInterface::~StatsInterface() {} +QbonePacketProcessor::Filter::~Filter() {} + +QbonePacketProcessor::ProcessingResult +QbonePacketProcessor::Filter::FilterPacket(Direction direction, + absl::string_view full_packet, + absl::string_view payload, + icmp6_hdr* icmp_header, + OutputInterface* output) { + return ProcessingResult::OK; +} + +void QbonePacketProcessor::ProcessPacket(std::string* packet, + Direction direction) { + if (ABSL_PREDICT_FALSE(!IsValid())) { + QUIC_BUG(quic_bug_11024_1) + << "QuicPacketProcessor is invoked in an invalid state."; + stats_->OnPacketDroppedSilently(direction); + return; + } + + uint8_t transport_protocol; + char* transport_data; + icmp6_hdr icmp_header; + memset(&icmp_header, 0, sizeof(icmp_header)); + ProcessingResult result = ProcessIPv6HeaderAndFilter( + packet, direction, &transport_protocol, &transport_data, &icmp_header); + + in6_addr dst; + // TODO(b/70339814): ensure this is actually a unicast address. + memcpy(&dst, &packet->data()[kIPv6DestinationOffset], kIPv6AddressSize); + + switch (result) { + case ProcessingResult::OK: + switch (direction) { + case Direction::FROM_OFF_NETWORK: + output_->SendPacketToNetwork(*packet); + break; + case Direction::FROM_NETWORK: + output_->SendPacketToClient(*packet); + break; + } + stats_->OnPacketForwarded(direction); + break; + case ProcessingResult::SILENT_DROP: + stats_->OnPacketDroppedSilently(direction); + break; + case ProcessingResult::DEFER: + stats_->OnPacketDeferred(direction); + break; + case ProcessingResult::ICMP: + if (icmp_header.icmp6_type == ICMP6_ECHO_REPLY) { + // If this is an ICMP6 ECHO REPLY, the payload should be the same as the + // ICMP6 ECHO REQUEST that this came from, not the entire packet. So we + // need to take off both the IPv6 header and the ICMP6 header. + auto icmp_body = absl::string_view(*packet).substr(sizeof(ip6_hdr) + + sizeof(icmp6_hdr)); + SendIcmpResponse(dst, &icmp_header, icmp_body, direction); + } else { + SendIcmpResponse(dst, &icmp_header, *packet, direction); + } + stats_->OnPacketDroppedWithIcmp(direction); + break; + case ProcessingResult::ICMP_AND_TCP_RESET: + SendIcmpResponse(dst, &icmp_header, *packet, direction); + stats_->OnPacketDroppedWithIcmp(direction); + SendTcpReset(*packet, direction); + stats_->OnPacketDroppedWithTcpReset(direction); + break; + case ProcessingResult::TCP_RESET: + SendTcpReset(*packet, direction); + stats_->OnPacketDroppedWithTcpReset(direction); + break; + } +} + +QbonePacketProcessor::ProcessingResult +QbonePacketProcessor::ProcessIPv6HeaderAndFilter(std::string* packet, + Direction direction, + uint8_t* transport_protocol, + char** transport_data, + icmp6_hdr* icmp_header) { + ProcessingResult result = ProcessIPv6Header( + packet, direction, transport_protocol, transport_data, icmp_header); + + if (result == ProcessingResult::OK) { + char* packet_data = &*packet->begin(); + size_t header_size = *transport_data - packet_data; + // Sanity-check the bounds. + if (packet_data >= *transport_data || header_size > packet->size() || + header_size < kIPv6HeaderSize) { + QUIC_BUG(quic_bug_11024_2) + << "Invalid pointers encountered in " + "QbonePacketProcessor::ProcessPacket. Dropping the packet"; + return ProcessingResult::SILENT_DROP; + } + + result = filter_->FilterPacket( + direction, *packet, + absl::string_view(*transport_data, packet->size() - header_size), + icmp_header, output_); + } + + // Do not send ICMP error messages in response to ICMP errors. + if (result == ProcessingResult::ICMP) { + const uint8_t* header = reinterpret_cast(packet->data()); + + constexpr size_t kIPv6NextHeaderOffset = 6; + constexpr size_t kIcmpMessageTypeOffset = kIPv6HeaderSize + 0; + constexpr size_t kIcmpMessageTypeMaxError = 127; + if ( + // Check size. + packet->size() >= (kIPv6HeaderSize + kICMPv6HeaderSize) && + // Check that the packet is in fact ICMP. + header[kIPv6NextHeaderOffset] == IPPROTO_ICMPV6 && + // Check that ICMP message type is an error. + header[kIcmpMessageTypeOffset] < kIcmpMessageTypeMaxError) { + result = ProcessingResult::SILENT_DROP; + } + } + + return result; +} + +QbonePacketProcessor::ProcessingResult QbonePacketProcessor::ProcessIPv6Header( + std::string* packet, Direction direction, uint8_t* transport_protocol, + char** transport_data, icmp6_hdr* icmp_header) { + // Check if the packet is big enough to have IPv6 header. + if (packet->size() < kIPv6HeaderSize) { + QUIC_DVLOG(1) << "Dropped malformed packet: IPv6 header too short"; + return ProcessingResult::SILENT_DROP; + } + + // Check version field. + ip6_hdr* header = reinterpret_cast(&*packet->begin()); + if (header->ip6_vfc >> 4 != 6) { + QUIC_DVLOG(1) << "Dropped malformed packet: IP version is not IPv6"; + return ProcessingResult::SILENT_DROP; + } + + // Check payload size. + const size_t declared_payload_size = + quiche::QuicheEndian::NetToHost16(header->ip6_plen); + const size_t actual_payload_size = packet->size() - kIPv6HeaderSize; + if (declared_payload_size != actual_payload_size) { + QUIC_DVLOG(1) + << "Dropped malformed packet: incorrect packet length specified"; + return ProcessingResult::SILENT_DROP; + } + + // Check that the address of the client is in the packet. + QuicIpAddress address_to_check; + uint8_t address_reject_code; + bool ip_parse_result; + switch (direction) { + case Direction::FROM_OFF_NETWORK: + // Expect the source IP to match the client. + ip_parse_result = address_to_check.FromPackedString( + reinterpret_cast(&header->ip6_src), + sizeof(header->ip6_src)); + address_reject_code = kICMPv6DestinationUnreachableDueToSourcePolicy; + break; + case Direction::FROM_NETWORK: + // Expect the destination IP to match the client. + ip_parse_result = address_to_check.FromPackedString( + reinterpret_cast(&header->ip6_dst), + sizeof(header->ip6_src)); + address_reject_code = ICMP6_DST_UNREACH_NOROUTE; + break; + } + QUICHE_DCHECK(ip_parse_result); + if (!client_ip_.InSameSubnet(address_to_check, client_ip_subnet_length_)) { + QUIC_DVLOG(1) + << "Dropped packet: source/destination address is not client's"; + icmp_header->icmp6_type = ICMP6_DST_UNREACH; + icmp_header->icmp6_code = address_reject_code; + return ProcessingResult::ICMP; + } + + // Check and decrement TTL. + if (header->ip6_hops <= 1) { + icmp_header->icmp6_type = ICMP6_TIME_EXCEEDED; + icmp_header->icmp6_code = ICMP6_TIME_EXCEED_TRANSIT; + return ProcessingResult::ICMP; + } + header->ip6_hops--; + + // Check and extract IP headers. + switch (header->ip6_nxt) { + case IPPROTO_TCP: + case IPPROTO_UDP: + case IPPROTO_ICMPV6: + *transport_protocol = header->ip6_nxt; + *transport_data = (&*packet->begin()) + kIPv6HeaderSize; + break; + default: + icmp_header->icmp6_type = ICMP6_PARAM_PROB; + icmp_header->icmp6_code = ICMP6_PARAMPROB_NEXTHEADER; + return ProcessingResult::ICMP; + } + + return ProcessingResult::OK; +} + +void QbonePacketProcessor::SendIcmpResponse(in6_addr dst, + icmp6_hdr* icmp_header, + absl::string_view payload, + Direction original_direction) { + CreateIcmpPacket(self_ip_, dst, *icmp_header, payload, + [this, original_direction](absl::string_view packet) { + SendResponse(original_direction, packet); + }); +} + +void QbonePacketProcessor::SendTcpReset(absl::string_view original_packet, + Direction original_direction) { + CreateTcpResetPacket(original_packet, + [this, original_direction](absl::string_view packet) { + SendResponse(original_direction, packet); + }); +} + +void QbonePacketProcessor::SendResponse(Direction original_direction, + absl::string_view packet) { + switch (original_direction) { + case Direction::FROM_OFF_NETWORK: + output_->SendPacketToClient(packet); + break; + case Direction::FROM_NETWORK: + output_->SendPacketToNetwork(packet); + break; + } +} + +} // namespace quic diff --git a/quiche/quic/qbone/qbone_packet_processor.h b/quiche/quic/qbone/qbone_packet_processor.h new file mode 100644 index 000000000000..77eb13ed6dfa --- /dev/null +++ b/quiche/quic/qbone/qbone_packet_processor.h @@ -0,0 +1,200 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_QBONE_PACKET_PROCESSOR_H_ +#define QUICHE_QUIC_QBONE_QBONE_PACKET_PROCESSOR_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_ip_address.h" + +namespace quic { + +enum : size_t { + kIPv6HeaderSize = 40, + kICMPv6HeaderSize = sizeof(icmp6_hdr), + kTotalICMPv6HeaderSize = kIPv6HeaderSize + kICMPv6HeaderSize, +}; + +// QBONE packet processor accepts packets destined in either direction +// (client-to-network or network-to-client). It inspects them and makes +// decisions on whether they should be forwarded or dropped, replying with ICMP +// messages as appropriate. +class QbonePacketProcessor { + public: + enum class Direction { + // Packet is going from the QBONE client into the network behind the QBONE. + FROM_OFF_NETWORK = 0, + // Packet is going from the network begin QBONE to the client. + FROM_NETWORK = 1 + }; + + enum class ProcessingResult { + OK = 0, + SILENT_DROP = 1, + ICMP = 2, + // Equivalent to |SILENT_DROP| at the moment, but indicates that the + // downstream filter has buffered the packet and deferred its processing. + // The packet may be emitted at a later time. + DEFER = 3, + // In addition to sending an ICMP message, also send a TCP RST. This option + // requires the incoming packet to have been a valid TCP packet, as a TCP + // RST requires information from the current connection state to be + // well-formed. + ICMP_AND_TCP_RESET = 4, + // Send a TCP RST. + TCP_RESET = 5, + }; + + class OutputInterface { + public: + virtual ~OutputInterface(); + + virtual void SendPacketToClient(absl::string_view packet) = 0; + virtual void SendPacketToNetwork(absl::string_view packet) = 0; + }; + + class StatsInterface { + public: + virtual ~StatsInterface(); + + virtual void OnPacketForwarded(Direction direction) = 0; + virtual void OnPacketDroppedSilently(Direction direction) = 0; + virtual void OnPacketDroppedWithIcmp(Direction direction) = 0; + virtual void OnPacketDroppedWithTcpReset(Direction direction) = 0; + virtual void OnPacketDeferred(Direction direction) = 0; + }; + + // Allows to implement a custom packet filter on top of the filtering done by + // the packet processor itself. + class Filter { + public: + virtual ~Filter(); + // The main interface function. The following arguments are supplied: + // - |direction|, to indicate direction of the packet. + // - |full_packet|, which includes the IPv6 header and possibly the IPv6 + // options that were understood by the processor. + // - |payload|, the contents of the IPv6 packet, i.e. a TCP, a UDP or an + // ICMP packet. + // - |icmp_header|, an output argument which allows the filter to specify + // the ICMP message with which the packet is to be rejected. + // The method is called only on packets which were already verified as valid + // IPv6 packets. + // + // The implementer of this method has four options to return: + // - OK will cause the filter to pass the packet through + // - SILENT_DROP will cause the filter to drop the packet silently + // - ICMP will cause the filter to drop the packet and send an ICMP + // response. + // - DEFER will cause the packet to be not forwarded; the filter is + // responsible for sending (or not sending) it later using |output|. + // + // Note that |output| should not be used except in the DEFER case, as the + // processor will perform the necessary writes itself. + virtual ProcessingResult FilterPacket(Direction direction, + absl::string_view full_packet, + absl::string_view payload, + icmp6_hdr* icmp_header, + OutputInterface* output); + + protected: + // Helper methods that allow to easily extract information that is required + // for filtering from the |ipv6_header| argument. All of those assume that + // the header is of valid size, which is true for everything passed into + // FilterPacket(). + uint8_t TransportProtocolFromHeader(absl::string_view ipv6_header) { + return ipv6_header[6]; + } + QuicIpAddress SourceIpFromHeader(absl::string_view ipv6_header) { + QuicIpAddress address; + address.FromPackedString(&ipv6_header[8], + QuicIpAddress::kIPv6AddressSize); + return address; + } + QuicIpAddress DestinationIpFromHeader(absl::string_view ipv6_header) { + QuicIpAddress address; + address.FromPackedString(&ipv6_header[24], + QuicIpAddress::kIPv6AddressSize); + return address; + } + }; + + // |self_ip| is the IP address from which the processor will originate ICMP + // messages. |client_ip| is the expected IP address of the client, used for + // packet validation. + // + // |output| and |stats| are the visitor interfaces used by the processor. + // |output| gets notified whenever the processor decides to send a packet, and + // |stats| gets notified about any decisions that processor makes, without a + // reference to which packet that decision was made about. + QbonePacketProcessor(QuicIpAddress self_ip, QuicIpAddress client_ip, + size_t client_ip_subnet_length, OutputInterface* output, + StatsInterface* stats); + QbonePacketProcessor(const QbonePacketProcessor&) = delete; + QbonePacketProcessor& operator=(const QbonePacketProcessor&) = delete; + + // Accepts an IPv6 packet and handles it accordingly by either forwarding it, + // replying with an ICMP packet or silently dropping it. |packet| will be + // modified in the process, by having the TTL field decreased. + void ProcessPacket(std::string* packet, Direction direction); + + void set_filter(std::unique_ptr filter) { + filter_ = std::move(filter); + } + + void set_client_ip(QuicIpAddress client_ip) { client_ip_ = client_ip; } + void set_client_ip_subnet_length(size_t client_ip_subnet_length) { + client_ip_subnet_length_ = client_ip_subnet_length; + } + + static const QuicIpAddress kInvalidIpAddress; + + protected: + // Processes the header and returns what should be done with the packet. + // After that, calls an external packet filter if registered. TTL of the + // packet may be decreased in the process. + ProcessingResult ProcessIPv6HeaderAndFilter(std::string* packet, + Direction direction, + uint8_t* transport_protocol, + char** transport_data, + icmp6_hdr* icmp_header); + + void SendIcmpResponse(in6_addr dst, icmp6_hdr* icmp_header, + absl::string_view payload, + Direction original_direction); + + void SendTcpReset(absl::string_view original_packet, + Direction original_direction); + + bool IsValid() const { return client_ip_ != kInvalidIpAddress; } + + // IP address of the server. Used to send ICMP messages. + in6_addr self_ip_; + // IP address range of the VPN client. + QuicIpAddress client_ip_; + size_t client_ip_subnet_length_; + + OutputInterface* output_; + StatsInterface* stats_; + std::unique_ptr filter_; + + private: + // Performs basic sanity and permission checks on the packet, and decreases + // the TTL. + ProcessingResult ProcessIPv6Header(std::string* packet, Direction direction, + uint8_t* transport_protocol, + char** transport_data, + icmp6_hdr* icmp_header); + + void SendResponse(Direction original_direction, absl::string_view packet); + + in6_addr GetDestinationFromPacket(absl::string_view packet); +}; + +} // namespace quic +#endif // QUICHE_QUIC_QBONE_QBONE_PACKET_PROCESSOR_H_ diff --git a/quiche/quic/qbone/qbone_packet_processor_test.cc b/quiche/quic/qbone/qbone_packet_processor_test.cc new file mode 100644 index 000000000000..519ef43f69c0 --- /dev/null +++ b/quiche/quic/qbone/qbone_packet_processor_test.cc @@ -0,0 +1,388 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/qbone_packet_processor.h" + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/qbone/qbone_packet_processor_test_tools.h" +#include "quiche/common/quiche_text_utils.h" + +namespace quic::test { +namespace { + +using Direction = QbonePacketProcessor::Direction; +using ProcessingResult = QbonePacketProcessor::ProcessingResult; +using OutputInterface = QbonePacketProcessor::OutputInterface; +using ::testing::_; +using ::testing::Eq; +using ::testing::Invoke; +using ::testing::Return; +using ::testing::WithArgs; + +// clang-format off +static const char kReferenceClientPacketData[] = { + // IPv6 with zero TOS and flow label. + 0x60, 0x00, 0x00, 0x00, + // Payload size is 8 bytes. + 0x00, 0x08, + // Next header is UDP + 17, + // TTL is 50. + 50, + // IP address of the sender is fd00:0:0:1::1 + 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + // IP address of the receiver is fd00:0:0:5::1 + 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + // Source port 12345 + 0x30, 0x39, + // Destination port 443 + 0x01, 0xbb, + // UDP content length is zero + 0x00, 0x00, + // Checksum is not actually checked in any of the tests, so we leave it as + // zero + 0x00, 0x00, +}; + +static const char kReferenceNetworkPacketData[] = { + // IPv6 with zero TOS and flow label. + 0x60, 0x00, 0x00, 0x00, + // Payload size is 8 bytes. + 0x00, 0x08, + // Next header is UDP + 17, + // TTL is 50. + 50, + // IP address of the sender is fd00:0:0:5::1 + 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + // IP address of the receiver is fd00:0:0:1::1 + 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + // Source port 443 + 0x01, 0xbb, + // Destination port 12345 + 0x30, 0x39, + // UDP content length is zero + 0x00, 0x00, + // Checksum is not actually checked in any of the tests, so we leave it as + // zero + 0x00, 0x00, +}; + +static const char kReferenceClientSubnetPacketData[] = { + // IPv6 with zero TOS and flow label. + 0x60, 0x00, 0x00, 0x00, + // Payload size is 8 bytes. + 0x00, 0x08, + // Next header is UDP + 17, + // TTL is 50. + 50, + // IP address of the sender is fd00:0:0:2::1, which is within the /62 of the + // client. + 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + // IP address of the receiver is fd00:0:0:5::1 + 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + // Source port 12345 + 0x30, 0x39, + // Destination port 443 + 0x01, 0xbb, + // UDP content length is zero + 0x00, 0x00, + // Checksum is not actually checked in any of the tests, so we leave it as + // zero + 0x00, 0x00, +}; + +static const char kReferenceEchoRequestData[] = { + // IPv6 with zero TOS and flow label. + 0x60, 0x00, 0x00, 0x00, + // Payload size is 64 bytes. + 0x00, 64, + // Next header is ICMP + 58, + // TTL is 127. + 127, + // IP address of the sender is fd00:0:0:1::1 + 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + // IP address of the receiver is fe80::71:626f:6e6f + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x71, 0x62, 0x6f, 0x6e, 0x6f, + // ICMP Type ping request + 128, + // ICMP Code 0 + 0, + // Checksum is not actually checked in any of the tests, so we leave it as + // zero + 0x00, 0x00, + // ICMP Identifier (0xcafe to be memorable) + 0xca, 0xfe, + // Sequence number + 0x00, 0x01, + // Data, starting with unix timeval then 0x10..0x37 + 0x67, 0x37, 0x8a, 0x63, 0x00, 0x00, 0x00, 0x00, + 0x96, 0x58, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, + 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, +}; + +static const char kReferenceEchoReplyData[] = { + // IPv6 with zero TOS and flow label. + 0x60, 0x00, 0x00, 0x00, + // Payload size is 64 bytes. + 0x00, 64, + // Next header is ICMP + 58, + // TTL is 255. + 255, + // IP address of the sender is fd00:4:0:1::1 + 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + // IP address of the receiver is fd00:0:0:1::1 + 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + // ICMP Type ping reply + 129, + // ICMP Code 0 + 0, + // Checksum + 0x66, 0xb6, + // ICMP Identifier (0xcafe to be memorable) + 0xca, 0xfe, + // Sequence number + 0x00, 0x01, + // Data, starting with unix timeval then 0x10..0x37 + 0x67, 0x37, 0x8a, 0x63, 0x00, 0x00, 0x00, 0x00, + 0x96, 0x58, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, + 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, +}; + +// clang-format on + +static const absl::string_view kReferenceClientPacket( + kReferenceClientPacketData, ABSL_ARRAYSIZE(kReferenceClientPacketData)); + +static const absl::string_view kReferenceNetworkPacket( + kReferenceNetworkPacketData, ABSL_ARRAYSIZE(kReferenceNetworkPacketData)); + +static const absl::string_view kReferenceClientSubnetPacket( + kReferenceClientSubnetPacketData, + ABSL_ARRAYSIZE(kReferenceClientSubnetPacketData)); + +static const absl::string_view kReferenceEchoRequest( + kReferenceEchoRequestData, ABSL_ARRAYSIZE(kReferenceEchoRequestData)); + +MATCHER_P(IsIcmpMessage, icmp_type, + "Checks whether the argument is an ICMP message of supplied type") { + if (arg.size() < kTotalICMPv6HeaderSize) { + return false; + } + + return arg[40] == icmp_type; +} + +class MockPacketFilter : public QbonePacketProcessor::Filter { + public: + MOCK_METHOD(ProcessingResult, FilterPacket, + (Direction, absl::string_view, absl::string_view, icmp6_hdr*, + OutputInterface*), + (override)); +}; + +class QbonePacketProcessorTest : public QuicTest { + protected: + QbonePacketProcessorTest() { + QUICHE_CHECK(client_ip_.FromString("fd00:0:0:1::1")); + QUICHE_CHECK(self_ip_.FromString("fd00:0:0:4::1")); + QUICHE_CHECK(network_ip_.FromString("fd00:0:0:5::1")); + + processor_ = std::make_unique( + self_ip_, client_ip_, /*client_ip_subnet_length=*/62, &output_, + &stats_); + } + + void SendPacketFromClient(absl::string_view packet) { + std::string packet_buffer(packet.data(), packet.size()); + processor_->ProcessPacket(&packet_buffer, Direction::FROM_OFF_NETWORK); + } + + void SendPacketFromNetwork(absl::string_view packet) { + std::string packet_buffer(packet.data(), packet.size()); + processor_->ProcessPacket(&packet_buffer, Direction::FROM_NETWORK); + } + + QuicIpAddress client_ip_; + QuicIpAddress self_ip_; + QuicIpAddress network_ip_; + + std::unique_ptr processor_; + testing::StrictMock output_; + testing::StrictMock stats_; +}; + +TEST_F(QbonePacketProcessorTest, EmptyPacket) { + EXPECT_CALL(stats_, OnPacketDroppedSilently(Direction::FROM_OFF_NETWORK)); + SendPacketFromClient(""); + + EXPECT_CALL(stats_, OnPacketDroppedSilently(Direction::FROM_NETWORK)); + SendPacketFromNetwork(""); +} + +TEST_F(QbonePacketProcessorTest, RandomGarbage) { + EXPECT_CALL(stats_, OnPacketDroppedSilently(Direction::FROM_OFF_NETWORK)); + SendPacketFromClient(std::string(1280, 'a')); + + EXPECT_CALL(stats_, OnPacketDroppedSilently(Direction::FROM_NETWORK)); + SendPacketFromNetwork(std::string(1280, 'a')); +} + +TEST_F(QbonePacketProcessorTest, RandomGarbageWithCorrectLengthFields) { + std::string packet(40, 'a'); + packet[4] = 0; + packet[5] = 0; + + EXPECT_CALL(stats_, OnPacketDroppedWithIcmp(Direction::FROM_OFF_NETWORK)); + EXPECT_CALL(output_, SendPacketToClient(IsIcmpMessage(ICMP6_DST_UNREACH))); + SendPacketFromClient(packet); +} + +TEST_F(QbonePacketProcessorTest, GoodPacketFromClient) { + EXPECT_CALL(stats_, OnPacketForwarded(Direction::FROM_OFF_NETWORK)); + EXPECT_CALL(output_, SendPacketToNetwork(_)); + SendPacketFromClient(kReferenceClientPacket); +} + +TEST_F(QbonePacketProcessorTest, GoodPacketFromClientSubnet) { + EXPECT_CALL(stats_, OnPacketForwarded(Direction::FROM_OFF_NETWORK)); + EXPECT_CALL(output_, SendPacketToNetwork(_)); + SendPacketFromClient(kReferenceClientSubnetPacket); +} + +TEST_F(QbonePacketProcessorTest, GoodPacketFromNetwork) { + EXPECT_CALL(stats_, OnPacketForwarded(Direction::FROM_NETWORK)); + EXPECT_CALL(output_, SendPacketToClient(_)); + SendPacketFromNetwork(kReferenceNetworkPacket); +} + +TEST_F(QbonePacketProcessorTest, GoodPacketFromNetworkWrongDirection) { + EXPECT_CALL(stats_, OnPacketDroppedWithIcmp(Direction::FROM_OFF_NETWORK)); + EXPECT_CALL(output_, SendPacketToClient(IsIcmpMessage(ICMP6_DST_UNREACH))); + SendPacketFromClient(kReferenceNetworkPacket); +} + +TEST_F(QbonePacketProcessorTest, TtlExpired) { + std::string packet(kReferenceNetworkPacket); + packet[7] = 1; + + EXPECT_CALL(stats_, OnPacketDroppedWithIcmp(Direction::FROM_NETWORK)); + EXPECT_CALL(output_, SendPacketToNetwork(IsIcmpMessage(ICMP6_TIME_EXCEEDED))); + SendPacketFromNetwork(packet); +} + +TEST_F(QbonePacketProcessorTest, UnknownProtocol) { + std::string packet(kReferenceNetworkPacket); + packet[6] = IPPROTO_SCTP; + + EXPECT_CALL(stats_, OnPacketDroppedWithIcmp(Direction::FROM_NETWORK)); + EXPECT_CALL(output_, SendPacketToNetwork(IsIcmpMessage(ICMP6_PARAM_PROB))); + SendPacketFromNetwork(packet); +} + +TEST_F(QbonePacketProcessorTest, FilterFromClient) { + auto filter = std::make_unique(); + EXPECT_CALL(*filter, FilterPacket(_, _, _, _, _)) + .WillRepeatedly(Return(ProcessingResult::SILENT_DROP)); + processor_->set_filter(std::move(filter)); + + EXPECT_CALL(stats_, OnPacketDroppedSilently(Direction::FROM_OFF_NETWORK)); + SendPacketFromClient(kReferenceClientPacket); +} + +class TestFilter : public QbonePacketProcessor::Filter { + public: + TestFilter(QuicIpAddress client_ip, QuicIpAddress network_ip) + : client_ip_(client_ip), network_ip_(network_ip) {} + ProcessingResult FilterPacket(Direction direction, + absl::string_view full_packet, + absl::string_view payload, + icmp6_hdr* icmp_header, + OutputInterface* output) override { + EXPECT_EQ(kIPv6HeaderSize, full_packet.size() - payload.size()); + EXPECT_EQ(IPPROTO_UDP, TransportProtocolFromHeader(full_packet)); + EXPECT_EQ(client_ip_, SourceIpFromHeader(full_packet)); + EXPECT_EQ(network_ip_, DestinationIpFromHeader(full_packet)); + + called_++; + return ProcessingResult::SILENT_DROP; + } + + int called() const { return called_; } + + private: + int called_ = 0; + + QuicIpAddress client_ip_; + QuicIpAddress network_ip_; +}; + +// Verify that the parameters are passed correctly into the filter, and that the +// helper functions of the filter class work. +TEST_F(QbonePacketProcessorTest, FilterHelperFunctions) { + auto filter_owned = std::make_unique(client_ip_, network_ip_); + TestFilter* filter = filter_owned.get(); + processor_->set_filter(std::move(filter_owned)); + + EXPECT_CALL(stats_, OnPacketDroppedSilently(Direction::FROM_OFF_NETWORK)); + SendPacketFromClient(kReferenceClientPacket); + ASSERT_EQ(1, filter->called()); +} + +TEST_F(QbonePacketProcessorTest, Icmp6EchoResponseHasRightPayload) { + auto filter = std::make_unique(); + EXPECT_CALL(*filter, FilterPacket(_, _, _, _, _)) + .WillOnce(WithArgs<2, 3>( + Invoke([](absl::string_view payload, icmp6_hdr* icmp_header) { + icmp_header->icmp6_type = ICMP6_ECHO_REPLY; + icmp_header->icmp6_code = 0; + auto* request_header = + reinterpret_cast(payload.data()); + icmp_header->icmp6_id = request_header->icmp6_id; + icmp_header->icmp6_seq = request_header->icmp6_seq; + return ProcessingResult::ICMP; + }))); + processor_->set_filter(std::move(filter)); + + EXPECT_CALL(stats_, OnPacketDroppedWithIcmp(Direction::FROM_OFF_NETWORK)); + EXPECT_CALL(output_, SendPacketToClient(_)) + .WillOnce(Invoke([](absl::string_view packet) { + // Explicit conversion because otherwise it is treated as a null + // terminated string. + absl::string_view expected = absl::string_view( + kReferenceEchoReplyData, sizeof(kReferenceEchoReplyData)); + + EXPECT_THAT(packet, Eq(expected)); + QUIC_LOG(INFO) << "ICMP response:\n" + << quiche::QuicheTextUtils::HexDump(packet); + })); + SendPacketFromClient(kReferenceEchoRequest); +} + +} // namespace +} // namespace quic::test diff --git a/quiche/quic/qbone/qbone_packet_processor_test_tools.cc b/quiche/quic/qbone/qbone_packet_processor_test_tools.cc new file mode 100644 index 000000000000..9a731887fbf7 --- /dev/null +++ b/quiche/quic/qbone/qbone_packet_processor_test_tools.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/qbone_packet_processor_test_tools.h" + +#include + +namespace quic { + +std::string PrependIPv6HeaderForTest(const std::string& body, int hops) { + ip6_hdr header; + memset(&header, 0, sizeof(header)); + + header.ip6_vfc = 6 << 4; + header.ip6_plen = htons(body.size()); + header.ip6_nxt = IPPROTO_UDP; + header.ip6_hops = hops; + header.ip6_src = in6addr_loopback; + header.ip6_dst = in6addr_loopback; + + std::string packet(sizeof(header) + body.size(), '\0'); + memcpy(&packet[0], &header, sizeof(header)); + memcpy(&packet[sizeof(header)], body.data(), body.size()); + return packet; +} + +bool DecrementIPv6HopLimit(std::string& packet) { + if (packet.size() < sizeof(ip6_hdr)) { + return false; + } + ip6_hdr* header = reinterpret_cast(&packet[0]); + if (header->ip6_vfc >> 4 != 6 || header->ip6_hops == 0) { + return false; + } + header->ip6_hops--; + return true; +} + +} // namespace quic diff --git a/quiche/quic/qbone/qbone_packet_processor_test_tools.h b/quiche/quic/qbone/qbone_packet_processor_test_tools.h new file mode 100644 index 000000000000..3d3492dba585 --- /dev/null +++ b/quiche/quic/qbone/qbone_packet_processor_test_tools.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_QBONE_PACKET_PROCESSOR_TEST_TOOLS_H_ +#define QUICHE_QUIC_QBONE_QBONE_PACKET_PROCESSOR_TEST_TOOLS_H_ + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/qbone/qbone_packet_processor.h" + +namespace quic { + +class MockPacketProcessorOutput : public QbonePacketProcessor::OutputInterface { + public: + MockPacketProcessorOutput() {} + + MOCK_METHOD(void, SendPacketToClient, (absl::string_view), (override)); + MOCK_METHOD(void, SendPacketToNetwork, (absl::string_view), (override)); +}; + +class MockPacketProcessorStats : public QbonePacketProcessor::StatsInterface { + public: + MockPacketProcessorStats() {} + + MOCK_METHOD(void, OnPacketForwarded, (QbonePacketProcessor::Direction), + (override)); + MOCK_METHOD(void, OnPacketDroppedSilently, (QbonePacketProcessor::Direction), + (override)); + MOCK_METHOD(void, OnPacketDroppedWithIcmp, (QbonePacketProcessor::Direction), + (override)); + MOCK_METHOD(void, OnPacketDroppedWithTcpReset, + (QbonePacketProcessor::Direction), (override)); + MOCK_METHOD(void, OnPacketDeferred, (QbonePacketProcessor::Direction), + (override)); +}; + +std::string PrependIPv6HeaderForTest(const std::string& body, int hops); + +// Returns true if the hop limit was decremented. Returns false if the packet is +// too short, not IPv6, or already has a hop limit of zero. +bool DecrementIPv6HopLimit(std::string& packet); + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_QBONE_PACKET_PROCESSOR_TEST_TOOLS_H_ diff --git a/quiche/quic/qbone/qbone_packet_writer.h b/quiche/quic/qbone/qbone_packet_writer.h new file mode 100644 index 000000000000..1ed8a46fa048 --- /dev/null +++ b/quiche/quic/qbone/qbone_packet_writer.h @@ -0,0 +1,24 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_QBONE_PACKET_WRITER_H_ +#define QUICHE_QUIC_QBONE_QBONE_PACKET_WRITER_H_ + +#include + +namespace quic { + +// QbonePacketWriter expects only one function to be defined, +// WritePacketToNetwork, which is called when a packet is received via QUIC +// and should be sent out on the network. This is the complete packet, +// and not just a fragment. +class QbonePacketWriter { + public: + virtual ~QbonePacketWriter() {} + virtual void WritePacketToNetwork(const char* packet, size_t size) = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_QBONE_PACKET_WRITER_H_ diff --git a/quiche/quic/qbone/qbone_server_session.cc b/quiche/quic/qbone/qbone_server_session.cc new file mode 100644 index 000000000000..cbe9583635bf --- /dev/null +++ b/quiche/quic/qbone/qbone_server_session.cc @@ -0,0 +1,113 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/qbone_server_session.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/qbone/qbone_constants.h" +#include "quiche/common/platform/api/quiche_command_line_flags.h" + +namespace quic { + +bool QboneCryptoServerStreamHelper::CanAcceptClientHello( + const CryptoHandshakeMessage& chlo, const QuicSocketAddress& client_address, + const QuicSocketAddress& peer_address, + const QuicSocketAddress& self_address, std::string* error_details) const { + absl::string_view alpn; + chlo.GetStringPiece(quic::kALPN, &alpn); + if (alpn != QboneConstants::kQboneAlpn) { + *error_details = "ALPN-indicated protocol is not qbone"; + return false; + } + return true; +} + +QboneServerSession::QboneServerSession( + const quic::ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, Visitor* owner, const QuicConfig& config, + const QuicCryptoServerConfig* quic_crypto_server_config, + QuicCompressedCertsCache* compressed_certs_cache, QbonePacketWriter* writer, + QuicIpAddress self_ip, QuicIpAddress client_ip, + size_t client_ip_subnet_length, QboneServerControlStream::Handler* handler) + : QboneSessionBase(connection, owner, config, supported_versions, writer), + processor_(self_ip, client_ip, client_ip_subnet_length, this, this), + quic_crypto_server_config_(quic_crypto_server_config), + compressed_certs_cache_(compressed_certs_cache), + handler_(handler) {} + +QboneServerSession::~QboneServerSession() {} + +std::unique_ptr QboneServerSession::CreateCryptoStream() { + return CreateCryptoServerStream(quic_crypto_server_config_, + compressed_certs_cache_, this, + &stream_helper_); +} + +void QboneServerSession::CreateControlStream() { + if (control_stream_ != nullptr) { + return; + } + // Register the reserved control stream. + auto control_stream = + std::make_unique(this, handler_); + control_stream_ = control_stream.get(); + ActivateStream(std::move(control_stream)); +} + +QuicStream* QboneServerSession::CreateControlStreamFromPendingStream( + PendingStream* pending) { + QUICHE_DCHECK(control_stream_ == nullptr); + // Register the reserved control stream. + auto control_stream = + std::make_unique(pending, this, handler_); + control_stream_ = control_stream.get(); + ActivateStream(std::move(control_stream)); + return control_stream_; +} + +void QboneServerSession::SetDefaultEncryptionLevel( + quic::EncryptionLevel level) { + QboneSessionBase::SetDefaultEncryptionLevel(level); + if (level == quic::ENCRYPTION_FORWARD_SECURE) { + CreateControlStream(); + } +} + +bool QboneServerSession::SendClientRequest(const QboneClientRequest& request) { + if (!control_stream_) { + QUIC_BUG(quic_bug_11026_1) + << "Cannot send client request before control stream is created."; + return false; + } + return control_stream_->SendRequest(request); +} + +void QboneServerSession::ProcessPacketFromNetwork(absl::string_view packet) { + std::string buffer = std::string(packet); + processor_.ProcessPacket(&buffer, + QbonePacketProcessor::Direction::FROM_NETWORK); +} + +void QboneServerSession::ProcessPacketFromPeer(absl::string_view packet) { + std::string buffer = std::string(packet); + processor_.ProcessPacket(&buffer, + QbonePacketProcessor::Direction::FROM_OFF_NETWORK); +} + +void QboneServerSession::SendPacketToClient(absl::string_view packet) { + SendPacketToPeer(packet); +} + +void QboneServerSession::SendPacketToNetwork(absl::string_view packet) { + QUICHE_DCHECK(writer_ != nullptr); + writer_->WritePacketToNetwork(packet.data(), packet.size()); +} + +} // namespace quic diff --git a/quiche/quic/qbone/qbone_server_session.h b/quiche/quic/qbone/qbone_server_session.h new file mode 100644 index 000000000000..c3138160cdff --- /dev/null +++ b/quiche/quic/qbone/qbone_server_session.h @@ -0,0 +1,101 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_QBONE_SERVER_SESSION_H_ +#define QUICHE_QUIC_QBONE_QBONE_SERVER_SESSION_H_ + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_crypto_server_stream_base.h" +#include "quiche/quic/core/quic_crypto_stream.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/qbone/qbone_control.pb.h" +#include "quiche/quic/qbone/qbone_control_stream.h" +#include "quiche/quic/qbone/qbone_packet_processor.h" +#include "quiche/quic/qbone/qbone_packet_writer.h" +#include "quiche/quic/qbone/qbone_session_base.h" + +namespace quic { + +// A helper class is used by the QuicCryptoServerStream. +class QboneCryptoServerStreamHelper + : public QuicCryptoServerStreamBase::Helper { + public: + // This will look for the QBONE alpn. + bool CanAcceptClientHello(const CryptoHandshakeMessage& chlo, + const QuicSocketAddress& client_address, + const QuicSocketAddress& peer_address, + const QuicSocketAddress& self_address, + std::string* error_details) const override; +}; + +class QUIC_EXPORT_PRIVATE QboneServerSession + : public QboneSessionBase, + public QbonePacketProcessor::OutputInterface, + public QbonePacketProcessor::StatsInterface { + public: + QboneServerSession(const quic::ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, Visitor* owner, + const QuicConfig& config, + const QuicCryptoServerConfig* quic_crypto_server_config, + QuicCompressedCertsCache* compressed_certs_cache, + QbonePacketWriter* writer, QuicIpAddress self_ip, + QuicIpAddress client_ip, size_t client_ip_subnet_length, + QboneServerControlStream::Handler* handler); + QboneServerSession(const QboneServerSession&) = delete; + QboneServerSession& operator=(const QboneServerSession&) = delete; + ~QboneServerSession() override; + + // Override to create control stream at FORWARD_SECURE encryption level. + void SetDefaultEncryptionLevel(quic::EncryptionLevel level) override; + + virtual bool SendClientRequest(const QboneClientRequest& request); + + void ProcessPacketFromNetwork(absl::string_view packet) override; + void ProcessPacketFromPeer(absl::string_view packet) override; + + // QbonePacketProcessor::OutputInterface implementation. + void SendPacketToClient(absl::string_view packet) override; + void SendPacketToNetwork(absl::string_view packet) override; + + // QbonePacketProcessor::StatsInterface implementation. + void OnPacketForwarded(QbonePacketProcessor::Direction direction) override {} + void OnPacketDroppedSilently( + QbonePacketProcessor::Direction direction) override {} + void OnPacketDroppedWithIcmp( + QbonePacketProcessor::Direction direction) override {} + void OnPacketDroppedWithTcpReset( + QbonePacketProcessor::Direction direction) override {} + void OnPacketDeferred(QbonePacketProcessor::Direction direction) override {} + + protected: + // QboneSessionBase interface implementation. + std::unique_ptr CreateCryptoStream() override; + + // Instantiates QboneServerControlStream. + virtual void CreateControlStream(); + + // Instantiates QboneServerControlStream from the pending stream and returns a + // pointer to it. + QuicStream* CreateControlStreamFromPendingStream(PendingStream* pending); + + // The packet processor. + QbonePacketProcessor processor_; + + // Config for QUIC crypto server stream, used by the server. + const QuicCryptoServerConfig* quic_crypto_server_config_; + + private: + // Used by QUIC crypto server stream to track most recently compressed certs. + QuicCompressedCertsCache* compressed_certs_cache_; + // This helper is needed when create QuicCryptoServerStream. + QboneCryptoServerStreamHelper stream_helper_; + // Passed to the control stream. + QboneServerControlStream::Handler* handler_; + // The unowned control stream. + QboneServerControlStream* control_stream_ = nullptr; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_QBONE_SERVER_SESSION_H_ diff --git a/quiche/quic/qbone/qbone_session_base.cc b/quiche/quic/qbone/qbone_session_base.cc new file mode 100644 index 000000000000..36517b75f5cf --- /dev/null +++ b/quiche/quic/qbone/qbone_session_base.cc @@ -0,0 +1,213 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/qbone_session_base.h" + +#include +#include + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_data_reader.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_exported_stats.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_testvalue.h" +#include "quiche/quic/qbone/platform/icmp_packet.h" +#include "quiche/quic/qbone/qbone_constants.h" +#include "quiche/common/platform/api/quiche_command_line_flags.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/common/quiche_buffer_allocator.h" + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + bool, qbone_close_ephemeral_frames, true, + "If true, we'll call CloseStream even when we receive ephemeral frames."); + +namespace quic { + +#define ENDPOINT \ + (perspective() == Perspective::IS_SERVER ? "Server: " : "Client: ") + +QboneSessionBase::QboneSessionBase( + QuicConnection* connection, Visitor* owner, const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + QbonePacketWriter* writer) + : QuicSession(connection, owner, config, supported_versions, + /*num_expected_unidirectional_static_streams = */ 0) { + set_writer(writer); + const uint32_t max_streams = + (std::numeric_limits::max() / kMaxAvailableStreamsMultiplier) - + 1; + this->config()->SetMaxBidirectionalStreamsToSend(max_streams); + if (VersionHasIetfQuicFrames(transport_version())) { + this->config()->SetMaxUnidirectionalStreamsToSend(max_streams); + } +} + +QboneSessionBase::~QboneSessionBase() {} + +void QboneSessionBase::Initialize() { + crypto_stream_ = CreateCryptoStream(); + QuicSession::Initialize(); +} + +const QuicCryptoStream* QboneSessionBase::GetCryptoStream() const { + return crypto_stream_.get(); +} + +QuicCryptoStream* QboneSessionBase::GetMutableCryptoStream() { + return crypto_stream_.get(); +} + +QuicStream* QboneSessionBase::CreateOutgoingStream() { + return ActivateDataStream( + CreateDataStream(GetNextOutgoingUnidirectionalStreamId())); +} + +void QboneSessionBase::OnStreamFrame(const QuicStreamFrame& frame) { + if (frame.offset == 0 && frame.fin && frame.data_length > 0) { + ++num_ephemeral_packets_; + ProcessPacketFromPeer( + absl::string_view(frame.data_buffer, frame.data_length)); + flow_controller()->AddBytesConsumed(frame.data_length); + // TODO(b/147817422): Add a counter for how many streams were actually + // closed here. + if (quiche::GetQuicheCommandLineFlag(FLAGS_qbone_close_ephemeral_frames)) { + ResetStream(frame.stream_id, QUIC_STREAM_CANCELLED); + } + return; + } + QuicSession::OnStreamFrame(frame); +} + +void QboneSessionBase::OnMessageReceived(absl::string_view message) { + ++num_message_packets_; + ProcessPacketFromPeer(message); +} + +QuicStream* QboneSessionBase::CreateIncomingStream(QuicStreamId id) { + return ActivateDataStream(CreateDataStream(id)); +} + +QuicStream* QboneSessionBase::CreateIncomingStream(PendingStream* /*pending*/) { + QUICHE_NOTREACHED(); + return nullptr; +} + +bool QboneSessionBase::ShouldKeepConnectionAlive() const { + // QBONE connections stay alive until they're explicitly closed. + return true; +} + +std::unique_ptr QboneSessionBase::CreateDataStream( + QuicStreamId id) { + if (!IsEncryptionEstablished()) { + // Encryption not active so no stream created + return nullptr; + } + + if (IsIncomingStream(id)) { + ++num_streamed_packets_; + return std::make_unique(id, this); + } + + return std::make_unique(id, this); +} + +QuicStream* QboneSessionBase::ActivateDataStream( + std::unique_ptr stream) { + // Transfer ownership of the data stream to the session via ActivateStream(). + QuicStream* raw = stream.get(); + if (stream) { + // Make QuicSession take ownership of the stream. + ActivateStream(std::move(stream)); + } + return raw; +} + +void QboneSessionBase::SendPacketToPeer(absl::string_view packet) { + if (crypto_stream_ == nullptr) { + QUIC_BUG(quic_bug_10987_1) + << "Attempting to send packet before encryption established"; + return; + } + + if (send_packets_as_messages_) { + quiche::QuicheMemSlice slice(quiche::QuicheBuffer::Copy( + connection()->helper()->GetStreamSendBufferAllocator(), packet)); + switch (SendMessage(absl::MakeSpan(&slice, 1), /*flush=*/true).status) { + case MESSAGE_STATUS_SUCCESS: + break; + case MESSAGE_STATUS_TOO_LARGE: { + if (packet.size() < sizeof(ip6_hdr)) { + QUIC_BUG(quic_bug_10987_2) + << "Dropped malformed packet: IPv6 header too short"; + break; + } + auto* header = reinterpret_cast(packet.begin()); + icmp6_hdr icmp_header{}; + icmp_header.icmp6_type = ICMP6_PACKET_TOO_BIG; + icmp_header.icmp6_mtu = + connection()->GetGuaranteedLargestMessagePayload(); + + CreateIcmpPacket(header->ip6_dst, header->ip6_src, icmp_header, packet, + [this](absl::string_view icmp_packet) { + writer_->WritePacketToNetwork(icmp_packet.data(), + icmp_packet.size()); + }); + break; + } + case MESSAGE_STATUS_ENCRYPTION_NOT_ESTABLISHED: + QUIC_BUG(quic_bug_10987_3) + << "MESSAGE_STATUS_ENCRYPTION_NOT_ESTABLISHED"; + break; + case MESSAGE_STATUS_UNSUPPORTED: + QUIC_BUG(quic_bug_10987_4) << "MESSAGE_STATUS_UNSUPPORTED"; + break; + case MESSAGE_STATUS_BLOCKED: + QUIC_BUG(quic_bug_10987_5) << "MESSAGE_STATUS_BLOCKED"; + break; + case MESSAGE_STATUS_INTERNAL_ERROR: + QUIC_BUG(quic_bug_10987_6) << "MESSAGE_STATUS_INTERNAL_ERROR"; + break; + } + return; + } + + // QBONE streams are ephemeral. + QuicStream* stream = CreateOutgoingStream(); + if (!stream) { + QUIC_BUG(quic_bug_10987_7) << "Failed to create an outgoing QBONE stream."; + return; + } + + QboneWriteOnlyStream* qbone_stream = + static_cast(stream); + qbone_stream->WritePacketToQuicStream(packet); +} + +uint64_t QboneSessionBase::GetNumEphemeralPackets() const { + return num_ephemeral_packets_; +} + +uint64_t QboneSessionBase::GetNumStreamedPackets() const { + return num_streamed_packets_; +} + +uint64_t QboneSessionBase::GetNumMessagePackets() const { + return num_message_packets_; +} + +uint64_t QboneSessionBase::GetNumFallbackToStream() const { + return num_fallback_to_stream_; +} + +void QboneSessionBase::set_writer(QbonePacketWriter* writer) { + writer_ = writer; + quic::AdjustTestValue("quic_QbonePacketWriter", &writer_); +} + +} // namespace quic diff --git a/quiche/quic/qbone/qbone_session_base.h b/quiche/quic/qbone/qbone_session_base.h new file mode 100644 index 000000000000..e643f0d942bc --- /dev/null +++ b/quiche/quic/qbone/qbone_session_base.h @@ -0,0 +1,111 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_QBONE_SESSION_BASE_H_ +#define QUICHE_QUIC_QBONE_QBONE_SESSION_BASE_H_ + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_crypto_server_stream_base.h" +#include "quiche/quic/core/quic_crypto_stream.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/platform/api/quic_export.h" +#include "quiche/quic/qbone/qbone_packet_writer.h" +#include "quiche/quic/qbone/qbone_stream.h" + +namespace quic { + +class QUIC_EXPORT_PRIVATE QboneSessionBase : public QuicSession { + public: + QboneSessionBase(QuicConnection* connection, Visitor* owner, + const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + QbonePacketWriter* writer); + QboneSessionBase(const QboneSessionBase&) = delete; + QboneSessionBase& operator=(const QboneSessionBase&) = delete; + ~QboneSessionBase() override; + + // Overrides from QuicSession. + // This will ensure that the crypto session is created. + void Initialize() override; + // This will check if the packet is wholly contained. + void OnStreamFrame(const QuicStreamFrame& frame) override; + // Called whenever a MESSAGE frame is received. + void OnMessageReceived(absl::string_view message) override; + + virtual void ProcessPacketFromNetwork(absl::string_view packet) = 0; + virtual void ProcessPacketFromPeer(absl::string_view packet) = 0; + + // Returns the number of QBONE network packets that were received + // that fit into a single QuicStreamFrame and elided the creation of + // a QboneReadOnlyStream. + uint64_t GetNumEphemeralPackets() const; + + // Returns the number of QBONE network packets that were via + // multiple packets, requiring the creation of a QboneReadOnlyStream. + uint64_t GetNumStreamedPackets() const; + + // Returns the number of QBONE network packets that were received using QUIC + // MESSAGE frame. + uint64_t GetNumMessagePackets() const; + + // Returns the number of times sending a MESSAGE frame failed, and the session + // used an ephemeral stream instead. + uint64_t GetNumFallbackToStream() const; + + void set_writer(QbonePacketWriter* writer); + void set_send_packets_as_messages(bool send_packets_as_messages) { + send_packets_as_messages_ = send_packets_as_messages; + } + + protected: + virtual std::unique_ptr CreateCryptoStream() = 0; + + // QuicSession interface implementation. + QuicCryptoStream* GetMutableCryptoStream() override; + const QuicCryptoStream* GetCryptoStream() const override; + QuicStream* CreateIncomingStream(QuicStreamId id) override; + QuicStream* CreateIncomingStream(PendingStream* pending) override; + bool ShouldKeepConnectionAlive() const override; + + bool MaybeIncreaseLargestPeerStreamId(const QuicStreamId stream_id) override { + return true; + } + + QuicStream* CreateOutgoingStream(); + std::unique_ptr CreateDataStream(QuicStreamId id); + // Activates a QuicStream. The session takes ownership of the stream, but + // returns an unowned pointer to the stream for convenience. + QuicStream* ActivateDataStream(std::unique_ptr stream); + + // Accepts a given packet from the network and writes it out + // to the QUIC stream. This will create an ephemeral stream per + // packet. This function will return true if a stream was created + // and the packet sent. It will return false if the stream could not + // be created. + void SendPacketToPeer(absl::string_view packet); + + QbonePacketWriter* writer_; + + // If true, send QUIC DATAGRAM (aka MESSAGE) frames instead of ephemeral + // streams. Note that receiving DATAGRAM frames is always supported. + bool send_packets_as_messages_ = true; + + private: + // Used for the crypto handshake. + std::unique_ptr crypto_stream_; + + // Statistics for the packets received by the session. + uint64_t num_ephemeral_packets_ = 0; + uint64_t num_message_packets_ = 0; + uint64_t num_streamed_packets_ = 0; + + // Number of times the connection has failed to send packets as MESSAGE frame + // and used streams as a fallback. + uint64_t num_fallback_to_stream_ = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_QBONE_SESSION_BASE_H_ diff --git a/quiche/quic/qbone/qbone_session_test.cc b/quiche/quic/qbone/qbone_session_test.cc new file mode 100644 index 000000000000..3e2e8d11ee50 --- /dev/null +++ b/quiche/quic/qbone/qbone_session_test.cc @@ -0,0 +1,636 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/io/quic_default_event_loop.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/proto/crypto_server_config_proto.h" +#include "quiche/quic/core/quic_alarm_factory.h" +#include "quiche/quic/core/quic_default_clock.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/platform/api/quic_test_loopback.h" +#include "quiche/quic/qbone/platform/icmp_packet.h" +#include "quiche/quic/qbone/qbone_client_session.h" +#include "quiche/quic/qbone/qbone_constants.h" +#include "quiche/quic/qbone/qbone_control_placeholder.pb.h" +#include "quiche/quic/qbone/qbone_packet_processor_test_tools.h" +#include "quiche/quic/qbone/qbone_server_session.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/mock_clock.h" +#include "quiche/quic/test_tools/mock_connection_id_generator.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_session_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { +namespace test { +namespace { + +using ::testing::_; +using ::testing::Contains; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::Invoke; +using ::testing::NiceMock; +using ::testing::Not; + +std::string TestPacketIn(const std::string& body) { + return PrependIPv6HeaderForTest(body, 5); +} + +std::string TestPacketOut(const std::string& body) { + return PrependIPv6HeaderForTest(body, 4); +} + +ParsedQuicVersionVector GetTestParams() { + SetQuicReloadableFlag(quic_disable_version_q046, false); + ParsedQuicVersionVector test_versions; + + // TODO(b/113130636): Make QBONE work with TLS. + for (const auto& version : CurrentSupportedVersionsWithQuicCrypto()) { + // QBONE requires MESSAGE frames + if (!version.SupportsMessageFrames()) { + continue; + } + test_versions.push_back(version); + } + + return test_versions; +} + +// Used by QuicCryptoServerConfig to provide server credentials, passes +// everything through to ProofSourceForTesting if success is true, +// and fails otherwise. +class IndirectionProofSource : public ProofSource { + public: + explicit IndirectionProofSource(bool success) { + if (success) { + proof_source_ = crypto_test_utils::ProofSourceForTesting(); + } + } + + // ProofSource override. + void GetProof(const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const std::string& hostname, const std::string& server_config, + QuicTransportVersion transport_version, + absl::string_view chlo_hash, + std::unique_ptr callback) override { + if (!proof_source_) { + QuicCryptoProof proof; + quiche::QuicheReferenceCountedPointer chain = + GetCertChain(server_address, client_address, hostname, + &proof.cert_matched_sni); + callback->Run(/*ok=*/false, chain, proof, /*details=*/nullptr); + return; + } + proof_source_->GetProof(server_address, client_address, hostname, + server_config, transport_version, chlo_hash, + std::move(callback)); + } + + quiche::QuicheReferenceCountedPointer GetCertChain( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, const std::string& hostname, + bool* cert_matched_sni) override { + if (!proof_source_) { + return quiche::QuicheReferenceCountedPointer(); + } + return proof_source_->GetCertChain(server_address, client_address, hostname, + cert_matched_sni); + } + + void ComputeTlsSignature( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, const std::string& hostname, + uint16_t signature_algorithm, absl::string_view in, + std::unique_ptr callback) override { + if (!proof_source_) { + callback->Run(/*ok=*/true, "Signature", /*details=*/nullptr); + return; + } + proof_source_->ComputeTlsSignature(server_address, client_address, hostname, + signature_algorithm, in, + std::move(callback)); + } + + absl::InlinedVector SupportedTlsSignatureAlgorithms() + const override { + if (!proof_source_) { + return {}; + } + return proof_source_->SupportedTlsSignatureAlgorithms(); + } + + TicketCrypter* GetTicketCrypter() override { return nullptr; } + + private: + std::unique_ptr proof_source_; +}; + +// Used by QuicCryptoClientConfig to verify server credentials, passes +// everything through to ProofVerifierForTesting is success is true, +// otherwise returns a canned response of QUIC_FAILURE. +class IndirectionProofVerifier : public ProofVerifier { + public: + explicit IndirectionProofVerifier(bool success) { + if (success) { + proof_verifier_ = crypto_test_utils::ProofVerifierForTesting(); + } + } + + // ProofVerifier override + QuicAsyncStatus VerifyProof( + const std::string& hostname, const uint16_t port, + const std::string& server_config, QuicTransportVersion transport_version, + absl::string_view chlo_hash, const std::vector& certs, + const std::string& cert_sct, const std::string& signature, + const ProofVerifyContext* context, std::string* error_details, + std::unique_ptr* verify_details, + std::unique_ptr callback) override { + if (!proof_verifier_) { + return QUIC_FAILURE; + } + return proof_verifier_->VerifyProof( + hostname, port, server_config, transport_version, chlo_hash, certs, + cert_sct, signature, context, error_details, verify_details, + std::move(callback)); + } + + QuicAsyncStatus VerifyCertChain( + const std::string& hostname, const uint16_t port, + const std::vector& certs, const std::string& ocsp_response, + const std::string& cert_sct, const ProofVerifyContext* context, + std::string* error_details, std::unique_ptr* details, + uint8_t* out_alert, + std::unique_ptr callback) override { + if (!proof_verifier_) { + return QUIC_FAILURE; + } + return proof_verifier_->VerifyCertChain( + hostname, port, certs, ocsp_response, cert_sct, context, error_details, + details, out_alert, std::move(callback)); + } + + std::unique_ptr CreateDefaultContext() override { + if (!proof_verifier_) { + return nullptr; + } + return proof_verifier_->CreateDefaultContext(); + } + + private: + std::unique_ptr proof_verifier_; +}; + +class DataSavingQbonePacketWriter : public QbonePacketWriter { + public: + void WritePacketToNetwork(const char* packet, size_t size) override { + data_.push_back(std::string(packet, size)); + } + + const std::vector& data() { return data_; } + + private: + std::vector data_; +}; + +template +class DataSavingQboneControlHandler : public QboneControlHandler { + public: + void OnControlRequest(const T& request) override { data_.push_back(request); } + + void OnControlError() override { error_ = true; } + + const std::vector& data() { return data_; } + bool error() { return error_; } + + private: + std::vector data_; + bool error_ = false; +}; + +// Single-threaded scheduled task runner based on a MockClock. +// +// Simulates asynchronous execution on a single thread by holding scheduled +// tasks until Run() is called. Performs no synchronization, assumes that +// Schedule() and Run() are called on the same thread. +class FakeTaskRunner { + public: + explicit FakeTaskRunner(MockQuicConnectionHelper* helper) + : tasks_([](const TaskType& l, const TaskType& r) { + // Items at a later time should run after items at an earlier time. + // Priority queue comparisons should return true if l appears after r. + return l->time() > r->time(); + }), + helper_(helper) {} + + // Runs all tasks in time order. Executes tasks scheduled at + // the same in an arbitrary order. + void Run() { + while (!tasks_.empty()) { + tasks_.top()->Run(); + tasks_.pop(); + } + } + + private: + class InnerTask { + public: + InnerTask(std::function task, QuicTime time) + : task_(std::move(task)), time_(time) {} + + void Cancel() { cancelled_ = true; } + + void Run() { + if (!cancelled_) { + task_(); + } + } + + QuicTime time() const { return time_; } + + private: + bool cancelled_ = false; + std::function task_; + QuicTime time_; + }; + + public: + // Schedules a function to run immediately and advances the time. + void Schedule(std::function task) { + tasks_.push(std::shared_ptr( + new InnerTask(std::move(task), helper_->GetClock()->Now()))); + helper_->AdvanceTime(QuicTime::Delta::FromMilliseconds(1)); + } + + private: + using TaskType = std::shared_ptr; + std::priority_queue, + std::function> + tasks_; + MockQuicConnectionHelper* helper_; +}; + +class QboneSessionTest : public QuicTestWithParam { + public: + QboneSessionTest() + : supported_versions_({GetParam()}), + runner_(&helper_), + compressed_certs_cache_(100) {} + + ~QboneSessionTest() override { + delete client_connection_; + delete server_connection_; + } + + const MockClock* GetClock() const { + return static_cast(helper_.GetClock()); + } + + // The parameters are used to control whether the handshake will success or + // not. + void CreateClientAndServerSessions(bool client_handshake_success = true, + bool server_handshake_success = true, + bool send_qbone_alpn = true) { + // Quic crashes if packets are sent at time 0, and the clock defaults to 0. + helper_.AdvanceTime(QuicTime::Delta::FromMilliseconds(1000)); + event_loop_ = GetDefaultEventLoop()->Create(QuicDefaultClock::Get()); + alarm_factory_ = event_loop_->CreateAlarmFactory(); + client_writer_ = std::make_unique(); + server_writer_ = std::make_unique(); + client_handler_ = + std::make_unique>(); + server_handler_ = + std::make_unique>(); + QuicSocketAddress server_address(TestLoopback(), 0); + QuicSocketAddress client_address; + if (server_address.host().address_family() == IpAddressFamily::IP_V4) { + client_address = QuicSocketAddress(QuicIpAddress::Any4(), 0); + } else { + client_address = QuicSocketAddress(QuicIpAddress::Any6(), 0); + } + + { + client_connection_ = new QuicConnection( + TestConnectionId(), client_address, server_address, &helper_, + alarm_factory_.get(), new NiceMock(), true, + Perspective::IS_CLIENT, supported_versions_, + connection_id_generator_); + client_connection_->SetSelfAddress(client_address); + QuicConfig config; + client_crypto_config_ = std::make_unique( + std::make_unique(client_handshake_success)); + if (send_qbone_alpn) { + client_crypto_config_->set_alpn("qbone"); + } + client_peer_ = std::make_unique( + client_connection_, client_crypto_config_.get(), + /*owner=*/nullptr, config, supported_versions_, + QuicServerId("test.example.com", 1234, false), client_writer_.get(), + client_handler_.get()); + } + + { + server_connection_ = new QuicConnection( + TestConnectionId(), server_address, client_address, &helper_, + alarm_factory_.get(), new NiceMock(), true, + Perspective::IS_SERVER, supported_versions_, + connection_id_generator_); + server_connection_->SetSelfAddress(server_address); + QuicConfig config; + server_crypto_config_ = std::make_unique( + QuicCryptoServerConfig::TESTING, QuicRandom::GetInstance(), + std::make_unique(server_handshake_success), + KeyExchangeSource::Default()); + QuicCryptoServerConfig::ConfigOptions options; + QuicServerConfigProtobuf primary_config = + server_crypto_config_->GenerateConfig(QuicRandom::GetInstance(), + GetClock(), options); + std::unique_ptr message( + server_crypto_config_->AddConfig(primary_config, + GetClock()->WallNow())); + + server_peer_ = std::make_unique( + supported_versions_, server_connection_, nullptr, config, + server_crypto_config_.get(), &compressed_certs_cache_, + server_writer_.get(), TestLoopback6(), TestLoopback6(), 64, + server_handler_.get()); + } + + // Hook everything up! + MockPacketWriter* client_writer = static_cast( + QuicConnectionPeer::GetWriter(client_peer_->connection())); + ON_CALL(*client_writer, WritePacket(_, _, _, _, _)) + .WillByDefault(Invoke([this](const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + PerPacketOptions* options) { + char* copy = new char[1024 * 1024]; + memcpy(copy, buffer, buf_len); + runner_.Schedule([this, copy, buf_len] { + QuicReceivedPacket packet(copy, buf_len, GetClock()->Now()); + server_peer_->ProcessUdpPacket(server_connection_->self_address(), + client_connection_->self_address(), + packet); + delete[] copy; + }); + return WriteResult(WRITE_STATUS_OK, buf_len); + })); + MockPacketWriter* server_writer = static_cast( + QuicConnectionPeer::GetWriter(server_peer_->connection())); + ON_CALL(*server_writer, WritePacket(_, _, _, _, _)) + .WillByDefault(Invoke([this](const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + PerPacketOptions* options) { + char* copy = new char[1024 * 1024]; + memcpy(copy, buffer, buf_len); + runner_.Schedule([this, copy, buf_len] { + QuicReceivedPacket packet(copy, buf_len, GetClock()->Now()); + client_peer_->ProcessUdpPacket(client_connection_->self_address(), + server_connection_->self_address(), + packet); + delete[] copy; + }); + return WriteResult(WRITE_STATUS_OK, buf_len); + })); + } + + void StartHandshake() { + server_peer_->Initialize(); + client_peer_->Initialize(); + runner_.Run(); + } + + void ExpectICMPTooBigResponse(const std::vector& written_packets, + const int mtu, const std::string& packet) { + auto* header = reinterpret_cast(packet.data()); + icmp6_hdr icmp_header{}; + icmp_header.icmp6_type = ICMP6_PACKET_TOO_BIG; + icmp_header.icmp6_mtu = mtu; + + std::string expected; + CreateIcmpPacket(header->ip6_dst, header->ip6_src, icmp_header, packet, + [&expected](absl::string_view icmp_packet) { + expected = std::string(icmp_packet); + }); + + EXPECT_THAT(written_packets, Contains(expected)); + } + + // Test handshake establishment and sending/receiving of data for two + // directions. + void TestStreamConnection(bool use_messages) { + ASSERT_TRUE(server_peer_->OneRttKeysAvailable()); + ASSERT_TRUE(client_peer_->OneRttKeysAvailable()); + ASSERT_TRUE(server_peer_->IsEncryptionEstablished()); + ASSERT_TRUE(client_peer_->IsEncryptionEstablished()); + + // Create an outgoing stream from the client and say hello. + QUIC_LOG(INFO) << "Sending client -> server"; + client_peer_->ProcessPacketFromNetwork(TestPacketIn("hello")); + client_peer_->ProcessPacketFromNetwork(TestPacketIn("world")); + runner_.Run(); + // The server should see the data, the client hasn't received + // anything yet. + EXPECT_THAT(server_writer_->data(), + ElementsAre(TestPacketOut("hello"), TestPacketOut("world"))); + EXPECT_TRUE(client_writer_->data().empty()); + EXPECT_EQ(0u, server_peer_->GetNumActiveStreams()); + EXPECT_EQ(0u, client_peer_->GetNumActiveStreams()); + + // Let's pretend some service responds. + QUIC_LOG(INFO) << "Sending server -> client"; + server_peer_->ProcessPacketFromNetwork(TestPacketIn("Hello Again")); + server_peer_->ProcessPacketFromNetwork(TestPacketIn("Again")); + runner_.Run(); + EXPECT_THAT(server_writer_->data(), + ElementsAre(TestPacketOut("hello"), TestPacketOut("world"))); + EXPECT_THAT( + client_writer_->data(), + ElementsAre(TestPacketOut("Hello Again"), TestPacketOut("Again"))); + EXPECT_EQ(0u, server_peer_->GetNumActiveStreams()); + EXPECT_EQ(0u, client_peer_->GetNumActiveStreams()); + + // Try to send long payloads that are larger than the QUIC MTU but + // smaller than the QBONE max size. + // This should trigger the non-ephemeral stream code path. + std::string long_data( + QboneConstants::kMaxQbonePacketBytes - sizeof(ip6_hdr) - 1, 'A'); + QUIC_LOG(INFO) << "Sending server -> client long data"; + server_peer_->ProcessPacketFromNetwork(TestPacketIn(long_data)); + runner_.Run(); + if (use_messages) { + ExpectICMPTooBigResponse( + server_writer_->data(), + server_peer_->connection()->GetGuaranteedLargestMessagePayload(), + TestPacketOut(long_data)); + } else { + EXPECT_THAT(client_writer_->data(), Contains(TestPacketOut(long_data))); + } + EXPECT_THAT(server_writer_->data(), + Not(Contains(TestPacketOut(long_data)))); + EXPECT_EQ(0u, server_peer_->GetNumActiveStreams()); + EXPECT_EQ(0u, client_peer_->GetNumActiveStreams()); + + QUIC_LOG(INFO) << "Sending client -> server long data"; + client_peer_->ProcessPacketFromNetwork(TestPacketIn(long_data)); + runner_.Run(); + if (use_messages) { + ExpectICMPTooBigResponse( + client_writer_->data(), + client_peer_->connection()->GetGuaranteedLargestMessagePayload(), + TestPacketIn(long_data)); + } else { + EXPECT_THAT(server_writer_->data(), Contains(TestPacketOut(long_data))); + } + EXPECT_FALSE(client_peer_->EarlyDataAccepted()); + EXPECT_FALSE(client_peer_->ReceivedInchoateReject()); + EXPECT_THAT(client_peer_->GetNumReceivedServerConfigUpdates(), Eq(0)); + + if (!use_messages) { + EXPECT_THAT(client_peer_->GetNumStreamedPackets(), Eq(1)); + EXPECT_THAT(server_peer_->GetNumStreamedPackets(), Eq(1)); + } + + if (use_messages) { + EXPECT_THAT(client_peer_->GetNumEphemeralPackets(), Eq(0)); + EXPECT_THAT(server_peer_->GetNumEphemeralPackets(), Eq(0)); + EXPECT_THAT(client_peer_->GetNumMessagePackets(), Eq(2)); + EXPECT_THAT(server_peer_->GetNumMessagePackets(), Eq(2)); + } else { + EXPECT_THAT(client_peer_->GetNumEphemeralPackets(), Eq(2)); + EXPECT_THAT(server_peer_->GetNumEphemeralPackets(), Eq(2)); + EXPECT_THAT(client_peer_->GetNumMessagePackets(), Eq(0)); + EXPECT_THAT(server_peer_->GetNumMessagePackets(), Eq(0)); + } + + // All streams are ephemeral and should be gone. + EXPECT_EQ(0u, server_peer_->GetNumActiveStreams()); + EXPECT_EQ(0u, client_peer_->GetNumActiveStreams()); + } + + // Test that client and server are not connected after handshake failure. + void TestDisconnectAfterFailedHandshake() { + EXPECT_FALSE(client_peer_->IsEncryptionEstablished()); + EXPECT_FALSE(client_peer_->OneRttKeysAvailable()); + + EXPECT_FALSE(server_peer_->IsEncryptionEstablished()); + EXPECT_FALSE(server_peer_->OneRttKeysAvailable()); + } + + protected: + const ParsedQuicVersionVector supported_versions_; + std::unique_ptr event_loop_; + std::unique_ptr alarm_factory_; + FakeTaskRunner runner_; + MockQuicConnectionHelper helper_; + QuicConnection* client_connection_; + QuicConnection* server_connection_; + QuicCompressedCertsCache compressed_certs_cache_; + + std::unique_ptr client_crypto_config_; + std::unique_ptr server_crypto_config_; + std::unique_ptr client_writer_; + std::unique_ptr server_writer_; + std::unique_ptr> + client_handler_; + std::unique_ptr> + server_handler_; + + std::unique_ptr server_peer_; + std::unique_ptr client_peer_; + MockConnectionIdGenerator connection_id_generator_; +}; + +INSTANTIATE_TEST_SUITE_P(Tests, QboneSessionTest, + ::testing::ValuesIn(GetTestParams()), + ::testing::PrintToStringParamName()); + +TEST_P(QboneSessionTest, StreamConnection) { + CreateClientAndServerSessions(); + client_peer_->set_send_packets_as_messages(false); + server_peer_->set_send_packets_as_messages(false); + StartHandshake(); + TestStreamConnection(false); +} + +TEST_P(QboneSessionTest, Messages) { + CreateClientAndServerSessions(); + client_peer_->set_send_packets_as_messages(true); + server_peer_->set_send_packets_as_messages(true); + StartHandshake(); + TestStreamConnection(true); +} + +TEST_P(QboneSessionTest, ClientRejection) { + CreateClientAndServerSessions(false /*client_handshake_success*/, + true /*server_handshake_success*/, + true /*send_qbone_alpn*/); + StartHandshake(); + TestDisconnectAfterFailedHandshake(); +} + +TEST_P(QboneSessionTest, BadAlpn) { + CreateClientAndServerSessions(true /*client_handshake_success*/, + true /*server_handshake_success*/, + false /*send_qbone_alpn*/); + StartHandshake(); + TestDisconnectAfterFailedHandshake(); +} + +TEST_P(QboneSessionTest, ServerRejection) { + CreateClientAndServerSessions(true /*client_handshake_success*/, + false /*server_handshake_success*/, + true /*send_qbone_alpn*/); + StartHandshake(); + TestDisconnectAfterFailedHandshake(); +} + +// Test that data streams are not created before handshake. +TEST_P(QboneSessionTest, CannotCreateDataStreamBeforeHandshake) { + CreateClientAndServerSessions(); + EXPECT_QUIC_BUG(client_peer_->ProcessPacketFromNetwork(TestPacketIn("hello")), + "Attempting to send packet before encryption established"); + EXPECT_QUIC_BUG(server_peer_->ProcessPacketFromNetwork(TestPacketIn("hello")), + "Attempting to send packet before encryption established"); + EXPECT_EQ(0u, server_peer_->GetNumActiveStreams()); + EXPECT_EQ(0u, client_peer_->GetNumActiveStreams()); +} + +TEST_P(QboneSessionTest, ControlRequests) { + CreateClientAndServerSessions(); + StartHandshake(); + EXPECT_TRUE(client_handler_->data().empty()); + EXPECT_FALSE(client_handler_->error()); + EXPECT_TRUE(server_handler_->data().empty()); + EXPECT_FALSE(server_handler_->error()); + + QboneClientRequest client_request; + client_request.SetExtension(client_placeholder, "hello from the server"); + EXPECT_TRUE(server_peer_->SendClientRequest(client_request)); + runner_.Run(); + ASSERT_FALSE(client_handler_->data().empty()); + EXPECT_THAT(client_handler_->data()[0].GetExtension(client_placeholder), + Eq("hello from the server")); + EXPECT_FALSE(client_handler_->error()); + + QboneServerRequest server_request; + server_request.SetExtension(server_placeholder, "hello from the client"); + EXPECT_TRUE(client_peer_->SendServerRequest(server_request)); + runner_.Run(); + ASSERT_FALSE(server_handler_->data().empty()); + EXPECT_THAT(server_handler_->data()[0].GetExtension(server_placeholder), + Eq("hello from the client")); + EXPECT_FALSE(server_handler_->error()); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/qbone/qbone_stream.cc b/quiche/quic/qbone/qbone_stream.cc new file mode 100644 index 000000000000..6465d5953ce7 --- /dev/null +++ b/quiche/quic/qbone/qbone_stream.cc @@ -0,0 +1,62 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/qbone_stream.h" + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_data_reader.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/qbone/qbone_constants.h" +#include "quiche/quic/qbone/qbone_session_base.h" +#include "quiche/common/platform/api/quiche_command_line_flags.h" + +DEFINE_QUICHE_COMMAND_LINE_FLAG(int, qbone_stream_ttl_secs, 3, + "The QBONE Stream TTL in seconds."); + +namespace quic { + +QboneWriteOnlyStream::QboneWriteOnlyStream(QuicStreamId id, + QuicSession* session) + : QuicStream(id, session, /*is_static=*/false, WRITE_UNIDIRECTIONAL) { + // QBONE uses a LIFO queue to try to always make progress. An individual + // packet may persist for upto to qbone_stream_ttl_secs seconds in memory. + MaybeSetTtl(QuicTime::Delta::FromSeconds( + quiche::GetQuicheCommandLineFlag(FLAGS_qbone_stream_ttl_secs))); +} + +void QboneWriteOnlyStream::WritePacketToQuicStream(absl::string_view packet) { + // Streams are one way and ephemeral. This function should only be + // called once. + WriteOrBufferData(packet, /* fin= */ true, nullptr); +} + +QboneReadOnlyStream::QboneReadOnlyStream(QuicStreamId id, + QboneSessionBase* session) + : QuicStream(id, session, + /*is_static=*/false, READ_UNIDIRECTIONAL), + session_(session) { + // QBONE uses a LIFO queue to try to always make progress. An individual + // packet may persist for upto to qbone_stream_ttl_secs seconds in memory. + MaybeSetTtl(QuicTime::Delta::FromSeconds( + quiche::GetQuicheCommandLineFlag(FLAGS_qbone_stream_ttl_secs))); +} + +void QboneReadOnlyStream::OnDataAvailable() { + // Read in data and buffer it, attempt to frame to see if there's a packet. + sequencer()->Read(&buffer_); + if (sequencer()->IsClosed()) { + session_->ProcessPacketFromPeer(buffer_); + OnFinRead(); + return; + } + if (buffer_.size() > QboneConstants::kMaxQbonePacketBytes) { + if (!rst_sent()) { + Reset(QUIC_BAD_APPLICATION_PAYLOAD); + } + StopReading(); + } +} + +} // namespace quic diff --git a/quiche/quic/qbone/qbone_stream.h b/quiche/quic/qbone/qbone_stream.h new file mode 100644 index 000000000000..bc5551a6b633 --- /dev/null +++ b/quiche/quic/qbone/qbone_stream.h @@ -0,0 +1,56 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_QBONE_QBONE_STREAM_H_ +#define QUICHE_QUIC_QBONE_QBONE_STREAM_H_ + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_stream.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +class QboneSessionBase; + +// QboneWriteOnlyStream is responsible for sending data for a single +// packet to the other side. +// Note that the stream will be created HalfClosed (reads will be closed). +class QUIC_EXPORT_PRIVATE QboneWriteOnlyStream : public QuicStream { + public: + QboneWriteOnlyStream(QuicStreamId id, QuicSession* session); + + // QuicStream implementation. QBONE writers are ephemeral and don't + // read any data. + void OnDataAvailable() override {} + + // Write a network packet over the quic stream. + void WritePacketToQuicStream(absl::string_view packet); +}; + +// QboneReadOnlyStream will be used if we find an incoming stream that +// isn't fully contained. It will buffer the data when available and +// attempt to parse it as a packet to send to the network when a FIN +// is found. +// Note that the stream will be created HalfClosed (writes will be closed). +class QUIC_EXPORT_PRIVATE QboneReadOnlyStream : public QuicStream { + public: + QboneReadOnlyStream(QuicStreamId id, QboneSessionBase* session); + + ~QboneReadOnlyStream() override = default; + + // QuicStream overrides. + // OnDataAvailable is called when there is data in the quic stream buffer. + // This will copy the buffer locally and attempt to parse it to write out + // packets to the network. + void OnDataAvailable() override; + + private: + std::string buffer_; + QboneSessionBase* session_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_QBONE_QBONE_STREAM_H_ diff --git a/quiche/quic/qbone/qbone_stream_test.cc b/quiche/quic/qbone/qbone_stream_test.cc new file mode 100644 index 000000000000..b7cc198bf1c0 --- /dev/null +++ b/quiche/quic/qbone/qbone_stream_test.cc @@ -0,0 +1,262 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/qbone/qbone_stream.h" + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_stream_priority.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/platform/api/quic_test_loopback.h" +#include "quiche/quic/qbone/qbone_constants.h" +#include "quiche/quic/qbone/qbone_session_base.h" +#include "quiche/quic/test_tools/mock_clock.h" +#include "quiche/quic/test_tools/mock_connection_id_generator.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/common/simple_buffer_allocator.h" + +namespace quic { + +namespace { + +using ::testing::_; +using ::testing::StrictMock; + +// MockQuicSession that does not create streams and writes data from +// QuicStream to a string. +class MockQuicSession : public QboneSessionBase { + public: + MockQuicSession(QuicConnection* connection, const QuicConfig& config) + : QboneSessionBase(connection, nullptr /*visitor*/, config, + CurrentSupportedVersions(), nullptr /*writer*/) {} + + ~MockQuicSession() override {} + + // Writes outgoing data from QuicStream to a string. + QuicConsumedData WritevData(QuicStreamId id, size_t write_length, + QuicStreamOffset offset, StreamSendingState state, + TransmissionType type, + EncryptionLevel level) override { + if (!writable_) { + return QuicConsumedData(0, false); + } + + return QuicConsumedData(write_length, state != StreamSendingState::NO_FIN); + } + + QboneReadOnlyStream* CreateIncomingStream(QuicStreamId id) override { + return nullptr; + } + + // Called by QuicStream when they want to close stream. + MOCK_METHOD(void, MaybeSendRstStreamFrame, + (QuicStreamId stream_id, QuicResetStreamError error, + QuicStreamOffset bytes_written), + (override)); + MOCK_METHOD(void, MaybeSendStopSendingFrame, + (QuicStreamId stream_id, QuicResetStreamError error), (override)); + + // Sets whether data is written to buffer, or else if this is write blocked. + void set_writable(bool writable) { writable_ = writable; } + + // Tracks whether the stream is write blocked and its priority. + void RegisterReliableStream(QuicStreamId stream_id) { + // The priority effectively does not matter. Put all streams on the same + // priority. + write_blocked_streams()->RegisterStream( + stream_id, + /* is_static_stream = */ false, + QuicStreamPriority::Default(priority_type())); + } + + // The session take ownership of the stream. + void ActivateReliableStream(std::unique_ptr stream) { + ActivateStream(std::move(stream)); + } + + std::unique_ptr CreateCryptoStream() override { + return std::make_unique(this); + } + + MOCK_METHOD(void, ProcessPacketFromPeer, (absl::string_view), (override)); + MOCK_METHOD(void, ProcessPacketFromNetwork, (absl::string_view), (override)); + + private: + // Whether data is written to write_buffer_. + bool writable_ = true; +}; + +// Packet writer that does nothing. This is required for QuicConnection but +// isn't used for writing data. +class DummyPacketWriter : public QuicPacketWriter { + public: + DummyPacketWriter() {} + + // QuicPacketWriter overrides. + WriteResult WritePacket(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + PerPacketOptions* options) override { + return WriteResult(WRITE_STATUS_ERROR, 0); + } + + bool IsWriteBlocked() const override { return false; }; + + void SetWritable() override {} + + absl::optional MessageTooBigErrorCode() const override { + return absl::nullopt; + } + + QuicByteCount GetMaxPacketSize( + const QuicSocketAddress& peer_address) const override { + return 0; + } + + bool SupportsReleaseTime() const override { return false; } + + bool IsBatchMode() const override { return false; } + + QuicPacketBuffer GetNextWriteLocation( + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address) override { + return {nullptr, nullptr}; + } + + WriteResult Flush() override { return WriteResult(WRITE_STATUS_OK, 0); } +}; + +class QboneReadOnlyStreamTest : public ::testing::Test, + public QuicConnectionHelperInterface { + public: + void CreateReliableQuicStream() { + // Arbitrary values for QuicConnection. + Perspective perspective = Perspective::IS_SERVER; + bool owns_writer = true; + + alarm_factory_ = std::make_unique(); + + connection_.reset(new QuicConnection( + test::TestConnectionId(0), QuicSocketAddress(TestLoopback(), 0), + QuicSocketAddress(TestLoopback(), 0), + this /*QuicConnectionHelperInterface*/, alarm_factory_.get(), + new DummyPacketWriter(), owns_writer, perspective, + ParsedVersionOfIndex(CurrentSupportedVersions(), 0), + connection_id_generator_)); + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(1)); + session_ = std::make_unique>(connection_.get(), + QuicConfig()); + session_->Initialize(); + stream_ = new QboneReadOnlyStream(kStreamId, session_.get()); + session_->ActivateReliableStream( + std::unique_ptr(stream_)); + } + + ~QboneReadOnlyStreamTest() override {} + + const QuicClock* GetClock() const override { return &clock_; } + + QuicRandom* GetRandomGenerator() override { + return QuicRandom::GetInstance(); + } + + quiche::QuicheBufferAllocator* GetStreamSendBufferAllocator() override { + return &buffer_allocator_; + } + + protected: + // The QuicSession will take the ownership. + QboneReadOnlyStream* stream_; + std::unique_ptr> session_; + std::unique_ptr alarm_factory_; + std::unique_ptr connection_; + // Used to implement the QuicConnectionHelperInterface. + quiche::SimpleBufferAllocator buffer_allocator_; + MockClock clock_; + const QuicStreamId kStreamId = QuicUtils::GetFirstUnidirectionalStreamId( + CurrentSupportedVersions()[0].transport_version, Perspective::IS_CLIENT); + quic::test::MockConnectionIdGenerator connection_id_generator_; +}; + +// Read an entire string. +TEST_F(QboneReadOnlyStreamTest, ReadDataWhole) { + std::string packet = "Stuff"; + CreateReliableQuicStream(); + QuicStreamFrame frame(kStreamId, true, 0, packet); + EXPECT_CALL(*session_, ProcessPacketFromPeer("Stuff")); + stream_->OnStreamFrame(frame); +} + +// Test buffering. +TEST_F(QboneReadOnlyStreamTest, ReadBuffered) { + CreateReliableQuicStream(); + std::string packet = "Stuf"; + { + QuicStreamFrame frame(kStreamId, false, 0, packet); + stream_->OnStreamFrame(frame); + } + // We didn't write 5 bytes yet... + + packet = "f"; + EXPECT_CALL(*session_, ProcessPacketFromPeer("Stuff")); + { + QuicStreamFrame frame(kStreamId, true, 4, packet); + stream_->OnStreamFrame(frame); + } +} + +TEST_F(QboneReadOnlyStreamTest, ReadOutOfOrder) { + CreateReliableQuicStream(); + std::string packet = "f"; + { + QuicStreamFrame frame(kStreamId, true, 4, packet); + stream_->OnStreamFrame(frame); + } + + packet = "S"; + { + QuicStreamFrame frame(kStreamId, false, 0, packet); + stream_->OnStreamFrame(frame); + } + + packet = "tuf"; + EXPECT_CALL(*session_, ProcessPacketFromPeer("Stuff")); + { + QuicStreamFrame frame(kStreamId, false, 1, packet); + stream_->OnStreamFrame(frame); + } +} + +// Test buffering too many bytes. +TEST_F(QboneReadOnlyStreamTest, ReadBufferedTooLarge) { + CreateReliableQuicStream(); + std::string packet = "0123456789"; + int iterations = (QboneConstants::kMaxQbonePacketBytes / packet.size()) + 2; + EXPECT_CALL(*session_, MaybeSendStopSendingFrame( + kStreamId, QuicResetStreamError::FromInternal( + QUIC_BAD_APPLICATION_PAYLOAD))); + EXPECT_CALL( + *session_, + MaybeSendRstStreamFrame( + kStreamId, + QuicResetStreamError::FromInternal(QUIC_BAD_APPLICATION_PAYLOAD), _)); + for (int i = 0; i < iterations; ++i) { + QuicStreamFrame frame(kStreamId, i == (iterations - 1), i * packet.size(), + packet); + if (!stream_->reading_stopped()) { + stream_->OnStreamFrame(frame); + } + } + // We should have nothing written to the network and the stream + // should have stopped reading. + EXPECT_TRUE(stream_->reading_stopped()); +} + +} // namespace + +} // namespace quic diff --git a/quiche/quic/test_tools/bad_packet_writer.cc b/quiche/quic/test_tools/bad_packet_writer.cc new file mode 100644 index 000000000000..1e146c4376ff --- /dev/null +++ b/quiche/quic/test_tools/bad_packet_writer.cc @@ -0,0 +1,35 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/bad_packet_writer.h" + +namespace quic { +namespace test { + +BadPacketWriter::BadPacketWriter(size_t packet_causing_write_error, + int error_code) + : packet_causing_write_error_(packet_causing_write_error), + error_code_(error_code) {} + +BadPacketWriter::~BadPacketWriter() {} + +WriteResult BadPacketWriter::WritePacket(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + PerPacketOptions* options) { + if (error_code_ == 0 || packet_causing_write_error_ > 0) { + if (packet_causing_write_error_ > 0) { + --packet_causing_write_error_; + } + return QuicPacketWriterWrapper::WritePacket(buffer, buf_len, self_address, + peer_address, options); + } + // It's time to cause write error. + int error_code = error_code_; + error_code_ = 0; + return WriteResult(WRITE_STATUS_ERROR, error_code); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/bad_packet_writer.h b/quiche/quic/test_tools/bad_packet_writer.h new file mode 100644 index 000000000000..bcf12f5f4dbf --- /dev/null +++ b/quiche/quic/test_tools/bad_packet_writer.h @@ -0,0 +1,35 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_BAD_PACKET_WRITER_H_ +#define QUICHE_QUIC_TEST_TOOLS_BAD_PACKET_WRITER_H_ + +#include "quiche/quic/core/quic_packet_writer_wrapper.h" + +namespace quic { + +namespace test { +// This packet writer allows causing packet write error with specified error +// code when writing a particular packet. +class BadPacketWriter : public QuicPacketWriterWrapper { + public: + BadPacketWriter(size_t packet_causing_write_error, int error_code); + + ~BadPacketWriter() override; + + WriteResult WritePacket(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + PerPacketOptions* options) override; + + private: + size_t packet_causing_write_error_; + int error_code_; +}; + +} // namespace test + +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_BAD_PACKET_WRITER_H_ diff --git a/quiche/quic/test_tools/crypto_test_utils.cc b/quiche/quic/test_tools/crypto_test_utils.cc new file mode 100644 index 000000000000..4df4f8f3cb2a --- /dev/null +++ b/quiche/quic/test_tools/crypto_test_utils.cc @@ -0,0 +1,979 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/crypto_test_utils.h" + +#include +#include +#include +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "openssl/bn.h" +#include "openssl/ec.h" +#include "openssl/ecdsa.h" +#include "openssl/nid.h" +#include "openssl/sha.h" +#include "quiche/quic/core/crypto/certificate_view.h" +#include "quiche/quic/core/crypto/channel_id.h" +#include "quiche/quic/core/crypto/crypto_handshake.h" +#include "quiche/quic/core/crypto/crypto_utils.h" +#include "quiche/quic/core/crypto/proof_source_x509.h" +#include "quiche/quic/core/crypto/quic_crypto_server_config.h" +#include "quiche/quic/core/crypto/quic_decrypter.h" +#include "quiche/quic/core/crypto/quic_encrypter.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/proto/crypto_server_config_proto.h" +#include "quiche/quic/core/quic_clock.h" +#include "quiche/quic/core/quic_crypto_client_stream.h" +#include "quiche/quic/core/quic_crypto_server_stream_base.h" +#include "quiche/quic/core/quic_crypto_stream.h" +#include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_hostname_utils.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_framer_peer.h" +#include "quiche/quic/test_tools/quic_stream_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/test_tools/simple_quic_framer.h" +#include "quiche/quic/test_tools/test_certificates.h" +#include "quiche/common/test_tools/quiche_test_utils.h" + +namespace quic { +namespace test { + +namespace crypto_test_utils { + +namespace { + +using testing::_; + +// CryptoFramerVisitor is a framer visitor that records handshake messages. +class CryptoFramerVisitor : public CryptoFramerVisitorInterface { + public: + CryptoFramerVisitor() : error_(false) {} + + void OnError(CryptoFramer* /*framer*/) override { error_ = true; } + + void OnHandshakeMessage(const CryptoHandshakeMessage& message) override { + messages_.push_back(message); + } + + bool error() const { return error_; } + + const std::vector& messages() const { + return messages_; + } + + private: + bool error_; + std::vector messages_; +}; + +// HexChar parses |c| as a hex character. If valid, it sets |*value| to the +// value of the hex character and returns true. Otherwise it returns false. +bool HexChar(char c, uint8_t* value) { + if (c >= '0' && c <= '9') { + *value = c - '0'; + return true; + } + if (c >= 'a' && c <= 'f') { + *value = c - 'a' + 10; + return true; + } + if (c >= 'A' && c <= 'F') { + *value = c - 'A' + 10; + return true; + } + return false; +} + +} // anonymous namespace + +FakeClientOptions::FakeClientOptions() {} + +FakeClientOptions::~FakeClientOptions() {} + +namespace { +// This class is used by GenerateFullCHLO() to extract SCID and STK from +// REJ and to construct a full CHLO with these fields and given inchoate +// CHLO. +class FullChloGenerator { + public: + FullChloGenerator( + QuicCryptoServerConfig* crypto_config, QuicSocketAddress server_addr, + QuicSocketAddress client_addr, const QuicClock* clock, + ParsedQuicVersion version, + quiche::QuicheReferenceCountedPointer + signed_config, + QuicCompressedCertsCache* compressed_certs_cache, + CryptoHandshakeMessage* out) + : crypto_config_(crypto_config), + server_addr_(server_addr), + client_addr_(client_addr), + clock_(clock), + version_(version), + signed_config_(signed_config), + compressed_certs_cache_(compressed_certs_cache), + out_(out), + params_(new QuicCryptoNegotiatedParameters) {} + + class ValidateClientHelloCallback : public ValidateClientHelloResultCallback { + public: + explicit ValidateClientHelloCallback(FullChloGenerator* generator) + : generator_(generator) {} + void Run(quiche::QuicheReferenceCountedPointer< + ValidateClientHelloResultCallback::Result> + result, + std::unique_ptr /* details */) override { + generator_->ValidateClientHelloDone(std::move(result)); + } + + private: + FullChloGenerator* generator_; + }; + + std::unique_ptr + GetValidateClientHelloCallback() { + return std::make_unique(this); + } + + private: + void ValidateClientHelloDone(quiche::QuicheReferenceCountedPointer< + ValidateClientHelloResultCallback::Result> + result) { + result_ = result; + crypto_config_->ProcessClientHello( + result_, /*reject_only=*/false, TestConnectionId(1), server_addr_, + client_addr_, version_, {version_}, clock_, QuicRandom::GetInstance(), + compressed_certs_cache_, params_, signed_config_, + /*total_framing_overhead=*/50, kDefaultMaxPacketSize, + GetProcessClientHelloCallback()); + } + + class ProcessClientHelloCallback : public ProcessClientHelloResultCallback { + public: + explicit ProcessClientHelloCallback(FullChloGenerator* generator) + : generator_(generator) {} + void Run(QuicErrorCode error, const std::string& error_details, + std::unique_ptr message, + std::unique_ptr /*diversification_nonce*/, + std::unique_ptr /*proof_source_details*/) + override { + ASSERT_TRUE(message) << QuicErrorCodeToString(error) << " " + << error_details; + generator_->ProcessClientHelloDone(std::move(message)); + } + + private: + FullChloGenerator* generator_; + }; + + std::unique_ptr GetProcessClientHelloCallback() { + return std::make_unique(this); + } + + void ProcessClientHelloDone(std::unique_ptr rej) { + // Verify output is a REJ. + EXPECT_THAT(rej->tag(), testing::Eq(kREJ)); + + QUIC_VLOG(1) << "Extract valid STK and SCID from\n" << rej->DebugString(); + absl::string_view srct; + ASSERT_TRUE(rej->GetStringPiece(kSourceAddressTokenTag, &srct)); + + absl::string_view scfg; + ASSERT_TRUE(rej->GetStringPiece(kSCFG, &scfg)); + std::unique_ptr server_config( + CryptoFramer::ParseMessage(scfg)); + + absl::string_view scid; + ASSERT_TRUE(server_config->GetStringPiece(kSCID, &scid)); + + *out_ = result_->client_hello; + out_->SetStringPiece(kSCID, scid); + out_->SetStringPiece(kSourceAddressTokenTag, srct); + uint64_t xlct = LeafCertHashForTesting(); + out_->SetValue(kXLCT, xlct); + } + + protected: + QuicCryptoServerConfig* crypto_config_; + QuicSocketAddress server_addr_; + QuicSocketAddress client_addr_; + const QuicClock* clock_; + ParsedQuicVersion version_; + quiche::QuicheReferenceCountedPointer signed_config_; + QuicCompressedCertsCache* compressed_certs_cache_; + CryptoHandshakeMessage* out_; + + quiche::QuicheReferenceCountedPointer params_; + quiche::QuicheReferenceCountedPointer< + ValidateClientHelloResultCallback::Result> + result_; +}; + +} // namespace + +std::unique_ptr CryptoServerConfigForTesting() { + return std::make_unique( + QuicCryptoServerConfig::TESTING, QuicRandom::GetInstance(), + ProofSourceForTesting(), KeyExchangeSource::Default()); +} + +int HandshakeWithFakeServer(QuicConfig* server_quic_config, + QuicCryptoServerConfig* crypto_config, + MockQuicConnectionHelper* helper, + MockAlarmFactory* alarm_factory, + PacketSavingConnection* client_conn, + QuicCryptoClientStreamBase* client, + std::string alpn) { + auto* server_conn = new testing::NiceMock( + helper, alarm_factory, Perspective::IS_SERVER, + ParsedVersionOfIndex(client_conn->supported_versions(), 0)); + + QuicCompressedCertsCache compressed_certs_cache( + QuicCompressedCertsCache::kQuicCompressedCertsCacheSize); + SetupCryptoServerConfigForTest( + server_conn->clock(), server_conn->random_generator(), crypto_config); + + TestQuicSpdyServerSession server_session( + server_conn, *server_quic_config, client_conn->supported_versions(), + crypto_config, &compressed_certs_cache); + // Call SetServerApplicationStateForResumption so that the fake server + // supports 0-RTT in TLS. + server_session.Initialize(); + server_session.GetMutableCryptoStream() + ->SetServerApplicationStateForResumption( + std::make_unique()); + EXPECT_CALL(*server_session.helper(), + CanAcceptClientHello(testing::_, testing::_, testing::_, + testing::_, testing::_)) + .Times(testing::AnyNumber()); + EXPECT_CALL(*server_conn, OnCanWrite()).Times(testing::AnyNumber()); + EXPECT_CALL(*client_conn, OnCanWrite()).Times(testing::AnyNumber()); + EXPECT_CALL(*server_conn, SendCryptoData(_, _, _)) + .Times(testing::AnyNumber()); + EXPECT_CALL(server_session, SelectAlpn(_)) + .WillRepeatedly([alpn](const std::vector& alpns) { + return std::find(alpns.cbegin(), alpns.cend(), alpn); + }); + + // The client's handshake must have been started already. + QUICHE_CHECK_NE(0u, client_conn->encrypted_packets_.size()); + + CommunicateHandshakeMessages(client_conn, client, server_conn, + server_session.GetMutableCryptoStream()); + if (client_conn->connected() && server_conn->connected()) { + CompareClientAndServerKeys(client, server_session.GetMutableCryptoStream()); + } + + return client->num_sent_client_hellos(); +} + +int HandshakeWithFakeClient(MockQuicConnectionHelper* helper, + MockAlarmFactory* alarm_factory, + PacketSavingConnection* server_conn, + QuicCryptoServerStreamBase* server, + const QuicServerId& server_id, + const FakeClientOptions& options, + std::string alpn) { + // This function does not do version negotiation; read the supported versions + // directly from the server connection instead. + ParsedQuicVersionVector supported_versions = + server_conn->supported_versions(); + if (options.only_tls_versions) { + supported_versions.erase( + std::remove_if(supported_versions.begin(), supported_versions.end(), + [](const ParsedQuicVersion& version) { + return version.handshake_protocol != PROTOCOL_TLS1_3; + }), + supported_versions.end()); + QUICHE_CHECK(!options.only_quic_crypto_versions); + } else if (options.only_quic_crypto_versions) { + supported_versions.erase( + std::remove_if(supported_versions.begin(), supported_versions.end(), + [](const ParsedQuicVersion& version) { + return version.handshake_protocol != + PROTOCOL_QUIC_CRYPTO; + }), + supported_versions.end()); + } + PacketSavingConnection* client_conn = new PacketSavingConnection( + helper, alarm_factory, Perspective::IS_CLIENT, supported_versions); + // Advance the time, because timers do not like uninitialized times. + client_conn->AdvanceTime(QuicTime::Delta::FromSeconds(1)); + + QuicCryptoClientConfig crypto_config(ProofVerifierForTesting()); + TestQuicSpdyClientSession client_session(client_conn, DefaultQuicConfig(), + supported_versions, server_id, + &crypto_config); + + EXPECT_CALL(client_session, OnProofValid(testing::_)) + .Times(testing::AnyNumber()); + EXPECT_CALL(client_session, OnProofVerifyDetailsAvailable(testing::_)) + .Times(testing::AnyNumber()); + EXPECT_CALL(*client_conn, OnCanWrite()).Times(testing::AnyNumber()); + if (!alpn.empty()) { + EXPECT_CALL(client_session, GetAlpnsToOffer()) + .WillRepeatedly(testing::Return(std::vector({alpn}))); + } else { + EXPECT_CALL(client_session, GetAlpnsToOffer()) + .WillRepeatedly(testing::Return(std::vector( + {AlpnForVersion(client_conn->version())}))); + } + client_session.GetMutableCryptoStream()->CryptoConnect(); + QUICHE_CHECK_EQ(1u, client_conn->encrypted_packets_.size()); + + CommunicateHandshakeMessages(client_conn, + client_session.GetMutableCryptoStream(), + server_conn, server); + + if (server->one_rtt_keys_available() && server->encryption_established()) { + CompareClientAndServerKeys(client_session.GetMutableCryptoStream(), server); + } + + return client_session.GetCryptoStream()->num_sent_client_hellos(); +} + +void SetupCryptoServerConfigForTest(const QuicClock* clock, QuicRandom* rand, + QuicCryptoServerConfig* crypto_config) { + QuicCryptoServerConfig::ConfigOptions options; + options.channel_id_enabled = true; + std::unique_ptr scfg = + crypto_config->AddDefaultConfig(rand, clock, options); +} + +void SendHandshakeMessageToStream(QuicCryptoStream* stream, + const CryptoHandshakeMessage& message, + Perspective /*perspective*/) { + const QuicData& data = message.GetSerialized(); + QuicSession* session = QuicStreamPeer::session(stream); + if (!QuicVersionUsesCryptoFrames(session->transport_version())) { + QuicStreamFrame frame( + QuicUtils::GetCryptoStreamId(session->transport_version()), false, + stream->crypto_bytes_read(), data.AsStringPiece()); + stream->OnStreamFrame(frame); + } else { + EncryptionLevel level = session->connection()->last_decrypted_level(); + QuicCryptoFrame frame(level, stream->BytesReadOnLevel(level), + data.AsStringPiece()); + stream->OnCryptoFrame(frame); + } +} + +void CommunicateHandshakeMessages(PacketSavingConnection* client_conn, + QuicCryptoStream* client, + PacketSavingConnection* server_conn, + QuicCryptoStream* server) { + size_t client_i = 0, server_i = 0; + while (client_conn->connected() && server_conn->connected() && + (!client->one_rtt_keys_available() || + !server->one_rtt_keys_available())) { + ASSERT_GT(client_conn->encrypted_packets_.size(), client_i); + QUIC_LOG(INFO) << "Processing " + << client_conn->encrypted_packets_.size() - client_i + << " packets client->server"; + MovePackets(client_conn, &client_i, server, server_conn, + Perspective::IS_SERVER, /*process_stream_data=*/false); + + if (client->one_rtt_keys_available() && server->one_rtt_keys_available() && + server_conn->encrypted_packets_.size() == server_i) { + break; + } + ASSERT_GT(server_conn->encrypted_packets_.size(), server_i); + QUIC_LOG(INFO) << "Processing " + << server_conn->encrypted_packets_.size() - server_i + << " packets server->client"; + MovePackets(server_conn, &server_i, client, client_conn, + Perspective::IS_CLIENT, /*process_stream_data=*/false); + } +} + +bool CommunicateHandshakeMessagesUntil(PacketSavingConnection* client_conn, + QuicCryptoStream* client, + std::function client_condition, + PacketSavingConnection* server_conn, + QuicCryptoStream* server, + std::function server_condition, + bool process_stream_data) { + size_t client_next_packet_to_deliver = + client_conn->number_of_packets_delivered_; + size_t server_next_packet_to_deliver = + server_conn->number_of_packets_delivered_; + while ( + client_conn->connected() && server_conn->connected() && + (!client_condition() || !server_condition()) && + (client_conn->encrypted_packets_.size() > client_next_packet_to_deliver || + server_conn->encrypted_packets_.size() > + server_next_packet_to_deliver)) { + if (!server_condition()) { + QUIC_LOG(INFO) << "Processing " + << client_conn->encrypted_packets_.size() - + client_next_packet_to_deliver + << " packets client->server"; + MovePackets(client_conn, &client_next_packet_to_deliver, server, + server_conn, Perspective::IS_SERVER, process_stream_data); + } + if (!client_condition()) { + QUIC_LOG(INFO) << "Processing " + << server_conn->encrypted_packets_.size() - + server_next_packet_to_deliver + << " packets server->client"; + MovePackets(server_conn, &server_next_packet_to_deliver, client, + client_conn, Perspective::IS_CLIENT, process_stream_data); + } + } + client_conn->number_of_packets_delivered_ = client_next_packet_to_deliver; + server_conn->number_of_packets_delivered_ = server_next_packet_to_deliver; + bool result = client_condition() && server_condition(); + if (!result) { + QUIC_LOG(INFO) << "CommunicateHandshakeMessagesUnti failed with state: " + "client connected? " + << client_conn->connected() << " server connected? " + << server_conn->connected() << " client condition met? " + << client_condition() << " server condition met? " + << server_condition(); + } + return result; +} + +std::pair AdvanceHandshake(PacketSavingConnection* client_conn, + QuicCryptoStream* client, + size_t client_i, + PacketSavingConnection* server_conn, + QuicCryptoStream* server, + size_t server_i) { + if (client_conn->encrypted_packets_.size() != client_i) { + QUIC_LOG(INFO) << "Processing " + << client_conn->encrypted_packets_.size() - client_i + << " packets client->server"; + MovePackets(client_conn, &client_i, server, server_conn, + Perspective::IS_SERVER, /*process_stream_data=*/false); + } + + if (server_conn->encrypted_packets_.size() != server_i) { + QUIC_LOG(INFO) << "Processing " + << server_conn->encrypted_packets_.size() - server_i + << " packets server->client"; + MovePackets(server_conn, &server_i, client, client_conn, + Perspective::IS_CLIENT, /*process_stream_data=*/false); + } + + return std::make_pair(client_i, server_i); +} + +std::string GetValueForTag(const CryptoHandshakeMessage& message, QuicTag tag) { + auto it = message.tag_value_map().find(tag); + if (it == message.tag_value_map().end()) { + return std::string(); + } + return it->second; +} + +uint64_t LeafCertHashForTesting() { + quiche::QuicheReferenceCountedPointer chain; + QuicSocketAddress server_address(QuicIpAddress::Any4(), 42); + QuicSocketAddress client_address(QuicIpAddress::Any4(), 43); + QuicCryptoProof proof; + std::unique_ptr proof_source(ProofSourceForTesting()); + + class Callback : public ProofSource::Callback { + public: + Callback(bool* ok, + quiche::QuicheReferenceCountedPointer* chain) + : ok_(ok), chain_(chain) {} + + void Run( + bool ok, + const quiche::QuicheReferenceCountedPointer& chain, + const QuicCryptoProof& /* proof */, + std::unique_ptr /* details */) override { + *ok_ = ok; + *chain_ = chain; + } + + private: + bool* ok_; + quiche::QuicheReferenceCountedPointer* chain_; + }; + + // Note: relies on the callback being invoked synchronously + bool ok = false; + proof_source->GetProof( + server_address, client_address, "", "", + AllSupportedVersionsWithQuicCrypto().front().transport_version, "", + std::unique_ptr(new Callback(&ok, &chain))); + if (!ok || chain->certs.empty()) { + QUICHE_DCHECK(false) << "Proof generation failed"; + return 0; + } + + return QuicUtils::FNV1a_64_Hash(chain->certs.at(0)); +} + +void FillInDummyReject(CryptoHandshakeMessage* rej) { + rej->set_tag(kREJ); + + // Minimum SCFG that passes config validation checks. + // clang-format off + unsigned char scfg[] = { + // SCFG + 0x53, 0x43, 0x46, 0x47, + // num entries + 0x01, 0x00, + // padding + 0x00, 0x00, + // EXPY + 0x45, 0x58, 0x50, 0x59, + // EXPY end offset + 0x08, 0x00, 0x00, 0x00, + // Value + '1', '2', '3', '4', + '5', '6', '7', '8' + }; + // clang-format on + rej->SetValue(kSCFG, scfg); + rej->SetStringPiece(kServerNonceTag, "SERVER_NONCE"); + int64_t ttl = 2 * 24 * 60 * 60; + rej->SetValue(kSTTL, ttl); + std::vector reject_reasons; + reject_reasons.push_back(CLIENT_NONCE_INVALID_FAILURE); + rej->SetVector(kRREJ, reject_reasons); +} + +namespace { + +#define RETURN_STRING_LITERAL(x) \ + case x: \ + return #x + +std::string EncryptionLevelString(EncryptionLevel level) { + switch (level) { + RETURN_STRING_LITERAL(ENCRYPTION_INITIAL); + RETURN_STRING_LITERAL(ENCRYPTION_HANDSHAKE); + RETURN_STRING_LITERAL(ENCRYPTION_ZERO_RTT); + RETURN_STRING_LITERAL(ENCRYPTION_FORWARD_SECURE); + default: + return ""; + } +} + +void CompareCrypters(const QuicEncrypter* encrypter, + const QuicDecrypter* decrypter, std::string label) { + if (encrypter == nullptr || decrypter == nullptr) { + ADD_FAILURE() << "Expected non-null crypters; have " << encrypter << " and " + << decrypter << " for " << label; + return; + } + absl::string_view encrypter_key = encrypter->GetKey(); + absl::string_view encrypter_iv = encrypter->GetNoncePrefix(); + absl::string_view decrypter_key = decrypter->GetKey(); + absl::string_view decrypter_iv = decrypter->GetNoncePrefix(); + quiche::test::CompareCharArraysWithHexError( + label + " key", encrypter_key.data(), encrypter_key.length(), + decrypter_key.data(), decrypter_key.length()); + quiche::test::CompareCharArraysWithHexError( + label + " iv", encrypter_iv.data(), encrypter_iv.length(), + decrypter_iv.data(), decrypter_iv.length()); +} + +} // namespace + +void CompareClientAndServerKeys(QuicCryptoClientStreamBase* client, + QuicCryptoServerStreamBase* server) { + QuicFramer* client_framer = QuicConnectionPeer::GetFramer( + QuicStreamPeer::session(client)->connection()); + QuicFramer* server_framer = QuicConnectionPeer::GetFramer( + QuicStreamPeer::session(server)->connection()); + for (EncryptionLevel level : + {ENCRYPTION_HANDSHAKE, ENCRYPTION_ZERO_RTT, ENCRYPTION_FORWARD_SECURE}) { + SCOPED_TRACE(EncryptionLevelString(level)); + const QuicEncrypter* client_encrypter( + QuicFramerPeer::GetEncrypter(client_framer, level)); + const QuicDecrypter* server_decrypter( + QuicFramerPeer::GetDecrypter(server_framer, level)); + if (level == ENCRYPTION_FORWARD_SECURE || + !((level == ENCRYPTION_HANDSHAKE || level == ENCRYPTION_ZERO_RTT || + client_encrypter == nullptr) && + (level == ENCRYPTION_ZERO_RTT || server_decrypter == nullptr))) { + CompareCrypters(client_encrypter, server_decrypter, + "client " + EncryptionLevelString(level) + " write"); + } + const QuicEncrypter* server_encrypter( + QuicFramerPeer::GetEncrypter(server_framer, level)); + const QuicDecrypter* client_decrypter( + QuicFramerPeer::GetDecrypter(client_framer, level)); + if (level == ENCRYPTION_FORWARD_SECURE || + !(server_encrypter == nullptr && + (level == ENCRYPTION_HANDSHAKE || level == ENCRYPTION_ZERO_RTT || + client_decrypter == nullptr))) { + CompareCrypters(server_encrypter, client_decrypter, + "server " + EncryptionLevelString(level) + " write"); + } + } + + absl::string_view client_subkey_secret = + client->crypto_negotiated_params().subkey_secret; + absl::string_view server_subkey_secret = + server->crypto_negotiated_params().subkey_secret; + quiche::test::CompareCharArraysWithHexError( + "subkey secret", client_subkey_secret.data(), + client_subkey_secret.length(), server_subkey_secret.data(), + server_subkey_secret.length()); +} + +QuicTag ParseTag(const char* tagstr) { + const size_t len = strlen(tagstr); + QUICHE_CHECK_NE(0u, len); + + QuicTag tag = 0; + + if (tagstr[0] == '#') { + QUICHE_CHECK_EQ(static_cast(1 + 2 * 4), len); + tagstr++; + + for (size_t i = 0; i < 8; i++) { + tag <<= 4; + + uint8_t v = 0; + QUICHE_CHECK(HexChar(tagstr[i], &v)); + tag |= v; + } + + return tag; + } + + QUICHE_CHECK_LE(len, 4u); + for (size_t i = 0; i < 4; i++) { + tag >>= 8; + if (i < len) { + tag |= static_cast(tagstr[i]) << 24; + } + } + + return tag; +} + +CryptoHandshakeMessage CreateCHLO( + std::vector> tags_and_values) { + return CreateCHLO(tags_and_values, -1); +} + +CryptoHandshakeMessage CreateCHLO( + std::vector> tags_and_values, + int minimum_size_bytes) { + CryptoHandshakeMessage msg; + msg.set_tag(MakeQuicTag('C', 'H', 'L', 'O')); + + if (minimum_size_bytes > 0) { + msg.set_minimum_size(minimum_size_bytes); + } + + for (const auto& tag_and_value : tags_and_values) { + const std::string& tag = tag_and_value.first; + const std::string& value = tag_and_value.second; + + const QuicTag quic_tag = ParseTag(tag.c_str()); + + size_t value_len = value.length(); + if (value_len > 0 && value[0] == '#') { + // This is ascii encoded hex. + std::string hex_value = + absl::HexStringToBytes(absl::string_view(&value[1])); + msg.SetStringPiece(quic_tag, hex_value); + continue; + } + msg.SetStringPiece(quic_tag, value); + } + + // The CryptoHandshakeMessage needs to be serialized and parsed to ensure + // that any padding is included. + std::unique_ptr bytes = + CryptoFramer::ConstructHandshakeMessage(msg); + std::unique_ptr parsed( + CryptoFramer::ParseMessage(bytes->AsStringPiece())); + QUICHE_CHECK(parsed); + + return *parsed; +} + +void MovePackets(PacketSavingConnection* source_conn, + size_t* inout_packet_index, QuicCryptoStream* dest_stream, + PacketSavingConnection* dest_conn, + Perspective dest_perspective, bool process_stream_data) { + SimpleQuicFramer framer(source_conn->supported_versions(), dest_perspective); + QuicFramerPeer::SetLastSerializedServerConnectionId(framer.framer(), + TestConnectionId()); + + SimpleQuicFramer null_encryption_framer(source_conn->supported_versions(), + dest_perspective); + QuicFramerPeer::SetLastSerializedServerConnectionId( + null_encryption_framer.framer(), TestConnectionId()); + + size_t index = *inout_packet_index; + for (; index < source_conn->encrypted_packets_.size(); index++) { + if (!dest_conn->connected()) { + QUIC_LOG(INFO) + << "Destination connection disconnected. Skipping packet at index " + << index; + continue; + } + // In order to properly test the code we need to perform encryption and + // decryption so that the crypters latch when expected. The crypters are in + // |dest_conn|, but we don't want to try and use them there. Instead we swap + // them into |framer|, perform the decryption with them, and then swap ther + // back. + QuicConnectionPeer::SwapCrypters(dest_conn, framer.framer()); + QuicConnectionPeer::AddBytesReceived( + dest_conn, source_conn->encrypted_packets_[index]->length()); + if (!framer.ProcessPacket(*source_conn->encrypted_packets_[index])) { + // The framer will be unable to decrypt zero-rtt packets sent during + // handshake or forward-secure packets sent after the handshake is + // complete. Don't treat them as handshake packets. + QuicConnectionPeer::SwapCrypters(dest_conn, framer.framer()); + continue; + } + QuicConnectionPeer::SwapCrypters(dest_conn, framer.framer()); + + // Install a packet flusher such that the packets generated by |dest_conn| + // in response to this packet are more likely to be coalesced and/or batched + // in the writer. + QuicConnection::ScopedPacketFlusher flusher(dest_conn); + + dest_conn->OnDecryptedPacket( + source_conn->encrypted_packets_[index]->length(), + framer.last_decrypted_level()); + + if (dest_stream->handshake_protocol() == PROTOCOL_TLS1_3) { + // Try to process the packet with a framer that only has the NullDecrypter + // for decryption. If ProcessPacket succeeds, that means the packet was + // encrypted with the NullEncrypter. With the TLS handshaker in use, no + // packets should ever be encrypted with the NullEncrypter, instead + // they're encrypted with an obfuscation cipher based on QUIC version and + // connection ID. + QUIC_LOG(INFO) << "Attempting to decrypt with NullDecrypter: " + "expect a decryption failure on the next log line."; + ASSERT_FALSE(null_encryption_framer.ProcessPacket( + *source_conn->encrypted_packets_[index])) + << "No TLS packets should be encrypted with the NullEncrypter"; + } + + // Since we're using QuicFramers separate from the connections to move + // packets, the QuicConnection never gets notified about what level the last + // packet was decrypted at. This is needed by TLS to know what encryption + // level was used for the data it's receiving, so we plumb this information + // from the SimpleQuicFramer back into the connection. + dest_conn->OnDecryptedPacket( + source_conn->encrypted_packets_[index]->length(), + framer.last_decrypted_level()); + + QuicConnectionPeer::SetCurrentPacket( + dest_conn, source_conn->encrypted_packets_[index]->AsStringPiece()); + for (const auto& stream_frame : framer.stream_frames()) { + if (process_stream_data && + dest_stream->handshake_protocol() == PROTOCOL_TLS1_3) { + // Deliver STREAM_FRAME such that application state is available and can + // be stored along with resumption ticket in session cache, + dest_conn->OnStreamFrame(*stream_frame); + } else { + // Ignore stream frames that are sent on other streams in the crypto + // event. + if (stream_frame->stream_id == dest_stream->id()) { + dest_stream->OnStreamFrame(*stream_frame); + } + } + } + for (const auto& crypto_frame : framer.crypto_frames()) { + dest_stream->OnCryptoFrame(*crypto_frame); + } + if (!framer.connection_close_frames().empty() && dest_conn->connected()) { + dest_conn->OnConnectionCloseFrame(framer.connection_close_frames()[0]); + } + } + *inout_packet_index = index; + + QuicConnectionPeer::SetCurrentPacket(dest_conn, + absl::string_view(nullptr, 0)); +} + +CryptoHandshakeMessage GenerateDefaultInchoateCHLO( + const QuicClock* clock, QuicTransportVersion version, + QuicCryptoServerConfig* crypto_config) { + // clang-format off + return CreateCHLO( + {{"PDMD", "X509"}, + {"AEAD", "AESG"}, + {"KEXS", "C255"}, + {"PUBS", GenerateClientPublicValuesHex().c_str()}, + {"NONC", GenerateClientNonceHex(clock, crypto_config).c_str()}, + {"VER\0", QuicVersionLabelToString( + CreateQuicVersionLabel( + ParsedQuicVersion(PROTOCOL_QUIC_CRYPTO, version))).c_str()}}, + kClientHelloMinimumSize); + // clang-format on +} + +std::string GenerateClientNonceHex(const QuicClock* clock, + QuicCryptoServerConfig* crypto_config) { + QuicCryptoServerConfig::ConfigOptions old_config_options; + QuicCryptoServerConfig::ConfigOptions new_config_options; + old_config_options.id = "old-config-id"; + crypto_config->AddDefaultConfig(QuicRandom::GetInstance(), clock, + old_config_options); + QuicServerConfigProtobuf primary_config = crypto_config->GenerateConfig( + QuicRandom::GetInstance(), clock, new_config_options); + primary_config.set_primary_time(clock->WallNow().ToUNIXSeconds()); + std::unique_ptr msg = + crypto_config->AddConfig(primary_config, clock->WallNow()); + absl::string_view orbit; + QUICHE_CHECK(msg->GetStringPiece(kORBT, &orbit)); + std::string nonce; + CryptoUtils::GenerateNonce(clock->WallNow(), QuicRandom::GetInstance(), orbit, + &nonce); + return ("#" + absl::BytesToHexString(nonce)); +} + +std::string GenerateClientPublicValuesHex() { + char public_value[32]; + memset(public_value, 42, sizeof(public_value)); + return ("#" + absl::BytesToHexString( + absl::string_view(public_value, sizeof(public_value)))); +} + +void GenerateFullCHLO( + const CryptoHandshakeMessage& inchoate_chlo, + QuicCryptoServerConfig* crypto_config, QuicSocketAddress server_addr, + QuicSocketAddress client_addr, QuicTransportVersion transport_version, + const QuicClock* clock, + quiche::QuicheReferenceCountedPointer signed_config, + QuicCompressedCertsCache* compressed_certs_cache, + CryptoHandshakeMessage* out) { + // Pass a inchoate CHLO. + FullChloGenerator generator( + crypto_config, server_addr, client_addr, clock, + ParsedQuicVersion(PROTOCOL_QUIC_CRYPTO, transport_version), signed_config, + compressed_certs_cache, out); + crypto_config->ValidateClientHello( + inchoate_chlo, client_addr, server_addr, transport_version, clock, + signed_config, generator.GetValidateClientHelloCallback()); +} + +namespace { + +constexpr char kTestProofHostname[] = "test.example.com"; + +class TestProofSource : public ProofSourceX509 { + public: + TestProofSource() + : ProofSourceX509( + quiche::QuicheReferenceCountedPointer( + new ProofSource::Chain( + std::vector{std::string(kTestCertificate)})), + std::move(*CertificatePrivateKey::LoadFromDer( + kTestCertificatePrivateKey))) { + QUICHE_DCHECK(valid()); + } + + protected: + void MaybeAddSctsForHostname(absl::string_view /*hostname*/, + std::string& leaf_cert_scts) override { + leaf_cert_scts = "Certificate Transparency is really nice"; + } +}; + +class TestProofVerifier : public ProofVerifier { + public: + TestProofVerifier() + : certificate_(std::move( + *CertificateView::ParseSingleCertificate(kTestCertificate))) {} + + class Details : public ProofVerifyDetails { + public: + ProofVerifyDetails* Clone() const override { return new Details(*this); } + }; + + QuicAsyncStatus VerifyProof( + const std::string& hostname, const uint16_t port, + const std::string& server_config, + QuicTransportVersion /*transport_version*/, absl::string_view chlo_hash, + const std::vector& certs, const std::string& cert_sct, + const std::string& signature, const ProofVerifyContext* context, + std::string* error_details, std::unique_ptr* details, + std::unique_ptr callback) override { + absl::optional payload = + CryptoUtils::GenerateProofPayloadToBeSigned(chlo_hash, server_config); + if (!payload.has_value()) { + *error_details = "Failed to serialize signed payload"; + return QUIC_FAILURE; + } + if (!certificate_.VerifySignature(*payload, signature, + SSL_SIGN_RSA_PSS_RSAE_SHA256)) { + *error_details = "Invalid signature"; + return QUIC_FAILURE; + } + + uint8_t out_alert; + return VerifyCertChain(hostname, port, certs, /*ocsp_response=*/"", + cert_sct, context, error_details, details, + &out_alert, std::move(callback)); + } + + QuicAsyncStatus VerifyCertChain( + const std::string& hostname, const uint16_t /*port*/, + const std::vector& certs, + const std::string& /*ocsp_response*/, const std::string& /*cert_sct*/, + const ProofVerifyContext* /*context*/, std::string* error_details, + std::unique_ptr* details, uint8_t* /*out_alert*/, + std::unique_ptr /*callback*/) override { + std::string normalized_hostname = + QuicHostnameUtils::NormalizeHostname(hostname); + if (normalized_hostname != kTestProofHostname) { + *error_details = absl::StrCat("Invalid hostname, expected ", + kTestProofHostname, " got ", hostname); + return QUIC_FAILURE; + } + if (certs.empty() || certs.front() != kTestCertificate) { + *error_details = "Received certificate different from the expected"; + return QUIC_FAILURE; + } + *details = std::make_unique
(); + return QUIC_SUCCESS; + } + + std::unique_ptr CreateDefaultContext() override { + return nullptr; + } + + private: + CertificateView certificate_; +}; + +} // namespace + +std::unique_ptr ProofSourceForTesting() { + return std::make_unique(); +} + +std::unique_ptr ProofVerifierForTesting() { + return std::make_unique(); +} + +std::string CertificateHostnameForTesting() { return kTestProofHostname; } + +std::unique_ptr ProofVerifyContextForTesting() { + return nullptr; +} + +} // namespace crypto_test_utils +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/crypto_test_utils.h b/quiche/quic/test_tools/crypto_test_utils.h new file mode 100644 index 000000000000..d65070614263 --- /dev/null +++ b/quiche/quic/test_tools/crypto_test_utils.h @@ -0,0 +1,222 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_CRYPTO_TEST_UTILS_H_ +#define QUICHE_QUIC_TEST_TOOLS_CRYPTO_TEST_UTILS_H_ + +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "openssl/evp.h" +#include "quiche/quic/core/crypto/crypto_framer.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/quic_framer.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { + +class ProofSource; +class ProofVerifier; +class ProofVerifyContext; +class QuicClock; +class QuicConfig; +class QuicCryptoClientStream; +class QuicCryptoServerConfig; +class QuicCryptoServerStreamBase; +class QuicCryptoStream; +class QuicServerId; + +namespace test { + +class PacketSavingConnection; + +namespace crypto_test_utils { + +// An interface for a source of callbacks. This is used for invoking +// callbacks asynchronously. +// +// Call the RunPendingCallbacks method regularly to run the callbacks from +// this source. +class CallbackSource { + public: + virtual ~CallbackSource() {} + + // Runs pending callbacks from this source. If there is no pending + // callback, does nothing. + virtual void RunPendingCallbacks() = 0; +}; + +// FakeClientOptions bundles together a number of options for configuring +// HandshakeWithFakeClient. +struct FakeClientOptions { + FakeClientOptions(); + ~FakeClientOptions(); + + // If only_tls_versions is set, then the client will only use TLS for the + // crypto handshake. + bool only_tls_versions = false; + + // If only_quic_crypto_versions is set, then the client will only use + // PROTOCOL_QUIC_CRYPTO for the crypto handshake. + bool only_quic_crypto_versions = false; +}; + +// Returns a QuicCryptoServerConfig that is in a reasonable configuration to +// pass into HandshakeWithFakeServer. +std::unique_ptr CryptoServerConfigForTesting(); + +// returns: the number of client hellos that the client sent. +int HandshakeWithFakeServer(QuicConfig* server_quic_config, + QuicCryptoServerConfig* crypto_config, + MockQuicConnectionHelper* helper, + MockAlarmFactory* alarm_factory, + PacketSavingConnection* client_conn, + QuicCryptoClientStreamBase* client, + std::string alpn); + +// returns: the number of client hellos that the client sent. +int HandshakeWithFakeClient(MockQuicConnectionHelper* helper, + MockAlarmFactory* alarm_factory, + PacketSavingConnection* server_conn, + QuicCryptoServerStreamBase* server, + const QuicServerId& server_id, + const FakeClientOptions& options, std::string alpn); + +// SetupCryptoServerConfigForTest configures |crypto_config| +// with sensible defaults for testing. +void SetupCryptoServerConfigForTest(const QuicClock* clock, QuicRandom* rand, + QuicCryptoServerConfig* crypto_config); + +// Sends the handshake message |message| to stream |stream| with the perspective +// that the message is coming from |perspective|. +void SendHandshakeMessageToStream(QuicCryptoStream* stream, + const CryptoHandshakeMessage& message, + Perspective perspective); + +// CommunicateHandshakeMessages moves messages from |client| to |server| and +// back until |clients|'s handshake has completed. +void CommunicateHandshakeMessages(PacketSavingConnection* client_conn, + QuicCryptoStream* client, + PacketSavingConnection* server_conn, + QuicCryptoStream* server); + +// CommunicateHandshakeMessagesUntil: +// 1) Moves messages from |client| to |server| until |server_condition| is met. +// 2) Moves messages from |server| to |client| until |client_condition| is met. +// 3) For IETF QUIC, if `process_stream_data` is true, STREAM_FRAME within the +// packet containing crypto messages is also processed. +// 4) Returns true if both conditions are met. +// 5) Returns false if either connection is closed or there is no more packet to +// deliver before both conditions are met. +bool CommunicateHandshakeMessagesUntil(PacketSavingConnection* client_conn, + QuicCryptoStream* client, + std::function client_condition, + PacketSavingConnection* server_conn, + QuicCryptoStream* server, + std::function server_condition, + bool process_stream_data); + +// AdvanceHandshake attempts to moves messages from |client| to |server| and +// |server| to |client|. Returns the number of messages moved. +std::pair AdvanceHandshake(PacketSavingConnection* client_conn, + QuicCryptoStream* client, + size_t client_i, + PacketSavingConnection* server_conn, + QuicCryptoStream* server, + size_t server_i); + +// Returns the value for the tag |tag| in the tag value map of |message|. +std::string GetValueForTag(const CryptoHandshakeMessage& message, QuicTag tag); + +// Returns a new |ProofSource| that serves up test certificates. +std::unique_ptr ProofSourceForTesting(); + +// Returns a new |ProofVerifier| that uses the QUIC testing root CA. +std::unique_ptr ProofVerifierForTesting(); + +// Returns the hostname used by the proof source and the proof verifier above. +std::string CertificateHostnameForTesting(); + +// Returns a hash of the leaf test certificate. +uint64_t LeafCertHashForTesting(); + +// Returns a |ProofVerifyContext| that must be used with the verifier +// returned by |ProofVerifierForTesting|. +std::unique_ptr ProofVerifyContextForTesting(); + +// Creates a minimal dummy reject message that will pass the client-config +// validation tests. This will include a server config, but no certs, proof +// source address token, or server nonce. +void FillInDummyReject(CryptoHandshakeMessage* rej); + +// ParseTag returns a QuicTag from parsing |tagstr|. |tagstr| may either be +// in the format "EXMP" (i.e. ASCII format), or "#11223344" (an explicit hex +// format). It QUICHE_CHECK fails if there's a parse error. +QuicTag ParseTag(const char* tagstr); + +// Message constructs a CHLO message from a provided vector of tag/value pairs. +// The first of each pair is the tag of a tag/value and is given as an argument +// to |ParseTag|. The second is the value of the tag/value pair and is either a +// hex dump, preceeded by a '#', or a raw value. If minimum_size_bytes is +// provided then the message will be padded to this minimum size. +// +// CreateCHLO( +// {{"NOCE", "#11223344"}, +// {"SNI", "www.example.com"}}, +// optional_minimum_size_bytes); +CryptoHandshakeMessage CreateCHLO( + std::vector> tags_and_values); +CryptoHandshakeMessage CreateCHLO( + std::vector> tags_and_values, + int minimum_size_bytes); + +// MovePackets parses crypto handshake messages from packet number +// |*inout_packet_index| through to the last packet (or until a packet fails +// to decrypt) and has |dest_stream| process them. |*inout_packet_index| is +// updated with an index one greater than the last packet processed. For IETF +// QUIC, if `process_stream_data` is true, STREAM_FRAME within the packet +// containing crypto messages is also processed. +void MovePackets(PacketSavingConnection* source_conn, + size_t* inout_packet_index, QuicCryptoStream* dest_stream, + PacketSavingConnection* dest_conn, + Perspective dest_perspective, bool process_stream_data); + +// Return an inchoate CHLO with some basic tag value pairs. +CryptoHandshakeMessage GenerateDefaultInchoateCHLO( + const QuicClock* clock, QuicTransportVersion version, + QuicCryptoServerConfig* crypto_config); + +// Takes a inchoate CHLO, returns a full CHLO in |out| which can pass +// |crypto_config|'s validation. +void GenerateFullCHLO( + const CryptoHandshakeMessage& inchoate_chlo, + QuicCryptoServerConfig* crypto_config, QuicSocketAddress server_addr, + QuicSocketAddress client_addr, QuicTransportVersion transport_version, + const QuicClock* clock, + quiche::QuicheReferenceCountedPointer signed_config, + QuicCompressedCertsCache* compressed_certs_cache, + CryptoHandshakeMessage* out); + +void CompareClientAndServerKeys(QuicCryptoClientStreamBase* client, + QuicCryptoServerStreamBase* server); + +// Return a CHLO nonce in hexadecimal. +std::string GenerateClientNonceHex(const QuicClock* clock, + QuicCryptoServerConfig* crypto_config); + +// Return a CHLO PUBS in hexadecimal. +std::string GenerateClientPublicValuesHex(); + +} // namespace crypto_test_utils + +} // namespace test + +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_CRYPTO_TEST_UTILS_H_ diff --git a/quiche/quic/test_tools/crypto_test_utils_test.cc b/quiche/quic/test_tools/crypto_test_utils_test.cc new file mode 100644 index 000000000000..05221756a62e --- /dev/null +++ b/quiche/quic/test_tools/crypto_test_utils_test.cc @@ -0,0 +1,187 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/crypto_test_utils.h" + +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/proto/crypto_server_config_proto.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/mock_clock.h" + +namespace quic { +namespace test { + +class ShloVerifier { + public: + ShloVerifier(QuicCryptoServerConfig* crypto_config, + QuicSocketAddress server_addr, QuicSocketAddress client_addr, + const QuicClock* clock, + quiche::QuicheReferenceCountedPointer + signed_config, + QuicCompressedCertsCache* compressed_certs_cache, + ParsedQuicVersion version) + : crypto_config_(crypto_config), + server_addr_(server_addr), + client_addr_(client_addr), + clock_(clock), + signed_config_(signed_config), + compressed_certs_cache_(compressed_certs_cache), + params_(new QuicCryptoNegotiatedParameters), + version_(version) {} + + class ValidateClientHelloCallback : public ValidateClientHelloResultCallback { + public: + explicit ValidateClientHelloCallback(ShloVerifier* shlo_verifier) + : shlo_verifier_(shlo_verifier) {} + void Run(quiche::QuicheReferenceCountedPointer< + ValidateClientHelloResultCallback::Result> + result, + std::unique_ptr /* details */) override { + shlo_verifier_->ValidateClientHelloDone(result); + } + + private: + ShloVerifier* shlo_verifier_; + }; + + std::unique_ptr + GetValidateClientHelloCallback() { + return std::make_unique(this); + } + + private: + void ValidateClientHelloDone( + const quiche::QuicheReferenceCountedPointer< + ValidateClientHelloResultCallback::Result>& result) { + result_ = result; + crypto_config_->ProcessClientHello( + result_, /*reject_only=*/false, + /*connection_id=*/TestConnectionId(1), server_addr_, client_addr_, + version_, AllSupportedVersions(), clock_, QuicRandom::GetInstance(), + compressed_certs_cache_, params_, signed_config_, + /*total_framing_overhead=*/50, kDefaultMaxPacketSize, + GetProcessClientHelloCallback()); + } + + class ProcessClientHelloCallback : public ProcessClientHelloResultCallback { + public: + explicit ProcessClientHelloCallback(ShloVerifier* shlo_verifier) + : shlo_verifier_(shlo_verifier) {} + void Run(QuicErrorCode /*error*/, const std::string& /*error_details*/, + std::unique_ptr message, + std::unique_ptr /*diversification_nonce*/, + std::unique_ptr /*proof_source_details*/) + override { + shlo_verifier_->ProcessClientHelloDone(std::move(message)); + } + + private: + ShloVerifier* shlo_verifier_; + }; + + std::unique_ptr GetProcessClientHelloCallback() { + return std::make_unique(this); + } + + void ProcessClientHelloDone(std::unique_ptr message) { + // Verify output is a SHLO. + EXPECT_EQ(message->tag(), kSHLO) + << "Fail to pass validation. Get " << message->DebugString(); + } + + QuicCryptoServerConfig* crypto_config_; + QuicSocketAddress server_addr_; + QuicSocketAddress client_addr_; + const QuicClock* clock_; + quiche::QuicheReferenceCountedPointer signed_config_; + QuicCompressedCertsCache* compressed_certs_cache_; + + quiche::QuicheReferenceCountedPointer params_; + quiche::QuicheReferenceCountedPointer< + ValidateClientHelloResultCallback::Result> + result_; + + const ParsedQuicVersion version_; +}; + +class CryptoTestUtilsTest : public QuicTest {}; + +TEST_F(CryptoTestUtilsTest, TestGenerateFullCHLO) { + MockClock clock; + QuicCryptoServerConfig crypto_config( + QuicCryptoServerConfig::TESTING, QuicRandom::GetInstance(), + crypto_test_utils::ProofSourceForTesting(), KeyExchangeSource::Default()); + QuicSocketAddress server_addr(QuicIpAddress::Any4(), 5); + QuicSocketAddress client_addr(QuicIpAddress::Loopback4(), 1); + quiche::QuicheReferenceCountedPointer signed_config( + new QuicSignedServerConfig); + QuicCompressedCertsCache compressed_certs_cache( + QuicCompressedCertsCache::kQuicCompressedCertsCacheSize); + CryptoHandshakeMessage full_chlo; + + QuicCryptoServerConfig::ConfigOptions old_config_options; + old_config_options.id = "old-config-id"; + crypto_config.AddDefaultConfig(QuicRandom::GetInstance(), &clock, + old_config_options); + QuicCryptoServerConfig::ConfigOptions new_config_options; + QuicServerConfigProtobuf primary_config = crypto_config.GenerateConfig( + QuicRandom::GetInstance(), &clock, new_config_options); + primary_config.set_primary_time(clock.WallNow().ToUNIXSeconds()); + std::unique_ptr msg = + crypto_config.AddConfig(primary_config, clock.WallNow()); + absl::string_view orbit; + ASSERT_TRUE(msg->GetStringPiece(kORBT, &orbit)); + std::string nonce; + CryptoUtils::GenerateNonce(clock.WallNow(), QuicRandom::GetInstance(), orbit, + &nonce); + std::string nonce_hex = "#" + absl::BytesToHexString(nonce); + + char public_value[32]; + memset(public_value, 42, sizeof(public_value)); + std::string pub_hex = "#" + absl::BytesToHexString(absl::string_view( + public_value, sizeof(public_value))); + + // The methods below use a PROTOCOL_QUIC_CRYPTO version so we pick the + // first one from the list of supported versions. + QuicTransportVersion transport_version = QUIC_VERSION_UNSUPPORTED; + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + if (version.handshake_protocol == PROTOCOL_QUIC_CRYPTO) { + transport_version = version.transport_version; + break; + } + } + ASSERT_NE(QUIC_VERSION_UNSUPPORTED, transport_version); + + CryptoHandshakeMessage inchoate_chlo = crypto_test_utils::CreateCHLO( + {{"PDMD", "X509"}, + {"AEAD", "AESG"}, + {"KEXS", "C255"}, + {"COPT", "SREJ"}, + {"PUBS", pub_hex}, + {"NONC", nonce_hex}, + {"VER\0", + QuicVersionLabelToString(CreateQuicVersionLabel( + ParsedQuicVersion(PROTOCOL_QUIC_CRYPTO, transport_version)))}}, + kClientHelloMinimumSize); + + crypto_test_utils::GenerateFullCHLO(inchoate_chlo, &crypto_config, + server_addr, client_addr, + transport_version, &clock, signed_config, + &compressed_certs_cache, &full_chlo); + // Verify that full_chlo can pass crypto_config's verification. + ShloVerifier shlo_verifier( + &crypto_config, server_addr, client_addr, &clock, signed_config, + &compressed_certs_cache, + ParsedQuicVersion(PROTOCOL_QUIC_CRYPTO, transport_version)); + crypto_config.ValidateClientHello( + full_chlo, client_addr, server_addr, transport_version, &clock, + signed_config, shlo_verifier.GetValidateClientHelloCallback()); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/failing_proof_source.cc b/quiche/quic/test_tools/failing_proof_source.cc new file mode 100644 index 000000000000..55ae06eaa170 --- /dev/null +++ b/quiche/quic/test_tools/failing_proof_source.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/failing_proof_source.h" + +#include "absl/strings/string_view.h" + +namespace quic { +namespace test { + +void FailingProofSource::GetProof(const QuicSocketAddress& /*server_address*/, + const QuicSocketAddress& /*client_address*/, + const std::string& /*hostname*/, + const std::string& /*server_config*/, + QuicTransportVersion /*transport_version*/, + absl::string_view /*chlo_hash*/, + std::unique_ptr callback) { + callback->Run(false, nullptr, QuicCryptoProof(), nullptr); +} + +quiche::QuicheReferenceCountedPointer +FailingProofSource::GetCertChain(const QuicSocketAddress& /*server_address*/, + const QuicSocketAddress& /*client_address*/, + const std::string& /*hostname*/, + bool* cert_matched_sni) { + *cert_matched_sni = false; + return quiche::QuicheReferenceCountedPointer(); +} + +void FailingProofSource::ComputeTlsSignature( + const QuicSocketAddress& /*server_address*/, + const QuicSocketAddress& /*client_address*/, + const std::string& /*hostname*/, uint16_t /*signature_algorithm*/, + absl::string_view /*in*/, std::unique_ptr callback) { + callback->Run(false, "", nullptr); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/failing_proof_source.h b/quiche/quic/test_tools/failing_proof_source.h new file mode 100644 index 000000000000..f9fe973ac27c --- /dev/null +++ b/quiche/quic/test_tools/failing_proof_source.h @@ -0,0 +1,45 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_FAILING_PROOF_SOURCE_H_ +#define QUICHE_QUIC_TEST_TOOLS_FAILING_PROOF_SOURCE_H_ + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/proof_source.h" + +namespace quic { +namespace test { + +class FailingProofSource : public ProofSource { + public: + void GetProof(const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const std::string& hostname, const std::string& server_config, + QuicTransportVersion transport_version, + absl::string_view chlo_hash, + std::unique_ptr callback) override; + + quiche::QuicheReferenceCountedPointer GetCertChain( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, const std::string& hostname, + bool* cert_matched_sni) override; + + void ComputeTlsSignature( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, const std::string& hostname, + uint16_t signature_algorithm, absl::string_view in, + std::unique_ptr callback) override; + + absl::InlinedVector SupportedTlsSignatureAlgorithms() + const override { + return {}; + } + + TicketCrypter* GetTicketCrypter() override { return nullptr; } +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_FAILING_PROOF_SOURCE_H_ diff --git a/quiche/quic/test_tools/fake_proof_source.cc b/quiche/quic/test_tools/fake_proof_source.cc new file mode 100644 index 000000000000..43c5a7257b07 --- /dev/null +++ b/quiche/quic/test_tools/fake_proof_source.cc @@ -0,0 +1,146 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/fake_proof_source.h" + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" + +namespace quic { +namespace test { + +FakeProofSource::FakeProofSource() + : delegate_(crypto_test_utils::ProofSourceForTesting()) {} + +FakeProofSource::~FakeProofSource() {} + +FakeProofSource::PendingOp::~PendingOp() = default; + +FakeProofSource::GetProofOp::GetProofOp( + const QuicSocketAddress& server_addr, + const QuicSocketAddress& client_address, std::string hostname, + std::string server_config, QuicTransportVersion transport_version, + std::string chlo_hash, std::unique_ptr callback, + ProofSource* delegate) + : server_address_(server_addr), + client_address_(client_address), + hostname_(std::move(hostname)), + server_config_(std::move(server_config)), + transport_version_(transport_version), + chlo_hash_(std::move(chlo_hash)), + callback_(std::move(callback)), + delegate_(delegate) {} + +FakeProofSource::GetProofOp::~GetProofOp() = default; + +void FakeProofSource::GetProofOp::Run() { + // Note: relies on the callback being invoked synchronously + delegate_->GetProof(server_address_, client_address_, hostname_, + server_config_, transport_version_, chlo_hash_, + std::move(callback_)); +} + +FakeProofSource::ComputeSignatureOp::ComputeSignatureOp( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, std::string hostname, + uint16_t sig_alg, absl::string_view in, + std::unique_ptr callback, + ProofSource* delegate) + : server_address_(server_address), + client_address_(client_address), + hostname_(std::move(hostname)), + sig_alg_(sig_alg), + in_(in), + callback_(std::move(callback)), + delegate_(delegate) {} + +FakeProofSource::ComputeSignatureOp::~ComputeSignatureOp() = default; + +void FakeProofSource::ComputeSignatureOp::Run() { + delegate_->ComputeTlsSignature(server_address_, client_address_, hostname_, + sig_alg_, in_, std::move(callback_)); +} + +void FakeProofSource::Activate() { active_ = true; } + +void FakeProofSource::GetProof( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, const std::string& hostname, + const std::string& server_config, QuicTransportVersion transport_version, + absl::string_view chlo_hash, + std::unique_ptr callback) { + if (!active_) { + delegate_->GetProof(server_address, client_address, hostname, server_config, + transport_version, chlo_hash, std::move(callback)); + return; + } + + pending_ops_.push_back(std::make_unique( + server_address, client_address, hostname, server_config, + transport_version, std::string(chlo_hash), std::move(callback), + delegate_.get())); +} + +quiche::QuicheReferenceCountedPointer +FakeProofSource::GetCertChain(const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const std::string& hostname, + bool* cert_matched_sni) { + return delegate_->GetCertChain(server_address, client_address, hostname, + cert_matched_sni); +} + +void FakeProofSource::ComputeTlsSignature( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, const std::string& hostname, + uint16_t signature_algorithm, absl::string_view in, + std::unique_ptr callback) { + QUIC_LOG(INFO) << "FakeProofSource::ComputeTlsSignature"; + if (!active_) { + QUIC_LOG(INFO) << "Not active - directly calling delegate"; + delegate_->ComputeTlsSignature(server_address, client_address, hostname, + signature_algorithm, in, + std::move(callback)); + return; + } + + QUIC_LOG(INFO) << "Adding pending op"; + pending_ops_.push_back(std::make_unique( + server_address, client_address, hostname, signature_algorithm, in, + std::move(callback), delegate_.get())); +} + +absl::InlinedVector +FakeProofSource::SupportedTlsSignatureAlgorithms() const { + return delegate_->SupportedTlsSignatureAlgorithms(); +} + +ProofSource::TicketCrypter* FakeProofSource::GetTicketCrypter() { + if (ticket_crypter_) { + return ticket_crypter_.get(); + } + return delegate_->GetTicketCrypter(); +} + +void FakeProofSource::SetTicketCrypter( + std::unique_ptr ticket_crypter) { + ticket_crypter_ = std::move(ticket_crypter); +} + +int FakeProofSource::NumPendingCallbacks() const { return pending_ops_.size(); } + +void FakeProofSource::InvokePendingCallback(int n) { + QUICHE_CHECK(NumPendingCallbacks() > n); + + pending_ops_[n]->Run(); + + auto it = pending_ops_.begin() + n; + pending_ops_.erase(it); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/fake_proof_source.h b/quiche/quic/test_tools/fake_proof_source.h new file mode 100644 index 000000000000..b0069af376ac --- /dev/null +++ b/quiche/quic/test_tools/fake_proof_source.h @@ -0,0 +1,129 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_FAKE_PROOF_SOURCE_H_ +#define QUICHE_QUIC_TEST_TOOLS_FAKE_PROOF_SOURCE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/proof_source.h" + +namespace quic { +namespace test { + +// Implementation of ProofSource which delegates to a ProofSourceForTesting, but +// allows for overriding certain functionality. FakeProofSource allows +// intercepting calls to GetProof and ComputeTlsSignature to force them to run +// asynchronously, and allow the caller to see that the call is pending and +// resume the operation at the caller's choosing. FakeProofSource also allows +// the caller to replace the TicketCrypter provided by +// FakeProofSource::GetTicketCrypter. +class FakeProofSource : public ProofSource { + public: + FakeProofSource(); + ~FakeProofSource() override; + + // Before this object is "active", all calls to GetProof will be delegated + // immediately. Once "active", the async ones will be intercepted. This + // distinction is necessary to ensure that GetProof can be called without + // interference during test case setup. + void Activate(); + + // ProofSource interface + void GetProof(const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const std::string& hostname, const std::string& server_config, + QuicTransportVersion transport_version, + absl::string_view chlo_hash, + std::unique_ptr callback) override; + quiche::QuicheReferenceCountedPointer GetCertChain( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, const std::string& hostname, + bool* cert_matched_sni) override; + void ComputeTlsSignature( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, const std::string& hostname, + uint16_t signature_algorithm, absl::string_view in, + std::unique_ptr callback) override; + absl::InlinedVector SupportedTlsSignatureAlgorithms() + const override; + TicketCrypter* GetTicketCrypter() override; + + // Sets the TicketCrypter to use. If nullptr, the TicketCrypter from + // ProofSourceForTesting will be returned instead. + void SetTicketCrypter(std::unique_ptr ticket_crypter); + + // Get the number of callbacks which are pending + int NumPendingCallbacks() const; + + // Invoke a pending callback. The index refers to the position in + // pending_ops_ of the callback to be completed. + void InvokePendingCallback(int n); + + private: + std::unique_ptr delegate_; + std::unique_ptr ticket_crypter_; + bool active_ = false; + + class PendingOp { + public: + virtual ~PendingOp(); + virtual void Run() = 0; + }; + + class GetProofOp : public PendingOp { + public: + GetProofOp(const QuicSocketAddress& server_addr, + const QuicSocketAddress& client_address, std::string hostname, + std::string server_config, + QuicTransportVersion transport_version, std::string chlo_hash, + std::unique_ptr callback, + ProofSource* delegate); + ~GetProofOp() override; + + void Run() override; + + private: + QuicSocketAddress server_address_; + QuicSocketAddress client_address_; + std::string hostname_; + std::string server_config_; + QuicTransportVersion transport_version_; + std::string chlo_hash_; + std::unique_ptr callback_; + ProofSource* delegate_; + }; + + class ComputeSignatureOp : public PendingOp { + public: + ComputeSignatureOp(const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + std::string hostname, uint16_t sig_alg, + absl::string_view in, + std::unique_ptr callback, + ProofSource* delegate); + ~ComputeSignatureOp() override; + + void Run() override; + + private: + QuicSocketAddress server_address_; + QuicSocketAddress client_address_; + std::string hostname_; + uint16_t sig_alg_; + std::string in_; + std::unique_ptr callback_; + ProofSource* delegate_; + }; + + std::vector> pending_ops_; +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_FAKE_PROOF_SOURCE_H_ diff --git a/quiche/quic/test_tools/fake_proof_source_handle.cc b/quiche/quic/test_tools/fake_proof_source_handle.cc new file mode 100644 index 000000000000..ac1f3a5138ee --- /dev/null +++ b/quiche/quic/test_tools/fake_proof_source_handle.cc @@ -0,0 +1,235 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/fake_proof_source_handle.h" + +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" + +namespace quic { +namespace test { +namespace { + +struct QUIC_EXPORT_PRIVATE ComputeSignatureResult { + bool ok; + std::string signature; + std::unique_ptr details; +}; + +class QUIC_EXPORT_PRIVATE ResultSavingSignatureCallback + : public ProofSource::SignatureCallback { + public: + explicit ResultSavingSignatureCallback( + absl::optional* result) + : result_(result) { + QUICHE_DCHECK(!result_->has_value()); + } + void Run(bool ok, std::string signature, + std::unique_ptr details) override { + result_->emplace( + ComputeSignatureResult{ok, std::move(signature), std::move(details)}); + } + + private: + absl::optional* result_; +}; + +ComputeSignatureResult ComputeSignatureNow( + ProofSource* delegate, const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, const std::string& hostname, + uint16_t signature_algorithm, absl::string_view in) { + absl::optional result; + delegate->ComputeTlsSignature( + server_address, client_address, hostname, signature_algorithm, in, + std::make_unique(&result)); + QUICHE_CHECK(result.has_value()) + << "delegate->ComputeTlsSignature must computes a " + "signature immediately"; + return std::move(result.value()); +} +} // namespace + +FakeProofSourceHandle::FakeProofSourceHandle( + ProofSource* delegate, ProofSourceHandleCallback* callback, + Action select_cert_action, Action compute_signature_action, + QuicDelayedSSLConfig dealyed_ssl_config) + : delegate_(delegate), + callback_(callback), + select_cert_action_(select_cert_action), + compute_signature_action_(compute_signature_action), + dealyed_ssl_config_(dealyed_ssl_config) {} + +void FakeProofSourceHandle::CloseHandle() { + select_cert_op_.reset(); + compute_signature_op_.reset(); + closed_ = true; +} + +QuicAsyncStatus FakeProofSourceHandle::SelectCertificate( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const QuicConnectionId& original_connection_id, + absl::string_view ssl_capabilities, const std::string& hostname, + absl::string_view client_hello, const std::string& alpn, + absl::optional alps, + const std::vector& quic_transport_params, + const absl::optional>& early_data_context, + const QuicSSLConfig& ssl_config) { + if (select_cert_action_ != Action::FAIL_SYNC_DO_NOT_CHECK_CLOSED) { + QUICHE_CHECK(!closed_); + } + all_select_cert_args_.push_back( + SelectCertArgs(server_address, client_address, original_connection_id, + ssl_capabilities, hostname, client_hello, alpn, alps, + quic_transport_params, early_data_context, ssl_config)); + + if (select_cert_action_ == Action::DELEGATE_ASYNC || + select_cert_action_ == Action::FAIL_ASYNC) { + select_cert_op_.emplace(delegate_, callback_, select_cert_action_, + all_select_cert_args_.back(), dealyed_ssl_config_); + return QUIC_PENDING; + } else if (select_cert_action_ == Action::FAIL_SYNC || + select_cert_action_ == Action::FAIL_SYNC_DO_NOT_CHECK_CLOSED) { + callback()->OnSelectCertificateDone( + /*ok=*/false, + /*is_sync=*/true, nullptr, /*handshake_hints=*/absl::string_view(), + /*ticket_encryption_key=*/absl::string_view(), + /*cert_matched_sni=*/false, dealyed_ssl_config_); + return QUIC_FAILURE; + } + + QUICHE_DCHECK(select_cert_action_ == Action::DELEGATE_SYNC); + bool cert_matched_sni; + quiche::QuicheReferenceCountedPointer chain = + delegate_->GetCertChain(server_address, client_address, hostname, + &cert_matched_sni); + + bool ok = chain && !chain->certs.empty(); + callback_->OnSelectCertificateDone( + ok, /*is_sync=*/true, chain.get(), + /*handshake_hints=*/absl::string_view(), + /*ticket_encryption_key=*/absl::string_view(), + /*cert_matched_sni=*/cert_matched_sni, dealyed_ssl_config_); + return ok ? QUIC_SUCCESS : QUIC_FAILURE; +} + +QuicAsyncStatus FakeProofSourceHandle::ComputeSignature( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, const std::string& hostname, + uint16_t signature_algorithm, absl::string_view in, + size_t max_signature_size) { + if (compute_signature_action_ != Action::FAIL_SYNC_DO_NOT_CHECK_CLOSED) { + QUICHE_CHECK(!closed_); + } + all_compute_signature_args_.push_back( + ComputeSignatureArgs(server_address, client_address, hostname, + signature_algorithm, in, max_signature_size)); + + if (compute_signature_action_ == Action::DELEGATE_ASYNC || + compute_signature_action_ == Action::FAIL_ASYNC) { + compute_signature_op_.emplace(delegate_, callback_, + compute_signature_action_, + all_compute_signature_args_.back()); + return QUIC_PENDING; + } else if (compute_signature_action_ == Action::FAIL_SYNC || + compute_signature_action_ == + Action::FAIL_SYNC_DO_NOT_CHECK_CLOSED) { + callback()->OnComputeSignatureDone(/*ok=*/false, /*is_sync=*/true, + /*signature=*/"", /*details=*/nullptr); + return QUIC_FAILURE; + } + + QUICHE_DCHECK(compute_signature_action_ == Action::DELEGATE_SYNC); + ComputeSignatureResult result = + ComputeSignatureNow(delegate_, server_address, client_address, hostname, + signature_algorithm, in); + callback_->OnComputeSignatureDone( + result.ok, /*is_sync=*/true, result.signature, std::move(result.details)); + return result.ok ? QUIC_SUCCESS : QUIC_FAILURE; +} + +ProofSourceHandleCallback* FakeProofSourceHandle::callback() { + return callback_; +} + +bool FakeProofSourceHandle::HasPendingOperation() const { + int num_pending_operations = NumPendingOperations(); + return num_pending_operations > 0; +} + +void FakeProofSourceHandle::CompletePendingOperation() { + QUICHE_DCHECK_LE(NumPendingOperations(), 1); + + if (select_cert_op_.has_value()) { + select_cert_op_->Run(); + select_cert_op_.reset(); + } else if (compute_signature_op_.has_value()) { + compute_signature_op_->Run(); + compute_signature_op_.reset(); + } +} + +int FakeProofSourceHandle::NumPendingOperations() const { + return static_cast(select_cert_op_.has_value()) + + static_cast(compute_signature_op_.has_value()); +} + +FakeProofSourceHandle::SelectCertOperation::SelectCertOperation( + ProofSource* delegate, ProofSourceHandleCallback* callback, Action action, + SelectCertArgs args, QuicDelayedSSLConfig dealyed_ssl_config) + : PendingOperation(delegate, callback, action), + args_(std::move(args)), + dealyed_ssl_config_(dealyed_ssl_config) {} + +void FakeProofSourceHandle::SelectCertOperation::Run() { + if (action_ == Action::FAIL_ASYNC) { + callback_->OnSelectCertificateDone( + /*ok=*/false, + /*is_sync=*/false, nullptr, + /*handshake_hints=*/absl::string_view(), + /*ticket_encryption_key=*/absl::string_view(), + /*cert_matched_sni=*/false, dealyed_ssl_config_); + } else if (action_ == Action::DELEGATE_ASYNC) { + bool cert_matched_sni; + quiche::QuicheReferenceCountedPointer chain = + delegate_->GetCertChain(args_.server_address, args_.client_address, + args_.hostname, &cert_matched_sni); + bool ok = chain && !chain->certs.empty(); + callback_->OnSelectCertificateDone( + ok, /*is_sync=*/false, chain.get(), + /*handshake_hints=*/absl::string_view(), + /*ticket_encryption_key=*/absl::string_view(), + /*cert_matched_sni=*/cert_matched_sni, dealyed_ssl_config_); + } else { + QUIC_BUG(quic_bug_10139_1) + << "Unexpected action: " << static_cast(action_); + } +} + +FakeProofSourceHandle::ComputeSignatureOperation::ComputeSignatureOperation( + ProofSource* delegate, ProofSourceHandleCallback* callback, Action action, + ComputeSignatureArgs args) + : PendingOperation(delegate, callback, action), args_(std::move(args)) {} + +void FakeProofSourceHandle::ComputeSignatureOperation::Run() { + if (action_ == Action::FAIL_ASYNC) { + callback_->OnComputeSignatureDone( + /*ok=*/false, /*is_sync=*/false, + /*signature=*/"", /*details=*/nullptr); + } else if (action_ == Action::DELEGATE_ASYNC) { + ComputeSignatureResult result = ComputeSignatureNow( + delegate_, args_.server_address, args_.client_address, args_.hostname, + args_.signature_algorithm, args_.in); + callback_->OnComputeSignatureDone(result.ok, /*is_sync=*/false, + result.signature, + std::move(result.details)); + } else { + QUIC_BUG(quic_bug_10139_2) + << "Unexpected action: " << static_cast(action_); + } +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/fake_proof_source_handle.h b/quiche/quic/test_tools/fake_proof_source_handle.h new file mode 100644 index 000000000000..599a1fa538a1 --- /dev/null +++ b/quiche/quic/test_tools/fake_proof_source_handle.h @@ -0,0 +1,198 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_FAKE_PROOF_SOURCE_HANDLE_H_ +#define QUICHE_QUIC_TEST_TOOLS_FAKE_PROOF_SOURCE_HANDLE_H_ + +#include "quiche/quic/core/crypto/proof_source.h" +#include "quiche/quic/core/quic_connection_id.h" + +namespace quic { +namespace test { + +// FakeProofSourceHandle allows its behavior to be scripted for testing. +class FakeProofSourceHandle : public ProofSourceHandle { + public: + // What would an operation return when it is called. + enum class Action { + // Delegate the operation to |delegate_| immediately. + DELEGATE_SYNC = 0, + // Handle the operation asynchronously. Delegate the operation to + // |delegate_| when the caller calls CompletePendingOperation(). + DELEGATE_ASYNC, + // Fail the operation immediately. + FAIL_SYNC, + // Handle the operation asynchronously. Fail the operation when the caller + // calls CompletePendingOperation(). + FAIL_ASYNC, + // Similar to FAIL_SYNC, but do not QUICHE_CHECK(!closed_) when invoked. + FAIL_SYNC_DO_NOT_CHECK_CLOSED, + }; + + // |delegate| must do cert selection and signature synchronously. + // |dealyed_ssl_config| is the config passed to OnSelectCertificateDone. + FakeProofSourceHandle( + ProofSource* delegate, ProofSourceHandleCallback* callback, + Action select_cert_action, Action compute_signature_action, + QuicDelayedSSLConfig dealyed_ssl_config = QuicDelayedSSLConfig()); + + ~FakeProofSourceHandle() override = default; + + void CloseHandle() override; + + QuicAsyncStatus SelectCertificate( + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const QuicConnectionId& original_connection_id, + absl::string_view ssl_capabilities, const std::string& hostname, + absl::string_view client_hello, const std::string& alpn, + absl::optional alps, + const std::vector& quic_transport_params, + const absl::optional>& early_data_context, + const QuicSSLConfig& ssl_config) override; + + QuicAsyncStatus ComputeSignature(const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const std::string& hostname, + uint16_t signature_algorithm, + absl::string_view in, + size_t max_signature_size) override; + + ProofSourceHandleCallback* callback() override; + + // Whether there's a pending operation in |this|. + bool HasPendingOperation() const; + void CompletePendingOperation(); + + struct SelectCertArgs { + SelectCertArgs(QuicSocketAddress server_address, + QuicSocketAddress client_address, + QuicConnectionId original_connection_id, + absl::string_view ssl_capabilities, std::string hostname, + absl::string_view client_hello, std::string alpn, + absl::optional alps, + std::vector quic_transport_params, + absl::optional> early_data_context, + QuicSSLConfig ssl_config) + : server_address(server_address), + client_address(client_address), + original_connection_id(original_connection_id), + ssl_capabilities(ssl_capabilities), + hostname(hostname), + client_hello(client_hello), + alpn(alpn), + alps(alps), + quic_transport_params(quic_transport_params), + early_data_context(early_data_context), + ssl_config(ssl_config) {} + + QuicSocketAddress server_address; + QuicSocketAddress client_address; + QuicConnectionId original_connection_id; + std::string ssl_capabilities; + std::string hostname; + std::string client_hello; + std::string alpn; + absl::optional alps; + std::vector quic_transport_params; + absl::optional> early_data_context; + QuicSSLConfig ssl_config; + }; + + struct ComputeSignatureArgs { + ComputeSignatureArgs(QuicSocketAddress server_address, + QuicSocketAddress client_address, std::string hostname, + uint16_t signature_algorithm, absl::string_view in, + size_t max_signature_size) + : server_address(server_address), + client_address(client_address), + hostname(hostname), + signature_algorithm(signature_algorithm), + in(in), + max_signature_size(max_signature_size) {} + + QuicSocketAddress server_address; + QuicSocketAddress client_address; + std::string hostname; + uint16_t signature_algorithm; + std::string in; + size_t max_signature_size; + }; + + std::vector all_select_cert_args() const { + return all_select_cert_args_; + } + + std::vector all_compute_signature_args() const { + return all_compute_signature_args_; + } + + private: + class PendingOperation { + public: + PendingOperation(ProofSource* delegate, ProofSourceHandleCallback* callback, + Action action) + : delegate_(delegate), callback_(callback), action_(action) {} + virtual ~PendingOperation() = default; + virtual void Run() = 0; + + protected: + ProofSource* delegate_; + ProofSourceHandleCallback* callback_; + Action action_; + }; + + class SelectCertOperation : public PendingOperation { + public: + SelectCertOperation(ProofSource* delegate, + ProofSourceHandleCallback* callback, Action action, + SelectCertArgs args, + QuicDelayedSSLConfig dealyed_ssl_config); + + ~SelectCertOperation() override = default; + + void Run() override; + + private: + const SelectCertArgs args_; + const QuicDelayedSSLConfig dealyed_ssl_config_; + }; + + class ComputeSignatureOperation : public PendingOperation { + public: + ComputeSignatureOperation(ProofSource* delegate, + ProofSourceHandleCallback* callback, + Action action, ComputeSignatureArgs args); + + ~ComputeSignatureOperation() override = default; + + void Run() override; + + private: + const ComputeSignatureArgs args_; + }; + + private: + int NumPendingOperations() const; + + bool closed_ = false; + ProofSource* delegate_; + ProofSourceHandleCallback* callback_; + // Action for the next select cert operation. + Action select_cert_action_ = Action::DELEGATE_SYNC; + // Action for the next compute signature operation. + Action compute_signature_action_ = Action::DELEGATE_SYNC; + const QuicDelayedSSLConfig dealyed_ssl_config_; + absl::optional select_cert_op_; + absl::optional compute_signature_op_; + + // Save all the select cert and compute signature args for tests to inspect. + std::vector all_select_cert_args_; + std::vector all_compute_signature_args_; +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_FAKE_PROOF_SOURCE_HANDLE_H_ diff --git a/quiche/quic/test_tools/first_flight.cc b/quiche/quic/test_tools/first_flight.cc new file mode 100644 index 000000000000..dcdeac3113ac --- /dev/null +++ b/quiche/quic/test_tools/first_flight.cc @@ -0,0 +1,191 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/first_flight.h" + +#include +#include + +#include "quiche/quic/core/crypto/quic_crypto_client_config.h" +#include "quiche/quic/core/http/quic_client_push_promise_index.h" +#include "quiche/quic/core/http/quic_spdy_client_session.h" +#include "quiche/quic/core/quic_config.h" +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_packet_writer.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/mock_connection_id_generator.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { +namespace test { + +// Utility class that creates a custom HTTP/3 session and QUIC connection in +// order to extract the first flight of packets it sends. This is meant to only +// be used by GetFirstFlightOfPackets() below. +class FirstFlightExtractor : public DelegatedPacketWriter::Delegate { + public: + FirstFlightExtractor(const ParsedQuicVersion& version, + const QuicConfig& config, + const QuicConnectionId& server_connection_id, + const QuicConnectionId& client_connection_id, + std::unique_ptr crypto_config) + : version_(version), + server_connection_id_(server_connection_id), + client_connection_id_(client_connection_id), + writer_(this), + config_(config), + crypto_config_(std::move(crypto_config)) { + EXPECT_NE(version_, UnsupportedQuicVersion()); + } + + FirstFlightExtractor(const ParsedQuicVersion& version, + const QuicConfig& config, + const QuicConnectionId& server_connection_id, + const QuicConnectionId& client_connection_id) + : FirstFlightExtractor( + version, config, server_connection_id, client_connection_id, + std::make_unique( + crypto_test_utils::ProofVerifierForTesting())) {} + + void GenerateFirstFlight() { + crypto_config_->set_alpn(AlpnForVersion(version_)); + connection_ = new QuicConnection( + server_connection_id_, + /*initial_self_address=*/QuicSocketAddress(), + QuicSocketAddress(TestPeerIPAddress(), kTestPort), &connection_helper_, + &alarm_factory_, &writer_, + /*owns_writer=*/false, Perspective::IS_CLIENT, + ParsedQuicVersionVector{version_}, connection_id_generator_); + connection_->set_client_connection_id(client_connection_id_); + session_ = std::make_unique( + config_, ParsedQuicVersionVector{version_}, + connection_, // session_ takes ownership of connection_ here. + TestServerId(), crypto_config_.get(), &push_promise_index_); + session_->Initialize(); + session_->CryptoConnect(); + } + + void OnDelegatedPacket(const char* buffer, size_t buf_len, + const QuicIpAddress& /*self_client_address*/, + const QuicSocketAddress& /*peer_client_address*/, + PerPacketOptions* /*options*/) override { + packets_.emplace_back( + QuicReceivedPacket(buffer, buf_len, + connection_helper_.GetClock()->ApproximateNow(), + /*owns_buffer=*/false) + .Clone()); + } + + std::vector>&& ConsumePackets() { + return std::move(packets_); + } + + uint64_t GetCryptoStreamBytesWritten() const { + QUICHE_DCHECK(session_); + QUICHE_DCHECK(session_->GetCryptoStream()); + return session_->GetCryptoStream()->BytesSentOnLevel( + EncryptionLevel::ENCRYPTION_INITIAL); + } + + private: + ParsedQuicVersion version_; + QuicConnectionId server_connection_id_; + QuicConnectionId client_connection_id_; + MockQuicConnectionHelper connection_helper_; + MockAlarmFactory alarm_factory_; + DelegatedPacketWriter writer_; + QuicConfig config_; + std::unique_ptr crypto_config_; + QuicClientPushPromiseIndex push_promise_index_; + QuicConnection* connection_; // Owned by session_. + std::unique_ptr session_; + std::vector> packets_; + MockConnectionIdGenerator connection_id_generator_; +}; + +std::vector> GetFirstFlightOfPackets( + const ParsedQuicVersion& version, const QuicConfig& config, + const QuicConnectionId& server_connection_id, + const QuicConnectionId& client_connection_id, + std::unique_ptr crypto_config) { + FirstFlightExtractor first_flight_extractor( + version, config, server_connection_id, client_connection_id, + std::move(crypto_config)); + first_flight_extractor.GenerateFirstFlight(); + return first_flight_extractor.ConsumePackets(); +} + +std::vector> GetFirstFlightOfPackets( + const ParsedQuicVersion& version, const QuicConfig& config, + const QuicConnectionId& server_connection_id, + const QuicConnectionId& client_connection_id) { + FirstFlightExtractor first_flight_extractor( + version, config, server_connection_id, client_connection_id); + first_flight_extractor.GenerateFirstFlight(); + return first_flight_extractor.ConsumePackets(); +} + +std::vector> GetFirstFlightOfPackets( + const ParsedQuicVersion& version, const QuicConfig& config, + const QuicConnectionId& server_connection_id) { + return GetFirstFlightOfPackets(version, config, server_connection_id, + EmptyQuicConnectionId()); +} + +std::vector> GetFirstFlightOfPackets( + const ParsedQuicVersion& version, const QuicConfig& config) { + return GetFirstFlightOfPackets(version, config, TestConnectionId()); +} + +std::vector> GetFirstFlightOfPackets( + const ParsedQuicVersion& version, + const QuicConnectionId& server_connection_id, + const QuicConnectionId& client_connection_id) { + return GetFirstFlightOfPackets(version, DefaultQuicConfig(), + server_connection_id, client_connection_id); +} + +std::vector> GetFirstFlightOfPackets( + const ParsedQuicVersion& version, + const QuicConnectionId& server_connection_id) { + return GetFirstFlightOfPackets(version, DefaultQuicConfig(), + server_connection_id, EmptyQuicConnectionId()); +} + +std::vector> GetFirstFlightOfPackets( + const ParsedQuicVersion& version) { + return GetFirstFlightOfPackets(version, DefaultQuicConfig(), + TestConnectionId()); +} + +AnnotatedPackets GetAnnotatedFirstFlightOfPackets( + const ParsedQuicVersion& version, const QuicConfig& config, + const QuicConnectionId& server_connection_id, + const QuicConnectionId& client_connection_id, + std::unique_ptr crypto_config) { + FirstFlightExtractor first_flight_extractor( + version, config, server_connection_id, client_connection_id, + std::move(crypto_config)); + first_flight_extractor.GenerateFirstFlight(); + return AnnotatedPackets{first_flight_extractor.ConsumePackets(), + first_flight_extractor.GetCryptoStreamBytesWritten()}; +} + +AnnotatedPackets GetAnnotatedFirstFlightOfPackets( + const ParsedQuicVersion& version, const QuicConfig& config) { + FirstFlightExtractor first_flight_extractor( + version, config, TestConnectionId(), EmptyQuicConnectionId()); + first_flight_extractor.GenerateFirstFlight(); + return AnnotatedPackets{first_flight_extractor.ConsumePackets(), + first_flight_extractor.GetCryptoStreamBytesWritten()}; +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/first_flight.h b/quiche/quic/test_tools/first_flight.h new file mode 100644 index 000000000000..389603174714 --- /dev/null +++ b/quiche/quic/test_tools/first_flight.h @@ -0,0 +1,133 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_FIRST_FLIGHT_H_ +#define QUICHE_QUIC_TEST_TOOLS_FIRST_FLIGHT_H_ + +#include +#include + +#include "quiche/quic/core/crypto/quic_crypto_client_config.h" +#include "quiche/quic/core/quic_config.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_packet_writer.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_socket_address.h" + +namespace quic { +namespace test { + +// Implementation of QuicPacketWriter that sends all packets to a delegate. +class QUIC_NO_EXPORT DelegatedPacketWriter : public QuicPacketWriter { + public: + class QUIC_NO_EXPORT Delegate { + public: + virtual ~Delegate() {} + // Note that |buffer| may be released after this call completes so overrides + // that want to use the data after the call is complete MUST copy it. + virtual void OnDelegatedPacket(const char* buffer, size_t buf_len, + const QuicIpAddress& self_client_address, + const QuicSocketAddress& peer_client_address, + PerPacketOptions* options) = 0; + }; + + // |delegate| MUST be valid for the duration of the DelegatedPacketWriter's + // lifetime. + explicit DelegatedPacketWriter(Delegate* delegate) : delegate_(delegate) { + QUICHE_CHECK_NE(delegate_, nullptr); + } + + // Overrides for QuicPacketWriter. + bool IsWriteBlocked() const override { return false; } + void SetWritable() override {} + absl::optional MessageTooBigErrorCode() const override { + return absl::nullopt; + } + QuicByteCount GetMaxPacketSize( + const QuicSocketAddress& /*peer_address*/) const override { + return kMaxOutgoingPacketSize; + } + bool SupportsReleaseTime() const override { return false; } + bool IsBatchMode() const override { return false; } + QuicPacketBuffer GetNextWriteLocation( + const QuicIpAddress& /*self_address*/, + const QuicSocketAddress& /*peer_address*/) override { + return {nullptr, nullptr}; + } + WriteResult Flush() override { return WriteResult(WRITE_STATUS_OK, 0); } + + WriteResult WritePacket(const char* buffer, size_t buf_len, + const QuicIpAddress& self_client_address, + const QuicSocketAddress& peer_client_address, + PerPacketOptions* options) override { + delegate_->OnDelegatedPacket(buffer, buf_len, self_client_address, + peer_client_address, options); + return WriteResult(WRITE_STATUS_OK, buf_len); + } + + private: + Delegate* delegate_; // Unowned. +}; + +// Returns an array of packets that represent the first flight of a real +// HTTP/3 connection. In most cases, this array will only contain one packet +// that carries the CHLO. +std::vector> GetFirstFlightOfPackets( + const ParsedQuicVersion& version, const QuicConfig& config, + const QuicConnectionId& server_connection_id, + const QuicConnectionId& client_connection_id, + std::unique_ptr crypto_config); + +// Below are various convenience overloads that use default values for the +// omitted parameters: +// |config| = DefaultQuicConfig(), +// |server_connection_id| = TestConnectionId(), +// |client_connection_id| = EmptyQuicConnectionId(). +// |crypto_config| = +// QuicCryptoClientConfig(crypto_test_utils::ProofVerifierForTesting()) +std::vector> GetFirstFlightOfPackets( + const ParsedQuicVersion& version, const QuicConfig& config, + const QuicConnectionId& server_connection_id, + const QuicConnectionId& client_connection_id); + +std::vector> GetFirstFlightOfPackets( + const ParsedQuicVersion& version, const QuicConfig& config, + const QuicConnectionId& server_connection_id); + +std::vector> GetFirstFlightOfPackets( + const ParsedQuicVersion& version, + const QuicConnectionId& server_connection_id, + const QuicConnectionId& client_connection_id); + +std::vector> GetFirstFlightOfPackets( + const ParsedQuicVersion& version, + const QuicConnectionId& server_connection_id); + +std::vector> GetFirstFlightOfPackets( + const ParsedQuicVersion& version, const QuicConfig& config); + +std::vector> GetFirstFlightOfPackets( + const ParsedQuicVersion& version); + +// Functions that also provide additional information about the session. +struct AnnotatedPackets { + std::vector> packets; + uint64_t crypto_stream_size; +}; + +AnnotatedPackets GetAnnotatedFirstFlightOfPackets( + const ParsedQuicVersion& version, const QuicConfig& config, + const QuicConnectionId& server_connection_id, + const QuicConnectionId& client_connection_id, + std::unique_ptr crypto_config); + +AnnotatedPackets GetAnnotatedFirstFlightOfPackets( + const ParsedQuicVersion& version, const QuicConfig& config); + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_FIRST_FLIGHT_H_ diff --git a/quiche/quic/test_tools/fuzzing/README.md b/quiche/quic/test_tools/fuzzing/README.md new file mode 100644 index 000000000000..d914fb95093f --- /dev/null +++ b/quiche/quic/test_tools/fuzzing/README.md @@ -0,0 +1,22 @@ +This directory contains several fuzz tests for QUIC code: + +- quic_framer_fuzzer: A test for CryptoFramer::ParseMessage and + QuicFramer::ProcessPacket using random packet data. +- quic_framer_process_data_packet_fuzzer: A test for QuicFramer::ProcessPacket + where the packet has a valid public header, is decryptable, and contains + random QUIC payload. + +To build and run the fuzz tests, using quic_framer_fuzzer as an example: + +```sh +$ blaze build --config=asan-fuzzer //gfe/quic/test_tools/fuzzing/... +$ CORPUS_DIR=`mktemp -d` && echo ${CORPUS_DIR} +$ ./blaze-bin/gfe/quic/test_tools/fuzzing/quic_framer_fuzzer ${CORPUS_DIR} -use_counters=0 +``` + +By default this fuzzes with 64 byte chunks, to test the framer with more +realistic size input, try 1350 (max payload size of a QUIC packet): + +```sh +$ ./blaze-bin/gfe/quic/test_tools/fuzzing/quic_framer_fuzzer ${CORPUS_DIR} -use_counters=0 -max_len=1350 +``` diff --git a/quiche/quic/test_tools/fuzzing/quic_framer_fuzzer.cc b/quiche/quic/test_tools/fuzzing/quic_framer_fuzzer.cc new file mode 100644 index 000000000000..7b1e09f68e5e --- /dev/null +++ b/quiche/quic/test_tools/fuzzing/quic_framer_fuzzer.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/crypto_framer.h" +#include "quiche/quic/core/crypto/crypto_handshake_message.h" +#include "quiche/quic/core/quic_framer.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + quic::QuicFramer framer(quic::AllSupportedVersions(), quic::QuicTime::Zero(), + quic::Perspective::IS_SERVER, + quic::kQuicDefaultConnectionIdLength); + const char* const packet_bytes = reinterpret_cast(data); + + // Test the CryptoFramer. + absl::string_view crypto_input(packet_bytes, size); + std::unique_ptr handshake_message( + quic::CryptoFramer::ParseMessage(crypto_input)); + + // Test the regular QuicFramer with the same input. + quic::test::NoOpFramerVisitor visitor; + framer.set_visitor(&visitor); + quic::QuicEncryptedPacket packet(packet_bytes, size); + framer.ProcessPacket(packet); + + return 0; +} diff --git a/quiche/quic/test_tools/fuzzing/quic_framer_process_data_packet_fuzzer.cc b/quiche/quic/test_tools/fuzzing/quic_framer_process_data_packet_fuzzer.cc new file mode 100644 index 000000000000..d3aebf7ac839 --- /dev/null +++ b/quiche/quic/test_tools/fuzzing/quic_framer_process_data_packet_fuzzer.cc @@ -0,0 +1,285 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include + +#include +#include +#include + +#include "absl/base/macros.h" +#include "quiche/quic/core/crypto/null_decrypter.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/core/quic_framer.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/test_tools/quic_framer_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +using quic::DiversificationNonce; +using quic::EncryptionLevel; +using quic::FirstSendingPacketNumber; +using quic::GetPacketHeaderSize; +using quic::kEthernetMTU; +using quic::kQuicDefaultConnectionIdLength; +using quic::NullDecrypter; +using quic::NullEncrypter; +using quic::PacketHeaderFormat; +using quic::ParsedQuicVersion; +using quic::ParsedQuicVersionVector; +using quic::Perspective; +using quic::QuicConnectionId; +using quic::QuicDataReader; +using quic::QuicDataWriter; +using quic::QuicEncryptedPacket; +using quic::QuicFramer; +using quic::QuicFramerVisitorInterface; +using quic::QuicLongHeaderType; +using quic::QuicPacketHeader; +using quic::QuicPacketNumber; +using quic::QuicTime; +using quic::QuicTransportVersion; +using quic::test::NoOpFramerVisitor; +using quic::test::QuicFramerPeer; + +PacketHeaderFormat ConsumePacketHeaderFormat(FuzzedDataProvider* provider, + ParsedQuicVersion version) { + if (!version.HasIetfInvariantHeader()) { + return quic::GOOGLE_QUIC_PACKET; + } + return provider->ConsumeBool() ? quic::IETF_QUIC_LONG_HEADER_PACKET + : quic::IETF_QUIC_SHORT_HEADER_PACKET; +} + +ParsedQuicVersion ConsumeParsedQuicVersion(FuzzedDataProvider* provider) { + // TODO(wub): Add support for v49+. + const QuicTransportVersion transport_versions[] = { + quic::QUIC_VERSION_43, + quic::QUIC_VERSION_46, + }; + + return ParsedQuicVersion( + quic::PROTOCOL_QUIC_CRYPTO, + transport_versions[provider->ConsumeIntegralInRange( + 0, ABSL_ARRAYSIZE(transport_versions) - 1)]); +} + +// QuicSelfContainedPacketHeader is a QuicPacketHeader with built-in stroage for +// diversification nonce. +struct QuicSelfContainedPacketHeader : public QuicPacketHeader { + DiversificationNonce nonce_storage; +}; + +// Construct a random data packet header that 1) can be successfully serialized +// at sender, and 2) the serialzied buffer can pass the receiver framer's +// ProcessPublicHeader and DecryptPayload functions. +QuicSelfContainedPacketHeader ConsumeQuicPacketHeader( + FuzzedDataProvider* provider, Perspective receiver_perspective) { + QuicSelfContainedPacketHeader header; + + header.version = ConsumeParsedQuicVersion(provider); + + header.form = ConsumePacketHeaderFormat(provider, header.version); + + const std::string cid_bytes = + provider->ConsumeBytesAsString(kQuicDefaultConnectionIdLength); + if (receiver_perspective == Perspective::IS_SERVER) { + header.destination_connection_id = + QuicConnectionId(cid_bytes.c_str(), cid_bytes.size()); + header.destination_connection_id_included = quic::CONNECTION_ID_PRESENT; + header.source_connection_id_included = quic::CONNECTION_ID_ABSENT; + } else { + header.source_connection_id = + QuicConnectionId(cid_bytes.c_str(), cid_bytes.size()); + header.source_connection_id_included = quic::CONNECTION_ID_PRESENT; + header.destination_connection_id_included = quic::CONNECTION_ID_ABSENT; + } + + header.version_flag = receiver_perspective == Perspective::IS_SERVER; + header.reset_flag = false; + + header.packet_number = + QuicPacketNumber(provider->ConsumeIntegral()); + if (header.packet_number < FirstSendingPacketNumber()) { + header.packet_number = FirstSendingPacketNumber(); + } + header.packet_number_length = quic::PACKET_4BYTE_PACKET_NUMBER; + + header.remaining_packet_length = 0; + + if (header.form != quic::GOOGLE_QUIC_PACKET && header.version_flag) { + header.long_packet_type = static_cast( + provider->ConsumeIntegralInRange( + // INITIAL, ZERO_RTT_PROTECTED, or HANDSHAKE. + static_cast(quic::INITIAL), + static_cast(quic::HANDSHAKE))); + } else { + header.long_packet_type = quic::INVALID_PACKET_TYPE; + } + + if (header.form == quic::IETF_QUIC_LONG_HEADER_PACKET && + header.long_packet_type == quic::ZERO_RTT_PROTECTED && + receiver_perspective == Perspective::IS_CLIENT && + header.version.handshake_protocol == quic::PROTOCOL_QUIC_CRYPTO) { + for (size_t i = 0; i < header.nonce_storage.size(); ++i) { + header.nonce_storage[i] = provider->ConsumeIntegral(); + } + header.nonce = &header.nonce_storage; + } else { + header.nonce = nullptr; + } + + return header; +} + +void SetupFramer(QuicFramer* framer, QuicFramerVisitorInterface* visitor) { + framer->set_visitor(visitor); + for (EncryptionLevel level : + {quic::ENCRYPTION_INITIAL, quic::ENCRYPTION_HANDSHAKE, + quic::ENCRYPTION_ZERO_RTT, quic::ENCRYPTION_FORWARD_SECURE}) { + framer->SetEncrypter( + level, std::make_unique(framer->perspective())); + if (framer->version().KnowsWhichDecrypterToUse()) { + framer->InstallDecrypter( + level, std::make_unique(framer->perspective())); + } + } + + if (!framer->version().KnowsWhichDecrypterToUse()) { + framer->SetDecrypter( + quic::ENCRYPTION_INITIAL, + std::make_unique(framer->perspective())); + } +} + +class FuzzingFramerVisitor : public NoOpFramerVisitor { + public: + // Called after a successful ProcessPublicHeader. + bool OnUnauthenticatedPublicHeader( + const QuicPacketHeader& /*header*/) override { + ++process_public_header_success_count_; + return true; + } + + // Called after a successful DecryptPayload. + bool OnPacketHeader(const QuicPacketHeader& /*header*/) override { + ++decrypted_packet_count_; + return true; + } + + uint64_t process_public_header_success_count_ = 0; + uint64_t decrypted_packet_count_ = 0; +}; + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + FuzzedDataProvider data_provider(data, size); + + const QuicTime creation_time = + QuicTime::Zero() + QuicTime::Delta::FromMicroseconds( + data_provider.ConsumeIntegral()); + Perspective receiver_perspective = data_provider.ConsumeBool() + ? Perspective::IS_CLIENT + : Perspective::IS_SERVER; + Perspective sender_perspective = + (receiver_perspective == Perspective::IS_CLIENT) ? Perspective::IS_SERVER + : Perspective::IS_CLIENT; + + QuicSelfContainedPacketHeader header = + ConsumeQuicPacketHeader(&data_provider, receiver_perspective); + + NoOpFramerVisitor sender_framer_visitor; + ParsedQuicVersionVector framer_versions = {header.version}; + QuicFramer sender_framer(framer_versions, creation_time, sender_perspective, + kQuicDefaultConnectionIdLength); + SetupFramer(&sender_framer, &sender_framer_visitor); + + FuzzingFramerVisitor receiver_framer_visitor; + QuicFramer receiver_framer(framer_versions, creation_time, + receiver_perspective, + kQuicDefaultConnectionIdLength); + SetupFramer(&receiver_framer, &receiver_framer_visitor); + if (receiver_perspective == Perspective::IS_CLIENT) { + QuicFramerPeer::SetLastSerializedServerConnectionId( + &receiver_framer, header.source_connection_id); + } else { + QuicFramerPeer::SetLastSerializedClientConnectionId( + &receiver_framer, header.source_connection_id); + } + + std::array packet_buffer; + while (data_provider.remaining_bytes() > 16) { + const size_t last_remaining_bytes = data_provider.remaining_bytes(); + + // Get a randomized packet size. + uint16_t max_payload_size = static_cast( + std::min(data_provider.remaining_bytes(), 1350u)); + uint16_t min_payload_size = std::min(16u, max_payload_size); + uint16_t payload_size = data_provider.ConsumeIntegralInRange( + min_payload_size, max_payload_size); + + QUICHE_CHECK_NE(last_remaining_bytes, data_provider.remaining_bytes()) + << "Check fail to avoid an infinite loop. ConsumeIntegralInRange(" + << min_payload_size << ", " << max_payload_size + << ") did not consume any bytes. remaining_bytes:" + << last_remaining_bytes; + + std::vector payload_buffer = + data_provider.ConsumeBytes(payload_size); + QUICHE_CHECK_GE( + packet_buffer.size(), + GetPacketHeaderSize(sender_framer.transport_version(), header) + + payload_buffer.size()); + + // Serialize the null-encrypted packet into |packet_buffer|. + QuicDataWriter writer(packet_buffer.size(), packet_buffer.data()); + size_t length_field_offset = 0; + QUICHE_CHECK(sender_framer.AppendPacketHeader(header, &writer, + &length_field_offset)); + + QUICHE_CHECK( + writer.WriteBytes(payload_buffer.data(), payload_buffer.size())); + + EncryptionLevel encryption_level = + quic::test::HeaderToEncryptionLevel(header); + QUICHE_CHECK(sender_framer.WriteIetfLongHeaderLength( + header, &writer, length_field_offset, encryption_level)); + + size_t encrypted_length = sender_framer.EncryptInPlace( + encryption_level, header.packet_number, + GetStartOfEncryptedData(sender_framer.transport_version(), header), + writer.length(), packet_buffer.size(), packet_buffer.data()); + QUICHE_CHECK_NE(encrypted_length, 0u); + + // Use receiver's framer to process the packet. Ensure both + // ProcessPublicHeader and DecryptPayload were called and succeeded. + QuicEncryptedPacket packet(packet_buffer.data(), encrypted_length); + QuicDataReader reader(packet.data(), packet.length()); + + const uint64_t process_public_header_success_count = + receiver_framer_visitor.process_public_header_success_count_; + const uint64_t decrypted_packet_count = + receiver_framer_visitor.decrypted_packet_count_; + + receiver_framer.ProcessPacket(packet); + + QUICHE_DCHECK_EQ( + process_public_header_success_count + 1, + receiver_framer_visitor.process_public_header_success_count_) + << "ProcessPublicHeader failed. error:" + << QuicErrorCodeToString(receiver_framer.error()) + << ", error_detail:" << receiver_framer.detailed_error() + << ". header:" << header; + QUICHE_DCHECK_EQ(decrypted_packet_count + 1, + receiver_framer_visitor.decrypted_packet_count_) + << "Packet was not decrypted. error:" + << QuicErrorCodeToString(receiver_framer.error()) + << ", error_detail:" << receiver_framer.detailed_error() + << ". header:" << header; + } + return 0; +} diff --git a/quiche/quic/test_tools/limited_mtu_test_writer.cc b/quiche/quic/test_tools/limited_mtu_test_writer.cc new file mode 100644 index 000000000000..fa46be157fbd --- /dev/null +++ b/quiche/quic/test_tools/limited_mtu_test_writer.cc @@ -0,0 +1,27 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/limited_mtu_test_writer.h" + +namespace quic { +namespace test { + +LimitedMtuTestWriter::LimitedMtuTestWriter(QuicByteCount mtu) : mtu_(mtu) {} + +LimitedMtuTestWriter::~LimitedMtuTestWriter() = default; + +WriteResult LimitedMtuTestWriter::WritePacket( + const char* buffer, size_t buf_len, const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, PerPacketOptions* options) { + if (buf_len > mtu_) { + // Drop the packet. + return WriteResult(WRITE_STATUS_OK, buf_len); + } + + return QuicPacketWriterWrapper::WritePacket(buffer, buf_len, self_address, + peer_address, options); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/limited_mtu_test_writer.h b/quiche/quic/test_tools/limited_mtu_test_writer.h new file mode 100644 index 000000000000..96cc82800f0f --- /dev/null +++ b/quiche/quic/test_tools/limited_mtu_test_writer.h @@ -0,0 +1,36 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_LIMITED_MTU_TEST_WRITER_H_ +#define QUICHE_QUIC_TEST_TOOLS_LIMITED_MTU_TEST_WRITER_H_ + +#include "quiche/quic/core/quic_packet_writer_wrapper.h" +#include "quiche/quic/core/quic_packets.h" + +namespace quic { +namespace test { + +// Simulates a connection over a link with fixed MTU. Drops packets which +// exceed the MTU and passes the rest of them as-is. +class LimitedMtuTestWriter : public QuicPacketWriterWrapper { + public: + explicit LimitedMtuTestWriter(QuicByteCount mtu); + LimitedMtuTestWriter(const LimitedMtuTestWriter&) = delete; + LimitedMtuTestWriter& operator=(const LimitedMtuTestWriter&) = delete; + ~LimitedMtuTestWriter() override; + + // Inherited from QuicPacketWriterWrapper. + WriteResult WritePacket(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + PerPacketOptions* options) override; + + private: + QuicByteCount mtu_; +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_LIMITED_MTU_TEST_WRITER_H_ diff --git a/quiche/quic/test_tools/mock_clock.cc b/quiche/quic/test_tools/mock_clock.cc new file mode 100644 index 000000000000..da5f604578d6 --- /dev/null +++ b/quiche/quic/test_tools/mock_clock.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/mock_clock.h" + +namespace quic { + +MockClock::MockClock() : now_(QuicTime::Zero()) {} + +MockClock::~MockClock() {} + +void MockClock::AdvanceTime(QuicTime::Delta delta) { now_ = now_ + delta; } + +void MockClock::Reset() { now_ = QuicTime::Zero(); } + +QuicTime MockClock::Now() const { return now_; } + +QuicTime MockClock::ApproximateNow() const { return now_; } + +QuicWallTime MockClock::WallNow() const { + return QuicWallTime::FromUNIXSeconds((now_ - QuicTime::Zero()).ToSeconds()); +} + +} // namespace quic diff --git a/quiche/quic/test_tools/mock_clock.h b/quiche/quic/test_tools/mock_clock.h new file mode 100644 index 000000000000..cbb08c00f67d --- /dev/null +++ b/quiche/quic/test_tools/mock_clock.h @@ -0,0 +1,36 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_MOCK_CLOCK_H_ +#define QUICHE_QUIC_TEST_TOOLS_MOCK_CLOCK_H_ + +#include "quiche/quic/core/quic_clock.h" +#include "quiche/quic/core/quic_time.h" + +namespace quic { + +class MockClock : public QuicClock { + public: + MockClock(); + MockClock(const MockClock&) = delete; + MockClock& operator=(const MockClock&) = delete; + ~MockClock() override; + + // QuicClock implementation: + QuicTime Now() const override; + QuicTime ApproximateNow() const override; + QuicWallTime WallNow() const override; + + // Advances the current time by |delta|, which may be negative. + void AdvanceTime(QuicTime::Delta delta); + // Resets time back to zero. + void Reset(); + + private: + QuicTime now_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_MOCK_CLOCK_H_ diff --git a/quiche/quic/test_tools/mock_connection_id_generator.h b/quiche/quic/test_tools/mock_connection_id_generator.h new file mode 100644 index 000000000000..42209d687873 --- /dev/null +++ b/quiche/quic/test_tools/mock_connection_id_generator.h @@ -0,0 +1,31 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_MOCK_CONNECTION_ID_GENERATOR_H_ +#define QUICHE_QUIC_TEST_TOOLS_MOCK_CONNECTION_ID_GENERATOR_H_ + +#include "quiche/quic/core/connection_id_generator.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { + +class MockConnectionIdGenerator : public quic::ConnectionIdGeneratorInterface { + public: + MOCK_METHOD(absl::optional, GenerateNextConnectionId, + (const quic::QuicConnectionId& original), (override)); + + MOCK_METHOD(absl::optional, MaybeReplaceConnectionId, + (const quic::QuicConnectionId& original, + const quic::ParsedQuicVersion& version), + (override)); + + MOCK_METHOD(uint8_t, ConnectionIdLength, (uint8_t first_byte), + (const, override)); +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_MOCK_CONNECTION_ID_GENERATOR_H_ diff --git a/quiche/quic/test_tools/mock_quic_client_promised_info.cc b/quiche/quic/test_tools/mock_quic_client_promised_info.cc new file mode 100644 index 000000000000..6729819e240c --- /dev/null +++ b/quiche/quic/test_tools/mock_quic_client_promised_info.cc @@ -0,0 +1,17 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/mock_quic_client_promised_info.h" + +namespace quic { +namespace test { + +MockQuicClientPromisedInfo::MockQuicClientPromisedInfo( + QuicSpdyClientSessionBase* session, QuicStreamId id, std::string url) + : QuicClientPromisedInfo(session, id, url) {} + +MockQuicClientPromisedInfo::~MockQuicClientPromisedInfo() {} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/mock_quic_client_promised_info.h b/quiche/quic/test_tools/mock_quic_client_promised_info.h new file mode 100644 index 000000000000..acaefeff9c42 --- /dev/null +++ b/quiche/quic/test_tools/mock_quic_client_promised_info.h @@ -0,0 +1,33 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_MOCK_QUIC_CLIENT_PROMISED_INFO_H_ +#define QUICHE_QUIC_TEST_TOOLS_MOCK_QUIC_CLIENT_PROMISED_INFO_H_ + +#include + +#include "quiche/quic/core/http/quic_client_promised_info.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { +namespace test { + +class MockQuicClientPromisedInfo : public QuicClientPromisedInfo { + public: + MockQuicClientPromisedInfo(QuicSpdyClientSessionBase* session, + QuicStreamId id, std::string url); + ~MockQuicClientPromisedInfo() override; + + MOCK_METHOD(QuicAsyncStatus, HandleClientRequest, + (const spdy::Http2HeaderBlock& headers, + QuicClientPushPromiseIndex::Delegate*), + (override)); +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_MOCK_QUIC_CLIENT_PROMISED_INFO_H_ diff --git a/quiche/quic/test_tools/mock_quic_dispatcher.cc b/quiche/quic/test_tools/mock_quic_dispatcher.cc new file mode 100644 index 000000000000..527047e9c7ca --- /dev/null +++ b/quiche/quic/test_tools/mock_quic_dispatcher.cc @@ -0,0 +1,28 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/mock_quic_dispatcher.h" + +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { +namespace test { + +MockQuicDispatcher::MockQuicDispatcher( + const QuicConfig* config, const QuicCryptoServerConfig* crypto_config, + QuicVersionManager* version_manager, + std::unique_ptr helper, + std::unique_ptr session_helper, + std::unique_ptr alarm_factory, + QuicSimpleServerBackend* quic_simple_server_backend, + ConnectionIdGeneratorInterface& generator) + : QuicSimpleDispatcher(config, crypto_config, version_manager, + std::move(helper), std::move(session_helper), + std::move(alarm_factory), quic_simple_server_backend, + kQuicDefaultConnectionIdLength, generator) {} + +MockQuicDispatcher::~MockQuicDispatcher() {} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/mock_quic_dispatcher.h b/quiche/quic/test_tools/mock_quic_dispatcher.h new file mode 100644 index 000000000000..32b4c8504155 --- /dev/null +++ b/quiche/quic/test_tools/mock_quic_dispatcher.h @@ -0,0 +1,43 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_MOCK_QUIC_DISPATCHER_H_ +#define QUICHE_QUIC_TEST_TOOLS_MOCK_QUIC_DISPATCHER_H_ + +#include "quiche/quic/core/crypto/quic_crypto_server_config.h" +#include "quiche/quic/core/quic_config.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/tools/quic_simple_dispatcher.h" +#include "quiche/quic/tools/quic_simple_server_backend.h" + +namespace quic { +namespace test { + +class MockQuicDispatcher : public QuicSimpleDispatcher { + public: + MockQuicDispatcher( + const QuicConfig* config, const QuicCryptoServerConfig* crypto_config, + QuicVersionManager* version_manager, + std::unique_ptr helper, + std::unique_ptr session_helper, + std::unique_ptr alarm_factory, + QuicSimpleServerBackend* quic_simple_server_backend, + ConnectionIdGeneratorInterface& generator); + MockQuicDispatcher(const MockQuicDispatcher&) = delete; + MockQuicDispatcher& operator=(const MockQuicDispatcher&) = delete; + + ~MockQuicDispatcher() override; + + MOCK_METHOD(void, ProcessPacket, + (const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + const QuicReceivedPacket& packet), + (override)); +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_MOCK_QUIC_DISPATCHER_H_ diff --git a/quiche/quic/test_tools/mock_quic_session_visitor.cc b/quiche/quic/test_tools/mock_quic_session_visitor.cc new file mode 100644 index 000000000000..5f1f75a96624 --- /dev/null +++ b/quiche/quic/test_tools/mock_quic_session_visitor.cc @@ -0,0 +1,19 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/mock_quic_session_visitor.h" + +namespace quic { +namespace test { + +MockQuicSessionVisitor::MockQuicSessionVisitor() = default; + +MockQuicSessionVisitor::~MockQuicSessionVisitor() = default; + +MockQuicCryptoServerStreamHelper::MockQuicCryptoServerStreamHelper() = default; + +MockQuicCryptoServerStreamHelper::~MockQuicCryptoServerStreamHelper() = default; + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/mock_quic_session_visitor.h b/quiche/quic/test_tools/mock_quic_session_visitor.h new file mode 100644 index 000000000000..ab230b466446 --- /dev/null +++ b/quiche/quic/test_tools/mock_quic_session_visitor.h @@ -0,0 +1,62 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_MOCK_QUIC_SESSION_VISITOR_H_ +#define QUICHE_QUIC_TEST_TOOLS_MOCK_QUIC_SESSION_VISITOR_H_ + +#include "quiche/quic/core/quic_crypto_server_stream_base.h" +#include "quiche/quic/core/quic_time_wait_list_manager.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { + +class MockQuicSessionVisitor : public QuicTimeWaitListManager::Visitor { + public: + MockQuicSessionVisitor(); + MockQuicSessionVisitor(const MockQuicSessionVisitor&) = delete; + MockQuicSessionVisitor& operator=(const MockQuicSessionVisitor&) = delete; + ~MockQuicSessionVisitor() override; + MOCK_METHOD(void, OnConnectionClosed, + (QuicConnectionId connection_id, QuicErrorCode error, + const std::string& error_details, ConnectionCloseSource source), + (override)); + MOCK_METHOD(void, OnWriteBlocked, (QuicBlockedWriterInterface*), (override)); + MOCK_METHOD(void, OnRstStreamReceived, (const QuicRstStreamFrame& frame), + (override)); + MOCK_METHOD(void, OnStopSendingReceived, (const QuicStopSendingFrame& frame), + (override)); + MOCK_METHOD(bool, TryAddNewConnectionId, + (const QuicConnectionId& server_connection_id, + const QuicConnectionId& new_connection_id), + (override)); + MOCK_METHOD(void, OnConnectionIdRetired, + (const quic::QuicConnectionId& server_connection_id), (override)); + MOCK_METHOD(void, OnConnectionAddedToTimeWaitList, + (QuicConnectionId connection_id), (override)); + MOCK_METHOD(void, OnServerPreferredAddressAvailable, + (const QuicSocketAddress& server_preferred_address), (override)); +}; + +class MockQuicCryptoServerStreamHelper + : public QuicCryptoServerStreamBase::Helper { + public: + MockQuicCryptoServerStreamHelper(); + MockQuicCryptoServerStreamHelper(const MockQuicCryptoServerStreamHelper&) = + delete; + MockQuicCryptoServerStreamHelper& operator=( + const MockQuicCryptoServerStreamHelper&) = delete; + ~MockQuicCryptoServerStreamHelper() override; + MOCK_METHOD(bool, CanAcceptClientHello, + (const CryptoHandshakeMessage& message, + const QuicSocketAddress& client_address, + const QuicSocketAddress& peer_address, + const QuicSocketAddress& self_address, std::string*), + (const, override)); +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_MOCK_QUIC_SESSION_VISITOR_H_ diff --git a/quiche/quic/test_tools/mock_quic_spdy_client_stream.cc b/quiche/quic/test_tools/mock_quic_spdy_client_stream.cc new file mode 100644 index 000000000000..ac9604fdb2dc --- /dev/null +++ b/quiche/quic/test_tools/mock_quic_spdy_client_stream.cc @@ -0,0 +1,17 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/mock_quic_spdy_client_stream.h" + +namespace quic { +namespace test { + +MockQuicSpdyClientStream::MockQuicSpdyClientStream( + QuicStreamId id, QuicSpdyClientSession* session, StreamType type) + : QuicSpdyClientStream(id, session, type) {} + +MockQuicSpdyClientStream::~MockQuicSpdyClientStream() {} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/mock_quic_spdy_client_stream.h b/quiche/quic/test_tools/mock_quic_spdy_client_stream.h new file mode 100644 index 000000000000..10c9c6df29bf --- /dev/null +++ b/quiche/quic/test_tools/mock_quic_spdy_client_stream.h @@ -0,0 +1,33 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_MOCK_QUIC_SPDY_CLIENT_STREAM_H_ +#define QUICHE_QUIC_TEST_TOOLS_MOCK_QUIC_SPDY_CLIENT_STREAM_H_ + +#include "quiche/quic/core/http/quic_header_list.h" +#include "quiche/quic/core/http/quic_spdy_client_stream.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { + +class MockQuicSpdyClientStream : public QuicSpdyClientStream { + public: + MockQuicSpdyClientStream(QuicStreamId id, QuicSpdyClientSession* session, + StreamType type); + ~MockQuicSpdyClientStream() override; + + MOCK_METHOD(void, OnStreamFrame, (const QuicStreamFrame& frame), (override)); + MOCK_METHOD(void, OnPromiseHeaderList, + (QuicStreamId promised_stream_id, size_t frame_len, + const QuicHeaderList& list), + (override)); + MOCK_METHOD(void, OnDataAvailable, (), (override)); +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_MOCK_QUIC_SPDY_CLIENT_STREAM_H_ diff --git a/quiche/quic/test_tools/mock_quic_time_wait_list_manager.cc b/quiche/quic/test_tools/mock_quic_time_wait_list_manager.cc new file mode 100644 index 000000000000..c345bf3b7df6 --- /dev/null +++ b/quiche/quic/test_tools/mock_quic_time_wait_list_manager.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/mock_quic_time_wait_list_manager.h" + +using testing::_; +using testing::Invoke; + +namespace quic { +namespace test { + +MockTimeWaitListManager::MockTimeWaitListManager( + QuicPacketWriter* writer, Visitor* visitor, const QuicClock* clock, + QuicAlarmFactory* alarm_factory) + : QuicTimeWaitListManager(writer, visitor, clock, alarm_factory) { + // Though AddConnectionIdToTimeWait is mocked, we want to retain its + // functionality. + EXPECT_CALL(*this, AddConnectionIdToTimeWait(_, _)) + .Times(testing::AnyNumber()); + ON_CALL(*this, AddConnectionIdToTimeWait(_, _)) + .WillByDefault( + Invoke(this, &MockTimeWaitListManager:: + QuicTimeWaitListManager_AddConnectionIdToTimeWait)); +} + +MockTimeWaitListManager::~MockTimeWaitListManager() = default; + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/mock_quic_time_wait_list_manager.h b/quiche/quic/test_tools/mock_quic_time_wait_list_manager.h new file mode 100644 index 000000000000..5218f2f66342 --- /dev/null +++ b/quiche/quic/test_tools/mock_quic_time_wait_list_manager.h @@ -0,0 +1,63 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_MOCK_QUIC_TIME_WAIT_LIST_MANAGER_H_ +#define QUICHE_QUIC_TEST_TOOLS_MOCK_QUIC_TIME_WAIT_LIST_MANAGER_H_ + +#include "quiche/quic/core/quic_time_wait_list_manager.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { + +class MockTimeWaitListManager : public QuicTimeWaitListManager { + public: + MockTimeWaitListManager(QuicPacketWriter* writer, Visitor* visitor, + const QuicClock* clock, + QuicAlarmFactory* alarm_factory); + ~MockTimeWaitListManager() override; + + MOCK_METHOD(void, AddConnectionIdToTimeWait, + (QuicTimeWaitListManager::TimeWaitAction action, + quic::TimeWaitConnectionInfo info), + (override)); + + void QuicTimeWaitListManager_AddConnectionIdToTimeWait( + QuicTimeWaitListManager::TimeWaitAction action, + quic::TimeWaitConnectionInfo info) { + QuicTimeWaitListManager::AddConnectionIdToTimeWait(action, std::move(info)); + } + + MOCK_METHOD(void, ProcessPacket, + (const QuicSocketAddress&, const QuicSocketAddress&, + QuicConnectionId, PacketHeaderFormat, size_t, + std::unique_ptr), + (override)); + + MOCK_METHOD(void, SendVersionNegotiationPacket, + (QuicConnectionId server_connection_id, + QuicConnectionId client_connection_id, bool ietf_quic, + bool has_length_prefix, + const ParsedQuicVersionVector& supported_versions, + const QuicSocketAddress& server_address, + const QuicSocketAddress& client_address, + std::unique_ptr packet_context), + (override)); + + MOCK_METHOD(void, SendPublicReset, + (const QuicSocketAddress&, const QuicSocketAddress&, + QuicConnectionId, bool, size_t, + std::unique_ptr), + (override)); + + MOCK_METHOD(void, SendPacket, + (const QuicSocketAddress&, const QuicSocketAddress&, + const QuicEncryptedPacket&), + (override)); +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_MOCK_QUIC_TIME_WAIT_LIST_MANAGER_H_ diff --git a/quiche/quic/test_tools/mock_random.cc b/quiche/quic/test_tools/mock_random.cc new file mode 100644 index 000000000000..2c45e65de3da --- /dev/null +++ b/quiche/quic/test_tools/mock_random.cc @@ -0,0 +1,48 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/mock_random.h" + +#include + +namespace quic { +namespace test { + +using testing::_; +using testing::Invoke; + +MockRandom::MockRandom() : MockRandom(0xDEADBEEF) {} + +MockRandom::MockRandom(uint32_t base) : base_(base), increment_(0) { + ON_CALL(*this, RandBytes(_, _)) + .WillByDefault(Invoke(this, &MockRandom::DefaultRandBytes)); + ON_CALL(*this, RandUint64()) + .WillByDefault(Invoke(this, &MockRandom::DefaultRandUint64)); + ON_CALL(*this, InsecureRandBytes(_, _)) + .WillByDefault(Invoke(this, &MockRandom::DefaultInsecureRandBytes)); + ON_CALL(*this, InsecureRandUint64()) + .WillByDefault(Invoke(this, &MockRandom::DefaultInsecureRandUint64)); +} + +void MockRandom::DefaultRandBytes(void* data, size_t len) { + memset(data, increment_ + static_cast('r'), len); +} + +uint64_t MockRandom::DefaultRandUint64() { return base_ + increment_; } + +void MockRandom::DefaultInsecureRandBytes(void* data, size_t len) { + DefaultRandBytes(data, len); +} + +uint64_t MockRandom::DefaultInsecureRandUint64() { return DefaultRandUint64(); } + +void MockRandom::ChangeValue() { increment_++; } + +void MockRandom::ResetBase(uint32_t base) { + base_ = base; + increment_ = 0; +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/mock_random.h b/quiche/quic/test_tools/mock_random.h new file mode 100644 index 000000000000..0a4918d0ddc0 --- /dev/null +++ b/quiche/quic/test_tools/mock_random.h @@ -0,0 +1,57 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_MOCK_RANDOM_H_ +#define QUICHE_QUIC_TEST_TOOLS_MOCK_RANDOM_H_ + +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { + +class MockRandom : public QuicRandom { + public: + // Initializes base_ to 0xDEADBEEF. + MockRandom(); + explicit MockRandom(uint32_t base); + MockRandom(const MockRandom&) = delete; + MockRandom& operator=(const MockRandom&) = delete; + + MOCK_METHOD(void, RandBytes, (void* data, size_t len), (override)); + MOCK_METHOD(uint64_t, RandUint64, (), (override)); + MOCK_METHOD(void, InsecureRandBytes, (void* data, size_t len), (override)); + MOCK_METHOD(uint64_t, InsecureRandUint64, (), (override)); + + // Default QuicRandom implementations. They are used if the caller does not + // setup the MockRandom via EXPECT_CALLs. + + // Fills the |data| buffer with a repeating byte, initially 'r'. + void DefaultRandBytes(void* data, size_t len); + // Returns base + the current increment. + uint64_t DefaultRandUint64(); + + // InsecureRandBytes behaves equivalently to RandBytes. + void DefaultInsecureRandBytes(void* data, size_t len); + // InsecureRandUint64 behaves equivalently to RandUint64. + uint64_t DefaultInsecureRandUint64(); + + // ChangeValue increments |increment_|. This causes the value returned by + // |RandUint64| and the byte that |RandBytes| fills with, to change. + // Used by the Default implementations. + void ChangeValue(); + + // Sets the base to |base| and resets increment to zero. + // Used by the Default implementations. + void ResetBase(uint32_t base); + + private: + uint32_t base_; + uint8_t increment_; +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_MOCK_RANDOM_H_ diff --git a/quiche/quic/test_tools/packet_dropping_test_writer.cc b/quiche/quic/test_tools/packet_dropping_test_writer.cc new file mode 100644 index 000000000000..69aae1456f21 --- /dev/null +++ b/quiche/quic/test_tools/packet_dropping_test_writer.cc @@ -0,0 +1,252 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/packet_dropping_test_writer.h" + +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { +namespace test { + +// Every dropped packet must be followed by this number of succesfully written +// packets. This is to avoid flaky test failures and timeouts, for example, in +// case both the client and the server drop every other packet (which is +// statistically possible even if drop percentage is less than 50%). +const int32_t kMinSuccesfulWritesAfterPacketLoss = 2; + +// An alarm that is scheduled if a blocked socket is simulated to indicate +// it's writable again. +class WriteUnblockedAlarm : public QuicAlarm::DelegateWithoutContext { + public: + explicit WriteUnblockedAlarm(PacketDroppingTestWriter* writer) + : writer_(writer) {} + + void OnAlarm() override { + QUIC_DLOG(INFO) << "Unblocking socket."; + writer_->OnCanWrite(); + } + + private: + PacketDroppingTestWriter* writer_; +}; + +// An alarm that is scheduled every time a new packet is to be written at a +// later point. +class DelayAlarm : public QuicAlarm::DelegateWithoutContext { + public: + explicit DelayAlarm(PacketDroppingTestWriter* writer) : writer_(writer) {} + + void OnAlarm() override { + QuicTime new_deadline = writer_->ReleaseOldPackets(); + if (new_deadline.IsInitialized()) { + writer_->SetDelayAlarm(new_deadline); + } + } + + private: + PacketDroppingTestWriter* writer_; +}; + +PacketDroppingTestWriter::PacketDroppingTestWriter() + : clock_(nullptr), + cur_buffer_size_(0), + num_calls_to_write_(0), + // Do not require any number of successful writes before the first dropped + // packet. + num_consecutive_succesful_writes_(kMinSuccesfulWritesAfterPacketLoss), + fake_packet_loss_percentage_(0), + fake_drop_first_n_packets_(0), + fake_blocked_socket_percentage_(0), + fake_packet_reorder_percentage_(0), + fake_packet_delay_(QuicTime::Delta::Zero()), + fake_bandwidth_(QuicBandwidth::Zero()), + buffer_size_(0) { + uint64_t seed = QuicRandom::GetInstance()->RandUint64(); + QUIC_LOG(INFO) << "Seeding packet loss with " << seed; + simple_random_.set_seed(seed); +} + +PacketDroppingTestWriter::~PacketDroppingTestWriter() { + if (write_unblocked_alarm_ != nullptr) { + write_unblocked_alarm_->PermanentCancel(); + } + if (delay_alarm_ != nullptr) { + delay_alarm_->PermanentCancel(); + } +} + +void PacketDroppingTestWriter::Initialize( + QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory, + std::unique_ptr on_can_write) { + clock_ = helper->GetClock(); + write_unblocked_alarm_.reset( + alarm_factory->CreateAlarm(new WriteUnblockedAlarm(this))); + delay_alarm_.reset(alarm_factory->CreateAlarm(new DelayAlarm(this))); + on_can_write_ = std::move(on_can_write); +} + +WriteResult PacketDroppingTestWriter::WritePacket( + const char* buffer, size_t buf_len, const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, PerPacketOptions* options) { + ++num_calls_to_write_; + ReleaseOldPackets(); + + QuicWriterMutexLock lock(&config_mutex_); + if (fake_drop_first_n_packets_ > 0 && + num_calls_to_write_ <= + static_cast(fake_drop_first_n_packets_)) { + QUIC_DVLOG(1) << "Dropping first " << fake_drop_first_n_packets_ + << " packets (packet number " << num_calls_to_write_ << ")"; + num_consecutive_succesful_writes_ = 0; + return WriteResult(WRITE_STATUS_OK, buf_len); + } + + // Drop every packet at 100%, otherwise always succeed for at least + // kMinSuccesfulWritesAfterPacketLoss packets between two dropped ones. + if (fake_packet_loss_percentage_ == 100 || + (fake_packet_loss_percentage_ > 0 && + num_consecutive_succesful_writes_ >= + kMinSuccesfulWritesAfterPacketLoss && + (simple_random_.RandUint64() % 100 < + static_cast(fake_packet_loss_percentage_)))) { + QUIC_DVLOG(1) << "Dropping packet " << num_calls_to_write_; + num_consecutive_succesful_writes_ = 0; + return WriteResult(WRITE_STATUS_OK, buf_len); + } else { + ++num_consecutive_succesful_writes_; + } + + if (fake_blocked_socket_percentage_ > 0 && + simple_random_.RandUint64() % 100 < + static_cast(fake_blocked_socket_percentage_)) { + QUICHE_CHECK(on_can_write_ != nullptr); + QUIC_DVLOG(1) << "Blocking socket for packet " << num_calls_to_write_; + if (!write_unblocked_alarm_->IsSet()) { + // Set the alarm to fire immediately. + write_unblocked_alarm_->Set(clock_->ApproximateNow()); + } + + // Dropping this packet on retry could result in PTO timeout, + // make sure to avoid this. + num_consecutive_succesful_writes_ = 0; + + return WriteResult(WRITE_STATUS_BLOCKED, EAGAIN); + } + + if (!fake_packet_delay_.IsZero() || !fake_bandwidth_.IsZero()) { + if (buffer_size_ > 0 && buf_len + cur_buffer_size_ > buffer_size_) { + // Drop packets which do not fit into the buffer. + QUIC_DVLOG(1) << "Dropping packet because the buffer is full."; + return WriteResult(WRITE_STATUS_OK, buf_len); + } + + // Queue it to be sent. + QuicTime send_time = clock_->ApproximateNow() + fake_packet_delay_; + if (!fake_bandwidth_.IsZero()) { + // Calculate a time the bandwidth limit would impose. + QuicTime::Delta bandwidth_delay = QuicTime::Delta::FromMicroseconds( + (buf_len * kNumMicrosPerSecond) / fake_bandwidth_.ToBytesPerSecond()); + send_time = delayed_packets_.empty() + ? send_time + bandwidth_delay + : delayed_packets_.back().send_time + bandwidth_delay; + } + std::unique_ptr delayed_options; + if (options != nullptr) { + delayed_options = options->Clone(); + } + delayed_packets_.push_back( + DelayedWrite(buffer, buf_len, self_address, peer_address, + std::move(delayed_options), send_time)); + cur_buffer_size_ += buf_len; + + // Set the alarm if it's not yet set. + if (!delay_alarm_->IsSet()) { + delay_alarm_->Set(send_time); + } + + return WriteResult(WRITE_STATUS_OK, buf_len); + } + + return QuicPacketWriterWrapper::WritePacket(buffer, buf_len, self_address, + peer_address, options); +} + +bool PacketDroppingTestWriter::IsWriteBlocked() const { + if (write_unblocked_alarm_ != nullptr && write_unblocked_alarm_->IsSet()) { + return true; + } + return QuicPacketWriterWrapper::IsWriteBlocked(); +} + +void PacketDroppingTestWriter::SetWritable() { + if (write_unblocked_alarm_ != nullptr && write_unblocked_alarm_->IsSet()) { + write_unblocked_alarm_->Cancel(); + } + QuicPacketWriterWrapper::SetWritable(); +} + +QuicTime PacketDroppingTestWriter::ReleaseNextPacket() { + if (delayed_packets_.empty()) { + return QuicTime::Zero(); + } + QuicReaderMutexLock lock(&config_mutex_); + auto iter = delayed_packets_.begin(); + // Determine if we should re-order. + if (delayed_packets_.size() > 1 && fake_packet_reorder_percentage_ > 0 && + simple_random_.RandUint64() % 100 < + static_cast(fake_packet_reorder_percentage_)) { + QUIC_DLOG(INFO) << "Reordering packets."; + ++iter; + // Swap the send times when re-ordering packets. + delayed_packets_.begin()->send_time = iter->send_time; + } + + QUIC_DVLOG(1) << "Releasing packet. " << (delayed_packets_.size() - 1) + << " remaining."; + // Grab the next one off the queue and send it. + QuicPacketWriterWrapper::WritePacket( + iter->buffer.data(), iter->buffer.length(), iter->self_address, + iter->peer_address, iter->options.get()); + QUICHE_DCHECK_GE(cur_buffer_size_, iter->buffer.length()); + cur_buffer_size_ -= iter->buffer.length(); + delayed_packets_.erase(iter); + + // If there are others, find the time for the next to be sent. + if (delayed_packets_.empty()) { + return QuicTime::Zero(); + } + return delayed_packets_.begin()->send_time; +} + +QuicTime PacketDroppingTestWriter::ReleaseOldPackets() { + while (!delayed_packets_.empty()) { + QuicTime next_send_time = delayed_packets_.front().send_time; + if (next_send_time > clock_->Now()) { + return next_send_time; + } + ReleaseNextPacket(); + } + return QuicTime::Zero(); +} + +void PacketDroppingTestWriter::SetDelayAlarm(QuicTime new_deadline) { + delay_alarm_->Set(new_deadline); +} + +void PacketDroppingTestWriter::OnCanWrite() { on_can_write_->OnCanWrite(); } + +PacketDroppingTestWriter::DelayedWrite::DelayedWrite( + const char* buffer, size_t buf_len, const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + std::unique_ptr options, QuicTime send_time) + : buffer(buffer, buf_len), + self_address(self_address), + peer_address(peer_address), + options(std::move(options)), + send_time(send_time) {} + +PacketDroppingTestWriter::DelayedWrite::~DelayedWrite() = default; + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/packet_dropping_test_writer.h b/quiche/quic/test_tools/packet_dropping_test_writer.h new file mode 100644 index 000000000000..a7e91d3465c9 --- /dev/null +++ b/quiche/quic/test_tools/packet_dropping_test_writer.h @@ -0,0 +1,185 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_PACKET_DROPPING_TEST_WRITER_H_ +#define QUICHE_QUIC_TEST_TOOLS_PACKET_DROPPING_TEST_WRITER_H_ + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "quiche/quic/core/quic_alarm.h" +#include "quiche/quic/core/quic_alarm_factory.h" +#include "quiche/quic/core/quic_clock.h" +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/core/quic_packet_writer_wrapper.h" +#include "quiche/quic/platform/api/quic_mutex.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { +namespace test { + +// Simulates a connection that drops packets a configured percentage of the time +// and has a blocked socket a configured percentage of the time. Also provides +// the options to delay packets and reorder packets if delay is enabled. +class PacketDroppingTestWriter : public QuicPacketWriterWrapper { + public: + class Delegate { + public: + virtual ~Delegate() {} + virtual void OnCanWrite() = 0; + }; + + PacketDroppingTestWriter(); + PacketDroppingTestWriter(const PacketDroppingTestWriter&) = delete; + PacketDroppingTestWriter& operator=(const PacketDroppingTestWriter&) = delete; + + ~PacketDroppingTestWriter() override; + + // Must be called before blocking, reordering or delaying (loss is OK). May be + // called after connecting if the helper is not available before. + // |on_can_write| will be triggered when fake-unblocking. + void Initialize(QuicConnectionHelperInterface* helper, + QuicAlarmFactory* alarm_factory, + std::unique_ptr on_can_write); + + // QuicPacketWriter methods: + WriteResult WritePacket(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + PerPacketOptions* options) override; + + bool IsWriteBlocked() const override; + + void SetWritable() override; + + QuicPacketBuffer GetNextWriteLocation( + const QuicIpAddress& /*self_address*/, + const QuicSocketAddress& /*peer_address*/) override { + // If the wrapped writer supports zero-copy, disable it, because it is not + // compatible with delayed writes in this class. + return {nullptr, nullptr}; + } + + // Writes out any packet which should have been sent by now + // to the contained writer and returns the time + // for the next delayed packet to be written. + QuicTime ReleaseOldPackets(); + + // Sets |delay_alarm_| to fire at |new_deadline|. + void SetDelayAlarm(QuicTime new_deadline); + + void OnCanWrite(); + + // The percent of time a packet is simulated as being lost. + // If |fake_packet_loss_percentage| is 100, then all packages are lost. + // Otherwise actual percentage will be lower than + // |fake_packet_loss_percentage|, because every dropped package is followed by + // a minimum number of successfully written packets. + void set_fake_packet_loss_percentage(int32_t fake_packet_loss_percentage) { + QuicWriterMutexLock lock(&config_mutex_); + fake_packet_loss_percentage_ = fake_packet_loss_percentage; + } + + // Simulate dropping the first n packets unconditionally. + // Subsequent packets will be lost at fake_packet_loss_percentage_ if set. + void set_fake_drop_first_n_packets(int32_t fake_drop_first_n_packets) { + QuicWriterMutexLock lock(&config_mutex_); + fake_drop_first_n_packets_ = fake_drop_first_n_packets; + } + + // The percent of time WritePacket will block and set WriteResult's status + // to WRITE_STATUS_BLOCKED. + void set_fake_blocked_socket_percentage( + int32_t fake_blocked_socket_percentage) { + QUICHE_DCHECK(clock_); + QuicWriterMutexLock lock(&config_mutex_); + fake_blocked_socket_percentage_ = fake_blocked_socket_percentage; + } + + // The percent of time a packet is simulated as being reordered. + void set_fake_reorder_percentage(int32_t fake_packet_reorder_percentage) { + QUICHE_DCHECK(clock_); + QuicWriterMutexLock lock(&config_mutex_); + QUICHE_DCHECK(!fake_packet_delay_.IsZero()); + fake_packet_reorder_percentage_ = fake_packet_reorder_percentage; + } + + // The delay before writing this packet. + void set_fake_packet_delay(QuicTime::Delta fake_packet_delay) { + QUICHE_DCHECK(clock_); + QuicWriterMutexLock lock(&config_mutex_); + fake_packet_delay_ = fake_packet_delay; + } + + // The maximum bandwidth and buffer size of the connection. When these are + // set, packets will be delayed until a connection with that bandwidth would + // transmit it. Once the |buffer_size| is reached, all new packets are + // dropped. + void set_max_bandwidth_and_buffer_size(QuicBandwidth fake_bandwidth, + QuicByteCount buffer_size) { + QUICHE_DCHECK(clock_); + QuicWriterMutexLock lock(&config_mutex_); + fake_bandwidth_ = fake_bandwidth; + buffer_size_ = buffer_size; + } + + // Useful for reproducing very flaky issues. + ABSL_ATTRIBUTE_UNUSED void set_seed(uint64_t seed) { + simple_random_.set_seed(seed); + } + + private: + // Writes out the next packet to the contained writer and returns the time + // for the next delayed packet to be written. + QuicTime ReleaseNextPacket(); + + // A single packet which will be sent at the supplied send_time. + struct DelayedWrite { + public: + DelayedWrite(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + std::unique_ptr options, QuicTime send_time); + DelayedWrite(const DelayedWrite&) = delete; + DelayedWrite(DelayedWrite&&) = default; + DelayedWrite& operator=(const DelayedWrite&) = delete; + DelayedWrite& operator=(DelayedWrite&&) = default; + ~DelayedWrite(); + + std::string buffer; + QuicIpAddress self_address; + QuicSocketAddress peer_address; + std::unique_ptr options; + QuicTime send_time; + }; + + using DelayedPacketList = std::list; + + const QuicClock* clock_; + std::unique_ptr write_unblocked_alarm_; + std::unique_ptr delay_alarm_; + std::unique_ptr on_can_write_; + SimpleRandom simple_random_; + // Stored packets delayed by fake packet delay or bandwidth restrictions. + DelayedPacketList delayed_packets_; + QuicByteCount cur_buffer_size_; + uint64_t num_calls_to_write_; + int32_t num_consecutive_succesful_writes_; + + QuicMutex config_mutex_; + int32_t fake_packet_loss_percentage_ QUIC_GUARDED_BY(config_mutex_); + int32_t fake_drop_first_n_packets_ QUIC_GUARDED_BY(config_mutex_); + int32_t fake_blocked_socket_percentage_ QUIC_GUARDED_BY(config_mutex_); + int32_t fake_packet_reorder_percentage_ QUIC_GUARDED_BY(config_mutex_); + QuicTime::Delta fake_packet_delay_ QUIC_GUARDED_BY(config_mutex_); + QuicBandwidth fake_bandwidth_ QUIC_GUARDED_BY(config_mutex_); + QuicByteCount buffer_size_ QUIC_GUARDED_BY(config_mutex_); +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_PACKET_DROPPING_TEST_WRITER_H_ diff --git a/quiche/quic/test_tools/packet_reordering_writer.cc b/quiche/quic/test_tools/packet_reordering_writer.cc new file mode 100644 index 000000000000..8eb8573a7937 --- /dev/null +++ b/quiche/quic/test_tools/packet_reordering_writer.cc @@ -0,0 +1,51 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/packet_reordering_writer.h" + +namespace quic { +namespace test { + +PacketReorderingWriter::PacketReorderingWriter() = default; + +PacketReorderingWriter::~PacketReorderingWriter() = default; + +WriteResult PacketReorderingWriter::WritePacket( + const char* buffer, size_t buf_len, const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, PerPacketOptions* options) { + if (!delay_next_) { + QUIC_VLOG(2) << "Writing a non-delayed packet"; + WriteResult wr = QuicPacketWriterWrapper::WritePacket( + buffer, buf_len, self_address, peer_address, options); + --num_packets_to_wait_; + if (num_packets_to_wait_ == 0) { + QUIC_VLOG(2) << "Writing a delayed packet"; + // It's time to write the delayed packet. + QuicPacketWriterWrapper::WritePacket( + delayed_data_.data(), delayed_data_.length(), delayed_self_address_, + delayed_peer_address_, delayed_options_.get()); + } + return wr; + } + // Still have packet to wait. + QUICHE_DCHECK_LT(0u, num_packets_to_wait_) + << "Only allow one packet to be delayed"; + delayed_data_ = std::string(buffer, buf_len); + delayed_self_address_ = self_address; + delayed_peer_address_ = peer_address; + if (options != nullptr) { + delayed_options_ = options->Clone(); + } + delay_next_ = false; + return WriteResult(WRITE_STATUS_OK, buf_len); +} + +void PacketReorderingWriter::SetDelay(size_t num_packets_to_wait) { + QUICHE_DCHECK_GT(num_packets_to_wait, 0u); + num_packets_to_wait_ = num_packets_to_wait; + delay_next_ = true; +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/packet_reordering_writer.h b/quiche/quic/test_tools/packet_reordering_writer.h new file mode 100644 index 000000000000..53204c54a816 --- /dev/null +++ b/quiche/quic/test_tools/packet_reordering_writer.h @@ -0,0 +1,44 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_PACKET_REORDERING_WRITER_H_ +#define QUICHE_QUIC_TEST_TOOLS_PACKET_REORDERING_WRITER_H_ + +#include "quiche/quic/core/quic_packet_writer_wrapper.h" + +namespace quic { + +namespace test { + +// This packet writer allows delaying writing the next packet after +// SetDelay(num_packets_to_wait) +// is called and buffer this packet and write it after it writes next +// |num_packets_to_wait| packets. It doesn't support delaying a packet while +// there is already a packet delayed. +class PacketReorderingWriter : public QuicPacketWriterWrapper { + public: + PacketReorderingWriter(); + + ~PacketReorderingWriter() override; + + WriteResult WritePacket(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + PerPacketOptions* options) override; + + void SetDelay(size_t num_packets_to_wait); + + private: + bool delay_next_ = false; + size_t num_packets_to_wait_ = 0; + std::string delayed_data_; + QuicIpAddress delayed_self_address_; + QuicSocketAddress delayed_peer_address_; + std::unique_ptr delayed_options_; +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_PACKET_REORDERING_WRITER_H_ diff --git a/quiche/quic/test_tools/qpack/qpack_decoder_test_utils.cc b/quiche/quic/test_tools/qpack/qpack_decoder_test_utils.cc new file mode 100644 index 000000000000..2b39dcdf7c35 --- /dev/null +++ b/quiche/quic/test_tools/qpack/qpack_decoder_test_utils.cc @@ -0,0 +1,85 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/qpack/qpack_decoder_test_utils.h" + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { + +TestHeadersHandler::TestHeadersHandler() + : decoding_completed_(false), decoding_error_detected_(false) {} + +void TestHeadersHandler::OnHeaderDecoded(absl::string_view name, + absl::string_view value) { + ASSERT_FALSE(decoding_completed_); + ASSERT_FALSE(decoding_error_detected_); + + header_list_.AppendValueOrAddHeader(name, value); +} + +void TestHeadersHandler::OnDecodingCompleted() { + ASSERT_FALSE(decoding_completed_); + ASSERT_FALSE(decoding_error_detected_); + + decoding_completed_ = true; +} + +void TestHeadersHandler::OnDecodingErrorDetected( + QuicErrorCode /*error_code*/, absl::string_view error_message) { + ASSERT_FALSE(decoding_completed_); + ASSERT_FALSE(decoding_error_detected_); + + decoding_error_detected_ = true; + error_message_.assign(error_message.data(), error_message.size()); +} + +spdy::Http2HeaderBlock TestHeadersHandler::ReleaseHeaderList() { + QUICHE_DCHECK(decoding_completed_); + QUICHE_DCHECK(!decoding_error_detected_); + + return std::move(header_list_); +} + +bool TestHeadersHandler::decoding_completed() const { + return decoding_completed_; +} + +bool TestHeadersHandler::decoding_error_detected() const { + return decoding_error_detected_; +} + +const std::string& TestHeadersHandler::error_message() const { + QUICHE_DCHECK(decoding_error_detected_); + return error_message_; +} + +void QpackDecode( + uint64_t maximum_dynamic_table_capacity, uint64_t maximum_blocked_streams, + QpackDecoder::EncoderStreamErrorDelegate* encoder_stream_error_delegate, + QpackStreamSenderDelegate* decoder_stream_sender_delegate, + QpackProgressiveDecoder::HeadersHandlerInterface* handler, + const FragmentSizeGenerator& fragment_size_generator, + absl::string_view data) { + QpackDecoder decoder(maximum_dynamic_table_capacity, maximum_blocked_streams, + encoder_stream_error_delegate); + decoder.set_qpack_stream_sender_delegate(decoder_stream_sender_delegate); + auto progressive_decoder = + decoder.CreateProgressiveDecoder(/* stream_id = */ 1, handler); + while (!data.empty()) { + size_t fragment_size = std::min(fragment_size_generator(), data.size()); + progressive_decoder->Decode(data.substr(0, fragment_size)); + data = data.substr(fragment_size); + } + progressive_decoder->EndHeaderBlock(); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/qpack/qpack_decoder_test_utils.h b/quiche/quic/test_tools/qpack/qpack_decoder_test_utils.h new file mode 100644 index 000000000000..f746550bc6a3 --- /dev/null +++ b/quiche/quic/test_tools/qpack/qpack_decoder_test_utils.h @@ -0,0 +1,101 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QPACK_QPACK_DECODER_TEST_UTILS_H_ +#define QUICHE_QUIC_TEST_TOOLS_QPACK_QPACK_DECODER_TEST_UTILS_H_ + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/qpack/qpack_decoder.h" +#include "quiche/quic/core/qpack/qpack_progressive_decoder.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/qpack/qpack_test_utils.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { +namespace test { + +// Mock QpackDecoder::EncoderStreamErrorDelegate implementation. +class MockEncoderStreamErrorDelegate + : public QpackDecoder::EncoderStreamErrorDelegate { + public: + ~MockEncoderStreamErrorDelegate() override = default; + + MOCK_METHOD(void, OnEncoderStreamError, + (QuicErrorCode error_code, absl::string_view error_message), + (override)); +}; + +// HeadersHandlerInterface implementation that collects decoded headers +// into a Http2HeaderBlock. +class TestHeadersHandler + : public QpackProgressiveDecoder::HeadersHandlerInterface { + public: + TestHeadersHandler(); + ~TestHeadersHandler() override = default; + + // HeadersHandlerInterface implementation: + void OnHeaderDecoded(absl::string_view name, + absl::string_view value) override; + void OnDecodingCompleted() override; + void OnDecodingErrorDetected(QuicErrorCode error_code, + absl::string_view error_message) override; + + // Release decoded header list. Must only be called if decoding is complete + // and no errors have been detected. + spdy::Http2HeaderBlock ReleaseHeaderList(); + + bool decoding_completed() const; + bool decoding_error_detected() const; + const std::string& error_message() const; + + private: + spdy::Http2HeaderBlock header_list_; + bool decoding_completed_; + bool decoding_error_detected_; + std::string error_message_; +}; + +class MockHeadersHandler + : public QpackProgressiveDecoder::HeadersHandlerInterface { + public: + MockHeadersHandler() = default; + MockHeadersHandler(const MockHeadersHandler&) = delete; + MockHeadersHandler& operator=(const MockHeadersHandler&) = delete; + ~MockHeadersHandler() override = default; + + MOCK_METHOD(void, OnHeaderDecoded, + (absl::string_view name, absl::string_view value), (override)); + MOCK_METHOD(void, OnDecodingCompleted, (), (override)); + MOCK_METHOD(void, OnDecodingErrorDetected, + (QuicErrorCode error_code, absl::string_view error_message), + (override)); +}; + +class NoOpHeadersHandler + : public QpackProgressiveDecoder::HeadersHandlerInterface { + public: + ~NoOpHeadersHandler() override = default; + + void OnHeaderDecoded(absl::string_view /*name*/, + absl::string_view /*value*/) override {} + void OnDecodingCompleted() override {} + void OnDecodingErrorDetected(QuicErrorCode /*error_code*/, + absl::string_view /*error_message*/) override {} +}; + +void QpackDecode( + uint64_t maximum_dynamic_table_capacity, uint64_t maximum_blocked_streams, + QpackDecoder::EncoderStreamErrorDelegate* encoder_stream_error_delegate, + QpackStreamSenderDelegate* decoder_stream_sender_delegate, + QpackProgressiveDecoder::HeadersHandlerInterface* handler, + const FragmentSizeGenerator& fragment_size_generator, + absl::string_view data); + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QPACK_QPACK_DECODER_TEST_UTILS_H_ diff --git a/quiche/quic/test_tools/qpack/qpack_encoder_peer.cc b/quiche/quic/test_tools/qpack/qpack_encoder_peer.cc new file mode 100644 index 000000000000..73894370346b --- /dev/null +++ b/quiche/quic/test_tools/qpack/qpack_encoder_peer.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/qpack/qpack_encoder_peer.h" + +#include "quiche/quic/core/qpack/qpack_encoder.h" + +namespace quic { +namespace test { + +// static +QpackEncoderHeaderTable* QpackEncoderPeer::header_table(QpackEncoder* encoder) { + return &encoder->header_table_; +} + +// static +uint64_t QpackEncoderPeer::maximum_blocked_streams( + const QpackEncoder* encoder) { + return encoder->maximum_blocked_streams_; +} + +// static +uint64_t QpackEncoderPeer::smallest_blocking_index( + const QpackEncoder* encoder) { + return encoder->blocking_manager_.smallest_blocking_index(); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/qpack/qpack_encoder_peer.h b/quiche/quic/test_tools/qpack/qpack_encoder_peer.h new file mode 100644 index 000000000000..94a308af9806 --- /dev/null +++ b/quiche/quic/test_tools/qpack/qpack_encoder_peer.h @@ -0,0 +1,30 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QPACK_QPACK_ENCODER_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QPACK_QPACK_ENCODER_PEER_H_ + +#include + +namespace quic { + +class QpackEncoder; +class QpackEncoderHeaderTable; + +namespace test { + +class QpackEncoderPeer { + public: + QpackEncoderPeer() = delete; + + static QpackEncoderHeaderTable* header_table(QpackEncoder* encoder); + static uint64_t maximum_blocked_streams(const QpackEncoder* encoder); + static uint64_t smallest_blocking_index(const QpackEncoder* encoder); +}; + +} // namespace test + +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QPACK_QPACK_ENCODER_PEER_H_ diff --git a/quiche/quic/test_tools/qpack/qpack_offline_decoder.cc b/quiche/quic/test_tools/qpack/qpack_offline_decoder.cc new file mode 100644 index 000000000000..282e33f5c92f --- /dev/null +++ b/quiche/quic/test_tools/qpack/qpack_offline_decoder.cc @@ -0,0 +1,336 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Decoder to test QPACK Offline Interop corpus +// +// See https://github.com/quicwg/base-drafts/wiki/QPACK-Offline-Interop for +// description of test data format. +// +// Example usage +// +// cd $TEST_DATA +// git clone https://github.com/qpackers/qifs.git +// TEST_ENCODED_DATA=$TEST_DATA/qifs/encoded/qpack-06 +// TEST_QIF_DATA=$TEST_DATA/qifs/qifs +// $BIN/qpack_offline_decoder \ +// $TEST_ENCODED_DATA/f5/fb-req.qifencoded.4096.100.0 \ +// $TEST_QIF_DATA/fb-req.qif +// $TEST_ENCODED_DATA/h2o/fb-req-hq.out.512.0.1 \ +// $TEST_QIF_DATA/fb-req-hq.qif +// $TEST_ENCODED_DATA/ls-qpack/fb-resp-hq.out.0.0.0 \ +// $TEST_QIF_DATA/fb-resp-hq.qif +// $TEST_ENCODED_DATA/proxygen/netbsd.qif.proxygen.out.4096.0.0 \ +// $TEST_QIF_DATA/netbsd.qif +// + +#include "quiche/quic/test_tools/qpack/qpack_offline_decoder.h" + +#include +#include +#include + +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/test_tools/qpack/qpack_test_utils.h" +#include "quiche/common/platform/api/quiche_file_utils.h" +#include "quiche/common/quiche_endian.h" + +namespace quic { + +QpackOfflineDecoder::QpackOfflineDecoder() + : encoder_stream_error_detected_(false) {} + +bool QpackOfflineDecoder::DecodeAndVerifyOfflineData( + absl::string_view input_filename, + absl::string_view expected_headers_filename) { + if (!ParseInputFilename(input_filename)) { + QUIC_LOG(ERROR) << "Error parsing input filename " << input_filename; + return false; + } + + if (!DecodeHeaderBlocksFromFile(input_filename)) { + QUIC_LOG(ERROR) << "Error decoding header blocks in " << input_filename; + return false; + } + + if (!VerifyDecodedHeaderLists(expected_headers_filename)) { + QUIC_LOG(ERROR) << "Header lists decoded from " << input_filename + << " to not match expected headers parsed from " + << expected_headers_filename; + return false; + } + + return true; +} + +void QpackOfflineDecoder::OnEncoderStreamError( + QuicErrorCode error_code, absl::string_view error_message) { + QUIC_LOG(ERROR) << "Encoder stream error: " + << QuicErrorCodeToString(error_code) << " " << error_message; + encoder_stream_error_detected_ = true; +} + +bool QpackOfflineDecoder::ParseInputFilename(absl::string_view input_filename) { + std::vector pieces = absl::StrSplit(input_filename, '.'); + + if (pieces.size() < 3) { + QUIC_LOG(ERROR) << "Not enough fields in input filename " << input_filename; + return false; + } + + auto piece_it = pieces.rbegin(); + + // Acknowledgement mode: 1 for immediate, 0 for none. + if (*piece_it != "0" && *piece_it != "1") { + QUIC_LOG(ERROR) + << "Header acknowledgement field must be 0 or 1 in input filename " + << input_filename; + return false; + } + + ++piece_it; + + // Maximum allowed number of blocked streams. + uint64_t max_blocked_streams = 0; + if (!absl::SimpleAtoi(*piece_it, &max_blocked_streams)) { + QUIC_LOG(ERROR) << "Error parsing part of input filename \"" << *piece_it + << "\" as an integer."; + return false; + } + + ++piece_it; + + // Maximum Dynamic Table Capacity in bytes + uint64_t maximum_dynamic_table_capacity = 0; + if (!absl::SimpleAtoi(*piece_it, &maximum_dynamic_table_capacity)) { + QUIC_LOG(ERROR) << "Error parsing part of input filename \"" << *piece_it + << "\" as an integer."; + return false; + } + qpack_decoder_ = std::make_unique( + maximum_dynamic_table_capacity, max_blocked_streams, this); + qpack_decoder_->set_qpack_stream_sender_delegate( + &decoder_stream_sender_delegate_); + + // The initial dynamic table capacity is zero according to + // https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#eviction. + // However, for historical reasons, offline interop encoders use + // |maximum_dynamic_table_capacity| as initial capacity. + qpack_decoder_->OnSetDynamicTableCapacity(maximum_dynamic_table_capacity); + + return true; +} + +bool QpackOfflineDecoder::DecodeHeaderBlocksFromFile( + absl::string_view input_filename) { + // Store data in |input_data_storage|; use a absl::string_view to + // efficiently keep track of remaining portion yet to be decoded. + absl::optional input_data_storage = + quiche::ReadFileContents(input_filename); + QUICHE_DCHECK(input_data_storage.has_value()); + absl::string_view input_data(*input_data_storage); + + while (!input_data.empty()) { + // Parse stream_id and length. + if (input_data.size() < sizeof(uint64_t) + sizeof(uint32_t)) { + QUIC_LOG(ERROR) << "Unexpected end of input file."; + return false; + } + + uint64_t stream_id = quiche::QuicheEndian::NetToHost64( + *reinterpret_cast(input_data.data())); + input_data = input_data.substr(sizeof(uint64_t)); + + uint32_t length = quiche::QuicheEndian::NetToHost32( + *reinterpret_cast(input_data.data())); + input_data = input_data.substr(sizeof(uint32_t)); + + if (input_data.size() < length) { + QUIC_LOG(ERROR) << "Unexpected end of input file."; + return false; + } + + // Parse data. + absl::string_view data = input_data.substr(0, length); + input_data = input_data.substr(length); + + // Process data. + if (stream_id == 0) { + qpack_decoder_->encoder_stream_receiver()->Decode(data); + + if (encoder_stream_error_detected_) { + QUIC_LOG(ERROR) << "Error detected on encoder stream."; + return false; + } + } else { + auto headers_handler = std::make_unique(); + auto progressive_decoder = qpack_decoder_->CreateProgressiveDecoder( + stream_id, headers_handler.get()); + + progressive_decoder->Decode(data); + progressive_decoder->EndHeaderBlock(); + + if (headers_handler->decoding_error_detected()) { + QUIC_LOG(ERROR) << "Sync decoding error on stream " << stream_id << ": " + << headers_handler->error_message(); + return false; + } + + decoders_.push_back({std::move(headers_handler), + std::move(progressive_decoder), stream_id}); + } + + // Move decoded header lists from TestHeadersHandlers and append them to + // |decoded_header_lists_| while preserving the order in |decoders_|. + while (!decoders_.empty() && + decoders_.front().headers_handler->decoding_completed()) { + Decoder* decoder = &decoders_.front(); + + if (decoder->headers_handler->decoding_error_detected()) { + QUIC_LOG(ERROR) << "Async decoding error on stream " + << decoder->stream_id << ": " + << decoder->headers_handler->error_message(); + return false; + } + + if (!decoder->headers_handler->decoding_completed()) { + QUIC_LOG(ERROR) << "Decoding incomplete after reading entire" + " file, on stream " + << decoder->stream_id; + return false; + } + + decoded_header_lists_.push_back( + decoder->headers_handler->ReleaseHeaderList()); + decoders_.pop_front(); + } + } + + if (!decoders_.empty()) { + QUICHE_DCHECK(!decoders_.front().headers_handler->decoding_completed()); + + QUIC_LOG(ERROR) << "Blocked decoding uncomplete after reading entire" + " file, on stream " + << decoders_.front().stream_id; + return false; + } + + return true; +} + +bool QpackOfflineDecoder::VerifyDecodedHeaderLists( + absl::string_view expected_headers_filename) { + // Store data in |expected_headers_data_storage|; use a + // absl::string_view to efficiently keep track of remaining portion + // yet to be decoded. + absl::optional expected_headers_data_storage = + quiche::ReadFileContents(expected_headers_filename); + QUICHE_DCHECK(expected_headers_data_storage.has_value()); + absl::string_view expected_headers_data(*expected_headers_data_storage); + + while (!decoded_header_lists_.empty()) { + spdy::Http2HeaderBlock decoded_header_list = + std::move(decoded_header_lists_.front()); + decoded_header_lists_.pop_front(); + + spdy::Http2HeaderBlock expected_header_list; + if (!ReadNextExpectedHeaderList(&expected_headers_data, + &expected_header_list)) { + QUIC_LOG(ERROR) + << "Error parsing expected header list to match next decoded " + "header list."; + return false; + } + + if (!CompareHeaderBlocks(std::move(decoded_header_list), + std::move(expected_header_list))) { + QUIC_LOG(ERROR) << "Decoded header does not match expected header."; + return false; + } + } + + if (!expected_headers_data.empty()) { + QUIC_LOG(ERROR) + << "Not enough encoded header lists to match expected ones."; + return false; + } + + return true; +} + +bool QpackOfflineDecoder::ReadNextExpectedHeaderList( + absl::string_view* expected_headers_data, + spdy::Http2HeaderBlock* expected_header_list) { + while (true) { + absl::string_view::size_type endline = expected_headers_data->find('\n'); + + // Even last header list must be followed by an empty line. + if (endline == absl::string_view::npos) { + QUIC_LOG(ERROR) << "Unexpected end of expected header list file."; + return false; + } + + if (endline == 0) { + // Empty line indicates end of header list. + *expected_headers_data = expected_headers_data->substr(1); + return true; + } + + absl::string_view header_field = expected_headers_data->substr(0, endline); + std::vector pieces = absl::StrSplit(header_field, '\t'); + + if (pieces.size() != 2) { + QUIC_LOG(ERROR) << "Header key and value must be separated by TAB."; + return false; + } + + expected_header_list->AppendValueOrAddHeader(pieces[0], pieces[1]); + + *expected_headers_data = expected_headers_data->substr(endline + 1); + } +} + +bool QpackOfflineDecoder::CompareHeaderBlocks( + spdy::Http2HeaderBlock decoded_header_list, + spdy::Http2HeaderBlock expected_header_list) { + if (decoded_header_list == expected_header_list) { + return true; + } + + // The h2o decoder reshuffles the "content-length" header and pseudo-headers, + // see + // https://github.com/qpackers/qifs/blob/master/encoded/qpack-03/h2o/README.md. + // Remove such headers one by one if they match. + const char* kContentLength = "content-length"; + const char* kPseudoHeaderPrefix = ":"; + for (spdy::Http2HeaderBlock::iterator decoded_it = + decoded_header_list.begin(); + decoded_it != decoded_header_list.end();) { + const absl::string_view key = decoded_it->first; + if (key != kContentLength && !absl::StartsWith(key, kPseudoHeaderPrefix)) { + ++decoded_it; + continue; + } + spdy::Http2HeaderBlock::iterator expected_it = + expected_header_list.find(key); + if (expected_it == expected_header_list.end() || + decoded_it->second != expected_it->second) { + ++decoded_it; + continue; + } + // Http2HeaderBlock does not support erasing by iterator, only by key. + ++decoded_it; + expected_header_list.erase(key); + // This will invalidate |key|. + decoded_header_list.erase(key); + } + + return decoded_header_list == expected_header_list; +} + +} // namespace quic diff --git a/quiche/quic/test_tools/qpack/qpack_offline_decoder.h b/quiche/quic/test_tools/qpack/qpack_offline_decoder.h new file mode 100644 index 000000000000..07fa2769356b --- /dev/null +++ b/quiche/quic/test_tools/qpack/qpack_offline_decoder.h @@ -0,0 +1,88 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QPACK_QPACK_OFFLINE_DECODER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QPACK_QPACK_OFFLINE_DECODER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/qpack/qpack_decoder.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/test_tools/qpack/qpack_decoder_test_utils.h" +#include "quiche/quic/test_tools/qpack/qpack_test_utils.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +// A decoder to read encoded data from a file, decode it, and compare to +// a list of expected header lists read from another file. File format is +// described at +// https://github.com/quicwg/base-drafts/wiki/QPACK-Offline-Interop. +class QpackOfflineDecoder : public QpackDecoder::EncoderStreamErrorDelegate { + public: + QpackOfflineDecoder(); + ~QpackOfflineDecoder() override = default; + + // Read encoded header blocks and encoder stream data from |input_filename| + // and decode them, read expected header lists from + // |expected_headers_filename|, and compare decoded header lists to expected + // ones. Returns true if there is an equal number of them and the + // corresponding ones match, false otherwise. + bool DecodeAndVerifyOfflineData(absl::string_view input_filename, + absl::string_view expected_headers_filename); + + // QpackDecoder::EncoderStreamErrorDelegate implementation: + void OnEncoderStreamError(QuicErrorCode error_code, + absl::string_view error_message) override; + + private: + // Data structure to hold TestHeadersHandler and QpackProgressiveDecoder until + // decoding of a header header block (and all preceding header blocks) is + // complete. + struct Decoder { + std::unique_ptr headers_handler; + std::unique_ptr progressive_decoder; + uint64_t stream_id; + }; + + // Parse decoder parameters from |input_filename| and set up |qpack_decoder_| + // accordingly. + bool ParseInputFilename(absl::string_view input_filename); + + // Read encoded header blocks and encoder stream data from |input_filename|, + // pass them to |qpack_decoder_| for decoding, and add decoded header lists to + // |decoded_header_lists_|. + bool DecodeHeaderBlocksFromFile(absl::string_view input_filename); + + // Read expected header lists from |expected_headers_filename| and verify + // decoded header lists in |decoded_header_lists_| against them. + bool VerifyDecodedHeaderLists(absl::string_view expected_headers_filename); + + // Parse next header list from |*expected_headers_data| into + // |*expected_header_list|, removing consumed data from the beginning of + // |*expected_headers_data|. Returns true on success, false if parsing fails. + bool ReadNextExpectedHeaderList(absl::string_view* expected_headers_data, + spdy::Http2HeaderBlock* expected_header_list); + + // Compare two header lists. Allow for different orders of certain headers as + // described at + // https://github.com/qpackers/qifs/blob/master/encoded/qpack-03/h2o/README.md. + bool CompareHeaderBlocks(spdy::Http2HeaderBlock decoded_header_list, + spdy::Http2HeaderBlock expected_header_list); + + bool encoder_stream_error_detected_; + test::NoopQpackStreamSenderDelegate decoder_stream_sender_delegate_; + std::unique_ptr qpack_decoder_; + + // Objects necessary for decoding, one list element for each header block. + std::list decoders_; + + // Decoded header lists. + std::list decoded_header_lists_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QPACK_QPACK_OFFLINE_DECODER_H_ diff --git a/quiche/quic/test_tools/qpack/qpack_test_utils.cc b/quiche/quic/test_tools/qpack/qpack_test_utils.cc new file mode 100644 index 000000000000..3d7f20c1fc77 --- /dev/null +++ b/quiche/quic/test_tools/qpack/qpack_test_utils.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/qpack/qpack_test_utils.h" + +#include + +#include "quiche/quic/platform/api/quic_bug_tracker.h" + +namespace quic { +namespace test { + +FragmentSizeGenerator FragmentModeToFragmentSizeGenerator( + FragmentMode fragment_mode) { + switch (fragment_mode) { + case FragmentMode::kSingleChunk: + return []() { return std::numeric_limits::max(); }; + case FragmentMode::kOctetByOctet: + return []() { return 1; }; + } + QUIC_BUG(quic_bug_10259_1) + << "Unknown FragmentMode " << static_cast(fragment_mode); + return []() { return 0; }; +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/qpack/qpack_test_utils.h b/quiche/quic/test_tools/qpack/qpack_test_utils.h new file mode 100644 index 000000000000..c6f1c657d58b --- /dev/null +++ b/quiche/quic/test_tools/qpack/qpack_test_utils.h @@ -0,0 +1,51 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QPACK_QPACK_TEST_UTILS_H_ +#define QUICHE_QUIC_TEST_TOOLS_QPACK_QPACK_TEST_UTILS_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/qpack/qpack_stream_sender_delegate.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { + +// Called repeatedly to determine the size of each fragment when encoding or +// decoding. Must return a positive value. +using FragmentSizeGenerator = std::function; + +enum class FragmentMode { + kSingleChunk, + kOctetByOctet, +}; + +FragmentSizeGenerator FragmentModeToFragmentSizeGenerator( + FragmentMode fragment_mode); + +// Mock QpackUnidirectionalStreamSenderDelegate implementation. +class MockQpackStreamSenderDelegate : public QpackStreamSenderDelegate { + public: + ~MockQpackStreamSenderDelegate() override = default; + + MOCK_METHOD(void, WriteStreamData, (absl::string_view data), (override)); + MOCK_METHOD(uint64_t, NumBytesBuffered, (), (const, override)); +}; + +class NoopQpackStreamSenderDelegate : public QpackStreamSenderDelegate { + public: + ~NoopQpackStreamSenderDelegate() override = default; + + void WriteStreamData(absl::string_view /*data*/) override {} + + uint64_t NumBytesBuffered() const override { return 0; } +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QPACK_QPACK_TEST_UTILS_H_ diff --git a/quiche/quic/test_tools/quic_buffered_packet_store_peer.cc b/quiche/quic/test_tools/quic_buffered_packet_store_peer.cc new file mode 100644 index 000000000000..4d11806fe59d --- /dev/null +++ b/quiche/quic/test_tools/quic_buffered_packet_store_peer.cc @@ -0,0 +1,25 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_buffered_packet_store_peer.h" + +#include "quiche/quic/core/quic_buffered_packet_store.h" + +namespace quic { +namespace test { + +// static +QuicAlarm* QuicBufferedPacketStorePeer::expiration_alarm( + QuicBufferedPacketStore* store) { + return store->expiration_alarm_.get(); +} + +// static +void QuicBufferedPacketStorePeer::set_clock(QuicBufferedPacketStore* store, + const QuicClock* clock) { + store->clock_ = clock; +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_buffered_packet_store_peer.h b/quiche/quic/test_tools/quic_buffered_packet_store_peer.h new file mode 100644 index 000000000000..06102747cf28 --- /dev/null +++ b/quiche/quic/test_tools/quic_buffered_packet_store_peer.h @@ -0,0 +1,32 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_BUFFERED_PACKET_STORE_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_BUFFERED_PACKET_STORE_PEER_H_ + +#include + +#include "quiche/quic/core/quic_alarm.h" +#include "quiche/quic/core/quic_clock.h" + +namespace quic { + +class QuicBufferedPacketStore; + +namespace test { + +class QuicBufferedPacketStorePeer { + public: + QuicBufferedPacketStorePeer() = delete; + + static QuicAlarm* expiration_alarm(QuicBufferedPacketStore* store); + + static void set_clock(QuicBufferedPacketStore* store, const QuicClock* clock); +}; + +} // namespace test + +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_BUFFERED_PACKET_STORE_PEER_H_ diff --git a/quiche/quic/test_tools/quic_client_promised_info_peer.cc b/quiche/quic/test_tools/quic_client_promised_info_peer.cc new file mode 100644 index 000000000000..5ea080c58b68 --- /dev/null +++ b/quiche/quic/test_tools/quic_client_promised_info_peer.cc @@ -0,0 +1,17 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_client_promised_info_peer.h" + +namespace quic { +namespace test { + +// static +QuicAlarm* QuicClientPromisedInfoPeer::GetAlarm( + QuicClientPromisedInfo* promised_stream) { + return promised_stream->cleanup_alarm_.get(); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_client_promised_info_peer.h b/quiche/quic/test_tools/quic_client_promised_info_peer.h new file mode 100644 index 000000000000..596200be79d1 --- /dev/null +++ b/quiche/quic/test_tools/quic_client_promised_info_peer.h @@ -0,0 +1,22 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_CLIENT_PROMISED_INFO_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_CLIENT_PROMISED_INFO_PEER_H_ + +#include "quiche/quic/core/http/quic_client_promised_info.h" + +namespace quic { +namespace test { + +class QuicClientPromisedInfoPeer { + public: + QuicClientPromisedInfoPeer() = delete; + + static QuicAlarm* GetAlarm(QuicClientPromisedInfo* promised_stream); +}; +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_CLIENT_PROMISED_INFO_PEER_H_ diff --git a/quiche/quic/test_tools/quic_client_session_cache_peer.h b/quiche/quic/test_tools/quic_client_session_cache_peer.h new file mode 100644 index 000000000000..b070ef69933d --- /dev/null +++ b/quiche/quic/test_tools/quic_client_session_cache_peer.h @@ -0,0 +1,33 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_CLIENT_SESSION_CACHE_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_CLIENT_SESSION_CACHE_PEER_H_ + +#include "quiche/quic/core/crypto/quic_client_session_cache.h" + +namespace quic { +namespace test { + +class QuicClientSessionCachePeer { + public: + static std::string GetToken(QuicClientSessionCache* cache, + const QuicServerId& server_id) { + auto iter = cache->cache_.Lookup(server_id); + if (iter == cache->cache_.end()) { + return {}; + } + return iter->second->token; + } + + static bool HasEntry(QuicClientSessionCache* cache, + const QuicServerId& server_id) { + return cache->cache_.Lookup(server_id) != cache->cache_.end(); + } +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_CLIENT_SESSION_CACHE_PEER_H_ diff --git a/quiche/quic/test_tools/quic_coalesced_packet_peer.cc b/quiche/quic/test_tools/quic_coalesced_packet_peer.cc new file mode 100644 index 000000000000..eeb16212484e --- /dev/null +++ b/quiche/quic/test_tools/quic_coalesced_packet_peer.cc @@ -0,0 +1,23 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_coalesced_packet_peer.h" + +namespace quic { +namespace test { + +// static +void QuicCoalescedPacketPeer::SetMaxPacketLength( + QuicCoalescedPacket& coalesced_packet, QuicPacketLength length) { + coalesced_packet.max_packet_length_ = length; +} + +// static +std::string* QuicCoalescedPacketPeer::GetMutableEncryptedBuffer( + QuicCoalescedPacket& coalesced_packet, EncryptionLevel encryption_level) { + return &coalesced_packet.encrypted_buffers_[encryption_level]; +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_coalesced_packet_peer.h b/quiche/quic/test_tools/quic_coalesced_packet_peer.h new file mode 100644 index 000000000000..c84c37c5a3d8 --- /dev/null +++ b/quiche/quic/test_tools/quic_coalesced_packet_peer.h @@ -0,0 +1,26 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_COALESCED_PACKET_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_COALESCED_PACKET_PEER_H_ + +#include "quiche/quic/core/quic_coalesced_packet.h" +#include "quiche/quic/core/quic_types.h" + +namespace quic { +namespace test { + +class QuicCoalescedPacketPeer { + public: + static void SetMaxPacketLength(QuicCoalescedPacket& coalesced_packet, + QuicPacketLength length); + + static std::string* GetMutableEncryptedBuffer( + QuicCoalescedPacket& coalesced_packet, EncryptionLevel encryption_level); +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_COALESCED_PACKET_PEER_H_ diff --git a/quiche/quic/test_tools/quic_config_peer.cc b/quiche/quic/test_tools/quic_config_peer.cc new file mode 100644 index 000000000000..a3e9acc549ef --- /dev/null +++ b/quiche/quic/test_tools/quic_config_peer.cc @@ -0,0 +1,156 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_config_peer.h" + +#include "quiche/quic/core/quic_config.h" +#include "quiche/quic/core/quic_connection_id.h" + +namespace quic { +namespace test { + +// static +void QuicConfigPeer::SetReceivedInitialStreamFlowControlWindow( + QuicConfig* config, uint32_t window_bytes) { + config->initial_stream_flow_control_window_bytes_.SetReceivedValue( + window_bytes); +} + +// static +void QuicConfigPeer::SetReceivedInitialMaxStreamDataBytesIncomingBidirectional( + QuicConfig* config, uint32_t window_bytes) { + config->initial_max_stream_data_bytes_incoming_bidirectional_ + .SetReceivedValue(window_bytes); +} + +// static +void QuicConfigPeer::SetReceivedInitialMaxStreamDataBytesOutgoingBidirectional( + QuicConfig* config, uint32_t window_bytes) { + config->initial_max_stream_data_bytes_outgoing_bidirectional_ + .SetReceivedValue(window_bytes); +} + +// static +void QuicConfigPeer::SetReceivedInitialMaxStreamDataBytesUnidirectional( + QuicConfig* config, uint32_t window_bytes) { + config->initial_max_stream_data_bytes_unidirectional_.SetReceivedValue( + window_bytes); +} + +// static +void QuicConfigPeer::SetReceivedInitialSessionFlowControlWindow( + QuicConfig* config, uint32_t window_bytes) { + config->initial_session_flow_control_window_bytes_.SetReceivedValue( + window_bytes); +} + +// static +void QuicConfigPeer::SetReceivedConnectionOptions( + QuicConfig* config, const QuicTagVector& options) { + config->connection_options_.SetReceivedValues(options); +} + +// static +void QuicConfigPeer::SetReceivedBytesForConnectionId(QuicConfig* config, + uint32_t bytes) { + QUICHE_DCHECK(bytes == 0 || bytes == 8); + config->bytes_for_connection_id_.SetReceivedValue(bytes); +} + +// static +void QuicConfigPeer::SetReceivedDisableConnectionMigration(QuicConfig* config) { + config->connection_migration_disabled_.SetReceivedValue(1); +} + +// static +void QuicConfigPeer::SetReceivedMaxBidirectionalStreams(QuicConfig* config, + uint32_t max_streams) { + config->max_bidirectional_streams_.SetReceivedValue(max_streams); +} +// static +void QuicConfigPeer::SetReceivedMaxUnidirectionalStreams(QuicConfig* config, + uint32_t max_streams) { + config->max_unidirectional_streams_.SetReceivedValue(max_streams); +} + +// static +void QuicConfigPeer::SetConnectionOptionsToSend(QuicConfig* config, + const QuicTagVector& options) { + config->SetConnectionOptionsToSend(options); +} + +// static +void QuicConfigPeer::SetReceivedStatelessResetToken( + QuicConfig* config, const StatelessResetToken& token) { + config->stateless_reset_token_.SetReceivedValue(token); +} + +// static +void QuicConfigPeer::SetReceivedMaxPacketSize(QuicConfig* config, + uint32_t max_udp_payload_size) { + config->max_udp_payload_size_.SetReceivedValue(max_udp_payload_size); +} + +// static +void QuicConfigPeer::SetReceivedMinAckDelayMs(QuicConfig* config, + uint32_t min_ack_delay_ms) { + config->min_ack_delay_ms_.SetReceivedValue(min_ack_delay_ms); +} + +// static +void QuicConfigPeer::SetNegotiated(QuicConfig* config, bool negotiated) { + config->negotiated_ = negotiated; +} + +// static +void QuicConfigPeer::SetReceivedOriginalConnectionId( + QuicConfig* config, + const QuicConnectionId& original_destination_connection_id) { + config->received_original_destination_connection_id_ = + original_destination_connection_id; +} + +// static +void QuicConfigPeer::SetReceivedInitialSourceConnectionId( + QuicConfig* config, const QuicConnectionId& initial_source_connection_id) { + config->received_initial_source_connection_id_ = initial_source_connection_id; +} + +// static +void QuicConfigPeer::SetReceivedRetrySourceConnectionId( + QuicConfig* config, const QuicConnectionId& retry_source_connection_id) { + config->received_retry_source_connection_id_ = retry_source_connection_id; +} + +// static +void QuicConfigPeer::SetReceivedMaxDatagramFrameSize( + QuicConfig* config, uint64_t max_datagram_frame_size) { + config->max_datagram_frame_size_.SetReceivedValue(max_datagram_frame_size); +} + +// static +void QuicConfigPeer::SetReceivedAlternateServerAddress( + QuicConfig* config, const QuicSocketAddress& server_address) { + switch (server_address.host().address_family()) { + case quiche::IpAddressFamily::IP_V4: + config->alternate_server_address_ipv4_.SetReceivedValue(server_address); + break; + case quiche::IpAddressFamily::IP_V6: + config->alternate_server_address_ipv6_.SetReceivedValue(server_address); + break; + case quiche::IpAddressFamily::IP_UNSPEC: + break; + } +} + +// static +void QuicConfigPeer::SetPreferredAddressConnectionIdAndToken( + QuicConfig* config, QuicConnectionId connection_id, + const StatelessResetToken& token) { + config->preferred_address_connection_id_and_token_ = + std::make_pair(connection_id, token); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_config_peer.h b/quiche/quic/test_tools/quic_config_peer.h new file mode 100644 index 000000000000..25481e0b7cd6 --- /dev/null +++ b/quiche/quic/test_tools/quic_config_peer.h @@ -0,0 +1,90 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_CONFIG_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_CONFIG_PEER_H_ + +#include "quiche/quic/core/quic_config.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_types.h" + +namespace quic { + +class QuicConfig; + +namespace test { + +class QuicConfigPeer { + public: + QuicConfigPeer() = delete; + + static void SetReceivedInitialStreamFlowControlWindow(QuicConfig* config, + uint32_t window_bytes); + + static void SetReceivedInitialMaxStreamDataBytesIncomingBidirectional( + QuicConfig* config, uint32_t window_bytes); + + static void SetReceivedInitialMaxStreamDataBytesOutgoingBidirectional( + QuicConfig* config, uint32_t window_bytes); + + static void SetReceivedInitialMaxStreamDataBytesUnidirectional( + QuicConfig* config, uint32_t window_bytes); + + static void SetReceivedInitialSessionFlowControlWindow(QuicConfig* config, + uint32_t window_bytes); + + static void SetReceivedConnectionOptions(QuicConfig* config, + const QuicTagVector& options); + + static void SetReceivedBytesForConnectionId(QuicConfig* config, + uint32_t bytes); + + static void SetReceivedDisableConnectionMigration(QuicConfig* config); + + static void SetReceivedMaxBidirectionalStreams(QuicConfig* config, + uint32_t max_streams); + static void SetReceivedMaxUnidirectionalStreams(QuicConfig* config, + uint32_t max_streams); + + static void SetConnectionOptionsToSend(QuicConfig* config, + const QuicTagVector& options); + + static void SetReceivedStatelessResetToken(QuicConfig* config, + const StatelessResetToken& token); + + static void SetReceivedMaxPacketSize(QuicConfig* config, + uint32_t max_udp_payload_size); + + static void SetReceivedMinAckDelayMs(QuicConfig* config, + uint32_t min_ack_delay_ms); + + static void SetNegotiated(QuicConfig* config, bool negotiated); + + static void SetReceivedOriginalConnectionId( + QuicConfig* config, + const QuicConnectionId& original_destination_connection_id); + + static void SetReceivedInitialSourceConnectionId( + QuicConfig* config, const QuicConnectionId& initial_source_connection_id); + + static void SetReceivedRetrySourceConnectionId( + QuicConfig* config, const QuicConnectionId& retry_source_connection_id); + + static void SetReceivedMaxDatagramFrameSize(QuicConfig* config, + uint64_t max_datagram_frame_size); + + static void SetReceivedAlternateServerAddress( + QuicConfig* config, const QuicSocketAddress& server_address); + + static void SetPreferredAddressConnectionIdAndToken( + QuicConfig* config, QuicConnectionId connection_id, + const StatelessResetToken& token); +}; + +} // namespace test + +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_CONFIG_PEER_H_ diff --git a/quiche/quic/test_tools/quic_connection_id_manager_peer.h b/quiche/quic/test_tools/quic_connection_id_manager_peer.h new file mode 100644 index 000000000000..e7ac33021612 --- /dev/null +++ b/quiche/quic/test_tools/quic_connection_id_manager_peer.h @@ -0,0 +1,29 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_CONNECTION_ID_MANAGER_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_CONNECTION_ID_MANAGER_PEER_H_ + +#include "quiche/quic/core/quic_connection_id_manager.h" + +namespace quic { +namespace test { + +class QuicConnectionIdManagerPeer { + public: + static QuicAlarm* GetRetirePeerIssuedConnectionIdAlarm( + QuicPeerIssuedConnectionIdManager* manager) { + return manager->retire_connection_id_alarm_.get(); + } + + static QuicAlarm* GetRetireSelfIssuedConnectionIdAlarm( + QuicSelfIssuedConnectionIdManager* manager) { + return manager->retire_connection_id_alarm_.get(); + } +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_CONNECTION_ID_MANAGER_PEER_H_ diff --git a/quiche/quic/test_tools/quic_connection_peer.cc b/quiche/quic/test_tools/quic_connection_peer.cc new file mode 100644 index 000000000000..8bddb19ab015 --- /dev/null +++ b/quiche/quic/test_tools/quic_connection_peer.cc @@ -0,0 +1,628 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_connection_peer.h" + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/congestion_control/send_algorithm_interface.h" +#include "quiche/quic/core/quic_packet_writer.h" +#include "quiche/quic/core/quic_received_packet_manager.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/test_tools/quic_connection_id_manager_peer.h" +#include "quiche/quic/test_tools/quic_framer_peer.h" +#include "quiche/quic/test_tools/quic_sent_packet_manager_peer.h" + +namespace quic { +namespace test { + +// static +void QuicConnectionPeer::SetSendAlgorithm( + QuicConnection* connection, SendAlgorithmInterface* send_algorithm) { + GetSentPacketManager(connection)->SetSendAlgorithm(send_algorithm); +} + +// static +void QuicConnectionPeer::SetLossAlgorithm( + QuicConnection* connection, LossDetectionInterface* loss_algorithm) { + GetSentPacketManager(connection)->loss_algorithm_ = loss_algorithm; +} + +// static +void QuicConnectionPeer::PopulateStopWaitingFrame( + QuicConnection* connection, QuicStopWaitingFrame* stop_waiting) { + connection->PopulateStopWaitingFrame(stop_waiting); +} + +// static +QuicPacketCreator* QuicConnectionPeer::GetPacketCreator( + QuicConnection* connection) { + return &connection->packet_creator_; +} + +// static +QuicSentPacketManager* QuicConnectionPeer::GetSentPacketManager( + QuicConnection* connection) { + return &connection->sent_packet_manager_; +} + +// static +QuicTime::Delta QuicConnectionPeer::GetNetworkTimeout( + QuicConnection* connection) { + return connection->idle_network_detector_.idle_network_timeout_; +} + +// static +QuicTime::Delta QuicConnectionPeer::GetHandshakeTimeout( + QuicConnection* connection) { + return connection->idle_network_detector_.handshake_timeout_; +} + +// static +QuicTime::Delta QuicConnectionPeer::GetBandwidthUpdateTimeout( + QuicConnection* connection) { + return connection->idle_network_detector_.bandwidth_update_timeout_; +} + +// static +void QuicConnectionPeer::DisableBandwidthUpdate(QuicConnection* connection) { + if (connection->idle_network_detector_.bandwidth_update_timeout_ + .IsInfinite()) { + return; + } + connection->idle_network_detector_.bandwidth_update_timeout_ = + QuicTime::Delta::Infinite(); + connection->idle_network_detector_.SetAlarm(); +} + +// static +void QuicConnectionPeer::SetPerspective(QuicConnection* connection, + Perspective perspective) { + connection->perspective_ = perspective; + QuicFramerPeer::SetPerspective(&connection->framer_, perspective); + connection->ping_manager_.perspective_ = perspective; +} + +// static +void QuicConnectionPeer::SetSelfAddress(QuicConnection* connection, + const QuicSocketAddress& self_address) { + connection->default_path_.self_address = self_address; +} + +// static +void QuicConnectionPeer::SetPeerAddress(QuicConnection* connection, + const QuicSocketAddress& peer_address) { + connection->UpdatePeerAddress(peer_address); +} + +// static +void QuicConnectionPeer::SetDirectPeerAddress( + QuicConnection* connection, const QuicSocketAddress& direct_peer_address) { + connection->direct_peer_address_ = direct_peer_address; +} + +// static +void QuicConnectionPeer::SetEffectivePeerAddress( + QuicConnection* connection, + const QuicSocketAddress& effective_peer_address) { + connection->default_path_.peer_address = effective_peer_address; +} + +// static +void QuicConnectionPeer::SwapCrypters(QuicConnection* connection, + QuicFramer* framer) { + QuicFramerPeer::SwapCrypters(framer, &connection->framer_); +} + +// static +void QuicConnectionPeer::SetCurrentPacket(QuicConnection* connection, + absl::string_view current_packet) { + connection->current_packet_data_ = current_packet.data(); + connection->last_received_packet_info_.length = current_packet.size(); +} + +// static +QuicConnectionHelperInterface* QuicConnectionPeer::GetHelper( + QuicConnection* connection) { + return connection->helper_; +} + +// static +QuicAlarmFactory* QuicConnectionPeer::GetAlarmFactory( + QuicConnection* connection) { + return connection->alarm_factory_; +} + +// static +QuicFramer* QuicConnectionPeer::GetFramer(QuicConnection* connection) { + return &connection->framer_; +} + +// static +QuicAlarm* QuicConnectionPeer::GetAckAlarm(QuicConnection* connection) { + return connection->ack_alarm_.get(); +} + +// static +QuicAlarm* QuicConnectionPeer::GetPingAlarm(QuicConnection* connection) { + return connection->ping_manager_.alarm_.get(); +} + +// static +QuicAlarm* QuicConnectionPeer::GetRetransmissionAlarm( + QuicConnection* connection) { + return connection->retransmission_alarm_.get(); +} + +// static +QuicAlarm* QuicConnectionPeer::GetSendAlarm(QuicConnection* connection) { + return connection->send_alarm_.get(); +} + +// static +QuicAlarm* QuicConnectionPeer::GetMtuDiscoveryAlarm( + QuicConnection* connection) { + return connection->mtu_discovery_alarm_.get(); +} + +// static +QuicAlarm* QuicConnectionPeer::GetProcessUndecryptablePacketsAlarm( + QuicConnection* connection) { + return connection->process_undecryptable_packets_alarm_.get(); +} + +// static +QuicAlarm* QuicConnectionPeer::GetDiscardPreviousOneRttKeysAlarm( + QuicConnection* connection) { + return connection->discard_previous_one_rtt_keys_alarm_.get(); +} + +// static +QuicAlarm* QuicConnectionPeer::GetDiscardZeroRttDecryptionKeysAlarm( + QuicConnection* connection) { + return connection->discard_zero_rtt_decryption_keys_alarm_.get(); +} + +// static +QuicAlarm* QuicConnectionPeer::GetRetirePeerIssuedConnectionIdAlarm( + QuicConnection* connection) { + if (connection->peer_issued_cid_manager_ == nullptr) { + return nullptr; + } + return QuicConnectionIdManagerPeer::GetRetirePeerIssuedConnectionIdAlarm( + connection->peer_issued_cid_manager_.get()); +} +// static +QuicAlarm* QuicConnectionPeer::GetRetireSelfIssuedConnectionIdAlarm( + QuicConnection* connection) { + if (connection->self_issued_cid_manager_ == nullptr) { + return nullptr; + } + return QuicConnectionIdManagerPeer::GetRetireSelfIssuedConnectionIdAlarm( + connection->self_issued_cid_manager_.get()); +} + +// static +QuicPacketWriter* QuicConnectionPeer::GetWriter(QuicConnection* connection) { + return connection->writer_; +} + +// static +void QuicConnectionPeer::SetWriter(QuicConnection* connection, + QuicPacketWriter* writer, bool owns_writer) { + if (connection->owns_writer_) { + delete connection->writer_; + } + connection->writer_ = writer; + connection->owns_writer_ = owns_writer; +} + +// static +void QuicConnectionPeer::TearDownLocalConnectionState( + QuicConnection* connection) { + connection->connected_ = false; +} + +// static +QuicEncryptedPacket* QuicConnectionPeer::GetConnectionClosePacket( + QuicConnection* connection) { + if (connection->termination_packets_ == nullptr || + connection->termination_packets_->empty()) { + return nullptr; + } + return (*connection->termination_packets_)[0].get(); +} + +// static +QuicPacketHeader* QuicConnectionPeer::GetLastHeader( + QuicConnection* connection) { + return &connection->last_received_packet_info_.header; +} + +// static +QuicConnectionStats* QuicConnectionPeer::GetStats(QuicConnection* connection) { + return &connection->stats_; +} + +// static +QuicPacketCount QuicConnectionPeer::GetPacketsBetweenMtuProbes( + QuicConnection* connection) { + return connection->mtu_discoverer_.packets_between_probes(); +} + +// static +void QuicConnectionPeer::ReInitializeMtuDiscoverer( + QuicConnection* connection, QuicPacketCount packets_between_probes_base, + QuicPacketNumber next_probe_at) { + connection->mtu_discoverer_ = + QuicConnectionMtuDiscoverer(packets_between_probes_base, next_probe_at); +} + +// static +void QuicConnectionPeer::SetAckDecimationDelay(QuicConnection* connection, + float ack_decimation_delay) { + for (auto& received_packet_manager : + connection->uber_received_packet_manager_.received_packet_managers_) { + received_packet_manager.ack_decimation_delay_ = ack_decimation_delay; + } +} + +// static +bool QuicConnectionPeer::HasRetransmittableFrames(QuicConnection* connection, + uint64_t packet_number) { + return QuicSentPacketManagerPeer::HasRetransmittableFrames( + GetSentPacketManager(connection), packet_number); +} + +// static +bool QuicConnectionPeer::GetNoStopWaitingFrames(QuicConnection* connection) { + return connection->no_stop_waiting_frames_; +} + +// static +void QuicConnectionPeer::SetNoStopWaitingFrames(QuicConnection* connection, + bool no_stop_waiting_frames) { + connection->no_stop_waiting_frames_ = no_stop_waiting_frames; +} + +// static +void QuicConnectionPeer::SetMaxTrackedPackets( + QuicConnection* connection, QuicPacketCount max_tracked_packets) { + connection->max_tracked_packets_ = max_tracked_packets; +} + +// static +void QuicConnectionPeer::SetNegotiatedVersion(QuicConnection* connection) { + connection->version_negotiated_ = true; + if (connection->perspective() == Perspective::IS_SERVER && + !QuicFramerPeer::infer_packet_header_type_from_version( + &connection->framer_)) { + connection->framer_.InferPacketHeaderTypeFromVersion(); + } +} + +// static +void QuicConnectionPeer::SetMaxConsecutiveNumPacketsWithNoRetransmittableFrames( + QuicConnection* connection, size_t new_value) { + connection->max_consecutive_num_packets_with_no_retransmittable_frames_ = + new_value; +} + +// static +bool QuicConnectionPeer::SupportsReleaseTime(QuicConnection* connection) { + return connection->supports_release_time_; +} + +// static +QuicConnection::PacketContent QuicConnectionPeer::GetCurrentPacketContent( + QuicConnection* connection) { + return connection->current_packet_content_; +} + +// static +void QuicConnectionPeer::AddBytesReceived(QuicConnection* connection, + size_t length) { + if (connection->EnforceAntiAmplificationLimit()) { + connection->default_path_.bytes_received_before_address_validation += + length; + } +} + +// static +void QuicConnectionPeer::SetAddressValidated(QuicConnection* connection) { + connection->default_path_.validated = true; +} + +// static +void QuicConnectionPeer::SendConnectionClosePacket( + QuicConnection* connection, QuicIetfTransportErrorCodes ietf_error, + QuicErrorCode error, const std::string& details) { + connection->SendConnectionClosePacket(error, ietf_error, details); +} + +// static +size_t QuicConnectionPeer::GetNumEncryptionLevels(QuicConnection* connection) { + size_t count = 0; + for (EncryptionLevel level : + {ENCRYPTION_INITIAL, ENCRYPTION_HANDSHAKE, ENCRYPTION_ZERO_RTT, + ENCRYPTION_FORWARD_SECURE}) { + if (connection->framer_.HasEncrypterOfEncryptionLevel(level)) { + ++count; + } + } + return count; +} + +// static +QuicNetworkBlackholeDetector& QuicConnectionPeer::GetBlackholeDetector( + QuicConnection* connection) { + return connection->blackhole_detector_; +} + +// static +QuicAlarm* QuicConnectionPeer::GetBlackholeDetectorAlarm( + QuicConnection* connection) { + return connection->blackhole_detector_.alarm_.get(); +} + +// static +QuicTime QuicConnectionPeer::GetPathDegradingDeadline( + QuicConnection* connection) { + return connection->blackhole_detector_.path_degrading_deadline_; +} + +// static +QuicTime QuicConnectionPeer::GetBlackholeDetectionDeadline( + QuicConnection* connection) { + return connection->blackhole_detector_.blackhole_deadline_; +} + +// static +QuicTime QuicConnectionPeer::GetPathMtuReductionDetectionDeadline( + QuicConnection* connection) { + return connection->blackhole_detector_.path_mtu_reduction_deadline_; +} + +// static +QuicTime QuicConnectionPeer::GetIdleNetworkDeadline( + QuicConnection* connection) { + return connection->idle_network_detector_.GetIdleNetworkDeadline(); +} + +// static +QuicAlarm* QuicConnectionPeer::GetIdleNetworkDetectorAlarm( + QuicConnection* connection) { + return connection->idle_network_detector_.alarm_.get(); +} + +// static +QuicIdleNetworkDetector& QuicConnectionPeer::GetIdleNetworkDetector( + QuicConnection* connection) { + return connection->idle_network_detector_; +} + +// static +QuicAlarm* QuicConnectionPeer::GetMultiPortProbingAlarm( + QuicConnection* connection) { + return connection->multi_port_probing_alarm_.get(); +} + +// static +void QuicConnectionPeer::SetServerConnectionId( + QuicConnection* connection, const QuicConnectionId& server_connection_id) { + connection->default_path_.server_connection_id = server_connection_id; + connection->InstallInitialCrypters(server_connection_id); +} + +// static +size_t QuicConnectionPeer::NumUndecryptablePackets(QuicConnection* connection) { + return connection->undecryptable_packets_.size(); +} + +void QuicConnectionPeer::SetConnectionClose(QuicConnection* connection) { + connection->connected_ = false; +} + +// static +void QuicConnectionPeer::SendPing(QuicConnection* connection) { + connection->SendPingAtLevel(connection->encryption_level()); +} + +// static +void QuicConnectionPeer::SetLastPacketDestinationAddress( + QuicConnection* connection, const QuicSocketAddress& address) { + connection->last_received_packet_info_.destination_address = address; +} + +// static +QuicPathValidator* QuicConnectionPeer::path_validator( + QuicConnection* connection) { + return &connection->path_validator_; +} + +// static +QuicByteCount QuicConnectionPeer::BytesReceivedOnDefaultPath( + QuicConnection* connection) { + return connection->default_path_.bytes_received_before_address_validation; +} + +// static +QuicByteCount QuicConnectionPeer::BytesSentOnAlternativePath( + QuicConnection* connection) { + return connection->alternative_path_.bytes_sent_before_address_validation; +} + +// static +QuicByteCount QuicConnectionPeer::BytesReceivedOnAlternativePath( + QuicConnection* connection) { + return connection->alternative_path_.bytes_received_before_address_validation; +} + +// static +QuicConnectionId QuicConnectionPeer::GetClientConnectionIdOnAlternativePath( + const QuicConnection* connection) { + return connection->alternative_path_.client_connection_id; +} + +// static +QuicConnectionId QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + const QuicConnection* connection) { + return connection->alternative_path_.server_connection_id; +} + +// static +bool QuicConnectionPeer::IsAlternativePathValidated( + QuicConnection* connection) { + return connection->alternative_path_.validated; +} + +// static +bool QuicConnectionPeer::IsAlternativePath( + QuicConnection* connection, const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address) { + return connection->IsAlternativePath(self_address, peer_address); +} + +// static +QuicByteCount QuicConnectionPeer::BytesReceivedBeforeAddressValidation( + QuicConnection* connection) { + return connection->default_path_.bytes_received_before_address_validation; +} + +// static +void QuicConnectionPeer::ResetPeerIssuedConnectionIdManager( + QuicConnection* connection) { + connection->peer_issued_cid_manager_ = nullptr; +} + +// static +QuicConnection::PathState* QuicConnectionPeer::GetDefaultPath( + QuicConnection* connection) { + return &connection->default_path_; +} + +// static +bool QuicConnectionPeer::IsDefaultPath(QuicConnection* connection, + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address) { + return connection->IsDefaultPath(self_address, peer_address); +} + +// static +QuicConnection::PathState* QuicConnectionPeer::GetAlternativePath( + QuicConnection* connection) { + return &connection->alternative_path_; +} + +// static +void QuicConnectionPeer::RetirePeerIssuedConnectionIdsNoLongerOnPath( + QuicConnection* connection) { + connection->RetirePeerIssuedConnectionIdsNoLongerOnPath(); +} + +// static +bool QuicConnectionPeer::HasUnusedPeerIssuedConnectionId( + const QuicConnection* connection) { + return connection->peer_issued_cid_manager_->HasUnusedConnectionId(); +} + +// static +bool QuicConnectionPeer::HasSelfIssuedConnectionIdToConsume( + const QuicConnection* connection) { + return connection->self_issued_cid_manager_->HasConnectionIdToConsume(); +} + +// static +QuicSelfIssuedConnectionIdManager* +QuicConnectionPeer::GetSelfIssuedConnectionIdManager( + QuicConnection* connection) { + return connection->self_issued_cid_manager_.get(); +} + +// static +std::unique_ptr +QuicConnectionPeer::MakeSelfIssuedConnectionIdManager( + QuicConnection* connection) { + return connection->MakeSelfIssuedConnectionIdManager(); +} + +// static +void QuicConnectionPeer::SetLastDecryptedLevel(QuicConnection* connection, + EncryptionLevel level) { + connection->last_received_packet_info_.decrypted_level = level; +} + +// static +QuicCoalescedPacket& QuicConnectionPeer::GetCoalescedPacket( + QuicConnection* connection) { + return connection->coalesced_packet_; +} + +// static +void QuicConnectionPeer::FlushCoalescedPacket(QuicConnection* connection) { + connection->FlushCoalescedPacket(); +} + +// static +void QuicConnectionPeer::SetInProbeTimeOut(QuicConnection* connection, + bool value) { + connection->in_probe_time_out_ = value; +} + +// static +QuicSocketAddress QuicConnectionPeer::GetReceivedServerPreferredAddress( + QuicConnection* connection) { + return connection->received_server_preferred_address_; +} + +// static +QuicSocketAddress QuicConnectionPeer::GetSentServerPreferredAddress( + QuicConnection* connection) { + return connection->sent_server_preferred_address_; +} + +// static +bool QuicConnectionPeer::TestLastReceivedPacketInfoDefaults() { + QuicConnection::ReceivedPacketInfo info{QuicTime::Zero()}; + QUIC_DVLOG(2) + << "QuicConnectionPeer::TestLastReceivedPacketInfoDefaults" + << " dest_addr passed: " + << (info.destination_address == QuicSocketAddress()) + << " source_addr passed: " << (info.source_address == QuicSocketAddress()) + << " receipt_time passed: " << (info.receipt_time == QuicTime::Zero()) + << " received_bytes_counted passed: " << !info.received_bytes_counted + << " destination_connection_id passed: " + << (info.destination_connection_id == QuicConnectionId()) + << " length passed: " << (info.length == 0) + << " decrypted passed: " << !info.decrypted << " decrypted_level passed: " + << (info.decrypted_level == ENCRYPTION_INITIAL) + << " frames.empty passed: " << info.frames.empty() + << " ecn_codepoint passed: " << (info.ecn_codepoint == ECN_NOT_ECT) + << " sizeof(ReceivedPacketInfo) passed: " + << (sizeof(size_t) != 8 || + sizeof(QuicConnection::ReceivedPacketInfo) == 280); + return info.destination_address == QuicSocketAddress() && + info.source_address == QuicSocketAddress() && + info.receipt_time == QuicTime::Zero() && + !info.received_bytes_counted && info.length == 0 && + info.destination_connection_id == QuicConnectionId() && + !info.decrypted && info.decrypted_level == ENCRYPTION_INITIAL && + // There's no simple way to compare all the values of QuicPacketHeader. + info.frames.empty() && info.ecn_codepoint == ECN_NOT_ECT && + info.actual_destination_address == QuicSocketAddress() && + // If the condition below fails, the contents of ReceivedPacketInfo + // have changed. Please add the relevant conditions and update the + // length below. + (sizeof(size_t) != 8 || + sizeof(QuicConnection::ReceivedPacketInfo) == 280); +} + +// static +void QuicConnectionPeer::DisableEcnCodepointValidation( + QuicConnection* connection) { + connection->disable_ecn_codepoint_validation_ = true; +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_connection_peer.h b/quiche/quic/test_tools/quic_connection_peer.h new file mode 100644 index 000000000000..cf8a5a6cf96a --- /dev/null +++ b/quiche/quic/test_tools/quic_connection_peer.h @@ -0,0 +1,253 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_CONNECTION_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_CONNECTION_PEER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_connection_stats.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_path_validator.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/platform/api/quic_socket_address.h" + +namespace quic { + +struct QuicPacketHeader; +class QuicAlarm; +class QuicConnectionHelperInterface; +class QuicConnectionVisitorInterface; +class QuicEncryptedPacket; +class QuicFramer; +class QuicPacketCreator; +class QuicPacketWriter; +class QuicSentPacketManager; +class SendAlgorithmInterface; + +namespace test { + +// Peer to make public a number of otherwise private QuicConnection methods. +class QuicConnectionPeer { + public: + QuicConnectionPeer() = delete; + + static void SetSendAlgorithm(QuicConnection* connection, + SendAlgorithmInterface* send_algorithm); + + static void SetLossAlgorithm(QuicConnection* connection, + LossDetectionInterface* loss_algorithm); + + static void PopulateStopWaitingFrame(QuicConnection* connection, + QuicStopWaitingFrame* stop_waiting); + + static QuicPacketCreator* GetPacketCreator(QuicConnection* connection); + + static QuicSentPacketManager* GetSentPacketManager( + QuicConnection* connection); + + static QuicTime::Delta GetNetworkTimeout(QuicConnection* connection); + + static QuicTime::Delta GetHandshakeTimeout(QuicConnection* connection); + + static QuicTime::Delta GetBandwidthUpdateTimeout(QuicConnection* connection); + + static void DisableBandwidthUpdate(QuicConnection* connection); + + static void SetPerspective(QuicConnection* connection, + Perspective perspective); + + static void SetSelfAddress(QuicConnection* connection, + const QuicSocketAddress& self_address); + + static void SetPeerAddress(QuicConnection* connection, + const QuicSocketAddress& peer_address); + + static void SetDirectPeerAddress( + QuicConnection* connection, const QuicSocketAddress& direct_peer_address); + + static void SetEffectivePeerAddress( + QuicConnection* connection, + const QuicSocketAddress& effective_peer_address); + + static void SwapCrypters(QuicConnection* connection, QuicFramer* framer); + + static void SetCurrentPacket(QuicConnection* connection, + absl::string_view current_packet); + + static QuicConnectionHelperInterface* GetHelper(QuicConnection* connection); + + static QuicAlarmFactory* GetAlarmFactory(QuicConnection* connection); + + static QuicFramer* GetFramer(QuicConnection* connection); + + static QuicAlarm* GetAckAlarm(QuicConnection* connection); + static QuicAlarm* GetPingAlarm(QuicConnection* connection); + static QuicAlarm* GetRetransmissionAlarm(QuicConnection* connection); + static QuicAlarm* GetSendAlarm(QuicConnection* connection); + static QuicAlarm* GetMtuDiscoveryAlarm(QuicConnection* connection); + static QuicAlarm* GetProcessUndecryptablePacketsAlarm( + QuicConnection* connection); + static QuicAlarm* GetDiscardPreviousOneRttKeysAlarm( + QuicConnection* connection); + static QuicAlarm* GetDiscardZeroRttDecryptionKeysAlarm( + QuicConnection* connection); + static QuicAlarm* GetRetirePeerIssuedConnectionIdAlarm( + QuicConnection* connection); + static QuicAlarm* GetRetireSelfIssuedConnectionIdAlarm( + QuicConnection* connection); + + static QuicPacketWriter* GetWriter(QuicConnection* connection); + // If |owns_writer| is true, takes ownership of |writer|. + static void SetWriter(QuicConnection* connection, QuicPacketWriter* writer, + bool owns_writer); + static void TearDownLocalConnectionState(QuicConnection* connection); + static QuicEncryptedPacket* GetConnectionClosePacket( + QuicConnection* connection); + + static QuicPacketHeader* GetLastHeader(QuicConnection* connection); + + static QuicConnectionStats* GetStats(QuicConnection* connection); + + static QuicPacketCount GetPacketsBetweenMtuProbes(QuicConnection* connection); + + static void ReInitializeMtuDiscoverer( + QuicConnection* connection, QuicPacketCount packets_between_probes_base, + QuicPacketNumber next_probe_at); + static void SetAckDecimationDelay(QuicConnection* connection, + float ack_decimation_delay); + static bool HasRetransmittableFrames(QuicConnection* connection, + uint64_t packet_number); + static bool GetNoStopWaitingFrames(QuicConnection* connection); + static void SetNoStopWaitingFrames(QuicConnection* connection, + bool no_stop_waiting_frames); + static void SetMaxTrackedPackets(QuicConnection* connection, + QuicPacketCount max_tracked_packets); + static void SetNegotiatedVersion(QuicConnection* connection); + static void SetMaxConsecutiveNumPacketsWithNoRetransmittableFrames( + QuicConnection* connection, size_t new_value); + static bool SupportsReleaseTime(QuicConnection* connection); + static QuicConnection::PacketContent GetCurrentPacketContent( + QuicConnection* connection); + static void AddBytesReceived(QuicConnection* connection, size_t length); + static void SetAddressValidated(QuicConnection* connection); + + static void SendConnectionClosePacket(QuicConnection* connection, + QuicIetfTransportErrorCodes ietf_error, + QuicErrorCode error, + const std::string& details); + + static size_t GetNumEncryptionLevels(QuicConnection* connection); + + static QuicNetworkBlackholeDetector& GetBlackholeDetector( + QuicConnection* connection); + + static QuicAlarm* GetBlackholeDetectorAlarm(QuicConnection* connection); + + static QuicTime GetPathDegradingDeadline(QuicConnection* connection); + + static QuicTime GetBlackholeDetectionDeadline(QuicConnection* connection); + + static QuicTime GetPathMtuReductionDetectionDeadline( + QuicConnection* connection); + + static QuicAlarm* GetIdleNetworkDetectorAlarm(QuicConnection* connection); + + static QuicTime GetIdleNetworkDeadline(QuicConnection* connection); + + static QuicIdleNetworkDetector& GetIdleNetworkDetector( + QuicConnection* connection); + + static void SetServerConnectionId( + QuicConnection* connection, const QuicConnectionId& server_connection_id); + + static size_t NumUndecryptablePackets(QuicConnection* connection); + + static void SetConnectionClose(QuicConnection* connection); + + static void SendPing(QuicConnection* connection); + + static void SetLastPacketDestinationAddress(QuicConnection* connection, + const QuicSocketAddress& address); + + static QuicPathValidator* path_validator(QuicConnection* connection); + + static QuicByteCount BytesReceivedOnDefaultPath(QuicConnection* connection); + + static QuicByteCount BytesSentOnAlternativePath(QuicConnection* connection); + + static QuicByteCount BytesReceivedOnAlternativePath( + QuicConnection* connection); + + static QuicConnectionId GetClientConnectionIdOnAlternativePath( + const QuicConnection* connection); + + static QuicConnectionId GetServerConnectionIdOnAlternativePath( + const QuicConnection* connection); + + static bool IsAlternativePath(QuicConnection* connection, + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address); + + static bool IsAlternativePathValidated(QuicConnection* connection); + + static QuicByteCount BytesReceivedBeforeAddressValidation( + QuicConnection* connection); + + static void ResetPeerIssuedConnectionIdManager(QuicConnection* connection); + + static QuicConnection::PathState* GetDefaultPath(QuicConnection* connection); + + static bool IsDefaultPath(QuicConnection* connection, + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address); + + static QuicConnection::PathState* GetAlternativePath( + QuicConnection* connection); + + static void RetirePeerIssuedConnectionIdsNoLongerOnPath( + QuicConnection* connection); + + static bool HasUnusedPeerIssuedConnectionId(const QuicConnection* connection); + + static bool HasSelfIssuedConnectionIdToConsume( + const QuicConnection* connection); + + static QuicSelfIssuedConnectionIdManager* GetSelfIssuedConnectionIdManager( + QuicConnection* connection); + + static std::unique_ptr + MakeSelfIssuedConnectionIdManager(QuicConnection* connection); + + static void SetLastDecryptedLevel(QuicConnection* connection, + EncryptionLevel level); + + static QuicCoalescedPacket& GetCoalescedPacket(QuicConnection* connection); + + static void FlushCoalescedPacket(QuicConnection* connection); + + static QuicAlarm* GetMultiPortProbingAlarm(QuicConnection* connection); + + static void SetInProbeTimeOut(QuicConnection* connection, bool value); + + static QuicSocketAddress GetReceivedServerPreferredAddress( + QuicConnection* connection); + + static QuicSocketAddress GetSentServerPreferredAddress( + QuicConnection* connection); + + static bool TestLastReceivedPacketInfoDefaults(); + + // Overrides restrictions on sending ECN for test purposes. + static void DisableEcnCodepointValidation(QuicConnection* connection); +}; + +} // namespace test + +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_CONNECTION_PEER_H_ diff --git a/quiche/quic/test_tools/quic_crypto_server_config_peer.cc b/quiche/quic/test_tools/quic_crypto_server_config_peer.cc new file mode 100644 index 000000000000..48c8aedde696 --- /dev/null +++ b/quiche/quic/test_tools/quic_crypto_server_config_peer.cc @@ -0,0 +1,152 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_crypto_server_config_peer.h" + +#include "absl/strings/string_view.h" +#include "quiche/quic/test_tools/mock_clock.h" +#include "quiche/quic/test_tools/mock_random.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { +namespace test { + +quiche::QuicheReferenceCountedPointer +QuicCryptoServerConfigPeer::GetPrimaryConfig() { + QuicReaderMutexLock locked(&server_config_->configs_lock_); + return quiche::QuicheReferenceCountedPointer( + server_config_->primary_config_); +} + +quiche::QuicheReferenceCountedPointer +QuicCryptoServerConfigPeer::GetConfig(std::string config_id) { + QuicReaderMutexLock locked(&server_config_->configs_lock_); + if (config_id == "") { + return quiche::QuicheReferenceCountedPointer< + QuicCryptoServerConfig::Config>(server_config_->primary_config_); + } else { + return server_config_->GetConfigWithScid(config_id); + } +} + +ProofSource* QuicCryptoServerConfigPeer::GetProofSource() const { + return server_config_->proof_source_.get(); +} + +void QuicCryptoServerConfigPeer::ResetProofSource( + std::unique_ptr proof_source) { + server_config_->proof_source_ = std::move(proof_source); +} + +std::string QuicCryptoServerConfigPeer::NewSourceAddressToken( + std::string config_id, SourceAddressTokens previous_tokens, + const QuicIpAddress& ip, QuicRandom* rand, QuicWallTime now, + CachedNetworkParameters* cached_network_params) { + return server_config_->NewSourceAddressToken( + *GetConfig(config_id)->source_address_token_boxer, previous_tokens, ip, + rand, now, cached_network_params); +} + +HandshakeFailureReason QuicCryptoServerConfigPeer::ValidateSourceAddressTokens( + std::string config_id, absl::string_view srct, const QuicIpAddress& ip, + QuicWallTime now, CachedNetworkParameters* cached_network_params) { + SourceAddressTokens tokens; + HandshakeFailureReason reason = server_config_->ParseSourceAddressToken( + *GetConfig(config_id)->source_address_token_boxer, srct, tokens); + if (reason != HANDSHAKE_OK) { + return reason; + } + + return server_config_->ValidateSourceAddressTokens(tokens, ip, now, + cached_network_params); +} + +HandshakeFailureReason +QuicCryptoServerConfigPeer::ValidateSingleSourceAddressToken( + absl::string_view token, const QuicIpAddress& ip, QuicWallTime now) { + SourceAddressTokens tokens; + HandshakeFailureReason parse_status = server_config_->ParseSourceAddressToken( + *GetPrimaryConfig()->source_address_token_boxer, token, tokens); + if (HANDSHAKE_OK != parse_status) { + return parse_status; + } + EXPECT_EQ(1, tokens.tokens_size()); + return server_config_->ValidateSingleSourceAddressToken(tokens.tokens(0), ip, + now); +} + +void QuicCryptoServerConfigPeer::CheckConfigs( + std::vector> expected_ids_and_status) { + QuicReaderMutexLock locked(&server_config_->configs_lock_); + + ASSERT_EQ(expected_ids_and_status.size(), server_config_->configs_.size()) + << ConfigsDebug(); + + for (const std::pair>& i : + server_config_->configs_) { + bool found = false; + for (std::pair& j : expected_ids_and_status) { + if (i.first == j.first && i.second->is_primary == j.second) { + found = true; + j.first.clear(); + break; + } + } + + ASSERT_TRUE(found) << "Failed to find match for " << i.first + << " in configs:\n" + << ConfigsDebug(); + } +} + +// ConfigsDebug returns a std::string that contains debugging information about +// the set of Configs loaded in |server_config_| and their status. +std::string QuicCryptoServerConfigPeer::ConfigsDebug() { + if (server_config_->configs_.empty()) { + return "No Configs in QuicCryptoServerConfig"; + } + + std::string s; + + for (const auto& i : server_config_->configs_) { + const quiche::QuicheReferenceCountedPointer + config = i.second; + if (config->is_primary) { + s += "(primary) "; + } else { + s += " "; + } + s += config->id; + s += "\n"; + } + + return s; +} + +void QuicCryptoServerConfigPeer::SelectNewPrimaryConfig(int seconds) { + QuicWriterMutexLock locked(&server_config_->configs_lock_); + server_config_->SelectNewPrimaryConfig( + QuicWallTime::FromUNIXSeconds(seconds)); +} + +std::string QuicCryptoServerConfigPeer::CompressChain( + QuicCompressedCertsCache* compressed_certs_cache, + const quiche::QuicheReferenceCountedPointer& chain, + const std::string& client_cached_cert_hashes) { + return QuicCryptoServerConfig::CompressChain(compressed_certs_cache, chain, + client_cached_cert_hashes); +} + +uint32_t QuicCryptoServerConfigPeer::source_address_token_future_secs() { + return server_config_->source_address_token_future_secs_; +} + +uint32_t QuicCryptoServerConfigPeer::source_address_token_lifetime_secs() { + return server_config_->source_address_token_lifetime_secs_; +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_crypto_server_config_peer.h b/quiche/quic/test_tools/quic_crypto_server_config_peer.h new file mode 100644 index 000000000000..f8a1ee54e7c6 --- /dev/null +++ b/quiche/quic/test_tools/quic_crypto_server_config_peer.h @@ -0,0 +1,88 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_CRYPTO_SERVER_CONFIG_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_CRYPTO_SERVER_CONFIG_PEER_H_ + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/quic_crypto_server_config.h" + +namespace quic { +namespace test { + +// Peer for accessing otherwise private members of a QuicCryptoServerConfig. +class QuicCryptoServerConfigPeer { + public: + explicit QuicCryptoServerConfigPeer(QuicCryptoServerConfig* server_config) + : server_config_(server_config) {} + + // Returns the primary config. + quiche::QuicheReferenceCountedPointer + GetPrimaryConfig(); + + // Returns the config associated with |config_id|. + quiche::QuicheReferenceCountedPointer + GetConfig(std::string config_id); + + // Returns a pointer to the ProofSource object. + ProofSource* GetProofSource() const; + + // Reset the proof_source_ member. + void ResetProofSource(std::unique_ptr proof_source); + + // Generates a new valid source address token. + std::string NewSourceAddressToken( + std::string config_id, SourceAddressTokens previous_tokens, + const QuicIpAddress& ip, QuicRandom* rand, QuicWallTime now, + CachedNetworkParameters* cached_network_params); + + // Attempts to validate the tokens in |srct|. + HandshakeFailureReason ValidateSourceAddressTokens( + std::string config_id, absl::string_view srct, const QuicIpAddress& ip, + QuicWallTime now, CachedNetworkParameters* cached_network_params); + + // Attempts to validate the single source address token in |token|. + HandshakeFailureReason ValidateSingleSourceAddressToken( + absl::string_view token, const QuicIpAddress& ip, QuicWallTime now); + + // CheckConfigs compares the state of the Configs in |server_config_| to the + // description given as arguments. + // The first of each pair is the server config ID of a Config. The second is a + // boolean describing whether the config is the primary. For example: + // CheckConfigs(std::vector>()); // checks + // that no Configs are loaded. + // + // // Checks that exactly three Configs are loaded with the given IDs and + // // status. + // CheckConfigs( + // {{"id1", false}, + // {"id2", true}, + // {"id3", false}}); + void CheckConfigs( + std::vector> expected_ids_and_status); + + // ConfigsDebug returns a std::string that contains debugging information + // about the set of Configs loaded in |server_config_| and their status. + std::string ConfigsDebug() + QUIC_SHARED_LOCKS_REQUIRED(server_config_->configs_lock_); + + void SelectNewPrimaryConfig(int seconds); + + static std::string CompressChain( + QuicCompressedCertsCache* compressed_certs_cache, + const quiche::QuicheReferenceCountedPointer& chain, + const std::string& client_cached_cert_hashes); + + uint32_t source_address_token_future_secs(); + + uint32_t source_address_token_lifetime_secs(); + + private: + QuicCryptoServerConfig* server_config_; +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_CRYPTO_SERVER_CONFIG_PEER_H_ diff --git a/quiche/quic/test_tools/quic_dispatcher_peer.cc b/quiche/quic/test_tools/quic_dispatcher_peer.cc new file mode 100644 index 000000000000..8a803c5d5240 --- /dev/null +++ b/quiche/quic/test_tools/quic_dispatcher_peer.cc @@ -0,0 +1,136 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_dispatcher_peer.h" + +#include "quiche/quic/core/quic_dispatcher.h" +#include "quiche/quic/core/quic_packet_writer_wrapper.h" + +namespace quic { +namespace test { + +// static +QuicTimeWaitListManager* QuicDispatcherPeer::GetTimeWaitListManager( + QuicDispatcher* dispatcher) { + return dispatcher->time_wait_list_manager_.get(); +} + +// static +void QuicDispatcherPeer::SetTimeWaitListManager( + QuicDispatcher* dispatcher, + QuicTimeWaitListManager* time_wait_list_manager) { + dispatcher->time_wait_list_manager_.reset(time_wait_list_manager); +} + +// static +void QuicDispatcherPeer::UseWriter(QuicDispatcher* dispatcher, + QuicPacketWriterWrapper* writer) { + writer->set_writer(dispatcher->writer_.release()); + dispatcher->writer_.reset(writer); +} + +// static +QuicPacketWriter* QuicDispatcherPeer::GetWriter(QuicDispatcher* dispatcher) { + return dispatcher->writer_.get(); +} + +// static +QuicCompressedCertsCache* QuicDispatcherPeer::GetCache( + QuicDispatcher* dispatcher) { + return dispatcher->compressed_certs_cache(); +} + +// static +QuicConnectionHelperInterface* QuicDispatcherPeer::GetHelper( + QuicDispatcher* dispatcher) { + return dispatcher->helper_.get(); +} + +// static +QuicAlarmFactory* QuicDispatcherPeer::GetAlarmFactory( + QuicDispatcher* dispatcher) { + return dispatcher->alarm_factory_.get(); +} + +// static +QuicDispatcher::WriteBlockedList* QuicDispatcherPeer::GetWriteBlockedList( + QuicDispatcher* dispatcher) { + return &dispatcher->write_blocked_list_; +} + +// static +QuicErrorCode QuicDispatcherPeer::GetAndClearLastError( + QuicDispatcher* dispatcher) { + QuicErrorCode ret = dispatcher->last_error_; + dispatcher->last_error_ = QUIC_NO_ERROR; + return ret; +} + +// static +QuicBufferedPacketStore* QuicDispatcherPeer::GetBufferedPackets( + QuicDispatcher* dispatcher) { + return &(dispatcher->buffered_packets_); +} + +// static +void QuicDispatcherPeer::set_new_sessions_allowed_per_event_loop( + QuicDispatcher* dispatcher, size_t num_session_allowed) { + dispatcher->new_sessions_allowed_per_event_loop_ = num_session_allowed; +} + +// static +void QuicDispatcherPeer::SendPublicReset( + QuicDispatcher* dispatcher, const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, QuicConnectionId connection_id, + bool ietf_quic, size_t received_packet_length, + std::unique_ptr packet_context) { + dispatcher->time_wait_list_manager()->SendPublicReset( + self_address, peer_address, connection_id, ietf_quic, + received_packet_length, std::move(packet_context)); +} + +// static +std::unique_ptr QuicDispatcherPeer::GetPerPacketContext( + QuicDispatcher* dispatcher) { + return dispatcher->GetPerPacketContext(); +} + +// static +void QuicDispatcherPeer::RestorePerPacketContext( + QuicDispatcher* dispatcher, std::unique_ptr context) { + dispatcher->RestorePerPacketContext(std::move(context)); +} + +// static +std::string QuicDispatcherPeer::SelectAlpn( + QuicDispatcher* dispatcher, const std::vector& alpns) { + return dispatcher->SelectAlpn(alpns); +} + +// static +QuicSession* QuicDispatcherPeer::GetFirstSessionIfAny( + QuicDispatcher* dispatcher) { + if (dispatcher->reference_counted_session_map_.empty()) { + return nullptr; + } + return dispatcher->reference_counted_session_map_.begin()->second.get(); +} + +// static +const QuicSession* QuicDispatcherPeer::FindSession( + const QuicDispatcher* dispatcher, QuicConnectionId id) { + auto it = dispatcher->reference_counted_session_map_.find(id); + return (it == dispatcher->reference_counted_session_map_.end()) + ? nullptr + : it->second.get(); +} + +// static +QuicAlarm* QuicDispatcherPeer::GetClearResetAddressesAlarm( + QuicDispatcher* dispatcher) { + return dispatcher->clear_stateless_reset_addresses_alarm_.get(); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_dispatcher_peer.h b/quiche/quic/test_tools/quic_dispatcher_peer.h new file mode 100644 index 000000000000..1238a89b185c --- /dev/null +++ b/quiche/quic/test_tools/quic_dispatcher_peer.h @@ -0,0 +1,82 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_DISPATCHER_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_DISPATCHER_PEER_H_ + +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_dispatcher.h" + +namespace quic { + +class QuicPacketWriterWrapper; + +namespace test { + +class QuicDispatcherPeer { + public: + QuicDispatcherPeer() = delete; + + static QuicTimeWaitListManager* GetTimeWaitListManager( + QuicDispatcher* dispatcher); + + static void SetTimeWaitListManager( + QuicDispatcher* dispatcher, + QuicTimeWaitListManager* time_wait_list_manager); + + // Injects |writer| into |dispatcher| as the shared writer. + static void UseWriter(QuicDispatcher* dispatcher, + QuicPacketWriterWrapper* writer); + + static QuicPacketWriter* GetWriter(QuicDispatcher* dispatcher); + + static QuicCompressedCertsCache* GetCache(QuicDispatcher* dispatcher); + + static QuicConnectionHelperInterface* GetHelper(QuicDispatcher* dispatcher); + + static QuicAlarmFactory* GetAlarmFactory(QuicDispatcher* dispatcher); + + static QuicDispatcher::WriteBlockedList* GetWriteBlockedList( + QuicDispatcher* dispatcher); + + // Get the dispatcher's record of the last error reported to its framer + // visitor's OnError() method. Then set that record to QUIC_NO_ERROR. + static QuicErrorCode GetAndClearLastError(QuicDispatcher* dispatcher); + + static QuicBufferedPacketStore* GetBufferedPackets( + QuicDispatcher* dispatcher); + + static void set_new_sessions_allowed_per_event_loop( + QuicDispatcher* dispatcher, size_t num_session_allowed); + + static void SendPublicReset( + QuicDispatcher* dispatcher, const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, QuicConnectionId connection_id, + bool ietf_quic, size_t received_packet_length, + std::unique_ptr packet_context); + + static std::unique_ptr GetPerPacketContext( + QuicDispatcher* dispatcher); + + static void RestorePerPacketContext(QuicDispatcher* dispatcher, + std::unique_ptr); + + static std::string SelectAlpn(QuicDispatcher* dispatcher, + const std::vector& alpns); + + // Get the first session in the session map. Returns nullptr if the map is + // empty. + static QuicSession* GetFirstSessionIfAny(QuicDispatcher* dispatcher); + + // Find the corresponding session if exsits. + static const QuicSession* FindSession(const QuicDispatcher* dispatcher, + QuicConnectionId id); + + static QuicAlarm* GetClearResetAddressesAlarm(QuicDispatcher* dispatcher); +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_DISPATCHER_PEER_H_ diff --git a/quiche/quic/test_tools/quic_flow_controller_peer.cc b/quiche/quic/test_tools/quic_flow_controller_peer.cc new file mode 100644 index 000000000000..70b3007e5d47 --- /dev/null +++ b/quiche/quic/test_tools/quic_flow_controller_peer.cc @@ -0,0 +1,65 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_flow_controller_peer.h" + +#include + +#include "quiche/quic/core/quic_flow_controller.h" +#include "quiche/quic/core/quic_packets.h" + +namespace quic { +namespace test { + +// static +void QuicFlowControllerPeer::SetSendWindowOffset( + QuicFlowController* flow_controller, QuicStreamOffset offset) { + flow_controller->send_window_offset_ = offset; +} + +// static +void QuicFlowControllerPeer::SetReceiveWindowOffset( + QuicFlowController* flow_controller, QuicStreamOffset offset) { + flow_controller->receive_window_offset_ = offset; +} + +// static +void QuicFlowControllerPeer::SetMaxReceiveWindow( + QuicFlowController* flow_controller, QuicByteCount window_size) { + flow_controller->receive_window_size_ = window_size; +} + +// static +QuicStreamOffset QuicFlowControllerPeer::SendWindowOffset( + QuicFlowController* flow_controller) { + return flow_controller->send_window_offset_; +} + +// static +QuicByteCount QuicFlowControllerPeer::SendWindowSize( + QuicFlowController* flow_controller) { + return flow_controller->SendWindowSize(); +} + +// static +QuicStreamOffset QuicFlowControllerPeer::ReceiveWindowOffset( + QuicFlowController* flow_controller) { + return flow_controller->receive_window_offset_; +} + +// static +QuicByteCount QuicFlowControllerPeer::ReceiveWindowSize( + QuicFlowController* flow_controller) { + return flow_controller->receive_window_offset_ - + flow_controller->highest_received_byte_offset_; +} + +// static +QuicByteCount QuicFlowControllerPeer::WindowUpdateThreshold( + QuicFlowController* flow_controller) { + return flow_controller->WindowUpdateThreshold(); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_flow_controller_peer.h b/quiche/quic/test_tools/quic_flow_controller_peer.h new file mode 100644 index 000000000000..7c5e42622edd --- /dev/null +++ b/quiche/quic/test_tools/quic_flow_controller_peer.h @@ -0,0 +1,45 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_FLOW_CONTROLLER_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_FLOW_CONTROLLER_PEER_H_ + +#include "quiche/quic/core/quic_packets.h" + +namespace quic { + +class QuicFlowController; + +namespace test { + +class QuicFlowControllerPeer { + public: + QuicFlowControllerPeer() = delete; + + static void SetSendWindowOffset(QuicFlowController* flow_controller, + QuicStreamOffset offset); + + static void SetReceiveWindowOffset(QuicFlowController* flow_controller, + QuicStreamOffset offset); + + static void SetMaxReceiveWindow(QuicFlowController* flow_controller, + QuicByteCount window_size); + + static QuicStreamOffset SendWindowOffset(QuicFlowController* flow_controller); + + static QuicByteCount SendWindowSize(QuicFlowController* flow_controller); + + static QuicStreamOffset ReceiveWindowOffset( + QuicFlowController* flow_controller); + + static QuicByteCount ReceiveWindowSize(QuicFlowController* flow_controller); + + static QuicByteCount WindowUpdateThreshold( + QuicFlowController* flow_controller); +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_FLOW_CONTROLLER_PEER_H_ diff --git a/quiche/quic/test_tools/quic_framer_peer.cc b/quiche/quic/test_tools/quic_framer_peer.cc new file mode 100644 index 000000000000..9819a8aa202e --- /dev/null +++ b/quiche/quic/test_tools/quic_framer_peer.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_framer_peer.h" + +#include "quiche/quic/core/quic_framer.h" +#include "quiche/quic/core/quic_packets.h" + +namespace quic { +namespace test { + +// static +uint64_t QuicFramerPeer::CalculatePacketNumberFromWire( + QuicFramer* framer, QuicPacketNumberLength packet_number_length, + QuicPacketNumber last_packet_number, uint64_t packet_number) { + return framer->CalculatePacketNumberFromWire( + packet_number_length, last_packet_number, packet_number); +} + +// static +void QuicFramerPeer::SetLastSerializedServerConnectionId( + QuicFramer* framer, QuicConnectionId server_connection_id) { + framer->last_serialized_server_connection_id_ = server_connection_id; +} + +// static +void QuicFramerPeer::SetLastSerializedClientConnectionId( + QuicFramer* framer, QuicConnectionId client_connection_id) { + framer->last_serialized_client_connection_id_ = client_connection_id; +} + +// static +void QuicFramerPeer::SetLastWrittenPacketNumberLength( + QuicFramer* framer, size_t packet_number_length) { + framer->last_written_packet_number_length_ = packet_number_length; +} + +// static +void QuicFramerPeer::SetLargestPacketNumber(QuicFramer* framer, + QuicPacketNumber packet_number) { + framer->largest_packet_number_ = packet_number; +} + +// static +void QuicFramerPeer::SetPerspective(QuicFramer* framer, + Perspective perspective) { + framer->perspective_ = perspective; + framer->infer_packet_header_type_from_version_ = + perspective == Perspective::IS_CLIENT; +} + +// static +void QuicFramerPeer::SwapCrypters(QuicFramer* framer1, QuicFramer* framer2) { + for (int i = ENCRYPTION_INITIAL; i < NUM_ENCRYPTION_LEVELS; i++) { + framer1->encrypter_[i].swap(framer2->encrypter_[i]); + framer1->decrypter_[i].swap(framer2->decrypter_[i]); + } + + EncryptionLevel framer2_level = framer2->decrypter_level_; + framer2->decrypter_level_ = framer1->decrypter_level_; + framer1->decrypter_level_ = framer2_level; + framer2_level = framer2->alternative_decrypter_level_; + framer2->alternative_decrypter_level_ = framer1->alternative_decrypter_level_; + framer1->alternative_decrypter_level_ = framer2_level; + + const bool framer2_latch = framer2->alternative_decrypter_latch_; + framer2->alternative_decrypter_latch_ = framer1->alternative_decrypter_latch_; + framer1->alternative_decrypter_latch_ = framer2_latch; +} + +// static +QuicEncrypter* QuicFramerPeer::GetEncrypter(QuicFramer* framer, + EncryptionLevel level) { + return framer->encrypter_[level].get(); +} + +// static +QuicDecrypter* QuicFramerPeer::GetDecrypter(QuicFramer* framer, + EncryptionLevel level) { + return framer->decrypter_[level].get(); +} + +// static +void QuicFramerPeer::SetFirstSendingPacketNumber(QuicFramer* framer, + uint64_t packet_number) { + *const_cast(&framer->first_sending_packet_number_) = + QuicPacketNumber(packet_number); +} + +// static +void QuicFramerPeer::SetExpectedServerConnectionIDLength( + QuicFramer* framer, uint8_t expected_server_connection_id_length) { + *const_cast(&framer->expected_server_connection_id_length_) = + expected_server_connection_id_length; +} + +// static +QuicPacketNumber QuicFramerPeer::GetLargestDecryptedPacketNumber( + QuicFramer* framer, PacketNumberSpace packet_number_space) { + return framer->largest_decrypted_packet_numbers_[packet_number_space]; +} + +// static +bool QuicFramerPeer::ProcessAndValidateIetfConnectionIdLength( + QuicDataReader* reader, ParsedQuicVersion version, Perspective perspective, + bool should_update_expected_server_connection_id_length, + uint8_t* expected_server_connection_id_length, + uint8_t* destination_connection_id_length, + uint8_t* source_connection_id_length, std::string* detailed_error) { + return QuicFramer::ProcessAndValidateIetfConnectionIdLength( + reader, version, perspective, + should_update_expected_server_connection_id_length, + expected_server_connection_id_length, destination_connection_id_length, + source_connection_id_length, detailed_error); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_framer_peer.h b/quiche/quic/test_tools/quic_framer_peer.h new file mode 100644 index 000000000000..aa383947662b --- /dev/null +++ b/quiche/quic/test_tools/quic_framer_peer.h @@ -0,0 +1,69 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_FRAMER_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_FRAMER_PEER_H_ + +#include "quiche/quic/core/crypto/quic_encrypter.h" +#include "quiche/quic/core/quic_framer.h" +#include "quiche/quic/core/quic_packets.h" + +namespace quic { + +namespace test { + +class QuicFramerPeer { + public: + QuicFramerPeer() = delete; + + static uint64_t CalculatePacketNumberFromWire( + QuicFramer* framer, QuicPacketNumberLength packet_number_length, + QuicPacketNumber last_packet_number, uint64_t packet_number); + static void SetLastSerializedServerConnectionId( + QuicFramer* framer, QuicConnectionId server_connection_id); + static void SetLastSerializedClientConnectionId( + QuicFramer* framer, QuicConnectionId client_connection_id); + static void SetLastWrittenPacketNumberLength(QuicFramer* framer, + size_t packet_number_length); + static void SetLargestPacketNumber(QuicFramer* framer, + QuicPacketNumber packet_number); + static void SetPerspective(QuicFramer* framer, Perspective perspective); + + // SwapCrypters exchanges the state of the crypters of |framer1| with + // |framer2|. + static void SwapCrypters(QuicFramer* framer1, QuicFramer* framer2); + + static QuicEncrypter* GetEncrypter(QuicFramer* framer, EncryptionLevel level); + static QuicDecrypter* GetDecrypter(QuicFramer* framer, EncryptionLevel level); + + static void SetFirstSendingPacketNumber(QuicFramer* framer, + uint64_t packet_number); + static void SetExpectedServerConnectionIDLength( + QuicFramer* framer, uint8_t expected_server_connection_id_length); + static QuicPacketNumber GetLargestDecryptedPacketNumber( + QuicFramer* framer, PacketNumberSpace packet_number_space); + + static bool ProcessAndValidateIetfConnectionIdLength( + QuicDataReader* reader, ParsedQuicVersion version, + Perspective perspective, + bool should_update_expected_server_connection_id_length, + uint8_t* expected_server_connection_id_length, + uint8_t* destination_connection_id_length, + uint8_t* source_connection_id_length, std::string* detailed_error); + + static void set_current_received_frame_type( + QuicFramer* framer, uint64_t current_received_frame_type) { + framer->current_received_frame_type_ = current_received_frame_type; + } + + static bool infer_packet_header_type_from_version(QuicFramer* framer) { + return framer->infer_packet_header_type_from_version_; + } +}; + +} // namespace test + +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_FRAMER_PEER_H_ diff --git a/quiche/quic/test_tools/quic_http_response_cache_data/test.example.com/index.html b/quiche/quic/test_tools/quic_http_response_cache_data/test.example.com/index.html new file mode 100644 index 000000000000..5edaf9af7b50 --- /dev/null +++ b/quiche/quic/test_tools/quic_http_response_cache_data/test.example.com/index.html @@ -0,0 +1,63 @@ +HTTP/1.1 200 OK +Date: Tue, 28 Aug 2012 15:08:56 GMT +Server: Apache/2.2.3 (CentOS) +X-Powered-By: PHP/5.1.6 +Set-Cookie: bblastvisit=1346166536; expires=Wed, 28-Aug-2013 15:08:56 GMT; path=/; domain=.nasioc.com +Set-Cookie: bblastactivity=0; expires=Wed, 28-Aug-2013 15:08:56 GMT; path=/; domain=.nasioc.com +Expires: 0 +Cache-Control: private, post-check=0, pre-check=0, max-age=0 +Pragma: no-cache +X-UA-Compatible: IE=7 +Connection: close +Content-Type: text/html; charset=ISO-8859-1 + + + + + Example Domain + + + + + + + + +
+

Example Domain

+

This domain is established to be used for illustrative examples in documents. You may use this + domain in examples without prior coordination or asking for permission.

+

More information...

+
+ + diff --git a/quiche/quic/test_tools/quic_http_response_cache_data/test.example.com/map.html b/quiche/quic/test_tools/quic_http_response_cache_data/test.example.com/map.html new file mode 100644 index 000000000000..b34c3b085f07 --- /dev/null +++ b/quiche/quic/test_tools/quic_http_response_cache_data/test.example.com/map.html @@ -0,0 +1,65 @@ +HTTP/1.1 200 OK +Date: Tue, 28 Aug 2012 15:08:56 GMT +Server: Apache/2.2.3 (CentOS) +X-Powered-By: PHP/5.1.6 +Set-Cookie: bblastvisit=1346166536; expires=Wed, 28-Aug-2013 15:08:56 GMT; path=/; domain=.nasioc.com +Set-Cookie: bblastactivity=0; expires=Wed, 28-Aug-2013 15:08:56 GMT; path=/; domain=.nasioc.com +Expires: 0 +Cache-Control: private, post-check=0, pre-check=0, max-age=0 +Pragma: no-cache +X-UA-Compatible: IE=7 +Connection: close +Content-Type: text/html; charset=ISO-8859-1 +X-Original-Url: http://test.example.com/site_map.html + + + + + + Example Domain + + + + + + + + +
+

Example Domain

+

This domain is established to be used for illustrative examples in documents. You may use this + domain in examples without prior coordination or asking for permission.

+

More information...

+
+ + diff --git a/quiche/quic/test_tools/quic_interval_deque_peer.h b/quiche/quic/test_tools/quic_interval_deque_peer.h new file mode 100644 index 000000000000..c522457f2791 --- /dev/null +++ b/quiche/quic/test_tools/quic_interval_deque_peer.h @@ -0,0 +1,35 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_INTERVAL_DEQUE_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_INTERVAL_DEQUE_PEER_H_ + +#include "quiche/quic/core/quic_interval_deque.h" + +namespace quic { + +namespace test { + +class QuicIntervalDequePeer { + public: + template + static int32_t GetCachedIndex(QuicIntervalDeque* interval_deque) { + if (!interval_deque->cached_index_.has_value()) { + return -1; + } + return interval_deque->cached_index_.value(); + } + + template + static T* GetItem(QuicIntervalDeque* interval_deque, + const std::size_t index) { + return &interval_deque->container_[index]; + } +}; + +} // namespace test + +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_INTERVAL_DEQUE_PEER_H_ diff --git a/quiche/quic/test_tools/quic_mock_syscall_wrapper.cc b/quiche/quic/test_tools/quic_mock_syscall_wrapper.cc new file mode 100644 index 000000000000..f934198f73dd --- /dev/null +++ b/quiche/quic/test_tools/quic_mock_syscall_wrapper.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_mock_syscall_wrapper.h" + +using testing::_; +using testing::Invoke; + +namespace quic { +namespace test { + +MockQuicSyscallWrapper::MockQuicSyscallWrapper(QuicSyscallWrapper* delegate) { + ON_CALL(*this, Sendmsg(_, _, _)) + .WillByDefault(Invoke(delegate, &QuicSyscallWrapper::Sendmsg)); + + ON_CALL(*this, Sendmmsg(_, _, _, _)) + .WillByDefault(Invoke(delegate, &QuicSyscallWrapper::Sendmmsg)); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_mock_syscall_wrapper.h b/quiche/quic/test_tools/quic_mock_syscall_wrapper.h new file mode 100644 index 000000000000..58b021dc6112 --- /dev/null +++ b/quiche/quic/test_tools/quic_mock_syscall_wrapper.h @@ -0,0 +1,33 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_MOCK_SYSCALL_WRAPPER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_MOCK_SYSCALL_WRAPPER_H_ + +#include "quiche/quic/core/quic_syscall_wrapper.h" +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { + +class MockQuicSyscallWrapper : public QuicSyscallWrapper { + public: + // Create a standard mock object. + MockQuicSyscallWrapper() = default; + + // Create a 'mockable' object that delegates everything to |delegate| by + // default. + explicit MockQuicSyscallWrapper(QuicSyscallWrapper* delegate); + + MOCK_METHOD(ssize_t, Sendmsg, (int sockfd, const msghdr*, int flags), + (override)); + + MOCK_METHOD(int, Sendmmsg, + (int sockfd, mmsghdr*, unsigned int vlen, int flags), (override)); +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_MOCK_SYSCALL_WRAPPER_H_ diff --git a/quiche/quic/test_tools/quic_packet_creator_peer.cc b/quiche/quic/test_tools/quic_packet_creator_peer.cc new file mode 100644 index 000000000000..2ab9fb0d8551 --- /dev/null +++ b/quiche/quic/test_tools/quic_packet_creator_peer.cc @@ -0,0 +1,161 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_packet_creator_peer.h" + +#include "quiche/quic/core/frames/quic_frame.h" +#include "quiche/quic/core/quic_packet_creator.h" +#include "quiche/quic/core/quic_types.h" + +namespace quic { +namespace test { + +// static +bool QuicPacketCreatorPeer::SendVersionInPacket(QuicPacketCreator* creator) { + return creator->IncludeVersionInHeader(); +} + +// static +void QuicPacketCreatorPeer::SetSendVersionInPacket( + QuicPacketCreator* creator, bool send_version_in_packet) { + ParsedQuicVersion version = creator->framer_->version(); + if (!VersionHasIetfQuicFrames(version.transport_version) && + version.handshake_protocol != PROTOCOL_TLS1_3) { + creator->send_version_in_packet_ = send_version_in_packet; + return; + } + if (!send_version_in_packet) { + creator->packet_.encryption_level = ENCRYPTION_FORWARD_SECURE; + return; + } + QUICHE_DCHECK(creator->packet_.encryption_level < ENCRYPTION_FORWARD_SECURE); +} + +// static +void QuicPacketCreatorPeer::SetPacketNumberLength( + QuicPacketCreator* creator, QuicPacketNumberLength packet_number_length) { + creator->packet_.packet_number_length = packet_number_length; +} + +// static +QuicPacketNumberLength QuicPacketCreatorPeer::GetPacketNumberLength( + QuicPacketCreator* creator) { + return creator->GetPacketNumberLength(); +} + +// static +quiche::QuicheVariableLengthIntegerLength +QuicPacketCreatorPeer::GetRetryTokenLengthLength(QuicPacketCreator* creator) { + return creator->GetRetryTokenLengthLength(); +} + +// static +quiche::QuicheVariableLengthIntegerLength +QuicPacketCreatorPeer::GetLengthLength(QuicPacketCreator* creator) { + return creator->GetLengthLength(); +} + +void QuicPacketCreatorPeer::SetPacketNumber(QuicPacketCreator* creator, + uint64_t s) { + QUICHE_DCHECK_NE(0u, s); + creator->packet_.packet_number = QuicPacketNumber(s); +} + +void QuicPacketCreatorPeer::SetPacketNumber(QuicPacketCreator* creator, + QuicPacketNumber num) { + creator->packet_.packet_number = num; +} + +// static +void QuicPacketCreatorPeer::ClearPacketNumber(QuicPacketCreator* creator) { + creator->packet_.packet_number.Clear(); +} + +// static +void QuicPacketCreatorPeer::FillPacketHeader(QuicPacketCreator* creator, + QuicPacketHeader* header) { + creator->FillPacketHeader(header); +} + +// static +void QuicPacketCreatorPeer::CreateStreamFrame(QuicPacketCreator* creator, + QuicStreamId id, + size_t data_length, + QuicStreamOffset offset, bool fin, + QuicFrame* frame) { + creator->CreateStreamFrame(id, data_length, offset, fin, frame); +} + +// static +bool QuicPacketCreatorPeer::CreateCryptoFrame(QuicPacketCreator* creator, + EncryptionLevel level, + size_t write_length, + QuicStreamOffset offset, + QuicFrame* frame) { + return creator->CreateCryptoFrame(level, write_length, offset, frame); +} + +// static +SerializedPacket QuicPacketCreatorPeer::SerializeAllFrames( + QuicPacketCreator* creator, const QuicFrames& frames, char* buffer, + size_t buffer_len) { + QUICHE_DCHECK(creator->queued_frames_.empty()); + QUICHE_DCHECK(!frames.empty()); + for (const QuicFrame& frame : frames) { + bool success = creator->AddFrame(frame, NOT_RETRANSMISSION); + QUICHE_DCHECK(success); + } + const bool success = + creator->SerializePacket(QuicOwnedPacketBuffer(buffer, nullptr), + buffer_len, /*allow_padding=*/true); + QUICHE_DCHECK(success); + SerializedPacket packet = std::move(creator->packet_); + // The caller takes ownership of the QuicEncryptedPacket. + creator->packet_.encrypted_buffer = nullptr; + return packet; +} + +// static +std::unique_ptr +QuicPacketCreatorPeer::SerializeConnectivityProbingPacket( + QuicPacketCreator* creator) { + return creator->SerializeConnectivityProbingPacket(); +} + +// static +std::unique_ptr +QuicPacketCreatorPeer::SerializePathChallengeConnectivityProbingPacket( + QuicPacketCreator* creator, const QuicPathFrameBuffer& payload) { + return creator->SerializePathChallengeConnectivityProbingPacket(payload); +} + +// static +EncryptionLevel QuicPacketCreatorPeer::GetEncryptionLevel( + QuicPacketCreator* creator) { + return creator->packet_.encryption_level; +} + +// static +QuicFramer* QuicPacketCreatorPeer::framer(QuicPacketCreator* creator) { + return creator->framer_; +} + +// static +std::string QuicPacketCreatorPeer::GetRetryToken(QuicPacketCreator* creator) { + return creator->retry_token_; +} + +// static +QuicFrames& QuicPacketCreatorPeer::QueuedFrames(QuicPacketCreator* creator) { + return creator->queued_frames_; +} + +// static +void QuicPacketCreatorPeer::SetRandom(QuicPacketCreator* creator, + QuicRandom* random) { + creator->random_ = random; +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_packet_creator_peer.h b/quiche/quic/test_tools/quic_packet_creator_peer.h new file mode 100644 index 000000000000..b7833a456d05 --- /dev/null +++ b/quiche/quic/test_tools/quic_packet_creator_peer.h @@ -0,0 +1,64 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_PACKET_CREATOR_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_PACKET_CREATOR_PEER_H_ + +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/quic_packets.h" + +namespace quic { +class QuicFramer; +class QuicPacketCreator; + +namespace test { + +class QuicPacketCreatorPeer { + public: + QuicPacketCreatorPeer() = delete; + + static bool SendVersionInPacket(QuicPacketCreator* creator); + + static void SetSendVersionInPacket(QuicPacketCreator* creator, + bool send_version_in_packet); + static void SetPacketNumberLength( + QuicPacketCreator* creator, QuicPacketNumberLength packet_number_length); + static QuicPacketNumberLength GetPacketNumberLength( + QuicPacketCreator* creator); + static quiche::QuicheVariableLengthIntegerLength GetRetryTokenLengthLength( + QuicPacketCreator* creator); + static quiche::QuicheVariableLengthIntegerLength GetLengthLength( + QuicPacketCreator* creator); + static void SetPacketNumber(QuicPacketCreator* creator, uint64_t s); + static void SetPacketNumber(QuicPacketCreator* creator, QuicPacketNumber num); + static void ClearPacketNumber(QuicPacketCreator* creator); + static void FillPacketHeader(QuicPacketCreator* creator, + QuicPacketHeader* header); + static void CreateStreamFrame(QuicPacketCreator* creator, QuicStreamId id, + size_t data_length, QuicStreamOffset offset, + bool fin, QuicFrame* frame); + static bool CreateCryptoFrame(QuicPacketCreator* creator, + EncryptionLevel level, size_t write_length, + QuicStreamOffset offset, QuicFrame* frame); + static SerializedPacket SerializeAllFrames(QuicPacketCreator* creator, + const QuicFrames& frames, + char* buffer, size_t buffer_len); + static std::unique_ptr SerializeConnectivityProbingPacket( + QuicPacketCreator* creator); + static std::unique_ptr + SerializePathChallengeConnectivityProbingPacket( + QuicPacketCreator* creator, const QuicPathFrameBuffer& payload); + + static EncryptionLevel GetEncryptionLevel(QuicPacketCreator* creator); + static QuicFramer* framer(QuicPacketCreator* creator); + static std::string GetRetryToken(QuicPacketCreator* creator); + static QuicFrames& QueuedFrames(QuicPacketCreator* creator); + static void SetRandom(QuicPacketCreator* creator, QuicRandom* random); +}; + +} // namespace test + +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_PACKET_CREATOR_PEER_H_ diff --git a/quiche/quic/test_tools/quic_path_validator_peer.cc b/quiche/quic/test_tools/quic_path_validator_peer.cc new file mode 100644 index 000000000000..42c92cf1cb72 --- /dev/null +++ b/quiche/quic/test_tools/quic_path_validator_peer.cc @@ -0,0 +1,15 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_path_validator_peer.h" + +namespace quic { +namespace test { +// static +QuicAlarm* QuicPathValidatorPeer::retry_timer(QuicPathValidator* validator) { + return validator->retry_timer_.get(); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_path_validator_peer.h b/quiche/quic/test_tools/quic_path_validator_peer.h new file mode 100644 index 000000000000..00139830fbb4 --- /dev/null +++ b/quiche/quic/test_tools/quic_path_validator_peer.h @@ -0,0 +1,20 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_PATH_VALIDATOR_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_PATH_VALIDATOR_PEER_H_ + +#include "quiche/quic/core/quic_path_validator.h" + +namespace quic { +namespace test { + +class QuicPathValidatorPeer { + public: + static QuicAlarm* retry_timer(QuicPathValidator* validator); +}; + +} // namespace test +} // namespace quic +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_PATH_VALIDATOR_PEER_H_ diff --git a/quiche/quic/test_tools/quic_sent_packet_manager_peer.cc b/quiche/quic/test_tools/quic_sent_packet_manager_peer.cc new file mode 100644 index 000000000000..c8954151c749 --- /dev/null +++ b/quiche/quic/test_tools/quic_sent_packet_manager_peer.cc @@ -0,0 +1,186 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_sent_packet_manager_peer.h" + +#include "quiche/quic/core/congestion_control/loss_detection_interface.h" +#include "quiche/quic/core/congestion_control/send_algorithm_interface.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_sent_packet_manager.h" +#include "quiche/quic/test_tools/quic_unacked_packet_map_peer.h" + +namespace quic { +namespace test { + + +// static +void QuicSentPacketManagerPeer::SetPerspective( + QuicSentPacketManager* sent_packet_manager, Perspective perspective) { + QuicUnackedPacketMapPeer::SetPerspective( + &sent_packet_manager->unacked_packets_, perspective); +} + +// static +SendAlgorithmInterface* QuicSentPacketManagerPeer::GetSendAlgorithm( + const QuicSentPacketManager& sent_packet_manager) { + return sent_packet_manager.send_algorithm_.get(); +} + +// static +void QuicSentPacketManagerPeer::SetSendAlgorithm( + QuicSentPacketManager* sent_packet_manager, + SendAlgorithmInterface* send_algorithm) { + sent_packet_manager->SetSendAlgorithm(send_algorithm); +} + +// static +const LossDetectionInterface* QuicSentPacketManagerPeer::GetLossAlgorithm( + QuicSentPacketManager* sent_packet_manager) { + return sent_packet_manager->loss_algorithm_; +} + +// static +void QuicSentPacketManagerPeer::SetLossAlgorithm( + QuicSentPacketManager* sent_packet_manager, + LossDetectionInterface* loss_detector) { + sent_packet_manager->loss_algorithm_ = loss_detector; +} + +// static +RttStats* QuicSentPacketManagerPeer::GetRttStats( + QuicSentPacketManager* sent_packet_manager) { + return &sent_packet_manager->rtt_stats_; +} + +// static +bool QuicSentPacketManagerPeer::IsRetransmission( + QuicSentPacketManager* sent_packet_manager, uint64_t packet_number) { + QUICHE_DCHECK(HasRetransmittableFrames(sent_packet_manager, packet_number)); + if (!HasRetransmittableFrames(sent_packet_manager, packet_number)) { + return false; + } + return sent_packet_manager->unacked_packets_ + .GetTransmissionInfo(QuicPacketNumber(packet_number)) + .transmission_type != NOT_RETRANSMISSION; +} + +// static +void QuicSentPacketManagerPeer::MarkForRetransmission( + QuicSentPacketManager* sent_packet_manager, uint64_t packet_number, + TransmissionType transmission_type) { + sent_packet_manager->MarkForRetransmission(QuicPacketNumber(packet_number), + transmission_type); +} + +// static +size_t QuicSentPacketManagerPeer::GetNumRetransmittablePackets( + const QuicSentPacketManager* sent_packet_manager) { + size_t num_unacked_packets = 0; + for (auto it = sent_packet_manager->unacked_packets_.begin(); + it != sent_packet_manager->unacked_packets_.end(); ++it) { + if (sent_packet_manager->unacked_packets_.HasRetransmittableFrames(*it)) { + ++num_unacked_packets; + } + } + return num_unacked_packets; +} + +// static +void QuicSentPacketManagerPeer::SetConsecutivePtoCount( + QuicSentPacketManager* sent_packet_manager, size_t count) { + sent_packet_manager->consecutive_pto_count_ = count; +} + +// static +QuicSustainedBandwidthRecorder& QuicSentPacketManagerPeer::GetBandwidthRecorder( + QuicSentPacketManager* sent_packet_manager) { + return sent_packet_manager->sustained_bandwidth_recorder_; +} + +// static +bool QuicSentPacketManagerPeer::UsingPacing( + const QuicSentPacketManager* sent_packet_manager) { + return sent_packet_manager->using_pacing_; +} + +// static +void QuicSentPacketManagerPeer::SetUsingPacing( + QuicSentPacketManager* sent_packet_manager, bool using_pacing) { + sent_packet_manager->using_pacing_ = using_pacing; +} + +// static +bool QuicSentPacketManagerPeer::HasRetransmittableFrames( + QuicSentPacketManager* sent_packet_manager, uint64_t packet_number) { + return sent_packet_manager->unacked_packets_.HasRetransmittableFrames( + QuicPacketNumber(packet_number)); +} + +// static +QuicUnackedPacketMap* QuicSentPacketManagerPeer::GetUnackedPacketMap( + QuicSentPacketManager* sent_packet_manager) { + return &sent_packet_manager->unacked_packets_; +} + +// static +void QuicSentPacketManagerPeer::DisablePacerBursts( + QuicSentPacketManager* sent_packet_manager) { + sent_packet_manager->pacing_sender_.burst_tokens_ = 0; + sent_packet_manager->pacing_sender_.initial_burst_size_ = 0; +} + +// static +int QuicSentPacketManagerPeer::GetPacerInitialBurstSize( + QuicSentPacketManager* sent_packet_manager) { + return sent_packet_manager->pacing_sender_.initial_burst_size_; +} + +// static +void QuicSentPacketManagerPeer::SetNextPacedPacketTime( + QuicSentPacketManager* sent_packet_manager, QuicTime time) { + sent_packet_manager->pacing_sender_.ideal_next_packet_send_time_ = time; +} + +// static +int QuicSentPacketManagerPeer::GetReorderingShift( + QuicSentPacketManager* sent_packet_manager) { + return sent_packet_manager->uber_loss_algorithm_.general_loss_algorithms_[0] + .reordering_shift(); +} + +// static +bool QuicSentPacketManagerPeer::AdaptiveReorderingThresholdEnabled( + QuicSentPacketManager* sent_packet_manager) { + return sent_packet_manager->uber_loss_algorithm_.general_loss_algorithms_[0] + .use_adaptive_reordering_threshold(); +} + +// static +bool QuicSentPacketManagerPeer::AdaptiveTimeThresholdEnabled( + QuicSentPacketManager* sent_packet_manager) { + return sent_packet_manager->uber_loss_algorithm_.general_loss_algorithms_[0] + .use_adaptive_time_threshold(); +} + +// static +bool QuicSentPacketManagerPeer::UsePacketThresholdForRuntPackets( + QuicSentPacketManager* sent_packet_manager) { + return sent_packet_manager->uber_loss_algorithm_.general_loss_algorithms_[0] + .use_packet_threshold_for_runt_packets(); +} + +// static +int QuicSentPacketManagerPeer::GetNumPtosForPathDegrading( + QuicSentPacketManager* sent_packet_manager) { + return sent_packet_manager->num_ptos_for_path_degrading_; +} + +// static +QuicEcnCounts* QuicSentPacketManagerPeer::GetPeerEcnCounts( + QuicSentPacketManager* sent_packet_manager, PacketNumberSpace space) { + return &(sent_packet_manager->peer_ack_ecn_counts_[space]); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_sent_packet_manager_peer.h b/quiche/quic/test_tools/quic_sent_packet_manager_peer.h new file mode 100644 index 000000000000..e9606194b940 --- /dev/null +++ b/quiche/quic/test_tools/quic_sent_packet_manager_peer.h @@ -0,0 +1,97 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_SENT_PACKET_MANAGER_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_SENT_PACKET_MANAGER_PEER_H_ + +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_sent_packet_manager.h" + +namespace quic { + +class SendAlgorithmInterface; + +namespace test { + +class QuicSentPacketManagerPeer { + public: + QuicSentPacketManagerPeer() = delete; + + + static void SetPerspective(QuicSentPacketManager* sent_packet_manager, + Perspective perspective); + + static SendAlgorithmInterface* GetSendAlgorithm( + const QuicSentPacketManager& sent_packet_manager); + + static void SetSendAlgorithm(QuicSentPacketManager* sent_packet_manager, + SendAlgorithmInterface* send_algorithm); + + static const LossDetectionInterface* GetLossAlgorithm( + QuicSentPacketManager* sent_packet_manager); + + static void SetLossAlgorithm(QuicSentPacketManager* sent_packet_manager, + LossDetectionInterface* loss_detector); + + static RttStats* GetRttStats(QuicSentPacketManager* sent_packet_manager); + + // Returns true if |packet_number| is a retransmission of a packet. + static bool IsRetransmission(QuicSentPacketManager* sent_packet_manager, + uint64_t packet_number); + + static void MarkForRetransmission(QuicSentPacketManager* sent_packet_manager, + uint64_t packet_number, + TransmissionType transmission_type); + + static size_t GetNumRetransmittablePackets( + const QuicSentPacketManager* sent_packet_manager); + + static void SetConsecutivePtoCount(QuicSentPacketManager* sent_packet_manager, + size_t count); + + static QuicSustainedBandwidthRecorder& GetBandwidthRecorder( + QuicSentPacketManager* sent_packet_manager); + + static void SetUsingPacing(QuicSentPacketManager* sent_packet_manager, + bool using_pacing); + + static bool UsingPacing(const QuicSentPacketManager* sent_packet_manager); + + static bool HasRetransmittableFrames( + QuicSentPacketManager* sent_packet_manager, uint64_t packet_number); + + static QuicUnackedPacketMap* GetUnackedPacketMap( + QuicSentPacketManager* sent_packet_manager); + + static void DisablePacerBursts(QuicSentPacketManager* sent_packet_manager); + + static int GetPacerInitialBurstSize( + QuicSentPacketManager* sent_packet_manager); + + static void SetNextPacedPacketTime(QuicSentPacketManager* sent_packet_manager, + QuicTime time); + + static int GetReorderingShift(QuicSentPacketManager* sent_packet_manager); + + static bool AdaptiveReorderingThresholdEnabled( + QuicSentPacketManager* sent_packet_manager); + + static bool AdaptiveTimeThresholdEnabled( + QuicSentPacketManager* sent_packet_manager); + + static bool UsePacketThresholdForRuntPackets( + QuicSentPacketManager* sent_packet_manager); + + static int GetNumPtosForPathDegrading( + QuicSentPacketManager* sent_packet_manager); + + static QuicEcnCounts* GetPeerEcnCounts( + QuicSentPacketManager* sent_packet_manager, PacketNumberSpace space); +}; + +} // namespace test + +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_SENT_PACKET_MANAGER_PEER_H_ diff --git a/quiche/quic/test_tools/quic_server_peer.cc b/quiche/quic/test_tools/quic_server_peer.cc new file mode 100644 index 000000000000..6f6c8f9069f1 --- /dev/null +++ b/quiche/quic/test_tools/quic_server_peer.cc @@ -0,0 +1,32 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_server_peer.h" + +#include "quiche/quic/core/quic_dispatcher.h" +#include "quiche/quic/core/quic_packet_reader.h" +#include "quiche/quic/tools/quic_server.h" + +namespace quic { +namespace test { + +// static +bool QuicServerPeer::SetSmallSocket(QuicServer* server) { + int size = 1024 * 10; + return setsockopt(server->fd_, SOL_SOCKET, SO_RCVBUF, &size, sizeof(size)) != + -1; +} + +// static +QuicDispatcher* QuicServerPeer::GetDispatcher(QuicServer* server) { + return server->dispatcher_.get(); +} + +// static +void QuicServerPeer::SetReader(QuicServer* server, QuicPacketReader* reader) { + server->packet_reader_.reset(reader); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_server_peer.h b/quiche/quic/test_tools/quic_server_peer.h new file mode 100644 index 000000000000..29a36d45065f --- /dev/null +++ b/quiche/quic/test_tools/quic_server_peer.h @@ -0,0 +1,28 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_SERVER_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_SERVER_PEER_H_ + +namespace quic { + +class QuicDispatcher; +class QuicServer; +class QuicPacketReader; + +namespace test { + +class QuicServerPeer { + public: + QuicServerPeer() = delete; + + static bool SetSmallSocket(QuicServer* server); + static QuicDispatcher* GetDispatcher(QuicServer* server); + static void SetReader(QuicServer* server, QuicPacketReader* reader); +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_SERVER_PEER_H_ diff --git a/quiche/quic/test_tools/quic_server_session_base_peer.h b/quiche/quic/test_tools/quic_server_session_base_peer.h new file mode 100644 index 000000000000..c6b60c0d6d34 --- /dev/null +++ b/quiche/quic/test_tools/quic_server_session_base_peer.h @@ -0,0 +1,33 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_SERVER_SESSION_BASE_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_SERVER_SESSION_BASE_PEER_H_ + +#include "quiche/quic/core/http/quic_server_session_base.h" +#include "quiche/quic/core/quic_utils.h" + +namespace quic { +namespace test { + +class QuicServerSessionBasePeer { + public: + static QuicStream* GetOrCreateStream(QuicServerSessionBase* s, + QuicStreamId id) { + return s->GetOrCreateStream(id); + } + static void SetCryptoStream(QuicServerSessionBase* s, + QuicCryptoServerStreamBase* crypto_stream) { + s->crypto_stream_.reset(crypto_stream); + } + static bool IsBandwidthResumptionEnabled(QuicServerSessionBase* s) { + return s->bandwidth_resumption_enabled_; + } +}; + +} // namespace test + +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_SERVER_SESSION_BASE_PEER_H_ diff --git a/quiche/quic/test_tools/quic_session_peer.cc b/quiche/quic/test_tools/quic_session_peer.cc new file mode 100644 index 000000000000..3986a0e6c780 --- /dev/null +++ b/quiche/quic/test_tools/quic_session_peer.cc @@ -0,0 +1,246 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_session_peer.h" + +#include "absl/container/flat_hash_map.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_stream.h" +#include "quiche/quic/core/quic_utils.h" + +namespace quic { +namespace test { + +// static +QuicStreamId QuicSessionPeer::GetNextOutgoingBidirectionalStreamId( + QuicSession* session) { + return session->GetNextOutgoingBidirectionalStreamId(); +} + +// static +QuicStreamId QuicSessionPeer::GetNextOutgoingUnidirectionalStreamId( + QuicSession* session) { + return session->GetNextOutgoingUnidirectionalStreamId(); +} + +// static +void QuicSessionPeer::SetNextOutgoingBidirectionalStreamId(QuicSession* session, + QuicStreamId id) { + if (VersionHasIetfQuicFrames(session->transport_version())) { + session->ietf_streamid_manager_.bidirectional_stream_id_manager_ + .next_outgoing_stream_id_ = id; + return; + } + session->stream_id_manager_.next_outgoing_stream_id_ = id; +} + +// static +void QuicSessionPeer::SetMaxOpenIncomingStreams(QuicSession* session, + uint32_t max_streams) { + if (VersionHasIetfQuicFrames(session->transport_version())) { + QUIC_BUG(quic_bug_10193_1) + << "SetmaxOpenIncomingStreams deprecated for IETF QUIC"; + session->ietf_streamid_manager_.SetMaxOpenIncomingUnidirectionalStreams( + max_streams); + session->ietf_streamid_manager_.SetMaxOpenIncomingBidirectionalStreams( + max_streams); + return; + } + session->stream_id_manager_.set_max_open_incoming_streams(max_streams); +} + +// static +void QuicSessionPeer::SetMaxOpenIncomingBidirectionalStreams( + QuicSession* session, uint32_t max_streams) { + QUICHE_DCHECK(VersionHasIetfQuicFrames(session->transport_version())) + << "SetmaxOpenIncomingBidirectionalStreams not supported for Google " + "QUIC"; + session->ietf_streamid_manager_.SetMaxOpenIncomingBidirectionalStreams( + max_streams); +} +// static +void QuicSessionPeer::SetMaxOpenIncomingUnidirectionalStreams( + QuicSession* session, uint32_t max_streams) { + QUICHE_DCHECK(VersionHasIetfQuicFrames(session->transport_version())) + << "SetmaxOpenIncomingUnidirectionalStreams not supported for Google " + "QUIC"; + session->ietf_streamid_manager_.SetMaxOpenIncomingUnidirectionalStreams( + max_streams); +} + +// static +void QuicSessionPeer::SetMaxOpenOutgoingStreams(QuicSession* session, + uint32_t max_streams) { + if (VersionHasIetfQuicFrames(session->transport_version())) { + QUIC_BUG(quic_bug_10193_2) + << "SetmaxOpenOutgoingStreams deprecated for IETF QUIC"; + return; + } + session->stream_id_manager_.set_max_open_outgoing_streams(max_streams); +} + +// static +void QuicSessionPeer::SetMaxOpenOutgoingBidirectionalStreams( + QuicSession* session, uint32_t max_streams) { + QUICHE_DCHECK(VersionHasIetfQuicFrames(session->transport_version())) + << "SetmaxOpenOutgoingBidirectionalStreams not supported for Google " + "QUIC"; + session->ietf_streamid_manager_.MaybeAllowNewOutgoingBidirectionalStreams( + max_streams); +} +// static +void QuicSessionPeer::SetMaxOpenOutgoingUnidirectionalStreams( + QuicSession* session, uint32_t max_streams) { + QUICHE_DCHECK(VersionHasIetfQuicFrames(session->transport_version())) + << "SetmaxOpenOutgoingUnidirectionalStreams not supported for Google " + "QUIC"; + session->ietf_streamid_manager_.MaybeAllowNewOutgoingUnidirectionalStreams( + max_streams); +} + +// static +QuicCryptoStream* QuicSessionPeer::GetMutableCryptoStream( + QuicSession* session) { + return session->GetMutableCryptoStream(); +} + +// static +QuicWriteBlockedListInterface* QuicSessionPeer::GetWriteBlockedStreams( + QuicSession* session) { + return session->write_blocked_streams(); +} + +// static +QuicStream* QuicSessionPeer::GetOrCreateStream(QuicSession* session, + QuicStreamId stream_id) { + return session->GetOrCreateStream(stream_id); +} + +// static +absl::flat_hash_map& +QuicSessionPeer::GetLocallyClosedStreamsHighestOffset(QuicSession* session) { + return session->locally_closed_streams_highest_offset_; +} + +// static +QuicSession::StreamMap& QuicSessionPeer::stream_map(QuicSession* session) { + return session->stream_map_; +} + +// static +const QuicSession::ClosedStreams& QuicSessionPeer::closed_streams( + QuicSession* session) { + return *session->closed_streams(); +} + +// static +void QuicSessionPeer::ActivateStream(QuicSession* session, + std::unique_ptr stream) { + return session->ActivateStream(std::move(stream)); +} + +// static +bool QuicSessionPeer::IsStreamClosed(QuicSession* session, QuicStreamId id) { + return session->IsClosedStream(id); +} + +// static +bool QuicSessionPeer::IsStreamCreated(QuicSession* session, QuicStreamId id) { + return session->stream_map_.contains(id); +} + +// static +bool QuicSessionPeer::IsStreamAvailable(QuicSession* session, QuicStreamId id) { + if (VersionHasIetfQuicFrames(session->transport_version())) { + if (id % QuicUtils::StreamIdDelta(session->transport_version()) < 2) { + return session->ietf_streamid_manager_.bidirectional_stream_id_manager_ + .available_streams_.contains(id); + } + return session->ietf_streamid_manager_.unidirectional_stream_id_manager_ + .available_streams_.contains(id); + } + return session->stream_id_manager_.available_streams_.contains(id); +} + +// static +QuicStream* QuicSessionPeer::GetStream(QuicSession* session, QuicStreamId id) { + return session->GetStream(id); +} + +// static +bool QuicSessionPeer::IsStreamWriteBlocked(QuicSession* session, + QuicStreamId id) { + return session->write_blocked_streams()->IsStreamBlocked(id); +} + +// static +QuicAlarm* QuicSessionPeer::GetCleanUpClosedStreamsAlarm(QuicSession* session) { + return session->closed_streams_clean_up_alarm_.get(); +} + +// static +LegacyQuicStreamIdManager* QuicSessionPeer::GetStreamIdManager( + QuicSession* session) { + return &session->stream_id_manager_; +} + +// static +UberQuicStreamIdManager* QuicSessionPeer::ietf_streamid_manager( + QuicSession* session) { + return &session->ietf_streamid_manager_; +} + +// static +QuicStreamIdManager* QuicSessionPeer::ietf_bidirectional_stream_id_manager( + QuicSession* session) { + return &session->ietf_streamid_manager_.bidirectional_stream_id_manager_; +} + +// static +QuicStreamIdManager* QuicSessionPeer::ietf_unidirectional_stream_id_manager( + QuicSession* session) { + return &session->ietf_streamid_manager_.unidirectional_stream_id_manager_; +} + +// static +PendingStream* QuicSessionPeer::GetPendingStream(QuicSession* session, + QuicStreamId stream_id) { + auto it = session->pending_stream_map_.find(stream_id); + return it == session->pending_stream_map_.end() ? nullptr : it->second.get(); +} + +// static +void QuicSessionPeer::set_is_configured(QuicSession* session, bool value) { + session->is_configured_ = value; +} + +// static +void QuicSessionPeer::SetPerspective(QuicSession* session, + Perspective perspective) { + session->perspective_ = perspective; +} + +// static +size_t QuicSessionPeer::GetNumOpenDynamicStreams(QuicSession* session) { + size_t result = 0; + for (const auto& it : session->stream_map_) { + if (!it.second->is_static()) { + ++result; + } + } + // Exclude draining streams. + result -= session->num_draining_streams_; + // Add locally closed streams. + result += session->locally_closed_streams_highest_offset_.size(); + + return result; +} + +// static +size_t QuicSessionPeer::GetNumDrainingStreams(QuicSession* session) { + return session->num_draining_streams_; +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_session_peer.h b/quiche/quic/test_tools/quic_session_peer.h new file mode 100644 index 000000000000..f0e83c90041a --- /dev/null +++ b/quiche/quic/test_tools/quic_session_peer.h @@ -0,0 +1,96 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_SESSION_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_SESSION_PEER_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/core/quic_write_blocked_list.h" + +namespace quic { + +class QuicCryptoStream; +class QuicSession; +class QuicStream; + +namespace test { + +class QuicSessionPeer { + public: + QuicSessionPeer() = delete; + + static QuicStreamId GetNextOutgoingBidirectionalStreamId( + QuicSession* session); + static QuicStreamId GetNextOutgoingUnidirectionalStreamId( + QuicSession* session); + static void SetNextOutgoingBidirectionalStreamId(QuicSession* session, + QuicStreamId id); + // Following is only for Google-QUIC, will QUIC_BUG if called for IETF + // QUIC. + static void SetMaxOpenIncomingStreams(QuicSession* session, + uint32_t max_streams); + // Following two are only for IETF-QUIC, will QUIC_BUG if called for Google + // QUIC. + static void SetMaxOpenIncomingBidirectionalStreams(QuicSession* session, + uint32_t max_streams); + static void SetMaxOpenIncomingUnidirectionalStreams(QuicSession* session, + uint32_t max_streams); + + static void SetMaxOpenOutgoingStreams(QuicSession* session, + uint32_t max_streams); + static void SetMaxOpenOutgoingBidirectionalStreams(QuicSession* session, + uint32_t max_streams); + static void SetMaxOpenOutgoingUnidirectionalStreams(QuicSession* session, + uint32_t max_streams); + + static QuicCryptoStream* GetMutableCryptoStream(QuicSession* session); + static QuicWriteBlockedListInterface* GetWriteBlockedStreams( + QuicSession* session); + static QuicStream* GetOrCreateStream(QuicSession* session, + QuicStreamId stream_id); + static absl::flat_hash_map& + GetLocallyClosedStreamsHighestOffset(QuicSession* session); + static QuicSession::StreamMap& stream_map(QuicSession* session); + static const QuicSession::ClosedStreams& closed_streams(QuicSession* session); + static void ActivateStream(QuicSession* session, + std::unique_ptr stream); + + // Discern the state of a stream. Exactly one of these should be true at a + // time for any stream id > 0 (other than the special streams 1 and 3). + static bool IsStreamClosed(QuicSession* session, QuicStreamId id); + static bool IsStreamCreated(QuicSession* session, QuicStreamId id); + static bool IsStreamAvailable(QuicSession* session, QuicStreamId id); + + static QuicStream* GetStream(QuicSession* session, QuicStreamId id); + static bool IsStreamWriteBlocked(QuicSession* session, QuicStreamId id); + static QuicAlarm* GetCleanUpClosedStreamsAlarm(QuicSession* session); + static LegacyQuicStreamIdManager* GetStreamIdManager(QuicSession* session); + static UberQuicStreamIdManager* ietf_streamid_manager(QuicSession* session); + static QuicStreamIdManager* ietf_bidirectional_stream_id_manager( + QuicSession* session); + static QuicStreamIdManager* ietf_unidirectional_stream_id_manager( + QuicSession* session); + static PendingStream* GetPendingStream(QuicSession* session, + QuicStreamId stream_id); + static void set_is_configured(QuicSession* session, bool value); + static void SetPerspective(QuicSession* session, Perspective perspective); + static size_t GetNumOpenDynamicStreams(QuicSession* session); + static size_t GetNumDrainingStreams(QuicSession* session); + static QuicStreamId GetLargestPeerCreatedStreamId(QuicSession* session, + bool unidirectional) { + return session->GetLargestPeerCreatedStreamId(unidirectional); + } +}; + +} // namespace test + +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_SESSION_PEER_H_ diff --git a/quiche/quic/test_tools/quic_spdy_session_peer.cc b/quiche/quic/test_tools/quic_spdy_session_peer.cc new file mode 100644 index 000000000000..ea9206709556 --- /dev/null +++ b/quiche/quic/test_tools/quic_spdy_session_peer.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_spdy_session_peer.h" + +#include "quiche/quic/core/http/quic_spdy_session.h" +#include "quiche/quic/core/qpack/qpack_receive_stream.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/test_tools/quic_session_peer.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace quic { +namespace test { + +// static +QuicHeadersStream* QuicSpdySessionPeer::GetHeadersStream( + QuicSpdySession* session) { + QUICHE_DCHECK(!VersionUsesHttp3(session->transport_version())); + return session->headers_stream(); +} + +void QuicSpdySessionPeer::SetHeadersStream(QuicSpdySession* session, + QuicHeadersStream* headers_stream) { + QUICHE_DCHECK(!VersionUsesHttp3(session->transport_version())); + for (auto& it : QuicSessionPeer::stream_map(session)) { + if (it.first == + QuicUtils::GetHeadersStreamId(session->transport_version())) { + it.second.reset(headers_stream); + session->headers_stream_ = static_cast(it.second.get()); + break; + } + } +} + +// static +spdy::SpdyFramer* QuicSpdySessionPeer::GetSpdyFramer(QuicSpdySession* session) { + return &session->spdy_framer_; +} + +void QuicSpdySessionPeer::SetMaxInboundHeaderListSize( + QuicSpdySession* session, size_t max_inbound_header_size) { + session->set_max_inbound_header_list_size(max_inbound_header_size); +} + +// static +size_t QuicSpdySessionPeer::WriteHeadersOnHeadersStream( + QuicSpdySession* session, QuicStreamId id, spdy::Http2HeaderBlock headers, + bool fin, const spdy::SpdyStreamPrecedence& precedence, + quiche::QuicheReferenceCountedPointer + ack_listener) { + return session->WriteHeadersOnHeadersStream( + id, std::move(headers), fin, precedence, std::move(ack_listener)); +} + +// static +QuicStreamId QuicSpdySessionPeer::GetNextOutgoingUnidirectionalStreamId( + QuicSpdySession* session) { + return session->GetNextOutgoingUnidirectionalStreamId(); +} + +// static +QuicReceiveControlStream* QuicSpdySessionPeer::GetReceiveControlStream( + QuicSpdySession* session) { + return session->receive_control_stream_; +} + +// static +QuicSendControlStream* QuicSpdySessionPeer::GetSendControlStream( + QuicSpdySession* session) { + return session->send_control_stream_; +} + +// static +QpackSendStream* QuicSpdySessionPeer::GetQpackDecoderSendStream( + QuicSpdySession* session) { + return session->qpack_decoder_send_stream_; +} + +// static +QpackSendStream* QuicSpdySessionPeer::GetQpackEncoderSendStream( + QuicSpdySession* session) { + return session->qpack_encoder_send_stream_; +} + +// static +QpackReceiveStream* QuicSpdySessionPeer::GetQpackDecoderReceiveStream( + QuicSpdySession* session) { + return session->qpack_decoder_receive_stream_; +} + +// static +QpackReceiveStream* QuicSpdySessionPeer::GetQpackEncoderReceiveStream( + QuicSpdySession* session) { + return session->qpack_encoder_receive_stream_; +} + +// static +void QuicSpdySessionPeer::SetHttpDatagramSupport( + QuicSpdySession* session, HttpDatagramSupport http_datagram_support) { + session->http_datagram_support_ = http_datagram_support; +} + +// static +HttpDatagramSupport QuicSpdySessionPeer::LocalHttpDatagramSupport( + QuicSpdySession* session) { + return session->LocalHttpDatagramSupport(); +} + +// static +void QuicSpdySessionPeer::EnableWebTransport(QuicSpdySession* session) { + QUICHE_DCHECK(session->WillNegotiateWebTransport()); + SetHttpDatagramSupport(session, HttpDatagramSupport::kDraft04); + session->peer_supports_webtransport_ = true; +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_spdy_session_peer.h b/quiche/quic/test_tools/quic_spdy_session_peer.h new file mode 100644 index 000000000000..87d38eee2ddb --- /dev/null +++ b/quiche/quic/test_tools/quic_spdy_session_peer.h @@ -0,0 +1,62 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_SPDY_SESSION_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_SPDY_SESSION_PEER_H_ + +#include "quiche/quic/core/http/quic_receive_control_stream.h" +#include "quiche/quic/core/http/quic_send_control_stream.h" +#include "quiche/quic/core/http/quic_spdy_session.h" +#include "quiche/quic/core/qpack/qpack_receive_stream.h" +#include "quiche/quic/core/qpack/qpack_send_stream.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_write_blocked_list.h" +#include "quiche/spdy/core/http2_header_block.h" +#include "quiche/spdy/core/spdy_framer.h" + +namespace quic { + +class QuicHeadersStream; + +namespace test { + +class QuicSpdySessionPeer { + public: + QuicSpdySessionPeer() = delete; + + static QuicHeadersStream* GetHeadersStream(QuicSpdySession* session); + static void SetHeadersStream(QuicSpdySession* session, + QuicHeadersStream* headers_stream); + static spdy::SpdyFramer* GetSpdyFramer(QuicSpdySession* session); + // Must be called before Initialize(). + static void SetMaxInboundHeaderListSize(QuicSpdySession* session, + size_t max_inbound_header_size); + static size_t WriteHeadersOnHeadersStream( + QuicSpdySession* session, QuicStreamId id, spdy::Http2HeaderBlock headers, + bool fin, const spdy::SpdyStreamPrecedence& precedence, + quiche::QuicheReferenceCountedPointer + ack_listener); + // |session| can't be nullptr. + static QuicStreamId GetNextOutgoingUnidirectionalStreamId( + QuicSpdySession* session); + static QuicReceiveControlStream* GetReceiveControlStream( + QuicSpdySession* session); + static QuicSendControlStream* GetSendControlStream(QuicSpdySession* session); + static QpackSendStream* GetQpackDecoderSendStream(QuicSpdySession* session); + static QpackSendStream* GetQpackEncoderSendStream(QuicSpdySession* session); + static QpackReceiveStream* GetQpackDecoderReceiveStream( + QuicSpdySession* session); + static QpackReceiveStream* GetQpackEncoderReceiveStream( + QuicSpdySession* session); + static void SetHttpDatagramSupport(QuicSpdySession* session, + HttpDatagramSupport http_datagram_support); + static HttpDatagramSupport LocalHttpDatagramSupport(QuicSpdySession* session); + static void EnableWebTransport(QuicSpdySession* session); +}; + +} // namespace test + +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_SPDY_SESSION_PEER_H_ diff --git a/quiche/quic/test_tools/quic_spdy_stream_peer.cc b/quiche/quic/test_tools/quic_spdy_stream_peer.cc new file mode 100644 index 000000000000..15806b372ab0 --- /dev/null +++ b/quiche/quic/test_tools/quic_spdy_stream_peer.cc @@ -0,0 +1,33 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_spdy_stream_peer.h" + +#include "quiche/quic/core/http/quic_spdy_stream.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { +namespace test { + +// static +void QuicSpdyStreamPeer::set_ack_listener( + QuicSpdyStream* stream, + quiche::QuicheReferenceCountedPointer + ack_listener) { + stream->set_ack_listener(std::move(ack_listener)); +} + +// static +const QuicIntervalSet& +QuicSpdyStreamPeer::unacked_frame_headers_offsets(QuicSpdyStream* stream) { + return stream->unacked_frame_headers_offsets_; +} + +// static +bool QuicSpdyStreamPeer::OnHeadersFrameEnd(QuicSpdyStream* stream) { + return stream->OnHeadersFrameEnd(); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_spdy_stream_peer.h b/quiche/quic/test_tools/quic_spdy_stream_peer.h new file mode 100644 index 000000000000..8b5083d26200 --- /dev/null +++ b/quiche/quic/test_tools/quic_spdy_stream_peer.h @@ -0,0 +1,33 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_SPDY_STREAM_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_SPDY_STREAM_PEER_H_ + +#include "quiche/quic/core/quic_ack_listener_interface.h" +#include "quiche/quic/core/quic_interval_set.h" + +namespace quic { + +class QpackDecodedHeadersAccumulator; +class QuicSpdyStream; + +namespace test { + +class QuicSpdyStreamPeer { + public: + static void set_ack_listener( + QuicSpdyStream* stream, + quiche::QuicheReferenceCountedPointer + ack_listener); + static const QuicIntervalSet& unacked_frame_headers_offsets( + QuicSpdyStream* stream); + static bool OnHeadersFrameEnd(QuicSpdyStream* stream); +}; + +} // namespace test + +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_SPDY_STREAM_PEER_H_ diff --git a/quiche/quic/test_tools/quic_stream_id_manager_peer.cc b/quiche/quic/test_tools/quic_stream_id_manager_peer.cc new file mode 100644 index 000000000000..eaf3b00f7f27 --- /dev/null +++ b/quiche/quic/test_tools/quic_stream_id_manager_peer.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +#include "quiche/quic/test_tools/quic_stream_id_manager_peer.h" + +#include "quiche/quic/core/quic_stream_id_manager.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/uber_quic_stream_id_manager.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { +namespace test { + +// static +void QuicStreamIdManagerPeer::set_incoming_actual_max_streams( + QuicStreamIdManager* stream_id_manager, QuicStreamCount count) { + stream_id_manager->incoming_actual_max_streams_ = count; +} + +// static +void QuicStreamIdManagerPeer::set_outgoing_max_streams( + QuicStreamIdManager* stream_id_manager, QuicStreamCount count) { + stream_id_manager->outgoing_max_streams_ = count; +} + +// static +QuicStreamId QuicStreamIdManagerPeer::GetFirstIncomingStreamId( + QuicStreamIdManager* stream_id_manager) { + return stream_id_manager->GetFirstIncomingStreamId(); +} + +// static +bool QuicStreamIdManagerPeer::get_unidirectional( + QuicStreamIdManager* stream_id_manager) { + return stream_id_manager->unidirectional_; +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_stream_id_manager_peer.h b/quiche/quic/test_tools/quic_stream_id_manager_peer.h new file mode 100644 index 000000000000..509d062e14d6 --- /dev/null +++ b/quiche/quic/test_tools/quic_stream_id_manager_peer.h @@ -0,0 +1,38 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_STREAM_ID_MANAGER_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_STREAM_ID_MANAGER_PEER_H_ + +#include + +#include "quiche/quic/core/quic_types.h" + +namespace quic { + +class QuicStreamIdManager; +class UberQuicStreamIdManager; + +namespace test { + +class QuicStreamIdManagerPeer { + public: + QuicStreamIdManagerPeer() = delete; + + static void set_incoming_actual_max_streams( + QuicStreamIdManager* stream_id_manager, QuicStreamCount count); + static void set_outgoing_max_streams(QuicStreamIdManager* stream_id_manager, + QuicStreamCount count); + + static QuicStreamId GetFirstIncomingStreamId( + QuicStreamIdManager* stream_id_manager); + + static bool get_unidirectional(QuicStreamIdManager* stream_id_manager); +}; + +} // namespace test + +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_STREAM_ID_MANAGER_PEER_H_ diff --git a/quiche/quic/test_tools/quic_stream_peer.cc b/quiche/quic/test_tools/quic_stream_peer.cc new file mode 100644 index 000000000000..0fdaa58dd9ea --- /dev/null +++ b/quiche/quic/test_tools/quic_stream_peer.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_stream_peer.h" + +#include + +#include "quiche/quic/core/quic_stream.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/test_tools/quic_flow_controller_peer.h" +#include "quiche/quic/test_tools/quic_stream_send_buffer_peer.h" + +namespace quic { +namespace test { + +// static +void QuicStreamPeer::SetWriteSideClosed(bool value, QuicStream* stream) { + stream->write_side_closed_ = value; +} + +// static +void QuicStreamPeer::SetStreamBytesWritten( + QuicStreamOffset stream_bytes_written, QuicStream* stream) { + stream->send_buffer_.stream_bytes_written_ = stream_bytes_written; + stream->send_buffer_.stream_bytes_outstanding_ = stream_bytes_written; + QuicStreamSendBufferPeer::SetStreamOffset(&stream->send_buffer_, + stream_bytes_written); +} + +// static +void QuicStreamPeer::SetSendWindowOffset(QuicStream* stream, + QuicStreamOffset offset) { + QuicFlowControllerPeer::SetSendWindowOffset(&*stream->flow_controller_, + offset); +} + +// static +QuicByteCount QuicStreamPeer::bytes_consumed(QuicStream* stream) { + return stream->flow_controller_->bytes_consumed(); +} + +// static +void QuicStreamPeer::SetReceiveWindowOffset(QuicStream* stream, + QuicStreamOffset offset) { + QuicFlowControllerPeer::SetReceiveWindowOffset(&*stream->flow_controller_, + offset); +} + +// static +void QuicStreamPeer::SetMaxReceiveWindow(QuicStream* stream, + QuicStreamOffset size) { + QuicFlowControllerPeer::SetMaxReceiveWindow(&*stream->flow_controller_, size); +} + +// static +QuicByteCount QuicStreamPeer::SendWindowSize(QuicStream* stream) { + return stream->flow_controller_->SendWindowSize(); +} + +// static +QuicStreamOffset QuicStreamPeer::ReceiveWindowOffset(QuicStream* stream) { + return QuicFlowControllerPeer::ReceiveWindowOffset( + &*stream->flow_controller_); +} + +// static +QuicByteCount QuicStreamPeer::ReceiveWindowSize(QuicStream* stream) { + return QuicFlowControllerPeer::ReceiveWindowSize(&*stream->flow_controller_); +} + +// static +QuicStreamOffset QuicStreamPeer::SendWindowOffset(QuicStream* stream) { + return stream->flow_controller_->send_window_offset(); +} + +// static +bool QuicStreamPeer::read_side_closed(QuicStream* stream) { + return stream->read_side_closed_; +} + +// static +void QuicStreamPeer::CloseReadSide(QuicStream* stream) { + stream->CloseReadSide(); +} + +// static +bool QuicStreamPeer::StreamContributesToConnectionFlowControl( + QuicStream* stream) { + return stream->stream_contributes_to_connection_flow_control_; +} + +// static +QuicStreamSequencer* QuicStreamPeer::sequencer(QuicStream* stream) { + return &(stream->sequencer_); +} + +// static +QuicSession* QuicStreamPeer::session(QuicStream* stream) { + return stream->session(); +} + +// static +QuicStreamSendBuffer& QuicStreamPeer::SendBuffer(QuicStream* stream) { + return stream->send_buffer_; +} + +// static +void QuicStreamPeer::SetFinReceived(QuicStream* stream) { + stream->fin_received_ = true; +} + +// static +void QuicStreamPeer::SetFinSent(QuicStream* stream) { + stream->fin_sent_ = true; +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_stream_peer.h b/quiche/quic/test_tools/quic_stream_peer.h new file mode 100644 index 000000000000..3525debcad0d --- /dev/null +++ b/quiche/quic/test_tools/quic_stream_peer.h @@ -0,0 +1,55 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_STREAM_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_STREAM_PEER_H_ + +#include + +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_stream_send_buffer.h" +#include "quiche/quic/core/quic_stream_sequencer.h" +#include "quiche/quic/core/quic_types.h" + +namespace quic { + +class QuicStream; +class QuicSession; + +namespace test { + +class QuicStreamPeer { + public: + QuicStreamPeer() = delete; + + static void SetWriteSideClosed(bool value, QuicStream* stream); + static void SetStreamBytesWritten(QuicStreamOffset stream_bytes_written, + QuicStream* stream); + static void SetSendWindowOffset(QuicStream* stream, QuicStreamOffset offset); + static void SetReceiveWindowOffset(QuicStream* stream, + QuicStreamOffset offset); + static void SetMaxReceiveWindow(QuicStream* stream, QuicStreamOffset size); + static bool read_side_closed(QuicStream* stream); + static void CloseReadSide(QuicStream* stream); + static QuicByteCount bytes_consumed(QuicStream* stream); + static QuicByteCount ReceiveWindowSize(QuicStream* stream); + static QuicByteCount SendWindowSize(QuicStream* stream); + static QuicStreamOffset SendWindowOffset(QuicStream* stream); + static QuicStreamOffset ReceiveWindowOffset(QuicStream* stream); + + static bool StreamContributesToConnectionFlowControl(QuicStream* stream); + + static QuicStreamSequencer* sequencer(QuicStream* stream); + static QuicSession* session(QuicStream* stream); + static void SetFinReceived(QuicStream* stream); + static void SetFinSent(QuicStream* stream); + + static QuicStreamSendBuffer& SendBuffer(QuicStream* stream); +}; + +} // namespace test + +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_STREAM_PEER_H_ diff --git a/quiche/quic/test_tools/quic_stream_send_buffer_peer.cc b/quiche/quic/test_tools/quic_stream_send_buffer_peer.cc new file mode 100644 index 000000000000..c81b45ec5b76 --- /dev/null +++ b/quiche/quic/test_tools/quic_stream_send_buffer_peer.cc @@ -0,0 +1,54 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_stream_send_buffer_peer.h" + +#include "quiche/quic/test_tools/quic_interval_deque_peer.h" + +namespace quic { + +namespace test { + +// static +void QuicStreamSendBufferPeer::SetStreamOffset( + QuicStreamSendBuffer* send_buffer, QuicStreamOffset stream_offset) { + send_buffer->stream_offset_ = stream_offset; +} + +// static +const BufferedSlice* QuicStreamSendBufferPeer::CurrentWriteSlice( + QuicStreamSendBuffer* send_buffer) { + auto wi = write_index(send_buffer); + + if (wi == -1) { + return nullptr; + } + return QuicIntervalDequePeer::GetItem(&send_buffer->interval_deque_, wi); +} + +QuicStreamOffset QuicStreamSendBufferPeer::EndOffset( + QuicStreamSendBuffer* send_buffer) { + return send_buffer->current_end_offset_; +} + +// static +QuicByteCount QuicStreamSendBufferPeer::TotalLength( + QuicStreamSendBuffer* send_buffer) { + QuicByteCount length = 0; + for (auto slice = send_buffer->interval_deque_.DataBegin(); + slice != send_buffer->interval_deque_.DataEnd(); ++slice) { + length += slice->slice.length(); + } + return length; +} + +// static +int32_t QuicStreamSendBufferPeer::write_index( + QuicStreamSendBuffer* send_buffer) { + return QuicIntervalDequePeer::GetCachedIndex(&send_buffer->interval_deque_); +} + +} // namespace test + +} // namespace quic diff --git a/quiche/quic/test_tools/quic_stream_send_buffer_peer.h b/quiche/quic/test_tools/quic_stream_send_buffer_peer.h new file mode 100644 index 000000000000..979229023c1d --- /dev/null +++ b/quiche/quic/test_tools/quic_stream_send_buffer_peer.h @@ -0,0 +1,33 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_STREAM_SEND_BUFFER_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_STREAM_SEND_BUFFER_PEER_H_ + +#include "quiche/quic/core/quic_stream_send_buffer.h" + +namespace quic { + +namespace test { + +class QuicStreamSendBufferPeer { + public: + static void SetStreamOffset(QuicStreamSendBuffer* send_buffer, + QuicStreamOffset stream_offset); + + static const BufferedSlice* CurrentWriteSlice( + QuicStreamSendBuffer* send_buffer); + + static QuicStreamOffset EndOffset(QuicStreamSendBuffer* send_buffer); + + static QuicByteCount TotalLength(QuicStreamSendBuffer* send_buffer); + + static int32_t write_index(QuicStreamSendBuffer* send_buffer); +}; + +} // namespace test + +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_STREAM_SEND_BUFFER_PEER_H_ diff --git a/quiche/quic/test_tools/quic_stream_sequencer_buffer_peer.cc b/quiche/quic/test_tools/quic_stream_sequencer_buffer_peer.cc new file mode 100644 index 000000000000..679bd91ac540 --- /dev/null +++ b/quiche/quic/test_tools/quic_stream_sequencer_buffer_peer.cc @@ -0,0 +1,163 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_stream_sequencer_buffer_peer.h" + +#include + +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +using BufferBlock = quic::QuicStreamSequencerBuffer::BufferBlock; + +static const size_t kBlockSizeBytes = + quic::QuicStreamSequencerBuffer::kBlockSizeBytes; + +namespace quic { +namespace test { + +QuicStreamSequencerBufferPeer::QuicStreamSequencerBufferPeer( + QuicStreamSequencerBuffer* buffer) + : buffer_(buffer) {} + +// Read from this buffer_ into the given destination buffer_ up to the +// size of the destination. Returns the number of bytes read. Reading from +// an empty buffer_->returns 0. +size_t QuicStreamSequencerBufferPeer::Read(char* dest_buffer, size_t size) { + iovec dest; + dest.iov_base = dest_buffer, dest.iov_len = size; + size_t bytes_read; + std::string error_details; + EXPECT_THAT(buffer_->Readv(&dest, 1, &bytes_read, &error_details), + IsQuicNoError()); + return bytes_read; +} + +// If buffer is empty, the blocks_ array must be empty, which means all +// blocks are deallocated. +bool QuicStreamSequencerBufferPeer::CheckEmptyInvariants() { + return !buffer_->Empty() || IsBlockArrayEmpty(); +} + +bool QuicStreamSequencerBufferPeer::IsBlockArrayEmpty() { + if (buffer_->blocks_ == nullptr) { + return true; + } + + size_t count = current_blocks_count(); + for (size_t i = 0; i < count; i++) { + if (buffer_->blocks_[i] != nullptr) { + return false; + } + } + return true; +} + +bool QuicStreamSequencerBufferPeer::CheckInitialState() { + EXPECT_TRUE(buffer_->Empty() && buffer_->total_bytes_read_ == 0 && + buffer_->num_bytes_buffered_ == 0); + return CheckBufferInvariants(); +} + +bool QuicStreamSequencerBufferPeer::CheckBufferInvariants() { + QuicStreamOffset data_span = + buffer_->NextExpectedByte() - buffer_->total_bytes_read_; + bool capacity_sane = data_span <= buffer_->max_buffer_capacity_bytes_ && + data_span >= buffer_->num_bytes_buffered_; + if (!capacity_sane) { + QUIC_LOG(ERROR) << "data span is larger than capacity."; + QUIC_LOG(ERROR) << "total read: " << buffer_->total_bytes_read_ + << " last byte: " << buffer_->NextExpectedByte(); + } + bool total_read_sane = + buffer_->FirstMissingByte() >= buffer_->total_bytes_read_; + if (!total_read_sane) { + QUIC_LOG(ERROR) << "read across 1st gap."; + } + bool read_offset_sane = buffer_->ReadOffset() < kBlockSizeBytes; + if (!capacity_sane) { + QUIC_LOG(ERROR) << "read offset go beyond 1st block"; + } + bool block_match_capacity = + (buffer_->max_buffer_capacity_bytes_ <= + buffer_->max_blocks_count_ * kBlockSizeBytes) && + (buffer_->max_buffer_capacity_bytes_ > + (buffer_->max_blocks_count_ - 1) * kBlockSizeBytes); + if (!capacity_sane) { + QUIC_LOG(ERROR) << "block number not match capcaity."; + } + bool block_retired_when_empty = CheckEmptyInvariants(); + if (!block_retired_when_empty) { + QUIC_LOG(ERROR) << "block is not retired after use."; + } + return capacity_sane && total_read_sane && read_offset_sane && + block_match_capacity && block_retired_when_empty; +} + +size_t QuicStreamSequencerBufferPeer::GetInBlockOffset( + QuicStreamOffset offset) { + return buffer_->GetInBlockOffset(offset); +} + +BufferBlock* QuicStreamSequencerBufferPeer::GetBlock(size_t index) { + return buffer_->blocks_[index]; +} + +int QuicStreamSequencerBufferPeer::IntervalSize() { + if (buffer_->bytes_received_.Empty()) { + return 1; + } + int gap_size = buffer_->bytes_received_.Size() + 1; + if (buffer_->bytes_received_.Empty()) { + return gap_size; + } + if (buffer_->bytes_received_.begin()->min() == 0) { + --gap_size; + } + if (buffer_->bytes_received_.rbegin()->max() == + std::numeric_limits::max()) { + --gap_size; + } + return gap_size; +} + +size_t QuicStreamSequencerBufferPeer::max_buffer_capacity() { + return buffer_->max_buffer_capacity_bytes_; +} + +size_t QuicStreamSequencerBufferPeer::ReadableBytes() { + return buffer_->ReadableBytes(); +} + +void QuicStreamSequencerBufferPeer::set_total_bytes_read( + QuicStreamOffset total_bytes_read) { + buffer_->total_bytes_read_ = total_bytes_read; +} + +void QuicStreamSequencerBufferPeer::AddBytesReceived(QuicStreamOffset offset, + QuicByteCount length) { + buffer_->bytes_received_.Add(offset, offset + length); +} + +bool QuicStreamSequencerBufferPeer::IsBufferAllocated() { + return buffer_->blocks_ != nullptr; +} + +size_t QuicStreamSequencerBufferPeer::max_blocks_count() { + return buffer_->max_blocks_count_; +} + +size_t QuicStreamSequencerBufferPeer::current_blocks_count() { + return buffer_->current_blocks_count_; +} + +const QuicIntervalSet& +QuicStreamSequencerBufferPeer::bytes_received() { + return buffer_->bytes_received_; +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_stream_sequencer_buffer_peer.h b/quiche/quic/test_tools/quic_stream_sequencer_buffer_peer.h new file mode 100644 index 000000000000..eac892a6923f --- /dev/null +++ b/quiche/quic/test_tools/quic_stream_sequencer_buffer_peer.h @@ -0,0 +1,65 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_STREAM_SEQUENCER_BUFFER_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_STREAM_SEQUENCER_BUFFER_PEER_H_ + +#include "quiche/quic/core/quic_stream_sequencer_buffer.h" + +namespace quic { + +namespace test { + +class QuicStreamSequencerBufferPeer { + public: + explicit QuicStreamSequencerBufferPeer(QuicStreamSequencerBuffer* buffer); + QuicStreamSequencerBufferPeer(const QuicStreamSequencerBufferPeer&) = delete; + QuicStreamSequencerBufferPeer& operator=( + const QuicStreamSequencerBufferPeer&) = delete; + + // Read from this buffer_ into the given destination buffer_ up to the + // size of the destination. Returns the number of bytes read. Reading from + // an empty buffer_->returns 0. + size_t Read(char* dest_buffer, size_t size); + + // If buffer is empty, the blocks_ array must be empty, which means all + // blocks are deallocated. + bool CheckEmptyInvariants(); + + bool IsBlockArrayEmpty(); + + bool CheckInitialState(); + + bool CheckBufferInvariants(); + + size_t GetInBlockOffset(QuicStreamOffset offset); + + QuicStreamSequencerBuffer::BufferBlock* GetBlock(size_t index); + + int IntervalSize(); + + size_t max_buffer_capacity(); + + size_t ReadableBytes(); + + void set_total_bytes_read(QuicStreamOffset total_bytes_read); + + void AddBytesReceived(QuicStreamOffset offset, QuicByteCount length); + + bool IsBufferAllocated(); + + size_t max_blocks_count(); + + size_t current_blocks_count(); + + const QuicIntervalSet& bytes_received(); + + private: + QuicStreamSequencerBuffer* buffer_; +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_STREAM_SEQUENCER_BUFFER_PEER_H_ diff --git a/quiche/quic/test_tools/quic_stream_sequencer_peer.cc b/quiche/quic/test_tools/quic_stream_sequencer_peer.cc new file mode 100644 index 000000000000..3ba4f626e595 --- /dev/null +++ b/quiche/quic/test_tools/quic_stream_sequencer_peer.cc @@ -0,0 +1,39 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_stream_sequencer_peer.h" + +#include "quiche/quic/core/quic_stream_sequencer.h" +#include "quiche/quic/test_tools/quic_stream_sequencer_buffer_peer.h" + +namespace quic { +namespace test { + +// static +size_t QuicStreamSequencerPeer::GetNumBufferedBytes( + QuicStreamSequencer* sequencer) { + return sequencer->buffered_frames_.BytesBuffered(); +} + +// static +QuicStreamOffset QuicStreamSequencerPeer::GetCloseOffset( + QuicStreamSequencer* sequencer) { + return sequencer->close_offset_; +} + +// static +bool QuicStreamSequencerPeer::IsUnderlyingBufferAllocated( + QuicStreamSequencer* sequencer) { + QuicStreamSequencerBufferPeer buffer_peer(&(sequencer->buffered_frames_)); + return buffer_peer.IsBufferAllocated(); +} + +// static +void QuicStreamSequencerPeer::SetFrameBufferTotalBytesRead( + QuicStreamSequencer* sequencer, QuicStreamOffset total_bytes_read) { + QuicStreamSequencerBufferPeer buffer_peer(&(sequencer->buffered_frames_)); + buffer_peer.set_total_bytes_read(total_bytes_read); +} +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_stream_sequencer_peer.h b/quiche/quic/test_tools/quic_stream_sequencer_peer.h new file mode 100644 index 000000000000..4be113f2f872 --- /dev/null +++ b/quiche/quic/test_tools/quic_stream_sequencer_peer.h @@ -0,0 +1,33 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_STREAM_SEQUENCER_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_STREAM_SEQUENCER_PEER_H_ + +#include "quiche/quic/core/quic_packets.h" + +namespace quic { + +class QuicStreamSequencer; + +namespace test { + +class QuicStreamSequencerPeer { + public: + QuicStreamSequencerPeer() = delete; + + static size_t GetNumBufferedBytes(QuicStreamSequencer* sequencer); + + static QuicStreamOffset GetCloseOffset(QuicStreamSequencer* sequencer); + + static bool IsUnderlyingBufferAllocated(QuicStreamSequencer* sequencer); + + static void SetFrameBufferTotalBytesRead(QuicStreamSequencer* sequencer, + QuicStreamOffset total_bytes_read); +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_STREAM_SEQUENCER_PEER_H_ diff --git a/quiche/quic/test_tools/quic_sustained_bandwidth_recorder_peer.cc b/quiche/quic/test_tools/quic_sustained_bandwidth_recorder_peer.cc new file mode 100644 index 000000000000..46fa83c94209 --- /dev/null +++ b/quiche/quic/test_tools/quic_sustained_bandwidth_recorder_peer.cc @@ -0,0 +1,34 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_sustained_bandwidth_recorder_peer.h" + +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_sustained_bandwidth_recorder.h" + +namespace quic { +namespace test { + +// static +void QuicSustainedBandwidthRecorderPeer::SetBandwidthEstimate( + QuicSustainedBandwidthRecorder* bandwidth_recorder, + int32_t bandwidth_estimate_kbytes_per_second) { + bandwidth_recorder->has_estimate_ = true; + bandwidth_recorder->bandwidth_estimate_ = + QuicBandwidth::FromKBytesPerSecond(bandwidth_estimate_kbytes_per_second); +} + +// static +void QuicSustainedBandwidthRecorderPeer::SetMaxBandwidthEstimate( + QuicSustainedBandwidthRecorder* bandwidth_recorder, + int32_t max_bandwidth_estimate_kbytes_per_second, + int32_t max_bandwidth_timestamp) { + bandwidth_recorder->max_bandwidth_estimate_ = + QuicBandwidth::FromKBytesPerSecond( + max_bandwidth_estimate_kbytes_per_second); + bandwidth_recorder->max_bandwidth_timestamp_ = max_bandwidth_timestamp; +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_sustained_bandwidth_recorder_peer.h b/quiche/quic/test_tools/quic_sustained_bandwidth_recorder_peer.h new file mode 100644 index 000000000000..b60412c5790d --- /dev/null +++ b/quiche/quic/test_tools/quic_sustained_bandwidth_recorder_peer.h @@ -0,0 +1,35 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_SUSTAINED_BANDWIDTH_RECORDER_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_SUSTAINED_BANDWIDTH_RECORDER_PEER_H_ + +#include + +#include "quiche/quic/core/quic_packets.h" + +namespace quic { + +class QuicSustainedBandwidthRecorder; + +namespace test { + +class QuicSustainedBandwidthRecorderPeer { + public: + QuicSustainedBandwidthRecorderPeer() = delete; + + static void SetBandwidthEstimate( + QuicSustainedBandwidthRecorder* bandwidth_recorder, + int32_t bandwidth_estimate_kbytes_per_second); + + static void SetMaxBandwidthEstimate( + QuicSustainedBandwidthRecorder* bandwidth_recorder, + int32_t max_bandwidth_estimate_kbytes_per_second, + int32_t max_bandwidth_timestamp); +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_SUSTAINED_BANDWIDTH_RECORDER_PEER_H_ diff --git a/quiche/quic/test_tools/quic_test_backend.cc b/quiche/quic/test_tools/quic_test_backend.cc new file mode 100644 index 000000000000..983db8e44813 --- /dev/null +++ b/quiche/quic/test_tools/quic_test_backend.cc @@ -0,0 +1,120 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_test_backend.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/web_transport_interface.h" +#include "quiche/quic/test_tools/web_transport_resets_backend.h" +#include "quiche/quic/tools/web_transport_test_visitors.h" + +namespace quic { +namespace test { + +namespace { + +// SessionCloseVisitor implements the "/session-close" endpoint. If the client +// sends a unidirectional stream of format "code message" to this endpoint, it +// will close the session with the corresponding error code and error message. +// For instance, sending "42 test error" will cause it to be closed with code 42 +// and message "test error". +class SessionCloseVisitor : public WebTransportVisitor { + public: + SessionCloseVisitor(WebTransportSession* session) : session_(session) {} + + void OnSessionReady(const spdy::Http2HeaderBlock& /*headers*/) override {} + void OnSessionClosed(WebTransportSessionError /*error_code*/, + const std::string& /*error_message*/) override {} + + void OnIncomingBidirectionalStreamAvailable() override {} + void OnIncomingUnidirectionalStreamAvailable() override { + WebTransportStream* stream = session_->AcceptIncomingUnidirectionalStream(); + if (stream == nullptr) { + return; + } + stream->SetVisitor( + std::make_unique( + stream, [this](const std::string& data) { + std::pair parsed = + absl::StrSplit(data, absl::MaxSplits(' ', 1)); + WebTransportSessionError error_code = 0; + bool success = absl::SimpleAtoi(parsed.first, &error_code); + QUICHE_DCHECK(success) << data; + session_->CloseSession(error_code, parsed.second); + })); + stream->visitor()->OnCanRead(); + } + + void OnDatagramReceived(absl::string_view /*datagram*/) override {} + + void OnCanCreateNewOutgoingBidirectionalStream() override {} + void OnCanCreateNewOutgoingUnidirectionalStream() override {} + + private: + WebTransportSession* session_; // Not owned. +}; + +} // namespace + +QuicSimpleServerBackend::WebTransportResponse +QuicTestBackend::ProcessWebTransportRequest( + const spdy::Http2HeaderBlock& request_headers, + WebTransportSession* session) { + if (!SupportsWebTransport()) { + return QuicSimpleServerBackend::ProcessWebTransportRequest(request_headers, + session); + } + + auto path_it = request_headers.find(":path"); + if (path_it == request_headers.end()) { + WebTransportResponse response; + response.response_headers[":status"] = "400"; + return response; + } + absl::string_view path = path_it->second; + // Match any "/echo.*" pass, e.g. "/echo_foobar" + if (absl::StartsWith(path, "/echo")) { + WebTransportResponse response; + response.response_headers[":status"] = "200"; + // Add response headers if the paramer has "set-header=XXX:YYY" query. + GURL url = GURL(absl::StrCat("https://localhost", path)); + const std::vector& params = absl::StrSplit(url.query(), '&'); + for (const auto& param : params) { + absl::string_view param_view = param; + if (absl::ConsumePrefix(¶m_view, "set-header=")) { + const std::vector header_value = + absl::StrSplit(param_view, ':'); + if (header_value.size() == 2 && + !absl::StartsWith(header_value[0], ":")) { + response.response_headers[header_value[0]] = header_value[1]; + } + } + } + + response.visitor = + std::make_unique(session); + return response; + } + if (path == "/resets") { + return WebTransportResetsBackend(request_headers, session); + } + if (path == "/session-close") { + WebTransportResponse response; + response.response_headers[":status"] = "200"; + response.visitor = std::make_unique(session); + return response; + } + + WebTransportResponse response; + response.response_headers[":status"] = "404"; + return response; +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_test_backend.h b/quiche/quic/test_tools/quic_test_backend.h new file mode 100644 index 000000000000..e59eb9124af4 --- /dev/null +++ b/quiche/quic/test_tools/quic_test_backend.h @@ -0,0 +1,44 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_TEST_BACKEND_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_TEST_BACKEND_H_ + +#include "quiche/quic/tools/quic_memory_cache_backend.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { +namespace test { + +// QuicTestBackend is a QuicSimpleServer backend usable in tests. It has extra +// WebTransport endpoints on top of what QuicMemoryCacheBackend already +// provides. +class QuicTestBackend : public QuicMemoryCacheBackend { + public: + WebTransportResponse ProcessWebTransportRequest( + const spdy::Http2HeaderBlock& request_headers, + WebTransportSession* session) override; + bool SupportsWebTransport() override { return enable_webtransport_; } + + void set_enable_webtransport(bool enable_webtransport) { + QUICHE_DCHECK(!enable_webtransport || enable_extended_connect_); + enable_webtransport_ = enable_webtransport; + } + + bool SupportsExtendedConnect() override { return enable_extended_connect_; } + + void set_enable_extended_connect(bool enable_extended_connect) { + enable_extended_connect_ = enable_extended_connect; + } + + private: + bool enable_webtransport_ = false; + bool enable_extended_connect_ = true; +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_TEST_BACKEND_H_ diff --git a/quiche/quic/test_tools/quic_test_client.cc b/quiche/quic/test_tools/quic_test_client.cc new file mode 100644 index 000000000000..a617348754c9 --- /dev/null +++ b/quiche/quic/test_tools/quic_test_client.cc @@ -0,0 +1,932 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_test_client.h" + +#include +#include +#include + +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" +#include "openssl/x509.h" +#include "quiche/quic/core/crypto/proof_verifier.h" +#include "quiche/quic/core/http/quic_spdy_client_stream.h" +#include "quiche/quic/core/http/spdy_utils.h" +#include "quiche/quic/core/io/quic_default_event_loop.h" +#include "quiche/quic/core/quic_default_clock.h" +#include "quiche/quic/core/quic_packet_writer_wrapper.h" +#include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/core/quic_stream_priority.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_stack_trace.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_spdy_session_peer.h" +#include "quiche/quic/test_tools/quic_spdy_stream_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/tools/quic_url.h" +#include "quiche/common/quiche_text_utils.h" + +namespace quic { +namespace test { +namespace { + +// RecordingProofVerifier accepts any certificate chain and records the common +// name of the leaf and then delegates the actual verification to an actual +// verifier. If no optional verifier is provided, then VerifyProof will return +// success. +class RecordingProofVerifier : public ProofVerifier { + public: + explicit RecordingProofVerifier(std::unique_ptr verifier) + : verifier_(std::move(verifier)) {} + + // ProofVerifier interface. + QuicAsyncStatus VerifyProof( + const std::string& hostname, const uint16_t port, + const std::string& server_config, QuicTransportVersion transport_version, + absl::string_view chlo_hash, const std::vector& certs, + const std::string& cert_sct, const std::string& signature, + const ProofVerifyContext* context, std::string* error_details, + std::unique_ptr* details, + std::unique_ptr callback) override { + QuicAsyncStatus process_certs_result = ProcessCerts(certs, cert_sct); + if (process_certs_result != QUIC_SUCCESS) { + return process_certs_result; + } + + if (!verifier_) { + return QUIC_SUCCESS; + } + + return verifier_->VerifyProof(hostname, port, server_config, + transport_version, chlo_hash, certs, cert_sct, + signature, context, error_details, details, + std::move(callback)); + } + + QuicAsyncStatus VerifyCertChain( + const std::string& /*hostname*/, const uint16_t /*port*/, + const std::vector& certs, + const std::string& /*ocsp_response*/, const std::string& cert_sct, + const ProofVerifyContext* /*context*/, std::string* /*error_details*/, + std::unique_ptr* /*details*/, uint8_t* /*out_alert*/, + std::unique_ptr /*callback*/) override { + return ProcessCerts(certs, cert_sct); + } + + std::unique_ptr CreateDefaultContext() override { + return verifier_ != nullptr ? verifier_->CreateDefaultContext() : nullptr; + } + + const std::string& common_name() const { return common_name_; } + + const std::string& cert_sct() const { return cert_sct_; } + + private: + QuicAsyncStatus ProcessCerts(const std::vector& certs, + const std::string& cert_sct) { + common_name_.clear(); + if (certs.empty()) { + return QUIC_FAILURE; + } + + // Parse the cert into an X509 structure. + const uint8_t* data; + data = reinterpret_cast(certs[0].data()); + bssl::UniquePtr cert(d2i_X509(nullptr, &data, certs[0].size())); + if (!cert.get()) { + return QUIC_FAILURE; + } + + // Extract the CN field + X509_NAME* subject = X509_get_subject_name(cert.get()); + const int index = X509_NAME_get_index_by_NID(subject, NID_commonName, -1); + if (index < 0) { + return QUIC_FAILURE; + } + ASN1_STRING* name_data = + X509_NAME_ENTRY_get_data(X509_NAME_get_entry(subject, index)); + if (name_data == nullptr) { + return QUIC_FAILURE; + } + + // Convert the CN to UTF8, in case the cert represents it in a different + // format. + unsigned char* buf = nullptr; + const int len = ASN1_STRING_to_UTF8(&buf, name_data); + if (len <= 0) { + return QUIC_FAILURE; + } + bssl::UniquePtr deleter(buf); + + common_name_.assign(reinterpret_cast(buf), len); + cert_sct_ = cert_sct; + return QUIC_SUCCESS; + } + + std::unique_ptr verifier_; + std::string common_name_; + std::string cert_sct_; +}; +} // namespace + +class MockableQuicClientDefaultNetworkHelper + : public QuicClientDefaultNetworkHelper { + public: + using QuicClientDefaultNetworkHelper::QuicClientDefaultNetworkHelper; + ~MockableQuicClientDefaultNetworkHelper() override = default; + + void ProcessPacket(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + const QuicReceivedPacket& packet) override { + QuicClientDefaultNetworkHelper::ProcessPacket(self_address, peer_address, + packet); + if (track_last_incoming_packet_) { + last_incoming_packet_ = packet.Clone(); + } + } + + QuicPacketWriter* CreateQuicPacketWriter() override { + QuicPacketWriter* writer = + QuicClientDefaultNetworkHelper::CreateQuicPacketWriter(); + if (!test_writer_) { + return writer; + } + test_writer_->set_writer(writer); + return test_writer_; + } + + const QuicReceivedPacket* last_incoming_packet() { + return last_incoming_packet_.get(); + } + + void set_track_last_incoming_packet(bool track) { + track_last_incoming_packet_ = track; + } + + void UseWriter(QuicPacketWriterWrapper* writer) { + QUICHE_CHECK(test_writer_ == nullptr); + test_writer_ = writer; + } + + void set_peer_address(const QuicSocketAddress& address) { + QUICHE_CHECK(test_writer_ != nullptr); + test_writer_->set_peer_address(address); + } + + private: + QuicPacketWriterWrapper* test_writer_ = nullptr; + // The last incoming packet, iff |track_last_incoming_packet_| is true. + std::unique_ptr last_incoming_packet_; + // If true, copy each packet from ProcessPacket into |last_incoming_packet_| + bool track_last_incoming_packet_ = false; +}; + +MockableQuicClient::MockableQuicClient( + QuicSocketAddress server_address, const QuicServerId& server_id, + const ParsedQuicVersionVector& supported_versions, + QuicEventLoop* event_loop) + : MockableQuicClient(server_address, server_id, QuicConfig(), + supported_versions, event_loop) {} + +MockableQuicClient::MockableQuicClient( + QuicSocketAddress server_address, const QuicServerId& server_id, + const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, + QuicEventLoop* event_loop) + : MockableQuicClient(server_address, server_id, config, supported_versions, + event_loop, nullptr) {} + +MockableQuicClient::MockableQuicClient( + QuicSocketAddress server_address, const QuicServerId& server_id, + const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, + QuicEventLoop* event_loop, std::unique_ptr proof_verifier) + : MockableQuicClient(server_address, server_id, config, supported_versions, + event_loop, std::move(proof_verifier), nullptr) {} + +MockableQuicClient::MockableQuicClient( + QuicSocketAddress server_address, const QuicServerId& server_id, + const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, + QuicEventLoop* event_loop, std::unique_ptr proof_verifier, + std::unique_ptr session_cache) + : QuicDefaultClient( + server_address, server_id, supported_versions, config, event_loop, + std::make_unique(event_loop, + this), + std::make_unique(std::move(proof_verifier)), + std::move(session_cache)), + override_client_connection_id_(EmptyQuicConnectionId()), + client_connection_id_overridden_(false) {} + +MockableQuicClient::~MockableQuicClient() { + if (connected()) { + Disconnect(); + } +} + +MockableQuicClientDefaultNetworkHelper* +MockableQuicClient::mockable_network_helper() { + return static_cast( + default_network_helper()); +} + +const MockableQuicClientDefaultNetworkHelper* +MockableQuicClient::mockable_network_helper() const { + return static_cast( + default_network_helper()); +} + +QuicConnectionId MockableQuicClient::GetClientConnectionId() { + if (client_connection_id_overridden_) { + return override_client_connection_id_; + } + if (override_client_connection_id_length_ >= 0) { + return QuicUtils::CreateRandomConnectionId( + override_client_connection_id_length_); + } + return QuicDefaultClient::GetClientConnectionId(); +} + +void MockableQuicClient::UseClientConnectionId( + QuicConnectionId client_connection_id) { + client_connection_id_overridden_ = true; + override_client_connection_id_ = client_connection_id; +} + +void MockableQuicClient::UseClientConnectionIdLength( + int client_connection_id_length) { + override_client_connection_id_length_ = client_connection_id_length; +} + +void MockableQuicClient::UseWriter(QuicPacketWriterWrapper* writer) { + mockable_network_helper()->UseWriter(writer); +} + +void MockableQuicClient::set_peer_address(const QuicSocketAddress& address) { + mockable_network_helper()->set_peer_address(address); + if (client_session() != nullptr) { + client_session()->connection()->AddKnownServerAddress(address); + } +} + +const QuicReceivedPacket* MockableQuicClient::last_incoming_packet() { + return mockable_network_helper()->last_incoming_packet(); +} + +void MockableQuicClient::set_track_last_incoming_packet(bool track) { + mockable_network_helper()->set_track_last_incoming_packet(track); +} + +QuicTestClient::QuicTestClient( + QuicSocketAddress server_address, const std::string& server_hostname, + const ParsedQuicVersionVector& supported_versions) + : QuicTestClient(server_address, server_hostname, QuicConfig(), + supported_versions) {} + +QuicTestClient::QuicTestClient( + QuicSocketAddress server_address, const std::string& server_hostname, + const QuicConfig& config, const ParsedQuicVersionVector& supported_versions) + : event_loop_(GetDefaultEventLoop()->Create(QuicDefaultClock::Get())), + client_(new MockableQuicClient( + server_address, + QuicServerId(server_hostname, server_address.port(), false), config, + supported_versions, event_loop_.get())) { + Initialize(); +} + +QuicTestClient::QuicTestClient( + QuicSocketAddress server_address, const std::string& server_hostname, + const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, + std::unique_ptr proof_verifier) + : event_loop_(GetDefaultEventLoop()->Create(QuicDefaultClock::Get())), + client_(new MockableQuicClient( + server_address, + QuicServerId(server_hostname, server_address.port(), false), config, + supported_versions, event_loop_.get(), std::move(proof_verifier))) { + Initialize(); +} + +QuicTestClient::QuicTestClient( + QuicSocketAddress server_address, const std::string& server_hostname, + const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, + std::unique_ptr proof_verifier, + std::unique_ptr session_cache) + : event_loop_(GetDefaultEventLoop()->Create(QuicDefaultClock::Get())), + client_(new MockableQuicClient( + server_address, + QuicServerId(server_hostname, server_address.port(), false), config, + supported_versions, event_loop_.get(), std::move(proof_verifier), + std::move(session_cache))) { + Initialize(); +} + +QuicTestClient::QuicTestClient( + QuicSocketAddress server_address, const std::string& server_hostname, + const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, + std::unique_ptr proof_verifier, + std::unique_ptr session_cache, + std::unique_ptr event_loop) + : event_loop_(std::move(event_loop)), + client_(new MockableQuicClient( + server_address, + QuicServerId(server_hostname, server_address.port(), false), config, + supported_versions, event_loop_.get(), std::move(proof_verifier), + std::move(session_cache))) { + Initialize(); +} + +QuicTestClient::QuicTestClient() = default; + +QuicTestClient::~QuicTestClient() { + for (std::pair stream : open_streams_) { + stream.second->set_visitor(nullptr); + } +} + +void QuicTestClient::Initialize() { + priority_ = 3; + connect_attempted_ = false; + auto_reconnect_ = false; + buffer_body_ = true; + num_requests_ = 0; + num_responses_ = 0; + ClearPerConnectionState(); + // As chrome will generally do this, we want it to be the default when it's + // not overridden. + if (!client_->config()->HasSetBytesForConnectionIdToSend()) { + client_->config()->SetBytesForConnectionIdToSend(0); + } +} + +void QuicTestClient::SetUserAgentID(const std::string& user_agent_id) { + client_->SetUserAgentID(user_agent_id); +} + +ssize_t QuicTestClient::SendRequest(const std::string& uri) { + spdy::Http2HeaderBlock headers; + if (!PopulateHeaderBlockFromUrl(uri, &headers)) { + return 0; + } + return SendMessage(headers, ""); +} + +ssize_t QuicTestClient::SendRequestAndRstTogether(const std::string& uri) { + spdy::Http2HeaderBlock headers; + if (!PopulateHeaderBlockFromUrl(uri, &headers)) { + return 0; + } + + QuicSpdyClientSession* session = client()->client_session(); + QuicConnection::ScopedPacketFlusher flusher(session->connection()); + ssize_t ret = SendMessage(headers, "", /*fin=*/true, /*flush=*/false); + + QuicStreamId stream_id = GetNthClientInitiatedBidirectionalStreamId( + session->transport_version(), 0); + session->ResetStream(stream_id, QUIC_STREAM_CANCELLED); + return ret; +} + +void QuicTestClient::SendRequestsAndWaitForResponses( + const std::vector& url_list) { + for (const std::string& url : url_list) { + SendRequest(url); + } + while (client()->WaitForEvents()) { + } +} + +ssize_t QuicTestClient::GetOrCreateStreamAndSendRequest( + const spdy::Http2HeaderBlock* headers, absl::string_view body, bool fin, + quiche::QuicheReferenceCountedPointer + ack_listener) { + if (headers) { + QuicClientPushPromiseIndex::TryHandle* handle; + QuicAsyncStatus rv = + client()->push_promise_index()->Try(*headers, this, &handle); + if (rv == QUIC_SUCCESS) return 1; + if (rv == QUIC_PENDING) { + // May need to retry request if asynchronous rendezvous fails. + std::unique_ptr new_headers( + new spdy::Http2HeaderBlock(headers->Clone())); + push_promise_data_to_resend_ = std::make_unique( + std::move(new_headers), body, fin, this, std::move(ack_listener)); + return 1; + } + } + + // Maybe it's better just to overload this. it's just that we need + // for the GetOrCreateStream function to call something else...which + // is icky and complicated, but maybe not worse than this. + QuicSpdyClientStream* stream = GetOrCreateStream(); + if (stream == nullptr) { + return 0; + } + QuicSpdyStreamPeer::set_ack_listener(stream, ack_listener); + + ssize_t ret = 0; + if (headers != nullptr) { + spdy::Http2HeaderBlock spdy_headers(headers->Clone()); + if (spdy_headers[":authority"].as_string().empty()) { + spdy_headers[":authority"] = client_->server_id().host(); + } + ret = stream->SendRequest(std::move(spdy_headers), body, fin); + ++num_requests_; + } else { + stream->WriteOrBufferBody(std::string(body), fin); + ret = body.length(); + } + return ret; +} + +ssize_t QuicTestClient::SendMessage(const spdy::Http2HeaderBlock& headers, + absl::string_view body) { + return SendMessage(headers, body, /*fin=*/true); +} + +ssize_t QuicTestClient::SendMessage(const spdy::Http2HeaderBlock& headers, + absl::string_view body, bool fin) { + return SendMessage(headers, body, fin, /*flush=*/true); +} + +ssize_t QuicTestClient::SendMessage(const spdy::Http2HeaderBlock& headers, + absl::string_view body, bool fin, + bool flush) { + // Always force creation of a stream for SendMessage. + latest_created_stream_ = nullptr; + + ssize_t ret = GetOrCreateStreamAndSendRequest(&headers, body, fin, nullptr); + + if (flush) { + WaitForWriteToFlush(); + } + return ret; +} + +ssize_t QuicTestClient::SendData(const std::string& data, bool last_data) { + return SendData(data, last_data, nullptr); +} + +ssize_t QuicTestClient::SendData( + const std::string& data, bool last_data, + quiche::QuicheReferenceCountedPointer + ack_listener) { + return GetOrCreateStreamAndSendRequest(nullptr, absl::string_view(data), + last_data, std::move(ack_listener)); +} + +bool QuicTestClient::response_complete() const { return response_complete_; } + +int64_t QuicTestClient::response_body_size() const { + return response_body_size_; +} + +bool QuicTestClient::buffer_body() const { return buffer_body_; } + +void QuicTestClient::set_buffer_body(bool buffer_body) { + buffer_body_ = buffer_body; +} + +const std::string& QuicTestClient::response_body() const { return response_; } + +std::string QuicTestClient::SendCustomSynchronousRequest( + const spdy::Http2HeaderBlock& headers, const std::string& body) { + // Clear connection state here and only track this synchronous request. + ClearPerConnectionState(); + if (SendMessage(headers, body) == 0) { + QUIC_DLOG(ERROR) << "Failed the request for: " << headers.DebugString(); + // Set the response_ explicitly. Otherwise response_ will contain the + // response from the previously successful request. + response_ = ""; + } else { + WaitForResponse(); + } + return response_; +} + +std::string QuicTestClient::SendSynchronousRequest(const std::string& uri) { + spdy::Http2HeaderBlock headers; + if (!PopulateHeaderBlockFromUrl(uri, &headers)) { + return ""; + } + return SendCustomSynchronousRequest(headers, ""); +} + +void QuicTestClient::SendConnectivityProbing() { + QuicConnection* connection = client()->client_session()->connection(); + connection->SendConnectivityProbingPacket(connection->writer(), + connection->peer_address()); +} + +void QuicTestClient::SetLatestCreatedStream(QuicSpdyClientStream* stream) { + latest_created_stream_ = stream; + if (latest_created_stream_ != nullptr) { + open_streams_[stream->id()] = stream; + stream->set_visitor(this); + } +} + +QuicSpdyClientStream* QuicTestClient::GetOrCreateStream() { + if (!connect_attempted_ || auto_reconnect_) { + if (!connected()) { + Connect(); + } + if (!connected()) { + return nullptr; + } + } + if (open_streams_.empty()) { + ClearPerConnectionState(); + } + if (!latest_created_stream_) { + SetLatestCreatedStream(client_->CreateClientStream()); + if (latest_created_stream_) { + latest_created_stream_->SetPriority(QuicStreamPriority( + HttpStreamPriority{priority_, /* incremental = */ false})); + } + } + + return latest_created_stream_; +} + +QuicErrorCode QuicTestClient::connection_error() const { + return client()->connection_error(); +} + +const std::string& QuicTestClient::cert_common_name() const { + return reinterpret_cast(client_->proof_verifier()) + ->common_name(); +} + +const std::string& QuicTestClient::cert_sct() const { + return reinterpret_cast(client_->proof_verifier()) + ->cert_sct(); +} + +const QuicTagValueMap& QuicTestClient::GetServerConfig() const { + QuicCryptoClientConfig* config = client_->crypto_config(); + const QuicCryptoClientConfig::CachedState* state = + config->LookupOrCreate(client_->server_id()); + const CryptoHandshakeMessage* handshake_msg = state->GetServerConfig(); + return handshake_msg->tag_value_map(); +} + +bool QuicTestClient::connected() const { return client_->connected(); } + +void QuicTestClient::Connect() { + if (connected()) { + QUIC_BUG(quic_bug_10133_1) << "Cannot connect already-connected client"; + return; + } + if (!connect_attempted_) { + client_->Initialize(); + } + + // If we've been asked to override SNI, set it now + if (override_sni_set_) { + client_->set_server_id( + QuicServerId(override_sni_, address().port(), false)); + } + + client_->Connect(); + connect_attempted_ = true; +} + +void QuicTestClient::ResetConnection() { + Disconnect(); + Connect(); +} + +void QuicTestClient::Disconnect() { + ClearPerConnectionState(); + if (client_->initialized()) { + client_->Disconnect(); + } + connect_attempted_ = false; +} + +QuicSocketAddress QuicTestClient::local_address() const { + return client_->network_helper()->GetLatestClientAddress(); +} + +void QuicTestClient::ClearPerRequestState() { + stream_error_ = QUIC_STREAM_NO_ERROR; + response_ = ""; + response_complete_ = false; + response_headers_complete_ = false; + preliminary_headers_.clear(); + response_headers_.clear(); + response_trailers_.clear(); + bytes_read_ = 0; + bytes_written_ = 0; + response_body_size_ = 0; +} + +bool QuicTestClient::HaveActiveStream() { + return push_promise_data_to_resend_.get() || !open_streams_.empty(); +} + +bool QuicTestClient::WaitUntil(int timeout_ms, std::function trigger) { + QuicTime::Delta timeout = QuicTime::Delta::FromMilliseconds(timeout_ms); + const QuicClock* clock = client()->session()->connection()->clock(); + QuicTime end_waiting_time = clock->Now() + timeout; + while (connected() && !(trigger && trigger()) && + (timeout_ms < 0 || clock->Now() < end_waiting_time)) { + event_loop_->RunEventLoopOnce(timeout); + client_->WaitForEventsPostprocessing(); + } + ReadNextResponse(); + if (trigger && !trigger()) { + QUIC_VLOG(1) << "Client WaitUntil returning with trigger returning false."; + return false; + } + return true; +} + +ssize_t QuicTestClient::Send(absl::string_view data) { + return SendData(std::string(data), false); +} + +bool QuicTestClient::response_headers_complete() const { + for (std::pair stream : open_streams_) { + if (stream.second->headers_decompressed()) { + return true; + } + } + return response_headers_complete_; +} + +const spdy::Http2HeaderBlock* QuicTestClient::response_headers() const { + for (std::pair stream : open_streams_) { + if (stream.second->headers_decompressed()) { + response_headers_ = stream.second->response_headers().Clone(); + break; + } + } + return &response_headers_; +} + +const spdy::Http2HeaderBlock* QuicTestClient::preliminary_headers() const { + for (std::pair stream : open_streams_) { + size_t bytes_read = + stream.second->stream_bytes_read() + stream.second->header_bytes_read(); + if (bytes_read > 0) { + preliminary_headers_ = stream.second->preliminary_headers().Clone(); + break; + } + } + return &preliminary_headers_; +} + +const spdy::Http2HeaderBlock& QuicTestClient::response_trailers() const { + return response_trailers_; +} + +int64_t QuicTestClient::response_size() const { return bytes_read(); } + +size_t QuicTestClient::bytes_read() const { + for (std::pair stream : open_streams_) { + size_t bytes_read = stream.second->total_body_bytes_read() + + stream.second->header_bytes_read(); + if (bytes_read > 0) { + return bytes_read; + } + } + return bytes_read_; +} + +size_t QuicTestClient::bytes_written() const { + for (std::pair stream : open_streams_) { + size_t bytes_written = stream.second->stream_bytes_written() + + stream.second->header_bytes_written(); + if (bytes_written > 0) { + return bytes_written; + } + } + return bytes_written_; +} + +absl::string_view QuicTestClient::partial_response_body() const { + return latest_created_stream_ == nullptr ? "" + : latest_created_stream_->data(); +} + +void QuicTestClient::OnClose(QuicSpdyStream* stream) { + if (stream == nullptr) { + return; + } + // Always close the stream, regardless of whether it was the last stream + // written. + client()->OnClose(stream); + ++num_responses_; + if (open_streams_.find(stream->id()) == open_streams_.end()) { + return; + } + if (latest_created_stream_ == stream) { + latest_created_stream_ = nullptr; + } + QuicSpdyClientStream* client_stream = + static_cast(stream); + QuicStreamId id = client_stream->id(); + closed_stream_states_.insert(std::make_pair( + id, + PerStreamState( + // Set response_complete to true iff stream is closed while connected. + client_stream->stream_error(), connected(), + client_stream->headers_decompressed(), + client_stream->response_headers(), + client_stream->preliminary_headers(), + (buffer_body() ? std::string(client_stream->data()) : ""), + client_stream->received_trailers(), + // Use NumBytesConsumed to avoid counting retransmitted stream frames. + client_stream->total_body_bytes_read() + + client_stream->header_bytes_read(), + client_stream->stream_bytes_written() + + client_stream->header_bytes_written(), + client_stream->data().size()))); + open_streams_.erase(id); +} + +bool QuicTestClient::CheckVary( + const spdy::Http2HeaderBlock& /*client_request*/, + const spdy::Http2HeaderBlock& /*promise_request*/, + const spdy::Http2HeaderBlock& /*promise_response*/) { + return true; +} + +void QuicTestClient::OnRendezvousResult(QuicSpdyStream* stream) { + std::unique_ptr data_to_resend = + std::move(push_promise_data_to_resend_); + SetLatestCreatedStream(static_cast(stream)); + if (stream) { + stream->OnBodyAvailable(); + } else if (data_to_resend) { + data_to_resend->Resend(); + } +} + +void QuicTestClient::UseWriter(QuicPacketWriterWrapper* writer) { + client_->UseWriter(writer); +} + +void QuicTestClient::UseConnectionId(QuicConnectionId server_connection_id) { + QUICHE_DCHECK(!connected()); + client_->set_server_connection_id_override(server_connection_id); +} + +void QuicTestClient::UseConnectionIdLength( + uint8_t server_connection_id_length) { + QUICHE_DCHECK(!connected()); + client_->set_server_connection_id_length(server_connection_id_length); +} + +void QuicTestClient::UseClientConnectionId( + QuicConnectionId client_connection_id) { + QUICHE_DCHECK(!connected()); + client_->UseClientConnectionId(client_connection_id); +} + +void QuicTestClient::UseClientConnectionIdLength( + uint8_t client_connection_id_length) { + QUICHE_DCHECK(!connected()); + client_->UseClientConnectionIdLength(client_connection_id_length); +} + +bool QuicTestClient::MigrateSocket(const QuicIpAddress& new_host) { + return client_->MigrateSocket(new_host); +} + +bool QuicTestClient::MigrateSocketWithSpecifiedPort( + const QuicIpAddress& new_host, int port) { + client_->set_local_port(port); + return client_->MigrateSocket(new_host); +} + +QuicIpAddress QuicTestClient::bind_to_address() const { + return client_->bind_to_address(); +} + +void QuicTestClient::set_bind_to_address(QuicIpAddress address) { + client_->set_bind_to_address(address); +} + +const QuicSocketAddress& QuicTestClient::address() const { + return client_->server_address(); +} + +void QuicTestClient::WaitForWriteToFlush() { + while (connected() && client()->session()->HasDataToWrite()) { + client_->WaitForEvents(); + } +} + +QuicTestClient::TestClientDataToResend::TestClientDataToResend( + std::unique_ptr headers, absl::string_view body, + bool fin, QuicTestClient* test_client, + quiche::QuicheReferenceCountedPointer + ack_listener) + : QuicDefaultClient::QuicDataToResend(std::move(headers), body, fin), + test_client_(test_client), + ack_listener_(std::move(ack_listener)) {} + +QuicTestClient::TestClientDataToResend::~TestClientDataToResend() = default; + +void QuicTestClient::TestClientDataToResend::Resend() { + test_client_->GetOrCreateStreamAndSendRequest(headers_.get(), body_, fin_, + ack_listener_); + headers_.reset(); +} + +QuicTestClient::PerStreamState::PerStreamState(const PerStreamState& other) + : stream_error(other.stream_error), + response_complete(other.response_complete), + response_headers_complete(other.response_headers_complete), + response_headers(other.response_headers.Clone()), + preliminary_headers(other.preliminary_headers.Clone()), + response(other.response), + response_trailers(other.response_trailers.Clone()), + bytes_read(other.bytes_read), + bytes_written(other.bytes_written), + response_body_size(other.response_body_size) {} + +QuicTestClient::PerStreamState::PerStreamState( + QuicRstStreamErrorCode stream_error, bool response_complete, + bool response_headers_complete, + const spdy::Http2HeaderBlock& response_headers, + const spdy::Http2HeaderBlock& preliminary_headers, + const std::string& response, + const spdy::Http2HeaderBlock& response_trailers, uint64_t bytes_read, + uint64_t bytes_written, int64_t response_body_size) + : stream_error(stream_error), + response_complete(response_complete), + response_headers_complete(response_headers_complete), + response_headers(response_headers.Clone()), + preliminary_headers(preliminary_headers.Clone()), + response(response), + response_trailers(response_trailers.Clone()), + bytes_read(bytes_read), + bytes_written(bytes_written), + response_body_size(response_body_size) {} + +QuicTestClient::PerStreamState::~PerStreamState() = default; + +bool QuicTestClient::PopulateHeaderBlockFromUrl( + const std::string& uri, spdy::Http2HeaderBlock* headers) { + std::string url; + if (absl::StartsWith(uri, "https://") || absl::StartsWith(uri, "http://")) { + url = uri; + } else if (uri[0] == '/') { + url = "https://" + client_->server_id().host() + uri; + } else { + url = "https://" + uri; + } + return SpdyUtils::PopulateHeaderBlockFromUrl(url, headers); +} + +void QuicTestClient::ReadNextResponse() { + if (closed_stream_states_.empty()) { + return; + } + + PerStreamState state(closed_stream_states_.front().second); + + stream_error_ = state.stream_error; + response_ = state.response; + response_complete_ = state.response_complete; + response_headers_complete_ = state.response_headers_complete; + preliminary_headers_ = state.preliminary_headers.Clone(); + response_headers_ = state.response_headers.Clone(); + response_trailers_ = state.response_trailers.Clone(); + bytes_read_ = state.bytes_read; + bytes_written_ = state.bytes_written; + response_body_size_ = state.response_body_size; + + closed_stream_states_.pop_front(); +} + +void QuicTestClient::ClearPerConnectionState() { + ClearPerRequestState(); + open_streams_.clear(); + closed_stream_states_.clear(); + latest_created_stream_ = nullptr; +} + +void QuicTestClient::WaitForDelayedAcks() { + // kWaitDuration is a period of time that is long enough for all delayed + // acks to be sent and received on the other end. + const QuicTime::Delta kWaitDuration = + 4 * QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs); + + const QuicClock* clock = client()->client_session()->connection()->clock(); + + QuicTime wait_until = clock->ApproximateNow() + kWaitDuration; + while (connected() && clock->ApproximateNow() < wait_until) { + // This waits for up to 50 ms. + client()->WaitForEvents(); + } +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_test_client.h b/quiche/quic/test_tools/quic_test_client.h new file mode 100644 index 000000000000..9abab23e1201 --- /dev/null +++ b/quiche/quic/test_tools/quic_test_client.h @@ -0,0 +1,444 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_TEST_CLIENT_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_TEST_CLIENT_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/proto/cached_network_parameters_proto.h" +#include "quiche/quic/core/quic_framer.h" +#include "quiche/quic/core/quic_packet_creator.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/tools/quic_default_client.h" +#include "quiche/common/quiche_linked_hash_map.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +class ProofVerifier; +class QuicPacketWriterWrapper; + +namespace test { + +class MockableQuicClientDefaultNetworkHelper; + +// A quic client which allows mocking out reads and writes. +class MockableQuicClient : public QuicDefaultClient { + public: + MockableQuicClient(QuicSocketAddress server_address, + const QuicServerId& server_id, + const ParsedQuicVersionVector& supported_versions, + QuicEventLoop* event_loop); + + MockableQuicClient(QuicSocketAddress server_address, + const QuicServerId& server_id, const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + QuicEventLoop* event_loop); + + MockableQuicClient(QuicSocketAddress server_address, + const QuicServerId& server_id, const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + QuicEventLoop* event_loop, + std::unique_ptr proof_verifier); + + MockableQuicClient(QuicSocketAddress server_address, + const QuicServerId& server_id, const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + QuicEventLoop* event_loop, + std::unique_ptr proof_verifier, + std::unique_ptr session_cache); + MockableQuicClient(const MockableQuicClient&) = delete; + MockableQuicClient& operator=(const MockableQuicClient&) = delete; + + ~MockableQuicClient() override; + + QuicConnectionId GetClientConnectionId() override; + void UseClientConnectionId(QuicConnectionId client_connection_id); + void UseClientConnectionIdLength(int client_connection_id_length); + + void UseWriter(QuicPacketWriterWrapper* writer); + void set_peer_address(const QuicSocketAddress& address); + // The last incoming packet, iff |track_last_incoming_packet| is true. + const QuicReceivedPacket* last_incoming_packet(); + // If true, copy each packet from ProcessPacket into |last_incoming_packet| + void set_track_last_incoming_packet(bool track); + + // Casts the network helper to a MockableQuicClientDefaultNetworkHelper. + MockableQuicClientDefaultNetworkHelper* mockable_network_helper(); + const MockableQuicClientDefaultNetworkHelper* mockable_network_helper() const; + + private: + // Client connection ID to use, if client_connection_id_overridden_. + // TODO(wub): Move client_connection_id_(length_) overrides to QuicClientBase. + QuicConnectionId override_client_connection_id_; + bool client_connection_id_overridden_; + int override_client_connection_id_length_ = -1; + CachedNetworkParameters cached_network_paramaters_; +}; + +// A toy QUIC client used for testing. +class QuicTestClient : public QuicSpdyStream::Visitor, + public QuicClientPushPromiseIndex::Delegate { + public: + QuicTestClient(QuicSocketAddress server_address, + const std::string& server_hostname, + const ParsedQuicVersionVector& supported_versions); + QuicTestClient(QuicSocketAddress server_address, + const std::string& server_hostname, const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions); + QuicTestClient(QuicSocketAddress server_address, + const std::string& server_hostname, const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + std::unique_ptr proof_verifier); + QuicTestClient(QuicSocketAddress server_address, + const std::string& server_hostname, const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + std::unique_ptr proof_verifier, + std::unique_ptr session_cache); + QuicTestClient(QuicSocketAddress server_address, + const std::string& server_hostname, const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + std::unique_ptr proof_verifier, + std::unique_ptr session_cache, + std::unique_ptr event_loop); + + ~QuicTestClient() override; + + // Sets the |user_agent_id| of the |client_|. + void SetUserAgentID(const std::string& user_agent_id); + + // Wraps data in a quic packet and sends it. + ssize_t SendData(const std::string& data, bool last_data); + // As above, but |delegate| will be notified when |data| is ACKed. + ssize_t SendData( + const std::string& data, bool last_data, + quiche::QuicheReferenceCountedPointer + ack_listener); + + // Clears any outstanding state and sends a simple GET of 'uri' to the + // server. Returns 0 if the request failed and no bytes were written. + ssize_t SendRequest(const std::string& uri); + // Send a request R and a RST_FRAME which resets R, in the same packet. + ssize_t SendRequestAndRstTogether(const std::string& uri); + // Sends requests for all the urls and waits for the responses. To process + // the individual responses as they are returned, the caller should use the + // set the response_listener on the client(). + void SendRequestsAndWaitForResponses( + const std::vector& url_list); + // Sends a request containing |headers| and |body| and returns the number of + // bytes sent (the size of the serialized request headers and body). + ssize_t SendMessage(const spdy::Http2HeaderBlock& headers, + absl::string_view body); + // Sends a request containing |headers| and |body| with the fin bit set to + // |fin| and returns the number of bytes sent (the size of the serialized + // request headers and body). + ssize_t SendMessage(const spdy::Http2HeaderBlock& headers, + absl::string_view body, bool fin); + // Sends a request containing |headers| and |body| with the fin bit set to + // |fin| and returns the number of bytes sent (the size of the serialized + // request headers and body). If |flush| is true, will wait for the message to + // be flushed before returning. + ssize_t SendMessage(const spdy::Http2HeaderBlock& headers, + absl::string_view body, bool fin, bool flush); + // Sends a request containing |headers| and |body|, waits for the response, + // and returns the response body. + std::string SendCustomSynchronousRequest( + const spdy::Http2HeaderBlock& headers, const std::string& body); + // Sends a GET request for |uri|, waits for the response, and returns the + // response body. + std::string SendSynchronousRequest(const std::string& uri); + void SendConnectivityProbing(); + void Connect(); + void ResetConnection(); + void Disconnect(); + QuicSocketAddress local_address() const; + void ClearPerRequestState(); + bool WaitUntil(int timeout_ms, std::function trigger); + ssize_t Send(absl::string_view data); + bool connected() const; + bool buffer_body() const; + void set_buffer_body(bool buffer_body); + + // Getters for stream state that only get updated once a complete response is + // received. + const spdy::Http2HeaderBlock& response_trailers() const; + bool response_complete() const; + int64_t response_body_size() const; + const std::string& response_body() const; + // Getters for stream state that return state of the oldest active stream that + // have received a partial response. + bool response_headers_complete() const; + const spdy::Http2HeaderBlock* response_headers() const; + const spdy::Http2HeaderBlock* preliminary_headers() const; + int64_t response_size() const; + size_t bytes_read() const; + size_t bytes_written() const; + + // Returns response body received so far by the stream that has been most + // recently opened among currently open streams. To query response body + // received by a stream that is already closed, use `response_body()` instead. + absl::string_view partial_response_body() const; + + // Returns once at least one complete response or a connection close has been + // received from the server. If responses are received for multiple (say 2) + // streams, next WaitForResponse will return immediately. + void WaitForResponse() { WaitForResponseForMs(-1); } + + // Returns once some data is received on any open streams or at least one + // complete response is received from the server. + void WaitForInitialResponse() { WaitForInitialResponseForMs(-1); } + + // Returns once at least one complete response or a connection close has been + // received from the server, or once the timeout expires. + // Passing in a timeout value of -1 disables the timeout. If multiple + // responses are received while the client is waiting, subsequent calls to + // this function will return immediately. + void WaitForResponseForMs(int timeout_ms) { + WaitUntil(timeout_ms, [this]() { + return !HaveActiveStream() || !closed_stream_states_.empty(); + }); + if (response_complete()) { + QUIC_VLOG(1) << "Client received response:" + << response_headers()->DebugString() << response_body(); + } + } + + // Returns once some data is received on any open streams or at least one + // complete response is received from the server, or once the timeout + // expires. -1 means no timeout. + void WaitForInitialResponseForMs(int timeout_ms) { + WaitUntil(timeout_ms, + [this]() { return !HaveActiveStream() || response_size() != 0; }); + } + + // Migrate local address to <|new_host|, a random port>. + // Return whether the migration succeeded. + bool MigrateSocket(const QuicIpAddress& new_host); + // Migrate local address to <|new_host|, |port|>. + // Return whether the migration succeeded. + bool MigrateSocketWithSpecifiedPort(const QuicIpAddress& new_host, int port); + QuicIpAddress bind_to_address() const; + void set_bind_to_address(QuicIpAddress address); + const QuicSocketAddress& address() const; + + // From QuicSpdyStream::Visitor + void OnClose(QuicSpdyStream* stream) override; + + // From QuicClientPushPromiseIndex::Delegate + bool CheckVary(const spdy::Http2HeaderBlock& client_request, + const spdy::Http2HeaderBlock& promise_request, + const spdy::Http2HeaderBlock& promise_response) override; + void OnRendezvousResult(QuicSpdyStream*) override; + + // Configures client_ to take ownership of and use the writer. + // Must be called before initial connect. + void UseWriter(QuicPacketWriterWrapper* writer); + // Configures client_ to use a specific server connection ID instead of a + // random one. + void UseConnectionId(QuicConnectionId server_connection_id); + // Configures client_ to use a specific server connection ID length instead + // of the default of kQuicDefaultConnectionIdLength. + void UseConnectionIdLength(uint8_t server_connection_id_length); + // Configures client_ to use a specific client connection ID instead of an + // empty one. + void UseClientConnectionId(QuicConnectionId client_connection_id); + // Configures client_ to use a specific client connection ID length instead + // of the default of zero. + void UseClientConnectionIdLength(uint8_t client_connection_id_length); + + // Returns nullptr if the maximum number of streams have already been created. + QuicSpdyClientStream* GetOrCreateStream(); + + // Calls GetOrCreateStream(), sends the request on the stream, and + // stores the request in case it needs to be resent. If |headers| is + // null, only the body will be sent on the stream. + ssize_t GetOrCreateStreamAndSendRequest( + const spdy::Http2HeaderBlock* headers, absl::string_view body, bool fin, + quiche::QuicheReferenceCountedPointer + ack_listener); + + QuicRstStreamErrorCode stream_error() { return stream_error_; } + QuicErrorCode connection_error() const; + + MockableQuicClient* client() { return client_.get(); } + const MockableQuicClient* client() const { return client_.get(); } + + // cert_common_name returns the common name value of the server's certificate, + // or the empty std::string if no certificate was presented. + const std::string& cert_common_name() const; + + // cert_sct returns the signed timestamp of the server's certificate, + // or the empty std::string if no signed timestamp was presented. + const std::string& cert_sct() const; + + // Get the server config map. Server config must exist. + const QuicTagValueMap& GetServerConfig() const; + + void set_auto_reconnect(bool reconnect) { auto_reconnect_ = reconnect; } + + void set_priority(spdy::SpdyPriority priority) { priority_ = priority; } + + void WaitForWriteToFlush(); + + QuicEventLoop* event_loop() { return event_loop_.get(); } + + size_t num_requests() const { return num_requests_; } + + size_t num_responses() const { return num_responses_; } + + void set_server_address(const QuicSocketAddress& server_address) { + client_->set_server_address(server_address); + } + + void set_peer_address(const QuicSocketAddress& address) { + client_->set_peer_address(address); + } + + // Explicitly set the SNI value for this client, overriding the default + // behavior which extracts the SNI value from the request URL. + void OverrideSni(const std::string& sni) { + override_sni_set_ = true; + override_sni_ = sni; + } + + void Initialize(); + + void set_client(MockableQuicClient* client) { client_.reset(client); } + + // Given |uri|, populates the fields in |headers| for a simple GET + // request. If |uri| is a relative URL, the QuicServerId will be + // use to specify the authority. + bool PopulateHeaderBlockFromUrl(const std::string& uri, + spdy::Http2HeaderBlock* headers); + + // Waits for a period of time that is long enough to receive all delayed acks + // sent by peer. + void WaitForDelayedAcks(); + + QuicSpdyClientStream* latest_created_stream() { + return latest_created_stream_; + } + + protected: + QuicTestClient(); + QuicTestClient(const QuicTestClient&) = delete; + QuicTestClient(const QuicTestClient&&) = delete; + QuicTestClient& operator=(const QuicTestClient&) = delete; + QuicTestClient& operator=(const QuicTestClient&&) = delete; + + private: + class TestClientDataToResend : public QuicDefaultClient::QuicDataToResend { + public: + TestClientDataToResend( + std::unique_ptr headers, absl::string_view body, + bool fin, QuicTestClient* test_client, + quiche::QuicheReferenceCountedPointer + ack_listener); + + ~TestClientDataToResend() override; + + void Resend() override; + + protected: + QuicTestClient* test_client_; + quiche::QuicheReferenceCountedPointer + ack_listener_; + }; + + // PerStreamState of a stream is updated when it is closed. + struct PerStreamState { + PerStreamState(const PerStreamState& other); + PerStreamState(QuicRstStreamErrorCode stream_error, bool response_complete, + bool response_headers_complete, + const spdy::Http2HeaderBlock& response_headers, + const spdy::Http2HeaderBlock& preliminary_headers, + const std::string& response, + const spdy::Http2HeaderBlock& response_trailers, + uint64_t bytes_read, uint64_t bytes_written, + int64_t response_body_size); + ~PerStreamState(); + + QuicRstStreamErrorCode stream_error; + bool response_complete; + bool response_headers_complete; + spdy::Http2HeaderBlock response_headers; + spdy::Http2HeaderBlock preliminary_headers; + std::string response; + spdy::Http2HeaderBlock response_trailers; + uint64_t bytes_read; + uint64_t bytes_written; + int64_t response_body_size; + }; + + bool HaveActiveStream(); + + // Read oldest received response and remove it from closed_stream_states_. + void ReadNextResponse(); + + // Clear open_streams_, closed_stream_states_ and reset + // latest_created_stream_. + void ClearPerConnectionState(); + + // Update latest_created_stream_, add |stream| to open_streams_ and starts + // tracking its state. + void SetLatestCreatedStream(QuicSpdyClientStream* stream); + + std::unique_ptr event_loop_; + std::unique_ptr client_; // The actual client + QuicSpdyClientStream* latest_created_stream_; + std::map open_streams_; + // Received responses of closed streams. + quiche::QuicheLinkedHashMap + closed_stream_states_; + + QuicRstStreamErrorCode stream_error_; + + bool response_complete_; + bool response_headers_complete_; + mutable spdy::Http2HeaderBlock preliminary_headers_; + mutable spdy::Http2HeaderBlock response_headers_; + + // Parsed response trailers (if present), copied from the stream in OnClose. + spdy::Http2HeaderBlock response_trailers_; + + spdy::SpdyPriority priority_; + std::string response_; + // bytes_read_ and bytes_written_ are updated only when stream_ is released; + // prefer bytes_read() and bytes_written() member functions. + uint64_t bytes_read_; + uint64_t bytes_written_; + // The number of HTTP body bytes received. + int64_t response_body_size_; + // True if we tried to connect already since the last call to Disconnect(). + bool connect_attempted_; + // The client will auto-connect exactly once before sending data. If + // something causes a connection reset, it will not automatically reconnect + // unless auto_reconnect_ is true. + bool auto_reconnect_; + // Should we buffer the response body? Defaults to true. + bool buffer_body_; + // For async push promise rendezvous, validation may fail in which + // case the request should be retried. + std::unique_ptr push_promise_data_to_resend_; + // Number of requests/responses this client has sent/received. + size_t num_requests_; + size_t num_responses_; + + // If set, this value is used for the connection SNI, overriding the usual + // logic which extracts the SNI from the request URL. + bool override_sni_set_ = false; + std::string override_sni_; +}; + +} // namespace test + +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_TEST_CLIENT_H_ diff --git a/quiche/quic/test_tools/quic_test_server.cc b/quiche/quic/test_tools/quic_test_server.cc new file mode 100644 index 000000000000..ad97b4a70c3b --- /dev/null +++ b/quiche/quic/test_tools/quic_test_server.cc @@ -0,0 +1,260 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_test_server.h" + +#include + +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/io/quic_default_event_loop.h" +#include "quiche/quic/core/quic_default_connection_helper.h" +#include "quiche/quic/tools/quic_simple_crypto_server_stream_helper.h" +#include "quiche/quic/tools/quic_simple_dispatcher.h" +#include "quiche/quic/tools/quic_simple_server_session.h" + +namespace quic { + +namespace test { + +class CustomStreamSession : public QuicSimpleServerSession { + public: + CustomStreamSession( + const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, QuicSession::Visitor* visitor, + QuicCryptoServerStreamBase::Helper* helper, + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache, + QuicTestServer::StreamFactory* stream_factory, + QuicTestServer::CryptoStreamFactory* crypto_stream_factory, + QuicSimpleServerBackend* quic_simple_server_backend) + : QuicSimpleServerSession(config, supported_versions, connection, visitor, + helper, crypto_config, compressed_certs_cache, + quic_simple_server_backend), + stream_factory_(stream_factory), + crypto_stream_factory_(crypto_stream_factory) {} + + QuicSpdyStream* CreateIncomingStream(QuicStreamId id) override { + if (!ShouldCreateIncomingStream(id)) { + return nullptr; + } + if (stream_factory_) { + QuicSpdyStream* stream = + stream_factory_->CreateStream(id, this, server_backend()); + ActivateStream(absl::WrapUnique(stream)); + return stream; + } + return QuicSimpleServerSession::CreateIncomingStream(id); + } + + std::unique_ptr CreateQuicCryptoServerStream( + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache) override { + if (crypto_stream_factory_) { + return crypto_stream_factory_->CreateCryptoStream(crypto_config, this); + } + return QuicSimpleServerSession::CreateQuicCryptoServerStream( + crypto_config, compressed_certs_cache); + } + + private: + QuicTestServer::StreamFactory* stream_factory_; // Not owned. + QuicTestServer::CryptoStreamFactory* crypto_stream_factory_; // Not owned. +}; + +class QuicTestDispatcher : public QuicSimpleDispatcher { + public: + QuicTestDispatcher( + const QuicConfig* config, const QuicCryptoServerConfig* crypto_config, + QuicVersionManager* version_manager, + std::unique_ptr helper, + std::unique_ptr session_helper, + std::unique_ptr alarm_factory, + QuicSimpleServerBackend* quic_simple_server_backend, + uint8_t expected_server_connection_id_length, + ConnectionIdGeneratorInterface& generator) + : QuicSimpleDispatcher(config, crypto_config, version_manager, + std::move(helper), std::move(session_helper), + std::move(alarm_factory), + quic_simple_server_backend, + expected_server_connection_id_length, generator), + session_factory_(nullptr), + stream_factory_(nullptr), + crypto_stream_factory_(nullptr) {} + + std::unique_ptr CreateQuicSession( + QuicConnectionId id, const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, absl::string_view /*alpn*/, + const ParsedQuicVersion& version, + const ParsedClientHello& /*parsed_chlo*/) override { + QuicReaderMutexLock lock(&factory_lock_); + // The QuicServerSessionBase takes ownership of |connection| below. + QuicConnection* connection = new QuicConnection( + id, self_address, peer_address, helper(), alarm_factory(), writer(), + /* owns_writer= */ false, Perspective::IS_SERVER, + ParsedQuicVersionVector{version}, connection_id_generator()); + + std::unique_ptr session; + if (session_factory_ == nullptr && stream_factory_ == nullptr && + crypto_stream_factory_ == nullptr) { + session = std::make_unique( + config(), GetSupportedVersions(), connection, this, session_helper(), + crypto_config(), compressed_certs_cache(), server_backend()); + } else if (stream_factory_ != nullptr || + crypto_stream_factory_ != nullptr) { + session = std::make_unique( + config(), GetSupportedVersions(), connection, this, session_helper(), + crypto_config(), compressed_certs_cache(), stream_factory_, + crypto_stream_factory_, server_backend()); + } else { + session = session_factory_->CreateSession( + config(), connection, this, session_helper(), crypto_config(), + compressed_certs_cache(), server_backend()); + } + if (VersionUsesHttp3(version.transport_version) && + GetQuicReloadableFlag(quic_verify_request_headers_2)) { + QUICHE_DCHECK(session->allow_extended_connect()); + // Do not allow extended CONNECT request if the backend doesn't support + // it. + session->set_allow_extended_connect( + server_backend()->SupportsExtendedConnect()); + } + session->Initialize(); + return session; + } + + void SetSessionFactory(QuicTestServer::SessionFactory* factory) { + QuicWriterMutexLock lock(&factory_lock_); + QUICHE_DCHECK(session_factory_ == nullptr); + QUICHE_DCHECK(stream_factory_ == nullptr); + QUICHE_DCHECK(crypto_stream_factory_ == nullptr); + session_factory_ = factory; + } + + void SetStreamFactory(QuicTestServer::StreamFactory* factory) { + QuicWriterMutexLock lock(&factory_lock_); + QUICHE_DCHECK(session_factory_ == nullptr); + QUICHE_DCHECK(stream_factory_ == nullptr); + stream_factory_ = factory; + } + + void SetCryptoStreamFactory(QuicTestServer::CryptoStreamFactory* factory) { + QuicWriterMutexLock lock(&factory_lock_); + QUICHE_DCHECK(session_factory_ == nullptr); + QUICHE_DCHECK(crypto_stream_factory_ == nullptr); + crypto_stream_factory_ = factory; + } + + private: + QuicMutex factory_lock_; + QuicTestServer::SessionFactory* session_factory_; // Not owned. + QuicTestServer::StreamFactory* stream_factory_; // Not owned. + QuicTestServer::CryptoStreamFactory* crypto_stream_factory_; // Not owned. +}; + +QuicTestServer::QuicTestServer( + std::unique_ptr proof_source, + QuicSimpleServerBackend* quic_simple_server_backend) + : QuicServer(std::move(proof_source), quic_simple_server_backend) {} + +QuicTestServer::QuicTestServer( + std::unique_ptr proof_source, const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + QuicSimpleServerBackend* quic_simple_server_backend) + : QuicTestServer(std::move(proof_source), config, supported_versions, + quic_simple_server_backend, + kQuicDefaultConnectionIdLength) {} + +QuicTestServer::QuicTestServer( + std::unique_ptr proof_source, const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + QuicSimpleServerBackend* quic_simple_server_backend, + uint8_t expected_server_connection_id_length) + : QuicServer(std::move(proof_source), config, + QuicCryptoServerConfig::ConfigOptions(), supported_versions, + quic_simple_server_backend, + expected_server_connection_id_length) {} + +QuicDispatcher* QuicTestServer::CreateQuicDispatcher() { + return new QuicTestDispatcher( + &config(), &crypto_config(), version_manager(), + std::make_unique(), + std::unique_ptr( + new QuicSimpleCryptoServerStreamHelper()), + event_loop()->CreateAlarmFactory(), server_backend(), + expected_server_connection_id_length(), connection_id_generator()); +} + +void QuicTestServer::SetSessionFactory(SessionFactory* factory) { + QUICHE_DCHECK(dispatcher()); + static_cast(dispatcher())->SetSessionFactory(factory); +} + +void QuicTestServer::SetSpdyStreamFactory(StreamFactory* factory) { + static_cast(dispatcher())->SetStreamFactory(factory); +} + +void QuicTestServer::SetCryptoStreamFactory(CryptoStreamFactory* factory) { + static_cast(dispatcher()) + ->SetCryptoStreamFactory(factory); +} + +void QuicTestServer::SetEventLoopFactory(QuicEventLoopFactory* factory) { + event_loop_factory_ = factory; +} + +std::unique_ptr QuicTestServer::CreateEventLoop() { + QuicEventLoopFactory* factory = event_loop_factory_; + if (factory == nullptr) { + factory = GetDefaultEventLoop(); + } + return factory->Create(QuicDefaultClock::Get()); +} + +/////////////////////////// TEST SESSIONS /////////////////////////////// + +ImmediateGoAwaySession::ImmediateGoAwaySession( + const QuicConfig& config, QuicConnection* connection, + QuicSession::Visitor* visitor, QuicCryptoServerStreamBase::Helper* helper, + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache, + QuicSimpleServerBackend* quic_simple_server_backend) + : QuicSimpleServerSession( + config, CurrentSupportedVersions(), connection, visitor, helper, + crypto_config, compressed_certs_cache, quic_simple_server_backend) {} + +void ImmediateGoAwaySession::OnStreamFrame(const QuicStreamFrame& frame) { + if (VersionUsesHttp3(transport_version())) { + SendHttp3GoAway(QUIC_PEER_GOING_AWAY, ""); + } else { + SendGoAway(QUIC_PEER_GOING_AWAY, ""); + } + QuicSimpleServerSession::OnStreamFrame(frame); +} + +void ImmediateGoAwaySession::OnCryptoFrame(const QuicCryptoFrame& frame) { + // In IETF QUIC, GOAWAY lives up in HTTP/3 layer. It's sent in a QUIC stream + // and requires encryption. Thus the sending is done in + // OnNewEncryptionKeyAvailable(). + if (!VersionUsesHttp3(transport_version())) { + SendGoAway(QUIC_PEER_GOING_AWAY, ""); + } + QuicSimpleServerSession::OnCryptoFrame(frame); +} + +void ImmediateGoAwaySession::OnNewEncryptionKeyAvailable( + EncryptionLevel level, std::unique_ptr encrypter) { + QuicSimpleServerSession::OnNewEncryptionKeyAvailable(level, + std::move(encrypter)); + if (VersionUsesHttp3(transport_version())) { + if (IsEncryptionEstablished() && !goaway_sent()) { + SendHttp3GoAway(QUIC_PEER_GOING_AWAY, ""); + } + } +} + +} // namespace test + +} // namespace quic diff --git a/quiche/quic/test_tools/quic_test_server.h b/quiche/quic/test_tools/quic_test_server.h new file mode 100644 index 000000000000..c0be6f0cf62c --- /dev/null +++ b/quiche/quic/test_tools/quic_test_server.h @@ -0,0 +1,126 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_TEST_SERVER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_TEST_SERVER_H_ + +#include "quiche/quic/core/quic_dispatcher.h" +#include "quiche/quic/core/quic_session.h" +#include "quiche/quic/tools/quic_server.h" +#include "quiche/quic/tools/quic_simple_server_backend.h" +#include "quiche/quic/tools/quic_simple_server_session.h" +#include "quiche/quic/tools/quic_simple_server_stream.h" + +namespace quic { + +namespace test { + +// A test server which enables easy creation of custom QuicServerSessions +// +// Eventually this may be extended to allow custom QuicConnections etc. +class QuicTestServer : public QuicServer { + public: + // Factory for creating QuicServerSessions. + class SessionFactory { + public: + virtual ~SessionFactory() {} + + // Returns a new session owned by the caller. + virtual std::unique_ptr CreateSession( + const QuicConfig& config, QuicConnection* connection, + QuicSession::Visitor* visitor, + QuicCryptoServerStreamBase::Helper* helper, + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache, + QuicSimpleServerBackend* quic_simple_server_backend) = 0; + }; + + // Factory for creating QuicSimpleServerStreams. + class StreamFactory { + public: + virtual ~StreamFactory() {} + + // Returns a new stream owned by the caller. + virtual QuicSimpleServerStream* CreateStream( + QuicStreamId id, QuicSpdySession* session, + QuicSimpleServerBackend* quic_simple_server_backend) = 0; + + virtual QuicSimpleServerStream* CreateStream( + PendingStream* pending, QuicSpdySession* session, + QuicSimpleServerBackend* quic_simple_server_backend) = 0; + }; + + class CryptoStreamFactory { + public: + virtual ~CryptoStreamFactory() {} + + // Returns a new QuicCryptoServerStreamBase owned by the caller + virtual std::unique_ptr CreateCryptoStream( + const QuicCryptoServerConfig* crypto_config, + QuicServerSessionBase* session) = 0; + }; + + QuicTestServer(std::unique_ptr proof_source, + QuicSimpleServerBackend* quic_simple_server_backend); + QuicTestServer(std::unique_ptr proof_source, + const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + QuicSimpleServerBackend* quic_simple_server_backend); + QuicTestServer(std::unique_ptr proof_source, + const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + QuicSimpleServerBackend* quic_simple_server_backend, + uint8_t expected_server_connection_id_length); + + // Create a custom dispatcher which creates custom sessions. + QuicDispatcher* CreateQuicDispatcher() override; + + // Sets a custom session factory, owned by the caller, for easy custom + // session logic. This is incompatible with setting a stream factory or a + // crypto stream factory. + void SetSessionFactory(SessionFactory* factory); + + // Sets a custom stream factory, owned by the caller, for easy custom + // stream logic. This is incompatible with setting a session factory. + void SetSpdyStreamFactory(StreamFactory* factory); + + // Sets a custom crypto stream factory, owned by the caller, for easy custom + // crypto logic. This is incompatible with setting a session factory. + void SetCryptoStreamFactory(CryptoStreamFactory* factory); + + // Sets the override for the default event loop factory used by the server. + void SetEventLoopFactory(QuicEventLoopFactory* factory); + + protected: + std::unique_ptr CreateEventLoop() override; + + private: + QuicEventLoopFactory* event_loop_factory_ = nullptr; +}; + +// Useful test sessions for the QuicTestServer. + +// Test session which sends a GOAWAY immedaitely on creation, before crypto +// credentials have even been established. +class ImmediateGoAwaySession : public QuicSimpleServerSession { + public: + ImmediateGoAwaySession(const QuicConfig& config, QuicConnection* connection, + QuicSession::Visitor* visitor, + QuicCryptoServerStreamBase::Helper* helper, + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache, + QuicSimpleServerBackend* quic_simple_server_backend); + + // Override to send GoAway. + void OnStreamFrame(const QuicStreamFrame& frame) override; + void OnCryptoFrame(const QuicCryptoFrame& frame) override; + void OnNewEncryptionKeyAvailable( + EncryptionLevel level, std::unique_ptr encrypter) override; +}; + +} // namespace test + +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_TEST_SERVER_H_ diff --git a/quiche/quic/test_tools/quic_test_utils.cc b/quiche/quic/test_tools/quic_test_utils.cc new file mode 100644 index 000000000000..e622ed816fa5 --- /dev/null +++ b/quiche/quic/test_tools/quic_test_utils.cc @@ -0,0 +1,1515 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_test_utils.h" + +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/string_view.h" +#include "openssl/chacha.h" +#include "openssl/sha.h" +#include "quiche/quic/core/crypto/crypto_framer.h" +#include "quiche/quic/core/crypto/crypto_handshake.h" +#include "quiche/quic/core/crypto/crypto_utils.h" +#include "quiche/quic/core/crypto/null_decrypter.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/core/crypto/quic_decrypter.h" +#include "quiche/quic/core/crypto/quic_encrypter.h" +#include "quiche/quic/core/http/quic_spdy_client_session.h" +#include "quiche/quic/core/quic_config.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/core/quic_framer.h" +#include "quiche/quic/core/quic_packet_creator.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/quic_config_peer.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/common/quiche_buffer_allocator.h" +#include "quiche/common/quiche_endian.h" +#include "quiche/common/simple_buffer_allocator.h" +#include "quiche/spdy/core/spdy_frame_builder.h" + +using testing::_; +using testing::Invoke; + +namespace quic { +namespace test { + +QuicConnectionId TestConnectionId() { + // Chosen by fair dice roll. + // Guaranteed to be random. + return TestConnectionId(42); +} + +QuicConnectionId TestConnectionId(uint64_t connection_number) { + const uint64_t connection_id64_net = + quiche::QuicheEndian::HostToNet64(connection_number); + return QuicConnectionId(reinterpret_cast(&connection_id64_net), + sizeof(connection_id64_net)); +} + +QuicConnectionId TestConnectionIdNineBytesLong(uint64_t connection_number) { + const uint64_t connection_number_net = + quiche::QuicheEndian::HostToNet64(connection_number); + char connection_id_bytes[9] = {}; + static_assert( + sizeof(connection_id_bytes) == 1 + sizeof(connection_number_net), + "bad lengths"); + memcpy(connection_id_bytes + 1, &connection_number_net, + sizeof(connection_number_net)); + return QuicConnectionId(connection_id_bytes, sizeof(connection_id_bytes)); +} + +uint64_t TestConnectionIdToUInt64(QuicConnectionId connection_id) { + QUICHE_DCHECK_EQ(connection_id.length(), kQuicDefaultConnectionIdLength); + uint64_t connection_id64_net = 0; + memcpy(&connection_id64_net, connection_id.data(), + std::min(static_cast(connection_id.length()), + sizeof(connection_id64_net))); + return quiche::QuicheEndian::NetToHost64(connection_id64_net); +} + +std::vector CreateStatelessResetTokenForTest() { + static constexpr uint8_t kStatelessResetTokenDataForTest[16] = { + 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, + 0x98, 0x99, 0x9A, 0x9B, 0x9C, 0x9D, 0x9E, 0x9F}; + return std::vector(kStatelessResetTokenDataForTest, + kStatelessResetTokenDataForTest + + sizeof(kStatelessResetTokenDataForTest)); +} + +std::string TestHostname() { return "test.example.com"; } + +QuicServerId TestServerId() { return QuicServerId(TestHostname(), kTestPort); } + +QuicAckFrame InitAckFrame(const std::vector& ack_blocks) { + QUICHE_DCHECK_GT(ack_blocks.size(), 0u); + + QuicAckFrame ack; + QuicPacketNumber end_of_previous_block(1); + for (const QuicAckBlock& block : ack_blocks) { + QUICHE_DCHECK_GE(block.start, end_of_previous_block); + QUICHE_DCHECK_GT(block.limit, block.start); + ack.packets.AddRange(block.start, block.limit); + end_of_previous_block = block.limit; + } + + ack.largest_acked = ack.packets.Max(); + + return ack; +} + +QuicAckFrame InitAckFrame(uint64_t largest_acked) { + return InitAckFrame(QuicPacketNumber(largest_acked)); +} + +QuicAckFrame InitAckFrame(QuicPacketNumber largest_acked) { + return InitAckFrame({{QuicPacketNumber(1), largest_acked + 1}}); +} + +QuicAckFrame MakeAckFrameWithAckBlocks(size_t num_ack_blocks, + uint64_t least_unacked) { + QuicAckFrame ack; + ack.largest_acked = QuicPacketNumber(2 * num_ack_blocks + least_unacked); + // Add enough received packets to get num_ack_blocks ack blocks. + for (QuicPacketNumber i = QuicPacketNumber(2); + i < QuicPacketNumber(2 * num_ack_blocks + 1); i += 2) { + ack.packets.Add(i + least_unacked); + } + return ack; +} + +QuicAckFrame MakeAckFrameWithGaps(uint64_t gap_size, size_t max_num_gaps, + uint64_t largest_acked) { + QuicAckFrame ack; + ack.largest_acked = QuicPacketNumber(largest_acked); + ack.packets.Add(QuicPacketNumber(largest_acked)); + for (size_t i = 0; i < max_num_gaps; ++i) { + if (largest_acked <= gap_size) { + break; + } + largest_acked -= gap_size; + ack.packets.Add(QuicPacketNumber(largest_acked)); + } + return ack; +} + +EncryptionLevel HeaderToEncryptionLevel(const QuicPacketHeader& header) { + if (header.form == IETF_QUIC_SHORT_HEADER_PACKET) { + return ENCRYPTION_FORWARD_SECURE; + } else if (header.form == IETF_QUIC_LONG_HEADER_PACKET) { + if (header.long_packet_type == HANDSHAKE) { + return ENCRYPTION_HANDSHAKE; + } else if (header.long_packet_type == ZERO_RTT_PROTECTED) { + return ENCRYPTION_ZERO_RTT; + } + } + return ENCRYPTION_INITIAL; +} + +std::unique_ptr BuildUnsizedDataPacket( + QuicFramer* framer, const QuicPacketHeader& header, + const QuicFrames& frames) { + const size_t max_plaintext_size = + framer->GetMaxPlaintextSize(kMaxOutgoingPacketSize); + size_t packet_size = GetPacketHeaderSize(framer->transport_version(), header); + for (size_t i = 0; i < frames.size(); ++i) { + QUICHE_DCHECK_LE(packet_size, max_plaintext_size); + bool first_frame = i == 0; + bool last_frame = i == frames.size() - 1; + const size_t frame_size = framer->GetSerializedFrameLength( + frames[i], max_plaintext_size - packet_size, first_frame, last_frame, + header.packet_number_length); + QUICHE_DCHECK(frame_size); + packet_size += frame_size; + } + return BuildUnsizedDataPacket(framer, header, frames, packet_size); +} + +std::unique_ptr BuildUnsizedDataPacket( + QuicFramer* framer, const QuicPacketHeader& header, + const QuicFrames& frames, size_t packet_size) { + char* buffer = new char[packet_size]; + EncryptionLevel level = HeaderToEncryptionLevel(header); + size_t length = + framer->BuildDataPacket(header, frames, buffer, packet_size, level); + + if (length == 0) { + delete[] buffer; + return nullptr; + } + // Re-construct the data packet with data ownership. + return std::make_unique( + buffer, length, /* owns_buffer */ true, + GetIncludedDestinationConnectionIdLength(header), + GetIncludedSourceConnectionIdLength(header), header.version_flag, + header.nonce != nullptr, header.packet_number_length, + header.retry_token_length_length, header.retry_token.length(), + header.length_length); +} + +std::string Sha1Hash(absl::string_view data) { + char buffer[SHA_DIGEST_LENGTH]; + SHA1(reinterpret_cast(data.data()), data.size(), + reinterpret_cast(buffer)); + return std::string(buffer, ABSL_ARRAYSIZE(buffer)); +} + +bool ClearControlFrame(const QuicFrame& frame) { + DeleteFrame(&const_cast(frame)); + return true; +} + +bool ClearControlFrameWithTransmissionType(const QuicFrame& frame, + TransmissionType /*type*/) { + return ClearControlFrame(frame); +} + +uint64_t SimpleRandom::RandUint64() { + uint64_t result; + RandBytes(&result, sizeof(result)); + return result; +} + +void SimpleRandom::RandBytes(void* data, size_t len) { + uint8_t* data_bytes = reinterpret_cast(data); + while (len > 0) { + const size_t buffer_left = sizeof(buffer_) - buffer_offset_; + const size_t to_copy = std::min(buffer_left, len); + memcpy(data_bytes, buffer_ + buffer_offset_, to_copy); + data_bytes += to_copy; + buffer_offset_ += to_copy; + len -= to_copy; + + if (buffer_offset_ == sizeof(buffer_)) { + FillBuffer(); + } + } +} + +void SimpleRandom::InsecureRandBytes(void* data, size_t len) { + RandBytes(data, len); +} + +uint64_t SimpleRandom::InsecureRandUint64() { return RandUint64(); } + +void SimpleRandom::FillBuffer() { + uint8_t nonce[12]; + memcpy(nonce, buffer_, sizeof(nonce)); + CRYPTO_chacha_20(buffer_, buffer_, sizeof(buffer_), key_, nonce, 0); + buffer_offset_ = 0; +} + +void SimpleRandom::set_seed(uint64_t seed) { + static_assert(sizeof(key_) == SHA256_DIGEST_LENGTH, "Key has to be 256 bits"); + SHA256(reinterpret_cast(&seed), sizeof(seed), key_); + + memset(buffer_, 0, sizeof(buffer_)); + FillBuffer(); +} + +MockFramerVisitor::MockFramerVisitor() { + // By default, we want to accept packets. + ON_CALL(*this, OnProtocolVersionMismatch(_)) + .WillByDefault(testing::Return(false)); + + // By default, we want to accept packets. + ON_CALL(*this, OnUnauthenticatedHeader(_)) + .WillByDefault(testing::Return(true)); + + ON_CALL(*this, OnUnauthenticatedPublicHeader(_)) + .WillByDefault(testing::Return(true)); + + ON_CALL(*this, OnPacketHeader(_)).WillByDefault(testing::Return(true)); + + ON_CALL(*this, OnStreamFrame(_)).WillByDefault(testing::Return(true)); + + ON_CALL(*this, OnCryptoFrame(_)).WillByDefault(testing::Return(true)); + + ON_CALL(*this, OnStopWaitingFrame(_)).WillByDefault(testing::Return(true)); + + ON_CALL(*this, OnPaddingFrame(_)).WillByDefault(testing::Return(true)); + + ON_CALL(*this, OnPingFrame(_)).WillByDefault(testing::Return(true)); + + ON_CALL(*this, OnRstStreamFrame(_)).WillByDefault(testing::Return(true)); + + ON_CALL(*this, OnConnectionCloseFrame(_)) + .WillByDefault(testing::Return(true)); + + ON_CALL(*this, OnStopSendingFrame(_)).WillByDefault(testing::Return(true)); + + ON_CALL(*this, OnPathChallengeFrame(_)).WillByDefault(testing::Return(true)); + + ON_CALL(*this, OnPathResponseFrame(_)).WillByDefault(testing::Return(true)); + + ON_CALL(*this, OnGoAwayFrame(_)).WillByDefault(testing::Return(true)); + ON_CALL(*this, OnMaxStreamsFrame(_)).WillByDefault(testing::Return(true)); + ON_CALL(*this, OnStreamsBlockedFrame(_)).WillByDefault(testing::Return(true)); +} + +MockFramerVisitor::~MockFramerVisitor() {} + +bool NoOpFramerVisitor::OnProtocolVersionMismatch( + ParsedQuicVersion /*version*/) { + return false; +} + +bool NoOpFramerVisitor::OnUnauthenticatedPublicHeader( + const QuicPacketHeader& /*header*/) { + return true; +} + +bool NoOpFramerVisitor::OnUnauthenticatedHeader( + const QuicPacketHeader& /*header*/) { + return true; +} + +bool NoOpFramerVisitor::OnPacketHeader(const QuicPacketHeader& /*header*/) { + return true; +} + +void NoOpFramerVisitor::OnCoalescedPacket( + const QuicEncryptedPacket& /*packet*/) {} + +void NoOpFramerVisitor::OnUndecryptablePacket( + const QuicEncryptedPacket& /*packet*/, EncryptionLevel /*decryption_level*/, + bool /*has_decryption_key*/) {} + +bool NoOpFramerVisitor::OnStreamFrame(const QuicStreamFrame& /*frame*/) { + return true; +} + +bool NoOpFramerVisitor::OnCryptoFrame(const QuicCryptoFrame& /*frame*/) { + return true; +} + +bool NoOpFramerVisitor::OnAckFrameStart(QuicPacketNumber /*largest_acked*/, + QuicTime::Delta /*ack_delay_time*/) { + return true; +} + +bool NoOpFramerVisitor::OnAckRange(QuicPacketNumber /*start*/, + QuicPacketNumber /*end*/) { + return true; +} + +bool NoOpFramerVisitor::OnAckTimestamp(QuicPacketNumber /*packet_number*/, + QuicTime /*timestamp*/) { + return true; +} + +bool NoOpFramerVisitor::OnAckFrameEnd( + QuicPacketNumber /*start*/, + const absl::optional& /*ecn_counts*/) { + return true; +} + +bool NoOpFramerVisitor::OnStopWaitingFrame( + const QuicStopWaitingFrame& /*frame*/) { + return true; +} + +bool NoOpFramerVisitor::OnPaddingFrame(const QuicPaddingFrame& /*frame*/) { + return true; +} + +bool NoOpFramerVisitor::OnPingFrame(const QuicPingFrame& /*frame*/) { + return true; +} + +bool NoOpFramerVisitor::OnRstStreamFrame(const QuicRstStreamFrame& /*frame*/) { + return true; +} + +bool NoOpFramerVisitor::OnConnectionCloseFrame( + const QuicConnectionCloseFrame& /*frame*/) { + return true; +} + +bool NoOpFramerVisitor::OnNewConnectionIdFrame( + const QuicNewConnectionIdFrame& /*frame*/) { + return true; +} + +bool NoOpFramerVisitor::OnRetireConnectionIdFrame( + const QuicRetireConnectionIdFrame& /*frame*/) { + return true; +} + +bool NoOpFramerVisitor::OnNewTokenFrame(const QuicNewTokenFrame& /*frame*/) { + return true; +} + +bool NoOpFramerVisitor::OnStopSendingFrame( + const QuicStopSendingFrame& /*frame*/) { + return true; +} + +bool NoOpFramerVisitor::OnPathChallengeFrame( + const QuicPathChallengeFrame& /*frame*/) { + return true; +} + +bool NoOpFramerVisitor::OnPathResponseFrame( + const QuicPathResponseFrame& /*frame*/) { + return true; +} + +bool NoOpFramerVisitor::OnGoAwayFrame(const QuicGoAwayFrame& /*frame*/) { + return true; +} + +bool NoOpFramerVisitor::OnMaxStreamsFrame( + const QuicMaxStreamsFrame& /*frame*/) { + return true; +} + +bool NoOpFramerVisitor::OnStreamsBlockedFrame( + const QuicStreamsBlockedFrame& /*frame*/) { + return true; +} + +bool NoOpFramerVisitor::OnWindowUpdateFrame( + const QuicWindowUpdateFrame& /*frame*/) { + return true; +} + +bool NoOpFramerVisitor::OnBlockedFrame(const QuicBlockedFrame& /*frame*/) { + return true; +} + +bool NoOpFramerVisitor::OnMessageFrame(const QuicMessageFrame& /*frame*/) { + return true; +} + +bool NoOpFramerVisitor::OnHandshakeDoneFrame( + const QuicHandshakeDoneFrame& /*frame*/) { + return true; +} + +bool NoOpFramerVisitor::OnAckFrequencyFrame( + const QuicAckFrequencyFrame& /*frame*/) { + return true; +} + +bool NoOpFramerVisitor::IsValidStatelessResetToken( + const StatelessResetToken& /*token*/) const { + return false; +} + +MockQuicConnectionVisitor::MockQuicConnectionVisitor() {} + +MockQuicConnectionVisitor::~MockQuicConnectionVisitor() {} + +MockQuicConnectionHelper::MockQuicConnectionHelper() {} + +MockQuicConnectionHelper::~MockQuicConnectionHelper() {} + +const QuicClock* MockQuicConnectionHelper::GetClock() const { return &clock_; } + +QuicClock* MockQuicConnectionHelper::GetClock() { return &clock_; } + +QuicRandom* MockQuicConnectionHelper::GetRandomGenerator() { + return &random_generator_; +} + +QuicAlarm* MockAlarmFactory::CreateAlarm(QuicAlarm::Delegate* delegate) { + return new MockAlarmFactory::TestAlarm( + QuicArenaScopedPtr(delegate)); +} + +QuicArenaScopedPtr MockAlarmFactory::CreateAlarm( + QuicArenaScopedPtr delegate, + QuicConnectionArena* arena) { + if (arena != nullptr) { + return arena->New(std::move(delegate)); + } else { + return QuicArenaScopedPtr(new TestAlarm(std::move(delegate))); + } +} + +quiche::QuicheBufferAllocator* +MockQuicConnectionHelper::GetStreamSendBufferAllocator() { + return &buffer_allocator_; +} + +void MockQuicConnectionHelper::AdvanceTime(QuicTime::Delta delta) { + clock_.AdvanceTime(delta); +} + +MockQuicConnection::MockQuicConnection(QuicConnectionHelperInterface* helper, + QuicAlarmFactory* alarm_factory, + Perspective perspective) + : MockQuicConnection(TestConnectionId(), + QuicSocketAddress(TestPeerIPAddress(), kTestPort), + helper, alarm_factory, perspective, + ParsedVersionOfIndex(CurrentSupportedVersions(), 0)) {} + +MockQuicConnection::MockQuicConnection(QuicSocketAddress address, + QuicConnectionHelperInterface* helper, + QuicAlarmFactory* alarm_factory, + Perspective perspective) + : MockQuicConnection(TestConnectionId(), address, helper, alarm_factory, + perspective, + ParsedVersionOfIndex(CurrentSupportedVersions(), 0)) {} + +MockQuicConnection::MockQuicConnection(QuicConnectionId connection_id, + QuicConnectionHelperInterface* helper, + QuicAlarmFactory* alarm_factory, + Perspective perspective) + : MockQuicConnection(connection_id, + QuicSocketAddress(TestPeerIPAddress(), kTestPort), + helper, alarm_factory, perspective, + ParsedVersionOfIndex(CurrentSupportedVersions(), 0)) {} + +MockQuicConnection::MockQuicConnection( + QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory, + Perspective perspective, const ParsedQuicVersionVector& supported_versions) + : MockQuicConnection( + TestConnectionId(), QuicSocketAddress(TestPeerIPAddress(), kTestPort), + helper, alarm_factory, perspective, supported_versions) {} + +MockQuicConnection::MockQuicConnection( + QuicConnectionId connection_id, QuicSocketAddress initial_peer_address, + QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory, + Perspective perspective, const ParsedQuicVersionVector& supported_versions) + : QuicConnection( + connection_id, + /*initial_self_address=*/QuicSocketAddress(QuicIpAddress::Any4(), 5), + initial_peer_address, helper, alarm_factory, + new testing::NiceMock(), + /* owns_writer= */ true, perspective, supported_versions, + connection_id_generator_) { + ON_CALL(*this, OnError(_)) + .WillByDefault( + Invoke(this, &PacketSavingConnection::QuicConnection_OnError)); + ON_CALL(*this, SendCryptoData(_, _, _)) + .WillByDefault( + Invoke(this, &MockQuicConnection::QuicConnection_SendCryptoData)); + + SetSelfAddress(QuicSocketAddress(QuicIpAddress::Any4(), 5)); +} + +MockQuicConnection::~MockQuicConnection() {} + +void MockQuicConnection::AdvanceTime(QuicTime::Delta delta) { + static_cast(helper())->AdvanceTime(delta); +} + +bool MockQuicConnection::OnProtocolVersionMismatch( + ParsedQuicVersion /*version*/) { + return false; +} + +PacketSavingConnection::PacketSavingConnection( + QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory, + Perspective perspective) + : MockQuicConnection(helper, alarm_factory, perspective) {} + +PacketSavingConnection::PacketSavingConnection( + QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory, + Perspective perspective, const ParsedQuicVersionVector& supported_versions) + : MockQuicConnection(helper, alarm_factory, perspective, + supported_versions) {} + +PacketSavingConnection::~PacketSavingConnection() {} + +SerializedPacketFate PacketSavingConnection::GetSerializedPacketFate( + bool /*is_mtu_discovery*/, EncryptionLevel /*encryption_level*/) { + return SEND_TO_WRITER; +} + +void PacketSavingConnection::SendOrQueuePacket(SerializedPacket packet) { + encrypted_packets_.push_back(std::make_unique( + CopyBuffer(packet), packet.encrypted_length, true)); + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + // Transfer ownership of the packet to the SentPacketManager and the + // ack notifier to the AckNotifierManager. + OnPacketSent(packet.encryption_level, packet.transmission_type); + QuicConnectionPeer::GetSentPacketManager(this)->OnPacketSent( + &packet, clock_.ApproximateNow(), NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA, true, ECN_NOT_ECT); +} + +MockQuicSession::MockQuicSession(QuicConnection* connection) + : MockQuicSession(connection, true) {} + +MockQuicSession::MockQuicSession(QuicConnection* connection, + bool create_mock_crypto_stream) + : QuicSession(connection, nullptr, DefaultQuicConfig(), + connection->supported_versions(), + /*num_expected_unidirectional_static_streams = */ 0) { + if (create_mock_crypto_stream) { + crypto_stream_ = std::make_unique(this); + } + ON_CALL(*this, WritevData(_, _, _, _, _, _)) + .WillByDefault(testing::Return(QuicConsumedData(0, false))); +} + +MockQuicSession::~MockQuicSession() { DeleteConnection(); } + +QuicCryptoStream* MockQuicSession::GetMutableCryptoStream() { + return crypto_stream_.get(); +} + +const QuicCryptoStream* MockQuicSession::GetCryptoStream() const { + return crypto_stream_.get(); +} + +void MockQuicSession::SetCryptoStream(QuicCryptoStream* crypto_stream) { + crypto_stream_.reset(crypto_stream); +} + +QuicConsumedData MockQuicSession::ConsumeData( + QuicStreamId id, size_t write_length, QuicStreamOffset offset, + StreamSendingState state, TransmissionType /*type*/, + absl::optional /*level*/) { + if (write_length > 0) { + auto buf = std::make_unique(write_length); + QuicStream* stream = GetOrCreateStream(id); + QUICHE_DCHECK(stream); + QuicDataWriter writer(write_length, buf.get(), quiche::HOST_BYTE_ORDER); + stream->WriteStreamData(offset, write_length, &writer); + } else { + QUICHE_DCHECK(state != NO_FIN); + } + return QuicConsumedData(write_length, state != NO_FIN); +} + +MockQuicCryptoStream::MockQuicCryptoStream(QuicSession* session) + : QuicCryptoStream(session), params_(new QuicCryptoNegotiatedParameters) {} + +MockQuicCryptoStream::~MockQuicCryptoStream() {} + +ssl_early_data_reason_t MockQuicCryptoStream::EarlyDataReason() const { + return ssl_early_data_unknown; +} + +bool MockQuicCryptoStream::encryption_established() const { return false; } + +bool MockQuicCryptoStream::one_rtt_keys_available() const { return false; } + +const QuicCryptoNegotiatedParameters& +MockQuicCryptoStream::crypto_negotiated_params() const { + return *params_; +} + +CryptoMessageParser* MockQuicCryptoStream::crypto_message_parser() { + return &crypto_framer_; +} + +MockQuicSpdySession::MockQuicSpdySession(QuicConnection* connection) + : MockQuicSpdySession(connection, true) {} + +MockQuicSpdySession::MockQuicSpdySession(QuicConnection* connection, + bool create_mock_crypto_stream) + : QuicSpdySession(connection, nullptr, DefaultQuicConfig(), + connection->supported_versions()) { + if (create_mock_crypto_stream) { + crypto_stream_ = std::make_unique(this); + } + + ON_CALL(*this, WritevData(_, _, _, _, _, _)) + .WillByDefault(testing::Return(QuicConsumedData(0, false))); + + ON_CALL(*this, SendWindowUpdate(_, _)) + .WillByDefault([this](QuicStreamId id, QuicStreamOffset byte_offset) { + return QuicSpdySession::SendWindowUpdate(id, byte_offset); + }); + + ON_CALL(*this, SendBlocked(_, _)) + .WillByDefault([this](QuicStreamId id, QuicStreamOffset byte_offset) { + return QuicSpdySession::SendBlocked(id, byte_offset); + }); + + ON_CALL(*this, OnCongestionWindowChange(_)).WillByDefault(testing::Return()); +} + +MockQuicSpdySession::~MockQuicSpdySession() { DeleteConnection(); } + +QuicCryptoStream* MockQuicSpdySession::GetMutableCryptoStream() { + return crypto_stream_.get(); +} + +const QuicCryptoStream* MockQuicSpdySession::GetCryptoStream() const { + return crypto_stream_.get(); +} + +void MockQuicSpdySession::SetCryptoStream(QuicCryptoStream* crypto_stream) { + crypto_stream_.reset(crypto_stream); +} + +QuicConsumedData MockQuicSpdySession::ConsumeData( + QuicStreamId id, size_t write_length, QuicStreamOffset offset, + StreamSendingState state, TransmissionType /*type*/, + absl::optional /*level*/) { + if (write_length > 0) { + auto buf = std::make_unique(write_length); + QuicStream* stream = GetOrCreateStream(id); + QUICHE_DCHECK(stream); + QuicDataWriter writer(write_length, buf.get(), quiche::HOST_BYTE_ORDER); + stream->WriteStreamData(offset, write_length, &writer); + } else { + QUICHE_DCHECK(state != NO_FIN); + } + return QuicConsumedData(write_length, state != NO_FIN); +} + +TestQuicSpdyServerSession::TestQuicSpdyServerSession( + QuicConnection* connection, const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache) + : QuicServerSessionBase(config, supported_versions, connection, &visitor_, + &helper_, crypto_config, compressed_certs_cache) { + ON_CALL(helper_, CanAcceptClientHello(_, _, _, _, _)) + .WillByDefault(testing::Return(true)); +} + +TestQuicSpdyServerSession::~TestQuicSpdyServerSession() { DeleteConnection(); } + +std::unique_ptr +TestQuicSpdyServerSession::CreateQuicCryptoServerStream( + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache) { + return CreateCryptoServerStream(crypto_config, compressed_certs_cache, this, + &helper_); +} + +QuicCryptoServerStreamBase* +TestQuicSpdyServerSession::GetMutableCryptoStream() { + return QuicServerSessionBase::GetMutableCryptoStream(); +} + +const QuicCryptoServerStreamBase* TestQuicSpdyServerSession::GetCryptoStream() + const { + return QuicServerSessionBase::GetCryptoStream(); +} + +TestQuicSpdyClientSession::TestQuicSpdyClientSession( + QuicConnection* connection, const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + const QuicServerId& server_id, QuicCryptoClientConfig* crypto_config, + absl::optional ssl_config) + : QuicSpdyClientSessionBase(connection, nullptr, &push_promise_index_, + config, supported_versions), + ssl_config_(std::move(ssl_config)) { + // TODO(b/153726130): Consider adding SetServerApplicationStateForResumption + // calls in tests and set |has_application_state| to true. + crypto_stream_ = std::make_unique( + server_id, this, crypto_test_utils::ProofVerifyContextForTesting(), + crypto_config, this, /*has_application_state = */ false); + Initialize(); + ON_CALL(*this, OnConfigNegotiated()) + .WillByDefault( + Invoke(this, &TestQuicSpdyClientSession::RealOnConfigNegotiated)); +} + +TestQuicSpdyClientSession::~TestQuicSpdyClientSession() {} + +bool TestQuicSpdyClientSession::IsAuthorized(const std::string& /*authority*/) { + return true; +} + +QuicCryptoClientStream* TestQuicSpdyClientSession::GetMutableCryptoStream() { + return crypto_stream_.get(); +} + +const QuicCryptoClientStream* TestQuicSpdyClientSession::GetCryptoStream() + const { + return crypto_stream_.get(); +} + +void TestQuicSpdyClientSession::RealOnConfigNegotiated() { + QuicSpdyClientSessionBase::OnConfigNegotiated(); +} + +TestPushPromiseDelegate::TestPushPromiseDelegate(bool match) + : match_(match), rendezvous_fired_(false), rendezvous_stream_(nullptr) {} + +bool TestPushPromiseDelegate::CheckVary( + const spdy::Http2HeaderBlock& /*client_request*/, + const spdy::Http2HeaderBlock& /*promise_request*/, + const spdy::Http2HeaderBlock& /*promise_response*/) { + QUIC_DVLOG(1) << "match " << match_; + return match_; +} + +void TestPushPromiseDelegate::OnRendezvousResult(QuicSpdyStream* stream) { + rendezvous_fired_ = true; + rendezvous_stream_ = stream; +} + +MockPacketWriter::MockPacketWriter() { + ON_CALL(*this, GetMaxPacketSize(_)) + .WillByDefault(testing::Return(kMaxOutgoingPacketSize)); + ON_CALL(*this, IsBatchMode()).WillByDefault(testing::Return(false)); + ON_CALL(*this, GetNextWriteLocation(_, _)) + .WillByDefault(testing::Return(QuicPacketBuffer())); + ON_CALL(*this, Flush()) + .WillByDefault(testing::Return(WriteResult(WRITE_STATUS_OK, 0))); + ON_CALL(*this, SupportsReleaseTime()).WillByDefault(testing::Return(false)); +} + +MockPacketWriter::~MockPacketWriter() {} + +MockSendAlgorithm::MockSendAlgorithm() { + ON_CALL(*this, PacingRate(_)) + .WillByDefault(testing::Return(QuicBandwidth::Zero())); + ON_CALL(*this, BandwidthEstimate()) + .WillByDefault(testing::Return(QuicBandwidth::Zero())); +} + +MockSendAlgorithm::~MockSendAlgorithm() {} + +MockLossAlgorithm::MockLossAlgorithm() {} + +MockLossAlgorithm::~MockLossAlgorithm() {} + +MockAckListener::MockAckListener() {} + +MockAckListener::~MockAckListener() {} + +MockNetworkChangeVisitor::MockNetworkChangeVisitor() {} + +MockNetworkChangeVisitor::~MockNetworkChangeVisitor() {} + +QuicIpAddress TestPeerIPAddress() { return QuicIpAddress::Loopback4(); } + +ParsedQuicVersion QuicVersionMax() { return AllSupportedVersions().front(); } + +ParsedQuicVersion QuicVersionMin() { return AllSupportedVersions().back(); } + +void DisableQuicVersionsWithTls() { + for (const ParsedQuicVersion& version : AllSupportedVersionsWithTls()) { + QuicDisableVersion(version); + } +} + +QuicEncryptedPacket* ConstructEncryptedPacket( + QuicConnectionId destination_connection_id, + QuicConnectionId source_connection_id, bool version_flag, bool reset_flag, + uint64_t packet_number, const std::string& data) { + return ConstructEncryptedPacket( + destination_connection_id, source_connection_id, version_flag, reset_flag, + packet_number, data, CONNECTION_ID_PRESENT, CONNECTION_ID_ABSENT, + PACKET_4BYTE_PACKET_NUMBER); +} + +QuicEncryptedPacket* ConstructEncryptedPacket( + QuicConnectionId destination_connection_id, + QuicConnectionId source_connection_id, bool version_flag, bool reset_flag, + uint64_t packet_number, const std::string& data, + QuicConnectionIdIncluded destination_connection_id_included, + QuicConnectionIdIncluded source_connection_id_included, + QuicPacketNumberLength packet_number_length) { + return ConstructEncryptedPacket( + destination_connection_id, source_connection_id, version_flag, reset_flag, + packet_number, data, destination_connection_id_included, + source_connection_id_included, packet_number_length, nullptr); +} + +QuicEncryptedPacket* ConstructEncryptedPacket( + QuicConnectionId destination_connection_id, + QuicConnectionId source_connection_id, bool version_flag, bool reset_flag, + uint64_t packet_number, const std::string& data, + QuicConnectionIdIncluded destination_connection_id_included, + QuicConnectionIdIncluded source_connection_id_included, + QuicPacketNumberLength packet_number_length, + ParsedQuicVersionVector* versions) { + return ConstructEncryptedPacket( + destination_connection_id, source_connection_id, version_flag, reset_flag, + packet_number, data, false, destination_connection_id_included, + source_connection_id_included, packet_number_length, versions, + Perspective::IS_CLIENT); +} + +QuicEncryptedPacket* ConstructEncryptedPacket( + QuicConnectionId destination_connection_id, + QuicConnectionId source_connection_id, bool version_flag, bool reset_flag, + uint64_t packet_number, const std::string& data, bool full_padding, + QuicConnectionIdIncluded destination_connection_id_included, + QuicConnectionIdIncluded source_connection_id_included, + QuicPacketNumberLength packet_number_length, + ParsedQuicVersionVector* versions) { + return ConstructEncryptedPacket( + destination_connection_id, source_connection_id, version_flag, reset_flag, + packet_number, data, full_padding, destination_connection_id_included, + source_connection_id_included, packet_number_length, versions, + Perspective::IS_CLIENT); +} + +QuicEncryptedPacket* ConstructEncryptedPacket( + QuicConnectionId destination_connection_id, + QuicConnectionId source_connection_id, bool version_flag, bool reset_flag, + uint64_t packet_number, const std::string& data, bool full_padding, + QuicConnectionIdIncluded destination_connection_id_included, + QuicConnectionIdIncluded source_connection_id_included, + QuicPacketNumberLength packet_number_length, + ParsedQuicVersionVector* versions, Perspective perspective) { + QuicPacketHeader header; + header.destination_connection_id = destination_connection_id; + header.destination_connection_id_included = + destination_connection_id_included; + header.source_connection_id = source_connection_id; + header.source_connection_id_included = source_connection_id_included; + header.version_flag = version_flag; + header.reset_flag = reset_flag; + header.packet_number_length = packet_number_length; + header.packet_number = QuicPacketNumber(packet_number); + ParsedQuicVersionVector supported_versions = CurrentSupportedVersions(); + if (!versions) { + versions = &supported_versions; + } + EXPECT_FALSE(versions->empty()); + ParsedQuicVersion version = (*versions)[0]; + if (QuicVersionHasLongHeaderLengths(version.transport_version) && + version_flag) { + header.retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1; + header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2; + } + + QuicFrames frames; + QuicFramer framer(*versions, QuicTime::Zero(), perspective, + kQuicDefaultConnectionIdLength); + framer.SetInitialObfuscators(destination_connection_id); + EncryptionLevel level = + header.version_flag ? ENCRYPTION_INITIAL : ENCRYPTION_FORWARD_SECURE; + if (level != ENCRYPTION_INITIAL) { + framer.SetEncrypter(level, std::make_unique(level)); + } + if (!QuicVersionUsesCryptoFrames(version.transport_version)) { + QuicFrame frame( + QuicStreamFrame(QuicUtils::GetCryptoStreamId(version.transport_version), + false, 0, absl::string_view(data))); + frames.push_back(frame); + } else { + QuicFrame frame(new QuicCryptoFrame(level, 0, data)); + frames.push_back(frame); + } + if (full_padding) { + frames.push_back(QuicFrame(QuicPaddingFrame(-1))); + } else { + // We need a minimum number of bytes of encrypted payload. This will + // guarantee that we have at least that much. (It ignores the overhead of + // the stream/crypto framing, so it overpads slightly.) + size_t min_plaintext_size = QuicPacketCreator::MinPlaintextPacketSize( + version, packet_number_length); + if (data.length() < min_plaintext_size) { + size_t padding_length = min_plaintext_size - data.length(); + frames.push_back(QuicFrame(QuicPaddingFrame(padding_length))); + } + } + + std::unique_ptr packet( + BuildUnsizedDataPacket(&framer, header, frames)); + EXPECT_TRUE(packet != nullptr); + char* buffer = new char[kMaxOutgoingPacketSize]; + size_t encrypted_length = + framer.EncryptPayload(level, QuicPacketNumber(packet_number), *packet, + buffer, kMaxOutgoingPacketSize); + EXPECT_NE(0u, encrypted_length); + DeleteFrames(&frames); + return new QuicEncryptedPacket(buffer, encrypted_length, true); +} + +std::unique_ptr GetUndecryptableEarlyPacket( + const ParsedQuicVersion& version, + const QuicConnectionId& server_connection_id) { + QuicPacketHeader header; + header.destination_connection_id = server_connection_id; + header.destination_connection_id_included = CONNECTION_ID_PRESENT; + header.source_connection_id = EmptyQuicConnectionId(); + header.source_connection_id_included = CONNECTION_ID_PRESENT; + if (!version.SupportsClientConnectionIds()) { + header.source_connection_id_included = CONNECTION_ID_ABSENT; + } + header.version_flag = true; + header.reset_flag = false; + header.packet_number_length = PACKET_4BYTE_PACKET_NUMBER; + header.packet_number = QuicPacketNumber(33); + header.long_packet_type = ZERO_RTT_PROTECTED; + if (version.HasLongHeaderLengths()) { + header.retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1; + header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2; + } + + QuicFrames frames; + frames.push_back(QuicFrame(QuicPingFrame())); + frames.push_back(QuicFrame(QuicPaddingFrame(100))); + QuicFramer framer({version}, QuicTime::Zero(), Perspective::IS_CLIENT, + kQuicDefaultConnectionIdLength); + framer.SetInitialObfuscators(server_connection_id); + + framer.SetEncrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(ENCRYPTION_ZERO_RTT)); + std::unique_ptr packet( + BuildUnsizedDataPacket(&framer, header, frames)); + EXPECT_TRUE(packet != nullptr); + char* buffer = new char[kMaxOutgoingPacketSize]; + size_t encrypted_length = + framer.EncryptPayload(ENCRYPTION_ZERO_RTT, header.packet_number, *packet, + buffer, kMaxOutgoingPacketSize); + EXPECT_NE(0u, encrypted_length); + DeleteFrames(&frames); + return std::make_unique(buffer, encrypted_length, + /*owns_buffer=*/true); +} + +QuicReceivedPacket* ConstructReceivedPacket( + const QuicEncryptedPacket& encrypted_packet, QuicTime receipt_time) { + char* buffer = new char[encrypted_packet.length()]; + memcpy(buffer, encrypted_packet.data(), encrypted_packet.length()); + return new QuicReceivedPacket(buffer, encrypted_packet.length(), receipt_time, + true); +} + +QuicEncryptedPacket* ConstructMisFramedEncryptedPacket( + QuicConnectionId destination_connection_id, + QuicConnectionId source_connection_id, bool version_flag, bool reset_flag, + uint64_t packet_number, const std::string& data, + QuicConnectionIdIncluded destination_connection_id_included, + QuicConnectionIdIncluded source_connection_id_included, + QuicPacketNumberLength packet_number_length, ParsedQuicVersion version, + Perspective perspective) { + QuicPacketHeader header; + header.destination_connection_id = destination_connection_id; + header.destination_connection_id_included = + destination_connection_id_included; + header.source_connection_id = source_connection_id; + header.source_connection_id_included = source_connection_id_included; + header.version_flag = version_flag; + header.reset_flag = reset_flag; + header.packet_number_length = packet_number_length; + header.packet_number = QuicPacketNumber(packet_number); + if (QuicVersionHasLongHeaderLengths(version.transport_version) && + version_flag) { + header.retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1; + header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2; + } + QuicFrame frame(QuicStreamFrame(1, false, 0, absl::string_view(data))); + QuicFrames frames; + frames.push_back(frame); + QuicFramer framer({version}, QuicTime::Zero(), perspective, + kQuicDefaultConnectionIdLength); + framer.SetInitialObfuscators(destination_connection_id); + EncryptionLevel level = + version_flag ? ENCRYPTION_INITIAL : ENCRYPTION_FORWARD_SECURE; + if (level != ENCRYPTION_INITIAL) { + framer.SetEncrypter(level, std::make_unique(level)); + } + // We need a minimum of 7 bytes of encrypted payload. This will guarantee that + // we have at least that much. (It ignores the overhead of the stream/crypto + // framing, so it overpads slightly.) + if (data.length() < 7) { + size_t padding_length = 7 - data.length(); + frames.push_back(QuicFrame(QuicPaddingFrame(padding_length))); + } + + std::unique_ptr packet( + BuildUnsizedDataPacket(&framer, header, frames)); + EXPECT_TRUE(packet != nullptr); + + // Now set the frame type to 0x1F, which is an invalid frame type. + reinterpret_cast( + packet->mutable_data())[GetStartOfEncryptedData( + framer.transport_version(), + GetIncludedDestinationConnectionIdLength(header), + GetIncludedSourceConnectionIdLength(header), version_flag, + false /* no diversification nonce */, packet_number_length, + header.retry_token_length_length, 0, header.length_length)] = 0x1F; + + char* buffer = new char[kMaxOutgoingPacketSize]; + size_t encrypted_length = + framer.EncryptPayload(level, QuicPacketNumber(packet_number), *packet, + buffer, kMaxOutgoingPacketSize); + EXPECT_NE(0u, encrypted_length); + return new QuicEncryptedPacket(buffer, encrypted_length, true); +} + +QuicConfig DefaultQuicConfig() { + QuicConfig config; + config.SetInitialMaxStreamDataBytesIncomingBidirectionalToSend( + kInitialStreamFlowControlWindowForTest); + config.SetInitialMaxStreamDataBytesOutgoingBidirectionalToSend( + kInitialStreamFlowControlWindowForTest); + config.SetInitialMaxStreamDataBytesUnidirectionalToSend( + kInitialStreamFlowControlWindowForTest); + config.SetInitialStreamFlowControlWindowToSend( + kInitialStreamFlowControlWindowForTest); + config.SetInitialSessionFlowControlWindowToSend( + kInitialSessionFlowControlWindowForTest); + QuicConfigPeer::SetReceivedMaxBidirectionalStreams( + &config, kDefaultMaxStreamsPerConnection); + // Default enable NSTP. + // This is unnecessary for versions > 44 + if (!config.HasClientSentConnectionOption(quic::kNSTP, + quic::Perspective::IS_CLIENT)) { + quic::QuicTagVector connection_options; + connection_options.push_back(quic::kNSTP); + config.SetConnectionOptionsToSend(connection_options); + } + return config; +} + +ParsedQuicVersionVector SupportedVersions(ParsedQuicVersion version) { + ParsedQuicVersionVector versions; + versions.push_back(version); + return versions; +} + +MockQuicConnectionDebugVisitor::MockQuicConnectionDebugVisitor() {} + +MockQuicConnectionDebugVisitor::~MockQuicConnectionDebugVisitor() {} + +MockReceivedPacketManager::MockReceivedPacketManager(QuicConnectionStats* stats) + : QuicReceivedPacketManager(stats) {} + +MockReceivedPacketManager::~MockReceivedPacketManager() {} + +MockPacketCreatorDelegate::MockPacketCreatorDelegate() {} +MockPacketCreatorDelegate::~MockPacketCreatorDelegate() {} + +MockSessionNotifier::MockSessionNotifier() {} +MockSessionNotifier::~MockSessionNotifier() {} + +// static +QuicCryptoClientStream::HandshakerInterface* +QuicCryptoClientStreamPeer::GetHandshaker(QuicCryptoClientStream* stream) { + return stream->handshaker_.get(); +} + +void CreateClientSessionForTest( + QuicServerId server_id, QuicTime::Delta connection_start_time, + const ParsedQuicVersionVector& supported_versions, + QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory, + QuicCryptoClientConfig* crypto_client_config, + PacketSavingConnection** client_connection, + TestQuicSpdyClientSession** client_session) { + QUICHE_CHECK(crypto_client_config); + QUICHE_CHECK(client_connection); + QUICHE_CHECK(client_session); + QUICHE_CHECK(!connection_start_time.IsZero()) + << "Connections must start at non-zero times, otherwise the " + << "strike-register will be unhappy."; + + QuicConfig config = DefaultQuicConfig(); + *client_connection = new PacketSavingConnection( + helper, alarm_factory, Perspective::IS_CLIENT, supported_versions); + *client_session = new TestQuicSpdyClientSession(*client_connection, config, + supported_versions, server_id, + crypto_client_config); + (*client_connection)->AdvanceTime(connection_start_time); +} + +void CreateServerSessionForTest( + QuicServerId /*server_id*/, QuicTime::Delta connection_start_time, + ParsedQuicVersionVector supported_versions, + QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory, + QuicCryptoServerConfig* server_crypto_config, + QuicCompressedCertsCache* compressed_certs_cache, + PacketSavingConnection** server_connection, + TestQuicSpdyServerSession** server_session) { + QUICHE_CHECK(server_crypto_config); + QUICHE_CHECK(server_connection); + QUICHE_CHECK(server_session); + QUICHE_CHECK(!connection_start_time.IsZero()) + << "Connections must start at non-zero times, otherwise the " + << "strike-register will be unhappy."; + + *server_connection = + new PacketSavingConnection(helper, alarm_factory, Perspective::IS_SERVER, + ParsedVersionOfIndex(supported_versions, 0)); + *server_session = new TestQuicSpdyServerSession( + *server_connection, DefaultQuicConfig(), supported_versions, + server_crypto_config, compressed_certs_cache); + (*server_session)->Initialize(); + + // We advance the clock initially because the default time is zero and the + // strike register worries that we've just overflowed a uint32_t time. + (*server_connection)->AdvanceTime(connection_start_time); +} + +QuicStreamId GetNthClientInitiatedBidirectionalStreamId( + QuicTransportVersion version, int n) { + int num = n; + if (!VersionUsesHttp3(version)) { + num++; + } + return QuicUtils::GetFirstBidirectionalStreamId(version, + Perspective::IS_CLIENT) + + QuicUtils::StreamIdDelta(version) * num; +} + +QuicStreamId GetNthServerInitiatedBidirectionalStreamId( + QuicTransportVersion version, int n) { + return QuicUtils::GetFirstBidirectionalStreamId(version, + Perspective::IS_SERVER) + + QuicUtils::StreamIdDelta(version) * n; +} + +QuicStreamId GetNthServerInitiatedUnidirectionalStreamId( + QuicTransportVersion version, int n) { + return QuicUtils::GetFirstUnidirectionalStreamId(version, + Perspective::IS_SERVER) + + QuicUtils::StreamIdDelta(version) * n; +} + +QuicStreamId GetNthClientInitiatedUnidirectionalStreamId( + QuicTransportVersion version, int n) { + return QuicUtils::GetFirstUnidirectionalStreamId(version, + Perspective::IS_CLIENT) + + QuicUtils::StreamIdDelta(version) * n; +} + +StreamType DetermineStreamType(QuicStreamId id, ParsedQuicVersion version, + Perspective perspective, bool is_incoming, + StreamType default_type) { + return version.HasIetfQuicFrames() + ? QuicUtils::GetStreamType(id, perspective, is_incoming, version) + : default_type; +} + +quiche::QuicheMemSlice MemSliceFromString(absl::string_view data) { + if (data.empty()) { + return quiche::QuicheMemSlice(); + } + + static quiche::SimpleBufferAllocator* allocator = + new quiche::SimpleBufferAllocator(); + return quiche::QuicheMemSlice(quiche::QuicheBuffer::Copy(allocator, data)); +} + +bool TaggingEncrypter::EncryptPacket(uint64_t /*packet_number*/, + absl::string_view /*associated_data*/, + absl::string_view plaintext, char* output, + size_t* output_length, + size_t max_output_length) { + const size_t len = plaintext.size() + kTagSize; + if (max_output_length < len) { + return false; + } + // Memmove is safe for inplace encryption. + memmove(output, plaintext.data(), plaintext.size()); + output += plaintext.size(); + memset(output, tag_, kTagSize); + *output_length = len; + return true; +} + +bool TaggingDecrypter::DecryptPacket(uint64_t /*packet_number*/, + absl::string_view /*associated_data*/, + absl::string_view ciphertext, char* output, + size_t* output_length, + size_t /*max_output_length*/) { + if (ciphertext.size() < kTagSize) { + return false; + } + if (!CheckTag(ciphertext, GetTag(ciphertext))) { + return false; + } + *output_length = ciphertext.size() - kTagSize; + memcpy(output, ciphertext.data(), *output_length); + return true; +} + +bool TaggingDecrypter::CheckTag(absl::string_view ciphertext, uint8_t tag) { + for (size_t i = ciphertext.size() - kTagSize; i < ciphertext.size(); i++) { + if (ciphertext.data()[i] != tag) { + return false; + } + } + + return true; +} + +TestPacketWriter::TestPacketWriter(ParsedQuicVersion version, MockClock* clock, + Perspective perspective) + : version_(version), + framer_(SupportedVersions(version_), + QuicUtils::InvertPerspective(perspective)), + clock_(clock) { + QuicFramerPeer::SetLastSerializedServerConnectionId(framer_.framer(), + TestConnectionId()); + framer_.framer()->SetInitialObfuscators(TestConnectionId()); + + for (int i = 0; i < 128; ++i) { + PacketBuffer* p = new PacketBuffer(); + packet_buffer_pool_.push_back(p); + packet_buffer_pool_index_[p->buffer] = p; + packet_buffer_free_list_.push_back(p); + } +} + +TestPacketWriter::~TestPacketWriter() { + EXPECT_EQ(packet_buffer_pool_.size(), packet_buffer_free_list_.size()) + << packet_buffer_pool_.size() - packet_buffer_free_list_.size() + << " out of " << packet_buffer_pool_.size() + << " packet buffers have been leaked."; + for (auto p : packet_buffer_pool_) { + delete p; + } +} + +WriteResult TestPacketWriter::WritePacket(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + PerPacketOptions* options) { + last_write_source_address_ = self_address; + last_write_peer_address_ = peer_address; + // If the buffer is allocated from the pool, return it back to the pool. + // Note the buffer content doesn't change. + if (packet_buffer_pool_index_.find(const_cast(buffer)) != + packet_buffer_pool_index_.end()) { + FreePacketBuffer(buffer); + } + + QuicEncryptedPacket packet(buffer, buf_len); + ++packets_write_attempts_; + + if (packet.length() >= sizeof(final_bytes_of_last_packet_)) { + final_bytes_of_previous_packet_ = final_bytes_of_last_packet_; + memcpy(&final_bytes_of_last_packet_, packet.data() + packet.length() - 4, + sizeof(final_bytes_of_last_packet_)); + } + if (framer_.framer()->version().KnowsWhichDecrypterToUse()) { + framer_.framer()->InstallDecrypter(ENCRYPTION_HANDSHAKE, + std::make_unique()); + framer_.framer()->InstallDecrypter(ENCRYPTION_ZERO_RTT, + std::make_unique()); + framer_.framer()->InstallDecrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique()); + } else if (!framer_.framer()->HasDecrypterOfEncryptionLevel( + ENCRYPTION_FORWARD_SECURE) && + !framer_.framer()->HasDecrypterOfEncryptionLevel( + ENCRYPTION_ZERO_RTT)) { + framer_.framer()->SetAlternativeDecrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(ENCRYPTION_FORWARD_SECURE), + false); + } + EXPECT_EQ(next_packet_processable_, framer_.ProcessPacket(packet)) + << framer_.framer()->detailed_error() << " perspective " + << framer_.framer()->perspective(); + next_packet_processable_ = true; + if (block_on_next_write_) { + write_blocked_ = true; + block_on_next_write_ = false; + } + if (next_packet_too_large_) { + next_packet_too_large_ = false; + return WriteResult(WRITE_STATUS_ERROR, *MessageTooBigErrorCode()); + } + if (always_get_packet_too_large_) { + return WriteResult(WRITE_STATUS_ERROR, *MessageTooBigErrorCode()); + } + if (IsWriteBlocked()) { + return WriteResult(is_write_blocked_data_buffered_ + ? WRITE_STATUS_BLOCKED_DATA_BUFFERED + : WRITE_STATUS_BLOCKED, + 0); + } + + if (ShouldWriteFail()) { + return WriteResult(WRITE_STATUS_ERROR, write_error_code_); + } + + last_packet_size_ = packet.length(); + total_bytes_written_ += packet.length(); + last_packet_header_ = framer_.header(); + if (!framer_.connection_close_frames().empty()) { + ++connection_close_packets_; + } + if (!write_pause_time_delta_.IsZero()) { + clock_->AdvanceTime(write_pause_time_delta_); + } + if (is_batch_mode_) { + bytes_buffered_ += last_packet_size_; + return WriteResult(WRITE_STATUS_OK, 0); + } + last_ecn_sent_ = (options == nullptr) ? ECN_NOT_ECT : options->ecn_codepoint; + return WriteResult(WRITE_STATUS_OK, last_packet_size_); +} + +QuicPacketBuffer TestPacketWriter::GetNextWriteLocation( + const QuicIpAddress& /*self_address*/, + const QuicSocketAddress& /*peer_address*/) { + return {AllocPacketBuffer(), [this](const char* p) { FreePacketBuffer(p); }}; +} + +WriteResult TestPacketWriter::Flush() { + flush_attempts_++; + if (block_on_next_flush_) { + block_on_next_flush_ = false; + SetWriteBlocked(); + return WriteResult(WRITE_STATUS_BLOCKED, /*errno*/ -1); + } + if (write_should_fail_) { + return WriteResult(WRITE_STATUS_ERROR, /*errno*/ -1); + } + int bytes_flushed = bytes_buffered_; + bytes_buffered_ = 0; + return WriteResult(WRITE_STATUS_OK, bytes_flushed); +} + +char* TestPacketWriter::AllocPacketBuffer() { + PacketBuffer* p = packet_buffer_free_list_.front(); + EXPECT_FALSE(p->in_use); + p->in_use = true; + packet_buffer_free_list_.pop_front(); + return p->buffer; +} + +void TestPacketWriter::FreePacketBuffer(const char* buffer) { + auto iter = packet_buffer_pool_index_.find(const_cast(buffer)); + ASSERT_TRUE(iter != packet_buffer_pool_index_.end()); + PacketBuffer* p = iter->second; + ASSERT_TRUE(p->in_use); + p->in_use = false; + packet_buffer_free_list_.push_back(p); +} + +bool WriteServerVersionNegotiationProbeResponse( + char* packet_bytes, size_t* packet_length_out, + const char* source_connection_id_bytes, + uint8_t source_connection_id_length) { + if (packet_bytes == nullptr) { + QUIC_BUG(quic_bug_10256_1) << "Invalid packet_bytes"; + return false; + } + if (packet_length_out == nullptr) { + QUIC_BUG(quic_bug_10256_2) << "Invalid packet_length_out"; + return false; + } + QuicConnectionId source_connection_id(source_connection_id_bytes, + source_connection_id_length); + std::unique_ptr encrypted_packet = + QuicFramer::BuildVersionNegotiationPacket( + source_connection_id, EmptyQuicConnectionId(), + /*ietf_quic=*/true, /*use_length_prefix=*/true, + ParsedQuicVersionVector{}); + if (!encrypted_packet) { + QUIC_BUG(quic_bug_10256_3) << "Failed to create version negotiation packet"; + return false; + } + if (*packet_length_out < encrypted_packet->length()) { + QUIC_BUG(quic_bug_10256_4) + << "Invalid *packet_length_out " << *packet_length_out << " < " + << encrypted_packet->length(); + return false; + } + *packet_length_out = encrypted_packet->length(); + memcpy(packet_bytes, encrypted_packet->data(), *packet_length_out); + return true; +} + +bool ParseClientVersionNegotiationProbePacket( + const char* packet_bytes, size_t packet_length, + char* destination_connection_id_bytes, + uint8_t* destination_connection_id_length_out) { + if (packet_bytes == nullptr) { + QUIC_BUG(quic_bug_10256_5) << "Invalid packet_bytes"; + return false; + } + if (packet_length < kMinPacketSizeForVersionNegotiation || + packet_length > 65535) { + QUIC_BUG(quic_bug_10256_6) << "Invalid packet_length"; + return false; + } + if (destination_connection_id_bytes == nullptr) { + QUIC_BUG(quic_bug_10256_7) << "Invalid destination_connection_id_bytes"; + return false; + } + if (destination_connection_id_length_out == nullptr) { + QUIC_BUG(quic_bug_10256_8) + << "Invalid destination_connection_id_length_out"; + return false; + } + + QuicEncryptedPacket encrypted_packet(packet_bytes, packet_length); + PacketHeaderFormat format; + QuicLongHeaderType long_packet_type; + bool version_present, has_length_prefix; + QuicVersionLabel version_label; + ParsedQuicVersion parsed_version = ParsedQuicVersion::Unsupported(); + QuicConnectionId destination_connection_id, source_connection_id; + absl::optional retry_token; + std::string detailed_error; + QuicErrorCode error = QuicFramer::ParsePublicHeaderDispatcher( + encrypted_packet, + /*expected_destination_connection_id_length=*/0, &format, + &long_packet_type, &version_present, &has_length_prefix, &version_label, + &parsed_version, &destination_connection_id, &source_connection_id, + &retry_token, &detailed_error); + if (error != QUIC_NO_ERROR) { + QUIC_BUG(quic_bug_10256_9) << "Failed to parse packet: " << detailed_error; + return false; + } + if (!version_present) { + QUIC_BUG(quic_bug_10256_10) << "Packet is not a long header"; + return false; + } + if (*destination_connection_id_length_out < + destination_connection_id.length()) { + QUIC_BUG(quic_bug_10256_11) + << "destination_connection_id_length_out too small"; + return false; + } + *destination_connection_id_length_out = destination_connection_id.length(); + memcpy(destination_connection_id_bytes, destination_connection_id.data(), + *destination_connection_id_length_out); + return true; +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_test_utils.h b/quiche/quic/test_tools/quic_test_utils.h new file mode 100644 index 000000000000..2835ac167a44 --- /dev/null +++ b/quiche/quic/test_tools/quic_test_utils.h @@ -0,0 +1,2197 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Common utilities for Quic tests + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_TEST_UTILS_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_TEST_UTILS_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/congestion_control/loss_detection_interface.h" +#include "quiche/quic/core/congestion_control/send_algorithm_interface.h" +#include "quiche/quic/core/crypto/transport_parameters.h" +#include "quiche/quic/core/http/http_decoder.h" +#include "quiche/quic/core/http/quic_client_push_promise_index.h" +#include "quiche/quic/core/http/quic_server_session_base.h" +#include "quiche/quic/core/http/quic_spdy_session.h" +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_framer.h" +#include "quiche/quic/core/quic_packet_writer.h" +#include "quiche/quic/core/quic_path_validator.h" +#include "quiche/quic/core/quic_sent_packet_manager.h" +#include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/test_tools/mock_clock.h" +#include "quiche/quic/test_tools/mock_connection_id_generator.h" +#include "quiche/quic/test_tools/mock_quic_session_visitor.h" +#include "quiche/quic/test_tools/mock_random.h" +#include "quiche/quic/test_tools/quic_framer_peer.h" +#include "quiche/quic/test_tools/simple_quic_framer.h" +#include "quiche/common/capsule.h" +#include "quiche/common/simple_buffer_allocator.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +namespace test { + +// A generic predictable connection ID suited for testing. +QuicConnectionId TestConnectionId(); + +// A generic predictable connection ID suited for testing, generated from a +// given number, such as an index. +QuicConnectionId TestConnectionId(uint64_t connection_number); + +// A generic predictable connection ID suited for testing, generated from a +// given number, such as an index. Guaranteed to be 9 bytes long. +QuicConnectionId TestConnectionIdNineBytesLong(uint64_t connection_number); + +// Extracts the connection number passed to TestConnectionId(). +uint64_t TestConnectionIdToUInt64(QuicConnectionId connection_id); + +enum : uint16_t { kTestPort = 12345 }; +enum : uint32_t { + kMaxDatagramFrameSizeForTest = 1333, + kMaxPacketSizeForTest = 9001, + kInitialStreamFlowControlWindowForTest = 1024 * 1024, // 1 MB + kInitialSessionFlowControlWindowForTest = 1536 * 1024, // 1.5 MB +}; + +enum : uint64_t { + kAckDelayExponentForTest = 10, + kMaxAckDelayForTest = 51, // ms + kActiveConnectionIdLimitForTest = 52, + kMinAckDelayUsForTest = 1000 +}; + +// Create an arbitrary stateless reset token, same across multiple calls. +std::vector CreateStatelessResetTokenForTest(); + +// A hostname useful for testing, returns "test.example.org". +std::string TestHostname(); + +// A server ID useful for testing, returns test.example.org:12345. +QuicServerId TestServerId(); + +// Returns the test peer IP address. +QuicIpAddress TestPeerIPAddress(); + +// Upper limit on versions we support. +ParsedQuicVersion QuicVersionMax(); + +// Lower limit on versions we support. +ParsedQuicVersion QuicVersionMin(); + +// Disables all flags that enable QUIC versions that use TLS. +// This is only meant as a temporary measure to prevent some broken tests +// from running with TLS. +void DisableQuicVersionsWithTls(); + +// Create an encrypted packet for testing. +// If versions == nullptr, uses &AllSupportedVersions(). +// Note that the packet is encrypted with NullEncrypter, so to decrypt the +// constructed packet, the framer must be set to use NullDecrypter. +QuicEncryptedPacket* ConstructEncryptedPacket( + QuicConnectionId destination_connection_id, + QuicConnectionId source_connection_id, bool version_flag, bool reset_flag, + uint64_t packet_number, const std::string& data, bool full_padding, + QuicConnectionIdIncluded destination_connection_id_included, + QuicConnectionIdIncluded source_connection_id_included, + QuicPacketNumberLength packet_number_length, + ParsedQuicVersionVector* versions, Perspective perspective); + +QuicEncryptedPacket* ConstructEncryptedPacket( + QuicConnectionId destination_connection_id, + QuicConnectionId source_connection_id, bool version_flag, bool reset_flag, + uint64_t packet_number, const std::string& data, bool full_padding, + QuicConnectionIdIncluded destination_connection_id_included, + QuicConnectionIdIncluded source_connection_id_included, + QuicPacketNumberLength packet_number_length, + ParsedQuicVersionVector* versions); + +// Create an encrypted packet for testing. +// If versions == nullptr, uses &AllSupportedVersions(). +// Note that the packet is encrypted with NullEncrypter, so to decrypt the +// constructed packet, the framer must be set to use NullDecrypter. +QuicEncryptedPacket* ConstructEncryptedPacket( + QuicConnectionId destination_connection_id, + QuicConnectionId source_connection_id, bool version_flag, bool reset_flag, + uint64_t packet_number, const std::string& data, + QuicConnectionIdIncluded destination_connection_id_included, + QuicConnectionIdIncluded source_connection_id_included, + QuicPacketNumberLength packet_number_length, + ParsedQuicVersionVector* versions); + +// This form assumes |versions| == nullptr. +QuicEncryptedPacket* ConstructEncryptedPacket( + QuicConnectionId destination_connection_id, + QuicConnectionId source_connection_id, bool version_flag, bool reset_flag, + uint64_t packet_number, const std::string& data, + QuicConnectionIdIncluded destination_connection_id_included, + QuicConnectionIdIncluded source_connection_id_included, + QuicPacketNumberLength packet_number_length); + +// This form assumes |connection_id_length| == PACKET_8BYTE_CONNECTION_ID, +// |packet_number_length| == PACKET_4BYTE_PACKET_NUMBER and +// |versions| == nullptr. +QuicEncryptedPacket* ConstructEncryptedPacket( + QuicConnectionId destination_connection_id, + QuicConnectionId source_connection_id, bool version_flag, bool reset_flag, + uint64_t packet_number, const std::string& data); + +// Creates a client-to-server ZERO-RTT packet that will fail to decrypt. +std::unique_ptr GetUndecryptableEarlyPacket( + const ParsedQuicVersion& version, + const QuicConnectionId& server_connection_id); + +// Constructs a received packet for testing. The caller must take ownership +// of the returned pointer. +QuicReceivedPacket* ConstructReceivedPacket( + const QuicEncryptedPacket& encrypted_packet, QuicTime receipt_time); + +// Create an encrypted packet for testing whose data portion erroneous. +// The specific way the data portion is erroneous is not specified, but +// it is an error that QuicFramer detects. +// Note that the packet is encrypted with NullEncrypter, so to decrypt the +// constructed packet, the framer must be set to use NullDecrypter. +QuicEncryptedPacket* ConstructMisFramedEncryptedPacket( + QuicConnectionId destination_connection_id, + QuicConnectionId source_connection_id, bool version_flag, bool reset_flag, + uint64_t packet_number, const std::string& data, + QuicConnectionIdIncluded destination_connection_id_included, + QuicConnectionIdIncluded source_connection_id_included, + QuicPacketNumberLength packet_number_length, ParsedQuicVersion version, + Perspective perspective); + +// Returns QuicConfig set to default values. +QuicConfig DefaultQuicConfig(); + +ParsedQuicVersionVector SupportedVersions(ParsedQuicVersion version); + +struct QuicAckBlock { + QuicPacketNumber start; // Included + QuicPacketNumber limit; // Excluded +}; + +// Testing convenience method to construct a QuicAckFrame with arbitrary ack +// blocks. Each block is given by a (closed-open) range of packet numbers. e.g.: +// InitAckFrame({{1, 10}}) +// => 1 ack block acking packet numbers 1 to 9. +// +// InitAckFrame({{1, 2}, {3, 4}}) +// => 2 ack blocks acking packet 1 and 3. Packet 2 is missing. +QuicAckFrame InitAckFrame(const std::vector& ack_blocks); + +// Testing convenience method to construct a QuicAckFrame with 1 ack block which +// covers packet number range [1, |largest_acked| + 1). +// Equivalent to InitAckFrame({{1, largest_acked + 1}}) +QuicAckFrame InitAckFrame(uint64_t largest_acked); +QuicAckFrame InitAckFrame(QuicPacketNumber largest_acked); + +// Testing convenience method to construct a QuicAckFrame with |num_ack_blocks| +// ack blocks of width 1 packet, starting from |least_unacked| + 2. +QuicAckFrame MakeAckFrameWithAckBlocks(size_t num_ack_blocks, + uint64_t least_unacked); + +// Testing convenice method to construct a QuicAckFrame with |largest_acked|, +// ack blocks of width 1 packet and |gap_size|. +QuicAckFrame MakeAckFrameWithGaps(uint64_t gap_size, size_t max_num_gaps, + uint64_t largest_acked); + +// Returns the encryption level that corresponds to the header type in +// |header|. If the header is for GOOGLE_QUIC_PACKET instead of an +// IETF-invariants packet, this function returns ENCRYPTION_INITIAL. +EncryptionLevel HeaderToEncryptionLevel(const QuicPacketHeader& header); + +// Returns a QuicPacket that is owned by the caller, and +// is populated with the fields in |header| and |frames|, or is nullptr if the +// packet could not be created. +std::unique_ptr BuildUnsizedDataPacket( + QuicFramer* framer, const QuicPacketHeader& header, + const QuicFrames& frames); +// Returns a QuicPacket that is owned by the caller, and of size |packet_size|. +std::unique_ptr BuildUnsizedDataPacket( + QuicFramer* framer, const QuicPacketHeader& header, + const QuicFrames& frames, size_t packet_size); + +// Compute SHA-1 hash of the supplied std::string. +std::string Sha1Hash(absl::string_view data); + +// Delete |frame| and return true. +bool ClearControlFrame(const QuicFrame& frame); +bool ClearControlFrameWithTransmissionType(const QuicFrame& frame, + TransmissionType type); + +// Simple random number generator used to compute random numbers suitable +// for pseudo-randomly dropping packets in tests. +class SimpleRandom : public QuicRandom { + public: + SimpleRandom() { set_seed(0); } + SimpleRandom(const SimpleRandom&) = delete; + SimpleRandom& operator=(const SimpleRandom&) = delete; + ~SimpleRandom() override {} + + // Generates |len| random bytes in the |data| buffer. + void RandBytes(void* data, size_t len) override; + // Returns a random number in the range [0, kuint64max]. + uint64_t RandUint64() override; + + // InsecureRandBytes behaves equivalently to RandBytes. + void InsecureRandBytes(void* data, size_t len) override; + // InsecureRandUint64 behaves equivalently to RandUint64. + uint64_t InsecureRandUint64() override; + + void set_seed(uint64_t seed); + + private: + uint8_t buffer_[4096]; + size_t buffer_offset_ = 0; + uint8_t key_[32]; + + void FillBuffer(); +}; + +class MockFramerVisitor : public QuicFramerVisitorInterface { + public: + MockFramerVisitor(); + MockFramerVisitor(const MockFramerVisitor&) = delete; + MockFramerVisitor& operator=(const MockFramerVisitor&) = delete; + ~MockFramerVisitor() override; + + MOCK_METHOD(void, OnError, (QuicFramer*), (override)); + // The constructor sets this up to return false by default. + MOCK_METHOD(bool, OnProtocolVersionMismatch, (ParsedQuicVersion version), + (override)); + MOCK_METHOD(void, OnPacket, (), (override)); + MOCK_METHOD(void, OnPublicResetPacket, (const QuicPublicResetPacket& header), + (override)); + MOCK_METHOD(void, OnVersionNegotiationPacket, + (const QuicVersionNegotiationPacket& packet), (override)); + MOCK_METHOD(void, OnRetryPacket, + (QuicConnectionId original_connection_id, + QuicConnectionId new_connection_id, + absl::string_view retry_token, + absl::string_view retry_integrity_tag, + absl::string_view retry_without_tag), + (override)); + // The constructor sets this up to return true by default. + MOCK_METHOD(bool, OnUnauthenticatedHeader, (const QuicPacketHeader& header), + (override)); + // The constructor sets this up to return true by default. + MOCK_METHOD(bool, OnUnauthenticatedPublicHeader, + (const QuicPacketHeader& header), (override)); + MOCK_METHOD(void, OnDecryptedPacket, (size_t length, EncryptionLevel level), + (override)); + MOCK_METHOD(bool, OnPacketHeader, (const QuicPacketHeader& header), + (override)); + MOCK_METHOD(void, OnCoalescedPacket, (const QuicEncryptedPacket& packet), + (override)); + MOCK_METHOD(void, OnUndecryptablePacket, + (const QuicEncryptedPacket& packet, + EncryptionLevel decryption_level, bool has_decryption_key), + (override)); + MOCK_METHOD(bool, OnStreamFrame, (const QuicStreamFrame& frame), (override)); + MOCK_METHOD(bool, OnCryptoFrame, (const QuicCryptoFrame& frame), (override)); + MOCK_METHOD(bool, OnAckFrameStart, (QuicPacketNumber, QuicTime::Delta), + (override)); + MOCK_METHOD(bool, OnAckRange, (QuicPacketNumber, QuicPacketNumber), + (override)); + MOCK_METHOD(bool, OnAckTimestamp, (QuicPacketNumber, QuicTime), (override)); + MOCK_METHOD(bool, OnAckFrameEnd, + (QuicPacketNumber, const absl::optional&), + (override)); + MOCK_METHOD(bool, OnStopWaitingFrame, (const QuicStopWaitingFrame& frame), + (override)); + MOCK_METHOD(bool, OnPaddingFrame, (const QuicPaddingFrame& frame), + (override)); + MOCK_METHOD(bool, OnPingFrame, (const QuicPingFrame& frame), (override)); + MOCK_METHOD(bool, OnRstStreamFrame, (const QuicRstStreamFrame& frame), + (override)); + MOCK_METHOD(bool, OnConnectionCloseFrame, + (const QuicConnectionCloseFrame& frame), (override)); + MOCK_METHOD(bool, OnNewConnectionIdFrame, + (const QuicNewConnectionIdFrame& frame), (override)); + MOCK_METHOD(bool, OnRetireConnectionIdFrame, + (const QuicRetireConnectionIdFrame& frame), (override)); + MOCK_METHOD(bool, OnNewTokenFrame, (const QuicNewTokenFrame& frame), + (override)); + MOCK_METHOD(bool, OnStopSendingFrame, (const QuicStopSendingFrame& frame), + (override)); + MOCK_METHOD(bool, OnPathChallengeFrame, (const QuicPathChallengeFrame& frame), + (override)); + MOCK_METHOD(bool, OnPathResponseFrame, (const QuicPathResponseFrame& frame), + (override)); + MOCK_METHOD(bool, OnGoAwayFrame, (const QuicGoAwayFrame& frame), (override)); + MOCK_METHOD(bool, OnMaxStreamsFrame, (const QuicMaxStreamsFrame& frame), + (override)); + MOCK_METHOD(bool, OnStreamsBlockedFrame, + (const QuicStreamsBlockedFrame& frame), (override)); + MOCK_METHOD(bool, OnWindowUpdateFrame, (const QuicWindowUpdateFrame& frame), + (override)); + MOCK_METHOD(bool, OnBlockedFrame, (const QuicBlockedFrame& frame), + (override)); + MOCK_METHOD(bool, OnMessageFrame, (const QuicMessageFrame& frame), + (override)); + MOCK_METHOD(bool, OnHandshakeDoneFrame, (const QuicHandshakeDoneFrame& frame), + (override)); + MOCK_METHOD(bool, OnAckFrequencyFrame, (const QuicAckFrequencyFrame& frame), + (override)); + MOCK_METHOD(void, OnPacketComplete, (), (override)); + MOCK_METHOD(bool, IsValidStatelessResetToken, (const StatelessResetToken&), + (const, override)); + MOCK_METHOD(void, OnAuthenticatedIetfStatelessResetPacket, + (const QuicIetfStatelessResetPacket&), (override)); + MOCK_METHOD(void, OnKeyUpdate, (KeyUpdateReason), (override)); + MOCK_METHOD(void, OnDecryptedFirstPacketInKeyPhase, (), (override)); + MOCK_METHOD(std::unique_ptr, + AdvanceKeysAndCreateCurrentOneRttDecrypter, (), (override)); + MOCK_METHOD(std::unique_ptr, CreateCurrentOneRttEncrypter, (), + (override)); +}; + +class NoOpFramerVisitor : public QuicFramerVisitorInterface { + public: + NoOpFramerVisitor() {} + NoOpFramerVisitor(const NoOpFramerVisitor&) = delete; + NoOpFramerVisitor& operator=(const NoOpFramerVisitor&) = delete; + + void OnError(QuicFramer* /*framer*/) override {} + void OnPacket() override {} + void OnPublicResetPacket(const QuicPublicResetPacket& /*packet*/) override {} + void OnVersionNegotiationPacket( + const QuicVersionNegotiationPacket& /*packet*/) override {} + void OnRetryPacket(QuicConnectionId /*original_connection_id*/, + QuicConnectionId /*new_connection_id*/, + absl::string_view /*retry_token*/, + absl::string_view /*retry_integrity_tag*/, + absl::string_view /*retry_without_tag*/) override {} + bool OnProtocolVersionMismatch(ParsedQuicVersion version) override; + bool OnUnauthenticatedHeader(const QuicPacketHeader& header) override; + bool OnUnauthenticatedPublicHeader(const QuicPacketHeader& header) override; + void OnDecryptedPacket(size_t /*length*/, + EncryptionLevel /*level*/) override {} + bool OnPacketHeader(const QuicPacketHeader& header) override; + void OnCoalescedPacket(const QuicEncryptedPacket& packet) override; + void OnUndecryptablePacket(const QuicEncryptedPacket& packet, + EncryptionLevel decryption_level, + bool has_decryption_key) override; + bool OnStreamFrame(const QuicStreamFrame& frame) override; + bool OnCryptoFrame(const QuicCryptoFrame& frame) override; + bool OnAckFrameStart(QuicPacketNumber largest_acked, + QuicTime::Delta ack_delay_time) override; + bool OnAckRange(QuicPacketNumber start, QuicPacketNumber end) override; + bool OnAckTimestamp(QuicPacketNumber packet_number, + QuicTime timestamp) override; + bool OnAckFrameEnd(QuicPacketNumber start, + const absl::optional& ecn_counts) override; + bool OnStopWaitingFrame(const QuicStopWaitingFrame& frame) override; + bool OnPaddingFrame(const QuicPaddingFrame& frame) override; + bool OnPingFrame(const QuicPingFrame& frame) override; + bool OnRstStreamFrame(const QuicRstStreamFrame& frame) override; + bool OnConnectionCloseFrame(const QuicConnectionCloseFrame& frame) override; + bool OnNewConnectionIdFrame(const QuicNewConnectionIdFrame& frame) override; + bool OnRetireConnectionIdFrame( + const QuicRetireConnectionIdFrame& frame) override; + bool OnNewTokenFrame(const QuicNewTokenFrame& frame) override; + bool OnStopSendingFrame(const QuicStopSendingFrame& frame) override; + bool OnPathChallengeFrame(const QuicPathChallengeFrame& frame) override; + bool OnPathResponseFrame(const QuicPathResponseFrame& frame) override; + bool OnGoAwayFrame(const QuicGoAwayFrame& frame) override; + bool OnMaxStreamsFrame(const QuicMaxStreamsFrame& frame) override; + bool OnStreamsBlockedFrame(const QuicStreamsBlockedFrame& frame) override; + bool OnWindowUpdateFrame(const QuicWindowUpdateFrame& frame) override; + bool OnBlockedFrame(const QuicBlockedFrame& frame) override; + bool OnMessageFrame(const QuicMessageFrame& frame) override; + bool OnHandshakeDoneFrame(const QuicHandshakeDoneFrame& frame) override; + bool OnAckFrequencyFrame(const QuicAckFrequencyFrame& frame) override; + void OnPacketComplete() override {} + bool IsValidStatelessResetToken( + const StatelessResetToken& token) const override; + void OnAuthenticatedIetfStatelessResetPacket( + const QuicIetfStatelessResetPacket& /*packet*/) override {} + void OnKeyUpdate(KeyUpdateReason /*reason*/) override {} + void OnDecryptedFirstPacketInKeyPhase() override {} + std::unique_ptr AdvanceKeysAndCreateCurrentOneRttDecrypter() + override { + return nullptr; + } + std::unique_ptr CreateCurrentOneRttEncrypter() override { + return nullptr; + } +}; + +class MockQuicConnectionVisitor : public QuicConnectionVisitorInterface { + public: + MockQuicConnectionVisitor(); + MockQuicConnectionVisitor(const MockQuicConnectionVisitor&) = delete; + MockQuicConnectionVisitor& operator=(const MockQuicConnectionVisitor&) = + delete; + ~MockQuicConnectionVisitor() override; + + MOCK_METHOD(void, OnStreamFrame, (const QuicStreamFrame& frame), (override)); + MOCK_METHOD(void, OnCryptoFrame, (const QuicCryptoFrame& frame), (override)); + MOCK_METHOD(void, OnWindowUpdateFrame, (const QuicWindowUpdateFrame& frame), + (override)); + MOCK_METHOD(void, OnBlockedFrame, (const QuicBlockedFrame& frame), + (override)); + MOCK_METHOD(void, OnRstStream, (const QuicRstStreamFrame& frame), (override)); + MOCK_METHOD(void, OnGoAway, (const QuicGoAwayFrame& frame), (override)); + MOCK_METHOD(void, OnMessageReceived, (absl::string_view message), (override)); + MOCK_METHOD(void, OnHandshakeDoneReceived, (), (override)); + MOCK_METHOD(void, OnNewTokenReceived, (absl::string_view token), (override)); + MOCK_METHOD(void, OnConnectionClosed, + (const QuicConnectionCloseFrame& frame, + ConnectionCloseSource source), + (override)); + MOCK_METHOD(void, OnWriteBlocked, (), (override)); + MOCK_METHOD(void, OnCanWrite, (), (override)); + MOCK_METHOD(void, OnCongestionWindowChange, (QuicTime now), (override)); + MOCK_METHOD(void, OnConnectionMigration, (AddressChangeType type), + (override)); + MOCK_METHOD(void, OnPathDegrading, (), (override)); + MOCK_METHOD(void, OnForwardProgressMadeAfterPathDegrading, (), (override)); + MOCK_METHOD(bool, WillingAndAbleToWrite, (), (const, override)); + MOCK_METHOD(bool, ShouldKeepConnectionAlive, (), (const, override)); + MOCK_METHOD(std::string, GetStreamsInfoForLogging, (), (const, override)); + MOCK_METHOD(void, OnSuccessfulVersionNegotiation, + (const ParsedQuicVersion& version), (override)); + MOCK_METHOD(void, OnPacketReceived, + (const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + bool is_connectivity_probe), + (override)); + MOCK_METHOD(void, OnAckNeedsRetransmittableFrame, (), (override)); + MOCK_METHOD(void, SendAckFrequency, (const QuicAckFrequencyFrame& frame), + (override)); + MOCK_METHOD(void, SendNewConnectionId, + (const QuicNewConnectionIdFrame& frame), (override)); + MOCK_METHOD(void, SendRetireConnectionId, (uint64_t sequence_number), + (override)); + MOCK_METHOD(bool, MaybeReserveConnectionId, + (const QuicConnectionId& server_connection_id), (override)); + MOCK_METHOD(void, OnServerConnectionIdRetired, + (const QuicConnectionId& server_connection_id), (override)); + MOCK_METHOD(bool, AllowSelfAddressChange, (), (const, override)); + MOCK_METHOD(HandshakeState, GetHandshakeState, (), (const, override)); + MOCK_METHOD(bool, OnMaxStreamsFrame, (const QuicMaxStreamsFrame& frame), + (override)); + MOCK_METHOD(bool, OnStreamsBlockedFrame, + (const QuicStreamsBlockedFrame& frame), (override)); + MOCK_METHOD(void, OnStopSendingFrame, (const QuicStopSendingFrame& frame), + (override)); + MOCK_METHOD(void, OnPacketDecrypted, (EncryptionLevel), (override)); + MOCK_METHOD(void, OnOneRttPacketAcknowledged, (), (override)); + MOCK_METHOD(void, OnHandshakePacketSent, (), (override)); + MOCK_METHOD(void, OnKeyUpdate, (KeyUpdateReason), (override)); + MOCK_METHOD(std::unique_ptr, + AdvanceKeysAndCreateCurrentOneRttDecrypter, (), (override)); + MOCK_METHOD(std::unique_ptr, CreateCurrentOneRttEncrypter, (), + (override)); + MOCK_METHOD(void, BeforeConnectionCloseSent, (), (override)); + MOCK_METHOD(bool, ValidateToken, (absl::string_view), (override)); + MOCK_METHOD(bool, MaybeSendAddressToken, (), (override)); + MOCK_METHOD(std::unique_ptr, + CreateContextForMultiPortPath, (), (override)); + MOCK_METHOD(void, MigrateToMultiPortPath, + (std::unique_ptr), (override)); + MOCK_METHOD(void, OnServerPreferredAddressAvailable, + (const QuicSocketAddress&), (override)); + void OnBandwidthUpdateTimeout() override {} +}; + +class MockQuicConnectionHelper : public QuicConnectionHelperInterface { + public: + MockQuicConnectionHelper(); + MockQuicConnectionHelper(const MockQuicConnectionHelper&) = delete; + MockQuicConnectionHelper& operator=(const MockQuicConnectionHelper&) = delete; + ~MockQuicConnectionHelper() override; + const QuicClock* GetClock() const override; + QuicClock* GetClock(); + QuicRandom* GetRandomGenerator() override; + quiche::QuicheBufferAllocator* GetStreamSendBufferAllocator() override; + void AdvanceTime(QuicTime::Delta delta); + + private: + MockClock clock_; + testing::NiceMock random_generator_; + quiche::SimpleBufferAllocator buffer_allocator_; +}; + +class MockAlarmFactory : public QuicAlarmFactory { + public: + QuicAlarm* CreateAlarm(QuicAlarm::Delegate* delegate) override; + QuicArenaScopedPtr CreateAlarm( + QuicArenaScopedPtr delegate, + QuicConnectionArena* arena) override; + + // No-op alarm implementation + class TestAlarm : public QuicAlarm { + public: + explicit TestAlarm(QuicArenaScopedPtr delegate) + : QuicAlarm(std::move(delegate)) {} + + void SetImpl() override {} + void CancelImpl() override {} + + using QuicAlarm::Fire; + }; + + void FireAlarm(QuicAlarm* alarm) { + reinterpret_cast(alarm)->Fire(); + } +}; + +class TestAlarmFactory : public QuicAlarmFactory { + public: + class TestAlarm : public QuicAlarm { + public: + explicit TestAlarm(QuicArenaScopedPtr delegate) + : QuicAlarm(std::move(delegate)) {} + + void SetImpl() override {} + void CancelImpl() override {} + using QuicAlarm::Fire; + }; + + TestAlarmFactory() {} + TestAlarmFactory(const TestAlarmFactory&) = delete; + TestAlarmFactory& operator=(const TestAlarmFactory&) = delete; + + QuicAlarm* CreateAlarm(QuicAlarm::Delegate* delegate) override { + return new TestAlarm(QuicArenaScopedPtr(delegate)); + } + + QuicArenaScopedPtr CreateAlarm( + QuicArenaScopedPtr delegate, + QuicConnectionArena* arena) override { + return arena->New(std::move(delegate)); + } +}; + +class MockQuicConnection : public QuicConnection { + public: + // Uses a ConnectionId of 42 and 127.0.0.1:123. + MockQuicConnection(QuicConnectionHelperInterface* helper, + QuicAlarmFactory* alarm_factory, Perspective perspective); + + // Uses a ConnectionId of 42. + MockQuicConnection(QuicSocketAddress address, + QuicConnectionHelperInterface* helper, + QuicAlarmFactory* alarm_factory, Perspective perspective); + + // Uses 127.0.0.1:123. + MockQuicConnection(QuicConnectionId connection_id, + QuicConnectionHelperInterface* helper, + QuicAlarmFactory* alarm_factory, Perspective perspective); + + // Uses a ConnectionId of 42, and 127.0.0.1:123. + MockQuicConnection(QuicConnectionHelperInterface* helper, + QuicAlarmFactory* alarm_factory, Perspective perspective, + const ParsedQuicVersionVector& supported_versions); + + MockQuicConnection(QuicConnectionId connection_id, QuicSocketAddress address, + QuicConnectionHelperInterface* helper, + QuicAlarmFactory* alarm_factory, Perspective perspective, + const ParsedQuicVersionVector& supported_versions); + MockQuicConnection(const MockQuicConnection&) = delete; + MockQuicConnection& operator=(const MockQuicConnection&) = delete; + + ~MockQuicConnection() override; + + // If the constructor that uses a QuicConnectionHelperInterface has been used + // then this method will advance the time of the MockClock. + void AdvanceTime(QuicTime::Delta delta); + + MOCK_METHOD(void, ProcessUdpPacket, + (const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + const QuicReceivedPacket& packet), + (override)); + MOCK_METHOD(void, CloseConnection, + (QuicErrorCode error, const std::string& details, + ConnectionCloseBehavior connection_close_behavior), + (override)); + MOCK_METHOD(void, CloseConnection, + (QuicErrorCode error, QuicIetfTransportErrorCodes ietf_error, + const std::string& details, + ConnectionCloseBehavior connection_close_behavior), + (override)); + MOCK_METHOD(void, SendConnectionClosePacket, + (QuicErrorCode error, QuicIetfTransportErrorCodes ietf_error, + const std::string& details), + (override)); + MOCK_METHOD(void, OnCanWrite, (), (override)); + MOCK_METHOD(bool, SendConnectivityProbingPacket, + (QuicPacketWriter*, const QuicSocketAddress& peer_address), + (override)); + MOCK_METHOD(void, MaybeProbeMultiPortPath, (), (override)); + + MOCK_METHOD(void, OnSendConnectionState, (const CachedNetworkParameters&), + (override)); + MOCK_METHOD(void, ResumeConnectionState, + (const CachedNetworkParameters&, bool), (override)); + MOCK_METHOD(void, SetMaxPacingRate, (QuicBandwidth), (override)); + + MOCK_METHOD(void, OnStreamReset, (QuicStreamId, QuicRstStreamErrorCode), + (override)); + MOCK_METHOD(bool, SendControlFrame, (const QuicFrame& frame), (override)); + MOCK_METHOD(MessageStatus, SendMessage, + (QuicMessageId, absl::Span, bool), + (override)); + MOCK_METHOD(bool, SendPathChallenge, + (const QuicPathFrameBuffer&, const QuicSocketAddress&, + const QuicSocketAddress&, const QuicSocketAddress&, + QuicPacketWriter*), + (override)); + + MOCK_METHOD(void, OnError, (QuicFramer*), (override)); + void QuicConnection_OnError(QuicFramer* framer) { + QuicConnection::OnError(framer); + } + + void ReallyOnCanWrite() { QuicConnection::OnCanWrite(); } + + void ReallyCloseConnection( + QuicErrorCode error, const std::string& details, + ConnectionCloseBehavior connection_close_behavior) { + // Call the 4-param method directly instead of the 3-param method, so that + // it doesn't invoke the virtual 4-param method causing the mock 4-param + // method to trigger. + QuicConnection::CloseConnection(error, NO_IETF_QUIC_ERROR, details, + connection_close_behavior); + } + + void ReallyCloseConnection4( + QuicErrorCode error, QuicIetfTransportErrorCodes ietf_error, + const std::string& details, + ConnectionCloseBehavior connection_close_behavior) { + QuicConnection::CloseConnection(error, ietf_error, details, + connection_close_behavior); + } + + void ReallySendConnectionClosePacket(QuicErrorCode error, + QuicIetfTransportErrorCodes ietf_error, + const std::string& details) { + QuicConnection::SendConnectionClosePacket(error, ietf_error, details); + } + + void ReallyProcessUdpPacket(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + const QuicReceivedPacket& packet) { + QuicConnection::ProcessUdpPacket(self_address, peer_address, packet); + } + + bool OnProtocolVersionMismatch(ParsedQuicVersion version) override; + void OnIdleNetworkDetected() override {} + + bool ReallySendControlFrame(const QuicFrame& frame) { + return QuicConnection::SendControlFrame(frame); + } + + bool ReallySendConnectivityProbingPacket( + QuicPacketWriter* probing_writer, const QuicSocketAddress& peer_address) { + return QuicConnection::SendConnectivityProbingPacket(probing_writer, + peer_address); + } + + bool ReallyOnPathResponseFrame(const QuicPathResponseFrame& frame) { + return QuicConnection::OnPathResponseFrame(frame); + } + + MOCK_METHOD(bool, OnPathResponseFrame, (const QuicPathResponseFrame&), + (override)); + MOCK_METHOD(bool, OnStopSendingFrame, (const QuicStopSendingFrame& frame), + (override)); + MOCK_METHOD(size_t, SendCryptoData, + (EncryptionLevel, size_t, QuicStreamOffset), (override)); + size_t QuicConnection_SendCryptoData(EncryptionLevel level, + size_t write_length, + QuicStreamOffset offset) { + return QuicConnection::SendCryptoData(level, write_length, offset); + } + + MockConnectionIdGenerator& connection_id_generator() { + return connection_id_generator_; + } + + private: + // It would be more correct to pass the generator as an argument to the + // constructor, particularly in dispatcher tests that keep their own + // reference to a generator. But there are many, many instances of derived + // test classes that would have to declare a generator. As this object is + // public, it is straightforward for the caller to use it as an argument to + // EXPECT_CALL. + MockConnectionIdGenerator connection_id_generator_; +}; + +class PacketSavingConnection : public MockQuicConnection { + public: + PacketSavingConnection(QuicConnectionHelperInterface* helper, + QuicAlarmFactory* alarm_factory, + Perspective perspective); + + PacketSavingConnection(QuicConnectionHelperInterface* helper, + QuicAlarmFactory* alarm_factory, + Perspective perspective, + const ParsedQuicVersionVector& supported_versions); + PacketSavingConnection(const PacketSavingConnection&) = delete; + PacketSavingConnection& operator=(const PacketSavingConnection&) = delete; + + ~PacketSavingConnection() override; + + SerializedPacketFate GetSerializedPacketFate( + bool is_mtu_discovery, EncryptionLevel encryption_level) override; + + void SendOrQueuePacket(SerializedPacket packet) override; + + MOCK_METHOD(void, OnPacketSent, (EncryptionLevel, TransmissionType)); + + std::vector> encrypted_packets_; + // Number of packets in encrypted_packets that has been delivered to the peer + // connection. + size_t number_of_packets_delivered_ = 0; + MockClock clock_; +}; + +class MockQuicSession : public QuicSession { + public: + // Takes ownership of |connection|. + MockQuicSession(QuicConnection* connection, bool create_mock_crypto_stream); + + // Takes ownership of |connection|. + explicit MockQuicSession(QuicConnection* connection); + MockQuicSession(const MockQuicSession&) = delete; + MockQuicSession& operator=(const MockQuicSession&) = delete; + ~MockQuicSession() override; + + QuicCryptoStream* GetMutableCryptoStream() override; + const QuicCryptoStream* GetCryptoStream() const override; + void SetCryptoStream(QuicCryptoStream* crypto_stream); + + MOCK_METHOD(void, OnConnectionClosed, + (const QuicConnectionCloseFrame& frame, + ConnectionCloseSource source), + (override)); + MOCK_METHOD(QuicStream*, CreateIncomingStream, (QuicStreamId id), (override)); + MOCK_METHOD(QuicSpdyStream*, CreateIncomingStream, (PendingStream*), + (override)); + MOCK_METHOD(QuicConsumedData, WritevData, + (QuicStreamId id, size_t write_length, QuicStreamOffset offset, + StreamSendingState state, TransmissionType type, + EncryptionLevel level), + (override)); + MOCK_METHOD(bool, WriteControlFrame, + (const QuicFrame& frame, TransmissionType type), (override)); + MOCK_METHOD(void, MaybeSendRstStreamFrame, + (QuicStreamId stream_id, QuicResetStreamError error, + QuicStreamOffset bytes_written), + (override)); + MOCK_METHOD(void, MaybeSendStopSendingFrame, + (QuicStreamId stream_id, QuicResetStreamError error), (override)); + MOCK_METHOD(void, SendBlocked, + (QuicStreamId stream_id, QuicStreamOffset offset), (override)); + + MOCK_METHOD(bool, ShouldKeepConnectionAlive, (), (const, override)); + MOCK_METHOD(std::vector, GetAlpnsToOffer, (), (const, override)); + MOCK_METHOD(std::vector::const_iterator, SelectAlpn, + (const std::vector&), (const, override)); + MOCK_METHOD(void, OnAlpnSelected, (absl::string_view), (override)); + + using QuicSession::ActivateStream; + + // Returns a QuicConsumedData that indicates all of |write_length| (and |fin| + // if set) has been consumed. + QuicConsumedData ConsumeData(QuicStreamId id, size_t write_length, + QuicStreamOffset offset, + StreamSendingState state, TransmissionType type, + absl::optional level); + + void ReallyMaybeSendRstStreamFrame(QuicStreamId id, + QuicRstStreamErrorCode error, + QuicStreamOffset bytes_written) { + QuicSession::MaybeSendRstStreamFrame( + id, QuicResetStreamError::FromInternal(error), bytes_written); + } + + private: + std::unique_ptr crypto_stream_; +}; + +class MockQuicCryptoStream : public QuicCryptoStream { + public: + explicit MockQuicCryptoStream(QuicSession* session); + + ~MockQuicCryptoStream() override; + + ssl_early_data_reason_t EarlyDataReason() const override; + bool encryption_established() const override; + bool one_rtt_keys_available() const override; + const QuicCryptoNegotiatedParameters& crypto_negotiated_params() + const override; + CryptoMessageParser* crypto_message_parser() override; + void OnPacketDecrypted(EncryptionLevel /*level*/) override {} + void OnOneRttPacketAcknowledged() override {} + void OnHandshakePacketSent() override {} + void OnHandshakeDoneReceived() override {} + void OnNewTokenReceived(absl::string_view /*token*/) override {} + std::string GetAddressToken( + const CachedNetworkParameters* /*cached_network_parameters*/) + const override { + return ""; + } + bool ValidateAddressToken(absl::string_view /*token*/) const override { + return true; + } + const CachedNetworkParameters* PreviousCachedNetworkParams() const override { + return nullptr; + } + void SetPreviousCachedNetworkParams( + CachedNetworkParameters /*cached_network_params*/) override {} + void OnConnectionClosed(QuicErrorCode /*error*/, + ConnectionCloseSource /*source*/) override {} + HandshakeState GetHandshakeState() const override { return HANDSHAKE_START; } + void SetServerApplicationStateForResumption( + std::unique_ptr /*application_state*/) override {} + std::unique_ptr AdvanceKeysAndCreateCurrentOneRttDecrypter() + override { + return nullptr; + } + std::unique_ptr CreateCurrentOneRttEncrypter() override { + return nullptr; + } + bool ExportKeyingMaterial(absl::string_view /*label*/, + absl::string_view /*context*/, + size_t /*result_len*/, + std::string* /*result*/) override { + return false; + } + SSL* GetSsl() const override { return nullptr; } + bool IsCryptoFrameExpectedForEncryptionLevel( + quic::EncryptionLevel level) const override { + return level != ENCRYPTION_ZERO_RTT; + } + EncryptionLevel GetEncryptionLevelToSendCryptoDataOfSpace( + PacketNumberSpace space) const override { + switch (space) { + case INITIAL_DATA: + return ENCRYPTION_INITIAL; + case HANDSHAKE_DATA: + return ENCRYPTION_HANDSHAKE; + case APPLICATION_DATA: + return ENCRYPTION_FORWARD_SECURE; + default: + QUICHE_DCHECK(false); + return NUM_ENCRYPTION_LEVELS; + } + } + + private: + quiche::QuicheReferenceCountedPointer params_; + CryptoFramer crypto_framer_; +}; + +class MockQuicSpdySession : public QuicSpdySession { + public: + // Takes ownership of |connection|. + explicit MockQuicSpdySession(QuicConnection* connection); + // Takes ownership of |connection|. + MockQuicSpdySession(QuicConnection* connection, + bool create_mock_crypto_stream); + MockQuicSpdySession(const MockQuicSpdySession&) = delete; + MockQuicSpdySession& operator=(const MockQuicSpdySession&) = delete; + ~MockQuicSpdySession() override; + + QuicCryptoStream* GetMutableCryptoStream() override; + const QuicCryptoStream* GetCryptoStream() const override; + void SetCryptoStream(QuicCryptoStream* crypto_stream); + + void ReallyOnConnectionClosed(const QuicConnectionCloseFrame& frame, + ConnectionCloseSource source) { + QuicSession::OnConnectionClosed(frame, source); + } + + // From QuicSession. + MOCK_METHOD(void, OnConnectionClosed, + (const QuicConnectionCloseFrame& frame, + ConnectionCloseSource source), + (override)); + MOCK_METHOD(QuicSpdyStream*, CreateIncomingStream, (QuicStreamId id), + (override)); + MOCK_METHOD(QuicSpdyStream*, CreateIncomingStream, (PendingStream*), + (override)); + MOCK_METHOD(QuicSpdyStream*, CreateOutgoingBidirectionalStream, (), + (override)); + MOCK_METHOD(QuicSpdyStream*, CreateOutgoingUnidirectionalStream, (), + (override)); + MOCK_METHOD(bool, ShouldCreateIncomingStream, (QuicStreamId id), (override)); + MOCK_METHOD(bool, ShouldCreateOutgoingBidirectionalStream, (), (override)); + MOCK_METHOD(bool, ShouldCreateOutgoingUnidirectionalStream, (), (override)); + MOCK_METHOD(QuicConsumedData, WritevData, + (QuicStreamId id, size_t write_length, QuicStreamOffset offset, + StreamSendingState state, TransmissionType type, + EncryptionLevel level), + (override)); + MOCK_METHOD(void, MaybeSendRstStreamFrame, + (QuicStreamId stream_id, QuicResetStreamError error, + QuicStreamOffset bytes_written), + (override)); + MOCK_METHOD(void, MaybeSendStopSendingFrame, + (QuicStreamId stream_id, QuicResetStreamError error), (override)); + MOCK_METHOD(void, SendWindowUpdate, + (QuicStreamId id, QuicStreamOffset byte_offset), (override)); + MOCK_METHOD(void, SendBlocked, + (QuicStreamId id, QuicStreamOffset byte_offset), (override)); + MOCK_METHOD(void, OnStreamHeadersPriority, + (QuicStreamId stream_id, + const spdy::SpdyStreamPrecedence& precedence), + (override)); + MOCK_METHOD(void, OnStreamHeaderList, + (QuicStreamId stream_id, bool fin, size_t frame_len, + const QuicHeaderList& header_list), + (override)); + MOCK_METHOD(void, OnPromiseHeaderList, + (QuicStreamId stream_id, QuicStreamId promised_stream_id, + size_t frame_len, const QuicHeaderList& header_list), + (override)); + MOCK_METHOD(void, OnPriorityFrame, + (QuicStreamId id, const spdy::SpdyStreamPrecedence& precedence), + (override)); + MOCK_METHOD(void, OnCongestionWindowChange, (QuicTime now), (override)); + + // Returns a QuicConsumedData that indicates all of |write_length| (and |fin| + // if set) has been consumed. + QuicConsumedData ConsumeData(QuicStreamId id, size_t write_length, + QuicStreamOffset offset, + StreamSendingState state, TransmissionType type, + absl::optional level); + + using QuicSession::ActivateStream; + + private: + std::unique_ptr crypto_stream_; +}; + +class MockHttp3DebugVisitor : public Http3DebugVisitor { + public: + MOCK_METHOD(void, OnControlStreamCreated, (QuicStreamId), (override)); + MOCK_METHOD(void, OnQpackEncoderStreamCreated, (QuicStreamId), (override)); + MOCK_METHOD(void, OnQpackDecoderStreamCreated, (QuicStreamId), (override)); + MOCK_METHOD(void, OnPeerControlStreamCreated, (QuicStreamId), (override)); + MOCK_METHOD(void, OnPeerQpackEncoderStreamCreated, (QuicStreamId), + (override)); + MOCK_METHOD(void, OnPeerQpackDecoderStreamCreated, (QuicStreamId), + (override)); + + MOCK_METHOD(void, OnSettingsFrameReceivedViaAlps, (const SettingsFrame&), + (override)); + + MOCK_METHOD(void, OnAcceptChFrameReceivedViaAlps, (const AcceptChFrame&), + (override)); + + MOCK_METHOD(void, OnSettingsFrameReceived, (const SettingsFrame&), + (override)); + MOCK_METHOD(void, OnGoAwayFrameReceived, (const GoAwayFrame&), (override)); + MOCK_METHOD(void, OnPriorityUpdateFrameReceived, (const PriorityUpdateFrame&), + (override)); + MOCK_METHOD(void, OnAcceptChFrameReceived, (const AcceptChFrame&), + (override)); + + MOCK_METHOD(void, OnDataFrameReceived, (QuicStreamId, QuicByteCount), + (override)); + MOCK_METHOD(void, OnHeadersFrameReceived, (QuicStreamId, QuicByteCount), + (override)); + MOCK_METHOD(void, OnHeadersDecoded, (QuicStreamId, QuicHeaderList), + (override)); + MOCK_METHOD(void, OnUnknownFrameReceived, + (QuicStreamId, uint64_t, QuicByteCount), (override)); + + MOCK_METHOD(void, OnSettingsFrameSent, (const SettingsFrame&), (override)); + MOCK_METHOD(void, OnGoAwayFrameSent, (QuicStreamId), (override)); + MOCK_METHOD(void, OnPriorityUpdateFrameSent, (const PriorityUpdateFrame&), + (override)); + + MOCK_METHOD(void, OnDataFrameSent, (QuicStreamId, QuicByteCount), (override)); + MOCK_METHOD(void, OnHeadersFrameSent, + (QuicStreamId, const spdy::Http2HeaderBlock&), (override)); + MOCK_METHOD(void, OnSettingsFrameResumed, (const SettingsFrame&), (override)); +}; + +class TestQuicSpdyServerSession : public QuicServerSessionBase { + public: + // Takes ownership of |connection|. + TestQuicSpdyServerSession(QuicConnection* connection, + const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache); + TestQuicSpdyServerSession(const TestQuicSpdyServerSession&) = delete; + TestQuicSpdyServerSession& operator=(const TestQuicSpdyServerSession&) = + delete; + ~TestQuicSpdyServerSession() override; + + MOCK_METHOD(QuicSpdyStream*, CreateIncomingStream, (QuicStreamId id), + (override)); + MOCK_METHOD(QuicSpdyStream*, CreateIncomingStream, (PendingStream*), + (override)); + MOCK_METHOD(QuicSpdyStream*, CreateOutgoingBidirectionalStream, (), + (override)); + MOCK_METHOD(QuicSpdyStream*, CreateOutgoingUnidirectionalStream, (), + (override)); + MOCK_METHOD(std::vector::const_iterator, SelectAlpn, + (const std::vector&), (const, override)); + MOCK_METHOD(void, OnAlpnSelected, (absl::string_view), (override)); + std::unique_ptr CreateQuicCryptoServerStream( + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache) override; + + QuicCryptoServerStreamBase* GetMutableCryptoStream() override; + + const QuicCryptoServerStreamBase* GetCryptoStream() const override; + + MockQuicCryptoServerStreamHelper* helper() { return &helper_; } + + QuicSSLConfig GetSSLConfig() const override { + QuicSSLConfig ssl_config = QuicServerSessionBase::GetSSLConfig(); + if (early_data_enabled_.has_value()) { + ssl_config.early_data_enabled = *early_data_enabled_; + } + if (client_cert_mode_.has_value()) { + ssl_config.client_cert_mode = *client_cert_mode_; + } + + return ssl_config; + } + + void set_early_data_enabled(bool enabled) { early_data_enabled_ = enabled; } + + void set_client_cert_mode(ClientCertMode mode) { client_cert_mode_ = mode; } + + private: + MockQuicSessionVisitor visitor_; + MockQuicCryptoServerStreamHelper helper_; + // If not nullopt, override the early_data_enabled value from base class' + // ssl_config. + absl::optional early_data_enabled_; + // If not nullopt, override the client_cert_mode value from base class' + // ssl_config. + absl::optional client_cert_mode_; +}; + +// A test implementation of QuicClientPushPromiseIndex::Delegate. +class TestPushPromiseDelegate : public QuicClientPushPromiseIndex::Delegate { + public: + // |match| sets the validation result for checking whether designated header + // fields match for promise request and client request. + explicit TestPushPromiseDelegate(bool match); + + bool CheckVary(const spdy::Http2HeaderBlock& client_request, + const spdy::Http2HeaderBlock& promise_request, + const spdy::Http2HeaderBlock& promise_response) override; + + void OnRendezvousResult(QuicSpdyStream* stream) override; + + QuicSpdyStream* rendezvous_stream() { return rendezvous_stream_; } + bool rendezvous_fired() { return rendezvous_fired_; } + + private: + bool match_; + bool rendezvous_fired_; + QuicSpdyStream* rendezvous_stream_; +}; + +class TestQuicSpdyClientSession : public QuicSpdyClientSessionBase { + public: + TestQuicSpdyClientSession( + QuicConnection* connection, const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + const QuicServerId& server_id, QuicCryptoClientConfig* crypto_config, + absl::optional ssl_config = absl::nullopt); + TestQuicSpdyClientSession(const TestQuicSpdyClientSession&) = delete; + TestQuicSpdyClientSession& operator=(const TestQuicSpdyClientSession&) = + delete; + ~TestQuicSpdyClientSession() override; + + bool IsAuthorized(const std::string& authority) override; + + // QuicSpdyClientSessionBase + MOCK_METHOD(void, OnProofValid, + (const QuicCryptoClientConfig::CachedState& cached), (override)); + MOCK_METHOD(void, OnProofVerifyDetailsAvailable, + (const ProofVerifyDetails& verify_details), (override)); + + // TestQuicSpdyClientSession + MOCK_METHOD(QuicSpdyStream*, CreateIncomingStream, (QuicStreamId id), + (override)); + MOCK_METHOD(QuicSpdyStream*, CreateIncomingStream, (PendingStream*), + (override)); + MOCK_METHOD(QuicSpdyStream*, CreateOutgoingBidirectionalStream, (), + (override)); + MOCK_METHOD(QuicSpdyStream*, CreateOutgoingUnidirectionalStream, (), + (override)); + MOCK_METHOD(bool, ShouldCreateIncomingStream, (QuicStreamId id), (override)); + MOCK_METHOD(bool, ShouldCreateOutgoingBidirectionalStream, (), (override)); + MOCK_METHOD(bool, ShouldCreateOutgoingUnidirectionalStream, (), (override)); + MOCK_METHOD(std::vector, GetAlpnsToOffer, (), (const, override)); + MOCK_METHOD(void, OnAlpnSelected, (absl::string_view), (override)); + MOCK_METHOD(void, OnConfigNegotiated, (), (override)); + + QuicCryptoClientStream* GetMutableCryptoStream() override; + const QuicCryptoClientStream* GetCryptoStream() const override; + + QuicSSLConfig GetSSLConfig() const override { + return ssl_config_.has_value() ? *ssl_config_ + : QuicSpdyClientSessionBase::GetSSLConfig(); + } + + // Override to save sent crypto handshake messages. + void OnCryptoHandshakeMessageSent( + const CryptoHandshakeMessage& message) override { + sent_crypto_handshake_messages_.push_back(message); + } + + const std::vector& sent_crypto_handshake_messages() + const { + return sent_crypto_handshake_messages_; + } + + private: + // Calls the parent class's OnConfigNegotiated method. Used to set the default + // mock behavior for OnConfigNegotiated. + void RealOnConfigNegotiated(); + + std::unique_ptr crypto_stream_; + QuicClientPushPromiseIndex push_promise_index_; + std::vector sent_crypto_handshake_messages_; + absl::optional ssl_config_; +}; + +class MockPacketWriter : public QuicPacketWriter { + public: + MockPacketWriter(); + MockPacketWriter(const MockPacketWriter&) = delete; + MockPacketWriter& operator=(const MockPacketWriter&) = delete; + ~MockPacketWriter() override; + + MOCK_METHOD(WriteResult, WritePacket, + (const char*, size_t buf_len, const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, PerPacketOptions*), + (override)); + MOCK_METHOD(bool, IsWriteBlocked, (), (const, override)); + MOCK_METHOD(void, SetWritable, (), (override)); + MOCK_METHOD(absl::optional, MessageTooBigErrorCode, (), + (const, override)); + MOCK_METHOD(QuicByteCount, GetMaxPacketSize, + (const QuicSocketAddress& peer_address), (const, override)); + MOCK_METHOD(bool, SupportsReleaseTime, (), (const, override)); + MOCK_METHOD(bool, IsBatchMode, (), (const, override)); + MOCK_METHOD(QuicPacketBuffer, GetNextWriteLocation, + (const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address), + (override)); + MOCK_METHOD(WriteResult, Flush, (), (override)); +}; + +class MockSendAlgorithm : public SendAlgorithmInterface { + public: + MockSendAlgorithm(); + MockSendAlgorithm(const MockSendAlgorithm&) = delete; + MockSendAlgorithm& operator=(const MockSendAlgorithm&) = delete; + ~MockSendAlgorithm() override; + + MOCK_METHOD(void, SetFromConfig, + (const QuicConfig& config, Perspective perspective), (override)); + MOCK_METHOD(void, ApplyConnectionOptions, + (const QuicTagVector& connection_options), (override)); + MOCK_METHOD(void, SetInitialCongestionWindowInPackets, + (QuicPacketCount packets), (override)); + MOCK_METHOD(void, OnCongestionEvent, + (bool rtt_updated, QuicByteCount bytes_in_flight, + QuicTime event_time, const AckedPacketVector& acked_packets, + const LostPacketVector& lost_packets, QuicPacketCount num_ect, + QuicPacketCount num_ce), + (override)); + MOCK_METHOD(void, OnPacketSent, + (QuicTime, QuicByteCount, QuicPacketNumber, QuicByteCount, + HasRetransmittableData), + (override)); + MOCK_METHOD(void, OnPacketNeutered, (QuicPacketNumber), (override)); + MOCK_METHOD(void, OnRetransmissionTimeout, (bool), (override)); + MOCK_METHOD(void, OnConnectionMigration, (), (override)); + MOCK_METHOD(bool, CanSend, (QuicByteCount), (override)); + MOCK_METHOD(QuicBandwidth, PacingRate, (QuicByteCount), (const, override)); + MOCK_METHOD(QuicBandwidth, BandwidthEstimate, (), (const, override)); + MOCK_METHOD(bool, HasGoodBandwidthEstimateForResumption, (), + (const, override)); + MOCK_METHOD(QuicByteCount, GetCongestionWindow, (), (const, override)); + MOCK_METHOD(std::string, GetDebugState, (), (const, override)); + MOCK_METHOD(bool, InSlowStart, (), (const, override)); + MOCK_METHOD(bool, InRecovery, (), (const, override)); + MOCK_METHOD(QuicByteCount, GetSlowStartThreshold, (), (const, override)); + MOCK_METHOD(CongestionControlType, GetCongestionControlType, (), + (const, override)); + MOCK_METHOD(void, AdjustNetworkParameters, (const NetworkParams&), + (override)); + MOCK_METHOD(void, OnApplicationLimited, (QuicByteCount), (override)); + MOCK_METHOD(void, PopulateConnectionStats, (QuicConnectionStats*), + (const, override)); + MOCK_METHOD(bool, SupportsECT0, (), (const, override)); + MOCK_METHOD(bool, SupportsECT1, (), (const, override)); +}; + +class MockLossAlgorithm : public LossDetectionInterface { + public: + MockLossAlgorithm(); + MockLossAlgorithm(const MockLossAlgorithm&) = delete; + MockLossAlgorithm& operator=(const MockLossAlgorithm&) = delete; + ~MockLossAlgorithm() override; + + MOCK_METHOD(void, SetFromConfig, + (const QuicConfig& config, Perspective perspective), (override)); + + MOCK_METHOD(DetectionStats, DetectLosses, + (const QuicUnackedPacketMap& unacked_packets, QuicTime time, + const RttStats& rtt_stats, + QuicPacketNumber largest_recently_acked, + const AckedPacketVector& packets_acked, LostPacketVector*), + (override)); + MOCK_METHOD(QuicTime, GetLossTimeout, (), (const, override)); + MOCK_METHOD(void, SpuriousLossDetected, + (const QuicUnackedPacketMap&, const RttStats&, QuicTime, + QuicPacketNumber, QuicPacketNumber), + (override)); + + MOCK_METHOD(void, OnConfigNegotiated, (), (override)); + MOCK_METHOD(void, OnMinRttAvailable, (), (override)); + MOCK_METHOD(void, OnUserAgentIdKnown, (), (override)); + MOCK_METHOD(void, OnConnectionClosed, (), (override)); + MOCK_METHOD(void, OnReorderingDetected, (), (override)); +}; + +class MockAckListener : public QuicAckListenerInterface { + public: + MockAckListener(); + MockAckListener(const MockAckListener&) = delete; + MockAckListener& operator=(const MockAckListener&) = delete; + + MOCK_METHOD(void, OnPacketAcked, + (int acked_bytes, QuicTime::Delta ack_delay_time), (override)); + + MOCK_METHOD(void, OnPacketRetransmitted, (int retransmitted_bytes), + (override)); + + protected: + // Object is ref counted. + ~MockAckListener() override; +}; + +class MockNetworkChangeVisitor + : public QuicSentPacketManager::NetworkChangeVisitor { + public: + MockNetworkChangeVisitor(); + MockNetworkChangeVisitor(const MockNetworkChangeVisitor&) = delete; + MockNetworkChangeVisitor& operator=(const MockNetworkChangeVisitor&) = delete; + ~MockNetworkChangeVisitor() override; + + MOCK_METHOD(void, OnCongestionChange, (), (override)); + MOCK_METHOD(void, OnPathMtuIncreased, (QuicPacketLength), (override)); +}; + +class MockQuicConnectionDebugVisitor : public QuicConnectionDebugVisitor { + public: + MockQuicConnectionDebugVisitor(); + ~MockQuicConnectionDebugVisitor() override; + + MOCK_METHOD(void, OnPacketSent, + (QuicPacketNumber, QuicPacketLength, bool, TransmissionType, + EncryptionLevel, const QuicFrames&, const QuicFrames&, QuicTime), + (override)); + + MOCK_METHOD(void, OnCoalescedPacketSent, (const QuicCoalescedPacket&, size_t), + (override)); + + MOCK_METHOD(void, OnPingSent, (), (override)); + + MOCK_METHOD(void, OnPacketReceived, + (const QuicSocketAddress&, const QuicSocketAddress&, + const QuicEncryptedPacket&), + (override)); + + MOCK_METHOD(void, OnIncorrectConnectionId, (QuicConnectionId), (override)); + + MOCK_METHOD(void, OnProtocolVersionMismatch, (ParsedQuicVersion), (override)); + + MOCK_METHOD(void, OnPacketHeader, + (const QuicPacketHeader& header, QuicTime receive_time, + EncryptionLevel level), + (override)); + + MOCK_METHOD(void, OnSuccessfulVersionNegotiation, (const ParsedQuicVersion&), + (override)); + + MOCK_METHOD(void, OnStreamFrame, (const QuicStreamFrame&), (override)); + + MOCK_METHOD(void, OnCryptoFrame, (const QuicCryptoFrame&), (override)); + + MOCK_METHOD(void, OnStopWaitingFrame, (const QuicStopWaitingFrame&), + (override)); + + MOCK_METHOD(void, OnRstStreamFrame, (const QuicRstStreamFrame&), (override)); + + MOCK_METHOD(void, OnConnectionCloseFrame, (const QuicConnectionCloseFrame&), + (override)); + + MOCK_METHOD(void, OnBlockedFrame, (const QuicBlockedFrame&), (override)); + + MOCK_METHOD(void, OnNewConnectionIdFrame, (const QuicNewConnectionIdFrame&), + (override)); + + MOCK_METHOD(void, OnRetireConnectionIdFrame, + (const QuicRetireConnectionIdFrame&), (override)); + + MOCK_METHOD(void, OnNewTokenFrame, (const QuicNewTokenFrame&), (override)); + + MOCK_METHOD(void, OnMessageFrame, (const QuicMessageFrame&), (override)); + + MOCK_METHOD(void, OnStopSendingFrame, (const QuicStopSendingFrame&), + (override)); + + MOCK_METHOD(void, OnPathChallengeFrame, (const QuicPathChallengeFrame&), + (override)); + + MOCK_METHOD(void, OnPathResponseFrame, (const QuicPathResponseFrame&), + (override)); + + MOCK_METHOD(void, OnPublicResetPacket, (const QuicPublicResetPacket&), + (override)); + + MOCK_METHOD(void, OnVersionNegotiationPacket, + (const QuicVersionNegotiationPacket&), (override)); + + MOCK_METHOD(void, OnTransportParametersSent, (const TransportParameters&), + (override)); + + MOCK_METHOD(void, OnTransportParametersReceived, (const TransportParameters&), + (override)); + + MOCK_METHOD(void, OnZeroRttRejected, (int), (override)); + MOCK_METHOD(void, OnZeroRttPacketAcked, (), (override)); +}; + +class MockReceivedPacketManager : public QuicReceivedPacketManager { + public: + explicit MockReceivedPacketManager(QuicConnectionStats* stats); + ~MockReceivedPacketManager() override; + + MOCK_METHOD(void, RecordPacketReceived, + (const QuicPacketHeader& header, QuicTime receipt_time, + const QuicEcnCodepoint ecn), + (override)); + MOCK_METHOD(bool, IsMissing, (QuicPacketNumber packet_number), (override)); + MOCK_METHOD(bool, IsAwaitingPacket, (QuicPacketNumber packet_number), + (const, override)); + MOCK_METHOD(bool, HasNewMissingPackets, (), (const, override)); + MOCK_METHOD(bool, ack_frame_updated, (), (const, override)); +}; + +class MockPacketCreatorDelegate : public QuicPacketCreator::DelegateInterface { + public: + MockPacketCreatorDelegate(); + MockPacketCreatorDelegate(const MockPacketCreatorDelegate&) = delete; + MockPacketCreatorDelegate& operator=(const MockPacketCreatorDelegate&) = + delete; + ~MockPacketCreatorDelegate() override; + + MOCK_METHOD(QuicPacketBuffer, GetPacketBuffer, (), (override)); + MOCK_METHOD(void, OnSerializedPacket, (SerializedPacket), (override)); + MOCK_METHOD(void, OnUnrecoverableError, (QuicErrorCode, const std::string&), + (override)); + MOCK_METHOD(bool, ShouldGeneratePacket, + (HasRetransmittableData retransmittable, IsHandshake handshake), + (override)); + MOCK_METHOD(const QuicFrames, MaybeBundleAckOpportunistically, (), + (override)); + MOCK_METHOD(SerializedPacketFate, GetSerializedPacketFate, + (bool, EncryptionLevel), (override)); +}; + +class MockSessionNotifier : public SessionNotifierInterface { + public: + MockSessionNotifier(); + ~MockSessionNotifier() override; + + MOCK_METHOD(bool, OnFrameAcked, (const QuicFrame&, QuicTime::Delta, QuicTime), + (override)); + MOCK_METHOD(void, OnStreamFrameRetransmitted, (const QuicStreamFrame&), + (override)); + MOCK_METHOD(void, OnFrameLost, (const QuicFrame&), (override)); + MOCK_METHOD(bool, RetransmitFrames, + (const QuicFrames&, TransmissionType type), (override)); + MOCK_METHOD(bool, IsFrameOutstanding, (const QuicFrame&), (const, override)); + MOCK_METHOD(bool, HasUnackedCryptoData, (), (const, override)); + MOCK_METHOD(bool, HasUnackedStreamData, (), (const, override)); +}; + +class MockQuicPathValidationContext : public QuicPathValidationContext { + public: + MockQuicPathValidationContext(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + const QuicSocketAddress& effective_peer_address, + QuicPacketWriter* writer) + : QuicPathValidationContext(self_address, peer_address, + effective_peer_address), + writer_(writer) {} + QuicPacketWriter* WriterToUse() override { return writer_; } + + private: + QuicPacketWriter* writer_; +}; + +class MockQuicPathValidationResultDelegate + : public QuicPathValidator::ResultDelegate { + public: + MOCK_METHOD(void, OnPathValidationSuccess, + (std::unique_ptr, QuicTime), + (override)); + + MOCK_METHOD(void, OnPathValidationFailure, + (std::unique_ptr), (override)); +}; + +class MockHttpDecoderVisitor : public HttpDecoder::Visitor { + public: + ~MockHttpDecoderVisitor() override = default; + + // Called if an error is detected. + MOCK_METHOD(void, OnError, (HttpDecoder*), (override)); + + MOCK_METHOD(bool, OnMaxPushIdFrame, (), (override)); + MOCK_METHOD(bool, OnGoAwayFrame, (const GoAwayFrame& frame), (override)); + MOCK_METHOD(bool, OnSettingsFrameStart, (QuicByteCount header_length), + (override)); + MOCK_METHOD(bool, OnSettingsFrame, (const SettingsFrame& frame), (override)); + + MOCK_METHOD(bool, OnDataFrameStart, + (QuicByteCount header_length, QuicByteCount payload_length), + (override)); + MOCK_METHOD(bool, OnDataFramePayload, (absl::string_view payload), + (override)); + MOCK_METHOD(bool, OnDataFrameEnd, (), (override)); + + MOCK_METHOD(bool, OnHeadersFrameStart, + (QuicByteCount header_length, QuicByteCount payload_length), + (override)); + MOCK_METHOD(bool, OnHeadersFramePayload, (absl::string_view payload), + (override)); + MOCK_METHOD(bool, OnHeadersFrameEnd, (), (override)); + + MOCK_METHOD(bool, OnPriorityUpdateFrameStart, (QuicByteCount header_length), + (override)); + MOCK_METHOD(bool, OnPriorityUpdateFrame, (const PriorityUpdateFrame& frame), + (override)); + + MOCK_METHOD(bool, OnAcceptChFrameStart, (QuicByteCount header_length), + (override)); + MOCK_METHOD(bool, OnAcceptChFrame, (const AcceptChFrame& frame), (override)); + MOCK_METHOD(void, OnWebTransportStreamFrameType, + (QuicByteCount header_length, WebTransportSessionId session_id), + (override)); + + MOCK_METHOD(bool, OnUnknownFrameStart, + (uint64_t frame_type, QuicByteCount header_length, + QuicByteCount payload_length), + (override)); + MOCK_METHOD(bool, OnUnknownFramePayload, (absl::string_view payload), + (override)); + MOCK_METHOD(bool, OnUnknownFrameEnd, (), (override)); +}; + +class QuicCryptoClientStreamPeer { + public: + QuicCryptoClientStreamPeer() = delete; + + static QuicCryptoClientStream::HandshakerInterface* GetHandshaker( + QuicCryptoClientStream* stream); +}; + +// Creates a client session for testing. +// +// server_id: The server id associated with this stream. +// connection_start_time: The time to set for the connection clock. +// Needed for strike-register nonce verification. The client +// connection_start_time should be synchronized witht the server +// start time, otherwise nonce verification will fail. +// supported_versions: Set of QUIC versions this client supports. +// helper: Pointer to the MockQuicConnectionHelper to use for the session. +// crypto_client_config: Pointer to the crypto client config. +// client_connection: Pointer reference for newly created +// connection. This object will be owned by the +// client_session. +// client_session: Pointer reference for the newly created client +// session. The new object will be owned by the caller. +void CreateClientSessionForTest( + QuicServerId server_id, QuicTime::Delta connection_start_time, + const ParsedQuicVersionVector& supported_versions, + QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory, + QuicCryptoClientConfig* crypto_client_config, + PacketSavingConnection** client_connection, + TestQuicSpdyClientSession** client_session); + +// Creates a server session for testing. +// +// server_id: The server id associated with this stream. +// connection_start_time: The time to set for the connection clock. +// Needed for strike-register nonce verification. The server +// connection_start_time should be synchronized witht the client +// start time, otherwise nonce verification will fail. +// supported_versions: Set of QUIC versions this server supports. +// helper: Pointer to the MockQuicConnectionHelper to use for the session. +// server_crypto_config: Pointer to the crypto server config. +// server_connection: Pointer reference for newly created +// connection. This object will be owned by the +// server_session. +// server_session: Pointer reference for the newly created server +// session. The new object will be owned by the caller. +void CreateServerSessionForTest( + QuicServerId server_id, QuicTime::Delta connection_start_time, + ParsedQuicVersionVector supported_versions, + QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory, + QuicCryptoServerConfig* server_crypto_config, + QuicCompressedCertsCache* compressed_certs_cache, + PacketSavingConnection** server_connection, + TestQuicSpdyServerSession** server_session); + +// Verifies that the relative error of |actual| with respect to |expected| is +// no more than |margin|. +// Please use EXPECT_APPROX_EQ, a wrapper around this function, for better error +// report. +template +void ExpectApproxEq(T expected, T actual, float relative_margin) { + // If |relative_margin| > 1 and T is an unsigned type, the comparison will + // underflow. + ASSERT_LE(relative_margin, 1); + ASSERT_GE(relative_margin, 0); + + T absolute_margin = expected * relative_margin; + + EXPECT_GE(expected + absolute_margin, actual) << "actual value too big"; + EXPECT_LE(expected - absolute_margin, actual) << "actual value too small"; +} + +#define EXPECT_APPROX_EQ(expected, actual, relative_margin) \ + do { \ + SCOPED_TRACE(testing::Message() << "relative_margin:" << relative_margin); \ + quic::test::ExpectApproxEq(expected, actual, relative_margin); \ + } while (0) + +template +QuicHeaderList AsHeaderList(const T& container) { + QuicHeaderList l; + l.OnHeaderBlockStart(); + size_t total_size = 0; + for (auto p : container) { + total_size += p.first.size() + p.second.size(); + l.OnHeader(p.first, p.second); + } + l.OnHeaderBlockEnd(total_size, total_size); + return l; +} + +// Helper functions for stream ids, to allow test logic to abstract over the +// HTTP stream numbering scheme (i.e. whether one or two QUIC streams are used +// per HTTP transaction). +QuicStreamId GetNthClientInitiatedBidirectionalStreamId( + QuicTransportVersion version, int n); +QuicStreamId GetNthServerInitiatedBidirectionalStreamId( + QuicTransportVersion version, int n); +QuicStreamId GetNthServerInitiatedUnidirectionalStreamId( + QuicTransportVersion version, int n); +QuicStreamId GetNthClientInitiatedUnidirectionalStreamId( + QuicTransportVersion version, int n); + +StreamType DetermineStreamType(QuicStreamId id, ParsedQuicVersion version, + Perspective perspective, bool is_incoming, + StreamType default_type); + +// Creates a MemSlice using a singleton trivial buffer allocator. Performs a +// copy. +quiche::QuicheMemSlice MemSliceFromString(absl::string_view data); + +// Used to compare ReceivedPacketInfo. +MATCHER_P(ReceivedPacketInfoEquals, info, "") { + return info.ToString() == arg.ToString(); +} + +MATCHER_P(ReceivedPacketInfoConnectionIdEquals, destination_connection_id, "") { + return arg.destination_connection_id == destination_connection_id; +} + +MATCHER_P2(InRange, min, max, "") { return arg >= min && arg <= max; } + +// A GMock matcher that prints expected and actual QuicErrorCode strings +// upon failure. Example usage: +// EXPECT_THAT(stream_->connection_error(), IsError(QUIC_INTERNAL_ERROR)); +MATCHER_P(IsError, expected, + absl::StrCat(negation ? "isn't equal to " : "is equal to ", + QuicErrorCodeToString(expected))) { + *result_listener << QuicErrorCodeToString(static_cast(arg)); + return arg == expected; +} + +// Shorthand for IsError(QUIC_NO_ERROR). +// Example usage: EXPECT_THAT(stream_->connection_error(), IsQuicNoError()); +MATCHER(IsQuicNoError, + absl::StrCat(negation ? "isn't equal to " : "is equal to ", + QuicErrorCodeToString(QUIC_NO_ERROR))) { + *result_listener << QuicErrorCodeToString(arg); + return arg == QUIC_NO_ERROR; +} + +// A GMock matcher that prints expected and actual QuicRstStreamErrorCode +// strings upon failure. Example usage: +// EXPECT_THAT(stream_->stream_error(), IsStreamError(QUIC_INTERNAL_ERROR)); +MATCHER_P(IsStreamError, expected, + absl::StrCat(negation ? "isn't equal to " : "is equal to ", + QuicRstStreamErrorCodeToString(expected))) { + *result_listener << QuicRstStreamErrorCodeToString(arg); + return arg == expected; +} + +// Shorthand for IsStreamError(QUIC_STREAM_NO_ERROR). Example usage: +// EXPECT_THAT(stream_->stream_error(), IsQuicStreamNoError()); +MATCHER(IsQuicStreamNoError, + absl::StrCat(negation ? "isn't equal to " : "is equal to ", + QuicRstStreamErrorCodeToString(QUIC_STREAM_NO_ERROR))) { + *result_listener << QuicRstStreamErrorCodeToString(arg); + return arg == QUIC_STREAM_NO_ERROR; +} + +// TaggingEncrypter appends kTagSize bytes of |tag| to the end of each message. +class TaggingEncrypter : public QuicEncrypter { + public: + explicit TaggingEncrypter(uint8_t tag) : tag_(tag) {} + TaggingEncrypter(const TaggingEncrypter&) = delete; + TaggingEncrypter& operator=(const TaggingEncrypter&) = delete; + + ~TaggingEncrypter() override {} + + // QuicEncrypter interface. + bool SetKey(absl::string_view /*key*/) override { return true; } + + bool SetNoncePrefix(absl::string_view /*nonce_prefix*/) override { + return true; + } + + bool SetIV(absl::string_view /*iv*/) override { return true; } + + bool SetHeaderProtectionKey(absl::string_view /*key*/) override { + return true; + } + + bool EncryptPacket(uint64_t packet_number, absl::string_view associated_data, + absl::string_view plaintext, char* output, + size_t* output_length, size_t max_output_length) override; + + std::string GenerateHeaderProtectionMask( + absl::string_view /*sample*/) override { + return std::string(5, 0); + } + + size_t GetKeySize() const override { return 0; } + size_t GetNoncePrefixSize() const override { return 0; } + size_t GetIVSize() const override { return 0; } + + size_t GetMaxPlaintextSize(size_t ciphertext_size) const override { + return ciphertext_size - kTagSize; + } + + size_t GetCiphertextSize(size_t plaintext_size) const override { + return plaintext_size + kTagSize; + } + + QuicPacketCount GetConfidentialityLimit() const override { + return std::numeric_limits::max(); + } + + absl::string_view GetKey() const override { return absl::string_view(); } + + absl::string_view GetNoncePrefix() const override { + return absl::string_view(); + } + + private: + enum { + kTagSize = 16, + }; + + const uint8_t tag_; +}; + +// TaggingDecrypter ensures that the final kTagSize bytes of the message all +// have the same value and then removes them. +class TaggingDecrypter : public QuicDecrypter { + public: + ~TaggingDecrypter() override {} + + // QuicDecrypter interface + bool SetKey(absl::string_view /*key*/) override { return true; } + + bool SetNoncePrefix(absl::string_view /*nonce_prefix*/) override { + return true; + } + + bool SetIV(absl::string_view /*iv*/) override { return true; } + + bool SetHeaderProtectionKey(absl::string_view /*key*/) override { + return true; + } + + bool SetPreliminaryKey(absl::string_view /*key*/) override { + QUIC_BUG(quic_bug_10230_1) << "should not be called"; + return false; + } + + bool SetDiversificationNonce(const DiversificationNonce& /*key*/) override { + return true; + } + + bool DecryptPacket(uint64_t packet_number, absl::string_view associated_data, + absl::string_view ciphertext, char* output, + size_t* output_length, size_t max_output_length) override; + + std::string GenerateHeaderProtectionMask( + QuicDataReader* /*sample_reader*/) override { + return std::string(5, 0); + } + + size_t GetKeySize() const override { return 0; } + size_t GetNoncePrefixSize() const override { return 0; } + size_t GetIVSize() const override { return 0; } + absl::string_view GetKey() const override { return absl::string_view(); } + absl::string_view GetNoncePrefix() const override { + return absl::string_view(); + } + // Use a distinct value starting with 0xFFFFFF, which is never used by TLS. + uint32_t cipher_id() const override { return 0xFFFFFFF0; } + QuicPacketCount GetIntegrityLimit() const override { + return std::numeric_limits::max(); + } + + protected: + virtual uint8_t GetTag(absl::string_view ciphertext) { + return ciphertext.data()[ciphertext.size() - 1]; + } + + private: + enum { + kTagSize = 16, + }; + + bool CheckTag(absl::string_view ciphertext, uint8_t tag); +}; + +// StringTaggingDecrypter ensures that the final kTagSize bytes of the message +// match the expected value. +class StrictTaggingDecrypter : public TaggingDecrypter { + public: + explicit StrictTaggingDecrypter(uint8_t tag) : tag_(tag) {} + ~StrictTaggingDecrypter() override {} + + // TaggingQuicDecrypter + uint8_t GetTag(absl::string_view /*ciphertext*/) override { return tag_; } + + // Use a distinct value starting with 0xFFFFFF, which is never used by TLS. + uint32_t cipher_id() const override { return 0xFFFFFFF1; } + + private: + const uint8_t tag_; +}; + +class TestPacketWriter : public QuicPacketWriter { + struct PacketBuffer { + ABSL_CACHELINE_ALIGNED char buffer[1500]; + bool in_use = false; + }; + + public: + TestPacketWriter(ParsedQuicVersion version, MockClock* clock, + Perspective perspective); + + TestPacketWriter(const TestPacketWriter&) = delete; + TestPacketWriter& operator=(const TestPacketWriter&) = delete; + + ~TestPacketWriter() override; + + // QuicPacketWriter interface + WriteResult WritePacket(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + PerPacketOptions* options) override; + + bool ShouldWriteFail() { return write_should_fail_; } + + bool IsWriteBlocked() const override { return write_blocked_; } + + absl::optional MessageTooBigErrorCode() const override { return 0x1234; } + + void SetWriteBlocked() { write_blocked_ = true; } + + void SetWritable() override { write_blocked_ = false; } + + void SetShouldWriteFail() { write_should_fail_ = true; } + + void SetWriteError(int error_code) { write_error_code_ = error_code; } + + QuicByteCount GetMaxPacketSize( + const QuicSocketAddress& /*peer_address*/) const override { + return max_packet_size_; + } + + bool SupportsReleaseTime() const override { return supports_release_time_; } + + bool IsBatchMode() const override { return is_batch_mode_; } + + QuicPacketBuffer GetNextWriteLocation( + const QuicIpAddress& /*self_address*/, + const QuicSocketAddress& /*peer_address*/) override; + + WriteResult Flush() override; + + void BlockOnNextFlush() { block_on_next_flush_ = true; } + + void BlockOnNextWrite() { block_on_next_write_ = true; } + + void SimulateNextPacketTooLarge() { next_packet_too_large_ = true; } + + void ExpectNextPacketUnprocessable() { next_packet_processable_ = false; } + + void AlwaysGetPacketTooLarge() { always_get_packet_too_large_ = true; } + + // Sets the amount of time that the writer should before the actual write. + void SetWritePauseTimeDelta(QuicTime::Delta delta) { + write_pause_time_delta_ = delta; + } + + void SetBatchMode(bool new_value) { is_batch_mode_ = new_value; } + + const QuicPacketHeader& header() { return framer_.header(); } + + size_t frame_count() const { return framer_.num_frames(); } + + const std::vector& ack_frames() const { + return framer_.ack_frames(); + } + + const std::vector& stop_waiting_frames() const { + return framer_.stop_waiting_frames(); + } + + const std::vector& connection_close_frames() const { + return framer_.connection_close_frames(); + } + + const std::vector& rst_stream_frames() const { + return framer_.rst_stream_frames(); + } + + const std::vector>& stream_frames() const { + return framer_.stream_frames(); + } + + const std::vector>& crypto_frames() const { + return framer_.crypto_frames(); + } + + const std::vector& ping_frames() const { + return framer_.ping_frames(); + } + + const std::vector& message_frames() const { + return framer_.message_frames(); + } + + const std::vector& window_update_frames() const { + return framer_.window_update_frames(); + } + + const std::vector& padding_frames() const { + return framer_.padding_frames(); + } + + const std::vector& path_challenge_frames() const { + return framer_.path_challenge_frames(); + } + + const std::vector& path_response_frames() const { + return framer_.path_response_frames(); + } + + const QuicEncryptedPacket* coalesced_packet() const { + return framer_.coalesced_packet(); + } + + size_t last_packet_size() const { return last_packet_size_; } + + size_t total_bytes_written() const { return total_bytes_written_; } + + const QuicPacketHeader& last_packet_header() const { + return last_packet_header_; + } + + const QuicVersionNegotiationPacket* version_negotiation_packet() { + return framer_.version_negotiation_packet(); + } + + void set_is_write_blocked_data_buffered(bool buffered) { + is_write_blocked_data_buffered_ = buffered; + } + + void set_perspective(Perspective perspective) { + // We invert perspective here, because the framer needs to parse packets + // we send. + QuicFramerPeer::SetPerspective(framer_.framer(), + QuicUtils::InvertPerspective(perspective)); + framer_.framer()->SetInitialObfuscators(TestConnectionId()); + } + + // final_bytes_of_last_packet_ returns the last four bytes of the previous + // packet as a little-endian, uint32_t. This is intended to be used with a + // TaggingEncrypter so that tests can determine which encrypter was used for + // a given packet. + uint32_t final_bytes_of_last_packet() { return final_bytes_of_last_packet_; } + + // Returns the final bytes of the second to last packet. + uint32_t final_bytes_of_previous_packet() { + return final_bytes_of_previous_packet_; + } + + uint32_t packets_write_attempts() const { return packets_write_attempts_; } + + uint32_t flush_attempts() const { return flush_attempts_; } + + uint32_t connection_close_packets() const { + return connection_close_packets_; + } + + void Reset() { framer_.Reset(); } + + void SetSupportedVersions(const ParsedQuicVersionVector& versions) { + framer_.SetSupportedVersions(versions); + } + + void set_max_packet_size(QuicByteCount max_packet_size) { + max_packet_size_ = max_packet_size; + } + + void set_supports_release_time(bool supports_release_time) { + supports_release_time_ = supports_release_time; + } + + SimpleQuicFramer* framer() { return &framer_; } + + const QuicIpAddress& last_write_source_address() const { + return last_write_source_address_; + } + + const QuicSocketAddress& last_write_peer_address() const { + return last_write_peer_address_; + } + + QuicEcnCodepoint last_ecn_sent() const { return last_ecn_sent_; } + + private: + char* AllocPacketBuffer(); + + void FreePacketBuffer(const char* buffer); + + ParsedQuicVersion version_; + SimpleQuicFramer framer_; + size_t last_packet_size_ = 0; + size_t total_bytes_written_ = 0; + QuicPacketHeader last_packet_header_; + bool write_blocked_ = false; + bool write_should_fail_ = false; + bool block_on_next_flush_ = false; + bool block_on_next_write_ = false; + bool next_packet_too_large_ = false; + bool next_packet_processable_ = true; + bool always_get_packet_too_large_ = false; + bool is_write_blocked_data_buffered_ = false; + bool is_batch_mode_ = false; + // Number of times Flush() was called. + uint32_t flush_attempts_ = 0; + // (Batch mode only) Number of bytes buffered in writer. It is used as the + // return value of a successful Flush(). + uint32_t bytes_buffered_ = 0; + uint32_t final_bytes_of_last_packet_ = 0; + uint32_t final_bytes_of_previous_packet_ = 0; + uint32_t packets_write_attempts_ = 0; + uint32_t connection_close_packets_ = 0; + MockClock* clock_ = nullptr; + // If non-zero, the clock will pause during WritePacket for this amount of + // time. + QuicTime::Delta write_pause_time_delta_ = QuicTime::Delta::Zero(); + QuicByteCount max_packet_size_ = kMaxOutgoingPacketSize; + bool supports_release_time_ = false; + // Used to verify writer-allocated packet buffers are properly released. + std::vector packet_buffer_pool_; + // Buffer address => Address of the owning PacketBuffer. + absl::flat_hash_map> + packet_buffer_pool_index_; + // Indices in packet_buffer_pool_ that are not allocated. + std::list packet_buffer_free_list_; + // The soruce/peer address passed into WritePacket(). + QuicIpAddress last_write_source_address_; + QuicSocketAddress last_write_peer_address_; + int write_error_code_{0}; + QuicEcnCodepoint last_ecn_sent_ = ECN_NOT_ECT; +}; + +// Parses a packet generated by +// QuicFramer::WriteClientVersionNegotiationProbePacket. +// |packet_bytes| must point to |packet_length| bytes in memory which represent +// the packet. This method will fill in |destination_connection_id_bytes| +// which must point to at least |*destination_connection_id_length_out| bytes in +// memory. |*destination_connection_id_length_out| will contain the length of +// the received destination connection ID, which on success will match the +// contents of the destination connection ID passed in to +// WriteClientVersionNegotiationProbePacket. +bool ParseClientVersionNegotiationProbePacket( + const char* packet_bytes, size_t packet_length, + char* destination_connection_id_bytes, + uint8_t* destination_connection_id_length_out); + +// Writes an array of bytes that correspond to a QUIC version negotiation packet +// that a QUIC server would send in response to a probe created by +// QuicFramer::WriteClientVersionNegotiationProbePacket. +// The bytes will be written to |packet_bytes|, which must point to +// |*packet_length_out| bytes of memory. |*packet_length_out| will contain the +// length of the created packet. |source_connection_id_bytes| will be sent as +// the source connection ID, and must point to |source_connection_id_length| +// bytes of memory. +bool WriteServerVersionNegotiationProbeResponse( + char* packet_bytes, size_t* packet_length_out, + const char* source_connection_id_bytes, + uint8_t source_connection_id_length); + +// Implementation of Http3DatagramVisitor which saves all received datagrams. +class SavingHttp3DatagramVisitor : public QuicSpdyStream::Http3DatagramVisitor { + public: + struct SavedHttp3Datagram { + QuicStreamId stream_id; + std::string payload; + bool operator==(const SavedHttp3Datagram& o) const { + return stream_id == o.stream_id && payload == o.payload; + } + }; + struct SavedUnknownCapsule { + QuicStreamId stream_id; + uint64_t type; + std::string payload; + bool operator==(const SavedUnknownCapsule& o) const { + return stream_id == o.stream_id && type == o.type && payload == o.payload; + } + }; + const std::vector& received_h3_datagrams() const { + return received_h3_datagrams_; + } + const std::vector& received_unknown_capsules() const { + return received_unknown_capsules_; + } + + // Override from QuicSpdyStream::Http3DatagramVisitor. + void OnHttp3Datagram(QuicStreamId stream_id, + absl::string_view payload) override { + received_h3_datagrams_.push_back( + SavedHttp3Datagram{stream_id, std::string(payload)}); + } + void OnUnknownCapsule(QuicStreamId stream_id, + const quiche::UnknownCapsule& capsule) override { + received_unknown_capsules_.push_back(SavedUnknownCapsule{ + stream_id, capsule.type, std::string(capsule.payload)}); + } + + private: + std::vector received_h3_datagrams_; + std::vector received_unknown_capsules_; +}; + +// Implementation of ConnectIpVisitor which saves all received capsules. +class SavingConnectIpVisitor : public QuicSpdyStream::ConnectIpVisitor { + public: + const std::vector& + received_address_assign_capsules() const { + return received_address_assign_capsules_; + } + const std::vector& + received_address_request_capsules() const { + return received_address_request_capsules_; + } + const std::vector& + received_route_advertisement_capsules() const { + return received_route_advertisement_capsules_; + } + bool headers_written() const { return headers_written_; } + + // From QuicSpdyStream::ConnectIpVisitor. + bool OnAddressAssignCapsule( + const quiche::AddressAssignCapsule& capsule) override { + received_address_assign_capsules_.push_back(capsule); + return true; + } + bool OnAddressRequestCapsule( + const quiche::AddressRequestCapsule& capsule) override { + received_address_request_capsules_.push_back(capsule); + return true; + } + bool OnRouteAdvertisementCapsule( + const quiche::RouteAdvertisementCapsule& capsule) override { + received_route_advertisement_capsules_.push_back(capsule); + return true; + } + void OnHeadersWritten() override { headers_written_ = true; } + + private: + std::vector received_address_assign_capsules_; + std::vector received_address_request_capsules_; + std::vector + received_route_advertisement_capsules_; + bool headers_written_ = false; +}; + +inline std::string EscapeTestParamName(absl::string_view name) { + std::string result(name); + // Escape all characters that are not allowed by gtest ([a-zA-Z0-9_]). + for (char& c : result) { + bool valid = absl::ascii_isalnum(c) || c == '_'; + if (!valid) { + c = '_'; + } + } + return result; +} + +struct TestPerPacketOptions : PerPacketOptions { + public: + std::unique_ptr Clone() const override { + return std::make_unique(*this); + } +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_TEST_UTILS_H_ diff --git a/quiche/quic/test_tools/quic_test_utils_test.cc b/quiche/quic/test_tools/quic_test_utils_test.cc new file mode 100644 index 000000000000..16ca977e9943 --- /dev/null +++ b/quiche/quic/test_tools/quic_test_utils_test.cc @@ -0,0 +1,79 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_test_utils.h" + +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { + +class QuicTestUtilsTest : public QuicTest {}; + +TEST_F(QuicTestUtilsTest, ConnectionId) { + EXPECT_NE(EmptyQuicConnectionId(), TestConnectionId()); + EXPECT_NE(EmptyQuicConnectionId(), TestConnectionId(1)); + EXPECT_EQ(TestConnectionId(), TestConnectionId()); + EXPECT_EQ(TestConnectionId(33), TestConnectionId(33)); + EXPECT_NE(TestConnectionId(0xdead), TestConnectionId(0xbeef)); + EXPECT_EQ(0x1337u, TestConnectionIdToUInt64(TestConnectionId(0x1337))); + EXPECT_NE(0xdeadu, TestConnectionIdToUInt64(TestConnectionId(0xbeef))); +} + +TEST_F(QuicTestUtilsTest, BasicApproxEq) { + EXPECT_APPROX_EQ(10, 10, 1e-6f); + EXPECT_APPROX_EQ(1000, 1001, 0.01f); + EXPECT_NONFATAL_FAILURE(EXPECT_APPROX_EQ(1000, 1100, 0.01f), ""); + + EXPECT_APPROX_EQ(64, 31, 0.55f); + EXPECT_NONFATAL_FAILURE(EXPECT_APPROX_EQ(31, 64, 0.55f), ""); +} + +TEST_F(QuicTestUtilsTest, QuicTimeDelta) { + EXPECT_APPROX_EQ(QuicTime::Delta::FromMicroseconds(1000), + QuicTime::Delta::FromMicroseconds(1003), 0.01f); + EXPECT_NONFATAL_FAILURE( + EXPECT_APPROX_EQ(QuicTime::Delta::FromMicroseconds(1000), + QuicTime::Delta::FromMicroseconds(1200), 0.01f), + ""); +} + +TEST_F(QuicTestUtilsTest, QuicBandwidth) { + EXPECT_APPROX_EQ(QuicBandwidth::FromBytesPerSecond(1000), + QuicBandwidth::FromBitsPerSecond(8005), 0.01f); + EXPECT_NONFATAL_FAILURE( + EXPECT_APPROX_EQ(QuicBandwidth::FromBytesPerSecond(1000), + QuicBandwidth::FromBitsPerSecond(9005), 0.01f), + ""); +} + +// Ensure that SimpleRandom does not change its output for a fixed seed. +TEST_F(QuicTestUtilsTest, SimpleRandomStability) { + SimpleRandom rng; + rng.set_seed(UINT64_C(0x1234567800010001)); + EXPECT_EQ(UINT64_C(12589383305231984671), rng.RandUint64()); + EXPECT_EQ(UINT64_C(17775425089941798664), rng.RandUint64()); +} + +// Ensure that the output of SimpleRandom does not depend on the size of the +// read calls. +TEST_F(QuicTestUtilsTest, SimpleRandomChunks) { + SimpleRandom rng; + std::string reference(16 * 1024, '\0'); + rng.RandBytes(&reference[0], reference.size()); + + for (size_t chunk_size : {3, 4, 7, 4096}) { + rng.set_seed(0); + size_t chunks = reference.size() / chunk_size; + std::string buffer(chunks * chunk_size, '\0'); + for (size_t i = 0; i < chunks; i++) { + rng.RandBytes(&buffer[i * chunk_size], chunk_size); + } + EXPECT_EQ(reference.substr(0, buffer.size()), buffer) + << "Failed for chunk_size = " << chunk_size; + } +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_time_wait_list_manager_peer.cc b/quiche/quic/test_tools/quic_time_wait_list_manager_peer.cc new file mode 100644 index 000000000000..b8c38caa6e7b --- /dev/null +++ b/quiche/quic/test_tools/quic_time_wait_list_manager_peer.cc @@ -0,0 +1,45 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_time_wait_list_manager_peer.h" + +namespace quic { +namespace test { + +bool QuicTimeWaitListManagerPeer::ShouldSendResponse( + QuicTimeWaitListManager* manager, int received_packet_count) { + return manager->ShouldSendResponse(received_packet_count); +} + +QuicTime::Delta QuicTimeWaitListManagerPeer::time_wait_period( + QuicTimeWaitListManager* manager) { + return manager->time_wait_period_; +} + +QuicAlarm* QuicTimeWaitListManagerPeer::expiration_alarm( + QuicTimeWaitListManager* manager) { + return manager->connection_id_clean_up_alarm_.get(); +} + +void QuicTimeWaitListManagerPeer::set_clock(QuicTimeWaitListManager* manager, + const QuicClock* clock) { + manager->clock_ = clock; +} + +// static +bool QuicTimeWaitListManagerPeer::SendOrQueuePacket( + QuicTimeWaitListManager* manager, + std::unique_ptr packet, + const QuicPerPacketContext* packet_context) { + return manager->SendOrQueuePacket(std::move(packet), packet_context); +} + +// static +size_t QuicTimeWaitListManagerPeer::PendingPacketsQueueSize( + QuicTimeWaitListManager* manager) { + return manager->pending_packets_queue_.size(); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_time_wait_list_manager_peer.h b/quiche/quic/test_tools/quic_time_wait_list_manager_peer.h new file mode 100644 index 000000000000..a7aed4719fdd --- /dev/null +++ b/quiche/quic/test_tools/quic_time_wait_list_manager_peer.h @@ -0,0 +1,36 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_TIME_WAIT_LIST_MANAGER_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_TIME_WAIT_LIST_MANAGER_PEER_H_ + +#include "quiche/quic/core/quic_time_wait_list_manager.h" + +namespace quic { +namespace test { + +class QuicTimeWaitListManagerPeer { + public: + static bool ShouldSendResponse(QuicTimeWaitListManager* manager, + int received_packet_count); + + static QuicTime::Delta time_wait_period(QuicTimeWaitListManager* manager); + + static QuicAlarm* expiration_alarm(QuicTimeWaitListManager* manager); + + static void set_clock(QuicTimeWaitListManager* manager, + const QuicClock* clock); + + static bool SendOrQueuePacket( + QuicTimeWaitListManager* manager, + std::unique_ptr packet, + const QuicPerPacketContext* packet_context); + + static size_t PendingPacketsQueueSize(QuicTimeWaitListManager* manager); +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_TIME_WAIT_LIST_MANAGER_PEER_H_ diff --git a/quiche/quic/test_tools/quic_unacked_packet_map_peer.cc b/quiche/quic/test_tools/quic_unacked_packet_map_peer.cc new file mode 100644 index 000000000000..4df89b03ead4 --- /dev/null +++ b/quiche/quic/test_tools/quic_unacked_packet_map_peer.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/quic_unacked_packet_map_peer.h" + +namespace quic { +namespace test { + +// static +const QuicStreamFrame& QuicUnackedPacketMapPeer::GetAggregatedStreamFrame( + const QuicUnackedPacketMap& unacked_packets) { + return unacked_packets.aggregated_stream_frame_; +} + +// static +void QuicUnackedPacketMapPeer::SetPerspective( + QuicUnackedPacketMap* unacked_packets, Perspective perspective) { + *const_cast(&unacked_packets->perspective_) = perspective; +} + +// static +size_t QuicUnackedPacketMapPeer::GetCapacity( + const QuicUnackedPacketMap& unacked_packets) { + return unacked_packets.unacked_packets_.capacity(); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/quic_unacked_packet_map_peer.h b/quiche/quic/test_tools/quic_unacked_packet_map_peer.h new file mode 100644 index 000000000000..5525506f1be2 --- /dev/null +++ b/quiche/quic/test_tools/quic_unacked_packet_map_peer.h @@ -0,0 +1,27 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_UNACKED_PACKET_MAP_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_UNACKED_PACKET_MAP_PEER_H_ + +#include "quiche/quic/core/quic_unacked_packet_map.h" + +namespace quic { +namespace test { + +class QuicUnackedPacketMapPeer { + public: + static const QuicStreamFrame& GetAggregatedStreamFrame( + const QuicUnackedPacketMap& unacked_packets); + + static void SetPerspective(QuicUnackedPacketMap* unacked_packets, + Perspective perspective); + + static size_t GetCapacity(const QuicUnackedPacketMap& unacked_packets); +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_UNACKED_PACKET_MAP_PEER_H_ diff --git a/quiche/quic/test_tools/rtt_stats_peer.cc b/quiche/quic/test_tools/rtt_stats_peer.cc new file mode 100644 index 000000000000..8bf1f990fef2 --- /dev/null +++ b/quiche/quic/test_tools/rtt_stats_peer.cc @@ -0,0 +1,21 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/rtt_stats_peer.h" + +namespace quic { +namespace test { + +// static +void RttStatsPeer::SetSmoothedRtt(RttStats* rtt_stats, QuicTime::Delta rtt_ms) { + rtt_stats->smoothed_rtt_ = rtt_ms; +} + +// static +void RttStatsPeer::SetMinRtt(RttStats* rtt_stats, QuicTime::Delta rtt_ms) { + rtt_stats->min_rtt_ = rtt_ms; +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/rtt_stats_peer.h b/quiche/quic/test_tools/rtt_stats_peer.h new file mode 100644 index 000000000000..5e7f473fb354 --- /dev/null +++ b/quiche/quic/test_tools/rtt_stats_peer.h @@ -0,0 +1,26 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_RTT_STATS_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_RTT_STATS_PEER_H_ + +#include "quiche/quic/core/congestion_control/rtt_stats.h" +#include "quiche/quic/core/quic_time.h" + +namespace quic { +namespace test { + +class RttStatsPeer { + public: + RttStatsPeer() = delete; + + static void SetSmoothedRtt(RttStats* rtt_stats, QuicTime::Delta rtt_ms); + + static void SetMinRtt(RttStats* rtt_stats, QuicTime::Delta rtt_ms); +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_RTT_STATS_PEER_H_ diff --git a/quiche/quic/test_tools/send_algorithm_test_result.proto b/quiche/quic/test_tools/send_algorithm_test_result.proto new file mode 100644 index 000000000000..a836c474ba04 --- /dev/null +++ b/quiche/quic/test_tools/send_algorithm_test_result.proto @@ -0,0 +1,15 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +syntax = "proto2"; + +option optimize_for = LITE_RUNTIME; + +package quic; + +message SendAlgorithmTestResult { + optional string test_name = 1; + optional uint64 random_seed = 2; + optional int64 simulated_duration_micros = 3; +} diff --git a/quiche/quic/test_tools/send_algorithm_test_utils.cc b/quiche/quic/test_tools/send_algorithm_test_utils.cc new file mode 100644 index 000000000000..2ca17919a076 --- /dev/null +++ b/quiche/quic/test_tools/send_algorithm_test_utils.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/send_algorithm_test_utils.h" + +#include "absl/strings/str_cat.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/platform/api/quic_test_output.h" + +namespace quic { +namespace test { + +bool LoadSendAlgorithmTestResult(SendAlgorithmTestResult* result) { + std::string test_result_file_content; + if (!QuicLoadTestOutput(GetSendAlgorithmTestResultFilename(), + &test_result_file_content)) { + return false; + } + return result->ParseFromString(test_result_file_content); +} + +void RecordSendAlgorithmTestResult(uint64_t random_seed, + int64_t simulated_duration_micros) { + SendAlgorithmTestResult result; + result.set_test_name(GetFullSendAlgorithmTestName()); + result.set_random_seed(random_seed); + result.set_simulated_duration_micros(simulated_duration_micros); + + QuicSaveTestOutput(GetSendAlgorithmTestResultFilename(), + result.SerializeAsString()); +} + +void CompareSendAlgorithmTestResult(int64_t actual_simulated_duration_micros) { + SendAlgorithmTestResult expected; + ASSERT_TRUE(LoadSendAlgorithmTestResult(&expected)); + QUIC_LOG(INFO) << "Loaded expected test result: " + << expected.ShortDebugString(); + + EXPECT_GE(expected.simulated_duration_micros(), + actual_simulated_duration_micros); +} + +std::string GetFullSendAlgorithmTestName() { + const auto* test_info = + ::testing::UnitTest::GetInstance()->current_test_info(); + const std::string type_param = + test_info->type_param() ? test_info->type_param() : ""; + const std::string value_param = + test_info->value_param() ? test_info->value_param() : ""; + return absl::StrCat(test_info->test_suite_name(), ".", test_info->name(), "_", + type_param, "_", value_param); +} + +std::string GetSendAlgorithmTestResultFilename() { + return GetFullSendAlgorithmTestName() + ".test_result"; +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/send_algorithm_test_utils.h b/quiche/quic/test_tools/send_algorithm_test_utils.h new file mode 100644 index 000000000000..9516b9aa46fc --- /dev/null +++ b/quiche/quic/test_tools/send_algorithm_test_utils.h @@ -0,0 +1,29 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_SEND_ALGORITHM_TEST_UTILS_H_ +#define QUICHE_QUIC_TEST_TOOLS_SEND_ALGORITHM_TEST_UTILS_H_ + +#include "quiche/quic/test_tools/send_algorithm_test_result.pb.h" + +namespace quic { +namespace test { + +bool LoadSendAlgorithmTestResult(SendAlgorithmTestResult* result); + +void RecordSendAlgorithmTestResult(uint64_t random_seed, + int64_t simulated_duration_micros); + +// Load the expected test result with LoadSendAlgorithmTestResult(), and compare +// it with the actual results provided in the arguments. +void CompareSendAlgorithmTestResult(int64_t actual_simulated_duration_micros); + +std::string GetFullSendAlgorithmTestName(); + +std::string GetSendAlgorithmTestResultFilename(); + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_SEND_ALGORITHM_TEST_UTILS_H_ diff --git a/quiche/quic/test_tools/server_thread.cc b/quiche/quic/test_tools/server_thread.cc new file mode 100644 index 000000000000..c0d32ea40d5a --- /dev/null +++ b/quiche/quic/test_tools/server_thread.cc @@ -0,0 +1,143 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/server_thread.h" + +#include "quiche/quic/core/quic_default_clock.h" +#include "quiche/quic/core/quic_dispatcher.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/quic_dispatcher_peer.h" +#include "quiche/quic/test_tools/quic_server_peer.h" + +namespace quic { +namespace test { + +ServerThread::ServerThread(std::unique_ptr server, + const QuicSocketAddress& address) + : QuicThread("server_thread"), + server_(std::move(server)), + clock_(QuicDefaultClock::Get()), + address_(address), + port_(0), + initialized_(false) {} + +ServerThread::~ServerThread() = default; + +void ServerThread::Initialize() { + if (initialized_) { + return; + } + if (!server_->CreateUDPSocketAndListen(address_)) { + return; + } + + QuicWriterMutexLock lock(&port_lock_); + port_ = server_->port(); + + initialized_ = true; +} + +void ServerThread::Run() { + if (!initialized_) { + Initialize(); + } + + while (!quit_.HasBeenNotified()) { + if (pause_.HasBeenNotified() && !resume_.HasBeenNotified()) { + paused_.Notify(); + resume_.WaitForNotification(); + } + server_->WaitForEvents(); + ExecuteScheduledActions(); + MaybeNotifyOfHandshakeConfirmation(); + } + + server_->Shutdown(); +} + +int ServerThread::GetPort() { + QuicReaderMutexLock lock(&port_lock_); + int rc = port_; + return rc; +} + +void ServerThread::Schedule(std::function action) { + QUICHE_DCHECK(!quit_.HasBeenNotified()); + QuicWriterMutexLock lock(&scheduled_actions_lock_); + scheduled_actions_.push_back(std::move(action)); +} + +void ServerThread::WaitForCryptoHandshakeConfirmed() { + confirmed_.WaitForNotification(); +} + +bool ServerThread::WaitUntil(std::function termination_predicate, + QuicTime::Delta timeout) { + const QuicTime deadline = clock_->Now() + timeout; + while (clock_->Now() < deadline) { + QuicNotification done_checking; + bool should_terminate = false; + Schedule([&] { + should_terminate = termination_predicate(); + done_checking.Notify(); + }); + done_checking.WaitForNotification(); + if (should_terminate) { + return true; + } + } + return false; +} + +void ServerThread::Pause() { + QUICHE_DCHECK(!pause_.HasBeenNotified()); + pause_.Notify(); + paused_.WaitForNotification(); +} + +void ServerThread::Resume() { + QUICHE_DCHECK(!resume_.HasBeenNotified()); + QUICHE_DCHECK(pause_.HasBeenNotified()); + resume_.Notify(); +} + +void ServerThread::Quit() { + if (pause_.HasBeenNotified() && !resume_.HasBeenNotified()) { + resume_.Notify(); + } + if (!quit_.HasBeenNotified()) { + quit_.Notify(); + } +} + +void ServerThread::MaybeNotifyOfHandshakeConfirmation() { + if (confirmed_.HasBeenNotified()) { + // Only notify once. + return; + } + QuicDispatcher* dispatcher = QuicServerPeer::GetDispatcher(server()); + if (dispatcher->NumSessions() == 0) { + // Wait for a session to be created. + return; + } + QuicSession* session = QuicDispatcherPeer::GetFirstSessionIfAny(dispatcher); + if (session->OneRttKeysAvailable()) { + confirmed_.Notify(); + } +} + +void ServerThread::ExecuteScheduledActions() { + quiche::QuicheCircularDeque> actions; + { + QuicWriterMutexLock lock(&scheduled_actions_lock_); + actions.swap(scheduled_actions_); + } + while (!actions.empty()) { + actions.front()(); + actions.pop_front(); + } +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/server_thread.h b/quiche/quic/test_tools/server_thread.h new file mode 100644 index 000000000000..c29d6331aab0 --- /dev/null +++ b/quiche/quic/test_tools/server_thread.h @@ -0,0 +1,96 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_SERVER_THREAD_H_ +#define QUICHE_QUIC_TEST_TOOLS_SERVER_THREAD_H_ + +#include + +#include "quiche/quic/core/quic_config.h" +#include "quiche/quic/platform/api/quic_mutex.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/platform/api/quic_thread.h" +#include "quiche/quic/tools/quic_server.h" + +namespace quic { +namespace test { + +// Simple wrapper class to run QuicServer in a dedicated thread. +class ServerThread : public QuicThread { + public: + ServerThread(std::unique_ptr server, + const QuicSocketAddress& address); + ServerThread(const ServerThread&) = delete; + ServerThread& operator=(const ServerThread&) = delete; + + ~ServerThread() override; + + // Prepares the server, but does not start accepting connections. Useful for + // injecting mocks. + void Initialize(); + + // Runs the event loop. Will initialize if necessary. + void Run() override; + + // Schedules the given action for execution in the event loop. + void Schedule(std::function action); + + // Waits for the handshake to be confirmed for the first session created. + void WaitForCryptoHandshakeConfirmed(); + + // Wait until |termination_predicate| returns true in server thread, or + // reached |timeout|. Must be called from an external thread. + // Return whether the function returned after |termination_predicate| become + // true. + bool WaitUntil(std::function termination_predicate, + QuicTime::Delta timeout); + + // Pauses execution of the server until Resume() is called. May only be + // called once. + void Pause(); + + // Resumes execution of the server after Pause() has been called. May only + // be called once. + void Resume(); + + // Stops the server from executing and shuts it down, destroying all + // server objects. + void Quit(); + + // Returns the underlying server. Care must be taken to avoid data races + // when accessing the server. It is always safe to access the server + // after calling Pause() and before calling Resume(). + QuicServer* server() { return server_.get(); } + + // Returns the port that the server is listening on. + int GetPort(); + + private: + void MaybeNotifyOfHandshakeConfirmation(); + void ExecuteScheduledActions(); + + QuicNotification + confirmed_; // Notified when the first handshake is confirmed. + QuicNotification pause_; // Notified when the server should pause. + QuicNotification paused_; // Notitied when the server has paused + QuicNotification resume_; // Notified when the server should resume. + QuicNotification quit_; // Notified when the server should quit. + + std::unique_ptr server_; + QuicClock* clock_; + QuicSocketAddress address_; + mutable QuicMutex port_lock_; + int port_ QUIC_GUARDED_BY(port_lock_); + + bool initialized_; + + QuicMutex scheduled_actions_lock_; + quiche::QuicheCircularDeque> scheduled_actions_ + QUIC_GUARDED_BY(scheduled_actions_lock_); +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_SERVER_THREAD_H_ diff --git a/quiche/quic/test_tools/simple_data_producer.cc b/quiche/quic/test_tools/simple_data_producer.cc new file mode 100644 index 000000000000..f0adc68b45ac --- /dev/null +++ b/quiche/quic/test_tools/simple_data_producer.cc @@ -0,0 +1,67 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/simple_data_producer.h" + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flags.h" + +namespace quic { + +namespace test { + +SimpleDataProducer::SimpleDataProducer() {} + +SimpleDataProducer::~SimpleDataProducer() {} + +void SimpleDataProducer::SaveStreamData(QuicStreamId id, + absl::string_view data) { + if (data.empty()) { + return; + } + if (!send_buffer_map_.contains(id)) { + send_buffer_map_[id] = std::make_unique(&allocator_); + } + send_buffer_map_[id]->SaveStreamData(data); +} + +void SimpleDataProducer::SaveCryptoData(EncryptionLevel level, + QuicStreamOffset offset, + absl::string_view data) { + auto key = std::make_pair(level, offset); + crypto_buffer_map_[key] = std::string(data); +} + +WriteStreamDataResult SimpleDataProducer::WriteStreamData( + QuicStreamId id, QuicStreamOffset offset, QuicByteCount data_length, + QuicDataWriter* writer) { + auto iter = send_buffer_map_.find(id); + if (iter == send_buffer_map_.end()) { + return STREAM_MISSING; + } + if (iter->second->WriteStreamData(offset, data_length, writer)) { + return WRITE_SUCCESS; + } + return WRITE_FAILED; +} + +bool SimpleDataProducer::WriteCryptoData(EncryptionLevel level, + QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* writer) { + auto it = crypto_buffer_map_.find(std::make_pair(level, offset)); + if (it == crypto_buffer_map_.end() || it->second.length() < data_length) { + return false; + } + return writer->WriteStringPiece( + absl::string_view(it->second.data(), data_length)); +} + +} // namespace test + +} // namespace quic diff --git a/quiche/quic/test_tools/simple_data_producer.h b/quiche/quic/test_tools/simple_data_producer.h new file mode 100644 index 000000000000..96952d744a87 --- /dev/null +++ b/quiche/quic/test_tools/simple_data_producer.h @@ -0,0 +1,73 @@ +// Copyright (c) 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_SIMPLE_DATA_PRODUCER_H_ +#define QUICHE_QUIC_TEST_TOOLS_SIMPLE_DATA_PRODUCER_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_stream_frame_data_producer.h" +#include "quiche/quic/core/quic_stream_send_buffer.h" +#include "quiche/common/simple_buffer_allocator.h" + +namespace quic { + +namespace test { + +// A simple data producer which copies stream data into a map from stream +// id to send buffer. +class SimpleDataProducer : public QuicStreamFrameDataProducer { + public: + SimpleDataProducer(); + ~SimpleDataProducer() override; + + // Saves `data` to be provided when WriteStreamData() is called. Multiple + // calls to SaveStreamData() for the same stream ID append to the buffer for + // that stream. + void SaveStreamData(QuicStreamId id, absl::string_view data); + + void SaveCryptoData(EncryptionLevel level, QuicStreamOffset offset, + absl::string_view data); + + // QuicStreamFrameDataProducer + WriteStreamDataResult WriteStreamData(QuicStreamId id, + QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* writer) override; + bool WriteCryptoData(EncryptionLevel level, QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* writer) override; + + private: + using SendBufferMap = + absl::flat_hash_map>; + + using CryptoBufferMap = + absl::flat_hash_map, + std::string>; + + quiche::SimpleBufferAllocator allocator_; + + SendBufferMap send_buffer_map_; + + // |crypto_buffer_map_| stores data provided by SaveCryptoData to later write + // in WriteCryptoData. The level and data passed into SaveCryptoData are used + // as the key to identify the data when WriteCryptoData is called. + // WriteCryptoData will only succeed if there is data in the map for the + // provided level and offset, and the data in the map matches the data_length + // passed into WriteCryptoData. + // + // Unlike SaveStreamData/WriteStreamData which uses a map of + // QuicStreamSendBuffers (for each stream ID), this map provides data for + // specific offsets. Using a QuicStreamSendBuffer requires that all data + // before an offset exist, whereas this allows providing data that exists at + // arbitrary offsets for testing. + CryptoBufferMap crypto_buffer_map_; +}; + +} // namespace test + +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_SIMPLE_DATA_PRODUCER_H_ diff --git a/quiche/quic/test_tools/simple_quic_framer.cc b/quiche/quic/test_tools/simple_quic_framer.cc new file mode 100644 index 000000000000..33a56effc29e --- /dev/null +++ b/quiche/quic/test_tools/simple_quic_framer.cc @@ -0,0 +1,439 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/simple_quic_framer.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/quic_decrypter.h" +#include "quiche/quic/core/crypto/quic_encrypter.h" +#include "quiche/quic/core/quic_types.h" + +namespace quic { +namespace test { + +class SimpleFramerVisitor : public QuicFramerVisitorInterface { + public: + SimpleFramerVisitor() : error_(QUIC_NO_ERROR) {} + SimpleFramerVisitor(const SimpleFramerVisitor&) = delete; + SimpleFramerVisitor& operator=(const SimpleFramerVisitor&) = delete; + + ~SimpleFramerVisitor() override {} + + void OnError(QuicFramer* framer) override { error_ = framer->error(); } + + bool OnProtocolVersionMismatch(ParsedQuicVersion /*version*/) override { + return false; + } + + void OnPacket() override {} + void OnPublicResetPacket(const QuicPublicResetPacket& packet) override { + public_reset_packet_ = std::make_unique((packet)); + } + void OnVersionNegotiationPacket( + const QuicVersionNegotiationPacket& packet) override { + version_negotiation_packet_ = + std::make_unique((packet)); + } + + void OnRetryPacket(QuicConnectionId /*original_connection_id*/, + QuicConnectionId /*new_connection_id*/, + absl::string_view /*retry_token*/, + absl::string_view /*retry_integrity_tag*/, + absl::string_view /*retry_without_tag*/) override {} + + bool OnUnauthenticatedPublicHeader( + const QuicPacketHeader& /*header*/) override { + return true; + } + bool OnUnauthenticatedHeader(const QuicPacketHeader& /*header*/) override { + return true; + } + void OnDecryptedPacket(size_t /*length*/, EncryptionLevel level) override { + last_decrypted_level_ = level; + } + bool OnPacketHeader(const QuicPacketHeader& header) override { + has_header_ = true; + header_ = header; + return true; + } + + void OnCoalescedPacket(const QuicEncryptedPacket& packet) override { + coalesced_packet_ = packet.Clone(); + } + + void OnUndecryptablePacket(const QuicEncryptedPacket& /*packet*/, + EncryptionLevel /*decryption_level*/, + bool /*has_decryption_key*/) override {} + + bool OnStreamFrame(const QuicStreamFrame& frame) override { + // Save a copy of the data so it is valid after the packet is processed. + std::string* string_data = + new std::string(frame.data_buffer, frame.data_length); + stream_data_.push_back(absl::WrapUnique(string_data)); + // TODO(ianswett): A pointer isn't necessary with emplace_back. + stream_frames_.push_back(std::make_unique( + frame.stream_id, frame.fin, frame.offset, + absl::string_view(*string_data))); + return true; + } + + bool OnCryptoFrame(const QuicCryptoFrame& frame) override { + // Save a copy of the data so it is valid after the packet is processed. + std::string* string_data = + new std::string(frame.data_buffer, frame.data_length); + crypto_data_.push_back(absl::WrapUnique(string_data)); + crypto_frames_.push_back(std::make_unique( + frame.level, frame.offset, absl::string_view(*string_data))); + return true; + } + + bool OnAckFrameStart(QuicPacketNumber largest_acked, + QuicTime::Delta ack_delay_time) override { + QuicAckFrame ack_frame; + ack_frame.largest_acked = largest_acked; + ack_frame.ack_delay_time = ack_delay_time; + ack_frames_.push_back(ack_frame); + return true; + } + + bool OnAckRange(QuicPacketNumber start, QuicPacketNumber end) override { + QUICHE_DCHECK(!ack_frames_.empty()); + ack_frames_[ack_frames_.size() - 1].packets.AddRange(start, end); + return true; + } + + bool OnAckTimestamp(QuicPacketNumber /*packet_number*/, + QuicTime /*timestamp*/) override { + return true; + } + + bool OnAckFrameEnd( + QuicPacketNumber /*start*/, + const absl::optional& /*ecn_counts*/) override { + return true; + } + + bool OnStopWaitingFrame(const QuicStopWaitingFrame& frame) override { + stop_waiting_frames_.push_back(frame); + return true; + } + + bool OnPaddingFrame(const QuicPaddingFrame& frame) override { + padding_frames_.push_back(frame); + return true; + } + + bool OnPingFrame(const QuicPingFrame& frame) override { + ping_frames_.push_back(frame); + return true; + } + + bool OnRstStreamFrame(const QuicRstStreamFrame& frame) override { + rst_stream_frames_.push_back(frame); + return true; + } + + bool OnConnectionCloseFrame(const QuicConnectionCloseFrame& frame) override { + connection_close_frames_.push_back(frame); + return true; + } + + bool OnNewConnectionIdFrame(const QuicNewConnectionIdFrame& frame) override { + new_connection_id_frames_.push_back(frame); + return true; + } + + bool OnRetireConnectionIdFrame( + const QuicRetireConnectionIdFrame& frame) override { + retire_connection_id_frames_.push_back(frame); + return true; + } + + bool OnNewTokenFrame(const QuicNewTokenFrame& frame) override { + new_token_frames_.push_back(frame); + return true; + } + + bool OnStopSendingFrame(const QuicStopSendingFrame& frame) override { + stop_sending_frames_.push_back(frame); + return true; + } + + bool OnPathChallengeFrame(const QuicPathChallengeFrame& frame) override { + path_challenge_frames_.push_back(frame); + return true; + } + + bool OnPathResponseFrame(const QuicPathResponseFrame& frame) override { + path_response_frames_.push_back(frame); + return true; + } + + bool OnGoAwayFrame(const QuicGoAwayFrame& frame) override { + goaway_frames_.push_back(frame); + return true; + } + bool OnMaxStreamsFrame(const QuicMaxStreamsFrame& frame) override { + max_streams_frames_.push_back(frame); + return true; + } + + bool OnStreamsBlockedFrame(const QuicStreamsBlockedFrame& frame) override { + streams_blocked_frames_.push_back(frame); + return true; + } + + bool OnWindowUpdateFrame(const QuicWindowUpdateFrame& frame) override { + window_update_frames_.push_back(frame); + return true; + } + + bool OnBlockedFrame(const QuicBlockedFrame& frame) override { + blocked_frames_.push_back(frame); + return true; + } + + bool OnMessageFrame(const QuicMessageFrame& frame) override { + message_frames_.emplace_back(frame.data, frame.message_length); + return true; + } + + bool OnHandshakeDoneFrame(const QuicHandshakeDoneFrame& frame) override { + handshake_done_frames_.push_back(frame); + return true; + } + + bool OnAckFrequencyFrame(const QuicAckFrequencyFrame& frame) override { + ack_frequency_frames_.push_back(frame); + return true; + } + + void OnPacketComplete() override {} + + bool IsValidStatelessResetToken( + const StatelessResetToken& /*token*/) const override { + return false; + } + + void OnAuthenticatedIetfStatelessResetPacket( + const QuicIetfStatelessResetPacket& packet) override { + stateless_reset_packet_ = + std::make_unique(packet); + } + + void OnKeyUpdate(KeyUpdateReason /*reason*/) override {} + void OnDecryptedFirstPacketInKeyPhase() override {} + std::unique_ptr AdvanceKeysAndCreateCurrentOneRttDecrypter() + override { + return nullptr; + } + std::unique_ptr CreateCurrentOneRttEncrypter() override { + return nullptr; + } + + const QuicPacketHeader& header() const { return header_; } + const std::vector& ack_frames() const { return ack_frames_; } + const std::vector& connection_close_frames() const { + return connection_close_frames_; + } + + const std::vector& goaway_frames() const { + return goaway_frames_; + } + const std::vector& max_streams_frames() const { + return max_streams_frames_; + } + const std::vector& streams_blocked_frames() const { + return streams_blocked_frames_; + } + const std::vector& rst_stream_frames() const { + return rst_stream_frames_; + } + const std::vector>& stream_frames() const { + return stream_frames_; + } + const std::vector>& crypto_frames() const { + return crypto_frames_; + } + const std::vector& stop_waiting_frames() const { + return stop_waiting_frames_; + } + const std::vector& ping_frames() const { return ping_frames_; } + const std::vector& message_frames() const { + return message_frames_; + } + const std::vector& window_update_frames() const { + return window_update_frames_; + } + const std::vector& padding_frames() const { + return padding_frames_; + } + const std::vector& path_challenge_frames() const { + return path_challenge_frames_; + } + const std::vector& path_response_frames() const { + return path_response_frames_; + } + const QuicVersionNegotiationPacket* version_negotiation_packet() const { + return version_negotiation_packet_.get(); + } + EncryptionLevel last_decrypted_level() const { return last_decrypted_level_; } + const QuicEncryptedPacket* coalesced_packet() const { + return coalesced_packet_.get(); + } + + private: + QuicErrorCode error_; + bool has_header_; + QuicPacketHeader header_; + std::unique_ptr version_negotiation_packet_; + std::unique_ptr public_reset_packet_; + std::unique_ptr stateless_reset_packet_; + std::vector ack_frames_; + std::vector stop_waiting_frames_; + std::vector padding_frames_; + std::vector ping_frames_; + std::vector> stream_frames_; + std::vector> crypto_frames_; + std::vector rst_stream_frames_; + std::vector goaway_frames_; + std::vector streams_blocked_frames_; + std::vector max_streams_frames_; + std::vector connection_close_frames_; + std::vector stop_sending_frames_; + std::vector path_challenge_frames_; + std::vector path_response_frames_; + std::vector window_update_frames_; + std::vector blocked_frames_; + std::vector new_connection_id_frames_; + std::vector retire_connection_id_frames_; + std::vector new_token_frames_; + std::vector message_frames_; + std::vector handshake_done_frames_; + std::vector ack_frequency_frames_; + std::vector> stream_data_; + std::vector> crypto_data_; + EncryptionLevel last_decrypted_level_; + std::unique_ptr coalesced_packet_; +}; + +SimpleQuicFramer::SimpleQuicFramer() + : framer_(AllSupportedVersions(), QuicTime::Zero(), Perspective::IS_SERVER, + kQuicDefaultConnectionIdLength) {} + +SimpleQuicFramer::SimpleQuicFramer( + const ParsedQuicVersionVector& supported_versions) + : framer_(supported_versions, QuicTime::Zero(), Perspective::IS_SERVER, + kQuicDefaultConnectionIdLength) {} + +SimpleQuicFramer::SimpleQuicFramer( + const ParsedQuicVersionVector& supported_versions, Perspective perspective) + : framer_(supported_versions, QuicTime::Zero(), perspective, + kQuicDefaultConnectionIdLength) {} + +SimpleQuicFramer::~SimpleQuicFramer() {} + +bool SimpleQuicFramer::ProcessPacket(const QuicEncryptedPacket& packet) { + visitor_ = std::make_unique(); + framer_.set_visitor(visitor_.get()); + return framer_.ProcessPacket(packet); +} + +void SimpleQuicFramer::Reset() { + visitor_ = std::make_unique(); +} + +const QuicPacketHeader& SimpleQuicFramer::header() const { + return visitor_->header(); +} + +const QuicVersionNegotiationPacket* +SimpleQuicFramer::version_negotiation_packet() const { + return visitor_->version_negotiation_packet(); +} + +EncryptionLevel SimpleQuicFramer::last_decrypted_level() const { + return visitor_->last_decrypted_level(); +} + +QuicFramer* SimpleQuicFramer::framer() { return &framer_; } + +size_t SimpleQuicFramer::num_frames() const { + return ack_frames().size() + goaway_frames().size() + + rst_stream_frames().size() + stop_waiting_frames().size() + + path_challenge_frames().size() + path_response_frames().size() + + stream_frames().size() + ping_frames().size() + + connection_close_frames().size() + padding_frames().size() + + crypto_frames().size(); +} + +const std::vector& SimpleQuicFramer::ack_frames() const { + return visitor_->ack_frames(); +} + +const std::vector& SimpleQuicFramer::stop_waiting_frames() + const { + return visitor_->stop_waiting_frames(); +} + +const std::vector& +SimpleQuicFramer::path_challenge_frames() const { + return visitor_->path_challenge_frames(); +} +const std::vector& +SimpleQuicFramer::path_response_frames() const { + return visitor_->path_response_frames(); +} + +const std::vector& SimpleQuicFramer::ping_frames() const { + return visitor_->ping_frames(); +} + +const std::vector& SimpleQuicFramer::message_frames() const { + return visitor_->message_frames(); +} + +const std::vector& +SimpleQuicFramer::window_update_frames() const { + return visitor_->window_update_frames(); +} + +const std::vector>& +SimpleQuicFramer::stream_frames() const { + return visitor_->stream_frames(); +} + +const std::vector>& +SimpleQuicFramer::crypto_frames() const { + return visitor_->crypto_frames(); +} + +const std::vector& SimpleQuicFramer::rst_stream_frames() + const { + return visitor_->rst_stream_frames(); +} + +const std::vector& SimpleQuicFramer::goaway_frames() const { + return visitor_->goaway_frames(); +} + +const std::vector& +SimpleQuicFramer::connection_close_frames() const { + return visitor_->connection_close_frames(); +} + +const std::vector& SimpleQuicFramer::padding_frames() const { + return visitor_->padding_frames(); +} + +const QuicEncryptedPacket* SimpleQuicFramer::coalesced_packet() const { + return visitor_->coalesced_packet(); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/simple_quic_framer.h b/quiche/quic/test_tools/simple_quic_framer.h new file mode 100644 index 000000000000..a748f4c400b0 --- /dev/null +++ b/quiche/quic/test_tools/simple_quic_framer.h @@ -0,0 +1,70 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_SIMPLE_QUIC_FRAMER_H_ +#define QUICHE_QUIC_TEST_TOOLS_SIMPLE_QUIC_FRAMER_H_ + +#include +#include + +#include "quiche/quic/core/quic_framer.h" +#include "quiche/quic/core/quic_packets.h" + +namespace quic { + +struct QuicAckFrame; + +namespace test { + +class SimpleFramerVisitor; + +// Peer to make public a number of otherwise private QuicFramer methods. +class SimpleQuicFramer { + public: + SimpleQuicFramer(); + explicit SimpleQuicFramer(const ParsedQuicVersionVector& supported_versions); + SimpleQuicFramer(const ParsedQuicVersionVector& supported_versions, + Perspective perspective); + SimpleQuicFramer(const SimpleQuicFramer&) = delete; + SimpleQuicFramer& operator=(const SimpleQuicFramer&) = delete; + ~SimpleQuicFramer(); + + bool ProcessPacket(const QuicEncryptedPacket& packet); + void Reset(); + + const QuicPacketHeader& header() const; + size_t num_frames() const; + const std::vector& ack_frames() const; + const std::vector& connection_close_frames() const; + const std::vector& stop_waiting_frames() const; + const std::vector& path_challenge_frames() const; + const std::vector& path_response_frames() const; + const std::vector& ping_frames() const; + const std::vector& message_frames() const; + const std::vector& window_update_frames() const; + const std::vector& goaway_frames() const; + const std::vector& rst_stream_frames() const; + const std::vector>& stream_frames() const; + const std::vector>& crypto_frames() const; + const std::vector& padding_frames() const; + const QuicVersionNegotiationPacket* version_negotiation_packet() const; + EncryptionLevel last_decrypted_level() const; + const QuicEncryptedPacket* coalesced_packet() const; + + QuicFramer* framer(); + + void SetSupportedVersions(const ParsedQuicVersionVector& versions) { + framer_.SetSupportedVersions(versions); + } + + private: + QuicFramer framer_; + std::unique_ptr visitor_; +}; + +} // namespace test + +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_SIMPLE_QUIC_FRAMER_H_ diff --git a/quiche/quic/test_tools/simple_session_cache.cc b/quiche/quic/test_tools/simple_session_cache.cc new file mode 100644 index 000000000000..05f433d5388c --- /dev/null +++ b/quiche/quic/test_tools/simple_session_cache.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/simple_session_cache.h" + +#include + +#include "quiche/quic/core/crypto/quic_crypto_client_config.h" + +namespace quic { +namespace test { + +void SimpleSessionCache::Insert(const QuicServerId& server_id, + bssl::UniquePtr session, + const TransportParameters& params, + const ApplicationState* application_state) { + auto it = cache_entries_.find(server_id); + if (it == cache_entries_.end()) { + it = cache_entries_.insert(std::make_pair(server_id, Entry())).first; + } + if (session != nullptr) { + it->second.session = std::move(session); + } + if (application_state != nullptr) { + it->second.application_state = + std::make_unique(*application_state); + } + it->second.params = std::make_unique(params); +} + +std::unique_ptr SimpleSessionCache::Lookup( + const QuicServerId& server_id, QuicWallTime /*now*/, + const SSL_CTX* /*ctx*/) { + auto it = cache_entries_.find(server_id); + if (it == cache_entries_.end()) { + return nullptr; + } + + if (!it->second.session) { + cache_entries_.erase(it); + return nullptr; + } + + auto state = std::make_unique(); + state->tls_session = std::move(it->second.session); + if (it->second.application_state != nullptr) { + state->application_state = + std::make_unique(*it->second.application_state); + } + state->transport_params = + std::make_unique(*it->second.params); + state->token = it->second.token; + return state; +} + +void SimpleSessionCache::ClearEarlyData(const QuicServerId& /*server_id*/) { + // The simple session cache only stores 1 SSL ticket per entry, so no need to + // do anything here. +} + +void SimpleSessionCache::OnNewTokenReceived(const QuicServerId& server_id, + absl::string_view token) { + auto it = cache_entries_.find(server_id); + if (it == cache_entries_.end()) { + return; + } + it->second.token = std::string(token); +} + +void SimpleSessionCache::RemoveExpiredEntries(QuicWallTime /*now*/) { + // The simple session cache does not support removing expired entries. +} + +void SimpleSessionCache::Clear() { cache_entries_.clear(); } + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/simple_session_cache.h b/quiche/quic/test_tools/simple_session_cache.h new file mode 100644 index 000000000000..90558007ddb4 --- /dev/null +++ b/quiche/quic/test_tools/simple_session_cache.h @@ -0,0 +1,53 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_SIMPLE_SESSION_CACHE_H_ +#define QUICHE_QUIC_TEST_TOOLS_SIMPLE_SESSION_CACHE_H_ + +#include + +#include "quiche/quic/core/crypto/quic_crypto_client_config.h" +#include "quiche/quic/core/crypto/transport_parameters.h" + +namespace quic { +namespace test { + +// SimpleSessionCache provides a simple implementation of SessionCache that +// stores only one QuicResumptionState per QuicServerId. No limit is placed on +// the total number of entries in the cache. When Lookup is called, if a cache +// entry exists for the provided QuicServerId, the entry will be removed from +// the cached when it is returned. +// TODO(fayang): Remove SimpleSessionCache by using QuicClientSessionCache. +class SimpleSessionCache : public SessionCache { + public: + SimpleSessionCache() = default; + ~SimpleSessionCache() override = default; + + void Insert(const QuicServerId& server_id, + bssl::UniquePtr session, + const TransportParameters& params, + const ApplicationState* application_state) override; + std::unique_ptr Lookup(const QuicServerId& server_id, + QuicWallTime now, + const SSL_CTX* ctx) override; + void ClearEarlyData(const QuicServerId& server_id) override; + void OnNewTokenReceived(const QuicServerId& server_id, + absl::string_view token) override; + void RemoveExpiredEntries(QuicWallTime now) override; + void Clear() override; + + private: + struct Entry { + bssl::UniquePtr session; + std::unique_ptr params; + std::unique_ptr application_state; + std::string token; + }; + std::map cache_entries_; +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_SIMPLE_SESSION_CACHE_H_ diff --git a/quiche/quic/test_tools/simple_session_notifier.cc b/quiche/quic/test_tools/simple_session_notifier.cc new file mode 100644 index 000000000000..7a2e70581564 --- /dev/null +++ b/quiche/quic/test_tools/simple_session_notifier.cc @@ -0,0 +1,768 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/simple_session_notifier.h" + +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/test_tools/quic_test_utils.h" + +namespace quic { + +namespace test { + +SimpleSessionNotifier::SimpleSessionNotifier(QuicConnection* connection) + : last_control_frame_id_(kInvalidControlFrameId), + least_unacked_(1), + least_unsent_(1), + connection_(connection) {} + +SimpleSessionNotifier::~SimpleSessionNotifier() { + while (!control_frames_.empty()) { + DeleteFrame(&control_frames_.front()); + control_frames_.pop_front(); + } +} + +SimpleSessionNotifier::StreamState::StreamState() + : bytes_total(0), + bytes_sent(0), + fin_buffered(false), + fin_sent(false), + fin_outstanding(false), + fin_lost(false) {} + +SimpleSessionNotifier::StreamState::~StreamState() {} + +QuicConsumedData SimpleSessionNotifier::WriteOrBufferData( + QuicStreamId id, QuicByteCount data_length, StreamSendingState state) { + return WriteOrBufferData(id, data_length, state, NOT_RETRANSMISSION); +} + +QuicConsumedData SimpleSessionNotifier::WriteOrBufferData( + QuicStreamId id, QuicByteCount data_length, StreamSendingState state, + TransmissionType transmission_type) { + if (!stream_map_.contains(id)) { + stream_map_[id] = StreamState(); + } + StreamState& stream_state = stream_map_.find(id)->second; + const bool had_buffered_data = + HasBufferedStreamData() || HasBufferedControlFrames(); + QuicStreamOffset offset = stream_state.bytes_sent; + QUIC_DVLOG(1) << "WriteOrBuffer stream_id: " << id << " [" << offset << ", " + << offset + data_length << "), fin: " << (state != NO_FIN); + stream_state.bytes_total += data_length; + stream_state.fin_buffered = state != NO_FIN; + if (had_buffered_data) { + QUIC_DLOG(WARNING) << "Connection is write blocked"; + return {0, false}; + } + const size_t length = stream_state.bytes_total - stream_state.bytes_sent; + connection_->SetTransmissionType(transmission_type); + QuicConsumedData consumed = + connection_->SendStreamData(id, length, stream_state.bytes_sent, state); + QUIC_DVLOG(1) << "consumed: " << consumed; + OnStreamDataConsumed(id, stream_state.bytes_sent, consumed.bytes_consumed, + consumed.fin_consumed); + return consumed; +} + +void SimpleSessionNotifier::OnStreamDataConsumed(QuicStreamId id, + QuicStreamOffset offset, + QuicByteCount data_length, + bool fin) { + StreamState& state = stream_map_.find(id)->second; + if (QuicUtils::IsCryptoStreamId(connection_->transport_version(), id) && + data_length > 0) { + crypto_bytes_transferred_[connection_->encryption_level()].Add( + offset, offset + data_length); + } + state.bytes_sent += data_length; + state.fin_sent = fin; + state.fin_outstanding = fin; +} + +size_t SimpleSessionNotifier::WriteCryptoData(EncryptionLevel level, + QuicByteCount data_length, + QuicStreamOffset offset) { + crypto_state_[level].bytes_total += data_length; + size_t bytes_written = + connection_->SendCryptoData(level, data_length, offset); + crypto_state_[level].bytes_sent += bytes_written; + crypto_bytes_transferred_[level].Add(offset, offset + bytes_written); + return bytes_written; +} + +void SimpleSessionNotifier::WriteOrBufferRstStream( + QuicStreamId id, QuicRstStreamErrorCode error, + QuicStreamOffset bytes_written) { + QUIC_DVLOG(1) << "Writing RST_STREAM_FRAME"; + const bool had_buffered_data = + HasBufferedStreamData() || HasBufferedControlFrames(); + control_frames_.emplace_back((QuicFrame(new QuicRstStreamFrame( + ++last_control_frame_id_, id, error, bytes_written)))); + if (error != QUIC_STREAM_NO_ERROR) { + // Delete stream to avoid retransmissions. + stream_map_.erase(id); + } + if (had_buffered_data) { + QUIC_DLOG(WARNING) << "Connection is write blocked"; + return; + } + WriteBufferedControlFrames(); +} + +void SimpleSessionNotifier::WriteOrBufferWindowUpate( + QuicStreamId id, QuicStreamOffset byte_offset) { + QUIC_DVLOG(1) << "Writing WINDOW_UPDATE"; + const bool had_buffered_data = + HasBufferedStreamData() || HasBufferedControlFrames(); + QuicControlFrameId control_frame_id = ++last_control_frame_id_; + control_frames_.emplace_back( + (QuicFrame(QuicWindowUpdateFrame(control_frame_id, id, byte_offset)))); + if (had_buffered_data) { + QUIC_DLOG(WARNING) << "Connection is write blocked"; + return; + } + WriteBufferedControlFrames(); +} + +void SimpleSessionNotifier::WriteOrBufferPing() { + QUIC_DVLOG(1) << "Writing PING_FRAME"; + const bool had_buffered_data = + HasBufferedStreamData() || HasBufferedControlFrames(); + control_frames_.emplace_back( + (QuicFrame(QuicPingFrame(++last_control_frame_id_)))); + if (had_buffered_data) { + QUIC_DLOG(WARNING) << "Connection is write blocked"; + return; + } + WriteBufferedControlFrames(); +} + +void SimpleSessionNotifier::WriteOrBufferAckFrequency( + const QuicAckFrequencyFrame& ack_frequency_frame) { + QUIC_DVLOG(1) << "Writing ACK_FREQUENCY"; + const bool had_buffered_data = + HasBufferedStreamData() || HasBufferedControlFrames(); + QuicControlFrameId control_frame_id = ++last_control_frame_id_; + control_frames_.emplace_back(( + QuicFrame(new QuicAckFrequencyFrame(control_frame_id, + /*sequence_number=*/control_frame_id, + ack_frequency_frame.packet_tolerance, + ack_frequency_frame.max_ack_delay)))); + if (had_buffered_data) { + QUIC_DLOG(WARNING) << "Connection is write blocked"; + return; + } + WriteBufferedControlFrames(); +} + +void SimpleSessionNotifier::NeuterUnencryptedData() { + if (QuicVersionUsesCryptoFrames(connection_->transport_version())) { + for (const auto& interval : crypto_bytes_transferred_[ENCRYPTION_INITIAL]) { + QuicCryptoFrame crypto_frame(ENCRYPTION_INITIAL, interval.min(), + interval.max() - interval.min()); + OnFrameAcked(QuicFrame(&crypto_frame), QuicTime::Delta::Zero(), + QuicTime::Zero()); + } + return; + } + for (const auto& interval : crypto_bytes_transferred_[ENCRYPTION_INITIAL]) { + QuicStreamFrame stream_frame( + QuicUtils::GetCryptoStreamId(connection_->transport_version()), false, + interval.min(), interval.max() - interval.min()); + OnFrameAcked(QuicFrame(stream_frame), QuicTime::Delta::Zero(), + QuicTime::Zero()); + } +} + +void SimpleSessionNotifier::OnCanWrite() { + if (connection_->framer().is_processing_packet()) { + // Do not write data in the middle of packet processing because rest + // frames in the packet may change the data to write. For example, lost + // data could be acknowledged. Also, connection is going to emit + // OnCanWrite signal post packet processing. + QUIC_BUG(simple_notifier_write_mid_packet_processing) + << "Try to write mid packet processing."; + return; + } + if (!RetransmitLostCryptoData() || !RetransmitLostControlFrames() || + !RetransmitLostStreamData()) { + return; + } + if (!WriteBufferedCryptoData() || !WriteBufferedControlFrames()) { + return; + } + // Write new data. + for (const auto& pair : stream_map_) { + const auto& state = pair.second; + if (!StreamHasBufferedData(pair.first)) { + continue; + } + + const size_t length = state.bytes_total - state.bytes_sent; + const bool can_bundle_fin = + state.fin_buffered && (state.bytes_sent + length == state.bytes_total); + connection_->SetTransmissionType(NOT_RETRANSMISSION); + QuicConnection::ScopedEncryptionLevelContext context( + connection_, + connection_->framer().GetEncryptionLevelToSendApplicationData()); + QuicConsumedData consumed = connection_->SendStreamData( + pair.first, length, state.bytes_sent, can_bundle_fin ? FIN : NO_FIN); + QUIC_DVLOG(1) << "Tries to write stream_id: " << pair.first << " [" + << state.bytes_sent << ", " << state.bytes_sent + length + << "), fin: " << can_bundle_fin + << ", and consumed: " << consumed; + OnStreamDataConsumed(pair.first, state.bytes_sent, consumed.bytes_consumed, + consumed.fin_consumed); + if (length != consumed.bytes_consumed || + (can_bundle_fin && !consumed.fin_consumed)) { + break; + } + } +} + +void SimpleSessionNotifier::OnStreamReset(QuicStreamId id, + QuicRstStreamErrorCode error) { + if (error != QUIC_STREAM_NO_ERROR) { + // Delete stream to avoid retransmissions. + stream_map_.erase(id); + } +} + +bool SimpleSessionNotifier::WillingToWrite() const { + QUIC_DVLOG(1) << "has_buffered_control_frames: " << HasBufferedControlFrames() + << " as_lost_control_frames: " << !lost_control_frames_.empty() + << " has_buffered_stream_data: " << HasBufferedStreamData() + << " has_lost_stream_data: " << HasLostStreamData(); + return HasBufferedControlFrames() || !lost_control_frames_.empty() || + HasBufferedStreamData() || HasLostStreamData(); +} + +QuicByteCount SimpleSessionNotifier::StreamBytesSent() const { + QuicByteCount bytes_sent = 0; + for (const auto& pair : stream_map_) { + const auto& state = pair.second; + bytes_sent += state.bytes_sent; + } + return bytes_sent; +} + +QuicByteCount SimpleSessionNotifier::StreamBytesToSend() const { + QuicByteCount bytes_to_send = 0; + for (const auto& pair : stream_map_) { + const auto& state = pair.second; + bytes_to_send += (state.bytes_total - state.bytes_sent); + } + return bytes_to_send; +} + +bool SimpleSessionNotifier::OnFrameAcked(const QuicFrame& frame, + QuicTime::Delta /*ack_delay_time*/, + QuicTime /*receive_timestamp*/) { + QUIC_DVLOG(1) << "Acking " << frame; + if (frame.type == CRYPTO_FRAME) { + StreamState* state = &crypto_state_[frame.crypto_frame->level]; + QuicStreamOffset offset = frame.crypto_frame->offset; + QuicByteCount data_length = frame.crypto_frame->data_length; + QuicIntervalSet newly_acked(offset, offset + data_length); + newly_acked.Difference(state->bytes_acked); + if (newly_acked.Empty()) { + return false; + } + state->bytes_acked.Add(offset, offset + data_length); + state->pending_retransmissions.Difference(offset, offset + data_length); + return true; + } + if (frame.type != STREAM_FRAME) { + return OnControlFrameAcked(frame); + } + if (!stream_map_.contains(frame.stream_frame.stream_id)) { + return false; + } + auto* state = &stream_map_.find(frame.stream_frame.stream_id)->second; + QuicStreamOffset offset = frame.stream_frame.offset; + QuicByteCount data_length = frame.stream_frame.data_length; + QuicIntervalSet newly_acked(offset, offset + data_length); + newly_acked.Difference(state->bytes_acked); + const bool fin_newly_acked = frame.stream_frame.fin && state->fin_outstanding; + if (newly_acked.Empty() && !fin_newly_acked) { + return false; + } + state->bytes_acked.Add(offset, offset + data_length); + if (fin_newly_acked) { + state->fin_outstanding = false; + state->fin_lost = false; + } + state->pending_retransmissions.Difference(offset, offset + data_length); + return true; +} + +void SimpleSessionNotifier::OnFrameLost(const QuicFrame& frame) { + QUIC_DVLOG(1) << "Losting " << frame; + if (frame.type == CRYPTO_FRAME) { + StreamState* state = &crypto_state_[frame.crypto_frame->level]; + QuicStreamOffset offset = frame.crypto_frame->offset; + QuicByteCount data_length = frame.crypto_frame->data_length; + QuicIntervalSet bytes_lost(offset, offset + data_length); + bytes_lost.Difference(state->bytes_acked); + if (bytes_lost.Empty()) { + return; + } + for (const auto& lost : bytes_lost) { + state->pending_retransmissions.Add(lost.min(), lost.max()); + } + return; + } + if (frame.type != STREAM_FRAME) { + OnControlFrameLost(frame); + return; + } + if (!stream_map_.contains(frame.stream_frame.stream_id)) { + return; + } + auto* state = &stream_map_.find(frame.stream_frame.stream_id)->second; + QuicStreamOffset offset = frame.stream_frame.offset; + QuicByteCount data_length = frame.stream_frame.data_length; + QuicIntervalSet bytes_lost(offset, offset + data_length); + bytes_lost.Difference(state->bytes_acked); + const bool fin_lost = state->fin_outstanding && frame.stream_frame.fin; + if (bytes_lost.Empty() && !fin_lost) { + return; + } + for (const auto& lost : bytes_lost) { + state->pending_retransmissions.Add(lost.min(), lost.max()); + } + state->fin_lost = fin_lost; +} + +bool SimpleSessionNotifier::RetransmitFrames(const QuicFrames& frames, + TransmissionType type) { + QuicConnection::ScopedPacketFlusher retransmission_flusher(connection_); + connection_->SetTransmissionType(type); + for (const QuicFrame& frame : frames) { + if (frame.type == CRYPTO_FRAME) { + const StreamState& state = crypto_state_[frame.crypto_frame->level]; + const EncryptionLevel current_encryption_level = + connection_->encryption_level(); + QuicIntervalSet retransmission( + frame.crypto_frame->offset, + frame.crypto_frame->offset + frame.crypto_frame->data_length); + retransmission.Difference(state.bytes_acked); + for (const auto& interval : retransmission) { + QuicStreamOffset offset = interval.min(); + QuicByteCount length = interval.max() - interval.min(); + connection_->SetDefaultEncryptionLevel(frame.crypto_frame->level); + size_t consumed = connection_->SendCryptoData(frame.crypto_frame->level, + length, offset); + if (consumed < length) { + return false; + } + } + connection_->SetDefaultEncryptionLevel(current_encryption_level); + } + if (frame.type != STREAM_FRAME) { + if (GetControlFrameId(frame) == kInvalidControlFrameId) { + continue; + } + QuicFrame copy = CopyRetransmittableControlFrame(frame); + if (!connection_->SendControlFrame(copy)) { + // Connection is write blocked. + DeleteFrame(©); + return false; + } + continue; + } + if (!stream_map_.contains(frame.stream_frame.stream_id)) { + continue; + } + const auto& state = stream_map_.find(frame.stream_frame.stream_id)->second; + QuicIntervalSet retransmission( + frame.stream_frame.offset, + frame.stream_frame.offset + frame.stream_frame.data_length); + EncryptionLevel retransmission_encryption_level = + connection_->encryption_level(); + if (QuicUtils::IsCryptoStreamId(connection_->transport_version(), + frame.stream_frame.stream_id)) { + for (size_t i = 0; i < NUM_ENCRYPTION_LEVELS; ++i) { + if (retransmission.Intersects(crypto_bytes_transferred_[i])) { + retransmission_encryption_level = static_cast(i); + retransmission.Intersection(crypto_bytes_transferred_[i]); + break; + } + } + } + retransmission.Difference(state.bytes_acked); + bool retransmit_fin = frame.stream_frame.fin && state.fin_outstanding; + QuicConsumedData consumed(0, false); + for (const auto& interval : retransmission) { + QuicStreamOffset retransmission_offset = interval.min(); + QuicByteCount retransmission_length = interval.max() - interval.min(); + const bool can_bundle_fin = + retransmit_fin && + (retransmission_offset + retransmission_length == state.bytes_sent); + QuicConnection::ScopedEncryptionLevelContext context( + connection_, + QuicUtils::IsCryptoStreamId(connection_->transport_version(), + frame.stream_frame.stream_id) + ? retransmission_encryption_level + : connection_->framer() + .GetEncryptionLevelToSendApplicationData()); + consumed = connection_->SendStreamData( + frame.stream_frame.stream_id, retransmission_length, + retransmission_offset, can_bundle_fin ? FIN : NO_FIN); + QUIC_DVLOG(1) << "stream " << frame.stream_frame.stream_id + << " is forced to retransmit stream data [" + << retransmission_offset << ", " + << retransmission_offset + retransmission_length + << ") and fin: " << can_bundle_fin + << ", consumed: " << consumed; + if (can_bundle_fin) { + retransmit_fin = !consumed.fin_consumed; + } + if (consumed.bytes_consumed < retransmission_length || + (can_bundle_fin && !consumed.fin_consumed)) { + // Connection is write blocked. + return false; + } + } + if (retransmit_fin) { + QUIC_DVLOG(1) << "stream " << frame.stream_frame.stream_id + << " retransmits fin only frame."; + consumed = connection_->SendStreamData(frame.stream_frame.stream_id, 0, + state.bytes_sent, FIN); + if (!consumed.fin_consumed) { + return false; + } + } + } + return true; +} + +bool SimpleSessionNotifier::IsFrameOutstanding(const QuicFrame& frame) const { + if (frame.type == CRYPTO_FRAME) { + QuicStreamOffset offset = frame.crypto_frame->offset; + QuicByteCount data_length = frame.crypto_frame->data_length; + bool ret = data_length > 0 && + !crypto_state_[frame.crypto_frame->level].bytes_acked.Contains( + offset, offset + data_length); + return ret; + } + if (frame.type != STREAM_FRAME) { + return IsControlFrameOutstanding(frame); + } + if (!stream_map_.contains(frame.stream_frame.stream_id)) { + return false; + } + const auto& state = stream_map_.find(frame.stream_frame.stream_id)->second; + QuicStreamOffset offset = frame.stream_frame.offset; + QuicByteCount data_length = frame.stream_frame.data_length; + return (data_length > 0 && + !state.bytes_acked.Contains(offset, offset + data_length)) || + (frame.stream_frame.fin && state.fin_outstanding); +} + +bool SimpleSessionNotifier::HasUnackedCryptoData() const { + if (QuicVersionUsesCryptoFrames(connection_->transport_version())) { + for (size_t i = 0; i < NUM_ENCRYPTION_LEVELS; ++i) { + const StreamState& state = crypto_state_[i]; + if (state.bytes_total > state.bytes_sent) { + return true; + } + QuicIntervalSet bytes_to_ack(0, state.bytes_total); + bytes_to_ack.Difference(state.bytes_acked); + if (!bytes_to_ack.Empty()) { + return true; + } + } + return false; + } + if (!stream_map_.contains( + QuicUtils::GetCryptoStreamId(connection_->transport_version()))) { + return false; + } + const auto& state = + stream_map_ + .find(QuicUtils::GetCryptoStreamId(connection_->transport_version())) + ->second; + if (state.bytes_total > state.bytes_sent) { + return true; + } + QuicIntervalSet bytes_to_ack(0, state.bytes_total); + bytes_to_ack.Difference(state.bytes_acked); + return !bytes_to_ack.Empty(); +} + +bool SimpleSessionNotifier::HasUnackedStreamData() const { + for (const auto& it : stream_map_) { + if (StreamIsWaitingForAcks(it.first)) return true; + } + return false; +} + +bool SimpleSessionNotifier::OnControlFrameAcked(const QuicFrame& frame) { + QuicControlFrameId id = GetControlFrameId(frame); + if (id == kInvalidControlFrameId) { + return false; + } + QUICHE_DCHECK(id < least_unacked_ + control_frames_.size()); + if (id < least_unacked_ || + GetControlFrameId(control_frames_.at(id - least_unacked_)) == + kInvalidControlFrameId) { + return false; + } + SetControlFrameId(kInvalidControlFrameId, + &control_frames_.at(id - least_unacked_)); + lost_control_frames_.erase(id); + while (!control_frames_.empty() && + GetControlFrameId(control_frames_.front()) == kInvalidControlFrameId) { + DeleteFrame(&control_frames_.front()); + control_frames_.pop_front(); + ++least_unacked_; + } + return true; +} + +void SimpleSessionNotifier::OnControlFrameLost(const QuicFrame& frame) { + QuicControlFrameId id = GetControlFrameId(frame); + if (id == kInvalidControlFrameId) { + return; + } + QUICHE_DCHECK(id < least_unacked_ + control_frames_.size()); + if (id < least_unacked_ || + GetControlFrameId(control_frames_.at(id - least_unacked_)) == + kInvalidControlFrameId) { + return; + } + if (!lost_control_frames_.contains(id)) { + lost_control_frames_[id] = true; + } +} + +bool SimpleSessionNotifier::IsControlFrameOutstanding( + const QuicFrame& frame) const { + QuicControlFrameId id = GetControlFrameId(frame); + if (id == kInvalidControlFrameId) { + return false; + } + return id < least_unacked_ + control_frames_.size() && id >= least_unacked_ && + GetControlFrameId(control_frames_.at(id - least_unacked_)) != + kInvalidControlFrameId; +} + +bool SimpleSessionNotifier::RetransmitLostControlFrames() { + while (!lost_control_frames_.empty()) { + QuicFrame pending = control_frames_.at(lost_control_frames_.begin()->first - + least_unacked_); + QuicFrame copy = CopyRetransmittableControlFrame(pending); + connection_->SetTransmissionType(LOSS_RETRANSMISSION); + if (!connection_->SendControlFrame(copy)) { + // Connection is write blocked. + DeleteFrame(©); + break; + } + lost_control_frames_.pop_front(); + } + return lost_control_frames_.empty(); +} + +bool SimpleSessionNotifier::RetransmitLostCryptoData() { + if (QuicVersionUsesCryptoFrames(connection_->transport_version())) { + for (EncryptionLevel level : + {ENCRYPTION_INITIAL, ENCRYPTION_HANDSHAKE, ENCRYPTION_ZERO_RTT, + ENCRYPTION_FORWARD_SECURE}) { + auto& state = crypto_state_[level]; + while (!state.pending_retransmissions.Empty()) { + connection_->SetTransmissionType(HANDSHAKE_RETRANSMISSION); + EncryptionLevel current_encryption_level = + connection_->encryption_level(); + connection_->SetDefaultEncryptionLevel(level); + QuicIntervalSet retransmission( + state.pending_retransmissions.begin()->min(), + state.pending_retransmissions.begin()->max()); + retransmission.Intersection(crypto_bytes_transferred_[level]); + QuicStreamOffset retransmission_offset = retransmission.begin()->min(); + QuicByteCount retransmission_length = + retransmission.begin()->max() - retransmission.begin()->min(); + size_t bytes_consumed = connection_->SendCryptoData( + level, retransmission_length, retransmission_offset); + // Restore encryption level. + connection_->SetDefaultEncryptionLevel(current_encryption_level); + state.pending_retransmissions.Difference( + retransmission_offset, retransmission_offset + bytes_consumed); + if (bytes_consumed < retransmission_length) { + return false; + } + } + } + return true; + } + if (!stream_map_.contains( + QuicUtils::GetCryptoStreamId(connection_->transport_version()))) { + return true; + } + auto& state = + stream_map_ + .find(QuicUtils::GetCryptoStreamId(connection_->transport_version())) + ->second; + while (!state.pending_retransmissions.Empty()) { + connection_->SetTransmissionType(HANDSHAKE_RETRANSMISSION); + QuicIntervalSet retransmission( + state.pending_retransmissions.begin()->min(), + state.pending_retransmissions.begin()->max()); + EncryptionLevel retransmission_encryption_level = ENCRYPTION_INITIAL; + for (size_t i = 0; i < NUM_ENCRYPTION_LEVELS; ++i) { + if (retransmission.Intersects(crypto_bytes_transferred_[i])) { + retransmission_encryption_level = static_cast(i); + retransmission.Intersection(crypto_bytes_transferred_[i]); + break; + } + } + QuicStreamOffset retransmission_offset = retransmission.begin()->min(); + QuicByteCount retransmission_length = + retransmission.begin()->max() - retransmission.begin()->min(); + EncryptionLevel current_encryption_level = connection_->encryption_level(); + // Set appropriate encryption level. + connection_->SetDefaultEncryptionLevel(retransmission_encryption_level); + QuicConsumedData consumed = connection_->SendStreamData( + QuicUtils::GetCryptoStreamId(connection_->transport_version()), + retransmission_length, retransmission_offset, NO_FIN); + // Restore encryption level. + connection_->SetDefaultEncryptionLevel(current_encryption_level); + state.pending_retransmissions.Difference( + retransmission_offset, retransmission_offset + consumed.bytes_consumed); + if (consumed.bytes_consumed < retransmission_length) { + break; + } + } + return state.pending_retransmissions.Empty(); +} + +bool SimpleSessionNotifier::RetransmitLostStreamData() { + for (auto& pair : stream_map_) { + StreamState& state = pair.second; + QuicConsumedData consumed(0, false); + while (!state.pending_retransmissions.Empty() || state.fin_lost) { + connection_->SetTransmissionType(LOSS_RETRANSMISSION); + if (state.pending_retransmissions.Empty()) { + QUIC_DVLOG(1) << "stream " << pair.first + << " retransmits fin only frame."; + consumed = + connection_->SendStreamData(pair.first, 0, state.bytes_sent, FIN); + state.fin_lost = !consumed.fin_consumed; + if (state.fin_lost) { + QUIC_DLOG(INFO) << "Connection is write blocked"; + return false; + } + } else { + QuicStreamOffset offset = state.pending_retransmissions.begin()->min(); + QuicByteCount length = state.pending_retransmissions.begin()->max() - + state.pending_retransmissions.begin()->min(); + const bool can_bundle_fin = + state.fin_lost && (offset + length == state.bytes_sent); + consumed = connection_->SendStreamData(pair.first, length, offset, + can_bundle_fin ? FIN : NO_FIN); + QUIC_DVLOG(1) << "stream " << pair.first + << " tries to retransmit stream data [" << offset << ", " + << offset + length << ") and fin: " << can_bundle_fin + << ", consumed: " << consumed; + state.pending_retransmissions.Difference( + offset, offset + consumed.bytes_consumed); + if (consumed.fin_consumed) { + state.fin_lost = false; + } + if (length > consumed.bytes_consumed || + (can_bundle_fin && !consumed.fin_consumed)) { + QUIC_DVLOG(1) << "Connection is write blocked"; + break; + } + } + } + } + return !HasLostStreamData(); +} + +bool SimpleSessionNotifier::WriteBufferedCryptoData() { + for (size_t i = 0; i < NUM_ENCRYPTION_LEVELS; ++i) { + const StreamState& state = crypto_state_[i]; + QuicIntervalSet buffered_crypto_data(0, + state.bytes_total); + buffered_crypto_data.Difference(crypto_bytes_transferred_[i]); + for (const auto& interval : buffered_crypto_data) { + size_t bytes_written = connection_->SendCryptoData( + static_cast(i), interval.Length(), interval.min()); + crypto_state_[i].bytes_sent += bytes_written; + crypto_bytes_transferred_[i].Add(interval.min(), + interval.min() + bytes_written); + if (bytes_written < interval.Length()) { + return false; + } + } + } + return true; +} + +bool SimpleSessionNotifier::WriteBufferedControlFrames() { + while (HasBufferedControlFrames()) { + QuicFrame frame_to_send = + control_frames_.at(least_unsent_ - least_unacked_); + QuicFrame copy = CopyRetransmittableControlFrame(frame_to_send); + connection_->SetTransmissionType(NOT_RETRANSMISSION); + if (!connection_->SendControlFrame(copy)) { + // Connection is write blocked. + DeleteFrame(©); + break; + } + ++least_unsent_; + } + return !HasBufferedControlFrames(); +} + +bool SimpleSessionNotifier::HasBufferedControlFrames() const { + return least_unsent_ < least_unacked_ + control_frames_.size(); +} + +bool SimpleSessionNotifier::HasBufferedStreamData() const { + for (const auto& pair : stream_map_) { + const auto& state = pair.second; + if (state.bytes_total > state.bytes_sent || + (state.fin_buffered && !state.fin_sent)) { + return true; + } + } + return false; +} + +bool SimpleSessionNotifier::StreamIsWaitingForAcks(QuicStreamId id) const { + if (!stream_map_.contains(id)) { + return false; + } + const StreamState& state = stream_map_.find(id)->second; + return !state.bytes_acked.Contains(0, state.bytes_sent) || + state.fin_outstanding; +} + +bool SimpleSessionNotifier::StreamHasBufferedData(QuicStreamId id) const { + if (!stream_map_.contains(id)) { + return false; + } + const StreamState& state = stream_map_.find(id)->second; + return state.bytes_total > state.bytes_sent || + (state.fin_buffered && !state.fin_sent); +} + +bool SimpleSessionNotifier::HasLostStreamData() const { + for (const auto& pair : stream_map_) { + const auto& state = pair.second; + if (!state.pending_retransmissions.Empty() || state.fin_lost) { + return true; + } + } + return false; +} + +} // namespace test + +} // namespace quic diff --git a/quiche/quic/test_tools/simple_session_notifier.h b/quiche/quic/test_tools/simple_session_notifier.h new file mode 100644 index 000000000000..f1af586ed15f --- /dev/null +++ b/quiche/quic/test_tools/simple_session_notifier.h @@ -0,0 +1,167 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_SIMPLE_SESSION_NOTIFIER_H_ +#define QUICHE_QUIC_TEST_TOOLS_SIMPLE_SESSION_NOTIFIER_H_ + +#include "absl/container/flat_hash_map.h" +#include "quiche/quic/core/quic_interval_set.h" +#include "quiche/quic/core/session_notifier_interface.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/common/quiche_circular_deque.h" +#include "quiche/common/quiche_linked_hash_map.h" + +namespace quic { + +class QuicConnection; + +namespace test { + +// SimpleSessionNotifier implements the basic functionalities of a session, and +// it manages stream data and control frames. +class SimpleSessionNotifier : public SessionNotifierInterface { + public: + explicit SimpleSessionNotifier(QuicConnection* connection); + ~SimpleSessionNotifier() override; + + // Tries to write stream data and returns data consumed. + QuicConsumedData WriteOrBufferData(QuicStreamId id, QuicByteCount data_length, + StreamSendingState state); + QuicConsumedData WriteOrBufferData(QuicStreamId id, QuicByteCount data_length, + StreamSendingState state, + TransmissionType transmission_type); + + // Tries to write RST_STREAM_FRAME. + void WriteOrBufferRstStream(QuicStreamId id, QuicRstStreamErrorCode error, + QuicStreamOffset bytes_written); + + // Tries to write WINDOW_UPDATE. + void WriteOrBufferWindowUpate(QuicStreamId id, QuicStreamOffset byte_offset); + + // Tries to write PING. + void WriteOrBufferPing(); + + // Tries to write ACK_FREQUENCY. + void WriteOrBufferAckFrequency( + const QuicAckFrequencyFrame& ack_frequency_frame); + + // Tries to write CRYPTO data and returns the number of bytes written. + size_t WriteCryptoData(EncryptionLevel level, QuicByteCount data_length, + QuicStreamOffset offset); + + // Neuters unencrypted data of crypto stream. + void NeuterUnencryptedData(); + + // Called when connection_ becomes writable. + void OnCanWrite(); + + // Called to reset stream. + void OnStreamReset(QuicStreamId id, QuicRstStreamErrorCode error); + + // Returns true if there are 1) unsent control frames and stream data, or 2) + // lost control frames and stream data. + bool WillingToWrite() const; + + // Number of sent stream bytes. Please note, this does not count + // retransmissions. + QuicByteCount StreamBytesSent() const; + + // Number of stream bytes waiting to be sent for the first time. + QuicByteCount StreamBytesToSend() const; + + // Returns true if there is any stream data waiting to be sent for the first + // time. + bool HasBufferedStreamData() const; + + // Returns true if stream |id| has any outstanding data. + bool StreamIsWaitingForAcks(QuicStreamId id) const; + + // SessionNotifierInterface methods: + bool OnFrameAcked(const QuicFrame& frame, QuicTime::Delta ack_delay_time, + QuicTime receive_timestamp) override; + void OnStreamFrameRetransmitted(const QuicStreamFrame& /*frame*/) override {} + void OnFrameLost(const QuicFrame& frame) override; + bool RetransmitFrames(const QuicFrames& frames, + TransmissionType type) override; + bool IsFrameOutstanding(const QuicFrame& frame) const override; + bool HasUnackedCryptoData() const override; + bool HasUnackedStreamData() const override; + bool HasLostStreamData() const; + + private: + struct StreamState { + StreamState(); + ~StreamState(); + + // Total number of bytes. + QuicByteCount bytes_total; + // Number of sent bytes. + QuicByteCount bytes_sent; + // Record of acked offsets. + QuicIntervalSet bytes_acked; + // Data considered as lost and needs to be retransmitted. + QuicIntervalSet pending_retransmissions; + + bool fin_buffered; + bool fin_sent; + bool fin_outstanding; + bool fin_lost; + }; + + friend std::ostream& operator<<(std::ostream& os, const StreamState& s); + + using StreamMap = absl::flat_hash_map; + + void OnStreamDataConsumed(QuicStreamId id, QuicStreamOffset offset, + QuicByteCount data_length, bool fin); + + bool OnControlFrameAcked(const QuicFrame& frame); + + void OnControlFrameLost(const QuicFrame& frame); + + bool RetransmitLostControlFrames(); + + bool RetransmitLostCryptoData(); + + bool RetransmitLostStreamData(); + + bool WriteBufferedControlFrames(); + + bool WriteBufferedCryptoData(); + + bool IsControlFrameOutstanding(const QuicFrame& frame) const; + + bool HasBufferedControlFrames() const; + + bool StreamHasBufferedData(QuicStreamId id) const; + + quiche::QuicheCircularDeque control_frames_; + + quiche::QuicheLinkedHashMap lost_control_frames_; + + // Id of latest saved control frame. 0 if no control frame has been saved. + QuicControlFrameId last_control_frame_id_; + + // The control frame at the 0th index of control_frames_. + QuicControlFrameId least_unacked_; + + // ID of the least unsent control frame. + QuicControlFrameId least_unsent_; + + StreamMap stream_map_; + + // Transferred crypto bytes according to encryption levels. + QuicIntervalSet + crypto_bytes_transferred_[NUM_ENCRYPTION_LEVELS]; + + StreamState crypto_state_[NUM_ENCRYPTION_LEVELS]; + + QuicConnection* connection_; +}; + +} // namespace test + +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_SIMPLE_SESSION_NOTIFIER_H_ diff --git a/quiche/quic/test_tools/simple_session_notifier_test.cc b/quiche/quic/test_tools/simple_session_notifier_test.cc new file mode 100644 index 000000000000..513394ccbfe4 --- /dev/null +++ b/quiche/quic/test_tools/simple_session_notifier_test.cc @@ -0,0 +1,367 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/simple_session_notifier.h" + +#include + +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/test_tools/simple_data_producer.h" + +using testing::_; +using testing::InSequence; +using testing::Return; +using testing::StrictMock; + +namespace quic { +namespace test { +namespace { + +class MockQuicConnectionWithSendStreamData : public MockQuicConnection { + public: + MockQuicConnectionWithSendStreamData(MockQuicConnectionHelper* helper, + MockAlarmFactory* alarm_factory, + Perspective perspective) + : MockQuicConnection(helper, alarm_factory, perspective) {} + + MOCK_METHOD(QuicConsumedData, SendStreamData, + (QuicStreamId id, size_t write_length, QuicStreamOffset offset, + StreamSendingState state), + (override)); +}; + +class SimpleSessionNotifierTest : public QuicTest { + public: + SimpleSessionNotifierTest() + : connection_(&helper_, &alarm_factory_, Perspective::IS_CLIENT), + notifier_(&connection_) { + connection_.set_visitor(&visitor_); + connection_.SetSessionNotifier(¬ifier_); + EXPECT_FALSE(notifier_.WillingToWrite()); + EXPECT_EQ(0u, notifier_.StreamBytesSent()); + EXPECT_FALSE(notifier_.HasBufferedStreamData()); + } + + bool ControlFrameConsumed(const QuicFrame& frame) { + DeleteFrame(&const_cast(frame)); + return true; + } + + MockQuicConnectionHelper helper_; + MockAlarmFactory alarm_factory_; + MockQuicConnectionVisitor visitor_; + StrictMock connection_; + SimpleSessionNotifier notifier_; +}; + +TEST_F(SimpleSessionNotifierTest, WriteOrBufferData) { + InSequence s; + EXPECT_CALL(connection_, SendStreamData(3, 1024, 0, NO_FIN)) + .WillOnce(Return(QuicConsumedData(1024, false))); + notifier_.WriteOrBufferData(3, 1024, NO_FIN); + EXPECT_EQ(0u, notifier_.StreamBytesToSend()); + EXPECT_CALL(connection_, SendStreamData(5, 512, 0, NO_FIN)) + .WillOnce(Return(QuicConsumedData(512, false))); + notifier_.WriteOrBufferData(5, 512, NO_FIN); + EXPECT_FALSE(notifier_.WillingToWrite()); + // Connection is blocked. + EXPECT_CALL(connection_, SendStreamData(5, 512, 512, FIN)) + .WillOnce(Return(QuicConsumedData(256, false))); + notifier_.WriteOrBufferData(5, 512, FIN); + EXPECT_TRUE(notifier_.WillingToWrite()); + EXPECT_EQ(1792u, notifier_.StreamBytesSent()); + EXPECT_EQ(256u, notifier_.StreamBytesToSend()); + EXPECT_TRUE(notifier_.HasBufferedStreamData()); + + // New data cannot be sent as connection is blocked. + EXPECT_CALL(connection_, SendStreamData(7, 1024, 0, FIN)).Times(0); + notifier_.WriteOrBufferData(7, 1024, FIN); + EXPECT_EQ(1792u, notifier_.StreamBytesSent()); +} + +TEST_F(SimpleSessionNotifierTest, WriteOrBufferRstStream) { + InSequence s; + EXPECT_CALL(connection_, SendStreamData(5, 1024, 0, FIN)) + .WillOnce(Return(QuicConsumedData(1024, true))); + notifier_.WriteOrBufferData(5, 1024, FIN); + EXPECT_TRUE(notifier_.StreamIsWaitingForAcks(5)); + EXPECT_TRUE(notifier_.HasUnackedStreamData()); + + // Reset stream 5 with no error. + EXPECT_CALL(connection_, SendControlFrame(_)) + .WillRepeatedly( + Invoke(this, &SimpleSessionNotifierTest::ControlFrameConsumed)); + notifier_.WriteOrBufferRstStream(5, QUIC_STREAM_NO_ERROR, 1024); + // Verify stream 5 is waiting for acks. + EXPECT_TRUE(notifier_.StreamIsWaitingForAcks(5)); + EXPECT_TRUE(notifier_.HasUnackedStreamData()); + + // Reset stream 5 with error. + notifier_.WriteOrBufferRstStream(5, QUIC_ERROR_PROCESSING_STREAM, 1024); + EXPECT_FALSE(notifier_.StreamIsWaitingForAcks(5)); + EXPECT_FALSE(notifier_.HasUnackedStreamData()); +} + +TEST_F(SimpleSessionNotifierTest, WriteOrBufferPing) { + InSequence s; + // Write ping when connection is not write blocked. + EXPECT_CALL(connection_, SendControlFrame(_)) + .WillRepeatedly( + Invoke(this, &SimpleSessionNotifierTest::ControlFrameConsumed)); + notifier_.WriteOrBufferPing(); + EXPECT_EQ(0u, notifier_.StreamBytesToSend()); + EXPECT_FALSE(notifier_.WillingToWrite()); + + // Write stream data and cause the connection to be write blocked. + EXPECT_CALL(connection_, SendStreamData(3, 1024, 0, NO_FIN)) + .WillOnce(Return(QuicConsumedData(1024, false))); + notifier_.WriteOrBufferData(3, 1024, NO_FIN); + EXPECT_EQ(0u, notifier_.StreamBytesToSend()); + EXPECT_CALL(connection_, SendStreamData(5, 512, 0, NO_FIN)) + .WillOnce(Return(QuicConsumedData(256, false))); + notifier_.WriteOrBufferData(5, 512, NO_FIN); + EXPECT_TRUE(notifier_.WillingToWrite()); + + // Connection is blocked. + EXPECT_CALL(connection_, SendControlFrame(_)).Times(0); + notifier_.WriteOrBufferPing(); +} + +TEST_F(SimpleSessionNotifierTest, NeuterUnencryptedData) { + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + // This test writes crypto data through crypto streams. It won't work when + // crypto frames are used instead. + return; + } + InSequence s; + // Send crypto data [0, 1024) in ENCRYPTION_INITIAL. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + EXPECT_CALL(connection_, SendStreamData(QuicUtils::GetCryptoStreamId( + connection_.transport_version()), + 1024, 0, NO_FIN)) + .WillOnce(Return(QuicConsumedData(1024, false))); + notifier_.WriteOrBufferData( + QuicUtils::GetCryptoStreamId(connection_.transport_version()), 1024, + NO_FIN); + // Send crypto data [1024, 2048) in ENCRYPTION_ZERO_RTT. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + EXPECT_CALL(connection_, SendStreamData(QuicUtils::GetCryptoStreamId( + connection_.transport_version()), + 1024, 1024, NO_FIN)) + .WillOnce(Return(QuicConsumedData(1024, false))); + notifier_.WriteOrBufferData( + QuicUtils::GetCryptoStreamId(connection_.transport_version()), 1024, + NO_FIN); + // Ack [1024, 2048). + QuicStreamFrame stream_frame( + QuicUtils::GetCryptoStreamId(connection_.transport_version()), false, + 1024, 1024); + notifier_.OnFrameAcked(QuicFrame(stream_frame), QuicTime::Delta::Zero(), + QuicTime::Zero()); + EXPECT_TRUE(notifier_.StreamIsWaitingForAcks( + QuicUtils::GetCryptoStreamId(connection_.transport_version()))); + EXPECT_TRUE(notifier_.HasUnackedStreamData()); + + // Neuters unencrypted data. + notifier_.NeuterUnencryptedData(); + EXPECT_FALSE(notifier_.StreamIsWaitingForAcks( + QuicUtils::GetCryptoStreamId(connection_.transport_version()))); + EXPECT_FALSE(notifier_.HasUnackedStreamData()); +} + +TEST_F(SimpleSessionNotifierTest, OnCanWrite) { + if (QuicVersionUsesCryptoFrames(connection_.transport_version())) { + // This test writes crypto data through crypto streams. It won't work when + // crypto frames are used instead. + return; + } + InSequence s; + // Send crypto data [0, 1024) in ENCRYPTION_INITIAL. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + EXPECT_CALL(connection_, SendStreamData(QuicUtils::GetCryptoStreamId( + connection_.transport_version()), + 1024, 0, NO_FIN)) + .WillOnce(Return(QuicConsumedData(1024, false))); + notifier_.WriteOrBufferData( + QuicUtils::GetCryptoStreamId(connection_.transport_version()), 1024, + NO_FIN); + + // Send crypto data [1024, 2048) in ENCRYPTION_ZERO_RTT. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + EXPECT_CALL(connection_, SendStreamData(QuicUtils::GetCryptoStreamId( + connection_.transport_version()), + 1024, 1024, NO_FIN)) + .WillOnce(Return(QuicConsumedData(1024, false))); + notifier_.WriteOrBufferData( + QuicUtils::GetCryptoStreamId(connection_.transport_version()), 1024, + NO_FIN); + // Send stream 3 [0, 1024) and connection is blocked. + EXPECT_CALL(connection_, SendStreamData(3, 1024, 0, FIN)) + .WillOnce(Return(QuicConsumedData(512, false))); + notifier_.WriteOrBufferData(3, 1024, FIN); + // Send stream 5 [0, 1024). + EXPECT_CALL(connection_, SendStreamData(5, _, _, _)).Times(0); + notifier_.WriteOrBufferData(5, 1024, NO_FIN); + // Reset stream 5 with error. + EXPECT_CALL(connection_, SendControlFrame(_)).Times(0); + notifier_.WriteOrBufferRstStream(5, QUIC_ERROR_PROCESSING_STREAM, 1024); + + // Lost crypto data [500, 1500) and stream 3 [0, 512). + QuicStreamFrame frame1( + QuicUtils::GetCryptoStreamId(connection_.transport_version()), false, 500, + 1000); + QuicStreamFrame frame2(3, false, 0, 512); + notifier_.OnFrameLost(QuicFrame(frame1)); + notifier_.OnFrameLost(QuicFrame(frame2)); + + // Connection becomes writable. + // Lost crypto data gets retransmitted as [500, 1024) and [1024, 1500), as + // they are in different encryption levels. + EXPECT_CALL(connection_, SendStreamData(QuicUtils::GetCryptoStreamId( + connection_.transport_version()), + 524, 500, NO_FIN)) + .WillOnce(Return(QuicConsumedData(524, false))); + EXPECT_CALL(connection_, SendStreamData(QuicUtils::GetCryptoStreamId( + connection_.transport_version()), + 476, 1024, NO_FIN)) + .WillOnce(Return(QuicConsumedData(476, false))); + // Lost stream 3 data gets retransmitted. + EXPECT_CALL(connection_, SendStreamData(3, 512, 0, NO_FIN)) + .WillOnce(Return(QuicConsumedData(512, false))); + // Buffered control frames get sent. + EXPECT_CALL(connection_, SendControlFrame(_)) + .WillOnce(Invoke(this, &SimpleSessionNotifierTest::ControlFrameConsumed)); + // Buffered stream 3 data [512, 1024) gets sent. + EXPECT_CALL(connection_, SendStreamData(3, 512, 512, FIN)) + .WillOnce(Return(QuicConsumedData(512, true))); + notifier_.OnCanWrite(); + EXPECT_FALSE(notifier_.WillingToWrite()); +} + +TEST_F(SimpleSessionNotifierTest, OnCanWriteCryptoFrames) { + if (!QuicVersionUsesCryptoFrames(connection_.transport_version())) { + return; + } + SimpleDataProducer producer; + connection_.SetDataProducer(&producer); + InSequence s; + // Send crypto data [0, 1024) in ENCRYPTION_INITIAL. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + EXPECT_CALL(connection_, SendCryptoData(ENCRYPTION_INITIAL, 1024, 0)) + .WillOnce(Invoke(&connection_, + &MockQuicConnection::QuicConnection_SendCryptoData)); + EXPECT_CALL(connection_, CloseConnection(QUIC_PACKET_WRITE_ERROR, _, _)); + std::string crypto_data1(1024, 'a'); + producer.SaveCryptoData(ENCRYPTION_INITIAL, 0, crypto_data1); + std::string crypto_data2(524, 'a'); + producer.SaveCryptoData(ENCRYPTION_INITIAL, 500, crypto_data2); + notifier_.WriteCryptoData(ENCRYPTION_INITIAL, 1024, 0); + // Send crypto data [1024, 2048) in ENCRYPTION_ZERO_RTT. + connection_.SetEncrypter(ENCRYPTION_ZERO_RTT, std::make_unique( + Perspective::IS_CLIENT)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + EXPECT_CALL(connection_, SendCryptoData(ENCRYPTION_ZERO_RTT, 1024, 0)) + .WillOnce(Invoke(&connection_, + &MockQuicConnection::QuicConnection_SendCryptoData)); + std::string crypto_data3(1024, 'a'); + producer.SaveCryptoData(ENCRYPTION_ZERO_RTT, 0, crypto_data3); + notifier_.WriteCryptoData(ENCRYPTION_ZERO_RTT, 1024, 0); + // Send stream 3 [0, 1024) and connection is blocked. + EXPECT_CALL(connection_, SendStreamData(3, 1024, 0, FIN)) + .WillOnce(Return(QuicConsumedData(512, false))); + notifier_.WriteOrBufferData(3, 1024, FIN); + // Send stream 5 [0, 1024). + EXPECT_CALL(connection_, SendStreamData(5, _, _, _)).Times(0); + notifier_.WriteOrBufferData(5, 1024, NO_FIN); + // Reset stream 5 with error. + EXPECT_CALL(connection_, SendControlFrame(_)).Times(0); + notifier_.WriteOrBufferRstStream(5, QUIC_ERROR_PROCESSING_STREAM, 1024); + + // Lost crypto data [500, 1500) and stream 3 [0, 512). + QuicCryptoFrame crypto_frame1(ENCRYPTION_INITIAL, 500, 524); + QuicCryptoFrame crypto_frame2(ENCRYPTION_ZERO_RTT, 0, 476); + QuicStreamFrame stream3_frame(3, false, 0, 512); + notifier_.OnFrameLost(QuicFrame(&crypto_frame1)); + notifier_.OnFrameLost(QuicFrame(&crypto_frame2)); + notifier_.OnFrameLost(QuicFrame(stream3_frame)); + + // Connection becomes writable. + // Lost crypto data gets retransmitted as [500, 1024) and [1024, 1500), as + // they are in different encryption levels. + EXPECT_CALL(connection_, SendCryptoData(ENCRYPTION_INITIAL, 524, 500)) + .WillOnce(Invoke(&connection_, + &MockQuicConnection::QuicConnection_SendCryptoData)); + EXPECT_CALL(connection_, SendCryptoData(ENCRYPTION_ZERO_RTT, 476, 0)) + .WillOnce(Invoke(&connection_, + &MockQuicConnection::QuicConnection_SendCryptoData)); + // Lost stream 3 data gets retransmitted. + EXPECT_CALL(connection_, SendStreamData(3, 512, 0, NO_FIN)) + .WillOnce(Return(QuicConsumedData(512, false))); + // Buffered control frames get sent. + EXPECT_CALL(connection_, SendControlFrame(_)) + .WillOnce(Invoke(this, &SimpleSessionNotifierTest::ControlFrameConsumed)); + // Buffered stream 3 data [512, 1024) gets sent. + EXPECT_CALL(connection_, SendStreamData(3, 512, 512, FIN)) + .WillOnce(Return(QuicConsumedData(512, true))); + notifier_.OnCanWrite(); + EXPECT_FALSE(notifier_.WillingToWrite()); +} + +TEST_F(SimpleSessionNotifierTest, RetransmitFrames) { + InSequence s; + connection_.SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(Perspective::IS_CLIENT)); + // Send stream 3 data [0, 10) and fin. + EXPECT_CALL(connection_, SendStreamData(3, 10, 0, FIN)) + .WillOnce(Return(QuicConsumedData(10, true))); + notifier_.WriteOrBufferData(3, 10, FIN); + QuicStreamFrame frame1(3, true, 0, 10); + // Send stream 5 [0, 10) and fin. + EXPECT_CALL(connection_, SendStreamData(5, 10, 0, FIN)) + .WillOnce(Return(QuicConsumedData(10, true))); + notifier_.WriteOrBufferData(5, 10, FIN); + QuicStreamFrame frame2(5, true, 0, 10); + // Reset stream 5 with no error. + EXPECT_CALL(connection_, SendControlFrame(_)) + .WillOnce(Invoke(this, &SimpleSessionNotifierTest::ControlFrameConsumed)); + notifier_.WriteOrBufferRstStream(5, QUIC_STREAM_NO_ERROR, 10); + + // Ack stream 3 [3, 7), and stream 5 [8, 10). + QuicStreamFrame ack_frame1(3, false, 3, 4); + QuicStreamFrame ack_frame2(5, false, 8, 2); + notifier_.OnFrameAcked(QuicFrame(ack_frame1), QuicTime::Delta::Zero(), + QuicTime::Zero()); + notifier_.OnFrameAcked(QuicFrame(ack_frame2), QuicTime::Delta::Zero(), + QuicTime::Zero()); + EXPECT_FALSE(notifier_.WillingToWrite()); + + // Force to send. + QuicRstStreamFrame rst_stream(1, 5, QUIC_STREAM_NO_ERROR, 10); + QuicFrames frames; + frames.push_back(QuicFrame(frame2)); + frames.push_back(QuicFrame(&rst_stream)); + frames.push_back(QuicFrame(frame1)); + // stream 5 data [0, 8), fin only are retransmitted. + EXPECT_CALL(connection_, SendStreamData(5, 8, 0, NO_FIN)) + .WillOnce(Return(QuicConsumedData(8, false))); + EXPECT_CALL(connection_, SendStreamData(5, 0, 10, FIN)) + .WillOnce(Return(QuicConsumedData(0, true))); + // rst_stream is retransmitted. + EXPECT_CALL(connection_, SendControlFrame(_)) + .WillOnce(Invoke(this, &SimpleSessionNotifierTest::ControlFrameConsumed)); + // stream 3 data [0, 3) is retransmitted and connection is blocked. + EXPECT_CALL(connection_, SendStreamData(3, 3, 0, NO_FIN)) + .WillOnce(Return(QuicConsumedData(2, false))); + notifier_.RetransmitFrames(frames, PTO_RETRANSMISSION); + EXPECT_FALSE(notifier_.WillingToWrite()); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/simulator/README.md b/quiche/quic/test_tools/simulator/README.md new file mode 100644 index 000000000000..8582962adeef --- /dev/null +++ b/quiche/quic/test_tools/simulator/README.md @@ -0,0 +1,99 @@ +# QUIC network simulator + +This directory contains a discrete event network simulator which QUIC code uses +for testing congestion control and other transmission control code that requires +a network simulation for tests on QuicConnection level of abstraction. + +## Actors + +The core of the simulator is the Simulator class, which maintains a virtual +clock and an event queue. Any object in a simulation that needs to schedule +events has to subclass Actor. Subclassing Actor involves: + +1. Calling the `Actor::Actor(Simulator*, std::string)` constructor to establish + the name of the object and the simulator it is associated with. +2. Calling `Schedule(QuicTime)` to schedule the time at which `Act()` method is + called. `Schedule` will only cause the object to be rescheduled if the time + for which it is currently scheduled is later than the new time. +3. Implementing `Act()` method with the relevant logic. The actor will be + removed from the event queue right before `Act()` is called. + +Here is a simple example of an object that outputs simulation time into the log +every 100 ms. + +```c++ +class LogClock : public Actor { + public: + LogClock(Simulator* simulator, std::string name) : Actor(simulator, name) { + Schedule(clock_->Now()); + } + ~LogClock() override {} + + void Act() override { + QUIC_LOG(INFO) << "The current time is " + << clock_->Now().ToDebuggingValue(); + Schedule(clock_->Now() + QuicTime::Delta::FromMilliseconds(100)); + } +}; +``` + +A QuicAlarm object can be used to schedule events in the simulation using +`Simulator::GetAlarmFactory()`. + +## Ports + +The simulated network transfers packets, which are modelled as an instance of +struct `Packet`. A packet consists of source and destination address (which are +just plain strings), a transmission timestamp and the UDP-layer payload. + +The simulation uses the push model: any object that wishes to transfer a packet +to another component in the simulation has to explicitly do it itself. Any +object that can accept a packet is called a *port*. There are two types of +ports: unconstrained ports, which can always accept packets, and constrained +ports, which signal when they can accept a new packet. + +An endpoint is an object that is connected to the network and can both receive +and send packets. In our model, the endpoint always receives packets as an +unconstrained port (*RX port*), and always writes packets to a constrained port +(*TX port*). + +## Links + +The `SymmetricLink` class models a symmetric duplex links with finite bandwidth +and propagation delay. It consists of a pair of identical `OneWayLink`s, which +accept packets as a constrained port (where constrain comes from the finiteness +of bandwidth) and outputs them into an unconstrained port. Two endpoints +connected via a `SymmetricLink` look like this: + +```none + Endpoint A Endpoint B ++-----------+ SymmetricLink +-----------+ +| | +------------------------------+ | | +| +---------+ | +------------------------+ | +---------+ | +| | RX port <-----| OneWayLink *<-----| TX port | | +| +---------+ | +------------------------+ | +---------+ | +| | | | | | +| +---------+ | +------------------------+ | +---------+ | +| | TX port |----->* OneWayLink |-----> RX port | | +| +---------+ | +------------------------+ | +---------+ | +| | +------------------------------+ | | ++-----------+ +-----------+ + + ( -->* denotes constrained port) +``` + +In most common scenario, one of the endpoints is going to be a QUIC endpoint, +and another is going to be a switch port. + +## Other objects + +Besides `SymmetricLink`, the simulator provides the following objects: + +* `Queue` allows to convert a constrained port into an unconstrained one by + buffering packets upon arrival. The queue has a finite size, and once the + queue is full, the packets are silently dropped. +* `Switch` simulates a multi-port learning switch with a fixed queue for each + output port. +* `QuicEndpoint` allows QuicConnection to be run over the simulated network. +* `QuicEndpointMultiplexer` allows multiple connections to share the same + network endpoint. diff --git a/quiche/quic/test_tools/simulator/actor.cc b/quiche/quic/test_tools/simulator/actor.cc new file mode 100644 index 000000000000..213d861e3204 --- /dev/null +++ b/quiche/quic/test_tools/simulator/actor.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/simulator/actor.h" + +#include "quiche/quic/test_tools/simulator/simulator.h" + +namespace quic { +namespace simulator { + +Actor::Actor(Simulator* simulator, std::string name) + : simulator_(simulator), + clock_(simulator->GetClock()), + name_(std::move(name)) { + simulator_->AddActor(this); +} + +Actor::~Actor() { simulator_->RemoveActor(this); } + +void Actor::Schedule(QuicTime next_tick) { + simulator_->Schedule(this, next_tick); +} + +void Actor::Unschedule() { simulator_->Unschedule(this); } + +} // namespace simulator +} // namespace quic diff --git a/quiche/quic/test_tools/simulator/actor.h b/quiche/quic/test_tools/simulator/actor.h new file mode 100644 index 000000000000..37473223f395 --- /dev/null +++ b/quiche/quic/test_tools/simulator/actor.h @@ -0,0 +1,66 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_SIMULATOR_ACTOR_H_ +#define QUICHE_QUIC_TEST_TOOLS_SIMULATOR_ACTOR_H_ + +#include + +#include "quiche/quic/core/quic_clock.h" +#include "quiche/quic/core/quic_time.h" + +namespace quic { +namespace simulator { + +class Simulator; + +// Actor is the base class for all participants of the simulation which can +// schedule events to be triggered at the specified time. Every actor has a +// name assigned to it, which can be used for debugging and addressing purposes. +// +// The Actor object is scheduled as follows: +// 1. Every Actor object appears at most once in the event queue, for one +// specific time. +// 2. Actor is scheduled by calling Schedule() method. +// 3. If Schedule() method is called with multiple different times specified, +// Act() method will be called at the earliest time specified. +// 4. Before Act() is called, the Actor is removed from the event queue. Act() +// will not be called again unless Schedule() is called. +class Actor { + public: + Actor(Simulator* simulator, std::string name); + virtual ~Actor(); + + // Trigger all the events the actor can potentially handle at this point. + // Before Act() is called, the actor is removed from the event queue, and has + // to schedule the next call manually. + virtual void Act() = 0; + + std::string name() const { return name_; } + Simulator* simulator() const { return simulator_; } + + protected: + // Calls Schedule() on the associated simulator. + void Schedule(QuicTime next_tick); + + // Calls Unschedule() on the associated simulator. + void Unschedule(); + + Simulator* simulator_; + const QuicClock* clock_; + std::string name_; + + private: + // Since the Actor object registers itself with a simulator using a pointer to + // itself, do not allow it to be moved. + Actor(Actor&&) = delete; + Actor(const Actor&) = delete; + Actor& operator=(const Actor&) = delete; + Actor& operator=(Actor&&) = delete; +}; + +} // namespace simulator +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_SIMULATOR_ACTOR_H_ diff --git a/quiche/quic/test_tools/simulator/alarm_factory.cc b/quiche/quic/test_tools/simulator/alarm_factory.cc new file mode 100644 index 000000000000..48939b53f45c --- /dev/null +++ b/quiche/quic/test_tools/simulator/alarm_factory.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/simulator/alarm_factory.h" + +#include "absl/strings/str_format.h" +#include "quiche/quic/core/quic_alarm.h" + +namespace quic { +namespace simulator { + +// Alarm is an implementation of QuicAlarm which can schedule alarms in the +// simulation timeline. +class Alarm : public QuicAlarm { + public: + Alarm(Simulator* simulator, std::string name, + QuicArenaScopedPtr delegate) + : QuicAlarm(std::move(delegate)), adapter_(simulator, name, this) {} + ~Alarm() override {} + + void SetImpl() override { + QUICHE_DCHECK(deadline().IsInitialized()); + adapter_.Set(deadline()); + } + + void CancelImpl() override { adapter_.Cancel(); } + + private: + // An adapter class triggering a QuicAlarm using a simulation time system. + // An adapter is required here because neither Actor nor QuicAlarm are pure + // interfaces. + class Adapter : public Actor { + public: + Adapter(Simulator* simulator, std::string name, Alarm* parent) + : Actor(simulator, name), parent_(parent) {} + ~Adapter() override {} + + void Set(QuicTime time) { Schedule(std::max(time, clock_->Now())); } + void Cancel() { Unschedule(); } + + void Act() override { + QUICHE_DCHECK(clock_->Now() >= parent_->deadline()); + parent_->Fire(); + } + + private: + Alarm* parent_; + }; + Adapter adapter_; +}; + +AlarmFactory::AlarmFactory(Simulator* simulator, std::string name) + : simulator_(simulator), name_(std::move(name)), counter_(0) {} + +AlarmFactory::~AlarmFactory() {} + +std::string AlarmFactory::GetNewAlarmName() { + ++counter_; + return absl::StrFormat("%s (alarm %i)", name_, counter_); +} + +QuicAlarm* AlarmFactory::CreateAlarm(QuicAlarm::Delegate* delegate) { + return new Alarm(simulator_, GetNewAlarmName(), + QuicArenaScopedPtr(delegate)); +} + +QuicArenaScopedPtr AlarmFactory::CreateAlarm( + QuicArenaScopedPtr delegate, + QuicConnectionArena* arena) { + if (arena != nullptr) { + return arena->New(simulator_, GetNewAlarmName(), + std::move(delegate)); + } + return QuicArenaScopedPtr( + new Alarm(simulator_, GetNewAlarmName(), std::move(delegate))); +} + +} // namespace simulator +} // namespace quic diff --git a/quiche/quic/test_tools/simulator/alarm_factory.h b/quiche/quic/test_tools/simulator/alarm_factory.h new file mode 100644 index 000000000000..2d6f1836acaa --- /dev/null +++ b/quiche/quic/test_tools/simulator/alarm_factory.h @@ -0,0 +1,39 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_SIMULATOR_ALARM_FACTORY_H_ +#define QUICHE_QUIC_TEST_TOOLS_SIMULATOR_ALARM_FACTORY_H_ + +#include "quiche/quic/core/quic_alarm_factory.h" +#include "quiche/quic/test_tools/simulator/actor.h" + +namespace quic { +namespace simulator { + +// AlarmFactory allows to schedule QuicAlarms using the simulation event queue. +class AlarmFactory : public QuicAlarmFactory { + public: + AlarmFactory(Simulator* simulator, std::string name); + AlarmFactory(const AlarmFactory&) = delete; + AlarmFactory& operator=(const AlarmFactory&) = delete; + ~AlarmFactory() override; + + QuicAlarm* CreateAlarm(QuicAlarm::Delegate* delegate) override; + QuicArenaScopedPtr CreateAlarm( + QuicArenaScopedPtr delegate, + QuicConnectionArena* arena) override; + + private: + // Automatically generate a name for a new alarm. + std::string GetNewAlarmName(); + + Simulator* simulator_; + std::string name_; + int counter_; +}; + +} // namespace simulator +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_SIMULATOR_ALARM_FACTORY_H_ diff --git a/quiche/quic/test_tools/simulator/link.cc b/quiche/quic/test_tools/simulator/link.cc new file mode 100644 index 000000000000..e2a094bfe60f --- /dev/null +++ b/quiche/quic/test_tools/simulator/link.cc @@ -0,0 +1,115 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/simulator/link.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "quiche/quic/test_tools/simulator/simulator.h" + +namespace quic { +namespace simulator { + +// Parameters for random noise delay. +const uint64_t kMaxRandomDelayUs = 10; + +OneWayLink::OneWayLink(Simulator* simulator, std::string name, + UnconstrainedPortInterface* sink, + QuicBandwidth bandwidth, + QuicTime::Delta propagation_delay) + : Actor(simulator, name), + sink_(sink), + bandwidth_(bandwidth), + propagation_delay_(propagation_delay), + next_write_at_(QuicTime::Zero()) {} + +OneWayLink::~OneWayLink() {} + +OneWayLink::QueuedPacket::QueuedPacket(std::unique_ptr packet, + QuicTime dequeue_time) + : packet(std::move(packet)), dequeue_time(dequeue_time) {} + +OneWayLink::QueuedPacket::QueuedPacket(QueuedPacket&& other) = default; + +OneWayLink::QueuedPacket::~QueuedPacket() {} + +void OneWayLink::AcceptPacket(std::unique_ptr packet) { + QUICHE_DCHECK(TimeUntilAvailable().IsZero()); + QuicTime::Delta transfer_time = bandwidth_.TransferTime(packet->size); + next_write_at_ = clock_->Now() + transfer_time; + + packets_in_transit_.emplace_back( + std::move(packet), + // Ensure that packets are delivered in order. + std::max( + next_write_at_ + propagation_delay_ + GetRandomDelay(transfer_time), + packets_in_transit_.empty() + ? QuicTime::Zero() + : packets_in_transit_.back().dequeue_time)); + ScheduleNextPacketDeparture(); +} + +QuicTime::Delta OneWayLink::TimeUntilAvailable() { + const QuicTime now = clock_->Now(); + if (next_write_at_ <= now) { + return QuicTime::Delta::Zero(); + } + + return next_write_at_ - now; +} + +void OneWayLink::Act() { + QUICHE_DCHECK(!packets_in_transit_.empty()); + QUICHE_DCHECK(packets_in_transit_.front().dequeue_time >= clock_->Now()); + + sink_->AcceptPacket(std::move(packets_in_transit_.front().packet)); + packets_in_transit_.pop_front(); + + ScheduleNextPacketDeparture(); +} + +void OneWayLink::ScheduleNextPacketDeparture() { + if (packets_in_transit_.empty()) { + return; + } + + Schedule(packets_in_transit_.front().dequeue_time); +} + +QuicTime::Delta OneWayLink::GetRandomDelay(QuicTime::Delta transfer_time) { + if (!simulator_->enable_random_delays()) { + return QuicTime::Delta::Zero(); + } + + QuicTime::Delta delta = QuicTime::Delta::FromMicroseconds( + simulator_->GetRandomGenerator()->RandUint64() % (kMaxRandomDelayUs + 1)); + // Have an upper bound on the delay to ensure packets do not go out of order. + delta = std::min(delta, transfer_time * 0.5); + return delta; +} + +SymmetricLink::SymmetricLink(Simulator* simulator, std::string name, + UnconstrainedPortInterface* sink_a, + UnconstrainedPortInterface* sink_b, + QuicBandwidth bandwidth, + QuicTime::Delta propagation_delay) + : a_to_b_link_(simulator, absl::StrCat(name, " (A-to-B)"), sink_b, + bandwidth, propagation_delay), + b_to_a_link_(simulator, absl::StrCat(name, " (B-to-A)"), sink_a, + bandwidth, propagation_delay) {} + +SymmetricLink::SymmetricLink(Endpoint* endpoint_a, Endpoint* endpoint_b, + QuicBandwidth bandwidth, + QuicTime::Delta propagation_delay) + : SymmetricLink(endpoint_a->simulator(), + absl::StrFormat("Link [%s]<->[%s]", endpoint_a->name(), + endpoint_b->name()), + endpoint_a->GetRxPort(), endpoint_b->GetRxPort(), bandwidth, + propagation_delay) { + endpoint_a->SetTxPort(&a_to_b_link_); + endpoint_b->SetTxPort(&b_to_a_link_); +} + +} // namespace simulator +} // namespace quic diff --git a/quiche/quic/test_tools/simulator/link.h b/quiche/quic/test_tools/simulator/link.h new file mode 100644 index 000000000000..19061ddd1f4e --- /dev/null +++ b/quiche/quic/test_tools/simulator/link.h @@ -0,0 +1,97 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_SIMULATOR_LINK_H_ +#define QUICHE_QUIC_TEST_TOOLS_SIMULATOR_LINK_H_ + +#include + +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/quic_bandwidth.h" +#include "quiche/quic/test_tools/simulator/actor.h" +#include "quiche/quic/test_tools/simulator/port.h" +#include "quiche/common/quiche_circular_deque.h" + +namespace quic { +namespace simulator { + +// A reliable simplex link between two endpoints with constrained bandwidth. A +// few microseconds of random delay are added for every packet to avoid +// synchronization issues. +class OneWayLink : public Actor, public ConstrainedPortInterface { + public: + OneWayLink(Simulator* simulator, std::string name, + UnconstrainedPortInterface* sink, QuicBandwidth bandwidth, + QuicTime::Delta propagation_delay); + OneWayLink(const OneWayLink&) = delete; + OneWayLink& operator=(const OneWayLink&) = delete; + ~OneWayLink() override; + + void AcceptPacket(std::unique_ptr packet) override; + QuicTime::Delta TimeUntilAvailable() override; + void Act() override; + + QuicBandwidth bandwidth() const { return bandwidth_; } + void set_bandwidth(QuicBandwidth new_bandwidth) { + bandwidth_ = new_bandwidth; + } + + protected: + // Get the value of a random delay imposed on each packet. By default, this + // is a short random delay in order to avoid artifical synchronization + // artifacts during the simulation. Subclasses may override this behavior + // (for example, to provide a random component of delay). + virtual QuicTime::Delta GetRandomDelay(QuicTime::Delta transfer_time); + + private: + struct QueuedPacket { + std::unique_ptr packet; + QuicTime dequeue_time; + + QueuedPacket(std::unique_ptr packet, QuicTime dequeue_time); + QueuedPacket(QueuedPacket&& other); + ~QueuedPacket(); + }; + + // Schedule the next packet to be egressed out of the link if there are + // packets on the link. + void ScheduleNextPacketDeparture(); + + UnconstrainedPortInterface* sink_; + quiche::QuicheCircularDeque packets_in_transit_; + + QuicBandwidth bandwidth_; + const QuicTime::Delta propagation_delay_; + + QuicTime next_write_at_; +}; + +// A full-duplex link between two endpoints, functionally equivalent to two +// OneWayLink objects tied together. +class SymmetricLink { + public: + SymmetricLink(Simulator* simulator, std::string name, + UnconstrainedPortInterface* sink_a, + UnconstrainedPortInterface* sink_b, QuicBandwidth bandwidth, + QuicTime::Delta propagation_delay); + SymmetricLink(Endpoint* endpoint_a, Endpoint* endpoint_b, + QuicBandwidth bandwidth, QuicTime::Delta propagation_delay); + SymmetricLink(const SymmetricLink&) = delete; + SymmetricLink& operator=(const SymmetricLink&) = delete; + + QuicBandwidth bandwidth() { return a_to_b_link_.bandwidth(); } + void set_bandwidth(QuicBandwidth new_bandwidth) { + a_to_b_link_.set_bandwidth(new_bandwidth); + b_to_a_link_.set_bandwidth(new_bandwidth); + } + + private: + OneWayLink a_to_b_link_; + OneWayLink b_to_a_link_; +}; + +} // namespace simulator +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_SIMULATOR_LINK_H_ diff --git a/quiche/quic/test_tools/simulator/packet_filter.cc b/quiche/quic/test_tools/simulator/packet_filter.cc new file mode 100644 index 000000000000..fc07c1542a54 --- /dev/null +++ b/quiche/quic/test_tools/simulator/packet_filter.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/simulator/packet_filter.h" + +namespace quic { +namespace simulator { + +PacketFilter::PacketFilter(Simulator* simulator, std::string name, + Endpoint* input) + : Endpoint(simulator, name), input_(input) { + input_->SetTxPort(this); +} + +PacketFilter::~PacketFilter() {} + +void PacketFilter::AcceptPacket(std::unique_ptr packet) { + if (FilterPacket(*packet)) { + output_tx_port_->AcceptPacket(std::move(packet)); + } +} + +QuicTime::Delta PacketFilter::TimeUntilAvailable() { + return output_tx_port_->TimeUntilAvailable(); +} + +void PacketFilter::Act() {} + +UnconstrainedPortInterface* PacketFilter::GetRxPort() { + return input_->GetRxPort(); +} + +void PacketFilter::SetTxPort(ConstrainedPortInterface* port) { + output_tx_port_ = port; +} + +} // namespace simulator +} // namespace quic diff --git a/quiche/quic/test_tools/simulator/packet_filter.h b/quiche/quic/test_tools/simulator/packet_filter.h new file mode 100644 index 000000000000..cf57bb0f8de3 --- /dev/null +++ b/quiche/quic/test_tools/simulator/packet_filter.h @@ -0,0 +1,75 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_SIMULATOR_PACKET_FILTER_H_ +#define QUICHE_QUIC_TEST_TOOLS_SIMULATOR_PACKET_FILTER_H_ + +#include "quiche/quic/test_tools/simulator/port.h" + +namespace quic { +namespace simulator { + +// Packet filter allows subclasses to filter out the packets that enter the +// input port and exit the output port. Packets in the other direction are +// always passed through. +// +// The filter wraps around the input endpoint, and exposes the resulting +// filtered endpoint via the output() method. For example, if initially there +// are two endpoints, A and B, connected via a symmetric link: +// +// QuicEndpoint endpoint_a; +// QuicEndpoint endpoint_b; +// +// [...] +// +// SymmetricLink a_b_link(&endpoint_a, &endpoint_b, ...); +// +// and the goal is to filter the traffic from A to B, then the new invocation +// would be as follows: +// +// PacketFilter filter(&simulator, "A-to-B packet filter", endpoint_a); +// SymmetricLink a_b_link(&filter, &endpoint_b, ...); +// +// Note that the filter drops the packet instanteneously, without it ever +// reaching the output wire. This means that in a direct endpoint-to-endpoint +// scenario, whenever the packet is dropped, the link would become immediately +// available for the next packet. +class PacketFilter : public Endpoint, public ConstrainedPortInterface { + public: + // Initialize the filter by wrapping around |input|. Does not take the + // ownership of |input|. + PacketFilter(Simulator* simulator, std::string name, Endpoint* input); + PacketFilter(const PacketFilter&) = delete; + PacketFilter& operator=(const PacketFilter&) = delete; + ~PacketFilter() override; + + // Implementation of ConstrainedPortInterface. + void AcceptPacket(std::unique_ptr packet) override; + QuicTime::Delta TimeUntilAvailable() override; + + // Implementation of Endpoint interface methods. + UnconstrainedPortInterface* GetRxPort() override; + void SetTxPort(ConstrainedPortInterface* port) override; + + // Implementation of Actor interface methods. + void Act() override; + + protected: + // Returns true if the packet should be passed through, and false if it should + // be dropped. The function is called once per packet, in the order that the + // packets arrive, so it is safe for the function to alter the internal state + // of the filter. + virtual bool FilterPacket(const Packet& packet) = 0; + + private: + // The port onto which the filtered packets are egressed. + ConstrainedPortInterface* output_tx_port_; + + // The original network endpoint wrapped by the class. + Endpoint* input_; +}; + +} // namespace simulator +} // namespace quic +#endif // QUICHE_QUIC_TEST_TOOLS_SIMULATOR_PACKET_FILTER_H_ diff --git a/quiche/quic/test_tools/simulator/port.cc b/quiche/quic/test_tools/simulator/port.cc new file mode 100644 index 000000000000..bffc086fa5ae --- /dev/null +++ b/quiche/quic/test_tools/simulator/port.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/simulator/port.h" + +namespace quic { +namespace simulator { + +Packet::Packet() + : source(), destination(), tx_timestamp(QuicTime::Zero()), size(0) {} + +Packet::~Packet() {} + +Packet::Packet(const Packet& packet) = default; + +Endpoint::Endpoint(Simulator* simulator, std::string name) + : Actor(simulator, name) {} + +} // namespace simulator +} // namespace quic diff --git a/quiche/quic/test_tools/simulator/port.h b/quiche/quic/test_tools/simulator/port.h new file mode 100644 index 000000000000..27bac9554c1b --- /dev/null +++ b/quiche/quic/test_tools/simulator/port.h @@ -0,0 +1,66 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_SIMULATOR_PORT_H_ +#define QUICHE_QUIC_TEST_TOOLS_SIMULATOR_PORT_H_ + +#include +#include + +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/test_tools/simulator/actor.h" + +namespace quic { +namespace simulator { + +struct Packet { + Packet(); + ~Packet(); + Packet(const Packet& packet); + + std::string source; + std::string destination; + QuicTime tx_timestamp; + + std::string contents; + QuicByteCount size; +}; + +// An interface for anything that accepts packets at arbitrary rate. +class UnconstrainedPortInterface { + public: + virtual ~UnconstrainedPortInterface() {} + virtual void AcceptPacket(std::unique_ptr packet) = 0; +}; + +// An interface for any device that accepts packets at a specific rate. +// Typically one would use a Queue object in order to write into a constrained +// port. +class ConstrainedPortInterface { + public: + virtual ~ConstrainedPortInterface() {} + + // Accept a packet for a port. TimeUntilAvailable() must be zero before this + // method is called. + virtual void AcceptPacket(std::unique_ptr packet) = 0; + + // Time until write for the next port is available. Cannot be infinite. + virtual QuicTime::Delta TimeUntilAvailable() = 0; +}; + +// A convenience class for any network endpoints, i.e. the objects which can +// both accept and send packets. +class Endpoint : public Actor { + public: + virtual UnconstrainedPortInterface* GetRxPort() = 0; + virtual void SetTxPort(ConstrainedPortInterface* port) = 0; + + protected: + Endpoint(Simulator* simulator, std::string name); +}; + +} // namespace simulator +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_SIMULATOR_PORT_H_ diff --git a/quiche/quic/test_tools/simulator/queue.cc b/quiche/quic/test_tools/simulator/queue.cc new file mode 100644 index 000000000000..5a1eccf61a5e --- /dev/null +++ b/quiche/quic/test_tools/simulator/queue.cc @@ -0,0 +1,127 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/simulator/queue.h" + +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/test_tools/simulator/simulator.h" + +namespace quic { +namespace simulator { + +Queue::ListenerInterface::~ListenerInterface() {} + +Queue::Queue(Simulator* simulator, std::string name, QuicByteCount capacity) + : Actor(simulator, name), + capacity_(capacity), + bytes_queued_(0), + aggregation_threshold_(0), + aggregation_timeout_(QuicTime::Delta::Infinite()), + current_bundle_(0), + current_bundle_bytes_(0), + tx_port_(nullptr), + listener_(nullptr) { + aggregation_timeout_alarm_.reset(simulator_->GetAlarmFactory()->CreateAlarm( + new AggregationAlarmDelegate(this))); +} + +Queue::~Queue() { aggregation_timeout_alarm_->PermanentCancel(); } + +void Queue::set_tx_port(ConstrainedPortInterface* port) { tx_port_ = port; } + +void Queue::AcceptPacket(std::unique_ptr packet) { + if (packet->size + bytes_queued_ > capacity_) { + QUIC_DVLOG(1) << "Queue [" << name() << "] has received a packet from [" + << packet->source << "] to [" << packet->destination + << "] which is over capacity. Dropping it."; + QUIC_DVLOG(1) << "Queue size: " << bytes_queued_ << " out of " << capacity_ + << ". Packet size: " << packet->size; + return; + } + + bytes_queued_ += packet->size; + queue_.emplace_back(std::move(packet), current_bundle_); + + if (IsAggregationEnabled()) { + current_bundle_bytes_ += queue_.front().packet->size; + if (!aggregation_timeout_alarm_->IsSet()) { + aggregation_timeout_alarm_->Set(clock_->Now() + aggregation_timeout_); + } + if (current_bundle_bytes_ >= aggregation_threshold_) { + NextBundle(); + } + } + + ScheduleNextPacketDequeue(); +} + +void Queue::Act() { + QUICHE_DCHECK(!queue_.empty()); + if (tx_port_->TimeUntilAvailable().IsZero()) { + QUICHE_DCHECK(bytes_queued_ >= queue_.front().packet->size); + bytes_queued_ -= queue_.front().packet->size; + + tx_port_->AcceptPacket(std::move(queue_.front().packet)); + queue_.pop_front(); + if (listener_ != nullptr) { + listener_->OnPacketDequeued(); + } + } + + ScheduleNextPacketDequeue(); +} + +void Queue::EnableAggregation(QuicByteCount aggregation_threshold, + QuicTime::Delta aggregation_timeout) { + QUICHE_DCHECK_EQ(bytes_queued_, 0u); + QUICHE_DCHECK_GT(aggregation_threshold, 0u); + QUICHE_DCHECK(!aggregation_timeout.IsZero()); + QUICHE_DCHECK(!aggregation_timeout.IsInfinite()); + + aggregation_threshold_ = aggregation_threshold; + aggregation_timeout_ = aggregation_timeout; +} + +Queue::AggregationAlarmDelegate::AggregationAlarmDelegate(Queue* queue) + : queue_(queue) {} + +void Queue::AggregationAlarmDelegate::OnAlarm() { + queue_->NextBundle(); + queue_->ScheduleNextPacketDequeue(); +} + +Queue::EnqueuedPacket::EnqueuedPacket(std::unique_ptr packet, + AggregationBundleNumber bundle) + : packet(std::move(packet)), bundle(bundle) {} + +Queue::EnqueuedPacket::EnqueuedPacket(EnqueuedPacket&& other) = default; + +Queue::EnqueuedPacket::~EnqueuedPacket() = default; + +void Queue::NextBundle() { + current_bundle_++; + current_bundle_bytes_ = 0; + aggregation_timeout_alarm_->Cancel(); +} + +void Queue::ScheduleNextPacketDequeue() { + if (queue_.empty()) { + QUICHE_DCHECK_EQ(bytes_queued_, 0u); + return; + } + + if (IsAggregationEnabled() && queue_.front().bundle == current_bundle_) { + return; + } + + QuicTime::Delta time_until_available = QuicTime::Delta::Zero(); + if (tx_port_) { + time_until_available = tx_port_->TimeUntilAvailable(); + } + + Schedule(clock_->Now() + time_until_available); +} + +} // namespace simulator +} // namespace quic diff --git a/quiche/quic/test_tools/simulator/queue.h b/quiche/quic/test_tools/simulator/queue.h new file mode 100644 index 000000000000..b81db56c2207 --- /dev/null +++ b/quiche/quic/test_tools/simulator/queue.h @@ -0,0 +1,119 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_SIMULATOR_QUEUE_H_ +#define QUICHE_QUIC_TEST_TOOLS_SIMULATOR_QUEUE_H_ + +#include "quiche/quic/core/quic_alarm.h" +#include "quiche/quic/test_tools/simulator/link.h" +#include "quiche/common/quiche_circular_deque.h" + +namespace quic { +namespace simulator { + +// A finitely sized queue which egresses packets onto a constrained link. The +// capacity of the queue is measured in bytes as opposed to packets. +class Queue : public Actor, public UnconstrainedPortInterface { + public: + class ListenerInterface { + public: + virtual ~ListenerInterface(); + + // Called whenever a packet is removed from the queue. + virtual void OnPacketDequeued() = 0; + }; + + Queue(Simulator* simulator, std::string name, QuicByteCount capacity); + Queue(const Queue&) = delete; + Queue& operator=(const Queue&) = delete; + ~Queue() override; + + void set_tx_port(ConstrainedPortInterface* port); + + void AcceptPacket(std::unique_ptr packet) override; + + void Act() override; + + QuicByteCount capacity() const { return capacity_; } + QuicByteCount bytes_queued() const { return bytes_queued_; } + QuicPacketCount packets_queued() const { return queue_.size(); } + + void set_listener_interface(ListenerInterface* listener) { + listener_ = listener; + } + + // Enables packet aggregation on the queue. Packet aggregation makes the + // queue bundle packets up until they reach certain size. When the + // aggregation is enabled, the packets are not dequeued until the total size + // of packets in the queue reaches |aggregation_threshold|. The packets are + // automatically flushed from the queue if the oldest packet has been in it + // for |aggregation_timeout|. + // + // This method may only be called when the queue is empty. Once enabled, + // aggregation cannot be disabled. + void EnableAggregation(QuicByteCount aggregation_threshold, + QuicTime::Delta aggregation_timeout); + + private: + using AggregationBundleNumber = uint64_t; + + // In order to implement packet aggregation, each packet is tagged with a + // bundle number. The queue keeps a bundle counter, and whenever a bundle is + // ready, it increments the number of the current bundle. Only the packets + // outside of the current bundle are allowed to leave the queue. + struct EnqueuedPacket { + EnqueuedPacket(std::unique_ptr packet, + AggregationBundleNumber bundle); + EnqueuedPacket(EnqueuedPacket&& other); + ~EnqueuedPacket(); + + std::unique_ptr packet; + AggregationBundleNumber bundle; + }; + + // Alarm handler for aggregation timeout. + class AggregationAlarmDelegate : public QuicAlarm::DelegateWithoutContext { + public: + explicit AggregationAlarmDelegate(Queue* queue); + + void OnAlarm() override; + + private: + Queue* queue_; + }; + + bool IsAggregationEnabled() const { return aggregation_threshold_ > 0; } + + // Increment the bundle counter and reset the bundle state. This causes all + // packets currently in the bundle to be flushed onto the link. + void NextBundle(); + + void ScheduleNextPacketDequeue(); + + const QuicByteCount capacity_; + QuicByteCount bytes_queued_; + + QuicByteCount aggregation_threshold_; + QuicTime::Delta aggregation_timeout_; + // The number of the current aggregation bundle. Monotonically increasing. + // All packets in the previous bundles are allowed to leave the queue, and + // none of the packets in the current one are. + AggregationBundleNumber current_bundle_; + // Size of the current bundle. Whenever it exceeds |aggregation_threshold_|, + // the next bundle is created. + QuicByteCount current_bundle_bytes_; + // Alarm responsible for flushing the current bundle upon timeout. Set when + // the first packet in the bundle is enqueued. + std::unique_ptr aggregation_timeout_alarm_; + + ConstrainedPortInterface* tx_port_; + quiche::QuicheCircularDeque queue_; + + ListenerInterface* listener_; +}; + +} // namespace simulator +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_SIMULATOR_QUEUE_H_ diff --git a/quiche/quic/test_tools/simulator/quic_endpoint.cc b/quiche/quic/test_tools/simulator/quic_endpoint.cc new file mode 100644 index 000000000000..d4d580985cf2 --- /dev/null +++ b/quiche/quic/test_tools/simulator/quic_endpoint.cc @@ -0,0 +1,246 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/simulator/quic_endpoint.h" + +#include +#include + +#include "quiche/quic/core/crypto/crypto_handshake_message.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/platform/api/quic_test_output.h" +#include "quiche/quic/test_tools/quic_config_peer.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/test_tools/simulator/simulator.h" + +namespace quic { +namespace simulator { + +const QuicStreamId kDataStream = 3; +const QuicByteCount kWriteChunkSize = 128 * 1024; +const char kStreamDataContents = 'Q'; + +QuicEndpoint::QuicEndpoint(Simulator* simulator, std::string name, + std::string peer_name, Perspective perspective, + QuicConnectionId connection_id) + : QuicEndpointBase(simulator, name, peer_name), + bytes_to_transfer_(0), + bytes_transferred_(0), + wrong_data_received_(false), + notifier_(nullptr) { + connection_ = std::make_unique( + connection_id, GetAddressFromName(name), GetAddressFromName(peer_name), + simulator, simulator->GetAlarmFactory(), &writer_, false, perspective, + ParsedVersionOfIndex(CurrentSupportedVersions(), 0), + connection_id_generator_); + connection_->set_visitor(this); + connection_->SetEncrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique( + ENCRYPTION_FORWARD_SECURE)); + connection_->SetEncrypter(ENCRYPTION_INITIAL, nullptr); + if (connection_->version().KnowsWhichDecrypterToUse()) { + connection_->InstallDecrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique( + ENCRYPTION_FORWARD_SECURE)); + connection_->RemoveDecrypter(ENCRYPTION_INITIAL); + } else { + connection_->SetDecrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique( + ENCRYPTION_FORWARD_SECURE)); + } + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + connection_->OnHandshakeComplete(); + if (perspective == Perspective::IS_SERVER) { + // Skip version negotiation. + test::QuicConnectionPeer::SetNegotiatedVersion(connection_.get()); + } + test::QuicConnectionPeer::SetAddressValidated(connection_.get()); + connection_->SetDataProducer(&producer_); + connection_->SetSessionNotifier(this); + notifier_ = std::make_unique(connection_.get()); + + // Configure the connection as if it received a handshake. This is important + // primarily because + // - this enables pacing, and + // - this sets the non-handshake timeouts. + std::string error; + CryptoHandshakeMessage peer_hello; + peer_hello.SetValue(kICSL, + static_cast(kMaximumIdleTimeoutSecs - 1)); + peer_hello.SetValue(kMIBS, + static_cast(kDefaultMaxStreamsPerConnection)); + QuicConfig config; + QuicErrorCode error_code = config.ProcessPeerHello( + peer_hello, perspective == Perspective::IS_CLIENT ? SERVER : CLIENT, + &error); + QUICHE_DCHECK_EQ(error_code, QUIC_NO_ERROR) + << "Configuration failed: " << error; + if (connection_->version().UsesTls()) { + if (connection_->perspective() == Perspective::IS_CLIENT) { + test::QuicConfigPeer::SetReceivedOriginalConnectionId( + &config, connection_->connection_id()); + test::QuicConfigPeer::SetReceivedInitialSourceConnectionId( + &config, connection_->connection_id()); + } else { + test::QuicConfigPeer::SetReceivedInitialSourceConnectionId( + &config, connection_->client_connection_id()); + } + } + connection_->SetFromConfig(config); + connection_->DisableMtuDiscovery(); +} + +QuicByteCount QuicEndpoint::bytes_received() const { + QuicByteCount total = 0; + for (auto& interval : offsets_received_) { + total += interval.max() - interval.min(); + } + return total; +} + +QuicByteCount QuicEndpoint::bytes_to_transfer() const { + if (notifier_ != nullptr) { + return notifier_->StreamBytesToSend(); + } + return bytes_to_transfer_; +} + +QuicByteCount QuicEndpoint::bytes_transferred() const { + if (notifier_ != nullptr) { + return notifier_->StreamBytesSent(); + } + return bytes_transferred_; +} + +void QuicEndpoint::AddBytesToTransfer(QuicByteCount bytes) { + if (notifier_ != nullptr) { + if (notifier_->HasBufferedStreamData()) { + Schedule(clock_->Now()); + } + notifier_->WriteOrBufferData(kDataStream, bytes, NO_FIN); + return; + } + + if (bytes_to_transfer_ > 0) { + Schedule(clock_->Now()); + } + + bytes_to_transfer_ += bytes; + WriteStreamData(); +} + +void QuicEndpoint::OnStreamFrame(const QuicStreamFrame& frame) { + // Verify that the data received always matches the expected. + QUICHE_DCHECK(frame.stream_id == kDataStream); + for (size_t i = 0; i < frame.data_length; i++) { + if (frame.data_buffer[i] != kStreamDataContents) { + wrong_data_received_ = true; + } + } + offsets_received_.Add(frame.offset, frame.offset + frame.data_length); + // Sanity check against very pathological connections. + QUICHE_DCHECK_LE(offsets_received_.Size(), 1000u); +} + +void QuicEndpoint::OnCryptoFrame(const QuicCryptoFrame& /*frame*/) {} + +void QuicEndpoint::OnCanWrite() { + if (notifier_ != nullptr) { + notifier_->OnCanWrite(); + return; + } + WriteStreamData(); +} + +bool QuicEndpoint::WillingAndAbleToWrite() const { + if (notifier_ != nullptr) { + return notifier_->WillingToWrite(); + } + return bytes_to_transfer_ != 0; +} +bool QuicEndpoint::ShouldKeepConnectionAlive() const { return true; } + +bool QuicEndpoint::AllowSelfAddressChange() const { return false; } + +bool QuicEndpoint::OnFrameAcked(const QuicFrame& frame, + QuicTime::Delta ack_delay_time, + QuicTime receive_timestamp) { + if (notifier_ != nullptr) { + return notifier_->OnFrameAcked(frame, ack_delay_time, receive_timestamp); + } + return false; +} + +void QuicEndpoint::OnFrameLost(const QuicFrame& frame) { + QUICHE_DCHECK(notifier_); + notifier_->OnFrameLost(frame); +} + +bool QuicEndpoint::RetransmitFrames(const QuicFrames& frames, + TransmissionType type) { + QUICHE_DCHECK(notifier_); + return notifier_->RetransmitFrames(frames, type); +} + +bool QuicEndpoint::IsFrameOutstanding(const QuicFrame& frame) const { + QUICHE_DCHECK(notifier_); + return notifier_->IsFrameOutstanding(frame); +} + +bool QuicEndpoint::HasUnackedCryptoData() const { return false; } + +bool QuicEndpoint::HasUnackedStreamData() const { + if (notifier_ != nullptr) { + return notifier_->HasUnackedStreamData(); + } + return false; +} + +HandshakeState QuicEndpoint::GetHandshakeState() const { + return HANDSHAKE_COMPLETE; +} + +WriteStreamDataResult QuicEndpoint::DataProducer::WriteStreamData( + QuicStreamId /*id*/, QuicStreamOffset /*offset*/, QuicByteCount data_length, + QuicDataWriter* writer) { + writer->WriteRepeatedByte(kStreamDataContents, data_length); + return WRITE_SUCCESS; +} + +bool QuicEndpoint::DataProducer::WriteCryptoData(EncryptionLevel /*level*/, + QuicStreamOffset /*offset*/, + QuicByteCount /*data_length*/, + QuicDataWriter* /*writer*/) { + QUIC_BUG(quic_bug_10157_1) + << "QuicEndpoint::DataProducer::WriteCryptoData is unimplemented"; + return false; +} + +void QuicEndpoint::WriteStreamData() { + // Instantiate a flusher which would normally be here due to QuicSession. + QuicConnection::ScopedPacketFlusher flusher(connection_.get()); + + while (bytes_to_transfer_ > 0) { + // Transfer data in chunks of size at most |kWriteChunkSize|. + const size_t transmission_size = + std::min(kWriteChunkSize, bytes_to_transfer_); + + QuicConsumedData consumed_data = connection_->SendStreamData( + kDataStream, transmission_size, bytes_transferred_, NO_FIN); + + QUICHE_DCHECK(consumed_data.bytes_consumed <= transmission_size); + bytes_transferred_ += consumed_data.bytes_consumed; + bytes_to_transfer_ -= consumed_data.bytes_consumed; + if (consumed_data.bytes_consumed != transmission_size) { + return; + } + } +} + +} // namespace simulator +} // namespace quic diff --git a/quiche/quic/test_tools/simulator/quic_endpoint.h b/quiche/quic/test_tools/simulator/quic_endpoint.h new file mode 100644 index 000000000000..6be7bbbe164f --- /dev/null +++ b/quiche/quic/test_tools/simulator/quic_endpoint.h @@ -0,0 +1,171 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_SIMULATOR_QUIC_ENDPOINT_H_ +#define QUICHE_QUIC_TEST_TOOLS_SIMULATOR_QUIC_ENDPOINT_H_ + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/null_decrypter.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/core/quic_packet_writer.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_stream_frame_data_producer.h" +#include "quiche/quic/core/quic_trace_visitor.h" +#include "quiche/quic/test_tools/simple_session_notifier.h" +#include "quiche/quic/test_tools/simulator/link.h" +#include "quiche/quic/test_tools/simulator/queue.h" +#include "quiche/quic/test_tools/simulator/quic_endpoint_base.h" + +namespace quic { +namespace simulator { + +// A QUIC connection endpoint. Wraps around QuicConnection. In order to +// initiate a transfer, the caller has to call AddBytesToTransfer(). The data +// transferred is always the same and is always transferred on a single stream. +// The endpoint receives all packets addressed to it, and verifies that the data +// received is what it's supposed to be. +class QuicEndpoint : public QuicEndpointBase, + public QuicConnectionVisitorInterface, + public SessionNotifierInterface { + public: + QuicEndpoint(Simulator* simulator, std::string name, std::string peer_name, + Perspective perspective, QuicConnectionId connection_id); + + QuicByteCount bytes_to_transfer() const; + QuicByteCount bytes_transferred() const; + QuicByteCount bytes_received() const; + bool wrong_data_received() const { return wrong_data_received_; } + + // Send |bytes| bytes. Initiates the transfer if one is not already in + // progress. + void AddBytesToTransfer(QuicByteCount bytes); + + // Begin QuicConnectionVisitorInterface implementation. + void OnStreamFrame(const QuicStreamFrame& frame) override; + void OnCryptoFrame(const QuicCryptoFrame& frame) override; + void OnCanWrite() override; + bool WillingAndAbleToWrite() const override; + bool ShouldKeepConnectionAlive() const override; + + std::string GetStreamsInfoForLogging() const override { return ""; } + void OnWindowUpdateFrame(const QuicWindowUpdateFrame& /*frame*/) override {} + void OnBlockedFrame(const QuicBlockedFrame& /*frame*/) override {} + void OnRstStream(const QuicRstStreamFrame& /*frame*/) override {} + void OnGoAway(const QuicGoAwayFrame& /*frame*/) override {} + void OnMessageReceived(absl::string_view /*message*/) override {} + void OnHandshakeDoneReceived() override {} + void OnNewTokenReceived(absl::string_view /*token*/) override {} + void OnConnectionClosed(const QuicConnectionCloseFrame& /*frame*/, + ConnectionCloseSource /*source*/) override {} + void OnWriteBlocked() override {} + void OnSuccessfulVersionNegotiation( + const ParsedQuicVersion& /*version*/) override {} + void OnPacketReceived(const QuicSocketAddress& /*self_address*/, + const QuicSocketAddress& /*peer_address*/, + bool /*is_connectivity_probe*/) override {} + void OnCongestionWindowChange(QuicTime /*now*/) override {} + void OnConnectionMigration(AddressChangeType /*type*/) override {} + void OnPathDegrading() override {} + void OnForwardProgressMadeAfterPathDegrading() override {} + void OnAckNeedsRetransmittableFrame() override {} + void SendAckFrequency(const QuicAckFrequencyFrame& /*frame*/) override {} + void SendNewConnectionId(const QuicNewConnectionIdFrame& /*frame*/) override { + } + void SendRetireConnectionId(uint64_t /*sequence_number*/) override {} + bool MaybeReserveConnectionId( + const QuicConnectionId& /*server_connection_id*/) override { + return true; + } + void OnServerConnectionIdRetired( + const QuicConnectionId& /*server_connection_id*/) override {} + bool AllowSelfAddressChange() const override; + HandshakeState GetHandshakeState() const override; + bool OnMaxStreamsFrame(const QuicMaxStreamsFrame& /*frame*/) override { + return true; + } + bool OnStreamsBlockedFrame( + const QuicStreamsBlockedFrame& /*frame*/) override { + return true; + } + void OnStopSendingFrame(const QuicStopSendingFrame& /*frame*/) override {} + void OnPacketDecrypted(EncryptionLevel /*level*/) override {} + void OnOneRttPacketAcknowledged() override {} + void OnHandshakePacketSent() override {} + void OnKeyUpdate(KeyUpdateReason /*reason*/) override {} + std::unique_ptr AdvanceKeysAndCreateCurrentOneRttDecrypter() + override { + return nullptr; + } + std::unique_ptr CreateCurrentOneRttEncrypter() override { + return nullptr; + } + void BeforeConnectionCloseSent() override {} + bool ValidateToken(absl::string_view /*token*/) override { return true; } + bool MaybeSendAddressToken() override { return false; } + void OnBandwidthUpdateTimeout() override {} + std::unique_ptr CreateContextForMultiPortPath() + override { + return nullptr; + } + void MigrateToMultiPortPath( + std::unique_ptr /*context*/) override {} + void OnServerPreferredAddressAvailable( + const QuicSocketAddress& /*server_preferred_address*/) override {} + + // End QuicConnectionVisitorInterface implementation. + + // Begin SessionNotifierInterface methods: + bool OnFrameAcked(const QuicFrame& frame, QuicTime::Delta ack_delay_time, + QuicTime receive_timestamp) override; + void OnStreamFrameRetransmitted(const QuicStreamFrame& /*frame*/) override {} + void OnFrameLost(const QuicFrame& frame) override; + bool RetransmitFrames(const QuicFrames& frames, + TransmissionType type) override; + bool IsFrameOutstanding(const QuicFrame& frame) const override; + bool HasUnackedCryptoData() const override; + bool HasUnackedStreamData() const override; + // End SessionNotifierInterface implementation. + + private: + // The producer outputs the repetition of the same byte. That sequence is + // verified by the receiver. + class DataProducer : public QuicStreamFrameDataProducer { + public: + WriteStreamDataResult WriteStreamData(QuicStreamId id, + QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* writer) override; + bool WriteCryptoData(EncryptionLevel level, QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* writer) override; + }; + + std::unique_ptr CreateConnection( + Simulator* simulator, std::string name, std::string peer_name, + Perspective perspective, QuicConnectionId connection_id); + + // Write stream data until |bytes_to_transfer_| is zero or the connection is + // write-blocked. + void WriteStreamData(); + + DataProducer producer_; + + QuicByteCount bytes_to_transfer_; + QuicByteCount bytes_transferred_; + + // Set to true if the endpoint receives stream data different from what it + // expects. + bool wrong_data_received_; + + // Record of received offsets in the data stream. + QuicIntervalSet offsets_received_; + + std::unique_ptr notifier_; +}; + +} // namespace simulator +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_SIMULATOR_QUIC_ENDPOINT_H_ diff --git a/quiche/quic/test_tools/simulator/quic_endpoint_base.cc b/quiche/quic/test_tools/simulator/quic_endpoint_base.cc new file mode 100644 index 000000000000..d466c96137d5 --- /dev/null +++ b/quiche/quic/test_tools/simulator/quic_endpoint_base.cc @@ -0,0 +1,206 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/simulator/quic_endpoint_base.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "quiche/quic/core/crypto/crypto_handshake_message.h" +#include "quiche/quic/core/crypto/crypto_protocol.h" +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/core/quic_data_writer.h" +#include "quiche/quic/platform/api/quic_test_output.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/test_tools/simulator/simulator.h" + +namespace quic { +namespace simulator { + +// Takes a SHA-1 hash of the name and converts it into five 32-bit integers. +static std::vector HashNameIntoFive32BitIntegers(std::string name) { + const std::string hash = test::Sha1Hash(name); + + std::vector output; + uint32_t current_number = 0; + for (size_t i = 0; i < hash.size(); i++) { + current_number = (current_number << 8) + hash[i]; + if (i % 4 == 3) { + output.push_back(i); + current_number = 0; + } + } + + return output; +} + +QuicSocketAddress GetAddressFromName(std::string name) { + const std::vector hash = HashNameIntoFive32BitIntegers(name); + + // Generate a random port between 1025 and 65535. + const uint16_t port = 1025 + hash[0] % (65535 - 1025 + 1); + + // Generate a random 10.x.x.x address, where x is between 1 and 254. + std::string ip_address{"\xa\0\0\0", 4}; + for (size_t i = 1; i < 4; i++) { + ip_address[i] = 1 + hash[i] % 254; + } + QuicIpAddress host; + host.FromPackedString(ip_address.c_str(), ip_address.length()); + return QuicSocketAddress(host, port); +} + +QuicEndpointBase::QuicEndpointBase(Simulator* simulator, std::string name, + std::string peer_name) + : Endpoint(simulator, name), + peer_name_(peer_name), + writer_(this), + nic_tx_queue_(simulator, absl::StrCat(name, " (TX Queue)"), + kMaxOutgoingPacketSize * kTxQueueSize), + connection_(nullptr), + write_blocked_count_(0), + drop_next_packet_(false) { + nic_tx_queue_.set_listener_interface(this); +} + +QuicEndpointBase::~QuicEndpointBase() { + if (trace_visitor_ != nullptr) { + const char* perspective_prefix = + connection_->perspective() == Perspective::IS_CLIENT ? "C" : "S"; + + std::string identifier = absl::StrCat( + perspective_prefix, connection_->connection_id().ToString()); + QuicRecordTrace(identifier, trace_visitor_->trace()->SerializeAsString()); + } +} + +void QuicEndpointBase::DropNextIncomingPacket() { drop_next_packet_ = true; } + +void QuicEndpointBase::RecordTrace() { + trace_visitor_ = std::make_unique(connection_.get()); + connection_->set_debug_visitor(trace_visitor_.get()); +} + +void QuicEndpointBase::AcceptPacket(std::unique_ptr packet) { + if (packet->destination != name_) { + return; + } + if (drop_next_packet_) { + drop_next_packet_ = false; + return; + } + + QuicReceivedPacket received_packet(packet->contents.data(), + packet->contents.size(), clock_->Now()); + connection_->ProcessUdpPacket(connection_->self_address(), + connection_->peer_address(), received_packet); +} + +UnconstrainedPortInterface* QuicEndpointBase::GetRxPort() { return this; } + +void QuicEndpointBase::SetTxPort(ConstrainedPortInterface* port) { + // Any egress done by the endpoint is actually handled by a queue on an NIC. + nic_tx_queue_.set_tx_port(port); +} + +void QuicEndpointBase::OnPacketDequeued() { + if (writer_.IsWriteBlocked() && + (nic_tx_queue_.capacity() - nic_tx_queue_.bytes_queued()) >= + kMaxOutgoingPacketSize) { + writer_.SetWritable(); + connection_->OnCanWrite(); + } +} + +QuicEndpointBase::Writer::Writer(QuicEndpointBase* endpoint) + : endpoint_(endpoint), is_blocked_(false) {} + +QuicEndpointBase::Writer::~Writer() {} + +WriteResult QuicEndpointBase::Writer::WritePacket( + const char* buffer, size_t buf_len, const QuicIpAddress& /*self_address*/, + const QuicSocketAddress& /*peer_address*/, PerPacketOptions* options) { + QUICHE_DCHECK(!IsWriteBlocked()); + QUICHE_DCHECK(options == nullptr); + QUICHE_DCHECK(buf_len <= kMaxOutgoingPacketSize); + + // Instead of losing a packet, become write-blocked when the egress queue is + // full. + if (endpoint_->nic_tx_queue_.packets_queued() > kTxQueueSize) { + is_blocked_ = true; + endpoint_->write_blocked_count_++; + return WriteResult(WRITE_STATUS_BLOCKED, 0); + } + + auto packet = std::make_unique(); + packet->source = endpoint_->name(); + packet->destination = endpoint_->peer_name_; + packet->tx_timestamp = endpoint_->clock_->Now(); + + packet->contents = std::string(buffer, buf_len); + packet->size = buf_len; + + endpoint_->nic_tx_queue_.AcceptPacket(std::move(packet)); + + return WriteResult(WRITE_STATUS_OK, buf_len); +} + +bool QuicEndpointBase::Writer::IsWriteBlocked() const { return is_blocked_; } + +void QuicEndpointBase::Writer::SetWritable() { is_blocked_ = false; } + +absl::optional QuicEndpointBase::Writer::MessageTooBigErrorCode() const { + return absl::nullopt; +} + +QuicByteCount QuicEndpointBase::Writer::GetMaxPacketSize( + const QuicSocketAddress& /*peer_address*/) const { + return kMaxOutgoingPacketSize; +} + +bool QuicEndpointBase::Writer::SupportsReleaseTime() const { return false; } + +bool QuicEndpointBase::Writer::IsBatchMode() const { return false; } + +QuicPacketBuffer QuicEndpointBase::Writer::GetNextWriteLocation( + const QuicIpAddress& /*self_address*/, + const QuicSocketAddress& /*peer_address*/) { + return {nullptr, nullptr}; +} + +WriteResult QuicEndpointBase::Writer::Flush() { + return WriteResult(WRITE_STATUS_OK, 0); +} + +QuicEndpointMultiplexer::QuicEndpointMultiplexer( + std::string name, const std::vector& endpoints) + : Endpoint((*endpoints.begin())->simulator(), name) { + for (QuicEndpointBase* endpoint : endpoints) { + mapping_.insert(std::make_pair(endpoint->name(), endpoint)); + } +} + +QuicEndpointMultiplexer::~QuicEndpointMultiplexer() {} + +void QuicEndpointMultiplexer::AcceptPacket(std::unique_ptr packet) { + auto key_value_pair_it = mapping_.find(packet->destination); + if (key_value_pair_it == mapping_.end()) { + return; + } + + key_value_pair_it->second->GetRxPort()->AcceptPacket(std::move(packet)); +} +UnconstrainedPortInterface* QuicEndpointMultiplexer::GetRxPort() { + return this; +} +void QuicEndpointMultiplexer::SetTxPort(ConstrainedPortInterface* port) { + for (auto& key_value_pair : mapping_) { + key_value_pair.second->SetTxPort(port); + } +} + +} // namespace simulator +} // namespace quic diff --git a/quiche/quic/test_tools/simulator/quic_endpoint_base.h b/quiche/quic/test_tools/simulator/quic_endpoint_base.h new file mode 100644 index 000000000000..540b2852bc8d --- /dev/null +++ b/quiche/quic/test_tools/simulator/quic_endpoint_base.h @@ -0,0 +1,160 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_SIMULATOR_QUIC_ENDPOINT_BASE_H_ +#define QUICHE_QUIC_TEST_TOOLS_SIMULATOR_QUIC_ENDPOINT_BASE_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "quiche/quic/core/crypto/null_decrypter.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/core/quic_packet_writer.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_stream_frame_data_producer.h" +#include "quiche/quic/core/quic_trace_visitor.h" +#include "quiche/quic/test_tools/mock_connection_id_generator.h" +#include "quiche/quic/test_tools/simple_session_notifier.h" +#include "quiche/quic/test_tools/simulator/link.h" +#include "quiche/quic/test_tools/simulator/queue.h" + +namespace quic { +namespace simulator { + +// Size of the TX queue used by the kernel/NIC. 1000 is the Linux +// kernel default. +const QuicByteCount kTxQueueSize = 1000; + +// Generate a random local network host-port tuple based on the name of the +// endpoint. +QuicSocketAddress GetAddressFromName(std::string name); + +// A QUIC connection endpoint. If the specific data transmitted does not matter +// (e.g. for congestion control purposes), QuicEndpoint is the subclass that +// transmits dummy data. If the actual semantics of the connection matter, +// subclassing QuicEndpointBase is required. +class QuicEndpointBase : public Endpoint, + public UnconstrainedPortInterface, + public Queue::ListenerInterface { + public: + // Does not create the connection; the subclass has to create connection by + // itself. + QuicEndpointBase(Simulator* simulator, std::string name, + std::string peer_name); + ~QuicEndpointBase() override; + + QuicConnection* connection() { return connection_.get(); } + size_t write_blocked_count() { return write_blocked_count_; } + + // Drop the next packet upon receipt. + void DropNextIncomingPacket(); + + // UnconstrainedPortInterface method. Called whenever the endpoint receives a + // packet. + void AcceptPacket(std::unique_ptr packet) override; + + // Enables logging of the connection trace at the end of the unit test. + void RecordTrace(); + + // Begin Endpoint implementation. + UnconstrainedPortInterface* GetRxPort() override; + void SetTxPort(ConstrainedPortInterface* port) override; + // End Endpoint implementation. + + // Actor method. + void Act() override {} + + // Queue::ListenerInterface method. + void OnPacketDequeued() override; + + protected: + // A Writer object that writes into the |nic_tx_queue_|. + class Writer : public QuicPacketWriter { + public: + explicit Writer(QuicEndpointBase* endpoint); + ~Writer() override; + + WriteResult WritePacket(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + PerPacketOptions* options) override; + bool IsWriteBlocked() const override; + void SetWritable() override; + absl::optional MessageTooBigErrorCode() const override; + QuicByteCount GetMaxPacketSize( + const QuicSocketAddress& peer_address) const override; + bool SupportsReleaseTime() const override; + bool IsBatchMode() const override; + QuicPacketBuffer GetNextWriteLocation( + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address) override; + WriteResult Flush() override; + + private: + QuicEndpointBase* endpoint_; + + bool is_blocked_; + }; + + // The producer outputs the repetition of the same byte. That sequence is + // verified by the receiver. + class DataProducer : public QuicStreamFrameDataProducer { + public: + WriteStreamDataResult WriteStreamData(QuicStreamId id, + QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* writer) override; + bool WriteCryptoData(EncryptionLevel level, QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* writer) override; + }; + + std::string peer_name_; + + Writer writer_; + // The queue for the outgoing packets. In reality, this might be either on + // the network card, or in the kernel, but for concreteness we assume it's on + // the network card. + Queue nic_tx_queue_; + // Created by the subclass. + std::unique_ptr connection_; + + // Counts the number of times the writer became write-blocked. + size_t write_blocked_count_; + + // If true, drop the next packet when receiving it. + bool drop_next_packet_; + + std::unique_ptr trace_visitor_; + + test::MockConnectionIdGenerator connection_id_generator_; +}; + +// Multiplexes multiple connections at the same host on the network. +class QuicEndpointMultiplexer : public Endpoint, + public UnconstrainedPortInterface { + public: + QuicEndpointMultiplexer(std::string name, + const std::vector& endpoints); + ~QuicEndpointMultiplexer() override; + + // Receives a packet and passes it to the specified endpoint if that endpoint + // is one of the endpoints being multiplexed, otherwise ignores the packet. + void AcceptPacket(std::unique_ptr packet) override; + UnconstrainedPortInterface* GetRxPort() override; + + // Sets the egress port for all the endpoints being multiplexed. + void SetTxPort(ConstrainedPortInterface* port) override; + + void Act() override {} + + private: + absl::flat_hash_map mapping_; +}; + +} // namespace simulator +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_SIMULATOR_QUIC_ENDPOINT_BASE_H_ diff --git a/quiche/quic/test_tools/simulator/quic_endpoint_test.cc b/quiche/quic/test_tools/simulator/quic_endpoint_test.cc new file mode 100644 index 000000000000..c247eb52aac2 --- /dev/null +++ b/quiche/quic/test_tools/simulator/quic_endpoint_test.cc @@ -0,0 +1,207 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/simulator/quic_endpoint.h" + +#include + +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/test_tools/simulator/simulator.h" +#include "quiche/quic/test_tools/simulator/switch.h" + +using ::testing::_; +using ::testing::NiceMock; +using ::testing::Return; + +namespace quic { +namespace simulator { + +const QuicBandwidth kDefaultBandwidth = + QuicBandwidth::FromKBitsPerSecond(10 * 1000); +const QuicTime::Delta kDefaultPropagationDelay = + QuicTime::Delta::FromMilliseconds(20); +const QuicByteCount kDefaultBdp = kDefaultBandwidth * kDefaultPropagationDelay; + +// A simple test harness where all hosts are connected to a switch with +// identical links. +class QuicEndpointTest : public quic::test::QuicTest { + public: + QuicEndpointTest() + : simulator_(), switch_(&simulator_, "Switch", 8, kDefaultBdp * 2) {} + + protected: + Simulator simulator_; + Switch switch_; + + std::unique_ptr Link(Endpoint* a, Endpoint* b) { + return std::make_unique(a, b, kDefaultBandwidth, + kDefaultPropagationDelay); + } + + std::unique_ptr CustomLink(Endpoint* a, Endpoint* b, + uint64_t extra_rtt_ms) { + return std::make_unique( + a, b, kDefaultBandwidth, + kDefaultPropagationDelay + + QuicTime::Delta::FromMilliseconds(extra_rtt_ms)); + } +}; + +// Test transmission from one host to another. +TEST_F(QuicEndpointTest, OneWayTransmission) { + QuicEndpoint endpoint_a(&simulator_, "Endpoint A", "Endpoint B", + Perspective::IS_CLIENT, test::TestConnectionId(42)); + QuicEndpoint endpoint_b(&simulator_, "Endpoint B", "Endpoint A", + Perspective::IS_SERVER, test::TestConnectionId(42)); + auto link_a = Link(&endpoint_a, switch_.port(1)); + auto link_b = Link(&endpoint_b, switch_.port(2)); + + // First transmit a small, packet-size chunk of data. + endpoint_a.AddBytesToTransfer(600); + QuicTime end_time = + simulator_.GetClock()->Now() + QuicTime::Delta::FromMilliseconds(1000); + simulator_.RunUntil( + [this, end_time]() { return simulator_.GetClock()->Now() >= end_time; }); + + EXPECT_EQ(600u, endpoint_a.bytes_transferred()); + ASSERT_EQ(600u, endpoint_b.bytes_received()); + EXPECT_FALSE(endpoint_a.wrong_data_received()); + EXPECT_FALSE(endpoint_b.wrong_data_received()); + + // After a small chunk succeeds, try to transfer 2 MiB. + endpoint_a.AddBytesToTransfer(2 * 1024 * 1024); + end_time = simulator_.GetClock()->Now() + QuicTime::Delta::FromSeconds(5); + simulator_.RunUntil( + [this, end_time]() { return simulator_.GetClock()->Now() >= end_time; }); + + const QuicByteCount total_bytes_transferred = 600 + 2 * 1024 * 1024; + EXPECT_EQ(total_bytes_transferred, endpoint_a.bytes_transferred()); + EXPECT_EQ(total_bytes_transferred, endpoint_b.bytes_received()); + EXPECT_EQ(0u, endpoint_a.write_blocked_count()); + EXPECT_FALSE(endpoint_a.wrong_data_received()); + EXPECT_FALSE(endpoint_b.wrong_data_received()); +} + +// Test the situation in which the writer becomes write-blocked. +TEST_F(QuicEndpointTest, WriteBlocked) { + QuicEndpoint endpoint_a(&simulator_, "Endpoint A", "Endpoint B", + Perspective::IS_CLIENT, test::TestConnectionId(42)); + QuicEndpoint endpoint_b(&simulator_, "Endpoint B", "Endpoint A", + Perspective::IS_SERVER, test::TestConnectionId(42)); + auto link_a = Link(&endpoint_a, switch_.port(1)); + auto link_b = Link(&endpoint_b, switch_.port(2)); + + // Will be owned by the sent packet manager. + auto* sender = new NiceMock(); + EXPECT_CALL(*sender, CanSend(_)).WillRepeatedly(Return(true)); + EXPECT_CALL(*sender, PacingRate(_)) + .WillRepeatedly(Return(10 * kDefaultBandwidth)); + EXPECT_CALL(*sender, BandwidthEstimate()) + .WillRepeatedly(Return(10 * kDefaultBandwidth)); + EXPECT_CALL(*sender, GetCongestionWindow()) + .WillRepeatedly(Return(kMaxOutgoingPacketSize * + GetQuicFlag(quic_max_congestion_window))); + test::QuicConnectionPeer::SetSendAlgorithm(endpoint_a.connection(), sender); + + // First transmit a small, packet-size chunk of data. + QuicByteCount bytes_to_transfer = 3 * 1024 * 1024; + endpoint_a.AddBytesToTransfer(bytes_to_transfer); + QuicTime end_time = + simulator_.GetClock()->Now() + QuicTime::Delta::FromSeconds(30); + simulator_.RunUntil([this, &endpoint_b, bytes_to_transfer, end_time]() { + return endpoint_b.bytes_received() == bytes_to_transfer || + simulator_.GetClock()->Now() >= end_time; + }); + + EXPECT_EQ(bytes_to_transfer, endpoint_a.bytes_transferred()); + EXPECT_EQ(bytes_to_transfer, endpoint_b.bytes_received()); + EXPECT_GT(endpoint_a.write_blocked_count(), 0u); + EXPECT_FALSE(endpoint_a.wrong_data_received()); + EXPECT_FALSE(endpoint_b.wrong_data_received()); +} + +// Test transmission of 1 MiB of data between two hosts simultaneously in both +// directions. +TEST_F(QuicEndpointTest, TwoWayTransmission) { + QuicEndpoint endpoint_a(&simulator_, "Endpoint A", "Endpoint B", + Perspective::IS_CLIENT, test::TestConnectionId(42)); + QuicEndpoint endpoint_b(&simulator_, "Endpoint B", "Endpoint A", + Perspective::IS_SERVER, test::TestConnectionId(42)); + auto link_a = Link(&endpoint_a, switch_.port(1)); + auto link_b = Link(&endpoint_b, switch_.port(2)); + + endpoint_a.RecordTrace(); + endpoint_b.RecordTrace(); + + endpoint_a.AddBytesToTransfer(1024 * 1024); + endpoint_b.AddBytesToTransfer(1024 * 1024); + QuicTime end_time = + simulator_.GetClock()->Now() + QuicTime::Delta::FromSeconds(5); + simulator_.RunUntil( + [this, end_time]() { return simulator_.GetClock()->Now() >= end_time; }); + + EXPECT_EQ(1024u * 1024u, endpoint_a.bytes_transferred()); + EXPECT_EQ(1024u * 1024u, endpoint_b.bytes_transferred()); + EXPECT_EQ(1024u * 1024u, endpoint_a.bytes_received()); + EXPECT_EQ(1024u * 1024u, endpoint_b.bytes_received()); + EXPECT_FALSE(endpoint_a.wrong_data_received()); + EXPECT_FALSE(endpoint_b.wrong_data_received()); +} + +// Simulate three hosts trying to send data to a fourth one simultaneously. +TEST_F(QuicEndpointTest, Competition) { + auto endpoint_a = std::make_unique( + &simulator_, "Endpoint A", "Endpoint D (A)", Perspective::IS_CLIENT, + test::TestConnectionId(42)); + auto endpoint_b = std::make_unique( + &simulator_, "Endpoint B", "Endpoint D (B)", Perspective::IS_CLIENT, + test::TestConnectionId(43)); + auto endpoint_c = std::make_unique( + &simulator_, "Endpoint C", "Endpoint D (C)", Perspective::IS_CLIENT, + test::TestConnectionId(44)); + auto endpoint_d_a = std::make_unique( + &simulator_, "Endpoint D (A)", "Endpoint A", Perspective::IS_SERVER, + test::TestConnectionId(42)); + auto endpoint_d_b = std::make_unique( + &simulator_, "Endpoint D (B)", "Endpoint B", Perspective::IS_SERVER, + test::TestConnectionId(43)); + auto endpoint_d_c = std::make_unique( + &simulator_, "Endpoint D (C)", "Endpoint C", Perspective::IS_SERVER, + test::TestConnectionId(44)); + QuicEndpointMultiplexer endpoint_d( + "Endpoint D", + {endpoint_d_a.get(), endpoint_d_b.get(), endpoint_d_c.get()}); + + // Create links with slightly different RTTs in order to avoid pathological + // side-effects of packets entering the queue at the exactly same time. + auto link_a = CustomLink(endpoint_a.get(), switch_.port(1), 0); + auto link_b = CustomLink(endpoint_b.get(), switch_.port(2), 1); + auto link_c = CustomLink(endpoint_c.get(), switch_.port(3), 2); + auto link_d = Link(&endpoint_d, switch_.port(4)); + + endpoint_a->AddBytesToTransfer(2 * 1024 * 1024); + endpoint_b->AddBytesToTransfer(2 * 1024 * 1024); + endpoint_c->AddBytesToTransfer(2 * 1024 * 1024); + QuicTime end_time = + simulator_.GetClock()->Now() + QuicTime::Delta::FromSeconds(12); + simulator_.RunUntil( + [this, end_time]() { return simulator_.GetClock()->Now() >= end_time; }); + + for (QuicEndpoint* endpoint : + {endpoint_a.get(), endpoint_b.get(), endpoint_c.get()}) { + EXPECT_EQ(2u * 1024u * 1024u, endpoint->bytes_transferred()); + EXPECT_GE(endpoint->connection()->GetStats().packets_lost, 0u); + } + for (QuicEndpoint* endpoint : + {endpoint_d_a.get(), endpoint_d_b.get(), endpoint_d_c.get()}) { + EXPECT_EQ(2u * 1024u * 1024u, endpoint->bytes_received()); + EXPECT_FALSE(endpoint->wrong_data_received()); + } +} + +} // namespace simulator +} // namespace quic diff --git a/quiche/quic/test_tools/simulator/simulator.cc b/quiche/quic/test_tools/simulator/simulator.cc new file mode 100644 index 000000000000..49a0ae220fa4 --- /dev/null +++ b/quiche/quic/test_tools/simulator/simulator.cc @@ -0,0 +1,160 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/simulator/simulator.h" + +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { +namespace simulator { + +Simulator::Simulator() : Simulator(nullptr) {} + +Simulator::Simulator(QuicRandom* random_generator) + : random_generator_(random_generator), + alarm_factory_(this, "Default Alarm Manager"), + run_for_should_stop_(false), + enable_random_delays_(false) { + run_for_alarm_.reset( + alarm_factory_.CreateAlarm(new RunForDelegate(&run_for_should_stop_))); +} + +Simulator::~Simulator() { + // Ensure that Actor under run_for_alarm_ is removed before Simulator data + // structures are destructed. + run_for_alarm_.reset(); +} + +Simulator::Clock::Clock() : now_(kStartTime) {} + +QuicTime Simulator::Clock::ApproximateNow() const { return now_; } + +QuicTime Simulator::Clock::Now() const { return now_; } + +QuicWallTime Simulator::Clock::WallNow() const { + return QuicWallTime::FromUNIXMicroseconds( + (now_ - QuicTime::Zero()).ToMicroseconds()); +} + +void Simulator::AddActor(Actor* actor) { + auto emplace_times_result = + scheduled_times_.insert(std::make_pair(actor, QuicTime::Infinite())); + auto emplace_names_result = actor_names_.insert(actor->name()); + + // Ensure that the object was actually placed into the map. + QUICHE_DCHECK(emplace_times_result.second); + QUICHE_DCHECK(emplace_names_result.second); +} + +void Simulator::RemoveActor(Actor* actor) { + auto scheduled_time_it = scheduled_times_.find(actor); + auto actor_names_it = actor_names_.find(actor->name()); + QUICHE_DCHECK(scheduled_time_it != scheduled_times_.end()); + QUICHE_DCHECK(actor_names_it != actor_names_.end()); + + QuicTime scheduled_time = scheduled_time_it->second; + if (scheduled_time != QuicTime::Infinite()) { + Unschedule(actor); + } + + scheduled_times_.erase(scheduled_time_it); + actor_names_.erase(actor_names_it); +} + +void Simulator::Schedule(Actor* actor, QuicTime new_time) { + auto scheduled_time_it = scheduled_times_.find(actor); + QUICHE_DCHECK(scheduled_time_it != scheduled_times_.end()); + QuicTime scheduled_time = scheduled_time_it->second; + + if (scheduled_time <= new_time) { + return; + } + + if (scheduled_time != QuicTime::Infinite()) { + Unschedule(actor); + } + + scheduled_time_it->second = new_time; + schedule_.insert(std::make_pair(new_time, actor)); +} + +void Simulator::Unschedule(Actor* actor) { + auto scheduled_time_it = scheduled_times_.find(actor); + QUICHE_DCHECK(scheduled_time_it != scheduled_times_.end()); + QuicTime scheduled_time = scheduled_time_it->second; + + QUICHE_DCHECK(scheduled_time != QuicTime::Infinite()); + auto range = schedule_.equal_range(scheduled_time); + for (auto it = range.first; it != range.second; ++it) { + if (it->second == actor) { + schedule_.erase(it); + scheduled_time_it->second = QuicTime::Infinite(); + return; + } + } + QUICHE_DCHECK(false); +} + +const QuicClock* Simulator::GetClock() const { return &clock_; } + +QuicRandom* Simulator::GetRandomGenerator() { + if (random_generator_ == nullptr) { + random_generator_ = QuicRandom::GetInstance(); + } + + return random_generator_; +} + +quiche::QuicheBufferAllocator* Simulator::GetStreamSendBufferAllocator() { + return &buffer_allocator_; +} + +QuicAlarmFactory* Simulator::GetAlarmFactory() { return &alarm_factory_; } + +Simulator::RunForDelegate::RunForDelegate(bool* run_for_should_stop) + : run_for_should_stop_(run_for_should_stop) {} + +void Simulator::RunForDelegate::OnAlarm() { *run_for_should_stop_ = true; } + +void Simulator::RunFor(QuicTime::Delta time_span) { + QUICHE_DCHECK(!run_for_alarm_->IsSet()); + + // RunFor() ensures that the simulation stops at the exact time specified by + // scheduling an alarm at that point and using that alarm to abort the + // simulation. An alarm is necessary because otherwise it is possible that + // nothing is scheduled at |end_time|, so the simulation will either go + // further than requested or stop before reaching |end_time|. + const QuicTime end_time = clock_.Now() + time_span; + run_for_alarm_->Set(end_time); + run_for_should_stop_ = false; + bool simulation_result = RunUntil([this]() { return run_for_should_stop_; }); + + QUICHE_DCHECK(simulation_result); + QUICHE_DCHECK(clock_.Now() == end_time); +} + +void Simulator::HandleNextScheduledActor() { + const auto current_event_it = schedule_.begin(); + QuicTime event_time = current_event_it->first; + Actor* actor = current_event_it->second; + QUIC_DVLOG(3) << "At t = " << event_time.ToDebuggingValue() << ", calling " + << actor->name(); + + Unschedule(actor); + + if (clock_.Now() > event_time) { + QUIC_BUG(quic_bug_10150_1) + << "Error: event registered by [" << actor->name() + << "] requires travelling back in time. Current time: " + << clock_.Now().ToDebuggingValue() + << ", scheduled time: " << event_time.ToDebuggingValue(); + } + clock_.now_ = event_time; + + actor->Act(); +} + +} // namespace simulator +} // namespace quic diff --git a/quiche/quic/test_tools/simulator/simulator.h b/quiche/quic/test_tools/simulator/simulator.h new file mode 100644 index 000000000000..805d7c2dc88c --- /dev/null +++ b/quiche/quic/test_tools/simulator/simulator.h @@ -0,0 +1,166 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_SIMULATOR_SIMULATOR_H_ +#define QUICHE_QUIC_TEST_TOOLS_SIMULATOR_SIMULATOR_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/test_tools/simulator/actor.h" +#include "quiche/quic/test_tools/simulator/alarm_factory.h" +#include "quiche/common/simple_buffer_allocator.h" + +namespace quic { +namespace simulator { + +// Simulator is responsible for scheduling actors in the simulation and +// providing basic utility interfaces (clock, alarms, RNG and others). +class Simulator : public QuicConnectionHelperInterface { + public: + Simulator(); + explicit Simulator(QuicRandom* random_generator); + Simulator(const Simulator&) = delete; + Simulator& operator=(const Simulator&) = delete; + ~Simulator() override; + + // Schedule the specified actor. This method will ensure that |actor| is + // called at |new_time| at latest. If Schedule() is called multiple times + // before the Actor is called, Act() is called exactly once, at the earliest + // time requested, and the Actor has to reschedule itself manually for the + // subsequent times if they are still necessary. + void Schedule(Actor* actor, QuicTime new_time); + + // Remove the specified actor from the schedule. + void Unschedule(Actor* actor); + + // Begin QuicConnectionHelperInterface implementation. + const QuicClock* GetClock() const override; + QuicRandom* GetRandomGenerator() override; + quiche::QuicheBufferAllocator* GetStreamSendBufferAllocator() override; + // End QuicConnectionHelperInterface implementation. + + QuicAlarmFactory* GetAlarmFactory(); + + void set_random_generator(QuicRandom* random) { random_generator_ = random; } + + bool enable_random_delays() const { return enable_random_delays_; } + + // Run the simulation until either no actors are scheduled or + // |termination_predicate| returns true. Returns true if terminated due to + // predicate, and false otherwise. + template + bool RunUntil(TerminationPredicate termination_predicate); + + // Same as RunUntil, except this function also accepts a |deadline|, and will + // return false if the deadline is exceeded. + template + bool RunUntilOrTimeout(TerminationPredicate termination_predicate, + QuicTime::Delta deadline); + + // Runs the simulation for exactly the specified |time_span|. + void RunFor(QuicTime::Delta time_span); + + private: + friend class Actor; + + class Clock : public QuicClock { + public: + // Do not start at zero as certain code can treat zero as an invalid + // timestamp. + const QuicTime kStartTime = + QuicTime::Zero() + QuicTime::Delta::FromMicroseconds(1); + + Clock(); + + QuicTime ApproximateNow() const override; + QuicTime Now() const override; + QuicWallTime WallNow() const override; + + QuicTime now_; + }; + + // The delegate used for RunFor(). + class RunForDelegate : public QuicAlarm::DelegateWithoutContext { + public: + explicit RunForDelegate(bool* run_for_should_stop); + void OnAlarm() override; + + private: + // Pointer to |run_for_should_stop_| in the parent simulator. + bool* run_for_should_stop_; + }; + + // Register an actor with the simulator. Invoked by Actor constructor. + void AddActor(Actor* actor); + + // Unregister an actor with the simulator. Invoked by Actor destructor. + void RemoveActor(Actor* actor); + + // Finds the next scheduled actor, advances time to the schedule time and + // notifies the actor. + void HandleNextScheduledActor(); + + Clock clock_; + QuicRandom* random_generator_; + quiche::SimpleBufferAllocator buffer_allocator_; + AlarmFactory alarm_factory_; + + // Alarm for RunFor() method. + std::unique_ptr run_for_alarm_; + // Flag used to stop simulations ran via RunFor(). + bool run_for_should_stop_; + + // Indicates whether the simulator should add random delays on the links in + // order to avoid synchronization issues. + bool enable_random_delays_; + + // Schedule of when the actors will be executed via an Act() call. The + // schedule is subject to the following invariants: + // - An actor cannot be scheduled for a later time than it's currently in the + // schedule. + // - An actor is removed from schedule either immediately before Act() is + // called or by explicitly calling Unschedule(). + // - Each Actor appears in the map at most once. + std::multimap schedule_; + // For each actor, maintain the time it is scheduled at. The value for + // unscheduled actors is QuicTime::Infinite(). + absl::flat_hash_map scheduled_times_; + absl::flat_hash_set actor_names_; +}; + +template +bool Simulator::RunUntil(TerminationPredicate termination_predicate) { + bool predicate_value = false; + while (true) { + predicate_value = termination_predicate(); + if (predicate_value || schedule_.empty()) { + break; + } + HandleNextScheduledActor(); + } + return predicate_value; +} + +template +bool Simulator::RunUntilOrTimeout(TerminationPredicate termination_predicate, + QuicTime::Delta timeout) { + QuicTime end_time = clock_.Now() + timeout; + bool return_value = RunUntil([end_time, &termination_predicate, this]() { + return termination_predicate() || clock_.Now() >= end_time; + }); + + if (clock_.Now() >= end_time) { + return false; + } + return return_value; +} + +} // namespace simulator +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_SIMULATOR_SIMULATOR_H_ diff --git a/quiche/quic/test_tools/simulator/simulator_test.cc b/quiche/quic/test_tools/simulator/simulator_test.cc new file mode 100644 index 000000000000..4ae04a7803eb --- /dev/null +++ b/quiche/quic/test_tools/simulator/simulator_test.cc @@ -0,0 +1,827 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/simulator/simulator.h" + +#include + +#include "absl/container/node_hash_map.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/test_tools/simulator/alarm_factory.h" +#include "quiche/quic/test_tools/simulator/link.h" +#include "quiche/quic/test_tools/simulator/packet_filter.h" +#include "quiche/quic/test_tools/simulator/queue.h" +#include "quiche/quic/test_tools/simulator/switch.h" +#include "quiche/quic/test_tools/simulator/traffic_policer.h" + +using testing::_; +using testing::Return; +using testing::StrictMock; + +namespace quic { +namespace simulator { + +// A simple counter that increments its value by 1 every specified period. +class Counter : public Actor { + public: + Counter(Simulator* simulator, std::string name, QuicTime::Delta period) + : Actor(simulator, name), value_(-1), period_(period) { + Schedule(clock_->Now()); + } + ~Counter() override {} + + inline int get_value() const { return value_; } + + void Act() override { + ++value_; + QUIC_DVLOG(1) << name_ << " has value " << value_ << " at time " + << clock_->Now().ToDebuggingValue(); + Schedule(clock_->Now() + period_); + } + + private: + int value_; + QuicTime::Delta period_; +}; + +class SimulatorTest : public quic::test::QuicTest {}; + +// Test that the basic event handling works, and that Actors can be created and +// destroyed mid-simulation. +TEST_F(SimulatorTest, Counters) { + Simulator simulator; + for (int i = 0; i < 2; ++i) { + Counter fast_counter(&simulator, "fast_counter", + QuicTime::Delta::FromSeconds(3)); + Counter slow_counter(&simulator, "slow_counter", + QuicTime::Delta::FromSeconds(10)); + + simulator.RunUntil( + [&slow_counter]() { return slow_counter.get_value() >= 10; }); + + EXPECT_EQ(10, slow_counter.get_value()); + EXPECT_EQ(10 * 10 / 3, fast_counter.get_value()); + } +} + +// A port which counts the number of packets received on it, both total and +// per-destination. +class CounterPort : public UnconstrainedPortInterface { + public: + CounterPort() { Reset(); } + ~CounterPort() override {} + + inline QuicByteCount bytes() const { return bytes_; } + inline QuicPacketCount packets() const { return packets_; } + + void AcceptPacket(std::unique_ptr packet) override { + bytes_ += packet->size; + packets_ += 1; + + per_destination_packet_counter_[packet->destination] += 1; + } + + void Reset() { + bytes_ = 0; + packets_ = 0; + per_destination_packet_counter_.clear(); + } + + QuicPacketCount CountPacketsForDestination(std::string destination) const { + auto result_it = per_destination_packet_counter_.find(destination); + if (result_it == per_destination_packet_counter_.cend()) { + return 0; + } + return result_it->second; + } + + private: + QuicByteCount bytes_; + QuicPacketCount packets_; + + absl::node_hash_map + per_destination_packet_counter_; +}; + +// Sends the packet to the specified destination at the uplink rate. Provides a +// CounterPort as an Rx interface. +class LinkSaturator : public Endpoint { + public: + LinkSaturator(Simulator* simulator, std::string name, + QuicByteCount packet_size, std::string destination) + : Endpoint(simulator, name), + packet_size_(packet_size), + destination_(std::move(destination)), + bytes_transmitted_(0), + packets_transmitted_(0) { + Schedule(clock_->Now()); + } + + void Act() override { + if (tx_port_->TimeUntilAvailable().IsZero()) { + auto packet = std::make_unique(); + packet->source = name_; + packet->destination = destination_; + packet->tx_timestamp = clock_->Now(); + packet->size = packet_size_; + + tx_port_->AcceptPacket(std::move(packet)); + + bytes_transmitted_ += packet_size_; + packets_transmitted_ += 1; + } + + Schedule(clock_->Now() + tx_port_->TimeUntilAvailable()); + } + + UnconstrainedPortInterface* GetRxPort() override { + return static_cast(&rx_port_); + } + + void SetTxPort(ConstrainedPortInterface* port) override { tx_port_ = port; } + + CounterPort* counter() { return &rx_port_; } + + inline QuicByteCount bytes_transmitted() const { return bytes_transmitted_; } + inline QuicPacketCount packets_transmitted() const { + return packets_transmitted_; + } + + void Pause() { Unschedule(); } + void Resume() { Schedule(clock_->Now()); } + + private: + QuicByteCount packet_size_; + std::string destination_; + + ConstrainedPortInterface* tx_port_; + CounterPort rx_port_; + + QuicByteCount bytes_transmitted_; + QuicPacketCount packets_transmitted_; +}; + +// Saturate a symmetric link and verify that the number of packets sent and +// received is correct. +TEST_F(SimulatorTest, DirectLinkSaturation) { + Simulator simulator; + LinkSaturator saturator_a(&simulator, "Saturator A", 1000, "Saturator B"); + LinkSaturator saturator_b(&simulator, "Saturator B", 100, "Saturator A"); + SymmetricLink link(&saturator_a, &saturator_b, + QuicBandwidth::FromKBytesPerSecond(1000), + QuicTime::Delta::FromMilliseconds(100) + + QuicTime::Delta::FromMicroseconds(1)); + + const QuicTime start_time = simulator.GetClock()->Now(); + const QuicTime after_first_50_ms = + start_time + QuicTime::Delta::FromMilliseconds(50); + simulator.RunUntil([&simulator, after_first_50_ms]() { + return simulator.GetClock()->Now() >= after_first_50_ms; + }); + EXPECT_LE(1000u * 50u, saturator_a.bytes_transmitted()); + EXPECT_GE(1000u * 51u, saturator_a.bytes_transmitted()); + EXPECT_LE(1000u * 50u, saturator_b.bytes_transmitted()); + EXPECT_GE(1000u * 51u, saturator_b.bytes_transmitted()); + EXPECT_LE(50u, saturator_a.packets_transmitted()); + EXPECT_GE(51u, saturator_a.packets_transmitted()); + EXPECT_LE(500u, saturator_b.packets_transmitted()); + EXPECT_GE(501u, saturator_b.packets_transmitted()); + EXPECT_EQ(0u, saturator_a.counter()->bytes()); + EXPECT_EQ(0u, saturator_b.counter()->bytes()); + + simulator.RunUntil([&saturator_a, &saturator_b]() { + if (saturator_a.counter()->packets() > 1000 || + saturator_b.counter()->packets() > 100) { + ADD_FAILURE() << "The simulation did not arrive at the expected " + "termination contidition. Saturator A counter: " + << saturator_a.counter()->packets() + << ", saturator B counter: " + << saturator_b.counter()->packets(); + return true; + } + + return saturator_a.counter()->packets() == 1000 && + saturator_b.counter()->packets() == 100; + }); + EXPECT_EQ(201u, saturator_a.packets_transmitted()); + EXPECT_EQ(2001u, saturator_b.packets_transmitted()); + EXPECT_EQ(201u * 1000, saturator_a.bytes_transmitted()); + EXPECT_EQ(2001u * 100, saturator_b.bytes_transmitted()); + + EXPECT_EQ(1000u, + saturator_a.counter()->CountPacketsForDestination("Saturator A")); + EXPECT_EQ(100u, + saturator_b.counter()->CountPacketsForDestination("Saturator B")); + EXPECT_EQ(0u, + saturator_a.counter()->CountPacketsForDestination("Saturator B")); + EXPECT_EQ(0u, + saturator_b.counter()->CountPacketsForDestination("Saturator A")); + + const QuicTime end_time = simulator.GetClock()->Now(); + const QuicBandwidth observed_bandwidth = QuicBandwidth::FromBytesAndTimeDelta( + saturator_a.bytes_transmitted(), end_time - start_time); + EXPECT_APPROX_EQ(link.bandwidth(), observed_bandwidth, 0.01f); +} + +// Accepts packets and stores them internally. +class PacketAcceptor : public ConstrainedPortInterface { + public: + void AcceptPacket(std::unique_ptr packet) override { + packets_.emplace_back(std::move(packet)); + } + + QuicTime::Delta TimeUntilAvailable() override { + return QuicTime::Delta::Zero(); + } + + std::vector>* packets() { return &packets_; } + + private: + std::vector> packets_; +}; + +// Ensure the queue behaves correctly with accepting packets. +TEST_F(SimulatorTest, Queue) { + Simulator simulator; + Queue queue(&simulator, "Queue", 1000); + PacketAcceptor acceptor; + queue.set_tx_port(&acceptor); + + EXPECT_EQ(0u, queue.bytes_queued()); + EXPECT_EQ(0u, queue.packets_queued()); + EXPECT_EQ(0u, acceptor.packets()->size()); + + auto first_packet = std::make_unique(); + first_packet->size = 600; + queue.AcceptPacket(std::move(first_packet)); + EXPECT_EQ(600u, queue.bytes_queued()); + EXPECT_EQ(1u, queue.packets_queued()); + EXPECT_EQ(0u, acceptor.packets()->size()); + + // The second packet does not fit and is dropped. + auto second_packet = std::make_unique(); + second_packet->size = 500; + queue.AcceptPacket(std::move(second_packet)); + EXPECT_EQ(600u, queue.bytes_queued()); + EXPECT_EQ(1u, queue.packets_queued()); + EXPECT_EQ(0u, acceptor.packets()->size()); + + auto third_packet = std::make_unique(); + third_packet->size = 400; + queue.AcceptPacket(std::move(third_packet)); + EXPECT_EQ(1000u, queue.bytes_queued()); + EXPECT_EQ(2u, queue.packets_queued()); + EXPECT_EQ(0u, acceptor.packets()->size()); + + // Run until there is nothing scheduled, so that the queue can deplete. + simulator.RunUntil([]() { return false; }); + EXPECT_EQ(0u, queue.bytes_queued()); + EXPECT_EQ(0u, queue.packets_queued()); + ASSERT_EQ(2u, acceptor.packets()->size()); + EXPECT_EQ(600u, acceptor.packets()->at(0)->size); + EXPECT_EQ(400u, acceptor.packets()->at(1)->size); +} + +// Simulate a situation where the bottleneck link is 10 times slower than the +// uplink, and they are separated by a queue. +TEST_F(SimulatorTest, QueueBottleneck) { + const QuicBandwidth local_bandwidth = + QuicBandwidth::FromKBytesPerSecond(1000); + const QuicBandwidth bottleneck_bandwidth = 0.1f * local_bandwidth; + const QuicTime::Delta local_propagation_delay = + QuicTime::Delta::FromMilliseconds(1); + const QuicTime::Delta bottleneck_propagation_delay = + QuicTime::Delta::FromMilliseconds(20); + const QuicByteCount bdp = + bottleneck_bandwidth * + (local_propagation_delay + bottleneck_propagation_delay); + + Simulator simulator; + LinkSaturator saturator(&simulator, "Saturator", 1000, "Counter"); + ASSERT_GE(bdp, 1000u); + Queue queue(&simulator, "Queue", bdp); + CounterPort counter; + + OneWayLink local_link(&simulator, "Local link", &queue, local_bandwidth, + local_propagation_delay); + OneWayLink bottleneck_link(&simulator, "Bottleneck link", &counter, + bottleneck_bandwidth, + bottleneck_propagation_delay); + saturator.SetTxPort(&local_link); + queue.set_tx_port(&bottleneck_link); + + static const QuicPacketCount packets_received = 1000; + simulator.RunUntil( + [&counter]() { return counter.packets() == packets_received; }); + const double loss_ratio = 1 - static_cast(packets_received) / + saturator.packets_transmitted(); + EXPECT_NEAR(loss_ratio, 0.9, 0.001); +} + +// Verify that the queue of exactly one packet allows the transmission to +// actually go through. +TEST_F(SimulatorTest, OnePacketQueue) { + const QuicBandwidth local_bandwidth = + QuicBandwidth::FromKBytesPerSecond(1000); + const QuicBandwidth bottleneck_bandwidth = 0.1f * local_bandwidth; + const QuicTime::Delta local_propagation_delay = + QuicTime::Delta::FromMilliseconds(1); + const QuicTime::Delta bottleneck_propagation_delay = + QuicTime::Delta::FromMilliseconds(20); + + Simulator simulator; + LinkSaturator saturator(&simulator, "Saturator", 1000, "Counter"); + Queue queue(&simulator, "Queue", 1000); + CounterPort counter; + + OneWayLink local_link(&simulator, "Local link", &queue, local_bandwidth, + local_propagation_delay); + OneWayLink bottleneck_link(&simulator, "Bottleneck link", &counter, + bottleneck_bandwidth, + bottleneck_propagation_delay); + saturator.SetTxPort(&local_link); + queue.set_tx_port(&bottleneck_link); + + static const QuicPacketCount packets_received = 10; + // The deadline here is to prevent this tests from looping infinitely in case + // the packets never reach the receiver. + const QuicTime deadline = + simulator.GetClock()->Now() + QuicTime::Delta::FromSeconds(10); + simulator.RunUntil([&simulator, &counter, deadline]() { + return counter.packets() == packets_received || + simulator.GetClock()->Now() > deadline; + }); + ASSERT_EQ(packets_received, counter.packets()); +} + +// Simulate a network where three endpoints are connected to a switch and they +// are sending traffic in circle (1 -> 2, 2 -> 3, 3 -> 1). +TEST_F(SimulatorTest, SwitchedNetwork) { + const QuicBandwidth bandwidth = QuicBandwidth::FromBytesPerSecond(10000); + const QuicTime::Delta base_propagation_delay = + QuicTime::Delta::FromMilliseconds(50); + + Simulator simulator; + LinkSaturator saturator1(&simulator, "Saturator 1", 1000, "Saturator 2"); + LinkSaturator saturator2(&simulator, "Saturator 2", 1000, "Saturator 3"); + LinkSaturator saturator3(&simulator, "Saturator 3", 1000, "Saturator 1"); + Switch network_switch(&simulator, "Switch", 8, + bandwidth * base_propagation_delay * 10); + + // For determinicity, make it so that the first packet will arrive from + // Saturator 1, then from Saturator 2, and then from Saturator 3. + SymmetricLink link1(&saturator1, network_switch.port(1), bandwidth, + base_propagation_delay); + SymmetricLink link2(&saturator2, network_switch.port(2), bandwidth, + base_propagation_delay * 2); + SymmetricLink link3(&saturator3, network_switch.port(3), bandwidth, + base_propagation_delay * 3); + + const QuicTime start_time = simulator.GetClock()->Now(); + static const QuicPacketCount bytes_received = 64 * 1000; + simulator.RunUntil([&saturator1]() { + return saturator1.counter()->bytes() >= bytes_received; + }); + const QuicTime end_time = simulator.GetClock()->Now(); + + const QuicBandwidth observed_bandwidth = QuicBandwidth::FromBytesAndTimeDelta( + bytes_received, end_time - start_time); + const double bandwidth_ratio = + static_cast(observed_bandwidth.ToBitsPerSecond()) / + bandwidth.ToBitsPerSecond(); + EXPECT_NEAR(1, bandwidth_ratio, 0.1); + + const double normalized_received_packets_for_saturator_2 = + static_cast(saturator2.counter()->packets()) / + saturator1.counter()->packets(); + const double normalized_received_packets_for_saturator_3 = + static_cast(saturator3.counter()->packets()) / + saturator1.counter()->packets(); + EXPECT_NEAR(1, normalized_received_packets_for_saturator_2, 0.1); + EXPECT_NEAR(1, normalized_received_packets_for_saturator_3, 0.1); + + // Since Saturator 1 has its packet arrive first into the switch, switch will + // always know how to route traffic to it. + EXPECT_EQ(0u, + saturator2.counter()->CountPacketsForDestination("Saturator 1")); + EXPECT_EQ(0u, + saturator3.counter()->CountPacketsForDestination("Saturator 1")); + + // Packets from the other saturators will be broadcast at least once. + EXPECT_EQ(1u, + saturator1.counter()->CountPacketsForDestination("Saturator 2")); + EXPECT_EQ(1u, + saturator3.counter()->CountPacketsForDestination("Saturator 2")); + EXPECT_EQ(1u, + saturator1.counter()->CountPacketsForDestination("Saturator 3")); + EXPECT_EQ(1u, + saturator2.counter()->CountPacketsForDestination("Saturator 3")); +} + +// Toggle an alarm on and off at the specified interval. Assumes that alarm is +// initially set and unsets it almost immediately after the object is +// instantiated. +class AlarmToggler : public Actor { + public: + AlarmToggler(Simulator* simulator, std::string name, QuicAlarm* alarm, + QuicTime::Delta interval) + : Actor(simulator, name), + alarm_(alarm), + interval_(interval), + deadline_(alarm->deadline()), + times_set_(0), + times_cancelled_(0) { + EXPECT_TRUE(alarm->IsSet()); + EXPECT_GE(alarm->deadline(), clock_->Now()); + Schedule(clock_->Now()); + } + + void Act() override { + if (deadline_ <= clock_->Now()) { + return; + } + + if (alarm_->IsSet()) { + alarm_->Cancel(); + times_cancelled_++; + } else { + alarm_->Set(deadline_); + times_set_++; + } + + Schedule(clock_->Now() + interval_); + } + + inline int times_set() { return times_set_; } + inline int times_cancelled() { return times_cancelled_; } + + private: + QuicAlarm* alarm_; + QuicTime::Delta interval_; + QuicTime deadline_; + + // Counts the number of times the alarm was set. + int times_set_; + // Counts the number of times the alarm was cancelled. + int times_cancelled_; +}; + +// Counts the number of times an alarm has fired. +class CounterDelegate : public QuicAlarm::DelegateWithoutContext { + public: + explicit CounterDelegate(size_t* counter) : counter_(counter) {} + + void OnAlarm() override { *counter_ += 1; } + + private: + size_t* counter_; +}; + +// Verifies that the alarms work correctly, even when they are repeatedly +// toggled. +TEST_F(SimulatorTest, Alarms) { + Simulator simulator; + QuicAlarmFactory* alarm_factory = simulator.GetAlarmFactory(); + + size_t fast_alarm_counter = 0; + size_t slow_alarm_counter = 0; + std::unique_ptr alarm_fast( + alarm_factory->CreateAlarm(new CounterDelegate(&fast_alarm_counter))); + std::unique_ptr alarm_slow( + alarm_factory->CreateAlarm(new CounterDelegate(&slow_alarm_counter))); + + const QuicTime start_time = simulator.GetClock()->Now(); + alarm_fast->Set(start_time + QuicTime::Delta::FromMilliseconds(100)); + alarm_slow->Set(start_time + QuicTime::Delta::FromMilliseconds(750)); + AlarmToggler toggler(&simulator, "Toggler", alarm_slow.get(), + QuicTime::Delta::FromMilliseconds(100)); + + const QuicTime end_time = + start_time + QuicTime::Delta::FromMilliseconds(1000); + EXPECT_FALSE(simulator.RunUntil([&simulator, end_time]() { + return simulator.GetClock()->Now() >= end_time; + })); + EXPECT_EQ(1u, slow_alarm_counter); + EXPECT_EQ(1u, fast_alarm_counter); + + EXPECT_EQ(4, toggler.times_set()); + EXPECT_EQ(4, toggler.times_cancelled()); +} + +// Verifies that a cancelled alarm is never fired. +TEST_F(SimulatorTest, AlarmCancelling) { + Simulator simulator; + QuicAlarmFactory* alarm_factory = simulator.GetAlarmFactory(); + + size_t alarm_counter = 0; + std::unique_ptr alarm( + alarm_factory->CreateAlarm(new CounterDelegate(&alarm_counter))); + + const QuicTime start_time = simulator.GetClock()->Now(); + const QuicTime alarm_at = start_time + QuicTime::Delta::FromMilliseconds(300); + const QuicTime end_time = start_time + QuicTime::Delta::FromMilliseconds(400); + + alarm->Set(alarm_at); + alarm->Cancel(); + EXPECT_FALSE(alarm->IsSet()); + + EXPECT_FALSE(simulator.RunUntil([&simulator, end_time]() { + return simulator.GetClock()->Now() >= end_time; + })); + + EXPECT_FALSE(alarm->IsSet()); + EXPECT_EQ(0u, alarm_counter); +} + +// Verifies that alarms can be scheduled into the past. +TEST_F(SimulatorTest, AlarmInPast) { + Simulator simulator; + QuicAlarmFactory* alarm_factory = simulator.GetAlarmFactory(); + + size_t alarm_counter = 0; + std::unique_ptr alarm( + alarm_factory->CreateAlarm(new CounterDelegate(&alarm_counter))); + + const QuicTime start_time = simulator.GetClock()->Now(); + simulator.RunFor(QuicTime::Delta::FromMilliseconds(400)); + + alarm->Set(start_time); + simulator.RunFor(QuicTime::Delta::FromMilliseconds(1)); + EXPECT_FALSE(alarm->IsSet()); + EXPECT_EQ(1u, alarm_counter); +} + +// Tests Simulator::RunUntilOrTimeout() interface. +TEST_F(SimulatorTest, RunUntilOrTimeout) { + Simulator simulator; + bool simulation_result; + + // Count the number of seconds since the beginning of the simulation. + Counter counter(&simulator, "counter", QuicTime::Delta::FromSeconds(1)); + + // Ensure that the counter reaches the value of 10 given a 20 second deadline. + simulation_result = simulator.RunUntilOrTimeout( + [&counter]() { return counter.get_value() == 10; }, + QuicTime::Delta::FromSeconds(20)); + ASSERT_TRUE(simulation_result); + + // Ensure that the counter will not reach the value of 100 given that the + // starting value is 10 and the deadline is 20 seconds. + simulation_result = simulator.RunUntilOrTimeout( + [&counter]() { return counter.get_value() == 100; }, + QuicTime::Delta::FromSeconds(20)); + ASSERT_FALSE(simulation_result); +} + +// Tests Simulator::RunFor() interface. +TEST_F(SimulatorTest, RunFor) { + Simulator simulator; + + Counter counter(&simulator, "counter", QuicTime::Delta::FromSeconds(3)); + + simulator.RunFor(QuicTime::Delta::FromSeconds(100)); + + EXPECT_EQ(33, counter.get_value()); +} + +class MockPacketFilter : public PacketFilter { + public: + MockPacketFilter(Simulator* simulator, std::string name, Endpoint* endpoint) + : PacketFilter(simulator, name, endpoint) {} + MOCK_METHOD(bool, FilterPacket, (const Packet&), (override)); +}; + +// Set up two trivial packet filters, one allowing any packets, and one dropping +// all of them. +TEST_F(SimulatorTest, PacketFilter) { + const QuicBandwidth bandwidth = + QuicBandwidth::FromBytesPerSecond(1024 * 1024); + const QuicTime::Delta base_propagation_delay = + QuicTime::Delta::FromMilliseconds(5); + + Simulator simulator; + LinkSaturator saturator_a(&simulator, "Saturator A", 1000, "Saturator B"); + LinkSaturator saturator_b(&simulator, "Saturator B", 1000, "Saturator A"); + + // Attach packets to the switch to create a delay between the point at which + // the packet is generated and the point at which it is filtered. Note that + // if the saturators were connected directly, the link would be always + // available for the endpoint which has all of its packets dropped, resulting + // in saturator looping infinitely. + Switch network_switch(&simulator, "Switch", 8, + bandwidth * base_propagation_delay * 10); + StrictMock a_to_b_filter(&simulator, "A -> B filter", + network_switch.port(1)); + StrictMock b_to_a_filter(&simulator, "B -> A filter", + network_switch.port(2)); + SymmetricLink link_a(&a_to_b_filter, &saturator_b, bandwidth, + base_propagation_delay); + SymmetricLink link_b(&b_to_a_filter, &saturator_a, bandwidth, + base_propagation_delay); + + // Allow packets from A to B, but not from B to A. + EXPECT_CALL(a_to_b_filter, FilterPacket(_)).WillRepeatedly(Return(true)); + EXPECT_CALL(b_to_a_filter, FilterPacket(_)).WillRepeatedly(Return(false)); + + // Run the simulation for a while, and expect that only B will receive any + // packets. + simulator.RunFor(QuicTime::Delta::FromSeconds(10)); + EXPECT_GE(saturator_b.counter()->packets(), 1u); + EXPECT_EQ(saturator_a.counter()->packets(), 0u); +} + +// Set up a traffic policer in one direction that throttles at 25% of link +// bandwidth, and put two link saturators at each endpoint. +TEST_F(SimulatorTest, TrafficPolicer) { + const QuicBandwidth bandwidth = + QuicBandwidth::FromBytesPerSecond(1024 * 1024); + const QuicTime::Delta base_propagation_delay = + QuicTime::Delta::FromMilliseconds(5); + const QuicTime::Delta timeout = QuicTime::Delta::FromSeconds(10); + + Simulator simulator; + LinkSaturator saturator1(&simulator, "Saturator 1", 1000, "Saturator 2"); + LinkSaturator saturator2(&simulator, "Saturator 2", 1000, "Saturator 1"); + Switch network_switch(&simulator, "Switch", 8, + bandwidth * base_propagation_delay * 10); + + static const QuicByteCount initial_burst = 1000 * 10; + static const QuicByteCount max_bucket_size = 1000 * 100; + static const QuicBandwidth target_bandwidth = bandwidth * 0.25; + TrafficPolicer policer(&simulator, "Policer", initial_burst, max_bucket_size, + target_bandwidth, network_switch.port(2)); + + SymmetricLink link1(&saturator1, network_switch.port(1), bandwidth, + base_propagation_delay); + SymmetricLink link2(&saturator2, &policer, bandwidth, base_propagation_delay); + + // Ensure the initial burst passes without being dropped at all. + bool simulator_result = simulator.RunUntilOrTimeout( + [&saturator1]() { + return saturator1.bytes_transmitted() == initial_burst; + }, + timeout); + ASSERT_TRUE(simulator_result); + saturator1.Pause(); + simulator_result = simulator.RunUntilOrTimeout( + [&saturator2]() { + return saturator2.counter()->bytes() == initial_burst; + }, + timeout); + ASSERT_TRUE(simulator_result); + saturator1.Resume(); + + // Run for some time so that the initial burst is not visible. + const QuicTime::Delta simulation_time = QuicTime::Delta::FromSeconds(10); + simulator.RunFor(simulation_time); + + // Ensure we've transmitted the amount of data we expected. + for (auto* saturator : {&saturator1, &saturator2}) { + EXPECT_APPROX_EQ(bandwidth * simulation_time, + saturator->bytes_transmitted(), 0.01f); + } + + // Check that only one direction is throttled. + EXPECT_APPROX_EQ(saturator1.bytes_transmitted() / 4, + saturator2.counter()->bytes(), 0.1f); + EXPECT_APPROX_EQ(saturator2.bytes_transmitted(), + saturator1.counter()->bytes(), 0.1f); +} + +// Ensure that a larger burst is allowed when the policed saturator exits +// quiescence. +TEST_F(SimulatorTest, TrafficPolicerBurst) { + const QuicBandwidth bandwidth = + QuicBandwidth::FromBytesPerSecond(1024 * 1024); + const QuicTime::Delta base_propagation_delay = + QuicTime::Delta::FromMilliseconds(5); + const QuicTime::Delta timeout = QuicTime::Delta::FromSeconds(10); + + Simulator simulator; + LinkSaturator saturator1(&simulator, "Saturator 1", 1000, "Saturator 2"); + LinkSaturator saturator2(&simulator, "Saturator 2", 1000, "Saturator 1"); + Switch network_switch(&simulator, "Switch", 8, + bandwidth * base_propagation_delay * 10); + + const QuicByteCount initial_burst = 1000 * 10; + const QuicByteCount max_bucket_size = 1000 * 100; + const QuicBandwidth target_bandwidth = bandwidth * 0.25; + TrafficPolicer policer(&simulator, "Policer", initial_burst, max_bucket_size, + target_bandwidth, network_switch.port(2)); + + SymmetricLink link1(&saturator1, network_switch.port(1), bandwidth, + base_propagation_delay); + SymmetricLink link2(&saturator2, &policer, bandwidth, base_propagation_delay); + + // Ensure at least one packet is sent on each side. + bool simulator_result = simulator.RunUntilOrTimeout( + [&saturator1, &saturator2]() { + return saturator1.packets_transmitted() > 0 && + saturator2.packets_transmitted() > 0; + }, + timeout); + ASSERT_TRUE(simulator_result); + + // Wait until the bucket fills up. + saturator1.Pause(); + saturator2.Pause(); + simulator.RunFor(1.5f * target_bandwidth.TransferTime(max_bucket_size)); + + // Send a burst. + saturator1.Resume(); + simulator.RunFor(bandwidth.TransferTime(max_bucket_size)); + saturator1.Pause(); + simulator.RunFor(2 * base_propagation_delay); + + // Expect the burst to pass without losses. + EXPECT_APPROX_EQ(saturator1.bytes_transmitted(), + saturator2.counter()->bytes(), 0.1f); + + // Expect subsequent traffic to be policed. + saturator1.Resume(); + simulator.RunFor(QuicTime::Delta::FromSeconds(10)); + EXPECT_APPROX_EQ(saturator1.bytes_transmitted() / 4, + saturator2.counter()->bytes(), 0.1f); +} + +// Test that the packet aggregation support in queues work. +TEST_F(SimulatorTest, PacketAggregation) { + // Model network where the delays are dominated by transfer delay. + const QuicBandwidth bandwidth = QuicBandwidth::FromBytesPerSecond(1000); + const QuicTime::Delta base_propagation_delay = + QuicTime::Delta::FromMicroseconds(1); + const QuicByteCount aggregation_threshold = 1000; + const QuicTime::Delta aggregation_timeout = QuicTime::Delta::FromSeconds(30); + + Simulator simulator; + LinkSaturator saturator1(&simulator, "Saturator 1", 10, "Saturator 2"); + LinkSaturator saturator2(&simulator, "Saturator 2", 10, "Saturator 1"); + Switch network_switch(&simulator, "Switch", 8, 10 * aggregation_threshold); + + // Make links with asymmetric propagation delay so that Saturator 2 only + // receives packets addressed to it. + SymmetricLink link1(&saturator1, network_switch.port(1), bandwidth, + base_propagation_delay); + SymmetricLink link2(&saturator2, network_switch.port(2), bandwidth, + 2 * base_propagation_delay); + + // Enable aggregation in 1 -> 2 direction. + Queue* queue = network_switch.port_queue(2); + queue->EnableAggregation(aggregation_threshold, aggregation_timeout); + + // Enable aggregation in 2 -> 1 direction in a way that all packets are larger + // than the threshold, so that aggregation is effectively a no-op. + network_switch.port_queue(1)->EnableAggregation(5, aggregation_timeout); + + // Fill up the aggregation buffer up to 90% (900 bytes). + simulator.RunFor(0.9 * bandwidth.TransferTime(aggregation_threshold)); + EXPECT_EQ(0u, saturator2.counter()->bytes()); + + // Stop sending, ensure that given a timespan much shorter than timeout, the + // packets remain in the queue. + saturator1.Pause(); + saturator2.Pause(); + simulator.RunFor(QuicTime::Delta::FromSeconds(10)); + EXPECT_EQ(0u, saturator2.counter()->bytes()); + EXPECT_EQ(900u, queue->bytes_queued()); + + // Ensure that all packets have reached the saturator not affected by + // aggregation. Here, 10 extra bytes account for a misrouted packet in the + // beginning. + EXPECT_EQ(910u, saturator1.counter()->bytes()); + + // Send 500 more bytes. Since the aggregation threshold is 1000 bytes, and + // queue already has 900 bytes, 1000 bytes will be send and 400 will be in the + // queue. + saturator1.Resume(); + simulator.RunFor(0.5 * bandwidth.TransferTime(aggregation_threshold)); + saturator1.Pause(); + simulator.RunFor(QuicTime::Delta::FromSeconds(10)); + EXPECT_EQ(1000u, saturator2.counter()->bytes()); + EXPECT_EQ(400u, queue->bytes_queued()); + + // Actually time out, and cause all of the data to be received. + simulator.RunFor(aggregation_timeout); + EXPECT_EQ(1400u, saturator2.counter()->bytes()); + EXPECT_EQ(0u, queue->bytes_queued()); + + // Run saturator for a longer time, to ensure that the logic to cancel and + // reset alarms works correctly. + saturator1.Resume(); + simulator.RunFor(5.5 * bandwidth.TransferTime(aggregation_threshold)); + saturator1.Pause(); + simulator.RunFor(QuicTime::Delta::FromSeconds(10)); + EXPECT_EQ(6400u, saturator2.counter()->bytes()); + EXPECT_EQ(500u, queue->bytes_queued()); + + // Time out again. + simulator.RunFor(aggregation_timeout); + EXPECT_EQ(6900u, saturator2.counter()->bytes()); + EXPECT_EQ(0u, queue->bytes_queued()); +} + +} // namespace simulator +} // namespace quic diff --git a/quiche/quic/test_tools/simulator/switch.cc b/quiche/quic/test_tools/simulator/switch.cc new file mode 100644 index 000000000000..fbd396e241a3 --- /dev/null +++ b/quiche/quic/test_tools/simulator/switch.cc @@ -0,0 +1,77 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/simulator/switch.h" + +#include +#include + +#include "absl/strings/str_cat.h" + +namespace quic { +namespace simulator { + +Switch::Switch(Simulator* simulator, std::string name, + SwitchPortNumber port_count, QuicByteCount queue_capacity) { + for (size_t port_number = 1; port_number <= port_count; port_number++) { + ports_.emplace_back(simulator, + absl::StrCat(name, " (port ", port_number, ")"), this, + port_number, queue_capacity); + } +} + +Switch::~Switch() {} + +Switch::Port::Port(Simulator* simulator, std::string name, Switch* parent, + SwitchPortNumber port_number, QuicByteCount queue_capacity) + : Endpoint(simulator, name), + parent_(parent), + port_number_(port_number), + connected_(false), + queue_(simulator, absl::StrCat(name, " (queue)"), queue_capacity) {} + +void Switch::Port::AcceptPacket(std::unique_ptr packet) { + parent_->DispatchPacket(port_number_, std::move(packet)); +} + +void Switch::Port::EnqueuePacket(std::unique_ptr packet) { + queue_.AcceptPacket(std::move(packet)); +} + +UnconstrainedPortInterface* Switch::Port::GetRxPort() { return this; } + +void Switch::Port::SetTxPort(ConstrainedPortInterface* port) { + queue_.set_tx_port(port); + connected_ = true; +} + +void Switch::Port::Act() {} + +void Switch::DispatchPacket(SwitchPortNumber port_number, + std::unique_ptr packet) { + Port* source_port = &ports_[port_number - 1]; + const auto source_mapping_it = switching_table_.find(packet->source); + if (source_mapping_it == switching_table_.end()) { + switching_table_.insert(std::make_pair(packet->source, source_port)); + } + + const auto destination_mapping_it = + switching_table_.find(packet->destination); + if (destination_mapping_it != switching_table_.end()) { + destination_mapping_it->second->EnqueuePacket(std::move(packet)); + return; + } + + // If no mapping is available yet, broadcast the packet to all ports + // different from the source. + for (Port& egress_port : ports_) { + if (!egress_port.connected()) { + continue; + } + egress_port.EnqueuePacket(std::make_unique(*packet)); + } +} + +} // namespace simulator +} // namespace quic diff --git a/quiche/quic/test_tools/simulator/switch.h b/quiche/quic/test_tools/simulator/switch.h new file mode 100644 index 000000000000..1bafed41493b --- /dev/null +++ b/quiche/quic/test_tools/simulator/switch.h @@ -0,0 +1,84 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_SIMULATOR_SWITCH_H_ +#define QUICHE_QUIC_TEST_TOOLS_SIMULATOR_SWITCH_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "quiche/quic/test_tools/simulator/queue.h" + +namespace quic { +namespace simulator { + +using SwitchPortNumber = size_t; + +// Simulates a network switch with simple persistent learning scheme and queues +// on every output port. +class Switch { + public: + Switch(Simulator* simulator, std::string name, SwitchPortNumber port_count, + QuicByteCount queue_capacity); + Switch(const Switch&) = delete; + Switch& operator=(const Switch&) = delete; + ~Switch(); + + // Returns Endpoint associated with the port under number |port_number|. Just + // like on most real switches, port numbering starts with 1. + Endpoint* port(SwitchPortNumber port_number) { + QUICHE_DCHECK_NE(port_number, 0u); + return &ports_[port_number - 1]; + } + + Queue* port_queue(SwitchPortNumber port_number) { + return ports_[port_number - 1].queue(); + } + + private: + class Port : public Endpoint, public UnconstrainedPortInterface { + public: + Port(Simulator* simulator, std::string name, Switch* parent, + SwitchPortNumber port_number, QuicByteCount queue_capacity); + Port(Port&&) = delete; + Port(const Port&) = delete; + Port& operator=(const Port&) = delete; + ~Port() override {} + + // Accepts packet to be routed into the switch. + void AcceptPacket(std::unique_ptr packet) override; + // Enqueue packet to be routed out of the switch. + void EnqueuePacket(std::unique_ptr packet); + + UnconstrainedPortInterface* GetRxPort() override; + void SetTxPort(ConstrainedPortInterface* port) override; + + void Act() override; + + bool connected() const { return connected_; } + Queue* queue() { return &queue_; } + + private: + Switch* parent_; + SwitchPortNumber port_number_; + bool connected_; + + Queue queue_; + }; + + // Sends the packet to the appropriate port, or to all ports if the + // appropriate port is not known. + void DispatchPacket(SwitchPortNumber port_number, + std::unique_ptr packet); + + // This cannot be a quiche::QuicheCircularDeque since pointers into this are + // assumed to be stable. + std::deque ports_; + absl::flat_hash_map switching_table_; +}; + +} // namespace simulator +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_SIMULATOR_SWITCH_H_ diff --git a/quiche/quic/test_tools/simulator/test_harness.cc b/quiche/quic/test_tools/simulator/test_harness.cc new file mode 100644 index 000000000000..1dfc8a2470ea --- /dev/null +++ b/quiche/quic/test_tools/simulator/test_harness.cc @@ -0,0 +1,35 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/simulator/test_harness.h" + +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/test_tools/simulator/quic_endpoint_base.h" + +namespace quic::simulator { + +QuicEndpointWithConnection::QuicEndpointWithConnection( + Simulator* simulator, const std::string& name, const std::string& peer_name, + Perspective perspective, const ParsedQuicVersionVector& supported_versions) + : QuicEndpointBase(simulator, name, peer_name) { + connection_ = std::make_unique( + quic::test::TestConnectionId(0x10), GetAddressFromName(name), + GetAddressFromName(peer_name), simulator, simulator->GetAlarmFactory(), + &writer_, /*owns_writer=*/false, perspective, supported_versions, + connection_id_generator_); + connection_->SetSelfAddress(GetAddressFromName(name)); +} + +TestHarness::TestHarness() : switch_(&simulator_, "Switch", 8, 2 * kBdp) {} + +void TestHarness::WireUpEndpoints() { + client_link_.emplace(client_, switch_.port(1), kClientBandwidth, + kClientPropagationDelay); + server_link_.emplace(server_, switch_.port(2), kServerBandwidth, + kServerPropagationDelay); +} + +} // namespace quic::simulator diff --git a/quiche/quic/test_tools/simulator/test_harness.h b/quiche/quic/test_tools/simulator/test_harness.h new file mode 100644 index 000000000000..681cfa3fa4e0 --- /dev/null +++ b/quiche/quic/test_tools/simulator/test_harness.h @@ -0,0 +1,83 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_SIMULATOR_TEST_HARNESS_H_ +#define QUICHE_QUIC_TEST_TOOLS_SIMULATOR_TEST_HARNESS_H_ + +#include + +#include "absl/types/optional.h" +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/test_tools/simulator/link.h" +#include "quiche/quic/test_tools/simulator/port.h" +#include "quiche/quic/test_tools/simulator/quic_endpoint_base.h" +#include "quiche/quic/test_tools/simulator/simulator.h" +#include "quiche/quic/test_tools/simulator/switch.h" + +namespace quic::simulator { + +// A subclass of QuicEndpointBase that creates the connection object for the +// caller. Uses a fixed connection ID (0x10) and IP addresses derived from the +// names supplied. +class QuicEndpointWithConnection : public QuicEndpointBase { + public: + QuicEndpointWithConnection(Simulator* simulator, const std::string& name, + const std::string& peer_name, + Perspective perspective, + const ParsedQuicVersionVector& supported_versions); +}; + +// A test harness that provides a reasonable preset for running unit tests. +class TestHarness { + public: + // The configuration of the test harness. + static constexpr QuicBandwidth kClientBandwidth = + QuicBandwidth::FromKBitsPerSecond(10000); + static constexpr QuicTime::Delta kClientPropagationDelay = + QuicTime::Delta::FromMilliseconds(2); + static constexpr QuicBandwidth kServerBandwidth = + QuicBandwidth::FromKBitsPerSecond(4000); + static constexpr QuicTime::Delta kServerPropagationDelay = + QuicTime::Delta::FromMilliseconds(50); + static constexpr QuicTime::Delta kTransferTime = + kClientBandwidth.TransferTime(kMaxOutgoingPacketSize) + + kServerBandwidth.TransferTime(kMaxOutgoingPacketSize); + static constexpr QuicTime::Delta kRtt = + (kClientPropagationDelay + kServerPropagationDelay + kTransferTime) * 2; + static constexpr QuicByteCount kBdp = kRtt * kServerBandwidth; + + static constexpr QuicTime::Delta kDefaultTimeout = + QuicTime::Delta::FromSeconds(3); + + TestHarness(); + + Simulator& simulator() { return simulator_; } + void set_client(Endpoint* client) { client_ = client; } + void set_server(Endpoint* server) { server_ = server; } + + // Connects |client_| and |server_| to a virtual switch; must be called after + // set_client/set_server are called. + void WireUpEndpoints(); + + // A convenience wrapper around Simulator::RunUntilOrTimeout(). + template + bool RunUntilWithDefaultTimeout(TerminationPredicate termination_predicate) { + return simulator_.RunUntilOrTimeout(std::move(termination_predicate), + kDefaultTimeout); + } + + private: + Simulator simulator_; + Switch switch_; + absl::optional client_link_; + absl::optional server_link_; + + Endpoint* client_; + Endpoint* server_; +}; + +} // namespace quic::simulator + +#endif // QUICHE_QUIC_TEST_TOOLS_SIMULATOR_TEST_HARNESS_H_ diff --git a/quiche/quic/test_tools/simulator/traffic_policer.cc b/quiche/quic/test_tools/simulator/traffic_policer.cc new file mode 100644 index 000000000000..d701a16e360d --- /dev/null +++ b/quiche/quic/test_tools/simulator/traffic_policer.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/simulator/traffic_policer.h" + +#include + +namespace quic { +namespace simulator { + +TrafficPolicer::TrafficPolicer(Simulator* simulator, std::string name, + QuicByteCount initial_bucket_size, + QuicByteCount max_bucket_size, + QuicBandwidth target_bandwidth, Endpoint* input) + : PacketFilter(simulator, name, input), + initial_bucket_size_(initial_bucket_size), + max_bucket_size_(max_bucket_size), + target_bandwidth_(target_bandwidth), + last_refill_time_(clock_->Now()) {} + +TrafficPolicer::~TrafficPolicer() {} + +void TrafficPolicer::Refill() { + QuicTime::Delta time_passed = clock_->Now() - last_refill_time_; + QuicByteCount refill_size = time_passed * target_bandwidth_; + + for (auto& bucket : token_buckets_) { + bucket.second = std::min(bucket.second + refill_size, max_bucket_size_); + } + + last_refill_time_ = clock_->Now(); +} + +bool TrafficPolicer::FilterPacket(const Packet& packet) { + // Refill existing buckets. + Refill(); + + // Create a new bucket if one does not exist. + if (token_buckets_.count(packet.destination) == 0) { + token_buckets_.insert( + std::make_pair(packet.destination, initial_bucket_size_)); + } + + auto bucket = token_buckets_.find(packet.destination); + QUICHE_DCHECK(bucket != token_buckets_.end()); + + // Silently drop the packet on the floor if out of tokens + if (bucket->second < packet.size) { + return false; + } + + bucket->second -= packet.size; + return true; +} + +} // namespace simulator +} // namespace quic diff --git a/quiche/quic/test_tools/simulator/traffic_policer.h b/quiche/quic/test_tools/simulator/traffic_policer.h new file mode 100644 index 000000000000..710ad739e088 --- /dev/null +++ b/quiche/quic/test_tools/simulator/traffic_policer.h @@ -0,0 +1,52 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_SIMULATOR_TRAFFIC_POLICER_H_ +#define QUICHE_QUIC_TEST_TOOLS_SIMULATOR_TRAFFIC_POLICER_H_ + +#include "absl/container/flat_hash_map.h" +#include "quiche/quic/test_tools/simulator/packet_filter.h" +#include "quiche/quic/test_tools/simulator/port.h" + +namespace quic { +namespace simulator { + +// Traffic policer uses a token bucket to limit the bandwidth of the traffic +// passing through. It wraps around an input port and exposes an output port. +// Only the traffic from input to the output is policed, so in case when +// bidirectional policing is desired, two policers have to be used. The flows +// are hashed by the destination only. +class TrafficPolicer : public PacketFilter { + public: + TrafficPolicer(Simulator* simulator, std::string name, + QuicByteCount initial_bucket_size, + QuicByteCount max_bucket_size, QuicBandwidth target_bandwidth, + Endpoint* input); + TrafficPolicer(const TrafficPolicer&) = delete; + TrafficPolicer& operator=(const TrafficPolicer&) = delete; + ~TrafficPolicer() override; + + protected: + bool FilterPacket(const Packet& packet) override; + + private: + // Refill the token buckets with all the tokens that have been granted since + // |last_refill_time_|. + void Refill(); + + QuicByteCount initial_bucket_size_; + QuicByteCount max_bucket_size_; + QuicBandwidth target_bandwidth_; + + // The time at which the token buckets were last refilled. + QuicTime last_refill_time_; + + // Maps each destination to the number of tokens it has left. + absl::flat_hash_map token_buckets_; +}; + +} // namespace simulator +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_SIMULATOR_TRAFFIC_POLICER_H_ diff --git a/quiche/quic/test_tools/test_certificates.cc b/quiche/quic/test_tools/test_certificates.cc new file mode 100644 index 000000000000..6d450a81cf3d --- /dev/null +++ b/quiche/quic/test_tools/test_certificates.cc @@ -0,0 +1,731 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/test_certificates.h" + +namespace quic { +namespace test { + +// A test certificate generated by //net/tools/quic/certs/generate-certs.sh. +ABSL_CONST_INIT const char kTestCertificateRaw[] = { + '\x30', '\x82', '\x03', '\xb4', '\x30', '\x82', '\x02', '\x9c', '\xa0', + '\x03', '\x02', '\x01', '\x02', '\x02', '\x01', '\x01', '\x30', '\x0d', + '\x06', '\x09', '\x2a', '\x86', '\x48', '\x86', '\xf7', '\x0d', '\x01', + '\x01', '\x0b', '\x05', '\x00', '\x30', '\x1e', '\x31', '\x1c', '\x30', + '\x1a', '\x06', '\x03', '\x55', '\x04', '\x03', '\x0c', '\x13', '\x51', + '\x55', '\x49', '\x43', '\x20', '\x53', '\x65', '\x72', '\x76', '\x65', + '\x72', '\x20', '\x52', '\x6f', '\x6f', '\x74', '\x20', '\x43', '\x41', + '\x30', '\x1e', '\x17', '\x0d', '\x32', '\x30', '\x30', '\x31', '\x33', + '\x30', '\x31', '\x38', '\x31', '\x33', '\x35', '\x39', '\x5a', '\x17', + '\x0d', '\x32', '\x30', '\x30', '\x32', '\x30', '\x32', '\x31', '\x38', + '\x31', '\x33', '\x35', '\x39', '\x5a', '\x30', '\x64', '\x31', '\x0b', + '\x30', '\x09', '\x06', '\x03', '\x55', '\x04', '\x06', '\x13', '\x02', + '\x55', '\x53', '\x31', '\x13', '\x30', '\x11', '\x06', '\x03', '\x55', + '\x04', '\x08', '\x0c', '\x0a', '\x43', '\x61', '\x6c', '\x69', '\x66', + '\x6f', '\x72', '\x6e', '\x69', '\x61', '\x31', '\x16', '\x30', '\x14', + '\x06', '\x03', '\x55', '\x04', '\x07', '\x0c', '\x0d', '\x4d', '\x6f', + '\x75', '\x6e', '\x74', '\x61', '\x69', '\x6e', '\x20', '\x56', '\x69', + '\x65', '\x77', '\x31', '\x14', '\x30', '\x12', '\x06', '\x03', '\x55', + '\x04', '\x0a', '\x0c', '\x0b', '\x51', '\x55', '\x49', '\x43', '\x20', + '\x53', '\x65', '\x72', '\x76', '\x65', '\x72', '\x31', '\x12', '\x30', + '\x10', '\x06', '\x03', '\x55', '\x04', '\x03', '\x0c', '\x09', '\x31', + '\x32', '\x37', '\x2e', '\x30', '\x2e', '\x30', '\x2e', '\x31', '\x30', + '\x82', '\x01', '\x22', '\x30', '\x0d', '\x06', '\x09', '\x2a', '\x86', + '\x48', '\x86', '\xf7', '\x0d', '\x01', '\x01', '\x01', '\x05', '\x00', + '\x03', '\x82', '\x01', '\x0f', '\x00', '\x30', '\x82', '\x01', '\x0a', + '\x02', '\x82', '\x01', '\x01', '\x00', '\xc5', '\xe2', '\x51', '\x6d', + '\x3f', '\xd6', '\x28', '\xf2', '\xad', '\x34', '\x73', '\x87', '\x64', + '\xca', '\x33', '\x19', '\x33', '\xb7', '\x75', '\x91', '\xab', '\x31', + '\x19', '\x2b', '\xe3', '\xa4', '\x26', '\x09', '\x29', '\x8b', '\x2d', + '\xf7', '\x52', '\x75', '\xa7', '\x55', '\x15', '\xf0', '\x11', '\xc7', + '\xc2', '\xc4', '\xed', '\x18', '\x1b', '\x33', '\x0b', '\x71', '\x32', + '\xe6', '\x35', '\x89', '\xcd', '\x2d', '\x5a', '\x05', '\x57', '\x4e', + '\xc2', '\x78', '\x75', '\x65', '\x72', '\x2d', '\x8a', '\x17', '\x83', + '\xd6', '\x32', '\x90', '\x85', '\xf8', '\x22', '\xe2', '\x65', '\xa9', + '\xe0', '\xa0', '\xfe', '\x19', '\xb2', '\x39', '\x2d', '\x14', '\x03', + '\x10', '\x2f', '\xcc', '\x8b', '\x5e', '\xaa', '\x25', '\x27', '\x0d', + '\xa3', '\x37', '\x10', '\x0c', '\x17', '\xec', '\xf0', '\x8b', '\xc5', + '\x6b', '\xed', '\x6b', '\x5e', '\xb2', '\xe2', '\x35', '\x3e', '\x46', + '\x3b', '\xf7', '\xf6', '\x59', '\xb1', '\xe0', '\x16', '\xa6', '\xfb', + '\x03', '\xbf', '\x84', '\x4f', '\xce', '\x64', '\x15', '\x0d', '\x59', + '\x99', '\xa6', '\xf0', '\x7f', '\x8a', '\x33', '\x4b', '\xbb', '\x0b', + '\xb8', '\xf2', '\xd1', '\x27', '\x90', '\x8f', '\x38', '\xf8', '\x5a', + '\x41', '\x82', '\x07', '\x9b', '\x0d', '\xd9', '\x52', '\xe0', '\x70', + '\xff', '\xde', '\xda', '\xd8', '\x25', '\x4e', '\x2f', '\x2d', '\x9f', + '\xaf', '\x92', '\x63', '\xc7', '\x42', '\xb4', '\xdc', '\x16', '\x95', + '\x23', '\x05', '\x02', '\x6b', '\xb0', '\xe8', '\xc5', '\xfe', '\x15', + '\x9a', '\xe8', '\x7d', '\x2f', '\xdc', '\x43', '\xf4', '\x70', '\x91', + '\x1a', '\x93', '\xbe', '\x71', '\xaf', '\x85', '\x84', '\xdb', '\xcf', + '\x6b', '\x5c', '\x80', '\xb2', '\xd3', '\xf3', '\x42', '\x6e', '\x24', + '\xec', '\x2a', '\x62', '\x99', '\xc6', '\x3c', '\xe5', '\x32', '\xe5', + '\x72', '\x37', '\x30', '\x9b', '\x0b', '\xe4', '\x06', '\xb4', '\x64', + '\x26', '\x95', '\x59', '\xba', '\xf1', '\x53', '\x83', '\x3d', '\x99', + '\x6d', '\xf0', '\x80', '\xe2', '\xdb', '\x6b', '\x34', '\x52', '\x06', + '\x77', '\x3c', '\x73', '\xbe', '\xc6', '\xe3', '\xce', '\xb2', '\x11', + '\x02', '\x03', '\x01', '\x00', '\x01', '\xa3', '\x81', '\xb6', '\x30', + '\x81', '\xb3', '\x30', '\x0c', '\x06', '\x03', '\x55', '\x1d', '\x13', + '\x01', '\x01', '\xff', '\x04', '\x02', '\x30', '\x00', '\x30', '\x1d', + '\x06', '\x03', '\x55', '\x1d', '\x0e', '\x04', '\x16', '\x04', '\x14', + '\xc8', '\x54', '\x28', '\xf6', '\xd2', '\xd5', '\x12', '\x35', '\x89', + '\x15', '\x75', '\xb8', '\xbf', '\xdd', '\xfb', '\x4a', '\xfc', '\x6c', + '\x89', '\xde', '\x30', '\x1f', '\x06', '\x03', '\x55', '\x1d', '\x23', + '\x04', '\x18', '\x30', '\x16', '\x80', '\x14', '\x50', '\xe4', '\x1d', + '\xc3', '\x1a', '\xfb', '\xfd', '\x38', '\xdd', '\xa2', '\x05', '\xfd', + '\xc8', '\xfa', '\x57', '\x0a', '\xc1', '\x06', '\x0f', '\xae', '\x30', + '\x1d', '\x06', '\x03', '\x55', '\x1d', '\x25', '\x04', '\x16', '\x30', + '\x14', '\x06', '\x08', '\x2b', '\x06', '\x01', '\x05', '\x05', '\x07', + '\x03', '\x01', '\x06', '\x08', '\x2b', '\x06', '\x01', '\x05', '\x05', + '\x07', '\x03', '\x02', '\x30', '\x44', '\x06', '\x03', '\x55', '\x1d', + '\x11', '\x04', '\x3d', '\x30', '\x3b', '\x82', '\x0f', '\x77', '\x77', + '\x77', '\x2e', '\x65', '\x78', '\x61', '\x6d', '\x70', '\x6c', '\x65', + '\x2e', '\x6f', '\x72', '\x67', '\x82', '\x10', '\x6d', '\x61', '\x69', + '\x6c', '\x2e', '\x65', '\x78', '\x61', '\x6d', '\x70', '\x6c', '\x65', + '\x2e', '\x6f', '\x72', '\x67', '\x82', '\x10', '\x6d', '\x61', '\x69', + '\x6c', '\x2e', '\x65', '\x78', '\x61', '\x6d', '\x70', '\x6c', '\x65', + '\x2e', '\x63', '\x6f', '\x6d', '\x87', '\x04', '\x7f', '\x00', '\x00', + '\x01', '\x30', '\x0d', '\x06', '\x09', '\x2a', '\x86', '\x48', '\x86', + '\xf7', '\x0d', '\x01', '\x01', '\x0b', '\x05', '\x00', '\x03', '\x82', + '\x01', '\x01', '\x00', '\x45', '\x41', '\x7a', '\x68', '\xe0', '\xa7', + '\x59', '\xa1', '\x62', '\x54', '\x73', '\x74', '\x14', '\x4f', '\xde', + '\x9c', '\x51', '\xac', '\x25', '\x97', '\x70', '\xf7', '\x09', '\x51', + '\x39', '\x72', '\x39', '\x3c', '\xd0', '\x31', '\xe1', '\xc3', '\x02', + '\x91', '\x14', '\x4d', '\x8f', '\x1d', '\x31', '\xab', '\x98', '\x7e', + '\xe6', '\xbb', '\xab', '\x6a', '\xd9', '\xc5', '\x86', '\xaa', '\x4e', + '\x6a', '\x48', '\xe9', '\xf8', '\xd7', '\xb3', '\x1d', '\xa0', '\xc5', + '\xe6', '\xbf', '\x4c', '\x5a', '\x9b', '\xb5', '\x78', '\x01', '\xa3', + '\x39', '\x7b', '\x5f', '\xbc', '\xb8', '\xa7', '\xc2', '\x71', '\xb0', + '\x7b', '\xdd', '\xa1', '\x87', '\xa6', '\x54', '\x9c', '\xf6', '\x59', + '\x81', '\xb1', '\x2c', '\xde', '\xc5', '\x8a', '\xa2', '\x06', '\x89', + '\xb5', '\xc1', '\x7a', '\xbe', '\x0c', '\x9f', '\x3d', '\xde', '\x81', + '\x48', '\x53', '\x71', '\x7b', '\x8d', '\xc7', '\xea', '\x87', '\xd7', + '\xd1', '\xda', '\x94', '\xb4', '\xc5', '\xac', '\x1e', '\x83', '\xa3', + '\x42', '\x7d', '\xe6', '\xab', '\x3f', '\xd6', '\x1c', '\xd6', '\x65', + '\xc3', '\x60', '\xe9', '\x76', '\x54', '\x79', '\x3f', '\xeb', '\x65', + '\x85', '\x4f', '\x60', '\x7d', '\xbb', '\x96', '\x03', '\x54', '\x2e', + '\xd0', '\x1b', '\xe2', '\x6c', '\x2d', '\x91', '\xae', '\x33', '\x9c', + '\x04', '\xc4', '\x44', '\x0a', '\x7d', '\x5f', '\xbb', '\x80', '\xa2', + '\x01', '\xbc', '\x90', '\x81', '\xa5', '\xdc', '\x4a', '\xc8', '\x77', + '\xc9', '\x8d', '\x34', '\x17', '\xe6', '\x2a', '\x7d', '\x02', '\x1e', + '\x32', '\x3f', '\x7d', '\xd7', '\x0c', '\x80', '\x5b', '\xc6', '\x94', + '\x6a', '\x42', '\x36', '\x05', '\x9f', '\x9e', '\xc5', '\x85', '\x9f', + '\x60', '\xe3', '\x72', '\x73', '\x34', '\x39', '\x44', '\x75', '\x55', + '\x60', '\x24', '\x7a', '\x8b', '\x09', '\x74', '\x84', '\x72', '\xfd', + '\x91', '\x68', '\x93', '\x57', '\x9e', '\x70', '\x46', '\x4d', '\xe4', + '\x30', '\x84', '\x5f', '\x20', '\x07', '\xad', '\xfd', '\x86', '\x32', + '\xd3', '\xfb', '\xba', '\xaf', '\xd9', '\x61', '\x14', '\x3c', '\xe0', + '\xa1', '\xa9', '\x51', '\x51', '\x0f', '\xad', '\x60'}; + +ABSL_CONST_INIT const absl::string_view kTestCertificate( + kTestCertificateRaw, sizeof(kTestCertificateRaw)); + +ABSL_CONST_INIT const char kTestCertificatePem[] = + R"(Certificate: + Data: + Version: 3 (0x2) + Serial Number: 1 (0x1) + Signature Algorithm: sha256WithRSAEncryption + Issuer: CN=QUIC Server Root CA + Validity + Not Before: Jan 30 18:13:59 2020 GMT + Not After : Feb 2 18:13:59 2020 GMT + Subject: C=US, ST=California, L=Mountain View, O=QUIC Server, CN=127.0.0.1 + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + RSA Public-Key: (2048 bit) + Modulus: + 00:c5:e2:51:6d:3f:d6:28:f2:ad:34:73:87:64:ca: + 33:19:33:b7:75:91:ab:31:19:2b:e3:a4:26:09:29: + 8b:2d:f7:52:75:a7:55:15:f0:11:c7:c2:c4:ed:18: + 1b:33:0b:71:32:e6:35:89:cd:2d:5a:05:57:4e:c2: + 78:75:65:72:2d:8a:17:83:d6:32:90:85:f8:22:e2: + 65:a9:e0:a0:fe:19:b2:39:2d:14:03:10:2f:cc:8b: + 5e:aa:25:27:0d:a3:37:10:0c:17:ec:f0:8b:c5:6b: + ed:6b:5e:b2:e2:35:3e:46:3b:f7:f6:59:b1:e0:16: + a6:fb:03:bf:84:4f:ce:64:15:0d:59:99:a6:f0:7f: + 8a:33:4b:bb:0b:b8:f2:d1:27:90:8f:38:f8:5a:41: + 82:07:9b:0d:d9:52:e0:70:ff:de:da:d8:25:4e:2f: + 2d:9f:af:92:63:c7:42:b4:dc:16:95:23:05:02:6b: + b0:e8:c5:fe:15:9a:e8:7d:2f:dc:43:f4:70:91:1a: + 93:be:71:af:85:84:db:cf:6b:5c:80:b2:d3:f3:42: + 6e:24:ec:2a:62:99:c6:3c:e5:32:e5:72:37:30:9b: + 0b:e4:06:b4:64:26:95:59:ba:f1:53:83:3d:99:6d: + f0:80:e2:db:6b:34:52:06:77:3c:73:be:c6:e3:ce: + b2:11 + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Basic Constraints: critical + CA:FALSE + X509v3 Subject Key Identifier: + C8:54:28:F6:D2:D5:12:35:89:15:75:B8:BF:DD:FB:4A:FC:6C:89:DE + X509v3 Authority Key Identifier: + keyid:50:E4:1D:C3:1A:FB:FD:38:DD:A2:05:FD:C8:FA:57:0A:C1:06:0F:AE + + X509v3 Extended Key Usage: + TLS Web Server Authentication, TLS Web Client Authentication + X509v3 Subject Alternative Name: + DNS:www.example.org, DNS:mail.example.org, DNS:mail.example.com, IP Address:127.0.0.1 + Signature Algorithm: sha256WithRSAEncryption + 45:41:7a:68:e0:a7:59:a1:62:54:73:74:14:4f:de:9c:51:ac: + 25:97:70:f7:09:51:39:72:39:3c:d0:31:e1:c3:02:91:14:4d: + 8f:1d:31:ab:98:7e:e6:bb:ab:6a:d9:c5:86:aa:4e:6a:48:e9: + f8:d7:b3:1d:a0:c5:e6:bf:4c:5a:9b:b5:78:01:a3:39:7b:5f: + bc:b8:a7:c2:71:b0:7b:dd:a1:87:a6:54:9c:f6:59:81:b1:2c: + de:c5:8a:a2:06:89:b5:c1:7a:be:0c:9f:3d:de:81:48:53:71: + 7b:8d:c7:ea:87:d7:d1:da:94:b4:c5:ac:1e:83:a3:42:7d:e6: + ab:3f:d6:1c:d6:65:c3:60:e9:76:54:79:3f:eb:65:85:4f:60: + 7d:bb:96:03:54:2e:d0:1b:e2:6c:2d:91:ae:33:9c:04:c4:44: + 0a:7d:5f:bb:80:a2:01:bc:90:81:a5:dc:4a:c8:77:c9:8d:34: + 17:e6:2a:7d:02:1e:32:3f:7d:d7:0c:80:5b:c6:94:6a:42:36: + 05:9f:9e:c5:85:9f:60:e3:72:73:34:39:44:75:55:60:24:7a: + 8b:09:74:84:72:fd:91:68:93:57:9e:70:46:4d:e4:30:84:5f: + 20:07:ad:fd:86:32:d3:fb:ba:af:d9:61:14:3c:e0:a1:a9:51: + 51:0f:ad:60 +-----BEGIN CERTIFICATE----- +MIIDtDCCApygAwIBAgIBATANBgkqhkiG9w0BAQsFADAeMRwwGgYDVQQDDBNRVUlD +IFNlcnZlciBSb290IENBMB4XDTIwMDEzMDE4MTM1OVoXDTIwMDIwMjE4MTM1OVow +ZDELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExFjAUBgNVBAcMDU1v +dW50YWluIFZpZXcxFDASBgNVBAoMC1FVSUMgU2VydmVyMRIwEAYDVQQDDAkxMjcu +MC4wLjEwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDF4lFtP9Yo8q00 +c4dkyjMZM7d1kasxGSvjpCYJKYst91J1p1UV8BHHwsTtGBszC3Ey5jWJzS1aBVdO +wnh1ZXItiheD1jKQhfgi4mWp4KD+GbI5LRQDEC/Mi16qJScNozcQDBfs8IvFa+1r +XrLiNT5GO/f2WbHgFqb7A7+ET85kFQ1Zmabwf4ozS7sLuPLRJ5CPOPhaQYIHmw3Z +UuBw/97a2CVOLy2fr5Jjx0K03BaVIwUCa7Doxf4Vmuh9L9xD9HCRGpO+ca+FhNvP +a1yAstPzQm4k7CpimcY85TLlcjcwmwvkBrRkJpVZuvFTgz2ZbfCA4ttrNFIGdzxz +vsbjzrIRAgMBAAGjgbYwgbMwDAYDVR0TAQH/BAIwADAdBgNVHQ4EFgQUyFQo9tLV +EjWJFXW4v937Svxsid4wHwYDVR0jBBgwFoAUUOQdwxr7/TjdogX9yPpXCsEGD64w +HQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMEQGA1UdEQQ9MDuCD3d3dy5l +eGFtcGxlLm9yZ4IQbWFpbC5leGFtcGxlLm9yZ4IQbWFpbC5leGFtcGxlLmNvbYcE +fwAAATANBgkqhkiG9w0BAQsFAAOCAQEARUF6aOCnWaFiVHN0FE/enFGsJZdw9wlR +OXI5PNAx4cMCkRRNjx0xq5h+5ruratnFhqpOakjp+NezHaDF5r9MWpu1eAGjOXtf +vLinwnGwe92hh6ZUnPZZgbEs3sWKogaJtcF6vgyfPd6BSFNxe43H6ofX0dqUtMWs +HoOjQn3mqz/WHNZlw2DpdlR5P+tlhU9gfbuWA1Qu0BvibC2RrjOcBMRECn1fu4Ci +AbyQgaXcSsh3yY00F+YqfQIeMj991wyAW8aUakI2BZ+exYWfYONyczQ5RHVVYCR6 +iwl0hHL9kWiTV55wRk3kMIRfIAet/YYy0/u6r9lhFDzgoalRUQ+tYA== +-----END CERTIFICATE-----)"; + +// Same leaf as above, but with an intermediary attached. +ABSL_CONST_INIT const char kTestCertificateChainPem[] = + R"(-----BEGIN CERTIFICATE----- +MIIDtDCCApygAwIBAgIBATANBgkqhkiG9w0BAQsFADAeMRwwGgYDVQQDDBNRVUlD +IFNlcnZlciBSb290IENBMB4XDTIwMDEzMDE4MTM1OVoXDTIwMDIwMjE4MTM1OVow +ZDELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExFjAUBgNVBAcMDU1v +dW50YWluIFZpZXcxFDASBgNVBAoMC1FVSUMgU2VydmVyMRIwEAYDVQQDDAkxMjcu +MC4wLjEwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDF4lFtP9Yo8q00 +c4dkyjMZM7d1kasxGSvjpCYJKYst91J1p1UV8BHHwsTtGBszC3Ey5jWJzS1aBVdO +wnh1ZXItiheD1jKQhfgi4mWp4KD+GbI5LRQDEC/Mi16qJScNozcQDBfs8IvFa+1r +XrLiNT5GO/f2WbHgFqb7A7+ET85kFQ1Zmabwf4ozS7sLuPLRJ5CPOPhaQYIHmw3Z +UuBw/97a2CVOLy2fr5Jjx0K03BaVIwUCa7Doxf4Vmuh9L9xD9HCRGpO+ca+FhNvP +a1yAstPzQm4k7CpimcY85TLlcjcwmwvkBrRkJpVZuvFTgz2ZbfCA4ttrNFIGdzxz +vsbjzrIRAgMBAAGjgbYwgbMwDAYDVR0TAQH/BAIwADAdBgNVHQ4EFgQUyFQo9tLV +EjWJFXW4v937Svxsid4wHwYDVR0jBBgwFoAUUOQdwxr7/TjdogX9yPpXCsEGD64w +HQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMEQGA1UdEQQ9MDuCD3d3dy5l +eGFtcGxlLm9yZ4IQbWFpbC5leGFtcGxlLm9yZ4IQbWFpbC5leGFtcGxlLmNvbYcE +fwAAATANBgkqhkiG9w0BAQsFAAOCAQEARUF6aOCnWaFiVHN0FE/enFGsJZdw9wlR +OXI5PNAx4cMCkRRNjx0xq5h+5ruratnFhqpOakjp+NezHaDF5r9MWpu1eAGjOXtf +vLinwnGwe92hh6ZUnPZZgbEs3sWKogaJtcF6vgyfPd6BSFNxe43H6ofX0dqUtMWs +HoOjQn3mqz/WHNZlw2DpdlR5P+tlhU9gfbuWA1Qu0BvibC2RrjOcBMRECn1fu4Ci +AbyQgaXcSsh3yY00F+YqfQIeMj991wyAW8aUakI2BZ+exYWfYONyczQ5RHVVYCR6 +iwl0hHL9kWiTV55wRk3kMIRfIAet/YYy0/u6r9lhFDzgoalRUQ+tYA== +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIDDDCCAfSgAwIBAgIUfVS7RH+aVGqZhrjyuyD4qCnTS+MwDQYJKoZIhvcNAQEL +BQAwHjEcMBoGA1UEAwwTUVVJQyBTZXJ2ZXIgUm9vdCBDQTAeFw0yMDAxMzAxODEz +NTlaFw0yMDAyMDIxODEzNTlaMB4xHDAaBgNVBAMME1FVSUMgU2VydmVyIFJvb3Qg +Q0EwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCc3k0GGpCBf6jXHxia +QM4ntB6pWkT+NbaZUNHb1SkG2Cp9dN5dEKOXiqOi9306j4WNWTq/q0Ku9lCPPPFs +JTIVC3tKY8Nbiczw+mohgW4rwLgpAP5rjjVzTxSFpDWZlgkH54HpqLjJFVl4Fklg +vzSj+rYfqP+ueesi7z7KwPwzd30jjsJlpr2rlkZkidWT5vRTD3uYhNOW7IIT0lRP +MDTwdxTEU5unyxESAsZyckNuJDeNF0y1Aw5Xiw/Bww+CyRH+tX6OUcWNtA+ZSDU8 +oVH5m4rxYK/DaHAZrA672/ywvUcPQaNaRxsAWRVjhktgyGPT3pjqiHDCN8+42uhH +SgrbAgMBAAGjQjBAMA8GA1UdEwEB/wQFMAMBAf8wHQYDVR0OBBYEFFDkHcMa+/04 +3aIF/cj6VwrBBg+uMA4GA1UdDwEB/wQEAwIBBjANBgkqhkiG9w0BAQsFAAOCAQEA +iX+tn1Zfxx4M5YqZlPgXFB219agrJP2vM0fzW0E4zqDvA2ALaQN+lwdnFueN3tDk +3IJvxd2W5k1Qh7LqWFUbBghDAP43XffW/yNy0+nuR2n3nRYdNStSMrGQm7oywhBd +5jQl0GQUyYf1jcbD76HA5JraBjEXnQyJe6gJYHiRiMaMURWyzcngOPv5w3XBzIe3 +sRM0Rk/TTZP1Qx7fDY3ikFe1w9LzAMGbKDTKfc1+F0GZByJ3pdWakUNXZvtGFhIF +hTXMooR/wD7an6gtnXD8ixCh7bP0TyPiBhNsUb12WrvSEAm/UyciQbQlR7P+K0Z7 +Cmn1Mj4hQ+pT0t+pw/DMOw== +-----END CERTIFICATE-----)"; + +ABSL_CONST_INIT const char kTestCertWithUnknownSanTypePem[] = + R"(-----BEGIN CERTIFICATE----- +MIIEYTCCA0mgAwIBAgIJAILStmLgUUcVMA0GCSqGSIb3DQEBCwUAMHYxCzAJBgNV +BAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRYwFAYDVQQHDA1TYW4gRnJhbmNp +c2NvMQ0wCwYDVQQKDARMeWZ0MRkwFwYDVQQLDBBMeWZ0IEVuZ2luZWVyaW5nMRAw +DgYDVQQDDAdUZXN0IENBMB4XDTE4MTIxNzIwMTgwMFoXDTIwMTIxNjIwMTgwMFow +gaYxCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRYwFAYDVQQHDA1T +YW4gRnJhbmNpc2NvMQ0wCwYDVQQKDARMeWZ0MRkwFwYDVQQLDBBMeWZ0IEVuZ2lu +ZWVyaW5nMRowGAYDVQQDDBFUZXN0IEJhY2tlbmQgVGVhbTEkMCIGCSqGSIb3DQEJ +ARYVYmFja2VuZC10ZWFtQGx5ZnQuY29tMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEAuvPdQdmwZongPAgQho/Vipd3PZWrQ6BKxIb4l/RvqtVP321IUTLs +4vVwpXoYJ+12L+XOO3jCInszs53tHjFpTI1GE8/sasmgR6LRr2krwSoVRHPqUoc9 +tzkDG1SzKP2TRTi1MTI3FO+TnLFahntO9Zstxhv1Epz5GZ/xQLE0/LLoRYzcynL/ +iflk18iL1KM8i0Hy4cKjclOaUdnh2nh753iJfxCSb5wJfx4FH1qverYHHT6FopYR +V40Cg0yYXcYo8yNwrg+EBY8QAT2JOMDokXNKbZpmVKiBlh0QYMX6BBiW249v3sYl +3Ve+fZvCkle3W0xP0xJw8PdX0NRbvGOrBQIDAQABo4HAMIG9MAwGA1UdEwEB/wQC +MAAwCwYDVR0PBAQDAgXgMB0GA1UdJQQWMBQGCCsGAQUFBwMCBggrBgEFBQcDATBB +BgNVHREEOjA4hh5zcGlmZmU6Ly9seWZ0LmNvbS9iYWNrZW5kLXRlYW2CCGx5ZnQu +Y29tggx3d3cubHlmdC5jb20wHQYDVR0OBBYEFLHmMm0DV9jCHJSWVRwyPYpBw62r +MB8GA1UdIwQYMBaAFBQz1vaSbPuePL++7GTMqLAMtk3kMA0GCSqGSIb3DQEBCwUA +A4IBAQAwx3/M2o00W8GlQ3OT4y/hQGb5K2aytxx8QeSmJaaZTJbvaHhe0x3/fLgq +uWrW3WEWFtwasilySjOrFOtB9UNmJmNOHSJD3Bslbv5htRaWnoFPCXdwZtVMdoTq +IHIQqLoos/xj3kVD5sJSYySrveMeKaeUILTkb5ZubSivye1X2yiJLR7AtuwuiMio +CdIOqhn6xJqYhT7z0IhdKpLNPk4w1tBZSKOXqzrXS4uoJgTC67hWslWWZ2VC6IvZ +FmKuuGZamCCj6F1QF2IjMVM8evl84hEnN0ajdkA/QWnil9kcWvBm15Ho+oTvvJ7s +M8MD3RDSq/90FSiME4vbyNEyTmj0 +-----END CERTIFICATE-----)"; + +ABSL_CONST_INIT const char kTestCertificatePrivateKeyRaw[] = { + '\x30', '\x82', '\x04', '\xbc', '\x02', '\x01', '\x00', '\x30', '\x0d', + '\x06', '\x09', '\x2a', '\x86', '\x48', '\x86', '\xf7', '\x0d', '\x01', + '\x01', '\x01', '\x05', '\x00', '\x04', '\x82', '\x04', '\xa6', '\x30', + '\x82', '\x04', '\xa2', '\x02', '\x01', '\x00', '\x02', '\x82', '\x01', + '\x01', '\x00', '\xc5', '\xe2', '\x51', '\x6d', '\x3f', '\xd6', '\x28', + '\xf2', '\xad', '\x34', '\x73', '\x87', '\x64', '\xca', '\x33', '\x19', + '\x33', '\xb7', '\x75', '\x91', '\xab', '\x31', '\x19', '\x2b', '\xe3', + '\xa4', '\x26', '\x09', '\x29', '\x8b', '\x2d', '\xf7', '\x52', '\x75', + '\xa7', '\x55', '\x15', '\xf0', '\x11', '\xc7', '\xc2', '\xc4', '\xed', + '\x18', '\x1b', '\x33', '\x0b', '\x71', '\x32', '\xe6', '\x35', '\x89', + '\xcd', '\x2d', '\x5a', '\x05', '\x57', '\x4e', '\xc2', '\x78', '\x75', + '\x65', '\x72', '\x2d', '\x8a', '\x17', '\x83', '\xd6', '\x32', '\x90', + '\x85', '\xf8', '\x22', '\xe2', '\x65', '\xa9', '\xe0', '\xa0', '\xfe', + '\x19', '\xb2', '\x39', '\x2d', '\x14', '\x03', '\x10', '\x2f', '\xcc', + '\x8b', '\x5e', '\xaa', '\x25', '\x27', '\x0d', '\xa3', '\x37', '\x10', + '\x0c', '\x17', '\xec', '\xf0', '\x8b', '\xc5', '\x6b', '\xed', '\x6b', + '\x5e', '\xb2', '\xe2', '\x35', '\x3e', '\x46', '\x3b', '\xf7', '\xf6', + '\x59', '\xb1', '\xe0', '\x16', '\xa6', '\xfb', '\x03', '\xbf', '\x84', + '\x4f', '\xce', '\x64', '\x15', '\x0d', '\x59', '\x99', '\xa6', '\xf0', + '\x7f', '\x8a', '\x33', '\x4b', '\xbb', '\x0b', '\xb8', '\xf2', '\xd1', + '\x27', '\x90', '\x8f', '\x38', '\xf8', '\x5a', '\x41', '\x82', '\x07', + '\x9b', '\x0d', '\xd9', '\x52', '\xe0', '\x70', '\xff', '\xde', '\xda', + '\xd8', '\x25', '\x4e', '\x2f', '\x2d', '\x9f', '\xaf', '\x92', '\x63', + '\xc7', '\x42', '\xb4', '\xdc', '\x16', '\x95', '\x23', '\x05', '\x02', + '\x6b', '\xb0', '\xe8', '\xc5', '\xfe', '\x15', '\x9a', '\xe8', '\x7d', + '\x2f', '\xdc', '\x43', '\xf4', '\x70', '\x91', '\x1a', '\x93', '\xbe', + '\x71', '\xaf', '\x85', '\x84', '\xdb', '\xcf', '\x6b', '\x5c', '\x80', + '\xb2', '\xd3', '\xf3', '\x42', '\x6e', '\x24', '\xec', '\x2a', '\x62', + '\x99', '\xc6', '\x3c', '\xe5', '\x32', '\xe5', '\x72', '\x37', '\x30', + '\x9b', '\x0b', '\xe4', '\x06', '\xb4', '\x64', '\x26', '\x95', '\x59', + '\xba', '\xf1', '\x53', '\x83', '\x3d', '\x99', '\x6d', '\xf0', '\x80', + '\xe2', '\xdb', '\x6b', '\x34', '\x52', '\x06', '\x77', '\x3c', '\x73', + '\xbe', '\xc6', '\xe3', '\xce', '\xb2', '\x11', '\x02', '\x03', '\x01', + '\x00', '\x01', '\x02', '\x82', '\x01', '\x00', '\x39', '\x75', '\xac', + '\x1b', '\x43', '\x0c', '\x16', '\xbb', '\xd0', '\xdb', '\x88', '\x28', + '\x6a', '\x75', '\xe4', '\x3c', '\x8f', '\x2d', '\xd8', '\x6f', '\xc1', + '\xfb', '\xf1', '\xc9', '\x32', '\xc2', '\xb9', '\x60', '\xb3', '\xb5', + '\x7c', '\x55', '\x72', '\x96', '\x43', '\x4e', '\x8b', '\x9e', '\x38', + '\x2b', '\x7f', '\x3c', '\xdb', '\x73', '\xc2', '\x82', '\x21', '\xf2', + '\x6e', '\xcb', '\x36', '\x04', '\x9b', '\x95', '\x6d', '\xac', '\x5b', + '\x5b', '\xbd', '\x50', '\x69', '\x16', '\x59', '\xff', '\x2b', '\x38', + '\x04', '\xca', '\x2f', '\xc8', '\x93', '\x7e', '\x27', '\xf3', '\x01', + '\x7e', '\x40', '\x81', '\xbf', '\x07', '\x0b', '\x1f', '\x5b', '\x1d', + '\x92', '\x7e', '\x22', '\xc3', '\x0c', '\x3d', '\x22', '\xbe', '\xc3', + '\x06', '\x4c', '\xbc', '\x72', '\x66', '\x70', '\x94', '\x16', '\x8d', + '\x1f', '\x78', '\x65', '\x6a', '\x66', '\x07', '\x1f', '\x74', '\x42', + '\x6e', '\xf6', '\x7e', '\xdc', '\x03', '\xd3', '\x88', '\xb4', '\x4b', + '\x2c', '\x5c', '\x3c', '\x42', '\x59', '\x42', '\x1f', '\x01', '\x13', + '\x31', '\xc5', '\x22', '\xe7', '\x6a', '\x96', '\xf2', '\xfb', '\x66', + '\xfe', '\xc8', '\xa1', '\x7e', '\x24', '\x96', '\x5f', '\x02', '\xee', + '\x38', '\x21', '\xa5', '\x14', '\xd2', '\xa6', '\x35', '\x70', '\x6c', + '\x8d', '\xa6', '\xd8', '\x2a', '\xd2', '\x45', '\x31', '\x5f', '\x67', + '\x9e', '\x35', '\x57', '\x6a', '\xc4', '\x15', '\xe7', '\xba', '\x60', + '\x2f', '\x8e', '\x52', '\x4e', '\xfc', '\x6f', '\xa0', '\x08', '\x91', + '\x31', '\x71', '\x06', '\x68', '\x19', '\x48', '\xc7', '\x81', '\x0d', + '\x5e', '\x52', '\x93', '\x57', '\xcc', '\xfe', '\x46', '\xac', '\xa9', + '\x4f', '\xe2', '\x96', '\x4f', '\xaf', '\x12', '\xfb', '\xc2', '\x4b', + '\xc4', '\x8d', '\x3b', '\xb0', '\x38', '\xe4', '\xbb', '\x8d', '\x19', + '\x81', '\xe4', '\x74', '\x63', '\x9c', '\x8d', '\xaa', '\x84', '\x82', + '\x91', '\xdf', '\xdc', '\x45', '\xf0', '\x39', '\xb2', '\xb4', '\xac', + '\x45', '\xda', '\x3f', '\x30', '\x4d', '\x46', '\xb1', '\xe1', '\xb2', + '\x9d', '\xdf', '\xd8', '\xc4', '\xa2', '\xef', '\xe9', '\x1a', '\x97', + '\x79', '\x02', '\x81', '\x81', '\x00', '\xe5', '\x23', '\xb8', '\xd7', + '\x09', '\x54', '\x54', '\x3b', '\xb6', '\x78', '\x78', '\x67', '\x57', + '\x65', '\xc5', '\xd4', '\x74', '\xaf', '\x05', '\x4f', '\xb5', '\xc8', + '\x8c', '\x1b', '\xd1', '\x9a', '\x2c', '\xd6', '\xe4', '\x68', '\xd1', + '\xaf', '\x3d', '\x72', '\x42', '\x50', '\xc8', '\xdd', '\xb1', '\xee', + '\x77', '\x52', '\xb8', '\xb1', '\x31', '\xbe', '\xf0', '\x74', '\x78', + '\x42', '\x59', '\xea', '\x13', '\x8b', '\x82', '\x00', '\x54', '\x22', + '\xd2', '\x0a', '\x24', '\xb0', '\x1f', '\x1e', '\x76', '\x27', '\xae', + '\x63', '\xc6', '\x6b', '\x59', '\x28', '\x1d', '\xa0', '\x9f', '\x42', + '\x30', '\xf1', '\xe3', '\x59', '\x1c', '\x4f', '\x31', '\x49', '\xff', + '\x45', '\x7e', '\x6b', '\xef', '\xe9', '\x6f', '\xde', '\xaf', '\x1e', + '\x04', '\x96', '\x61', '\x4e', '\x9f', '\x58', '\xf5', '\x0d', '\x64', + '\x08', '\x48', '\x0a', '\xae', '\xac', '\xe4', '\x76', '\x91', '\xdd', + '\x6e', '\x33', '\x97', '\xc5', '\x96', '\xda', '\xff', '\xbc', '\x42', + '\x5b', '\x71', '\xb5', '\x76', '\xae', '\x01', '\xb3', '\x02', '\x81', + '\x81', '\x00', '\xdd', '\x14', '\xa5', '\x6c', '\x89', '\x2b', '\x80', + '\x78', '\xf6', '\xc3', '\x80', '\x4d', '\x53', '\x54', '\xb3', '\x2b', + '\x40', '\xce', '\x98', '\x16', '\xa0', '\xbf', '\x72', '\xf1', '\xe3', + '\xdc', '\xe9', '\x0b', '\x45', '\x23', '\x86', '\x38', '\x4c', '\x29', + '\xf1', '\xa0', '\xe0', '\x2c', '\xfa', '\x86', '\x3f', '\x01', '\x90', + '\xc5', '\x1b', '\x96', '\x10', '\x44', '\x84', '\xfb', '\xec', '\x3c', + '\x74', '\x6c', '\x0d', '\xcc', '\xc3', '\xcd', '\x1b', '\x28', '\x12', + '\xaa', '\xb4', '\x67', '\x80', '\xc8', '\xd9', '\x1b', '\x7d', '\xe7', + '\x54', '\x39', '\x03', '\x6d', '\xba', '\xaa', '\x6f', '\xf7', '\x93', + '\x1f', '\x94', '\x76', '\xd6', '\xab', '\x9b', '\xda', '\x3d', '\x89', + '\x37', '\x83', '\xfe', '\x72', '\x2a', '\xbb', '\x6f', '\x36', '\xc5', + '\xe0', '\xae', '\x65', '\xf9', '\xbb', '\xc6', '\xe2', '\x98', '\x0f', + '\xbd', '\xf6', '\x22', '\xf8', '\x35', '\x5b', '\x99', '\xe6', '\xff', + '\x6d', '\x6e', '\xb2', '\x92', '\x93', '\x64', '\x25', '\xc1', '\xe8', + '\x9c', '\x6b', '\x73', '\x2b', '\x02', '\x81', '\x80', '\x13', '\x30', + '\x1a', '\x9a', '\x67', '\x3d', '\x98', '\x90', '\x27', '\x87', '\x8f', + '\x0d', '\x98', '\x53', '\xfd', '\x6c', '\xfd', '\x18', '\x6a', '\xe9', + '\x71', '\xdf', '\x89', '\x5c', '\x0b', '\x01', '\x4e', '\x1f', '\xf0', + '\xa0', '\x96', '\x6e', '\x86', '\x46', '\xbb', '\x26', '\xe8', '\xab', + '\x27', '\xeb', '\x40', '\x32', '\xbd', '\x24', '\x99', '\x75', '\xd3', + '\xcc', '\xed', '\x05', '\x21', '\x62', '\x68', '\xa0', '\x96', '\x12', + '\x50', '\xf9', '\x59', '\x7d', '\x5f', '\xf5', '\x1f', '\xa5', '\xfd', + '\x5e', '\xf5', '\x4b', '\x85', '\xa2', '\x17', '\xa5', '\x34', '\x55', + '\xef', '\x00', '\x2b', '\xf9', '\x15', '\x80', '\xb0', '\xce', '\x30', + '\xe2', '\x71', '\x6d', '\xf0', '\x58', '\x39', '\x8e', '\xe2', '\xbf', + '\x53', '\x0a', '\xc0', '\x77', '\x97', '\x4e', '\x6e', '\x29', '\x94', + '\xdb', '\xba', '\x34', '\xb7', '\x53', '\xad', '\xac', '\xec', '\xb4', + '\xc1', '\x22', '\x39', '\xc8', '\x38', '\x3d', '\x63', '\x94', '\x93', + '\x35', '\xc0', '\x98', '\xc7', '\xbc', '\xda', '\x63', '\x57', '\xe1', + '\x02', '\x81', '\x80', '\x51', '\x71', '\x7c', '\xab', '\x6a', '\x30', + '\xe3', '\x68', '\x2c', '\x87', '\xc2', '\xe9', '\x39', '\x8c', '\x97', + '\x60', '\x94', '\xc4', '\x46', '\xd4', '\xf7', '\x2c', '\xf0', '\x1c', + '\x5a', '\x34', '\x14', '\x89', '\xf9', '\x53', '\x67', '\xeb', '\xaf', + '\x6b', '\x38', '\x3f', '\x6a', '\xb6', '\x47', '\x28', '\x53', '\x67', + '\xb1', '\x3c', '\x5b', '\xb8', '\x41', '\x8f', '\xec', '\x69', '\x9e', + '\x12', '\x7b', '\x55', '\x1f', '\x14', '\x53', '\x01', '\x69', '\x42', + '\xae', '\xf5', '\xc1', '\xf5', '\xeb', '\x44', '\x92', '\x6e', '\x85', + '\x48', '\x46', '\x07', '\xa6', '\xd2', '\xb2', '\x94', '\x7d', '\x20', + '\xf8', '\x4b', '\x06', '\xf7', '\x6c', '\x87', '\xd5', '\xa7', '\x65', + '\x49', '\xfa', '\x70', '\x9e', '\xb8', '\xd2', '\x33', '\x30', '\x7a', + '\x3e', '\x15', '\x52', '\x49', '\xf0', '\xe1', '\x13', '\x18', '\x80', + '\xaa', '\x33', '\xf1', '\xcb', '\xda', '\x22', '\x55', '\xf7', '\x71', + '\x58', '\xa1', '\xa8', '\xc9', '\x12', '\x24', '\x48', '\x1d', '\x7c', + '\xbc', '\xc3', '\x7a', '\xf5', '\xf7', '\x02', '\x81', '\x80', '\x41', + '\x7c', '\xae', '\x6e', '\x48', '\x3f', '\xb5', '\x0b', '\x99', '\xaa', + '\xc5', '\xea', '\x81', '\xad', '\x84', '\x6b', '\x29', '\x78', '\x4b', + '\x18', '\xdb', '\x0e', '\xd3', '\x3e', '\x60', '\x8b', '\xef', '\x65', + '\x4d', '\x58', '\x25', '\x3a', '\x08', '\xb5', '\x21', '\xb6', '\x61', + '\x0c', '\xfa', '\xf0', '\x69', '\x78', '\x4e', '\x68', '\x36', '\xdb', + '\x41', '\x4b', '\x50', '\xd8', '\xd3', '\x8e', '\x3d', '\x74', '\x80', + '\x8e', '\xa0', '\xe6', '\xda', '\xec', '\x70', '\x89', '\x77', '\xb2', + '\x9d', '\xd6', '\x6e', '\x0a', '\xc4', '\xbd', '\xf6', '\x9a', '\x07', + '\x15', '\xba', '\x55', '\x9f', '\xd4', '\x4d', '\x3a', '\x0f', '\x51', + '\x12', '\xa4', '\xd9', '\xc2', '\x98', '\x76', '\xc5', '\xb7', '\x29', + '\x40', '\xca', '\xf4', '\xbb', '\x74', '\x2d', '\x71', '\x03', '\x4d', + '\xe7', '\x05', '\x75', '\xc0', '\x8d', '\x96', '\x7e', '\x59', '\xa1', + '\x8b', '\x3b', '\xa3', '\x2b', '\xa5', '\xa3', '\xc8', '\xf7', '\xd3', + '\x3e', '\x6b', '\x2e', '\xfa', '\x4f', '\x4d', '\xe6', '\xbe', '\xd3', + '\x59'}; + +ABSL_CONST_INIT const absl::string_view kTestCertificatePrivateKey( + kTestCertificatePrivateKeyRaw, sizeof(kTestCertificatePrivateKeyRaw)); + +ABSL_CONST_INIT const char kTestCertificatePrivateKeyPem[] = + R"(-----BEGIN PRIVATE KEY----- +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDF4lFtP9Yo8q00 +c4dkyjMZM7d1kasxGSvjpCYJKYst91J1p1UV8BHHwsTtGBszC3Ey5jWJzS1aBVdO +wnh1ZXItiheD1jKQhfgi4mWp4KD+GbI5LRQDEC/Mi16qJScNozcQDBfs8IvFa+1r +XrLiNT5GO/f2WbHgFqb7A7+ET85kFQ1Zmabwf4ozS7sLuPLRJ5CPOPhaQYIHmw3Z +UuBw/97a2CVOLy2fr5Jjx0K03BaVIwUCa7Doxf4Vmuh9L9xD9HCRGpO+ca+FhNvP +a1yAstPzQm4k7CpimcY85TLlcjcwmwvkBrRkJpVZuvFTgz2ZbfCA4ttrNFIGdzxz +vsbjzrIRAgMBAAECggEAOXWsG0MMFrvQ24goanXkPI8t2G/B+/HJMsK5YLO1fFVy +lkNOi544K38823PCgiHybss2BJuVbaxbW71QaRZZ/ys4BMovyJN+J/MBfkCBvwcL +H1sdkn4iwww9Ir7DBky8cmZwlBaNH3hlamYHH3RCbvZ+3APTiLRLLFw8QllCHwET +McUi52qW8vtm/sihfiSWXwLuOCGlFNKmNXBsjabYKtJFMV9nnjVXasQV57pgL45S +TvxvoAiRMXEGaBlIx4ENXlKTV8z+RqypT+KWT68S+8JLxI07sDjku40ZgeR0Y5yN +qoSCkd/cRfA5srSsRdo/ME1GseGynd/YxKLv6RqXeQKBgQDlI7jXCVRUO7Z4eGdX +ZcXUdK8FT7XIjBvRmizW5GjRrz1yQlDI3bHud1K4sTG+8HR4QlnqE4uCAFQi0gok +sB8edieuY8ZrWSgdoJ9CMPHjWRxPMUn/RX5r7+lv3q8eBJZhTp9Y9Q1kCEgKrqzk +dpHdbjOXxZba/7xCW3G1dq4BswKBgQDdFKVsiSuAePbDgE1TVLMrQM6YFqC/cvHj +3OkLRSOGOEwp8aDgLPqGPwGQxRuWEESE++w8dGwNzMPNGygSqrRngMjZG33nVDkD +bbqqb/eTH5R21qub2j2JN4P+ciq7bzbF4K5l+bvG4pgPvfYi+DVbmeb/bW6ykpNk +JcHonGtzKwKBgBMwGppnPZiQJ4ePDZhT/Wz9GGrpcd+JXAsBTh/woJZuhka7Juir +J+tAMr0kmXXTzO0FIWJooJYSUPlZfV/1H6X9XvVLhaIXpTRV7wAr+RWAsM4w4nFt +8Fg5juK/UwrAd5dObimU27o0t1OtrOy0wSI5yDg9Y5STNcCYx7zaY1fhAoGAUXF8 +q2ow42gsh8LpOYyXYJTERtT3LPAcWjQUiflTZ+uvazg/arZHKFNnsTxbuEGP7Gme +EntVHxRTAWlCrvXB9etEkm6FSEYHptKylH0g+EsG92yH1adlSfpwnrjSMzB6PhVS +SfDhExiAqjPxy9oiVfdxWKGoyRIkSB18vMN69fcCgYBBfK5uSD+1C5mqxeqBrYRr +KXhLGNsO0z5gi+9lTVglOgi1IbZhDPrwaXhOaDbbQUtQ2NOOPXSAjqDm2uxwiXey +ndZuCsS99poHFbpVn9RNOg9REqTZwph2xbcpQMr0u3QtcQNN5wV1wI2Wflmhizuj +K6WjyPfTPmsu+k9N5r7TWQ== +-----END PRIVATE KEY-----)"; + +// The legacy version was manually generated from the one above using der2ascii. +ABSL_CONST_INIT const char kTestCertificatePrivateKeyLegacyPem[] = + R"(-----BEGIN RSA PRIVATE KEY----- +MIIEogIBAAKCAQEAxeJRbT/WKPKtNHOHZMozGTO3dZGrMRkr46QmCSmLLfdSdadVFfARx8LE7Rgb +MwtxMuY1ic0tWgVXTsJ4dWVyLYoXg9YykIX4IuJlqeCg/hmyOS0UAxAvzIteqiUnDaM3EAwX7PCL +xWvta16y4jU+Rjv39lmx4Bam+wO/hE/OZBUNWZmm8H+KM0u7C7jy0SeQjzj4WkGCB5sN2VLgcP/e +2tglTi8tn6+SY8dCtNwWlSMFAmuw6MX+FZrofS/cQ/RwkRqTvnGvhYTbz2tcgLLT80JuJOwqYpnG +POUy5XI3MJsL5Aa0ZCaVWbrxU4M9mW3wgOLbazRSBnc8c77G486yEQIDAQABAoIBADl1rBtDDBa7 +0NuIKGp15DyPLdhvwfvxyTLCuWCztXxVcpZDToueOCt/PNtzwoIh8m7LNgSblW2sW1u9UGkWWf8r +OATKL8iTfifzAX5Agb8HCx9bHZJ+IsMMPSK+wwZMvHJmcJQWjR94ZWpmBx90Qm72ftwD04i0Syxc +PEJZQh8BEzHFIudqlvL7Zv7IoX4kll8C7jghpRTSpjVwbI2m2CrSRTFfZ541V2rEFee6YC+OUk78 +b6AIkTFxBmgZSMeBDV5Sk1fM/kasqU/ilk+vEvvCS8SNO7A45LuNGYHkdGOcjaqEgpHf3EXwObK0 +rEXaPzBNRrHhsp3f2MSi7+kal3kCgYEA5SO41wlUVDu2eHhnV2XF1HSvBU+1yIwb0Zos1uRo0a89 +ckJQyN2x7ndSuLExvvB0eEJZ6hOLggBUItIKJLAfHnYnrmPGa1koHaCfQjDx41kcTzFJ/0V+a+/p +b96vHgSWYU6fWPUNZAhICq6s5HaR3W4zl8WW2v+8QltxtXauAbMCgYEA3RSlbIkrgHj2w4BNU1Sz +K0DOmBagv3Lx49zpC0UjhjhMKfGg4Cz6hj8BkMUblhBEhPvsPHRsDczDzRsoEqq0Z4DI2Rt951Q5 +A226qm/3kx+Udtarm9o9iTeD/nIqu282xeCuZfm7xuKYD732Ivg1W5nm/21uspKTZCXB6JxrcysC +gYATMBqaZz2YkCeHjw2YU/1s/Rhq6XHfiVwLAU4f8KCWboZGuyboqyfrQDK9JJl108ztBSFiaKCW +ElD5WX1f9R+l/V71S4WiF6U0Ve8AK/kVgLDOMOJxbfBYOY7iv1MKwHeXTm4plNu6NLdTrazstMEi +Ocg4PWOUkzXAmMe82mNX4QKBgFFxfKtqMONoLIfC6TmMl2CUxEbU9yzwHFo0FIn5U2frr2s4P2q2 +RyhTZ7E8W7hBj+xpnhJ7VR8UUwFpQq71wfXrRJJuhUhGB6bSspR9IPhLBvdsh9WnZUn6cJ640jMw +ej4VUknw4RMYgKoz8cvaIlX3cVihqMkSJEgdfLzDevX3AoGAQXyubkg/tQuZqsXqga2Eayl4Sxjb +DtM+YIvvZU1YJToItSG2YQz68Gl4Tmg220FLUNjTjj10gI6g5trscIl3sp3WbgrEvfaaBxW6VZ/U +TToPURKk2cKYdsW3KUDK9Lt0LXEDTecFdcCNln5ZoYs7oyulo8j30z5rLvpPTea+01k= +-----END RSA PRIVATE KEY-----)"; + +ABSL_CONST_INIT const char kWildcardCertificateRaw[] = { + '\x30', '\x82', '\x03', '\x5f', '\x30', '\x82', '\x02', '\x47', '\xa0', + '\x03', '\x02', '\x01', '\x02', '\x02', '\x14', '\x36', '\x1d', '\xe3', + '\xd2', '\x39', '\x35', '\x20', '\xb1', '\xae', '\x18', '\xdd', '\x71', + '\xc9', '\x5b', '\x4a', '\x17', '\xbe', '\x00', '\xb4', '\x15', '\x30', + '\x0d', '\x06', '\x09', '\x2a', '\x86', '\x48', '\x86', '\xf7', '\x0d', + '\x01', '\x01', '\x0b', '\x05', '\x00', '\x30', '\x24', '\x31', '\x0b', + '\x30', '\x09', '\x06', '\x03', '\x55', '\x04', '\x06', '\x13', '\x02', + '\x55', '\x53', '\x31', '\x15', '\x30', '\x13', '\x06', '\x03', '\x55', + '\x04', '\x03', '\x0c', '\x0c', '\x77', '\x77', '\x77', '\x2e', '\x66', + '\x6f', '\x6f', '\x2e', '\x74', '\x65', '\x73', '\x74', '\x30', '\x1e', + '\x17', '\x0d', '\x32', '\x30', '\x30', '\x34', '\x32', '\x31', '\x30', + '\x32', '\x31', '\x38', '\x34', '\x35', '\x5a', '\x17', '\x0d', '\x32', + '\x31', '\x30', '\x34', '\x32', '\x31', '\x30', '\x32', '\x31', '\x38', + '\x34', '\x35', '\x5a', '\x30', '\x24', '\x31', '\x0b', '\x30', '\x09', + '\x06', '\x03', '\x55', '\x04', '\x06', '\x13', '\x02', '\x55', '\x53', + '\x31', '\x15', '\x30', '\x13', '\x06', '\x03', '\x55', '\x04', '\x03', + '\x0c', '\x0c', '\x77', '\x77', '\x77', '\x2e', '\x66', '\x6f', '\x6f', + '\x2e', '\x74', '\x65', '\x73', '\x74', '\x30', '\x82', '\x01', '\x22', + '\x30', '\x0d', '\x06', '\x09', '\x2a', '\x86', '\x48', '\x86', '\xf7', + '\x0d', '\x01', '\x01', '\x01', '\x05', '\x00', '\x03', '\x82', '\x01', + '\x0f', '\x00', '\x30', '\x82', '\x01', '\x0a', '\x02', '\x82', '\x01', + '\x01', '\x00', '\xcc', '\xd5', '\x5d', '\xa0', '\x4a', '\x03', '\x9d', + '\x89', '\xa2', '\xae', '\x7a', '\x59', '\x15', '\xf7', '\x27', '\x67', + '\x49', '\xa4', '\xc1', '\x87', '\xcd', '\x9c', '\x02', '\x9e', '\xb9', + '\x2f', '\xd1', '\xa1', '\x0d', '\x57', '\xff', '\xd6', '\xc0', '\x6a', + '\x7b', '\xaa', '\x52', '\xb2', '\x6e', '\xa6', '\x12', '\x34', '\xcf', + '\xdc', '\xd3', '\x1e', '\x32', '\xc1', '\x8d', '\x42', '\xa3', '\x0b', + '\xd6', '\xaf', '\xe9', '\x37', '\x42', '\xf8', '\x78', '\xdc', '\xcb', + '\x2d', '\x0e', '\x42', '\x5a', '\xe2', '\xbf', '\xd2', '\xe4', '\x9c', + '\xb4', '\x34', '\x38', '\x97', '\x5e', '\x4d', '\x5e', '\x8a', '\x0b', + '\xd8', '\x42', '\x11', '\x88', '\x19', '\xa2', '\x23', '\x4b', '\xec', + '\x3b', '\x0a', '\xc9', '\x67', '\x49', '\x2c', '\x8e', '\x1c', '\x5e', + '\x7f', '\x42', '\xe7', '\x73', '\x0b', '\x86', '\x68', '\xf0', '\xaa', + '\x3f', '\x1e', '\x17', '\x3e', '\x29', '\xc4', '\x57', '\x6e', '\x34', + '\x78', '\xaf', '\x15', '\x03', '\x39', '\x32', '\x27', '\x80', '\x76', + '\xb1', '\xda', '\x08', '\xe5', '\x4d', '\x3f', '\x4c', '\xfc', '\x1e', + '\x23', '\x5a', '\xb3', '\xd4', '\x99', '\xdc', '\x5c', '\x2b', '\xf1', + '\xa8', '\xe3', '\x02', '\x0a', '\xc8', '\x4d', '\x63', '\x27', '\xb9', + '\x0d', '\x6c', '\xc2', '\x34', '\x82', '\x82', '\x5d', '\x56', '\xa8', + '\x93', '\x44', '\x8b', '\xf4', '\x8b', '\xf0', '\x63', '\xe5', '\x23', + '\x7f', '\x8d', '\x5f', '\x3a', '\x4a', '\xa5', '\x50', '\xb9', '\xc6', + '\x5c', '\xe6', '\x33', '\xe3', '\xfc', '\xc8', '\x96', '\x88', '\x88', + '\xe9', '\x53', '\xaf', '\x0d', '\xbb', '\x80', '\x9c', '\xbb', '\xed', + '\x4d', '\x06', '\xfa', '\xe9', '\x7c', '\x25', '\x1c', '\x59', '\xee', + '\x19', '\xcc', '\xa9', '\x7c', '\x1d', '\x86', '\xd9', '\x95', '\x78', + '\x2d', '\x3a', '\x95', '\x49', '\x11', '\x45', '\xfa', '\xd6', '\xef', + '\xd5', '\x07', '\x1c', '\x23', '\xeb', '\xad', '\xd3', '\x3b', '\x95', + '\xcf', '\x53', '\xa3', '\x47', '\xa9', '\xa7', '\x90', '\xde', '\x34', + '\xa4', '\xbb', '\x05', '\xdc', '\x54', '\x87', '\x97', '\x30', '\xea', + '\x25', '\xf0', '\xfd', '\xba', '\xa1', '\x1b', '\x02', '\x03', '\x01', + '\x00', '\x01', '\xa3', '\x81', '\x88', '\x30', '\x81', '\x85', '\x30', + '\x1d', '\x06', '\x03', '\x55', '\x1d', '\x0e', '\x04', '\x16', '\x04', + '\x14', '\x09', '\xfb', '\x77', '\xbb', '\xc8', '\x8f', '\xd6', '\xa4', + '\xf0', '\x74', '\xb2', '\x90', '\x46', '\x0a', '\x8d', '\x09', '\x4b', + '\x89', '\x2e', '\x41', '\x30', '\x1f', '\x06', '\x03', '\x55', '\x1d', + '\x23', '\x04', '\x18', '\x30', '\x16', '\x80', '\x14', '\x09', '\xfb', + '\x77', '\xbb', '\xc8', '\x8f', '\xd6', '\xa4', '\xf0', '\x74', '\xb2', + '\x90', '\x46', '\x0a', '\x8d', '\x09', '\x4b', '\x89', '\x2e', '\x41', + '\x30', '\x0f', '\x06', '\x03', '\x55', '\x1d', '\x13', '\x01', '\x01', + '\xff', '\x04', '\x05', '\x30', '\x03', '\x01', '\x01', '\xff', '\x30', + '\x32', '\x06', '\x03', '\x55', '\x1d', '\x11', '\x04', '\x2b', '\x30', + '\x29', '\x82', '\x08', '\x66', '\x6f', '\x6f', '\x2e', '\x74', '\x65', + '\x73', '\x74', '\x82', '\x0c', '\x77', '\x77', '\x77', '\x2e', '\x66', + '\x6f', '\x6f', '\x2e', '\x74', '\x65', '\x73', '\x74', '\x82', '\x0f', + '\x2a', '\x2e', '\x77', '\x69', '\x6c', '\x64', '\x63', '\x61', '\x72', + '\x64', '\x2e', '\x74', '\x65', '\x73', '\x74', '\x30', '\x0d', '\x06', + '\x09', '\x2a', '\x86', '\x48', '\x86', '\xf7', '\x0d', '\x01', '\x01', + '\x0b', '\x05', '\x00', '\x03', '\x82', '\x01', '\x01', '\x00', '\x93', + '\xbc', '\x33', '\x4c', '\xa4', '\xdf', '\xdc', '\xed', '\x4b', '\x4d', + '\x5e', '\xdb', '\xdd', '\x4a', '\xb7', '\xbc', '\x50', '\x1f', '\xca', + '\x66', '\x4d', '\x28', '\x96', '\x42', '\x4e', '\x84', '\x44', '\x80', + '\x25', '\x17', '\x2c', '\x05', '\x93', '\xe0', '\x2a', '\x29', '\xef', + '\xe4', '\x26', '\x19', '\x63', '\xdf', '\xb2', '\x72', '\xb1', '\x82', + '\x7e', '\x5f', '\xce', '\x82', '\x41', '\xad', '\x96', '\x78', '\x94', + '\xa8', '\x21', '\xee', '\xf2', '\x4a', '\xf5', '\x41', '\xa8', '\xfb', + '\xe0', '\xe1', '\x22', '\x89', '\xf1', '\x40', '\x85', '\x86', '\x53', + '\x61', '\x57', '\x0f', '\x31', '\xae', '\x0c', '\xc3', '\x8d', '\xe8', + '\x29', '\xac', '\xe0', '\x03', '\x2d', '\x69', '\x44', '\x3d', '\xd6', + '\x3b', '\x2b', '\x0f', '\xb3', '\xf5', '\x83', '\x1b', '\x4e', '\x65', + '\x60', '\x6b', '\xa2', '\x01', '\x03', '\x1e', '\x98', '\xca', '\xca', + '\x32', '\xd4', '\x5b', '\xde', '\x45', '\xe2', '\x35', '\xd2', '\x54', + '\x1a', '\x2a', '\x38', '\xa7', '\x42', '\xa0', '\xf3', '\xef', '\x28', + '\xe3', '\x6e', '\x23', '\x77', '\x07', '\xd5', '\xef', '\xfd', '\x30', + '\xd6', '\x31', '\xfa', '\xf2', '\x94', '\x95', '\x2f', '\x03', '\x7a', + '\x43', '\xe0', '\xb3', '\x82', '\xca', '\x7e', '\xb4', '\x00', '\xc9', + '\x08', '\x15', '\x7b', '\x2e', '\x51', '\xec', '\xab', '\x68', '\xca', + '\xc2', '\xca', '\x44', '\xe1', '\xbe', '\xe4', '\x06', '\x98', '\x87', + '\x9b', '\x58', '\xbc', '\xf1', '\xea', '\x55', '\xf6', '\x64', '\x92', + '\xe6', '\x73', '\xc9', '\xf6', '\xc5', '\x7a', '\x90', '\x42', '\x83', + '\x39', '\x9e', '\xd0', '\xca', '\x85', '\x6c', '\x53', '\x99', '\x64', + '\xbb', '\x49', '\xdc', '\xae', '\x1c', '\xe5', '\x00', '\x65', '\x13', + '\xdd', '\xdc', '\xde', '\x3f', '\xf9', '\x14', '\x91', '\x0d', '\xe6', + '\xba', '\xc1', '\x7d', '\x5f', '\xd5', '\x6d', '\xe8', '\x65', '\x9c', + '\xfb', '\xda', '\x82', '\xf7', '\x4d', '\x45', '\x81', '\x8c', '\x54', + '\xec', '\x50', '\xbb', '\x14', '\xe9', '\x06', '\xda', '\x76', '\xb3', + '\xf0', '\xb7', '\xbb', '\x58', '\x4c', '\x8f', '\x6a', '\x5d', '\x8e', + '\x93', '\x5f', '\x35'}; + +ABSL_CONST_INIT const absl::string_view kWildcardCertificate( + kWildcardCertificateRaw, sizeof(kWildcardCertificateRaw)); + +ABSL_CONST_INIT const char kWildcardCertificatePrivateKeyRaw[] = { + '\x30', '\x82', '\x04', '\xbe', '\x02', '\x01', '\x00', '\x30', '\x0d', + '\x06', '\x09', '\x2a', '\x86', '\x48', '\x86', '\xf7', '\x0d', '\x01', + '\x01', '\x01', '\x05', '\x00', '\x04', '\x82', '\x04', '\xa8', '\x30', + '\x82', '\x04', '\xa4', '\x02', '\x01', '\x00', '\x02', '\x82', '\x01', + '\x01', '\x00', '\xcc', '\xd5', '\x5d', '\xa0', '\x4a', '\x03', '\x9d', + '\x89', '\xa2', '\xae', '\x7a', '\x59', '\x15', '\xf7', '\x27', '\x67', + '\x49', '\xa4', '\xc1', '\x87', '\xcd', '\x9c', '\x02', '\x9e', '\xb9', + '\x2f', '\xd1', '\xa1', '\x0d', '\x57', '\xff', '\xd6', '\xc0', '\x6a', + '\x7b', '\xaa', '\x52', '\xb2', '\x6e', '\xa6', '\x12', '\x34', '\xcf', + '\xdc', '\xd3', '\x1e', '\x32', '\xc1', '\x8d', '\x42', '\xa3', '\x0b', + '\xd6', '\xaf', '\xe9', '\x37', '\x42', '\xf8', '\x78', '\xdc', '\xcb', + '\x2d', '\x0e', '\x42', '\x5a', '\xe2', '\xbf', '\xd2', '\xe4', '\x9c', + '\xb4', '\x34', '\x38', '\x97', '\x5e', '\x4d', '\x5e', '\x8a', '\x0b', + '\xd8', '\x42', '\x11', '\x88', '\x19', '\xa2', '\x23', '\x4b', '\xec', + '\x3b', '\x0a', '\xc9', '\x67', '\x49', '\x2c', '\x8e', '\x1c', '\x5e', + '\x7f', '\x42', '\xe7', '\x73', '\x0b', '\x86', '\x68', '\xf0', '\xaa', + '\x3f', '\x1e', '\x17', '\x3e', '\x29', '\xc4', '\x57', '\x6e', '\x34', + '\x78', '\xaf', '\x15', '\x03', '\x39', '\x32', '\x27', '\x80', '\x76', + '\xb1', '\xda', '\x08', '\xe5', '\x4d', '\x3f', '\x4c', '\xfc', '\x1e', + '\x23', '\x5a', '\xb3', '\xd4', '\x99', '\xdc', '\x5c', '\x2b', '\xf1', + '\xa8', '\xe3', '\x02', '\x0a', '\xc8', '\x4d', '\x63', '\x27', '\xb9', + '\x0d', '\x6c', '\xc2', '\x34', '\x82', '\x82', '\x5d', '\x56', '\xa8', + '\x93', '\x44', '\x8b', '\xf4', '\x8b', '\xf0', '\x63', '\xe5', '\x23', + '\x7f', '\x8d', '\x5f', '\x3a', '\x4a', '\xa5', '\x50', '\xb9', '\xc6', + '\x5c', '\xe6', '\x33', '\xe3', '\xfc', '\xc8', '\x96', '\x88', '\x88', + '\xe9', '\x53', '\xaf', '\x0d', '\xbb', '\x80', '\x9c', '\xbb', '\xed', + '\x4d', '\x06', '\xfa', '\xe9', '\x7c', '\x25', '\x1c', '\x59', '\xee', + '\x19', '\xcc', '\xa9', '\x7c', '\x1d', '\x86', '\xd9', '\x95', '\x78', + '\x2d', '\x3a', '\x95', '\x49', '\x11', '\x45', '\xfa', '\xd6', '\xef', + '\xd5', '\x07', '\x1c', '\x23', '\xeb', '\xad', '\xd3', '\x3b', '\x95', + '\xcf', '\x53', '\xa3', '\x47', '\xa9', '\xa7', '\x90', '\xde', '\x34', + '\xa4', '\xbb', '\x05', '\xdc', '\x54', '\x87', '\x97', '\x30', '\xea', + '\x25', '\xf0', '\xfd', '\xba', '\xa1', '\x1b', '\x02', '\x03', '\x01', + '\x00', '\x01', '\x02', '\x82', '\x01', '\x01', '\x00', '\xa3', '\xb3', + '\x01', '\x98', '\x50', '\x8e', '\x83', '\x20', '\xb4', '\x3a', '\xec', + '\xdc', '\xb5', '\x89', '\x48', '\x9c', '\x6b', '\x66', '\x98', '\xa4', + '\x87', '\xd5', '\xde', '\xe2', '\x2a', '\xed', '\xe4', '\x82', '\xe9', + '\xbf', '\x22', '\x5f', '\xe6', '\x77', '\x33', '\x4d', '\xf3', '\xb9', + '\x56', '\x64', '\xb2', '\xb8', '\x32', '\x47', '\x31', '\x12', '\x39', + '\x4e', '\x26', '\x2e', '\xd3', '\x4f', '\x6a', '\xcc', '\x3b', '\x7e', + '\x46', '\xaf', '\x7d', '\x28', '\x37', '\xd8', '\x52', '\x45', '\x05', + '\x8d', '\xa1', '\xf0', '\x51', '\x74', '\x4b', '\x30', '\x50', '\xe9', + '\xe8', '\x1b', '\xbd', '\x2a', '\x66', '\x3c', '\xf6', '\xd0', '\x3c', + '\x0d', '\x00', '\x5f', '\x65', '\x15', '\xee', '\x39', '\xb8', '\xac', + '\x2a', '\xf6', '\xc8', '\xbc', '\x33', '\x69', '\x51', '\x76', '\xd7', + '\xa2', '\xa6', '\x50', '\xc7', '\xc5', '\xc7', '\x9b', '\xac', '\xc7', + '\xa9', '\x69', '\x98', '\xd6', '\x22', '\x69', '\x30', '\xc3', '\x82', + '\x47', '\xfb', '\xa5', '\x46', '\x2d', '\x96', '\x05', '\xc2', '\x84', + '\xd1', '\x1d', '\xd5', '\xa7', '\x5c', '\xdb', '\x6d', '\x35', '\x7b', + '\x1b', '\x80', '\xe4', '\x42', '\x1f', '\x4d', '\x68', '\x2e', '\xbc', + '\x58', '\xb6', '\x7c', '\x7e', '\xc5', '\x07', '\xe1', '\xf5', '\x30', + '\xa9', '\x8f', '\x14', '\x76', '\xad', '\xe2', '\xdf', '\xaf', '\xd3', + '\xf1', '\xba', '\xd5', '\x98', '\xf3', '\x5e', '\x30', '\x79', '\xcb', + '\xe7', '\x7a', '\x83', '\xba', '\xf7', '\x71', '\xb0', '\xb2', '\xd1', + '\xf4', '\x34', '\x5b', '\xe1', '\xe8', '\x60', '\x39', '\x96', '\x12', + '\xdc', '\xb4', '\x0d', '\xf9', '\x8d', '\x8c', '\xd8', '\xbb', '\xb7', + '\xd2', '\x1b', '\x83', '\x10', '\xbd', '\x86', '\xef', '\x5c', '\x6c', + '\xe3', '\xb1', '\x96', '\x7f', '\xab', '\x58', '\xce', '\x87', '\xc9', + '\x48', '\x69', '\xbb', '\xb1', '\xec', '\xa4', '\x3a', '\x06', '\xa3', + '\x33', '\xad', '\x7a', '\xe5', '\x88', '\x6d', '\x32', '\x67', '\x1c', + '\x03', '\xda', '\x9d', '\x3c', '\x73', '\xe0', '\xd7', '\x6c', '\x00', + '\xe4', '\x8d', '\x7d', '\xf2', '\xac', '\xa5', '\xb8', '\x35', '\xb9', + '\xac', '\x81', '\x02', '\x81', '\x81', '\x00', '\xe8', '\xd5', '\x5b', + '\xd0', '\x4f', '\x7c', '\xfc', '\x4b', '\xe6', '\xe8', '\x3c', '\x4c', + '\x24', '\xce', '\x68', '\x73', '\x3b', '\x4b', '\xa0', '\xfb', '\x79', + '\xa5', '\x72', '\x1d', '\x77', '\xb2', '\xdf', '\x2b', '\x0a', '\x11', + '\x28', '\xe8', '\x02', '\x7f', '\x26', '\x40', '\x34', '\x8f', '\x78', + '\x18', '\xad', '\xf4', '\x11', '\x78', '\x45', '\x9f', '\x66', '\x4e', + '\x78', '\x71', '\x60', '\x40', '\xeb', '\x64', '\x28', '\x06', '\xae', + '\x9b', '\x32', '\x73', '\xb5', '\xe1', '\x7e', '\x3c', '\x07', '\x31', + '\x8d', '\x82', '\xed', '\x6a', '\xe6', '\x1e', '\x65', '\x9e', '\x81', + '\x29', '\x08', '\x56', '\x17', '\x4b', '\x31', '\xc3', '\xf5', '\x27', + '\xef', '\xb8', '\xda', '\x58', '\xff', '\x36', '\x47', '\x12', '\xb0', + '\xef', '\x14', '\x20', '\x5c', '\x48', '\xb3', '\x84', '\x0d', '\x64', + '\x22', '\x3e', '\xfe', '\x94', '\x17', '\x6c', '\x45', '\xe7', '\x3f', + '\x4c', '\x90', '\x67', '\x13', '\x1a', '\xa8', '\xbc', '\x5b', '\xd0', + '\xc1', '\x8a', '\xa9', '\x42', '\xbe', '\xe4', '\x0e', '\x59', '\x02', + '\x81', '\x81', '\x00', '\xe1', '\x36', '\xcd', '\x86', '\x1e', '\xcb', + '\x8b', '\x68', '\x65', '\x6b', '\x42', '\xec', '\x50', '\x29', '\xa0', + '\xab', '\x3a', '\xe5', '\x6f', '\xe1', '\x13', '\xe8', '\xa3', '\x6b', + '\x7c', '\x2b', '\xd3', '\x69', '\x89', '\x47', '\x07', '\x39', '\xb2', + '\x0f', '\x03', '\x4e', '\x6f', '\x28', '\x94', '\x1d', '\x1f', '\x22', + '\x47', '\xf9', '\x95', '\xff', '\x3e', '\xa4', '\x26', '\x38', '\x07', + '\x5b', '\xdd', '\xef', '\x0a', '\xa5', '\xe8', '\x99', '\xad', '\x91', + '\x68', '\x83', '\xf2', '\xf5', '\xa5', '\x3d', '\x21', '\x88', '\xa5', + '\x6a', '\x39', '\x3b', '\xca', '\x4c', '\xc9', '\xd1', '\x9a', '\x74', + '\xb2', '\xe3', '\x73', '\x5d', '\xfe', '\xbd', '\x05', '\x1b', '\x9a', + '\x13', '\x98', '\x39', '\x93', '\xf3', '\x88', '\x55', '\x61', '\x85', + '\x7a', '\x53', '\x5a', '\xd9', '\x2c', '\xdb', '\x15', '\x69', '\xa6', + '\x31', '\x09', '\xbb', '\xd1', '\xe8', '\x6e', '\x8c', '\x47', '\x77', + '\x1e', '\x9b', '\xbe', '\xb7', '\x57', '\xd4', '\xaa', '\xd5', '\x92', + '\xa1', '\xd5', '\x55', '\x04', '\x93', '\x02', '\x81', '\x80', '\x06', + '\x84', '\x01', '\xff', '\xc0', '\x59', '\xb5', '\x0d', '\xc2', '\xb6', + '\x79', '\x09', '\x80', '\x76', '\x2e', '\x42', '\x1b', '\x44', '\xb0', + '\x8a', '\x99', '\x0a', '\xe2', '\x38', '\xa4', '\xe2', '\xe2', '\x8f', + '\xe7', '\xc6', '\x37', '\x28', '\xd6', '\xf9', '\x0b', '\xee', '\xfc', + '\x09', '\x8f', '\xc8', '\xd1', '\x05', '\x65', '\x7f', '\xc2', '\x23', + '\x05', '\xcf', '\xe8', '\x5a', '\xf3', '\xe0', '\x9d', '\x35', '\xbe', + '\x51', '\x01', '\x8d', '\xe2', '\x49', '\x8e', '\xab', '\x72', '\xc6', + '\xe7', '\x44', '\xa1', '\xbb', '\x2a', '\x3d', '\xb5', '\x96', '\xe0', + '\x2d', '\x21', '\x5c', '\x2e', '\x99', '\x8a', '\x29', '\x56', '\x89', + '\x2f', '\x51', '\x20', '\xca', '\x41', '\x82', '\x00', '\x12', '\x5a', + '\xc6', '\xd1', '\x20', '\xbf', '\xa5', '\x70', '\x2f', '\xb0', '\xa6', + '\x5f', '\x61', '\x8f', '\xfb', '\xc7', '\x50', '\x09', '\x9f', '\xc4', + '\x0d', '\x06', '\x9e', '\x73', '\xe4', '\x0e', '\x8a', '\xce', '\x72', + '\x06', '\xf7', '\xbe', '\x92', '\xcc', '\xcd', '\xcb', '\x5d', '\xc2', + '\x71', '\x02', '\x81', '\x80', '\x26', '\xf3', '\xba', '\x92', '\x52', + '\xeb', '\x33', '\x7e', '\x67', '\xe4', '\x28', '\x5c', '\x04', '\xf5', + '\x5e', '\x33', '\x9f', '\x69', '\x25', '\x73', '\x91', '\x64', '\xf0', + '\x36', '\xdb', '\xf0', '\x1c', '\x8d', '\xa9', '\x4f', '\x9e', '\xa1', + '\x4c', '\xf9', '\xa9', '\xc1', '\xbc', '\x1a', '\x11', '\x9c', '\x03', + '\xd1', '\x83', '\x0f', '\x58', '\xf1', '\x1f', '\x9d', '\x76', '\x7a', + '\xc4', '\x53', '\x10', '\x4c', '\x92', '\xd3', '\xe5', '\x2a', '\x07', + '\x4a', '\x1a', '\x00', '\x90', '\x5a', '\x0a', '\x2d', '\x4b', '\x8a', + '\x7d', '\xc9', '\xa4', '\x82', '\x81', '\xd7', '\xcc', '\x24', '\x33', + '\x89', '\xb1', '\x93', '\x03', '\x56', '\x23', '\x83', '\xff', '\xc9', + '\x29', '\x59', '\xf0', '\x3f', '\x2d', '\x26', '\xb6', '\xd2', '\xc5', + '\x9e', '\x37', '\x6d', '\x09', '\x4e', '\x7c', '\xa2', '\x9b', '\xce', + '\x7d', '\x0f', '\x08', '\x36', '\xf2', '\xf4', '\x37', '\x82', '\x8d', + '\xad', '\xbd', '\x9e', '\x84', '\x5a', '\xe3', '\x97', '\x05', '\xc1', + '\x10', '\xae', '\x6a', '\xde', '\x5c', '\x7f', '\x02', '\x81', '\x81', + '\x00', '\x9b', '\x8e', '\xa4', '\x2b', '\xcf', '\xb6', '\x30', '\x1c', + '\xb5', '\x82', '\x50', '\x08', '\xc0', '\x0b', '\x57', '\xf4', '\x2d', + '\x82', '\x39', '\x11', '\x1b', '\x02', '\xe6', '\xbe', '\x14', '\x26', + '\x77', '\xd7', '\x26', '\x1f', '\x0d', '\x92', '\xc6', '\x67', '\xa0', + '\x01', '\x6c', '\xd9', '\x7a', '\xdf', '\xc3', '\x3d', '\x50', '\x8d', + '\x43', '\xef', '\x95', '\x50', '\x72', '\x25', '\x06', '\x28', '\x7a', + '\x7e', '\x99', '\xea', '\x4d', '\xe8', '\x87', '\xe5', '\xca', '\x71', + '\x36', '\x8a', '\xce', '\x18', '\x55', '\xe4', '\x87', '\x39', '\x3d', + '\xea', '\x9a', '\x22', '\x99', '\x1a', '\xab', '\xe3', '\x6f', '\x48', + '\x78', '\x49', '\x8f', '\xf6', '\xfa', '\xb1', '\xb8', '\x68', '\xae', + '\xc3', '\x47', '\x1d', '\x8f', '\x1d', '\x11', '\xa1', '\x06', '\xf5', + '\xc0', '\x0d', '\xcf', '\x7b', '\x33', '\xfe', '\x0c', '\x69', '\xca', + '\x46', '\xfe', '\x2c', '\xac', '\xd8', '\x4d', '\x02', '\x79', '\xfe', + '\x47', '\xca', '\x21', '\x30', '\x65', '\xa4', '\xe5', '\xaa', '\x4e', + '\x9c', '\xbc', '\xa5'}; + +ABSL_CONST_INIT const absl::string_view kWildcardCertificatePrivateKey( + kWildcardCertificatePrivateKeyRaw, + sizeof(kWildcardCertificatePrivateKeyRaw)); + +ABSL_CONST_INIT const char kTestEcPrivateKeyLegacyPem[] = + R"(-----BEGIN EC PARAMETERS----- +BggqhkjOPQMBBw== +-----END EC PARAMETERS----- +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIMdjXX0hg399DlccZuYFXPKq+dMGduXWmQYClDYJNDGroAoGCCqGSM49 +AwEHoUQDQgAENCuPQTywFI8hbsGo68AeN1KVWmd09buzlu/2CAtsJcNoECUmpVXH +4dwvWMv6zWn9RJ5EzI72R/5FVcO485s5MQ== +-----END EC PRIVATE KEY-----)"; + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/test_certificates.h b/quiche/quic/test_tools/test_certificates.h new file mode 100644 index 000000000000..6a7eba768f17 --- /dev/null +++ b/quiche/quic/test_tools/test_certificates.h @@ -0,0 +1,50 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_TEST_CERTIFICATES_H_ +#define QUICHE_QUIC_TEST_TOOLS_TEST_CERTIFICATES_H_ + +#include "absl/base/attributes.h" +#include "absl/strings/string_view.h" + +namespace quic { +namespace test { + +// A test certificate generated by //net/tools/quic/certs/generate-certs.sh. +ABSL_CONST_INIT extern const absl::string_view kTestCertificate; + +// PEM-encoded version of |kTestCertificate|. +ABSL_CONST_INIT extern const char kTestCertificatePem[]; + +// |kTestCertificatePem| with a PEM-encoded root appended to the end. +ABSL_CONST_INIT extern const char kTestCertificateChainPem[]; + +// PEM-encoded certificate that contains a subjectAltName with an +// unknown/unsupported type. +ABSL_CONST_INIT extern const char kTestCertWithUnknownSanTypePem[]; + +// DER-encoded private key for |kTestCertificate|. +ABSL_CONST_INIT extern const absl::string_view kTestCertificatePrivateKey; + +// PEM-encoded version of |kTestCertificatePrivateKey|. +ABSL_CONST_INIT extern const char kTestCertificatePrivateKeyPem[]; + +// The legacy PEM-encoded version of |kTestCertificatePrivateKey| manually +// generated from the one above using der2ascii. +ABSL_CONST_INIT extern const char kTestCertificatePrivateKeyLegacyPem[]; + +// Another DER-encoded test certificate, valid for foo.test, www.foo.test and +// *.wildcard.test. +ABSL_CONST_INIT extern const absl::string_view kWildcardCertificate; + +// DER-encoded private key for |kWildcardCertificate|. +ABSL_CONST_INIT extern const absl::string_view kWildcardCertificatePrivateKey; + +// PEM-encoded P-256 private key using legacy OpenSSL encoding. +ABSL_CONST_INIT extern const char kTestEcPrivateKeyLegacyPem[]; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_TEST_CERTIFICATES_H_ diff --git a/quiche/quic/test_tools/test_ticket_crypter.cc b/quiche/quic/test_tools/test_ticket_crypter.cc new file mode 100644 index 000000000000..0a5ec49f41a7 --- /dev/null +++ b/quiche/quic/test_tools/test_ticket_crypter.cc @@ -0,0 +1,84 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/test_ticket_crypter.h" + +#include + +#include "absl/base/macros.h" +#include "quiche/quic/core/crypto/quic_random.h" + +namespace quic { +namespace test { + +namespace { + +// A TicketCrypter implementation is supposed to encrypt and decrypt session +// tickets. However, the only requirement that is needed of a test +// implementation is that calling Decrypt(Encrypt(input), callback) results in +// callback being called with input. (The output of Encrypt must also not exceed +// the overhead specified by MaxOverhead.) This test implementation encrypts +// tickets by prepending kTicketPrefix to generate the ciphertext. The decrypt +// function checks that the prefix is present and strips it; otherwise it +// returns an empty vector to signal failure. +constexpr char kTicketPrefix[] = "TEST TICKET"; + +} // namespace + +TestTicketCrypter::TestTicketCrypter() + : ticket_prefix_(ABSL_ARRAYSIZE(kTicketPrefix) + 16) { + memcpy(ticket_prefix_.data(), kTicketPrefix, ABSL_ARRAYSIZE(kTicketPrefix)); + QuicRandom::GetInstance()->RandBytes( + ticket_prefix_.data() + ABSL_ARRAYSIZE(kTicketPrefix), 16); +} + +size_t TestTicketCrypter::MaxOverhead() { return ticket_prefix_.size(); } + +std::vector TestTicketCrypter::Encrypt( + absl::string_view in, absl::string_view /* encryption_key */) { + if (fail_encrypt_) { + return {}; + } + size_t prefix_len = ticket_prefix_.size(); + std::vector out(prefix_len + in.size()); + memcpy(out.data(), ticket_prefix_.data(), prefix_len); + memcpy(out.data() + prefix_len, in.data(), in.size()); + return out; +} + +std::vector TestTicketCrypter::Decrypt(absl::string_view in) { + size_t prefix_len = ticket_prefix_.size(); + if (fail_decrypt_ || in.size() < prefix_len || + memcmp(ticket_prefix_.data(), in.data(), prefix_len) != 0) { + return std::vector(); + } + return std::vector(in.begin() + prefix_len, in.end()); +} + +void TestTicketCrypter::Decrypt( + absl::string_view in, + std::shared_ptr callback) { + auto decrypted_ticket = Decrypt(in); + if (run_async_) { + pending_callbacks_.push_back({std::move(callback), decrypted_ticket}); + } else { + callback->Run(decrypted_ticket); + } +} + +void TestTicketCrypter::SetRunCallbacksAsync(bool run_async) { + run_async_ = run_async; +} + +size_t TestTicketCrypter::NumPendingCallbacks() { + return pending_callbacks_.size(); +} + +void TestTicketCrypter::RunPendingCallback(size_t n) { + const PendingCallback& callback = pending_callbacks_[n]; + callback.callback->Run(callback.decrypted_ticket); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/test_ticket_crypter.h b/quiche/quic/test_tools/test_ticket_crypter.h new file mode 100644 index 000000000000..0efd0f920366 --- /dev/null +++ b/quiche/quic/test_tools/test_ticket_crypter.h @@ -0,0 +1,54 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_TEST_TICKET_CRYPTER_H_ +#define QUICHE_QUIC_TEST_TOOLS_TEST_TICKET_CRYPTER_H_ + +#include "quiche/quic/core/crypto/proof_source.h" + +namespace quic { +namespace test { + +// Provides a simple implementation of ProofSource::TicketCrypter for testing. +// THIS IMPLEMENTATION IS NOT SECURE. It is only intended for testing purposes. +class TestTicketCrypter : public ProofSource::TicketCrypter { + public: + TestTicketCrypter(); + ~TestTicketCrypter() override = default; + + // TicketCrypter interface + size_t MaxOverhead() override; + std::vector Encrypt(absl::string_view in, + absl::string_view encryption_key) override; + void Decrypt(absl::string_view in, + std::shared_ptr callback) override; + + void SetRunCallbacksAsync(bool run_async); + size_t NumPendingCallbacks(); + void RunPendingCallback(size_t n); + + // Allows configuring this TestTicketCrypter to fail decryption. + void set_fail_decrypt(bool fail_decrypt) { fail_decrypt_ = fail_decrypt; } + void set_fail_encrypt(bool fail_encrypt) { fail_encrypt_ = fail_encrypt; } + + private: + // Performs the Decrypt operation synchronously. + std::vector Decrypt(absl::string_view in); + + struct PendingCallback { + std::shared_ptr callback; + std::vector decrypted_ticket; + }; + + bool fail_decrypt_ = false; + bool fail_encrypt_ = false; + bool run_async_ = false; + std::vector pending_callbacks_; + std::vector ticket_prefix_; +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_TEST_TICKET_CRYPTER_H_ diff --git a/quiche/quic/test_tools/web_transport_resets_backend.cc b/quiche/quic/test_tools/web_transport_resets_backend.cc new file mode 100644 index 000000000000..63bb8422bd28 --- /dev/null +++ b/quiche/quic/test_tools/web_transport_resets_backend.cc @@ -0,0 +1,113 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/test_tools/web_transport_resets_backend.h" + +#include + +#include "quiche/quic/core/web_transport_interface.h" +#include "quiche/quic/tools/web_transport_test_visitors.h" +#include "quiche/common/quiche_circular_deque.h" + +namespace quic { +namespace test { + +namespace { + +class ResetsVisitor; + +class BidirectionalEchoVisitorWithLogging + : public WebTransportBidirectionalEchoVisitor { + public: + BidirectionalEchoVisitorWithLogging(WebTransportStream* stream, + ResetsVisitor* session_visitor) + : WebTransportBidirectionalEchoVisitor(stream), + session_visitor_(session_visitor) {} + + void OnResetStreamReceived(WebTransportStreamError error) override; + void OnStopSendingReceived(WebTransportStreamError error) override; + + private: + ResetsVisitor* session_visitor_; // Not owned. +}; + +class ResetsVisitor : public WebTransportVisitor { + public: + ResetsVisitor(WebTransportSession* session) : session_(session) {} + + void OnSessionReady(const spdy::Http2HeaderBlock& /*headers*/) override {} + void OnSessionClosed(WebTransportSessionError /*error_code*/, + const std::string& /*error_message*/) override {} + + void OnIncomingBidirectionalStreamAvailable() override { + while (true) { + WebTransportStream* stream = + session_->AcceptIncomingBidirectionalStream(); + if (stream == nullptr) { + return; + } + stream->SetVisitor( + std::make_unique(stream, this)); + stream->visitor()->OnCanRead(); + } + } + void OnIncomingUnidirectionalStreamAvailable() override {} + + void OnDatagramReceived(absl::string_view /*datagram*/) override {} + + void OnCanCreateNewOutgoingBidirectionalStream() override {} + void OnCanCreateNewOutgoingUnidirectionalStream() override { + MaybeSendLogsBack(); + } + + void Log(std::string line) { + log_.push_back(std::move(line)); + MaybeSendLogsBack(); + } + + private: + void MaybeSendLogsBack() { + while (!log_.empty() && + session_->CanOpenNextOutgoingUnidirectionalStream()) { + WebTransportStream* stream = session_->OpenOutgoingUnidirectionalStream(); + stream->SetVisitor( + std::make_unique( + stream, log_.front())); + log_.pop_front(); + stream->visitor()->OnCanWrite(); + } + } + + WebTransportSession* session_; // Not owned. + quiche::QuicheCircularDeque log_; +}; + +void BidirectionalEchoVisitorWithLogging::OnResetStreamReceived( + WebTransportStreamError error) { + session_visitor_->Log(absl::StrCat("Received reset for stream ", + stream()->GetStreamId(), + " with error code ", error)); + WebTransportBidirectionalEchoVisitor::OnResetStreamReceived(error); +} +void BidirectionalEchoVisitorWithLogging::OnStopSendingReceived( + WebTransportStreamError error) { + session_visitor_->Log(absl::StrCat("Received stop sending for stream ", + stream()->GetStreamId(), + " with error code ", error)); + WebTransportBidirectionalEchoVisitor::OnStopSendingReceived(error); +} + +} // namespace + +QuicSimpleServerBackend::WebTransportResponse WebTransportResetsBackend( + const spdy::Http2HeaderBlock& /*request_headers*/, + WebTransportSession* session) { + QuicSimpleServerBackend::WebTransportResponse response; + response.response_headers[":status"] = "200"; + response.visitor = std::make_unique(session); + return response; +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/test_tools/web_transport_resets_backend.h b/quiche/quic/test_tools/web_transport_resets_backend.h new file mode 100644 index 000000000000..c5ffbe0df7af --- /dev/null +++ b/quiche/quic/test_tools/web_transport_resets_backend.h @@ -0,0 +1,24 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_WEB_TRANSPORT_RESETS_BACKEND_H_ +#define QUICHE_QUIC_TEST_TOOLS_WEB_TRANSPORT_RESETS_BACKEND_H_ + +#include "quiche/quic/test_tools/quic_test_backend.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { +namespace test { + +// A backend for testing RESET_STREAM/STOP_SENDING behavior. Provides +// bidirectional echo streams; whenever one of those receives RESET_STREAM or +// STOP_SENDING, a log message is sent as a unidirectional stream. +QuicSimpleServerBackend::WebTransportResponse WebTransportResetsBackend( + const spdy::Http2HeaderBlock& request_headers, + WebTransportSession* session); + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_WEB_TRANSPORT_RESETS_BACKEND_H_ diff --git a/quiche/quic/test_tools/web_transport_test_tools.h b/quiche/quic/test_tools/web_transport_test_tools.h new file mode 100644 index 000000000000..353b48f14a78 --- /dev/null +++ b/quiche/quic/test_tools/web_transport_test_tools.h @@ -0,0 +1,43 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_WEB_TRANSPORT_TEST_TOOLS_H_ +#define QUICHE_QUIC_TEST_TOOLS_WEB_TRANSPORT_TEST_TOOLS_H_ + +#include "quiche/quic/core/web_transport_interface.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { +namespace test { + +class MockWebTransportSessionVisitor : public WebTransportVisitor { + public: + MOCK_METHOD(void, OnSessionReady, (const spdy::Http2HeaderBlock&), + (override)); + MOCK_METHOD(void, OnSessionClosed, + (WebTransportSessionError, const std::string&), (override)); + MOCK_METHOD(void, OnIncomingBidirectionalStreamAvailable, (), (override)); + MOCK_METHOD(void, OnIncomingUnidirectionalStreamAvailable, (), (override)); + MOCK_METHOD(void, OnDatagramReceived, (absl::string_view), (override)); + MOCK_METHOD(void, OnCanCreateNewOutgoingBidirectionalStream, (), (override)); + MOCK_METHOD(void, OnCanCreateNewOutgoingUnidirectionalStream, (), (override)); +}; + +class MockWebTransportStreamVisitor : public WebTransportStreamVisitor { + public: + MOCK_METHOD(void, OnCanRead, (), (override)); + MOCK_METHOD(void, OnCanWrite, (), (override)); + + MOCK_METHOD(void, OnResetStreamReceived, (WebTransportStreamError error), + (override)); + MOCK_METHOD(void, OnStopSendingReceived, (WebTransportStreamError error), + (override)); + MOCK_METHOD(void, OnWriteSideInDataRecvdState, (), (override)); +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_WEB_TRANSPORT_TEST_TOOLS_H_ diff --git a/quiche/quic/tools/connect_server_backend.cc b/quiche/quic/tools/connect_server_backend.cc new file mode 100644 index 000000000000..cc46571ea90f --- /dev/null +++ b/quiche/quic/tools/connect_server_backend.cc @@ -0,0 +1,164 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/connect_server_backend.h" + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/core/socket_factory.h" +#include "quiche/quic/tools/connect_tunnel.h" +#include "quiche/quic/tools/connect_udp_tunnel.h" +#include "quiche/quic/tools/quic_simple_server_backend.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +namespace { + +void SendErrorResponse(QuicSimpleServerBackend::RequestHandler* request_handler, + absl::string_view error_code) { + spdy::Http2HeaderBlock headers; + headers[":status"] = error_code; + QuicBackendResponse response; + response.set_headers(std::move(headers)); + request_handler->OnResponseBackendComplete(&response); +} + +} // namespace + +ConnectServerBackend::ConnectServerBackend( + std::unique_ptr non_connect_backend, + absl::flat_hash_set acceptable_connect_destinations, + absl::flat_hash_set acceptable_connect_udp_targets, + std::string server_label) + : non_connect_backend_(std::move(non_connect_backend)), + acceptable_connect_destinations_( + std::move(acceptable_connect_destinations)), + acceptable_connect_udp_targets_( + std::move(acceptable_connect_udp_targets)), + server_label_(std::move(server_label)) { + QUICHE_DCHECK(non_connect_backend_); + QUICHE_DCHECK(!server_label_.empty()); +} + +ConnectServerBackend::~ConnectServerBackend() { + // Expect all streams to be closed before destroying backend. + QUICHE_DCHECK(connect_tunnels_.empty()); + QUICHE_DCHECK(connect_udp_tunnels_.empty()); +} + +bool ConnectServerBackend::InitializeBackend(const std::string&) { + return true; +} + +bool ConnectServerBackend::IsBackendInitialized() const { return true; } + +void ConnectServerBackend::SetSocketFactory(SocketFactory* socket_factory) { + QUICHE_DCHECK_NE(socket_factory_, socket_factory); + QUICHE_DCHECK(connect_tunnels_.empty()); + QUICHE_DCHECK(connect_udp_tunnels_.empty()); + socket_factory_ = socket_factory; +} + +void ConnectServerBackend::FetchResponseFromBackend( + const spdy::Http2HeaderBlock& request_headers, + const std::string& request_body, RequestHandler* request_handler) { + // Not a CONNECT request, so send to `non_connect_backend_`. + non_connect_backend_->FetchResponseFromBackend(request_headers, request_body, + request_handler); +} + +void ConnectServerBackend::HandleConnectHeaders( + const spdy::Http2HeaderBlock& request_headers, + RequestHandler* request_handler) { + QUICHE_DCHECK(request_headers.contains(":method") && + request_headers.find(":method")->second == "CONNECT"); + + if (!socket_factory_) { + QUICHE_BUG(connect_server_backend_no_socket_factory) + << "Must set socket factory before ConnectServerBackend receives " + "requests."; + SendErrorResponse(request_handler, "500"); + return; + } + + if (!request_headers.contains(":protocol")) { + // normal CONNECT + auto [tunnel_it, inserted] = connect_tunnels_.emplace( + request_handler->stream_id(), + std::make_unique(request_handler, socket_factory_, + acceptable_connect_destinations_)); + QUICHE_DCHECK(inserted); + + tunnel_it->second->OpenTunnel(request_headers); + } else if (request_headers.find(":protocol")->second == "connect-udp") { + // CONNECT-UDP + auto [tunnel_it, inserted] = connect_udp_tunnels_.emplace( + request_handler->stream_id(), + std::make_unique(request_handler, socket_factory_, + server_label_, + acceptable_connect_udp_targets_)); + QUICHE_DCHECK(inserted); + + tunnel_it->second->OpenTunnel(request_headers); + } else { + // Not a supported request. + non_connect_backend_->HandleConnectHeaders(request_headers, + request_handler); + } +} + +void ConnectServerBackend::HandleConnectData(absl::string_view data, + bool data_complete, + RequestHandler* request_handler) { + // Expect ConnectUdpTunnels to register a datagram visitor, causing the + // stream to process data as capsules. HandleConnectData() should therefore + // never be called for streams with a ConnectUdpTunnel. + QUICHE_DCHECK(!connect_udp_tunnels_.contains(request_handler->stream_id())); + + auto tunnel_it = connect_tunnels_.find(request_handler->stream_id()); + if (tunnel_it == connect_tunnels_.end()) { + // If tunnel not found, perhaps it's something being handled for + // non-CONNECT. Possible because this method could be called for anything + // with a ":method":"CONNECT" header, but this class does not handle such + // requests if they have a ":protocol" header. + non_connect_backend_->HandleConnectData(data, data_complete, + request_handler); + return; + } + + if (!data.empty()) { + tunnel_it->second->SendDataToDestination(data); + } + if (data_complete) { + tunnel_it->second->OnClientStreamClose(); + connect_tunnels_.erase(tunnel_it); + } +} + +void ConnectServerBackend::CloseBackendResponseStream( + QuicSimpleServerBackend::RequestHandler* request_handler) { + auto tunnel_it = connect_tunnels_.find(request_handler->stream_id()); + if (tunnel_it != connect_tunnels_.end()) { + tunnel_it->second->OnClientStreamClose(); + connect_tunnels_.erase(tunnel_it); + } + + auto udp_tunnel_it = connect_udp_tunnels_.find(request_handler->stream_id()); + if (udp_tunnel_it != connect_udp_tunnels_.end()) { + udp_tunnel_it->second->OnClientStreamClose(); + connect_udp_tunnels_.erase(udp_tunnel_it); + } + + non_connect_backend_->CloseBackendResponseStream(request_handler); +} + +} // namespace quic diff --git a/quiche/quic/tools/connect_server_backend.h b/quiche/quic/tools/connect_server_backend.h new file mode 100644 index 000000000000..c3dcab6cfc65 --- /dev/null +++ b/quiche/quic/tools/connect_server_backend.h @@ -0,0 +1,70 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CONNECT_PROXY_CONNECT_SERVER_BACKEND_H_ +#define QUICHE_QUIC_CONNECT_PROXY_CONNECT_SERVER_BACKEND_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/core/socket_factory.h" +#include "quiche/quic/tools/connect_tunnel.h" +#include "quiche/quic/tools/connect_udp_tunnel.h" +#include "quiche/quic/tools/quic_simple_server_backend.h" + +namespace quic { + +// QUIC server backend that handles CONNECT and CONNECT-UDP requests. +// Non-CONNECT requests are delegated to a separate backend. +class ConnectServerBackend : public QuicSimpleServerBackend { + public: + // `server_label` is an identifier (typically randomly generated) to identify + // the server or backend in error headers, per the requirements of RFC 9209, + // Section 2. + ConnectServerBackend( + std::unique_ptr non_connect_backend, + absl::flat_hash_set acceptable_connect_destinations, + absl::flat_hash_set acceptable_connect_udp_targets, + std::string server_label); + + ConnectServerBackend(const ConnectServerBackend&) = delete; + ConnectServerBackend& operator=(const ConnectServerBackend&) = delete; + + ~ConnectServerBackend() override; + + // QuicSimpleServerBackend: + bool InitializeBackend(const std::string& backend_url) override; + bool IsBackendInitialized() const override; + void SetSocketFactory(SocketFactory* socket_factory) override; + void FetchResponseFromBackend(const spdy::Http2HeaderBlock& request_headers, + const std::string& request_body, + RequestHandler* request_handler) override; + void HandleConnectHeaders(const spdy::Http2HeaderBlock& request_headers, + RequestHandler* request_handler) override; + void HandleConnectData(absl::string_view data, bool data_complete, + RequestHandler* request_handler) override; + void CloseBackendResponseStream( + QuicSimpleServerBackend::RequestHandler* request_handler) override; + + private: + std::unique_ptr non_connect_backend_; + const absl::flat_hash_set acceptable_connect_destinations_; + const absl::flat_hash_set acceptable_connect_udp_targets_; + const std::string server_label_; + + SocketFactory* socket_factory_; // unowned + absl::flat_hash_map> + connect_tunnels_; + absl::flat_hash_map> + connect_udp_tunnels_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CONNECT_PROXY_CONNECT_SERVER_BACKEND_H_ diff --git a/quiche/quic/tools/connect_tunnel.cc b/quiche/quic/tools/connect_tunnel.cc new file mode 100644 index 000000000000..d31e5403eec6 --- /dev/null +++ b/quiche/quic/tools/connect_tunnel.cc @@ -0,0 +1,295 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/connect_tunnel.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/core/socket_factory.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/tools/quic_backend_response.h" +#include "quiche/quic/tools/quic_name_lookup.h" +#include "quiche/quic/tools/quic_simple_server_backend.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +namespace { + +// Arbitrarily chosen. No effort has been made to figure out an optimal size. +constexpr size_t kReadSize = 4 * 1024; + +absl::optional ValidateHeadersAndGetAuthority( + const spdy::Http2HeaderBlock& request_headers) { + QUICHE_DCHECK(request_headers.contains(":method")); + QUICHE_DCHECK(request_headers.find(":method")->second == "CONNECT"); + QUICHE_DCHECK(!request_headers.contains(":protocol")); + + auto scheme_it = request_headers.find(":scheme"); + if (scheme_it != request_headers.end()) { + QUICHE_DVLOG(1) << "CONNECT request contains unexpected scheme: " + << scheme_it->second; + return absl::nullopt; + } + + auto path_it = request_headers.find(":path"); + if (path_it != request_headers.end()) { + QUICHE_DVLOG(1) << "CONNECT request contains unexpected path: " + << path_it->second; + return absl::nullopt; + } + + auto authority_it = request_headers.find(":authority"); + if (authority_it == request_headers.end() || authority_it->second.empty()) { + QUICHE_DVLOG(1) << "CONNECT request missing authority"; + return absl::nullopt; + } + + // A valid CONNECT authority must contain host and port and nothing else, per + // https://www.rfc-editor.org/rfc/rfc9110.html#name-connect. This matches the + // host and port parsing rules for QuicServerId. + absl::optional server_id = + QuicServerId::ParseFromHostPortString(authority_it->second); + if (!server_id.has_value()) { + QUICHE_DVLOG(1) << "CONNECT request authority is malformed: " + << authority_it->second; + return absl::nullopt; + } + + return server_id; +} + +bool ValidateAuthority( + const QuicServerId& authority, + const absl::flat_hash_set& acceptable_destinations) { + if (acceptable_destinations.contains(authority)) { + return true; + } + + QUICHE_DVLOG(1) << "CONNECT request authority: " + << authority.ToHostPortString() + << " is not an acceptable allow-listed destiation "; + return false; +} + +} // namespace + +ConnectTunnel::ConnectTunnel( + QuicSimpleServerBackend::RequestHandler* client_stream_request_handler, + SocketFactory* socket_factory, + absl::flat_hash_set acceptable_destinations) + : acceptable_destinations_(std::move(acceptable_destinations)), + socket_factory_(socket_factory), + client_stream_request_handler_(client_stream_request_handler) { + QUICHE_DCHECK(client_stream_request_handler_); + QUICHE_DCHECK(socket_factory_); +} + +ConnectTunnel::~ConnectTunnel() { + // Expect client and destination sides of tunnel to both be closed before + // destruction. + QUICHE_DCHECK_EQ(client_stream_request_handler_, nullptr); + QUICHE_DCHECK(!IsConnectedToDestination()); + QUICHE_DCHECK(!receive_started_); +} + +void ConnectTunnel::OpenTunnel(const spdy::Http2HeaderBlock& request_headers) { + QUICHE_DCHECK(!IsConnectedToDestination()); + + absl::optional authority = + ValidateHeadersAndGetAuthority(request_headers); + if (!authority.has_value()) { + TerminateClientStream( + "invalid request headers", + QuicResetStreamError::FromIetf(QuicHttp3ErrorCode::MESSAGE_ERROR)); + return; + } + + if (!ValidateAuthority(authority.value(), acceptable_destinations_)) { + TerminateClientStream( + "disallowed request authority", + QuicResetStreamError::FromIetf(QuicHttp3ErrorCode::REQUEST_REJECTED)); + return; + } + + QuicSocketAddress address = + tools::LookupAddress(AF_UNSPEC, authority.value()); + if (!address.IsInitialized()) { + TerminateClientStream("host resolution error"); + return; + } + + destination_socket_ = + socket_factory_->CreateTcpClientSocket(address, + /*receive_buffer_size=*/0, + /*send_buffer_size=*/0, + /*async_visitor=*/this); + QUICHE_DCHECK(destination_socket_); + + absl::Status connect_result = destination_socket_->ConnectBlocking(); + if (!connect_result.ok()) { + TerminateClientStream( + "error connecting TCP socket to destination server: " + + connect_result.ToString()); + return; + } + + QUICHE_DVLOG(1) << "CONNECT tunnel opened from stream " + << client_stream_request_handler_->stream_id() << " to " + << authority.value().ToHostPortString(); + + SendConnectResponse(); + BeginAsyncReadFromDestination(); +} + +bool ConnectTunnel::IsConnectedToDestination() const { + return !!destination_socket_; +} + +void ConnectTunnel::SendDataToDestination(absl::string_view data) { + QUICHE_DCHECK(IsConnectedToDestination()); + QUICHE_DCHECK(!data.empty()); + + absl::Status send_result = + destination_socket_->SendBlocking(std::string(data)); + if (!send_result.ok()) { + TerminateClientStream("TCP error sending data to destination server: " + + send_result.ToString()); + } +} + +void ConnectTunnel::OnClientStreamClose() { + QUICHE_DCHECK(client_stream_request_handler_); + + QUICHE_DVLOG(1) << "CONNECT stream " + << client_stream_request_handler_->stream_id() << " closed"; + + client_stream_request_handler_ = nullptr; + + if (IsConnectedToDestination()) { + // TODO(ericorth): Consider just calling shutdown() on the socket rather + // than fully disconnecting in order to allow a graceful TCP FIN stream + // shutdown per + // https://www.rfc-editor.org/rfc/rfc9114.html#name-the-connect-method. + // Would require shutdown support in the socket library, and would need to + // deal with the tunnel/socket outliving the client stream. + destination_socket_->Disconnect(); + } + + // Clear socket pointer. + destination_socket_.reset(); +} + +void ConnectTunnel::ConnectComplete(absl::Status /*status*/) { + // Async connect not expected. + QUICHE_NOTREACHED(); +} + +void ConnectTunnel::ReceiveComplete( + absl::StatusOr data) { + QUICHE_DCHECK(IsConnectedToDestination()); + QUICHE_DCHECK(receive_started_); + + receive_started_ = false; + + if (!data.ok()) { + if (client_stream_request_handler_) { + TerminateClientStream("TCP error receiving data from destination server"); + } else { + // This typically just means a receive operation was cancelled on calling + // destination_socket_->Disconnect(). + QUICHE_DVLOG(1) << "TCP error receiving data from destination server " + "after stream already closed."; + } + return; + } else if (data.value().empty()) { + OnDestinationConnectionClosed(); + return; + } + + QUICHE_DCHECK(client_stream_request_handler_); + client_stream_request_handler_->SendStreamData(data.value().AsStringView(), + /*close_stream=*/false); + + BeginAsyncReadFromDestination(); +} + +void ConnectTunnel::SendComplete(absl::Status /*status*/) { + // Async send not expected. + QUICHE_NOTREACHED(); +} + +void ConnectTunnel::BeginAsyncReadFromDestination() { + QUICHE_DCHECK(IsConnectedToDestination()); + QUICHE_DCHECK(client_stream_request_handler_); + QUICHE_DCHECK(!receive_started_); + + receive_started_ = true; + destination_socket_->ReceiveAsync(kReadSize); +} + +void ConnectTunnel::OnDestinationConnectionClosed() { + QUICHE_DCHECK(IsConnectedToDestination()); + QUICHE_DCHECK(client_stream_request_handler_); + + QUICHE_DVLOG(1) << "CONNECT stream " + << client_stream_request_handler_->stream_id() + << " destination connection closed"; + destination_socket_->Disconnect(); + + // Clear socket pointer. + destination_socket_.reset(); + + // Extra check that nothing in the Disconnect could lead to terminating the + // stream. + QUICHE_DCHECK(client_stream_request_handler_); + + client_stream_request_handler_->SendStreamData("", /*close_stream=*/true); +} + +void ConnectTunnel::SendConnectResponse() { + QUICHE_DCHECK(IsConnectedToDestination()); + QUICHE_DCHECK(client_stream_request_handler_); + + spdy::Http2HeaderBlock response_headers; + response_headers[":status"] = "200"; + + QuicBackendResponse response; + response.set_headers(std::move(response_headers)); + // Need to leave the stream open after sending the CONNECT response. + response.set_response_type(QuicBackendResponse::INCOMPLETE_RESPONSE); + + client_stream_request_handler_->OnResponseBackendComplete(&response); +} + +void ConnectTunnel::TerminateClientStream(absl::string_view error_description, + QuicResetStreamError error_code) { + QUICHE_DCHECK(client_stream_request_handler_); + + std::string error_description_str = + error_description.empty() ? "" + : absl::StrCat(" due to ", error_description); + QUICHE_DVLOG(1) << "Terminating CONNECT stream " + << client_stream_request_handler_->stream_id() + << " with error code " << error_code.ietf_application_code() + << error_description_str; + + client_stream_request_handler_->TerminateStreamWithError(error_code); +} + +} // namespace quic diff --git a/quiche/quic/tools/connect_tunnel.h b/quiche/quic/tools/connect_tunnel.h new file mode 100644 index 000000000000..259f12542e34 --- /dev/null +++ b/quiche/quic/tools/connect_tunnel.h @@ -0,0 +1,89 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TOOLS_CONNECT_TUNNEL_H_ +#define QUICHE_QUIC_TOOLS_CONNECT_TUNNEL_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/connecting_client_socket.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/core/socket_factory.h" +#include "quiche/quic/tools/quic_simple_server_backend.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +// Manages a single connection tunneled over a CONNECT proxy. +class ConnectTunnel : public ConnectingClientSocket::AsyncVisitor { + public: + // `client_stream_request_handler` and `socket_factory` must both outlive the + // created ConnectTunnel. + ConnectTunnel( + QuicSimpleServerBackend::RequestHandler* client_stream_request_handler, + SocketFactory* socket_factory, + absl::flat_hash_set acceptable_destinations); + ~ConnectTunnel(); + ConnectTunnel(const ConnectTunnel&) = delete; + ConnectTunnel& operator=(const ConnectTunnel&) = delete; + + // Attempts to open TCP connection to destination server and then sends + // appropriate success/error response to the request stream. `request_headers` + // must represent headers from a CONNECT request, that is ":method"="CONNECT" + // and no ":protocol". + void OpenTunnel(const spdy::Http2HeaderBlock& request_headers); + + // Returns true iff the connection to the destination server is currently open + bool IsConnectedToDestination() const; + + void SendDataToDestination(absl::string_view data); + + // Called when the client stream has been closed. Connection to destination + // server is closed if connected. The RequestHandler will no longer be + // interacted with after completion. + void OnClientStreamClose(); + + // ConnectingClientSocket::AsyncVisitor: + void ConnectComplete(absl::Status status) override; + void ReceiveComplete(absl::StatusOr data) override; + void SendComplete(absl::Status status) override; + + private: + void BeginAsyncReadFromDestination(); + void OnDataReceivedFromDestination(bool success); + + // For normal (FIN) closure. Errors (RST) should result in directly calling + // TerminateClientStream(). + void OnDestinationConnectionClosed(); + + void SendConnectResponse(); + void TerminateClientStream( + absl::string_view error_description, + QuicResetStreamError error_code = + QuicResetStreamError::FromIetf(QuicHttp3ErrorCode::CONNECT_ERROR)); + + const absl::flat_hash_set acceptable_destinations_; + SocketFactory* const socket_factory_; + + // Null when client stream closed. + QuicSimpleServerBackend::RequestHandler* client_stream_request_handler_; + + // Null when destination connection disconnected. + std::unique_ptr destination_socket_; + + bool receive_started_ = false; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_CONNECT_TUNNEL_H_ diff --git a/quiche/quic/tools/connect_tunnel_test.cc b/quiche/quic/tools/connect_tunnel_test.cc new file mode 100644 index 000000000000..379da7c4923a --- /dev/null +++ b/quiche/quic/tools/connect_tunnel_test.cc @@ -0,0 +1,353 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/connect_tunnel.h" + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/connecting_client_socket.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/socket_factory.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/platform/api/quic_test_loopback.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/tools/quic_backend_response.h" +#include "quiche/quic/tools/quic_simple_server_backend.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic::test { +namespace { + +using ::testing::_; +using ::testing::AllOf; +using ::testing::AnyOf; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::Ge; +using ::testing::Gt; +using ::testing::InvokeWithoutArgs; +using ::testing::IsEmpty; +using ::testing::Matcher; +using ::testing::NiceMock; +using ::testing::Pair; +using ::testing::Property; +using ::testing::Return; +using ::testing::StrictMock; + +class MockRequestHandler : public QuicSimpleServerBackend::RequestHandler { + public: + QuicConnectionId connection_id() const override { + return TestConnectionId(41212); + } + QuicStreamId stream_id() const override { return 100; } + std::string peer_host() const override { return "127.0.0.1"; } + + MOCK_METHOD(QuicSpdyStream*, GetStream, (), (override)); + MOCK_METHOD(void, OnResponseBackendComplete, + (const QuicBackendResponse* response), (override)); + MOCK_METHOD(void, SendStreamData, (absl::string_view data, bool close_stream), + (override)); + MOCK_METHOD(void, TerminateStreamWithError, (QuicResetStreamError error), + (override)); +}; + +class MockSocketFactory : public SocketFactory { + public: + MOCK_METHOD(std::unique_ptr, CreateTcpClientSocket, + (const quic::QuicSocketAddress& peer_address, + QuicByteCount receive_buffer_size, + QuicByteCount send_buffer_size, + ConnectingClientSocket::AsyncVisitor* async_visitor), + (override)); + MOCK_METHOD(std::unique_ptr, + CreateConnectingUdpClientSocket, + (const quic::QuicSocketAddress& peer_address, + QuicByteCount receive_buffer_size, + QuicByteCount send_buffer_size, + ConnectingClientSocket::AsyncVisitor* async_visitor), + (override)); +}; + +class MockSocket : public ConnectingClientSocket { + public: + MOCK_METHOD(absl::Status, ConnectBlocking, (), (override)); + MOCK_METHOD(void, ConnectAsync, (), (override)); + MOCK_METHOD(void, Disconnect, (), (override)); + MOCK_METHOD(absl::StatusOr, GetLocalAddress, (), + (override)); + MOCK_METHOD(absl::StatusOr, ReceiveBlocking, + (QuicByteCount max_size), (override)); + MOCK_METHOD(void, ReceiveAsync, (QuicByteCount max_size), (override)); + MOCK_METHOD(absl::Status, SendBlocking, (std::string data), (override)); + MOCK_METHOD(absl::Status, SendBlocking, (quiche::QuicheMemSlice data), + (override)); + MOCK_METHOD(void, SendAsync, (std::string data), (override)); + MOCK_METHOD(void, SendAsync, (quiche::QuicheMemSlice data), (override)); +}; + +class ConnectTunnelTest : public quiche::test::QuicheTest { + public: + void SetUp() override { + auto socket = std::make_unique>(); + socket_ = socket.get(); + ON_CALL(socket_factory_, + CreateTcpClientSocket( + AnyOf(QuicSocketAddress(TestLoopback4(), kAcceptablePort), + QuicSocketAddress(TestLoopback6(), kAcceptablePort)), + _, _, &tunnel_)) + .WillByDefault(Return(ByMove(std::move(socket)))); + } + + protected: + static constexpr absl::string_view kAcceptableDestination = "localhost"; + static constexpr uint16_t kAcceptablePort = 977; + + StrictMock request_handler_; + NiceMock socket_factory_; + StrictMock* socket_; + + ConnectTunnel tunnel_{ + &request_handler_, + &socket_factory_, + /*acceptable_destinations=*/ + {{std::string(kAcceptableDestination), kAcceptablePort}, + {TestLoopback4().ToString(), kAcceptablePort}, + {absl::StrCat("[", TestLoopback6().ToString(), "]"), kAcceptablePort}}}; +}; + +TEST_F(ConnectTunnelTest, OpenTunnel) { + EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, ReceiveAsync(Gt(0))); + EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() { + tunnel_.ReceiveComplete(absl::CancelledError()); + })); + + spdy::Http2HeaderBlock expected_response_headers; + expected_response_headers[":status"] = "200"; + QuicBackendResponse expected_response; + expected_response.set_headers(std::move(expected_response_headers)); + expected_response.set_response_type(QuicBackendResponse::INCOMPLETE_RESPONSE); + EXPECT_CALL(request_handler_, + OnResponseBackendComplete( + AllOf(Property(&QuicBackendResponse::response_type, + QuicBackendResponse::INCOMPLETE_RESPONSE), + Property(&QuicBackendResponse::headers, + ElementsAre(Pair(":status", "200"))), + Property(&QuicBackendResponse::trailers, IsEmpty()), + Property(&QuicBackendResponse::body, IsEmpty())))); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":authority"] = + absl::StrCat(kAcceptableDestination, ":", kAcceptablePort); + + tunnel_.OpenTunnel(request_headers); + EXPECT_TRUE(tunnel_.IsConnectedToDestination()); + tunnel_.OnClientStreamClose(); + EXPECT_FALSE(tunnel_.IsConnectedToDestination()); +} + +TEST_F(ConnectTunnelTest, OpenTunnelToIpv4LiteralDestination) { + EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, ReceiveAsync(Gt(0))); + EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() { + tunnel_.ReceiveComplete(absl::CancelledError()); + })); + + spdy::Http2HeaderBlock expected_response_headers; + expected_response_headers[":status"] = "200"; + QuicBackendResponse expected_response; + expected_response.set_headers(std::move(expected_response_headers)); + expected_response.set_response_type(QuicBackendResponse::INCOMPLETE_RESPONSE); + EXPECT_CALL(request_handler_, + OnResponseBackendComplete( + AllOf(Property(&QuicBackendResponse::response_type, + QuicBackendResponse::INCOMPLETE_RESPONSE), + Property(&QuicBackendResponse::headers, + ElementsAre(Pair(":status", "200"))), + Property(&QuicBackendResponse::trailers, IsEmpty()), + Property(&QuicBackendResponse::body, IsEmpty())))); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":authority"] = + absl::StrCat(TestLoopback4().ToString(), ":", kAcceptablePort); + + tunnel_.OpenTunnel(request_headers); + EXPECT_TRUE(tunnel_.IsConnectedToDestination()); + tunnel_.OnClientStreamClose(); + EXPECT_FALSE(tunnel_.IsConnectedToDestination()); +} + +TEST_F(ConnectTunnelTest, OpenTunnelToIpv6LiteralDestination) { + EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, ReceiveAsync(Gt(0))); + EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() { + tunnel_.ReceiveComplete(absl::CancelledError()); + })); + + spdy::Http2HeaderBlock expected_response_headers; + expected_response_headers[":status"] = "200"; + QuicBackendResponse expected_response; + expected_response.set_headers(std::move(expected_response_headers)); + expected_response.set_response_type(QuicBackendResponse::INCOMPLETE_RESPONSE); + EXPECT_CALL(request_handler_, + OnResponseBackendComplete( + AllOf(Property(&QuicBackendResponse::response_type, + QuicBackendResponse::INCOMPLETE_RESPONSE), + Property(&QuicBackendResponse::headers, + ElementsAre(Pair(":status", "200"))), + Property(&QuicBackendResponse::trailers, IsEmpty()), + Property(&QuicBackendResponse::body, IsEmpty())))); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":authority"] = + absl::StrCat("[", TestLoopback6().ToString(), "]:", kAcceptablePort); + + tunnel_.OpenTunnel(request_headers); + EXPECT_TRUE(tunnel_.IsConnectedToDestination()); + tunnel_.OnClientStreamClose(); + EXPECT_FALSE(tunnel_.IsConnectedToDestination()); +} + +TEST_F(ConnectTunnelTest, OpenTunnelWithMalformedRequest) { + EXPECT_CALL(request_handler_, + TerminateStreamWithError(Property( + &QuicResetStreamError::ietf_application_code, + static_cast(QuicHttp3ErrorCode::MESSAGE_ERROR)))); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + // No ":authority" header. + + tunnel_.OpenTunnel(request_headers); + EXPECT_FALSE(tunnel_.IsConnectedToDestination()); + tunnel_.OnClientStreamClose(); +} + +TEST_F(ConnectTunnelTest, OpenTunnelWithUnacceptableDestination) { + EXPECT_CALL( + request_handler_, + TerminateStreamWithError(Property( + &QuicResetStreamError::ietf_application_code, + static_cast(QuicHttp3ErrorCode::REQUEST_REJECTED)))); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":authority"] = "unacceptable.test:100"; + + tunnel_.OpenTunnel(request_headers); + EXPECT_FALSE(tunnel_.IsConnectedToDestination()); + tunnel_.OnClientStreamClose(); +} + +TEST_F(ConnectTunnelTest, ReceiveFromDestination) { + static constexpr absl::string_view kData = "\x11\x22\x33\x44\x55"; + + EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, ReceiveAsync(Ge(kData.size()))).Times(2); + EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() { + tunnel_.ReceiveComplete(absl::CancelledError()); + })); + + EXPECT_CALL(request_handler_, OnResponseBackendComplete(_)); + + EXPECT_CALL(request_handler_, SendStreamData(kData, /*close_stream=*/false)); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":authority"] = + absl::StrCat(kAcceptableDestination, ":", kAcceptablePort); + + tunnel_.OpenTunnel(request_headers); + + // Simulate receiving `kData`. + tunnel_.ReceiveComplete(MemSliceFromString(kData)); + + tunnel_.OnClientStreamClose(); +} + +TEST_F(ConnectTunnelTest, SendToDestination) { + static constexpr absl::string_view kData = "\x11\x22\x33\x44\x55"; + + EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, ReceiveAsync(Gt(0))); + EXPECT_CALL(*socket_, SendBlocking(Matcher(Eq(kData)))) + .WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() { + tunnel_.ReceiveComplete(absl::CancelledError()); + })); + + EXPECT_CALL(request_handler_, OnResponseBackendComplete(_)); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":authority"] = + absl::StrCat(kAcceptableDestination, ":", kAcceptablePort); + + tunnel_.OpenTunnel(request_headers); + tunnel_.SendDataToDestination(kData); + tunnel_.OnClientStreamClose(); +} + +TEST_F(ConnectTunnelTest, DestinationDisconnect) { + EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, ReceiveAsync(Gt(0))); + EXPECT_CALL(*socket_, Disconnect()); + + EXPECT_CALL(request_handler_, OnResponseBackendComplete(_)); + EXPECT_CALL(request_handler_, SendStreamData("", /*close_stream=*/true)); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":authority"] = + absl::StrCat(kAcceptableDestination, ":", kAcceptablePort); + + tunnel_.OpenTunnel(request_headers); + + // Simulate receiving empty data. + tunnel_.ReceiveComplete(quiche::QuicheMemSlice()); + + EXPECT_FALSE(tunnel_.IsConnectedToDestination()); + + tunnel_.OnClientStreamClose(); +} + +TEST_F(ConnectTunnelTest, DestinationTcpConnectionError) { + EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, ReceiveAsync(Gt(0))); + EXPECT_CALL(*socket_, Disconnect()); + + EXPECT_CALL(request_handler_, OnResponseBackendComplete(_)); + EXPECT_CALL(request_handler_, + TerminateStreamWithError(Property( + &QuicResetStreamError::ietf_application_code, + static_cast(QuicHttp3ErrorCode::CONNECT_ERROR)))); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":authority"] = + absl::StrCat(kAcceptableDestination, ":", kAcceptablePort); + + tunnel_.OpenTunnel(request_headers); + + // Simulate receving error. + tunnel_.ReceiveComplete(absl::UnknownError("error")); + + tunnel_.OnClientStreamClose(); +} + +} // namespace +} // namespace quic::test diff --git a/quiche/quic/tools/connect_udp_tunnel.cc b/quiche/quic/tools/connect_udp_tunnel.cc new file mode 100644 index 000000000000..96612fc59188 --- /dev/null +++ b/quiche/quic/tools/connect_udp_tunnel.cc @@ -0,0 +1,424 @@ +// Copyright 2022 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/connect_udp_tunnel.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "url/url_canon.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/core/socket_factory.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/tools/quic_backend_response.h" +#include "quiche/quic/tools/quic_name_lookup.h" +#include "quiche/quic/tools/quic_simple_server_backend.h" +#include "quiche/common/masque/connect_udp_datagram_payload.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/common/platform/api/quiche_url_utils.h" +#include "quiche/common/structured_headers.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +namespace structured_headers = quiche::structured_headers; + +namespace { + +// Arbitrarily chosen. No effort has been made to figure out an optimal size. +constexpr size_t kReadSize = 4 * 1024; + +// Only support the default path +// ("/.well-known/masque/udp/{target_host}/{target_port}/") +absl::optional ValidateAndParseTargetFromPath( + absl::string_view path) { + std::string canonicalized_path_str; + url::StdStringCanonOutput canon_output(&canonicalized_path_str); + url::Component path_component; + url::CanonicalizePath(path.data(), url::Component(0, path.size()), + &canon_output, &path_component); + if (!path_component.is_nonempty()) { + QUICHE_DVLOG(1) << "CONNECT-UDP request with non-canonicalizable path: " + << path; + return absl::nullopt; + } + canon_output.Complete(); + absl::string_view canonicalized_path = + absl::string_view(canonicalized_path_str) + .substr(path_component.begin, path_component.len); + + std::vector path_split = + absl::StrSplit(canonicalized_path, '/'); + if (path_split.size() != 7 || !path_split[0].empty() || + path_split[1] != ".well-known" || path_split[2] != "masque" || + path_split[3] != "udp" || path_split[4].empty() || + path_split[5].empty() || !path_split[6].empty()) { + QUICHE_DVLOG(1) << "CONNECT-UDP request with bad path: " + << canonicalized_path; + return absl::nullopt; + } + + absl::optional decoded_host = + quiche::AsciiUrlDecode(path_split[4]); + if (!decoded_host.has_value()) { + QUICHE_DVLOG(1) << "CONNECT-UDP request with undecodable host: " + << path_split[4]; + return absl::nullopt; + } + // Empty host checked above after path split. Expect decoding to never result + // in an empty decoded host from non-empty encoded host. + QUICHE_DCHECK(!decoded_host.value().empty()); + + absl::optional decoded_port = + quiche::AsciiUrlDecode(path_split[5]); + if (!decoded_port.has_value()) { + QUICHE_DVLOG(1) << "CONNECT-UDP request with undecodable port: " + << path_split[5]; + return absl::nullopt; + } + // Empty port checked above after path split. Expect decoding to never result + // in an empty decoded port from non-empty encoded port. + QUICHE_DCHECK(!decoded_port.value().empty()); + + int parsed_port_number = + url::ParsePort(decoded_port.value().data(), + url::Component(0, decoded_port.value().size())); + // Negative result is either invalid or unspecified, either of which is + // disallowed for this parse. Port 0 is technically valid but reserved and not + // really usable in practice, so easiest to just disallow it here. + if (parsed_port_number <= 0) { + QUICHE_DVLOG(1) << "CONNECT-UDP request with bad port: " + << decoded_port.value(); + return absl::nullopt; + } + // Expect url::ParsePort() to validate port is uint16_t and otherwise return + // negative number checked for above. + QUICHE_DCHECK_LE(parsed_port_number, std::numeric_limits::max()); + + return QuicServerId(decoded_host.value(), + static_cast(parsed_port_number)); +} + +// Validate header expectations from RFC 9298, section 3.4. +absl::optional ValidateHeadersAndGetTarget( + const spdy::Http2HeaderBlock& request_headers) { + QUICHE_DCHECK(request_headers.contains(":method")); + QUICHE_DCHECK(request_headers.find(":method")->second == "CONNECT"); + QUICHE_DCHECK(request_headers.contains(":protocol")); + QUICHE_DCHECK(request_headers.find(":protocol")->second == "connect-udp"); + + auto authority_it = request_headers.find(":authority"); + if (authority_it == request_headers.end() || authority_it->second.empty()) { + QUICHE_DVLOG(1) << "CONNECT-UDP request missing authority"; + return absl::nullopt; + } + // For toy server simplicity, skip validating that the authority matches the + // current server. + + auto scheme_it = request_headers.find(":scheme"); + if (scheme_it == request_headers.end() || scheme_it->second.empty()) { + QUICHE_DVLOG(1) << "CONNECT-UDP request missing scheme"; + return absl::nullopt; + } else if (scheme_it->second != "https") { + QUICHE_DVLOG(1) << "CONNECT-UDP request contains unexpected scheme: " + << scheme_it->second; + return absl::nullopt; + } + + auto path_it = request_headers.find(":path"); + if (path_it == request_headers.end() || path_it->second.empty()) { + QUICHE_DVLOG(1) << "CONNECT-UDP request missing path"; + return absl::nullopt; + } + absl::optional target_server_id = + ValidateAndParseTargetFromPath(path_it->second); + + return target_server_id; +} + +bool ValidateTarget( + const QuicServerId& target, + const absl::flat_hash_set& acceptable_targets) { + if (acceptable_targets.contains(target)) { + return true; + } + + QUICHE_DVLOG(1) + << "CONNECT-UDP request target is not an acceptable allow-listed target: " + << target.ToHostPortString(); + return false; +} + +} // namespace + +ConnectUdpTunnel::ConnectUdpTunnel( + QuicSimpleServerBackend::RequestHandler* client_stream_request_handler, + SocketFactory* socket_factory, std::string server_label, + absl::flat_hash_set acceptable_targets) + : acceptable_targets_(std::move(acceptable_targets)), + socket_factory_(socket_factory), + server_label_(std::move(server_label)), + client_stream_request_handler_(client_stream_request_handler) { + QUICHE_DCHECK(client_stream_request_handler_); + QUICHE_DCHECK(socket_factory_); + QUICHE_DCHECK(!server_label_.empty()); +} + +ConnectUdpTunnel::~ConnectUdpTunnel() { + // Expect client and target sides of tunnel to both be closed before + // destruction. + QUICHE_DCHECK(!IsTunnelOpenToTarget()); + QUICHE_DCHECK(!receive_started_); + QUICHE_DCHECK(!datagram_visitor_registered_); +} + +void ConnectUdpTunnel::OpenTunnel( + const spdy::Http2HeaderBlock& request_headers) { + QUICHE_DCHECK(!IsTunnelOpenToTarget()); + + absl::optional target = + ValidateHeadersAndGetTarget(request_headers); + if (!target.has_value()) { + // Malformed request. + TerminateClientStream( + "invalid request headers", + QuicResetStreamError::FromIetf(QuicHttp3ErrorCode::MESSAGE_ERROR)); + return; + } + + if (!ValidateTarget(target.value(), acceptable_targets_)) { + SendErrorResponse("403", "destination_ip_prohibited", + "disallowed proxy target"); + return; + } + + // TODO(ericorth): Validate that the IP address doesn't fall into diallowed + // ranges per RFC 9298, Section 7. + QuicSocketAddress address = tools::LookupAddress(AF_UNSPEC, target.value()); + if (!address.IsInitialized()) { + SendErrorResponse("500", "dns_error", "host resolution error"); + return; + } + + target_socket_ = socket_factory_->CreateConnectingUdpClientSocket( + address, + /*receive_buffer_size=*/0, + /*send_buffer_size=*/0, + /*async_visitor=*/this); + QUICHE_DCHECK(target_socket_); + + absl::Status connect_result = target_socket_->ConnectBlocking(); + if (!connect_result.ok()) { + SendErrorResponse( + "502", "destination_ip_unroutable", + absl::StrCat("UDP socket error: ", connect_result.ToString())); + return; + } + + QUICHE_DVLOG(1) << "CONNECT-UDP tunnel opened from stream " + << client_stream_request_handler_->stream_id() << " to " + << target.value().ToHostPortString(); + + client_stream_request_handler_->GetStream()->RegisterHttp3DatagramVisitor( + this); + datagram_visitor_registered_ = true; + + SendConnectResponse(); + BeginAsyncReadFromTarget(); +} + +bool ConnectUdpTunnel::IsTunnelOpenToTarget() const { return !!target_socket_; } + +void ConnectUdpTunnel::OnClientStreamClose() { + QUICHE_CHECK(client_stream_request_handler_); + + QUICHE_DVLOG(1) << "CONNECT-UDP stream " + << client_stream_request_handler_->stream_id() << " closed"; + + if (datagram_visitor_registered_) { + client_stream_request_handler_->GetStream() + ->UnregisterHttp3DatagramVisitor(); + datagram_visitor_registered_ = false; + } + client_stream_request_handler_ = nullptr; + + if (IsTunnelOpenToTarget()) { + target_socket_->Disconnect(); + } + + // Clear socket pointer. + target_socket_.reset(); +} + +void ConnectUdpTunnel::ConnectComplete(absl::Status /*status*/) { + // Async connect not expected. + QUICHE_NOTREACHED(); +} + +void ConnectUdpTunnel::ReceiveComplete( + absl::StatusOr data) { + QUICHE_DCHECK(IsTunnelOpenToTarget()); + QUICHE_DCHECK(receive_started_); + + receive_started_ = false; + + if (!data.ok()) { + if (client_stream_request_handler_) { + QUICHE_LOG(WARNING) << "Error receiving CONNECT-UDP data from target: " + << data.status(); + } else { + // This typically just means a receive operation was cancelled on calling + // target_socket_->Disconnect(). + QUICHE_DVLOG(1) << "Error receiving CONNECT-UDP data from target after " + "stream already closed."; + } + return; + } + + QUICHE_DCHECK(client_stream_request_handler_); + quiche::ConnectUdpDatagramUdpPacketPayload payload( + data.value().AsStringView()); + client_stream_request_handler_->GetStream()->SendHttp3Datagram( + payload.Serialize()); + + BeginAsyncReadFromTarget(); +} + +void ConnectUdpTunnel::SendComplete(absl::Status /*status*/) { + // Async send not expected. + QUICHE_NOTREACHED(); +} + +void ConnectUdpTunnel::OnHttp3Datagram(QuicStreamId stream_id, + absl::string_view payload) { + QUICHE_DCHECK(IsTunnelOpenToTarget()); + QUICHE_DCHECK_EQ(stream_id, client_stream_request_handler_->stream_id()); + QUICHE_DCHECK(!payload.empty()); + + std::unique_ptr parsed_payload = + quiche::ConnectUdpDatagramPayload::Parse(payload); + if (!parsed_payload) { + QUICHE_DVLOG(1) << "Ignoring HTTP Datagram payload, due to inability to " + "parse as CONNECT-UDP payload."; + return; + } + + switch (parsed_payload->GetType()) { + case quiche::ConnectUdpDatagramPayload::Type::kUdpPacket: + SendUdpPacketToTarget(parsed_payload->GetUdpProxyingPayload()); + break; + case quiche::ConnectUdpDatagramPayload::Type::kUnknown: + QUICHE_DVLOG(1) + << "Ignoring HTTP Datagram payload with unrecognized context ID."; + } +} + +void ConnectUdpTunnel::BeginAsyncReadFromTarget() { + QUICHE_DCHECK(IsTunnelOpenToTarget()); + QUICHE_DCHECK(client_stream_request_handler_); + QUICHE_DCHECK(!receive_started_); + + receive_started_ = true; + target_socket_->ReceiveAsync(kReadSize); +} + +void ConnectUdpTunnel::SendUdpPacketToTarget(absl::string_view packet) { + absl::Status send_result = target_socket_->SendBlocking(std::string(packet)); + if (!send_result.ok()) { + QUICHE_LOG(WARNING) << "Error sending CONNECT-UDP datagram to target: " + << send_result; + } +} + +void ConnectUdpTunnel::SendConnectResponse() { + QUICHE_DCHECK(IsTunnelOpenToTarget()); + QUICHE_DCHECK(client_stream_request_handler_); + + spdy::Http2HeaderBlock response_headers; + response_headers[":status"] = "200"; + + absl::optional capsule_protocol_value = + structured_headers::SerializeItem(structured_headers::Item(true)); + QUICHE_CHECK(capsule_protocol_value.has_value()); + response_headers["Capsule-Protocol"] = capsule_protocol_value.value(); + + QuicBackendResponse response; + response.set_headers(std::move(response_headers)); + // Need to leave the stream open after sending the CONNECT response. + response.set_response_type(QuicBackendResponse::INCOMPLETE_RESPONSE); + + client_stream_request_handler_->OnResponseBackendComplete(&response); +} + +void ConnectUdpTunnel::SendErrorResponse(absl::string_view status, + absl::string_view proxy_status_error, + absl::string_view error_details) { + QUICHE_DCHECK(!status.empty()); + QUICHE_DCHECK(!proxy_status_error.empty()); + QUICHE_DCHECK(!error_details.empty()); + QUICHE_DCHECK(client_stream_request_handler_); + +#ifndef NDEBUG + // Expect a valid status code (number, 100 to 599 inclusive) and not a + // Successful code (200 to 299 inclusive). + int status_num = 0; + bool is_num = absl::SimpleAtoi(status, &status_num); + QUICHE_DCHECK(is_num); + QUICHE_DCHECK_GE(status_num, 100); + QUICHE_DCHECK_LT(status_num, 600); + QUICHE_DCHECK(status_num < 200 || status_num >= 300); +#endif // !NDEBUG + + spdy::Http2HeaderBlock headers; + headers[":status"] = status; + + structured_headers::Item proxy_status_item(server_label_); + structured_headers::Item proxy_status_error_item( + std::string{proxy_status_error}); + structured_headers::Item proxy_status_details_item( + std::string{error_details}); + structured_headers::ParameterizedMember proxy_status_member( + std::move(proxy_status_item), + {{"error", std::move(proxy_status_error_item)}, + {"details", std::move(proxy_status_details_item)}}); + absl::optional proxy_status_value = + structured_headers::SerializeList({proxy_status_member}); + QUICHE_CHECK(proxy_status_value.has_value()); + headers["Proxy-Status"] = proxy_status_value.value(); + + QuicBackendResponse response; + response.set_headers(std::move(headers)); + + client_stream_request_handler_->OnResponseBackendComplete(&response); +} + +void ConnectUdpTunnel::TerminateClientStream( + absl::string_view error_description, QuicResetStreamError error_code) { + QUICHE_DCHECK(client_stream_request_handler_); + + std::string error_description_str = + error_description.empty() ? "" + : absl::StrCat(" due to ", error_description); + QUICHE_DVLOG(1) << "Terminating CONNECT stream " + << client_stream_request_handler_->stream_id() + << " with error code " << error_code.ietf_application_code() + << error_description_str; + + client_stream_request_handler_->TerminateStreamWithError(error_code); +} + +} // namespace quic diff --git a/quiche/quic/tools/connect_udp_tunnel.h b/quiche/quic/tools/connect_udp_tunnel.h new file mode 100644 index 000000000000..1e800a75f7c2 --- /dev/null +++ b/quiche/quic/tools/connect_udp_tunnel.h @@ -0,0 +1,99 @@ +// Copyright 2022 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TOOLS_CONNECT_UDP_TUNNEL_H_ +#define QUICHE_QUIC_TOOLS_CONNECT_UDP_TUNNEL_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/connecting_client_socket.h" +#include "quiche/quic/core/http/quic_spdy_stream.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/socket_factory.h" +#include "quiche/quic/tools/quic_simple_server_backend.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +// Manages a single UDP tunnel for a CONNECT-UDP proxy (see RFC 9298). +class ConnectUdpTunnel : public ConnectingClientSocket::AsyncVisitor, + public QuicSpdyStream::Http3DatagramVisitor { + public: + // `client_stream_request_handler` and `socket_factory` must both outlive the + // created ConnectUdpTunnel. `server_label` is an identifier (typically + // randomly generated) to indentify the server or backend in error headers, + // per the requirements of RFC 9209, Section 2. + ConnectUdpTunnel( + QuicSimpleServerBackend::RequestHandler* client_stream_request_handler, + SocketFactory* socket_factory, std::string server_label, + absl::flat_hash_set acceptable_targets); + ~ConnectUdpTunnel(); + ConnectUdpTunnel(const ConnectUdpTunnel&) = delete; + ConnectUdpTunnel& operator=(const ConnectUdpTunnel&) = delete; + + // Attempts to open UDP tunnel to target server and then sends appropriate + // success/error response to the request stream. `request_headers` must + // represent headers from a CONNECT-UDP request, that is ":method"="CONNECT" + // and ":protocol"="connect-udp". + void OpenTunnel(const spdy::Http2HeaderBlock& request_headers); + + // Returns true iff the tunnel to the target server is currently open + bool IsTunnelOpenToTarget() const; + + // Called when the client stream has been closed. Tunnel to target + // server is closed if open. The RequestHandler will no longer be + // interacted with after completion. + void OnClientStreamClose(); + + // ConnectingClientSocket::AsyncVisitor: + void ConnectComplete(absl::Status status) override; + void ReceiveComplete(absl::StatusOr data) override; + void SendComplete(absl::Status status) override; + + // QuicSpdyStream::Http3DatagramVisitor: + void OnHttp3Datagram(QuicStreamId stream_id, + absl::string_view payload) override; + void OnUnknownCapsule(QuicStreamId /*stream_id*/, + const quiche::UnknownCapsule& /*capsule*/) override {} + + private: + void BeginAsyncReadFromTarget(); + void OnDataReceivedFromTarget(bool success); + + void SendUdpPacketToTarget(absl::string_view packet); + + void SendConnectResponse(); + void SendErrorResponse(absl::string_view status, + absl::string_view proxy_status_error, + absl::string_view error_details); + void TerminateClientStream(absl::string_view error_description, + QuicResetStreamError error_code); + + const absl::flat_hash_set acceptable_targets_; + SocketFactory* const socket_factory_; + const std::string server_label_; + + // Null when client stream closed. + QuicSimpleServerBackend::RequestHandler* client_stream_request_handler_; + + // Null when target connection disconnected. + std::unique_ptr target_socket_; + + bool receive_started_ = false; + bool datagram_visitor_registered_ = false; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_CONNECT_UDP_TUNNEL_H_ diff --git a/quiche/quic/tools/connect_udp_tunnel_test.cc b/quiche/quic/tools/connect_udp_tunnel_test.cc new file mode 100644 index 000000000000..9d7e88c397ba --- /dev/null +++ b/quiche/quic/tools/connect_udp_tunnel_test.cc @@ -0,0 +1,362 @@ +// Copyright 2022 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/connect_udp_tunnel.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "url/url_canon_stdstring.h" +#include "url/url_util.h" +#include "quiche/quic/core/connecting_client_socket.h" +#include "quiche/quic/core/http/quic_spdy_stream.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/socket_factory.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/platform/api/quic_test_loopback.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/tools/quic_simple_server_backend.h" +#include "quiche/common/masque/connect_udp_datagram_payload.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace quic::test { +namespace { + +using ::testing::_; +using ::testing::AnyOf; +using ::testing::Eq; +using ::testing::Ge; +using ::testing::Gt; +using ::testing::HasSubstr; +using ::testing::InvokeWithoutArgs; +using ::testing::IsEmpty; +using ::testing::Matcher; +using ::testing::NiceMock; +using ::testing::Pair; +using ::testing::Property; +using ::testing::Return; +using ::testing::StrictMock; +using ::testing::UnorderedElementsAre; + +constexpr QuicStreamId kStreamId = 100; + +class MockStream : public QuicSpdyStream { + public: + explicit MockStream(QuicSpdySession* spdy_session) + : QuicSpdyStream(kStreamId, spdy_session, BIDIRECTIONAL) {} + + void OnBodyAvailable() override {} + + MOCK_METHOD(MessageStatus, SendHttp3Datagram, (absl::string_view data), + (override)); +}; + +class MockRequestHandler : public QuicSimpleServerBackend::RequestHandler { + public: + QuicConnectionId connection_id() const override { + return TestConnectionId(41212); + } + QuicStreamId stream_id() const override { return kStreamId; } + std::string peer_host() const override { return "127.0.0.1"; } + + MOCK_METHOD(QuicSpdyStream*, GetStream, (), (override)); + MOCK_METHOD(void, OnResponseBackendComplete, + (const QuicBackendResponse* response), (override)); + MOCK_METHOD(void, SendStreamData, (absl::string_view data, bool close_stream), + (override)); + MOCK_METHOD(void, TerminateStreamWithError, (QuicResetStreamError error), + (override)); +}; + +class MockSocketFactory : public SocketFactory { + public: + MOCK_METHOD(std::unique_ptr, CreateTcpClientSocket, + (const QuicSocketAddress& peer_address, + QuicByteCount receive_buffer_size, + QuicByteCount send_buffer_size, + ConnectingClientSocket::AsyncVisitor* async_visitor), + (override)); + MOCK_METHOD(std::unique_ptr, + CreateConnectingUdpClientSocket, + (const QuicSocketAddress& peer_address, + QuicByteCount receive_buffer_size, + QuicByteCount send_buffer_size, + ConnectingClientSocket::AsyncVisitor* async_visitor), + (override)); +}; + +class MockSocket : public ConnectingClientSocket { + public: + MOCK_METHOD(absl::Status, ConnectBlocking, (), (override)); + MOCK_METHOD(void, ConnectAsync, (), (override)); + MOCK_METHOD(void, Disconnect, (), (override)); + MOCK_METHOD(absl::StatusOr, GetLocalAddress, (), + (override)); + MOCK_METHOD(absl::StatusOr, ReceiveBlocking, + (QuicByteCount max_size), (override)); + MOCK_METHOD(void, ReceiveAsync, (QuicByteCount max_size), (override)); + MOCK_METHOD(absl::Status, SendBlocking, (std::string data), (override)); + MOCK_METHOD(absl::Status, SendBlocking, (quiche::QuicheMemSlice data), + (override)); + MOCK_METHOD(void, SendAsync, (std::string data), (override)); + MOCK_METHOD(void, SendAsync, (quiche::QuicheMemSlice data), (override)); +}; + +class ConnectUdpTunnelTest : public quiche::test::QuicheTest { + public: + void SetUp() override { + auto socket = std::make_unique>(); + socket_ = socket.get(); + ON_CALL(socket_factory_, + CreateConnectingUdpClientSocket( + AnyOf(QuicSocketAddress(TestLoopback4(), kAcceptablePort), + QuicSocketAddress(TestLoopback6(), kAcceptablePort)), + _, _, &tunnel_)) + .WillByDefault(Return(ByMove(std::move(socket)))); + + EXPECT_CALL(request_handler_, GetStream()).WillRepeatedly(Return(&stream_)); + } + + protected: + static constexpr absl::string_view kAcceptableTarget = "localhost"; + static constexpr uint16_t kAcceptablePort = 977; + + NiceMock connection_helper_; + NiceMock alarm_factory_; + NiceMock session_{new NiceMock( + &connection_helper_, &alarm_factory_, Perspective::IS_SERVER)}; + StrictMock stream_{&session_}; + + StrictMock request_handler_; + NiceMock socket_factory_; + StrictMock* socket_; + + ConnectUdpTunnel tunnel_{ + &request_handler_, + &socket_factory_, + "server_label", + /*acceptable_targets=*/ + {{std::string(kAcceptableTarget), kAcceptablePort}, + {TestLoopback4().ToString(), kAcceptablePort}, + {absl::StrCat("[", TestLoopback6().ToString(), "]"), kAcceptablePort}}}; +}; + +TEST_F(ConnectUdpTunnelTest, OpenTunnel) { + EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, ReceiveAsync(Gt(0))); + EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() { + tunnel_.ReceiveComplete(absl::CancelledError()); + })); + + EXPECT_CALL( + request_handler_, + OnResponseBackendComplete( + AllOf(Property(&QuicBackendResponse::response_type, + QuicBackendResponse::INCOMPLETE_RESPONSE), + Property(&QuicBackendResponse::headers, + UnorderedElementsAre(Pair(":status", "200"), + Pair("Capsule-Protocol", "?1"))), + Property(&QuicBackendResponse::trailers, IsEmpty()), + Property(&QuicBackendResponse::body, IsEmpty())))); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":protocol"] = "connect-udp"; + request_headers[":authority"] = "proxy.test"; + request_headers[":scheme"] = "https"; + request_headers[":path"] = absl::StrCat( + "/.well-known/masque/udp/", kAcceptableTarget, "/", kAcceptablePort, "/"); + + tunnel_.OpenTunnel(request_headers); + EXPECT_TRUE(tunnel_.IsTunnelOpenToTarget()); + tunnel_.OnClientStreamClose(); + EXPECT_FALSE(tunnel_.IsTunnelOpenToTarget()); +} + +TEST_F(ConnectUdpTunnelTest, OpenTunnelToIpv4LiteralTarget) { + EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, ReceiveAsync(Gt(0))); + EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() { + tunnel_.ReceiveComplete(absl::CancelledError()); + })); + + EXPECT_CALL( + request_handler_, + OnResponseBackendComplete( + AllOf(Property(&QuicBackendResponse::response_type, + QuicBackendResponse::INCOMPLETE_RESPONSE), + Property(&QuicBackendResponse::headers, + UnorderedElementsAre(Pair(":status", "200"), + Pair("Capsule-Protocol", "?1"))), + Property(&QuicBackendResponse::trailers, IsEmpty()), + Property(&QuicBackendResponse::body, IsEmpty())))); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":protocol"] = "connect-udp"; + request_headers[":authority"] = "proxy.test"; + request_headers[":scheme"] = "https"; + request_headers[":path"] = + absl::StrCat("/.well-known/masque/udp/", TestLoopback4().ToString(), "/", + kAcceptablePort, "/"); + + tunnel_.OpenTunnel(request_headers); + EXPECT_TRUE(tunnel_.IsTunnelOpenToTarget()); + tunnel_.OnClientStreamClose(); + EXPECT_FALSE(tunnel_.IsTunnelOpenToTarget()); +} + +std::string PercentEncode(absl::string_view input) { + std::string encoded; + url::StdStringCanonOutput canon_output(&encoded); + url::EncodeURIComponent(input.data(), input.size(), &canon_output); + canon_output.Complete(); + return encoded; +} + +TEST_F(ConnectUdpTunnelTest, OpenTunnelToIpv6LiteralTarget) { + EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, ReceiveAsync(Gt(0))); + EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() { + tunnel_.ReceiveComplete(absl::CancelledError()); + })); + + EXPECT_CALL( + request_handler_, + OnResponseBackendComplete( + AllOf(Property(&QuicBackendResponse::response_type, + QuicBackendResponse::INCOMPLETE_RESPONSE), + Property(&QuicBackendResponse::headers, + UnorderedElementsAre(Pair(":status", "200"), + Pair("Capsule-Protocol", "?1"))), + Property(&QuicBackendResponse::trailers, IsEmpty()), + Property(&QuicBackendResponse::body, IsEmpty())))); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":protocol"] = "connect-udp"; + request_headers[":authority"] = "proxy.test"; + request_headers[":scheme"] = "https"; + request_headers[":path"] = absl::StrCat( + "/.well-known/masque/udp/", + PercentEncode(absl::StrCat("[", TestLoopback6().ToString(), "]")), "/", + kAcceptablePort, "/"); + + tunnel_.OpenTunnel(request_headers); + EXPECT_TRUE(tunnel_.IsTunnelOpenToTarget()); + tunnel_.OnClientStreamClose(); + EXPECT_FALSE(tunnel_.IsTunnelOpenToTarget()); +} + +TEST_F(ConnectUdpTunnelTest, OpenTunnelWithMalformedRequest) { + EXPECT_CALL(request_handler_, + TerminateStreamWithError(Property( + &QuicResetStreamError::ietf_application_code, + static_cast(QuicHttp3ErrorCode::MESSAGE_ERROR)))); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":protocol"] = "connect-udp"; + request_headers[":authority"] = "proxy.test"; + request_headers[":scheme"] = "https"; + // No ":path" header. + + tunnel_.OpenTunnel(request_headers); + EXPECT_FALSE(tunnel_.IsTunnelOpenToTarget()); + tunnel_.OnClientStreamClose(); +} + +TEST_F(ConnectUdpTunnelTest, OpenTunnelWithUnacceptableTarget) { + EXPECT_CALL(request_handler_, + OnResponseBackendComplete(AllOf( + Property(&QuicBackendResponse::response_type, + QuicBackendResponse::REGULAR_RESPONSE), + Property(&QuicBackendResponse::headers, + UnorderedElementsAre( + Pair(":status", "403"), + Pair("Proxy-Status", + HasSubstr("destination_ip_prohibited")))), + Property(&QuicBackendResponse::trailers, IsEmpty())))); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":protocol"] = "connect-udp"; + request_headers[":authority"] = "proxy.test"; + request_headers[":scheme"] = "https"; + request_headers[":path"] = "/.well-known/masque/udp/unacceptable.test/100/"; + + tunnel_.OpenTunnel(request_headers); + EXPECT_FALSE(tunnel_.IsTunnelOpenToTarget()); + tunnel_.OnClientStreamClose(); +} + +TEST_F(ConnectUdpTunnelTest, ReceiveFromTarget) { + static constexpr absl::string_view kData = "\x11\x22\x33\x44\x55"; + + EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, ReceiveAsync(Ge(kData.size()))).Times(2); + EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() { + tunnel_.ReceiveComplete(absl::CancelledError()); + })); + + EXPECT_CALL(request_handler_, OnResponseBackendComplete(_)); + + EXPECT_CALL( + stream_, + SendHttp3Datagram( + quiche::ConnectUdpDatagramUdpPacketPayload(kData).Serialize())) + .WillOnce(Return(MESSAGE_STATUS_SUCCESS)); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":protocol"] = "connect-udp"; + request_headers[":authority"] = "proxy.test"; + request_headers[":scheme"] = "https"; + request_headers[":path"] = absl::StrCat( + "/.well-known/masque/udp/", kAcceptableTarget, "/", kAcceptablePort, "/"); + + tunnel_.OpenTunnel(request_headers); + + // Simulate receiving `kData`. + tunnel_.ReceiveComplete(MemSliceFromString(kData)); + + tunnel_.OnClientStreamClose(); +} + +TEST_F(ConnectUdpTunnelTest, SendToTarget) { + static constexpr absl::string_view kData = "\x11\x22\x33\x44\x55"; + + EXPECT_CALL(*socket_, ConnectBlocking()).WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, ReceiveAsync(Gt(0))); + EXPECT_CALL(*socket_, SendBlocking(Matcher(Eq(kData)))) + .WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*socket_, Disconnect()).WillOnce(InvokeWithoutArgs([this]() { + tunnel_.ReceiveComplete(absl::CancelledError()); + })); + + EXPECT_CALL(request_handler_, OnResponseBackendComplete(_)); + + spdy::Http2HeaderBlock request_headers; + request_headers[":method"] = "CONNECT"; + request_headers[":protocol"] = "connect-udp"; + request_headers[":authority"] = "proxy.test"; + request_headers[":scheme"] = "https"; + request_headers[":path"] = absl::StrCat( + "/.well-known/masque/udp/", kAcceptableTarget, "/", kAcceptablePort, "/"); + + tunnel_.OpenTunnel(request_headers); + tunnel_.OnHttp3Datagram( + kStreamId, quiche::ConnectUdpDatagramUdpPacketPayload(kData).Serialize()); + tunnel_.OnClientStreamClose(); +} + +} // namespace +} // namespace quic::test diff --git a/quiche/quic/tools/crypto_message_printer_bin.cc b/quiche/quic/tools/crypto_message_printer_bin.cc new file mode 100644 index 000000000000..eb7393d549ea --- /dev/null +++ b/quiche/quic/tools/crypto_message_printer_bin.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Dumps the contents of a QUIC crypto handshake message in a human readable +// format. +// +// Usage: crypto_message_printer_bin + +#include +#include + +#include "absl/strings/escaping.h" +#include "quiche/quic/core/crypto/crypto_framer.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/common/platform/api/quiche_command_line_flags.h" + +using std::cerr; +using std::cout; +using std::endl; + +namespace quic { + +class CryptoMessagePrinter : public ::quic::CryptoFramerVisitorInterface { + public: + void OnHandshakeMessage(const CryptoHandshakeMessage& message) override { + cout << message.DebugString() << endl; + } + + void OnError(CryptoFramer* framer) override { + cerr << "Error code: " << framer->error() << endl; + cerr << "Error details: " << framer->error_detail() << endl; + } +}; + +} // namespace quic + +int main(int argc, char* argv[]) { + const char* usage = "Usage: crypto_message_printer "; + std::vector messages = + quiche::QuicheParseCommandLineFlags(usage, argc, argv); + if (messages.size() != 1) { + quiche::QuichePrintCommandLineFlagHelp(usage); + exit(0); + } + + quic::CryptoMessagePrinter printer; + quic::CryptoFramer framer; + framer.set_visitor(&printer); + framer.set_process_truncated_messages(true); + std::string input = absl::HexStringToBytes(messages[0]); + if (!framer.ProcessInput(input)) { + return 1; + } + if (framer.InputBytesRemaining() != 0) { + cerr << "Input partially consumed. " << framer.InputBytesRemaining() + << " bytes remaining." << endl; + return 2; + } + return 0; +} diff --git a/quiche/quic/tools/fake_proof_verifier.h b/quiche/quic/tools/fake_proof_verifier.h new file mode 100644 index 000000000000..9350fc6d9760 --- /dev/null +++ b/quiche/quic/tools/fake_proof_verifier.h @@ -0,0 +1,44 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TOOLS_FAKE_PROOF_VERIFIER_H_ +#define QUICHE_QUIC_TOOLS_FAKE_PROOF_VERIFIER_H_ + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/proof_verifier.h" + +namespace quic { + +// ProofVerifier implementation which always returns success. +class FakeProofVerifier : public ProofVerifier { + public: + ~FakeProofVerifier() override {} + QuicAsyncStatus VerifyProof( + const std::string& /*hostname*/, const uint16_t /*port*/, + const std::string& /*server_config*/, + QuicTransportVersion /*quic_version*/, absl::string_view /*chlo_hash*/, + const std::vector& /*certs*/, + const std::string& /*cert_sct*/, const std::string& /*signature*/, + const ProofVerifyContext* /*context*/, std::string* /*error_details*/, + std::unique_ptr* /*details*/, + std::unique_ptr /*callback*/) override { + return QUIC_SUCCESS; + } + QuicAsyncStatus VerifyCertChain( + const std::string& /*hostname*/, const uint16_t /*port*/, + const std::vector& /*certs*/, + const std::string& /*ocsp_response*/, const std::string& /*cert_sct*/, + const ProofVerifyContext* /*context*/, std::string* /*error_details*/, + std::unique_ptr* /*details*/, uint8_t* /*out_alert*/, + std::unique_ptr /*callback*/) override { + return QUIC_SUCCESS; + } + std::unique_ptr CreateDefaultContext() override { + return nullptr; + } +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_FAKE_PROOF_VERIFIER_H_ diff --git a/quiche/quic/tools/qpack_offline_decoder_bin.cc b/quiche/quic/tools/qpack_offline_decoder_bin.cc new file mode 100644 index 000000000000..a5cbc856dd4a --- /dev/null +++ b/quiche/quic/tools/qpack_offline_decoder_bin.cc @@ -0,0 +1,46 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/test_tools/qpack/qpack_offline_decoder.h" +#include "quiche/common/platform/api/quiche_command_line_flags.h" + +int main(int argc, char* argv[]) { + const char* usage = + "Usage: qpack_offline_decoder input_filename expected_headers_filename " + "...."; + std::vector args = + quiche::QuicheParseCommandLineFlags(usage, argc, argv); + + if (args.size() < 2 || args.size() % 2 != 0) { + quiche::QuichePrintCommandLineFlagHelp(usage); + return 1; + } + + size_t i; + size_t success_count = 0; + for (i = 0; 2 * i < args.size(); ++i) { + const absl::string_view input_filename(args[2 * i]); + const absl::string_view expected_headers_filename(args[2 * i + 1]); + + // Every file represents a different connection, + // therefore every file needs a fresh decoding context. + quic::QpackOfflineDecoder decoder; + if (decoder.DecodeAndVerifyOfflineData(input_filename, + expected_headers_filename)) { + ++success_count; + } + } + + std::cout << "Processed " << i << " pairs of input files, " << success_count + << " passed, " << (i - success_count) << " failed." << std::endl; + + // Return success if all input files pass. + return (success_count == i) ? 0 : 1; +} diff --git a/quiche/quic/tools/quic_backend_response.cc b/quiche/quic/tools/quic_backend_response.cc new file mode 100644 index 000000000000..8a54204a49a9 --- /dev/null +++ b/quiche/quic/tools/quic_backend_response.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/quic_backend_response.h" + +namespace quic { + +QuicBackendResponse::ServerPushInfo::ServerPushInfo( + QuicUrl request_url, spdy::Http2HeaderBlock headers, + spdy::SpdyPriority priority, std::string body) + : request_url(request_url), + headers(std::move(headers)), + priority(priority), + body(body) {} + +QuicBackendResponse::ServerPushInfo::ServerPushInfo(const ServerPushInfo& other) + : request_url(other.request_url), + headers(other.headers.Clone()), + priority(other.priority), + body(other.body) {} + +QuicBackendResponse::QuicBackendResponse() + : response_type_(REGULAR_RESPONSE), delay_(QuicTime::Delta::Zero()) {} + +QuicBackendResponse::~QuicBackendResponse() = default; + +} // namespace quic diff --git a/quiche/quic/tools/quic_backend_response.h b/quiche/quic/tools/quic_backend_response.h new file mode 100644 index 000000000000..6d1b10584f28 --- /dev/null +++ b/quiche/quic/tools/quic_backend_response.h @@ -0,0 +1,98 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TOOLS_QUIC_BACKEND_RESPONSE_H_ +#define QUICHE_QUIC_TOOLS_QUIC_BACKEND_RESPONSE_H_ + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_time.h" +#include "quiche/quic/tools/quic_url.h" +#include "quiche/spdy/core/http2_header_block.h" +#include "quiche/spdy/core/spdy_protocol.h" + +namespace quic { + +// Container for HTTP response header/body pairs +// fetched by the QuicSimpleServerBackend +class QuicBackendResponse { + public: + // A ServerPushInfo contains path of the push request and everything needed in + // comprising a response for the push request. + // TODO(b/171463363): Remove. + struct ServerPushInfo { + ServerPushInfo(QuicUrl request_url, spdy::Http2HeaderBlock headers, + spdy::SpdyPriority priority, std::string body); + ServerPushInfo(const ServerPushInfo& other); + + QuicUrl request_url; + spdy::Http2HeaderBlock headers; + spdy::SpdyPriority priority; + std::string body; + }; + + enum SpecialResponseType { + REGULAR_RESPONSE, // Send the headers and body like a server should. + CLOSE_CONNECTION, // Close the connection (sending the close packet). + IGNORE_REQUEST, // Do nothing, expect the client to time out. + BACKEND_ERR_RESPONSE, // There was an error fetching the response from + // the backend, for example as a TCP connection + // error. + INCOMPLETE_RESPONSE, // The server will act as if there is a non-empty + // trailer but it will not be sent, as a result, FIN + // will not be sent too. + GENERATE_BYTES // Sends a response with a length equal to the number + // of bytes in the URL path. + }; + QuicBackendResponse(); + + QuicBackendResponse(const QuicBackendResponse& other) = delete; + QuicBackendResponse& operator=(const QuicBackendResponse& other) = delete; + + ~QuicBackendResponse(); + + const std::vector& early_hints() const { + return early_hints_; + } + SpecialResponseType response_type() const { return response_type_; } + const spdy::Http2HeaderBlock& headers() const { return headers_; } + const spdy::Http2HeaderBlock& trailers() const { return trailers_; } + const absl::string_view body() const { return absl::string_view(body_); } + + void AddEarlyHints(const spdy::Http2HeaderBlock& headers) { + spdy::Http2HeaderBlock hints = headers.Clone(); + hints[":status"] = "103"; + early_hints_.push_back(std::move(hints)); + } + + void set_response_type(SpecialResponseType response_type) { + response_type_ = response_type; + } + + void set_headers(spdy::Http2HeaderBlock headers) { + headers_ = std::move(headers); + } + void set_trailers(spdy::Http2HeaderBlock trailers) { + trailers_ = std::move(trailers); + } + void set_body(absl::string_view body) { + body_.assign(body.data(), body.size()); + } + + // This would simulate a delay before sending the response + // back to the client. Intended for testing purposes. + void set_delay(QuicTime::Delta delay) { delay_ = delay; } + QuicTime::Delta delay() const { return delay_; } + + private: + std::vector early_hints_; + SpecialResponseType response_type_; + spdy::Http2HeaderBlock headers_; + spdy::Http2HeaderBlock trailers_; + std::string body_; + QuicTime::Delta delay_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_QUIC_BACKEND_RESPONSE_H_ diff --git a/quiche/quic/tools/quic_client_base.cc b/quiche/quic/tools/quic_client_base.cc new file mode 100644 index 000000000000..4612be0641e5 --- /dev/null +++ b/quiche/quic/tools/quic_client_base.cc @@ -0,0 +1,545 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/quic_client_base.h" + +#include +#include + +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/http/spdy_utils.h" +#include "quiche/quic/core/quic_packet_writer.h" +#include "quiche/quic/core/quic_path_validator.h" +#include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" + +namespace quic { + +namespace { + +// Implements the basic feature of a result delegate for path validation for +// connection migration. If the validation succeeds, migrate to the alternative +// path. Otherwise, stay on the current path. +class QuicClientSocketMigrationValidationResultDelegate + : public QuicPathValidator::ResultDelegate { + public: + explicit QuicClientSocketMigrationValidationResultDelegate( + QuicClientBase* client) + : QuicPathValidator::ResultDelegate(), client_(client) {} + + virtual ~QuicClientSocketMigrationValidationResultDelegate() = default; + + // QuicPathValidator::ResultDelegate + // Overridden to start migration and takes the ownership of the writer in the + // context. + void OnPathValidationSuccess( + std::unique_ptr context, + QuicTime /*start_time*/) override { + QUIC_DLOG(INFO) << "Successfully validated path from " << *context + << ". Migrate to it now."; + auto migration_context = std::unique_ptr( + static_cast(context.release())); + client_->session()->MigratePath( + migration_context->self_address(), migration_context->peer_address(), + migration_context->WriterToUse(), /*owns_writer=*/false); + QUICHE_DCHECK(migration_context->WriterToUse() != nullptr); + // Hand the ownership of the alternative writer to the client. + client_->set_writer(migration_context->ReleaseWriter()); + } + + void OnPathValidationFailure( + std::unique_ptr context) override { + QUIC_LOG(WARNING) << "Fail to validate path " << *context + << ", stop migrating."; + client_->session()->connection()->OnPathValidationFailureAtClient( + /*is_multi_port=*/false, *context); + } + + protected: + QuicClientBase* client() { return client_; } + + private: + QuicClientBase* client_; +}; + +class ServerPreferredAddressResultDelegateWithWriter + : public QuicClientSocketMigrationValidationResultDelegate { + public: + ServerPreferredAddressResultDelegateWithWriter(QuicClientBase* client) + : QuicClientSocketMigrationValidationResultDelegate(client) {} + + // Overridden to transfer the ownership of the new writer. + void OnPathValidationSuccess( + std::unique_ptr context, + QuicTime /*start_time*/) override { + client()->session()->connection()->OnServerPreferredAddressValidated( + *context, false); + auto migration_context = std::unique_ptr( + static_cast(context.release())); + client()->set_writer(migration_context->ReleaseWriter()); + } +}; + +} // namespace + +QuicClientBase::NetworkHelper::~NetworkHelper() = default; + +QuicClientBase::QuicClientBase( + const QuicServerId& server_id, + const ParsedQuicVersionVector& supported_versions, const QuicConfig& config, + QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory, + std::unique_ptr network_helper, + std::unique_ptr proof_verifier, + std::unique_ptr session_cache) + : server_id_(server_id), + initialized_(false), + local_port_(0), + config_(config), + crypto_config_(std::move(proof_verifier), std::move(session_cache)), + helper_(helper), + alarm_factory_(alarm_factory), + supported_versions_(supported_versions), + initial_max_packet_length_(0), + num_sent_client_hellos_(0), + connection_error_(QUIC_NO_ERROR), + connected_or_attempting_connect_(false), + network_helper_(std::move(network_helper)), + connection_debug_visitor_(nullptr), + server_connection_id_length_(kQuicDefaultConnectionIdLength), + client_connection_id_length_(0) {} + +QuicClientBase::~QuicClientBase() = default; + +bool QuicClientBase::Initialize() { + num_sent_client_hellos_ = 0; + connection_error_ = QUIC_NO_ERROR; + connected_or_attempting_connect_ = false; + + // If an initial flow control window has not explicitly been set, then use the + // same values that Chrome uses. + const uint32_t kSessionMaxRecvWindowSize = 15 * 1024 * 1024; // 15 MB + const uint32_t kStreamMaxRecvWindowSize = 6 * 1024 * 1024; // 6 MB + if (config()->GetInitialStreamFlowControlWindowToSend() == + kDefaultFlowControlSendWindow) { + config()->SetInitialStreamFlowControlWindowToSend(kStreamMaxRecvWindowSize); + } + if (config()->GetInitialSessionFlowControlWindowToSend() == + kDefaultFlowControlSendWindow) { + config()->SetInitialSessionFlowControlWindowToSend( + kSessionMaxRecvWindowSize); + } + + if (!network_helper_->CreateUDPSocketAndBind(server_address_, + bind_to_address_, local_port_)) { + return false; + } + + initialized_ = true; + return true; +} + +bool QuicClientBase::Connect() { + // Attempt multiple connects until the maximum number of client hellos have + // been sent. + int num_attempts = 0; + while (!connected() && + num_attempts <= QuicCryptoClientStream::kMaxClientHellos) { + StartConnect(); + while (EncryptionBeingEstablished()) { + WaitForEvents(); + } + ParsedQuicVersion version = UnsupportedQuicVersion(); + if (session() != nullptr && !CanReconnectWithDifferentVersion(&version)) { + // We've successfully created a session but we're not connected, and we + // cannot reconnect with a different version. Give up trying. + break; + } + num_attempts++; + } + if (session() == nullptr) { + QUIC_BUG(quic_bug_10906_1) << "Missing session after Connect"; + return false; + } + return session()->connection()->connected(); +} + +void QuicClientBase::StartConnect() { + QUICHE_DCHECK(initialized_); + QUICHE_DCHECK(!connected()); + QuicPacketWriter* writer = network_helper_->CreateQuicPacketWriter(); + ParsedQuicVersion mutual_version = UnsupportedQuicVersion(); + const bool can_reconnect_with_different_version = + CanReconnectWithDifferentVersion(&mutual_version); + if (connected_or_attempting_connect()) { + // Clear queued up data if client can not try to connect with a different + // version. + if (!can_reconnect_with_different_version) { + ClearDataToResend(); + } + // Before we destroy the last session and create a new one, gather its stats + // and update the stats for the overall connection. + UpdateStats(); + } + + const quic::ParsedQuicVersionVector client_supported_versions = + can_reconnect_with_different_version + ? ParsedQuicVersionVector{mutual_version} + : supported_versions(); + + session_ = CreateQuicClientSession( + client_supported_versions, + new QuicConnection(GetNextConnectionId(), QuicSocketAddress(), + server_address(), helper(), alarm_factory(), writer, + /* owns_writer= */ false, Perspective::IS_CLIENT, + client_supported_versions, connection_id_generator_)); + if (can_reconnect_with_different_version) { + session()->set_client_original_supported_versions(supported_versions()); + } + if (connection_debug_visitor_ != nullptr) { + session()->connection()->set_debug_visitor(connection_debug_visitor_); + } + session()->connection()->set_client_connection_id(GetClientConnectionId()); + if (initial_max_packet_length_ != 0) { + session()->connection()->SetMaxPacketLength(initial_max_packet_length_); + } + // Reset |writer()| after |session()| so that the old writer outlives the old + // session. + set_writer(writer); + InitializeSession(); + if (can_reconnect_with_different_version) { + // This is a reconnect using server supported |mutual_version|. + session()->connection()->SetVersionNegotiated(); + } + set_connected_or_attempting_connect(true); +} + +void QuicClientBase::InitializeSession() { session()->Initialize(); } + +void QuicClientBase::Disconnect() { + QUICHE_DCHECK(initialized_); + + initialized_ = false; + if (connected()) { + session()->connection()->CloseConnection( + QUIC_PEER_GOING_AWAY, "Client disconnecting", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + } + + ClearDataToResend(); + + network_helper_->CleanUpAllUDPSockets(); +} + +ProofVerifier* QuicClientBase::proof_verifier() const { + return crypto_config_.proof_verifier(); +} + +bool QuicClientBase::EncryptionBeingEstablished() { + return !session_->IsEncryptionEstablished() && + session_->connection()->connected(); +} + +bool QuicClientBase::WaitForEvents() { + if (!connected()) { + QUIC_BUG(quic_bug_10906_2) + << "Cannot call WaitForEvents on non-connected client"; + return false; + } + + network_helper_->RunEventLoop(); + + return WaitForEventsPostprocessing(); +} + +bool QuicClientBase::WaitForEventsPostprocessing() { + QUICHE_DCHECK(session() != nullptr); + ParsedQuicVersion version = UnsupportedQuicVersion(); + if (!connected() && CanReconnectWithDifferentVersion(&version)) { + QUIC_DLOG(INFO) << "Can reconnect with version: " << version + << ", attempting to reconnect."; + + Connect(); + } + + return HasActiveRequests(); +} + +bool QuicClientBase::MigrateSocket(const QuicIpAddress& new_host) { + return MigrateSocketWithSpecifiedPort(new_host, local_port_); +} + +bool QuicClientBase::MigrateSocketWithSpecifiedPort( + const QuicIpAddress& new_host, int port) { + if (!connected()) { + QUICHE_DVLOG(1) + << "MigrateSocketWithSpecifiedPort failed as connection has closed"; + return false; + } + + network_helper_->CleanUpAllUDPSockets(); + std::unique_ptr writer = + CreateWriterForNewNetwork(new_host, port); + if (writer == nullptr) { + QUICHE_DVLOG(1) + << "MigrateSocketWithSpecifiedPort failed from writer creation"; + return false; + } + if (!session()->MigratePath(network_helper_->GetLatestClientAddress(), + session()->connection()->peer_address(), + writer.get(), false)) { + QUICHE_DVLOG(1) + << "MigrateSocketWithSpecifiedPort failed from session()->MigratePath"; + return false; + } + set_writer(writer.release()); + return true; +} + +bool QuicClientBase::ValidateAndMigrateSocket(const QuicIpAddress& new_host) { + QUICHE_DCHECK(VersionHasIetfQuicFrames( + session_->connection()->version().transport_version)); + if (!connected()) { + return false; + } + + std::unique_ptr writer = + CreateWriterForNewNetwork(new_host, local_port_); + if (writer == nullptr) { + return false; + } + // Asynchronously start migration. + session_->ValidatePath( + std::make_unique( + std::move(writer), network_helper_->GetLatestClientAddress(), + session_->peer_address()), + std::make_unique(this), + PathValidationReason::kConnectionMigration); + return true; +} + +std::unique_ptr QuicClientBase::CreateWriterForNewNetwork( + const QuicIpAddress& new_host, int port) { + set_bind_to_address(new_host); + set_local_port(port); + if (!network_helper_->CreateUDPSocketAndBind(server_address_, + bind_to_address_, port)) { + return nullptr; + } + + QuicPacketWriter* writer = network_helper_->CreateQuicPacketWriter(); + QUIC_LOG_IF(WARNING, writer == writer_.get()) + << "The new writer is wrapped in the same wrapper as the old one, thus " + "appearing to have the same address as the old one."; + return std::unique_ptr(writer); +} + +bool QuicClientBase::ChangeEphemeralPort() { + auto current_host = network_helper_->GetLatestClientAddress().host(); + return MigrateSocketWithSpecifiedPort(current_host, 0 /*any ephemeral port*/); +} + +QuicSession* QuicClientBase::session() { return session_.get(); } + +const QuicSession* QuicClientBase::session() const { return session_.get(); } + +QuicClientBase::NetworkHelper* QuicClientBase::network_helper() { + return network_helper_.get(); +} + +const QuicClientBase::NetworkHelper* QuicClientBase::network_helper() const { + return network_helper_.get(); +} + +void QuicClientBase::WaitForStreamToClose(QuicStreamId id) { + if (!connected()) { + QUIC_BUG(quic_bug_10906_3) + << "Cannot WaitForStreamToClose on non-connected client"; + return; + } + + while (connected() && !session_->IsClosedStream(id)) { + WaitForEvents(); + } +} + +bool QuicClientBase::WaitForOneRttKeysAvailable() { + if (!connected()) { + QUIC_BUG(quic_bug_10906_4) + << "Cannot WaitForOneRttKeysAvailable on non-connected client"; + return false; + } + + while (connected() && !session_->OneRttKeysAvailable()) { + WaitForEvents(); + } + + // If the handshake fails due to a timeout, the connection will be closed. + QUIC_LOG_IF(ERROR, !connected()) << "Handshake with server failed."; + return connected(); +} + +bool QuicClientBase::WaitForHandshakeConfirmed() { + if (!session_->connection()->version().UsesTls()) { + return WaitForOneRttKeysAvailable(); + } + // Otherwise, wait for receipt of HANDSHAKE_DONE frame. + while (connected() && session_->GetHandshakeState() < HANDSHAKE_CONFIRMED) { + WaitForEvents(); + } + + // If the handshake fails due to a timeout, the connection will be closed. + QUIC_LOG_IF(ERROR, !connected()) << "Handshake with server failed."; + return connected(); +} + +bool QuicClientBase::connected() const { + return session_.get() && session_->connection() && + session_->connection()->connected(); +} + +bool QuicClientBase::goaway_received() const { + return session_ != nullptr && session_->transport_goaway_received(); +} + +int QuicClientBase::GetNumSentClientHellos() { + // If we are not actively attempting to connect, the session object + // corresponds to the previous connection and should not be used. + const int current_session_hellos = !connected_or_attempting_connect_ + ? 0 + : GetNumSentClientHellosFromSession(); + return num_sent_client_hellos_ + current_session_hellos; +} + +void QuicClientBase::UpdateStats() { + num_sent_client_hellos_ += GetNumSentClientHellosFromSession(); +} + +int QuicClientBase::GetNumReceivedServerConfigUpdates() { + // If we are not actively attempting to connect, the session object + // corresponds to the previous connection and should not be used. + return !connected_or_attempting_connect_ + ? 0 + : GetNumReceivedServerConfigUpdatesFromSession(); +} + +QuicErrorCode QuicClientBase::connection_error() const { + // Return the high-level error if there was one. Otherwise, return the + // connection error from the last session. + if (connection_error_ != QUIC_NO_ERROR) { + return connection_error_; + } + if (session_ == nullptr) { + return QUIC_NO_ERROR; + } + return session_->error(); +} + +QuicConnectionId QuicClientBase::GetNextConnectionId() { + if (server_connection_id_override_.has_value()) { + return *server_connection_id_override_; + } + return GenerateNewConnectionId(); +} + +QuicConnectionId QuicClientBase::GenerateNewConnectionId() { + return QuicUtils::CreateRandomConnectionId(server_connection_id_length_); +} + +QuicConnectionId QuicClientBase::GetClientConnectionId() { + return QuicUtils::CreateRandomConnectionId(client_connection_id_length_); +} + +bool QuicClientBase::CanReconnectWithDifferentVersion( + ParsedQuicVersion* version) const { + if (session_ == nullptr || session_->connection() == nullptr || + session_->error() != QUIC_INVALID_VERSION) { + return false; + } + + const auto& server_supported_versions = + session_->connection()->server_supported_versions(); + if (server_supported_versions.empty()) { + return false; + } + + for (const auto& client_version : supported_versions_) { + if (std::find(server_supported_versions.begin(), + server_supported_versions.end(), + client_version) != server_supported_versions.end()) { + *version = client_version; + return true; + } + } + return false; +} + +bool QuicClientBase::HasPendingPathValidation() { + return session()->HasPendingPathValidation(); +} + +class ValidationResultDelegate : public QuicPathValidator::ResultDelegate { + public: + ValidationResultDelegate(QuicClientBase* client) + : QuicPathValidator::ResultDelegate(), client_(client) {} + + void OnPathValidationSuccess( + std::unique_ptr context, + QuicTime start_time) override { + QUIC_DLOG(INFO) << "Successfully validated path from " << *context + << ", validation started at " << start_time; + client_->AddValidatedPath(std::move(context)); + } + void OnPathValidationFailure( + std::unique_ptr context) override { + QUIC_LOG(WARNING) << "Fail to validate path " << *context + << ", stop migrating."; + client_->session()->connection()->OnPathValidationFailureAtClient( + /*is_multi_port=*/false, *context); + } + + private: + QuicClientBase* client_; +}; + +void QuicClientBase::ValidateNewNetwork(const QuicIpAddress& host) { + std::unique_ptr writer = + CreateWriterForNewNetwork(host, local_port_); + auto result_delegate = std::make_unique(this); + if (writer == nullptr) { + result_delegate->OnPathValidationFailure( + std::make_unique( + nullptr, network_helper_->GetLatestClientAddress(), + session_->peer_address())); + return; + } + session()->ValidatePath( + std::make_unique( + std::move(writer), network_helper_->GetLatestClientAddress(), + session_->peer_address()), + std::move(result_delegate), PathValidationReason::kConnectionMigration); +} + +void QuicClientBase::OnServerPreferredAddressAvailable( + const QuicSocketAddress& server_preferred_address) { + const auto self_address = session_->self_address(); + if (network_helper_ == nullptr || + !network_helper_->CreateUDPSocketAndBind(server_preferred_address, + self_address.host(), 0)) { + return; + } + QuicPacketWriter* writer = network_helper_->CreateQuicPacketWriter(); + if (writer == nullptr) { + return; + } + session()->ValidatePath( + std::make_unique( + std::unique_ptr(writer), + network_helper_->GetLatestClientAddress(), server_preferred_address), + std::make_unique(this), + PathValidationReason::kServerPreferredAddressMigration); +} + +} // namespace quic diff --git a/quiche/quic/tools/quic_client_base.h b/quiche/quic/tools/quic_client_base.h new file mode 100644 index 000000000000..d17503f41d27 --- /dev/null +++ b/quiche/quic/tools/quic_client_base.h @@ -0,0 +1,479 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// A base class for the toy client, which connects to a specified port and sends +// QUIC request to that endpoint. + +#ifndef QUICHE_QUIC_TOOLS_QUIC_CLIENT_BASE_H_ +#define QUICHE_QUIC_TOOLS_QUIC_CLIENT_BASE_H_ + +#include +#include + +#include "absl/base/attributes.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/crypto_handshake.h" +#include "quiche/quic/core/deterministic_connection_id_generator.h" +#include "quiche/quic/core/http/quic_client_push_promise_index.h" +#include "quiche/quic/core/http/quic_spdy_client_session.h" +#include "quiche/quic/core/http/quic_spdy_client_stream.h" +#include "quiche/quic/core/quic_config.h" +#include "quiche/quic/core/quic_connection_id.h" +#include "quiche/quic/platform/api/quic_socket_address.h" + +namespace quic { + +class ProofVerifier; +class QuicServerId; +class SessionCache; + +// A path context which owns the writer. +class QUIC_EXPORT_PRIVATE PathMigrationContext + : public QuicPathValidationContext { + public: + PathMigrationContext(std::unique_ptr writer, + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address) + : QuicPathValidationContext(self_address, peer_address), + alternative_writer_(std::move(writer)) {} + + QuicPacketWriter* WriterToUse() override { return alternative_writer_.get(); } + + QuicPacketWriter* ReleaseWriter() { return alternative_writer_.release(); } + + private: + std::unique_ptr alternative_writer_; +}; + +// QuicClientBase handles establishing a connection to the passed in +// server id, including ensuring that it supports the passed in versions +// and config. +// Subclasses derived from this class are responsible for creating the +// actual QuicSession instance, as well as defining functions that +// create and run the underlying network transport. +class QuicClientBase : public QuicSession::Visitor { + public: + // An interface to various network events that the QuicClient will need to + // interact with. + class NetworkHelper { + public: + virtual ~NetworkHelper(); + + // Runs one iteration of the event loop. + virtual void RunEventLoop() = 0; + + // Used during initialization: creates the UDP socket FD, sets socket + // options, and binds the socket to our address. + virtual bool CreateUDPSocketAndBind(QuicSocketAddress server_address, + QuicIpAddress bind_to_address, + int bind_to_port) = 0; + + // Unregister and close all open UDP sockets. + virtual void CleanUpAllUDPSockets() = 0; + + // If the client has at least one UDP socket, return address of the latest + // created one. Otherwise, return an empty socket address. + virtual QuicSocketAddress GetLatestClientAddress() const = 0; + + // Creates a packet writer to be used for the next connection. + virtual QuicPacketWriter* CreateQuicPacketWriter() = 0; + }; + + QuicClientBase(const QuicServerId& server_id, + const ParsedQuicVersionVector& supported_versions, + const QuicConfig& config, + QuicConnectionHelperInterface* helper, + QuicAlarmFactory* alarm_factory, + std::unique_ptr network_helper, + std::unique_ptr proof_verifier, + std::unique_ptr session_cache); + QuicClientBase(const QuicClientBase&) = delete; + QuicClientBase& operator=(const QuicClientBase&) = delete; + + virtual ~QuicClientBase(); + + // Implmenets QuicSession::Visitor + void OnConnectionClosed(QuicConnectionId /*server_connection_id*/, + QuicErrorCode /*error*/, + const std::string& /*error_details*/, + ConnectionCloseSource /*source*/) override {} + + void OnWriteBlocked(QuicBlockedWriterInterface* /*blocked_writer*/) override { + } + void OnRstStreamReceived(const QuicRstStreamFrame& /*frame*/) override {} + void OnStopSendingReceived(const QuicStopSendingFrame& /*frame*/) override {} + bool TryAddNewConnectionId( + const QuicConnectionId& /*server_connection_id*/, + const QuicConnectionId& /*new_connection_id*/) override { + return false; + } + void OnConnectionIdRetired( + const QuicConnectionId& /*server_connection_id*/) override {} + void OnServerPreferredAddressAvailable( + const QuicSocketAddress& server_preferred_address) override; + + // Initializes the client to create a connection. Should be called exactly + // once before calling StartConnect or Connect. Returns true if the + // initialization succeeds, false otherwise. + virtual bool Initialize(); + + // "Connect" to the QUIC server, including performing synchronous crypto + // handshake. + bool Connect(); + + // Start the crypto handshake. This can be done in place of the synchronous + // Connect(), but callers are responsible for making sure the crypto handshake + // completes. + void StartConnect(); + + // Calls session()->Initialize(). Subclasses may override this if any extra + // initialization needs to be done. Subclasses should expect that session() + // is non-null and valid. + virtual void InitializeSession(); + + // Disconnects from the QUIC server. + void Disconnect(); + + // Returns true if the crypto handshake has yet to establish encryption. + // Returns false if encryption is active (even if the server hasn't confirmed + // the handshake) or if the connection has been closed. + bool EncryptionBeingEstablished(); + + // Wait for events until the stream with the given ID is closed. + void WaitForStreamToClose(QuicStreamId id); + + // Wait for 1-RTT keys become available. + // Returns true once 1-RTT keys are available, false otherwise. + ABSL_MUST_USE_RESULT bool WaitForOneRttKeysAvailable(); + + // Wait for handshake state proceeds to HANDSHAKE_CONFIRMED. + // In QUIC crypto, this does the same as WaitForOneRttKeysAvailable, while in + // TLS, this waits for HANDSHAKE_DONE frame is received. + ABSL_MUST_USE_RESULT bool WaitForHandshakeConfirmed(); + + // Wait up to 50ms, and handle any events which occur. + // Returns true if there are any outstanding requests. + bool WaitForEvents(); + + // Performs the part of WaitForEvents() that is done after the actual event + // loop call. + bool WaitForEventsPostprocessing(); + + // Migrate to a new socket (new_host) during an active connection. + bool MigrateSocket(const QuicIpAddress& new_host); + + // Migrate to a new socket (new_host, port) during an active connection. + bool MigrateSocketWithSpecifiedPort(const QuicIpAddress& new_host, int port); + + // Validate the new socket and migrate to it if the validation succeeds. + // Otherwise stay on the current socket. Return true if the validation has + // started. + bool ValidateAndMigrateSocket(const QuicIpAddress& new_host); + + // Open a new socket to change to a new ephemeral port. + bool ChangeEphemeralPort(); + + QuicSession* session(); + const QuicSession* session() const; + + bool connected() const; + virtual bool goaway_received() const; + + const QuicServerId& server_id() const { return server_id_; } + + // This should only be set before the initial Connect() + void set_server_id(const QuicServerId& server_id) { server_id_ = server_id; } + + void SetUserAgentID(const std::string& user_agent_id) { + crypto_config_.set_user_agent_id(user_agent_id); + } + + void SetTlsSignatureAlgorithms(std::string signature_algorithms) { + crypto_config_.set_tls_signature_algorithms( + std::move(signature_algorithms)); + } + + const ParsedQuicVersionVector& supported_versions() const { + return supported_versions_; + } + + void SetSupportedVersions(const ParsedQuicVersionVector& versions) { + supported_versions_ = versions; + } + + QuicConfig* config() { return &config_; } + + QuicCryptoClientConfig* crypto_config() { return &crypto_config_; } + + // Change the initial maximum packet size of the connection. Has to be called + // before Connect()/StartConnect() in order to have any effect. + void set_initial_max_packet_length(QuicByteCount initial_max_packet_length) { + initial_max_packet_length_ = initial_max_packet_length; + } + + // The number of client hellos sent. + int GetNumSentClientHellos(); + + // Returns true if early data (0-RTT data) was sent and the server accepted + // it. + virtual bool EarlyDataAccepted() = 0; + + // Returns true if the handshake was delayed one round trip by the server + // because the server wanted proof the client controls its source address + // before progressing further. In Google QUIC, this would be due to an + // inchoate REJ in the QUIC Crypto handshake; in IETF QUIC this would be due + // to a Retry packet. + // TODO(nharper): Consider a better name for this method. + virtual bool ReceivedInchoateReject() = 0; + + // Gather the stats for the last session and update the stats for the overall + // connection. + void UpdateStats(); + + // The number of server config updates received. + int GetNumReceivedServerConfigUpdates(); + + // Returns any errors that occurred at the connection-level. + QuicErrorCode connection_error() const; + void set_connection_error(QuicErrorCode connection_error) { + connection_error_ = connection_error; + } + + bool connected_or_attempting_connect() const { + return connected_or_attempting_connect_; + } + void set_connected_or_attempting_connect( + bool connected_or_attempting_connect) { + connected_or_attempting_connect_ = connected_or_attempting_connect; + } + + QuicPacketWriter* writer() { return writer_.get(); } + void set_writer(QuicPacketWriter* writer) { + if (writer_.get() != writer) { + writer_.reset(writer); + } + } + + void reset_writer() { writer_.reset(); } + + ProofVerifier* proof_verifier() const; + + void set_bind_to_address(QuicIpAddress address) { + bind_to_address_ = address; + } + + QuicIpAddress bind_to_address() const { return bind_to_address_; } + + void set_local_port(int local_port) { local_port_ = local_port; } + + int local_port() const { return local_port_; } + + const QuicSocketAddress& server_address() const { return server_address_; } + + void set_server_address(const QuicSocketAddress& server_address) { + server_address_ = server_address; + } + + QuicConnectionHelperInterface* helper() { return helper_.get(); } + + NetworkHelper* network_helper(); + const NetworkHelper* network_helper() const; + + bool initialized() const { return initialized_; } + + void SetPreSharedKey(absl::string_view key) { + crypto_config_.set_pre_shared_key(key); + } + + void set_connection_debug_visitor( + QuicConnectionDebugVisitor* connection_debug_visitor) { + connection_debug_visitor_ = connection_debug_visitor; + } + + // Sets the interface name to bind. If empty, will not attempt to bind the + // socket to that interface. Defaults to empty string. + void set_interface_name(std::string interface_name) { + interface_name_ = interface_name; + } + + std::string interface_name() const { return interface_name_; } + + void set_server_connection_id_override( + const QuicConnectionId& connection_id) { + server_connection_id_override_ = connection_id; + } + + void set_server_connection_id_length(uint8_t server_connection_id_length) { + server_connection_id_length_ = server_connection_id_length; + } + + void set_client_connection_id_length(uint8_t client_connection_id_length) { + client_connection_id_length_ = client_connection_id_length; + } + + bool HasPendingPathValidation(); + + void ValidateNewNetwork(const QuicIpAddress& host); + + void AddValidatedPath(std::unique_ptr context) { + validated_paths_.push_back(std::move(context)); + } + + const std::vector>& + validated_paths() const { + return validated_paths_; + } + + protected: + // TODO(rch): Move GetNumSentClientHellosFromSession and + // GetNumReceivedServerConfigUpdatesFromSession into a new/better + // QuicSpdyClientSession class. The current inherits dependencies from + // Spdy. When that happens this class and all its subclasses should + // work with QuicSpdyClientSession instead of QuicSession. + // That will obviate the need for the pure virtual functions below. + + // Extract the number of sent client hellos from the session. + virtual int GetNumSentClientHellosFromSession() = 0; + + // The number of server config updates received. + virtual int GetNumReceivedServerConfigUpdatesFromSession() = 0; + + // If this client supports buffering data, resend it. + virtual void ResendSavedData() = 0; + + // If this client supports buffering data, clear it. + virtual void ClearDataToResend() = 0; + + // Takes ownership of |connection|. If you override this function, + // you probably want to call ResetSession() in your destructor. + // TODO(rch): Change the connection parameter to take in a + // std::unique_ptr instead. + virtual std::unique_ptr CreateQuicClientSession( + const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection) = 0; + + // Generates the next ConnectionId for |server_id_|. By default, if the + // cached server config contains a server-designated ID, that ID will be + // returned. Otherwise, the next random ID will be returned. + QuicConnectionId GetNextConnectionId(); + + // Generates a new, random connection ID (as opposed to a server-designated + // connection ID). + virtual QuicConnectionId GenerateNewConnectionId(); + + // Returns the client connection ID to use. + virtual QuicConnectionId GetClientConnectionId(); + + QuicAlarmFactory* alarm_factory() { return alarm_factory_.get(); } + + // Subclasses may need to explicitly clear the session on destruction + // if they create it with objects that will be destroyed before this is. + // You probably want to call this if you override CreateQuicSpdyClientSession. + void ResetSession() { session_.reset(); } + + // Returns true if the corresponding of this client has active requests. + virtual bool HasActiveRequests() = 0; + + // Allows derived classes to access this when creating connections. + ConnectionIdGeneratorInterface& connection_id_generator(); + + private: + // Returns true and set |version| if client can reconnect with a different + // version. + bool CanReconnectWithDifferentVersion(ParsedQuicVersion* version) const; + + std::unique_ptr CreateWriterForNewNetwork( + const QuicIpAddress& new_host, int port); + + // |server_id_| is a tuple (hostname, port, is_https) of the server. + QuicServerId server_id_; + + // Tracks if the client is initialized to connect. + bool initialized_; + + // Address of the server. + QuicSocketAddress server_address_; + + // If initialized, the address to bind to. + QuicIpAddress bind_to_address_; + + // Local port to bind to. Initialize to 0. + int local_port_; + + // config_ and crypto_config_ contain configuration and cached state about + // servers. + QuicConfig config_; + QuicCryptoClientConfig crypto_config_; + + // Helper to be used by created connections. Must outlive |session_|. + std::unique_ptr helper_; + + // Alarm factory to be used by created connections. Must outlive |session_|. + std::unique_ptr alarm_factory_; + + // Writer used to actually send packets to the wire. Must outlive |session_|. + std::unique_ptr writer_; + + // Session which manages streams. + std::unique_ptr session_; + + // This vector contains QUIC versions which we currently support. + // This should be ordered such that the highest supported version is the first + // element, with subsequent elements in descending order (versions can be + // skipped as necessary). We will always pick supported_versions_[0] as the + // initial version to use. + ParsedQuicVersionVector supported_versions_; + + // The initial value of maximum packet size of the connection. If set to + // zero, the default is used. + QuicByteCount initial_max_packet_length_; + + // The number of hellos sent during the current/latest connection. + int num_sent_client_hellos_; + + // Used to store any errors that occurred with the overall connection (as + // opposed to that associated with the last session object). + QuicErrorCode connection_error_; + + // True when the client is attempting to connect. Set to false between a call + // to Disconnect() and the subsequent call to StartConnect(). When + // connected_or_attempting_connect_ is false, the session object corresponds + // to the previous client-level connection. + bool connected_or_attempting_connect_; + + // The network helper used to create sockets and manage the event loop. + // Not owned by this class. + std::unique_ptr network_helper_; + + // The debug visitor set on the connection right after it is constructed. + // Not owned, must be valid for the lifetime of the QuicClientBase instance. + QuicConnectionDebugVisitor* connection_debug_visitor_; + + // If set, + // - GetNextConnectionId will use this as the next server connection id. + // - GenerateNewConnectionId will not be called. + absl::optional server_connection_id_override_; + + // GenerateNewConnectionId creates a random connection ID of this length. + // Defaults to 8. + uint8_t server_connection_id_length_; + + // GetClientConnectionId creates a random connection ID of this length. + // Defaults to 0. + uint8_t client_connection_id_length_; + + // Stores validated paths. + std::vector> validated_paths_; + + // Stores the interface name to bind. If empty, will not attempt to bind the + // socket to that interface. Defaults to empty string. + std::string interface_name_; + + DeterministicConnectionIdGenerator connection_id_generator_{ + kQuicDefaultConnectionIdLength}; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_QUIC_CLIENT_BASE_H_ diff --git a/quiche/quic/tools/quic_client_bin.cc b/quiche/quic/tools/quic_client_bin.cc new file mode 100644 index 000000000000..ad2acf94a47f --- /dev/null +++ b/quiche/quic/tools/quic_client_bin.cc @@ -0,0 +1,67 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// A binary wrapper for QuicClient. +// Connects to a host using QUIC, sends a request to the provided URL, and +// displays the response. +// +// Some usage examples: +// +// Standard request/response: +// quic_client www.google.com +// quic_client www.google.com --quiet +// quic_client www.google.com --port=443 +// +// Use a specific version: +// quic_client www.google.com --quic_version=23 +// +// Send a POST instead of a GET: +// quic_client www.google.com --body="this is a POST body" +// +// Append additional headers to the request: +// quic_client www.google.com --headers="header-a: 1234; header-b: 5678" +// +// Connect to a host different to the URL being requested: +// quic_client mail.google.com --host=www.google.com +// +// Connect to a specific IP: +// IP=`dig www.google.com +short | head -1` +// quic_client www.google.com --host=${IP} +// +// Send repeated requests and change ephemeral port between requests +// quic_client www.google.com --num_requests=10 +// +// Try to connect to a host which does not speak QUIC: +// quic_client www.example.com +// +// This tool is available as a built binary at: +// /google/data/ro/teams/quic/tools/quic_client +// After submitting changes to this file, you will need to follow the +// instructions at go/quic_client_binary_update + +#include +#include +#include + +#include "quiche/quic/tools/quic_epoll_client_factory.h" +#include "quiche/quic/tools/quic_toy_client.h" +#include "quiche/common/platform/api/quiche_command_line_flags.h" +#include "quiche/common/platform/api/quiche_system_event_loop.h" + +int main(int argc, char* argv[]) { + quiche::QuicheSystemEventLoop event_loop("quic_client"); + const char* usage = "Usage: quic_client [options] "; + + // All non-flag arguments should be interpreted as URLs to fetch. + std::vector urls = + quiche::QuicheParseCommandLineFlags(usage, argc, argv); + if (urls.size() != 1) { + quiche::QuichePrintCommandLineFlagHelp(usage); + exit(0); + } + + quic::QuicEpollClientFactory factory; + quic::QuicToyClient client(&factory); + return client.SendRequestsAndPrintResponses(urls); +} diff --git a/quiche/quic/tools/quic_client_default_network_helper.cc b/quiche/quic/tools/quic_client_default_network_helper.cc new file mode 100644 index 000000000000..ab9e1b1af79b --- /dev/null +++ b/quiche/quic/tools/quic_client_default_network_helper.cc @@ -0,0 +1,258 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/quic_client_default_network_helper.h" + +#include "absl/cleanup/cleanup.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/quic_default_packet_writer.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_udp_socket.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_system_event_loop.h" + +namespace quic { + +namespace { + +// For level-triggered I/O, we need to manually rearm the kSocketEventWritable +// listener whenever the socket gets blocked. +class LevelTriggeredPacketWriter : public QuicDefaultPacketWriter { + public: + explicit LevelTriggeredPacketWriter(int fd, QuicEventLoop* event_loop) + : QuicDefaultPacketWriter(fd), event_loop_(event_loop) { + QUICHE_DCHECK(!event_loop->SupportsEdgeTriggered()); + } + + WriteResult WritePacket(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + PerPacketOptions* options) override { + WriteResult result = QuicDefaultPacketWriter::WritePacket( + buffer, buf_len, self_address, peer_address, options); + if (IsWriteBlockedStatus(result.status)) { + bool success = event_loop_->RearmSocket(fd(), kSocketEventWritable); + QUICHE_DCHECK(success); + } + return result; + } + + private: + QuicEventLoop* event_loop_; +}; + +} // namespace + +QuicClientDefaultNetworkHelper::QuicClientDefaultNetworkHelper( + QuicEventLoop* event_loop, QuicClientBase* client) + : event_loop_(event_loop), + packets_dropped_(0), + overflow_supported_(false), + packet_reader_(new QuicPacketReader()), + client_(client), + max_reads_per_event_loop_(std::numeric_limits::max()) {} + +QuicClientDefaultNetworkHelper::~QuicClientDefaultNetworkHelper() { + if (client_->connected()) { + client_->session()->connection()->CloseConnection( + QUIC_PEER_GOING_AWAY, "Client being torn down", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + } + + CleanUpAllUDPSockets(); +} + +bool QuicClientDefaultNetworkHelper::CreateUDPSocketAndBind( + QuicSocketAddress server_address, QuicIpAddress bind_to_address, + int bind_to_port) { + int fd = CreateUDPSocket(server_address, &overflow_supported_); + if (fd < 0) { + return false; + } + auto closer = absl::MakeCleanup([fd] { close(fd); }); + + QuicSocketAddress client_address; + if (bind_to_address.IsInitialized()) { + client_address = QuicSocketAddress(bind_to_address, client_->local_port()); + } else if (server_address.host().address_family() == IpAddressFamily::IP_V4) { + client_address = QuicSocketAddress(QuicIpAddress::Any4(), bind_to_port); + } else { + client_address = QuicSocketAddress(QuicIpAddress::Any6(), bind_to_port); + } + + // Some platforms expect that the addrlen given to bind() exactly matches the + // size of the associated protocol family's sockaddr struct. + // TODO(b/179430548): Revert this when affected platforms are updated to + // to support binding with an addrelen of sizeof(sockaddr_storage) + socklen_t addrlen; + switch (client_address.host().address_family()) { + case IpAddressFamily::IP_V4: + addrlen = sizeof(sockaddr_in); + break; + case IpAddressFamily::IP_V6: + addrlen = sizeof(sockaddr_in6); + break; + case IpAddressFamily::IP_UNSPEC: + addrlen = 0; + break; + } + + sockaddr_storage addr = client_address.generic_address(); + int rc = bind(fd, reinterpret_cast(&addr), addrlen); + if (rc < 0) { + QUIC_LOG(ERROR) << "Bind failed: " << strerror(errno) + << " bind_to_address:" << bind_to_address + << ", bind_to_port:" << bind_to_port + << ", client_address:" << client_address; + return false; + } + + if (client_address.FromSocket(fd) != 0) { + QUIC_LOG(ERROR) << "Unable to get self address. Error: " + << strerror(errno); + } + + if (event_loop_->RegisterSocket( + fd, kSocketEventReadable | kSocketEventWritable, this)) { + fd_address_map_[fd] = client_address; + std::move(closer).Cancel(); + return true; + } + return false; +} + +void QuicClientDefaultNetworkHelper::CleanUpUDPSocket(int fd) { + CleanUpUDPSocketImpl(fd); + fd_address_map_.erase(fd); +} + +void QuicClientDefaultNetworkHelper::CleanUpAllUDPSockets() { + for (std::pair fd_address : fd_address_map_) { + CleanUpUDPSocketImpl(fd_address.first); + } + fd_address_map_.clear(); +} + +void QuicClientDefaultNetworkHelper::CleanUpUDPSocketImpl(int fd) { + if (fd > -1) { + bool success = event_loop_->UnregisterSocket(fd); + QUICHE_DCHECK(success || fds_unregistered_externally_); + int rc = close(fd); + QUICHE_DCHECK_EQ(0, rc); + } +} + +void QuicClientDefaultNetworkHelper::RunEventLoop() { + quiche::QuicheRunSystemEventLoopIteration(); + event_loop_->RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(50)); +} + +void QuicClientDefaultNetworkHelper::OnSocketEvent( + QuicEventLoop* /*event_loop*/, QuicUdpSocketFd fd, + QuicSocketEventMask events) { + if (events & kSocketEventReadable) { + QUIC_DVLOG(1) << "Read packets on kSocketEventReadable"; + int times_to_read = max_reads_per_event_loop_; + bool more_to_read = true; + QuicPacketCount packets_dropped = 0; + while (client_->connected() && more_to_read && times_to_read > 0) { + more_to_read = packet_reader_->ReadAndDispatchPackets( + fd, GetLatestClientAddress().port(), *client_->helper()->GetClock(), + this, overflow_supported_ ? &packets_dropped : nullptr); + --times_to_read; + } + if (packets_dropped_ < packets_dropped) { + QUIC_LOG(ERROR) + << packets_dropped - packets_dropped_ + << " more packets are dropped in the socket receive buffer."; + packets_dropped_ = packets_dropped; + } + if (client_->connected() && more_to_read) { + bool success = + event_loop_->ArtificiallyNotifyEvent(fd, kSocketEventReadable); + QUICHE_DCHECK(success); + } else if (!event_loop_->SupportsEdgeTriggered()) { + bool success = event_loop_->RearmSocket(fd, kSocketEventReadable); + QUICHE_DCHECK(success); + } + } + if (client_->connected() && (events & kSocketEventWritable)) { + client_->writer()->SetWritable(); + client_->session()->connection()->OnCanWrite(); + } +} + +QuicPacketWriter* QuicClientDefaultNetworkHelper::CreateQuicPacketWriter() { + if (event_loop_->SupportsEdgeTriggered()) { + return new QuicDefaultPacketWriter(GetLatestFD()); + } else { + return new LevelTriggeredPacketWriter(GetLatestFD(), event_loop_); + } +} + +void QuicClientDefaultNetworkHelper::SetClientPort(int port) { + fd_address_map_.back().second = + QuicSocketAddress(GetLatestClientAddress().host(), port); +} + +QuicSocketAddress QuicClientDefaultNetworkHelper::GetLatestClientAddress() + const { + if (fd_address_map_.empty()) { + return QuicSocketAddress(); + } + + return fd_address_map_.back().second; +} + +int QuicClientDefaultNetworkHelper::GetLatestFD() const { + if (fd_address_map_.empty()) { + return -1; + } + + return fd_address_map_.back().first; +} + +void QuicClientDefaultNetworkHelper::ProcessPacket( + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, const QuicReceivedPacket& packet) { + client_->session()->ProcessUdpPacket(self_address, peer_address, packet); +} + +int QuicClientDefaultNetworkHelper::CreateUDPSocket( + QuicSocketAddress server_address, bool* overflow_supported) { + QuicUdpSocketApi api; + int fd = api.Create(server_address.host().AddressFamilyToInt(), + /*receive_buffer_size =*/kDefaultSocketReceiveBuffer, + /*send_buffer_size =*/kDefaultSocketReceiveBuffer); + if (fd < 0) { + return fd; + } + + *overflow_supported = api.EnableDroppedPacketCount(fd); + api.EnableReceiveTimestamp(fd); + + if (!BindInterfaceNameIfNeeded(fd)) { + CleanUpUDPSocket(fd); + return kQuicInvalidSocketFd; + } + + return fd; +} + +bool QuicClientDefaultNetworkHelper::BindInterfaceNameIfNeeded(int fd) { + QuicUdpSocketApi api; + std::string interface_name = client_->interface_name(); + if (!interface_name.empty()) { + if (!api.BindInterface(fd, interface_name)) { + QUIC_DLOG(WARNING) << "Failed to bind socket (" << fd + << ") to interface (" << interface_name << ")."; + return false; + } + } + return true; +} + +} // namespace quic diff --git a/quiche/quic/tools/quic_client_default_network_helper.h b/quiche/quic/tools/quic_client_default_network_helper.h new file mode 100644 index 000000000000..07b07533f202 --- /dev/null +++ b/quiche/quic/tools/quic_client_default_network_helper.h @@ -0,0 +1,133 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TOOLS_QUIC_CLIENT_DEFAULT_NETWORK_HELPER_H_ +#define QUICHE_QUIC_TOOLS_QUIC_CLIENT_DEFAULT_NETWORK_HELPER_H_ + +#include +#include +#include + +#include "absl/types/optional.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/quic_packet_reader.h" +#include "quiche/quic/core/quic_udp_socket.h" +#include "quiche/quic/tools/quic_client_base.h" +#include "quiche/common/quiche_linked_hash_map.h" + +namespace quic { + +namespace test { +class QuicClientPeer; +} // namespace test + +// An implementation of the QuicClientBase::NetworkHelper interface that is +// based on the QuicEventLoop API. +class QuicClientDefaultNetworkHelper : public QuicClientBase::NetworkHelper, + public QuicSocketEventListener, + public ProcessPacketInterface { + public: + QuicClientDefaultNetworkHelper(QuicEventLoop* event_loop, + QuicClientBase* client); + QuicClientDefaultNetworkHelper(const QuicClientDefaultNetworkHelper&) = + delete; + QuicClientDefaultNetworkHelper& operator=( + const QuicClientDefaultNetworkHelper&) = delete; + + ~QuicClientDefaultNetworkHelper() override; + + // From QuicSocketEventListener. + void OnSocketEvent(QuicEventLoop* event_loop, QuicUdpSocketFd fd, + QuicSocketEventMask events) override; + + // From ProcessPacketInterface. This will be called for each received + // packet. + void ProcessPacket(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + const QuicReceivedPacket& packet) override; + + // From NetworkHelper. + void RunEventLoop() override; + bool CreateUDPSocketAndBind(QuicSocketAddress server_address, + QuicIpAddress bind_to_address, + int bind_to_port) override; + void CleanUpAllUDPSockets() override; + QuicSocketAddress GetLatestClientAddress() const override; + QuicPacketWriter* CreateQuicPacketWriter() override; + + // Accessors provided for convenience, not part of any interface. + QuicEventLoop* event_loop() { return event_loop_; } + const quiche::QuicheLinkedHashMap& fd_address_map() + const { + return fd_address_map_; + } + + // If the client has at least one UDP socket, return the latest created one. + // Otherwise, return -1. + int GetLatestFD() const; + + // Create socket for connection to |server_address| with default socket + // options. + // Return fd index. + virtual int CreateUDPSocket(QuicSocketAddress server_address, + bool* overflow_supported); + + QuicClientBase* client() { return client_; } + + void set_max_reads_per_event_loop(int num_reads) { + max_reads_per_event_loop_ = num_reads; + } + // If |fd| is an open UDP socket, unregister and close it. Otherwise, do + // nothing. + void CleanUpUDPSocket(int fd); + + // Used for testing. + void SetClientPort(int port); + + // Indicates that some of the FDs owned by the network helper may be + // unregistered by the external code by manually calling + // event_loop()->UnregisterSocket() (this is useful for certain scenarios + // where an external event loop is used). + void AllowFdsToBeUnregisteredExternally() { + fds_unregistered_externally_ = true; + } + + // Bind a socket to a specific network interface. + bool BindInterfaceNameIfNeeded(int fd); + + // Actually clean up |fd|. + virtual void CleanUpUDPSocketImpl(int fd); + + private: + // Listens for events on the client socket. + QuicEventLoop* event_loop_; + + // Map mapping created UDP sockets to their addresses. By using linked hash + // map, the order of socket creation can be recorded. + quiche::QuicheLinkedHashMap fd_address_map_; + + // If overflow_supported_ is true, this will be the number of packets dropped + // during the lifetime of the server. + QuicPacketCount packets_dropped_; + + // True if the kernel supports SO_RXQ_OVFL, the number of packets dropped + // because the socket would otherwise overflow. + bool overflow_supported_; + + // Point to a QuicPacketReader object on the heap. The reader allocates more + // space than allowed on the stack. + std::unique_ptr packet_reader_; + + QuicClientBase* client_; + + int max_reads_per_event_loop_; + + // If true, some of the FDs owned by the network helper may be unregistered by + // the external code. + bool fds_unregistered_externally_ = false; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_QUIC_CLIENT_DEFAULT_NETWORK_HELPER_H_ diff --git a/quiche/quic/tools/quic_client_factory.h b/quiche/quic/tools/quic_client_factory.h new file mode 100644 index 000000000000..df4aab9f3919 --- /dev/null +++ b/quiche/quic/tools/quic_client_factory.h @@ -0,0 +1,35 @@ +// Copyright (c) 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TOOLS_QUIC_CLIENT_FACTORY_H_ +#define QUICHE_QUIC_TOOLS_QUIC_CLIENT_FACTORY_H_ + +#include "quiche/quic/core/crypto/proof_verifier.h" +#include "quiche/quic/core/crypto/quic_crypto_client_config.h" +#include "quiche/quic/core/quic_config.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/tools/quic_spdy_client_base.h" + +namespace quic { + +// Interface implemented by Factories to create QuicClients. +class ClientFactoryInterface { + public: + virtual ~ClientFactoryInterface() = default; + + // Creates a new client configured to connect to |host_for_lookup:port| + // supporting |versions|, using |host_for_handshake| for handshake and + // |verifier| to verify proofs. + virtual std::unique_ptr CreateClient( + std::string host_for_handshake, std::string host_for_lookup, + // AF_INET, AF_INET6, or AF_UNSPEC(=don't care). + int address_family_for_lookup, uint16_t port, + ParsedQuicVersionVector versions, const QuicConfig& config, + std::unique_ptr verifier, + std::unique_ptr session_cache) = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_QUIC_CLIENT_FACTORY_H_ diff --git a/quiche/quic/tools/quic_client_interop_test_bin.cc b/quiche/quic/tools/quic_client_interop_test_bin.cc new file mode 100644 index 000000000000..6fe1a511ba7b --- /dev/null +++ b/quiche/quic/tools/quic_client_interop_test_bin.cc @@ -0,0 +1,462 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "quiche/quic/core/crypto/quic_client_session_cache.h" +#include "quiche/quic/core/io/quic_default_event_loop.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/quic_default_clock.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_session_peer.h" +#include "quiche/quic/tools/fake_proof_verifier.h" +#include "quiche/quic/tools/quic_default_client.h" +#include "quiche/quic/tools/quic_name_lookup.h" +#include "quiche/quic/tools/quic_url.h" +#include "quiche/common/platform/api/quiche_command_line_flags.h" +#include "quiche/common/platform/api/quiche_system_event_loop.h" +#include "quiche/spdy/core/http2_header_block.h" + +DEFINE_QUICHE_COMMAND_LINE_FLAG(std::string, host, "", + "The IP or hostname to connect to."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::string, quic_version, "", + "The QUIC version to use. Defaults to most recent IETF QUIC version."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG(int32_t, port, 0, "The port to connect to."); + +namespace quic { + +enum class Feature { + // First row of features ("table stakes") + // A version negotiation response is elicited and acted on. + kVersionNegotiation, + // The handshake completes successfully. + kHandshake, + // Stream data is being exchanged and ACK'ed. + kStreamData, + // The connection close procedcure completes with a zero error code. + kConnectionClose, + // The connection was established using TLS resumption. + kResumption, + // 0-RTT data is being sent and acted on. + kZeroRtt, + // A RETRY packet was successfully processed. + kRetry, + // A handshake using a ClientHello that spans multiple packets completed + // successfully. + kQuantum, + + // Second row of features (anything else protocol-related) + // We switched to a different port and the server migrated to it. + kRebinding, + // One endpoint can update keys and its peer responds correctly. + kKeyUpdate, + + // Third row of features (H3 tests) + // An H3 transaction succeeded. + kHttp3, + // One or both endpoints insert entries into dynamic table and subsequenly + // reference them from header blocks. + kDynamicEntryReferenced, +}; + +char MatrixLetter(Feature f) { + switch (f) { + case Feature::kVersionNegotiation: + return 'V'; + case Feature::kHandshake: + return 'H'; + case Feature::kStreamData: + return 'D'; + case Feature::kConnectionClose: + return 'C'; + case Feature::kResumption: + return 'R'; + case Feature::kZeroRtt: + return 'Z'; + case Feature::kRetry: + return 'S'; + case Feature::kQuantum: + return 'Q'; + case Feature::kRebinding: + return 'B'; + case Feature::kKeyUpdate: + return 'U'; + case Feature::kHttp3: + return '3'; + case Feature::kDynamicEntryReferenced: + return 'd'; + } +} + +class QuicClientInteropRunner : QuicConnectionDebugVisitor { + public: + QuicClientInteropRunner() {} + + void InsertFeature(Feature feature) { features_.insert(feature); } + + std::set features() const { return features_; } + + // Attempts a resumption using |client| by disconnecting and reconnecting. If + // resumption is successful, |features_| is modified to add + // Feature::kResumption to it, otherwise it is left unmodified. + void AttemptResumption(QuicDefaultClient* client, + const std::string& authority); + + void AttemptRequest(QuicSocketAddress addr, std::string authority, + QuicServerId server_id, ParsedQuicVersion version, + bool test_version_negotiation, bool attempt_rebind, + bool attempt_multi_packet_chlo, bool attempt_key_update); + + // Constructs a Http2HeaderBlock containing the pseudo-headers needed to make + // a GET request to "/" on the hostname |authority|. + spdy::Http2HeaderBlock ConstructHeaderBlock(const std::string& authority); + + // Sends an HTTP request represented by |header_block| using |client|. + void SendRequest(QuicDefaultClient* client, + const spdy::Http2HeaderBlock& header_block); + + void OnConnectionCloseFrame(const QuicConnectionCloseFrame& frame) override { + switch (frame.close_type) { + case GOOGLE_QUIC_CONNECTION_CLOSE: + QUIC_LOG(ERROR) << "Received unexpected GoogleQUIC connection close"; + break; + case IETF_QUIC_TRANSPORT_CONNECTION_CLOSE: + if (frame.wire_error_code == NO_IETF_QUIC_ERROR) { + InsertFeature(Feature::kConnectionClose); + } else { + QUIC_LOG(ERROR) << "Received transport connection close " + << QuicIetfTransportErrorCodeString( + static_cast( + frame.wire_error_code)); + } + break; + case IETF_QUIC_APPLICATION_CONNECTION_CLOSE: + if (frame.wire_error_code == 0) { + InsertFeature(Feature::kConnectionClose); + } else { + QUIC_LOG(ERROR) << "Received application connection close " + << frame.wire_error_code; + } + break; + } + } + + void OnVersionNegotiationPacket( + const QuicVersionNegotiationPacket& /*packet*/) override { + InsertFeature(Feature::kVersionNegotiation); + } + + private: + std::set features_; +}; + +void QuicClientInteropRunner::AttemptResumption(QuicDefaultClient* client, + const std::string& authority) { + client->Disconnect(); + if (!client->Initialize()) { + QUIC_LOG(ERROR) << "Failed to reinitialize client"; + return; + } + if (!client->Connect()) { + return; + } + + bool zero_rtt_attempt = !client->session()->OneRttKeysAvailable(); + + spdy::Http2HeaderBlock header_block = ConstructHeaderBlock(authority); + SendRequest(client, header_block); + + if (!client->session()->OneRttKeysAvailable()) { + return; + } + + if (static_cast( + test::QuicSessionPeer::GetMutableCryptoStream(client->session())) + ->IsResumption()) { + InsertFeature(Feature::kResumption); + } + if (static_cast( + test::QuicSessionPeer::GetMutableCryptoStream(client->session())) + ->EarlyDataAccepted() && + zero_rtt_attempt && client->latest_response_code() != -1) { + InsertFeature(Feature::kZeroRtt); + } +} + +void QuicClientInteropRunner::AttemptRequest( + QuicSocketAddress addr, std::string authority, QuicServerId server_id, + ParsedQuicVersion version, bool test_version_negotiation, + bool attempt_rebind, bool attempt_multi_packet_chlo, + bool attempt_key_update) { + ParsedQuicVersionVector versions = {version}; + if (test_version_negotiation) { + versions.insert(versions.begin(), QuicVersionReservedForNegotiation()); + } + + auto proof_verifier = std::make_unique(); + auto session_cache = std::make_unique(); + QuicConfig config; + QuicTime::Delta timeout = QuicTime::Delta::FromSeconds(20); + config.SetIdleNetworkTimeout(timeout); + if (attempt_multi_packet_chlo) { + // Make the ClientHello span multiple packets by adding a custom transport + // parameter. + constexpr auto kCustomParameter = + static_cast(0x173E); + std::string custom_value(2000, '?'); + config.custom_transport_parameters_to_send()[kCustomParameter] = + custom_value; + } + std::unique_ptr event_loop = + GetDefaultEventLoop()->Create(QuicDefaultClock::Get()); + auto client = std::make_unique( + addr, server_id, versions, config, event_loop.get(), + std::move(proof_verifier), std::move(session_cache)); + client->set_connection_debug_visitor(this); + if (!client->Initialize()) { + QUIC_LOG(ERROR) << "Failed to initialize client"; + return; + } + const bool connect_result = client->Connect(); + QuicConnection* connection = client->session()->connection(); + if (connection == nullptr) { + QUIC_LOG(ERROR) << "No QuicConnection object"; + return; + } + QuicConnectionStats client_stats = connection->GetStats(); + if (client_stats.retry_packet_processed) { + InsertFeature(Feature::kRetry); + } + if (test_version_negotiation && connection->version() == version) { + InsertFeature(Feature::kVersionNegotiation); + } + if (test_version_negotiation && !connect_result) { + // Failed to negotiate version, retry without version negotiation. + AttemptRequest(addr, authority, server_id, version, + /*test_version_negotiation=*/false, attempt_rebind, + attempt_multi_packet_chlo, attempt_key_update); + return; + } + if (!client->session()->OneRttKeysAvailable()) { + if (attempt_multi_packet_chlo) { + // Failed to handshake with multi-packet client hello, retry without it. + AttemptRequest(addr, authority, server_id, version, + test_version_negotiation, attempt_rebind, + /*attempt_multi_packet_chlo=*/false, attempt_key_update); + return; + } + return; + } + InsertFeature(Feature::kHandshake); + if (attempt_multi_packet_chlo) { + InsertFeature(Feature::kQuantum); + } + + spdy::Http2HeaderBlock header_block = ConstructHeaderBlock(authority); + SendRequest(client.get(), header_block); + + if (!client->connected()) { + return; + } + + if (client->latest_response_code() != -1) { + InsertFeature(Feature::kHttp3); + + if (client->client_session()->dynamic_table_entry_referenced()) { + InsertFeature(Feature::kDynamicEntryReferenced); + } + + if (attempt_rebind) { + // Now make a second request after switching to a different client port. + if (client->ChangeEphemeralPort()) { + client->SendRequestAndWaitForResponse(header_block, "", /*fin=*/true); + if (!client->connected()) { + // Rebinding does not work, retry without attempting it. + AttemptRequest(addr, authority, server_id, version, + test_version_negotiation, /*attempt_rebind=*/false, + attempt_multi_packet_chlo, attempt_key_update); + return; + } + InsertFeature(Feature::kRebinding); + + if (client->client_session()->dynamic_table_entry_referenced()) { + InsertFeature(Feature::kDynamicEntryReferenced); + } + } else { + QUIC_LOG(ERROR) << "Failed to change ephemeral port"; + } + } + + if (attempt_key_update) { + if (connection->IsKeyUpdateAllowed()) { + if (connection->InitiateKeyUpdate( + KeyUpdateReason::kLocalForInteropRunner)) { + client->SendRequestAndWaitForResponse(header_block, "", /*fin=*/true); + if (!client->connected()) { + // Key update does not work, retry without attempting it. + AttemptRequest(addr, authority, server_id, version, + test_version_negotiation, attempt_rebind, + attempt_multi_packet_chlo, + /*attempt_key_update=*/false); + return; + } + InsertFeature(Feature::kKeyUpdate); + } else { + QUIC_LOG(ERROR) << "Failed to initiate key update"; + } + } else { + QUIC_LOG(ERROR) << "Key update not allowed"; + } + } + } + + if (connection->connected()) { + connection->CloseConnection( + QUIC_NO_ERROR, "Graceful close", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + InsertFeature(Feature::kConnectionClose); + } + + AttemptResumption(client.get(), authority); +} + +spdy::Http2HeaderBlock QuicClientInteropRunner::ConstructHeaderBlock( + const std::string& authority) { + // Construct and send a request. + spdy::Http2HeaderBlock header_block; + header_block[":method"] = "GET"; + header_block[":scheme"] = "https"; + header_block[":authority"] = authority; + header_block[":path"] = "/"; + return header_block; +} + +void QuicClientInteropRunner::SendRequest( + QuicDefaultClient* client, const spdy::Http2HeaderBlock& header_block) { + client->set_store_response(true); + client->SendRequestAndWaitForResponse(header_block, "", /*fin=*/true); + + QuicConnection* connection = client->session()->connection(); + if (connection == nullptr) { + QUIC_LOG(ERROR) << "No QuicConnection object"; + return; + } + QuicConnectionStats client_stats = connection->GetStats(); + QuicSentPacketManager* sent_packet_manager = + test::QuicConnectionPeer::GetSentPacketManager(connection); + const bool received_forward_secure_ack = + sent_packet_manager != nullptr && + sent_packet_manager->GetLargestAckedPacket(ENCRYPTION_FORWARD_SECURE) + .IsInitialized(); + if (client_stats.stream_bytes_received > 0 && received_forward_secure_ack) { + InsertFeature(Feature::kStreamData); + } +} + +std::set ServerSupport(std::string dns_host, std::string url_host, + int port, ParsedQuicVersion version) { + std::cout << "Attempting interop with version " << version << std::endl; + + // Build the client, and try to connect. + QuicSocketAddress addr = tools::LookupAddress(dns_host, absl::StrCat(port)); + if (!addr.IsInitialized()) { + QUIC_LOG(ERROR) << "Failed to resolve " << dns_host; + return std::set(); + } + QuicServerId server_id(url_host, port, false); + std::string authority = absl::StrCat(url_host, ":", port); + + QuicClientInteropRunner runner; + + runner.AttemptRequest(addr, authority, server_id, version, + /*test_version_negotiation=*/true, + /*attempt_rebind=*/true, + /*attempt_multi_packet_chlo=*/true, + /*attempt_key_update=*/true); + + return runner.features(); +} + +} // namespace quic + +int main(int argc, char* argv[]) { + quiche::QuicheSystemEventLoop event_loop("quic_client"); + const char* usage = "Usage: quic_client_interop_test [options] [url]"; + + std::vector args = + quiche::QuicheParseCommandLineFlags(usage, argc, argv); + if (args.size() > 1) { + quiche::QuichePrintCommandLineFlagHelp(usage); + exit(1); + } + std::string dns_host = quiche::GetQuicheCommandLineFlag(FLAGS_host); + std::string url_host = ""; + int port = quiche::GetQuicheCommandLineFlag(FLAGS_port); + + if (!args.empty()) { + quic::QuicUrl url(args[0], "https"); + url_host = url.host(); + if (dns_host.empty()) { + dns_host = url_host; + } + if (port == 0) { + port = url.port(); + } + } + if (port == 0) { + port = 443; + } + if (dns_host.empty()) { + quiche::QuichePrintCommandLineFlagHelp(usage); + exit(1); + } + if (url_host.empty()) { + url_host = dns_host; + } + + // Pick QUIC version to use. + quic::QuicVersionInitializeSupportForIetfDraft(); + quic::ParsedQuicVersion version = quic::UnsupportedQuicVersion(); + std::string quic_version_string = + quiche::GetQuicheCommandLineFlag(FLAGS_quic_version); + if (!quic_version_string.empty()) { + version = quic::ParseQuicVersionString(quic_version_string); + } else { + for (const quic::ParsedQuicVersion& vers : quic::AllSupportedVersions()) { + // Use the most recent IETF QUIC version. + if (vers.HasIetfQuicFrames() && vers.UsesHttp3() && vers.UsesTls()) { + version = vers; + break; + } + } + } + QUICHE_CHECK(version.IsKnown()); + QuicEnableVersion(version); + + auto supported_features = + quic::ServerSupport(dns_host, url_host, port, version); + std::cout << "Results for " << url_host << ":" << port << std::endl; + int current_row = 1; + for (auto feature : supported_features) { + if (current_row < 2 && feature >= quic::Feature::kRebinding) { + std::cout << std::endl; + current_row = 2; + } + if (current_row < 3 && feature >= quic::Feature::kHttp3) { + std::cout << std::endl; + current_row = 3; + } + std::cout << MatrixLetter(feature); + } + std::cout << std::endl; +} diff --git a/quiche/quic/tools/quic_default_client.cc b/quiche/quic/tools/quic_default_client.cc new file mode 100644 index 000000000000..6a15fa911cb5 --- /dev/null +++ b/quiche/quic/tools/quic_default_client.cc @@ -0,0 +1,103 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/quic_default_client.h" + +#include + +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/core/quic_default_connection_helper.h" +#include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/tools/quic_simple_client_session.h" + +namespace quic { + +QuicDefaultClient::QuicDefaultClient( + QuicSocketAddress server_address, const QuicServerId& server_id, + const ParsedQuicVersionVector& supported_versions, + QuicEventLoop* event_loop, std::unique_ptr proof_verifier) + : QuicDefaultClient( + server_address, server_id, supported_versions, QuicConfig(), + event_loop, + std::make_unique(event_loop, this), + std::move(proof_verifier), nullptr) {} + +QuicDefaultClient::QuicDefaultClient( + QuicSocketAddress server_address, const QuicServerId& server_id, + const ParsedQuicVersionVector& supported_versions, + QuicEventLoop* event_loop, std::unique_ptr proof_verifier, + std::unique_ptr session_cache) + : QuicDefaultClient( + server_address, server_id, supported_versions, QuicConfig(), + event_loop, + std::make_unique(event_loop, this), + std::move(proof_verifier), std::move(session_cache)) {} + +QuicDefaultClient::QuicDefaultClient( + QuicSocketAddress server_address, const QuicServerId& server_id, + const ParsedQuicVersionVector& supported_versions, const QuicConfig& config, + QuicEventLoop* event_loop, std::unique_ptr proof_verifier, + std::unique_ptr session_cache) + : QuicDefaultClient( + server_address, server_id, supported_versions, config, event_loop, + std::make_unique(event_loop, this), + std::move(proof_verifier), std::move(session_cache)) {} + +QuicDefaultClient::QuicDefaultClient( + QuicSocketAddress server_address, const QuicServerId& server_id, + const ParsedQuicVersionVector& supported_versions, + QuicEventLoop* event_loop, + std::unique_ptr network_helper, + std::unique_ptr proof_verifier) + : QuicDefaultClient(server_address, server_id, supported_versions, + QuicConfig(), event_loop, std::move(network_helper), + std::move(proof_verifier), nullptr) {} + +QuicDefaultClient::QuicDefaultClient( + QuicSocketAddress server_address, const QuicServerId& server_id, + const ParsedQuicVersionVector& supported_versions, const QuicConfig& config, + QuicEventLoop* event_loop, + std::unique_ptr network_helper, + std::unique_ptr proof_verifier) + : QuicDefaultClient(server_address, server_id, supported_versions, config, + event_loop, std::move(network_helper), + std::move(proof_verifier), nullptr) {} + +QuicDefaultClient::QuicDefaultClient( + QuicSocketAddress server_address, const QuicServerId& server_id, + const ParsedQuicVersionVector& supported_versions, const QuicConfig& config, + QuicEventLoop* event_loop, + std::unique_ptr network_helper, + std::unique_ptr proof_verifier, + std::unique_ptr session_cache) + : QuicSpdyClientBase(server_id, supported_versions, config, + new QuicDefaultConnectionHelper(), + event_loop->CreateAlarmFactory().release(), + std::move(network_helper), std::move(proof_verifier), + std::move(session_cache)) { + set_server_address(server_address); +} + +QuicDefaultClient::~QuicDefaultClient() = default; + +std::unique_ptr QuicDefaultClient::CreateQuicClientSession( + const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection) { + return std::make_unique( + *config(), supported_versions, connection, this, network_helper(), + server_id(), crypto_config(), push_promise_index(), drop_response_body(), + enable_web_transport()); +} + +QuicClientDefaultNetworkHelper* QuicDefaultClient::default_network_helper() { + return static_cast(network_helper()); +} + +const QuicClientDefaultNetworkHelper* +QuicDefaultClient::default_network_helper() const { + return static_cast(network_helper()); +} + +} // namespace quic diff --git a/quiche/quic/tools/quic_default_client.h b/quiche/quic/tools/quic_default_client.h new file mode 100644 index 000000000000..8ef170db8073 --- /dev/null +++ b/quiche/quic/tools/quic_default_client.h @@ -0,0 +1,87 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// A toy client, which connects to a specified port and sends QUIC +// request to that endpoint. + +#ifndef QUICHE_QUIC_TOOLS_QUIC_DEFAULT_CLIENT_H_ +#define QUICHE_QUIC_TOOLS_QUIC_DEFAULT_CLIENT_H_ + +#include +#include +#include + +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/quic_config.h" +#include "quiche/quic/tools/quic_client_default_network_helper.h" +#include "quiche/quic/tools/quic_spdy_client_base.h" + +namespace quic { + +class QuicServerId; + +namespace test { +class QuicDefaultClientPeer; +} // namespace test + +class QuicDefaultClient : public QuicSpdyClientBase { + public: + // These will create their own QuicClientDefaultNetworkHelper. + QuicDefaultClient(QuicSocketAddress server_address, + const QuicServerId& server_id, + const ParsedQuicVersionVector& supported_versions, + QuicEventLoop* event_loop, + std::unique_ptr proof_verifier); + QuicDefaultClient(QuicSocketAddress server_address, + const QuicServerId& server_id, + const ParsedQuicVersionVector& supported_versions, + QuicEventLoop* event_loop, + std::unique_ptr proof_verifier, + std::unique_ptr session_cache); + QuicDefaultClient(QuicSocketAddress server_address, + const QuicServerId& server_id, + const ParsedQuicVersionVector& supported_versions, + const QuicConfig& config, QuicEventLoop* event_loop, + std::unique_ptr proof_verifier, + std::unique_ptr session_cache); + // This will take ownership of a passed in network primitive. + QuicDefaultClient( + QuicSocketAddress server_address, const QuicServerId& server_id, + const ParsedQuicVersionVector& supported_versions, + QuicEventLoop* event_loop, + std::unique_ptr network_helper, + std::unique_ptr proof_verifier); + QuicDefaultClient( + QuicSocketAddress server_address, const QuicServerId& server_id, + const ParsedQuicVersionVector& supported_versions, + const QuicConfig& config, QuicEventLoop* event_loop, + std::unique_ptr network_helper, + std::unique_ptr proof_verifier); + QuicDefaultClient( + QuicSocketAddress server_address, const QuicServerId& server_id, + const ParsedQuicVersionVector& supported_versions, + const QuicConfig& config, QuicEventLoop* event_loop, + std::unique_ptr network_helper, + std::unique_ptr proof_verifier, + std::unique_ptr session_cache); + QuicDefaultClient(const QuicDefaultClient&) = delete; + QuicDefaultClient& operator=(const QuicDefaultClient&) = delete; + + ~QuicDefaultClient() override; + + // QuicSpdyClientBase overrides. + std::unique_ptr CreateQuicClientSession( + const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection) override; + + // Exposed for QUIC tests. + int GetLatestFD() const { return default_network_helper()->GetLatestFD(); } + + QuicClientDefaultNetworkHelper* default_network_helper(); + const QuicClientDefaultNetworkHelper* default_network_helper() const; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_QUIC_DEFAULT_CLIENT_H_ diff --git a/quiche/quic/tools/quic_default_client_test.cc b/quiche/quic/tools/quic_default_client_test.cc new file mode 100644 index 000000000000..d58a24ae0295 --- /dev/null +++ b/quiche/quic/tools/quic_default_client_test.cc @@ -0,0 +1,146 @@ +// Copyright (c) 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This unit test relies on /proc, which is not available on non-Linux based +// OSes that we support. +#if defined(__linux__) + +#include "quiche/quic/tools/quic_default_client.h" + +#include +#include + +#include +#include + +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/io/quic_default_event_loop.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/quic_default_clock.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/platform/api/quic_test_loopback.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/common/quiche_text_utils.h" + +namespace quic { +namespace test { +namespace { + +const char* kPathToFds = "/proc/self/fd"; + +// Return the value of a symbolic link in |path|, if |path| is not found, return +// an empty string. +std::string ReadLink(const std::string& path) { + std::string result(PATH_MAX, '\0'); + ssize_t result_size = readlink(path.c_str(), &result[0], result.size()); + if (result_size < 0 && errno == ENOENT) { + return ""; + } + QUICHE_CHECK(result_size > 0 && + static_cast(result_size) < result.size()) + << "result_size:" << result_size << ", errno:" << errno + << ", path:" << path; + result.resize(result_size); + return result; +} + +// Counts the number of open sockets for the current process. +size_t NumOpenSocketFDs() { + size_t socket_count = 0; + dirent* file; + std::unique_ptr fd_directory(opendir(kPathToFds), + closedir); + while ((file = readdir(fd_directory.get())) != nullptr) { + absl::string_view name(file->d_name); + if (name == "." || name == "..") { + continue; + } + + std::string fd_path = ReadLink(absl::StrCat(kPathToFds, "/", name)); + if (absl::StartsWith(fd_path, "socket:")) { + socket_count++; + } + } + return socket_count; +} + +class QuicDefaultClientTest : public QuicTest { + public: + QuicDefaultClientTest() + : event_loop_(GetDefaultEventLoop()->Create(QuicDefaultClock::Get())) { + // Creates and destroys a single client first which may open persistent + // sockets when initializing platform dependencies like certificate + // verifier. Future creation of addtional clients will deterministically + // open one socket per client. + CreateAndInitializeQuicClient(); + } + + // Creates a new QuicClient and Initializes it on an unused port. + // Caller is responsible for deletion. + std::unique_ptr CreateAndInitializeQuicClient() { + QuicSocketAddress server_address(QuicSocketAddress(TestLoopback(), 0)); + QuicServerId server_id("hostname", server_address.port(), false); + ParsedQuicVersionVector versions = AllSupportedVersions(); + auto client = std::make_unique( + server_address, server_id, versions, event_loop_.get(), + crypto_test_utils::ProofVerifierForTesting()); + EXPECT_TRUE(client->Initialize()); + return client; + } + + private: + std::unique_ptr event_loop_; +}; + +TEST_F(QuicDefaultClientTest, DoNotLeakSocketFDs) { + // Make sure that the QuicClient doesn't leak socket FDs. Doing so could cause + // port exhaustion in long running processes which repeatedly create clients. + + // Record the initial number of FDs. + size_t number_of_open_fds = NumOpenSocketFDs(); + + // Create a number of clients, initialize them, and verify this has resulted + // in additional FDs being opened. + const int kNumClients = 50; + for (int i = 0; i < kNumClients; ++i) { + EXPECT_EQ(number_of_open_fds, NumOpenSocketFDs()); + std::unique_ptr client(CreateAndInitializeQuicClient()); + // Initializing the client will create a new FD. + EXPECT_EQ(number_of_open_fds + 1, NumOpenSocketFDs()); + } + + // The FDs created by the QuicClients should now be closed. + EXPECT_EQ(number_of_open_fds, NumOpenSocketFDs()); +} + +TEST_F(QuicDefaultClientTest, CreateAndCleanUpUDPSockets) { + size_t number_of_open_fds = NumOpenSocketFDs(); + + std::unique_ptr client(CreateAndInitializeQuicClient()); + // Creating and initializing a client will result in one socket being opened. + EXPECT_EQ(number_of_open_fds + 1, NumOpenSocketFDs()); + + // Create more UDP sockets. + EXPECT_TRUE(client->default_network_helper()->CreateUDPSocketAndBind( + client->server_address(), client->bind_to_address(), + client->local_port())); + EXPECT_EQ(number_of_open_fds + 2, NumOpenSocketFDs()); + EXPECT_TRUE(client->default_network_helper()->CreateUDPSocketAndBind( + client->server_address(), client->bind_to_address(), + client->local_port())); + EXPECT_EQ(number_of_open_fds + 3, NumOpenSocketFDs()); + + // Clean up UDP sockets. + client->default_network_helper()->CleanUpUDPSocket(client->GetLatestFD()); + EXPECT_EQ(number_of_open_fds + 2, NumOpenSocketFDs()); + client->default_network_helper()->CleanUpUDPSocket(client->GetLatestFD()); + EXPECT_EQ(number_of_open_fds + 1, NumOpenSocketFDs()); +} + +} // namespace +} // namespace test +} // namespace quic + +#endif // defined(__linux__) diff --git a/quiche/quic/tools/quic_epoll_client_factory.cc b/quiche/quic/tools/quic_epoll_client_factory.cc new file mode 100644 index 000000000000..1f6677796725 --- /dev/null +++ b/quiche/quic/tools/quic_epoll_client_factory.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/quic_epoll_client_factory.h" + +#include + +#include "absl/strings/str_cat.h" +#include "quiche/quic/core/io/quic_default_event_loop.h" +#include "quiche/quic/core/quic_default_clock.h" +#include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/tools/quic_default_client.h" +#include "quiche/quic/tools/quic_name_lookup.h" + +namespace quic { + +QuicEpollClientFactory::QuicEpollClientFactory() + : event_loop_(GetDefaultEventLoop()->Create(QuicDefaultClock::Get())) {} + +std::unique_ptr QuicEpollClientFactory::CreateClient( + std::string host_for_handshake, std::string host_for_lookup, + int address_family_for_lookup, uint16_t port, + ParsedQuicVersionVector versions, const QuicConfig& config, + std::unique_ptr verifier, + std::unique_ptr session_cache) { + QuicSocketAddress addr = tools::LookupAddress( + address_family_for_lookup, host_for_lookup, absl::StrCat(port)); + if (!addr.IsInitialized()) { + QUIC_LOG(ERROR) << "Unable to resolve address: " << host_for_lookup; + return nullptr; + } + QuicServerId server_id(host_for_handshake, port, false); + return std::make_unique( + addr, server_id, versions, config, event_loop_.get(), std::move(verifier), + std::move(session_cache)); +} + +} // namespace quic diff --git a/quiche/quic/tools/quic_epoll_client_factory.h b/quiche/quic/tools/quic_epoll_client_factory.h new file mode 100644 index 000000000000..ab5e88203c07 --- /dev/null +++ b/quiche/quic/tools/quic_epoll_client_factory.h @@ -0,0 +1,33 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TOOLS_QUIC_EPOLL_CLIENT_FACTORY_H_ +#define QUICHE_QUIC_TOOLS_QUIC_EPOLL_CLIENT_FACTORY_H_ + +#include + +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/tools/quic_client_factory.h" + +namespace quic { + +// Factory creating QuicClient instances. +class QuicEpollClientFactory : public ClientFactoryInterface { + public: + QuicEpollClientFactory(); + + std::unique_ptr CreateClient( + std::string host_for_handshake, std::string host_for_lookup, + int address_family_for_lookup, uint16_t port, + ParsedQuicVersionVector versions, const QuicConfig& config, + std::unique_ptr verifier, + std::unique_ptr session_cache) override; + + private: + std::unique_ptr event_loop_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_QUIC_EPOLL_CLIENT_FACTORY_H_ diff --git a/quiche/quic/tools/quic_memory_cache_backend.cc b/quiche/quic/tools/quic_memory_cache_backend.cc new file mode 100644 index 000000000000..0fe984d96914 --- /dev/null +++ b/quiche/quic/tools/quic_memory_cache_backend.cc @@ -0,0 +1,507 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/quic_memory_cache_backend.h" + +#include + +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/http/spdy_utils.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/tools/web_transport_test_visitors.h" +#include "quiche/common/platform/api/quiche_file_utils.h" +#include "quiche/common/quiche_text_utils.h" + +using spdy::Http2HeaderBlock; +using spdy::kV3LowestPriority; + +namespace quic { + +QuicMemoryCacheBackend::ResourceFile::ResourceFile(const std::string& file_name) + : file_name_(file_name) {} + +QuicMemoryCacheBackend::ResourceFile::~ResourceFile() = default; + +void QuicMemoryCacheBackend::ResourceFile::Read() { + absl::optional maybe_file_contents = + quiche::ReadFileContents(file_name_); + if (!maybe_file_contents) { + QUIC_LOG(DFATAL) << "Failed to read file for the memory cache backend: " + << file_name_; + return; + } + file_contents_ = *maybe_file_contents; + + // First read the headers. + size_t start = 0; + while (start < file_contents_.length()) { + size_t pos = file_contents_.find('\n', start); + if (pos == std::string::npos) { + QUIC_LOG(DFATAL) << "Headers invalid or empty, ignoring: " << file_name_; + return; + } + size_t len = pos - start; + // Support both dos and unix line endings for convenience. + if (file_contents_[pos - 1] == '\r') { + len -= 1; + } + absl::string_view line(file_contents_.data() + start, len); + start = pos + 1; + // Headers end with an empty line. + if (line.empty()) { + break; + } + // Extract the status from the HTTP first line. + if (line.substr(0, 4) == "HTTP") { + pos = line.find(' '); + if (pos == std::string::npos) { + QUIC_LOG(DFATAL) << "Headers invalid or empty, ignoring: " + << file_name_; + return; + } + spdy_headers_[":status"] = line.substr(pos + 1, 3); + continue; + } + // Headers are "key: value". + pos = line.find(": "); + if (pos == std::string::npos) { + QUIC_LOG(DFATAL) << "Headers invalid or empty, ignoring: " << file_name_; + return; + } + spdy_headers_.AppendValueOrAddHeader( + quiche::QuicheTextUtils::ToLower(line.substr(0, pos)), + line.substr(pos + 2)); + } + + // The connection header is prohibited in HTTP/2. + spdy_headers_.erase("connection"); + + // Override the URL with the X-Original-Url header, if present. + auto it = spdy_headers_.find("x-original-url"); + if (it != spdy_headers_.end()) { + x_original_url_ = it->second; + HandleXOriginalUrl(); + } + + // X-Push-URL header is a relatively quick way to support sever push + // in the toy server. A production server should use link=preload + // stuff as described in https://w3c.github.io/preload/. + it = spdy_headers_.find("x-push-url"); + if (it != spdy_headers_.end()) { + absl::string_view push_urls = it->second; + size_t start = 0; + while (start < push_urls.length()) { + size_t pos = push_urls.find('\0', start); + if (pos == std::string::npos) { + push_urls_.push_back(absl::string_view(push_urls.data() + start, + push_urls.length() - start)); + break; + } + push_urls_.push_back(absl::string_view(push_urls.data() + start, pos)); + start += pos + 1; + } + } + + body_ = absl::string_view(file_contents_.data() + start, + file_contents_.size() - start); +} + +void QuicMemoryCacheBackend::ResourceFile::SetHostPathFromBase( + absl::string_view base) { + QUICHE_DCHECK(base[0] != '/') << base; + size_t path_start = base.find_first_of('/'); + if (path_start == absl::string_view::npos) { + host_ = std::string(base); + path_ = ""; + return; + } + + host_ = std::string(base.substr(0, path_start)); + size_t query_start = base.find_first_of(','); + if (query_start > 0) { + path_ = std::string(base.substr(path_start, query_start - 1)); + } else { + path_ = std::string(base.substr(path_start)); + } +} + +absl::string_view QuicMemoryCacheBackend::ResourceFile::RemoveScheme( + absl::string_view url) { + if (absl::StartsWith(url, "https://")) { + url.remove_prefix(8); + } else if (absl::StartsWith(url, "http://")) { + url.remove_prefix(7); + } + return url; +} + +void QuicMemoryCacheBackend::ResourceFile::HandleXOriginalUrl() { + absl::string_view url(x_original_url_); + SetHostPathFromBase(RemoveScheme(url)); +} + +const QuicBackendResponse* QuicMemoryCacheBackend::GetResponse( + absl::string_view host, absl::string_view path) const { + QuicWriterMutexLock lock(&response_mutex_); + + auto it = responses_.find(GetKey(host, path)); + if (it == responses_.end()) { + uint64_t ignored = 0; + if (generate_bytes_response_) { + if (absl::SimpleAtoi(absl::string_view(path.data() + 1, path.size() - 1), + &ignored)) { + // The actual parsed length is ignored here and will be recomputed + // by the caller. + return generate_bytes_response_.get(); + } + } + QUIC_DVLOG(1) << "Get response for resource failed: host " << host + << " path " << path; + if (default_response_) { + return default_response_.get(); + } + return nullptr; + } + return it->second.get(); +} + +using ServerPushInfo = QuicBackendResponse::ServerPushInfo; +using SpecialResponseType = QuicBackendResponse::SpecialResponseType; + +void QuicMemoryCacheBackend::AddSimpleResponse(absl::string_view host, + absl::string_view path, + int response_code, + absl::string_view body) { + Http2HeaderBlock response_headers; + response_headers[":status"] = absl::StrCat(response_code); + response_headers["content-length"] = absl::StrCat(body.length()); + AddResponse(host, path, std::move(response_headers), body); +} + +void QuicMemoryCacheBackend::AddSimpleResponseWithServerPushResources( + absl::string_view host, absl::string_view path, int response_code, + absl::string_view body, std::list push_resources) { + AddSimpleResponse(host, path, response_code, body); + MaybeAddServerPushResources(host, path, push_resources); +} + +void QuicMemoryCacheBackend::AddDefaultResponse(QuicBackendResponse* response) { + QuicWriterMutexLock lock(&response_mutex_); + default_response_.reset(response); +} + +void QuicMemoryCacheBackend::AddResponse(absl::string_view host, + absl::string_view path, + Http2HeaderBlock response_headers, + absl::string_view response_body) { + AddResponseImpl(host, path, QuicBackendResponse::REGULAR_RESPONSE, + std::move(response_headers), response_body, + Http2HeaderBlock(), std::vector()); +} + +void QuicMemoryCacheBackend::AddResponse(absl::string_view host, + absl::string_view path, + Http2HeaderBlock response_headers, + absl::string_view response_body, + Http2HeaderBlock response_trailers) { + AddResponseImpl(host, path, QuicBackendResponse::REGULAR_RESPONSE, + std::move(response_headers), response_body, + std::move(response_trailers), + std::vector()); +} + +bool QuicMemoryCacheBackend::SetResponseDelay(absl::string_view host, + absl::string_view path, + QuicTime::Delta delay) { + QuicWriterMutexLock lock(&response_mutex_); + auto it = responses_.find(GetKey(host, path)); + if (it == responses_.end()) return false; + + it->second->set_delay(delay); + return true; +} + +void QuicMemoryCacheBackend::AddResponseWithEarlyHints( + absl::string_view host, absl::string_view path, + spdy::Http2HeaderBlock response_headers, absl::string_view response_body, + const std::vector& early_hints) { + AddResponseImpl(host, path, QuicBackendResponse::REGULAR_RESPONSE, + std::move(response_headers), response_body, + Http2HeaderBlock(), early_hints); +} + +void QuicMemoryCacheBackend::AddSpecialResponse( + absl::string_view host, absl::string_view path, + SpecialResponseType response_type) { + AddResponseImpl(host, path, response_type, Http2HeaderBlock(), "", + Http2HeaderBlock(), std::vector()); +} + +void QuicMemoryCacheBackend::AddSpecialResponse( + absl::string_view host, absl::string_view path, + spdy::Http2HeaderBlock response_headers, absl::string_view response_body, + SpecialResponseType response_type) { + AddResponseImpl(host, path, response_type, std::move(response_headers), + response_body, Http2HeaderBlock(), + std::vector()); +} + +QuicMemoryCacheBackend::QuicMemoryCacheBackend() : cache_initialized_(false) {} + +bool QuicMemoryCacheBackend::InitializeBackend( + const std::string& cache_directory) { + if (cache_directory.empty()) { + QUIC_BUG(quic_bug_10932_1) << "cache_directory must not be empty."; + return false; + } + QUIC_LOG(INFO) + << "Attempting to initialize QuicMemoryCacheBackend from directory: " + << cache_directory; + std::vector files; + if (!quiche::EnumerateDirectoryRecursively(cache_directory, files)) { + QUIC_BUG(QuicMemoryCacheBackend unreadable directory) + << "Can't read QuicMemoryCacheBackend directory: " << cache_directory; + return false; + } + std::list> resource_files; + for (const auto& filename : files) { + std::unique_ptr resource_file(new ResourceFile(filename)); + + // Tease apart filename into host and path. + std::string base(resource_file->file_name()); + // Transform windows path separators to URL path separators. + for (size_t i = 0; i < base.length(); ++i) { + if (base[i] == '\\') { + base[i] = '/'; + } + } + base.erase(0, cache_directory.length()); + if (base[0] == '/') { + base.erase(0, 1); + } + + resource_file->SetHostPathFromBase(base); + resource_file->Read(); + + AddResponse(resource_file->host(), resource_file->path(), + resource_file->spdy_headers().Clone(), resource_file->body()); + + resource_files.push_back(std::move(resource_file)); + } + + for (const auto& resource_file : resource_files) { + std::list push_resources; + for (const auto& push_url : resource_file->push_urls()) { + QuicUrl url(push_url); + const QuicBackendResponse* response = GetResponse(url.host(), url.path()); + if (!response) { + QUIC_BUG(quic_bug_10932_2) + << "Push URL '" << push_url << "' not found."; + return false; + } + push_resources.push_back(ServerPushInfo(url, response->headers().Clone(), + kV3LowestPriority, + (std::string(response->body())))); + } + MaybeAddServerPushResources(resource_file->host(), resource_file->path(), + push_resources); + } + + cache_initialized_ = true; + return true; +} + +void QuicMemoryCacheBackend::GenerateDynamicResponses() { + QuicWriterMutexLock lock(&response_mutex_); + // Add a generate bytes response. + spdy::Http2HeaderBlock response_headers; + response_headers[":status"] = "200"; + generate_bytes_response_ = std::make_unique(); + generate_bytes_response_->set_headers(std::move(response_headers)); + generate_bytes_response_->set_response_type( + QuicBackendResponse::GENERATE_BYTES); +} + +void QuicMemoryCacheBackend::EnableWebTransport() { + enable_webtransport_ = true; +} + +bool QuicMemoryCacheBackend::IsBackendInitialized() const { + return cache_initialized_; +} + +void QuicMemoryCacheBackend::FetchResponseFromBackend( + const Http2HeaderBlock& request_headers, + const std::string& /*request_body*/, + QuicSimpleServerBackend::RequestHandler* quic_stream) { + const QuicBackendResponse* quic_response = nullptr; + // Find response in cache. If not found, send error response. + auto authority = request_headers.find(":authority"); + auto path = request_headers.find(":path"); + if (authority != request_headers.end() && path != request_headers.end()) { + quic_response = GetResponse(authority->second, path->second); + } + + std::string request_url; + if (authority != request_headers.end()) { + request_url = std::string(authority->second); + } + if (path != request_headers.end()) { + request_url += std::string(path->second); + } + QUIC_DVLOG(1) + << "Fetching QUIC response from backend in-memory cache for url " + << request_url; + quic_stream->OnResponseBackendComplete(quic_response); +} + +// The memory cache does not have a per-stream handler +void QuicMemoryCacheBackend::CloseBackendResponseStream( + QuicSimpleServerBackend::RequestHandler* /*quic_stream*/) {} + +std::list QuicMemoryCacheBackend::GetServerPushResources( + std::string request_url) { + QuicWriterMutexLock lock(&response_mutex_); + + std::list resources; + auto resource_range = server_push_resources_.equal_range(request_url); + for (auto it = resource_range.first; it != resource_range.second; ++it) { + resources.push_back(it->second); + } + QUIC_DVLOG(1) << "Found " << resources.size() << " push resources for " + << request_url; + return resources; +} + +QuicMemoryCacheBackend::WebTransportResponse +QuicMemoryCacheBackend::ProcessWebTransportRequest( + const spdy::Http2HeaderBlock& request_headers, + WebTransportSession* session) { + if (!SupportsWebTransport()) { + return QuicSimpleServerBackend::ProcessWebTransportRequest(request_headers, + session); + } + + auto path_it = request_headers.find(":path"); + if (path_it == request_headers.end()) { + WebTransportResponse response; + response.response_headers[":status"] = "400"; + return response; + } + absl::string_view path = path_it->second; + if (path == "/echo") { + WebTransportResponse response; + response.response_headers[":status"] = "200"; + response.visitor = + std::make_unique(session); + return response; + } + + WebTransportResponse response; + response.response_headers[":status"] = "404"; + return response; +} + +QuicMemoryCacheBackend::~QuicMemoryCacheBackend() { + { + QuicWriterMutexLock lock(&response_mutex_); + responses_.clear(); + } +} + +void QuicMemoryCacheBackend::AddResponseImpl( + absl::string_view host, absl::string_view path, + SpecialResponseType response_type, Http2HeaderBlock response_headers, + absl::string_view response_body, Http2HeaderBlock response_trailers, + const std::vector& early_hints) { + QuicWriterMutexLock lock(&response_mutex_); + + QUICHE_DCHECK(!host.empty()) + << "Host must be populated, e.g. \"www.google.com\""; + std::string key = GetKey(host, path); + if (responses_.contains(key)) { + QUIC_BUG(quic_bug_10932_3) + << "Response for '" << key << "' already exists!"; + return; + } + auto new_response = std::make_unique(); + new_response->set_response_type(response_type); + new_response->set_headers(std::move(response_headers)); + new_response->set_body(response_body); + new_response->set_trailers(std::move(response_trailers)); + for (auto& headers : early_hints) { + new_response->AddEarlyHints(headers); + } + QUIC_DVLOG(1) << "Add response with key " << key; + responses_[key] = std::move(new_response); +} + +std::string QuicMemoryCacheBackend::GetKey(absl::string_view host, + absl::string_view path) const { + std::string host_string = std::string(host); + size_t port = host_string.find(':'); + if (port != std::string::npos) + host_string = std::string(host_string.c_str(), port); + return host_string + std::string(path); +} + +void QuicMemoryCacheBackend::MaybeAddServerPushResources( + absl::string_view request_host, absl::string_view request_path, + std::list push_resources) { + std::string request_url = GetKey(request_host, request_path); + + for (const auto& push_resource : push_resources) { + if (PushResourceExistsInCache(request_url, push_resource)) { + continue; + } + + QUIC_DVLOG(1) << "Add request-resource association: request url " + << request_url << " push url " + << push_resource.request_url.ToString() + << " response headers " + << push_resource.headers.DebugString(); + { + QuicWriterMutexLock lock(&response_mutex_); + server_push_resources_.insert(std::make_pair(request_url, push_resource)); + } + std::string host = push_resource.request_url.host(); + if (host.empty()) { + host = std::string(request_host); + } + std::string path = push_resource.request_url.path(); + bool found_existing_response = false; + { + QuicWriterMutexLock lock(&response_mutex_); + found_existing_response = responses_.contains(GetKey(host, path)); + } + if (!found_existing_response) { + // Add a server push response to responses map, if it is not in the map. + absl::string_view body = push_resource.body; + QUIC_DVLOG(1) << "Add response for push resource: host " << host + << " path " << path; + AddResponse(host, path, push_resource.headers.Clone(), body); + } + } +} + +bool QuicMemoryCacheBackend::PushResourceExistsInCache( + std::string original_request_url, ServerPushInfo resource) { + QuicWriterMutexLock lock(&response_mutex_); + auto resource_range = + server_push_resources_.equal_range(original_request_url); + for (auto it = resource_range.first; it != resource_range.second; ++it) { + ServerPushInfo push_resource = it->second; + if (push_resource.request_url.ToString() == + resource.request_url.ToString()) { + return true; + } + } + return false; +} + +} // namespace quic diff --git a/quiche/quic/tools/quic_memory_cache_backend.h b/quiche/quic/tools/quic_memory_cache_backend.h new file mode 100644 index 000000000000..ccdaf59a299f --- /dev/null +++ b/quiche/quic/tools/quic_memory_cache_backend.h @@ -0,0 +1,210 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TOOLS_QUIC_MEMORY_CACHE_BACKEND_H_ +#define QUICHE_QUIC_TOOLS_QUIC_MEMORY_CACHE_BACKEND_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/http/spdy_utils.h" +#include "quiche/quic/platform/api/quic_mutex.h" +#include "quiche/quic/tools/quic_backend_response.h" +#include "quiche/quic/tools/quic_simple_server_backend.h" +#include "quiche/quic/tools/quic_url.h" +#include "quiche/spdy/core/http2_header_block.h" +#include "quiche/spdy/core/spdy_framer.h" + +namespace quic { + +// In-memory cache for HTTP responses. +// Reads from disk cache generated by: +// `wget -p --save_headers ` +class QuicMemoryCacheBackend : public QuicSimpleServerBackend { + public: + // Class to manage loading a resource file into memory. There are + // two uses: called by InitializeBackend to load resources + // from files, and recursively called when said resources specify + // server push associations. + class ResourceFile { + public: + explicit ResourceFile(const std::string& file_name); + ResourceFile(const ResourceFile&) = delete; + ResourceFile& operator=(const ResourceFile&) = delete; + virtual ~ResourceFile(); + + void Read(); + + // |base| is |file_name_| with |cache_directory| prefix stripped. + void SetHostPathFromBase(absl::string_view base); + + const std::string& file_name() { return file_name_; } + + absl::string_view host() { return host_; } + + absl::string_view path() { return path_; } + + const spdy::Http2HeaderBlock& spdy_headers() { return spdy_headers_; } + + absl::string_view body() { return body_; } + + const std::vector& push_urls() { return push_urls_; } + + private: + void HandleXOriginalUrl(); + absl::string_view RemoveScheme(absl::string_view url); + + std::string file_name_; + std::string file_contents_; + absl::string_view body_; + spdy::Http2HeaderBlock spdy_headers_; + absl::string_view x_original_url_; + std::vector push_urls_; + std::string host_; + std::string path_; + }; + + QuicMemoryCacheBackend(); + QuicMemoryCacheBackend(const QuicMemoryCacheBackend&) = delete; + QuicMemoryCacheBackend& operator=(const QuicMemoryCacheBackend&) = delete; + ~QuicMemoryCacheBackend() override; + + // Retrieve a response from this cache for a given host and path.. + // If no appropriate response exists, nullptr is returned. + const QuicBackendResponse* GetResponse(absl::string_view host, + absl::string_view path) const; + + // Adds a simple response to the cache. The response headers will + // only contain the "content-length" header with the length of |body|. + void AddSimpleResponse(absl::string_view host, absl::string_view path, + int response_code, absl::string_view body); + + // Add a simple response to the cache as AddSimpleResponse() does, and add + // some server push resources(resource path, corresponding response status and + // path) associated with it. + // Push resource implicitly come from the same host. + // TODO(b/171463363): Remove. + void AddSimpleResponseWithServerPushResources( + absl::string_view host, absl::string_view path, int response_code, + absl::string_view body, + std::list push_resources); + + // Add a response to the cache. + void AddResponse(absl::string_view host, absl::string_view path, + spdy::Http2HeaderBlock response_headers, + absl::string_view response_body); + + // Add a response, with trailers, to the cache. + void AddResponse(absl::string_view host, absl::string_view path, + spdy::Http2HeaderBlock response_headers, + absl::string_view response_body, + spdy::Http2HeaderBlock response_trailers); + + // Add a response, with 103 Early Hints, to the cache. + void AddResponseWithEarlyHints( + absl::string_view host, absl::string_view path, + spdy::Http2HeaderBlock response_headers, absl::string_view response_body, + const std::vector& early_hints); + + // Simulate a special behavior at a particular path. + void AddSpecialResponse( + absl::string_view host, absl::string_view path, + QuicBackendResponse::SpecialResponseType response_type); + + void AddSpecialResponse( + absl::string_view host, absl::string_view path, + spdy::Http2HeaderBlock response_headers, absl::string_view response_body, + QuicBackendResponse::SpecialResponseType response_type); + + // Finds a response with the given host and path, and assign it a simulated + // delay. Returns true if the requisite response was found and the delay was + // set. + bool SetResponseDelay(absl::string_view host, absl::string_view path, + QuicTime::Delta delay); + + // Sets a default response in case of cache misses. Takes ownership of + // 'response'. + void AddDefaultResponse(QuicBackendResponse* response); + + // Once called, URLs which have a numeric path will send a dynamically + // generated response of that many bytes. + void GenerateDynamicResponses(); + + void EnableWebTransport(); + + // Find all the server push resources associated with |request_url|. + // TODO(b/171463363): Remove. + std::list GetServerPushResources( + std::string request_url); + + // Implements the functions for interface QuicSimpleServerBackend + // |cache_cirectory| can be generated using `wget -p --save-headers `. + bool InitializeBackend(const std::string& cache_directory) override; + bool IsBackendInitialized() const override; + void FetchResponseFromBackend( + const spdy::Http2HeaderBlock& request_headers, + const std::string& request_body, + QuicSimpleServerBackend::RequestHandler* quic_stream) override; + void CloseBackendResponseStream( + QuicSimpleServerBackend::RequestHandler* quic_stream) override; + WebTransportResponse ProcessWebTransportRequest( + const spdy::Http2HeaderBlock& request_headers, + WebTransportSession* session) override; + bool SupportsWebTransport() override { return enable_webtransport_; } + + private: + void AddResponseImpl(absl::string_view host, absl::string_view path, + QuicBackendResponse::SpecialResponseType response_type, + spdy::Http2HeaderBlock response_headers, + absl::string_view response_body, + spdy::Http2HeaderBlock response_trailers, + const std::vector& early_hints); + + std::string GetKey(absl::string_view host, absl::string_view path) const; + + // Add some server push urls with given responses for specified + // request if these push resources are not associated with this request yet. + // TODO(b/171463363): Remove. + void MaybeAddServerPushResources( + absl::string_view request_host, absl::string_view request_path, + std::list push_resources); + + // Check if push resource(push_host/push_path) associated with given request + // url already exists in server push map. + // TODO(b/171463363): Remove. + bool PushResourceExistsInCache(std::string original_request_url, + QuicBackendResponse::ServerPushInfo resource); + + // Cached responses. + absl::flat_hash_map> + responses_ QUIC_GUARDED_BY(response_mutex_); + + // The default response for cache misses, if set. + std::unique_ptr default_response_ + QUIC_GUARDED_BY(response_mutex_); + + // The generate bytes response, if set. + std::unique_ptr generate_bytes_response_ + QUIC_GUARDED_BY(response_mutex_); + + // A map from request URL to associated server push responses (if any). + // TODO(b/171463363): Remove. + std::multimap + server_push_resources_ QUIC_GUARDED_BY(response_mutex_); + + // Protects against concurrent access from test threads setting responses, and + // server threads accessing those responses. + mutable QuicMutex response_mutex_; + bool cache_initialized_; + + bool enable_webtransport_ = false; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_QUIC_MEMORY_CACHE_BACKEND_H_ diff --git a/quiche/quic/tools/quic_memory_cache_backend_test.cc b/quiche/quic/tools/quic_memory_cache_backend_test.cc new file mode 100644 index 000000000000..a75b46ec58c2 --- /dev/null +++ b/quiche/quic/tools/quic_memory_cache_backend_test.cc @@ -0,0 +1,264 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/quic_memory_cache_backend.h" + +#include + +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/tools/quic_backend_response.h" +#include "quiche/common/platform/api/quiche_file_utils.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace quic { +namespace test { + +namespace { +using Response = QuicBackendResponse; +using ServerPushInfo = QuicBackendResponse::ServerPushInfo; +} // namespace + +class QuicMemoryCacheBackendTest : public QuicTest { + protected: + void CreateRequest(std::string host, std::string path, + spdy::Http2HeaderBlock* headers) { + (*headers)[":method"] = "GET"; + (*headers)[":path"] = path; + (*headers)[":authority"] = host; + (*headers)[":scheme"] = "https"; + } + + std::string CacheDirectory() { + return quiche::test::QuicheGetTestMemoryCachePath(); + } + + QuicMemoryCacheBackend cache_; +}; + +TEST_F(QuicMemoryCacheBackendTest, GetResponseNoMatch) { + const Response* response = + cache_.GetResponse("mail.google.com", "/index.html"); + ASSERT_FALSE(response); +} + +TEST_F(QuicMemoryCacheBackendTest, AddSimpleResponseGetResponse) { + std::string response_body("hello response"); + cache_.AddSimpleResponse("www.google.com", "/", 200, response_body); + + spdy::Http2HeaderBlock request_headers; + CreateRequest("www.google.com", "/", &request_headers); + const Response* response = cache_.GetResponse("www.google.com", "/"); + ASSERT_TRUE(response); + ASSERT_TRUE(response->headers().contains(":status")); + EXPECT_EQ("200", response->headers().find(":status")->second); + EXPECT_EQ(response_body.size(), response->body().length()); +} + +TEST_F(QuicMemoryCacheBackendTest, AddResponse) { + const std::string kRequestHost = "www.foo.com"; + const std::string kRequestPath = "/"; + const std::string kResponseBody("hello response"); + + spdy::Http2HeaderBlock response_headers; + response_headers[":status"] = "200"; + response_headers["content-length"] = absl::StrCat(kResponseBody.size()); + + spdy::Http2HeaderBlock response_trailers; + response_trailers["key-1"] = "value-1"; + response_trailers["key-2"] = "value-2"; + response_trailers["key-3"] = "value-3"; + + cache_.AddResponse(kRequestHost, "/", response_headers.Clone(), kResponseBody, + response_trailers.Clone()); + + const Response* response = cache_.GetResponse(kRequestHost, kRequestPath); + EXPECT_EQ(response->headers(), response_headers); + EXPECT_EQ(response->body(), kResponseBody); + EXPECT_EQ(response->trailers(), response_trailers); +} + +// TODO(crbug.com/1249712) This test is failing on iOS. +#if defined(OS_IOS) +#define MAYBE_ReadsCacheDir DISABLED_ReadsCacheDir +#else +#define MAYBE_ReadsCacheDir ReadsCacheDir +#endif +TEST_F(QuicMemoryCacheBackendTest, MAYBE_ReadsCacheDir) { + cache_.InitializeBackend(CacheDirectory()); + const Response* response = + cache_.GetResponse("test.example.com", "/index.html"); + ASSERT_TRUE(response); + ASSERT_TRUE(response->headers().contains(":status")); + EXPECT_EQ("200", response->headers().find(":status")->second); + // Connection headers are not valid in HTTP/2. + EXPECT_FALSE(response->headers().contains("connection")); + EXPECT_LT(0U, response->body().length()); +} + +// TODO(crbug.com/1249712) This test is failing on iOS. +#if defined(OS_IOS) +#define MAYBE_UsesOriginalUrl DISABLED_UsesOriginalUrl +#else +#define MAYBE_UsesOriginalUrl UsesOriginalUrl +#endif +TEST_F(QuicMemoryCacheBackendTest, MAYBE_UsesOriginalUrl) { + cache_.InitializeBackend(CacheDirectory()); + const Response* response = + cache_.GetResponse("test.example.com", "/site_map.html"); + ASSERT_TRUE(response); + ASSERT_TRUE(response->headers().contains(":status")); + EXPECT_EQ("200", response->headers().find(":status")->second); + // Connection headers are not valid in HTTP/2. + EXPECT_FALSE(response->headers().contains("connection")); + EXPECT_LT(0U, response->body().length()); +} + +// TODO(crbug.com/1249712) This test is failing on iOS. +#if defined(OS_IOS) +#define MAYBE_UsesOriginalUrlOnly DISABLED_UsesOriginalUrlOnly +#else +#define MAYBE_UsesOriginalUrlOnly UsesOriginalUrlOnly +#endif +TEST_F(QuicMemoryCacheBackendTest, MAYBE_UsesOriginalUrlOnly) { + // Tests that if the URL cannot be inferred correctly from the path + // because the directory does not include the hostname, that the + // X-Original-Url header's value will be used. + std::string dir; + std::string path = "map.html"; + std::vector files; + ASSERT_TRUE(quiche::EnumerateDirectoryRecursively(CacheDirectory(), files)); + for (const std::string& file : files) { + if (absl::EndsWithIgnoreCase(file, "map.html")) { + dir = file; + dir.erase(dir.length() - path.length() - 1); + break; + } + } + ASSERT_NE("", dir); + + cache_.InitializeBackend(dir); + const Response* response = + cache_.GetResponse("test.example.com", "/site_map.html"); + ASSERT_TRUE(response); + ASSERT_TRUE(response->headers().contains(":status")); + EXPECT_EQ("200", response->headers().find(":status")->second); + // Connection headers are not valid in HTTP/2. + EXPECT_FALSE(response->headers().contains("connection")); + EXPECT_LT(0U, response->body().length()); +} + +TEST_F(QuicMemoryCacheBackendTest, DefaultResponse) { + // Verify GetResponse returns nullptr when no default is set. + const Response* response = cache_.GetResponse("www.google.com", "/"); + ASSERT_FALSE(response); + + // Add a default response. + spdy::Http2HeaderBlock response_headers; + response_headers[":status"] = "200"; + response_headers["content-length"] = "0"; + Response* default_response = new Response; + default_response->set_headers(std::move(response_headers)); + cache_.AddDefaultResponse(default_response); + + // Now we should get the default response for the original request. + response = cache_.GetResponse("www.google.com", "/"); + ASSERT_TRUE(response); + ASSERT_TRUE(response->headers().contains(":status")); + EXPECT_EQ("200", response->headers().find(":status")->second); + + // Now add a set response for / and make sure it is returned + cache_.AddSimpleResponse("www.google.com", "/", 302, ""); + response = cache_.GetResponse("www.google.com", "/"); + ASSERT_TRUE(response); + ASSERT_TRUE(response->headers().contains(":status")); + EXPECT_EQ("302", response->headers().find(":status")->second); + + // We should get the default response for other requests. + response = cache_.GetResponse("www.google.com", "/asd"); + ASSERT_TRUE(response); + ASSERT_TRUE(response->headers().contains(":status")); + EXPECT_EQ("200", response->headers().find(":status")->second); +} + +TEST_F(QuicMemoryCacheBackendTest, AddSimpleResponseWithServerPushResources) { + std::string request_host = "www.foo.com"; + std::string response_body("hello response"); + const size_t kNumResources = 5; + int NumResources = 5; + std::list push_resources; + std::string scheme = "http"; + for (int i = 0; i < NumResources; ++i) { + std::string path = absl::StrCat("/server_push_src", i); + std::string url = scheme + "://" + request_host + path; + QuicUrl resource_url(url); + std::string body = + absl::StrCat("This is server push response body for ", path); + spdy::Http2HeaderBlock response_headers; + response_headers[":status"] = "200"; + response_headers["content-length"] = absl::StrCat(body.size()); + push_resources.push_back( + ServerPushInfo(resource_url, response_headers.Clone(), i, body)); + } + + cache_.AddSimpleResponseWithServerPushResources( + request_host, "/", 200, response_body, push_resources); + + std::string request_url = request_host + "/"; + std::list resources = + cache_.GetServerPushResources(request_url); + ASSERT_EQ(kNumResources, resources.size()); + for (const auto& push_resource : push_resources) { + ServerPushInfo resource = resources.front(); + EXPECT_EQ(resource.request_url.ToString(), + push_resource.request_url.ToString()); + EXPECT_EQ(resource.priority, push_resource.priority); + resources.pop_front(); + } +} + +TEST_F(QuicMemoryCacheBackendTest, GetServerPushResourcesAndPushResponses) { + std::string request_host = "www.foo.com"; + std::string response_body("hello response"); + const size_t kNumResources = 4; + int NumResources = 4; + std::string scheme = "http"; + std::string push_response_status[kNumResources] = {"200", "200", "301", + "404"}; + std::list push_resources; + for (int i = 0; i < NumResources; ++i) { + std::string path = absl::StrCat("/server_push_src", i); + std::string url = scheme + "://" + request_host + path; + QuicUrl resource_url(url); + std::string body = "This is server push response body for " + path; + spdy::Http2HeaderBlock response_headers; + response_headers[":status"] = push_response_status[i]; + response_headers["content-length"] = absl::StrCat(body.size()); + push_resources.push_back( + ServerPushInfo(resource_url, response_headers.Clone(), i, body)); + } + cache_.AddSimpleResponseWithServerPushResources( + request_host, "/", 200, response_body, push_resources); + std::string request_url = request_host + "/"; + std::list resources = + cache_.GetServerPushResources(request_url); + ASSERT_EQ(kNumResources, resources.size()); + int i = 0; + for (const auto& push_resource : push_resources) { + QuicUrl url = resources.front().request_url; + std::string host = url.host(); + std::string path = url.path(); + const Response* response = cache_.GetResponse(host, path); + ASSERT_TRUE(response); + ASSERT_TRUE(response->headers().contains(":status")); + EXPECT_EQ(push_response_status[i++], + response->headers().find(":status")->second); + EXPECT_EQ(push_resource.body, response->body()); + resources.pop_front(); + } +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/tools/quic_name_lookup.cc b/quiche/quic/tools/quic_name_lookup.cc new file mode 100644 index 000000000000..dc1d918f5bb6 --- /dev/null +++ b/quiche/quic/tools/quic_name_lookup.cc @@ -0,0 +1,54 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/quic_name_lookup.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_socket_address.h" + +#if defined(_WIN32) +#include +#include +#else // else assume POSIX +#include +#include +#include +#endif + +namespace quic::tools { + +QuicSocketAddress LookupAddress(int address_family_for_lookup, std::string host, + std::string port) { + addrinfo hint; + memset(&hint, 0, sizeof(hint)); + hint.ai_family = address_family_for_lookup; + hint.ai_protocol = IPPROTO_UDP; + + addrinfo* info_list = nullptr; + int result = getaddrinfo(host.c_str(), port.c_str(), &hint, &info_list); + if (result != 0) { + QUIC_LOG(ERROR) << "Failed to look up " << host << ": " + << gai_strerror(result); + return QuicSocketAddress(); + } + + QUICHE_CHECK(info_list != nullptr); + std::unique_ptr info_list_owned( + info_list, [](addrinfo* ai) { freeaddrinfo(ai); }); + return QuicSocketAddress(info_list->ai_addr, info_list->ai_addrlen); +} + +QuicSocketAddress LookupAddress(int address_family_for_lookup, + const QuicServerId& server_id) { + return LookupAddress(address_family_for_lookup, + std::string(server_id.GetHostWithoutIpv6Brackets()), + absl::StrCat(server_id.port())); +} + +} // namespace quic::tools diff --git a/quiche/quic/tools/quic_name_lookup.h b/quiche/quic/tools/quic_name_lookup.h new file mode 100644 index 000000000000..d9fd8e8c1edc --- /dev/null +++ b/quiche/quic/tools/quic_name_lookup.h @@ -0,0 +1,35 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TOOLS_QUIC_NAME_LOOKUP_H_ +#define QUICHE_QUIC_TOOLS_QUIC_NAME_LOOKUP_H_ + +#include + +#include "quiche/quic/platform/api/quic_socket_address.h" + +namespace quic { + +class QuicServerId; + +namespace tools { + +quic::QuicSocketAddress LookupAddress(int address_family_for_lookup, + std::string host, std::string port); + +quic::QuicSocketAddress LookupAddress(int address_family_for_lookup, + const QuicServerId& server_id); + +inline QuicSocketAddress LookupAddress(std::string host, std::string port) { + return LookupAddress(0, host, port); +} + +inline QuicSocketAddress LookupAddress(const QuicServerId& server_id) { + return LookupAddress(0, server_id); +} + +} // namespace tools +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_QUIC_NAME_LOOKUP_H_ diff --git a/quiche/quic/tools/quic_packet_printer_bin.cc b/quiche/quic/tools/quic_packet_printer_bin.cc new file mode 100644 index 000000000000..92a3c4c5790f --- /dev/null +++ b/quiche/quic/tools/quic_packet_printer_bin.cc @@ -0,0 +1,287 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// clang-format off + +// Dumps out the decryptable contents of a QUIC packet in a human-readable way. +// If the packet is null encrypted, this will dump full packet contents. +// Otherwise it will dump the header, and fail with an error that the +// packet is undecryptable. +// +// Usage: quic_packet_printer server|client +// +// Example input: +// quic_packet_printer server 0c6b810308320f24c004a939a38a2e3fd6ca589917f200400201b80b0100501c0700060003023d0000001c00556e656e637279707465642073747265616d2064617461207365656e +// +// Example output: +// OnPacket +// OnUnauthenticatedPublicHeader +// OnUnauthenticatedHeader: { connection_id: 13845207862000976235, connection_id_length:8, packet_number_length:1, multipath_flag: 0, reset_flag: 0, version_flag: 0, path_id: , packet_number: 4 } +// OnDecryptedPacket +// OnPacketHeader +// OnAckFrame: largest_observed: 1 ack_delay_time: 3000 missing_packets: [ ] is_truncated: 0 received_packets: [ 1 at 466016 ] +// OnStopWaitingFrame +// OnConnectionCloseFrame: error_code { 61 } error_details { Unencrypted stream data seen } + +// clang-format on + +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_framer.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/common/platform/api/quiche_command_line_flags.h" +#include "quiche/common/quiche_text_utils.h" + +DEFINE_QUICHE_COMMAND_LINE_FLAG(std::string, quic_version, "", + "If set, specify the QUIC version to use."); + +namespace quic { + +class QuicPacketPrinter : public QuicFramerVisitorInterface { + public: + explicit QuicPacketPrinter(QuicFramer* framer) : framer_(framer) {} + + void OnError(QuicFramer* framer) override { + std::cerr << "OnError: " << QuicErrorCodeToString(framer->error()) + << " detail: " << framer->detailed_error() << "\n"; + } + bool OnProtocolVersionMismatch(ParsedQuicVersion received_version) override { + framer_->set_version(received_version); + std::cerr << "OnProtocolVersionMismatch: " + << ParsedQuicVersionToString(received_version) << "\n"; + return true; + } + void OnPacket() override { std::cerr << "OnPacket\n"; } + void OnPublicResetPacket(const QuicPublicResetPacket& /*packet*/) override { + std::cerr << "OnPublicResetPacket\n"; + } + void OnVersionNegotiationPacket( + const QuicVersionNegotiationPacket& /*packet*/) override { + std::cerr << "OnVersionNegotiationPacket\n"; + } + void OnRetryPacket(QuicConnectionId /*original_connection_id*/, + QuicConnectionId /*new_connection_id*/, + absl::string_view /*retry_token*/, + absl::string_view /*retry_integrity_tag*/, + absl::string_view /*retry_without_tag*/) override { + std::cerr << "OnRetryPacket\n"; + } + bool OnUnauthenticatedPublicHeader( + const QuicPacketHeader& /*header*/) override { + std::cerr << "OnUnauthenticatedPublicHeader\n"; + return true; + } + bool OnUnauthenticatedHeader(const QuicPacketHeader& header) override { + std::cerr << "OnUnauthenticatedHeader: " << header; + return true; + } + void OnDecryptedPacket(size_t /*length*/, EncryptionLevel level) override { + // This only currently supports "decrypting" null encrypted packets. + QUICHE_DCHECK_EQ(ENCRYPTION_INITIAL, level); + std::cerr << "OnDecryptedPacket\n"; + } + bool OnPacketHeader(const QuicPacketHeader& /*header*/) override { + std::cerr << "OnPacketHeader\n"; + return true; + } + void OnCoalescedPacket(const QuicEncryptedPacket& /*packet*/) override { + std::cerr << "OnCoalescedPacket\n"; + } + void OnUndecryptablePacket(const QuicEncryptedPacket& /*packet*/, + EncryptionLevel /*decryption_level*/, + bool /*has_decryption_key*/) override { + std::cerr << "OnUndecryptablePacket\n"; + } + bool OnStreamFrame(const QuicStreamFrame& frame) override { + std::cerr << "OnStreamFrame: " << frame; + std::cerr << " data: { " + << absl::BytesToHexString( + absl::string_view(frame.data_buffer, frame.data_length)) + << " }\n"; + return true; + } + bool OnCryptoFrame(const QuicCryptoFrame& frame) override { + std::cerr << "OnCryptoFrame: " << frame; + std::cerr << " data: { " + << absl::BytesToHexString( + absl::string_view(frame.data_buffer, frame.data_length)) + << " }\n"; + return true; + } + bool OnAckFrameStart(QuicPacketNumber largest_acked, + QuicTime::Delta /*ack_delay_time*/) override { + std::cerr << "OnAckFrameStart, largest_acked: " << largest_acked; + return true; + } + bool OnAckRange(QuicPacketNumber start, QuicPacketNumber end) override { + std::cerr << "OnAckRange: [" << start << ", " << end << ")"; + return true; + } + bool OnAckTimestamp(QuicPacketNumber packet_number, + QuicTime timestamp) override { + std::cerr << "OnAckTimestamp: [" << packet_number << ", " + << timestamp.ToDebuggingValue() << ")"; + return true; + } + bool OnAckFrameEnd(QuicPacketNumber start, + const absl::optional& ecn_counts) override { + std::cerr << "OnAckFrameEnd, start: " << start; + if (ecn_counts.has_value()) { + std::cerr << " ECN counts: " << ecn_counts->ToString(); + } + return true; + } + bool OnStopWaitingFrame(const QuicStopWaitingFrame& frame) override { + std::cerr << "OnStopWaitingFrame: " << frame; + return true; + } + bool OnPaddingFrame(const QuicPaddingFrame& frame) override { + std::cerr << "OnPaddingFrame: " << frame; + return true; + } + bool OnPingFrame(const QuicPingFrame& frame) override { + std::cerr << "OnPingFrame: " << frame; + return true; + } + bool OnRstStreamFrame(const QuicRstStreamFrame& frame) override { + std::cerr << "OnRstStreamFrame: " << frame; + return true; + } + bool OnConnectionCloseFrame(const QuicConnectionCloseFrame& frame) override { + // The frame printout will indicate whether it's a Google QUIC + // CONNECTION_CLOSE, IETF QUIC CONNECTION_CLOSE/Transport, or IETF QUIC + // CONNECTION_CLOSE/Application frame. + std::cerr << "OnConnectionCloseFrame: " << frame; + return true; + } + bool OnNewConnectionIdFrame(const QuicNewConnectionIdFrame& frame) override { + std::cerr << "OnNewConnectionIdFrame: " << frame; + return true; + } + bool OnRetireConnectionIdFrame( + const QuicRetireConnectionIdFrame& frame) override { + std::cerr << "OnRetireConnectionIdFrame: " << frame; + return true; + } + bool OnNewTokenFrame(const QuicNewTokenFrame& frame) override { + std::cerr << "OnNewTokenFrame: " << frame; + return true; + } + bool OnStopSendingFrame(const QuicStopSendingFrame& frame) override { + std::cerr << "OnStopSendingFrame: " << frame; + return true; + } + bool OnPathChallengeFrame(const QuicPathChallengeFrame& frame) override { + std::cerr << "OnPathChallengeFrame: " << frame; + return true; + } + bool OnPathResponseFrame(const QuicPathResponseFrame& frame) override { + std::cerr << "OnPathResponseFrame: " << frame; + return true; + } + bool OnGoAwayFrame(const QuicGoAwayFrame& frame) override { + std::cerr << "OnGoAwayFrame: " << frame; + return true; + } + bool OnMaxStreamsFrame(const QuicMaxStreamsFrame& frame) override { + std::cerr << "OnMaxStreamsFrame: " << frame; + return true; + } + bool OnStreamsBlockedFrame(const QuicStreamsBlockedFrame& frame) override { + std::cerr << "OnStreamsBlockedFrame: " << frame; + return true; + } + bool OnWindowUpdateFrame(const QuicWindowUpdateFrame& frame) override { + std::cerr << "OnWindowUpdateFrame: " << frame; + return true; + } + bool OnBlockedFrame(const QuicBlockedFrame& frame) override { + std::cerr << "OnBlockedFrame: " << frame; + return true; + } + bool OnMessageFrame(const QuicMessageFrame& frame) override { + std::cerr << "OnMessageFrame: " << frame; + return true; + } + bool OnHandshakeDoneFrame(const QuicHandshakeDoneFrame& frame) override { + std::cerr << "OnHandshakeDoneFrame: " << frame; + return true; + } + bool OnAckFrequencyFrame(const QuicAckFrequencyFrame& frame) override { + std::cerr << "OnAckFrequencyFrame: " << frame; + return true; + } + void OnPacketComplete() override { std::cerr << "OnPacketComplete\n"; } + bool IsValidStatelessResetToken( + const StatelessResetToken& /*token*/) const override { + std::cerr << "IsValidStatelessResetToken\n"; + return false; + } + void OnAuthenticatedIetfStatelessResetPacket( + const QuicIetfStatelessResetPacket& /*packet*/) override { + std::cerr << "OnAuthenticatedIetfStatelessResetPacket\n"; + } + void OnKeyUpdate(KeyUpdateReason reason) override { + std::cerr << "OnKeyUpdate: " << reason << "\n"; + } + void OnDecryptedFirstPacketInKeyPhase() override { + std::cerr << "OnDecryptedFirstPacketInKeyPhase\n"; + } + std::unique_ptr AdvanceKeysAndCreateCurrentOneRttDecrypter() + override { + std::cerr << "AdvanceKeysAndCreateCurrentOneRttDecrypter\n"; + return nullptr; + } + std::unique_ptr CreateCurrentOneRttEncrypter() override { + std::cerr << "CreateCurrentOneRttEncrypter\n"; + return nullptr; + } + + private: + QuicFramer* framer_; // Unowned. +}; + +} // namespace quic + +int main(int argc, char* argv[]) { + const char* usage = "Usage: quic_packet_printer client|server "; + std::vector args = + quiche::QuicheParseCommandLineFlags(usage, argc, argv); + + if (args.size() < 2) { + quiche::QuichePrintCommandLineFlagHelp(usage); + return 1; + } + + std::string perspective_string = args[0]; + quic::Perspective perspective; + if (perspective_string == "client") { + perspective = quic::Perspective::IS_CLIENT; + } else if (perspective_string == "server") { + perspective = quic::Perspective::IS_SERVER; + } else { + std::cerr << "Invalid perspective" << std::endl; + quiche::QuichePrintCommandLineFlagHelp(usage); + return 1; + } + std::string hex = absl::HexStringToBytes(args[1]); + quic::ParsedQuicVersionVector versions = quic::AllSupportedVersions(); + // Fake a time since we're not actually generating acks. + quic::QuicTime start(quic::QuicTime::Zero()); + quic::QuicFramer framer(versions, start, perspective, + quic::kQuicDefaultConnectionIdLength); + const quic::ParsedQuicVersion& version = quic::ParseQuicVersionString( + quiche::GetQuicheCommandLineFlag(FLAGS_quic_version)); + if (version != quic::ParsedQuicVersion::Unsupported()) { + framer.set_version(version); + } + quic::QuicPacketPrinter visitor(&framer); + framer.set_visitor(&visitor); + quic::QuicEncryptedPacket encrypted(hex.c_str(), hex.length()); + return framer.ProcessPacket(encrypted); +} diff --git a/quiche/quic/tools/quic_reject_reason_decoder_bin.cc b/quiche/quic/tools/quic_reject_reason_decoder_bin.cc new file mode 100644 index 000000000000..661e3d9122e9 --- /dev/null +++ b/quiche/quic/tools/quic_reject_reason_decoder_bin.cc @@ -0,0 +1,45 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Decodes the packet HandshakeFailureReason from the chromium histogram +// Net.QuicClientHelloRejectReasons + +#include + +#include "absl/strings/numbers.h" +#include "quiche/quic/core/crypto/crypto_handshake.h" +#include "quiche/quic/core/crypto/crypto_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/common/platform/api/quiche_command_line_flags.h" +#include "quiche/common/quiche_text_utils.h" + +using quic::CryptoUtils; +using quic::HandshakeFailureReason; +using quic::MAX_FAILURE_REASON; + +int main(int argc, char* argv[]) { + const char* usage = "Usage: quic_reject_reason_decoder "; + std::vector args = + quiche::QuicheParseCommandLineFlags(usage, argc, argv); + + if (args.size() != 1) { + std::cerr << usage << std::endl; + return 1; + } + + uint32_t packed_error = 0; + if (!absl::SimpleAtoi(args[0], &packed_error)) { + std::cerr << "Unable to parse: " << args[0] << "\n"; + return 2; + } + + for (int i = 1; i < MAX_FAILURE_REASON; ++i) { + if ((packed_error & (1 << (i - 1))) == 0) { + continue; + } + HandshakeFailureReason reason = static_cast(i); + std::cout << CryptoUtils::HandshakeFailureReasonToString(reason) << "\n"; + } + return 0; +} diff --git a/quiche/quic/tools/quic_server.cc b/quiche/quic/tools/quic_server.cc new file mode 100644 index 000000000000..738d8c9a1aef --- /dev/null +++ b/quiche/quic/tools/quic_server.cc @@ -0,0 +1,231 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/quic_server.h" + +#include +#include + +#include "quiche/quic/core/crypto/crypto_handshake.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/io/event_loop_socket_factory.h" +#include "quiche/quic/core/io/quic_default_event_loop.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/quic_clock.h" +#include "quiche/quic/core/quic_crypto_stream.h" +#include "quiche/quic/core/quic_data_reader.h" +#include "quiche/quic/core/quic_default_clock.h" +#include "quiche/quic/core/quic_default_connection_helper.h" +#include "quiche/quic/core/quic_default_packet_writer.h" +#include "quiche/quic/core/quic_dispatcher.h" +#include "quiche/quic/core/quic_packet_reader.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/tools/quic_simple_crypto_server_stream_helper.h" +#include "quiche/quic/tools/quic_simple_dispatcher.h" +#include "quiche/quic/tools/quic_simple_server_backend.h" +#include "quiche/common/simple_buffer_allocator.h" + +namespace quic { + +namespace { + +const char kSourceAddressTokenSecret[] = "secret"; + +} // namespace + +const size_t kNumSessionsToCreatePerSocketEvent = 16; + +QuicServer::QuicServer(std::unique_ptr proof_source, + QuicSimpleServerBackend* quic_simple_server_backend) + : QuicServer(std::move(proof_source), quic_simple_server_backend, + AllSupportedVersions()) {} + +QuicServer::QuicServer(std::unique_ptr proof_source, + QuicSimpleServerBackend* quic_simple_server_backend, + const ParsedQuicVersionVector& supported_versions) + : QuicServer(std::move(proof_source), QuicConfig(), + QuicCryptoServerConfig::ConfigOptions(), supported_versions, + quic_simple_server_backend, kQuicDefaultConnectionIdLength) {} + +QuicServer::QuicServer( + std::unique_ptr proof_source, const QuicConfig& config, + const QuicCryptoServerConfig::ConfigOptions& crypto_config_options, + const ParsedQuicVersionVector& supported_versions, + QuicSimpleServerBackend* quic_simple_server_backend, + uint8_t expected_server_connection_id_length) + : port_(0), + fd_(-1), + packets_dropped_(0), + overflow_supported_(false), + silent_close_(false), + config_(config), + crypto_config_(kSourceAddressTokenSecret, QuicRandom::GetInstance(), + std::move(proof_source), KeyExchangeSource::Default()), + crypto_config_options_(crypto_config_options), + version_manager_(supported_versions), + packet_reader_(new QuicPacketReader()), + quic_simple_server_backend_(quic_simple_server_backend), + expected_server_connection_id_length_( + expected_server_connection_id_length), + connection_id_generator_(expected_server_connection_id_length) { + QUICHE_DCHECK(quic_simple_server_backend_); + Initialize(); +} + +void QuicServer::Initialize() { + // If an initial flow control window has not explicitly been set, then use a + // sensible value for a server: 1 MB for session, 64 KB for each stream. + const uint32_t kInitialSessionFlowControlWindow = 1 * 1024 * 1024; // 1 MB + const uint32_t kInitialStreamFlowControlWindow = 64 * 1024; // 64 KB + if (config_.GetInitialStreamFlowControlWindowToSend() == + kDefaultFlowControlSendWindow) { + config_.SetInitialStreamFlowControlWindowToSend( + kInitialStreamFlowControlWindow); + } + if (config_.GetInitialSessionFlowControlWindowToSend() == + kDefaultFlowControlSendWindow) { + config_.SetInitialSessionFlowControlWindowToSend( + kInitialSessionFlowControlWindow); + } + + std::unique_ptr scfg(crypto_config_.AddDefaultConfig( + QuicRandom::GetInstance(), QuicDefaultClock::Get(), + crypto_config_options_)); +} + +QuicServer::~QuicServer() { + close(fd_); + fd_ = -1; + + // Should be fine without because nothing should send requests to the backend + // after `this` is destroyed, but for extra pointer safety, clear the socket + // factory from the backend before the socket factory is destroyed. + quic_simple_server_backend_->SetSocketFactory(nullptr); +} + +bool QuicServer::CreateUDPSocketAndListen(const QuicSocketAddress& address) { + event_loop_ = CreateEventLoop(); + + socket_factory_ = std::make_unique( + event_loop_.get(), quiche::SimpleBufferAllocator::Get()); + quic_simple_server_backend_->SetSocketFactory(socket_factory_.get()); + + QuicUdpSocketApi socket_api; + fd_ = socket_api.Create(address.host().AddressFamilyToInt(), + /*receive_buffer_size =*/kDefaultSocketReceiveBuffer, + /*send_buffer_size =*/kDefaultSocketReceiveBuffer); + if (fd_ == kQuicInvalidSocketFd) { + QUIC_LOG(ERROR) << "CreateSocket() failed: " << strerror(errno); + return false; + } + + overflow_supported_ = socket_api.EnableDroppedPacketCount(fd_); + socket_api.EnableReceiveTimestamp(fd_); + + bool success = socket_api.Bind(fd_, address); + if (!success) { + QUIC_LOG(ERROR) << "Bind failed: " << strerror(errno); + return false; + } + QUIC_LOG(INFO) << "Listening on " << address.ToString(); + port_ = address.port(); + if (port_ == 0) { + QuicSocketAddress address; + if (address.FromSocket(fd_) != 0) { + QUIC_LOG(ERROR) << "Unable to get self address. Error: " + << strerror(errno); + } + port_ = address.port(); + } + + bool register_result = event_loop_->RegisterSocket( + fd_, kSocketEventReadable | kSocketEventWritable, this); + if (!register_result) { + return false; + } + dispatcher_.reset(CreateQuicDispatcher()); + dispatcher_->InitializeWithWriter(CreateWriter(fd_)); + + return true; +} + +QuicPacketWriter* QuicServer::CreateWriter(int fd) { + return new QuicDefaultPacketWriter(fd); +} + +QuicDispatcher* QuicServer::CreateQuicDispatcher() { + return new QuicSimpleDispatcher( + &config_, &crypto_config_, &version_manager_, + std::make_unique(), + std::unique_ptr( + new QuicSimpleCryptoServerStreamHelper()), + event_loop_->CreateAlarmFactory(), quic_simple_server_backend_, + expected_server_connection_id_length_, connection_id_generator_); +} + +std::unique_ptr QuicServer::CreateEventLoop() { + return GetDefaultEventLoop()->Create(QuicDefaultClock::Get()); +} + +void QuicServer::HandleEventsForever() { + while (true) { + WaitForEvents(); + } +} + +void QuicServer::WaitForEvents() { + event_loop_->RunEventLoopOnce(QuicTime::Delta::FromMilliseconds(50)); +} + +void QuicServer::Shutdown() { + if (!silent_close_) { + // Before we shut down the epoll server, give all active sessions a chance + // to notify clients that they're closing. + dispatcher_->Shutdown(); + } + + dispatcher_.reset(); + event_loop_.reset(); +} + +void QuicServer::OnSocketEvent(QuicEventLoop* /*event_loop*/, + QuicUdpSocketFd fd, QuicSocketEventMask events) { + QUICHE_DCHECK_EQ(fd, fd_); + + if (events & kSocketEventReadable) { + QUIC_DVLOG(1) << "EPOLLIN"; + + dispatcher_->ProcessBufferedChlos(kNumSessionsToCreatePerSocketEvent); + + bool more_to_read = true; + while (more_to_read) { + more_to_read = packet_reader_->ReadAndDispatchPackets( + fd_, port_, *QuicDefaultClock::Get(), dispatcher_.get(), + overflow_supported_ ? &packets_dropped_ : nullptr); + } + + if (dispatcher_->HasChlosBuffered()) { + // Register EPOLLIN event to consume buffered CHLO(s). + bool success = + event_loop_->ArtificiallyNotifyEvent(fd_, kSocketEventReadable); + QUICHE_DCHECK(success); + } + if (!event_loop_->SupportsEdgeTriggered()) { + bool success = event_loop_->RearmSocket(fd_, kSocketEventReadable); + QUICHE_DCHECK(success); + } + } + if (events & kSocketEventWritable) { + dispatcher_->OnCanWrite(); + if (!event_loop_->SupportsEdgeTriggered() && + dispatcher_->HasPendingWrites()) { + bool success = event_loop_->RearmSocket(fd_, kSocketEventWritable); + QUICHE_DCHECK(success); + } + } +} + +} // namespace quic diff --git a/quiche/quic/tools/quic_server.h b/quiche/quic/tools/quic_server.h new file mode 100644 index 000000000000..f97a4593082c --- /dev/null +++ b/quiche/quic/tools/quic_server.h @@ -0,0 +1,174 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// A toy server, which listens on a specified address for QUIC traffic and +// handles incoming responses. +// +// Note that this server is intended to verify correctness of the client and is +// in no way expected to be performant. + +#ifndef QUICHE_QUIC_TOOLS_QUIC_SERVER_H_ +#define QUICHE_QUIC_TOOLS_QUIC_SERVER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/quic_crypto_server_config.h" +#include "quiche/quic/core/deterministic_connection_id_generator.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/quic_config.h" +#include "quiche/quic/core/quic_packet_writer.h" +#include "quiche/quic/core/quic_udp_socket.h" +#include "quiche/quic/core/quic_version_manager.h" +#include "quiche/quic/core/socket_factory.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/tools/quic_simple_server_backend.h" +#include "quiche/quic/tools/quic_spdy_server_base.h" + +namespace quic { + +namespace test { +class QuicServerPeer; +} // namespace test + +class QuicDispatcher; +class QuicPacketReader; + +class QuicServer : public QuicSpdyServerBase, public QuicSocketEventListener { + public: + // `quic_simple_server_backend` must outlive the created QuicServer. + QuicServer(std::unique_ptr proof_source, + QuicSimpleServerBackend* quic_simple_server_backend); + QuicServer(std::unique_ptr proof_source, + QuicSimpleServerBackend* quic_simple_server_backend, + const ParsedQuicVersionVector& supported_versions); + QuicServer(std::unique_ptr proof_source, + const QuicConfig& config, + const QuicCryptoServerConfig::ConfigOptions& crypto_config_options, + const ParsedQuicVersionVector& supported_versions, + QuicSimpleServerBackend* quic_simple_server_backend, + uint8_t expected_server_connection_id_length); + QuicServer(const QuicServer&) = delete; + QuicServer& operator=(const QuicServer&) = delete; + + ~QuicServer() override; + + // Start listening on the specified address. + bool CreateUDPSocketAndListen(const QuicSocketAddress& address) override; + // Handles all events. Does not return. + void HandleEventsForever() override; + + // Wait up to 50ms, and handle any events which occur. + void WaitForEvents(); + + // Server deletion is imminent. Start cleaning up any pending sessions. + virtual void Shutdown(); + + // QuicSocketEventListener implementation. + void OnSocketEvent(QuicEventLoop* event_loop, QuicUdpSocketFd fd, + QuicSocketEventMask events) override; + + void SetChloMultiplier(size_t multiplier) { + crypto_config_.set_chlo_multiplier(multiplier); + } + + void SetPreSharedKey(absl::string_view key) { + crypto_config_.set_pre_shared_key(key); + } + + bool overflow_supported() { return overflow_supported_; } + + QuicPacketCount packets_dropped() { return packets_dropped_; } + + int port() { return port_; } + + QuicEventLoop* event_loop() { return event_loop_.get(); } + + protected: + virtual QuicPacketWriter* CreateWriter(int fd); + + virtual QuicDispatcher* CreateQuicDispatcher(); + + virtual std::unique_ptr CreateEventLoop(); + + const QuicConfig& config() const { return config_; } + const QuicCryptoServerConfig& crypto_config() const { return crypto_config_; } + + QuicDispatcher* dispatcher() { return dispatcher_.get(); } + + QuicVersionManager* version_manager() { return &version_manager_; } + + QuicSimpleServerBackend* server_backend() { + return quic_simple_server_backend_; + } + + void set_silent_close(bool value) { silent_close_ = value; } + + uint8_t expected_server_connection_id_length() { + return expected_server_connection_id_length_; + } + + ConnectionIdGeneratorInterface& connection_id_generator() { + return connection_id_generator_; + } + + private: + friend class quic::test::QuicServerPeer; + + // Initialize the internal state of the server. + void Initialize(); + + // Schedules alarms and notifies the server of the I/O events. + std::unique_ptr event_loop_; + // Used by some backends to create additional sockets, e.g. for upstream + // destination connections for proxying. + std::unique_ptr socket_factory_; + // Accepts data from the framer and demuxes clients to sessions. + std::unique_ptr dispatcher_; + + // The port the server is listening on. + int port_; + + // Listening connection. Also used for outbound client communication. + QuicUdpSocketFd fd_; + + // If overflow_supported_ is true this will be the number of packets dropped + // during the lifetime of the server. This may overflow if enough packets + // are dropped. + QuicPacketCount packets_dropped_; + + // True if the kernel supports SO_RXQ_OVFL, the number of packets dropped + // because the socket would otherwise overflow. + bool overflow_supported_; + + // If true, do not call Shutdown on the dispatcher. Connections will close + // without sending a final connection close. + bool silent_close_; + + // config_ contains non-crypto parameters that are negotiated in the crypto + // handshake. + QuicConfig config_; + // crypto_config_ contains crypto parameters for the handshake. + QuicCryptoServerConfig crypto_config_; + // crypto_config_options_ contains crypto parameters for the handshake. + QuicCryptoServerConfig::ConfigOptions crypto_config_options_; + + // Used to generate current supported versions. + QuicVersionManager version_manager_; + + // Point to a QuicPacketReader object on the heap. The reader allocates more + // space than allowed on the stack. + std::unique_ptr packet_reader_; + + QuicSimpleServerBackend* quic_simple_server_backend_; // unowned. + + // Connection ID length expected to be read on incoming IETF short headers. + uint8_t expected_server_connection_id_length_; + + DeterministicConnectionIdGenerator connection_id_generator_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_QUIC_SERVER_H_ diff --git a/quiche/quic/tools/quic_server_bin.cc b/quiche/quic/tools/quic_server_bin.cc new file mode 100644 index 000000000000..d823af0f7452 --- /dev/null +++ b/quiche/quic/tools/quic_server_bin.cc @@ -0,0 +1,29 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// A binary wrapper for QuicServer. It listens forever on --port +// (default 6121) until it's killed or ctrl-cd to death. + +#include + +#include "quiche/quic/tools/quic_server_factory.h" +#include "quiche/quic/tools/quic_toy_server.h" +#include "quiche/common/platform/api/quiche_command_line_flags.h" +#include "quiche/common/platform/api/quiche_system_event_loop.h" + +int main(int argc, char* argv[]) { + quiche::QuicheSystemEventLoop event_loop("quic_server"); + const char* usage = "Usage: quic_server [options]"; + std::vector non_option_args = + quiche::QuicheParseCommandLineFlags(usage, argc, argv); + if (!non_option_args.empty()) { + quiche::QuichePrintCommandLineFlagHelp(usage); + exit(0); + } + + quic::QuicToyServer::MemoryCacheBackendFactory backend_factory; + quic::QuicServerFactory server_factory; + quic::QuicToyServer server(&backend_factory, &server_factory); + return server.Start(); +} diff --git a/quiche/quic/tools/quic_server_factory.cc b/quiche/quic/tools/quic_server_factory.cc new file mode 100644 index 000000000000..7aac48b1e3b5 --- /dev/null +++ b/quiche/quic/tools/quic_server_factory.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/quic_server_factory.h" + +#include + +#include "quiche/quic/tools/quic_server.h" + +namespace quic { + +std::unique_ptr QuicServerFactory::CreateServer( + quic::QuicSimpleServerBackend* backend, + std::unique_ptr proof_source, + const quic::ParsedQuicVersionVector& supported_versions) { + return std::make_unique(std::move(proof_source), backend, + supported_versions); +} + +} // namespace quic diff --git a/quiche/quic/tools/quic_server_factory.h b/quiche/quic/tools/quic_server_factory.h new file mode 100644 index 000000000000..ba6bdfceb6d4 --- /dev/null +++ b/quiche/quic/tools/quic_server_factory.h @@ -0,0 +1,23 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TOOLS_QUIC_SERVER_FACTORY_H_ +#define QUICHE_QUIC_TOOLS_QUIC_SERVER_FACTORY_H_ + +#include "quiche/quic/tools/quic_toy_server.h" + +namespace quic { + +// Factory creating QuicServer instances. +class QuicServerFactory : public QuicToyServer::ServerFactory { + public: + std::unique_ptr CreateServer( + QuicSimpleServerBackend* backend, + std::unique_ptr proof_source, + const quic::ParsedQuicVersionVector& supported_versions) override; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_QUIC_SERVER_FACTORY_H_ diff --git a/quiche/quic/tools/quic_server_test.cc b/quiche/quic/tools/quic_server_test.cc new file mode 100644 index 000000000000..facd42e7b06f --- /dev/null +++ b/quiche/quic/tools/quic_server_test.cc @@ -0,0 +1,228 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/quic_server.h" + +#include + +#include "absl/base/macros.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/deterministic_connection_id_generator.h" +#include "quiche/quic/core/io/quic_default_event_loop.h" +#include "quiche/quic/core/io/quic_event_loop.h" +#include "quiche/quic/core/quic_default_clock.h" +#include "quiche/quic/core/quic_default_connection_helper.h" +#include "quiche/quic/core/quic_default_packet_writer.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/platform/api/quic_test_loopback.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/mock_quic_dispatcher.h" +#include "quiche/quic/test_tools/quic_server_peer.h" +#include "quiche/quic/tools/quic_memory_cache_backend.h" +#include "quiche/quic/tools/quic_simple_crypto_server_stream_helper.h" + +namespace quic { +namespace test { + +using ::testing::_; + +namespace { + +class MockQuicSimpleDispatcher : public QuicSimpleDispatcher { + public: + MockQuicSimpleDispatcher( + const QuicConfig* config, const QuicCryptoServerConfig* crypto_config, + QuicVersionManager* version_manager, + std::unique_ptr helper, + std::unique_ptr session_helper, + std::unique_ptr alarm_factory, + QuicSimpleServerBackend* quic_simple_server_backend, + ConnectionIdGeneratorInterface& generator) + : QuicSimpleDispatcher(config, crypto_config, version_manager, + std::move(helper), std::move(session_helper), + std::move(alarm_factory), + quic_simple_server_backend, + kQuicDefaultConnectionIdLength, generator) {} + ~MockQuicSimpleDispatcher() override = default; + + MOCK_METHOD(void, OnCanWrite, (), (override)); + MOCK_METHOD(bool, HasPendingWrites, (), (const, override)); + MOCK_METHOD(bool, HasChlosBuffered, (), (const, override)); + MOCK_METHOD(void, ProcessBufferedChlos, (size_t), (override)); +}; + +class TestQuicServer : public QuicServer { + public: + explicit TestQuicServer(QuicEventLoopFactory* event_loop_factory, + QuicMemoryCacheBackend* quic_simple_server_backend) + : QuicServer(crypto_test_utils::ProofSourceForTesting(), + quic_simple_server_backend), + quic_simple_server_backend_(quic_simple_server_backend), + event_loop_factory_(event_loop_factory) {} + + ~TestQuicServer() override = default; + + MockQuicSimpleDispatcher* mock_dispatcher() { return mock_dispatcher_; } + + protected: + QuicDispatcher* CreateQuicDispatcher() override { + mock_dispatcher_ = new MockQuicSimpleDispatcher( + &config(), &crypto_config(), version_manager(), + std::make_unique(), + std::unique_ptr( + new QuicSimpleCryptoServerStreamHelper()), + event_loop()->CreateAlarmFactory(), quic_simple_server_backend_, + connection_id_generator()); + return mock_dispatcher_; + } + + std::unique_ptr CreateEventLoop() override { + return event_loop_factory_->Create(QuicDefaultClock::Get()); + } + + MockQuicSimpleDispatcher* mock_dispatcher_ = nullptr; + QuicMemoryCacheBackend* quic_simple_server_backend_; + QuicEventLoopFactory* event_loop_factory_; +}; + +class QuicServerEpollInTest : public QuicTestWithParam { + public: + QuicServerEpollInTest() + : server_address_(TestLoopback(), 0), + server_(GetParam(), &quic_simple_server_backend_) {} + + void StartListening() { + server_.CreateUDPSocketAndListen(server_address_); + server_address_ = QuicSocketAddress(server_address_.host(), server_.port()); + + ASSERT_TRUE(QuicServerPeer::SetSmallSocket(&server_)); + + if (!server_.overflow_supported()) { + QUIC_LOG(WARNING) << "Overflow not supported. Not testing."; + return; + } + } + + protected: + QuicSocketAddress server_address_; + QuicMemoryCacheBackend quic_simple_server_backend_; + TestQuicServer server_; +}; + +std::string GetTestParamName( + ::testing::TestParamInfo info) { + return EscapeTestParamName(info.param->GetName()); +} + +INSTANTIATE_TEST_SUITE_P(QuicServerEpollInTests, QuicServerEpollInTest, + ::testing::ValuesIn(GetAllSupportedEventLoops()), + GetTestParamName); + +// Tests that if dispatcher has CHLOs waiting for connection creation, EPOLLIN +// event should try to create connections for them. And set epoll mask with +// EPOLLIN if there are still CHLOs remaining at the end of epoll event. +TEST_P(QuicServerEpollInTest, ProcessBufferedCHLOsOnEpollin) { + // Given an EPOLLIN event, try to create session for buffered CHLOs. In first + // event, dispatcher can't create session for all of CHLOs. So listener should + // register another EPOLLIN event by itself. Even without new packet arrival, + // the rest CHLOs should be process in next epoll event. + StartListening(); + bool more_chlos = true; + MockQuicSimpleDispatcher* dispatcher_ = server_.mock_dispatcher(); + QUICHE_DCHECK(dispatcher_ != nullptr); + EXPECT_CALL(*dispatcher_, OnCanWrite()).Times(testing::AnyNumber()); + EXPECT_CALL(*dispatcher_, ProcessBufferedChlos(_)).Times(2); + EXPECT_CALL(*dispatcher_, HasPendingWrites()).Times(testing::AnyNumber()); + // Expect there are still CHLOs buffered after 1st event. But not any more + // after 2nd event. + EXPECT_CALL(*dispatcher_, HasChlosBuffered()) + .WillOnce(testing::Return(true)) + .WillOnce( + DoAll(testing::Assign(&more_chlos, false), testing::Return(false))); + + // Send a packet to trigger epoll event. + QuicUdpSocketApi socket_api; + SocketFd fd = + socket_api.Create(server_address_.host().AddressFamilyToInt(), + /*receive_buffer_size =*/kDefaultSocketReceiveBuffer, + /*send_buffer_size =*/kDefaultSocketReceiveBuffer); + ASSERT_NE(fd, kQuicInvalidSocketFd); + + char buf[1024]; + memset(buf, 0, ABSL_ARRAYSIZE(buf)); + QuicUdpPacketInfo packet_info; + packet_info.SetPeerAddress(server_address_); + WriteResult result = + socket_api.WritePacket(fd, buf, sizeof(buf), packet_info); + if (result.status != WRITE_STATUS_OK) { + QUIC_LOG(ERROR) << "Write error for UDP packet: " << result.error_code; + } + + while (more_chlos) { + server_.WaitForEvents(); + } +} + +class QuicServerDispatchPacketTest : public QuicTest { + public: + QuicServerDispatchPacketTest() + : crypto_config_("blah", QuicRandom::GetInstance(), + crypto_test_utils::ProofSourceForTesting(), + KeyExchangeSource::Default()), + version_manager_(AllSupportedVersions()), + event_loop_(GetDefaultEventLoop()->Create(QuicDefaultClock::Get())), + connection_id_generator_(kQuicDefaultConnectionIdLength), + dispatcher_(&config_, &crypto_config_, &version_manager_, + std::make_unique(), + std::make_unique(), + event_loop_->CreateAlarmFactory(), + &quic_simple_server_backend_, connection_id_generator_) { + dispatcher_.InitializeWithWriter(new QuicDefaultPacketWriter(1234)); + } + + void DispatchPacket(const QuicReceivedPacket& packet) { + QuicSocketAddress client_addr, server_addr; + dispatcher_.ProcessPacket(server_addr, client_addr, packet); + } + + protected: + QuicConfig config_; + QuicCryptoServerConfig crypto_config_; + QuicVersionManager version_manager_; + std::unique_ptr event_loop_; + QuicMemoryCacheBackend quic_simple_server_backend_; + DeterministicConnectionIdGenerator connection_id_generator_; + MockQuicDispatcher dispatcher_; +}; + +TEST_F(QuicServerDispatchPacketTest, DispatchPacket) { + // clang-format off + unsigned char valid_packet[] = { + // public flags (8 byte connection_id) + 0x3C, + // connection_id + 0x10, 0x32, 0x54, 0x76, + 0x98, 0xBA, 0xDC, 0xFE, + // packet number + 0xBC, 0x9A, 0x78, 0x56, + 0x34, 0x12, + // private flags + 0x00 + }; + // clang-format on + QuicReceivedPacket encrypted_valid_packet( + reinterpret_cast(valid_packet), ABSL_ARRAYSIZE(valid_packet), + QuicTime::Zero(), false); + + EXPECT_CALL(dispatcher_, ProcessPacket(_, _, _)).Times(1); + DispatchPacket(encrypted_valid_packet); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/tools/quic_simple_client_session.cc b/quiche/quic/tools/quic_simple_client_session.cc new file mode 100644 index 000000000000..23601ad3a5a6 --- /dev/null +++ b/quiche/quic/tools/quic_simple_client_session.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/quic_simple_client_session.h" + +#include + +#include "quiche/quic/core/quic_path_validator.h" + +namespace quic { + +QuicSimpleClientSession::QuicSimpleClientSession( + const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, QuicClientBase::NetworkHelper* network_helper, + const QuicServerId& server_id, QuicCryptoClientConfig* crypto_config, + QuicClientPushPromiseIndex* push_promise_index, bool drop_response_body, + bool enable_web_transport) + : QuicSimpleClientSession(config, supported_versions, connection, + /*visitor=*/nullptr, network_helper, server_id, + crypto_config, push_promise_index, + drop_response_body, enable_web_transport) {} + +QuicSimpleClientSession::QuicSimpleClientSession( + const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, QuicSession::Visitor* visitor, + QuicClientBase::NetworkHelper* network_helper, + const QuicServerId& server_id, QuicCryptoClientConfig* crypto_config, + QuicClientPushPromiseIndex* push_promise_index, bool drop_response_body, + bool enable_web_transport) + : QuicSpdyClientSession(config, supported_versions, connection, visitor, + server_id, crypto_config, push_promise_index), + network_helper_(network_helper), + drop_response_body_(drop_response_body), + enable_web_transport_(enable_web_transport) {} + +std::unique_ptr +QuicSimpleClientSession::CreateClientStream() { + return std::make_unique( + GetNextOutgoingBidirectionalStreamId(), this, BIDIRECTIONAL, + drop_response_body_); +} + +bool QuicSimpleClientSession::ShouldNegotiateWebTransport() { + return enable_web_transport_; +} + +HttpDatagramSupport QuicSimpleClientSession::LocalHttpDatagramSupport() { + return enable_web_transport_ ? HttpDatagramSupport::kDraft04 + : HttpDatagramSupport::kNone; +} + +std::unique_ptr +QuicSimpleClientSession::CreateContextForMultiPortPath() { + if (!network_helper_ || connection()->multi_port_stats() == nullptr) { + return nullptr; + } + auto self_address = connection()->self_address(); + auto server_address = connection()->peer_address(); + if (!network_helper_->CreateUDPSocketAndBind( + server_address, self_address.host(), self_address.port() + 1)) { + return nullptr; + } + QuicPacketWriter* writer = network_helper_->CreateQuicPacketWriter(); + if (writer == nullptr) { + return nullptr; + } + return std::make_unique( + std::unique_ptr(writer), + network_helper_->GetLatestClientAddress(), peer_address()); +} + +void QuicSimpleClientSession::MigrateToMultiPortPath( + std::unique_ptr context) { + auto* path_migration_context = + static_cast(context.get()); + MigratePath(path_migration_context->self_address(), + path_migration_context->peer_address(), + path_migration_context->ReleaseWriter(), /*owns_writer=*/true); +} + +} // namespace quic diff --git a/quiche/quic/tools/quic_simple_client_session.h b/quiche/quic/tools/quic_simple_client_session.h new file mode 100644 index 000000000000..1770c9fc66f5 --- /dev/null +++ b/quiche/quic/tools/quic_simple_client_session.h @@ -0,0 +1,52 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TOOLS_QUIC_SIMPLE_CLIENT_SESSION_H_ +#define QUICHE_QUIC_TOOLS_QUIC_SIMPLE_CLIENT_SESSION_H_ + +#include "quiche/quic/core/http/quic_spdy_client_session.h" +#include "quiche/quic/tools/quic_client_base.h" +#include "quiche/quic/tools/quic_simple_client_stream.h" + +namespace quic { + +class QuicSimpleClientSession : public QuicSpdyClientSession { + public: + QuicSimpleClientSession(const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, + QuicClientBase::NetworkHelper* network_helper, + const QuicServerId& server_id, + QuicCryptoClientConfig* crypto_config, + QuicClientPushPromiseIndex* push_promise_index, + bool drop_response_body, bool enable_web_transport); + + QuicSimpleClientSession(const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, + QuicSession::Visitor* visitor, + QuicClientBase::NetworkHelper* network_helper, + const QuicServerId& server_id, + QuicCryptoClientConfig* crypto_config, + QuicClientPushPromiseIndex* push_promise_index, + bool drop_response_body, bool enable_web_transport); + + std::unique_ptr CreateClientStream() override; + bool ShouldNegotiateWebTransport() override; + HttpDatagramSupport LocalHttpDatagramSupport() override; + std::unique_ptr CreateContextForMultiPortPath() + override; + void MigrateToMultiPortPath( + std::unique_ptr context) override; + bool drop_response_body() const { return drop_response_body_; } + + private: + QuicClientBase::NetworkHelper* network_helper_; + const bool drop_response_body_; + const bool enable_web_transport_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_QUIC_SIMPLE_CLIENT_SESSION_H_ diff --git a/quiche/quic/tools/quic_simple_client_stream.cc b/quiche/quic/tools/quic_simple_client_stream.cc new file mode 100644 index 000000000000..14de9f03c2c4 --- /dev/null +++ b/quiche/quic/tools/quic_simple_client_stream.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/quic_simple_client_stream.h" + +namespace quic { + +void QuicSimpleClientStream::OnBodyAvailable() { + if (!drop_response_body_) { + QuicSpdyClientStream::OnBodyAvailable(); + return; + } + + while (HasBytesToRead()) { + struct iovec iov; + if (GetReadableRegions(&iov, 1) == 0) { + break; + } + MarkConsumed(iov.iov_len); + } + if (sequencer()->IsClosed()) { + OnFinRead(); + } else { + sequencer()->SetUnblocked(); + } +} + +} // namespace quic diff --git a/quiche/quic/tools/quic_simple_client_stream.h b/quiche/quic/tools/quic_simple_client_stream.h new file mode 100644 index 000000000000..976ceafbfb8a --- /dev/null +++ b/quiche/quic/tools/quic_simple_client_stream.h @@ -0,0 +1,26 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TOOLS_QUIC_SIMPLE_CLIENT_STREAM_H_ +#define QUICHE_QUIC_TOOLS_QUIC_SIMPLE_CLIENT_STREAM_H_ + +#include "quiche/quic/core/http/quic_spdy_client_stream.h" + +namespace quic { + +class QuicSimpleClientStream : public QuicSpdyClientStream { + public: + QuicSimpleClientStream(QuicStreamId id, QuicSpdyClientSession* session, + StreamType type, bool drop_response_body) + : QuicSpdyClientStream(id, session, type), + drop_response_body_(drop_response_body) {} + void OnBodyAvailable() override; + + private: + const bool drop_response_body_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_QUIC_SIMPLE_CLIENT_STREAM_H_ diff --git a/quiche/quic/tools/quic_simple_crypto_server_stream_helper.cc b/quiche/quic/tools/quic_simple_crypto_server_stream_helper.cc new file mode 100644 index 000000000000..08d63b7ee089 --- /dev/null +++ b/quiche/quic/tools/quic_simple_crypto_server_stream_helper.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/quic_simple_crypto_server_stream_helper.h" + +#include "quiche/quic/core/quic_utils.h" + +namespace quic { + +QuicSimpleCryptoServerStreamHelper::QuicSimpleCryptoServerStreamHelper() = + default; + +QuicSimpleCryptoServerStreamHelper::~QuicSimpleCryptoServerStreamHelper() = + default; + +bool QuicSimpleCryptoServerStreamHelper::CanAcceptClientHello( + const CryptoHandshakeMessage& /*message*/, + const QuicSocketAddress& /*client_address*/, + const QuicSocketAddress& /*peer_address*/, + const QuicSocketAddress& /*self_address*/, + std::string* /*error_details*/) const { + return true; +} + +} // namespace quic diff --git a/quiche/quic/tools/quic_simple_crypto_server_stream_helper.h b/quiche/quic/tools/quic_simple_crypto_server_stream_helper.h new file mode 100644 index 000000000000..ba64226bf1c5 --- /dev/null +++ b/quiche/quic/tools/quic_simple_crypto_server_stream_helper.h @@ -0,0 +1,31 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TOOLS_QUIC_SIMPLE_CRYPTO_SERVER_STREAM_HELPER_H_ +#define QUICHE_QUIC_TOOLS_QUIC_SIMPLE_CRYPTO_SERVER_STREAM_HELPER_H_ + +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/quic_crypto_server_stream_base.h" + +namespace quic { + +// Simple helper for server crypto streams which generates a new random +// connection ID for rejects. +class QuicSimpleCryptoServerStreamHelper + : public QuicCryptoServerStreamBase::Helper { + public: + QuicSimpleCryptoServerStreamHelper(); + + ~QuicSimpleCryptoServerStreamHelper() override; + + bool CanAcceptClientHello(const CryptoHandshakeMessage& message, + const QuicSocketAddress& client_address, + const QuicSocketAddress& peer_address, + const QuicSocketAddress& self_address, + std::string* error_details) const override; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_QUIC_SIMPLE_CRYPTO_SERVER_STREAM_HELPER_H_ diff --git a/quiche/quic/tools/quic_simple_dispatcher.cc b/quiche/quic/tools/quic_simple_dispatcher.cc new file mode 100644 index 000000000000..50f063697ff4 --- /dev/null +++ b/quiche/quic/tools/quic_simple_dispatcher.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/quic_simple_dispatcher.h" + +#include "absl/strings/string_view.h" +#include "quiche/quic/tools/quic_simple_server_session.h" + +namespace quic { + +QuicSimpleDispatcher::QuicSimpleDispatcher( + const QuicConfig* config, const QuicCryptoServerConfig* crypto_config, + QuicVersionManager* version_manager, + std::unique_ptr helper, + std::unique_ptr session_helper, + std::unique_ptr alarm_factory, + QuicSimpleServerBackend* quic_simple_server_backend, + uint8_t expected_server_connection_id_length, + ConnectionIdGeneratorInterface& generator) + : QuicDispatcher(config, crypto_config, version_manager, std::move(helper), + std::move(session_helper), std::move(alarm_factory), + expected_server_connection_id_length, generator), + quic_simple_server_backend_(quic_simple_server_backend) {} + +QuicSimpleDispatcher::~QuicSimpleDispatcher() = default; + +int QuicSimpleDispatcher::GetRstErrorCount( + QuicRstStreamErrorCode error_code) const { + auto it = rst_error_map_.find(error_code); + if (it == rst_error_map_.end()) { + return 0; + } + return it->second; +} + +void QuicSimpleDispatcher::OnRstStreamReceived( + const QuicRstStreamFrame& frame) { + auto it = rst_error_map_.find(frame.error_code); + if (it == rst_error_map_.end()) { + rst_error_map_.insert(std::make_pair(frame.error_code, 1)); + } else { + it->second++; + } +} + +std::unique_ptr QuicSimpleDispatcher::CreateQuicSession( + QuicConnectionId connection_id, const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, absl::string_view /*alpn*/, + const ParsedQuicVersion& version, + const ParsedClientHello& /*parsed_chlo*/) { + // The QuicServerSessionBase takes ownership of |connection| below. + QuicConnection* connection = new QuicConnection( + connection_id, self_address, peer_address, helper(), alarm_factory(), + writer(), + /* owns_writer= */ false, Perspective::IS_SERVER, + ParsedQuicVersionVector{version}, connection_id_generator()); + + auto session = std::make_unique( + config(), GetSupportedVersions(), connection, this, session_helper(), + crypto_config(), compressed_certs_cache(), quic_simple_server_backend_); + session->Initialize(); + return session; +} + +} // namespace quic diff --git a/quiche/quic/tools/quic_simple_dispatcher.h b/quiche/quic/tools/quic_simple_dispatcher.h new file mode 100644 index 000000000000..16f7fd0afc78 --- /dev/null +++ b/quiche/quic/tools/quic_simple_dispatcher.h @@ -0,0 +1,53 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TOOLS_QUIC_SIMPLE_DISPATCHER_H_ +#define QUICHE_QUIC_TOOLS_QUIC_SIMPLE_DISPATCHER_H_ + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/http/quic_server_session_base.h" +#include "quiche/quic/core/quic_dispatcher.h" +#include "quiche/quic/tools/quic_simple_server_backend.h" + +namespace quic { + +class QuicSimpleDispatcher : public QuicDispatcher { + public: + QuicSimpleDispatcher( + const QuicConfig* config, const QuicCryptoServerConfig* crypto_config, + QuicVersionManager* version_manager, + std::unique_ptr helper, + std::unique_ptr session_helper, + std::unique_ptr alarm_factory, + QuicSimpleServerBackend* quic_simple_server_backend, + uint8_t expected_server_connection_id_length, + ConnectionIdGeneratorInterface& generator); + + ~QuicSimpleDispatcher() override; + + int GetRstErrorCount(QuicRstStreamErrorCode rst_error_code) const; + + void OnRstStreamReceived(const QuicRstStreamFrame& frame) override; + + protected: + std::unique_ptr CreateQuicSession( + QuicConnectionId connection_id, const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, absl::string_view alpn, + const ParsedQuicVersion& version, + const ParsedClientHello& parsed_chlo) override; + + QuicSimpleServerBackend* server_backend() { + return quic_simple_server_backend_; + } + + private: + QuicSimpleServerBackend* quic_simple_server_backend_; // Unowned. + + // The map of the reset error code with its counter. + std::map rst_error_map_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_QUIC_SIMPLE_DISPATCHER_H_ diff --git a/quiche/quic/tools/quic_simple_server_backend.h b/quiche/quic/tools/quic_simple_server_backend.h new file mode 100644 index 000000000000..a142dd85526f --- /dev/null +++ b/quiche/quic/tools/quic_simple_server_backend.h @@ -0,0 +1,123 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TOOLS_QUIC_SIMPLE_SERVER_BACKEND_H_ +#define QUICHE_QUIC_TOOLS_QUIC_SIMPLE_SERVER_BACKEND_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/http/quic_spdy_stream.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/socket_factory.h" +#include "quiche/quic/core/web_transport_interface.h" +#include "quiche/quic/tools/quic_backend_response.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +// This interface implements the functionality to fetch a response +// from the backend (such as cache, http-proxy etc) to serve +// requests received by a Quic Server +class QuicSimpleServerBackend { + public: + // This interface implements the methods + // called by the QuicSimpleServerBackend implementation + // to process the request in the backend + class RequestHandler { + public: + virtual ~RequestHandler() {} + + virtual QuicConnectionId connection_id() const = 0; + virtual QuicStreamId stream_id() const = 0; + virtual std::string peer_host() const = 0; + virtual QuicSpdyStream* GetStream() = 0; + // Called when the response is ready at the backend and can be send back to + // the QUIC client. + virtual void OnResponseBackendComplete( + const QuicBackendResponse* response) = 0; + // Sends additional non-full-response data (without headers) to the request + // stream, e.g. for CONNECT data. May only be called after sending an + // incomplete response (using `QuicBackendResponse::INCOMPLETE_RESPONSE`). + // Sends the data with the FIN bit to close the stream if `close_stream` is + // true. + virtual void SendStreamData(absl::string_view data, bool close_stream) = 0; + // Abruptly terminates (resets) the request stream with `error`. + virtual void TerminateStreamWithError(QuicResetStreamError error) = 0; + }; + + struct WebTransportResponse { + spdy::Http2HeaderBlock response_headers; + std::unique_ptr visitor; + }; + + virtual ~QuicSimpleServerBackend() = default; + // This method initializes the backend instance to fetch responses + // from a backend server, in-memory cache etc. + virtual bool InitializeBackend(const std::string& backend_url) = 0; + // Returns true if the backend has been successfully initialized + // and could be used to fetch HTTP requests + virtual bool IsBackendInitialized() const = 0; + // Passes the socket factory in use by the QuicServer. Must live as long as + // incoming requests/data are still sent to the backend, or until cleared by + // calling with null. Must not be called while backend is handling requests. + virtual void SetSocketFactory(SocketFactory* /*socket_factory*/) {} + // Triggers a HTTP request to be sent to the backend server or cache + // If response is immediately available, the function synchronously calls + // the `request_handler` with the HTTP response. + // If the response has to be fetched over the network, the function + // asynchronously calls `request_handler` with the HTTP response. + // + // Not called for requests using the CONNECT method. + virtual void FetchResponseFromBackend( + const spdy::Http2HeaderBlock& request_headers, + const std::string& request_body, RequestHandler* request_handler) = 0; + + // Handles headers for requests using the CONNECT method. Called immediately + // on receiving the headers, potentially before the request is complete or + // data is received. Any response (complete or incomplete) should be sent, + // potentially asynchronously, using `request_handler`. + // + // If not overridden by backend, sends an error appropriate for a server that + // does not handle CONNECT requests. + virtual void HandleConnectHeaders( + const spdy::Http2HeaderBlock& /*request_headers*/, + RequestHandler* request_handler) { + spdy::Http2HeaderBlock headers; + headers[":status"] = "405"; + QuicBackendResponse response; + response.set_headers(std::move(headers)); + request_handler->OnResponseBackendComplete(&response); + } + // Handles data for requests using the CONNECT method. Called repeatedly + // whenever new data is available. If `data_complete` is true, data was + // received with the FIN bit, and this is the last call to this method. + // + // If not overridden by backend, abruptly terminates the stream. + virtual void HandleConnectData(absl::string_view /*data*/, + bool /*data_complete*/, + RequestHandler* request_handler) { + request_handler->TerminateStreamWithError( + QuicResetStreamError::FromInternal(QUIC_STREAM_CONNECT_ERROR)); + } + + // Clears the state of the backend instance + virtual void CloseBackendResponseStream(RequestHandler* request_handler) = 0; + + virtual WebTransportResponse ProcessWebTransportRequest( + const spdy::Http2HeaderBlock& /*request_headers*/, + WebTransportSession* /*session*/) { + WebTransportResponse response; + response.response_headers[":status"] = "400"; + return response; + } + virtual bool SupportsWebTransport() { return false; } + virtual bool SupportsExtendedConnect() { return true; } +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_QUIC_SIMPLE_SERVER_BACKEND_H_ diff --git a/quiche/quic/tools/quic_simple_server_session.cc b/quiche/quic/tools/quic_simple_server_session.cc new file mode 100644 index 000000000000..7e49470e47f6 --- /dev/null +++ b/quiche/quic/tools/quic_simple_server_session.cc @@ -0,0 +1,106 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/quic_simple_server_session.h" + +#include + +#include "absl/memory/memory.h" +#include "quiche/quic/core/http/quic_server_initiated_spdy_stream.h" +#include "quiche/quic/core/http/quic_spdy_session.h" +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/core/quic_stream_priority.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/tools/quic_simple_server_stream.h" + +namespace quic { + +QuicSimpleServerSession::QuicSimpleServerSession( + const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, QuicSession::Visitor* visitor, + QuicCryptoServerStreamBase::Helper* helper, + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache, + QuicSimpleServerBackend* quic_simple_server_backend) + : QuicServerSessionBase(config, supported_versions, connection, visitor, + helper, crypto_config, compressed_certs_cache), + quic_simple_server_backend_(quic_simple_server_backend) { + QUICHE_DCHECK(quic_simple_server_backend_); +} + +QuicSimpleServerSession::~QuicSimpleServerSession() { DeleteConnection(); } + +std::unique_ptr +QuicSimpleServerSession::CreateQuicCryptoServerStream( + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache) { + return CreateCryptoServerStream(crypto_config, compressed_certs_cache, this, + stream_helper()); +} + +void QuicSimpleServerSession::OnStreamFrame(const QuicStreamFrame& frame) { + if (!IsIncomingStream(frame.stream_id) && !WillNegotiateWebTransport()) { + QUIC_LOG(WARNING) << "Client shouldn't send data on server push stream"; + connection()->CloseConnection( + QUIC_INVALID_STREAM_ID, "Client sent data on server push stream", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; + } + QuicSpdySession::OnStreamFrame(frame); +} + +QuicSpdyStream* QuicSimpleServerSession::CreateIncomingStream(QuicStreamId id) { + if (!ShouldCreateIncomingStream(id)) { + return nullptr; + } + + QuicSpdyStream* stream = new QuicSimpleServerStream( + id, this, BIDIRECTIONAL, quic_simple_server_backend_); + ActivateStream(absl::WrapUnique(stream)); + return stream; +} + +QuicSpdyStream* QuicSimpleServerSession::CreateIncomingStream( + PendingStream* pending) { + QuicSpdyStream* stream = + new QuicSimpleServerStream(pending, this, quic_simple_server_backend_); + ActivateStream(absl::WrapUnique(stream)); + return stream; +} + +QuicSpdyStream* QuicSimpleServerSession::CreateOutgoingBidirectionalStream() { + if (!WillNegotiateWebTransport()) { + QUIC_BUG(QuicSimpleServerSession CreateOutgoingBidirectionalStream without + WebTransport support) + << "QuicSimpleServerSession::CreateOutgoingBidirectionalStream called " + "in a session without WebTransport support."; + return nullptr; + } + if (!ShouldCreateOutgoingBidirectionalStream()) { + return nullptr; + } + + QuicServerInitiatedSpdyStream* stream = new QuicServerInitiatedSpdyStream( + GetNextOutgoingBidirectionalStreamId(), this, BIDIRECTIONAL); + ActivateStream(absl::WrapUnique(stream)); + return stream; +} + +QuicSimpleServerStream* +QuicSimpleServerSession::CreateOutgoingUnidirectionalStream() { + if (!ShouldCreateOutgoingUnidirectionalStream()) { + return nullptr; + } + + QuicSimpleServerStream* stream = new QuicSimpleServerStream( + GetNextOutgoingUnidirectionalStreamId(), this, WRITE_UNIDIRECTIONAL, + quic_simple_server_backend_); + ActivateStream(absl::WrapUnique(stream)); + return stream; +} + +} // namespace quic diff --git a/quiche/quic/tools/quic_simple_server_session.h b/quiche/quic/tools/quic_simple_server_session.h new file mode 100644 index 000000000000..bceb4e587cb1 --- /dev/null +++ b/quiche/quic/tools/quic_simple_server_session.h @@ -0,0 +1,87 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// A toy server specific QuicSession subclass. + +#ifndef QUICHE_QUIC_TOOLS_QUIC_SIMPLE_SERVER_SESSION_H_ +#define QUICHE_QUIC_TOOLS_QUIC_SIMPLE_SERVER_SESSION_H_ + +#include + +#include +#include +#include +#include +#include +#include + +#include "quiche/quic/core/http/quic_server_session_base.h" +#include "quiche/quic/core/http/quic_spdy_session.h" +#include "quiche/quic/core/quic_crypto_server_stream_base.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/tools/quic_backend_response.h" +#include "quiche/quic/tools/quic_simple_server_backend.h" +#include "quiche/quic/tools/quic_simple_server_stream.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +namespace test { +class QuicSimpleServerSessionPeer; +} // namespace test + +class QuicSimpleServerSession : public QuicServerSessionBase { + public: + // Takes ownership of |connection|. + QuicSimpleServerSession(const QuicConfig& config, + const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, + QuicSession::Visitor* visitor, + QuicCryptoServerStreamBase::Helper* helper, + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache, + QuicSimpleServerBackend* quic_simple_server_backend); + QuicSimpleServerSession(const QuicSimpleServerSession&) = delete; + QuicSimpleServerSession& operator=(const QuicSimpleServerSession&) = delete; + + ~QuicSimpleServerSession() override; + + // Override base class to detact client sending data on server push stream. + void OnStreamFrame(const QuicStreamFrame& frame) override; + + protected: + // QuicSession methods: + QuicSpdyStream* CreateIncomingStream(QuicStreamId id) override; + QuicSpdyStream* CreateIncomingStream(PendingStream* pending) override; + QuicSpdyStream* CreateOutgoingBidirectionalStream() override; + QuicSimpleServerStream* CreateOutgoingUnidirectionalStream() override; + + // QuicServerSessionBaseMethod: + std::unique_ptr CreateQuicCryptoServerStream( + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache) override; + + QuicSimpleServerBackend* server_backend() { + return quic_simple_server_backend_; + } + + bool ShouldNegotiateWebTransport() override { + return quic_simple_server_backend_->SupportsWebTransport(); + } + HttpDatagramSupport LocalHttpDatagramSupport() override { + if (ShouldNegotiateWebTransport()) { + return HttpDatagramSupport::kRfcAndDraft04; + } + return QuicServerSessionBase::LocalHttpDatagramSupport(); + } + + private: + friend class test::QuicSimpleServerSessionPeer; + + QuicSimpleServerBackend* quic_simple_server_backend_; // Not owned. +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_QUIC_SIMPLE_SERVER_SESSION_H_ diff --git a/quiche/quic/tools/quic_simple_server_session_test.cc b/quiche/quic/tools/quic_simple_server_session_test.cc new file mode 100644 index 000000000000..b6662d621b5a --- /dev/null +++ b/quiche/quic/tools/quic_simple_server_session_test.cc @@ -0,0 +1,465 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/quic_simple_server_session.h" + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/core/crypto/quic_crypto_server_config.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/http/http_encoder.h" +#include "quiche/quic/core/proto/cached_network_parameters_proto.h" +#include "quiche/quic/core/quic_connection.h" +#include "quiche/quic/core/quic_crypto_server_stream.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/core/tls_server_handshaker.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/mock_quic_session_visitor.h" +#include "quiche/quic/test_tools/quic_config_peer.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_sent_packet_manager_peer.h" +#include "quiche/quic/test_tools/quic_session_peer.h" +#include "quiche/quic/test_tools/quic_spdy_session_peer.h" +#include "quiche/quic/test_tools/quic_stream_peer.h" +#include "quiche/quic/test_tools/quic_sustained_bandwidth_recorder_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/tools/quic_backend_response.h" +#include "quiche/quic/tools/quic_memory_cache_backend.h" +#include "quiche/quic/tools/quic_simple_server_stream.h" + +using testing::_; +using testing::AtLeast; +using testing::InSequence; +using testing::Invoke; +using testing::Return; +using testing::StrictMock; + +namespace quic { +namespace test { +namespace { + +// Data to be sent on a request stream. In Google QUIC, this is interpreted as +// DATA payload (there is no framing on request streams). In IETF QUIC, this is +// interpreted as HEADERS frame (type 0x1) with payload length 122 ('z'). Since +// no payload is included, QPACK decoder will not be invoked. +const char* const kStreamData = "\1z"; + +} // namespace + +class QuicSimpleServerSessionPeer { + public: + static void SetCryptoStream(QuicSimpleServerSession* s, + QuicCryptoServerStreamBase* crypto_stream) { + s->crypto_stream_.reset(crypto_stream); + } + + static QuicSpdyStream* CreateIncomingStream(QuicSimpleServerSession* s, + QuicStreamId id) { + return s->CreateIncomingStream(id); + } + + static QuicSimpleServerStream* CreateOutgoingUnidirectionalStream( + QuicSimpleServerSession* s) { + return s->CreateOutgoingUnidirectionalStream(); + } +}; + +namespace { + +const size_t kMaxStreamsForTest = 10; + +class MockQuicCryptoServerStream : public QuicCryptoServerStream { + public: + explicit MockQuicCryptoServerStream( + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache, QuicSession* session, + QuicCryptoServerStreamBase::Helper* helper) + : QuicCryptoServerStream(crypto_config, compressed_certs_cache, session, + helper) {} + MockQuicCryptoServerStream(const MockQuicCryptoServerStream&) = delete; + MockQuicCryptoServerStream& operator=(const MockQuicCryptoServerStream&) = + delete; + ~MockQuicCryptoServerStream() override {} + + MOCK_METHOD(void, SendServerConfigUpdate, (const CachedNetworkParameters*), + (override)); + + bool encryption_established() const override { return true; } +}; + +class MockTlsServerHandshaker : public TlsServerHandshaker { + public: + explicit MockTlsServerHandshaker(QuicSession* session, + const QuicCryptoServerConfig* crypto_config) + : TlsServerHandshaker(session, crypto_config) {} + MockTlsServerHandshaker(const MockTlsServerHandshaker&) = delete; + MockTlsServerHandshaker& operator=(const MockTlsServerHandshaker&) = delete; + ~MockTlsServerHandshaker() override {} + + MOCK_METHOD(void, SendServerConfigUpdate, (const CachedNetworkParameters*), + (override)); + + bool encryption_established() const override { return true; } +}; + +class MockQuicConnectionWithSendStreamData : public MockQuicConnection { + public: + MockQuicConnectionWithSendStreamData( + MockQuicConnectionHelper* helper, MockAlarmFactory* alarm_factory, + Perspective perspective, + const ParsedQuicVersionVector& supported_versions) + : MockQuicConnection(helper, alarm_factory, perspective, + supported_versions) { + auto consume_all_data = [](QuicStreamId /*id*/, size_t write_length, + QuicStreamOffset /*offset*/, + StreamSendingState state) { + return QuicConsumedData(write_length, state != NO_FIN); + }; + ON_CALL(*this, SendStreamData(_, _, _, _)) + .WillByDefault(Invoke(consume_all_data)); + } + + MOCK_METHOD(QuicConsumedData, SendStreamData, + (QuicStreamId id, size_t write_length, QuicStreamOffset offset, + StreamSendingState state), + (override)); +}; + +class MockQuicSimpleServerSession : public QuicSimpleServerSession { + public: + MockQuicSimpleServerSession( + const QuicConfig& config, QuicConnection* connection, + QuicSession::Visitor* visitor, QuicCryptoServerStreamBase::Helper* helper, + const QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache, + QuicSimpleServerBackend* quic_simple_server_backend) + : QuicSimpleServerSession( + config, CurrentSupportedVersions(), connection, visitor, helper, + crypto_config, compressed_certs_cache, quic_simple_server_backend) { + } + MOCK_METHOD(void, SendBlocked, (QuicStreamId, QuicStreamOffset), (override)); + MOCK_METHOD(bool, WriteControlFrame, + (const QuicFrame& frame, TransmissionType type), (override)); +}; + +class QuicSimpleServerSessionTest + : public QuicTestWithParam { + public: + // The function ensures that A) the MAX_STREAMS frames get properly deleted + // (since the test uses a 'did we leak memory' check ... if we just lose the + // frame, the test fails) and B) returns true (instead of the default, false) + // which ensures that the rest of the system thinks that the frame actually + // was transmitted. + bool ClearMaxStreamsControlFrame(const QuicFrame& frame) { + if (frame.type == MAX_STREAMS_FRAME) { + DeleteFrame(&const_cast(frame)); + return true; + } + return false; + } + + protected: + QuicSimpleServerSessionTest() + : crypto_config_(QuicCryptoServerConfig::TESTING, + QuicRandom::GetInstance(), + crypto_test_utils::ProofSourceForTesting(), + KeyExchangeSource::Default()), + compressed_certs_cache_( + QuicCompressedCertsCache::kQuicCompressedCertsCacheSize) { + config_.SetMaxBidirectionalStreamsToSend(kMaxStreamsForTest); + QuicConfigPeer::SetReceivedMaxBidirectionalStreams(&config_, + kMaxStreamsForTest); + config_.SetMaxUnidirectionalStreamsToSend(kMaxStreamsForTest); + + config_.SetInitialStreamFlowControlWindowToSend( + kInitialStreamFlowControlWindowForTest); + config_.SetInitialMaxStreamDataBytesIncomingBidirectionalToSend( + kInitialStreamFlowControlWindowForTest); + config_.SetInitialMaxStreamDataBytesOutgoingBidirectionalToSend( + kInitialStreamFlowControlWindowForTest); + config_.SetInitialMaxStreamDataBytesUnidirectionalToSend( + kInitialStreamFlowControlWindowForTest); + config_.SetInitialSessionFlowControlWindowToSend( + kInitialSessionFlowControlWindowForTest); + if (VersionUsesHttp3(transport_version())) { + QuicConfigPeer::SetReceivedMaxUnidirectionalStreams( + &config_, kMaxStreamsForTest + 3); + } else { + QuicConfigPeer::SetReceivedMaxUnidirectionalStreams(&config_, + kMaxStreamsForTest); + } + + ParsedQuicVersionVector supported_versions = SupportedVersions(version()); + connection_ = new StrictMock( + &helper_, &alarm_factory_, Perspective::IS_SERVER, supported_versions); + connection_->AdvanceTime(QuicTime::Delta::FromSeconds(1)); + connection_->SetEncrypter( + ENCRYPTION_FORWARD_SECURE, + std::make_unique(connection_->perspective())); + session_ = std::make_unique( + config_, connection_, &owner_, &stream_helper_, &crypto_config_, + &compressed_certs_cache_, &memory_cache_backend_); + MockClock clock; + handshake_message_ = crypto_config_.AddDefaultConfig( + QuicRandom::GetInstance(), &clock, + QuicCryptoServerConfig::ConfigOptions()); + session_->Initialize(); + + if (VersionHasIetfQuicFrames(transport_version())) { + EXPECT_CALL(*session_, WriteControlFrame(_, _)) + .WillRepeatedly(Invoke(&ClearControlFrameWithTransmissionType)); + } + session_->OnConfigNegotiated(); + } + + QuicStreamId GetNthClientInitiatedBidirectionalId(int n) { + return GetNthClientInitiatedBidirectionalStreamId(transport_version(), n); + } + + QuicStreamId GetNthServerInitiatedUnidirectionalId(int n) { + return quic::test::GetNthServerInitiatedUnidirectionalStreamId( + transport_version(), n); + } + + ParsedQuicVersion version() const { return GetParam(); } + + QuicTransportVersion transport_version() const { + return version().transport_version; + } + + void InjectStopSending(QuicStreamId stream_id, + QuicRstStreamErrorCode rst_stream_code) { + // Create and inject a STOP_SENDING frame. In GOOGLE QUIC, receiving a + // RST_STREAM frame causes a two-way close. For IETF QUIC, RST_STREAM causes + // a one-way close. + if (!VersionHasIetfQuicFrames(transport_version())) { + // Only needed for version 99/IETF QUIC. + return; + } + EXPECT_CALL(owner_, OnStopSendingReceived(_)).Times(1); + QuicStopSendingFrame stop_sending(kInvalidControlFrameId, stream_id, + rst_stream_code); + // Expect the RESET_STREAM that is generated in response to receiving a + // STOP_SENDING. + EXPECT_CALL(*connection_, OnStreamReset(stream_id, rst_stream_code)); + session_->OnStopSendingFrame(stop_sending); + } + + StrictMock owner_; + StrictMock stream_helper_; + MockQuicConnectionHelper helper_; + MockAlarmFactory alarm_factory_; + StrictMock* connection_; + QuicConfig config_; + QuicCryptoServerConfig crypto_config_; + QuicCompressedCertsCache compressed_certs_cache_; + QuicMemoryCacheBackend memory_cache_backend_; + std::unique_ptr session_; + std::unique_ptr handshake_message_; +}; + +INSTANTIATE_TEST_SUITE_P(Tests, QuicSimpleServerSessionTest, + ::testing::ValuesIn(AllSupportedVersions()), + ::testing::PrintToStringParamName()); + +TEST_P(QuicSimpleServerSessionTest, CloseStreamDueToReset) { + // Send some data open a stream, then reset it. + QuicStreamFrame data1(GetNthClientInitiatedBidirectionalId(0), false, 0, + kStreamData); + session_->OnStreamFrame(data1); + EXPECT_EQ(1u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); + + // Receive a reset (and send a RST in response). + QuicRstStreamFrame rst1(kInvalidControlFrameId, + GetNthClientInitiatedBidirectionalId(0), + QUIC_ERROR_PROCESSING_STREAM, 0); + EXPECT_CALL(owner_, OnRstStreamReceived(_)).Times(1); + EXPECT_CALL(*session_, WriteControlFrame(_, _)); + + if (!VersionHasIetfQuicFrames(transport_version())) { + // For version 99, this is covered in InjectStopSending() + EXPECT_CALL(*connection_, + OnStreamReset(GetNthClientInitiatedBidirectionalId(0), + QUIC_RST_ACKNOWLEDGEMENT)); + } + session_->OnRstStream(rst1); + // Create and inject a STOP_SENDING frame. In GOOGLE QUIC, receiving a + // RST_STREAM frame causes a two-way close. For IETF QUIC, RST_STREAM causes + // a one-way close. + InjectStopSending(GetNthClientInitiatedBidirectionalId(0), + QUIC_ERROR_PROCESSING_STREAM); + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); + + // Send the same two bytes of payload in a new packet. + session_->OnStreamFrame(data1); + + // The stream should not be re-opened. + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); + EXPECT_TRUE(connection_->connected()); +} + +TEST_P(QuicSimpleServerSessionTest, NeverOpenStreamDueToReset) { + // Send a reset (and expect the peer to send a RST in response). + QuicRstStreamFrame rst1(kInvalidControlFrameId, + GetNthClientInitiatedBidirectionalId(0), + QUIC_ERROR_PROCESSING_STREAM, 0); + EXPECT_CALL(owner_, OnRstStreamReceived(_)).Times(1); + if (!VersionHasIetfQuicFrames(transport_version())) { + EXPECT_CALL(*session_, WriteControlFrame(_, _)); + // For version 99, this is covered in InjectStopSending() + EXPECT_CALL(*connection_, + OnStreamReset(GetNthClientInitiatedBidirectionalId(0), + QUIC_RST_ACKNOWLEDGEMENT)); + } + session_->OnRstStream(rst1); + // Create and inject a STOP_SENDING frame. In GOOGLE QUIC, receiving a + // RST_STREAM frame causes a two-way close. For IETF QUIC, RST_STREAM causes + // a one-way close. + InjectStopSending(GetNthClientInitiatedBidirectionalId(0), + QUIC_ERROR_PROCESSING_STREAM); + + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); + + QuicStreamFrame data1(GetNthClientInitiatedBidirectionalId(0), false, 0, + kStreamData); + session_->OnStreamFrame(data1); + + // The stream should never be opened, now that the reset is received. + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); + EXPECT_TRUE(connection_->connected()); +} + +TEST_P(QuicSimpleServerSessionTest, AcceptClosedStream) { + // Send some data to open two streams. + QuicStreamFrame frame1(GetNthClientInitiatedBidirectionalId(0), false, 0, + kStreamData); + QuicStreamFrame frame2(GetNthClientInitiatedBidirectionalId(1), false, 0, + kStreamData); + session_->OnStreamFrame(frame1); + session_->OnStreamFrame(frame2); + EXPECT_EQ(2u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); + + // Send a reset (and expect the peer to send a RST in response). + QuicRstStreamFrame rst(kInvalidControlFrameId, + GetNthClientInitiatedBidirectionalId(0), + QUIC_ERROR_PROCESSING_STREAM, 0); + EXPECT_CALL(owner_, OnRstStreamReceived(_)).Times(1); + if (!VersionHasIetfQuicFrames(transport_version())) { + EXPECT_CALL(*session_, WriteControlFrame(_, _)); + // For version 99, this is covered in InjectStopSending() + EXPECT_CALL(*connection_, + OnStreamReset(GetNthClientInitiatedBidirectionalId(0), + QUIC_RST_ACKNOWLEDGEMENT)); + } + session_->OnRstStream(rst); + // Create and inject a STOP_SENDING frame. In GOOGLE QUIC, receiving a + // RST_STREAM frame causes a two-way close. For IETF QUIC, RST_STREAM causes + // a one-way close. + InjectStopSending(GetNthClientInitiatedBidirectionalId(0), + QUIC_ERROR_PROCESSING_STREAM); + + // If we were tracking, we'd probably want to reject this because it's data + // past the reset point of stream 3. As it's a closed stream we just drop the + // data on the floor, but accept the packet because it has data for stream 5. + QuicStreamFrame frame3(GetNthClientInitiatedBidirectionalId(0), false, 2, + kStreamData); + QuicStreamFrame frame4(GetNthClientInitiatedBidirectionalId(1), false, 2, + kStreamData); + session_->OnStreamFrame(frame3); + session_->OnStreamFrame(frame4); + // The stream should never be opened, now that the reset is received. + EXPECT_EQ(1u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); + EXPECT_TRUE(connection_->connected()); +} + +TEST_P(QuicSimpleServerSessionTest, CreateIncomingStreamDisconnected) { + // EXPECT_QUIC_BUG tests are expensive so only run one instance of them. + if (version() != AllSupportedVersions()[0]) { + return; + } + + // Tests that incoming stream creation fails when connection is not connected. + size_t initial_num_open_stream = + QuicSessionPeer::GetNumOpenDynamicStreams(session_.get()); + QuicConnectionPeer::TearDownLocalConnectionState(connection_); + EXPECT_QUIC_BUG(QuicSimpleServerSessionPeer::CreateIncomingStream( + session_.get(), GetNthClientInitiatedBidirectionalId(0)), + "ShouldCreateIncomingStream called when disconnected"); + EXPECT_EQ(initial_num_open_stream, + QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); +} + +TEST_P(QuicSimpleServerSessionTest, CreateIncomingStream) { + QuicSpdyStream* stream = QuicSimpleServerSessionPeer::CreateIncomingStream( + session_.get(), GetNthClientInitiatedBidirectionalId(0)); + EXPECT_NE(nullptr, stream); + EXPECT_EQ(GetNthClientInitiatedBidirectionalId(0), stream->id()); +} + +TEST_P(QuicSimpleServerSessionTest, CreateOutgoingDynamicStreamDisconnected) { + // EXPECT_QUIC_BUG tests are expensive so only run one instance of them. + if (version() != AllSupportedVersions()[0]) { + return; + } + + // Tests that outgoing stream creation fails when connection is not connected. + size_t initial_num_open_stream = + QuicSessionPeer::GetNumOpenDynamicStreams(session_.get()); + QuicConnectionPeer::TearDownLocalConnectionState(connection_); + EXPECT_QUIC_BUG( + QuicSimpleServerSessionPeer::CreateOutgoingUnidirectionalStream( + session_.get()), + "ShouldCreateOutgoingUnidirectionalStream called when disconnected"); + + EXPECT_EQ(initial_num_open_stream, + QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); +} + +TEST_P(QuicSimpleServerSessionTest, CreateOutgoingDynamicStreamUnencrypted) { + // EXPECT_QUIC_BUG tests are expensive so only run one instance of them. + if (version() != AllSupportedVersions()[0]) { + return; + } + + // Tests that outgoing stream creation fails when encryption has not yet been + // established. + size_t initial_num_open_stream = + QuicSessionPeer::GetNumOpenDynamicStreams(session_.get()); + EXPECT_QUIC_BUG( + QuicSimpleServerSessionPeer::CreateOutgoingUnidirectionalStream( + session_.get()), + "Encryption not established so no outgoing stream created."); + EXPECT_EQ(initial_num_open_stream, + QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); +} + +// Tests that calling GetOrCreateStream() on an outgoing stream not promised yet +// should result close connection. +TEST_P(QuicSimpleServerSessionTest, GetEvenIncomingError) { + const size_t initial_num_open_stream = + QuicSessionPeer::GetNumOpenDynamicStreams(session_.get()); + const QuicErrorCode expected_error = VersionUsesHttp3(transport_version()) + ? QUIC_HTTP_STREAM_WRONG_DIRECTION + : QUIC_INVALID_STREAM_ID; + EXPECT_CALL(*connection_, CloseConnection(expected_error, + "Data for nonexistent stream", _)); + EXPECT_EQ(nullptr, + QuicSessionPeer::GetOrCreateStream( + session_.get(), GetNthServerInitiatedUnidirectionalId(3))); + EXPECT_EQ(initial_num_open_stream, + QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/tools/quic_simple_server_stream.cc b/quiche/quic/tools/quic_simple_server_stream.cc new file mode 100644 index 000000000000..b0cb9cb4ce7d --- /dev/null +++ b/quiche/quic/tools/quic_simple_server_stream.cc @@ -0,0 +1,503 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/quic_simple_server_stream.h" + +#include +#include +#include + +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/quic/core/http/quic_spdy_stream.h" +#include "quiche/quic/core/http/spdy_utils.h" +#include "quiche/quic/core/http/web_transport_http3.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/quic/tools/quic_simple_server_session.h" +#include "quiche/spdy/core/spdy_protocol.h" + +using spdy::Http2HeaderBlock; + +namespace quic { + +QuicSimpleServerStream::QuicSimpleServerStream( + QuicStreamId id, QuicSpdySession* session, StreamType type, + QuicSimpleServerBackend* quic_simple_server_backend) + : QuicSpdyServerStreamBase(id, session, type), + content_length_(-1), + generate_bytes_length_(0), + quic_simple_server_backend_(quic_simple_server_backend) { + QUICHE_DCHECK(quic_simple_server_backend_); +} + +QuicSimpleServerStream::QuicSimpleServerStream( + PendingStream* pending, QuicSpdySession* session, + QuicSimpleServerBackend* quic_simple_server_backend) + : QuicSpdyServerStreamBase(pending, session), + content_length_(-1), + generate_bytes_length_(0), + quic_simple_server_backend_(quic_simple_server_backend) { + QUICHE_DCHECK(quic_simple_server_backend_); +} + +QuicSimpleServerStream::~QuicSimpleServerStream() { + quic_simple_server_backend_->CloseBackendResponseStream(this); +} + +void QuicSimpleServerStream::OnInitialHeadersComplete( + bool fin, size_t frame_len, const QuicHeaderList& header_list) { + QuicSpdyStream::OnInitialHeadersComplete(fin, frame_len, header_list); + // QuicSpdyStream::OnInitialHeadersComplete() may have already sent error + // response. + if (!response_sent_ && + !SpdyUtils::CopyAndValidateHeaders(header_list, &content_length_, + &request_headers_)) { + QUIC_DVLOG(1) << "Invalid headers"; + SendErrorResponse(); + } + ConsumeHeaderList(); + + // CONNECT requests do not carry any message content but carry data after the + // headers, so they require sending the response right after parsing the + // headers even though the FIN bit has not been received on the request + // stream. + if (!fin && !response_sent_ && IsConnectRequest()) { + if (quic_simple_server_backend_ == nullptr) { + QUIC_DVLOG(1) << "Backend is missing on CONNECT headers."; + SendErrorResponse(); + return; + } + + if (web_transport() != nullptr) { + QuicSimpleServerBackend::WebTransportResponse response = + quic_simple_server_backend_->ProcessWebTransportRequest( + request_headers_, web_transport()); + if (response.response_headers[":status"] == "200") { + WriteHeaders(std::move(response.response_headers), false, nullptr); + if (response.visitor != nullptr) { + web_transport()->SetVisitor(std::move(response.visitor)); + } + web_transport()->HeadersReceived(request_headers_); + } else { + WriteHeaders(std::move(response.response_headers), true, nullptr); + } + return; + } + + quic_simple_server_backend_->HandleConnectHeaders(request_headers_, + /*request_handler=*/this); + } +} + +void QuicSimpleServerStream::OnBodyAvailable() { + while (HasBytesToRead()) { + struct iovec iov; + if (GetReadableRegions(&iov, 1) == 0) { + // No more data to read. + break; + } + QUIC_DVLOG(1) << "Stream " << id() << " processed " << iov.iov_len + << " bytes."; + body_.append(static_cast(iov.iov_base), iov.iov_len); + + if (content_length_ >= 0 && + body_.size() > static_cast(content_length_)) { + QUIC_DVLOG(1) << "Body size (" << body_.size() << ") > content length (" + << content_length_ << ")."; + SendErrorResponse(); + return; + } + MarkConsumed(iov.iov_len); + } + + if (!sequencer()->IsClosed()) { + if (IsConnectRequest()) { + HandleRequestConnectData(/*fin_received=*/false); + } + sequencer()->SetUnblocked(); + return; + } + + // If the sequencer is closed, then all the body, including the fin, has been + // consumed. + OnFinRead(); + + if (write_side_closed() || fin_buffered()) { + return; + } + + if (IsConnectRequest()) { + HandleRequestConnectData(/*fin_received=*/true); + } else { + SendResponse(); + } +} + +void QuicSimpleServerStream::HandleRequestConnectData(bool fin_received) { + QUICHE_DCHECK(IsConnectRequest()); + + if (quic_simple_server_backend_ == nullptr) { + QUIC_DVLOG(1) << "Backend is missing on CONNECT data."; + ResetWriteSide( + QuicResetStreamError::FromInternal(QUIC_STREAM_CONNECT_ERROR)); + return; + } + + // Clear `body_`, so only new data is sent to the backend next time. + std::string data = std::move(body_); + body_.clear(); + + quic_simple_server_backend_->HandleConnectData(data, + /*data_complete=*/fin_received, + this); +} + +void QuicSimpleServerStream::SendResponse() { + QUICHE_DCHECK(!IsConnectRequest()); + + if (request_headers_.empty()) { + QUIC_DVLOG(1) << "Request headers empty."; + SendErrorResponse(); + return; + } + + if (content_length_ > 0 && + static_cast(content_length_) != body_.size()) { + QUIC_DVLOG(1) << "Content length (" << content_length_ << ") != body size (" + << body_.size() << ")."; + SendErrorResponse(); + return; + } + + if (!request_headers_.contains(":authority")) { + QUIC_DVLOG(1) << "Request headers do not contain :authority."; + SendErrorResponse(); + return; + } + + if (!request_headers_.contains(":path")) { + QUIC_DVLOG(1) << "Request headers do not contain :path."; + SendErrorResponse(); + return; + } + + if (quic_simple_server_backend_ == nullptr) { + QUIC_DVLOG(1) << "Backend is missing in SendResponse()."; + SendErrorResponse(); + return; + } + + if (web_transport() != nullptr) { + QuicSimpleServerBackend::WebTransportResponse response = + quic_simple_server_backend_->ProcessWebTransportRequest( + request_headers_, web_transport()); + if (response.response_headers[":status"] == "200") { + WriteHeaders(std::move(response.response_headers), false, nullptr); + if (response.visitor != nullptr) { + web_transport()->SetVisitor(std::move(response.visitor)); + } + web_transport()->HeadersReceived(request_headers_); + } else { + WriteHeaders(std::move(response.response_headers), true, nullptr); + } + return; + } + + // Fetch the response from the backend interface and wait for callback once + // response is ready + quic_simple_server_backend_->FetchResponseFromBackend(request_headers_, body_, + this); +} + +QuicConnectionId QuicSimpleServerStream::connection_id() const { + return spdy_session()->connection_id(); +} + +QuicStreamId QuicSimpleServerStream::stream_id() const { return id(); } + +std::string QuicSimpleServerStream::peer_host() const { + return spdy_session()->peer_address().host().ToString(); +} + +QuicSpdyStream* QuicSimpleServerStream::GetStream() { return this; } + +namespace { + +class DelayedResponseAlarm : public QuicAlarm::DelegateWithContext { + public: + DelayedResponseAlarm(QuicSimpleServerStream* stream, + const QuicBackendResponse* response) + : QuicAlarm::DelegateWithContext( + stream->spdy_session()->connection()->context()), + stream_(stream), + response_(response) { + stream_ = stream; + response_ = response; + } + + ~DelayedResponseAlarm() override = default; + + void OnAlarm() override { stream_->Respond(response_); } + + private: + QuicSimpleServerStream* stream_; + const QuicBackendResponse* response_; +}; + +} // namespace + +void QuicSimpleServerStream::OnResponseBackendComplete( + const QuicBackendResponse* response) { + if (response == nullptr) { + QUIC_DVLOG(1) << "Response not found in cache."; + SendNotFoundResponse(); + return; + } + + auto delay = response->delay(); + if (delay.IsZero()) { + Respond(response); + return; + } + + auto* connection = session()->connection(); + delayed_response_alarm_.reset(connection->alarm_factory()->CreateAlarm( + new DelayedResponseAlarm(this, response))); + delayed_response_alarm_->Set(connection->clock()->Now() + delay); +} + +void QuicSimpleServerStream::Respond(const QuicBackendResponse* response) { + // Send Early Hints first. + for (const auto& headers : response->early_hints()) { + QUIC_DVLOG(1) << "Stream " << id() << " sending an Early Hints response: " + << headers.DebugString(); + WriteHeaders(headers.Clone(), false, nullptr); + } + + if (response->response_type() == QuicBackendResponse::CLOSE_CONNECTION) { + QUIC_DVLOG(1) << "Special response: closing connection."; + OnUnrecoverableError(QUIC_NO_ERROR, "Toy server forcing close"); + return; + } + + if (response->response_type() == QuicBackendResponse::IGNORE_REQUEST) { + QUIC_DVLOG(1) << "Special response: ignoring request."; + return; + } + + if (response->response_type() == QuicBackendResponse::BACKEND_ERR_RESPONSE) { + QUIC_DVLOG(1) << "Quic Proxy: Backend connection error."; + /*502 Bad Gateway + The server was acting as a gateway or proxy and received an + invalid response from the upstream server.*/ + SendErrorResponse(502); + return; + } + + // Examing response status, if it was not pure integer as typical h2 + // response status, send error response. Notice that + // QuicHttpResponseCache push urls are strictly authority + path only, + // scheme is not included (see |QuicHttpResponseCache::GetKey()|). + std::string request_url = request_headers_[":authority"].as_string() + + request_headers_[":path"].as_string(); + int response_code; + const Http2HeaderBlock& response_headers = response->headers(); + if (!ParseHeaderStatusCode(response_headers, &response_code)) { + auto status = response_headers.find(":status"); + if (status == response_headers.end()) { + QUIC_LOG(WARNING) + << ":status not present in response from cache for request " + << request_url; + } else { + QUIC_LOG(WARNING) << "Illegal (non-integer) response :status from cache: " + << status->second << " for request " << request_url; + } + SendErrorResponse(); + return; + } + + if (response->response_type() == QuicBackendResponse::INCOMPLETE_RESPONSE) { + QUIC_DVLOG(1) + << "Stream " << id() + << " sending an incomplete response, i.e. no trailer, no fin."; + SendIncompleteResponse(response->headers().Clone(), response->body()); + return; + } + + if (response->response_type() == QuicBackendResponse::GENERATE_BYTES) { + QUIC_DVLOG(1) << "Stream " << id() << " sending a generate bytes response."; + std::string path = request_headers_[":path"].as_string().substr(1); + if (!absl::SimpleAtoi(path, &generate_bytes_length_)) { + QUIC_LOG(ERROR) << "Path is not a number."; + SendNotFoundResponse(); + return; + } + Http2HeaderBlock headers = response->headers().Clone(); + headers["content-length"] = absl::StrCat(generate_bytes_length_); + + WriteHeaders(std::move(headers), false, nullptr); + QUICHE_DCHECK(!response_sent_); + response_sent_ = true; + + WriteGeneratedBytes(); + + return; + } + + QUIC_DVLOG(1) << "Stream " << id() << " sending response."; + SendHeadersAndBodyAndTrailers(response->headers().Clone(), response->body(), + response->trailers().Clone()); +} + +void QuicSimpleServerStream::SendStreamData(absl::string_view data, + bool close_stream) { + // Doesn't make sense to call this without data or `close_stream`. + QUICHE_DCHECK(!data.empty() || close_stream); + + if (close_stream) { + SendHeadersAndBodyAndTrailers( + /*response_headers=*/absl::nullopt, data, + /*response_trailers=*/spdy::Http2HeaderBlock()); + } else { + SendIncompleteResponse(/*response_headers=*/absl::nullopt, data); + } +} + +void QuicSimpleServerStream::TerminateStreamWithError( + QuicResetStreamError error) { + QUIC_DVLOG(1) << "Stream " << id() << " abruptly terminating with error " + << error.internal_code(); + ResetWriteSide(error); +} + +void QuicSimpleServerStream::OnCanWrite() { + QuicSpdyStream::OnCanWrite(); + WriteGeneratedBytes(); +} + +void QuicSimpleServerStream::WriteGeneratedBytes() { + static size_t kChunkSize = 1024; + while (!HasBufferedData() && generate_bytes_length_ > 0) { + size_t len = std::min(kChunkSize, generate_bytes_length_); + std::string data(len, 'a'); + generate_bytes_length_ -= len; + bool fin = generate_bytes_length_ == 0; + WriteOrBufferBody(data, fin); + } +} + +void QuicSimpleServerStream::SendNotFoundResponse() { + QUIC_DVLOG(1) << "Stream " << id() << " sending not found response."; + Http2HeaderBlock headers; + headers[":status"] = "404"; + headers["content-length"] = absl::StrCat(strlen(kNotFoundResponseBody)); + SendHeadersAndBody(std::move(headers), kNotFoundResponseBody); +} + +void QuicSimpleServerStream::SendErrorResponse() { SendErrorResponse(0); } + +void QuicSimpleServerStream::SendErrorResponse(int resp_code) { + QUIC_DVLOG(1) << "Stream " << id() << " sending error response."; + if (!reading_stopped()) { + StopReading(); + } + Http2HeaderBlock headers; + if (resp_code <= 0) { + headers[":status"] = "500"; + } else { + headers[":status"] = absl::StrCat(resp_code); + } + headers["content-length"] = absl::StrCat(strlen(kErrorResponseBody)); + SendHeadersAndBody(std::move(headers), kErrorResponseBody); +} + +void QuicSimpleServerStream::SendIncompleteResponse( + absl::optional response_headers, absl::string_view body) { + // Headers should be sent iff not sent in a previous response. + QUICHE_DCHECK_NE(response_headers.has_value(), response_sent_); + + if (response_headers.has_value()) { + QUIC_DLOG(INFO) << "Stream " << id() << " writing headers (fin = false) : " + << response_headers.value().DebugString(); + // Do not mark response sent for early 100 continue response. + int response_code; + if (!ParseHeaderStatusCode(*response_headers, &response_code) || + response_code != 100) { + response_sent_ = true; + } + WriteHeaders(std::move(response_headers).value(), /*fin=*/false, nullptr); + } + + QUIC_DLOG(INFO) << "Stream " << id() + << " writing body (fin = false) with size: " << body.size(); + if (!body.empty()) { + WriteOrBufferBody(body, /*fin=*/false); + } +} + +void QuicSimpleServerStream::SendHeadersAndBody( + Http2HeaderBlock response_headers, absl::string_view body) { + SendHeadersAndBodyAndTrailers(std::move(response_headers), body, + Http2HeaderBlock()); +} + +void QuicSimpleServerStream::SendHeadersAndBodyAndTrailers( + absl::optional response_headers, absl::string_view body, + Http2HeaderBlock response_trailers) { + // Headers should be sent iff not sent in a previous response. + QUICHE_DCHECK_NE(response_headers.has_value(), response_sent_); + + if (response_headers.has_value()) { + // Send the headers, with a FIN if there's nothing else to send. + bool send_fin = (body.empty() && response_trailers.empty()); + QUIC_DLOG(INFO) << "Stream " << id() + << " writing headers (fin = " << send_fin + << ") : " << response_headers.value().DebugString(); + WriteHeaders(std::move(response_headers).value(), send_fin, nullptr); + response_sent_ = true; + if (send_fin) { + // Nothing else to send. + return; + } + } + + // Send the body, with a FIN if there's no trailers to send. + bool send_fin = response_trailers.empty(); + QUIC_DLOG(INFO) << "Stream " << id() << " writing body (fin = " << send_fin + << ") with size: " << body.size(); + if (!body.empty() || send_fin) { + WriteOrBufferBody(body, send_fin); + } + if (send_fin) { + // Nothing else to send. + return; + } + + // Send the trailers. A FIN is always sent with trailers. + QUIC_DLOG(INFO) << "Stream " << id() << " writing trailers (fin = true): " + << response_trailers.DebugString(); + WriteTrailers(std::move(response_trailers), nullptr); +} + +bool QuicSimpleServerStream::IsConnectRequest() const { + auto method_it = request_headers_.find(":method"); + return method_it != request_headers_.end() && method_it->second == "CONNECT"; +} + +void QuicSimpleServerStream::OnInvalidHeaders() { + QUIC_DVLOG(1) << "Invalid headers"; + SendErrorResponse(400); +} + +const char* const QuicSimpleServerStream::kErrorResponseBody = "bad"; +const char* const QuicSimpleServerStream::kNotFoundResponseBody = + "file not found"; + +} // namespace quic diff --git a/quiche/quic/tools/quic_simple_server_stream.h b/quiche/quic/tools/quic_simple_server_stream.h new file mode 100644 index 000000000000..fe12b70fcfab --- /dev/null +++ b/quiche/quic/tools/quic_simple_server_stream.h @@ -0,0 +1,126 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TOOLS_QUIC_SIMPLE_SERVER_STREAM_H_ +#define QUICHE_QUIC_TOOLS_QUIC_SIMPLE_SERVER_STREAM_H_ + +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/quic/core/http/quic_spdy_server_stream_base.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/tools/quic_backend_response.h" +#include "quiche/quic/tools/quic_simple_server_backend.h" +#include "quiche/spdy/core/http2_header_block.h" +#include "quiche/spdy/core/spdy_framer.h" + +namespace quic { + +// All this does right now is aggregate data, and on fin, send an HTTP +// response. +class QuicSimpleServerStream : public QuicSpdyServerStreamBase, + public QuicSimpleServerBackend::RequestHandler { + public: + QuicSimpleServerStream(QuicStreamId id, QuicSpdySession* session, + StreamType type, + QuicSimpleServerBackend* quic_simple_server_backend); + QuicSimpleServerStream(PendingStream* pending, QuicSpdySession* session, + QuicSimpleServerBackend* quic_simple_server_backend); + QuicSimpleServerStream(const QuicSimpleServerStream&) = delete; + QuicSimpleServerStream& operator=(const QuicSimpleServerStream&) = delete; + ~QuicSimpleServerStream() override; + + // QuicSpdyStream + void OnInitialHeadersComplete(bool fin, size_t frame_len, + const QuicHeaderList& header_list) override; + void OnCanWrite() override; + + // QuicStream implementation called by the sequencer when there is + // data (or a FIN) to be read. + void OnBodyAvailable() override; + + void OnInvalidHeaders() override; + + // The response body of error responses. + static const char* const kErrorResponseBody; + static const char* const kNotFoundResponseBody; + + // Implements QuicSimpleServerBackend::RequestHandler callbacks + QuicConnectionId connection_id() const override; + QuicStreamId stream_id() const override; + std::string peer_host() const override; + QuicSpdyStream* GetStream() override; + void OnResponseBackendComplete(const QuicBackendResponse* response) override; + void SendStreamData(absl::string_view data, bool close_stream) override; + void TerminateStreamWithError(QuicResetStreamError error) override; + + void Respond(const QuicBackendResponse* response); + + protected: + // Handles fresh body data whenever received when method is CONNECT. + void HandleRequestConnectData(bool fin_received); + + // Sends a response using SendHeaders for the headers and WriteData for the + // body. + virtual void SendResponse(); + + // Sends a basic 500 response using SendHeaders for the headers and WriteData + // for the body. + virtual void SendErrorResponse(); + virtual void SendErrorResponse(int resp_code); + + // Sends a basic 404 response using SendHeaders for the headers and WriteData + // for the body. + void SendNotFoundResponse(); + + // Sends the response header (if not `absl::nullopt`) and body, but not the + // fin. + void SendIncompleteResponse( + absl::optional response_headers, + absl::string_view body); + + void SendHeadersAndBody(spdy::Http2HeaderBlock response_headers, + absl::string_view body); + void SendHeadersAndBodyAndTrailers( + absl::optional response_headers, + absl::string_view body, spdy::Http2HeaderBlock response_trailers); + + spdy::Http2HeaderBlock* request_headers() { return &request_headers_; } + + // Returns true iff the request (per saved `request_headers_`) is a CONNECT or + // Extended CONNECT request. + bool IsConnectRequest() const; + + const std::string& body() { return body_; } + + // Writes the body bytes for the GENERATE_BYTES response type. + void WriteGeneratedBytes(); + + void set_quic_simple_server_backend_for_test( + QuicSimpleServerBackend* backend) { + quic_simple_server_backend_ = backend; + } + + bool response_sent() const { return response_sent_; } + void set_response_sent() { response_sent_ = true; } + // The parsed headers received from the client. + spdy::Http2HeaderBlock request_headers_; + int64_t content_length_; + std::string body_; + + private: + uint64_t generate_bytes_length_; + // Whether response headers have already been sent. + bool response_sent_ = false; + + std::unique_ptr delayed_response_alarm_; + + QuicSimpleServerBackend* quic_simple_server_backend_; // Not owned. +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_QUIC_SIMPLE_SERVER_STREAM_H_ diff --git a/quiche/quic/tools/quic_simple_server_stream_test.cc b/quiche/quic/tools/quic_simple_server_stream_test.cc new file mode 100644 index 000000000000..ab8b60734ea5 --- /dev/null +++ b/quiche/quic/tools/quic_simple_server_stream_test.cc @@ -0,0 +1,912 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/quic_simple_server_stream.h" + +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/quic/core/crypto/null_encrypter.h" +#include "quiche/quic/core/http/http_encoder.h" +#include "quiche/quic/core/http/spdy_utils.h" +#include "quiche/quic/core/quic_alarm_factory.h" +#include "quiche/quic/core/quic_default_clock.h" +#include "quiche/quic/core/quic_error_codes.h" +#include "quiche/quic/core/quic_types.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/platform/api/quic_expect_bug.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/crypto_test_utils.h" +#include "quiche/quic/test_tools/quic_config_peer.h" +#include "quiche/quic/test_tools/quic_connection_peer.h" +#include "quiche/quic/test_tools/quic_session_peer.h" +#include "quiche/quic/test_tools/quic_spdy_session_peer.h" +#include "quiche/quic/test_tools/quic_stream_peer.h" +#include "quiche/quic/test_tools/quic_test_utils.h" +#include "quiche/quic/test_tools/simulator/simulator.h" +#include "quiche/quic/tools/quic_backend_response.h" +#include "quiche/quic/tools/quic_memory_cache_backend.h" +#include "quiche/quic/tools/quic_simple_server_backend.h" +#include "quiche/quic/tools/quic_simple_server_session.h" +#include "quiche/common/simple_buffer_allocator.h" + +using testing::_; +using testing::AnyNumber; +using testing::InSequence; +using testing::Invoke; +using testing::StrictMock; + +namespace quic { +namespace test { + +const size_t kFakeFrameLen = 60; +const size_t kErrorLength = strlen(QuicSimpleServerStream::kErrorResponseBody); +const size_t kDataFrameHeaderLength = 2; + +class TestStream : public QuicSimpleServerStream { + public: + TestStream(QuicStreamId stream_id, QuicSpdySession* session, StreamType type, + QuicSimpleServerBackend* quic_simple_server_backend) + : QuicSimpleServerStream(stream_id, session, type, + quic_simple_server_backend) { + EXPECT_CALL(*this, WriteOrBufferBody(_, _)) + .Times(AnyNumber()) + .WillRepeatedly([this](absl::string_view data, bool fin) { + this->QuicSimpleServerStream::WriteOrBufferBody(data, fin); + }); + } + + ~TestStream() override = default; + + MOCK_METHOD(void, FireAlarmMock, (), ()); + MOCK_METHOD(void, WriteHeadersMock, (bool fin), ()); + MOCK_METHOD(void, WriteEarlyHintsHeadersMock, (bool fin), ()); + MOCK_METHOD(void, WriteOrBufferBody, (absl::string_view data, bool fin), + (override)); + + size_t WriteHeaders( + spdy::Http2HeaderBlock header_block, bool fin, + quiche::QuicheReferenceCountedPointer + /*ack_listener*/) override { + if (header_block[":status"] == "103") { + WriteEarlyHintsHeadersMock(fin); + } else { + WriteHeadersMock(fin); + } + return 0; + } + + // Expose protected QuicSimpleServerStream methods. + void DoSendResponse() { SendResponse(); } + void DoSendErrorResponse() { QuicSimpleServerStream::SendErrorResponse(); } + + spdy::Http2HeaderBlock* mutable_headers() { return &request_headers_; } + void set_body(std::string body) { body_ = std::move(body); } + const std::string& body() const { return body_; } + int content_length() const { return content_length_; } + bool send_response_was_called() const { return send_response_was_called_; } + bool send_error_response_was_called() const { + return send_error_response_was_called_; + } + + absl::string_view GetHeader(absl::string_view key) const { + auto it = request_headers_.find(key); + QUICHE_DCHECK(it != request_headers_.end()); + return it->second; + } + + void ReplaceBackend(QuicSimpleServerBackend* backend) { + set_quic_simple_server_backend_for_test(backend); + } + + protected: + void SendResponse() override { + send_response_was_called_ = true; + QuicSimpleServerStream::SendResponse(); + } + + void SendErrorResponse(int resp_code) override { + send_error_response_was_called_ = true; + QuicSimpleServerStream::SendErrorResponse(resp_code); + } + + private: + bool send_response_was_called_ = false; + bool send_error_response_was_called_ = false; +}; + +namespace { + +class MockQuicSimpleServerSession : public QuicSimpleServerSession { + public: + const size_t kMaxStreamsForTest = 100; + + MockQuicSimpleServerSession( + QuicConnection* connection, MockQuicSessionVisitor* owner, + MockQuicCryptoServerStreamHelper* helper, + QuicCryptoServerConfig* crypto_config, + QuicCompressedCertsCache* compressed_certs_cache, + QuicSimpleServerBackend* quic_simple_server_backend) + : QuicSimpleServerSession(DefaultQuicConfig(), CurrentSupportedVersions(), + connection, owner, helper, crypto_config, + compressed_certs_cache, + quic_simple_server_backend) { + if (VersionHasIetfQuicFrames(connection->transport_version())) { + QuicSessionPeer::SetMaxOpenIncomingUnidirectionalStreams( + this, kMaxStreamsForTest); + QuicSessionPeer::SetMaxOpenIncomingBidirectionalStreams( + this, kMaxStreamsForTest); + } else { + QuicSessionPeer::SetMaxOpenIncomingStreams(this, kMaxStreamsForTest); + QuicSessionPeer::SetMaxOpenOutgoingStreams(this, kMaxStreamsForTest); + } + ON_CALL(*this, WritevData(_, _, _, _, _, _)) + .WillByDefault(Invoke(this, &MockQuicSimpleServerSession::ConsumeData)); + } + + MockQuicSimpleServerSession(const MockQuicSimpleServerSession&) = delete; + MockQuicSimpleServerSession& operator=(const MockQuicSimpleServerSession&) = + delete; + ~MockQuicSimpleServerSession() override = default; + + MOCK_METHOD(void, OnConnectionClosed, + (const QuicConnectionCloseFrame& frame, + ConnectionCloseSource source), + (override)); + MOCK_METHOD(QuicSpdyStream*, CreateIncomingStream, (QuicStreamId id), + (override)); + MOCK_METHOD(QuicConsumedData, WritevData, + (QuicStreamId id, size_t write_length, QuicStreamOffset offset, + StreamSendingState state, TransmissionType type, + EncryptionLevel level), + (override)); + MOCK_METHOD(void, OnStreamHeaderList, + (QuicStreamId stream_id, bool fin, size_t frame_len, + const QuicHeaderList& header_list), + (override)); + MOCK_METHOD(void, OnStreamHeadersPriority, + (QuicStreamId stream_id, + const spdy::SpdyStreamPrecedence& precedence), + (override)); + MOCK_METHOD(void, MaybeSendRstStreamFrame, + (QuicStreamId stream_id, QuicResetStreamError error, + QuicStreamOffset bytes_written), + (override)); + MOCK_METHOD(void, MaybeSendStopSendingFrame, + (QuicStreamId stream_id, QuicResetStreamError error), (override)); + + using QuicSession::ActivateStream; + + QuicConsumedData ConsumeData(QuicStreamId id, size_t write_length, + QuicStreamOffset offset, + StreamSendingState state, + TransmissionType /*type*/, + absl::optional /*level*/) { + if (write_length > 0) { + auto buf = std::make_unique(write_length); + QuicStream* stream = GetOrCreateStream(id); + QUICHE_DCHECK(stream); + QuicDataWriter writer(write_length, buf.get(), quiche::HOST_BYTE_ORDER); + stream->WriteStreamData(offset, write_length, &writer); + } else { + QUICHE_DCHECK(state != NO_FIN); + } + return QuicConsumedData(write_length, state != NO_FIN); + } + + spdy::Http2HeaderBlock original_request_headers_; +}; + +class QuicSimpleServerStreamTest : public QuicTestWithParam { + public: + QuicSimpleServerStreamTest() + : connection_(new StrictMock( + &simulator_, simulator_.GetAlarmFactory(), Perspective::IS_SERVER, + SupportedVersions(GetParam()))), + crypto_config_(new QuicCryptoServerConfig( + QuicCryptoServerConfig::TESTING, QuicRandom::GetInstance(), + crypto_test_utils::ProofSourceForTesting(), + KeyExchangeSource::Default())), + compressed_certs_cache_( + QuicCompressedCertsCache::kQuicCompressedCertsCacheSize), + session_(connection_, &session_owner_, &session_helper_, + crypto_config_.get(), &compressed_certs_cache_, + &memory_cache_backend_), + quic_response_(new QuicBackendResponse), + body_("hello world") { + connection_->set_visitor(&session_); + header_list_.OnHeaderBlockStart(); + header_list_.OnHeader(":authority", "www.google.com"); + header_list_.OnHeader(":path", "/"); + header_list_.OnHeader(":method", "POST"); + header_list_.OnHeader(":scheme", "https"); + header_list_.OnHeader("content-length", "11"); + + header_list_.OnHeaderBlockEnd(128, 128); + + // New streams rely on having the peer's flow control receive window + // negotiated in the config. + session_.config()->SetInitialStreamFlowControlWindowToSend( + kInitialStreamFlowControlWindowForTest); + session_.config()->SetInitialSessionFlowControlWindowToSend( + kInitialSessionFlowControlWindowForTest); + session_.Initialize(); + connection_->SetEncrypter( + quic::ENCRYPTION_FORWARD_SECURE, + std::make_unique(connection_->perspective())); + if (connection_->version().SupportsAntiAmplificationLimit()) { + QuicConnectionPeer::SetAddressValidated(connection_); + } + stream_ = new StrictMock( + GetNthClientInitiatedBidirectionalStreamId( + connection_->transport_version(), 0), + &session_, BIDIRECTIONAL, &memory_cache_backend_); + // Register stream_ in dynamic_stream_map_ and pass ownership to session_. + session_.ActivateStream(absl::WrapUnique(stream_)); + QuicConfigPeer::SetReceivedInitialSessionFlowControlWindow( + session_.config(), kMinimumFlowControlSendWindow); + QuicConfigPeer::SetReceivedInitialMaxStreamDataBytesUnidirectional( + session_.config(), kMinimumFlowControlSendWindow); + QuicConfigPeer::SetReceivedInitialMaxStreamDataBytesIncomingBidirectional( + session_.config(), kMinimumFlowControlSendWindow); + QuicConfigPeer::SetReceivedInitialMaxStreamDataBytesOutgoingBidirectional( + session_.config(), kMinimumFlowControlSendWindow); + QuicConfigPeer::SetReceivedMaxUnidirectionalStreams(session_.config(), 10); + session_.OnConfigNegotiated(); + simulator_.RunFor(QuicTime::Delta::FromSeconds(1)); + } + + const std::string& StreamBody() { return stream_->body(); } + + std::string StreamHeadersValue(const std::string& key) { + return (*stream_->mutable_headers())[key].as_string(); + } + + bool UsesHttp3() const { + return VersionUsesHttp3(connection_->transport_version()); + } + + void ReplaceBackend(std::unique_ptr backend) { + replacement_backend_ = std::move(backend); + stream_->ReplaceBackend(replacement_backend_.get()); + } + + quic::simulator::Simulator simulator_; + spdy::Http2HeaderBlock response_headers_; + MockQuicConnectionHelper helper_; + StrictMock* connection_; + StrictMock session_owner_; + StrictMock session_helper_; + std::unique_ptr crypto_config_; + QuicCompressedCertsCache compressed_certs_cache_; + QuicMemoryCacheBackend memory_cache_backend_; + std::unique_ptr replacement_backend_; + StrictMock session_; + StrictMock* stream_; // Owned by session_. + std::unique_ptr quic_response_; + std::string body_; + QuicHeaderList header_list_; +}; + +INSTANTIATE_TEST_SUITE_P(Tests, QuicSimpleServerStreamTest, + ::testing::ValuesIn(AllSupportedVersions()), + ::testing::PrintToStringParamName()); + +TEST_P(QuicSimpleServerStreamTest, TestFraming) { + EXPECT_CALL(session_, WritevData(_, _, _, _, _, _)) + .WillRepeatedly( + Invoke(&session_, &MockQuicSimpleServerSession::ConsumeData)); + stream_->OnStreamHeaderList(false, kFakeFrameLen, header_list_); + quiche::QuicheBuffer header = HttpEncoder::SerializeDataFrameHeader( + body_.length(), quiche::SimpleBufferAllocator::Get()); + std::string data = + UsesHttp3() ? absl::StrCat(header.AsStringView(), body_) : body_; + stream_->OnStreamFrame( + QuicStreamFrame(stream_->id(), /*fin=*/false, /*offset=*/0, data)); + EXPECT_EQ("11", StreamHeadersValue("content-length")); + EXPECT_EQ("/", StreamHeadersValue(":path")); + EXPECT_EQ("POST", StreamHeadersValue(":method")); + EXPECT_EQ(body_, StreamBody()); +} + +TEST_P(QuicSimpleServerStreamTest, TestFramingOnePacket) { + EXPECT_CALL(session_, WritevData(_, _, _, _, _, _)) + .WillRepeatedly( + Invoke(&session_, &MockQuicSimpleServerSession::ConsumeData)); + + stream_->OnStreamHeaderList(false, kFakeFrameLen, header_list_); + quiche::QuicheBuffer header = HttpEncoder::SerializeDataFrameHeader( + body_.length(), quiche::SimpleBufferAllocator::Get()); + std::string data = + UsesHttp3() ? absl::StrCat(header.AsStringView(), body_) : body_; + stream_->OnStreamFrame( + QuicStreamFrame(stream_->id(), /*fin=*/false, /*offset=*/0, data)); + EXPECT_EQ("11", StreamHeadersValue("content-length")); + EXPECT_EQ("/", StreamHeadersValue(":path")); + EXPECT_EQ("POST", StreamHeadersValue(":method")); + EXPECT_EQ(body_, StreamBody()); +} + +TEST_P(QuicSimpleServerStreamTest, SendQuicRstStreamNoErrorInStopReading) { + EXPECT_CALL(session_, WritevData(_, _, _, _, _, _)) + .WillRepeatedly( + Invoke(&session_, &MockQuicSimpleServerSession::ConsumeData)); + + EXPECT_FALSE(stream_->fin_received()); + EXPECT_FALSE(stream_->rst_received()); + + QuicStreamPeer::SetFinSent(stream_); + stream_->CloseWriteSide(); + + if (session_.version().UsesHttp3()) { + EXPECT_CALL(session_, + MaybeSendStopSendingFrame(_, QuicResetStreamError::FromInternal( + QUIC_STREAM_NO_ERROR))) + .Times(1); + } else { + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_STREAM_NO_ERROR), _)) + .Times(1); + } + stream_->StopReading(); +} + +TEST_P(QuicSimpleServerStreamTest, TestFramingExtraData) { + InSequence seq; + std::string large_body = "hello world!!!!!!"; + + // We'll automatically write out an error (headers + body) + EXPECT_CALL(*stream_, WriteHeadersMock(false)); + if (UsesHttp3()) { + EXPECT_CALL(session_, + WritevData(_, kDataFrameHeaderLength, _, NO_FIN, _, _)); + } + EXPECT_CALL(session_, WritevData(_, kErrorLength, _, FIN, _, _)); + + stream_->OnStreamHeaderList(false, kFakeFrameLen, header_list_); + quiche::QuicheBuffer header = HttpEncoder::SerializeDataFrameHeader( + body_.length(), quiche::SimpleBufferAllocator::Get()); + std::string data = + UsesHttp3() ? absl::StrCat(header.AsStringView(), body_) : body_; + + stream_->OnStreamFrame( + QuicStreamFrame(stream_->id(), /*fin=*/false, /*offset=*/0, data)); + // Content length is still 11. This will register as an error and we won't + // accept the bytes. + header = HttpEncoder::SerializeDataFrameHeader( + large_body.length(), quiche::SimpleBufferAllocator::Get()); + std::string data2 = UsesHttp3() + ? absl::StrCat(header.AsStringView(), large_body) + : large_body; + stream_->OnStreamFrame( + QuicStreamFrame(stream_->id(), /*fin=*/true, data.size(), data2)); + EXPECT_EQ("11", StreamHeadersValue("content-length")); + EXPECT_EQ("/", StreamHeadersValue(":path")); + EXPECT_EQ("POST", StreamHeadersValue(":method")); +} + +TEST_P(QuicSimpleServerStreamTest, SendResponseWithIllegalResponseStatus) { + // Send an illegal response with response status not supported by HTTP/2. + spdy::Http2HeaderBlock* request_headers = stream_->mutable_headers(); + (*request_headers)[":path"] = "/bar"; + (*request_headers)[":authority"] = "www.google.com"; + (*request_headers)[":method"] = "GET"; + + // HTTP/2 only supports integer responsecode, so "200 OK" is illegal. + response_headers_[":status"] = "200 OK"; + response_headers_["content-length"] = "5"; + std::string body = "Yummm"; + quiche::QuicheBuffer header = HttpEncoder::SerializeDataFrameHeader( + body.length(), quiche::SimpleBufferAllocator::Get()); + + memory_cache_backend_.AddResponse("www.google.com", "/bar", + std::move(response_headers_), body); + + QuicStreamPeer::SetFinReceived(stream_); + + InSequence s; + EXPECT_CALL(*stream_, WriteHeadersMock(false)); + if (UsesHttp3()) { + EXPECT_CALL(session_, WritevData(_, header.size(), _, NO_FIN, _, _)); + } + EXPECT_CALL(session_, WritevData(_, kErrorLength, _, FIN, _, _)); + + stream_->DoSendResponse(); + EXPECT_FALSE(QuicStreamPeer::read_side_closed(stream_)); + EXPECT_TRUE(stream_->write_side_closed()); +} + +TEST_P(QuicSimpleServerStreamTest, SendResponseWithIllegalResponseStatus2) { + // Send an illegal response with response status not supported by HTTP/2. + spdy::Http2HeaderBlock* request_headers = stream_->mutable_headers(); + (*request_headers)[":path"] = "/bar"; + (*request_headers)[":authority"] = "www.google.com"; + (*request_headers)[":method"] = "GET"; + + // HTTP/2 only supports 3-digit-integer, so "+200" is illegal. + response_headers_[":status"] = "+200"; + response_headers_["content-length"] = "5"; + std::string body = "Yummm"; + + quiche::QuicheBuffer header = HttpEncoder::SerializeDataFrameHeader( + body.length(), quiche::SimpleBufferAllocator::Get()); + + memory_cache_backend_.AddResponse("www.google.com", "/bar", + std::move(response_headers_), body); + + QuicStreamPeer::SetFinReceived(stream_); + + InSequence s; + EXPECT_CALL(*stream_, WriteHeadersMock(false)); + if (UsesHttp3()) { + EXPECT_CALL(session_, WritevData(_, header.size(), _, NO_FIN, _, _)); + } + EXPECT_CALL(session_, WritevData(_, kErrorLength, _, FIN, _, _)); + + stream_->DoSendResponse(); + EXPECT_FALSE(QuicStreamPeer::read_side_closed(stream_)); + EXPECT_TRUE(stream_->write_side_closed()); +} + +TEST_P(QuicSimpleServerStreamTest, SendResponseWithValidHeaders) { + // Add a request and response with valid headers. + spdy::Http2HeaderBlock* request_headers = stream_->mutable_headers(); + (*request_headers)[":path"] = "/bar"; + (*request_headers)[":authority"] = "www.google.com"; + (*request_headers)[":method"] = "GET"; + + response_headers_[":status"] = "200"; + response_headers_["content-length"] = "5"; + std::string body = "Yummm"; + + quiche::QuicheBuffer header = HttpEncoder::SerializeDataFrameHeader( + body.length(), quiche::SimpleBufferAllocator::Get()); + + memory_cache_backend_.AddResponse("www.google.com", "/bar", + std::move(response_headers_), body); + QuicStreamPeer::SetFinReceived(stream_); + + InSequence s; + EXPECT_CALL(*stream_, WriteHeadersMock(false)); + if (UsesHttp3()) { + EXPECT_CALL(session_, WritevData(_, header.size(), _, NO_FIN, _, _)); + } + EXPECT_CALL(session_, WritevData(_, body.length(), _, FIN, _, _)); + + stream_->DoSendResponse(); + EXPECT_FALSE(QuicStreamPeer::read_side_closed(stream_)); + EXPECT_TRUE(stream_->write_side_closed()); +} + +TEST_P(QuicSimpleServerStreamTest, SendResponseWithEarlyHints) { + std::string host = "www.google.com"; + std::string request_path = "/foo"; + std::string body = "Yummm"; + + // Add a request and response with early hints. + spdy::Http2HeaderBlock* request_headers = stream_->mutable_headers(); + (*request_headers)[":path"] = request_path; + (*request_headers)[":authority"] = host; + (*request_headers)[":method"] = "GET"; + + quiche::QuicheBuffer header = HttpEncoder::SerializeDataFrameHeader( + body.length(), quiche::SimpleBufferAllocator::Get()); + std::vector early_hints; + // Add two Early Hints. + const size_t kNumEarlyHintsResponses = 2; + for (size_t i = 0; i < kNumEarlyHintsResponses; ++i) { + spdy::Http2HeaderBlock hints; + hints["link"] = "; rel=preload; as=image"; + early_hints.push_back(std::move(hints)); + } + + response_headers_[":status"] = "200"; + response_headers_["content-length"] = "5"; + memory_cache_backend_.AddResponseWithEarlyHints( + host, request_path, std::move(response_headers_), body, early_hints); + QuicStreamPeer::SetFinReceived(stream_); + + InSequence s; + for (size_t i = 0; i < kNumEarlyHintsResponses; ++i) { + EXPECT_CALL(*stream_, WriteEarlyHintsHeadersMock(false)); + } + EXPECT_CALL(*stream_, WriteHeadersMock(false)); + if (UsesHttp3()) { + EXPECT_CALL(session_, WritevData(_, header.size(), _, NO_FIN, _, _)); + } + EXPECT_CALL(session_, WritevData(_, body.length(), _, FIN, _, _)); + + stream_->DoSendResponse(); + EXPECT_FALSE(QuicStreamPeer::read_side_closed(stream_)); + EXPECT_TRUE(stream_->write_side_closed()); +} + +class AlarmTestDelegate : public QuicAlarm::DelegateWithoutContext { + public: + AlarmTestDelegate(TestStream* stream) : stream_(stream) {} + + void OnAlarm() override { stream_->FireAlarmMock(); } + + private: + TestStream* stream_; +}; + +TEST_P(QuicSimpleServerStreamTest, SendResponseWithDelay) { + // Add a request and response with valid headers. + spdy::Http2HeaderBlock* request_headers = stream_->mutable_headers(); + std::string host = "www.google.com"; + std::string path = "/bar"; + (*request_headers)[":path"] = path; + (*request_headers)[":authority"] = host; + (*request_headers)[":method"] = "GET"; + + response_headers_[":status"] = "200"; + response_headers_["content-length"] = "5"; + std::string body = "Yummm"; + QuicTime::Delta delay = QuicTime::Delta::FromMilliseconds(3000); + + quiche::QuicheBuffer header = HttpEncoder::SerializeDataFrameHeader( + body.length(), quiche::SimpleBufferAllocator::Get()); + + memory_cache_backend_.AddResponse(host, path, std::move(response_headers_), + body); + auto did_delay_succeed = + memory_cache_backend_.SetResponseDelay(host, path, delay); + EXPECT_TRUE(did_delay_succeed); + auto did_invalid_delay_succeed = + memory_cache_backend_.SetResponseDelay(host, "nonsense", delay); + EXPECT_FALSE(did_invalid_delay_succeed); + std::unique_ptr alarm(connection_->alarm_factory()->CreateAlarm( + new AlarmTestDelegate(stream_))); + alarm->Set(connection_->clock()->Now() + delay); + QuicStreamPeer::SetFinReceived(stream_); + InSequence s; + EXPECT_CALL(*stream_, FireAlarmMock()); + EXPECT_CALL(*stream_, WriteHeadersMock(false)); + + if (UsesHttp3()) { + EXPECT_CALL(session_, WritevData(_, header.size(), _, NO_FIN, _, _)); + } + EXPECT_CALL(session_, WritevData(_, body.length(), _, FIN, _, _)); + + stream_->DoSendResponse(); + simulator_.RunFor(delay); + + EXPECT_FALSE(QuicStreamPeer::read_side_closed(stream_)); + EXPECT_TRUE(stream_->write_side_closed()); +} + +TEST_P(QuicSimpleServerStreamTest, TestSendErrorResponse) { + QuicStreamPeer::SetFinReceived(stream_); + + InSequence s; + EXPECT_CALL(*stream_, WriteHeadersMock(false)); + if (UsesHttp3()) { + EXPECT_CALL(session_, + WritevData(_, kDataFrameHeaderLength, _, NO_FIN, _, _)); + } + EXPECT_CALL(session_, WritevData(_, kErrorLength, _, FIN, _, _)); + + stream_->DoSendErrorResponse(); + EXPECT_FALSE(QuicStreamPeer::read_side_closed(stream_)); + EXPECT_TRUE(stream_->write_side_closed()); +} + +TEST_P(QuicSimpleServerStreamTest, InvalidMultipleContentLength) { + spdy::Http2HeaderBlock request_headers; + // \000 is a way to write the null byte when followed by a literal digit. + header_list_.OnHeader("content-length", absl::string_view("11\00012", 5)); + + if (session_.version().UsesHttp3()) { + EXPECT_CALL(session_, + MaybeSendStopSendingFrame(_, QuicResetStreamError::FromInternal( + QUIC_STREAM_NO_ERROR))); + } + EXPECT_CALL(*stream_, WriteHeadersMock(false)); + EXPECT_CALL(session_, WritevData(_, _, _, _, _, _)) + .WillRepeatedly( + Invoke(&session_, &MockQuicSimpleServerSession::ConsumeData)); + stream_->OnStreamHeaderList(true, kFakeFrameLen, header_list_); + + EXPECT_TRUE(QuicStreamPeer::read_side_closed(stream_)); + EXPECT_TRUE(stream_->reading_stopped()); + EXPECT_TRUE(stream_->write_side_closed()); +} + +TEST_P(QuicSimpleServerStreamTest, InvalidLeadingNullContentLength) { + spdy::Http2HeaderBlock request_headers; + // \000 is a way to write the null byte when followed by a literal digit. + header_list_.OnHeader("content-length", absl::string_view("\00012", 3)); + + if (session_.version().UsesHttp3()) { + EXPECT_CALL(session_, + MaybeSendStopSendingFrame(_, QuicResetStreamError::FromInternal( + QUIC_STREAM_NO_ERROR))); + } + EXPECT_CALL(*stream_, WriteHeadersMock(false)); + EXPECT_CALL(session_, WritevData(_, _, _, _, _, _)) + .WillRepeatedly( + Invoke(&session_, &MockQuicSimpleServerSession::ConsumeData)); + stream_->OnStreamHeaderList(true, kFakeFrameLen, header_list_); + + EXPECT_TRUE(QuicStreamPeer::read_side_closed(stream_)); + EXPECT_TRUE(stream_->reading_stopped()); + EXPECT_TRUE(stream_->write_side_closed()); +} + +TEST_P(QuicSimpleServerStreamTest, InvalidMultipleContentLengthII) { + spdy::Http2HeaderBlock request_headers; + // \000 is a way to write the null byte when followed by a literal digit. + header_list_.OnHeader("content-length", absl::string_view("11\00011", 5)); + + if (session_.version().UsesHttp3()) { + EXPECT_CALL(session_, + MaybeSendStopSendingFrame(_, QuicResetStreamError::FromInternal( + QUIC_STREAM_NO_ERROR))); + EXPECT_CALL(*stream_, WriteHeadersMock(false)); + EXPECT_CALL(session_, WritevData(_, _, _, _, _, _)) + .WillRepeatedly( + Invoke(&session_, &MockQuicSimpleServerSession::ConsumeData)); + } + + stream_->OnStreamHeaderList(false, kFakeFrameLen, header_list_); + + if (session_.version().UsesHttp3()) { + EXPECT_TRUE(QuicStreamPeer::read_side_closed(stream_)); + EXPECT_TRUE(stream_->reading_stopped()); + EXPECT_TRUE(stream_->write_side_closed()); + } else { + EXPECT_EQ(11, stream_->content_length()); + EXPECT_FALSE(QuicStreamPeer::read_side_closed(stream_)); + EXPECT_FALSE(stream_->reading_stopped()); + EXPECT_FALSE(stream_->write_side_closed()); + } +} + +TEST_P(QuicSimpleServerStreamTest, + DoNotSendQuicRstStreamNoErrorWithRstReceived) { + EXPECT_FALSE(stream_->reading_stopped()); + + if (VersionUsesHttp3(connection_->transport_version())) { + // Unidirectional stream type and then a Stream Cancellation instruction is + // sent on the QPACK decoder stream. Ignore these writes without any + // assumption on their number or size. + auto* qpack_decoder_stream = + QuicSpdySessionPeer::GetQpackDecoderSendStream(&session_); + EXPECT_CALL(session_, WritevData(qpack_decoder_stream->id(), _, _, _, _, _)) + .Times(AnyNumber()); + } + + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + _, + session_.version().UsesHttp3() + ? QuicResetStreamError::FromInternal(QUIC_STREAM_CANCELLED) + : QuicResetStreamError::FromInternal(QUIC_RST_ACKNOWLEDGEMENT), + _)) + .Times(1); + QuicRstStreamFrame rst_frame(kInvalidControlFrameId, stream_->id(), + QUIC_STREAM_CANCELLED, 1234); + stream_->OnStreamReset(rst_frame); + if (VersionHasIetfQuicFrames(connection_->transport_version())) { + EXPECT_CALL(session_owner_, OnStopSendingReceived(_)); + // Create and inject a STOP SENDING frame to complete the close + // of the stream. This is only needed for version 99/IETF QUIC. + QuicStopSendingFrame stop_sending(kInvalidControlFrameId, stream_->id(), + QUIC_STREAM_CANCELLED); + session_.OnStopSendingFrame(stop_sending); + } + EXPECT_TRUE(stream_->reading_stopped()); + EXPECT_TRUE(stream_->write_side_closed()); +} + +TEST_P(QuicSimpleServerStreamTest, InvalidHeadersWithFin) { + char arr[] = { + 0x3a, 0x68, 0x6f, 0x73, // :hos + 0x74, 0x00, 0x00, 0x00, // t... + 0x00, 0x00, 0x00, 0x00, // .... + 0x07, 0x3a, 0x6d, 0x65, // .:me + 0x74, 0x68, 0x6f, 0x64, // thod + 0x00, 0x00, 0x00, 0x03, // .... + 0x47, 0x45, 0x54, 0x00, // GET. + 0x00, 0x00, 0x05, 0x3a, // ...: + 0x70, 0x61, 0x74, 0x68, // path + 0x00, 0x00, 0x00, 0x04, // .... + 0x2f, 0x66, 0x6f, 0x6f, // /foo + 0x00, 0x00, 0x00, 0x07, // .... + 0x3a, 0x73, 0x63, 0x68, // :sch + 0x65, 0x6d, 0x65, 0x00, // eme. + 0x00, 0x00, 0x00, 0x00, // .... + 0x00, 0x00, 0x08, 0x3a, // ...: + 0x76, 0x65, 0x72, 0x73, // vers + '\x96', 0x6f, 0x6e, 0x00, // on. + 0x00, 0x00, 0x08, 0x48, // ...H + 0x54, 0x54, 0x50, 0x2f, // TTP/ + 0x31, 0x2e, 0x31, // 1.1 + }; + absl::string_view data(arr, ABSL_ARRAYSIZE(arr)); + QuicStreamFrame frame(stream_->id(), true, 0, data); + // Verify that we don't crash when we get a invalid headers in stream frame. + stream_->OnStreamFrame(frame); +} + +// Basic QuicSimpleServerBackend that implements its behavior through mocking. +class TestQuicSimpleServerBackend : public QuicSimpleServerBackend { + public: + TestQuicSimpleServerBackend() = default; + ~TestQuicSimpleServerBackend() override = default; + + // QuicSimpleServerBackend: + bool InitializeBackend(const std::string& /*backend_url*/) override { + return true; + } + bool IsBackendInitialized() const override { return true; } + MOCK_METHOD(void, FetchResponseFromBackend, + (const spdy::Http2HeaderBlock&, const std::string&, + RequestHandler*), + (override)); + MOCK_METHOD(void, HandleConnectHeaders, + (const spdy::Http2HeaderBlock&, RequestHandler*), (override)); + MOCK_METHOD(void, HandleConnectData, + (absl::string_view, bool, RequestHandler*), (override)); + void CloseBackendResponseStream( + RequestHandler* /*request_handler*/) override {} +}; + +ACTION_P(SendHeadersResponse, response_ptr) { + arg1->OnResponseBackendComplete(response_ptr); +} + +ACTION_P(SendStreamData, data, close_stream) { + arg2->SendStreamData(data, close_stream); +} + +ACTION_P(TerminateStream, error) { arg1->TerminateStreamWithError(error); } + +TEST_P(QuicSimpleServerStreamTest, ConnectSendsIntermediateResponses) { + auto test_backend = std::make_unique(); + TestQuicSimpleServerBackend* test_backend_ptr = test_backend.get(); + ReplaceBackend(std::move(test_backend)); + + constexpr absl::string_view kRequestBody = "\x11\x11"; + spdy::Http2HeaderBlock response_headers; + response_headers[":status"] = "200"; + QuicBackendResponse headers_response; + headers_response.set_headers(response_headers.Clone()); + headers_response.set_response_type(QuicBackendResponse::INCOMPLETE_RESPONSE); + constexpr absl::string_view kBody1 = "\x22\x22"; + constexpr absl::string_view kBody2 = "\x33\x33"; + + // Expect an initial headers-only request to result in a headers-only + // incomplete response. Then a data frame without fin, resulting in stream + // data. Then a data frame with fin, resulting in stream data with fin. + InSequence s; + EXPECT_CALL(*test_backend_ptr, HandleConnectHeaders(_, _)) + .WillOnce(SendHeadersResponse(&headers_response)); + EXPECT_CALL(*stream_, WriteHeadersMock(false)); + EXPECT_CALL(*test_backend_ptr, HandleConnectData(kRequestBody, false, _)) + .WillOnce(SendStreamData(kBody1, + /*close_stream=*/false)); + EXPECT_CALL(*stream_, WriteOrBufferBody(kBody1, false)); + EXPECT_CALL(*test_backend_ptr, HandleConnectData(kRequestBody, true, _)) + .WillOnce(SendStreamData(kBody2, + /*close_stream=*/true)); + EXPECT_CALL(*stream_, WriteOrBufferBody(kBody2, true)); + + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":authority", "www.google.com:4433"); + header_list.OnHeader(":method", "CONNECT"); + header_list.OnHeaderBlockEnd(128, 128); + + stream_->OnStreamHeaderList(/*fin=*/false, kFakeFrameLen, header_list); + quiche::QuicheBuffer header = HttpEncoder::SerializeDataFrameHeader( + kRequestBody.length(), quiche::SimpleBufferAllocator::Get()); + std::string data = UsesHttp3() + ? absl::StrCat(header.AsStringView(), kRequestBody) + : std::string(kRequestBody); + stream_->OnStreamFrame( + QuicStreamFrame(stream_->id(), /*fin=*/false, /*offset=*/0, data)); + stream_->OnStreamFrame( + QuicStreamFrame(stream_->id(), /*fin=*/true, data.length(), data)); + + // Expect to not go through SendResponse(). + EXPECT_FALSE(stream_->send_response_was_called()); + EXPECT_FALSE(stream_->send_error_response_was_called()); +} + +TEST_P(QuicSimpleServerStreamTest, ErrorOnUnhandledConnect) { + // Expect single set of failure response headers with FIN in response to the + // headers. Then, expect abrupt stream termination in response to the body. + EXPECT_CALL(*stream_, WriteHeadersMock(true)); + EXPECT_CALL(session_, MaybeSendRstStreamFrame(stream_->id(), _, _)); + + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":authority", "www.google.com:4433"); + header_list.OnHeader(":method", "CONNECT"); + header_list.OnHeaderBlockEnd(128, 128); + constexpr absl::string_view kRequestBody = "\x11\x11"; + + stream_->OnStreamHeaderList(/*fin=*/false, kFakeFrameLen, header_list); + quiche::QuicheBuffer header = HttpEncoder::SerializeDataFrameHeader( + kRequestBody.length(), quiche::SimpleBufferAllocator::Get()); + std::string data = UsesHttp3() + ? absl::StrCat(header.AsStringView(), kRequestBody) + : std::string(kRequestBody); + stream_->OnStreamFrame( + QuicStreamFrame(stream_->id(), /*fin=*/true, /*offset=*/0, data)); + + // Expect failure to not go through SendResponse(). + EXPECT_FALSE(stream_->send_response_was_called()); + EXPECT_FALSE(stream_->send_error_response_was_called()); +} + +TEST_P(QuicSimpleServerStreamTest, ConnectWithInvalidHeader) { + EXPECT_CALL(session_, WritevData(_, _, _, _, _, _)) + .WillRepeatedly( + Invoke(&session_, &MockQuicSimpleServerSession::ConsumeData)); + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":authority", "www.google.com:4433"); + header_list.OnHeader(":method", "CONNECT"); + // QUIC requires lower-case header names. + header_list.OnHeader("InVaLiD-HeAdEr", "Well that's just wrong!"); + header_list.OnHeaderBlockEnd(128, 128); + + if (UsesHttp3()) { + EXPECT_CALL(session_, + MaybeSendStopSendingFrame(_, QuicResetStreamError::FromInternal( + QUIC_STREAM_NO_ERROR))) + .Times(1); + } else { + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_STREAM_NO_ERROR), _)) + .Times(1); + } + EXPECT_CALL(*stream_, WriteHeadersMock(/*fin=*/false)); + stream_->OnStreamHeaderList(/*fin=*/false, kFakeFrameLen, header_list); + EXPECT_FALSE(stream_->send_response_was_called()); + EXPECT_TRUE(stream_->send_error_response_was_called()); +} + +TEST_P(QuicSimpleServerStreamTest, BackendCanTerminateStream) { + auto test_backend = std::make_unique(); + TestQuicSimpleServerBackend* test_backend_ptr = test_backend.get(); + ReplaceBackend(std::move(test_backend)); + + EXPECT_CALL(session_, WritevData(_, _, _, _, _, _)) + .WillRepeatedly( + Invoke(&session_, &MockQuicSimpleServerSession::ConsumeData)); + + QuicResetStreamError expected_error = + QuicResetStreamError::FromInternal(QUIC_STREAM_CONNECT_ERROR); + EXPECT_CALL(*test_backend_ptr, HandleConnectHeaders(_, _)) + .WillOnce(TerminateStream(expected_error)); + EXPECT_CALL(session_, + MaybeSendRstStreamFrame(stream_->id(), expected_error, _)); + + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":authority", "www.google.com:4433"); + header_list.OnHeader(":method", "CONNECT"); + header_list.OnHeaderBlockEnd(128, 128); + stream_->OnStreamHeaderList(/*fin=*/false, kFakeFrameLen, header_list); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/tools/quic_spdy_client_base.cc b/quiche/quic/tools/quic_spdy_client_base.cc new file mode 100644 index 000000000000..b46e78038b1e --- /dev/null +++ b/quiche/quic/tools/quic_spdy_client_base.cc @@ -0,0 +1,282 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/quic_spdy_client_base.h" + +#include + +#include "absl/strings/numbers.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/quic_random.h" +#include "quiche/quic/core/http/spdy_utils.h" +#include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/platform/api/quic_flags.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/quiche_text_utils.h" + +using spdy::Http2HeaderBlock; + +namespace quic { + +void QuicSpdyClientBase::ClientQuicDataToResend::Resend() { + client_->SendRequest(*headers_, body_, fin_); + headers_ = nullptr; +} + +QuicSpdyClientBase::QuicDataToResend::QuicDataToResend( + std::unique_ptr headers, absl::string_view body, bool fin) + : headers_(std::move(headers)), body_(body), fin_(fin) {} + +QuicSpdyClientBase::QuicDataToResend::~QuicDataToResend() = default; + +QuicSpdyClientBase::QuicSpdyClientBase( + const QuicServerId& server_id, + const ParsedQuicVersionVector& supported_versions, const QuicConfig& config, + QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory, + std::unique_ptr network_helper, + std::unique_ptr proof_verifier, + std::unique_ptr session_cache) + : QuicClientBase(server_id, supported_versions, config, helper, + alarm_factory, std::move(network_helper), + std::move(proof_verifier), std::move(session_cache)), + store_response_(false), + latest_response_code_(-1) {} + +QuicSpdyClientBase::~QuicSpdyClientBase() { + // We own the push promise index. We need to explicitly kill + // the session before the push promise index goes out of scope. + ResetSession(); +} + +QuicSpdyClientSession* QuicSpdyClientBase::client_session() { + return static_cast(QuicClientBase::session()); +} + +const QuicSpdyClientSession* QuicSpdyClientBase::client_session() const { + return static_cast(QuicClientBase::session()); +} + +void QuicSpdyClientBase::InitializeSession() { + if (max_inbound_header_list_size_ > 0) { + client_session()->set_max_inbound_header_list_size( + max_inbound_header_list_size_); + } + client_session()->Initialize(); + client_session()->CryptoConnect(); +} + +void QuicSpdyClientBase::OnClose(QuicSpdyStream* stream) { + QUICHE_DCHECK(stream != nullptr); + QuicSpdyClientStream* client_stream = + static_cast(stream); + + const Http2HeaderBlock& response_headers = client_stream->response_headers(); + if (response_listener_ != nullptr) { + response_listener_->OnCompleteResponse(stream->id(), response_headers, + client_stream->data()); + } + + // Store response headers and body. + if (store_response_) { + auto status = response_headers.find(":status"); + if (status == response_headers.end()) { + QUIC_LOG(ERROR) << "Missing :status response header"; + } else if (!absl::SimpleAtoi(status->second, &latest_response_code_)) { + QUIC_LOG(ERROR) << "Invalid :status response header: " << status->second; + } + latest_response_headers_ = response_headers.DebugString(); + preliminary_response_headers_ = + client_stream->preliminary_headers().DebugString(); + latest_response_header_block_ = response_headers.Clone(); + latest_response_body_ = std::string(client_stream->data()); + latest_response_trailers_ = + client_stream->received_trailers().DebugString(); + } +} + +std::unique_ptr QuicSpdyClientBase::CreateQuicClientSession( + const quic::ParsedQuicVersionVector& supported_versions, + QuicConnection* connection) { + return std::make_unique( + *config(), supported_versions, connection, server_id(), crypto_config(), + &push_promise_index_); +} + +void QuicSpdyClientBase::SendRequest(const Http2HeaderBlock& headers, + absl::string_view body, bool fin) { + if (GetQuicFlag(quic_client_convert_http_header_name_to_lowercase)) { + QUIC_CODE_COUNT(quic_client_convert_http_header_name_to_lowercase); + Http2HeaderBlock sanitized_headers; + for (const auto& p : headers) { + sanitized_headers[quiche::QuicheTextUtils::ToLower(p.first)] = p.second; + } + + SendRequestInternal(std::move(sanitized_headers), body, fin); + } else { + SendRequestInternal(headers.Clone(), body, fin); + } +} + +void QuicSpdyClientBase::SendRequestInternal(Http2HeaderBlock sanitized_headers, + absl::string_view body, bool fin) { + QuicClientPushPromiseIndex::TryHandle* handle; + QuicAsyncStatus rv = + push_promise_index()->Try(sanitized_headers, this, &handle); + if (rv == QUIC_SUCCESS) return; + + if (rv == QUIC_PENDING) { + // May need to retry request if asynchronous rendezvous fails. + AddPromiseDataToResend(sanitized_headers, body, fin); + return; + } + + QuicSpdyClientStream* stream = CreateClientStream(); + if (stream == nullptr) { + QUIC_BUG(quic_bug_10949_1) << "stream creation failed!"; + return; + } + stream->SendRequest(std::move(sanitized_headers), body, fin); +} + +void QuicSpdyClientBase::SendRequestAndWaitForResponse( + const Http2HeaderBlock& headers, absl::string_view body, bool fin) { + SendRequest(headers, body, fin); + while (WaitForEvents()) { + } +} + +void QuicSpdyClientBase::SendRequestsAndWaitForResponse( + const std::vector& url_list) { + for (size_t i = 0; i < url_list.size(); ++i) { + Http2HeaderBlock headers; + if (!SpdyUtils::PopulateHeaderBlockFromUrl(url_list[i], &headers)) { + QUIC_BUG(quic_bug_10949_2) << "Unable to create request"; + continue; + } + SendRequest(headers, "", true); + } + while (WaitForEvents()) { + } +} + +QuicSpdyClientStream* QuicSpdyClientBase::CreateClientStream() { + if (!connected()) { + return nullptr; + } + if (VersionHasIetfQuicFrames(client_session()->transport_version())) { + // Process MAX_STREAMS from peer or wait for liveness testing succeeds. + while (!client_session()->CanOpenNextOutgoingBidirectionalStream()) { + network_helper()->RunEventLoop(); + } + } + auto* stream = static_cast( + client_session()->CreateOutgoingBidirectionalStream()); + if (stream) { + stream->set_visitor(this); + } + return stream; +} + +bool QuicSpdyClientBase::goaway_received() const { + return client_session() && client_session()->goaway_received(); +} + +bool QuicSpdyClientBase::EarlyDataAccepted() { + return client_session()->EarlyDataAccepted(); +} + +bool QuicSpdyClientBase::ReceivedInchoateReject() { + return client_session()->ReceivedInchoateReject(); +} + +int QuicSpdyClientBase::GetNumSentClientHellosFromSession() { + return client_session()->GetNumSentClientHellos(); +} + +int QuicSpdyClientBase::GetNumReceivedServerConfigUpdatesFromSession() { + return client_session()->GetNumReceivedServerConfigUpdates(); +} + +void QuicSpdyClientBase::MaybeAddQuicDataToResend( + std::unique_ptr data_to_resend) { + data_to_resend_on_connect_.push_back(std::move(data_to_resend)); +} + +void QuicSpdyClientBase::ClearDataToResend() { + data_to_resend_on_connect_.clear(); +} + +void QuicSpdyClientBase::ResendSavedData() { + // Calling Resend will re-enqueue the data, so swap out + // data_to_resend_on_connect_ before iterating. + std::vector> old_data; + old_data.swap(data_to_resend_on_connect_); + for (const auto& data : old_data) { + data->Resend(); + } +} + +void QuicSpdyClientBase::AddPromiseDataToResend(const Http2HeaderBlock& headers, + absl::string_view body, + bool fin) { + std::unique_ptr new_headers( + new Http2HeaderBlock(headers.Clone())); + push_promise_data_to_resend_.reset( + new ClientQuicDataToResend(std::move(new_headers), body, fin, this)); +} + +bool QuicSpdyClientBase::CheckVary( + const Http2HeaderBlock& /*client_request*/, + const Http2HeaderBlock& /*promise_request*/, + const Http2HeaderBlock& /*promise_response*/) { + return true; +} + +void QuicSpdyClientBase::OnRendezvousResult(QuicSpdyStream* stream) { + std::unique_ptr data_to_resend = + std::move(push_promise_data_to_resend_); + if (stream) { + stream->set_visitor(this); + stream->OnBodyAvailable(); + } else if (data_to_resend) { + data_to_resend->Resend(); + } +} + +int QuicSpdyClientBase::latest_response_code() const { + QUIC_BUG_IF(quic_bug_10949_3, !store_response_) << "Response not stored!"; + return latest_response_code_; +} + +const std::string& QuicSpdyClientBase::latest_response_headers() const { + QUIC_BUG_IF(quic_bug_10949_4, !store_response_) << "Response not stored!"; + return latest_response_headers_; +} + +const std::string& QuicSpdyClientBase::preliminary_response_headers() const { + QUIC_BUG_IF(quic_bug_10949_5, !store_response_) << "Response not stored!"; + return preliminary_response_headers_; +} + +const Http2HeaderBlock& QuicSpdyClientBase::latest_response_header_block() + const { + QUIC_BUG_IF(quic_bug_10949_6, !store_response_) << "Response not stored!"; + return latest_response_header_block_; +} + +const std::string& QuicSpdyClientBase::latest_response_body() const { + QUIC_BUG_IF(quic_bug_10949_7, !store_response_) << "Response not stored!"; + return latest_response_body_; +} + +const std::string& QuicSpdyClientBase::latest_response_trailers() const { + QUIC_BUG_IF(quic_bug_10949_8, !store_response_) << "Response not stored!"; + return latest_response_trailers_; +} + +bool QuicSpdyClientBase::HasActiveRequests() { + return client_session()->HasActiveRequestStreams(); +} + +} // namespace quic diff --git a/quiche/quic/tools/quic_spdy_client_base.h b/quiche/quic/tools/quic_spdy_client_base.h new file mode 100644 index 000000000000..28c3e96693bc --- /dev/null +++ b/quiche/quic/tools/quic_spdy_client_base.h @@ -0,0 +1,237 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// A base class for the toy client, which connects to a specified port and sends +// QUIC request to that endpoint. + +#ifndef QUICHE_QUIC_TOOLS_QUIC_SPDY_CLIENT_BASE_H_ +#define QUICHE_QUIC_TOOLS_QUIC_SPDY_CLIENT_BASE_H_ + +#include + +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/crypto_handshake.h" +#include "quiche/quic/core/http/quic_client_push_promise_index.h" +#include "quiche/quic/core/http/quic_spdy_client_session.h" +#include "quiche/quic/core/http/quic_spdy_client_stream.h" +#include "quiche/quic/core/quic_config.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/tools/quic_client_base.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +class ProofVerifier; +class QuicServerId; +class SessionCache; + +class QuicSpdyClientBase : public QuicClientBase, + public QuicClientPushPromiseIndex::Delegate, + public QuicSpdyStream::Visitor { + public: + // A ResponseListener is notified when a complete response is received. + class ResponseListener { + public: + ResponseListener() {} + virtual ~ResponseListener() {} + virtual void OnCompleteResponse( + QuicStreamId id, const spdy::Http2HeaderBlock& response_headers, + absl::string_view response_body) = 0; + }; + + // A piece of data that can be sent multiple times. For example, it can be a + // HTTP request that is resent after a connect=>version negotiation=>reconnect + // sequence. + class QuicDataToResend { + public: + // |headers| may be null, since it's possible to send data without headers. + QuicDataToResend(std::unique_ptr headers, + absl::string_view body, bool fin); + QuicDataToResend(const QuicDataToResend&) = delete; + QuicDataToResend& operator=(const QuicDataToResend&) = delete; + + virtual ~QuicDataToResend(); + + // Must be overridden by specific classes with the actual method for + // re-sending data. + virtual void Resend() = 0; + + protected: + std::unique_ptr headers_; + absl::string_view body_; + bool fin_; + }; + + QuicSpdyClientBase(const QuicServerId& server_id, + const ParsedQuicVersionVector& supported_versions, + const QuicConfig& config, + QuicConnectionHelperInterface* helper, + QuicAlarmFactory* alarm_factory, + std::unique_ptr network_helper, + std::unique_ptr proof_verifier, + std::unique_ptr session_cache); + QuicSpdyClientBase(const QuicSpdyClientBase&) = delete; + QuicSpdyClientBase& operator=(const QuicSpdyClientBase&) = delete; + + ~QuicSpdyClientBase() override; + + // QuicSpdyStream::Visitor + void OnClose(QuicSpdyStream* stream) override; + + // A spdy session has to call CryptoConnect on top of the regular + // initialization. + void InitializeSession() override; + + // Sends an HTTP request and does not wait for response before returning. + void SendRequest(const spdy::Http2HeaderBlock& headers, + absl::string_view body, bool fin); + + // Sends an HTTP request and waits for response before returning. + void SendRequestAndWaitForResponse(const spdy::Http2HeaderBlock& headers, + absl::string_view body, bool fin); + + // Sends a request simple GET for each URL in |url_list|, and then waits for + // each to complete. + void SendRequestsAndWaitForResponse(const std::vector& url_list); + + // Returns a newly created QuicSpdyClientStream. + virtual QuicSpdyClientStream* CreateClientStream(); + + // Returns a the session used for this client downcasted to a + // QuicSpdyClientSession. + QuicSpdyClientSession* client_session(); + const QuicSpdyClientSession* client_session() const; + + QuicClientPushPromiseIndex* push_promise_index() { + return &push_promise_index_; + } + + bool CheckVary(const spdy::Http2HeaderBlock& client_request, + const spdy::Http2HeaderBlock& promise_request, + const spdy::Http2HeaderBlock& promise_response) override; + void OnRendezvousResult(QuicSpdyStream*) override; + + // If the crypto handshake has not yet been confirmed, adds the data to the + // queue of data to resend if the client receives a stateless reject. + // Otherwise, deletes the data. + void MaybeAddQuicDataToResend( + std::unique_ptr data_to_resend); + + void set_store_response(bool val) { store_response_ = val; } + + int latest_response_code() const; + const std::string& latest_response_headers() const; + const std::string& preliminary_response_headers() const; + const spdy::Http2HeaderBlock& latest_response_header_block() const; + const std::string& latest_response_body() const; + const std::string& latest_response_trailers() const; + + void set_response_listener(std::unique_ptr listener) { + response_listener_ = std::move(listener); + } + + void set_drop_response_body(bool drop_response_body) { + drop_response_body_ = drop_response_body; + } + bool drop_response_body() const { return drop_response_body_; } + + void set_enable_web_transport(bool enable_web_transport) { + enable_web_transport_ = enable_web_transport; + } + bool enable_web_transport() const { return enable_web_transport_; } + + void set_use_datagram_contexts(bool use_datagram_contexts) { + use_datagram_contexts_ = use_datagram_contexts; + } + bool use_datagram_contexts() const { return use_datagram_contexts_; } + + // QuicClientBase methods. + bool goaway_received() const override; + bool EarlyDataAccepted() override; + bool ReceivedInchoateReject() override; + + void set_max_inbound_header_list_size(size_t size) { + max_inbound_header_list_size_ = size; + } + + protected: + int GetNumSentClientHellosFromSession() override; + int GetNumReceivedServerConfigUpdatesFromSession() override; + + // Takes ownership of |connection|. + std::unique_ptr CreateQuicClientSession( + const quic::ParsedQuicVersionVector& supported_versions, + QuicConnection* connection) override; + + void ClearDataToResend() override; + + void ResendSavedData() override; + + void AddPromiseDataToResend(const spdy::Http2HeaderBlock& headers, + absl::string_view body, bool fin); + bool HasActiveRequests() override; + + private: + // Specific QuicClient class for storing data to resend. + class ClientQuicDataToResend : public QuicDataToResend { + public: + ClientQuicDataToResend(std::unique_ptr headers, + absl::string_view body, bool fin, + QuicSpdyClientBase* client) + : QuicDataToResend(std::move(headers), body, fin), client_(client) { + QUICHE_DCHECK(headers_); + QUICHE_DCHECK(client); + } + + ClientQuicDataToResend(const ClientQuicDataToResend&) = delete; + ClientQuicDataToResend& operator=(const ClientQuicDataToResend&) = delete; + ~ClientQuicDataToResend() override {} + + void Resend() override; + + private: + QuicSpdyClientBase* client_; + }; + + void SendRequestInternal(spdy::Http2HeaderBlock sanitized_headers, + absl::string_view body, bool fin); + + // Index of pending promised streams. Must outlive |session_|. + QuicClientPushPromiseIndex push_promise_index_; + + // If true, store the latest response code, headers, and body. + bool store_response_; + // HTTP response code from most recent response. + int latest_response_code_; + // HTTP/2 headers from most recent response. + std::string latest_response_headers_; + // preliminary 100 Continue HTTP/2 headers from most recent response, if any. + std::string preliminary_response_headers_; + // HTTP/2 headers from most recent response. + spdy::Http2HeaderBlock latest_response_header_block_; + // Body of most recent response. + std::string latest_response_body_; + // HTTP/2 trailers from most recent response. + std::string latest_response_trailers_; + + // Listens for full responses. + std::unique_ptr response_listener_; + + // Keeps track of any data that must be resent upon a subsequent successful + // connection, in case the client receives a stateless reject. + std::vector> data_to_resend_on_connect_; + + std::unique_ptr push_promise_data_to_resend_; + + bool drop_response_body_ = false; + bool enable_web_transport_ = false; + bool use_datagram_contexts_ = false; + // If not zero, used to set client's max inbound header size before session + // initialize. + size_t max_inbound_header_list_size_ = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_QUIC_SPDY_CLIENT_BASE_H_ diff --git a/quiche/quic/tools/quic_spdy_server_base.h b/quiche/quic/tools/quic_spdy_server_base.h new file mode 100644 index 000000000000..39bde905de2e --- /dev/null +++ b/quiche/quic/tools/quic_spdy_server_base.h @@ -0,0 +1,30 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// A toy server, which connects to a specified port and sends QUIC +// requests to that endpoint. + +#ifndef QUICHE_QUIC_TOOLS_QUIC_SPDY_SERVER_BASE_H_ +#define QUICHE_QUIC_TOOLS_QUIC_SPDY_SERVER_BASE_H_ + +#include "quiche/quic/platform/api/quic_socket_address.h" + +namespace quic { + +// Base class for service instances to be used with QuicToyServer. +class QuicSpdyServerBase { + public: + virtual ~QuicSpdyServerBase() = default; + + // Creates a UDP socket and listens on |address|. Returns true on success + // and false otherwise. + virtual bool CreateUDPSocketAndListen(const QuicSocketAddress& address) = 0; + + // Handles incoming requests. Does not return. + virtual void HandleEventsForever() = 0; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_QUIC_SPDY_SERVER_BASE_H_ diff --git a/quiche/quic/tools/quic_tcp_like_trace_converter.cc b/quiche/quic/tools/quic_tcp_like_trace_converter.cc new file mode 100644 index 000000000000..b3d5cec6ed0f --- /dev/null +++ b/quiche/quic/tools/quic_tcp_like_trace_converter.cc @@ -0,0 +1,118 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/quic_tcp_like_trace_converter.h" + +#include "quiche/quic/core/quic_constants.h" +#include "quiche/quic/platform/api/quic_bug_tracker.h" + +namespace quic { + +QuicTcpLikeTraceConverter::QuicTcpLikeTraceConverter() + : largest_observed_control_frame_id_(kInvalidControlFrameId), + connection_offset_(0) {} + +QuicTcpLikeTraceConverter::StreamOffsetSegment::StreamOffsetSegment() + : connection_offset(0) {} + +QuicTcpLikeTraceConverter::StreamOffsetSegment::StreamOffsetSegment( + QuicStreamOffset stream_offset, uint64_t connection_offset, + QuicByteCount data_length) + : stream_data(stream_offset, stream_offset + data_length), + connection_offset(connection_offset) {} + +QuicTcpLikeTraceConverter::StreamInfo::StreamInfo() : fin(false) {} + +QuicIntervalSet QuicTcpLikeTraceConverter::OnCryptoFrameSent( + EncryptionLevel level, QuicStreamOffset offset, QuicByteCount data_length) { + if (level >= NUM_ENCRYPTION_LEVELS) { + QUIC_BUG(quic_bug_10907_1) << "Invalid encryption level"; + return {}; + } + return OnFrameSent(offset, data_length, /*fin=*/false, + &crypto_frames_info_[level]); +} + +QuicIntervalSet QuicTcpLikeTraceConverter::OnStreamFrameSent( + QuicStreamId stream_id, QuicStreamOffset offset, QuicByteCount data_length, + bool fin) { + return OnFrameSent( + offset, data_length, fin, + &streams_info_.emplace(stream_id, StreamInfo()).first->second); +} + +QuicIntervalSet QuicTcpLikeTraceConverter::OnFrameSent( + QuicStreamOffset offset, QuicByteCount data_length, bool fin, + StreamInfo* info) { + QuicIntervalSet connection_offsets; + if (fin) { + // Stream fin consumes a connection offset. + ++data_length; + } + // Get connection offsets of retransmission data in this frame. + for (const auto& segment : info->segments) { + QuicInterval retransmission(offset, offset + data_length); + retransmission.IntersectWith(segment.stream_data); + if (retransmission.Empty()) { + continue; + } + const uint64_t connection_offset = segment.connection_offset + + retransmission.min() - + segment.stream_data.min(); + connection_offsets.Add(connection_offset, + connection_offset + retransmission.Length()); + } + + if (info->fin) { + return connection_offsets; + } + + // Get connection offsets of new data in this frame. + QuicStreamOffset least_unsent_offset = + info->segments.empty() ? 0 : info->segments.back().stream_data.max(); + if (least_unsent_offset >= offset + data_length) { + return connection_offsets; + } + // Ignore out-of-order stream data so that as connection offset increases, + // stream offset increases. + QuicStreamOffset new_data_offset = std::max(least_unsent_offset, offset); + QuicByteCount new_data_length = offset + data_length - new_data_offset; + connection_offsets.Add(connection_offset_, + connection_offset_ + new_data_length); + if (!info->segments.empty() && new_data_offset == least_unsent_offset && + connection_offset_ == info->segments.back().connection_offset + + info->segments.back().stream_data.Length()) { + // Extend the last segment if both stream and connection offsets are + // contiguous. + info->segments.back().stream_data.SetMax(new_data_offset + new_data_length); + } else { + info->segments.emplace_back(new_data_offset, connection_offset_, + new_data_length); + } + info->fin = fin; + connection_offset_ += new_data_length; + + return connection_offsets; +} + +QuicInterval QuicTcpLikeTraceConverter::OnControlFrameSent( + QuicControlFrameId control_frame_id, QuicByteCount control_frame_length) { + if (control_frame_id > largest_observed_control_frame_id_) { + // New control frame. + QuicInterval connection_offset = QuicInterval( + connection_offset_, connection_offset_ + control_frame_length); + connection_offset_ += control_frame_length; + control_frames_info_[control_frame_id] = connection_offset; + largest_observed_control_frame_id_ = control_frame_id; + return connection_offset; + } + const auto iter = control_frames_info_.find(control_frame_id); + if (iter == control_frames_info_.end()) { + // Ignore out of order control frames. + return {}; + } + return iter->second; +} + +} // namespace quic diff --git a/quiche/quic/tools/quic_tcp_like_trace_converter.h b/quiche/quic/tools/quic_tcp_like_trace_converter.h new file mode 100644 index 000000000000..3aeaa0fd5aab --- /dev/null +++ b/quiche/quic/tools/quic_tcp_like_trace_converter.h @@ -0,0 +1,85 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TOOLS_QUIC_TCP_LIKE_TRACE_CONVERTER_H_ +#define QUICHE_QUIC_TOOLS_QUIC_TCP_LIKE_TRACE_CONVERTER_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "quiche/quic/core/frames/quic_stream_frame.h" +#include "quiche/quic/core/quic_interval.h" +#include "quiche/quic/core/quic_interval_set.h" +#include "quiche/quic/core/quic_types.h" + +namespace quic { + +// This converter converts sent QUIC frames to connection byte offset (just like +// TCP byte sequence number). +class QuicTcpLikeTraceConverter { + public: + // StreamOffsetSegment stores a stream offset range which has contiguous + // connection offset. + struct StreamOffsetSegment { + StreamOffsetSegment(); + StreamOffsetSegment(QuicStreamOffset stream_offset, + uint64_t connection_offset, QuicByteCount data_length); + + QuicInterval stream_data; + uint64_t connection_offset; + }; + + QuicTcpLikeTraceConverter(); + QuicTcpLikeTraceConverter(const QuicTcpLikeTraceConverter& other) = delete; + QuicTcpLikeTraceConverter(QuicTcpLikeTraceConverter&& other) = delete; + + ~QuicTcpLikeTraceConverter() {} + + // Called when a crypto frame is sent. Returns the corresponding connection + // offsets. + QuicIntervalSet OnCryptoFrameSent(EncryptionLevel level, + QuicStreamOffset offset, + QuicByteCount data_length); + + // Called when a stream frame is sent. Returns the corresponding connection + // offsets. + QuicIntervalSet OnStreamFrameSent(QuicStreamId stream_id, + QuicStreamOffset offset, + QuicByteCount data_length, + bool fin); + + // Called when a control frame is sent. Returns the corresponding connection + // offsets. + QuicInterval OnControlFrameSent(QuicControlFrameId control_frame_id, + QuicByteCount control_frame_length); + + private: + struct StreamInfo { + StreamInfo(); + + // Stores contiguous connection offset pieces. + std::vector segments; + // Indicates whether fin has been sent. + bool fin; + }; + + // Called when frame with |offset|, |data_length| and |fin| has been sent. + // Update |info| and returns connection offsets. + QuicIntervalSet OnFrameSent(QuicStreamOffset offset, + QuicByteCount data_length, bool fin, + StreamInfo* info); + + StreamInfo crypto_frames_info_[NUM_ENCRYPTION_LEVELS]; + absl::flat_hash_map streams_info_; + absl::flat_hash_map> + control_frames_info_; + + QuicControlFrameId largest_observed_control_frame_id_; + + uint64_t connection_offset_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_QUIC_TCP_LIKE_TRACE_CONVERTER_H_ diff --git a/quiche/quic/tools/quic_tcp_like_trace_converter_test.cc b/quiche/quic/tools/quic_tcp_like_trace_converter_test.cc new file mode 100644 index 000000000000..287a319bd4b1 --- /dev/null +++ b/quiche/quic/tools/quic_tcp_like_trace_converter_test.cc @@ -0,0 +1,124 @@ +// Copyright (c) 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/quic_tcp_like_trace_converter.h" + +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { +namespace { + +TEST(QuicTcpLikeTraceConverterTest, BasicTest) { + QuicTcpLikeTraceConverter converter; + + EXPECT_EQ(QuicIntervalSet(0, 100), + converter.OnStreamFrameSent(1, 0, 100, false)); + EXPECT_EQ(QuicIntervalSet(100, 200), + converter.OnStreamFrameSent(3, 0, 100, false)); + EXPECT_EQ(QuicIntervalSet(200, 300), + converter.OnStreamFrameSent(3, 100, 100, false)); + EXPECT_EQ(QuicInterval(300, 450), + converter.OnControlFrameSent(2, 150)); + EXPECT_EQ(QuicIntervalSet(450, 550), + converter.OnStreamFrameSent(1, 100, 100, false)); + EXPECT_EQ(QuicInterval(550, 650), + converter.OnControlFrameSent(3, 100)); + EXPECT_EQ(QuicIntervalSet(650, 850), + converter.OnStreamFrameSent(3, 200, 200, false)); + EXPECT_EQ(QuicInterval(850, 1050), + converter.OnControlFrameSent(4, 200)); + EXPECT_EQ(QuicIntervalSet(1050, 1100), + converter.OnStreamFrameSent(1, 200, 50, false)); + EXPECT_EQ(QuicIntervalSet(1100, 1150), + converter.OnStreamFrameSent(1, 250, 50, false)); + EXPECT_EQ(QuicIntervalSet(1150, 1350), + converter.OnStreamFrameSent(3, 400, 200, false)); + + // Stream 1 retransmits [50, 300) and sends new data [300, 350) in the same + // frame. + QuicIntervalSet expected; + expected.Add(50, 100); + expected.Add(450, 550); + expected.Add(1050, 1150); + expected.Add(1350, 1401); + EXPECT_EQ(expected, converter.OnStreamFrameSent(1, 50, 300, true)); + + expected.Clear(); + // Stream 3 retransmits [150, 500). + expected.Add(250, 300); + expected.Add(650, 850); + expected.Add(1150, 1250); + EXPECT_EQ(expected, converter.OnStreamFrameSent(3, 150, 350, false)); + + // Stream 3 retransmits [300, 600) and sends new data [600, 800) in the same + // frame. + expected.Clear(); + expected.Add(750, 850); + expected.Add(1150, 1350); + expected.Add(1401, 1602); + EXPECT_EQ(expected, converter.OnStreamFrameSent(3, 300, 500, true)); + + // Stream 3 retransmits fin only frame. + expected.Clear(); + expected.Add(1601, 1602); + EXPECT_EQ(expected, converter.OnStreamFrameSent(3, 800, 0, true)); + + QuicInterval expected2; + // Ignore out of order control frames. + EXPECT_EQ(expected2, converter.OnControlFrameSent(1, 100)); + + // Ignore passed in length for retransmitted frame. + expected2 = {300, 450}; + EXPECT_EQ(expected2, converter.OnControlFrameSent(2, 200)); + + expected2 = {1602, 1702}; + EXPECT_EQ(expected2, converter.OnControlFrameSent(10, 100)); +} + +TEST(QuicTcpLikeTraceConverterTest, FuzzerTest) { + QuicTcpLikeTraceConverter converter; + // Stream does not start from offset 0. + EXPECT_EQ(QuicIntervalSet(0, 100), + converter.OnStreamFrameSent(1, 100, 100, false)); + EXPECT_EQ(QuicIntervalSet(100, 300), + converter.OnStreamFrameSent(3, 200, 200, false)); + // Stream does not send data contiguously. + EXPECT_EQ(QuicIntervalSet(300, 400), + converter.OnStreamFrameSent(1, 300, 100, false)); + + // Stream fills existing holes. + QuicIntervalSet expected; + expected.Add(0, 100); + expected.Add(300, 501); + EXPECT_EQ(expected, converter.OnStreamFrameSent(1, 0, 500, true)); + + // Stream sends frame after fin. + EXPECT_EQ(expected, converter.OnStreamFrameSent(1, 50, 600, false)); +} + +TEST(QuicTcpLikeTraceConverterTest, OnCryptoFrameSent) { + QuicTcpLikeTraceConverter converter; + + EXPECT_EQ(QuicIntervalSet(0, 100), + converter.OnCryptoFrameSent(ENCRYPTION_INITIAL, 0, 100)); + EXPECT_EQ(QuicIntervalSet(100, 200), + converter.OnStreamFrameSent(1, 0, 100, false)); + EXPECT_EQ(QuicIntervalSet(200, 300), + converter.OnStreamFrameSent(1, 100, 100, false)); + EXPECT_EQ(QuicIntervalSet(300, 400), + converter.OnCryptoFrameSent(ENCRYPTION_HANDSHAKE, 0, 100)); + EXPECT_EQ(QuicIntervalSet(400, 500), + converter.OnCryptoFrameSent(ENCRYPTION_HANDSHAKE, 100, 100)); + + // Verify crypto frame retransmission works as intended. + EXPECT_EQ(QuicIntervalSet(0, 100), + converter.OnCryptoFrameSent(ENCRYPTION_INITIAL, 0, 100)); + EXPECT_EQ(QuicIntervalSet(400, 500), + converter.OnCryptoFrameSent(ENCRYPTION_HANDSHAKE, 100, 100)); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/tools/quic_toy_client.cc b/quiche/quic/tools/quic_toy_client.cc new file mode 100644 index 000000000000..21c85d7fbcee --- /dev/null +++ b/quiche/quic/tools/quic_toy_client.cc @@ -0,0 +1,555 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// A binary wrapper for QuicClient. +// Connects to a host using QUIC, sends a request to the provided URL, and +// displays the response. +// +// Some usage examples: +// +// Standard request/response: +// quic_client www.google.com +// quic_client www.google.com --quiet +// quic_client www.google.com --port=443 +// +// Use a specific version: +// quic_client www.google.com --quic_version=23 +// +// Send a POST instead of a GET: +// quic_client www.google.com --body="this is a POST body" +// +// Append additional headers to the request: +// quic_client www.google.com --headers="Header-A: 1234; Header-B: 5678" +// +// Connect to a host different to the URL being requested: +// quic_client mail.google.com --host=www.google.com +// +// Connect to a specific IP: +// IP=`dig www.google.com +short | head -1` +// quic_client www.google.com --host=${IP} +// +// Send repeated requests and change ephemeral port between requests +// quic_client www.google.com --num_requests=10 +// +// Try to connect to a host which does not speak QUIC: +// quic_client www.example.com +// +// This tool is available as a built binary at: +// /google/data/ro/teams/quic/tools/quic_client +// After submitting changes to this file, you will need to follow the +// instructions at go/quic_client_binary_update + +#include "quiche/quic/tools/quic_toy_client.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/crypto/quic_client_session_cache.h" +#include "quiche/quic/core/quic_packets.h" +#include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/core/quic_utils.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_default_proof_providers.h" +#include "quiche/quic/platform/api/quic_ip_address.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/tools/fake_proof_verifier.h" +#include "quiche/quic/tools/quic_url.h" +#include "quiche/common/platform/api/quiche_command_line_flags.h" +#include "quiche/common/quiche_text_utils.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace { + +using quiche::QuicheTextUtils; + +} // namespace + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::string, host, "", + "The IP or hostname to connect to. If not provided, the host " + "will be derived from the provided URL."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG(int32_t, port, 0, "The port to connect to."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG(std::string, ip_version_for_host_lookup, "", + "Only used if host address lookup is needed. " + "4=ipv4; 6=ipv6; otherwise=don't care."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG(std::string, body, "", + "If set, send a POST with this body."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::string, body_hex, "", + "If set, contents are converted from hex to ascii, before " + "sending as body of a POST. e.g. --body_hex=\"68656c6c6f\""); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::string, headers, "", + "A semicolon separated list of key:value pairs to " + "add to request headers."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG(bool, quiet, false, + "Set to true for a quieter output experience."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::string, quic_version, "", + "QUIC version to speak, e.g. 21. If not set, then all available " + "versions are offered in the handshake. Also supports wire versions " + "such as Q043 or T099."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::string, connection_options, "", + "Connection options as ASCII tags separated by commas, " + "e.g. \"ABCD,EFGH\""); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::string, client_connection_options, "", + "Client connection options as ASCII tags separated by commas, " + "e.g. \"ABCD,EFGH\""); + +DEFINE_QUICHE_COMMAND_LINE_FLAG(bool, quic_ietf_draft, false, + "Use the IETF draft version. This also enables " + "required internal QUIC flags."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + bool, version_mismatch_ok, false, + "If true, a version mismatch in the handshake is not considered a " + "failure. Useful for probing a server to determine if it speaks " + "any version of QUIC."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + bool, force_version_negotiation, false, + "If true, start by proposing a version that is reserved for version " + "negotiation."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + bool, multi_packet_chlo, false, + "If true, add a transport parameter to make the ClientHello span two " + "packets. Only works with QUIC+TLS."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + bool, redirect_is_success, true, + "If true, an HTTP response code of 3xx is considered to be a " + "successful response, otherwise a failure."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG(int32_t, initial_mtu, 0, + "Initial MTU of the connection."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + int32_t, num_requests, 1, + "How many sequential requests to make on a single connection."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + bool, disable_certificate_verification, false, + "If true, don't verify the server certificate."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::string, default_client_cert, "", + "The path to the file containing PEM-encoded client default certificate to " + "be sent to the server, if server requested client certs."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::string, default_client_cert_key, "", + "The path to the file containing PEM-encoded private key of the client's " + "default certificate for signing, if server requested client certs."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + bool, drop_response_body, false, + "If true, drop response body immediately after it is received."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + bool, disable_port_changes, false, + "If true, do not change local port after each request."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG(bool, one_connection_per_request, false, + "If true, close the connection after each " + "request. This allows testing 0-RTT."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::string, server_connection_id, "", + "If non-empty, the client will use the given server connection id for all " + "connections. The flag value is the hex-string of the on-wire connection id" + " bytes, e.g. '--server_connection_id=0123456789abcdef'."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + int32_t, server_connection_id_length, -1, + "Length of the server connection ID used. This flag has no effects if " + "--server_connection_id is non-empty."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG(int32_t, client_connection_id_length, -1, + "Length of the client connection ID used."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG(int32_t, max_time_before_crypto_handshake_ms, + 10000, + "Max time to wait before handshake completes."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + int32_t, max_inbound_header_list_size, 128 * 1024, + "Max inbound header list size. 0 means default."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG(std::string, interface_name, "", + "Interface name to bind QUIC UDP sockets to."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::string, signing_algorithms_pref, "", + "A textual specification of a set of signature algorithms that can be " + "accepted by boring SSL SSL_set1_sigalgs_list()"); + +namespace quic { +namespace { + +// Creates a ClientProofSource which only contains a default client certificate. +// Return nullptr for failure. +std::unique_ptr CreateTestClientProofSource( + absl::string_view default_client_cert_file, + absl::string_view default_client_cert_key_file) { + std::ifstream cert_stream(std::string{default_client_cert_file}, + std::ios::binary); + std::vector certs = + CertificateView::LoadPemFromStream(&cert_stream); + if (certs.empty()) { + std::cerr << "Failed to load client certs." << std::endl; + return nullptr; + } + + std::ifstream key_stream(std::string{default_client_cert_key_file}, + std::ios::binary); + std::unique_ptr private_key = + CertificatePrivateKey::LoadPemFromStream(&key_stream); + if (private_key == nullptr) { + std::cerr << "Failed to load client cert key." << std::endl; + return nullptr; + } + + auto proof_source = std::make_unique(); + proof_source->AddCertAndKey( + {"*"}, + quiche::QuicheReferenceCountedPointer( + new ClientProofSource::Chain(certs)), + std::move(*private_key)); + + return proof_source; +} + +} // namespace + +QuicToyClient::QuicToyClient(ClientFactory* client_factory) + : client_factory_(client_factory) {} + +int QuicToyClient::SendRequestsAndPrintResponses( + std::vector urls) { + QuicUrl url(urls[0], "https"); + std::string host = quiche::GetQuicheCommandLineFlag(FLAGS_host); + if (host.empty()) { + host = url.host(); + } + int port = quiche::GetQuicheCommandLineFlag(FLAGS_port); + if (port == 0) { + port = url.port(); + } + + quic::ParsedQuicVersionVector versions = quic::CurrentSupportedVersions(); + + if (quiche::GetQuicheCommandLineFlag(FLAGS_quic_ietf_draft)) { + quic::QuicVersionInitializeSupportForIetfDraft(); + versions = {}; + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + if (version.HasIetfQuicFrames() && + version.handshake_protocol == quic::PROTOCOL_TLS1_3) { + versions.push_back(version); + } + } + } + + std::string quic_version_string = + quiche::GetQuicheCommandLineFlag(FLAGS_quic_version); + if (!quic_version_string.empty()) { + versions = quic::ParseQuicVersionVectorString(quic_version_string); + } + + if (versions.empty()) { + std::cerr << "No known version selected." << std::endl; + return 1; + } + + for (const quic::ParsedQuicVersion& version : versions) { + quic::QuicEnableVersion(version); + } + + if (quiche::GetQuicheCommandLineFlag(FLAGS_force_version_negotiation)) { + versions.insert(versions.begin(), + quic::QuicVersionReservedForNegotiation()); + } + + const int32_t num_requests( + quiche::GetQuicheCommandLineFlag(FLAGS_num_requests)); + std::unique_ptr proof_verifier; + if (quiche::GetQuicheCommandLineFlag( + FLAGS_disable_certificate_verification)) { + proof_verifier = std::make_unique(); + } else { + proof_verifier = quic::CreateDefaultProofVerifier(url.host()); + } + std::unique_ptr session_cache; + if (num_requests > 1 && + quiche::GetQuicheCommandLineFlag(FLAGS_one_connection_per_request)) { + session_cache = std::make_unique(); + } + + QuicConfig config; + std::string connection_options_string = + quiche::GetQuicheCommandLineFlag(FLAGS_connection_options); + if (!connection_options_string.empty()) { + config.SetConnectionOptionsToSend( + ParseQuicTagVector(connection_options_string)); + } + std::string client_connection_options_string = + quiche::GetQuicheCommandLineFlag(FLAGS_client_connection_options); + if (!client_connection_options_string.empty()) { + config.SetClientConnectionOptions( + ParseQuicTagVector(client_connection_options_string)); + } + if (quiche::GetQuicheCommandLineFlag(FLAGS_multi_packet_chlo)) { + // Make the ClientHello span multiple packets by adding a custom transport + // parameter. + constexpr auto kCustomParameter = + static_cast(0x173E); + std::string custom_value(2000, '?'); + config.custom_transport_parameters_to_send()[kCustomParameter] = + custom_value; + } + config.set_max_time_before_crypto_handshake( + QuicTime::Delta::FromMilliseconds(quiche::GetQuicheCommandLineFlag( + FLAGS_max_time_before_crypto_handshake_ms))); + + int address_family_for_lookup = AF_UNSPEC; + if (quiche::GetQuicheCommandLineFlag(FLAGS_ip_version_for_host_lookup) == + "4") { + address_family_for_lookup = AF_INET; + } else if (quiche::GetQuicheCommandLineFlag( + FLAGS_ip_version_for_host_lookup) == "6") { + address_family_for_lookup = AF_INET6; + } + + // Build the client, and try to connect. + std::unique_ptr client = client_factory_->CreateClient( + url.host(), host, address_family_for_lookup, port, versions, config, + std::move(proof_verifier), std::move(session_cache)); + + if (client == nullptr) { + std::cerr << "Failed to create client." << std::endl; + return 1; + } + + if (!quiche::GetQuicheCommandLineFlag(FLAGS_default_client_cert).empty() && + !quiche::GetQuicheCommandLineFlag(FLAGS_default_client_cert_key) + .empty()) { + std::unique_ptr proof_source = + CreateTestClientProofSource( + quiche::GetQuicheCommandLineFlag(FLAGS_default_client_cert), + quiche::GetQuicheCommandLineFlag(FLAGS_default_client_cert_key)); + if (proof_source == nullptr) { + std::cerr << "Failed to create client proof source." << std::endl; + return 1; + } + client->crypto_config()->set_proof_source(std::move(proof_source)); + } + + int32_t initial_mtu = quiche::GetQuicheCommandLineFlag(FLAGS_initial_mtu); + client->set_initial_max_packet_length( + initial_mtu != 0 ? initial_mtu : quic::kDefaultMaxPacketSize); + client->set_drop_response_body( + quiche::GetQuicheCommandLineFlag(FLAGS_drop_response_body)); + const std::string server_connection_id_hex_string = + quiche::GetQuicheCommandLineFlag(FLAGS_server_connection_id); + QUICHE_CHECK(server_connection_id_hex_string.size() % 2 == 0) + << "The length of --server_connection_id must be even. It is " + << server_connection_id_hex_string.size() << "-byte long."; + if (!server_connection_id_hex_string.empty()) { + const std::string server_connection_id_bytes = + absl::HexStringToBytes(server_connection_id_hex_string); + client->set_server_connection_id_override(QuicConnectionId( + server_connection_id_bytes.data(), server_connection_id_bytes.size())); + } + const int32_t server_connection_id_length = + quiche::GetQuicheCommandLineFlag(FLAGS_server_connection_id_length); + if (server_connection_id_length >= 0) { + client->set_server_connection_id_length(server_connection_id_length); + } + const int32_t client_connection_id_length = + quiche::GetQuicheCommandLineFlag(FLAGS_client_connection_id_length); + if (client_connection_id_length >= 0) { + client->set_client_connection_id_length(client_connection_id_length); + } + const size_t max_inbound_header_list_size = + quiche::GetQuicheCommandLineFlag(FLAGS_max_inbound_header_list_size); + if (max_inbound_header_list_size > 0) { + client->set_max_inbound_header_list_size(max_inbound_header_list_size); + } + const std::string interface_name = + quiche::GetQuicheCommandLineFlag(FLAGS_interface_name); + if (!interface_name.empty()) { + client->set_interface_name(interface_name); + } + const std::string signing_algorithms_pref = + quiche::GetQuicheCommandLineFlag(FLAGS_signing_algorithms_pref); + if (!signing_algorithms_pref.empty()) { + client->SetTlsSignatureAlgorithms(signing_algorithms_pref); + } + if (!client->Initialize()) { + std::cerr << "Failed to initialize client." << std::endl; + return 1; + } + if (!client->Connect()) { + quic::QuicErrorCode error = client->session()->error(); + if (error == quic::QUIC_INVALID_VERSION) { + std::cerr << "Failed to negotiate version with " << host << ":" << port + << ". " << client->session()->error_details() << std::endl; + // 0: No error. + // 20: Failed to connect due to QUIC_INVALID_VERSION. + return quiche::GetQuicheCommandLineFlag(FLAGS_version_mismatch_ok) ? 0 + : 20; + } + std::cerr << "Failed to connect to " << host << ":" << port << ". " + << quic::QuicErrorCodeToString(error) << " " + << client->session()->error_details() << std::endl; + return 1; + } + std::cerr << "Connected to " << host << ":" << port << std::endl; + + // Construct the string body from flags, if provided. + std::string body = quiche::GetQuicheCommandLineFlag(FLAGS_body); + if (!quiche::GetQuicheCommandLineFlag(FLAGS_body_hex).empty()) { + QUICHE_DCHECK(quiche::GetQuicheCommandLineFlag(FLAGS_body).empty()) + << "Only set one of --body and --body_hex."; + body = absl::HexStringToBytes( + quiche::GetQuicheCommandLineFlag(FLAGS_body_hex)); + } + + // Construct a GET or POST request for supplied URL. + spdy::Http2HeaderBlock header_block; + header_block[":method"] = body.empty() ? "GET" : "POST"; + header_block[":scheme"] = url.scheme(); + header_block[":authority"] = url.HostPort(); + header_block[":path"] = url.PathParamsQuery(); + + // Append any additional headers supplied on the command line. + const std::string headers = quiche::GetQuicheCommandLineFlag(FLAGS_headers); + for (absl::string_view sp : absl::StrSplit(headers, ';')) { + QuicheTextUtils::RemoveLeadingAndTrailingWhitespace(&sp); + if (sp.empty()) { + continue; + } + std::vector kv = + absl::StrSplit(sp, absl::MaxSplits(':', 1)); + QuicheTextUtils::RemoveLeadingAndTrailingWhitespace(&kv[0]); + QuicheTextUtils::RemoveLeadingAndTrailingWhitespace(&kv[1]); + header_block[kv[0]] = kv[1]; + } + + // Make sure to store the response, for later output. + client->set_store_response(true); + + for (int i = 0; i < num_requests; ++i) { + // Send the request. + client->SendRequestAndWaitForResponse(header_block, body, /*fin=*/true); + + // Print request and response details. + if (!quiche::GetQuicheCommandLineFlag(FLAGS_quiet)) { + std::cout << "Request:" << std::endl; + std::cout << "headers:" << header_block.DebugString(); + if (!quiche::GetQuicheCommandLineFlag(FLAGS_body_hex).empty()) { + // Print the user provided hex, rather than binary body. + std::cout << "body:\n" + << QuicheTextUtils::HexDump(absl::HexStringToBytes( + quiche::GetQuicheCommandLineFlag(FLAGS_body_hex))) + << std::endl; + } else { + std::cout << "body: " << body << std::endl; + } + std::cout << std::endl; + + if (!client->preliminary_response_headers().empty()) { + std::cout << "Preliminary response headers: " + << client->preliminary_response_headers() << std::endl; + std::cout << std::endl; + } + + std::cout << "Response:" << std::endl; + std::cout << "headers: " << client->latest_response_headers() + << std::endl; + std::string response_body = client->latest_response_body(); + if (!quiche::GetQuicheCommandLineFlag(FLAGS_body_hex).empty()) { + // Assume response is binary data. + std::cout << "body:\n" + << QuicheTextUtils::HexDump(response_body) << std::endl; + } else { + std::cout << "body: " << response_body << std::endl; + } + std::cout << "trailers: " << client->latest_response_trailers() + << std::endl; + std::cout << "early data accepted: " << client->EarlyDataAccepted() + << std::endl; + } + + if (!client->connected()) { + std::cerr << "Request caused connection failure. Error: " + << quic::QuicErrorCodeToString(client->session()->error()) + << std::endl; + return 1; + } + + int response_code = client->latest_response_code(); + if (response_code >= 200 && response_code < 300) { + std::cout << "Request succeeded (" << response_code << ")." << std::endl; + } else if (response_code >= 300 && response_code < 400) { + if (quiche::GetQuicheCommandLineFlag(FLAGS_redirect_is_success)) { + std::cout << "Request succeeded (redirect " << response_code << ")." + << std::endl; + } else { + std::cout << "Request failed (redirect " << response_code << ")." + << std::endl; + return 1; + } + } else { + std::cout << "Request failed (" << response_code << ")." << std::endl; + return 1; + } + + if (i + 1 < num_requests) { // There are more requests to perform. + if (quiche::GetQuicheCommandLineFlag(FLAGS_one_connection_per_request)) { + std::cout << "Disconnecting client between requests." << std::endl; + client->Disconnect(); + if (!client->Initialize()) { + std::cerr << "Failed to reinitialize client between requests." + << std::endl; + return 1; + } + if (!client->Connect()) { + std::cerr << "Failed to reconnect client between requests." + << std::endl; + return 1; + } + } else if (!quiche::GetQuicheCommandLineFlag( + FLAGS_disable_port_changes)) { + // Change the ephemeral port. + if (!client->ChangeEphemeralPort()) { + std::cerr << "Failed to change ephemeral port." << std::endl; + return 1; + } + } + } + } + + return 0; +} + +} // namespace quic diff --git a/quiche/quic/tools/quic_toy_client.h b/quiche/quic/tools/quic_toy_client.h new file mode 100644 index 000000000000..50e42c4b80f7 --- /dev/null +++ b/quiche/quic/tools/quic_toy_client.h @@ -0,0 +1,35 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// A toy client, which connects to a specified port and sends QUIC +// requests to that endpoint. + +#ifndef QUICHE_QUIC_TOOLS_QUIC_TOY_CLIENT_H_ +#define QUICHE_QUIC_TOOLS_QUIC_TOY_CLIENT_H_ + +#include "quiche/quic/tools/quic_client_factory.h" + +namespace quic { + +class QuicToyClient { + public: + // Constructs a new toy client that will use |client_factory| to create the + // actual QuicSpdyClientBase instance. + QuicToyClient(ClientFactoryInterface* client_factory); + + // Connects to the QUIC server based on the various flags defined in the + // .cc file, sends requests and prints the responses. Returns 0 on success + // and non-zero otherwise. + int SendRequestsAndPrintResponses(std::vector urls); + + // Compatibility alias + using ClientFactory = ClientFactoryInterface; + + private: + ClientFactoryInterface* client_factory_; // Unowned. +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_QUIC_TOY_CLIENT_H_ diff --git a/quiche/quic/tools/quic_toy_server.cc b/quiche/quic/tools/quic_toy_server.cc new file mode 100644 index 000000000000..3f23af6f7f0a --- /dev/null +++ b/quiche/quic/tools/quic_toy_server.cc @@ -0,0 +1,174 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/quic_toy_server.h" + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "quiche/quic/core/quic_server_id.h" +#include "quiche/quic/core/quic_versions.h" +#include "quiche/quic/platform/api/quic_default_proof_providers.h" +#include "quiche/quic/platform/api/quic_socket_address.h" +#include "quiche/quic/tools/connect_server_backend.h" +#include "quiche/quic/tools/quic_memory_cache_backend.h" +#include "quiche/common/platform/api/quiche_command_line_flags.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_random.h" + +DEFINE_QUICHE_COMMAND_LINE_FLAG(int32_t, port, 6121, + "The port the quic server will listen on."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::string, quic_response_cache_dir, "", + "Specifies the directory used during QuicHttpResponseCache " + "construction to seed the cache. Cache directory can be " + "generated using `wget -p --save-headers `"); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + bool, generate_dynamic_responses, false, + "If true, then URLs which have a numeric path will send a dynamically " + "generated response of that many bytes."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG(bool, quic_ietf_draft, false, + "Only enable IETF draft versions. This also " + "enables required internal QUIC flags."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::string, quic_versions, "", + "QUIC versions to enable, e.g. \"h3-25,h3-27\". If not set, then all " + "available versions are enabled."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG(bool, enable_webtransport, false, + "If true, WebTransport support is enabled."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::string, connect_proxy_destinations, "", + "Specifies a comma-separated list of destinations (\"hostname:port\") to " + "which the QUIC server will allow tunneling via CONNECT."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::string, connect_udp_proxy_targets, "", + "Specifies a comma-separated list of target servers (\"hostname:port\") to " + "which the QUIC server will allow tunneling via CONNECT-UDP."); + +DEFINE_QUICHE_COMMAND_LINE_FLAG( + std::string, proxy_server_label, "", + "Specifies an identifier to identify the server in proxy error headers, " + "per the requirements of RFC 9209, Section 2. It should uniquely identify " + "the running service between separate running instances of the QUIC toy " + "server binary. If not specified, one will be randomly generated as " + "\"QuicToyServerN\" where N is a random uint64_t."); + +namespace quic { + +std::unique_ptr +QuicToyServer::MemoryCacheBackendFactory::CreateBackend() { + auto memory_cache_backend = std::make_unique(); + if (quiche::GetQuicheCommandLineFlag(FLAGS_generate_dynamic_responses)) { + memory_cache_backend->GenerateDynamicResponses(); + } + if (!quiche::GetQuicheCommandLineFlag(FLAGS_quic_response_cache_dir) + .empty()) { + memory_cache_backend->InitializeBackend( + quiche::GetQuicheCommandLineFlag(FLAGS_quic_response_cache_dir)); + } + if (quiche::GetQuicheCommandLineFlag(FLAGS_enable_webtransport)) { + memory_cache_backend->EnableWebTransport(); + } + + if (!quiche::GetQuicheCommandLineFlag(FLAGS_connect_proxy_destinations) + .empty() || + !quiche::GetQuicheCommandLineFlag(FLAGS_connect_udp_proxy_targets) + .empty()) { + absl::flat_hash_set connect_proxy_destinations; + for (absl::string_view destination : absl::StrSplit( + quiche::GetQuicheCommandLineFlag(FLAGS_connect_proxy_destinations), + ',', absl::SkipEmpty())) { + absl::optional destination_server_id = + QuicServerId::ParseFromHostPortString(destination); + QUICHE_CHECK(destination_server_id.has_value()); + connect_proxy_destinations.insert( + std::move(destination_server_id).value()); + } + + absl::flat_hash_set connect_udp_proxy_targets; + for (absl::string_view target : absl::StrSplit( + quiche::GetQuicheCommandLineFlag(FLAGS_connect_udp_proxy_targets), + ',', absl::SkipEmpty())) { + absl::optional target_server_id = + QuicServerId::ParseFromHostPortString(target); + QUICHE_CHECK(target_server_id.has_value()); + connect_udp_proxy_targets.insert(std::move(target_server_id).value()); + } + + QUICHE_CHECK(!connect_proxy_destinations.empty() || + !connect_udp_proxy_targets.empty()); + + std::string proxy_server_label = + quiche::GetQuicheCommandLineFlag(FLAGS_proxy_server_label); + if (proxy_server_label.empty()) { + proxy_server_label = absl::StrCat( + "QuicToyServer", + quiche::QuicheRandom::GetInstance()->InsecureRandUint64()); + } + + return std::make_unique( + std::move(memory_cache_backend), std::move(connect_proxy_destinations), + std::move(connect_udp_proxy_targets), std::move(proxy_server_label)); + } + + return memory_cache_backend; +} + +QuicToyServer::QuicToyServer(BackendFactory* backend_factory, + ServerFactory* server_factory) + : backend_factory_(backend_factory), server_factory_(server_factory) {} + +int QuicToyServer::Start() { + ParsedQuicVersionVector supported_versions; + if (quiche::GetQuicheCommandLineFlag(FLAGS_quic_ietf_draft)) { + QuicVersionInitializeSupportForIetfDraft(); + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + // Add all versions that supports IETF QUIC. + if (version.HasIetfQuicFrames() && + version.handshake_protocol == quic::PROTOCOL_TLS1_3) { + supported_versions.push_back(version); + } + } + } else { + supported_versions = AllSupportedVersions(); + } + std::string versions_string = + quiche::GetQuicheCommandLineFlag(FLAGS_quic_versions); + if (!versions_string.empty()) { + supported_versions = ParseQuicVersionVectorString(versions_string); + } + if (supported_versions.empty()) { + return 1; + } + for (const auto& version : supported_versions) { + QuicEnableVersion(version); + } + auto proof_source = quic::CreateDefaultProofSource(); + auto backend = backend_factory_->CreateBackend(); + auto server = server_factory_->CreateServer( + backend.get(), std::move(proof_source), supported_versions); + + if (!server->CreateUDPSocketAndListen(quic::QuicSocketAddress( + quic::QuicIpAddress::Any6(), + quiche::GetQuicheCommandLineFlag(FLAGS_port)))) { + return 1; + } + + server->HandleEventsForever(); + return 0; +} + +} // namespace quic diff --git a/quiche/quic/tools/quic_toy_server.h b/quiche/quic/tools/quic_toy_server.h new file mode 100644 index 000000000000..fc82ff706f25 --- /dev/null +++ b/quiche/quic/tools/quic_toy_server.h @@ -0,0 +1,63 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TOOLS_QUIC_TOY_SERVER_H_ +#define QUICHE_QUIC_TOOLS_QUIC_TOY_SERVER_H_ + +#include "quiche/quic/core/crypto/proof_source.h" +#include "quiche/quic/tools/quic_simple_server_backend.h" +#include "quiche/quic/tools/quic_spdy_server_base.h" + +namespace quic { + +// A binary wrapper for QuicServer. It listens forever on --port +// (default 6121) until it's killed or ctrl-cd to death. +class QuicToyServer { + public: + // A factory for creating QuicSpdyServerBase instances. + class ServerFactory { + public: + virtual ~ServerFactory() = default; + + // Creates a QuicSpdyServerBase instance using |backend| for generating + // responses, and |proof_source| for certificates. + virtual std::unique_ptr CreateServer( + QuicSimpleServerBackend* backend, + std::unique_ptr proof_source, + const ParsedQuicVersionVector& supported_versions) = 0; + }; + + // A facotry for creating QuicSimpleServerBackend instances. + class BackendFactory { + public: + virtual ~BackendFactory() = default; + + // Creates a new backend. + virtual std::unique_ptr CreateBackend() = 0; + }; + + // A factory for creating QuicMemoryCacheBackend instances, configured + // to load files from disk, if necessary. + class MemoryCacheBackendFactory : public BackendFactory { + public: + std::unique_ptr CreateBackend() override; + }; + + // Constructs a new toy server that will use |server_factory| to create the + // actual QuicSpdyServerBase instance. + QuicToyServer(BackendFactory* backend_factory, ServerFactory* server_factory); + + // Connects to the QUIC server based on the various flags defined in the + // .cc file, listends for requests and sends the responses. Returns 1 on + // failure and does not return otherwise. + int Start(); + + private: + BackendFactory* backend_factory_; // Unowned. + ServerFactory* server_factory_; // Unowned. +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_QUIC_TOY_SERVER_H_ diff --git a/quiche/quic/tools/quic_url.cc b/quiche/quic/tools/quic_url.cc new file mode 100644 index 000000000000..db527209024f --- /dev/null +++ b/quiche/quic/tools/quic_url.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/quic_url.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" + +namespace quic { + +static constexpr size_t kMaxHostNameLength = 256; + +QuicUrl::QuicUrl(absl::string_view url) : url_(static_cast(url)) {} + +QuicUrl::QuicUrl(absl::string_view url, absl::string_view default_scheme) + : QuicUrl(url) { + if (url_.has_scheme()) { + return; + } + + url_ = GURL(absl::StrCat(default_scheme, "://", url)); +} + +std::string QuicUrl::ToString() const { + if (IsValid()) { + return url_.spec(); + } + return ""; +} + +bool QuicUrl::IsValid() const { + if (!url_.is_valid() || !url_.has_scheme()) { + return false; + } + + if (url_.has_host() && url_.host().length() > kMaxHostNameLength) { + return false; + } + + return true; +} + +std::string QuicUrl::HostPort() const { + if (!IsValid() || !url_.has_host()) { + return ""; + } + + std::string host = url_.host(); + int port = url_.IntPort(); + if (port == url::PORT_UNSPECIFIED) { + return host; + } + return absl::StrCat(host, ":", port); +} + +std::string QuicUrl::PathParamsQuery() const { + if (!IsValid() || !url_.has_path()) { + return "/"; + } + + return url_.PathForRequest(); +} + +std::string QuicUrl::scheme() const { + if (!IsValid()) { + return ""; + } + + return url_.scheme(); +} + +std::string QuicUrl::host() const { + if (!IsValid()) { + return ""; + } + + return url_.HostNoBrackets(); +} + +std::string QuicUrl::path() const { + if (!IsValid()) { + return ""; + } + + return url_.path(); +} + +uint16_t QuicUrl::port() const { + if (!IsValid()) { + return 0; + } + + int port = url_.EffectiveIntPort(); + if (port == url::PORT_UNSPECIFIED) { + return 0; + } + return port; +} + +} // namespace quic diff --git a/quiche/quic/tools/quic_url.h b/quiche/quic/tools/quic_url.h new file mode 100644 index 000000000000..78b21ec5ee6a --- /dev/null +++ b/quiche/quic/tools/quic_url.h @@ -0,0 +1,61 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TOOLS_QUIC_URL_H_ +#define QUICHE_QUIC_TOOLS_QUIC_URL_H_ + +#include + +#include "absl/strings/string_view.h" +#include "url/gurl.h" +#include "quiche/quic/platform/api/quic_export.h" + +namespace quic { + +// A utility class that wraps GURL. +class QuicUrl { + public: + // Constructs an empty QuicUrl. + QuicUrl() = default; + + // Constructs a QuicUrl from the url string |url|. + // + // NOTE: If |url| doesn't have a scheme, it will have an empty scheme + // field. If that's not what you want, use the QuicUrlImpl(url, + // default_scheme) form below. + explicit QuicUrl(absl::string_view url); + + // Constructs a QuicUrlImpl from |url|, assuming that the scheme for the URL + // is |default_scheme| if there is no scheme specified in |url|. + QuicUrl(absl::string_view url, absl::string_view default_scheme); + + // Returns false if the URL is not valid. + bool IsValid() const; + + // Returns full text of the QuicUrl if it is valid. Return empty string + // otherwise. + std::string ToString() const; + + // Returns host:port. + // If the host is empty, it will return an empty string. + // If the host is an IPv6 address, it will be bracketed. + // If port is not present or is equal to default_port of scheme (e.g., port + // 80 for HTTP), it won't be returned. + std::string HostPort() const; + + // Returns a string assembles path, parameters and query. + std::string PathParamsQuery() const; + + std::string scheme() const; + std::string host() const; + std::string path() const; + uint16_t port() const; + + private: + GURL url_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_QUIC_URL_H_ diff --git a/quiche/quic/tools/quic_url_test.cc b/quiche/quic/tools/quic_url_test.cc new file mode 100644 index 000000000000..8f26e016df33 --- /dev/null +++ b/quiche/quic/tools/quic_url_test.cc @@ -0,0 +1,156 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/quic_url.h" + +#include + +#include "quiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace test { +namespace { + +class QuicUrlTest : public QuicTest {}; + +TEST_F(QuicUrlTest, Basic) { + // No scheme specified. + std::string url_str = "www.example.com"; + QuicUrl url(url_str); + EXPECT_FALSE(url.IsValid()); + + // scheme is HTTP. + url_str = "http://www.example.com"; + url = QuicUrl(url_str); + EXPECT_TRUE(url.IsValid()); + EXPECT_EQ("http://www.example.com/", url.ToString()); + EXPECT_EQ("http", url.scheme()); + EXPECT_EQ("www.example.com", url.HostPort()); + EXPECT_EQ("/", url.PathParamsQuery()); + EXPECT_EQ(80u, url.port()); + + // scheme is HTTPS. + url_str = "https://www.example.com:12345/path/to/resource?a=1&campaign=2"; + url = QuicUrl(url_str); + EXPECT_TRUE(url.IsValid()); + EXPECT_EQ("https://www.example.com:12345/path/to/resource?a=1&campaign=2", + url.ToString()); + EXPECT_EQ("https", url.scheme()); + EXPECT_EQ("www.example.com:12345", url.HostPort()); + EXPECT_EQ("/path/to/resource?a=1&campaign=2", url.PathParamsQuery()); + EXPECT_EQ(12345u, url.port()); + + // scheme is FTP. + url_str = "ftp://www.example.com"; + url = QuicUrl(url_str); + EXPECT_TRUE(url.IsValid()); + EXPECT_EQ("ftp://www.example.com/", url.ToString()); + EXPECT_EQ("ftp", url.scheme()); + EXPECT_EQ("www.example.com", url.HostPort()); + EXPECT_EQ("/", url.PathParamsQuery()); + EXPECT_EQ(21u, url.port()); +} + +TEST_F(QuicUrlTest, DefaultScheme) { + // Default scheme to HTTP. + std::string url_str = "www.example.com"; + QuicUrl url(url_str, "http"); + EXPECT_EQ("http://www.example.com/", url.ToString()); + EXPECT_EQ("http", url.scheme()); + + // URL already has a scheme specified. + url_str = "http://www.example.com"; + url = QuicUrl(url_str, "https"); + EXPECT_EQ("http://www.example.com/", url.ToString()); + EXPECT_EQ("http", url.scheme()); + + // Default scheme to FTP. + url_str = "www.example.com"; + url = QuicUrl(url_str, "ftp"); + EXPECT_EQ("ftp://www.example.com/", url.ToString()); + EXPECT_EQ("ftp", url.scheme()); +} + +TEST_F(QuicUrlTest, IsValid) { + std::string url_str = + "ftp://www.example.com:12345/path/to/resource?a=1&campaign=2"; + EXPECT_TRUE(QuicUrl(url_str).IsValid()); + + // Invalid characters in host name. + url_str = "https://www%.example.com:12345/path/to/resource?a=1&campaign=2"; + EXPECT_FALSE(QuicUrl(url_str).IsValid()); + + // Invalid characters in scheme. + url_str = "%http://www.example.com:12345/path/to/resource?a=1&campaign=2"; + EXPECT_FALSE(QuicUrl(url_str).IsValid()); + + // Host name too long. + std::string host(1024, 'a'); + url_str = "https://" + host; + EXPECT_FALSE(QuicUrl(url_str).IsValid()); + + // Invalid port number. + url_str = "https://www..example.com:123456/path/to/resource?a=1&campaign=2"; + EXPECT_FALSE(QuicUrl(url_str).IsValid()); +} + +TEST_F(QuicUrlTest, HostPort) { + std::string url_str = "http://www.example.com/"; + QuicUrl url(url_str); + EXPECT_EQ("www.example.com", url.HostPort()); + EXPECT_EQ("www.example.com", url.host()); + EXPECT_EQ(80u, url.port()); + + url_str = "http://www.example.com:80/"; + url = QuicUrl(url_str); + EXPECT_EQ("www.example.com", url.HostPort()); + EXPECT_EQ("www.example.com", url.host()); + EXPECT_EQ(80u, url.port()); + + url_str = "http://www.example.com:81/"; + url = QuicUrl(url_str); + EXPECT_EQ("www.example.com:81", url.HostPort()); + EXPECT_EQ("www.example.com", url.host()); + EXPECT_EQ(81u, url.port()); + + url_str = "https://192.168.1.1:443/"; + url = QuicUrl(url_str); + EXPECT_EQ("192.168.1.1", url.HostPort()); + EXPECT_EQ("192.168.1.1", url.host()); + EXPECT_EQ(443u, url.port()); + + url_str = "http://[2001::1]:80/"; + url = QuicUrl(url_str); + EXPECT_EQ("[2001::1]", url.HostPort()); + EXPECT_EQ("2001::1", url.host()); + EXPECT_EQ(80u, url.port()); + + url_str = "http://[2001::1]:81/"; + url = QuicUrl(url_str); + EXPECT_EQ("[2001::1]:81", url.HostPort()); + EXPECT_EQ("2001::1", url.host()); + EXPECT_EQ(81u, url.port()); +} + +TEST_F(QuicUrlTest, PathParamsQuery) { + std::string url_str = + "https://www.example.com:12345/path/to/resource?a=1&campaign=2"; + QuicUrl url(url_str); + EXPECT_EQ("/path/to/resource?a=1&campaign=2", url.PathParamsQuery()); + EXPECT_EQ("/path/to/resource", url.path()); + + url_str = "https://www.example.com/?"; + url = QuicUrl(url_str); + EXPECT_EQ("/?", url.PathParamsQuery()); + EXPECT_EQ("/", url.path()); + + url_str = "https://www.example.com/"; + url = QuicUrl(url_str); + EXPECT_EQ("/", url.PathParamsQuery()); + EXPECT_EQ("/", url.path()); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/quiche/quic/tools/simple_ticket_crypter.cc b/quiche/quic/tools/simple_ticket_crypter.cc new file mode 100644 index 000000000000..ad9fea1a0984 --- /dev/null +++ b/quiche/quic/tools/simple_ticket_crypter.cc @@ -0,0 +1,112 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/simple_ticket_crypter.h" + +#include "openssl/aead.h" +#include "openssl/rand.h" + +namespace quic { + +namespace { + +constexpr QuicTime::Delta kTicketKeyLifetime = + QuicTime::Delta::FromSeconds(60 * 60 * 24 * 7); + +// The format of an encrypted ticket is 1 byte for the key epoch, followed by +// 16 bytes of IV, followed by the output from the AES-GCM Seal operation. The +// seal operation has an overhead of 16 bytes for its auth tag. +constexpr size_t kEpochSize = 1; +constexpr size_t kIVSize = 16; +constexpr size_t kAuthTagSize = 16; + +// Offsets into the ciphertext to make message parsing easier. +constexpr size_t kIVOffset = kEpochSize; +constexpr size_t kMessageOffset = kIVOffset + kIVSize; + +} // namespace + +SimpleTicketCrypter::SimpleTicketCrypter(QuicClock* clock) : clock_(clock) { + RAND_bytes(&key_epoch_, 1); + current_key_ = NewKey(); +} + +SimpleTicketCrypter::~SimpleTicketCrypter() = default; + +size_t SimpleTicketCrypter::MaxOverhead() { + return kEpochSize + kIVSize + kAuthTagSize; +} + +std::vector SimpleTicketCrypter::Encrypt( + absl::string_view in, absl::string_view encryption_key) { + // This class is only used in Chromium, in which the |encryption_key| argument + // will never be populated and an internally-cached key should be used for + // encrypting tickets. + QUICHE_DCHECK(encryption_key.empty()); + MaybeRotateKeys(); + std::vector out(in.size() + MaxOverhead()); + out[0] = key_epoch_; + RAND_bytes(out.data() + kIVOffset, kIVSize); + size_t out_len; + const EVP_AEAD_CTX* ctx = current_key_->aead_ctx.get(); + if (!EVP_AEAD_CTX_seal(ctx, out.data() + kMessageOffset, &out_len, + out.size() - kMessageOffset, out.data() + kIVOffset, + kIVSize, reinterpret_cast(in.data()), + in.size(), nullptr, 0)) { + return std::vector(); + } + out.resize(out_len + kMessageOffset); + return out; +} + +std::vector SimpleTicketCrypter::Decrypt(absl::string_view in) { + MaybeRotateKeys(); + if (in.size() < kMessageOffset) { + return std::vector(); + } + const uint8_t* input = reinterpret_cast(in.data()); + std::vector out(in.size() - kMessageOffset); + size_t out_len; + const EVP_AEAD_CTX* ctx = current_key_->aead_ctx.get(); + if (input[0] != key_epoch_) { + if (input[0] == static_cast(key_epoch_ - 1) && previous_key_) { + ctx = previous_key_->aead_ctx.get(); + } else { + return std::vector(); + } + } + if (!EVP_AEAD_CTX_open(ctx, out.data(), &out_len, out.size(), + input + kIVOffset, kIVSize, input + kMessageOffset, + in.size() - kMessageOffset, nullptr, 0)) { + return std::vector(); + } + out.resize(out_len); + return out; +} + +void SimpleTicketCrypter::Decrypt( + absl::string_view in, + std::shared_ptr callback) { + callback->Run(Decrypt(in)); +} + +void SimpleTicketCrypter::MaybeRotateKeys() { + QuicTime now = clock_->ApproximateNow(); + if (current_key_->expiration < now) { + previous_key_ = std::move(current_key_); + current_key_ = NewKey(); + key_epoch_++; + } +} + +std::unique_ptr SimpleTicketCrypter::NewKey() { + auto key = std::make_unique(); + RAND_bytes(key->key, kKeySize); + EVP_AEAD_CTX_init(key->aead_ctx.get(), EVP_aead_aes_128_gcm(), key->key, + kKeySize, EVP_AEAD_DEFAULT_TAG_LENGTH, nullptr); + key->expiration = clock_->ApproximateNow() + kTicketKeyLifetime; + return key; +} + +} // namespace quic diff --git a/quiche/quic/tools/simple_ticket_crypter.h b/quiche/quic/tools/simple_ticket_crypter.h new file mode 100644 index 000000000000..b5ad16d6f969 --- /dev/null +++ b/quiche/quic/tools/simple_ticket_crypter.h @@ -0,0 +1,56 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TOOLS_SIMPLE_TICKET_CRYPTER_H_ +#define QUICHE_QUIC_TOOLS_SIMPLE_TICKET_CRYPTER_H_ + +#include "openssl/aead.h" +#include "quiche/quic/core/crypto/proof_source.h" +#include "quiche/quic/core/quic_clock.h" +#include "quiche/quic/core/quic_time.h" + +namespace quic { + +// SimpleTicketCrypter implements the QUIC ProofSource::TicketCrypter interface. +// It generates a random key at startup and every 7 days it rotates the key, +// keeping track of the previous key used to facilitate decrypting older +// tickets. This implementation is not suitable for server setups where multiple +// servers need to share keys. +class QUIC_NO_EXPORT SimpleTicketCrypter + : public quic::ProofSource::TicketCrypter { + public: + explicit SimpleTicketCrypter(QuicClock* clock); + ~SimpleTicketCrypter() override; + + size_t MaxOverhead() override; + std::vector Encrypt(absl::string_view in, + absl::string_view encryption_key) override; + void Decrypt( + absl::string_view in, + std::shared_ptr callback) override; + + private: + std::vector Decrypt(absl::string_view in); + + void MaybeRotateKeys(); + + static constexpr size_t kKeySize = 16; + + struct Key { + uint8_t key[kKeySize]; + bssl::ScopedEVP_AEAD_CTX aead_ctx; + QuicTime expiration = QuicTime::Zero(); + }; + + std::unique_ptr NewKey(); + + std::unique_ptr current_key_; + std::unique_ptr previous_key_; + uint8_t key_epoch_ = 0; + QuicClock* clock_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_SIMPLE_TICKET_CRYPTER_H_ diff --git a/quiche/quic/tools/simple_ticket_crypter_test.cc b/quiche/quic/tools/simple_ticket_crypter_test.cc new file mode 100644 index 000000000000..0399047703ec --- /dev/null +++ b/quiche/quic/tools/simple_ticket_crypter_test.cc @@ -0,0 +1,111 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/quic/tools/simple_ticket_crypter.h" + +#include "quiche/quic/platform/api/quic_test.h" +#include "quiche/quic/test_tools/mock_clock.h" + +namespace quic { +namespace test { + +namespace { + +constexpr QuicTime::Delta kOneDay = QuicTime::Delta::FromSeconds(60 * 60 * 24); + +} // namespace + +class DecryptCallback : public quic::ProofSource::DecryptCallback { + public: + explicit DecryptCallback(std::vector* out) : out_(out) {} + + void Run(std::vector plaintext) override { *out_ = plaintext; } + + private: + std::vector* out_; +}; + +absl::string_view StringPiece(const std::vector& in) { + return absl::string_view(reinterpret_cast(in.data()), in.size()); +} + +class SimpleTicketCrypterTest : public QuicTest { + public: + SimpleTicketCrypterTest() : ticket_crypter_(&mock_clock_) {} + + protected: + MockClock mock_clock_; + SimpleTicketCrypter ticket_crypter_; +}; + +TEST_F(SimpleTicketCrypterTest, EncryptDecrypt) { + std::vector plaintext = {1, 2, 3, 4, 5}; + std::vector ciphertext = + ticket_crypter_.Encrypt(StringPiece(plaintext), {}); + EXPECT_NE(plaintext, ciphertext); + + std::vector out_plaintext; + ticket_crypter_.Decrypt(StringPiece(ciphertext), + std::make_unique(&out_plaintext)); + EXPECT_EQ(out_plaintext, plaintext); +} + +TEST_F(SimpleTicketCrypterTest, CiphertextsDiffer) { + std::vector plaintext = {1, 2, 3, 4, 5}; + std::vector ciphertext1 = + ticket_crypter_.Encrypt(StringPiece(plaintext), {}); + std::vector ciphertext2 = + ticket_crypter_.Encrypt(StringPiece(plaintext), {}); + EXPECT_NE(ciphertext1, ciphertext2); +} + +TEST_F(SimpleTicketCrypterTest, DecryptionFailureWithModifiedCiphertext) { + std::vector plaintext = {1, 2, 3, 4, 5}; + std::vector ciphertext = + ticket_crypter_.Encrypt(StringPiece(plaintext), {}); + EXPECT_NE(plaintext, ciphertext); + + // Check that a bit flip in any byte will cause a decryption failure. + for (size_t i = 0; i < ciphertext.size(); i++) { + SCOPED_TRACE(i); + std::vector munged_ciphertext = ciphertext; + munged_ciphertext[i] ^= 1; + std::vector out_plaintext; + ticket_crypter_.Decrypt(StringPiece(munged_ciphertext), + std::make_unique(&out_plaintext)); + EXPECT_TRUE(out_plaintext.empty()); + } +} + +TEST_F(SimpleTicketCrypterTest, DecryptionFailureWithEmptyCiphertext) { + std::vector out_plaintext; + ticket_crypter_.Decrypt(absl::string_view(), + std::make_unique(&out_plaintext)); + EXPECT_TRUE(out_plaintext.empty()); +} + +TEST_F(SimpleTicketCrypterTest, KeyRotation) { + std::vector plaintext = {1, 2, 3}; + std::vector ciphertext = + ticket_crypter_.Encrypt(StringPiece(plaintext), {}); + EXPECT_FALSE(ciphertext.empty()); + + // Advance the clock 8 days, so the key used for |ciphertext| is now the + // previous key. Check that decryption still works. + mock_clock_.AdvanceTime(kOneDay * 8); + std::vector out_plaintext; + ticket_crypter_.Decrypt(StringPiece(ciphertext), + std::make_unique(&out_plaintext)); + EXPECT_EQ(out_plaintext, plaintext); + + // Advance the clock 8 more days. Now the original key should be expired and + // decryption should fail. + mock_clock_.AdvanceTime(kOneDay * 8); + ticket_crypter_.Decrypt(StringPiece(ciphertext), + std::make_unique(&out_plaintext)); + EXPECT_TRUE(out_plaintext.empty()); +} + +} // namespace test +} // namespace quic diff --git a/quiche/quic/tools/web_transport_test_visitors.h b/quiche/quic/tools/web_transport_test_visitors.h new file mode 100644 index 000000000000..0770af67700f --- /dev/null +++ b/quiche/quic/tools/web_transport_test_visitors.h @@ -0,0 +1,270 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TOOLS_WEB_TRANSPORT_TEST_VISITORS_H_ +#define QUICHE_QUIC_TOOLS_WEB_TRANSPORT_TEST_VISITORS_H_ + +#include + +#include "absl/status/status.h" +#include "quiche/quic/core/web_transport_interface.h" +#include "quiche/quic/platform/api/quic_logging.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_mem_slice.h" +#include "quiche/common/quiche_circular_deque.h" +#include "quiche/common/quiche_stream.h" +#include "quiche/common/simple_buffer_allocator.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace quic { + +// Discards any incoming data. +class WebTransportDiscardVisitor : public WebTransportStreamVisitor { + public: + WebTransportDiscardVisitor(WebTransportStream* stream) : stream_(stream) {} + + void OnCanRead() override { + std::string buffer; + WebTransportStream::ReadResult result = stream_->Read(&buffer); + QUIC_DVLOG(2) << "Read " << result.bytes_read + << " bytes from WebTransport stream " + << stream_->GetStreamId() << ", fin: " << result.fin; + } + + void OnCanWrite() override {} + + void OnResetStreamReceived(WebTransportStreamError /*error*/) override {} + void OnStopSendingReceived(WebTransportStreamError /*error*/) override {} + void OnWriteSideInDataRecvdState() override {} + + private: + WebTransportStream* stream_; +}; + +// Echoes any incoming data back on the same stream. +class WebTransportBidirectionalEchoVisitor : public WebTransportStreamVisitor { + public: + WebTransportBidirectionalEchoVisitor(WebTransportStream* stream) + : stream_(stream) {} + + void OnCanRead() override { + WebTransportStream::ReadResult result = stream_->Read(&buffer_); + QUIC_DVLOG(1) << "Attempted reading on WebTransport bidirectional stream " + << stream_->GetStreamId() + << ", bytes read: " << result.bytes_read; + if (result.fin) { + send_fin_ = true; + } + OnCanWrite(); + } + + void OnCanWrite() override { + if (stop_sending_received_) { + return; + } + + if (!buffer_.empty()) { + absl::Status status = quiche::WriteIntoStream(*stream_, buffer_); + QUIC_DVLOG(1) << "Attempted writing on WebTransport bidirectional stream " + << stream_->GetStreamId() << ", success: " << status; + if (!status.ok()) { + return; + } + + buffer_ = ""; + } + + if (send_fin_ && !fin_sent_) { + absl::Status status = quiche::SendFinOnStream(*stream_); + if (status.ok()) { + fin_sent_ = true; + } + } + } + + void OnResetStreamReceived(WebTransportStreamError /*error*/) override { + // Send FIN in response to a stream reset. We want to test that we can + // operate one side of the stream cleanly while the other is reset, thus + // replying with a FIN rather than a RESET_STREAM is more appropriate here. + send_fin_ = true; + OnCanWrite(); + } + void OnStopSendingReceived(WebTransportStreamError /*error*/) override { + stop_sending_received_ = true; + } + void OnWriteSideInDataRecvdState() override {} + + protected: + WebTransportStream* stream() { return stream_; } + + private: + WebTransportStream* stream_; + std::string buffer_; + bool send_fin_ = false; + bool fin_sent_ = false; + bool stop_sending_received_ = false; +}; + +// Buffers all of the data and calls |callback| with the entirety of the stream +// data. +class WebTransportUnidirectionalEchoReadVisitor + : public WebTransportStreamVisitor { + public: + using Callback = std::function; + + WebTransportUnidirectionalEchoReadVisitor(WebTransportStream* stream, + Callback callback) + : stream_(stream), callback_(std::move(callback)) {} + + void OnCanRead() override { + WebTransportStream::ReadResult result = stream_->Read(&buffer_); + QUIC_DVLOG(1) << "Attempted reading on WebTransport unidirectional stream " + << stream_->GetStreamId() + << ", bytes read: " << result.bytes_read; + if (result.fin) { + QUIC_DVLOG(1) << "Finished receiving data on a WebTransport stream " + << stream_->GetStreamId() << ", queueing up the echo"; + callback_(buffer_); + } + } + + void OnCanWrite() override { QUICHE_NOTREACHED(); } + + void OnResetStreamReceived(WebTransportStreamError /*error*/) override {} + void OnStopSendingReceived(WebTransportStreamError /*error*/) override {} + void OnWriteSideInDataRecvdState() override {} + + private: + WebTransportStream* stream_; + std::string buffer_; + Callback callback_; +}; + +// Sends supplied data. +class WebTransportUnidirectionalEchoWriteVisitor + : public WebTransportStreamVisitor { + public: + WebTransportUnidirectionalEchoWriteVisitor(WebTransportStream* stream, + const std::string& data) + : stream_(stream), data_(data) {} + + void OnCanRead() override { QUICHE_NOTREACHED(); } + void OnCanWrite() override { + if (data_.empty()) { + return; + } + absl::Status write_status = quiche::WriteIntoStream(*stream_, data_); + if (!write_status.ok()) { + QUICHE_DLOG_IF(WARNING, !absl::IsUnavailable(write_status)) + << "Failed to write into stream: " << write_status; + return; + } + data_ = ""; + absl::Status fin_status = quiche::SendFinOnStream(*stream_); + QUICHE_DVLOG(1) + << "WebTransportUnidirectionalEchoWriteVisitor finished sending data."; + QUICHE_DCHECK(fin_status.ok()); + } + + void OnResetStreamReceived(WebTransportStreamError /*error*/) override {} + void OnStopSendingReceived(WebTransportStreamError /*error*/) override {} + void OnWriteSideInDataRecvdState() override {} + + private: + WebTransportStream* stream_; + std::string data_; +}; + +// A session visitor which sets unidirectional or bidirectional stream visitors +// to echo. +class EchoWebTransportSessionVisitor : public WebTransportVisitor { + public: + EchoWebTransportSessionVisitor(WebTransportSession* session) + : session_(session) {} + + void OnSessionReady(const spdy::Http2HeaderBlock&) override { + if (session_->CanOpenNextOutgoingBidirectionalStream()) { + OnCanCreateNewOutgoingBidirectionalStream(); + } + } + + void OnSessionClosed(WebTransportSessionError /*error_code*/, + const std::string& /*error_message*/) override {} + + void OnIncomingBidirectionalStreamAvailable() override { + while (true) { + WebTransportStream* stream = + session_->AcceptIncomingBidirectionalStream(); + if (stream == nullptr) { + return; + } + QUIC_DVLOG(1) + << "EchoWebTransportSessionVisitor received a bidirectional stream " + << stream->GetStreamId(); + stream->SetVisitor( + std::make_unique(stream)); + stream->visitor()->OnCanRead(); + } + } + + void OnIncomingUnidirectionalStreamAvailable() override { + while (true) { + WebTransportStream* stream = + session_->AcceptIncomingUnidirectionalStream(); + if (stream == nullptr) { + return; + } + QUIC_DVLOG(1) + << "EchoWebTransportSessionVisitor received a unidirectional stream"; + stream->SetVisitor( + std::make_unique( + stream, [this](const std::string& data) { + streams_to_echo_back_.push_back(data); + TrySendingUnidirectionalStreams(); + })); + stream->visitor()->OnCanRead(); + } + } + + void OnDatagramReceived(absl::string_view datagram) override { + session_->SendOrQueueDatagram(datagram); + } + + void OnCanCreateNewOutgoingBidirectionalStream() override { + if (!echo_stream_opened_) { + WebTransportStream* stream = session_->OpenOutgoingBidirectionalStream(); + stream->SetVisitor( + std::make_unique(stream)); + echo_stream_opened_ = true; + } + } + void OnCanCreateNewOutgoingUnidirectionalStream() override { + TrySendingUnidirectionalStreams(); + } + + void TrySendingUnidirectionalStreams() { + while (!streams_to_echo_back_.empty() && + session_->CanOpenNextOutgoingUnidirectionalStream()) { + QUIC_DVLOG(1) + << "EchoWebTransportServer echoed a unidirectional stream back"; + WebTransportStream* stream = session_->OpenOutgoingUnidirectionalStream(); + stream->SetVisitor( + std::make_unique( + stream, streams_to_echo_back_.front())); + streams_to_echo_back_.pop_front(); + stream->visitor()->OnCanWrite(); + } + } + + private: + WebTransportSession* session_; + quiche::SimpleBufferAllocator allocator_; + bool echo_stream_opened_ = false; + + quiche::QuicheCircularDeque streams_to_echo_back_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_TOOLS_WEB_TRANSPORT_TEST_VISITORS_H_ diff --git a/quiche/spdy/core/array_output_buffer.cc b/quiche/spdy/core/array_output_buffer.cc new file mode 100644 index 000000000000..8ceba2ca0da0 --- /dev/null +++ b/quiche/spdy/core/array_output_buffer.cc @@ -0,0 +1,21 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/array_output_buffer.h" + +namespace spdy { + +void ArrayOutputBuffer::Next(char** data, int* size) { + *data = current_; + *size = capacity_ > 0 ? capacity_ : 0; +} + +void ArrayOutputBuffer::AdvanceWritePtr(int64_t count) { + current_ += count; + capacity_ -= count; +} + +uint64_t ArrayOutputBuffer::BytesFree() const { return capacity_; } + +} // namespace spdy diff --git a/quiche/spdy/core/array_output_buffer.h b/quiche/spdy/core/array_output_buffer.h new file mode 100644 index 000000000000..edce72c6ac51 --- /dev/null +++ b/quiche/spdy/core/array_output_buffer.h @@ -0,0 +1,47 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_SPDY_CORE_ARRAY_OUTPUT_BUFFER_H_ +#define QUICHE_SPDY_CORE_ARRAY_OUTPUT_BUFFER_H_ + +#include + +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/spdy/core/zero_copy_output_buffer.h" + +namespace spdy { + +class QUICHE_EXPORT ArrayOutputBuffer : public ZeroCopyOutputBuffer { + public: + // |buffer| is pointed to the output to write to, and |size| is the capacity + // of the output. + ArrayOutputBuffer(char* buffer, int64_t size) + : current_(buffer), begin_(buffer), capacity_(size) {} + ~ArrayOutputBuffer() override {} + + ArrayOutputBuffer(const ArrayOutputBuffer&) = delete; + ArrayOutputBuffer& operator=(const ArrayOutputBuffer&) = delete; + + void Next(char** data, int* size) override; + void AdvanceWritePtr(int64_t count) override; + uint64_t BytesFree() const override; + + size_t Size() const { return current_ - begin_; } + char* Begin() const { return begin_; } + + // Resets the buffer to its original state. + void Reset() { + capacity_ += Size(); + current_ = begin_; + } + + private: + char* current_ = nullptr; + char* begin_ = nullptr; + uint64_t capacity_ = 0; +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_ARRAY_OUTPUT_BUFFER_H_ diff --git a/quiche/spdy/core/array_output_buffer_test.cc b/quiche/spdy/core/array_output_buffer_test.cc new file mode 100644 index 000000000000..0054a7228191 --- /dev/null +++ b/quiche/spdy/core/array_output_buffer_test.cc @@ -0,0 +1,49 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/array_output_buffer.h" + +#include "quiche/common/platform/api/quiche_test.h" + +namespace spdy { +namespace test { + +// This test verifies that ArrayOutputBuffer is initialized properly. +TEST(ArrayOutputBufferTest, InitializedFromArray) { + char array[100]; + ArrayOutputBuffer buffer(array, sizeof(array)); + EXPECT_EQ(sizeof(array), buffer.BytesFree()); + EXPECT_EQ(0u, buffer.Size()); + EXPECT_EQ(array, buffer.Begin()); +} + +// This test verifies that Reset() causes an ArrayOutputBuffer's capacity and +// size to be reset to the initial state. +TEST(ArrayOutputBufferTest, WriteAndReset) { + char array[100]; + ArrayOutputBuffer buffer(array, sizeof(array)); + + // Let's write some bytes. + char* dst; + int size; + buffer.Next(&dst, &size); + ASSERT_GT(size, 1); + ASSERT_NE(nullptr, dst); + const int64_t written = size / 2; + memset(dst, 'x', written); + buffer.AdvanceWritePtr(written); + + // The buffer should be partially used. + EXPECT_EQ(static_cast(size) - written, buffer.BytesFree()); + EXPECT_EQ(static_cast(written), buffer.Size()); + + buffer.Reset(); + + // After a reset, the buffer should regain its full capacity. + EXPECT_EQ(sizeof(array), buffer.BytesFree()); + EXPECT_EQ(0u, buffer.Size()); +} + +} // namespace test +} // namespace spdy diff --git a/quiche/spdy/core/header_byte_listener_interface.h b/quiche/spdy/core/header_byte_listener_interface.h new file mode 100644 index 000000000000..308110ff167d --- /dev/null +++ b/quiche/spdy/core/header_byte_listener_interface.h @@ -0,0 +1,22 @@ +#ifndef QUICHE_SPDY_CORE_HEADER_BYTE_LISTENER_INTERFACE_H_ +#define QUICHE_SPDY_CORE_HEADER_BYTE_LISTENER_INTERFACE_H_ + +#include + +#include "quiche/common/platform/api/quiche_export.h" + +namespace spdy { + +// Listens for the receipt of uncompressed header bytes. +class QUICHE_EXPORT HeaderByteListenerInterface { + public: + virtual ~HeaderByteListenerInterface() {} + + // Called when a header block has been parsed, with the number of uncompressed + // header bytes parsed from the header block. + virtual void OnHeaderBytesReceived(size_t uncompressed_header_bytes) = 0; +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_HEADER_BYTE_LISTENER_INTERFACE_H_ diff --git a/quiche/spdy/core/hpack/hpack_constants.cc b/quiche/spdy/core/hpack/hpack_constants.cc new file mode 100644 index 000000000000..817216e648c4 --- /dev/null +++ b/quiche/spdy/core/hpack/hpack_constants.cc @@ -0,0 +1,374 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/hpack/hpack_constants.h" + +#include +#include +#include + +#include "absl/base/macros.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/spdy/core/hpack/hpack_static_table.h" + +namespace spdy { + +// Produced by applying the python program [1] with tables provided by [2] +// (inserted into the source of the python program) and copy-paste them into +// this file. +// +// [1] net/tools/build_hpack_constants.py in Chromium +// [2] http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-08 + +// HpackHuffmanSymbol entries are initialized as {code, length, id}. +// Codes are specified in the |length| most-significant bits of |code|. +const std::vector& HpackHuffmanCodeVector() { + static const auto* kHpackHuffmanCode = new std::vector{ + {0xffc00000ul, 13, 0}, // 11111111|11000 + {0xffffb000ul, 23, 1}, // 11111111|11111111|1011000 + {0xfffffe20ul, 28, 2}, // 11111111|11111111|11111110|0010 + {0xfffffe30ul, 28, 3}, // 11111111|11111111|11111110|0011 + {0xfffffe40ul, 28, 4}, // 11111111|11111111|11111110|0100 + {0xfffffe50ul, 28, 5}, // 11111111|11111111|11111110|0101 + {0xfffffe60ul, 28, 6}, // 11111111|11111111|11111110|0110 + {0xfffffe70ul, 28, 7}, // 11111111|11111111|11111110|0111 + {0xfffffe80ul, 28, 8}, // 11111111|11111111|11111110|1000 + {0xffffea00ul, 24, 9}, // 11111111|11111111|11101010 + {0xfffffff0ul, 30, 10}, // 11111111|11111111|11111111|111100 + {0xfffffe90ul, 28, 11}, // 11111111|11111111|11111110|1001 + {0xfffffea0ul, 28, 12}, // 11111111|11111111|11111110|1010 + {0xfffffff4ul, 30, 13}, // 11111111|11111111|11111111|111101 + {0xfffffeb0ul, 28, 14}, // 11111111|11111111|11111110|1011 + {0xfffffec0ul, 28, 15}, // 11111111|11111111|11111110|1100 + {0xfffffed0ul, 28, 16}, // 11111111|11111111|11111110|1101 + {0xfffffee0ul, 28, 17}, // 11111111|11111111|11111110|1110 + {0xfffffef0ul, 28, 18}, // 11111111|11111111|11111110|1111 + {0xffffff00ul, 28, 19}, // 11111111|11111111|11111111|0000 + {0xffffff10ul, 28, 20}, // 11111111|11111111|11111111|0001 + {0xffffff20ul, 28, 21}, // 11111111|11111111|11111111|0010 + {0xfffffff8ul, 30, 22}, // 11111111|11111111|11111111|111110 + {0xffffff30ul, 28, 23}, // 11111111|11111111|11111111|0011 + {0xffffff40ul, 28, 24}, // 11111111|11111111|11111111|0100 + {0xffffff50ul, 28, 25}, // 11111111|11111111|11111111|0101 + {0xffffff60ul, 28, 26}, // 11111111|11111111|11111111|0110 + {0xffffff70ul, 28, 27}, // 11111111|11111111|11111111|0111 + {0xffffff80ul, 28, 28}, // 11111111|11111111|11111111|1000 + {0xffffff90ul, 28, 29}, // 11111111|11111111|11111111|1001 + {0xffffffa0ul, 28, 30}, // 11111111|11111111|11111111|1010 + {0xffffffb0ul, 28, 31}, // 11111111|11111111|11111111|1011 + {0x50000000ul, 6, 32}, // ' ' 010100 + {0xfe000000ul, 10, 33}, // '!' 11111110|00 + {0xfe400000ul, 10, 34}, // '"' 11111110|01 + {0xffa00000ul, 12, 35}, // '#' 11111111|1010 + {0xffc80000ul, 13, 36}, // '$' 11111111|11001 + {0x54000000ul, 6, 37}, // '%' 010101 + {0xf8000000ul, 8, 38}, // '&' 11111000 + {0xff400000ul, 11, 39}, // ''' 11111111|010 + {0xfe800000ul, 10, 40}, // '(' 11111110|10 + {0xfec00000ul, 10, 41}, // ')' 11111110|11 + {0xf9000000ul, 8, 42}, // '*' 11111001 + {0xff600000ul, 11, 43}, // '+' 11111111|011 + {0xfa000000ul, 8, 44}, // ',' 11111010 + {0x58000000ul, 6, 45}, // '-' 010110 + {0x5c000000ul, 6, 46}, // '.' 010111 + {0x60000000ul, 6, 47}, // '/' 011000 + {0x00000000ul, 5, 48}, // '0' 00000 + {0x08000000ul, 5, 49}, // '1' 00001 + {0x10000000ul, 5, 50}, // '2' 00010 + {0x64000000ul, 6, 51}, // '3' 011001 + {0x68000000ul, 6, 52}, // '4' 011010 + {0x6c000000ul, 6, 53}, // '5' 011011 + {0x70000000ul, 6, 54}, // '6' 011100 + {0x74000000ul, 6, 55}, // '7' 011101 + {0x78000000ul, 6, 56}, // '8' 011110 + {0x7c000000ul, 6, 57}, // '9' 011111 + {0xb8000000ul, 7, 58}, // ':' 1011100 + {0xfb000000ul, 8, 59}, // ';' 11111011 + {0xfff80000ul, 15, 60}, // '<' 11111111|1111100 + {0x80000000ul, 6, 61}, // '=' 100000 + {0xffb00000ul, 12, 62}, // '>' 11111111|1011 + {0xff000000ul, 10, 63}, // '?' 11111111|00 + {0xffd00000ul, 13, 64}, // '@' 11111111|11010 + {0x84000000ul, 6, 65}, // 'A' 100001 + {0xba000000ul, 7, 66}, // 'B' 1011101 + {0xbc000000ul, 7, 67}, // 'C' 1011110 + {0xbe000000ul, 7, 68}, // 'D' 1011111 + {0xc0000000ul, 7, 69}, // 'E' 1100000 + {0xc2000000ul, 7, 70}, // 'F' 1100001 + {0xc4000000ul, 7, 71}, // 'G' 1100010 + {0xc6000000ul, 7, 72}, // 'H' 1100011 + {0xc8000000ul, 7, 73}, // 'I' 1100100 + {0xca000000ul, 7, 74}, // 'J' 1100101 + {0xcc000000ul, 7, 75}, // 'K' 1100110 + {0xce000000ul, 7, 76}, // 'L' 1100111 + {0xd0000000ul, 7, 77}, // 'M' 1101000 + {0xd2000000ul, 7, 78}, // 'N' 1101001 + {0xd4000000ul, 7, 79}, // 'O' 1101010 + {0xd6000000ul, 7, 80}, // 'P' 1101011 + {0xd8000000ul, 7, 81}, // 'Q' 1101100 + {0xda000000ul, 7, 82}, // 'R' 1101101 + {0xdc000000ul, 7, 83}, // 'S' 1101110 + {0xde000000ul, 7, 84}, // 'T' 1101111 + {0xe0000000ul, 7, 85}, // 'U' 1110000 + {0xe2000000ul, 7, 86}, // 'V' 1110001 + {0xe4000000ul, 7, 87}, // 'W' 1110010 + {0xfc000000ul, 8, 88}, // 'X' 11111100 + {0xe6000000ul, 7, 89}, // 'Y' 1110011 + {0xfd000000ul, 8, 90}, // 'Z' 11111101 + {0xffd80000ul, 13, 91}, // '[' 11111111|11011 + {0xfffe0000ul, 19, 92}, // '\' 11111111|11111110|000 + {0xffe00000ul, 13, 93}, // ']' 11111111|11100 + {0xfff00000ul, 14, 94}, // '^' 11111111|111100 + {0x88000000ul, 6, 95}, // '_' 100010 + {0xfffa0000ul, 15, 96}, // '`' 11111111|1111101 + {0x18000000ul, 5, 97}, // 'a' 00011 + {0x8c000000ul, 6, 98}, // 'b' 100011 + {0x20000000ul, 5, 99}, // 'c' 00100 + {0x90000000ul, 6, 100}, // 'd' 100100 + {0x28000000ul, 5, 101}, // 'e' 00101 + {0x94000000ul, 6, 102}, // 'f' 100101 + {0x98000000ul, 6, 103}, // 'g' 100110 + {0x9c000000ul, 6, 104}, // 'h' 100111 + {0x30000000ul, 5, 105}, // 'i' 00110 + {0xe8000000ul, 7, 106}, // 'j' 1110100 + {0xea000000ul, 7, 107}, // 'k' 1110101 + {0xa0000000ul, 6, 108}, // 'l' 101000 + {0xa4000000ul, 6, 109}, // 'm' 101001 + {0xa8000000ul, 6, 110}, // 'n' 101010 + {0x38000000ul, 5, 111}, // 'o' 00111 + {0xac000000ul, 6, 112}, // 'p' 101011 + {0xec000000ul, 7, 113}, // 'q' 1110110 + {0xb0000000ul, 6, 114}, // 'r' 101100 + {0x40000000ul, 5, 115}, // 's' 01000 + {0x48000000ul, 5, 116}, // 't' 01001 + {0xb4000000ul, 6, 117}, // 'u' 101101 + {0xee000000ul, 7, 118}, // 'v' 1110111 + {0xf0000000ul, 7, 119}, // 'w' 1111000 + {0xf2000000ul, 7, 120}, // 'x' 1111001 + {0xf4000000ul, 7, 121}, // 'y' 1111010 + {0xf6000000ul, 7, 122}, // 'z' 1111011 + {0xfffc0000ul, 15, 123}, // '{' 11111111|1111110 + {0xff800000ul, 11, 124}, // '|' 11111111|100 + {0xfff40000ul, 14, 125}, // '}' 11111111|111101 + {0xffe80000ul, 13, 126}, // '~' 11111111|11101 + {0xffffffc0ul, 28, 127}, // 11111111|11111111|11111111|1100 + {0xfffe6000ul, 20, 128}, // 11111111|11111110|0110 + {0xffff4800ul, 22, 129}, // 11111111|11111111|010010 + {0xfffe7000ul, 20, 130}, // 11111111|11111110|0111 + {0xfffe8000ul, 20, 131}, // 11111111|11111110|1000 + {0xffff4c00ul, 22, 132}, // 11111111|11111111|010011 + {0xffff5000ul, 22, 133}, // 11111111|11111111|010100 + {0xffff5400ul, 22, 134}, // 11111111|11111111|010101 + {0xffffb200ul, 23, 135}, // 11111111|11111111|1011001 + {0xffff5800ul, 22, 136}, // 11111111|11111111|010110 + {0xffffb400ul, 23, 137}, // 11111111|11111111|1011010 + {0xffffb600ul, 23, 138}, // 11111111|11111111|1011011 + {0xffffb800ul, 23, 139}, // 11111111|11111111|1011100 + {0xffffba00ul, 23, 140}, // 11111111|11111111|1011101 + {0xffffbc00ul, 23, 141}, // 11111111|11111111|1011110 + {0xffffeb00ul, 24, 142}, // 11111111|11111111|11101011 + {0xffffbe00ul, 23, 143}, // 11111111|11111111|1011111 + {0xffffec00ul, 24, 144}, // 11111111|11111111|11101100 + {0xffffed00ul, 24, 145}, // 11111111|11111111|11101101 + {0xffff5c00ul, 22, 146}, // 11111111|11111111|010111 + {0xffffc000ul, 23, 147}, // 11111111|11111111|1100000 + {0xffffee00ul, 24, 148}, // 11111111|11111111|11101110 + {0xffffc200ul, 23, 149}, // 11111111|11111111|1100001 + {0xffffc400ul, 23, 150}, // 11111111|11111111|1100010 + {0xffffc600ul, 23, 151}, // 11111111|11111111|1100011 + {0xffffc800ul, 23, 152}, // 11111111|11111111|1100100 + {0xfffee000ul, 21, 153}, // 11111111|11111110|11100 + {0xffff6000ul, 22, 154}, // 11111111|11111111|011000 + {0xffffca00ul, 23, 155}, // 11111111|11111111|1100101 + {0xffff6400ul, 22, 156}, // 11111111|11111111|011001 + {0xffffcc00ul, 23, 157}, // 11111111|11111111|1100110 + {0xffffce00ul, 23, 158}, // 11111111|11111111|1100111 + {0xffffef00ul, 24, 159}, // 11111111|11111111|11101111 + {0xffff6800ul, 22, 160}, // 11111111|11111111|011010 + {0xfffee800ul, 21, 161}, // 11111111|11111110|11101 + {0xfffe9000ul, 20, 162}, // 11111111|11111110|1001 + {0xffff6c00ul, 22, 163}, // 11111111|11111111|011011 + {0xffff7000ul, 22, 164}, // 11111111|11111111|011100 + {0xffffd000ul, 23, 165}, // 11111111|11111111|1101000 + {0xffffd200ul, 23, 166}, // 11111111|11111111|1101001 + {0xfffef000ul, 21, 167}, // 11111111|11111110|11110 + {0xffffd400ul, 23, 168}, // 11111111|11111111|1101010 + {0xffff7400ul, 22, 169}, // 11111111|11111111|011101 + {0xffff7800ul, 22, 170}, // 11111111|11111111|011110 + {0xfffff000ul, 24, 171}, // 11111111|11111111|11110000 + {0xfffef800ul, 21, 172}, // 11111111|11111110|11111 + {0xffff7c00ul, 22, 173}, // 11111111|11111111|011111 + {0xffffd600ul, 23, 174}, // 11111111|11111111|1101011 + {0xffffd800ul, 23, 175}, // 11111111|11111111|1101100 + {0xffff0000ul, 21, 176}, // 11111111|11111111|00000 + {0xffff0800ul, 21, 177}, // 11111111|11111111|00001 + {0xffff8000ul, 22, 178}, // 11111111|11111111|100000 + {0xffff1000ul, 21, 179}, // 11111111|11111111|00010 + {0xffffda00ul, 23, 180}, // 11111111|11111111|1101101 + {0xffff8400ul, 22, 181}, // 11111111|11111111|100001 + {0xffffdc00ul, 23, 182}, // 11111111|11111111|1101110 + {0xffffde00ul, 23, 183}, // 11111111|11111111|1101111 + {0xfffea000ul, 20, 184}, // 11111111|11111110|1010 + {0xffff8800ul, 22, 185}, // 11111111|11111111|100010 + {0xffff8c00ul, 22, 186}, // 11111111|11111111|100011 + {0xffff9000ul, 22, 187}, // 11111111|11111111|100100 + {0xffffe000ul, 23, 188}, // 11111111|11111111|1110000 + {0xffff9400ul, 22, 189}, // 11111111|11111111|100101 + {0xffff9800ul, 22, 190}, // 11111111|11111111|100110 + {0xffffe200ul, 23, 191}, // 11111111|11111111|1110001 + {0xfffff800ul, 26, 192}, // 11111111|11111111|11111000|00 + {0xfffff840ul, 26, 193}, // 11111111|11111111|11111000|01 + {0xfffeb000ul, 20, 194}, // 11111111|11111110|1011 + {0xfffe2000ul, 19, 195}, // 11111111|11111110|001 + {0xffff9c00ul, 22, 196}, // 11111111|11111111|100111 + {0xffffe400ul, 23, 197}, // 11111111|11111111|1110010 + {0xffffa000ul, 22, 198}, // 11111111|11111111|101000 + {0xfffff600ul, 25, 199}, // 11111111|11111111|11110110|0 + {0xfffff880ul, 26, 200}, // 11111111|11111111|11111000|10 + {0xfffff8c0ul, 26, 201}, // 11111111|11111111|11111000|11 + {0xfffff900ul, 26, 202}, // 11111111|11111111|11111001|00 + {0xfffffbc0ul, 27, 203}, // 11111111|11111111|11111011|110 + {0xfffffbe0ul, 27, 204}, // 11111111|11111111|11111011|111 + {0xfffff940ul, 26, 205}, // 11111111|11111111|11111001|01 + {0xfffff100ul, 24, 206}, // 11111111|11111111|11110001 + {0xfffff680ul, 25, 207}, // 11111111|11111111|11110110|1 + {0xfffe4000ul, 19, 208}, // 11111111|11111110|010 + {0xffff1800ul, 21, 209}, // 11111111|11111111|00011 + {0xfffff980ul, 26, 210}, // 11111111|11111111|11111001|10 + {0xfffffc00ul, 27, 211}, // 11111111|11111111|11111100|000 + {0xfffffc20ul, 27, 212}, // 11111111|11111111|11111100|001 + {0xfffff9c0ul, 26, 213}, // 11111111|11111111|11111001|11 + {0xfffffc40ul, 27, 214}, // 11111111|11111111|11111100|010 + {0xfffff200ul, 24, 215}, // 11111111|11111111|11110010 + {0xffff2000ul, 21, 216}, // 11111111|11111111|00100 + {0xffff2800ul, 21, 217}, // 11111111|11111111|00101 + {0xfffffa00ul, 26, 218}, // 11111111|11111111|11111010|00 + {0xfffffa40ul, 26, 219}, // 11111111|11111111|11111010|01 + {0xffffffd0ul, 28, 220}, // 11111111|11111111|11111111|1101 + {0xfffffc60ul, 27, 221}, // 11111111|11111111|11111100|011 + {0xfffffc80ul, 27, 222}, // 11111111|11111111|11111100|100 + {0xfffffca0ul, 27, 223}, // 11111111|11111111|11111100|101 + {0xfffec000ul, 20, 224}, // 11111111|11111110|1100 + {0xfffff300ul, 24, 225}, // 11111111|11111111|11110011 + {0xfffed000ul, 20, 226}, // 11111111|11111110|1101 + {0xffff3000ul, 21, 227}, // 11111111|11111111|00110 + {0xffffa400ul, 22, 228}, // 11111111|11111111|101001 + {0xffff3800ul, 21, 229}, // 11111111|11111111|00111 + {0xffff4000ul, 21, 230}, // 11111111|11111111|01000 + {0xffffe600ul, 23, 231}, // 11111111|11111111|1110011 + {0xffffa800ul, 22, 232}, // 11111111|11111111|101010 + {0xffffac00ul, 22, 233}, // 11111111|11111111|101011 + {0xfffff700ul, 25, 234}, // 11111111|11111111|11110111|0 + {0xfffff780ul, 25, 235}, // 11111111|11111111|11110111|1 + {0xfffff400ul, 24, 236}, // 11111111|11111111|11110100 + {0xfffff500ul, 24, 237}, // 11111111|11111111|11110101 + {0xfffffa80ul, 26, 238}, // 11111111|11111111|11111010|10 + {0xffffe800ul, 23, 239}, // 11111111|11111111|1110100 + {0xfffffac0ul, 26, 240}, // 11111111|11111111|11111010|11 + {0xfffffcc0ul, 27, 241}, // 11111111|11111111|11111100|110 + {0xfffffb00ul, 26, 242}, // 11111111|11111111|11111011|00 + {0xfffffb40ul, 26, 243}, // 11111111|11111111|11111011|01 + {0xfffffce0ul, 27, 244}, // 11111111|11111111|11111100|111 + {0xfffffd00ul, 27, 245}, // 11111111|11111111|11111101|000 + {0xfffffd20ul, 27, 246}, // 11111111|11111111|11111101|001 + {0xfffffd40ul, 27, 247}, // 11111111|11111111|11111101|010 + {0xfffffd60ul, 27, 248}, // 11111111|11111111|11111101|011 + {0xffffffe0ul, 28, 249}, // 11111111|11111111|11111111|1110 + {0xfffffd80ul, 27, 250}, // 11111111|11111111|11111101|100 + {0xfffffda0ul, 27, 251}, // 11111111|11111111|11111101|101 + {0xfffffdc0ul, 27, 252}, // 11111111|11111111|11111101|110 + {0xfffffde0ul, 27, 253}, // 11111111|11111111|11111101|111 + {0xfffffe00ul, 27, 254}, // 11111111|11111111|11111110|000 + {0xfffffb80ul, 26, 255}, // 11111111|11111111|11111011|10 + {0xfffffffcul, 30, 256}, // EOS 11111111|11111111|11111111|111111 + }; + return *kHpackHuffmanCode; +} + +// The "constructor" for a HpackStaticEntry that computes the lengths at +// compile time. +#define STATIC_ENTRY(name, value) \ + { name, ABSL_ARRAYSIZE(name) - 1, value, ABSL_ARRAYSIZE(value) - 1 } + +const std::vector& HpackStaticTableVector() { + static const auto* kHpackStaticTable = new std::vector{ + STATIC_ENTRY(":authority", ""), // 1 + STATIC_ENTRY(":method", "GET"), // 2 + STATIC_ENTRY(":method", "POST"), // 3 + STATIC_ENTRY(":path", "/"), // 4 + STATIC_ENTRY(":path", "/index.html"), // 5 + STATIC_ENTRY(":scheme", "http"), // 6 + STATIC_ENTRY(":scheme", "https"), // 7 + STATIC_ENTRY(":status", "200"), // 8 + STATIC_ENTRY(":status", "204"), // 9 + STATIC_ENTRY(":status", "206"), // 10 + STATIC_ENTRY(":status", "304"), // 11 + STATIC_ENTRY(":status", "400"), // 12 + STATIC_ENTRY(":status", "404"), // 13 + STATIC_ENTRY(":status", "500"), // 14 + STATIC_ENTRY("accept-charset", ""), // 15 + STATIC_ENTRY("accept-encoding", "gzip, deflate"), // 16 + STATIC_ENTRY("accept-language", ""), // 17 + STATIC_ENTRY("accept-ranges", ""), // 18 + STATIC_ENTRY("accept", ""), // 19 + STATIC_ENTRY("access-control-allow-origin", ""), // 20 + STATIC_ENTRY("age", ""), // 21 + STATIC_ENTRY("allow", ""), // 22 + STATIC_ENTRY("authorization", ""), // 23 + STATIC_ENTRY("cache-control", ""), // 24 + STATIC_ENTRY("content-disposition", ""), // 25 + STATIC_ENTRY("content-encoding", ""), // 26 + STATIC_ENTRY("content-language", ""), // 27 + STATIC_ENTRY("content-length", ""), // 28 + STATIC_ENTRY("content-location", ""), // 29 + STATIC_ENTRY("content-range", ""), // 30 + STATIC_ENTRY("content-type", ""), // 31 + STATIC_ENTRY("cookie", ""), // 32 + STATIC_ENTRY("date", ""), // 33 + STATIC_ENTRY("etag", ""), // 34 + STATIC_ENTRY("expect", ""), // 35 + STATIC_ENTRY("expires", ""), // 36 + STATIC_ENTRY("from", ""), // 37 + STATIC_ENTRY("host", ""), // 38 + STATIC_ENTRY("if-match", ""), // 39 + STATIC_ENTRY("if-modified-since", ""), // 40 + STATIC_ENTRY("if-none-match", ""), // 41 + STATIC_ENTRY("if-range", ""), // 42 + STATIC_ENTRY("if-unmodified-since", ""), // 43 + STATIC_ENTRY("last-modified", ""), // 44 + STATIC_ENTRY("link", ""), // 45 + STATIC_ENTRY("location", ""), // 46 + STATIC_ENTRY("max-forwards", ""), // 47 + STATIC_ENTRY("proxy-authenticate", ""), // 48 + STATIC_ENTRY("proxy-authorization", ""), // 49 + STATIC_ENTRY("range", ""), // 50 + STATIC_ENTRY("referer", ""), // 51 + STATIC_ENTRY("refresh", ""), // 52 + STATIC_ENTRY("retry-after", ""), // 53 + STATIC_ENTRY("server", ""), // 54 + STATIC_ENTRY("set-cookie", ""), // 55 + STATIC_ENTRY("strict-transport-security", ""), // 56 + STATIC_ENTRY("transfer-encoding", ""), // 57 + STATIC_ENTRY("user-agent", ""), // 58 + STATIC_ENTRY("vary", ""), // 59 + STATIC_ENTRY("via", ""), // 60 + STATIC_ENTRY("www-authenticate", ""), // 61 + }; + return *kHpackStaticTable; +} + +#undef STATIC_ENTRY + +const HpackStaticTable& ObtainHpackStaticTable() { + static const HpackStaticTable* const shared_static_table = []() { + auto* table = new HpackStaticTable(); + table->Initialize(HpackStaticTableVector().data(), + HpackStaticTableVector().size()); + QUICHE_CHECK(table->IsInitialized()); + return table; + }(); + return *shared_static_table; +} + +} // namespace spdy diff --git a/quiche/spdy/core/hpack/hpack_constants.h b/quiche/spdy/core/hpack/hpack_constants.h new file mode 100644 index 000000000000..1220a38e3c8b --- /dev/null +++ b/quiche/spdy/core/hpack/hpack_constants.h @@ -0,0 +1,88 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_SPDY_CORE_HPACK_HPACK_CONSTANTS_H_ +#define QUICHE_SPDY_CORE_HPACK_HPACK_CONSTANTS_H_ + +#include +#include +#include + +#include "quiche/common/platform/api/quiche_export.h" + +// All section references below are to +// https://httpwg.org/specs/rfc7540.html and +// https://httpwg.org/specs/rfc7541.html. + +namespace spdy { + +// An HpackPrefix signifies |bits| stored in the top |bit_size| bits +// of an octet. +struct QUICHE_EXPORT HpackPrefix { + uint8_t bits; + size_t bit_size; +}; + +// Represents a symbol and its Huffman code (stored in most-significant bits). +struct QUICHE_EXPORT HpackHuffmanSymbol { + uint32_t code; + uint8_t length; + uint16_t id; +}; + +// An entry in the static table. Must be a POD in order to avoid static +// initializers, i.e. no user-defined constructors or destructors. +struct QUICHE_EXPORT HpackStaticEntry { + const char* const name; + const size_t name_len; + const char* const value; + const size_t value_len; +}; + +class HpackStaticTable; + +// RFC 7540, 6.5.2: Initial value for SETTINGS_HEADER_TABLE_SIZE. +const uint32_t kDefaultHeaderTableSizeSetting = 4096; + +// RFC 7541, 5.2: Flag for a string literal that is stored unmodified (i.e., +// without Huffman encoding). +const HpackPrefix kStringLiteralIdentityEncoded = {0x0, 1}; + +// RFC 7541, 5.2: Flag for a Huffman-coded string literal. +const HpackPrefix kStringLiteralHuffmanEncoded = {0x1, 1}; + +// RFC 7541, 6.1: Opcode for an indexed header field. +const HpackPrefix kIndexedOpcode = {0b1, 1}; + +// RFC 7541, 6.2.1: Opcode for a literal header field with incremental indexing. +const HpackPrefix kLiteralIncrementalIndexOpcode = {0b01, 2}; + +// RFC 7541, 6.2.2: Opcode for a literal header field without indexing. +const HpackPrefix kLiteralNoIndexOpcode = {0b0000, 4}; + +// RFC 7541, 6.2.3: Opcode for a literal header field which is never indexed. +// Currently unused. +// const HpackPrefix kLiteralNeverIndexOpcode = {0b0001, 4}; + +// RFC 7541, 6.3: Opcode for maximum header table size update. Begins a +// varint-encoded table size with a 5-bit prefix. +const HpackPrefix kHeaderTableSizeUpdateOpcode = {0b001, 3}; + +// RFC 7541, Appendix B: Huffman Code. +QUICHE_EXPORT const std::vector& HpackHuffmanCodeVector(); + +// RFC 7541, Appendix A: Static Table Definition. +QUICHE_EXPORT const std::vector& HpackStaticTableVector(); + +// Returns a HpackStaticTable instance initialized with |kHpackStaticTable|. +// The instance is read-only, has static lifetime, and is safe to share amoung +// threads. This function is thread-safe. +QUICHE_EXPORT const HpackStaticTable& ObtainHpackStaticTable(); + +// RFC 7541, 8.1.2.1: Pseudo-headers start with a colon. +const char kPseudoHeaderPrefix = ':'; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_HPACK_HPACK_CONSTANTS_H_ diff --git a/quiche/spdy/core/hpack/hpack_decoder_adapter.cc b/quiche/spdy/core/hpack/hpack_decoder_adapter.cc new file mode 100644 index 000000000000..122fe4581914 --- /dev/null +++ b/quiche/spdy/core/hpack/hpack_decoder_adapter.cc @@ -0,0 +1,164 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/hpack/hpack_decoder_adapter.h" + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace spdy { +namespace { +const size_t kMaxDecodeBufferSizeBytes = 32 * 1024; // 32 KB +} // namespace + +HpackDecoderAdapter::HpackDecoderAdapter() + : hpack_decoder_(&listener_adapter_, kMaxDecodeBufferSizeBytes), + max_decode_buffer_size_bytes_(kMaxDecodeBufferSizeBytes), + max_header_block_bytes_(0), + header_block_started_(false), + error_(http2::HpackDecodingError::kOk) {} + +HpackDecoderAdapter::~HpackDecoderAdapter() = default; + +void HpackDecoderAdapter::ApplyHeaderTableSizeSetting(size_t size_setting) { + QUICHE_DVLOG(2) << "HpackDecoderAdapter::ApplyHeaderTableSizeSetting"; + hpack_decoder_.ApplyHeaderTableSizeSetting(size_setting); +} + +size_t HpackDecoderAdapter::GetCurrentHeaderTableSizeSetting() const { + return hpack_decoder_.GetCurrentHeaderTableSizeSetting(); +} + +void HpackDecoderAdapter::HandleControlFrameHeadersStart( + SpdyHeadersHandlerInterface* handler) { + QUICHE_DVLOG(2) << "HpackDecoderAdapter::HandleControlFrameHeadersStart"; + QUICHE_DCHECK(!header_block_started_); + listener_adapter_.set_handler(handler); +} + +bool HpackDecoderAdapter::HandleControlFrameHeadersData( + const char* headers_data, size_t headers_data_length) { + QUICHE_DVLOG(2) << "HpackDecoderAdapter::HandleControlFrameHeadersData: len=" + << headers_data_length; + if (!header_block_started_) { + // Initialize the decoding process here rather than in + // HandleControlFrameHeadersStart because that method is not always called. + header_block_started_ = true; + if (!hpack_decoder_.StartDecodingBlock()) { + header_block_started_ = false; + error_ = hpack_decoder_.error(); + detailed_error_ = hpack_decoder_.detailed_error(); + return false; + } + } + + // Sometimes we get a call with headers_data==nullptr and + // headers_data_length==0, in which case we need to avoid creating + // a DecodeBuffer, which would otherwise complain. + if (headers_data_length > 0) { + QUICHE_DCHECK_NE(headers_data, nullptr); + if (headers_data_length > max_decode_buffer_size_bytes_) { + QUICHE_DVLOG(1) << "max_decode_buffer_size_bytes_ < headers_data_length: " + << max_decode_buffer_size_bytes_ << " < " + << headers_data_length; + error_ = http2::HpackDecodingError::kFragmentTooLong; + detailed_error_ = ""; + return false; + } + listener_adapter_.AddToTotalHpackBytes(headers_data_length); + if (max_header_block_bytes_ != 0 && + listener_adapter_.total_hpack_bytes() > max_header_block_bytes_) { + error_ = http2::HpackDecodingError::kCompressedHeaderSizeExceedsLimit; + detailed_error_ = ""; + return false; + } + http2::DecodeBuffer db(headers_data, headers_data_length); + bool ok = hpack_decoder_.DecodeFragment(&db); + QUICHE_DCHECK(!ok || db.Empty()) << "Remaining=" << db.Remaining(); + if (!ok) { + error_ = hpack_decoder_.error(); + detailed_error_ = hpack_decoder_.detailed_error(); + } + return ok; + } + return true; +} + +bool HpackDecoderAdapter::HandleControlFrameHeadersComplete() { + QUICHE_DVLOG(2) << "HpackDecoderAdapter::HandleControlFrameHeadersComplete"; + if (!hpack_decoder_.EndDecodingBlock()) { + QUICHE_DVLOG(3) << "EndDecodingBlock returned false"; + error_ = hpack_decoder_.error(); + detailed_error_ = hpack_decoder_.detailed_error(); + return false; + } + header_block_started_ = false; + return true; +} + +const Http2HeaderBlock& HpackDecoderAdapter::decoded_block() const { + return listener_adapter_.decoded_block(); +} + +void HpackDecoderAdapter::set_max_decode_buffer_size_bytes( + size_t max_decode_buffer_size_bytes) { + QUICHE_DVLOG(2) << "HpackDecoderAdapter::set_max_decode_buffer_size_bytes"; + max_decode_buffer_size_bytes_ = max_decode_buffer_size_bytes; + hpack_decoder_.set_max_string_size_bytes(max_decode_buffer_size_bytes); +} + +void HpackDecoderAdapter::set_max_header_block_bytes( + size_t max_header_block_bytes) { + max_header_block_bytes_ = max_header_block_bytes; +} + +HpackDecoderAdapter::ListenerAdapter::ListenerAdapter() : handler_(nullptr) {} +HpackDecoderAdapter::ListenerAdapter::~ListenerAdapter() = default; + +void HpackDecoderAdapter::ListenerAdapter::set_handler( + SpdyHeadersHandlerInterface* handler) { + handler_ = handler; +} + +void HpackDecoderAdapter::ListenerAdapter::OnHeaderListStart() { + QUICHE_DVLOG(2) << "HpackDecoderAdapter::ListenerAdapter::OnHeaderListStart"; + total_hpack_bytes_ = 0; + total_uncompressed_bytes_ = 0; + decoded_block_.clear(); + if (handler_ != nullptr) { + handler_->OnHeaderBlockStart(); + } +} + +void HpackDecoderAdapter::ListenerAdapter::OnHeader(const std::string& name, + const std::string& value) { + QUICHE_DVLOG(2) << "HpackDecoderAdapter::ListenerAdapter::OnHeader:\n name: " + << name << "\n value: " << value; + total_uncompressed_bytes_ += name.size() + value.size(); + if (handler_ == nullptr) { + QUICHE_DVLOG(3) << "Adding to decoded_block"; + decoded_block_.AppendValueOrAddHeader(name, value); + } else { + QUICHE_DVLOG(3) << "Passing to handler"; + handler_->OnHeader(name, value); + } +} + +void HpackDecoderAdapter::ListenerAdapter::OnHeaderListEnd() { + QUICHE_DVLOG(2) << "HpackDecoderAdapter::ListenerAdapter::OnHeaderListEnd"; + // We don't clear the Http2HeaderBlock here to allow access to it until the + // next HPACK block is decoded. + if (handler_ != nullptr) { + handler_->OnHeaderBlockEnd(total_uncompressed_bytes_, total_hpack_bytes_); + handler_ = nullptr; + } +} + +void HpackDecoderAdapter::ListenerAdapter::OnHeaderErrorDetected( + absl::string_view error_message) { + QUICHE_VLOG(1) << error_message; +} + +} // namespace spdy diff --git a/quiche/spdy/core/hpack/hpack_decoder_adapter.h b/quiche/spdy/core/hpack/hpack_decoder_adapter.h new file mode 100644 index 000000000000..d5685a5f5869 --- /dev/null +++ b/quiche/spdy/core/hpack/hpack_decoder_adapter.h @@ -0,0 +1,156 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_SPDY_CORE_HPACK_HPACK_DECODER_ADAPTER_H_ +#define QUICHE_SPDY_CORE_HPACK_HPACK_DECODER_ADAPTER_H_ + +// HpackDecoderAdapter uses http2::HpackDecoder to decode HPACK blocks into +// HTTP/2 header lists as outlined in http://tools.ietf.org/html/rfc7541. + +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/http2/hpack/decoder/hpack_decoder.h" +#include "quiche/http2/hpack/decoder/hpack_decoder_listener.h" +#include "quiche/http2/hpack/decoder/hpack_decoder_tables.h" +#include "quiche/http2/hpack/http2_hpack_constants.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/spdy/core/hpack/hpack_header_table.h" +#include "quiche/spdy/core/http2_header_block.h" +#include "quiche/spdy/core/spdy_headers_handler_interface.h" + +namespace spdy { +namespace test { +class HpackDecoderAdapterPeer; +} // namespace test + +class QUICHE_EXPORT HpackDecoderAdapter { + public: + friend test::HpackDecoderAdapterPeer; + HpackDecoderAdapter(); + HpackDecoderAdapter(const HpackDecoderAdapter&) = delete; + HpackDecoderAdapter& operator=(const HpackDecoderAdapter&) = delete; + ~HpackDecoderAdapter(); + + // Called upon acknowledgement of SETTINGS_HEADER_TABLE_SIZE. + void ApplyHeaderTableSizeSetting(size_t size_setting); + + // Returns the most recently applied value of SETTINGS_HEADER_TABLE_SIZE. + size_t GetCurrentHeaderTableSizeSetting() const; + + // If a SpdyHeadersHandlerInterface is provided, the decoder will emit + // headers to it rather than accumulating them in a Http2HeaderBlock. + // Does not take ownership of the handler, but does use the pointer until + // the current HPACK block is completely decoded. + void HandleControlFrameHeadersStart(SpdyHeadersHandlerInterface* handler); + + // Called as HPACK block fragments arrive. Returns false if an error occurred + // while decoding the block. Does not take ownership of headers_data. + bool HandleControlFrameHeadersData(const char* headers_data, + size_t headers_data_length); + + // Called after a HPACK block has been completely delivered via + // HandleControlFrameHeadersData(). Returns false if an error occurred. + // |compressed_len| if non-null will be set to the size of the encoded + // buffered block that was accumulated in HandleControlFrameHeadersData(), + // to support subsequent calculation of compression percentage. + // Discards the handler supplied at the start of decoding the block. + bool HandleControlFrameHeadersComplete(); + + // Accessor for the most recently decoded headers block. Valid until the next + // call to HandleControlFrameHeadersData(). + // TODO(birenroy): Remove this method when all users of HpackDecoder specify + // a SpdyHeadersHandlerInterface. + const Http2HeaderBlock& decoded_block() const; + + // Returns the current dynamic table size, including the 32 bytes per entry + // overhead mentioned in RFC 7541 section 4.1. + size_t GetDynamicTableSize() const { + return hpack_decoder_.GetDynamicTableSize(); + } + + // Set how much encoded data this decoder is willing to buffer. + // TODO(jamessynge): Resolve definition of this value, as it is currently + // too tied to a single implementation. We probably want to limit one or more + // of these: individual name or value strings, header entries, the entire + // header list, or the HPACK block; we probably shouldn't care about the size + // of individual transport buffers. + void set_max_decode_buffer_size_bytes(size_t max_decode_buffer_size_bytes); + + // Specifies the maximum size of an on-the-wire header block that will be + // accepted. + void set_max_header_block_bytes(size_t max_header_block_bytes); + + // Error code if an error has occurred, Error::kOk otherwise. + http2::HpackDecodingError error() const { return error_; } + + std::string detailed_error() const { return detailed_error_; } + + private: + class QUICHE_EXPORT ListenerAdapter : public http2::HpackDecoderListener { + public: + ListenerAdapter(); + ~ListenerAdapter() override; + + // If a SpdyHeadersHandlerInterface is provided, the decoder will emit + // headers to it rather than accumulating them in a Http2HeaderBlock. + // Does not take ownership of the handler, but does use the pointer until + // the current HPACK block is completely decoded. + void set_handler(SpdyHeadersHandlerInterface* handler); + const Http2HeaderBlock& decoded_block() const { return decoded_block_; } + + // Override the HpackDecoderListener methods: + void OnHeaderListStart() override; + void OnHeader(const std::string& name, const std::string& value) override; + void OnHeaderListEnd() override; + void OnHeaderErrorDetected(absl::string_view error_message) override; + + void AddToTotalHpackBytes(size_t delta) { total_hpack_bytes_ += delta; } + size_t total_hpack_bytes() const { return total_hpack_bytes_; } + + private: + // If the caller doesn't provide a handler, the header list is stored in + // this Http2HeaderBlock. + Http2HeaderBlock decoded_block_; + + // If non-NULL, handles decoded headers. Not owned. + SpdyHeadersHandlerInterface* handler_; + + // Total bytes that have been received as input (i.e. HPACK encoded) + // in the current HPACK block. + size_t total_hpack_bytes_; + + // Total bytes of the name and value strings in the current HPACK block. + size_t total_uncompressed_bytes_; + }; + + // Converts calls to HpackDecoderListener into calls to + // SpdyHeadersHandlerInterface. + ListenerAdapter listener_adapter_; + + // The actual decoder. + http2::HpackDecoder hpack_decoder_; + + // How much encoded data this decoder is willing to buffer. + size_t max_decode_buffer_size_bytes_; + + // How much encoded data this decoder is willing to process. + size_t max_header_block_bytes_; + + // Flag to keep track of having seen the header block start. Needed at the + // moment because HandleControlFrameHeadersStart won't be called if a handler + // is not being provided by the caller. + bool header_block_started_; + + // Error code if an error has occurred, Error::kOk otherwise. + http2::HpackDecodingError error_; + std::string detailed_error_; +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_HPACK_HPACK_DECODER_ADAPTER_H_ diff --git a/quiche/spdy/core/hpack/hpack_decoder_adapter_test.cc b/quiche/spdy/core/hpack/hpack_decoder_adapter_test.cc new file mode 100644 index 000000000000..2143807695a9 --- /dev/null +++ b/quiche/spdy/core/hpack/hpack_decoder_adapter_test.cc @@ -0,0 +1,1119 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/hpack/hpack_decoder_adapter.h" + +// Tests of HpackDecoderAdapter. + +#include + +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/escaping.h" +#include "quiche/http2/hpack/decoder/hpack_decoder_state.h" +#include "quiche/http2/hpack/decoder/hpack_decoder_tables.h" +#include "quiche/http2/test_tools/hpack_block_builder.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/quiche_text_utils.h" +#include "quiche/spdy/core/hpack/hpack_constants.h" +#include "quiche/spdy/core/hpack/hpack_encoder.h" +#include "quiche/spdy/core/hpack/hpack_output_stream.h" +#include "quiche/spdy/core/recording_headers_handler.h" +#include "quiche/spdy/test_tools/spdy_test_utils.h" + +using ::http2::HpackEntryType; +using ::http2::HpackStringPair; +using ::http2::test::HpackBlockBuilder; +using ::http2::test::HpackDecoderPeer; +using ::testing::ElementsAre; +using ::testing::Pair; + +namespace http2 { +namespace test { + +class HpackDecoderStatePeer { + public: + static HpackDecoderTables* GetDecoderTables(HpackDecoderState* state) { + return &state->decoder_tables_; + } +}; + +class HpackDecoderPeer { + public: + static HpackDecoderState* GetDecoderState(HpackDecoder* decoder) { + return &decoder->decoder_state_; + } + static HpackDecoderTables* GetDecoderTables(HpackDecoder* decoder) { + return HpackDecoderStatePeer::GetDecoderTables(GetDecoderState(decoder)); + } +}; + +} // namespace test +} // namespace http2 + +namespace spdy { +namespace test { + +class HpackDecoderAdapterPeer { + public: + explicit HpackDecoderAdapterPeer(HpackDecoderAdapter* decoder) + : decoder_(decoder) {} + + void HandleHeaderRepresentation(const std::string& name, + const std::string& value) { + decoder_->listener_adapter_.OnHeader(name, value); + } + + http2::HpackDecoderTables* GetDecoderTables() { + return HpackDecoderPeer::GetDecoderTables(&decoder_->hpack_decoder_); + } + + const HpackStringPair* GetTableEntry(uint32_t index) { + return GetDecoderTables()->Lookup(index); + } + + size_t current_header_table_size() { + return GetDecoderTables()->current_header_table_size(); + } + + size_t header_table_size_limit() { + return GetDecoderTables()->header_table_size_limit(); + } + + void set_header_table_size_limit(size_t size) { + return GetDecoderTables()->DynamicTableSizeUpdate(size); + } + + private: + HpackDecoderAdapter* decoder_; +}; + +class HpackEncoderPeer { + public: + static void CookieToCrumbs(const HpackEncoder::Representation& cookie, + HpackEncoder::Representations* crumbs_out) { + HpackEncoder::CookieToCrumbs(cookie, crumbs_out); + } +}; + +namespace { + +const bool kNoCheckDecodedSize = false; +const char* kCookieKey = "cookie"; + +// Is HandleControlFrameHeadersStart to be called, and with what value? +enum StartChoice { START_WITH_HANDLER, START_WITHOUT_HANDLER, NO_START }; + +class HpackDecoderAdapterTest + : public quiche::test::QuicheTestWithParam> { + protected: + HpackDecoderAdapterTest() : decoder_(), decoder_peer_(&decoder_) {} + + void SetUp() override { + std::tie(start_choice_, randomly_split_input_buffer_) = GetParam(); + } + + void HandleControlFrameHeadersStart() { + bytes_passed_in_ = 0; + switch (start_choice_) { + case START_WITH_HANDLER: + decoder_.HandleControlFrameHeadersStart(&handler_); + break; + case START_WITHOUT_HANDLER: + decoder_.HandleControlFrameHeadersStart(nullptr); + break; + case NO_START: + break; + } + } + + bool HandleControlFrameHeadersData(absl::string_view str) { + QUICHE_VLOG(3) << "HandleControlFrameHeadersData:\n" + << quiche::QuicheTextUtils::HexDump(str); + bytes_passed_in_ += str.size(); + return decoder_.HandleControlFrameHeadersData(str.data(), str.size()); + } + + bool HandleControlFrameHeadersComplete() { + bool rc = decoder_.HandleControlFrameHeadersComplete(); + return rc; + } + + bool DecodeHeaderBlock(absl::string_view str, + bool check_decoded_size = true) { + // Don't call this again if HandleControlFrameHeadersData failed previously. + EXPECT_FALSE(decode_has_failed_); + HandleControlFrameHeadersStart(); + if (randomly_split_input_buffer_) { + do { + // Decode some fragment of the remaining bytes. + size_t bytes = str.size(); + if (!str.empty()) { + bytes = random_.Uniform(str.size()) + 1; + } + EXPECT_LE(bytes, str.size()); + if (!HandleControlFrameHeadersData(str.substr(0, bytes))) { + decode_has_failed_ = true; + return false; + } + str.remove_prefix(bytes); + } while (!str.empty()); + } else if (!HandleControlFrameHeadersData(str)) { + decode_has_failed_ = true; + return false; + } + if (start_choice_ == START_WITH_HANDLER) { + if (!HandleControlFrameHeadersComplete()) { + decode_has_failed_ = true; + return false; + } + EXPECT_EQ(handler_.compressed_header_bytes(), bytes_passed_in_); + } else { + if (!HandleControlFrameHeadersComplete()) { + decode_has_failed_ = true; + return false; + } + } + if (check_decoded_size && start_choice_ == START_WITH_HANDLER) { + EXPECT_EQ(handler_.uncompressed_header_bytes(), + SizeOfHeaders(decoded_block())); + } + return true; + } + + bool EncodeAndDecodeDynamicTableSizeUpdates(size_t first, size_t second) { + HpackBlockBuilder hbb; + hbb.AppendDynamicTableSizeUpdate(first); + if (second != first) { + hbb.AppendDynamicTableSizeUpdate(second); + } + return DecodeHeaderBlock(hbb.buffer()); + } + + const Http2HeaderBlock& decoded_block() const { + if (start_choice_ == START_WITH_HANDLER) { + return handler_.decoded_block(); + } else { + return decoder_.decoded_block(); + } + } + + static size_t SizeOfHeaders(const Http2HeaderBlock& headers) { + size_t size = 0; + for (const auto& kv : headers) { + if (kv.first == kCookieKey) { + HpackEncoder::Representations crumbs; + HpackEncoderPeer::CookieToCrumbs(kv, &crumbs); + for (const auto& crumb : crumbs) { + size += crumb.first.size() + crumb.second.size(); + } + } else { + size += kv.first.size() + kv.second.size(); + } + } + return size; + } + + const Http2HeaderBlock& DecodeBlockExpectingSuccess(absl::string_view str) { + EXPECT_TRUE(DecodeHeaderBlock(str)); + return decoded_block(); + } + + void expectEntry(size_t index, size_t size, const std::string& name, + const std::string& value) { + const HpackStringPair* entry = decoder_peer_.GetTableEntry(index); + EXPECT_EQ(name, entry->name) << "index " << index; + EXPECT_EQ(value, entry->value); + EXPECT_EQ(size, entry->size()); + } + + Http2HeaderBlock MakeHeaderBlock( + const std::vector>& headers) { + Http2HeaderBlock result; + for (const auto& kv : headers) { + result.AppendValueOrAddHeader(kv.first, kv.second); + } + return result; + } + + http2::test::Http2Random random_; + HpackDecoderAdapter decoder_; + test::HpackDecoderAdapterPeer decoder_peer_; + RecordingHeadersHandler handler_; + StartChoice start_choice_; + bool randomly_split_input_buffer_; + bool decode_has_failed_ = false; + size_t bytes_passed_in_; +}; + +INSTANTIATE_TEST_SUITE_P( + NoHandler, HpackDecoderAdapterTest, + ::testing::Combine(::testing::Values(START_WITHOUT_HANDLER, NO_START), + ::testing::Bool())); + +INSTANTIATE_TEST_SUITE_P( + WithHandler, HpackDecoderAdapterTest, + ::testing::Combine(::testing::Values(START_WITH_HANDLER), + ::testing::Bool())); + +TEST_P(HpackDecoderAdapterTest, ApplyHeaderTableSizeSetting) { + EXPECT_EQ(4096u, decoder_.GetCurrentHeaderTableSizeSetting()); + decoder_.ApplyHeaderTableSizeSetting(12 * 1024); + EXPECT_EQ(12288u, decoder_.GetCurrentHeaderTableSizeSetting()); +} + +TEST_P(HpackDecoderAdapterTest, + AddHeaderDataWithHandleControlFrameHeadersData) { + // The hpack decode buffer size is limited in size. This test verifies that + // adding encoded data under that limit is accepted, and data that exceeds the + // limit is rejected. + HandleControlFrameHeadersStart(); + const size_t kMaxBufferSizeBytes = 50; + const std::string a_value = std::string(49, 'x'); + decoder_.set_max_decode_buffer_size_bytes(kMaxBufferSizeBytes); + HpackBlockBuilder hbb; + hbb.AppendLiteralNameAndValue(HpackEntryType::kNeverIndexedLiteralHeader, + false, "a", false, a_value); + const std::string& s = hbb.buffer(); + EXPECT_GT(s.size(), kMaxBufferSizeBytes); + + // Any one in input buffer must not exceed kMaxBufferSizeBytes. + EXPECT_TRUE(HandleControlFrameHeadersData(s.substr(0, s.size() / 2))); + EXPECT_TRUE(HandleControlFrameHeadersData(s.substr(s.size() / 2))); + + EXPECT_FALSE(HandleControlFrameHeadersData(s)); + Http2HeaderBlock expected_block = MakeHeaderBlock({{"a", a_value}}); + EXPECT_EQ(expected_block, decoded_block()); +} + +TEST_P(HpackDecoderAdapterTest, NameTooLong) { + // Verify that a name longer than the allowed size generates an error. + const size_t kMaxBufferSizeBytes = 50; + const std::string name = std::string(2 * kMaxBufferSizeBytes, 'x'); + const std::string value = "abc"; + + decoder_.set_max_decode_buffer_size_bytes(kMaxBufferSizeBytes); + + HpackBlockBuilder hbb; + hbb.AppendLiteralNameAndValue(HpackEntryType::kNeverIndexedLiteralHeader, + false, name, false, value); + + const size_t fragment_size = (3 * kMaxBufferSizeBytes) / 2; + const std::string fragment = hbb.buffer().substr(0, fragment_size); + + HandleControlFrameHeadersStart(); + EXPECT_FALSE(HandleControlFrameHeadersData(fragment)); +} + +TEST_P(HpackDecoderAdapterTest, HeaderTooLongToBuffer) { + // Verify that a header longer than the allowed size generates an error if + // it isn't all in one input buffer. + const std::string name = "some-key"; + const std::string value = "some-value"; + const size_t kMaxBufferSizeBytes = name.size() + value.size() - 2; + decoder_.set_max_decode_buffer_size_bytes(kMaxBufferSizeBytes); + + HpackBlockBuilder hbb; + hbb.AppendLiteralNameAndValue(HpackEntryType::kNeverIndexedLiteralHeader, + false, name, false, value); + const size_t fragment_size = hbb.size() - 1; + const std::string fragment = hbb.buffer().substr(0, fragment_size); + + HandleControlFrameHeadersStart(); + EXPECT_FALSE(HandleControlFrameHeadersData(fragment)); +} + +// Verify that a header block that exceeds the maximum length is rejected. +TEST_P(HpackDecoderAdapterTest, HeaderBlockTooLong) { + const std::string name = "some-key"; + const std::string value = "some-value"; + const size_t kMaxBufferSizeBytes = 1024; + + HpackBlockBuilder hbb; + hbb.AppendLiteralNameAndValue(HpackEntryType::kIndexedLiteralHeader, false, + name, false, value); + while (hbb.size() < kMaxBufferSizeBytes) { + hbb.AppendLiteralNameAndValue(HpackEntryType::kIndexedLiteralHeader, false, + "", false, ""); + } + // With no limit on the maximum header block size, the decoder handles the + // entire block successfully. + HandleControlFrameHeadersStart(); + EXPECT_TRUE(HandleControlFrameHeadersData(hbb.buffer())); + EXPECT_TRUE(HandleControlFrameHeadersComplete()); + + // When a total byte limit is imposed, the decoder bails before the end of the + // block. + decoder_.set_max_header_block_bytes(kMaxBufferSizeBytes); + HandleControlFrameHeadersStart(); + EXPECT_FALSE(HandleControlFrameHeadersData(hbb.buffer())); +} + +// Decode with incomplete data in buffer. +TEST_P(HpackDecoderAdapterTest, DecodeWithIncompleteData) { + HandleControlFrameHeadersStart(); + + // No need to wait for more data. + EXPECT_TRUE(HandleControlFrameHeadersData("\x82\x85\x82")); + std::vector> expected_headers = { + {":method", "GET"}, {":path", "/index.html"}, {":method", "GET"}}; + + Http2HeaderBlock expected_block1 = MakeHeaderBlock(expected_headers); + EXPECT_EQ(expected_block1, decoded_block()); + + // Full and partial headers, won't add partial to the headers. + EXPECT_TRUE( + HandleControlFrameHeadersData("\x40\x03goo" + "\x03gar\xbe\x40\x04spam")); + expected_headers.push_back({"goo", "gar"}); + expected_headers.push_back({"goo", "gar"}); + + Http2HeaderBlock expected_block2 = MakeHeaderBlock(expected_headers); + EXPECT_EQ(expected_block2, decoded_block()); + + // Add the needed data. + EXPECT_TRUE(HandleControlFrameHeadersData("\x04gggs")); + + EXPECT_TRUE(HandleControlFrameHeadersComplete()); + + expected_headers.push_back({"spam", "gggs"}); + + Http2HeaderBlock expected_block3 = MakeHeaderBlock(expected_headers); + EXPECT_EQ(expected_block3, decoded_block()); +} + +TEST_P(HpackDecoderAdapterTest, HandleHeaderRepresentation) { + // Make sure the decoder is properly initialized. + HandleControlFrameHeadersStart(); + HandleControlFrameHeadersData(""); + + // All cookie crumbs are joined. + decoder_peer_.HandleHeaderRepresentation("cookie", " part 1"); + decoder_peer_.HandleHeaderRepresentation("cookie", "part 2 "); + decoder_peer_.HandleHeaderRepresentation("cookie", "part3"); + + // Already-delimited headers are passed through. + decoder_peer_.HandleHeaderRepresentation("passed-through", + std::string("foo\0baz", 7)); + + // Other headers are joined on \0. Case matters. + decoder_peer_.HandleHeaderRepresentation("joined", "joined"); + decoder_peer_.HandleHeaderRepresentation("joineD", "value 1"); + decoder_peer_.HandleHeaderRepresentation("joineD", "value 2"); + + // Empty headers remain empty. + decoder_peer_.HandleHeaderRepresentation("empty", ""); + + // Joined empty headers work as expected. + decoder_peer_.HandleHeaderRepresentation("empty-joined", ""); + decoder_peer_.HandleHeaderRepresentation("empty-joined", "foo"); + decoder_peer_.HandleHeaderRepresentation("empty-joined", ""); + decoder_peer_.HandleHeaderRepresentation("empty-joined", ""); + + // Non-contiguous cookie crumb. + decoder_peer_.HandleHeaderRepresentation("cookie", " fin!"); + + // Finish and emit all headers. + decoder_.HandleControlFrameHeadersComplete(); + + // Resulting decoded headers are in the same order as the inputs. + EXPECT_THAT( + decoded_block(), + ElementsAre( + Pair("cookie", " part 1; part 2 ; part3; fin!"), + Pair("passed-through", absl::string_view("foo\0baz", 7)), + Pair("joined", absl::string_view("joined\0value 1\0value 2", 22)), + Pair("empty", ""), + Pair("empty-joined", absl::string_view("\0foo\0\0", 6)))); +} + +// Decoding indexed static table field should work. +TEST_P(HpackDecoderAdapterTest, IndexedHeaderStatic) { + // Reference static table entries #2 and #5. + const Http2HeaderBlock& header_set1 = DecodeBlockExpectingSuccess("\x82\x85"); + Http2HeaderBlock expected_header_set1; + expected_header_set1[":method"] = "GET"; + expected_header_set1[":path"] = "/index.html"; + EXPECT_EQ(expected_header_set1, header_set1); + + // Reference static table entry #2. + const Http2HeaderBlock& header_set2 = DecodeBlockExpectingSuccess("\x82"); + Http2HeaderBlock expected_header_set2; + expected_header_set2[":method"] = "GET"; + EXPECT_EQ(expected_header_set2, header_set2); +} + +TEST_P(HpackDecoderAdapterTest, IndexedHeaderDynamic) { + // First header block: add an entry to header table. + const Http2HeaderBlock& header_set1 = DecodeBlockExpectingSuccess( + "\x40\x03" + "foo" + "\x03" + "bar"); + Http2HeaderBlock expected_header_set1; + expected_header_set1["foo"] = "bar"; + EXPECT_EQ(expected_header_set1, header_set1); + + // Second header block: add another entry to header table. + const Http2HeaderBlock& header_set2 = DecodeBlockExpectingSuccess( + "\xbe\x40\x04" + "spam" + "\x04" + "eggs"); + Http2HeaderBlock expected_header_set2; + expected_header_set2["foo"] = "bar"; + expected_header_set2["spam"] = "eggs"; + EXPECT_EQ(expected_header_set2, header_set2); + + // Third header block: refer to most recently added entry. + const Http2HeaderBlock& header_set3 = DecodeBlockExpectingSuccess("\xbe"); + Http2HeaderBlock expected_header_set3; + expected_header_set3["spam"] = "eggs"; + EXPECT_EQ(expected_header_set3, header_set3); +} + +// Test a too-large indexed header. +TEST_P(HpackDecoderAdapterTest, InvalidIndexedHeader) { + // High-bit set, and a prefix of one more than the number of static entries. + EXPECT_FALSE(DecodeHeaderBlock("\xbe")); +} + +TEST_P(HpackDecoderAdapterTest, ContextUpdateMaximumSize) { + EXPECT_EQ(kDefaultHeaderTableSizeSetting, + decoder_peer_.header_table_size_limit()); + std::string input; + { + // Maximum-size update with size 126. Succeeds. + HpackOutputStream output_stream; + output_stream.AppendPrefix(kHeaderTableSizeUpdateOpcode); + output_stream.AppendUint32(126); + + input = output_stream.TakeString(); + EXPECT_TRUE(DecodeHeaderBlock(input)); + EXPECT_EQ(126u, decoder_peer_.header_table_size_limit()); + } + { + // Maximum-size update with kDefaultHeaderTableSizeSetting. Succeeds. + HpackOutputStream output_stream; + output_stream.AppendPrefix(kHeaderTableSizeUpdateOpcode); + output_stream.AppendUint32(kDefaultHeaderTableSizeSetting); + + input = output_stream.TakeString(); + EXPECT_TRUE(DecodeHeaderBlock(input)); + EXPECT_EQ(kDefaultHeaderTableSizeSetting, + decoder_peer_.header_table_size_limit()); + } + { + // Maximum-size update with kDefaultHeaderTableSizeSetting + 1. Fails. + HpackOutputStream output_stream; + output_stream.AppendPrefix(kHeaderTableSizeUpdateOpcode); + output_stream.AppendUint32(kDefaultHeaderTableSizeSetting + 1); + + input = output_stream.TakeString(); + EXPECT_FALSE(DecodeHeaderBlock(input)); + EXPECT_EQ(kDefaultHeaderTableSizeSetting, + decoder_peer_.header_table_size_limit()); + } +} + +// Two HeaderTableSizeUpdates may appear at the beginning of the block +TEST_P(HpackDecoderAdapterTest, TwoTableSizeUpdates) { + std::string input; + { + // Should accept two table size updates, update to second one + HpackOutputStream output_stream; + output_stream.AppendPrefix(kHeaderTableSizeUpdateOpcode); + output_stream.AppendUint32(0); + output_stream.AppendPrefix(kHeaderTableSizeUpdateOpcode); + output_stream.AppendUint32(122); + + input = output_stream.TakeString(); + EXPECT_TRUE(DecodeHeaderBlock(input)); + EXPECT_EQ(122u, decoder_peer_.header_table_size_limit()); + } +} + +// Three HeaderTableSizeUpdates should result in an error +TEST_P(HpackDecoderAdapterTest, ThreeTableSizeUpdatesError) { + std::string input; + { + // Should reject three table size updates, update to second one + HpackOutputStream output_stream; + output_stream.AppendPrefix(kHeaderTableSizeUpdateOpcode); + output_stream.AppendUint32(5); + output_stream.AppendPrefix(kHeaderTableSizeUpdateOpcode); + output_stream.AppendUint32(10); + output_stream.AppendPrefix(kHeaderTableSizeUpdateOpcode); + output_stream.AppendUint32(15); + + input = output_stream.TakeString(); + + EXPECT_FALSE(DecodeHeaderBlock(input)); + EXPECT_EQ(10u, decoder_peer_.header_table_size_limit()); + } +} + +// HeaderTableSizeUpdates may only appear at the beginning of the block +// Any other updates should result in an error +TEST_P(HpackDecoderAdapterTest, TableSizeUpdateSecondError) { + std::string input; + { + // Should reject a table size update appearing after a different entry + // The table size should remain as the default + HpackOutputStream output_stream; + output_stream.AppendBytes("\x82\x85"); + output_stream.AppendPrefix(kHeaderTableSizeUpdateOpcode); + output_stream.AppendUint32(123); + + input = output_stream.TakeString(); + + EXPECT_FALSE(DecodeHeaderBlock(input)); + EXPECT_EQ(kDefaultHeaderTableSizeSetting, + decoder_peer_.header_table_size_limit()); + } +} + +// HeaderTableSizeUpdates may only appear at the beginning of the block +// Any other updates should result in an error +TEST_P(HpackDecoderAdapterTest, TableSizeUpdateFirstThirdError) { + std::string input; + { + // Should reject the second table size update + // if a different entry appears after the first update + // The table size should update to the first but not the second + HpackOutputStream output_stream; + output_stream.AppendPrefix(kHeaderTableSizeUpdateOpcode); + output_stream.AppendUint32(60); + output_stream.AppendBytes("\x82\x85"); + output_stream.AppendPrefix(kHeaderTableSizeUpdateOpcode); + output_stream.AppendUint32(125); + + input = output_stream.TakeString(); + + EXPECT_FALSE(DecodeHeaderBlock(input)); + EXPECT_EQ(60u, decoder_peer_.header_table_size_limit()); + } +} + +// Decoding two valid encoded literal headers with no indexing should +// work. +TEST_P(HpackDecoderAdapterTest, LiteralHeaderNoIndexing) { + // First header with indexed name, second header with string literal + // name. + const char input[] = "\x04\x0c/sample/path\x00\x06:path2\x0e/sample/path/2"; + const Http2HeaderBlock& header_set = DecodeBlockExpectingSuccess( + absl::string_view(input, ABSL_ARRAYSIZE(input) - 1)); + + Http2HeaderBlock expected_header_set; + expected_header_set[":path"] = "/sample/path"; + expected_header_set[":path2"] = "/sample/path/2"; + EXPECT_EQ(expected_header_set, header_set); +} + +// Decoding two valid encoded literal headers with incremental +// indexing and string literal names should work. +TEST_P(HpackDecoderAdapterTest, LiteralHeaderIncrementalIndexing) { + const char input[] = "\x44\x0c/sample/path\x40\x06:path2\x0e/sample/path/2"; + const Http2HeaderBlock& header_set = DecodeBlockExpectingSuccess( + absl::string_view(input, ABSL_ARRAYSIZE(input) - 1)); + + Http2HeaderBlock expected_header_set; + expected_header_set[":path"] = "/sample/path"; + expected_header_set[":path2"] = "/sample/path/2"; + EXPECT_EQ(expected_header_set, header_set); +} + +TEST_P(HpackDecoderAdapterTest, LiteralHeaderWithIndexingInvalidNameIndex) { + decoder_.ApplyHeaderTableSizeSetting(0); + EXPECT_TRUE(EncodeAndDecodeDynamicTableSizeUpdates(0, 0)); + + // Name is the last static index. Works. + EXPECT_TRUE(DecodeHeaderBlock(absl::string_view("\x7d\x03ooo"))); + // Name is one beyond the last static index. Fails. + EXPECT_FALSE(DecodeHeaderBlock(absl::string_view("\x7e\x03ooo"))); +} + +TEST_P(HpackDecoderAdapterTest, LiteralHeaderNoIndexingInvalidNameIndex) { + // Name is the last static index. Works. + EXPECT_TRUE(DecodeHeaderBlock(absl::string_view("\x0f\x2e\x03ooo"))); + // Name is one beyond the last static index. Fails. + EXPECT_FALSE(DecodeHeaderBlock(absl::string_view("\x0f\x2f\x03ooo"))); +} + +TEST_P(HpackDecoderAdapterTest, LiteralHeaderNeverIndexedInvalidNameIndex) { + // Name is the last static index. Works. + EXPECT_TRUE(DecodeHeaderBlock(absl::string_view("\x1f\x2e\x03ooo"))); + // Name is one beyond the last static index. Fails. + EXPECT_FALSE(DecodeHeaderBlock(absl::string_view("\x1f\x2f\x03ooo"))); +} + +TEST_P(HpackDecoderAdapterTest, TruncatedIndex) { + // Indexed Header, varint for index requires multiple bytes, + // but only one provided. + EXPECT_FALSE(DecodeHeaderBlock("\xff")); +} + +TEST_P(HpackDecoderAdapterTest, TruncatedHuffmanLiteral) { + // Literal value, Huffman encoded, but with the last byte missing (i.e. + // drop the final ff shown below). + // + // 41 | == Literal indexed == + // | Indexed name (idx = 1) + // | :authority + // 8c | Literal value (len = 12) + // | Huffman encoded: + // f1e3 c2e5 f23a 6ba0 ab90 f4ff | .....:k..... + // | Decoded: + // | www.example.com + // | -> :authority: www.example.com + + std::string first = absl::HexStringToBytes("418cf1e3c2e5f23a6ba0ab90f4ff"); + EXPECT_TRUE(DecodeHeaderBlock(first)); + first.pop_back(); + EXPECT_FALSE(DecodeHeaderBlock(first)); +} + +TEST_P(HpackDecoderAdapterTest, HuffmanEOSError) { + // Literal value, Huffman encoded, but with an additional ff byte at the end + // of the string, i.e. an EOS that is longer than permitted. + // + // 41 | == Literal indexed == + // | Indexed name (idx = 1) + // | :authority + // 8d | Literal value (len = 13) + // | Huffman encoded: + // f1e3 c2e5 f23a 6ba0 ab90 f4ff | .....:k..... + // | Decoded: + // | www.example.com + // | -> :authority: www.example.com + + std::string first = absl::HexStringToBytes("418cf1e3c2e5f23a6ba0ab90f4ff"); + EXPECT_TRUE(DecodeHeaderBlock(first)); + first = absl::HexStringToBytes("418df1e3c2e5f23a6ba0ab90f4ffff"); + EXPECT_FALSE(DecodeHeaderBlock(first)); +} + +// Round-tripping the header set from RFC 7541 C.3.1 should work. +// http://httpwg.org/specs/rfc7541.html#rfc.section.C.3.1 +TEST_P(HpackDecoderAdapterTest, BasicC31) { + HpackEncoder encoder; + + Http2HeaderBlock expected_header_set; + expected_header_set[":method"] = "GET"; + expected_header_set[":scheme"] = "http"; + expected_header_set[":path"] = "/"; + expected_header_set[":authority"] = "www.example.com"; + + std::string encoded_header_set = + encoder.EncodeHeaderBlock(expected_header_set); + + EXPECT_TRUE(DecodeHeaderBlock(encoded_header_set)); + EXPECT_EQ(expected_header_set, decoded_block()); +} + +// RFC 7541, Section C.4: Request Examples with Huffman Coding +// http://httpwg.org/specs/rfc7541.html#rfc.section.C.4 +TEST_P(HpackDecoderAdapterTest, SectionC4RequestHuffmanExamples) { + // TODO(jamessynge): Use http2/hpack/tools/hpack_example.h to parse the + // example directly, instead of having it as a comment. + // + // 82 | == Indexed - Add == + // | idx = 2 + // | -> :method: GET + // 86 | == Indexed - Add == + // | idx = 6 + // | -> :scheme: http + // 84 | == Indexed - Add == + // | idx = 4 + // | -> :path: / + // 41 | == Literal indexed == + // | Indexed name (idx = 1) + // | :authority + // 8c | Literal value (len = 12) + // | Huffman encoded: + // f1e3 c2e5 f23a 6ba0 ab90 f4ff | .....:k..... + // | Decoded: + // | www.example.com + // | -> :authority: www.example.com + std::string first = + absl::HexStringToBytes("828684418cf1e3c2e5f23a6ba0ab90f4ff"); + const Http2HeaderBlock& first_header_set = DecodeBlockExpectingSuccess(first); + + EXPECT_THAT(first_header_set, + ElementsAre( + // clang-format off + Pair(":method", "GET"), + Pair(":scheme", "http"), + Pair(":path", "/"), + Pair(":authority", "www.example.com"))); + // clang-format on + + expectEntry(62, 57, ":authority", "www.example.com"); + EXPECT_EQ(57u, decoder_peer_.current_header_table_size()); + + // 82 | == Indexed - Add == + // | idx = 2 + // | -> :method: GET + // 86 | == Indexed - Add == + // | idx = 6 + // | -> :scheme: http + // 84 | == Indexed - Add == + // | idx = 4 + // | -> :path: / + // be | == Indexed - Add == + // | idx = 62 + // | -> :authority: www.example.com + // 58 | == Literal indexed == + // | Indexed name (idx = 24) + // | cache-control + // 86 | Literal value (len = 8) + // | Huffman encoded: + // a8eb 1064 9cbf | ...d.. + // | Decoded: + // | no-cache + // | -> cache-control: no-cache + + std::string second = absl::HexStringToBytes("828684be5886a8eb10649cbf"); + const Http2HeaderBlock& second_header_set = + DecodeBlockExpectingSuccess(second); + + EXPECT_THAT(second_header_set, + ElementsAre( + // clang-format off + Pair(":method", "GET"), + Pair(":scheme", "http"), + Pair(":path", "/"), + Pair(":authority", "www.example.com"), + Pair("cache-control", "no-cache"))); + // clang-format on + + expectEntry(62, 53, "cache-control", "no-cache"); + expectEntry(63, 57, ":authority", "www.example.com"); + EXPECT_EQ(110u, decoder_peer_.current_header_table_size()); + + // 82 | == Indexed - Add == + // | idx = 2 + // | -> :method: GET + // 87 | == Indexed - Add == + // | idx = 7 + // | -> :scheme: https + // 85 | == Indexed - Add == + // | idx = 5 + // | -> :path: /index.html + // bf | == Indexed - Add == + // | idx = 63 + // | -> :authority: www.example.com + // 40 | == Literal indexed == + // 88 | Literal name (len = 10) + // | Huffman encoded: + // 25a8 49e9 5ba9 7d7f | %.I.[.}. + // | Decoded: + // | custom-key + // 89 | Literal value (len = 12) + // | Huffman encoded: + // 25a8 49e9 5bb8 e8b4 bf | %.I.[.... + // | Decoded: + // | custom-value + // | -> custom-key: custom-value + std::string third = absl::HexStringToBytes( + "828785bf408825a849e95ba97d7f8925a849e95bb8e8b4bf"); + const Http2HeaderBlock& third_header_set = DecodeBlockExpectingSuccess(third); + + EXPECT_THAT( + third_header_set, + ElementsAre( + // clang-format off + Pair(":method", "GET"), + Pair(":scheme", "https"), + Pair(":path", "/index.html"), + Pair(":authority", "www.example.com"), + Pair("custom-key", "custom-value"))); + // clang-format on + + expectEntry(62, 54, "custom-key", "custom-value"); + expectEntry(63, 53, "cache-control", "no-cache"); + expectEntry(64, 57, ":authority", "www.example.com"); + EXPECT_EQ(164u, decoder_peer_.current_header_table_size()); +} + +// RFC 7541, Section C.6: Response Examples with Huffman Coding +// http://httpwg.org/specs/rfc7541.html#rfc.section.C.6 +TEST_P(HpackDecoderAdapterTest, SectionC6ResponseHuffmanExamples) { + // The example is based on a maximum dynamic table size of 256, + // which allows for testing dynamic table evictions. + decoder_peer_.set_header_table_size_limit(256); + + // 48 | == Literal indexed == + // | Indexed name (idx = 8) + // | :status + // 82 | Literal value (len = 3) + // | Huffman encoded: + // 6402 | d. + // | Decoded: + // | 302 + // | -> :status: 302 + // 58 | == Literal indexed == + // | Indexed name (idx = 24) + // | cache-control + // 85 | Literal value (len = 7) + // | Huffman encoded: + // aec3 771a 4b | ..w.K + // | Decoded: + // | private + // | -> cache-control: private + // 61 | == Literal indexed == + // | Indexed name (idx = 33) + // | date + // 96 | Literal value (len = 29) + // | Huffman encoded: + // d07a be94 1054 d444 a820 0595 040b 8166 | .z...T.D. .....f + // e082 a62d 1bff | ...-.. + // | Decoded: + // | Mon, 21 Oct 2013 20:13:21 + // | GMT + // | -> date: Mon, 21 Oct 2013 + // | 20:13:21 GMT + // 6e | == Literal indexed == + // | Indexed name (idx = 46) + // | location + // 91 | Literal value (len = 23) + // | Huffman encoded: + // 9d29 ad17 1863 c78f 0b97 c8e9 ae82 ae43 | .)...c.........C + // d3 | . + // | Decoded: + // | https://www.example.com + // | -> location: https://www.e + // | xample.com + + std::string first = absl::HexStringToBytes( + "488264025885aec3771a4b6196d07abe" + "941054d444a8200595040b8166e082a6" + "2d1bff6e919d29ad171863c78f0b97c8" + "e9ae82ae43d3"); + const Http2HeaderBlock& first_header_set = DecodeBlockExpectingSuccess(first); + + EXPECT_THAT(first_header_set, + ElementsAre( + // clang-format off + Pair(":status", "302"), + Pair("cache-control", "private"), + Pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"), + Pair("location", "https://www.example.com"))); + // clang-format on + + expectEntry(62, 63, "location", "https://www.example.com"); + expectEntry(63, 65, "date", "Mon, 21 Oct 2013 20:13:21 GMT"); + expectEntry(64, 52, "cache-control", "private"); + expectEntry(65, 42, ":status", "302"); + EXPECT_EQ(222u, decoder_peer_.current_header_table_size()); + + // 48 | == Literal indexed == + // | Indexed name (idx = 8) + // | :status + // 83 | Literal value (len = 3) + // | Huffman encoded: + // 640e ff | d.. + // | Decoded: + // | 307 + // | - evict: :status: 302 + // | -> :status: 307 + // c1 | == Indexed - Add == + // | idx = 65 + // | -> cache-control: private + // c0 | == Indexed - Add == + // | idx = 64 + // | -> date: Mon, 21 Oct 2013 + // | 20:13:21 GMT + // bf | == Indexed - Add == + // | idx = 63 + // | -> location: + // | https://www.example.com + std::string second = absl::HexStringToBytes("4883640effc1c0bf"); + const Http2HeaderBlock& second_header_set = + DecodeBlockExpectingSuccess(second); + + EXPECT_THAT(second_header_set, + ElementsAre( + // clang-format off + Pair(":status", "307"), + Pair("cache-control", "private"), + Pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"), + Pair("location", "https://www.example.com"))); + // clang-format on + + expectEntry(62, 42, ":status", "307"); + expectEntry(63, 63, "location", "https://www.example.com"); + expectEntry(64, 65, "date", "Mon, 21 Oct 2013 20:13:21 GMT"); + expectEntry(65, 52, "cache-control", "private"); + EXPECT_EQ(222u, decoder_peer_.current_header_table_size()); + + // 88 | == Indexed - Add == + // | idx = 8 + // | -> :status: 200 + // c1 | == Indexed - Add == + // | idx = 65 + // | -> cache-control: private + // 61 | == Literal indexed == + // | Indexed name (idx = 33) + // | date + // 96 | Literal value (len = 22) + // | Huffman encoded: + // d07a be94 1054 d444 a820 0595 040b 8166 | .z...T.D. .....f + // e084 a62d 1bff | ...-.. + // | Decoded: + // | Mon, 21 Oct 2013 20:13:22 + // | GMT + // | - evict: cache-control: + // | private + // | -> date: Mon, 21 Oct 2013 + // | 20:13:22 GMT + // c0 | == Indexed - Add == + // | idx = 64 + // | -> location: + // | https://www.example.com + // 5a | == Literal indexed == + // | Indexed name (idx = 26) + // | content-encoding + // 83 | Literal value (len = 3) + // | Huffman encoded: + // 9bd9 ab | ... + // | Decoded: + // | gzip + // | - evict: date: Mon, 21 Oct + // | 2013 20:13:21 GMT + // | -> content-encoding: gzip + // 77 | == Literal indexed == + // | Indexed name (idx = 55) + // | set-cookie + // ad | Literal value (len = 45) + // | Huffman encoded: + // 94e7 821d d7f2 e6c7 b335 dfdf cd5b 3960 | .........5...[9` + // d5af 2708 7f36 72c1 ab27 0fb5 291f 9587 | ..'..6r..'..)... + // 3160 65c0 03ed 4ee5 b106 3d50 07 | 1`e...N...=P. + // | Decoded: + // | foo=ASDJKHQKBZXOQWEOPIUAXQ + // | WEOIU; max-age=3600; versi + // | on=1 + // | - evict: location: + // | https://www.example.com + // | - evict: :status: 307 + // | -> set-cookie: foo=ASDJKHQ + // | KBZXOQWEOPIUAXQWEOIU; + // | max-age=3600; version=1 + std::string third = absl::HexStringToBytes( + "88c16196d07abe941054d444a8200595" + "040b8166e084a62d1bffc05a839bd9ab" + "77ad94e7821dd7f2e6c7b335dfdfcd5b" + "3960d5af27087f3672c1ab270fb5291f" + "9587316065c003ed4ee5b1063d5007"); + const Http2HeaderBlock& third_header_set = DecodeBlockExpectingSuccess(third); + + EXPECT_THAT(third_header_set, + ElementsAre( + // clang-format off + Pair(":status", "200"), + Pair("cache-control", "private"), + Pair("date", "Mon, 21 Oct 2013 20:13:22 GMT"), + Pair("location", "https://www.example.com"), + Pair("content-encoding", "gzip"), + Pair("set-cookie", "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU;" + " max-age=3600; version=1"))); + // clang-format on + + expectEntry(62, 98, "set-cookie", + "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU;" + " max-age=3600; version=1"); + expectEntry(63, 52, "content-encoding", "gzip"); + expectEntry(64, 65, "date", "Mon, 21 Oct 2013 20:13:22 GMT"); + EXPECT_EQ(215u, decoder_peer_.current_header_table_size()); +} + +// Regression test: Found that entries with dynamic indexed names and literal +// values caused "use after free" MSAN failures if the name was evicted as it +// was being re-used. +TEST_P(HpackDecoderAdapterTest, ReuseNameOfEvictedEntry) { + // Each entry is measured as 32 bytes plus the sum of the lengths of the name + // and the value. Set the size big enough for at most one entry, and a fairly + // small one at that (31 ASCII characters). + decoder_.ApplyHeaderTableSizeSetting(63); + + HpackBlockBuilder hbb; + hbb.AppendDynamicTableSizeUpdate(0); + hbb.AppendDynamicTableSizeUpdate(63); + + const absl::string_view name("some-name"); + const absl::string_view value1("some-value"); + const absl::string_view value2("another-value"); + const absl::string_view value3("yet-another-value"); + + // Add an entry that will become the first in the dynamic table, entry 62. + hbb.AppendLiteralNameAndValue(HpackEntryType::kIndexedLiteralHeader, false, + name, false, value1); + + // Confirm that entry has been added by re-using it. + hbb.AppendIndexedHeader(62); + + // Add another entry referring to the name of the first. This will evict the + // first. + hbb.AppendNameIndexAndLiteralValue(HpackEntryType::kIndexedLiteralHeader, 62, + false, value2); + + // Confirm that entry has been added by re-using it. + hbb.AppendIndexedHeader(62); + + // Add another entry referring to the name of the second. This will evict the + // second. + hbb.AppendNameIndexAndLiteralValue(HpackEntryType::kIndexedLiteralHeader, 62, + false, value3); + + // Confirm that entry has been added by re-using it. + hbb.AppendIndexedHeader(62); + + // Can't have DecodeHeaderBlock do the default check for size of the decoded + // data because Http2HeaderBlock will join multiple headers with the same + // name into a single entry, thus we won't see repeated occurrences of the + // name, instead seeing separators between values. + EXPECT_TRUE(DecodeHeaderBlock(hbb.buffer(), kNoCheckDecodedSize)); + + Http2HeaderBlock expected_header_set; + expected_header_set.AppendValueOrAddHeader(name, value1); + expected_header_set.AppendValueOrAddHeader(name, value1); + expected_header_set.AppendValueOrAddHeader(name, value2); + expected_header_set.AppendValueOrAddHeader(name, value2); + expected_header_set.AppendValueOrAddHeader(name, value3); + expected_header_set.AppendValueOrAddHeader(name, value3); + + // Http2HeaderBlock stores these 6 strings as '\0' separated values. + // Make sure that is what happened. + std::string joined_values = expected_header_set[name].as_string(); + EXPECT_EQ(joined_values.size(), + 2 * value1.size() + 2 * value2.size() + 2 * value3.size() + 5); + + EXPECT_EQ(expected_header_set, decoded_block()); + + if (start_choice_ == START_WITH_HANDLER) { + EXPECT_EQ(handler_.uncompressed_header_bytes(), + 6 * name.size() + 2 * value1.size() + 2 * value2.size() + + 2 * value3.size()); + } +} + +// Regression test for https://crbug.com/747395. +TEST_P(HpackDecoderAdapterTest, Cookies) { + Http2HeaderBlock expected_header_set; + expected_header_set["cookie"] = "foo; bar"; + + EXPECT_TRUE(DecodeHeaderBlock(absl::HexStringToBytes("608294e76003626172"))); + EXPECT_EQ(expected_header_set, decoded_block()); +} + +} // namespace +} // namespace test +} // namespace spdy diff --git a/quiche/spdy/core/hpack/hpack_encoder.cc b/quiche/spdy/core/hpack/hpack_encoder.cc new file mode 100644 index 000000000000..3b95cfe13d23 --- /dev/null +++ b/quiche/spdy/core/hpack/hpack_encoder.cc @@ -0,0 +1,375 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/hpack/hpack_encoder.h" + +#include +#include +#include + +#include "quiche/http2/hpack/huffman/hpack_huffman_encoder.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/spdy/core/hpack/hpack_constants.h" +#include "quiche/spdy/core/hpack/hpack_header_table.h" +#include "quiche/spdy/core/hpack/hpack_output_stream.h" + +namespace spdy { + +class HpackEncoder::RepresentationIterator { + public: + // |pseudo_headers| and |regular_headers| must outlive the iterator. + RepresentationIterator(const Representations& pseudo_headers, + const Representations& regular_headers) + : pseudo_begin_(pseudo_headers.begin()), + pseudo_end_(pseudo_headers.end()), + regular_begin_(regular_headers.begin()), + regular_end_(regular_headers.end()) {} + + // |headers| must outlive the iterator. + explicit RepresentationIterator(const Representations& headers) + : pseudo_begin_(headers.begin()), + pseudo_end_(headers.end()), + regular_begin_(headers.end()), + regular_end_(headers.end()) {} + + bool HasNext() { + return pseudo_begin_ != pseudo_end_ || regular_begin_ != regular_end_; + } + + const Representation Next() { + if (pseudo_begin_ != pseudo_end_) { + return *pseudo_begin_++; + } else { + return *regular_begin_++; + } + } + + private: + Representations::const_iterator pseudo_begin_; + Representations::const_iterator pseudo_end_; + Representations::const_iterator regular_begin_; + Representations::const_iterator regular_end_; +}; + +namespace { + +// The default header listener. +void NoOpListener(absl::string_view /*name*/, absl::string_view /*value*/) {} + +// The default HPACK indexing policy. +bool DefaultPolicy(absl::string_view name, absl::string_view /* value */) { + if (name.empty()) { + return false; + } + // :authority is always present and rarely changes, and has moderate + // length, therefore it makes a lot of sense to index (insert in the + // dynamic table). + if (name[0] == kPseudoHeaderPrefix) { + return name == ":authority"; + } + return true; +} + +} // namespace + +HpackEncoder::HpackEncoder() + : output_stream_(), + min_table_size_setting_received_(std::numeric_limits::max()), + listener_(NoOpListener), + should_index_(DefaultPolicy), + enable_compression_(true), + should_emit_table_size_(false) {} + +HpackEncoder::~HpackEncoder() = default; + +std::string HpackEncoder::EncodeHeaderBlock( + const Http2HeaderBlock& header_set) { + // Separate header set into pseudo-headers and regular headers. + Representations pseudo_headers; + Representations regular_headers; + bool found_cookie = false; + for (const auto& header : header_set) { + if (!found_cookie && header.first == "cookie") { + // Note that there can only be one "cookie" header, because header_set is + // a map. + found_cookie = true; + CookieToCrumbs(header, ®ular_headers); + } else if (!header.first.empty() && + header.first[0] == kPseudoHeaderPrefix) { + DecomposeRepresentation(header, &pseudo_headers); + } else { + DecomposeRepresentation(header, ®ular_headers); + } + } + + RepresentationIterator iter(pseudo_headers, regular_headers); + return EncodeRepresentations(&iter); +} + +void HpackEncoder::ApplyHeaderTableSizeSetting(size_t size_setting) { + if (size_setting == header_table_.settings_size_bound()) { + return; + } + if (size_setting < header_table_.settings_size_bound()) { + min_table_size_setting_received_ = + std::min(size_setting, min_table_size_setting_received_); + } + header_table_.SetSettingsHeaderTableSize(size_setting); + should_emit_table_size_ = true; +} + +std::string HpackEncoder::EncodeRepresentations(RepresentationIterator* iter) { + MaybeEmitTableSize(); + while (iter->HasNext()) { + const auto header = iter->Next(); + listener_(header.first, header.second); + if (enable_compression_) { + size_t index = + header_table_.GetByNameAndValue(header.first, header.second); + if (index != kHpackEntryNotFound) { + EmitIndex(index); + } else if (should_index_(header.first, header.second)) { + EmitIndexedLiteral(header); + } else { + EmitNonIndexedLiteral(header, enable_compression_); + } + } else { + EmitNonIndexedLiteral(header, enable_compression_); + } + } + + return output_stream_.TakeString(); +} + +void HpackEncoder::EmitIndex(size_t index) { + QUICHE_DVLOG(2) << "Emitting index " << index; + output_stream_.AppendPrefix(kIndexedOpcode); + output_stream_.AppendUint32(index); +} + +void HpackEncoder::EmitIndexedLiteral(const Representation& representation) { + QUICHE_DVLOG(2) << "Emitting indexed literal: (" << representation.first + << ", " << representation.second << ")"; + output_stream_.AppendPrefix(kLiteralIncrementalIndexOpcode); + EmitLiteral(representation); + header_table_.TryAddEntry(representation.first, representation.second); +} + +void HpackEncoder::EmitNonIndexedLiteral(const Representation& representation, + bool enable_compression) { + QUICHE_DVLOG(2) << "Emitting nonindexed literal: (" << representation.first + << ", " << representation.second << ")"; + output_stream_.AppendPrefix(kLiteralNoIndexOpcode); + size_t name_index = header_table_.GetByName(representation.first); + if (enable_compression && name_index != kHpackEntryNotFound) { + output_stream_.AppendUint32(name_index); + } else { + output_stream_.AppendUint32(0); + EmitString(representation.first); + } + EmitString(representation.second); +} + +void HpackEncoder::EmitLiteral(const Representation& representation) { + size_t name_index = header_table_.GetByName(representation.first); + if (name_index != kHpackEntryNotFound) { + output_stream_.AppendUint32(name_index); + } else { + output_stream_.AppendUint32(0); + EmitString(representation.first); + } + EmitString(representation.second); +} + +void HpackEncoder::EmitString(absl::string_view str) { + size_t encoded_size = + enable_compression_ ? http2::HuffmanSize(str) : str.size(); + if (encoded_size < str.size()) { + QUICHE_DVLOG(2) << "Emitted Huffman-encoded string of length " + << encoded_size; + output_stream_.AppendPrefix(kStringLiteralHuffmanEncoded); + output_stream_.AppendUint32(encoded_size); + http2::HuffmanEncodeFast(str, encoded_size, output_stream_.MutableString()); + } else { + QUICHE_DVLOG(2) << "Emitted literal string of length " << str.size(); + output_stream_.AppendPrefix(kStringLiteralIdentityEncoded); + output_stream_.AppendUint32(str.size()); + output_stream_.AppendBytes(str); + } +} + +void HpackEncoder::MaybeEmitTableSize() { + if (!should_emit_table_size_) { + return; + } + const size_t current_size = CurrentHeaderTableSizeSetting(); + QUICHE_DVLOG(1) << "MaybeEmitTableSize current_size=" << current_size; + QUICHE_DVLOG(1) << "MaybeEmitTableSize min_table_size_setting_received_=" + << min_table_size_setting_received_; + if (min_table_size_setting_received_ < current_size) { + output_stream_.AppendPrefix(kHeaderTableSizeUpdateOpcode); + output_stream_.AppendUint32(min_table_size_setting_received_); + } + output_stream_.AppendPrefix(kHeaderTableSizeUpdateOpcode); + output_stream_.AppendUint32(current_size); + min_table_size_setting_received_ = std::numeric_limits::max(); + should_emit_table_size_ = false; +} + +// static +void HpackEncoder::CookieToCrumbs(const Representation& cookie, + Representations* out) { + // See Section 8.1.2.5. "Compressing the Cookie Header Field" in the HTTP/2 + // specification at https://tools.ietf.org/html/draft-ietf-httpbis-http2-14. + // Cookie values are split into individually-encoded HPACK representations. + absl::string_view cookie_value = cookie.second; + // Consume leading and trailing whitespace if present. + absl::string_view::size_type first = cookie_value.find_first_not_of(" \t"); + absl::string_view::size_type last = cookie_value.find_last_not_of(" \t"); + if (first == absl::string_view::npos) { + cookie_value = absl::string_view(); + } else { + cookie_value = cookie_value.substr(first, (last - first) + 1); + } + for (size_t pos = 0;;) { + size_t end = cookie_value.find(';', pos); + + if (end == absl::string_view::npos) { + out->push_back(std::make_pair(cookie.first, cookie_value.substr(pos))); + break; + } + out->push_back( + std::make_pair(cookie.first, cookie_value.substr(pos, end - pos))); + + // Consume next space if present. + pos = end + 1; + if (pos != cookie_value.size() && cookie_value[pos] == ' ') { + pos++; + } + } +} + +// static +void HpackEncoder::DecomposeRepresentation(const Representation& header_field, + Representations* out) { + size_t pos = 0; + size_t end = 0; + while (end != absl::string_view::npos) { + end = header_field.second.find('\0', pos); + out->push_back(std::make_pair( + header_field.first, + header_field.second.substr( + pos, end == absl::string_view::npos ? end : end - pos))); + pos = end + 1; + } +} + +// Iteratively encodes a Http2HeaderBlock. +class HpackEncoder::Encoderator : public ProgressiveEncoder { + public: + Encoderator(const Http2HeaderBlock& header_set, HpackEncoder* encoder); + Encoderator(const Representations& representations, HpackEncoder* encoder); + + // Encoderator is neither copyable nor movable. + Encoderator(const Encoderator&) = delete; + Encoderator& operator=(const Encoderator&) = delete; + + // Returns true iff more remains to encode. + bool HasNext() const override { return has_next_; } + + // Encodes and returns up to max_encoded_bytes of the current header block. + std::string Next(size_t max_encoded_bytes) override; + + private: + HpackEncoder* encoder_; + std::unique_ptr header_it_; + Representations pseudo_headers_; + Representations regular_headers_; + bool has_next_; +}; + +HpackEncoder::Encoderator::Encoderator(const Http2HeaderBlock& header_set, + HpackEncoder* encoder) + : encoder_(encoder), has_next_(true) { + // Separate header set into pseudo-headers and regular headers. + bool found_cookie = false; + for (const auto& header : header_set) { + if (!found_cookie && header.first == "cookie") { + // Note that there can only be one "cookie" header, because header_set + // is a map. + found_cookie = true; + CookieToCrumbs(header, ®ular_headers_); + } else if (!header.first.empty() && + header.first[0] == kPseudoHeaderPrefix) { + DecomposeRepresentation(header, &pseudo_headers_); + } else { + DecomposeRepresentation(header, ®ular_headers_); + } + } + header_it_ = std::make_unique(pseudo_headers_, + regular_headers_); + + encoder_->MaybeEmitTableSize(); +} + +HpackEncoder::Encoderator::Encoderator(const Representations& representations, + HpackEncoder* encoder) + : encoder_(encoder), has_next_(true) { + for (const auto& header : representations) { + if (header.first == "cookie") { + CookieToCrumbs(header, ®ular_headers_); + } else if (!header.first.empty() && + header.first[0] == kPseudoHeaderPrefix) { + pseudo_headers_.push_back(header); + } else { + regular_headers_.push_back(header); + } + } + header_it_ = std::make_unique(pseudo_headers_, + regular_headers_); + + encoder_->MaybeEmitTableSize(); +} + +std::string HpackEncoder::Encoderator::Next(size_t max_encoded_bytes) { + QUICHE_BUG_IF(spdy_bug_61_1, !has_next_) + << "Encoderator::Next called with nothing left to encode."; + const bool enable_compression = encoder_->enable_compression_; + + // Encode up to max_encoded_bytes of headers. + while (header_it_->HasNext() && + encoder_->output_stream_.size() <= max_encoded_bytes) { + const Representation header = header_it_->Next(); + encoder_->listener_(header.first, header.second); + if (enable_compression) { + size_t index = encoder_->header_table_.GetByNameAndValue(header.first, + header.second); + if (index != kHpackEntryNotFound) { + encoder_->EmitIndex(index); + } else if (encoder_->should_index_(header.first, header.second)) { + encoder_->EmitIndexedLiteral(header); + } else { + encoder_->EmitNonIndexedLiteral(header, enable_compression); + } + } else { + encoder_->EmitNonIndexedLiteral(header, enable_compression); + } + } + + has_next_ = encoder_->output_stream_.size() > max_encoded_bytes; + return encoder_->output_stream_.BoundedTakeString(max_encoded_bytes); +} + +std::unique_ptr HpackEncoder::EncodeHeaderSet( + const Http2HeaderBlock& header_set) { + return std::make_unique(header_set, this); +} + +std::unique_ptr +HpackEncoder::EncodeRepresentations(const Representations& representations) { + return std::make_unique(representations, this); +} + +} // namespace spdy diff --git a/quiche/spdy/core/hpack/hpack_encoder.h b/quiche/spdy/core/hpack/hpack_encoder.h new file mode 100644 index 000000000000..50e4711a5879 --- /dev/null +++ b/quiche/spdy/core/hpack/hpack_encoder.h @@ -0,0 +1,147 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_SPDY_CORE_HPACK_HPACK_ENCODER_H_ +#define QUICHE_SPDY_CORE_HPACK_HPACK_ENCODER_H_ + +#include + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/spdy/core/hpack/hpack_header_table.h" +#include "quiche/spdy/core/hpack/hpack_output_stream.h" +#include "quiche/spdy/core/http2_header_block.h" +#include "quiche/spdy/core/spdy_protocol.h" + +// An HpackEncoder encodes header sets as outlined in +// http://tools.ietf.org/html/rfc7541. + +namespace spdy { + +namespace test { +class HpackEncoderPeer; +} // namespace test + +class QUICHE_EXPORT HpackEncoder { + public: + using Representation = std::pair; + using Representations = std::vector; + + // Callers may provide a HeaderListener to be informed of header name-value + // pairs processed by this encoder. + using HeaderListener = + std::function; + + // An indexing policy should return true if the provided header name-value + // pair should be inserted into the HPACK dynamic table. + using IndexingPolicy = + std::function; + + HpackEncoder(); + HpackEncoder(const HpackEncoder&) = delete; + HpackEncoder& operator=(const HpackEncoder&) = delete; + ~HpackEncoder(); + + // Encodes and returns the given header set as a string. + std::string EncodeHeaderBlock(const Http2HeaderBlock& header_set); + + class QUICHE_EXPORT ProgressiveEncoder { + public: + virtual ~ProgressiveEncoder() {} + + // Returns true iff more remains to encode. + virtual bool HasNext() const = 0; + + // Encodes and returns up to max_encoded_bytes of the current header block. + virtual std::string Next(size_t max_encoded_bytes) = 0; + }; + + // Returns a ProgressiveEncoder which must be outlived by both the given + // Http2HeaderBlock and this object. + std::unique_ptr EncodeHeaderSet( + const Http2HeaderBlock& header_set); + // Returns a ProgressiveEncoder which must be outlived by this HpackEncoder. + // The encoder will not attempt to split any \0-delimited values in + // |representations|. If such splitting is desired, it must be performed by + // the caller when constructing the list of representations. + std::unique_ptr EncodeRepresentations( + const Representations& representations); + + // Called upon a change to SETTINGS_HEADER_TABLE_SIZE. Specifically, this + // is to be called after receiving (and sending an acknowledgement for) a + // SETTINGS_HEADER_TABLE_SIZE update from the remote decoding endpoint. + void ApplyHeaderTableSizeSetting(size_t size_setting); + + // TODO(birenroy): Rename this GetDynamicTableCapacity(). + size_t CurrentHeaderTableSizeSetting() const { + return header_table_.settings_size_bound(); + } + + // This HpackEncoder will use |policy| to determine whether to insert header + // name-value pairs into the dynamic table. + void SetIndexingPolicy(IndexingPolicy policy) { should_index_ = policy; } + + // |listener| will be invoked for each header name-value pair processed by + // this encoder. + void SetHeaderListener(HeaderListener listener) { listener_ = listener; } + + void DisableCompression() { enable_compression_ = false; } + + // Returns the current dynamic table size, including the 32 bytes per entry + // overhead mentioned in RFC 7541 section 4.1. + size_t GetDynamicTableSize() const { return header_table_.size(); } + + private: + friend class test::HpackEncoderPeer; + + class RepresentationIterator; + class Encoderator; + + // Encodes a sequence of header name-value pairs as a single header block. + std::string EncodeRepresentations(RepresentationIterator* iter); + + // Emits a static/dynamic indexed representation (Section 7.1). + void EmitIndex(size_t index); + + // Emits a literal representation (Section 7.2). + void EmitIndexedLiteral(const Representation& representation); + void EmitNonIndexedLiteral(const Representation& representation, + bool enable_compression); + void EmitLiteral(const Representation& representation); + + // Emits a Huffman or identity string (whichever is smaller). + void EmitString(absl::string_view str); + + // Emits the current dynamic table size if the table size was recently + // updated and we have not yet emitted it (Section 6.3). + void MaybeEmitTableSize(); + + // Crumbles a cookie header into ";" delimited crumbs. + static void CookieToCrumbs(const Representation& cookie, + Representations* crumbs_out); + + // Crumbles other header field values at \0 delimiters. + static void DecomposeRepresentation(const Representation& header_field, + Representations* out); + + HpackHeaderTable header_table_; + HpackOutputStream output_stream_; + + size_t min_table_size_setting_received_; + HeaderListener listener_; + IndexingPolicy should_index_; + bool enable_compression_; + bool should_emit_table_size_; +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_HPACK_HPACK_ENCODER_H_ diff --git a/quiche/spdy/core/hpack/hpack_encoder_test.cc b/quiche/spdy/core/hpack/hpack_encoder_test.cc new file mode 100644 index 000000000000..fe80ffea95aa --- /dev/null +++ b/quiche/spdy/core/hpack/hpack_encoder_test.cc @@ -0,0 +1,754 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/hpack/hpack_encoder.h" + +#include +#include + +#include "quiche/http2/hpack/huffman/hpack_huffman_encoder.h" +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/spdy/core/hpack/hpack_static_table.h" +#include "quiche/spdy/core/spdy_simple_arena.h" + +namespace spdy { + +namespace test { + +class HpackHeaderTablePeer { + public: + explicit HpackHeaderTablePeer(HpackHeaderTable* table) : table_(table) {} + + const HpackEntry* GetFirstStaticEntry() const { + return &table_->static_entries_.front(); + } + + HpackHeaderTable::DynamicEntryTable* dynamic_entries() { + return &table_->dynamic_entries_; + } + + private: + HpackHeaderTable* table_; +}; + +class HpackEncoderPeer { + public: + typedef HpackEncoder::Representation Representation; + typedef HpackEncoder::Representations Representations; + + explicit HpackEncoderPeer(HpackEncoder* encoder) : encoder_(encoder) {} + + bool compression_enabled() const { return encoder_->enable_compression_; } + HpackHeaderTable* table() { return &encoder_->header_table_; } + HpackHeaderTablePeer table_peer() { return HpackHeaderTablePeer(table()); } + void EmitString(absl::string_view str) { encoder_->EmitString(str); } + void TakeString(std::string* out) { + *out = encoder_->output_stream_.TakeString(); + } + static void CookieToCrumbs(absl::string_view cookie, + std::vector* out) { + Representations tmp; + HpackEncoder::CookieToCrumbs(std::make_pair("", cookie), &tmp); + + out->clear(); + for (size_t i = 0; i != tmp.size(); ++i) { + out->push_back(tmp[i].second); + } + } + static void DecomposeRepresentation(absl::string_view value, + std::vector* out) { + Representations tmp; + HpackEncoder::DecomposeRepresentation(std::make_pair("foobar", value), + &tmp); + + out->clear(); + for (size_t i = 0; i != tmp.size(); ++i) { + out->push_back(tmp[i].second); + } + } + + // TODO(dahollings): Remove or clean up these methods when deprecating + // non-incremental encoding path. + static std::string EncodeHeaderBlock(HpackEncoder* encoder, + const Http2HeaderBlock& header_set) { + return encoder->EncodeHeaderBlock(header_set); + } + + static bool EncodeIncremental(HpackEncoder* encoder, + const Http2HeaderBlock& header_set, + std::string* output) { + std::unique_ptr encoderator = + encoder->EncodeHeaderSet(header_set); + http2::test::Http2Random random; + std::string output_buffer = encoderator->Next(random.UniformInRange(0, 16)); + while (encoderator->HasNext()) { + std::string second_buffer = + encoderator->Next(random.UniformInRange(0, 16)); + output_buffer.append(second_buffer); + } + *output = std::move(output_buffer); + return true; + } + + static bool EncodeRepresentations(HpackEncoder* encoder, + const Representations& representations, + std::string* output) { + std::unique_ptr encoderator = + encoder->EncodeRepresentations(representations); + http2::test::Http2Random random; + std::string output_buffer = encoderator->Next(random.UniformInRange(0, 16)); + while (encoderator->HasNext()) { + std::string second_buffer = + encoderator->Next(random.UniformInRange(0, 16)); + output_buffer.append(second_buffer); + } + *output = std::move(output_buffer); + return true; + } + + private: + HpackEncoder* encoder_; +}; + +} // namespace test + +namespace { + +using testing::ElementsAre; +using testing::Pair; + +const size_t kStaticEntryIndex = 1; + +enum EncodeStrategy { + kDefault, + kIncremental, + kRepresentations, +}; + +class HpackEncoderTest + : public quiche::test::QuicheTestWithParam { + protected: + typedef test::HpackEncoderPeer::Representations Representations; + + HpackEncoderTest() + : peer_(&encoder_), + static_(peer_.table_peer().GetFirstStaticEntry()), + dynamic_table_insertions_(0), + headers_storage_(1024 /* block size */), + strategy_(GetParam()) {} + + void SetUp() override { + // Populate dynamic entries into the table fixture. For simplicity each + // entry has name.size() + value.size() == 10. + key_1_ = peer_.table()->TryAddEntry("key1", "value1"); + key_1_index_ = dynamic_table_insertions_++; + key_2_ = peer_.table()->TryAddEntry("key2", "value2"); + key_2_index_ = dynamic_table_insertions_++; + cookie_a_ = peer_.table()->TryAddEntry("cookie", "a=bb"); + cookie_a_index_ = dynamic_table_insertions_++; + cookie_c_ = peer_.table()->TryAddEntry("cookie", "c=dd"); + cookie_c_index_ = dynamic_table_insertions_++; + + // No further insertions may occur without evictions. + peer_.table()->SetMaxSize(peer_.table()->size()); + QUICHE_CHECK_EQ(kInitialDynamicTableSize, peer_.table()->size()); + } + + void SaveHeaders(absl::string_view name, absl::string_view value) { + absl::string_view n(headers_storage_.Memdup(name.data(), name.size()), + name.size()); + absl::string_view v(headers_storage_.Memdup(value.data(), value.size()), + value.size()); + headers_observed_.push_back(std::make_pair(n, v)); + } + + void ExpectIndex(size_t index) { + expected_.AppendPrefix(kIndexedOpcode); + expected_.AppendUint32(index); + } + void ExpectIndexedLiteral(size_t key_index, absl::string_view value) { + expected_.AppendPrefix(kLiteralIncrementalIndexOpcode); + expected_.AppendUint32(key_index); + ExpectString(&expected_, value); + } + void ExpectIndexedLiteral(absl::string_view name, absl::string_view value) { + expected_.AppendPrefix(kLiteralIncrementalIndexOpcode); + expected_.AppendUint32(0); + ExpectString(&expected_, name); + ExpectString(&expected_, value); + } + void ExpectNonIndexedLiteral(absl::string_view name, + absl::string_view value) { + expected_.AppendPrefix(kLiteralNoIndexOpcode); + expected_.AppendUint32(0); + ExpectString(&expected_, name); + ExpectString(&expected_, value); + } + void ExpectNonIndexedLiteralWithNameIndex(size_t key_index, + absl::string_view value) { + expected_.AppendPrefix(kLiteralNoIndexOpcode); + expected_.AppendUint32(key_index); + ExpectString(&expected_, value); + } + void ExpectString(HpackOutputStream* stream, absl::string_view str) { + size_t encoded_size = + peer_.compression_enabled() ? http2::HuffmanSize(str) : str.size(); + if (encoded_size < str.size()) { + expected_.AppendPrefix(kStringLiteralHuffmanEncoded); + expected_.AppendUint32(encoded_size); + http2::HuffmanEncodeFast(str, encoded_size, stream->MutableString()); + } else { + expected_.AppendPrefix(kStringLiteralIdentityEncoded); + expected_.AppendUint32(str.size()); + expected_.AppendBytes(str); + } + } + void ExpectHeaderTableSizeUpdate(uint32_t size) { + expected_.AppendPrefix(kHeaderTableSizeUpdateOpcode); + expected_.AppendUint32(size); + } + Representations MakeRepresentations(const Http2HeaderBlock& header_set) { + Representations r; + for (const auto& header : header_set) { + r.push_back(header); + } + return r; + } + void CompareWithExpectedEncoding(const Http2HeaderBlock& header_set) { + std::string actual_out; + std::string expected_out = expected_.TakeString(); + switch (strategy_) { + case kDefault: + actual_out = + test::HpackEncoderPeer::EncodeHeaderBlock(&encoder_, header_set); + break; + case kIncremental: + EXPECT_TRUE(test::HpackEncoderPeer::EncodeIncremental( + &encoder_, header_set, &actual_out)); + break; + case kRepresentations: + EXPECT_TRUE(test::HpackEncoderPeer::EncodeRepresentations( + &encoder_, MakeRepresentations(header_set), &actual_out)); + break; + } + EXPECT_EQ(expected_out, actual_out); + } + void CompareWithExpectedEncoding(const Representations& representations) { + std::string actual_out; + std::string expected_out = expected_.TakeString(); + EXPECT_TRUE(test::HpackEncoderPeer::EncodeRepresentations( + &encoder_, representations, &actual_out)); + EXPECT_EQ(expected_out, actual_out); + } + // Converts the index of a dynamic table entry to the HPACK index. + // In these test, dynamic table entries are indexed sequentially, starting + // with 0. The HPACK indexing scheme is defined at + // https://httpwg.org/specs/rfc7541.html#index.address.space. + size_t DynamicIndexToWireIndex(size_t index) { + return dynamic_table_insertions_ - index + kStaticTableSize; + } + + HpackEncoder encoder_; + test::HpackEncoderPeer peer_; + + // Calculated based on the names and values inserted in SetUp(), above. + const size_t kInitialDynamicTableSize = 4 * (10 + 32); + + const HpackEntry* static_; + const HpackEntry* key_1_; + const HpackEntry* key_2_; + const HpackEntry* cookie_a_; + const HpackEntry* cookie_c_; + size_t key_1_index_; + size_t key_2_index_; + size_t cookie_a_index_; + size_t cookie_c_index_; + size_t dynamic_table_insertions_; + + SpdySimpleArena headers_storage_; + std::vector> + headers_observed_; + + HpackOutputStream expected_; + const EncodeStrategy strategy_; +}; + +using HpackEncoderTestWithDefaultStrategy = HpackEncoderTest; + +INSTANTIATE_TEST_SUITE_P(HpackEncoderTests, HpackEncoderTestWithDefaultStrategy, + ::testing::Values(kDefault)); + +TEST_P(HpackEncoderTestWithDefaultStrategy, EncodeRepresentations) { + EXPECT_EQ(kInitialDynamicTableSize, encoder_.GetDynamicTableSize()); + encoder_.SetHeaderListener( + [this](absl::string_view name, absl::string_view value) { + this->SaveHeaders(name, value); + }); + const std::vector> + header_list = {{"cookie", "val1; val2;val3"}, + {":path", "/home"}, + {"accept", "text/html, text/plain,application/xml"}, + {"cookie", "val4"}, + {"withnul", absl::string_view("one\0two", 7)}}; + ExpectNonIndexedLiteralWithNameIndex(peer_.table()->GetByName(":path"), + "/home"); + ExpectIndexedLiteral(peer_.table()->GetByName("cookie"), "val1"); + ExpectIndexedLiteral(peer_.table()->GetByName("cookie"), "val2"); + ExpectIndexedLiteral(peer_.table()->GetByName("cookie"), "val3"); + ExpectIndexedLiteral(peer_.table()->GetByName("accept"), + "text/html, text/plain,application/xml"); + ExpectIndexedLiteral(peer_.table()->GetByName("cookie"), "val4"); + ExpectIndexedLiteral("withnul", absl::string_view("one\0two", 7)); + + CompareWithExpectedEncoding(header_list); + EXPECT_THAT( + headers_observed_, + ElementsAre(Pair(":path", "/home"), Pair("cookie", "val1"), + Pair("cookie", "val2"), Pair("cookie", "val3"), + Pair("accept", "text/html, text/plain,application/xml"), + Pair("cookie", "val4"), + Pair("withnul", absl::string_view("one\0two", 7)))); + // Insertions and evictions have happened over the course of the test. + EXPECT_GE(kInitialDynamicTableSize, encoder_.GetDynamicTableSize()); +} + +TEST_P(HpackEncoderTestWithDefaultStrategy, DynamicTableGrows) { + EXPECT_EQ(kInitialDynamicTableSize, encoder_.GetDynamicTableSize()); + peer_.table()->SetMaxSize(4096); + encoder_.SetHeaderListener( + [this](absl::string_view name, absl::string_view value) { + this->SaveHeaders(name, value); + }); + const std::vector> + header_list = {{"cookie", "val1; val2;val3"}, + {":path", "/home"}, + {"accept", "text/html, text/plain,application/xml"}, + {"cookie", "val4"}, + {"withnul", absl::string_view("one\0two", 7)}}; + std::string out; + EXPECT_TRUE(test::HpackEncoderPeer::EncodeRepresentations(&encoder_, + header_list, &out)); + + EXPECT_FALSE(out.empty()); + // Insertions have happened over the course of the test. + EXPECT_GT(encoder_.GetDynamicTableSize(), kInitialDynamicTableSize); +} + +INSTANTIATE_TEST_SUITE_P(HpackEncoderTests, HpackEncoderTest, + ::testing::Values(kDefault, kIncremental, + kRepresentations)); + +TEST_P(HpackEncoderTest, SingleDynamicIndex) { + encoder_.SetHeaderListener( + [this](absl::string_view name, absl::string_view value) { + this->SaveHeaders(name, value); + }); + + ExpectIndex(DynamicIndexToWireIndex(key_2_index_)); + + Http2HeaderBlock headers; + headers[key_2_->name()] = key_2_->value(); + CompareWithExpectedEncoding(headers); + EXPECT_THAT(headers_observed_, + ElementsAre(Pair(key_2_->name(), key_2_->value()))); +} + +TEST_P(HpackEncoderTest, SingleStaticIndex) { + ExpectIndex(kStaticEntryIndex); + + Http2HeaderBlock headers; + headers[static_->name()] = static_->value(); + CompareWithExpectedEncoding(headers); +} + +TEST_P(HpackEncoderTest, SingleStaticIndexTooLarge) { + peer_.table()->SetMaxSize(1); // Also evicts all fixtures. + ExpectIndex(kStaticEntryIndex); + + Http2HeaderBlock headers; + headers[static_->name()] = static_->value(); + CompareWithExpectedEncoding(headers); + + EXPECT_EQ(0u, peer_.table_peer().dynamic_entries()->size()); +} + +TEST_P(HpackEncoderTest, SingleLiteralWithIndexName) { + ExpectIndexedLiteral(DynamicIndexToWireIndex(key_2_index_), "value3"); + + Http2HeaderBlock headers; + headers[key_2_->name()] = "value3"; + CompareWithExpectedEncoding(headers); + + // A new entry was inserted and added to the reference set. + HpackEntry* new_entry = peer_.table_peer().dynamic_entries()->front().get(); + EXPECT_EQ(new_entry->name(), key_2_->name()); + EXPECT_EQ(new_entry->value(), "value3"); +} + +TEST_P(HpackEncoderTest, SingleLiteralWithLiteralName) { + ExpectIndexedLiteral("key3", "value3"); + + Http2HeaderBlock headers; + headers["key3"] = "value3"; + CompareWithExpectedEncoding(headers); + + HpackEntry* new_entry = peer_.table_peer().dynamic_entries()->front().get(); + EXPECT_EQ(new_entry->name(), "key3"); + EXPECT_EQ(new_entry->value(), "value3"); +} + +TEST_P(HpackEncoderTest, SingleLiteralTooLarge) { + peer_.table()->SetMaxSize(1); // Also evicts all fixtures. + + ExpectIndexedLiteral("key3", "value3"); + + // A header overflowing the header table is still emitted. + // The header table is empty. + Http2HeaderBlock headers; + headers["key3"] = "value3"; + CompareWithExpectedEncoding(headers); + + EXPECT_EQ(0u, peer_.table_peer().dynamic_entries()->size()); +} + +TEST_P(HpackEncoderTest, EmitThanEvict) { + // |key_1_| is toggled and placed into the reference set, + // and then immediately evicted by "key3". + ExpectIndex(DynamicIndexToWireIndex(key_1_index_)); + ExpectIndexedLiteral("key3", "value3"); + + Http2HeaderBlock headers; + headers[key_1_->name()] = key_1_->value(); + headers["key3"] = "value3"; + CompareWithExpectedEncoding(headers); +} + +TEST_P(HpackEncoderTest, CookieHeaderIsCrumbled) { + ExpectIndex(DynamicIndexToWireIndex(cookie_a_index_)); + ExpectIndex(DynamicIndexToWireIndex(cookie_c_index_)); + ExpectIndexedLiteral(peer_.table()->GetByName("cookie"), "e=ff"); + + Http2HeaderBlock headers; + headers["cookie"] = "a=bb; c=dd; e=ff"; + CompareWithExpectedEncoding(headers); +} + +TEST_P(HpackEncoderTest, MultiValuedHeadersNotCrumbled) { + ExpectIndexedLiteral("foo", "bar, baz"); + Http2HeaderBlock headers; + headers["foo"] = "bar, baz"; + CompareWithExpectedEncoding(headers); +} + +TEST_P(HpackEncoderTest, StringsDynamicallySelectHuffmanCoding) { + // Compactable string. Uses Huffman coding. + peer_.EmitString("feedbeef"); + expected_.AppendPrefix(kStringLiteralHuffmanEncoded); + expected_.AppendUint32(6); + expected_.AppendBytes("\x94\xA5\x92\x32\x96_"); + + // Non-compactable. Uses identity coding. + peer_.EmitString("@@@@@@"); + expected_.AppendPrefix(kStringLiteralIdentityEncoded); + expected_.AppendUint32(6); + expected_.AppendBytes("@@@@@@"); + + std::string actual_out; + std::string expected_out = expected_.TakeString(); + peer_.TakeString(&actual_out); + EXPECT_EQ(expected_out, actual_out); +} + +TEST_P(HpackEncoderTest, EncodingWithoutCompression) { + encoder_.SetHeaderListener( + [this](absl::string_view name, absl::string_view value) { + this->SaveHeaders(name, value); + }); + encoder_.DisableCompression(); + + ExpectNonIndexedLiteral(":path", "/index.html"); + ExpectNonIndexedLiteral("cookie", "foo=bar"); + ExpectNonIndexedLiteral("cookie", "baz=bing"); + if (strategy_ == kRepresentations) { + ExpectNonIndexedLiteral("hello", std::string("goodbye\0aloha", 13)); + } else { + ExpectNonIndexedLiteral("hello", "goodbye"); + ExpectNonIndexedLiteral("hello", "aloha"); + } + ExpectNonIndexedLiteral("multivalue", "value1, value2"); + + Http2HeaderBlock headers; + headers[":path"] = "/index.html"; + headers["cookie"] = "foo=bar; baz=bing"; + headers["hello"] = "goodbye"; + headers.AppendValueOrAddHeader("hello", "aloha"); + headers["multivalue"] = "value1, value2"; + + CompareWithExpectedEncoding(headers); + + if (strategy_ == kRepresentations) { + EXPECT_THAT( + headers_observed_, + ElementsAre(Pair(":path", "/index.html"), Pair("cookie", "foo=bar"), + Pair("cookie", "baz=bing"), + Pair("hello", absl::string_view("goodbye\0aloha", 13)), + Pair("multivalue", "value1, value2"))); + } else { + EXPECT_THAT( + headers_observed_, + ElementsAre(Pair(":path", "/index.html"), Pair("cookie", "foo=bar"), + Pair("cookie", "baz=bing"), Pair("hello", "goodbye"), + Pair("hello", "aloha"), + Pair("multivalue", "value1, value2"))); + } + EXPECT_EQ(kInitialDynamicTableSize, encoder_.GetDynamicTableSize()); +} + +TEST_P(HpackEncoderTest, MultipleEncodingPasses) { + encoder_.SetHeaderListener( + [this](absl::string_view name, absl::string_view value) { + this->SaveHeaders(name, value); + }); + + // Pass 1. + { + Http2HeaderBlock headers; + headers["key1"] = "value1"; + headers["cookie"] = "a=bb"; + + ExpectIndex(DynamicIndexToWireIndex(key_1_index_)); + ExpectIndex(DynamicIndexToWireIndex(cookie_a_index_)); + CompareWithExpectedEncoding(headers); + } + // Header table is: + // 65: key1: value1 + // 64: key2: value2 + // 63: cookie: a=bb + // 62: cookie: c=dd + // Pass 2. + { + Http2HeaderBlock headers; + headers["key2"] = "value2"; + headers["cookie"] = "c=dd; e=ff"; + + // "key2: value2" + ExpectIndex(DynamicIndexToWireIndex(key_2_index_)); + // "cookie: c=dd" + ExpectIndex(DynamicIndexToWireIndex(cookie_c_index_)); + // This cookie evicts |key1| from the dynamic table. + ExpectIndexedLiteral(peer_.table()->GetByName("cookie"), "e=ff"); + dynamic_table_insertions_++; + + CompareWithExpectedEncoding(headers); + } + // Header table is: + // 65: key2: value2 + // 64: cookie: a=bb + // 63: cookie: c=dd + // 62: cookie: e=ff + // Pass 3. + { + Http2HeaderBlock headers; + headers["key2"] = "value2"; + headers["cookie"] = "a=bb; b=cc; c=dd"; + + // "key2: value2" + EXPECT_EQ(65u, DynamicIndexToWireIndex(key_2_index_)); + ExpectIndex(DynamicIndexToWireIndex(key_2_index_)); + // "cookie: a=bb" + EXPECT_EQ(64u, DynamicIndexToWireIndex(cookie_a_index_)); + ExpectIndex(DynamicIndexToWireIndex(cookie_a_index_)); + // This cookie evicts |key2| from the dynamic table. + ExpectIndexedLiteral(peer_.table()->GetByName("cookie"), "b=cc"); + dynamic_table_insertions_++; + // "cookie: c=dd" + ExpectIndex(DynamicIndexToWireIndex(cookie_c_index_)); + + CompareWithExpectedEncoding(headers); + } + + // clang-format off + EXPECT_THAT(headers_observed_, + ElementsAre(Pair("key1", "value1"), + Pair("cookie", "a=bb"), + Pair("key2", "value2"), + Pair("cookie", "c=dd"), + Pair("cookie", "e=ff"), + Pair("key2", "value2"), + Pair("cookie", "a=bb"), + Pair("cookie", "b=cc"), + Pair("cookie", "c=dd"))); + // clang-format on +} + +TEST_P(HpackEncoderTest, PseudoHeadersFirst) { + Http2HeaderBlock headers; + // A pseudo-header that should not be indexed. + headers[":path"] = "/spam/eggs.html"; + // A pseudo-header to be indexed. + headers[":authority"] = "www.example.com"; + // A regular header which precedes ":" alphabetically, should still be encoded + // after pseudo-headers. + headers["-foo"] = "bar"; + headers["foo"] = "bar"; + headers["cookie"] = "c=dd"; + + // Headers are indexed in the order in which they were added. + // This entry pushes "cookie: a=bb" back to 63. + ExpectNonIndexedLiteralWithNameIndex(peer_.table()->GetByName(":path"), + "/spam/eggs.html"); + ExpectIndexedLiteral(peer_.table()->GetByName(":authority"), + "www.example.com"); + ExpectIndexedLiteral("-foo", "bar"); + ExpectIndexedLiteral("foo", "bar"); + ExpectIndexedLiteral(peer_.table()->GetByName("cookie"), "c=dd"); + CompareWithExpectedEncoding(headers); +} + +TEST_P(HpackEncoderTest, CookieToCrumbs) { + test::HpackEncoderPeer peer(nullptr); + std::vector out; + + // Leading and trailing whitespace is consumed. A space after ';' is consumed. + // All other spaces remain. ';' at beginning and end of string produce empty + // crumbs. + // See section 8.1.3.4 "Compressing the Cookie Header Field" in the HTTP/2 + // specification at http://tools.ietf.org/html/draft-ietf-httpbis-http2-11 + peer.CookieToCrumbs(" foo=1;bar=2 ; bar=3; bing=4; ", &out); + EXPECT_THAT(out, ElementsAre("foo=1", "bar=2 ", "bar=3", " bing=4", "")); + + peer.CookieToCrumbs(";;foo = bar ;; ;baz =bing", &out); + EXPECT_THAT(out, ElementsAre("", "", "foo = bar ", "", "", "baz =bing")); + + peer.CookieToCrumbs("baz=bing; foo=bar; baz=bing", &out); + EXPECT_THAT(out, ElementsAre("baz=bing", "foo=bar", "baz=bing")); + + peer.CookieToCrumbs("baz=bing", &out); + EXPECT_THAT(out, ElementsAre("baz=bing")); + + peer.CookieToCrumbs("", &out); + EXPECT_THAT(out, ElementsAre("")); + + peer.CookieToCrumbs("foo;bar; baz;baz;bing;", &out); + EXPECT_THAT(out, ElementsAre("foo", "bar", "baz", "baz", "bing", "")); + + peer.CookieToCrumbs(" \t foo=1;bar=2 ; bar=3;\t ", &out); + EXPECT_THAT(out, ElementsAre("foo=1", "bar=2 ", "bar=3", "")); + + peer.CookieToCrumbs(" \t foo=1;bar=2 ; bar=3 \t ", &out); + EXPECT_THAT(out, ElementsAre("foo=1", "bar=2 ", "bar=3")); +} + +TEST_P(HpackEncoderTest, DecomposeRepresentation) { + test::HpackEncoderPeer peer(nullptr); + std::vector out; + + peer.DecomposeRepresentation("", &out); + EXPECT_THAT(out, ElementsAre("")); + + peer.DecomposeRepresentation("foobar", &out); + EXPECT_THAT(out, ElementsAre("foobar")); + + peer.DecomposeRepresentation(absl::string_view("foo\0bar", 7), &out); + EXPECT_THAT(out, ElementsAre("foo", "bar")); + + peer.DecomposeRepresentation(absl::string_view("\0foo\0bar", 8), &out); + EXPECT_THAT(out, ElementsAre("", "foo", "bar")); + + peer.DecomposeRepresentation(absl::string_view("foo\0bar\0", 8), &out); + EXPECT_THAT(out, ElementsAre("foo", "bar", "")); + + peer.DecomposeRepresentation(absl::string_view("\0foo\0bar\0", 9), &out); + EXPECT_THAT(out, ElementsAre("", "foo", "bar", "")); +} + +// Test that encoded headers do not have \0-delimited multiple values, as this +// became disallowed in HTTP/2 draft-14. +TEST_P(HpackEncoderTest, CrumbleNullByteDelimitedValue) { + if (strategy_ == kRepresentations) { + // When HpackEncoder is asked to encode a list of Representations, the + // caller must crumble null-delimited values. + return; + } + Http2HeaderBlock headers; + // A header field to be crumbled: "spam: foo\0bar". + headers["spam"] = std::string("foo\0bar", 7); + + ExpectIndexedLiteral("spam", "foo"); + expected_.AppendPrefix(kLiteralIncrementalIndexOpcode); + expected_.AppendUint32(62); + expected_.AppendPrefix(kStringLiteralIdentityEncoded); + expected_.AppendUint32(3); + expected_.AppendBytes("bar"); + CompareWithExpectedEncoding(headers); +} + +TEST_P(HpackEncoderTest, HeaderTableSizeUpdate) { + encoder_.ApplyHeaderTableSizeSetting(1024); + ExpectHeaderTableSizeUpdate(1024); + ExpectIndexedLiteral("key3", "value3"); + + Http2HeaderBlock headers; + headers["key3"] = "value3"; + CompareWithExpectedEncoding(headers); + + HpackEntry* new_entry = peer_.table_peer().dynamic_entries()->front().get(); + EXPECT_EQ(new_entry->name(), "key3"); + EXPECT_EQ(new_entry->value(), "value3"); +} + +TEST_P(HpackEncoderTest, HeaderTableSizeUpdateWithMin) { + const size_t starting_size = peer_.table()->settings_size_bound(); + encoder_.ApplyHeaderTableSizeSetting(starting_size - 2); + encoder_.ApplyHeaderTableSizeSetting(starting_size - 1); + // We must encode the low watermark, so the peer knows to evict entries + // if necessary. + ExpectHeaderTableSizeUpdate(starting_size - 2); + ExpectHeaderTableSizeUpdate(starting_size - 1); + ExpectIndexedLiteral("key3", "value3"); + + Http2HeaderBlock headers; + headers["key3"] = "value3"; + CompareWithExpectedEncoding(headers); + + HpackEntry* new_entry = peer_.table_peer().dynamic_entries()->front().get(); + EXPECT_EQ(new_entry->name(), "key3"); + EXPECT_EQ(new_entry->value(), "value3"); +} + +TEST_P(HpackEncoderTest, HeaderTableSizeUpdateWithExistingSize) { + encoder_.ApplyHeaderTableSizeSetting(peer_.table()->settings_size_bound()); + // No encoded size update. + ExpectIndexedLiteral("key3", "value3"); + + Http2HeaderBlock headers; + headers["key3"] = "value3"; + CompareWithExpectedEncoding(headers); + + HpackEntry* new_entry = peer_.table_peer().dynamic_entries()->front().get(); + EXPECT_EQ(new_entry->name(), "key3"); + EXPECT_EQ(new_entry->value(), "value3"); +} + +TEST_P(HpackEncoderTest, HeaderTableSizeUpdatesWithGreaterSize) { + const size_t starting_size = peer_.table()->settings_size_bound(); + encoder_.ApplyHeaderTableSizeSetting(starting_size + 1); + encoder_.ApplyHeaderTableSizeSetting(starting_size + 2); + // Only a single size update to the final size. + ExpectHeaderTableSizeUpdate(starting_size + 2); + ExpectIndexedLiteral("key3", "value3"); + + Http2HeaderBlock headers; + headers["key3"] = "value3"; + CompareWithExpectedEncoding(headers); + + HpackEntry* new_entry = peer_.table_peer().dynamic_entries()->front().get(); + EXPECT_EQ(new_entry->name(), "key3"); + EXPECT_EQ(new_entry->value(), "value3"); +} + +} // namespace + +} // namespace spdy diff --git a/quiche/spdy/core/hpack/hpack_entry.cc b/quiche/spdy/core/hpack/hpack_entry.cc new file mode 100644 index 000000000000..437b5d04f16f --- /dev/null +++ b/quiche/spdy/core/hpack/hpack_entry.cc @@ -0,0 +1,24 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/hpack/hpack_entry.h" + +#include "absl/strings/str_cat.h" + +namespace spdy { + +HpackEntry::HpackEntry(std::string name, std::string value) + : name_(std::move(name)), value_(std::move(value)) {} + +// static +size_t HpackEntry::Size(absl::string_view name, absl::string_view value) { + return name.size() + value.size() + kHpackEntrySizeOverhead; +} +size_t HpackEntry::Size() const { return Size(name(), value()); } + +std::string HpackEntry::GetDebugString() const { + return absl::StrCat("{ name: \"", name_, "\", value: \"", value_, "\" }"); +} + +} // namespace spdy diff --git a/quiche/spdy/core/hpack/hpack_entry.h b/quiche/spdy/core/hpack/hpack_entry.h new file mode 100644 index 000000000000..b57203db9214 --- /dev/null +++ b/quiche/spdy/core/hpack/hpack_entry.h @@ -0,0 +1,81 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_SPDY_CORE_HPACK_HPACK_ENTRY_H_ +#define QUICHE_SPDY_CORE_HPACK_HPACK_ENTRY_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" + +// All section references below are to +// http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-08 + +namespace spdy { + +// The constant amount added to name().size() and value().size() to +// get the size of an HpackEntry as defined in 5.1. +constexpr size_t kHpackEntrySizeOverhead = 32; + +// A structure for looking up entries in the static and dynamic tables. +struct QUICHE_EXPORT HpackLookupEntry { + absl::string_view name; + absl::string_view value; + + bool operator==(const HpackLookupEntry& other) const { + return name == other.name && value == other.value; + } + + // Abseil hashing framework extension according to absl/hash/hash.h: + template + friend H AbslHashValue(H h, const HpackLookupEntry& entry) { + return H::combine(std::move(h), entry.name, entry.value); + } +}; + +// A structure for an entry in the static table (3.3.1) +// and the header table (3.3.2). +class QUICHE_EXPORT HpackEntry { + public: + HpackEntry(std::string name, std::string value); + + // Make HpackEntry non-copyable to make sure it is always moved. + HpackEntry(const HpackEntry&) = delete; + HpackEntry& operator=(const HpackEntry&) = delete; + + HpackEntry(HpackEntry&&) = default; + HpackEntry& operator=(HpackEntry&&) = default; + + // Getters for std::string members traditionally return const std::string&. + // However, HpackHeaderTable uses string_view as keys in the maps + // static_name_index_ and dynamic_name_index_. If HpackEntry::name() returned + // const std::string&, then + // dynamic_name_index_.insert(std::make_pair(entry.name(), index)); + // would silently create a dangling reference: make_pair infers type from the + // return type of entry.name() and silently creates a temporary string copy. + // Insert creates a string_view that points to this copy, which then + // immediately goes out of scope and gets destroyed. While this is quite easy + // to avoid, for example, by explicitly specifying type as a template + // parameter to make_pair, returning string_view here is less error-prone. + absl::string_view name() const { return name_; } + absl::string_view value() const { return value_; } + + // Returns the size of an entry as defined in 5.1. + static size_t Size(absl::string_view name, absl::string_view value); + size_t Size() const; + + std::string GetDebugString() const; + + private: + std::string name_; + std::string value_; +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_HPACK_HPACK_ENTRY_H_ diff --git a/quiche/spdy/core/hpack/hpack_entry_test.cc b/quiche/spdy/core/hpack/hpack_entry_test.cc new file mode 100644 index 000000000000..faf77862bfd7 --- /dev/null +++ b/quiche/spdy/core/hpack/hpack_entry_test.cc @@ -0,0 +1,53 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/hpack/hpack_entry.h" + +#include "absl/hash/hash.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace spdy { + +namespace { + +TEST(HpackLookupEntryTest, EntryNamesDiffer) { + HpackLookupEntry entry1{"header", "value"}; + HpackLookupEntry entry2{"HEADER", "value"}; + + EXPECT_FALSE(entry1 == entry2); + EXPECT_NE(absl::Hash()(entry1), + absl::Hash()(entry2)); +} + +TEST(HpackLookupEntryTest, EntryValuesDiffer) { + HpackLookupEntry entry1{"header", "value"}; + HpackLookupEntry entry2{"header", "VALUE"}; + + EXPECT_FALSE(entry1 == entry2); + EXPECT_NE(absl::Hash()(entry1), + absl::Hash()(entry2)); +} + +TEST(HpackLookupEntryTest, EntriesEqual) { + HpackLookupEntry entry1{"name", "value"}; + HpackLookupEntry entry2{"name", "value"}; + + EXPECT_TRUE(entry1 == entry2); + EXPECT_EQ(absl::Hash()(entry1), + absl::Hash()(entry2)); +} + +TEST(HpackEntryTest, BasicEntry) { + HpackEntry entry("header-name", "header value"); + + EXPECT_EQ("header-name", entry.name()); + EXPECT_EQ("header value", entry.value()); + + EXPECT_EQ(55u, entry.Size()); + EXPECT_EQ(55u, HpackEntry::Size("header-name", "header value")); +} + +} // namespace + +} // namespace spdy diff --git a/quiche/spdy/core/hpack/hpack_header_table.cc b/quiche/spdy/core/hpack/hpack_header_table.cc new file mode 100644 index 000000000000..e9bc9fc3b00e --- /dev/null +++ b/quiche/spdy/core/hpack/hpack_header_table.cc @@ -0,0 +1,188 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/hpack/hpack_header_table.h" + +#include + +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/spdy/core/hpack/hpack_constants.h" +#include "quiche/spdy/core/hpack/hpack_static_table.h" + +namespace spdy { + +HpackHeaderTable::HpackHeaderTable() + : static_entries_(ObtainHpackStaticTable().GetStaticEntries()), + static_index_(ObtainHpackStaticTable().GetStaticIndex()), + static_name_index_(ObtainHpackStaticTable().GetStaticNameIndex()), + settings_size_bound_(kDefaultHeaderTableSizeSetting), + size_(0), + max_size_(kDefaultHeaderTableSizeSetting), + dynamic_table_insertions_(0) {} + +HpackHeaderTable::~HpackHeaderTable() = default; + +size_t HpackHeaderTable::GetByName(absl::string_view name) { + { + auto it = static_name_index_.find(name); + if (it != static_name_index_.end()) { + return 1 + it->second; + } + } + { + NameToEntryMap::const_iterator it = dynamic_name_index_.find(name); + if (it != dynamic_name_index_.end()) { + return dynamic_table_insertions_ - it->second + kStaticTableSize; + } + } + return kHpackEntryNotFound; +} + +size_t HpackHeaderTable::GetByNameAndValue(absl::string_view name, + absl::string_view value) { + HpackLookupEntry query{name, value}; + { + auto it = static_index_.find(query); + if (it != static_index_.end()) { + return 1 + it->second; + } + } + { + auto it = dynamic_index_.find(query); + if (it != dynamic_index_.end()) { + return dynamic_table_insertions_ - it->second + kStaticTableSize; + } + } + return kHpackEntryNotFound; +} + +void HpackHeaderTable::SetMaxSize(size_t max_size) { + QUICHE_CHECK_LE(max_size, settings_size_bound_); + + max_size_ = max_size; + if (size_ > max_size_) { + Evict(EvictionCountToReclaim(size_ - max_size_)); + QUICHE_CHECK_LE(size_, max_size_); + } +} + +void HpackHeaderTable::SetSettingsHeaderTableSize(size_t settings_size) { + settings_size_bound_ = settings_size; + SetMaxSize(settings_size_bound_); +} + +void HpackHeaderTable::EvictionSet(absl::string_view name, + absl::string_view value, + DynamicEntryTable::iterator* begin_out, + DynamicEntryTable::iterator* end_out) { + size_t eviction_count = EvictionCountForEntry(name, value); + *begin_out = dynamic_entries_.end() - eviction_count; + *end_out = dynamic_entries_.end(); +} + +size_t HpackHeaderTable::EvictionCountForEntry(absl::string_view name, + absl::string_view value) const { + size_t available_size = max_size_ - size_; + size_t entry_size = HpackEntry::Size(name, value); + + if (entry_size <= available_size) { + // No evictions are required. + return 0; + } + return EvictionCountToReclaim(entry_size - available_size); +} + +size_t HpackHeaderTable::EvictionCountToReclaim(size_t reclaim_size) const { + size_t count = 0; + for (auto it = dynamic_entries_.rbegin(); + it != dynamic_entries_.rend() && reclaim_size != 0; ++it, ++count) { + reclaim_size -= std::min(reclaim_size, (*it)->Size()); + } + return count; +} + +void HpackHeaderTable::Evict(size_t count) { + for (size_t i = 0; i != count; ++i) { + QUICHE_CHECK(!dynamic_entries_.empty()); + + HpackEntry* entry = dynamic_entries_.back().get(); + const size_t index = dynamic_table_insertions_ - dynamic_entries_.size(); + + size_ -= entry->Size(); + auto it = dynamic_index_.find({entry->name(), entry->value()}); + QUICHE_DCHECK(it != dynamic_index_.end()); + // Only remove an entry from the index if its insertion index matches; + // otherwise, the index refers to another entry with the same name and + // value. + if (it->second == index) { + dynamic_index_.erase(it); + } + auto name_it = dynamic_name_index_.find(entry->name()); + QUICHE_DCHECK(name_it != dynamic_name_index_.end()); + // Only remove an entry from the literal index if its insertion index + /// matches; otherwise, the index refers to another entry with the same + // name. + if (name_it->second == index) { + dynamic_name_index_.erase(name_it); + } + dynamic_entries_.pop_back(); + } +} + +const HpackEntry* HpackHeaderTable::TryAddEntry(absl::string_view name, + absl::string_view value) { + // Since |dynamic_entries_| has iterator stability, |name| and |value| are + // valid even after evicting other entries and push_front() making room for + // the new one. + Evict(EvictionCountForEntry(name, value)); + + size_t entry_size = HpackEntry::Size(name, value); + if (entry_size > (max_size_ - size_)) { + // Entire table has been emptied, but there's still insufficient room. + QUICHE_DCHECK(dynamic_entries_.empty()); + QUICHE_DCHECK_EQ(0u, size_); + return nullptr; + } + + const size_t index = dynamic_table_insertions_; + dynamic_entries_.push_front( + std::make_unique(std::string(name), std::string(value))); + HpackEntry* new_entry = dynamic_entries_.front().get(); + auto index_result = dynamic_index_.insert(std::make_pair( + HpackLookupEntry{new_entry->name(), new_entry->value()}, index)); + if (!index_result.second) { + // An entry with the same name and value already exists in the dynamic + // index. We should replace it with the newly added entry. + QUICHE_DVLOG(1) << "Found existing entry at: " << index_result.first->second + << " replacing with: " << new_entry->GetDebugString() + << " at: " << index; + QUICHE_DCHECK_GT(index, index_result.first->second); + dynamic_index_.erase(index_result.first); + auto insert_result = dynamic_index_.insert(std::make_pair( + HpackLookupEntry{new_entry->name(), new_entry->value()}, index)); + QUICHE_CHECK(insert_result.second); + } + + auto name_result = + dynamic_name_index_.insert(std::make_pair(new_entry->name(), index)); + if (!name_result.second) { + // An entry with the same name already exists in the dynamic index. We + // should replace it with the newly added entry. + QUICHE_DVLOG(1) << "Found existing entry at: " << name_result.first->second + << " replacing with: " << new_entry->GetDebugString() + << " at: " << index; + QUICHE_DCHECK_GT(index, name_result.first->second); + dynamic_name_index_.erase(name_result.first); + auto insert_result = + dynamic_name_index_.insert(std::make_pair(new_entry->name(), index)); + QUICHE_CHECK(insert_result.second); + } + + size_ += entry_size; + ++dynamic_table_insertions_; + + return dynamic_entries_.front().get(); +} + +} // namespace spdy diff --git a/quiche/spdy/core/hpack/hpack_header_table.h b/quiche/spdy/core/hpack/hpack_header_table.h new file mode 100644 index 000000000000..a22b762a761c --- /dev/null +++ b/quiche/spdy/core/hpack/hpack_header_table.h @@ -0,0 +1,153 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_SPDY_CORE_HPACK_HPACK_HEADER_TABLE_H_ +#define QUICHE_SPDY_CORE_HPACK_HPACK_HEADER_TABLE_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/hash/hash.h" +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/quiche_circular_deque.h" +#include "quiche/spdy/core/hpack/hpack_entry.h" + +// All section references below are to http://tools.ietf.org/html/rfc7541. + +namespace spdy { + +namespace test { +class HpackHeaderTablePeer; +} // namespace test + +// Return value of GetByName() and GetByNameAndValue() if matching entry is not +// found. This value is never used in HPACK for indexing entries, see +// https://httpwg.org/specs/rfc7541.html#index.address.space. +constexpr size_t kHpackEntryNotFound = 0; + +// A data structure for the static table (2.3.1) and the dynamic table (2.3.2). +class QUICHE_EXPORT HpackHeaderTable { + public: + friend class test::HpackHeaderTablePeer; + + // Use a lightweight, memory efficient container for the static table, which + // is initialized once and never changed after. + using StaticEntryTable = std::vector; + + // HpackHeaderTable takes advantage of the deque property that references + // remain valid, so long as insertions & deletions are at the head & tail. + using DynamicEntryTable = + quiche::QuicheCircularDeque>; + + using NameValueToEntryMap = absl::flat_hash_map; + using NameToEntryMap = absl::flat_hash_map; + + HpackHeaderTable(); + HpackHeaderTable(const HpackHeaderTable&) = delete; + HpackHeaderTable& operator=(const HpackHeaderTable&) = delete; + + ~HpackHeaderTable(); + + // Last-acknowledged value of SETTINGS_HEADER_TABLE_SIZE. + size_t settings_size_bound() const { return settings_size_bound_; } + + // Current and maximum estimated byte size of the table, as described in + // 4.1. Notably, this is /not/ the number of entries in the table. + size_t size() const { return size_; } + size_t max_size() const { return max_size_; } + + // The HPACK indexing scheme used by GetByName() and GetByNameAndValue() is + // defined at https://httpwg.org/specs/rfc7541.html#index.address.space. + + // Returns the index of the lowest-index entry matching |name|, + // or kHpackEntryNotFound if no matching entry is found. + size_t GetByName(absl::string_view name); + + // Returns the index of the lowest-index entry matching |name| and |value|, + // or kHpackEntryNotFound if no matching entry is found. + size_t GetByNameAndValue(absl::string_view name, absl::string_view value); + + // Sets the maximum size of the header table, evicting entries if + // necessary as described in 5.2. + void SetMaxSize(size_t max_size); + + // Sets the SETTINGS_HEADER_TABLE_SIZE bound of the table. Will call + // SetMaxSize() as needed to preserve max_size() <= settings_size_bound(). + void SetSettingsHeaderTableSize(size_t settings_size); + + // Determine the set of entries which would be evicted by the insertion + // of |name| & |value| into the table, as per section 4.4. No eviction + // actually occurs. The set is returned via range [begin_out, end_out). + void EvictionSet(absl::string_view name, absl::string_view value, + DynamicEntryTable::iterator* begin_out, + DynamicEntryTable::iterator* end_out); + + // Adds an entry for the representation, evicting entries as needed. |name| + // and |value| must not point to an entry in |dynamic_entries_| which is about + // to be evicted, but they may point to an entry which is not. + // The added HpackEntry is returned, or NULL is returned if all entries were + // evicted and the empty table is of insufficent size for the representation. + const HpackEntry* TryAddEntry(absl::string_view name, + absl::string_view value); + + private: + // Returns number of evictions required to enter |name| & |value|. + size_t EvictionCountForEntry(absl::string_view name, + absl::string_view value) const; + + // Returns number of evictions required to reclaim |reclaim_size| table size. + size_t EvictionCountToReclaim(size_t reclaim_size) const; + + // Evicts |count| oldest entries from the table. + void Evict(size_t count); + + // |static_entries_|, |static_index_|, and |static_name_index_| are owned by + // HpackStaticTable singleton. + + // Stores HpackEntries. + const StaticEntryTable& static_entries_; + DynamicEntryTable dynamic_entries_; + + // Tracks the index of the unique HpackEntry for a given header name and + // value. Keys consist of string_views that point to strings stored in + // |static_entries_|. + const NameValueToEntryMap& static_index_; + + // Tracks the index of the first static entry for each name in the static + // table. Each key is a string_view that points to a name string stored in + // |static_entries_|. + const NameToEntryMap& static_name_index_; + + // Tracks the index of the most recently inserted HpackEntry for a given + // header name and value. Keys consist of string_views that point to strings + // stored in |dynamic_entries_|. + NameValueToEntryMap dynamic_index_; + + // Tracks the index of the most recently inserted HpackEntry for a given + // header name. Each key is a string_view that points to a name string stored + // in |dynamic_entries_|. + NameToEntryMap dynamic_name_index_; + + // Last acknowledged value for SETTINGS_HEADER_TABLE_SIZE. + size_t settings_size_bound_; + + // Estimated current and maximum byte size of the table. + // |max_size_| <= |settings_size_bound_| + size_t size_; + size_t max_size_; + + // Total number of dynamic table insertions so far + // (including entries that have been evicted). + size_t dynamic_table_insertions_; +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_HPACK_HPACK_HEADER_TABLE_H_ diff --git a/quiche/spdy/core/hpack/hpack_header_table_test.cc b/quiche/spdy/core/hpack/hpack_header_table_test.cc new file mode 100644 index 000000000000..7886b2d23d14 --- /dev/null +++ b/quiche/spdy/core/hpack/hpack_header_table_test.cc @@ -0,0 +1,392 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/hpack/hpack_header_table.h" + +#include +#include +#include +#include +#include + +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/spdy/core/hpack/hpack_constants.h" +#include "quiche/spdy/core/hpack/hpack_entry.h" +#include "quiche/spdy/core/hpack/hpack_static_table.h" + +namespace spdy { + +using std::distance; + +namespace test { + +class HpackHeaderTablePeer { + public: + explicit HpackHeaderTablePeer(HpackHeaderTable* table) : table_(table) {} + + const HpackHeaderTable::DynamicEntryTable& dynamic_entries() { + return table_->dynamic_entries_; + } + const HpackHeaderTable::StaticEntryTable& static_entries() { + return table_->static_entries_; + } + const HpackEntry* GetFirstStaticEntry() { + return &table_->static_entries_.front(); + } + const HpackEntry* GetLastStaticEntry() { + return &table_->static_entries_.back(); + } + std::vector EvictionSet(absl::string_view name, + absl::string_view value) { + HpackHeaderTable::DynamicEntryTable::iterator begin, end; + table_->EvictionSet(name, value, &begin, &end); + std::vector result; + for (; begin != end; ++begin) { + result.push_back(begin->get()); + } + return result; + } + size_t dynamic_table_insertions() { + return table_->dynamic_table_insertions_; + } + size_t EvictionCountForEntry(absl::string_view name, + absl::string_view value) { + return table_->EvictionCountForEntry(name, value); + } + size_t EvictionCountToReclaim(size_t reclaim_size) { + return table_->EvictionCountToReclaim(reclaim_size); + } + void Evict(size_t count) { return table_->Evict(count); } + + private: + HpackHeaderTable* table_; +}; + +} // namespace test + +namespace { + +class HpackHeaderTableTest : public quiche::test::QuicheTest { + protected: + typedef std::vector HpackEntryVector; + + HpackHeaderTableTest() : table_(), peer_(&table_) {} + + // Returns an entry whose Size() is equal to the given one. + static HpackEntry MakeEntryOfSize(uint32_t size) { + EXPECT_GE(size, kHpackEntrySizeOverhead); + std::string name((size - kHpackEntrySizeOverhead) / 2, 'n'); + std::string value(size - kHpackEntrySizeOverhead - name.size(), 'v'); + HpackEntry entry(name, value); + EXPECT_EQ(size, entry.Size()); + return entry; + } + + // Returns a vector of entries whose total size is equal to the given + // one. + static HpackEntryVector MakeEntriesOfTotalSize(uint32_t total_size) { + EXPECT_GE(total_size, kHpackEntrySizeOverhead); + uint32_t entry_size = kHpackEntrySizeOverhead; + uint32_t remaining_size = total_size; + HpackEntryVector entries; + while (remaining_size > 0) { + EXPECT_LE(entry_size, remaining_size); + entries.push_back(MakeEntryOfSize(entry_size)); + remaining_size -= entry_size; + entry_size = std::min(remaining_size, entry_size + 32); + } + return entries; + } + + // Adds the given vector of entries to the given header table, + // expecting no eviction to happen. + void AddEntriesExpectNoEviction(const HpackEntryVector& entries) { + for (auto it = entries.begin(); it != entries.end(); ++it) { + HpackHeaderTable::DynamicEntryTable::iterator begin, end; + + table_.EvictionSet(it->name(), it->value(), &begin, &end); + EXPECT_EQ(0, distance(begin, end)); + + const HpackEntry* entry = table_.TryAddEntry(it->name(), it->value()); + EXPECT_NE(entry, static_cast(nullptr)); + } + } + + HpackHeaderTable table_; + test::HpackHeaderTablePeer peer_; +}; + +TEST_F(HpackHeaderTableTest, StaticTableInitialization) { + EXPECT_EQ(0u, table_.size()); + EXPECT_EQ(kDefaultHeaderTableSizeSetting, table_.max_size()); + EXPECT_EQ(kDefaultHeaderTableSizeSetting, table_.settings_size_bound()); + + EXPECT_EQ(0u, peer_.dynamic_entries().size()); + EXPECT_EQ(0u, peer_.dynamic_table_insertions()); + + // Static entries have been populated and inserted into the table & index. + const HpackHeaderTable::StaticEntryTable& static_entries = + peer_.static_entries(); + EXPECT_EQ(kStaticTableSize, static_entries.size()); + // HPACK indexing scheme is 1-based. + size_t index = 1; + for (const HpackEntry& entry : static_entries) { + EXPECT_EQ(index, table_.GetByNameAndValue(entry.name(), entry.value())); + index++; + } +} + +TEST_F(HpackHeaderTableTest, BasicDynamicEntryInsertionAndEviction) { + EXPECT_EQ(kStaticTableSize, peer_.static_entries().size()); + + const HpackEntry* first_static_entry = peer_.GetFirstStaticEntry(); + const HpackEntry* last_static_entry = peer_.GetLastStaticEntry(); + + const HpackEntry* entry = table_.TryAddEntry("header-key", "Header Value"); + EXPECT_EQ("header-key", entry->name()); + EXPECT_EQ("Header Value", entry->value()); + + // Table counts were updated appropriately. + EXPECT_EQ(entry->Size(), table_.size()); + EXPECT_EQ(1u, peer_.dynamic_entries().size()); + EXPECT_EQ(kStaticTableSize, peer_.static_entries().size()); + + EXPECT_EQ(62u, table_.GetByNameAndValue("header-key", "Header Value")); + + // Index of static entries does not change. + EXPECT_EQ(first_static_entry, peer_.GetFirstStaticEntry()); + EXPECT_EQ(last_static_entry, peer_.GetLastStaticEntry()); + + // Evict |entry|. Table counts are again updated appropriately. + peer_.Evict(1); + EXPECT_EQ(0u, table_.size()); + EXPECT_EQ(0u, peer_.dynamic_entries().size()); + EXPECT_EQ(kStaticTableSize, peer_.static_entries().size()); + + // Index of static entries does not change. + EXPECT_EQ(first_static_entry, peer_.GetFirstStaticEntry()); + EXPECT_EQ(last_static_entry, peer_.GetLastStaticEntry()); +} + +TEST_F(HpackHeaderTableTest, EntryIndexing) { + const HpackEntry* first_static_entry = peer_.GetFirstStaticEntry(); + const HpackEntry* last_static_entry = peer_.GetLastStaticEntry(); + + // Static entries are queryable by name & value. + EXPECT_EQ(1u, table_.GetByName(first_static_entry->name())); + EXPECT_EQ(1u, table_.GetByNameAndValue(first_static_entry->name(), + first_static_entry->value())); + + // Create a mix of entries which duplicate names, and names & values of both + // dynamic and static entries. + table_.TryAddEntry(first_static_entry->name(), first_static_entry->value()); + table_.TryAddEntry(first_static_entry->name(), "Value Four"); + table_.TryAddEntry("key-1", "Value One"); + table_.TryAddEntry("key-2", "Value Three"); + table_.TryAddEntry("key-1", "Value Two"); + table_.TryAddEntry("key-2", "Value Three"); + table_.TryAddEntry("key-2", "Value Four"); + + // The following entry is identical to the one at index 68. The smaller index + // is returned by GetByNameAndValue(). + EXPECT_EQ(1u, table_.GetByNameAndValue(first_static_entry->name(), + first_static_entry->value())); + EXPECT_EQ(67u, + table_.GetByNameAndValue(first_static_entry->name(), "Value Four")); + EXPECT_EQ(66u, table_.GetByNameAndValue("key-1", "Value One")); + EXPECT_EQ(64u, table_.GetByNameAndValue("key-1", "Value Two")); + // The following entry is identical to the one at index 65. The smaller index + // is returned by GetByNameAndValue(). + EXPECT_EQ(63u, table_.GetByNameAndValue("key-2", "Value Three")); + EXPECT_EQ(62u, table_.GetByNameAndValue("key-2", "Value Four")); + + // Index of static entries does not change. + EXPECT_EQ(first_static_entry, peer_.GetFirstStaticEntry()); + EXPECT_EQ(last_static_entry, peer_.GetLastStaticEntry()); + + // Querying by name returns the most recently added matching entry. + EXPECT_EQ(64u, table_.GetByName("key-1")); + EXPECT_EQ(62u, table_.GetByName("key-2")); + EXPECT_EQ(1u, table_.GetByName(first_static_entry->name())); + EXPECT_EQ(kHpackEntryNotFound, table_.GetByName("not-present")); + + // Querying by name & value returns the lowest-index matching entry among + // static entries, and the highest-index one among dynamic entries. + EXPECT_EQ(66u, table_.GetByNameAndValue("key-1", "Value One")); + EXPECT_EQ(64u, table_.GetByNameAndValue("key-1", "Value Two")); + EXPECT_EQ(63u, table_.GetByNameAndValue("key-2", "Value Three")); + EXPECT_EQ(62u, table_.GetByNameAndValue("key-2", "Value Four")); + EXPECT_EQ(1u, table_.GetByNameAndValue(first_static_entry->name(), + first_static_entry->value())); + EXPECT_EQ(67u, + table_.GetByNameAndValue(first_static_entry->name(), "Value Four")); + EXPECT_EQ(kHpackEntryNotFound, + table_.GetByNameAndValue("key-1", "Not Present")); + EXPECT_EQ(kHpackEntryNotFound, + table_.GetByNameAndValue("not-present", "Value One")); + + // Evict |entry1|. Queries for its name & value now return the static entry. + // |entry2| remains queryable. + peer_.Evict(1); + EXPECT_EQ(1u, table_.GetByNameAndValue(first_static_entry->name(), + first_static_entry->value())); + EXPECT_EQ(67u, + table_.GetByNameAndValue(first_static_entry->name(), "Value Four")); + + // Evict |entry2|. Queries by its name & value are not found. + peer_.Evict(1); + EXPECT_EQ(kHpackEntryNotFound, + table_.GetByNameAndValue(first_static_entry->name(), "Value Four")); + + // Index of static entries does not change. + EXPECT_EQ(first_static_entry, peer_.GetFirstStaticEntry()); + EXPECT_EQ(last_static_entry, peer_.GetLastStaticEntry()); +} + +TEST_F(HpackHeaderTableTest, SetSizes) { + std::string key = "key", value = "value"; + const HpackEntry* entry1 = table_.TryAddEntry(key, value); + const HpackEntry* entry2 = table_.TryAddEntry(key, value); + const HpackEntry* entry3 = table_.TryAddEntry(key, value); + + // Set exactly large enough. No Evictions. + size_t max_size = entry1->Size() + entry2->Size() + entry3->Size(); + table_.SetMaxSize(max_size); + EXPECT_EQ(3u, peer_.dynamic_entries().size()); + + // Set just too small. One eviction. + max_size = entry1->Size() + entry2->Size() + entry3->Size() - 1; + table_.SetMaxSize(max_size); + EXPECT_EQ(2u, peer_.dynamic_entries().size()); + + // Changing SETTINGS_HEADER_TABLE_SIZE. + EXPECT_EQ(kDefaultHeaderTableSizeSetting, table_.settings_size_bound()); + // In production, the size passed to SetSettingsHeaderTableSize is never + // larger than table_.settings_size_bound(). + table_.SetSettingsHeaderTableSize(kDefaultHeaderTableSizeSetting * 3 + 1); + EXPECT_EQ(kDefaultHeaderTableSizeSetting * 3 + 1, table_.max_size()); + + // SETTINGS_HEADER_TABLE_SIZE upper-bounds |table_.max_size()|, + // and will force evictions. + max_size = entry3->Size() - 1; + table_.SetSettingsHeaderTableSize(max_size); + EXPECT_EQ(max_size, table_.max_size()); + EXPECT_EQ(max_size, table_.settings_size_bound()); + EXPECT_EQ(0u, peer_.dynamic_entries().size()); +} + +TEST_F(HpackHeaderTableTest, EvictionCountForEntry) { + std::string key = "key", value = "value"; + const HpackEntry* entry1 = table_.TryAddEntry(key, value); + const HpackEntry* entry2 = table_.TryAddEntry(key, value); + size_t entry3_size = HpackEntry::Size(key, value); + + // Just enough capacity for third entry. + table_.SetMaxSize(entry1->Size() + entry2->Size() + entry3_size); + EXPECT_EQ(0u, peer_.EvictionCountForEntry(key, value)); + EXPECT_EQ(1u, peer_.EvictionCountForEntry(key, value + "x")); + + // No extra capacity. Third entry would force evictions. + table_.SetMaxSize(entry1->Size() + entry2->Size()); + EXPECT_EQ(1u, peer_.EvictionCountForEntry(key, value)); + EXPECT_EQ(2u, peer_.EvictionCountForEntry(key, value + "x")); +} + +TEST_F(HpackHeaderTableTest, EvictionCountToReclaim) { + std::string key = "key", value = "value"; + const HpackEntry* entry1 = table_.TryAddEntry(key, value); + const HpackEntry* entry2 = table_.TryAddEntry(key, value); + + EXPECT_EQ(1u, peer_.EvictionCountToReclaim(1)); + EXPECT_EQ(1u, peer_.EvictionCountToReclaim(entry1->Size())); + EXPECT_EQ(2u, peer_.EvictionCountToReclaim(entry1->Size() + 1)); + EXPECT_EQ(2u, peer_.EvictionCountToReclaim(entry1->Size() + entry2->Size())); +} + +// Fill a header table with entries. Make sure the entries are in +// reverse order in the header table. +TEST_F(HpackHeaderTableTest, TryAddEntryBasic) { + EXPECT_EQ(0u, table_.size()); + EXPECT_EQ(table_.settings_size_bound(), table_.max_size()); + + HpackEntryVector entries = MakeEntriesOfTotalSize(table_.max_size()); + + // Most of the checks are in AddEntriesExpectNoEviction(). + AddEntriesExpectNoEviction(entries); + EXPECT_EQ(table_.max_size(), table_.size()); + EXPECT_EQ(table_.settings_size_bound(), table_.size()); +} + +// Fill a header table with entries, and then ramp the table's max +// size down to evict an entry one at a time. Make sure the eviction +// happens as expected. +TEST_F(HpackHeaderTableTest, SetMaxSize) { + HpackEntryVector entries = + MakeEntriesOfTotalSize(kDefaultHeaderTableSizeSetting / 2); + AddEntriesExpectNoEviction(entries); + + for (auto it = entries.begin(); it != entries.end(); ++it) { + size_t expected_count = distance(it, entries.end()); + EXPECT_EQ(expected_count, peer_.dynamic_entries().size()); + + table_.SetMaxSize(table_.size() + 1); + EXPECT_EQ(expected_count, peer_.dynamic_entries().size()); + + table_.SetMaxSize(table_.size()); + EXPECT_EQ(expected_count, peer_.dynamic_entries().size()); + + --expected_count; + table_.SetMaxSize(table_.size() - 1); + EXPECT_EQ(expected_count, peer_.dynamic_entries().size()); + } + EXPECT_EQ(0u, table_.size()); +} + +// Fill a header table with entries, and then add an entry just big +// enough to cause eviction of all but one entry. Make sure the +// eviction happens as expected and the long entry is inserted into +// the table. +TEST_F(HpackHeaderTableTest, TryAddEntryEviction) { + HpackEntryVector entries = MakeEntriesOfTotalSize(table_.max_size()); + AddEntriesExpectNoEviction(entries); + + // The first entry in the dynamic table. + const HpackEntry* survivor_entry = peer_.dynamic_entries().front().get(); + + HpackEntry long_entry = + MakeEntryOfSize(table_.max_size() - survivor_entry->Size()); + + // All dynamic entries but the first are to be evicted. + EXPECT_EQ(peer_.dynamic_entries().size() - 1, + peer_.EvictionSet(long_entry.name(), long_entry.value()).size()); + + table_.TryAddEntry(long_entry.name(), long_entry.value()); + EXPECT_EQ(2u, peer_.dynamic_entries().size()); + EXPECT_EQ(63u, table_.GetByNameAndValue(survivor_entry->name(), + survivor_entry->value())); + EXPECT_EQ(62u, + table_.GetByNameAndValue(long_entry.name(), long_entry.value())); +} + +// Fill a header table with entries, and then add an entry bigger than +// the entire table. Make sure no entry remains in the table. +TEST_F(HpackHeaderTableTest, TryAddTooLargeEntry) { + HpackEntryVector entries = MakeEntriesOfTotalSize(table_.max_size()); + AddEntriesExpectNoEviction(entries); + + const HpackEntry long_entry = MakeEntryOfSize(table_.max_size() + 1); + + // All entries are to be evicted. + EXPECT_EQ(peer_.dynamic_entries().size(), + peer_.EvictionSet(long_entry.name(), long_entry.value()).size()); + + const HpackEntry* new_entry = + table_.TryAddEntry(long_entry.name(), long_entry.value()); + EXPECT_EQ(new_entry, static_cast(nullptr)); + EXPECT_EQ(0u, peer_.dynamic_entries().size()); +} + +} // namespace + +} // namespace spdy diff --git a/quiche/spdy/core/hpack/hpack_output_stream.cc b/quiche/spdy/core/hpack/hpack_output_stream.cc new file mode 100644 index 000000000000..8a0c8d670374 --- /dev/null +++ b/quiche/spdy/core/hpack/hpack_output_stream.cc @@ -0,0 +1,100 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/hpack/hpack_output_stream.h" + +#include + +#include "quiche/common/platform/api/quiche_logging.h" + +namespace spdy { + +HpackOutputStream::HpackOutputStream() : bit_offset_(0) {} + +HpackOutputStream::~HpackOutputStream() = default; + +void HpackOutputStream::AppendBits(uint8_t bits, size_t bit_size) { + QUICHE_DCHECK_GT(bit_size, 0u); + QUICHE_DCHECK_LE(bit_size, 8u); + QUICHE_DCHECK_EQ(bits >> bit_size, 0); + size_t new_bit_offset = bit_offset_ + bit_size; + if (bit_offset_ == 0) { + // Buffer ends on a byte boundary. + QUICHE_DCHECK_LE(bit_size, 8u); + buffer_.append(1, bits << (8 - bit_size)); + } else if (new_bit_offset <= 8) { + // Buffer does not end on a byte boundary but the given bits fit + // in the remainder of the last byte. + buffer_.back() |= bits << (8 - new_bit_offset); + } else { + // Buffer does not end on a byte boundary and the given bits do + // not fit in the remainder of the last byte. + buffer_.back() |= bits >> (new_bit_offset - 8); + buffer_.append(1, bits << (16 - new_bit_offset)); + } + bit_offset_ = new_bit_offset % 8; +} + +void HpackOutputStream::AppendPrefix(HpackPrefix prefix) { + AppendBits(prefix.bits, prefix.bit_size); +} + +void HpackOutputStream::AppendBytes(absl::string_view buffer) { + QUICHE_DCHECK_EQ(bit_offset_, 0u); + buffer_.append(buffer.data(), buffer.size()); +} + +void HpackOutputStream::AppendUint32(uint32_t I) { + // The algorithm below is adapted from the pseudocode in 6.1. + size_t N = 8 - bit_offset_; + uint8_t max_first_byte = static_cast((1 << N) - 1); + if (I < max_first_byte) { + AppendBits(static_cast(I), N); + } else { + AppendBits(max_first_byte, N); + I -= max_first_byte; + while ((I & ~0x7f) != 0) { + buffer_.append(1, (I & 0x7f) | 0x80); + I >>= 7; + } + AppendBits(static_cast(I), 8); + } + QUICHE_DCHECK_EQ(bit_offset_, 0u); +} + +std::string* HpackOutputStream::MutableString() { + QUICHE_DCHECK_EQ(bit_offset_, 0u); + return &buffer_; +} + +std::string HpackOutputStream::TakeString() { + // This must hold, since all public functions cause the buffer to + // end on a byte boundary. + QUICHE_DCHECK_EQ(bit_offset_, 0u); + std::string out = std::move(buffer_); + buffer_ = {}; + bit_offset_ = 0; + return out; +} + +std::string HpackOutputStream::BoundedTakeString(size_t max_size) { + if (buffer_.size() > max_size) { + // Save off overflow bytes to temporary string (causes a copy). + std::string overflow = buffer_.substr(max_size); + + // Resize buffer down to the given limit. + buffer_.resize(max_size); + + // Give buffer to output string. + std::string out = std::move(buffer_); + + // Reset to contain overflow. + buffer_ = std::move(overflow); + return out; + } else { + return TakeString(); + } +} + +} // namespace spdy diff --git a/quiche/spdy/core/hpack/hpack_output_stream.h b/quiche/spdy/core/hpack/hpack_output_stream.h new file mode 100644 index 000000000000..0d3bf8c0ed25 --- /dev/null +++ b/quiche/spdy/core/hpack/hpack_output_stream.h @@ -0,0 +1,75 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_SPDY_CORE_HPACK_HPACK_OUTPUT_STREAM_H_ +#define QUICHE_SPDY_CORE_HPACK_HPACK_OUTPUT_STREAM_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/spdy/core/hpack/hpack_constants.h" + +// All section references below are to +// http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-08 + +namespace spdy { + +// An HpackOutputStream handles all the low-level details of encoding +// header fields. +class QUICHE_EXPORT HpackOutputStream { + public: + HpackOutputStream(); + HpackOutputStream(const HpackOutputStream&) = delete; + HpackOutputStream& operator=(const HpackOutputStream&) = delete; + ~HpackOutputStream(); + + // Appends the lower |bit_size| bits of |bits| to the internal buffer. + // + // |bit_size| must be > 0 and <= 8. |bits| must not have any bits + // set other than the lower |bit_size| bits. + void AppendBits(uint8_t bits, size_t bit_size); + + // Simply forwards to AppendBits(prefix.bits, prefix.bit-size). + void AppendPrefix(HpackPrefix prefix); + + // Directly appends |buffer|. + void AppendBytes(absl::string_view buffer); + + // Appends the given integer using the representation described in + // 6.1. If the internal buffer ends on a byte boundary, the prefix + // length N is taken to be 8; otherwise, it is taken to be the + // number of bits to the next byte boundary. + // + // It is guaranteed that the internal buffer will end on a byte + // boundary after this function is called. + void AppendUint32(uint32_t I); + + // Return pointer to internal buffer. |bit_offset_| needs to be zero. + std::string* MutableString(); + + // Returns the internal buffer as a string, then resets state. + std::string TakeString(); + + // Returns up to |max_size| bytes of the internal buffer. Resets + // internal state with the overflow. + std::string BoundedTakeString(size_t max_size); + + // Size in bytes of stream's internal buffer. + size_t size() const { return buffer_.size(); } + + private: + // The internal bit buffer. + std::string buffer_; + + // If 0, the buffer ends on a byte boundary. If non-zero, the buffer + // ends on the nth most significant bit. Guaranteed to be < 8. + size_t bit_offset_; +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_HPACK_HPACK_OUTPUT_STREAM_H_ diff --git a/quiche/spdy/core/hpack/hpack_output_stream_test.cc b/quiche/spdy/core/hpack/hpack_output_stream_test.cc new file mode 100644 index 000000000000..ea1c2656fd8a --- /dev/null +++ b/quiche/spdy/core/hpack/hpack_output_stream_test.cc @@ -0,0 +1,284 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/hpack/hpack_output_stream.h" + +#include + +#include "quiche/common/platform/api/quiche_test.h" + +namespace spdy { + +namespace { + +// Make sure that AppendBits() appends bits starting from the most +// significant bit, and that it can handle crossing a byte boundary. +TEST(HpackOutputStreamTest, AppendBits) { + HpackOutputStream output_stream; + std::string expected_str; + + output_stream.AppendBits(0x1, 1); + expected_str.append(1, 0x00); + expected_str.back() |= (0x1 << 7); + + output_stream.AppendBits(0x0, 1); + + output_stream.AppendBits(0x3, 2); + *expected_str.rbegin() |= (0x3 << 4); + + output_stream.AppendBits(0x0, 2); + + // Byte-crossing append. + output_stream.AppendBits(0x7, 3); + *expected_str.rbegin() |= (0x7 >> 1); + expected_str.append(1, 0x00); + expected_str.back() |= (0x7 << 7); + + output_stream.AppendBits(0x0, 7); + + std::string str = output_stream.TakeString(); + EXPECT_EQ(expected_str, str); +} + +// Utility function to return I as a string encoded with an N-bit +// prefix. +std::string EncodeUint32(uint8_t N, uint32_t I) { + HpackOutputStream output_stream; + if (N < 8) { + output_stream.AppendBits(0x00, 8 - N); + } + output_stream.AppendUint32(I); + std::string str = output_stream.TakeString(); + return str; +} + +// The {Number}ByteIntegersEightBitPrefix tests below test that +// certain integers are encoded correctly with an 8-bit prefix in +// exactly {Number} bytes. + +TEST(HpackOutputStreamTest, OneByteIntegersEightBitPrefix) { + // Minimum. + EXPECT_EQ(std::string("\x00", 1), EncodeUint32(8, 0x00)); + EXPECT_EQ("\x7f", EncodeUint32(8, 0x7f)); + // Maximum. + EXPECT_EQ("\xfe", EncodeUint32(8, 0xfe)); +} + +TEST(HpackOutputStreamTest, TwoByteIntegersEightBitPrefix) { + // Minimum. + EXPECT_EQ(std::string("\xff\x00", 2), EncodeUint32(8, 0xff)); + EXPECT_EQ("\xff\x01", EncodeUint32(8, 0x0100)); + // Maximum. + EXPECT_EQ("\xff\x7f", EncodeUint32(8, 0x017e)); +} + +TEST(HpackOutputStreamTest, ThreeByteIntegersEightBitPrefix) { + // Minimum. + EXPECT_EQ("\xff\x80\x01", EncodeUint32(8, 0x017f)); + EXPECT_EQ("\xff\x80\x1e", EncodeUint32(8, 0x0fff)); + // Maximum. + EXPECT_EQ("\xff\xff\x7f", EncodeUint32(8, 0x40fe)); +} + +TEST(HpackOutputStreamTest, FourByteIntegersEightBitPrefix) { + // Minimum. + EXPECT_EQ("\xff\x80\x80\x01", EncodeUint32(8, 0x40ff)); + EXPECT_EQ("\xff\x80\xfe\x03", EncodeUint32(8, 0xffff)); + // Maximum. + EXPECT_EQ("\xff\xff\xff\x7f", EncodeUint32(8, 0x002000fe)); +} + +TEST(HpackOutputStreamTest, FiveByteIntegersEightBitPrefix) { + // Minimum. + EXPECT_EQ("\xff\x80\x80\x80\x01", EncodeUint32(8, 0x002000ff)); + EXPECT_EQ("\xff\x80\xfe\xff\x07", EncodeUint32(8, 0x00ffffff)); + // Maximum. + EXPECT_EQ("\xff\xff\xff\xff\x7f", EncodeUint32(8, 0x100000fe)); +} + +TEST(HpackOutputStreamTest, SixByteIntegersEightBitPrefix) { + // Minimum. + EXPECT_EQ("\xff\x80\x80\x80\x80\x01", EncodeUint32(8, 0x100000ff)); + // Maximum. + EXPECT_EQ("\xff\x80\xfe\xff\xff\x0f", EncodeUint32(8, 0xffffffff)); +} + +// The {Number}ByteIntegersOneToSevenBitPrefix tests below test that +// certain integers are encoded correctly with an N-bit prefix in +// exactly {Number} bytes for N in {1, 2, ..., 7}. + +TEST(HpackOutputStreamTest, OneByteIntegersOneToSevenBitPrefixes) { + // Minimums. + EXPECT_EQ(std::string("\x00", 1), EncodeUint32(7, 0x00)); + EXPECT_EQ(std::string("\x00", 1), EncodeUint32(6, 0x00)); + EXPECT_EQ(std::string("\x00", 1), EncodeUint32(5, 0x00)); + EXPECT_EQ(std::string("\x00", 1), EncodeUint32(4, 0x00)); + EXPECT_EQ(std::string("\x00", 1), EncodeUint32(3, 0x00)); + EXPECT_EQ(std::string("\x00", 1), EncodeUint32(2, 0x00)); + EXPECT_EQ(std::string("\x00", 1), EncodeUint32(1, 0x00)); + + // Maximums. + EXPECT_EQ("\x7e", EncodeUint32(7, 0x7e)); + EXPECT_EQ("\x3e", EncodeUint32(6, 0x3e)); + EXPECT_EQ("\x1e", EncodeUint32(5, 0x1e)); + EXPECT_EQ("\x0e", EncodeUint32(4, 0x0e)); + EXPECT_EQ("\x06", EncodeUint32(3, 0x06)); + EXPECT_EQ("\x02", EncodeUint32(2, 0x02)); + EXPECT_EQ(std::string("\x00", 1), EncodeUint32(1, 0x00)); +} + +TEST(HpackOutputStreamTest, TwoByteIntegersOneToSevenBitPrefixes) { + // Minimums. + EXPECT_EQ(std::string("\x7f\x00", 2), EncodeUint32(7, 0x7f)); + EXPECT_EQ(std::string("\x3f\x00", 2), EncodeUint32(6, 0x3f)); + EXPECT_EQ(std::string("\x1f\x00", 2), EncodeUint32(5, 0x1f)); + EXPECT_EQ(std::string("\x0f\x00", 2), EncodeUint32(4, 0x0f)); + EXPECT_EQ(std::string("\x07\x00", 2), EncodeUint32(3, 0x07)); + EXPECT_EQ(std::string("\x03\x00", 2), EncodeUint32(2, 0x03)); + EXPECT_EQ(std::string("\x01\x00", 2), EncodeUint32(1, 0x01)); + + // Maximums. + EXPECT_EQ("\x7f\x7f", EncodeUint32(7, 0xfe)); + EXPECT_EQ("\x3f\x7f", EncodeUint32(6, 0xbe)); + EXPECT_EQ("\x1f\x7f", EncodeUint32(5, 0x9e)); + EXPECT_EQ("\x0f\x7f", EncodeUint32(4, 0x8e)); + EXPECT_EQ("\x07\x7f", EncodeUint32(3, 0x86)); + EXPECT_EQ("\x03\x7f", EncodeUint32(2, 0x82)); + EXPECT_EQ("\x01\x7f", EncodeUint32(1, 0x80)); +} + +TEST(HpackOutputStreamTest, ThreeByteIntegersOneToSevenBitPrefixes) { + // Minimums. + EXPECT_EQ("\x7f\x80\x01", EncodeUint32(7, 0xff)); + EXPECT_EQ("\x3f\x80\x01", EncodeUint32(6, 0xbf)); + EXPECT_EQ("\x1f\x80\x01", EncodeUint32(5, 0x9f)); + EXPECT_EQ("\x0f\x80\x01", EncodeUint32(4, 0x8f)); + EXPECT_EQ("\x07\x80\x01", EncodeUint32(3, 0x87)); + EXPECT_EQ("\x03\x80\x01", EncodeUint32(2, 0x83)); + EXPECT_EQ("\x01\x80\x01", EncodeUint32(1, 0x81)); + + // Maximums. + EXPECT_EQ("\x7f\xff\x7f", EncodeUint32(7, 0x407e)); + EXPECT_EQ("\x3f\xff\x7f", EncodeUint32(6, 0x403e)); + EXPECT_EQ("\x1f\xff\x7f", EncodeUint32(5, 0x401e)); + EXPECT_EQ("\x0f\xff\x7f", EncodeUint32(4, 0x400e)); + EXPECT_EQ("\x07\xff\x7f", EncodeUint32(3, 0x4006)); + EXPECT_EQ("\x03\xff\x7f", EncodeUint32(2, 0x4002)); + EXPECT_EQ("\x01\xff\x7f", EncodeUint32(1, 0x4000)); +} + +TEST(HpackOutputStreamTest, FourByteIntegersOneToSevenBitPrefixes) { + // Minimums. + EXPECT_EQ("\x7f\x80\x80\x01", EncodeUint32(7, 0x407f)); + EXPECT_EQ("\x3f\x80\x80\x01", EncodeUint32(6, 0x403f)); + EXPECT_EQ("\x1f\x80\x80\x01", EncodeUint32(5, 0x401f)); + EXPECT_EQ("\x0f\x80\x80\x01", EncodeUint32(4, 0x400f)); + EXPECT_EQ("\x07\x80\x80\x01", EncodeUint32(3, 0x4007)); + EXPECT_EQ("\x03\x80\x80\x01", EncodeUint32(2, 0x4003)); + EXPECT_EQ("\x01\x80\x80\x01", EncodeUint32(1, 0x4001)); + + // Maximums. + EXPECT_EQ("\x7f\xff\xff\x7f", EncodeUint32(7, 0x20007e)); + EXPECT_EQ("\x3f\xff\xff\x7f", EncodeUint32(6, 0x20003e)); + EXPECT_EQ("\x1f\xff\xff\x7f", EncodeUint32(5, 0x20001e)); + EXPECT_EQ("\x0f\xff\xff\x7f", EncodeUint32(4, 0x20000e)); + EXPECT_EQ("\x07\xff\xff\x7f", EncodeUint32(3, 0x200006)); + EXPECT_EQ("\x03\xff\xff\x7f", EncodeUint32(2, 0x200002)); + EXPECT_EQ("\x01\xff\xff\x7f", EncodeUint32(1, 0x200000)); +} + +TEST(HpackOutputStreamTest, FiveByteIntegersOneToSevenBitPrefixes) { + // Minimums. + EXPECT_EQ("\x7f\x80\x80\x80\x01", EncodeUint32(7, 0x20007f)); + EXPECT_EQ("\x3f\x80\x80\x80\x01", EncodeUint32(6, 0x20003f)); + EXPECT_EQ("\x1f\x80\x80\x80\x01", EncodeUint32(5, 0x20001f)); + EXPECT_EQ("\x0f\x80\x80\x80\x01", EncodeUint32(4, 0x20000f)); + EXPECT_EQ("\x07\x80\x80\x80\x01", EncodeUint32(3, 0x200007)); + EXPECT_EQ("\x03\x80\x80\x80\x01", EncodeUint32(2, 0x200003)); + EXPECT_EQ("\x01\x80\x80\x80\x01", EncodeUint32(1, 0x200001)); + + // Maximums. + EXPECT_EQ("\x7f\xff\xff\xff\x7f", EncodeUint32(7, 0x1000007e)); + EXPECT_EQ("\x3f\xff\xff\xff\x7f", EncodeUint32(6, 0x1000003e)); + EXPECT_EQ("\x1f\xff\xff\xff\x7f", EncodeUint32(5, 0x1000001e)); + EXPECT_EQ("\x0f\xff\xff\xff\x7f", EncodeUint32(4, 0x1000000e)); + EXPECT_EQ("\x07\xff\xff\xff\x7f", EncodeUint32(3, 0x10000006)); + EXPECT_EQ("\x03\xff\xff\xff\x7f", EncodeUint32(2, 0x10000002)); + EXPECT_EQ("\x01\xff\xff\xff\x7f", EncodeUint32(1, 0x10000000)); +} + +TEST(HpackOutputStreamTest, SixByteIntegersOneToSevenBitPrefixes) { + // Minimums. + EXPECT_EQ("\x7f\x80\x80\x80\x80\x01", EncodeUint32(7, 0x1000007f)); + EXPECT_EQ("\x3f\x80\x80\x80\x80\x01", EncodeUint32(6, 0x1000003f)); + EXPECT_EQ("\x1f\x80\x80\x80\x80\x01", EncodeUint32(5, 0x1000001f)); + EXPECT_EQ("\x0f\x80\x80\x80\x80\x01", EncodeUint32(4, 0x1000000f)); + EXPECT_EQ("\x07\x80\x80\x80\x80\x01", EncodeUint32(3, 0x10000007)); + EXPECT_EQ("\x03\x80\x80\x80\x80\x01", EncodeUint32(2, 0x10000003)); + EXPECT_EQ("\x01\x80\x80\x80\x80\x01", EncodeUint32(1, 0x10000001)); + + // Maximums. + EXPECT_EQ("\x7f\x80\xff\xff\xff\x0f", EncodeUint32(7, 0xffffffff)); + EXPECT_EQ("\x3f\xc0\xff\xff\xff\x0f", EncodeUint32(6, 0xffffffff)); + EXPECT_EQ("\x1f\xe0\xff\xff\xff\x0f", EncodeUint32(5, 0xffffffff)); + EXPECT_EQ("\x0f\xf0\xff\xff\xff\x0f", EncodeUint32(4, 0xffffffff)); + EXPECT_EQ("\x07\xf8\xff\xff\xff\x0f", EncodeUint32(3, 0xffffffff)); + EXPECT_EQ("\x03\xfc\xff\xff\xff\x0f", EncodeUint32(2, 0xffffffff)); + EXPECT_EQ("\x01\xfe\xff\xff\xff\x0f", EncodeUint32(1, 0xffffffff)); +} + +// Test that encoding an integer with an N-bit prefix preserves the +// upper (8-N) bits of the first byte. +TEST(HpackOutputStreamTest, AppendUint32PreservesUpperBits) { + HpackOutputStream output_stream; + output_stream.AppendBits(0x7f, 7); + output_stream.AppendUint32(0x01); + std::string str = output_stream.TakeString(); + EXPECT_EQ(std::string("\xff\x00", 2), str); +} + +TEST(HpackOutputStreamTest, AppendBytes) { + HpackOutputStream output_stream; + + output_stream.AppendBytes("buffer1"); + output_stream.AppendBytes("buffer2"); + + std::string str = output_stream.TakeString(); + EXPECT_EQ("buffer1buffer2", str); +} + +TEST(HpackOutputStreamTest, BoundedTakeString) { + HpackOutputStream output_stream; + + output_stream.AppendBytes("buffer12"); + output_stream.AppendBytes("buffer456"); + + std::string str = output_stream.BoundedTakeString(9); + EXPECT_EQ("buffer12b", str); + + output_stream.AppendBits(0x7f, 7); + output_stream.AppendUint32(0x11); + str = output_stream.BoundedTakeString(9); + EXPECT_EQ("uffer456\xff", str); + + str = output_stream.BoundedTakeString(9); + EXPECT_EQ("\x10", str); +} + +TEST(HpackOutputStreamTest, MutableString) { + HpackOutputStream output_stream; + + output_stream.AppendBytes("1"); + output_stream.MutableString()->append("2"); + + output_stream.AppendBytes("foo"); + output_stream.MutableString()->append("bar"); + + std::string str = output_stream.TakeString(); + EXPECT_EQ("12foobar", str); +} + +} // namespace + +} // namespace spdy diff --git a/quiche/spdy/core/hpack/hpack_round_trip_test.cc b/quiche/spdy/core/hpack/hpack_round_trip_test.cc new file mode 100644 index 000000000000..44a36d3b2825 --- /dev/null +++ b/quiche/spdy/core/hpack/hpack_round_trip_test.cc @@ -0,0 +1,224 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include +#include +#include +#include + +#include "quiche/http2/test_tools/http2_random.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/spdy/core/hpack/hpack_constants.h" +#include "quiche/spdy/core/hpack/hpack_decoder_adapter.h" +#include "quiche/spdy/core/hpack/hpack_encoder.h" +#include "quiche/spdy/core/http2_header_block.h" +#include "quiche/spdy/test_tools/spdy_test_utils.h" + +namespace spdy { +namespace test { + +namespace { + +// Supports testing with the input split at every byte boundary. +enum InputSizeParam { ALL_INPUT, ONE_BYTE, ZERO_THEN_ONE_BYTE }; + +class HpackRoundTripTest + : public quiche::test::QuicheTestWithParam { + protected: + void SetUp() override { + // Use a small table size to tickle eviction handling. + encoder_.ApplyHeaderTableSizeSetting(256); + decoder_.ApplyHeaderTableSizeSetting(256); + } + + bool RoundTrip(const Http2HeaderBlock& header_set) { + std::string encoded = encoder_.EncodeHeaderBlock(header_set); + + bool success = true; + if (GetParam() == ALL_INPUT) { + // Pass all the input to the decoder at once. + success = decoder_.HandleControlFrameHeadersData(encoded.data(), + encoded.size()); + } else if (GetParam() == ONE_BYTE) { + // Pass the input to the decoder one byte at a time. + const char* data = encoded.data(); + for (size_t ndx = 0; ndx < encoded.size() && success; ++ndx) { + success = decoder_.HandleControlFrameHeadersData(data + ndx, 1); + } + } else if (GetParam() == ZERO_THEN_ONE_BYTE) { + // Pass the input to the decoder one byte at a time, but before each + // byte pass an empty buffer. + const char* data = encoded.data(); + for (size_t ndx = 0; ndx < encoded.size() && success; ++ndx) { + success = (decoder_.HandleControlFrameHeadersData(data + ndx, 0) && + decoder_.HandleControlFrameHeadersData(data + ndx, 1)); + } + } else { + ADD_FAILURE() << "Unknown param: " << GetParam(); + } + + if (success) { + success = decoder_.HandleControlFrameHeadersComplete(); + } + + EXPECT_EQ(header_set, decoder_.decoded_block()); + return success; + } + + size_t SampleExponential(size_t mean, size_t sanity_bound) { + return std::min(-std::log(random_.RandDouble()) * mean, + sanity_bound); + } + + http2::test::Http2Random random_; + HpackEncoder encoder_; + HpackDecoderAdapter decoder_; +}; + +INSTANTIATE_TEST_SUITE_P(Tests, HpackRoundTripTest, + ::testing::Values(ALL_INPUT, ONE_BYTE, + ZERO_THEN_ONE_BYTE)); + +TEST_P(HpackRoundTripTest, ResponseFixtures) { + { + Http2HeaderBlock headers; + headers[":status"] = "302"; + headers["cache-control"] = "private"; + headers["date"] = "Mon, 21 Oct 2013 20:13:21 GMT"; + headers["location"] = "https://www.example.com"; + EXPECT_TRUE(RoundTrip(headers)); + } + { + Http2HeaderBlock headers; + headers[":status"] = "200"; + headers["cache-control"] = "private"; + headers["date"] = "Mon, 21 Oct 2013 20:13:21 GMT"; + headers["location"] = "https://www.example.com"; + EXPECT_TRUE(RoundTrip(headers)); + } + { + Http2HeaderBlock headers; + headers[":status"] = "200"; + headers["cache-control"] = "private"; + headers["content-encoding"] = "gzip"; + headers["date"] = "Mon, 21 Oct 2013 20:13:22 GMT"; + headers["location"] = "https://www.example.com"; + headers["set-cookie"] = + "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU;" + " max-age=3600; version=1"; + headers["multivalue"] = std::string("foo\0bar", 7); + EXPECT_TRUE(RoundTrip(headers)); + } +} + +TEST_P(HpackRoundTripTest, RequestFixtures) { + { + Http2HeaderBlock headers; + headers[":authority"] = "www.example.com"; + headers[":method"] = "GET"; + headers[":path"] = "/"; + headers[":scheme"] = "http"; + headers["cookie"] = "baz=bing; foo=bar"; + EXPECT_TRUE(RoundTrip(headers)); + } + { + Http2HeaderBlock headers; + headers[":authority"] = "www.example.com"; + headers[":method"] = "GET"; + headers[":path"] = "/"; + headers[":scheme"] = "http"; + headers["cache-control"] = "no-cache"; + headers["cookie"] = "foo=bar; spam=eggs"; + EXPECT_TRUE(RoundTrip(headers)); + } + { + Http2HeaderBlock headers; + headers[":authority"] = "www.example.com"; + headers[":method"] = "GET"; + headers[":path"] = "/index.html"; + headers[":scheme"] = "https"; + headers["custom-key"] = "custom-value"; + headers["cookie"] = "baz=bing; fizzle=fazzle; garbage"; + headers["multivalue"] = std::string("foo\0bar", 7); + EXPECT_TRUE(RoundTrip(headers)); + } +} + +TEST_P(HpackRoundTripTest, RandomizedExamples) { + // Grow vectors of names & values, which are seeded with fixtures and then + // expanded with dynamically generated data. Samples are taken using the + // exponential distribution. + std::vector pseudo_header_names, random_header_names; + pseudo_header_names.push_back(":authority"); + pseudo_header_names.push_back(":path"); + pseudo_header_names.push_back(":status"); + + // TODO(jgraettinger): Enable "cookie" as a name fixture. Crumbs may be + // reconstructed in any order, which breaks the simple validation used here. + + std::vector values; + values.push_back("/"); + values.push_back("/index.html"); + values.push_back("200"); + values.push_back("404"); + values.push_back(""); + values.push_back("baz=bing; foo=bar; garbage"); + values.push_back("baz=bing; fizzle=fazzle; garbage"); + + for (size_t i = 0; i != 2000; ++i) { + Http2HeaderBlock headers; + + // Choose a random number of headers to add, and of these a random subset + // will be HTTP/2 pseudo headers. + size_t header_count = 1 + SampleExponential(7, 50); + size_t pseudo_header_count = + std::min(header_count, 1 + SampleExponential(7, 50)); + EXPECT_LE(pseudo_header_count, header_count); + for (size_t j = 0; j != header_count; ++j) { + std::string name, value; + // Pseudo headers must be added before regular headers. + if (j < pseudo_header_count) { + // Choose one of the defined pseudo headers at random. + size_t name_index = random_.Uniform(pseudo_header_names.size()); + name = pseudo_header_names[name_index]; + } else { + // Randomly reuse an existing header name, or generate a new one. + size_t name_index = SampleExponential(20, 200); + if (name_index >= random_header_names.size()) { + name = random_.RandString(1 + SampleExponential(5, 30)); + // A regular header cannot begin with the pseudo header prefix ":". + if (name[0] == ':') { + name[0] = 'x'; + } + random_header_names.push_back(name); + } else { + name = random_header_names[name_index]; + } + } + + // Randomly reuse an existing value, or generate a new one. + size_t value_index = SampleExponential(20, 200); + if (value_index >= values.size()) { + std::string newvalue = + random_.RandString(1 + SampleExponential(15, 75)); + // Currently order is not preserved in the encoder. In particular, + // when a value is decomposed at \0 delimiters, its parts might get + // encoded out of order if some but not all of them already exist in + // the header table. For now, avoid \0 bytes in values. + std::replace(newvalue.begin(), newvalue.end(), '\x00', '\x01'); + values.push_back(newvalue); + value = values.back(); + } else { + value = values[value_index]; + } + headers[name] = value; + } + EXPECT_TRUE(RoundTrip(headers)); + } +} + +} // namespace + +} // namespace test +} // namespace spdy diff --git a/quiche/spdy/core/hpack/hpack_static_table.cc b/quiche/spdy/core/hpack/hpack_static_table.cc new file mode 100644 index 000000000000..a443283091ad --- /dev/null +++ b/quiche/spdy/core/hpack/hpack_static_table.cc @@ -0,0 +1,50 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/hpack/hpack_static_table.h" + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/spdy/core/hpack/hpack_constants.h" +#include "quiche/spdy/core/hpack/hpack_entry.h" + +namespace spdy { + +HpackStaticTable::HpackStaticTable() = default; + +HpackStaticTable::~HpackStaticTable() = default; + +void HpackStaticTable::Initialize(const HpackStaticEntry* static_entry_table, + size_t static_entry_count) { + QUICHE_CHECK(!IsInitialized()); + + static_entries_.reserve(static_entry_count); + + for (const HpackStaticEntry* it = static_entry_table; + it != static_entry_table + static_entry_count; ++it) { + std::string name(it->name, it->name_len); + std::string value(it->value, it->value_len); + static_entries_.push_back(HpackEntry(std::move(name), std::move(value))); + } + + // |static_entries_| will not be mutated any more. Therefore its entries will + // remain stable even if the container does not have iterator stability. + int insertion_count = 0; + for (const auto& entry : static_entries_) { + auto result = static_index_.insert(std::make_pair( + HpackLookupEntry{entry.name(), entry.value()}, insertion_count)); + QUICHE_CHECK(result.second); + + // Multiple static entries may have the same name, so inserts may fail. + static_name_index_.insert(std::make_pair(entry.name(), insertion_count)); + + ++insertion_count; + } +} + +bool HpackStaticTable::IsInitialized() const { + return !static_entries_.empty(); +} + +} // namespace spdy diff --git a/quiche/spdy/core/hpack/hpack_static_table.h b/quiche/spdy/core/hpack/hpack_static_table.h new file mode 100644 index 000000000000..278c242c6068 --- /dev/null +++ b/quiche/spdy/core/hpack/hpack_static_table.h @@ -0,0 +1,56 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_SPDY_CORE_HPACK_HPACK_STATIC_TABLE_H_ +#define QUICHE_SPDY_CORE_HPACK_HPACK_STATIC_TABLE_H_ + +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/spdy/core/hpack/hpack_header_table.h" + +namespace spdy { + +struct HpackStaticEntry; + +// Number of entries in the HPACK static table. +constexpr size_t kStaticTableSize = 61; + +// HpackStaticTable provides |static_entries_| and |static_index_| for HPACK +// encoding and decoding contexts. Once initialized, an instance is read only +// and may be accessed only through its const interface. Such an instance may +// be shared accross multiple HPACK contexts. +class QUICHE_EXPORT HpackStaticTable { + public: + HpackStaticTable(); + ~HpackStaticTable(); + + // Prepares HpackStaticTable by filling up static_entries_ and static_index_ + // from an array of struct HpackStaticEntry. Must be called exactly once. + void Initialize(const HpackStaticEntry* static_entry_table, + size_t static_entry_count); + + // Returns whether Initialize() has been called. + bool IsInitialized() const; + + // Accessors. + const HpackHeaderTable::StaticEntryTable& GetStaticEntries() const { + return static_entries_; + } + const HpackHeaderTable::NameValueToEntryMap& GetStaticIndex() const { + return static_index_; + } + const HpackHeaderTable::NameToEntryMap& GetStaticNameIndex() const { + return static_name_index_; + } + + private: + HpackHeaderTable::StaticEntryTable static_entries_; + // The following two members have string_views that point to strings stored in + // |static_entries_|. + HpackHeaderTable::NameValueToEntryMap static_index_; + HpackHeaderTable::NameToEntryMap static_name_index_; +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_HPACK_HPACK_STATIC_TABLE_H_ diff --git a/quiche/spdy/core/hpack/hpack_static_table_test.cc b/quiche/spdy/core/hpack/hpack_static_table_test.cc new file mode 100644 index 000000000000..b781aaf4128d --- /dev/null +++ b/quiche/spdy/core/hpack/hpack_static_table_test.cc @@ -0,0 +1,63 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/hpack/hpack_static_table.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/spdy/core/hpack/hpack_constants.h" + +namespace spdy { + +namespace test { + +namespace { + +class HpackStaticTableTest : public quiche::test::QuicheTest { + protected: + HpackStaticTableTest() : table_() {} + + HpackStaticTable table_; +}; + +// Check that an initialized instance has the right number of entries. +TEST_F(HpackStaticTableTest, Initialize) { + EXPECT_FALSE(table_.IsInitialized()); + table_.Initialize(HpackStaticTableVector().data(), + HpackStaticTableVector().size()); + EXPECT_TRUE(table_.IsInitialized()); + + const HpackHeaderTable::StaticEntryTable& static_entries = + table_.GetStaticEntries(); + EXPECT_EQ(kStaticTableSize, static_entries.size()); + + const HpackHeaderTable::NameValueToEntryMap& static_index = + table_.GetStaticIndex(); + EXPECT_EQ(kStaticTableSize, static_index.size()); + + const HpackHeaderTable::NameToEntryMap& static_name_index = + table_.GetStaticNameIndex(); + // Count distinct names in static table. + std::set names; + for (const auto& entry : static_entries) { + names.insert(entry.name()); + } + EXPECT_EQ(names.size(), static_name_index.size()); +} + +// Test that ObtainHpackStaticTable returns the same instance every time. +TEST_F(HpackStaticTableTest, IsSingleton) { + const HpackStaticTable* static_table_one = &ObtainHpackStaticTable(); + const HpackStaticTable* static_table_two = &ObtainHpackStaticTable(); + EXPECT_EQ(static_table_one, static_table_two); +} + +} // namespace + +} // namespace test + +} // namespace spdy diff --git a/quiche/spdy/core/http2_frame_decoder_adapter.cc b/quiche/spdy/core/http2_frame_decoder_adapter.cc new file mode 100644 index 000000000000..303f037086a6 --- /dev/null +++ b/quiche/spdy/core/http2_frame_decoder_adapter.cc @@ -0,0 +1,1111 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/http2_frame_decoder_adapter.h" + +// Logging policy: If an error in the input is detected, QUICHE_VLOG(n) is used +// so that the option exists to debug the situation. Otherwise, this code mostly +// uses QUICHE_DVLOG so that the logging does not slow down production code when +// things are working OK. + +#include + +#include +#include +#include + +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/decoder/decode_status.h" +#include "quiche/http2/decoder/http2_frame_decoder.h" +#include "quiche/http2/decoder/http2_frame_decoder_listener.h" +#include "quiche/http2/http2_constants.h" +#include "quiche/http2/http2_structures.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_endian.h" +#include "quiche/spdy/core/hpack/hpack_decoder_adapter.h" +#include "quiche/spdy/core/hpack/hpack_header_table.h" +#include "quiche/spdy/core/spdy_alt_svc_wire_format.h" +#include "quiche/spdy/core/spdy_headers_handler_interface.h" +#include "quiche/spdy/core/spdy_protocol.h" + +using ::spdy::ExtensionVisitorInterface; +using ::spdy::HpackDecoderAdapter; +using ::spdy::HpackHeaderTable; +using ::spdy::ParseErrorCode; +using ::spdy::ParseFrameType; +using ::spdy::SpdyAltSvcWireFormat; +using ::spdy::SpdyErrorCode; +using ::spdy::SpdyFramerDebugVisitorInterface; +using ::spdy::SpdyFramerVisitorInterface; +using ::spdy::SpdyFrameType; +using ::spdy::SpdyHeadersHandlerInterface; +using ::spdy::SpdyKnownSettingsId; +using ::spdy::SpdySettingsId; + +namespace http2 { +namespace { + +const bool kHasPriorityFields = true; +const bool kNotHasPriorityFields = false; + +bool IsPaddable(Http2FrameType type) { + return type == Http2FrameType::DATA || type == Http2FrameType::HEADERS || + type == Http2FrameType::PUSH_PROMISE; +} + +SpdyFrameType ToSpdyFrameType(Http2FrameType type) { + return ParseFrameType(static_cast(type)); +} + +uint64_t ToSpdyPingId(const Http2PingFields& ping) { + uint64_t v; + std::memcpy(&v, ping.opaque_bytes, Http2PingFields::EncodedSize()); + return quiche::QuicheEndian::NetToHost64(v); +} + +// Overwrites the fields of the header with invalid values, for the purpose +// of identifying reading of unset fields. Only takes effect for debug builds. +// In Address Sanatizer builds, it also marks the fields as un-readable. +#ifndef NDEBUG +void CorruptFrameHeader(Http2FrameHeader* header) { + // Beyond a valid payload length, which is 2^24 - 1. + header->payload_length = 0x1010dead; + // An unsupported frame type. + header->type = Http2FrameType(0x80); + QUICHE_DCHECK(!IsSupportedHttp2FrameType(header->type)); + // Frame flag bits that aren't used by any supported frame type. + header->flags = Http2FrameFlag(0xd2); + // A stream id with the reserved high-bit (R in the RFC) set. + // 2129510127 when the high-bit is cleared. + header->stream_id = 0xfeedbeef; +} +#else +void CorruptFrameHeader(Http2FrameHeader* /*header*/) {} +#endif + +Http2DecoderAdapter::SpdyFramerError HpackDecodingErrorToSpdyFramerError( + HpackDecodingError error) { + switch (error) { + case HpackDecodingError::kOk: + return Http2DecoderAdapter::SpdyFramerError::SPDY_NO_ERROR; + case HpackDecodingError::kIndexVarintError: + return Http2DecoderAdapter::SpdyFramerError:: + SPDY_HPACK_INDEX_VARINT_ERROR; + case HpackDecodingError::kNameLengthVarintError: + return Http2DecoderAdapter::SpdyFramerError:: + SPDY_HPACK_NAME_LENGTH_VARINT_ERROR; + case HpackDecodingError::kValueLengthVarintError: + return Http2DecoderAdapter::SpdyFramerError:: + SPDY_HPACK_VALUE_LENGTH_VARINT_ERROR; + case HpackDecodingError::kNameTooLong: + return Http2DecoderAdapter::SpdyFramerError::SPDY_HPACK_NAME_TOO_LONG; + case HpackDecodingError::kValueTooLong: + return Http2DecoderAdapter::SpdyFramerError::SPDY_HPACK_VALUE_TOO_LONG; + case HpackDecodingError::kNameHuffmanError: + return Http2DecoderAdapter::SpdyFramerError:: + SPDY_HPACK_NAME_HUFFMAN_ERROR; + case HpackDecodingError::kValueHuffmanError: + return Http2DecoderAdapter::SpdyFramerError:: + SPDY_HPACK_VALUE_HUFFMAN_ERROR; + case HpackDecodingError::kMissingDynamicTableSizeUpdate: + return Http2DecoderAdapter::SpdyFramerError:: + SPDY_HPACK_MISSING_DYNAMIC_TABLE_SIZE_UPDATE; + case HpackDecodingError::kInvalidIndex: + return Http2DecoderAdapter::SpdyFramerError::SPDY_HPACK_INVALID_INDEX; + case HpackDecodingError::kInvalidNameIndex: + return Http2DecoderAdapter::SpdyFramerError:: + SPDY_HPACK_INVALID_NAME_INDEX; + case HpackDecodingError::kDynamicTableSizeUpdateNotAllowed: + return Http2DecoderAdapter::SpdyFramerError:: + SPDY_HPACK_DYNAMIC_TABLE_SIZE_UPDATE_NOT_ALLOWED; + case HpackDecodingError::kInitialDynamicTableSizeUpdateIsAboveLowWaterMark: + return Http2DecoderAdapter::SpdyFramerError:: + SPDY_HPACK_INITIAL_DYNAMIC_TABLE_SIZE_UPDATE_IS_ABOVE_LOW_WATER_MARK; + case HpackDecodingError::kDynamicTableSizeUpdateIsAboveAcknowledgedSetting: + return Http2DecoderAdapter::SpdyFramerError:: + SPDY_HPACK_DYNAMIC_TABLE_SIZE_UPDATE_IS_ABOVE_ACKNOWLEDGED_SETTING; + case HpackDecodingError::kTruncatedBlock: + return Http2DecoderAdapter::SpdyFramerError::SPDY_HPACK_TRUNCATED_BLOCK; + case HpackDecodingError::kFragmentTooLong: + return Http2DecoderAdapter::SpdyFramerError::SPDY_HPACK_FRAGMENT_TOO_LONG; + case HpackDecodingError::kCompressedHeaderSizeExceedsLimit: + return Http2DecoderAdapter::SpdyFramerError:: + SPDY_HPACK_COMPRESSED_HEADER_SIZE_EXCEEDS_LIMIT; + } + + return Http2DecoderAdapter::SpdyFramerError::SPDY_DECOMPRESS_FAILURE; +} + +} // namespace + +const char* Http2DecoderAdapter::StateToString(int state) { + switch (state) { + case SPDY_ERROR: + return "ERROR"; + case SPDY_FRAME_COMPLETE: + return "FRAME_COMPLETE"; + case SPDY_READY_FOR_FRAME: + return "READY_FOR_FRAME"; + case SPDY_READING_COMMON_HEADER: + return "READING_COMMON_HEADER"; + case SPDY_CONTROL_FRAME_PAYLOAD: + return "CONTROL_FRAME_PAYLOAD"; + case SPDY_READ_DATA_FRAME_PADDING_LENGTH: + return "SPDY_READ_DATA_FRAME_PADDING_LENGTH"; + case SPDY_CONSUME_PADDING: + return "SPDY_CONSUME_PADDING"; + case SPDY_IGNORE_REMAINING_PAYLOAD: + return "IGNORE_REMAINING_PAYLOAD"; + case SPDY_FORWARD_STREAM_FRAME: + return "FORWARD_STREAM_FRAME"; + case SPDY_CONTROL_FRAME_BEFORE_HEADER_BLOCK: + return "SPDY_CONTROL_FRAME_BEFORE_HEADER_BLOCK"; + case SPDY_CONTROL_FRAME_HEADER_BLOCK: + return "SPDY_CONTROL_FRAME_HEADER_BLOCK"; + case SPDY_GOAWAY_FRAME_PAYLOAD: + return "SPDY_GOAWAY_FRAME_PAYLOAD"; + case SPDY_SETTINGS_FRAME_HEADER: + return "SPDY_SETTINGS_FRAME_HEADER"; + case SPDY_SETTINGS_FRAME_PAYLOAD: + return "SPDY_SETTINGS_FRAME_PAYLOAD"; + case SPDY_ALTSVC_FRAME_PAYLOAD: + return "SPDY_ALTSVC_FRAME_PAYLOAD"; + } + return "UNKNOWN_STATE"; +} + +const char* Http2DecoderAdapter::SpdyFramerErrorToString( + SpdyFramerError spdy_framer_error) { + switch (spdy_framer_error) { + case SPDY_NO_ERROR: + return "NO_ERROR"; + case SPDY_INVALID_STREAM_ID: + return "INVALID_STREAM_ID"; + case SPDY_INVALID_CONTROL_FRAME: + return "INVALID_CONTROL_FRAME"; + case SPDY_CONTROL_PAYLOAD_TOO_LARGE: + return "CONTROL_PAYLOAD_TOO_LARGE"; + case SPDY_DECOMPRESS_FAILURE: + return "DECOMPRESS_FAILURE"; + case SPDY_INVALID_PADDING: + return "INVALID_PADDING"; + case SPDY_INVALID_DATA_FRAME_FLAGS: + return "INVALID_DATA_FRAME_FLAGS"; + case SPDY_UNEXPECTED_FRAME: + return "UNEXPECTED_FRAME"; + case SPDY_INTERNAL_FRAMER_ERROR: + return "INTERNAL_FRAMER_ERROR"; + case SPDY_INVALID_CONTROL_FRAME_SIZE: + return "INVALID_CONTROL_FRAME_SIZE"; + case SPDY_OVERSIZED_PAYLOAD: + return "OVERSIZED_PAYLOAD"; + case SPDY_HPACK_INDEX_VARINT_ERROR: + return "HPACK_INDEX_VARINT_ERROR"; + case SPDY_HPACK_NAME_LENGTH_VARINT_ERROR: + return "HPACK_NAME_LENGTH_VARINT_ERROR"; + case SPDY_HPACK_VALUE_LENGTH_VARINT_ERROR: + return "HPACK_VALUE_LENGTH_VARINT_ERROR"; + case SPDY_HPACK_NAME_TOO_LONG: + return "HPACK_NAME_TOO_LONG"; + case SPDY_HPACK_VALUE_TOO_LONG: + return "HPACK_VALUE_TOO_LONG"; + case SPDY_HPACK_NAME_HUFFMAN_ERROR: + return "HPACK_NAME_HUFFMAN_ERROR"; + case SPDY_HPACK_VALUE_HUFFMAN_ERROR: + return "HPACK_VALUE_HUFFMAN_ERROR"; + case SPDY_HPACK_MISSING_DYNAMIC_TABLE_SIZE_UPDATE: + return "HPACK_MISSING_DYNAMIC_TABLE_SIZE_UPDATE"; + case SPDY_HPACK_INVALID_INDEX: + return "HPACK_INVALID_INDEX"; + case SPDY_HPACK_INVALID_NAME_INDEX: + return "HPACK_INVALID_NAME_INDEX"; + case SPDY_HPACK_DYNAMIC_TABLE_SIZE_UPDATE_NOT_ALLOWED: + return "HPACK_DYNAMIC_TABLE_SIZE_UPDATE_NOT_ALLOWED"; + case SPDY_HPACK_INITIAL_DYNAMIC_TABLE_SIZE_UPDATE_IS_ABOVE_LOW_WATER_MARK: + return "HPACK_INITIAL_DYNAMIC_TABLE_SIZE_UPDATE_IS_ABOVE_LOW_WATER_MARK"; + case SPDY_HPACK_DYNAMIC_TABLE_SIZE_UPDATE_IS_ABOVE_ACKNOWLEDGED_SETTING: + return "HPACK_DYNAMIC_TABLE_SIZE_UPDATE_IS_ABOVE_ACKNOWLEDGED_SETTING"; + case SPDY_HPACK_TRUNCATED_BLOCK: + return "HPACK_TRUNCATED_BLOCK"; + case SPDY_HPACK_FRAGMENT_TOO_LONG: + return "HPACK_FRAGMENT_TOO_LONG"; + case SPDY_HPACK_COMPRESSED_HEADER_SIZE_EXCEEDS_LIMIT: + return "HPACK_COMPRESSED_HEADER_SIZE_EXCEEDS_LIMIT"; + case SPDY_STOP_PROCESSING: + return "STOP_PROCESSING"; + case LAST_ERROR: + return "UNKNOWN_ERROR"; + } + return "UNKNOWN_ERROR"; +} + +Http2DecoderAdapter::Http2DecoderAdapter() : frame_decoder_(this) { + QUICHE_DVLOG(1) << "Http2DecoderAdapter ctor"; + + CorruptFrameHeader(&frame_header_); + CorruptFrameHeader(&hpack_first_frame_header_); +} + +Http2DecoderAdapter::~Http2DecoderAdapter() = default; + +void Http2DecoderAdapter::set_visitor(SpdyFramerVisitorInterface* visitor) { + visitor_ = visitor; +} + +void Http2DecoderAdapter::set_debug_visitor( + SpdyFramerDebugVisitorInterface* debug_visitor) { + debug_visitor_ = debug_visitor; +} + +void Http2DecoderAdapter::set_extension_visitor( + ExtensionVisitorInterface* visitor) { + extension_ = visitor; +} + +size_t Http2DecoderAdapter::ProcessInput(const char* data, size_t len) { + size_t total_processed = 0; + while (len > 0 && spdy_state_ != SPDY_ERROR) { + // Process one at a time so that we update the adapter's internal + // state appropriately. + const size_t processed = ProcessInputFrame(data, len); + + // We had some data, and weren't in an error state, so should have + // processed/consumed at least one byte of it, even if we then ended up + // in an error state. + QUICHE_DCHECK(processed > 0) + << "processed=" << processed << " spdy_state_=" << spdy_state_ + << " spdy_framer_error_=" << spdy_framer_error_; + + data += processed; + len -= processed; + total_processed += processed; + if (processed == 0) { + break; + } + } + return total_processed; +} + +Http2DecoderAdapter::SpdyState Http2DecoderAdapter::state() const { + return spdy_state_; +} + +Http2DecoderAdapter::SpdyFramerError Http2DecoderAdapter::spdy_framer_error() + const { + return spdy_framer_error_; +} + +bool Http2DecoderAdapter::probable_http_response() const { + return latched_probable_http_response_; +} + +void Http2DecoderAdapter::StopProcessing() { + SetSpdyErrorAndNotify(SpdyFramerError::SPDY_STOP_PROCESSING, + "Ignoring further events on this connection."); +} + +void Http2DecoderAdapter::SetMaxFrameSize(size_t max_frame_size) { + max_frame_size_ = max_frame_size; + frame_decoder_.set_maximum_payload_size(max_frame_size); +} + +// =========================================================================== +// Implementations of the methods declared by Http2FrameDecoderListener. + +// Called once the common frame header has been decoded for any frame. +// This function is largely based on Http2DecoderAdapter::ValidateFrameHeader +// and some parts of Http2DecoderAdapter::ProcessCommonHeader. +bool Http2DecoderAdapter::OnFrameHeader(const Http2FrameHeader& header) { + QUICHE_DVLOG(1) << "OnFrameHeader: " << header; + decoded_frame_header_ = true; + if (!latched_probable_http_response_) { + latched_probable_http_response_ = header.IsProbableHttpResponse(); + } + const uint8_t raw_frame_type = static_cast(header.type); + visitor()->OnCommonHeader(header.stream_id, header.payload_length, + raw_frame_type, header.flags); + if (has_expected_frame_type_ && header.type != expected_frame_type_) { + // Report an unexpected frame error and close the connection if we + // expect a known frame type (probably CONTINUATION) and receive an + // unknown frame. + QUICHE_VLOG(1) << "The framer was expecting to receive a " + << expected_frame_type_ + << " frame, but instead received an unknown frame of type " + << header.type; + SetSpdyErrorAndNotify(SpdyFramerError::SPDY_UNEXPECTED_FRAME, ""); + return false; + } + if (!IsSupportedHttp2FrameType(header.type)) { + if (extension_ != nullptr) { + // Unknown frames will be passed to the registered extension. + return true; + } + // In HTTP2 we ignore unknown frame types for extensibility, as long as + // the rest of the control frame header is valid. + // We rely on the visitor to check validity of stream_id. + bool valid_stream = + visitor()->OnUnknownFrame(header.stream_id, raw_frame_type); + if (!valid_stream) { + // Report an invalid frame error if the stream_id is not valid. + QUICHE_VLOG(1) << "Unknown control frame type " << header.type + << " received on invalid stream " << header.stream_id; + SetSpdyErrorAndNotify(SpdyFramerError::SPDY_INVALID_CONTROL_FRAME, ""); + return false; + } else { + QUICHE_DVLOG(1) << "Ignoring unknown frame type " << header.type; + return true; + } + } + + SpdyFrameType frame_type = ToSpdyFrameType(header.type); + if (!IsValidHTTP2FrameStreamId(header.stream_id, frame_type)) { + QUICHE_VLOG(1) << "The framer received an invalid streamID of " + << header.stream_id << " for a frame of type " + << header.type; + SetSpdyErrorAndNotify(SpdyFramerError::SPDY_INVALID_STREAM_ID, ""); + return false; + } + + if (has_expected_frame_type_ && header.type != expected_frame_type_) { + QUICHE_VLOG(1) << "Expected frame type " << expected_frame_type_ << ", not " + << header.type; + SetSpdyErrorAndNotify(SpdyFramerError::SPDY_UNEXPECTED_FRAME, ""); + return false; + } + + if (!has_expected_frame_type_ && + header.type == Http2FrameType::CONTINUATION) { + QUICHE_VLOG(1) << "Got CONTINUATION frame when not expected."; + SetSpdyErrorAndNotify(SpdyFramerError::SPDY_UNEXPECTED_FRAME, ""); + return false; + } + + if (header.type == Http2FrameType::DATA) { + // For some reason SpdyFramer still rejects invalid DATA frame flags. + uint8_t valid_flags = Http2FrameFlag::PADDED | Http2FrameFlag::END_STREAM; + if (header.HasAnyFlags(~valid_flags)) { + SetSpdyErrorAndNotify(SpdyFramerError::SPDY_INVALID_DATA_FRAME_FLAGS, ""); + return false; + } + } + + return true; +} + +void Http2DecoderAdapter::OnDataStart(const Http2FrameHeader& header) { + QUICHE_DVLOG(1) << "OnDataStart: " << header; + + if (IsOkToStartFrame(header) && HasRequiredStreamId(header)) { + frame_header_ = header; + has_frame_header_ = true; + visitor()->OnDataFrameHeader(header.stream_id, header.payload_length, + header.IsEndStream()); + } +} + +void Http2DecoderAdapter::OnDataPayload(const char* data, size_t len) { + QUICHE_DVLOG(1) << "OnDataPayload: len=" << len; + QUICHE_DCHECK(has_frame_header_); + QUICHE_DCHECK_EQ(frame_header_.type, Http2FrameType::DATA); + visitor()->OnStreamFrameData(frame_header().stream_id, data, len); +} + +void Http2DecoderAdapter::OnDataEnd() { + QUICHE_DVLOG(1) << "OnDataEnd"; + QUICHE_DCHECK(has_frame_header_); + QUICHE_DCHECK_EQ(frame_header_.type, Http2FrameType::DATA); + if (frame_header().IsEndStream()) { + visitor()->OnStreamEnd(frame_header().stream_id); + } + opt_pad_length_.reset(); +} + +void Http2DecoderAdapter::OnHeadersStart(const Http2FrameHeader& header) { + QUICHE_DVLOG(1) << "OnHeadersStart: " << header; + if (IsOkToStartFrame(header) && HasRequiredStreamId(header)) { + frame_header_ = header; + has_frame_header_ = true; + if (header.HasPriority()) { + // Once we've got the priority fields, then we can report the arrival + // of this HEADERS frame. + on_headers_called_ = false; + return; + } + on_headers_called_ = true; + ReportReceiveCompressedFrame(header); + visitor()->OnHeaders(header.stream_id, header.payload_length, + kNotHasPriorityFields, + 0, // priority + 0, // parent_stream_id + false, // exclusive + header.IsEndStream(), header.IsEndHeaders()); + CommonStartHpackBlock(); + } +} + +void Http2DecoderAdapter::OnHeadersPriority( + const Http2PriorityFields& priority) { + QUICHE_DVLOG(1) << "OnHeadersPriority: " << priority; + QUICHE_DCHECK(has_frame_header_); + QUICHE_DCHECK_EQ(frame_type(), Http2FrameType::HEADERS) << frame_header_; + QUICHE_DCHECK(frame_header_.HasPriority()); + QUICHE_DCHECK(!on_headers_called_); + on_headers_called_ = true; + ReportReceiveCompressedFrame(frame_header_); + if (!visitor()) { + QUICHE_BUG(spdy_bug_1_1) + << "Visitor is nullptr, handling priority in headers failed." + << " priority:" << priority << " frame_header:" << frame_header_; + return; + } + visitor()->OnHeaders( + frame_header_.stream_id, frame_header_.payload_length, kHasPriorityFields, + priority.weight, priority.stream_dependency, priority.is_exclusive, + frame_header_.IsEndStream(), frame_header_.IsEndHeaders()); + CommonStartHpackBlock(); +} + +void Http2DecoderAdapter::OnHpackFragment(const char* data, size_t len) { + QUICHE_DVLOG(1) << "OnHpackFragment: len=" << len; + on_hpack_fragment_called_ = true; + auto* decoder = GetHpackDecoder(); + if (!decoder->HandleControlFrameHeadersData(data, len)) { + SetSpdyErrorAndNotify(HpackDecodingErrorToSpdyFramerError(decoder->error()), + decoder->detailed_error()); + return; + } +} + +void Http2DecoderAdapter::OnHeadersEnd() { + QUICHE_DVLOG(1) << "OnHeadersEnd"; + CommonHpackFragmentEnd(); + opt_pad_length_.reset(); +} + +void Http2DecoderAdapter::OnPriorityFrame(const Http2FrameHeader& header, + const Http2PriorityFields& priority) { + QUICHE_DVLOG(1) << "OnPriorityFrame: " << header + << "; priority: " << priority; + if (IsOkToStartFrame(header) && HasRequiredStreamId(header)) { + visitor()->OnPriority(header.stream_id, priority.stream_dependency, + priority.weight, priority.is_exclusive); + } +} + +void Http2DecoderAdapter::OnContinuationStart(const Http2FrameHeader& header) { + QUICHE_DVLOG(1) << "OnContinuationStart: " << header; + if (IsOkToStartFrame(header) && HasRequiredStreamId(header)) { + QUICHE_DCHECK(has_hpack_first_frame_header_); + if (header.stream_id != hpack_first_frame_header_.stream_id) { + SetSpdyErrorAndNotify(SpdyFramerError::SPDY_UNEXPECTED_FRAME, ""); + return; + } + frame_header_ = header; + has_frame_header_ = true; + ReportReceiveCompressedFrame(header); + visitor()->OnContinuation(header.stream_id, header.payload_length, + header.IsEndHeaders()); + } +} + +void Http2DecoderAdapter::OnContinuationEnd() { + QUICHE_DVLOG(1) << "OnContinuationEnd"; + CommonHpackFragmentEnd(); +} + +void Http2DecoderAdapter::OnPadLength(size_t trailing_length) { + QUICHE_DVLOG(1) << "OnPadLength: " << trailing_length; + opt_pad_length_ = trailing_length; + QUICHE_DCHECK_LT(trailing_length, 256u); + if (frame_header_.type == Http2FrameType::DATA) { + visitor()->OnStreamPadLength(stream_id(), trailing_length); + } +} + +void Http2DecoderAdapter::OnPadding(const char* /*padding*/, + size_t skipped_length) { + QUICHE_DVLOG(1) << "OnPadding: " << skipped_length; + if (frame_header_.type == Http2FrameType::DATA) { + visitor()->OnStreamPadding(stream_id(), skipped_length); + } else { + MaybeAnnounceEmptyFirstHpackFragment(); + } +} + +void Http2DecoderAdapter::OnRstStream(const Http2FrameHeader& header, + Http2ErrorCode http2_error_code) { + QUICHE_DVLOG(1) << "OnRstStream: " << header << "; code=" << http2_error_code; + if (IsOkToStartFrame(header) && HasRequiredStreamId(header)) { + SpdyErrorCode error_code = + ParseErrorCode(static_cast(http2_error_code)); + visitor()->OnRstStream(header.stream_id, error_code); + } +} + +void Http2DecoderAdapter::OnSettingsStart(const Http2FrameHeader& header) { + QUICHE_DVLOG(1) << "OnSettingsStart: " << header; + if (IsOkToStartFrame(header) && HasRequiredStreamIdZero(header)) { + frame_header_ = header; + has_frame_header_ = true; + visitor()->OnSettings(); + } +} + +void Http2DecoderAdapter::OnSetting(const Http2SettingFields& setting_fields) { + QUICHE_DVLOG(1) << "OnSetting: " << setting_fields; + const auto parameter = static_cast(setting_fields.parameter); + visitor()->OnSetting(parameter, setting_fields.value); + SpdyKnownSettingsId known_id; + if (extension_ != nullptr && !spdy::ParseSettingsId(parameter, &known_id)) { + extension_->OnSetting(parameter, setting_fields.value); + } +} + +void Http2DecoderAdapter::OnSettingsEnd() { + QUICHE_DVLOG(1) << "OnSettingsEnd"; + visitor()->OnSettingsEnd(); +} + +void Http2DecoderAdapter::OnSettingsAck(const Http2FrameHeader& header) { + QUICHE_DVLOG(1) << "OnSettingsAck: " << header; + if (IsOkToStartFrame(header) && HasRequiredStreamIdZero(header)) { + visitor()->OnSettingsAck(); + } +} + +void Http2DecoderAdapter::OnPushPromiseStart( + const Http2FrameHeader& header, const Http2PushPromiseFields& promise, + size_t total_padding_length) { + QUICHE_DVLOG(1) << "OnPushPromiseStart: " << header + << "; promise: " << promise + << "; total_padding_length: " << total_padding_length; + if (IsOkToStartFrame(header) && HasRequiredStreamId(header)) { + if (promise.promised_stream_id == 0) { + SetSpdyErrorAndNotify(SpdyFramerError::SPDY_INVALID_CONTROL_FRAME, ""); + return; + } + frame_header_ = header; + has_frame_header_ = true; + ReportReceiveCompressedFrame(header); + visitor()->OnPushPromise(header.stream_id, promise.promised_stream_id, + header.IsEndHeaders()); + CommonStartHpackBlock(); + } +} + +void Http2DecoderAdapter::OnPushPromiseEnd() { + QUICHE_DVLOG(1) << "OnPushPromiseEnd"; + CommonHpackFragmentEnd(); + opt_pad_length_.reset(); +} + +void Http2DecoderAdapter::OnPing(const Http2FrameHeader& header, + const Http2PingFields& ping) { + QUICHE_DVLOG(1) << "OnPing: " << header << "; ping: " << ping; + if (IsOkToStartFrame(header) && HasRequiredStreamIdZero(header)) { + visitor()->OnPing(ToSpdyPingId(ping), false); + } +} + +void Http2DecoderAdapter::OnPingAck(const Http2FrameHeader& header, + const Http2PingFields& ping) { + QUICHE_DVLOG(1) << "OnPingAck: " << header << "; ping: " << ping; + if (IsOkToStartFrame(header) && HasRequiredStreamIdZero(header)) { + visitor()->OnPing(ToSpdyPingId(ping), true); + } +} + +void Http2DecoderAdapter::OnGoAwayStart(const Http2FrameHeader& header, + const Http2GoAwayFields& goaway) { + QUICHE_DVLOG(1) << "OnGoAwayStart: " << header << "; goaway: " << goaway; + if (IsOkToStartFrame(header) && HasRequiredStreamIdZero(header)) { + frame_header_ = header; + has_frame_header_ = true; + SpdyErrorCode error_code = + ParseErrorCode(static_cast(goaway.error_code)); + visitor()->OnGoAway(goaway.last_stream_id, error_code); + } +} + +void Http2DecoderAdapter::OnGoAwayOpaqueData(const char* data, size_t len) { + QUICHE_DVLOG(1) << "OnGoAwayOpaqueData: len=" << len; + visitor()->OnGoAwayFrameData(data, len); +} + +void Http2DecoderAdapter::OnGoAwayEnd() { + QUICHE_DVLOG(1) << "OnGoAwayEnd"; + visitor()->OnGoAwayFrameData(nullptr, 0); +} + +void Http2DecoderAdapter::OnWindowUpdate(const Http2FrameHeader& header, + uint32_t increment) { + QUICHE_DVLOG(1) << "OnWindowUpdate: " << header + << "; increment=" << increment; + if (IsOkToStartFrame(header)) { + visitor()->OnWindowUpdate(header.stream_id, increment); + } +} + +// Per RFC7838, an ALTSVC frame on stream 0 with origin_length == 0, or one on +// a stream other than stream 0 with origin_length != 0 MUST be ignored. All +// frames are decoded by Http2DecoderAdapter, and it is left to the consumer +// (listener) to implement this behavior. +void Http2DecoderAdapter::OnAltSvcStart(const Http2FrameHeader& header, + size_t origin_length, + size_t value_length) { + QUICHE_DVLOG(1) << "OnAltSvcStart: " << header + << "; origin_length: " << origin_length + << "; value_length: " << value_length; + if (!IsOkToStartFrame(header)) { + return; + } + frame_header_ = header; + has_frame_header_ = true; + alt_svc_origin_.clear(); + alt_svc_value_.clear(); +} + +void Http2DecoderAdapter::OnAltSvcOriginData(const char* data, size_t len) { + QUICHE_DVLOG(1) << "OnAltSvcOriginData: len=" << len; + alt_svc_origin_.append(data, len); +} + +// Called when decoding the Alt-Svc-Field-Value of an ALTSVC; +// the field is uninterpreted. +void Http2DecoderAdapter::OnAltSvcValueData(const char* data, size_t len) { + QUICHE_DVLOG(1) << "OnAltSvcValueData: len=" << len; + alt_svc_value_.append(data, len); +} + +void Http2DecoderAdapter::OnAltSvcEnd() { + QUICHE_DVLOG(1) << "OnAltSvcEnd: origin.size(): " << alt_svc_origin_.size() + << "; value.size(): " << alt_svc_value_.size(); + SpdyAltSvcWireFormat::AlternativeServiceVector altsvc_vector; + if (!SpdyAltSvcWireFormat::ParseHeaderFieldValue(alt_svc_value_, + &altsvc_vector)) { + QUICHE_DLOG(ERROR) << "SpdyAltSvcWireFormat::ParseHeaderFieldValue failed."; + SetSpdyErrorAndNotify(SpdyFramerError::SPDY_INVALID_CONTROL_FRAME, ""); + return; + } + visitor()->OnAltSvc(frame_header_.stream_id, alt_svc_origin_, altsvc_vector); + // We assume that ALTSVC frames are rare, so get rid of the storage. + alt_svc_origin_.clear(); + alt_svc_origin_.shrink_to_fit(); + alt_svc_value_.clear(); + alt_svc_value_.shrink_to_fit(); +} + +void Http2DecoderAdapter::OnPriorityUpdateStart( + const Http2FrameHeader& header, + const Http2PriorityUpdateFields& priority_update) { + QUICHE_DVLOG(1) << "OnPriorityUpdateStart: " << header + << "; prioritized_stream_id: " + << priority_update.prioritized_stream_id; + if (IsOkToStartFrame(header) && HasRequiredStreamIdZero(header) && + HasRequiredStreamId(priority_update.prioritized_stream_id)) { + frame_header_ = header; + has_frame_header_ = true; + prioritized_stream_id_ = priority_update.prioritized_stream_id; + } +} + +void Http2DecoderAdapter::OnPriorityUpdatePayload(const char* data, + size_t len) { + QUICHE_DVLOG(1) << "OnPriorityUpdatePayload: len=" << len; + priority_field_value_.append(data, len); +} + +void Http2DecoderAdapter::OnPriorityUpdateEnd() { + QUICHE_DVLOG(1) << "OnPriorityUpdateEnd: priority_field_value.size(): " + << priority_field_value_.size(); + visitor()->OnPriorityUpdate(prioritized_stream_id_, priority_field_value_); + priority_field_value_.clear(); +} + +void Http2DecoderAdapter::OnUnknownStart(const Http2FrameHeader& header) { + QUICHE_DVLOG(1) << "OnUnknownStart: " << header; + if (IsOkToStartFrame(header)) { + frame_header_ = header; + has_frame_header_ = true; + const uint8_t type = static_cast(header.type); + const uint8_t flags = static_cast(header.flags); + if (extension_ != nullptr) { + handling_extension_payload_ = extension_->OnFrameHeader( + header.stream_id, header.payload_length, type, flags); + } + visitor()->OnUnknownFrameStart(header.stream_id, header.payload_length, + type, flags); + } +} + +void Http2DecoderAdapter::OnUnknownPayload(const char* data, size_t len) { + if (handling_extension_payload_) { + extension_->OnFramePayload(data, len); + } else { + QUICHE_DVLOG(1) << "OnUnknownPayload: len=" << len; + } + visitor()->OnUnknownFramePayload(frame_header_.stream_id, + absl::string_view(data, len)); +} + +void Http2DecoderAdapter::OnUnknownEnd() { + QUICHE_DVLOG(1) << "OnUnknownEnd"; + handling_extension_payload_ = false; +} + +void Http2DecoderAdapter::OnPaddingTooLong(const Http2FrameHeader& header, + size_t missing_length) { + QUICHE_DVLOG(1) << "OnPaddingTooLong: " << header + << "; missing_length: " << missing_length; + if (header.type == Http2FrameType::DATA) { + if (header.payload_length == 0) { + QUICHE_DCHECK_EQ(1u, missing_length); + SetSpdyErrorAndNotify(SpdyFramerError::SPDY_INVALID_DATA_FRAME_FLAGS, ""); + return; + } + visitor()->OnStreamPadding(header.stream_id, 1); + } + SetSpdyErrorAndNotify(SpdyFramerError::SPDY_INVALID_PADDING, ""); +} + +void Http2DecoderAdapter::OnFrameSizeError(const Http2FrameHeader& header) { + QUICHE_DVLOG(1) << "OnFrameSizeError: " << header; + if (header.payload_length > max_frame_size_) { + if (header.type == Http2FrameType::DATA) { + SetSpdyErrorAndNotify(SpdyFramerError::SPDY_OVERSIZED_PAYLOAD, ""); + } else { + SetSpdyErrorAndNotify(SpdyFramerError::SPDY_CONTROL_PAYLOAD_TOO_LARGE, + ""); + } + return; + } + switch (header.type) { + case Http2FrameType::GOAWAY: + case Http2FrameType::ALTSVC: + SetSpdyErrorAndNotify(SpdyFramerError::SPDY_INVALID_CONTROL_FRAME, ""); + break; + default: + SetSpdyErrorAndNotify(SpdyFramerError::SPDY_INVALID_CONTROL_FRAME_SIZE, + ""); + } +} + +// Decodes the input up to the next frame boundary (i.e. at most one frame), +// stopping early if an error is detected. +size_t Http2DecoderAdapter::ProcessInputFrame(const char* data, size_t len) { + QUICHE_DCHECK_NE(spdy_state_, SpdyState::SPDY_ERROR); + DecodeBuffer db(data, len); + DecodeStatus status = frame_decoder_.DecodeFrame(&db); + if (spdy_state_ != SpdyState::SPDY_ERROR) { + DetermineSpdyState(status); + } else { + QUICHE_VLOG(1) << "ProcessInputFrame spdy_framer_error_=" + << SpdyFramerErrorToString(spdy_framer_error_); + if (spdy_framer_error_ == SpdyFramerError::SPDY_INVALID_PADDING && + has_frame_header_ && frame_type() != Http2FrameType::DATA) { + // spdy_framer_test checks that all of the available frame payload + // has been consumed, so do that. + size_t total = remaining_total_payload(); + if (total <= frame_header().payload_length) { + size_t avail = db.MinLengthRemaining(total); + QUICHE_VLOG(1) << "Skipping past " << avail << " bytes, of " << total + << " total remaining in the frame's payload."; + db.AdvanceCursor(avail); + } else { + QUICHE_BUG(spdy_bug_1_2) + << "Total remaining (" << total + << ") should not be greater than the payload length; " + << frame_header(); + } + } + } + return db.Offset(); +} + +// After decoding, determine the next SpdyState. Only called if the current +// state is NOT SpdyState::SPDY_ERROR (i.e. if none of the callback methods +// detected an error condition), because otherwise we assume that the callback +// method has set spdy_framer_error_ appropriately. +void Http2DecoderAdapter::DetermineSpdyState(DecodeStatus status) { + QUICHE_DCHECK_EQ(spdy_framer_error_, SPDY_NO_ERROR); + QUICHE_DCHECK(!HasError()) << spdy_framer_error_; + switch (status) { + case DecodeStatus::kDecodeDone: + QUICHE_DVLOG(1) << "ProcessInputFrame -> DecodeStatus::kDecodeDone"; + ResetBetweenFrames(); + break; + case DecodeStatus::kDecodeInProgress: + QUICHE_DVLOG(1) << "ProcessInputFrame -> DecodeStatus::kDecodeInProgress"; + if (decoded_frame_header_) { + if (IsDiscardingPayload()) { + set_spdy_state(SpdyState::SPDY_IGNORE_REMAINING_PAYLOAD); + } else if (has_frame_header_ && frame_type() == Http2FrameType::DATA) { + if (IsReadingPaddingLength()) { + set_spdy_state(SpdyState::SPDY_READ_DATA_FRAME_PADDING_LENGTH); + } else if (IsSkippingPadding()) { + set_spdy_state(SpdyState::SPDY_CONSUME_PADDING); + } else { + set_spdy_state(SpdyState::SPDY_FORWARD_STREAM_FRAME); + } + } else { + set_spdy_state(SpdyState::SPDY_CONTROL_FRAME_PAYLOAD); + } + } else { + set_spdy_state(SpdyState::SPDY_READING_COMMON_HEADER); + } + break; + case DecodeStatus::kDecodeError: + QUICHE_VLOG(1) << "ProcessInputFrame -> DecodeStatus::kDecodeError"; + if (IsDiscardingPayload()) { + if (remaining_total_payload() == 0) { + // Push the Http2FrameDecoder out of state kDiscardPayload now + // since doing so requires no input. + DecodeBuffer tmp("", 0); + DecodeStatus decode_status = frame_decoder_.DecodeFrame(&tmp); + if (decode_status != DecodeStatus::kDecodeDone) { + QUICHE_BUG(spdy_bug_1_3) + << "Expected to be done decoding the frame, not " + << decode_status; + SetSpdyErrorAndNotify(SPDY_INTERNAL_FRAMER_ERROR, ""); + } else if (spdy_framer_error_ != SPDY_NO_ERROR) { + QUICHE_BUG(spdy_bug_1_4) + << "Expected to have no error, not " + << SpdyFramerErrorToString(spdy_framer_error_); + } else { + ResetBetweenFrames(); + } + } else { + set_spdy_state(SpdyState::SPDY_IGNORE_REMAINING_PAYLOAD); + } + } else { + SetSpdyErrorAndNotify(SPDY_INVALID_CONTROL_FRAME, ""); + } + break; + } +} + +void Http2DecoderAdapter::ResetBetweenFrames() { + CorruptFrameHeader(&frame_header_); + decoded_frame_header_ = false; + has_frame_header_ = false; + set_spdy_state(SpdyState::SPDY_READY_FOR_FRAME); +} + +void Http2DecoderAdapter::set_spdy_state(SpdyState v) { + QUICHE_DVLOG(2) << "set_spdy_state(" << StateToString(v) << ")"; + spdy_state_ = v; +} + +void Http2DecoderAdapter::SetSpdyErrorAndNotify(SpdyFramerError error, + std::string detailed_error) { + if (HasError()) { + QUICHE_DCHECK_EQ(spdy_state_, SpdyState::SPDY_ERROR); + } else { + QUICHE_VLOG(2) << "SetSpdyErrorAndNotify(" << SpdyFramerErrorToString(error) + << ")"; + QUICHE_DCHECK_NE(error, SpdyFramerError::SPDY_NO_ERROR); + spdy_framer_error_ = error; + set_spdy_state(SpdyState::SPDY_ERROR); + frame_decoder_.set_listener(&no_op_listener_); + visitor()->OnError(error, detailed_error); + } +} + +bool Http2DecoderAdapter::HasError() const { + if (spdy_state_ == SpdyState::SPDY_ERROR) { + QUICHE_DCHECK_NE(spdy_framer_error(), SpdyFramerError::SPDY_NO_ERROR); + return true; + } else { + QUICHE_DCHECK_EQ(spdy_framer_error(), SpdyFramerError::SPDY_NO_ERROR); + return false; + } +} + +const Http2FrameHeader& Http2DecoderAdapter::frame_header() const { + QUICHE_DCHECK(has_frame_header_); + return frame_header_; +} + +uint32_t Http2DecoderAdapter::stream_id() const { + return frame_header().stream_id; +} + +Http2FrameType Http2DecoderAdapter::frame_type() const { + return frame_header().type; +} + +size_t Http2DecoderAdapter::remaining_total_payload() const { + QUICHE_DCHECK(has_frame_header_); + size_t remaining = frame_decoder_.remaining_payload(); + if (IsPaddable(frame_type()) && frame_header_.IsPadded()) { + remaining += frame_decoder_.remaining_padding(); + } + return remaining; +} + +bool Http2DecoderAdapter::IsReadingPaddingLength() { + bool result = frame_header_.IsPadded() && !opt_pad_length_; + QUICHE_DVLOG(2) << "Http2DecoderAdapter::IsReadingPaddingLength: " << result; + return result; +} +bool Http2DecoderAdapter::IsSkippingPadding() { + bool result = frame_header_.IsPadded() && opt_pad_length_ && + frame_decoder_.remaining_payload() == 0 && + frame_decoder_.remaining_padding() > 0; + QUICHE_DVLOG(2) << "Http2DecoderAdapter::IsSkippingPadding: " << result; + return result; +} +bool Http2DecoderAdapter::IsDiscardingPayload() { + bool result = decoded_frame_header_ && frame_decoder_.IsDiscardingPayload(); + QUICHE_DVLOG(2) << "Http2DecoderAdapter::IsDiscardingPayload: " << result; + return result; +} +// Called from OnXyz or OnXyzStart methods to decide whether it is OK to +// handle the callback. +bool Http2DecoderAdapter::IsOkToStartFrame(const Http2FrameHeader& header) { + QUICHE_DVLOG(3) << "IsOkToStartFrame"; + if (HasError()) { + QUICHE_VLOG(2) << "HasError()"; + return false; + } + QUICHE_DCHECK(!has_frame_header_); + if (has_expected_frame_type_ && header.type != expected_frame_type_) { + QUICHE_VLOG(1) << "Expected frame type " << expected_frame_type_ << ", not " + << header.type; + SetSpdyErrorAndNotify(SpdyFramerError::SPDY_UNEXPECTED_FRAME, ""); + return false; + } + + return true; +} + +bool Http2DecoderAdapter::HasRequiredStreamId(uint32_t stream_id) { + QUICHE_DVLOG(3) << "HasRequiredStreamId: " << stream_id; + if (HasError()) { + QUICHE_VLOG(2) << "HasError()"; + return false; + } + if (stream_id != 0) { + return true; + } + QUICHE_VLOG(1) << "Stream Id is required, but zero provided"; + SetSpdyErrorAndNotify(SpdyFramerError::SPDY_INVALID_STREAM_ID, ""); + return false; +} + +bool Http2DecoderAdapter::HasRequiredStreamId(const Http2FrameHeader& header) { + return HasRequiredStreamId(header.stream_id); +} + +bool Http2DecoderAdapter::HasRequiredStreamIdZero(uint32_t stream_id) { + QUICHE_DVLOG(3) << "HasRequiredStreamIdZero: " << stream_id; + if (HasError()) { + QUICHE_VLOG(2) << "HasError()"; + return false; + } + if (stream_id == 0) { + return true; + } + QUICHE_VLOG(1) << "Stream Id was not zero, as required: " << stream_id; + SetSpdyErrorAndNotify(SpdyFramerError::SPDY_INVALID_STREAM_ID, ""); + return false; +} + +bool Http2DecoderAdapter::HasRequiredStreamIdZero( + const Http2FrameHeader& header) { + return HasRequiredStreamIdZero(header.stream_id); +} + +void Http2DecoderAdapter::ReportReceiveCompressedFrame( + const Http2FrameHeader& header) { + if (debug_visitor() != nullptr) { + size_t total = header.payload_length + Http2FrameHeader::EncodedSize(); + debug_visitor()->OnReceiveCompressedFrame( + header.stream_id, ToSpdyFrameType(header.type), total); + } +} + +HpackDecoderAdapter* Http2DecoderAdapter::GetHpackDecoder() { + if (hpack_decoder_ == nullptr) { + hpack_decoder_ = std::make_unique(); + } + return hpack_decoder_.get(); +} + +void Http2DecoderAdapter::CommonStartHpackBlock() { + QUICHE_DVLOG(1) << "CommonStartHpackBlock"; + QUICHE_DCHECK(!has_hpack_first_frame_header_); + if (!frame_header_.IsEndHeaders()) { + hpack_first_frame_header_ = frame_header_; + has_hpack_first_frame_header_ = true; + } else { + CorruptFrameHeader(&hpack_first_frame_header_); + } + on_hpack_fragment_called_ = false; + SpdyHeadersHandlerInterface* handler = + visitor()->OnHeaderFrameStart(stream_id()); + if (handler == nullptr) { + QUICHE_BUG(spdy_bug_1_5) << "visitor_->OnHeaderFrameStart returned nullptr"; + SetSpdyErrorAndNotify(SpdyFramerError::SPDY_INTERNAL_FRAMER_ERROR, ""); + return; + } + GetHpackDecoder()->HandleControlFrameHeadersStart(handler); +} + +// SpdyFramer calls HandleControlFrameHeadersData even if there are zero +// fragment bytes in the first frame, so do the same. +void Http2DecoderAdapter::MaybeAnnounceEmptyFirstHpackFragment() { + if (!on_hpack_fragment_called_) { + OnHpackFragment(nullptr, 0); + QUICHE_DCHECK(on_hpack_fragment_called_); + } +} + +void Http2DecoderAdapter::CommonHpackFragmentEnd() { + QUICHE_DVLOG(1) << "CommonHpackFragmentEnd: stream_id=" << stream_id(); + if (HasError()) { + QUICHE_VLOG(1) << "HasError(), returning"; + return; + } + QUICHE_DCHECK(has_frame_header_); + MaybeAnnounceEmptyFirstHpackFragment(); + if (frame_header_.IsEndHeaders()) { + QUICHE_DCHECK_EQ(has_hpack_first_frame_header_, + frame_type() == Http2FrameType::CONTINUATION) + << frame_header(); + has_expected_frame_type_ = false; + auto* decoder = GetHpackDecoder(); + if (decoder->HandleControlFrameHeadersComplete()) { + visitor()->OnHeaderFrameEnd(stream_id()); + } else { + SetSpdyErrorAndNotify( + HpackDecodingErrorToSpdyFramerError(decoder->error()), ""); + return; + } + const Http2FrameHeader& first = frame_type() == Http2FrameType::CONTINUATION + ? hpack_first_frame_header_ + : frame_header_; + if (first.type == Http2FrameType::HEADERS && first.IsEndStream()) { + visitor()->OnStreamEnd(first.stream_id); + } + has_hpack_first_frame_header_ = false; + CorruptFrameHeader(&hpack_first_frame_header_); + } else { + QUICHE_DCHECK(has_hpack_first_frame_header_); + has_expected_frame_type_ = true; + expected_frame_type_ = Http2FrameType::CONTINUATION; + } +} + +} // namespace http2 + +namespace spdy { + +bool SpdyFramerVisitorInterface::OnGoAwayFrameData(const char* /*goaway_data*/, + size_t /*len*/) { + return true; +} + +} // namespace spdy diff --git a/quiche/spdy/core/http2_frame_decoder_adapter.h b/quiche/spdy/core/http2_frame_decoder_adapter.h new file mode 100644 index 000000000000..adfb14c121f4 --- /dev/null +++ b/quiche/spdy/core/http2_frame_decoder_adapter.h @@ -0,0 +1,564 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_SPDY_CORE_HTTP2_FRAME_DECODER_ADAPTER_H_ +#define QUICHE_SPDY_CORE_HTTP2_FRAME_DECODER_ADAPTER_H_ + +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "quiche/http2/decoder/http2_frame_decoder.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/spdy/core/hpack/hpack_decoder_adapter.h" +#include "quiche/spdy/core/hpack/hpack_header_table.h" +#include "quiche/spdy/core/spdy_alt_svc_wire_format.h" +#include "quiche/spdy/core/spdy_headers_handler_interface.h" +#include "quiche/spdy/core/spdy_protocol.h" + +namespace spdy { + +class SpdyFramerVisitorInterface; +class ExtensionVisitorInterface; + +} // namespace spdy + +// TODO(dahollings): Perform various renames/moves suggested in cl/164660364. + +namespace http2 { + +// Adapts SpdyFramer interface to use Http2FrameDecoder. +class QUICHE_EXPORT Http2DecoderAdapter + : public http2::Http2FrameDecoderListener { + public: + // HTTP2 states. + enum SpdyState { + SPDY_ERROR, + SPDY_READY_FOR_FRAME, // Framer is ready for reading the next frame. + SPDY_FRAME_COMPLETE, // Framer has finished reading a frame, need to reset. + SPDY_READING_COMMON_HEADER, + SPDY_CONTROL_FRAME_PAYLOAD, + SPDY_READ_DATA_FRAME_PADDING_LENGTH, + SPDY_CONSUME_PADDING, + SPDY_IGNORE_REMAINING_PAYLOAD, + SPDY_FORWARD_STREAM_FRAME, + SPDY_CONTROL_FRAME_BEFORE_HEADER_BLOCK, + SPDY_CONTROL_FRAME_HEADER_BLOCK, + SPDY_GOAWAY_FRAME_PAYLOAD, + SPDY_SETTINGS_FRAME_HEADER, + SPDY_SETTINGS_FRAME_PAYLOAD, + SPDY_ALTSVC_FRAME_PAYLOAD, + SPDY_EXTENSION_FRAME_PAYLOAD, + }; + + // Framer error codes. + enum SpdyFramerError { + SPDY_NO_ERROR, + SPDY_INVALID_STREAM_ID, // Stream ID is invalid + SPDY_INVALID_CONTROL_FRAME, // Control frame is mal-formatted. + SPDY_CONTROL_PAYLOAD_TOO_LARGE, // Control frame payload was too large. + SPDY_DECOMPRESS_FAILURE, // There was an error decompressing. + SPDY_INVALID_PADDING, // HEADERS or DATA frame padding invalid + SPDY_INVALID_DATA_FRAME_FLAGS, // Data frame has invalid flags. + SPDY_UNEXPECTED_FRAME, // Frame received out of order. + SPDY_INTERNAL_FRAMER_ERROR, // SpdyFramer was used incorrectly. + SPDY_INVALID_CONTROL_FRAME_SIZE, // Control frame not sized to spec + SPDY_OVERSIZED_PAYLOAD, // Payload size was too large + + // HttpDecoder or HttpDecoderAdapter error. + // See HpackDecodingError for description of each error code. + SPDY_HPACK_INDEX_VARINT_ERROR, + SPDY_HPACK_NAME_LENGTH_VARINT_ERROR, + SPDY_HPACK_VALUE_LENGTH_VARINT_ERROR, + SPDY_HPACK_NAME_TOO_LONG, + SPDY_HPACK_VALUE_TOO_LONG, + SPDY_HPACK_NAME_HUFFMAN_ERROR, + SPDY_HPACK_VALUE_HUFFMAN_ERROR, + SPDY_HPACK_MISSING_DYNAMIC_TABLE_SIZE_UPDATE, + SPDY_HPACK_INVALID_INDEX, + SPDY_HPACK_INVALID_NAME_INDEX, + SPDY_HPACK_DYNAMIC_TABLE_SIZE_UPDATE_NOT_ALLOWED, + SPDY_HPACK_INITIAL_DYNAMIC_TABLE_SIZE_UPDATE_IS_ABOVE_LOW_WATER_MARK, + SPDY_HPACK_DYNAMIC_TABLE_SIZE_UPDATE_IS_ABOVE_ACKNOWLEDGED_SETTING, + SPDY_HPACK_TRUNCATED_BLOCK, + SPDY_HPACK_FRAGMENT_TOO_LONG, + SPDY_HPACK_COMPRESSED_HEADER_SIZE_EXCEEDS_LIMIT, + + // Set if the visitor no longer wishes to receive events for this + // connection. + SPDY_STOP_PROCESSING, + + LAST_ERROR, // Must be the last entry in the enum. + }; + + // For debugging. + static const char* StateToString(int state); + static const char* SpdyFramerErrorToString(SpdyFramerError spdy_framer_error); + + Http2DecoderAdapter(); + ~Http2DecoderAdapter() override; + + Http2DecoderAdapter(const Http2DecoderAdapter&) = delete; + Http2DecoderAdapter& operator=(const Http2DecoderAdapter&) = delete; + + // Set callbacks to be called from the framer. A visitor must be set, or + // else the framer will likely crash. It is acceptable for the visitor + // to do nothing. If this is called multiple times, only the last visitor + // will be used. + void set_visitor(spdy::SpdyFramerVisitorInterface* visitor); + spdy::SpdyFramerVisitorInterface* visitor() const { return visitor_; } + + // Set extension callbacks to be called from the framer or decoder. Optional. + // If called multiple times, only the last visitor will be used. + void set_extension_visitor(spdy::ExtensionVisitorInterface* visitor); + spdy::ExtensionVisitorInterface* extension_visitor() const { + return extension_; + } + + // Set debug callbacks to be called from the framer. The debug visitor is + // completely optional and need not be set in order for normal operation. + // If this is called multiple times, only the last visitor will be used. + void set_debug_visitor(spdy::SpdyFramerDebugVisitorInterface* debug_visitor); + spdy::SpdyFramerDebugVisitorInterface* debug_visitor() const { + return debug_visitor_; + } + + // Decode the |len| bytes of encoded HTTP/2 starting at |*data|. Returns + // the number of bytes consumed. It is safe to pass more bytes in than + // may be consumed. Should process (or otherwise buffer) as much as + // available. + // + // If the input contains the entirety of a DATA frame payload, GOAWAY frame + // Additional Debug Data field, or unknown frame payload, then the + // corresponding SpdyFramerVisitorInterface::OnStreamFrameData(), + // OnGoAwayFrameData(), or ExtensionVisitorInterface::OnFramePayload() method + // is guaranteed to be called exactly once, with the entire payload or field. + size_t ProcessInput(const char* data, size_t len); + + // Current state of the decoder. + SpdyState state() const; + + // Current error code (NO_ERROR if state != ERROR). + SpdyFramerError spdy_framer_error() const; + + // Has any frame header looked like the start of an HTTP/1.1 (or earlier) + // response? Used to detect if a backend/server that we sent a request to + // has responded with an HTTP/1.1 (or earlier) response. + bool probable_http_response() const; + + spdy::HpackDecoderAdapter* GetHpackDecoder(); + const spdy::HpackDecoderAdapter* GetHpackDecoder() const { + return hpack_decoder_.get(); + } + + bool HasError() const; + + // A visitor may call this method to indicate it no longer wishes to receive + // events for this connection. + void StopProcessing(); + + // Sets the limit on the size of received HTTP/2 frame payloads. Corresponds + // to SETTINGS_MAX_FRAME_SIZE as advertised to the peer. + void SetMaxFrameSize(size_t max_frame_size); + + private: + bool OnFrameHeader(const Http2FrameHeader& header) override; + void OnDataStart(const Http2FrameHeader& header) override; + void OnDataPayload(const char* data, size_t len) override; + void OnDataEnd() override; + void OnHeadersStart(const Http2FrameHeader& header) override; + void OnHeadersPriority(const Http2PriorityFields& priority) override; + void OnHpackFragment(const char* data, size_t len) override; + void OnHeadersEnd() override; + void OnPriorityFrame(const Http2FrameHeader& header, + const Http2PriorityFields& priority) override; + void OnContinuationStart(const Http2FrameHeader& header) override; + void OnContinuationEnd() override; + void OnPadLength(size_t trailing_length) override; + void OnPadding(const char* padding, size_t skipped_length) override; + void OnRstStream(const Http2FrameHeader& header, + Http2ErrorCode http2_error_code) override; + void OnSettingsStart(const Http2FrameHeader& header) override; + void OnSetting(const Http2SettingFields& setting_fields) override; + void OnSettingsEnd() override; + void OnSettingsAck(const Http2FrameHeader& header) override; + void OnPushPromiseStart(const Http2FrameHeader& header, + const Http2PushPromiseFields& promise, + size_t total_padding_length) override; + void OnPushPromiseEnd() override; + void OnPing(const Http2FrameHeader& header, + const Http2PingFields& ping) override; + void OnPingAck(const Http2FrameHeader& header, + const Http2PingFields& ping) override; + void OnGoAwayStart(const Http2FrameHeader& header, + const Http2GoAwayFields& goaway) override; + void OnGoAwayOpaqueData(const char* data, size_t len) override; + void OnGoAwayEnd() override; + void OnWindowUpdate(const Http2FrameHeader& header, + uint32_t increment) override; + void OnAltSvcStart(const Http2FrameHeader& header, size_t origin_length, + size_t value_length) override; + void OnAltSvcOriginData(const char* data, size_t len) override; + void OnAltSvcValueData(const char* data, size_t len) override; + void OnAltSvcEnd() override; + void OnPriorityUpdateStart( + const Http2FrameHeader& header, + const Http2PriorityUpdateFields& priority_update) override; + void OnPriorityUpdatePayload(const char* data, size_t len) override; + void OnPriorityUpdateEnd() override; + void OnUnknownStart(const Http2FrameHeader& header) override; + void OnUnknownPayload(const char* data, size_t len) override; + void OnUnknownEnd() override; + void OnPaddingTooLong(const Http2FrameHeader& header, + size_t missing_length) override; + void OnFrameSizeError(const Http2FrameHeader& header) override; + + size_t ProcessInputFrame(const char* data, size_t len); + + void DetermineSpdyState(DecodeStatus status); + void ResetBetweenFrames(); + + void set_spdy_state(SpdyState v); + + void SetSpdyErrorAndNotify(SpdyFramerError error, std::string detailed_error); + + const Http2FrameHeader& frame_header() const; + + uint32_t stream_id() const; + Http2FrameType frame_type() const; + + size_t remaining_total_payload() const; + + bool IsReadingPaddingLength(); + bool IsSkippingPadding(); + bool IsDiscardingPayload(); + // Called from OnXyz or OnXyzStart methods to decide whether it is OK to + // handle the callback. + bool IsOkToStartFrame(const Http2FrameHeader& header); + bool HasRequiredStreamId(uint32_t stream_id); + + bool HasRequiredStreamId(const Http2FrameHeader& header); + + bool HasRequiredStreamIdZero(uint32_t stream_id); + + bool HasRequiredStreamIdZero(const Http2FrameHeader& header); + + void ReportReceiveCompressedFrame(const Http2FrameHeader& header); + + void CommonStartHpackBlock(); + + // SpdyFramer calls HandleControlFrameHeadersData even if there are zero + // fragment bytes in the first frame, so do the same. + void MaybeAnnounceEmptyFirstHpackFragment(); + void CommonHpackFragmentEnd(); + + // The most recently decoded frame header; invalid after we reached the end + // of that frame. + Http2FrameHeader frame_header_; + + // If decoding an HPACK block that is split across multiple frames, this holds + // the frame header of the HEADERS or PUSH_PROMISE that started the block. + Http2FrameHeader hpack_first_frame_header_; + + // Amount of trailing padding. Currently used just as an indicator of whether + // OnPadLength has been called. + absl::optional opt_pad_length_; + + // Temporary buffers for the AltSvc fields. + std::string alt_svc_origin_; + std::string alt_svc_value_; + + // Temporary buffers for PRIORITY_UPDATE fields. + uint32_t prioritized_stream_id_ = 0; + std::string priority_field_value_; + + // Listener used if we transition to an error state; the listener ignores all + // the callbacks. + Http2FrameDecoderNoOpListener no_op_listener_; + + spdy::SpdyFramerVisitorInterface* visitor_ = nullptr; + spdy::SpdyFramerDebugVisitorInterface* debug_visitor_ = nullptr; + + // If non-null, unknown frames and settings are passed to the extension. + spdy::ExtensionVisitorInterface* extension_ = nullptr; + + // The HPACK decoder to be used for this adapter. User is responsible for + // clearing if the adapter is to be used for another connection. + std::unique_ptr hpack_decoder_; + + // The HTTP/2 frame decoder. + Http2FrameDecoder frame_decoder_; + + // Next frame type expected. Currently only used for CONTINUATION frames, + // but could be used for detecting whether the first frame is a SETTINGS + // frame. + // TODO(jamessynge): Provide means to indicate that decoder should require + // SETTINGS frame as the first frame. + Http2FrameType expected_frame_type_; + + // Attempt to duplicate the SpdyState and SpdyFramerError values that + // SpdyFramer sets. Values determined by getting tests to pass. + SpdyState spdy_state_ = SpdyState::SPDY_READY_FOR_FRAME; + SpdyFramerError spdy_framer_error_ = SpdyFramerError::SPDY_NO_ERROR; + + // The limit on the size of received HTTP/2 payloads as specified in the + // SETTINGS_MAX_FRAME_SIZE advertised to peer. + size_t max_frame_size_ = spdy::kHttp2DefaultFramePayloadLimit; + + // Has OnFrameHeader been called? + bool decoded_frame_header_ = false; + + // Have we recorded an Http2FrameHeader for the current frame? + // We only do so if the decoder will make multiple callbacks for + // the frame; for example, for PING frames we don't make record + // the frame header, but for ALTSVC we do. + bool has_frame_header_ = false; + + // Have we recorded an Http2FrameHeader for the current HPACK block? + // True only for multi-frame HPACK blocks. + bool has_hpack_first_frame_header_ = false; + + // Has OnHeaders() already been called for current HEADERS block? Only + // meaningful between OnHeadersStart and OnHeadersPriority. + bool on_headers_called_ = false; + + // Has OnHpackFragment() already been called for current HPACK block? + // SpdyFramer will pass an empty buffer to the HPACK decoder if a HEADERS + // or PUSH_PROMISE has no HPACK data in it (e.g. a HEADERS frame with only + // padding). Detect that condition and replicate the behavior using this + // field. + bool on_hpack_fragment_called_ = false; + + // Have we seen a frame header that appears to be an HTTP/1 response? + bool latched_probable_http_response_ = false; + + // Is expected_frame_type_ set? + bool has_expected_frame_type_ = false; + + // Is the current frame payload destined for |extension_|? + bool handling_extension_payload_ = false; +}; + +} // namespace http2 + +namespace spdy { + +// Http2DecoderAdapter will use the given visitor implementing this +// interface to deliver event callbacks as frames are decoded. +// +// Control frames that contain HTTP2 header blocks (HEADER, and PUSH_PROMISE) +// are processed in fashion that allows the decompressed header block to be +// delivered in chunks to the visitor. +// The following steps are followed: +// 1. OnHeaders, or OnPushPromise is called. +// 2. OnHeaderFrameStart is called; visitor is expected to return an instance +// of SpdyHeadersHandlerInterface that will receive the header key-value +// pairs. +// 3. OnHeaderFrameEnd is called, indicating that the full header block has +// been delivered for the control frame. +// During step 2, if the visitor is not interested in accepting the header data, +// it should return a no-op implementation of SpdyHeadersHandlerInterface. +class QUICHE_EXPORT SpdyFramerVisitorInterface { + public: + virtual ~SpdyFramerVisitorInterface() {} + + // Called if an error is detected in the SpdyFrame protocol. + virtual void OnError(http2::Http2DecoderAdapter::SpdyFramerError error, + std::string detailed_error) = 0; + + // Called when the common header for a frame is received. Validating the + // common header occurs in later processing. + virtual void OnCommonHeader(SpdyStreamId /*stream_id*/, size_t /*length*/, + uint8_t /*type*/, uint8_t /*flags*/) {} + + // Called when a data frame header is received. The frame's data payload will + // be provided via subsequent calls to OnStreamFrameData(). + // |stream_id| The stream receiving data. + // |length| The length of the payload in this DATA frame. Includes the length + // of the data itself and potential padding. + // |fin| Whether the END_STREAM flag is set in the frame header. + virtual void OnDataFrameHeader(SpdyStreamId stream_id, size_t length, + bool fin) = 0; + + // Called when data is received. + // |stream_id| The stream receiving data. + // |data| A buffer containing the data received. + // |len| The length of the data buffer. + virtual void OnStreamFrameData(SpdyStreamId stream_id, const char* data, + size_t len) = 0; + + // Called when the other side has finished sending data on this stream. + // |stream_id| The stream that was receiving data. + virtual void OnStreamEnd(SpdyStreamId stream_id) = 0; + + // Called when padding length field is received on a DATA frame. + // |stream_id| The stream receiving data. + // |value| The value of the padding length field. + virtual void OnStreamPadLength(SpdyStreamId /*stream_id*/, size_t /*value*/) { + } + + // Called when padding is received (the trailing octets, not pad_len field) on + // a DATA frame. + // |stream_id| The stream receiving data. + // |len| The number of padding octets. + virtual void OnStreamPadding(SpdyStreamId stream_id, size_t len) = 0; + + // Called just before processing the payload of a frame containing header + // data. Should return an implementation of SpdyHeadersHandlerInterface that + // will receive headers for stream |stream_id|. The caller will not take + // ownership of the headers handler. The same instance should remain live + // and be returned for all header frames comprising a logical header block + // (i.e. until OnHeaderFrameEnd() is called). + virtual SpdyHeadersHandlerInterface* OnHeaderFrameStart( + SpdyStreamId stream_id) = 0; + + // Called after processing the payload of a frame containing header data. + virtual void OnHeaderFrameEnd(SpdyStreamId stream_id) = 0; + + // Called when a RST_STREAM frame has been parsed. + virtual void OnRstStream(SpdyStreamId stream_id, + SpdyErrorCode error_code) = 0; + + // Called when a SETTINGS frame is received. + virtual void OnSettings() {} + + // Called when a complete setting within a SETTINGS frame has been parsed. + // Note that |id| may or may not be a SETTINGS ID defined in the HTTP/2 spec. + virtual void OnSetting(SpdySettingsId id, uint32_t value) = 0; + + // Called when a SETTINGS frame is received with the ACK flag set. + virtual void OnSettingsAck() {} + + // Called before and after parsing SETTINGS id and value tuples. + virtual void OnSettingsEnd() = 0; + + // Called when a PING frame has been parsed. + virtual void OnPing(SpdyPingId unique_id, bool is_ack) = 0; + + // Called when a GOAWAY frame has been parsed. + virtual void OnGoAway(SpdyStreamId last_accepted_stream_id, + SpdyErrorCode error_code) = 0; + + // Called when a HEADERS frame is received. + // Note that header block data is not included. See OnHeaderFrameStart(). + // |stream_id| The stream receiving the header. + // |payload_length| The length of the payload in this HEADERS frame. Includes + // the length of the encoded header block and potential padding. + // |has_priority| Whether or not the headers frame included a priority value, + // and stream dependency info. + // |weight| If |has_priority| is true, then weight (in the range [1, 256]) + // for the receiving stream, otherwise 0. + // |parent_stream_id| If |has_priority| is true the parent stream of the + // receiving stream, else 0. + // |exclusive| If |has_priority| is true the exclusivity of dependence on the + // parent stream, else false. + // |fin| Whether the END_STREAM flag is set in the frame header. + // |end| False if HEADERs frame is to be followed by a CONTINUATION frame, + // or true if not. + virtual void OnHeaders(SpdyStreamId stream_id, size_t payload_length, + bool has_priority, int weight, + SpdyStreamId parent_stream_id, bool exclusive, + bool fin, bool end) = 0; + + // Called when a WINDOW_UPDATE frame has been parsed. + virtual void OnWindowUpdate(SpdyStreamId stream_id, + int delta_window_size) = 0; + + // Called when a goaway frame opaque data is available. + // |goaway_data| A buffer containing the opaque GOAWAY data chunk received. + // |len| The length of the header data buffer. A length of zero indicates + // that the header data block has been completely sent. + // When this function returns true the visitor indicates that it accepted + // all of the data. Returning false indicates that that an error has + // occurred while processing the data. Default implementation returns true. + virtual bool OnGoAwayFrameData(const char* goaway_data, size_t len); + + // Called when a PUSH_PROMISE frame is received. + // Note that header block data is not included. See OnHeaderFrameStart(). + virtual void OnPushPromise(SpdyStreamId stream_id, + SpdyStreamId promised_stream_id, bool end) = 0; + + // Called when a CONTINUATION frame is received. + // Note that header block data is not included. See OnHeaderFrameStart(). + // |stream_id| The stream receiving the CONTINUATION. + // |payload_length| The length of the payload in this CONTINUATION frame. + // |end| True if this CONTINUATION frame will not be followed by another + // CONTINUATION frame. + virtual void OnContinuation(SpdyStreamId stream_id, size_t payload_length, + bool end) = 0; + + // Called when an ALTSVC frame has been parsed. + virtual void OnAltSvc( + SpdyStreamId /*stream_id*/, absl::string_view /*origin*/, + const SpdyAltSvcWireFormat::AlternativeServiceVector& /*altsvc_vector*/) { + } + + // Called when a PRIORITY frame is received. + // |stream_id| The stream to update the priority of. + // |parent_stream_id| The parent stream of |stream_id|. + // |weight| Stream weight, in the range [1, 256]. + // |exclusive| Whether |stream_id| should be an only child of + // |parent_stream_id|. + virtual void OnPriority(SpdyStreamId stream_id, SpdyStreamId parent_stream_id, + int weight, bool exclusive) = 0; + + // Called when a PRIORITY_UPDATE frame is received on stream 0. + // |prioritized_stream_id| is the Prioritized Stream ID and + // |priority_field_value| is the Priority Field Value + // parsed from the frame payload. + virtual void OnPriorityUpdate(SpdyStreamId prioritized_stream_id, + absl::string_view priority_field_value) = 0; + + // Called when a frame type we don't recognize is received. + // Return true if this appears to be a valid extension frame, false otherwise. + // We distinguish between extension frames and nonsense by checking + // whether the stream id is valid. + // TODO(b/239060116): Remove this callback altogether. + virtual bool OnUnknownFrame(SpdyStreamId stream_id, uint8_t frame_type) = 0; + + // Called when the common header for a non-standard frame is received. If the + // `length` is nonzero, the frame's payload will be provided via subsequent + // calls to OnUnknownFramePayload(). + // |stream_id| The stream receiving the non-standard frame. + // |length| The length of the payload of the frame. + // |type| The type of the frame. This type is non-standard. + // |flags| The flags of the frame. + virtual void OnUnknownFrameStart(SpdyStreamId stream_id, size_t length, + uint8_t type, uint8_t flags) = 0; + + // Called when a non-empty payload chunk for a non-standard frame is received. + // The payload for a single frame may be delivered as multiple calls to + // OnUnknownFramePayload(). Since the length field is passed in + // OnUnknownFrameStart(), there is no explicit indication of the end of the + // frame payload. + // |stream_id| The stream receiving the non-standard frame. + // |payload| The payload chunk, which will be non-empty. + virtual void OnUnknownFramePayload(SpdyStreamId stream_id, + absl::string_view payload) = 0; +}; + +class QUICHE_EXPORT ExtensionVisitorInterface { + public: + virtual ~ExtensionVisitorInterface() {} + + // Called when non-standard SETTINGS are received. + virtual void OnSetting(SpdySettingsId id, uint32_t value) = 0; + + // Called when non-standard frames are received. + virtual bool OnFrameHeader(SpdyStreamId stream_id, size_t length, + uint8_t type, uint8_t flags) = 0; + + // The payload for a single frame may be delivered as multiple calls to + // OnFramePayload. Since the length field is passed in OnFrameHeader, there is + // no explicit indication of the end of the frame payload. + virtual void OnFramePayload(const char* data, size_t len) = 0; +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_HTTP2_FRAME_DECODER_ADAPTER_H_ diff --git a/quiche/spdy/core/http2_header_block.cc b/quiche/spdy/core/http2_header_block.cc new file mode 100644 index 000000000000..8dddae50a143 --- /dev/null +++ b/quiche/spdy/core/http2_header_block.cc @@ -0,0 +1,315 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/http2_header_block.h" + +#include + +#include +#include + +#include "absl/strings/str_cat.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace spdy { +namespace { + +// By default, linked_hash_map's internal map allocates space for 100 map +// buckets on construction, which is larger than necessary. Standard library +// unordered map implementations use a list of prime numbers to set the bucket +// count for a particular capacity. |kInitialMapBuckets| is chosen to reduce +// memory usage for small header blocks, at the cost of having to rehash for +// large header blocks. +const size_t kInitialMapBuckets = 11; + +const char kCookieKey[] = "cookie"; +const char kNullSeparator = 0; + +absl::string_view SeparatorForKey(absl::string_view key) { + if (key == kCookieKey) { + static absl::string_view cookie_separator = "; "; + return cookie_separator; + } else { + return absl::string_view(&kNullSeparator, 1); + } +} + +} // namespace + +Http2HeaderBlock::HeaderValue::HeaderValue(Http2HeaderStorage* storage, + absl::string_view key, + absl::string_view initial_value) + : storage_(storage), + fragments_({initial_value}), + pair_({key, {}}), + size_(initial_value.size()), + separator_size_(SeparatorForKey(key).size()) {} + +Http2HeaderBlock::HeaderValue::HeaderValue(HeaderValue&& other) + : storage_(other.storage_), + fragments_(std::move(other.fragments_)), + pair_(std::move(other.pair_)), + size_(other.size_), + separator_size_(other.separator_size_) {} + +Http2HeaderBlock::HeaderValue& Http2HeaderBlock::HeaderValue::operator=( + HeaderValue&& other) { + storage_ = other.storage_; + fragments_ = std::move(other.fragments_); + pair_ = std::move(other.pair_); + size_ = other.size_; + separator_size_ = other.separator_size_; + return *this; +} + +void Http2HeaderBlock::HeaderValue::set_storage(Http2HeaderStorage* storage) { + storage_ = storage; +} + +Http2HeaderBlock::HeaderValue::~HeaderValue() = default; + +absl::string_view Http2HeaderBlock::HeaderValue::ConsolidatedValue() const { + if (fragments_.empty()) { + return absl::string_view(); + } + if (fragments_.size() > 1) { + fragments_ = { + storage_->WriteFragments(fragments_, SeparatorForKey(pair_.first))}; + } + return fragments_[0]; +} + +void Http2HeaderBlock::HeaderValue::Append(absl::string_view fragment) { + size_ += (fragment.size() + separator_size_); + fragments_.push_back(fragment); +} + +const std::pair& +Http2HeaderBlock::HeaderValue::as_pair() const { + pair_.second = ConsolidatedValue(); + return pair_; +} + +Http2HeaderBlock::iterator::iterator(MapType::const_iterator it) : it_(it) {} + +Http2HeaderBlock::iterator::iterator(const iterator& other) = default; + +Http2HeaderBlock::iterator::~iterator() = default; + +Http2HeaderBlock::ValueProxy::ValueProxy( + Http2HeaderBlock* block, Http2HeaderBlock::MapType::iterator lookup_result, + const absl::string_view key, size_t* spdy_header_block_value_size) + : block_(block), + lookup_result_(lookup_result), + key_(key), + spdy_header_block_value_size_(spdy_header_block_value_size), + valid_(true) {} + +Http2HeaderBlock::ValueProxy::ValueProxy(ValueProxy&& other) + : block_(other.block_), + lookup_result_(other.lookup_result_), + key_(other.key_), + spdy_header_block_value_size_(other.spdy_header_block_value_size_), + valid_(true) { + other.valid_ = false; +} + +Http2HeaderBlock::ValueProxy& Http2HeaderBlock::ValueProxy::operator=( + Http2HeaderBlock::ValueProxy&& other) { + block_ = other.block_; + lookup_result_ = other.lookup_result_; + key_ = other.key_; + valid_ = true; + other.valid_ = false; + spdy_header_block_value_size_ = other.spdy_header_block_value_size_; + return *this; +} + +Http2HeaderBlock::ValueProxy::~ValueProxy() { + // If the ValueProxy is destroyed while lookup_result_ == block_->end(), + // the assignment operator was never used, and the block's Http2HeaderStorage + // can reclaim the memory used by the key. This makes lookup-only access to + // Http2HeaderBlock through operator[] memory-neutral. + if (valid_ && lookup_result_ == block_->map_.end()) { + block_->storage_.Rewind(key_); + } +} + +Http2HeaderBlock::ValueProxy& Http2HeaderBlock::ValueProxy::operator=( + absl::string_view value) { + *spdy_header_block_value_size_ += value.size(); + Http2HeaderStorage* storage = &block_->storage_; + if (lookup_result_ == block_->map_.end()) { + QUICHE_DVLOG(1) << "Inserting: (" << key_ << ", " << value << ")"; + lookup_result_ = + block_->map_ + .emplace(std::make_pair( + key_, HeaderValue(storage, key_, storage->Write(value)))) + .first; + } else { + QUICHE_DVLOG(1) << "Updating key: " << key_ << " with value: " << value; + *spdy_header_block_value_size_ -= lookup_result_->second.SizeEstimate(); + lookup_result_->second = HeaderValue(storage, key_, storage->Write(value)); + } + return *this; +} + +bool Http2HeaderBlock::ValueProxy::operator==(absl::string_view value) const { + if (lookup_result_ == block_->map_.end()) { + return false; + } else { + return value == lookup_result_->second.value(); + } +} + +std::string Http2HeaderBlock::ValueProxy::as_string() const { + if (lookup_result_ == block_->map_.end()) { + return ""; + } else { + return std::string(lookup_result_->second.value()); + } +} + +Http2HeaderBlock::Http2HeaderBlock() : map_(kInitialMapBuckets) {} + +Http2HeaderBlock::Http2HeaderBlock(Http2HeaderBlock&& other) + : map_(kInitialMapBuckets) { + map_.swap(other.map_); + storage_ = std::move(other.storage_); + for (auto& p : map_) { + p.second.set_storage(&storage_); + } + key_size_ = other.key_size_; + value_size_ = other.value_size_; +} + +Http2HeaderBlock::~Http2HeaderBlock() = default; + +Http2HeaderBlock& Http2HeaderBlock::operator=(Http2HeaderBlock&& other) { + map_.swap(other.map_); + storage_ = std::move(other.storage_); + for (auto& p : map_) { + p.second.set_storage(&storage_); + } + key_size_ = other.key_size_; + value_size_ = other.value_size_; + return *this; +} + +Http2HeaderBlock Http2HeaderBlock::Clone() const { + Http2HeaderBlock copy; + for (const auto& p : *this) { + copy.AppendHeader(p.first, p.second); + } + return copy; +} + +bool Http2HeaderBlock::operator==(const Http2HeaderBlock& other) const { + return size() == other.size() && std::equal(begin(), end(), other.begin()); +} + +bool Http2HeaderBlock::operator!=(const Http2HeaderBlock& other) const { + return !(operator==(other)); +} + +std::string Http2HeaderBlock::DebugString() const { + if (empty()) { + return "{}"; + } + + std::string output = "\n{\n"; + for (auto it = begin(); it != end(); ++it) { + absl::StrAppend(&output, " ", it->first, " ", it->second, "\n"); + } + absl::StrAppend(&output, "}\n"); + return output; +} + +void Http2HeaderBlock::erase(absl::string_view key) { + auto iter = map_.find(key); + if (iter != map_.end()) { + QUICHE_DVLOG(1) << "Erasing header with name: " << key; + key_size_ -= key.size(); + value_size_ -= iter->second.SizeEstimate(); + map_.erase(iter); + } +} + +void Http2HeaderBlock::clear() { + key_size_ = 0; + value_size_ = 0; + map_.clear(); + storage_.Clear(); +} + +void Http2HeaderBlock::insert(const Http2HeaderBlock::value_type& value) { + // TODO(birenroy): Write new value in place of old value, if it fits. + value_size_ += value.second.size(); + + auto iter = map_.find(value.first); + if (iter == map_.end()) { + QUICHE_DVLOG(1) << "Inserting: (" << value.first << ", " << value.second + << ")"; + AppendHeader(value.first, value.second); + } else { + QUICHE_DVLOG(1) << "Updating key: " << iter->first + << " with value: " << value.second; + value_size_ -= iter->second.SizeEstimate(); + iter->second = + HeaderValue(&storage_, iter->first, storage_.Write(value.second)); + } +} + +Http2HeaderBlock::ValueProxy Http2HeaderBlock::operator[]( + const absl::string_view key) { + QUICHE_DVLOG(2) << "Operator[] saw key: " << key; + absl::string_view out_key; + auto iter = map_.find(key); + if (iter == map_.end()) { + // We write the key first, to assure that the ValueProxy has a + // reference to a valid absl::string_view in its operator=. + out_key = WriteKey(key); + QUICHE_DVLOG(2) << "Key written as: " << std::hex + << static_cast(key.data()) << ", " << std::dec + << key.size(); + } else { + out_key = iter->first; + } + return ValueProxy(this, iter, out_key, &value_size_); +} + +void Http2HeaderBlock::AppendValueOrAddHeader(const absl::string_view key, + const absl::string_view value) { + value_size_ += value.size(); + + auto iter = map_.find(key); + if (iter == map_.end()) { + QUICHE_DVLOG(1) << "Inserting: (" << key << ", " << value << ")"; + + AppendHeader(key, value); + return; + } + QUICHE_DVLOG(1) << "Updating key: " << iter->first + << "; appending value: " << value; + value_size_ += SeparatorForKey(key).size(); + iter->second.Append(storage_.Write(value)); +} + +void Http2HeaderBlock::AppendHeader(const absl::string_view key, + const absl::string_view value) { + auto backed_key = WriteKey(key); + map_.emplace(std::make_pair( + backed_key, HeaderValue(&storage_, backed_key, storage_.Write(value)))); +} + +absl::string_view Http2HeaderBlock::WriteKey(const absl::string_view key) { + key_size_ += key.size(); + return storage_.Write(key); +} + +size_t Http2HeaderBlock::bytes_allocated() const { + return storage_.bytes_allocated(); +} + +} // namespace spdy diff --git a/quiche/spdy/core/http2_header_block.h b/quiche/spdy/core/http2_header_block.h new file mode 100644 index 000000000000..e0da8f55d3c7 --- /dev/null +++ b/quiche/spdy/core/http2_header_block.h @@ -0,0 +1,291 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_SPDY_CORE_HTTP2_HEADER_BLOCK_H_ +#define QUICHE_SPDY_CORE_HTTP2_HEADER_BLOCK_H_ + +#include + +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/quiche_linked_hash_map.h" +#include "quiche/common/quiche_text_utils.h" +#include "quiche/spdy/core/http2_header_storage.h" + +namespace spdy { + +namespace test { +class Http2HeaderBlockPeer; +class ValueProxyPeer; +} // namespace test + +#ifndef SPDY_HEADER_DEBUG +#if !defined(NDEBUG) || defined(ADDRESS_SANITIZER) +#define SPDY_HEADER_DEBUG 1 +#else // !defined(NDEBUG) || defined(ADDRESS_SANITIZER) +#define SPDY_HEADER_DEBUG 0 +#endif // !defined(NDEBUG) || defined(ADDRESS_SANITIZER) +#endif // SPDY_HEADER_DEBUG + +// This class provides a key-value map that can be used to store SPDY header +// names and values. This data structure preserves insertion order. +// +// Under the hood, this data structure uses large, contiguous blocks of memory +// to store names and values. Lookups may be performed with absl::string_view +// keys, and values are returned as absl::string_views (via ValueProxy, below). +// Value absl::string_views are valid as long as the Http2HeaderBlock exists; +// allocated memory is never freed until Http2HeaderBlock's destruction. +// +// This implementation does not make much of an effort to minimize wasted space. +// It's expected that keys are rarely deleted from a Http2HeaderBlock. +class QUICHE_EXPORT Http2HeaderBlock { + private: + // Stores a list of value fragments that can be joined later with a + // key-dependent separator. + class QUICHE_EXPORT HeaderValue { + public: + HeaderValue(Http2HeaderStorage* storage, absl::string_view key, + absl::string_view initial_value); + + // Moves are allowed. + HeaderValue(HeaderValue&& other); + HeaderValue& operator=(HeaderValue&& other); + + void set_storage(Http2HeaderStorage* storage); + + // Copies are not. + HeaderValue(const HeaderValue& other) = delete; + HeaderValue& operator=(const HeaderValue& other) = delete; + + ~HeaderValue(); + + // Consumes at most |fragment.size()| bytes of memory. + void Append(absl::string_view fragment); + + absl::string_view value() const { return as_pair().second; } + const std::pair& as_pair() const; + + // Size estimate including separators. Used when keys are erased from + // Http2HeaderBlock. + size_t SizeEstimate() const { return size_; } + + private: + // May allocate a large contiguous region of memory to hold the concatenated + // fragments and separators. + absl::string_view ConsolidatedValue() const; + + mutable Http2HeaderStorage* storage_; + mutable std::vector fragments_; + // The first element is the key; the second is the consolidated value. + mutable std::pair pair_; + size_t size_ = 0; + size_t separator_size_ = 0; + }; + + typedef quiche::QuicheLinkedHashMap + MapType; + + public: + typedef std::pair value_type; + + // Provides iteration over a sequence of std::pair, even though the underlying MapType::value_type is + // different. Dereferencing the iterator will result in memory allocation for + // multi-value headers. + class QUICHE_EXPORT iterator { + public: + // The following type definitions fulfill the requirements for iterator + // implementations. + typedef std::pair value_type; + typedef value_type& reference; + typedef value_type* pointer; + typedef std::forward_iterator_tag iterator_category; + typedef MapType::iterator::difference_type difference_type; + + // In practice, this iterator only offers access to const value_type. + typedef const value_type& const_reference; + typedef const value_type* const_pointer; + + explicit iterator(MapType::const_iterator it); + iterator(const iterator& other); + ~iterator(); + + // This will result in memory allocation if the value consists of multiple + // fragments. + const_reference operator*() const { +#if SPDY_HEADER_DEBUG + QUICHE_CHECK(!dereference_forbidden_); +#endif // SPDY_HEADER_DEBUG + return it_->second.as_pair(); + } + + const_pointer operator->() const { return &(this->operator*()); } + bool operator==(const iterator& it) const { return it_ == it.it_; } + bool operator!=(const iterator& it) const { return !(*this == it); } + + iterator& operator++() { + it_++; + return *this; + } + + iterator operator++(int) { + auto ret = *this; + this->operator++(); + return ret; + } + +#if SPDY_HEADER_DEBUG + void forbid_dereference() { dereference_forbidden_ = true; } +#endif // SPDY_HEADER_DEBUG + + private: + MapType::const_iterator it_; +#if SPDY_HEADER_DEBUG + bool dereference_forbidden_ = false; +#endif // SPDY_HEADER_DEBUG + }; + typedef iterator const_iterator; + + Http2HeaderBlock(); + Http2HeaderBlock(const Http2HeaderBlock& other) = delete; + Http2HeaderBlock(Http2HeaderBlock&& other); + ~Http2HeaderBlock(); + + Http2HeaderBlock& operator=(const Http2HeaderBlock& other) = delete; + Http2HeaderBlock& operator=(Http2HeaderBlock&& other); + Http2HeaderBlock Clone() const; + + bool operator==(const Http2HeaderBlock& other) const; + bool operator!=(const Http2HeaderBlock& other) const; + + // Provides a human readable multi-line representation of the stored header + // keys and values. + std::string DebugString() const; + + iterator begin() { return wrap_iterator(map_.begin()); } + iterator end() { return wrap_iterator(map_.end()); } + const_iterator begin() const { return wrap_const_iterator(map_.begin()); } + const_iterator end() const { return wrap_const_iterator(map_.end()); } + bool empty() const { return map_.empty(); } + size_t size() const { return map_.size(); } + iterator find(absl::string_view key) { return wrap_iterator(map_.find(key)); } + const_iterator find(absl::string_view key) const { + return wrap_const_iterator(map_.find(key)); + } + bool contains(absl::string_view key) const { return find(key) != end(); } + void erase(absl::string_view key); + + // Clears both our MapType member and the memory used to hold headers. + void clear(); + + // The next few methods copy data into our backing storage. + + // If key already exists in the block, replaces the value of that key. Else + // adds a new header to the end of the block. + void insert(const value_type& value); + + // If a header with the key is already present, then append the value to the + // existing header value, NUL ("\0") separated unless the key is cookie, in + // which case the separator is "; ". + // If there is no such key, a new header with the key and value is added. + void AppendValueOrAddHeader(const absl::string_view key, + const absl::string_view value); + + // This object provides automatic conversions that allow Http2HeaderBlock to + // be nearly a drop-in replacement for + // SpdyLinkedHashMap. + // It reads data from or writes data to a Http2HeaderStorage. + class QUICHE_EXPORT ValueProxy { + public: + ~ValueProxy(); + + // Moves are allowed. + ValueProxy(ValueProxy&& other); + ValueProxy& operator=(ValueProxy&& other); + + // Copies are not. + ValueProxy(const ValueProxy& other) = delete; + ValueProxy& operator=(const ValueProxy& other) = delete; + + // Assignment modifies the underlying Http2HeaderBlock. + ValueProxy& operator=(absl::string_view value); + + // Provides easy comparison against absl::string_view. + bool operator==(absl::string_view value) const; + + std::string as_string() const; + + private: + friend class Http2HeaderBlock; + friend class test::ValueProxyPeer; + + ValueProxy(Http2HeaderBlock* block, + Http2HeaderBlock::MapType::iterator lookup_result, + const absl::string_view key, + size_t* spdy_header_block_value_size); + + Http2HeaderBlock* block_; + Http2HeaderBlock::MapType::iterator lookup_result_; + absl::string_view key_; + size_t* spdy_header_block_value_size_; + bool valid_; + }; + + // Allows either lookup or mutation of the value associated with a key. + ABSL_MUST_USE_RESULT ValueProxy operator[](const absl::string_view key); + + size_t TotalBytesUsed() const { return key_size_ + value_size_; } + + private: + friend class test::Http2HeaderBlockPeer; + + inline iterator wrap_iterator(MapType::const_iterator inner_iterator) const { +#if SPDY_HEADER_DEBUG + iterator outer_iterator(inner_iterator); + if (inner_iterator == map_.end()) { + outer_iterator.forbid_dereference(); + } + return outer_iterator; +#else // SPDY_HEADER_DEBUG + return iterator(inner_iterator); +#endif // SPDY_HEADER_DEBUG + } + + inline const_iterator wrap_const_iterator( + MapType::const_iterator inner_iterator) const { +#if SPDY_HEADER_DEBUG + const_iterator outer_iterator(inner_iterator); + if (inner_iterator == map_.end()) { + outer_iterator.forbid_dereference(); + } + return outer_iterator; +#else // SPDY_HEADER_DEBUG + return iterator(inner_iterator); +#endif // SPDY_HEADER_DEBUG + } + + void AppendHeader(const absl::string_view key, const absl::string_view value); + absl::string_view WriteKey(const absl::string_view key); + size_t bytes_allocated() const; + + // absl::string_views held by |map_| point to memory owned by |storage_|. + MapType map_; + Http2HeaderStorage storage_; + + size_t key_size_ = 0; + size_t value_size_ = 0; +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_HTTP2_HEADER_BLOCK_H_ diff --git a/quiche/spdy/core/http2_header_block_hpack_listener.h b/quiche/spdy/core/http2_header_block_hpack_listener.h new file mode 100644 index 000000000000..2733a855cb11 --- /dev/null +++ b/quiche/spdy/core/http2_header_block_hpack_listener.h @@ -0,0 +1,49 @@ +#ifndef QUICHE_SPDY_CORE_HTTP2_HEADER_BLOCK_HPACK_LISTENER_H_ +#define QUICHE_SPDY_CORE_HTTP2_HEADER_BLOCK_HPACK_LISTENER_H_ + +#include "absl/strings/string_view.h" +#include "quiche/http2/hpack/decoder/hpack_decoder_listener.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace spdy { + +// This class simply gathers the key-value pairs emitted by an HpackDecoder in +// a Http2HeaderBlock. +class QUICHE_EXPORT Http2HeaderBlockHpackListener + : public http2::HpackDecoderListener { + public: + Http2HeaderBlockHpackListener() {} + + void OnHeaderListStart() override { + header_block_.clear(); + hpack_error_ = false; + } + + void OnHeader(const std::string& name, const std::string& value) override { + header_block_.AppendValueOrAddHeader(name, value); + } + + void OnHeaderListEnd() override {} + + void OnHeaderErrorDetected(absl::string_view error_message) override { + QUICHE_VLOG(1) << error_message; + hpack_error_ = true; + } + + Http2HeaderBlock release_header_block() { + Http2HeaderBlock block = std::move(header_block_); + header_block_ = {}; + return block; + } + bool hpack_error() const { return hpack_error_; } + + private: + Http2HeaderBlock header_block_; + bool hpack_error_ = false; +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_HTTP2_HEADER_BLOCK_HPACK_LISTENER_H_ diff --git a/quiche/spdy/core/http2_header_block_test.cc b/quiche/spdy/core/http2_header_block_test.cc new file mode 100644 index 000000000000..1ad045f16cfc --- /dev/null +++ b/quiche/spdy/core/http2_header_block_test.cc @@ -0,0 +1,295 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/http2_header_block.h" + +#include +#include + +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/spdy/test_tools/spdy_test_utils.h" + +using ::testing::ElementsAre; + +namespace spdy { +namespace test { + +class ValueProxyPeer { + public: + static absl::string_view key(Http2HeaderBlock::ValueProxy* p) { + return p->key_; + } +}; + +std::pair Pair(absl::string_view k, + absl::string_view v) { + return std::make_pair(k, v); +} + +// This test verifies that Http2HeaderBlock behaves correctly when empty. +TEST(Http2HeaderBlockTest, EmptyBlock) { + Http2HeaderBlock block; + EXPECT_TRUE(block.empty()); + EXPECT_EQ(0u, block.size()); + EXPECT_EQ(block.end(), block.find("foo")); + EXPECT_FALSE(block.contains("foo")); + EXPECT_TRUE(block.end() == block.begin()); + + // Should have no effect. + block.erase("bar"); +} + +TEST(Http2HeaderBlockTest, KeyMemoryReclaimedOnLookup) { + Http2HeaderBlock block; + absl::string_view copied_key1; + { + auto proxy1 = block["some key name"]; + copied_key1 = ValueProxyPeer::key(&proxy1); + } + absl::string_view copied_key2; + { + auto proxy2 = block["some other key name"]; + copied_key2 = ValueProxyPeer::key(&proxy2); + } + // Because proxy1 was never used to modify the block, the memory used for the + // key could be reclaimed and used for the second call to operator[]. + // Therefore, we expect the pointers of the two absl::string_views to be + // equal. + EXPECT_EQ(copied_key1.data(), copied_key2.data()); + + { + auto proxy1 = block["some key name"]; + block["some other key name"] = "some value"; + } + // Nothing should blow up when proxy1 is destructed, and we should be able to + // modify and access the Http2HeaderBlock. + block["key"] = "value"; + EXPECT_EQ("value", block["key"]); + EXPECT_EQ("some value", block["some other key name"]); + EXPECT_TRUE(block.find("some key name") == block.end()); +} + +// This test verifies that headers can be set in a variety of ways. +TEST(Http2HeaderBlockTest, AddHeaders) { + Http2HeaderBlock block; + block["foo"] = std::string(300, 'x'); + block["bar"] = "baz"; + block["qux"] = "qux1"; + block["qux"] = "qux2"; + block.insert(std::make_pair("key", "value")); + + EXPECT_EQ(Pair("foo", std::string(300, 'x')), *block.find("foo")); + EXPECT_EQ("baz", block["bar"]); + std::string qux("qux"); + EXPECT_EQ("qux2", block[qux]); + ASSERT_NE(block.end(), block.find("key")); + ASSERT_TRUE(block.contains("key")); + EXPECT_EQ(Pair("key", "value"), *block.find("key")); + + block.erase("key"); + EXPECT_EQ(block.end(), block.find("key")); +} + +// This test verifies that Http2HeaderBlock can be copied using Clone(). +TEST(Http2HeaderBlockTest, CopyBlocks) { + Http2HeaderBlock block1; + block1["foo"] = std::string(300, 'x'); + block1["bar"] = "baz"; + block1.insert(std::make_pair("qux", "qux1")); + + Http2HeaderBlock block2 = block1.Clone(); + Http2HeaderBlock block3(block1.Clone()); + + EXPECT_EQ(block1, block2); + EXPECT_EQ(block1, block3); +} + +TEST(Http2HeaderBlockTest, Equality) { + // Test equality and inequality operators. + Http2HeaderBlock block1; + block1["foo"] = "bar"; + + Http2HeaderBlock block2; + block2["foo"] = "bar"; + + Http2HeaderBlock block3; + block3["baz"] = "qux"; + + EXPECT_EQ(block1, block2); + EXPECT_NE(block1, block3); + + block2["baz"] = "qux"; + EXPECT_NE(block1, block2); +} + +Http2HeaderBlock ReturnTestHeaderBlock() { + Http2HeaderBlock block; + block["foo"] = "bar"; + block.insert(std::make_pair("foo2", "baz")); + return block; +} + +// Test that certain methods do not crash on moved-from instances. +TEST(Http2HeaderBlockTest, MovedFromIsValid) { + Http2HeaderBlock block1; + block1["foo"] = "bar"; + + Http2HeaderBlock block2(std::move(block1)); + EXPECT_THAT(block2, ElementsAre(Pair("foo", "bar"))); + + block1["baz"] = "qux"; // NOLINT testing post-move behavior + + Http2HeaderBlock block3(std::move(block1)); + + block1["foo"] = "bar"; // NOLINT testing post-move behavior + + Http2HeaderBlock block4(std::move(block1)); + + block1.clear(); // NOLINT testing post-move behavior + EXPECT_TRUE(block1.empty()); + + block1["foo"] = "bar"; + EXPECT_THAT(block1, ElementsAre(Pair("foo", "bar"))); + + Http2HeaderBlock block5 = ReturnTestHeaderBlock(); + block5.AppendValueOrAddHeader("foo", "bar2"); + EXPECT_THAT(block5, ElementsAre(Pair("foo", std::string("bar\0bar2", 8)), + Pair("foo2", "baz"))); +} + +// This test verifies that headers can be appended to no matter how they were +// added originally. +TEST(Http2HeaderBlockTest, AppendHeaders) { + Http2HeaderBlock block; + block["foo"] = "foo"; + block.AppendValueOrAddHeader("foo", "bar"); + EXPECT_EQ(Pair("foo", std::string("foo\0bar", 7)), *block.find("foo")); + + block.insert(std::make_pair("foo", "baz")); + EXPECT_EQ("baz", block["foo"]); + EXPECT_EQ(Pair("foo", "baz"), *block.find("foo")); + + // Try all four methods of adding an entry. + block["cookie"] = "key1=value1"; + block.AppendValueOrAddHeader("h1", "h1v1"); + block.insert(std::make_pair("h2", "h2v1")); + + block.AppendValueOrAddHeader("h3", "h3v2"); + block.AppendValueOrAddHeader("h2", "h2v2"); + block.AppendValueOrAddHeader("h1", "h1v2"); + block.AppendValueOrAddHeader("cookie", "key2=value2"); + + block.AppendValueOrAddHeader("cookie", "key3=value3"); + block.AppendValueOrAddHeader("h1", "h1v3"); + block.AppendValueOrAddHeader("h2", "h2v3"); + block.AppendValueOrAddHeader("h3", "h3v3"); + block.AppendValueOrAddHeader("h4", "singleton"); + + EXPECT_EQ("key1=value1; key2=value2; key3=value3", block["cookie"]); + EXPECT_EQ("baz", block["foo"]); + EXPECT_EQ(std::string("h1v1\0h1v2\0h1v3", 14), block["h1"]); + EXPECT_EQ(std::string("h2v1\0h2v2\0h2v3", 14), block["h2"]); + EXPECT_EQ(std::string("h3v2\0h3v3", 9), block["h3"]); + EXPECT_EQ("singleton", block["h4"]); +} + +TEST(Http2HeaderBlockTest, CompareValueToStringPiece) { + Http2HeaderBlock block; + block["foo"] = "foo"; + block.AppendValueOrAddHeader("foo", "bar"); + const auto& val = block["foo"]; + const char expected[] = "foo\0bar"; + EXPECT_TRUE(absl::string_view(expected, 7) == val); + EXPECT_TRUE(val == absl::string_view(expected, 7)); + EXPECT_FALSE(absl::string_view(expected, 3) == val); + EXPECT_FALSE(val == absl::string_view(expected, 3)); + const char not_expected[] = "foo\0barextra"; + EXPECT_FALSE(absl::string_view(not_expected, 12) == val); + EXPECT_FALSE(val == absl::string_view(not_expected, 12)); + + const auto& val2 = block["foo2"]; + EXPECT_FALSE(absl::string_view(expected, 7) == val2); + EXPECT_FALSE(val2 == absl::string_view(expected, 7)); + EXPECT_FALSE(absl::string_view("") == val2); + EXPECT_FALSE(val2 == absl::string_view("")); +} + +// This test demonstrates that the Http2HeaderBlock data structure does not +// place any limitations on the characters present in the header names. +TEST(Http2HeaderBlockTest, UpperCaseNames) { + Http2HeaderBlock block; + block["Foo"] = "foo"; + block.AppendValueOrAddHeader("Foo", "bar"); + EXPECT_NE(block.end(), block.find("foo")); + EXPECT_EQ(Pair("Foo", std::string("foo\0bar", 7)), *block.find("Foo")); + + // The map is case insensitive, so updating "foo" modifies the entry + // previously added. + block.AppendValueOrAddHeader("foo", "baz"); + EXPECT_THAT(block, + ElementsAre(Pair("Foo", std::string("foo\0bar\0baz", 11)))); +} + +namespace { +size_t Http2HeaderBlockSize(const Http2HeaderBlock& block) { + size_t size = 0; + for (const auto& pair : block) { + size += pair.first.size() + pair.second.size(); + } + return size; +} +} // namespace + +// Tests Http2HeaderBlock SizeEstimate(). +TEST(Http2HeaderBlockTest, TotalBytesUsed) { + Http2HeaderBlock block; + const size_t value_size = 300; + block["foo"] = std::string(value_size, 'x'); + EXPECT_EQ(block.TotalBytesUsed(), Http2HeaderBlockSize(block)); + block.insert(std::make_pair("key", std::string(value_size, 'x'))); + EXPECT_EQ(block.TotalBytesUsed(), Http2HeaderBlockSize(block)); + block.AppendValueOrAddHeader("abc", std::string(value_size, 'x')); + EXPECT_EQ(block.TotalBytesUsed(), Http2HeaderBlockSize(block)); + + // Replace value for existing key. + block["foo"] = std::string(value_size, 'x'); + EXPECT_EQ(block.TotalBytesUsed(), Http2HeaderBlockSize(block)); + block.insert(std::make_pair("key", std::string(value_size, 'x'))); + EXPECT_EQ(block.TotalBytesUsed(), Http2HeaderBlockSize(block)); + // Add value for existing key. + block.AppendValueOrAddHeader("abc", std::string(value_size, 'x')); + EXPECT_EQ(block.TotalBytesUsed(), Http2HeaderBlockSize(block)); + + // Copies/clones Http2HeaderBlock. + size_t block_size = block.TotalBytesUsed(); + Http2HeaderBlock block_copy = std::move(block); + EXPECT_EQ(block_size, block_copy.TotalBytesUsed()); + + // Erases key. + block_copy.erase("foo"); + EXPECT_EQ(block_copy.TotalBytesUsed(), Http2HeaderBlockSize(block_copy)); + block_copy.erase("key"); + EXPECT_EQ(block_copy.TotalBytesUsed(), Http2HeaderBlockSize(block_copy)); + block_copy.erase("abc"); + EXPECT_EQ(block_copy.TotalBytesUsed(), Http2HeaderBlockSize(block_copy)); +} + +// The order of header fields is preserved. Note that all pseudo-header fields +// must appear before regular header fields, both in HTTP/2 and HTTP/3, see +// https://www.rfc-editor.org/rfc/rfc9113.html#name-http-control-data and +// https://www.rfc-editor.org/rfc/rfc9114.html#name-http-control-data. It is +// the responsibility of the higher layer to add header fields in the correct +// order. +TEST(Http2HeaderBlockTest, OrderPreserved) { + Http2HeaderBlock block; + block[":method"] = "GET"; + block["foo"] = "bar"; + block[":path"] = "/"; + + EXPECT_THAT(block, ElementsAre(Pair(":method", "GET"), Pair("foo", "bar"), + Pair(":path", "/"))); +} + +} // namespace test +} // namespace spdy diff --git a/quiche/spdy/core/http2_header_storage.cc b/quiche/spdy/core/http2_header_storage.cc new file mode 100644 index 000000000000..653a62839629 --- /dev/null +++ b/quiche/spdy/core/http2_header_storage.cc @@ -0,0 +1,59 @@ +#include "quiche/spdy/core/http2_header_storage.h" + +#include + +#include "quiche/common/platform/api/quiche_logging.h" + +namespace spdy { +namespace { + +// Http2HeaderStorage allocates blocks of this size by default. +const size_t kDefaultStorageBlockSize = 2048; + +} // namespace + +Http2HeaderStorage::Http2HeaderStorage() : arena_(kDefaultStorageBlockSize) {} + +absl::string_view Http2HeaderStorage::Write(const absl::string_view s) { + return absl::string_view(arena_.Memdup(s.data(), s.size()), s.size()); +} + +void Http2HeaderStorage::Rewind(const absl::string_view s) { + arena_.Free(const_cast(s.data()), s.size()); +} + +absl::string_view Http2HeaderStorage::WriteFragments( + const std::vector& fragments, + absl::string_view separator) { + if (fragments.empty()) { + return absl::string_view(); + } + size_t total_size = separator.size() * (fragments.size() - 1); + for (const absl::string_view& fragment : fragments) { + total_size += fragment.size(); + } + char* dst = arena_.Alloc(total_size); + size_t written = Join(dst, fragments, separator); + QUICHE_DCHECK_EQ(written, total_size); + return absl::string_view(dst, total_size); +} + +size_t Join(char* dst, const std::vector& fragments, + absl::string_view separator) { + if (fragments.empty()) { + return 0; + } + auto* original_dst = dst; + auto it = fragments.begin(); + memcpy(dst, it->data(), it->size()); + dst += it->size(); + for (++it; it != fragments.end(); ++it) { + memcpy(dst, separator.data(), separator.size()); + dst += separator.size(); + memcpy(dst, it->data(), it->size()); + dst += it->size(); + } + return dst - original_dst; +} + +} // namespace spdy diff --git a/quiche/spdy/core/http2_header_storage.h b/quiche/spdy/core/http2_header_storage.h new file mode 100644 index 000000000000..bc275b39b5bd --- /dev/null +++ b/quiche/spdy/core/http2_header_storage.h @@ -0,0 +1,58 @@ +#ifndef QUICHE_SPDY_CORE_HTTP2_HEADER_STORAGE_H_ +#define QUICHE_SPDY_CORE_HTTP2_HEADER_STORAGE_H_ + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/spdy/core/spdy_simple_arena.h" + +namespace spdy { + +// This class provides a backing store for absl::string_views. It previously +// used custom allocation logic, but now uses an UnsafeArena instead. It has the +// property that absl::string_views that refer to data in Http2HeaderStorage are +// never invalidated until the Http2HeaderStorage is deleted or Clear() is +// called. +// +// Write operations always append to the last block. If there is not enough +// space to perform the write, a new block is allocated, and any unused space +// is wasted. +class QUICHE_EXPORT Http2HeaderStorage { + public: + Http2HeaderStorage(); + + Http2HeaderStorage(const Http2HeaderStorage&) = delete; + Http2HeaderStorage& operator=(const Http2HeaderStorage&) = delete; + + Http2HeaderStorage(Http2HeaderStorage&& other) = default; + Http2HeaderStorage& operator=(Http2HeaderStorage&& other) = default; + + absl::string_view Write(absl::string_view s); + + // If |s| points to the most recent allocation from arena_, the arena will + // reclaim the memory. Otherwise, this method is a no-op. + void Rewind(absl::string_view s); + + void Clear() { arena_.Reset(); } + + // Given a list of fragments and a separator, writes the fragments joined by + // the separator to a contiguous region of memory. Returns a absl::string_view + // pointing to the region of memory. + absl::string_view WriteFragments( + const std::vector& fragments, + absl::string_view separator); + + size_t bytes_allocated() const { return arena_.status().bytes_allocated(); } + + private: + SpdySimpleArena arena_; +}; + +// Writes |fragments| to |dst|, joined by |separator|. |dst| must be large +// enough to hold the result. Returns the number of bytes written. +QUICHE_EXPORT size_t Join(char* dst, + const std::vector& fragments, + absl::string_view separator); + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_HTTP2_HEADER_STORAGE_H_ diff --git a/quiche/spdy/core/http2_header_storage_test.cc b/quiche/spdy/core/http2_header_storage_test.cc new file mode 100644 index 000000000000..bfe9b5e4cf64 --- /dev/null +++ b/quiche/spdy/core/http2_header_storage_test.cc @@ -0,0 +1,35 @@ +#include "quiche/spdy/core/http2_header_storage.h" + +#include "quiche/common/platform/api/quiche_test.h" + +namespace spdy { +namespace test { + +TEST(JoinTest, JoinEmpty) { + std::vector empty; + absl::string_view separator = ", "; + char buf[10] = ""; + size_t written = Join(buf, empty, separator); + EXPECT_EQ(0u, written); +} + +TEST(JoinTest, JoinOne) { + std::vector v = {"one"}; + absl::string_view separator = ", "; + char buf[15]; + size_t written = Join(buf, v, separator); + EXPECT_EQ(3u, written); + EXPECT_EQ("one", absl::string_view(buf, written)); +} + +TEST(JoinTest, JoinMultiple) { + std::vector v = {"one", "two", "three"}; + absl::string_view separator = ", "; + char buf[15]; + size_t written = Join(buf, v, separator); + EXPECT_EQ(15u, written); + EXPECT_EQ("one, two, three", absl::string_view(buf, written)); +} + +} // namespace test +} // namespace spdy diff --git a/quiche/spdy/core/metadata_extension.cc b/quiche/spdy/core/metadata_extension.cc new file mode 100644 index 000000000000..c22167f10e06 --- /dev/null +++ b/quiche/spdy/core/metadata_extension.cc @@ -0,0 +1,176 @@ +#include "quiche/spdy/core/metadata_extension.h" + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "quiche/http2/decoder/decode_buffer.h" +#include "quiche/http2/hpack/decoder/hpack_decoder.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/spdy/core/http2_header_block_hpack_listener.h" + +namespace spdy { + +// Non-standard constants related to METADATA frames. +const SpdySettingsId MetadataVisitor::kMetadataExtensionId = 0x4d44; +const uint8_t MetadataVisitor::kMetadataFrameType = 0x4d; +const uint8_t MetadataVisitor::kEndMetadataFlag = 0x4; + +namespace { + +const size_t kMaxMetadataBlockSize = 1 << 20; // 1 MB + +} // anonymous namespace + +MetadataFrameSequence::MetadataFrameSequence(SpdyStreamId stream_id, + spdy::Http2HeaderBlock payload) + : stream_id_(stream_id), payload_(std::move(payload)) { + // Metadata should not use HPACK compression. + encoder_.DisableCompression(); + HpackEncoder::Representations r; + for (const auto& kv_pair : payload_) { + r.push_back(kv_pair); + } + progressive_encoder_ = encoder_.EncodeRepresentations(r); +} + +bool MetadataFrameSequence::HasNext() const { + return progressive_encoder_->HasNext(); +} + +std::unique_ptr MetadataFrameSequence::Next() { + if (!HasNext()) { + return nullptr; + } + // METADATA frames obey the HTTP/2 maximum frame size. + std::string payload = + progressive_encoder_->Next(spdy::kHttp2DefaultFramePayloadLimit); + const bool end_metadata = !HasNext(); + const uint8_t flags = end_metadata ? MetadataVisitor::kEndMetadataFlag : 0; + return std::make_unique( + stream_id_, MetadataVisitor::kMetadataFrameType, flags, + std::move(payload)); +} + +struct MetadataVisitor::MetadataPayloadState { + MetadataPayloadState(size_t remaining, bool end) + : bytes_remaining(remaining), end_metadata(end) {} + std::list buffer; + size_t bytes_remaining; + bool end_metadata; +}; + +MetadataVisitor::MetadataVisitor(OnCompletePayload on_payload, + OnMetadataSupport on_support) + : on_payload_(std::move(on_payload)), + on_support_(std::move(on_support)), + peer_supports_metadata_(MetadataSupportState::UNSPECIFIED) {} + +MetadataVisitor::~MetadataVisitor() {} + +void MetadataVisitor::OnSetting(SpdySettingsId id, uint32_t value) { + QUICHE_VLOG(1) << "MetadataVisitor::OnSetting(" << id << ", " << value << ")"; + if (id == kMetadataExtensionId) { + if (value == 0) { + const MetadataSupportState previous_state = peer_supports_metadata_; + peer_supports_metadata_ = MetadataSupportState::NOT_SUPPORTED; + if (previous_state == MetadataSupportState::UNSPECIFIED || + previous_state == MetadataSupportState::SUPPORTED) { + on_support_(false); + } + } else if (value == 1) { + const MetadataSupportState previous_state = peer_supports_metadata_; + peer_supports_metadata_ = MetadataSupportState::SUPPORTED; + if (previous_state == MetadataSupportState::UNSPECIFIED || + previous_state == MetadataSupportState::NOT_SUPPORTED) { + on_support_(true); + } + } else { + QUICHE_LOG_EVERY_N_SEC(WARNING, 1) + << "Unrecognized value for setting " << id << ": " << value; + } + } +} + +bool MetadataVisitor::OnFrameHeader(SpdyStreamId stream_id, size_t length, + uint8_t type, uint8_t flags) { + QUICHE_VLOG(1) << "OnFrameHeader(stream_id=" << stream_id + << ", length=" << length << ", type=" << static_cast(type) + << ", flags=" << static_cast(flags); + // TODO(birenroy): Consider disabling METADATA handling until our setting + // advertising METADATA support has been acked. + if (type != kMetadataFrameType) { + return false; + } + auto it = metadata_map_.find(stream_id); + if (it == metadata_map_.end()) { + auto state = std::make_unique( + length, flags & kEndMetadataFlag); + auto result = + metadata_map_.insert(std::make_pair(stream_id, std::move(state))); + QUICHE_BUG_IF(bug_if_2781_1, !result.second) << "Map insertion failed."; + it = result.first; + } else { + QUICHE_BUG_IF(bug_22051_1, it->second->end_metadata) + << "Inconsistent metadata payload state!"; + QUICHE_BUG_IF(bug_if_2781_2, it->second->bytes_remaining > 0) + << "Incomplete metadata block!"; + } + + if (it->second == nullptr) { + QUICHE_BUG(bug_2781_3) << "Null metadata payload state!"; + return false; + } + current_stream_ = stream_id; + it->second->bytes_remaining = length; + it->second->end_metadata = (flags & kEndMetadataFlag); + return true; +} + +void MetadataVisitor::OnFramePayload(const char* data, size_t len) { + QUICHE_VLOG(1) << "OnFramePayload(stream_id=" << current_stream_ + << ", len=" << len << ")"; + auto it = metadata_map_.find(current_stream_); + if (it == metadata_map_.end() || it->second == nullptr) { + QUICHE_BUG(bug_2781_4) << "Invalid order of operations on MetadataVisitor."; + } else { + MetadataPayloadState* state = it->second.get(); // For readability. + state->buffer.push_back(std::string(data, len)); + if (len < state->bytes_remaining) { + state->bytes_remaining -= len; + } else { + QUICHE_BUG_IF(bug_22051_2, len > state->bytes_remaining) + << "Metadata payload overflow! len: " << len + << " bytes_remaining: " << state->bytes_remaining; + state->bytes_remaining = 0; + if (state->end_metadata) { + // The whole process of decoding the HPACK-encoded metadata block, + // below, is more cumbersome than it ought to be. + spdy::Http2HeaderBlockHpackListener listener; + http2::HpackDecoder decoder(&listener, kMaxMetadataBlockSize); + + // If any operations fail, the decode process should be aborted. + bool success = decoder.StartDecodingBlock(); + for (const std::string& slice : state->buffer) { + if (!success) { + break; + } + http2::DecodeBuffer buffer(slice.data(), slice.size()); + success = success && decoder.DecodeFragment(&buffer); + } + success = + success && decoder.EndDecodingBlock() && !listener.hpack_error(); + if (success) { + on_payload_(current_stream_, listener.release_header_block()); + } + // TODO(birenroy): add varz counting metadata decode successes/failures. + metadata_map_.erase(it); + } + } + } +} + +} // namespace spdy diff --git a/quiche/spdy/core/metadata_extension.h b/quiche/spdy/core/metadata_extension.h new file mode 100644 index 000000000000..9ac2b91d50e4 --- /dev/null +++ b/quiche/spdy/core/metadata_extension.h @@ -0,0 +1,122 @@ +#ifndef QUICHE_SPDY_CORE_METADATA_EXTENSION_H_ +#define QUICHE_SPDY_CORE_METADATA_EXTENSION_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/spdy/core/hpack/hpack_encoder.h" +#include "quiche/spdy/core/http2_frame_decoder_adapter.h" +#include "quiche/spdy/core/http2_header_block.h" +#include "quiche/spdy/core/spdy_protocol.h" +#include "quiche/spdy/core/zero_copy_output_buffer.h" + +namespace spdy { + +// An implementation of the ExtensionVisitorInterface that can parse +// METADATA frames. METADATA is a non-standard HTTP/2 extension developed and +// used internally at Google. A peer advertises support for METADATA by sending +// a setting with a setting ID of kMetadataExtensionId and a value of 1. +// +// Metadata is represented as a HPACK header block with literal encoding. +class QUICHE_EXPORT MetadataVisitor : public spdy::ExtensionVisitorInterface { + public: + using MetadataPayload = spdy::Http2HeaderBlock; + + static_assert(!std::is_copy_constructible::value, + "MetadataPayload should be a move-only type!"); + + using OnMetadataSupport = std::function; + using OnCompletePayload = + std::function; + + // The HTTP/2 SETTINGS ID that is used to indicate support for METADATA + // frames. + static const spdy::SpdySettingsId kMetadataExtensionId; + + // The 8-bit frame type code for a METADATA frame. + static const uint8_t kMetadataFrameType; + + // The flag that indicates the end of a logical metadata block. Due to frame + // size limits, a single metadata block may be emitted as several HTTP/2 + // frames. + static const uint8_t kEndMetadataFlag; + + // |on_payload| is invoked whenever a complete metadata payload is received. + // |on_support| is invoked whenever the peer's advertised support for metadata + // changes. + MetadataVisitor(OnCompletePayload on_payload, OnMetadataSupport on_support); + ~MetadataVisitor() override; + + MetadataVisitor(const MetadataVisitor&) = delete; + MetadataVisitor& operator=(const MetadataVisitor&) = delete; + + // Interprets the non-standard setting indicating support for METADATA. + void OnSetting(spdy::SpdySettingsId id, uint32_t value) override; + + // Returns true iff |type| indicates a METADATA frame. + bool OnFrameHeader(spdy::SpdyStreamId stream_id, size_t length, uint8_t type, + uint8_t flags) override; + + // Consumes a METADATA frame payload. Invokes the registered callback when a + // complete payload has been received. + void OnFramePayload(const char* data, size_t len) override; + + // Returns true if the peer has advertised support for METADATA via the + // appropriate setting. + bool PeerSupportsMetadata() const { + return peer_supports_metadata_ == MetadataSupportState::SUPPORTED; + } + + private: + enum class MetadataSupportState : uint8_t { + UNSPECIFIED, + SUPPORTED, + NOT_SUPPORTED, + }; + + struct MetadataPayloadState; + + using StreamMetadataMap = + absl::flat_hash_map>; + + OnCompletePayload on_payload_; + OnMetadataSupport on_support_; + StreamMetadataMap metadata_map_; + spdy::SpdyStreamId current_stream_; + MetadataSupportState peer_supports_metadata_; +}; + +// This class uses an HpackEncoder to serialize a METADATA block as a series of +// METADATA frames. +class QUICHE_EXPORT MetadataFrameSequence { + public: + MetadataFrameSequence(SpdyStreamId stream_id, spdy::Http2HeaderBlock payload); + + // Copies are not allowed. + MetadataFrameSequence(const MetadataFrameSequence& other) = delete; + MetadataFrameSequence& operator=(const MetadataFrameSequence& other) = delete; + + // True if Next() would return non-nullptr. + bool HasNext() const; + + // Returns the next HTTP/2 METADATA frame for this block, unless the block has + // been entirely serialized in frames returned by previous calls of Next(), in + // which case returns nullptr. + std::unique_ptr Next(); + + SpdyStreamId stream_id() const { return stream_id_; } + + private: + SpdyStreamId stream_id_; + Http2HeaderBlock payload_; + HpackEncoder encoder_; + std::unique_ptr progressive_encoder_; +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_METADATA_EXTENSION_H_ diff --git a/quiche/spdy/core/metadata_extension_test.cc b/quiche/spdy/core/metadata_extension_test.cc new file mode 100644 index 000000000000..ee20e6d56c61 --- /dev/null +++ b/quiche/spdy/core/metadata_extension_test.cc @@ -0,0 +1,281 @@ +#include "quiche/spdy/core/metadata_extension.h" + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/functional/bind_front.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/spdy/core/array_output_buffer.h" +#include "quiche/spdy/core/spdy_framer.h" +#include "quiche/spdy/core/spdy_no_op_visitor.h" +#include "quiche/spdy/core/spdy_protocol.h" +#include "quiche/spdy/test_tools/mock_spdy_framer_visitor.h" + +namespace spdy { +namespace test { +namespace { + +using ::absl::bind_front; +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::HasSubstr; +using ::testing::IsEmpty; + +const size_t kBufferSize = 64 * 1024; +char kBuffer[kBufferSize]; + +class MetadataExtensionTest : public quiche::test::QuicheTest { + protected: + MetadataExtensionTest() : test_buffer_(kBuffer, kBufferSize) {} + + void SetUp() override { + extension_ = std::make_unique( + bind_front(&MetadataExtensionTest::OnCompletePayload, this), + bind_front(&MetadataExtensionTest::OnMetadataSupport, this)); + } + + void OnCompletePayload(spdy::SpdyStreamId stream_id, + MetadataVisitor::MetadataPayload payload) { + ++received_count_; + received_payload_map_.insert(std::make_pair(stream_id, std::move(payload))); + } + + void OnMetadataSupport(bool peer_supports_metadata) { + EXPECT_EQ(peer_supports_metadata, extension_->PeerSupportsMetadata()); + received_metadata_support_.push_back(peer_supports_metadata); + } + + Http2HeaderBlock PayloadForData(absl::string_view data) { + Http2HeaderBlock block; + block["example-payload"] = data; + return block; + } + + std::unique_ptr extension_; + absl::flat_hash_map + received_payload_map_; + std::vector received_metadata_support_; + size_t received_count_ = 0; + spdy::ArrayOutputBuffer test_buffer_; +}; + +// This test verifies that the MetadataVisitor is initialized to a state where +// it believes the peer does not support metadata. +TEST_F(MetadataExtensionTest, MetadataNotSupported) { + EXPECT_FALSE(extension_->PeerSupportsMetadata()); + EXPECT_THAT(received_metadata_support_, IsEmpty()); +} + +// This test verifies that upon receiving a specific setting, the extension +// realizes that the peer supports metadata. +TEST_F(MetadataExtensionTest, MetadataSupported) { + EXPECT_FALSE(extension_->PeerSupportsMetadata()); + // 3 is not an appropriate value for the metadata extension key. + extension_->OnSetting(MetadataVisitor::kMetadataExtensionId, 3); + EXPECT_FALSE(extension_->PeerSupportsMetadata()); + extension_->OnSetting(MetadataVisitor::kMetadataExtensionId, 1); + ASSERT_TRUE(extension_->PeerSupportsMetadata()); + extension_->OnSetting(MetadataVisitor::kMetadataExtensionId, 0); + EXPECT_FALSE(extension_->PeerSupportsMetadata()); + EXPECT_THAT(received_metadata_support_, ElementsAre(true, false)); +} + +TEST_F(MetadataExtensionTest, MetadataDeliveredToUnknownFrameCallbacks) { + const char kData[] = "some payload"; + Http2HeaderBlock payload = PayloadForData(kData); + + extension_->OnSetting(MetadataVisitor::kMetadataExtensionId, 1); + ASSERT_TRUE(extension_->PeerSupportsMetadata()); + + MetadataFrameSequence sequence(3, std::move(payload)); + + http2::Http2DecoderAdapter deframer; + ::testing::StrictMock visitor; + deframer.set_visitor(&visitor); + + EXPECT_CALL(visitor, + OnCommonHeader(3, _, MetadataVisitor::kMetadataFrameType, _)); + // The Return(true) should not be necessary. http://b/36023792 + EXPECT_CALL(visitor, OnUnknownFrame(3, MetadataVisitor::kMetadataFrameType)) + .WillOnce(::testing::Return(true)); + EXPECT_CALL(visitor, + OnUnknownFrameStart(3, _, MetadataVisitor::kMetadataFrameType, + MetadataVisitor::kEndMetadataFlag)); + EXPECT_CALL(visitor, OnUnknownFramePayload(3, HasSubstr(kData))); + + SpdyFramer framer(SpdyFramer::ENABLE_COMPRESSION); + auto frame = sequence.Next(); + ASSERT_TRUE(frame != nullptr); + while (frame != nullptr) { + const size_t frame_size = framer.SerializeFrame(*frame, &test_buffer_); + ASSERT_GT(frame_size, 0u); + ASSERT_FALSE(deframer.HasError()); + ASSERT_EQ(frame_size, test_buffer_.Size()); + EXPECT_EQ(frame_size, deframer.ProcessInput(kBuffer, frame_size)); + test_buffer_.Reset(); + frame = sequence.Next(); + } + EXPECT_FALSE(deframer.HasError()); + EXPECT_THAT(received_metadata_support_, ElementsAre(true)); +} + +// This test verifies that the METADATA frame emitted by a MetadataExtension +// can be parsed by another SpdyFramer with a MetadataVisitor. +TEST_F(MetadataExtensionTest, MetadataPayloadEndToEnd) { + Http2HeaderBlock block1; + block1["foo"] = "Some metadata value."; + Http2HeaderBlock block2; + block2["bar"] = + "The color taupe truly represents a triumph of the human spirit over " + "adversity."; + block2["baz"] = + "Or perhaps it represents abject surrender to the implacable and " + "incomprehensible forces of the universe."; + const absl::string_view binary_payload{"binary\0payload", 14}; + block2["qux"] = binary_payload; + EXPECT_EQ(binary_payload, block2["qux"]); + for (const Http2HeaderBlock& payload_block : + {std::move(block1), std::move(block2)}) { + extension_->OnSetting(MetadataVisitor::kMetadataExtensionId, 1); + ASSERT_TRUE(extension_->PeerSupportsMetadata()); + + MetadataFrameSequence sequence(3, payload_block.Clone()); + http2::Http2DecoderAdapter deframer; + ::spdy::SpdyNoOpVisitor visitor; + deframer.set_visitor(&visitor); + deframer.set_extension_visitor(extension_.get()); + SpdyFramer framer(SpdyFramer::ENABLE_COMPRESSION); + auto frame = sequence.Next(); + ASSERT_TRUE(frame != nullptr); + while (frame != nullptr) { + const size_t frame_size = framer.SerializeFrame(*frame, &test_buffer_); + ASSERT_GT(frame_size, 0u); + ASSERT_FALSE(deframer.HasError()); + ASSERT_EQ(frame_size, test_buffer_.Size()); + EXPECT_EQ(frame_size, deframer.ProcessInput(kBuffer, frame_size)); + test_buffer_.Reset(); + frame = sequence.Next(); + } + EXPECT_EQ(1u, received_count_); + auto it = received_payload_map_.find(3); + ASSERT_TRUE(it != received_payload_map_.end()); + EXPECT_EQ(payload_block, it->second); + + received_count_ = 0; + received_payload_map_.clear(); + } +} + +// This test verifies that METADATA frames for two different streams can be +// interleaved and still successfully parsed by another SpdyFramer with a +// MetadataVisitor. +TEST_F(MetadataExtensionTest, MetadataPayloadInterleaved) { + const std::string kData1 = std::string(65 * 1024, 'a'); + const std::string kData2 = std::string(65 * 1024, 'b'); + const Http2HeaderBlock payload1 = PayloadForData(kData1); + const Http2HeaderBlock payload2 = PayloadForData(kData2); + + extension_->OnSetting(MetadataVisitor::kMetadataExtensionId, 1); + ASSERT_TRUE(extension_->PeerSupportsMetadata()); + + MetadataFrameSequence sequence1(3, payload1.Clone()); + MetadataFrameSequence sequence2(5, payload2.Clone()); + + http2::Http2DecoderAdapter deframer; + ::spdy::SpdyNoOpVisitor visitor; + deframer.set_visitor(&visitor); + deframer.set_extension_visitor(extension_.get()); + + SpdyFramer framer(SpdyFramer::ENABLE_COMPRESSION); + auto frame1 = sequence1.Next(); + ASSERT_TRUE(frame1 != nullptr); + auto frame2 = sequence2.Next(); + ASSERT_TRUE(frame2 != nullptr); + while (frame1 != nullptr || frame2 != nullptr) { + for (auto frame : {frame1.get(), frame2.get()}) { + if (frame != nullptr) { + const size_t frame_size = framer.SerializeFrame(*frame, &test_buffer_); + ASSERT_GT(frame_size, 0u); + ASSERT_FALSE(deframer.HasError()); + ASSERT_EQ(frame_size, test_buffer_.Size()); + EXPECT_EQ(frame_size, deframer.ProcessInput(kBuffer, frame_size)); + test_buffer_.Reset(); + } + } + frame1 = sequence1.Next(); + frame2 = sequence2.Next(); + } + EXPECT_EQ(2u, received_count_); + auto it = received_payload_map_.find(3); + ASSERT_TRUE(it != received_payload_map_.end()); + EXPECT_EQ(payload1, it->second); + + it = received_payload_map_.find(5); + ASSERT_TRUE(it != received_payload_map_.end()); + EXPECT_EQ(payload2, it->second); +} + +// Test that an empty metadata block is serialized as a single frame with +// END_METADATA set and empty frame payload. +TEST_F(MetadataExtensionTest, EmptyBlock) { + MetadataFrameSequence sequence(1, Http2HeaderBlock{}); + + EXPECT_TRUE(sequence.HasNext()); + std::unique_ptr frame = sequence.Next(); + EXPECT_FALSE(sequence.HasNext()); + + auto* const metadata_frame = static_cast(frame.get()); + EXPECT_EQ(MetadataVisitor::kEndMetadataFlag, + metadata_frame->flags() & MetadataVisitor::kEndMetadataFlag); + EXPECT_TRUE(metadata_frame->payload().empty()); +} + +// Test that a small metadata block is serialized as a single frame with +// END_METADATA set and non-empty frame payload. +TEST_F(MetadataExtensionTest, SmallBlock) { + Http2HeaderBlock metadata_block; + metadata_block["foo"] = "bar"; + MetadataFrameSequence sequence(1, std::move(metadata_block)); + + EXPECT_TRUE(sequence.HasNext()); + std::unique_ptr frame = sequence.Next(); + EXPECT_FALSE(sequence.HasNext()); + + auto* const metadata_frame = static_cast(frame.get()); + EXPECT_EQ(MetadataVisitor::kEndMetadataFlag, + metadata_frame->flags() & MetadataVisitor::kEndMetadataFlag); + EXPECT_LT(0u, metadata_frame->payload().size()); +} + +// Test that a large metadata block is serialized as multiple frames, +// with END_METADATA set only on the last one. +TEST_F(MetadataExtensionTest, LargeBlock) { + Http2HeaderBlock metadata_block; + metadata_block["foo"] = std::string(65 * 1024, 'a'); + MetadataFrameSequence sequence(1, std::move(metadata_block)); + + int frame_count = 0; + while (sequence.HasNext()) { + std::unique_ptr frame = sequence.Next(); + ++frame_count; + + auto* const metadata_frame = static_cast(frame.get()); + EXPECT_LT(0u, metadata_frame->payload().size()); + + if (sequence.HasNext()) { + EXPECT_EQ(0u, + metadata_frame->flags() & MetadataVisitor::kEndMetadataFlag); + } else { + EXPECT_EQ(MetadataVisitor::kEndMetadataFlag, + metadata_frame->flags() & MetadataVisitor::kEndMetadataFlag); + } + } + + EXPECT_LE(2, frame_count); +} + +} // anonymous namespace +} // namespace test +} // namespace spdy diff --git a/quiche/spdy/core/no_op_headers_handler.h b/quiche/spdy/core/no_op_headers_handler.h new file mode 100644 index 000000000000..2e0c82f0eaee --- /dev/null +++ b/quiche/spdy/core/no_op_headers_handler.h @@ -0,0 +1,38 @@ +#ifndef QUICHE_SPDY_CORE_NO_OP_HEADERS_HANDLER_H_ +#define QUICHE_SPDY_CORE_NO_OP_HEADERS_HANDLER_H_ + +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/spdy/core/header_byte_listener_interface.h" +#include "quiche/spdy/core/spdy_headers_handler_interface.h" + +namespace spdy { + +// Drops all header data, but passes information about header bytes parsed to +// a listener. +class QUICHE_EXPORT NoOpHeadersHandler : public SpdyHeadersHandlerInterface { + public: + // Does not take ownership of listener. + explicit NoOpHeadersHandler(HeaderByteListenerInterface* listener) + : listener_(listener) {} + NoOpHeadersHandler(const NoOpHeadersHandler&) = delete; + NoOpHeadersHandler& operator=(const NoOpHeadersHandler&) = delete; + ~NoOpHeadersHandler() override {} + + // From SpdyHeadersHandlerInterface + void OnHeaderBlockStart() override {} + void OnHeader(absl::string_view /*key*/, + absl::string_view /*value*/) override {} + void OnHeaderBlockEnd(size_t uncompressed_header_bytes, + size_t /* compressed_header_bytes */) override { + if (listener_ != nullptr) { + listener_->OnHeaderBytesReceived(uncompressed_header_bytes); + } + } + + private: + HeaderByteListenerInterface* listener_; +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_NO_OP_HEADERS_HANDLER_H_ diff --git a/quiche/spdy/core/recording_headers_handler.cc b/quiche/spdy/core/recording_headers_handler.cc new file mode 100644 index 000000000000..7808e7f6944b --- /dev/null +++ b/quiche/spdy/core/recording_headers_handler.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/recording_headers_handler.h" + +namespace spdy { + +RecordingHeadersHandler::RecordingHeadersHandler( + SpdyHeadersHandlerInterface* wrapped) + : wrapped_(wrapped) {} + +void RecordingHeadersHandler::OnHeaderBlockStart() { + block_.clear(); + if (wrapped_ != nullptr) { + wrapped_->OnHeaderBlockStart(); + } +} + +void RecordingHeadersHandler::OnHeader(absl::string_view key, + absl::string_view value) { + block_.AppendValueOrAddHeader(key, value); + if (wrapped_ != nullptr) { + wrapped_->OnHeader(key, value); + } +} + +void RecordingHeadersHandler::OnHeaderBlockEnd(size_t uncompressed_header_bytes, + size_t compressed_header_bytes) { + uncompressed_header_bytes_ = uncompressed_header_bytes; + compressed_header_bytes_ = compressed_header_bytes; + if (wrapped_ != nullptr) { + wrapped_->OnHeaderBlockEnd(uncompressed_header_bytes, + compressed_header_bytes); + } +} + +} // namespace spdy diff --git a/quiche/spdy/core/recording_headers_handler.h b/quiche/spdy/core/recording_headers_handler.h new file mode 100644 index 000000000000..54165ebb6713 --- /dev/null +++ b/quiche/spdy/core/recording_headers_handler.h @@ -0,0 +1,51 @@ +// Copyright (c) 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_SPDY_CORE_RECORDING_HEADERS_HANDLER_H_ +#define QUICHE_SPDY_CORE_RECORDING_HEADERS_HANDLER_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/spdy/core/http2_header_block.h" +#include "quiche/spdy/core/spdy_headers_handler_interface.h" + +namespace spdy { + +// RecordingHeadersHandler copies the headers emitted from the deframer, and +// when needed can forward events to another wrapped handler. +class QUICHE_EXPORT RecordingHeadersHandler + : public SpdyHeadersHandlerInterface { + public: + explicit RecordingHeadersHandler( + SpdyHeadersHandlerInterface* wrapped = nullptr); + RecordingHeadersHandler(const RecordingHeadersHandler&) = delete; + RecordingHeadersHandler& operator=(const RecordingHeadersHandler&) = delete; + + void OnHeaderBlockStart() override; + + void OnHeader(absl::string_view key, absl::string_view value) override; + + void OnHeaderBlockEnd(size_t uncompressed_header_bytes, + size_t compressed_header_bytes) override; + + const Http2HeaderBlock& decoded_block() const { return block_; } + size_t uncompressed_header_bytes() const { + return uncompressed_header_bytes_; + } + size_t compressed_header_bytes() const { return compressed_header_bytes_; } + + private: + SpdyHeadersHandlerInterface* wrapped_ = nullptr; + Http2HeaderBlock block_; + size_t uncompressed_header_bytes_ = 0; + size_t compressed_header_bytes_ = 0; +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_RECORDING_HEADERS_HANDLER_H_ diff --git a/quiche/spdy/core/spdy_alt_svc_wire_format.cc b/quiche/spdy/core/spdy_alt_svc_wire_format.cc new file mode 100644 index 000000000000..6a68ebbba220 --- /dev/null +++ b/quiche/spdy/core/spdy_alt_svc_wire_format.cc @@ -0,0 +1,420 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/spdy_alt_svc_wire_format.h" + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "quiche/common/platform/api/quiche_logging.h" + +namespace spdy { + +namespace { + +template +bool ParsePositiveIntegerImpl(absl::string_view::const_iterator c, + absl::string_view::const_iterator end, T* value) { + *value = 0; + for (; c != end && std::isdigit(*c); ++c) { + if (*value > std::numeric_limits::max() / 10) { + return false; + } + *value *= 10; + if (*value > std::numeric_limits::max() - (*c - '0')) { + return false; + } + *value += *c - '0'; + } + return (c == end && *value > 0); +} + +} // namespace + +SpdyAltSvcWireFormat::AlternativeService::AlternativeService() = default; + +SpdyAltSvcWireFormat::AlternativeService::AlternativeService( + const std::string& protocol_id, const std::string& host, uint16_t port, + uint32_t max_age_seconds, VersionVector version) + : protocol_id(protocol_id), + host(host), + port(port), + max_age_seconds(max_age_seconds), + version(std::move(version)) {} + +SpdyAltSvcWireFormat::AlternativeService::~AlternativeService() = default; + +SpdyAltSvcWireFormat::AlternativeService::AlternativeService( + const AlternativeService& other) = default; + +// static +bool SpdyAltSvcWireFormat::ParseHeaderFieldValue( + absl::string_view value, AlternativeServiceVector* altsvc_vector) { + // Empty value is invalid according to the specification. + if (value.empty()) { + return false; + } + altsvc_vector->clear(); + if (value == absl::string_view("clear")) { + return true; + } + absl::string_view::const_iterator c = value.begin(); + while (c != value.end()) { + // Parse protocol-id. + absl::string_view::const_iterator percent_encoded_protocol_id_end = + std::find(c, value.end(), '='); + std::string protocol_id; + if (percent_encoded_protocol_id_end == c || + !PercentDecode(c, percent_encoded_protocol_id_end, &protocol_id)) { + return false; + } + // Check for IETF format for advertising QUIC: + // hq=":443";quic=51303338;quic=51303334 + const bool is_ietf_format_quic = (protocol_id == "hq"); + c = percent_encoded_protocol_id_end; + if (c == value.end()) { + return false; + } + // Parse alt-authority. + QUICHE_DCHECK_EQ('=', *c); + ++c; + if (c == value.end() || *c != '"') { + return false; + } + ++c; + absl::string_view::const_iterator alt_authority_begin = c; + for (; c != value.end() && *c != '"'; ++c) { + // Decode backslash encoding. + if (*c != '\\') { + continue; + } + ++c; + if (c == value.end()) { + return false; + } + } + if (c == alt_authority_begin || c == value.end()) { + return false; + } + QUICHE_DCHECK_EQ('"', *c); + std::string host; + uint16_t port; + if (!ParseAltAuthority(alt_authority_begin, c, &host, &port)) { + return false; + } + ++c; + // Parse parameters. + uint32_t max_age_seconds = 86400; + VersionVector version; + absl::string_view::const_iterator parameters_end = + std::find(c, value.end(), ','); + while (c != parameters_end) { + SkipWhiteSpace(&c, parameters_end); + if (c == parameters_end) { + break; + } + if (*c != ';') { + return false; + } + ++c; + SkipWhiteSpace(&c, parameters_end); + if (c == parameters_end) { + break; + } + std::string parameter_name; + for (; c != parameters_end && *c != '=' && *c != ' ' && *c != '\t'; ++c) { + parameter_name.push_back(tolower(*c)); + } + SkipWhiteSpace(&c, parameters_end); + if (c == parameters_end || *c != '=') { + return false; + } + ++c; + SkipWhiteSpace(&c, parameters_end); + absl::string_view::const_iterator parameter_value_begin = c; + for (; c != parameters_end && *c != ';' && *c != ' ' && *c != '\t'; ++c) { + } + if (c == parameter_value_begin) { + return false; + } + if (parameter_name == "ma") { + if (!ParsePositiveInteger32(parameter_value_begin, c, + &max_age_seconds)) { + return false; + } + } else if (!is_ietf_format_quic && parameter_name == "v") { + // Version is a comma separated list of positive integers enclosed in + // quotation marks. Since it can contain commas, which are not + // delineating alternative service entries, |parameters_end| and |c| can + // be invalid. + if (*parameter_value_begin != '"') { + return false; + } + c = std::find(parameter_value_begin + 1, value.end(), '"'); + if (c == value.end()) { + return false; + } + ++c; + parameters_end = std::find(c, value.end(), ','); + absl::string_view::const_iterator v_begin = parameter_value_begin + 1; + while (v_begin < c) { + absl::string_view::const_iterator v_end = v_begin; + while (v_end < c - 1 && *v_end != ',') { + ++v_end; + } + uint16_t v; + if (!ParsePositiveInteger16(v_begin, v_end, &v)) { + return false; + } + version.push_back(v); + v_begin = v_end + 1; + if (v_begin == c - 1) { + // List ends in comma. + return false; + } + } + } else if (is_ietf_format_quic && parameter_name == "quic") { + // IETF format for advertising QUIC. Version is hex encoding of QUIC + // version tag. Hex-encoded string should not include leading "0x" or + // leading zeros. + // Example for advertising QUIC versions "Q038" and "Q034": + // hq=":443";quic=51303338;quic=51303334 + if (*parameter_value_begin == '0') { + return false; + } + // Versions will be stored as the uint32_t hex decoding of the param + // value string. Example: QUIC version "Q038", which is advertised as: + // hq=":443";quic=51303338 + // ... will be stored in |versions| as 0x51303338. + uint32_t quic_version; + if (!HexDecodeToUInt32(absl::string_view(&*parameter_value_begin, + c - parameter_value_begin), + &quic_version) || + quic_version == 0) { + return false; + } + version.push_back(quic_version); + } + } + altsvc_vector->emplace_back(protocol_id, host, port, max_age_seconds, + version); + for (; c != value.end() && (*c == ' ' || *c == '\t' || *c == ','); ++c) { + } + } + return true; +} + +// static +std::string SpdyAltSvcWireFormat::SerializeHeaderFieldValue( + const AlternativeServiceVector& altsvc_vector) { + if (altsvc_vector.empty()) { + return std::string("clear"); + } + const char kNibbleToHex[] = "0123456789ABCDEF"; + std::string value; + for (const AlternativeService& altsvc : altsvc_vector) { + if (!value.empty()) { + value.push_back(','); + } + // Check for IETF format for advertising QUIC. + const bool is_ietf_format_quic = (altsvc.protocol_id == "hq"); + // Percent escape protocol id according to + // http://tools.ietf.org/html/rfc7230#section-3.2.6. + for (char c : altsvc.protocol_id) { + if (isalnum(c)) { + value.push_back(c); + continue; + } + switch (c) { + case '!': + case '#': + case '$': + case '&': + case '\'': + case '*': + case '+': + case '-': + case '.': + case '^': + case '_': + case '`': + case '|': + case '~': + value.push_back(c); + break; + default: + value.push_back('%'); + // Network byte order is big-endian. + value.push_back(kNibbleToHex[c >> 4]); + value.push_back(kNibbleToHex[c & 0x0f]); + break; + } + } + value.push_back('='); + value.push_back('"'); + for (char c : altsvc.host) { + if (c == '"' || c == '\\') { + value.push_back('\\'); + } + value.push_back(c); + } + absl::StrAppend(&value, ":", altsvc.port, "\""); + if (altsvc.max_age_seconds != 86400) { + absl::StrAppend(&value, "; ma=", altsvc.max_age_seconds); + } + if (!altsvc.version.empty()) { + if (is_ietf_format_quic) { + for (uint32_t quic_version : altsvc.version) { + absl::StrAppend(&value, "; quic=", absl::Hex(quic_version)); + } + } else { + value.append("; v=\""); + for (auto it = altsvc.version.begin(); it != altsvc.version.end(); + ++it) { + if (it != altsvc.version.begin()) { + value.append(","); + } + absl::StrAppend(&value, *it); + } + value.append("\""); + } + } + } + return value; +} + +// static +void SpdyAltSvcWireFormat::SkipWhiteSpace( + absl::string_view::const_iterator* c, + absl::string_view::const_iterator end) { + for (; *c != end && (**c == ' ' || **c == '\t'); ++*c) { + } +} + +// static +bool SpdyAltSvcWireFormat::PercentDecode(absl::string_view::const_iterator c, + absl::string_view::const_iterator end, + std::string* output) { + output->clear(); + for (; c != end; ++c) { + if (*c != '%') { + output->push_back(*c); + continue; + } + QUICHE_DCHECK_EQ('%', *c); + ++c; + if (c == end || !std::isxdigit(*c)) { + return false; + } + // Network byte order is big-endian. + char decoded = HexDigitToInt(*c) << 4; + ++c; + if (c == end || !std::isxdigit(*c)) { + return false; + } + decoded += HexDigitToInt(*c); + output->push_back(decoded); + } + return true; +} + +// static +bool SpdyAltSvcWireFormat::ParseAltAuthority( + absl::string_view::const_iterator c, absl::string_view::const_iterator end, + std::string* host, uint16_t* port) { + host->clear(); + if (c == end) { + return false; + } + if (*c == '[') { + for (; c != end && *c != ']'; ++c) { + if (*c == '"') { + // Port is mandatory. + return false; + } + host->push_back(*c); + } + if (c == end) { + return false; + } + QUICHE_DCHECK_EQ(']', *c); + host->push_back(*c); + ++c; + } else { + for (; c != end && *c != ':'; ++c) { + if (*c == '"') { + // Port is mandatory. + return false; + } + if (*c == '\\') { + ++c; + if (c == end) { + return false; + } + } + host->push_back(*c); + } + } + if (c == end || *c != ':') { + return false; + } + QUICHE_DCHECK_EQ(':', *c); + ++c; + return ParsePositiveInteger16(c, end, port); +} + +// static +bool SpdyAltSvcWireFormat::ParsePositiveInteger16( + absl::string_view::const_iterator c, absl::string_view::const_iterator end, + uint16_t* value) { + return ParsePositiveIntegerImpl(c, end, value); +} + +// static +bool SpdyAltSvcWireFormat::ParsePositiveInteger32( + absl::string_view::const_iterator c, absl::string_view::const_iterator end, + uint32_t* value) { + return ParsePositiveIntegerImpl(c, end, value); +} + +// static +char SpdyAltSvcWireFormat::HexDigitToInt(char c) { + QUICHE_DCHECK(std::isxdigit(c)); + + if (std::isdigit(c)) { + return c - '0'; + } + if (c >= 'A' && c <= 'F') { + return c - 'A' + 10; + } + if (c >= 'a' && c <= 'f') { + return c - 'a' + 10; + } + + return 0; +} + +// static +bool SpdyAltSvcWireFormat::HexDecodeToUInt32(absl::string_view data, + uint32_t* value) { + if (data.empty() || data.length() > 8u) { + return false; + } + + *value = 0; + for (char c : data) { + if (!std::isxdigit(c)) { + return false; + } + + *value <<= 4; + *value += HexDigitToInt(c); + } + + return true; +} + +} // namespace spdy diff --git a/quiche/spdy/core/spdy_alt_svc_wire_format.h b/quiche/spdy/core/spdy_alt_svc_wire_format.h new file mode 100644 index 000000000000..ec0c124cd1b6 --- /dev/null +++ b/quiche/spdy/core/spdy_alt_svc_wire_format.h @@ -0,0 +1,104 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This file contains data structures and utility functions used for serializing +// and parsing alternative service header values, common to HTTP/1.1 header +// fields and HTTP/2 and QUIC ALTSVC frames. See specification at +// https://httpwg.github.io/http-extensions/alt-svc.html. + +#ifndef QUICHE_SPDY_CORE_SPDY_ALT_SVC_WIRE_FORMAT_H_ +#define QUICHE_SPDY_CORE_SPDY_ALT_SVC_WIRE_FORMAT_H_ + +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace spdy { + +namespace test { +class SpdyAltSvcWireFormatPeer; +} // namespace test + +class QUICHE_EXPORT SpdyAltSvcWireFormat { + public: + using VersionVector = absl::InlinedVector; + + struct QUICHE_EXPORT AlternativeService { + std::string protocol_id; + std::string host; + + // Default is 0: invalid port. + uint16_t port = 0; + // Default is one day. + uint32_t max_age_seconds = 86400; + // Default is empty: unspecified version. + VersionVector version; + + AlternativeService(); + AlternativeService(const std::string& protocol_id, const std::string& host, + uint16_t port, uint32_t max_age_seconds, + VersionVector version); + AlternativeService(const AlternativeService& other); + ~AlternativeService(); + + bool operator==(const AlternativeService& other) const { + return protocol_id == other.protocol_id && host == other.host && + port == other.port && version == other.version && + max_age_seconds == other.max_age_seconds; + } + }; + // An empty vector means alternative services should be cleared for given + // origin. Note that the wire format for this is the string "clear", not an + // empty value (which is invalid). + typedef std::vector AlternativeServiceVector; + + friend class test::SpdyAltSvcWireFormatPeer; + static bool ParseHeaderFieldValue(absl::string_view value, + AlternativeServiceVector* altsvc_vector); + static std::string SerializeHeaderFieldValue( + const AlternativeServiceVector& altsvc_vector); + + private: + // Forward |*c| over space and tab or until |end| is reached. + static void SkipWhiteSpace(absl::string_view::const_iterator* c, + absl::string_view::const_iterator end); + // Decode percent-decoded string between |c| and |end| into |*output|. + // Return true on success, false if input is invalid. + static bool PercentDecode(absl::string_view::const_iterator c, + absl::string_view::const_iterator end, + std::string* output); + // Parse the authority part of Alt-Svc between |c| and |end| into |*host| and + // |*port|. Return true on success, false if input is invalid. + static bool ParseAltAuthority(absl::string_view::const_iterator c, + absl::string_view::const_iterator end, + std::string* host, uint16_t* port); + // Parse a positive integer between |c| and |end| into |*value|. + // Return true on success, false if input is not a positive integer or it + // cannot be represented on uint16_t. + static bool ParsePositiveInteger16(absl::string_view::const_iterator c, + absl::string_view::const_iterator end, + uint16_t* value); + // Parse a positive integer between |c| and |end| into |*value|. + // Return true on success, false if input is not a positive integer or it + // cannot be represented on uint32_t. + static bool ParsePositiveInteger32(absl::string_view::const_iterator c, + absl::string_view::const_iterator end, + uint32_t* value); + // Parse |c| as hexadecimal digit, case insensitive. |c| must be [0-9a-fA-F]. + // Output is between 0 and 15. + static char HexDigitToInt(char c); + // Parse |data| as hexadecimal number into |*value|. |data| must only contain + // hexadecimal digits, no "0x" prefix. + // Return true on success, false if input is empty, not valid hexadecimal + // number, or cannot be represented on uint32_t. + static bool HexDecodeToUInt32(absl::string_view data, uint32_t* value); +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_SPDY_ALT_SVC_WIRE_FORMAT_H_ diff --git a/quiche/spdy/core/spdy_alt_svc_wire_format_test.cc b/quiche/spdy/core/spdy_alt_svc_wire_format_test.cc new file mode 100644 index 000000000000..50c67a328e76 --- /dev/null +++ b/quiche/spdy/core/spdy_alt_svc_wire_format_test.cc @@ -0,0 +1,636 @@ +// Copyright (c) 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/spdy_alt_svc_wire_format.h" + +#include "quiche/common/platform/api/quiche_test.h" + +namespace spdy { + +namespace test { + +// Expose all private methods of class SpdyAltSvcWireFormat. +class SpdyAltSvcWireFormatPeer { + public: + static void SkipWhiteSpace(absl::string_view::const_iterator* c, + absl::string_view::const_iterator end) { + SpdyAltSvcWireFormat::SkipWhiteSpace(c, end); + } + static bool PercentDecode(absl::string_view::const_iterator c, + absl::string_view::const_iterator end, + std::string* output) { + return SpdyAltSvcWireFormat::PercentDecode(c, end, output); + } + static bool ParseAltAuthority(absl::string_view::const_iterator c, + absl::string_view::const_iterator end, + std::string* host, uint16_t* port) { + return SpdyAltSvcWireFormat::ParseAltAuthority(c, end, host, port); + } + static bool ParsePositiveInteger16(absl::string_view::const_iterator c, + absl::string_view::const_iterator end, + uint16_t* max_age_seconds) { + return SpdyAltSvcWireFormat::ParsePositiveInteger16(c, end, + max_age_seconds); + } + static bool ParsePositiveInteger32(absl::string_view::const_iterator c, + absl::string_view::const_iterator end, + uint32_t* max_age_seconds) { + return SpdyAltSvcWireFormat::ParsePositiveInteger32(c, end, + max_age_seconds); + } + static char HexDigitToInt(char c) { + return SpdyAltSvcWireFormat::HexDigitToInt(c); + } + static bool HexDecodeToUInt32(absl::string_view data, uint32_t* value) { + return SpdyAltSvcWireFormat::HexDecodeToUInt32(data, value); + } +}; + +namespace { + +// Generate header field values, possibly with multiply defined parameters and +// random case, and corresponding AlternativeService entries. +void FuzzHeaderFieldValue( + int i, std::string* header_field_value, + SpdyAltSvcWireFormat::AlternativeService* expected_altsvc) { + if (!header_field_value->empty()) { + header_field_value->push_back(','); + } + // TODO(b/77515496): use struct of bools instead of int |i| to generate the + // header field value. + bool is_ietf_format_quic = (i & 1 << 0) != 0; + if (i & 1 << 0) { + expected_altsvc->protocol_id = "hq"; + header_field_value->append("hq=\""); + } else { + expected_altsvc->protocol_id = "a=b%c"; + header_field_value->append("a%3Db%25c=\""); + } + if (i & 1 << 1) { + expected_altsvc->host = "foo\"bar\\baz"; + header_field_value->append("foo\\\"bar\\\\baz"); + } else { + expected_altsvc->host = ""; + } + expected_altsvc->port = 42; + header_field_value->append(":42\""); + if (i & 1 << 2) { + header_field_value->append(" "); + } + if (i & 3 << 3) { + expected_altsvc->max_age_seconds = 1111; + header_field_value->append(";"); + if (i & 1 << 3) { + header_field_value->append(" "); + } + header_field_value->append("mA=1111"); + if (i & 2 << 3) { + header_field_value->append(" "); + } + } + if (i & 1 << 5) { + header_field_value->append("; J=s"); + } + if (i & 1 << 6) { + if (is_ietf_format_quic) { + if (i & 1 << 7) { + expected_altsvc->version.push_back(0x923457e); + header_field_value->append("; quic=923457E"); + } else { + expected_altsvc->version.push_back(1); + expected_altsvc->version.push_back(0xFFFFFFFF); + header_field_value->append("; quic=1; quic=fFfFffFf"); + } + } else { + if (i & i << 7) { + expected_altsvc->version.push_back(24); + header_field_value->append("; v=\"24\""); + } else { + expected_altsvc->version.push_back(1); + expected_altsvc->version.push_back(65535); + header_field_value->append("; v=\"1,65535\""); + } + } + } + if (i & 1 << 8) { + expected_altsvc->max_age_seconds = 999999999; + header_field_value->append("; Ma=999999999"); + } + if (i & 1 << 9) { + header_field_value->append(";"); + } + if (i & 1 << 10) { + header_field_value->append(" "); + } + if (i & 1 << 11) { + header_field_value->append(","); + } + if (i & 1 << 12) { + header_field_value->append(" "); + } +} + +// Generate AlternativeService entries and corresponding header field values in +// canonical form, that is, what SerializeHeaderFieldValue() should output. +void FuzzAlternativeService(int i, + SpdyAltSvcWireFormat::AlternativeService* altsvc, + std::string* expected_header_field_value) { + if (!expected_header_field_value->empty()) { + expected_header_field_value->push_back(','); + } + altsvc->protocol_id = "a=b%c"; + altsvc->port = 42; + expected_header_field_value->append("a%3Db%25c=\""); + if (i & 1 << 0) { + altsvc->host = "foo\"bar\\baz"; + expected_header_field_value->append("foo\\\"bar\\\\baz"); + } + expected_header_field_value->append(":42\""); + if (i & 1 << 1) { + altsvc->max_age_seconds = 1111; + expected_header_field_value->append("; ma=1111"); + } + if (i & 1 << 2) { + altsvc->version.push_back(24); + altsvc->version.push_back(25); + expected_header_field_value->append("; v=\"24,25\""); + } +} + +// Tests of public API. + +TEST(SpdyAltSvcWireFormatTest, DefaultValues) { + SpdyAltSvcWireFormat::AlternativeService altsvc; + EXPECT_EQ("", altsvc.protocol_id); + EXPECT_EQ("", altsvc.host); + EXPECT_EQ(0u, altsvc.port); + EXPECT_EQ(86400u, altsvc.max_age_seconds); + EXPECT_TRUE(altsvc.version.empty()); +} + +TEST(SpdyAltSvcWireFormatTest, ParseInvalidEmptyHeaderFieldValue) { + SpdyAltSvcWireFormat::AlternativeServiceVector altsvc_vector; + ASSERT_FALSE(SpdyAltSvcWireFormat::ParseHeaderFieldValue("", &altsvc_vector)); +} + +TEST(SpdyAltSvcWireFormatTest, ParseHeaderFieldValueClear) { + SpdyAltSvcWireFormat::AlternativeServiceVector altsvc_vector; + ASSERT_TRUE( + SpdyAltSvcWireFormat::ParseHeaderFieldValue("clear", &altsvc_vector)); + EXPECT_EQ(0u, altsvc_vector.size()); +} + +// Fuzz test of ParseHeaderFieldValue() with optional whitespaces, ignored +// parameters, duplicate parameters, trailing space, trailing alternate service +// separator, etc. Single alternative service at a time. +TEST(SpdyAltSvcWireFormatTest, ParseHeaderFieldValue) { + for (int i = 0; i < 1 << 13; ++i) { + std::string header_field_value; + SpdyAltSvcWireFormat::AlternativeService expected_altsvc; + FuzzHeaderFieldValue(i, &header_field_value, &expected_altsvc); + SpdyAltSvcWireFormat::AlternativeServiceVector altsvc_vector; + ASSERT_TRUE(SpdyAltSvcWireFormat::ParseHeaderFieldValue(header_field_value, + &altsvc_vector)); + ASSERT_EQ(1u, altsvc_vector.size()); + EXPECT_EQ(expected_altsvc.protocol_id, altsvc_vector[0].protocol_id); + EXPECT_EQ(expected_altsvc.host, altsvc_vector[0].host); + EXPECT_EQ(expected_altsvc.port, altsvc_vector[0].port); + EXPECT_EQ(expected_altsvc.max_age_seconds, + altsvc_vector[0].max_age_seconds); + EXPECT_EQ(expected_altsvc.version, altsvc_vector[0].version); + + // Roundtrip test starting with |altsvc_vector|. + std::string reserialized_header_field_value = + SpdyAltSvcWireFormat::SerializeHeaderFieldValue(altsvc_vector); + SpdyAltSvcWireFormat::AlternativeServiceVector roundtrip_altsvc_vector; + ASSERT_TRUE(SpdyAltSvcWireFormat::ParseHeaderFieldValue( + reserialized_header_field_value, &roundtrip_altsvc_vector)); + ASSERT_EQ(1u, roundtrip_altsvc_vector.size()); + EXPECT_EQ(expected_altsvc.protocol_id, + roundtrip_altsvc_vector[0].protocol_id); + EXPECT_EQ(expected_altsvc.host, roundtrip_altsvc_vector[0].host); + EXPECT_EQ(expected_altsvc.port, roundtrip_altsvc_vector[0].port); + EXPECT_EQ(expected_altsvc.max_age_seconds, + roundtrip_altsvc_vector[0].max_age_seconds); + EXPECT_EQ(expected_altsvc.version, roundtrip_altsvc_vector[0].version); + } +} + +// Fuzz test of ParseHeaderFieldValue() with optional whitespaces, ignored +// parameters, duplicate parameters, trailing space, trailing alternate service +// separator, etc. Possibly multiple alternative service at a time. +TEST(SpdyAltSvcWireFormatTest, ParseHeaderFieldValueMultiple) { + for (int i = 0; i < 1 << 13;) { + std::string header_field_value; + SpdyAltSvcWireFormat::AlternativeServiceVector expected_altsvc_vector; + // This will generate almost two hundred header field values with two, + // three, four, five, six, and seven alternative services each, and + // thousands with a single one. + do { + SpdyAltSvcWireFormat::AlternativeService expected_altsvc; + FuzzHeaderFieldValue(i, &header_field_value, &expected_altsvc); + expected_altsvc_vector.push_back(expected_altsvc); + ++i; + } while (i % 6 < i % 7); + SpdyAltSvcWireFormat::AlternativeServiceVector altsvc_vector; + ASSERT_TRUE(SpdyAltSvcWireFormat::ParseHeaderFieldValue(header_field_value, + &altsvc_vector)); + ASSERT_EQ(expected_altsvc_vector.size(), altsvc_vector.size()); + for (unsigned int j = 0; j < altsvc_vector.size(); ++j) { + EXPECT_EQ(expected_altsvc_vector[j].protocol_id, + altsvc_vector[j].protocol_id); + EXPECT_EQ(expected_altsvc_vector[j].host, altsvc_vector[j].host); + EXPECT_EQ(expected_altsvc_vector[j].port, altsvc_vector[j].port); + EXPECT_EQ(expected_altsvc_vector[j].max_age_seconds, + altsvc_vector[j].max_age_seconds); + EXPECT_EQ(expected_altsvc_vector[j].version, altsvc_vector[j].version); + } + + // Roundtrip test starting with |altsvc_vector|. + std::string reserialized_header_field_value = + SpdyAltSvcWireFormat::SerializeHeaderFieldValue(altsvc_vector); + SpdyAltSvcWireFormat::AlternativeServiceVector roundtrip_altsvc_vector; + ASSERT_TRUE(SpdyAltSvcWireFormat::ParseHeaderFieldValue( + reserialized_header_field_value, &roundtrip_altsvc_vector)); + ASSERT_EQ(expected_altsvc_vector.size(), roundtrip_altsvc_vector.size()); + for (unsigned int j = 0; j < roundtrip_altsvc_vector.size(); ++j) { + EXPECT_EQ(expected_altsvc_vector[j].protocol_id, + roundtrip_altsvc_vector[j].protocol_id); + EXPECT_EQ(expected_altsvc_vector[j].host, + roundtrip_altsvc_vector[j].host); + EXPECT_EQ(expected_altsvc_vector[j].port, + roundtrip_altsvc_vector[j].port); + EXPECT_EQ(expected_altsvc_vector[j].max_age_seconds, + roundtrip_altsvc_vector[j].max_age_seconds); + EXPECT_EQ(expected_altsvc_vector[j].version, + roundtrip_altsvc_vector[j].version); + } + } +} + +TEST(SpdyAltSvcWireFormatTest, SerializeEmptyHeaderFieldValue) { + SpdyAltSvcWireFormat::AlternativeServiceVector altsvc_vector; + EXPECT_EQ("clear", + SpdyAltSvcWireFormat::SerializeHeaderFieldValue(altsvc_vector)); +} + +// Test ParseHeaderFieldValue() and SerializeHeaderFieldValue() on the same pair +// of |expected_header_field_value| and |altsvc|, with and without hostname and +// each +// parameter. Single alternative service at a time. +TEST(SpdyAltSvcWireFormatTest, RoundTrip) { + for (int i = 0; i < 1 << 3; ++i) { + SpdyAltSvcWireFormat::AlternativeService altsvc; + std::string expected_header_field_value; + FuzzAlternativeService(i, &altsvc, &expected_header_field_value); + + // Test ParseHeaderFieldValue(). + SpdyAltSvcWireFormat::AlternativeServiceVector parsed_altsvc_vector; + ASSERT_TRUE(SpdyAltSvcWireFormat::ParseHeaderFieldValue( + expected_header_field_value, &parsed_altsvc_vector)); + ASSERT_EQ(1u, parsed_altsvc_vector.size()); + EXPECT_EQ(altsvc.protocol_id, parsed_altsvc_vector[0].protocol_id); + EXPECT_EQ(altsvc.host, parsed_altsvc_vector[0].host); + EXPECT_EQ(altsvc.port, parsed_altsvc_vector[0].port); + EXPECT_EQ(altsvc.max_age_seconds, parsed_altsvc_vector[0].max_age_seconds); + EXPECT_EQ(altsvc.version, parsed_altsvc_vector[0].version); + + // Test SerializeHeaderFieldValue(). + SpdyAltSvcWireFormat::AlternativeServiceVector altsvc_vector; + altsvc_vector.push_back(altsvc); + EXPECT_EQ(expected_header_field_value, + SpdyAltSvcWireFormat::SerializeHeaderFieldValue(altsvc_vector)); + } +} + +// Test ParseHeaderFieldValue() and SerializeHeaderFieldValue() on the same pair +// of |expected_header_field_value| and |altsvc|, with and without hostname and +// each +// parameter. Multiple alternative services at a time. +TEST(SpdyAltSvcWireFormatTest, RoundTripMultiple) { + SpdyAltSvcWireFormat::AlternativeServiceVector altsvc_vector; + std::string expected_header_field_value; + for (int i = 0; i < 1 << 3; ++i) { + SpdyAltSvcWireFormat::AlternativeService altsvc; + FuzzAlternativeService(i, &altsvc, &expected_header_field_value); + altsvc_vector.push_back(altsvc); + } + + // Test ParseHeaderFieldValue(). + SpdyAltSvcWireFormat::AlternativeServiceVector parsed_altsvc_vector; + ASSERT_TRUE(SpdyAltSvcWireFormat::ParseHeaderFieldValue( + expected_header_field_value, &parsed_altsvc_vector)); + ASSERT_EQ(altsvc_vector.size(), parsed_altsvc_vector.size()); + auto expected_it = altsvc_vector.begin(); + auto parsed_it = parsed_altsvc_vector.begin(); + for (; expected_it != altsvc_vector.end(); ++expected_it, ++parsed_it) { + EXPECT_EQ(expected_it->protocol_id, parsed_it->protocol_id); + EXPECT_EQ(expected_it->host, parsed_it->host); + EXPECT_EQ(expected_it->port, parsed_it->port); + EXPECT_EQ(expected_it->max_age_seconds, parsed_it->max_age_seconds); + EXPECT_EQ(expected_it->version, parsed_it->version); + } + + // Test SerializeHeaderFieldValue(). + EXPECT_EQ(expected_header_field_value, + SpdyAltSvcWireFormat::SerializeHeaderFieldValue(altsvc_vector)); +} + +// ParseHeaderFieldValue() should return false on malformed field values: +// invalid percent encoding, unmatched quotation mark, empty port, non-numeric +// characters in numeric fields. +TEST(SpdyAltSvcWireFormatTest, ParseHeaderFieldValueInvalid) { + SpdyAltSvcWireFormat::AlternativeServiceVector altsvc_vector; + const char* invalid_field_value_array[] = {"a%", + "a%x", + "a%b", + "a%9z", + "a=", + "a=\"", + "a=\"b\"", + "a=\":\"", + "a=\"c:\"", + "a=\"c:foo\"", + "a=\"c:42foo\"", + "a=\"b:42\"bar", + "a=\"b:42\" ; m", + "a=\"b:42\" ; min-age", + "a=\"b:42\" ; ma", + "a=\"b:42\" ; ma=", + "a=\"b:42\" ; v=\"..\"", + "a=\"b:42\" ; ma=ma", + "a=\"b:42\" ; ma=123bar", + "a=\"b:42\" ; v=24", + "a=\"b:42\" ; v=24,25", + "a=\"b:42\" ; v=\"-3\"", + "a=\"b:42\" ; v=\"1.2\"", + "a=\"b:42\" ; v=\"24,\""}; + for (const char* invalid_field_value : invalid_field_value_array) { + EXPECT_FALSE(SpdyAltSvcWireFormat::ParseHeaderFieldValue( + invalid_field_value, &altsvc_vector)) + << invalid_field_value; + } +} + +// ParseHeaderFieldValue() should return false on a field values truncated +// before closing quotation mark, without trying to access memory beyond the end +// of the input. +TEST(SpdyAltSvcWireFormatTest, ParseTruncatedHeaderFieldValue) { + SpdyAltSvcWireFormat::AlternativeServiceVector altsvc_vector; + const char* field_value_array[] = {"a=\":137\"", "a=\"foo:137\"", + "a%25=\"foo\\\"bar\\\\baz:137\""}; + for (const absl::string_view field_value : field_value_array) { + for (size_t len = 1; len < field_value.size(); ++len) { + EXPECT_FALSE(SpdyAltSvcWireFormat::ParseHeaderFieldValue( + field_value.substr(0, len), &altsvc_vector)) + << len; + } + } +} + +// Tests of private methods. + +// Test SkipWhiteSpace(). +TEST(SpdyAltSvcWireFormatTest, SkipWhiteSpace) { + absl::string_view input("a \tb "); + absl::string_view::const_iterator c = input.begin(); + SpdyAltSvcWireFormatPeer::SkipWhiteSpace(&c, input.end()); + ASSERT_EQ(input.begin(), c); + ++c; + SpdyAltSvcWireFormatPeer::SkipWhiteSpace(&c, input.end()); + ASSERT_EQ(input.begin() + 3, c); + ++c; + SpdyAltSvcWireFormatPeer::SkipWhiteSpace(&c, input.end()); + ASSERT_EQ(input.end(), c); +} + +// Test PercentDecode() on valid input. +TEST(SpdyAltSvcWireFormatTest, PercentDecodeValid) { + absl::string_view input(""); + std::string output; + ASSERT_TRUE(SpdyAltSvcWireFormatPeer::PercentDecode(input.begin(), + input.end(), &output)); + EXPECT_EQ("", output); + + input = absl::string_view("foo"); + output.clear(); + ASSERT_TRUE(SpdyAltSvcWireFormatPeer::PercentDecode(input.begin(), + input.end(), &output)); + EXPECT_EQ("foo", output); + + input = absl::string_view("%2ca%5Cb"); + output.clear(); + ASSERT_TRUE(SpdyAltSvcWireFormatPeer::PercentDecode(input.begin(), + input.end(), &output)); + EXPECT_EQ(",a\\b", output); +} + +// Test PercentDecode() on invalid input. +TEST(SpdyAltSvcWireFormatTest, PercentDecodeInvalid) { + const char* invalid_input_array[] = {"a%", "a%x", "a%b", "%J22", "%9z"}; + for (const char* invalid_input : invalid_input_array) { + absl::string_view input(invalid_input); + std::string output; + EXPECT_FALSE(SpdyAltSvcWireFormatPeer::PercentDecode(input.begin(), + input.end(), &output)) + << input; + } +} + +// Test ParseAltAuthority() on valid input. +TEST(SpdyAltSvcWireFormatTest, ParseAltAuthorityValid) { + absl::string_view input(":42"); + std::string host; + uint16_t port; + ASSERT_TRUE(SpdyAltSvcWireFormatPeer::ParseAltAuthority( + input.begin(), input.end(), &host, &port)); + EXPECT_TRUE(host.empty()); + EXPECT_EQ(42, port); + + input = absl::string_view("foo:137"); + ASSERT_TRUE(SpdyAltSvcWireFormatPeer::ParseAltAuthority( + input.begin(), input.end(), &host, &port)); + EXPECT_EQ("foo", host); + EXPECT_EQ(137, port); + + input = absl::string_view("[2003:8:0:16::509d:9615]:443"); + ASSERT_TRUE(SpdyAltSvcWireFormatPeer::ParseAltAuthority( + input.begin(), input.end(), &host, &port)); + EXPECT_EQ("[2003:8:0:16::509d:9615]", host); + EXPECT_EQ(443, port); +} + +// Test ParseAltAuthority() on invalid input: empty string, no port, zero port, +// non-digit characters following port. +TEST(SpdyAltSvcWireFormatTest, ParseAltAuthorityInvalid) { + const char* invalid_input_array[] = {"", + ":", + "foo:", + ":bar", + ":0", + "foo:0", + ":12bar", + "foo:23bar", + " ", + ":12 ", + "foo:12 ", + "[2003:8:0:16::509d:9615]", + "[2003:8:0:16::509d:9615]:", + "[2003:8:0:16::509d:9615]foo:443", + "[2003:8:0:16::509d:9615:443", + "2003:8:0:16::509d:9615]:443"}; + for (const char* invalid_input : invalid_input_array) { + absl::string_view input(invalid_input); + std::string host; + uint16_t port; + EXPECT_FALSE(SpdyAltSvcWireFormatPeer::ParseAltAuthority( + input.begin(), input.end(), &host, &port)) + << input; + } +} + +// Test ParseInteger() on valid input. +TEST(SpdyAltSvcWireFormatTest, ParseIntegerValid) { + absl::string_view input("3"); + uint16_t value; + ASSERT_TRUE(SpdyAltSvcWireFormatPeer::ParsePositiveInteger16( + input.begin(), input.end(), &value)); + EXPECT_EQ(3, value); + + input = absl::string_view("1337"); + ASSERT_TRUE(SpdyAltSvcWireFormatPeer::ParsePositiveInteger16( + input.begin(), input.end(), &value)); + EXPECT_EQ(1337, value); +} + +// Test ParseIntegerValid() on invalid input: empty, zero, non-numeric, trailing +// non-numeric characters. +TEST(SpdyAltSvcWireFormatTest, ParseIntegerInvalid) { + const char* invalid_input_array[] = {"", " ", "a", "0", "00", "1 ", "12b"}; + for (const char* invalid_input : invalid_input_array) { + absl::string_view input(invalid_input); + uint16_t value; + EXPECT_FALSE(SpdyAltSvcWireFormatPeer::ParsePositiveInteger16( + input.begin(), input.end(), &value)) + << input; + } +} + +// Test ParseIntegerValid() around overflow limit. +TEST(SpdyAltSvcWireFormatTest, ParseIntegerOverflow) { + // Largest possible uint16_t value. + absl::string_view input("65535"); + uint16_t value16; + ASSERT_TRUE(SpdyAltSvcWireFormatPeer::ParsePositiveInteger16( + input.begin(), input.end(), &value16)); + EXPECT_EQ(65535, value16); + + // Overflow uint16_t, ParsePositiveInteger16() should return false. + input = absl::string_view("65536"); + ASSERT_FALSE(SpdyAltSvcWireFormatPeer::ParsePositiveInteger16( + input.begin(), input.end(), &value16)); + + // However, even if overflow is not checked for, 65536 overflows to 0, which + // returns false anyway. Check for a larger number which overflows to 1. + input = absl::string_view("65537"); + ASSERT_FALSE(SpdyAltSvcWireFormatPeer::ParsePositiveInteger16( + input.begin(), input.end(), &value16)); + + // Largest possible uint32_t value. + input = absl::string_view("4294967295"); + uint32_t value32; + ASSERT_TRUE(SpdyAltSvcWireFormatPeer::ParsePositiveInteger32( + input.begin(), input.end(), &value32)); + EXPECT_EQ(4294967295, value32); + + // Overflow uint32_t, ParsePositiveInteger32() should return false. + input = absl::string_view("4294967296"); + ASSERT_FALSE(SpdyAltSvcWireFormatPeer::ParsePositiveInteger32( + input.begin(), input.end(), &value32)); + + // However, even if overflow is not checked for, 4294967296 overflows to 0, + // which returns false anyway. Check for a larger number which overflows to + // 1. + input = absl::string_view("4294967297"); + ASSERT_FALSE(SpdyAltSvcWireFormatPeer::ParsePositiveInteger32( + input.begin(), input.end(), &value32)); +} + +// Test parsing an Alt-Svc entry with IP literal hostname. +// Regression test for https://crbug.com/664173. +TEST(SpdyAltSvcWireFormatTest, ParseIPLiteral) { + const char* input = + "quic=\"[2003:8:0:16::509d:9615]:443\"; v=\"36,35\"; ma=60"; + SpdyAltSvcWireFormat::AlternativeServiceVector altsvc_vector; + ASSERT_TRUE( + SpdyAltSvcWireFormat::ParseHeaderFieldValue(input, &altsvc_vector)); + EXPECT_EQ(1u, altsvc_vector.size()); + EXPECT_EQ("quic", altsvc_vector[0].protocol_id); + EXPECT_EQ("[2003:8:0:16::509d:9615]", altsvc_vector[0].host); + EXPECT_EQ(443u, altsvc_vector[0].port); + EXPECT_EQ(60u, altsvc_vector[0].max_age_seconds); + EXPECT_THAT(altsvc_vector[0].version, ::testing::ElementsAre(36, 35)); +} + +TEST(SpdyAltSvcWireFormatTest, HexDigitToInt) { + EXPECT_EQ(0, SpdyAltSvcWireFormatPeer::HexDigitToInt('0')); + EXPECT_EQ(1, SpdyAltSvcWireFormatPeer::HexDigitToInt('1')); + EXPECT_EQ(2, SpdyAltSvcWireFormatPeer::HexDigitToInt('2')); + EXPECT_EQ(3, SpdyAltSvcWireFormatPeer::HexDigitToInt('3')); + EXPECT_EQ(4, SpdyAltSvcWireFormatPeer::HexDigitToInt('4')); + EXPECT_EQ(5, SpdyAltSvcWireFormatPeer::HexDigitToInt('5')); + EXPECT_EQ(6, SpdyAltSvcWireFormatPeer::HexDigitToInt('6')); + EXPECT_EQ(7, SpdyAltSvcWireFormatPeer::HexDigitToInt('7')); + EXPECT_EQ(8, SpdyAltSvcWireFormatPeer::HexDigitToInt('8')); + EXPECT_EQ(9, SpdyAltSvcWireFormatPeer::HexDigitToInt('9')); + + EXPECT_EQ(10, SpdyAltSvcWireFormatPeer::HexDigitToInt('a')); + EXPECT_EQ(11, SpdyAltSvcWireFormatPeer::HexDigitToInt('b')); + EXPECT_EQ(12, SpdyAltSvcWireFormatPeer::HexDigitToInt('c')); + EXPECT_EQ(13, SpdyAltSvcWireFormatPeer::HexDigitToInt('d')); + EXPECT_EQ(14, SpdyAltSvcWireFormatPeer::HexDigitToInt('e')); + EXPECT_EQ(15, SpdyAltSvcWireFormatPeer::HexDigitToInt('f')); + + EXPECT_EQ(10, SpdyAltSvcWireFormatPeer::HexDigitToInt('A')); + EXPECT_EQ(11, SpdyAltSvcWireFormatPeer::HexDigitToInt('B')); + EXPECT_EQ(12, SpdyAltSvcWireFormatPeer::HexDigitToInt('C')); + EXPECT_EQ(13, SpdyAltSvcWireFormatPeer::HexDigitToInt('D')); + EXPECT_EQ(14, SpdyAltSvcWireFormatPeer::HexDigitToInt('E')); + EXPECT_EQ(15, SpdyAltSvcWireFormatPeer::HexDigitToInt('F')); +} + +TEST(SpdyAltSvcWireFormatTest, HexDecodeToUInt32) { + uint32_t out; + EXPECT_TRUE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("0", &out)); + EXPECT_EQ(0u, out); + EXPECT_TRUE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("00", &out)); + EXPECT_EQ(0u, out); + EXPECT_TRUE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("0000000", &out)); + EXPECT_EQ(0u, out); + EXPECT_TRUE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("00000000", &out)); + EXPECT_EQ(0u, out); + EXPECT_TRUE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("1", &out)); + EXPECT_EQ(1u, out); + EXPECT_TRUE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("ffffFFF", &out)); + EXPECT_EQ(0xFFFFFFFu, out); + EXPECT_TRUE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("fFfFffFf", &out)); + EXPECT_EQ(0xFFFFFFFFu, out); + EXPECT_TRUE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("01AEF", &out)); + EXPECT_EQ(0x1AEFu, out); + EXPECT_TRUE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("abcde", &out)); + EXPECT_EQ(0xABCDEu, out); + EXPECT_TRUE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("1234abcd", &out)); + EXPECT_EQ(0x1234ABCDu, out); + + EXPECT_FALSE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("", &out)); + EXPECT_FALSE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("111111111", &out)); + EXPECT_FALSE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("1111111111", &out)); + EXPECT_FALSE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("0x1111", &out)); +} + +} // namespace + +} // namespace test + +} // namespace spdy diff --git a/quiche/spdy/core/spdy_bitmasks.h b/quiche/spdy/core/spdy_bitmasks.h new file mode 100644 index 000000000000..657bd1761e98 --- /dev/null +++ b/quiche/spdy/core/spdy_bitmasks.h @@ -0,0 +1,18 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_SPDY_CORE_SPDY_BITMASKS_H_ +#define QUICHE_SPDY_CORE_SPDY_BITMASKS_H_ + +namespace spdy { + +// StreamId mask from the SpdyHeader +const unsigned int kStreamIdMask = 0x7fffffff; + +// Mask the lower 24 bits. +const unsigned int kLengthMask = 0xffffff; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_SPDY_BITMASKS_H_ diff --git a/quiche/spdy/core/spdy_frame_builder.cc b/quiche/spdy/core/spdy_frame_builder.cc new file mode 100644 index 000000000000..050d01ef456c --- /dev/null +++ b/quiche/spdy/core/spdy_frame_builder.cc @@ -0,0 +1,182 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/spdy_frame_builder.h" + +#include +#include +#include +#include + +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/spdy/core/spdy_protocol.h" +#include "quiche/spdy/core/zero_copy_output_buffer.h" + +namespace spdy { + +SpdyFrameBuilder::SpdyFrameBuilder(size_t size) + : buffer_(new char[size]), capacity_(size), length_(0), offset_(0) {} + +SpdyFrameBuilder::SpdyFrameBuilder(size_t size, ZeroCopyOutputBuffer* output) + : buffer_(output == nullptr ? new char[size] : nullptr), + output_(output), + capacity_(size), + length_(0), + offset_(0) {} + +SpdyFrameBuilder::~SpdyFrameBuilder() = default; + +char* SpdyFrameBuilder::GetWritableBuffer(size_t length) { + if (!CanWrite(length)) { + return nullptr; + } + return buffer_.get() + offset_ + length_; +} + +char* SpdyFrameBuilder::GetWritableOutput(size_t length, + size_t* actual_length) { + char* dest = nullptr; + int size = 0; + + if (!CanWrite(length)) { + return nullptr; + } + output_->Next(&dest, &size); + *actual_length = std::min(length, size); + return dest; +} + +bool SpdyFrameBuilder::Seek(size_t length) { + if (!CanWrite(length)) { + return false; + } + if (output_ == nullptr) { + length_ += length; + } else { + output_->AdvanceWritePtr(length); + length_ += length; + } + return true; +} + +bool SpdyFrameBuilder::BeginNewFrame(SpdyFrameType type, uint8_t flags, + SpdyStreamId stream_id) { + uint8_t raw_frame_type = SerializeFrameType(type); + QUICHE_DCHECK(IsDefinedFrameType(raw_frame_type)); + QUICHE_DCHECK_EQ(0u, stream_id & ~kStreamIdMask); + bool success = true; + if (length_ > 0) { + QUICHE_BUG(spdy_bug_73_1) + << "SpdyFrameBuilder doesn't have a clean state when BeginNewFrame" + << "is called. Leftover length_ is " << length_; + offset_ += length_; + length_ = 0; + } + + success &= WriteUInt24(capacity_ - offset_ - kFrameHeaderSize); + success &= WriteUInt8(raw_frame_type); + success &= WriteUInt8(flags); + success &= WriteUInt32(stream_id); + QUICHE_DCHECK_EQ(kDataFrameMinimumSize, length_); + return success; +} + +bool SpdyFrameBuilder::BeginNewFrame(SpdyFrameType type, uint8_t flags, + SpdyStreamId stream_id, size_t length) { + uint8_t raw_frame_type = SerializeFrameType(type); + QUICHE_DCHECK(IsDefinedFrameType(raw_frame_type)); + QUICHE_DCHECK_EQ(0u, stream_id & ~kStreamIdMask); + QUICHE_BUG_IF(spdy_bug_73_2, length > kSpdyMaxFrameSizeLimit) + << "Frame length " << length << " is longer than frame size limit."; + return BeginNewFrameInternal(raw_frame_type, flags, stream_id, length); +} + +bool SpdyFrameBuilder::BeginNewUncheckedFrame(uint8_t raw_frame_type, + uint8_t flags, + SpdyStreamId stream_id, + size_t length) { + return BeginNewFrameInternal(raw_frame_type, flags, stream_id, length); +} + +bool SpdyFrameBuilder::BeginNewFrameInternal(uint8_t raw_frame_type, + uint8_t flags, + SpdyStreamId stream_id, + size_t length) { + QUICHE_DCHECK_EQ(length, length & kLengthMask); + bool success = true; + + offset_ += length_; + length_ = 0; + + success &= WriteUInt24(length); + success &= WriteUInt8(raw_frame_type); + success &= WriteUInt8(flags); + success &= WriteUInt32(stream_id); + QUICHE_DCHECK_EQ(kDataFrameMinimumSize, length_); + return success; +} + +bool SpdyFrameBuilder::WriteStringPiece32(const absl::string_view value) { + if (!WriteUInt32(value.size())) { + return false; + } + + return WriteBytes(value.data(), value.size()); +} + +bool SpdyFrameBuilder::WriteBytes(const void* data, uint32_t data_len) { + if (!CanWrite(data_len)) { + return false; + } + + if (output_ == nullptr) { + char* dest = GetWritableBuffer(data_len); + memcpy(dest, data, data_len); + Seek(data_len); + } else { + char* dest = nullptr; + size_t size = 0; + size_t total_written = 0; + const char* data_ptr = reinterpret_cast(data); + while (data_len > 0) { + dest = GetWritableOutput(data_len, &size); + if (dest == nullptr || size == 0) { + // Unable to make progress. + return false; + } + uint32_t to_copy = std::min(data_len, size); + const char* src = data_ptr + total_written; + memcpy(dest, src, to_copy); + Seek(to_copy); + data_len -= to_copy; + total_written += to_copy; + } + } + return true; +} + +bool SpdyFrameBuilder::CanWrite(size_t length) const { + if (length > kLengthMask) { + QUICHE_DCHECK(false); + return false; + } + + if (output_ == nullptr) { + if (offset_ + length_ + length > capacity_) { + QUICHE_DLOG(FATAL) << "Requested: " << length + << " capacity: " << capacity_ + << " used: " << offset_ + length_; + return false; + } + } else { + if (length > output_->BytesFree()) { + return false; + } + } + + return true; +} + +} // namespace spdy diff --git a/quiche/spdy/core/spdy_frame_builder.h b/quiche/spdy/core/spdy_frame_builder.h new file mode 100644 index 000000000000..69cf3526ea52 --- /dev/null +++ b/quiche/spdy/core/spdy_frame_builder.h @@ -0,0 +1,140 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_SPDY_CORE_SPDY_FRAME_BUILDER_H_ +#define QUICHE_SPDY_CORE_SPDY_FRAME_BUILDER_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/quiche_endian.h" +#include "quiche/spdy/core/spdy_protocol.h" +#include "quiche/spdy/core/zero_copy_output_buffer.h" + +namespace spdy { + +namespace test { +class SpdyFrameBuilderPeer; +} // namespace test + +// This class provides facilities for basic binary value packing +// into Spdy frames. +// +// The SpdyFrameBuilder supports appending primitive values (int, string, etc) +// to a frame instance. The SpdyFrameBuilder grows its internal memory buffer +// dynamically to hold the sequence of primitive values. The internal memory +// buffer is exposed as the "data" of the SpdyFrameBuilder. +class QUICHE_EXPORT SpdyFrameBuilder { + public: + // Initializes a SpdyFrameBuilder with a buffer of given size + explicit SpdyFrameBuilder(size_t size); + // Doesn't take ownership of output. + SpdyFrameBuilder(size_t size, ZeroCopyOutputBuffer* output); + + ~SpdyFrameBuilder(); + + // Returns the total size of the SpdyFrameBuilder's data, which may include + // multiple frames. + size_t length() const { return offset_ + length_; } + + // Seeks forward by the given number of bytes. Useful in conjunction with + // GetWriteableBuffer() above. + bool Seek(size_t length); + + // Populates this frame with a HTTP2 frame prefix using length information + // from |capacity_|. The given type must be a control frame type. + bool BeginNewFrame(SpdyFrameType type, uint8_t flags, SpdyStreamId stream_id); + + // Populates this frame with a HTTP2 frame prefix with type and length + // information. |type| must be a defined frame type. + bool BeginNewFrame(SpdyFrameType type, uint8_t flags, SpdyStreamId stream_id, + size_t length); + + // Populates this frame with a HTTP2 frame prefix with type and length + // information. |raw_frame_type| may be a defined or undefined frame type. + bool BeginNewUncheckedFrame(uint8_t raw_frame_type, uint8_t flags, + SpdyStreamId stream_id, size_t length); + + // Takes the buffer from the SpdyFrameBuilder. + SpdySerializedFrame take() { + QUICHE_BUG_IF(spdy_bug_39_1, output_ != nullptr) + << "ZeroCopyOutputBuffer is used to build " + << "frames. take() shouldn't be called"; + QUICHE_BUG_IF(spdy_bug_39_2, kMaxFrameSizeLimit < length_) + << "Frame length " << length_ + << " is longer than the maximum possible allowed length."; + SpdySerializedFrame rv(buffer_.release(), length(), true); + capacity_ = 0; + length_ = 0; + offset_ = 0; + return rv; + } + + // Methods for adding to the payload. These values are appended to the end + // of the SpdyFrameBuilder payload. Note - binary integers are converted from + // host to network form. + bool WriteUInt8(uint8_t value) { return WriteBytes(&value, sizeof(value)); } + bool WriteUInt16(uint16_t value) { + value = quiche::QuicheEndian::HostToNet16(value); + return WriteBytes(&value, sizeof(value)); + } + bool WriteUInt24(uint32_t value) { + value = quiche::QuicheEndian::HostToNet32(value); + return WriteBytes(reinterpret_cast(&value) + 1, sizeof(value) - 1); + } + bool WriteUInt32(uint32_t value) { + value = quiche::QuicheEndian::HostToNet32(value); + return WriteBytes(&value, sizeof(value)); + } + bool WriteUInt64(uint64_t value) { + uint32_t upper = + quiche::QuicheEndian::HostToNet32(static_cast(value >> 32)); + uint32_t lower = + quiche::QuicheEndian::HostToNet32(static_cast(value)); + return (WriteBytes(&upper, sizeof(upper)) && + WriteBytes(&lower, sizeof(lower))); + } + bool WriteStringPiece32(const absl::string_view value); + bool WriteBytes(const void* data, uint32_t data_len); + + private: + friend class test::SpdyFrameBuilderPeer; + + // Populates this frame with a HTTP2 frame prefix with type and length + // information. + bool BeginNewFrameInternal(uint8_t raw_frame_type, uint8_t flags, + SpdyStreamId stream_id, size_t length); + + // Returns a writeable buffer of given size in bytes, to be appended to the + // currently written frame. Does bounds checking on length but does not + // increment the underlying iterator. To do so, consumers should subsequently + // call Seek(). + // In general, consumers should use Write*() calls instead of this. + // Returns NULL on failure. + char* GetWritableBuffer(size_t length); + char* GetWritableOutput(size_t desired_length, size_t* actual_length); + + // Checks to make sure that there is an appropriate amount of space for a + // write of given size, in bytes. + bool CanWrite(size_t length) const; + + // A buffer to be created whenever a new frame needs to be written. Used only + // if |output_| is nullptr. + std::unique_ptr buffer_; + // A pre-allocated buffer. If not-null, serialized frame data is written to + // this buffer. + ZeroCopyOutputBuffer* output_ = nullptr; // Does not own. + + size_t capacity_; // Allocation size of payload, set by constructor. + size_t length_; // Length of the latest frame in the buffer. + size_t offset_; // Position at which the latest frame begins. +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_SPDY_FRAME_BUILDER_H_ diff --git a/quiche/spdy/core/spdy_frame_builder_test.cc b/quiche/spdy/core/spdy_frame_builder_test.cc new file mode 100644 index 000000000000..293d7d8bc530 --- /dev/null +++ b/quiche/spdy/core/spdy_frame_builder_test.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/spdy_frame_builder.h" + +#include + +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/spdy/core/array_output_buffer.h" +#include "quiche/spdy/core/spdy_framer.h" +#include "quiche/spdy/core/spdy_protocol.h" + +namespace spdy { + +namespace test { + +class QUICHE_EXPORT SpdyFrameBuilderPeer { + public: + static char* GetWritableBuffer(SpdyFrameBuilder* builder, size_t length) { + return builder->GetWritableBuffer(length); + } + + static char* GetWritableOutput(SpdyFrameBuilder* builder, + size_t desired_length, size_t* actual_length) { + return builder->GetWritableOutput(desired_length, actual_length); + } +}; + +namespace { + +const int64_t kSize = 64 * 1024; +char output_buffer[kSize] = ""; + +} // namespace + +// Verifies that SpdyFrameBuilder::GetWritableBuffer() can be used to build a +// SpdySerializedFrame. +TEST(SpdyFrameBuilderTest, GetWritableBuffer) { + const size_t kBuilderSize = 10; + SpdyFrameBuilder builder(kBuilderSize); + char* writable_buffer = + SpdyFrameBuilderPeer::GetWritableBuffer(&builder, kBuilderSize); + memset(writable_buffer, ~1, kBuilderSize); + EXPECT_TRUE(builder.Seek(kBuilderSize)); + SpdySerializedFrame frame(builder.take()); + char expected[kBuilderSize]; + memset(expected, ~1, kBuilderSize); + EXPECT_EQ(absl::string_view(expected, kBuilderSize), + absl::string_view(frame.data(), kBuilderSize)); +} + +// Verifies that SpdyFrameBuilder::GetWritableBuffer() can be used to build a +// SpdySerializedFrame to the output buffer. +TEST(SpdyFrameBuilderTest, GetWritableOutput) { + ArrayOutputBuffer output(output_buffer, kSize); + const size_t kBuilderSize = 10; + SpdyFrameBuilder builder(kBuilderSize, &output); + size_t actual_size = 0; + char* writable_buffer = SpdyFrameBuilderPeer::GetWritableOutput( + &builder, kBuilderSize, &actual_size); + memset(writable_buffer, ~1, kBuilderSize); + EXPECT_TRUE(builder.Seek(kBuilderSize)); + SpdySerializedFrame frame(output.Begin(), kBuilderSize, false); + char expected[kBuilderSize]; + memset(expected, ~1, kBuilderSize); + EXPECT_EQ(absl::string_view(expected, kBuilderSize), + absl::string_view(frame.data(), kBuilderSize)); +} + +// Verifies the case that the buffer's capacity is too small. +TEST(SpdyFrameBuilderTest, GetWritableOutputNegative) { + size_t small_cap = 1; + ArrayOutputBuffer output(output_buffer, small_cap); + const size_t kBuilderSize = 10; + SpdyFrameBuilder builder(kBuilderSize, &output); + size_t actual_size = 0; + char* writable_buffer = SpdyFrameBuilderPeer::GetWritableOutput( + &builder, kBuilderSize, &actual_size); + EXPECT_EQ(0u, actual_size); + EXPECT_EQ(nullptr, writable_buffer); +} + +} // namespace test +} // namespace spdy diff --git a/quiche/spdy/core/spdy_framer.cc b/quiche/spdy/core/spdy_framer.cc new file mode 100644 index 000000000000..8b7d69ed8033 --- /dev/null +++ b/quiche/spdy/core/spdy_framer.cc @@ -0,0 +1,1365 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/spdy_framer.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/memory/memory.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/spdy/core/spdy_bitmasks.h" +#include "quiche/spdy/core/spdy_frame_builder.h" + +namespace spdy { + +namespace { + +// Pack parent stream ID and exclusive flag into the format used by HTTP/2 +// headers and priority frames. +uint32_t PackStreamDependencyValues(bool exclusive, + SpdyStreamId parent_stream_id) { + // Make sure the highest-order bit in the parent stream id is zeroed out. + uint32_t parent = parent_stream_id & 0x7fffffff; + // Set the one-bit exclusivity flag. + uint32_t e_bit = exclusive ? 0x80000000 : 0; + return parent | e_bit; +} + +// Used to indicate no flags in a HTTP2 flags field. +const uint8_t kNoFlags = 0; + +// Wire size of pad length field. +const size_t kPadLengthFieldSize = 1; + +// The size of one parameter in SETTINGS frame. +const size_t kOneSettingParameterSize = 6; + +size_t GetUncompressedSerializedLength(const Http2HeaderBlock& headers) { + const size_t num_name_value_pairs_size = sizeof(uint32_t); + const size_t length_of_name_size = num_name_value_pairs_size; + const size_t length_of_value_size = num_name_value_pairs_size; + + size_t total_length = num_name_value_pairs_size; + for (const auto& header : headers) { + // We add space for the length of the name and the length of the value as + // well as the length of the name and the length of the value. + total_length += length_of_name_size + header.first.size() + + length_of_value_size + header.second.size(); + } + return total_length; +} + +// Serializes the flags octet for a given SpdyHeadersIR. +uint8_t SerializeHeaderFrameFlags(const SpdyHeadersIR& header_ir, + const bool end_headers) { + uint8_t flags = 0; + if (header_ir.fin()) { + flags |= CONTROL_FLAG_FIN; + } + if (end_headers) { + flags |= HEADERS_FLAG_END_HEADERS; + } + if (header_ir.padded()) { + flags |= HEADERS_FLAG_PADDED; + } + if (header_ir.has_priority()) { + flags |= HEADERS_FLAG_PRIORITY; + } + return flags; +} + +// Serializes the flags octet for a given SpdyPushPromiseIR. +uint8_t SerializePushPromiseFrameFlags(const SpdyPushPromiseIR& push_promise_ir, + const bool end_headers) { + uint8_t flags = 0; + if (push_promise_ir.padded()) { + flags = flags | PUSH_PROMISE_FLAG_PADDED; + } + if (end_headers) { + flags |= PUSH_PROMISE_FLAG_END_PUSH_PROMISE; + } + return flags; +} + +// Serializes a HEADERS frame from the given SpdyHeadersIR and encoded header +// block. Does not need or use the Http2HeaderBlock inside SpdyHeadersIR. +// Return false if the serialization fails. |encoding| should not be empty. +bool SerializeHeadersGivenEncoding(const SpdyHeadersIR& headers, + const std::string& encoding, + const bool end_headers, + ZeroCopyOutputBuffer* output) { + const size_t frame_size = + GetHeaderFrameSizeSansBlock(headers) + encoding.size(); + SpdyFrameBuilder builder(frame_size, output); + bool ret = builder.BeginNewFrame( + SpdyFrameType::HEADERS, SerializeHeaderFrameFlags(headers, end_headers), + headers.stream_id(), frame_size - kFrameHeaderSize); + QUICHE_DCHECK_EQ(kFrameHeaderSize, builder.length()); + + if (ret && headers.padded()) { + ret &= builder.WriteUInt8(headers.padding_payload_len()); + } + + if (ret && headers.has_priority()) { + int weight = ClampHttp2Weight(headers.weight()); + ret &= builder.WriteUInt32(PackStreamDependencyValues( + headers.exclusive(), headers.parent_stream_id())); + // Per RFC 7540 section 6.3, serialized weight value is actual value - 1. + ret &= builder.WriteUInt8(weight - 1); + } + + if (ret) { + ret &= builder.WriteBytes(encoding.data(), encoding.size()); + } + + if (ret && headers.padding_payload_len() > 0) { + std::string padding(headers.padding_payload_len(), 0); + ret &= builder.WriteBytes(padding.data(), padding.length()); + } + + if (!ret) { + QUICHE_DLOG(WARNING) + << "Failed to build HEADERS. Not enough space in output"; + } + return ret; +} + +// Serializes a PUSH_PROMISE frame from the given SpdyPushPromiseIR and +// encoded header block. Does not need or use the Http2HeaderBlock inside +// SpdyPushPromiseIR. +bool SerializePushPromiseGivenEncoding(const SpdyPushPromiseIR& push_promise, + const std::string& encoding, + const bool end_headers, + ZeroCopyOutputBuffer* output) { + const size_t frame_size = + GetPushPromiseFrameSizeSansBlock(push_promise) + encoding.size(); + SpdyFrameBuilder builder(frame_size, output); + bool ok = builder.BeginNewFrame( + SpdyFrameType::PUSH_PROMISE, + SerializePushPromiseFrameFlags(push_promise, end_headers), + push_promise.stream_id(), frame_size - kFrameHeaderSize); + + if (push_promise.padded()) { + ok = ok && builder.WriteUInt8(push_promise.padding_payload_len()); + } + ok = ok && builder.WriteUInt32(push_promise.promised_stream_id()) && + builder.WriteBytes(encoding.data(), encoding.size()); + if (ok && push_promise.padding_payload_len() > 0) { + std::string padding(push_promise.padding_payload_len(), 0); + ok = builder.WriteBytes(padding.data(), padding.length()); + } + + QUICHE_DLOG_IF(ERROR, !ok) + << "Failed to write PUSH_PROMISE encoding, not enough " + << "space in output"; + return ok; +} + +bool WritePayloadWithContinuation(SpdyFrameBuilder* builder, + const std::string& hpack_encoding, + SpdyStreamId stream_id, SpdyFrameType type, + int padding_payload_len) { + uint8_t end_flag = 0; + uint8_t flags = 0; + if (type == SpdyFrameType::HEADERS) { + end_flag = HEADERS_FLAG_END_HEADERS; + } else if (type == SpdyFrameType::PUSH_PROMISE) { + end_flag = PUSH_PROMISE_FLAG_END_PUSH_PROMISE; + } else { + QUICHE_DLOG(FATAL) << "CONTINUATION frames cannot be used with frame type " + << FrameTypeToString(type); + } + + // Write all the padding payload and as much of the data payload as possible + // into the initial frame. + size_t bytes_remaining = 0; + bytes_remaining = hpack_encoding.size() - + std::min(hpack_encoding.size(), + kHttp2MaxControlFrameSendSize - builder->length() - + padding_payload_len); + bool ret = builder->WriteBytes(&hpack_encoding[0], + hpack_encoding.size() - bytes_remaining); + if (padding_payload_len > 0) { + std::string padding = std::string(padding_payload_len, 0); + ret &= builder->WriteBytes(padding.data(), padding.length()); + } + + // Tack on CONTINUATION frames for the overflow. + while (bytes_remaining > 0 && ret) { + size_t bytes_to_write = + std::min(bytes_remaining, + kHttp2MaxControlFrameSendSize - kContinuationFrameMinimumSize); + // Write CONTINUATION frame prefix. + if (bytes_remaining == bytes_to_write) { + flags |= end_flag; + } + ret &= builder->BeginNewFrame(SpdyFrameType::CONTINUATION, flags, stream_id, + bytes_to_write); + // Write payload fragment. + ret &= builder->WriteBytes( + &hpack_encoding[hpack_encoding.size() - bytes_remaining], + bytes_to_write); + bytes_remaining -= bytes_to_write; + } + return ret; +} + +void SerializeDataBuilderHelper(const SpdyDataIR& data_ir, uint8_t* flags, + int* num_padding_fields, + size_t* size_with_padding) { + if (data_ir.fin()) { + *flags = DATA_FLAG_FIN; + } + + if (data_ir.padded()) { + *flags = *flags | DATA_FLAG_PADDED; + ++*num_padding_fields; + } + + *size_with_padding = *num_padding_fields + data_ir.data_len() + + data_ir.padding_payload_len() + kDataFrameMinimumSize; +} + +void SerializeDataFrameHeaderWithPaddingLengthFieldBuilderHelper( + const SpdyDataIR& data_ir, uint8_t* flags, size_t* frame_size, + size_t* num_padding_fields) { + *flags = DATA_FLAG_NONE; + if (data_ir.fin()) { + *flags = DATA_FLAG_FIN; + } + + *frame_size = kDataFrameMinimumSize; + if (data_ir.padded()) { + *flags = *flags | DATA_FLAG_PADDED; + ++(*num_padding_fields); + *frame_size = *frame_size + *num_padding_fields; + } +} + +void SerializeSettingsBuilderHelper(const SpdySettingsIR& settings, + uint8_t* flags, const SettingsMap* values, + size_t* size) { + if (settings.is_ack()) { + *flags = *flags | SETTINGS_FLAG_ACK; + } + *size = + kSettingsFrameMinimumSize + (values->size() * kOneSettingParameterSize); +} + +void SerializeAltSvcBuilderHelper(const SpdyAltSvcIR& altsvc_ir, + std::string* value, size_t* size) { + *size = kGetAltSvcFrameMinimumSize; + *size = *size + altsvc_ir.origin().length(); + *value = SpdyAltSvcWireFormat::SerializeHeaderFieldValue( + altsvc_ir.altsvc_vector()); + *size = *size + value->length(); +} + +} // namespace + +SpdyFramer::SpdyFramer(CompressionOption option) + : debug_visitor_(nullptr), compression_option_(option) { + static_assert(kHttp2MaxControlFrameSendSize <= kHttp2DefaultFrameSizeLimit, + "Our send limit should be at most our receive limit."); +} + +SpdyFramer::~SpdyFramer() = default; + +void SpdyFramer::set_debug_visitor( + SpdyFramerDebugVisitorInterface* debug_visitor) { + debug_visitor_ = debug_visitor; +} + +SpdyFramer::SpdyFrameIterator::SpdyFrameIterator(SpdyFramer* framer) + : framer_(framer), is_first_frame_(true), has_next_frame_(true) {} + +SpdyFramer::SpdyFrameIterator::~SpdyFrameIterator() = default; + +size_t SpdyFramer::SpdyFrameIterator::NextFrame(ZeroCopyOutputBuffer* output) { + const SpdyFrameIR& frame_ir = GetIR(); + if (!has_next_frame_) { + QUICHE_BUG(spdy_bug_75_1) + << "SpdyFramer::SpdyFrameIterator::NextFrame called without " + << "a next frame."; + return false; + } + + const size_t size_without_block = + is_first_frame_ ? GetFrameSizeSansBlock() : kContinuationFrameMinimumSize; + std::string encoding = + encoder_->Next(kHttp2MaxControlFrameSendSize - size_without_block); + has_next_frame_ = encoder_->HasNext(); + + if (framer_->debug_visitor_ != nullptr) { + const auto& header_block_frame_ir = + static_cast(frame_ir); + const size_t header_list_size = + GetUncompressedSerializedLength(header_block_frame_ir.header_block()); + framer_->debug_visitor_->OnSendCompressedFrame( + frame_ir.stream_id(), + is_first_frame_ ? frame_ir.frame_type() : SpdyFrameType::CONTINUATION, + header_list_size, size_without_block + encoding.size()); + } + + const size_t free_bytes_before = output->BytesFree(); + bool ok = false; + if (is_first_frame_) { + is_first_frame_ = false; + ok = SerializeGivenEncoding(encoding, output); + } else { + SpdyContinuationIR continuation_ir(frame_ir.stream_id()); + continuation_ir.take_encoding(std::move(encoding)); + continuation_ir.set_end_headers(!has_next_frame_); + ok = framer_->SerializeContinuation(continuation_ir, output); + } + return ok ? free_bytes_before - output->BytesFree() : 0; +} + +bool SpdyFramer::SpdyFrameIterator::HasNextFrame() const { + return has_next_frame_; +} + +SpdyFramer::SpdyHeaderFrameIterator::SpdyHeaderFrameIterator( + SpdyFramer* framer, std::unique_ptr headers_ir) + : SpdyFrameIterator(framer), headers_ir_(std::move(headers_ir)) { + SetEncoder(headers_ir_.get()); +} + +SpdyFramer::SpdyHeaderFrameIterator::~SpdyHeaderFrameIterator() = default; + +const SpdyFrameIR& SpdyFramer::SpdyHeaderFrameIterator::GetIR() const { + return *headers_ir_; +} + +size_t SpdyFramer::SpdyHeaderFrameIterator::GetFrameSizeSansBlock() const { + return GetHeaderFrameSizeSansBlock(*headers_ir_); +} + +bool SpdyFramer::SpdyHeaderFrameIterator::SerializeGivenEncoding( + const std::string& encoding, ZeroCopyOutputBuffer* output) const { + return SerializeHeadersGivenEncoding(*headers_ir_, encoding, + !has_next_frame(), output); +} + +SpdyFramer::SpdyPushPromiseFrameIterator::SpdyPushPromiseFrameIterator( + SpdyFramer* framer, + std::unique_ptr push_promise_ir) + : SpdyFrameIterator(framer), push_promise_ir_(std::move(push_promise_ir)) { + SetEncoder(push_promise_ir_.get()); +} + +SpdyFramer::SpdyPushPromiseFrameIterator::~SpdyPushPromiseFrameIterator() = + default; + +const SpdyFrameIR& SpdyFramer::SpdyPushPromiseFrameIterator::GetIR() const { + return *push_promise_ir_; +} + +size_t SpdyFramer::SpdyPushPromiseFrameIterator::GetFrameSizeSansBlock() const { + return GetPushPromiseFrameSizeSansBlock(*push_promise_ir_); +} + +bool SpdyFramer::SpdyPushPromiseFrameIterator::SerializeGivenEncoding( + const std::string& encoding, ZeroCopyOutputBuffer* output) const { + return SerializePushPromiseGivenEncoding(*push_promise_ir_, encoding, + !has_next_frame(), output); +} + +SpdyFramer::SpdyControlFrameIterator::SpdyControlFrameIterator( + SpdyFramer* framer, std::unique_ptr frame_ir) + : framer_(framer), frame_ir_(std::move(frame_ir)) {} + +SpdyFramer::SpdyControlFrameIterator::~SpdyControlFrameIterator() = default; + +size_t SpdyFramer::SpdyControlFrameIterator::NextFrame( + ZeroCopyOutputBuffer* output) { + size_t size_written = framer_->SerializeFrame(*frame_ir_, output); + has_next_frame_ = false; + return size_written; +} + +bool SpdyFramer::SpdyControlFrameIterator::HasNextFrame() const { + return has_next_frame_; +} + +const SpdyFrameIR& SpdyFramer::SpdyControlFrameIterator::GetIR() const { + return *frame_ir_; +} + +std::unique_ptr SpdyFramer::CreateIterator( + SpdyFramer* framer, std::unique_ptr frame_ir) { + switch (frame_ir->frame_type()) { + case SpdyFrameType::HEADERS: { + return std::make_unique( + framer, absl::WrapUnique( + static_cast(frame_ir.release()))); + } + case SpdyFrameType::PUSH_PROMISE: { + return std::make_unique( + framer, absl::WrapUnique(static_cast( + frame_ir.release()))); + } + case SpdyFrameType::DATA: { + QUICHE_DVLOG(1) << "Serialize a stream end DATA frame for VTL"; + ABSL_FALLTHROUGH_INTENDED; + } + default: { + return std::make_unique(framer, + std::move(frame_ir)); + } + } +} + +SpdySerializedFrame SpdyFramer::SerializeData(const SpdyDataIR& data_ir) { + uint8_t flags = DATA_FLAG_NONE; + int num_padding_fields = 0; + size_t size_with_padding = 0; + SerializeDataBuilderHelper(data_ir, &flags, &num_padding_fields, + &size_with_padding); + + SpdyFrameBuilder builder(size_with_padding); + builder.BeginNewFrame(SpdyFrameType::DATA, flags, data_ir.stream_id()); + if (data_ir.padded()) { + builder.WriteUInt8(data_ir.padding_payload_len() & 0xff); + } + builder.WriteBytes(data_ir.data(), data_ir.data_len()); + if (data_ir.padding_payload_len() > 0) { + std::string padding(data_ir.padding_payload_len(), 0); + builder.WriteBytes(padding.data(), padding.length()); + } + QUICHE_DCHECK_EQ(size_with_padding, builder.length()); + return builder.take(); +} + +SpdySerializedFrame SpdyFramer::SerializeDataFrameHeaderWithPaddingLengthField( + const SpdyDataIR& data_ir) { + uint8_t flags = DATA_FLAG_NONE; + size_t frame_size = 0; + size_t num_padding_fields = 0; + SerializeDataFrameHeaderWithPaddingLengthFieldBuilderHelper( + data_ir, &flags, &frame_size, &num_padding_fields); + + SpdyFrameBuilder builder(frame_size); + builder.BeginNewFrame( + SpdyFrameType::DATA, flags, data_ir.stream_id(), + num_padding_fields + data_ir.data_len() + data_ir.padding_payload_len()); + if (data_ir.padded()) { + builder.WriteUInt8(data_ir.padding_payload_len() & 0xff); + } + QUICHE_DCHECK_EQ(frame_size, builder.length()); + return builder.take(); +} + +SpdySerializedFrame SpdyFramer::SerializeRstStream( + const SpdyRstStreamIR& rst_stream) const { + size_t expected_length = kRstStreamFrameSize; + SpdyFrameBuilder builder(expected_length); + + builder.BeginNewFrame(SpdyFrameType::RST_STREAM, 0, rst_stream.stream_id()); + + builder.WriteUInt32(rst_stream.error_code()); + + QUICHE_DCHECK_EQ(expected_length, builder.length()); + return builder.take(); +} + +SpdySerializedFrame SpdyFramer::SerializeSettings( + const SpdySettingsIR& settings) const { + uint8_t flags = 0; + // Size, in bytes, of this SETTINGS frame. + size_t size = 0; + const SettingsMap* values = &(settings.values()); + SerializeSettingsBuilderHelper(settings, &flags, values, &size); + SpdyFrameBuilder builder(size); + builder.BeginNewFrame(SpdyFrameType::SETTINGS, flags, 0); + + // If this is an ACK, payload should be empty. + if (settings.is_ack()) { + return builder.take(); + } + + QUICHE_DCHECK_EQ(kSettingsFrameMinimumSize, builder.length()); + for (auto it = values->begin(); it != values->end(); ++it) { + int setting_id = it->first; + QUICHE_DCHECK_GE(setting_id, 0); + builder.WriteUInt16(static_cast(setting_id)); + builder.WriteUInt32(it->second); + } + QUICHE_DCHECK_EQ(size, builder.length()); + return builder.take(); +} + +SpdySerializedFrame SpdyFramer::SerializePing(const SpdyPingIR& ping) const { + SpdyFrameBuilder builder(kPingFrameSize); + uint8_t flags = 0; + if (ping.is_ack()) { + flags |= PING_FLAG_ACK; + } + builder.BeginNewFrame(SpdyFrameType::PING, flags, 0); + builder.WriteUInt64(ping.id()); + QUICHE_DCHECK_EQ(kPingFrameSize, builder.length()); + return builder.take(); +} + +SpdySerializedFrame SpdyFramer::SerializeGoAway( + const SpdyGoAwayIR& goaway) const { + // Compute the output buffer size, take opaque data into account. + size_t expected_length = kGoawayFrameMinimumSize; + expected_length += goaway.description().size(); + SpdyFrameBuilder builder(expected_length); + + // Serialize the GOAWAY frame. + builder.BeginNewFrame(SpdyFrameType::GOAWAY, 0, 0); + + // GOAWAY frames specify the last good stream id. + builder.WriteUInt32(goaway.last_good_stream_id()); + + // GOAWAY frames also specify the error code. + builder.WriteUInt32(goaway.error_code()); + + // GOAWAY frames may also specify opaque data. + if (!goaway.description().empty()) { + builder.WriteBytes(goaway.description().data(), + goaway.description().size()); + } + + QUICHE_DCHECK_EQ(expected_length, builder.length()); + return builder.take(); +} + +void SpdyFramer::SerializeHeadersBuilderHelper(const SpdyHeadersIR& headers, + uint8_t* flags, size_t* size, + std::string* hpack_encoding, + int* weight, + size_t* length_field) { + if (headers.fin()) { + *flags = *flags | CONTROL_FLAG_FIN; + } + // This will get overwritten if we overflow into a CONTINUATION frame. + *flags = *flags | HEADERS_FLAG_END_HEADERS; + if (headers.has_priority()) { + *flags = *flags | HEADERS_FLAG_PRIORITY; + } + if (headers.padded()) { + *flags = *flags | HEADERS_FLAG_PADDED; + } + + *size = kHeadersFrameMinimumSize; + + if (headers.padded()) { + *size = *size + kPadLengthFieldSize; + *size = *size + headers.padding_payload_len(); + } + + if (headers.has_priority()) { + *weight = ClampHttp2Weight(headers.weight()); + *size = *size + 5; + } + + *hpack_encoding = + GetHpackEncoder()->EncodeHeaderBlock(headers.header_block()); + *size = *size + hpack_encoding->size(); + if (*size > kHttp2MaxControlFrameSendSize) { + *size = *size + GetNumberRequiredContinuationFrames(*size) * + kContinuationFrameMinimumSize; + *flags = *flags & ~HEADERS_FLAG_END_HEADERS; + } + // Compute frame length field. + if (headers.padded()) { + *length_field = *length_field + kPadLengthFieldSize; + } + if (headers.has_priority()) { + *length_field = *length_field + 4; // Dependency field. + *length_field = *length_field + 1; // Weight field. + } + *length_field = *length_field + headers.padding_payload_len(); + *length_field = *length_field + hpack_encoding->size(); + // If the HEADERS frame with payload would exceed the max frame size, then + // WritePayloadWithContinuation() will serialize CONTINUATION frames as + // necessary. + *length_field = + std::min(*length_field, kHttp2MaxControlFrameSendSize - kFrameHeaderSize); +} + +SpdySerializedFrame SpdyFramer::SerializeHeaders(const SpdyHeadersIR& headers) { + uint8_t flags = 0; + // The size of this frame, including padding (if there is any) and + // variable-length header block. + size_t size = 0; + std::string hpack_encoding; + int weight = 0; + size_t length_field = 0; + SerializeHeadersBuilderHelper(headers, &flags, &size, &hpack_encoding, + &weight, &length_field); + + SpdyFrameBuilder builder(size); + builder.BeginNewFrame(SpdyFrameType::HEADERS, flags, headers.stream_id(), + length_field); + + QUICHE_DCHECK_EQ(kHeadersFrameMinimumSize, builder.length()); + + int padding_payload_len = 0; + if (headers.padded()) { + builder.WriteUInt8(headers.padding_payload_len()); + padding_payload_len = headers.padding_payload_len(); + } + if (headers.has_priority()) { + builder.WriteUInt32(PackStreamDependencyValues(headers.exclusive(), + headers.parent_stream_id())); + // Per RFC 7540 section 6.3, serialized weight value is actual value - 1. + builder.WriteUInt8(weight - 1); + } + WritePayloadWithContinuation(&builder, hpack_encoding, headers.stream_id(), + SpdyFrameType::HEADERS, padding_payload_len); + + if (debug_visitor_) { + const size_t header_list_size = + GetUncompressedSerializedLength(headers.header_block()); + debug_visitor_->OnSendCompressedFrame(headers.stream_id(), + SpdyFrameType::HEADERS, + header_list_size, builder.length()); + } + + return builder.take(); +} + +SpdySerializedFrame SpdyFramer::SerializeWindowUpdate( + const SpdyWindowUpdateIR& window_update) { + SpdyFrameBuilder builder(kWindowUpdateFrameSize); + builder.BeginNewFrame(SpdyFrameType::WINDOW_UPDATE, kNoFlags, + window_update.stream_id()); + builder.WriteUInt32(window_update.delta()); + QUICHE_DCHECK_EQ(kWindowUpdateFrameSize, builder.length()); + return builder.take(); +} + +void SpdyFramer::SerializePushPromiseBuilderHelper( + const SpdyPushPromiseIR& push_promise, uint8_t* flags, + std::string* hpack_encoding, size_t* size) { + *flags = 0; + // This will get overwritten if we overflow into a CONTINUATION frame. + *flags = *flags | PUSH_PROMISE_FLAG_END_PUSH_PROMISE; + // The size of this frame, including variable-length name-value block. + *size = kPushPromiseFrameMinimumSize; + + if (push_promise.padded()) { + *flags = *flags | PUSH_PROMISE_FLAG_PADDED; + *size = *size + kPadLengthFieldSize; + *size = *size + push_promise.padding_payload_len(); + } + + *hpack_encoding = + GetHpackEncoder()->EncodeHeaderBlock(push_promise.header_block()); + *size = *size + hpack_encoding->size(); + if (*size > kHttp2MaxControlFrameSendSize) { + *size = *size + GetNumberRequiredContinuationFrames(*size) * + kContinuationFrameMinimumSize; + *flags = *flags & ~PUSH_PROMISE_FLAG_END_PUSH_PROMISE; + } +} + +SpdySerializedFrame SpdyFramer::SerializePushPromise( + const SpdyPushPromiseIR& push_promise) { + uint8_t flags = 0; + size_t size = 0; + std::string hpack_encoding; + SerializePushPromiseBuilderHelper(push_promise, &flags, &hpack_encoding, + &size); + + SpdyFrameBuilder builder(size); + size_t length = + std::min(size, kHttp2MaxControlFrameSendSize) - kFrameHeaderSize; + builder.BeginNewFrame(SpdyFrameType::PUSH_PROMISE, flags, + push_promise.stream_id(), length); + int padding_payload_len = 0; + if (push_promise.padded()) { + builder.WriteUInt8(push_promise.padding_payload_len()); + builder.WriteUInt32(push_promise.promised_stream_id()); + QUICHE_DCHECK_EQ(kPushPromiseFrameMinimumSize + kPadLengthFieldSize, + builder.length()); + + padding_payload_len = push_promise.padding_payload_len(); + } else { + builder.WriteUInt32(push_promise.promised_stream_id()); + QUICHE_DCHECK_EQ(kPushPromiseFrameMinimumSize, builder.length()); + } + + WritePayloadWithContinuation( + &builder, hpack_encoding, push_promise.stream_id(), + SpdyFrameType::PUSH_PROMISE, padding_payload_len); + + if (debug_visitor_) { + const size_t header_list_size = + GetUncompressedSerializedLength(push_promise.header_block()); + debug_visitor_->OnSendCompressedFrame(push_promise.stream_id(), + SpdyFrameType::PUSH_PROMISE, + header_list_size, builder.length()); + } + + return builder.take(); +} + +SpdySerializedFrame SpdyFramer::SerializeContinuation( + const SpdyContinuationIR& continuation) const { + const std::string& encoding = continuation.encoding(); + size_t frame_size = kContinuationFrameMinimumSize + encoding.size(); + SpdyFrameBuilder builder(frame_size); + uint8_t flags = continuation.end_headers() ? HEADERS_FLAG_END_HEADERS : 0; + builder.BeginNewFrame(SpdyFrameType::CONTINUATION, flags, + continuation.stream_id()); + QUICHE_DCHECK_EQ(kFrameHeaderSize, builder.length()); + + builder.WriteBytes(encoding.data(), encoding.size()); + return builder.take(); +} + +SpdySerializedFrame SpdyFramer::SerializeAltSvc(const SpdyAltSvcIR& altsvc_ir) { + std::string value; + size_t size = 0; + SerializeAltSvcBuilderHelper(altsvc_ir, &value, &size); + SpdyFrameBuilder builder(size); + builder.BeginNewFrame(SpdyFrameType::ALTSVC, kNoFlags, altsvc_ir.stream_id()); + + builder.WriteUInt16(altsvc_ir.origin().length()); + builder.WriteBytes(altsvc_ir.origin().data(), altsvc_ir.origin().length()); + builder.WriteBytes(value.data(), value.length()); + QUICHE_DCHECK_LT(kGetAltSvcFrameMinimumSize, builder.length()); + return builder.take(); +} + +SpdySerializedFrame SpdyFramer::SerializePriority( + const SpdyPriorityIR& priority) const { + SpdyFrameBuilder builder(kPriorityFrameSize); + builder.BeginNewFrame(SpdyFrameType::PRIORITY, kNoFlags, + priority.stream_id()); + + builder.WriteUInt32(PackStreamDependencyValues(priority.exclusive(), + priority.parent_stream_id())); + // Per RFC 7540 section 6.3, serialized weight value is actual value - 1. + builder.WriteUInt8(priority.weight() - 1); + QUICHE_DCHECK_EQ(kPriorityFrameSize, builder.length()); + return builder.take(); +} + +SpdySerializedFrame SpdyFramer::SerializePriorityUpdate( + const SpdyPriorityUpdateIR& priority_update) const { + const size_t total_size = kPriorityUpdateFrameMinimumSize + + priority_update.priority_field_value().size(); + SpdyFrameBuilder builder(total_size); + builder.BeginNewFrame(SpdyFrameType::PRIORITY_UPDATE, kNoFlags, + priority_update.stream_id()); + + builder.WriteUInt32(priority_update.prioritized_stream_id()); + builder.WriteBytes(priority_update.priority_field_value().data(), + priority_update.priority_field_value().size()); + QUICHE_DCHECK_EQ(total_size, builder.length()); + return builder.take(); +} + +SpdySerializedFrame SpdyFramer::SerializeAcceptCh( + const SpdyAcceptChIR& accept_ch) const { + const size_t total_size = accept_ch.size(); + SpdyFrameBuilder builder(total_size); + builder.BeginNewFrame(SpdyFrameType::ACCEPT_CH, kNoFlags, + accept_ch.stream_id()); + + for (const AcceptChOriginValuePair& entry : accept_ch.entries()) { + builder.WriteUInt16(entry.origin.size()); + builder.WriteBytes(entry.origin.data(), entry.origin.size()); + builder.WriteUInt16(entry.value.size()); + builder.WriteBytes(entry.value.data(), entry.value.size()); + } + + QUICHE_DCHECK_EQ(total_size, builder.length()); + return builder.take(); +} + +SpdySerializedFrame SpdyFramer::SerializeUnknown( + const SpdyUnknownIR& unknown) const { + const size_t total_size = kFrameHeaderSize + unknown.payload().size(); + SpdyFrameBuilder builder(total_size); + builder.BeginNewUncheckedFrame(unknown.type(), unknown.flags(), + unknown.stream_id(), unknown.length()); + builder.WriteBytes(unknown.payload().data(), unknown.payload().size()); + return builder.take(); +} + +namespace { + +class FrameSerializationVisitor : public SpdyFrameVisitor { + public: + explicit FrameSerializationVisitor(SpdyFramer* framer) + : framer_(framer), frame_() {} + ~FrameSerializationVisitor() override = default; + + SpdySerializedFrame ReleaseSerializedFrame() { return std::move(frame_); } + + void VisitData(const SpdyDataIR& data) override { + frame_ = framer_->SerializeData(data); + } + void VisitRstStream(const SpdyRstStreamIR& rst_stream) override { + frame_ = framer_->SerializeRstStream(rst_stream); + } + void VisitSettings(const SpdySettingsIR& settings) override { + frame_ = framer_->SerializeSettings(settings); + } + void VisitPing(const SpdyPingIR& ping) override { + frame_ = framer_->SerializePing(ping); + } + void VisitGoAway(const SpdyGoAwayIR& goaway) override { + frame_ = framer_->SerializeGoAway(goaway); + } + void VisitHeaders(const SpdyHeadersIR& headers) override { + frame_ = framer_->SerializeHeaders(headers); + } + void VisitWindowUpdate(const SpdyWindowUpdateIR& window_update) override { + frame_ = framer_->SerializeWindowUpdate(window_update); + } + void VisitPushPromise(const SpdyPushPromiseIR& push_promise) override { + frame_ = framer_->SerializePushPromise(push_promise); + } + void VisitContinuation(const SpdyContinuationIR& continuation) override { + frame_ = framer_->SerializeContinuation(continuation); + } + void VisitAltSvc(const SpdyAltSvcIR& altsvc) override { + frame_ = framer_->SerializeAltSvc(altsvc); + } + void VisitPriority(const SpdyPriorityIR& priority) override { + frame_ = framer_->SerializePriority(priority); + } + void VisitPriorityUpdate( + const SpdyPriorityUpdateIR& priority_update) override { + frame_ = framer_->SerializePriorityUpdate(priority_update); + } + void VisitAcceptCh(const SpdyAcceptChIR& accept_ch) override { + frame_ = framer_->SerializeAcceptCh(accept_ch); + } + void VisitUnknown(const SpdyUnknownIR& unknown) override { + frame_ = framer_->SerializeUnknown(unknown); + } + + private: + SpdyFramer* framer_; + SpdySerializedFrame frame_; +}; + +// TODO(diannahu): Use also in frame serialization. +class FlagsSerializationVisitor : public SpdyFrameVisitor { + public: + void VisitData(const SpdyDataIR& data) override { + flags_ = DATA_FLAG_NONE; + if (data.fin()) { + flags_ |= DATA_FLAG_FIN; + } + if (data.padded()) { + flags_ |= DATA_FLAG_PADDED; + } + } + + void VisitRstStream(const SpdyRstStreamIR& /*rst_stream*/) override { + flags_ = kNoFlags; + } + + void VisitSettings(const SpdySettingsIR& settings) override { + flags_ = kNoFlags; + if (settings.is_ack()) { + flags_ |= SETTINGS_FLAG_ACK; + } + } + + void VisitPing(const SpdyPingIR& ping) override { + flags_ = kNoFlags; + if (ping.is_ack()) { + flags_ |= PING_FLAG_ACK; + } + } + + void VisitGoAway(const SpdyGoAwayIR& /*goaway*/) override { + flags_ = kNoFlags; + } + + // TODO(diannahu): The END_HEADERS flag is incorrect for HEADERS that require + // CONTINUATION frames. + void VisitHeaders(const SpdyHeadersIR& headers) override { + flags_ = HEADERS_FLAG_END_HEADERS; + if (headers.fin()) { + flags_ |= CONTROL_FLAG_FIN; + } + if (headers.padded()) { + flags_ |= HEADERS_FLAG_PADDED; + } + if (headers.has_priority()) { + flags_ |= HEADERS_FLAG_PRIORITY; + } + } + + void VisitWindowUpdate(const SpdyWindowUpdateIR& /*window_update*/) override { + flags_ = kNoFlags; + } + + // TODO(diannahu): The END_PUSH_PROMISE flag is incorrect for PUSH_PROMISEs + // that require CONTINUATION frames. + void VisitPushPromise(const SpdyPushPromiseIR& push_promise) override { + flags_ = PUSH_PROMISE_FLAG_END_PUSH_PROMISE; + if (push_promise.padded()) { + flags_ |= PUSH_PROMISE_FLAG_PADDED; + } + } + + // TODO(diannahu): The END_HEADERS flag is incorrect for CONTINUATIONs that + // require CONTINUATION frames. + void VisitContinuation(const SpdyContinuationIR& /*continuation*/) override { + flags_ = HEADERS_FLAG_END_HEADERS; + } + + void VisitAltSvc(const SpdyAltSvcIR& /*altsvc*/) override { + flags_ = kNoFlags; + } + + void VisitPriority(const SpdyPriorityIR& /*priority*/) override { + flags_ = kNoFlags; + } + + void VisitPriorityUpdate( + const SpdyPriorityUpdateIR& /*priority_update*/) override { + flags_ = kNoFlags; + } + + void VisitAcceptCh(const SpdyAcceptChIR& /*accept_ch*/) override { + flags_ = kNoFlags; + } + + uint8_t flags() const { return flags_; } + + private: + uint8_t flags_ = kNoFlags; +}; + +} // namespace + +SpdySerializedFrame SpdyFramer::SerializeFrame(const SpdyFrameIR& frame) { + FrameSerializationVisitor visitor(this); + frame.Visit(&visitor); + return visitor.ReleaseSerializedFrame(); +} + +uint8_t SpdyFramer::GetSerializedFlags(const SpdyFrameIR& frame) { + FlagsSerializationVisitor visitor; + frame.Visit(&visitor); + return visitor.flags(); +} + +bool SpdyFramer::SerializeData(const SpdyDataIR& data_ir, + ZeroCopyOutputBuffer* output) const { + uint8_t flags = DATA_FLAG_NONE; + int num_padding_fields = 0; + size_t size_with_padding = 0; + SerializeDataBuilderHelper(data_ir, &flags, &num_padding_fields, + &size_with_padding); + SpdyFrameBuilder builder(size_with_padding, output); + + bool ok = + builder.BeginNewFrame(SpdyFrameType::DATA, flags, data_ir.stream_id()); + + if (data_ir.padded()) { + ok = ok && builder.WriteUInt8(data_ir.padding_payload_len() & 0xff); + } + + ok = ok && builder.WriteBytes(data_ir.data(), data_ir.data_len()); + if (data_ir.padding_payload_len() > 0) { + std::string padding; + padding = std::string(data_ir.padding_payload_len(), 0); + ok = ok && builder.WriteBytes(padding.data(), padding.length()); + } + QUICHE_DCHECK_EQ(size_with_padding, builder.length()); + return ok; +} + +bool SpdyFramer::SerializeDataFrameHeaderWithPaddingLengthField( + const SpdyDataIR& data_ir, ZeroCopyOutputBuffer* output) const { + uint8_t flags = DATA_FLAG_NONE; + size_t frame_size = 0; + size_t num_padding_fields = 0; + SerializeDataFrameHeaderWithPaddingLengthFieldBuilderHelper( + data_ir, &flags, &frame_size, &num_padding_fields); + + SpdyFrameBuilder builder(frame_size, output); + bool ok = true; + ok = ok && + builder.BeginNewFrame(SpdyFrameType::DATA, flags, data_ir.stream_id(), + num_padding_fields + data_ir.data_len() + + data_ir.padding_payload_len()); + if (data_ir.padded()) { + ok = ok && builder.WriteUInt8(data_ir.padding_payload_len() & 0xff); + } + QUICHE_DCHECK_EQ(frame_size, builder.length()); + return ok; +} + +bool SpdyFramer::SerializeRstStream(const SpdyRstStreamIR& rst_stream, + ZeroCopyOutputBuffer* output) const { + size_t expected_length = kRstStreamFrameSize; + SpdyFrameBuilder builder(expected_length, output); + bool ok = builder.BeginNewFrame(SpdyFrameType::RST_STREAM, 0, + rst_stream.stream_id()); + ok = ok && builder.WriteUInt32(rst_stream.error_code()); + + QUICHE_DCHECK_EQ(expected_length, builder.length()); + return ok; +} + +bool SpdyFramer::SerializeSettings(const SpdySettingsIR& settings, + ZeroCopyOutputBuffer* output) const { + uint8_t flags = 0; + // Size, in bytes, of this SETTINGS frame. + size_t size = 0; + const SettingsMap* values = &(settings.values()); + SerializeSettingsBuilderHelper(settings, &flags, values, &size); + SpdyFrameBuilder builder(size, output); + bool ok = builder.BeginNewFrame(SpdyFrameType::SETTINGS, flags, 0); + + // If this is an ACK, payload should be empty. + if (settings.is_ack()) { + return ok; + } + + QUICHE_DCHECK_EQ(kSettingsFrameMinimumSize, builder.length()); + for (auto it = values->begin(); it != values->end(); ++it) { + int setting_id = it->first; + QUICHE_DCHECK_GE(setting_id, 0); + ok = ok && builder.WriteUInt16(static_cast(setting_id)) && + builder.WriteUInt32(it->second); + } + QUICHE_DCHECK_EQ(size, builder.length()); + return ok; +} + +bool SpdyFramer::SerializePing(const SpdyPingIR& ping, + ZeroCopyOutputBuffer* output) const { + SpdyFrameBuilder builder(kPingFrameSize, output); + uint8_t flags = 0; + if (ping.is_ack()) { + flags |= PING_FLAG_ACK; + } + bool ok = builder.BeginNewFrame(SpdyFrameType::PING, flags, 0); + ok = ok && builder.WriteUInt64(ping.id()); + QUICHE_DCHECK_EQ(kPingFrameSize, builder.length()); + return ok; +} + +bool SpdyFramer::SerializeGoAway(const SpdyGoAwayIR& goaway, + ZeroCopyOutputBuffer* output) const { + // Compute the output buffer size, take opaque data into account. + size_t expected_length = kGoawayFrameMinimumSize; + expected_length += goaway.description().size(); + SpdyFrameBuilder builder(expected_length, output); + + // Serialize the GOAWAY frame. + bool ok = builder.BeginNewFrame(SpdyFrameType::GOAWAY, 0, 0); + + // GOAWAY frames specify the last good stream id. + ok = ok && builder.WriteUInt32(goaway.last_good_stream_id()) && + // GOAWAY frames also specify the error status code. + builder.WriteUInt32(goaway.error_code()); + + // GOAWAY frames may also specify opaque data. + if (!goaway.description().empty()) { + ok = ok && builder.WriteBytes(goaway.description().data(), + goaway.description().size()); + } + + QUICHE_DCHECK_EQ(expected_length, builder.length()); + return ok; +} + +bool SpdyFramer::SerializeHeaders(const SpdyHeadersIR& headers, + ZeroCopyOutputBuffer* output) { + uint8_t flags = 0; + // The size of this frame, including padding (if there is any) and + // variable-length header block. + size_t size = 0; + std::string hpack_encoding; + int weight = 0; + size_t length_field = 0; + SerializeHeadersBuilderHelper(headers, &flags, &size, &hpack_encoding, + &weight, &length_field); + + bool ok = true; + SpdyFrameBuilder builder(size, output); + ok = ok && builder.BeginNewFrame(SpdyFrameType::HEADERS, flags, + headers.stream_id(), length_field); + QUICHE_DCHECK_EQ(kHeadersFrameMinimumSize, builder.length()); + + int padding_payload_len = 0; + if (headers.padded()) { + ok = ok && builder.WriteUInt8(headers.padding_payload_len()); + padding_payload_len = headers.padding_payload_len(); + } + if (headers.has_priority()) { + ok = ok && + builder.WriteUInt32(PackStreamDependencyValues( + headers.exclusive(), headers.parent_stream_id())) && + // Per RFC 7540 section 6.3, serialized weight value is weight - 1. + builder.WriteUInt8(weight - 1); + } + ok = ok && WritePayloadWithContinuation( + &builder, hpack_encoding, headers.stream_id(), + SpdyFrameType::HEADERS, padding_payload_len); + + if (debug_visitor_) { + const size_t header_list_size = + GetUncompressedSerializedLength(headers.header_block()); + debug_visitor_->OnSendCompressedFrame(headers.stream_id(), + SpdyFrameType::HEADERS, + header_list_size, builder.length()); + } + + return ok; +} + +bool SpdyFramer::SerializeWindowUpdate(const SpdyWindowUpdateIR& window_update, + ZeroCopyOutputBuffer* output) const { + SpdyFrameBuilder builder(kWindowUpdateFrameSize, output); + bool ok = builder.BeginNewFrame(SpdyFrameType::WINDOW_UPDATE, kNoFlags, + window_update.stream_id()); + ok = ok && builder.WriteUInt32(window_update.delta()); + QUICHE_DCHECK_EQ(kWindowUpdateFrameSize, builder.length()); + return ok; +} + +bool SpdyFramer::SerializePushPromise(const SpdyPushPromiseIR& push_promise, + ZeroCopyOutputBuffer* output) { + uint8_t flags = 0; + size_t size = 0; + std::string hpack_encoding; + SerializePushPromiseBuilderHelper(push_promise, &flags, &hpack_encoding, + &size); + + bool ok = true; + SpdyFrameBuilder builder(size, output); + size_t length = + std::min(size, kHttp2MaxControlFrameSendSize) - kFrameHeaderSize; + ok = builder.BeginNewFrame(SpdyFrameType::PUSH_PROMISE, flags, + push_promise.stream_id(), length); + + int padding_payload_len = 0; + if (push_promise.padded()) { + ok = ok && builder.WriteUInt8(push_promise.padding_payload_len()) && + builder.WriteUInt32(push_promise.promised_stream_id()); + QUICHE_DCHECK_EQ(kPushPromiseFrameMinimumSize + kPadLengthFieldSize, + builder.length()); + + padding_payload_len = push_promise.padding_payload_len(); + } else { + ok = ok && builder.WriteUInt32(push_promise.promised_stream_id()); + QUICHE_DCHECK_EQ(kPushPromiseFrameMinimumSize, builder.length()); + } + + ok = ok && WritePayloadWithContinuation( + &builder, hpack_encoding, push_promise.stream_id(), + SpdyFrameType::PUSH_PROMISE, padding_payload_len); + + if (debug_visitor_) { + const size_t header_list_size = + GetUncompressedSerializedLength(push_promise.header_block()); + debug_visitor_->OnSendCompressedFrame(push_promise.stream_id(), + SpdyFrameType::PUSH_PROMISE, + header_list_size, builder.length()); + } + + return ok; +} + +bool SpdyFramer::SerializeContinuation(const SpdyContinuationIR& continuation, + ZeroCopyOutputBuffer* output) const { + const std::string& encoding = continuation.encoding(); + size_t frame_size = kContinuationFrameMinimumSize + encoding.size(); + SpdyFrameBuilder builder(frame_size, output); + uint8_t flags = continuation.end_headers() ? HEADERS_FLAG_END_HEADERS : 0; + bool ok = builder.BeginNewFrame(SpdyFrameType::CONTINUATION, flags, + continuation.stream_id(), + frame_size - kFrameHeaderSize); + QUICHE_DCHECK_EQ(kFrameHeaderSize, builder.length()); + + ok = ok && builder.WriteBytes(encoding.data(), encoding.size()); + return ok; +} + +bool SpdyFramer::SerializeAltSvc(const SpdyAltSvcIR& altsvc_ir, + ZeroCopyOutputBuffer* output) { + std::string value; + size_t size = 0; + SerializeAltSvcBuilderHelper(altsvc_ir, &value, &size); + SpdyFrameBuilder builder(size, output); + bool ok = builder.BeginNewFrame(SpdyFrameType::ALTSVC, kNoFlags, + altsvc_ir.stream_id()) && + builder.WriteUInt16(altsvc_ir.origin().length()) && + builder.WriteBytes(altsvc_ir.origin().data(), + altsvc_ir.origin().length()) && + builder.WriteBytes(value.data(), value.length()); + QUICHE_DCHECK_LT(kGetAltSvcFrameMinimumSize, builder.length()); + return ok; +} + +bool SpdyFramer::SerializePriority(const SpdyPriorityIR& priority, + ZeroCopyOutputBuffer* output) const { + SpdyFrameBuilder builder(kPriorityFrameSize, output); + bool ok = builder.BeginNewFrame(SpdyFrameType::PRIORITY, kNoFlags, + priority.stream_id()); + ok = ok && + builder.WriteUInt32(PackStreamDependencyValues( + priority.exclusive(), priority.parent_stream_id())) && + // Per RFC 7540 section 6.3, serialized weight value is actual value - 1. + builder.WriteUInt8(priority.weight() - 1); + QUICHE_DCHECK_EQ(kPriorityFrameSize, builder.length()); + return ok; +} + +bool SpdyFramer::SerializePriorityUpdate( + const SpdyPriorityUpdateIR& priority_update, + ZeroCopyOutputBuffer* output) const { + const size_t total_size = kPriorityUpdateFrameMinimumSize + + priority_update.priority_field_value().size(); + SpdyFrameBuilder builder(total_size, output); + bool ok = builder.BeginNewFrame(SpdyFrameType::PRIORITY_UPDATE, kNoFlags, + priority_update.stream_id()); + + ok = ok && builder.WriteUInt32(priority_update.prioritized_stream_id()); + ok = ok && builder.WriteBytes(priority_update.priority_field_value().data(), + priority_update.priority_field_value().size()); + QUICHE_DCHECK_EQ(total_size, builder.length()); + return ok; +} + +bool SpdyFramer::SerializeAcceptCh(const SpdyAcceptChIR& accept_ch, + ZeroCopyOutputBuffer* output) const { + const size_t total_size = accept_ch.size(); + SpdyFrameBuilder builder(total_size, output); + bool ok = builder.BeginNewFrame(SpdyFrameType::ACCEPT_CH, kNoFlags, + accept_ch.stream_id()); + + for (const AcceptChOriginValuePair& entry : accept_ch.entries()) { + ok = ok && builder.WriteUInt16(entry.origin.size()); + ok = ok && builder.WriteBytes(entry.origin.data(), entry.origin.size()); + ok = ok && builder.WriteUInt16(entry.value.size()); + ok = ok && builder.WriteBytes(entry.value.data(), entry.value.size()); + } + + QUICHE_DCHECK_EQ(total_size, builder.length()); + return ok; +} + +bool SpdyFramer::SerializeUnknown(const SpdyUnknownIR& unknown, + ZeroCopyOutputBuffer* output) const { + const size_t total_size = kFrameHeaderSize + unknown.payload().size(); + SpdyFrameBuilder builder(total_size, output); + bool ok = builder.BeginNewUncheckedFrame( + unknown.type(), unknown.flags(), unknown.stream_id(), unknown.length()); + ok = ok && + builder.WriteBytes(unknown.payload().data(), unknown.payload().size()); + return ok; +} + +namespace { + +class FrameSerializationVisitorWithOutput : public SpdyFrameVisitor { + public: + explicit FrameSerializationVisitorWithOutput(SpdyFramer* framer, + ZeroCopyOutputBuffer* output) + : framer_(framer), output_(output), result_(false) {} + ~FrameSerializationVisitorWithOutput() override = default; + + size_t Result() { return result_; } + + void VisitData(const SpdyDataIR& data) override { + result_ = framer_->SerializeData(data, output_); + } + void VisitRstStream(const SpdyRstStreamIR& rst_stream) override { + result_ = framer_->SerializeRstStream(rst_stream, output_); + } + void VisitSettings(const SpdySettingsIR& settings) override { + result_ = framer_->SerializeSettings(settings, output_); + } + void VisitPing(const SpdyPingIR& ping) override { + result_ = framer_->SerializePing(ping, output_); + } + void VisitGoAway(const SpdyGoAwayIR& goaway) override { + result_ = framer_->SerializeGoAway(goaway, output_); + } + void VisitHeaders(const SpdyHeadersIR& headers) override { + result_ = framer_->SerializeHeaders(headers, output_); + } + void VisitWindowUpdate(const SpdyWindowUpdateIR& window_update) override { + result_ = framer_->SerializeWindowUpdate(window_update, output_); + } + void VisitPushPromise(const SpdyPushPromiseIR& push_promise) override { + result_ = framer_->SerializePushPromise(push_promise, output_); + } + void VisitContinuation(const SpdyContinuationIR& continuation) override { + result_ = framer_->SerializeContinuation(continuation, output_); + } + void VisitAltSvc(const SpdyAltSvcIR& altsvc) override { + result_ = framer_->SerializeAltSvc(altsvc, output_); + } + void VisitPriority(const SpdyPriorityIR& priority) override { + result_ = framer_->SerializePriority(priority, output_); + } + void VisitPriorityUpdate( + const SpdyPriorityUpdateIR& priority_update) override { + result_ = framer_->SerializePriorityUpdate(priority_update, output_); + } + void VisitAcceptCh(const SpdyAcceptChIR& accept_ch) override { + result_ = framer_->SerializeAcceptCh(accept_ch, output_); + } + + void VisitUnknown(const SpdyUnknownIR& unknown) override { + result_ = framer_->SerializeUnknown(unknown, output_); + } + + private: + SpdyFramer* framer_; + ZeroCopyOutputBuffer* output_; + bool result_; +}; + +} // namespace + +size_t SpdyFramer::SerializeFrame(const SpdyFrameIR& frame, + ZeroCopyOutputBuffer* output) { + FrameSerializationVisitorWithOutput visitor(this, output); + size_t free_bytes_before = output->BytesFree(); + frame.Visit(&visitor); + return visitor.Result() ? free_bytes_before - output->BytesFree() : 0; +} + +HpackEncoder* SpdyFramer::GetHpackEncoder() { + if (hpack_encoder_ == nullptr) { + hpack_encoder_ = std::make_unique(); + if (!compression_enabled()) { + hpack_encoder_->DisableCompression(); + } + } + return hpack_encoder_.get(); +} + +void SpdyFramer::UpdateHeaderEncoderTableSize(uint32_t value) { + GetHpackEncoder()->ApplyHeaderTableSizeSetting(value); +} + +size_t SpdyFramer::header_encoder_table_size() const { + if (hpack_encoder_ == nullptr) { + return kDefaultHeaderTableSizeSetting; + } else { + return hpack_encoder_->CurrentHeaderTableSizeSetting(); + } +} + +} // namespace spdy diff --git a/quiche/spdy/core/spdy_framer.h b/quiche/spdy/core/spdy_framer.h new file mode 100644 index 000000000000..ea886ca61b05 --- /dev/null +++ b/quiche/spdy/core/spdy_framer.h @@ -0,0 +1,376 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_SPDY_CORE_SPDY_FRAMER_H_ +#define QUICHE_SPDY_CORE_SPDY_FRAMER_H_ + +#include + +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/spdy/core/hpack/hpack_encoder.h" +#include "quiche/spdy/core/http2_header_block.h" +#include "quiche/spdy/core/spdy_alt_svc_wire_format.h" +#include "quiche/spdy/core/spdy_headers_handler_interface.h" +#include "quiche/spdy/core/spdy_protocol.h" +#include "quiche/spdy/core/zero_copy_output_buffer.h" + +namespace spdy { + +namespace test { + +class SpdyFramerPeer; +class SpdyFramerTest_MultipleContinuationFramesWithIterator_Test; +class SpdyFramerTest_PushPromiseFramesWithIterator_Test; + +} // namespace test + +class QUICHE_EXPORT SpdyFrameSequence { + public: + virtual ~SpdyFrameSequence() {} + + // Serializes the next frame in the sequence to |output|. Returns the number + // of bytes written to |output|. + virtual size_t NextFrame(ZeroCopyOutputBuffer* output) = 0; + + // Returns true iff there is at least one more frame in the sequence. + virtual bool HasNextFrame() const = 0; + + // Get SpdyFrameIR of the frame to be serialized. + virtual const SpdyFrameIR& GetIR() const = 0; +}; + +class QUICHE_EXPORT SpdyFramer { + public: + enum CompressionOption { + ENABLE_COMPRESSION, + DISABLE_COMPRESSION, + }; + + // Create a SpdyFrameSequence to serialize |frame_ir|. + static std::unique_ptr CreateIterator( + SpdyFramer* framer, std::unique_ptr frame_ir); + + // Gets the serialized flags for the given |frame|. + static uint8_t GetSerializedFlags(const SpdyFrameIR& frame); + + // Serialize a data frame. + static SpdySerializedFrame SerializeData(const SpdyDataIR& data_ir); + // Serializes the data frame header and optionally padding length fields, + // excluding actual data payload and padding. + static SpdySerializedFrame SerializeDataFrameHeaderWithPaddingLengthField( + const SpdyDataIR& data_ir); + + // Serializes a WINDOW_UPDATE frame. The WINDOW_UPDATE + // frame is used to implement per stream flow control. + static SpdySerializedFrame SerializeWindowUpdate( + const SpdyWindowUpdateIR& window_update); + + explicit SpdyFramer(CompressionOption option); + + virtual ~SpdyFramer(); + + // Set debug callbacks to be called from the framer. The debug visitor is + // completely optional and need not be set in order for normal operation. + // If this is called multiple times, only the last visitor will be used. + void set_debug_visitor(SpdyFramerDebugVisitorInterface* debug_visitor); + + SpdySerializedFrame SerializeRstStream( + const SpdyRstStreamIR& rst_stream) const; + + // Serializes a SETTINGS frame. The SETTINGS frame is + // used to communicate name/value pairs relevant to the communication channel. + SpdySerializedFrame SerializeSettings(const SpdySettingsIR& settings) const; + + // Serializes a PING frame. The unique_id is used to + // identify the ping request/response. + SpdySerializedFrame SerializePing(const SpdyPingIR& ping) const; + + // Serializes a GOAWAY frame. The GOAWAY frame is used + // prior to the shutting down of the TCP connection, and includes the + // stream_id of the last stream the sender of the frame is willing to process + // to completion. + SpdySerializedFrame SerializeGoAway(const SpdyGoAwayIR& goaway) const; + + // Serializes a HEADERS frame. The HEADERS frame is used + // for sending headers. + SpdySerializedFrame SerializeHeaders(const SpdyHeadersIR& headers); + + // Serializes a PUSH_PROMISE frame. The PUSH_PROMISE frame is used + // to inform the client that it will be receiving an additional stream + // in response to the original request. The frame includes synthesized + // headers to explain the upcoming data. + SpdySerializedFrame SerializePushPromise( + const SpdyPushPromiseIR& push_promise); + + // Serializes a CONTINUATION frame. The CONTINUATION frame is used + // to continue a sequence of header block fragments. + SpdySerializedFrame SerializeContinuation( + const SpdyContinuationIR& continuation) const; + + // Serializes an ALTSVC frame. The ALTSVC frame advertises the + // availability of an alternative service to the client. + SpdySerializedFrame SerializeAltSvc(const SpdyAltSvcIR& altsvc); + + // Serializes a PRIORITY frame. The PRIORITY frame advises a change in + // the relative priority of the given stream. + SpdySerializedFrame SerializePriority(const SpdyPriorityIR& priority) const; + + // Serializes a PRIORITY_UPDATE frame. + // See https://httpwg.org/http-extensions/draft-ietf-httpbis-priority.html. + SpdySerializedFrame SerializePriorityUpdate( + const SpdyPriorityUpdateIR& priority_update) const; + + // Serializes an ACCEPT_CH frame. See + // https://tools.ietf.org/html/draft-davidben-http-client-hint-reliability-02. + SpdySerializedFrame SerializeAcceptCh(const SpdyAcceptChIR& accept_ch) const; + + // Serializes an unknown frame given a frame header and payload. + SpdySerializedFrame SerializeUnknown(const SpdyUnknownIR& unknown) const; + + // Serialize a frame of unknown type. + SpdySerializedFrame SerializeFrame(const SpdyFrameIR& frame); + + // Serialize a data frame. + bool SerializeData(const SpdyDataIR& data, + ZeroCopyOutputBuffer* output) const; + + // Serializes the data frame header and optionally padding length fields, + // excluding actual data payload and padding. + bool SerializeDataFrameHeaderWithPaddingLengthField( + const SpdyDataIR& data, ZeroCopyOutputBuffer* output) const; + + bool SerializeRstStream(const SpdyRstStreamIR& rst_stream, + ZeroCopyOutputBuffer* output) const; + + // Serializes a SETTINGS frame. The SETTINGS frame is + // used to communicate name/value pairs relevant to the communication channel. + bool SerializeSettings(const SpdySettingsIR& settings, + ZeroCopyOutputBuffer* output) const; + + // Serializes a PING frame. The unique_id is used to + // identify the ping request/response. + bool SerializePing(const SpdyPingIR& ping, + ZeroCopyOutputBuffer* output) const; + + // Serializes a GOAWAY frame. The GOAWAY frame is used + // prior to the shutting down of the TCP connection, and includes the + // stream_id of the last stream the sender of the frame is willing to process + // to completion. + bool SerializeGoAway(const SpdyGoAwayIR& goaway, + ZeroCopyOutputBuffer* output) const; + + // Serializes a HEADERS frame. The HEADERS frame is used + // for sending headers. + bool SerializeHeaders(const SpdyHeadersIR& headers, + ZeroCopyOutputBuffer* output); + + // Serializes a WINDOW_UPDATE frame. The WINDOW_UPDATE + // frame is used to implement per stream flow control. + bool SerializeWindowUpdate(const SpdyWindowUpdateIR& window_update, + ZeroCopyOutputBuffer* output) const; + + // Serializes a PUSH_PROMISE frame. The PUSH_PROMISE frame is used + // to inform the client that it will be receiving an additional stream + // in response to the original request. The frame includes synthesized + // headers to explain the upcoming data. + bool SerializePushPromise(const SpdyPushPromiseIR& push_promise, + ZeroCopyOutputBuffer* output); + + // Serializes a CONTINUATION frame. The CONTINUATION frame is used + // to continue a sequence of header block fragments. + bool SerializeContinuation(const SpdyContinuationIR& continuation, + ZeroCopyOutputBuffer* output) const; + + // Serializes an ALTSVC frame. The ALTSVC frame advertises the + // availability of an alternative service to the client. + bool SerializeAltSvc(const SpdyAltSvcIR& altsvc, + ZeroCopyOutputBuffer* output); + + // Serializes a PRIORITY frame. The PRIORITY frame advises a change in + // the relative priority of the given stream. + bool SerializePriority(const SpdyPriorityIR& priority, + ZeroCopyOutputBuffer* output) const; + + // Serializes a PRIORITY_UPDATE frame. + // See https://httpwg.org/http-extensions/draft-ietf-httpbis-priority.html. + bool SerializePriorityUpdate(const SpdyPriorityUpdateIR& priority_update, + ZeroCopyOutputBuffer* output) const; + + // Serializes an ACCEPT_CH frame. See + // https://tools.ietf.org/html/draft-davidben-http-client-hint-reliability-02. + bool SerializeAcceptCh(const SpdyAcceptChIR& accept_ch, + ZeroCopyOutputBuffer* output) const; + + // Serializes an unknown frame given a frame header and payload. + bool SerializeUnknown(const SpdyUnknownIR& unknown, + ZeroCopyOutputBuffer* output) const; + + // Serialize a frame of unknown type. + size_t SerializeFrame(const SpdyFrameIR& frame, ZeroCopyOutputBuffer* output); + + // Returns whether this SpdyFramer will compress header blocks using HPACK. + bool compression_enabled() const { + return compression_option_ == ENABLE_COMPRESSION; + } + + void SetHpackIndexingPolicy(HpackEncoder::IndexingPolicy policy) { + GetHpackEncoder()->SetIndexingPolicy(std::move(policy)); + } + + // Updates the maximum size of the header encoder compression table. + void UpdateHeaderEncoderTableSize(uint32_t value); + + // Returns the maximum size of the header encoder compression table. + size_t header_encoder_table_size() const; + + // Get (and lazily initialize) the HPACK encoder state. + HpackEncoder* GetHpackEncoder(); + + // Gets the HPACK encoder state. Returns nullptr if the encoder has not been + // initialized. + const HpackEncoder* GetHpackEncoder() const { return hpack_encoder_.get(); } + + protected: + friend class test::SpdyFramerPeer; + friend class test::SpdyFramerTest_MultipleContinuationFramesWithIterator_Test; + friend class test::SpdyFramerTest_PushPromiseFramesWithIterator_Test; + + // Iteratively converts a SpdyFrameIR into an appropriate sequence of Spdy + // frames. + // Example usage: + // std::unique_ptr it = CreateIterator(framer, frame_ir); + // while (it->HasNextFrame()) { + // if(it->NextFrame(output) == 0) { + // // Write failed; + // } + // } + class QUICHE_EXPORT SpdyFrameIterator : public SpdyFrameSequence { + public: + // Creates an iterator with the provided framer. + // Does not take ownership of |framer|. + // |framer| must outlive this instance. + explicit SpdyFrameIterator(SpdyFramer* framer); + ~SpdyFrameIterator() override; + + // Serializes the next frame in the sequence to |output|. Returns the number + // of bytes written to |output|. + size_t NextFrame(ZeroCopyOutputBuffer* output) override; + + // Returns true iff there is at least one more frame in the sequence. + bool HasNextFrame() const override; + + // SpdyFrameIterator is neither copyable nor movable. + SpdyFrameIterator(const SpdyFrameIterator&) = delete; + SpdyFrameIterator& operator=(const SpdyFrameIterator&) = delete; + + protected: + virtual size_t GetFrameSizeSansBlock() const = 0; + virtual bool SerializeGivenEncoding(const std::string& encoding, + ZeroCopyOutputBuffer* output) const = 0; + + SpdyFramer* GetFramer() const { return framer_; } + + void SetEncoder(const SpdyFrameWithHeaderBlockIR* ir) { + encoder_ = + framer_->GetHpackEncoder()->EncodeHeaderSet(ir->header_block()); + } + + bool has_next_frame() const { return has_next_frame_; } + + private: + SpdyFramer* const framer_; + std::unique_ptr encoder_; + bool is_first_frame_; + bool has_next_frame_; + }; + + // Iteratively converts a SpdyHeadersIR (with a possibly huge + // Http2HeaderBlock) into an appropriate sequence of SpdySerializedFrames, and + // write to the output. + class QUICHE_EXPORT SpdyHeaderFrameIterator : public SpdyFrameIterator { + public: + // Does not take ownership of |framer|. Take ownership of |headers_ir|. + SpdyHeaderFrameIterator(SpdyFramer* framer, + std::unique_ptr headers_ir); + + ~SpdyHeaderFrameIterator() override; + + private: + const SpdyFrameIR& GetIR() const override; + size_t GetFrameSizeSansBlock() const override; + bool SerializeGivenEncoding(const std::string& encoding, + ZeroCopyOutputBuffer* output) const override; + + const std::unique_ptr headers_ir_; + }; + + // Iteratively converts a SpdyPushPromiseIR (with a possibly huge + // Http2HeaderBlock) into an appropriate sequence of SpdySerializedFrames, and + // write to the output. + class QUICHE_EXPORT SpdyPushPromiseFrameIterator : public SpdyFrameIterator { + public: + // Does not take ownership of |framer|. Take ownership of |push_promise_ir|. + SpdyPushPromiseFrameIterator( + SpdyFramer* framer, + std::unique_ptr push_promise_ir); + + ~SpdyPushPromiseFrameIterator() override; + + private: + const SpdyFrameIR& GetIR() const override; + size_t GetFrameSizeSansBlock() const override; + bool SerializeGivenEncoding(const std::string& encoding, + ZeroCopyOutputBuffer* output) const override; + + const std::unique_ptr push_promise_ir_; + }; + + // Converts a SpdyFrameIR into one Spdy frame (a sequence of length 1), and + // write it to the output. + class QUICHE_EXPORT SpdyControlFrameIterator : public SpdyFrameSequence { + public: + SpdyControlFrameIterator(SpdyFramer* framer, + std::unique_ptr frame_ir); + ~SpdyControlFrameIterator() override; + + size_t NextFrame(ZeroCopyOutputBuffer* output) override; + + bool HasNextFrame() const override; + + const SpdyFrameIR& GetIR() const override; + + private: + SpdyFramer* const framer_; + std::unique_ptr frame_ir_; + bool has_next_frame_ = true; + }; + + private: + void SerializeHeadersBuilderHelper(const SpdyHeadersIR& headers, + uint8_t* flags, size_t* size, + std::string* hpack_encoding, int* weight, + size_t* length_field); + void SerializePushPromiseBuilderHelper(const SpdyPushPromiseIR& push_promise, + uint8_t* flags, + std::string* hpack_encoding, + size_t* size); + + std::unique_ptr hpack_encoder_; + + SpdyFramerDebugVisitorInterface* debug_visitor_; + + // Determines whether HPACK compression is used. + const CompressionOption compression_option_; +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_SPDY_FRAMER_H_ diff --git a/quiche/spdy/core/spdy_framer_test.cc b/quiche/spdy/core/spdy_framer_test.cc new file mode 100644 index 000000000000..33decc350e0c --- /dev/null +++ b/quiche/spdy/core/spdy_framer_test.cc @@ -0,0 +1,5089 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/spdy_framer.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/quiche_text_utils.h" +#include "quiche/spdy/core/array_output_buffer.h" +#include "quiche/spdy/core/http2_frame_decoder_adapter.h" +#include "quiche/spdy/core/recording_headers_handler.h" +#include "quiche/spdy/core/spdy_bitmasks.h" +#include "quiche/spdy/core/spdy_frame_builder.h" +#include "quiche/spdy/core/spdy_protocol.h" +#include "quiche/spdy/test_tools/mock_spdy_framer_visitor.h" +#include "quiche/spdy/test_tools/spdy_test_utils.h" + +using ::http2::Http2DecoderAdapter; +using ::testing::_; + +namespace spdy { + +namespace test { + +namespace { + +const int64_t kSize = 1024 * 1024; +char output_buffer[kSize] = ""; + +// frame_list_char is used to hold frames to be compared with output_buffer. +const int64_t buffer_size = 64 * 1024; +char frame_list_char[buffer_size] = ""; +} // namespace + +class MockDebugVisitor : public SpdyFramerDebugVisitorInterface { + public: + MOCK_METHOD(void, OnSendCompressedFrame, + (SpdyStreamId stream_id, SpdyFrameType type, size_t payload_len, + size_t frame_len), + (override)); + + MOCK_METHOD(void, OnReceiveCompressedFrame, + (SpdyStreamId stream_id, SpdyFrameType type, size_t frame_len), + (override)); +}; + +MATCHER_P(IsFrameUnionOf, frame_list, "") { + size_t size_verified = 0; + for (const auto& frame : *frame_list) { + if (arg.size() < size_verified + frame.size()) { + QUICHE_LOG(FATAL) + << "Incremental header serialization should not lead to a " + << "higher total frame length than non-incremental method."; + return false; + } + if (memcmp(arg.data() + size_verified, frame.data(), frame.size())) { + CompareCharArraysWithHexError( + "Header serialization methods should be equivalent: ", + reinterpret_cast(arg.data() + size_verified), + frame.size(), reinterpret_cast(frame.data()), + frame.size()); + return false; + } + size_verified += frame.size(); + } + return size_verified == arg.size(); +} + +class SpdyFramerPeer { + public: + // TODO(dahollings): Remove these methods when deprecating non-incremental + // header serialization path. + static std::unique_ptr CloneSpdyHeadersIR( + const SpdyHeadersIR& headers) { + auto new_headers = std::make_unique( + headers.stream_id(), headers.header_block().Clone()); + new_headers->set_fin(headers.fin()); + new_headers->set_has_priority(headers.has_priority()); + new_headers->set_weight(headers.weight()); + new_headers->set_parent_stream_id(headers.parent_stream_id()); + new_headers->set_exclusive(headers.exclusive()); + if (headers.padded()) { + new_headers->set_padding_len(headers.padding_payload_len() + 1); + } + return new_headers; + } + + static SpdySerializedFrame SerializeHeaders(SpdyFramer* framer, + const SpdyHeadersIR& headers) { + SpdySerializedFrame serialized_headers_old_version( + framer->SerializeHeaders(headers)); + framer->hpack_encoder_.reset(nullptr); + auto* saved_debug_visitor = framer->debug_visitor_; + framer->debug_visitor_ = nullptr; + + std::vector frame_list; + ArrayOutputBuffer frame_list_buffer(frame_list_char, buffer_size); + SpdyFramer::SpdyHeaderFrameIterator it(framer, CloneSpdyHeadersIR(headers)); + while (it.HasNextFrame()) { + size_t size_before = frame_list_buffer.Size(); + EXPECT_GT(it.NextFrame(&frame_list_buffer), 0u); + frame_list.emplace_back( + SpdySerializedFrame(frame_list_buffer.Begin() + size_before, + frame_list_buffer.Size() - size_before, false)); + } + framer->debug_visitor_ = saved_debug_visitor; + + EXPECT_THAT(serialized_headers_old_version, IsFrameUnionOf(&frame_list)); + return serialized_headers_old_version; + } + + static SpdySerializedFrame SerializeHeaders(SpdyFramer* framer, + const SpdyHeadersIR& headers, + ArrayOutputBuffer* output) { + if (output == nullptr) { + return SerializeHeaders(framer, headers); + } + output->Reset(); + EXPECT_TRUE(framer->SerializeHeaders(headers, output)); + SpdySerializedFrame serialized_headers_old_version(output->Begin(), + output->Size(), false); + framer->hpack_encoder_.reset(nullptr); + auto* saved_debug_visitor = framer->debug_visitor_; + framer->debug_visitor_ = nullptr; + + std::vector frame_list; + ArrayOutputBuffer frame_list_buffer(frame_list_char, buffer_size); + SpdyFramer::SpdyHeaderFrameIterator it(framer, CloneSpdyHeadersIR(headers)); + while (it.HasNextFrame()) { + size_t size_before = frame_list_buffer.Size(); + EXPECT_GT(it.NextFrame(&frame_list_buffer), 0u); + frame_list.emplace_back( + SpdySerializedFrame(frame_list_buffer.Begin() + size_before, + frame_list_buffer.Size() - size_before, false)); + } + framer->debug_visitor_ = saved_debug_visitor; + + EXPECT_THAT(serialized_headers_old_version, IsFrameUnionOf(&frame_list)); + return serialized_headers_old_version; + } + + static std::unique_ptr CloneSpdyPushPromiseIR( + const SpdyPushPromiseIR& push_promise) { + auto new_push_promise = std::make_unique( + push_promise.stream_id(), push_promise.promised_stream_id(), + push_promise.header_block().Clone()); + new_push_promise->set_fin(push_promise.fin()); + if (push_promise.padded()) { + new_push_promise->set_padding_len(push_promise.padding_payload_len() + 1); + } + return new_push_promise; + } + + static SpdySerializedFrame SerializePushPromise( + SpdyFramer* framer, const SpdyPushPromiseIR& push_promise) { + SpdySerializedFrame serialized_headers_old_version = + framer->SerializePushPromise(push_promise); + framer->hpack_encoder_.reset(nullptr); + auto* saved_debug_visitor = framer->debug_visitor_; + framer->debug_visitor_ = nullptr; + + std::vector frame_list; + ArrayOutputBuffer frame_list_buffer(frame_list_char, buffer_size); + frame_list_buffer.Reset(); + SpdyFramer::SpdyPushPromiseFrameIterator it( + framer, CloneSpdyPushPromiseIR(push_promise)); + while (it.HasNextFrame()) { + size_t size_before = frame_list_buffer.Size(); + EXPECT_GT(it.NextFrame(&frame_list_buffer), 0u); + frame_list.emplace_back( + SpdySerializedFrame(frame_list_buffer.Begin() + size_before, + frame_list_buffer.Size() - size_before, false)); + } + framer->debug_visitor_ = saved_debug_visitor; + + EXPECT_THAT(serialized_headers_old_version, IsFrameUnionOf(&frame_list)); + return serialized_headers_old_version; + } + + static SpdySerializedFrame SerializePushPromise( + SpdyFramer* framer, const SpdyPushPromiseIR& push_promise, + ArrayOutputBuffer* output) { + if (output == nullptr) { + return SerializePushPromise(framer, push_promise); + } + output->Reset(); + EXPECT_TRUE(framer->SerializePushPromise(push_promise, output)); + SpdySerializedFrame serialized_headers_old_version(output->Begin(), + output->Size(), false); + framer->hpack_encoder_.reset(nullptr); + auto* saved_debug_visitor = framer->debug_visitor_; + framer->debug_visitor_ = nullptr; + + std::vector frame_list; + ArrayOutputBuffer frame_list_buffer(frame_list_char, buffer_size); + frame_list_buffer.Reset(); + SpdyFramer::SpdyPushPromiseFrameIterator it( + framer, CloneSpdyPushPromiseIR(push_promise)); + while (it.HasNextFrame()) { + size_t size_before = frame_list_buffer.Size(); + EXPECT_GT(it.NextFrame(&frame_list_buffer), 0u); + frame_list.emplace_back( + SpdySerializedFrame(frame_list_buffer.Begin() + size_before, + frame_list_buffer.Size() - size_before, false)); + } + framer->debug_visitor_ = saved_debug_visitor; + + EXPECT_THAT(serialized_headers_old_version, IsFrameUnionOf(&frame_list)); + return serialized_headers_old_version; + } +}; + +class TestSpdyVisitor : public SpdyFramerVisitorInterface, + public SpdyFramerDebugVisitorInterface { + public: + // This is larger than our max frame size because header blocks that + // are too long can spill over into CONTINUATION frames. + static constexpr size_t kDefaultHeaderBufferSize = 16 * 1024 * 1024; + + explicit TestSpdyVisitor(SpdyFramer::CompressionOption option) + : framer_(option), + error_count_(0), + headers_frame_count_(0), + push_promise_frame_count_(0), + goaway_count_(0), + setting_count_(0), + settings_ack_sent_(0), + settings_ack_received_(0), + continuation_count_(0), + altsvc_count_(0), + priority_count_(0), + unknown_frame_count_(0), + on_unknown_frame_result_(false), + last_window_update_stream_(0), + last_window_update_delta_(0), + last_push_promise_stream_(0), + last_push_promise_promised_stream_(0), + data_bytes_(0), + fin_frame_count_(0), + fin_flag_count_(0), + end_of_stream_count_(0), + control_frame_header_data_count_(0), + zero_length_control_frame_header_data_count_(0), + data_frame_count_(0), + last_payload_len_(0), + last_frame_len_(0), + unknown_payload_len_(0), + header_buffer_(new char[kDefaultHeaderBufferSize]), + header_buffer_length_(0), + header_buffer_size_(kDefaultHeaderBufferSize), + header_stream_id_(static_cast(-1)), + header_control_type_(SpdyFrameType::DATA), + header_buffer_valid_(false) {} + + void OnError(Http2DecoderAdapter::SpdyFramerError error, + std::string /*detailed_error*/) override { + QUICHE_VLOG(1) << "SpdyFramer Error: " + << Http2DecoderAdapter::SpdyFramerErrorToString(error); + ++error_count_; + } + + void OnDataFrameHeader(SpdyStreamId stream_id, size_t length, + bool fin) override { + QUICHE_VLOG(1) << "OnDataFrameHeader(" << stream_id << ", " << length + << ", " << fin << ")"; + ++data_frame_count_; + header_stream_id_ = stream_id; + } + + void OnStreamFrameData(SpdyStreamId stream_id, const char* data, + size_t len) override { + QUICHE_VLOG(1) << "OnStreamFrameData(" << stream_id << ", data, " << len + << ", " + << ") data:\n" + << quiche::QuicheTextUtils::HexDump( + absl::string_view(data, len)); + EXPECT_EQ(header_stream_id_, stream_id); + + data_bytes_ += len; + } + + void OnStreamEnd(SpdyStreamId stream_id) override { + QUICHE_VLOG(1) << "OnStreamEnd(" << stream_id << ")"; + EXPECT_EQ(header_stream_id_, stream_id); + ++end_of_stream_count_; + } + + void OnStreamPadLength(SpdyStreamId stream_id, size_t value) override { + QUICHE_VLOG(1) << "OnStreamPadding(" << stream_id << ", " << value << ")\n"; + EXPECT_EQ(header_stream_id_, stream_id); + // Count the padding length field byte against total data bytes. + data_bytes_ += 1; + } + + void OnStreamPadding(SpdyStreamId stream_id, size_t len) override { + QUICHE_VLOG(1) << "OnStreamPadding(" << stream_id << ", " << len << ")\n"; + EXPECT_EQ(header_stream_id_, stream_id); + data_bytes_ += len; + } + + SpdyHeadersHandlerInterface* OnHeaderFrameStart( + SpdyStreamId /*stream_id*/) override { + if (headers_handler_ == nullptr) { + headers_handler_ = std::make_unique(); + } + return headers_handler_.get(); + } + + void OnHeaderFrameEnd(SpdyStreamId /*stream_id*/) override { + QUICHE_CHECK(headers_handler_ != nullptr); + headers_ = headers_handler_->decoded_block().Clone(); + header_bytes_received_ = headers_handler_->uncompressed_header_bytes(); + headers_handler_.reset(); + } + + void OnRstStream(SpdyStreamId stream_id, SpdyErrorCode error_code) override { + QUICHE_VLOG(1) << "OnRstStream(" << stream_id << ", " << error_code << ")"; + ++fin_frame_count_; + } + + void OnSetting(SpdySettingsId id, uint32_t value) override { + QUICHE_VLOG(1) << "OnSetting(" << id << ", " << std::hex << value << ")"; + ++setting_count_; + } + + void OnSettingsAck() override { + QUICHE_VLOG(1) << "OnSettingsAck"; + ++settings_ack_received_; + } + + void OnSettingsEnd() override { + QUICHE_VLOG(1) << "OnSettingsEnd"; + ++settings_ack_sent_; + } + + void OnPing(SpdyPingId unique_id, bool is_ack) override { + QUICHE_LOG(DFATAL) << "OnPing(" << unique_id << ", " << (is_ack ? 1 : 0) + << ")"; + } + + void OnGoAway(SpdyStreamId last_accepted_stream_id, + SpdyErrorCode error_code) override { + QUICHE_VLOG(1) << "OnGoAway(" << last_accepted_stream_id << ", " + << error_code << ")"; + ++goaway_count_; + } + + void OnHeaders(SpdyStreamId stream_id, size_t payload_length, + bool has_priority, int weight, SpdyStreamId parent_stream_id, + bool exclusive, bool fin, bool end) override { + QUICHE_VLOG(1) << "OnHeaders(" << stream_id << ", " << payload_length + << ", " << has_priority << ", " << weight << ", " + << parent_stream_id << ", " << exclusive << ", " << fin + << ", " << end << ")"; + ++headers_frame_count_; + InitHeaderStreaming(SpdyFrameType::HEADERS, stream_id); + if (fin) { + ++fin_flag_count_; + } + header_has_priority_ = has_priority; + header_parent_stream_id_ = parent_stream_id; + header_exclusive_ = exclusive; + } + + void OnWindowUpdate(SpdyStreamId stream_id, int delta_window_size) override { + QUICHE_VLOG(1) << "OnWindowUpdate(" << stream_id << ", " + << delta_window_size << ")"; + last_window_update_stream_ = stream_id; + last_window_update_delta_ = delta_window_size; + } + + void OnPushPromise(SpdyStreamId stream_id, SpdyStreamId promised_stream_id, + bool end) override { + QUICHE_VLOG(1) << "OnPushPromise(" << stream_id << ", " + << promised_stream_id << ", " << end << ")"; + ++push_promise_frame_count_; + InitHeaderStreaming(SpdyFrameType::PUSH_PROMISE, stream_id); + last_push_promise_stream_ = stream_id; + last_push_promise_promised_stream_ = promised_stream_id; + } + + void OnContinuation(SpdyStreamId stream_id, size_t payload_size, + bool end) override { + QUICHE_VLOG(1) << "OnContinuation(" << stream_id << ", " << payload_size + << ", " << end << ")"; + ++continuation_count_; + } + + void OnAltSvc(SpdyStreamId stream_id, absl::string_view origin, + const SpdyAltSvcWireFormat::AlternativeServiceVector& + altsvc_vector) override { + QUICHE_VLOG(1) << "OnAltSvc(" << stream_id << ", \"" << origin + << "\", altsvc_vector)"; + test_altsvc_ir_ = std::make_unique(stream_id); + if (origin.length() > 0) { + test_altsvc_ir_->set_origin(std::string(origin)); + } + for (const auto& altsvc : altsvc_vector) { + test_altsvc_ir_->add_altsvc(altsvc); + } + ++altsvc_count_; + } + + void OnPriority(SpdyStreamId stream_id, SpdyStreamId parent_stream_id, + int weight, bool exclusive) override { + QUICHE_VLOG(1) << "OnPriority(" << stream_id << ", " << parent_stream_id + << ", " << weight << ", " << (exclusive ? 1 : 0) << ")"; + ++priority_count_; + } + + void OnPriorityUpdate(SpdyStreamId prioritized_stream_id, + absl::string_view priority_field_value) override { + QUICHE_VLOG(1) << "OnPriorityUpdate(" << prioritized_stream_id << ", " + << priority_field_value << ")"; + } + + bool OnUnknownFrame(SpdyStreamId stream_id, uint8_t frame_type) override { + QUICHE_VLOG(1) << "OnUnknownFrame(" << stream_id << ", " << frame_type + << ")"; + return on_unknown_frame_result_; + } + + void OnUnknownFrameStart(SpdyStreamId stream_id, size_t length, uint8_t type, + uint8_t flags) override { + QUICHE_VLOG(1) << "OnUnknownFrameStart(" << stream_id << ", " << length + << ", " << static_cast(type) << ", " + << static_cast(flags) << ")"; + ++unknown_frame_count_; + } + + void OnUnknownFramePayload(SpdyStreamId stream_id, + absl::string_view payload) override { + QUICHE_VLOG(1) << "OnUnknownFramePayload(" << stream_id << ", " << payload + << ")"; + unknown_payload_len_ += payload.length(); + } + + void OnSendCompressedFrame(SpdyStreamId stream_id, SpdyFrameType type, + size_t payload_len, size_t frame_len) override { + QUICHE_VLOG(1) << "OnSendCompressedFrame(" << stream_id << ", " << type + << ", " << payload_len << ", " << frame_len << ")"; + last_payload_len_ = payload_len; + last_frame_len_ = frame_len; + } + + void OnReceiveCompressedFrame(SpdyStreamId stream_id, SpdyFrameType type, + size_t frame_len) override { + QUICHE_VLOG(1) << "OnReceiveCompressedFrame(" << stream_id << ", " << type + << ", " << frame_len << ")"; + last_frame_len_ = frame_len; + } + + // Convenience function which runs a framer simulation with particular input. + void SimulateInFramer(const unsigned char* input, size_t size) { + deframer_.set_visitor(this); + size_t input_remaining = size; + const char* input_ptr = reinterpret_cast(input); + while (input_remaining > 0 && deframer_.spdy_framer_error() == + Http2DecoderAdapter::SPDY_NO_ERROR) { + // To make the tests more interesting, we feed random (and small) chunks + // into the framer. This simulates getting strange-sized reads from + // the socket. + const size_t kMaxReadSize = 32; + size_t bytes_read = + (rand() % std::min(input_remaining, kMaxReadSize)) + 1; // NOLINT + size_t bytes_processed = deframer_.ProcessInput(input_ptr, bytes_read); + input_remaining -= bytes_processed; + input_ptr += bytes_processed; + } + } + + void InitHeaderStreaming(SpdyFrameType header_control_type, + SpdyStreamId stream_id) { + if (!IsDefinedFrameType(SerializeFrameType(header_control_type))) { + QUICHE_DLOG(FATAL) << "Attempted to init header streaming with " + << "invalid control frame type: " + << header_control_type; + } + memset(header_buffer_.get(), 0, header_buffer_size_); + header_buffer_length_ = 0; + header_stream_id_ = stream_id; + header_control_type_ = header_control_type; + header_buffer_valid_ = true; + } + + void set_extension_visitor(ExtensionVisitorInterface* extension) { + deframer_.set_extension_visitor(extension); + } + + // Override the default buffer size (16K). Call before using the framer! + void set_header_buffer_size(size_t header_buffer_size) { + header_buffer_size_ = header_buffer_size; + header_buffer_.reset(new char[header_buffer_size]); + } + + SpdyFramer framer_; + Http2DecoderAdapter deframer_; + + // Counters from the visitor callbacks. + int error_count_; + int headers_frame_count_; + int push_promise_frame_count_; + int goaway_count_; + int setting_count_; + int settings_ack_sent_; + int settings_ack_received_; + int continuation_count_; + int altsvc_count_; + int priority_count_; + std::unique_ptr test_altsvc_ir_; + int unknown_frame_count_; + bool on_unknown_frame_result_; + SpdyStreamId last_window_update_stream_; + int last_window_update_delta_; + SpdyStreamId last_push_promise_stream_; + SpdyStreamId last_push_promise_promised_stream_; + int data_bytes_; + int fin_frame_count_; // The count of RST_STREAM type frames received. + int fin_flag_count_; // The count of frames with the FIN flag set. + int end_of_stream_count_; // The count of zero-length data frames. + int control_frame_header_data_count_; // The count of chunks received. + // The count of zero-length control frame header data chunks received. + int zero_length_control_frame_header_data_count_; + int data_frame_count_; + size_t last_payload_len_; + size_t last_frame_len_; + size_t unknown_payload_len_; + + // Header block streaming state: + std::unique_ptr header_buffer_; + size_t header_buffer_length_; + size_t header_buffer_size_; + size_t header_bytes_received_; + SpdyStreamId header_stream_id_; + SpdyFrameType header_control_type_; + bool header_buffer_valid_; + std::unique_ptr headers_handler_; + Http2HeaderBlock headers_; + bool header_has_priority_; + SpdyStreamId header_parent_stream_id_; + bool header_exclusive_; +}; + +class TestExtension : public ExtensionVisitorInterface { + public: + void OnSetting(SpdySettingsId id, uint32_t value) override { + settings_received_.push_back({id, value}); + } + + // Called when non-standard frames are received. + bool OnFrameHeader(SpdyStreamId stream_id, size_t length, uint8_t type, + uint8_t flags) override { + stream_id_ = stream_id; + length_ = length; + type_ = type; + flags_ = flags; + return true; + } + + // The payload for a single frame may be delivered as multiple calls to + // OnFramePayload. + void OnFramePayload(const char* data, size_t len) override { + payload_.append(data, len); + } + + std::vector> settings_received_; + SpdyStreamId stream_id_ = 0; + size_t length_ = 0; + uint8_t type_ = 0; + uint8_t flags_ = 0; + std::string payload_; +}; + +// Exposes SpdyUnknownIR::set_length() for testing purposes. +class TestSpdyUnknownIR : public SpdyUnknownIR { + public: + using SpdyUnknownIR::set_length; + using SpdyUnknownIR::SpdyUnknownIR; +}; + +enum Output { USE, NOT_USE }; + +class SpdyFramerTest : public quiche::test::QuicheTestWithParam { + public: + SpdyFramerTest() + : output_(output_buffer, kSize), + framer_(SpdyFramer::ENABLE_COMPRESSION), + deframer_(std::make_unique()) {} + + protected: + void SetUp() override { + switch (GetParam()) { + case USE: + use_output_ = true; + break; + case NOT_USE: + // TODO(yasong): remove this case after + // gfe2_reloadable_flag_write_queue_zero_copy_buffer deprecates. + use_output_ = false; + break; + } + } + + void CompareFrame(const std::string& description, + const SpdySerializedFrame& actual_frame, + const unsigned char* expected, const int expected_len) { + const unsigned char* actual = + reinterpret_cast(actual_frame.data()); + CompareCharArraysWithHexError(description, actual, actual_frame.size(), + expected, expected_len); + } + + bool use_output_ = false; + ArrayOutputBuffer output_; + SpdyFramer framer_; + std::unique_ptr deframer_; +}; + +INSTANTIATE_TEST_SUITE_P(SpdyFramerTests, SpdyFramerTest, + ::testing::Values(USE, NOT_USE)); + +// Test that we can encode and decode a Http2HeaderBlock in serialized form. +TEST_P(SpdyFramerTest, HeaderBlockInBuffer) { + SpdyFramer framer(SpdyFramer::DISABLE_COMPRESSION); + + // Encode the header block into a Headers frame. + SpdyHeadersIR headers(/* stream_id = */ 1); + headers.SetHeader("alpha", "beta"); + headers.SetHeader("gamma", "charlie"); + headers.SetHeader("cookie", "key1=value1; key2=value2"); + SpdySerializedFrame frame( + SpdyFramerPeer::SerializeHeaders(&framer, headers, &output_)); + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer(reinterpret_cast(frame.data()), + frame.size()); + + EXPECT_EQ(0, visitor.zero_length_control_frame_header_data_count_); + EXPECT_EQ(headers.header_block(), visitor.headers_); +} + +// Test that if there's not a full frame, we fail to parse it. +TEST_P(SpdyFramerTest, UndersizedHeaderBlockInBuffer) { + SpdyFramer framer(SpdyFramer::DISABLE_COMPRESSION); + + // Encode the header block into a Headers frame. + SpdyHeadersIR headers(/* stream_id = */ 1); + headers.SetHeader("alpha", "beta"); + headers.SetHeader("gamma", "charlie"); + SpdySerializedFrame frame( + SpdyFramerPeer::SerializeHeaders(&framer, headers, &output_)); + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer(reinterpret_cast(frame.data()), + frame.size() - 2); + + EXPECT_EQ(0, visitor.zero_length_control_frame_header_data_count_); + EXPECT_THAT(visitor.headers_, testing::IsEmpty()); +} + +// Test that we can encode and decode stream dependency values in a header +// frame. +TEST_P(SpdyFramerTest, HeaderStreamDependencyValues) { + SpdyFramer framer(SpdyFramer::DISABLE_COMPRESSION); + + const SpdyStreamId parent_stream_id_test_array[] = {0, 3}; + for (SpdyStreamId parent_stream_id : parent_stream_id_test_array) { + const bool exclusive_test_array[] = {true, false}; + for (bool exclusive : exclusive_test_array) { + SpdyHeadersIR headers(1); + headers.set_has_priority(true); + headers.set_parent_stream_id(parent_stream_id); + headers.set_exclusive(exclusive); + SpdySerializedFrame frame( + SpdyFramerPeer::SerializeHeaders(&framer, headers, &output_)); + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer(reinterpret_cast(frame.data()), + frame.size()); + + EXPECT_TRUE(visitor.header_has_priority_); + EXPECT_EQ(parent_stream_id, visitor.header_parent_stream_id_); + EXPECT_EQ(exclusive, visitor.header_exclusive_); + } + } +} + +// Test that if we receive a frame with a payload length field at the default +// max size, we do not set an error in ProcessInput. +TEST_P(SpdyFramerTest, AcceptMaxFrameSizeSetting) { + testing::StrictMock visitor; + deframer_->set_visitor(&visitor); + + // DATA frame with maximum allowed payload length. + unsigned char kH2FrameData[] = { + 0x00, 0x40, 0x00, // Length: 2^14 + 0x00, // Type: DATA + 0x00, // Flags: None + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, 0x00, 0x00, 0x00, // Junk payload + }; + + SpdySerializedFrame frame(reinterpret_cast(kH2FrameData), + sizeof(kH2FrameData), false); + + EXPECT_CALL(visitor, OnCommonHeader(1, 16384, 0x0, 0x0)); + EXPECT_CALL(visitor, OnDataFrameHeader(1, 1 << 14, false)); + EXPECT_CALL(visitor, OnStreamFrameData(1, _, 4)); + deframer_->ProcessInput(frame.data(), frame.size()); + EXPECT_FALSE(deframer_->HasError()); +} + +// Test that if we receive a frame with a payload length larger than the default +// max size, we set an error of SPDY_INVALID_CONTROL_FRAME_SIZE. +TEST_P(SpdyFramerTest, ExceedMaxFrameSizeSetting) { + testing::StrictMock visitor; + deframer_->set_visitor(&visitor); + + // DATA frame with too large payload length. + unsigned char kH2FrameData[] = { + 0x00, 0x40, 0x01, // Length: 2^14 + 1 + 0x00, // Type: DATA + 0x00, // Flags: None + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, 0x00, 0x00, 0x00, // Junk payload + }; + + SpdySerializedFrame frame(reinterpret_cast(kH2FrameData), + sizeof(kH2FrameData), false); + + EXPECT_CALL(visitor, OnCommonHeader(1, 16385, 0x0, 0x0)); + EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_OVERSIZED_PAYLOAD, _)); + deframer_->ProcessInput(frame.data(), frame.size()); + EXPECT_TRUE(deframer_->HasError()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_OVERSIZED_PAYLOAD, + deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); +} + +// Test that if we set a larger max frame size and then receive a frame with a +// payload length at that larger size, we do not set an error in ProcessInput. +TEST_P(SpdyFramerTest, AcceptLargerMaxFrameSizeSetting) { + testing::StrictMock visitor; + deframer_->set_visitor(&visitor); + + const size_t big_frame_size = (1 << 14) + 1; + deframer_->SetMaxFrameSize(big_frame_size); + + // DATA frame with larger-than-default but acceptable payload length. + unsigned char kH2FrameData[] = { + 0x00, 0x40, 0x01, // Length: 2^14 + 1 + 0x00, // Type: DATA + 0x00, // Flags: None + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, 0x00, 0x00, 0x00, // Junk payload + }; + + SpdySerializedFrame frame(reinterpret_cast(kH2FrameData), + sizeof(kH2FrameData), false); + + EXPECT_CALL(visitor, OnCommonHeader(1, big_frame_size, 0x0, 0x0)); + EXPECT_CALL(visitor, OnDataFrameHeader(1, big_frame_size, false)); + EXPECT_CALL(visitor, OnStreamFrameData(1, _, 4)); + deframer_->ProcessInput(frame.data(), frame.size()); + EXPECT_FALSE(deframer_->HasError()); +} + +// Test that if we receive a DATA frame with padding length larger than the +// payload length, we set an error of SPDY_INVALID_PADDING +TEST_P(SpdyFramerTest, OversizedDataPaddingError) { + testing::StrictMock visitor; + deframer_->set_visitor(&visitor); + + // DATA frame with invalid padding length. + // |kH2FrameData| has to be |unsigned char|, because Chromium on Windows uses + // MSVC, where |char| is signed by default, which would not compile because of + // the element exceeding 127. + unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x05, // Length: 5 + 0x00, // Type: DATA + 0x09, // Flags: END_STREAM|PADDED + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0xff, // PadLen: 255 trailing bytes (Too Long) + 0x00, 0x00, 0x00, 0x00, // Padding + }; + + SpdySerializedFrame frame(reinterpret_cast(kH2FrameData), + sizeof(kH2FrameData), false); + + { + testing::InSequence seq; + EXPECT_CALL(visitor, OnCommonHeader(1, 5, 0x0, 0x9)); + EXPECT_CALL(visitor, OnDataFrameHeader(1, 5, 1)); + EXPECT_CALL(visitor, OnStreamPadding(1, 1)); + EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_PADDING, _)); + } + EXPECT_GT(frame.size(), deframer_->ProcessInput(frame.data(), frame.size())); + EXPECT_TRUE(deframer_->HasError()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_INVALID_PADDING, + deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); +} + +// Test that if we receive a DATA frame with padding length not larger than the +// payload length, we do not set an error of SPDY_INVALID_PADDING +TEST_P(SpdyFramerTest, CorrectlySizedDataPaddingNoError) { + testing::StrictMock visitor; + + deframer_->set_visitor(&visitor); + + // DATA frame with valid Padding length + char kH2FrameData[] = { + 0x00, 0x00, 0x05, // Length: 5 + 0x00, // Type: DATA + 0x08, // Flags: PADDED + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x04, // PadLen: 4 trailing bytes + 0x00, 0x00, 0x00, 0x00, // Padding + }; + + SpdySerializedFrame frame(kH2FrameData, sizeof(kH2FrameData), false); + + { + testing::InSequence seq; + EXPECT_CALL(visitor, OnCommonHeader(1, 5, 0x0, 0x8)); + EXPECT_CALL(visitor, OnDataFrameHeader(1, 5, false)); + EXPECT_CALL(visitor, OnStreamPadLength(1, 4)); + EXPECT_CALL(visitor, OnError(_, _)).Times(0); + // Note that OnStreamFrameData(1, _, 1)) is never called + // since there is no data, only padding + EXPECT_CALL(visitor, OnStreamPadding(1, 4)); + } + + EXPECT_EQ(frame.size(), deframer_->ProcessInput(frame.data(), frame.size())); + EXPECT_FALSE(deframer_->HasError()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_NO_ERROR, deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); +} + +// Test that if we receive a HEADERS frame with padding length larger than the +// payload length, we set an error of SPDY_INVALID_PADDING +TEST_P(SpdyFramerTest, OversizedHeadersPaddingError) { + testing::StrictMock visitor; + + deframer_->set_visitor(&visitor); + + // HEADERS frame with invalid padding length. + // |kH2FrameData| has to be |unsigned char|, because Chromium on Windows uses + // MSVC, where |char| is signed by default, which would not compile because of + // the element exceeding 127. + unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x05, // Length: 5 + 0x01, // Type: HEADERS + 0x08, // Flags: PADDED + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0xff, // PadLen: 255 trailing bytes (Too Long) + 0x00, 0x00, 0x00, 0x00, // Padding + }; + + SpdySerializedFrame frame(reinterpret_cast(kH2FrameData), + sizeof(kH2FrameData), false); + + EXPECT_CALL(visitor, OnCommonHeader(1, 5, 0x1, 0x8)); + EXPECT_CALL(visitor, OnHeaders(1, 5, false, 0, 0, false, false, false)); + EXPECT_CALL(visitor, OnHeaderFrameStart(1)).Times(1); + EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_PADDING, _)); + EXPECT_EQ(frame.size(), deframer_->ProcessInput(frame.data(), frame.size())); + EXPECT_TRUE(deframer_->HasError()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_INVALID_PADDING, + deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); +} + +// Test that if we receive a HEADERS frame with padding length not larger +// than the payload length, we do not set an error of SPDY_INVALID_PADDING +TEST_P(SpdyFramerTest, CorrectlySizedHeadersPaddingNoError) { + testing::StrictMock visitor; + + deframer_->set_visitor(&visitor); + + // HEADERS frame with invalid Padding length + char kH2FrameData[] = { + 0x00, 0x00, 0x05, // Length: 5 + 0x01, // Type: HEADERS + 0x08, // Flags: PADDED + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x04, // PadLen: 4 trailing bytes + 0x00, 0x00, 0x00, 0x00, // Padding + }; + + SpdySerializedFrame frame(kH2FrameData, sizeof(kH2FrameData), false); + + EXPECT_CALL(visitor, OnCommonHeader(1, 5, 0x1, 0x8)); + EXPECT_CALL(visitor, OnHeaders(1, 5, false, 0, 0, false, false, false)); + EXPECT_CALL(visitor, OnHeaderFrameStart(1)).Times(1); + + EXPECT_EQ(frame.size(), deframer_->ProcessInput(frame.data(), frame.size())); + EXPECT_FALSE(deframer_->HasError()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_NO_ERROR, deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); +} + +// Test that if we receive a DATA with stream ID zero, we signal an error +// (but don't crash). +TEST_P(SpdyFramerTest, DataWithStreamIdZero) { + testing::StrictMock visitor; + + deframer_->set_visitor(&visitor); + + const char bytes[] = "hello"; + SpdyDataIR data_ir(/* stream_id = */ 0, bytes); + SpdySerializedFrame frame(framer_.SerializeData(data_ir)); + + // We shouldn't have to read the whole frame before we signal an error. + EXPECT_CALL(visitor, OnCommonHeader(0, _, 0x0, _)); + EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, _)); + EXPECT_GT(frame.size(), deframer_->ProcessInput(frame.data(), frame.size())); + EXPECT_TRUE(deframer_->HasError()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, + deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); +} + +// Test that if we receive a HEADERS with stream ID zero, we signal an error +// (but don't crash). +TEST_P(SpdyFramerTest, HeadersWithStreamIdZero) { + testing::StrictMock visitor; + + deframer_->set_visitor(&visitor); + + SpdyHeadersIR headers(/* stream_id = */ 0); + headers.SetHeader("alpha", "beta"); + SpdySerializedFrame frame( + SpdyFramerPeer::SerializeHeaders(&framer_, headers, &output_)); + + // We shouldn't have to read the whole frame before we signal an error. + EXPECT_CALL(visitor, OnCommonHeader(0, _, 0x1, _)); + EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, _)); + EXPECT_GT(frame.size(), deframer_->ProcessInput(frame.data(), frame.size())); + EXPECT_TRUE(deframer_->HasError()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, + deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); +} + +// Test that if we receive a PRIORITY with stream ID zero, we signal an error +// (but don't crash). +TEST_P(SpdyFramerTest, PriorityWithStreamIdZero) { + testing::StrictMock visitor; + + deframer_->set_visitor(&visitor); + + SpdyPriorityIR priority_ir(/* stream_id = */ 0, + /* parent_stream_id = */ 1, + /* weight = */ 16, + /* exclusive = */ true); + SpdySerializedFrame frame(framer_.SerializeFrame(priority_ir)); + if (use_output_) { + EXPECT_EQ(framer_.SerializeFrame(priority_ir, &output_), frame.size()); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + + // We shouldn't have to read the whole frame before we signal an error. + EXPECT_CALL(visitor, OnCommonHeader(0, _, 0x2, _)); + EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, _)); + EXPECT_GT(frame.size(), deframer_->ProcessInput(frame.data(), frame.size())); + EXPECT_TRUE(deframer_->HasError()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, + deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); +} + +// Test that if we receive a RST_STREAM with stream ID zero, we signal an error +// (but don't crash). +TEST_P(SpdyFramerTest, RstStreamWithStreamIdZero) { + testing::StrictMock visitor; + + deframer_->set_visitor(&visitor); + + SpdyRstStreamIR rst_stream_ir(/* stream_id = */ 0, ERROR_CODE_PROTOCOL_ERROR); + SpdySerializedFrame frame(framer_.SerializeRstStream(rst_stream_ir)); + if (use_output_) { + EXPECT_TRUE(framer_.SerializeRstStream(rst_stream_ir, &output_)); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + + // We shouldn't have to read the whole frame before we signal an error. + EXPECT_CALL(visitor, OnCommonHeader(0, _, 0x3, _)); + EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, _)); + EXPECT_GT(frame.size(), deframer_->ProcessInput(frame.data(), frame.size())); + EXPECT_TRUE(deframer_->HasError()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, + deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); +} + +// Test that if we receive a SETTINGS with stream ID other than zero, +// we signal an error (but don't crash). +TEST_P(SpdyFramerTest, SettingsWithStreamIdNotZero) { + testing::StrictMock visitor; + + deframer_->set_visitor(&visitor); + + // Settings frame with invalid StreamID of 0x01 + char kH2FrameData[] = { + 0x00, 0x00, 0x06, // Length: 6 + 0x04, // Type: SETTINGS + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, 0x04, // Param: INITIAL_WINDOW_SIZE + 0x0a, 0x0b, 0x0c, 0x0d, // Value: 168496141 + }; + + SpdySerializedFrame frame(kH2FrameData, sizeof(kH2FrameData), false); + + // We shouldn't have to read the whole frame before we signal an error. + EXPECT_CALL(visitor, OnCommonHeader(1, 6, 0x4, 0x0)); + EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, _)); + EXPECT_GT(frame.size(), deframer_->ProcessInput(frame.data(), frame.size())); + EXPECT_TRUE(deframer_->HasError()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, + deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); +} + +// Test that if we receive a GOAWAY with stream ID other than zero, +// we signal an error (but don't crash). +TEST_P(SpdyFramerTest, GoawayWithStreamIdNotZero) { + testing::StrictMock visitor; + + deframer_->set_visitor(&visitor); + + // GOAWAY frame with invalid StreamID of 0x01 + char kH2FrameData[] = { + 0x00, 0x00, 0x0a, // Length: 10 + 0x07, // Type: GOAWAY + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, 0x00, 0x00, 0x00, // Last: 0 + 0x00, 0x00, 0x00, 0x00, // Error: NO_ERROR + 0x47, 0x41, // Description + }; + + SpdySerializedFrame frame(kH2FrameData, sizeof(kH2FrameData), false); + + // We shouldn't have to read the whole frame before we signal an error. + EXPECT_CALL(visitor, OnCommonHeader(1, 10, 0x7, 0x0)); + EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, _)); + EXPECT_GT(frame.size(), deframer_->ProcessInput(frame.data(), frame.size())); + EXPECT_TRUE(deframer_->HasError()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, + deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); +} + +// Test that if we receive a CONTINUATION with stream ID zero, we signal +// SPDY_INVALID_STREAM_ID. +TEST_P(SpdyFramerTest, ContinuationWithStreamIdZero) { + testing::StrictMock visitor; + + deframer_->set_visitor(&visitor); + + SpdyContinuationIR continuation(/* stream_id = */ 0); + std::string some_nonsense_encoding = "some nonsense encoding"; + continuation.take_encoding(std::move(some_nonsense_encoding)); + continuation.set_end_headers(true); + SpdySerializedFrame frame(framer_.SerializeContinuation(continuation)); + if (use_output_) { + ASSERT_TRUE(framer_.SerializeContinuation(continuation, &output_)); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + + // We shouldn't have to read the whole frame before we signal an error. + EXPECT_CALL(visitor, OnCommonHeader(0, _, 0x9, _)); + EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, _)); + EXPECT_GT(frame.size(), deframer_->ProcessInput(frame.data(), frame.size())); + EXPECT_TRUE(deframer_->HasError()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, + deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); +} + +// Test that if we receive a PUSH_PROMISE with stream ID zero, we signal +// SPDY_INVALID_STREAM_ID. +TEST_P(SpdyFramerTest, PushPromiseWithStreamIdZero) { + testing::StrictMock visitor; + + deframer_->set_visitor(&visitor); + + SpdyPushPromiseIR push_promise(/* stream_id = */ 0, + /* promised_stream_id = */ 4); + push_promise.SetHeader("alpha", "beta"); + SpdySerializedFrame frame(SpdyFramerPeer::SerializePushPromise( + &framer_, push_promise, use_output_ ? &output_ : nullptr)); + + // We shouldn't have to read the whole frame before we signal an error. + EXPECT_CALL(visitor, OnCommonHeader(0, _, 0x5, _)); + EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, _)); + EXPECT_GT(frame.size(), deframer_->ProcessInput(frame.data(), frame.size())); + EXPECT_TRUE(deframer_->HasError()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, + deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); +} + +// Test that if we receive a PUSH_PROMISE with promised stream ID zero, we +// signal SPDY_INVALID_CONTROL_FRAME. +TEST_P(SpdyFramerTest, PushPromiseWithPromisedStreamIdZero) { + testing::StrictMock visitor; + + deframer_->set_visitor(&visitor); + + SpdyPushPromiseIR push_promise(/* stream_id = */ 3, + /* promised_stream_id = */ 0); + push_promise.SetHeader("alpha", "beta"); + SpdySerializedFrame frame(SpdyFramerPeer::SerializePushPromise( + &framer_, push_promise, use_output_ ? &output_ : nullptr)); + + EXPECT_CALL(visitor, OnCommonHeader(3, _, 0x5, _)); + EXPECT_CALL(visitor, + OnError(Http2DecoderAdapter::SPDY_INVALID_CONTROL_FRAME, _)); + deframer_->ProcessInput(frame.data(), frame.size()); + EXPECT_TRUE(deframer_->HasError()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_INVALID_CONTROL_FRAME, + deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); +} + +TEST_P(SpdyFramerTest, MultiValueHeader) { + SpdyFramer framer(SpdyFramer::DISABLE_COMPRESSION); + std::string value("value1\0value2", 13); + // TODO(jgraettinger): If this pattern appears again, move to test class. + Http2HeaderBlock header_set; + header_set["name"] = value; + HpackEncoder encoder; + encoder.DisableCompression(); + std::string buffer = encoder.EncodeHeaderBlock(header_set); + // Frame builder with plentiful buffer size. + SpdyFrameBuilder frame(1024); + frame.BeginNewFrame(SpdyFrameType::HEADERS, + HEADERS_FLAG_PRIORITY | HEADERS_FLAG_END_HEADERS, 3, + buffer.size() + 5 /* priority */); + frame.WriteUInt32(0); // Priority exclusivity and dependent stream. + frame.WriteUInt8(255); // Priority weight. + frame.WriteBytes(&buffer[0], buffer.size()); + + SpdySerializedFrame control_frame(frame.take()); + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer( + reinterpret_cast(control_frame.data()), + control_frame.size()); + + EXPECT_THAT(visitor.headers_, testing::ElementsAre(testing::Pair( + "name", absl::string_view(value)))); +} + +TEST_P(SpdyFramerTest, CompressEmptyHeaders) { + // See https://crbug.com/172383/ + SpdyHeadersIR headers(1); + headers.SetHeader("server", "SpdyServer 1.0"); + headers.SetHeader("date", "Mon 12 Jan 2009 12:12:12 PST"); + headers.SetHeader("status", "200"); + headers.SetHeader("version", "HTTP/1.1"); + headers.SetHeader("content-type", "text/html"); + headers.SetHeader("content-length", "12"); + headers.SetHeader("x-empty-header", ""); + + SpdyFramer framer(SpdyFramer::ENABLE_COMPRESSION); + SpdySerializedFrame frame1( + SpdyFramerPeer::SerializeHeaders(&framer, headers, &output_)); +} + +TEST_P(SpdyFramerTest, Basic) { + // Send HEADERS frames with PRIORITY and END_HEADERS set. + // frame-format off + const unsigned char kH2Input[] = { + 0x00, 0x00, 0x05, // Length: 5 + 0x01, // Type: HEADERS + 0x24, // Flags: END_HEADERS|PRIORITY + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, 0x00, 0x00, 0x00, // Parent: 0 + 0x82, // Weight: 131 + + 0x00, 0x00, 0x01, // Length: 1 + 0x01, // Type: HEADERS + 0x04, // Flags: END_HEADERS + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x8c, // :status: 200 + + 0x00, 0x00, 0x0c, // Length: 12 + 0x00, // Type: DATA + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0xde, 0xad, 0xbe, 0xef, // Payload + 0xde, 0xad, 0xbe, 0xef, // + 0xde, 0xad, 0xbe, 0xef, // + + 0x00, 0x00, 0x05, // Length: 5 + 0x01, // Type: HEADERS + 0x24, // Flags: END_HEADERS|PRIORITY + 0x00, 0x00, 0x00, 0x03, // Stream: 3 + 0x00, 0x00, 0x00, 0x00, // Parent: 0 + 0x82, // Weight: 131 + + 0x00, 0x00, 0x08, // Length: 8 + 0x00, // Type: DATA + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x03, // Stream: 3 + 0xde, 0xad, 0xbe, 0xef, // Payload + 0xde, 0xad, 0xbe, 0xef, // + + 0x00, 0x00, 0x04, // Length: 4 + 0x00, // Type: DATA + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0xde, 0xad, 0xbe, 0xef, // Payload + + 0x00, 0x00, 0x04, // Length: 4 + 0x03, // Type: RST_STREAM + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, 0x00, 0x00, 0x08, // Error: CANCEL + + 0x00, 0x00, 0x00, // Length: 0 + 0x00, // Type: DATA + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x03, // Stream: 3 + + 0x00, 0x00, 0x04, // Length: 4 + 0x03, // Type: RST_STREAM + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x03, // Stream: 3 + 0x00, 0x00, 0x00, 0x08, // Error: CANCEL + }; + // frame-format on + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer(kH2Input, sizeof(kH2Input)); + + EXPECT_EQ(24, visitor.data_bytes_); + EXPECT_EQ(0, visitor.error_count_); + EXPECT_EQ(2, visitor.fin_frame_count_); + + EXPECT_EQ(3, visitor.headers_frame_count_); + + EXPECT_EQ(0, visitor.fin_flag_count_); + EXPECT_EQ(0, visitor.end_of_stream_count_); + EXPECT_EQ(4, visitor.data_frame_count_); +} + +// Verifies that the decoder stops delivering events after a user error. +TEST_P(SpdyFramerTest, BasicWithError) { + // Send HEADERS frames with PRIORITY and END_HEADERS set. + // frame-format off + const unsigned char kH2Input[] = { + 0x00, 0x00, 0x01, // Length: 1 + 0x01, // Type: HEADERS + 0x04, // Flags: END_HEADERS + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x8c, // :status: 200 + + 0x00, 0x00, 0x0c, // Length: 12 + 0x00, // Type: DATA + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0xde, 0xad, 0xbe, 0xef, // Payload + 0xde, 0xad, 0xbe, 0xef, // + 0xde, 0xad, 0xbe, 0xef, // + + 0x00, 0x00, 0x06, // Length: 6 + 0x01, // Type: HEADERS + 0x24, // Flags: END_HEADERS|PRIORITY + 0x00, 0x00, 0x00, 0x03, // Stream: 3 + 0x00, 0x00, 0x00, 0x00, // Parent: 0 + 0x82, // Weight: 131 + 0x8c, // :status: 200 + + 0x00, 0x00, 0x08, // Length: 8 + 0x00, // Type: DATA + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x03, // Stream: 3 + 0xde, 0xad, 0xbe, 0xef, // Payload + 0xde, 0xad, 0xbe, 0xef, // + + 0x00, 0x00, 0x04, // Length: 4 + 0x00, // Type: DATA + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0xde, 0xad, 0xbe, 0xef, // Payload + + 0x00, 0x00, 0x04, // Length: 4 + 0x03, // Type: RST_STREAM + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, 0x00, 0x00, 0x08, // Error: CANCEL + + 0x00, 0x00, 0x00, // Length: 0 + 0x00, // Type: DATA + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x03, // Stream: 3 + + 0x00, 0x00, 0x04, // Length: 4 + 0x03, // Type: RST_STREAM + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x03, // Stream: 3 + 0x00, 0x00, 0x00, 0x08, // Error: CANCEL + }; + // frame-format on + + testing::StrictMock visitor; + + deframer_->set_visitor(&visitor); + + testing::InSequence s; + EXPECT_CALL(visitor, OnCommonHeader(1, 1, 0x1, 0x4)); + EXPECT_CALL(visitor, OnHeaders(1, 1, false, 0, 0, false, false, true)); + EXPECT_CALL(visitor, OnHeaderFrameStart(1)); + EXPECT_CALL(visitor, OnHeaderFrameEnd(1)); + EXPECT_CALL(visitor, OnCommonHeader(1, 12, 0x0, 0x0)); + EXPECT_CALL(visitor, OnDataFrameHeader(1, 12, false)); + EXPECT_CALL(visitor, OnStreamFrameData(1, _, 12)); + EXPECT_CALL(visitor, OnCommonHeader(3, 6, 0x1, 0x24)); + EXPECT_CALL(visitor, OnHeaders(3, 6, true, 131, 0, false, false, true)); + EXPECT_CALL(visitor, OnHeaderFrameStart(3)); + EXPECT_CALL(visitor, OnHeaderFrameEnd(3)); + EXPECT_CALL(visitor, OnCommonHeader(3, 8, 0x0, 0x0)); + EXPECT_CALL(visitor, OnDataFrameHeader(3, 8, false)) + .WillOnce(testing::InvokeWithoutArgs( + [this]() { deframer_->StopProcessing(); })); + // Remaining frames are not processed due to the error. + EXPECT_CALL( + visitor, + OnError(http2::Http2DecoderAdapter::SpdyFramerError::SPDY_STOP_PROCESSING, + "Ignoring further events on this connection.")); + + size_t processed = deframer_->ProcessInput( + reinterpret_cast(kH2Input), sizeof(kH2Input)); + EXPECT_LT(processed, sizeof(kH2Input)); +} + +// Test that the FIN flag on a data frame signifies EOF. +TEST_P(SpdyFramerTest, FinOnDataFrame) { + // Send HEADERS frames with END_HEADERS set. + // frame-format off + const unsigned char kH2Input[] = { + 0x00, 0x00, 0x05, // Length: 5 + 0x01, // Type: HEADERS + 0x24, // Flags: END_HEADERS|PRIORITY + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, 0x00, 0x00, 0x00, // Parent: 0 + 0x82, // Weight: 131 + + 0x00, 0x00, 0x01, // Length: 1 + 0x01, // Type: HEADERS + 0x04, // Flags: END_HEADERS + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x8c, // :status: 200 + + 0x00, 0x00, 0x0c, // Length: 12 + 0x00, // Type: DATA + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0xde, 0xad, 0xbe, 0xef, // Payload + 0xde, 0xad, 0xbe, 0xef, // + 0xde, 0xad, 0xbe, 0xef, // + + 0x00, 0x00, 0x04, // Length: 4 + 0x00, // Type: DATA + 0x01, // Flags: END_STREAM + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0xde, 0xad, 0xbe, 0xef, // Payload + }; + // frame-format on + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer(kH2Input, sizeof(kH2Input)); + + EXPECT_EQ(0, visitor.error_count_); + EXPECT_EQ(2, visitor.headers_frame_count_); + EXPECT_EQ(16, visitor.data_bytes_); + EXPECT_EQ(0, visitor.fin_frame_count_); + EXPECT_EQ(0, visitor.fin_flag_count_); + EXPECT_EQ(1, visitor.end_of_stream_count_); + EXPECT_EQ(2, visitor.data_frame_count_); +} + +TEST_P(SpdyFramerTest, FinOnHeadersFrame) { + // Send HEADERS frames with END_HEADERS set. + // frame-format off + const unsigned char kH2Input[] = { + 0x00, 0x00, 0x05, // Length: 5 + 0x01, // Type: HEADERS + 0x24, // Flags: END_HEADERS|PRIORITY + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, 0x00, 0x00, 0x00, // Parent: 0 + 0x82, // Weight: 131 + + 0x00, 0x00, 0x01, // Length: 1 + 0x01, // Type: HEADERS + 0x05, // Flags: END_STREAM|END_HEADERS + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x8c, // :status: 200 + }; + // frame-format on + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer(kH2Input, sizeof(kH2Input)); + + EXPECT_EQ(0, visitor.error_count_); + EXPECT_EQ(2, visitor.headers_frame_count_); + EXPECT_EQ(0, visitor.data_bytes_); + EXPECT_EQ(0, visitor.fin_frame_count_); + EXPECT_EQ(1, visitor.fin_flag_count_); + EXPECT_EQ(1, visitor.end_of_stream_count_); + EXPECT_EQ(0, visitor.data_frame_count_); +} + +// Verify we can decompress the stream even if handed over to the +// framer 1 byte at a time. +TEST_P(SpdyFramerTest, UnclosedStreamDataCompressorsOneByteAtATime) { + const char kHeader1[] = "header1"; + const char kHeader2[] = "header2"; + const char kValue1[] = "value1"; + const char kValue2[] = "value2"; + + SpdyHeadersIR headers(/* stream_id = */ 1); + headers.SetHeader(kHeader1, kValue1); + headers.SetHeader(kHeader2, kValue2); + SpdySerializedFrame headers_frame(SpdyFramerPeer::SerializeHeaders( + &framer_, headers, use_output_ ? &output_ : nullptr)); + + const char bytes[] = "this is a test test test test test!"; + SpdyDataIR data_ir(/* stream_id = */ 1, + absl::string_view(bytes, ABSL_ARRAYSIZE(bytes))); + data_ir.set_fin(true); + SpdySerializedFrame send_frame(framer_.SerializeData(data_ir)); + + // Run the inputs through the framer. + TestSpdyVisitor visitor(SpdyFramer::ENABLE_COMPRESSION); + const unsigned char* data; + data = reinterpret_cast(headers_frame.data()); + for (size_t idx = 0; idx < headers_frame.size(); ++idx) { + visitor.SimulateInFramer(data + idx, 1); + ASSERT_EQ(0, visitor.error_count_); + } + data = reinterpret_cast(send_frame.data()); + for (size_t idx = 0; idx < send_frame.size(); ++idx) { + visitor.SimulateInFramer(data + idx, 1); + ASSERT_EQ(0, visitor.error_count_); + } + + EXPECT_EQ(0, visitor.error_count_); + EXPECT_EQ(1, visitor.headers_frame_count_); + EXPECT_EQ(ABSL_ARRAYSIZE(bytes), static_cast(visitor.data_bytes_)); + EXPECT_EQ(0, visitor.fin_frame_count_); + EXPECT_EQ(0, visitor.fin_flag_count_); + EXPECT_EQ(1, visitor.end_of_stream_count_); + EXPECT_EQ(1, visitor.data_frame_count_); +} + +TEST_P(SpdyFramerTest, WindowUpdateFrame) { + SpdyWindowUpdateIR window_update(/* stream_id = */ 1, + /* delta = */ 0x12345678); + SpdySerializedFrame frame(framer_.SerializeWindowUpdate(window_update)); + if (use_output_) { + ASSERT_TRUE(framer_.SerializeWindowUpdate(window_update, &output_)); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + + const char kDescription[] = "WINDOW_UPDATE frame, stream 1, delta 0x12345678"; + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x04, // Length: 4 + 0x08, // Type: WINDOW_UPDATE + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x12, 0x34, 0x56, 0x78, // Increment: 305419896 + }; + + CompareFrame(kDescription, frame, kH2FrameData, ABSL_ARRAYSIZE(kH2FrameData)); +} + +TEST_P(SpdyFramerTest, CreateDataFrame) { + { + const char kDescription[] = "'hello' data frame, no FIN"; + // frame-format off + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x05, // Length: 5 + 0x00, // Type: DATA + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 'h', 'e', 'l', 'l', // Payload + 'o', // + }; + // frame-format on + const char bytes[] = "hello"; + + SpdyDataIR data_ir(/* stream_id = */ 1, bytes); + SpdySerializedFrame frame(framer_.SerializeData(data_ir)); + CompareFrame(kDescription, frame, kH2FrameData, + ABSL_ARRAYSIZE(kH2FrameData)); + + SpdyDataIR data_header_ir(/* stream_id = */ 1); + data_header_ir.SetDataShallow(bytes); + frame = + framer_.SerializeDataFrameHeaderWithPaddingLengthField(data_header_ir); + CompareCharArraysWithHexError( + kDescription, reinterpret_cast(frame.data()), + kDataFrameMinimumSize, kH2FrameData, kDataFrameMinimumSize); + } + + { + const char kDescription[] = "'hello' data frame with more padding, no FIN"; + // clang-format off + // frame-format off + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0xfd, // Length: 253 + 0x00, // Type: DATA + 0x08, // Flags: PADDED + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0xf7, // PadLen: 247 trailing bytes + 'h', 'e', 'l', 'l', // Payload + 'o', // + // Padding of 247 0x00(s). + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }; + // frame-format on + // clang-format on + const char bytes[] = "hello"; + + SpdyDataIR data_ir(/* stream_id = */ 1, bytes); + // 247 zeros and the pad length field make the overall padding to be 248 + // bytes. + data_ir.set_padding_len(248); + SpdySerializedFrame frame(framer_.SerializeData(data_ir)); + CompareFrame(kDescription, frame, kH2FrameData, + ABSL_ARRAYSIZE(kH2FrameData)); + + frame = framer_.SerializeDataFrameHeaderWithPaddingLengthField(data_ir); + CompareCharArraysWithHexError( + kDescription, reinterpret_cast(frame.data()), + kDataFrameMinimumSize, kH2FrameData, kDataFrameMinimumSize); + } + + { + const char kDescription[] = "'hello' data frame with few padding, no FIN"; + // frame-format off + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x0d, // Length: 13 + 0x00, // Type: DATA + 0x08, // Flags: PADDED + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x07, // PadLen: 7 trailing bytes + 'h', 'e', 'l', 'l', // Payload + 'o', // + 0x00, 0x00, 0x00, 0x00, // Padding + 0x00, 0x00, 0x00, // Padding + }; + // frame-format on + const char bytes[] = "hello"; + + SpdyDataIR data_ir(/* stream_id = */ 1, bytes); + // 7 zeros and the pad length field make the overall padding to be 8 bytes. + data_ir.set_padding_len(8); + SpdySerializedFrame frame(framer_.SerializeData(data_ir)); + CompareFrame(kDescription, frame, kH2FrameData, + ABSL_ARRAYSIZE(kH2FrameData)); + + frame = framer_.SerializeDataFrameHeaderWithPaddingLengthField(data_ir); + CompareCharArraysWithHexError( + kDescription, reinterpret_cast(frame.data()), + kDataFrameMinimumSize, kH2FrameData, kDataFrameMinimumSize); + } + + { + const char kDescription[] = + "'hello' data frame with 1 byte padding, no FIN"; + // frame-format off + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x06, // Length: 6 + 0x00, // Type: DATA + 0x08, // Flags: PADDED + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, // PadLen: 0 trailing bytes + 'h', 'e', 'l', 'l', // Payload + 'o', // + }; + // frame-format on + const char bytes[] = "hello"; + + SpdyDataIR data_ir(/* stream_id = */ 1, bytes); + // The pad length field itself is used for the 1-byte padding and no padding + // payload is needed. + data_ir.set_padding_len(1); + SpdySerializedFrame frame(framer_.SerializeData(data_ir)); + CompareFrame(kDescription, frame, kH2FrameData, + ABSL_ARRAYSIZE(kH2FrameData)); + + frame = framer_.SerializeDataFrameHeaderWithPaddingLengthField(data_ir); + CompareCharArraysWithHexError( + kDescription, reinterpret_cast(frame.data()), + kDataFrameMinimumSize, kH2FrameData, kDataFrameMinimumSize); + } + + { + const char kDescription[] = "Data frame with negative data byte, no FIN"; + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x01, // Length: 1 + 0x00, // Type: DATA + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0xff, // Payload + }; + SpdyDataIR data_ir(/* stream_id = */ 1, "\xff"); + SpdySerializedFrame frame(framer_.SerializeData(data_ir)); + CompareFrame(kDescription, frame, kH2FrameData, + ABSL_ARRAYSIZE(kH2FrameData)); + } + + { + const char kDescription[] = "'hello' data frame, with FIN"; + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x05, // Length: 5 + 0x00, // Type: DATA + 0x01, // Flags: END_STREAM + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x68, 0x65, 0x6c, 0x6c, // Payload + 0x6f, // + }; + SpdyDataIR data_ir(/* stream_id = */ 1, "hello"); + data_ir.set_fin(true); + SpdySerializedFrame frame(framer_.SerializeData(data_ir)); + CompareFrame(kDescription, frame, kH2FrameData, + ABSL_ARRAYSIZE(kH2FrameData)); + } + + { + const char kDescription[] = "Empty data frame"; + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x00, // Length: 0 + 0x00, // Type: DATA + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + }; + SpdyDataIR data_ir(/* stream_id = */ 1, ""); + SpdySerializedFrame frame(framer_.SerializeData(data_ir)); + CompareFrame(kDescription, frame, kH2FrameData, + ABSL_ARRAYSIZE(kH2FrameData)); + + frame = framer_.SerializeDataFrameHeaderWithPaddingLengthField(data_ir); + CompareCharArraysWithHexError( + kDescription, reinterpret_cast(frame.data()), + kDataFrameMinimumSize, kH2FrameData, kDataFrameMinimumSize); + } + + { + const char kDescription[] = "Data frame with max stream ID"; + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x05, // Length: 5 + 0x00, // Type: DATA + 0x01, // Flags: END_STREAM + 0x7f, 0xff, 0xff, 0xff, // Stream: 0x7fffffff + 0x68, 0x65, 0x6c, 0x6c, // Payload + 0x6f, // + }; + SpdyDataIR data_ir(/* stream_id = */ 0x7fffffff, "hello"); + data_ir.set_fin(true); + SpdySerializedFrame frame(framer_.SerializeData(data_ir)); + CompareFrame(kDescription, frame, kH2FrameData, + ABSL_ARRAYSIZE(kH2FrameData)); + } +} + +TEST_P(SpdyFramerTest, CreateRstStream) { + { + const char kDescription[] = "RST_STREAM frame"; + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x04, // Length: 4 + 0x03, // Type: RST_STREAM + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, 0x00, 0x00, 0x01, // Error: PROTOCOL_ERROR + }; + SpdyRstStreamIR rst_stream(/* stream_id = */ 1, ERROR_CODE_PROTOCOL_ERROR); + SpdySerializedFrame frame(framer_.SerializeRstStream(rst_stream)); + if (use_output_) { + ASSERT_TRUE(framer_.SerializeRstStream(rst_stream, &output_)); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + CompareFrame(kDescription, frame, kH2FrameData, + ABSL_ARRAYSIZE(kH2FrameData)); + } + + { + const char kDescription[] = "RST_STREAM frame with max stream ID"; + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x04, // Length: 4 + 0x03, // Type: RST_STREAM + 0x00, // Flags: none + 0x7f, 0xff, 0xff, 0xff, // Stream: 0x7fffffff + 0x00, 0x00, 0x00, 0x01, // Error: PROTOCOL_ERROR + }; + SpdyRstStreamIR rst_stream(/* stream_id = */ 0x7FFFFFFF, + ERROR_CODE_PROTOCOL_ERROR); + SpdySerializedFrame frame(framer_.SerializeRstStream(rst_stream)); + if (use_output_) { + output_.Reset(); + ASSERT_TRUE(framer_.SerializeRstStream(rst_stream, &output_)); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + CompareFrame(kDescription, frame, kH2FrameData, + ABSL_ARRAYSIZE(kH2FrameData)); + } + + { + const char kDescription[] = "RST_STREAM frame with max status code"; + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x04, // Length: 4 + 0x03, // Type: RST_STREAM + 0x00, // Flags: none + 0x7f, 0xff, 0xff, 0xff, // Stream: 0x7fffffff + 0x00, 0x00, 0x00, 0x02, // Error: INTERNAL_ERROR + }; + SpdyRstStreamIR rst_stream(/* stream_id = */ 0x7FFFFFFF, + ERROR_CODE_INTERNAL_ERROR); + SpdySerializedFrame frame(framer_.SerializeRstStream(rst_stream)); + if (use_output_) { + output_.Reset(); + ASSERT_TRUE(framer_.SerializeRstStream(rst_stream, &output_)); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + CompareFrame(kDescription, frame, kH2FrameData, + ABSL_ARRAYSIZE(kH2FrameData)); + } +} + +TEST_P(SpdyFramerTest, CreateSettings) { + { + const char kDescription[] = "Network byte order SETTINGS frame"; + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x06, // Length: 6 + 0x04, // Type: SETTINGS + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x00, // Stream: 0 + 0x00, 0x04, // Param: INITIAL_WINDOW_SIZE + 0x0a, 0x0b, 0x0c, 0x0d, // Value: 168496141 + }; + + uint32_t kValue = 0x0a0b0c0d; + SpdySettingsIR settings_ir; + + SpdyKnownSettingsId kId = SETTINGS_INITIAL_WINDOW_SIZE; + settings_ir.AddSetting(kId, kValue); + + SpdySerializedFrame frame(framer_.SerializeSettings(settings_ir)); + if (use_output_) { + ASSERT_TRUE(framer_.SerializeSettings(settings_ir, &output_)); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + CompareFrame(kDescription, frame, kH2FrameData, + ABSL_ARRAYSIZE(kH2FrameData)); + } + + { + const char kDescription[] = "Basic SETTINGS frame"; + // These end up seemingly out of order because of the way that our internal + // ordering for settings_ir works. HTTP2 has no requirement on ordering on + // the wire. + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x18, // Length: 24 + 0x04, // Type: SETTINGS + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x00, // Stream: 0 + 0x00, 0x01, // Param: HEADER_TABLE_SIZE + 0x00, 0x00, 0x00, 0x05, // Value: 5 + 0x00, 0x02, // Param: ENABLE_PUSH + 0x00, 0x00, 0x00, 0x06, // Value: 6 + 0x00, 0x03, // Param: MAX_CONCURRENT_STREAMS + 0x00, 0x00, 0x00, 0x07, // Value: 7 + 0x00, 0x04, // Param: INITIAL_WINDOW_SIZE + 0x00, 0x00, 0x00, 0x08, // Value: 8 + }; + + SpdySettingsIR settings_ir; + settings_ir.AddSetting(SETTINGS_HEADER_TABLE_SIZE, 5); + settings_ir.AddSetting(SETTINGS_ENABLE_PUSH, 6); + settings_ir.AddSetting(SETTINGS_MAX_CONCURRENT_STREAMS, 7); + settings_ir.AddSetting(SETTINGS_INITIAL_WINDOW_SIZE, 8); + SpdySerializedFrame frame(framer_.SerializeSettings(settings_ir)); + if (use_output_) { + output_.Reset(); + ASSERT_TRUE(framer_.SerializeSettings(settings_ir, &output_)); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + + CompareFrame(kDescription, frame, kH2FrameData, + ABSL_ARRAYSIZE(kH2FrameData)); + } + + { + const char kDescription[] = "Empty SETTINGS frame"; + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x00, // Length: 0 + 0x04, // Type: SETTINGS + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x00, // Stream: 0 + }; + SpdySettingsIR settings_ir; + SpdySerializedFrame frame(framer_.SerializeSettings(settings_ir)); + if (use_output_) { + output_.Reset(); + ASSERT_TRUE(framer_.SerializeSettings(settings_ir, &output_)); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + + CompareFrame(kDescription, frame, kH2FrameData, + ABSL_ARRAYSIZE(kH2FrameData)); + } +} + +TEST_P(SpdyFramerTest, CreatePingFrame) { + { + const char kDescription[] = "PING frame"; + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x08, // Length: 8 + 0x06, // Type: PING + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x00, // Stream: 0 + 0x12, 0x34, 0x56, 0x78, // Opaque + 0x9a, 0xbc, 0xde, 0xff, // Data + }; + const unsigned char kH2FrameDataWithAck[] = { + 0x00, 0x00, 0x08, // Length: 8 + 0x06, // Type: PING + 0x01, // Flags: ACK + 0x00, 0x00, 0x00, 0x00, // Stream: 0 + 0x12, 0x34, 0x56, 0x78, // Opaque + 0x9a, 0xbc, 0xde, 0xff, // Data + }; + const SpdyPingId kPingId = 0x123456789abcdeffULL; + SpdyPingIR ping_ir(kPingId); + // Tests SpdyPingIR when the ping is not an ack. + ASSERT_FALSE(ping_ir.is_ack()); + SpdySerializedFrame frame(framer_.SerializePing(ping_ir)); + if (use_output_) { + ASSERT_TRUE(framer_.SerializePing(ping_ir, &output_)); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + CompareFrame(kDescription, frame, kH2FrameData, + ABSL_ARRAYSIZE(kH2FrameData)); + + // Tests SpdyPingIR when the ping is an ack. + ping_ir.set_is_ack(true); + frame = framer_.SerializePing(ping_ir); + if (use_output_) { + output_.Reset(); + ASSERT_TRUE(framer_.SerializePing(ping_ir, &output_)); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + CompareFrame(kDescription, frame, kH2FrameDataWithAck, + ABSL_ARRAYSIZE(kH2FrameDataWithAck)); + } +} + +TEST_P(SpdyFramerTest, CreateGoAway) { + { + const char kDescription[] = "GOAWAY frame"; + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x0a, // Length: 10 + 0x07, // Type: GOAWAY + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x00, // Stream: 0 + 0x00, 0x00, 0x00, 0x00, // Last: 0 + 0x00, 0x00, 0x00, 0x00, // Error: NO_ERROR + 0x47, 0x41, // Description + }; + SpdyGoAwayIR goaway_ir(/* last_good_stream_id = */ 0, ERROR_CODE_NO_ERROR, + "GA"); + SpdySerializedFrame frame(framer_.SerializeGoAway(goaway_ir)); + if (use_output_) { + ASSERT_TRUE(framer_.SerializeGoAway(goaway_ir, &output_)); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + CompareFrame(kDescription, frame, kH2FrameData, + ABSL_ARRAYSIZE(kH2FrameData)); + } + + { + const char kDescription[] = "GOAWAY frame with max stream ID, status"; + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x0a, // Length: 10 + 0x07, // Type: GOAWAY + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x00, // Stream: 0 + 0x7f, 0xff, 0xff, 0xff, // Last: 0x7fffffff + 0x00, 0x00, 0x00, 0x02, // Error: INTERNAL_ERROR + 0x47, 0x41, // Description + }; + SpdyGoAwayIR goaway_ir(/* last_good_stream_id = */ 0x7FFFFFFF, + ERROR_CODE_INTERNAL_ERROR, "GA"); + SpdySerializedFrame frame(framer_.SerializeGoAway(goaway_ir)); + if (use_output_) { + output_.Reset(); + ASSERT_TRUE(framer_.SerializeGoAway(goaway_ir, &output_)); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + CompareFrame(kDescription, frame, kH2FrameData, + ABSL_ARRAYSIZE(kH2FrameData)); + } +} + +TEST_P(SpdyFramerTest, CreateHeadersUncompressed) { + SpdyFramer framer(SpdyFramer::DISABLE_COMPRESSION); + + { + const char kDescription[] = "HEADERS frame, no FIN"; + // frame-format off + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x12, // Length: 18 + 0x01, // Type: HEADERS + 0x04, // Flags: END_HEADERS + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + + 0x00, // Unindexed Entry + 0x03, // Name Len: 3 + 0x62, 0x61, 0x72, // bar + 0x03, // Value Len: 3 + 0x66, 0x6f, 0x6f, // foo + + 0x00, // Unindexed Entry + 0x03, // Name Len: 3 + 0x66, 0x6f, 0x6f, // foo + 0x03, // Value Len: 3 + 0x62, 0x61, 0x72, // bar + }; + // frame-format on + SpdyHeadersIR headers(/* stream_id = */ 1); + headers.SetHeader("bar", "foo"); + headers.SetHeader("foo", "bar"); + SpdySerializedFrame frame(SpdyFramerPeer::SerializeHeaders( + &framer, headers, use_output_ ? &output_ : nullptr)); + CompareFrame(kDescription, frame, kH2FrameData, + ABSL_ARRAYSIZE(kH2FrameData)); + } + + { + const char kDescription[] = + "HEADERS frame with a 0-length header name, FIN, max stream ID"; + // frame-format off + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x0f, // Length: 15 + 0x01, // Type: HEADERS + 0x05, // Flags: END_STREAM|END_HEADERS + 0x7f, 0xff, 0xff, 0xff, // Stream: 2147483647 + + 0x00, // Unindexed Entry + 0x00, // Name Len: 0 + 0x03, // Value Len: 3 + 0x66, 0x6f, 0x6f, // foo + + 0x00, // Unindexed Entry + 0x03, // Name Len: 3 + 0x66, 0x6f, 0x6f, // foo + 0x03, // Value Len: 3 + 0x62, 0x61, 0x72, // bar + }; + // frame-format on + SpdyHeadersIR headers(/* stream_id = */ 0x7fffffff); + headers.set_fin(true); + headers.SetHeader("", "foo"); + headers.SetHeader("foo", "bar"); + SpdySerializedFrame frame(SpdyFramerPeer::SerializeHeaders( + &framer, headers, use_output_ ? &output_ : nullptr)); + CompareFrame(kDescription, frame, kH2FrameData, + ABSL_ARRAYSIZE(kH2FrameData)); + } + + { + const char kDescription[] = + "HEADERS frame with a 0-length header val, FIN, max stream ID"; + // frame-format off + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x0f, // Length: 15 + 0x01, // Type: HEADERS + 0x05, // Flags: END_STREAM|END_HEADERS + 0x7f, 0xff, 0xff, 0xff, // Stream: 2147483647 + + 0x00, // Unindexed Entry + 0x03, // Name Len: 3 + 0x62, 0x61, 0x72, // bar + 0x03, // Value Len: 3 + 0x66, 0x6f, 0x6f, // foo + + 0x00, // Unindexed Entry + 0x03, // Name Len: 3 + 0x66, 0x6f, 0x6f, // foo + 0x00, // Value Len: 0 + }; + // frame-format on + SpdyHeadersIR headers_ir(/* stream_id = */ 0x7fffffff); + headers_ir.set_fin(true); + headers_ir.SetHeader("bar", "foo"); + headers_ir.SetHeader("foo", ""); + SpdySerializedFrame frame(SpdyFramerPeer::SerializeHeaders( + &framer, headers_ir, use_output_ ? &output_ : nullptr)); + CompareFrame(kDescription, frame, kH2FrameData, + ABSL_ARRAYSIZE(kH2FrameData)); + } + + { + const char kDescription[] = + "HEADERS frame with a 0-length header val, FIN, max stream ID, pri"; + + // frame-format off + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x14, // Length: 20 + 0x01, // Type: HEADERS + 0x25, // Flags: END_STREAM|END_HEADERS|PRIORITY + 0x7f, 0xff, 0xff, 0xff, // Stream: 2147483647 + 0x00, 0x00, 0x00, 0x00, // Parent: 0 + 0xdb, // Weight: 220 + + 0x00, // Unindexed Entry + 0x03, // Name Len: 3 + 0x62, 0x61, 0x72, // bar + 0x03, // Value Len: 3 + 0x66, 0x6f, 0x6f, // foo + + 0x00, // Unindexed Entry + 0x03, // Name Len: 3 + 0x66, 0x6f, 0x6f, // foo + 0x00, // Value Len: 0 + }; + // frame-format on + SpdyHeadersIR headers_ir(/* stream_id = */ 0x7fffffff); + headers_ir.set_fin(true); + headers_ir.set_has_priority(true); + headers_ir.set_weight(220); + headers_ir.SetHeader("bar", "foo"); + headers_ir.SetHeader("foo", ""); + SpdySerializedFrame frame(SpdyFramerPeer::SerializeHeaders( + &framer, headers_ir, use_output_ ? &output_ : nullptr)); + CompareFrame(kDescription, frame, kH2FrameData, + ABSL_ARRAYSIZE(kH2FrameData)); + } + + { + const char kDescription[] = + "HEADERS frame with a 0-length header val, FIN, max stream ID, pri, " + "exclusive=true, parent_stream=0"; + + // frame-format off + const unsigned char kV4FrameData[] = { + 0x00, 0x00, 0x14, // Length: 20 + 0x01, // Type: HEADERS + 0x25, // Flags: END_STREAM|END_HEADERS|PRIORITY + 0x7f, 0xff, 0xff, 0xff, // Stream: 2147483647 + 0x80, 0x00, 0x00, 0x00, // Parent: 0 (Exclusive) + 0xdb, // Weight: 220 + + 0x00, // Unindexed Entry + 0x03, // Name Len: 3 + 0x62, 0x61, 0x72, // bar + 0x03, // Value Len: 3 + 0x66, 0x6f, 0x6f, // foo + + 0x00, // Unindexed Entry + 0x03, // Name Len: 3 + 0x66, 0x6f, 0x6f, // foo + 0x00, // Value Len: 0 + }; + // frame-format on + SpdyHeadersIR headers_ir(/* stream_id = */ 0x7fffffff); + headers_ir.set_fin(true); + headers_ir.set_has_priority(true); + headers_ir.set_weight(220); + headers_ir.set_exclusive(true); + headers_ir.set_parent_stream_id(0); + headers_ir.SetHeader("bar", "foo"); + headers_ir.SetHeader("foo", ""); + SpdySerializedFrame frame(SpdyFramerPeer::SerializeHeaders( + &framer, headers_ir, use_output_ ? &output_ : nullptr)); + CompareFrame(kDescription, frame, kV4FrameData, + ABSL_ARRAYSIZE(kV4FrameData)); + } + + { + const char kDescription[] = + "HEADERS frame with a 0-length header val, FIN, max stream ID, pri, " + "exclusive=false, parent_stream=max stream ID"; + + // frame-format off + const unsigned char kV4FrameData[] = { + 0x00, 0x00, 0x14, // Length: 20 + 0x01, // Type: HEADERS + 0x25, // Flags: END_STREAM|END_HEADERS|PRIORITY + 0x7f, 0xff, 0xff, 0xff, // Stream: 2147483647 + 0x7f, 0xff, 0xff, 0xff, // Parent: 2147483647 + 0xdb, // Weight: 220 + + 0x00, // Unindexed Entry + 0x03, // Name Len: 3 + 0x62, 0x61, 0x72, // bar + 0x03, // Value Len: 3 + 0x66, 0x6f, 0x6f, // foo + + 0x00, // Unindexed Entry + 0x03, // Name Len: 3 + 0x66, 0x6f, 0x6f, // foo + 0x00, // Value Len: 0 + }; + // frame-format on + SpdyHeadersIR headers_ir(/* stream_id = */ 0x7fffffff); + headers_ir.set_fin(true); + headers_ir.set_has_priority(true); + headers_ir.set_weight(220); + headers_ir.set_exclusive(false); + headers_ir.set_parent_stream_id(0x7fffffff); + headers_ir.SetHeader("bar", "foo"); + headers_ir.SetHeader("foo", ""); + SpdySerializedFrame frame(SpdyFramerPeer::SerializeHeaders( + &framer, headers_ir, use_output_ ? &output_ : nullptr)); + CompareFrame(kDescription, frame, kV4FrameData, + ABSL_ARRAYSIZE(kV4FrameData)); + } + + { + const char kDescription[] = + "HEADERS frame with a 0-length header name, FIN, max stream ID, padded"; + + // frame-format off + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x15, // Length: 21 + 0x01, // Type: HEADERS + 0x0d, // Flags: END_STREAM|END_HEADERS|PADDED + 0x7f, 0xff, 0xff, 0xff, // Stream: 2147483647 + 0x05, // PadLen: 5 trailing bytes + + 0x00, // Unindexed Entry + 0x00, // Name Len: 0 + 0x03, // Value Len: 3 + 0x66, 0x6f, 0x6f, // foo + + 0x00, // Unindexed Entry + 0x03, // Name Len: 3 + 0x66, 0x6f, 0x6f, // foo + 0x03, // Value Len: 3 + 0x62, 0x61, 0x72, // bar + + 0x00, 0x00, 0x00, 0x00, // Padding + 0x00, // Padding + }; + // frame-format on + SpdyHeadersIR headers_ir(/* stream_id = */ 0x7fffffff); + headers_ir.set_fin(true); + headers_ir.SetHeader("", "foo"); + headers_ir.SetHeader("foo", "bar"); + headers_ir.set_padding_len(6); + SpdySerializedFrame frame(SpdyFramerPeer::SerializeHeaders( + &framer, headers_ir, use_output_ ? &output_ : nullptr)); + CompareFrame(kDescription, frame, kH2FrameData, + ABSL_ARRAYSIZE(kH2FrameData)); + } +} + +TEST_P(SpdyFramerTest, CreateWindowUpdate) { + { + const char kDescription[] = "WINDOW_UPDATE frame"; + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x04, // Length: 4 + 0x08, // Type: WINDOW_UPDATE + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, 0x00, 0x00, 0x01, // Increment: 1 + }; + SpdySerializedFrame frame(framer_.SerializeWindowUpdate( + SpdyWindowUpdateIR(/* stream_id = */ 1, /* delta = */ 1))); + if (use_output_) { + output_.Reset(); + ASSERT_TRUE(framer_.SerializeWindowUpdate( + SpdyWindowUpdateIR(/* stream_id = */ 1, /* delta = */ 1), &output_)); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + CompareFrame(kDescription, frame, kH2FrameData, + ABSL_ARRAYSIZE(kH2FrameData)); + } + + { + const char kDescription[] = "WINDOW_UPDATE frame with max stream ID"; + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x04, // Length: 4 + 0x08, // Type: WINDOW_UPDATE + 0x00, // Flags: none + 0x7f, 0xff, 0xff, 0xff, // Stream: 0x7fffffff + 0x00, 0x00, 0x00, 0x01, // Increment: 1 + }; + SpdySerializedFrame frame(framer_.SerializeWindowUpdate( + SpdyWindowUpdateIR(/* stream_id = */ 0x7FFFFFFF, /* delta = */ 1))); + if (use_output_) { + output_.Reset(); + ASSERT_TRUE(framer_.SerializeWindowUpdate( + SpdyWindowUpdateIR(/* stream_id = */ 0x7FFFFFFF, /* delta = */ 1), + &output_)); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + CompareFrame(kDescription, frame, kH2FrameData, + ABSL_ARRAYSIZE(kH2FrameData)); + } + + { + const char kDescription[] = "WINDOW_UPDATE frame with max window delta"; + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x04, // Length: 4 + 0x08, // Type: WINDOW_UPDATE + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x7f, 0xff, 0xff, 0xff, // Increment: 0x7fffffff + }; + SpdySerializedFrame frame(framer_.SerializeWindowUpdate( + SpdyWindowUpdateIR(/* stream_id = */ 1, /* delta = */ 0x7FFFFFFF))); + if (use_output_) { + output_.Reset(); + ASSERT_TRUE(framer_.SerializeWindowUpdate( + SpdyWindowUpdateIR(/* stream_id = */ 1, /* delta = */ 0x7FFFFFFF), + &output_)); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + CompareFrame(kDescription, frame, kH2FrameData, + ABSL_ARRAYSIZE(kH2FrameData)); + } +} + +TEST_P(SpdyFramerTest, CreatePushPromiseUncompressed) { + { + // Test framing PUSH_PROMISE without padding. + SpdyFramer framer(SpdyFramer::DISABLE_COMPRESSION); + const char kDescription[] = "PUSH_PROMISE frame without padding"; + + // frame-format off + const unsigned char kFrameData[] = { + 0x00, 0x00, 0x16, // Length: 22 + 0x05, // Type: PUSH_PROMISE + 0x04, // Flags: END_HEADERS + 0x00, 0x00, 0x00, 0x29, // Stream: 41 + 0x00, 0x00, 0x00, 0x3a, // Promise: 58 + + 0x00, // Unindexed Entry + 0x03, // Name Len: 3 + 0x62, 0x61, 0x72, // bar + 0x03, // Value Len: 3 + 0x66, 0x6f, 0x6f, // foo + + 0x00, // Unindexed Entry + 0x03, // Name Len: 3 + 0x66, 0x6f, 0x6f, // foo + 0x03, // Value Len: 3 + 0x62, 0x61, 0x72, // bar + }; + // frame-format on + + SpdyPushPromiseIR push_promise(/* stream_id = */ 41, + /* promised_stream_id = */ 58); + push_promise.SetHeader("bar", "foo"); + push_promise.SetHeader("foo", "bar"); + SpdySerializedFrame frame(SpdyFramerPeer::SerializePushPromise( + &framer, push_promise, use_output_ ? &output_ : nullptr)); + CompareFrame(kDescription, frame, kFrameData, ABSL_ARRAYSIZE(kFrameData)); + } + + { + // Test framing PUSH_PROMISE with one byte of padding. + SpdyFramer framer(SpdyFramer::DISABLE_COMPRESSION); + const char kDescription[] = "PUSH_PROMISE frame with one byte of padding"; + + // frame-format off + const unsigned char kFrameData[] = { + 0x00, 0x00, 0x17, // Length: 23 + 0x05, // Type: PUSH_PROMISE + 0x0c, // Flags: END_HEADERS|PADDED + 0x00, 0x00, 0x00, 0x29, // Stream: 41 + 0x00, // PadLen: 0 trailing bytes + 0x00, 0x00, 0x00, 0x3a, // Promise: 58 + + 0x00, // Unindexed Entry + 0x03, // Name Len: 3 + 0x62, 0x61, 0x72, // bar + 0x03, // Value Len: 3 + 0x66, 0x6f, 0x6f, // foo + + 0x00, // Unindexed Entry + 0x03, // Name Len: 3 + 0x66, 0x6f, 0x6f, // foo + 0x03, // Value Len: 3 + 0x62, 0x61, 0x72, // bar + }; + // frame-format on + + SpdyPushPromiseIR push_promise(/* stream_id = */ 41, + /* promised_stream_id = */ 58); + push_promise.set_padding_len(1); + push_promise.SetHeader("bar", "foo"); + push_promise.SetHeader("foo", "bar"); + output_.Reset(); + SpdySerializedFrame frame(SpdyFramerPeer::SerializePushPromise( + &framer, push_promise, use_output_ ? &output_ : nullptr)); + + CompareFrame(kDescription, frame, kFrameData, ABSL_ARRAYSIZE(kFrameData)); + } + + { + // Test framing PUSH_PROMISE with 177 bytes of padding. + SpdyFramer framer(SpdyFramer::DISABLE_COMPRESSION); + const char kDescription[] = "PUSH_PROMISE frame with 177 bytes of padding"; + + // frame-format off + // clang-format off + const unsigned char kFrameData[] = { + 0x00, 0x00, 0xc7, // Length: 199 + 0x05, // Type: PUSH_PROMISE + 0x0c, // Flags: END_HEADERS|PADDED + 0x00, 0x00, 0x00, 0x2a, // Stream: 42 + 0xb0, // PadLen: 176 trailing bytes + 0x00, 0x00, 0x00, 0x39, // Promise: 57 + + 0x00, // Unindexed Entry + 0x03, // Name Len: 3 + 0x62, 0x61, 0x72, // bar + 0x03, // Value Len: 3 + 0x66, 0x6f, 0x6f, // foo + + 0x00, // Unindexed Entry + 0x03, // Name Len: 3 + 0x66, 0x6f, 0x6f, // foo + 0x03, // Value Len: 3 + 0x62, 0x61, 0x72, // bar + + // Padding of 176 0x00(s). + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }; + // clang-format on + // frame-format on + + SpdyPushPromiseIR push_promise(/* stream_id = */ 42, + /* promised_stream_id = */ 57); + push_promise.set_padding_len(177); + push_promise.SetHeader("bar", "foo"); + push_promise.SetHeader("foo", "bar"); + output_.Reset(); + SpdySerializedFrame frame(SpdyFramerPeer::SerializePushPromise( + &framer, push_promise, use_output_ ? &output_ : nullptr)); + + CompareFrame(kDescription, frame, kFrameData, ABSL_ARRAYSIZE(kFrameData)); + } +} + +// Regression test for https://crbug.com/464748. +TEST_P(SpdyFramerTest, GetNumberRequiredContinuationFrames) { + EXPECT_EQ(1u, GetNumberRequiredContinuationFrames(16383 + 16374)); + EXPECT_EQ(2u, GetNumberRequiredContinuationFrames(16383 + 16374 + 1)); + EXPECT_EQ(2u, GetNumberRequiredContinuationFrames(16383 + 2 * 16374)); + EXPECT_EQ(3u, GetNumberRequiredContinuationFrames(16383 + 2 * 16374 + 1)); +} + +TEST_P(SpdyFramerTest, CreateContinuationUncompressed) { + SpdyFramer framer(SpdyFramer::DISABLE_COMPRESSION); + const char kDescription[] = "CONTINUATION frame"; + + // frame-format off + const unsigned char kFrameData[] = { + 0x00, 0x00, 0x12, // Length: 18 + 0x09, // Type: CONTINUATION + 0x04, // Flags: END_HEADERS + 0x00, 0x00, 0x00, 0x2a, // Stream: 42 + + 0x00, // Unindexed Entry + 0x03, // Name Len: 3 + 0x62, 0x61, 0x72, // bar + 0x03, // Value Len: 3 + 0x66, 0x6f, 0x6f, // foo + + 0x00, // Unindexed Entry + 0x03, // Name Len: 3 + 0x66, 0x6f, 0x6f, // foo + 0x03, // Value Len: 3 + 0x62, 0x61, 0x72, // bar + }; + // frame-format on + + Http2HeaderBlock header_block; + header_block["bar"] = "foo"; + header_block["foo"] = "bar"; + HpackEncoder encoder; + encoder.DisableCompression(); + std::string buffer = encoder.EncodeHeaderBlock(header_block); + + SpdyContinuationIR continuation(/* stream_id = */ 42); + continuation.take_encoding(std::move(buffer)); + continuation.set_end_headers(true); + + SpdySerializedFrame frame(framer.SerializeContinuation(continuation)); + if (use_output_) { + ASSERT_TRUE(framer.SerializeContinuation(continuation, &output_)); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + CompareFrame(kDescription, frame, kFrameData, ABSL_ARRAYSIZE(kFrameData)); +} + +// Test that if we send an unexpected CONTINUATION +// we signal an error (but don't crash). +TEST_P(SpdyFramerTest, SendUnexpectedContinuation) { + testing::StrictMock visitor; + + deframer_->set_visitor(&visitor); + + // frame-format off + char kH2FrameData[] = { + 0x00, 0x00, 0x12, // Length: 18 + 0x09, // Type: CONTINUATION + 0x04, // Flags: END_HEADERS + 0x00, 0x00, 0x00, 0x2a, // Stream: 42 + + 0x00, // Unindexed Entry + 0x03, // Name Len: 3 + 0x62, 0x61, 0x72, // bar + 0x03, // Value Len: 3 + 0x66, 0x6f, 0x6f, // foo + + 0x00, // Unindexed Entry + 0x03, // Name Len: 3 + 0x66, 0x6f, 0x6f, // foo + 0x03, // Value Len: 3 + 0x62, 0x61, 0x72, // bar + }; + // frame-format on + + SpdySerializedFrame frame(kH2FrameData, sizeof(kH2FrameData), false); + + // We shouldn't have to read the whole frame before we signal an error. + EXPECT_CALL(visitor, OnCommonHeader(42, 18, 0x9, 0x4)); + EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_UNEXPECTED_FRAME, _)); + EXPECT_GT(frame.size(), deframer_->ProcessInput(frame.data(), frame.size())); + EXPECT_TRUE(deframer_->HasError()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_UNEXPECTED_FRAME, + deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); +} + +TEST_P(SpdyFramerTest, CreatePushPromiseThenContinuationUncompressed) { + { + // Test framing in a case such that a PUSH_PROMISE frame, with one byte of + // padding, cannot hold all the data payload, which is overflowed to the + // consecutive CONTINUATION frame. + SpdyFramer framer(SpdyFramer::DISABLE_COMPRESSION); + const char kDescription[] = + "PUSH_PROMISE and CONTINUATION frames with one byte of padding"; + + // frame-format off + const unsigned char kPartialPushPromiseFrameData[] = { + 0x00, 0x3f, 0xf6, // Length: 16374 + 0x05, // Type: PUSH_PROMISE + 0x08, // Flags: PADDED + 0x00, 0x00, 0x00, 0x2a, // Stream: 42 + 0x00, // PadLen: 0 trailing bytes + 0x00, 0x00, 0x00, 0x39, // Promise: 57 + + 0x00, // Unindexed Entry + 0x03, // Name Len: 3 + 0x78, 0x78, 0x78, // xxx + 0x7f, 0x80, 0x7f, // Value Len: 16361 + 0x78, 0x78, 0x78, 0x78, // xxxx + 0x78, 0x78, 0x78, 0x78, // xxxx + 0x78, 0x78, 0x78, 0x78, // xxxx + 0x78, 0x78, 0x78, 0x78, // xxxx + 0x78, 0x78, 0x78, 0x78, // xxxx + 0x78, 0x78, 0x78, 0x78, // xxxx + 0x78, 0x78, 0x78, 0x78, // xxxx + 0x78, 0x78, 0x78, 0x78, // xxxx + 0x78, 0x78, 0x78, 0x78, // xxxx + 0x78, 0x78, 0x78, 0x78, // xxxx + 0x78, 0x78, 0x78, 0x78, // xxxx + 0x78, 0x78, 0x78, 0x78, // xxxx + 0x78, 0x78, 0x78, 0x78, // xxxx + 0x78, 0x78, 0x78, 0x78, // xxxx + 0x78, 0x78, 0x78, 0x78, // xxxx + 0x78, 0x78, 0x78, 0x78, // xxxx + 0x78, 0x78, 0x78, 0x78, // xxxx + 0x78, 0x78, 0x78, 0x78, // xxxx + 0x78, 0x78, 0x78, 0x78, // xxxx + 0x78, 0x78, 0x78, 0x78, // xxxx + 0x78, 0x78, 0x78, 0x78, // xxxx + }; + const unsigned char kContinuationFrameData[] = { + 0x00, 0x00, 0x16, // Length: 22 + 0x09, // Type: CONTINUATION + 0x04, // Flags: END_HEADERS + 0x00, 0x00, 0x00, 0x2a, // Stream: 42 + 0x78, 0x78, 0x78, 0x78, // xxxx + 0x78, 0x78, 0x78, 0x78, // xxxx + 0x78, 0x78, 0x78, 0x78, // xxxx + 0x78, 0x78, 0x78, 0x78, // xxxx + 0x78, 0x78, 0x78, 0x78, // xxxx + 0x78, // x + }; + // frame-format on + + SpdyPushPromiseIR push_promise(/* stream_id = */ 42, + /* promised_stream_id = */ 57); + push_promise.set_padding_len(1); + std::string big_value(kHttp2MaxControlFrameSendSize, 'x'); + push_promise.SetHeader("xxx", big_value); + SpdySerializedFrame frame(SpdyFramerPeer::SerializePushPromise( + &framer, push_promise, use_output_ ? &output_ : nullptr)); + + // The entire frame should look like below: + // Name Length in Byte + // ------------------------------------------- Begin of PUSH_PROMISE frame + // PUSH_PROMISE header 9 + // Pad length field 1 + // Promised stream 4 + // Length field of key 2 + // Content of key 3 + // Length field of value 3 + // Part of big_value 16361 + // ------------------------------------------- Begin of CONTINUATION frame + // CONTINUATION header 9 + // Remaining of big_value 22 + // ------------------------------------------- End + + // Length of everything listed above except big_value. + int len_non_data_payload = 31; + EXPECT_EQ(kHttp2MaxControlFrameSendSize + len_non_data_payload, + frame.size()); + + // Partially compare the PUSH_PROMISE frame against the template. + const unsigned char* frame_data = + reinterpret_cast(frame.data()); + CompareCharArraysWithHexError(kDescription, frame_data, + ABSL_ARRAYSIZE(kPartialPushPromiseFrameData), + kPartialPushPromiseFrameData, + ABSL_ARRAYSIZE(kPartialPushPromiseFrameData)); + + // Compare the CONTINUATION frame against the template. + frame_data += kHttp2MaxControlFrameSendSize; + CompareCharArraysWithHexError( + kDescription, frame_data, ABSL_ARRAYSIZE(kContinuationFrameData), + kContinuationFrameData, ABSL_ARRAYSIZE(kContinuationFrameData)); + } +} + +TEST_P(SpdyFramerTest, CreateAltSvc) { + const char kDescription[] = "ALTSVC frame"; + const unsigned char kType = SerializeFrameType(SpdyFrameType::ALTSVC); + const unsigned char kFrameData[] = { + 0x00, 0x00, 0x49, kType, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x06, 'o', + 'r', 'i', 'g', 'i', 'n', 'p', 'i', 'd', '1', '=', '"', 'h', + 'o', 's', 't', ':', '4', '4', '3', '"', ';', ' ', 'm', 'a', + '=', '5', ',', 'p', '%', '2', '2', '%', '3', 'D', 'i', '%', + '3', 'A', 'd', '=', '"', 'h', '_', '\\', '\\', 'o', '\\', '"', + 's', 't', ':', '1', '2', '3', '"', ';', ' ', 'm', 'a', '=', + '4', '2', ';', ' ', 'v', '=', '"', '2', '4', '"'}; + SpdyAltSvcIR altsvc_ir(/* stream_id = */ 3); + altsvc_ir.set_origin("origin"); + altsvc_ir.add_altsvc(SpdyAltSvcWireFormat::AlternativeService( + "pid1", "host", 443, 5, SpdyAltSvcWireFormat::VersionVector())); + altsvc_ir.add_altsvc(SpdyAltSvcWireFormat::AlternativeService( + "p\"=i:d", "h_\\o\"st", 123, 42, + SpdyAltSvcWireFormat::VersionVector{24})); + SpdySerializedFrame frame(framer_.SerializeFrame(altsvc_ir)); + if (use_output_) { + EXPECT_EQ(framer_.SerializeFrame(altsvc_ir, &output_), frame.size()); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + CompareFrame(kDescription, frame, kFrameData, ABSL_ARRAYSIZE(kFrameData)); +} + +TEST_P(SpdyFramerTest, CreatePriority) { + const char kDescription[] = "PRIORITY frame"; + const unsigned char kFrameData[] = { + 0x00, 0x00, 0x05, // Length: 5 + 0x02, // Type: PRIORITY + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x02, // Stream: 2 + 0x80, 0x00, 0x00, 0x01, // Parent: 1 (Exclusive) + 0x10, // Weight: 17 + }; + SpdyPriorityIR priority_ir(/* stream_id = */ 2, + /* parent_stream_id = */ 1, + /* weight = */ 17, + /* exclusive = */ true); + SpdySerializedFrame frame(framer_.SerializeFrame(priority_ir)); + if (use_output_) { + EXPECT_EQ(framer_.SerializeFrame(priority_ir, &output_), frame.size()); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + CompareFrame(kDescription, frame, kFrameData, ABSL_ARRAYSIZE(kFrameData)); +} + +TEST_P(SpdyFramerTest, CreatePriorityUpdate) { + const char kDescription[] = "PRIORITY_UPDATE frame"; + const unsigned char kType = + SerializeFrameType(SpdyFrameType::PRIORITY_UPDATE); + const unsigned char kFrameData[] = { + 0x00, 0x00, 0x07, // frame length + kType, // frame type + 0x00, // flags + 0x00, 0x00, 0x00, 0x00, // stream ID, must be 0 for PRIORITY_UPDATE + 0x00, 0x00, 0x00, 0x03, // prioritized stream ID + 'u', '=', '0'}; // priority field value + SpdyPriorityUpdateIR priority_update_ir(/* stream_id = */ 0, + /* prioritized_stream_id = */ 3, + /* priority_field_value = */ "u=0"); + SpdySerializedFrame frame(framer_.SerializeFrame(priority_update_ir)); + if (use_output_) { + EXPECT_EQ(framer_.SerializeFrame(priority_update_ir, &output_), + frame.size()); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + CompareFrame(kDescription, frame, kFrameData, ABSL_ARRAYSIZE(kFrameData)); +} + +TEST_P(SpdyFramerTest, CreateAcceptCh) { + const char kDescription[] = "ACCEPT_CH frame"; + const unsigned char kType = SerializeFrameType(SpdyFrameType::ACCEPT_CH); + const unsigned char kFrameData[] = { + 0x00, 0x00, 0x2d, // frame length + kType, // frame type + 0x00, // flags + 0x00, 0x00, 0x00, 0x00, // stream ID, must be 0 for ACCEPT_CH + 0x00, 0x0f, // origin length + 'w', 'w', 'w', '.', 'e', 'x', // origin + 'a', 'm', 'p', 'l', 'e', '.', // + 'c', 'o', 'm', // + 0x00, 0x03, // value length + 'f', 'o', 'o', // value + 0x00, 0x10, // origin length + 'm', 'a', 'i', 'l', '.', 'e', // + 'x', 'a', 'm', 'p', 'l', 'e', // + '.', 'c', 'o', 'm', // + 0x00, 0x03, // value length + 'b', 'a', 'r'}; // value + SpdyAcceptChIR accept_ch_ir( + {{"www.example.com", "foo"}, {"mail.example.com", "bar"}}); + SpdySerializedFrame frame(framer_.SerializeFrame(accept_ch_ir)); + if (use_output_) { + EXPECT_EQ(framer_.SerializeFrame(accept_ch_ir, &output_), frame.size()); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + CompareFrame(kDescription, frame, kFrameData, ABSL_ARRAYSIZE(kFrameData)); +} + +TEST_P(SpdyFramerTest, CreateUnknown) { + const char kDescription[] = "Unknown frame"; + const uint8_t kType = 0xaf; + const uint8_t kFlags = 0x11; + const uint8_t kLength = strlen(kDescription); + const unsigned char kFrameData[] = { + 0x00, 0x00, kLength, // Length: 13 + kType, // Type: undefined + kFlags, // Flags: arbitrary, undefined + 0x00, 0x00, 0x00, 0x02, // Stream: 2 + 0x55, 0x6e, 0x6b, 0x6e, // "Unkn" + 0x6f, 0x77, 0x6e, 0x20, // "own " + 0x66, 0x72, 0x61, 0x6d, // "fram" + 0x65, // "e" + }; + SpdyUnknownIR unknown_ir(/* stream_id = */ 2, + /* type = */ kType, + /* flags = */ kFlags, + /* payload = */ kDescription); + SpdySerializedFrame frame(framer_.SerializeFrame(unknown_ir)); + if (use_output_) { + EXPECT_EQ(framer_.SerializeFrame(unknown_ir, &output_), frame.size()); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + CompareFrame(kDescription, frame, kFrameData, ABSL_ARRAYSIZE(kFrameData)); +} + +// Test serialization of a SpdyUnknownIR with a defined type, a length field +// that does not match the payload size and in fact exceeds framer limits, and a +// stream ID that effectively flips the reserved bit. +TEST_P(SpdyFramerTest, CreateUnknownUnchecked) { + const char kDescription[] = "Unknown frame"; + const uint8_t kType = 0x00; + const uint8_t kFlags = 0x11; + const uint8_t kLength = std::numeric_limits::max(); + const unsigned int kStreamId = kStreamIdMask + 42; + const unsigned char kFrameData[] = { + 0x00, 0x00, kLength, // Length: 16426 + kType, // Type: DATA, defined + kFlags, // Flags: arbitrary, undefined + 0x80, 0x00, 0x00, 0x29, // Stream: 2147483689 + 0x55, 0x6e, 0x6b, 0x6e, // "Unkn" + 0x6f, 0x77, 0x6e, 0x20, // "own " + 0x66, 0x72, 0x61, 0x6d, // "fram" + 0x65, // "e" + }; + TestSpdyUnknownIR unknown_ir(/* stream_id = */ kStreamId, + /* type = */ kType, + /* flags = */ kFlags, + /* payload = */ kDescription); + unknown_ir.set_length(kLength); + SpdySerializedFrame frame(framer_.SerializeFrame(unknown_ir)); + if (use_output_) { + EXPECT_EQ(framer_.SerializeFrame(unknown_ir, &output_), frame.size()); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + CompareFrame(kDescription, frame, kFrameData, ABSL_ARRAYSIZE(kFrameData)); +} + +TEST_P(SpdyFramerTest, ReadCompressedHeadersHeaderBlock) { + SpdyHeadersIR headers_ir(/* stream_id = */ 1); + headers_ir.SetHeader("alpha", "beta"); + headers_ir.SetHeader("gamma", "delta"); + SpdySerializedFrame control_frame(SpdyFramerPeer::SerializeHeaders( + &framer_, headers_ir, use_output_ ? &output_ : nullptr)); + TestSpdyVisitor visitor(SpdyFramer::ENABLE_COMPRESSION); + visitor.SimulateInFramer( + reinterpret_cast(control_frame.data()), + control_frame.size()); + EXPECT_EQ(1, visitor.headers_frame_count_); + EXPECT_EQ(0, visitor.control_frame_header_data_count_); + EXPECT_EQ(0, visitor.zero_length_control_frame_header_data_count_); + EXPECT_EQ(0, visitor.end_of_stream_count_); + EXPECT_EQ(headers_ir.header_block(), visitor.headers_); +} + +TEST_P(SpdyFramerTest, ReadCompressedHeadersHeaderBlockWithHalfClose) { + SpdyHeadersIR headers_ir(/* stream_id = */ 1); + headers_ir.set_fin(true); + headers_ir.SetHeader("alpha", "beta"); + headers_ir.SetHeader("gamma", "delta"); + SpdySerializedFrame control_frame(SpdyFramerPeer::SerializeHeaders( + &framer_, headers_ir, use_output_ ? &output_ : nullptr)); + TestSpdyVisitor visitor(SpdyFramer::ENABLE_COMPRESSION); + visitor.SimulateInFramer( + reinterpret_cast(control_frame.data()), + control_frame.size()); + EXPECT_EQ(1, visitor.headers_frame_count_); + EXPECT_EQ(0, visitor.control_frame_header_data_count_); + EXPECT_EQ(0, visitor.zero_length_control_frame_header_data_count_); + EXPECT_EQ(1, visitor.end_of_stream_count_); + EXPECT_EQ(headers_ir.header_block(), visitor.headers_); +} + +TEST_P(SpdyFramerTest, TooLargeHeadersFrameUsesContinuation) { + SpdyFramer framer(SpdyFramer::DISABLE_COMPRESSION); + SpdyHeadersIR headers(/* stream_id = */ 1); + headers.set_padding_len(256); + + // Exact payload length will change with HPACK, but this should be long + // enough to cause an overflow. + const size_t kBigValueSize = kHttp2MaxControlFrameSendSize; + std::string big_value(kBigValueSize, 'x'); + headers.SetHeader("aa", big_value); + SpdySerializedFrame control_frame(SpdyFramerPeer::SerializeHeaders( + &framer, headers, use_output_ ? &output_ : nullptr)); + EXPECT_GT(control_frame.size(), kHttp2MaxControlFrameSendSize); + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer( + reinterpret_cast(control_frame.data()), + control_frame.size()); + EXPECT_TRUE(visitor.header_buffer_valid_); + EXPECT_EQ(0, visitor.error_count_); + EXPECT_EQ(1, visitor.headers_frame_count_); + EXPECT_EQ(1, visitor.continuation_count_); + EXPECT_EQ(0, visitor.zero_length_control_frame_header_data_count_); +} + +TEST_P(SpdyFramerTest, MultipleContinuationFramesWithIterator) { + SpdyFramer framer(SpdyFramer::DISABLE_COMPRESSION); + auto headers = std::make_unique(/* stream_id = */ 1); + headers->set_padding_len(256); + + // Exact payload length will change with HPACK, but this should be long + // enough to cause an overflow. + const size_t kBigValueSize = kHttp2MaxControlFrameSendSize; + std::string big_valuex(kBigValueSize, 'x'); + headers->SetHeader("aa", big_valuex); + std::string big_valuez(kBigValueSize, 'z'); + headers->SetHeader("bb", big_valuez); + + SpdyFramer::SpdyHeaderFrameIterator frame_it(&framer, std::move(headers)); + + EXPECT_TRUE(frame_it.HasNextFrame()); + EXPECT_GT(frame_it.NextFrame(&output_), 0u); + SpdySerializedFrame headers_frame(output_.Begin(), output_.Size(), false); + EXPECT_EQ(headers_frame.size(), kHttp2MaxControlFrameSendSize); + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer( + reinterpret_cast(headers_frame.data()), + headers_frame.size()); + EXPECT_TRUE(visitor.header_buffer_valid_); + EXPECT_EQ(0, visitor.error_count_); + EXPECT_EQ(1, visitor.headers_frame_count_); + EXPECT_EQ(0, visitor.continuation_count_); + EXPECT_EQ(0, visitor.zero_length_control_frame_header_data_count_); + + output_.Reset(); + EXPECT_TRUE(frame_it.HasNextFrame()); + EXPECT_GT(frame_it.NextFrame(&output_), 0u); + SpdySerializedFrame first_cont_frame(output_.Begin(), output_.Size(), false); + EXPECT_EQ(first_cont_frame.size(), kHttp2MaxControlFrameSendSize); + + visitor.SimulateInFramer( + reinterpret_cast(first_cont_frame.data()), + first_cont_frame.size()); + EXPECT_TRUE(visitor.header_buffer_valid_); + EXPECT_EQ(0, visitor.error_count_); + EXPECT_EQ(1, visitor.headers_frame_count_); + EXPECT_EQ(1, visitor.continuation_count_); + EXPECT_EQ(0, visitor.zero_length_control_frame_header_data_count_); + + output_.Reset(); + EXPECT_TRUE(frame_it.HasNextFrame()); + EXPECT_GT(frame_it.NextFrame(&output_), 0u); + SpdySerializedFrame second_cont_frame(output_.Begin(), output_.Size(), false); + EXPECT_LT(second_cont_frame.size(), kHttp2MaxControlFrameSendSize); + + visitor.SimulateInFramer( + reinterpret_cast(second_cont_frame.data()), + second_cont_frame.size()); + EXPECT_TRUE(visitor.header_buffer_valid_); + EXPECT_EQ(0, visitor.error_count_); + EXPECT_EQ(1, visitor.headers_frame_count_); + EXPECT_EQ(2, visitor.continuation_count_); + EXPECT_EQ(0, visitor.zero_length_control_frame_header_data_count_); + + EXPECT_FALSE(frame_it.HasNextFrame()); +} + +TEST_P(SpdyFramerTest, PushPromiseFramesWithIterator) { + SpdyFramer framer(SpdyFramer::DISABLE_COMPRESSION); + auto push_promise = + std::make_unique(/* stream_id = */ 1, + /* promised_stream_id = */ 2); + push_promise->set_padding_len(256); + + // Exact payload length will change with HPACK, but this should be long + // enough to cause an overflow. + const size_t kBigValueSize = kHttp2MaxControlFrameSendSize; + std::string big_valuex(kBigValueSize, 'x'); + push_promise->SetHeader("aa", big_valuex); + std::string big_valuez(kBigValueSize, 'z'); + push_promise->SetHeader("bb", big_valuez); + + SpdyFramer::SpdyPushPromiseFrameIterator frame_it(&framer, + std::move(push_promise)); + + EXPECT_TRUE(frame_it.HasNextFrame()); + EXPECT_GT(frame_it.NextFrame(&output_), 0u); + SpdySerializedFrame push_promise_frame(output_.Begin(), output_.Size(), + false); + EXPECT_EQ(push_promise_frame.size(), kHttp2MaxControlFrameSendSize); + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer( + reinterpret_cast(push_promise_frame.data()), + push_promise_frame.size()); + EXPECT_TRUE(visitor.header_buffer_valid_); + EXPECT_EQ(0, visitor.error_count_); + EXPECT_EQ(1, visitor.push_promise_frame_count_); + EXPECT_EQ(0, visitor.continuation_count_); + EXPECT_EQ(0, visitor.zero_length_control_frame_header_data_count_); + + EXPECT_TRUE(frame_it.HasNextFrame()); + output_.Reset(); + EXPECT_GT(frame_it.NextFrame(&output_), 0u); + SpdySerializedFrame first_cont_frame(output_.Begin(), output_.Size(), false); + + EXPECT_EQ(first_cont_frame.size(), kHttp2MaxControlFrameSendSize); + visitor.SimulateInFramer( + reinterpret_cast(first_cont_frame.data()), + first_cont_frame.size()); + EXPECT_TRUE(visitor.header_buffer_valid_); + EXPECT_EQ(0, visitor.error_count_); + EXPECT_EQ(1, visitor.push_promise_frame_count_); + EXPECT_EQ(1, visitor.continuation_count_); + EXPECT_EQ(0, visitor.zero_length_control_frame_header_data_count_); + + EXPECT_TRUE(frame_it.HasNextFrame()); + output_.Reset(); + EXPECT_GT(frame_it.NextFrame(&output_), 0u); + SpdySerializedFrame second_cont_frame(output_.Begin(), output_.Size(), false); + EXPECT_LT(second_cont_frame.size(), kHttp2MaxControlFrameSendSize); + + visitor.SimulateInFramer( + reinterpret_cast(second_cont_frame.data()), + second_cont_frame.size()); + EXPECT_TRUE(visitor.header_buffer_valid_); + EXPECT_EQ(0, visitor.error_count_); + EXPECT_EQ(1, visitor.push_promise_frame_count_); + EXPECT_EQ(2, visitor.continuation_count_); + EXPECT_EQ(0, visitor.zero_length_control_frame_header_data_count_); + + EXPECT_FALSE(frame_it.HasNextFrame()); +} + +class SpdyControlFrameIteratorTest : public quiche::test::QuicheTest { + public: + SpdyControlFrameIteratorTest() : output_(output_buffer, kSize) {} + + void RunTest(std::unique_ptr ir) { + SpdyFramer framer(SpdyFramer::DISABLE_COMPRESSION); + SpdySerializedFrame frame(framer.SerializeFrame(*ir)); + std::unique_ptr it = + SpdyFramer::CreateIterator(&framer, std::move(ir)); + EXPECT_TRUE(it->HasNextFrame()); + EXPECT_EQ(it->NextFrame(&output_), frame.size()); + EXPECT_FALSE(it->HasNextFrame()); + } + + private: + ArrayOutputBuffer output_; +}; + +TEST_F(SpdyControlFrameIteratorTest, RstStreamFrameWithIterator) { + auto ir = std::make_unique(0, ERROR_CODE_PROTOCOL_ERROR); + RunTest(std::move(ir)); +} + +TEST_F(SpdyControlFrameIteratorTest, SettingsFrameWithIterator) { + auto ir = std::make_unique(); + uint32_t kValue = 0x0a0b0c0d; + SpdyKnownSettingsId kId = SETTINGS_INITIAL_WINDOW_SIZE; + ir->AddSetting(kId, kValue); + RunTest(std::move(ir)); +} + +TEST_F(SpdyControlFrameIteratorTest, PingFrameWithIterator) { + const SpdyPingId kPingId = 0x123456789abcdeffULL; + auto ir = std::make_unique(kPingId); + RunTest(std::move(ir)); +} + +TEST_F(SpdyControlFrameIteratorTest, GoAwayFrameWithIterator) { + auto ir = std::make_unique(0, ERROR_CODE_NO_ERROR, "GA"); + RunTest(std::move(ir)); +} + +TEST_F(SpdyControlFrameIteratorTest, WindowUpdateFrameWithIterator) { + auto ir = std::make_unique(1, 1); + RunTest(std::move(ir)); +} + +TEST_F(SpdyControlFrameIteratorTest, AtlSvcFrameWithIterator) { + auto ir = std::make_unique(3); + ir->set_origin("origin"); + ir->add_altsvc(SpdyAltSvcWireFormat::AlternativeService( + "pid1", "host", 443, 5, SpdyAltSvcWireFormat::VersionVector())); + ir->add_altsvc(SpdyAltSvcWireFormat::AlternativeService( + "p\"=i:d", "h_\\o\"st", 123, 42, + SpdyAltSvcWireFormat::VersionVector{24})); + RunTest(std::move(ir)); +} + +TEST_F(SpdyControlFrameIteratorTest, PriorityFrameWithIterator) { + auto ir = std::make_unique(2, 1, 17, true); + RunTest(std::move(ir)); +} + +TEST_P(SpdyFramerTest, TooLargePushPromiseFrameUsesContinuation) { + SpdyFramer framer(SpdyFramer::DISABLE_COMPRESSION); + SpdyPushPromiseIR push_promise(/* stream_id = */ 1, + /* promised_stream_id = */ 2); + push_promise.set_padding_len(256); + + // Exact payload length will change with HPACK, but this should be long + // enough to cause an overflow. + const size_t kBigValueSize = kHttp2MaxControlFrameSendSize; + std::string big_value(kBigValueSize, 'x'); + push_promise.SetHeader("aa", big_value); + SpdySerializedFrame control_frame(SpdyFramerPeer::SerializePushPromise( + &framer, push_promise, use_output_ ? &output_ : nullptr)); + EXPECT_GT(control_frame.size(), kHttp2MaxControlFrameSendSize); + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer( + reinterpret_cast(control_frame.data()), + control_frame.size()); + EXPECT_TRUE(visitor.header_buffer_valid_); + EXPECT_EQ(0, visitor.error_count_); + EXPECT_EQ(1, visitor.push_promise_frame_count_); + EXPECT_EQ(1, visitor.continuation_count_); + EXPECT_EQ(0, visitor.zero_length_control_frame_header_data_count_); +} + +// Check that the framer stops delivering header data chunks once the visitor +// declares it doesn't want any more. This is important to guard against +// "zip bomb" types of attacks. +TEST_P(SpdyFramerTest, ControlFrameMuchTooLarge) { + const size_t kHeaderBufferChunks = 4; + const size_t kHeaderBufferSize = + kHttp2DefaultFramePayloadLimit / kHeaderBufferChunks; + const size_t kBigValueSize = kHeaderBufferSize * 2; + std::string big_value(kBigValueSize, 'x'); + SpdyHeadersIR headers(/* stream_id = */ 1); + headers.set_fin(true); + headers.SetHeader("aa", big_value); + SpdySerializedFrame control_frame(SpdyFramerPeer::SerializeHeaders( + &framer_, headers, use_output_ ? &output_ : nullptr)); + TestSpdyVisitor visitor(SpdyFramer::ENABLE_COMPRESSION); + visitor.set_header_buffer_size(kHeaderBufferSize); + visitor.SimulateInFramer( + reinterpret_cast(control_frame.data()), + control_frame.size()); + // It's up to the visitor to ignore extraneous header data; the framer + // won't throw an error. + EXPECT_GT(visitor.header_bytes_received_, visitor.header_buffer_size_); + EXPECT_EQ(1, visitor.end_of_stream_count_); +} + +TEST_P(SpdyFramerTest, ControlFrameSizesAreValidated) { + // Create a GoAway frame that has a few extra bytes at the end. + const size_t length = 20; + + // HTTP/2 GOAWAY frames are only bound by a minimal length, since they may + // carry opaque data. Verify that minimal length is tested. + ASSERT_GT(kGoawayFrameMinimumSize, kFrameHeaderSize); + const size_t less_than_min_length = + kGoawayFrameMinimumSize - kFrameHeaderSize - 1; + ASSERT_LE(less_than_min_length, std::numeric_limits::max()); + const unsigned char kH2Len = static_cast(less_than_min_length); + const unsigned char kH2FrameData[] = { + 0x00, 0x00, kH2Len, // Length: min length - 1 + 0x07, // Type: GOAWAY + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x00, // Stream: 0 + 0x00, 0x00, 0x00, 0x00, // Last: 0 + 0x00, 0x00, 0x00, // Truncated Status Field + }; + const size_t pad_length = length + kFrameHeaderSize - sizeof(kH2FrameData); + std::string pad(pad_length, 'A'); + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + + visitor.SimulateInFramer(kH2FrameData, sizeof(kH2FrameData)); + visitor.SimulateInFramer(reinterpret_cast(pad.c_str()), + pad.length()); + + EXPECT_EQ(1, visitor.error_count_); // This generated an error. + EXPECT_EQ(Http2DecoderAdapter::SPDY_INVALID_CONTROL_FRAME, + visitor.deframer_.spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + visitor.deframer_.spdy_framer_error()); + EXPECT_EQ(0, visitor.goaway_count_); // Frame not parsed. +} + +TEST_P(SpdyFramerTest, ReadZeroLenSettingsFrame) { + SpdySettingsIR settings_ir; + SpdySerializedFrame control_frame(framer_.SerializeSettings(settings_ir)); + if (use_output_) { + ASSERT_TRUE(framer_.SerializeSettings(settings_ir, &output_)); + control_frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + SetFrameLength(&control_frame, 0); + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer( + reinterpret_cast(control_frame.data()), kFrameHeaderSize); + // Zero-len settings frames are permitted as of HTTP/2. + EXPECT_EQ(0, visitor.error_count_); +} + +// Tests handling of SETTINGS frames with invalid length. +TEST_P(SpdyFramerTest, ReadBogusLenSettingsFrame) { + SpdySettingsIR settings_ir; + + // Add settings to more than fill the frame so that we don't get a buffer + // overflow when calling SimulateInFramer() below. These settings must be + // distinct parameters because SpdySettingsIR has a map for settings, and + // will collapse multiple copies of the same parameter. + settings_ir.AddSetting(SETTINGS_INITIAL_WINDOW_SIZE, 0x00000002); + settings_ir.AddSetting(SETTINGS_MAX_CONCURRENT_STREAMS, 0x00000002); + SpdySerializedFrame control_frame(framer_.SerializeSettings(settings_ir)); + if (use_output_) { + ASSERT_TRUE(framer_.SerializeSettings(settings_ir, &output_)); + control_frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + const size_t kNewLength = 8; + SetFrameLength(&control_frame, kNewLength); + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer( + reinterpret_cast(control_frame.data()), + kFrameHeaderSize + kNewLength); + // Should generate an error, since its not possible to have a + // settings frame of length kNewLength. + EXPECT_EQ(1, visitor.error_count_); + EXPECT_EQ(Http2DecoderAdapter::SPDY_INVALID_CONTROL_FRAME_SIZE, + visitor.deframer_.spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + visitor.deframer_.spdy_framer_error()); +} + +// Tests handling of larger SETTINGS frames. +TEST_P(SpdyFramerTest, ReadLargeSettingsFrame) { + SpdySettingsIR settings_ir; + settings_ir.AddSetting(SETTINGS_HEADER_TABLE_SIZE, 5); + settings_ir.AddSetting(SETTINGS_ENABLE_PUSH, 6); + settings_ir.AddSetting(SETTINGS_MAX_CONCURRENT_STREAMS, 7); + + SpdySerializedFrame control_frame(framer_.SerializeSettings(settings_ir)); + if (use_output_) { + ASSERT_TRUE(framer_.SerializeSettings(settings_ir, &output_)); + control_frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + + // Read all at once. + visitor.SimulateInFramer( + reinterpret_cast(control_frame.data()), + control_frame.size()); + EXPECT_EQ(0, visitor.error_count_); + EXPECT_EQ(3, visitor.setting_count_); + EXPECT_EQ(1, visitor.settings_ack_sent_); + + // Read data in small chunks. + size_t framed_data = 0; + size_t unframed_data = control_frame.size(); + size_t kReadChunkSize = 5; // Read five bytes at a time. + while (unframed_data > 0) { + size_t to_read = std::min(kReadChunkSize, unframed_data); + visitor.SimulateInFramer( + reinterpret_cast(control_frame.data() + framed_data), + to_read); + unframed_data -= to_read; + framed_data += to_read; + } + EXPECT_EQ(0, visitor.error_count_); + EXPECT_EQ(3 * 2, visitor.setting_count_); + EXPECT_EQ(2, visitor.settings_ack_sent_); +} + +// Tests handling of SETTINGS frame with duplicate entries. +TEST_P(SpdyFramerTest, ReadDuplicateSettings) { + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x12, // Length: 18 + 0x04, // Type: SETTINGS + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x00, // Stream: 0 + 0x00, 0x01, // Param: HEADER_TABLE_SIZE + 0x00, 0x00, 0x00, 0x02, // Value: 2 + 0x00, 0x01, // Param: HEADER_TABLE_SIZE + 0x00, 0x00, 0x00, 0x03, // Value: 3 + 0x00, 0x03, // Param: MAX_CONCURRENT_STREAMS + 0x00, 0x00, 0x00, 0x03, // Value: 3 + }; + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer(kH2FrameData, sizeof(kH2FrameData)); + + // In HTTP/2, duplicate settings are allowed; + // each setting replaces the previous value for that setting. + EXPECT_EQ(3, visitor.setting_count_); + EXPECT_EQ(0, visitor.error_count_); + EXPECT_EQ(1, visitor.settings_ack_sent_); +} + +// Tests handling of SETTINGS frame with a setting we don't recognize. +TEST_P(SpdyFramerTest, ReadUnknownSettingsId) { + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x06, // Length: 6 + 0x04, // Type: SETTINGS + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x00, // Stream: 0 + 0x00, 0x10, // Param: 16 + 0x00, 0x00, 0x00, 0x02, // Value: 2 + }; + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer(kH2FrameData, sizeof(kH2FrameData)); + + // In HTTP/2, we ignore unknown settings because of extensions. However, we + // pass the SETTINGS to the visitor, which can decide how to handle them. + EXPECT_EQ(1, visitor.setting_count_); + EXPECT_EQ(0, visitor.error_count_); +} + +TEST_P(SpdyFramerTest, ReadKnownAndUnknownSettingsWithExtension) { + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x18, // Length: 24 + 0x04, // Type: SETTINGS + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x00, // Stream: 0 + 0x00, 0x10, // Param: 16 + 0x00, 0x00, 0x00, 0x02, // Value: 2 + 0x00, 0x5f, // Param: 95 + 0x00, 0x01, 0x00, 0x02, // Value: 65538 + 0x00, 0x02, // Param: ENABLE_PUSH + 0x00, 0x00, 0x00, 0x01, // Value: 1 + 0x00, 0x08, // Param: ENABLE_CONNECT_PROTOCOL + 0x00, 0x00, 0x00, 0x01, // Value: 1 + }; + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + TestExtension extension; + visitor.set_extension_visitor(&extension); + visitor.SimulateInFramer(kH2FrameData, sizeof(kH2FrameData)); + + // In HTTP/2, we ignore unknown settings because of extensions. However, we + // pass the SETTINGS to the visitor, which can decide how to handle them. + EXPECT_EQ(4, visitor.setting_count_); + EXPECT_EQ(0, visitor.error_count_); + + // The extension receives only the non-standard SETTINGS. + EXPECT_THAT( + extension.settings_received_, + testing::ElementsAre(testing::Pair(16, 2), testing::Pair(95, 65538))); +} + +// Tests handling of SETTINGS frame with entries out of order. +TEST_P(SpdyFramerTest, ReadOutOfOrderSettings) { + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x12, // Length: 18 + 0x04, // Type: SETTINGS + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x00, // Stream: 0 + 0x00, 0x02, // Param: ENABLE_PUSH + 0x00, 0x00, 0x00, 0x02, // Value: 2 + 0x00, 0x01, // Param: HEADER_TABLE_SIZE + 0x00, 0x00, 0x00, 0x03, // Value: 3 + 0x00, 0x03, // Param: MAX_CONCURRENT_STREAMS + 0x00, 0x00, 0x00, 0x03, // Value: 3 + }; + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer(kH2FrameData, sizeof(kH2FrameData)); + + // In HTTP/2, settings are allowed in any order. + EXPECT_EQ(3, visitor.setting_count_); + EXPECT_EQ(0, visitor.error_count_); +} + +TEST_P(SpdyFramerTest, ProcessSettingsAckFrame) { + const unsigned char kFrameData[] = { + 0x00, 0x00, 0x00, // Length: 0 + 0x04, // Type: SETTINGS + 0x01, // Flags: ACK + 0x00, 0x00, 0x00, 0x00, // Stream: 0 + }; + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer(kFrameData, sizeof(kFrameData)); + + EXPECT_EQ(0, visitor.error_count_); + EXPECT_EQ(0, visitor.setting_count_); + EXPECT_EQ(1, visitor.settings_ack_received_); +} + +TEST_P(SpdyFramerTest, ProcessDataFrameWithPadding) { + const int kPaddingLen = 119; + const char data_payload[] = "hello"; + + testing::StrictMock visitor; + deframer_->set_visitor(&visitor); + + SpdyDataIR data_ir(/* stream_id = */ 1, data_payload); + data_ir.set_padding_len(kPaddingLen); + SpdySerializedFrame frame(framer_.SerializeData(data_ir)); + + int bytes_consumed = 0; + + // Send the frame header. + EXPECT_CALL(visitor, + OnCommonHeader(1, kPaddingLen + strlen(data_payload), 0x0, 0x8)); + EXPECT_CALL(visitor, + OnDataFrameHeader(1, kPaddingLen + strlen(data_payload), false)); + QUICHE_CHECK_EQ(kDataFrameMinimumSize, + deframer_->ProcessInput(frame.data(), kDataFrameMinimumSize)); + QUICHE_CHECK_EQ(deframer_->state(), + Http2DecoderAdapter::SPDY_READ_DATA_FRAME_PADDING_LENGTH); + QUICHE_CHECK_EQ(deframer_->spdy_framer_error(), + Http2DecoderAdapter::SPDY_NO_ERROR); + bytes_consumed += kDataFrameMinimumSize; + + // Send the padding length field. + EXPECT_CALL(visitor, OnStreamPadLength(1, kPaddingLen - 1)); + QUICHE_CHECK_EQ(1u, + deframer_->ProcessInput(frame.data() + bytes_consumed, 1)); + QUICHE_CHECK_EQ(deframer_->state(), + Http2DecoderAdapter::SPDY_FORWARD_STREAM_FRAME); + QUICHE_CHECK_EQ(deframer_->spdy_framer_error(), + Http2DecoderAdapter::SPDY_NO_ERROR); + bytes_consumed += 1; + + // Send the first two bytes of the data payload, i.e., "he". + EXPECT_CALL(visitor, OnStreamFrameData(1, _, 2)); + QUICHE_CHECK_EQ(2u, + deframer_->ProcessInput(frame.data() + bytes_consumed, 2)); + QUICHE_CHECK_EQ(deframer_->state(), + Http2DecoderAdapter::SPDY_FORWARD_STREAM_FRAME); + QUICHE_CHECK_EQ(deframer_->spdy_framer_error(), + Http2DecoderAdapter::SPDY_NO_ERROR); + bytes_consumed += 2; + + // Send the rest three bytes of the data payload, i.e., "llo". + EXPECT_CALL(visitor, OnStreamFrameData(1, _, 3)); + QUICHE_CHECK_EQ(3u, + deframer_->ProcessInput(frame.data() + bytes_consumed, 3)); + QUICHE_CHECK_EQ(deframer_->state(), + Http2DecoderAdapter::SPDY_CONSUME_PADDING); + QUICHE_CHECK_EQ(deframer_->spdy_framer_error(), + Http2DecoderAdapter::SPDY_NO_ERROR); + bytes_consumed += 3; + + // Send the first 100 bytes of the padding payload. + EXPECT_CALL(visitor, OnStreamPadding(1, 100)); + QUICHE_CHECK_EQ(100u, + deframer_->ProcessInput(frame.data() + bytes_consumed, 100)); + QUICHE_CHECK_EQ(deframer_->state(), + Http2DecoderAdapter::SPDY_CONSUME_PADDING); + QUICHE_CHECK_EQ(deframer_->spdy_framer_error(), + Http2DecoderAdapter::SPDY_NO_ERROR); + bytes_consumed += 100; + + // Send rest of the padding payload. + EXPECT_CALL(visitor, OnStreamPadding(1, 18)); + QUICHE_CHECK_EQ(18u, + deframer_->ProcessInput(frame.data() + bytes_consumed, 18)); + QUICHE_CHECK_EQ(deframer_->state(), + Http2DecoderAdapter::SPDY_READY_FOR_FRAME); + QUICHE_CHECK_EQ(deframer_->spdy_framer_error(), + Http2DecoderAdapter::SPDY_NO_ERROR); +} + +TEST_P(SpdyFramerTest, ReadWindowUpdate) { + SpdySerializedFrame control_frame(framer_.SerializeWindowUpdate( + SpdyWindowUpdateIR(/* stream_id = */ 1, /* delta = */ 2))); + if (use_output_) { + ASSERT_TRUE(framer_.SerializeWindowUpdate( + SpdyWindowUpdateIR(/* stream_id = */ 1, /* delta = */ 2), &output_)); + control_frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer( + reinterpret_cast(control_frame.data()), + control_frame.size()); + EXPECT_EQ(1u, visitor.last_window_update_stream_); + EXPECT_EQ(2, visitor.last_window_update_delta_); +} + +TEST_P(SpdyFramerTest, ReadCompressedPushPromise) { + SpdyPushPromiseIR push_promise(/* stream_id = */ 42, + /* promised_stream_id = */ 57); + push_promise.SetHeader("foo", "bar"); + push_promise.SetHeader("bar", "foofoo"); + SpdySerializedFrame frame(SpdyFramerPeer::SerializePushPromise( + &framer_, push_promise, use_output_ ? &output_ : nullptr)); + TestSpdyVisitor visitor(SpdyFramer::ENABLE_COMPRESSION); + visitor.SimulateInFramer(reinterpret_cast(frame.data()), + frame.size()); + EXPECT_EQ(42u, visitor.last_push_promise_stream_); + EXPECT_EQ(57u, visitor.last_push_promise_promised_stream_); + EXPECT_EQ(push_promise.header_block(), visitor.headers_); +} + +TEST_P(SpdyFramerTest, ReadHeadersWithContinuation) { + // frame-format off + const unsigned char kInput[] = { + 0x00, 0x00, 0x14, // Length: 20 + 0x01, // Type: HEADERS + 0x08, // Flags: PADDED + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x03, // PadLen: 3 trailing bytes + 0x00, // Unindexed Entry + 0x06, // Name Len: 6 + 'c', 'o', 'o', 'k', 'i', 'e', // Name + 0x07, // Value Len: 7 + 'f', 'o', 'o', '=', 'b', 'a', 'r', // Value + 0x00, 0x00, 0x00, // Padding + + 0x00, 0x00, 0x14, // Length: 20 + 0x09, // Type: CONTINUATION + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, // Unindexed Entry + 0x06, // Name Len: 6 + 'c', 'o', 'o', 'k', 'i', 'e', // Name + 0x08, // Value Len: 7 + 'b', 'a', 'z', '=', 'b', 'i', 'n', 'g', // Value + 0x00, // Unindexed Entry + 0x06, // Name Len: 6 + 'c', // Name (split) + + 0x00, 0x00, 0x12, // Length: 18 + 0x09, // Type: CONTINUATION + 0x04, // Flags: END_HEADERS + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 'o', 'o', 'k', 'i', 'e', // Name (continued) + 0x00, // Value Len: 0 + 0x00, // Unindexed Entry + 0x04, // Name Len: 4 + 'n', 'a', 'm', 'e', // Name + 0x05, // Value Len: 5 + 'v', 'a', 'l', 'u', 'e', // Value + }; + // frame-format on + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer(kInput, sizeof(kInput)); + + EXPECT_EQ(0, visitor.error_count_); + EXPECT_EQ(1, visitor.headers_frame_count_); + EXPECT_EQ(2, visitor.continuation_count_); + EXPECT_EQ(0, visitor.zero_length_control_frame_header_data_count_); + EXPECT_EQ(0, visitor.end_of_stream_count_); + + EXPECT_THAT( + visitor.headers_, + testing::ElementsAre(testing::Pair("cookie", "foo=bar; baz=bing; "), + testing::Pair("name", "value"))); +} + +TEST_P(SpdyFramerTest, ReadHeadersWithContinuationAndFin) { + // frame-format off + const unsigned char kInput[] = { + 0x00, 0x00, 0x10, // Length: 20 + 0x01, // Type: HEADERS + 0x01, // Flags: END_STREAM + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, // Unindexed Entry + 0x06, // Name Len: 6 + 'c', 'o', 'o', 'k', 'i', 'e', // Name + 0x07, // Value Len: 7 + 'f', 'o', 'o', '=', 'b', 'a', 'r', // Value + + 0x00, 0x00, 0x14, // Length: 20 + 0x09, // Type: CONTINUATION + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, // Unindexed Entry + 0x06, // Name Len: 6 + 'c', 'o', 'o', 'k', 'i', 'e', // Name + 0x08, // Value Len: 7 + 'b', 'a', 'z', '=', 'b', 'i', 'n', 'g', // Value + 0x00, // Unindexed Entry + 0x06, // Name Len: 6 + 'c', // Name (split) + + 0x00, 0x00, 0x12, // Length: 18 + 0x09, // Type: CONTINUATION + 0x04, // Flags: END_HEADERS + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 'o', 'o', 'k', 'i', 'e', // Name (continued) + 0x00, // Value Len: 0 + 0x00, // Unindexed Entry + 0x04, // Name Len: 4 + 'n', 'a', 'm', 'e', // Name + 0x05, // Value Len: 5 + 'v', 'a', 'l', 'u', 'e', // Value + }; + // frame-format on + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer(kInput, sizeof(kInput)); + + EXPECT_EQ(0, visitor.error_count_); + EXPECT_EQ(1, visitor.headers_frame_count_); + EXPECT_EQ(2, visitor.continuation_count_); + EXPECT_EQ(1, visitor.fin_flag_count_); + EXPECT_EQ(0, visitor.zero_length_control_frame_header_data_count_); + EXPECT_EQ(1, visitor.end_of_stream_count_); + + EXPECT_THAT( + visitor.headers_, + testing::ElementsAre(testing::Pair("cookie", "foo=bar; baz=bing; "), + testing::Pair("name", "value"))); +} + +TEST_P(SpdyFramerTest, ReadPushPromiseWithContinuation) { + // frame-format off + const unsigned char kInput[] = { + 0x00, 0x00, 0x17, // Length: 23 + 0x05, // Type: PUSH_PROMISE + 0x08, // Flags: PADDED + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x02, // PadLen: 2 trailing bytes + 0x00, 0x00, 0x00, 0x2a, // Promise: 42 + 0x00, // Unindexed Entry + 0x06, // Name Len: 6 + 'c', 'o', 'o', 'k', 'i', 'e', // Name + 0x07, // Value Len: 7 + 'f', 'o', 'o', '=', 'b', 'a', 'r', // Value + 0x00, 0x00, // Padding + + 0x00, 0x00, 0x14, // Length: 20 + 0x09, // Type: CONTINUATION + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, // Unindexed Entry + 0x06, // Name Len: 6 + 'c', 'o', 'o', 'k', 'i', 'e', // Name + 0x08, // Value Len: 7 + 'b', 'a', 'z', '=', 'b', 'i', 'n', 'g', // Value + 0x00, // Unindexed Entry + 0x06, // Name Len: 6 + 'c', // Name (split) + + 0x00, 0x00, 0x12, // Length: 18 + 0x09, // Type: CONTINUATION + 0x04, // Flags: END_HEADERS + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 'o', 'o', 'k', 'i', 'e', // Name (continued) + 0x00, // Value Len: 0 + 0x00, // Unindexed Entry + 0x04, // Name Len: 4 + 'n', 'a', 'm', 'e', // Name + 0x05, // Value Len: 5 + 'v', 'a', 'l', 'u', 'e', // Value + }; + // frame-format on + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer(kInput, sizeof(kInput)); + + EXPECT_EQ(0, visitor.error_count_); + EXPECT_EQ(1u, visitor.last_push_promise_stream_); + EXPECT_EQ(42u, visitor.last_push_promise_promised_stream_); + EXPECT_EQ(2, visitor.continuation_count_); + EXPECT_EQ(0, visitor.zero_length_control_frame_header_data_count_); + EXPECT_EQ(0, visitor.end_of_stream_count_); + + EXPECT_THAT( + visitor.headers_, + testing::ElementsAre(testing::Pair("cookie", "foo=bar; baz=bing; "), + testing::Pair("name", "value"))); +} + +// Receiving an unknown frame when a continuation is expected should +// result in a SPDY_UNEXPECTED_FRAME error +TEST_P(SpdyFramerTest, ReceiveUnknownMidContinuation) { + const unsigned char kInput[] = { + 0x00, 0x00, 0x10, // Length: 16 + 0x01, // Type: HEADERS + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, 0x06, 0x63, 0x6f, // HPACK + 0x6f, 0x6b, 0x69, 0x65, // + 0x07, 0x66, 0x6f, 0x6f, // + 0x3d, 0x62, 0x61, 0x72, // + + 0x00, 0x00, 0x14, // Length: 20 + 0xa9, // Type: UnknownFrameType(169) + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, 0x06, 0x63, 0x6f, // Payload + 0x6f, 0x6b, 0x69, 0x65, // + 0x08, 0x62, 0x61, 0x7a, // + 0x3d, 0x62, 0x69, 0x6e, // + 0x67, 0x00, 0x06, 0x63, // + }; + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + // Assume the unknown frame is allowed + visitor.on_unknown_frame_result_ = true; + deframer_->set_visitor(&visitor); + visitor.SimulateInFramer(kInput, sizeof(kInput)); + + EXPECT_EQ(1, visitor.error_count_); + EXPECT_EQ(Http2DecoderAdapter::SPDY_UNEXPECTED_FRAME, + visitor.deframer_.spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + visitor.deframer_.spdy_framer_error()); + EXPECT_EQ(1, visitor.headers_frame_count_); + EXPECT_EQ(0, visitor.continuation_count_); + EXPECT_EQ(0u, visitor.header_buffer_length_); +} + +// Receiving an unknown frame when a continuation is expected should +// result in a SPDY_UNEXPECTED_FRAME error +TEST_P(SpdyFramerTest, ReceiveUnknownMidContinuationWithExtension) { + const unsigned char kInput[] = { + 0x00, 0x00, 0x10, // Length: 16 + 0x01, // Type: HEADERS + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, 0x06, 0x63, 0x6f, // HPACK + 0x6f, 0x6b, 0x69, 0x65, // + 0x07, 0x66, 0x6f, 0x6f, // + 0x3d, 0x62, 0x61, 0x72, // + + 0x00, 0x00, 0x14, // Length: 20 + 0xa9, // Type: UnknownFrameType(169) + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, 0x06, 0x63, 0x6f, // Payload + 0x6f, 0x6b, 0x69, 0x65, // + 0x08, 0x62, 0x61, 0x7a, // + 0x3d, 0x62, 0x69, 0x6e, // + 0x67, 0x00, 0x06, 0x63, // + }; + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + TestExtension extension; + visitor.set_extension_visitor(&extension); + deframer_->set_visitor(&visitor); + visitor.SimulateInFramer(kInput, sizeof(kInput)); + + EXPECT_EQ(1, visitor.error_count_); + EXPECT_EQ(Http2DecoderAdapter::SPDY_UNEXPECTED_FRAME, + visitor.deframer_.spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + visitor.deframer_.spdy_framer_error()); + EXPECT_EQ(1, visitor.headers_frame_count_); + EXPECT_EQ(0, visitor.continuation_count_); + EXPECT_EQ(0u, visitor.header_buffer_length_); +} + +TEST_P(SpdyFramerTest, ReceiveContinuationOnWrongStream) { + const unsigned char kInput[] = { + 0x00, 0x00, 0x10, // Length: 16 + 0x01, // Type: HEADERS + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, 0x06, 0x63, 0x6f, // HPACK + 0x6f, 0x6b, 0x69, 0x65, // + 0x07, 0x66, 0x6f, 0x6f, // + 0x3d, 0x62, 0x61, 0x72, // + + 0x00, 0x00, 0x14, // Length: 20 + 0x09, // Type: CONTINUATION + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x02, // Stream: 2 + 0x00, 0x06, 0x63, 0x6f, // HPACK + 0x6f, 0x6b, 0x69, 0x65, // + 0x08, 0x62, 0x61, 0x7a, // + 0x3d, 0x62, 0x69, 0x6e, // + 0x67, 0x00, 0x06, 0x63, // + }; + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + deframer_->set_visitor(&visitor); + visitor.SimulateInFramer(kInput, sizeof(kInput)); + + EXPECT_EQ(1, visitor.error_count_); + EXPECT_EQ(Http2DecoderAdapter::SPDY_UNEXPECTED_FRAME, + visitor.deframer_.spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + visitor.deframer_.spdy_framer_error()); + EXPECT_EQ(1, visitor.headers_frame_count_); + EXPECT_EQ(0, visitor.continuation_count_); + EXPECT_EQ(0u, visitor.header_buffer_length_); +} + +TEST_P(SpdyFramerTest, ReadContinuationOutOfOrder) { + const unsigned char kInput[] = { + 0x00, 0x00, 0x18, // Length: 24 + 0x09, // Type: CONTINUATION + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, 0x06, 0x63, 0x6f, // HPACK + 0x6f, 0x6b, 0x69, 0x65, // + 0x07, 0x66, 0x6f, 0x6f, // + 0x3d, 0x62, 0x61, 0x72, // + }; + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + deframer_->set_visitor(&visitor); + visitor.SimulateInFramer(kInput, sizeof(kInput)); + + EXPECT_EQ(1, visitor.error_count_); + EXPECT_EQ(Http2DecoderAdapter::SPDY_UNEXPECTED_FRAME, + visitor.deframer_.spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + visitor.deframer_.spdy_framer_error()); + EXPECT_EQ(0, visitor.continuation_count_); + EXPECT_EQ(0u, visitor.header_buffer_length_); +} + +TEST_P(SpdyFramerTest, ExpectContinuationReceiveData) { + const unsigned char kInput[] = { + 0x00, 0x00, 0x10, // Length: 16 + 0x01, // Type: HEADERS + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, 0x06, 0x63, 0x6f, // HPACK + 0x6f, 0x6b, 0x69, 0x65, // + 0x07, 0x66, 0x6f, 0x6f, // + 0x3d, 0x62, 0x61, 0x72, // + + 0x00, 0x00, 0x00, // Length: 0 + 0x00, // Type: DATA + 0x01, // Flags: END_STREAM + 0x00, 0x00, 0x00, 0x04, // Stream: 4 + + 0xde, 0xad, 0xbe, 0xef, // Truncated Frame Header + }; + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + deframer_->set_visitor(&visitor); + visitor.SimulateInFramer(kInput, sizeof(kInput)); + + EXPECT_EQ(1, visitor.error_count_); + EXPECT_EQ(Http2DecoderAdapter::SPDY_UNEXPECTED_FRAME, + visitor.deframer_.spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + visitor.deframer_.spdy_framer_error()); + EXPECT_EQ(1, visitor.headers_frame_count_); + EXPECT_EQ(0, visitor.continuation_count_); + EXPECT_EQ(0u, visitor.header_buffer_length_); + EXPECT_EQ(0, visitor.data_frame_count_); +} + +TEST_P(SpdyFramerTest, ExpectContinuationReceiveControlFrame) { + const unsigned char kInput[] = { + 0x00, 0x00, 0x10, // Length: 16 + 0x01, // Type: HEADERS + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, 0x06, 0x63, 0x6f, // HPACK + 0x6f, 0x6b, 0x69, 0x65, // + 0x07, 0x66, 0x6f, 0x6f, // + 0x3d, 0x62, 0x61, 0x72, // + + 0x00, 0x00, 0x10, // Length: 16 + 0x01, // Type: HEADERS + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, 0x06, 0x63, 0x6f, // HPACK + 0x6f, 0x6b, 0x69, 0x65, // + 0x07, 0x66, 0x6f, 0x6f, // + 0x3d, 0x62, 0x61, 0x72, // + }; + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + deframer_->set_visitor(&visitor); + visitor.SimulateInFramer(kInput, sizeof(kInput)); + + EXPECT_EQ(1, visitor.error_count_); + EXPECT_EQ(Http2DecoderAdapter::SPDY_UNEXPECTED_FRAME, + visitor.deframer_.spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + visitor.deframer_.spdy_framer_error()); + EXPECT_EQ(1, visitor.headers_frame_count_); + EXPECT_EQ(0, visitor.continuation_count_); + EXPECT_EQ(0u, visitor.header_buffer_length_); + EXPECT_EQ(0, visitor.data_frame_count_); +} + +TEST_P(SpdyFramerTest, ReadGarbage) { + unsigned char garbage_frame[256]; + memset(garbage_frame, ~0, sizeof(garbage_frame)); + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer(garbage_frame, sizeof(garbage_frame)); + EXPECT_EQ(1, visitor.error_count_); +} + +TEST_P(SpdyFramerTest, ReadUnknownExtensionFrame) { + // The unrecognized frame type should still have a valid length. + const unsigned char unknown_frame[] = { + 0x00, 0x00, 0x08, // Length: 8 + 0xff, // Type: UnknownFrameType(255) + 0xff, // Flags: 0xff + 0xff, 0xff, 0xff, 0xff, // Stream: 0x7fffffff (R-bit set) + 0xff, 0xff, 0xff, 0xff, // Payload + 0xff, 0xff, 0xff, 0xff, // + }; + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + + // Simulate the case where the stream id validation checks out. + visitor.on_unknown_frame_result_ = true; + visitor.SimulateInFramer(unknown_frame, ABSL_ARRAYSIZE(unknown_frame)); + EXPECT_EQ(0, visitor.error_count_); + EXPECT_EQ(1, visitor.unknown_frame_count_); + EXPECT_EQ(8, visitor.unknown_payload_len_); + + // Follow it up with a valid control frame to make sure we handle + // subsequent frames correctly. + SpdySettingsIR settings_ir; + settings_ir.AddSetting(SETTINGS_HEADER_TABLE_SIZE, 10); + SpdySerializedFrame control_frame(framer_.SerializeSettings(settings_ir)); + if (use_output_) { + ASSERT_TRUE(framer_.SerializeSettings(settings_ir, &output_)); + control_frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + visitor.SimulateInFramer( + reinterpret_cast(control_frame.data()), + control_frame.size()); + EXPECT_EQ(0, visitor.error_count_); + EXPECT_EQ(1, visitor.setting_count_); + EXPECT_EQ(1, visitor.settings_ack_sent_); +} + +TEST_P(SpdyFramerTest, ReadUnknownExtensionFrameWithExtension) { + // The unrecognized frame type should still have a valid length. + const unsigned char unknown_frame[] = { + 0x00, 0x00, 0x14, // Length: 20 + 0xff, // Type: UnknownFrameType(255) + 0xff, // Flags: 0xff + 0xff, 0xff, 0xff, 0xff, // Stream: 0x7fffffff (R-bit set) + 0xff, 0xff, 0xff, 0xff, // Payload + 0xff, 0xff, 0xff, 0xff, // + 0xff, 0xff, 0xff, 0xff, // + 0xff, 0xff, 0xff, 0xff, // + 0xff, 0xff, 0xff, 0xff, // + }; + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + TestExtension extension; + visitor.set_extension_visitor(&extension); + visitor.SimulateInFramer(unknown_frame, ABSL_ARRAYSIZE(unknown_frame)); + EXPECT_EQ(0, visitor.error_count_); + EXPECT_EQ(0x7fffffffu, extension.stream_id_); + EXPECT_EQ(20u, extension.length_); + EXPECT_EQ(255, extension.type_); + EXPECT_EQ(0xff, extension.flags_); + EXPECT_EQ(std::string(20, '\xff'), extension.payload_); + + // Follow it up with a valid control frame to make sure we handle + // subsequent frames correctly. + SpdySettingsIR settings_ir; + settings_ir.AddSetting(SETTINGS_HEADER_TABLE_SIZE, 10); + SpdySerializedFrame control_frame(framer_.SerializeSettings(settings_ir)); + visitor.SimulateInFramer( + reinterpret_cast(control_frame.data()), + control_frame.size()); + EXPECT_EQ(0, visitor.error_count_); + EXPECT_EQ(1, visitor.setting_count_); + EXPECT_EQ(1, visitor.settings_ack_sent_); +} + +TEST_P(SpdyFramerTest, ReadGarbageWithValidLength) { + const unsigned char kFrameData[] = { + 0x00, 0x00, 0x08, // Length: 8 + 0xff, // Type: UnknownFrameType(255) + 0xff, // Flags: 0xff + 0xff, 0xff, 0xff, 0xff, // Stream: 0x7fffffff (R-bit set) + 0xff, 0xff, 0xff, 0xff, // Payload + 0xff, 0xff, 0xff, 0xff, // + }; + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer(kFrameData, ABSL_ARRAYSIZE(kFrameData)); + EXPECT_EQ(1, visitor.error_count_); +} + +TEST_P(SpdyFramerTest, ReadGarbageHPACKEncoding) { + const unsigned char kInput[] = { + 0x00, 0x12, 0x01, // Length: 4609 + 0x04, // Type: SETTINGS + 0x00, // Flags: none + 0x00, 0x00, 0x01, 0xef, // Stream: 495 + 0xef, 0xff, // Param: 61439 + 0xff, 0xff, 0xff, 0xff, // Value: 4294967295 + 0xff, 0xff, // Param: 0xffff + 0xff, 0xff, 0xff, 0xff, // Value: 4294967295 + 0xff, 0xff, 0xff, 0xff, // Settings (Truncated) + 0xff, // + }; + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer(kInput, ABSL_ARRAYSIZE(kInput)); + EXPECT_EQ(1, visitor.error_count_); +} + +TEST_P(SpdyFramerTest, SizesTest) { + EXPECT_EQ(9u, kFrameHeaderSize); + EXPECT_EQ(9u, kDataFrameMinimumSize); + EXPECT_EQ(9u, kHeadersFrameMinimumSize); + EXPECT_EQ(14u, kPriorityFrameSize); + EXPECT_EQ(13u, kRstStreamFrameSize); + EXPECT_EQ(9u, kSettingsFrameMinimumSize); + EXPECT_EQ(13u, kPushPromiseFrameMinimumSize); + EXPECT_EQ(17u, kPingFrameSize); + EXPECT_EQ(17u, kGoawayFrameMinimumSize); + EXPECT_EQ(13u, kWindowUpdateFrameSize); + EXPECT_EQ(9u, kContinuationFrameMinimumSize); + EXPECT_EQ(11u, kGetAltSvcFrameMinimumSize); + EXPECT_EQ(9u, kFrameMinimumSize); + + EXPECT_EQ(16384u, kHttp2DefaultFramePayloadLimit); + EXPECT_EQ(16393u, kHttp2DefaultFrameSizeLimit); +} + +TEST_P(SpdyFramerTest, StateToStringTest) { + EXPECT_STREQ("ERROR", Http2DecoderAdapter::StateToString( + Http2DecoderAdapter::SPDY_ERROR)); + EXPECT_STREQ("FRAME_COMPLETE", Http2DecoderAdapter::StateToString( + Http2DecoderAdapter::SPDY_FRAME_COMPLETE)); + EXPECT_STREQ("READY_FOR_FRAME", + Http2DecoderAdapter::StateToString( + Http2DecoderAdapter::SPDY_READY_FOR_FRAME)); + EXPECT_STREQ("READING_COMMON_HEADER", + Http2DecoderAdapter::StateToString( + Http2DecoderAdapter::SPDY_READING_COMMON_HEADER)); + EXPECT_STREQ("CONTROL_FRAME_PAYLOAD", + Http2DecoderAdapter::StateToString( + Http2DecoderAdapter::SPDY_CONTROL_FRAME_PAYLOAD)); + EXPECT_STREQ("IGNORE_REMAINING_PAYLOAD", + Http2DecoderAdapter::StateToString( + Http2DecoderAdapter::SPDY_IGNORE_REMAINING_PAYLOAD)); + EXPECT_STREQ("FORWARD_STREAM_FRAME", + Http2DecoderAdapter::StateToString( + Http2DecoderAdapter::SPDY_FORWARD_STREAM_FRAME)); + EXPECT_STREQ( + "SPDY_CONTROL_FRAME_BEFORE_HEADER_BLOCK", + Http2DecoderAdapter::StateToString( + Http2DecoderAdapter::SPDY_CONTROL_FRAME_BEFORE_HEADER_BLOCK)); + EXPECT_STREQ("SPDY_CONTROL_FRAME_HEADER_BLOCK", + Http2DecoderAdapter::StateToString( + Http2DecoderAdapter::SPDY_CONTROL_FRAME_HEADER_BLOCK)); + EXPECT_STREQ("SPDY_SETTINGS_FRAME_PAYLOAD", + Http2DecoderAdapter::StateToString( + Http2DecoderAdapter::SPDY_SETTINGS_FRAME_PAYLOAD)); + EXPECT_STREQ("SPDY_ALTSVC_FRAME_PAYLOAD", + Http2DecoderAdapter::StateToString( + Http2DecoderAdapter::SPDY_ALTSVC_FRAME_PAYLOAD)); + EXPECT_STREQ("UNKNOWN_STATE", + Http2DecoderAdapter::StateToString( + Http2DecoderAdapter::SPDY_ALTSVC_FRAME_PAYLOAD + 1)); +} + +TEST_P(SpdyFramerTest, SpdyFramerErrorToStringTest) { + EXPECT_STREQ("NO_ERROR", Http2DecoderAdapter::SpdyFramerErrorToString( + Http2DecoderAdapter::SPDY_NO_ERROR)); + EXPECT_STREQ("INVALID_STREAM_ID", + Http2DecoderAdapter::SpdyFramerErrorToString( + Http2DecoderAdapter::SPDY_INVALID_STREAM_ID)); + EXPECT_STREQ("INVALID_CONTROL_FRAME", + Http2DecoderAdapter::SpdyFramerErrorToString( + Http2DecoderAdapter::SPDY_INVALID_CONTROL_FRAME)); + EXPECT_STREQ("CONTROL_PAYLOAD_TOO_LARGE", + Http2DecoderAdapter::SpdyFramerErrorToString( + Http2DecoderAdapter::SPDY_CONTROL_PAYLOAD_TOO_LARGE)); + EXPECT_STREQ("DECOMPRESS_FAILURE", + Http2DecoderAdapter::SpdyFramerErrorToString( + Http2DecoderAdapter::SPDY_DECOMPRESS_FAILURE)); + EXPECT_STREQ("INVALID_PADDING", + Http2DecoderAdapter::SpdyFramerErrorToString( + Http2DecoderAdapter::SPDY_INVALID_PADDING)); + EXPECT_STREQ("INVALID_DATA_FRAME_FLAGS", + Http2DecoderAdapter::SpdyFramerErrorToString( + Http2DecoderAdapter::SPDY_INVALID_DATA_FRAME_FLAGS)); + EXPECT_STREQ("UNEXPECTED_FRAME", + Http2DecoderAdapter::SpdyFramerErrorToString( + Http2DecoderAdapter::SPDY_UNEXPECTED_FRAME)); + EXPECT_STREQ("INTERNAL_FRAMER_ERROR", + Http2DecoderAdapter::SpdyFramerErrorToString( + Http2DecoderAdapter::SPDY_INTERNAL_FRAMER_ERROR)); + EXPECT_STREQ("INVALID_CONTROL_FRAME_SIZE", + Http2DecoderAdapter::SpdyFramerErrorToString( + Http2DecoderAdapter::SPDY_INVALID_CONTROL_FRAME_SIZE)); + EXPECT_STREQ("OVERSIZED_PAYLOAD", + Http2DecoderAdapter::SpdyFramerErrorToString( + Http2DecoderAdapter::SPDY_OVERSIZED_PAYLOAD)); + EXPECT_STREQ("UNKNOWN_ERROR", Http2DecoderAdapter::SpdyFramerErrorToString( + Http2DecoderAdapter::LAST_ERROR)); + EXPECT_STREQ("UNKNOWN_ERROR", + Http2DecoderAdapter::SpdyFramerErrorToString( + static_cast( + Http2DecoderAdapter::LAST_ERROR + 1))); +} + +TEST_P(SpdyFramerTest, DataFrameFlagsV4) { + uint8_t valid_data_flags = DATA_FLAG_FIN | DATA_FLAG_PADDED; + + uint8_t flags = 0; + do { + SCOPED_TRACE(testing::Message() + << "Flags " << std::hex << static_cast(flags)); + + testing::StrictMock visitor; + + deframer_->set_visitor(&visitor); + + SpdyDataIR data_ir(/* stream_id = */ 1, "hello"); + SpdySerializedFrame frame(framer_.SerializeData(data_ir)); + SetFrameFlags(&frame, flags); + + EXPECT_CALL(visitor, OnCommonHeader(1, 5, 0x0, flags)); + if (flags & ~valid_data_flags) { + EXPECT_CALL(visitor, OnError(_, _)); + } else { + EXPECT_CALL(visitor, OnDataFrameHeader(1, 5, flags & DATA_FLAG_FIN)); + if (flags & DATA_FLAG_PADDED) { + // The first byte of payload is parsed as padding length, but 'h' + // (0x68) is too large a padding length for a 5 byte payload. + EXPECT_CALL(visitor, OnStreamPadding(_, 1)); + // Expect Error since the frame ends prematurely. + EXPECT_CALL(visitor, OnError(_, _)); + } else { + EXPECT_CALL(visitor, OnStreamFrameData(_, _, 5)); + if (flags & DATA_FLAG_FIN) { + EXPECT_CALL(visitor, OnStreamEnd(_)); + } + } + } + + deframer_->ProcessInput(frame.data(), frame.size()); + if (flags & ~valid_data_flags) { + EXPECT_EQ(Http2DecoderAdapter::SPDY_ERROR, deframer_->state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_INVALID_DATA_FRAME_FLAGS, + deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); + } else if (flags & DATA_FLAG_PADDED) { + EXPECT_EQ(Http2DecoderAdapter::SPDY_ERROR, deframer_->state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_INVALID_PADDING, + deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); + } else { + EXPECT_EQ(Http2DecoderAdapter::SPDY_READY_FOR_FRAME, deframer_->state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_NO_ERROR, + deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); + } + deframer_ = std::make_unique(); + } while (++flags != 0); +} + +TEST_P(SpdyFramerTest, RstStreamFrameFlags) { + uint8_t flags = 0; + do { + SCOPED_TRACE(testing::Message() + << "Flags " << std::hex << static_cast(flags)); + + testing::StrictMock visitor; + deframer_->set_visitor(&visitor); + + SpdyRstStreamIR rst_stream(/* stream_id = */ 13, ERROR_CODE_CANCEL); + SpdySerializedFrame frame(framer_.SerializeRstStream(rst_stream)); + if (use_output_) { + output_.Reset(); + ASSERT_TRUE(framer_.SerializeRstStream(rst_stream, &output_)); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + SetFrameFlags(&frame, flags); + + EXPECT_CALL(visitor, OnCommonHeader(13, 4, 0x3, flags)); + EXPECT_CALL(visitor, OnRstStream(13, ERROR_CODE_CANCEL)); + + deframer_->ProcessInput(frame.data(), frame.size()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_READY_FOR_FRAME, deframer_->state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_NO_ERROR, + deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); + deframer_ = std::make_unique(); + } while (++flags != 0); +} + +TEST_P(SpdyFramerTest, SettingsFrameFlags) { + uint8_t flags = 0; + do { + SCOPED_TRACE(testing::Message() + << "Flags " << std::hex << static_cast(flags)); + + testing::StrictMock visitor; + deframer_->set_visitor(&visitor); + + SpdySettingsIR settings_ir; + settings_ir.AddSetting(SETTINGS_INITIAL_WINDOW_SIZE, 16); + SpdySerializedFrame frame(framer_.SerializeSettings(settings_ir)); + if (use_output_) { + output_.Reset(); + ASSERT_TRUE(framer_.SerializeSettings(settings_ir, &output_)); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + SetFrameFlags(&frame, flags); + + EXPECT_CALL(visitor, OnCommonHeader(0, 6, 0x4, flags)); + if (flags & SETTINGS_FLAG_ACK) { + EXPECT_CALL(visitor, OnError(_, _)); + } else { + EXPECT_CALL(visitor, OnSettings()); + EXPECT_CALL(visitor, OnSetting(SETTINGS_INITIAL_WINDOW_SIZE, 16)); + EXPECT_CALL(visitor, OnSettingsEnd()); + } + + deframer_->ProcessInput(frame.data(), frame.size()); + if (flags & SETTINGS_FLAG_ACK) { + // The frame is invalid because ACK frames should have no payload. + EXPECT_EQ(Http2DecoderAdapter::SPDY_ERROR, deframer_->state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_INVALID_CONTROL_FRAME_SIZE, + deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); + } else { + EXPECT_EQ(Http2DecoderAdapter::SPDY_READY_FOR_FRAME, deframer_->state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_NO_ERROR, + deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); + } + deframer_ = std::make_unique(); + } while (++flags != 0); +} + +TEST_P(SpdyFramerTest, GoawayFrameFlags) { + uint8_t flags = 0; + do { + SCOPED_TRACE(testing::Message() + << "Flags " << std::hex << static_cast(flags)); + + testing::StrictMock visitor; + + deframer_->set_visitor(&visitor); + + SpdyGoAwayIR goaway_ir(/* last_good_stream_id = */ 97, ERROR_CODE_NO_ERROR, + "test"); + SpdySerializedFrame frame(framer_.SerializeGoAway(goaway_ir)); + if (use_output_) { + output_.Reset(); + ASSERT_TRUE(framer_.SerializeGoAway(goaway_ir, &output_)); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + SetFrameFlags(&frame, flags); + + EXPECT_CALL(visitor, OnCommonHeader(0, _, 0x7, flags)); + EXPECT_CALL(visitor, OnGoAway(97, ERROR_CODE_NO_ERROR)); + EXPECT_CALL(visitor, OnGoAwayFrameData) + .WillRepeatedly(testing::Return(true)); + + deframer_->ProcessInput(frame.data(), frame.size()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_READY_FOR_FRAME, deframer_->state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_NO_ERROR, + deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); + deframer_ = std::make_unique(); + } while (++flags != 0); +} + +TEST_P(SpdyFramerTest, HeadersFrameFlags) { + uint8_t flags = 0; + do { + SCOPED_TRACE(testing::Message() + << "Flags " << std::hex << static_cast(flags)); + + testing::StrictMock visitor; + SpdyFramer framer(SpdyFramer::ENABLE_COMPRESSION); + Http2DecoderAdapter deframer; + deframer.set_visitor(&visitor); + + SpdyHeadersIR headers_ir(/* stream_id = */ 57); + if (flags & HEADERS_FLAG_PRIORITY) { + headers_ir.set_weight(3); + headers_ir.set_has_priority(true); + headers_ir.set_parent_stream_id(5); + headers_ir.set_exclusive(true); + } + headers_ir.SetHeader("foo", "bar"); + SpdySerializedFrame frame(SpdyFramerPeer::SerializeHeaders( + &framer, headers_ir, use_output_ ? &output_ : nullptr)); + uint8_t set_flags = flags & ~HEADERS_FLAG_PADDED; + SetFrameFlags(&frame, set_flags); + + // Expected callback values + SpdyStreamId stream_id = 57; + bool has_priority = false; + int weight = 0; + SpdyStreamId parent_stream_id = 0; + bool exclusive = false; + bool fin = flags & CONTROL_FLAG_FIN; + bool end = flags & HEADERS_FLAG_END_HEADERS; + if (flags & HEADERS_FLAG_PRIORITY) { + has_priority = true; + weight = 3; + parent_stream_id = 5; + exclusive = true; + } + EXPECT_CALL(visitor, OnCommonHeader(stream_id, _, 0x1, set_flags)); + EXPECT_CALL(visitor, OnHeaders(stream_id, _, has_priority, weight, + parent_stream_id, exclusive, fin, end)); + EXPECT_CALL(visitor, OnHeaderFrameStart(57)).Times(1); + if (end) { + EXPECT_CALL(visitor, OnHeaderFrameEnd(57)).Times(1); + } + if (flags & DATA_FLAG_FIN && end) { + EXPECT_CALL(visitor, OnStreamEnd(_)); + } else { + // Do not close the stream if we are expecting a CONTINUATION frame. + EXPECT_CALL(visitor, OnStreamEnd(_)).Times(0); + } + + deframer.ProcessInput(frame.data(), frame.size()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_READY_FOR_FRAME, deframer.state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_NO_ERROR, deframer.spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer.spdy_framer_error()); + } while (++flags != 0); +} + +TEST_P(SpdyFramerTest, PingFrameFlags) { + uint8_t flags = 0; + do { + SCOPED_TRACE(testing::Message() + << "Flags " << std::hex << static_cast(flags)); + + testing::StrictMock visitor; + deframer_->set_visitor(&visitor); + + SpdySerializedFrame frame(framer_.SerializePing(SpdyPingIR(42))); + SetFrameFlags(&frame, flags); + + EXPECT_CALL(visitor, OnCommonHeader(0, 8, 0x6, flags)); + EXPECT_CALL(visitor, OnPing(42, flags & PING_FLAG_ACK)); + + deframer_->ProcessInput(frame.data(), frame.size()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_READY_FOR_FRAME, deframer_->state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_NO_ERROR, + deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); + deframer_ = std::make_unique(); + } while (++flags != 0); +} + +TEST_P(SpdyFramerTest, WindowUpdateFrameFlags) { + uint8_t flags = 0; + do { + SCOPED_TRACE(testing::Message() + << "Flags " << std::hex << static_cast(flags)); + + testing::StrictMock visitor; + + deframer_->set_visitor(&visitor); + + SpdySerializedFrame frame(framer_.SerializeWindowUpdate( + SpdyWindowUpdateIR(/* stream_id = */ 4, /* delta = */ 1024))); + SetFrameFlags(&frame, flags); + + EXPECT_CALL(visitor, OnCommonHeader(4, 4, 0x8, flags)); + EXPECT_CALL(visitor, OnWindowUpdate(4, 1024)); + + deframer_->ProcessInput(frame.data(), frame.size()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_READY_FOR_FRAME, deframer_->state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_NO_ERROR, + deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); + deframer_ = std::make_unique(); + } while (++flags != 0); +} + +TEST_P(SpdyFramerTest, PushPromiseFrameFlags) { + const SpdyStreamId client_id = 123; // Must be odd. + const SpdyStreamId promised_id = 22; // Must be even. + uint8_t flags = 0; + do { + SCOPED_TRACE(testing::Message() + << "Flags " << std::hex << static_cast(flags)); + + testing::StrictMock visitor; + testing::StrictMock debug_visitor; + SpdyFramer framer(SpdyFramer::ENABLE_COMPRESSION); + Http2DecoderAdapter deframer; + deframer.set_visitor(&visitor); + deframer.set_debug_visitor(&debug_visitor); + framer.set_debug_visitor(&debug_visitor); + + EXPECT_CALL( + debug_visitor, + OnSendCompressedFrame(client_id, SpdyFrameType::PUSH_PROMISE, _, _)); + + SpdyPushPromiseIR push_promise(client_id, promised_id); + push_promise.SetHeader("foo", "bar"); + SpdySerializedFrame frame(SpdyFramerPeer::SerializePushPromise( + &framer, push_promise, use_output_ ? &output_ : nullptr)); + // TODO(jgraettinger): Add padding to SpdyPushPromiseIR, + // and implement framing. + SetFrameFlags(&frame, flags & ~HEADERS_FLAG_PADDED); + + bool end = flags & PUSH_PROMISE_FLAG_END_PUSH_PROMISE; + EXPECT_CALL(debug_visitor, OnReceiveCompressedFrame( + client_id, SpdyFrameType::PUSH_PROMISE, _)); + EXPECT_CALL(visitor, OnCommonHeader(client_id, _, 0x5, + flags & ~HEADERS_FLAG_PADDED)); + EXPECT_CALL(visitor, OnPushPromise(client_id, promised_id, end)); + EXPECT_CALL(visitor, OnHeaderFrameStart(client_id)).Times(1); + if (end) { + EXPECT_CALL(visitor, OnHeaderFrameEnd(client_id)).Times(1); + } + + deframer.ProcessInput(frame.data(), frame.size()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_READY_FOR_FRAME, deframer.state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_NO_ERROR, deframer.spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer.spdy_framer_error()); + } while (++flags != 0); +} + +TEST_P(SpdyFramerTest, ContinuationFrameFlags) { + uint8_t flags = 0; + do { + if (use_output_) { + output_.Reset(); + } + SCOPED_TRACE(testing::Message() + << "Flags " << std::hex << static_cast(flags)); + + testing::StrictMock visitor; + testing::StrictMock debug_visitor; + SpdyFramer framer(SpdyFramer::ENABLE_COMPRESSION); + Http2DecoderAdapter deframer; + deframer.set_visitor(&visitor); + deframer.set_debug_visitor(&debug_visitor); + framer.set_debug_visitor(&debug_visitor); + + EXPECT_CALL(debug_visitor, + OnSendCompressedFrame(42, SpdyFrameType::HEADERS, _, _)); + EXPECT_CALL(debug_visitor, + OnReceiveCompressedFrame(42, SpdyFrameType::HEADERS, _)); + EXPECT_CALL(visitor, OnCommonHeader(42, _, 0x1, 0)); + EXPECT_CALL(visitor, OnHeaders(42, _, false, 0, 0, false, false, false)); + EXPECT_CALL(visitor, OnHeaderFrameStart(42)).Times(1); + + SpdyHeadersIR headers_ir(/* stream_id = */ 42); + headers_ir.SetHeader("foo", "bar"); + SpdySerializedFrame frame0; + if (use_output_) { + EXPECT_TRUE(framer.SerializeHeaders(headers_ir, &output_)); + frame0 = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } else { + frame0 = framer.SerializeHeaders(headers_ir); + } + SetFrameFlags(&frame0, 0); + + SpdyContinuationIR continuation(/* stream_id = */ 42); + SpdySerializedFrame frame1; + if (use_output_) { + char* begin = output_.Begin() + output_.Size(); + ASSERT_TRUE(framer.SerializeContinuation(continuation, &output_)); + frame1 = + SpdySerializedFrame(begin, output_.Size() - frame0.size(), false); + } else { + frame1 = framer.SerializeContinuation(continuation); + } + SetFrameFlags(&frame1, flags); + + EXPECT_CALL(debug_visitor, + OnReceiveCompressedFrame(42, SpdyFrameType::CONTINUATION, _)); + EXPECT_CALL(visitor, OnCommonHeader(42, _, 0x9, flags)); + EXPECT_CALL(visitor, + OnContinuation(42, _, flags & HEADERS_FLAG_END_HEADERS)); + bool end = flags & HEADERS_FLAG_END_HEADERS; + if (end) { + EXPECT_CALL(visitor, OnHeaderFrameEnd(42)).Times(1); + } + + deframer.ProcessInput(frame0.data(), frame0.size()); + deframer.ProcessInput(frame1.data(), frame1.size()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_READY_FOR_FRAME, deframer.state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_NO_ERROR, deframer.spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer.spdy_framer_error()); + } while (++flags != 0); +} + +// TODO(mlavan): Add TEST_P(SpdyFramerTest, AltSvcFrameFlags) + +// Test handling of a RST_STREAM with out-of-bounds status codes. +TEST_P(SpdyFramerTest, RstStreamStatusBounds) { + const unsigned char kH2RstStreamInvalid[] = { + 0x00, 0x00, 0x04, // Length: 4 + 0x03, // Type: RST_STREAM + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, 0x00, 0x00, 0x00, // Error: NO_ERROR + }; + const unsigned char kH2RstStreamNumStatusCodes[] = { + 0x00, 0x00, 0x04, // Length: 4 + 0x03, // Type: RST_STREAM + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, 0x00, 0x00, 0xff, // Error: 255 + }; + + testing::StrictMock visitor; + deframer_->set_visitor(&visitor); + + EXPECT_CALL(visitor, OnCommonHeader(1, 4, 0x3, 0x0)); + EXPECT_CALL(visitor, OnRstStream(1, ERROR_CODE_NO_ERROR)); + deframer_->ProcessInput(reinterpret_cast(kH2RstStreamInvalid), + ABSL_ARRAYSIZE(kH2RstStreamInvalid)); + EXPECT_EQ(Http2DecoderAdapter::SPDY_READY_FOR_FRAME, deframer_->state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_NO_ERROR, deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); + deframer_ = std::make_unique(); + deframer_->set_visitor(&visitor); + + EXPECT_CALL(visitor, OnCommonHeader(1, 4, 0x3, 0x0)); + EXPECT_CALL(visitor, OnRstStream(1, ERROR_CODE_INTERNAL_ERROR)); + deframer_->ProcessInput( + reinterpret_cast(kH2RstStreamNumStatusCodes), + ABSL_ARRAYSIZE(kH2RstStreamNumStatusCodes)); + EXPECT_EQ(Http2DecoderAdapter::SPDY_READY_FOR_FRAME, deframer_->state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_NO_ERROR, deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); +} + +// Test handling of GOAWAY frames with out-of-bounds status code. +TEST_P(SpdyFramerTest, GoAwayStatusBounds) { + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x0a, // Length: 10 + 0x07, // Type: GOAWAY + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x00, // Stream: 0 + 0x00, 0x00, 0x00, 0x01, // Last: 1 + 0xff, 0xff, 0xff, 0xff, // Error: 0xffffffff + 0x47, 0x41, // Description + }; + testing::StrictMock visitor; + deframer_->set_visitor(&visitor); + + EXPECT_CALL(visitor, OnCommonHeader(0, 10, 0x7, 0x0)); + EXPECT_CALL(visitor, OnGoAway(1, ERROR_CODE_INTERNAL_ERROR)); + EXPECT_CALL(visitor, OnGoAwayFrameData).WillRepeatedly(testing::Return(true)); + deframer_->ProcessInput(reinterpret_cast(kH2FrameData), + ABSL_ARRAYSIZE(kH2FrameData)); + EXPECT_EQ(Http2DecoderAdapter::SPDY_READY_FOR_FRAME, deframer_->state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_NO_ERROR, deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); +} + +// Tests handling of a GOAWAY frame with out-of-bounds stream ID. +TEST_P(SpdyFramerTest, GoAwayStreamIdBounds) { + const unsigned char kH2FrameData[] = { + 0x00, 0x00, 0x08, // Length: 8 + 0x07, // Type: GOAWAY + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x00, // Stream: 0 + 0xff, 0xff, 0xff, 0xff, // Last: 0x7fffffff (R-bit set) + 0x00, 0x00, 0x00, 0x00, // Error: NO_ERROR + }; + + testing::StrictMock visitor; + + deframer_->set_visitor(&visitor); + + EXPECT_CALL(visitor, OnCommonHeader(0, 8, 0x7, 0x0)); + EXPECT_CALL(visitor, OnGoAway(0x7fffffff, ERROR_CODE_NO_ERROR)); + EXPECT_CALL(visitor, OnGoAwayFrameData).WillRepeatedly(testing::Return(true)); + deframer_->ProcessInput(reinterpret_cast(kH2FrameData), + ABSL_ARRAYSIZE(kH2FrameData)); + EXPECT_EQ(Http2DecoderAdapter::SPDY_READY_FOR_FRAME, deframer_->state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_NO_ERROR, deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); +} + +TEST_P(SpdyFramerTest, OnAltSvcWithOrigin) { + const SpdyStreamId kStreamId = 0; // Stream id must be zero if origin given. + + testing::StrictMock visitor; + + deframer_->set_visitor(&visitor); + + SpdyAltSvcWireFormat::AlternativeService altsvc1( + "pid1", "host", 443, 5, SpdyAltSvcWireFormat::VersionVector()); + SpdyAltSvcWireFormat::AlternativeService altsvc2( + "p\"=i:d", "h_\\o\"st", 123, 42, SpdyAltSvcWireFormat::VersionVector{24}); + SpdyAltSvcWireFormat::AlternativeServiceVector altsvc_vector; + altsvc_vector.push_back(altsvc1); + altsvc_vector.push_back(altsvc2); + EXPECT_CALL(visitor, OnCommonHeader(kStreamId, _, 0x0A, 0x0)); + EXPECT_CALL(visitor, + OnAltSvc(kStreamId, absl::string_view("o_r|g!n"), altsvc_vector)); + + SpdyAltSvcIR altsvc_ir(kStreamId); + altsvc_ir.set_origin("o_r|g!n"); + altsvc_ir.add_altsvc(altsvc1); + altsvc_ir.add_altsvc(altsvc2); + SpdySerializedFrame frame(framer_.SerializeFrame(altsvc_ir)); + if (use_output_) { + output_.Reset(); + EXPECT_EQ(framer_.SerializeFrame(altsvc_ir, &output_), frame.size()); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + deframer_->ProcessInput(frame.data(), frame.size()); + + EXPECT_EQ(Http2DecoderAdapter::SPDY_READY_FOR_FRAME, deframer_->state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_NO_ERROR, deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); +} + +TEST_P(SpdyFramerTest, OnAltSvcNoOrigin) { + const SpdyStreamId kStreamId = 1; + + testing::StrictMock visitor; + + deframer_->set_visitor(&visitor); + + SpdyAltSvcWireFormat::AlternativeService altsvc1( + "pid1", "host", 443, 5, SpdyAltSvcWireFormat::VersionVector()); + SpdyAltSvcWireFormat::AlternativeService altsvc2( + "p\"=i:d", "h_\\o\"st", 123, 42, SpdyAltSvcWireFormat::VersionVector{24}); + SpdyAltSvcWireFormat::AlternativeServiceVector altsvc_vector; + altsvc_vector.push_back(altsvc1); + altsvc_vector.push_back(altsvc2); + EXPECT_CALL(visitor, OnCommonHeader(kStreamId, _, 0x0A, 0x0)); + EXPECT_CALL(visitor, + OnAltSvc(kStreamId, absl::string_view(""), altsvc_vector)); + + SpdyAltSvcIR altsvc_ir(kStreamId); + altsvc_ir.add_altsvc(altsvc1); + altsvc_ir.add_altsvc(altsvc2); + SpdySerializedFrame frame(framer_.SerializeFrame(altsvc_ir)); + deframer_->ProcessInput(frame.data(), frame.size()); + + EXPECT_EQ(Http2DecoderAdapter::SPDY_READY_FOR_FRAME, deframer_->state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_NO_ERROR, deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); +} + +TEST_P(SpdyFramerTest, OnAltSvcEmptyProtocolId) { + const SpdyStreamId kStreamId = 0; // Stream id must be zero if origin given. + + testing::StrictMock visitor; + + deframer_->set_visitor(&visitor); + + EXPECT_CALL(visitor, OnCommonHeader(kStreamId, _, 0x0A, 0x0)); + EXPECT_CALL(visitor, + OnError(Http2DecoderAdapter::SPDY_INVALID_CONTROL_FRAME, _)); + + SpdyAltSvcIR altsvc_ir(kStreamId); + altsvc_ir.set_origin("o1"); + altsvc_ir.add_altsvc(SpdyAltSvcWireFormat::AlternativeService( + "pid1", "host", 443, 5, SpdyAltSvcWireFormat::VersionVector())); + altsvc_ir.add_altsvc(SpdyAltSvcWireFormat::AlternativeService( + "", "h1", 443, 10, SpdyAltSvcWireFormat::VersionVector())); + SpdySerializedFrame frame(framer_.SerializeFrame(altsvc_ir)); + if (use_output_) { + output_.Reset(); + EXPECT_EQ(framer_.SerializeFrame(altsvc_ir, &output_), frame.size()); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + deframer_->ProcessInput(frame.data(), frame.size()); + + EXPECT_EQ(Http2DecoderAdapter::SPDY_ERROR, deframer_->state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_INVALID_CONTROL_FRAME, + deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); +} + +TEST_P(SpdyFramerTest, OnAltSvcBadLengths) { + const unsigned char kType = SerializeFrameType(SpdyFrameType::ALTSVC); + const unsigned char kFrameDataOriginLenLargerThanFrame[] = { + 0x00, 0x00, 0x05, kType, 0x00, 0x00, 0x00, + 0x00, 0x03, 0x42, 0x42, 'f', 'o', 'o', + }; + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + + deframer_->set_visitor(&visitor); + visitor.SimulateInFramer(kFrameDataOriginLenLargerThanFrame, + sizeof(kFrameDataOriginLenLargerThanFrame)); + + EXPECT_EQ(1, visitor.error_count_); + EXPECT_EQ(Http2DecoderAdapter::SPDY_INVALID_CONTROL_FRAME, + visitor.deframer_.spdy_framer_error()); +} + +// Tests handling of ALTSVC frames delivered in small chunks. +TEST_P(SpdyFramerTest, ReadChunkedAltSvcFrame) { + SpdyAltSvcIR altsvc_ir(/* stream_id = */ 1); + SpdyAltSvcWireFormat::AlternativeService altsvc1( + "pid1", "host", 443, 5, SpdyAltSvcWireFormat::VersionVector()); + SpdyAltSvcWireFormat::AlternativeService altsvc2( + "p\"=i:d", "h_\\o\"st", 123, 42, SpdyAltSvcWireFormat::VersionVector{24}); + altsvc_ir.add_altsvc(altsvc1); + altsvc_ir.add_altsvc(altsvc2); + + SpdySerializedFrame control_frame(framer_.SerializeAltSvc(altsvc_ir)); + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + + // Read data in small chunks. + size_t framed_data = 0; + size_t unframed_data = control_frame.size(); + size_t kReadChunkSize = 5; // Read five bytes at a time. + while (unframed_data > 0) { + size_t to_read = std::min(kReadChunkSize, unframed_data); + visitor.SimulateInFramer( + reinterpret_cast(control_frame.data() + framed_data), + to_read); + unframed_data -= to_read; + framed_data += to_read; + } + EXPECT_EQ(0, visitor.error_count_); + EXPECT_EQ(1, visitor.altsvc_count_); + ASSERT_NE(nullptr, visitor.test_altsvc_ir_); + ASSERT_EQ(2u, visitor.test_altsvc_ir_->altsvc_vector().size()); + EXPECT_TRUE(visitor.test_altsvc_ir_->altsvc_vector()[0] == altsvc1); + EXPECT_TRUE(visitor.test_altsvc_ir_->altsvc_vector()[1] == altsvc2); +} + +// While RFC7838 Section 4 says that an ALTSVC frame on stream 0 with empty +// origin MUST be ignored, it is not implemented at the framer level: instead, +// such frames are passed on to the consumer. +TEST_P(SpdyFramerTest, ReadAltSvcFrame) { + constexpr struct { + uint32_t stream_id; + const char* origin; + } test_cases[] = {{0, ""}, + {1, ""}, + {0, "https://www.example.com"}, + {1, "https://www.example.com"}}; + for (const auto& test_case : test_cases) { + SpdyAltSvcIR altsvc_ir(test_case.stream_id); + SpdyAltSvcWireFormat::AlternativeService altsvc( + "pid1", "host", 443, 5, SpdyAltSvcWireFormat::VersionVector()); + altsvc_ir.add_altsvc(altsvc); + altsvc_ir.set_origin(test_case.origin); + SpdySerializedFrame frame(framer_.SerializeAltSvc(altsvc_ir)); + + TestSpdyVisitor visitor(SpdyFramer::ENABLE_COMPRESSION); + deframer_->set_visitor(&visitor); + deframer_->ProcessInput(frame.data(), frame.size()); + + EXPECT_EQ(0, visitor.error_count_); + EXPECT_EQ(1, visitor.altsvc_count_); + EXPECT_EQ(Http2DecoderAdapter::SPDY_READY_FOR_FRAME, deframer_->state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_NO_ERROR, + deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); + } +} + +// An ALTSVC frame with invalid Alt-Svc-Field-Value results in an error. +TEST_P(SpdyFramerTest, ErrorOnAltSvcFrameWithInvalidValue) { + // Alt-Svc-Field-Value must be "clear" or must contain an "=" character + // per RFC7838 Section 3. + const char kFrameData[] = { + 0x00, 0x00, 0x16, // Length: 22 + 0x0a, // Type: ALTSVC + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, 0x00, // Origin-Len: 0 + 0x74, 0x68, 0x69, 0x73, // thisisnotavalidvalue + 0x69, 0x73, 0x6e, 0x6f, 0x74, 0x61, 0x76, 0x61, + 0x6c, 0x69, 0x64, 0x76, 0x61, 0x6c, 0x75, 0x65, + }; + + TestSpdyVisitor visitor(SpdyFramer::ENABLE_COMPRESSION); + deframer_->set_visitor(&visitor); + deframer_->ProcessInput(kFrameData, sizeof(kFrameData)); + + EXPECT_EQ(1, visitor.error_count_); + EXPECT_EQ(0, visitor.altsvc_count_); + EXPECT_EQ(Http2DecoderAdapter::SPDY_ERROR, deframer_->state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_INVALID_CONTROL_FRAME, + deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); +} + +TEST_P(SpdyFramerTest, ReadPriorityUpdateFrame) { + const char kFrameData[] = { + 0x00, 0x00, 0x07, // payload length + 0x10, // frame type PRIORITY_UPDATE + 0x00, // flags + 0x00, 0x00, 0x00, 0x00, // stream ID, must be 0 + 0x00, 0x00, 0x00, 0x03, // prioritized stream ID, must not be zero + 'f', 'o', 'o' // priority field value + }; + + testing::StrictMock visitor; + deframer_->set_visitor(&visitor); + + EXPECT_CALL(visitor, OnCommonHeader(0, 7, 0x10, 0x0)); + EXPECT_CALL(visitor, OnPriorityUpdate(3, "foo")); + deframer_->ProcessInput(kFrameData, sizeof(kFrameData)); + EXPECT_FALSE(deframer_->HasError()); +} + +TEST_P(SpdyFramerTest, ReadPriorityUpdateFrameWithEmptyPriorityFieldValue) { + const char kFrameData[] = { + 0x00, 0x00, 0x04, // payload length + 0x10, // frame type PRIORITY_UPDATE + 0x00, // flags + 0x00, 0x00, 0x00, 0x00, // stream ID, must be 0 + 0x00, 0x00, 0x00, 0x03 // prioritized stream ID, must not be zero + }; + + testing::StrictMock visitor; + deframer_->set_visitor(&visitor); + + EXPECT_CALL(visitor, OnCommonHeader(0, 4, 0x10, 0x0)); + EXPECT_CALL(visitor, OnPriorityUpdate(3, "")); + deframer_->ProcessInput(kFrameData, sizeof(kFrameData)); + EXPECT_FALSE(deframer_->HasError()); +} + +TEST_P(SpdyFramerTest, PriorityUpdateFrameWithEmptyPayload) { + const char kFrameData[] = { + 0x00, 0x00, 0x00, // payload length + 0x10, // frame type PRIORITY_UPDATE + 0x00, // flags + 0x00, 0x00, 0x00, 0x00, // stream ID, must be 0 + }; + + testing::StrictMock visitor; + deframer_->set_visitor(&visitor); + + EXPECT_CALL(visitor, OnCommonHeader(0, 0, 0x10, 0x0)); + EXPECT_CALL(visitor, + OnError(Http2DecoderAdapter::SPDY_INVALID_CONTROL_FRAME_SIZE, _)); + deframer_->ProcessInput(kFrameData, sizeof(kFrameData)); + EXPECT_TRUE(deframer_->HasError()); +} + +TEST_P(SpdyFramerTest, PriorityUpdateFrameWithShortPayload) { + const char kFrameData[] = { + 0x00, 0x00, 0x02, // payload length + 0x10, // frame type PRIORITY_UPDATE + 0x00, // flags + 0x00, 0x00, 0x00, 0x00, // stream ID, must be 0 + 0x00, 0x01 // payload not long enough to hold 32 bits of prioritized + // stream ID + }; + + testing::StrictMock visitor; + deframer_->set_visitor(&visitor); + + EXPECT_CALL(visitor, OnCommonHeader(0, 2, 0x10, 0x0)); + EXPECT_CALL(visitor, + OnError(Http2DecoderAdapter::SPDY_INVALID_CONTROL_FRAME_SIZE, _)); + deframer_->ProcessInput(kFrameData, sizeof(kFrameData)); + EXPECT_TRUE(deframer_->HasError()); +} + +TEST_P(SpdyFramerTest, PriorityUpdateFrameOnIncorrectStream) { + const char kFrameData[] = { + 0x00, 0x00, 0x04, // payload length + 0x10, // frame type PRIORITY_UPDATE + 0x00, // flags + 0x00, 0x00, 0x00, 0x01, // invalid stream ID, must be 0 + 0x00, 0x00, 0x00, 0x01, // prioritized stream ID, must not be zero + }; + + testing::StrictMock visitor; + deframer_->set_visitor(&visitor); + + EXPECT_CALL(visitor, OnCommonHeader(1, 4, 0x10, 0x0)); + EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, _)); + deframer_->ProcessInput(kFrameData, sizeof(kFrameData)); + EXPECT_TRUE(deframer_->HasError()); +} + +TEST_P(SpdyFramerTest, PriorityUpdateFramePrioritizingIncorrectStream) { + const char kFrameData[] = { + 0x00, 0x00, 0x04, // payload length + 0x10, // frame type PRIORITY_UPDATE + 0x00, // flags + 0x00, 0x00, 0x00, 0x00, // stream ID, must be 0 + 0x00, 0x00, 0x00, 0x00, // prioritized stream ID, must not be zero + }; + + testing::StrictMock visitor; + deframer_->set_visitor(&visitor); + + EXPECT_CALL(visitor, OnCommonHeader(0, 4, 0x10, 0x0)); + EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, _)); + deframer_->ProcessInput(kFrameData, sizeof(kFrameData)); + EXPECT_TRUE(deframer_->HasError()); +} + +// Tests handling of PRIORITY frames. +TEST_P(SpdyFramerTest, ReadPriority) { + SpdyPriorityIR priority(/* stream_id = */ 3, + /* parent_stream_id = */ 1, + /* weight = */ 256, + /* exclusive = */ false); + SpdySerializedFrame frame(framer_.SerializePriority(priority)); + if (use_output_) { + output_.Reset(); + ASSERT_TRUE(framer_.SerializePriority(priority, &output_)); + frame = SpdySerializedFrame(output_.Begin(), output_.Size(), false); + } + testing::StrictMock visitor; + deframer_->set_visitor(&visitor); + EXPECT_CALL(visitor, OnCommonHeader(3, 5, 0x2, 0x0)); + EXPECT_CALL(visitor, OnPriority(3, 1, 256, false)); + deframer_->ProcessInput(frame.data(), frame.size()); + + EXPECT_EQ(Http2DecoderAdapter::SPDY_READY_FOR_FRAME, deframer_->state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_NO_ERROR, deframer_->spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + deframer_->spdy_framer_error()); +} + +// Tests handling of PRIORITY frame with incorrect size. +TEST_P(SpdyFramerTest, ReadIncorrectlySizedPriority) { + // PRIORITY frame of size 4, which isn't correct. + const unsigned char kFrameData[] = { + 0x00, 0x00, 0x04, // Length: 4 + 0x02, // Type: PRIORITY + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x03, // Stream: 3 + 0x00, 0x00, 0x00, 0x01, // Priority (Truncated) + }; + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer(kFrameData, sizeof(kFrameData)); + + EXPECT_EQ(Http2DecoderAdapter::SPDY_ERROR, visitor.deframer_.state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_INVALID_CONTROL_FRAME_SIZE, + visitor.deframer_.spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + visitor.deframer_.spdy_framer_error()); +} + +// Tests handling of PING frame with incorrect size. +TEST_P(SpdyFramerTest, ReadIncorrectlySizedPing) { + // PING frame of size 4, which isn't correct. + const unsigned char kFrameData[] = { + 0x00, 0x00, 0x04, // Length: 4 + 0x06, // Type: PING + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x00, // Stream: 0 + 0x00, 0x00, 0x00, 0x01, // Ping (Truncated) + }; + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer(kFrameData, sizeof(kFrameData)); + + EXPECT_EQ(Http2DecoderAdapter::SPDY_ERROR, visitor.deframer_.state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_INVALID_CONTROL_FRAME_SIZE, + visitor.deframer_.spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + visitor.deframer_.spdy_framer_error()); +} + +// Tests handling of WINDOW_UPDATE frame with incorrect size. +TEST_P(SpdyFramerTest, ReadIncorrectlySizedWindowUpdate) { + // WINDOW_UPDATE frame of size 3, which isn't correct. + const unsigned char kFrameData[] = { + 0x00, 0x00, 0x03, // Length: 3 + 0x08, // Type: WINDOW_UPDATE + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x03, // Stream: 3 + 0x00, 0x00, 0x01, // WindowUpdate (Truncated) + }; + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer(kFrameData, sizeof(kFrameData)); + + EXPECT_EQ(Http2DecoderAdapter::SPDY_ERROR, visitor.deframer_.state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_INVALID_CONTROL_FRAME_SIZE, + visitor.deframer_.spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + visitor.deframer_.spdy_framer_error()); +} + +// Tests handling of RST_STREAM frame with incorrect size. +TEST_P(SpdyFramerTest, ReadIncorrectlySizedRstStream) { + // RST_STREAM frame of size 3, which isn't correct. + const unsigned char kFrameData[] = { + 0x00, 0x00, 0x03, // Length: 3 + 0x03, // Type: RST_STREAM + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x03, // Stream: 3 + 0x00, 0x00, 0x01, // RstStream (Truncated) + }; + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer(kFrameData, sizeof(kFrameData)); + + EXPECT_EQ(Http2DecoderAdapter::SPDY_ERROR, visitor.deframer_.state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_INVALID_CONTROL_FRAME_SIZE, + visitor.deframer_.spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + visitor.deframer_.spdy_framer_error()); +} + +// Regression test for https://crbug.com/548674: +// RST_STREAM with payload must not be accepted. +TEST_P(SpdyFramerTest, ReadInvalidRstStreamWithPayload) { + const unsigned char kFrameData[] = { + 0x00, 0x00, 0x07, // Length: 7 + 0x03, // Type: RST_STREAM + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, 0x00, 0x00, 0x00, // Error: NO_ERROR + 'f', 'o', 'o' // Payload: "foo" + }; + + TestSpdyVisitor visitor(SpdyFramer::DISABLE_COMPRESSION); + visitor.SimulateInFramer(kFrameData, sizeof(kFrameData)); + + EXPECT_EQ(Http2DecoderAdapter::SPDY_ERROR, visitor.deframer_.state()); + EXPECT_EQ(Http2DecoderAdapter::SPDY_INVALID_CONTROL_FRAME_SIZE, + visitor.deframer_.spdy_framer_error()) + << Http2DecoderAdapter::SpdyFramerErrorToString( + visitor.deframer_.spdy_framer_error()); +} + +// Test that SpdyFramer processes all passed input in one call to ProcessInput. +TEST_P(SpdyFramerTest, ProcessAllInput) { + auto visitor = + std::make_unique(SpdyFramer::DISABLE_COMPRESSION); + deframer_->set_visitor(visitor.get()); + + // Create two input frames. + SpdyHeadersIR headers(/* stream_id = */ 1); + headers.SetHeader("alpha", "beta"); + headers.SetHeader("gamma", "charlie"); + headers.SetHeader("cookie", "key1=value1; key2=value2"); + SpdySerializedFrame headers_frame(SpdyFramerPeer::SerializeHeaders( + &framer_, headers, use_output_ ? &output_ : nullptr)); + + const char four_score[] = "Four score and seven years ago"; + SpdyDataIR four_score_ir(/* stream_id = */ 1, four_score); + SpdySerializedFrame four_score_frame(framer_.SerializeData(four_score_ir)); + + // Put them in a single buffer (new variables here to make it easy to + // change the order and type of frames). + SpdySerializedFrame frame1 = std::move(headers_frame); + SpdySerializedFrame frame2 = std::move(four_score_frame); + + const size_t frame1_size = frame1.size(); + const size_t frame2_size = frame2.size(); + + QUICHE_VLOG(1) << "frame1_size = " << frame1_size; + QUICHE_VLOG(1) << "frame2_size = " << frame2_size; + + std::string input_buffer; + input_buffer.append(frame1.data(), frame1_size); + input_buffer.append(frame2.data(), frame2_size); + + const char* buf = input_buffer.data(); + const size_t buf_size = input_buffer.size(); + + QUICHE_VLOG(1) << "buf_size = " << buf_size; + + size_t processed = deframer_->ProcessInput(buf, buf_size); + EXPECT_EQ(buf_size, processed); + EXPECT_EQ(Http2DecoderAdapter::SPDY_READY_FOR_FRAME, deframer_->state()); + EXPECT_EQ(1, visitor->headers_frame_count_); + EXPECT_EQ(1, visitor->data_frame_count_); + EXPECT_EQ(strlen(four_score), static_cast(visitor->data_bytes_)); +} + +namespace { +void CheckFrameAndIRSize(SpdyFrameIR* ir, SpdyFramer* framer, + ArrayOutputBuffer* output_buffer) { + output_buffer->Reset(); + SpdyFrameType type = ir->frame_type(); + size_t ir_size = ir->size(); + framer->SerializeFrame(*ir, output_buffer); + if (type == SpdyFrameType::HEADERS || type == SpdyFrameType::PUSH_PROMISE) { + // For HEADERS and PUSH_PROMISE, the size is an estimate. + EXPECT_GE(ir_size, output_buffer->Size() * 9 / 10); + EXPECT_LT(ir_size, output_buffer->Size() * 11 / 10); + } else { + EXPECT_EQ(ir_size, output_buffer->Size()); + } +} +} // namespace + +TEST_P(SpdyFramerTest, SpdyFrameIRSize) { + SpdyFramer framer(SpdyFramer::DISABLE_COMPRESSION); + + const char bytes[] = "this is a very short data frame"; + SpdyDataIR data_ir(1, absl::string_view(bytes, ABSL_ARRAYSIZE(bytes))); + CheckFrameAndIRSize(&data_ir, &framer, &output_); + + SpdyRstStreamIR rst_ir(/* stream_id = */ 1, ERROR_CODE_PROTOCOL_ERROR); + CheckFrameAndIRSize(&rst_ir, &framer, &output_); + + SpdySettingsIR settings_ir; + settings_ir.AddSetting(SETTINGS_HEADER_TABLE_SIZE, 5); + settings_ir.AddSetting(SETTINGS_ENABLE_PUSH, 6); + settings_ir.AddSetting(SETTINGS_MAX_CONCURRENT_STREAMS, 7); + CheckFrameAndIRSize(&settings_ir, &framer, &output_); + + SpdyPingIR ping_ir(42); + CheckFrameAndIRSize(&ping_ir, &framer, &output_); + + SpdyGoAwayIR goaway_ir(97, ERROR_CODE_NO_ERROR, "Goaway description"); + CheckFrameAndIRSize(&goaway_ir, &framer, &output_); + + SpdyHeadersIR headers_ir(1); + headers_ir.SetHeader("alpha", "beta"); + headers_ir.SetHeader("gamma", "charlie"); + headers_ir.SetHeader("cookie", "key1=value1; key2=value2"); + CheckFrameAndIRSize(&headers_ir, &framer, &output_); + + SpdyHeadersIR headers_ir_with_continuation(1); + headers_ir_with_continuation.SetHeader("alpha", std::string(100000, 'x')); + headers_ir_with_continuation.SetHeader("beta", std::string(100000, 'x')); + headers_ir_with_continuation.SetHeader("cookie", "key1=value1; key2=value2"); + CheckFrameAndIRSize(&headers_ir_with_continuation, &framer, &output_); + + SpdyWindowUpdateIR window_update_ir(4, 1024); + CheckFrameAndIRSize(&window_update_ir, &framer, &output_); + + SpdyPushPromiseIR push_promise_ir(3, 8); + push_promise_ir.SetHeader("alpha", std::string(100000, 'x')); + push_promise_ir.SetHeader("beta", std::string(100000, 'x')); + push_promise_ir.SetHeader("cookie", "key1=value1; key2=value2"); + CheckFrameAndIRSize(&push_promise_ir, &framer, &output_); + + SpdyAltSvcWireFormat::AlternativeService altsvc1( + "pid1", "host", 443, 5, SpdyAltSvcWireFormat::VersionVector()); + SpdyAltSvcWireFormat::AlternativeService altsvc2( + "p\"=i:d", "h_\\o\"st", 123, 42, SpdyAltSvcWireFormat::VersionVector{24}); + SpdyAltSvcWireFormat::AlternativeServiceVector altsvc_vector; + altsvc_vector.push_back(altsvc1); + altsvc_vector.push_back(altsvc2); + SpdyAltSvcIR altsvc_ir(0); + altsvc_ir.set_origin("o_r|g!n"); + altsvc_ir.add_altsvc(altsvc1); + altsvc_ir.add_altsvc(altsvc2); + CheckFrameAndIRSize(&altsvc_ir, &framer, &output_); + + SpdyPriorityIR priority_ir(3, 1, 256, false); + CheckFrameAndIRSize(&priority_ir, &framer, &output_); + + const char kDescription[] = "Unknown frame"; + const uint8_t kType = 0xaf; + const uint8_t kFlags = 0x11; + SpdyUnknownIR unknown_ir(2, kType, kFlags, kDescription); + CheckFrameAndIRSize(&unknown_ir, &framer, &output_); +} + +} // namespace test + +} // namespace spdy diff --git a/quiche/spdy/core/spdy_headers_handler_interface.h b/quiche/spdy/core/spdy_headers_handler_interface.h new file mode 100644 index 000000000000..629874123b6d --- /dev/null +++ b/quiche/spdy/core/spdy_headers_handler_interface.h @@ -0,0 +1,39 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_SPDY_CORE_SPDY_HEADERS_HANDLER_INTERFACE_H_ +#define QUICHE_SPDY_CORE_SPDY_HEADERS_HANDLER_INTERFACE_H_ + +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace spdy { + +// This interface defines how an object that accepts header data should behave. +// It is used by both SpdyHeadersBlockParser and HpackDecoder. +class QUICHE_EXPORT SpdyHeadersHandlerInterface { + public: + virtual ~SpdyHeadersHandlerInterface() {} + + // A callback method which notifies when the parser starts handling a new + // header block. Will only be called once per block, even if it extends into + // CONTINUATION frames. + virtual void OnHeaderBlockStart() = 0; + + // A callback method which notifies on a header key value pair. Multiple + // values for a given key will be emitted as multiple calls to OnHeader. + virtual void OnHeader(absl::string_view key, absl::string_view value) = 0; + + // A callback method which notifies when the parser finishes handling a + // header block (i.e. the containing frame has the END_HEADERS flag set). + // Also indicates the total number of bytes in this block. + virtual void OnHeaderBlockEnd(size_t uncompressed_header_bytes, + size_t compressed_header_bytes) = 0; +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_SPDY_HEADERS_HANDLER_INTERFACE_H_ diff --git a/quiche/spdy/core/spdy_intrusive_list.h b/quiche/spdy/core/spdy_intrusive_list.h new file mode 100644 index 000000000000..bdf1614e7a91 --- /dev/null +++ b/quiche/spdy/core/spdy_intrusive_list.h @@ -0,0 +1,341 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_SPDY_CORE_SPDY_INTRUSIVE_LIST_H_ +#define QUICHE_SPDY_CORE_SPDY_INTRUSIVE_LIST_H_ + +// A SpdyIntrusiveList<> is a doubly-linked list where the link pointers are +// embedded in the elements. They are circularly linked making insertion and +// removal into a known position constant time and branch-free operations. +// +// Usage is similar to an STL list<> where feasible, but there are important +// differences. First and foremost, the elements must derive from the +// SpdyIntrusiveLink<> base class: +// +// struct Foo : public SpdyIntrusiveLink { +// // ... +// } +// +// SpdyIntrusiveList l; +// l.push_back(new Foo); +// l.push_front(new Foo); +// l.erase(&l.front()); +// l.erase(&l.back()); +// +// Intrusive lists are primarily useful when you would have considered embedding +// link pointers in your class directly for space or performance reasons. An +// SpdyIntrusiveLink<> is the size of 2 pointers, usually 16 bytes on 64-bit +// systems. Intrusive lists do not perform memory allocation (unlike the STL +// list<> class) and thus may use less memory than list<>. In particular, if the +// list elements are pointers to objects, using a list<> would perform an extra +// memory allocation for each list node structure, while an SpdyIntrusiveList<> +// would not. +// +// Note that SpdyIntrusiveLink is exempt from the C++ style guide's limitations +// on multiple inheritance, so it's fine to inherit from both SpdyIntrusiveLink +// and a base class, even if the base class is not a pure interface. +// +// Because the list pointers are embedded in the objects stored in an +// SpdyIntrusiveList<>, erasing an item from a list is constant time. Consider +// the following: +// +// map foo_map; +// list foo_list; +// +// foo_list.push_back(&foo_map["bar"]); +// foo_list.erase(&foo_map["bar"]); // Compile error! +// +// The problem here is that a Foo* doesn't know where on foo_list it resides, +// so removal requires iteration over the list. Various tricks can be performed +// to overcome this. For example, a foo_list::iterator can be stored inside of +// the Foo object. But at that point you'd be better off using an +// SpdyIntrusiveList<>: +// +// map foo_map; +// SpdyIntrusiveList foo_list; +// +// foo_list.push_back(&foo_map["bar"]); +// foo_list.erase(&foo_map["bar"]); // Yeah! +// +// Note that SpdyIntrusiveLists come with a few limitations. The primary +// limitation is that the SpdyIntrusiveLink<> base class is not copyable or +// assignable. The result is that STL algorithms which mutate the order of +// iterators, such as reverse() and unique(), will not work by default with +// SpdyIntrusiveLists. In order to allow these algorithms to work you'll need to +// define swap() and/or operator= for your class. +// +// Another limitation is that the SpdyIntrusiveList<> structure itself is not +// copyable or assignable since an item/link combination can only exist on one +// SpdyIntrusiveList<> at a time. This limitation is a result of the link +// pointers for an item being intrusive in the item itself. For example, the +// following will not compile: +// +// FooList a; +// FooList b(a); // no copy constructor +// b = a; // no assignment operator +// +// The similar STL code does work since the link pointers are external to the +// item: +// +// list a; +// a.push_back(new int); +// list b(a); +// QUICHE_CHECK(a.front() == b.front()); +// +// Note that SpdyIntrusiveList::size() runs in O(N) time. + +#include + +#include + +#include "quiche/common/platform/api/quiche_export.h" + +namespace spdy { + +template +class SpdyIntrusiveList; + +template +class QUICHE_EXPORT SpdyIntrusiveLink { + protected: + // We declare the constructor protected so that only derived types and the + // befriended list can construct this. + SpdyIntrusiveLink() : next_(nullptr), prev_(nullptr) {} + +#ifndef SWIG + SpdyIntrusiveLink(const SpdyIntrusiveLink&) = delete; + SpdyIntrusiveLink& operator=(const SpdyIntrusiveLink&) = delete; +#endif // SWIG + + private: + // We befriend the matching list type so that it can manipulate the links + // while they are kept private from others. + friend class SpdyIntrusiveList; + + // Encapsulates the logic to convert from a link to its derived type. + T* cast_to_derived() { return static_cast(this); } + const T* cast_to_derived() const { return static_cast(this); } + + SpdyIntrusiveLink* next_; + SpdyIntrusiveLink* prev_; +}; + +template +class QUICHE_EXPORT SpdyIntrusiveList { + template + class iterator_impl; + + public: + typedef T value_type; + typedef value_type* pointer; + typedef const value_type* const_pointer; + typedef value_type& reference; + typedef const value_type& const_reference; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + + typedef SpdyIntrusiveLink link_type; + typedef iterator_impl iterator; + typedef iterator_impl const_iterator; + typedef std::reverse_iterator const_reverse_iterator; + typedef std::reverse_iterator reverse_iterator; + + SpdyIntrusiveList() { clear(); } + // After the move constructor the moved-from list will be empty. + // + // NOTE: There is no move assign operator (for now). + // The reason is that at the moment 'clear()' does not unlink the nodes. + // It makes is_linked() return true when it should return false. + // If such node is removed from the list (e.g. from its destructor), or is + // added to another list - a memory corruption will occur. + // Admitedly the destructor does not unlink the nodes either, but move-assign + // will likely make the problem more prominent. +#ifndef SWIG + SpdyIntrusiveList(SpdyIntrusiveList&& src) noexcept { + clear(); + if (src.empty()) return; + sentinel_link_.next_ = src.sentinel_link_.next_; + sentinel_link_.prev_ = src.sentinel_link_.prev_; + // Fix head and tail nodes of the list. + sentinel_link_.prev_->next_ = &sentinel_link_; + sentinel_link_.next_->prev_ = &sentinel_link_; + src.clear(); + } +#endif // SWIG + + iterator begin() { return iterator(sentinel_link_.next_); } + const_iterator begin() const { return const_iterator(sentinel_link_.next_); } + iterator end() { return iterator(&sentinel_link_); } + const_iterator end() const { return const_iterator(&sentinel_link_); } + reverse_iterator rbegin() { return reverse_iterator(end()); } + const_reverse_iterator rbegin() const { + return const_reverse_iterator(end()); + } + reverse_iterator rend() { return reverse_iterator(begin()); } + const_reverse_iterator rend() const { + return const_reverse_iterator(begin()); + } + + bool empty() const { return (sentinel_link_.next_ == &sentinel_link_); } + // This runs in O(N) time. + size_type size() const { return std::distance(begin(), end()); } + size_type max_size() const { return size_type(-1); } + + reference front() { return *begin(); } + const_reference front() const { return *begin(); } + reference back() { return *(--end()); } + const_reference back() const { return *(--end()); } + + static iterator insert(iterator position, T* obj) { + return insert_link(position.link(), obj); + } + void push_front(T* obj) { insert(begin(), obj); } + void push_back(T* obj) { insert(end(), obj); } + + static iterator erase(T* obj) { + link_type* obj_link = obj; + // Fix up the next and previous links for the previous and next objects. + obj_link->next_->prev_ = obj_link->prev_; + obj_link->prev_->next_ = obj_link->next_; + // Zero out the next and previous links for the removed item. This will + // cause any future attempt to remove the item from the list to cause a + // crash instead of possibly corrupting the list structure. + link_type* next_link = obj_link->next_; + obj_link->next_ = nullptr; + obj_link->prev_ = nullptr; + return iterator(next_link); + } + + static iterator erase(iterator position) { + return erase(position.operator->()); + } + void pop_front() { erase(begin()); } + void pop_back() { erase(--end()); } + + // Check whether the given element is linked into some list. Note that this + // does *not* check whether it is linked into a particular list. + // Also, if clear() is used to clear the containing list, is_linked() will + // still return true even though obj is no longer in any list. + static bool is_linked(const T* obj) { + return obj->link_type::next_ != nullptr; + } + + void clear() { + sentinel_link_.next_ = sentinel_link_.prev_ = &sentinel_link_; + } + void swap(SpdyIntrusiveList& x) { + SpdyIntrusiveList tmp; + tmp.splice(tmp.begin(), *this); + this->splice(this->begin(), x); + x.splice(x.begin(), tmp); + } + + void splice(iterator pos, SpdyIntrusiveList& src) { + splice(pos, src.begin(), src.end()); + } + + void splice(iterator pos, iterator i) { splice(pos, i, std::next(i)); } + + void splice(iterator pos, iterator first, iterator last) { + if (first == last) return; + + link_type* const last_prev = last.link()->prev_; + + // Remove from the source. + first.link()->prev_->next_ = last.operator->(); + last.link()->prev_ = first.link()->prev_; + + // Attach to the destination. + first.link()->prev_ = pos.link()->prev_; + pos.link()->prev_->next_ = first.operator->(); + last_prev->next_ = pos.operator->(); + pos.link()->prev_ = last_prev; + } + + private: + static iterator insert_link(link_type* next_link, T* obj) { + link_type* obj_link = obj; + obj_link->next_ = next_link; + link_type* const initial_next_prev = next_link->prev_; + obj_link->prev_ = initial_next_prev; + initial_next_prev->next_ = obj_link; + next_link->prev_ = obj_link; + return iterator(obj_link); + } + + // The iterator implementation is parameterized on a potentially qualified + // variant of T and the matching qualified link type. Essentially, QualifiedT + // will either be 'T' or 'const T', the latter for a const_iterator. + template + class QUICHE_EXPORT iterator_impl { + public: + using iterator_category = std::bidirectional_iterator_tag; + using value_type = QualifiedT; + using difference_type = std::ptrdiff_t; + using pointer = QualifiedT*; + using reference = QualifiedT&; + + iterator_impl() = default; + iterator_impl(QualifiedLinkT* link) : link_(link) {} + iterator_impl(const iterator_impl& x) = default; + iterator_impl& operator=(const iterator_impl& x) = default; + + // Allow converting and comparing across iterators where the pointer + // assignment and comparisons (respectively) are allowed. + template + iterator_impl(const iterator_impl& x) : link_(x.link_) {} + template + bool operator==(const iterator_impl& x) const { + return link_ == x.link_; + } + template + bool operator!=(const iterator_impl& x) const { + return link_ != x.link_; + } + + reference operator*() const { return *operator->(); } + pointer operator->() const { return link_->cast_to_derived(); } + + QualifiedLinkT* link() const { return link_; } + +#ifndef SWIG // SWIG can't wrap these operator overloads. + iterator_impl& operator++() { + link_ = link_->next_; + return *this; + } + iterator_impl operator++(int /*unused*/) { + iterator_impl tmp = *this; + ++*this; + return tmp; + } + iterator_impl& operator--() { + link_ = link_->prev_; + return *this; + } + iterator_impl operator--(int /*unused*/) { + iterator_impl tmp = *this; + --*this; + return tmp; + } +#endif // SWIG + + private: + // Ensure iterators can access other iterators node directly. + template + friend class iterator_impl; + + QualifiedLinkT* link_ = nullptr; + }; + + // This bare link acts as the sentinel node. + link_type sentinel_link_; + + // These are private and undefined to prevent copying and assigning. + SpdyIntrusiveList(const SpdyIntrusiveList&); + void operator=(const SpdyIntrusiveList&); +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_SPDY_INTRUSIVE_LIST_H_ diff --git a/quiche/spdy/core/spdy_intrusive_list_test.cc b/quiche/spdy/core/spdy_intrusive_list_test.cc new file mode 100644 index 000000000000..ad916250ff26 --- /dev/null +++ b/quiche/spdy/core/spdy_intrusive_list_test.cc @@ -0,0 +1,420 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/spdy_intrusive_list.h" + +#include +#include +#include +#include +#include + +#include "quiche/common/platform/api/quiche_test.h" + +namespace spdy { +namespace test { + +struct ListId2 {}; + +struct TestItem : public SpdyIntrusiveLink, + public SpdyIntrusiveLink { + int n; +}; +typedef SpdyIntrusiveList TestList; +typedef std::list CanonicalList; + +void swap(TestItem &a, TestItem &b) { + using std::swap; + swap(a.n, b.n); +} + +class IntrusiveListTest : public quiche::test::QuicheTest { + protected: + void CheckLists() { + CheckLists(l1, ll1); + if (quiche::test::QuicheTest::HasFailure()) return; + CheckLists(l2, ll2); + } + + void CheckLists(const TestList &list_a, const CanonicalList &list_b) { + ASSERT_EQ(list_a.size(), list_b.size()); + TestList::const_iterator it_a = list_a.begin(); + CanonicalList::const_iterator it_b = list_b.begin(); + while (it_a != list_a.end()) { + EXPECT_EQ(&*it_a++, *it_b++); + } + EXPECT_EQ(list_a.end(), it_a); + EXPECT_EQ(list_b.end(), it_b); + } + + void PrepareLists(int num_elems_1, int num_elems_2 = 0) { + FillLists(&l1, &ll1, e, num_elems_1); + FillLists(&l2, &ll2, e + num_elems_1, num_elems_2); + } + + void FillLists(TestList *list_a, CanonicalList *list_b, TestItem *elems, + int num_elems) { + list_a->clear(); + list_b->clear(); + for (int i = 0; i < num_elems; ++i) { + list_a->push_back(elems + i); + list_b->push_back(elems + i); + } + CheckLists(*list_a, *list_b); + } + + TestItem e[10]; + TestList l1, l2; + CanonicalList ll1, ll2; +}; + +TEST(NewIntrusiveListTest, Basic) { + TestList list1; + + EXPECT_EQ(sizeof(SpdyIntrusiveLink), sizeof(void *) * 2); + + for (int i = 0; i < 10; ++i) { + TestItem *e = new TestItem; + e->n = i; + list1.push_front(e); + } + EXPECT_EQ(list1.size(), 10u); + + // Verify we can reverse a list because we defined swap for TestItem. + std::reverse(list1.begin(), list1.end()); + EXPECT_EQ(list1.size(), 10u); + + // Check both const and non-const forward iteration. + const TestList &clist1 = list1; + int i = 0; + TestList::iterator iter = list1.begin(); + for (; iter != list1.end(); ++iter, ++i) { + EXPECT_EQ(iter->n, i); + } + EXPECT_EQ(iter, clist1.end()); + EXPECT_NE(iter, clist1.begin()); + i = 0; + iter = list1.begin(); + for (; iter != list1.end(); ++iter, ++i) { + EXPECT_EQ(iter->n, i); + } + EXPECT_EQ(iter, clist1.end()); + EXPECT_NE(iter, clist1.begin()); + + EXPECT_EQ(list1.front().n, 0); + EXPECT_EQ(list1.back().n, 9); + + // Verify we can swap 2 lists. + TestList list2; + list2.swap(list1); + EXPECT_EQ(list1.size(), 0u); + EXPECT_EQ(list2.size(), 10u); + + // Check both const and non-const reverse iteration. + const TestList &clist2 = list2; + TestList::reverse_iterator riter = list2.rbegin(); + i = 9; + for (; riter != list2.rend(); ++riter, --i) { + EXPECT_EQ(riter->n, i); + } + EXPECT_EQ(riter, clist2.rend()); + EXPECT_NE(riter, clist2.rbegin()); + + riter = list2.rbegin(); + i = 9; + for (; riter != list2.rend(); ++riter, --i) { + EXPECT_EQ(riter->n, i); + } + EXPECT_EQ(riter, clist2.rend()); + EXPECT_NE(riter, clist2.rbegin()); + + while (!list2.empty()) { + TestItem *e = &list2.front(); + list2.pop_front(); + delete e; + } +} + +TEST(NewIntrusiveListTest, Erase) { + TestList l; + TestItem *e[10]; + + // Create a list with 10 items. + for (int i = 0; i < 10; ++i) { + e[i] = new TestItem; + l.push_front(e[i]); + } + + // Test that erase works. + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(l.size(), (10u - i)); + + TestList::iterator iter = l.erase(e[i]); + EXPECT_NE(iter, TestList::iterator(e[i])); + + EXPECT_EQ(l.size(), (10u - i - 1)); + delete e[i]; + } +} + +TEST(NewIntrusiveListTest, Insert) { + TestList l; + TestList::iterator iter = l.end(); + TestItem *e[10]; + + // Create a list with 10 items. + for (int i = 9; i >= 0; --i) { + e[i] = new TestItem; + iter = l.insert(iter, e[i]); + EXPECT_EQ(&(*iter), e[i]); + } + + EXPECT_EQ(l.size(), 10u); + + // Verify insertion order. + iter = l.begin(); + for (TestItem *item : e) { + EXPECT_EQ(&(*iter), item); + iter = l.erase(item); + delete item; + } +} + +TEST(NewIntrusiveListTest, Move) { + // Move contructible. + + { // Move-construct from an empty list. + TestList src; + TestList dest(std::move(src)); + EXPECT_TRUE(dest.empty()); + } + + { // Move-construct from a single item list. + TestItem e; + TestList src; + src.push_front(&e); + + TestList dest(std::move(src)); + EXPECT_TRUE(src.empty()); // NOLINT bugprone-use-after-move + ASSERT_THAT(dest.size(), 1); + EXPECT_THAT(&dest.front(), &e); + EXPECT_THAT(&dest.back(), &e); + } + + { // Move-construct from a list with multiple items. + TestItem items[10]; + TestList src; + for (TestItem &e : items) src.push_back(&e); + + TestList dest(std::move(src)); + EXPECT_TRUE(src.empty()); // NOLINT bugprone-use-after-move + // Verify the items on the destination list. + ASSERT_THAT(dest.size(), 10); + int i = 0; + for (TestItem &e : dest) { + EXPECT_THAT(&e, &items[i++]) << " for index " << i; + } + } +} + +TEST(NewIntrusiveListTest, StaticInsertErase) { + TestList l; + TestItem e[2]; + TestList::iterator i = l.begin(); + TestList::insert(i, &e[0]); + TestList::insert(&e[0], &e[1]); + TestList::erase(&e[0]); + TestList::erase(TestList::iterator(&e[1])); + EXPECT_TRUE(l.empty()); +} + +TEST_F(IntrusiveListTest, Splice) { + // We verify that the contents of this secondary list aren't affected by any + // of the splices. + SpdyIntrusiveList secondary_list; + for (int i = 0; i < 3; ++i) { + secondary_list.push_back(&e[i]); + } + + // Test the basic cases: + // - The lists range from 0 to 2 elements. + // - The insertion point ranges from begin() to end() + // - The transfered range has multiple sizes and locations in the source. + for (int l1_count = 0; l1_count < 3; ++l1_count) { + for (int l2_count = 0; l2_count < 3; ++l2_count) { + for (int pos = 0; pos <= l1_count; ++pos) { + for (int first = 0; first <= l2_count; ++first) { + for (int last = first; last <= l2_count; ++last) { + PrepareLists(l1_count, l2_count); + + l1.splice(std::next(l1.begin(), pos), std::next(l2.begin(), first), + std::next(l2.begin(), last)); + ll1.splice(std::next(ll1.begin(), pos), ll2, + std::next(ll2.begin(), first), + std::next(ll2.begin(), last)); + + CheckLists(); + + ASSERT_EQ(3u, secondary_list.size()); + for (int i = 0; i < 3; ++i) { + EXPECT_EQ(&e[i], &*std::next(secondary_list.begin(), i)); + } + } + } + } + } + } +} + +// Build up a set of classes which form "challenging" type hierarchies to use +// with an SpdyIntrusiveList. +struct BaseLinkId {}; +struct DerivedLinkId {}; + +struct AbstractBase : public SpdyIntrusiveLink { + virtual ~AbstractBase() = 0; + virtual std::string name() { return "AbstractBase"; } +}; +AbstractBase::~AbstractBase() {} +struct DerivedClass : public SpdyIntrusiveLink, + public AbstractBase { + ~DerivedClass() override {} + std::string name() override { return "DerivedClass"; } +}; +struct VirtuallyDerivedBaseClass : public virtual AbstractBase { + ~VirtuallyDerivedBaseClass() override = 0; + std::string name() override { return "VirtuallyDerivedBaseClass"; } +}; +VirtuallyDerivedBaseClass::~VirtuallyDerivedBaseClass() {} +struct VirtuallyDerivedClassA + : public SpdyIntrusiveLink, + public virtual VirtuallyDerivedBaseClass { + ~VirtuallyDerivedClassA() override {} + std::string name() override { return "VirtuallyDerivedClassA"; } +}; +struct NonceClass { + virtual ~NonceClass() {} + int data_; +}; +struct VirtuallyDerivedClassB + : public SpdyIntrusiveLink, + public virtual NonceClass, + public virtual VirtuallyDerivedBaseClass { + ~VirtuallyDerivedClassB() override {} + std::string name() override { return "VirtuallyDerivedClassB"; } +}; +struct VirtuallyDerivedClassC + : public SpdyIntrusiveLink, + public virtual AbstractBase, + public virtual NonceClass, + public virtual VirtuallyDerivedBaseClass { + ~VirtuallyDerivedClassC() override {} + std::string name() override { return "VirtuallyDerivedClassC"; } +}; + +// Test for multiple layers between the element type and the link. +namespace templated_base_link { +template +struct AbstractBase : public SpdyIntrusiveLink { + virtual ~AbstractBase() = 0; +}; +template +AbstractBase::~AbstractBase() {} +struct DerivedClass : public AbstractBase { + int n; +}; +} // namespace templated_base_link + +TEST(NewIntrusiveListTest, HandleInheritanceHierarchies) { + { + SpdyIntrusiveList list; + DerivedClass elements[2]; + EXPECT_TRUE(list.empty()); + list.push_back(&elements[0]); + EXPECT_EQ(1u, list.size()); + list.push_back(&elements[1]); + EXPECT_EQ(2u, list.size()); + list.pop_back(); + EXPECT_EQ(1u, list.size()); + list.pop_back(); + EXPECT_TRUE(list.empty()); + } + { + SpdyIntrusiveList list; + VirtuallyDerivedClassA elements[2]; + EXPECT_TRUE(list.empty()); + list.push_back(&elements[0]); + EXPECT_EQ(1u, list.size()); + list.push_back(&elements[1]); + EXPECT_EQ(2u, list.size()); + list.pop_back(); + EXPECT_EQ(1u, list.size()); + list.pop_back(); + EXPECT_TRUE(list.empty()); + } + { + SpdyIntrusiveList list; + VirtuallyDerivedClassC elements[2]; + EXPECT_TRUE(list.empty()); + list.push_back(&elements[0]); + EXPECT_EQ(1u, list.size()); + list.push_back(&elements[1]); + EXPECT_EQ(2u, list.size()); + list.pop_back(); + EXPECT_EQ(1u, list.size()); + list.pop_back(); + EXPECT_TRUE(list.empty()); + } + { + SpdyIntrusiveList list; + DerivedClass d1; + VirtuallyDerivedClassA d2; + VirtuallyDerivedClassB d3; + VirtuallyDerivedClassC d4; + EXPECT_TRUE(list.empty()); + list.push_back(&d1); + EXPECT_EQ(1u, list.size()); + list.push_back(&d2); + EXPECT_EQ(2u, list.size()); + list.push_back(&d3); + EXPECT_EQ(3u, list.size()); + list.push_back(&d4); + EXPECT_EQ(4u, list.size()); + SpdyIntrusiveList::iterator it = list.begin(); + EXPECT_EQ("DerivedClass", (it++)->name()); + EXPECT_EQ("VirtuallyDerivedClassA", (it++)->name()); + EXPECT_EQ("VirtuallyDerivedClassB", (it++)->name()); + EXPECT_EQ("VirtuallyDerivedClassC", (it++)->name()); + } + { + SpdyIntrusiveList list; + templated_base_link::DerivedClass elements[2]; + EXPECT_TRUE(list.empty()); + list.push_back(&elements[0]); + EXPECT_EQ(1u, list.size()); + list.push_back(&elements[1]); + EXPECT_EQ(2u, list.size()); + list.pop_back(); + EXPECT_EQ(1u, list.size()); + list.pop_back(); + EXPECT_TRUE(list.empty()); + } +} + +class IntrusiveListTagTypeTest : public quiche::test::QuicheTest { + protected: + struct Tag {}; + class Element : public SpdyIntrusiveLink {}; +}; + +TEST_F(IntrusiveListTagTypeTest, TagTypeListID) { + SpdyIntrusiveList list; + { + Element e; + list.push_back(&e); + } +} + +} // namespace test +} // namespace spdy diff --git a/quiche/spdy/core/spdy_no_op_visitor.cc b/quiche/spdy/core/spdy_no_op_visitor.cc new file mode 100644 index 000000000000..d1ab89b9b661 --- /dev/null +++ b/quiche/spdy/core/spdy_no_op_visitor.cc @@ -0,0 +1,27 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/spdy_no_op_visitor.h" + +#include + +namespace spdy { + +SpdyNoOpVisitor::SpdyNoOpVisitor() { + static_assert(std::is_abstract::value == false, + "Need to update SpdyNoOpVisitor."); +} +SpdyNoOpVisitor::~SpdyNoOpVisitor() = default; + +SpdyHeadersHandlerInterface* SpdyNoOpVisitor::OnHeaderFrameStart( + SpdyStreamId /*stream_id*/) { + return this; +} + +bool SpdyNoOpVisitor::OnUnknownFrame(SpdyStreamId /*stream_id*/, + uint8_t /*frame_type*/) { + return true; +} + +} // namespace spdy diff --git a/quiche/spdy/core/spdy_no_op_visitor.h b/quiche/spdy/core/spdy_no_op_visitor.h new file mode 100644 index 000000000000..b20c9c30b111 --- /dev/null +++ b/quiche/spdy/core/spdy_no_op_visitor.h @@ -0,0 +1,91 @@ +// Copyright (c) 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// SpdyNoOpVisitor implements several of the visitor and handler interfaces +// to make it easier to write tests that need to provide instances. Other +// interfaces can be added as needed. + +#ifndef QUICHE_SPDY_CORE_SPDY_NO_OP_VISITOR_H_ +#define QUICHE_SPDY_CORE_SPDY_NO_OP_VISITOR_H_ + +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/spdy/core/http2_frame_decoder_adapter.h" +#include "quiche/spdy/core/spdy_protocol.h" + +namespace spdy { + +class QUICHE_EXPORT SpdyNoOpVisitor : public SpdyFramerVisitorInterface, + public SpdyFramerDebugVisitorInterface, + public SpdyHeadersHandlerInterface { + public: + SpdyNoOpVisitor(); + ~SpdyNoOpVisitor() override; + + // SpdyFramerVisitorInterface methods: + void OnError(http2::Http2DecoderAdapter::SpdyFramerError /*error*/, + std::string /*detailed_error*/) override {} + SpdyHeadersHandlerInterface* OnHeaderFrameStart( + SpdyStreamId stream_id) override; + void OnHeaderFrameEnd(SpdyStreamId /*stream_id*/) override {} + void OnDataFrameHeader(SpdyStreamId /*stream_id*/, size_t /*length*/, + bool /*fin*/) override {} + void OnStreamFrameData(SpdyStreamId /*stream_id*/, const char* /*data*/, + size_t /*len*/) override {} + void OnStreamEnd(SpdyStreamId /*stream_id*/) override {} + void OnStreamPadding(SpdyStreamId /*stream_id*/, size_t /*len*/) override {} + void OnRstStream(SpdyStreamId /*stream_id*/, + SpdyErrorCode /*error_code*/) override {} + void OnSetting(SpdySettingsId /*id*/, uint32_t /*value*/) override {} + void OnPing(SpdyPingId /*unique_id*/, bool /*is_ack*/) override {} + void OnSettingsEnd() override {} + void OnSettingsAck() override {} + void OnGoAway(SpdyStreamId /*last_accepted_stream_id*/, + SpdyErrorCode /*error_code*/) override {} + void OnHeaders(SpdyStreamId /*stream_id*/, size_t /*payload_length*/, + bool /*has_priority*/, int /*weight*/, + SpdyStreamId /*parent_stream_id*/, bool /*exclusive*/, + bool /*fin*/, bool /*end*/) override {} + void OnWindowUpdate(SpdyStreamId /*stream_id*/, + int /*delta_window_size*/) override {} + void OnPushPromise(SpdyStreamId /*stream_id*/, + SpdyStreamId /*promised_stream_id*/, + bool /*end*/) override {} + void OnContinuation(SpdyStreamId /*stream_id*/, size_t /*payload_size*/, + bool /*end*/) override {} + void OnAltSvc(SpdyStreamId /*stream_id*/, absl::string_view /*origin*/, + const SpdyAltSvcWireFormat::AlternativeServiceVector& + /*altsvc_vector*/) override {} + void OnPriority(SpdyStreamId /*stream_id*/, SpdyStreamId /*parent_stream_id*/, + int /*weight*/, bool /*exclusive*/) override {} + void OnPriorityUpdate(SpdyStreamId /*prioritized_stream_id*/, + absl::string_view /*priority_field_value*/) override {} + bool OnUnknownFrame(SpdyStreamId /*stream_id*/, + uint8_t /*frame_type*/) override; + void OnUnknownFrameStart(SpdyStreamId /*stream_id*/, size_t /*length*/, + uint8_t /*type*/, uint8_t /*flags*/) override {} + void OnUnknownFramePayload(SpdyStreamId /*stream_id*/, + absl::string_view /*payload*/) override {} + + // SpdyFramerDebugVisitorInterface methods: + void OnSendCompressedFrame(SpdyStreamId /*stream_id*/, SpdyFrameType /*type*/, + size_t /*payload_len*/, + size_t /*frame_len*/) override {} + void OnReceiveCompressedFrame(SpdyStreamId /*stream_id*/, + SpdyFrameType /*type*/, + size_t /*frame_len*/) override {} + + // SpdyHeadersHandlerInterface methods: + void OnHeaderBlockStart() override {} + void OnHeader(absl::string_view /*key*/, + absl::string_view /*value*/) override {} + void OnHeaderBlockEnd(size_t /* uncompressed_header_bytes */, + size_t /* compressed_header_bytes */) override {} +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_SPDY_NO_OP_VISITOR_H_ diff --git a/quiche/spdy/core/spdy_pinnable_buffer_piece.cc b/quiche/spdy/core/spdy_pinnable_buffer_piece.cc new file mode 100644 index 000000000000..4448d909d1fc --- /dev/null +++ b/quiche/spdy/core/spdy_pinnable_buffer_piece.cc @@ -0,0 +1,36 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/spdy_pinnable_buffer_piece.h" + +#include + +namespace spdy { + +SpdyPinnableBufferPiece::SpdyPinnableBufferPiece() + : buffer_(nullptr), length_(0) {} + +SpdyPinnableBufferPiece::~SpdyPinnableBufferPiece() = default; + +void SpdyPinnableBufferPiece::Pin() { + if (!storage_ && buffer_ != nullptr && length_ != 0) { + storage_.reset(new char[length_]); + std::copy(buffer_, buffer_ + length_, storage_.get()); + buffer_ = storage_.get(); + } +} + +void SpdyPinnableBufferPiece::Swap(SpdyPinnableBufferPiece* other) { + size_t length = length_; + length_ = other->length_; + other->length_ = length; + + const char* buffer = buffer_; + buffer_ = other->buffer_; + other->buffer_ = buffer; + + storage_.swap(other->storage_); +} + +} // namespace spdy diff --git a/quiche/spdy/core/spdy_pinnable_buffer_piece.h b/quiche/spdy/core/spdy_pinnable_buffer_piece.h new file mode 100644 index 000000000000..d73a400b4a14 --- /dev/null +++ b/quiche/spdy/core/spdy_pinnable_buffer_piece.h @@ -0,0 +1,53 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_SPDY_CORE_SPDY_PINNABLE_BUFFER_PIECE_H_ +#define QUICHE_SPDY_CORE_SPDY_PINNABLE_BUFFER_PIECE_H_ + +#include + +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" + +namespace spdy { + +class SpdyPrefixedBufferReader; + +// Helper class of SpdyPrefixedBufferReader. +// Represents a piece of consumed buffer which may (or may not) own its +// underlying storage. Users may "pin" the buffer at a later time to ensure +// a SpdyPinnableBufferPiece owns and retains storage of the buffer. +struct QUICHE_EXPORT SpdyPinnableBufferPiece { + public: + SpdyPinnableBufferPiece(); + ~SpdyPinnableBufferPiece(); + + const char* buffer() const { return buffer_; } + + explicit operator absl::string_view() const { + return absl::string_view(buffer_, length_); + } + + // Allocates and copies the buffer to internal storage. + void Pin(); + + bool IsPinned() const { return storage_ != nullptr; } + + // Swaps buffers, including internal storage, with |other|. + void Swap(SpdyPinnableBufferPiece* other); + + private: + friend class SpdyPrefixedBufferReader; + + const char* buffer_; + size_t length_; + // Null iff |buffer_| isn't pinned. + std::unique_ptr storage_; +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_SPDY_PINNABLE_BUFFER_PIECE_H_ diff --git a/quiche/spdy/core/spdy_pinnable_buffer_piece_test.cc b/quiche/spdy/core/spdy_pinnable_buffer_piece_test.cc new file mode 100644 index 000000000000..984be68df730 --- /dev/null +++ b/quiche/spdy/core/spdy_pinnable_buffer_piece_test.cc @@ -0,0 +1,80 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/spdy_pinnable_buffer_piece.h" + +#include + +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/spdy/core/spdy_prefixed_buffer_reader.h" + +namespace spdy { + +namespace test { + +class SpdyPinnableBufferPieceTest : public quiche::test::QuicheTest { + protected: + SpdyPrefixedBufferReader Build(const std::string& prefix, + const std::string& suffix) { + prefix_ = prefix; + suffix_ = suffix; + return SpdyPrefixedBufferReader(prefix_.data(), prefix_.length(), + suffix_.data(), suffix_.length()); + } + std::string prefix_, suffix_; +}; + +TEST_F(SpdyPinnableBufferPieceTest, Pin) { + SpdyPrefixedBufferReader reader = Build("foobar", ""); + SpdyPinnableBufferPiece piece; + EXPECT_TRUE(reader.ReadN(6, &piece)); + + // Piece points to underlying prefix storage. + EXPECT_EQ(absl::string_view("foobar"), absl::string_view(piece)); + EXPECT_FALSE(piece.IsPinned()); + EXPECT_EQ(prefix_.data(), piece.buffer()); + + piece.Pin(); + + // Piece now points to allocated storage. + EXPECT_EQ(absl::string_view("foobar"), absl::string_view(piece)); + EXPECT_TRUE(piece.IsPinned()); + EXPECT_NE(prefix_.data(), piece.buffer()); + + // Pinning again has no effect. + const char* buffer = piece.buffer(); + piece.Pin(); + EXPECT_EQ(buffer, piece.buffer()); +} + +TEST_F(SpdyPinnableBufferPieceTest, Swap) { + SpdyPrefixedBufferReader reader = Build("foobar", ""); + SpdyPinnableBufferPiece piece1, piece2; + EXPECT_TRUE(reader.ReadN(4, &piece1)); + EXPECT_TRUE(reader.ReadN(2, &piece2)); + + piece1.Pin(); + + EXPECT_EQ(absl::string_view("foob"), absl::string_view(piece1)); + EXPECT_TRUE(piece1.IsPinned()); + EXPECT_EQ(absl::string_view("ar"), absl::string_view(piece2)); + EXPECT_FALSE(piece2.IsPinned()); + + piece1.Swap(&piece2); + + EXPECT_EQ(absl::string_view("ar"), absl::string_view(piece1)); + EXPECT_FALSE(piece1.IsPinned()); + EXPECT_EQ(absl::string_view("foob"), absl::string_view(piece2)); + EXPECT_TRUE(piece2.IsPinned()); + + SpdyPinnableBufferPiece empty; + piece2.Swap(&empty); + + EXPECT_EQ(absl::string_view(""), absl::string_view(piece2)); + EXPECT_FALSE(piece2.IsPinned()); +} + +} // namespace test + +} // namespace spdy diff --git a/quiche/spdy/core/spdy_prefixed_buffer_reader.cc b/quiche/spdy/core/spdy_prefixed_buffer_reader.cc new file mode 100644 index 000000000000..8b5a252fae85 --- /dev/null +++ b/quiche/spdy/core/spdy_prefixed_buffer_reader.cc @@ -0,0 +1,84 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/spdy_prefixed_buffer_reader.h" + +#include + +#include "quiche/common/platform/api/quiche_logging.h" + +namespace spdy { + +SpdyPrefixedBufferReader::SpdyPrefixedBufferReader(const char* prefix, + size_t prefix_length, + const char* suffix, + size_t suffix_length) + : prefix_(prefix), + suffix_(suffix), + prefix_length_(prefix_length), + suffix_length_(suffix_length) {} + +size_t SpdyPrefixedBufferReader::Available() { + return prefix_length_ + suffix_length_; +} + +bool SpdyPrefixedBufferReader::ReadN(size_t count, char* out) { + if (Available() < count) { + return false; + } + + if (prefix_length_ >= count) { + // Read is fully satisfied by the prefix. + std::copy(prefix_, prefix_ + count, out); + prefix_ += count; + prefix_length_ -= count; + return true; + } else if (prefix_length_ != 0) { + // Read is partially satisfied by the prefix. + out = std::copy(prefix_, prefix_ + prefix_length_, out); + count -= prefix_length_; + prefix_length_ = 0; + // Fallthrough to suffix read. + } + QUICHE_DCHECK(suffix_length_ >= count); + // Read is satisfied by the suffix. + std::copy(suffix_, suffix_ + count, out); + suffix_ += count; + suffix_length_ -= count; + return true; +} + +bool SpdyPrefixedBufferReader::ReadN(size_t count, + SpdyPinnableBufferPiece* out) { + if (Available() < count) { + return false; + } + + out->storage_.reset(); + out->length_ = count; + + if (prefix_length_ >= count) { + // Read is fully satisfied by the prefix. + out->buffer_ = prefix_; + prefix_ += count; + prefix_length_ -= count; + return true; + } else if (prefix_length_ != 0) { + // Read is only partially satisfied by the prefix. We need to allocate + // contiguous storage as the read spans the prefix & suffix. + out->storage_.reset(new char[count]); + out->buffer_ = out->storage_.get(); + ReadN(count, out->storage_.get()); + return true; + } else { + QUICHE_DCHECK(suffix_length_ >= count); + // Read is fully satisfied by the suffix. + out->buffer_ = suffix_; + suffix_ += count; + suffix_length_ -= count; + return true; + } +} + +} // namespace spdy diff --git a/quiche/spdy/core/spdy_prefixed_buffer_reader.h b/quiche/spdy/core/spdy_prefixed_buffer_reader.h new file mode 100644 index 000000000000..c102ee6b6925 --- /dev/null +++ b/quiche/spdy/core/spdy_prefixed_buffer_reader.h @@ -0,0 +1,43 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_SPDY_CORE_SPDY_PREFIXED_BUFFER_READER_H_ +#define QUICHE_SPDY_CORE_SPDY_PREFIXED_BUFFER_READER_H_ + +#include + +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/spdy/core/spdy_pinnable_buffer_piece.h" + +namespace spdy { + +// Reader class which simplifies reading contiguously from +// from a disjoint buffer prefix & suffix. +class QUICHE_EXPORT SpdyPrefixedBufferReader { + public: + SpdyPrefixedBufferReader(const char* prefix, size_t prefix_length, + const char* suffix, size_t suffix_length); + + // Returns number of bytes available to be read. + size_t Available(); + + // Reads |count| bytes, copying into |*out|. Returns true on success, + // false if not enough bytes were available. + bool ReadN(size_t count, char* out); + + // Reads |count| bytes, returned in |*out|. Returns true on success, + // false if not enough bytes were available. + bool ReadN(size_t count, SpdyPinnableBufferPiece* out); + + private: + const char* prefix_; + const char* suffix_; + + size_t prefix_length_; + size_t suffix_length_; +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_SPDY_PREFIXED_BUFFER_READER_H_ diff --git a/quiche/spdy/core/spdy_prefixed_buffer_reader_test.cc b/quiche/spdy/core/spdy_prefixed_buffer_reader_test.cc new file mode 100644 index 000000000000..c013c4cb745c --- /dev/null +++ b/quiche/spdy/core/spdy_prefixed_buffer_reader_test.cc @@ -0,0 +1,131 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/spdy_prefixed_buffer_reader.h" + +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace spdy { + +namespace test { + +using testing::ElementsAreArray; + +class SpdyPrefixedBufferReaderTest : public quiche::test::QuicheTest { + protected: + SpdyPrefixedBufferReader Build(const std::string& prefix, + const std::string& suffix) { + prefix_ = prefix; + suffix_ = suffix; + return SpdyPrefixedBufferReader(prefix_.data(), prefix_.length(), + suffix_.data(), suffix_.length()); + } + std::string prefix_, suffix_; +}; + +TEST_F(SpdyPrefixedBufferReaderTest, ReadRawFromPrefix) { + SpdyPrefixedBufferReader reader = Build("foobar", ""); + EXPECT_EQ(6u, reader.Available()); + + char buffer[] = "123456"; + EXPECT_FALSE(reader.ReadN(10, buffer)); // Not enough buffer. + EXPECT_TRUE(reader.ReadN(6, buffer)); + EXPECT_THAT(buffer, ElementsAreArray("foobar")); + EXPECT_EQ(0u, reader.Available()); +} + +TEST_F(SpdyPrefixedBufferReaderTest, ReadPieceFromPrefix) { + SpdyPrefixedBufferReader reader = Build("foobar", ""); + EXPECT_EQ(6u, reader.Available()); + + SpdyPinnableBufferPiece piece; + EXPECT_FALSE(reader.ReadN(10, &piece)); // Not enough buffer. + EXPECT_TRUE(reader.ReadN(6, &piece)); + EXPECT_FALSE(piece.IsPinned()); + EXPECT_EQ(absl::string_view("foobar"), absl::string_view(piece)); + EXPECT_EQ(0u, reader.Available()); +} + +TEST_F(SpdyPrefixedBufferReaderTest, ReadRawFromSuffix) { + SpdyPrefixedBufferReader reader = Build("", "foobar"); + EXPECT_EQ(6u, reader.Available()); + + char buffer[] = "123456"; + EXPECT_FALSE(reader.ReadN(10, buffer)); // Not enough buffer. + EXPECT_TRUE(reader.ReadN(6, buffer)); + EXPECT_THAT(buffer, ElementsAreArray("foobar")); + EXPECT_EQ(0u, reader.Available()); +} + +TEST_F(SpdyPrefixedBufferReaderTest, ReadPieceFromSuffix) { + SpdyPrefixedBufferReader reader = Build("", "foobar"); + EXPECT_EQ(6u, reader.Available()); + + SpdyPinnableBufferPiece piece; + EXPECT_FALSE(reader.ReadN(10, &piece)); // Not enough buffer. + EXPECT_TRUE(reader.ReadN(6, &piece)); + EXPECT_FALSE(piece.IsPinned()); + EXPECT_EQ(absl::string_view("foobar"), absl::string_view(piece)); + EXPECT_EQ(0u, reader.Available()); +} + +TEST_F(SpdyPrefixedBufferReaderTest, ReadRawSpanning) { + SpdyPrefixedBufferReader reader = Build("foob", "ar"); + EXPECT_EQ(6u, reader.Available()); + + char buffer[] = "123456"; + EXPECT_FALSE(reader.ReadN(10, buffer)); // Not enough buffer. + EXPECT_TRUE(reader.ReadN(6, buffer)); + EXPECT_THAT(buffer, ElementsAreArray("foobar")); + EXPECT_EQ(0u, reader.Available()); +} + +TEST_F(SpdyPrefixedBufferReaderTest, ReadPieceSpanning) { + SpdyPrefixedBufferReader reader = Build("foob", "ar"); + EXPECT_EQ(6u, reader.Available()); + + SpdyPinnableBufferPiece piece; + EXPECT_FALSE(reader.ReadN(10, &piece)); // Not enough buffer. + EXPECT_TRUE(reader.ReadN(6, &piece)); + EXPECT_TRUE(piece.IsPinned()); + EXPECT_EQ(absl::string_view("foobar"), absl::string_view(piece)); + EXPECT_EQ(0u, reader.Available()); +} + +TEST_F(SpdyPrefixedBufferReaderTest, ReadMixed) { + SpdyPrefixedBufferReader reader = Build("abcdef", "hijkl"); + EXPECT_EQ(11u, reader.Available()); + + char buffer[] = "1234"; + SpdyPinnableBufferPiece piece; + + EXPECT_TRUE(reader.ReadN(3, buffer)); + EXPECT_THAT(buffer, ElementsAreArray("abc4")); + EXPECT_EQ(8u, reader.Available()); + + EXPECT_TRUE(reader.ReadN(2, buffer)); + EXPECT_THAT(buffer, ElementsAreArray("dec4")); + EXPECT_EQ(6u, reader.Available()); + + EXPECT_TRUE(reader.ReadN(3, &piece)); + EXPECT_EQ(absl::string_view("fhi"), absl::string_view(piece)); + EXPECT_TRUE(piece.IsPinned()); + EXPECT_EQ(3u, reader.Available()); + + EXPECT_TRUE(reader.ReadN(2, &piece)); + EXPECT_EQ(absl::string_view("jk"), absl::string_view(piece)); + EXPECT_FALSE(piece.IsPinned()); + EXPECT_EQ(1u, reader.Available()); + + EXPECT_TRUE(reader.ReadN(1, buffer)); + EXPECT_THAT(buffer, ElementsAreArray("lec4")); + EXPECT_EQ(0u, reader.Available()); +} + +} // namespace test + +} // namespace spdy diff --git a/quiche/spdy/core/spdy_protocol.cc b/quiche/spdy/core/spdy_protocol.cc new file mode 100644 index 000000000000..35c2c40a5ac3 --- /dev/null +++ b/quiche/spdy/core/spdy_protocol.cc @@ -0,0 +1,616 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/spdy_protocol.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "quiche/common/platform/api/quiche_bug_tracker.h" + +namespace spdy { + +const char* const kHttp2ConnectionHeaderPrefix = + "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; + +std::ostream& operator<<(std::ostream& out, SpdyKnownSettingsId id) { + return out << static_cast(id); +} + +std::ostream& operator<<(std::ostream& out, SpdyFrameType frame_type) { + return out << SerializeFrameType(frame_type); +} + +SpdyPriority ClampSpdy3Priority(SpdyPriority priority) { + static_assert(std::numeric_limits::min() == kV3HighestPriority, + "The value of given priority shouldn't be smaller than highest " + "priority. Check this invariant explicitly."); + if (priority > kV3LowestPriority) { + QUICHE_BUG(spdy_bug_22_1) + << "Invalid priority: " << static_cast(priority); + return kV3LowestPriority; + } + return priority; +} + +int ClampHttp2Weight(int weight) { + if (weight < kHttp2MinStreamWeight) { + QUICHE_BUG(spdy_bug_22_2) << "Invalid weight: " << weight; + return kHttp2MinStreamWeight; + } + if (weight > kHttp2MaxStreamWeight) { + QUICHE_BUG(spdy_bug_22_3) << "Invalid weight: " << weight; + return kHttp2MaxStreamWeight; + } + return weight; +} + +int Spdy3PriorityToHttp2Weight(SpdyPriority priority) { + priority = ClampSpdy3Priority(priority); + const float kSteps = 255.9f / 7.f; + return static_cast(kSteps * (7.f - priority)) + 1; +} + +SpdyPriority Http2WeightToSpdy3Priority(int weight) { + weight = ClampHttp2Weight(weight); + const float kSteps = 255.9f / 7.f; + return static_cast(7.f - (weight - 1) / kSteps); +} + +bool IsDefinedFrameType(uint8_t frame_type_field) { + switch (static_cast(frame_type_field)) { + case SpdyFrameType::DATA: + return true; + case SpdyFrameType::HEADERS: + return true; + case SpdyFrameType::PRIORITY: + return true; + case SpdyFrameType::RST_STREAM: + return true; + case SpdyFrameType::SETTINGS: + return true; + case SpdyFrameType::PUSH_PROMISE: + return true; + case SpdyFrameType::PING: + return true; + case SpdyFrameType::GOAWAY: + return true; + case SpdyFrameType::WINDOW_UPDATE: + return true; + case SpdyFrameType::CONTINUATION: + return true; + case SpdyFrameType::ALTSVC: + return true; + case SpdyFrameType::PRIORITY_UPDATE: + return true; + case SpdyFrameType::ACCEPT_CH: + return true; + } + return false; +} + +SpdyFrameType ParseFrameType(uint8_t frame_type_field) { + QUICHE_BUG_IF(spdy_bug_22_4, !IsDefinedFrameType(frame_type_field)) + << "Frame type not defined: " << static_cast(frame_type_field); + return static_cast(frame_type_field); +} + +uint8_t SerializeFrameType(SpdyFrameType frame_type) { + return static_cast(frame_type); +} + +bool IsValidHTTP2FrameStreamId(SpdyStreamId current_frame_stream_id, + SpdyFrameType frame_type_field) { + if (current_frame_stream_id == 0) { + switch (frame_type_field) { + case SpdyFrameType::DATA: + case SpdyFrameType::HEADERS: + case SpdyFrameType::PRIORITY: + case SpdyFrameType::RST_STREAM: + case SpdyFrameType::CONTINUATION: + case SpdyFrameType::PUSH_PROMISE: + // These frame types must specify a stream + return false; + default: + return true; + } + } else { + switch (frame_type_field) { + case SpdyFrameType::GOAWAY: + case SpdyFrameType::SETTINGS: + case SpdyFrameType::PING: + // These frame types must not specify a stream + return false; + default: + return true; + } + } +} + +const char* FrameTypeToString(SpdyFrameType frame_type) { + switch (frame_type) { + case SpdyFrameType::DATA: + return "DATA"; + case SpdyFrameType::RST_STREAM: + return "RST_STREAM"; + case SpdyFrameType::SETTINGS: + return "SETTINGS"; + case SpdyFrameType::PING: + return "PING"; + case SpdyFrameType::GOAWAY: + return "GOAWAY"; + case SpdyFrameType::HEADERS: + return "HEADERS"; + case SpdyFrameType::WINDOW_UPDATE: + return "WINDOW_UPDATE"; + case SpdyFrameType::PUSH_PROMISE: + return "PUSH_PROMISE"; + case SpdyFrameType::CONTINUATION: + return "CONTINUATION"; + case SpdyFrameType::PRIORITY: + return "PRIORITY"; + case SpdyFrameType::ALTSVC: + return "ALTSVC"; + case SpdyFrameType::PRIORITY_UPDATE: + return "PRIORITY_UPDATE"; + case SpdyFrameType::ACCEPT_CH: + return "ACCEPT_CH"; + } + return "UNKNOWN_FRAME_TYPE"; +} + +bool ParseSettingsId(SpdySettingsId wire_setting_id, + SpdyKnownSettingsId* setting_id) { + if (wire_setting_id != SETTINGS_EXPERIMENT_SCHEDULER && + (wire_setting_id < SETTINGS_MIN || wire_setting_id > SETTINGS_MAX)) { + return false; + } + + *setting_id = static_cast(wire_setting_id); + // This switch ensures that the casted value is valid. The default case is + // explicitly omitted to have compile-time guarantees that new additions to + // |SpdyKnownSettingsId| must also be handled here. + switch (*setting_id) { + case SETTINGS_HEADER_TABLE_SIZE: + case SETTINGS_ENABLE_PUSH: + case SETTINGS_MAX_CONCURRENT_STREAMS: + case SETTINGS_INITIAL_WINDOW_SIZE: + case SETTINGS_MAX_FRAME_SIZE: + case SETTINGS_MAX_HEADER_LIST_SIZE: + case SETTINGS_ENABLE_CONNECT_PROTOCOL: + case SETTINGS_DEPRECATE_HTTP2_PRIORITIES: + case SETTINGS_EXPERIMENT_SCHEDULER: + return true; + } + return false; +} + +std::string SettingsIdToString(SpdySettingsId id) { + SpdyKnownSettingsId known_id; + if (!ParseSettingsId(id, &known_id)) { + return absl::StrCat("SETTINGS_UNKNOWN_", absl::Hex(uint32_t{id})); + } + + switch (known_id) { + case SETTINGS_HEADER_TABLE_SIZE: + return "SETTINGS_HEADER_TABLE_SIZE"; + case SETTINGS_ENABLE_PUSH: + return "SETTINGS_ENABLE_PUSH"; + case SETTINGS_MAX_CONCURRENT_STREAMS: + return "SETTINGS_MAX_CONCURRENT_STREAMS"; + case SETTINGS_INITIAL_WINDOW_SIZE: + return "SETTINGS_INITIAL_WINDOW_SIZE"; + case SETTINGS_MAX_FRAME_SIZE: + return "SETTINGS_MAX_FRAME_SIZE"; + case SETTINGS_MAX_HEADER_LIST_SIZE: + return "SETTINGS_MAX_HEADER_LIST_SIZE"; + case SETTINGS_ENABLE_CONNECT_PROTOCOL: + return "SETTINGS_ENABLE_CONNECT_PROTOCOL"; + case SETTINGS_DEPRECATE_HTTP2_PRIORITIES: + return "SETTINGS_DEPRECATE_HTTP2_PRIORITIES"; + case SETTINGS_EXPERIMENT_SCHEDULER: + return "SETTINGS_EXPERIMENT_SCHEDULER"; + } + + return absl::StrCat("SETTINGS_UNKNOWN_", absl::Hex(uint32_t{id})); +} + +SpdyErrorCode ParseErrorCode(uint32_t wire_error_code) { + if (wire_error_code > ERROR_CODE_MAX) { + return ERROR_CODE_INTERNAL_ERROR; + } + + return static_cast(wire_error_code); +} + +const char* ErrorCodeToString(SpdyErrorCode error_code) { + switch (error_code) { + case ERROR_CODE_NO_ERROR: + return "NO_ERROR"; + case ERROR_CODE_PROTOCOL_ERROR: + return "PROTOCOL_ERROR"; + case ERROR_CODE_INTERNAL_ERROR: + return "INTERNAL_ERROR"; + case ERROR_CODE_FLOW_CONTROL_ERROR: + return "FLOW_CONTROL_ERROR"; + case ERROR_CODE_SETTINGS_TIMEOUT: + return "SETTINGS_TIMEOUT"; + case ERROR_CODE_STREAM_CLOSED: + return "STREAM_CLOSED"; + case ERROR_CODE_FRAME_SIZE_ERROR: + return "FRAME_SIZE_ERROR"; + case ERROR_CODE_REFUSED_STREAM: + return "REFUSED_STREAM"; + case ERROR_CODE_CANCEL: + return "CANCEL"; + case ERROR_CODE_COMPRESSION_ERROR: + return "COMPRESSION_ERROR"; + case ERROR_CODE_CONNECT_ERROR: + return "CONNECT_ERROR"; + case ERROR_CODE_ENHANCE_YOUR_CALM: + return "ENHANCE_YOUR_CALM"; + case ERROR_CODE_INADEQUATE_SECURITY: + return "INADEQUATE_SECURITY"; + case ERROR_CODE_HTTP_1_1_REQUIRED: + return "HTTP_1_1_REQUIRED"; + } + return "UNKNOWN_ERROR_CODE"; +} + +const char* WriteSchedulerTypeToString(WriteSchedulerType type) { + switch (type) { + case WriteSchedulerType::LIFO: + return "LIFO"; + case WriteSchedulerType::SPDY: + return "SPDY"; + case WriteSchedulerType::HTTP2: + return "HTTP2"; + case WriteSchedulerType::FIFO: + return "FIFO"; + } + return "UNKNOWN"; +} + +size_t GetNumberRequiredContinuationFrames(size_t size) { + QUICHE_DCHECK_GT(size, kHttp2MaxControlFrameSendSize); + size_t overflow = size - kHttp2MaxControlFrameSendSize; + int payload_size = + kHttp2MaxControlFrameSendSize - kContinuationFrameMinimumSize; + // This is ceiling(overflow/payload_size) using integer arithmetics. + return (overflow - 1) / payload_size + 1; +} + +const char* const kHttp2Npn = "h2"; + +const char* const kHttp2AuthorityHeader = ":authority"; +const char* const kHttp2MethodHeader = ":method"; +const char* const kHttp2PathHeader = ":path"; +const char* const kHttp2SchemeHeader = ":scheme"; +const char* const kHttp2ProtocolHeader = ":protocol"; + +const char* const kHttp2StatusHeader = ":status"; + +bool SpdyFrameIR::fin() const { return false; } + +int SpdyFrameIR::flow_control_window_consumed() const { return 0; } + +bool SpdyFrameWithFinIR::fin() const { return fin_; } + +SpdyFrameWithHeaderBlockIR::SpdyFrameWithHeaderBlockIR( + SpdyStreamId stream_id, Http2HeaderBlock header_block) + : SpdyFrameWithFinIR(stream_id), header_block_(std::move(header_block)) {} + +SpdyFrameWithHeaderBlockIR::~SpdyFrameWithHeaderBlockIR() = default; + +SpdyDataIR::SpdyDataIR(SpdyStreamId stream_id, absl::string_view data) + : SpdyFrameWithFinIR(stream_id), + data_(nullptr), + data_len_(0), + padded_(false), + padding_payload_len_(0) { + SetDataDeep(data); +} + +SpdyDataIR::SpdyDataIR(SpdyStreamId stream_id, const char* data) + : SpdyDataIR(stream_id, absl::string_view(data)) {} + +SpdyDataIR::SpdyDataIR(SpdyStreamId stream_id, std::string data) + : SpdyFrameWithFinIR(stream_id), + data_store_(std::make_unique(std::move(data))), + data_(data_store_->data()), + data_len_(data_store_->size()), + padded_(false), + padding_payload_len_(0) {} + +SpdyDataIR::SpdyDataIR(SpdyStreamId stream_id) + : SpdyFrameWithFinIR(stream_id), + data_(nullptr), + data_len_(0), + padded_(false), + padding_payload_len_(0) {} + +SpdyDataIR::~SpdyDataIR() = default; + +void SpdyDataIR::Visit(SpdyFrameVisitor* visitor) const { + return visitor->VisitData(*this); +} + +SpdyFrameType SpdyDataIR::frame_type() const { return SpdyFrameType::DATA; } + +int SpdyDataIR::flow_control_window_consumed() const { + return padded_ ? 1 + padding_payload_len_ + data_len_ : data_len_; +} + +size_t SpdyDataIR::size() const { + return kFrameHeaderSize + + (padded() ? 1 + padding_payload_len() + data_len() : data_len()); +} + +SpdyRstStreamIR::SpdyRstStreamIR(SpdyStreamId stream_id, + SpdyErrorCode error_code) + : SpdyFrameIR(stream_id) { + set_error_code(error_code); +} + +SpdyRstStreamIR::~SpdyRstStreamIR() = default; + +void SpdyRstStreamIR::Visit(SpdyFrameVisitor* visitor) const { + return visitor->VisitRstStream(*this); +} + +SpdyFrameType SpdyRstStreamIR::frame_type() const { + return SpdyFrameType::RST_STREAM; +} + +size_t SpdyRstStreamIR::size() const { return kRstStreamFrameSize; } + +SpdySettingsIR::SpdySettingsIR() : is_ack_(false) {} + +SpdySettingsIR::~SpdySettingsIR() = default; + +void SpdySettingsIR::Visit(SpdyFrameVisitor* visitor) const { + return visitor->VisitSettings(*this); +} + +SpdyFrameType SpdySettingsIR::frame_type() const { + return SpdyFrameType::SETTINGS; +} + +size_t SpdySettingsIR::size() const { + return kFrameHeaderSize + values_.size() * kSettingsOneSettingSize; +} + +void SpdyPingIR::Visit(SpdyFrameVisitor* visitor) const { + return visitor->VisitPing(*this); +} + +SpdyFrameType SpdyPingIR::frame_type() const { return SpdyFrameType::PING; } + +size_t SpdyPingIR::size() const { return kPingFrameSize; } + +SpdyGoAwayIR::SpdyGoAwayIR(SpdyStreamId last_good_stream_id, + SpdyErrorCode error_code, + absl::string_view description) + : description_(description) { + set_last_good_stream_id(last_good_stream_id); + set_error_code(error_code); +} + +SpdyGoAwayIR::SpdyGoAwayIR(SpdyStreamId last_good_stream_id, + SpdyErrorCode error_code, const char* description) + : SpdyGoAwayIR(last_good_stream_id, error_code, + absl::string_view(description)) {} + +SpdyGoAwayIR::SpdyGoAwayIR(SpdyStreamId last_good_stream_id, + SpdyErrorCode error_code, std::string description) + : description_store_(std::move(description)), + description_(description_store_) { + set_last_good_stream_id(last_good_stream_id); + set_error_code(error_code); +} + +SpdyGoAwayIR::~SpdyGoAwayIR() = default; + +void SpdyGoAwayIR::Visit(SpdyFrameVisitor* visitor) const { + return visitor->VisitGoAway(*this); +} + +SpdyFrameType SpdyGoAwayIR::frame_type() const { return SpdyFrameType::GOAWAY; } + +size_t SpdyGoAwayIR::size() const { + return kGoawayFrameMinimumSize + description_.size(); +} + +SpdyContinuationIR::SpdyContinuationIR(SpdyStreamId stream_id) + : SpdyFrameIR(stream_id), end_headers_(false) {} + +SpdyContinuationIR::~SpdyContinuationIR() = default; + +void SpdyContinuationIR::Visit(SpdyFrameVisitor* visitor) const { + return visitor->VisitContinuation(*this); +} + +SpdyFrameType SpdyContinuationIR::frame_type() const { + return SpdyFrameType::CONTINUATION; +} + +size_t SpdyContinuationIR::size() const { + // We don't need to get the size of CONTINUATION frame directly. It is + // calculated in HEADERS or PUSH_PROMISE frame. + QUICHE_DLOG(WARNING) << "Shouldn't not call size() for CONTINUATION frame."; + return 0; +} + +void SpdyHeadersIR::Visit(SpdyFrameVisitor* visitor) const { + return visitor->VisitHeaders(*this); +} + +SpdyFrameType SpdyHeadersIR::frame_type() const { + return SpdyFrameType::HEADERS; +} + +size_t SpdyHeadersIR::size() const { + size_t size = kHeadersFrameMinimumSize; + + if (padded_) { + // Padding field length. + size += 1; + size += padding_payload_len_; + } + + if (has_priority_) { + size += 5; + } + + // Assume no hpack encoding is applied. + size += header_block().TotalBytesUsed() + + header_block().size() * kPerHeaderHpackOverhead; + if (size > kHttp2MaxControlFrameSendSize) { + size += GetNumberRequiredContinuationFrames(size) * + kContinuationFrameMinimumSize; + } + return size; +} + +void SpdyWindowUpdateIR::Visit(SpdyFrameVisitor* visitor) const { + return visitor->VisitWindowUpdate(*this); +} + +SpdyFrameType SpdyWindowUpdateIR::frame_type() const { + return SpdyFrameType::WINDOW_UPDATE; +} + +size_t SpdyWindowUpdateIR::size() const { return kWindowUpdateFrameSize; } + +void SpdyPushPromiseIR::Visit(SpdyFrameVisitor* visitor) const { + return visitor->VisitPushPromise(*this); +} + +SpdyFrameType SpdyPushPromiseIR::frame_type() const { + return SpdyFrameType::PUSH_PROMISE; +} + +size_t SpdyPushPromiseIR::size() const { + size_t size = kPushPromiseFrameMinimumSize; + + if (padded_) { + // Padding length field. + size += 1; + size += padding_payload_len_; + } + + size += header_block().TotalBytesUsed(); + if (size > kHttp2MaxControlFrameSendSize) { + size += GetNumberRequiredContinuationFrames(size) * + kContinuationFrameMinimumSize; + } + return size; +} + +SpdyAltSvcIR::SpdyAltSvcIR(SpdyStreamId stream_id) : SpdyFrameIR(stream_id) {} + +SpdyAltSvcIR::~SpdyAltSvcIR() = default; + +void SpdyAltSvcIR::Visit(SpdyFrameVisitor* visitor) const { + return visitor->VisitAltSvc(*this); +} + +SpdyFrameType SpdyAltSvcIR::frame_type() const { return SpdyFrameType::ALTSVC; } + +size_t SpdyAltSvcIR::size() const { + size_t size = kGetAltSvcFrameMinimumSize; + size += origin_.length(); + // TODO(yasong): estimates the size without serializing the vector. + std::string str = + SpdyAltSvcWireFormat::SerializeHeaderFieldValue(altsvc_vector_); + size += str.size(); + return size; +} + +void SpdyPriorityIR::Visit(SpdyFrameVisitor* visitor) const { + return visitor->VisitPriority(*this); +} + +SpdyFrameType SpdyPriorityIR::frame_type() const { + return SpdyFrameType::PRIORITY; +} + +size_t SpdyPriorityIR::size() const { return kPriorityFrameSize; } + +void SpdyPriorityUpdateIR::Visit(SpdyFrameVisitor* visitor) const { + return visitor->VisitPriorityUpdate(*this); +} + +SpdyFrameType SpdyPriorityUpdateIR::frame_type() const { + return SpdyFrameType::PRIORITY_UPDATE; +} + +size_t SpdyPriorityUpdateIR::size() const { + return kPriorityUpdateFrameMinimumSize + priority_field_value_.size(); +} + +void SpdyAcceptChIR::Visit(SpdyFrameVisitor* visitor) const { + return visitor->VisitAcceptCh(*this); +} + +SpdyFrameType SpdyAcceptChIR::frame_type() const { + return SpdyFrameType::ACCEPT_CH; +} + +size_t SpdyAcceptChIR::size() const { + size_t total_size = kAcceptChFrameMinimumSize; + for (const AcceptChOriginValuePair& entry : entries_) { + total_size += entry.origin.size() + entry.value.size() + + kAcceptChFramePerEntryOverhead; + } + return total_size; +} + +void SpdyUnknownIR::Visit(SpdyFrameVisitor* visitor) const { + return visitor->VisitUnknown(*this); +} + +SpdyFrameType SpdyUnknownIR::frame_type() const { + return static_cast(type()); +} + +size_t SpdyUnknownIR::size() const { + return kFrameHeaderSize + payload_.size(); +} + +int SpdyUnknownIR::flow_control_window_consumed() const { + if (frame_type() == SpdyFrameType::DATA) { + return payload_.size(); + } else { + return 0; + } +} + +// Wire size of pad length field. +const size_t kPadLengthFieldSize = 1; + +size_t GetHeaderFrameSizeSansBlock(const SpdyHeadersIR& header_ir) { + size_t min_size = kFrameHeaderSize; + if (header_ir.padded()) { + min_size += kPadLengthFieldSize; + min_size += header_ir.padding_payload_len(); + } + if (header_ir.has_priority()) { + min_size += 5; + } + return min_size; +} + +size_t GetPushPromiseFrameSizeSansBlock( + const SpdyPushPromiseIR& push_promise_ir) { + size_t min_size = kPushPromiseFrameMinimumSize; + if (push_promise_ir.padded()) { + min_size += kPadLengthFieldSize; + min_size += push_promise_ir.padding_payload_len(); + } + return min_size; +} + +} // namespace spdy diff --git a/quiche/spdy/core/spdy_protocol.h b/quiche/spdy/core/spdy_protocol.h new file mode 100644 index 000000000000..d475cb04ceaa --- /dev/null +++ b/quiche/spdy/core/spdy_protocol.h @@ -0,0 +1,1126 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This file contains some protocol structures for use with SPDY 3 and HTTP 2 +// The SPDY 3 spec can be found at: +// http://dev.chromium.org/spdy/spdy-protocol/spdy-protocol-draft3 + +#ifndef QUICHE_SPDY_CORE_SPDY_PROTOCOL_H_ +#define QUICHE_SPDY_CORE_SPDY_PROTOCOL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/spdy/core/http2_header_block.h" +#include "quiche/spdy/core/spdy_alt_svc_wire_format.h" +#include "quiche/spdy/core/spdy_bitmasks.h" + +namespace spdy { + +// A stream ID is a 31-bit entity. +using SpdyStreamId = uint32_t; + +// A SETTINGS ID is a 16-bit entity. +using SpdySettingsId = uint16_t; + +// Specifies the stream ID used to denote the current session (for +// flow control). +const SpdyStreamId kSessionFlowControlStreamId = 0; + +// 0 is not a valid stream ID for any other purpose than flow control. +const SpdyStreamId kInvalidStreamId = 0; + +// Max stream id. +const SpdyStreamId kMaxStreamId = 0x7fffffff; + +// The maximum possible frame payload size allowed by the spec. +const uint32_t kSpdyMaxFrameSizeLimit = (1 << 24) - 1; + +// The initial value for the maximum frame payload size as per the spec. This is +// the maximum control frame size we accept. +const uint32_t kHttp2DefaultFramePayloadLimit = 1 << 14; + +// The maximum size of the control frames that we send, including the size of +// the header. This limit is arbitrary. We can enforce it here or at the +// application layer. We chose the framing layer, but this can be changed (or +// removed) if necessary later down the line. +const size_t kHttp2MaxControlFrameSendSize = kHttp2DefaultFramePayloadLimit - 1; + +// Number of octets in the frame header. +const size_t kFrameHeaderSize = 9; + +// The initial value for the maximum frame payload size as per the spec. This is +// the maximum control frame size we accept. +const uint32_t kHttp2DefaultFrameSizeLimit = + kHttp2DefaultFramePayloadLimit + kFrameHeaderSize; + +// The initial value for the maximum size of the header list, "unlimited" (max +// unsigned 32-bit int) as per the spec. +const uint32_t kSpdyInitialHeaderListSizeLimit = 0xFFFFFFFF; + +// Maximum window size for a Spdy stream or session. +const int32_t kSpdyMaximumWindowSize = 0x7FFFFFFF; // Max signed 32bit int + +// Maximum padding size in octets for one DATA or HEADERS or PUSH_PROMISE frame. +const int32_t kPaddingSizePerFrame = 256; + +// The HTTP/2 connection preface, which must be the first bytes sent by the +// client upon starting an HTTP/2 connection, and which must be followed by a +// SETTINGS frame. Note that even though |kHttp2ConnectionHeaderPrefix| is +// defined as a string literal with a null terminator, the actual connection +// preface is only the first |kHttp2ConnectionHeaderPrefixSize| bytes, which +// excludes the null terminator. +QUICHE_EXPORT extern const char* const kHttp2ConnectionHeaderPrefix; +const int kHttp2ConnectionHeaderPrefixSize = 24; + +// Wire values for HTTP2 frame types. +enum class SpdyFrameType : uint8_t { + DATA = 0x00, + HEADERS = 0x01, + PRIORITY = 0x02, + RST_STREAM = 0x03, + SETTINGS = 0x04, + PUSH_PROMISE = 0x05, + PING = 0x06, + GOAWAY = 0x07, + WINDOW_UPDATE = 0x08, + CONTINUATION = 0x09, + // ALTSVC is a public extension. + ALTSVC = 0x0a, + PRIORITY_UPDATE = 0x10, + ACCEPT_CH = 0x89, +}; + +// Flags on data packets. +enum SpdyDataFlags { + DATA_FLAG_NONE = 0x00, + DATA_FLAG_FIN = 0x01, + DATA_FLAG_PADDED = 0x08, +}; + +// Flags on control packets +enum SpdyControlFlags { + CONTROL_FLAG_NONE = 0x00, + CONTROL_FLAG_FIN = 0x01, +}; + +enum SpdyPingFlags { + PING_FLAG_ACK = 0x01, +}; + +// Used by HEADERS, PUSH_PROMISE, and CONTINUATION. +enum SpdyHeadersFlags { + HEADERS_FLAG_END_HEADERS = 0x04, + HEADERS_FLAG_PADDED = 0x08, + HEADERS_FLAG_PRIORITY = 0x20, +}; + +enum SpdyPushPromiseFlags { + PUSH_PROMISE_FLAG_END_PUSH_PROMISE = 0x04, + PUSH_PROMISE_FLAG_PADDED = 0x08, +}; + +enum Http2SettingsControlFlags { + SETTINGS_FLAG_ACK = 0x01, +}; + +// Wire values of HTTP/2 setting identifiers. +enum SpdyKnownSettingsId : SpdySettingsId { + // HPACK header table maximum size. + SETTINGS_HEADER_TABLE_SIZE = 0x1, + SETTINGS_MIN = SETTINGS_HEADER_TABLE_SIZE, + // Whether or not server push (PUSH_PROMISE) is enabled. + SETTINGS_ENABLE_PUSH = 0x2, + // The maximum number of simultaneous live streams in each direction. + SETTINGS_MAX_CONCURRENT_STREAMS = 0x3, + // Initial window size in bytes + SETTINGS_INITIAL_WINDOW_SIZE = 0x4, + // The size of the largest frame payload that a receiver is willing to accept. + SETTINGS_MAX_FRAME_SIZE = 0x5, + // The maximum size of header list that the sender is prepared to accept. + SETTINGS_MAX_HEADER_LIST_SIZE = 0x6, + // Enable Websockets over HTTP/2, see + // https://httpwg.org/specs/rfc8441.html + SETTINGS_ENABLE_CONNECT_PROTOCOL = 0x8, + // Disable HTTP/2 priorities, see + // https://tools.ietf.org/html/draft-ietf-httpbis-priority-02. + SETTINGS_DEPRECATE_HTTP2_PRIORITIES = 0x9, + SETTINGS_MAX = SETTINGS_DEPRECATE_HTTP2_PRIORITIES, + // Experimental setting used to configure an alternative write scheduler. + SETTINGS_EXPERIMENT_SCHEDULER = 0xFF45, +}; + +// This explicit operator is needed, otherwise compiler finds +// overloaded operator to be ambiguous. +QUICHE_EXPORT std::ostream& operator<<(std::ostream& out, + SpdyKnownSettingsId id); + +// This operator is needed, because SpdyFrameType is an enum class, +// therefore implicit conversion to underlying integer type is not allowed. +QUICHE_EXPORT std::ostream& operator<<(std::ostream& out, + SpdyFrameType frame_type); + +using SettingsMap = std::map; + +// HTTP/2 error codes, RFC 7540 Section 7. +enum SpdyErrorCode : uint32_t { + ERROR_CODE_NO_ERROR = 0x0, + ERROR_CODE_PROTOCOL_ERROR = 0x1, + ERROR_CODE_INTERNAL_ERROR = 0x2, + ERROR_CODE_FLOW_CONTROL_ERROR = 0x3, + ERROR_CODE_SETTINGS_TIMEOUT = 0x4, + ERROR_CODE_STREAM_CLOSED = 0x5, + ERROR_CODE_FRAME_SIZE_ERROR = 0x6, + ERROR_CODE_REFUSED_STREAM = 0x7, + ERROR_CODE_CANCEL = 0x8, + ERROR_CODE_COMPRESSION_ERROR = 0x9, + ERROR_CODE_CONNECT_ERROR = 0xa, + ERROR_CODE_ENHANCE_YOUR_CALM = 0xb, + ERROR_CODE_INADEQUATE_SECURITY = 0xc, + ERROR_CODE_HTTP_1_1_REQUIRED = 0xd, + ERROR_CODE_MAX = ERROR_CODE_HTTP_1_1_REQUIRED +}; + +// Type of priority write scheduler. +enum class WriteSchedulerType { + LIFO, // Last added stream has the highest priority. + SPDY, // Uses SPDY priorities described in + // https://www.chromium.org/spdy/spdy-protocol/spdy-protocol-draft3-1#TOC-2.3.3-Stream-priority. + HTTP2, // Uses HTTP2 (tree-style) priority described in + // https://tools.ietf.org/html/rfc7540#section-5.3. + FIFO, // Stream with the smallest stream ID has the highest priority. +}; + +// A SPDY priority is a number between 0 and 7 (inclusive). +typedef uint8_t SpdyPriority; + +// Lowest and Highest here refer to SPDY priorities as described in +// https://www.chromium.org/spdy/spdy-protocol/spdy-protocol-draft3-1#TOC-2.3.3-Stream-priority +const SpdyPriority kV3HighestPriority = 0; +const SpdyPriority kV3LowestPriority = 7; + +// Returns SPDY 3.x priority value clamped to the valid range of [0, 7]. +QUICHE_EXPORT SpdyPriority ClampSpdy3Priority(SpdyPriority priority); + +// HTTP/2 stream weights are integers in range [1, 256], as specified in RFC +// 7540 section 5.3.2. Default stream weight is defined in section 5.3.5. +const int kHttp2MinStreamWeight = 1; +const int kHttp2MaxStreamWeight = 256; +const int kHttp2DefaultStreamWeight = 16; + +// Returns HTTP/2 weight clamped to the valid range of [1, 256]. +QUICHE_EXPORT int ClampHttp2Weight(int weight); + +// Maps SPDY 3.x priority value in range [0, 7] to HTTP/2 weight value in range +// [1, 256], where priority 0 (i.e. highest precedence) corresponds to maximum +// weight 256 and priority 7 (lowest precedence) corresponds to minimum weight +// 1. +QUICHE_EXPORT int Spdy3PriorityToHttp2Weight(SpdyPriority priority); + +// Maps HTTP/2 weight value in range [1, 256] to SPDY 3.x priority value in +// range [0, 7], where minimum weight 1 corresponds to priority 7 (lowest +// precedence) and maximum weight 256 corresponds to priority 0 (highest +// precedence). +QUICHE_EXPORT SpdyPriority Http2WeightToSpdy3Priority(int weight); + +// Reserved ID for root stream of HTTP/2 stream dependency tree, as specified +// in RFC 7540 section 5.3.1. +const unsigned int kHttp2RootStreamId = 0; + +typedef uint64_t SpdyPingId; + +// Returns true if a given on-the-wire enumeration of a frame type is defined +// in a standardized HTTP/2 specification, false otherwise. +QUICHE_EXPORT bool IsDefinedFrameType(uint8_t frame_type_field); + +// Parses a frame type from an on-the-wire enumeration. +// Behavior is undefined for invalid frame type fields; consumers should first +// use IsValidFrameType() to verify validity of frame type fields. +QUICHE_EXPORT SpdyFrameType ParseFrameType(uint8_t frame_type_field); + +// Serializes a frame type to the on-the-wire value. +QUICHE_EXPORT uint8_t SerializeFrameType(SpdyFrameType frame_type); + +// (HTTP/2) All standard frame types except WINDOW_UPDATE are +// (stream-specific xor connection-level). Returns false iff we know +// the given frame type does not align with the given streamID. +QUICHE_EXPORT bool IsValidHTTP2FrameStreamId( + SpdyStreamId current_frame_stream_id, SpdyFrameType frame_type_field); + +// Serialize |frame_type| to string for logging/debugging. +QUICHE_EXPORT const char* FrameTypeToString(SpdyFrameType frame_type); + +// If |wire_setting_id| is the on-the-wire representation of a defined SETTINGS +// parameter, parse it to |*setting_id| and return true. +QUICHE_EXPORT bool ParseSettingsId(SpdySettingsId wire_setting_id, + SpdyKnownSettingsId* setting_id); + +// Returns a string representation of the |id| for logging/debugging. Returns +// the |id| prefixed with "SETTINGS_UNKNOWN_" for unknown SETTINGS IDs. To parse +// the |id| into a SpdyKnownSettingsId (if applicable), use ParseSettingsId(). +QUICHE_EXPORT std::string SettingsIdToString(SpdySettingsId id); + +// Parse |wire_error_code| to a SpdyErrorCode. +// Treat unrecognized error codes as INTERNAL_ERROR +// as recommended by the HTTP/2 specification. +QUICHE_EXPORT SpdyErrorCode ParseErrorCode(uint32_t wire_error_code); + +// Serialize RST_STREAM or GOAWAY frame error code to string +// for logging/debugging. +QUICHE_EXPORT const char* ErrorCodeToString(SpdyErrorCode error_code); + +// Serialize |type| to string for logging/debugging. +QUICHE_EXPORT const char* WriteSchedulerTypeToString(WriteSchedulerType type); + +// Minimum size of a frame, in octets. +const size_t kFrameMinimumSize = kFrameHeaderSize; + +// Minimum frame size for variable size frame types (includes mandatory fields), +// frame size for fixed size frames, in octets. + +const size_t kDataFrameMinimumSize = kFrameHeaderSize; +const size_t kHeadersFrameMinimumSize = kFrameHeaderSize; +// PRIORITY frame has stream_dependency (4 octets) and weight (1 octet) fields. +const size_t kPriorityFrameSize = kFrameHeaderSize + 5; +// RST_STREAM frame has error_code (4 octets) field. +const size_t kRstStreamFrameSize = kFrameHeaderSize + 4; +const size_t kSettingsFrameMinimumSize = kFrameHeaderSize; +const size_t kSettingsOneSettingSize = + sizeof(uint32_t) + sizeof(SpdySettingsId); +// PUSH_PROMISE frame has promised_stream_id (4 octet) field. +const size_t kPushPromiseFrameMinimumSize = kFrameHeaderSize + 4; +// PING frame has opaque_bytes (8 octet) field. +const size_t kPingFrameSize = kFrameHeaderSize + 8; +// GOAWAY frame has last_stream_id (4 octet) and error_code (4 octet) fields. +const size_t kGoawayFrameMinimumSize = kFrameHeaderSize + 8; +// WINDOW_UPDATE frame has window_size_increment (4 octet) field. +const size_t kWindowUpdateFrameSize = kFrameHeaderSize + 4; +const size_t kContinuationFrameMinimumSize = kFrameHeaderSize; +// ALTSVC frame has origin_len (2 octets) field. +const size_t kGetAltSvcFrameMinimumSize = kFrameHeaderSize + 2; +// PRIORITY_UPDATE frame has prioritized_stream_id (4 octets) field. +const size_t kPriorityUpdateFrameMinimumSize = kFrameHeaderSize + 4; +// ACCEPT_CH frame may have empty payload. +const size_t kAcceptChFrameMinimumSize = kFrameHeaderSize; +// Each ACCEPT_CH frame entry has a 16-bit origin length and a 16-bit value +// length. +const size_t kAcceptChFramePerEntryOverhead = 4; + +// Maximum possible configurable size of a frame in octets. +const size_t kMaxFrameSizeLimit = kSpdyMaxFrameSizeLimit + kFrameHeaderSize; +// Size of a header block size field. +const size_t kSizeOfSizeField = sizeof(uint32_t); +// Initial window size for a stream in bytes. +const int32_t kInitialStreamWindowSize = 64 * 1024 - 1; +// Initial window size for a session in bytes. +const int32_t kInitialSessionWindowSize = 64 * 1024 - 1; +// The NPN string for HTTP2, "h2". +QUICHE_EXPORT extern const char* const kHttp2Npn; +// An estimate size of the HPACK overhead for each header field. 1 bytes for +// indexed literal, 1 bytes for key literal and length encoding, and 2 bytes for +// value literal and length encoding. +const size_t kPerHeaderHpackOverhead = 4; + +// Names of pseudo-headers defined for HTTP/2 requests. +QUICHE_EXPORT extern const char* const kHttp2AuthorityHeader; +QUICHE_EXPORT extern const char* const kHttp2MethodHeader; +QUICHE_EXPORT extern const char* const kHttp2PathHeader; +QUICHE_EXPORT extern const char* const kHttp2SchemeHeader; +QUICHE_EXPORT extern const char* const kHttp2ProtocolHeader; + +// Name of pseudo-header defined for HTTP/2 responses. +QUICHE_EXPORT extern const char* const kHttp2StatusHeader; + +QUICHE_EXPORT size_t GetNumberRequiredContinuationFrames(size_t size); + +// Variant type (i.e. tagged union) that is either a SPDY 3.x priority value, +// or else an HTTP/2 stream dependency tuple {parent stream ID, weight, +// exclusive bit}. Templated to allow for use by QUIC code; SPDY and HTTP/2 +// code should use the concrete type instantiation SpdyStreamPrecedence. +template +class QUICHE_EXPORT StreamPrecedence { + public: + // Constructs instance that is a SPDY 3.x priority. Clamps priority value to + // the valid range [0, 7]. + explicit StreamPrecedence(SpdyPriority priority) + : is_spdy3_priority_(true), + spdy3_priority_(ClampSpdy3Priority(priority)) {} + + // Constructs instance that is an HTTP/2 stream weight, parent stream ID, and + // exclusive bit. Clamps stream weight to the valid range [1, 256]. + StreamPrecedence(StreamIdType parent_id, int weight, bool is_exclusive) + : is_spdy3_priority_(false), + http2_stream_dependency_{parent_id, ClampHttp2Weight(weight), + is_exclusive} {} + + // Intentionally copyable, to support pass by value. + StreamPrecedence(const StreamPrecedence& other) = default; + StreamPrecedence& operator=(const StreamPrecedence& other) = default; + + // Returns true if this instance is a SPDY 3.x priority, or false if this + // instance is an HTTP/2 stream dependency. + bool is_spdy3_priority() const { return is_spdy3_priority_; } + + // Returns SPDY 3.x priority value. If |is_spdy3_priority()| is true, this is + // the value provided at construction, clamped to the legal priority + // range. Otherwise, it is the HTTP/2 stream weight mapped to a SPDY 3.x + // priority value, where minimum weight 1 corresponds to priority 7 (lowest + // precedence) and maximum weight 256 corresponds to priority 0 (highest + // precedence). + SpdyPriority spdy3_priority() const { + return is_spdy3_priority_ + ? spdy3_priority_ + : Http2WeightToSpdy3Priority(http2_stream_dependency_.weight); + } + + // Returns HTTP/2 parent stream ID. If |is_spdy3_priority()| is false, this is + // the value provided at construction, otherwise it is |kHttp2RootStreamId|. + StreamIdType parent_id() const { + return is_spdy3_priority_ ? kHttp2RootStreamId + : http2_stream_dependency_.parent_id; + } + + // Returns HTTP/2 stream weight. If |is_spdy3_priority()| is false, this is + // the value provided at construction, clamped to the legal weight + // range. Otherwise, it is the SPDY 3.x priority value mapped to an HTTP/2 + // stream weight, where priority 0 (i.e. highest precedence) corresponds to + // maximum weight 256 and priority 7 (lowest precedence) corresponds to + // minimum weight 1. + int weight() const { + return is_spdy3_priority_ ? Spdy3PriorityToHttp2Weight(spdy3_priority_) + : http2_stream_dependency_.weight; + } + + // Returns HTTP/2 parent stream exclusivity. If |is_spdy3_priority()| is + // false, this is the value provided at construction, otherwise it is false. + bool is_exclusive() const { + return !is_spdy3_priority_ && http2_stream_dependency_.is_exclusive; + } + + // Facilitates test assertions. + bool operator==(const StreamPrecedence& other) const { + if (is_spdy3_priority()) { + return other.is_spdy3_priority() && + (spdy3_priority() == other.spdy3_priority()); + } else { + return !other.is_spdy3_priority() && (parent_id() == other.parent_id()) && + (weight() == other.weight()) && + (is_exclusive() == other.is_exclusive()); + } + } + + bool operator!=(const StreamPrecedence& other) const { + return !(*this == other); + } + + private: + struct QUICHE_EXPORT Http2StreamDependency { + StreamIdType parent_id; + int weight; + bool is_exclusive; + }; + + bool is_spdy3_priority_; + union { + SpdyPriority spdy3_priority_; + Http2StreamDependency http2_stream_dependency_; + }; +}; + +typedef StreamPrecedence SpdyStreamPrecedence; + +class SpdyFrameVisitor; + +// Intermediate representation for HTTP2 frames. +class QUICHE_EXPORT SpdyFrameIR { + public: + virtual ~SpdyFrameIR() {} + + virtual void Visit(SpdyFrameVisitor* visitor) const = 0; + virtual SpdyFrameType frame_type() const = 0; + SpdyStreamId stream_id() const { return stream_id_; } + virtual bool fin() const; + // Returns an estimate of the size of the serialized frame, without applying + // compression. May not be exact. + virtual size_t size() const = 0; + + // Returns the number of bytes of flow control window that would be consumed + // by this frame if written to the wire. + virtual int flow_control_window_consumed() const; + + protected: + SpdyFrameIR() : stream_id_(0) {} + explicit SpdyFrameIR(SpdyStreamId stream_id) : stream_id_(stream_id) {} + SpdyFrameIR(const SpdyFrameIR&) = delete; + SpdyFrameIR& operator=(const SpdyFrameIR&) = delete; + + private: + SpdyStreamId stream_id_; +}; + +// Abstract class intended to be inherited by IRs that have the option of a FIN +// flag. +class QUICHE_EXPORT SpdyFrameWithFinIR : public SpdyFrameIR { + public: + ~SpdyFrameWithFinIR() override {} + bool fin() const override; + void set_fin(bool fin) { fin_ = fin; } + + protected: + explicit SpdyFrameWithFinIR(SpdyStreamId stream_id) + : SpdyFrameIR(stream_id), fin_(false) {} + SpdyFrameWithFinIR(const SpdyFrameWithFinIR&) = delete; + SpdyFrameWithFinIR& operator=(const SpdyFrameWithFinIR&) = delete; + + private: + bool fin_; +}; + +// Abstract class intended to be inherited by IRs that contain a header +// block. Implies SpdyFrameWithFinIR. +class QUICHE_EXPORT SpdyFrameWithHeaderBlockIR : public SpdyFrameWithFinIR { + public: + ~SpdyFrameWithHeaderBlockIR() override; + + const Http2HeaderBlock& header_block() const { return header_block_; } + void set_header_block(Http2HeaderBlock header_block) { + // Deep copy. + header_block_ = std::move(header_block); + } + void SetHeader(absl::string_view name, absl::string_view value) { + header_block_[name] = value; + } + + protected: + SpdyFrameWithHeaderBlockIR(SpdyStreamId stream_id, + Http2HeaderBlock header_block); + SpdyFrameWithHeaderBlockIR(const SpdyFrameWithHeaderBlockIR&) = delete; + SpdyFrameWithHeaderBlockIR& operator=(const SpdyFrameWithHeaderBlockIR&) = + delete; + + private: + Http2HeaderBlock header_block_; +}; + +class QUICHE_EXPORT SpdyDataIR : public SpdyFrameWithFinIR { + public: + // Performs a deep copy on data. + SpdyDataIR(SpdyStreamId stream_id, absl::string_view data); + + // Performs a deep copy on data. + SpdyDataIR(SpdyStreamId stream_id, const char* data); + + // Moves data into data_store_. Makes a copy if passed a non-movable string. + SpdyDataIR(SpdyStreamId stream_id, std::string data); + + // Use in conjunction with SetDataShallow() for shallow-copy on data. + explicit SpdyDataIR(SpdyStreamId stream_id); + SpdyDataIR(const SpdyDataIR&) = delete; + SpdyDataIR& operator=(const SpdyDataIR&) = delete; + + ~SpdyDataIR() override; + + const char* data() const { return data_; } + size_t data_len() const { return data_len_; } + + bool padded() const { return padded_; } + + int padding_payload_len() const { return padding_payload_len_; } + + void set_padding_len(int padding_len) { + QUICHE_DCHECK_GT(padding_len, 0); + QUICHE_DCHECK_LE(padding_len, kPaddingSizePerFrame); + padded_ = true; + // The pad field takes one octet on the wire. + padding_payload_len_ = padding_len - 1; + } + + // Deep-copy of data (keep private copy). + void SetDataDeep(absl::string_view data) { + data_store_ = std::make_unique(data.data(), data.size()); + data_ = data_store_->data(); + data_len_ = data.size(); + } + + // Shallow-copy of data (do not keep private copy). + void SetDataShallow(absl::string_view data) { + data_store_.reset(); + data_ = data.data(); + data_len_ = data.size(); + } + + // Use this method if we don't have a contiguous buffer and only + // need a length. + void SetDataShallow(size_t len) { + data_store_.reset(); + data_ = nullptr; + data_len_ = len; + } + + void Visit(SpdyFrameVisitor* visitor) const override; + + SpdyFrameType frame_type() const override; + + int flow_control_window_consumed() const override; + + size_t size() const override; + + private: + // Used to store data that this SpdyDataIR should own. + std::unique_ptr data_store_; + const char* data_; + size_t data_len_; + + bool padded_; + // padding_payload_len_ = desired padding length - len(padding length field). + int padding_payload_len_; +}; + +class QUICHE_EXPORT SpdyRstStreamIR : public SpdyFrameIR { + public: + SpdyRstStreamIR(SpdyStreamId stream_id, SpdyErrorCode error_code); + SpdyRstStreamIR(const SpdyRstStreamIR&) = delete; + SpdyRstStreamIR& operator=(const SpdyRstStreamIR&) = delete; + + ~SpdyRstStreamIR() override; + + SpdyErrorCode error_code() const { return error_code_; } + void set_error_code(SpdyErrorCode error_code) { error_code_ = error_code; } + + void Visit(SpdyFrameVisitor* visitor) const override; + + SpdyFrameType frame_type() const override; + + size_t size() const override; + + private: + SpdyErrorCode error_code_; +}; + +class QUICHE_EXPORT SpdySettingsIR : public SpdyFrameIR { + public: + SpdySettingsIR(); + SpdySettingsIR(const SpdySettingsIR&) = delete; + SpdySettingsIR& operator=(const SpdySettingsIR&) = delete; + ~SpdySettingsIR() override; + + // Overwrites as appropriate. + const SettingsMap& values() const { return values_; } + void AddSetting(SpdySettingsId id, int32_t value) { values_[id] = value; } + + bool is_ack() const { return is_ack_; } + void set_is_ack(bool is_ack) { is_ack_ = is_ack; } + + void Visit(SpdyFrameVisitor* visitor) const override; + + SpdyFrameType frame_type() const override; + + size_t size() const override; + + private: + SettingsMap values_; + bool is_ack_; +}; + +class QUICHE_EXPORT SpdyPingIR : public SpdyFrameIR { + public: + explicit SpdyPingIR(SpdyPingId id) : id_(id), is_ack_(false) {} + SpdyPingIR(const SpdyPingIR&) = delete; + SpdyPingIR& operator=(const SpdyPingIR&) = delete; + SpdyPingId id() const { return id_; } + + bool is_ack() const { return is_ack_; } + void set_is_ack(bool is_ack) { is_ack_ = is_ack; } + + void Visit(SpdyFrameVisitor* visitor) const override; + + SpdyFrameType frame_type() const override; + + size_t size() const override; + + private: + SpdyPingId id_; + bool is_ack_; +}; + +class QUICHE_EXPORT SpdyGoAwayIR : public SpdyFrameIR { + public: + // References description, doesn't copy it, so description must outlast + // this SpdyGoAwayIR. + SpdyGoAwayIR(SpdyStreamId last_good_stream_id, SpdyErrorCode error_code, + absl::string_view description); + + // References description, doesn't copy it, so description must outlast + // this SpdyGoAwayIR. + SpdyGoAwayIR(SpdyStreamId last_good_stream_id, SpdyErrorCode error_code, + const char* description); + + // Moves description into description_store_, so caller doesn't need to + // keep description live after constructing this SpdyGoAwayIR. + SpdyGoAwayIR(SpdyStreamId last_good_stream_id, SpdyErrorCode error_code, + std::string description); + SpdyGoAwayIR(const SpdyGoAwayIR&) = delete; + SpdyGoAwayIR& operator=(const SpdyGoAwayIR&) = delete; + + ~SpdyGoAwayIR() override; + + SpdyStreamId last_good_stream_id() const { return last_good_stream_id_; } + void set_last_good_stream_id(SpdyStreamId last_good_stream_id) { + QUICHE_DCHECK_EQ(0u, last_good_stream_id & ~kStreamIdMask); + last_good_stream_id_ = last_good_stream_id; + } + SpdyErrorCode error_code() const { return error_code_; } + void set_error_code(SpdyErrorCode error_code) { + // TODO(hkhalil): Check valid ranges of error_code? + error_code_ = error_code; + } + + const absl::string_view& description() const { return description_; } + + void Visit(SpdyFrameVisitor* visitor) const override; + + SpdyFrameType frame_type() const override; + + size_t size() const override; + + private: + SpdyStreamId last_good_stream_id_; + SpdyErrorCode error_code_; + const std::string description_store_; + const absl::string_view description_; +}; + +class QUICHE_EXPORT SpdyHeadersIR : public SpdyFrameWithHeaderBlockIR { + public: + explicit SpdyHeadersIR(SpdyStreamId stream_id) + : SpdyHeadersIR(stream_id, Http2HeaderBlock()) {} + SpdyHeadersIR(SpdyStreamId stream_id, Http2HeaderBlock header_block) + : SpdyFrameWithHeaderBlockIR(stream_id, std::move(header_block)) {} + SpdyHeadersIR(const SpdyHeadersIR&) = delete; + SpdyHeadersIR& operator=(const SpdyHeadersIR&) = delete; + + void Visit(SpdyFrameVisitor* visitor) const override; + + SpdyFrameType frame_type() const override; + + size_t size() const override; + + bool has_priority() const { return has_priority_; } + void set_has_priority(bool has_priority) { has_priority_ = has_priority; } + int weight() const { return weight_; } + void set_weight(int weight) { weight_ = weight; } + SpdyStreamId parent_stream_id() const { return parent_stream_id_; } + void set_parent_stream_id(SpdyStreamId id) { parent_stream_id_ = id; } + bool exclusive() const { return exclusive_; } + void set_exclusive(bool exclusive) { exclusive_ = exclusive; } + bool padded() const { return padded_; } + int padding_payload_len() const { return padding_payload_len_; } + void set_padding_len(int padding_len) { + QUICHE_DCHECK_GT(padding_len, 0); + QUICHE_DCHECK_LE(padding_len, kPaddingSizePerFrame); + padded_ = true; + // The pad field takes one octet on the wire. + padding_payload_len_ = padding_len - 1; + } + + private: + bool has_priority_ = false; + int weight_ = kHttp2DefaultStreamWeight; + SpdyStreamId parent_stream_id_ = 0; + bool exclusive_ = false; + bool padded_ = false; + int padding_payload_len_ = 0; +}; + +class QUICHE_EXPORT SpdyWindowUpdateIR : public SpdyFrameIR { + public: + SpdyWindowUpdateIR(SpdyStreamId stream_id, int32_t delta) + : SpdyFrameIR(stream_id) { + set_delta(delta); + } + SpdyWindowUpdateIR(const SpdyWindowUpdateIR&) = delete; + SpdyWindowUpdateIR& operator=(const SpdyWindowUpdateIR&) = delete; + + int32_t delta() const { return delta_; } + void set_delta(int32_t delta) { + QUICHE_DCHECK_LE(0, delta); + QUICHE_DCHECK_LE(delta, kSpdyMaximumWindowSize); + delta_ = delta; + } + + void Visit(SpdyFrameVisitor* visitor) const override; + + SpdyFrameType frame_type() const override; + + size_t size() const override; + + private: + int32_t delta_; +}; + +class QUICHE_EXPORT SpdyPushPromiseIR : public SpdyFrameWithHeaderBlockIR { + public: + SpdyPushPromiseIR(SpdyStreamId stream_id, SpdyStreamId promised_stream_id) + : SpdyPushPromiseIR(stream_id, promised_stream_id, Http2HeaderBlock()) {} + SpdyPushPromiseIR(SpdyStreamId stream_id, SpdyStreamId promised_stream_id, + Http2HeaderBlock header_block) + : SpdyFrameWithHeaderBlockIR(stream_id, std::move(header_block)), + promised_stream_id_(promised_stream_id), + padded_(false), + padding_payload_len_(0) {} + SpdyPushPromiseIR(const SpdyPushPromiseIR&) = delete; + SpdyPushPromiseIR& operator=(const SpdyPushPromiseIR&) = delete; + SpdyStreamId promised_stream_id() const { return promised_stream_id_; } + + void Visit(SpdyFrameVisitor* visitor) const override; + + SpdyFrameType frame_type() const override; + + size_t size() const override; + + bool padded() const { return padded_; } + int padding_payload_len() const { return padding_payload_len_; } + void set_padding_len(int padding_len) { + QUICHE_DCHECK_GT(padding_len, 0); + QUICHE_DCHECK_LE(padding_len, kPaddingSizePerFrame); + padded_ = true; + // The pad field takes one octet on the wire. + padding_payload_len_ = padding_len - 1; + } + + private: + SpdyStreamId promised_stream_id_; + + bool padded_; + int padding_payload_len_; +}; + +class QUICHE_EXPORT SpdyContinuationIR : public SpdyFrameIR { + public: + explicit SpdyContinuationIR(SpdyStreamId stream_id); + SpdyContinuationIR(const SpdyContinuationIR&) = delete; + SpdyContinuationIR& operator=(const SpdyContinuationIR&) = delete; + ~SpdyContinuationIR() override; + + void Visit(SpdyFrameVisitor* visitor) const override; + + SpdyFrameType frame_type() const override; + + bool end_headers() const { return end_headers_; } + void set_end_headers(bool end_headers) { end_headers_ = end_headers; } + const std::string& encoding() const { return encoding_; } + void take_encoding(std::string encoding) { encoding_ = std::move(encoding); } + size_t size() const override; + + private: + std::string encoding_; + bool end_headers_; +}; + +class QUICHE_EXPORT SpdyAltSvcIR : public SpdyFrameIR { + public: + explicit SpdyAltSvcIR(SpdyStreamId stream_id); + SpdyAltSvcIR(const SpdyAltSvcIR&) = delete; + SpdyAltSvcIR& operator=(const SpdyAltSvcIR&) = delete; + ~SpdyAltSvcIR() override; + + std::string origin() const { return origin_; } + const SpdyAltSvcWireFormat::AlternativeServiceVector& altsvc_vector() const { + return altsvc_vector_; + } + + void set_origin(std::string origin) { origin_ = std::move(origin); } + void add_altsvc(const SpdyAltSvcWireFormat::AlternativeService& altsvc) { + altsvc_vector_.push_back(altsvc); + } + + void Visit(SpdyFrameVisitor* visitor) const override; + + SpdyFrameType frame_type() const override; + + size_t size() const override; + + private: + std::string origin_; + SpdyAltSvcWireFormat::AlternativeServiceVector altsvc_vector_; +}; + +class QUICHE_EXPORT SpdyPriorityIR : public SpdyFrameIR { + public: + SpdyPriorityIR(SpdyStreamId stream_id, SpdyStreamId parent_stream_id, + int weight, bool exclusive) + : SpdyFrameIR(stream_id), + parent_stream_id_(parent_stream_id), + weight_(weight), + exclusive_(exclusive) {} + SpdyPriorityIR(const SpdyPriorityIR&) = delete; + SpdyPriorityIR& operator=(const SpdyPriorityIR&) = delete; + SpdyStreamId parent_stream_id() const { return parent_stream_id_; } + int weight() const { return weight_; } + bool exclusive() const { return exclusive_; } + + void Visit(SpdyFrameVisitor* visitor) const override; + + SpdyFrameType frame_type() const override; + + size_t size() const override; + + private: + SpdyStreamId parent_stream_id_; + int weight_; + bool exclusive_; +}; + +class QUICHE_EXPORT SpdyPriorityUpdateIR : public SpdyFrameIR { + public: + SpdyPriorityUpdateIR(SpdyStreamId stream_id, + SpdyStreamId prioritized_stream_id, + std::string priority_field_value) + : SpdyFrameIR(stream_id), + prioritized_stream_id_(prioritized_stream_id), + priority_field_value_(std::move(priority_field_value)) {} + SpdyPriorityUpdateIR(const SpdyPriorityUpdateIR&) = delete; + SpdyPriorityUpdateIR& operator=(const SpdyPriorityUpdateIR&) = delete; + SpdyStreamId prioritized_stream_id() const { return prioritized_stream_id_; } + const std::string& priority_field_value() const { + return priority_field_value_; + } + + void Visit(SpdyFrameVisitor* visitor) const override; + + SpdyFrameType frame_type() const override; + + size_t size() const override; + + private: + SpdyStreamId prioritized_stream_id_; + std::string priority_field_value_; +}; + +struct QUICHE_EXPORT AcceptChOriginValuePair { + std::string origin; + std::string value; + bool operator==(const AcceptChOriginValuePair& rhs) const { + return origin == rhs.origin && value == rhs.value; + } +}; + +class QUICHE_EXPORT SpdyAcceptChIR : public SpdyFrameIR { + public: + SpdyAcceptChIR(std::vector entries) + : entries_(std::move(entries)) {} + SpdyAcceptChIR(const SpdyAcceptChIR&) = delete; + SpdyAcceptChIR& operator=(const SpdyAcceptChIR&) = delete; + + void Visit(SpdyFrameVisitor* visitor) const override; + + SpdyFrameType frame_type() const override; + + size_t size() const override; + + const std::vector& entries() const { + return entries_; + } + + private: + std::vector entries_; +}; + +// Represents a frame of unrecognized type. +class QUICHE_EXPORT SpdyUnknownIR : public SpdyFrameIR { + public: + SpdyUnknownIR(SpdyStreamId stream_id, uint8_t type, uint8_t flags, + std::string payload) + : SpdyFrameIR(stream_id), + type_(type), + flags_(flags), + length_(payload.size()), + payload_(std::move(payload)) {} + SpdyUnknownIR(const SpdyUnknownIR&) = delete; + SpdyUnknownIR& operator=(const SpdyUnknownIR&) = delete; + uint8_t type() const { return type_; } + uint8_t flags() const { return flags_; } + size_t length() const { return length_; } + const std::string& payload() const { return payload_; } + + void Visit(SpdyFrameVisitor* visitor) const override; + + SpdyFrameType frame_type() const override; + + int flow_control_window_consumed() const override; + + size_t size() const override; + + protected: + // Allows subclasses to overwrite the default payload length. + void set_length(size_t length) { length_ = length; } + + private: + uint8_t type_; + uint8_t flags_; + size_t length_; + const std::string payload_; +}; + +class QUICHE_EXPORT SpdySerializedFrame { + public: + SpdySerializedFrame() + : frame_(const_cast("")), size_(0), owns_buffer_(false) {} + + // Create a valid SpdySerializedFrame using a pre-created buffer. + // If |owns_buffer| is true, this class takes ownership of the buffer and will + // delete it on cleanup. The buffer must have been created using new char[]. + // If |owns_buffer| is false, the caller retains ownership of the buffer and + // is responsible for making sure the buffer outlives this frame. In other + // words, this class does NOT create a copy of the buffer. + SpdySerializedFrame(char* data, size_t size, bool owns_buffer) + : frame_(data), size_(size), owns_buffer_(owns_buffer) {} + + SpdySerializedFrame(SpdySerializedFrame&& other) + : frame_(other.frame_), + size_(other.size_), + owns_buffer_(other.owns_buffer_) { + // |other| is no longer responsible for the buffer. + other.owns_buffer_ = false; + } + SpdySerializedFrame(const SpdySerializedFrame&) = delete; + SpdySerializedFrame& operator=(const SpdySerializedFrame&) = delete; + + SpdySerializedFrame& operator=(SpdySerializedFrame&& other) { + // Free buffer if necessary. + if (owns_buffer_) { + delete[] frame_; + } + // Take over |other|. + frame_ = other.frame_; + size_ = other.size_; + owns_buffer_ = other.owns_buffer_; + // |other| is no longer responsible for the buffer. + other.owns_buffer_ = false; + return *this; + } + + ~SpdySerializedFrame() { + if (owns_buffer_) { + delete[] frame_; + } + } + + // Provides access to the frame bytes, which is a buffer containing the frame + // packed as expected for sending over the wire. + char* data() const { return frame_; } + + // Returns the actual size of the underlying buffer. + size_t size() const { return size_; } + + operator absl::string_view() const { + return absl::string_view{frame_, size_}; + } + + operator std::string() const { return std::string{frame_, size_}; } + + // Returns a buffer containing the contents of the frame, of which the caller + // takes ownership, and clears this SpdySerializedFrame. + char* ReleaseBuffer() { + char* buffer; + if (owns_buffer_) { + // If the buffer is owned, relinquish ownership to the caller. + buffer = frame_; + owns_buffer_ = false; + } else { + // Otherwise, we need to make a copy to give to the caller. + buffer = new char[size_]; + memcpy(buffer, frame_, size_); + } + *this = SpdySerializedFrame(); + return buffer; + } + + protected: + char* frame_; + + private: + size_t size_; + bool owns_buffer_; +}; + +// This interface is for classes that want to process SpdyFrameIRs without +// having to know what type they are. An instance of this interface can be +// passed to a SpdyFrameIR's Visit method, and the appropriate type-specific +// method of this class will be called. +class QUICHE_EXPORT SpdyFrameVisitor { + public: + SpdyFrameVisitor() {} + SpdyFrameVisitor(const SpdyFrameVisitor&) = delete; + SpdyFrameVisitor& operator=(const SpdyFrameVisitor&) = delete; + virtual ~SpdyFrameVisitor() {} + + virtual void VisitRstStream(const SpdyRstStreamIR& rst_stream) = 0; + virtual void VisitSettings(const SpdySettingsIR& settings) = 0; + virtual void VisitPing(const SpdyPingIR& ping) = 0; + virtual void VisitGoAway(const SpdyGoAwayIR& goaway) = 0; + virtual void VisitHeaders(const SpdyHeadersIR& headers) = 0; + virtual void VisitWindowUpdate(const SpdyWindowUpdateIR& window_update) = 0; + virtual void VisitPushPromise(const SpdyPushPromiseIR& push_promise) = 0; + virtual void VisitContinuation(const SpdyContinuationIR& continuation) = 0; + virtual void VisitAltSvc(const SpdyAltSvcIR& altsvc) = 0; + virtual void VisitPriority(const SpdyPriorityIR& priority) = 0; + virtual void VisitData(const SpdyDataIR& data) = 0; + virtual void VisitPriorityUpdate( + const SpdyPriorityUpdateIR& priority_update) = 0; + virtual void VisitAcceptCh(const SpdyAcceptChIR& accept_ch) = 0; + virtual void VisitUnknown(const SpdyUnknownIR& /*unknown*/) { + // TODO(birenroy): make abstract. + } +}; + +// Optionally, and in addition to SpdyFramerVisitorInterface, a class supporting +// SpdyFramerDebugVisitorInterface may be used in conjunction with SpdyFramer in +// order to extract debug/internal information about the SpdyFramer as it +// operates. +// +// Most HTTP2 implementations need not bother with this interface at all. +class QUICHE_EXPORT SpdyFramerDebugVisitorInterface { + public: + virtual ~SpdyFramerDebugVisitorInterface() {} + + // Called after compressing a frame with a payload of + // a list of name-value pairs. + // |payload_len| is the uncompressed payload size. + // |frame_len| is the compressed frame size. + virtual void OnSendCompressedFrame(SpdyStreamId /*stream_id*/, + SpdyFrameType /*type*/, + size_t /*payload_len*/, + size_t /*frame_len*/) {} + + // Called when a frame containing a compressed payload of + // name-value pairs is received. + // |frame_len| is the compressed frame size. + virtual void OnReceiveCompressedFrame(SpdyStreamId /*stream_id*/, + SpdyFrameType /*type*/, + size_t /*frame_len*/) {} +}; + +// Calculates the number of bytes required to serialize a SpdyHeadersIR, not +// including the bytes to be used for the encoded header set. +size_t GetHeaderFrameSizeSansBlock(const SpdyHeadersIR& header_ir); + +// Calculates the number of bytes required to serialize a SpdyPushPromiseIR, +// not including the bytes to be used for the encoded header set. +size_t GetPushPromiseFrameSizeSansBlock( + const SpdyPushPromiseIR& push_promise_ir); + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_SPDY_PROTOCOL_H_ diff --git a/quiche/spdy/core/spdy_protocol_test.cc b/quiche/spdy/core/spdy_protocol_test.cc new file mode 100644 index 000000000000..4602e7f70967 --- /dev/null +++ b/quiche/spdy/core/spdy_protocol_test.cc @@ -0,0 +1,275 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/spdy_protocol.h" + +#include +#include +#include + +#include "quiche/common/platform/api/quiche_expect_bug.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/spdy/core/spdy_bitmasks.h" +#include "quiche/spdy/test_tools/spdy_test_utils.h" + +namespace spdy { + +std::ostream& operator<<(std::ostream& os, + const SpdyStreamPrecedence precedence) { + if (precedence.is_spdy3_priority()) { + os << "SpdyStreamPrecedence[spdy3_priority=" << precedence.spdy3_priority() + << "]"; + } else { + os << "SpdyStreamPrecedence[parent_id=" << precedence.parent_id() + << ", weight=" << precedence.weight() + << ", is_exclusive=" << precedence.is_exclusive() << "]"; + } + return os; +} + +namespace test { + +TEST(SpdyProtocolTest, ClampSpdy3Priority) { + EXPECT_QUICHE_BUG(EXPECT_EQ(7, ClampSpdy3Priority(8)), "Invalid priority: 8"); + EXPECT_EQ(kV3LowestPriority, ClampSpdy3Priority(kV3LowestPriority)); + EXPECT_EQ(kV3HighestPriority, ClampSpdy3Priority(kV3HighestPriority)); +} + +TEST(SpdyProtocolTest, ClampHttp2Weight) { + EXPECT_QUICHE_BUG(EXPECT_EQ(kHttp2MinStreamWeight, ClampHttp2Weight(0)), + "Invalid weight: 0"); + EXPECT_QUICHE_BUG(EXPECT_EQ(kHttp2MaxStreamWeight, ClampHttp2Weight(300)), + "Invalid weight: 300"); + EXPECT_EQ(kHttp2MinStreamWeight, ClampHttp2Weight(kHttp2MinStreamWeight)); + EXPECT_EQ(kHttp2MaxStreamWeight, ClampHttp2Weight(kHttp2MaxStreamWeight)); +} + +TEST(SpdyProtocolTest, Spdy3PriorityToHttp2Weight) { + EXPECT_EQ(256, Spdy3PriorityToHttp2Weight(0)); + EXPECT_EQ(220, Spdy3PriorityToHttp2Weight(1)); + EXPECT_EQ(183, Spdy3PriorityToHttp2Weight(2)); + EXPECT_EQ(147, Spdy3PriorityToHttp2Weight(3)); + EXPECT_EQ(110, Spdy3PriorityToHttp2Weight(4)); + EXPECT_EQ(74, Spdy3PriorityToHttp2Weight(5)); + EXPECT_EQ(37, Spdy3PriorityToHttp2Weight(6)); + EXPECT_EQ(1, Spdy3PriorityToHttp2Weight(7)); +} + +TEST(SpdyProtocolTest, Http2WeightToSpdy3Priority) { + EXPECT_EQ(0u, Http2WeightToSpdy3Priority(256)); + EXPECT_EQ(0u, Http2WeightToSpdy3Priority(221)); + EXPECT_EQ(1u, Http2WeightToSpdy3Priority(220)); + EXPECT_EQ(1u, Http2WeightToSpdy3Priority(184)); + EXPECT_EQ(2u, Http2WeightToSpdy3Priority(183)); + EXPECT_EQ(2u, Http2WeightToSpdy3Priority(148)); + EXPECT_EQ(3u, Http2WeightToSpdy3Priority(147)); + EXPECT_EQ(3u, Http2WeightToSpdy3Priority(111)); + EXPECT_EQ(4u, Http2WeightToSpdy3Priority(110)); + EXPECT_EQ(4u, Http2WeightToSpdy3Priority(75)); + EXPECT_EQ(5u, Http2WeightToSpdy3Priority(74)); + EXPECT_EQ(5u, Http2WeightToSpdy3Priority(38)); + EXPECT_EQ(6u, Http2WeightToSpdy3Priority(37)); + EXPECT_EQ(6u, Http2WeightToSpdy3Priority(2)); + EXPECT_EQ(7u, Http2WeightToSpdy3Priority(1)); +} + +TEST(SpdyProtocolTest, IsValidHTTP2FrameStreamId) { + // Stream-specific frames must have non-zero stream ids + EXPECT_TRUE(IsValidHTTP2FrameStreamId(1, SpdyFrameType::DATA)); + EXPECT_FALSE(IsValidHTTP2FrameStreamId(0, SpdyFrameType::DATA)); + EXPECT_TRUE(IsValidHTTP2FrameStreamId(1, SpdyFrameType::HEADERS)); + EXPECT_FALSE(IsValidHTTP2FrameStreamId(0, SpdyFrameType::HEADERS)); + EXPECT_TRUE(IsValidHTTP2FrameStreamId(1, SpdyFrameType::PRIORITY)); + EXPECT_FALSE(IsValidHTTP2FrameStreamId(0, SpdyFrameType::PRIORITY)); + EXPECT_TRUE(IsValidHTTP2FrameStreamId(1, SpdyFrameType::RST_STREAM)); + EXPECT_FALSE(IsValidHTTP2FrameStreamId(0, SpdyFrameType::RST_STREAM)); + EXPECT_TRUE(IsValidHTTP2FrameStreamId(1, SpdyFrameType::CONTINUATION)); + EXPECT_FALSE(IsValidHTTP2FrameStreamId(0, SpdyFrameType::CONTINUATION)); + EXPECT_TRUE(IsValidHTTP2FrameStreamId(1, SpdyFrameType::PUSH_PROMISE)); + EXPECT_FALSE(IsValidHTTP2FrameStreamId(0, SpdyFrameType::PUSH_PROMISE)); + + // Connection-level frames must have zero stream ids + EXPECT_FALSE(IsValidHTTP2FrameStreamId(1, SpdyFrameType::GOAWAY)); + EXPECT_TRUE(IsValidHTTP2FrameStreamId(0, SpdyFrameType::GOAWAY)); + EXPECT_FALSE(IsValidHTTP2FrameStreamId(1, SpdyFrameType::SETTINGS)); + EXPECT_TRUE(IsValidHTTP2FrameStreamId(0, SpdyFrameType::SETTINGS)); + EXPECT_FALSE(IsValidHTTP2FrameStreamId(1, SpdyFrameType::PING)); + EXPECT_TRUE(IsValidHTTP2FrameStreamId(0, SpdyFrameType::PING)); + + // Frames that are neither stream-specific nor connection-level + // should not have their stream id declared invalid + EXPECT_TRUE(IsValidHTTP2FrameStreamId(1, SpdyFrameType::WINDOW_UPDATE)); + EXPECT_TRUE(IsValidHTTP2FrameStreamId(0, SpdyFrameType::WINDOW_UPDATE)); +} + +TEST(SpdyProtocolTest, ParseSettingsId) { + SpdyKnownSettingsId setting_id; + EXPECT_FALSE(ParseSettingsId(0, &setting_id)); + EXPECT_TRUE(ParseSettingsId(1, &setting_id)); + EXPECT_EQ(SETTINGS_HEADER_TABLE_SIZE, setting_id); + EXPECT_TRUE(ParseSettingsId(2, &setting_id)); + EXPECT_EQ(SETTINGS_ENABLE_PUSH, setting_id); + EXPECT_TRUE(ParseSettingsId(3, &setting_id)); + EXPECT_EQ(SETTINGS_MAX_CONCURRENT_STREAMS, setting_id); + EXPECT_TRUE(ParseSettingsId(4, &setting_id)); + EXPECT_EQ(SETTINGS_INITIAL_WINDOW_SIZE, setting_id); + EXPECT_TRUE(ParseSettingsId(5, &setting_id)); + EXPECT_EQ(SETTINGS_MAX_FRAME_SIZE, setting_id); + EXPECT_TRUE(ParseSettingsId(6, &setting_id)); + EXPECT_EQ(SETTINGS_MAX_HEADER_LIST_SIZE, setting_id); + EXPECT_FALSE(ParseSettingsId(7, &setting_id)); + EXPECT_TRUE(ParseSettingsId(8, &setting_id)); + EXPECT_EQ(SETTINGS_ENABLE_CONNECT_PROTOCOL, setting_id); + EXPECT_TRUE(ParseSettingsId(9, &setting_id)); + EXPECT_EQ(SETTINGS_DEPRECATE_HTTP2_PRIORITIES, setting_id); + EXPECT_FALSE(ParseSettingsId(10, &setting_id)); + EXPECT_FALSE(ParseSettingsId(0xFF44, &setting_id)); + EXPECT_TRUE(ParseSettingsId(0xFF45, &setting_id)); + EXPECT_EQ(SETTINGS_EXPERIMENT_SCHEDULER, setting_id); + EXPECT_FALSE(ParseSettingsId(0xFF46, &setting_id)); +} + +TEST(SpdyProtocolTest, SettingsIdToString) { + struct { + SpdySettingsId setting_id; + const std::string expected_string; + } test_cases[] = { + {0, "SETTINGS_UNKNOWN_0"}, + {SETTINGS_HEADER_TABLE_SIZE, "SETTINGS_HEADER_TABLE_SIZE"}, + {SETTINGS_ENABLE_PUSH, "SETTINGS_ENABLE_PUSH"}, + {SETTINGS_MAX_CONCURRENT_STREAMS, "SETTINGS_MAX_CONCURRENT_STREAMS"}, + {SETTINGS_INITIAL_WINDOW_SIZE, "SETTINGS_INITIAL_WINDOW_SIZE"}, + {SETTINGS_MAX_FRAME_SIZE, "SETTINGS_MAX_FRAME_SIZE"}, + {SETTINGS_MAX_HEADER_LIST_SIZE, "SETTINGS_MAX_HEADER_LIST_SIZE"}, + {7, "SETTINGS_UNKNOWN_7"}, + {SETTINGS_ENABLE_CONNECT_PROTOCOL, "SETTINGS_ENABLE_CONNECT_PROTOCOL"}, + {SETTINGS_DEPRECATE_HTTP2_PRIORITIES, + "SETTINGS_DEPRECATE_HTTP2_PRIORITIES"}, + {0xa, "SETTINGS_UNKNOWN_a"}, + {0xFF44, "SETTINGS_UNKNOWN_ff44"}, + {0xFF45, "SETTINGS_EXPERIMENT_SCHEDULER"}, + {0xFF46, "SETTINGS_UNKNOWN_ff46"}}; + for (auto test_case : test_cases) { + EXPECT_EQ(test_case.expected_string, + SettingsIdToString(test_case.setting_id)); + } +} + +TEST(SpdyStreamPrecedenceTest, Basic) { + SpdyStreamPrecedence spdy3_prec(2); + EXPECT_TRUE(spdy3_prec.is_spdy3_priority()); + EXPECT_EQ(2, spdy3_prec.spdy3_priority()); + EXPECT_EQ(kHttp2RootStreamId, spdy3_prec.parent_id()); + EXPECT_EQ(Spdy3PriorityToHttp2Weight(2), spdy3_prec.weight()); + EXPECT_FALSE(spdy3_prec.is_exclusive()); + + for (bool is_exclusive : {true, false}) { + SpdyStreamPrecedence h2_prec(7, 123, is_exclusive); + EXPECT_FALSE(h2_prec.is_spdy3_priority()); + EXPECT_EQ(Http2WeightToSpdy3Priority(123), h2_prec.spdy3_priority()); + EXPECT_EQ(7u, h2_prec.parent_id()); + EXPECT_EQ(123, h2_prec.weight()); + EXPECT_EQ(is_exclusive, h2_prec.is_exclusive()); + } +} + +TEST(SpdyStreamPrecedenceTest, Clamping) { + EXPECT_QUICHE_BUG(EXPECT_EQ(7, SpdyStreamPrecedence(8).spdy3_priority()), + "Invalid priority: 8"); + EXPECT_QUICHE_BUG(EXPECT_EQ(kHttp2MinStreamWeight, + SpdyStreamPrecedence(3, 0, false).weight()), + "Invalid weight: 0"); + EXPECT_QUICHE_BUG(EXPECT_EQ(kHttp2MaxStreamWeight, + SpdyStreamPrecedence(3, 300, false).weight()), + "Invalid weight: 300"); +} + +TEST(SpdyStreamPrecedenceTest, Copying) { + SpdyStreamPrecedence prec1(3); + SpdyStreamPrecedence copy1(prec1); + EXPECT_TRUE(copy1.is_spdy3_priority()); + EXPECT_EQ(3, copy1.spdy3_priority()); + + SpdyStreamPrecedence prec2(4, 5, true); + SpdyStreamPrecedence copy2(prec2); + EXPECT_FALSE(copy2.is_spdy3_priority()); + EXPECT_EQ(4u, copy2.parent_id()); + EXPECT_EQ(5, copy2.weight()); + EXPECT_TRUE(copy2.is_exclusive()); + + copy1 = prec2; + EXPECT_FALSE(copy1.is_spdy3_priority()); + EXPECT_EQ(4u, copy1.parent_id()); + EXPECT_EQ(5, copy1.weight()); + EXPECT_TRUE(copy1.is_exclusive()); + + copy2 = prec1; + EXPECT_TRUE(copy2.is_spdy3_priority()); + EXPECT_EQ(3, copy2.spdy3_priority()); +} + +TEST(SpdyStreamPrecedenceTest, Equals) { + EXPECT_EQ(SpdyStreamPrecedence(3), SpdyStreamPrecedence(3)); + EXPECT_NE(SpdyStreamPrecedence(3), SpdyStreamPrecedence(4)); + + EXPECT_EQ(SpdyStreamPrecedence(1, 2, false), + SpdyStreamPrecedence(1, 2, false)); + EXPECT_NE(SpdyStreamPrecedence(1, 2, false), + SpdyStreamPrecedence(2, 2, false)); + EXPECT_NE(SpdyStreamPrecedence(1, 2, false), + SpdyStreamPrecedence(1, 3, false)); + EXPECT_NE(SpdyStreamPrecedence(1, 2, false), + SpdyStreamPrecedence(1, 2, true)); + + SpdyStreamPrecedence spdy3_prec(3); + SpdyStreamPrecedence h2_prec(spdy3_prec.parent_id(), spdy3_prec.weight(), + spdy3_prec.is_exclusive()); + EXPECT_NE(spdy3_prec, h2_prec); +} + +TEST(SpdyDataIRTest, Construct) { + // Confirm that it makes a string of zero length from a + // absl::string_view(nullptr). + absl::string_view s1; + SpdyDataIR d1(/* stream_id = */ 1, s1); + EXPECT_EQ(0u, d1.data_len()); + EXPECT_NE(nullptr, d1.data()); + + // Confirms makes a copy of char array. + const char s2[] = "something"; + SpdyDataIR d2(/* stream_id = */ 2, s2); + EXPECT_EQ(absl::string_view(d2.data(), d2.data_len()), s2); + EXPECT_NE(absl::string_view(d1.data(), d1.data_len()), s2); + EXPECT_EQ((int)d1.data_len(), d1.flow_control_window_consumed()); + + // Confirm copies a const string. + const std::string foo = "foo"; + SpdyDataIR d3(/* stream_id = */ 3, foo); + EXPECT_EQ(foo, d3.data()); + EXPECT_EQ((int)d3.data_len(), d3.flow_control_window_consumed()); + + // Confirm copies a non-const string. + std::string bar = "bar"; + SpdyDataIR d4(/* stream_id = */ 4, bar); + EXPECT_EQ("bar", bar); + EXPECT_EQ("bar", absl::string_view(d4.data(), d4.data_len())); + + // Confirm moves an rvalue reference. Note that the test string "baz" is too + // short to trigger the move optimization, and instead a copy occurs. + std::string baz = "the quick brown fox"; + SpdyDataIR d5(/* stream_id = */ 5, std::move(baz)); + EXPECT_EQ("", baz); + EXPECT_EQ(absl::string_view(d5.data(), d5.data_len()), "the quick brown fox"); + + // Confirms makes a copy of string literal. + SpdyDataIR d7(/* stream_id = */ 7, "something else"); + EXPECT_EQ(absl::string_view(d7.data(), d7.data_len()), "something else"); + + SpdyDataIR d8(/* stream_id = */ 8, "shawarma"); + d8.set_padding_len(20); + EXPECT_EQ(28, d8.flow_control_window_consumed()); +} + +} // namespace test +} // namespace spdy diff --git a/quiche/spdy/core/spdy_simple_arena.cc b/quiche/spdy/core/spdy_simple_arena.cc new file mode 100644 index 000000000000..0b3b7985214b --- /dev/null +++ b/quiche/spdy/core/spdy_simple_arena.cc @@ -0,0 +1,106 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/spdy_simple_arena.h" + +#include +#include + +#include "quiche/common/platform/api/quiche_logging.h" + +namespace spdy { + +SpdySimpleArena::SpdySimpleArena(size_t block_size) : block_size_(block_size) {} + +SpdySimpleArena::~SpdySimpleArena() = default; + +SpdySimpleArena::SpdySimpleArena(SpdySimpleArena&& other) = default; +SpdySimpleArena& SpdySimpleArena::operator=(SpdySimpleArena&& other) = default; + +char* SpdySimpleArena::Alloc(size_t size) { + Reserve(size); + Block& b = blocks_.back(); + QUICHE_DCHECK_GE(b.size, b.used + size); + char* out = b.data.get() + b.used; + b.used += size; + return out; +} + +char* SpdySimpleArena::Realloc(char* original, size_t oldsize, size_t newsize) { + QUICHE_DCHECK(!blocks_.empty()); + Block& last = blocks_.back(); + if (last.data.get() <= original && original < last.data.get() + last.size) { + // (original, oldsize) is in the last Block. + QUICHE_DCHECK_GE(last.data.get() + last.used, original + oldsize); + if (original + oldsize == last.data.get() + last.used) { + // (original, oldsize) was the most recent allocation, + if (original + newsize < last.data.get() + last.size) { + // (original, newsize) fits in the same Block. + last.used += newsize - oldsize; + return original; + } + } + } + char* out = Alloc(newsize); + memcpy(out, original, oldsize); + return out; +} + +char* SpdySimpleArena::Memdup(const char* data, size_t size) { + char* out = Alloc(size); + memcpy(out, data, size); + return out; +} + +void SpdySimpleArena::Free(char* data, size_t size) { + if (blocks_.empty()) { + return; + } + Block& b = blocks_.back(); + if (size <= b.used && data + size == b.data.get() + b.used) { + // The memory region passed by the caller was the most recent allocation + // from the final block in this arena. + b.used -= size; + } +} + +void SpdySimpleArena::Reset() { + blocks_.clear(); + status_.bytes_allocated_ = 0; +} + +void SpdySimpleArena::Reserve(size_t additional_space) { + if (blocks_.empty()) { + AllocBlock(std::max(additional_space, block_size_)); + } else { + const Block& last = blocks_.back(); + if (last.size < last.used + additional_space) { + AllocBlock(std::max(additional_space, block_size_)); + } + } +} + +void SpdySimpleArena::AllocBlock(size_t size) { + blocks_.push_back(Block(size)); + status_.bytes_allocated_ += size; +} + +SpdySimpleArena::Block::Block(size_t s) : data(new char[s]), size(s), used(0) {} + +SpdySimpleArena::Block::~Block() = default; + +SpdySimpleArena::Block::Block(SpdySimpleArena::Block&& other) + : size(other.size), used(other.used) { + data = std::move(other.data); +} + +SpdySimpleArena::Block& SpdySimpleArena::Block::operator=( + SpdySimpleArena::Block&& other) { + size = other.size; + used = other.used; + data = std::move(other.data); + return *this; +} + +} // namespace spdy diff --git a/quiche/spdy/core/spdy_simple_arena.h b/quiche/spdy/core/spdy_simple_arena.h new file mode 100644 index 000000000000..dae6879d835f --- /dev/null +++ b/quiche/spdy/core/spdy_simple_arena.h @@ -0,0 +1,77 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_SPDY_CORE_SPDY_SIMPLE_ARENA_H_ +#define QUICHE_SPDY_CORE_SPDY_SIMPLE_ARENA_H_ + +#include +#include + +#include "quiche/common/platform/api/quiche_export.h" + +namespace spdy { + +// Allocates large blocks of memory, and doles them out in smaller chunks. +// Not thread-safe. +class QUICHE_EXPORT SpdySimpleArena { + public: + class QUICHE_EXPORT Status { + private: + friend class SpdySimpleArena; + size_t bytes_allocated_; + + public: + Status() : bytes_allocated_(0) {} + size_t bytes_allocated() const { return bytes_allocated_; } + }; + + // Blocks allocated by this arena will be at least |block_size| bytes. + explicit SpdySimpleArena(size_t block_size); + ~SpdySimpleArena(); + + // Copy and assign are not allowed. + SpdySimpleArena() = delete; + SpdySimpleArena(const SpdySimpleArena&) = delete; + SpdySimpleArena& operator=(const SpdySimpleArena&) = delete; + + // Move is allowed. + SpdySimpleArena(SpdySimpleArena&& other); + SpdySimpleArena& operator=(SpdySimpleArena&& other); + + char* Alloc(size_t size); + char* Realloc(char* original, size_t oldsize, size_t newsize); + char* Memdup(const char* data, size_t size); + + // If |data| and |size| describe the most recent allocation made from this + // arena, the memory is reclaimed. Otherwise, this method is a no-op. + void Free(char* data, size_t size); + + void Reset(); + + Status status() const { return status_; } + + private: + struct QUICHE_EXPORT Block { + std::unique_ptr data; + size_t size = 0; + size_t used = 0; + + explicit Block(size_t s); + ~Block(); + + Block(Block&& other); + Block& operator=(Block&& other); + }; + + void Reserve(size_t additional_space); + void AllocBlock(size_t size); + + size_t block_size_; + std::vector blocks_; + Status status_; +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_SPDY_SIMPLE_ARENA_H_ diff --git a/quiche/spdy/core/spdy_simple_arena_test.cc b/quiche/spdy/core/spdy_simple_arena_test.cc new file mode 100644 index 000000000000..9708375374be --- /dev/null +++ b/quiche/spdy/core/spdy_simple_arena_test.cc @@ -0,0 +1,141 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/core/spdy_simple_arena.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_test.h" + +namespace spdy { +namespace { + +size_t kDefaultBlockSize = 2048; +const char kTestString[] = "This is a decently long test string."; + +TEST(SpdySimpleArenaTest, NoAllocationOnConstruction) { + SpdySimpleArena arena(kDefaultBlockSize); + EXPECT_EQ(0u, arena.status().bytes_allocated()); +} + +TEST(SpdySimpleArenaTest, Memdup) { + SpdySimpleArena arena(kDefaultBlockSize); + const size_t length = strlen(kTestString); + char* c = arena.Memdup(kTestString, length); + EXPECT_NE(nullptr, c); + EXPECT_NE(c, kTestString); + EXPECT_EQ(absl::string_view(c, length), kTestString); +} + +TEST(SpdySimpleArenaTest, MemdupLargeString) { + SpdySimpleArena arena(10 /* block size */); + const size_t length = strlen(kTestString); + char* c = arena.Memdup(kTestString, length); + EXPECT_NE(nullptr, c); + EXPECT_NE(c, kTestString); + EXPECT_EQ(absl::string_view(c, length), kTestString); +} + +TEST(SpdySimpleArenaTest, MultipleBlocks) { + SpdySimpleArena arena(40 /* block size */); + std::vector strings = { + "One decently long string.", "Another string.", + "A third string that will surely go in a different block."}; + std::vector copies; + for (const std::string& s : strings) { + absl::string_view sp(arena.Memdup(s.data(), s.size()), s.size()); + copies.push_back(sp); + } + EXPECT_EQ(strings.size(), copies.size()); + for (size_t i = 0; i < strings.size(); ++i) { + EXPECT_EQ(copies[i], strings[i]); + } +} + +TEST(SpdySimpleArenaTest, UseAfterReset) { + SpdySimpleArena arena(kDefaultBlockSize); + const size_t length = strlen(kTestString); + char* c = arena.Memdup(kTestString, length); + arena.Reset(); + c = arena.Memdup(kTestString, length); + EXPECT_NE(nullptr, c); + EXPECT_NE(c, kTestString); + EXPECT_EQ(absl::string_view(c, length), kTestString); +} + +TEST(SpdySimpleArenaTest, Free) { + SpdySimpleArena arena(kDefaultBlockSize); + const size_t length = strlen(kTestString); + // Freeing memory not owned by the arena should be a no-op, and freeing + // before any allocations from the arena should be a no-op. + arena.Free(const_cast(kTestString), length); + char* c1 = arena.Memdup("Foo", 3); + char* c2 = arena.Memdup(kTestString, length); + arena.Free(const_cast(kTestString), length); + char* c3 = arena.Memdup("Bar", 3); + char* c4 = arena.Memdup(kTestString, length); + EXPECT_NE(c1, c2); + EXPECT_NE(c1, c3); + EXPECT_NE(c1, c4); + EXPECT_NE(c2, c3); + EXPECT_NE(c2, c4); + EXPECT_NE(c3, c4); + // Freeing c4 should succeed, since it was the most recent allocation. + arena.Free(c4, length); + // Freeing c2 should be a no-op. + arena.Free(c2, length); + // c5 should reuse memory that was previously used by c4. + char* c5 = arena.Memdup("Baz", 3); + EXPECT_EQ(c4, c5); +} + +TEST(SpdySimpleArenaTest, Alloc) { + SpdySimpleArena arena(kDefaultBlockSize); + const size_t length = strlen(kTestString); + char* c1 = arena.Alloc(length); + char* c2 = arena.Alloc(2 * length); + char* c3 = arena.Alloc(3 * length); + char* c4 = arena.Memdup(kTestString, length); + EXPECT_EQ(c1 + length, c2); + EXPECT_EQ(c2 + 2 * length, c3); + EXPECT_EQ(c3 + 3 * length, c4); + EXPECT_EQ(absl::string_view(c4, length), kTestString); +} + +TEST(SpdySimpleArenaTest, Realloc) { + SpdySimpleArena arena(kDefaultBlockSize); + const size_t length = strlen(kTestString); + // Simple realloc that fits in the block. + char* c1 = arena.Memdup(kTestString, length); + char* c2 = arena.Realloc(c1, length, 2 * length); + EXPECT_TRUE(c1); + EXPECT_EQ(c1, c2); + EXPECT_EQ(absl::string_view(c1, length), kTestString); + // Multiple reallocs. + char* c3 = arena.Memdup(kTestString, length); + EXPECT_EQ(c2 + 2 * length, c3); + EXPECT_EQ(absl::string_view(c3, length), kTestString); + char* c4 = arena.Realloc(c3, length, 2 * length); + EXPECT_EQ(c3, c4); + EXPECT_EQ(absl::string_view(c4, length), kTestString); + char* c5 = arena.Realloc(c4, 2 * length, 3 * length); + EXPECT_EQ(c4, c5); + EXPECT_EQ(absl::string_view(c5, length), kTestString); + char* c6 = arena.Memdup(kTestString, length); + EXPECT_EQ(c5 + 3 * length, c6); + EXPECT_EQ(absl::string_view(c6, length), kTestString); + // Realloc that does not fit in the remainder of the first block. + char* c7 = arena.Realloc(c6, length, kDefaultBlockSize); + EXPECT_EQ(absl::string_view(c7, length), kTestString); + arena.Free(c7, kDefaultBlockSize); + char* c8 = arena.Memdup(kTestString, length); + EXPECT_NE(c6, c7); + EXPECT_EQ(c7, c8); + EXPECT_EQ(absl::string_view(c8, length), kTestString); +} + +} // namespace +} // namespace spdy diff --git a/quiche/spdy/core/zero_copy_output_buffer.h b/quiche/spdy/core/zero_copy_output_buffer.h new file mode 100644 index 000000000000..2ab92ed3e18d --- /dev/null +++ b/quiche/spdy/core/zero_copy_output_buffer.h @@ -0,0 +1,32 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_SPDY_CORE_ZERO_COPY_OUTPUT_BUFFER_H_ +#define QUICHE_SPDY_CORE_ZERO_COPY_OUTPUT_BUFFER_H_ + +#include + +#include "quiche/common/platform/api/quiche_export.h" + +namespace spdy { + +class QUICHE_EXPORT ZeroCopyOutputBuffer { + public: + virtual ~ZeroCopyOutputBuffer() {} + + // Returns the next available segment of memory to write. Will always return + // the same segment until AdvanceWritePtr is called. + virtual void Next(char** data, int* size) = 0; + + // After writing to a buffer returned from Next(), the caller should call + // this method to indicate how many bytes were written. + virtual void AdvanceWritePtr(int64_t count) = 0; + + // Returns the available capacity of the buffer. + virtual uint64_t BytesFree() const = 0; +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_ZERO_COPY_OUTPUT_BUFFER_H_ diff --git a/quiche/spdy/test_tools/mock_spdy_framer_visitor.cc b/quiche/spdy/test_tools/mock_spdy_framer_visitor.cc new file mode 100644 index 000000000000..079ff3775f99 --- /dev/null +++ b/quiche/spdy/test_tools/mock_spdy_framer_visitor.cc @@ -0,0 +1,17 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/test_tools/mock_spdy_framer_visitor.h" + +namespace spdy { + +namespace test { + +MockSpdyFramerVisitor::MockSpdyFramerVisitor() { DelegateHeaderHandling(); } + +MockSpdyFramerVisitor::~MockSpdyFramerVisitor() = default; + +} // namespace test + +} // namespace spdy diff --git a/quiche/spdy/test_tools/mock_spdy_framer_visitor.h b/quiche/spdy/test_tools/mock_spdy_framer_visitor.h new file mode 100644 index 000000000000..93a5ca7221b0 --- /dev/null +++ b/quiche/spdy/test_tools/mock_spdy_framer_visitor.h @@ -0,0 +1,125 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_SPDY_TEST_TOOLS_MOCK_SPDY_FRAMER_VISITOR_H_ +#define QUICHE_SPDY_TEST_TOOLS_MOCK_SPDY_FRAMER_VISITOR_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/spdy/core/http2_frame_decoder_adapter.h" +#include "quiche/spdy/core/recording_headers_handler.h" +#include "quiche/spdy/test_tools/spdy_test_utils.h" + +namespace spdy { + +namespace test { + +class QUICHE_NO_EXPORT MockSpdyFramerVisitor + : public SpdyFramerVisitorInterface { + public: + MockSpdyFramerVisitor(); + ~MockSpdyFramerVisitor() override; + + MOCK_METHOD(void, OnError, + (http2::Http2DecoderAdapter::SpdyFramerError error, + std::string detailed_error), + (override)); + MOCK_METHOD(void, OnCommonHeader, + (SpdyStreamId stream_id, size_t length, uint8_t type, + uint8_t flags), + (override)); + MOCK_METHOD(void, OnDataFrameHeader, + (SpdyStreamId stream_id, size_t length, bool fin), (override)); + MOCK_METHOD(void, OnStreamFrameData, + (SpdyStreamId stream_id, const char* data, size_t len), + (override)); + MOCK_METHOD(void, OnStreamEnd, (SpdyStreamId stream_id), (override)); + MOCK_METHOD(void, OnStreamPadLength, (SpdyStreamId stream_id, size_t value), + (override)); + MOCK_METHOD(void, OnStreamPadding, (SpdyStreamId stream_id, size_t len), + (override)); + MOCK_METHOD(SpdyHeadersHandlerInterface*, OnHeaderFrameStart, + (SpdyStreamId stream_id), (override)); + MOCK_METHOD(void, OnHeaderFrameEnd, (SpdyStreamId stream_id), (override)); + MOCK_METHOD(void, OnRstStream, + (SpdyStreamId stream_id, SpdyErrorCode error_code), (override)); + MOCK_METHOD(void, OnSettings, (), (override)); + MOCK_METHOD(void, OnSetting, (SpdySettingsId id, uint32_t value), (override)); + MOCK_METHOD(void, OnPing, (SpdyPingId unique_id, bool is_ack), (override)); + MOCK_METHOD(void, OnSettingsEnd, (), (override)); + MOCK_METHOD(void, OnSettingsAck, (), (override)); + MOCK_METHOD(void, OnGoAway, + (SpdyStreamId last_accepted_stream_id, SpdyErrorCode error_code), + (override)); + MOCK_METHOD(bool, OnGoAwayFrameData, (const char* goaway_data, size_t len), + (override)); + MOCK_METHOD(void, OnHeaders, + (SpdyStreamId stream_id, size_t payload_length, bool has_priority, + int weight, SpdyStreamId parent_stream_id, bool exclusive, + bool fin, bool end), + (override)); + MOCK_METHOD(void, OnWindowUpdate, + (SpdyStreamId stream_id, int delta_window_size), (override)); + MOCK_METHOD(void, OnPushPromise, + (SpdyStreamId stream_id, SpdyStreamId promised_stream_id, + bool end), + (override)); + MOCK_METHOD(void, OnContinuation, + (SpdyStreamId stream_id, size_t payload_length, bool end), + (override)); + MOCK_METHOD( + void, OnAltSvc, + (SpdyStreamId stream_id, absl::string_view origin, + const SpdyAltSvcWireFormat::AlternativeServiceVector& altsvc_vector), + (override)); + MOCK_METHOD(void, OnPriority, + (SpdyStreamId stream_id, SpdyStreamId parent_stream_id, + int weight, bool exclusive), + (override)); + MOCK_METHOD(void, OnPriorityUpdate, + (SpdyStreamId prioritized_stream_id, + absl::string_view priority_field_value), + (override)); + MOCK_METHOD(bool, OnUnknownFrame, + (SpdyStreamId stream_id, uint8_t frame_type), (override)); + MOCK_METHOD(void, OnUnknownFrameStart, + (SpdyStreamId stream_id, size_t length, uint8_t type, + uint8_t flags), + (override)); + MOCK_METHOD(void, OnUnknownFramePayload, + (SpdyStreamId stream_id, absl::string_view payload), (override)); + + void DelegateHeaderHandling() { + ON_CALL(*this, OnHeaderFrameStart(testing::_)) + .WillByDefault(testing::Invoke( + this, &MockSpdyFramerVisitor::ReturnTestHeadersHandler)); + ON_CALL(*this, OnHeaderFrameEnd(testing::_)) + .WillByDefault(testing::Invoke( + this, &MockSpdyFramerVisitor::ResetTestHeadersHandler)); + } + + SpdyHeadersHandlerInterface* ReturnTestHeadersHandler( + SpdyStreamId /* stream_id */) { + if (headers_handler_ == nullptr) { + headers_handler_ = std::make_unique(); + } + return headers_handler_.get(); + } + + void ResetTestHeadersHandler(SpdyStreamId /* stream_id */) { + headers_handler_.reset(); + } + + std::unique_ptr headers_handler_; +}; + +} // namespace test + +} // namespace spdy + +#endif // QUICHE_SPDY_TEST_TOOLS_MOCK_SPDY_FRAMER_VISITOR_H_ diff --git a/quiche/spdy/test_tools/spdy_test_utils.cc b/quiche/spdy/test_tools/spdy_test_utils.cc new file mode 100644 index 000000000000..fc5962bd9e6e --- /dev/null +++ b/quiche/spdy/test_tools/spdy_test_utils.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "quiche/spdy/test_tools/spdy_test_utils.h" + +#include +#include +#include +#include +#include +#include + +#include "quiche/common/platform/api/quiche_logging.h" +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/common/quiche_endian.h" + +namespace spdy { +namespace test { + +std::string HexDumpWithMarks(const unsigned char* data, int length, + const bool* marks, int mark_length) { + static const char kHexChars[] = "0123456789abcdef"; + static const int kColumns = 4; + + const int kSizeLimit = 1024; + if (length > kSizeLimit || mark_length > kSizeLimit) { + QUICHE_LOG(ERROR) << "Only dumping first " << kSizeLimit << " bytes."; + length = std::min(length, kSizeLimit); + mark_length = std::min(mark_length, kSizeLimit); + } + + std::string hex; + for (const unsigned char* row = data; length > 0; + row += kColumns, length -= kColumns) { + for (const unsigned char* p = row; p < row + 4; ++p) { + if (p < row + length) { + const bool mark = + (marks && (p - data) < mark_length && marks[p - data]); + hex += mark ? '*' : ' '; + hex += kHexChars[(*p & 0xf0) >> 4]; + hex += kHexChars[*p & 0x0f]; + hex += mark ? '*' : ' '; + } else { + hex += " "; + } + } + hex = hex + " "; + + for (const unsigned char* p = row; p < row + 4 && p < row + length; ++p) { + hex += (*p >= 0x20 && *p <= 0x7f) ? (*p) : '.'; + } + + hex = hex + '\n'; + } + return hex; +} + +void CompareCharArraysWithHexError(const std::string& description, + const unsigned char* actual, + const int actual_len, + const unsigned char* expected, + const int expected_len) { + const int min_len = std::min(actual_len, expected_len); + const int max_len = std::max(actual_len, expected_len); + std::unique_ptr marks(new bool[max_len]); + bool identical = (actual_len == expected_len); + for (int i = 0; i < min_len; ++i) { + if (actual[i] != expected[i]) { + marks[i] = true; + identical = false; + } else { + marks[i] = false; + } + } + for (int i = min_len; i < max_len; ++i) { + marks[i] = true; + } + if (identical) return; + ADD_FAILURE() << "Description:\n" + << description << "\n\nExpected:\n" + << HexDumpWithMarks(expected, expected_len, marks.get(), + max_len) + << "\nActual:\n" + << HexDumpWithMarks(actual, actual_len, marks.get(), max_len); +} + +void SetFrameFlags(SpdySerializedFrame* frame, uint8_t flags) { + frame->data()[4] = flags; +} + +void SetFrameLength(SpdySerializedFrame* frame, size_t length) { + QUICHE_CHECK_GT(1u << 14, length); + { + int32_t wire_length = quiche::QuicheEndian::HostToNet32(length); + memcpy(frame->data(), reinterpret_cast(&wire_length) + 1, 3); + } +} + +} // namespace test +} // namespace spdy diff --git a/quiche/spdy/test_tools/spdy_test_utils.h b/quiche/spdy/test_tools/spdy_test_utils.h new file mode 100644 index 000000000000..366f2490adbb --- /dev/null +++ b/quiche/spdy/test_tools/spdy_test_utils.h @@ -0,0 +1,41 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_SPDY_TEST_TOOLS_SPDY_TEST_UTILS_H_ +#define QUICHE_SPDY_TEST_TOOLS_SPDY_TEST_UTILS_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quiche/spdy/core/http2_header_block.h" +#include "quiche/spdy/core/spdy_protocol.h" + +namespace spdy { + +inline bool operator==(absl::string_view x, + const Http2HeaderBlock::ValueProxy& y) { + return y.operator==(x); +} + +namespace test { + +std::string HexDumpWithMarks(const unsigned char* data, int length, + const bool* marks, int mark_length); + +void CompareCharArraysWithHexError(const std::string& description, + const unsigned char* actual, + const int actual_len, + const unsigned char* expected, + const int expected_len); + +void SetFrameFlags(SpdySerializedFrame* frame, uint8_t flags); + +void SetFrameLength(SpdySerializedFrame* frame, size_t length); + +} // namespace test +} // namespace spdy + +#endif // QUICHE_SPDY_TEST_TOOLS_SPDY_TEST_UTILS_H_ diff --git a/quiche/web_transport/test_tools/mock_web_transport.h b/quiche/web_transport/test_tools/mock_web_transport.h new file mode 100644 index 000000000000..9fdecfcc5eff --- /dev/null +++ b/quiche/web_transport/test_tools/mock_web_transport.h @@ -0,0 +1,80 @@ +// Copyright 2023 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Pre-defined mocks for the APIs in web_transport.h. + +#ifndef QUICHE_WEB_TRANSPORT_TEST_TOOLS_MOCK_WEB_TRANSPORT_H_ +#define QUICHE_WEB_TRANSPORT_TEST_TOOLS_MOCK_WEB_TRANSPORT_H_ + +#include "quiche/common/platform/api/quiche_test.h" +#include "quiche/web_transport/web_transport.h" + +namespace webtransport { +namespace test { + +class QUICHE_NO_EXPORT MockStreamVisitor : public StreamVisitor { + MOCK_METHOD(void, OnCanRead, (), (override)); + MOCK_METHOD(void, OnCanWrite, (), (override)); + MOCK_METHOD(void, OnResetStreamReceived, (StreamErrorCode), (override)); + MOCK_METHOD(void, OnStopSendingReceived, (StreamErrorCode), (override)); + MOCK_METHOD(void, OnWriteSideInDataRecvdState, (), (override)); +}; + +class QUICHE_NO_EXPORT MockStream : public Stream { + MOCK_METHOD(ReadResult, Read, (absl::Span buffer), (override)); + MOCK_METHOD(ReadResult, Read, (std::string * output), (override)); + MOCK_METHOD(absl::Status, Writev, + (absl::Span data, + const quiche::StreamWriteOptions& options), + (override)); + MOCK_METHOD(bool, CanWrite, (), (const, override)); + MOCK_METHOD(void, AbruptlyTerminate, (absl::Status), (override)); + MOCK_METHOD(size_t, ReadableBytes, (), (const, override)); + MOCK_METHOD(StreamId, GetStreamId, (), (const, override)); + MOCK_METHOD(void, ResetWithUserCode, (StreamErrorCode error), (override)); + MOCK_METHOD(void, SendStopSending, (StreamErrorCode error), (override)); + MOCK_METHOD(void, ResetDueToInternalError, (), (override)); + MOCK_METHOD(void, MaybeResetDueToStreamObjectGone, (), (override)); + MOCK_METHOD(StreamVisitor*, visitor, (), (override)); + MOCK_METHOD(void, SetVisitor, (std::unique_ptr visitor), + (override)); +}; + +class QUICHE_NO_EXPORT MockSessionVisitor : public SessionVisitor { + MOCK_METHOD(void, OnSessionReady, (const spdy::Http2HeaderBlock& headers), + (override)); + MOCK_METHOD(void, OnSessionClosed, + (SessionErrorCode error_code, const std::string& error_message), + (override)); + MOCK_METHOD(void, OnIncomingBidirectionalStreamAvailable, (), (override)); + MOCK_METHOD(void, OnIncomingUnidirectionalStreamAvailable, (), (override)); + MOCK_METHOD(void, OnDatagramReceived, (absl::string_view datagram), + (override)); + MOCK_METHOD(void, OnCanCreateNewOutgoingBidirectionalStream, (), (override)); + MOCK_METHOD(void, OnCanCreateNewOutgoingUnidirectionalStream, (), (override)); +}; + +class QUICHE_NO_EXPORT MockSession : public Session { + public: + MOCK_METHOD(void, CloseSession, + (SessionErrorCode error_code, absl::string_view error_message), + (override)); + MOCK_METHOD(Stream*, AcceptIncomingBidirectionalStream, (), (override)); + MOCK_METHOD(Stream*, AcceptIncomingUnidirectionalStream, (), (override)); + MOCK_METHOD(bool, CanOpenNextOutgoingBidirectionalStream, (), (override)); + MOCK_METHOD(bool, CanOpenNextOutgoingUnidirectionalStream, (), (override)); + MOCK_METHOD(Stream*, OpenOutgoingBidirectionalStream, (), (override)); + MOCK_METHOD(Stream*, OpenOutgoingUnidirectionalStream, (), (override)); + MOCK_METHOD(Stream*, GetStreamById, (StreamId), (override)); + MOCK_METHOD(DatagramStatus, SendOrQueueDatagram, (absl::string_view datagram), + (override)); + MOCK_METHOD(uint64_t, GetMaxDatagramSize, (), (const, override)); + MOCK_METHOD(void, SetDatagramMaxTimeInQueue, + (absl::Duration max_time_in_queue), (override)); +}; + +} // namespace test +} // namespace webtransport + +#endif // QUICHE_WEB_TRANSPORT_TEST_TOOLS_MOCK_WEB_TRANSPORT_H_ diff --git a/quiche/web_transport/web_transport.h b/quiche/web_transport/web_transport.h new file mode 100644 index 000000000000..7ea008527fc5 --- /dev/null +++ b/quiche/web_transport/web_transport.h @@ -0,0 +1,221 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This header contains interfaces that abstract away different backing +// protocols for WebTransport. + +#ifndef QUICHE_WEB_TRANSPORT_WEB_TRANSPORT_H_ +#define QUICHE_WEB_TRANSPORT_WEB_TRANSPORT_H_ + +#include +#include +#include + +// The dependencies of this API should be kept minimal and independent of +// specific transport implementations. +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "quiche/common/platform/api/quiche_export.h" +#include "quiche/common/quiche_stream.h" +#include "quiche/spdy/core/http2_header_block.h" + +namespace webtransport { + +// A numeric ID uniquely identifying a WebTransport stream. Note that by design, +// those IDs are not available in the Web API, and the IDs do not necessarily +// match between client and server perspective, since there may be a proxy +// between them. +using StreamId = uint32_t; +// Application-specific error code used for resetting either the read or the +// write half of the stream. +using StreamErrorCode = uint8_t; +// Application-specific error code used for closing a WebTransport session. +using SessionErrorCode = uint32_t; + +// An outcome of a datagram send call. +enum class DatagramStatusCode { + // Datagram has been successfully sent or placed into the datagram queue. + kSuccess, + // Datagram has not been sent since the underlying QUIC connection is blocked + // by the congestion control. Note that this can only happen if the queue is + // full. + kBlocked, + // Datagram has not been sent since it is too large to fit into a single + // UDP packet. + kTooBig, + // An unspecified internal error. + kInternalError, +}; + +// An outcome of a datagram send call, in both enum and human-readable form. +struct QUICHE_EXPORT DatagramStatus { + explicit DatagramStatus(DatagramStatusCode code, std::string error_message) + : code(code), error_message(std::move(error_message)) {} + + DatagramStatusCode code; + std::string error_message; +}; + +enum class StreamType { + kUnidirectional, + kBidirectional, +}; + +// The stream visitor is an application-provided object that gets notified about +// events related to a WebTransport stream. The visitor object is owned by the +// stream itself, meaning that if the stream is ever fully closed, the visitor +// will be garbage-collected. +class QUICHE_EXPORT StreamVisitor : public quiche::WriteStreamVisitor { + public: + virtual ~StreamVisitor() {} + + // Called whenever the stream has readable data available. + virtual void OnCanRead() = 0; + + // Called when RESET_STREAM is received for the stream. + virtual void OnResetStreamReceived(StreamErrorCode error) = 0; + // Called when STOP_SENDING is received for the stream. + virtual void OnStopSendingReceived(StreamErrorCode error) = 0; + // Called when the write side of the stream is closed and all of the data sent + // has been acknowledged ("Data Recvd" state of RFC 9000). Primarily used by + // the state machine of the Web API. + virtual void OnWriteSideInDataRecvdState() = 0; +}; + +// A stream (either bidirectional or unidirectional) that is contained within a +// WebTransport session. +class QUICHE_EXPORT Stream : public quiche::WriteStream { + public: + struct QUICHE_EXPORT ReadResult { + // Number of bytes actually read. + size_t bytes_read; + // Whether the FIN has been received; if true, no further data will arrive + // on the stream, and the stream object can be soon potentially garbage + // collected. + bool fin; + }; + + virtual ~Stream() {} + + // Reads at most |buffer.size()| bytes into |buffer|. + [[nodiscard]] virtual ReadResult Read(absl::Span buffer) = 0; + // Reads all available data and appends it to the end of |output|. + [[nodiscard]] virtual ReadResult Read(std::string* output) = 0; + + // Indicates the number of bytes that can be read from the stream. + virtual size_t ReadableBytes() const = 0; + + // An ID that is unique within the session. Those are not exposed to the user + // via the web API, but can be used internally for bookkeeping and + // diagnostics. + virtual StreamId GetStreamId() const = 0; + + // Resets the read or the write side of the stream with the specified error + // code. + virtual void ResetWithUserCode(StreamErrorCode error) = 0; + virtual void SendStopSending(StreamErrorCode error) = 0; + + // A general-purpose stream reset method that may be used when a specific + // error code is not available. + virtual void ResetDueToInternalError() = 0; + // If the stream has not been already reset, reset the stream. This is + // primarily used in the JavaScript API when the stream object has been + // garbage collected. + virtual void MaybeResetDueToStreamObjectGone() = 0; + + virtual StreamVisitor* visitor() = 0; + virtual void SetVisitor(std::unique_ptr visitor) = 0; +}; + +// Visitor that gets notified about events related to a WebTransport session. +class QUICHE_EXPORT SessionVisitor { + public: + virtual ~SessionVisitor() {} + + // Notifies the visitor when the session is ready to exchange application + // data. + virtual void OnSessionReady(const spdy::Http2HeaderBlock& headers) = 0; + + // Notifies the visitor when the session has been closed. + virtual void OnSessionClosed(SessionErrorCode error_code, + const std::string& error_message) = 0; + + // Notifies the visitor when a new stream has been received. The stream in + // question can be retrieved using AcceptIncomingBidirectionalStream() or + // AcceptIncomingUnidirectionalStream(). + virtual void OnIncomingBidirectionalStreamAvailable() = 0; + virtual void OnIncomingUnidirectionalStreamAvailable() = 0; + + // Notifies the visitor when a new datagram has been received. + virtual void OnDatagramReceived(absl::string_view datagram) = 0; + + // Notifies the visitor that a new outgoing stream can now be created. + virtual void OnCanCreateNewOutgoingBidirectionalStream() = 0; + virtual void OnCanCreateNewOutgoingUnidirectionalStream() = 0; +}; + +// An abstract interface for a WebTransport session. +// +// *** AN IMPORTANT NOTE ABOUT STREAM LIFETIMES *** +// Stream objects are managed internally by the underlying QUIC stack, and can +// go away at any time due to the peer resetting the stream. Because of that, +// any pointers to the stream objects returned by this class MUST NEVER be +// retained long-term, except inside the stream visitor (the stream visitor is +// owned by the stream object). If you need to store a reference to a stream, +// consider one of the two following options: +// (1) store a stream ID, +// (2) store a weak pointer to the stream visitor, and then access the stream +// via the said visitor (the visitor is guaranteed to be alive as long as +// the stream is alive). +class QUICHE_EXPORT Session { + public: + virtual ~Session() {} + + // Closes the WebTransport session in question with the specified |error_code| + // and |error_message|. + virtual void CloseSession(SessionErrorCode error_code, + absl::string_view error_message) = 0; + + // Return the earliest incoming stream that has been received by the session + // but has not been accepted. Returns nullptr if there are no incoming + // streams. See the class note regarding the lifetime of the returned stream + // object. + virtual Stream* AcceptIncomingBidirectionalStream() = 0; + virtual Stream* AcceptIncomingUnidirectionalStream() = 0; + + // Returns true if flow control allows opening a new stream. + // + // IMPORTANT: See the class note regarding the lifetime of the returned stream + // object. + virtual bool CanOpenNextOutgoingBidirectionalStream() = 0; + virtual bool CanOpenNextOutgoingUnidirectionalStream() = 0; + + // Opens a new WebTransport stream, or returns nullptr if that is not possible + // due to flow control. See the class note regarding the lifetime of the + // returned stream object. + // + // IMPORTANT: See the class note regarding the lifetime of the returned stream + // object. + virtual Stream* OpenOutgoingBidirectionalStream() = 0; + virtual Stream* OpenOutgoingUnidirectionalStream() = 0; + + // Returns the WebTransport stream with the corresponding ID. + // + // IMPORTANT: See the class note regarding the lifetime of the returned stream + // object. + virtual Stream* GetStreamById(StreamId id) = 0; + + virtual DatagramStatus SendOrQueueDatagram(absl::string_view datagram) = 0; + // Returns a conservative estimate of the largest datagram size that the + // session would be able to send. + virtual uint64_t GetMaxDatagramSize() const = 0; + // Sets the largest duration that a datagram can spend in the queue before + // being silently dropped. + virtual void SetDatagramMaxTimeInQueue(absl::Duration max_time_in_queue) = 0; +}; + +} // namespace webtransport + +#endif // QUICHE_WEB_TRANSPORT_WEB_TRANSPORT_H_